diff --git a/.gitignore b/.gitignore index c9fea6c..9ac65db 100644 --- a/.gitignore +++ b/.gitignore @@ -161,9 +161,6 @@ cython_debug/ # option (not recommended) you can uncomment the following to ignore the entire idea folder. #.idea/ -crosscodeeval/ -temp/ -*.jsonl *.out datasets/ repositories/ @@ -171,8 +168,6 @@ repositories/ cache/ output*/ *.pdf -*.json -*.jsonl old_scripts/ *cache*/ *.zip diff --git a/experiments/repo-qa/data/Qwen_slash_Qwen2.5-Coder-7B-Instruct.jsonl b/experiments/repo-qa/data/Qwen_slash_Qwen2.5-Coder-7B-Instruct.jsonl new file mode 100644 index 0000000..af65d28 --- /dev/null +++ b/experiments/repo-qa/data/Qwen_slash_Qwen2.5-Coder-7B-Instruct.jsonl @@ -0,0 +1,600 @@ +{"repo": "psf/black", "name": "_merge_string_group", "language": "python", "path": "src/black/trans.py", "position_ratio": 0.05, "description": "\nFunction Description:\n1. **Purpose**: To combine adjacent strings into a single string within a line of code, ensuring that the merged result is syntactically correct.\n2. **Input**: A line of code and a list of indices indicating the starting points of groups of strings to be merged.\n3. **Output**: A modified line of code with the specified groups of strings merged, or an error if the merge cannot be performed.\n4. **Procedure**: \n - The function iterates over the provided indices, each representing the start of a group of adjacent strings.\n - For each index, it checks if the group of strings can be merged without violating syntax rules.\n - If valid, it merges the strings and updates a dictionary with the new string and its position.\n - If no strings can be merged, it returns an error.\n - Otherwise, it constructs a new line of code with the merged strings and any remaining elements from the original line, including comments.\n\n", "instruction": "Based on the function description and code context, please retrieve and repeat the exact described function from the code context in a code block wrapped by ```:", "template": "instruction\ncode_context\ndescription\ninstruction", "code_context": " msg_cant_transform = msg_result.err()\n rblc_cant_transform = rblc_result.err()\n cant_transform = CannotTransform(\n \"StringMerger failed to merge any strings in this line.\"\n )\n\n # Chain the errors together using `__cause__`.\n msg_cant_transform.__cause__ = rblc_cant_transform\n cant_transform.__cause__ = msg_cant_transform\n\n yield Err(cant_transform)\n else:\n yield Ok(new_line)\n\n @staticmethod\n def _remove_backslash_line_continuation_chars(\n line: Line, string_indices: List[int]\n ) -> TResult[Line]:\n \"\"\"\n Merge strings that were split across multiple lines using\n line-continuation backslashes.\n\n Returns:\n Ok(new_line), if @line contains backslash line-continuation\n characters.\n OR\n Err(CannotTransform), otherwise.\n \"\"\"\n LL = line.leaves\n\n indices_to_transform = []\n for string_idx in string_indices:\n string_leaf = LL[string_idx]\n if (\n string_leaf.type == token.STRING\n and \"\\\\\\n\" in string_leaf.value\n and not has_triple_quotes(string_leaf.value)\n ):\n indices_to_transform.append(string_idx)\n\n if not indices_to_transform:\n return TErr(\n \"Found no string leaves that contain backslash line continuation\"\n \" characters.\"\n )\n\n new_line = line.clone()\n new_line.comments = line.comments.copy()\n append_leaves(new_line, line, LL)\n\n for string_idx in indices_to_transform:\n new_string_leaf = new_line.leaves[string_idx]\n new_string_leaf.value = new_string_leaf.value.replace(\"\\\\\\n\", \"\")\n\n return Ok(new_line)\n\n \ndef _merge_string_group(\n self, line: Line, string_indices: List[int]\n ) -> TResult[Line]:\n \"\"\"\n Merges string groups (i.e. set of adjacent strings).\n\n Each index from `string_indices` designates one string group's first\n leaf in `line.leaves`.\n\n Returns:\n Ok(new_line), if ALL of the validation checks found in\n _validate_msg(...) pass.\n OR\n Err(CannotTransform), otherwise.\n \"\"\"\n LL = line.leaves\n\n is_valid_index = is_valid_index_factory(LL)\n\n # A dict of {string_idx: tuple[num_of_strings, string_leaf]}.\n merged_string_idx_dict: Dict[int, Tuple[int, Leaf]] = {}\n for string_idx in string_indices:\n vresult = self._validate_msg(line, string_idx)\n if isinstance(vresult, Err):\n continue\n merged_string_idx_dict[string_idx] = self._merge_one_string_group(\n LL, string_idx, is_valid_index\n )\n\n if not merged_string_idx_dict:\n return TErr(\"No string group is merged\")\n\n # Build the final line ('new_line') that this method will later return.\n new_line = line.clone()\n previous_merged_string_idx = -1\n previous_merged_num_of_strings = -1\n for i, leaf in enumerate(LL):\n if i in merged_string_idx_dict:\n previous_merged_string_idx = i\n previous_merged_num_of_strings, string_leaf = merged_string_idx_dict[i]\n new_line.append(string_leaf)\n\n if (\n previous_merged_string_idx\n <= i\n < previous_merged_string_idx + previous_merged_num_of_strings\n ):\n for comment_leaf in line.comments_after(LL[i]):\n new_line.append(comment_leaf, preformatted=True)\n continue\n\n append_leaves(new_line, line, [leaf])\n\n return Ok(new_line)\n\n def _merge_one_string_group(\n self, LL: List[Leaf], string_idx: int, is_valid_index: Callable[[int], bool]\n ) -> Tuple[int, Leaf]:\n \"\"\"\n Merges one string group where the first string in the group is\n `LL[string_idx]`.\n\n Returns:\n A tuple of `(num_of_strings, leaf)` where `num_of_strings` is the\n number of strings merged and `leaf` is the newly merged string\n to be replaced in the new line.\n \"\"\"\n # If the string group is wrapped inside an Atom node, we must make sure\n # to later replace that Atom with our new (merged) string leaf.\n atom_node = LL[string_idx].parent\n\n # We will place BREAK_MARK in between every two substrings that we\n # merge. We will then later go through our final result and use the\n # various instances of BREAK_MARK we find to add the right values to\n # the custom split map.\n BREAK_MARK = \"@@@@@ BLACK BREAKPOINT MARKER @@@@@\"\n\n QUOTE = LL[string_idx].value[-1]\n\n def make_naked(string: str, string_prefix: str) -> str:\n \"\"\"Strip @string (i.e. make it a \"naked\" string)\n\n Pre-conditions:\n * assert_is_leaf_string(@string)\n\n Returns:\n A string that is identical to @string except that\n @string_prefix has been stripped, the surrounding QUOTE\n characters have been removed, and any remaining QUOTE\n characters have been escaped.\n \"\"\"\n assert_is_leaf_string(string)\n if \"f\" in string_prefix:\n f_expressions = (\n string[span[0] + 1 : span[1] - 1] # +-1 to get rid of curly braces\n for span in iter_fexpr_spans(string)\n )\n debug_expressions_contain_visible_quotes = any(\n re.search(r\".*[\\'\\\"].*(?= 0\n ), \"Logic error while filling the custom string breakpoint cache.\"\n\n temp_string = temp_string[mark_idx + len(BREAK_MARK) :]\n breakpoint_idx = mark_idx + (len(prefix) if has_prefix else 0) + 1\n custom_splits.append(CustomSplit(has_prefix, breakpoint_idx))\n\n string_leaf = Leaf(token.STRING, S_leaf.value.replace(BREAK_MARK, \"\"))\n\n if atom_node is not None:\n # If not all children of the atom node are merged (this can happen\n # when there is a standalone comment in the middle) ...\n if non_string_idx - string_idx < len(atom_node.children):\n # We need to replace the old STRING leaves with the new string leaf.\n first_child_idx = LL[string_idx].remove()\n for idx in range(string_idx + 1, non_string_idx):\n LL[idx].remove()\n if first_child_idx is not None:\n atom_node.insert_child(first_child_idx, string_leaf)\n else:\n # Else replace the atom node with the new string leaf.\n replace_child(atom_node, string_leaf)\n\n self.add_custom_splits(string_leaf.value, custom_splits)\n return num_of_strings, string_leaf\n\n @staticmethod\n def _validate_msg(line: Line, string_idx: int) -> TResult[None]:\n \"\"\"Validate (M)erge (S)tring (G)roup\n\n Transform-time string validation logic for _merge_string_group(...).\n\n Returns:\n * Ok(None), if ALL validation checks (listed below) pass.\n OR\n * Err(CannotTransform), if any of the following are true:\n - The target string group does not contain ANY stand-alone comments.\n - The target string is not in a string group (i.e. it has no\n adjacent strings).\n - The string group has more than one inline comment.\n - The string group has an inline comment that appears to be a pragma.\n - The set of all string prefixes in the string group is of\n length greater than one and is not equal to {\"\", \"f\"}.\n - The string group consists of raw strings.\n - The string group is stringified type annotations. We don't want to\n process stringified type annotations since pyright doesn't support\n them spanning multiple string values. (NOTE: mypy, pytype, pyre do\n support them, so we can change if pyright also gains support in the\n future. See https://github.com/microsoft/pyright/issues/4359.)\n \"\"\"\n # We first check for \"inner\" stand-alone comments (i.e. stand-alone\n # comments that have a string leaf before them AND after them).\n for inc in [1, -1]:\n i = string_idx\n found_sa_comment = False\n is_valid_index = is_valid_index_factory(line.leaves)\n while is_valid_index(i) and line.leaves[i].type in [\n token.STRING,\n STANDALONE_COMMENT,\n ]:\n if line.leaves[i].type == STANDALONE_COMMENT:\n found_sa_comment = True\n elif found_sa_comment:\n return TErr(\n \"StringMerger does NOT merge string groups which contain \"\n \"stand-alone comments.\"\n )\n\n i += inc\n\n num_of_inline_string_comments = 0\n set_of_prefixes = set()\n num_of_strings = 0\n for leaf in line.leaves[string_idx:]:\n if leaf.type != token.STRING:\n # If the string group is trailed by a comma, we count the\n # comments trailing the comma to be one of the string group's\n # comments.\n if leaf.type == token.COMMA and id(leaf) in line.comments:\n num_of_inline_string_comments += 1\n break\n\n if has_triple_quotes(leaf.value):\n return TErr(\"StringMerger does NOT merge multiline strings.\")\n\n num_of_strings += 1\n prefix = get_string_prefix(leaf.value).lower()\n if \"r\" in prefix:\n return TErr(\"StringMerger does NOT merge raw strings.\")\n\n set_of_prefixes.add(prefix)\n\n if id(leaf) in line.comments:\n num_of_inline_string_comments += 1\n if contains_pragma_comment(line.comments[id(leaf)]):\n return TErr(\"Cannot merge strings which have pragma comments.\")\n\n if num_of_strings < 2:\n return TErr(\n f\"Not enough strings to merge (num_of_strings={num_of_strings}).\"\n )\n\n if num_of_inline_string_comments > 1:\n return TErr(\n f\"Too many inline string comments ({num_of_inline_string_comments}).\"\n )\n\n if len(set_of_prefixes) > 1 and set_of_prefixes != {\"\", \"f\"}:\n return TErr(f\"Too many different prefixes ({set_of_prefixes}).\")\n\n return Ok(None)\n\n\nclass StringParenStripper(StringTransformer):\n \"\"\"StringTransformer that strips surrounding parentheses from strings.\n\n Requirements:\n The line contains a string which is surrounded by parentheses and:\n - The target string is NOT the only argument to a function call.\n - The target string is NOT a \"pointless\" string.\n - If the target string contains a PERCENT, the brackets are not\n preceded or followed by an operator with higher precedence than\n PERCENT.\n\n Transformations:\n The parentheses mentioned in the 'Requirements' section are stripped.\n\n Collaborations:\n StringParenStripper has its own inherent usefulness, but it is also\n relied on to clean up the parentheses created by StringParenWrapper (in\n the event that they are no longer needed).\n \"\"\"\n\n def do_match(self, line: Line) -> TMatchResult:\n LL = line.leaves\n\n is_valid_index = is_valid_index_factory(LL)\n\n string_indices = []\n\n idx = -1\n while True:\n idx += 1\n if idx >= len(LL):\n break\n leaf = LL[idx]\n\n # Should be a string...\n if leaf.type != token.STRING:\n continue\n\n # If this is a \"pointless\" string...\n if (\n leaf.parent\n and leaf.parent.parent\n and leaf.parent.parent.type == syms.simple_stmt\n ):\n continue\n\n # Should be preceded by a non-empty LPAR...\n if (\n not is_valid_index(idx - 1)\n or LL[idx - 1].type != token.LPAR\n or is_empty_lpar(LL[idx - 1])\n ):\n continue\n\n # That LPAR should NOT be preceded by a function name or a closing\n # bracket (which could be a function which returns a function or a\n # list/dictionary that contains a function)...\n if is_valid_index(idx - 2) and (\n LL[idx - 2].type == token.NAME or LL[idx - 2].type in CLOSING_BRACKETS\n ):\n continue\n\n string_idx = idx\n\n # Skip the string trailer, if one exists.\n string_parser = StringParser()\n next_idx = string_parser.parse(LL, string_idx)\n\n # if the leaves in the parsed string include a PERCENT, we need to\n # make sure the initial LPAR is NOT preceded by an operator with\n # higher or equal precedence to PERCENT\n if is_valid_index(idx - 2):\n # mypy can't quite follow unless we name this\n before_lpar = LL[idx - 2]\n if token.PERCENT in {leaf.type for leaf in LL[idx - 1 : next_idx]} and (\n (\n before_lpar.type\n in {\n token.STAR,\n token.AT,\n token.SLASH,\n token.DOUBLESLASH,\n token.PERCENT,\n token.TILDE,\n token.DOUBLESTAR,\n token.AWAIT,\n token.LSQB,\n token.LPAR,\n }\n )\n or (\n # only unary PLUS/MINUS\n before_lpar.parent\n and before_lpar.parent.type == syms.factor\n and (before_lpar.type in {token.PLUS, token.MINUS})\n )\n ):\n continue\n\n # Should be followed by a non-empty RPAR...\n if (\n is_valid_index(next_idx)\n and LL[next_idx].type == token.RPAR\n and not is_empty_rpar(LL[next_idx])\n ):\n # That RPAR should NOT be followed by anything with higher\n # precedence than PERCENT\n if is_valid_index(next_idx + 1) and LL[next_idx + 1].type in {\n token.DOUBLESTAR,\n token.LSQB,\n token.LPAR,\n token.DOT,\n }:\n continue\n\n string_indices.append(string_idx)\n idx = string_idx\n while idx < len(LL) - 1 and LL[idx + 1].type == token.STRING:\n idx += 1\n\n if string_indices:\n return Ok(string_indices)\n return TErr(\"This line has no strings wrapped in parens.\")\n\n def do_transform(\n self, line: Line, string_indices: List[int]\n ) -> Iterator[TResult[Line]]:\n LL = line.leaves\n\n string_and_rpar_indices: List[int] = []\n for string_idx in string_indices:\n string_parser = StringParser()\n rpar_idx = string_parser.parse(LL, string_idx)\n\n should_transform = True\n for leaf in (LL[string_idx - 1], LL[rpar_idx]):\n if line.comments_after(leaf):\n # Should not strip parentheses which have comments attached\n # to them.\n should_transform = False\n break\n if should_transform:\n string_and_rpar_indices.extend((string_idx, rpar_idx))\n\n if string_and_rpar_indices:\n yield Ok(self._transform_to_new_line(line, string_and_rpar_indices))\n else:\n yield Err(\n CannotTransform(\"All string groups have comments attached to them.\")\n )\n\n def _transform_to_new_line(\n self, line: Line, string_and_rpar_indices: List[int]\n ) -> Line:\n LL = line.leaves\n\n new_line = line.clone()\n new_line.comments = line.comments.copy()\n\n previous_idx = -1\n # We need to sort the indices, since string_idx and its matching\n # rpar_idx may not come in order, e.g. in\n # `(\"outer\" % (\"inner\".join(items)))`, the \"inner\" string's\n # string_idx is smaller than \"outer\" string's rpar_idx.\n for idx in sorted(string_and_rpar_indices):\n leaf = LL[idx]\n lpar_or_rpar_idx = idx - 1 if leaf.type == token.STRING else idx\n append_leaves(new_line, line, LL[previous_idx + 1 : lpar_or_rpar_idx])\n if leaf.type == token.STRING:\n string_leaf = Leaf(token.STRING, LL[idx].value)\n LL[lpar_or_rpar_idx].remove() # Remove lpar.\n replace_child(LL[idx], string_leaf)\n new_line.append(string_leaf)\n # replace comments\n old_comments = new_line.comments.pop(id(LL[idx]), [])\n new_line.comments.setdefault(id(string_leaf), []).extend(old_comments)\n else:\n LL[lpar_or_rpar_idx].remove() # This is a rpar.\n\n previous_idx = idx\n\n # Append the leaves after the last idx:\n append_leaves(new_line, line, LL[idx + 1 :])\n\n return new_line\n\n\nclass BaseStringSplitter(StringTransformer):\n \"\"\"\n Abstract class for StringTransformers which transform a Line's strings by splitting\n them or placing them on their own lines where necessary to avoid going over\n the configured line length.\n\n Requirements:\n * The target string value is responsible for the line going over the\n line length limit. It follows that after all of black's other line\n split methods have been exhausted, this line (or one of the resulting\n lines after all line splits are performed) would still be over the\n line_length limit unless we split this string.\n AND\n\n * The target string is NOT a \"pointless\" string (i.e. a string that has\n no parent or siblings).\n AND\n\n * The target string is not followed by an inline comment that appears\n to be a pragma.\n AND\n\n * The target string is not a multiline (i.e. triple-quote) string.\n \"\"\"\n\n STRING_OPERATORS: Final = [\n token.EQEQUAL,\n token.GREATER,\n token.GREATEREQUAL,\n token.LESS,\n token.LESSEQUAL,\n token.NOTEQUAL,\n token.PERCENT,\n token.PLUS,\n token.STAR,\n ]\n\n @abstractmethod\n def do_splitter_match(self, line: Line) -> TMatchResult:\n \"\"\"\n BaseStringSplitter asks its clients to override this method instead of\n `StringTransformer.do_match(...)`.\n\n Follows the same protocol as `StringTransformer.do_match(...)`.\n\n Refer to `help(StringTransformer.do_match)` for more information.\n \"\"\"\n\n def do_match(self, line: Line) -> TMatchResult:\n match_result = self.do_splitter_match(line)\n if isinstance(match_result, Err):\n return match_result\n\n string_indices = match_result.ok()\n assert len(string_indices) == 1, (\n f\"{self.__class__.__name__} should only find one match at a time, found\"\n f\" {len(string_indices)}\"\n )\n string_idx = string_indices[0]\n vresult = self._validate(line, string_idx)\n if isinstance(vresult, Err):\n return vresult\n\n return match_result\n\n def _validate(self, line: Line, string_idx: int) -> TResult[None]:\n \"\"\"\n Checks that @line meets all of the requirements listed in this classes'\n docstring. Refer to `help(BaseStringSplitter)` for a detailed\n description of those requirements.\n\n Returns:\n * Ok(None), if ALL of the requirements are met.\n OR\n * Err(CannotTransform), if ANY of the requirements are NOT met.\n \"\"\"\n LL = line.leaves\n\n string_leaf = LL[string_idx]\n\n max_string_length = self._get_max_string_length(line, string_idx)\n if len(string_leaf.value) <= max_string_length:\n return TErr(\n \"The string itself is not what is causing this line to be too long.\"\n )\n\n if not string_leaf.parent or [L.type for L in string_leaf.parent.children] == [\n token.STRING,\n token.NEWLINE,\n ]:\n return TErr(\n f\"This string ({string_leaf.value}) appears to be pointless (i.e. has\"\n \" no parent).\"\n )\n\n if id(line.leaves[string_idx]) in line.comments and contains_pragma_comment(\n line.comments[id(line.leaves[string_idx])]\n ):\n return TErr(\n \"Line appears to end with an inline pragma comment. Splitting the line\"\n \" could modify the pragma's behavior.\"\n )\n\n if has_triple_quotes(string_leaf.value):\n return TErr(\"We cannot split multiline strings.\")\n\n return Ok(None)\n\n def _get_max_string_length(self, line: Line, string_idx: int) -> int:\n \"\"\"\n Calculates the max string length used when attempting to determine\n whether or not the target string is responsible for causing the line to\n go over the line length limit.\n\n WARNING: This method is tightly coupled to both StringSplitter and\n (especially) StringParenWrapper. There is probably a better way to\n accomplish what is being done here.\n\n Returns:\n max_string_length: such that `line.leaves[string_idx].value >\n max_string_length` implies that the target string IS responsible\n for causing this line to exceed the line length limit.\n \"\"\"\n LL = line.leaves\n\n is_valid_index = is_valid_index_factory(LL)\n\n # We use the shorthand \"WMA4\" in comments to abbreviate \"We must\n # account for\". When giving examples, we use STRING to mean some/any\n # valid string.\n #\n # Finally, we use the following convenience variables:\n #\n # P: The leaf that is before the target string leaf.\n # N: The leaf that is after the target string leaf.\n # NN: The leaf that is after N.\n\n # WMA4 the whitespace at the beginning of the line.\n offset = line.depth * 4\n\n if is_valid_index(string_idx - 1):\n p_idx = string_idx - 1\n if (\n LL[string_idx - 1].type == token.LPAR\n and LL[string_idx - 1].value == \"\"\n and string_idx >= 2\n ):\n # If the previous leaf is an empty LPAR placeholder, we should skip it.\n p_idx -= 1\n\n P = LL[p_idx]\n if P.type in self.STRING_OPERATORS:\n # WMA4 a space and a string operator (e.g. `+ STRING` or `== STRING`).\n offset += len(str(P)) + 1\n\n if P.type == token.COMMA:\n # WMA4 a space, a comma, and a closing bracket [e.g. `), STRING`].\n offset += 3\n\n if P.type in [token.COLON, token.EQUAL, token.PLUSEQUAL, token.NAME]:\n # This conditional branch is meant to handle dictionary keys,\n # variable assignments, 'return STRING' statement lines, and\n # 'else STRING' ternary expression lines.\n\n # WMA4 a single space.\n offset += 1\n\n # WMA4 the lengths of any leaves that came before that space,\n # but after any closing bracket before that space.\n for leaf in reversed(LL[: p_idx + 1]):\n offset += len(str(leaf))\n if leaf.type in CLOSING_BRACKETS:\n break\n\n if is_valid_index(string_idx + 1):\n N = LL[string_idx + 1]\n if N.type == token.RPAR and N.value == \"\" and len(LL) > string_idx + 2:\n # If the next leaf is an empty RPAR placeholder, we should skip it.\n N = LL[string_idx + 2]\n\n if N.type == token.COMMA:\n # WMA4 a single comma at the end of the string (e.g `STRING,`).\n offset += 1\n\n if is_valid_index(string_idx + 2):\n NN = LL[string_idx + 2]\n\n if N.type == token.DOT and NN.type == token.NAME:\n # This conditional branch is meant to handle method calls invoked\n # off of a string literal up to and including the LPAR character.\n\n # WMA4 the '.' character.\n offset += 1\n\n if (\n is_valid_index(string_idx + 3)\n and LL[string_idx + 3].type == token.LPAR\n ):\n # WMA4 the left parenthesis character.\n offset += 1\n\n # WMA4 the length of the method's name.\n offset += len(NN.value)\n\n has_comments = False\n for comment_leaf in line.comments_after(LL[string_idx]):\n if not has_comments:\n has_comments = True\n # WMA4 two spaces before the '#' character.\n offset += 2\n\n # WMA4 the length of the inline comment.\n offset += len(comment_leaf.value)\n\n max_string_length = count_chars_in_width(str(line), self.line_length - offset)\n return max_string_length\n\n @staticmethod\n def _prefer_paren_wrap_match(LL: List[Leaf]) -> Optional[int]:\n \"\"\"\n Returns:\n string_idx such that @LL[string_idx] is equal to our target (i.e.\n matched) string, if this line matches the \"prefer paren wrap\" statement\n requirements listed in the 'Requirements' section of the StringParenWrapper\n class's docstring.\n OR\n None, otherwise.\n \"\"\"\n # The line must start with a string.\n if LL[0].type != token.STRING:\n return None\n\n matching_nodes = [\n syms.listmaker,\n syms.dictsetmaker,\n syms.testlist_gexp,\n ]\n # If the string is an immediate child of a list/set/tuple literal...\n if (\n parent_type(LL[0]) in matching_nodes\n or parent_type(LL[0].parent) in matching_nodes\n ):\n # And the string is surrounded by commas (or is the first/last child)...\n prev_sibling = LL[0].prev_sibling\n next_sibling = LL[0].next_sibling\n if (\n not prev_sibling\n and not next_sibling\n and parent_type(LL[0]) == syms.atom\n ):\n # If it's an atom string, we need to check the parent atom's siblings.\n parent = LL[0].parent\n assert parent is not None # For type checkers.\n prev_sibling = parent.prev_sibling\n next_sibling = parent.next_sibling\n if (not prev_sibling or prev_sibling.type == token.COMMA) and (\n not next_sibling or next_sibling.type == token.COMMA\n ):\n return 0\n\n return None\n\n\ndef iter_fexpr_spans(s: str) -> Iterator[Tuple[int, int]]:\n \"\"\"\n Yields spans corresponding to expressions in a given f-string.\n Spans are half-open ranges (left inclusive, right exclusive).\n Assumes the input string is a valid f-string, but will not crash if the input\n string is invalid.\n \"\"\"\n stack: List[int] = [] # our curly paren stack\n i = 0\n while i < len(s):\n if s[i] == \"{\":\n # if we're in a string part of the f-string, ignore escaped curly braces\n if not stack and i + 1 < len(s) and s[i + 1] == \"{\":\n i += 2\n continue\n stack.append(i)\n i += 1\n continue\n\n if s[i] == \"}\":\n if not stack:\n i += 1\n continue\n j = stack.pop()\n # we've made it back out of the expression! yield the span\n if not stack:\n yield (j, i + 1)\n i += 1\n continue\n\n # if we're in an expression part of the f-string, fast-forward through strings\n # note that backslashes are not legal in the expression portion of f-strings\n if stack:\n delim = None\n if s[i : i + 3] in (\"'''\", '\"\"\"'):\n delim = s[i : i + 3]\n elif s[i] in (\"'\", '\"'):\n delim = s[i]\n if delim:\n i += len(delim)\n while i < len(s) and s[i : i + len(delim)] != delim:\n i += 1\n i += len(delim)\n continue\n i += 1\n\n\ndef fstring_contains_expr(s: str) -> bool:\n return any(iter_fexpr_spans(s))\n\n\ndef _toggle_fexpr_quotes(fstring: str, old_quote: str) -> str:\n \"\"\"\n Toggles quotes used in f-string expressions that are `old_quote`.\n\n f-string expressions can't contain backslashes, so we need to toggle the\n quotes if the f-string itself will end up using the same quote. We can\n simply toggle without escaping because, quotes can't be reused in f-string\n expressions. They will fail to parse.\n\n NOTE: If PEP 701 is accepted, above statement will no longer be true.\n Though if quotes can be reused, we can simply reuse them without updates or\n escaping, once Black figures out how to parse the new grammar.\n \"\"\"\n new_quote = \"'\" if old_quote == '\"' else '\"'\n parts = []\n previous_index = 0\n for start, end in iter_fexpr_spans(fstring):\n parts.append(fstring[previous_index:start])\n parts.append(fstring[start:end].replace(old_quote, new_quote))\n previous_index = end\n parts.append(fstring[previous_index:])\n return \"\".join(parts)\n\n\nclass StringSplitter(BaseStringSplitter, CustomSplitMapMixin):\n \"\"\"\n StringTransformer that splits \"atom\" strings (i.e. strings which exist on\n lines by themselves).\n\n Requirements:\n * The line consists ONLY of a single string (possibly prefixed by a\n string operator [e.g. '+' or '==']), MAYBE a string trailer, and MAYBE\n a trailing comma.\n AND\n * All of the requirements listed in BaseStringSplitter's docstring.\n\n Transformations:\n The string mentioned in the 'Requirements' section is split into as\n many substrings as necessary to adhere to the configured line length.\n\n In the final set of substrings, no substring should be smaller than\n MIN_SUBSTR_SIZE characters.\n\n The string will ONLY be split on spaces (i.e. each new substring should\n start with a space). Note that the string will NOT be split on a space\n which is escaped with a backslash.\n\n If the string is an f-string, it will NOT be split in the middle of an\n f-expression (e.g. in f\"FooBar: {foo() if x else bar()}\", {foo() if x\n else bar()} is an f-expression).\n\n If the string that is being split has an associated set of custom split\n records and those custom splits will NOT result in any line going over\n the configured line length, those custom splits are used. Otherwise the\n string is split as late as possible (from left-to-right) while still\n adhering to the transformation rules listed above.\n\n Collaborations:\n StringSplitter relies on StringMerger to construct the appropriate\n CustomSplit objects and add them to the custom split map.\n \"\"\"\n\n MIN_SUBSTR_SIZE: Final = 6\n\n def do_splitter_match(self, line: Line) -> TMatchResult:\n LL = line.leaves\n\n if self._prefer_paren_wrap_match(LL) is not None:\n return TErr(\"Line needs to be wrapped in parens first.\")\n\n is_valid_index = is_valid_index_factory(LL)\n\n idx = 0\n\n # The first two leaves MAY be the 'not in' keywords...\n if (\n is_valid_index(idx)\n and is_valid_index(idx + 1)\n and [LL[idx].type, LL[idx + 1].type] == [token.NAME, token.NAME]\n and str(LL[idx]) + str(LL[idx + 1]) == \"not in\"\n ):\n idx += 2\n # Else the first leaf MAY be a string operator symbol or the 'in' keyword...\n elif is_valid_index(idx) and (\n LL[idx].type in self.STRING_OPERATORS\n or LL[idx].type == token.NAME\n and str(LL[idx]) == \"in\"\n ):\n idx += 1\n\n # The next/first leaf MAY be an empty LPAR...\n if is_valid_index(idx) and is_empty_lpar(LL[idx]):\n idx += 1\n\n # The next/first leaf MUST be a string...\n if not is_valid_index(idx) or LL[idx].type != token.STRING:\n return TErr(\"Line does not start with a string.\")\n\n string_idx = idx\n\n # Skip the string trailer, if one exists.\n string_parser = StringParser()\n idx = string_parser.parse(LL, string_idx)\n\n # That string MAY be followed by an empty RPAR...\n if is_valid_index(idx) and is_empty_rpar(LL[idx]):\n idx += 1\n\n # That string / empty RPAR leaf MAY be followed by a comma...\n if is_valid_index(idx) and LL[idx].type == token.COMMA:\n idx += 1\n\n # But no more leaves are allowed...\n if is_valid_index(idx):\n return TErr(\"This line does not end with a string.\")\n\n return Ok([string_idx])\n\n def do_transform(\n self, line: Line, string_indices: List[int]\n ) -> Iterator[TResult[Line]]:\n LL = line.leaves\n assert len(string_indices) == 1, (\n f\"{self.__class__.__name__} should only find one match at a time, found\"\n f\" {len(string_indices)}\"\n )\n string_idx = string_indices[0]\n\n QUOTE = LL[string_idx].value[-1]\n\n is_valid_index = is_valid_index_factory(LL)\n insert_str_child = insert_str_child_factory(LL[string_idx])\n\n prefix = get_string_prefix(LL[string_idx].value).lower()\n\n # We MAY choose to drop the 'f' prefix from substrings that don't\n # contain any f-expressions, but ONLY if the original f-string\n # contains at least one f-expression. Otherwise, we will alter the AST\n # of the program.\n drop_pointless_f_prefix = (\"f\" in prefix) and fstring_contains_expr(\n LL[string_idx].value\n )\n\n first_string_line = True\n\n string_op_leaves = self._get_string_operator_leaves(LL)\n string_op_leaves_length = (\n sum(len(str(prefix_leaf)) for prefix_leaf in string_op_leaves) + 1\n if string_op_leaves\n else 0\n )\n\n def maybe_append_string_operators(new_line: Line) -> None:\n \"\"\"\n Side Effects:\n If @line starts with a string operator and this is the first\n line we are constructing, this function appends the string\n operator to @new_line and replaces the old string operator leaf\n in the node structure. Otherwise this function does nothing.\n \"\"\"\n maybe_prefix_leaves = string_op_leaves if first_string_line else []\n for i, prefix_leaf in enumerate(maybe_prefix_leaves):\n replace_child(LL[i], prefix_leaf)\n new_line.append(prefix_leaf)\n\n ends_with_comma = (\n is_valid_index(string_idx + 1) and LL[string_idx + 1].type == token.COMMA\n )\n\n def max_last_string_column() -> int:\n \"\"\"\n Returns:\n The max allowed width of the string value used for the last\n line we will construct. Note that this value means the width\n rather than the number of characters (e.g., many East Asian\n characters expand to two columns).\n \"\"\"\n result = self.line_length\n result -= line.depth * 4\n result -= 1 if ends_with_comma else 0\n result -= string_op_leaves_length\n return result\n\n # --- Calculate Max Break Width (for string value)\n # We start with the line length limit\n max_break_width = self.line_length\n # The last index of a string of length N is N-1.\n max_break_width -= 1\n # Leading whitespace is not present in the string value (e.g. Leaf.value).\n max_break_width -= line.depth * 4\n if max_break_width < 0:\n yield TErr(\n f\"Unable to split {LL[string_idx].value} at such high of a line depth:\"\n f\" {line.depth}\"\n )\n return\n\n # Check if StringMerger registered any custom splits.\n custom_splits = self.pop_custom_splits(LL[string_idx].value)\n # We use them ONLY if none of them would produce lines that exceed the\n # line limit.\n use_custom_breakpoints = bool(\n custom_splits\n and all(csplit.break_idx <= max_break_width for csplit in custom_splits)\n )\n\n # Temporary storage for the remaining chunk of the string line that\n # can't fit onto the line currently being constructed.\n rest_value = LL[string_idx].value\n\n def more_splits_should_be_made() -> bool:\n \"\"\"\n Returns:\n True iff `rest_value` (the remaining string value from the last\n split), should be split again.\n \"\"\"\n if use_custom_breakpoints:\n return len(custom_splits) > 1\n else:\n return str_width(rest_value) > max_last_string_column()\n\n string_line_results: List[Ok[Line]] = []\n while more_splits_should_be_made():\n if use_custom_breakpoints:\n # Custom User Split (manual)\n csplit = custom_splits.pop(0)\n break_idx = csplit.break_idx\n else:\n # Algorithmic Split (automatic)\n max_bidx = (\n count_chars_in_width(rest_value, max_break_width)\n - string_op_leaves_length\n )\n maybe_break_idx = self._get_break_idx(rest_value, max_bidx)\n if maybe_break_idx is None:\n # If we are unable to algorithmically determine a good split\n # and this string has custom splits registered to it, we\n # fall back to using them--which means we have to start\n # over from the beginning.\n if custom_splits:\n rest_value = LL[string_idx].value\n string_line_results = []\n first_string_line = True\n use_custom_breakpoints = True\n continue\n\n # Otherwise, we stop splitting here.\n break\n\n break_idx = maybe_break_idx\n\n # --- Construct `next_value`\n next_value = rest_value[:break_idx] + QUOTE\n\n # HACK: The following 'if' statement is a hack to fix the custom\n # breakpoint index in the case of either: (a) substrings that were\n # f-strings but will have the 'f' prefix removed OR (b) substrings\n # that were not f-strings but will now become f-strings because of\n # redundant use of the 'f' prefix (i.e. none of the substrings\n # contain f-expressions but one or more of them had the 'f' prefix\n # anyway; in which case, we will prepend 'f' to _all_ substrings).\n #\n # There is probably a better way to accomplish what is being done\n # here...\n #\n # If this substring is an f-string, we _could_ remove the 'f'\n # prefix, and the current custom split did NOT originally use a\n # prefix...\n if (\n use_custom_breakpoints\n and not csplit.has_prefix\n and (\n # `next_value == prefix + QUOTE` happens when the custom\n # split is an empty string.\n next_value == prefix + QUOTE\n or next_value != self._normalize_f_string(next_value, prefix)\n )\n ):\n # Then `csplit.break_idx` will be off by one after removing\n # the 'f' prefix.\n break_idx += 1\n next_value = rest_value[:break_idx] + QUOTE\n\n if drop_pointless_f_prefix:\n next_value = self._normalize_f_string(next_value, prefix)\n\n # --- Construct `next_leaf`\n next_leaf = Leaf(token.STRING, next_value)\n insert_str_child(next_leaf)\n self._maybe_normalize_string_quotes(next_leaf)\n\n # --- Construct `next_line`\n next_line = line.clone()\n maybe_append_string_operators(next_line)\n next_line.append(next_leaf)\n string_line_results.append(Ok(next_line))\n\n rest_value = prefix + QUOTE + rest_value[break_idx:]\n first_string_line = False\n\n yield from string_line_results\n\n if drop_pointless_f_prefix:\n rest_value = self._normalize_f_string(rest_value, prefix)\n\n rest_leaf = Leaf(token.STRING, rest_value)\n insert_str_child(rest_leaf)\n\n # NOTE: I could not find a test case that verifies that the following\n # line is actually necessary, but it seems to be. Otherwise we risk\n # not normalizing the last substring, right?\n self._maybe_normalize_string_quotes(rest_leaf)\n\n last_line = line.clone()\n maybe_append_string_operators(last_line)\n\n # If there are any leaves to the right of the target string...\n if is_valid_index(string_idx + 1):\n # We use `temp_value` here to determine how long the last line\n # would be if we were to append all the leaves to the right of the\n # target string to the last string line.\n temp_value = rest_value\n for leaf in LL[string_idx + 1 :]:\n temp_value += str(leaf)\n if leaf.type == token.LPAR:\n break\n\n # Try to fit them all on the same line with the last substring...\n if (\n str_width(temp_value) <= max_last_string_column()\n or LL[string_idx + 1].type == token.COMMA\n ):\n last_line.append(rest_leaf)\n append_leaves(last_line, line, LL[string_idx + 1 :])\n yield Ok(last_line)\n # Otherwise, place the last substring on one line and everything\n # else on a line below that...\n else:\n last_line.append(rest_leaf)\n yield Ok(last_line)\n\n non_string_line = line.clone()\n append_leaves(non_string_line, line, LL[string_idx + 1 :])\n yield Ok(non_string_line)\n # Else the target string was the last leaf...\n else:\n last_line.append(rest_leaf)\n last_line.comments = line.comments.copy()\n yield Ok(last_line)\n\n def _iter_nameescape_slices(self, string: str) -> Iterator[Tuple[Index, Index]]:\n \"\"\"\n Yields:\n All ranges of @string which, if @string were to be split there,\n would result in the splitting of an \\\\N{...} expression (which is NOT\n allowed).\n \"\"\"\n # True - the previous backslash was unescaped\n # False - the previous backslash was escaped *or* there was no backslash\n previous_was_unescaped_backslash = False\n it = iter(enumerate(string))\n for idx, c in it:\n if c == \"\\\\\":\n previous_was_unescaped_backslash = not previous_was_unescaped_backslash\n continue\n if not previous_was_unescaped_backslash or c != \"N\":\n previous_was_unescaped_backslash = False\n continue\n previous_was_unescaped_backslash = False\n\n begin = idx - 1 # the position of backslash before \\N{...}\n for idx, c in it:\n if c == \"}\":\n end = idx\n break\n else:\n # malformed nameescape expression?\n # should have been detected by AST parsing earlier...\n raise RuntimeError(f\"{self.__class__.__name__} LOGIC ERROR!\")\n yield begin, end\n\n def _iter_fexpr_slices(self, string: str) -> Iterator[Tuple[Index, Index]]:\n \"\"\"\n Yields:\n All ranges of @string which, if @string were to be split there,\n would result in the splitting of an f-expression (which is NOT\n allowed).\n \"\"\"\n if \"f\" not in get_string_prefix(string).lower():\n return\n yield from iter_fexpr_spans(string)\n\n def _get_illegal_split_indices(self, string: str) -> Set[Index]:\n illegal_indices: Set[Index] = set()\n iterators = [\n self._iter_fexpr_slices(string),\n self._iter_nameescape_slices(string),\n ]\n for it in iterators:\n for begin, end in it:\n illegal_indices.update(range(begin, end + 1))\n return illegal_indices\n\n def _get_break_idx(self, string: str, max_break_idx: int) -> Optional[int]:\n \"\"\"\n This method contains the algorithm that StringSplitter uses to\n determine which character to split each string at.\n\n Args:\n @string: The substring that we are attempting to split.\n @max_break_idx: The ideal break index. We will return this value if it\n meets all the necessary conditions. In the likely event that it\n doesn't we will try to find the closest index BELOW @max_break_idx\n that does. If that fails, we will expand our search by also\n considering all valid indices ABOVE @max_break_idx.\n\n Pre-Conditions:\n * assert_is_leaf_string(@string)\n * 0 <= @max_break_idx < len(@string)\n\n Returns:\n break_idx, if an index is able to be found that meets all of the\n conditions listed in the 'Transformations' section of this classes'\n docstring.\n OR\n None, otherwise.\n \"\"\"\n is_valid_index = is_valid_index_factory(string)\n\n assert is_valid_index(max_break_idx)\n assert_is_leaf_string(string)\n\n _illegal_split_indices = self._get_illegal_split_indices(string)\n\n def breaks_unsplittable_expression(i: Index) -> bool:\n \"\"\"\n Returns:\n True iff returning @i would result in the splitting of an\n unsplittable expression (which is NOT allowed).\n \"\"\"\n return i in _illegal_split_indices\n\n def passes_all_checks(i: Index) -> bool:\n \"\"\"\n Returns:\n True iff ALL of the conditions listed in the 'Transformations'\n section of this classes' docstring would be met by returning @i.\n \"\"\"\n is_space = string[i] == \" \"\n is_split_safe = is_valid_index(i - 1) and string[i - 1] in SPLIT_SAFE_CHARS\n\n is_not_escaped = True\n j = i - 1\n while is_valid_index(j) and string[j] == \"\\\\\":\n is_not_escaped = not is_not_escaped\n j -= 1\n\n is_big_enough = (\n len(string[i:]) >= self.MIN_SUBSTR_SIZE\n and len(string[:i]) >= self.MIN_SUBSTR_SIZE\n )\n return (\n (is_space or is_split_safe)\n and is_not_escaped\n and is_big_enough\n and not breaks_unsplittable_expression(i)\n )\n\n # First, we check all indices BELOW @max_break_idx.\n break_idx = max_break_idx\n while is_valid_index(break_idx - 1) and not passes_all_checks(break_idx):\n break_idx -= 1\n\n if not passes_all_checks(break_idx):\n # If that fails, we check all indices ABOVE @max_break_idx.\n #\n # If we are able to find a valid index here, the next line is going\n # to be longer than the specified line length, but it's probably\n # better than doing nothing at all.\n break_idx = max_break_idx + 1\n while is_valid_index(break_idx + 1) and not passes_all_checks(break_idx):\n break_idx += 1\n\n if not is_valid_index(break_idx) or not passes_all_checks(break_idx):\n return None\n\n return break_idx\n\n def _maybe_normalize_string_quotes(self, leaf: Leaf) -> None:\n if self.normalize_strings:\n leaf.value = normalize_string_quotes(leaf.value)\n\n def _normalize_f_string(self, string: str, prefix: str) -> str:\n \"\"\"\n Pre-Conditions:\n * assert_is_leaf_string(@string)\n\n Returns:\n * If @string is an f-string that contains no f-expressions, we\n return a string identical to @string except that the 'f' prefix\n has been stripped and all double braces (i.e. '{{' or '}}') have\n been normalized (i.e. turned into '{' or '}').\n OR\n * Otherwise, we return @string.\n \"\"\"\n assert_is_leaf_string(string)\n\n if \"f\" in prefix and not fstring_contains_expr(string):\n new_prefix = prefix.replace(\"f\", \"\")\n\n temp = string[len(prefix) :]\n temp = re.sub(r\"\\{\\{\", \"{\", temp)\n temp = re.sub(r\"\\}\\}\", \"}\", temp)\n new_string = temp\n\n return f\"{new_prefix}{new_string}\"\n else:\n return string\n\n def _get_string_operator_leaves(self, leaves: Iterable[Leaf]) -> List[Leaf]:\n LL = list(leaves)\n\n string_op_leaves = []\n i = 0\n while LL[i].type in self.STRING_OPERATORS + [token.NAME]:\n prefix_leaf = Leaf(LL[i].type, str(LL[i]).strip())\n string_op_leaves.append(prefix_leaf)\n i += 1\n return string_op_leaves\n\n\nclass StringParenWrapper(BaseStringSplitter, CustomSplitMapMixin):\n \"\"\"\n StringTransformer that wraps strings in parens and then splits at the LPAR.\n\n Requirements:\n All of the requirements listed in BaseStringSplitter's docstring in\n addition to the requirements listed below:\n\n * The line is a return/yield statement, which returns/yields a string.\n OR\n * The line is part of a ternary expression (e.g. `x = y if cond else\n z`) such that the line starts with `else `, where is\n some string.\n OR\n * The line is an assert statement, which ends with a string.\n OR\n * The line is an assignment statement (e.g. `x = ` or `x +=\n `) such that the variable is being assigned the value of some\n string.\n OR\n * The line is a dictionary key assignment where some valid key is being\n assigned the value of some string.\n OR\n * The line is an lambda expression and the value is a string.\n OR\n * The line starts with an \"atom\" string that prefers to be wrapped in\n parens. It's preferred to be wrapped when it's is an immediate child of\n a list/set/tuple literal, AND the string is surrounded by commas (or is\n the first/last child).\n\n Transformations:\n The chosen string is wrapped in parentheses and then split at the LPAR.\n\n We then have one line which ends with an LPAR and another line that\n starts with the chosen string. The latter line is then split again at\n the RPAR. This results in the RPAR (and possibly a trailing comma)\n being placed on its own line.\n\n NOTE: If any leaves exist to the right of the chosen string (except\n for a trailing comma, which would be placed after the RPAR), those\n leaves are placed inside the parentheses. In effect, the chosen\n string is not necessarily being \"wrapped\" by parentheses. We can,\n however, count on the LPAR being placed directly before the chosen\n string.\n\n In other words, StringParenWrapper creates \"atom\" strings. These\n can then be split again by StringSplitter, if necessary.\n\n Collaborations:\n In the event that a string line split by StringParenWrapper is\n changed such that it no longer needs to be given its own line,\n StringParenWrapper relies on StringParenStripper to clean up the\n parentheses it created.\n\n For \"atom\" strings that prefers to be wrapped in parens, it requires\n StringSplitter to hold the split until the string is wrapped in parens.\n \"\"\"\n\n def do_splitter_match(self, line: Line) -> TMatchResult:\n LL = line.leaves\n\n if line.leaves[-1].type in OPENING_BRACKETS:\n return TErr(\n \"Cannot wrap parens around a line that ends in an opening bracket.\"\n )\n\n string_idx = (\n self._return_match(LL)\n or self._else_match(LL)\n or self._assert_match(LL)\n or self._assign_match(LL)\n or self._dict_or_lambda_match(LL)\n or self._prefer_paren_wrap_match(LL)\n )\n\n if string_idx is not None:\n string_value = line.leaves[string_idx].value\n # If the string has neither spaces nor East Asian stops...\n if not any(\n char == \" \" or char in SPLIT_SAFE_CHARS for char in string_value\n ):\n # And will still violate the line length limit when split...\n max_string_width = self.line_length - ((line.depth + 1) * 4)\n if str_width(string_value) > max_string_width:\n # And has no associated custom splits...\n if not self.has_custom_splits(string_value):\n # Then we should NOT put this string on its own line.\n return TErr(\n \"We do not wrap long strings in parentheses when the\"\n \" resultant line would still be over the specified line\"\n \" length and can't be split further by StringSplitter.\"\n )\n return Ok([string_idx])\n\n return TErr(\"This line does not contain any non-atomic strings.\")\n\n @staticmethod\n def _return_match(LL: List[Leaf]) -> Optional[int]:\n \"\"\"\n Returns:\n string_idx such that @LL[string_idx] is equal to our target (i.e.\n matched) string, if this line matches the return/yield statement\n requirements listed in the 'Requirements' section of this classes'\n docstring.\n OR\n None, otherwise.\n \"\"\"\n # If this line is a part of a return/yield statement and the first leaf\n # contains either the \"return\" or \"yield\" keywords...\n if parent_type(LL[0]) in [syms.return_stmt, syms.yield_expr] and LL[\n 0\n ].value in [\"return\", \"yield\"]:\n...\n", "needle_token_start": 508, "needle_token_end": 1066, "code_context_ntokens": 16122, "output": ["```python\ndef _merge_string_group(\n self, line: Line, string_indices: List[int]\n ) -> TResult[Line]:\n \"\"\"\n Merges string groups (i.e. set of adjacent strings).\n\n Each index from `string_indices` designates one string group's first\n leaf in `line.leaves`.\n\n Returns:\n Ok(new_line), if ALL of the validation checks found in\n _validate_msg(...) pass.\n OR\n Err(CannotTransform), otherwise.\n \"\"\"\n LL = line.leaves\n\n is_valid_index = is_valid_index_factory(LL)\n\n # A dict of {string_idx: tuple[num_of_strings, string_leaf]}.\n merged_string_idx_dict: Dict[int, Tuple[int, Leaf]] = {}\n for string_idx in string_indices:\n vresult = self._validate_msg(line, string_idx)\n if isinstance(vresult, Err):\n continue\n merged_string_idx_dict[string_idx] = self._merge_one_string_group(\n LL, string_idx, is_valid_index\n )\n\n if not merged_string_idx_dict:\n return TErr(\"No string group is merged\")\n\n # Build the final line ('new_line') that this method will later return.\n new_line = line.clone()\n previous_merged_string_idx = -1\n previous_merged_num_of_strings = -1\n for i, leaf in enumerate(LL):\n if i in merged_string_idx_dict:\n previous_merged_string_idx = i\n previous_merged_num_of_strings, string_leaf = merged_string_idx_dict[i]\n new_line.append(string_leaf)\n\n if (\n previous_merged_string_idx\n <= i\n < previous_merged_string_idx + previous_merged_num_of_strings\n ):\n for comment_leaf in line.comments_after(LL[i]):\n new_line.append(comment_leaf, preformatted=True)\n continue\n\n append_leaves(new_line, line, [leaf])\n\n return Ok(new_line)\n```"]} +{"repo": "psf/black", "name": "convert_unchanged_lines", "language": "python", "path": "src/black/ranges.py", "position_ratio": 0.15, "description": "\nFunction Description:\n1. **Purpose**: The function is designed to transform specific lines of code that have not been modified into a special comment format that preserves their original formatting during further processing. This is particularly useful in scenarios where selective formatting is required, ensuring that unchanged parts of the code remain visually consistent and untouched.\n2. **Input**: The function takes two parameters: a node representing the parsed source code and a collection of line ranges that specify which lines have not been altered.\n3. **Output**: There is no direct output from this function as it modifies the input node in place. The changes involve converting specified lines into a format that will not be altered in subsequent formatting steps.\n4. **Procedure**: The function operates in two main phases:\n - First, it identifies top-level statements within the unchanged lines and converts these blocks into a special comment type that preserves their formatting.\n - Second, it processes individual unchanged lines within the node, converting them into the same special comment type. This includes normalizing comment prefixes and indentations to ensure consistency, even though the content itself remains unchanged.\n\n", "instruction": "Based on the function description and code context, please retrieve and repeat the exact described function from the code context in a code block wrapped by ```:", "template": "instruction\ncode_context\ndescription\ninstruction", "code_context": "# Path: src/black/parsing.py\n\"\"\"\nParse Python code and perform AST validation.\n\"\"\"\n\nimport ast\nimport sys\nimport warnings\nfrom typing import Iterable, Iterator, List, Set, Tuple\n\nfrom black.mode import VERSION_TO_FEATURES, Feature, TargetVersion, supports_feature\nfrom black.nodes import syms\nfrom blib2to3 import pygram\nfrom blib2to3.pgen2 import driver\nfrom blib2to3.pgen2.grammar import Grammar\nfrom blib2to3.pgen2.parse import ParseError\nfrom blib2to3.pgen2.tokenize import TokenError\nfrom blib2to3.pytree import Leaf, Node\n\n\nclass InvalidInput(ValueError):\n \"\"\"Raised when input source code fails all parse attempts.\"\"\"\n\n\ndef get_grammars(target_versions: Set[TargetVersion]) -> List[Grammar]:\n if not target_versions:\n # No target_version specified, so try all grammars.\n return [\n # Python 3.7-3.9\n pygram.python_grammar_async_keywords,\n # Python 3.0-3.6\n pygram.python_grammar,\n # Python 3.10+\n pygram.python_grammar_soft_keywords,\n ]\n\n grammars = []\n # If we have to parse both, try to parse async as a keyword first\n if not supports_feature(\n target_versions, Feature.ASYNC_IDENTIFIERS\n ) and not supports_feature(target_versions, Feature.PATTERN_MATCHING):\n # Python 3.7-3.9\n grammars.append(pygram.python_grammar_async_keywords)\n if not supports_feature(target_versions, Feature.ASYNC_KEYWORDS):\n # Python 3.0-3.6\n grammars.append(pygram.python_grammar)\n if any(Feature.PATTERN_MATCHING in VERSION_TO_FEATURES[v] for v in target_versions):\n # Python 3.10+\n grammars.append(pygram.python_grammar_soft_keywords)\n\n # At least one of the above branches must have been taken, because every Python\n # version has exactly one of the two 'ASYNC_*' flags\n return grammars\n\n\ndef lib2to3_parse(src_txt: str, target_versions: Iterable[TargetVersion] = ()) -> Node:\n \"\"\"Given a string with source, return the lib2to3 Node.\"\"\"\n if not src_txt.endswith(\"\\n\"):\n src_txt += \"\\n\"\n\n grammars = get_grammars(set(target_versions))\n errors = {}\n for grammar in grammars:\n drv = driver.Driver(grammar)\n try:\n result = drv.parse_string(src_txt, True)\n break\n\n except ParseError as pe:\n lineno, column = pe.context[1]\n lines = src_txt.splitlines()\n try:\n faulty_line = lines[lineno - 1]\n except IndexError:\n faulty_line = \"\"\n errors[grammar.version] = InvalidInput(\n f\"Cannot parse: {lineno}:{column}: {faulty_line}\"\n )\n\n except TokenError as te:\n # In edge cases these are raised; and typically don't have a \"faulty_line\".\n lineno, column = te.args[1]\n errors[grammar.version] = InvalidInput(\n f\"Cannot parse: {lineno}:{column}: {te.args[0]}\"\n )\n\n else:\n # Choose the latest version when raising the actual parsing error.\n assert len(errors) >= 1\n...\n# Path: src/black/ranges.py\n\"\"\"Functions related to Black's formatting by line ranges feature.\"\"\"\n\nimport difflib\nfrom dataclasses import dataclass\nfrom typing import Collection, Iterator, List, Sequence, Set, Tuple, Union\n\nfrom black.nodes import (\n LN,\n STANDALONE_COMMENT,\n Leaf,\n Node,\n Visitor,\n first_leaf,\n furthest_ancestor_with_last_leaf,\n last_leaf,\n syms,\n)\nfrom blib2to3.pgen2.token import ASYNC, NEWLINE\n\n\ndef parse_line_ranges(line_ranges: Sequence[str]) -> List[Tuple[int, int]]:\n lines: List[Tuple[int, int]] = []\n for lines_str in line_ranges:\n parts = lines_str.split(\"-\")\n if len(parts) != 2:\n raise ValueError(\n \"Incorrect --line-ranges format, expect 'START-END', found\"\n f\" {lines_str!r}\"\n )\n try:\n start = int(parts[0])\n end = int(parts[1])\n except ValueError:\n raise ValueError(\n \"Incorrect --line-ranges value, expect integer ranges, found\"\n f\" {lines_str!r}\"\n ) from None\n else:\n lines.append((start, end))\n return lines\n\n\ndef is_valid_line_range(lines: Tuple[int, int]) -> bool:\n \"\"\"Returns whether the line range is valid.\"\"\"\n return not lines or lines[0] <= lines[1]\n\n\ndef adjusted_lines(\n lines: Collection[Tuple[int, int]],\n original_source: str,\n modified_source: str,\n) -> List[Tuple[int, int]]:\n \"\"\"Returns the adjusted line ranges based on edits from the original code.\n\n This computes the new line ranges by diffing original_source and\n modified_source, and adjust each range based on how the range overlaps with\n the diffs.\n\n Note the diff can contain lines outside of the original line ranges. This can\n happen when the formatting has to be done in adjacent to maintain consistent\n local results. For example:\n\n 1. def my_func(arg1, arg2,\n 2. arg3,):\n 3. pass\n\n If it restricts to line 2-2, it can't simply reformat line 2, it also has\n to reformat line 1:\n\n 1. def my_func(\n 2. arg1,\n 3. arg2,\n 4. arg3,\n 5. ):\n 6. pass\n\n In this case, we will expand the line ranges to also include the whole diff\n block.\n\n Args:\n lines: a collection of line ranges.\n original_source: the original source.\n modified_source: the modified source.\n \"\"\"\n lines_mappings = _calculate_lines_mappings(original_source, modified_source)\n\n new_lines = []\n # Keep an index of the current search. Since the lines and lines_mappings are\n # sorted, this makes the search complexity linear.\n current_mapping_index = 0\n for start, end in sorted(lines):\n start_mapping_index = _find_lines_mapping_index(\n start,\n lines_mappings,\n current_mapping_index,\n )\n end_mapping_index = _find_lines_mapping_index(\n end,\n lines_mappings,\n start_mapping_index,\n )\n current_mapping_index = start_mapping_index\n if start_mapping_index >= len(lines_mappings) or end_mapping_index >= len(\n lines_mappings\n ):\n # Protect against invalid inputs.\n continue\n start_mapping = lines_mappings[start_mapping_index]\n end_mapping = lines_mappings[end_mapping_index]\n if start_mapping.is_changed_block:\n # When the line falls into a changed block, expands to the whole block.\n new_start = start_mapping.modified_start\n else:\n new_start = (\n start - start_mapping.original_start + start_mapping.modified_start\n )\n if end_mapping.is_changed_block:\n # When the line falls into a changed block, expands to the whole block.\n new_end = end_mapping.modified_end\n else:\n new_end = end - end_mapping.original_start + end_mapping.modified_start\n new_range = (new_start, new_end)\n if is_valid_line_range(new_range):\n new_lines.append(new_range)\n return new_lines\n\n\n\ndef convert_unchanged_lines(src_node: Node, lines: Collection[Tuple[int, int]]) -> None:\n \"\"\"Converts unchanged lines to STANDALONE_COMMENT.\n\n The idea is similar to how `# fmt: on/off` is implemented. It also converts the\n nodes between those markers as a single `STANDALONE_COMMENT` leaf node with\n the unformatted code as its value. `STANDALONE_COMMENT` is a \"fake\" token\n that will be formatted as-is with its prefix normalized.\n\n Here we perform two passes:\n\n 1. Visit the top-level statements, and convert them to a single\n `STANDALONE_COMMENT` when unchanged. This speeds up formatting when some\n of the top-level statements aren't changed.\n 2. Convert unchanged \"unwrapped lines\" to `STANDALONE_COMMENT` nodes line by\n line. \"unwrapped lines\" are divided by the `NEWLINE` token. e.g. a\n multi-line statement is *one* \"unwrapped line\" that ends with `NEWLINE`,\n even though this statement itself can span multiple lines, and the\n tokenizer only sees the last '\\n' as the `NEWLINE` token.\n\n NOTE: During pass (2), comment prefixes and indentations are ALWAYS\n normalized even when the lines aren't changed. This is fixable by moving\n more formatting to pass (1). However, it's hard to get it correct when\n incorrect indentations are used. So we defer this to future optimizations.\n \"\"\"\n lines_set: Set[int] = set()\n for start, end in lines:\n lines_set.update(range(start, end + 1))\n visitor = _TopLevelStatementsVisitor(lines_set)\n _ = list(visitor.visit(src_node)) # Consume all results.\n _convert_unchanged_line_by_line(src_node, lines_set)\n\n\ndef _contains_standalone_comment(node: LN) -> bool:\n if isinstance(node, Leaf):\n return node.type == STANDALONE_COMMENT\n else:\n for child in node.children:\n if _contains_standalone_comment(child):\n return True\n return False\n\n\nclass _TopLevelStatementsVisitor(Visitor[None]):\n \"\"\"\n A node visitor that converts unchanged top-level statements to\n STANDALONE_COMMENT.\n\n This is used in addition to _convert_unchanged_line_by_line, to\n speed up formatting when there are unchanged top-level\n classes/functions/statements.\n \"\"\"\n\n def __init__(self, lines_set: Set[int]):\n self._lines_set = lines_set\n\n def visit_simple_stmt(self, node: Node) -> Iterator[None]:\n # This is only called for top-level statements, since `visit_suite`\n # won't visit its children nodes.\n yield from []\n newline_leaf = last_leaf(node)\n if not newline_leaf:\n return\n assert (\n newline_leaf.type == NEWLINE\n ), f\"Unexpectedly found leaf.type={newline_leaf.type}\"\n # We need to find the furthest ancestor with the NEWLINE as the last\n # leaf, since a `suite` can simply be a `simple_stmt` when it puts\n # its body on the same line. Example: `if cond: pass`.\n ancestor = furthest_ancestor_with_last_leaf(newline_leaf)\n if not _get_line_range(ancestor).intersection(self._lines_set):\n _convert_node_to_standalone_comment(ancestor)\n\n def visit_suite(self, node: Node) -> Iterator[None]:\n yield from []\n # If there is a STANDALONE_COMMENT node, it means parts of the node tree\n # have fmt on/off/skip markers. Those STANDALONE_COMMENT nodes can't\n # be simply converted by calling str(node). So we just don't convert\n # here.\n if _contains_standalone_comment(node):\n return\n # Find the semantic parent of this suite. For `async_stmt` and\n # `async_funcdef`, the ASYNC token is defined on a separate level by the\n # grammar.\n semantic_parent = node.parent\n if semantic_parent is not None:\n if (\n semantic_parent.prev_sibling is not None\n and semantic_parent.prev_sibling.type == ASYNC\n ):\n semantic_parent = semantic_parent.parent\n if semantic_parent is not None and not _get_line_range(\n semantic_parent\n ).intersection(self._lines_set):\n _convert_node_to_standalone_comment(semantic_parent)\n\n\ndef _convert_unchanged_line_by_line(node: Node, lines_set: Set[int]) -> None:\n \"\"\"Converts unchanged to STANDALONE_COMMENT line by line.\"\"\"\n for leaf in node.leaves():\n if leaf.type != NEWLINE:\n # We only consider \"unwrapped lines\", which are divided by the NEWLINE\n # token.\n continue\n if leaf.parent and leaf.parent.type == syms.match_stmt:\n # The `suite` node is defined as:\n # match_stmt: \"match\" subject_expr ':' NEWLINE INDENT case_block+ DEDENT\n # Here we need to check `subject_expr`. The `case_block+` will be\n # checked by their own NEWLINEs.\n nodes_to_ignore: List[LN] = []\n prev_sibling = leaf.prev_sibling\n while prev_sibling:\n nodes_to_ignore.insert(0, prev_sibling)\n prev_sibling = prev_sibling.prev_sibling\n if not _get_line_range(nodes_to_ignore).intersection(lines_set):\n _convert_nodes_to_standalone_comment(nodes_to_ignore, newline=leaf)\n elif leaf.parent and leaf.parent.type == syms.suite:\n # The `suite` node is defined as:\n # suite: simple_stmt | NEWLINE INDENT stmt+ DEDENT\n # We will check `simple_stmt` and `stmt+` separately against the lines set\n parent_sibling = leaf.parent.prev_sibling\n nodes_to_ignore = []\n while parent_sibling and not parent_sibling.type == syms.suite:\n # NOTE: Multiple suite nodes can exist as siblings in e.g. `if_stmt`.\n nodes_to_ignore.insert(0, parent_sibling)\n parent_sibling = parent_sibling.prev_sibling\n # Special case for `async_stmt` and `async_funcdef` where the ASYNC\n # token is on the grandparent node.\n grandparent = leaf.parent.parent\n if (\n grandparent is not None\n and grandparent.prev_sibling is not None\n and grandparent.prev_sibling.type == ASYNC\n ):\n nodes_to_ignore.insert(0, grandparent.prev_sibling)\n if not _get_line_range(nodes_to_ignore).intersection(lines_set):\n _convert_nodes_to_standalone_comment(nodes_to_ignore, newline=leaf)\n else:\n ancestor = furthest_ancestor_with_last_leaf(leaf)\n # Consider multiple decorators as a whole block, as their\n # newlines have different behaviors than the rest of the grammar.\n if (\n ancestor.type == syms.decorator\n and ancestor.parent\n and ancestor.parent.type == syms.decorators\n ):\n ancestor = ancestor.parent\n if not _get_line_range(ancestor).intersection(lines_set):\n _convert_node_to_standalone_comment(ancestor)\n\n\ndef _convert_node_to_standalone_comment(node: LN) -> None:\n \"\"\"Convert node to STANDALONE_COMMENT by modifying the tree inline.\"\"\"\n parent = node.parent\n if not parent:\n return\n first = first_leaf(node)\n last = last_leaf(node)\n if not first or not last:\n return\n if first is last:\n # This can happen on the following edge cases:\n # 1. A block of `# fmt: off/on` code except the `# fmt: on` is placed\n # on the end of the last line instead of on a new line.\n # 2. A single backslash on its own line followed by a comment line.\n # Ideally we don't want to format them when not requested, but fixing\n # isn't easy. These cases are also badly formatted code, so it isn't\n # too bad we reformat them.\n return\n # The prefix contains comments and indentation whitespaces. They are\n # reformatted accordingly to the correct indentation level.\n # This also means the indentation will be changed on the unchanged lines, and\n # this is actually required to not break incremental reformatting.\n prefix = first.prefix\n first.prefix = \"\"\n index = node.remove()\n if index is not None:\n # Remove the '\\n', as STANDALONE_COMMENT will have '\\n' appended when\n # generating the formatted code.\n value = str(node)[:-1]\n parent.insert_child(\n index,\n Leaf(\n STANDALONE_COMMENT,\n value,\n prefix=prefix,\n fmt_pass_converted_first_leaf=first,\n ),\n )\n\n\ndef _convert_nodes_to_standalone_comment(nodes: Sequence[LN], *, newline: Leaf) -> None:\n \"\"\"Convert nodes to STANDALONE_COMMENT by modifying the tree inline.\"\"\"\n if not nodes:\n return\n parent = nodes[0].parent\n first = first_leaf(nodes[0])\n if not parent or not first:\n return\n prefix = first.prefix\n first.prefix = \"\"\n value = \"\".join(str(node) for node in nodes)\n # The prefix comment on the NEWLINE leaf is the trailing comment of the statement.\n if newline.prefix:\n value += newline.prefix\n newline.prefix = \"\"\n index = nodes[0].remove()\n for node in nodes[1:]:\n node.remove()\n if index is not None:\n parent.insert_child(\n index,\n Leaf(\n STANDALONE_COMMENT,\n value,\n prefix=prefix,\n fmt_pass_converted_first_leaf=first,\n ),\n )\n\n\ndef _leaf_line_end(leaf: Leaf) -> int:\n \"\"\"Returns the line number of the leaf node's last line.\"\"\"\n if leaf.type == NEWLINE:\n return leaf.lineno\n else:\n # Leaf nodes like multiline strings can occupy multiple lines.\n return leaf.lineno + str(leaf).count(\"\\n\")\n\n\ndef _get_line_range(node_or_nodes: Union[LN, List[LN]]) -> Set[int]:\n \"\"\"Returns the line range of this node or list of nodes.\"\"\"\n if isinstance(node_or_nodes, list):\n nodes = node_or_nodes\n if not nodes:\n return set()\n first = first_leaf(nodes[0])\n last = last_leaf(nodes[-1])\n if first and last:\n line_start = first.lineno\n line_end = _leaf_line_end(last)\n return set(range(line_start, line_end + 1))\n else:\n return set()\n else:\n node = node_or_nodes\n if isinstance(node, Leaf):\n return set(range(node.lineno, _leaf_line_end(node) + 1))\n else:\n first = first_leaf(node)\n last = last_leaf(node)\n if first and last:\n return set(range(first.lineno, _leaf_line_end(last) + 1))\n else:\n return set()\n\n\n@dataclass\nclass _LinesMapping:\n \"\"\"1-based lines mapping from original source to modified source.\n\n Lines [original_start, original_end] from original source\n are mapped to [modified_start, modified_end].\n\n The ranges are inclusive on both ends.\n \"\"\"\n\n original_start: int\n original_end: int\n modified_start: int\n modified_end: int\n # Whether this range corresponds to a changed block, or an unchanged block.\n is_changed_block: bool\n\n\ndef _calculate_lines_mappings(\n original_source: str,\n modified_source: str,\n) -> Sequence[_LinesMapping]:\n \"\"\"Returns a sequence of _LinesMapping by diffing the sources.\n\n For example, given the following diff:\n import re\n - def func(arg1,\n - arg2, arg3):\n + def func(arg1, arg2, arg3):\n pass\n It returns the following mappings:\n original -> modified\n (1, 1) -> (1, 1), is_changed_block=False (the \"import re\" line)\n (2, 3) -> (2, 2), is_changed_block=True (the diff)\n (4, 4) -> (3, 3), is_changed_block=False (the \"pass\" line)\n\n You can think of this visually as if it brings up a side-by-side diff, and tries\n to map the line ranges from the left side to the right side:\n\n (1, 1)->(1, 1) 1. import re 1. import re\n (2, 3)->(2, 2) 2. def func(arg1, 2. def func(arg1, arg2, arg3):\n 3. arg2, arg3):\n (4, 4)->(3, 3) 4. pass 3. pass\n\n Args:\n original_source: the original source.\n modified_source: the modified source.\n \"\"\"\n matcher = difflib.SequenceMatcher(\n None,\n original_source.splitlines(keepends=True),\n modified_source.splitlines(keepends=True),\n )\n matching_blocks = matcher.get_matching_blocks()\n lines_mappings: List[_LinesMapping] = []\n # matching_blocks is a sequence of \"same block of code ranges\", see\n # https://docs.python.org/3/library/difflib.html#difflib.SequenceMatcher.get_matching_blocks\n # Each block corresponds to a _LinesMapping with is_changed_block=False,\n # and the ranges between two blocks corresponds to a _LinesMapping with\n # is_changed_block=True,\n # NOTE: matching_blocks is 0-based, but _LinesMapping is 1-based.\n for i, block in enumerate(matching_blocks):\n if i == 0:\n if block.a != 0 or block.b != 0:\n lines_mappings.append(\n _LinesMapping(\n original_start=1,\n original_end=block.a,\n modified_start=1,\n modified_end=block.b,\n is_changed_block=False,\n )\n )\n else:\n previous_block = matching_blocks[i - 1]\n lines_mappings.append(\n _LinesMapping(\n original_start=previous_block.a + previous_block.size + 1,\n original_end=block.a,\n modified_start=previous_block.b + previous_block.size + 1,\n modified_end=block.b,\n is_changed_block=True,\n )\n )\n if i < len(matching_blocks) - 1:\n lines_mappings.append(\n _LinesMapping(\n original_start=block.a + 1,\n original_end=block.a + block.size,\n modified_start=block.b + 1,\n modified_end=block.b + block.size,\n is_changed_block=False,\n )\n )\n return lines_mappings\n\n\ndef _find_lines_mapping_index(\n original_line: int,\n lines_mappings: Sequence[_LinesMapping],\n start_index: int,\n) -> int:\n \"\"\"Returns the original index of the lines mappings for the original line.\"\"\"\n index = start_index\n while index < len(lines_mappings):\n mapping = lines_mappings[index]\n if mapping.original_start <= original_line <= mapping.original_end:\n return index\n index += 1\n return index\n\n# Path: src/black/__init__.py\nimport io\nimport json\nimport platform\nimport re\nimport sys\nimport tokenize\nimport traceback\nfrom contextlib import contextmanager\nfrom dataclasses import replace\nfrom datetime import datetime, timezone\nfrom enum import Enum\nfrom json.decoder import JSONDecodeError\nfrom pathlib import Path\nfrom typing import (\n Any,\n Collection,\n Dict,\n Generator,\n Iterator,\n List,\n MutableMapping,\n Optional,\n Pattern,\n Sequence,\n Set,\n Sized,\n Tuple,\n Union,\n)\n\nimport click\nfrom click.core import ParameterSource\nfrom mypy_extensions import mypyc_attr\nfrom pathspec import PathSpec\nfrom pathspec.patterns.gitwildmatch import GitWildMatchPatternError\n\nfrom _black_version import version as __version__\nfrom black.cache import Cache\nfrom black.comments import normalize_fmt_off\nfrom black.const import (\n DEFAULT_EXCLUDES,\n DEFAULT_INCLUDES,\n DEFAULT_LINE_LENGTH,\n STDIN_PLACEHOLDER,\n)\nfrom black.files import (\n best_effort_relative_path,\n find_project_root,\n find_pyproject_toml,\n find_user_pyproject_toml,\n gen_python_files,\n get_gitignore,\n parse_pyproject_toml,\n path_is_excluded,\n resolves_outside_root_or_cannot_stat,\n wrap_stream_for_windows,\n)\nfrom black.handle_ipynb_magics import (\n PYTHON_CELL_MAGICS,\n TRANSFORMED_MAGICS,\n jupyter_dependencies_are_installed,\n mask_cell,\n put_trailing_semicolon_back,\n remove_trailing_semicolon,\n unmask_cell,\n)\nfrom black.linegen import LN, LineGenerator, transform_line\nfrom black.lines import EmptyLineTracker, LinesBlock\nfrom black.mode import FUTURE_FLAG_TO_FEATURE, VERSION_TO_FEATURES, Feature\nfrom black.mode import Mode as Mode # re-exported\nfrom black.mode import Preview, TargetVersion, supports_feature\nfrom black.nodes import (\n STARS,\n is_number_token,\n is_simple_decorator_expression,\n is_string_token,\n syms,\n)\nfrom black.output import color_diff, diff, dump_to_file, err, ipynb_diff, out\nfrom black.parsing import InvalidInput # noqa F401\nfrom black.parsing import lib2to3_parse, parse_ast, stringify_ast\nfrom black.ranges import adjusted_lines, convert_unchanged_lines, parse_line_ranges\nfrom black.report import Changed, NothingChanged, Report\nfrom black.trans import iter_fexpr_spans\nfrom blib2to3.pgen2 import token\nfrom blib2to3.pytree import Leaf, Node\n\nCOMPILED = Path(__file__).suffix in (\".pyd\", \".so\")\n\n# types\nFileContent = str\nEncoding = str\nNewLine = str\n\n\nclass WriteBack(Enum):\n NO = 0\n YES = 1\n DIFF = 2\n CHECK = 3\n COLOR_DIFF = 4\n\n @classmethod\n def from_configuration(\n cls, *, check: bool, diff: bool, color: bool = False\n ) -> \"WriteBack\":\n if check and not diff:\n return cls.CHECK\n\n if diff and color:\n return cls.COLOR_DIFF\n\n return cls.DIFF if diff else cls.YES\n\n\n# Legacy name, left for integrations.\nFileMode = Mode\n\n\ndef read_pyproject_toml(\n ctx: click.Context, param: click.Parameter, value: Optional[str]\n) -> Optional[str]:\n \"\"\"Inject Black configuration from \"pyproject.toml\" into defaults in `ctx`.\n\n Returns the path to a successfully found and read configuration file, None\n otherwise.\n \"\"\"\n if not value:\n value = find_pyproject_toml(\n ctx.params.get(\"src\", ()), ctx.params.get(\"stdin_filename\", None)\n )\n if value is None:\n return None\n\n try:\n config = parse_pyproject_toml(value)\n except (OSError, ValueError) as e:\n raise click.FileError(\n filename=value, hint=f\"Error reading configuration file: {e}\"\n ) from None\n\n if not config:\n return None\n else:\n spellcheck_pyproject_toml_keys(ctx, list(config), value)\n # Sanitize the values to be Click friendly. For more information please see:\n # https://github.com/psf/black/issues/1458\n # https://github.com/pallets/click/issues/1567\n config = {\n k: str(v) if not isinstance(v, (list, dict)) else v\n for k, v in config.items()\n }\n\n target_version = config.get(\"target_version\")\n if target_version is not None and not isinstance(target_version, list):\n raise click.BadOptionUsage(\n \"target-version\", \"Config key target-version must be a list\"\n )\n\n exclude = config.get(\"exclude\")\n if exclude is not None and not isinstance(exclude, str):\n raise click.BadOptionUsage(\"exclude\", \"Config key exclude must be a string\")\n\n extend_exclude = config.get(\"extend_exclude\")\n if extend_exclude is not None and not isinstance(extend_exclude, str):\n raise click.BadOptionUsage(\n \"extend-exclude\", \"Config key extend-exclude must be a string\"\n )\n\n line_ranges = config.get(\"line_ranges\")\n if line_ranges is not None:\n raise click.BadOptionUsage(\n \"line-ranges\", \"Cannot use line-ranges in the pyproject.toml file.\"\n )\n\n default_map: Dict[str, Any] = {}\n if ctx.default_map:\n default_map.update(ctx.default_map)\n default_map.update(config)\n\n ctx.default_map = default_map\n return value\n\n\ndef spellcheck_pyproject_toml_keys(\n ctx: click.Context, config_keys: List[str], config_file_path: str\n) -> None:\n invalid_keys: List[str] = []\n available_config_options = {param.name for param in ctx.command.params}\n for key in config_keys:\n if key not in available_config_options:\n invalid_keys.append(key)\n if invalid_keys:\n keys_str = \", \".join(map(repr, invalid_keys))\n out(\n f\"Invalid config keys detected: {keys_str} (in {config_file_path})\",\n fg=\"red\",\n )\n\n\ndef target_version_option_callback(\n c: click.Context, p: Union[click.Option, click.Parameter], v: Tuple[str, ...]\n) -> List[TargetVersion]:\n \"\"\"Compute the target versions from a --target-version flag.\n\n This is its own function because mypy couldn't infer the type correctly\n when it was a lambda, causing mypyc trouble.\n \"\"\"\n return [TargetVersion[val.upper()] for val in v]\n\n\ndef enable_unstable_feature_callback(\n c: click.Context, p: Union[click.Option, click.Parameter], v: Tuple[str, ...]\n) -> List[Preview]:\n \"\"\"Compute the features from an --enable-unstable-feature flag.\"\"\"\n return [Preview[val] for val in v]\n\n\ndef re_compile_maybe_verbose(regex: str) -> Pattern[str]:\n \"\"\"Compile a regular expression string in `regex`.\n\n If it contains newlines, use verbose mode.\n \"\"\"\n if \"\\n\" in regex:\n regex = \"(?x)\" + regex\n compiled: Pattern[str] = re.compile(regex)\n return compiled\n\n\ndef validate_regex(\n ctx: click.Context,\n param: click.Parameter,\n value: Optional[str],\n) -> Optional[Pattern[str]]:\n try:\n return re_compile_maybe_verbose(value) if value is not None else None\n except re.error as e:\n raise click.BadParameter(f\"Not a valid regular expression: {e}\") from None\n\n\n@click.command(\n context_settings={\"help_option_names\": [\"-h\", \"--help\"]},\n # While Click does set this field automatically using the docstring, mypyc\n # (annoyingly) strips 'em so we need to set it here too.\n help=\"The uncompromising code formatter.\",\n)\n@click.option(\"-c\", \"--code\", type=str, help=\"Format the code passed in as a string.\")\n@click.option(\n \"-l\",\n \"--line-length\",\n type=int,\n default=DEFAULT_LINE_LENGTH,\n help=\"How many characters per line to allow.\",\n show_default=True,\n)\n@click.option(\n \"-t\",\n \"--target-version\",\n type=click.Choice([v.name.lower() for v in TargetVersion]),\n callback=target_version_option_callback,\n multiple=True,\n help=(\n \"Python versions that should be supported by Black's output. You should\"\n \" include all versions that your code supports. By default, Black will infer\"\n \" target versions from the project metadata in pyproject.toml. If this does\"\n \" not yield conclusive results, Black will use per-file auto-detection.\"\n ),\n)\n@click.option(\n \"--pyi\",\n is_flag=True,\n help=(\n \"Format all input files like typing stubs regardless of file extension. This\"\n \" is useful when piping source on standard input.\"\n ),\n)\n@click.option(\n \"--ipynb\",\n is_flag=True,\n help=(\n \"Format all input files like Jupyter Notebooks regardless of file extension.\"\n \" This is useful when piping source on standard input.\"\n ),\n)\n@click.option(\n \"--python-cell-magics\",\n multiple=True,\n help=(\n \"When processing Jupyter Notebooks, add the given magic to the list\"\n f\" of known python-magics ({', '.join(sorted(PYTHON_CELL_MAGICS))}).\"\n \" Useful for formatting cells with custom python magics.\"\n ),\n default=[],\n)\n@click.option(\n \"-x\",\n \"--skip-source-first-line\",\n is_flag=True,\n help=\"Skip the first line of the source code.\",\n)\n@click.option(\n \"-S\",\n \"--skip-string-normalization\",\n is_flag=True,\n help=\"Don't normalize string quotes or prefixes.\",\n)\n@click.option(\n \"-C\",\n \"--skip-magic-trailing-comma\",\n is_flag=True,\n help=\"Don't use trailing commas as a reason to split lines.\",\n)\n@click.option(\n \"--preview\",\n is_flag=True,\n help=(\n \"Enable potentially disruptive style changes that may be added to Black's main\"\n \" functionality in the next major release.\"\n ),\n)\n@click.option(\n \"--unstable\",\n is_flag=True,\n help=(\n \"Enable potentially disruptive style changes that have known bugs or are not\"\n \" currently expected to make it into the stable style Black's next major\"\n \" release. Implies --preview.\"\n ),\n)\n@click.option(\n \"--enable-unstable-feature\",\n type=click.Choice([v.name for v in Preview]),\n callback=enable_unstable_feature_callback,\n multiple=True,\n help=(\n \"Enable specific features included in the `--unstable` style. Requires\"\n \" `--preview`. No compatibility guarantees are provided on the behavior\"\n \" or existence of any unstable features.\"\n ),\n)\n@click.option(\n \"--check\",\n is_flag=True,\n help=(\n \"Don't write the files back, just return the status. Return code 0 means\"\n \" nothing would change. Return code 1 means some files would be reformatted.\"\n \" Return code 123 means there was an internal error.\"\n ),\n)\n@click.option(\n \"--diff\",\n is_flag=True,\n help=(\n \"Don't write the files back, just output a diff to indicate what changes\"\n \" Black would've made. They are printed to stdout so capturing them is simple.\"\n ),\n)\n@click.option(\n \"--color/--no-color\",\n is_flag=True,\n help=\"Show (or do not show) colored diff. Only applies when --diff is given.\",\n)\n@click.option(\n \"--line-ranges\",\n multiple=True,\n metavar=\"START-END\",\n help=(\n \"When specified, Black will try its best to only format these lines. This\"\n \" option can be specified multiple times, and a union of the lines will be\"\n \" formatted. Each range must be specified as two integers connected by a `-`:\"\n \" `-`. The `` and `` integer indices are 1-based and\"\n \" inclusive on both ends.\"\n ),\n default=(),\n)\n@click.option(\n \"--fast/--safe\",\n is_flag=True,\n help=(\n \"By default, Black performs an AST safety check after formatting your code.\"\n \" The --fast flag turns off this check and the --safe flag explicitly enables\"\n \" it. [default: --safe]\"\n ),\n)\n@click.option(\n \"--required-version\",\n type=str,\n help=(\n \"Require a specific version of Black to be running. This is useful for\"\n \" ensuring that all contributors to your project are using the same\"\n \" version, because different versions of Black may format code a little\"\n \" differently. This option can be set in a configuration file for consistent\"\n \" results across environments.\"\n ),\n)\n@click.option(\n \"--exclude\",\n type=str,\n callback=validate_regex,\n help=(\n \"A regular expression that matches files and directories that should be\"\n \" excluded on recursive searches. An empty value means no paths are excluded.\"\n \" Use forward slashes for directories on all platforms (Windows, too).\"\n \" By default, Black also ignores all paths listed in .gitignore. Changing this\"\n f\" value will override all default exclusions. [default: {DEFAULT_EXCLUDES}]\"\n ),\n show_default=False,\n)\n@click.option(\n \"--extend-exclude\",\n type=str,\n callback=validate_regex,\n help=(\n \"Like --exclude, but adds additional files and directories on top of the\"\n \" default values instead of overriding them.\"\n ),\n)\n@click.option(\n \"--force-exclude\",\n type=str,\n callback=validate_regex,\n help=(\n \"Like --exclude, but files and directories matching this regex will be excluded\"\n \" even when they are passed explicitly as arguments. This is useful when\"\n \" invoking Black programmatically on changed files, such as in a pre-commit\"\n \" hook or editor plugin.\"\n ),\n)\n@click.option(\n \"--stdin-filename\",\n type=str,\n is_eager=True,\n help=(\n \"The name of the file when passing it through stdin. Useful to make sure Black\"\n \" will respect the --force-exclude option on some editors that rely on using\"\n \" stdin.\"\n ),\n)\n@click.option(\n \"--include\",\n type=str,\n default=DEFAULT_INCLUDES,\n callback=validate_regex,\n help=(\n \"A regular expression that matches files and directories that should be\"\n \" included on recursive searches. An empty value means all files are included\"\n \" regardless of the name. Use forward slashes for directories on all platforms\"\n \" (Windows, too). Overrides all exclusions, including from .gitignore and\"\n \" command line options.\"\n ),\n show_default=True,\n)\n@click.option(\n \"-W\",\n \"--workers\",\n type=click.IntRange(min=1),\n default=None,\n help=(\n \"When Black formats multiple files, it may use a process pool to speed up\"\n \" formatting. This option controls the number of parallel workers. This can\"\n \" also be specified via the BLACK_NUM_WORKERS environment variable. Defaults\"\n \" to the number of CPUs in the system.\"\n ),\n)\n@click.option(\n \"-q\",\n \"--quiet\",\n is_flag=True,\n help=(\n \"Stop emitting all non-critical output. Error messages will still be emitted\"\n \" (which can silenced by 2>/dev/null).\"\n ),\n)\n@click.option(\n \"-v\",\n \"--verbose\",\n is_flag=True,\n help=(\n \"Emit messages about files that were not changed or were ignored due to\"\n \" exclusion patterns. If Black is using a configuration file, a message\"\n \" detailing which one it is using will be emitted.\"\n ),\n)\n@click.version_option(\n version=__version__,\n message=(\n f\"%(prog)s, %(version)s (compiled: {'yes' if COMPILED else 'no'})\\n\"\n f\"Python ({platform.python_implementation()}) {platform.python_version()}\"\n ),\n)\n@click.argument(\n \"src\",\n nargs=-1,\n type=click.Path(\n exists=True, file_okay=True, dir_okay=True, readable=True, allow_dash=True\n ),\n is_eager=True,\n metavar=\"SRC ...\",\n)\n@click.option(\n \"--config\",\n type=click.Path(\n exists=True,\n file_okay=True,\n dir_okay=False,\n readable=True,\n allow_dash=False,\n path_type=str,\n ),\n is_eager=True,\n callback=read_pyproject_toml,\n help=\"Read configuration options from a configuration file.\",\n)\n@click.pass_context\ndef main( # noqa: C901\n ctx: click.Context,\n code: Optional[str],\n line_length: int,\n target_version: List[TargetVersion],\n check: bool,\n diff: bool,\n line_ranges: Sequence[str],\n color: bool,\n fast: bool,\n pyi: bool,\n ipynb: bool,\n python_cell_magics: Sequence[str],\n skip_source_first_line: bool,\n skip_string_normalization: bool,\n skip_magic_trailing_comma: bool,\n preview: bool,\n unstable: bool,\n enable_unstable_feature: List[Preview],\n quiet: bool,\n verbose: bool,\n required_version: Optional[str],\n include: Pattern[str],\n exclude: Optional[Pattern[str]],\n extend_exclude: Optional[Pattern[str]],\n force_exclude: Optional[Pattern[str]],\n stdin_filename: Optional[str],\n workers: Optional[int],\n src: Tuple[str, ...],\n config: Optional[str],\n) -> None:\n \"\"\"The uncompromising code formatter.\"\"\"\n ctx.ensure_object(dict)\n\n if src and code is not None:\n out(\n main.get_usage(ctx)\n + \"\\n\\n'SRC' and 'code' cannot be passed simultaneously.\"\n )\n ctx.exit(1)\n if not src and code is None:\n out(main.get_usage(ctx) + \"\\n\\nOne of 'SRC' or 'code' is required.\")\n ctx.exit(1)\n\n # It doesn't do anything if --unstable is also passed, so just allow it.\n if enable_unstable_feature and not (preview or unstable):\n out(\n main.get_usage(ctx)\n + \"\\n\\n'--enable-unstable-feature' requires '--preview'.\"\n )\n ctx.exit(1)\n\n root, method = (\n find_project_root(src, stdin_filename) if code is None else (None, None)\n )\n ctx.obj[\"root\"] = root\n\n if verbose:\n if root:\n out(\n f\"Identified `{root}` as project root containing a {method}.\",\n fg=\"blue\",\n )\n\n if config:\n config_source = ctx.get_parameter_source(\"config\")\n user_level_config = str(find_user_pyproject_toml())\n if config == user_level_config:\n out(\n \"Using configuration from user-level config at \"\n f\"'{user_level_config}'.\",\n fg=\"blue\",\n )\n elif config_source in (\n ParameterSource.DEFAULT,\n ParameterSource.DEFAULT_MAP,\n ):\n out(\"Using configuration from project root.\", fg=\"blue\")\n else:\n out(f\"Using configuration in '{config}'.\", fg=\"blue\")\n if ctx.default_map:\n for param, value in ctx.default_map.items():\n out(f\"{param}: {value}\")\n\n error_msg = \"Oh no! \ud83d\udca5 \ud83d\udc94 \ud83d\udca5\"\n if (\n required_version\n and required_version != __version__\n and required_version != __version__.split(\".\")[0]\n ):\n err(\n f\"{error_msg} The required version `{required_version}` does not match\"\n f\" the running version `{__version__}`!\"\n )\n ctx.exit(1)\n if ipynb and pyi:\n err(\"Cannot pass both `pyi` and `ipynb` flags!\")\n ctx.exit(1)\n\n write_back = WriteBack.from_configuration(check=check, diff=diff, color=color)\n if target_version:\n versions = set(target_version)\n else:\n # We'll autodetect later.\n versions = set()\n mode = Mode(\n target_versions=versions,\n line_length=line_length,\n is_pyi=pyi,\n is_ipynb=ipynb,\n skip_source_first_line=skip_source_first_line,\n string_normalization=not skip_string_normalization,\n magic_trailing_comma=not skip_magic_trailing_comma,\n preview=preview,\n unstable=unstable,\n python_cell_magics=set(python_cell_magics),\n enabled_features=set(enable_unstable_feature),\n )\n\n lines: List[Tuple[int, int]] = []\n if line_ranges:\n if ipynb:\n err(\"Cannot use --line-ranges with ipynb files.\")\n ctx.exit(1)\n\n try:\n lines = parse_line_ranges(line_ranges)\n except ValueError as e:\n err(str(e))\n ctx.exit(1)\n\n if code is not None:\n # Run in quiet mode by default with -c; the extra output isn't useful.\n # You can still pass -v to get verbose output.\n quiet = True\n\n report = Report(check=check, diff=diff, quiet=quiet, verbose=verbose)\n\n if code is not None:\n reformat_code(\n content=code,\n fast=fast,\n write_back=write_back,\n mode=mode,\n report=report,\n lines=lines,\n )\n else:\n assert root is not None # root is only None if code is not None\n try:\n sources = get_sources(\n root=root,\n src=src,\n quiet=quiet,\n verbose=verbose,\n include=include,\n exclude=exclude,\n extend_exclude=extend_exclude,\n force_exclude=force_exclude,\n report=report,\n stdin_filename=stdin_filename,\n )\n except GitWildMatchPatternError:\n ctx.exit(1)\n\n path_empty(\n sources,\n \"No Python files are present to be formatted. Nothing to do \ud83d\ude34\",\n quiet,\n verbose,\n ctx,\n )\n\n if len(sources) == 1:\n reformat_one(\n src=sources.pop(),\n fast=fast,\n write_back=write_back,\n mode=mode,\n report=report,\n lines=lines,\n )\n else:\n from black.concurrency import reformat_many\n\n if lines:\n err(\"Cannot use --line-ranges to format multiple files.\")\n ctx.exit(1)\n reformat_many(\n sources=sources,\n fast=fast,\n write_back=write_back,\n mode=mode,\n report=report,\n workers=workers,\n )\n\n if verbose or not quiet:\n if code is None and (verbose or report.change_count or report.failure_count):\n out()\n out(error_msg if report.return_code else \"All done! \u2728 \ud83c\udf70 \u2728\")\n if code is None:\n click.echo(str(report), err=True)\n ctx.exit(report.return_code)\n\n\ndef get_sources(\n *,\n root: Path,\n src: Tuple[str, ...],\n quiet: bool,\n verbose: bool,\n include: Pattern[str],\n exclude: Optional[Pattern[str]],\n extend_exclude: Optional[Pattern[str]],\n force_exclude: Optional[Pattern[str]],\n report: \"Report\",\n stdin_filename: Optional[str],\n) -> Set[Path]:\n \"\"\"Compute the set of files to be formatted.\"\"\"\n sources: Set[Path] = set()\n\n assert root.is_absolute(), f\"INTERNAL ERROR: `root` must be absolute but is {root}\"\n using_default_exclude = exclude is None\n exclude = re_compile_maybe_verbose(DEFAULT_EXCLUDES) if exclude is None else exclude\n gitignore: Optional[Dict[Path, PathSpec]] = None\n root_gitignore = get_gitignore(root)\n\n for s in src:\n if s == \"-\" and stdin_filename:\n path = Path(stdin_filename)\n is_stdin = True\n else:\n path = Path(s)\n is_stdin = False\n\n # Compare the logic here to the logic in `gen_python_files`.\n if is_stdin or path.is_file():\n if resolves_outside_root_or_cannot_stat(path, root, report):\n if verbose:\n out(f'Skipping invalid source: \"{path}\"', fg=\"red\")\n continue\n\n root_relative_path = best_effort_relative_path(path, root).as_posix()\n root_relative_path = \"/\" + root_relative_path\n\n # Hard-exclude any files that matches the `--force-exclude` regex.\n if path_is_excluded(root_relative_path, force_exclude):\n report.path_ignored(\n path, \"matches the --force-exclude regular expression\"\n )\n continue\n\n if is_stdin:\n path = Path(f\"{STDIN_PLACEHOLDER}{str(path)}\")\n\n if path.suffix == \".ipynb\" and not jupyter_dependencies_are_installed(\n warn=verbose or not quiet\n ):\n continue\n\n if verbose:\n out(f'Found input source: \"{path}\"', fg=\"blue\")\n sources.add(path)\n elif path.is_dir():\n path = root / (path.resolve().relative_to(root))\n if verbose:\n out(f'Found input source directory: \"{path}\"', fg=\"blue\")\n\n if using_default_exclude:\n gitignore = {\n root: root_gitignore,\n path: get_gitignore(path),\n }\n sources.update(\n gen_python_files(\n path.iterdir(),\n root,\n include,\n exclude,\n extend_exclude,\n force_exclude,\n report,\n gitignore,\n verbose=verbose,\n quiet=quiet,\n )\n )\n elif s == \"-\":\n if verbose:\n out(\"Found input source stdin\", fg=\"blue\")\n sources.add(path)\n else:\n err(f\"invalid path: {s}\")\n\n return sources\n\n\ndef path_empty(\n src: Sized, msg: str, quiet: bool, verbose: bool, ctx: click.Context\n) -> None:\n \"\"\"\n Exit if there is no `src` provided for formatting\n \"\"\"\n if not src:\n if verbose or not quiet:\n out(msg)\n ctx.exit(0)\n\n\ndef reformat_code(\n content: str,\n fast: bool,\n write_back: WriteBack,\n mode: Mode,\n report: Report,\n *,\n lines: Collection[Tuple[int, int]] = (),\n) -> None:\n \"\"\"\n Reformat and print out `content` without spawning child processes.\n Similar to `reformat_one`, but for string content.\n\n `fast`, `write_back`, and `mode` options are passed to\n :func:`format_file_in_place` or :func:`format_stdin_to_stdout`.\n \"\"\"\n path = Path(\"\")\n try:\n changed = Changed.NO\n if format_stdin_to_stdout(\n content=content, fast=fast, write_back=write_back, mode=mode, lines=lines\n ):\n changed = Changed.YES\n report.done(path, changed)\n except Exception as exc:\n if report.verbose:\n traceback.print_exc()\n report.failed(path, str(exc))\n\n\n# diff-shades depends on being to monkeypatch this function to operate. I know it's\n# not ideal, but this shouldn't cause any issues ... hopefully. ~ichard26\n@mypyc_attr(patchable=True)\ndef reformat_one(\n src: Path,\n fast: bool,\n write_back: WriteBack,\n mode: Mode,\n report: \"Report\",\n *,\n lines: Collection[Tuple[int, int]] = (),\n) -> None:\n \"\"\"Reformat a single file under `src` without spawning child processes.\n\n `fast`, `write_back`, and `mode` options are passed to\n :func:`format_file_in_place` or :func:`format_stdin_to_stdout`.\n \"\"\"\n try:\n changed = Changed.NO\n\n if str(src) == \"-\":\n is_stdin = True\n elif str(src).startswith(STDIN_PLACEHOLDER):\n is_stdin = True\n # Use the original name again in case we want to print something\n # to the user\n src = Path(str(src)[len(STDIN_PLACEHOLDER) :])\n else:\n is_stdin = False\n\n if is_stdin:\n if src.suffix == \".pyi\":\n mode = replace(mode, is_pyi=True)\n elif src.suffix == \".ipynb\":\n mode = replace(mode, is_ipynb=True)\n if format_stdin_to_stdout(\n fast=fast, write_back=write_back, mode=mode, lines=lines\n ):\n changed = Changed.YES\n else:\n cache = Cache.read(mode)\n if write_back not in (WriteBack.DIFF, WriteBack.COLOR_DIFF):\n if not cache.is_changed(src):\n changed = Changed.CACHED\n if changed is not Changed.CACHED and format_file_in_place(\n src, fast=fast, write_back=write_back, mode=mode, lines=lines\n ):\n changed = Changed.YES\n if (write_back is WriteBack.YES and changed is not Changed.CACHED) or (\n write_back is WriteBack.CHECK and changed is Changed.NO\n ):\n cache.write([src])\n report.done(src, changed)\n except Exception as exc:\n if report.verbose:\n traceback.print_exc()\n report.failed(src, str(exc))\n\n\ndef format_file_in_place(\n src: Path,\n fast: bool,\n mode: Mode,\n write_back: WriteBack = WriteBack.NO,\n lock: Any = None, # multiprocessing.Manager().Lock() is some crazy proxy\n *,\n lines: Collection[Tuple[int, int]] = (),\n) -> bool:\n \"\"\"Format file under `src` path. Return True if changed.\n\n If `write_back` is DIFF, write a diff to stdout. If it is YES, write reformatted\n code to the file.\n `mode` and `fast` options are passed to :func:`format_file_contents`.\n \"\"\"\n if src.suffix == \".pyi\":\n mode = replace(mode, is_pyi=True)\n elif src.suffix == \".ipynb\":\n mode = replace(mode, is_ipynb=True)\n\n then = datetime.fromtimestamp(src.stat().st_mtime, timezone.utc)\n header = b\"\"\n with open(src, \"rb\") as buf:\n if mode.skip_source_first_line:\n header = buf.readline()\n src_contents, encoding, newline = decode_bytes(buf.read())\n try:\n dst_contents = format_file_contents(\n src_contents, fast=fast, mode=mode, lines=lines\n )\n except NothingChanged:\n return False\n except JSONDecodeError:\n raise ValueError(\n f\"File '{src}' cannot be parsed as valid Jupyter notebook.\"\n ) from None\n src_contents = header.decode(encoding) + src_contents\n dst_contents = header.decode(encoding) + dst_contents\n\n if write_back == WriteBack.YES:\n with open(src, \"w\", encoding=encoding, newline=newline) as f:\n f.write(dst_contents)\n elif write_back in (WriteBack.DIFF, WriteBack.COLOR_DIFF):\n now = datetime.now(timezone.utc)\n src_name = f\"{src}\\t{then}\"\n dst_name = f\"{src}\\t{now}\"\n if mode.is_ipynb:\n diff_contents = ipynb_diff(src_contents, dst_contents, src_name, dst_name)\n else:\n diff_contents = diff(src_contents, dst_contents, src_name, dst_name)\n\n if write_back == WriteBack.COLOR_DIFF:\n diff_contents = color_diff(diff_contents)\n\n with lock or nullcontext():\n f = io.TextIOWrapper(\n sys.stdout.buffer,\n encoding=encoding,\n newline=newline,\n write_through=True,\n )\n f = wrap_stream_for_windows(f)\n f.write(diff_contents)\n f.detach()\n\n return True\n\n\ndef format_stdin_to_stdout(\n fast: bool,\n *,\n content: Optional[str] = None,\n write_back: WriteBack = WriteBack.NO,\n mode: Mode,\n lines: Collection[Tuple[int, int]] = (),\n) -> bool:\n \"\"\"Format file on stdin. Return True if changed.\n\n If content is None, it's read from sys.stdin.\n\n If `write_back` is YES, write reformatted code back to stdout. If it is DIFF,\n write a diff to stdout. The `mode` argument is passed to\n :func:`format_file_contents`.\n \"\"\"\n then = datetime.now(timezone.utc)\n\n if content is None:\n src, encoding, newline = decode_bytes(sys.stdin.buffer.read())\n else:\n src, encoding, newline = content, \"utf-8\", \"\"\n\n dst = src\n try:\n dst = format_file_contents(src, fast=fast, mode=mode, lines=lines)\n return True\n\n except NothingChanged:\n return False\n\n finally:\n f = io.TextIOWrapper(\n sys.stdout.buffer, encoding=encoding, newline=newline, write_through=True\n )\n if write_back == WriteBack.YES:\n # Make sure there's a newline after the content\n if dst and dst[-1] != \"\\n\":\n dst += \"\\n\"\n f.write(dst)\n elif write_back in (WriteBack.DIFF, WriteBack.COLOR_DIFF):\n now = datetime.now(timezone.utc)\n src_name = f\"STDIN\\t{then}\"\n dst_name = f\"STDOUT\\t{now}\"\n d = diff(src, dst, src_name, dst_name)\n if write_back == WriteBack.COLOR_DIFF:\n d = color_diff(d)\n f = wrap_stream_for_windows(f)\n f.write(d)\n f.detach()\n\n\ndef check_stability_and_equivalence(\n src_contents: str,\n dst_contents: str,\n *,\n mode: Mode,\n lines: Collection[Tuple[int, int]] = (),\n) -> None:\n \"\"\"Perform stability and equivalence checks.\n\n Raise AssertionError if source and destination contents are not\n equivalent, or if a second pass of the formatter would format the\n content differently.\n \"\"\"\n assert_equivalent(src_contents, dst_contents)\n assert_stable(src_contents, dst_contents, mode=mode, lines=lines)\n\n\ndef format_file_contents(\n...\n", "needle_token_start": 2176, "needle_token_end": 2667, "code_context_ntokens": 16222, "output": ["```python\ndef convert_unchanged_lines(src_node: Node, lines: Collection[Tuple[int, int]]) -> None:\n \"\"\"Converts unchanged lines to STANDALONE_COMMENT.\n\n The idea is similar to how `# fmt: on/off` is implemented. It also converts the\n nodes between those markers as a single `STANDALONE_COMMENT` leaf node with\n the unformatted code as its value. `STANDALONE_COMMENT` is a \"fake\" token\n that will be formatted as-is with its prefix normalized.\n\n Here we perform two passes:\n\n 1. Visit the top-level statements, and convert them to a single\n `STANDALONE_COMMENT` when unchanged. This speeds up formatting when some\n of the top-level statements aren't changed.\n 2. Convert unchanged \"unwrapped lines\" to `STANDALONE_COMMENT` nodes line by\n line. \"unwrapped lines\" are divided by the `NEWLINE` token. e.g. a\n multi-line statement is *one* \"unwrapped line\" that ends with `NEWLINE`,\n even though this statement itself can span multiple lines, and the\n tokenizer only sees the last '\\n' as the `NEWLINE` token.\n\n NOTE: During pass (2), comment prefixes and indentations are ALWAYS\n normalized even when the lines aren't changed. This is fixable by moving\n more formatting to pass (1). However, it's hard to get it correct when\n incorrect indentations are used. So we defer this to future optimizations.\n \"\"\"\n lines_set: Set[int] = set()\n for start, end in lines:\n lines_set.update(range(start, end + 1))\n visitor = _TopLevelStatementsVisitor(lines_set)\n _ = list(visitor.visit(src_node)) # Consume all results.\n _convert_unchanged_line_by_line(src_node, lines_set)\n\n\ndef _contains_standalone_comment(node: LN) -> bool:\n if isinstance(node, Leaf):\n return node.type == STANDALONE_COMMENT\n else:\n for child in node.children:\n if _contains_standalone_comment(child):\n return True\n return False\n\n\nclass _TopLevelStatementsVisitor(Visitor[None]):\n \"\"\"\n A node visitor that converts unchanged top-level statements to\n STANDALONE_COMMENT.\n\n This is used in addition to _convert_unchanged_line_by_line, to\n speed up formatting when there are unchanged top-level\n classes/functions/statements.\n \"\"\"\n\n def __init__(self, lines_set: Set[int]):\n self._lines_set = lines_set\n\n def visit_simple_stmt(self, node: Node) -> Iterator[None]:\n # This is only called for top-level statements, since `visit_suite`\n # won't visit its children nodes.\n yield from []\n newline_leaf = last_leaf(node)\n if not newline_leaf:\n return\n assert (\n newline_leaf.type == NEWLINE\n ), f\"Unexpectedly found leaf.type={newline_leaf.type}\"\n # We need to find the furthest ancestor with the NEWLINE as the last\n # leaf, since a `suite` can simply be a `simple_stmt` when it puts\n # its body on the same line. Example: `if cond: pass`.\n ancestor = furthest_ancestor_with_last_leaf(newline_leaf)\n if not _get_line_range(ancestor).intersection(self._lines_set):\n _convert_node_to_standalone_comment(ancestor)\n\n def visit_suite(self, node: Node) -> Iterator[None]:\n yield from []\n # If there is a STANDALONE_COMMENT node, it means parts of the node tree\n # have fmt on/off/skip markers. Those STANDALONE_COMMENT nodes can't\n # be simply converted by calling str(node). So we just don't convert\n # here.\n if _contains_standalone_comment(node):\n return\n # Find the semantic parent of this suite. For `async_stmt` and\n # `async_funcdef`, the ASYNC token is defined on a separate level by the\n # grammar.\n semantic_parent = node.parent\n if semantic_parent is not None:\n if (\n semantic_parent.prev_sibling is not None\n and semantic_parent.prev_sibling.type == ASYNC\n ):\n semantic_parent = semantic_parent.parent\n if semantic_parent is not None and not _get_line_range(\n semantic_parent\n ).intersection(self._lines_set):\n _convert_node_to_standalone_comment(semantic_parent)\n\n\ndef _convert_unchanged_line_by_line(node: Node, lines_set: Set[int]) -> None:\n \"\"\"Converts unchanged to STANDALONE_COMMENT line by line.\"\"\"\n for leaf in node.leaves():\n if leaf.type != NEWLINE:\n # We only consider \"unwrapped lines\", which are divided"]} +{"repo": "psf/black", "name": "make_simple_prefix", "language": "python", "path": "src/black/nodes.py", "position_ratio": 0.25, "description": "\nFunction Description:\n1. **Purpose**: The function generates a string used to format text with new lines and optionally a form feed character, typically used to control spacing in formatted output.\n2. **Input**: The function takes three parameters: an integer representing the number of new lines, a boolean indicating whether a form feed character should be included, and an optional string representing the line break character (defaulting to a newline).\n3. **Output**: It returns a string composed of the specified number of newline characters, and if requested, includes a form feed character followed by an additional newline.\n4. **Procedure**: The function first checks if the form feed should be included. If true, it concatenates the specified number of newline characters minus one with a form feed character and another newline. If false, it simply returns a string of newline characters multiplied by the specified integer.\n\n", "instruction": "Based on the function description and code context, please retrieve and repeat the exact described function from the code context in a code block wrapped by ```:", "template": "instruction\ncode_context\ndescription\ninstruction", "code_context": "# Path: src/black/strings.py\n\"\"\"\nSimple formatting on strings. Further string formatting code is in trans.py.\n\"\"\"\n\nimport re\nimport sys\nfrom functools import lru_cache\nfrom typing import Final, List, Match, Pattern\n\nfrom black._width_table import WIDTH_TABLE\nfrom blib2to3.pytree import Leaf\n\nSTRING_PREFIX_CHARS: Final = \"furbFURB\" # All possible string prefix characters.\nSTRING_PREFIX_RE: Final = re.compile(\n r\"^([\" + STRING_PREFIX_CHARS + r\"]*)(.*)$\", re.DOTALL\n)\nFIRST_NON_WHITESPACE_RE: Final = re.compile(r\"\\s*\\t+\\s*(\\S)\")\nUNICODE_ESCAPE_RE: Final = re.compile(\n r\"(?P\\\\+)(?P\"\n...\n# Path: src/black/nodes.py\n\"\"\"\nblib2to3 Node/Leaf transformation-related utility functions.\n\"\"\"\n\nimport sys\nfrom typing import (\n Final,\n Generic,\n Iterator,\n List,\n Literal,\n Optional,\n Set,\n Tuple,\n TypeVar,\n Union,\n)\n\nif sys.version_info >= (3, 10):\n from typing import TypeGuard\nelse:\n from typing_extensions import TypeGuard\n\nfrom mypy_extensions import mypyc_attr\n\nfrom black.cache import CACHE_DIR\nfrom black.mode import Mode, Preview\nfrom black.strings import get_string_prefix, has_triple_quotes\nfrom blib2to3 import pygram\nfrom blib2to3.pgen2 import token\nfrom blib2to3.pytree import NL, Leaf, Node, type_repr\n\npygram.initialize(CACHE_DIR)\nsyms: Final = pygram.python_symbols\n\n\n# types\nT = TypeVar(\"T\")\nLN = Union[Leaf, Node]\nLeafID = int\nNodeType = int\n\n\nWHITESPACE: Final = {token.DEDENT, token.INDENT, token.NEWLINE}\nSTATEMENT: Final = {\n syms.if_stmt,\n syms.while_stmt,\n syms.for_stmt,\n syms.try_stmt,\n syms.except_clause,\n syms.with_stmt,\n syms.funcdef,\n syms.classdef,\n syms.match_stmt,\n syms.case_block,\n}\nSTANDALONE_COMMENT: Final = 153\ntoken.tok_name[STANDALONE_COMMENT] = \"STANDALONE_COMMENT\"\nLOGIC_OPERATORS: Final = {\"and\", \"or\"}\nCOMPARATORS: Final = {\n token.LESS,\n token.GREATER,\n token.EQEQUAL,\n token.NOTEQUAL,\n token.LESSEQUAL,\n token.GREATEREQUAL,\n}\nMATH_OPERATORS: Final = {\n token.VBAR,\n token.CIRCUMFLEX,\n token.AMPER,\n token.LEFTSHIFT,\n token.RIGHTSHIFT,\n token.PLUS,\n token.MINUS,\n token.STAR,\n token.SLASH,\n token.DOUBLESLASH,\n token.PERCENT,\n token.AT,\n token.TILDE,\n token.DOUBLESTAR,\n}\nSTARS: Final = {token.STAR, token.DOUBLESTAR}\nVARARGS_SPECIALS: Final = STARS | {token.SLASH}\nVARARGS_PARENTS: Final = {\n syms.arglist,\n syms.argument, # double star in arglist\n syms.trailer, # single argument to call\n syms.typedargslist,\n syms.varargslist, # lambdas\n}\nUNPACKING_PARENTS: Final = {\n syms.atom, # single element of a list or set literal\n syms.dictsetmaker,\n syms.listmaker,\n syms.testlist_gexp,\n syms.testlist_star_expr,\n syms.subject_expr,\n syms.pattern,\n}\nTEST_DESCENDANTS: Final = {\n syms.test,\n syms.lambdef,\n syms.or_test,\n syms.and_test,\n syms.not_test,\n syms.comparison,\n syms.star_expr,\n syms.expr,\n syms.xor_expr,\n syms.and_expr,\n syms.shift_expr,\n syms.arith_expr,\n syms.trailer,\n syms.term,\n syms.power,\n syms.namedexpr_test,\n}\nTYPED_NAMES: Final = {syms.tname, syms.tname_star}\nASSIGNMENTS: Final = {\n \"=\",\n \"+=\",\n \"-=\",\n \"*=\",\n \"@=\",\n \"/=\",\n \"%=\",\n \"&=\",\n \"|=\",\n \"^=\",\n \"<<=\",\n \">>=\",\n \"**=\",\n \"//=\",\n \":\",\n}\n\nIMPLICIT_TUPLE: Final = {syms.testlist, syms.testlist_star_expr, syms.exprlist}\nBRACKET: Final = {\n token.LPAR: token.RPAR,\n token.LSQB: token.RSQB,\n token.LBRACE: token.RBRACE,\n}\nOPENING_BRACKETS: Final = set(BRACKET.keys())\nCLOSING_BRACKETS: Final = set(BRACKET.values())\nBRACKETS: Final = OPENING_BRACKETS | CLOSING_BRACKETS\nALWAYS_NO_SPACE: Final = CLOSING_BRACKETS | {token.COMMA, STANDALONE_COMMENT}\n\nRARROW = 55\n\n\n@mypyc_attr(allow_interpreted_subclasses=True)\nclass Visitor(Generic[T]):\n \"\"\"Basic lib2to3 visitor that yields things of type `T` on `visit()`.\"\"\"\n\n def visit(self, node: LN) -> Iterator[T]:\n \"\"\"Main method to visit `node` and its children.\n\n It tries to find a `visit_*()` method for the given `node.type`, like\n `visit_simple_stmt` for Node objects or `visit_INDENT` for Leaf objects.\n If no dedicated `visit_*()` method is found, chooses `visit_default()`\n instead.\n\n Then yields objects of type `T` from the selected visitor.\n \"\"\"\n if node.type < 256:\n name = token.tok_name[node.type]\n else:\n name = str(type_repr(node.type))\n # We explicitly branch on whether a visitor exists (instead of\n # using self.visit_default as the default arg to getattr) in order\n # to save needing to create a bound method object and so mypyc can\n # generate a native call to visit_default.\n visitf = getattr(self, f\"visit_{name}\", None)\n if visitf:\n yield from visitf(node)\n else:\n yield from self.visit_default(node)\n\n def visit_default(self, node: LN) -> Iterator[T]:\n \"\"\"Default `visit_*()` implementation. Recurses to children of `node`.\"\"\"\n if isinstance(node, Node):\n for child in node.children:\n yield from self.visit(child)\n\n\ndef whitespace(leaf: Leaf, *, complex_subscript: bool, mode: Mode) -> str: # noqa: C901\n \"\"\"Return whitespace prefix if needed for the given `leaf`.\n\n `complex_subscript` signals whether the given leaf is part of a subscription\n which has non-trivial arguments, like arithmetic expressions or function calls.\n \"\"\"\n NO: Final[str] = \"\"\n SPACE: Final[str] = \" \"\n DOUBLESPACE: Final[str] = \" \"\n t = leaf.type\n p = leaf.parent\n v = leaf.value\n if t in ALWAYS_NO_SPACE:\n return NO\n\n if t == token.COMMENT:\n return DOUBLESPACE\n\n assert p is not None, f\"INTERNAL ERROR: hand-made leaf without parent: {leaf!r}\"\n if t == token.COLON and p.type not in {\n syms.subscript,\n syms.subscriptlist,\n syms.sliceop,\n }:\n return NO\n\n prev = leaf.prev_sibling\n if not prev:\n prevp = preceding_leaf(p)\n if not prevp or prevp.type in OPENING_BRACKETS:\n return NO\n\n if t == token.COLON:\n if prevp.type == token.COLON:\n return NO\n\n elif prevp.type != token.COMMA and not complex_subscript:\n return NO\n\n return SPACE\n\n if prevp.type == token.EQUAL:\n if prevp.parent:\n if prevp.parent.type in {\n syms.arglist,\n syms.argument,\n syms.parameters,\n syms.varargslist,\n }:\n return NO\n\n elif prevp.parent.type == syms.typedargslist:\n # A bit hacky: if the equal sign has whitespace, it means we\n # previously found it's a typed argument. So, we're using\n # that, too.\n return prevp.prefix\n\n elif (\n prevp.type == token.STAR\n and parent_type(prevp) == syms.star_expr\n and parent_type(prevp.parent) == syms.subscriptlist\n ):\n # No space between typevar tuples.\n return NO\n\n elif prevp.type in VARARGS_SPECIALS:\n if is_vararg(prevp, within=VARARGS_PARENTS | UNPACKING_PARENTS):\n return NO\n\n elif prevp.type == token.COLON:\n if prevp.parent and prevp.parent.type in {syms.subscript, syms.sliceop}:\n return SPACE if complex_subscript else NO\n\n elif (\n prevp.parent\n and prevp.parent.type == syms.factor\n and prevp.type in MATH_OPERATORS\n ):\n return NO\n\n elif prevp.type == token.AT and p.parent and p.parent.type == syms.decorator:\n # no space in decorators\n return NO\n\n elif prev.type in OPENING_BRACKETS:\n return NO\n\n if p.type in {syms.parameters, syms.arglist}:\n # untyped function signatures or calls\n if not prev or prev.type != token.COMMA:\n return NO\n\n elif p.type == syms.varargslist:\n # lambdas\n if prev and prev.type != token.COMMA:\n return NO\n\n elif p.type == syms.typedargslist:\n # typed function signatures\n if not prev:\n return NO\n\n if t == token.EQUAL:\n if prev.type not in TYPED_NAMES:\n return NO\n\n elif prev.type == token.EQUAL:\n # A bit hacky: if the equal sign has whitespace, it means we\n # previously found it's a typed argument. So, we're using that, too.\n return prev.prefix\n\n elif prev.type != token.COMMA:\n return NO\n\n elif p.type in TYPED_NAMES:\n # type names\n if not prev:\n prevp = preceding_leaf(p)\n if not prevp or prevp.type != token.COMMA:\n return NO\n\n elif p.type == syms.trailer:\n # attributes and calls\n if t == token.LPAR or t == token.RPAR:\n return NO\n\n if not prev:\n if t == token.DOT or t == token.LSQB:\n return NO\n\n elif prev.type != token.COMMA:\n return NO\n\n elif p.type == syms.argument:\n # single argument\n if t == token.EQUAL:\n return NO\n\n if not prev:\n prevp = preceding_leaf(p)\n if not prevp or prevp.type == token.LPAR:\n return NO\n\n elif prev.type in {token.EQUAL} | VARARGS_SPECIALS:\n return NO\n\n elif p.type == syms.decorator:\n # decorators\n return NO\n\n elif p.type == syms.dotted_name:\n if prev:\n return NO\n\n prevp = preceding_leaf(p)\n if not prevp or prevp.type == token.AT or prevp.type == token.DOT:\n return NO\n\n elif p.type == syms.classdef:\n if t == token.LPAR:\n return NO\n\n if prev and prev.type == token.LPAR:\n return NO\n\n elif p.type in {syms.subscript, syms.sliceop}:\n # indexing\n if not prev:\n assert p.parent is not None, \"subscripts are always parented\"\n if p.parent.type == syms.subscriptlist:\n return SPACE\n\n return NO\n\n elif t == token.COLONEQUAL or prev.type == token.COLONEQUAL:\n return SPACE\n\n elif not complex_subscript:\n return NO\n\n elif p.type == syms.atom:\n if prev and t == token.DOT:\n # dots, but not the first one.\n return NO\n\n elif p.type == syms.dictsetmaker:\n # dict unpacking\n if prev and prev.type == token.DOUBLESTAR:\n return NO\n\n elif p.type in {syms.factor, syms.star_expr}:\n # unary ops\n if not prev:\n prevp = preceding_leaf(p)\n if not prevp or prevp.type in OPENING_BRACKETS:\n return NO\n\n prevp_parent = prevp.parent\n assert prevp_parent is not None\n if prevp.type == token.COLON and prevp_parent.type in {\n syms.subscript,\n syms.sliceop,\n }:\n return NO\n\n elif prevp.type == token.EQUAL and prevp_parent.type == syms.argument:\n return NO\n\n elif t in {token.NAME, token.NUMBER, token.STRING}:\n return NO\n\n elif p.type == syms.import_from:\n if t == token.DOT:\n if prev and prev.type == token.DOT:\n return NO\n\n elif t == token.NAME:\n if v == \"import\":\n return SPACE\n\n if prev and prev.type == token.DOT:\n return NO\n\n elif p.type == syms.sliceop:\n return NO\n\n elif p.type == syms.except_clause:\n if t == token.STAR:\n return NO\n\n return SPACE\n\n\n\ndef make_simple_prefix(nl_count: int, form_feed: bool, empty_line: str = \"\\n\") -> str:\n \"\"\"Generate a normalized prefix string.\"\"\"\n if form_feed:\n return (empty_line * (nl_count - 1)) + \"\\f\" + empty_line\n return empty_line * nl_count\n\n\ndef preceding_leaf(node: Optional[LN]) -> Optional[Leaf]:\n \"\"\"Return the first leaf that precedes `node`, if any.\"\"\"\n while node:\n res = node.prev_sibling\n if res:\n if isinstance(res, Leaf):\n return res\n\n try:\n return list(res.leaves())[-1]\n\n except IndexError:\n return None\n\n node = node.parent\n return None\n\n\ndef prev_siblings_are(node: Optional[LN], tokens: List[Optional[NodeType]]) -> bool:\n \"\"\"Return if the `node` and its previous siblings match types against the provided\n list of tokens; the provided `node`has its type matched against the last element in\n the list. `None` can be used as the first element to declare that the start of the\n list is anchored at the start of its parent's children.\"\"\"\n if not tokens:\n return True\n if tokens[-1] is None:\n return node is None\n if not node:\n return False\n if node.type != tokens[-1]:\n return False\n return prev_siblings_are(node.prev_sibling, tokens[:-1])\n\n\ndef parent_type(node: Optional[LN]) -> Optional[NodeType]:\n \"\"\"\n Returns:\n @node.parent.type, if @node is not None and has a parent.\n OR\n None, otherwise.\n \"\"\"\n if node is None or node.parent is None:\n return None\n\n return node.parent.type\n\n\ndef child_towards(ancestor: Node, descendant: LN) -> Optional[LN]:\n \"\"\"Return the child of `ancestor` that contains `descendant`.\"\"\"\n node: Optional[LN] = descendant\n while node and node.parent != ancestor:\n node = node.parent\n return node\n\n\ndef replace_child(old_child: LN, new_child: LN) -> None:\n \"\"\"\n Side Effects:\n * If @old_child.parent is set, replace @old_child with @new_child in\n @old_child's underlying Node structure.\n OR\n * Otherwise, this function does nothing.\n \"\"\"\n parent = old_child.parent\n if not parent:\n return\n\n child_idx = old_child.remove()\n if child_idx is not None:\n parent.insert_child(child_idx, new_child)\n\n\ndef container_of(leaf: Leaf) -> LN:\n \"\"\"Return `leaf` or one of its ancestors that is the topmost container of it.\n\n By \"container\" we mean a node where `leaf` is the very first child.\n \"\"\"\n same_prefix = leaf.prefix\n container: LN = leaf\n while container:\n parent = container.parent\n if parent is None:\n break\n\n if parent.children[0].prefix != same_prefix:\n break\n\n if parent.type == syms.file_input:\n break\n\n if parent.prev_sibling is not None and parent.prev_sibling.type in BRACKETS:\n break\n\n container = parent\n return container\n\n\ndef first_leaf_of(node: LN) -> Optional[Leaf]:\n \"\"\"Returns the first leaf of the node tree.\"\"\"\n if isinstance(node, Leaf):\n return node\n if node.children:\n return first_leaf_of(node.children[0])\n else:\n return None\n\n\ndef is_arith_like(node: LN) -> bool:\n \"\"\"Whether node is an arithmetic or a binary arithmetic expression\"\"\"\n return node.type in {\n syms.arith_expr,\n syms.shift_expr,\n syms.xor_expr,\n syms.and_expr,\n }\n\n\ndef is_docstring(leaf: Leaf, mode: Mode) -> bool:\n if leaf.type != token.STRING:\n return False\n\n prefix = get_string_prefix(leaf.value)\n if set(prefix).intersection(\"bBfF\"):\n return False\n\n if (\n Preview.unify_docstring_detection in mode\n and leaf.parent\n and leaf.parent.type == syms.simple_stmt\n and not leaf.parent.prev_sibling\n and leaf.parent.parent\n and leaf.parent.parent.type == syms.file_input\n ):\n return True\n\n if prev_siblings_are(\n leaf.parent, [None, token.NEWLINE, token.INDENT, syms.simple_stmt]\n ):\n return True\n\n # Multiline docstring on the same line as the `def`.\n if prev_siblings_are(leaf.parent, [syms.parameters, token.COLON, syms.simple_stmt]):\n # `syms.parameters` is only used in funcdefs and async_funcdefs in the Python\n # grammar. We're safe to return True without further checks.\n return True\n\n return False\n\n\ndef is_empty_tuple(node: LN) -> bool:\n \"\"\"Return True if `node` holds an empty tuple.\"\"\"\n return (\n node.type == syms.atom\n and len(node.children) == 2\n and node.children[0].type == token.LPAR\n and node.children[1].type == token.RPAR\n )\n\n\ndef is_one_tuple(node: LN) -> bool:\n \"\"\"Return True if `node` holds a tuple with one element, with or without parens.\"\"\"\n if node.type == syms.atom:\n gexp = unwrap_singleton_parenthesis(node)\n if gexp is None or gexp.type != syms.testlist_gexp:\n return False\n\n return len(gexp.children) == 2 and gexp.children[1].type == token.COMMA\n\n return (\n node.type in IMPLICIT_TUPLE\n and len(node.children) == 2\n and node.children[1].type == token.COMMA\n )\n\n\ndef is_tuple_containing_walrus(node: LN) -> bool:\n \"\"\"Return True if `node` holds a tuple that contains a walrus operator.\"\"\"\n if node.type != syms.atom:\n return False\n gexp = unwrap_singleton_parenthesis(node)\n if gexp is None or gexp.type != syms.testlist_gexp:\n return False\n\n return any(child.type == syms.namedexpr_test for child in gexp.children)\n\n\ndef is_one_sequence_between(\n opening: Leaf,\n closing: Leaf,\n leaves: List[Leaf],\n brackets: Tuple[int, int] = (token.LPAR, token.RPAR),\n) -> bool:\n \"\"\"Return True if content between `opening` and `closing` is a one-sequence.\"\"\"\n if (opening.type, closing.type) != brackets:\n return False\n\n depth = closing.bracket_depth + 1\n for _opening_index, leaf in enumerate(leaves):\n if leaf is opening:\n break\n\n else:\n raise LookupError(\"Opening paren not found in `leaves`\")\n\n commas = 0\n _opening_index += 1\n for leaf in leaves[_opening_index:]:\n if leaf is closing:\n break\n\n bracket_depth = leaf.bracket_depth\n if bracket_depth == depth and leaf.type == token.COMMA:\n commas += 1\n if leaf.parent and leaf.parent.type in {\n syms.arglist,\n syms.typedargslist,\n }:\n commas += 1\n break\n\n return commas < 2\n\n\ndef is_walrus_assignment(node: LN) -> bool:\n \"\"\"Return True iff `node` is of the shape ( test := test )\"\"\"\n inner = unwrap_singleton_parenthesis(node)\n return inner is not None and inner.type == syms.namedexpr_test\n\n\ndef is_simple_decorator_trailer(node: LN, last: bool = False) -> bool:\n \"\"\"Return True iff `node` is a trailer valid in a simple decorator\"\"\"\n return node.type == syms.trailer and (\n (\n len(node.children) == 2\n and node.children[0].type == token.DOT\n and node.children[1].type == token.NAME\n )\n # last trailer can be an argument-less parentheses pair\n or (\n last\n and len(node.children) == 2\n and node.children[0].type == token.LPAR\n and node.children[1].type == token.RPAR\n )\n # last trailer can be arguments\n or (\n last\n and len(node.children) == 3\n and node.children[0].type == token.LPAR\n # and node.children[1].type == syms.argument\n and node.children[2].type == token.RPAR\n )\n )\n\n\ndef is_simple_decorator_expression(node: LN) -> bool:\n \"\"\"Return True iff `node` could be a 'dotted name' decorator\n\n This function takes the node of the 'namedexpr_test' of the new decorator\n grammar and test if it would be valid under the old decorator grammar.\n\n The old grammar was: decorator: @ dotted_name [arguments] NEWLINE\n The new grammar is : decorator: @ namedexpr_test NEWLINE\n \"\"\"\n if node.type == token.NAME:\n return True\n if node.type == syms.power:\n if node.children:\n return (\n node.children[0].type == token.NAME\n and all(map(is_simple_decorator_trailer, node.children[1:-1]))\n and (\n len(node.children) < 2\n or is_simple_decorator_trailer(node.children[-1], last=True)\n )\n )\n return False\n\n\ndef is_yield(node: LN) -> bool:\n \"\"\"Return True if `node` holds a `yield` or `yield from` expression.\"\"\"\n if node.type == syms.yield_expr:\n return True\n\n if is_name_token(node) and node.value == \"yield\":\n return True\n\n if node.type != syms.atom:\n return False\n\n if len(node.children) != 3:\n return False\n\n lpar, expr, rpar = node.children\n if lpar.type == token.LPAR and rpar.type == token.RPAR:\n return is_yield(expr)\n\n return False\n\n\ndef is_vararg(leaf: Leaf, within: Set[NodeType]) -> bool:\n \"\"\"Return True if `leaf` is a star or double star in a vararg or kwarg.\n\n If `within` includes VARARGS_PARENTS, this applies to function signatures.\n If `within` includes UNPACKING_PARENTS, it applies to right hand-side\n extended iterable unpacking (PEP 3132) and additional unpacking\n generalizations (PEP 448).\n \"\"\"\n if leaf.type not in VARARGS_SPECIALS or not leaf.parent:\n return False\n\n p = leaf.parent\n if p.type == syms.star_expr:\n # Star expressions are also used as assignment targets in extended\n # iterable unpacking (PEP 3132). See what its parent is instead.\n if not p.parent:\n return False\n\n p = p.parent\n\n return p.type in within\n\n\ndef is_multiline_string(leaf: Leaf) -> bool:\n \"\"\"Return True if `leaf` is a multiline string that actually spans many lines.\"\"\"\n return has_triple_quotes(leaf.value) and \"\\n\" in leaf.value\n\n\ndef is_parent_function_or_class(node: Node) -> bool:\n assert node.type in {syms.suite, syms.simple_stmt}\n assert node.parent is not None\n # Note this works for suites / simple_stmts in async def as well\n return node.parent.type in {syms.funcdef, syms.classdef}\n\n\ndef is_function_or_class(node: Node) -> bool:\n return node.type in {syms.funcdef, syms.classdef, syms.async_funcdef}\n\n\ndef is_stub_suite(node: Node) -> bool:\n \"\"\"Return True if `node` is a suite with a stub body.\"\"\"\n if node.parent is not None and not is_parent_function_or_class(node):\n return False\n\n # If there is a comment, we want to keep it.\n if node.prefix.strip():\n return False\n\n if (\n len(node.children) != 4\n or node.children[0].type != token.NEWLINE\n or node.children[1].type != token.INDENT\n or node.children[3].type != token.DEDENT\n ):\n return False\n\n if node.children[3].prefix.strip():\n return False\n\n return is_stub_body(node.children[2])\n\n\ndef is_stub_body(node: LN) -> bool:\n \"\"\"Return True if `node` is a simple statement containing an ellipsis.\"\"\"\n if not isinstance(node, Node) or node.type != syms.simple_stmt:\n return False\n\n if len(node.children) != 2:\n return False\n\n child = node.children[0]\n return (\n not child.prefix.strip()\n and child.type == syms.atom\n and len(child.children) == 3\n and all(leaf == Leaf(token.DOT, \".\") for leaf in child.children)\n )\n\n\ndef is_atom_with_invisible_parens(node: LN) -> bool:\n \"\"\"Given a `LN`, determines whether it's an atom `node` with invisible\n parens. Useful in dedupe-ing and normalizing parens.\n \"\"\"\n if isinstance(node, Leaf) or node.type != syms.atom:\n return False\n\n first, last = node.children[0], node.children[-1]\n return (\n isinstance(first, Leaf)\n and first.type == token.LPAR\n and first.value == \"\"\n and isinstance(last, Leaf)\n and last.type == token.RPAR\n and last.value == \"\"\n )\n\n\ndef is_empty_par(leaf: Leaf) -> bool:\n return is_empty_lpar(leaf) or is_empty_rpar(leaf)\n\n\ndef is_empty_lpar(leaf: Leaf) -> bool:\n return leaf.type == token.LPAR and leaf.value == \"\"\n\n\ndef is_empty_rpar(leaf: Leaf) -> bool:\n return leaf.type == token.RPAR and leaf.value == \"\"\n\n\ndef is_import(leaf: Leaf) -> bool:\n \"\"\"Return True if the given leaf starts an import statement.\"\"\"\n p = leaf.parent\n t = leaf.type\n v = leaf.value\n return bool(\n t == token.NAME\n and (\n (v == \"import\" and p and p.type == syms.import_name)\n or (v == \"from\" and p and p.type == syms.import_from)\n )\n )\n\n\ndef is_with_or_async_with_stmt(leaf: Leaf) -> bool:\n \"\"\"Return True if the given leaf starts a with or async with statement.\"\"\"\n return bool(\n leaf.type == token.NAME\n and leaf.value == \"with\"\n and leaf.parent\n and leaf.parent.type == syms.with_stmt\n ) or bool(\n leaf.type == token.ASYNC\n and leaf.next_sibling\n and leaf.next_sibling.type == syms.with_stmt\n )\n\n\ndef is_async_stmt_or_funcdef(leaf: Leaf) -> bool:\n \"\"\"Return True if the given leaf starts an async def/for/with statement.\n\n Note that `async def` can be either an `async_stmt` or `async_funcdef`,\n the latter is used when it has decorators.\n \"\"\"\n return bool(\n leaf.type == token.ASYNC\n and leaf.parent\n and leaf.parent.type in {syms.async_stmt, syms.async_funcdef}\n )\n\n\ndef is_type_comment(leaf: Leaf) -> bool:\n \"\"\"Return True if the given leaf is a type comment. This function should only\n be used for general type comments (excluding ignore annotations, which should\n use `is_type_ignore_comment`). Note that general type comments are no longer\n used in modern version of Python, this function may be deprecated in the future.\"\"\"\n t = leaf.type\n v = leaf.value\n return t in {token.COMMENT, STANDALONE_COMMENT} and v.startswith(\"# type:\")\n\n\ndef is_type_ignore_comment(leaf: Leaf) -> bool:\n \"\"\"Return True if the given leaf is a type comment with ignore annotation.\"\"\"\n t = leaf.type\n v = leaf.value\n return t in {token.COMMENT, STANDALONE_COMMENT} and is_type_ignore_comment_string(v)\n\n\ndef is_type_ignore_comment_string(value: str) -> bool:\n \"\"\"Return True if the given string match with type comment with\n ignore annotation.\"\"\"\n return value.startswith(\"# type: ignore\")\n\n\ndef wrap_in_parentheses(parent: Node, child: LN, *, visible: bool = True) -> None:\n \"\"\"Wrap `child` in parentheses.\n\n This replaces `child` with an atom holding the parentheses and the old\n child. That requires moving the prefix.\n\n If `visible` is False, the leaves will be valueless (and thus invisible).\n \"\"\"\n lpar = Leaf(token.LPAR, \"(\" if visible else \"\")\n rpar = Leaf(token.RPAR, \")\" if visible else \"\")\n prefix = child.prefix\n child.prefix = \"\"\n index = child.remove() or 0\n new_child = Node(syms.atom, [lpar, child, rpar])\n new_child.prefix = prefix\n parent.insert_child(index, new_child)\n\n\ndef unwrap_singleton_parenthesis(node: LN) -> Optional[LN]:\n \"\"\"Returns `wrapped` if `node` is of the shape ( wrapped ).\n\n Parenthesis can be optional. Returns None otherwise\"\"\"\n if len(node.children) != 3:\n return None\n\n lpar, wrapped, rpar = node.children\n if not (lpar.type == token.LPAR and rpar.type == token.RPAR):\n return None\n\n return wrapped\n\n\ndef ensure_visible(leaf: Leaf) -> None:\n \"\"\"Make sure parentheses are visible.\n\n They could be invisible as part of some statements (see\n :func:`normalize_invisible_parens` and :func:`visit_import_from`).\n \"\"\"\n if leaf.type == token.LPAR:\n leaf.value = \"(\"\n elif leaf.type == token.RPAR:\n leaf.value = \")\"\n\n\ndef is_name_token(nl: NL) -> TypeGuard[Leaf]:\n return nl.type == token.NAME\n\n\ndef is_lpar_token(nl: NL) -> TypeGuard[Leaf]:\n return nl.type == token.LPAR\n\n\ndef is_rpar_token(nl: NL) -> TypeGuard[Leaf]:\n return nl.type == token.RPAR\n\n\ndef is_string_token(nl: NL) -> TypeGuard[Leaf]:\n return nl.type == token.STRING\n\n\ndef is_number_token(nl: NL) -> TypeGuard[Leaf]:\n return nl.type == token.NUMBER\n\n\ndef get_annotation_type(leaf: Leaf) -> Literal[\"return\", \"param\", None]:\n \"\"\"Returns the type of annotation this leaf is part of, if any.\"\"\"\n ancestor = leaf.parent\n while ancestor is not None:\n if ancestor.prev_sibling and ancestor.prev_sibling.type == token.RARROW:\n return \"return\"\n if ancestor.parent and ancestor.parent.type == syms.tname:\n return \"param\"\n ancestor = ancestor.parent\n return None\n\n\ndef is_part_of_annotation(leaf: Leaf) -> bool:\n \"\"\"Returns whether this leaf is part of a type annotation.\"\"\"\n return get_annotation_type(leaf) is not None\n\n\ndef first_leaf(node: LN) -> Optional[Leaf]:\n \"\"\"Returns the first leaf of the ancestor node.\"\"\"\n if isinstance(node, Leaf):\n return node\n elif not node.children:\n return None\n else:\n return first_leaf(node.children[0])\n\n\ndef last_leaf(node: LN) -> Optional[Leaf]:\n \"\"\"Returns the last leaf of the ancestor node.\"\"\"\n if isinstance(node, Leaf):\n return node\n elif not node.children:\n return None\n else:\n return last_leaf(node.children[-1])\n\n\ndef furthest_ancestor_with_last_leaf(leaf: Leaf) -> LN:\n \"\"\"Returns the furthest ancestor that has this leaf node as the last leaf.\"\"\"\n node: LN = leaf\n while node.parent and node.parent.children and node is node.parent.children[-1]:\n node = node.parent\n return node\n\n# Path: src/black/comments.py\nimport re\nfrom dataclasses import dataclass\nfrom functools import lru_cache\nfrom typing import Collection, Final, Iterator, List, Optional, Tuple, Union\n\nfrom black.mode import Mode, Preview\nfrom black.nodes import (\n CLOSING_BRACKETS,\n STANDALONE_COMMENT,\n WHITESPACE,\n container_of,\n first_leaf_of,\n make_simple_prefix,\n preceding_leaf,\n syms,\n)\nfrom blib2to3.pgen2 import token\nfrom blib2to3.pytree import Leaf, Node\n\n# types\nLN = Union[Leaf, Node]\n\nFMT_OFF: Final = {\"# fmt: off\", \"# fmt:off\", \"# yapf: disable\"}\nFMT_SKIP: Final = {\"# fmt: skip\", \"# fmt:skip\"}\nFMT_ON: Final = {\"# fmt: on\", \"# fmt:on\", \"# yapf: enable\"}\n\nCOMMENT_EXCEPTIONS = \" !:#'\"\n_COMMENT_PREFIX = \"# \"\n_COMMENT_LIST_SEPARATOR = \";\"\n\n\n@dataclass\nclass ProtoComment:\n \"\"\"Describes a piece of syntax that is a comment.\n\n It's not a :class:`blib2to3.pytree.Leaf` so that:\n\n * it can be cached (`Leaf` objects should not be reused more than once as\n they store their lineno, column, prefix, and parent information);\n * `newlines` and `consumed` fields are kept separate from the `value`. This\n simplifies handling of special marker comments like ``# fmt: off/on``.\n \"\"\"\n\n type: int # token.COMMENT or STANDALONE_COMMENT\n value: str # content of the comment\n newlines: int # how many newlines before the comment\n consumed: int # how many characters of the original leaf's prefix did we consume\n form_feed: bool # is there a form feed before the comment\n leading_whitespace: str # leading whitespace before the comment, if any\n\n\ndef generate_comments(leaf: LN) -> Iterator[Leaf]:\n \"\"\"Clean the prefix of the `leaf` and generate comments from it, if any.\n\n Comments in lib2to3 are shoved into the whitespace prefix. This happens\n in `pgen2/driver.py:Driver.parse_tokens()`. This was a brilliant implementation\n move because it does away with modifying the grammar to include all the\n possible places in which comments can be placed.\n\n The sad consequence for us though is that comments don't \"belong\" anywhere.\n This is why this function generates simple parentless Leaf objects for\n comments. We simply don't know what the correct parent should be.\n\n No matter though, we can live without this. We really only need to\n differentiate between inline and standalone comments. The latter don't\n share the line with any code.\n\n Inline comments are emitted as regular token.COMMENT leaves. Standalone\n are emitted with a fake STANDALONE_COMMENT token identifier.\n \"\"\"\n total_consumed = 0\n for pc in list_comments(leaf.prefix, is_endmarker=leaf.type == token.ENDMARKER):\n total_consumed = pc.consumed\n prefix = make_simple_prefix(pc.newlines, pc.form_feed)\n yield Leaf(pc.type, pc.value, prefix=prefix)\n normalize_trailing_prefix(leaf, total_consumed)\n\n\n@lru_cache(maxsize=4096)\ndef list_comments(prefix: str, *, is_endmarker: bool) -> List[ProtoComment]:\n \"\"\"Return a list of :class:`ProtoComment` objects parsed from the given `prefix`.\"\"\"\n result: List[ProtoComment] = []\n if not prefix or \"#\" not in prefix:\n return result\n\n consumed = 0\n nlines = 0\n ignored_lines = 0\n form_feed = False\n for index, full_line in enumerate(re.split(\"\\r?\\n\", prefix)):\n consumed += len(full_line) + 1 # adding the length of the split '\\n'\n match = re.match(r\"^(\\s*)(\\S.*|)$\", full_line)\n assert match\n whitespace, line = match.groups()\n if not line:\n nlines += 1\n if \"\\f\" in full_line:\n form_feed = True\n if not line.startswith(\"#\"):\n # Escaped newlines outside of a comment are not really newlines at\n # all. We treat a single-line comment following an escaped newline\n # as a simple trailing comment.\n if line.endswith(\"\\\\\"):\n ignored_lines += 1\n continue\n\n if index == ignored_lines and not is_endmarker:\n comment_type = token.COMMENT # simple trailing comment\n else:\n comment_type = STANDALONE_COMMENT\n comment = make_comment(line)\n result.append(\n ProtoComment(\n type=comment_type,\n value=comment,\n newlines=nlines,\n consumed=consumed,\n form_feed=form_feed,\n leading_whitespace=whitespace,\n )\n )\n form_feed = False\n nlines = 0\n return result\n\n\ndef normalize_trailing_prefix(leaf: LN, total_consumed: int) -> None:\n \"\"\"Normalize the prefix that's left over after generating comments.\n\n Note: don't use backslashes for formatting or you'll lose your voting rights.\n \"\"\"\n remainder = leaf.prefix[total_consumed:]\n if \"\\\\\" not in remainder:\n nl_count = remainder.count(\"\\n\")\n form_feed = \"\\f\" in remainder and remainder.endswith(\"\\n\")\n leaf.prefix = make_simple_prefix(nl_count, form_feed)\n return\n\n leaf.prefix = \"\"\n\n\ndef make_comment(content: str) -> str:\n \"\"\"Return a consistently formatted comment from the given `content` string.\n\n All comments (except for \"##\", \"#!\", \"#:\", '#'\") should have a single\n space between the hash sign and the content.\n\n If `content` didn't start with a hash sign, one is provided.\n \"\"\"\n content = content.rstrip()\n if not content:\n return \"#\"\n\n if content[0] == \"#\":\n content = content[1:]\n NON_BREAKING_SPACE = \"\u00a0\"\n if (\n content\n and content[0] == NON_BREAKING_SPACE\n and not content.lstrip().startswith(\"type:\")\n ):\n content = \" \" + content[1:] # Replace NBSP by a simple space\n if content and content[0] not in COMMENT_EXCEPTIONS:\n content = \" \" + content\n return \"#\" + content\n\n\ndef normalize_fmt_off(\n node: Node, mode: Mode, lines: Collection[Tuple[int, int]]\n) -> None:\n \"\"\"Convert content between `# fmt: off`/`# fmt: on` into standalone comments.\"\"\"\n try_again = True\n while try_again:\n try_again = convert_one_fmt_off_pair(node, mode, lines)\n\n\ndef convert_one_fmt_off_pair(\n node: Node, mode: Mode, lines: Collection[Tuple[int, int]]\n) -> bool:\n \"\"\"Convert content of a single `# fmt: off`/`# fmt: on` into a standalone comment.\n\n Returns True if a pair was converted.\n \"\"\"\n for leaf in node.leaves():\n previous_consumed = 0\n for comment in list_comments(leaf.prefix, is_endmarker=False):\n should_pass_fmt = comment.value in FMT_OFF or _contains_fmt_skip_comment(\n comment.value, mode\n )\n if not should_pass_fmt:\n previous_consumed = comment.consumed\n continue\n # We only want standalone comments. If there's no previous leaf or\n # the previous leaf is indentation, it's a standalone comment in\n # disguise.\n if should_pass_fmt and comment.type != STANDALONE_COMMENT:\n prev = preceding_leaf(leaf)\n if prev:\n if comment.value in FMT_OFF and prev.type not in WHITESPACE:\n continue\n if (\n _contains_fmt_skip_comment(comment.value, mode)\n and prev.type in WHITESPACE\n ):\n continue\n\n ignored_nodes = list(generate_ignored_nodes(leaf, comment, mode))\n if not ignored_nodes:\n continue\n\n first = ignored_nodes[0] # Can be a container node with the `leaf`.\n parent = first.parent\n prefix = first.prefix\n if comment.value in FMT_OFF:\n first.prefix = prefix[comment.consumed :]\n if _contains_fmt_skip_comment(comment.value, mode):\n first.prefix = \"\"\n standalone_comment_prefix = prefix\n else:\n standalone_comment_prefix = (\n prefix[:previous_consumed] + \"\\n\" * comment.newlines\n )\n hidden_value = \"\".join(str(n) for n in ignored_nodes)\n comment_lineno = leaf.lineno - comment.newlines\n if comment.value in FMT_OFF:\n fmt_off_prefix = \"\"\n if len(lines) > 0 and not any(\n line[0] <= comment_lineno <= line[1] for line in lines\n ):\n # keeping indentation of comment by preserving original whitespaces.\n fmt_off_prefix = prefix.split(comment.value)[0]\n if \"\\n\" in fmt_off_prefix:\n fmt_off_prefix = fmt_off_prefix.split(\"\\n\")[-1]\n standalone_comment_prefix += fmt_off_prefix\n hidden_value = comment.value + \"\\n\" + hidden_value\n if _contains_fmt_skip_comment(comment.value, mode):\n hidden_value += (\n comment.leading_whitespace\n if Preview.no_normalize_fmt_skip_whitespace in mode\n else \" \"\n ) + comment.value\n if hidden_value.endswith(\"\\n\"):\n # That happens when one of the `ignored_nodes` ended with a NEWLINE\n # leaf (possibly followed by a DEDENT).\n hidden_value = hidden_value[:-1]\n first_idx: Optional[int] = None\n for ignored in ignored_nodes:\n index = ignored.remove()\n if first_idx is None:\n first_idx = index\n assert parent is not None, \"INTERNAL ERROR: fmt: on/off handling (1)\"\n assert first_idx is not None, \"INTERNAL ERROR: fmt: on/off handling (2)\"\n parent.insert_child(\n first_idx,\n Leaf(\n STANDALONE_COMMENT,\n hidden_value,\n prefix=standalone_comment_prefix,\n fmt_pass_converted_first_leaf=first_leaf_of(first),\n ),\n )\n return True\n\n return False\n\n\ndef generate_ignored_nodes(\n leaf: Leaf, comment: ProtoComment, mode: Mode\n) -> Iterator[LN]:\n \"\"\"Starting from the container of `leaf`, generate all leaves until `# fmt: on`.\n\n If comment is skip, returns leaf only.\n Stops at the end of the block.\n \"\"\"\n if _contains_fmt_skip_comment(comment.value, mode):\n yield from _generate_ignored_nodes_from_fmt_skip(leaf, comment)\n return\n container: Optional[LN] = container_of(leaf)\n while container is not None and container.type != token.ENDMARKER:\n if is_fmt_on(container):\n return\n\n # fix for fmt: on in children\n if children_contains_fmt_on(container):\n for index, child in enumerate(container.children):\n if isinstance(child, Leaf) and is_fmt_on(child):\n if child.type in CLOSING_BRACKETS:\n # This means `# fmt: on` is placed at a different bracket level\n # than `# fmt: off`. This is an invalid use, but as a courtesy,\n # we include this closing bracket in the ignored nodes.\n # The alternative is to fail the formatting.\n yield child\n return\n if (\n child.type == token.INDENT\n and index < len(container.children) - 1\n and children_contains_fmt_on(container.children[index + 1])\n ):\n # This means `# fmt: on` is placed right after an indentation\n # level, and we shouldn't swallow the previous INDENT token.\n return\n if children_contains_fmt_on(child):\n return\n yield child\n else:\n if container.type == token.DEDENT and container.next_sibling is None:\n # This can happen when there is no matching `# fmt: on` comment at the\n # same level as `# fmt: on`. We need to keep this DEDENT.\n return\n yield container\n container = container.next_sibling\n\n\ndef _generate_ignored_nodes_from_fmt_skip(\n leaf: Leaf, comment: ProtoComment\n) -> Iterator[LN]:\n \"\"\"Generate all leaves that should be ignored by the `# fmt: skip` from `leaf`.\"\"\"\n prev_sibling = leaf.prev_sibling\n parent = leaf.parent\n # Need to properly format the leaf prefix to compare it to comment.value,\n # which is also formatted\n comments = list_comments(leaf.prefix, is_endmarker=False)\n if not comments or comment.value != comments[0].value:\n return\n if prev_sibling is not None:\n leaf.prefix = \"\"\n siblings = [prev_sibling]\n while \"\\n\" not in prev_sibling.prefix and prev_sibling.prev_sibling is not None:\n prev_sibling = prev_sibling.prev_sibling\n siblings.insert(0, prev_sibling)\n yield from siblings\n elif (\n parent is not None and parent.type == syms.suite and leaf.type == token.NEWLINE\n ):\n # The `# fmt: skip` is on the colon line of the if/while/def/class/...\n # statements. The ignored nodes should be previous siblings of the\n # parent suite node.\n leaf.prefix = \"\"\n ignored_nodes: List[LN] = []\n parent_sibling = parent.prev_sibling\n while parent_sibling is not None and parent_sibling.type != syms.suite:\n ignored_nodes.insert(0, parent_sibling)\n parent_sibling = parent_sibling.prev_sibling\n # Special case for `async_stmt` where the ASYNC token is on the\n # grandparent node.\n grandparent = parent.parent\n if (\n grandparent is not None\n and grandparent.prev_sibling is not None\n and grandparent.prev_sibling.type == token.ASYNC\n ):\n ignored_nodes.insert(0, grandparent.prev_sibling)\n yield from iter(ignored_nodes)\n\n\ndef is_fmt_on(container: LN) -> bool:\n \"\"\"Determine whether formatting is switched on within a container.\n Determined by whether the last `# fmt:` comment is `on` or `off`.\n \"\"\"\n fmt_on = False\n for comment in list_comments(container.prefix, is_endmarker=False):\n if comment.value in FMT_ON:\n fmt_on = True\n elif comment.value in FMT_OFF:\n fmt_on = False\n return fmt_on\n\n\ndef children_contains_fmt_on(container: LN) -> bool:\n \"\"\"Determine if children have formatting switched on.\"\"\"\n for child in container.children:\n leaf = first_leaf_of(child)\n if leaf is not None and is_fmt_on(leaf):\n return True\n\n return False\n\n\ndef contains_pragma_comment(comment_list: List[Leaf]) -> bool:\n \"\"\"\n Returns:\n True iff one of the comments in @comment_list is a pragma used by one\n of the more common static analysis tools for python (e.g. mypy, flake8,\n pylint).\n \"\"\"\n for comment in comment_list:\n if comment.value.startswith((\"# type:\", \"# noqa\", \"# pylint:\")):\n return True\n\n return False\n\n\ndef _contains_fmt_skip_comment(comment_line: str, mode: Mode) -> bool:\n \"\"\"\n Checks if the given comment contains FMT_SKIP alone or paired with other comments.\n Matching styles:\n # fmt:skip <-- single comment\n # noqa:XXX # fmt:skip # a nice line <-- multiple comments (Preview)\n # pylint:XXX; fmt:skip <-- list of comments (; separated, Preview)\n \"\"\"\n semantic_comment_blocks = [\n comment_line,\n *[\n _COMMENT_PREFIX + comment.strip()\n for comment in comment_line.split(_COMMENT_PREFIX)[1:]\n ],\n *[\n _COMMENT_PREFIX + comment.strip()\n for comment in comment_line.strip(_COMMENT_PREFIX).split(\n _COMMENT_LIST_SEPARATOR\n )\n ],\n ]\n\n return any(comment in FMT_SKIP for comment in semantic_comment_blocks)\n\n# Path: src/black/report.py\n\"\"\"\nSummarize Black runs to users.\n\"\"\"\n\nfrom dataclasses import dataclass\nfrom enum import Enum\nfrom pathlib import Path\n\nfrom click import style\n\nfrom black.output import err, out\n\n\nclass Changed(Enum):\n NO = 0\n CACHED = 1\n YES = 2\n\n\nclass NothingChanged(UserWarning):\n \"\"\"Raised when reformatted code is the same as source.\"\"\"\n\n\n@dataclass\nclass Report:\n \"\"\"Provides a reformatting counter. Can be rendered with `str(report)`.\"\"\"\n\n check: bool = False\n diff: bool = False\n quiet: bool = False\n verbose: bool = False\n change_count: int = 0\n same_count: int = 0\n failure_count: int = 0\n\n def done(self, src: Path, changed: Changed) -> None:\n \"\"\"Increment the counter for successful reformatting. Write out a message.\"\"\"\n if changed is Changed.YES:\n reformatted = \"would reformat\" if self.check or self.diff else \"reformatted\"\n if self.verbose or not self.quiet:\n out(f\"{reformatted} {src}\")\n self.change_count += 1\n else:\n if self.verbose:\n if changed is Changed.NO:\n msg = f\"{src} already well formatted, good job.\"\n else:\n msg = f\"{src} wasn't modified on disk since last run.\"\n out(msg, bold=False)\n self.same_count += 1\n\n def failed(self, src: Path, message: str) -> None:\n \"\"\"Increment the counter for failed reformatting. Write out a message.\"\"\"\n err(f\"error: cannot format {src}: {message}\")\n self.failure_count += 1\n\n def path_ignored(self, path: Path, message: str) -> None:\n if self.verbose:\n out(f\"{path} ignored: {message}\", bold=False)\n\n @property\n def return_code(self) -> int:\n \"\"\"Return the exit code that the app should use.\n\n This considers the current state of changed files and failures:\n - if there were any failures, return 123;\n - if any files were changed and --check is being used, return 1;\n - otherwise return 0.\n \"\"\"\n # According to http://tldp.org/LDP/abs/html/exitcodes.html starting with\n # 126 we have special return codes reserved by the shell.\n if self.failure_count:\n return 123\n\n elif self.change_count and self.check:\n return 1\n\n return 0\n\n def __str__(self) -> str:\n \"\"\"Render a color report of the current state.\n\n Use `click.unstyle` to remove colors.\n \"\"\"\n if self.check or self.diff:\n reformatted = \"would be reformatted\"\n unchanged = \"would be left unchanged\"\n failed = \"would fail to reformat\"\n else:\n reformatted = \"reformatted\"\n unchanged = \"left unchanged\"\n failed = \"failed to reformat\"\n report = []\n if self.change_count:\n s = \"s\" if self.change_count > 1 else \"\"\n report.append(\n style(f\"{self.change_count} file{s} \", bold=True, fg=\"blue\")\n + style(f\"{reformatted}\", bold=True)\n )\n\n if self.same_count:\n s = \"s\" if self.same_count > 1 else \"\"\n report.append(style(f\"{self.same_count} file{s} \", fg=\"blue\") + unchanged)\n if self.failure_count:\n s = \"s\" if self.failure_count > 1 else \"\"\n report.append(style(f\"{self.failure_count} file{s} {failed}\", fg=\"red\"))\n return \", \".join(report) + \".\"\n\n# Path: src/black/concurrency.py\n\"\"\"\nFormatting many files at once via multiprocessing. Contains entrypoint and utilities.\n\nNOTE: this module is only imported if we need to format several files at once.\n\"\"\"\n\nimport asyncio\nimport logging\nimport os\nimport signal\nimport sys\nimport traceback\nfrom concurrent.futures import Executor, ProcessPoolExecutor, ThreadPoolExecutor\nfrom multiprocessing import Manager\nfrom pathlib import Path\nfrom typing import Any, Iterable, Optional, Set\n\nfrom mypy_extensions import mypyc_attr\n\nfrom black import WriteBack, format_file_in_place\nfrom black.cache import Cache\nfrom black.mode import Mode\nfrom black.output import err\nfrom black.report import Changed, Report\n\n\ndef maybe_install_uvloop() -> None:\n \"\"\"If our environment has uvloop installed we use it.\n\n This is called only from command-line entry points to avoid\n interfering with the parent process if Black is used as a library.\n \"\"\"\n try:\n import uvloop\n\n uvloop.install()\n except ImportError:\n pass\n\n\ndef cancel(tasks: Iterable[\"asyncio.Future[Any]\"]) -> None:\n \"\"\"asyncio signal handler that cancels all `tasks` and reports to stderr.\"\"\"\n err(\"Aborted!\")\n for task in tasks:\n task.cancel()\n\n\ndef shutdown(loop: asyncio.AbstractEventLoop) -> None:\n \"\"\"Cancel all pending tasks on `loop`, wait for them, and close the loop.\"\"\"\n try:\n # This part is borrowed from asyncio/runners.py in Python 3.7b2.\n to_cancel = [task for task in asyncio.all_tasks(loop) if not task.done()]\n if not to_cancel:\n return\n\n for task in to_cancel:\n task.cancel()\n loop.run_until_complete(asyncio.gather(*to_cancel, return_exceptions=True))\n finally:\n # `concurrent.futures.Future` objects cannot be cancelled once they\n # are already running. There might be some when the `shutdown()` happened.\n # Silence their logger's spew about the event loop being closed.\n cf_logger = logging.getLogger(\"concurrent.futures\")\n cf_logger.setLevel(logging.CRITICAL)\n loop.close()\n\n\n# diff-shades depends on being to monkeypatch this function to operate. I know it's\n# not ideal, but this shouldn't cause any issues ... hopefully. ~ichard26\n@mypyc_attr(patchable=True)\ndef reformat_many(\n sources: Set[Path],\n fast: bool,\n write_back: WriteBack,\n mode: Mode,\n report: Report,\n workers: Optional[int],\n) -> None:\n \"\"\"Reformat multiple files using a ProcessPoolExecutor.\"\"\"\n maybe_install_uvloop()\n\n executor: Executor\n if workers is None:\n workers = int(os.environ.get(\"BLACK_NUM_WORKERS\", 0))\n workers = workers or os.cpu_count() or 1\n if sys.platform == \"win32\":\n # Work around https://bugs.python.org/issue26903\n workers = min(workers, 60)\n try:\n executor = ProcessPoolExecutor(max_workers=workers)\n except (ImportError, NotImplementedError, OSError):\n # we arrive here if the underlying system does not support multi-processing\n # like in AWS Lambda or Termux, in which case we gracefully fallback to\n # a ThreadPoolExecutor with just a single worker (more workers would not do us\n # any good due to the Global Interpreter Lock)\n...\n", "needle_token_start": 3951, "needle_token_end": 4036, "code_context_ntokens": 16048, "output": ["```python\ndef make_simple_prefix(nl_count: int, form_feed: bool, empty_line: str = \"\\n\") -> str:\n \"\"\"Generate a normalized prefix string.\"\"\"\n if form_feed:\n return (empty_line * (nl_count - 1)) + \"\\f\" + empty_line\n return empty_line * nl_count\n```"]} +{"repo": "psf/black", "name": "run_transformer", "language": "python", "path": "src/black/linegen.py", "position_ratio": 0.35, "description": "\nFunction Description:\n1. **Purpose**: The function is designed to apply a transformation to a line of code based on specified rules and conditions, and then evaluate the transformed line to ensure it meets certain criteria. If the initial transformation does not meet the criteria, a secondary transformation is attempted with additional constraints.\n2. **Input**: The function accepts a line of code, a transformation function, a mode specifying formatting rules, a collection of features that may influence the transformation, and optionally a string representation of the line.\n3. **Output**: It returns a list of transformed lines of code. If the transformation is unsuccessful or unnecessary, the original line or its direct transformation is returned.\n4. **Procedure**: \n - Convert the input line to a string if not provided.\n - Apply the transformation function to the line.\n - Check if the transformed line is unchanged; if so, raise an exception.\n - Validate the transformed line against a set of criteria including line length and structural integrity.\n - If the initial transformation does not meet all criteria, clone the original line and attempt a secondary transformation with an additional feature.\n - Return the result of the transformation that best meets the criteria.\n\n", "instruction": "Based on the function description and code context, please retrieve and repeat the exact described function from the code context in a code block wrapped by ```:", "template": "instruction\ncode_context\ndescription\ninstruction", "code_context": " for comment_after in line.comments_after(leaf):\n yield from append_to_line(comment_after)\n\n last_non_comment_leaf = _get_last_non_comment_leaf(line)\n for leaf_idx, leaf in enumerate(line.leaves):\n yield from append_to_line(leaf)\n\n previous_priority = leaf_idx > 0 and bt.delimiters.get(\n id(line.leaves[leaf_idx - 1])\n )\n if (\n previous_priority != delimiter_priority\n or delimiter_priority in MIGRATE_COMMENT_DELIMITERS\n ):\n yield from append_comments(leaf)\n\n lowest_depth = min(lowest_depth, leaf.bracket_depth)\n if trailing_comma_safe and leaf.bracket_depth == lowest_depth:\n trailing_comma_safe = _can_add_trailing_comma(leaf, features)\n\n if last_leaf.type == STANDALONE_COMMENT and leaf_idx == last_non_comment_leaf:\n current_line = _safe_add_trailing_comma(\n trailing_comma_safe, delimiter_priority, current_line\n )\n\n leaf_priority = bt.delimiters.get(id(leaf))\n if leaf_priority == delimiter_priority:\n if (\n leaf_idx + 1 < len(line.leaves)\n and delimiter_priority not in MIGRATE_COMMENT_DELIMITERS\n ):\n yield from append_comments(line.leaves[leaf_idx + 1])\n\n yield current_line\n current_line = Line(\n mode=line.mode, depth=line.depth, inside_brackets=line.inside_brackets\n )\n\n if current_line:\n current_line = _safe_add_trailing_comma(\n trailing_comma_safe, delimiter_priority, current_line\n )\n yield current_line\n\n\n@dont_increase_indentation\ndef standalone_comment_split(\n line: Line, features: Collection[Feature], mode: Mode\n) -> Iterator[Line]:\n \"\"\"Split standalone comments from the rest of the line.\"\"\"\n if not line.contains_standalone_comments():\n raise CannotSplit(\"Line does not have any standalone comments\")\n\n current_line = Line(\n mode=line.mode, depth=line.depth, inside_brackets=line.inside_brackets\n )\n\n def append_to_line(leaf: Leaf) -> Iterator[Line]:\n \"\"\"Append `leaf` to current line or to new line if appending impossible.\"\"\"\n nonlocal current_line\n try:\n current_line.append_safe(leaf, preformatted=True)\n except ValueError:\n yield current_line\n\n current_line = Line(\n line.mode, depth=line.depth, inside_brackets=line.inside_brackets\n )\n current_line.append(leaf)\n\n for leaf in line.leaves:\n yield from append_to_line(leaf)\n\n for comment_after in line.comments_after(leaf):\n yield from append_to_line(comment_after)\n\n if current_line:\n yield current_line\n\n\ndef normalize_invisible_parens( # noqa: C901\n node: Node, parens_after: Set[str], *, mode: Mode, features: Collection[Feature]\n) -> None:\n \"\"\"Make existing optional parentheses invisible or create new ones.\n\n `parens_after` is a set of string leaf values immediately after which parens\n should be put.\n\n Standardizes on visible parentheses for single-element tuples, and keeps\n existing visible parentheses for other tuples and generator expressions.\n \"\"\"\n for pc in list_comments(node.prefix, is_endmarker=False):\n if pc.value in FMT_OFF:\n # This `node` has a prefix with `# fmt: off`, don't mess with parens.\n return\n\n # The multiple context managers grammar has a different pattern, thus this is\n # separate from the for-loop below. This possibly wraps them in invisible parens,\n # and later will be removed in remove_with_parens when needed.\n if node.type == syms.with_stmt:\n _maybe_wrap_cms_in_parens(node, mode, features)\n\n check_lpar = False\n for index, child in enumerate(list(node.children)):\n # Fixes a bug where invisible parens are not properly stripped from\n # assignment statements that contain type annotations.\n if isinstance(child, Node) and child.type == syms.annassign:\n normalize_invisible_parens(\n child, parens_after=parens_after, mode=mode, features=features\n )\n\n # Fixes a bug where invisible parens are not properly wrapped around\n # case blocks.\n if isinstance(child, Node) and child.type == syms.case_block:\n normalize_invisible_parens(\n child, parens_after={\"case\"}, mode=mode, features=features\n )\n\n # Add parentheses around long tuple unpacking in assignments.\n if (\n index == 0\n and isinstance(child, Node)\n and child.type == syms.testlist_star_expr\n ):\n check_lpar = True\n\n if check_lpar:\n if (\n child.type == syms.atom\n and node.type == syms.for_stmt\n and isinstance(child.prev_sibling, Leaf)\n and child.prev_sibling.type == token.NAME\n and child.prev_sibling.value == \"for\"\n ):\n if maybe_make_parens_invisible_in_atom(\n child,\n parent=node,\n remove_brackets_around_comma=True,\n ):\n wrap_in_parentheses(node, child, visible=False)\n elif isinstance(child, Node) and node.type == syms.with_stmt:\n remove_with_parens(child, node)\n elif child.type == syms.atom:\n if maybe_make_parens_invisible_in_atom(\n child,\n parent=node,\n ):\n wrap_in_parentheses(node, child, visible=False)\n elif is_one_tuple(child):\n wrap_in_parentheses(node, child, visible=True)\n elif node.type == syms.import_from:\n _normalize_import_from(node, child, index)\n break\n elif (\n index == 1\n and child.type == token.STAR\n and node.type == syms.except_clause\n ):\n # In except* (PEP 654), the star is actually part of\n # of the keyword. So we need to skip the insertion of\n # invisible parentheses to work more precisely.\n continue\n\n elif (\n isinstance(child, Leaf)\n and child.next_sibling is not None\n and child.next_sibling.type == token.COLON\n and child.value == \"case\"\n ):\n # A special patch for \"case case:\" scenario, the second occurrence\n # of case will be not parsed as a Python keyword.\n break\n\n elif not (isinstance(child, Leaf) and is_multiline_string(child)):\n wrap_in_parentheses(node, child, visible=False)\n\n comma_check = child.type == token.COMMA\n\n check_lpar = isinstance(child, Leaf) and (\n child.value in parens_after or comma_check\n )\n\n\ndef _normalize_import_from(parent: Node, child: LN, index: int) -> None:\n # \"import from\" nodes store parentheses directly as part of\n # the statement\n if is_lpar_token(child):\n assert is_rpar_token(parent.children[-1])\n # make parentheses invisible\n child.value = \"\"\n parent.children[-1].value = \"\"\n elif child.type != token.STAR:\n # insert invisible parentheses\n parent.insert_child(index, Leaf(token.LPAR, \"\"))\n parent.append_child(Leaf(token.RPAR, \"\"))\n\n\ndef remove_await_parens(node: Node) -> None:\n if node.children[0].type == token.AWAIT and len(node.children) > 1:\n if (\n node.children[1].type == syms.atom\n and node.children[1].children[0].type == token.LPAR\n ):\n if maybe_make_parens_invisible_in_atom(\n node.children[1],\n parent=node,\n remove_brackets_around_comma=True,\n ):\n wrap_in_parentheses(node, node.children[1], visible=False)\n\n # Since await is an expression we shouldn't remove\n # brackets in cases where this would change\n # the AST due to operator precedence.\n # Therefore we only aim to remove brackets around\n # power nodes that aren't also await expressions themselves.\n # https://peps.python.org/pep-0492/#updated-operator-precedence-table\n # N.B. We've still removed any redundant nested brackets though :)\n opening_bracket = cast(Leaf, node.children[1].children[0])\n closing_bracket = cast(Leaf, node.children[1].children[-1])\n bracket_contents = node.children[1].children[1]\n if isinstance(bracket_contents, Node) and (\n bracket_contents.type != syms.power\n or bracket_contents.children[0].type == token.AWAIT\n or any(\n isinstance(child, Leaf) and child.type == token.DOUBLESTAR\n for child in bracket_contents.children\n )\n ):\n ensure_visible(opening_bracket)\n ensure_visible(closing_bracket)\n\n\ndef _maybe_wrap_cms_in_parens(\n node: Node, mode: Mode, features: Collection[Feature]\n) -> None:\n \"\"\"When enabled and safe, wrap the multiple context managers in invisible parens.\n\n It is only safe when `features` contain Feature.PARENTHESIZED_CONTEXT_MANAGERS.\n \"\"\"\n if (\n Feature.PARENTHESIZED_CONTEXT_MANAGERS not in features\n or len(node.children) <= 2\n # If it's an atom, it's already wrapped in parens.\n or node.children[1].type == syms.atom\n ):\n return\n colon_index: Optional[int] = None\n for i in range(2, len(node.children)):\n if node.children[i].type == token.COLON:\n colon_index = i\n break\n if colon_index is not None:\n lpar = Leaf(token.LPAR, \"\")\n rpar = Leaf(token.RPAR, \"\")\n context_managers = node.children[1:colon_index]\n for child in context_managers:\n child.remove()\n # After wrapping, the with_stmt will look like this:\n # with_stmt\n # NAME 'with'\n # atom\n # LPAR ''\n # testlist_gexp\n # ... <-- context_managers\n # /testlist_gexp\n # RPAR ''\n # /atom\n # COLON ':'\n new_child = Node(\n syms.atom, [lpar, Node(syms.testlist_gexp, context_managers), rpar]\n )\n node.insert_child(1, new_child)\n\n\ndef remove_with_parens(node: Node, parent: Node) -> None:\n \"\"\"Recursively hide optional parens in `with` statements.\"\"\"\n # Removing all unnecessary parentheses in with statements in one pass is a tad\n # complex as different variations of bracketed statements result in pretty\n # different parse trees:\n #\n # with (open(\"file\")) as f: # this is an asexpr_test\n # ...\n #\n # with (open(\"file\") as f): # this is an atom containing an\n # ... # asexpr_test\n #\n # with (open(\"file\")) as f, (open(\"file\")) as f: # this is asexpr_test, COMMA,\n # ... # asexpr_test\n #\n # with (open(\"file\") as f, open(\"file\") as f): # an atom containing a\n # ... # testlist_gexp which then\n # # contains multiple asexpr_test(s)\n if node.type == syms.atom:\n if maybe_make_parens_invisible_in_atom(\n node,\n parent=parent,\n remove_brackets_around_comma=True,\n ):\n wrap_in_parentheses(parent, node, visible=False)\n if isinstance(node.children[1], Node):\n remove_with_parens(node.children[1], node)\n elif node.type == syms.testlist_gexp:\n for child in node.children:\n if isinstance(child, Node):\n remove_with_parens(child, node)\n elif node.type == syms.asexpr_test and not any(\n leaf.type == token.COLONEQUAL for leaf in node.leaves()\n ):\n if maybe_make_parens_invisible_in_atom(\n node.children[0],\n parent=node,\n remove_brackets_around_comma=True,\n ):\n wrap_in_parentheses(node, node.children[0], visible=False)\n\n\ndef maybe_make_parens_invisible_in_atom(\n node: LN,\n parent: LN,\n remove_brackets_around_comma: bool = False,\n) -> bool:\n \"\"\"If it's safe, make the parens in the atom `node` invisible, recursively.\n Additionally, remove repeated, adjacent invisible parens from the atom `node`\n as they are redundant.\n\n Returns whether the node should itself be wrapped in invisible parentheses.\n \"\"\"\n if (\n node.type not in (syms.atom, syms.expr)\n or is_empty_tuple(node)\n or is_one_tuple(node)\n or (is_yield(node) and parent.type != syms.expr_stmt)\n or (\n # This condition tries to prevent removing non-optional brackets\n # around a tuple, however, can be a bit overzealous so we provide\n # and option to skip this check for `for` and `with` statements.\n not remove_brackets_around_comma\n and max_delimiter_priority_in_atom(node) >= COMMA_PRIORITY\n )\n or is_tuple_containing_walrus(node)\n ):\n return False\n\n if is_walrus_assignment(node):\n if parent.type in [\n syms.annassign,\n syms.expr_stmt,\n syms.assert_stmt,\n syms.return_stmt,\n syms.except_clause,\n syms.funcdef,\n syms.with_stmt,\n syms.tname,\n # these ones aren't useful to end users, but they do please fuzzers\n syms.for_stmt,\n syms.del_stmt,\n syms.for_stmt,\n ]:\n return False\n\n first = node.children[0]\n last = node.children[-1]\n if is_lpar_token(first) and is_rpar_token(last):\n middle = node.children[1]\n # make parentheses invisible\n if (\n # If the prefix of `middle` includes a type comment with\n # ignore annotation, then we do not remove the parentheses\n not is_type_ignore_comment_string(middle.prefix.strip())\n ):\n first.value = \"\"\n if first.prefix.strip():\n # Preserve comments before first paren\n middle.prefix = first.prefix + middle.prefix\n last.value = \"\"\n maybe_make_parens_invisible_in_atom(\n middle,\n parent=parent,\n remove_brackets_around_comma=remove_brackets_around_comma,\n )\n\n if is_atom_with_invisible_parens(middle):\n # Strip the invisible parens from `middle` by replacing\n # it with the child in-between the invisible parens\n middle.replace(middle.children[1])\n if middle.children[-1].prefix.strip():\n # Preserve comments before last paren\n last.prefix = middle.children[-1].prefix + last.prefix\n\n return False\n\n return True\n\n\ndef should_split_line(line: Line, opening_bracket: Leaf) -> bool:\n \"\"\"Should `line` be immediately split with `delimiter_split()` after RHS?\"\"\"\n\n if not (opening_bracket.parent and opening_bracket.value in \"[{(\"):\n return False\n\n # We're essentially checking if the body is delimited by commas and there's more\n # than one of them (we're excluding the trailing comma and if the delimiter priority\n # is still commas, that means there's more).\n exclude = set()\n trailing_comma = False\n try:\n last_leaf = line.leaves[-1]\n if last_leaf.type == token.COMMA:\n trailing_comma = True\n exclude.add(id(last_leaf))\n max_priority = line.bracket_tracker.max_delimiter_priority(exclude=exclude)\n except (IndexError, ValueError):\n return False\n\n return max_priority == COMMA_PRIORITY and (\n (line.mode.magic_trailing_comma and trailing_comma)\n # always explode imports\n or opening_bracket.parent.type in {syms.atom, syms.import_from}\n )\n\n\ndef generate_trailers_to_omit(line: Line, line_length: int) -> Iterator[Set[LeafID]]:\n \"\"\"Generate sets of closing bracket IDs that should be omitted in a RHS.\n\n Brackets can be omitted if the entire trailer up to and including\n a preceding closing bracket fits in one line.\n\n Yielded sets are cumulative (contain results of previous yields, too). First\n set is empty, unless the line should explode, in which case bracket pairs until\n the one that needs to explode are omitted.\n \"\"\"\n\n omit: Set[LeafID] = set()\n if not line.magic_trailing_comma:\n yield omit\n\n length = 4 * line.depth\n opening_bracket: Optional[Leaf] = None\n closing_bracket: Optional[Leaf] = None\n inner_brackets: Set[LeafID] = set()\n for index, leaf, leaf_length in line.enumerate_with_length(is_reversed=True):\n length += leaf_length\n if length > line_length:\n break\n\n has_inline_comment = leaf_length > len(leaf.value) + len(leaf.prefix)\n if leaf.type == STANDALONE_COMMENT or has_inline_comment:\n break\n\n if opening_bracket:\n if leaf is opening_bracket:\n opening_bracket = None\n elif leaf.type in CLOSING_BRACKETS:\n prev = line.leaves[index - 1] if index > 0 else None\n if (\n prev\n and prev.type == token.COMMA\n and leaf.opening_bracket is not None\n and not is_one_sequence_between(\n leaf.opening_bracket, leaf, line.leaves\n )\n ):\n # Never omit bracket pairs with trailing commas.\n # We need to explode on those.\n break\n\n inner_brackets.add(id(leaf))\n elif leaf.type in CLOSING_BRACKETS:\n prev = line.leaves[index - 1] if index > 0 else None\n if prev and prev.type in OPENING_BRACKETS:\n # Empty brackets would fail a split so treat them as \"inner\"\n # brackets (e.g. only add them to the `omit` set if another\n # pair of brackets was good enough.\n inner_brackets.add(id(leaf))\n continue\n\n if closing_bracket:\n omit.add(id(closing_bracket))\n omit.update(inner_brackets)\n inner_brackets.clear()\n yield omit\n\n if (\n prev\n and prev.type == token.COMMA\n and leaf.opening_bracket is not None\n and not is_one_sequence_between(leaf.opening_bracket, leaf, line.leaves)\n ):\n # Never omit bracket pairs with trailing commas.\n # We need to explode on those.\n break\n\n if leaf.value:\n opening_bracket = leaf.opening_bracket\n closing_bracket = leaf\n\n\ndef run\n_transformer(\n line: Line,\n transform: Transformer,\n mode: Mode,\n features: Collection[Feature],\n *,\n line_str: str = \"\",\n) -> List[Line]:\n if not line_str:\n line_str = line_to_string(line)\n result: List[Line] = []\n for transformed_line in transform(line, features, mode):\n if str(transformed_line).strip(\"\\n\") == line_str:\n raise CannotTransform(\"Line transformer returned an unchanged result\")\n\n result.extend(transform_line(transformed_line, mode=mode, features=features))\n\n features_set = set(features)\n if (\n Feature.FORCE_OPTIONAL_PARENTHESES in features_set\n or transform.__class__.__name__ != \"rhs\"\n or not line.bracket_tracker.invisible\n or any(bracket.value for bracket in line.bracket_tracker.invisible)\n or line.contains_multiline_strings()\n or result[0].contains_uncollapsable_type_comments()\n or result[0].contains_unsplittable_type_ignore()\n or is_line_short_enough(result[0], mode=mode)\n # If any leaves have no parents (which _can_ occur since\n # `transform(line)` potentially destroys the line's underlying node\n # structure), then we can't proceed. Doing so would cause the below\n # call to `append_leaves()` to fail.\n or any(leaf.parent is None for leaf in line.leaves)\n ):\n return result\n\n line_copy = line.clone()\n append_leaves(line_copy, line, line.leaves)\n features_fop = features_set | {Feature.FORCE_OPTIONAL_PARENTHESES}\n second_opinion = run_transformer(\n line_copy, transform, mode, features_fop, line_str=line_str\n )\n if all(is_line_short_enough(ln, mode=mode) for ln in second_opinion):\n result = second_opinion\n return result\n\n# Path: src/black/parsing.py\n\"\"\"\nParse Python code and perform AST validation.\n\"\"\"\n\nimport ast\nimport sys\nimport warnings\nfrom typing import Iterable, Iterator, List, Set, Tuple\n\nfrom black.mode import VERSION_TO_FEATURES, Feature, TargetVersion, supports_feature\nfrom black.nodes import syms\nfrom blib2to3 import pygram\nfrom blib2to3.pgen2 import driver\nfrom blib2to3.pgen2.grammar import Grammar\nfrom blib2to3.pgen2.parse import ParseError\nfrom blib2to3.pgen2.tokenize import TokenError\nfrom blib2to3.pytree import Leaf, Node\n\n\nclass InvalidInput(ValueError):\n \"\"\"Raised when input source code fails all parse attempts.\"\"\"\n\n\ndef get_grammars(target_versions: Set[TargetVersion]) -> List[Grammar]:\n if not target_versions:\n # No target_version specified, so try all grammars.\n return [\n # Python 3.7-3.9\n pygram.python_grammar_async_keywords,\n # Python 3.0-3.6\n pygram.python_grammar,\n # Python 3.10+\n pygram.python_grammar_soft_keywords,\n ]\n\n grammars = []\n # If we have to parse both, try to parse async as a keyword first\n if not supports_feature(\n target_versions, Feature.ASYNC_IDENTIFIERS\n ) and not supports_feature(target_versions, Feature.PATTERN_MATCHING):\n # Python 3.7-3.9\n grammars.append(pygram.python_grammar_async_keywords)\n if not supports_feature(target_versions, Feature.ASYNC_KEYWORDS):\n # Python 3.0-3.6\n grammars.append(pygram.python_grammar)\n if any(Feature.PATTERN_MATCHING in VERSION_TO_FEATURES[v] for v in target_versions):\n # Python 3.10+\n grammars.append(pygram.python_grammar_soft_keywords)\n\n # At least one of the above branches must have been taken, because every Python\n # version has exactly one of the two 'ASYNC_*' flags\n return grammars\n\n\ndef lib2to3_parse(src_txt: str, target_versions: Iterable[TargetVersion] = ()) -> Node:\n \"\"\"Given a string with source, return the lib2to3 Node.\"\"\"\n if not src_txt.endswith(\"\\n\"):\n src_txt += \"\\n\"\n\n grammars = get_grammars(set(target_versions))\n errors = {}\n for grammar in grammars:\n drv = driver.Driver(grammar)\n try:\n result = drv.parse_string(src_txt, True)\n break\n\n except ParseError as pe:\n lineno, column = pe.context[1]\n lines = src_txt.splitlines()\n try:\n faulty_line = lines[lineno - 1]\n except IndexError:\n faulty_line = \"\"\n errors[grammar.version] = InvalidInput(\n f\"Cannot parse: {lineno}:{column}: {faulty_line}\"\n )\n\n except TokenError as te:\n # In edge cases these are raised; and typically don't have a \"faulty_line\".\n lineno, column = te.args[1]\n errors[grammar.version] = InvalidInput(\n f\"Cannot parse: {lineno}:{column}: {te.args[0]}\"\n )\n\n else:\n # Choose the latest version when raising the actual parsing error.\n assert len(errors) >= 1\n exc = errors[max(errors)]\n raise exc from None\n\n if isinstance(result, Leaf):\n result = Node(syms.file_input, [result])\n return result\n\n\ndef matches_grammar(src_txt: str, grammar: Grammar) -> bool:\n drv = driver.Driver(grammar)\n try:\n drv.parse_string(src_txt, True)\n except (ParseError, TokenError, IndentationError):\n return False\n else:\n return True\n\n\ndef lib2to3_unparse(node: Node) -> str:\n \"\"\"Given a lib2to3 node, return its string representation.\"\"\"\n code = str(node)\n return code\n\n\ndef _parse_single_version(\n src: str, version: Tuple[int, int], *, type_comments: bool\n) -> ast.AST:\n filename = \"\"\n with warnings.catch_warnings():\n warnings.simplefilter(\"ignore\", SyntaxWarning)\n warnings.simplefilter(\"ignore\", DeprecationWarning)\n return ast.parse(\n src, filename, feature_version=version, type_comments=type_comments\n )\n\n\ndef parse_ast(src: str) -> ast.AST:\n # TODO: support Python 4+ ;)\n versions = [(3, minor) for minor in range(3, sys.version_info[1] + 1)]\n\n first_error = \"\"\n for version in sorted(versions, reverse=True):\n try:\n return _parse_single_version(src, version, type_comments=True)\n except SyntaxError as e:\n if not first_error:\n first_error = str(e)\n\n # Try to parse without type comments\n for version in sorted(versions, reverse=True):\n try:\n return _parse_single_version(src, version, type_comments=False)\n except SyntaxError:\n pass\n\n raise SyntaxError(first_error)\n\n\ndef _normalize(lineend: str, value: str) -> str:\n # To normalize, we strip any leading and trailing space from\n # each line...\n stripped: List[str] = [i.strip() for i in value.splitlines()]\n normalized = lineend.join(stripped)\n # ...and remove any blank lines at the beginning and end of\n # the whole string\n return normalized.strip()\n\n\ndef stringify_ast(node: ast.AST, depth: int = 0) -> Iterator[str]:\n \"\"\"Simple visitor generating strings to compare ASTs by content.\"\"\"\n\n if (\n isinstance(node, ast.Constant)\n and isinstance(node.value, str)\n and node.kind == \"u\"\n ):\n # It's a quirk of history that we strip the u prefix over here. We used to\n # rewrite the AST nodes for Python version compatibility and we never copied\n # over the kind\n node.kind = None\n\n yield f\"{' ' * depth}{node.__class__.__name__}(\"\n\n for field in sorted(node._fields): # noqa: F402\n # TypeIgnore has only one field 'lineno' which breaks this comparison\n if isinstance(node, ast.TypeIgnore):\n break\n\n try:\n value: object = getattr(node, field)\n except AttributeError:\n continue\n\n yield f\"{' ' * (depth + 1)}{field}=\"\n\n if isinstance(value, list):\n for item in value:\n # Ignore nested tuples within del statements, because we may insert\n # parentheses and they change the AST.\n if (\n field == \"targets\"\n and isinstance(node, ast.Delete)\n and isinstance(item, ast.Tuple)\n ):\n for elt in item.elts:\n yield from stringify_ast(elt, depth + 2)\n\n elif isinstance(item, ast.AST):\n yield from stringify_ast(item, depth + 2)\n\n elif isinstance(value, ast.AST):\n yield from stringify_ast(value, depth + 2)\n\n else:\n normalized: object\n if (\n isinstance(node, ast.Constant)\n and field == \"value\"\n and isinstance(value, str)\n ):\n # Constant strings may be indented across newlines, if they are\n # docstrings; fold spaces after newlines when comparing. Similarly,\n # trailing and leading space may be removed.\n normalized = _normalize(\"\\n\", value)\n elif field == \"type_comment\" and isinstance(value, str):\n # Trailing whitespace in type comments is removed.\n normalized = value.rstrip()\n else:\n normalized = value\n yield f\"{' ' * (depth + 2)}{normalized!r}, # {value.__class__.__name__}\"\n\n yield f\"{' ' * depth}) # /{node.__class__.__name__}\"\n\n# Path: src/black/ranges.py\n\"\"\"Functions related to Black's formatting by line ranges feature.\"\"\"\n\nimport difflib\nfrom dataclasses import dataclass\nfrom typing import Collection, Iterator, List, Sequence, Set, Tuple, Union\n\nfrom black.nodes import (\n LN,\n STANDALONE_COMMENT,\n Leaf,\n Node,\n Visitor,\n first_leaf,\n furthest_ancestor_with_last_leaf,\n last_leaf,\n syms,\n)\nfrom blib2to3.pgen2.token import ASYNC, NEWLINE\n\n\ndef parse_line_ranges(line_ranges: Sequence[str]) -> List[Tuple[int, int]]:\n lines: List[Tuple[int, int]] = []\n for lines_str in line_ranges:\n parts = lines_str.split(\"-\")\n if len(parts) != 2:\n raise ValueError(\n \"Incorrect --line-ranges format, expect 'START-END', found\"\n f\" {lines_str!r}\"\n )\n try:\n start = int(parts[0])\n end = int(parts[1])\n except ValueError:\n raise ValueError(\n \"Incorrect --line-ranges value, expect integer ranges, found\"\n f\" {lines_str!r}\"\n ) from None\n else:\n lines.append((start, end))\n return lines\n\n\ndef is_valid_line_range(lines: Tuple[int, int]) -> bool:\n \"\"\"Returns whether the line range is valid.\"\"\"\n return not lines or lines[0] <= lines[1]\n\n\ndef adjusted_lines(\n lines: Collection[Tuple[int, int]],\n original_source: str,\n modified_source: str,\n) -> List[Tuple[int, int]]:\n \"\"\"Returns the adjusted line ranges based on edits from the original code.\n\n This computes the new line ranges by diffing original_source and\n modified_source, and adjust each range based on how the range overlaps with\n the diffs.\n\n Note the diff can contain lines outside of the original line ranges. This can\n happen when the formatting has to be done in adjacent to maintain consistent\n local results. For example:\n\n 1. def my_func(arg1, arg2,\n 2. arg3,):\n 3. pass\n\n If it restricts to line 2-2, it can't simply reformat line 2, it also has\n to reformat line 1:\n\n 1. def my_func(\n 2. arg1,\n 3. arg2,\n 4. arg3,\n 5. ):\n 6. pass\n\n In this case, we will expand the line ranges to also include the whole diff\n block.\n\n Args:\n lines: a collection of line ranges.\n original_source: the original source.\n modified_source: the modified source.\n \"\"\"\n lines_mappings = _calculate_lines_mappings(original_source, modified_source)\n\n new_lines = []\n # Keep an index of the current search. Since the lines and lines_mappings are\n # sorted, this makes the search complexity linear.\n current_mapping_index = 0\n for start, end in sorted(lines):\n start_mapping_index = _find_lines_mapping_index(\n start,\n lines_mappings,\n current_mapping_index,\n )\n end_mapping_index = _find_lines_mapping_index(\n end,\n lines_mappings,\n start_mapping_index,\n )\n current_mapping_index = start_mapping_index\n if start_mapping_index >= len(lines_mappings) or end_mapping_index >= len(\n lines_mappings\n ):\n # Protect against invalid inputs.\n continue\n start_mapping = lines_mappings[start_mapping_index]\n end_mapping = lines_mappings[end_mapping_index]\n if start_mapping.is_changed_block:\n # When the line falls into a changed block, expands to the whole block.\n new_start = start_mapping.modified_start\n else:\n new_start = (\n start - start_mapping.original_start + start_mapping.modified_start\n )\n if end_mapping.is_changed_block:\n # When the line falls into a changed block, expands to the whole block.\n new_end = end_mapping.modified_end\n else:\n new_end = end - end_mapping.original_start + end_mapping.modified_start\n new_range = (new_start, new_end)\n if is_valid_line_range(new_range):\n new_lines.append(new_range)\n return new_lines\n\n\ndef convert_unchanged_lines(src_node: Node, lines: Collection[Tuple[int, int]]) -> None:\n \"\"\"Converts unchanged lines to STANDALONE_COMMENT.\n\n The idea is similar to how `# fmt: on/off` is implemented. It also converts the\n nodes between those markers as a single `STANDALONE_COMMENT` leaf node with\n the unformatted code as its value. `STANDALONE_COMMENT` is a \"fake\" token\n that will be formatted as-is with its prefix normalized.\n\n Here we perform two passes:\n\n 1. Visit the top-level statements, and convert them to a single\n `STANDALONE_COMMENT` when unchanged. This speeds up formatting when some\n of the top-level statements aren't changed.\n 2. Convert unchanged \"unwrapped lines\" to `STANDALONE_COMMENT` nodes line by\n line. \"unwrapped lines\" are divided by the `NEWLINE` token. e.g. a\n multi-line statement is *one* \"unwrapped line\" that ends with `NEWLINE`,\n even though this statement itself can span multiple lines, and the\n tokenizer only sees the last '\\n' as the `NEWLINE` token.\n\n NOTE: During pass (2), comment prefixes and indentations are ALWAYS\n normalized even when the lines aren't changed. This is fixable by moving\n more formatting to pass (1). However, it's hard to get it correct when\n incorrect indentations are used. So we defer this to future optimizations.\n \"\"\"\n lines_set: Set[int] = set()\n for start, end in lines:\n lines_set.update(range(start, end + 1))\n visitor = _TopLevelStatementsVisitor(lines_set)\n _ = list(visitor.visit(src_node)) # Consume all results.\n _convert_unchanged_line_by_line(src_node, lines_set)\n\n\ndef _contains_standalone_comment(node: LN) -> bool:\n if isinstance(node, Leaf):\n return node.type == STANDALONE_COMMENT\n else:\n for child in node.children:\n if _contains_standalone_comment(child):\n return True\n return False\n\n\nclass _TopLevelStatementsVisitor(Visitor[None]):\n \"\"\"\n A node visitor that converts unchanged top-level statements to\n STANDALONE_COMMENT.\n\n This is used in addition to _convert_unchanged_line_by_line, to\n speed up formatting when there are unchanged top-level\n classes/functions/statements.\n \"\"\"\n\n def __init__(self, lines_set: Set[int]):\n self._lines_set = lines_set\n\n def visit_simple_stmt(self, node: Node) -> Iterator[None]:\n # This is only called for top-level statements, since `visit_suite`\n # won't visit its children nodes.\n yield from []\n newline_leaf = last_leaf(node)\n if not newline_leaf:\n return\n assert (\n newline_leaf.type == NEWLINE\n ), f\"Unexpectedly found leaf.type={newline_leaf.type}\"\n # We need to find the furthest ancestor with the NEWLINE as the last\n # leaf, since a `suite` can simply be a `simple_stmt` when it puts\n # its body on the same line. Example: `if cond: pass`.\n ancestor = furthest_ancestor_with_last_leaf(newline_leaf)\n if not _get_line_range(ancestor).intersection(self._lines_set):\n _convert_node_to_standalone_comment(ancestor)\n\n def visit_suite(self, node: Node) -> Iterator[None]:\n yield from []\n # If there is a STANDALONE_COMMENT node, it means parts of the node tree\n # have fmt on/off/skip markers. Those STANDALONE_COMMENT nodes can't\n # be simply converted by calling str(node). So we just don't convert\n # here.\n if _contains_standalone_comment(node):\n return\n # Find the semantic parent of this suite. For `async_stmt` and\n # `async_funcdef`, the ASYNC token is defined on a separate level by the\n # grammar.\n semantic_parent = node.parent\n if semantic_parent is not None:\n if (\n semantic_parent.prev_sibling is not None\n and semantic_parent.prev_sibling.type == ASYNC\n ):\n semantic_parent = semantic_parent.parent\n if semantic_parent is not None and not _get_line_range(\n semantic_parent\n ).intersection(self._lines_set):\n _convert_node_to_standalone_comment(semantic_parent)\n\n\ndef _convert_unchanged_line_by_line(node: Node, lines_set: Set[int]) -> None:\n \"\"\"Converts unchanged to STANDALONE_COMMENT line by line.\"\"\"\n for leaf in node.leaves():\n if leaf.type != NEWLINE:\n # We only consider \"unwrapped lines\", which are divided by the NEWLINE\n # token.\n continue\n if leaf.parent and leaf.parent.type == syms.match_stmt:\n # The `suite` node is defined as:\n # match_stmt: \"match\" subject_expr ':' NEWLINE INDENT case_block+ DEDENT\n # Here we need to check `subject_expr`. The `case_block+` will be\n # checked by their own NEWLINEs.\n nodes_to_ignore: List[LN] = []\n prev_sibling = leaf.prev_sibling\n while prev_sibling:\n nodes_to_ignore.insert(0, prev_sibling)\n prev_sibling = prev_sibling.prev_sibling\n if not _get_line_range(nodes_to_ignore).intersection(lines_set):\n _convert_nodes_to_standalone_comment(nodes_to_ignore, newline=leaf)\n elif leaf.parent and leaf.parent.type == syms.suite:\n # The `suite` node is defined as:\n # suite: simple_stmt | NEWLINE INDENT stmt+ DEDENT\n # We will check `simple_stmt` and `stmt+` separately against the lines set\n parent_sibling = leaf.parent.prev_sibling\n nodes_to_ignore = []\n while parent_sibling and not parent_sibling.type == syms.suite:\n # NOTE: Multiple suite nodes can exist as siblings in e.g. `if_stmt`.\n nodes_to_ignore.insert(0, parent_sibling)\n parent_sibling = parent_sibling.prev_sibling\n # Special case for `async_stmt` and `async_funcdef` where the ASYNC\n # token is on the grandparent node.\n grandparent = leaf.parent.parent\n if (\n grandparent is not None\n and grandparent.prev_sibling is not None\n and grandparent.prev_sibling.type == ASYNC\n ):\n nodes_to_ignore.insert(0, grandparent.prev_sibling)\n if not _get_line_range(nodes_to_ignore).intersection(lines_set):\n _convert_nodes_to_standalone_comment(nodes_to_ignore, newline=leaf)\n else:\n ancestor = furthest_ancestor_with_last_leaf(leaf)\n # Consider multiple decorators as a whole block, as their\n # newlines have different behaviors than the rest of the grammar.\n if (\n ancestor.type == syms.decorator\n and ancestor.parent\n and ancestor.parent.type == syms.decorators\n ):\n ancestor = ancestor.parent\n if not _get_line_range(ancestor).intersection(lines_set):\n _convert_node_to_standalone_comment(ancestor)\n\n\ndef _convert_node_to_standalone_comment(node: LN) -> None:\n \"\"\"Convert node to STANDALONE_COMMENT by modifying the tree inline.\"\"\"\n parent = node.parent\n if not parent:\n return\n first = first_leaf(node)\n last = last_leaf(node)\n if not first or not last:\n return\n if first is last:\n # This can happen on the following edge cases:\n # 1. A block of `# fmt: off/on` code except the `# fmt: on` is placed\n # on the end of the last line instead of on a new line.\n # 2. A single backslash on its own line followed by a comment line.\n # Ideally we don't want to format them when not requested, but fixing\n # isn't easy. These cases are also badly formatted code, so it isn't\n # too bad we reformat them.\n return\n # The prefix contains comments and indentation whitespaces. They are\n # reformatted accordingly to the correct indentation level.\n # This also means the indentation will be changed on the unchanged lines, and\n # this is actually required to not break incremental reformatting.\n prefix = first.prefix\n first.prefix = \"\"\n index = node.remove()\n if index is not None:\n # Remove the '\\n', as STANDALONE_COMMENT will have '\\n' appended when\n # generating the formatted code.\n value = str(node)[:-1]\n parent.insert_child(\n index,\n Leaf(\n STANDALONE_COMMENT,\n value,\n prefix=prefix,\n fmt_pass_converted_first_leaf=first,\n ),\n )\n\n\ndef _convert_nodes_to_standalone_comment(nodes: Sequence[LN], *, newline: Leaf) -> None:\n \"\"\"Convert nodes to STANDALONE_COMMENT by modifying the tree inline.\"\"\"\n if not nodes:\n return\n parent = nodes[0].parent\n first = first_leaf(nodes[0])\n if not parent or not first:\n return\n prefix = first.prefix\n first.prefix = \"\"\n value = \"\".join(str(node) for node in nodes)\n # The prefix comment on the NEWLINE leaf is the trailing comment of the statement.\n if newline.prefix:\n value += newline.prefix\n newline.prefix = \"\"\n index = nodes[0].remove()\n for node in nodes[1:]:\n node.remove()\n if index is not None:\n parent.insert_child(\n index,\n Leaf(\n STANDALONE_COMMENT,\n value,\n prefix=prefix,\n fmt_pass_converted_first_leaf=first,\n ),\n )\n\n\ndef _leaf_line_end(leaf: Leaf) -> int:\n \"\"\"Returns the line number of the leaf node's last line.\"\"\"\n if leaf.type == NEWLINE:\n return leaf.lineno\n else:\n # Leaf nodes like multiline strings can occupy multiple lines.\n return leaf.lineno + str(leaf).count(\"\\n\")\n\n\ndef _get_line_range(node_or_nodes: Union[LN, List[LN]]) -> Set[int]:\n \"\"\"Returns the line range of this node or list of nodes.\"\"\"\n if isinstance(node_or_nodes, list):\n nodes = node_or_nodes\n if not nodes:\n return set()\n first = first_leaf(nodes[0])\n last = last_leaf(nodes[-1])\n if first and last:\n line_start = first.lineno\n line_end = _leaf_line_end(last)\n return set(range(line_start, line_end + 1))\n else:\n return set()\n else:\n node = node_or_nodes\n if isinstance(node, Leaf):\n return set(range(node.lineno, _leaf_line_end(node) + 1))\n else:\n first = first_leaf(node)\n last = last_leaf(node)\n if first and last:\n return set(range(first.lineno, _leaf_line_end(last) + 1))\n else:\n return set()\n\n\n@dataclass\nclass _LinesMapping:\n \"\"\"1-based lines mapping from original source to modified source.\n\n Lines [original_start, original_end] from original source\n are mapped to [modified_start, modified_end].\n\n The ranges are inclusive on both ends.\n \"\"\"\n\n original_start: int\n original_end: int\n modified_start: int\n modified_end: int\n # Whether this range corresponds to a changed block, or an unchanged block.\n is_changed_block: bool\n\n\ndef _calculate_lines_mappings(\n original_source: str,\n modified_source: str,\n) -> Sequence[_LinesMapping]:\n \"\"\"Returns a sequence of _LinesMapping by diffing the sources.\n\n For example, given the following diff:\n import re\n - def func(arg1,\n - arg2, arg3):\n + def func(arg1, arg2, arg3):\n pass\n It returns the following mappings:\n original -> modified\n (1, 1) -> (1, 1), is_changed_block=False (the \"import re\" line)\n (2, 3) -> (2, 2), is_changed_block=True (the diff)\n (4, 4) -> (3, 3), is_changed_block=False (the \"pass\" line)\n\n You can think of this visually as if it brings up a side-by-side diff, and tries\n to map the line ranges from the left side to the right side:\n\n (1, 1)->(1, 1) 1. import re 1. import re\n (2, 3)->(2, 2) 2. def func(arg1, 2. def func(arg1, arg2, arg3):\n 3. arg2, arg3):\n (4, 4)->(3, 3) 4. pass 3. pass\n\n Args:\n original_source: the original source.\n modified_source: the modified source.\n \"\"\"\n matcher = difflib.SequenceMatcher(\n None,\n original_source.splitlines(keepends=True),\n modified_source.splitlines(keepends=True),\n )\n matching_blocks = matcher.get_matching_blocks()\n lines_mappings: List[_LinesMapping] = []\n # matching_blocks is a sequence of \"same block of code ranges\", see\n # https://docs.python.org/3/library/difflib.html#difflib.SequenceMatcher.get_matching_blocks\n # Each block corresponds to a _LinesMapping with is_changed_block=False,\n # and the ranges between two blocks corresponds to a _LinesMapping with\n # is_changed_block=True,\n # NOTE: matching_blocks is 0-based, but _LinesMapping is 1-based.\n for i, block in enumerate(matching_blocks):\n if i == 0:\n if block.a != 0 or block.b != 0:\n lines_mappings.append(\n _LinesMapping(\n original_start=1,\n original_end=block.a,\n modified_start=1,\n modified_end=block.b,\n is_changed_block=False,\n )\n )\n else:\n previous_block = matching_blocks[i - 1]\n lines_mappings.append(\n _LinesMapping(\n original_start=previous_block.a + previous_block.size + 1,\n original_end=block.a,\n modified_start=previous_block.b + previous_block.size + 1,\n modified_end=block.b,\n is_changed_block=True,\n )\n )\n if i < len(matching_blocks) - 1:\n lines_mappings.append(\n _LinesMapping(\n original_start=block.a + 1,\n original_end=block.a + block.size,\n modified_start=block.b + 1,\n modified_end=block.b + block.size,\n is_changed_block=False,\n )\n )\n return lines_mappings\n\n\ndef _find_lines_mapping_index(\n original_line: int,\n lines_mappings: Sequence[_LinesMapping],\n start_index: int,\n) -> int:\n \"\"\"Returns the original index of the lines mappings for the original line.\"\"\"\n index = start_index\n while index < len(lines_mappings):\n mapping = lines_mappings[index]\n if mapping.original_start <= original_line <= mapping.original_end:\n return index\n index += 1\n return index\n\n# Path: src/black/__init__.py\nimport io\nimport json\nimport platform\nimport re\nimport sys\nimport tokenize\nimport traceback\nfrom contextlib import contextmanager\nfrom dataclasses import replace\nfrom datetime import datetime, timezone\nfrom enum import Enum\nfrom json.decoder import JSONDecodeError\nfrom pathlib import Path\nfrom typing import (\n Any,\n Collection,\n Dict,\n Generator,\n Iterator,\n List,\n MutableMapping,\n Optional,\n Pattern,\n Sequence,\n Set,\n Sized,\n Tuple,\n Union,\n)\n\nimport click\nfrom click.core import ParameterSource\nfrom mypy_extensions import mypyc_attr\nfrom pathspec import PathSpec\nfrom pathspec.patterns.gitwildmatch import GitWildMatchPatternError\n\nfrom _black_version import version as __version__\nfrom black.cache import Cache\nfrom black.comments import normalize_fmt_off\nfrom black.const import (\n DEFAULT_EXCLUDES,\n DEFAULT_INCLUDES,\n DEFAULT_LINE_LENGTH,\n STDIN_PLACEHOLDER,\n)\nfrom black.files import (\n best_effort_relative_path,\n find_project_root,\n find_pyproject_toml,\n find_user_pyproject_toml,\n gen_python_files,\n get_gitignore,\n parse_pyproject_toml,\n path_is_excluded,\n resolves_outside_root_or_cannot_stat,\n wrap_stream_for_windows,\n)\nfrom black.handle_ipynb_magics import (\n PYTHON_CELL_MAGICS,\n TRANSFORMED_MAGICS,\n jupyter_dependencies_are_installed,\n mask_cell,\n put_trailing_semicolon_back,\n remove_trailing_semicolon,\n unmask_cell,\n)\nfrom black.linegen import LN, LineGenerator, transform_line\nfrom black.lines import EmptyLineTracker, LinesBlock\nfrom black.mode import FUTURE_FLAG_TO_FEATURE, VERSION_TO_FEATURES, Feature\nfrom black.mode import Mode as Mode # re-exported\nfrom black.mode import Preview, TargetVersion, supports_feature\nfrom black.nodes import (\n STARS,\n is_number_token,\n is_simple_decorator_expression,\n is_string_token,\n syms,\n)\nfrom black.output import color_diff, diff, dump_to_file, err, ipynb_diff, out\nfrom black.parsing import InvalidInput # noqa F401\nfrom black.parsing import lib2to3_parse, parse_ast, stringify_ast\nfrom black.ranges import adjusted_lines, convert_unchanged_lines, parse_line_ranges\nfrom black.report import Changed, NothingChanged, Report\nfrom black.trans import iter_fexpr_spans\nfrom blib2to3.pgen2 import token\nfrom blib2to3.pytree import Leaf, Node\n\nCOMPILED = Path(__file__).suffix in (\".pyd\", \".so\")\n\n# types\nFileContent = str\nEncoding = str\nNewLine = str\n\n\nclass WriteBack(Enum):\n NO = 0\n YES = 1\n DIFF = 2\n CHECK = 3\n COLOR_DIFF = 4\n\n @classmethod\n def from_configuration(\n cls, *, check: bool, diff: bool, color: bool = False\n ) -> \"WriteBack\":\n if check and not diff:\n return cls.CHECK\n\n if diff and color:\n return cls.COLOR_DIFF\n\n return cls.DIFF if diff else cls.YES\n\n\n# Legacy name, left for integrations.\nFileMode = Mode\n\n\ndef read_pyproject_toml(\n ctx: click.Context, param: click.Parameter, value: Optional[str]\n) -> Optional[str]:\n \"\"\"Inject Black configuration from \"pyproject.toml\" into defaults in `ctx`.\n\n Returns the path to a successfully found and read configuration file, None\n otherwise.\n \"\"\"\n if not value:\n value = find_pyproject_toml(\n ctx.params.get(\"src\", ()), ctx.params.get(\"stdin_filename\", None)\n )\n if value is None:\n return None\n\n try:\n config = parse_pyproject_toml(value)\n except (OSError, ValueError) as e:\n raise click.FileError(\n filename=value, hint=f\"Error reading configuration file: {e}\"\n ) from None\n\n if not config:\n return None\n else:\n spellcheck_pyproject_toml_keys(ctx, list(config), value)\n # Sanitize the values to be Click friendly. For more information please see:\n # https://github.com/psf/black/issues/1458\n # https://github.com/pallets/click/issues/1567\n config = {\n k: str(v) if not isinstance(v, (list, dict)) else v\n for k, v in config.items()\n }\n\n target_version = config.get(\"target_version\")\n if target_version is not None and not isinstance(target_version, list):\n raise click.BadOptionUsage(\n \"target-version\", \"Config key target-version must be a list\"\n )\n\n exclude = config.get(\"exclude\")\n if exclude is not None and not isinstance(exclude, str):\n raise click.BadOptionUsage(\"exclude\", \"Config key exclude must be a string\")\n\n extend_exclude = config.get(\"extend_exclude\")\n if extend_exclude is not None and not isinstance(extend_exclude, str):\n raise click.BadOptionUsage(\n \"extend-exclude\", \"Config key extend-exclude must be a string\"\n )\n\n line_ranges = config.get(\"line_ranges\")\n if line_ranges is not None:\n raise click.BadOptionUsage(\n \"line-ranges\", \"Cannot use line-ranges in the pyproject.toml file.\"\n )\n\n default_map: Dict[str, Any] = {}\n if ctx.default_map:\n default_map.update(ctx.default_map)\n default_map.update(config)\n\n ctx.default_map = default_map\n return value\n\n\ndef spellcheck_pyproject_toml_keys(\n ctx: click.Context, config_keys: List[str], config_file_path: str\n) -> None:\n invalid_keys: List[str] = []\n available_config_options = {param.name for param in ctx.command.params}\n for key in config_keys:\n if key not in available_config_options:\n invalid_keys.append(key)\n if invalid_keys:\n keys_str = \", \".join(map(repr, invalid_keys))\n out(\n f\"Invalid config keys detected: {keys_str} (in {config_file_path})\",\n fg=\"red\",\n )\n\n\ndef target_version_option_callback(\n c: click.Context, p: Union[click.Option, click.Parameter], v: Tuple[str, ...]\n) -> List[TargetVersion]:\n \"\"\"Compute the target versions from a --target-version flag.\n\n This is its own function because mypy couldn't infer the type correctly\n when it was a lambda, causing mypyc trouble.\n \"\"\"\n return [TargetVersion[val.upper()] for val in v]\n\n\ndef enable_unstable_feature_callback(\n c: click.Context, p: Union[click.Option, click.Parameter], v: Tuple[str, ...]\n) -> List[Preview]:\n \"\"\"Compute the features from an --enable-unstable-feature flag.\"\"\"\n return [Preview[val] for val in v]\n\n\ndef re_compile_maybe_verbose(regex: str) -> Pattern[str]:\n \"\"\"Compile a regular expression string in `regex`.\n\n If it contains newlines, use verbose mode.\n \"\"\"\n if \"\\n\" in regex:\n regex = \"(?x)\" + regex\n compiled: Pattern[str] = re.compile(regex)\n return compiled\n\n\ndef validate_regex(\n ctx: click.Context,\n param: click.Parameter,\n value: Optional[str],\n) -> Optional[Pattern[str]]:\n try:\n return re_compile_maybe_verbose(value) if value is not None else None\n except re.error as e:\n raise click.BadParameter(f\"Not a valid regular expression: {e}\") from None\n\n\n@click.command(\n context_settings={\"help_option_names\": [\"-h\", \"--help\"]},\n # While Click does set this field automatically using the docstring, mypyc\n # (annoyingly) strips 'em so we need to set it here too.\n help=\"The uncompromising code formatter.\",\n)\n@click.option(\"-c\", \"--code\", type=str, help=\"Format the code passed in as a string.\")\n@click.option(\n \"-l\",\n \"--line-length\",\n type=int,\n default=DEFAULT_LINE_LENGTH,\n help=\"How many characters per line to allow.\",\n show_default=True,\n)\n@click.option(\n \"-t\",\n \"--target-version\",\n type=click.Choice([v.name.lower() for v in TargetVersion]),\n callback=target_version_option_callback,\n multiple=True,\n help=(\n \"Python versions that should be supported by Black's output. You should\"\n \" include all versions that your code supports. By default, Black will infer\"\n \" target versions from the project metadata in pyproject.toml. If this does\"\n \" not yield conclusive results, Black will use per-file auto-detection.\"\n ),\n)\n@click.option(\n \"--pyi\",\n is_flag=True,\n help=(\n \"Format all input files like typing stubs regardless of file extension. This\"\n \" is useful when piping source on standard input.\"\n ),\n)\n@click.option(\n \"--ipynb\",\n...\n", "needle_token_start": 5377, "needle_token_end": 5921, "code_context_ntokens": 16166, "output": ["```python\ndef run_transformer(\n line: Line,\n transform: Transformer,\n mode: Mode,\n features: Collection[Feature],\n *,\n line_str: str = \"\",\n) -> List[Line]:\n if not line_str:\n line_str = line_to_string(line)\n result: List[Line] = []\n for transformed_line in transform(line, features, mode):\n if str(transformed_line).strip(\"\\n\") == line_str:\n raise CannotTransform(\"Line transformer returned an unchanged result\")\n\n result.extend(transform_line(transformed_line, mode=mode, features=features))\n\n features_set = set(features)\n if (\n Feature.FORCE_OPTIONAL_PARENTHESES in features_set\n or transform.__class__.__name__ != \"rhs\"\n or not line.bracket_tracker.invisible\n or any(bracket.value for bracket in line.bracket_tracker.invisible)\n or line.contains_multiline_strings()\n or result[0].contains_uncollapsable_type_comments()\n or result[0].contains_unsplittable_type_ignore()\n or is_line_short_enough(result[0], mode=mode)\n # If any leaves have no parents (which _can_ occur since\n # `transform(line)` potentially destroys the line's underlying node\n # structure), then we can't proceed. Doing so would cause the below\n # call to `append_leaves()` to fail.\n or any(leaf.parent is None for leaf in line.leaves)\n ):\n return result\n\n line_copy = line.clone()\n append_leaves(line_copy, line, line.leaves)\n features_fop = features_set | {Feature.FORCE_OPTIONAL_PARENTHESES}\n second_opinion = run_transformer(\n line_copy, transform, mode, features_fop, line_str=line_str\n )\n if all(is_line_short_enough(ln, mode=mode) for ln in second_opinion):\n result = second_opinion\n return result\n```"]} +{"repo": "psf/black", "name": "is_part_of_annotation", "language": "python", "path": "src/black/nodes.py", "position_ratio": 0.45, "description": "\nFunction Description:\n1. **Purpose**: Determines if a specified syntax element is part of a type annotation in the source code.\n2. **Input**: A syntax element from the parsed source code.\n3. **Output**: A boolean value indicating whether the given syntax element is involved in a type annotation.\n4. **Procedure**: The function checks the type of annotation associated with the syntax element by tracing its position in the syntax tree and examining its relationship with surrounding elements. If the element is identified as part of either a return type or parameter type annotation, the function returns true; otherwise, it returns false.\n\n", "instruction": "Based on the function description and code context, please retrieve and repeat the exact described function from the code context in a code block wrapped by ```:", "template": "instruction\ncode_context\ndescription\ninstruction", "code_context": " DOUBLESPACE: Final[str] = \" \"\n t = leaf.type\n p = leaf.parent\n v = leaf.value\n if t in ALWAYS_NO_SPACE:\n return NO\n\n if t == token.COMMENT:\n return DOUBLESPACE\n\n assert p is not None, f\"INTERNAL ERROR: hand-made leaf without parent: {leaf!r}\"\n if t == token.COLON and p.type not in {\n syms.subscript,\n syms.subscriptlist,\n syms.sliceop,\n }:\n return NO\n\n prev = leaf.prev_sibling\n if not prev:\n prevp = preceding_leaf(p)\n if not prevp or prevp.type in OPENING_BRACKETS:\n return NO\n\n if t == token.COLON:\n if prevp.type == token.COLON:\n return NO\n\n elif prevp.type != token.COMMA and not complex_subscript:\n return NO\n\n return SPACE\n\n if prevp.type == token.EQUAL:\n if prevp.parent:\n if prevp.parent.type in {\n syms.arglist,\n syms.argument,\n syms.parameters,\n syms.varargslist,\n }:\n return NO\n\n elif prevp.parent.type == syms.typedargslist:\n # A bit hacky: if the equal sign has whitespace, it means we\n # previously found it's a typed argument. So, we're using\n # that, too.\n return prevp.prefix\n\n elif (\n prevp.type == token.STAR\n and parent_type(prevp) == syms.star_expr\n and parent_type(prevp.parent) == syms.subscriptlist\n ):\n # No space between typevar tuples.\n return NO\n\n elif prevp.type in VARARGS_SPECIALS:\n if is_vararg(prevp, within=VARARGS_PARENTS | UNPACKING_PARENTS):\n return NO\n\n elif prevp.type == token.COLON:\n if prevp.parent and prevp.parent.type in {syms.subscript, syms.sliceop}:\n return SPACE if complex_subscript else NO\n\n elif (\n prevp.parent\n and prevp.parent.type == syms.factor\n and prevp.type in MATH_OPERATORS\n ):\n return NO\n\n elif prevp.type == token.AT and p.parent and p.parent.type == syms.decorator:\n # no space in decorators\n return NO\n\n elif prev.type in OPENING_BRACKETS:\n return NO\n\n if p.type in {syms.parameters, syms.arglist}:\n # untyped function signatures or calls\n if not prev or prev.type != token.COMMA:\n return NO\n\n elif p.type == syms.varargslist:\n # lambdas\n if prev and prev.type != token.COMMA:\n return NO\n\n elif p.type == syms.typedargslist:\n # typed function signatures\n if not prev:\n return NO\n\n if t == token.EQUAL:\n if prev.type not in TYPED_NAMES:\n return NO\n\n elif prev.type == token.EQUAL:\n # A bit hacky: if the equal sign has whitespace, it means we\n # previously found it's a typed argument. So, we're using that, too.\n return prev.prefix\n\n elif prev.type != token.COMMA:\n return NO\n\n elif p.type in TYPED_NAMES:\n # type names\n if not prev:\n prevp = preceding_leaf(p)\n if not prevp or prevp.type != token.COMMA:\n return NO\n\n elif p.type == syms.trailer:\n # attributes and calls\n if t == token.LPAR or t == token.RPAR:\n return NO\n\n if not prev:\n if t == token.DOT or t == token.LSQB:\n return NO\n\n elif prev.type != token.COMMA:\n return NO\n\n elif p.type == syms.argument:\n # single argument\n if t == token.EQUAL:\n return NO\n\n if not prev:\n prevp = preceding_leaf(p)\n if not prevp or prevp.type == token.LPAR:\n return NO\n\n elif prev.type in {token.EQUAL} | VARARGS_SPECIALS:\n return NO\n\n elif p.type == syms.decorator:\n # decorators\n return NO\n\n elif p.type == syms.dotted_name:\n if prev:\n return NO\n\n prevp = preceding_leaf(p)\n if not prevp or prevp.type == token.AT or prevp.type == token.DOT:\n return NO\n\n elif p.type == syms.classdef:\n if t == token.LPAR:\n return NO\n\n if prev and prev.type == token.LPAR:\n return NO\n\n elif p.type in {syms.subscript, syms.sliceop}:\n # indexing\n if not prev:\n assert p.parent is not None, \"subscripts are always parented\"\n if p.parent.type == syms.subscriptlist:\n return SPACE\n\n return NO\n\n elif t == token.COLONEQUAL or prev.type == token.COLONEQUAL:\n return SPACE\n\n elif not complex_subscript:\n return NO\n\n elif p.type == syms.atom:\n if prev and t == token.DOT:\n # dots, but not the first one.\n return NO\n\n elif p.type == syms.dictsetmaker:\n # dict unpacking\n if prev and prev.type == token.DOUBLESTAR:\n return NO\n\n elif p.type in {syms.factor, syms.star_expr}:\n # unary ops\n if not prev:\n prevp = preceding_leaf(p)\n if not prevp or prevp.type in OPENING_BRACKETS:\n return NO\n\n prevp_parent = prevp.parent\n assert prevp_parent is not None\n if prevp.type == token.COLON and prevp_parent.type in {\n syms.subscript,\n syms.sliceop,\n }:\n return NO\n\n elif prevp.type == token.EQUAL and prevp_parent.type == syms.argument:\n return NO\n\n elif t in {token.NAME, token.NUMBER, token.STRING}:\n return NO\n\n elif p.type == syms.import_from:\n if t == token.DOT:\n if prev and prev.type == token.DOT:\n return NO\n\n elif t == token.NAME:\n if v == \"import\":\n return SPACE\n\n if prev and prev.type == token.DOT:\n return NO\n\n elif p.type == syms.sliceop:\n return NO\n\n elif p.type == syms.except_clause:\n if t == token.STAR:\n return NO\n\n return SPACE\n\n\ndef make_simple_prefix(nl_count: int, form_feed: bool, empty_line: str = \"\\n\") -> str:\n \"\"\"Generate a normalized prefix string.\"\"\"\n if form_feed:\n return (empty_line * (nl_count - 1)) + \"\\f\" + empty_line\n return empty_line * nl_count\n\n\ndef preceding_leaf(node: Optional[LN]) -> Optional[Leaf]:\n \"\"\"Return the first leaf that precedes `node`, if any.\"\"\"\n while node:\n res = node.prev_sibling\n if res:\n if isinstance(res, Leaf):\n return res\n\n try:\n return list(res.leaves())[-1]\n\n except IndexError:\n return None\n\n node = node.parent\n return None\n\n\ndef prev_siblings_are(node: Optional[LN], tokens: List[Optional[NodeType]]) -> bool:\n \"\"\"Return if the `node` and its previous siblings match types against the provided\n list of tokens; the provided `node`has its type matched against the last element in\n the list. `None` can be used as the first element to declare that the start of the\n list is anchored at the start of its parent's children.\"\"\"\n if not tokens:\n return True\n if tokens[-1] is None:\n return node is None\n if not node:\n return False\n if node.type != tokens[-1]:\n return False\n return prev_siblings_are(node.prev_sibling, tokens[:-1])\n\n\ndef parent_type(node: Optional[LN]) -> Optional[NodeType]:\n \"\"\"\n Returns:\n @node.parent.type, if @node is not None and has a parent.\n OR\n None, otherwise.\n \"\"\"\n if node is None or node.parent is None:\n return None\n\n return node.parent.type\n\n\ndef child_towards(ancestor: Node, descendant: LN) -> Optional[LN]:\n \"\"\"Return the child of `ancestor` that contains `descendant`.\"\"\"\n node: Optional[LN] = descendant\n while node and node.parent != ancestor:\n node = node.parent\n return node\n\n\ndef replace_child(old_child: LN, new_child: LN) -> None:\n \"\"\"\n Side Effects:\n * If @old_child.parent is set, replace @old_child with @new_child in\n @old_child's underlying Node structure.\n OR\n * Otherwise, this function does nothing.\n \"\"\"\n parent = old_child.parent\n if not parent:\n return\n\n child_idx = old_child.remove()\n if child_idx is not None:\n parent.insert_child(child_idx, new_child)\n\n\ndef container_of(leaf: Leaf) -> LN:\n \"\"\"Return `leaf` or one of its ancestors that is the topmost container of it.\n\n By \"container\" we mean a node where `leaf` is the very first child.\n \"\"\"\n same_prefix = leaf.prefix\n container: LN = leaf\n while container:\n parent = container.parent\n if parent is None:\n break\n\n if parent.children[0].prefix != same_prefix:\n break\n\n if parent.type == syms.file_input:\n break\n\n if parent.prev_sibling is not None and parent.prev_sibling.type in BRACKETS:\n break\n\n container = parent\n return container\n\n\ndef first_leaf_of(node: LN) -> Optional[Leaf]:\n \"\"\"Returns the first leaf of the node tree.\"\"\"\n if isinstance(node, Leaf):\n return node\n if node.children:\n return first_leaf_of(node.children[0])\n else:\n return None\n\n\ndef is_arith_like(node: LN) -> bool:\n \"\"\"Whether node is an arithmetic or a binary arithmetic expression\"\"\"\n return node.type in {\n syms.arith_expr,\n syms.shift_expr,\n syms.xor_expr,\n syms.and_expr,\n }\n\n\ndef is_docstring(leaf: Leaf, mode: Mode) -> bool:\n if leaf.type != token.STRING:\n return False\n\n prefix = get_string_prefix(leaf.value)\n if set(prefix).intersection(\"bBfF\"):\n return False\n\n if (\n Preview.unify_docstring_detection in mode\n and leaf.parent\n and leaf.parent.type == syms.simple_stmt\n and not leaf.parent.prev_sibling\n and leaf.parent.parent\n and leaf.parent.parent.type == syms.file_input\n ):\n return True\n\n if prev_siblings_are(\n leaf.parent, [None, token.NEWLINE, token.INDENT, syms.simple_stmt]\n ):\n return True\n\n # Multiline docstring on the same line as the `def`.\n if prev_siblings_are(leaf.parent, [syms.parameters, token.COLON, syms.simple_stmt]):\n # `syms.parameters` is only used in funcdefs and async_funcdefs in the Python\n # grammar. We're safe to return True without further checks.\n return True\n\n return False\n\n\ndef is_empty_tuple(node: LN) -> bool:\n \"\"\"Return True if `node` holds an empty tuple.\"\"\"\n return (\n node.type == syms.atom\n and len(node.children) == 2\n and node.children[0].type == token.LPAR\n and node.children[1].type == token.RPAR\n )\n\n\ndef is_one_tuple(node: LN) -> bool:\n \"\"\"Return True if `node` holds a tuple with one element, with or without parens.\"\"\"\n if node.type == syms.atom:\n gexp = unwrap_singleton_parenthesis(node)\n if gexp is None or gexp.type != syms.testlist_gexp:\n return False\n\n return len(gexp.children) == 2 and gexp.children[1].type == token.COMMA\n\n return (\n node.type in IMPLICIT_TUPLE\n and len(node.children) == 2\n and node.children[1].type == token.COMMA\n )\n\n\ndef is_tuple_containing_walrus(node: LN) -> bool:\n \"\"\"Return True if `node` holds a tuple that contains a walrus operator.\"\"\"\n if node.type != syms.atom:\n return False\n gexp = unwrap_singleton_parenthesis(node)\n if gexp is None or gexp.type != syms.testlist_gexp:\n return False\n\n return any(child.type == syms.namedexpr_test for child in gexp.children)\n\n\ndef is_one_sequence_between(\n opening: Leaf,\n closing: Leaf,\n leaves: List[Leaf],\n brackets: Tuple[int, int] = (token.LPAR, token.RPAR),\n) -> bool:\n \"\"\"Return True if content between `opening` and `closing` is a one-sequence.\"\"\"\n if (opening.type, closing.type) != brackets:\n return False\n\n depth = closing.bracket_depth + 1\n for _opening_index, leaf in enumerate(leaves):\n if leaf is opening:\n break\n\n else:\n raise LookupError(\"Opening paren not found in `leaves`\")\n\n commas = 0\n _opening_index += 1\n for leaf in leaves[_opening_index:]:\n if leaf is closing:\n break\n\n bracket_depth = leaf.bracket_depth\n if bracket_depth == depth and leaf.type == token.COMMA:\n commas += 1\n if leaf.parent and leaf.parent.type in {\n syms.arglist,\n syms.typedargslist,\n }:\n commas += 1\n break\n\n return commas < 2\n\n\ndef is_walrus_assignment(node: LN) -> bool:\n \"\"\"Return True iff `node` is of the shape ( test := test )\"\"\"\n inner = unwrap_singleton_parenthesis(node)\n return inner is not None and inner.type == syms.namedexpr_test\n\n\ndef is_simple_decorator_trailer(node: LN, last: bool = False) -> bool:\n \"\"\"Return True iff `node` is a trailer valid in a simple decorator\"\"\"\n return node.type == syms.trailer and (\n (\n len(node.children) == 2\n and node.children[0].type == token.DOT\n and node.children[1].type == token.NAME\n )\n # last trailer can be an argument-less parentheses pair\n or (\n last\n and len(node.children) == 2\n and node.children[0].type == token.LPAR\n and node.children[1].type == token.RPAR\n )\n # last trailer can be arguments\n or (\n last\n and len(node.children) == 3\n and node.children[0].type == token.LPAR\n # and node.children[1].type == syms.argument\n and node.children[2].type == token.RPAR\n )\n )\n\n\ndef is_simple_decorator_expression(node: LN) -> bool:\n \"\"\"Return True iff `node` could be a 'dotted name' decorator\n\n This function takes the node of the 'namedexpr_test' of the new decorator\n grammar and test if it would be valid under the old decorator grammar.\n\n The old grammar was: decorator: @ dotted_name [arguments] NEWLINE\n The new grammar is : decorator: @ namedexpr_test NEWLINE\n \"\"\"\n if node.type == token.NAME:\n return True\n if node.type == syms.power:\n if node.children:\n return (\n node.children[0].type == token.NAME\n and all(map(is_simple_decorator_trailer, node.children[1:-1]))\n and (\n len(node.children) < 2\n or is_simple_decorator_trailer(node.children[-1], last=True)\n )\n )\n return False\n\n\ndef is_yield(node: LN) -> bool:\n \"\"\"Return True if `node` holds a `yield` or `yield from` expression.\"\"\"\n if node.type == syms.yield_expr:\n return True\n\n if is_name_token(node) and node.value == \"yield\":\n return True\n\n if node.type != syms.atom:\n return False\n\n if len(node.children) != 3:\n return False\n\n lpar, expr, rpar = node.children\n if lpar.type == token.LPAR and rpar.type == token.RPAR:\n return is_yield(expr)\n\n return False\n\n\ndef is_vararg(leaf: Leaf, within: Set[NodeType]) -> bool:\n \"\"\"Return True if `leaf` is a star or double star in a vararg or kwarg.\n\n If `within` includes VARARGS_PARENTS, this applies to function signatures.\n If `within` includes UNPACKING_PARENTS, it applies to right hand-side\n extended iterable unpacking (PEP 3132) and additional unpacking\n generalizations (PEP 448).\n \"\"\"\n if leaf.type not in VARARGS_SPECIALS or not leaf.parent:\n return False\n\n p = leaf.parent\n if p.type == syms.star_expr:\n # Star expressions are also used as assignment targets in extended\n # iterable unpacking (PEP 3132). See what its parent is instead.\n if not p.parent:\n return False\n\n p = p.parent\n\n return p.type in within\n\n\ndef is_multiline_string(leaf: Leaf) -> bool:\n \"\"\"Return True if `leaf` is a multiline string that actually spans many lines.\"\"\"\n return has_triple_quotes(leaf.value) and \"\\n\" in leaf.value\n\n\ndef is_parent_function_or_class(node: Node) -> bool:\n assert node.type in {syms.suite, syms.simple_stmt}\n assert node.parent is not None\n # Note this works for suites / simple_stmts in async def as well\n return node.parent.type in {syms.funcdef, syms.classdef}\n\n\ndef is_function_or_class(node: Node) -> bool:\n return node.type in {syms.funcdef, syms.classdef, syms.async_funcdef}\n\n\ndef is_stub_suite(node: Node) -> bool:\n \"\"\"Return True if `node` is a suite with a stub body.\"\"\"\n if node.parent is not None and not is_parent_function_or_class(node):\n return False\n\n # If there is a comment, we want to keep it.\n if node.prefix.strip():\n return False\n\n if (\n len(node.children) != 4\n or node.children[0].type != token.NEWLINE\n or node.children[1].type != token.INDENT\n or node.children[3].type != token.DEDENT\n ):\n return False\n\n if node.children[3].prefix.strip():\n return False\n\n return is_stub_body(node.children[2])\n\n\ndef is_stub_body(node: LN) -> bool:\n \"\"\"Return True if `node` is a simple statement containing an ellipsis.\"\"\"\n if not isinstance(node, Node) or node.type != syms.simple_stmt:\n return False\n\n if len(node.children) != 2:\n return False\n\n child = node.children[0]\n return (\n not child.prefix.strip()\n and child.type == syms.atom\n and len(child.children) == 3\n and all(leaf == Leaf(token.DOT, \".\") for leaf in child.children)\n )\n\n\ndef is_atom_with_invisible_parens(node: LN) -> bool:\n \"\"\"Given a `LN`, determines whether it's an atom `node` with invisible\n parens. Useful in dedupe-ing and normalizing parens.\n \"\"\"\n if isinstance(node, Leaf) or node.type != syms.atom:\n return False\n\n first, last = node.children[0], node.children[-1]\n return (\n isinstance(first, Leaf)\n and first.type == token.LPAR\n and first.value == \"\"\n and isinstance(last, Leaf)\n and last.type == token.RPAR\n and last.value == \"\"\n )\n\n\ndef is_empty_par(leaf: Leaf) -> bool:\n return is_empty_lpar(leaf) or is_empty_rpar(leaf)\n\n\ndef is_empty_lpar(leaf: Leaf) -> bool:\n return leaf.type == token.LPAR and leaf.value == \"\"\n\n\ndef is_empty_rpar(leaf: Leaf) -> bool:\n return leaf.type == token.RPAR and leaf.value == \"\"\n\n\ndef is_import(leaf: Leaf) -> bool:\n \"\"\"Return True if the given leaf starts an import statement.\"\"\"\n p = leaf.parent\n t = leaf.type\n v = leaf.value\n return bool(\n t == token.NAME\n and (\n (v == \"import\" and p and p.type == syms.import_name)\n or (v == \"from\" and p and p.type == syms.import_from)\n )\n )\n\n\ndef is_with_or_async_with_stmt(leaf: Leaf) -> bool:\n \"\"\"Return True if the given leaf starts a with or async with statement.\"\"\"\n return bool(\n leaf.type == token.NAME\n and leaf.value == \"with\"\n and leaf.parent\n and leaf.parent.type == syms.with_stmt\n ) or bool(\n leaf.type == token.ASYNC\n and leaf.next_sibling\n and leaf.next_sibling.type == syms.with_stmt\n )\n\n\ndef is_async_stmt_or_funcdef(leaf: Leaf) -> bool:\n \"\"\"Return True if the given leaf starts an async def/for/with statement.\n\n Note that `async def` can be either an `async_stmt` or `async_funcdef`,\n the latter is used when it has decorators.\n \"\"\"\n return bool(\n leaf.type == token.ASYNC\n and leaf.parent\n and leaf.parent.type in {syms.async_stmt, syms.async_funcdef}\n )\n\n\ndef is_type_comment(leaf: Leaf) -> bool:\n \"\"\"Return True if the given leaf is a type comment. This function should only\n be used for general type comments (excluding ignore annotations, which should\n use `is_type_ignore_comment`). Note that general type comments are no longer\n used in modern version of Python, this function may be deprecated in the future.\"\"\"\n t = leaf.type\n v = leaf.value\n return t in {token.COMMENT, STANDALONE_COMMENT} and v.startswith(\"# type:\")\n\n\ndef is_type_ignore_comment(leaf: Leaf) -> bool:\n \"\"\"Return True if the given leaf is a type comment with ignore annotation.\"\"\"\n t = leaf.type\n v = leaf.value\n return t in {token.COMMENT, STANDALONE_COMMENT} and is_type_ignore_comment_string(v)\n\n\ndef is_type_ignore_comment_string(value: str) -> bool:\n \"\"\"Return True if the given string match with type comment with\n ignore annotation.\"\"\"\n return value.startswith(\"# type: ignore\")\n\n\ndef wrap_in_parentheses(parent: Node, child: LN, *, visible: bool = True) -> None:\n \"\"\"Wrap `child` in parentheses.\n\n This replaces `child` with an atom holding the parentheses and the old\n child. That requires moving the prefix.\n\n If `visible` is False, the leaves will be valueless (and thus invisible).\n \"\"\"\n lpar = Leaf(token.LPAR, \"(\" if visible else \"\")\n rpar = Leaf(token.RPAR, \")\" if visible else \"\")\n prefix = child.prefix\n child.prefix = \"\"\n index = child.remove() or 0\n new_child = Node(syms.atom, [lpar, child, rpar])\n new_child.prefix = prefix\n parent.insert_child(index, new_child)\n\n\ndef unwrap_singleton_parenthesis(node: LN) -> Optional[LN]:\n \"\"\"Returns `wrapped` if `node` is of the shape ( wrapped ).\n\n Parenthesis can be optional. Returns None otherwise\"\"\"\n if len(node.children) != 3:\n return None\n\n lpar, wrapped, rpar = node.children\n if not (lpar.type == token.LPAR and rpar.type == token.RPAR):\n return None\n\n return wrapped\n\n\ndef ensure_visible(leaf: Leaf) -> None:\n \"\"\"Make sure parentheses are visible.\n\n They could be invisible as part of some statements (see\n :func:`normalize_invisible_parens` and :func:`visit_import_from`).\n \"\"\"\n if leaf.type == token.LPAR:\n leaf.value = \"(\"\n elif leaf.type == token.RPAR:\n leaf.value = \")\"\n\n\ndef is_name_token(nl: NL) -> TypeGuard[Leaf]:\n return nl.type == token.NAME\n\n\ndef is_lpar_token(nl: NL) -> TypeGuard[Leaf]:\n return nl.type == token.LPAR\n\n\ndef is_rpar_token(nl: NL) -> TypeGuard[Leaf]:\n return nl.type == token.RPAR\n\n\ndef is_string_token(nl: NL) -> TypeGuard[Leaf]:\n return nl.type == token.STRING\n\n\ndef is_number_token(nl: NL) -> TypeGuard[Leaf]:\n return nl.type == token.NUMBER\n\n\ndef get_annotation_type(leaf: Leaf) -> Literal[\"return\", \"param\", None]:\n \"\"\"Returns the type of annotation this leaf is part of, if any.\"\"\"\n ancestor = leaf.parent\n while ancestor is not None:\n if ancestor.prev_sibling and ancestor.prev_sibling.type == token.RARROW:\n return \"return\"\n if ancestor.parent and ancestor.parent.type == syms.tname:\n return \"param\"\n ancestor = ancestor.parent\n return None\n\n\n\ndef is_part_of_annotation(leaf: Leaf) -> bool:\n \"\"\"Returns whether this leaf is part of a type annotation.\"\"\"\n return get_annotation_type(leaf) is not None\n\n\ndef first_leaf(node: LN) -> Optional[Leaf]:\n \"\"\"Returns the first leaf of the ancestor node.\"\"\"\n if isinstance(node, Leaf):\n return node\n elif not node.children:\n return None\n else:\n return first_leaf(node.children[0])\n\n\ndef last_leaf(node: LN) -> Optional[Leaf]:\n \"\"\"Returns the last leaf of the ancestor node.\"\"\"\n if isinstance(node, Leaf):\n return node\n elif not node.children:\n return None\n else:\n return last_leaf(node.children[-1])\n\n\ndef furthest_ancestor_with_last_leaf(leaf: Leaf) -> LN:\n \"\"\"Returns the furthest ancestor that has this leaf node as the last leaf.\"\"\"\n node: LN = leaf\n while node.parent and node.parent.children and node is node.parent.children[-1]:\n node = node.parent\n return node\n\n# Path: src/black/comments.py\nimport re\nfrom dataclasses import dataclass\nfrom functools import lru_cache\nfrom typing import Collection, Final, Iterator, List, Optional, Tuple, Union\n\nfrom black.mode import Mode, Preview\nfrom black.nodes import (\n CLOSING_BRACKETS,\n STANDALONE_COMMENT,\n WHITESPACE,\n container_of,\n first_leaf_of,\n make_simple_prefix,\n preceding_leaf,\n syms,\n)\nfrom blib2to3.pgen2 import token\nfrom blib2to3.pytree import Leaf, Node\n\n# types\nLN = Union[Leaf, Node]\n\nFMT_OFF: Final = {\"# fmt: off\", \"# fmt:off\", \"# yapf: disable\"}\nFMT_SKIP: Final = {\"# fmt: skip\", \"# fmt:skip\"}\nFMT_ON: Final = {\"# fmt: on\", \"# fmt:on\", \"# yapf: enable\"}\n\nCOMMENT_EXCEPTIONS = \" !:#'\"\n_COMMENT_PREFIX = \"# \"\n_COMMENT_LIST_SEPARATOR = \";\"\n\n\n@dataclass\nclass ProtoComment:\n \"\"\"Describes a piece of syntax that is a comment.\n\n It's not a :class:`blib2to3.pytree.Leaf` so that:\n\n * it can be cached (`Leaf` objects should not be reused more than once as\n they store their lineno, column, prefix, and parent information);\n * `newlines` and `consumed` fields are kept separate from the `value`. This\n simplifies handling of special marker comments like ``# fmt: off/on``.\n \"\"\"\n\n type: int # token.COMMENT or STANDALONE_COMMENT\n value: str # content of the comment\n newlines: int # how many newlines before the comment\n consumed: int # how many characters of the original leaf's prefix did we consume\n form_feed: bool # is there a form feed before the comment\n leading_whitespace: str # leading whitespace before the comment, if any\n\n\ndef generate_comments(leaf: LN) -> Iterator[Leaf]:\n \"\"\"Clean the prefix of the `leaf` and generate comments from it, if any.\n\n Comments in lib2to3 are shoved into the whitespace prefix. This happens\n in `pgen2/driver.py:Driver.parse_tokens()`. This was a brilliant implementation\n move because it does away with modifying the grammar to include all the\n possible places in which comments can be placed.\n\n The sad consequence for us though is that comments don't \"belong\" anywhere.\n This is why this function generates simple parentless Leaf objects for\n comments. We simply don't know what the correct parent should be.\n\n No matter though, we can live without this. We really only need to\n differentiate between inline and standalone comments. The latter don't\n share the line with any code.\n\n Inline comments are emitted as regular token.COMMENT leaves. Standalone\n are emitted with a fake STANDALONE_COMMENT token identifier.\n \"\"\"\n total_consumed = 0\n for pc in list_comments(leaf.prefix, is_endmarker=leaf.type == token.ENDMARKER):\n total_consumed = pc.consumed\n prefix = make_simple_prefix(pc.newlines, pc.form_feed)\n yield Leaf(pc.type, pc.value, prefix=prefix)\n normalize_trailing_prefix(leaf, total_consumed)\n\n\n@lru_cache(maxsize=4096)\ndef list_comments(prefix: str, *, is_endmarker: bool) -> List[ProtoComment]:\n \"\"\"Return a list of :class:`ProtoComment` objects parsed from the given `prefix`.\"\"\"\n result: List[ProtoComment] = []\n if not prefix or \"#\" not in prefix:\n return result\n\n consumed = 0\n nlines = 0\n ignored_lines = 0\n form_feed = False\n for index, full_line in enumerate(re.split(\"\\r?\\n\", prefix)):\n consumed += len(full_line) + 1 # adding the length of the split '\\n'\n match = re.match(r\"^(\\s*)(\\S.*|)$\", full_line)\n assert match\n whitespace, line = match.groups()\n if not line:\n nlines += 1\n if \"\\f\" in full_line:\n form_feed = True\n if not line.startswith(\"#\"):\n # Escaped newlines outside of a comment are not really newlines at\n # all. We treat a single-line comment following an escaped newline\n # as a simple trailing comment.\n if line.endswith(\"\\\\\"):\n ignored_lines += 1\n continue\n\n if index == ignored_lines and not is_endmarker:\n comment_type = token.COMMENT # simple trailing comment\n else:\n comment_type = STANDALONE_COMMENT\n comment = make_comment(line)\n result.append(\n ProtoComment(\n type=comment_type,\n value=comment,\n newlines=nlines,\n consumed=consumed,\n form_feed=form_feed,\n leading_whitespace=whitespace,\n )\n )\n form_feed = False\n nlines = 0\n return result\n\n\ndef normalize_trailing_prefix(leaf: LN, total_consumed: int) -> None:\n \"\"\"Normalize the prefix that's left over after generating comments.\n\n Note: don't use backslashes for formatting or you'll lose your voting rights.\n \"\"\"\n remainder = leaf.prefix[total_consumed:]\n if \"\\\\\" not in remainder:\n nl_count = remainder.count(\"\\n\")\n form_feed = \"\\f\" in remainder and remainder.endswith(\"\\n\")\n leaf.prefix = make_simple_prefix(nl_count, form_feed)\n return\n\n leaf.prefix = \"\"\n\n\ndef make_comment(content: str) -> str:\n \"\"\"Return a consistently formatted comment from the given `content` string.\n\n All comments (except for \"##\", \"#!\", \"#:\", '#'\") should have a single\n space between the hash sign and the content.\n\n If `content` didn't start with a hash sign, one is provided.\n \"\"\"\n content = content.rstrip()\n if not content:\n return \"#\"\n\n if content[0] == \"#\":\n content = content[1:]\n NON_BREAKING_SPACE = \"\u00a0\"\n if (\n content\n and content[0] == NON_BREAKING_SPACE\n and not content.lstrip().startswith(\"type:\")\n ):\n content = \" \" + content[1:] # Replace NBSP by a simple space\n if content and content[0] not in COMMENT_EXCEPTIONS:\n content = \" \" + content\n return \"#\" + content\n\n\ndef normalize_fmt_off(\n node: Node, mode: Mode, lines: Collection[Tuple[int, int]]\n) -> None:\n \"\"\"Convert content between `# fmt: off`/`# fmt: on` into standalone comments.\"\"\"\n try_again = True\n while try_again:\n try_again = convert_one_fmt_off_pair(node, mode, lines)\n\n\ndef convert_one_fmt_off_pair(\n node: Node, mode: Mode, lines: Collection[Tuple[int, int]]\n) -> bool:\n \"\"\"Convert content of a single `# fmt: off`/`# fmt: on` into a standalone comment.\n\n Returns True if a pair was converted.\n \"\"\"\n for leaf in node.leaves():\n previous_consumed = 0\n for comment in list_comments(leaf.prefix, is_endmarker=False):\n should_pass_fmt = comment.value in FMT_OFF or _contains_fmt_skip_comment(\n comment.value, mode\n )\n if not should_pass_fmt:\n previous_consumed = comment.consumed\n continue\n # We only want standalone comments. If there's no previous leaf or\n # the previous leaf is indentation, it's a standalone comment in\n # disguise.\n if should_pass_fmt and comment.type != STANDALONE_COMMENT:\n prev = preceding_leaf(leaf)\n if prev:\n if comment.value in FMT_OFF and prev.type not in WHITESPACE:\n continue\n if (\n _contains_fmt_skip_comment(comment.value, mode)\n and prev.type in WHITESPACE\n ):\n continue\n\n ignored_nodes = list(generate_ignored_nodes(leaf, comment, mode))\n if not ignored_nodes:\n continue\n\n first = ignored_nodes[0] # Can be a container node with the `leaf`.\n parent = first.parent\n prefix = first.prefix\n if comment.value in FMT_OFF:\n first.prefix = prefix[comment.consumed :]\n if _contains_fmt_skip_comment(comment.value, mode):\n first.prefix = \"\"\n standalone_comment_prefix = prefix\n else:\n standalone_comment_prefix = (\n prefix[:previous_consumed] + \"\\n\" * comment.newlines\n )\n hidden_value = \"\".join(str(n) for n in ignored_nodes)\n comment_lineno = leaf.lineno - comment.newlines\n if comment.value in FMT_OFF:\n fmt_off_prefix = \"\"\n if len(lines) > 0 and not any(\n line[0] <= comment_lineno <= line[1] for line in lines\n ):\n # keeping indentation of comment by preserving original whitespaces.\n fmt_off_prefix = prefix.split(comment.value)[0]\n if \"\\n\" in fmt_off_prefix:\n fmt_off_prefix = fmt_off_prefix.split(\"\\n\")[-1]\n standalone_comment_prefix += fmt_off_prefix\n hidden_value = comment.value + \"\\n\" + hidden_value\n if _contains_fmt_skip_comment(comment.value, mode):\n hidden_value += (\n comment.leading_whitespace\n if Preview.no_normalize_fmt_skip_whitespace in mode\n else \" \"\n ) + comment.value\n if hidden_value.endswith(\"\\n\"):\n # That happens when one of the `ignored_nodes` ended with a NEWLINE\n # leaf (possibly followed by a DEDENT).\n hidden_value = hidden_value[:-1]\n first_idx: Optional[int] = None\n for ignored in ignored_nodes:\n index = ignored.remove()\n if first_idx is None:\n first_idx = index\n assert parent is not None, \"INTERNAL ERROR: fmt: on/off handling (1)\"\n assert first_idx is not None, \"INTERNAL ERROR: fmt: on/off handling (2)\"\n parent.insert_child(\n first_idx,\n Leaf(\n STANDALONE_COMMENT,\n hidden_value,\n prefix=standalone_comment_prefix,\n fmt_pass_converted_first_leaf=first_leaf_of(first),\n ),\n )\n return True\n\n return False\n\n\ndef generate_ignored_nodes(\n leaf: Leaf, comment: ProtoComment, mode: Mode\n) -> Iterator[LN]:\n \"\"\"Starting from the container of `leaf`, generate all leaves until `# fmt: on`.\n\n If comment is skip, returns leaf only.\n Stops at the end of the block.\n \"\"\"\n if _contains_fmt_skip_comment(comment.value, mode):\n yield from _generate_ignored_nodes_from_fmt_skip(leaf, comment)\n return\n container: Optional[LN] = container_of(leaf)\n while container is not None and container.type != token.ENDMARKER:\n if is_fmt_on(container):\n return\n\n # fix for fmt: on in children\n if children_contains_fmt_on(container):\n for index, child in enumerate(container.children):\n if isinstance(child, Leaf) and is_fmt_on(child):\n if child.type in CLOSING_BRACKETS:\n # This means `# fmt: on` is placed at a different bracket level\n # than `# fmt: off`. This is an invalid use, but as a courtesy,\n # we include this closing bracket in the ignored nodes.\n # The alternative is to fail the formatting.\n yield child\n return\n if (\n child.type == token.INDENT\n and index < len(container.children) - 1\n and children_contains_fmt_on(container.children[index + 1])\n ):\n # This means `# fmt: on` is placed right after an indentation\n # level, and we shouldn't swallow the previous INDENT token.\n return\n if children_contains_fmt_on(child):\n return\n yield child\n else:\n if container.type == token.DEDENT and container.next_sibling is None:\n # This can happen when there is no matching `# fmt: on` comment at the\n # same level as `# fmt: on`. We need to keep this DEDENT.\n return\n yield container\n container = container.next_sibling\n\n\ndef _generate_ignored_nodes_from_fmt_skip(\n leaf: Leaf, comment: ProtoComment\n) -> Iterator[LN]:\n \"\"\"Generate all leaves that should be ignored by the `# fmt: skip` from `leaf`.\"\"\"\n prev_sibling = leaf.prev_sibling\n parent = leaf.parent\n # Need to properly format the leaf prefix to compare it to comment.value,\n # which is also formatted\n comments = list_comments(leaf.prefix, is_endmarker=False)\n if not comments or comment.value != comments[0].value:\n return\n if prev_sibling is not None:\n leaf.prefix = \"\"\n siblings = [prev_sibling]\n while \"\\n\" not in prev_sibling.prefix and prev_sibling.prev_sibling is not None:\n prev_sibling = prev_sibling.prev_sibling\n siblings.insert(0, prev_sibling)\n yield from siblings\n elif (\n parent is not None and parent.type == syms.suite and leaf.type == token.NEWLINE\n ):\n # The `# fmt: skip` is on the colon line of the if/while/def/class/...\n # statements. The ignored nodes should be previous siblings of the\n # parent suite node.\n leaf.prefix = \"\"\n ignored_nodes: List[LN] = []\n parent_sibling = parent.prev_sibling\n while parent_sibling is not None and parent_sibling.type != syms.suite:\n ignored_nodes.insert(0, parent_sibling)\n parent_sibling = parent_sibling.prev_sibling\n # Special case for `async_stmt` where the ASYNC token is on the\n # grandparent node.\n grandparent = parent.parent\n if (\n grandparent is not None\n and grandparent.prev_sibling is not None\n and grandparent.prev_sibling.type == token.ASYNC\n ):\n ignored_nodes.insert(0, grandparent.prev_sibling)\n yield from iter(ignored_nodes)\n\n\ndef is_fmt_on(container: LN) -> bool:\n \"\"\"Determine whether formatting is switched on within a container.\n Determined by whether the last `# fmt:` comment is `on` or `off`.\n \"\"\"\n fmt_on = False\n for comment in list_comments(container.prefix, is_endmarker=False):\n if comment.value in FMT_ON:\n fmt_on = True\n elif comment.value in FMT_OFF:\n fmt_on = False\n return fmt_on\n\n\ndef children_contains_fmt_on(container: LN) -> bool:\n \"\"\"Determine if children have formatting switched on.\"\"\"\n for child in container.children:\n leaf = first_leaf_of(child)\n if leaf is not None and is_fmt_on(leaf):\n return True\n\n return False\n\n\ndef contains_pragma_comment(comment_list: List[Leaf]) -> bool:\n \"\"\"\n Returns:\n True iff one of the comments in @comment_list is a pragma used by one\n of the more common static analysis tools for python (e.g. mypy, flake8,\n pylint).\n \"\"\"\n for comment in comment_list:\n if comment.value.startswith((\"# type:\", \"# noqa\", \"# pylint:\")):\n return True\n\n return False\n\n\ndef _contains_fmt_skip_comment(comment_line: str, mode: Mode) -> bool:\n \"\"\"\n Checks if the given comment contains FMT_SKIP alone or paired with other comments.\n Matching styles:\n # fmt:skip <-- single comment\n # noqa:XXX # fmt:skip # a nice line <-- multiple comments (Preview)\n # pylint:XXX; fmt:skip <-- list of comments (; separated, Preview)\n \"\"\"\n semantic_comment_blocks = [\n comment_line,\n *[\n _COMMENT_PREFIX + comment.strip()\n for comment in comment_line.split(_COMMENT_PREFIX)[1:]\n ],\n *[\n _COMMENT_PREFIX + comment.strip()\n for comment in comment_line.strip(_COMMENT_PREFIX).split(\n _COMMENT_LIST_SEPARATOR\n )\n ],\n ]\n\n return any(comment in FMT_SKIP for comment in semantic_comment_blocks)\n\n# Path: src/black/report.py\n\"\"\"\nSummarize Black runs to users.\n\"\"\"\n\nfrom dataclasses import dataclass\nfrom enum import Enum\nfrom pathlib import Path\n\nfrom click import style\n\nfrom black.output import err, out\n\n\nclass Changed(Enum):\n NO = 0\n CACHED = 1\n YES = 2\n\n\nclass NothingChanged(UserWarning):\n \"\"\"Raised when reformatted code is the same as source.\"\"\"\n\n\n@dataclass\nclass Report:\n \"\"\"Provides a reformatting counter. Can be rendered with `str(report)`.\"\"\"\n\n check: bool = False\n diff: bool = False\n quiet: bool = False\n verbose: bool = False\n change_count: int = 0\n same_count: int = 0\n failure_count: int = 0\n\n def done(self, src: Path, changed: Changed) -> None:\n \"\"\"Increment the counter for successful reformatting. Write out a message.\"\"\"\n if changed is Changed.YES:\n reformatted = \"would reformat\" if self.check or self.diff else \"reformatted\"\n if self.verbose or not self.quiet:\n out(f\"{reformatted} {src}\")\n self.change_count += 1\n else:\n if self.verbose:\n if changed is Changed.NO:\n msg = f\"{src} already well formatted, good job.\"\n else:\n msg = f\"{src} wasn't modified on disk since last run.\"\n out(msg, bold=False)\n self.same_count += 1\n\n def failed(self, src: Path, message: str) -> None:\n \"\"\"Increment the counter for failed reformatting. Write out a message.\"\"\"\n err(f\"error: cannot format {src}: {message}\")\n self.failure_count += 1\n\n def path_ignored(self, path: Path, message: str) -> None:\n if self.verbose:\n out(f\"{path} ignored: {message}\", bold=False)\n\n @property\n def return_code(self) -> int:\n \"\"\"Return the exit code that the app should use.\n\n This considers the current state of changed files and failures:\n - if there were any failures, return 123;\n - if any files were changed and --check is being used, return 1;\n - otherwise return 0.\n \"\"\"\n # According to http://tldp.org/LDP/abs/html/exitcodes.html starting with\n # 126 we have special return codes reserved by the shell.\n if self.failure_count:\n return 123\n\n elif self.change_count and self.check:\n return 1\n\n return 0\n\n def __str__(self) -> str:\n \"\"\"Render a color report of the current state.\n\n Use `click.unstyle` to remove colors.\n \"\"\"\n if self.check or self.diff:\n reformatted = \"would be reformatted\"\n unchanged = \"would be left unchanged\"\n failed = \"would fail to reformat\"\n else:\n reformatted = \"reformatted\"\n unchanged = \"left unchanged\"\n failed = \"failed to reformat\"\n report = []\n if self.change_count:\n s = \"s\" if self.change_count > 1 else \"\"\n report.append(\n style(f\"{self.change_count} file{s} \", bold=True, fg=\"blue\")\n + style(f\"{reformatted}\", bold=True)\n )\n\n if self.same_count:\n s = \"s\" if self.same_count > 1 else \"\"\n report.append(style(f\"{self.same_count} file{s} \", fg=\"blue\") + unchanged)\n if self.failure_count:\n s = \"s\" if self.failure_count > 1 else \"\"\n report.append(style(f\"{self.failure_count} file{s} {failed}\", fg=\"red\"))\n return \", \".join(report) + \".\"\n\n# Path: src/black/concurrency.py\n\"\"\"\nFormatting many files at once via multiprocessing. Contains entrypoint and utilities.\n\nNOTE: this module is only imported if we need to format several files at once.\n\"\"\"\n\nimport asyncio\nimport logging\nimport os\nimport signal\nimport sys\nimport traceback\nfrom concurrent.futures import Executor, ProcessPoolExecutor, ThreadPoolExecutor\nfrom multiprocessing import Manager\nfrom pathlib import Path\nfrom typing import Any, Iterable, Optional, Set\n\nfrom mypy_extensions import mypyc_attr\n\nfrom black import WriteBack, format_file_in_place\nfrom black.cache import Cache\nfrom black.mode import Mode\nfrom black.output import err\nfrom black.report import Changed, Report\n\n\ndef maybe_install_uvloop() -> None:\n \"\"\"If our environment has uvloop installed we use it.\n\n This is called only from command-line entry points to avoid\n interfering with the parent process if Black is used as a library.\n \"\"\"\n try:\n import uvloop\n\n uvloop.install()\n except ImportError:\n pass\n\n\ndef cancel(tasks: Iterable[\"asyncio.Future[Any]\"]) -> None:\n \"\"\"asyncio signal handler that cancels all `tasks` and reports to stderr.\"\"\"\n err(\"Aborted!\")\n for task in tasks:\n task.cancel()\n\n\ndef shutdown(loop: asyncio.AbstractEventLoop) -> None:\n \"\"\"Cancel all pending tasks on `loop`, wait for them, and close the loop.\"\"\"\n try:\n # This part is borrowed from asyncio/runners.py in Python 3.7b2.\n to_cancel = [task for task in asyncio.all_tasks(loop) if not task.done()]\n if not to_cancel:\n return\n\n for task in to_cancel:\n task.cancel()\n loop.run_until_complete(asyncio.gather(*to_cancel, return_exceptions=True))\n finally:\n # `concurrent.futures.Future` objects cannot be cancelled once they\n # are already running. There might be some when the `shutdown()` happened.\n # Silence their logger's spew about the event loop being closed.\n cf_logger = logging.getLogger(\"concurrent.futures\")\n cf_logger.setLevel(logging.CRITICAL)\n loop.close()\n\n\n# diff-shades depends on being to monkeypatch this function to operate. I know it's\n# not ideal, but this shouldn't cause any issues ... hopefully. ~ichard26\n@mypyc_attr(patchable=True)\ndef reformat_many(\n sources: Set[Path],\n fast: bool,\n write_back: WriteBack,\n mode: Mode,\n report: Report,\n workers: Optional[int],\n) -> None:\n \"\"\"Reformat multiple files using a ProcessPoolExecutor.\"\"\"\n maybe_install_uvloop()\n\n executor: Executor\n if workers is None:\n workers = int(os.environ.get(\"BLACK_NUM_WORKERS\", 0))\n workers = workers or os.cpu_count() or 1\n if sys.platform == \"win32\":\n # Work around https://bugs.python.org/issue26903\n workers = min(workers, 60)\n try:\n executor = ProcessPoolExecutor(max_workers=workers)\n except (ImportError, NotImplementedError, OSError):\n # we arrive here if the underlying system does not support multi-processing\n # like in AWS Lambda or Termux, in which case we gracefully fallback to\n # a ThreadPoolExecutor with just a single worker (more workers would not do us\n # any good due to the Global Interpreter Lock)\n executor = ThreadPoolExecutor(max_workers=1)\n\n loop = asyncio.new_event_loop()\n asyncio.set_event_loop(loop)\n try:\n loop.run_until_complete(\n schedule_formatting(\n sources=sources,\n fast=fast,\n write_back=write_back,\n mode=mode,\n report=report,\n loop=loop,\n executor=executor,\n )\n )\n finally:\n try:\n shutdown(loop)\n finally:\n asyncio.set_event_loop(None)\n if executor is not None:\n executor.shutdown()\n\n\nasync def schedule_formatting(\n sources: Set[Path],\n fast: bool,\n write_back: WriteBack,\n mode: Mode,\n report: \"Report\",\n loop: asyncio.AbstractEventLoop,\n executor: \"Executor\",\n) -> None:\n \"\"\"Run formatting of `sources` in parallel using the provided `executor`.\n\n (Use ProcessPoolExecutors for actual parallelism.)\n\n `write_back`, `fast`, and `mode` options are passed to\n :func:`format_file_in_place`.\n \"\"\"\n cache = Cache.read(mode)\n if write_back not in (WriteBack.DIFF, WriteBack.COLOR_DIFF):\n sources, cached = cache.filtered_cached(sources)\n for src in sorted(cached):\n report.done(src, Changed.CACHED)\n if not sources:\n return\n\n cancelled = []\n sources_to_cache = []\n lock = None\n if write_back in (WriteBack.DIFF, WriteBack.COLOR_DIFF):\n # For diff output, we need locks to ensure we don't interleave output\n # from different processes.\n manager = Manager()\n lock = manager.Lock()\n tasks = {\n asyncio.ensure_future(\n loop.run_in_executor(\n executor, format_file_in_place, src, fast, mode, write_back, lock\n )\n ): src\n for src in sorted(sources)\n }\n pending = tasks.keys()\n try:\n loop.add_signal_handler(signal.SIGINT, cancel, pending)\n loop.add_signal_handler(signal.SIGTERM, cancel, pending)\n except NotImplementedError:\n # There are no good alternatives for these on Windows.\n pass\n while pending:\n done, _ = await asyncio.wait(pending, return_when=asyncio.FIRST_COMPLETED)\n for task in done:\n src = tasks.pop(task)\n if task.cancelled():\n cancelled.append(task)\n elif exc := task.exception():\n if report.verbose:\n traceback.print_exception(type(exc), exc, exc.__traceback__)\n report.failed(src, str(exc))\n else:\n changed = Changed.YES if task.result() else Changed.NO\n # If the file was written back or was successfully checked as\n # well-formatted, store this information in the cache.\n if write_back is WriteBack.YES or (\n write_back is WriteBack.CHECK and changed is Changed.NO\n ):\n sources_to_cache.append(src)\n report.done(src, changed)\n if cancelled:\n await asyncio.gather(*cancelled, return_exceptions=True)\n if sources_to_cache:\n cache.write(sources_to_cache)\n\n# Path: src/black/handle_ipynb_magics.py\n\"\"\"Functions to process IPython magics with.\"\"\"\n\nimport ast\nimport collections\nimport dataclasses\nimport secrets\nimport sys\nfrom functools import lru_cache\nfrom importlib.util import find_spec\nfrom typing import Dict, List, Optional, Tuple\n\nif sys.version_info >= (3, 10):\n from typing import TypeGuard\nelse:\n from typing_extensions import TypeGuard\n\nfrom black.output import out\nfrom black.report import NothingChanged\n\nTRANSFORMED_MAGICS = frozenset((\n \"get_ipython().run_cell_magic\",\n \"get_ipython().system\",\n \"get_ipython().getoutput\",\n \"get_ipython().run_line_magic\",\n))\nTOKENS_TO_IGNORE = frozenset((\n \"ENDMARKER\",\n \"NL\",\n \"NEWLINE\",\n \"COMMENT\",\n \"DEDENT\",\n \"UNIMPORTANT_WS\",\n \"ESCAPED_NL\",\n))\nPYTHON_CELL_MAGICS = frozenset((\n \"capture\",\n \"prun\",\n \"pypy\",\n \"python\",\n \"python3\",\n \"time\",\n \"timeit\",\n))\nTOKEN_HEX = secrets.token_hex\n\n\n@dataclasses.dataclass(frozen=True)\nclass Replacement:\n mask: str\n src: str\n\n\n@lru_cache\ndef jupyter_dependencies_are_installed(*, warn: bool) -> bool:\n installed = (\n find_spec(\"tokenize_rt\") is not None and find_spec(\"IPython\") is not None\n )\n if not installed and warn:\n msg = (\n \"Skipping .ipynb files as Jupyter dependencies are not installed.\\n\"\n 'You can fix this by running ``pip install \"black[jupyter]\"``'\n )\n out(msg)\n return installed\n\n\ndef remove_trailing_semicolon(src: str) -> Tuple[str, bool]:\n \"\"\"Remove trailing semicolon from Jupyter notebook cell.\n\n For example,\n\n fig, ax = plt.subplots()\n ax.plot(x_data, y_data); # plot data\n\n would become\n\n fig, ax = plt.subplots()\n ax.plot(x_data, y_data) # plot data\n\n Mirrors the logic in `quiet` from `IPython.core.displayhook`, but uses\n ``tokenize_rt`` so that round-tripping works fine.\n \"\"\"\n from tokenize_rt import reversed_enumerate, src_to_tokens, tokens_to_src\n\n tokens = src_to_tokens(src)\n trailing_semicolon = False\n for idx, token in reversed_enumerate(tokens):\n if token.name in TOKENS_TO_IGNORE:\n continue\n if token.name == \"OP\" and token.src == \";\":\n del tokens[idx]\n trailing_semicolon = True\n break\n if not trailing_semicolon:\n return src, False\n return tokens_to_src(tokens), True\n\n\ndef put_trailing_semicolon_back(src: str, has_trailing_semicolon: bool) -> str:\n \"\"\"Put trailing semicolon back if cell originally had it.\n\n Mirrors the logic in `quiet` from `IPython.core.displayhook`, but uses\n ``tokenize_rt`` so that round-tripping works fine.\n \"\"\"\n if not has_trailing_semicolon:\n return src\n from tokenize_rt import reversed_enumerate, src_to_tokens, tokens_to_src\n\n tokens = src_to_tokens(src)\n for idx, token in reversed_enumerate(tokens):\n if token.name in TOKENS_TO_IGNORE:\n continue\n tokens[idx] = token._replace(src=token.src + \";\")\n break\n else: # pragma: nocover\n raise AssertionError(\n \"INTERNAL ERROR: Was not able to reinstate trailing semicolon. \"\n \"Please report a bug on https://github.com/psf/black/issues. \"\n ) from None\n return str(tokens_to_src(tokens))\n\n\ndef mask_cell(src: str) -> Tuple[str, List[Replacement]]:\n \"\"\"Mask IPython magics so content becomes parseable Python code.\n\n For example,\n\n %matplotlib inline\n 'foo'\n\n becomes\n\n \"25716f358c32750e\"\n...\n", "needle_token_start": 7162, "needle_token_end": 7209, "code_context_ntokens": 16062, "output": ["```python\ndef is_part_of_annotation(leaf: Leaf) -> bool:\n \"\"\"Returns whether this leaf is part of a type annotation.\"\"\"\n ancestor = leaf.parent\n while ancestor is not None:\n if ancestor.prev_sibling and ancestor.prev_sibling.type == token.RARROW:\n return \"return\"\n if ancestor.parent and ancestor.parent.type == syms.tname:\n return \"param\"\n ancestor = ancestor.parent\n return None\n```"]} +{"repo": "psf/black", "name": "maybe_install_uvloop", "language": "python", "path": "src/black/concurrency.py", "position_ratio": 0.55, "description": "\nFunction Description:\n1. **Purpose**: To enhance the event loop performance by optionally integrating a faster alternative if available in the environment.\n2. **Input**: None.\n3. **Output**: None, but the system's default event loop may be replaced.\n4. **Procedure**: The function attempts to import a specific high-performance event loop implementation. If the import is successful, it sets this implementation as the default for the entire application. If the import fails, the function does nothing, leaving the default event loop unchanged.\n\n", "instruction": "Based on the function description and code context, please retrieve and repeat the exact described function from the code context in a code block wrapped by ```:", "template": "instruction\ncode_context\ndescription\ninstruction", "code_context": "# Path: src/black/nodes.py\n\"\"\"\nblib2to3 Node/Leaf transformation-related utility functions.\n\"\"\"\n\nimport sys\nfrom typing import (\n Final,\n Generic,\n Iterator,\n List,\n Literal,\n Optional,\n Set,\n Tuple,\n TypeVar,\n Union,\n)\n\nif sys.version_info >= (3, 10):\n from typing import TypeGuard\nelse:\n from typing_extensions import TypeGuard\n\nfrom mypy_extensions import mypyc_attr\n\nfrom black.cache import CACHE_DIR\nfrom black.mode import Mode, Preview\nfrom black.strings import get_string_prefix, has_triple_quotes\nfrom blib2to3 import pygram\nfrom blib2to3.pgen2 import token\nfrom blib2to3.pytree import NL, Leaf, Node, type_repr\n\npygram.initialize(CACHE_DIR)\nsyms: Final = pygram.python_symbols\n\n\n# types\nT = TypeVar(\"T\")\nLN = Union[Leaf, Node]\nLeafID = int\nNodeType = int\n\n\nWHITESPACE: Final = {token.DEDENT, token.INDENT, token.NEWLINE}\nSTATEMENT: Final = {\n syms.if_stmt,\n syms.while_stmt,\n syms.for_stmt,\n syms.try_stmt,\n syms.except_clause,\n syms.with_stmt,\n syms.funcdef,\n syms.classdef,\n syms.match_stmt,\n syms.case_block,\n}\nSTANDALONE_COMMENT: Final = 153\ntoken.tok_name[STANDALONE_COMMENT] = \"STANDALONE_COMMENT\"\nLOGIC_OPERATORS: Final = {\"and\", \"or\"}\nCOMPARATORS: Final = {\n token.LESS,\n token.GREATER,\n token.EQEQUAL,\n token.NOTEQUAL,\n token.LESSEQUAL,\n token.GREATEREQUAL,\n}\nMATH_OPERATORS: Final = {\n token.VBAR,\n token.CIRCUMFLEX,\n token.AMPER,\n token.LEFTSHIFT,\n token.RIGHTSHIFT,\n token.PLUS,\n token.MINUS,\n token.STAR,\n token.SLASH,\n token.DOUBLESLASH,\n token.PERCENT,\n token.AT,\n token.TILDE,\n token.DOUBLESTAR,\n}\nSTARS: Final = {token.STAR, token.DOUBLESTAR}\nVARARGS_SPECIALS: Final = STARS | {token.SLASH}\nVARARGS_PARENTS: Final = {\n syms.arglist,\n syms.argument, # double star in arglist\n syms.trailer, # single argument to call\n syms.typedargslist,\n syms.varargslist, # lambdas\n}\nUNPACKING_PARENTS: Final = {\n syms.atom, # single element of a list or set literal\n syms.dictsetmaker,\n syms.listmaker,\n syms.testlist_gexp,\n syms.testlist_star_expr,\n syms.subject_expr,\n syms.pattern,\n}\nTEST_DESCENDANTS: Final = {\n syms.test,\n syms.lambdef,\n syms.or_test,\n syms.and_test,\n syms.not_test,\n syms.comparison,\n syms.star_expr,\n syms.expr,\n syms.xor_expr,\n syms.and_expr,\n syms.shift_expr,\n syms.arith_expr,\n syms.trailer,\n syms.term,\n syms.power,\n syms.namedexpr_test,\n}\nTYPED_NAMES: Final = {syms.tname, syms.tname_star}\nASSIGNMENTS: Final = {\n \"=\",\n \"+=\",\n \"-=\",\n \"*=\",\n \"@=\",\n \"/=\",\n \"%=\",\n \"&=\",\n \"|=\",\n \"^=\",\n \"<<=\",\n \">>=\",\n \"**=\",\n \"//=\",\n \":\",\n}\n\nIMPLICIT_TUPLE: Final = {syms.testlist, syms.testlist_star_expr, syms.exprlist}\nBRACKET: Final = {\n token.LPAR: token.RPAR,\n token.LSQB: token.RSQB,\n token.LBRACE: token.RBRACE,\n}\nOPENING_BRACKETS: Final = set(BRACKET.keys())\nCLOSING_BRACKETS: Final = set(BRACKET.values())\nBRACKETS: Final = OPENING_BRACKETS | CLOSING_BRACKETS\nALWAYS_NO_SPACE: Final = CLOSING_BRACKETS | {token.COMMA, STANDALONE_COMMENT}\n\nRARROW = 55\n\n\n@mypyc_attr(allow_interpreted_subclasses=True)\nclass Visitor(Generic[T]):\n \"\"\"Basic lib2to3 visitor that yields things of type `T` on `visit()`.\"\"\"\n\n def visit(self, node: LN) -> Iterator[T]:\n \"\"\"Main method to visit `node` and its children.\n\n It tries to find a `visit_*()` method for the given `node.type`, like\n `visit_simple_stmt` for Node objects or `visit_INDENT` for Leaf objects.\n If no dedicated `visit_*()` method is found, chooses `visit_default()`\n instead.\n\n Then yields objects of type `T` from the selected visitor.\n \"\"\"\n if node.type < 256:\n name = token.tok_name[node.type]\n else:\n name = str(type_repr(node.type))\n # We explicitly branch on whether a visitor exists (instead of\n # using self.visit_default as the default arg to getattr) in order\n # to save needing to create a bound method object and so mypyc can\n # generate a native call to visit_default.\n visitf = getattr(self, f\"visit_{name}\", None)\n if visitf:\n yield from visitf(node)\n else:\n yield from self.visit_default(node)\n\n def visit_default(self, node: LN) -> Iterator[T]:\n \"\"\"Default `visit_*()` implementation. Recurses to children of `node`.\"\"\"\n if isinstance(node, Node):\n for child in node.children:\n yield from self.visit(child)\n\n\ndef whitespace(leaf: Leaf, *, complex_subscript: bool, mode: Mode) -> str: # noqa: C901\n \"\"\"Return whitespace prefix if needed for the given `leaf`.\n\n `complex_subscript` signals whether the given leaf is part of a subscription\n which has non-trivial arguments, like arithmetic expressions or function calls.\n \"\"\"\n NO: Final[str] = \"\"\n SPACE: Final[str] = \" \"\n DOUBLESPACE: Final[str] = \" \"\n t = leaf.type\n p = leaf.parent\n v = leaf.value\n if t in ALWAYS_NO_SPACE:\n return NO\n\n if t == token.COMMENT:\n return DOUBLESPACE\n\n assert p is not None, f\"INTERNAL ERROR: hand-made leaf without parent: {leaf!r}\"\n if t == token.COLON and p.type not in {\n syms.subscript,\n syms.subscriptlist,\n syms.sliceop,\n }:\n return NO\n\n prev = leaf.prev_sibling\n if not prev:\n prevp = preceding_leaf(p)\n if not prevp or prevp.type in OPENING_BRACKETS:\n return NO\n\n if t == token.COLON:\n if prevp.type == token.COLON:\n return NO\n\n elif prevp.type != token.COMMA and not complex_subscript:\n return NO\n\n return SPACE\n\n if prevp.type == token.EQUAL:\n if prevp.parent:\n if prevp.parent.type in {\n syms.arglist,\n syms.argument,\n syms.parameters,\n syms.varargslist,\n }:\n return NO\n\n elif prevp.parent.type == syms.typedargslist:\n # A bit hacky: if the equal sign has whitespace, it means we\n # previously found it's a typed argument. So, we're using\n # that, too.\n return prevp.prefix\n\n elif (\n prevp.type == token.STAR\n and parent_type(prevp) == syms.star_expr\n and parent_type(prevp.parent) == syms.subscriptlist\n ):\n # No space between typevar tuples.\n return NO\n\n elif prevp.type in VARARGS_SPECIALS:\n if is_vararg(prevp, within=VARARGS_PARENTS | UNPACKING_PARENTS):\n return NO\n\n elif prevp.type == token.COLON:\n if prevp.parent and prevp.parent.type in {syms.subscript, syms.sliceop}:\n return SPACE if complex_subscript else NO\n\n elif (\n prevp.parent\n and prevp.parent.type == syms.factor\n and prevp.type in MATH_OPERATORS\n ):\n return NO\n\n elif prevp.type == token.AT and p.parent and p.parent.type == syms.decorator:\n # no space in decorators\n return NO\n\n elif prev.type in OPENING_BRACKETS:\n return NO\n\n if p.type in {syms.parameters, syms.arglist}:\n # untyped function signatures or calls\n if not prev or prev.type != token.COMMA:\n return NO\n\n elif p.type == syms.varargslist:\n # lambdas\n if prev and prev.type != token.COMMA:\n return NO\n\n elif p.type == syms.typedargslist:\n # typed function signatures\n if not prev:\n return NO\n\n if t == token.EQUAL:\n if prev.type not in TYPED_NAMES:\n return NO\n\n elif prev.type == token.EQUAL:\n # A bit hacky: if the equal sign has whitespace, it means we\n # previously found it's a typed argument. So, we're using that, too.\n return prev.prefix\n\n elif prev.type != token.COMMA:\n return NO\n\n elif p.type in TYPED_NAMES:\n # type names\n if not prev:\n prevp = preceding_leaf(p)\n if not prevp or prevp.type != token.COMMA:\n return NO\n\n elif p.type == syms.trailer:\n # attributes and calls\n if t == token.LPAR or t == token.RPAR:\n return NO\n\n if not prev:\n if t == token.DOT or t == token.LSQB:\n return NO\n\n elif prev.type != token.COMMA:\n return NO\n\n elif p.type == syms.argument:\n # single argument\n if t == token.EQUAL:\n return NO\n\n if not prev:\n prevp = preceding_leaf(p)\n if not prevp or prevp.type == token.LPAR:\n return NO\n\n elif prev.type in {token.EQUAL} | VARARGS_SPECIALS:\n return NO\n\n elif p.type == syms.decorator:\n # decorators\n return NO\n\n elif p.type == syms.dotted_name:\n if prev:\n return NO\n...\n# Path: src/black/comments.py\nimport re\nfrom dataclasses import dataclass\nfrom functools import lru_cache\nfrom typing import Collection, Final, Iterator, List, Optional, Tuple, Union\n\nfrom black.mode import Mode, Preview\nfrom black.nodes import (\n CLOSING_BRACKETS,\n STANDALONE_COMMENT,\n WHITESPACE,\n container_of,\n first_leaf_of,\n make_simple_prefix,\n preceding_leaf,\n syms,\n)\nfrom blib2to3.pgen2 import token\nfrom blib2to3.pytree import Leaf, Node\n\n# types\nLN = Union[Leaf, Node]\n\nFMT_OFF: Final = {\"# fmt: off\", \"# fmt:off\", \"# yapf: disable\"}\nFMT_SKIP: Final = {\"# fmt: skip\", \"# fmt:skip\"}\nFMT_ON: Final = {\"# fmt: on\", \"# fmt:on\", \"# yapf: enable\"}\n\nCOMMENT_EXCEPTIONS = \" !:#'\"\n_COMMENT_PREFIX = \"# \"\n_COMMENT_LIST_SEPARATOR = \";\"\n\n\n@dataclass\nclass ProtoComment:\n \"\"\"Describes a piece of syntax that is a comment.\n\n It's not a :class:`blib2to3.pytree.Leaf` so that:\n\n * it can be cached (`Leaf` objects should not be reused more than once as\n they store their lineno, column, prefix, and parent information);\n * `newlines` and `consumed` fields are kept separate from the `value`. This\n simplifies handling of special marker comments like ``# fmt: off/on``.\n \"\"\"\n\n type: int # token.COMMENT or STANDALONE_COMMENT\n value: str # content of the comment\n newlines: int # how many newlines before the comment\n consumed: int # how many characters of the original leaf's prefix did we consume\n form_feed: bool # is there a form feed before the comment\n leading_whitespace: str # leading whitespace before the comment, if any\n\n\ndef generate_comments(leaf: LN) -> Iterator[Leaf]:\n \"\"\"Clean the prefix of the `leaf` and generate comments from it, if any.\n\n Comments in lib2to3 are shoved into the whitespace prefix. This happens\n in `pgen2/driver.py:Driver.parse_tokens()`. This was a brilliant implementation\n move because it does away with modifying the grammar to include all the\n possible places in which comments can be placed.\n\n The sad consequence for us though is that comments don't \"belong\" anywhere.\n This is why this function generates simple parentless Leaf objects for\n comments. We simply don't know what the correct parent should be.\n\n No matter though, we can live without this. We really only need to\n differentiate between inline and standalone comments. The latter don't\n share the line with any code.\n\n Inline comments are emitted as regular token.COMMENT leaves. Standalone\n are emitted with a fake STANDALONE_COMMENT token identifier.\n \"\"\"\n total_consumed = 0\n for pc in list_comments(leaf.prefix, is_endmarker=leaf.type == token.ENDMARKER):\n total_consumed = pc.consumed\n prefix = make_simple_prefix(pc.newlines, pc.form_feed)\n yield Leaf(pc.type, pc.value, prefix=prefix)\n normalize_trailing_prefix(leaf, total_consumed)\n\n\n@lru_cache(maxsize=4096)\ndef list_comments(prefix: str, *, is_endmarker: bool) -> List[ProtoComment]:\n \"\"\"Return a list of :class:`ProtoComment` objects parsed from the given `prefix`.\"\"\"\n result: List[ProtoComment] = []\n if not prefix or \"#\" not in prefix:\n return result\n\n consumed = 0\n nlines = 0\n ignored_lines = 0\n form_feed = False\n for index, full_line in enumerate(re.split(\"\\r?\\n\", prefix)):\n consumed += len(full_line) + 1 # adding the length of the split '\\n'\n match = re.match(r\"^(\\s*)(\\S.*|)$\", full_line)\n assert match\n whitespace, line = match.groups()\n if not line:\n nlines += 1\n if \"\\f\" in full_line:\n form_feed = True\n if not line.startswith(\"#\"):\n # Escaped newlines outside of a comment are not really newlines at\n # all. We treat a single-line comment following an escaped newline\n # as a simple trailing comment.\n if line.endswith(\"\\\\\"):\n ignored_lines += 1\n continue\n\n if index == ignored_lines and not is_endmarker:\n comment_type = token.COMMENT # simple trailing comment\n else:\n comment_type = STANDALONE_COMMENT\n comment = make_comment(line)\n result.append(\n ProtoComment(\n type=comment_type,\n value=comment,\n newlines=nlines,\n consumed=consumed,\n form_feed=form_feed,\n leading_whitespace=whitespace,\n )\n )\n form_feed = False\n nlines = 0\n return result\n\n\ndef normalize_trailing_prefix(leaf: LN, total_consumed: int) -> None:\n \"\"\"Normalize the prefix that's left over after generating comments.\n\n Note: don't use backslashes for formatting or you'll lose your voting rights.\n \"\"\"\n remainder = leaf.prefix[total_consumed:]\n if \"\\\\\" not in remainder:\n nl_count = remainder.count(\"\\n\")\n form_feed = \"\\f\" in remainder and remainder.endswith(\"\\n\")\n leaf.prefix = make_simple_prefix(nl_count, form_feed)\n return\n\n leaf.prefix = \"\"\n\n\ndef make_comment(content: str) -> str:\n \"\"\"Return a consistently formatted comment from the given `content` string.\n\n All comments (except for \"##\", \"#!\", \"#:\", '#'\") should have a single\n space between the hash sign and the content.\n\n If `content` didn't start with a hash sign, one is provided.\n \"\"\"\n content = content.rstrip()\n if not content:\n return \"#\"\n\n if content[0] == \"#\":\n content = content[1:]\n NON_BREAKING_SPACE = \"\u00a0\"\n if (\n content\n and content[0] == NON_BREAKING_SPACE\n and not content.lstrip().startswith(\"type:\")\n ):\n content = \" \" + content[1:] # Replace NBSP by a simple space\n if content and content[0] not in COMMENT_EXCEPTIONS:\n content = \" \" + content\n return \"#\" + content\n\n\ndef normalize_fmt_off(\n node: Node, mode: Mode, lines: Collection[Tuple[int, int]]\n) -> None:\n \"\"\"Convert content between `# fmt: off`/`# fmt: on` into standalone comments.\"\"\"\n try_again = True\n while try_again:\n try_again = convert_one_fmt_off_pair(node, mode, lines)\n\n\ndef convert_one_fmt_off_pair(\n node: Node, mode: Mode, lines: Collection[Tuple[int, int]]\n) -> bool:\n \"\"\"Convert content of a single `# fmt: off`/`# fmt: on` into a standalone comment.\n\n Returns True if a pair was converted.\n \"\"\"\n for leaf in node.leaves():\n previous_consumed = 0\n for comment in list_comments(leaf.prefix, is_endmarker=False):\n should_pass_fmt = comment.value in FMT_OFF or _contains_fmt_skip_comment(\n comment.value, mode\n )\n if not should_pass_fmt:\n previous_consumed = comment.consumed\n continue\n # We only want standalone comments. If there's no previous leaf or\n # the previous leaf is indentation, it's a standalone comment in\n # disguise.\n if should_pass_fmt and comment.type != STANDALONE_COMMENT:\n prev = preceding_leaf(leaf)\n if prev:\n if comment.value in FMT_OFF and prev.type not in WHITESPACE:\n continue\n if (\n _contains_fmt_skip_comment(comment.value, mode)\n and prev.type in WHITESPACE\n ):\n continue\n\n ignored_nodes = list(generate_ignored_nodes(leaf, comment, mode))\n if not ignored_nodes:\n continue\n\n first = ignored_nodes[0] # Can be a container node with the `leaf`.\n parent = first.parent\n prefix = first.prefix\n if comment.value in FMT_OFF:\n first.prefix = prefix[comment.consumed :]\n if _contains_fmt_skip_comment(comment.value, mode):\n first.prefix = \"\"\n standalone_comment_prefix = prefix\n else:\n standalone_comment_prefix = (\n prefix[:previous_consumed] + \"\\n\" * comment.newlines\n )\n hidden_value = \"\".join(str(n) for n in ignored_nodes)\n comment_lineno = leaf.lineno - comment.newlines\n if comment.value in FMT_OFF:\n fmt_off_prefix = \"\"\n if len(lines) > 0 and not any(\n line[0] <= comment_lineno <= line[1] for line in lines\n ):\n # keeping indentation of comment by preserving original whitespaces.\n fmt_off_prefix = prefix.split(comment.value)[0]\n if \"\\n\" in fmt_off_prefix:\n fmt_off_prefix = fmt_off_prefix.split(\"\\n\")[-1]\n standalone_comment_prefix += fmt_off_prefix\n hidden_value = comment.value + \"\\n\" + hidden_value\n if _contains_fmt_skip_comment(comment.value, mode):\n hidden_value += (\n comment.leading_whitespace\n if Preview.no_normalize_fmt_skip_whitespace in mode\n else \" \"\n ) + comment.value\n if hidden_value.endswith(\"\\n\"):\n # That happens when one of the `ignored_nodes` ended with a NEWLINE\n # leaf (possibly followed by a DEDENT).\n hidden_value = hidden_value[:-1]\n first_idx: Optional[int] = None\n for ignored in ignored_nodes:\n index = ignored.remove()\n if first_idx is None:\n first_idx = index\n assert parent is not None, \"INTERNAL ERROR: fmt: on/off handling (1)\"\n assert first_idx is not None, \"INTERNAL ERROR: fmt: on/off handling (2)\"\n parent.insert_child(\n first_idx,\n Leaf(\n STANDALONE_COMMENT,\n hidden_value,\n prefix=standalone_comment_prefix,\n fmt_pass_converted_first_leaf=first_leaf_of(first),\n ),\n )\n return True\n\n return False\n\n\ndef generate_ignored_nodes(\n leaf: Leaf, comment: ProtoComment, mode: Mode\n) -> Iterator[LN]:\n \"\"\"Starting from the container of `leaf`, generate all leaves until `# fmt: on`.\n\n If comment is skip, returns leaf only.\n Stops at the end of the block.\n \"\"\"\n if _contains_fmt_skip_comment(comment.value, mode):\n yield from _generate_ignored_nodes_from_fmt_skip(leaf, comment)\n return\n container: Optional[LN] = container_of(leaf)\n while container is not None and container.type != token.ENDMARKER:\n if is_fmt_on(container):\n return\n\n # fix for fmt: on in children\n if children_contains_fmt_on(container):\n for index, child in enumerate(container.children):\n if isinstance(child, Leaf) and is_fmt_on(child):\n if child.type in CLOSING_BRACKETS:\n # This means `# fmt: on` is placed at a different bracket level\n # than `# fmt: off`. This is an invalid use, but as a courtesy,\n # we include this closing bracket in the ignored nodes.\n # The alternative is to fail the formatting.\n yield child\n return\n if (\n child.type == token.INDENT\n and index < len(container.children) - 1\n and children_contains_fmt_on(container.children[index + 1])\n ):\n # This means `# fmt: on` is placed right after an indentation\n # level, and we shouldn't swallow the previous INDENT token.\n return\n if children_contains_fmt_on(child):\n return\n yield child\n else:\n if container.type == token.DEDENT and container.next_sibling is None:\n # This can happen when there is no matching `# fmt: on` comment at the\n # same level as `# fmt: on`. We need to keep this DEDENT.\n return\n yield container\n container = container.next_sibling\n\n\ndef _generate_ignored_nodes_from_fmt_skip(\n leaf: Leaf, comment: ProtoComment\n) -> Iterator[LN]:\n \"\"\"Generate all leaves that should be ignored by the `# fmt: skip` from `leaf`.\"\"\"\n prev_sibling = leaf.prev_sibling\n parent = leaf.parent\n # Need to properly format the leaf prefix to compare it to comment.value,\n # which is also formatted\n comments = list_comments(leaf.prefix, is_endmarker=False)\n if not comments or comment.value != comments[0].value:\n return\n if prev_sibling is not None:\n leaf.prefix = \"\"\n siblings = [prev_sibling]\n while \"\\n\" not in prev_sibling.prefix and prev_sibling.prev_sibling is not None:\n prev_sibling = prev_sibling.prev_sibling\n siblings.insert(0, prev_sibling)\n yield from siblings\n elif (\n parent is not None and parent.type == syms.suite and leaf.type == token.NEWLINE\n ):\n # The `# fmt: skip` is on the colon line of the if/while/def/class/...\n # statements. The ignored nodes should be previous siblings of the\n # parent suite node.\n leaf.prefix = \"\"\n ignored_nodes: List[LN] = []\n parent_sibling = parent.prev_sibling\n while parent_sibling is not None and parent_sibling.type != syms.suite:\n ignored_nodes.insert(0, parent_sibling)\n parent_sibling = parent_sibling.prev_sibling\n # Special case for `async_stmt` where the ASYNC token is on the\n # grandparent node.\n grandparent = parent.parent\n if (\n grandparent is not None\n and grandparent.prev_sibling is not None\n and grandparent.prev_sibling.type == token.ASYNC\n ):\n ignored_nodes.insert(0, grandparent.prev_sibling)\n yield from iter(ignored_nodes)\n\n\ndef is_fmt_on(container: LN) -> bool:\n \"\"\"Determine whether formatting is switched on within a container.\n Determined by whether the last `# fmt:` comment is `on` or `off`.\n \"\"\"\n fmt_on = False\n for comment in list_comments(container.prefix, is_endmarker=False):\n if comment.value in FMT_ON:\n fmt_on = True\n elif comment.value in FMT_OFF:\n fmt_on = False\n return fmt_on\n\n\ndef children_contains_fmt_on(container: LN) -> bool:\n \"\"\"Determine if children have formatting switched on.\"\"\"\n for child in container.children:\n leaf = first_leaf_of(child)\n if leaf is not None and is_fmt_on(leaf):\n return True\n\n return False\n\n\ndef contains_pragma_comment(comment_list: List[Leaf]) -> bool:\n \"\"\"\n Returns:\n True iff one of the comments in @comment_list is a pragma used by one\n of the more common static analysis tools for python (e.g. mypy, flake8,\n pylint).\n \"\"\"\n for comment in comment_list:\n if comment.value.startswith((\"# type:\", \"# noqa\", \"# pylint:\")):\n return True\n\n return False\n\n\ndef _contains_fmt_skip_comment(comment_line: str, mode: Mode) -> bool:\n \"\"\"\n Checks if the given comment contains FMT_SKIP alone or paired with other comments.\n Matching styles:\n # fmt:skip <-- single comment\n # noqa:XXX # fmt:skip # a nice line <-- multiple comments (Preview)\n # pylint:XXX; fmt:skip <-- list of comments (; separated, Preview)\n \"\"\"\n semantic_comment_blocks = [\n comment_line,\n *[\n _COMMENT_PREFIX + comment.strip()\n for comment in comment_line.split(_COMMENT_PREFIX)[1:]\n ],\n *[\n _COMMENT_PREFIX + comment.strip()\n for comment in comment_line.strip(_COMMENT_PREFIX).split(\n _COMMENT_LIST_SEPARATOR\n )\n ],\n ]\n\n return any(comment in FMT_SKIP for comment in semantic_comment_blocks)\n\n# Path: src/black/report.py\n\"\"\"\nSummarize Black runs to users.\n\"\"\"\n\nfrom dataclasses import dataclass\nfrom enum import Enum\nfrom pathlib import Path\n\nfrom click import style\n\nfrom black.output import err, out\n\n\nclass Changed(Enum):\n NO = 0\n CACHED = 1\n YES = 2\n\n\nclass NothingChanged(UserWarning):\n \"\"\"Raised when reformatted code is the same as source.\"\"\"\n\n\n@dataclass\nclass Report:\n \"\"\"Provides a reformatting counter. Can be rendered with `str(report)`.\"\"\"\n\n check: bool = False\n diff: bool = False\n quiet: bool = False\n verbose: bool = False\n change_count: int = 0\n same_count: int = 0\n failure_count: int = 0\n\n def done(self, src: Path, changed: Changed) -> None:\n \"\"\"Increment the counter for successful reformatting. Write out a message.\"\"\"\n if changed is Changed.YES:\n reformatted = \"would reformat\" if self.check or self.diff else \"reformatted\"\n if self.verbose or not self.quiet:\n out(f\"{reformatted} {src}\")\n self.change_count += 1\n else:\n if self.verbose:\n if changed is Changed.NO:\n msg = f\"{src} already well formatted, good job.\"\n else:\n msg = f\"{src} wasn't modified on disk since last run.\"\n out(msg, bold=False)\n self.same_count += 1\n\n def failed(self, src: Path, message: str) -> None:\n \"\"\"Increment the counter for failed reformatting. Write out a message.\"\"\"\n err(f\"error: cannot format {src}: {message}\")\n self.failure_count += 1\n\n def path_ignored(self, path: Path, message: str) -> None:\n if self.verbose:\n out(f\"{path} ignored: {message}\", bold=False)\n\n @property\n def return_code(self) -> int:\n \"\"\"Return the exit code that the app should use.\n\n This considers the current state of changed files and failures:\n - if there were any failures, return 123;\n - if any files were changed and --check is being used, return 1;\n - otherwise return 0.\n \"\"\"\n # According to http://tldp.org/LDP/abs/html/exitcodes.html starting with\n # 126 we have special return codes reserved by the shell.\n if self.failure_count:\n return 123\n\n elif self.change_count and self.check:\n return 1\n\n return 0\n\n def __str__(self) -> str:\n \"\"\"Render a color report of the current state.\n\n Use `click.unstyle` to remove colors.\n \"\"\"\n if self.check or self.diff:\n reformatted = \"would be reformatted\"\n unchanged = \"would be left unchanged\"\n failed = \"would fail to reformat\"\n else:\n reformatted = \"reformatted\"\n unchanged = \"left unchanged\"\n failed = \"failed to reformat\"\n report = []\n if self.change_count:\n s = \"s\" if self.change_count > 1 else \"\"\n report.append(\n style(f\"{self.change_count} file{s} \", bold=True, fg=\"blue\")\n + style(f\"{reformatted}\", bold=True)\n )\n\n if self.same_count:\n s = \"s\" if self.same_count > 1 else \"\"\n report.append(style(f\"{self.same_count} file{s} \", fg=\"blue\") + unchanged)\n if self.failure_count:\n s = \"s\" if self.failure_count > 1 else \"\"\n report.append(style(f\"{self.failure_count} file{s} {failed}\", fg=\"red\"))\n return \", \".join(report) + \".\"\n\n# Path: src/black/concurrency.py\n\"\"\"\nFormatting many files at once via multiprocessing. Contains entrypoint and utilities.\n\nNOTE: this module is only imported if we need to format several files at once.\n\"\"\"\n\nimport asyncio\nimport logging\nimport os\nimport signal\nimport sys\nimport traceback\nfrom concurrent.futures import Executor, ProcessPoolExecutor, ThreadPoolExecutor\nfrom multiprocessing import Manager\nfrom pathlib import Path\nfrom typing import Any, Iterable, Optional, Set\n\nfrom mypy_extensions import mypyc_attr\n\nfrom black import WriteBack, format_file_in_place\nfrom black.cache import Cache\nfrom black.mode import Mode\nfrom black.output import err\nfrom black.report import Changed, Report\n\n\n\ndef maybe_install_uvloop() -> None:\n \"\"\"If our environment has uvloop installed we use it.\n\n This is called only from command-line entry points to avoid\n interfering with the parent process if Black is used as a library.\n \"\"\"\n try:\n import uvloop\n\n uvloop.install()\n except ImportError:\n pass\n\n\ndef cancel(tasks: Iterable[\"asyncio.Future[Any]\"]) -> None:\n \"\"\"asyncio signal handler that cancels all `tasks` and reports to stderr.\"\"\"\n err(\"Aborted!\")\n for task in tasks:\n task.cancel()\n\n\ndef shutdown(loop: asyncio.AbstractEventLoop) -> None:\n \"\"\"Cancel all pending tasks on `loop`, wait for them, and close the loop.\"\"\"\n try:\n # This part is borrowed from asyncio/runners.py in Python 3.7b2.\n to_cancel = [task for task in asyncio.all_tasks(loop) if not task.done()]\n if not to_cancel:\n return\n\n for task in to_cancel:\n task.cancel()\n loop.run_until_complete(asyncio.gather(*to_cancel, return_exceptions=True))\n finally:\n # `concurrent.futures.Future` objects cannot be cancelled once they\n # are already running. There might be some when the `shutdown()` happened.\n # Silence their logger's spew about the event loop being closed.\n cf_logger = logging.getLogger(\"concurrent.futures\")\n cf_logger.setLevel(logging.CRITICAL)\n loop.close()\n\n\n# diff-shades depends on being to monkeypatch this function to operate. I know it's\n# not ideal, but this shouldn't cause any issues ... hopefully. ~ichard26\n@mypyc_attr(patchable=True)\ndef reformat_many(\n sources: Set[Path],\n fast: bool,\n write_back: WriteBack,\n mode: Mode,\n report: Report,\n workers: Optional[int],\n) -> None:\n \"\"\"Reformat multiple files using a ProcessPoolExecutor.\"\"\"\n maybe_install_uvloop()\n\n executor: Executor\n if workers is None:\n workers = int(os.environ.get(\"BLACK_NUM_WORKERS\", 0))\n workers = workers or os.cpu_count() or 1\n if sys.platform == \"win32\":\n # Work around https://bugs.python.org/issue26903\n workers = min(workers, 60)\n try:\n executor = ProcessPoolExecutor(max_workers=workers)\n except (ImportError, NotImplementedError, OSError):\n # we arrive here if the underlying system does not support multi-processing\n # like in AWS Lambda or Termux, in which case we gracefully fallback to\n # a ThreadPoolExecutor with just a single worker (more workers would not do us\n # any good due to the Global Interpreter Lock)\n executor = ThreadPoolExecutor(max_workers=1)\n\n loop = asyncio.new_event_loop()\n asyncio.set_event_loop(loop)\n try:\n loop.run_until_complete(\n schedule_formatting(\n sources=sources,\n fast=fast,\n write_back=write_back,\n mode=mode,\n report=report,\n loop=loop,\n executor=executor,\n )\n )\n finally:\n try:\n shutdown(loop)\n finally:\n asyncio.set_event_loop(None)\n if executor is not None:\n executor.shutdown()\n\n\nasync def schedule_formatting(\n sources: Set[Path],\n fast: bool,\n write_back: WriteBack,\n mode: Mode,\n report: \"Report\",\n loop: asyncio.AbstractEventLoop,\n executor: \"Executor\",\n) -> None:\n \"\"\"Run formatting of `sources` in parallel using the provided `executor`.\n\n (Use ProcessPoolExecutors for actual parallelism.)\n\n `write_back`, `fast`, and `mode` options are passed to\n :func:`format_file_in_place`.\n \"\"\"\n cache = Cache.read(mode)\n if write_back not in (WriteBack.DIFF, WriteBack.COLOR_DIFF):\n sources, cached = cache.filtered_cached(sources)\n for src in sorted(cached):\n report.done(src, Changed.CACHED)\n if not sources:\n return\n\n cancelled = []\n sources_to_cache = []\n lock = None\n if write_back in (WriteBack.DIFF, WriteBack.COLOR_DIFF):\n # For diff output, we need locks to ensure we don't interleave output\n # from different processes.\n manager = Manager()\n lock = manager.Lock()\n tasks = {\n asyncio.ensure_future(\n loop.run_in_executor(\n executor, format_file_in_place, src, fast, mode, write_back, lock\n )\n ): src\n for src in sorted(sources)\n }\n pending = tasks.keys()\n try:\n loop.add_signal_handler(signal.SIGINT, cancel, pending)\n loop.add_signal_handler(signal.SIGTERM, cancel, pending)\n except NotImplementedError:\n # There are no good alternatives for these on Windows.\n pass\n while pending:\n done, _ = await asyncio.wait(pending, return_when=asyncio.FIRST_COMPLETED)\n for task in done:\n src = tasks.pop(task)\n if task.cancelled():\n cancelled.append(task)\n elif exc := task.exception():\n if report.verbose:\n traceback.print_exception(type(exc), exc, exc.__traceback__)\n report.failed(src, str(exc))\n else:\n changed = Changed.YES if task.result() else Changed.NO\n # If the file was written back or was successfully checked as\n # well-formatted, store this information in the cache.\n if write_back is WriteBack.YES or (\n write_back is WriteBack.CHECK and changed is Changed.NO\n ):\n sources_to_cache.append(src)\n report.done(src, changed)\n if cancelled:\n await asyncio.gather(*cancelled, return_exceptions=True)\n if sources_to_cache:\n cache.write(sources_to_cache)\n\n# Path: src/black/handle_ipynb_magics.py\n\"\"\"Functions to process IPython magics with.\"\"\"\n\nimport ast\nimport collections\nimport dataclasses\nimport secrets\nimport sys\nfrom functools import lru_cache\nfrom importlib.util import find_spec\nfrom typing import Dict, List, Optional, Tuple\n\nif sys.version_info >= (3, 10):\n from typing import TypeGuard\nelse:\n from typing_extensions import TypeGuard\n\nfrom black.output import out\nfrom black.report import NothingChanged\n\nTRANSFORMED_MAGICS = frozenset((\n \"get_ipython().run_cell_magic\",\n \"get_ipython().system\",\n \"get_ipython().getoutput\",\n \"get_ipython().run_line_magic\",\n))\nTOKENS_TO_IGNORE = frozenset((\n \"ENDMARKER\",\n \"NL\",\n \"NEWLINE\",\n \"COMMENT\",\n \"DEDENT\",\n \"UNIMPORTANT_WS\",\n \"ESCAPED_NL\",\n))\nPYTHON_CELL_MAGICS = frozenset((\n \"capture\",\n \"prun\",\n \"pypy\",\n \"python\",\n \"python3\",\n \"time\",\n \"timeit\",\n))\nTOKEN_HEX = secrets.token_hex\n\n\n@dataclasses.dataclass(frozen=True)\nclass Replacement:\n mask: str\n src: str\n\n\n@lru_cache\ndef jupyter_dependencies_are_installed(*, warn: bool) -> bool:\n installed = (\n find_spec(\"tokenize_rt\") is not None and find_spec(\"IPython\") is not None\n )\n if not installed and warn:\n msg = (\n \"Skipping .ipynb files as Jupyter dependencies are not installed.\\n\"\n 'You can fix this by running ``pip install \"black[jupyter]\"``'\n )\n out(msg)\n return installed\n\n\ndef remove_trailing_semicolon(src: str) -> Tuple[str, bool]:\n \"\"\"Remove trailing semicolon from Jupyter notebook cell.\n\n For example,\n\n fig, ax = plt.subplots()\n ax.plot(x_data, y_data); # plot data\n\n would become\n\n fig, ax = plt.subplots()\n ax.plot(x_data, y_data) # plot data\n\n Mirrors the logic in `quiet` from `IPython.core.displayhook`, but uses\n ``tokenize_rt`` so that round-tripping works fine.\n \"\"\"\n from tokenize_rt import reversed_enumerate, src_to_tokens, tokens_to_src\n\n tokens = src_to_tokens(src)\n trailing_semicolon = False\n for idx, token in reversed_enumerate(tokens):\n if token.name in TOKENS_TO_IGNORE:\n continue\n if token.name == \"OP\" and token.src == \";\":\n del tokens[idx]\n trailing_semicolon = True\n break\n if not trailing_semicolon:\n return src, False\n return tokens_to_src(tokens), True\n\n\ndef put_trailing_semicolon_back(src: str, has_trailing_semicolon: bool) -> str:\n \"\"\"Put trailing semicolon back if cell originally had it.\n\n Mirrors the logic in `quiet` from `IPython.core.displayhook`, but uses\n ``tokenize_rt`` so that round-tripping works fine.\n \"\"\"\n if not has_trailing_semicolon:\n return src\n from tokenize_rt import reversed_enumerate, src_to_tokens, tokens_to_src\n\n tokens = src_to_tokens(src)\n for idx, token in reversed_enumerate(tokens):\n if token.name in TOKENS_TO_IGNORE:\n continue\n tokens[idx] = token._replace(src=token.src + \";\")\n break\n else: # pragma: nocover\n raise AssertionError(\n \"INTERNAL ERROR: Was not able to reinstate trailing semicolon. \"\n \"Please report a bug on https://github.com/psf/black/issues. \"\n ) from None\n return str(tokens_to_src(tokens))\n\n\ndef mask_cell(src: str) -> Tuple[str, List[Replacement]]:\n \"\"\"Mask IPython magics so content becomes parseable Python code.\n\n For example,\n\n %matplotlib inline\n 'foo'\n\n becomes\n\n \"25716f358c32750e\"\n 'foo'\n\n The replacements are returned, along with the transformed code.\n \"\"\"\n replacements: List[Replacement] = []\n try:\n ast.parse(src)\n except SyntaxError:\n # Might have IPython magics, will process below.\n pass\n else:\n # Syntax is fine, nothing to mask, early return.\n return src, replacements\n\n from IPython.core.inputtransformer2 import TransformerManager\n\n transformer_manager = TransformerManager()\n transformed = transformer_manager.transform_cell(src)\n transformed, cell_magic_replacements = replace_cell_magics(transformed)\n replacements += cell_magic_replacements\n transformed = transformer_manager.transform_cell(transformed)\n transformed, magic_replacements = replace_magics(transformed)\n if len(transformed.splitlines()) != len(src.splitlines()):\n # Multi-line magic, not supported.\n raise NothingChanged\n replacements += magic_replacements\n return transformed, replacements\n\n\ndef get_token(src: str, magic: str) -> str:\n \"\"\"Return randomly generated token to mask IPython magic with.\n\n For example, if 'magic' was `%matplotlib inline`, then a possible\n token to mask it with would be `\"43fdd17f7e5ddc83\"`. The token\n will be the same length as the magic, and we make sure that it was\n not already present anywhere else in the cell.\n \"\"\"\n assert magic\n nbytes = max(len(magic) // 2 - 1, 1)\n token = TOKEN_HEX(nbytes)\n counter = 0\n while token in src:\n token = TOKEN_HEX(nbytes)\n counter += 1\n if counter > 100:\n raise AssertionError(\n \"INTERNAL ERROR: Black was not able to replace IPython magic. \"\n \"Please report a bug on https://github.com/psf/black/issues. \"\n f\"The magic might be helpful: {magic}\"\n ) from None\n if len(token) + 2 < len(magic):\n token = f\"{token}.\"\n return f'\"{token}\"'\n\n\ndef replace_cell_magics(src: str) -> Tuple[str, List[Replacement]]:\n \"\"\"Replace cell magic with token.\n\n Note that 'src' will already have been processed by IPython's\n TransformerManager().transform_cell.\n\n Example,\n\n get_ipython().run_cell_magic('t', '-n1', 'ls =!ls\\\\n')\n\n becomes\n\n \"a794.\"\n ls =!ls\n\n The replacement, along with the transformed code, is returned.\n \"\"\"\n replacements: List[Replacement] = []\n\n tree = ast.parse(src)\n\n cell_magic_finder = CellMagicFinder()\n cell_magic_finder.visit(tree)\n if cell_magic_finder.cell_magic is None:\n return src, replacements\n header = cell_magic_finder.cell_magic.header\n mask = get_token(src, header)\n replacements.append(Replacement(mask=mask, src=header))\n return f\"{mask}\\n{cell_magic_finder.cell_magic.body}\", replacements\n\n\ndef replace_magics(src: str) -> Tuple[str, List[Replacement]]:\n \"\"\"Replace magics within body of cell.\n\n Note that 'src' will already have been processed by IPython's\n TransformerManager().transform_cell.\n\n Example, this\n\n get_ipython().run_line_magic('matplotlib', 'inline')\n 'foo'\n\n becomes\n\n \"5e67db56d490fd39\"\n 'foo'\n\n The replacement, along with the transformed code, are returned.\n \"\"\"\n replacements = []\n magic_finder = MagicFinder()\n magic_finder.visit(ast.parse(src))\n new_srcs = []\n for i, line in enumerate(src.splitlines(), start=1):\n if i in magic_finder.magics:\n offsets_and_magics = magic_finder.magics[i]\n if len(offsets_and_magics) != 1: # pragma: nocover\n raise AssertionError(\n f\"Expecting one magic per line, got: {offsets_and_magics}\\n\"\n \"Please report a bug on https://github.com/psf/black/issues.\"\n )\n col_offset, magic = (\n offsets_and_magics[0].col_offset,\n offsets_and_magics[0].magic,\n )\n mask = get_token(src, magic)\n replacements.append(Replacement(mask=mask, src=magic))\n line = line[:col_offset] + mask\n new_srcs.append(line)\n return \"\\n\".join(new_srcs), replacements\n\n\ndef unmask_cell(src: str, replacements: List[Replacement]) -> str:\n \"\"\"Remove replacements from cell.\n\n For example\n\n \"9b20\"\n foo = bar\n\n becomes\n\n %%time\n foo = bar\n \"\"\"\n for replacement in replacements:\n src = src.replace(replacement.mask, replacement.src)\n return src\n\n\ndef _is_ipython_magic(node: ast.expr) -> TypeGuard[ast.Attribute]:\n \"\"\"Check if attribute is IPython magic.\n\n Note that the source of the abstract syntax tree\n will already have been processed by IPython's\n TransformerManager().transform_cell.\n \"\"\"\n return (\n isinstance(node, ast.Attribute)\n and isinstance(node.value, ast.Call)\n and isinstance(node.value.func, ast.Name)\n and node.value.func.id == \"get_ipython\"\n )\n\n\ndef _get_str_args(args: List[ast.expr]) -> List[str]:\n str_args = []\n for arg in args:\n assert isinstance(arg, ast.Str)\n str_args.append(arg.s)\n return str_args\n\n\n@dataclasses.dataclass(frozen=True)\nclass CellMagic:\n name: str\n params: Optional[str]\n body: str\n\n @property\n def header(self) -> str:\n if self.params:\n return f\"%%{self.name} {self.params}\"\n return f\"%%{self.name}\"\n\n\n# ast.NodeVisitor + dataclass = breakage under mypyc.\nclass CellMagicFinder(ast.NodeVisitor):\n \"\"\"Find cell magics.\n\n Note that the source of the abstract syntax tree\n will already have been processed by IPython's\n TransformerManager().transform_cell.\n\n For example,\n\n %%time\\n\n foo()\n\n would have been transformed to\n\n get_ipython().run_cell_magic('time', '', 'foo()\\\\n')\n\n and we look for instances of the latter.\n \"\"\"\n\n def __init__(self, cell_magic: Optional[CellMagic] = None) -> None:\n self.cell_magic = cell_magic\n\n def visit_Expr(self, node: ast.Expr) -> None:\n \"\"\"Find cell magic, extract header and body.\"\"\"\n if (\n isinstance(node.value, ast.Call)\n and _is_ipython_magic(node.value.func)\n and node.value.func.attr == \"run_cell_magic\"\n ):\n args = _get_str_args(node.value.args)\n self.cell_magic = CellMagic(name=args[0], params=args[1], body=args[2])\n self.generic_visit(node)\n\n\n@dataclasses.dataclass(frozen=True)\nclass OffsetAndMagic:\n col_offset: int\n magic: str\n\n\n# Unsurprisingly, subclassing ast.NodeVisitor means we can't use dataclasses here\n# as mypyc will generate broken code.\nclass MagicFinder(ast.NodeVisitor):\n \"\"\"Visit cell to look for get_ipython calls.\n\n Note that the source of the abstract syntax tree\n will already have been processed by IPython's\n TransformerManager().transform_cell.\n\n For example,\n\n %matplotlib inline\n\n would have been transformed to\n\n get_ipython().run_line_magic('matplotlib', 'inline')\n\n and we look for instances of the latter (and likewise for other\n types of magics).\n \"\"\"\n\n def __init__(self) -> None:\n self.magics: Dict[int, List[OffsetAndMagic]] = collections.defaultdict(list)\n\n def visit_Assign(self, node: ast.Assign) -> None:\n \"\"\"Look for system assign magics.\n\n For example,\n\n black_version = !black --version\n env = %env var\n\n would have been (respectively) transformed to\n\n black_version = get_ipython().getoutput('black --version')\n env = get_ipython().run_line_magic('env', 'var')\n\n and we look for instances of any of the latter.\n \"\"\"\n if isinstance(node.value, ast.Call) and _is_ipython_magic(node.value.func):\n args = _get_str_args(node.value.args)\n if node.value.func.attr == \"getoutput\":\n src = f\"!{args[0]}\"\n elif node.value.func.attr == \"run_line_magic\":\n src = f\"%{args[0]}\"\n if args[1]:\n src += f\" {args[1]}\"\n else:\n raise AssertionError(\n f\"Unexpected IPython magic {node.value.func.attr!r} found. \"\n \"Please report a bug on https://github.com/psf/black/issues.\"\n ) from None\n self.magics[node.value.lineno].append(\n OffsetAndMagic(node.value.col_offset, src)\n )\n self.generic_visit(node)\n\n def visit_Expr(self, node: ast.Expr) -> None:\n \"\"\"Look for magics in body of cell.\n\n For examples,\n\n !ls\n !!ls\n ?ls\n ??ls\n\n would (respectively) get transformed to\n\n get_ipython().system('ls')\n get_ipython().getoutput('ls')\n get_ipython().run_line_magic('pinfo', 'ls')\n get_ipython().run_line_magic('pinfo2', 'ls')\n\n and we look for instances of any of the latter.\n \"\"\"\n if isinstance(node.value, ast.Call) and _is_ipython_magic(node.value.func):\n args = _get_str_args(node.value.args)\n if node.value.func.attr == \"run_line_magic\":\n if args[0] == \"pinfo\":\n src = f\"?{args[1]}\"\n elif args[0] == \"pinfo2\":\n src = f\"??{args[1]}\"\n else:\n src = f\"%{args[0]}\"\n if args[1]:\n src += f\" {args[1]}\"\n elif node.value.func.attr == \"system\":\n src = f\"!{args[0]}\"\n elif node.value.func.attr == \"getoutput\":\n src = f\"!!{args[0]}\"\n else:\n raise NothingChanged # unsupported magic.\n self.magics[node.value.lineno].append(\n OffsetAndMagic(node.value.col_offset, src)\n )\n self.generic_visit(node)\n\n# Path: src/black/files.py\nimport io\nimport os\nimport sys\nfrom functools import lru_cache\nfrom pathlib import Path\nfrom typing import (\n TYPE_CHECKING,\n Any,\n Dict,\n Iterable,\n Iterator,\n List,\n Optional,\n Pattern,\n Sequence,\n Tuple,\n Union,\n)\n\nfrom mypy_extensions import mypyc_attr\nfrom packaging.specifiers import InvalidSpecifier, Specifier, SpecifierSet\nfrom packaging.version import InvalidVersion, Version\nfrom pathspec import PathSpec\nfrom pathspec.patterns.gitwildmatch import GitWildMatchPatternError\n\nif sys.version_info >= (3, 11):\n try:\n import tomllib\n except ImportError:\n # Help users on older alphas\n if not TYPE_CHECKING:\n import tomli as tomllib\nelse:\n import tomli as tomllib\n\nfrom black.handle_ipynb_magics import jupyter_dependencies_are_installed\nfrom black.mode import TargetVersion\nfrom black.output import err\nfrom black.report import Report\n\nif TYPE_CHECKING:\n import colorama # noqa: F401\n\n\n@lru_cache\ndef _load_toml(path: Union[Path, str]) -> Dict[str, Any]:\n with open(path, \"rb\") as f:\n return tomllib.load(f)\n\n\n@lru_cache\ndef _cached_resolve(path: Path) -> Path:\n return path.resolve()\n\n\n@lru_cache\ndef find_project_root(\n srcs: Sequence[str], stdin_filename: Optional[str] = None\n) -> Tuple[Path, str]:\n \"\"\"Return a directory containing .git, .hg, or pyproject.toml.\n\n That directory will be a common parent of all files and directories\n passed in `srcs`.\n\n If no directory in the tree contains a marker that would specify it's the\n project root, the root of the file system is returned.\n\n Returns a two-tuple with the first element as the project root path and\n the second element as a string describing the method by which the\n project root was discovered.\n \"\"\"\n if stdin_filename is not None:\n srcs = tuple(stdin_filename if s == \"-\" else s for s in srcs)\n if not srcs:\n srcs = [str(_cached_resolve(Path.cwd()))]\n\n path_srcs = [_cached_resolve(Path(Path.cwd(), src)) for src in srcs]\n\n # A list of lists of parents for each 'src'. 'src' is included as a\n # \"parent\" of itself if it is a directory\n src_parents = [\n list(path.parents) + ([path] if path.is_dir() else []) for path in path_srcs\n ]\n\n common_base = max(\n set.intersection(*(set(parents) for parents in src_parents)),\n key=lambda path: path.parts,\n )\n\n for directory in (common_base, *common_base.parents):\n if (directory / \".git\").exists():\n return directory, \".git directory\"\n\n if (directory / \".hg\").is_dir():\n return directory, \".hg directory\"\n\n if (directory / \"pyproject.toml\").is_file():\n pyproject_toml = _load_toml(directory / \"pyproject.toml\")\n if \"black\" in pyproject_toml.get(\"tool\", {}):\n return directory, \"pyproject.toml\"\n\n return directory, \"file system root\"\n\n\ndef find_pyproject_toml(\n path_search_start: Tuple[str, ...], stdin_filename: Optional[str] = None\n) -> Optional[str]:\n \"\"\"Find the absolute filepath to a pyproject.toml if it exists\"\"\"\n path_project_root, _ = find_project_root(path_search_start, stdin_filename)\n path_pyproject_toml = path_project_root / \"pyproject.toml\"\n if path_pyproject_toml.is_file():\n return str(path_pyproject_toml)\n\n try:\n path_user_pyproject_toml = find_user_pyproject_toml()\n return (\n str(path_user_pyproject_toml)\n if path_user_pyproject_toml.is_file()\n else None\n )\n except (PermissionError, RuntimeError) as e:\n # We do not have access to the user-level config directory, so ignore it.\n err(f\"Ignoring user configuration directory due to {e!r}\")\n return None\n\n\n@mypyc_attr(patchable=True)\ndef parse_pyproject_toml(path_config: str) -> Dict[str, Any]:\n \"\"\"Parse a pyproject toml file, pulling out relevant parts for Black.\n\n If parsing fails, will raise a tomllib.TOMLDecodeError.\n \"\"\"\n pyproject_toml = _load_toml(path_config)\n config: Dict[str, Any] = pyproject_toml.get(\"tool\", {}).get(\"black\", {})\n config = {k.replace(\"--\", \"\").replace(\"-\", \"_\"): v for k, v in config.items()}\n\n if \"target_version\" not in config:\n...\n", "needle_token_start": 8833, "needle_token_end": 8922, "code_context_ntokens": 16105, "output": ["```python\ndef maybe_install_uvloop() -> None:\n \"\"\"If our environment has uvloop installed we use it.\n\n This is called only from command-line entry points to avoid\n interfering with the parent process if Black is used as a library.\n \"\"\"\n try:\n import uvloop\n\n uvloop.install()\n except ImportError:\n pass\n```"]} +{"repo": "psf/black", "name": "check_stability_and_equivalence", "language": "python", "path": "src/black/__init__.py", "position_ratio": 0.65, "description": "\nFunction Description:\n1. **Purpose**: To ensure that the transformed content remains consistent after formatting and that the original and transformed contents are equivalent in terms of syntax and structure.\n2. **Input**: Two strings representing the original content and the transformed content, a configuration object specifying formatting rules, and an optional collection of line number ranges to focus the checks on specific parts of the content.\n3. **Output**: There is no return value; however, the function can raise an `AssertionError` if the transformed content does not meet the specified conditions.\n4. **Procedure**: The function first checks if the original and transformed contents are equivalent, ensuring no unintended changes have been introduced. It then verifies the stability of the transformation by checking if applying the transformation a second time results in the same output, ensuring idempotence of the formatting process.\n\n", "instruction": "Based on the function description and code context, please retrieve and repeat the exact described function from the code context in a code block wrapped by ```:", "template": "instruction\ncode_context\ndescription\ninstruction", "code_context": "# Path: src/black/ranges.py\n\"\"\"Functions related to Black's formatting by line ranges feature.\"\"\"\n\nimport difflib\nfrom dataclasses import dataclass\nfrom typing import Collection, Iterator, List, Sequence, Set, Tuple, Union\n\nfrom black.nodes import (\n LN,\n STANDALONE_COMMENT,\n Leaf,\n Node,\n Visitor,\n first_leaf,\n furthest_ancestor_with_last_leaf,\n last_leaf,\n syms,\n)\nfrom blib2to3.pgen2.token import ASYNC, NEWLINE\n\n\ndef parse_line_ranges(line_ranges: Sequence[str]) -> List[Tuple[int, int]]:\n lines: List[Tuple[int, int]] = []\n for lines_str in line_ranges:\n parts = lines_str.split(\"-\")\n if len(parts) != 2:\n raise ValueError(\n \"Incorrect --line-ranges format, expect 'START-END', found\"\n f\" {lines_str!r}\"\n )\n try:\n start = int(parts[0])\n end = int(parts[1])\n except ValueError:\n raise ValueError(\n \"Incorrect --line-ranges value, expect integer ranges, found\"\n f\" {lines_str!r}\"\n ) from None\n else:\n lines.append((start, end))\n return lines\n\n\ndef is_valid_line_range(lines: Tuple[int, int]) -> bool:\n \"\"\"Returns whether the line range is valid.\"\"\"\n return not lines or lines[0] <= lines[1]\n\n\ndef adjusted_lines(\n lines: Collection[Tuple[int, int]],\n original_source: str,\n modified_source: str,\n) -> List[Tuple[int, int]]:\n \"\"\"Returns the adjusted line ranges based on edits from the original code.\n\n This computes the new line ranges by diffing original_source and\n modified_source, and adjust each range based on how the range overlaps with\n the diffs.\n\n Note the diff can contain lines outside of the original line ranges. This can\n happen when the formatting has to be done in adjacent to maintain consistent\n local results. For example:\n\n 1. def my_func(arg1, arg2,\n 2. arg3,):\n 3. pass\n\n If it restricts to line 2-2, it can't simply reformat line 2, it also has\n to reformat line 1:\n\n 1. def my_func(\n 2. arg1,\n 3. arg2,\n 4. arg3,\n 5. ):\n 6. pass\n\n In this case, we will expand the line ranges to also include the whole diff\n block.\n\n Args:\n lines: a collection of line ranges.\n original_source: the original source.\n modified_source: the modified source.\n \"\"\"\n lines_mappings = _calculate_lines_mappings(original_source, modified_source)\n\n new_lines = []\n # Keep an index of the current search. Since the lines and lines_mappings are\n # sorted, this makes the search complexity linear.\n current_mapping_index = 0\n for start, end in sorted(lines):\n...\n# Path: src/black/__init__.py\nimport io\nimport json\nimport platform\nimport re\nimport sys\nimport tokenize\nimport traceback\nfrom contextlib import contextmanager\nfrom dataclasses import replace\nfrom datetime import datetime, timezone\nfrom enum import Enum\nfrom json.decoder import JSONDecodeError\nfrom pathlib import Path\nfrom typing import (\n Any,\n Collection,\n Dict,\n Generator,\n Iterator,\n List,\n MutableMapping,\n Optional,\n Pattern,\n Sequence,\n Set,\n Sized,\n Tuple,\n Union,\n)\n\nimport click\nfrom click.core import ParameterSource\nfrom mypy_extensions import mypyc_attr\nfrom pathspec import PathSpec\nfrom pathspec.patterns.gitwildmatch import GitWildMatchPatternError\n\nfrom _black_version import version as __version__\nfrom black.cache import Cache\nfrom black.comments import normalize_fmt_off\nfrom black.const import (\n DEFAULT_EXCLUDES,\n DEFAULT_INCLUDES,\n DEFAULT_LINE_LENGTH,\n STDIN_PLACEHOLDER,\n)\nfrom black.files import (\n best_effort_relative_path,\n find_project_root,\n find_pyproject_toml,\n find_user_pyproject_toml,\n gen_python_files,\n get_gitignore,\n parse_pyproject_toml,\n path_is_excluded,\n resolves_outside_root_or_cannot_stat,\n wrap_stream_for_windows,\n)\nfrom black.handle_ipynb_magics import (\n PYTHON_CELL_MAGICS,\n TRANSFORMED_MAGICS,\n jupyter_dependencies_are_installed,\n mask_cell,\n put_trailing_semicolon_back,\n remove_trailing_semicolon,\n unmask_cell,\n)\nfrom black.linegen import LN, LineGenerator, transform_line\nfrom black.lines import EmptyLineTracker, LinesBlock\nfrom black.mode import FUTURE_FLAG_TO_FEATURE, VERSION_TO_FEATURES, Feature\nfrom black.mode import Mode as Mode # re-exported\nfrom black.mode import Preview, TargetVersion, supports_feature\nfrom black.nodes import (\n STARS,\n is_number_token,\n is_simple_decorator_expression,\n is_string_token,\n syms,\n)\nfrom black.output import color_diff, diff, dump_to_file, err, ipynb_diff, out\nfrom black.parsing import InvalidInput # noqa F401\nfrom black.parsing import lib2to3_parse, parse_ast, stringify_ast\nfrom black.ranges import adjusted_lines, convert_unchanged_lines, parse_line_ranges\nfrom black.report import Changed, NothingChanged, Report\nfrom black.trans import iter_fexpr_spans\nfrom blib2to3.pgen2 import token\nfrom blib2to3.pytree import Leaf, Node\n\nCOMPILED = Path(__file__).suffix in (\".pyd\", \".so\")\n\n# types\nFileContent = str\nEncoding = str\nNewLine = str\n\n\nclass WriteBack(Enum):\n NO = 0\n YES = 1\n DIFF = 2\n CHECK = 3\n COLOR_DIFF = 4\n\n @classmethod\n def from_configuration(\n cls, *, check: bool, diff: bool, color: bool = False\n ) -> \"WriteBack\":\n if check and not diff:\n return cls.CHECK\n\n if diff and color:\n return cls.COLOR_DIFF\n\n return cls.DIFF if diff else cls.YES\n\n\n# Legacy name, left for integrations.\nFileMode = Mode\n\n\ndef read_pyproject_toml(\n ctx: click.Context, param: click.Parameter, value: Optional[str]\n) -> Optional[str]:\n \"\"\"Inject Black configuration from \"pyproject.toml\" into defaults in `ctx`.\n\n Returns the path to a successfully found and read configuration file, None\n otherwise.\n \"\"\"\n if not value:\n value = find_pyproject_toml(\n ctx.params.get(\"src\", ()), ctx.params.get(\"stdin_filename\", None)\n )\n if value is None:\n return None\n\n try:\n config = parse_pyproject_toml(value)\n except (OSError, ValueError) as e:\n raise click.FileError(\n filename=value, hint=f\"Error reading configuration file: {e}\"\n ) from None\n\n if not config:\n return None\n else:\n spellcheck_pyproject_toml_keys(ctx, list(config), value)\n # Sanitize the values to be Click friendly. For more information please see:\n # https://github.com/psf/black/issues/1458\n # https://github.com/pallets/click/issues/1567\n config = {\n k: str(v) if not isinstance(v, (list, dict)) else v\n for k, v in config.items()\n }\n\n target_version = config.get(\"target_version\")\n if target_version is not None and not isinstance(target_version, list):\n raise click.BadOptionUsage(\n \"target-version\", \"Config key target-version must be a list\"\n )\n\n exclude = config.get(\"exclude\")\n if exclude is not None and not isinstance(exclude, str):\n raise click.BadOptionUsage(\"exclude\", \"Config key exclude must be a string\")\n\n extend_exclude = config.get(\"extend_exclude\")\n if extend_exclude is not None and not isinstance(extend_exclude, str):\n raise click.BadOptionUsage(\n \"extend-exclude\", \"Config key extend-exclude must be a string\"\n )\n\n line_ranges = config.get(\"line_ranges\")\n if line_ranges is not None:\n raise click.BadOptionUsage(\n \"line-ranges\", \"Cannot use line-ranges in the pyproject.toml file.\"\n )\n\n default_map: Dict[str, Any] = {}\n if ctx.default_map:\n default_map.update(ctx.default_map)\n default_map.update(config)\n\n ctx.default_map = default_map\n return value\n\n\ndef spellcheck_pyproject_toml_keys(\n ctx: click.Context, config_keys: List[str], config_file_path: str\n) -> None:\n invalid_keys: List[str] = []\n available_config_options = {param.name for param in ctx.command.params}\n for key in config_keys:\n if key not in available_config_options:\n invalid_keys.append(key)\n if invalid_keys:\n keys_str = \", \".join(map(repr, invalid_keys))\n out(\n f\"Invalid config keys detected: {keys_str} (in {config_file_path})\",\n fg=\"red\",\n )\n\n\ndef target_version_option_callback(\n c: click.Context, p: Union[click.Option, click.Parameter], v: Tuple[str, ...]\n) -> List[TargetVersion]:\n \"\"\"Compute the target versions from a --target-version flag.\n\n This is its own function because mypy couldn't infer the type correctly\n when it was a lambda, causing mypyc trouble.\n \"\"\"\n return [TargetVersion[val.upper()] for val in v]\n\n\ndef enable_unstable_feature_callback(\n c: click.Context, p: Union[click.Option, click.Parameter], v: Tuple[str, ...]\n) -> List[Preview]:\n \"\"\"Compute the features from an --enable-unstable-feature flag.\"\"\"\n return [Preview[val] for val in v]\n\n\ndef re_compile_maybe_verbose(regex: str) -> Pattern[str]:\n \"\"\"Compile a regular expression string in `regex`.\n\n If it contains newlines, use verbose mode.\n \"\"\"\n if \"\\n\" in regex:\n regex = \"(?x)\" + regex\n compiled: Pattern[str] = re.compile(regex)\n return compiled\n\n\ndef validate_regex(\n ctx: click.Context,\n param: click.Parameter,\n value: Optional[str],\n) -> Optional[Pattern[str]]:\n try:\n return re_compile_maybe_verbose(value) if value is not None else None\n except re.error as e:\n raise click.BadParameter(f\"Not a valid regular expression: {e}\") from None\n\n\n@click.command(\n context_settings={\"help_option_names\": [\"-h\", \"--help\"]},\n # While Click does set this field automatically using the docstring, mypyc\n # (annoyingly) strips 'em so we need to set it here too.\n help=\"The uncompromising code formatter.\",\n)\n@click.option(\"-c\", \"--code\", type=str, help=\"Format the code passed in as a string.\")\n@click.option(\n \"-l\",\n \"--line-length\",\n type=int,\n default=DEFAULT_LINE_LENGTH,\n help=\"How many characters per line to allow.\",\n show_default=True,\n)\n@click.option(\n \"-t\",\n \"--target-version\",\n type=click.Choice([v.name.lower() for v in TargetVersion]),\n callback=target_version_option_callback,\n multiple=True,\n help=(\n \"Python versions that should be supported by Black's output. You should\"\n \" include all versions that your code supports. By default, Black will infer\"\n \" target versions from the project metadata in pyproject.toml. If this does\"\n \" not yield conclusive results, Black will use per-file auto-detection.\"\n ),\n)\n@click.option(\n \"--pyi\",\n is_flag=True,\n help=(\n \"Format all input files like typing stubs regardless of file extension. This\"\n \" is useful when piping source on standard input.\"\n ),\n)\n@click.option(\n \"--ipynb\",\n is_flag=True,\n help=(\n \"Format all input files like Jupyter Notebooks regardless of file extension.\"\n \" This is useful when piping source on standard input.\"\n ),\n)\n@click.option(\n \"--python-cell-magics\",\n multiple=True,\n help=(\n \"When processing Jupyter Notebooks, add the given magic to the list\"\n f\" of known python-magics ({', '.join(sorted(PYTHON_CELL_MAGICS))}).\"\n \" Useful for formatting cells with custom python magics.\"\n ),\n default=[],\n)\n@click.option(\n \"-x\",\n \"--skip-source-first-line\",\n is_flag=True,\n help=\"Skip the first line of the source code.\",\n)\n@click.option(\n \"-S\",\n \"--skip-string-normalization\",\n is_flag=True,\n help=\"Don't normalize string quotes or prefixes.\",\n)\n@click.option(\n \"-C\",\n \"--skip-magic-trailing-comma\",\n is_flag=True,\n help=\"Don't use trailing commas as a reason to split lines.\",\n)\n@click.option(\n \"--preview\",\n is_flag=True,\n help=(\n \"Enable potentially disruptive style changes that may be added to Black's main\"\n \" functionality in the next major release.\"\n ),\n)\n@click.option(\n \"--unstable\",\n is_flag=True,\n help=(\n \"Enable potentially disruptive style changes that have known bugs or are not\"\n \" currently expected to make it into the stable style Black's next major\"\n \" release. Implies --preview.\"\n ),\n)\n@click.option(\n \"--enable-unstable-feature\",\n type=click.Choice([v.name for v in Preview]),\n callback=enable_unstable_feature_callback,\n multiple=True,\n help=(\n \"Enable specific features included in the `--unstable` style. Requires\"\n \" `--preview`. No compatibility guarantees are provided on the behavior\"\n \" or existence of any unstable features.\"\n ),\n)\n@click.option(\n \"--check\",\n is_flag=True,\n help=(\n \"Don't write the files back, just return the status. Return code 0 means\"\n \" nothing would change. Return code 1 means some files would be reformatted.\"\n \" Return code 123 means there was an internal error.\"\n ),\n)\n@click.option(\n \"--diff\",\n is_flag=True,\n help=(\n \"Don't write the files back, just output a diff to indicate what changes\"\n \" Black would've made. They are printed to stdout so capturing them is simple.\"\n ),\n)\n@click.option(\n \"--color/--no-color\",\n is_flag=True,\n help=\"Show (or do not show) colored diff. Only applies when --diff is given.\",\n)\n@click.option(\n \"--line-ranges\",\n multiple=True,\n metavar=\"START-END\",\n help=(\n \"When specified, Black will try its best to only format these lines. This\"\n \" option can be specified multiple times, and a union of the lines will be\"\n \" formatted. Each range must be specified as two integers connected by a `-`:\"\n \" `-`. The `` and `` integer indices are 1-based and\"\n \" inclusive on both ends.\"\n ),\n default=(),\n)\n@click.option(\n \"--fast/--safe\",\n is_flag=True,\n help=(\n \"By default, Black performs an AST safety check after formatting your code.\"\n \" The --fast flag turns off this check and the --safe flag explicitly enables\"\n \" it. [default: --safe]\"\n ),\n)\n@click.option(\n \"--required-version\",\n type=str,\n help=(\n \"Require a specific version of Black to be running. This is useful for\"\n \" ensuring that all contributors to your project are using the same\"\n \" version, because different versions of Black may format code a little\"\n \" differently. This option can be set in a configuration file for consistent\"\n \" results across environments.\"\n ),\n)\n@click.option(\n \"--exclude\",\n type=str,\n callback=validate_regex,\n help=(\n \"A regular expression that matches files and directories that should be\"\n \" excluded on recursive searches. An empty value means no paths are excluded.\"\n \" Use forward slashes for directories on all platforms (Windows, too).\"\n \" By default, Black also ignores all paths listed in .gitignore. Changing this\"\n f\" value will override all default exclusions. [default: {DEFAULT_EXCLUDES}]\"\n ),\n show_default=False,\n)\n@click.option(\n \"--extend-exclude\",\n type=str,\n callback=validate_regex,\n help=(\n \"Like --exclude, but adds additional files and directories on top of the\"\n \" default values instead of overriding them.\"\n ),\n)\n@click.option(\n \"--force-exclude\",\n type=str,\n callback=validate_regex,\n help=(\n \"Like --exclude, but files and directories matching this regex will be excluded\"\n \" even when they are passed explicitly as arguments. This is useful when\"\n \" invoking Black programmatically on changed files, such as in a pre-commit\"\n \" hook or editor plugin.\"\n ),\n)\n@click.option(\n \"--stdin-filename\",\n type=str,\n is_eager=True,\n help=(\n \"The name of the file when passing it through stdin. Useful to make sure Black\"\n \" will respect the --force-exclude option on some editors that rely on using\"\n \" stdin.\"\n ),\n)\n@click.option(\n \"--include\",\n type=str,\n default=DEFAULT_INCLUDES,\n callback=validate_regex,\n help=(\n \"A regular expression that matches files and directories that should be\"\n \" included on recursive searches. An empty value means all files are included\"\n \" regardless of the name. Use forward slashes for directories on all platforms\"\n \" (Windows, too). Overrides all exclusions, including from .gitignore and\"\n \" command line options.\"\n ),\n show_default=True,\n)\n@click.option(\n \"-W\",\n \"--workers\",\n type=click.IntRange(min=1),\n default=None,\n help=(\n \"When Black formats multiple files, it may use a process pool to speed up\"\n \" formatting. This option controls the number of parallel workers. This can\"\n \" also be specified via the BLACK_NUM_WORKERS environment variable. Defaults\"\n \" to the number of CPUs in the system.\"\n ),\n)\n@click.option(\n \"-q\",\n \"--quiet\",\n is_flag=True,\n help=(\n \"Stop emitting all non-critical output. Error messages will still be emitted\"\n \" (which can silenced by 2>/dev/null).\"\n ),\n)\n@click.option(\n \"-v\",\n \"--verbose\",\n is_flag=True,\n help=(\n \"Emit messages about files that were not changed or were ignored due to\"\n \" exclusion patterns. If Black is using a configuration file, a message\"\n \" detailing which one it is using will be emitted.\"\n ),\n)\n@click.version_option(\n version=__version__,\n message=(\n f\"%(prog)s, %(version)s (compiled: {'yes' if COMPILED else 'no'})\\n\"\n f\"Python ({platform.python_implementation()}) {platform.python_version()}\"\n ),\n)\n@click.argument(\n \"src\",\n nargs=-1,\n type=click.Path(\n exists=True, file_okay=True, dir_okay=True, readable=True, allow_dash=True\n ),\n is_eager=True,\n metavar=\"SRC ...\",\n)\n@click.option(\n \"--config\",\n type=click.Path(\n exists=True,\n file_okay=True,\n dir_okay=False,\n readable=True,\n allow_dash=False,\n path_type=str,\n ),\n is_eager=True,\n callback=read_pyproject_toml,\n help=\"Read configuration options from a configuration file.\",\n)\n@click.pass_context\ndef main( # noqa: C901\n ctx: click.Context,\n code: Optional[str],\n line_length: int,\n target_version: List[TargetVersion],\n check: bool,\n diff: bool,\n line_ranges: Sequence[str],\n color: bool,\n fast: bool,\n pyi: bool,\n ipynb: bool,\n python_cell_magics: Sequence[str],\n skip_source_first_line: bool,\n skip_string_normalization: bool,\n skip_magic_trailing_comma: bool,\n preview: bool,\n unstable: bool,\n enable_unstable_feature: List[Preview],\n quiet: bool,\n verbose: bool,\n required_version: Optional[str],\n include: Pattern[str],\n exclude: Optional[Pattern[str]],\n extend_exclude: Optional[Pattern[str]],\n force_exclude: Optional[Pattern[str]],\n stdin_filename: Optional[str],\n workers: Optional[int],\n src: Tuple[str, ...],\n config: Optional[str],\n) -> None:\n \"\"\"The uncompromising code formatter.\"\"\"\n ctx.ensure_object(dict)\n\n if src and code is not None:\n out(\n main.get_usage(ctx)\n + \"\\n\\n'SRC' and 'code' cannot be passed simultaneously.\"\n )\n ctx.exit(1)\n if not src and code is None:\n out(main.get_usage(ctx) + \"\\n\\nOne of 'SRC' or 'code' is required.\")\n ctx.exit(1)\n\n # It doesn't do anything if --unstable is also passed, so just allow it.\n if enable_unstable_feature and not (preview or unstable):\n out(\n main.get_usage(ctx)\n + \"\\n\\n'--enable-unstable-feature' requires '--preview'.\"\n )\n ctx.exit(1)\n\n root, method = (\n find_project_root(src, stdin_filename) if code is None else (None, None)\n )\n ctx.obj[\"root\"] = root\n\n if verbose:\n if root:\n out(\n f\"Identified `{root}` as project root containing a {method}.\",\n fg=\"blue\",\n )\n\n if config:\n config_source = ctx.get_parameter_source(\"config\")\n user_level_config = str(find_user_pyproject_toml())\n if config == user_level_config:\n out(\n \"Using configuration from user-level config at \"\n f\"'{user_level_config}'.\",\n fg=\"blue\",\n )\n elif config_source in (\n ParameterSource.DEFAULT,\n ParameterSource.DEFAULT_MAP,\n ):\n out(\"Using configuration from project root.\", fg=\"blue\")\n else:\n out(f\"Using configuration in '{config}'.\", fg=\"blue\")\n if ctx.default_map:\n for param, value in ctx.default_map.items():\n out(f\"{param}: {value}\")\n\n error_msg = \"Oh no! \ud83d\udca5 \ud83d\udc94 \ud83d\udca5\"\n if (\n required_version\n and required_version != __version__\n and required_version != __version__.split(\".\")[0]\n ):\n err(\n f\"{error_msg} The required version `{required_version}` does not match\"\n f\" the running version `{__version__}`!\"\n )\n ctx.exit(1)\n if ipynb and pyi:\n err(\"Cannot pass both `pyi` and `ipynb` flags!\")\n ctx.exit(1)\n\n write_back = WriteBack.from_configuration(check=check, diff=diff, color=color)\n if target_version:\n versions = set(target_version)\n else:\n # We'll autodetect later.\n versions = set()\n mode = Mode(\n target_versions=versions,\n line_length=line_length,\n is_pyi=pyi,\n is_ipynb=ipynb,\n skip_source_first_line=skip_source_first_line,\n string_normalization=not skip_string_normalization,\n magic_trailing_comma=not skip_magic_trailing_comma,\n preview=preview,\n unstable=unstable,\n python_cell_magics=set(python_cell_magics),\n enabled_features=set(enable_unstable_feature),\n )\n\n lines: List[Tuple[int, int]] = []\n if line_ranges:\n if ipynb:\n err(\"Cannot use --line-ranges with ipynb files.\")\n ctx.exit(1)\n\n try:\n lines = parse_line_ranges(line_ranges)\n except ValueError as e:\n err(str(e))\n ctx.exit(1)\n\n if code is not None:\n # Run in quiet mode by default with -c; the extra output isn't useful.\n # You can still pass -v to get verbose output.\n quiet = True\n\n report = Report(check=check, diff=diff, quiet=quiet, verbose=verbose)\n\n if code is not None:\n reformat_code(\n content=code,\n fast=fast,\n write_back=write_back,\n mode=mode,\n report=report,\n lines=lines,\n )\n else:\n assert root is not None # root is only None if code is not None\n try:\n sources = get_sources(\n root=root,\n src=src,\n quiet=quiet,\n verbose=verbose,\n include=include,\n exclude=exclude,\n extend_exclude=extend_exclude,\n force_exclude=force_exclude,\n report=report,\n stdin_filename=stdin_filename,\n )\n except GitWildMatchPatternError:\n ctx.exit(1)\n\n path_empty(\n sources,\n \"No Python files are present to be formatted. Nothing to do \ud83d\ude34\",\n quiet,\n verbose,\n ctx,\n )\n\n if len(sources) == 1:\n reformat_one(\n src=sources.pop(),\n fast=fast,\n write_back=write_back,\n mode=mode,\n report=report,\n lines=lines,\n )\n else:\n from black.concurrency import reformat_many\n\n if lines:\n err(\"Cannot use --line-ranges to format multiple files.\")\n ctx.exit(1)\n reformat_many(\n sources=sources,\n fast=fast,\n write_back=write_back,\n mode=mode,\n report=report,\n workers=workers,\n )\n\n if verbose or not quiet:\n if code is None and (verbose or report.change_count or report.failure_count):\n out()\n out(error_msg if report.return_code else \"All done! \u2728 \ud83c\udf70 \u2728\")\n if code is None:\n click.echo(str(report), err=True)\n ctx.exit(report.return_code)\n\n\ndef get_sources(\n *,\n root: Path,\n src: Tuple[str, ...],\n quiet: bool,\n verbose: bool,\n include: Pattern[str],\n exclude: Optional[Pattern[str]],\n extend_exclude: Optional[Pattern[str]],\n force_exclude: Optional[Pattern[str]],\n report: \"Report\",\n stdin_filename: Optional[str],\n) -> Set[Path]:\n \"\"\"Compute the set of files to be formatted.\"\"\"\n sources: Set[Path] = set()\n\n assert root.is_absolute(), f\"INTERNAL ERROR: `root` must be absolute but is {root}\"\n using_default_exclude = exclude is None\n exclude = re_compile_maybe_verbose(DEFAULT_EXCLUDES) if exclude is None else exclude\n gitignore: Optional[Dict[Path, PathSpec]] = None\n root_gitignore = get_gitignore(root)\n\n for s in src:\n if s == \"-\" and stdin_filename:\n path = Path(stdin_filename)\n is_stdin = True\n else:\n path = Path(s)\n is_stdin = False\n\n # Compare the logic here to the logic in `gen_python_files`.\n if is_stdin or path.is_file():\n if resolves_outside_root_or_cannot_stat(path, root, report):\n if verbose:\n out(f'Skipping invalid source: \"{path}\"', fg=\"red\")\n continue\n\n root_relative_path = best_effort_relative_path(path, root).as_posix()\n root_relative_path = \"/\" + root_relative_path\n\n # Hard-exclude any files that matches the `--force-exclude` regex.\n if path_is_excluded(root_relative_path, force_exclude):\n report.path_ignored(\n path, \"matches the --force-exclude regular expression\"\n )\n continue\n\n if is_stdin:\n path = Path(f\"{STDIN_PLACEHOLDER}{str(path)}\")\n\n if path.suffix == \".ipynb\" and not jupyter_dependencies_are_installed(\n warn=verbose or not quiet\n ):\n continue\n\n if verbose:\n out(f'Found input source: \"{path}\"', fg=\"blue\")\n sources.add(path)\n elif path.is_dir():\n path = root / (path.resolve().relative_to(root))\n if verbose:\n out(f'Found input source directory: \"{path}\"', fg=\"blue\")\n\n if using_default_exclude:\n gitignore = {\n root: root_gitignore,\n path: get_gitignore(path),\n }\n sources.update(\n gen_python_files(\n path.iterdir(),\n root,\n include,\n exclude,\n extend_exclude,\n force_exclude,\n report,\n gitignore,\n verbose=verbose,\n quiet=quiet,\n )\n )\n elif s == \"-\":\n if verbose:\n out(\"Found input source stdin\", fg=\"blue\")\n sources.add(path)\n else:\n err(f\"invalid path: {s}\")\n\n return sources\n\n\ndef path_empty(\n src: Sized, msg: str, quiet: bool, verbose: bool, ctx: click.Context\n) -> None:\n \"\"\"\n Exit if there is no `src` provided for formatting\n \"\"\"\n if not src:\n if verbose or not quiet:\n out(msg)\n ctx.exit(0)\n\n\ndef reformat_code(\n content: str,\n fast: bool,\n write_back: WriteBack,\n mode: Mode,\n report: Report,\n *,\n lines: Collection[Tuple[int, int]] = (),\n) -> None:\n \"\"\"\n Reformat and print out `content` without spawning child processes.\n Similar to `reformat_one`, but for string content.\n\n `fast`, `write_back`, and `mode` options are passed to\n :func:`format_file_in_place` or :func:`format_stdin_to_stdout`.\n \"\"\"\n path = Path(\"\")\n try:\n changed = Changed.NO\n if format_stdin_to_stdout(\n content=content, fast=fast, write_back=write_back, mode=mode, lines=lines\n ):\n changed = Changed.YES\n report.done(path, changed)\n except Exception as exc:\n if report.verbose:\n traceback.print_exc()\n report.failed(path, str(exc))\n\n\n# diff-shades depends on being to monkeypatch this function to operate. I know it's\n# not ideal, but this shouldn't cause any issues ... hopefully. ~ichard26\n@mypyc_attr(patchable=True)\ndef reformat_one(\n src: Path,\n fast: bool,\n write_back: WriteBack,\n mode: Mode,\n report: \"Report\",\n *,\n lines: Collection[Tuple[int, int]] = (),\n) -> None:\n \"\"\"Reformat a single file under `src` without spawning child processes.\n\n `fast`, `write_back`, and `mode` options are passed to\n :func:`format_file_in_place` or :func:`format_stdin_to_stdout`.\n \"\"\"\n try:\n changed = Changed.NO\n\n if str(src) == \"-\":\n is_stdin = True\n elif str(src).startswith(STDIN_PLACEHOLDER):\n is_stdin = True\n # Use the original name again in case we want to print something\n # to the user\n src = Path(str(src)[len(STDIN_PLACEHOLDER) :])\n else:\n is_stdin = False\n\n if is_stdin:\n if src.suffix == \".pyi\":\n mode = replace(mode, is_pyi=True)\n elif src.suffix == \".ipynb\":\n mode = replace(mode, is_ipynb=True)\n if format_stdin_to_stdout(\n fast=fast, write_back=write_back, mode=mode, lines=lines\n ):\n changed = Changed.YES\n else:\n cache = Cache.read(mode)\n if write_back not in (WriteBack.DIFF, WriteBack.COLOR_DIFF):\n if not cache.is_changed(src):\n changed = Changed.CACHED\n if changed is not Changed.CACHED and format_file_in_place(\n src, fast=fast, write_back=write_back, mode=mode, lines=lines\n ):\n changed = Changed.YES\n if (write_back is WriteBack.YES and changed is not Changed.CACHED) or (\n write_back is WriteBack.CHECK and changed is Changed.NO\n ):\n cache.write([src])\n report.done(src, changed)\n except Exception as exc:\n if report.verbose:\n traceback.print_exc()\n report.failed(src, str(exc))\n\n\ndef format_file_in_place(\n src: Path,\n fast: bool,\n mode: Mode,\n write_back: WriteBack = WriteBack.NO,\n lock: Any = None, # multiprocessing.Manager().Lock() is some crazy proxy\n *,\n lines: Collection[Tuple[int, int]] = (),\n) -> bool:\n \"\"\"Format file under `src` path. Return True if changed.\n\n If `write_back` is DIFF, write a diff to stdout. If it is YES, write reformatted\n code to the file.\n `mode` and `fast` options are passed to :func:`format_file_contents`.\n \"\"\"\n if src.suffix == \".pyi\":\n mode = replace(mode, is_pyi=True)\n elif src.suffix == \".ipynb\":\n mode = replace(mode, is_ipynb=True)\n\n then = datetime.fromtimestamp(src.stat().st_mtime, timezone.utc)\n header = b\"\"\n with open(src, \"rb\") as buf:\n if mode.skip_source_first_line:\n header = buf.readline()\n src_contents, encoding, newline = decode_bytes(buf.read())\n try:\n dst_contents = format_file_contents(\n src_contents, fast=fast, mode=mode, lines=lines\n )\n except NothingChanged:\n return False\n except JSONDecodeError:\n raise ValueError(\n f\"File '{src}' cannot be parsed as valid Jupyter notebook.\"\n ) from None\n src_contents = header.decode(encoding) + src_contents\n dst_contents = header.decode(encoding) + dst_contents\n\n if write_back == WriteBack.YES:\n with open(src, \"w\", encoding=encoding, newline=newline) as f:\n f.write(dst_contents)\n elif write_back in (WriteBack.DIFF, WriteBack.COLOR_DIFF):\n now = datetime.now(timezone.utc)\n src_name = f\"{src}\\t{then}\"\n dst_name = f\"{src}\\t{now}\"\n if mode.is_ipynb:\n diff_contents = ipynb_diff(src_contents, dst_contents, src_name, dst_name)\n else:\n diff_contents = diff(src_contents, dst_contents, src_name, dst_name)\n\n if write_back == WriteBack.COLOR_DIFF:\n diff_contents = color_diff(diff_contents)\n\n with lock or nullcontext():\n f = io.TextIOWrapper(\n sys.stdout.buffer,\n encoding=encoding,\n newline=newline,\n write_through=True,\n )\n f = wrap_stream_for_windows(f)\n f.write(diff_contents)\n f.detach()\n\n return True\n\n\ndef format_stdin_to_stdout(\n fast: bool,\n *,\n content: Optional[str] = None,\n write_back: WriteBack = WriteBack.NO,\n mode: Mode,\n lines: Collection[Tuple[int, int]] = (),\n) -> bool:\n \"\"\"Format file on stdin. Return True if changed.\n\n If content is None, it's read from sys.stdin.\n\n If `write_back` is YES, write reformatted code back to stdout. If it is DIFF,\n write a diff to stdout. The `mode` argument is passed to\n :func:`format_file_contents`.\n \"\"\"\n then = datetime.now(timezone.utc)\n\n if content is None:\n src, encoding, newline = decode_bytes(sys.stdin.buffer.read())\n else:\n src, encoding, newline = content, \"utf-8\", \"\"\n\n dst = src\n try:\n dst = format_file_contents(src, fast=fast, mode=mode, lines=lines)\n return True\n\n except NothingChanged:\n return False\n\n finally:\n f = io.TextIOWrapper(\n sys.stdout.buffer, encoding=encoding, newline=newline, write_through=True\n )\n if write_back == WriteBack.YES:\n # Make sure there's a newline after the content\n if dst and dst[-1] != \"\\n\":\n dst += \"\\n\"\n f.write(dst)\n elif write_back in (WriteBack.DIFF, WriteBack.COLOR_DIFF):\n now = datetime.now(timezone.utc)\n src_name = f\"STDIN\\t{then}\"\n dst_name = f\"STDOUT\\t{now}\"\n d = diff(src, dst, src_name, dst_name)\n if write_back == WriteBack.COLOR_DIFF:\n d = color_diff(d)\n f = wrap_stream_for_windows(f)\n f.write(d)\n f.detach()\n\n\ndef check_stability\n_and_equivalence(\n src_contents: str,\n dst_contents: str,\n *,\n mode: Mode,\n lines: Collection[Tuple[int, int]] = (),\n) -> None:\n \"\"\"Perform stability and equivalence checks.\n\n Raise AssertionError if source and destination contents are not\n equivalent, or if a second pass of the formatter would format the\n content differently.\n \"\"\"\n assert_equivalent(src_contents, dst_contents)\n assert_stable(src_contents, dst_contents, mode=mode, lines=lines)\n\n\ndef format_file_contents(\n src_contents: str,\n *,\n fast: bool,\n mode: Mode,\n lines: Collection[Tuple[int, int]] = (),\n) -> FileContent:\n \"\"\"Reformat contents of a file and return new contents.\n\n If `fast` is False, additionally confirm that the reformatted code is\n valid by calling :func:`assert_equivalent` and :func:`assert_stable` on it.\n `mode` is passed to :func:`format_str`.\n \"\"\"\n if mode.is_ipynb:\n dst_contents = format_ipynb_string(src_contents, fast=fast, mode=mode)\n else:\n dst_contents = format_str(src_contents, mode=mode, lines=lines)\n if src_contents == dst_contents:\n raise NothingChanged\n\n if not fast and not mode.is_ipynb:\n # Jupyter notebooks will already have been checked above.\n check_stability_and_equivalence(\n src_contents, dst_contents, mode=mode, lines=lines\n )\n return dst_contents\n\n\ndef validate_cell(src: str, mode: Mode) -> None:\n \"\"\"Check that cell does not already contain TransformerManager transformations,\n or non-Python cell magics, which might cause tokenizer_rt to break because of\n indentations.\n\n If a cell contains ``!ls``, then it'll be transformed to\n ``get_ipython().system('ls')``. However, if the cell originally contained\n ``get_ipython().system('ls')``, then it would get transformed in the same way:\n\n >>> TransformerManager().transform_cell(\"get_ipython().system('ls')\")\n \"get_ipython().system('ls')\\n\"\n >>> TransformerManager().transform_cell(\"!ls\")\n \"get_ipython().system('ls')\\n\"\n\n Due to the impossibility of safely roundtripping in such situations, cells\n containing transformed magics will be ignored.\n \"\"\"\n if any(transformed_magic in src for transformed_magic in TRANSFORMED_MAGICS):\n raise NothingChanged\n if (\n src[:2] == \"%%\"\n and src.split()[0][2:] not in PYTHON_CELL_MAGICS | mode.python_cell_magics\n ):\n raise NothingChanged\n\n\ndef format_cell(src: str, *, fast: bool, mode: Mode) -> str:\n \"\"\"Format code in given cell of Jupyter notebook.\n\n General idea is:\n\n - if cell has trailing semicolon, remove it;\n - if cell has IPython magics, mask them;\n - format cell;\n - reinstate IPython magics;\n - reinstate trailing semicolon (if originally present);\n - strip trailing newlines.\n\n Cells with syntax errors will not be processed, as they\n could potentially be automagics or multi-line magics, which\n are currently not supported.\n \"\"\"\n validate_cell(src, mode)\n src_without_trailing_semicolon, has_trailing_semicolon = remove_trailing_semicolon(\n src\n )\n try:\n masked_src, replacements = mask_cell(src_without_trailing_semicolon)\n except SyntaxError:\n raise NothingChanged from None\n masked_dst = format_str(masked_src, mode=mode)\n if not fast:\n check_stability_and_equivalence(masked_src, masked_dst, mode=mode)\n dst_without_trailing_semicolon = unmask_cell(masked_dst, replacements)\n dst = put_trailing_semicolon_back(\n dst_without_trailing_semicolon, has_trailing_semicolon\n )\n dst = dst.rstrip(\"\\n\")\n if dst == src:\n raise NothingChanged from None\n return dst\n\n\ndef validate_metadata(nb: MutableMapping[str, Any]) -> None:\n \"\"\"If notebook is marked as non-Python, don't format it.\n\n All notebook metadata fields are optional, see\n https://nbformat.readthedocs.io/en/latest/format_description.html. So\n if a notebook has empty metadata, we will try to parse it anyway.\n \"\"\"\n language = nb.get(\"metadata\", {}).get(\"language_info\", {}).get(\"name\", None)\n if language is not None and language != \"python\":\n raise NothingChanged from None\n\n\ndef format_ipynb_string(src_contents: str, *, fast: bool, mode: Mode) -> FileContent:\n \"\"\"Format Jupyter notebook.\n\n Operate cell-by-cell, only on code cells, only for Python notebooks.\n If the ``.ipynb`` originally had a trailing newline, it'll be preserved.\n \"\"\"\n if not src_contents:\n raise NothingChanged\n\n trailing_newline = src_contents[-1] == \"\\n\"\n modified = False\n nb = json.loads(src_contents)\n validate_metadata(nb)\n for cell in nb[\"cells\"]:\n if cell.get(\"cell_type\", None) == \"code\":\n try:\n src = \"\".join(cell[\"source\"])\n dst = format_cell(src, fast=fast, mode=mode)\n except NothingChanged:\n pass\n else:\n cell[\"source\"] = dst.splitlines(keepends=True)\n modified = True\n if modified:\n dst_contents = json.dumps(nb, indent=1, ensure_ascii=False)\n if trailing_newline:\n dst_contents = dst_contents + \"\\n\"\n return dst_contents\n else:\n raise NothingChanged\n\n\ndef format_str(\n src_contents: str, *, mode: Mode, lines: Collection[Tuple[int, int]] = ()\n) -> str:\n \"\"\"Reformat a string and return new contents.\n\n `mode` determines formatting options, such as how many characters per line are\n allowed. Example:\n\n >>> import black\n >>> print(black.format_str(\"def f(arg:str='')->None:...\", mode=black.Mode()))\n def f(arg: str = \"\") -> None:\n ...\n\n A more complex example:\n\n >>> print(\n ... black.format_str(\n ... \"def f(arg:str='')->None: hey\",\n ... mode=black.Mode(\n ... target_versions={black.TargetVersion.PY36},\n ... line_length=10,\n ... string_normalization=False,\n ... is_pyi=False,\n ... ),\n ... ),\n ... )\n def f(\n arg: str = '',\n ) -> None:\n hey\n\n \"\"\"\n dst_contents = _format_str_once(src_contents, mode=mode, lines=lines)\n # Forced second pass to work around optional trailing commas (becoming\n # forced trailing commas on pass 2) interacting differently with optional\n # parentheses. Admittedly ugly.\n if src_contents != dst_contents:\n if lines:\n lines = adjusted_lines(lines, src_contents, dst_contents)\n return _format_str_once(dst_contents, mode=mode, lines=lines)\n return dst_contents\n\n\ndef _format_str_once(\n src_contents: str, *, mode: Mode, lines: Collection[Tuple[int, int]] = ()\n) -> str:\n src_node = lib2to3_parse(src_contents.lstrip(), mode.target_versions)\n dst_blocks: List[LinesBlock] = []\n if mode.target_versions:\n versions = mode.target_versions\n else:\n future_imports = get_future_imports(src_node)\n versions = detect_target_versions(src_node, future_imports=future_imports)\n\n context_manager_features = {\n feature\n for feature in {Feature.PARENTHESIZED_CONTEXT_MANAGERS}\n if supports_feature(versions, feature)\n }\n normalize_fmt_off(src_node, mode, lines)\n if lines:\n # This should be called after normalize_fmt_off.\n convert_unchanged_lines(src_node, lines)\n\n line_generator = LineGenerator(mode=mode, features=context_manager_features)\n elt = EmptyLineTracker(mode=mode)\n split_line_features = {\n feature\n for feature in {Feature.TRAILING_COMMA_IN_CALL, Feature.TRAILING_COMMA_IN_DEF}\n if supports_feature(versions, feature)\n }\n block: Optional[LinesBlock] = None\n for current_line in line_generator.visit(src_node):\n block = elt.maybe_empty_lines(current_line)\n dst_blocks.append(block)\n for line in transform_line(\n current_line, mode=mode, features=split_line_features\n ):\n block.content_lines.append(str(line))\n if dst_blocks:\n dst_blocks[-1].after = 0\n dst_contents = []\n for block in dst_blocks:\n dst_contents.extend(block.all_lines())\n if not dst_contents:\n # Use decode_bytes to retrieve the correct source newline (CRLF or LF),\n # and check if normalized_content has more than one line\n normalized_content, _, newline = decode_bytes(src_contents.encode(\"utf-8\"))\n if \"\\n\" in normalized_content:\n return newline\n return \"\"\n return \"\".join(dst_contents)\n\n\ndef decode_bytes(src: bytes) -> Tuple[FileContent, Encoding, NewLine]:\n \"\"\"Return a tuple of (decoded_contents, encoding, newline).\n\n `newline` is either CRLF or LF but `decoded_contents` is decoded with\n universal newlines (i.e. only contains LF).\n \"\"\"\n srcbuf = io.BytesIO(src)\n encoding, lines = tokenize.detect_encoding(srcbuf.readline)\n if not lines:\n return \"\", encoding, \"\\n\"\n\n newline = \"\\r\\n\" if b\"\\r\\n\" == lines[0][-2:] else \"\\n\"\n srcbuf.seek(0)\n with io.TextIOWrapper(srcbuf, encoding) as tiow:\n return tiow.read(), encoding, newline\n\n\ndef get_features_used( # noqa: C901\n node: Node, *, future_imports: Optional[Set[str]] = None\n) -> Set[Feature]:\n \"\"\"Return a set of (relatively) new Python features used in this file.\n\n Currently looking for:\n - f-strings;\n - self-documenting expressions in f-strings (f\"{x=}\");\n - underscores in numeric literals;\n - trailing commas after * or ** in function signatures and calls;\n - positional only arguments in function signatures and lambdas;\n - assignment expression;\n - relaxed decorator syntax;\n - usage of __future__ flags (annotations);\n - print / exec statements;\n - parenthesized context managers;\n - match statements;\n - except* clause;\n - variadic generics;\n \"\"\"\n features: Set[Feature] = set()\n if future_imports:\n features |= {\n FUTURE_FLAG_TO_FEATURE[future_import]\n for future_import in future_imports\n if future_import in FUTURE_FLAG_TO_FEATURE\n }\n\n for n in node.pre_order():\n if is_string_token(n):\n value_head = n.value[:2]\n if value_head in {'f\"', 'F\"', \"f'\", \"F'\", \"rf\", \"fr\", \"RF\", \"FR\"}:\n features.add(Feature.F_STRINGS)\n if Feature.DEBUG_F_STRINGS not in features:\n for span_beg, span_end in iter_fexpr_spans(n.value):\n if n.value[span_beg : span_end - 1].rstrip().endswith(\"=\"):\n features.add(Feature.DEBUG_F_STRINGS)\n break\n\n elif is_number_token(n):\n if \"_\" in n.value:\n features.add(Feature.NUMERIC_UNDERSCORES)\n\n elif n.type == token.SLASH:\n if n.parent and n.parent.type in {\n syms.typedargslist,\n syms.arglist,\n syms.varargslist,\n }:\n features.add(Feature.POS_ONLY_ARGUMENTS)\n\n elif n.type == token.COLONEQUAL:\n features.add(Feature.ASSIGNMENT_EXPRESSIONS)\n\n elif n.type == syms.decorator:\n if len(n.children) > 1 and not is_simple_decorator_expression(\n n.children[1]\n ):\n features.add(Feature.RELAXED_DECORATORS)\n\n elif (\n n.type in {syms.typedargslist, syms.arglist}\n and n.children\n and n.children[-1].type == token.COMMA\n ):\n if n.type == syms.typedargslist:\n feature = Feature.TRAILING_COMMA_IN_DEF\n else:\n feature = Feature.TRAILING_COMMA_IN_CALL\n\n for ch in n.children:\n if ch.type in STARS:\n features.add(feature)\n\n if ch.type == syms.argument:\n for argch in ch.children:\n if argch.type in STARS:\n features.add(feature)\n\n elif (\n n.type in {syms.return_stmt, syms.yield_expr}\n and len(n.children) >= 2\n and n.children[1].type == syms.testlist_star_expr\n and any(child.type == syms.star_expr for child in n.children[1].children)\n ):\n features.add(Feature.UNPACKING_ON_FLOW)\n\n elif (\n n.type == syms.annassign\n and len(n.children) >= 4\n and n.children[3].type == syms.testlist_star_expr\n ):\n features.add(Feature.ANN_ASSIGN_EXTENDED_RHS)\n\n elif (\n n.type == syms.with_stmt\n and len(n.children) > 2\n and n.children[1].type == syms.atom\n ):\n atom_children = n.children[1].children\n if (\n len(atom_children) == 3\n and atom_children[0].type == token.LPAR\n and _contains_asexpr(atom_children[1])\n and atom_children[2].type == token.RPAR\n ):\n features.add(Feature.PARENTHESIZED_CONTEXT_MANAGERS)\n\n elif n.type == syms.match_stmt:\n features.add(Feature.PATTERN_MATCHING)\n\n elif (\n n.type == syms.except_clause\n and len(n.children) >= 2\n and n.children[1].type == token.STAR\n ):\n features.add(Feature.EXCEPT_STAR)\n\n elif n.type in {syms.subscriptlist, syms.trailer} and any(\n child.type == syms.star_expr for child in n.children\n ):\n features.add(Feature.VARIADIC_GENERICS)\n\n elif (\n n.type == syms.tname_star\n and len(n.children) == 3\n and n.children[2].type == syms.star_expr\n ):\n features.add(Feature.VARIADIC_GENERICS)\n\n elif n.type in (syms.type_stmt, syms.typeparams):\n features.add(Feature.TYPE_PARAMS)\n\n return features\n\n\ndef _contains_asexpr(node: Union[Node, Leaf]) -> bool:\n \"\"\"Return True if `node` contains an as-pattern.\"\"\"\n if node.type == syms.asexpr_test:\n return True\n elif node.type == syms.atom:\n if (\n len(node.children) == 3\n and node.children[0].type == token.LPAR\n and node.children[2].type == token.RPAR\n ):\n return _contains_asexpr(node.children[1])\n elif node.type == syms.testlist_gexp:\n return any(_contains_asexpr(child) for child in node.children)\n return False\n\n\ndef detect_target_versions(\n node: Node, *, future_imports: Optional[Set[str]] = None\n) -> Set[TargetVersion]:\n \"\"\"Detect the version to target based on the nodes used.\"\"\"\n features = get_features_used(node, future_imports=future_imports)\n return {\n version for version in TargetVersion if features <= VERSION_TO_FEATURES[version]\n }\n\n\ndef get_future_imports(node: Node) -> Set[str]:\n \"\"\"Return a set of __future__ imports in the file.\"\"\"\n imports: Set[str] = set()\n\n def get_imports_from_children(children: List[LN]) -> Generator[str, None, None]:\n for child in children:\n if isinstance(child, Leaf):\n if child.type == token.NAME:\n yield child.value\n\n elif child.type == syms.import_as_name:\n orig_name = child.children[0]\n assert isinstance(orig_name, Leaf), \"Invalid syntax parsing imports\"\n assert orig_name.type == token.NAME, \"Invalid syntax parsing imports\"\n yield orig_name.value\n\n elif child.type == syms.import_as_names:\n yield from get_imports_from_children(child.children)\n\n else:\n raise AssertionError(\"Invalid syntax parsing imports\")\n\n for child in node.children:\n if child.type != syms.simple_stmt:\n break\n\n first_child = child.children[0]\n if isinstance(first_child, Leaf):\n # Continue looking if we see a docstring; otherwise stop.\n if (\n len(child.children) == 2\n and first_child.type == token.STRING\n and child.children[1].type == token.NEWLINE\n ):\n continue\n\n break\n\n elif first_child.type == syms.import_from:\n module_name = first_child.children[1]\n if not isinstance(module_name, Leaf) or module_name.value != \"__future__\":\n break\n\n imports |= set(get_imports_from_children(first_child.children[3:]))\n else:\n break\n\n return imports\n\n\ndef assert_equivalent(src: str, dst: str) -> None:\n \"\"\"Raise AssertionError if `src` and `dst` aren't equivalent.\"\"\"\n try:\n src_ast = parse_ast(src)\n except Exception as exc:\n raise AssertionError(\n \"cannot use --safe with this file; failed to parse source file AST: \"\n f\"{exc}\\n\"\n \"This could be caused by running Black with an older Python version \"\n \"that does not support new syntax used in your source file.\"\n ) from exc\n\n try:\n dst_ast = parse_ast(dst)\n except Exception as exc:\n log = dump_to_file(\"\".join(traceback.format_tb(exc.__traceback__)), dst)\n raise AssertionError(\n f\"INTERNAL ERROR: Black produced invalid code: {exc}. \"\n \"Please report a bug on https://github.com/psf/black/issues. \"\n f\"This invalid output might be helpful: {log}\"\n ) from None\n\n src_ast_str = \"\\n\".join(stringify_ast(src_ast))\n dst_ast_str = \"\\n\".join(stringify_ast(dst_ast))\n if src_ast_str != dst_ast_str:\n log = dump_to_file(diff(src_ast_str, dst_ast_str, \"src\", \"dst\"))\n raise AssertionError(\n \"INTERNAL ERROR: Black produced code that is not equivalent to the\"\n \" source. Please report a bug on \"\n f\"https://github.com/psf/black/issues. This diff might be helpful: {log}\"\n ) from None\n\n\ndef assert_stable(\n src: str, dst: str, mode: Mode, *, lines: Collection[Tuple[int, int]] = ()\n) -> None:\n \"\"\"Raise AssertionError if `dst` reformats differently the second time.\"\"\"\n if lines:\n # Formatting specified lines requires `adjusted_lines` to map original lines\n # to the formatted lines before re-formatting the previously formatted result.\n # Due to less-ideal diff algorithm, some edge cases produce incorrect new line\n # ranges. Hence for now, we skip the stable check.\n # See https://github.com/psf/black/issues/4033 for context.\n return\n # We shouldn't call format_str() here, because that formats the string\n # twice and may hide a bug where we bounce back and forth between two\n # versions.\n newdst = _format_str_once(dst, mode=mode, lines=lines)\n if dst != newdst:\n log = dump_to_file(\n str(mode),\n diff(src, dst, \"source\", \"first pass\"),\n diff(dst, newdst, \"first pass\", \"second pass\"),\n )\n raise AssertionError(\n \"INTERNAL ERROR: Black produced different code on the second pass of the\"\n...\n", "needle_token_start": 10454, "needle_token_end": 10602, "code_context_ntokens": 16174, "output": ["```python\ndef check_stability_and_equivalence(\n src_contents: str,\n dst_contents: str,\n *,\n mode: Mode,\n lines: Collection[Tuple[int, int]] = (),\n) -> None:\n \"\"\"Perform stability and equivalence checks.\n\n Raise AssertionError if source and destination contents are not\n equivalent, or if a second pass of the formatter would format the\n content differently.\n \"\"\"\n assert_equivalent(src_contents, dst_contents)\n assert_stable(src_contents, dst_contents, mode=mode, lines=lines)\n```"]} +{"repo": "psf/black", "name": "visit", "language": "python", "path": "src/black/nodes.py", "position_ratio": 0.75, "description": "\nFunction Description:\n1. **Purpose**: The function is designed to traverse a syntax tree node and its children, applying specific methods based on the node's type, and yielding results of a specified type.\n2. **Input**: A node from a syntax tree.\n3. **Output**: An iterator that yields objects of a specified type, derived from processing the node and its children.\n4. **Procedure**: The function first determines the type of the node. If the node type is less than 256, it uses a predefined mapping to get the method name; otherwise, it converts the node type to a string representation. It then attempts to find a method specifically tailored to handle the node type. If such a method exists, it is called and its results are yielded. If no specific method is found, a default method is called and its results are yielded.\n\n", "instruction": "Based on the function description and code context, please retrieve and repeat the exact described function from the code context in a code block wrapped by ```:", "template": "instruction\ncode_context\ndescription\ninstruction", "code_context": "# Path: src/black/_width_table.py\n# Generated by make_width_table.py\n# wcwidth 0.2.6\n# Unicode 15.0.0\nfrom typing import Final, List, Tuple\n\nWIDTH_TABLE: Final[List[Tuple[int, int, int]]] = [\n (0, 0, 0),\n (1, 31, -1),\n (127, 159, -1),\n (768, 879, 0),\n (1155, 1161, 0),\n (1425, 1469, 0),\n (1471, 1471, 0),\n (1473, 1474, 0),\n (1476, 1477, 0),\n (1479, 1479, 0),\n (1552, 1562, 0),\n (1611, 1631, 0),\n (1648, 1648, 0),\n (1750, 1756, 0),\n (1759, 1764, 0),\n (1767, 1768, 0),\n (1770, 1773, 0),\n (1809, 1809, 0),\n (1840, 1866, 0),\n (1958, 1968, 0),\n (2027, 2035, 0),\n (2045, 2045, 0),\n (2070, 2073, 0),\n (2075, 2083, 0),\n (2085, 2087, 0),\n (2089, 2093, 0),\n (2137, 2139, 0),\n (2200, 2207, 0),\n (2250, 2273, 0),\n (2275, 2306, 0),\n (2362, 2362, 0),\n (2364, 2364, 0),\n (2369, 2376, 0),\n (2381, 2381, 0),\n (2385, 2391, 0),\n (2402, 2403, 0),\n (2433, 2433, 0),\n (2492, 2492, 0),\n (2497, 2500, 0),\n (2509, 2509, 0),\n (2530, 2531, 0),\n (2558, 2558, 0),\n (2561, 2562, 0),\n (2620, 2620, 0),\n (2625, 2626, 0),\n (2631, 2632, 0),\n (2635, 2637, 0),\n (2641, 2641, 0),\n (2672, 2673, 0),\n (2677, 2677, 0),\n (2689, 2690, 0),\n (2748, 2748, 0),\n (2753, 2757, 0),\n (2759, 2760, 0),\n (2765, 2765, 0),\n (2786, 2787, 0),\n (2810, 2815, 0),\n (2817, 2817, 0),\n (2876, 2876, 0),\n (2879, 2879, 0),\n (2881, 2884, 0),\n (2893, 2893, 0),\n (2901, 2902, 0),\n (2914, 2915, 0),\n (2946, 2946, 0),\n (3008, 3008, 0),\n (3021, 3021, 0),\n (3072, 3072, 0),\n (3076, 3076, 0),\n (3132, 3132, 0),\n (3134, 3136, 0),\n (3142, 3144, 0),\n (3146, 3149, 0),\n (3157, 3158, 0),\n (3170, 3171, 0),\n (3201, 3201, 0),\n (3260, 3260, 0),\n (3263, 3263, 0),\n (3270, 3270, 0),\n (3276, 3277, 0),\n (3298, 3299, 0),\n (3328, 3329, 0),\n (3387, 3388, 0),\n (3393, 3396, 0),\n (3405, 3405, 0),\n (3426, 3427, 0),\n (3457, 3457, 0),\n (3530, 3530, 0),\n (3538, 3540, 0),\n (3542, 3542, 0),\n (3633, 3633, 0),\n (3636, 3642, 0),\n (3655, 3662, 0),\n (3761, 3761, 0),\n (3764, 3772, 0),\n (3784, 3790, 0),\n (3864, 3865, 0),\n (3893, 3893, 0),\n (3895, 3895, 0),\n (3897, 3897, 0),\n (3953, 3966, 0),\n (3968, 3972, 0),\n (3974, 3975, 0),\n (3981, 3991, 0),\n (3993, 4028, 0),\n (4038, 4038, 0),\n (4141, 4144, 0),\n (4146, 4151, 0),\n (4153, 4154, 0),\n (4157, 4158, 0),\n (4184, 4185, 0),\n (4190, 4192, 0),\n (4209, 4212, 0),\n (4226, 4226, 0),\n (4229, 4230, 0),\n (4237, 4237, 0),\n (4253, 4253, 0),\n (4352, 4447, 2),\n (4957, 4959, 0),\n (5906, 5908, 0),\n (5938, 5939, 0),\n (5970, 5971, 0),\n (6002, 6003, 0),\n (6068, 6069, 0),\n (6071, 6077, 0),\n (6086, 6086, 0),\n (6089, 6099, 0),\n (6109, 6109, 0),\n (6155, 6157, 0),\n (6159, 6159, 0),\n (6277, 6278, 0),\n (6313, 6313, 0),\n (6432, 6434, 0),\n (6439, 6440, 0),\n (6450, 6450, 0),\n (6457, 6459, 0),\n (6679, 6680, 0),\n (6683, 6683, 0),\n (6742, 6742, 0),\n (6744, 6750, 0),\n (6752, 6752, 0),\n (6754, 6754, 0),\n (6757, 6764, 0),\n (6771, 6780, 0),\n (6783, 6783, 0),\n (6832, 6862, 0),\n (6912, 6915, 0),\n (6964, 6964, 0),\n (6966, 6970, 0),\n (6972, 6972, 0),\n (6978, 6978, 0),\n (7019, 7027, 0),\n (7040, 7041, 0),\n (7074, 7077, 0),\n (7080, 7081, 0),\n (7083, 7085, 0),\n (7142, 7142, 0),\n (7144, 7145, 0),\n (7149, 7149, 0),\n (7151, 7153, 0),\n (7212, 7219, 0),\n (7222, 7223, 0),\n (7376, 7378, 0),\n (7380, 7392, 0),\n (7394, 7400, 0),\n (7405, 7405, 0),\n (7412, 7412, 0),\n (7416, 7417, 0),\n (7616, 7679, 0),\n (8203, 8207, 0),\n (8232, 8238, 0),\n (8288, 8291, 0),\n (8400, 8432, 0),\n (8986, 8987, 2),\n (9001, 9002, 2),\n (9193, 9196, 2),\n (9200, 9200, 2),\n (9203, 9203, 2),\n (9725, 9726, 2),\n (9748, 9749, 2),\n (9800, 9811, 2),\n (9855, 9855, 2),\n (9875, 9875, 2),\n (9889, 9889, 2),\n (9898, 9899, 2),\n (9917, 9918, 2),\n (9924, 9925, 2),\n (9934, 9934, 2),\n (9940, 9940, 2),\n (9962, 9962, 2),\n (9970, 9971, 2),\n (9973, 9973, 2),\n (9978, 9978, 2),\n (9981, 9981, 2),\n (9989, 9989, 2),\n (9994, 9995, 2),\n (10024, 10024, 2),\n (10060, 10060, 2),\n (10062, 10062, 2),\n (10067, 10069, 2),\n (10071, 10071, 2),\n (10133, 10135, 2),\n (10160, 10160, 2),\n (10175, 10175, 2),\n (11035, 11036, 2),\n (11088, 11088, 2),\n (11093, 11093, 2),\n (11503, 11505, 0),\n (11647, 11647, 0),\n (11744, 11775, 0),\n (11904, 11929, 2),\n (11931, 12019, 2),\n (12032, 12245, 2),\n (12272, 12283, 2),\n (12288, 12329, 2),\n (12330, 12333, 0),\n (12334, 12350, 2),\n (12353, 12438, 2),\n (12441, 12442, 0),\n (12443, 12543, 2),\n (12549, 12591, 2),\n (12593, 12686, 2),\n (12688, 12771, 2),\n (12784, 12830, 2),\n (12832, 12871, 2),\n (12880, 19903, 2),\n (19968, 42124, 2),\n (42128, 42182, 2),\n (42607, 42610, 0),\n (42612, 42621, 0),\n (42654, 42655, 0),\n (42736, 42737, 0),\n (43010, 43010, 0),\n (43014, 43014, 0),\n (43019, 43019, 0),\n (43045, 43046, 0),\n (43052, 43052, 0),\n (43204, 43205, 0),\n (43232, 43249, 0),\n (43263, 43263, 0),\n (43302, 43309, 0),\n (43335, 43345, 0),\n (43360, 43388, 2),\n (43392, 43394, 0),\n (43443, 43443, 0),\n (43446, 43449, 0),\n (43452, 43453, 0),\n (43493, 43493, 0),\n (43561, 43566, 0),\n (43569, 43570, 0),\n (43573, 43574, 0),\n (43587, 43587, 0),\n (43596, 43596, 0),\n (43644, 43644, 0),\n (43696, 43696, 0),\n (43698, 43700, 0),\n (43703, 43704, 0),\n (43710, 43711, 0),\n (43713, 43713, 0),\n (43756, 43757, 0),\n (43766, 43766, 0),\n (44005, 44005, 0),\n (44008, 44008, 0),\n (44013, 44013, 0),\n (44032, 55203, 2),\n (63744, 64255, 2),\n (64286, 64286, 0),\n (65024, 65039, 0),\n (65040, 65049, 2),\n (65056, 65071, 0),\n (65072, 65106, 2),\n (65108, 65126, 2),\n (65128, 65131, 2),\n (65281, 65376, 2),\n (65504, 65510, 2),\n (66045, 66045, 0),\n (66272, 66272, 0),\n (66422, 66426, 0),\n (68097, 68099, 0),\n (68101, 68102, 0),\n (68108, 68111, 0),\n (68152, 68154, 0),\n (68159, 68159, 0),\n (68325, 68326, 0),\n (68900, 68903, 0),\n (69291, 69292, 0),\n (69373, 69375, 0),\n (69446, 69456, 0),\n (69506, 69509, 0),\n (69633, 69633, 0),\n (69688, 69702, 0),\n (69744, 69744, 0),\n (69747, 69748, 0),\n (69759, 69761, 0),\n (69811, 69814, 0),\n (69817, 69818, 0),\n (69826, 69826, 0),\n (69888, 69890, 0),\n (69927, 69931, 0),\n (69933, 69940, 0),\n (70003, 70003, 0),\n (70016, 70017, 0),\n (70070, 70078, 0),\n (70089, 70092, 0),\n (70095, 70095, 0),\n (70191, 70193, 0),\n (70196, 70196, 0),\n (70198, 70199, 0),\n (70206, 70206, 0),\n (70209, 70209, 0),\n (70367, 70367, 0),\n (70371, 70378, 0),\n (70400, 70401, 0),\n (70459, 70460, 0),\n (70464, 70464, 0),\n (70502, 70508, 0),\n (70512, 70516, 0),\n (70712, 70719, 0),\n (70722, 70724, 0),\n (70726, 70726, 0),\n (70750, 70750, 0),\n (70835, 70840, 0),\n (70842, 70842, 0),\n (70847, 70848, 0),\n (70850, 70851, 0),\n (71090, 71093, 0),\n (71100, 71101, 0),\n (71103, 71104, 0),\n (71132, 71133, 0),\n (71219, 71226, 0),\n (71229, 71229, 0),\n (71231, 71232, 0),\n (71339, 71339, 0),\n (71341, 71341, 0),\n (71344, 71349, 0),\n (71351, 71351, 0),\n (71453, 71455, 0),\n (71458, 71461, 0),\n (71463, 71467, 0),\n (71727, 71735, 0),\n (71737, 71738, 0),\n (71995, 71996, 0),\n (71998, 71998, 0),\n (72003, 72003, 0),\n (72148, 72151, 0),\n (72154, 72155, 0),\n (72160, 72160, 0),\n (72193, 72202, 0),\n (72243, 72248, 0),\n (72251, 72254, 0),\n (72263, 72263, 0),\n (72273, 72278, 0),\n (72281, 72283, 0),\n (72330, 72342, 0),\n (72344, 72345, 0),\n (72752, 72758, 0),\n (72760, 72765, 0),\n (72767, 72767, 0),\n (72850, 72871, 0),\n (72874, 72880, 0),\n (72882, 72883, 0),\n (72885, 72886, 0),\n (73009, 73014, 0),\n (73018, 73018, 0),\n (73020, 73021, 0),\n (73023, 73029, 0),\n (73031, 73031, 0),\n (73104, 73105, 0),\n (73109, 73109, 0),\n (73111, 73111, 0),\n (73459, 73460, 0),\n (73472, 73473, 0),\n (73526, 73530, 0),\n (73536, 73536, 0),\n (73538, 73538, 0),\n (78912, 78912, 0),\n (78919, 78933, 0),\n (92912, 92916, 0),\n (92976, 92982, 0),\n (94031, 94031, 0),\n (94095, 94098, 0),\n (94176, 94179, 2),\n (94180, 94180, 0),\n (94192, 94193, 2),\n (94208, 100343, 2),\n (100352, 101589, 2),\n (101632, 101640, 2),\n (110576, 110579, 2),\n...\n# Path: src/black/strings.py\n\"\"\"\nSimple formatting on strings. Further string formatting code is in trans.py.\n\"\"\"\n\nimport re\nimport sys\nfrom functools import lru_cache\nfrom typing import Final, List, Match, Pattern\n\nfrom black._width_table import WIDTH_TABLE\nfrom blib2to3.pytree import Leaf\n\nSTRING_PREFIX_CHARS: Final = \"furbFURB\" # All possible string prefix characters.\nSTRING_PREFIX_RE: Final = re.compile(\n r\"^([\" + STRING_PREFIX_CHARS + r\"]*)(.*)$\", re.DOTALL\n)\nFIRST_NON_WHITESPACE_RE: Final = re.compile(r\"\\s*\\t+\\s*(\\S)\")\nUNICODE_ESCAPE_RE: Final = re.compile(\n r\"(?P\\\\+)(?P\"\n r\"(u(?P[a-fA-F0-9]{4}))\" # Character with 16-bit hex value xxxx\n r\"|(U(?P[a-fA-F0-9]{8}))\" # Character with 32-bit hex value xxxxxxxx\n r\"|(x(?P[a-fA-F0-9]{2}))\" # Character with hex value hh\n r\"|(N\\{(?P[a-zA-Z0-9 \\-]{2,})\\})\" # Character named name in the Unicode database\n r\")\",\n re.VERBOSE,\n)\n\n\ndef sub_twice(regex: Pattern[str], replacement: str, original: str) -> str:\n \"\"\"Replace `regex` with `replacement` twice on `original`.\n\n This is used by string normalization to perform replaces on\n overlapping matches.\n \"\"\"\n return regex.sub(replacement, regex.sub(replacement, original))\n\n\ndef has_triple_quotes(string: str) -> bool:\n \"\"\"\n Returns:\n True iff @string starts with three quotation characters.\n \"\"\"\n raw_string = string.lstrip(STRING_PREFIX_CHARS)\n return raw_string[:3] in {'\"\"\"', \"'''\"}\n\n\ndef lines_with_leading_tabs_expanded(s: str) -> List[str]:\n \"\"\"\n Splits string into lines and expands only leading tabs (following the normal\n Python rules)\n \"\"\"\n lines = []\n for line in s.splitlines():\n # Find the index of the first non-whitespace character after a string of\n # whitespace that includes at least one tab\n match = FIRST_NON_WHITESPACE_RE.match(line)\n if match:\n first_non_whitespace_idx = match.start(1)\n\n lines.append(\n line[:first_non_whitespace_idx].expandtabs()\n + line[first_non_whitespace_idx:]\n )\n else:\n lines.append(line)\n if s.endswith(\"\\n\"):\n lines.append(\"\")\n return lines\n\n\ndef fix_docstring(docstring: str, prefix: str) -> str:\n # https://www.python.org/dev/peps/pep-0257/#handling-docstring-indentation\n if not docstring:\n return \"\"\n lines = lines_with_leading_tabs_expanded(docstring)\n # Determine minimum indentation (first line doesn't count):\n indent = sys.maxsize\n for line in lines[1:]:\n stripped = line.lstrip()\n if stripped:\n indent = min(indent, len(line) - len(stripped))\n # Remove indentation (first line is special):\n trimmed = [lines[0].strip()]\n if indent < sys.maxsize:\n last_line_idx = len(lines) - 2\n for i, line in enumerate(lines[1:]):\n stripped_line = line[indent:].rstrip()\n if stripped_line or i == last_line_idx:\n trimmed.append(prefix + stripped_line)\n else:\n trimmed.append(\"\")\n return \"\\n\".join(trimmed)\n\n\ndef get_string_prefix(string: str) -> str:\n \"\"\"\n Pre-conditions:\n * assert_is_leaf_string(@string)\n\n Returns:\n @string's prefix (e.g. '', 'r', 'f', or 'rf').\n \"\"\"\n assert_is_leaf_string(string)\n\n prefix = \"\"\n prefix_idx = 0\n while string[prefix_idx] in STRING_PREFIX_CHARS:\n prefix += string[prefix_idx]\n prefix_idx += 1\n\n return prefix\n\n\ndef assert_is_leaf_string(string: str) -> None:\n \"\"\"\n Checks the pre-condition that @string has the format that you would expect\n of `leaf.value` where `leaf` is some Leaf such that `leaf.type ==\n token.STRING`. A more precise description of the pre-conditions that are\n checked are listed below.\n\n Pre-conditions:\n * @string starts with either ', \", ', or \" where\n `set()` is some subset of `set(STRING_PREFIX_CHARS)`.\n * @string ends with a quote character (' or \").\n\n Raises:\n AssertionError(...) if the pre-conditions listed above are not\n satisfied.\n \"\"\"\n dquote_idx = string.find('\"')\n squote_idx = string.find(\"'\")\n if -1 in [dquote_idx, squote_idx]:\n quote_idx = max(dquote_idx, squote_idx)\n else:\n quote_idx = min(squote_idx, dquote_idx)\n\n assert (\n 0 <= quote_idx < len(string) - 1\n ), f\"{string!r} is missing a starting quote character (' or \\\").\"\n assert string[-1] in (\n \"'\",\n '\"',\n ), f\"{string!r} is missing an ending quote character (' or \\\").\"\n assert set(string[:quote_idx]).issubset(\n set(STRING_PREFIX_CHARS)\n ), f\"{set(string[:quote_idx])} is NOT a subset of {set(STRING_PREFIX_CHARS)}.\"\n\n\ndef normalize_string_prefix(s: str) -> str:\n \"\"\"Make all string prefixes lowercase.\"\"\"\n match = STRING_PREFIX_RE.match(s)\n assert match is not None, f\"failed to match string {s!r}\"\n orig_prefix = match.group(1)\n new_prefix = (\n orig_prefix.replace(\"F\", \"f\")\n .replace(\"B\", \"b\")\n .replace(\"U\", \"\")\n .replace(\"u\", \"\")\n )\n\n # Python syntax guarantees max 2 prefixes and that one of them is \"r\"\n if len(new_prefix) == 2 and \"r\" != new_prefix[0].lower():\n new_prefix = new_prefix[::-1]\n return f\"{new_prefix}{match.group(2)}\"\n\n\n# Re(gex) does actually cache patterns internally but this still improves\n# performance on a long list literal of strings by 5-9% since lru_cache's\n# caching overhead is much lower.\n@lru_cache(maxsize=64)\ndef _cached_compile(pattern: str) -> Pattern[str]:\n return re.compile(pattern)\n\n\ndef normalize_string_quotes(s: str) -> str:\n \"\"\"Prefer double quotes but only if it doesn't cause more escaping.\n\n Adds or removes backslashes as appropriate. Doesn't parse and fix\n strings nested in f-strings.\n \"\"\"\n value = s.lstrip(STRING_PREFIX_CHARS)\n if value[:3] == '\"\"\"':\n return s\n\n elif value[:3] == \"'''\":\n orig_quote = \"'''\"\n new_quote = '\"\"\"'\n elif value[0] == '\"':\n orig_quote = '\"'\n new_quote = \"'\"\n else:\n orig_quote = \"'\"\n new_quote = '\"'\n first_quote_pos = s.find(orig_quote)\n if first_quote_pos == -1:\n return s # There's an internal error\n\n prefix = s[:first_quote_pos]\n unescaped_new_quote = _cached_compile(rf\"(([^\\\\]|^)(\\\\\\\\)*){new_quote}\")\n escaped_new_quote = _cached_compile(rf\"([^\\\\]|^)\\\\((?:\\\\\\\\)*){new_quote}\")\n escaped_orig_quote = _cached_compile(rf\"([^\\\\]|^)\\\\((?:\\\\\\\\)*){orig_quote}\")\n body = s[first_quote_pos + len(orig_quote) : -len(orig_quote)]\n if \"r\" in prefix.casefold():\n if unescaped_new_quote.search(body):\n # There's at least one unescaped new_quote in this raw string\n # so converting is impossible\n return s\n\n # Do not introduce or remove backslashes in raw strings\n new_body = body\n else:\n # remove unnecessary escapes\n new_body = sub_twice(escaped_new_quote, rf\"\\1\\2{new_quote}\", body)\n if body != new_body:\n # Consider the string without unnecessary escapes as the original\n body = new_body\n s = f\"{prefix}{orig_quote}{body}{orig_quote}\"\n new_body = sub_twice(escaped_orig_quote, rf\"\\1\\2{orig_quote}\", new_body)\n new_body = sub_twice(unescaped_new_quote, rf\"\\1\\\\{new_quote}\", new_body)\n if \"f\" in prefix.casefold():\n matches = re.findall(\n r\"\"\"\n (?:(? orig_escape_count:\n return s # Do not introduce more escaping\n\n if new_escape_count == orig_escape_count and orig_quote == '\"':\n return s # Prefer double quotes\n\n return f\"{prefix}{new_quote}{new_body}{new_quote}\"\n\n\ndef normalize_unicode_escape_sequences(leaf: Leaf) -> None:\n \"\"\"Replace hex codes in Unicode escape sequences with lowercase representation.\"\"\"\n text = leaf.value\n prefix = get_string_prefix(text)\n if \"r\" in prefix.lower():\n return\n\n def replace(m: Match[str]) -> str:\n groups = m.groupdict()\n back_slashes = groups[\"backslashes\"]\n\n if len(back_slashes) % 2 == 0:\n return back_slashes + groups[\"body\"]\n\n if groups[\"u\"]:\n # \\u\n return back_slashes + \"u\" + groups[\"u\"].lower()\n elif groups[\"U\"]:\n # \\U\n return back_slashes + \"U\" + groups[\"U\"].lower()\n elif groups[\"x\"]:\n # \\x\n return back_slashes + \"x\" + groups[\"x\"].lower()\n else:\n assert groups[\"N\"], f\"Unexpected match: {m}\"\n # \\N{}\n return back_slashes + \"N{\" + groups[\"N\"].upper() + \"}\"\n\n leaf.value = re.sub(UNICODE_ESCAPE_RE, replace, text)\n\n\n@lru_cache(maxsize=4096)\ndef char_width(char: str) -> int:\n \"\"\"Return the width of a single character as it would be displayed in a\n terminal or editor (which respects Unicode East Asian Width).\n\n Full width characters are counted as 2, while half width characters are\n counted as 1. Also control characters are counted as 0.\n \"\"\"\n table = WIDTH_TABLE\n codepoint = ord(char)\n highest = len(table) - 1\n lowest = 0\n idx = highest // 2\n while True:\n start_codepoint, end_codepoint, width = table[idx]\n if codepoint < start_codepoint:\n highest = idx - 1\n elif codepoint > end_codepoint:\n lowest = idx + 1\n else:\n return 0 if width < 0 else width\n if highest < lowest:\n break\n idx = (highest + lowest) // 2\n return 1\n\n\ndef str_width(line_str: str) -> int:\n \"\"\"Return the width of `line_str` as it would be displayed in a terminal\n or editor (which respects Unicode East Asian Width).\n\n You could utilize this function to determine, for example, if a string\n is too wide to display in a terminal or editor.\n \"\"\"\n if line_str.isascii():\n # Fast path for a line consisting of only ASCII characters\n return len(line_str)\n return sum(map(char_width, line_str))\n\n\ndef count_chars_in_width(line_str: str, max_width: int) -> int:\n \"\"\"Count the number of characters in `line_str` that would fit in a\n terminal or editor of `max_width` (which respects Unicode East Asian\n Width).\n \"\"\"\n total_width = 0\n for i, char in enumerate(line_str):\n width = char_width(char)\n if width + total_width > max_width:\n return i\n total_width += width\n return len(line_str)\n\n# Path: src/black/nodes.py\n\"\"\"\nblib2to3 Node/Leaf transformation-related utility functions.\n\"\"\"\n\nimport sys\nfrom typing import (\n Final,\n Generic,\n Iterator,\n List,\n Literal,\n Optional,\n Set,\n Tuple,\n TypeVar,\n Union,\n)\n\nif sys.version_info >= (3, 10):\n from typing import TypeGuard\nelse:\n from typing_extensions import TypeGuard\n\nfrom mypy_extensions import mypyc_attr\n\nfrom black.cache import CACHE_DIR\nfrom black.mode import Mode, Preview\nfrom black.strings import get_string_prefix, has_triple_quotes\nfrom blib2to3 import pygram\nfrom blib2to3.pgen2 import token\nfrom blib2to3.pytree import NL, Leaf, Node, type_repr\n\npygram.initialize(CACHE_DIR)\nsyms: Final = pygram.python_symbols\n\n\n# types\nT = TypeVar(\"T\")\nLN = Union[Leaf, Node]\nLeafID = int\nNodeType = int\n\n\nWHITESPACE: Final = {token.DEDENT, token.INDENT, token.NEWLINE}\nSTATEMENT: Final = {\n syms.if_stmt,\n syms.while_stmt,\n syms.for_stmt,\n syms.try_stmt,\n syms.except_clause,\n syms.with_stmt,\n syms.funcdef,\n syms.classdef,\n syms.match_stmt,\n syms.case_block,\n}\nSTANDALONE_COMMENT: Final = 153\ntoken.tok_name[STANDALONE_COMMENT] = \"STANDALONE_COMMENT\"\nLOGIC_OPERATORS: Final = {\"and\", \"or\"}\nCOMPARATORS: Final = {\n token.LESS,\n token.GREATER,\n token.EQEQUAL,\n token.NOTEQUAL,\n token.LESSEQUAL,\n token.GREATEREQUAL,\n}\nMATH_OPERATORS: Final = {\n token.VBAR,\n token.CIRCUMFLEX,\n token.AMPER,\n token.LEFTSHIFT,\n token.RIGHTSHIFT,\n token.PLUS,\n token.MINUS,\n token.STAR,\n token.SLASH,\n token.DOUBLESLASH,\n token.PERCENT,\n token.AT,\n token.TILDE,\n token.DOUBLESTAR,\n}\nSTARS: Final = {token.STAR, token.DOUBLESTAR}\nVARARGS_SPECIALS: Final = STARS | {token.SLASH}\nVARARGS_PARENTS: Final = {\n syms.arglist,\n syms.argument, # double star in arglist\n syms.trailer, # single argument to call\n syms.typedargslist,\n syms.varargslist, # lambdas\n}\nUNPACKING_PARENTS: Final = {\n syms.atom, # single element of a list or set literal\n syms.dictsetmaker,\n syms.listmaker,\n syms.testlist_gexp,\n syms.testlist_star_expr,\n syms.subject_expr,\n syms.pattern,\n}\nTEST_DESCENDANTS: Final = {\n syms.test,\n syms.lambdef,\n syms.or_test,\n syms.and_test,\n syms.not_test,\n syms.comparison,\n syms.star_expr,\n syms.expr,\n syms.xor_expr,\n syms.and_expr,\n syms.shift_expr,\n syms.arith_expr,\n syms.trailer,\n syms.term,\n syms.power,\n syms.namedexpr_test,\n}\nTYPED_NAMES: Final = {syms.tname, syms.tname_star}\nASSIGNMENTS: Final = {\n \"=\",\n \"+=\",\n \"-=\",\n \"*=\",\n \"@=\",\n \"/=\",\n \"%=\",\n \"&=\",\n \"|=\",\n \"^=\",\n \"<<=\",\n \">>=\",\n \"**=\",\n \"//=\",\n \":\",\n}\n\nIMPLICIT_TUPLE: Final = {syms.testlist, syms.testlist_star_expr, syms.exprlist}\nBRACKET: Final = {\n token.LPAR: token.RPAR,\n token.LSQB: token.RSQB,\n token.LBRACE: token.RBRACE,\n}\nOPENING_BRACKETS: Final = set(BRACKET.keys())\nCLOSING_BRACKETS: Final = set(BRACKET.values())\nBRACKETS: Final = OPENING_BRACKETS | CLOSING_BRACKETS\nALWAYS_NO_SPACE: Final = CLOSING_BRACKETS | {token.COMMA, STANDALONE_COMMENT}\n\nRARROW = 55\n\n\n@mypyc_attr(allow_interpreted_subclasses=True)\nclass Visitor(Generic[T]):\n \"\"\"Basic lib2to3 visitor that yields things of type `T` on `visit()`.\"\"\"\n\n \ndef visit(self, node: LN) -> Iterator[T]:\n \"\"\"Main method to visit `node` and its children.\n\n It tries to find a `visit_*()` method for the given `node.type`, like\n `visit_simple_stmt` for Node objects or `visit_INDENT` for Leaf objects.\n If no dedicated `visit_*()` method is found, chooses `visit_default()`\n instead.\n\n Then yields objects of type `T` from the selected visitor.\n \"\"\"\n if node.type < 256:\n name = token.tok_name[node.type]\n else:\n name = str(type_repr(node.type))\n # We explicitly branch on whether a visitor exists (instead of\n # using self.visit_default as the default arg to getattr) in order\n # to save needing to create a bound method object and so mypyc can\n # generate a native call to visit_default.\n visitf = getattr(self, f\"visit_{name}\", None)\n if visitf:\n yield from visitf(node)\n else:\n yield from self.visit_default(node)\n\n def visit_default(self, node: LN) -> Iterator[T]:\n \"\"\"Default `visit_*()` implementation. Recurses to children of `node`.\"\"\"\n if isinstance(node, Node):\n for child in node.children:\n yield from self.visit(child)\n\n\ndef whitespace(leaf: Leaf, *, complex_subscript: bool, mode: Mode) -> str: # noqa: C901\n \"\"\"Return whitespace prefix if needed for the given `leaf`.\n\n `complex_subscript` signals whether the given leaf is part of a subscription\n which has non-trivial arguments, like arithmetic expressions or function calls.\n \"\"\"\n NO: Final[str] = \"\"\n SPACE: Final[str] = \" \"\n DOUBLESPACE: Final[str] = \" \"\n t = leaf.type\n p = leaf.parent\n v = leaf.value\n if t in ALWAYS_NO_SPACE:\n return NO\n\n if t == token.COMMENT:\n return DOUBLESPACE\n\n assert p is not None, f\"INTERNAL ERROR: hand-made leaf without parent: {leaf!r}\"\n if t == token.COLON and p.type not in {\n syms.subscript,\n syms.subscriptlist,\n syms.sliceop,\n }:\n return NO\n\n prev = leaf.prev_sibling\n if not prev:\n prevp = preceding_leaf(p)\n if not prevp or prevp.type in OPENING_BRACKETS:\n return NO\n\n if t == token.COLON:\n if prevp.type == token.COLON:\n return NO\n\n elif prevp.type != token.COMMA and not complex_subscript:\n return NO\n\n return SPACE\n\n if prevp.type == token.EQUAL:\n if prevp.parent:\n if prevp.parent.type in {\n syms.arglist,\n syms.argument,\n syms.parameters,\n syms.varargslist,\n }:\n return NO\n\n elif prevp.parent.type == syms.typedargslist:\n # A bit hacky: if the equal sign has whitespace, it means we\n # previously found it's a typed argument. So, we're using\n # that, too.\n return prevp.prefix\n\n elif (\n prevp.type == token.STAR\n and parent_type(prevp) == syms.star_expr\n and parent_type(prevp.parent) == syms.subscriptlist\n ):\n # No space between typevar tuples.\n return NO\n\n elif prevp.type in VARARGS_SPECIALS:\n if is_vararg(prevp, within=VARARGS_PARENTS | UNPACKING_PARENTS):\n return NO\n\n elif prevp.type == token.COLON:\n if prevp.parent and prevp.parent.type in {syms.subscript, syms.sliceop}:\n return SPACE if complex_subscript else NO\n\n elif (\n prevp.parent\n and prevp.parent.type == syms.factor\n and prevp.type in MATH_OPERATORS\n ):\n return NO\n\n elif prevp.type == token.AT and p.parent and p.parent.type == syms.decorator:\n # no space in decorators\n return NO\n\n elif prev.type in OPENING_BRACKETS:\n return NO\n\n if p.type in {syms.parameters, syms.arglist}:\n # untyped function signatures or calls\n if not prev or prev.type != token.COMMA:\n return NO\n\n elif p.type == syms.varargslist:\n # lambdas\n if prev and prev.type != token.COMMA:\n return NO\n\n elif p.type == syms.typedargslist:\n # typed function signatures\n if not prev:\n return NO\n\n if t == token.EQUAL:\n if prev.type not in TYPED_NAMES:\n return NO\n\n elif prev.type == token.EQUAL:\n # A bit hacky: if the equal sign has whitespace, it means we\n # previously found it's a typed argument. So, we're using that, too.\n return prev.prefix\n\n elif prev.type != token.COMMA:\n return NO\n\n elif p.type in TYPED_NAMES:\n # type names\n if not prev:\n prevp = preceding_leaf(p)\n if not prevp or prevp.type != token.COMMA:\n return NO\n\n elif p.type == syms.trailer:\n # attributes and calls\n if t == token.LPAR or t == token.RPAR:\n return NO\n\n if not prev:\n if t == token.DOT or t == token.LSQB:\n return NO\n\n elif prev.type != token.COMMA:\n return NO\n\n elif p.type == syms.argument:\n # single argument\n if t == token.EQUAL:\n return NO\n\n if not prev:\n prevp = preceding_leaf(p)\n if not prevp or prevp.type == token.LPAR:\n return NO\n\n elif prev.type in {token.EQUAL} | VARARGS_SPECIALS:\n return NO\n\n elif p.type == syms.decorator:\n # decorators\n return NO\n\n elif p.type == syms.dotted_name:\n if prev:\n return NO\n\n prevp = preceding_leaf(p)\n if not prevp or prevp.type == token.AT or prevp.type == token.DOT:\n return NO\n\n elif p.type == syms.classdef:\n if t == token.LPAR:\n return NO\n\n if prev and prev.type == token.LPAR:\n return NO\n\n elif p.type in {syms.subscript, syms.sliceop}:\n # indexing\n if not prev:\n assert p.parent is not None, \"subscripts are always parented\"\n if p.parent.type == syms.subscriptlist:\n return SPACE\n\n return NO\n\n elif t == token.COLONEQUAL or prev.type == token.COLONEQUAL:\n return SPACE\n\n elif not complex_subscript:\n return NO\n\n elif p.type == syms.atom:\n if prev and t == token.DOT:\n # dots, but not the first one.\n return NO\n\n elif p.type == syms.dictsetmaker:\n # dict unpacking\n if prev and prev.type == token.DOUBLESTAR:\n return NO\n\n elif p.type in {syms.factor, syms.star_expr}:\n # unary ops\n if not prev:\n prevp = preceding_leaf(p)\n if not prevp or prevp.type in OPENING_BRACKETS:\n return NO\n\n prevp_parent = prevp.parent\n assert prevp_parent is not None\n if prevp.type == token.COLON and prevp_parent.type in {\n syms.subscript,\n syms.sliceop,\n }:\n return NO\n\n elif prevp.type == token.EQUAL and prevp_parent.type == syms.argument:\n return NO\n\n elif t in {token.NAME, token.NUMBER, token.STRING}:\n return NO\n\n elif p.type == syms.import_from:\n if t == token.DOT:\n if prev and prev.type == token.DOT:\n return NO\n\n elif t == token.NAME:\n if v == \"import\":\n return SPACE\n\n if prev and prev.type == token.DOT:\n return NO\n\n elif p.type == syms.sliceop:\n return NO\n\n elif p.type == syms.except_clause:\n if t == token.STAR:\n return NO\n\n return SPACE\n\n\ndef make_simple_prefix(nl_count: int, form_feed: bool, empty_line: str = \"\\n\") -> str:\n \"\"\"Generate a normalized prefix string.\"\"\"\n if form_feed:\n return (empty_line * (nl_count - 1)) + \"\\f\" + empty_line\n return empty_line * nl_count\n\n\ndef preceding_leaf(node: Optional[LN]) -> Optional[Leaf]:\n \"\"\"Return the first leaf that precedes `node`, if any.\"\"\"\n while node:\n res = node.prev_sibling\n if res:\n if isinstance(res, Leaf):\n return res\n\n try:\n return list(res.leaves())[-1]\n\n except IndexError:\n return None\n\n node = node.parent\n return None\n\n\ndef prev_siblings_are(node: Optional[LN], tokens: List[Optional[NodeType]]) -> bool:\n \"\"\"Return if the `node` and its previous siblings match types against the provided\n list of tokens; the provided `node`has its type matched against the last element in\n the list. `None` can be used as the first element to declare that the start of the\n list is anchored at the start of its parent's children.\"\"\"\n if not tokens:\n return True\n if tokens[-1] is None:\n return node is None\n if not node:\n return False\n if node.type != tokens[-1]:\n return False\n return prev_siblings_are(node.prev_sibling, tokens[:-1])\n\n\ndef parent_type(node: Optional[LN]) -> Optional[NodeType]:\n \"\"\"\n Returns:\n @node.parent.type, if @node is not None and has a parent.\n OR\n None, otherwise.\n \"\"\"\n if node is None or node.parent is None:\n return None\n\n return node.parent.type\n\n\ndef child_towards(ancestor: Node, descendant: LN) -> Optional[LN]:\n \"\"\"Return the child of `ancestor` that contains `descendant`.\"\"\"\n node: Optional[LN] = descendant\n while node and node.parent != ancestor:\n node = node.parent\n return node\n\n\ndef replace_child(old_child: LN, new_child: LN) -> None:\n \"\"\"\n Side Effects:\n * If @old_child.parent is set, replace @old_child with @new_child in\n @old_child's underlying Node structure.\n OR\n * Otherwise, this function does nothing.\n \"\"\"\n parent = old_child.parent\n if not parent:\n return\n\n child_idx = old_child.remove()\n if child_idx is not None:\n parent.insert_child(child_idx, new_child)\n\n\ndef container_of(leaf: Leaf) -> LN:\n \"\"\"Return `leaf` or one of its ancestors that is the topmost container of it.\n\n By \"container\" we mean a node where `leaf` is the very first child.\n \"\"\"\n same_prefix = leaf.prefix\n container: LN = leaf\n while container:\n parent = container.parent\n if parent is None:\n break\n\n if parent.children[0].prefix != same_prefix:\n break\n\n if parent.type == syms.file_input:\n break\n\n if parent.prev_sibling is not None and parent.prev_sibling.type in BRACKETS:\n break\n\n container = parent\n return container\n\n\ndef first_leaf_of(node: LN) -> Optional[Leaf]:\n \"\"\"Returns the first leaf of the node tree.\"\"\"\n if isinstance(node, Leaf):\n return node\n if node.children:\n return first_leaf_of(node.children[0])\n else:\n return None\n\n\ndef is_arith_like(node: LN) -> bool:\n \"\"\"Whether node is an arithmetic or a binary arithmetic expression\"\"\"\n return node.type in {\n syms.arith_expr,\n syms.shift_expr,\n syms.xor_expr,\n syms.and_expr,\n }\n\n\ndef is_docstring(leaf: Leaf, mode: Mode) -> bool:\n if leaf.type != token.STRING:\n return False\n\n prefix = get_string_prefix(leaf.value)\n if set(prefix).intersection(\"bBfF\"):\n return False\n\n if (\n Preview.unify_docstring_detection in mode\n and leaf.parent\n and leaf.parent.type == syms.simple_stmt\n and not leaf.parent.prev_sibling\n and leaf.parent.parent\n and leaf.parent.parent.type == syms.file_input\n ):\n return True\n\n if prev_siblings_are(\n leaf.parent, [None, token.NEWLINE, token.INDENT, syms.simple_stmt]\n ):\n return True\n\n # Multiline docstring on the same line as the `def`.\n if prev_siblings_are(leaf.parent, [syms.parameters, token.COLON, syms.simple_stmt]):\n # `syms.parameters` is only used in funcdefs and async_funcdefs in the Python\n # grammar. We're safe to return True without further checks.\n return True\n\n return False\n\n\ndef is_empty_tuple(node: LN) -> bool:\n \"\"\"Return True if `node` holds an empty tuple.\"\"\"\n return (\n node.type == syms.atom\n and len(node.children) == 2\n and node.children[0].type == token.LPAR\n and node.children[1].type == token.RPAR\n )\n\n\ndef is_one_tuple(node: LN) -> bool:\n \"\"\"Return True if `node` holds a tuple with one element, with or without parens.\"\"\"\n if node.type == syms.atom:\n gexp = unwrap_singleton_parenthesis(node)\n if gexp is None or gexp.type != syms.testlist_gexp:\n return False\n\n return len(gexp.children) == 2 and gexp.children[1].type == token.COMMA\n\n return (\n node.type in IMPLICIT_TUPLE\n and len(node.children) == 2\n and node.children[1].type == token.COMMA\n )\n\n\ndef is_tuple_containing_walrus(node: LN) -> bool:\n \"\"\"Return True if `node` holds a tuple that contains a walrus operator.\"\"\"\n if node.type != syms.atom:\n return False\n gexp = unwrap_singleton_parenthesis(node)\n if gexp is None or gexp.type != syms.testlist_gexp:\n return False\n\n return any(child.type == syms.namedexpr_test for child in gexp.children)\n\n\ndef is_one_sequence_between(\n opening: Leaf,\n closing: Leaf,\n leaves: List[Leaf],\n brackets: Tuple[int, int] = (token.LPAR, token.RPAR),\n) -> bool:\n...\n", "needle_token_start": 12070, "needle_token_end": 12356, "code_context_ntokens": 16200, "output": ["```python\ndef visit(self, node: LN) -> Iterator[T]:\n \"\"\"Main method to visit `node` and its children.\n\n It tries to find a `visit_*()` method for the given `node.type`, like\n `visit_simple_stmt` for Node objects or `visit_INDENT` for Leaf objects.\n If no dedicated `visit_*()` method is found, chooses `visit_default()`\n instead.\n\n Then yields objects of type `T` from the selected visitor.\n \"\"\"\n if node.type < 256:\n name = token.tok_name[node.type]\n else:\n name = str(type_repr(node.type))\n # We explicitly branch on whether a visitor exists (instead of\n # using self.visit_default as the default arg to getattr) in order\n # to save needing to create a bound method object and so mypyc can\n # generate a native call to visit_default.\n visitf = getattr(self, f\"visit_{name}\", None)\n if visitf:\n yield from visitf(node)\n else:\n yield from self.visit_default(node)\n\ndef visit_default(self, node: LN) -> Iterator[T]:\n \"\"\"Default `visit_*()` implementation. Recurses to children of `node`.\"\"\"\n if isinstance(node, Node):\n for child in node.children:\n yield from self.visit(child)\n```"]} +{"repo": "psf/black", "name": "__post_init__", "language": "python", "path": "src/black/linegen.py", "position_ratio": 0.85, "description": "\nFunction Description:\n1. **Purpose**: Initializes the state and configuration of an object after its creation, setting up necessary properties and methods for handling different types of statements in code formatting.\n2. **Input**: None explicitly, but operates on the instance itself.\n3. **Output**: None, as it modifies the instance state directly.\n4. **Procedure**: Sets an initial line configuration based on the mode, assigns specialized partial functions to handle various statement types like assertions, loops, and conditional statements, and configures handling for different syntactic elements based on the formatting mode.\n\n", "instruction": "Based on the function description and code context, please retrieve and repeat the exact described function from the code context in a code block wrapped by ```:", "template": "instruction\ncode_context\ndescription\ninstruction", "code_context": "# Path: src/black/trans.py\n\"\"\"\nString transformers that can split and merge strings.\n\"\"\"\n\nimport re\nfrom abc import ABC, abstractmethod\nfrom collections import defaultdict\nfrom dataclasses import dataclass\nfrom typing import (\n Any,\n Callable,\n ClassVar,\n Collection,\n Dict,\n Final,\n Iterable,\n Iterator,\n List,\n Literal,\n Optional,\n Sequence,\n Set,\n Tuple,\n TypeVar,\n Union,\n)\n\nfrom mypy_extensions import trait\n\nfrom black.comments import contains_pragma_comment\nfrom black.lines import Line, append_leaves\nfrom black.mode import Feature, Mode, Preview\nfrom black.nodes import (\n CLOSING_BRACKETS,\n OPENING_BRACKETS,\n STANDALONE_COMMENT,\n is_empty_lpar,\n is_empty_par,\n is_empty_rpar,\n is_part_of_annotation,\n parent_type,\n replace_child,\n syms,\n)\nfrom black.rusty import Err, Ok, Result\nfrom black.strings import (\n assert_is_leaf_string,\n count_chars_in_width,\n get_string_prefix,\n has_triple_quotes,\n normalize_string_quotes,\n str_width,\n)\nfrom blib2to3.pgen2 import token\nfrom blib2to3.pytree import Leaf, Node\n\n\nclass CannotTransform(Exception):\n \"\"\"Base class for errors raised by Transformers.\"\"\"\n\n\n# types\nT = TypeVar(\"T\")\nLN = Union[Leaf, Node]\nTransformer = Callable[[Line, Collection[Feature], Mode], Iterator[Line]]\nIndex = int\nNodeType = int\nParserState = int\nStringID = int\nTResult = Result[T, CannotTransform] # (T)ransform Result\nTMatchResult = TResult[List[Index]]\n\nSPLIT_SAFE_CHARS = frozenset([\"\\u3001\", \"\\u3002\", \"\\uff0c\"]) # East Asian stops\n\n\ndef TErr(err_msg: str) -> Err[CannotTransform]:\n \"\"\"(T)ransform Err\n\n Convenience function used when working with the TResult type.\n \"\"\"\n cant_transform = CannotTransform(err_msg)\n return Err(cant_transform)\n\n\ndef hug_power_op(\n line: Line, features: Collection[Feature], mode: Mode\n) -> Iterator[Line]:\n \"\"\"A transformer which normalizes spacing around power operators.\"\"\"\n\n # Performance optimization to avoid unnecessary Leaf clones and other ops.\n for leaf in line.leaves:\n if leaf.type == token.DOUBLESTAR:\n break\n else:\n raise CannotTransform(\"No doublestar token was found in the line.\")\n\n def is_simple_lookup(index: int, kind: Literal[1, -1]) -> bool:\n # Brackets and parentheses indicate calls, subscripts, etc. ...\n # basically stuff that doesn't count as \"simple\". Only a NAME lookup\n # or dotted lookup (eg. NAME.NAME) is OK.\n if Preview.is_simple_lookup_for_doublestar_expression not in mode:\n return original_is_simple_lookup_func(line, index, kind)\n\n else:\n if kind == -1:\n return handle_is_simple_look_up_prev(\n line, index, {token.RPAR, token.RSQB}\n )\n else:\n return handle_is_simple_lookup_forward(\n line, index, {token.LPAR, token.LSQB}\n )\n\n def is_simple_operand(index: int, kind: Literal[1, -1]) -> bool:\n # An operand is considered \"simple\" if's a NAME, a numeric CONSTANT, a simple\n # lookup (see above), with or without a preceding unary operator.\n start = line.leaves[index]\n if start.type in {token.NAME, token.NUMBER}:\n return is_simple_lookup(index, kind)\n\n if start.type in {token.PLUS, token.MINUS, token.TILDE}:\n if line.leaves[index + 1].type in {token.NAME, token.NUMBER}:\n # kind is always one as bases with a preceding unary op will be checked\n # for simplicity starting from the next token (so it'll hit the check\n # above).\n return is_simple_lookup(index + 1, kind=1)\n\n return False\n\n new_line = line.clone()\n should_hug = False\n for idx, leaf in enumerate(line.leaves):\n new_leaf = leaf.clone()\n if should_hug:\n new_leaf.prefix = \"\"\n should_hug = False\n\n should_hug = (\n (0 < idx < len(line.leaves) - 1)\n and leaf.type == token.DOUBLESTAR\n and is_simple_operand(idx - 1, kind=-1)\n and line.leaves[idx - 1].value != \"lambda\"\n and is_simple_operand(idx + 1, kind=1)\n )\n if should_hug:\n new_leaf.prefix = \"\"\n\n # We have to be careful to make a new line properly:\n # - bracket related metadata must be maintained (handled by Line.append)\n # - comments need to copied over, updating the leaf IDs they're attached to\n new_line.append(new_leaf, preformatted=True)\n for comment_leaf in line.comments_after(leaf):\n new_line.append(comment_leaf, preformatted=True)\n\n yield new_line\n\n\ndef original_is_simple_lookup_func(\n line: Line, index: int, step: Literal[1, -1]\n) -> bool:\n if step == -1:\n disallowed = {token.RPAR, token.RSQB}\n else:\n disallowed = {token.LPAR, token.LSQB}\n\n while 0 <= index < len(line.leaves):\n current = line.leaves[index]\n if current.type in disallowed:\n return False\n if current.type not in {token.NAME, token.DOT} or current.value == \"for\":\n # If the current token isn't disallowed, we'll assume this is\n # simple as only the disallowed tokens are semantically\n # attached to this lookup expression we're checking. Also,\n # stop early if we hit the 'for' bit of a comprehension.\n return True\n\n index += step\n\n return True\n\n\ndef handle_is_simple_look_up_prev(line: Line, index: int, disallowed: Set[int]) -> bool:\n \"\"\"\n Handling the determination of is_simple_lookup for the lines prior to the doublestar\n token. This is required because of the need to isolate the chained expression\n to determine the bracket or parenthesis belong to the single expression.\n \"\"\"\n contains_disallowed = False\n chain = []\n\n while 0 <= index < len(line.leaves):\n current = line.leaves[index]\n chain.append(current)\n if not contains_disallowed and current.type in disallowed:\n contains_disallowed = True\n if not is_expression_chained(chain):\n return not contains_disallowed\n\n index -= 1\n\n return True\n\n\ndef handle_is_simple_lookup_forward(\n line: Line, index: int, disallowed: Set[int]\n) -> bool:\n \"\"\"\n Handling decision is_simple_lookup for the lines behind the doublestar token.\n This function is simplified to keep consistent with the prior logic and the forward\n case are more straightforward and do not need to care about chained expressions.\n \"\"\"\n while 0 <= index < len(line.leaves):\n current = line.leaves[index]\n if current.type in disallowed:\n return False\n if current.type not in {token.NAME, token.DOT} or (\n current.type == token.NAME and current.value == \"for\"\n ):\n # If the current token isn't disallowed, we'll assume this is simple as\n # only the disallowed tokens are semantically attached to this lookup\n # expression we're checking. Also, stop early if we hit the 'for' bit\n # of a comprehension.\n return True\n\n index += 1\n\n return True\n\n\ndef is_expression_chained(chained_leaves: List[Leaf]) -> bool:\n \"\"\"\n Function to determine if the variable is a chained call.\n (e.g., foo.lookup, foo().lookup, (foo.lookup())) will be recognized as chained call)\n \"\"\"\n if len(chained_leaves) < 2:\n return True\n\n current_leaf = chained_leaves[-1]\n past_leaf = chained_leaves[-2]\n\n if past_leaf.type == token.NAME:\n return current_leaf.type in {token.DOT}\n elif past_leaf.type in {token.RPAR, token.RSQB}:\n return current_leaf.type in {token.RSQB, token.RPAR}\n elif past_leaf.type in {token.LPAR, token.LSQB}:\n return current_leaf.type in {token.NAME, token.LPAR, token.LSQB}\n else:\n return False\n\n\nclass StringTransformer(ABC):\n \"\"\"\n An implementation of the Transformer protocol that relies on its\n subclasses overriding the template methods `do_match(...)` and\n `do_transform(...)`.\n\n This Transformer works exclusively on strings (for example, by merging\n or splitting them).\n\n The following sections can be found among the docstrings of each concrete\n StringTransformer subclass.\n\n Requirements:\n Which requirements must be met of the given Line for this\n StringTransformer to be applied?\n\n Transformations:\n If the given Line meets all of the above requirements, which string\n transformations can you expect to be applied to it by this\n StringTransformer?\n\n Collaborations:\n What contractual agreements does this StringTransformer have with other\n StringTransfomers? Such collaborations should be eliminated/minimized\n as much as possible.\n \"\"\"\n\n __name__: Final = \"StringTransformer\"\n\n # Ideally this would be a dataclass, but unfortunately mypyc breaks when used with\n # `abc.ABC`.\n def __init__(self, line_length: int, normalize_strings: bool) -> None:\n self.line_length = line_length\n self.normalize_strings = normalize_strings\n\n @abstractmethod\n def do_match(self, line: Line) -> TMatchResult:\n \"\"\"\n Returns:\n * Ok(string_indices) such that for each index, `line.leaves[index]`\n is our target string if a match was able to be made. For\n transformers that don't result in more lines (e.g. StringMerger,\n StringParenStripper), multiple matches and transforms are done at\n once to reduce the complexity.\n OR\n * Err(CannotTransform), if no match could be made.\n \"\"\"\n\n @abstractmethod\n def do_transform(\n self, line: Line, string_indices: List[int]\n ) -> Iterator[TResult[Line]]:\n \"\"\"\n Yields:\n * Ok(new_line) where new_line is the new transformed line.\n OR\n * Err(CannotTransform) if the transformation failed for some reason. The\n `do_match(...)` template method should usually be used to reject\n the form of the given Line, but in some cases it is difficult to\n know whether or not a Line meets the StringTransformer's\n requirements until the transformation is already midway.\n\n Side Effects:\n This method should NOT mutate @line directly, but it MAY mutate the\n Line's underlying Node structure. (WARNING: If the underlying Node\n structure IS altered, then this method should NOT be allowed to\n yield an CannotTransform after that point.)\n \"\"\"\n\n def __call__(\n self, line: Line, _features: Collection[Feature], _mode: Mode\n ) -> Iterator[Line]:\n \"\"\"\n StringTransformer instances have a call signature that mirrors that of\n the Transformer type.\n\n Raises:\n CannotTransform(...) if the concrete StringTransformer class is unable\n to transform @line.\n \"\"\"\n # Optimization to avoid calling `self.do_match(...)` when the line does\n # not contain any string.\n if not any(leaf.type == token.STRING for leaf in line.leaves):\n raise CannotTransform(\"There are no strings in this line.\")\n\n match_result = self.do_match(line)\n\n if isinstance(match_result, Err):\n cant_transform = match_result.err()\n raise CannotTransform(\n f\"The string transformer {self.__class__.__name__} does not recognize\"\n \" this line as one that it can transform.\"\n ) from cant_transform\n\n string_indices = match_result.ok()\n\n for line_result in self.do_transform(line, string_indices):\n if isinstance(line_result, Err):\n cant_transform = line_result.err()\n raise CannotTransform(\n \"StringTransformer failed while attempting to transform string.\"\n ) from cant_transform\n line = line_result.ok()\n yield line\n\n\n@dataclass\nclass CustomSplit:\n \"\"\"A custom (i.e. manual) string split.\n\n A single CustomSplit instance represents a single substring.\n\n Examples:\n Consider the following string:\n ```\n \"Hi there friend.\"\n \" This is a custom\"\n f\" string {split}.\"\n ```\n\n This string will correspond to the following three CustomSplit instances:\n ```\n CustomSplit(False, 16)\n CustomSplit(False, 17)\n CustomSplit(True, 16)\n ```\n \"\"\"\n\n has_prefix: bool\n break_idx: int\n\n\n@trait\nclass CustomSplitMapMixin:\n \"\"\"\n This mixin class is used to map merged strings to a sequence of\n CustomSplits, which will then be used to re-split the strings iff none of\n the resultant substrings go over the configured max line length.\n \"\"\"\n\n _Key: ClassVar = Tuple[StringID, str]\n _CUSTOM_SPLIT_MAP: ClassVar[Dict[_Key, Tuple[CustomSplit, ...]]] = defaultdict(\n tuple\n )\n\n @staticmethod\n def _get_key(string: str) -> \"CustomSplitMapMixin._Key\":\n \"\"\"\n Returns:\n A unique identifier that is used internally to map @string to a\n group of custom splits.\n \"\"\"\n return (id(string), string)\n\n def add_custom_splits(\n self, string: str, custom_splits: Iterable[CustomSplit]\n ) -> None:\n \"\"\"Custom Split Map Setter Method\n\n Side Effects:\n Adds a mapping from @string to the custom splits @custom_splits.\n \"\"\"\n key = self._get_key(string)\n self._CUSTOM_SPLIT_MAP[key] = tuple(custom_splits)\n\n def pop_custom_splits(self, string: str) -> List[CustomSplit]:\n \"\"\"Custom Split Map Getter Method\n\n Returns:\n * A list of the custom splits that are mapped to @string, if any\n exist.\n OR\n * [], otherwise.\n\n Side Effects:\n Deletes the mapping between @string and its associated custom\n splits (which are returned to the caller).\n \"\"\"\n key = self._get_key(string)\n\n custom_splits = self._CUSTOM_SPLIT_MAP[key]\n del self._CUSTOM_SPLIT_MAP[key]\n\n return list(custom_splits)\n\n def has_custom_splits(self, string: str) -> bool:\n \"\"\"\n Returns:\n True iff @string is associated with a set of custom splits.\n \"\"\"\n key = self._get_key(string)\n return key in self._CUSTOM_SPLIT_MAP\n\n\nclass StringMerger(StringTransformer, CustomSplitMapMixin):\n \"\"\"StringTransformer that merges strings together.\n\n Requirements:\n (A) The line contains adjacent strings such that ALL of the validation checks\n listed in StringMerger._validate_msg(...)'s docstring pass.\n OR\n (B) The line contains a string which uses line continuation backslashes.\n\n Transformations:\n Depending on which of the two requirements above where met, either:\n\n (A) The string group associated with the target string is merged.\n OR\n (B) All line-continuation backslashes are removed from the target string.\n\n Collaborations:\n StringMerger provides custom split information to StringSplitter.\n \"\"\"\n\n def do_match(self, line: Line) -> TMatchResult:\n LL = line.leaves\n\n is_valid_index = is_valid_index_factory(LL)\n\n string_indices = []\n idx = 0\n while is_valid_index(idx):\n leaf = LL[idx]\n if (\n leaf.type == token.STRING\n and is_valid_index(idx + 1)\n and LL[idx + 1].type == token.STRING\n ):\n # Let's check if the string group contains an inline comment\n # If we have a comment inline, we don't merge the strings\n contains_comment = False\n i = idx\n while is_valid_index(i):\n if LL[i].type != token.STRING:\n break\n if line.comments_after(LL[i]):\n contains_comment = True\n break\n i += 1\n\n if not is_part_of_annotation(leaf) and not contains_comment:\n string_indices.append(idx)\n\n # Advance to the next non-STRING leaf.\n idx += 2\n while is_valid_index(idx) and LL[idx].type == token.STRING:\n idx += 1\n\n elif leaf.type == token.STRING and \"\\\\\\n\" in leaf.value:\n string_indices.append(idx)\n # Advance to the next non-STRING leaf.\n idx += 1\n while is_valid_index(idx) and LL[idx].type == token.STRING:\n idx += 1\n\n else:\n idx += 1\n\n if string_indices:\n return Ok(string_indices)\n else:\n return TErr(\"This line has no strings that need merging.\")\n\n def do_transform(\n self, line: Line, string_indices: List[int]\n ) -> Iterator[TResult[Line]]:\n new_line = line\n\n rblc_result = self._remove_backslash_line_continuation_chars(\n new_line, string_indices\n )\n if isinstance(rblc_result, Ok):\n new_line = rblc_result.ok()\n\n msg_result = self._merge_string_group(new_line, string_indices)\n if isinstance(msg_result, Ok):\n new_line = msg_result.ok()\n\n if isinstance(rblc_result, Err) and isinstance(msg_result, Err):\n msg_cant_transform = msg_result.err()\n rblc_cant_transform = rblc_result.err()\n cant_transform = CannotTransform(\n \"StringMerger failed to merge any strings in this line.\"\n )\n\n # Chain the errors together using `__cause__`.\n msg_cant_transform.__cause__ = rblc_cant_transform\n cant_transform.__cause__ = msg_cant_transform\n\n yield Err(cant_transform)\n else:\n yield Ok(new_line)\n\n @staticmethod\n def _remove_backslash_line_continuation_chars(\n line: Line, string_indices: List[int]\n ) -> TResult[Line]:\n \"\"\"\n Merge strings that were split across multiple lines using\n line-continuation backslashes.\n\n Returns:\n Ok(new_line), if @line contains backslash line-continuation\n characters.\n OR\n Err(CannotTransform), otherwise.\n \"\"\"\n LL = line.leaves\n\n indices_to_transform = []\n for string_idx in string_indices:\n string_leaf = LL[string_idx]\n if (\n string_leaf.type == token.STRING\n and \"\\\\\\n\" in string_leaf.value\n and not has_triple_quotes(string_leaf.value)\n ):\n indices_to_transform.append(string_idx)\n\n if not indices_to_transform:\n return TErr(\n \"Found no string leaves that contain backslash line continuation\"\n \" characters.\"\n )\n\n new_line = line.clone()\n new_line.comments = line.comments.copy()\n append_leaves(new_line, line, LL)\n\n for string_idx in indices_to_transform:\n new_string_leaf = new_line.leaves[string_idx]\n new_string_leaf.value = new_string_leaf.value.replace(\"\\\\\\n\", \"\")\n\n return Ok(new_line)\n\n def _merge_string_group(\n self, line: Line, string_indices: List[int]\n ) -> TResult[Line]:\n \"\"\"\n Merges string groups (i.e. set of adjacent strings).\n\n Each index from `string_indices` designates one string group's first\n leaf in `line.leaves`.\n\n Returns:\n Ok(new_line), if ALL of the validation checks found in\n _validate_msg(...) pass.\n OR\n Err(CannotTransform), otherwise.\n \"\"\"\n LL = line.leaves\n\n is_valid_index = is_valid_index_factory(LL)\n\n # A dict of {string_idx: tuple[num_of_strings, string_leaf]}.\n merged_string_idx_dict: Dict[int, Tuple[int, Leaf]] = {}\n for string_idx in string_indices:\n vresult = self._validate_msg(line, string_idx)\n if isinstance(vresult, Err):\n continue\n merged_string_idx_dict[string_idx] = self._merge_one_string_group(\n LL, string_idx, is_valid_index\n )\n\n if not merged_string_idx_dict:\n return TErr(\"No string group is merged\")\n\n # Build the final line ('new_line') that this method will later return.\n new_line = line.clone()\n previous_merged_string_idx = -1\n previous_merged_num_of_strings = -1\n for i, leaf in enumerate(LL):\n if i in merged_string_idx_dict:\n previous_merged_string_idx = i\n previous_merged_num_of_strings, string_leaf = merged_string_idx_dict[i]\n new_line.append(string_leaf)\n\n if (\n previous_merged_string_idx\n <= i\n < previous_merged_string_idx + previous_merged_num_of_strings\n ):\n for comment_leaf in line.comments_after(LL[i]):\n new_line.append(comment_leaf, preformatted=True)\n continue\n\n append_leaves(new_line, line, [leaf])\n\n return Ok(new_line)\n\n def _merge_one_string_group(\n self, LL: List[Leaf], string_idx: int, is_valid_index: Callable[[int], bool]\n ) -> Tuple[int, Leaf]:\n \"\"\"\n Merges one string group where the first string in the group is\n `LL[string_idx]`.\n\n Returns:\n A tuple of `(num_of_strings, leaf)` where `num_of_strings` is the\n number of strings merged and `leaf` is the newly merged string\n to be replaced in the new line.\n \"\"\"\n # If the string group is wrapped inside an Atom node, we must make sure\n # to later replace that Atom with our new (merged) string leaf.\n atom_node = LL[string_idx].parent\n\n # We will place BREAK_MARK in between every two substrings that we\n # merge. We will then later go through our final result and use the\n # various instances of BREAK_MARK we find to add the right values to\n # the custom split map.\n BREAK_MARK = \"@@@@@ BLACK BREAKPOINT MARKER @@@@@\"\n\n QUOTE = LL[string_idx].value[-1]\n\n def make_naked(string: str, string_prefix: str) -> str:\n \"\"\"Strip @string (i.e. make it a \"naked\" string)\n\n Pre-conditions:\n * assert_is_leaf_string(@string)\n\n Returns:\n A string that is identical to @string except that\n @string_prefix has been stripped, the surrounding QUOTE\n characters have been removed, and any remaining QUOTE\n characters have been escaped.\n \"\"\"\n assert_is_leaf_string(string)\n if \"f\" in string_prefix:\n f_expressions = (\n string[span[0] + 1 : span[1] - 1] # +-1 to get rid of curly braces\n for span in iter_fexpr_spans(string)\n )\n debug_expressions_contain_visible_quotes = any(\n re.search(r\".*[\\'\\\"].*(?= 0\n ), \"Logic error while filling the custom string breakpoint cache.\"\n\n temp_string = temp_string[mark_idx + len(BREAK_MARK) :]\n breakpoint_idx = mark_idx + (len(prefix) if has_prefix else 0) + 1\n custom_splits.append(CustomSplit(has_prefix, breakpoint_idx))\n\n string_leaf = Leaf(token.STRING, S_leaf.value.replace(BREAK_MARK, \"\"))\n\n if atom_node is not None:\n # If not all children of the atom node are merged (this can happen\n # when there is a standalone comment in the middle) ...\n if non_string_idx - string_idx < len(atom_node.children):\n # We need to replace the old STRING leaves with the new string leaf.\n first_child_idx = LL[string_idx].remove()\n for idx in range(string_idx + 1, non_string_idx):\n LL[idx].remove()\n if first_child_idx is not None:\n atom_node.insert_child(first_child_idx, string_leaf)\n else:\n # Else replace the atom node with the new string leaf.\n replace_child(atom_node, string_leaf)\n\n self.add_custom_splits(string_leaf.value, custom_splits)\n return num_of_strings, string_leaf\n\n @staticmethod\n def _validate_msg(line: Line, string_idx: int) -> TResult[None]:\n \"\"\"Validate (M)erge (S)tring (G)roup\n\n Transform-time string validation logic for _merge_string_group(...).\n\n Returns:\n * Ok(None), if ALL validation checks (listed below) pass.\n OR\n * Err(CannotTransform), if any of the following are true:\n - The target string group does not contain ANY stand-alone comments.\n - The target string is not in a string group (i.e. it has no\n adjacent strings).\n - The string group has more than one inline comment.\n...\n# Path: src/black/linegen.py\n\"\"\"\nGenerating lines of code.\n\"\"\"\n\nimport re\nimport sys\nfrom dataclasses import replace\nfrom enum import Enum, auto\nfrom functools import partial, wraps\nfrom typing import Collection, Iterator, List, Optional, Set, Union, cast\n\nfrom black.brackets import (\n COMMA_PRIORITY,\n DOT_PRIORITY,\n STRING_PRIORITY,\n get_leaves_inside_matching_brackets,\n max_delimiter_priority_in_atom,\n)\nfrom black.comments import FMT_OFF, generate_comments, list_comments\nfrom black.lines import (\n Line,\n RHSResult,\n append_leaves,\n can_be_split,\n can_omit_invisible_parens,\n is_line_short_enough,\n line_to_string,\n)\nfrom black.mode import Feature, Mode, Preview\nfrom black.nodes import (\n ASSIGNMENTS,\n BRACKETS,\n CLOSING_BRACKETS,\n OPENING_BRACKETS,\n STANDALONE_COMMENT,\n STATEMENT,\n WHITESPACE,\n Visitor,\n ensure_visible,\n get_annotation_type,\n is_arith_like,\n is_async_stmt_or_funcdef,\n is_atom_with_invisible_parens,\n is_docstring,\n is_empty_tuple,\n is_lpar_token,\n is_multiline_string,\n is_name_token,\n is_one_sequence_between,\n is_one_tuple,\n is_parent_function_or_class,\n is_part_of_annotation,\n is_rpar_token,\n is_stub_body,\n is_stub_suite,\n is_tuple_containing_walrus,\n is_type_ignore_comment_string,\n is_vararg,\n is_walrus_assignment,\n is_yield,\n syms,\n wrap_in_parentheses,\n)\nfrom black.numerics import normalize_numeric_literal\nfrom black.strings import (\n fix_docstring,\n get_string_prefix,\n normalize_string_prefix,\n normalize_string_quotes,\n normalize_unicode_escape_sequences,\n)\nfrom black.trans import (\n CannotTransform,\n StringMerger,\n StringParenStripper,\n StringParenWrapper,\n StringSplitter,\n Transformer,\n hug_power_op,\n)\nfrom blib2to3.pgen2 import token\nfrom blib2to3.pytree import Leaf, Node\n\n# types\nLeafID = int\nLN = Union[Leaf, Node]\n\n\nclass CannotSplit(CannotTransform):\n \"\"\"A readable split that fits the allotted line length is impossible.\"\"\"\n\n\n# This isn't a dataclass because @dataclass + Generic breaks mypyc.\n# See also https://github.com/mypyc/mypyc/issues/827.\nclass LineGenerator(Visitor[Line]):\n \"\"\"Generates reformatted Line objects. Empty lines are not emitted.\n\n Note: destroys the tree it's visiting by mutating prefixes of its leaves\n in ways that will no longer stringify to valid Python code on the tree.\n \"\"\"\n\n def __init__(self, mode: Mode, features: Collection[Feature]) -> None:\n self.mode = mode\n self.features = features\n self.current_line: Line\n self.__post_init__()\n\n def line(self, indent: int = 0) -> Iterator[Line]:\n \"\"\"Generate a line.\n\n If the line is empty, only emit if it makes sense.\n If the line is too long, split it first and then generate.\n\n If any lines were generated, set up a new current_line.\n \"\"\"\n if not self.current_line:\n self.current_line.depth += indent\n return # Line is empty, don't emit. Creating a new one unnecessary.\n\n if len(self.current_line.leaves) == 1 and is_async_stmt_or_funcdef(\n self.current_line.leaves[0]\n ):\n # Special case for async def/for/with statements. `visit_async_stmt`\n # adds an `ASYNC` leaf then visits the child def/for/with statement\n # nodes. Line yields from those nodes shouldn't treat the former\n # `ASYNC` leaf as a complete line.\n return\n\n complete_line = self.current_line\n self.current_line = Line(mode=self.mode, depth=complete_line.depth + indent)\n yield complete_line\n\n def visit_default(self, node: LN) -> Iterator[Line]:\n \"\"\"Default `visit_*()` implementation. Recurses to children of `node`.\"\"\"\n if isinstance(node, Leaf):\n any_open_brackets = self.current_line.bracket_tracker.any_open_brackets()\n for comment in generate_comments(node):\n if any_open_brackets:\n # any comment within brackets is subject to splitting\n self.current_line.append(comment)\n elif comment.type == token.COMMENT:\n # regular trailing comment\n self.current_line.append(comment)\n yield from self.line()\n\n else:\n # regular standalone comment\n yield from self.line()\n\n self.current_line.append(comment)\n yield from self.line()\n\n if any_open_brackets:\n node.prefix = \"\"\n if self.mode.string_normalization and node.type == token.STRING:\n node.value = normalize_string_prefix(node.value)\n node.value = normalize_string_quotes(node.value)\n if node.type == token.NUMBER:\n normalize_numeric_literal(node)\n if node.type not in WHITESPACE:\n self.current_line.append(node)\n yield from super().visit_default(node)\n\n def visit_test(self, node: Node) -> Iterator[Line]:\n \"\"\"Visit an `x if y else z` test\"\"\"\n\n already_parenthesized = (\n node.prev_sibling and node.prev_sibling.type == token.LPAR\n )\n\n if not already_parenthesized:\n # Similar to logic in wrap_in_parentheses\n lpar = Leaf(token.LPAR, \"\")\n rpar = Leaf(token.RPAR, \"\")\n prefix = node.prefix\n node.prefix = \"\"\n lpar.prefix = prefix\n node.insert_child(0, lpar)\n node.append_child(rpar)\n\n yield from self.visit_default(node)\n\n def visit_INDENT(self, node: Leaf) -> Iterator[Line]:\n \"\"\"Increase indentation level, maybe yield a line.\"\"\"\n # In blib2to3 INDENT never holds comments.\n yield from self.line(+1)\n yield from self.visit_default(node)\n\n def visit_DEDENT(self, node: Leaf) -> Iterator[Line]:\n \"\"\"Decrease indentation level, maybe yield a line.\"\"\"\n # The current line might still wait for trailing comments. At DEDENT time\n # there won't be any (they would be prefixes on the preceding NEWLINE).\n # Emit the line then.\n yield from self.line()\n\n # While DEDENT has no value, its prefix may contain standalone comments\n # that belong to the current indentation level. Get 'em.\n yield from self.visit_default(node)\n\n # Finally, emit the dedent.\n yield from self.line(-1)\n\n def visit_stmt(\n self, node: Node, keywords: Set[str], parens: Set[str]\n ) -> Iterator[Line]:\n \"\"\"Visit a statement.\n\n This implementation is shared for `if`, `while`, `for`, `try`, `except`,\n `def`, `with`, `class`, `assert`, and assignments.\n\n The relevant Python language `keywords` for a given statement will be\n NAME leaves within it. This methods puts those on a separate line.\n\n `parens` holds a set of string leaf values immediately after which\n invisible parens should be put.\n \"\"\"\n normalize_invisible_parens(\n node, parens_after=parens, mode=self.mode, features=self.features\n )\n for child in node.children:\n if is_name_token(child) and child.value in keywords:\n yield from self.line()\n\n yield from self.visit(child)\n\n def visit_typeparams(self, node: Node) -> Iterator[Line]:\n yield from self.visit_default(node)\n node.children[0].prefix = \"\"\n\n def visit_typevartuple(self, node: Node) -> Iterator[Line]:\n yield from self.visit_default(node)\n node.children[1].prefix = \"\"\n\n def visit_paramspec(self, node: Node) -> Iterator[Line]:\n yield from self.visit_default(node)\n node.children[1].prefix = \"\"\n\n def visit_dictsetmaker(self, node: Node) -> Iterator[Line]:\n if Preview.wrap_long_dict_values_in_parens in self.mode:\n for i, child in enumerate(node.children):\n if i == 0:\n continue\n if node.children[i - 1].type == token.COLON:\n if (\n child.type == syms.atom\n and child.children[0].type in OPENING_BRACKETS\n and not is_walrus_assignment(child)\n ):\n maybe_make_parens_invisible_in_atom(\n child,\n parent=node,\n remove_brackets_around_comma=False,\n )\n else:\n wrap_in_parentheses(node, child, visible=False)\n yield from self.visit_default(node)\n\n def visit_funcdef(self, node: Node) -> Iterator[Line]:\n \"\"\"Visit function definition.\"\"\"\n yield from self.line()\n\n # Remove redundant brackets around return type annotation.\n is_return_annotation = False\n for child in node.children:\n if child.type == token.RARROW:\n is_return_annotation = True\n elif is_return_annotation:\n if child.type == syms.atom and child.children[0].type == token.LPAR:\n if maybe_make_parens_invisible_in_atom(\n child,\n parent=node,\n remove_brackets_around_comma=False,\n ):\n wrap_in_parentheses(node, child, visible=False)\n else:\n wrap_in_parentheses(node, child, visible=False)\n is_return_annotation = False\n\n for child in node.children:\n yield from self.visit(child)\n\n def visit_match_case(self, node: Node) -> Iterator[Line]:\n \"\"\"Visit either a match or case statement.\"\"\"\n normalize_invisible_parens(\n node, parens_after=set(), mode=self.mode, features=self.features\n )\n\n yield from self.line()\n for child in node.children:\n yield from self.visit(child)\n\n def visit_suite(self, node: Node) -> Iterator[Line]:\n \"\"\"Visit a suite.\"\"\"\n if is_stub_suite(node):\n yield from self.visit(node.children[2])\n else:\n yield from self.visit_default(node)\n\n def visit_simple_stmt(self, node: Node) -> Iterator[Line]:\n \"\"\"Visit a statement without nested statements.\"\"\"\n prev_type: Optional[int] = None\n for child in node.children:\n if (prev_type is None or prev_type == token.SEMI) and is_arith_like(child):\n wrap_in_parentheses(node, child, visible=False)\n prev_type = child.type\n\n if node.parent and node.parent.type in STATEMENT:\n if is_parent_function_or_class(node) and is_stub_body(node):\n yield from self.visit_default(node)\n else:\n yield from self.line(+1)\n yield from self.visit_default(node)\n yield from self.line(-1)\n\n else:\n if not node.parent or not is_stub_suite(node.parent):\n yield from self.line()\n yield from self.visit_default(node)\n\n def visit_async_stmt(self, node: Node) -> Iterator[Line]:\n \"\"\"Visit `async def`, `async for`, `async with`.\"\"\"\n yield from self.line()\n\n children = iter(node.children)\n for child in children:\n yield from self.visit(child)\n\n if child.type == token.ASYNC or child.type == STANDALONE_COMMENT:\n # STANDALONE_COMMENT happens when `# fmt: skip` is applied on the async\n # line.\n break\n\n internal_stmt = next(children)\n yield from self.visit(internal_stmt)\n\n def visit_decorators(self, node: Node) -> Iterator[Line]:\n \"\"\"Visit decorators.\"\"\"\n for child in node.children:\n yield from self.line()\n yield from self.visit(child)\n\n def visit_power(self, node: Node) -> Iterator[Line]:\n for idx, leaf in enumerate(node.children[:-1]):\n next_leaf = node.children[idx + 1]\n\n if not isinstance(leaf, Leaf):\n continue\n\n value = leaf.value.lower()\n if (\n leaf.type == token.NUMBER\n and next_leaf.type == syms.trailer\n # Ensure that we are in an attribute trailer\n and next_leaf.children[0].type == token.DOT\n # It shouldn't wrap hexadecimal, binary and octal literals\n and not value.startswith((\"0x\", \"0b\", \"0o\"))\n # It shouldn't wrap complex literals\n and \"j\" not in value\n ):\n wrap_in_parentheses(node, leaf)\n\n remove_await_parens(node)\n\n yield from self.visit_default(node)\n\n def visit_SEMI(self, leaf: Leaf) -> Iterator[Line]:\n \"\"\"Remove a semicolon and put the other statement on a separate line.\"\"\"\n yield from self.line()\n\n def visit_ENDMARKER(self, leaf: Leaf) -> Iterator[Line]:\n \"\"\"End of file. Process outstanding comments and end with a newline.\"\"\"\n yield from self.visit_default(leaf)\n yield from self.line()\n\n def visit_STANDALONE_COMMENT(self, leaf: Leaf) -> Iterator[Line]:\n if not self.current_line.bracket_tracker.any_open_brackets():\n yield from self.line()\n yield from self.visit_default(leaf)\n\n def visit_factor(self, node: Node) -> Iterator[Line]:\n \"\"\"Force parentheses between a unary op and a binary power:\n\n -2 ** 8 -> -(2 ** 8)\n \"\"\"\n _operator, operand = node.children\n if (\n operand.type == syms.power\n and len(operand.children) == 3\n and operand.children[1].type == token.DOUBLESTAR\n ):\n lpar = Leaf(token.LPAR, \"(\")\n rpar = Leaf(token.RPAR, \")\")\n index = operand.remove() or 0\n node.insert_child(index, Node(syms.atom, [lpar, operand, rpar]))\n yield from self.visit_default(node)\n\n def visit_tname(self, node: Node) -> Iterator[Line]:\n \"\"\"\n Add potential parentheses around types in function parameter lists to be made\n into real parentheses in case the type hint is too long to fit on a line\n Examples:\n def foo(a: int, b: float = 7): ...\n\n ->\n\n def foo(a: (int), b: (float) = 7): ...\n \"\"\"\n assert len(node.children) == 3\n if maybe_make_parens_invisible_in_atom(node.children[2], parent=node):\n wrap_in_parentheses(node, node.children[2], visible=False)\n\n yield from self.visit_default(node)\n\n def visit_STRING(self, leaf: Leaf) -> Iterator[Line]:\n if Preview.hex_codes_in_unicode_sequences in self.mode:\n normalize_unicode_escape_sequences(leaf)\n\n if is_docstring(leaf, self.mode) and not re.search(r\"\\\\\\s*\\n\", leaf.value):\n # We're ignoring docstrings with backslash newline escapes because changing\n # indentation of those changes the AST representation of the code.\n if self.mode.string_normalization:\n docstring = normalize_string_prefix(leaf.value)\n # visit_default() does handle string normalization for us, but\n # since this method acts differently depending on quote style (ex.\n # see padding logic below), there's a possibility for unstable\n # formatting as visit_default() is called *after*. To avoid a\n # situation where this function formats a docstring differently on\n # the second pass, normalize it early.\n docstring = normalize_string_quotes(docstring)\n else:\n docstring = leaf.value\n prefix = get_string_prefix(docstring)\n docstring = docstring[len(prefix) :] # Remove the prefix\n quote_char = docstring[0]\n # A natural way to remove the outer quotes is to do:\n # docstring = docstring.strip(quote_char)\n # but that breaks on \"\"\"\"\"x\"\"\" (which is '\"\"x').\n # So we actually need to remove the first character and the next two\n # characters but only if they are the same as the first.\n quote_len = 1 if docstring[1] != quote_char else 3\n docstring = docstring[quote_len:-quote_len]\n docstring_started_empty = not docstring\n indent = \" \" * 4 * self.current_line.depth\n\n if is_multiline_string(leaf):\n docstring = fix_docstring(docstring, indent)\n else:\n docstring = docstring.strip()\n\n has_trailing_backslash = False\n if docstring:\n # Add some padding if the docstring starts / ends with a quote mark.\n if docstring[0] == quote_char:\n docstring = \" \" + docstring\n if docstring[-1] == quote_char:\n docstring += \" \"\n if docstring[-1] == \"\\\\\":\n backslash_count = len(docstring) - len(docstring.rstrip(\"\\\\\"))\n if backslash_count % 2:\n # Odd number of tailing backslashes, add some padding to\n # avoid escaping the closing string quote.\n docstring += \" \"\n has_trailing_backslash = True\n elif not docstring_started_empty:\n docstring = \" \"\n\n # We could enforce triple quotes at this point.\n quote = quote_char * quote_len\n\n # It's invalid to put closing single-character quotes on a new line.\n if quote_len == 3:\n # We need to find the length of the last line of the docstring\n # to find if we can add the closing quotes to the line without\n # exceeding the maximum line length.\n # If docstring is one line, we don't put the closing quotes on a\n # separate line because it looks ugly (#3320).\n lines = docstring.splitlines()\n last_line_length = len(lines[-1]) if docstring else 0\n\n # If adding closing quotes would cause the last line to exceed\n # the maximum line length, and the closing quote is not\n # prefixed by a newline then put a line break before\n # the closing quotes\n if (\n len(lines) > 1\n and last_line_length + quote_len > self.mode.line_length\n and len(indent) + quote_len <= self.mode.line_length\n and not has_trailing_backslash\n ):\n if (\n Preview.docstring_check_for_newline in self.mode\n and leaf.value[-1 - quote_len] == \"\\n\"\n ):\n leaf.value = prefix + quote + docstring + quote\n else:\n leaf.value = prefix + quote + docstring + \"\\n\" + indent + quote\n else:\n leaf.value = prefix + quote + docstring + quote\n else:\n leaf.value = prefix + quote + docstring + quote\n\n yield from self.visit_default(leaf)\n\n \ndef __post_init__(self) -> None:\n \"\"\"You are in a twisty little maze of passages.\"\"\"\n self.current_line = Line(mode=self.mode)\n\n v = self.visit_stmt\n \u00d8: Set[str] = set()\n self.visit_assert_stmt = partial(v, keywords={\"assert\"}, parens={\"assert\", \",\"})\n self.visit_if_stmt = partial(\n v, keywords={\"if\", \"else\", \"elif\"}, parens={\"if\", \"elif\"}\n )\n self.visit_while_stmt = partial(v, keywords={\"while\", \"else\"}, parens={\"while\"})\n self.visit_for_stmt = partial(v, keywords={\"for\", \"else\"}, parens={\"for\", \"in\"})\n self.visit_try_stmt = partial(\n v, keywords={\"try\", \"except\", \"else\", \"finally\"}, parens=\u00d8\n )\n self.visit_except_clause = partial(v, keywords={\"except\"}, parens={\"except\"})\n self.visit_with_stmt = partial(v, keywords={\"with\"}, parens={\"with\"})\n self.visit_classdef = partial(v, keywords={\"class\"}, parens=\u00d8)\n\n self.visit_expr_stmt = partial(v, keywords=\u00d8, parens=ASSIGNMENTS)\n self.visit_return_stmt = partial(v, keywords={\"return\"}, parens={\"return\"})\n self.visit_import_from = partial(v, keywords=\u00d8, parens={\"import\"})\n self.visit_del_stmt = partial(v, keywords=\u00d8, parens={\"del\"})\n self.visit_async_funcdef = self.visit_async_stmt\n self.visit_decorated = self.visit_decorators\n\n # PEP 634\n self.visit_match_stmt = self.visit_match_case\n self.visit_case_block = self.visit_match_case\n if Preview.remove_redundant_guard_parens in self.mode:\n self.visit_guard = partial(v, keywords=\u00d8, parens={\"if\"})\n\n\ndef _hugging_power_ops_line_to_string(\n line: Line,\n features: Collection[Feature],\n mode: Mode,\n) -> Optional[str]:\n try:\n return line_to_string(next(hug_power_op(line, features, mode)))\n except CannotTransform:\n return None\n\n\ndef transform_line(\n line: Line, mode: Mode, features: Collection[Feature] = ()\n) -> Iterator[Line]:\n \"\"\"Transform a `line`, potentially splitting it into many lines.\n\n They should fit in the allotted `line_length` but might not be able to.\n\n `features` are syntactical features that may be used in the output.\n \"\"\"\n if line.is_comment:\n yield line\n return\n\n line_str = line_to_string(line)\n\n # We need the line string when power operators are hugging to determine if we should\n # split the line. Default to line_str, if no power operator are present on the line.\n line_str_hugging_power_ops = (\n _hugging_power_ops_line_to_string(line, features, mode) or line_str\n )\n\n ll = mode.line_length\n sn = mode.string_normalization\n string_merge = StringMerger(ll, sn)\n string_paren_strip = StringParenStripper(ll, sn)\n string_split = StringSplitter(ll, sn)\n string_paren_wrap = StringParenWrapper(ll, sn)\n\n transformers: List[Transformer]\n if (\n not line.contains_uncollapsable_type_comments()\n and not line.should_split_rhs\n and not line.magic_trailing_comma\n and (\n is_line_short_enough(line, mode=mode, line_str=line_str_hugging_power_ops)\n or line.contains_unsplittable_type_ignore()\n )\n and not (line.inside_brackets and line.contains_standalone_comments())\n and not line.contains_implicit_multiline_string_with_comments()\n ):\n # Only apply basic string preprocessing, since lines shouldn't be split here.\n if Preview.string_processing in mode:\n transformers = [string_merge, string_paren_strip]\n else:\n transformers = []\n elif line.is_def and not should_split_funcdef_with_rhs(line, mode):\n transformers = [left_hand_split]\n else:\n\n def _rhs(\n self: object, line: Line, features: Collection[Feature], mode: Mode\n ) -> Iterator[Line]:\n \"\"\"Wraps calls to `right_hand_split`.\n\n The calls increasingly `omit` right-hand trailers (bracket pairs with\n content), meaning the trailers get glued together to split on another\n bracket pair instead.\n \"\"\"\n for omit in generate_trailers_to_omit(line, mode.line_length):\n lines = list(right_hand_split(line, mode, features, omit=omit))\n # Note: this check is only able to figure out if the first line of the\n # *current* transformation fits in the line length. This is true only\n # for simple cases. All others require running more transforms via\n # `transform_line()`. This check doesn't know if those would succeed.\n if is_line_short_enough(lines[0], mode=mode):\n yield from lines\n return\n\n # All splits failed, best effort split with no omits.\n # This mostly happens to multiline strings that are by definition\n # reported as not fitting a single line, as well as lines that contain\n # trailing commas (those have to be exploded).\n yield from right_hand_split(line, mode, features=features)\n\n # HACK: nested functions (like _rhs) compiled by mypyc don't retain their\n # __name__ attribute which is needed in `run_transformer` further down.\n # Unfortunately a nested class breaks mypyc too. So a class must be created\n # via type ... https://github.com/mypyc/mypyc/issues/884\n rhs = type(\"rhs\", (), {\"__call__\": _rhs})()\n\n if Preview.string_processing in mode:\n if line.inside_brackets:\n transformers = [\n string_merge,\n string_paren_strip,\n string_split,\n delimiter_split,\n standalone_comment_split,\n string_paren_wrap,\n rhs,\n ]\n else:\n transformers = [\n string_merge,\n string_paren_strip,\n string_split,\n string_paren_wrap,\n rhs,\n ]\n else:\n if line.inside_brackets:\n transformers = [delimiter_split, standalone_comment_split, rhs]\n else:\n transformers = [rhs]\n # It's always safe to attempt hugging of power operations and pretty much every line\n # could match.\n transformers.append(hug_power_op)\n\n for transform in transformers:\n # We are accumulating lines in `result` because we might want to abort\n # mission and return the original line in the end, or attempt a different\n # split altogether.\n try:\n result = run_transformer(line, transform, mode, features, line_str=line_str)\n except CannotTransform:\n continue\n else:\n yield from result\n break\n\n else:\n yield line\n\n\ndef should_split_funcdef_with_rhs(line: Line, mode: Mode) -> bool:\n \"\"\"If a funcdef has a magic trailing comma in the return type, then we should first\n split the line with rhs to respect the comma.\n \"\"\"\n return_type_leaves: List[Leaf] = []\n in_return_type = False\n\n for leaf in line.leaves:\n if leaf.type == token.COLON:\n in_return_type = False\n if in_return_type:\n return_type_leaves.append(leaf)\n if leaf.type == token.RARROW:\n in_return_type = True\n\n # using `bracket_split_build_line` will mess with whitespace, so we duplicate a\n # couple lines from it.\n result = Line(mode=line.mode, depth=line.depth)\n leaves_to_track = get_leaves_inside_matching_brackets(return_type_leaves)\n for leaf in return_type_leaves:\n result.append(\n leaf,\n preformatted=True,\n track_bracket=id(leaf) in leaves_to_track,\n )\n\n # we could also return true if the line is too long, and the return type is longer\n # than the param list. Or if `should_split_rhs` returns True.\n return result.magic_trailing_comma is not None\n\n\nclass _BracketSplitComponent(Enum):\n head = auto()\n body = auto()\n tail = auto()\n\n\ndef left_hand_split(\n line: Line, _features: Collection[Feature], mode: Mode\n) -> Iterator[Line]:\n \"\"\"Split line into many lines, starting with the first matching bracket pair.\n\n Note: this usually looks weird, only use this for function definitions.\n Prefer RHS otherwise. This is why this function is not symmetrical with\n :func:`right_hand_split` which also handles optional parentheses.\n \"\"\"\n tail_leaves: List[Leaf] = []\n body_leaves: List[Leaf] = []\n head_leaves: List[Leaf] = []\n current_leaves = head_leaves\n matching_bracket: Optional[Leaf] = None\n for leaf in line.leaves:\n if (\n current_leaves is body_leaves\n and leaf.type in CLOSING_BRACKETS\n and leaf.opening_bracket is matching_bracket\n and isinstance(matching_bracket, Leaf)\n ):\n ensure_visible(leaf)\n ensure_visible(matching_bracket)\n current_leaves = tail_leaves if body_leaves else head_leaves\n current_leaves.append(leaf)\n if current_leaves is head_leaves:\n if leaf.type in OPENING_BRACKETS:\n matching_bracket = leaf\n current_leaves = body_leaves\n...\n", "needle_token_start": 13424, "needle_token_end": 13982, "code_context_ntokens": 16131, "output": ["```python\ndef __post_init__(self) -> None:\n \"\"\"You are in a twisty little maze of passages.\"\"\"\n self.current_line = Line(mode=self.mode)\n\n v = self.visit_stmt\n \u00d8: Set[str] = set()\n self.visit_assert_stmt = partial(v, keywords={\"assert\"}, parens={\"assert\", \",\"})\n self.visit_if_stmt = partial(\n v, keywords={\"if\", \"else\", \"elif\"}, parens={\"if\", \"elif\"}\n )\n self.visit_while_stmt = partial(v, keywords={\"while\", \"else\"}, parens={\"while\"})\n self.visit_for_stmt = partial(v, keywords={\"for\", \"else\"}, parens={\"for\", \"in\"})\n self.visit_try_stmt = partial(\n v, keywords={\"try\", \"except\", \"else\", \"finally\"}, parens=\u00d8\n )\n self.visit_except_clause = partial(v, keywords={\"except\"}, parens={\"except\"})\n self.visit_with_stmt = partial(v, keywords={\"with\"}, parens={\"with\"})\n self.visit_classdef = partial(v, keywords={\"class\"}, parens=\u00d8)\n\n self.visit_expr_stmt = partial(v, keywords=\u00d8, parens=ASSIGNMENTS)\n self.visit_return_stmt = partial(v, keywords={\"return\"}, parens={\"return\"})\n self.visit_import_from = partial(v, keywords=\u00d8, parens={\"import\"})\n self.visit_del_stmt = partial(v, keywords=\u00d8, parens={\"del\"})\n self.visit_async_funcdef = self.visit_async_stmt\n self.visit_decorated = self.visit_decorators\n\n # PEP 634\n self.visit_match_stmt = self.visit_match_case\n self.visit_case_block = self.visit_match_case\n if Preview.remove_redundant_guard_parens in self.mode:\n self.visit_guard = partial(v, keywords=\u00d8, parens={\"if\"})\n```"]} +{"repo": "psf/black", "name": "should_split_funcdef_with_rhs", "language": "python", "path": "src/black/linegen.py", "position_ratio": 0.95, "description": "\nFunction Description:\n1. **Purpose**: To determine if a function definition line should be split based on the presence of a specific punctuation mark in the return type.\n2. **Input**: A line of code and a formatting mode.\n3. **Output**: A boolean value indicating whether the line should be split.\n4. **Procedure**: The function scans through the tokens of the line to identify the section that specifies the return type. It then checks if this section contains a specific punctuation mark that necessitates splitting the line. If found, the function returns true, indicating that the line should be split to properly format the return type according to the style rules.\n\n", "instruction": "Based on the function description and code context, please retrieve and repeat the exact described function from the code context in a code block wrapped by ```:", "template": "instruction\ncode_context\ndescription\ninstruction", "code_context": "# Path: src/black/trans.py\n\"\"\"\nString transformers that can split and merge strings.\n\"\"\"\n\nimport re\nfrom abc import ABC, abstractmethod\nfrom collections import defaultdict\nfrom dataclasses import dataclass\nfrom typing import (\n Any,\n Callable,\n ClassVar,\n Collection,\n Dict,\n Final,\n Iterable,\n Iterator,\n List,\n Literal,\n Optional,\n Sequence,\n Set,\n Tuple,\n TypeVar,\n Union,\n)\n\nfrom mypy_extensions import trait\n\nfrom black.comments import contains_pragma_comment\nfrom black.lines import Line, append_leaves\nfrom black.mode import Feature, Mode, Preview\nfrom black.nodes import (\n CLOSING_BRACKETS,\n OPENING_BRACKETS,\n STANDALONE_COMMENT,\n is_empty_lpar,\n is_empty_par,\n is_empty_rpar,\n is_part_of_annotation,\n parent_type,\n replace_child,\n syms,\n)\nfrom black.rusty import Err, Ok, Result\nfrom black.strings import (\n assert_is_leaf_string,\n count_chars_in_width,\n get_string_prefix,\n has_triple_quotes,\n normalize_string_quotes,\n str_width,\n)\nfrom blib2to3.pgen2 import token\nfrom blib2to3.pytree import Leaf, Node\n\n\nclass CannotTransform(Exception):\n \"\"\"Base class for errors raised by Transformers.\"\"\"\n\n\n# types\nT = TypeVar(\"T\")\nLN = Union[Leaf, Node]\nTransformer = Callable[[Line, Collection[Feature], Mode], Iterator[Line]]\nIndex = int\nNodeType = int\nParserState = int\nStringID = int\nTResult = Result[T, CannotTransform] # (T)ransform Result\nTMatchResult = TResult[List[Index]]\n\nSPLIT_SAFE_CHARS = frozenset([\"\\u3001\", \"\\u3002\", \"\\uff0c\"]) # East Asian stops\n\n\ndef TErr(err_msg: str) -> Err[CannotTransform]:\n \"\"\"(T)ransform Err\n\n Convenience function used when working with the TResult type.\n \"\"\"\n cant_transform = CannotTransform(err_msg)\n return Err(cant_transform)\n\n\ndef hug_power_op(\n line: Line, features: Collection[Feature], mode: Mode\n) -> Iterator[Line]:\n \"\"\"A transformer which normalizes spacing around power operators.\"\"\"\n\n # Performance optimization to avoid unnecessary Leaf clones and other ops.\n for leaf in line.leaves:\n if leaf.type == token.DOUBLESTAR:\n break\n else:\n raise CannotTransform(\"No doublestar token was found in the line.\")\n\n def is_simple_lookup(index: int, kind: Literal[1, -1]) -> bool:\n # Brackets and parentheses indicate calls, subscripts, etc. ...\n # basically stuff that doesn't count as \"simple\". Only a NAME lookup\n # or dotted lookup (eg. NAME.NAME) is OK.\n if Preview.is_simple_lookup_for_doublestar_expression not in mode:\n return original_is_simple_lookup_func(line, index, kind)\n\n else:\n if kind == -1:\n return handle_is_simple_look_up_prev(\n line, index, {token.RPAR, token.RSQB}\n )\n else:\n return handle_is_simple_lookup_forward(\n line, index, {token.LPAR, token.LSQB}\n )\n\n def is_simple_operand(index: int, kind: Literal[1, -1]) -> bool:\n # An operand is considered \"simple\" if's a NAME, a numeric CONSTANT, a simple\n # lookup (see above), with or without a preceding unary operator.\n start = line.leaves[index]\n if start.type in {token.NAME, token.NUMBER}:\n return is_simple_lookup(index, kind)\n\n if start.type in {token.PLUS, token.MINUS, token.TILDE}:\n if line.leaves[index + 1].type in {token.NAME, token.NUMBER}:\n # kind is always one as bases with a preceding unary op will be checked\n # for simplicity starting from the next token (so it'll hit the check\n # above).\n return is_simple_lookup(index + 1, kind=1)\n\n return False\n\n new_line = line.clone()\n should_hug = False\n for idx, leaf in enumerate(line.leaves):\n new_leaf = leaf.clone()\n if should_hug:\n new_leaf.prefix = \"\"\n should_hug = False\n\n should_hug = (\n (0 < idx < len(line.leaves) - 1)\n and leaf.type == token.DOUBLESTAR\n and is_simple_operand(idx - 1, kind=-1)\n and line.leaves[idx - 1].value != \"lambda\"\n and is_simple_operand(idx + 1, kind=1)\n )\n if should_hug:\n new_leaf.prefix = \"\"\n\n # We have to be careful to make a new line properly:\n # - bracket related metadata must be maintained (handled by Line.append)\n # - comments need to copied over, updating the leaf IDs they're attached to\n new_line.append(new_leaf, preformatted=True)\n for comment_leaf in line.comments_after(leaf):\n new_line.append(comment_leaf, preformatted=True)\n\n yield new_line\n\n\ndef original_is_simple_lookup_func(\n line: Line, index: int, step: Literal[1, -1]\n) -> bool:\n if step == -1:\n disallowed = {token.RPAR, token.RSQB}\n else:\n disallowed = {token.LPAR, token.LSQB}\n\n while 0 <= index < len(line.leaves):\n current = line.leaves[index]\n if current.type in disallowed:\n return False\n if current.type not in {token.NAME, token.DOT} or current.value == \"for\":\n # If the current token isn't disallowed, we'll assume this is\n # simple as only the disallowed tokens are semantically\n # attached to this lookup expression we're checking. Also,\n # stop early if we hit the 'for' bit of a comprehension.\n return True\n\n index += step\n\n return True\n\n\ndef handle_is_simple_look_up_prev(line: Line, index: int, disallowed: Set[int]) -> bool:\n \"\"\"\n Handling the determination of is_simple_lookup for the lines prior to the doublestar\n token. This is required because of the need to isolate the chained expression\n to determine the bracket or parenthesis belong to the single expression.\n \"\"\"\n contains_disallowed = False\n chain = []\n\n while 0 <= index < len(line.leaves):\n current = line.leaves[index]\n chain.append(current)\n if not contains_disallowed and current.type in disallowed:\n contains_disallowed = True\n if not is_expression_chained(chain):\n return not contains_disallowed\n\n index -= 1\n\n return True\n\n\ndef handle_is_simple_lookup_forward(\n line: Line, index: int, disallowed: Set[int]\n) -> bool:\n \"\"\"\n Handling decision is_simple_lookup for the lines behind the doublestar token.\n This function is simplified to keep consistent with the prior logic and the forward\n case are more straightforward and do not need to care about chained expressions.\n \"\"\"\n while 0 <= index < len(line.leaves):\n current = line.leaves[index]\n if current.type in disallowed:\n return False\n if current.type not in {token.NAME, token.DOT} or (\n current.type == token.NAME and current.value == \"for\"\n ):\n # If the current token isn't disallowed, we'll assume this is simple as\n # only the disallowed tokens are semantically attached to this lookup\n # expression we're checking. Also, stop early if we hit the 'for' bit\n # of a comprehension.\n return True\n\n index += 1\n\n return True\n\n\ndef is_expression_chained(chained_leaves: List[Leaf]) -> bool:\n \"\"\"\n Function to determine if the variable is a chained call.\n (e.g., foo.lookup, foo().lookup, (foo.lookup())) will be recognized as chained call)\n \"\"\"\n if len(chained_leaves) < 2:\n return True\n\n current_leaf = chained_leaves[-1]\n past_leaf = chained_leaves[-2]\n\n if past_leaf.type == token.NAME:\n return current_leaf.type in {token.DOT}\n elif past_leaf.type in {token.RPAR, token.RSQB}:\n return current_leaf.type in {token.RSQB, token.RPAR}\n elif past_leaf.type in {token.LPAR, token.LSQB}:\n return current_leaf.type in {token.NAME, token.LPAR, token.LSQB}\n else:\n return False\n\n\nclass StringTransformer(ABC):\n \"\"\"\n An implementation of the Transformer protocol that relies on its\n subclasses overriding the template methods `do_match(...)` and\n `do_transform(...)`.\n\n This Transformer works exclusively on strings (for example, by merging\n or splitting them).\n\n The following sections can be found among the docstrings of each concrete\n StringTransformer subclass.\n\n Requirements:\n Which requirements must be met of the given Line for this\n StringTransformer to be applied?\n\n Transformations:\n If the given Line meets all of the above requirements, which string\n transformations can you expect to be applied to it by this\n StringTransformer?\n\n Collaborations:\n What contractual agreements does this StringTransformer have with other\n StringTransfomers? Such collaborations should be eliminated/minimized\n as much as possible.\n \"\"\"\n\n __name__: Final = \"StringTransformer\"\n\n # Ideally this would be a dataclass, but unfortunately mypyc breaks when used with\n # `abc.ABC`.\n def __init__(self, line_length: int, normalize_strings: bool) -> None:\n self.line_length = line_length\n self.normalize_strings = normalize_strings\n\n @abstractmethod\n def do_match(self, line: Line) -> TMatchResult:\n \"\"\"\n Returns:\n * Ok(string_indices) such that for each index, `line.leaves[index]`\n is our target string if a match was able to be made. For\n transformers that don't result in more lines (e.g. StringMerger,\n StringParenStripper), multiple matches and transforms are done at\n once to reduce the complexity.\n OR\n * Err(CannotTransform), if no match could be made.\n \"\"\"\n\n @abstractmethod\n def do_transform(\n self, line: Line, string_indices: List[int]\n ) -> Iterator[TResult[Line]]:\n \"\"\"\n Yields:\n * Ok(new_line) where new_line is the new transformed line.\n OR\n * Err(CannotTransform) if the transformation failed for some reason. The\n `do_match(...)` template method should usually be used to reject\n the form of the given Line, but in some cases it is difficult to\n know whether or not a Line meets the StringTransformer's\n requirements until the transformation is already midway.\n\n Side Effects:\n This method should NOT mutate @line directly, but it MAY mutate the\n Line's underlying Node structure. (WARNING: If the underlying Node\n structure IS altered, then this method should NOT be allowed to\n yield an CannotTransform after that point.)\n \"\"\"\n\n def __call__(\n self, line: Line, _features: Collection[Feature], _mode: Mode\n ) -> Iterator[Line]:\n \"\"\"\n StringTransformer instances have a call signature that mirrors that of\n the Transformer type.\n\n Raises:\n CannotTransform(...) if the concrete StringTransformer class is unable\n to transform @line.\n \"\"\"\n # Optimization to avoid calling `self.do_match(...)` when the line does\n # not contain any string.\n if not any(leaf.type == token.STRING for leaf in line.leaves):\n raise CannotTransform(\"There are no strings in this line.\")\n\n match_result = self.do_match(line)\n\n if isinstance(match_result, Err):\n cant_transform = match_result.err()\n raise CannotTransform(\n f\"The string transformer {self.__class__.__name__} does not recognize\"\n \" this line as one that it can transform.\"\n ) from cant_transform\n\n string_indices = match_result.ok()\n\n for line_result in self.do_transform(line, string_indices):\n if isinstance(line_result, Err):\n cant_transform = line_result.err()\n raise CannotTransform(\n \"StringTransformer failed while attempting to transform string.\"\n ) from cant_transform\n line = line_result.ok()\n yield line\n\n\n@dataclass\nclass CustomSplit:\n \"\"\"A custom (i.e. manual) string split.\n\n A single CustomSplit instance represents a single substring.\n\n Examples:\n Consider the following string:\n ```\n \"Hi there friend.\"\n \" This is a custom\"\n f\" string {split}.\"\n ```\n\n This string will correspond to the following three CustomSplit instances:\n ```\n CustomSplit(False, 16)\n CustomSplit(False, 17)\n CustomSplit(True, 16)\n ```\n \"\"\"\n\n has_prefix: bool\n break_idx: int\n\n\n@trait\nclass CustomSplitMapMixin:\n \"\"\"\n This mixin class is used to map merged strings to a sequence of\n CustomSplits, which will then be used to re-split the strings iff none of\n the resultant substrings go over the configured max line length.\n \"\"\"\n\n _Key: ClassVar = Tuple[StringID, str]\n _CUSTOM_SPLIT_MAP: ClassVar[Dict[_Key, Tuple[CustomSplit, ...]]] = defaultdict(\n tuple\n )\n\n @staticmethod\n def _get_key(string: str) -> \"CustomSplitMapMixin._Key\":\n \"\"\"\n Returns:\n A unique identifier that is used internally to map @string to a\n group of custom splits.\n \"\"\"\n return (id(string), string)\n\n def add_custom_splits(\n self, string: str, custom_splits: Iterable[CustomSplit]\n ) -> None:\n \"\"\"Custom Split Map Setter Method\n\n Side Effects:\n Adds a mapping from @string to the custom splits @custom_splits.\n \"\"\"\n key = self._get_key(string)\n self._CUSTOM_SPLIT_MAP[key] = tuple(custom_splits)\n\n def pop_custom_splits(self, string: str) -> List[CustomSplit]:\n \"\"\"Custom Split Map Getter Method\n\n Returns:\n * A list of the custom splits that are mapped to @string, if any\n exist.\n OR\n * [], otherwise.\n\n Side Effects:\n Deletes the mapping between @string and its associated custom\n splits (which are returned to the caller).\n \"\"\"\n key = self._get_key(string)\n\n custom_splits = self._CUSTOM_SPLIT_MAP[key]\n del self._CUSTOM_SPLIT_MAP[key]\n\n return list(custom_splits)\n\n def has_custom_splits(self, string: str) -> bool:\n \"\"\"\n Returns:\n True iff @string is associated with a set of custom splits.\n \"\"\"\n key = self._get_key(string)\n return key in self._CUSTOM_SPLIT_MAP\n\n\nclass StringMerger(StringTransformer, CustomSplitMapMixin):\n \"\"\"StringTransformer that merges strings together.\n\n Requirements:\n (A) The line contains adjacent strings such that ALL of the validation checks\n listed in StringMerger._validate_msg(...)'s docstring pass.\n OR\n (B) The line contains a string which uses line continuation backslashes.\n\n Transformations:\n Depending on which of the two requirements above where met, either:\n\n (A) The string group associated with the target string is merged.\n OR\n (B) All line-continuation backslashes are removed from the target string.\n\n Collaborations:\n StringMerger provides custom split information to StringSplitter.\n \"\"\"\n\n def do_match(self, line: Line) -> TMatchResult:\n LL = line.leaves\n\n is_valid_index = is_valid_index_factory(LL)\n\n string_indices = []\n idx = 0\n while is_valid_index(idx):\n leaf = LL[idx]\n if (\n leaf.type == token.STRING\n and is_valid_index(idx + 1)\n and LL[idx + 1].type == token.STRING\n ):\n # Let's check if the string group contains an inline comment\n # If we have a comment inline, we don't merge the strings\n contains_comment = False\n i = idx\n while is_valid_index(i):\n if LL[i].type != token.STRING:\n break\n if line.comments_after(LL[i]):\n contains_comment = True\n break\n i += 1\n\n if not is_part_of_annotation(leaf) and not contains_comment:\n string_indices.append(idx)\n\n # Advance to the next non-STRING leaf.\n idx += 2\n while is_valid_index(idx) and LL[idx].type == token.STRING:\n idx += 1\n\n elif leaf.type == token.STRING and \"\\\\\\n\" in leaf.value:\n string_indices.append(idx)\n # Advance to the next non-STRING leaf.\n idx += 1\n while is_valid_index(idx) and LL[idx].type == token.STRING:\n idx += 1\n\n else:\n idx += 1\n\n if string_indices:\n return Ok(string_indices)\n else:\n return TErr(\"This line has no strings that need merging.\")\n\n def do_transform(\n self, line: Line, string_indices: List[int]\n ) -> Iterator[TResult[Line]]:\n new_line = line\n\n rblc_result = self._remove_backslash_line_continuation_chars(\n new_line, string_indices\n )\n if isinstance(rblc_result, Ok):\n new_line = rblc_result.ok()\n\n msg_result = self._merge_string_group(new_line, string_indices)\n if isinstance(msg_result, Ok):\n new_line = msg_result.ok()\n\n if isinstance(rblc_result, Err) and isinstance(msg_result, Err):\n msg_cant_transform = msg_result.err()\n rblc_cant_transform = rblc_result.err()\n cant_transform = CannotTransform(\n \"StringMerger failed to merge any strings in this line.\"\n )\n\n # Chain the errors together using `__cause__`.\n msg_cant_transform.__cause__ = rblc_cant_transform\n cant_transform.__cause__ = msg_cant_transform\n\n yield Err(cant_transform)\n else:\n yield Ok(new_line)\n\n @staticmethod\n def _remove_backslash_line_continuation_chars(\n line: Line, string_indices: List[int]\n ) -> TResult[Line]:\n \"\"\"\n Merge strings that were split across multiple lines using\n line-continuation backslashes.\n\n Returns:\n Ok(new_line), if @line contains backslash line-continuation\n characters.\n OR\n Err(CannotTransform), otherwise.\n \"\"\"\n LL = line.leaves\n\n indices_to_transform = []\n for string_idx in string_indices:\n string_leaf = LL[string_idx]\n if (\n string_leaf.type == token.STRING\n and \"\\\\\\n\" in string_leaf.value\n and not has_triple_quotes(string_leaf.value)\n ):\n indices_to_transform.append(string_idx)\n\n if not indices_to_transform:\n return TErr(\n \"Found no string leaves that contain backslash line continuation\"\n \" characters.\"\n )\n\n new_line = line.clone()\n new_line.comments = line.comments.copy()\n append_leaves(new_line, line, LL)\n\n for string_idx in indices_to_transform:\n new_string_leaf = new_line.leaves[string_idx]\n new_string_leaf.value = new_string_leaf.value.replace(\"\\\\\\n\", \"\")\n\n return Ok(new_line)\n\n def _merge_string_group(\n self, line: Line, string_indices: List[int]\n ) -> TResult[Line]:\n \"\"\"\n Merges string groups (i.e. set of adjacent strings).\n\n Each index from `string_indices` designates one string group's first\n leaf in `line.leaves`.\n\n Returns:\n Ok(new_line), if ALL of the validation checks found in\n _validate_msg(...) pass.\n OR\n Err(CannotTransform), otherwise.\n \"\"\"\n LL = line.leaves\n\n is_valid_index = is_valid_index_factory(LL)\n\n # A dict of {string_idx: tuple[num_of_strings, string_leaf]}.\n merged_string_idx_dict: Dict[int, Tuple[int, Leaf]] = {}\n for string_idx in string_indices:\n vresult = self._validate_msg(line, string_idx)\n if isinstance(vresult, Err):\n continue\n merged_string_idx_dict[string_idx] = self._merge_one_string_group(\n LL, string_idx, is_valid_index\n )\n\n if not merged_string_idx_dict:\n return TErr(\"No string group is merged\")\n\n # Build the final line ('new_line') that this method will later return.\n new_line = line.clone()\n previous_merged_string_idx = -1\n previous_merged_num_of_strings = -1\n for i, leaf in enumerate(LL):\n if i in merged_string_idx_dict:\n previous_merged_string_idx = i\n previous_merged_num_of_strings, string_leaf = merged_string_idx_dict[i]\n new_line.append(string_leaf)\n\n if (\n previous_merged_string_idx\n <= i\n < previous_merged_string_idx + previous_merged_num_of_strings\n ):\n for comment_leaf in line.comments_after(LL[i]):\n new_line.append(comment_leaf, preformatted=True)\n continue\n\n append_leaves(new_line, line, [leaf])\n\n return Ok(new_line)\n\n def _merge_one_string_group(\n self, LL: List[Leaf], string_idx: int, is_valid_index: Callable[[int], bool]\n ) -> Tuple[int, Leaf]:\n \"\"\"\n Merges one string group where the first string in the group is\n `LL[string_idx]`.\n\n Returns:\n A tuple of `(num_of_strings, leaf)` where `num_of_strings` is the\n number of strings merged and `leaf` is the newly merged string\n to be replaced in the new line.\n \"\"\"\n # If the string group is wrapped inside an Atom node, we must make sure\n # to later replace that Atom with our new (merged) string leaf.\n atom_node = LL[string_idx].parent\n\n # We will place BREAK_MARK in between every two substrings that we\n # merge. We will then later go through our final result and use the\n # various instances of BREAK_MARK we find to add the right values to\n # the custom split map.\n BREAK_MARK = \"@@@@@ BLACK BREAKPOINT MARKER @@@@@\"\n\n QUOTE = LL[string_idx].value[-1]\n\n def make_naked(string: str, string_prefix: str) -> str:\n \"\"\"Strip @string (i.e. make it a \"naked\" string)\n\n Pre-conditions:\n * assert_is_leaf_string(@string)\n\n Returns:\n A string that is identical to @string except that\n @string_prefix has been stripped, the surrounding QUOTE\n characters have been removed, and any remaining QUOTE\n characters have been escaped.\n \"\"\"\n assert_is_leaf_string(string)\n if \"f\" in string_prefix:\n f_expressions = (\n string[span[0] + 1 : span[1] - 1] # +-1 to get rid of curly braces\n for span in iter_fexpr_spans(string)\n )\n debug_expressions_contain_visible_quotes = any(\n re.search(r\".*[\\'\\\"].*(?= 0\n ), \"Logic error while filling the custom string breakpoint cache.\"\n\n temp_string = temp_string[mark_idx + len(BREAK_MARK) :]\n breakpoint_idx = mark_idx + (len(prefix) if has_prefix else 0) + 1\n custom_splits.append(CustomSplit(has_prefix, breakpoint_idx))\n\n string_leaf = Leaf(token.STRING, S_leaf.value.replace(BREAK_MARK, \"\"))\n\n if atom_node is not None:\n # If not all children of the atom node are merged (this can happen\n # when there is a standalone comment in the middle) ...\n if non_string_idx - string_idx < len(atom_node.children):\n # We need to replace the old STRING leaves with the new string leaf.\n first_child_idx = LL[string_idx].remove()\n for idx in range(string_idx + 1, non_string_idx):\n LL[idx].remove()\n if first_child_idx is not None:\n...\n# Path: src/black/linegen.py\n\"\"\"\nGenerating lines of code.\n\"\"\"\n\nimport re\nimport sys\nfrom dataclasses import replace\nfrom enum import Enum, auto\nfrom functools import partial, wraps\nfrom typing import Collection, Iterator, List, Optional, Set, Union, cast\n\nfrom black.brackets import (\n COMMA_PRIORITY,\n DOT_PRIORITY,\n STRING_PRIORITY,\n get_leaves_inside_matching_brackets,\n max_delimiter_priority_in_atom,\n)\nfrom black.comments import FMT_OFF, generate_comments, list_comments\nfrom black.lines import (\n Line,\n RHSResult,\n append_leaves,\n can_be_split,\n can_omit_invisible_parens,\n is_line_short_enough,\n line_to_string,\n)\nfrom black.mode import Feature, Mode, Preview\nfrom black.nodes import (\n ASSIGNMENTS,\n BRACKETS,\n CLOSING_BRACKETS,\n OPENING_BRACKETS,\n STANDALONE_COMMENT,\n STATEMENT,\n WHITESPACE,\n Visitor,\n ensure_visible,\n get_annotation_type,\n is_arith_like,\n is_async_stmt_or_funcdef,\n is_atom_with_invisible_parens,\n is_docstring,\n is_empty_tuple,\n is_lpar_token,\n is_multiline_string,\n is_name_token,\n is_one_sequence_between,\n is_one_tuple,\n is_parent_function_or_class,\n is_part_of_annotation,\n is_rpar_token,\n is_stub_body,\n is_stub_suite,\n is_tuple_containing_walrus,\n is_type_ignore_comment_string,\n is_vararg,\n is_walrus_assignment,\n is_yield,\n syms,\n wrap_in_parentheses,\n)\nfrom black.numerics import normalize_numeric_literal\nfrom black.strings import (\n fix_docstring,\n get_string_prefix,\n normalize_string_prefix,\n normalize_string_quotes,\n normalize_unicode_escape_sequences,\n)\nfrom black.trans import (\n CannotTransform,\n StringMerger,\n StringParenStripper,\n StringParenWrapper,\n StringSplitter,\n Transformer,\n hug_power_op,\n)\nfrom blib2to3.pgen2 import token\nfrom blib2to3.pytree import Leaf, Node\n\n# types\nLeafID = int\nLN = Union[Leaf, Node]\n\n\nclass CannotSplit(CannotTransform):\n \"\"\"A readable split that fits the allotted line length is impossible.\"\"\"\n\n\n# This isn't a dataclass because @dataclass + Generic breaks mypyc.\n# See also https://github.com/mypyc/mypyc/issues/827.\nclass LineGenerator(Visitor[Line]):\n \"\"\"Generates reformatted Line objects. Empty lines are not emitted.\n\n Note: destroys the tree it's visiting by mutating prefixes of its leaves\n in ways that will no longer stringify to valid Python code on the tree.\n \"\"\"\n\n def __init__(self, mode: Mode, features: Collection[Feature]) -> None:\n self.mode = mode\n self.features = features\n self.current_line: Line\n self.__post_init__()\n\n def line(self, indent: int = 0) -> Iterator[Line]:\n \"\"\"Generate a line.\n\n If the line is empty, only emit if it makes sense.\n If the line is too long, split it first and then generate.\n\n If any lines were generated, set up a new current_line.\n \"\"\"\n if not self.current_line:\n self.current_line.depth += indent\n return # Line is empty, don't emit. Creating a new one unnecessary.\n\n if len(self.current_line.leaves) == 1 and is_async_stmt_or_funcdef(\n self.current_line.leaves[0]\n ):\n # Special case for async def/for/with statements. `visit_async_stmt`\n # adds an `ASYNC` leaf then visits the child def/for/with statement\n # nodes. Line yields from those nodes shouldn't treat the former\n # `ASYNC` leaf as a complete line.\n return\n\n complete_line = self.current_line\n self.current_line = Line(mode=self.mode, depth=complete_line.depth + indent)\n yield complete_line\n\n def visit_default(self, node: LN) -> Iterator[Line]:\n \"\"\"Default `visit_*()` implementation. Recurses to children of `node`.\"\"\"\n if isinstance(node, Leaf):\n any_open_brackets = self.current_line.bracket_tracker.any_open_brackets()\n for comment in generate_comments(node):\n if any_open_brackets:\n # any comment within brackets is subject to splitting\n self.current_line.append(comment)\n elif comment.type == token.COMMENT:\n # regular trailing comment\n self.current_line.append(comment)\n yield from self.line()\n\n else:\n # regular standalone comment\n yield from self.line()\n\n self.current_line.append(comment)\n yield from self.line()\n\n if any_open_brackets:\n node.prefix = \"\"\n if self.mode.string_normalization and node.type == token.STRING:\n node.value = normalize_string_prefix(node.value)\n node.value = normalize_string_quotes(node.value)\n if node.type == token.NUMBER:\n normalize_numeric_literal(node)\n if node.type not in WHITESPACE:\n self.current_line.append(node)\n yield from super().visit_default(node)\n\n def visit_test(self, node: Node) -> Iterator[Line]:\n \"\"\"Visit an `x if y else z` test\"\"\"\n\n already_parenthesized = (\n node.prev_sibling and node.prev_sibling.type == token.LPAR\n )\n\n if not already_parenthesized:\n # Similar to logic in wrap_in_parentheses\n lpar = Leaf(token.LPAR, \"\")\n rpar = Leaf(token.RPAR, \"\")\n prefix = node.prefix\n node.prefix = \"\"\n lpar.prefix = prefix\n node.insert_child(0, lpar)\n node.append_child(rpar)\n\n yield from self.visit_default(node)\n\n def visit_INDENT(self, node: Leaf) -> Iterator[Line]:\n \"\"\"Increase indentation level, maybe yield a line.\"\"\"\n # In blib2to3 INDENT never holds comments.\n yield from self.line(+1)\n yield from self.visit_default(node)\n\n def visit_DEDENT(self, node: Leaf) -> Iterator[Line]:\n \"\"\"Decrease indentation level, maybe yield a line.\"\"\"\n # The current line might still wait for trailing comments. At DEDENT time\n # there won't be any (they would be prefixes on the preceding NEWLINE).\n # Emit the line then.\n yield from self.line()\n\n # While DEDENT has no value, its prefix may contain standalone comments\n # that belong to the current indentation level. Get 'em.\n yield from self.visit_default(node)\n\n # Finally, emit the dedent.\n yield from self.line(-1)\n\n def visit_stmt(\n self, node: Node, keywords: Set[str], parens: Set[str]\n ) -> Iterator[Line]:\n \"\"\"Visit a statement.\n\n This implementation is shared for `if`, `while`, `for`, `try`, `except`,\n `def`, `with`, `class`, `assert`, and assignments.\n\n The relevant Python language `keywords` for a given statement will be\n NAME leaves within it. This methods puts those on a separate line.\n\n `parens` holds a set of string leaf values immediately after which\n invisible parens should be put.\n \"\"\"\n normalize_invisible_parens(\n node, parens_after=parens, mode=self.mode, features=self.features\n )\n for child in node.children:\n if is_name_token(child) and child.value in keywords:\n yield from self.line()\n\n yield from self.visit(child)\n\n def visit_typeparams(self, node: Node) -> Iterator[Line]:\n yield from self.visit_default(node)\n node.children[0].prefix = \"\"\n\n def visit_typevartuple(self, node: Node) -> Iterator[Line]:\n yield from self.visit_default(node)\n node.children[1].prefix = \"\"\n\n def visit_paramspec(self, node: Node) -> Iterator[Line]:\n yield from self.visit_default(node)\n node.children[1].prefix = \"\"\n\n def visit_dictsetmaker(self, node: Node) -> Iterator[Line]:\n if Preview.wrap_long_dict_values_in_parens in self.mode:\n for i, child in enumerate(node.children):\n if i == 0:\n continue\n if node.children[i - 1].type == token.COLON:\n if (\n child.type == syms.atom\n and child.children[0].type in OPENING_BRACKETS\n and not is_walrus_assignment(child)\n ):\n maybe_make_parens_invisible_in_atom(\n child,\n parent=node,\n remove_brackets_around_comma=False,\n )\n else:\n wrap_in_parentheses(node, child, visible=False)\n yield from self.visit_default(node)\n\n def visit_funcdef(self, node: Node) -> Iterator[Line]:\n \"\"\"Visit function definition.\"\"\"\n yield from self.line()\n\n # Remove redundant brackets around return type annotation.\n is_return_annotation = False\n for child in node.children:\n if child.type == token.RARROW:\n is_return_annotation = True\n elif is_return_annotation:\n if child.type == syms.atom and child.children[0].type == token.LPAR:\n if maybe_make_parens_invisible_in_atom(\n child,\n parent=node,\n remove_brackets_around_comma=False,\n ):\n wrap_in_parentheses(node, child, visible=False)\n else:\n wrap_in_parentheses(node, child, visible=False)\n is_return_annotation = False\n\n for child in node.children:\n yield from self.visit(child)\n\n def visit_match_case(self, node: Node) -> Iterator[Line]:\n \"\"\"Visit either a match or case statement.\"\"\"\n normalize_invisible_parens(\n node, parens_after=set(), mode=self.mode, features=self.features\n )\n\n yield from self.line()\n for child in node.children:\n yield from self.visit(child)\n\n def visit_suite(self, node: Node) -> Iterator[Line]:\n \"\"\"Visit a suite.\"\"\"\n if is_stub_suite(node):\n yield from self.visit(node.children[2])\n else:\n yield from self.visit_default(node)\n\n def visit_simple_stmt(self, node: Node) -> Iterator[Line]:\n \"\"\"Visit a statement without nested statements.\"\"\"\n prev_type: Optional[int] = None\n for child in node.children:\n if (prev_type is None or prev_type == token.SEMI) and is_arith_like(child):\n wrap_in_parentheses(node, child, visible=False)\n prev_type = child.type\n\n if node.parent and node.parent.type in STATEMENT:\n if is_parent_function_or_class(node) and is_stub_body(node):\n yield from self.visit_default(node)\n else:\n yield from self.line(+1)\n yield from self.visit_default(node)\n yield from self.line(-1)\n\n else:\n if not node.parent or not is_stub_suite(node.parent):\n yield from self.line()\n yield from self.visit_default(node)\n\n def visit_async_stmt(self, node: Node) -> Iterator[Line]:\n \"\"\"Visit `async def`, `async for`, `async with`.\"\"\"\n yield from self.line()\n\n children = iter(node.children)\n for child in children:\n yield from self.visit(child)\n\n if child.type == token.ASYNC or child.type == STANDALONE_COMMENT:\n # STANDALONE_COMMENT happens when `# fmt: skip` is applied on the async\n # line.\n break\n\n internal_stmt = next(children)\n yield from self.visit(internal_stmt)\n\n def visit_decorators(self, node: Node) -> Iterator[Line]:\n \"\"\"Visit decorators.\"\"\"\n for child in node.children:\n yield from self.line()\n yield from self.visit(child)\n\n def visit_power(self, node: Node) -> Iterator[Line]:\n for idx, leaf in enumerate(node.children[:-1]):\n next_leaf = node.children[idx + 1]\n\n if not isinstance(leaf, Leaf):\n continue\n\n value = leaf.value.lower()\n if (\n leaf.type == token.NUMBER\n and next_leaf.type == syms.trailer\n # Ensure that we are in an attribute trailer\n and next_leaf.children[0].type == token.DOT\n # It shouldn't wrap hexadecimal, binary and octal literals\n and not value.startswith((\"0x\", \"0b\", \"0o\"))\n # It shouldn't wrap complex literals\n and \"j\" not in value\n ):\n wrap_in_parentheses(node, leaf)\n\n remove_await_parens(node)\n\n yield from self.visit_default(node)\n\n def visit_SEMI(self, leaf: Leaf) -> Iterator[Line]:\n \"\"\"Remove a semicolon and put the other statement on a separate line.\"\"\"\n yield from self.line()\n\n def visit_ENDMARKER(self, leaf: Leaf) -> Iterator[Line]:\n \"\"\"End of file. Process outstanding comments and end with a newline.\"\"\"\n yield from self.visit_default(leaf)\n yield from self.line()\n\n def visit_STANDALONE_COMMENT(self, leaf: Leaf) -> Iterator[Line]:\n if not self.current_line.bracket_tracker.any_open_brackets():\n yield from self.line()\n yield from self.visit_default(leaf)\n\n def visit_factor(self, node: Node) -> Iterator[Line]:\n \"\"\"Force parentheses between a unary op and a binary power:\n\n -2 ** 8 -> -(2 ** 8)\n \"\"\"\n _operator, operand = node.children\n if (\n operand.type == syms.power\n and len(operand.children) == 3\n and operand.children[1].type == token.DOUBLESTAR\n ):\n lpar = Leaf(token.LPAR, \"(\")\n rpar = Leaf(token.RPAR, \")\")\n index = operand.remove() or 0\n node.insert_child(index, Node(syms.atom, [lpar, operand, rpar]))\n yield from self.visit_default(node)\n\n def visit_tname(self, node: Node) -> Iterator[Line]:\n \"\"\"\n Add potential parentheses around types in function parameter lists to be made\n into real parentheses in case the type hint is too long to fit on a line\n Examples:\n def foo(a: int, b: float = 7): ...\n\n ->\n\n def foo(a: (int), b: (float) = 7): ...\n \"\"\"\n assert len(node.children) == 3\n if maybe_make_parens_invisible_in_atom(node.children[2], parent=node):\n wrap_in_parentheses(node, node.children[2], visible=False)\n\n yield from self.visit_default(node)\n\n def visit_STRING(self, leaf: Leaf) -> Iterator[Line]:\n if Preview.hex_codes_in_unicode_sequences in self.mode:\n normalize_unicode_escape_sequences(leaf)\n\n if is_docstring(leaf, self.mode) and not re.search(r\"\\\\\\s*\\n\", leaf.value):\n # We're ignoring docstrings with backslash newline escapes because changing\n # indentation of those changes the AST representation of the code.\n if self.mode.string_normalization:\n docstring = normalize_string_prefix(leaf.value)\n # visit_default() does handle string normalization for us, but\n # since this method acts differently depending on quote style (ex.\n # see padding logic below), there's a possibility for unstable\n # formatting as visit_default() is called *after*. To avoid a\n # situation where this function formats a docstring differently on\n # the second pass, normalize it early.\n docstring = normalize_string_quotes(docstring)\n else:\n docstring = leaf.value\n prefix = get_string_prefix(docstring)\n docstring = docstring[len(prefix) :] # Remove the prefix\n quote_char = docstring[0]\n # A natural way to remove the outer quotes is to do:\n # docstring = docstring.strip(quote_char)\n # but that breaks on \"\"\"\"\"x\"\"\" (which is '\"\"x').\n # So we actually need to remove the first character and the next two\n # characters but only if they are the same as the first.\n quote_len = 1 if docstring[1] != quote_char else 3\n docstring = docstring[quote_len:-quote_len]\n docstring_started_empty = not docstring\n indent = \" \" * 4 * self.current_line.depth\n\n if is_multiline_string(leaf):\n docstring = fix_docstring(docstring, indent)\n else:\n docstring = docstring.strip()\n\n has_trailing_backslash = False\n if docstring:\n # Add some padding if the docstring starts / ends with a quote mark.\n if docstring[0] == quote_char:\n docstring = \" \" + docstring\n if docstring[-1] == quote_char:\n docstring += \" \"\n if docstring[-1] == \"\\\\\":\n backslash_count = len(docstring) - len(docstring.rstrip(\"\\\\\"))\n if backslash_count % 2:\n # Odd number of tailing backslashes, add some padding to\n # avoid escaping the closing string quote.\n docstring += \" \"\n has_trailing_backslash = True\n elif not docstring_started_empty:\n docstring = \" \"\n\n # We could enforce triple quotes at this point.\n quote = quote_char * quote_len\n\n # It's invalid to put closing single-character quotes on a new line.\n if quote_len == 3:\n # We need to find the length of the last line of the docstring\n # to find if we can add the closing quotes to the line without\n # exceeding the maximum line length.\n # If docstring is one line, we don't put the closing quotes on a\n # separate line because it looks ugly (#3320).\n lines = docstring.splitlines()\n last_line_length = len(lines[-1]) if docstring else 0\n\n # If adding closing quotes would cause the last line to exceed\n # the maximum line length, and the closing quote is not\n # prefixed by a newline then put a line break before\n # the closing quotes\n if (\n len(lines) > 1\n and last_line_length + quote_len > self.mode.line_length\n and len(indent) + quote_len <= self.mode.line_length\n and not has_trailing_backslash\n ):\n if (\n Preview.docstring_check_for_newline in self.mode\n and leaf.value[-1 - quote_len] == \"\\n\"\n ):\n leaf.value = prefix + quote + docstring + quote\n else:\n leaf.value = prefix + quote + docstring + \"\\n\" + indent + quote\n else:\n leaf.value = prefix + quote + docstring + quote\n else:\n leaf.value = prefix + quote + docstring + quote\n\n yield from self.visit_default(leaf)\n\n def __post_init__(self) -> None:\n \"\"\"You are in a twisty little maze of passages.\"\"\"\n self.current_line = Line(mode=self.mode)\n\n v = self.visit_stmt\n \u00d8: Set[str] = set()\n self.visit_assert_stmt = partial(v, keywords={\"assert\"}, parens={\"assert\", \",\"})\n self.visit_if_stmt = partial(\n v, keywords={\"if\", \"else\", \"elif\"}, parens={\"if\", \"elif\"}\n )\n self.visit_while_stmt = partial(v, keywords={\"while\", \"else\"}, parens={\"while\"})\n self.visit_for_stmt = partial(v, keywords={\"for\", \"else\"}, parens={\"for\", \"in\"})\n self.visit_try_stmt = partial(\n v, keywords={\"try\", \"except\", \"else\", \"finally\"}, parens=\u00d8\n )\n self.visit_except_clause = partial(v, keywords={\"except\"}, parens={\"except\"})\n self.visit_with_stmt = partial(v, keywords={\"with\"}, parens={\"with\"})\n self.visit_classdef = partial(v, keywords={\"class\"}, parens=\u00d8)\n\n self.visit_expr_stmt = partial(v, keywords=\u00d8, parens=ASSIGNMENTS)\n self.visit_return_stmt = partial(v, keywords={\"return\"}, parens={\"return\"})\n self.visit_import_from = partial(v, keywords=\u00d8, parens={\"import\"})\n self.visit_del_stmt = partial(v, keywords=\u00d8, parens={\"del\"})\n self.visit_async_funcdef = self.visit_async_stmt\n self.visit_decorated = self.visit_decorators\n\n # PEP 634\n self.visit_match_stmt = self.visit_match_case\n self.visit_case_block = self.visit_match_case\n if Preview.remove_redundant_guard_parens in self.mode:\n self.visit_guard = partial(v, keywords=\u00d8, parens={\"if\"})\n\n\ndef _hugging_power_ops_line_to_string(\n line: Line,\n features: Collection[Feature],\n mode: Mode,\n) -> Optional[str]:\n try:\n return line_to_string(next(hug_power_op(line, features, mode)))\n except CannotTransform:\n return None\n\n\ndef transform_line(\n line: Line, mode: Mode, features: Collection[Feature] = ()\n) -> Iterator[Line]:\n \"\"\"Transform a `line`, potentially splitting it into many lines.\n\n They should fit in the allotted `line_length` but might not be able to.\n\n `features` are syntactical features that may be used in the output.\n \"\"\"\n if line.is_comment:\n yield line\n return\n\n line_str = line_to_string(line)\n\n # We need the line string when power operators are hugging to determine if we should\n # split the line. Default to line_str, if no power operator are present on the line.\n line_str_hugging_power_ops = (\n _hugging_power_ops_line_to_string(line, features, mode) or line_str\n )\n\n ll = mode.line_length\n sn = mode.string_normalization\n string_merge = StringMerger(ll, sn)\n string_paren_strip = StringParenStripper(ll, sn)\n string_split = StringSplitter(ll, sn)\n string_paren_wrap = StringParenWrapper(ll, sn)\n\n transformers: List[Transformer]\n if (\n not line.contains_uncollapsable_type_comments()\n and not line.should_split_rhs\n and not line.magic_trailing_comma\n and (\n is_line_short_enough(line, mode=mode, line_str=line_str_hugging_power_ops)\n or line.contains_unsplittable_type_ignore()\n )\n and not (line.inside_brackets and line.contains_standalone_comments())\n and not line.contains_implicit_multiline_string_with_comments()\n ):\n # Only apply basic string preprocessing, since lines shouldn't be split here.\n if Preview.string_processing in mode:\n transformers = [string_merge, string_paren_strip]\n else:\n transformers = []\n elif line.is_def and not should_split_funcdef_with_rhs(line, mode):\n transformers = [left_hand_split]\n else:\n\n def _rhs(\n self: object, line: Line, features: Collection[Feature], mode: Mode\n ) -> Iterator[Line]:\n \"\"\"Wraps calls to `right_hand_split`.\n\n The calls increasingly `omit` right-hand trailers (bracket pairs with\n content), meaning the trailers get glued together to split on another\n bracket pair instead.\n \"\"\"\n for omit in generate_trailers_to_omit(line, mode.line_length):\n lines = list(right_hand_split(line, mode, features, omit=omit))\n # Note: this check is only able to figure out if the first line of the\n # *current* transformation fits in the line length. This is true only\n # for simple cases. All others require running more transforms via\n # `transform_line()`. This check doesn't know if those would succeed.\n if is_line_short_enough(lines[0], mode=mode):\n yield from lines\n return\n\n # All splits failed, best effort split with no omits.\n # This mostly happens to multiline strings that are by definition\n # reported as not fitting a single line, as well as lines that contain\n # trailing commas (those have to be exploded).\n yield from right_hand_split(line, mode, features=features)\n\n # HACK: nested functions (like _rhs) compiled by mypyc don't retain their\n # __name__ attribute which is needed in `run_transformer` further down.\n # Unfortunately a nested class breaks mypyc too. So a class must be created\n # via type ... https://github.com/mypyc/mypyc/issues/884\n rhs = type(\"rhs\", (), {\"__call__\": _rhs})()\n\n if Preview.string_processing in mode:\n if line.inside_brackets:\n transformers = [\n string_merge,\n string_paren_strip,\n string_split,\n delimiter_split,\n standalone_comment_split,\n string_paren_wrap,\n rhs,\n ]\n else:\n transformers = [\n string_merge,\n string_paren_strip,\n string_split,\n string_paren_wrap,\n rhs,\n ]\n else:\n if line.inside_brackets:\n transformers = [delimiter_split, standalone_comment_split, rhs]\n else:\n transformers = [rhs]\n # It's always safe to attempt hugging of power operations and pretty much every line\n # could match.\n transformers.append(hug_power_op)\n\n for transform in transformers:\n # We are accumulating lines in `result` because we might want to abort\n # mission and return the original line in the end, or attempt a different\n # split altogether.\n try:\n result = run_transformer(line, transform, mode, features, line_str=line_str)\n except CannotTransform:\n continue\n else:\n yield from result\n break\n\n else:\n yield line\n\n\ndef sho\nuld_split_funcdef_with_rhs(line: Line, mode: Mode) -> bool:\n \"\"\"If a funcdef has a magic trailing comma in the return type, then we should first\n split the line with rhs to respect the comma.\n \"\"\"\n return_type_leaves: List[Leaf] = []\n in_return_type = False\n\n for leaf in line.leaves:\n if leaf.type == token.COLON:\n in_return_type = False\n if in_return_type:\n return_type_leaves.append(leaf)\n if leaf.type == token.RARROW:\n in_return_type = True\n\n # using `bracket_split_build_line` will mess with whitespace, so we duplicate a\n # couple lines from it.\n result = Line(mode=line.mode, depth=line.depth)\n leaves_to_track = get_leaves_inside_matching_brackets(return_type_leaves)\n for leaf in return_type_leaves:\n result.append(\n leaf,\n preformatted=True,\n track_bracket=id(leaf) in leaves_to_track,\n )\n\n # we could also return true if the line is too long, and the return type is longer\n # than the param list. Or if `should_split_rhs` returns True.\n return result.magic_trailing_comma is not None\n\n\nclass _BracketSplitComponent(Enum):\n head = auto()\n body = auto()\n tail = auto()\n\n\ndef left_hand_split(\n line: Line, _features: Collection[Feature], mode: Mode\n) -> Iterator[Line]:\n \"\"\"Split line into many lines, starting with the first matching bracket pair.\n\n Note: this usually looks weird, only use this for function definitions.\n Prefer RHS otherwise. This is why this function is not symmetrical with\n :func:`right_hand_split` which also handles optional parentheses.\n \"\"\"\n tail_leaves: List[Leaf] = []\n body_leaves: List[Leaf] = []\n head_leaves: List[Leaf] = []\n current_leaves = head_leaves\n matching_bracket: Optional[Leaf] = None\n for leaf in line.leaves:\n if (\n current_leaves is body_leaves\n and leaf.type in CLOSING_BRACKETS\n and leaf.opening_bracket is matching_bracket\n and isinstance(matching_bracket, Leaf)\n ):\n ensure_visible(leaf)\n ensure_visible(matching_bracket)\n current_leaves = tail_leaves if body_leaves else head_leaves\n current_leaves.append(leaf)\n if current_leaves is head_leaves:\n if leaf.type in OPENING_BRACKETS:\n matching_bracket = leaf\n current_leaves = body_leaves\n if not matching_bracket or not tail_leaves:\n raise CannotSplit(\"No brackets found\")\n\n head = bracket_split_build_line(\n head_leaves, line, matching_bracket, component=_BracketSplitComponent.head\n )\n body = bracket_split_build_line(\n body_leaves, line, matching_bracket, component=_BracketSplitComponent.body\n )\n tail = bracket_split_build_line(\n tail_leaves, line, matching_bracket, component=_BracketSplitComponent.tail\n )\n bracket_split_succeeded_or_raise(head, body, tail)\n for result in (head, body, tail):\n if result:\n yield result\n\n\ndef right_hand_split(\n line: Line,\n mode: Mode,\n features: Collection[Feature] = (),\n omit: Collection[LeafID] = (),\n) -> Iterator[Line]:\n \"\"\"Split line into many lines, starting with the last matching bracket pair.\n\n...\n", "needle_token_start": 15148, "needle_token_end": 15502, "code_context_ntokens": 16140, "output": ["```python\ndef should_split_funcdef_with_rhs(line: Line, mode: Mode) -> bool:\n \"\"\"If a funcdef has a magic trailing comma in the return type, then we should first\n split the line with rhs to respect the comma.\n \"\"\"\n return_type_leaves: List[Leaf] = []\n in_return_type = False\n\n for leaf in line.leaves:\n if leaf.type == token.COLON:\n in_return_type = False\n if in_return_type:\n return_type_leaves.append(leaf)\n if leaf.type == token.RARROW:\n in_return_type = True\n\n # using `bracket_split_build_line` will mess with whitespace, so we duplicate a\n # couple lines from it.\n result = Line(mode=line.mode, depth=line.depth)\n leaves_to_track = get_leaves_inside_matching_brackets(return_type_leaves)\n for leaf in return_type_leaves:\n result.append(\n leaf,\n preformatted=True,\n track_bracket=id(leaf) in leaves_to_track,\n )\n\n # we could also return true if the line is too long, and the return type is longer\n # than the param list. Or if `should_split_rhs` returns True.\n return result.magic_trailing_comma is not None\n```"]} +{"repo": "python-poetry/poetry", "name": "_find_best_version_for_package", "language": "python", "path": "src/poetry/console/commands/init.py", "position_ratio": 0.05, "description": "\nFunction Description:\n1. **Purpose**: The function is designed to determine the most suitable version of a software package based on given criteria, ensuring compatibility and meeting specific requirements.\n2. **Input**: It accepts four parameters: the name of the package, an optional version requirement, a flag to include pre-release versions, and an optional source to specify where to fetch the package.\n3. **Output**: Returns a tuple containing the formatted name of the package and a version string prefixed with a caret (^), indicating compatibility with versions greater than or equal to the specified one.\n4. **Procedure**: The function utilizes a version selector tool to find the best candidate that matches the provided name and version criteria from a pool of available packages. If no suitable package is found, it raises an error. Otherwise, it returns the formatted package name and the compatible version string.\n\n", "instruction": "Based on the function description and code context, please retrieve and repeat the exact described function from the code context in a code block wrapped by ```:", "template": "instruction\ncode_context\ndescription\ninstruction", "code_context": " info_string += \"\\nShowing the first 10 matches\"\n\n self.line(info_string)\n\n # Default to an empty value to signal no package was selected\n choices.append(\"\")\n\n package = self.choice(\n \"\\nEnter package # to add, or the complete package name if\"\n \" it is not listed\",\n choices,\n attempts=3,\n default=len(choices) - 1,\n )\n\n if not package:\n self.line(\"No package selected\")\n\n # package selected by user, set constraint name to package name\n if package:\n constraint[\"name\"] = package\n\n # no constraint yet, determine the best version automatically\n if package and \"version\" not in constraint:\n question = self.create_question(\n \"Enter the version constraint to require \"\n \"(or leave blank to use the latest version):\"\n )\n question.set_max_attempts(3)\n question.set_validator(lambda x: (x or \"\").strip() or None)\n\n package_constraint = self.ask(question)\n\n if package_constraint is None:\n _, package_constraint = self._find_best_version_for_package(\n package\n )\n\n self.line(\n f\"Using version {package_constraint} for\"\n f\" {package}\"\n )\n\n constraint[\"version\"] = package_constraint\n\n if package:\n result.append(constraint)\n\n if is_interactive:\n package = self.ask(follow_up_question)\n\n return result\n\n result = []\n for requirement in self._parse_requirements(requires):\n if \"git\" in requirement or \"url\" in requirement or \"path\" in requirement:\n result.append(requirement)\n continue\n elif \"version\" not in requirement:\n # determine the best version automatically\n name, version = self._find_best_version_for_package(\n requirement[\"name\"],\n allow_prereleases=allow_prereleases,\n source=source,\n )\n requirement[\"version\"] = version\n requirement[\"name\"] = name\n\n self.line(f\"Using version {version} for {name}\")\n else:\n # check that the specified version/constraint exists\n # before we proceed\n name, _ = self._find_best_version_for_package(\n requirement[\"name\"],\n requirement[\"version\"],\n allow_prereleases=allow_prereleases,\n source=source,\n )\n\n requirement[\"name\"] = name\n\n result.append(requirement)\n\n return result\n\n \ndef _find_best_version_for_package(\n self,\n name: str,\n required_version: str | None = None,\n allow_prereleases: bool = False,\n source: str | None = None,\n ) -> tuple[str, str]:\n from poetry.version.version_selector import VersionSelector\n\n selector = VersionSelector(self._get_pool())\n package = selector.find_best_candidate(\n name, required_version, allow_prereleases=allow_prereleases, source=source\n )\n\n if not package:\n # TODO: find similar\n raise ValueError(f\"Could not find a matching version of package {name}\")\n\n return package.pretty_name, f\"^{package.version.to_string()}\"\n\n def _parse_requirements(self, requirements: list[str]) -> list[dict[str, Any]]:\n from poetry.core.pyproject.exceptions import PyProjectException\n\n try:\n cwd = self.poetry.file.path.parent\n artifact_cache = self.poetry.pool.artifact_cache\n except (PyProjectException, RuntimeError):\n cwd = Path.cwd()\n artifact_cache = self._get_pool().artifact_cache\n\n parser = RequirementsParser(\n artifact_cache=artifact_cache,\n env=self.env if isinstance(self, EnvCommand) else None,\n cwd=cwd,\n )\n return [parser.parse(requirement) for requirement in requirements]\n\n def _format_requirements(self, requirements: list[dict[str, str]]) -> Requirements:\n requires: Requirements = {}\n for requirement in requirements:\n name = requirement.pop(\"name\")\n constraint: str | InlineTable\n if \"version\" in requirement and len(requirement) == 1:\n constraint = requirement[\"version\"]\n else:\n constraint = inline_table()\n constraint.trivia.trail = \"\\n\"\n constraint.update(requirement)\n\n requires[name] = constraint\n\n return requires\n\n @staticmethod\n def _validate_author(author: str, default: str) -> str | None:\n from poetry.core.packages.package import AUTHOR_REGEX\n from poetry.core.utils.helpers import combine_unicode\n\n author = combine_unicode(author or default)\n\n if author in [\"n\", \"no\"]:\n return None\n\n m = AUTHOR_REGEX.match(author)\n if not m:\n raise ValueError(\n \"Invalid author string. Must be in the format: \"\n \"John Smith \"\n )\n\n return author\n\n @staticmethod\n def _validate_package(package: str | None) -> str | None:\n if package and len(package.split()) > 2:\n raise ValueError(\"Invalid package definition.\")\n\n return package\n\n def _get_pool(self) -> RepositoryPool:\n from poetry.config.config import Config\n from poetry.repositories import RepositoryPool\n from poetry.repositories.pypi_repository import PyPiRepository\n\n if isinstance(self, EnvCommand):\n return self.poetry.pool\n\n if self._pool is None:\n self._pool = RepositoryPool()\n pool_size = Config.create().installer_max_workers\n self._pool.add_repository(PyPiRepository(pool_size=pool_size))\n\n return self._pool\n\n# Path: src/poetry/console/commands/install.py\nfrom __future__ import annotations\n\nfrom typing import TYPE_CHECKING\nfrom typing import ClassVar\n\nfrom cleo.helpers import option\n\nfrom poetry.console.commands.installer_command import InstallerCommand\n\n\nif TYPE_CHECKING:\n from cleo.io.inputs.option import Option\n\n\nclass InstallCommand(InstallerCommand):\n name = \"install\"\n description = \"Installs the project dependencies.\"\n\n options: ClassVar[list[Option]] = [\n *InstallerCommand._group_dependency_options(),\n option(\n \"no-dev\",\n None,\n \"Do not install the development dependencies.\"\n \" (Deprecated)\",\n ),\n option(\n \"sync\",\n None,\n \"Synchronize the environment with the locked packages and the specified\"\n \" groups.\",\n ),\n option(\n \"no-root\", None, \"Do not install the root package (the current project).\"\n ),\n option(\n \"no-directory\",\n None,\n \"Do not install any directory path dependencies; useful to install\"\n \" dependencies without source code, e.g. for caching of Docker layers)\",\n flag=True,\n multiple=False,\n ),\n option(\n \"dry-run\",\n None,\n \"Output the operations but do not execute anything \"\n \"(implicitly enables --verbose).\",\n ),\n option(\n \"remove-untracked\",\n None,\n \"Removes packages not present in the lock file.\"\n \" (Deprecated)\",\n ),\n option(\n \"extras\",\n \"E\",\n \"Extra sets of dependencies to install.\",\n flag=False,\n multiple=True,\n ),\n option(\"all-extras\", None, \"Install all extra dependencies.\"),\n option(\"only-root\", None, \"Exclude all dependencies.\"),\n option(\n \"compile\",\n None,\n \"Compile Python source files to bytecode.\"\n \" (This option has no effect if modern-installation is disabled\"\n \" because the old installer always compiles.)\",\n ),\n ]\n\n help = \"\"\"\\\nThe install command reads the poetry.lock file from\nthe current directory, processes it, and downloads and installs all the\nlibraries and dependencies outlined in that file. If the file does not\nexist it will look for pyproject.toml and do the same.\n\npoetry install\n\nBy default, the above command will also install the current project. To install only the\ndependencies and not including the current project, run the command with the\n--no-root option like below:\n\n poetry install --no-root\n\nIf you want to use Poetry only for dependency management but not for packaging,\nyou can set the \"package-mode\" to false in your pyproject.toml file.\n\"\"\"\n\n _loggers: ClassVar[list[str]] = [\n \"poetry.repositories.pypi_repository\",\n \"poetry.inspection.info\",\n ]\n\n @property\n def activated_groups(self) -> set[str]:\n if self.option(\"only-root\"):\n return set()\n else:\n return super().activated_groups\n\n def handle(self) -> int:\n from poetry.core.masonry.utils.module import ModuleOrPackageNotFound\n\n from poetry.masonry.builders.editable import EditableBuilder\n\n if self.option(\"extras\") and self.option(\"all-extras\"):\n self.line_error(\n \"You cannot specify explicit\"\n \" `--extras` while installing\"\n \" using `--all-extras`.\"\n )\n return 1\n\n if self.option(\"only-root\") and any(\n self.option(key) for key in {\"with\", \"without\", \"only\"}\n ):\n self.line_error(\n \"The `--with`,\"\n \" `--without` and\"\n \" `--only` options cannot be used with\"\n \" the `--only-root`\"\n \" option.\"\n )\n return 1\n\n if self.option(\"only-root\") and self.option(\"no-root\"):\n self.line_error(\n \"You cannot specify `--no-root`\"\n \" when using `--only-root`.\"\n )\n return 1\n\n extras: list[str]\n if self.option(\"all-extras\"):\n extras = list(self.poetry.package.extras.keys())\n else:\n extras = []\n for extra in self.option(\"extras\", []):\n extras += extra.split()\n\n self.installer.extras(extras)\n\n with_synchronization = self.option(\"sync\")\n if self.option(\"remove-untracked\"):\n self.line_error(\n \"The `--remove-untracked` option is\"\n \" deprecated, use the `--sync` option\"\n \" instead.\"\n )\n\n with_synchronization = True\n\n self.installer.only_groups(self.activated_groups)\n self.installer.skip_directory(self.option(\"no-directory\"))\n self.installer.dry_run(self.option(\"dry-run\"))\n self.installer.requires_synchronization(with_synchronization)\n self.installer.executor.enable_bytecode_compilation(self.option(\"compile\"))\n self.installer.verbose(self.io.is_verbose())\n\n return_code = self.installer.run()\n\n if return_code != 0:\n return return_code\n\n if self.option(\"no-root\") or not self.poetry.is_package_mode:\n return 0\n\n log_install = (\n \"Installing the current project:\"\n f\" {self.poetry.package.pretty_name}\"\n f\" (<{{tag}}>{self.poetry.package.pretty_version})\"\n )\n overwrite = self.io.output.is_decorated() and not self.io.is_debug()\n self.line(\"\")\n self.write(log_install.format(tag=\"c2\"))\n if not overwrite:\n self.line(\"\")\n\n if self.option(\"dry-run\"):\n self.line(\"\")\n return 0\n\n # Prior to https://github.com/python-poetry/poetry-core/pull/629\n # the existence of a module/package was checked when creating the\n # EditableBuilder. Afterwards, the existence is checked after\n # executing the build script (if there is one),\n # i.e. during EditableBuilder.build().\n try:\n builder = EditableBuilder(self.poetry, self.env, self.io)\n builder.build()\n except (ModuleOrPackageNotFound, FileNotFoundError) as e:\n # This is likely due to the fact that the project is an application\n # not following the structure expected by Poetry.\n # No need for an editable install in this case.\n self.line(\"\")\n self.line_error(\n f\"Warning: The current project could not be installed: {e}\\n\"\n \"If you do not want to install the current project\"\n \" use --no-root.\\n\"\n \"If you want to use Poetry only for dependency management\"\n \" but not for packaging, you can disable package mode by setting\"\n \" package-mode = false in your pyproject.toml file.\\n\"\n \"In a future version of Poetry this warning will become an error!\",\n style=\"warning\",\n )\n return 0\n\n if overwrite:\n self.overwrite(log_install.format(tag=\"success\"))\n self.line(\"\")\n\n return 0\n\n# Path: src/poetry/console/commands/installer_command.py\nfrom __future__ import annotations\n\nfrom typing import TYPE_CHECKING\n\nfrom poetry.console.commands.env_command import EnvCommand\nfrom poetry.console.commands.group_command import GroupCommand\n\n\nif TYPE_CHECKING:\n from poetry.installation.installer import Installer\n\n\nclass InstallerCommand(GroupCommand, EnvCommand):\n def __init__(self) -> None:\n # Set in poetry.console.application.Application.configure_installer\n self._installer: Installer | None = None\n\n super().__init__()\n\n def reset_poetry(self) -> None:\n super().reset_poetry()\n\n self.installer.set_package(self.poetry.package)\n self.installer.set_locker(self.poetry.locker)\n\n @property\n def installer(self) -> Installer:\n assert self._installer is not None\n return self._installer\n\n def set_installer(self, installer: Installer) -> None:\n self._installer = installer\n\n# Path: src/poetry/console/commands/lock.py\nfrom __future__ import annotations\n\nfrom typing import TYPE_CHECKING\nfrom typing import ClassVar\n\nfrom cleo.helpers import option\n\nfrom poetry.console.commands.installer_command import InstallerCommand\n\n\nif TYPE_CHECKING:\n from cleo.io.inputs.option import Option\n\n\nclass LockCommand(InstallerCommand):\n name = \"lock\"\n description = \"Locks the project dependencies.\"\n\n options: ClassVar[list[Option]] = [\n option(\n \"no-update\", None, \"Do not update locked versions, only refresh lock file.\"\n ),\n option(\n \"check\",\n None,\n \"Check that the poetry.lock file corresponds to the current\"\n \" version of pyproject.toml. (Deprecated) Use\"\n \" poetry check --lock instead.\",\n ),\n ]\n\n help = \"\"\"\nThe lock command reads the pyproject.toml file from the\ncurrent directory, processes it, and locks the dependencies in the\\\n poetry.lock\nfile.\n\npoetry lock\n\"\"\"\n\n loggers: ClassVar[list[str]] = [\"poetry.repositories.pypi_repository\"]\n\n def handle(self) -> int:\n if self.option(\"check\"):\n self.line_error(\n \"poetry lock --check is deprecated, use `poetry\"\n \" check --lock` instead.\"\n )\n if self.poetry.locker.is_locked() and self.poetry.locker.is_fresh():\n self.line(\"poetry.lock is consistent with pyproject.toml.\")\n return 0\n self.line_error(\n \"\"\n \"Error: pyproject.toml changed significantly since poetry.lock was last generated. \"\n \"Run `poetry lock [--no-update]` to fix the lock file.\"\n \"\"\n )\n return 1\n\n self.installer.lock(update=not self.option(\"no-update\"))\n\n return self.installer.run()\n\n# Path: src/poetry/console/commands/new.py\nfrom __future__ import annotations\n\nfrom typing import TYPE_CHECKING\nfrom typing import ClassVar\n\nfrom cleo.helpers import argument\nfrom cleo.helpers import option\n\nfrom poetry.console.commands.init import InitCommand\n\n\nif TYPE_CHECKING:\n from cleo.io.inputs.argument import Argument\n from cleo.io.inputs.option import Option\n\n\nclass NewCommand(InitCommand):\n name = \"new\"\n description = \"Creates a new Python project at .\"\n\n arguments: ClassVar[list[Argument]] = [\n argument(\"path\", \"The path to create the project at.\")\n ]\n options: ClassVar[list[Option]] = [\n option(\n \"interactive\",\n \"i\",\n \"Allow interactive specification of project configuration.\",\n flag=True,\n ),\n option(\"name\", None, \"Set the resulting package name.\", flag=False),\n option(\"src\", None, \"Use the src layout for the project.\"),\n option(\n \"readme\",\n None,\n \"Specify the readme file format. One of md (default) or rst\",\n flag=False,\n ),\n *[\n o\n for o in InitCommand.options\n if o.name\n in {\n \"description\",\n \"author\",\n \"python\",\n \"dependency\",\n \"dev-dependency\",\n \"license\",\n }\n ],\n ]\n\n def handle(self) -> int:\n from pathlib import Path\n\n if self.io.input.option(\"directory\"):\n self.line_error(\n \"--directory only makes sense with existing projects, and will\"\n \" be ignored. You should consider the option --path instead.\"\n )\n\n path = Path(self.argument(\"path\"))\n if not path.is_absolute():\n # we do not use resolve here due to compatibility issues\n # for path.resolve(strict=False)\n path = Path.cwd().joinpath(path)\n\n if path.exists() and list(path.glob(\"*\")):\n # Directory is not empty. Aborting.\n raise RuntimeError(\n f\"Destination {path} exists and is not empty\"\n )\n\n return self._init_pyproject(\n project_path=path,\n allow_interactive=self.option(\"interactive\"),\n layout_name=\"src\" if self.option(\"src\") else \"standard\",\n readme_format=self.option(\"readme\") or \"md\",\n )\n\n# Path: src/poetry/console/commands/publish.py\nfrom __future__ import annotations\n\nfrom pathlib import Path\nfrom typing import TYPE_CHECKING\nfrom typing import ClassVar\n\nfrom cleo.helpers import option\n\nfrom poetry.console.commands.command import Command\n\n\nif TYPE_CHECKING:\n from cleo.io.inputs.option import Option\n\n\nclass PublishCommand(Command):\n name = \"publish\"\n description = \"Publishes a package to a remote repository.\"\n\n options: ClassVar[list[Option]] = [\n option(\n \"repository\", \"r\", \"The repository to publish the package to.\", flag=False\n ),\n option(\"username\", \"u\", \"The username to access the repository.\", flag=False),\n option(\"password\", \"p\", \"The password to access the repository.\", flag=False),\n option(\n \"cert\", None, \"Certificate authority to access the repository.\", flag=False\n ),\n option(\n \"client-cert\",\n None,\n \"Client certificate to access the repository.\",\n flag=False,\n ),\n option(\n \"dist-dir\",\n None,\n \"Dist directory where built artifact are stored. Default is `dist`.\",\n default=\"dist\",\n flag=False,\n ),\n option(\"build\", None, \"Build the package before publishing.\"),\n option(\"dry-run\", None, \"Perform all actions except upload the package.\"),\n option(\n \"skip-existing\",\n None,\n \"Ignore errors from files already existing in the repository.\",\n ),\n ]\n\n help = \"\"\"The publish command builds and uploads the package to a remote repository.\n\nBy default, it will upload to PyPI but if you pass the --repository option it will\nupload to it instead.\n\nThe --repository option should match the name of a configured repository using\nthe config command.\n\"\"\"\n\n loggers: ClassVar[list[str]] = [\"poetry.publishing.publisher\"]\n\n def handle(self) -> int:\n from poetry.publishing.publisher import Publisher\n\n if not self.poetry.is_package_mode:\n self.line_error(\"Publishing a package is not possible in non-package mode.\")\n return 1\n\n dist_dir = self.option(\"dist-dir\")\n\n publisher = Publisher(self.poetry, self.io, Path(dist_dir))\n\n # Building package first, if told\n if self.option(\"build\"):\n if publisher.files and not self.confirm(\n f\"There are {len(publisher.files)} files ready for\"\n \" publishing. Build anyway?\"\n ):\n self.line_error(\"Aborted!\")\n\n return 1\n\n self.call(\"build\", args=f\"--output {dist_dir}\")\n\n files = publisher.files\n if not files:\n self.line_error(\n \"No files to publish. \"\n \"Run poetry build first or use the --build option.\"\n )\n\n return 1\n\n self.line(\"\")\n\n cert = Path(self.option(\"cert\")) if self.option(\"cert\") else None\n client_cert = (\n Path(self.option(\"client-cert\")) if self.option(\"client-cert\") else None\n )\n\n publisher.publish(\n self.option(\"repository\"),\n self.option(\"username\"),\n self.option(\"password\"),\n cert,\n client_cert,\n self.option(\"dry-run\"),\n self.option(\"skip-existing\"),\n )\n\n return 0\n\n# Path: src/poetry/console/commands/remove.py\nfrom __future__ import annotations\n\nfrom typing import TYPE_CHECKING\nfrom typing import Any\nfrom typing import ClassVar\n\nfrom cleo.helpers import argument\nfrom cleo.helpers import option\nfrom packaging.utils import canonicalize_name\nfrom poetry.core.packages.dependency_group import MAIN_GROUP\nfrom tomlkit.toml_document import TOMLDocument\n\nfrom poetry.console.commands.installer_command import InstallerCommand\n\n\nif TYPE_CHECKING:\n from cleo.io.inputs.argument import Argument\n from cleo.io.inputs.option import Option\n\n\nclass RemoveCommand(InstallerCommand):\n name = \"remove\"\n description = \"Removes a package from the project dependencies.\"\n\n arguments: ClassVar[list[Argument]] = [\n argument(\"packages\", \"The packages to remove.\", multiple=True)\n ]\n options: ClassVar[list[Option]] = [\n option(\"group\", \"G\", \"The group to remove the dependency from.\", flag=False),\n option(\n \"dev\",\n \"D\",\n \"Remove a package from the development dependencies.\"\n \" (Deprecated)\"\n \" Use --group=dev instead.\",\n ),\n option(\n \"dry-run\",\n None,\n \"Output the operations but do not execute anything \"\n \"(implicitly enables --verbose).\",\n ),\n option(\"lock\", None, \"Do not perform operations (only update the lockfile).\"),\n ]\n\n help = \"\"\"The remove command removes a package from the current\nlist of installed packages\n\npoetry remove\"\"\"\n\n loggers: ClassVar[list[str]] = [\n \"poetry.repositories.pypi_repository\",\n \"poetry.inspection.info\",\n ]\n\n def handle(self) -> int:\n packages = self.argument(\"packages\")\n\n if self.option(\"dev\"):\n self.line_error(\n \"The --dev option is deprecated, \"\n \"use the `--group dev` notation instead.\"\n )\n group = \"dev\"\n else:\n group = self.option(\"group\", self.default_group)\n\n content: dict[str, Any] = self.poetry.file.read()\n poetry_content = content[\"tool\"][\"poetry\"]\n\n if group is None:\n removed = []\n group_sections = [\n (group_name, group_section.get(\"dependencies\", {}))\n for group_name, group_section in poetry_content.get(\"group\", {}).items()\n ]\n\n for group_name, section in [\n (MAIN_GROUP, poetry_content[\"dependencies\"]),\n *group_sections,\n ]:\n removed += self._remove_packages(packages, section, group_name)\n if group_name != MAIN_GROUP:\n if not section:\n del poetry_content[\"group\"][group_name]\n else:\n poetry_content[\"group\"][group_name][\"dependencies\"] = section\n elif group == \"dev\" and \"dev-dependencies\" in poetry_content:\n # We need to account for the old `dev-dependencies` section\n removed = self._remove_packages(\n packages, poetry_content[\"dev-dependencies\"], \"dev\"\n )\n\n if not poetry_content[\"dev-dependencies\"]:\n del poetry_content[\"dev-dependencies\"]\n else:\n removed = []\n if \"group\" in poetry_content:\n if group in poetry_content[\"group\"]:\n removed = self._remove_packages(\n packages,\n poetry_content[\"group\"][group].get(\"dependencies\", {}),\n group,\n )\n\n if not poetry_content[\"group\"][group]:\n del poetry_content[\"group\"][group]\n\n if \"group\" in poetry_content and not poetry_content[\"group\"]:\n del poetry_content[\"group\"]\n\n removed_set = set(removed)\n not_found = set(packages).difference(removed_set)\n if not_found:\n raise ValueError(\n \"The following packages were not found: \" + \", \".join(sorted(not_found))\n )\n\n # Refresh the locker\n self.poetry.locker.set_local_config(poetry_content)\n self.installer.set_locker(self.poetry.locker)\n self.installer.set_package(self.poetry.package)\n self.installer.dry_run(self.option(\"dry-run\", False))\n self.installer.verbose(self.io.is_verbose())\n self.installer.update(True)\n self.installer.execute_operations(not self.option(\"lock\"))\n self.installer.whitelist(removed_set)\n\n status = self.installer.run()\n\n if not self.option(\"dry-run\") and status == 0:\n assert isinstance(content, TOMLDocument)\n self.poetry.file.write(content)\n\n return status\n\n def _remove_packages(\n self, packages: list[str], section: dict[str, Any], group_name: str\n ) -> list[str]:\n removed = []\n group = self.poetry.package.dependency_group(group_name)\n section_keys = list(section.keys())\n\n for package in packages:\n for existing_package in section_keys:\n if canonicalize_name(existing_package) == canonicalize_name(package):\n del section[existing_package]\n removed.append(package)\n group.remove_dependency(package)\n\n return removed\n\n# Path: src/poetry/console/commands/run.py\nfrom __future__ import annotations\n\nfrom typing import TYPE_CHECKING\nfrom typing import ClassVar\n\nfrom cleo.helpers import argument\n\nfrom poetry.console.commands.env_command import EnvCommand\nfrom poetry.utils._compat import WINDOWS\n\n\nif TYPE_CHECKING:\n from cleo.io.inputs.argument import Argument\n from poetry.core.masonry.utils.module import Module\n\n\nclass RunCommand(EnvCommand):\n name = \"run\"\n description = \"Runs a command in the appropriate environment.\"\n\n arguments: ClassVar[list[Argument]] = [\n argument(\"args\", \"The command and arguments/options to run.\", multiple=True)\n ]\n\n def handle(self) -> int:\n args = self.argument(\"args\")\n script = args[0]\n scripts = self.poetry.local_config.get(\"scripts\")\n\n if scripts and script in scripts:\n return self.run_script(scripts[script], args)\n\n try:\n return self.env.execute(*args)\n except FileNotFoundError:\n self.line_error(f\"Command not found: {script}\")\n return 1\n\n @property\n def _module(self) -> Module:\n from poetry.core.masonry.utils.module import Module\n\n poetry = self.poetry\n package = poetry.package\n path = poetry.file.path.parent\n module = Module(package.name, path.as_posix(), package.packages)\n\n return module\n\n def run_script(self, script: str | dict[str, str], args: list[str]) -> int:\n \"\"\"Runs an entry point script defined in the section ``[tool.poetry.scripts]``.\n\n When a script exists in the venv bin folder, i.e. after ``poetry install``,\n then ``sys.argv[0]`` must be set to the full path of the executable, so\n ``poetry run foo`` and ``poetry shell``, ``foo`` have the same ``sys.argv[0]``\n that points to the full path.\n\n Otherwise (when an entry point script does not exist), ``sys.argv[0]`` is the\n script name only, i.e. ``poetry run foo`` has ``sys.argv == ['foo']``.\n \"\"\"\n for script_dir in self.env.script_dirs:\n script_path = script_dir / args[0]\n if WINDOWS:\n script_path = script_path.with_suffix(\".cmd\")\n if script_path.exists():\n args = [str(script_path), *args[1:]]\n break\n else:\n # If we reach this point, the script is not installed\n self._warning_not_installed_script(args[0])\n\n if isinstance(script, dict):\n script = script[\"callable\"]\n\n module, callable_ = script.split(\":\")\n\n src_in_sys_path = \"sys.path.append('src'); \" if self._module.is_in_src() else \"\"\n\n cmd = [\"python\", \"-c\"]\n\n cmd += [\n \"import sys; \"\n \"from importlib import import_module; \"\n f\"sys.argv = {args!r}; {src_in_sys_path}\"\n f\"sys.exit(import_module('{module}').{callable_}())\"\n ]\n\n return self.env.execute(*cmd)\n\n def _warning_not_installed_script(self, script: str) -> None:\n message = f\"\"\"\\\nWarning: '{script}' is an entry point defined in pyproject.toml, but it's not \\\ninstalled as a script. You may get improper `sys.argv[0]`.\n\nThe support to run uninstalled scripts will be removed in a future release.\n\nRun `poetry install` to resolve and get rid of this message.\n\"\"\"\n self.line_error(message, style=\"warning\")\n\n# Path: src/poetry/console/commands/search.py\nfrom __future__ import annotations\n\nfrom typing import TYPE_CHECKING\nfrom typing import ClassVar\n\nfrom cleo.helpers import argument\n\nfrom poetry.console.commands.command import Command\n\n\nif TYPE_CHECKING:\n from cleo.io.inputs.argument import Argument\n\n\nclass SearchCommand(Command):\n name = \"search\"\n description = \"Searches for packages on remote repositories.\"\n\n arguments: ClassVar[list[Argument]] = [\n argument(\"tokens\", \"The tokens to search for.\", multiple=True)\n ]\n\n def handle(self) -> int:\n from poetry.repositories.pypi_repository import PyPiRepository\n\n results = PyPiRepository().search(self.argument(\"tokens\"))\n\n for result in results:\n self.line(\"\")\n name = f\"{result.name}\"\n\n name += f\" ({result.version})\"\n\n self.line(name)\n\n if result.description:\n self.line(f\" {result.description}\")\n\n return 0\n\n# Path: src/poetry/console/commands/shell.py\nfrom __future__ import annotations\n\nimport os\nimport sys\n\nfrom typing import TYPE_CHECKING\nfrom typing import cast\n\nfrom poetry.console.commands.env_command import EnvCommand\n\n\nif TYPE_CHECKING:\n from poetry.utils.env import VirtualEnv\n\n\nclass ShellCommand(EnvCommand):\n name = \"shell\"\n description = \"Spawns a shell within the virtual environment.\"\n\n help = f\"\"\"The shell command spawns a shell within the project's virtual environment.\n\nBy default, the current active shell is detected and used. Failing that,\nthe shell defined via the environment variable {'COMSPEC' if os.name == 'nt' else 'SHELL'} is used.\n\nIf a virtual environment does not exist, it will be created.\n\"\"\"\n\n def handle(self) -> int:\n from poetry.utils.shell import Shell\n\n # Check if it's already activated or doesn't exist and won't be created\n if self._is_venv_activated():\n self.line(\n f\"Virtual environment already activated: {self.env.path}\"\n )\n\n return 0\n\n self.line(f\"Spawning shell within {self.env.path}\")\n\n # Be sure that we have the right type of environment.\n env = self.env\n assert env.is_venv()\n env = cast(\"VirtualEnv\", env)\n\n # Setting this to avoid spawning unnecessary nested shells\n os.environ[\"POETRY_ACTIVE\"] = \"1\"\n shell = Shell.get()\n shell.activate(env)\n os.environ.pop(\"POETRY_ACTIVE\")\n\n return 0\n\n def _is_venv_activated(self) -> bool:\n return bool(os.environ.get(\"POETRY_ACTIVE\")) or getattr(\n sys, \"real_prefix\", sys.prefix\n ) == str(self.env.path)\n\n# Path: src/poetry/console/commands/show.py\nfrom __future__ import annotations\n\nfrom typing import TYPE_CHECKING\nfrom typing import ClassVar\n\nfrom cleo.helpers import argument\nfrom cleo.helpers import option\nfrom packaging.utils import canonicalize_name\n\nfrom poetry.console.commands.env_command import EnvCommand\nfrom poetry.console.commands.group_command import GroupCommand\n\n\nif TYPE_CHECKING:\n from cleo.io.inputs.argument import Argument\n from cleo.io.inputs.option import Option\n from cleo.io.io import IO\n from cleo.ui.table import Rows\n from packaging.utils import NormalizedName\n from poetry.core.packages.dependency import Dependency\n from poetry.core.packages.package import Package\n from poetry.core.packages.project_package import ProjectPackage\n\n from poetry.repositories.repository import Repository\n\n\ndef reverse_deps(pkg: Package, repo: Repository) -> dict[str, str]:\n required_by = {}\n for locked in repo.packages:\n dependencies = {d.name: d.pretty_constraint for d in locked.requires}\n\n if pkg.name in dependencies:\n required_by[locked.pretty_name] = dependencies[pkg.name]\n\n return required_by\n\n\nclass ShowCommand(GroupCommand, EnvCommand):\n name = \"show\"\n description = \"Shows information about packages.\"\n\n arguments: ClassVar[list[Argument]] = [\n argument(\"package\", \"The package to inspect\", optional=True)\n ]\n options: ClassVar[list[Option]] = [\n *GroupCommand._group_dependency_options(),\n option(\n \"no-dev\",\n None,\n \"Do not list the development dependencies. (Deprecated)\",\n ),\n option(\"tree\", \"t\", \"List the dependencies as a tree.\"),\n option(\n \"why\",\n None,\n \"When showing the full list, or a --tree for a single package,\"\n \" display whether they are a direct dependency or required by other\"\n \" packages\",\n ),\n option(\"latest\", \"l\", \"Show the latest version.\"),\n option(\n \"outdated\",\n \"o\",\n \"Show the latest version but only for packages that are outdated.\",\n ),\n option(\n \"all\",\n \"a\",\n \"Show all packages (even those not compatible with current system).\",\n ),\n option(\"top-level\", \"T\", \"Show only top-level dependencies.\"),\n ]\n\n help = \"\"\"The show command displays detailed information about a package, or\nlists all packages available.\"\"\"\n\n colors: ClassVar[list[str]] = [\"cyan\", \"yellow\", \"green\", \"magenta\", \"blue\"]\n\n def handle(self) -> int:\n package = self.argument(\"package\")\n\n if self.option(\"tree\"):\n self.init_styles(self.io)\n\n if self.option(\"top-level\"):\n if self.option(\"tree\"):\n self.line_error(\n \"Error: Cannot use --tree and --top-level at the same\"\n \" time.\"\n )\n return 1\n if package is not None:\n self.line_error(\n \"Error: Cannot use --top-level when displaying a single\"\n \" package.\"\n )\n return 1\n\n if self.option(\"why\"):\n if self.option(\"tree\") and package is None:\n self.line_error(\n \"Error: --why requires a package when combined with\"\n \" --tree.\"\n )\n\n return 1\n\n if not self.option(\"tree\") and package:\n self.line_error(\n \"Error: --why cannot be used without --tree when displaying\"\n \" a single package.\"\n )\n\n return 1\n\n if self.option(\"outdated\"):\n self.io.input.set_option(\"latest\", True)\n\n if not self.poetry.locker.is_locked():\n self.line_error(\n \"Error: poetry.lock not found. Run `poetry lock` to create\"\n \" it.\"\n )\n return 1\n\n locked_repo = self.poetry.locker.locked_repository()\n\n if package:\n return self._display_single_package_information(package, locked_repo)\n\n root = self.project_with_activated_groups_only()\n\n # Show tree view if requested\n if self.option(\"tree\"):\n return self._display_packages_tree_information(locked_repo, root)\n\n return self._display_packages_information(locked_repo, root)\n\n def _display_single_package_information(\n self, package: str, locked_repository: Repository\n ) -> int:\n locked_packages = locked_repository.packages\n canonicalized_package = canonicalize_name(package)\n pkg = None\n\n for locked in locked_packages:\n if locked.name == canonicalized_package:\n pkg = locked\n break\n\n if not pkg:\n raise ValueError(f\"Package {package} not found\")\n\n required_by = reverse_deps(pkg, locked_repository)\n\n if self.option(\"tree\"):\n if self.option(\"why\"):\n # The default case if there's no reverse dependencies is to query\n # the subtree for pkg but if any rev-deps exist we'll query for each\n # of them in turn\n packages = [pkg]\n if required_by:\n packages = [\n p for p in locked_packages for r in required_by if p.name == r\n ]\n else:\n # if no rev-deps exist we'll make this clear as it can otherwise\n # look very odd for packages that also have no or few direct\n # dependencies\n self.io.write_line(f\"Package {package} is a direct dependency.\")\n\n for p in packages:\n self.display_package_tree(\n self.io, p, locked_packages, why_package=pkg\n )\n\n else:\n self.display_package_tree(self.io, pkg, locked_packages)\n\n return 0\n\n rows: Rows = [\n [\"name\", f\" : {pkg.pretty_name}\"],\n [\"version\", f\" : {pkg.pretty_version}\"],\n [\"description\", f\" : {pkg.description}\"],\n ]\n\n self.table(rows=rows, style=\"compact\").render()\n\n if pkg.requires:\n self.line(\"\")\n self.line(\"dependencies\")\n for dependency in pkg.requires:\n self.line(\n f\" - {dependency.pretty_name}\"\n f\" {dependency.pretty_constraint}\"\n )\n\n if required_by:\n self.line(\"\")\n self.line(\"required by\")\n for parent, requires_version in required_by.items():\n self.line(f\" - {parent} {requires_version}\")\n\n return 0\n\n def _display_packages_information(\n self, locked_repository: Repository, root: ProjectPackage\n ) -> int:\n import shutil\n\n from cleo.io.null_io import NullIO\n\n from poetry.puzzle.solver import Solver\n from poetry.repositories.installed_repository import InstalledRepository\n from poetry.repositories.repository_pool import RepositoryPool\n from poetry.utils.helpers import get_package_version_display_string\n\n locked_packages = locked_repository.packages\n pool = RepositoryPool.from_packages(locked_packages, self.poetry.config)\n solver = Solver(\n root,\n pool=pool,\n installed=[],\n locked=locked_packages,\n io=NullIO(),\n )\n solver.provider.load_deferred(False)\n with solver.use_environment(self.env):\n ops = solver.solve().calculate_operations()\n\n required_locked_packages = {op.package for op in ops if not op.skipped}\n\n show_latest = self.option(\"latest\")\n show_all = self.option(\"all\")\n show_top_level = self.option(\"top-level\")\n width = shutil.get_terminal_size().columns\n name_length = version_length = latest_length = required_by_length = 0\n latest_packages = {}\n latest_statuses = {}\n installed_repo = InstalledRepository.load(self.env)\n\n # Computing widths\n for locked in locked_packages:\n if locked not in required_locked_packages and not show_all:\n continue\n\n current_length = len(locked.pretty_name)\n if not self.io.output.is_decorated():\n installed_status = self.get_installed_status(\n locked, installed_repo.packages\n )\n\n if installed_status == \"not-installed\":\n current_length += 4\n\n if show_latest:\n latest = self.find_latest_package(locked, root)\n if not latest:\n latest = locked\n\n latest_packages[locked.pretty_name] = latest\n update_status = latest_statuses[locked.pretty_name] = (\n self.get_update_status(latest, locked)\n )\n\n if not self.option(\"outdated\") or update_status != \"up-to-date\":\n name_length = max(name_length, current_length)\n version_length = max(\n version_length,\n len(\n get_package_version_display_string(\n locked, root=self.poetry.file.path.parent\n )\n ),\n )\n latest_length = max(\n latest_length,\n len(\n get_package_version_display_string(\n latest, root=self.poetry.file.path.parent\n )\n ),\n )\n\n if self.option(\"why\"):\n required_by = reverse_deps(locked, locked_repository)\n required_by_length = max(\n required_by_length,\n len(\" from \" + \",\".join(required_by.keys())),\n )\n else:\n name_length = max(name_length, current_length)\n version_length = max(\n version_length,\n len(\n get_package_version_display_string(\n locked, root=self.poetry.file.path.parent\n )\n ),\n )\n\n if self.option(\"why\"):\n required_by = reverse_deps(locked, locked_repository)\n required_by_length = max(\n required_by_length, len(\" from \" + \",\".join(required_by.keys()))\n )\n\n write_version = name_length + version_length + 3 <= width\n write_latest = name_length + version_length + latest_length + 3 <= width\n\n why_end_column = (\n name_length + version_length + latest_length + required_by_length\n )\n write_why = self.option(\"why\") and (why_end_column + 3) <= width\n write_description = (why_end_column + 24) <= width\n\n requires = root.all_requires\n\n for locked in locked_packages:\n color = \"cyan\"\n name = locked.pretty_name\n install_marker = \"\"\n\n if show_top_level and not any(locked.satisfies(r) for r in requires):\n continue\n\n if locked not in required_locked_packages:\n if not show_all:\n continue\n\n color = \"black;options=bold\"\n else:\n installed_status = self.get_installed_status(\n locked, installed_repo.packages\n )\n if installed_status == \"not-installed\":\n color = \"red\"\n\n if not self.io.output.is_decorated():\n # Non installed in non decorated mode\n install_marker = \" (!)\"\n\n if (\n show_latest\n and self.option(\"outdated\")\n and latest_statuses[locked.pretty_name] == \"up-to-date\"\n ):\n continue\n\n line = (\n f\"\"\n f\"{name:{name_length - len(install_marker)}}{install_marker}\"\n )\n if write_version:\n version = get_package_version_display_string(\n locked, root=self.poetry.file.path.parent\n )\n line += f\" {version:{version_length}}\"\n if show_latest:\n latest = latest_packages[locked.pretty_name]\n update_status = latest_statuses[locked.pretty_name]\n\n if write_latest:\n color = \"green\"\n if update_status == \"semver-safe-update\":\n color = \"red\"\n elif update_status == \"update-possible\":\n color = \"yellow\"\n\n version = get_package_version_display_string(\n latest, root=self.poetry.file.path.parent\n )\n line += f\" {version:{latest_length}}\"\n\n if write_why:\n required_by = reverse_deps(locked, locked_repository)\n if required_by:\n content = \",\".join(required_by.keys())\n # subtract 6 for ' from '\n line += f\" from {content:{required_by_length - 6}}\"\n else:\n line += \" \" * required_by_length\n\n if write_description:\n description = locked.description\n remaining = (\n width - name_length - version_length - required_by_length - 4\n )\n\n if show_latest:\n remaining -= latest_length\n\n if len(locked.description) > remaining:\n description = description[: remaining - 3] + \"...\"\n\n line += \" \" + description\n\n self.line(line)\n\n return 0\n\n def _display_packages_tree_information(\n self, locked_repository: Repository, root: ProjectPackage\n ) -> int:\n packages = locked_repository.packages\n\n for p in packages:\n for require in root.all_requires:\n if p.name == require.name:\n self.display_package_tree(self.io, p, packages)\n break\n\n return 0\n\n def display_package_tree(\n self,\n io: IO,\n package: Package,\n installed_packages: list[Package],\n why_package: Package | None = None,\n ) -> None:\n io.write(f\"{package.pretty_name}\")\n description = \"\"\n if package.description:\n description = \" \" + package.description\n\n io.write_line(f\" {package.pretty_version}{description}\")\n\n if why_package is not None:\n dependencies = [p for p in package.requires if p.name == why_package.name]\n else:\n dependencies = package.requires\n dependencies = sorted(\n dependencies,\n key=lambda x: x.name,\n )\n\n tree_bar = \"\u251c\"\n total = len(dependencies)\n for i, dependency in enumerate(dependencies, 1):\n if i == total:\n tree_bar = \"\u2514\"\n\n level = 1\n color = self.colors[level]\n info = (\n f\"{tree_bar}\u2500\u2500 <{color}>{dependency.name}\"\n f\" {dependency.pretty_constraint}\"\n )\n self._write_tree_line(io, info)\n\n tree_bar = tree_bar.replace(\"\u2514\", \" \")\n packages_in_tree = [package.name, dependency.name]\n\n self._display_tree(\n io,\n dependency,\n installed_packages,\n packages_in_tree,\n tree_bar,\n level + 1,\n )\n\n def _display_tree(\n self,\n io: IO,\n dependency: Dependency,\n installed_packages: list[Package],\n packages_in_tree: list[NormalizedName],\n previous_tree_bar: str = \"\u251c\",\n level: int = 1,\n ) -> None:\n previous_tree_bar = previous_tree_bar.replace(\"\u251c\", \"\u2502\")\n\n dependencies = []\n for package in installed_packages:\n if package.name == dependency.name:\n dependencies = package.requires\n\n break\n\n dependencies = sorted(\n dependencies,\n key=lambda x: x.name,\n )\n tree_bar = previous_tree_bar + \" \u251c\"\n total = len(dependencies)\n for i, dependency in enumerate(dependencies, 1):\n current_tree = packages_in_tree\n if i == total:\n tree_bar = previous_tree_bar + \" \u2514\"\n\n color_ident = level % len(self.colors)\n color = self.colors[color_ident]\n\n circular_warn = \"\"\n if dependency.name in current_tree:\n circular_warn = \"(circular dependency aborted here)\"\n\n info = (\n f\"{tree_bar}\u2500\u2500 <{color}>{dependency.name}\"\n f\" {dependency.pretty_constraint} {circular_warn}\"\n )\n self._write_tree_line(io, info)\n\n tree_bar = tree_bar.replace(\"\u2514\", \" \")\n\n if dependency.name not in current_tree:\n current_tree.append(dependency.name)\n\n self._display_tree(\n io,\n dependency,\n installed_packages,\n current_tree,\n tree_bar,\n level + 1,\n )\n\n def _write_tree_line(self, io: IO, line: str) -> None:\n if not io.output.supports_utf8():\n line = line.replace(\"\u2514\", \"`-\")\n line = line.replace(\"\u251c\", \"|-\")\n line = line.replace(\"\u2500\u2500\", \"-\")\n line = line.replace(\"\u2502\", \"|\")\n\n io.write_line(line)\n\n def init_styles(self, io: IO) -> None:\n from cleo.formatters.style import Style\n\n for color in self.colors:\n style = Style(color)\n io.output.formatter.set_style(color, style)\n io.error_output.formatter.set_style(color, style)\n\n def find_latest_package(\n self, package: Package, root: ProjectPackage\n ) -> Package | None:\n from cleo.io.null_io import NullIO\n\n from poetry.puzzle.provider import Provider\n from poetry.version.version_selector import VersionSelector\n\n # find the latest version allowed in this pool\n requires = root.all_requires\n if package.is_direct_origin():\n for dep in requires:\n if dep.name == package.name and dep.source_type == package.source_type:\n provider = Provider(root, self.poetry.pool, NullIO())\n return provider.search_for_direct_origin_dependency(dep)\n\n allow_prereleases = False\n for dep in requires:\n if dep.name == package.name:\n allow_prereleases = dep.allows_prereleases()\n break\n\n name = package.name\n selector = VersionSelector(self.poetry.pool)\n\n return selector.find_best_candidate(\n name, f\">={package.pretty_version}\", allow_prereleases\n )\n\n def get_update_status(self, latest: Package, package: Package) -> str:\n from poetry.core.constraints.version import parse_constraint\n\n if latest.full_pretty_version == package.full_pretty_version:\n return \"up-to-date\"\n\n constraint = parse_constraint(\"^\" + package.pretty_version)\n\n if constraint.allows(latest.version):\n # It needs an immediate semver-compliant upgrade\n return \"semver-safe-update\"\n\n # it needs an upgrade but has potential BC breaks so is not urgent\n return \"update-possible\"\n\n def get_installed_status(\n self, locked: Package, installed_packages: list[Package]\n ) -> str:\n for package in installed_packages:\n if locked.name == package.name:\n return \"installed\"\n\n return \"not-installed\"\n\n# Path: src/poetry/console/commands/update.py\nfrom __future__ import annotations\n\nfrom typing import TYPE_CHECKING\nfrom typing import ClassVar\n\nfrom cleo.helpers import argument\nfrom cleo.helpers import option\n\nfrom poetry.console.commands.installer_command import InstallerCommand\n\n\nif TYPE_CHECKING:\n from cleo.io.inputs.argument import Argument\n from cleo.io.inputs.option import Option\n\n\nclass UpdateCommand(InstallerCommand):\n name = \"update\"\n description = (\n \"Update the dependencies as according to the pyproject.toml file.\"\n )\n\n arguments: ClassVar[list[Argument]] = [\n argument(\"packages\", \"The packages to update\", optional=True, multiple=True)\n ]\n options: ClassVar[list[Option]] = [\n *InstallerCommand._group_dependency_options(),\n option(\n \"no-dev\",\n None,\n \"Do not update the development dependencies.\"\n \" (Deprecated)\",\n ),\n option(\n \"sync\",\n None,\n \"Synchronize the environment with the locked packages and the specified\"\n \" groups.\",\n ),\n option(\n \"dry-run\",\n None,\n \"Output the operations but do not execute anything \"\n \"(implicitly enables --verbose).\",\n ),\n option(\"lock\", None, \"Do not perform operations (only update the lockfile).\"),\n ]\n\n loggers: ClassVar[list[str]] = [\"poetry.repositories.pypi_repository\"]\n\n def handle(self) -> int:\n packages = self.argument(\"packages\")\n if packages:\n self.installer.whitelist({name: \"*\" for name in packages})\n\n self.installer.only_groups(self.activated_groups)\n self.installer.dry_run(self.option(\"dry-run\"))\n self.installer.requires_synchronization(self.option(\"sync\"))\n self.installer.execute_operations(not self.option(\"lock\"))\n\n # Force update\n self.installer.update(True)\n\n return self.installer.run()\n\n# Path: src/poetry/console/commands/version.py\nfrom __future__ import annotations\n\nfrom typing import TYPE_CHECKING\nfrom typing import Any\nfrom typing import ClassVar\n\nfrom cleo.helpers import argument\nfrom cleo.helpers import option\nfrom poetry.core.version.exceptions import InvalidVersion\nfrom tomlkit.toml_document import TOMLDocument\n\nfrom poetry.console.commands.command import Command\n\n\nif TYPE_CHECKING:\n from cleo.io.inputs.argument import Argument\n from cleo.io.inputs.option import Option\n from poetry.core.constraints.version import Version\n\n\nclass VersionCommand(Command):\n name = \"version\"\n description = (\n \"Shows the version of the project or bumps it when a valid \"\n \"bump rule is provided.\"\n )\n\n arguments: ClassVar[list[Argument]] = [\n argument(\n \"version\",\n \"The version number or the rule to update the version.\",\n optional=True,\n ),\n ]\n options: ClassVar[list[Option]] = [\n option(\"short\", \"s\", \"Output the version number only\"),\n option(\n \"dry-run\",\n None,\n \"Do not update pyproject.toml file\",\n ),\n option(\"next-phase\", None, \"Increment the phase of the current version\"),\n ]\n\n help = \"\"\"\\\nThe version command shows the current version of the project or bumps the version of\nthe project and writes the new version back to pyproject.toml if a valid\nbump rule is provided.\n\nThe new version should ideally be a valid semver string or a valid bump rule:\npatch, minor, major, prepatch, preminor, premajor, prerelease.\n\"\"\"\n\n RESERVED: ClassVar[set[str]] = {\n \"major\",\n \"minor\",\n \"patch\",\n \"premajor\",\n \"preminor\",\n \"prepatch\",\n \"prerelease\",\n }\n\n def handle(self) -> int:\n version = self.argument(\"version\")\n\n if version:\n version = self.increment_version(\n self.poetry.package.pretty_version, version, self.option(\"next-phase\")\n )\n\n if self.option(\"short\"):\n self.line(version.to_string())\n else:\n self.line(\n f\"Bumping version from {self.poetry.package.pretty_version}\"\n f\" to {version}\"\n )\n\n if not self.option(\"dry-run\"):\n content: dict[str, Any] = self.poetry.file.read()\n poetry_content = content[\"tool\"][\"poetry\"]\n poetry_content[\"version\"] = version.text\n\n assert isinstance(content, TOMLDocument)\n self.poetry.file.write(content)\n else:\n if self.option(\"short\"):\n self.line(self.poetry.package.pretty_version)\n else:\n self.line(\n f\"{self.poetry.package.pretty_name}\"\n f\" {self.poetry.package.pretty_version}\"\n )\n\n return 0\n\n def increment_version(\n self, version: str, rule: str, next_phase: bool = False\n ) -> Version:\n from poetry.core.constraints.version import Version\n...\n", "needle_token_start": 705, "needle_token_end": 894, "code_context_ntokens": 16019, "output": ["```python\ndef _find_best_version_for_package(\n self,\n name: str,\n required_version: str | None = None,\n allow_prereleases: bool = False,\n source: str | None = None,\n ) -> tuple[str, str]:\n from poetry.version.version_selector import VersionSelector\n\n selector = VersionSelector(self._get_pool())\n package = selector.find_best_candidate(\n name, required_version, allow_prereleases=allow_prereleases, source=source\n )\n\n if not package:\n # TODO: find similar\n raise ValueError(f\"Could not find a matching version of package {name}\")\n\n return package.pretty_name, f\"^{package.version.to_string()}\"\n```"]} +{"repo": "python-poetry/poetry", "name": "get_package_vcs_properties_from_path", "language": "python", "path": "src/poetry/repositories/installed_repository.py", "position_ratio": 0.15, "description": "\nFunction Description:\n1. **Purpose**: To retrieve version control system (VCS) details for a given package directory, specifically identifying the type of VCS, its remote origin URL, and the current revision hash.\n2. **Input**: A path to the directory where the package is located.\n3. **Output**: A tuple containing three strings: the type of VCS (e.g., 'git'), the URL of the remote origin, and the current revision hash.\n4. **Procedure**: The function utilizes a VCS utility to extract information from the specified directory. It first checks the VCS type and then gathers the remote origin URL and the current revision hash from the VCS metadata associated with the directory.\n\n", "instruction": "Based on the function description and code context, please retrieve and repeat the exact described function from the code context in a code block wrapped by ```:", "template": "instruction\ncode_context\ndescription\ninstruction", "code_context": "# Path: src/poetry/repositories/http_repository.py\nfrom __future__ import annotations\n\nimport functools\nimport hashlib\n\nfrom contextlib import contextmanager\nfrom pathlib import Path\nfrom typing import TYPE_CHECKING\nfrom typing import Any\nfrom typing import Iterator\n\nimport requests\nimport requests.adapters\n\nfrom packaging.metadata import parse_email\nfrom poetry.core.constraints.version import parse_constraint\nfrom poetry.core.packages.dependency import Dependency\nfrom poetry.core.utils.helpers import temporary_directory\nfrom poetry.core.version.markers import parse_marker\n\nfrom poetry.config.config import Config\nfrom poetry.inspection.info import PackageInfo\nfrom poetry.inspection.lazy_wheel import LazyWheelUnsupportedError\nfrom poetry.inspection.lazy_wheel import metadata_from_wheel_url\nfrom poetry.repositories.cached_repository import CachedRepository\nfrom poetry.repositories.exceptions import PackageNotFound\nfrom poetry.repositories.exceptions import RepositoryError\nfrom poetry.repositories.link_sources.html import HTMLPage\nfrom poetry.utils.authenticator import Authenticator\nfrom poetry.utils.constants import REQUESTS_TIMEOUT\nfrom poetry.utils.helpers import HTTPRangeRequestSupported\nfrom poetry.utils.helpers import download_file\nfrom poetry.utils.helpers import get_highest_priority_hash_type\nfrom poetry.utils.patterns import wheel_file_re\n\n\nif TYPE_CHECKING:\n from packaging.utils import NormalizedName\n from poetry.core.packages.utils.link import Link\n\n from poetry.repositories.link_sources.base import LinkSource\n from poetry.utils.authenticator import RepositoryCertificateConfig\n\n\nclass HTTPRepository(CachedRepository):\n def __init__(\n self,\n name: str,\n url: str,\n config: Config | None = None,\n disable_cache: bool = False,\n pool_size: int = requests.adapters.DEFAULT_POOLSIZE,\n ) -> None:\n super().__init__(name, disable_cache, config)\n self._url = url\n if config is None:\n config = Config.create()\n self._authenticator = Authenticator(\n config=config,\n cache_id=name,\n disable_cache=disable_cache,\n pool_size=pool_size,\n )\n self._authenticator.add_repository(name, url)\n self.get_page = functools.lru_cache(maxsize=None)(self._get_page)\n\n self._lazy_wheel = config.get(\"solver.lazy-wheel\", True)\n # We are tracking if a domain supports range requests or not to avoid\n # unnecessary requests.\n # ATTENTION: A domain might support range requests only for some files, so the\n # meaning is as follows:\n # - Domain not in dict: We don't know anything.\n # - True: The domain supports range requests for at least some files.\n # - False: The domain does not support range requests for the files we tried.\n self._supports_range_requests: dict[str, bool] = {}\n\n @property\n def session(self) -> Authenticator:\n return self._authenticator\n\n @property\n def url(self) -> str:\n return self._url\n\n @property\n def certificates(self) -> RepositoryCertificateConfig:\n return self._authenticator.get_certs_for_url(self.url)\n\n @property\n def authenticated_url(self) -> str:\n return self._authenticator.authenticated_url(url=self.url)\n\n def _download(\n self, url: str, dest: Path, *, raise_accepts_ranges: bool = False\n ) -> None:\n return download_file(\n url, dest, session=self.session, raise_accepts_ranges=raise_accepts_ranges\n )\n\n @contextmanager\n def _cached_or_downloaded_file(\n self, link: Link, *, raise_accepts_ranges: bool = False\n ) -> Iterator[Path]:\n self._log(f\"Downloading: {link.url}\", level=\"debug\")\n with temporary_directory() as temp_dir:\n filepath = Path(temp_dir) / link.filename\n self._download(\n link.url, filepath, raise_accepts_ranges=raise_accepts_ranges\n )\n yield filepath\n\n def _get_info_from_wheel(self, link: Link) -> PackageInfo:\n from poetry.inspection.info import PackageInfo\n\n netloc = link.netloc\n\n # If \"lazy-wheel\" is enabled and the domain supports range requests\n # or we don't know yet, we try range requests.\n raise_accepts_ranges = self._lazy_wheel\n if self._lazy_wheel and self._supports_range_requests.get(netloc, True):\n try:\n package_info = PackageInfo.from_metadata(\n metadata_from_wheel_url(link.filename, link.url, self.session)\n )\n except LazyWheelUnsupportedError as e:\n # Do not set to False if we already know that the domain supports\n # range requests for some URLs!\n self._log(\n f\"Disabling lazy wheel support for {netloc}: {e}\",\n level=\"debug\",\n )\n raise_accepts_ranges = False\n self._supports_range_requests.setdefault(netloc, False)\n else:\n self._supports_range_requests[netloc] = True\n return package_info\n\n try:\n with self._cached_or_downloaded_file(\n link, raise_accepts_ranges=raise_accepts_ranges\n ) as filepath:\n return PackageInfo.from_wheel(filepath)\n except HTTPRangeRequestSupported:\n # The domain did not support range requests for the first URL(s) we tried,\n # but supports it for some URLs (especially the current URL),\n # so we abort the download, update _supports_range_requests to try\n # range requests for all files and use it for the current URL.\n self._log(\n f\"Abort downloading {link.url} because server supports range requests\",\n level=\"debug\",\n )\n self._supports_range_requests[netloc] = True\n return self._get_info_from_wheel(link)\n\n def _get_info_from_sdist(self, link: Link) -> PackageInfo:\n...\n# Path: src/poetry/repositories/installed_repository.py\nfrom __future__ import annotations\n\nimport itertools\nimport json\nimport logging\n\nfrom pathlib import Path\nfrom typing import TYPE_CHECKING\n\nfrom packaging.utils import canonicalize_name\nfrom poetry.core.packages.package import Package\nfrom poetry.core.packages.utils.utils import url_to_path\nfrom poetry.core.utils.helpers import module_name\n\nfrom poetry.repositories.repository import Repository\nfrom poetry.utils._compat import metadata\n\n\nif TYPE_CHECKING:\n from poetry.utils.env import Env\n\n\nlogger = logging.getLogger(__name__)\n\n\nclass InstalledRepository(Repository):\n def __init__(self) -> None:\n super().__init__(\"poetry-installed\")\n\n @classmethod\n def get_package_paths(cls, env: Env, name: str) -> set[Path]:\n \"\"\"\n Process a .pth file within the site-packages directories, and return any valid\n paths. We skip executable .pth files as there is no reliable means to do this\n without side-effects to current run-time. Mo check is made that the item refers\n to a directory rather than a file, however, in order to maintain backwards\n compatibility, we allow non-existing paths to be discovered. The latter\n behaviour is different to how Python's site-specific hook configuration works.\n\n Reference: https://docs.python.org/3.8/library/site.html\n\n :param env: The environment to search for the .pth file in.\n :param name: The name of the package to search .pth file for.\n :return: A `Set` of valid `Path` objects.\n \"\"\"\n paths = set()\n\n # we identify the candidate pth files to check, this is done so to handle cases\n # where the pth file for foo-bar might have been installed as either foo-bar.pth\n # or foo_bar.pth (expected) in either pure or platform lib directories.\n candidates = itertools.product(\n {env.purelib, env.platlib},\n {name, module_name(name)},\n )\n\n for lib, module in candidates:\n pth_file = lib.joinpath(module).with_suffix(\".pth\")\n if not pth_file.exists():\n continue\n\n with pth_file.open() as f:\n for line in f:\n line = line.strip()\n if line and not line.startswith((\"#\", \"import \", \"import\\t\")):\n path = Path(line)\n if not path.is_absolute():\n path = lib.joinpath(path).resolve()\n paths.add(path)\n\n src_path = env.path / \"src\" / name\n if not paths and src_path.exists():\n paths.add(src_path)\n\n return paths\n\n @classmethod\n \ndef get_package_vcs_properties_from_path(cls, src: Path) -> tuple[str, str, str]:\n from poetry.vcs.git import Git\n\n info = Git.info(repo=src)\n return \"git\", info.origin, info.revision\n\n @classmethod\n def is_vcs_package(cls, package: Path | Package, env: Env) -> bool:\n # A VCS dependency should have been installed\n # in the src directory.\n src = env.path / \"src\"\n if isinstance(package, Package):\n return src.joinpath(package.name).is_dir()\n\n try:\n package.relative_to(env.path / \"src\")\n except ValueError:\n return False\n else:\n return True\n\n @classmethod\n def create_package_from_distribution(\n cls, distribution: metadata.Distribution, env: Env\n ) -> Package:\n # We first check for a direct_url.json file to determine\n # the type of package.\n path = Path(str(distribution._path)) # type: ignore[attr-defined]\n\n if (\n path.name.endswith(\".dist-info\")\n and path.joinpath(\"direct_url.json\").exists()\n ):\n return cls.create_package_from_pep610(distribution)\n\n is_standard_package = env.is_path_relative_to_lib(path)\n\n source_type = None\n source_url = None\n source_reference = None\n source_resolved_reference = None\n source_subdirectory = None\n if is_standard_package:\n if path.name.endswith(\".dist-info\"):\n paths = cls.get_package_paths(\n env=env, name=distribution.metadata[\"name\"]\n )\n if paths:\n is_editable_package = False\n for src in paths:\n if cls.is_vcs_package(src, env):\n (\n source_type,\n source_url,\n source_reference,\n ) = cls.get_package_vcs_properties_from_path(src)\n break\n\n if not (\n is_editable_package or env.is_path_relative_to_lib(src)\n ):\n is_editable_package = True\n else:\n # TODO: handle multiple source directories?\n if is_editable_package:\n source_type = \"directory\"\n source_url = paths.pop().as_posix()\n elif cls.is_vcs_package(path, env):\n (\n source_type,\n source_url,\n source_reference,\n ) = cls.get_package_vcs_properties_from_path(\n env.path / \"src\" / canonicalize_name(distribution.metadata[\"name\"])\n )\n else:\n # If not, it's a path dependency\n source_type = \"directory\"\n source_url = str(path.parent)\n\n package = Package(\n distribution.metadata[\"name\"],\n distribution.metadata[\"version\"],\n source_type=source_type,\n source_url=source_url,\n source_reference=source_reference,\n source_resolved_reference=source_resolved_reference,\n source_subdirectory=source_subdirectory,\n )\n\n package.description = distribution.metadata.get( # type: ignore[attr-defined]\n \"summary\",\n \"\",\n )\n\n return package\n\n @classmethod\n def create_package_from_pep610(cls, distribution: metadata.Distribution) -> Package:\n path = Path(str(distribution._path)) # type: ignore[attr-defined]\n source_type = None\n source_url = None\n source_reference = None\n source_resolved_reference = None\n source_subdirectory = None\n develop = False\n\n url_reference = json.loads(\n path.joinpath(\"direct_url.json\").read_text(encoding=\"utf-8\")\n )\n if \"archive_info\" in url_reference:\n # File or URL distribution\n if url_reference[\"url\"].startswith(\"file:\"):\n # File distribution\n source_type = \"file\"\n source_url = url_to_path(url_reference[\"url\"]).as_posix()\n else:\n # URL distribution\n source_type = \"url\"\n source_url = url_reference[\"url\"]\n elif \"dir_info\" in url_reference:\n # Directory distribution\n source_type = \"directory\"\n source_url = url_to_path(url_reference[\"url\"]).as_posix()\n develop = url_reference[\"dir_info\"].get(\"editable\", False)\n elif \"vcs_info\" in url_reference:\n # VCS distribution\n source_type = url_reference[\"vcs_info\"][\"vcs\"]\n source_url = url_reference[\"url\"]\n source_resolved_reference = url_reference[\"vcs_info\"][\"commit_id\"]\n source_reference = url_reference[\"vcs_info\"].get(\n \"requested_revision\", source_resolved_reference\n )\n source_subdirectory = url_reference.get(\"subdirectory\")\n\n package = Package(\n distribution.metadata[\"name\"],\n distribution.metadata[\"version\"],\n source_type=source_type,\n source_url=source_url,\n source_reference=source_reference,\n source_resolved_reference=source_resolved_reference,\n source_subdirectory=source_subdirectory,\n develop=develop,\n )\n\n package.description = distribution.metadata.get( # type: ignore[attr-defined]\n \"summary\",\n \"\",\n )\n\n return package\n\n @classmethod\n def load(cls, env: Env, with_dependencies: bool = False) -> InstalledRepository:\n \"\"\"\n Load installed packages.\n \"\"\"\n from poetry.core.packages.dependency import Dependency\n\n repo = cls()\n seen = set()\n skipped = set()\n\n for entry in reversed(env.sys_path):\n if not entry.strip():\n logger.debug(\n \"Project environment contains an empty path in sys_path,\"\n \" ignoring.\"\n )\n continue\n\n for distribution in sorted(\n metadata.distributions(path=[entry]),\n key=lambda d: str(d._path), # type: ignore[attr-defined]\n ):\n path = Path(str(distribution._path)) # type: ignore[attr-defined]\n\n if path in skipped:\n continue\n\n name = distribution.metadata.get(\"name\") # type: ignore[attr-defined]\n if name is None:\n logger.warning(\n \"Project environment contains an invalid distribution\"\n \" (%s). Consider removing it manually or recreate\"\n \" the environment.\",\n path,\n )\n skipped.add(path)\n continue\n\n name = canonicalize_name(name)\n\n if name in seen:\n continue\n\n package = cls.create_package_from_distribution(distribution, env)\n\n if with_dependencies:\n for require in distribution.metadata.get_all(\"requires-dist\", []):\n dep = Dependency.create_from_pep_508(require)\n package.add_dependency(dep)\n\n seen.add(package.name)\n repo.add_package(package)\n\n return repo\n\n# Path: src/poetry/repositories/legacy_repository.py\nfrom __future__ import annotations\n\nfrom typing import TYPE_CHECKING\nfrom typing import Any\n\nimport requests.adapters\n\nfrom poetry.core.packages.package import Package\n\nfrom poetry.inspection.info import PackageInfo\nfrom poetry.repositories.exceptions import PackageNotFound\nfrom poetry.repositories.http_repository import HTTPRepository\nfrom poetry.repositories.link_sources.html import SimpleRepositoryPage\n\n\nif TYPE_CHECKING:\n from packaging.utils import NormalizedName\n from poetry.core.constraints.version import Version\n from poetry.core.constraints.version import VersionConstraint\n from poetry.core.packages.utils.link import Link\n\n from poetry.config.config import Config\n\n\nclass LegacyRepository(HTTPRepository):\n def __init__(\n self,\n name: str,\n url: str,\n config: Config | None = None,\n disable_cache: bool = False,\n pool_size: int = requests.adapters.DEFAULT_POOLSIZE,\n ) -> None:\n if name == \"pypi\":\n raise ValueError(\"The name [pypi] is reserved for repositories\")\n\n super().__init__(name, url.rstrip(\"/\"), config, disable_cache, pool_size)\n\n @property\n def packages(self) -> list[Package]:\n # LegacyRepository._packages is not populated and other implementations\n # implicitly rely on this (e.g. Pool.search via\n # LegacyRepository.search). To avoid special-casing Pool or changing\n # behavior, we stub and return an empty list.\n #\n # TODO: Rethinking search behaviour and design.\n # Ref: https://github.com/python-poetry/poetry/issues/2446 and\n # https://github.com/python-poetry/poetry/pull/6669#discussion_r990874908.\n return []\n\n def package(\n self, name: str, version: Version, extras: list[str] | None = None\n ) -> Package:\n \"\"\"\n Retrieve the release information.\n\n This is a heavy task which takes time.\n We have to download a package to get the dependencies.\n We also need to download every file matching this release\n to get the various hashes.\n\n Note that this will be cached so the subsequent operations\n should be much faster.\n \"\"\"\n try:\n index = self._packages.index(Package(name, version))\n\n return self._packages[index]\n except ValueError:\n package = super().package(name, version, extras)\n package._source_type = \"legacy\"\n package._source_url = self._url\n package._source_reference = self.name\n\n return package\n\n def find_links_for_package(self, package: Package) -> list[Link]:\n try:\n page = self.get_page(package.name)\n except PackageNotFound:\n return []\n\n return list(page.links_for_version(package.name, package.version))\n\n def _find_packages(\n self, name: NormalizedName, constraint: VersionConstraint\n ) -> list[Package]:\n \"\"\"\n Find packages on the remote server.\n \"\"\"\n try:\n page = self.get_page(name)\n except PackageNotFound:\n self._log(f\"No packages found for {name}\", level=\"debug\")\n return []\n\n versions = [\n (version, page.yanked(name, version))\n for version in page.versions(name)\n if constraint.allows(version)\n ]\n\n return [\n Package(\n name,\n version,\n source_type=\"legacy\",\n source_reference=self.name,\n source_url=self._url,\n yanked=yanked,\n )\n for version, yanked in versions\n ]\n\n def _get_release_info(\n self, name: NormalizedName, version: Version\n ) -> dict[str, Any]:\n page = self.get_page(name)\n\n links = list(page.links_for_version(name, version))\n yanked = page.yanked(name, version)\n\n return self._links_to_data(\n links,\n PackageInfo(\n name=name,\n version=version.text,\n summary=\"\",\n requires_dist=[],\n requires_python=None,\n files=[],\n yanked=yanked,\n cache_version=str(self.CACHE_VERSION),\n ),\n )\n\n def _get_page(self, name: NormalizedName) -> SimpleRepositoryPage:\n response = self._get_response(f\"/{name}/\")\n if not response:\n raise PackageNotFound(f\"Package [{name}] not found.\")\n return SimpleRepositoryPage(response.url, response.text)\n\n# Path: src/poetry/repositories/lockfile_repository.py\nfrom __future__ import annotations\n\nfrom typing import TYPE_CHECKING\n\nfrom poetry.repositories import Repository\n\n\nif TYPE_CHECKING:\n from poetry.core.packages.package import Package\n\n\nclass LockfileRepository(Repository):\n \"\"\"\n Special repository that distinguishes packages not only by name and version,\n but also by source type, url, etc.\n \"\"\"\n\n def __init__(self) -> None:\n super().__init__(\"poetry-lockfile\")\n\n def has_package(self, package: Package) -> bool:\n return any(p == package for p in self.packages)\n\n def remove_package(self, package: Package) -> None:\n index = None\n for i, repo_package in enumerate(self.packages):\n if repo_package == package:\n index = i\n break\n\n if index is not None:\n del self._packages[index]\n\n# Path: src/poetry/repositories/pypi_repository.py\nfrom __future__ import annotations\n\nimport logging\n\nfrom typing import TYPE_CHECKING\nfrom typing import Any\n\nimport requests\nimport requests.adapters\n\nfrom cachecontrol.controller import logger as cache_control_logger\nfrom poetry.core.packages.package import Package\nfrom poetry.core.packages.utils.link import Link\nfrom poetry.core.version.exceptions import InvalidVersion\n\nfrom poetry.repositories.exceptions import PackageNotFound\nfrom poetry.repositories.http_repository import HTTPRepository\nfrom poetry.repositories.link_sources.json import SimpleJsonPage\nfrom poetry.repositories.parsers.pypi_search_parser import SearchResultParser\nfrom poetry.utils.constants import REQUESTS_TIMEOUT\n\n\ncache_control_logger.setLevel(logging.ERROR)\n\nlogger = logging.getLogger(__name__)\n\nif TYPE_CHECKING:\n from packaging.utils import NormalizedName\n from poetry.core.constraints.version import Version\n from poetry.core.constraints.version import VersionConstraint\n\nSUPPORTED_PACKAGE_TYPES = {\"sdist\", \"bdist_wheel\"}\n\n\nclass PyPiRepository(HTTPRepository):\n def __init__(\n self,\n url: str = \"https://pypi.org/\",\n disable_cache: bool = False,\n fallback: bool = True,\n pool_size: int = requests.adapters.DEFAULT_POOLSIZE,\n ) -> None:\n super().__init__(\n \"PyPI\",\n url.rstrip(\"/\") + \"/simple/\",\n disable_cache=disable_cache,\n pool_size=pool_size,\n )\n\n self._base_url = url\n self._fallback = fallback\n\n def search(self, query: str) -> list[Package]:\n results = []\n\n response = requests.get(\n self._base_url + \"search\", params={\"q\": query}, timeout=REQUESTS_TIMEOUT\n )\n parser = SearchResultParser()\n parser.feed(response.text)\n\n for result in parser.results:\n try:\n package = Package(result.name, result.version)\n package.description = result.description.strip()\n results.append(package)\n except InvalidVersion:\n self._log(\n f'Unable to parse version \"{result.version}\" for the'\n f\" {result.name} package, skipping\",\n level=\"debug\",\n )\n\n return results\n\n def get_package_info(self, name: NormalizedName) -> dict[str, Any]:\n \"\"\"\n Return the package information given its name.\n\n The information is returned from the cache if it exists\n or retrieved from the remote server.\n \"\"\"\n return self._get_package_info(name)\n\n def _find_packages(\n self, name: NormalizedName, constraint: VersionConstraint\n ) -> list[Package]:\n \"\"\"\n Find packages on the remote server.\n \"\"\"\n try:\n json_page = self.get_page(name)\n except PackageNotFound:\n self._log(f\"No packages found for {name}\", level=\"debug\")\n return []\n\n versions = [\n (version, json_page.yanked(name, version))\n for version in json_page.versions(name)\n if constraint.allows(version)\n ]\n\n return [Package(name, version, yanked=yanked) for version, yanked in versions]\n\n def _get_package_info(self, name: NormalizedName) -> dict[str, Any]:\n headers = {\"Accept\": \"application/vnd.pypi.simple.v1+json\"}\n info = self._get(f\"simple/{name}/\", headers=headers)\n if info is None:\n raise PackageNotFound(f\"Package [{name}] not found.\")\n\n return info\n\n def find_links_for_package(self, package: Package) -> list[Link]:\n json_data = self._get(f\"pypi/{package.name}/{package.version}/json\")\n if json_data is None:\n return []\n\n links = []\n for url in json_data[\"urls\"]:\n if url[\"packagetype\"] in SUPPORTED_PACKAGE_TYPES:\n h = f\"sha256={url['digests']['sha256']}\"\n links.append(Link(url[\"url\"] + \"#\" + h, yanked=self._get_yanked(url)))\n\n return links\n\n def _get_release_info(\n self, name: NormalizedName, version: Version\n ) -> dict[str, Any]:\n from poetry.inspection.info import PackageInfo\n\n self._log(f\"Getting info for {name} ({version}) from PyPI\", \"debug\")\n\n json_data = self._get(f\"pypi/{name}/{version}/json\")\n if json_data is None:\n raise PackageNotFound(f\"Package [{name}] not found.\")\n\n info = json_data[\"info\"]\n\n data = PackageInfo(\n name=info[\"name\"],\n version=info[\"version\"],\n summary=info[\"summary\"],\n requires_dist=info[\"requires_dist\"],\n requires_python=info[\"requires_python\"],\n yanked=self._get_yanked(info),\n cache_version=str(self.CACHE_VERSION),\n )\n\n try:\n version_info = json_data[\"urls\"]\n except KeyError:\n version_info = []\n\n files = info.get(\"files\", [])\n for file_info in version_info:\n if file_info[\"packagetype\"] in SUPPORTED_PACKAGE_TYPES:\n files.append(\n {\n \"file\": file_info[\"filename\"],\n \"hash\": \"sha256:\" + file_info[\"digests\"][\"sha256\"],\n }\n )\n data.files = files\n\n if self._fallback and data.requires_dist is None:\n self._log(\n \"No dependencies found, downloading metadata and/or archives\",\n level=\"debug\",\n )\n # No dependencies set (along with other information)\n # This might be due to actually no dependencies\n # or badly set metadata when uploading.\n # So, we need to make sure there is actually no\n # dependencies by introspecting packages.\n page = self.get_page(name)\n links = list(page.links_for_version(name, version))\n info = self._get_info_from_links(links)\n\n data.requires_dist = info.requires_dist\n\n if not data.requires_python:\n data.requires_python = info.requires_python\n\n return data.asdict()\n\n def _get_page(self, name: NormalizedName) -> SimpleJsonPage:\n source = self._base_url + f\"simple/{name}/\"\n info = self.get_package_info(name)\n return SimpleJsonPage(source, info)\n\n def _get(\n self, endpoint: str, headers: dict[str, str] | None = None\n ) -> dict[str, Any] | None:\n try:\n json_response = self.session.get(\n self._base_url + endpoint,\n raise_for_status=False,\n timeout=REQUESTS_TIMEOUT,\n headers=headers,\n )\n except requests.exceptions.TooManyRedirects:\n # Cache control redirect loop.\n # We try to remove the cache and try again\n self.session.delete_cache(self._base_url + endpoint)\n json_response = self.session.get(\n self._base_url + endpoint,\n raise_for_status=False,\n timeout=REQUESTS_TIMEOUT,\n headers=headers,\n )\n\n if json_response.status_code != 200:\n return None\n\n json: dict[str, Any] = json_response.json()\n return json\n\n @staticmethod\n def _get_yanked(json_data: dict[str, Any]) -> str | bool:\n if json_data.get(\"yanked\", False):\n return json_data.get(\"yanked_reason\") or True\n return False\n\n# Path: src/poetry/repositories/repository.py\nfrom __future__ import annotations\n\nimport logging\n\nfrom typing import TYPE_CHECKING\n\nfrom packaging.utils import canonicalize_name\nfrom poetry.core.constraints.version import Version\n\nfrom poetry.repositories.abstract_repository import AbstractRepository\nfrom poetry.repositories.exceptions import PackageNotFound\n\n\nif TYPE_CHECKING:\n from packaging.utils import NormalizedName\n from poetry.core.constraints.version import VersionConstraint\n from poetry.core.packages.dependency import Dependency\n from poetry.core.packages.package import Package\n from poetry.core.packages.utils.link import Link\n\n\nclass Repository(AbstractRepository):\n def __init__(self, name: str, packages: list[Package] | None = None) -> None:\n super().__init__(name)\n self._packages: list[Package] = []\n\n for package in packages or []:\n self.add_package(package)\n\n @property\n def packages(self) -> list[Package]:\n return self._packages\n\n def find_packages(self, dependency: Dependency) -> list[Package]:\n packages = []\n ignored_pre_release_packages = []\n\n constraint = dependency.constraint\n allow_prereleases = dependency.allows_prereleases()\n for package in self._find_packages(dependency.name, constraint):\n if package.yanked and not isinstance(constraint, Version):\n # PEP 592: yanked files are always ignored, unless they are the only\n # file that matches a version specifier that \"pins\" to an exact\n # version\n continue\n if (\n package.is_prerelease()\n and not allow_prereleases\n and not package.is_direct_origin()\n ):\n ignored_pre_release_packages.append(package)\n continue\n\n packages.append(package)\n\n self._log(\n f\"{len(packages)} packages found for {dependency.name} {constraint!s}\",\n level=\"debug\",\n )\n\n return packages or ignored_pre_release_packages\n\n def has_package(self, package: Package) -> bool:\n package_id = package.unique_name\n return any(\n package_id == repo_package.unique_name for repo_package in self.packages\n )\n\n def add_package(self, package: Package) -> None:\n self._packages.append(package)\n\n def remove_package(self, package: Package) -> None:\n package_id = package.unique_name\n\n index = None\n for i, repo_package in enumerate(self.packages):\n if package_id == repo_package.unique_name:\n index = i\n break\n\n if index is not None:\n del self._packages[index]\n\n def search(self, query: str) -> list[Package]:\n results: list[Package] = []\n\n for package in self.packages:\n if query in package.name:\n results.append(package)\n\n return results\n\n def _find_packages(\n self, name: NormalizedName, constraint: VersionConstraint\n ) -> list[Package]:\n return [\n package\n for package in self._packages\n if package.name == name and constraint.allows(package.version)\n ]\n\n def _log(self, msg: str, level: str = \"info\") -> None:\n logger = logging.getLogger(f\"{__name__}.{self.__class__.__name__}\")\n getattr(logger, level)(f\"Source ({self.name}): {msg}\")\n\n def __len__(self) -> int:\n return len(self._packages)\n\n def find_links_for_package(self, package: Package) -> list[Link]:\n return []\n\n def package(\n self, name: str, version: Version, extras: list[str] | None = None\n ) -> Package:\n canonicalized_name = canonicalize_name(name)\n for package in self.packages:\n if canonicalized_name == package.name and package.version == version:\n return package\n\n raise PackageNotFound(f\"Package {name} ({version}) not found.\")\n\n# Path: src/poetry/repositories/repository_pool.py\nfrom __future__ import annotations\n\nimport enum\nimport warnings\n\nfrom collections import OrderedDict\nfrom dataclasses import dataclass\nfrom enum import IntEnum\nfrom typing import TYPE_CHECKING\n\nfrom poetry.config.config import Config\nfrom poetry.repositories.abstract_repository import AbstractRepository\nfrom poetry.repositories.exceptions import PackageNotFound\nfrom poetry.repositories.repository import Repository\nfrom poetry.utils.cache import ArtifactCache\n\n\nif TYPE_CHECKING:\n from poetry.core.constraints.version import Version\n from poetry.core.packages.dependency import Dependency\n from poetry.core.packages.package import Package\n\n_SENTINEL = object()\n\n\nclass Priority(IntEnum):\n # The order of the members below dictates the actual priority. The first member has\n # top priority.\n DEFAULT = enum.auto()\n PRIMARY = enum.auto()\n SECONDARY = enum.auto()\n SUPPLEMENTAL = enum.auto()\n EXPLICIT = enum.auto()\n\n\n@dataclass(frozen=True)\nclass PrioritizedRepository:\n repository: Repository\n priority: Priority\n\n\nclass RepositoryPool(AbstractRepository):\n def __init__(\n self,\n repositories: list[Repository] | None = None,\n ignore_repository_names: object = _SENTINEL,\n *,\n config: Config | None = None,\n ) -> None:\n super().__init__(\"poetry-repository-pool\")\n self._repositories: OrderedDict[str, PrioritizedRepository] = OrderedDict()\n\n if repositories is None:\n repositories = []\n for repository in repositories:\n self.add_repository(repository)\n\n self._artifact_cache = ArtifactCache(\n cache_dir=(config or Config.create()).artifacts_cache_directory\n )\n\n if ignore_repository_names is not _SENTINEL:\n warnings.warn(\n \"The 'ignore_repository_names' argument to 'RepositoryPool.__init__' is\"\n \" deprecated. It has no effect anymore and will be removed in a future\"\n \" version.\",\n DeprecationWarning,\n stacklevel=2,\n )\n\n @staticmethod\n def from_packages(packages: list[Package], config: Config | None) -> RepositoryPool:\n pool = RepositoryPool(config=config)\n for package in packages:\n if package.is_direct_origin():\n continue\n\n repo_name = package.source_reference or \"PyPI\"\n try:\n repo = pool.repository(repo_name)\n except IndexError:\n repo = Repository(repo_name)\n pool.add_repository(repo)\n\n if not repo.has_package(package):\n repo.add_package(package)\n\n return pool\n\n @property\n def repositories(self) -> list[Repository]:\n \"\"\"\n Returns the repositories in the pool,\n in the order they will be searched for packages.\n\n ATTENTION: For backwards compatibility and practical reasons,\n repositories with priority EXPLICIT are NOT included,\n because they will not be searched.\n \"\"\"\n sorted_repositories = self._sorted_repositories\n return [\n prio_repo.repository\n for prio_repo in sorted_repositories\n if prio_repo.priority is not Priority.EXPLICIT\n ]\n\n @property\n def all_repositories(self) -> list[Repository]:\n return [prio_repo.repository for prio_repo in self._sorted_repositories]\n\n @property\n def _sorted_repositories(self) -> list[PrioritizedRepository]:\n return sorted(\n self._repositories.values(), key=lambda prio_repo: prio_repo.priority\n )\n\n @property\n def artifact_cache(self) -> ArtifactCache:\n return self._artifact_cache\n\n def has_default(self) -> bool:\n return self._contains_priority(Priority.DEFAULT)\n\n def has_primary_repositories(self) -> bool:\n return self._contains_priority(Priority.PRIMARY)\n\n def _contains_priority(self, priority: Priority) -> bool:\n return any(\n prio_repo.priority is priority for prio_repo in self._repositories.values()\n )\n\n def has_repository(self, name: str) -> bool:\n return name.lower() in self._repositories\n\n def repository(self, name: str) -> Repository:\n return self._get_prioritized_repository(name).repository\n\n def get_priority(self, name: str) -> Priority:\n return self._get_prioritized_repository(name).priority\n\n def _get_prioritized_repository(self, name: str) -> PrioritizedRepository:\n name = name.lower()\n if self.has_repository(name):\n return self._repositories[name]\n raise IndexError(f'Repository \"{name}\" does not exist.')\n\n def add_repository(\n self,\n repository: Repository,\n default: bool = False,\n secondary: bool = False,\n *,\n priority: Priority = Priority.PRIMARY,\n ) -> RepositoryPool:\n \"\"\"\n Adds a repository to the pool.\n \"\"\"\n repository_name = repository.name.lower()\n if self.has_repository(repository_name):\n raise ValueError(\n f\"A repository with name {repository_name} was already added.\"\n )\n\n if default or secondary:\n warnings.warn(\n \"Parameters 'default' and 'secondary' to\"\n \" 'RepositoryPool.add_repository' are deprecated. Please provide\"\n \" the keyword-argument 'priority' instead.\",\n DeprecationWarning,\n stacklevel=2,\n )\n priority = Priority.DEFAULT if default else Priority.SECONDARY\n\n if priority is Priority.DEFAULT and self.has_default():\n raise ValueError(\"Only one repository can be the default.\")\n\n self._repositories[repository_name] = PrioritizedRepository(\n repository, priority\n )\n return self\n\n def remove_repository(self, name: str) -> RepositoryPool:\n if not self.has_repository(name):\n raise IndexError(\n f\"RepositoryPool can not remove unknown repository '{name}'.\"\n )\n del self._repositories[name.lower()]\n return self\n\n def package(\n self,\n name: str,\n version: Version,\n extras: list[str] | None = None,\n repository_name: str | None = None,\n ) -> Package:\n if repository_name:\n return self.repository(repository_name).package(\n name, version, extras=extras\n )\n\n for repo in self.repositories:\n try:\n return repo.package(name, version, extras=extras)\n except PackageNotFound:\n continue\n raise PackageNotFound(f\"Package {name} ({version}) not found.\")\n\n def find_packages(self, dependency: Dependency) -> list[Package]:\n repository_name = dependency.source_name\n if repository_name:\n return self.repository(repository_name).find_packages(dependency)\n\n packages: list[Package] = []\n for repo in self.repositories:\n if packages and self.get_priority(repo.name) is Priority.SUPPLEMENTAL:\n break\n packages += repo.find_packages(dependency)\n return packages\n\n def search(self, query: str) -> list[Package]:\n results: list[Package] = []\n for repo in self.repositories:\n results += repo.search(query)\n return results\n\n# Path: src/poetry/repositories/single_page_repository.py\nfrom __future__ import annotations\n\nfrom typing import TYPE_CHECKING\n\nfrom poetry.repositories.exceptions import PackageNotFound\nfrom poetry.repositories.legacy_repository import LegacyRepository\nfrom poetry.repositories.link_sources.html import SimpleRepositoryPage\n\n\nif TYPE_CHECKING:\n from packaging.utils import NormalizedName\n\n\nclass SinglePageRepository(LegacyRepository):\n def _get_page(self, name: NormalizedName) -> SimpleRepositoryPage:\n \"\"\"\n Single page repositories only have one page irrespective of endpoint.\n \"\"\"\n response = self._get_response(\"\")\n if not response:\n raise PackageNotFound(f\"Package [{name}] not found.\")\n return SimpleRepositoryPage(response.url, response.text)\n\n# Path: src/poetry/toml/__init__.py\nfrom __future__ import annotations\n\nfrom poetry.toml.exceptions import TOMLError\nfrom poetry.toml.file import TOMLFile\n\n\n__all__ = [\"TOMLError\", \"TOMLFile\"]\n\n# Path: src/poetry/toml/exceptions.py\nfrom __future__ import annotations\n\nfrom poetry.core.exceptions import PoetryCoreException\nfrom tomlkit.exceptions import TOMLKitError\n\n\nclass TOMLError(TOMLKitError, PoetryCoreException):\n pass\n\n# Path: src/poetry/toml/file.py\nfrom __future__ import annotations\n\nimport warnings\n\nfrom typing import TYPE_CHECKING\nfrom typing import Any\n\nfrom tomlkit.toml_file import TOMLFile as BaseTOMLFile\n\n\nif TYPE_CHECKING:\n from pathlib import Path\n\n from tomlkit.toml_document import TOMLDocument\n\n\nclass TOMLFile(BaseTOMLFile):\n def __init__(self, path: Path) -> None:\n super().__init__(path)\n self.__path = path\n\n @property\n def path(self) -> Path:\n return self.__path\n\n def exists(self) -> bool:\n return self.__path.exists()\n\n def read(self) -> TOMLDocument:\n from tomlkit.exceptions import TOMLKitError\n\n from poetry.toml import TOMLError\n\n try:\n return super().read()\n except (ValueError, TOMLKitError) as e:\n raise TOMLError(f\"Invalid TOML file {self.path.as_posix()}: {e}\")\n\n def __getattr__(self, item: str) -> Any:\n warnings.warn(\n \"`__getattr__` will be removed from the `TOMLFile` in a future release.\"\n \"\\n\\nInstead of accessing properties of the underlying `Path` as \"\n \"`tomlfile.whatever`, prefer `tomlfile.path.whatever`.\",\n DeprecationWarning,\n stacklevel=2,\n )\n return getattr(self.__path, item)\n\n def __str__(self) -> str:\n return self.__path.as_posix()\n\n# Path: src/poetry/utils/__init__.py\n\n# Path: src/poetry/utils/_compat.py\nfrom __future__ import annotations\n\nimport sys\n\nfrom contextlib import suppress\n\n\n# TODO: use try/except ImportError when\n# https://github.com/python/mypy/issues/1393 is fixed\n\nif sys.version_info < (3, 11):\n # compatibility for python <3.11\n import tomli as tomllib\nelse:\n import tomllib # nopycln: import\n\n\nif sys.version_info < (3, 10):\n # compatibility for python <3.10\n import importlib_metadata as metadata\nelse:\n from importlib import metadata\n\nWINDOWS = sys.platform == \"win32\"\n\n\ndef decode(string: bytes | str, encodings: list[str] | None = None) -> str:\n if not isinstance(string, bytes):\n return string\n\n encodings = encodings or [\"utf-8\", \"latin1\", \"ascii\"]\n\n for encoding in encodings:\n with suppress(UnicodeEncodeError, UnicodeDecodeError):\n return string.decode(encoding)\n\n return string.decode(encodings[0], errors=\"ignore\")\n\n\ndef encode(string: str, encodings: list[str] | None = None) -> bytes:\n if isinstance(string, bytes):\n return string\n\n encodings = encodings or [\"utf-8\", \"latin1\", \"ascii\"]\n\n for encoding in encodings:\n with suppress(UnicodeEncodeError, UnicodeDecodeError):\n return string.encode(encoding)\n\n return string.encode(encodings[0], errors=\"ignore\")\n\n\n__all__ = [\n \"WINDOWS\",\n \"decode\",\n \"encode\",\n \"metadata\",\n \"tomllib\",\n]\n\n# Path: src/poetry/utils/authenticator.py\nfrom __future__ import annotations\n\nimport contextlib\nimport dataclasses\nimport functools\nimport logging\nimport time\nimport urllib.parse\n\nfrom os.path import commonprefix\nfrom pathlib import Path\nfrom typing import TYPE_CHECKING\nfrom typing import Any\n\nimport requests\nimport requests.adapters\nimport requests.auth\nimport requests.exceptions\n\nfrom cachecontrol import CacheControlAdapter\nfrom cachecontrol.caches import FileCache\nfrom requests_toolbelt import user_agent\n\nfrom poetry.__version__ import __version__\nfrom poetry.config.config import Config\nfrom poetry.exceptions import PoetryException\nfrom poetry.utils.constants import REQUESTS_TIMEOUT\nfrom poetry.utils.constants import RETRY_AFTER_HEADER\nfrom poetry.utils.constants import STATUS_FORCELIST\nfrom poetry.utils.password_manager import HTTPAuthCredential\nfrom poetry.utils.password_manager import PasswordManager\n\n\nif TYPE_CHECKING:\n from cleo.io.io import IO\n\n\nlogger = logging.getLogger(__name__)\n\n\n@dataclasses.dataclass(frozen=True)\nclass RepositoryCertificateConfig:\n cert: Path | None = dataclasses.field(default=None)\n client_cert: Path | None = dataclasses.field(default=None)\n verify: bool = dataclasses.field(default=True)\n\n @classmethod\n def create(\n cls, repository: str, config: Config | None\n ) -> RepositoryCertificateConfig:\n config = config if config else Config.create()\n\n verify: str | bool = config.get(\n f\"certificates.{repository}.verify\",\n config.get(f\"certificates.{repository}.cert\", True),\n )\n client_cert: str = config.get(f\"certificates.{repository}.client-cert\")\n\n return cls(\n cert=Path(verify) if isinstance(verify, str) else None,\n client_cert=Path(client_cert) if client_cert else None,\n verify=verify if isinstance(verify, bool) else True,\n )\n\n\n@dataclasses.dataclass\nclass AuthenticatorRepositoryConfig:\n name: str\n url: str\n netloc: str = dataclasses.field(init=False)\n path: str = dataclasses.field(init=False)\n\n def __post_init__(self) -> None:\n parsed_url = urllib.parse.urlsplit(self.url)\n self.netloc = parsed_url.netloc\n self.path = parsed_url.path\n\n def certs(self, config: Config) -> RepositoryCertificateConfig:\n return RepositoryCertificateConfig.create(self.name, config)\n\n @property\n def http_credential_keys(self) -> list[str]:\n return [self.url, self.netloc, self.name]\n\n def get_http_credentials(\n self, password_manager: PasswordManager, username: str | None = None\n ) -> HTTPAuthCredential:\n # try with the repository name via the password manager\n credential = HTTPAuthCredential(\n **(password_manager.get_http_auth(self.name) or {})\n )\n\n if credential.password is not None:\n return credential\n\n if password_manager.use_keyring:\n # fallback to url and netloc based keyring entries\n credential = password_manager.get_credential(\n self.url, self.netloc, username=credential.username\n )\n\n return credential\n\n\nclass Authenticator:\n def __init__(\n self,\n config: Config | None = None,\n io: IO | None = None,\n cache_id: str | None = None,\n disable_cache: bool = False,\n pool_size: int = requests.adapters.DEFAULT_POOLSIZE,\n ) -> None:\n self._config = config or Config.create()\n self._io = io\n self._sessions_for_netloc: dict[str, requests.Session] = {}\n self._credentials: dict[str, HTTPAuthCredential] = {}\n self._certs: dict[str, RepositoryCertificateConfig] = {}\n self._configured_repositories: (\n dict[str, AuthenticatorRepositoryConfig] | None\n ) = None\n self._password_manager = PasswordManager(self._config)\n self._cache_control = (\n FileCache(\n self._config.repository_cache_directory\n / (cache_id or \"_default_cache\")\n / \"_http\"\n )\n if not disable_cache\n else None\n )\n self.get_repository_config_for_url = functools.lru_cache(maxsize=None)(\n self._get_repository_config_for_url\n )\n self._pool_size = pool_size\n self._user_agent = user_agent(\"poetry\", __version__)\n\n def create_session(self) -> requests.Session:\n session = requests.Session()\n session.headers[\"User-Agent\"] = self._user_agent\n\n if self._cache_control is None:\n return session\n\n adapter = CacheControlAdapter(\n cache=self._cache_control,\n pool_maxsize=self._pool_size,\n )\n session.mount(\"http://\", adapter)\n session.mount(\"https://\", adapter)\n\n return session\n\n def get_session(self, url: str | None = None) -> requests.Session:\n if not url:\n return self.create_session()\n\n parsed_url = urllib.parse.urlsplit(url)\n netloc = parsed_url.netloc\n\n if netloc not in self._sessions_for_netloc:\n logger.debug(\"Creating new session for %s\", netloc)\n self._sessions_for_netloc[netloc] = self.create_session()\n\n return self._sessions_for_netloc[netloc]\n\n def close(self) -> None:\n for session in self._sessions_for_netloc.values():\n if session is not None:\n with contextlib.suppress(AttributeError):\n session.close()\n\n def __del__(self) -> None:\n self.close()\n\n def delete_cache(self, url: str) -> None:\n if self._cache_control is not None:\n self._cache_control.delete(key=url)\n\n def authenticated_url(self, url: str) -> str:\n parsed = urllib.parse.urlparse(url)\n credential = self.get_credentials_for_url(url)\n\n if credential.username is not None and credential.password is not None:\n username = urllib.parse.quote(credential.username, safe=\"\")\n password = urllib.parse.quote(credential.password, safe=\"\")\n\n return (\n f\"{parsed.scheme}://{username}:{password}@{parsed.netloc}{parsed.path}\"\n )\n\n return url\n\n def request(\n self, method: str, url: str, raise_for_status: bool = True, **kwargs: Any\n ) -> requests.Response:\n headers = kwargs.get(\"headers\")\n request = requests.Request(method, url, headers=headers)\n credential = self.get_credentials_for_url(url)\n\n if credential.username is not None or credential.password is not None:\n request = requests.auth.HTTPBasicAuth(\n credential.username or \"\", credential.password or \"\"\n )(request)\n\n session = self.get_session(url=url)\n prepared_request = session.prepare_request(request)\n\n proxies: dict[str, str] = kwargs.get(\"proxies\", {})\n stream: bool | None = kwargs.get(\"stream\")\n\n certs = self.get_certs_for_url(url)\n verify: bool | str | Path = kwargs.get(\"verify\") or certs.cert or certs.verify\n cert: str | Path | None = kwargs.get(\"cert\") or certs.client_cert\n\n if cert is not None:\n cert = str(cert)\n\n verify = str(verify) if isinstance(verify, Path) else verify\n\n settings = session.merge_environment_settings(\n prepared_request.url, proxies, stream, verify, cert\n )\n\n # Send the request.\n send_kwargs = {\n \"timeout\": kwargs.get(\"timeout\", REQUESTS_TIMEOUT),\n \"allow_redirects\": kwargs.get(\"allow_redirects\", True),\n }\n send_kwargs.update(settings)\n\n attempt = 0\n resp = None\n\n while True:\n is_last_attempt = attempt >= 5\n try:\n resp = session.send(prepared_request, **send_kwargs)\n except (requests.exceptions.ConnectionError, OSError) as e:\n if is_last_attempt:\n raise e\n else:\n if resp.status_code not in STATUS_FORCELIST or is_last_attempt:\n if raise_for_status:\n resp.raise_for_status()\n return resp\n\n if not is_last_attempt:\n attempt += 1\n delay = self._get_backoff(resp, attempt)\n logger.debug(\"Retrying HTTP request in %s seconds.\", delay)\n time.sleep(delay)\n continue\n\n # this should never really be hit under any sane circumstance\n raise PoetryException(\"Failed HTTP {} request\", method.upper())\n\n def _get_backoff(self, response: requests.Response | None, attempt: int) -> float:\n if response is not None:\n retry_after = response.headers.get(RETRY_AFTER_HEADER, \"\")\n if retry_after:\n return float(retry_after)\n\n return 0.5 * attempt\n\n def get(self, url: str, **kwargs: Any) -> requests.Response:\n return self.request(\"get\", url, **kwargs)\n\n def head(self, url: str, **kwargs: Any) -> requests.Response:\n kwargs.setdefault(\"allow_redirects\", False)\n return self.request(\"head\", url, **kwargs)\n\n def post(self, url: str, **kwargs: Any) -> requests.Response:\n return self.request(\"post\", url, **kwargs)\n\n def _get_credentials_for_repository(\n self, repository: AuthenticatorRepositoryConfig, username: str | None = None\n ) -> HTTPAuthCredential:\n # cache repository credentials by repository url to avoid multiple keyring\n # backend queries when packages are being downloaded from the same source\n key = f\"{repository.url}#username={username or ''}\"\n\n if key not in self._credentials:\n self._credentials[key] = repository.get_http_credentials(\n password_manager=self._password_manager, username=username\n )\n\n return self._credentials[key]\n\n def _get_credentials_for_url(\n self, url: str, exact_match: bool = False\n ) -> HTTPAuthCredential:\n repository = self.get_repository_config_for_url(url, exact_match)\n\n credential = (\n self._get_credentials_for_repository(repository=repository)\n if repository is not None\n else HTTPAuthCredential()\n )\n\n if credential.password is None:\n parsed_url = urllib.parse.urlsplit(url)\n netloc = parsed_url.netloc\n credential = self._password_manager.get_credential(\n url, netloc, username=credential.username\n )\n\n return HTTPAuthCredential(\n username=credential.username, password=credential.password\n )\n\n return credential\n\n def get_credentials_for_git_url(self, url: str) -> HTTPAuthCredential:\n parsed_url = urllib.parse.urlsplit(url)\n\n if parsed_url.scheme not in {\"http\", \"https\"}:\n return HTTPAuthCredential()\n\n key = f\"git+{url}\"\n\n if key not in self._credentials:\n self._credentials[key] = self._get_credentials_for_url(url, True)\n\n return self._credentials[key]\n\n def get_credentials_for_url(self, url: str) -> HTTPAuthCredential:\n parsed_url = urllib.parse.urlsplit(url)\n netloc = parsed_url.netloc\n\n if url not in self._credentials:\n if \"@\" not in netloc:\n # no credentials were provided in the url, try finding the\n # best repository configuration\n self._credentials[url] = self._get_credentials_for_url(url)\n else:\n # Split from the right because that's how urllib.parse.urlsplit()\n # behaves if more than one @ is present (which can be checked using\n # the password attribute of urlsplit()'s return value).\n auth, netloc = netloc.rsplit(\"@\", 1)\n # Split from the left because that's how urllib.parse.urlsplit()\n # behaves if more than one : is present (which again can be checked\n # using the password attribute of the return value)\n user, password = auth.split(\":\", 1) if \":\" in auth else (auth, \"\")\n self._credentials[url] = HTTPAuthCredential(\n urllib.parse.unquote(user),\n urllib.parse.unquote(password),\n )\n\n return self._credentials[url]\n\n def get_pypi_token(self, name: str) -> str | None:\n return self._password_manager.get_pypi_token(name)\n\n def get_http_auth(\n self, name: str, username: str | None = None\n ) -> HTTPAuthCredential | None:\n if name == \"pypi\":\n repository = AuthenticatorRepositoryConfig(\n name, \"https://upload.pypi.org/legacy/\"\n )\n else:\n if name not in self.configured_repositories:\n return None\n repository = self.configured_repositories[name]\n\n return self._get_credentials_for_repository(\n repository=repository, username=username\n )\n\n def get_certs_for_repository(self, name: str) -> RepositoryCertificateConfig:\n if name.lower() == \"pypi\" or name not in self.configured_repositories:\n return RepositoryCertificateConfig()\n return self.configured_repositories[name].certs(self._config)\n\n @property\n def configured_repositories(self) -> dict[str, AuthenticatorRepositoryConfig]:\n if self._configured_repositories is None:\n self._configured_repositories = {}\n for repository_name in self._config.get(\"repositories\", []):\n url = self._config.get(f\"repositories.{repository_name}.url\")\n self._configured_repositories[repository_name] = (\n AuthenticatorRepositoryConfig(repository_name, url)\n )\n\n return self._configured_repositories\n\n def reset_credentials_cache(self) -> None:\n self.get_repository_config_for_url.cache_clear()\n self._credentials = {}\n\n def add_repository(self, name: str, url: str) -> None:\n self.configured_repositories[name] = AuthenticatorRepositoryConfig(name, url)\n self.reset_credentials_cache()\n\n...\n", "needle_token_start": 2374, "needle_token_end": 2441, "code_context_ntokens": 16013, "output": ["```python\nclass InstalledRepository(Repository):\n @classmethod\n def get_package_vcs_properties_from_path(cls, src: Path) -> tuple[str, str, str]:\n from poetry.vcs.git import Git\n\n info = Git.info(repo=src)\n return \"git\", info.origin, info.revision\n```"]} +{"repo": "python-poetry/poetry", "name": "_validate_group_options", "language": "python", "path": "src/poetry/console/commands/group_command.py", "position_ratio": 0.25, "description": "\nFunction Description:\n1. **Purpose**: The function ensures that specified dependency groups are valid according to the project configuration.\n2. **Input**: A dictionary where keys are option types and values are sets of group names.\n3. **Output**: The function does not return any value but raises an error if any group is not recognized.\n4. **Procedure**: The function iterates over each group specified in the input dictionary. For each group, it checks if the group exists in the project's dependency configuration. If a group does not exist, it is added to a list of invalid groups. If there are any invalid groups, an error message is constructed and an exception is raised, indicating which groups are invalid and through which options they were specified.\n\n", "instruction": "Based on the function description and code context, please retrieve and repeat the exact described function from the code context in a code block wrapped by ```:", "template": "instruction\ncode_context\ndescription\ninstruction", "code_context": "# Path: src/poetry/console/commands/config.py\nfrom __future__ import annotations\n\nimport json\nimport re\n\nfrom pathlib import Path\nfrom typing import TYPE_CHECKING\nfrom typing import Any\nfrom typing import ClassVar\nfrom typing import cast\n\nfrom cleo.helpers import argument\nfrom cleo.helpers import option\n\nfrom poetry.config.config import PackageFilterPolicy\nfrom poetry.config.config import boolean_normalizer\nfrom poetry.config.config import boolean_validator\nfrom poetry.config.config import int_normalizer\nfrom poetry.console.commands.command import Command\n\n\nif TYPE_CHECKING:\n from cleo.io.inputs.argument import Argument\n from cleo.io.inputs.option import Option\n\n from poetry.config.config_source import ConfigSource\n\n\nclass ConfigCommand(Command):\n name = \"config\"\n description = \"Manages configuration settings.\"\n\n arguments: ClassVar[list[Argument]] = [\n argument(\"key\", \"Setting key.\", optional=True),\n argument(\"value\", \"Setting value.\", optional=True, multiple=True),\n ]\n\n options: ClassVar[list[Option]] = [\n option(\"list\", None, \"List configuration settings.\"),\n option(\"unset\", None, \"Unset configuration setting.\"),\n option(\"local\", None, \"Set/Get from the project's local configuration.\"),\n ]\n\n help = \"\"\"\\\nThis command allows you to edit the poetry config settings and repositories.\n\nTo add a repository:\n\n poetry config repositories.foo https://bar.com/simple/\n\nTo remove a repository (repo is a short alias for repositories):\n\n poetry config --unset repo.foo\"\"\"\n\n LIST_PROHIBITED_SETTINGS: ClassVar[set[str]] = {\"http-basic\", \"pypi-token\"}\n\n @property\n def unique_config_values(self) -> dict[str, tuple[Any, Any]]:\n unique_config_values = {\n \"cache-dir\": (str, lambda val: str(Path(val))),\n \"virtualenvs.create\": (boolean_validator, boolean_normalizer),\n \"virtualenvs.in-project\": (boolean_validator, boolean_normalizer),\n \"virtualenvs.options.always-copy\": (boolean_validator, boolean_normalizer),\n \"virtualenvs.options.system-site-packages\": (\n boolean_validator,\n boolean_normalizer,\n ),\n \"virtualenvs.options.no-pip\": (boolean_validator, boolean_normalizer),\n \"virtualenvs.options.no-setuptools\": (\n boolean_validator,\n boolean_normalizer,\n ),\n \"virtualenvs.path\": (str, lambda val: str(Path(val))),\n \"virtualenvs.prefer-active-python\": (boolean_validator, boolean_normalizer),\n \"virtualenvs.prompt\": (str, str),\n \"experimental.system-git-client\": (boolean_validator, boolean_normalizer),\n \"installer.modern-installation\": (boolean_validator, boolean_normalizer),\n \"installer.parallel\": (boolean_validator, boolean_normalizer),\n \"installer.max-workers\": (lambda val: int(val) > 0, int_normalizer),\n \"installer.no-binary\": (\n PackageFilterPolicy.validator,\n PackageFilterPolicy.normalize,\n ),\n \"solver.lazy-wheel\": (boolean_validator, boolean_normalizer),\n \"warnings.export\": (boolean_validator, boolean_normalizer),\n \"keyring.enabled\": (boolean_validator, boolean_normalizer),\n }\n\n return unique_config_values\n\n def handle(self) -> int:\n from pathlib import Path\n\n from poetry.core.pyproject.exceptions import PyProjectException\n\n from poetry.config.config import Config\n from poetry.config.file_config_source import FileConfigSource\n from poetry.locations import CONFIG_DIR\n from poetry.toml.file import TOMLFile\n\n config = Config.create()\n config_file = TOMLFile(CONFIG_DIR / \"config.toml\")\n\n try:\n local_config_file = TOMLFile(self.poetry.file.path.parent / \"poetry.toml\")\n if local_config_file.exists():\n config.merge(local_config_file.read())\n except (RuntimeError, PyProjectException):\n local_config_file = TOMLFile(Path.cwd() / \"poetry.toml\")\n\n if self.option(\"local\"):\n config.set_config_source(FileConfigSource(local_config_file))\n\n if not config_file.exists():\n config_file.path.parent.mkdir(parents=True, exist_ok=True)\n config_file.path.touch(mode=0o0600)\n\n if self.option(\"list\"):\n self._list_configuration(config.all(), config.raw())\n\n return 0\n\n setting_key = self.argument(\"key\")\n if not setting_key:\n return 0\n\n if self.argument(\"value\") and self.option(\"unset\"):\n raise RuntimeError(\"You can not combine a setting value with --unset\")\n\n # show the value if no value is provided\n if not self.argument(\"value\") and not self.option(\"unset\"):\n if setting_key.split(\".\")[0] in self.LIST_PROHIBITED_SETTINGS:\n raise ValueError(f\"Expected a value for {setting_key} setting.\")\n\n m = re.match(r\"^repos?(?:itories)?(?:\\.(.+))?\", self.argument(\"key\"))\n value: str | dict[str, Any]\n if m:\n if not m.group(1):\n value = {}\n if config.get(\"repositories\") is not None:\n value = config.get(\"repositories\")\n else:\n repo = config.get(f\"repositories.{m.group(1)}\")\n if repo is None:\n raise ValueError(f\"There is no {m.group(1)} repository defined\")\n\n value = repo\n\n self.line(str(value))\n else:\n if setting_key not in self.unique_config_values:\n raise ValueError(f\"There is no {setting_key} setting.\")\n\n value = config.get(setting_key)\n\n if not isinstance(value, str):\n value = json.dumps(value)\n\n self.line(value)\n\n return 0\n\n values: list[str] = self.argument(\"value\")\n\n if setting_key in self.unique_config_values:\n if self.option(\"unset\"):\n config.config_source.remove_property(setting_key)\n return 0\n\n return self._handle_single_value(\n config.config_source,\n setting_key,\n self.unique_config_values[setting_key],\n values,\n )\n\n # handle repositories\n m = re.match(r\"^repos?(?:itories)?(?:\\.(.+))?\", self.argument(\"key\"))\n if m:\n if not m.group(1):\n raise ValueError(\"You cannot remove the [repositories] section\")\n\n if self.option(\"unset\"):\n repo = config.get(f\"repositories.{m.group(1)}\")\n if repo is None:\n raise ValueError(f\"There is no {m.group(1)} repository defined\")\n\n config.config_source.remove_property(f\"repositories.{m.group(1)}\")\n\n return 0\n\n if len(values) == 1:\n url = values[0]\n\n config.config_source.add_property(f\"repositories.{m.group(1)}.url\", url)\n\n return 0\n\n raise ValueError(\n \"You must pass the url. \"\n \"Example: poetry config repositories.foo https://bar.com\"\n )\n\n # handle auth\n m = re.match(r\"^(http-basic|pypi-token)\\.(.+)\", self.argument(\"key\"))\n if m:\n from poetry.utils.password_manager import PasswordManager\n\n password_manager = PasswordManager(config)\n if self.option(\"unset\"):\n if m.group(1) == \"http-basic\":\n password_manager.delete_http_password(m.group(2))\n elif m.group(1) == \"pypi-token\":\n password_manager.delete_pypi_token(m.group(2))\n\n return 0\n\n if m.group(1) == \"http-basic\":\n if len(values) == 1:\n username = values[0]\n # Only username, so we prompt for password\n password = self.secret(\"Password:\")\n assert isinstance(password, str)\n elif len(values) != 2:\n raise ValueError(\n \"Expected one or two arguments \"\n f\"(username, password), got {len(values)}\"\n )\n else:\n username = values[0]\n password = values[1]\n\n password_manager.set_http_password(m.group(2), username, password)\n elif m.group(1) == \"pypi-token\":\n if len(values) != 1:\n raise ValueError(\n f\"Expected only one argument (token), got {len(values)}\"\n )\n\n token = values[0]\n\n password_manager.set_pypi_token(m.group(2), token)\n\n return 0\n\n # handle certs\n m = re.match(r\"certificates\\.([^.]+)\\.(cert|client-cert)\", self.argument(\"key\"))\n if m:\n repository = m.group(1)\n key = m.group(2)\n\n...\n# Path: src/poetry/console/commands/env_command.py\nfrom __future__ import annotations\n\nfrom typing import TYPE_CHECKING\n\nfrom poetry.console.commands.command import Command\n\n\nif TYPE_CHECKING:\n from poetry.utils.env import Env\n\n\nclass EnvCommand(Command):\n def __init__(self) -> None:\n # Set in poetry.console.application.Application.configure_env\n self._env: Env | None = None\n\n super().__init__()\n\n @property\n def env(self) -> Env:\n assert self._env is not None\n return self._env\n\n def set_env(self, env: Env) -> None:\n self._env = env\n\n# Path: src/poetry/console/commands/export.py\nfrom __future__ import annotations\n\nfrom poetry_plugin_export.command import ( # type: ignore[import-untyped]\n ExportCommand as BaseExportCommand,\n)\n\n\nclass ExportCommand(BaseExportCommand): # type: ignore[misc]\n def handle(self) -> int:\n if self.poetry.config.get(\"warnings.export\"):\n self.line_error(\n \"Warning: poetry-plugin-export will not be installed by default in a\"\n \" future version of Poetry.\\n\"\n \"In order to avoid a breaking change and make your automation\"\n \" forward-compatible, please install poetry-plugin-export explicitly.\"\n \" See https://python-poetry.org/docs/plugins/#using-plugins for details\"\n \" on how to install a plugin.\\n\"\n \"To disable this warning run 'poetry config warnings.export false'.\",\n style=\"warning\",\n )\n return super().handle() # type: ignore[no-any-return]\n\n# Path: src/poetry/console/commands/group_command.py\nfrom __future__ import annotations\n\nfrom collections import defaultdict\nfrom typing import TYPE_CHECKING\n\nfrom cleo.helpers import option\nfrom poetry.core.packages.dependency_group import MAIN_GROUP\n\nfrom poetry.console.commands.command import Command\nfrom poetry.console.exceptions import GroupNotFound\n\n\nif TYPE_CHECKING:\n from cleo.io.inputs.option import Option\n from poetry.core.packages.project_package import ProjectPackage\n\n\nclass GroupCommand(Command):\n @staticmethod\n def _group_dependency_options() -> list[Option]:\n return [\n option(\n \"without\",\n None,\n \"The dependency groups to ignore.\",\n flag=False,\n multiple=True,\n ),\n option(\n \"with\",\n None,\n \"The optional dependency groups to include.\",\n flag=False,\n multiple=True,\n ),\n option(\n \"only\",\n None,\n \"The only dependency groups to include.\",\n flag=False,\n multiple=True,\n ),\n ]\n\n @property\n def non_optional_groups(self) -> set[str]:\n # TODO: this should move into poetry-core\n return {\n group.name\n for group in self.poetry.package._dependency_groups.values()\n if not group.is_optional()\n }\n\n @property\n def default_group(self) -> str | None:\n \"\"\"\n The default group to use when no group is specified. This is useful\n for command that have the `--group` option, eg: add, remove.\n\n Can be overridden to adapt behavior.\n \"\"\"\n return None\n\n @property\n def default_groups(self) -> set[str]:\n \"\"\"\n The groups that are considered by the command by default.\n\n Can be overridden to adapt behavior.\n \"\"\"\n return self.non_optional_groups\n\n @property\n def activated_groups(self) -> set[str]:\n groups = {}\n\n for key in {\"with\", \"without\", \"only\"}:\n groups[key] = {\n group.strip()\n for groups in self.option(key, \"\")\n for group in groups.split(\",\")\n }\n self._validate_group_options(groups)\n\n for opt, new, group in [\n (\"no-dev\", \"only\", MAIN_GROUP),\n (\"dev\", \"with\", \"dev\"),\n ]:\n if self.io.input.has_option(opt) and self.option(opt):\n self.line_error(\n f\"The `--{opt}` option is\"\n f\" deprecated, use the `--{new} {group}`\"\n \" notation instead.\"\n )\n groups[new].add(group)\n\n if groups[\"only\"] and (groups[\"with\"] or groups[\"without\"]):\n self.line_error(\n \"The `--with` and \"\n \"`--without` options are ignored when used\"\n \" along with the `--only` option.\"\n \"\"\n )\n\n return groups[\"only\"] or self.default_groups.union(groups[\"with\"]).difference(\n groups[\"without\"]\n )\n\n def project_with_activated_groups_only(self) -> ProjectPackage:\n return self.poetry.package.with_dependency_groups(\n list(self.activated_groups), only=True\n )\n\n \ndef _validate_group_options(self, group_options: dict[str, set[str]]) -> None:\n \"\"\"\n Raises an error if it detects that a group is not part of pyproject.toml\n \"\"\"\n invalid_options = defaultdict(set)\n for opt, groups in group_options.items():\n for group in groups:\n if not self.poetry.package.has_dependency_group(group):\n invalid_options[group].add(opt)\n if invalid_options:\n message_parts = []\n for group in sorted(invalid_options):\n opts = \", \".join(\n f\"--{opt}\"\n for opt in sorted(invalid_options[group])\n )\n message_parts.append(f\"{group} (via {opts})\")\n raise GroupNotFound(f\"Group(s) not found: {', '.join(message_parts)}\")\n\n# Path: src/poetry/console/commands/init.py\nfrom __future__ import annotations\n\nfrom contextlib import suppress\nfrom pathlib import Path\nfrom typing import TYPE_CHECKING\nfrom typing import Any\nfrom typing import ClassVar\nfrom typing import Dict\nfrom typing import Mapping\nfrom typing import Union\n\nfrom cleo.helpers import option\nfrom packaging.utils import canonicalize_name\nfrom tomlkit import inline_table\n\nfrom poetry.console.commands.command import Command\nfrom poetry.console.commands.env_command import EnvCommand\nfrom poetry.utils.dependency_specification import RequirementsParser\n\n\nif TYPE_CHECKING:\n from cleo.io.inputs.option import Option\n from packaging.utils import NormalizedName\n from poetry.core.packages.package import Package\n from tomlkit.items import InlineTable\n\n from poetry.repositories import RepositoryPool\n\nRequirements = Dict[str, Union[str, Mapping[str, Any]]]\n\n\nclass InitCommand(Command):\n name = \"init\"\n description = (\n \"Creates a basic pyproject.toml file in the current directory.\"\n )\n\n options: ClassVar[list[Option]] = [\n option(\"name\", None, \"Name of the package.\", flag=False),\n option(\"description\", None, \"Description of the package.\", flag=False),\n option(\"author\", None, \"Author name of the package.\", flag=False),\n option(\"python\", None, \"Compatible Python versions.\", flag=False),\n option(\n \"dependency\",\n None,\n \"Package to require, with an optional version constraint, \"\n \"e.g. requests:^2.10.0 or requests=2.11.1.\",\n flag=False,\n multiple=True,\n ),\n option(\n \"dev-dependency\",\n None,\n \"Package to require for development, with an optional version\"\n \" constraint, e.g. requests:^2.10.0 or requests=2.11.1.\",\n flag=False,\n multiple=True,\n ),\n option(\"license\", \"l\", \"License of the package.\", flag=False),\n ]\n\n help = \"\"\"\\\nThe init command creates a basic pyproject.toml file in the\\\n current directory.\n\"\"\"\n\n def __init__(self) -> None:\n super().__init__()\n\n self._pool: RepositoryPool | None = None\n\n def handle(self) -> int:\n from pathlib import Path\n\n project_path = Path.cwd()\n\n if self.io.input.option(\"directory\"):\n project_path = Path(self.io.input.option(\"directory\"))\n if not project_path.exists() or not project_path.is_dir():\n self.line_error(\n \"The --directory path is not a directory.\"\n )\n return 1\n\n return self._init_pyproject(project_path=project_path)\n\n def _init_pyproject(\n self,\n project_path: Path,\n allow_interactive: bool = True,\n layout_name: str = \"standard\",\n readme_format: str = \"md\",\n ) -> int:\n from poetry.core.vcs.git import GitConfig\n\n from poetry.config.config import Config\n from poetry.layouts import layout\n from poetry.pyproject.toml import PyProjectTOML\n from poetry.utils.env import EnvManager\n\n is_interactive = self.io.is_interactive() and allow_interactive\n\n pyproject = PyProjectTOML(project_path / \"pyproject.toml\")\n\n if pyproject.file.exists():\n if pyproject.is_poetry_project():\n self.line_error(\n \"A pyproject.toml file with a poetry section already\"\n \" exists.\"\n )\n return 1\n\n if pyproject.data.get(\"build-system\"):\n self.line_error(\n \"A pyproject.toml file with a defined build-system already\"\n \" exists.\"\n )\n return 1\n\n vcs_config = GitConfig()\n\n if is_interactive:\n self.line(\"\")\n self.line(\n \"This command will guide you through creating your\"\n \" pyproject.toml config.\"\n )\n self.line(\"\")\n\n name = self.option(\"name\")\n if not name:\n name = project_path.name.lower()\n\n if is_interactive:\n question = self.create_question(\n f\"Package name [{name}]: \", default=name\n )\n name = self.ask(question)\n\n version = \"0.1.0\"\n\n if is_interactive:\n question = self.create_question(\n f\"Version [{version}]: \", default=version\n )\n version = self.ask(question)\n\n description = self.option(\"description\") or \"\"\n if not description and is_interactive:\n description = self.ask(self.create_question(\"Description []: \", default=\"\"))\n\n author = self.option(\"author\")\n if not author and vcs_config.get(\"user.name\"):\n author = vcs_config[\"user.name\"]\n author_email = vcs_config.get(\"user.email\")\n if author_email:\n author += f\" <{author_email}>\"\n\n if is_interactive:\n question = self.create_question(\n f\"Author [{author}, n to skip]: \", default=author\n )\n question.set_validator(lambda v: self._validate_author(v, author))\n author = self.ask(question)\n\n authors = [author] if author else []\n\n license_name = self.option(\"license\")\n if not license_name and is_interactive:\n license_name = self.ask(self.create_question(\"License []: \", default=\"\"))\n\n python = self.option(\"python\")\n if not python:\n config = Config.create()\n python = (\n \"^\"\n + EnvManager.get_python_version(\n precision=2,\n prefer_active_python=config.get(\"virtualenvs.prefer-active-python\"),\n io=self.io,\n ).to_string()\n )\n\n if is_interactive:\n question = self.create_question(\n f\"Compatible Python versions [{python}]: \",\n default=python,\n )\n python = self.ask(question)\n\n if is_interactive:\n self.line(\"\")\n\n requirements: Requirements = {}\n if self.option(\"dependency\"):\n requirements = self._format_requirements(\n self._determine_requirements(self.option(\"dependency\"))\n )\n\n question_text = \"Would you like to define your main dependencies interactively?\"\n help_message = \"\"\"\\\n You can specify a package in the following forms:\n - A single name (requests): this will search for matches on PyPI\n - A name and a constraint (requests@^2.23.0)\n - A git url (git+https://github.com/python-poetry/poetry.git)\n - A git url with a revision\\\n (git+https://github.com/python-poetry/poetry.git#develop)\n - A file path (../my-package/my-package.whl)\n - A directory (../my-package/)\n - A url (https://example.com/packages/my-package-0.1.0.tar.gz)\n \"\"\"\n\n help_displayed = False\n if is_interactive and self.confirm(question_text, True):\n self.line(help_message)\n help_displayed = True\n requirements.update(\n self._format_requirements(self._determine_requirements([]))\n )\n self.line(\"\")\n\n dev_requirements: Requirements = {}\n if self.option(\"dev-dependency\"):\n dev_requirements = self._format_requirements(\n self._determine_requirements(self.option(\"dev-dependency\"))\n )\n\n question_text = (\n \"Would you like to define your development dependencies interactively?\"\n )\n if is_interactive and self.confirm(question_text, True):\n if not help_displayed:\n self.line(help_message)\n\n dev_requirements.update(\n self._format_requirements(self._determine_requirements([]))\n )\n\n self.line(\"\")\n\n layout_ = layout(layout_name)(\n name,\n version,\n description=description,\n author=authors[0] if authors else None,\n readme_format=readme_format,\n license=license_name,\n python=python,\n dependencies=requirements,\n dev_dependencies=dev_requirements,\n )\n\n create_layout = not project_path.exists()\n\n if create_layout:\n layout_.create(project_path, with_pyproject=False)\n\n content = layout_.generate_poetry_content()\n for section, item in content.items():\n pyproject.data.append(section, item)\n\n if is_interactive:\n self.line(\"Generated file\")\n self.line(\"\")\n self.line(pyproject.data.as_string().replace(\"\\r\\n\", \"\\n\"))\n self.line(\"\")\n\n if is_interactive and not self.confirm(\"Do you confirm generation?\", True):\n self.line_error(\"Command aborted\")\n\n return 1\n\n pyproject.save()\n\n if create_layout:\n path = project_path.resolve()\n\n with suppress(ValueError):\n path = path.relative_to(Path.cwd())\n\n self.line(\n f\"Created package {layout_._package_name} in\"\n f\" {path.as_posix()}\"\n )\n\n return 0\n\n def _generate_choice_list(\n self, matches: list[Package], canonicalized_name: NormalizedName\n ) -> list[str]:\n choices = []\n matches_names = [p.name for p in matches]\n exact_match = canonicalized_name in matches_names\n if exact_match:\n choices.append(matches[matches_names.index(canonicalized_name)].pretty_name)\n\n for found_package in matches:\n if len(choices) >= 10:\n break\n\n if found_package.name == canonicalized_name:\n continue\n\n choices.append(found_package.pretty_name)\n\n return choices\n\n def _determine_requirements(\n self,\n requires: list[str],\n allow_prereleases: bool = False,\n source: str | None = None,\n is_interactive: bool | None = None,\n ) -> list[dict[str, Any]]:\n if is_interactive is None:\n is_interactive = self.io.is_interactive()\n\n if not requires:\n result = []\n\n question = self.create_question(\n \"Package to add or search for (leave blank to skip):\"\n )\n question.set_validator(self._validate_package)\n\n follow_up_question = self.create_question(\n \"\\nAdd a package (leave blank to skip):\"\n )\n follow_up_question.set_validator(self._validate_package)\n\n package = self.ask(question)\n while package:\n constraint = self._parse_requirements([package])[0]\n if (\n \"git\" in constraint\n or \"url\" in constraint\n or \"path\" in constraint\n or \"version\" in constraint\n ):\n self.line(f\"Adding {package}\")\n result.append(constraint)\n package = self.ask(follow_up_question)\n continue\n\n canonicalized_name = canonicalize_name(constraint[\"name\"])\n matches = self._get_pool().search(canonicalized_name)\n if not matches:\n self.line_error(\"Unable to find package\")\n package = False\n else:\n choices = self._generate_choice_list(matches, canonicalized_name)\n\n info_string = (\n f\"Found {len(matches)} packages matching\"\n f\" {package}\"\n )\n\n if len(matches) > 10:\n info_string += \"\\nShowing the first 10 matches\"\n\n self.line(info_string)\n\n # Default to an empty value to signal no package was selected\n choices.append(\"\")\n\n package = self.choice(\n \"\\nEnter package # to add, or the complete package name if\"\n \" it is not listed\",\n choices,\n attempts=3,\n default=len(choices) - 1,\n )\n\n if not package:\n self.line(\"No package selected\")\n\n # package selected by user, set constraint name to package name\n if package:\n constraint[\"name\"] = package\n\n # no constraint yet, determine the best version automatically\n if package and \"version\" not in constraint:\n question = self.create_question(\n \"Enter the version constraint to require \"\n \"(or leave blank to use the latest version):\"\n )\n question.set_max_attempts(3)\n question.set_validator(lambda x: (x or \"\").strip() or None)\n\n package_constraint = self.ask(question)\n\n if package_constraint is None:\n _, package_constraint = self._find_best_version_for_package(\n package\n )\n\n self.line(\n f\"Using version {package_constraint} for\"\n f\" {package}\"\n )\n\n constraint[\"version\"] = package_constraint\n\n if package:\n result.append(constraint)\n\n if is_interactive:\n package = self.ask(follow_up_question)\n\n return result\n\n result = []\n for requirement in self._parse_requirements(requires):\n if \"git\" in requirement or \"url\" in requirement or \"path\" in requirement:\n result.append(requirement)\n continue\n elif \"version\" not in requirement:\n # determine the best version automatically\n name, version = self._find_best_version_for_package(\n requirement[\"name\"],\n allow_prereleases=allow_prereleases,\n source=source,\n )\n requirement[\"version\"] = version\n requirement[\"name\"] = name\n\n self.line(f\"Using version {version} for {name}\")\n else:\n # check that the specified version/constraint exists\n # before we proceed\n name, _ = self._find_best_version_for_package(\n requirement[\"name\"],\n requirement[\"version\"],\n allow_prereleases=allow_prereleases,\n source=source,\n )\n\n requirement[\"name\"] = name\n\n result.append(requirement)\n\n return result\n\n def _find_best_version_for_package(\n self,\n name: str,\n required_version: str | None = None,\n allow_prereleases: bool = False,\n source: str | None = None,\n ) -> tuple[str, str]:\n from poetry.version.version_selector import VersionSelector\n\n selector = VersionSelector(self._get_pool())\n package = selector.find_best_candidate(\n name, required_version, allow_prereleases=allow_prereleases, source=source\n )\n\n if not package:\n # TODO: find similar\n raise ValueError(f\"Could not find a matching version of package {name}\")\n\n return package.pretty_name, f\"^{package.version.to_string()}\"\n\n def _parse_requirements(self, requirements: list[str]) -> list[dict[str, Any]]:\n from poetry.core.pyproject.exceptions import PyProjectException\n\n try:\n cwd = self.poetry.file.path.parent\n artifact_cache = self.poetry.pool.artifact_cache\n except (PyProjectException, RuntimeError):\n cwd = Path.cwd()\n artifact_cache = self._get_pool().artifact_cache\n\n parser = RequirementsParser(\n artifact_cache=artifact_cache,\n env=self.env if isinstance(self, EnvCommand) else None,\n cwd=cwd,\n )\n return [parser.parse(requirement) for requirement in requirements]\n\n def _format_requirements(self, requirements: list[dict[str, str]]) -> Requirements:\n requires: Requirements = {}\n for requirement in requirements:\n name = requirement.pop(\"name\")\n constraint: str | InlineTable\n if \"version\" in requirement and len(requirement) == 1:\n constraint = requirement[\"version\"]\n else:\n constraint = inline_table()\n constraint.trivia.trail = \"\\n\"\n constraint.update(requirement)\n\n requires[name] = constraint\n\n return requires\n\n @staticmethod\n def _validate_author(author: str, default: str) -> str | None:\n from poetry.core.packages.package import AUTHOR_REGEX\n from poetry.core.utils.helpers import combine_unicode\n\n author = combine_unicode(author or default)\n\n if author in [\"n\", \"no\"]:\n return None\n\n m = AUTHOR_REGEX.match(author)\n if not m:\n raise ValueError(\n \"Invalid author string. Must be in the format: \"\n \"John Smith \"\n )\n\n return author\n\n @staticmethod\n def _validate_package(package: str | None) -> str | None:\n if package and len(package.split()) > 2:\n raise ValueError(\"Invalid package definition.\")\n\n return package\n\n def _get_pool(self) -> RepositoryPool:\n from poetry.config.config import Config\n from poetry.repositories import RepositoryPool\n from poetry.repositories.pypi_repository import PyPiRepository\n\n if isinstance(self, EnvCommand):\n return self.poetry.pool\n\n if self._pool is None:\n self._pool = RepositoryPool()\n pool_size = Config.create().installer_max_workers\n self._pool.add_repository(PyPiRepository(pool_size=pool_size))\n\n return self._pool\n\n# Path: src/poetry/console/commands/install.py\nfrom __future__ import annotations\n\nfrom typing import TYPE_CHECKING\nfrom typing import ClassVar\n\nfrom cleo.helpers import option\n\nfrom poetry.console.commands.installer_command import InstallerCommand\n\n\nif TYPE_CHECKING:\n from cleo.io.inputs.option import Option\n\n\nclass InstallCommand(InstallerCommand):\n name = \"install\"\n description = \"Installs the project dependencies.\"\n\n options: ClassVar[list[Option]] = [\n *InstallerCommand._group_dependency_options(),\n option(\n \"no-dev\",\n None,\n \"Do not install the development dependencies.\"\n \" (Deprecated)\",\n ),\n option(\n \"sync\",\n None,\n \"Synchronize the environment with the locked packages and the specified\"\n \" groups.\",\n ),\n option(\n \"no-root\", None, \"Do not install the root package (the current project).\"\n ),\n option(\n \"no-directory\",\n None,\n \"Do not install any directory path dependencies; useful to install\"\n \" dependencies without source code, e.g. for caching of Docker layers)\",\n flag=True,\n multiple=False,\n ),\n option(\n \"dry-run\",\n None,\n \"Output the operations but do not execute anything \"\n \"(implicitly enables --verbose).\",\n ),\n option(\n \"remove-untracked\",\n None,\n \"Removes packages not present in the lock file.\"\n \" (Deprecated)\",\n ),\n option(\n \"extras\",\n \"E\",\n \"Extra sets of dependencies to install.\",\n flag=False,\n multiple=True,\n ),\n option(\"all-extras\", None, \"Install all extra dependencies.\"),\n option(\"only-root\", None, \"Exclude all dependencies.\"),\n option(\n \"compile\",\n None,\n \"Compile Python source files to bytecode.\"\n \" (This option has no effect if modern-installation is disabled\"\n \" because the old installer always compiles.)\",\n ),\n ]\n\n help = \"\"\"\\\nThe install command reads the poetry.lock file from\nthe current directory, processes it, and downloads and installs all the\nlibraries and dependencies outlined in that file. If the file does not\nexist it will look for pyproject.toml and do the same.\n\npoetry install\n\nBy default, the above command will also install the current project. To install only the\ndependencies and not including the current project, run the command with the\n--no-root option like below:\n\n poetry install --no-root\n\nIf you want to use Poetry only for dependency management but not for packaging,\nyou can set the \"package-mode\" to false in your pyproject.toml file.\n\"\"\"\n\n _loggers: ClassVar[list[str]] = [\n \"poetry.repositories.pypi_repository\",\n \"poetry.inspection.info\",\n ]\n\n @property\n def activated_groups(self) -> set[str]:\n if self.option(\"only-root\"):\n return set()\n else:\n return super().activated_groups\n\n def handle(self) -> int:\n from poetry.core.masonry.utils.module import ModuleOrPackageNotFound\n\n from poetry.masonry.builders.editable import EditableBuilder\n\n if self.option(\"extras\") and self.option(\"all-extras\"):\n self.line_error(\n \"You cannot specify explicit\"\n \" `--extras` while installing\"\n \" using `--all-extras`.\"\n )\n return 1\n\n if self.option(\"only-root\") and any(\n self.option(key) for key in {\"with\", \"without\", \"only\"}\n ):\n self.line_error(\n \"The `--with`,\"\n \" `--without` and\"\n \" `--only` options cannot be used with\"\n \" the `--only-root`\"\n \" option.\"\n )\n return 1\n\n if self.option(\"only-root\") and self.option(\"no-root\"):\n self.line_error(\n \"You cannot specify `--no-root`\"\n \" when using `--only-root`.\"\n )\n return 1\n\n extras: list[str]\n if self.option(\"all-extras\"):\n extras = list(self.poetry.package.extras.keys())\n else:\n extras = []\n for extra in self.option(\"extras\", []):\n extras += extra.split()\n\n self.installer.extras(extras)\n\n with_synchronization = self.option(\"sync\")\n if self.option(\"remove-untracked\"):\n self.line_error(\n \"The `--remove-untracked` option is\"\n \" deprecated, use the `--sync` option\"\n \" instead.\"\n )\n\n with_synchronization = True\n\n self.installer.only_groups(self.activated_groups)\n self.installer.skip_directory(self.option(\"no-directory\"))\n self.installer.dry_run(self.option(\"dry-run\"))\n self.installer.requires_synchronization(with_synchronization)\n self.installer.executor.enable_bytecode_compilation(self.option(\"compile\"))\n self.installer.verbose(self.io.is_verbose())\n\n return_code = self.installer.run()\n\n if return_code != 0:\n return return_code\n\n if self.option(\"no-root\") or not self.poetry.is_package_mode:\n return 0\n\n log_install = (\n \"Installing the current project:\"\n f\" {self.poetry.package.pretty_name}\"\n f\" (<{{tag}}>{self.poetry.package.pretty_version})\"\n )\n overwrite = self.io.output.is_decorated() and not self.io.is_debug()\n self.line(\"\")\n self.write(log_install.format(tag=\"c2\"))\n if not overwrite:\n self.line(\"\")\n\n if self.option(\"dry-run\"):\n self.line(\"\")\n return 0\n\n # Prior to https://github.com/python-poetry/poetry-core/pull/629\n # the existence of a module/package was checked when creating the\n # EditableBuilder. Afterwards, the existence is checked after\n # executing the build script (if there is one),\n # i.e. during EditableBuilder.build().\n try:\n builder = EditableBuilder(self.poetry, self.env, self.io)\n builder.build()\n except (ModuleOrPackageNotFound, FileNotFoundError) as e:\n # This is likely due to the fact that the project is an application\n # not following the structure expected by Poetry.\n # No need for an editable install in this case.\n self.line(\"\")\n self.line_error(\n f\"Warning: The current project could not be installed: {e}\\n\"\n \"If you do not want to install the current project\"\n \" use --no-root.\\n\"\n \"If you want to use Poetry only for dependency management\"\n \" but not for packaging, you can disable package mode by setting\"\n \" package-mode = false in your pyproject.toml file.\\n\"\n \"In a future version of Poetry this warning will become an error!\",\n style=\"warning\",\n )\n return 0\n\n if overwrite:\n self.overwrite(log_install.format(tag=\"success\"))\n self.line(\"\")\n\n return 0\n\n# Path: src/poetry/console/commands/installer_command.py\nfrom __future__ import annotations\n\nfrom typing import TYPE_CHECKING\n\nfrom poetry.console.commands.env_command import EnvCommand\nfrom poetry.console.commands.group_command import GroupCommand\n\n\nif TYPE_CHECKING:\n from poetry.installation.installer import Installer\n\n\nclass InstallerCommand(GroupCommand, EnvCommand):\n def __init__(self) -> None:\n # Set in poetry.console.application.Application.configure_installer\n self._installer: Installer | None = None\n\n super().__init__()\n\n def reset_poetry(self) -> None:\n super().reset_poetry()\n\n self.installer.set_package(self.poetry.package)\n self.installer.set_locker(self.poetry.locker)\n\n @property\n def installer(self) -> Installer:\n assert self._installer is not None\n return self._installer\n\n def set_installer(self, installer: Installer) -> None:\n self._installer = installer\n\n# Path: src/poetry/console/commands/lock.py\nfrom __future__ import annotations\n\nfrom typing import TYPE_CHECKING\nfrom typing import ClassVar\n\nfrom cleo.helpers import option\n\nfrom poetry.console.commands.installer_command import InstallerCommand\n\n\nif TYPE_CHECKING:\n from cleo.io.inputs.option import Option\n\n\nclass LockCommand(InstallerCommand):\n name = \"lock\"\n description = \"Locks the project dependencies.\"\n\n options: ClassVar[list[Option]] = [\n option(\n \"no-update\", None, \"Do not update locked versions, only refresh lock file.\"\n ),\n option(\n \"check\",\n None,\n \"Check that the poetry.lock file corresponds to the current\"\n \" version of pyproject.toml. (Deprecated) Use\"\n \" poetry check --lock instead.\",\n ),\n ]\n\n help = \"\"\"\nThe lock command reads the pyproject.toml file from the\ncurrent directory, processes it, and locks the dependencies in the\\\n poetry.lock\nfile.\n\npoetry lock\n\"\"\"\n\n loggers: ClassVar[list[str]] = [\"poetry.repositories.pypi_repository\"]\n\n def handle(self) -> int:\n if self.option(\"check\"):\n self.line_error(\n \"poetry lock --check is deprecated, use `poetry\"\n \" check --lock` instead.\"\n )\n if self.poetry.locker.is_locked() and self.poetry.locker.is_fresh():\n self.line(\"poetry.lock is consistent with pyproject.toml.\")\n return 0\n self.line_error(\n \"\"\n \"Error: pyproject.toml changed significantly since poetry.lock was last generated. \"\n \"Run `poetry lock [--no-update]` to fix the lock file.\"\n \"\"\n )\n return 1\n\n self.installer.lock(update=not self.option(\"no-update\"))\n\n return self.installer.run()\n\n# Path: src/poetry/console/commands/new.py\nfrom __future__ import annotations\n\nfrom typing import TYPE_CHECKING\nfrom typing import ClassVar\n\nfrom cleo.helpers import argument\nfrom cleo.helpers import option\n\nfrom poetry.console.commands.init import InitCommand\n\n\nif TYPE_CHECKING:\n from cleo.io.inputs.argument import Argument\n from cleo.io.inputs.option import Option\n\n\nclass NewCommand(InitCommand):\n name = \"new\"\n description = \"Creates a new Python project at .\"\n\n arguments: ClassVar[list[Argument]] = [\n argument(\"path\", \"The path to create the project at.\")\n ]\n options: ClassVar[list[Option]] = [\n option(\n \"interactive\",\n \"i\",\n \"Allow interactive specification of project configuration.\",\n flag=True,\n ),\n option(\"name\", None, \"Set the resulting package name.\", flag=False),\n option(\"src\", None, \"Use the src layout for the project.\"),\n option(\n \"readme\",\n None,\n \"Specify the readme file format. One of md (default) or rst\",\n flag=False,\n ),\n *[\n o\n for o in InitCommand.options\n if o.name\n in {\n \"description\",\n \"author\",\n \"python\",\n \"dependency\",\n \"dev-dependency\",\n \"license\",\n }\n ],\n ]\n\n def handle(self) -> int:\n from pathlib import Path\n\n if self.io.input.option(\"directory\"):\n self.line_error(\n \"--directory only makes sense with existing projects, and will\"\n \" be ignored. You should consider the option --path instead.\"\n )\n\n path = Path(self.argument(\"path\"))\n if not path.is_absolute():\n # we do not use resolve here due to compatibility issues\n # for path.resolve(strict=False)\n path = Path.cwd().joinpath(path)\n\n if path.exists() and list(path.glob(\"*\")):\n # Directory is not empty. Aborting.\n raise RuntimeError(\n f\"Destination {path} exists and is not empty\"\n )\n\n return self._init_pyproject(\n project_path=path,\n allow_interactive=self.option(\"interactive\"),\n layout_name=\"src\" if self.option(\"src\") else \"standard\",\n readme_format=self.option(\"readme\") or \"md\",\n )\n\n# Path: src/poetry/console/commands/publish.py\nfrom __future__ import annotations\n\nfrom pathlib import Path\nfrom typing import TYPE_CHECKING\nfrom typing import ClassVar\n\nfrom cleo.helpers import option\n\nfrom poetry.console.commands.command import Command\n\n\nif TYPE_CHECKING:\n from cleo.io.inputs.option import Option\n\n\nclass PublishCommand(Command):\n name = \"publish\"\n description = \"Publishes a package to a remote repository.\"\n\n options: ClassVar[list[Option]] = [\n option(\n \"repository\", \"r\", \"The repository to publish the package to.\", flag=False\n ),\n option(\"username\", \"u\", \"The username to access the repository.\", flag=False),\n option(\"password\", \"p\", \"The password to access the repository.\", flag=False),\n option(\n \"cert\", None, \"Certificate authority to access the repository.\", flag=False\n ),\n option(\n \"client-cert\",\n None,\n \"Client certificate to access the repository.\",\n flag=False,\n ),\n option(\n \"dist-dir\",\n None,\n \"Dist directory where built artifact are stored. Default is `dist`.\",\n default=\"dist\",\n flag=False,\n ),\n option(\"build\", None, \"Build the package before publishing.\"),\n option(\"dry-run\", None, \"Perform all actions except upload the package.\"),\n option(\n \"skip-existing\",\n None,\n \"Ignore errors from files already existing in the repository.\",\n ),\n ]\n\n help = \"\"\"The publish command builds and uploads the package to a remote repository.\n\nBy default, it will upload to PyPI but if you pass the --repository option it will\nupload to it instead.\n\nThe --repository option should match the name of a configured repository using\nthe config command.\n\"\"\"\n\n loggers: ClassVar[list[str]] = [\"poetry.publishing.publisher\"]\n\n def handle(self) -> int:\n from poetry.publishing.publisher import Publisher\n\n if not self.poetry.is_package_mode:\n self.line_error(\"Publishing a package is not possible in non-package mode.\")\n return 1\n\n dist_dir = self.option(\"dist-dir\")\n\n publisher = Publisher(self.poetry, self.io, Path(dist_dir))\n\n # Building package first, if told\n if self.option(\"build\"):\n if publisher.files and not self.confirm(\n f\"There are {len(publisher.files)} files ready for\"\n \" publishing. Build anyway?\"\n ):\n self.line_error(\"Aborted!\")\n\n return 1\n\n self.call(\"build\", args=f\"--output {dist_dir}\")\n\n files = publisher.files\n if not files:\n self.line_error(\n \"No files to publish. \"\n \"Run poetry build first or use the --build option.\"\n )\n\n return 1\n\n self.line(\"\")\n\n cert = Path(self.option(\"cert\")) if self.option(\"cert\") else None\n client_cert = (\n Path(self.option(\"client-cert\")) if self.option(\"client-cert\") else None\n )\n\n publisher.publish(\n self.option(\"repository\"),\n self.option(\"username\"),\n self.option(\"password\"),\n cert,\n client_cert,\n self.option(\"dry-run\"),\n self.option(\"skip-existing\"),\n )\n\n return 0\n\n# Path: src/poetry/console/commands/remove.py\nfrom __future__ import annotations\n\nfrom typing import TYPE_CHECKING\nfrom typing import Any\nfrom typing import ClassVar\n\nfrom cleo.helpers import argument\nfrom cleo.helpers import option\nfrom packaging.utils import canonicalize_name\nfrom poetry.core.packages.dependency_group import MAIN_GROUP\nfrom tomlkit.toml_document import TOMLDocument\n\nfrom poetry.console.commands.installer_command import InstallerCommand\n\n\nif TYPE_CHECKING:\n from cleo.io.inputs.argument import Argument\n from cleo.io.inputs.option import Option\n\n\nclass RemoveCommand(InstallerCommand):\n name = \"remove\"\n description = \"Removes a package from the project dependencies.\"\n\n arguments: ClassVar[list[Argument]] = [\n argument(\"packages\", \"The packages to remove.\", multiple=True)\n ]\n options: ClassVar[list[Option]] = [\n option(\"group\", \"G\", \"The group to remove the dependency from.\", flag=False),\n option(\n \"dev\",\n \"D\",\n \"Remove a package from the development dependencies.\"\n \" (Deprecated)\"\n \" Use --group=dev instead.\",\n ),\n option(\n \"dry-run\",\n None,\n \"Output the operations but do not execute anything \"\n \"(implicitly enables --verbose).\",\n ),\n option(\"lock\", None, \"Do not perform operations (only update the lockfile).\"),\n ]\n\n help = \"\"\"The remove command removes a package from the current\nlist of installed packages\n\npoetry remove\"\"\"\n\n loggers: ClassVar[list[str]] = [\n \"poetry.repositories.pypi_repository\",\n \"poetry.inspection.info\",\n ]\n\n def handle(self) -> int:\n packages = self.argument(\"packages\")\n\n if self.option(\"dev\"):\n self.line_error(\n \"The --dev option is deprecated, \"\n \"use the `--group dev` notation instead.\"\n )\n group = \"dev\"\n else:\n group = self.option(\"group\", self.default_group)\n\n content: dict[str, Any] = self.poetry.file.read()\n poetry_content = content[\"tool\"][\"poetry\"]\n\n if group is None:\n removed = []\n group_sections = [\n (group_name, group_section.get(\"dependencies\", {}))\n for group_name, group_section in poetry_content.get(\"group\", {}).items()\n ]\n\n for group_name, section in [\n (MAIN_GROUP, poetry_content[\"dependencies\"]),\n *group_sections,\n ]:\n removed += self._remove_packages(packages, section, group_name)\n if group_name != MAIN_GROUP:\n if not section:\n del poetry_content[\"group\"][group_name]\n else:\n poetry_content[\"group\"][group_name][\"dependencies\"] = section\n elif group == \"dev\" and \"dev-dependencies\" in poetry_content:\n # We need to account for the old `dev-dependencies` section\n removed = self._remove_packages(\n packages, poetry_content[\"dev-dependencies\"], \"dev\"\n )\n\n if not poetry_content[\"dev-dependencies\"]:\n del poetry_content[\"dev-dependencies\"]\n else:\n removed = []\n if \"group\" in poetry_content:\n if group in poetry_content[\"group\"]:\n removed = self._remove_packages(\n packages,\n poetry_content[\"group\"][group].get(\"dependencies\", {}),\n group,\n )\n\n if not poetry_content[\"group\"][group]:\n del poetry_content[\"group\"][group]\n\n if \"group\" in poetry_content and not poetry_content[\"group\"]:\n del poetry_content[\"group\"]\n\n removed_set = set(removed)\n not_found = set(packages).difference(removed_set)\n if not_found:\n raise ValueError(\n \"The following packages were not found: \" + \", \".join(sorted(not_found))\n )\n\n # Refresh the locker\n self.poetry.locker.set_local_config(poetry_content)\n self.installer.set_locker(self.poetry.locker)\n self.installer.set_package(self.poetry.package)\n self.installer.dry_run(self.option(\"dry-run\", False))\n self.installer.verbose(self.io.is_verbose())\n self.installer.update(True)\n self.installer.execute_operations(not self.option(\"lock\"))\n self.installer.whitelist(removed_set)\n\n status = self.installer.run()\n\n if not self.option(\"dry-run\") and status == 0:\n assert isinstance(content, TOMLDocument)\n self.poetry.file.write(content)\n\n return status\n\n def _remove_packages(\n self, packages: list[str], section: dict[str, Any], group_name: str\n ) -> list[str]:\n removed = []\n group = self.poetry.package.dependency_group(group_name)\n section_keys = list(section.keys())\n\n for package in packages:\n for existing_package in section_keys:\n if canonicalize_name(existing_package) == canonicalize_name(package):\n del section[existing_package]\n removed.append(package)\n group.remove_dependency(package)\n\n return removed\n\n# Path: src/poetry/console/commands/run.py\nfrom __future__ import annotations\n\nfrom typing import TYPE_CHECKING\nfrom typing import ClassVar\n\nfrom cleo.helpers import argument\n\nfrom poetry.console.commands.env_command import EnvCommand\nfrom poetry.utils._compat import WINDOWS\n\n\nif TYPE_CHECKING:\n from cleo.io.inputs.argument import Argument\n from poetry.core.masonry.utils.module import Module\n\n\nclass RunCommand(EnvCommand):\n name = \"run\"\n description = \"Runs a command in the appropriate environment.\"\n\n arguments: ClassVar[list[Argument]] = [\n argument(\"args\", \"The command and arguments/options to run.\", multiple=True)\n ]\n\n def handle(self) -> int:\n args = self.argument(\"args\")\n script = args[0]\n scripts = self.poetry.local_config.get(\"scripts\")\n\n if scripts and script in scripts:\n return self.run_script(scripts[script], args)\n\n try:\n return self.env.execute(*args)\n except FileNotFoundError:\n self.line_error(f\"Command not found: {script}\")\n return 1\n\n @property\n def _module(self) -> Module:\n from poetry.core.masonry.utils.module import Module\n\n poetry = self.poetry\n package = poetry.package\n path = poetry.file.path.parent\n module = Module(package.name, path.as_posix(), package.packages)\n\n return module\n\n def run_script(self, script: str | dict[str, str], args: list[str]) -> int:\n \"\"\"Runs an entry point script defined in the section ``[tool.poetry.scripts]``.\n\n When a script exists in the venv bin folder, i.e. after ``poetry install``,\n then ``sys.argv[0]`` must be set to the full path of the executable, so\n ``poetry run foo`` and ``poetry shell``, ``foo`` have the same ``sys.argv[0]``\n that points to the full path.\n\n Otherwise (when an entry point script does not exist), ``sys.argv[0]`` is the\n script name only, i.e. ``poetry run foo`` has ``sys.argv == ['foo']``.\n \"\"\"\n for script_dir in self.env.script_dirs:\n script_path = script_dir / args[0]\n if WINDOWS:\n script_path = script_path.with_suffix(\".cmd\")\n if script_path.exists():\n args = [str(script_path), *args[1:]]\n break\n else:\n # If we reach this point, the script is not installed\n self._warning_not_installed_script(args[0])\n\n if isinstance(script, dict):\n script = script[\"callable\"]\n\n module, callable_ = script.split(\":\")\n\n src_in_sys_path = \"sys.path.append('src'); \" if self._module.is_in_src() else \"\"\n\n cmd = [\"python\", \"-c\"]\n\n cmd += [\n \"import sys; \"\n \"from importlib import import_module; \"\n f\"sys.argv = {args!r}; {src_in_sys_path}\"\n f\"sys.exit(import_module('{module}').{callable_}())\"\n ]\n\n return self.env.execute(*cmd)\n\n def _warning_not_installed_script(self, script: str) -> None:\n message = f\"\"\"\\\nWarning: '{script}' is an entry point defined in pyproject.toml, but it's not \\\ninstalled as a script. You may get improper `sys.argv[0]`.\n\nThe support to run uninstalled scripts will be removed in a future release.\n\nRun `poetry install` to resolve and get rid of this message.\n\"\"\"\n self.line_error(message, style=\"warning\")\n\n# Path: src/poetry/console/commands/search.py\nfrom __future__ import annotations\n\nfrom typing import TYPE_CHECKING\nfrom typing import ClassVar\n\nfrom cleo.helpers import argument\n\nfrom poetry.console.commands.command import Command\n\n\nif TYPE_CHECKING:\n from cleo.io.inputs.argument import Argument\n\n\nclass SearchCommand(Command):\n name = \"search\"\n description = \"Searches for packages on remote repositories.\"\n\n arguments: ClassVar[list[Argument]] = [\n argument(\"tokens\", \"The tokens to search for.\", multiple=True)\n ]\n\n...\n", "needle_token_start": 3886, "needle_token_end": 4116, "code_context_ntokens": 16018, "output": ["```python\ndef _validate_group_options(self, group_options: dict[str, set[str]]) -> None:\n \"\"\"\n Raises an error if it detects that a group is not part of pyproject.toml\n \"\"\"\n invalid_options = defaultdict(set)\n for opt, groups in group_options.items():\n for group in groups:\n if not self.poetry.package.has_dependency_group(group):\n invalid_options[group].add(opt)\n if invalid_options:\n message_parts = []\n for group in sorted(invalid_options):\n opts = \", \".join(\n f\"--{opt}\"\n for opt in sorted(invalid_options[group])\n )\n message_parts.append(f\"{group} (via {opts})\")\n raise GroupNotFound(f\"Group(s) not found: {', '.join(message_parts)}\")\n```"]} +{"repo": "python-poetry/poetry", "name": "candidates", "language": "python", "path": "src/poetry/utils/env/site_packages.py", "position_ratio": 0.35, "description": "\nFunction Description:\n1. **Purpose**: To compile a list of potential directories where Python packages can be installed.\n2. **Input**: Directories specified for pure Python libraries and platform-specific libraries, along with optional fallback directories.\n3. **Output**: A list of unique directory paths that are considered as potential installation targets.\n4. **Procedure**: The method combines the specified directories for pure and platform-specific libraries with any provided fallback directories. It ensures each directory is unique in the resulting list to avoid duplication.\n\n", "instruction": "Based on the function description and code context, please retrieve and repeat the exact described function from the code context in a code block wrapped by ```:", "template": "instruction\ncode_context\ndescription\ninstruction", "code_context": "# Path: src/poetry/utils/env/env_manager.py\nfrom __future__ import annotations\n\nimport base64\nimport hashlib\nimport os\nimport plistlib\nimport re\nimport shutil\nimport subprocess\nimport sys\n\nfrom functools import cached_property\nfrom pathlib import Path\nfrom subprocess import CalledProcessError\nfrom typing import TYPE_CHECKING\n\nimport tomlkit\nimport virtualenv\n\nfrom cleo.io.null_io import NullIO\nfrom cleo.io.outputs.output import Verbosity\nfrom poetry.core.constraints.version import Version\nfrom poetry.core.constraints.version import parse_constraint\n\nfrom poetry.toml.file import TOMLFile\nfrom poetry.utils._compat import WINDOWS\nfrom poetry.utils._compat import encode\nfrom poetry.utils.env.exceptions import EnvCommandError\nfrom poetry.utils.env.exceptions import IncorrectEnvError\nfrom poetry.utils.env.exceptions import InvalidCurrentPythonVersionError\nfrom poetry.utils.env.exceptions import NoCompatiblePythonVersionFound\nfrom poetry.utils.env.exceptions import PythonVersionNotFound\nfrom poetry.utils.env.generic_env import GenericEnv\nfrom poetry.utils.env.script_strings import GET_ENV_PATH_ONELINER\nfrom poetry.utils.env.script_strings import GET_PYTHON_VERSION_ONELINER\nfrom poetry.utils.env.system_env import SystemEnv\nfrom poetry.utils.env.virtual_env import VirtualEnv\nfrom poetry.utils.helpers import get_real_windows_path\nfrom poetry.utils.helpers import remove_directory\n\n\nif TYPE_CHECKING:\n from cleo.io.io import IO\n\n from poetry.poetry import Poetry\n from poetry.utils.env.base_env import Env\n\n\nclass EnvsFile(TOMLFile):\n \"\"\"\n This file contains one section per project with the project's base env name\n as section name. Each section contains the minor and patch version of the\n python executable used to create the currently active virtualenv.\n\n Example:\n\n [poetry-QRErDmmj]\n minor = \"3.9\"\n patch = \"3.9.13\"\n\n [poetry-core-m5r7DkRA]\n minor = \"3.11\"\n patch = \"3.11.6\"\n \"\"\"\n\n def remove_section(self, name: str, minor: str | None = None) -> str | None:\n \"\"\"\n Remove a section from the envs file.\n\n If \"minor\" is given, the section is only removed if its minor value\n matches \"minor\".\n\n Returns the \"minor\" value of the removed section.\n \"\"\"\n envs = self.read()\n current_env = envs.get(name)\n if current_env is not None and (not minor or current_env[\"minor\"] == minor):\n del envs[name]\n self.write(envs)\n minor = current_env[\"minor\"]\n assert isinstance(minor, str)\n return minor\n\n return None\n\n\nclass EnvManager:\n \"\"\"\n Environments manager\n \"\"\"\n\n _env = None\n\n ENVS_FILE = \"envs.toml\"\n\n def __init__(self, poetry: Poetry, io: None | IO = None) -> None:\n self._poetry = poetry\n self._io = io or NullIO()\n\n @staticmethod\n def _full_python_path(python: str) -> Path | None:\n # eg first find pythonXY.bat on windows.\n path_python = shutil.which(python)\n if path_python is None:\n return None\n\n try:\n executable = subprocess.check_output(\n [path_python, \"-c\", \"import sys; print(sys.executable)\"], text=True\n ).strip()\n return Path(executable)\n\n except CalledProcessError:\n return None\n\n @staticmethod\n def _detect_active_python(io: None | IO = None) -> Path | None:\n io = io or NullIO()\n io.write_error_line(\n \"Trying to detect current active python executable as specified in\"\n \" the config.\",\n verbosity=Verbosity.VERBOSE,\n )\n\n executable = EnvManager._full_python_path(\"python\")\n\n if executable is not None:\n io.write_error_line(f\"Found: {executable}\", verbosity=Verbosity.VERBOSE)\n else:\n io.write_error_line(\n \"Unable to detect the current active python executable. Falling\"\n \" back to default.\",\n verbosity=Verbosity.VERBOSE,\n )\n\n return executable\n\n @staticmethod\n def get_python_version(\n precision: int = 3,\n prefer_active_python: bool = False,\n...\n# Path: src/poetry/utils/env/exceptions.py\nfrom __future__ import annotations\n\nfrom typing import TYPE_CHECKING\n\nfrom poetry.utils._compat import decode\n\n\nif TYPE_CHECKING:\n from subprocess import CalledProcessError\n\n\nclass EnvError(Exception):\n pass\n\n\nclass IncorrectEnvError(EnvError):\n def __init__(self, env_name: str) -> None:\n message = f\"Env {env_name} doesn't belong to this project.\"\n super().__init__(message)\n\n\nclass EnvCommandError(EnvError):\n def __init__(self, e: CalledProcessError) -> None:\n self.e = e\n\n message_parts = [\n f\"Command {e.cmd} errored with the following return code {e.returncode}\"\n ]\n if e.output:\n message_parts.append(f\"Output:\\n{decode(e.output)}\")\n if e.stderr:\n message_parts.append(f\"Error output:\\n{decode(e.stderr)}\")\n super().__init__(\"\\n\\n\".join(message_parts))\n\n\nclass PythonVersionNotFound(EnvError):\n def __init__(self, expected: str) -> None:\n super().__init__(f\"Could not find the python executable {expected}\")\n\n\nclass NoCompatiblePythonVersionFound(EnvError):\n def __init__(self, expected: str, given: str | None = None) -> None:\n if given:\n message = (\n f\"The specified Python version ({given}) \"\n f\"is not supported by the project ({expected}).\\n\"\n \"Please choose a compatible version \"\n \"or loosen the python constraint specified \"\n \"in the pyproject.toml file.\"\n )\n else:\n message = (\n \"Poetry was unable to find a compatible version. \"\n \"If you have one, you can explicitly use it \"\n 'via the \"env use\" command.'\n )\n\n super().__init__(message)\n\n\nclass InvalidCurrentPythonVersionError(EnvError):\n def __init__(self, expected: str, given: str) -> None:\n message = (\n f\"Current Python version ({given}) \"\n f\"is not allowed by the project ({expected}).\\n\"\n 'Please change python executable via the \"env use\" command.'\n )\n\n super().__init__(message)\n\n# Path: src/poetry/utils/env/generic_env.py\nfrom __future__ import annotations\n\nimport json\nimport os\nimport re\nimport subprocess\n\nfrom typing import TYPE_CHECKING\nfrom typing import Any\n\nfrom poetry.utils.env.script_strings import GET_PATHS_FOR_GENERIC_ENVS\nfrom poetry.utils.env.virtual_env import VirtualEnv\n\n\nif TYPE_CHECKING:\n from pathlib import Path\n\n from poetry.utils.env.base_env import Env\n\n\nclass GenericEnv(VirtualEnv):\n def __init__(\n self, path: Path, base: Path | None = None, child_env: Env | None = None\n ) -> None:\n self._child_env = child_env\n\n super().__init__(path, base=base)\n\n def find_executables(self) -> None:\n patterns = [(\"python*\", \"pip*\")]\n\n if self._child_env:\n minor_version = (\n f\"{self._child_env.version_info[0]}.{self._child_env.version_info[1]}\"\n )\n major_version = f\"{self._child_env.version_info[0]}\"\n patterns = [\n (f\"python{minor_version}\", f\"pip{minor_version}\"),\n (f\"python{major_version}\", f\"pip{major_version}\"),\n ]\n\n python_executable = None\n pip_executable = None\n\n for python_pattern, pip_pattern in patterns:\n if python_executable and pip_executable:\n break\n\n if not python_executable:\n python_executables = sorted(\n p.name\n for p in self._bin_dir.glob(python_pattern)\n if re.match(r\"python(?:\\d+(?:\\.\\d+)?)?(?:\\.exe)?$\", p.name)\n )\n\n if python_executables:\n executable = python_executables[0]\n if executable.endswith(\".exe\"):\n executable = executable[:-4]\n\n python_executable = executable\n\n if not pip_executable:\n pip_executables = sorted(\n p.name\n for p in self._bin_dir.glob(pip_pattern)\n if re.match(r\"pip(?:\\d+(?:\\.\\d+)?)?(?:\\.exe)?$\", p.name)\n )\n if pip_executables:\n pip_executable = pip_executables[0]\n if pip_executable.endswith(\".exe\"):\n pip_executable = pip_executable[:-4]\n\n if python_executable:\n self._executable = python_executable\n\n if pip_executable:\n self._pip_executable = pip_executable\n\n def get_paths(self) -> dict[str, str]:\n output = self.run_python_script(GET_PATHS_FOR_GENERIC_ENVS)\n\n paths: dict[str, str] = json.loads(output)\n return paths\n\n def execute(self, bin: str, *args: str, **kwargs: Any) -> int:\n command = self.get_command_from_bin(bin) + list(args)\n env = kwargs.pop(\"env\", dict(os.environ))\n\n if not self._is_windows:\n return os.execvpe(command[0], command, env=env)\n\n exe = subprocess.Popen(command, env=env, **kwargs)\n exe.communicate()\n\n return exe.returncode\n\n def _run(self, cmd: list[str], **kwargs: Any) -> str:\n return super(VirtualEnv, self)._run(cmd, **kwargs)\n\n def is_venv(self) -> bool:\n return self._path != self._base\n\n# Path: src/poetry/utils/env/mock_env.py\nfrom __future__ import annotations\n\nfrom typing import TYPE_CHECKING\nfrom typing import Any\n\nfrom poetry.utils.env.null_env import NullEnv\n\n\nif TYPE_CHECKING:\n from packaging.tags import Tag\n\n\nclass MockEnv(NullEnv):\n def __init__(\n self,\n version_info: tuple[int, int, int] = (3, 7, 0),\n *,\n python_implementation: str = \"CPython\",\n platform: str = \"darwin\",\n platform_machine: str = \"amd64\",\n os_name: str = \"posix\",\n is_venv: bool = False,\n sys_path: list[str] | None = None,\n marker_env: dict[str, Any] | None = None,\n supported_tags: list[Tag] | None = None,\n **kwargs: Any,\n ) -> None:\n super().__init__(**kwargs)\n\n self._version_info = version_info\n self._python_implementation = python_implementation\n self._platform = platform\n self._platform_machine = platform_machine\n self._os_name = os_name\n self._is_venv = is_venv\n self._sys_path = sys_path\n self._mock_marker_env = marker_env\n self._supported_tags = supported_tags\n\n @property\n def platform(self) -> str:\n return self._platform\n\n @property\n def platform_machine(self) -> str:\n return self._platform_machine\n\n @property\n def os(self) -> str:\n return self._os_name\n\n @property\n def sys_path(self) -> list[str]:\n if self._sys_path is None:\n return super().sys_path\n\n return self._sys_path\n\n def get_marker_env(self) -> dict[str, Any]:\n if self._mock_marker_env is not None:\n return self._mock_marker_env\n\n marker_env = super().get_marker_env()\n marker_env[\"python_implementation\"] = self._python_implementation\n marker_env[\"version_info\"] = self._version_info\n marker_env[\"python_version\"] = \".\".join(str(v) for v in self._version_info[:2])\n marker_env[\"python_full_version\"] = \".\".join(str(v) for v in self._version_info)\n marker_env[\"sys_platform\"] = self._platform\n marker_env[\"platform_machine\"] = self._platform_machine\n marker_env[\"interpreter_name\"] = self._python_implementation.lower()\n marker_env[\"interpreter_version\"] = \"cp\" + \"\".join(\n str(v) for v in self._version_info[:2]\n )\n\n return marker_env\n\n def is_venv(self) -> bool:\n return self._is_venv\n\n# Path: src/poetry/utils/env/null_env.py\nfrom __future__ import annotations\n\nimport sys\n\nfrom pathlib import Path\nfrom typing import Any\n\nfrom poetry.utils.env.system_env import SystemEnv\n\n\nclass NullEnv(SystemEnv):\n def __init__(\n self, path: Path | None = None, base: Path | None = None, execute: bool = False\n ) -> None:\n if path is None:\n path = Path(sys.prefix)\n\n super().__init__(path, base=base)\n\n self._execute = execute\n self.executed: list[list[str]] = []\n\n @property\n def paths(self) -> dict[str, str]:\n if self._paths is None:\n self._paths = self.get_paths()\n self._paths[\"platlib\"] = str(self._path / \"platlib\")\n self._paths[\"purelib\"] = str(self._path / \"purelib\")\n self._paths[\"scripts\"] = str(self._path / \"scripts\")\n self._paths[\"data\"] = str(self._path / \"data\")\n\n return self._paths\n\n def _run(self, cmd: list[str], **kwargs: Any) -> str:\n self.executed.append(cmd)\n\n if self._execute:\n return super()._run(cmd, **kwargs)\n return \"\"\n\n def execute(self, bin: str, *args: str, **kwargs: Any) -> int:\n self.executed.append([bin, *list(args)])\n\n if self._execute:\n return super().execute(bin, *args, **kwargs)\n return 0\n\n def _bin(self, bin: str) -> str:\n return bin\n\n# Path: src/poetry/utils/env/script_strings.py\nfrom __future__ import annotations\n\nimport packaging.tags\n\n\nGET_SYS_TAGS = f\"\"\"\nimport importlib.util\nimport json\nimport sys\n\nfrom pathlib import Path\n\nspec = importlib.util.spec_from_file_location(\n \"packaging\", Path(r\"{packaging.__file__}\")\n)\npackaging = importlib.util.module_from_spec(spec)\nsys.modules[spec.name] = packaging\n\nspec = importlib.util.spec_from_file_location(\n \"packaging.tags\", Path(r\"{packaging.tags.__file__}\")\n)\npackaging_tags = importlib.util.module_from_spec(spec)\nspec.loader.exec_module(packaging_tags)\n\nprint(\n json.dumps([(t.interpreter, t.abi, t.platform) for t in packaging_tags.sys_tags()])\n)\n\"\"\"\n\nGET_ENVIRONMENT_INFO = \"\"\"\\\nimport json\nimport os\nimport platform\nimport sys\nimport sysconfig\n\nINTERPRETER_SHORT_NAMES = {\n \"python\": \"py\",\n \"cpython\": \"cp\",\n \"pypy\": \"pp\",\n \"ironpython\": \"ip\",\n \"jython\": \"jy\",\n}\n\n\ndef interpreter_version():\n version = sysconfig.get_config_var(\"interpreter_version\")\n if version:\n version = str(version)\n else:\n version = _version_nodot(sys.version_info[:2])\n\n return version\n\n\ndef _version_nodot(version):\n if any(v >= 10 for v in version):\n sep = \"_\"\n else:\n sep = \"\"\n\n return sep.join(map(str, version))\n\n\nif hasattr(sys, \"implementation\"):\n info = sys.implementation.version\n iver = \"{0.major}.{0.minor}.{0.micro}\".format(info)\n kind = info.releaselevel\n if kind != \"final\":\n iver += kind[0] + str(info.serial)\n\n implementation_name = sys.implementation.name\nelse:\n iver = \"0\"\n implementation_name = platform.python_implementation().lower()\n\nenv = {\n \"implementation_name\": implementation_name,\n \"implementation_version\": iver,\n \"os_name\": os.name,\n \"platform_machine\": platform.machine(),\n \"platform_release\": platform.release(),\n \"platform_system\": platform.system(),\n \"platform_version\": platform.version(),\n \"python_full_version\": platform.python_version(),\n \"platform_python_implementation\": platform.python_implementation(),\n \"python_version\": \".\".join(platform.python_version_tuple()[:2]),\n \"sys_platform\": sys.platform,\n \"version_info\": tuple(sys.version_info),\n # Extra information\n \"interpreter_name\": INTERPRETER_SHORT_NAMES.get(\n implementation_name, implementation_name\n ),\n \"interpreter_version\": interpreter_version(),\n}\n\nprint(json.dumps(env))\n\"\"\"\n\nGET_BASE_PREFIX = \"\"\"\\\nimport sys\n\nif hasattr(sys, \"real_prefix\"):\n print(sys.real_prefix)\nelif hasattr(sys, \"base_prefix\"):\n print(sys.base_prefix)\nelse:\n print(sys.prefix)\n\"\"\"\n\nGET_PYTHON_VERSION = \"\"\"\\\nimport sys\n\nprint('.'.join([str(s) for s in sys.version_info[:3]]))\n\"\"\"\n\nGET_PYTHON_VERSION_ONELINER = (\n \"import sys; print('.'.join([str(s) for s in sys.version_info[:3]]))\"\n)\nGET_ENV_PATH_ONELINER = \"import sys; print(sys.prefix)\"\n\nGET_SYS_PATH = \"\"\"\\\nimport json\nimport sys\n\nprint(json.dumps(sys.path))\n\"\"\"\n\nGET_PATHS = \"\"\"\\\nimport json\nimport sysconfig\n\nprint(json.dumps(sysconfig.get_paths()))\n\"\"\"\n\nGET_PATHS_FOR_GENERIC_ENVS = \"\"\"\\\nimport json\nimport site\nimport sysconfig\n\npaths = sysconfig.get_paths().copy()\n\nif site.check_enableusersite():\n paths[\"usersite\"] = site.getusersitepackages()\n paths[\"userbase\"] = site.getuserbase()\n\nprint(json.dumps(paths))\n\"\"\"\n\n# Path: src/poetry/utils/env/site_packages.py\nfrom __future__ import annotations\n\nimport contextlib\nimport itertools\n\nfrom pathlib import Path\nfrom typing import TYPE_CHECKING\nfrom typing import Any\n\nfrom poetry.utils._compat import metadata\nfrom poetry.utils.helpers import is_dir_writable\nfrom poetry.utils.helpers import paths_csv\nfrom poetry.utils.helpers import remove_directory\n\n\nif TYPE_CHECKING:\n from collections.abc import Iterable\n\n\nclass SitePackages:\n def __init__(\n self,\n purelib: Path,\n platlib: Path | None = None,\n fallbacks: list[Path] | None = None,\n skip_write_checks: bool = False,\n ) -> None:\n self._purelib = purelib\n self._platlib = platlib or purelib\n\n if platlib and platlib.resolve() == purelib.resolve():\n self._platlib = purelib\n\n self._fallbacks = fallbacks or []\n self._skip_write_checks = skip_write_checks\n\n self._candidates: list[Path] = []\n for path in itertools.chain([self._purelib, self._platlib], self._fallbacks):\n if path not in self._candidates:\n self._candidates.append(path)\n\n self._writable_candidates = None if not skip_write_checks else self._candidates\n\n @property\n def path(self) -> Path:\n return self._purelib\n\n @property\n def purelib(self) -> Path:\n return self._purelib\n\n @property\n def platlib(self) -> Path:\n return self._platlib\n\n @property\n \ndef candidates(self) -> list[Path]:\n return self._candidates\n\n @property\n def writable_candidates(self) -> list[Path]:\n if self._writable_candidates is not None:\n return self._writable_candidates\n\n self._writable_candidates = []\n for candidate in self._candidates:\n if not is_dir_writable(path=candidate, create=True):\n continue\n self._writable_candidates.append(candidate)\n\n return self._writable_candidates\n\n def make_candidates(\n self, path: Path, writable_only: bool = False, strict: bool = False\n ) -> list[Path]:\n candidates = self._candidates if not writable_only else self.writable_candidates\n if path.is_absolute():\n for candidate in candidates:\n with contextlib.suppress(ValueError):\n path.relative_to(candidate)\n return [path]\n site_type = \"writable \" if writable_only else \"\"\n raise ValueError(\n f\"{path} is not relative to any discovered {site_type}sites\"\n )\n\n results = [candidate / path for candidate in candidates]\n\n if not results and strict:\n raise RuntimeError(\n f'Unable to find a suitable destination for \"{path}\" in'\n f\" {paths_csv(self._candidates)}\"\n )\n\n return results\n\n def distributions(\n self, name: str | None = None, writable_only: bool = False\n ) -> Iterable[metadata.Distribution]:\n path = list(\n map(\n str, self._candidates if not writable_only else self.writable_candidates\n )\n )\n\n yield from metadata.PathDistribution.discover(name=name, path=path)\n\n def find_distribution(\n self, name: str, writable_only: bool = False\n ) -> metadata.Distribution | None:\n for distribution in self.distributions(name=name, writable_only=writable_only):\n return distribution\n return None\n\n def find_distribution_files_with_suffix(\n self, distribution_name: str, suffix: str, writable_only: bool = False\n ) -> Iterable[Path]:\n for distribution in self.distributions(\n name=distribution_name, writable_only=writable_only\n ):\n files = [] if distribution.files is None else distribution.files\n for file in files:\n if file.name.endswith(suffix):\n path = distribution.locate_file(file)\n assert isinstance(path, Path)\n yield path\n\n def find_distribution_files_with_name(\n self, distribution_name: str, name: str, writable_only: bool = False\n ) -> Iterable[Path]:\n for distribution in self.distributions(\n name=distribution_name, writable_only=writable_only\n ):\n files = [] if distribution.files is None else distribution.files\n for file in files:\n if file.name == name:\n path = distribution.locate_file(file)\n assert isinstance(path, Path)\n yield path\n\n def find_distribution_direct_url_json_files(\n self, distribution_name: str, writable_only: bool = False\n ) -> Iterable[Path]:\n return self.find_distribution_files_with_name(\n distribution_name=distribution_name,\n name=\"direct_url.json\",\n writable_only=writable_only,\n )\n\n def remove_distribution_files(self, distribution_name: str) -> list[Path]:\n paths = []\n\n for distribution in self.distributions(\n name=distribution_name, writable_only=True\n ):\n files = [] if distribution.files is None else distribution.files\n for file in files:\n path = distribution.locate_file(file)\n assert isinstance(path, Path)\n path.unlink(missing_ok=True)\n\n distribution_path: Path = distribution._path # type: ignore[attr-defined]\n if distribution_path.exists():\n remove_directory(distribution_path, force=True)\n\n paths.append(distribution_path)\n\n return paths\n\n def _path_method_wrapper(\n self,\n path: Path,\n method: str,\n *args: Any,\n return_first: bool = True,\n writable_only: bool = False,\n **kwargs: Any,\n ) -> tuple[Path, Any] | list[tuple[Path, Any]]:\n candidates = self.make_candidates(\n path, writable_only=writable_only, strict=True\n )\n\n results = []\n\n for candidate in candidates:\n try:\n result = candidate, getattr(candidate, method)(*args, **kwargs)\n if return_first:\n return result\n results.append(result)\n except OSError:\n # TODO: Replace with PermissionError\n pass\n\n if results:\n return results\n\n raise OSError(f\"Unable to access any of {paths_csv(candidates)}\")\n\n def write_text(self, path: Path, *args: Any, **kwargs: Any) -> Path:\n paths = self._path_method_wrapper(path, \"write_text\", *args, **kwargs)\n assert isinstance(paths, tuple)\n return paths[0]\n\n def mkdir(self, path: Path, *args: Any, **kwargs: Any) -> Path:\n paths = self._path_method_wrapper(path, \"mkdir\", *args, **kwargs)\n assert isinstance(paths, tuple)\n return paths[0]\n\n def exists(self, path: Path) -> bool:\n return any(\n value[-1]\n for value in self._path_method_wrapper(path, \"exists\", return_first=False)\n )\n\n def find(\n self,\n path: Path,\n writable_only: bool = False,\n ) -> list[Path]:\n return [\n value[0]\n for value in self._path_method_wrapper(\n path, \"exists\", return_first=False, writable_only=writable_only\n )\n if value[-1] is True\n ]\n\n# Path: src/poetry/utils/env/system_env.py\nfrom __future__ import annotations\n\nimport os\nimport platform\nimport site\nimport sys\nimport sysconfig\n\nfrom pathlib import Path\nfrom typing import Any\n\nfrom packaging.tags import Tag\nfrom packaging.tags import interpreter_name\nfrom packaging.tags import interpreter_version\nfrom packaging.tags import sys_tags\n\nfrom poetry.utils.env.base_env import Env\n\n\nclass SystemEnv(Env):\n \"\"\"\n A system (i.e. not a virtualenv) Python environment.\n \"\"\"\n\n @property\n def python(self) -> Path:\n return Path(sys.executable)\n\n @property\n def sys_path(self) -> list[str]:\n return sys.path\n\n def get_version_info(self) -> tuple[Any, ...]:\n return tuple(sys.version_info)\n\n def get_python_implementation(self) -> str:\n return platform.python_implementation()\n\n def get_paths(self) -> dict[str, str]:\n import site\n\n paths = sysconfig.get_paths().copy()\n\n if site.check_enableusersite():\n paths[\"usersite\"] = site.getusersitepackages()\n paths[\"userbase\"] = site.getuserbase()\n\n return paths\n\n def get_supported_tags(self) -> list[Tag]:\n return list(sys_tags())\n\n def get_marker_env(self) -> dict[str, Any]:\n if hasattr(sys, \"implementation\"):\n info = sys.implementation.version\n iver = f\"{info.major}.{info.minor}.{info.micro}\"\n kind = info.releaselevel\n if kind != \"final\":\n iver += kind[0] + str(info.serial)\n\n implementation_name = sys.implementation.name\n else:\n iver = \"0\"\n implementation_name = \"\"\n\n return {\n \"implementation_name\": implementation_name,\n \"implementation_version\": iver,\n \"os_name\": os.name,\n \"platform_machine\": platform.machine(),\n \"platform_release\": platform.release(),\n \"platform_system\": platform.system(),\n \"platform_version\": platform.version(),\n \"python_full_version\": platform.python_version(),\n \"platform_python_implementation\": platform.python_implementation(),\n \"python_version\": \".\".join(platform.python_version().split(\".\")[:2]),\n \"sys_platform\": sys.platform,\n \"version_info\": sys.version_info,\n \"interpreter_name\": interpreter_name(),\n \"interpreter_version\": interpreter_version(),\n }\n\n def is_venv(self) -> bool:\n return self._path != self._base\n\n def _get_lib_dirs(self) -> list[Path]:\n return super()._get_lib_dirs() + [Path(d) for d in site.getsitepackages()]\n\n# Path: src/poetry/utils/env/virtual_env.py\nfrom __future__ import annotations\n\nimport json\nimport os\nimport re\nimport sys\n\nfrom contextlib import contextmanager\nfrom copy import deepcopy\nfrom functools import cached_property\nfrom pathlib import Path\nfrom typing import TYPE_CHECKING\nfrom typing import Any\n\nfrom packaging.tags import Tag\n\nfrom poetry.utils.env.base_env import Env\nfrom poetry.utils.env.script_strings import GET_BASE_PREFIX\nfrom poetry.utils.env.script_strings import GET_ENVIRONMENT_INFO\nfrom poetry.utils.env.script_strings import GET_PATHS\nfrom poetry.utils.env.script_strings import GET_PYTHON_VERSION\nfrom poetry.utils.env.script_strings import GET_SYS_PATH\nfrom poetry.utils.env.script_strings import GET_SYS_TAGS\nfrom poetry.utils.env.system_env import SystemEnv\n\n\nif TYPE_CHECKING:\n from collections.abc import Iterator\n\n\nclass VirtualEnv(Env):\n \"\"\"\n A virtual Python environment.\n \"\"\"\n\n def __init__(self, path: Path, base: Path | None = None) -> None:\n super().__init__(path, base)\n\n # If base is None, it probably means this is\n # a virtualenv created from VIRTUAL_ENV.\n # In this case we need to get sys.base_prefix\n # from inside the virtualenv.\n if base is None:\n output = self.run_python_script(GET_BASE_PREFIX)\n self._base = Path(output.strip())\n\n @property\n def sys_path(self) -> list[str]:\n output = self.run_python_script(GET_SYS_PATH)\n paths: list[str] = json.loads(output)\n return paths\n\n def get_version_info(self) -> tuple[Any, ...]:\n output = self.run_python_script(GET_PYTHON_VERSION)\n assert isinstance(output, str)\n\n return tuple(int(s) for s in output.strip().split(\".\"))\n\n def get_python_implementation(self) -> str:\n implementation: str = self.marker_env[\"platform_python_implementation\"]\n return implementation\n\n def get_supported_tags(self) -> list[Tag]:\n output = self.run_python_script(GET_SYS_TAGS)\n\n return [Tag(*t) for t in json.loads(output)]\n\n def get_marker_env(self) -> dict[str, Any]:\n output = self.run_python_script(GET_ENVIRONMENT_INFO)\n\n env: dict[str, Any] = json.loads(output)\n return env\n\n def get_paths(self) -> dict[str, str]:\n output = self.run_python_script(GET_PATHS)\n paths: dict[str, str] = json.loads(output)\n return paths\n\n def is_venv(self) -> bool:\n return True\n\n def is_sane(self) -> bool:\n # A virtualenv is considered sane if \"python\" exists.\n return os.path.exists(self.python)\n\n def _run(self, cmd: list[str], **kwargs: Any) -> str:\n kwargs[\"env\"] = self.get_temp_environ(environ=kwargs.get(\"env\"))\n return super()._run(cmd, **kwargs)\n\n def get_temp_environ(\n self,\n environ: dict[str, str] | None = None,\n exclude: list[str] | None = None,\n **kwargs: str,\n ) -> dict[str, str]:\n exclude = exclude or []\n exclude.extend([\"PYTHONHOME\", \"__PYVENV_LAUNCHER__\"])\n\n if environ:\n environ = deepcopy(environ)\n for key in exclude:\n environ.pop(key, None)\n else:\n environ = {k: v for k, v in os.environ.items() if k not in exclude}\n\n environ.update(kwargs)\n\n environ[\"PATH\"] = self._updated_path()\n environ[\"VIRTUAL_ENV\"] = str(self._path)\n\n return environ\n\n def execute(self, bin: str, *args: str, **kwargs: Any) -> int:\n kwargs[\"env\"] = self.get_temp_environ(environ=kwargs.get(\"env\"))\n return super().execute(bin, *args, **kwargs)\n\n @contextmanager\n def temp_environ(self) -> Iterator[None]:\n environ = dict(os.environ)\n try:\n yield\n finally:\n os.environ.clear()\n os.environ.update(environ)\n\n def _updated_path(self) -> str:\n return os.pathsep.join([str(self._bin_dir), os.environ.get(\"PATH\", \"\")])\n\n @cached_property\n def includes_system_site_packages(self) -> bool:\n pyvenv_cfg = self._path / \"pyvenv.cfg\"\n return pyvenv_cfg.exists() and (\n re.search(\n r\"^\\s*include-system-site-packages\\s*=\\s*true\\s*$\",\n pyvenv_cfg.read_text(),\n re.IGNORECASE | re.MULTILINE,\n )\n is not None\n )\n\n def is_path_relative_to_lib(self, path: Path) -> bool:\n return super().is_path_relative_to_lib(path) or (\n self.includes_system_site_packages\n and SystemEnv(Path(sys.prefix)).is_path_relative_to_lib(path)\n )\n\n# Path: src/poetry/vcs/git/__init__.py\nfrom __future__ import annotations\n\nfrom poetry.vcs.git.backend import Git\n\n\n__all__ = [\"Git\"]\n\n# Path: src/poetry/vcs/git/backend.py\nfrom __future__ import annotations\n\nimport dataclasses\nimport logging\nimport re\n\nfrom pathlib import Path\nfrom subprocess import CalledProcessError\nfrom typing import TYPE_CHECKING\nfrom urllib.parse import urljoin\nfrom urllib.parse import urlparse\nfrom urllib.parse import urlunparse\n\nfrom dulwich import porcelain\nfrom dulwich.client import HTTPUnauthorized\nfrom dulwich.client import get_transport_and_path\nfrom dulwich.config import ConfigFile\nfrom dulwich.config import parse_submodules\nfrom dulwich.errors import NotGitRepository\nfrom dulwich.index import IndexEntry\nfrom dulwich.refs import ANNOTATED_TAG_SUFFIX\nfrom dulwich.repo import Repo\n\nfrom poetry.console.exceptions import PoetryConsoleError\nfrom poetry.utils.authenticator import get_default_authenticator\nfrom poetry.utils.helpers import remove_directory\n\n\nif TYPE_CHECKING:\n from dulwich.client import FetchPackResult\n from dulwich.client import GitClient\n\n\nlogger = logging.getLogger(__name__)\n\n# A relative URL by definition starts with ../ or ./\nRELATIVE_SUBMODULE_REGEX = re.compile(r\"^\\.{1,2}/\")\n\n\ndef is_revision_sha(revision: str | None) -> bool:\n return re.match(r\"^\\b[0-9a-f]{5,40}\\b$\", revision or \"\") is not None\n\n\ndef annotated_tag(ref: str | bytes) -> bytes:\n if isinstance(ref, str):\n ref = ref.encode(\"utf-8\")\n return ref + ANNOTATED_TAG_SUFFIX\n\n\n@dataclasses.dataclass\nclass GitRefSpec:\n branch: str | None = None\n revision: str | None = None\n tag: str | None = None\n ref: bytes = dataclasses.field(default_factory=lambda: b\"HEAD\")\n\n def resolve(self, remote_refs: FetchPackResult) -> None:\n \"\"\"\n Resolve the ref using the provided remote refs.\n \"\"\"\n self._normalise(remote_refs=remote_refs)\n self._set_head(remote_refs=remote_refs)\n\n def _normalise(self, remote_refs: FetchPackResult) -> None:\n \"\"\"\n Internal helper method to determine if given revision is\n 1. a branch or tag; if so, set corresponding properties.\n 2. a short sha; if so, resolve full sha and set as revision\n \"\"\"\n if self.revision:\n ref = f\"refs/tags/{self.revision}\".encode()\n if ref in remote_refs.refs or annotated_tag(ref) in remote_refs.refs:\n # this is a tag, incorrectly specified as a revision, tags take priority\n self.tag = self.revision\n self.revision = None\n elif (\n self.revision.encode(\"utf-8\") in remote_refs.refs\n or f\"refs/heads/{self.revision}\".encode() in remote_refs.refs\n ):\n # this is most likely a ref spec or a branch incorrectly specified\n self.branch = self.revision\n self.revision = None\n elif (\n self.branch\n and f\"refs/heads/{self.branch}\".encode() not in remote_refs.refs\n and (\n f\"refs/tags/{self.branch}\".encode() in remote_refs.refs\n or annotated_tag(f\"refs/tags/{self.branch}\") in remote_refs.refs\n )\n ):\n # this is a tag incorrectly specified as a branch\n self.tag = self.branch\n self.branch = None\n\n if self.revision and self.is_sha_short:\n # revision is a short sha, resolve to full sha\n short_sha = self.revision.encode(\"utf-8\")\n for sha in remote_refs.refs.values():\n if sha.startswith(short_sha):\n self.revision = sha.decode(\"utf-8\")\n break\n\n def _set_head(self, remote_refs: FetchPackResult) -> None:\n \"\"\"\n Internal helper method to populate ref and set it's sha as the remote's head\n and default ref.\n \"\"\"\n self.ref = remote_refs.symrefs[b\"HEAD\"]\n\n if self.revision:\n head = self.revision.encode(\"utf-8\")\n else:\n if self.tag:\n ref = f\"refs/tags/{self.tag}\".encode()\n annotated = annotated_tag(ref)\n self.ref = annotated if annotated in remote_refs.refs else ref\n elif self.branch:\n self.ref = (\n self.branch.encode(\"utf-8\")\n if self.is_ref\n else f\"refs/heads/{self.branch}\".encode()\n )\n head = remote_refs.refs[self.ref]\n\n remote_refs.refs[self.ref] = remote_refs.refs[b\"HEAD\"] = head\n\n @property\n def key(self) -> str:\n return self.revision or self.branch or self.tag or self.ref.decode(\"utf-8\")\n\n @property\n def is_sha(self) -> bool:\n return is_revision_sha(revision=self.revision)\n\n @property\n def is_ref(self) -> bool:\n return self.branch is not None and (\n self.branch.startswith(\"refs/\") or self.branch == \"HEAD\"\n )\n\n @property\n def is_sha_short(self) -> bool:\n return self.revision is not None and self.is_sha and len(self.revision) < 40\n\n\n@dataclasses.dataclass\nclass GitRepoLocalInfo:\n repo: dataclasses.InitVar[Repo | Path]\n origin: str = dataclasses.field(init=False)\n revision: str = dataclasses.field(init=False)\n\n def __post_init__(self, repo: Repo | Path) -> None:\n repo = Git.as_repo(repo=repo) if not isinstance(repo, Repo) else repo\n self.origin = Git.get_remote_url(repo=repo, remote=\"origin\")\n self.revision = Git.get_revision(repo=repo)\n\n\nclass Git:\n @staticmethod\n def as_repo(repo: Path) -> Repo:\n return Repo(str(repo))\n\n @staticmethod\n def get_remote_url(repo: Repo, remote: str = \"origin\") -> str:\n with repo:\n config = repo.get_config()\n section = (b\"remote\", remote.encode(\"utf-8\"))\n\n url = \"\"\n if config.has_section(section):\n value = config.get(section, b\"url\")\n url = value.decode(\"utf-8\")\n\n return url\n\n @staticmethod\n def get_revision(repo: Repo) -> str:\n with repo:\n return repo.head().decode(\"utf-8\")\n\n @classmethod\n def info(cls, repo: Repo | Path) -> GitRepoLocalInfo:\n return GitRepoLocalInfo(repo=repo)\n\n @staticmethod\n def get_name_from_source_url(url: str) -> str:\n return re.sub(r\"(.git)?$\", \"\", url.rsplit(\"/\", 1)[-1])\n\n @classmethod\n def _fetch_remote_refs(cls, url: str, local: Repo) -> FetchPackResult:\n \"\"\"\n Helper method to fetch remote refs.\n \"\"\"\n client: GitClient\n path: str\n\n kwargs: dict[str, str] = {}\n credentials = get_default_authenticator().get_credentials_for_git_url(url=url)\n\n if credentials.password and credentials.username:\n # we do this conditionally as otherwise, dulwich might complain if these\n # parameters are passed in for an ssh url\n kwargs[\"username\"] = credentials.username\n kwargs[\"password\"] = credentials.password\n\n config = local.get_config_stack()\n client, path = get_transport_and_path(url, config=config, **kwargs)\n\n with local:\n result: FetchPackResult = client.fetch(\n path,\n local,\n determine_wants=local.object_store.determine_wants_all,\n )\n return result\n\n @staticmethod\n def _clone_legacy(url: str, refspec: GitRefSpec, target: Path) -> Repo:\n \"\"\"\n Helper method to facilitate fallback to using system provided git client via\n subprocess calls.\n \"\"\"\n from poetry.vcs.git.system import SystemGit\n\n logger.debug(\"Cloning '%s' using system git client\", url)\n\n if target.exists():\n remove_directory(path=target, force=True)\n\n revision = refspec.tag or refspec.branch or refspec.revision or \"HEAD\"\n\n try:\n SystemGit.clone(url, target)\n except CalledProcessError:\n raise PoetryConsoleError(\n f\"Failed to clone {url}, check your git configuration and permissions\"\n \" for this repository.\"\n )\n\n if revision:\n revision.replace(\"refs/head/\", \"\")\n revision.replace(\"refs/tags/\", \"\")\n\n try:\n SystemGit.checkout(revision, target)\n except CalledProcessError:\n raise PoetryConsoleError(f\"Failed to checkout {url} at '{revision}'\")\n\n repo = Repo(str(target))\n return repo\n\n @classmethod\n def _clone(cls, url: str, refspec: GitRefSpec, target: Path) -> Repo:\n \"\"\"\n Helper method to clone a remove repository at the given `url` at the specified\n ref spec.\n \"\"\"\n local: Repo\n if not target.exists():\n local = Repo.init(str(target), mkdir=True)\n porcelain.remote_add(local, \"origin\", url)\n else:\n local = Repo(str(target))\n\n remote_refs = cls._fetch_remote_refs(url=url, local=local)\n\n logger.debug(\n \"Cloning %s at '%s' to %s\", url, refspec.key, target\n )\n\n try:\n refspec.resolve(remote_refs=remote_refs)\n except KeyError: # branch / ref does not exist\n raise PoetryConsoleError(\n f\"Failed to clone {url} at '{refspec.key}', verify ref exists on\"\n \" remote.\"\n )\n\n # ensure local HEAD matches remote\n local.refs[b\"HEAD\"] = remote_refs.refs[b\"HEAD\"]\n\n if refspec.is_ref:\n # set ref to current HEAD\n local.refs[refspec.ref] = local.refs[b\"HEAD\"]\n\n for base, prefix in {\n (b\"refs/remotes/origin\", b\"refs/heads/\"),\n (b\"refs/tags\", b\"refs/tags\"),\n }:\n local.refs.import_refs(\n base=base,\n other={\n n[len(prefix) :]: v\n for (n, v) in remote_refs.refs.items()\n if n.startswith(prefix) and not n.endswith(ANNOTATED_TAG_SUFFIX)\n },\n )\n\n try:\n with local:\n local.reset_index()\n except (AssertionError, KeyError) as e:\n # this implies the ref we need does not exist or is invalid\n if isinstance(e, KeyError):\n # the local copy is at a bad state, lets remove it\n logger.debug(\n \"Removing local clone (%s) of repository as it is in a\"\n \" broken state.\",\n local.path,\n )\n remove_directory(Path(local.path), force=True)\n\n if isinstance(e, AssertionError) and \"Invalid object name\" not in str(e):\n raise\n\n logger.debug(\n \"\\nRequested ref (%s) was not fetched to local copy and\"\n \" cannot be used. The following error was\"\n \" raised:\\n\\n\\t%s\",\n refspec.key,\n e,\n )\n\n raise PoetryConsoleError(\n f\"Failed to clone {url} at '{refspec.key}', verify ref exists on\"\n \" remote.\"\n )\n\n return local\n\n @classmethod\n def _clone_submodules(cls, repo: Repo) -> None:\n \"\"\"\n Helper method to identify configured submodules and clone them recursively.\n \"\"\"\n repo_root = Path(repo.path)\n for submodule in cls._get_submodules(repo):\n path_absolute = repo_root / submodule.path\n source_root = path_absolute.parent\n source_root.mkdir(parents=True, exist_ok=True)\n cls.clone(\n url=submodule.url,\n source_root=source_root,\n name=path_absolute.name,\n revision=submodule.revision,\n clean=path_absolute.exists()\n and not path_absolute.joinpath(\".git\").is_dir(),\n )\n\n @classmethod\n def _get_submodules(cls, repo: Repo) -> list[SubmoduleInfo]:\n modules_config = Path(repo.path, \".gitmodules\")\n\n if not modules_config.exists():\n return []\n\n config = ConfigFile.from_path(str(modules_config))\n\n submodules: list[SubmoduleInfo] = []\n for path, url, name in parse_submodules(config):\n url_str = url.decode(\"utf-8\")\n path_str = path.decode(\"utf-8\")\n name_str = name.decode(\"utf-8\")\n\n if RELATIVE_SUBMODULE_REGEX.search(url_str):\n url_str = urlpathjoin(f\"{cls.get_remote_url(repo)}/\", url_str)\n\n with repo:\n index = repo.open_index()\n\n try:\n entry = index[path]\n except KeyError:\n logger.debug(\n \"Skip submodule %s in %s, path %s not found\",\n name,\n repo.path,\n path,\n )\n continue\n\n assert isinstance(entry, IndexEntry)\n revision = entry.sha.decode(\"utf-8\")\n\n submodules.append(\n SubmoduleInfo(\n path=path_str,\n url=url_str,\n name=name_str,\n revision=revision,\n )\n )\n\n return submodules\n\n @staticmethod\n def is_using_legacy_client() -> bool:\n from poetry.config.config import Config\n\n legacy_client: bool = Config.create().get(\n \"experimental.system-git-client\", False\n )\n return legacy_client\n\n @staticmethod\n def get_default_source_root() -> Path:\n from poetry.config.config import Config\n\n return Path(Config.create().get(\"cache-dir\")) / \"src\"\n\n @classmethod\n def clone(\n cls,\n url: str,\n name: str | None = None,\n branch: str | None = None,\n tag: str | None = None,\n revision: str | None = None,\n source_root: Path | None = None,\n clean: bool = False,\n ) -> Repo:\n source_root = source_root or cls.get_default_source_root()\n source_root.mkdir(parents=True, exist_ok=True)\n\n name = name or cls.get_name_from_source_url(url=url)\n target = source_root / name\n refspec = GitRefSpec(branch=branch, revision=revision, tag=tag)\n\n if target.exists():\n if clean:\n # force clean the local copy if it exists, do not reuse\n remove_directory(target, force=True)\n else:\n # check if the current local copy matches the requested ref spec\n try:\n current_repo = Repo(str(target))\n\n with current_repo:\n current_sha = current_repo.head().decode(\"utf-8\")\n except (NotGitRepository, AssertionError, KeyError):\n # something is wrong with the current checkout, clean it\n remove_directory(target, force=True)\n else:\n if not is_revision_sha(revision=current_sha):\n # head is not a sha, this will cause issues later, lets reset\n remove_directory(target, force=True)\n elif (\n refspec.is_sha\n and refspec.revision is not None\n and current_sha.startswith(refspec.revision)\n ):\n # if revision is used short-circuit remote fetch head matches\n return current_repo\n\n try:\n if not cls.is_using_legacy_client():\n local = cls._clone(url=url, refspec=refspec, target=target)\n cls._clone_submodules(repo=local)\n return local\n except HTTPUnauthorized:\n # we do this here to handle http authenticated repositories as dulwich\n # does not currently support using credentials from git-credential helpers.\n # upstream issue: https://github.com/jelmer/dulwich/issues/873\n #\n # this is a little inefficient, however preferred as this is transparent\n # without additional configuration or changes for existing projects that\n # use http basic auth credentials.\n logger.debug(\n \"Unable to fetch from private repository '%s', falling back to\"\n \" system git\",\n url,\n )\n\n # fallback to legacy git client\n return cls._clone_legacy(url=url, refspec=refspec, target=target)\n\n\ndef urlpathjoin(base: str, path: str) -> str:\n \"\"\"\n Allow any URL to be joined with a path\n\n This works around an issue with urllib.parse.urljoin where it only handles\n relative URLs for protocols contained in urllib.parse.uses_relative. As it\n happens common protocols used with git, like ssh or git+ssh are not in that\n list.\n\n Thus we need to implement our own version of urljoin that handles all URLs\n protocols. This is accomplished by using urlparse and urlunparse to split\n the URL into its components, join the path, and then reassemble the URL.\n\n See: https://github.com/python-poetry/poetry/issues/6499#issuecomment-1564712609\n \"\"\"\n parsed_base = urlparse(base)\n new = parsed_base._replace(path=urljoin(parsed_base.path, path))\n return urlunparse(new)\n\n\n@dataclasses.dataclass\nclass SubmoduleInfo:\n path: str\n url: str\n name: str\n revision: str\n\n# Path: src/poetry/vcs/git/system.py\nfrom __future__ import annotations\n\nimport os\nimport subprocess\n\nfrom typing import TYPE_CHECKING\n\nfrom dulwich.client import find_git_command\n\n\nif TYPE_CHECKING:\n from pathlib import Path\n from typing import Any\n\n\nclass SystemGit:\n @classmethod\n def clone(cls, repository: str, dest: Path) -> None:\n cls._check_parameter(repository)\n\n cls.run(\"clone\", \"--recurse-submodules\", \"--\", repository, str(dest))\n\n @classmethod\n def checkout(cls, rev: str, target: Path | None = None) -> None:\n cls._check_parameter(rev)\n cls.run(\"checkout\", rev, folder=target)\n\n @staticmethod\n def run(*args: Any, **kwargs: Any) -> None:\n folder = kwargs.pop(\"folder\", None)\n if folder:\n args = (\n \"--git-dir\",\n (folder / \".git\").as_posix(),\n \"--work-tree\",\n folder.as_posix(),\n *args,\n )\n\n git_command = find_git_command()\n env = os.environ.copy()\n env[\"GIT_TERMINAL_PROMPT\"] = \"0\"\n subprocess.check_call(\n git_command + list(args),\n stderr=subprocess.DEVNULL,\n stdout=subprocess.DEVNULL,\n env=env,\n text=True,\n )\n\n @staticmethod\n def _check_parameter(parameter: str) -> None:\n \"\"\"\n Checks a git parameter to avoid unwanted code execution.\n \"\"\"\n if parameter.strip().startswith(\"-\"):\n raise RuntimeError(f\"Invalid Git parameter: {parameter}\")\n\n# Path: src/poetry/console/commands/cache/__init__.py\n\n# Path: src/poetry/console/commands/cache/clear.py\nfrom __future__ import annotations\n\nimport os\n\nfrom typing import TYPE_CHECKING\nfrom typing import ClassVar\n\nfrom cleo.helpers import argument\nfrom cleo.helpers import option\nfrom packaging.utils import canonicalize_name\n\nfrom poetry.config.config import Config\nfrom poetry.console.commands.command import Command\nfrom poetry.utils.cache import FileCache\n\n\nif TYPE_CHECKING:\n from cleo.io.inputs.argument import Argument\n from cleo.io.inputs.option import Option\n\n\nclass CacheClearCommand(Command):\n name = \"cache clear\"\n description = \"Clears a Poetry cache by name.\"\n\n arguments: ClassVar[list[Argument]] = [\n argument(\"cache\", description=\"The name of the cache to clear.\")\n ]\n options: ClassVar[list[Option]] = [\n option(\"all\", description=\"Clear all entries in the cache.\")\n ]\n\n def handle(self) -> int:\n cache = self.argument(\"cache\")\n\n parts = cache.split(\":\")\n root = parts[0]\n\n config = Config.create()\n cache_dir = config.repository_cache_directory / root\n\n try:\n cache_dir.relative_to(config.repository_cache_directory)\n except ValueError:\n raise ValueError(f\"{root} is not a valid repository cache\")\n\n cache = FileCache(cache_dir)\n\n if len(parts) == 1:\n if not self.option(\"all\"):\n raise RuntimeError(\n f\"Add the --all option if you want to clear all {parts[0]} caches\"\n )\n\n if not cache_dir.exists():\n self.line(f\"No cache entries for {parts[0]}\")\n return 0\n\n # Calculate number of entries\n entries_count = sum(\n len(files) for _path, _dirs, files in os.walk(str(cache_dir))\n )\n\n delete = self.confirm(f\"Delete {entries_count} entries?\", True)\n if not delete:\n return 0\n\n cache.flush()\n elif len(parts) == 2:\n raise RuntimeError(\n \"Only specifying the package name is not yet supported. \"\n \"Add a specific version to clear\"\n )\n elif len(parts) == 3:\n package = canonicalize_name(parts[1])\n version = parts[2]\n\n if not cache.has(f\"{package}:{version}\"):\n self.line(f\"No cache entries for {package}:{version}\")\n return 0\n\n delete = self.confirm(f\"Delete cache entry {package}:{version}\", True)\n if not delete:\n return 0\n\n cache.forget(f\"{package}:{version}\")\n else:\n raise ValueError(\"Invalid cache key\")\n\n return 0\n\n# Path: src/poetry/console/commands/cache/list.py\nfrom __future__ import annotations\n\nfrom poetry.config.config import Config\nfrom poetry.console.commands.command import Command\n\n\nclass CacheListCommand(Command):\n name = \"cache list\"\n description = \"List Poetry's caches.\"\n\n def handle(self) -> int:\n config = Config.create()\n if config.repository_cache_directory.exists():\n caches = sorted(config.repository_cache_directory.iterdir())\n if caches:\n for cache in caches:\n self.line(f\"{cache.name}\")\n return 0\n\n self.line_error(\"No caches found\")\n return 0\n\n# Path: src/poetry/console/commands/debug/__init__.py\n\n# Path: src/poetry/console/commands/debug/info.py\nfrom __future__ import annotations\n\nimport sys\n\nfrom poetry.console.commands.command import Command\n\n\nclass DebugInfoCommand(Command):\n name = \"debug info\"\n description = \"Shows debug information.\"\n\n def handle(self) -> int:\n poetry_python_version = \".\".join(str(s) for s in sys.version_info[:3])\n\n self.line(\"\")\n self.line(\"Poetry\")\n self.line(\n \"\\n\".join(\n [\n f\"Version: {self.poetry.VERSION}\",\n f\"Python: {poetry_python_version}\",\n ]\n )\n )\n command = self.get_application().get(\"env info\")\n\n exit_code: int = command.run(self.io)\n return exit_code\n\n# Path: src/poetry/console/commands/debug/resolve.py\nfrom __future__ import annotations\n\nfrom typing import TYPE_CHECKING\nfrom typing import ClassVar\n\nfrom cleo.helpers import argument\nfrom cleo.helpers import option\n...\n", "needle_token_start": 5574, "needle_token_end": 5592, "code_context_ntokens": 15996, "output": ["```python\ndef candidates(self) -> list[Path]:\n return self._candidates\n\n@cached_property\ndef writable_candidates(self) -> list[Path]:\n if self._writable_candidates is not None:\n return self._writable_candidates\n\n self._writable_candidates = []\n for candidate in self._candidates:\n if not is_dir_writable(path=candidate, create=True):\n continue\n self._writable_candidates.append(candidate)\n\n return self._writable_candidates\n```"]} +{"repo": "python-poetry/poetry", "name": "_display_complete_info", "language": "python", "path": "src/poetry/console/commands/env/info.py", "position_ratio": 0.45, "description": "\nFunction Description:\n1. **Purpose**: The function is designed to display detailed information about the current Python environment and its base environment, including versions, paths, and other relevant data.\n2. **Input**: It accepts an environment object that encapsulates details about a Python environment.\n3. **Output**: There is no return value; instead, it outputs formatted text directly to the console, presenting various attributes of the environment and its base.\n4. **Procedure**: The function first extracts and formats details such as the Python version, implementation, and paths from the provided environment object. It checks if the environment is a virtual environment and displays additional validity information if true. It then retrieves and displays similar information for the base environment associated with the provided environment.\n\n", "instruction": "Based on the function description and code context, please retrieve and repeat the exact described function from the code context in a code block wrapped by ```:", "template": "instruction\ncode_context\ndescription\ninstruction", "code_context": "# Path: src/poetry/vcs/git/backend.py\nfrom __future__ import annotations\n\nimport dataclasses\nimport logging\nimport re\n\nfrom pathlib import Path\nfrom subprocess import CalledProcessError\nfrom typing import TYPE_CHECKING\nfrom urllib.parse import urljoin\nfrom urllib.parse import urlparse\nfrom urllib.parse import urlunparse\n\nfrom dulwich import porcelain\nfrom dulwich.client import HTTPUnauthorized\nfrom dulwich.client import get_transport_and_path\nfrom dulwich.config import ConfigFile\nfrom dulwich.config import parse_submodules\nfrom dulwich.errors import NotGitRepository\nfrom dulwich.index import IndexEntry\nfrom dulwich.refs import ANNOTATED_TAG_SUFFIX\nfrom dulwich.repo import Repo\n\nfrom poetry.console.exceptions import PoetryConsoleError\nfrom poetry.utils.authenticator import get_default_authenticator\nfrom poetry.utils.helpers import remove_directory\n\n\nif TYPE_CHECKING:\n from dulwich.client import FetchPackResult\n from dulwich.client import GitClient\n\n\nlogger = logging.getLogger(__name__)\n\n# A relative URL by definition starts with ../ or ./\nRELATIVE_SUBMODULE_REGEX = re.compile(r\"^\\.{1,2}/\")\n\n\ndef is_revision_sha(revision: str | None) -> bool:\n return re.match(r\"^\\b[0-9a-f]{5,40}\\b$\", revision or \"\") is not None\n\n\ndef annotated_tag(ref: str | bytes) -> bytes:\n if isinstance(ref, str):\n ref = ref.encode(\"utf-8\")\n return ref + ANNOTATED_TAG_SUFFIX\n\n\n@dataclasses.dataclass\nclass GitRefSpec:\n branch: str | None = None\n revision: str | None = None\n tag: str | None = None\n ref: bytes = dataclasses.field(default_factory=lambda: b\"HEAD\")\n\n def resolve(self, remote_refs: FetchPackResult) -> None:\n \"\"\"\n Resolve the ref using the provided remote refs.\n \"\"\"\n self._normalise(remote_refs=remote_refs)\n self._set_head(remote_refs=remote_refs)\n\n def _normalise(self, remote_refs: FetchPackResult) -> None:\n \"\"\"\n Internal helper method to determine if given revision is\n 1. a branch or tag; if so, set corresponding properties.\n 2. a short sha; if so, resolve full sha and set as revision\n \"\"\"\n if self.revision:\n ref = f\"refs/tags/{self.revision}\".encode()\n if ref in remote_refs.refs or annotated_tag(ref) in remote_refs.refs:\n # this is a tag, incorrectly specified as a revision, tags take priority\n self.tag = self.revision\n self.revision = None\n elif (\n self.revision.encode(\"utf-8\") in remote_refs.refs\n or f\"refs/heads/{self.revision}\".encode() in remote_refs.refs\n ):\n # this is most likely a ref spec or a branch incorrectly specified\n self.branch = self.revision\n self.revision = None\n elif (\n self.branch\n and f\"refs/heads/{self.branch}\".encode() not in remote_refs.refs\n and (\n f\"refs/tags/{self.branch}\".encode() in remote_refs.refs\n or annotated_tag(f\"refs/tags/{self.branch}\") in remote_refs.refs\n )\n ):\n # this is a tag incorrectly specified as a branch\n self.tag = self.branch\n self.branch = None\n\n if self.revision and self.is_sha_short:\n # revision is a short sha, resolve to full sha\n short_sha = self.revision.encode(\"utf-8\")\n for sha in remote_refs.refs.values():\n if sha.startswith(short_sha):\n self.revision = sha.decode(\"utf-8\")\n break\n\n def _set_head(self, remote_refs: FetchPackResult) -> None:\n \"\"\"\n Internal helper method to populate ref and set it's sha as the remote's head\n and default ref.\n \"\"\"\n self.ref = remote_refs.symrefs[b\"HEAD\"]\n\n if self.revision:\n head = self.revision.encode(\"utf-8\")\n else:\n if self.tag:\n ref = f\"refs/tags/{self.tag}\".encode()\n annotated = annotated_tag(ref)\n self.ref = annotated if annotated in remote_refs.refs else ref\n elif self.branch:\n self.ref = (\n self.branch.encode(\"utf-8\")\n if self.is_ref\n else f\"refs/heads/{self.branch}\".encode()\n )\n head = remote_refs.refs[self.ref]\n\n remote_refs.refs[self.ref] = remote_refs.refs[b\"HEAD\"] = head\n\n @property\n def key(self) -> str:\n return self.revision or self.branch or self.tag or self.ref.decode(\"utf-8\")\n\n @property\n def is_sha(self) -> bool:\n return is_revision_sha(revision=self.revision)\n\n @property\n def is_ref(self) -> bool:\n return self.branch is not None and (\n self.branch.startswith(\"refs/\") or self.branch == \"HEAD\"\n )\n\n @property\n def is_sha_short(self) -> bool:\n return self.revision is not None and self.is_sha and len(self.revision) < 40\n\n\n@dataclasses.dataclass\nclass GitRepoLocalInfo:\n repo: dataclasses.InitVar[Repo | Path]\n origin: str = dataclasses.field(init=False)\n revision: str = dataclasses.field(init=False)\n\n def __post_init__(self, repo: Repo | Path) -> None:\n repo = Git.as_repo(repo=repo) if not isinstance(repo, Repo) else repo\n self.origin = Git.get_remote_url(repo=repo, remote=\"origin\")\n self.revision = Git.get_revision(repo=repo)\n\n\nclass Git:\n @staticmethod\n def as_repo(repo: Path) -> Repo:\n return Repo(str(repo))\n\n @staticmethod\n def get_remote_url(repo: Repo, remote: str = \"origin\") -> str:\n with repo:\n config = repo.get_config()\n section = (b\"remote\", remote.encode(\"utf-8\"))\n\n url = \"\"\n if config.has_section(section):\n value = config.get(section, b\"url\")\n url = value.decode(\"utf-8\")\n\n return url\n\n @staticmethod\n def get_revision(repo: Repo) -> str:\n with repo:\n return repo.head().decode(\"utf-8\")\n\n @classmethod\n def info(cls, repo: Repo | Path) -> GitRepoLocalInfo:\n return GitRepoLocalInfo(repo=repo)\n\n @staticmethod\n def get_name_from_source_url(url: str) -> str:\n return re.sub(r\"(.git)?$\", \"\", url.rsplit(\"/\", 1)[-1])\n\n @classmethod\n def _fetch_remote_refs(cls, url: str, local: Repo) -> FetchPackResult:\n \"\"\"\n Helper method to fetch remote refs.\n \"\"\"\n client: GitClient\n path: str\n\n kwargs: dict[str, str] = {}\n credentials = get_default_authenticator().get_credentials_for_git_url(url=url)\n\n if credentials.password and credentials.username:\n # we do this conditionally as otherwise, dulwich might complain if these\n # parameters are passed in for an ssh url\n kwargs[\"username\"] = credentials.username\n kwargs[\"password\"] = credentials.password\n\n config = local.get_config_stack()\n client, path = get_transport_and_path(url, config=config, **kwargs)\n\n with local:\n result: FetchPackResult = client.fetch(\n path,\n local,\n determine_wants=local.object_store.determine_wants_all,\n )\n return result\n\n @staticmethod\n def _clone_legacy(url: str, refspec: GitRefSpec, target: Path) -> Repo:\n \"\"\"\n Helper method to facilitate fallback to using system provided git client via\n subprocess calls.\n \"\"\"\n from poetry.vcs.git.system import SystemGit\n\n logger.debug(\"Cloning '%s' using system git client\", url)\n\n if target.exists():\n remove_directory(path=target, force=True)\n\n revision = refspec.tag or refspec.branch or refspec.revision or \"HEAD\"\n\n try:\n SystemGit.clone(url, target)\n except CalledProcessError:\n raise PoetryConsoleError(\n f\"Failed to clone {url}, check your git configuration and permissions\"\n \" for this repository.\"\n )\n\n if revision:\n revision.replace(\"refs/head/\", \"\")\n revision.replace(\"refs/tags/\", \"\")\n\n try:\n SystemGit.checkout(revision, target)\n except CalledProcessError:\n raise PoetryConsoleError(f\"Failed to checkout {url} at '{revision}'\")\n\n repo = Repo(str(target))\n return repo\n\n @classmethod\n def _clone(cls, url: str, refspec: GitRefSpec, target: Path) -> Repo:\n \"\"\"\n Helper method to clone a remove repository at the given `url` at the specified\n ref spec.\n \"\"\"\n local: Repo\n if not target.exists():\n local = Repo.init(str(target), mkdir=True)\n porcelain.remote_add(local, \"origin\", url)\n else:\n local = Repo(str(target))\n\n remote_refs = cls._fetch_remote_refs(url=url, local=local)\n\n logger.debug(\n \"Cloning %s at '%s' to %s\", url, refspec.key, target\n )\n\n try:\n refspec.resolve(remote_refs=remote_refs)\n except KeyError: # branch / ref does not exist\n raise PoetryConsoleError(\n f\"Failed to clone {url} at '{refspec.key}', verify ref exists on\"\n \" remote.\"\n )\n\n # ensure local HEAD matches remote\n local.refs[b\"HEAD\"] = remote_refs.refs[b\"HEAD\"]\n\n if refspec.is_ref:\n # set ref to current HEAD\n local.refs[refspec.ref] = local.refs[b\"HEAD\"]\n\n for base, prefix in {\n (b\"refs/remotes/origin\", b\"refs/heads/\"),\n (b\"refs/tags\", b\"refs/tags\"),\n }:\n local.refs.import_refs(\n base=base,\n other={\n n[len(prefix) :]: v\n for (n, v) in remote_refs.refs.items()\n if n.startswith(prefix) and not n.endswith(ANNOTATED_TAG_SUFFIX)\n },\n )\n\n try:\n with local:\n local.reset_index()\n except (AssertionError, KeyError) as e:\n # this implies the ref we need does not exist or is invalid\n if isinstance(e, KeyError):\n # the local copy is at a bad state, lets remove it\n logger.debug(\n \"Removing local clone (%s) of repository as it is in a\"\n \" broken state.\",\n local.path,\n )\n remove_directory(Path(local.path), force=True)\n\n if isinstance(e, AssertionError) and \"Invalid object name\" not in str(e):\n raise\n\n logger.debug(\n \"\\nRequested ref (%s) was not fetched to local copy and\"\n \" cannot be used. The following error was\"\n \" raised:\\n\\n\\t%s\",\n refspec.key,\n e,\n )\n\n raise PoetryConsoleError(\n f\"Failed to clone {url} at '{refspec.key}', verify ref exists on\"\n \" remote.\"\n )\n\n return local\n\n @classmethod\n def _clone_submodules(cls, repo: Repo) -> None:\n \"\"\"\n Helper method to identify configured submodules and clone them recursively.\n \"\"\"\n repo_root = Path(repo.path)\n for submodule in cls._get_submodules(repo):\n path_absolute = repo_root / submodule.path\n source_root = path_absolute.parent\n source_root.mkdir(parents=True, exist_ok=True)\n cls.clone(\n url=submodule.url,\n source_root=source_root,\n name=path_absolute.name,\n revision=submodule.revision,\n clean=path_absolute.exists()\n and not path_absolute.joinpath(\".git\").is_dir(),\n )\n\n @classmethod\n def _get_submodules(cls, repo: Repo) -> list[SubmoduleInfo]:\n modules_config = Path(repo.path, \".gitmodules\")\n\n if not modules_config.exists():\n return []\n\n config = ConfigFile.from_path(str(modules_config))\n\n submodules: list[SubmoduleInfo] = []\n for path, url, name in parse_submodules(config):\n url_str = url.decode(\"utf-8\")\n path_str = path.decode(\"utf-8\")\n name_str = name.decode(\"utf-8\")\n\n if RELATIVE_SUBMODULE_REGEX.search(url_str):\n url_str = urlpathjoin(f\"{cls.get_remote_url(repo)}/\", url_str)\n\n with repo:\n index = repo.open_index()\n...\n# Path: src/poetry/vcs/git/system.py\nfrom __future__ import annotations\n\nimport os\nimport subprocess\n\nfrom typing import TYPE_CHECKING\n\nfrom dulwich.client import find_git_command\n\n\nif TYPE_CHECKING:\n from pathlib import Path\n from typing import Any\n\n\nclass SystemGit:\n @classmethod\n def clone(cls, repository: str, dest: Path) -> None:\n cls._check_parameter(repository)\n\n cls.run(\"clone\", \"--recurse-submodules\", \"--\", repository, str(dest))\n\n @classmethod\n def checkout(cls, rev: str, target: Path | None = None) -> None:\n cls._check_parameter(rev)\n cls.run(\"checkout\", rev, folder=target)\n\n @staticmethod\n def run(*args: Any, **kwargs: Any) -> None:\n folder = kwargs.pop(\"folder\", None)\n if folder:\n args = (\n \"--git-dir\",\n (folder / \".git\").as_posix(),\n \"--work-tree\",\n folder.as_posix(),\n *args,\n )\n\n git_command = find_git_command()\n env = os.environ.copy()\n env[\"GIT_TERMINAL_PROMPT\"] = \"0\"\n subprocess.check_call(\n git_command + list(args),\n stderr=subprocess.DEVNULL,\n stdout=subprocess.DEVNULL,\n env=env,\n text=True,\n )\n\n @staticmethod\n def _check_parameter(parameter: str) -> None:\n \"\"\"\n Checks a git parameter to avoid unwanted code execution.\n \"\"\"\n if parameter.strip().startswith(\"-\"):\n raise RuntimeError(f\"Invalid Git parameter: {parameter}\")\n\n# Path: src/poetry/console/commands/cache/__init__.py\n\n# Path: src/poetry/console/commands/cache/clear.py\nfrom __future__ import annotations\n\nimport os\n\nfrom typing import TYPE_CHECKING\nfrom typing import ClassVar\n\nfrom cleo.helpers import argument\nfrom cleo.helpers import option\nfrom packaging.utils import canonicalize_name\n\nfrom poetry.config.config import Config\nfrom poetry.console.commands.command import Command\nfrom poetry.utils.cache import FileCache\n\n\nif TYPE_CHECKING:\n from cleo.io.inputs.argument import Argument\n from cleo.io.inputs.option import Option\n\n\nclass CacheClearCommand(Command):\n name = \"cache clear\"\n description = \"Clears a Poetry cache by name.\"\n\n arguments: ClassVar[list[Argument]] = [\n argument(\"cache\", description=\"The name of the cache to clear.\")\n ]\n options: ClassVar[list[Option]] = [\n option(\"all\", description=\"Clear all entries in the cache.\")\n ]\n\n def handle(self) -> int:\n cache = self.argument(\"cache\")\n\n parts = cache.split(\":\")\n root = parts[0]\n\n config = Config.create()\n cache_dir = config.repository_cache_directory / root\n\n try:\n cache_dir.relative_to(config.repository_cache_directory)\n except ValueError:\n raise ValueError(f\"{root} is not a valid repository cache\")\n\n cache = FileCache(cache_dir)\n\n if len(parts) == 1:\n if not self.option(\"all\"):\n raise RuntimeError(\n f\"Add the --all option if you want to clear all {parts[0]} caches\"\n )\n\n if not cache_dir.exists():\n self.line(f\"No cache entries for {parts[0]}\")\n return 0\n\n # Calculate number of entries\n entries_count = sum(\n len(files) for _path, _dirs, files in os.walk(str(cache_dir))\n )\n\n delete = self.confirm(f\"Delete {entries_count} entries?\", True)\n if not delete:\n return 0\n\n cache.flush()\n elif len(parts) == 2:\n raise RuntimeError(\n \"Only specifying the package name is not yet supported. \"\n \"Add a specific version to clear\"\n )\n elif len(parts) == 3:\n package = canonicalize_name(parts[1])\n version = parts[2]\n\n if not cache.has(f\"{package}:{version}\"):\n self.line(f\"No cache entries for {package}:{version}\")\n return 0\n\n delete = self.confirm(f\"Delete cache entry {package}:{version}\", True)\n if not delete:\n return 0\n\n cache.forget(f\"{package}:{version}\")\n else:\n raise ValueError(\"Invalid cache key\")\n\n return 0\n\n# Path: src/poetry/console/commands/cache/list.py\nfrom __future__ import annotations\n\nfrom poetry.config.config import Config\nfrom poetry.console.commands.command import Command\n\n\nclass CacheListCommand(Command):\n name = \"cache list\"\n description = \"List Poetry's caches.\"\n\n def handle(self) -> int:\n config = Config.create()\n if config.repository_cache_directory.exists():\n caches = sorted(config.repository_cache_directory.iterdir())\n if caches:\n for cache in caches:\n self.line(f\"{cache.name}\")\n return 0\n\n self.line_error(\"No caches found\")\n return 0\n\n# Path: src/poetry/console/commands/debug/__init__.py\n\n# Path: src/poetry/console/commands/debug/info.py\nfrom __future__ import annotations\n\nimport sys\n\nfrom poetry.console.commands.command import Command\n\n\nclass DebugInfoCommand(Command):\n name = \"debug info\"\n description = \"Shows debug information.\"\n\n def handle(self) -> int:\n poetry_python_version = \".\".join(str(s) for s in sys.version_info[:3])\n\n self.line(\"\")\n self.line(\"Poetry\")\n self.line(\n \"\\n\".join(\n [\n f\"Version: {self.poetry.VERSION}\",\n f\"Python: {poetry_python_version}\",\n ]\n )\n )\n command = self.get_application().get(\"env info\")\n\n exit_code: int = command.run(self.io)\n return exit_code\n\n# Path: src/poetry/console/commands/debug/resolve.py\nfrom __future__ import annotations\n\nfrom typing import TYPE_CHECKING\nfrom typing import ClassVar\n\nfrom cleo.helpers import argument\nfrom cleo.helpers import option\nfrom cleo.io.outputs.output import Verbosity\n\nfrom poetry.console.commands.init import InitCommand\nfrom poetry.console.commands.show import ShowCommand\n\n\nif TYPE_CHECKING:\n from cleo.io.inputs.argument import Argument\n from cleo.io.inputs.option import Option\n from cleo.ui.table import Rows\n\n\nclass DebugResolveCommand(InitCommand):\n name = \"debug resolve\"\n description = \"Debugs dependency resolution.\"\n\n arguments: ClassVar[list[Argument]] = [\n argument(\"package\", \"The packages to resolve.\", optional=True, multiple=True)\n ]\n options: ClassVar[list[Option]] = [\n option(\n \"extras\",\n \"E\",\n \"Extras to activate for the dependency.\",\n flag=False,\n multiple=True,\n ),\n option(\"python\", None, \"Python version(s) to use for resolution.\", flag=False),\n option(\"tree\", None, \"Display the dependency tree.\"),\n option(\"install\", None, \"Show what would be installed for the current system.\"),\n ]\n\n loggers: ClassVar[list[str]] = [\n \"poetry.repositories.pypi_repository\",\n \"poetry.inspection.info\",\n ]\n\n def handle(self) -> int:\n from cleo.io.null_io import NullIO\n from poetry.core.packages.project_package import ProjectPackage\n\n from poetry.factory import Factory\n from poetry.puzzle.solver import Solver\n from poetry.repositories.repository import Repository\n from poetry.repositories.repository_pool import RepositoryPool\n from poetry.utils.env import EnvManager\n\n packages = self.argument(\"package\")\n\n if not packages:\n package = self.poetry.package\n else:\n # Using current pool for determine_requirements()\n self._pool = self.poetry.pool\n\n package = ProjectPackage(\n self.poetry.package.name, self.poetry.package.version\n )\n\n # Silencing output\n verbosity = self.io.output.verbosity\n self.io.output.set_verbosity(Verbosity.QUIET)\n\n requirements = self._determine_requirements(packages)\n\n self.io.output.set_verbosity(verbosity)\n\n for constraint in requirements:\n name = constraint.pop(\"name\")\n assert isinstance(name, str)\n extras = []\n for extra in self.option(\"extras\"):\n extras += extra.split()\n\n constraint[\"extras\"] = extras\n\n package.add_dependency(Factory.create_dependency(name, constraint))\n\n package.python_versions = self.option(\"python\") or (\n self.poetry.package.python_versions\n )\n\n pool = self.poetry.pool\n\n solver = Solver(package, pool, [], [], self.io)\n\n ops = solver.solve().calculate_operations()\n\n self.line(\"\")\n self.line(\"Resolution results:\")\n self.line(\"\")\n\n if self.option(\"tree\"):\n show_command = self.get_application().find(\"show\")\n assert isinstance(show_command, ShowCommand)\n show_command.init_styles(self.io)\n\n packages = [op.package for op in ops]\n\n requires = package.all_requires\n for pkg in packages:\n for require in requires:\n if pkg.name == require.name:\n show_command.display_package_tree(self.io, pkg, packages)\n break\n\n return 0\n\n table = self.table(style=\"compact\")\n table.style.set_vertical_border_chars(\"\", \" \")\n rows: Rows = []\n\n if self.option(\"install\"):\n env = EnvManager(self.poetry).get()\n pool = RepositoryPool(config=self.poetry.config)\n locked_repository = Repository(\"poetry-locked\")\n for op in ops:\n locked_repository.add_package(op.package)\n\n pool.add_repository(locked_repository)\n\n solver = Solver(package, pool, [], [], NullIO())\n with solver.use_environment(env):\n ops = solver.solve().calculate_operations()\n\n for op in ops:\n if self.option(\"install\") and op.skipped:\n continue\n\n pkg = op.package\n row = [\n f\"{pkg.complete_name}\",\n f\"{pkg.version}\",\n ]\n\n if not pkg.marker.is_any():\n row[2] = str(pkg.marker)\n\n rows.append(row)\n\n table.set_rows(rows)\n table.render()\n\n return 0\n\n# Path: src/poetry/console/commands/env/__init__.py\n\n# Path: src/poetry/console/commands/env/info.py\nfrom __future__ import annotations\n\nfrom typing import TYPE_CHECKING\nfrom typing import ClassVar\n\nfrom cleo.helpers import option\n\nfrom poetry.console.commands.command import Command\n\n\nif TYPE_CHECKING:\n from cleo.io.inputs.option import Option\n\n from poetry.utils.env import Env\n\n\nclass EnvInfoCommand(Command):\n name = \"env info\"\n description = \"Displays information about the current environment.\"\n\n options: ClassVar[list[Option]] = [\n option(\"path\", \"p\", \"Only display the environment's path.\"),\n option(\n \"executable\", \"e\", \"Only display the environment's python executable path.\"\n ),\n ]\n\n def handle(self) -> int:\n from poetry.utils.env import EnvManager\n\n env = EnvManager(self.poetry).get()\n\n if self.option(\"path\"):\n if not env.is_venv():\n return 1\n\n self.line(str(env.path))\n\n return 0\n\n if self.option(\"executable\"):\n if not env.is_venv():\n return 1\n\n self.line(str(env.python))\n\n return 0\n\n self._display_complete_info(env)\n return 0\n\n \ndef _display_complete_info(self, env: Env) -> None:\n env_python_version = \".\".join(str(s) for s in env.version_info[:3])\n self.line(\"\")\n self.line(\"Virtualenv\")\n listing = [\n f\"Python: {env_python_version}\",\n f\"Implementation: {env.python_implementation}\",\n (\n \"Path: \"\n f\" {env.path if env.is_venv() else 'NA'}\"\n ),\n (\n \"Executable: \"\n f\" {env.python if env.is_venv() else 'NA'}\"\n ),\n ]\n if env.is_venv():\n listing.append(\n \"Valid: \"\n f\" <{'comment' if env.is_sane() else 'error'}>{env.is_sane()}\"\n )\n self.line(\"\\n\".join(listing))\n\n self.line(\"\")\n\n base_env = env.parent_env\n python = \".\".join(str(v) for v in base_env.version_info[:3])\n self.line(\"Base\")\n self.line(\n \"\\n\".join(\n [\n f\"Platform: {env.platform}\",\n f\"OS: {env.os}\",\n f\"Python: {python}\",\n f\"Path: {base_env.path}\",\n f\"Executable: {base_env.python}\",\n ]\n )\n )\n\n# Path: src/poetry/console/commands/env/list.py\nfrom __future__ import annotations\n\nfrom typing import TYPE_CHECKING\nfrom typing import ClassVar\n\nfrom cleo.helpers import option\n\nfrom poetry.console.commands.command import Command\n\n\nif TYPE_CHECKING:\n from cleo.io.inputs.option import Option\n\n\nclass EnvListCommand(Command):\n name = \"env list\"\n description = \"Lists all virtualenvs associated with the current project.\"\n\n options: ClassVar[list[Option]] = [\n option(\"full-path\", None, \"Output the full paths of the virtualenvs.\")\n ]\n\n def handle(self) -> int:\n from poetry.utils.env import EnvManager\n\n manager = EnvManager(self.poetry)\n current_env = manager.get()\n\n for venv in manager.list():\n name = venv.path.name\n if self.option(\"full-path\"):\n name = str(venv.path)\n\n if venv == current_env:\n self.line(f\"{name} (Activated)\")\n\n continue\n\n self.line(name)\n\n return 0\n\n# Path: src/poetry/console/commands/env/remove.py\nfrom __future__ import annotations\n\nfrom typing import TYPE_CHECKING\nfrom typing import ClassVar\n\nfrom cleo.helpers import argument\nfrom cleo.helpers import option\n\nfrom poetry.console.commands.command import Command\n\n\nif TYPE_CHECKING:\n from cleo.io.inputs.argument import Argument\n from cleo.io.inputs.option import Option\n\n\nclass EnvRemoveCommand(Command):\n name = \"env remove\"\n description = \"Remove virtual environments associated with the project.\"\n\n arguments: ClassVar[list[Argument]] = [\n argument(\n \"python\",\n \"The python executables associated with, or names of the virtual\"\n \" environments which are to be removed.\",\n optional=True,\n multiple=True,\n )\n ]\n options: ClassVar[list[Option]] = [\n option(\n \"all\",\n description=(\n \"Remove all managed virtual environments associated with the project.\"\n ),\n ),\n ]\n\n def handle(self) -> int:\n from poetry.utils.env import EnvManager\n\n pythons = self.argument(\"python\")\n all = self.option(\"all\")\n if not (pythons or all):\n self.line(\"No virtualenv provided.\")\n\n manager = EnvManager(self.poetry)\n # TODO: refactor env.py to allow removal with one loop\n for python in pythons:\n venv = manager.remove(python)\n self.line(f\"Deleted virtualenv: {venv.path}\")\n if all:\n for venv in manager.list():\n manager.remove_venv(venv.path)\n self.line(f\"Deleted virtualenv: {venv.path}\")\n # Since we remove all the virtualenvs, we can also remove the entry\n # in the envs file. (Strictly speaking, we should do this explicitly,\n # in case it points to a virtualenv that had been removed manually before.)\n if manager.envs_file.exists():\n manager.envs_file.remove_section(manager.base_env_name)\n\n return 0\n\n# Path: src/poetry/console/commands/env/use.py\nfrom __future__ import annotations\n\nfrom typing import TYPE_CHECKING\nfrom typing import ClassVar\n\nfrom cleo.helpers import argument\n\nfrom poetry.console.commands.command import Command\n\n\nif TYPE_CHECKING:\n from cleo.io.inputs.argument import Argument\n\n\nclass EnvUseCommand(Command):\n name = \"env use\"\n description = \"Activates or creates a new virtualenv for the current project.\"\n\n arguments: ClassVar[list[Argument]] = [\n argument(\"python\", \"The python executable to use.\")\n ]\n\n def handle(self) -> int:\n from poetry.utils.env import EnvManager\n\n manager = EnvManager(self.poetry, io=self.io)\n\n if self.argument(\"python\") == \"system\":\n manager.deactivate()\n\n return 0\n\n env = manager.activate(self.argument(\"python\"))\n\n self.line(f\"Using virtualenv: {env.path}\")\n\n return 0\n\n# Path: src/poetry/console/commands/self/__init__.py\n\n# Path: src/poetry/console/commands/self/add.py\nfrom __future__ import annotations\n\nfrom typing import TYPE_CHECKING\nfrom typing import ClassVar\n\nfrom poetry.core.constraints.version import Version\n\nfrom poetry.__version__ import __version__\nfrom poetry.console.commands.add import AddCommand\nfrom poetry.console.commands.self.self_command import SelfCommand\n\n\nif TYPE_CHECKING:\n from cleo.io.inputs.option import Option\n\n\nclass SelfAddCommand(SelfCommand, AddCommand):\n name = \"self add\"\n description = \"Add additional packages to Poetry's runtime environment.\"\n options: ClassVar[list[Option]] = [\n o\n for o in AddCommand.options\n if o.name in {\"editable\", \"extras\", \"source\", \"dry-run\", \"allow-prereleases\"}\n ]\n help = f\"\"\"\\\nThe self add command installs additional packages to Poetry's runtime \\\nenvironment.\n\nThis is managed in the {SelfCommand.get_default_system_pyproject_file()} \\\nfile.\n\n{AddCommand.examples}\n\"\"\"\n\n @property\n def _hint_update_packages(self) -> str:\n version = Version.parse(__version__)\n flags = \"\"\n\n if not version.is_stable():\n flags = \" --preview\"\n\n return (\n \"\\nIf you want to update it to the latest compatible version, you can use\"\n f\" `poetry self update{flags}`.\\nIf you prefer to upgrade it to the latest\"\n \" available version, you can use `poetry self add package@latest`.\\n\"\n )\n\n# Path: src/poetry/console/commands/self/install.py\nfrom __future__ import annotations\n\nfrom typing import TYPE_CHECKING\nfrom typing import ClassVar\n\nfrom poetry.core.packages.dependency_group import MAIN_GROUP\n\nfrom poetry.console.commands.install import InstallCommand\nfrom poetry.console.commands.self.self_command import SelfCommand\n\n\nif TYPE_CHECKING:\n from cleo.io.inputs.option import Option\n\n\nclass SelfInstallCommand(SelfCommand, InstallCommand):\n name = \"self install\"\n description = (\n \"Install locked packages (incl. addons) required by this Poetry installation.\"\n )\n options: ClassVar[list[Option]] = [\n o for o in InstallCommand.options if o.name in {\"sync\", \"dry-run\"}\n ]\n help = f\"\"\"\\\nThe self install command ensures all additional packages specified are \\\ninstalled in the current runtime environment.\n\nThis is managed in the {SelfCommand.get_default_system_pyproject_file()} \\\nfile.\n\nYou can add more packages using the self add command and remove them using \\\nthe self remove command.\n\"\"\"\n\n @property\n def activated_groups(self) -> set[str]:\n return {MAIN_GROUP, self.default_group}\n\n# Path: src/poetry/console/commands/self/lock.py\nfrom __future__ import annotations\n\nfrom poetry.console.commands.lock import LockCommand\nfrom poetry.console.commands.self.self_command import SelfCommand\n\n\nclass SelfLockCommand(SelfCommand, LockCommand):\n name = \"self lock\"\n description = \"Lock the Poetry installation's system requirements.\"\n help = f\"\"\"\\\nThe self lock command reads this Poetry installation's system requirements as \\\nspecified in the {SelfCommand.get_default_system_pyproject_file()} file.\n\nThe system dependencies are locked in the \\\n{SelfCommand.get_default_system_pyproject_file().parent.joinpath(\"poetry.lock\")} \\\nfile.\n\"\"\"\n\n# Path: src/poetry/console/commands/self/remove.py\nfrom __future__ import annotations\n\nfrom typing import TYPE_CHECKING\nfrom typing import ClassVar\n\nfrom poetry.console.commands.remove import RemoveCommand\nfrom poetry.console.commands.self.self_command import SelfCommand\n\n\nif TYPE_CHECKING:\n from cleo.io.inputs.option import Option\n\n\nclass SelfRemoveCommand(SelfCommand, RemoveCommand):\n name = \"self remove\"\n description = \"Remove additional packages from Poetry's runtime environment.\"\n options: ClassVar[list[Option]] = [\n o for o in RemoveCommand.options if o.name in {\"dry-run\"}\n ]\n help = f\"\"\"\\\nThe self remove command removes additional package's to Poetry's runtime \\\nenvironment.\n\nThis is managed in the {SelfCommand.get_default_system_pyproject_file()} \\\nfile.\n\"\"\"\n\n# Path: src/poetry/console/commands/self/self_command.py\nfrom __future__ import annotations\n\nfrom pathlib import Path\nfrom typing import TYPE_CHECKING\n\nfrom poetry.core.packages.dependency import Dependency\nfrom poetry.core.packages.project_package import ProjectPackage\n\nfrom poetry.__version__ import __version__\nfrom poetry.console.commands.installer_command import InstallerCommand\nfrom poetry.factory import Factory\nfrom poetry.pyproject.toml import PyProjectTOML\nfrom poetry.utils.env import EnvManager\nfrom poetry.utils.env import SystemEnv\nfrom poetry.utils.helpers import directory\n\n\nif TYPE_CHECKING:\n from poetry.poetry import Poetry\n from poetry.utils.env import Env\n\n\nclass SelfCommand(InstallerCommand):\n ADDITIONAL_PACKAGE_GROUP = \"additional\"\n\n @staticmethod\n def get_default_system_pyproject_file() -> Path:\n # We separate this out to avoid unwanted side effect during testing while\n # maintaining dynamic use in help text.\n #\n # This is not ideal, but is the simplest solution for now.\n from poetry.locations import CONFIG_DIR\n\n return Path(CONFIG_DIR).joinpath(\"pyproject.toml\")\n\n @property\n def system_pyproject(self) -> Path:\n file = self.get_default_system_pyproject_file()\n file.parent.mkdir(parents=True, exist_ok=True)\n return file\n\n def reset_env(self) -> None:\n self._env = EnvManager.get_system_env(naive=True)\n\n @property\n def env(self) -> Env:\n if not isinstance(self._env, SystemEnv):\n self.reset_env()\n assert self._env is not None\n return self._env\n\n @property\n def default_group(self) -> str:\n return self.ADDITIONAL_PACKAGE_GROUP\n\n @property\n def activated_groups(self) -> set[str]:\n return {self.default_group}\n\n def generate_system_pyproject(self) -> None:\n preserved = {}\n\n if self.system_pyproject.exists():\n content = PyProjectTOML(self.system_pyproject).poetry_config\n\n for key in {\"group\", \"source\"}:\n if key in content:\n preserved[key] = content[key]\n\n package = ProjectPackage(name=\"poetry-instance\", version=__version__)\n package.add_dependency(Dependency(name=\"poetry\", constraint=f\"{__version__}\"))\n\n package.python_versions = \".\".join(str(v) for v in self.env.version_info[:3])\n\n content = Factory.create_pyproject_from_package(package=package)\n\n for key in preserved:\n content[\"tool\"][\"poetry\"][key] = preserved[key] # type: ignore[index]\n\n pyproject = PyProjectTOML(self.system_pyproject)\n pyproject.file.write(content)\n\n def reset_poetry(self) -> None:\n with directory(self.system_pyproject.parent):\n self.generate_system_pyproject()\n self._poetry = Factory().create_poetry(\n self.system_pyproject.parent, io=self.io, disable_plugins=True\n )\n\n @property\n def poetry(self) -> Poetry:\n if self._poetry is None:\n self.reset_poetry()\n\n assert self._poetry is not None\n return self._poetry\n\n def _system_project_handle(self) -> int:\n \"\"\"\n This is a helper method that by default calls the handle method implemented in\n the child class's next MRO sibling. Override this if you want special handling\n either before calling the handle() from the super class or have custom logic\n to handle the command.\n\n The default implementations handles cases where a `self` command delegates\n handling to an existing command. Eg: `SelfAddCommand(SelfCommand, AddCommand)`.\n \"\"\"\n return_code: int = super().handle()\n return return_code\n\n def reset(self) -> None:\n \"\"\"\n Reset current command instance's environment and poetry instances to ensure\n use of the system specific ones.\n \"\"\"\n self.reset_env()\n self.reset_poetry()\n\n def handle(self) -> int:\n # We override the base class's handle() method to ensure that poetry and env\n # are reset to work within the system project instead of current context.\n # Further, during execution, the working directory is temporarily changed\n # to parent directory of Poetry system pyproject.toml file.\n #\n # This method **should not** be overridden in child classes as it may have\n # unexpected consequences.\n\n self.reset()\n\n with directory(self.system_pyproject.parent):\n return self._system_project_handle()\n\n# Path: src/poetry/console/commands/self/update.py\nfrom __future__ import annotations\n\nfrom typing import TYPE_CHECKING\nfrom typing import ClassVar\n\nfrom cleo.helpers import argument\nfrom cleo.helpers import option\nfrom cleo.io.inputs.string_input import StringInput\nfrom cleo.io.io import IO\n\nfrom poetry.console.commands.add import AddCommand\nfrom poetry.console.commands.self.self_command import SelfCommand\n\n\nif TYPE_CHECKING:\n from cleo.io.inputs.argument import Argument\n from cleo.io.inputs.option import Option\n\n\nclass SelfUpdateCommand(SelfCommand):\n name = \"self update\"\n description = \"Updates Poetry to the latest version.\"\n\n arguments: ClassVar[list[Argument]] = [\n argument(\n \"version\", \"The version to update to.\", optional=True, default=\"latest\"\n )\n ]\n options: ClassVar[list[Option]] = [\n option(\"preview\", None, \"Allow the installation of pre-release versions.\"),\n option(\n \"dry-run\",\n None,\n \"Output the operations but do not execute anything \"\n \"(implicitly enables --verbose).\",\n ),\n ]\n help = \"\"\"\\\nThe self update command updates Poetry version in its current runtime \\\nenvironment.\n\"\"\"\n\n def _system_project_handle(self) -> int:\n self.write(\"Updating Poetry version ...\\n\\n\")\n application = self.get_application()\n add_command = application.find(\"add\")\n assert isinstance(add_command, AddCommand)\n add_command.set_env(self.env)\n application.configure_installer_for_command(add_command, self.io)\n\n argv = [\"add\", f\"poetry@{self.argument('version')}\"]\n\n if self.option(\"dry-run\"):\n argv.append(\"--dry-run\")\n\n if self.option(\"preview\"):\n argv.append(\"--allow-prereleases\")\n\n exit_code: int = add_command.run(\n IO(\n StringInput(\" \".join(argv)),\n self.io.output,\n self.io.error_output,\n )\n )\n return exit_code\n\n# Path: src/poetry/console/commands/source/__init__.py\n\n# Path: src/poetry/console/commands/source/add.py\nfrom __future__ import annotations\n\nfrom typing import TYPE_CHECKING\nfrom typing import ClassVar\n\nfrom cleo.helpers import argument\nfrom cleo.helpers import option\nfrom cleo.io.null_io import NullIO\nfrom tomlkit.items import AoT\n\nfrom poetry.config.source import Source\nfrom poetry.console.commands.command import Command\nfrom poetry.repositories.repository_pool import Priority\n\n\nif TYPE_CHECKING:\n from cleo.io.inputs.argument import Argument\n from cleo.io.inputs.option import Option\n\n\nclass SourceAddCommand(Command):\n name = \"source add\"\n description = \"Add source configuration for project.\"\n\n arguments: ClassVar[list[Argument]] = [\n argument(\n \"name\",\n \"Source repository name.\",\n ),\n argument(\n \"url\",\n \"Source repository URL.\"\n \" Required, except for PyPI, for which it is not allowed.\",\n optional=True,\n ),\n ]\n\n options: ClassVar[list[Option]] = [\n option(\n \"default\",\n \"d\",\n \"Set this source as the default (disable PyPI). A \"\n \"default source will also be the fallback source if \"\n \"you add other sources. (Deprecated, use --priority)\",\n ),\n option(\n \"secondary\",\n \"s\",\n \"Set this source as secondary. (Deprecated, use\"\n \" --priority)\",\n ),\n option(\n \"priority\",\n \"p\",\n \"Set the priority of this source. One of:\"\n f\" {', '.join(p.name.lower() for p in Priority)}. Defaults to\"\n f\" {Priority.PRIMARY.name.lower()}.\",\n flag=False,\n ),\n ]\n\n def handle(self) -> int:\n from poetry.factory import Factory\n from poetry.utils.source import source_to_table\n\n name: str = self.argument(\"name\")\n lower_name = name.lower()\n url: str = self.argument(\"url\")\n is_default: bool = self.option(\"default\", False)\n is_secondary: bool = self.option(\"secondary\", False)\n priority_str: str | None = self.option(\"priority\", None)\n\n if lower_name == \"pypi\":\n name = \"PyPI\"\n if url:\n self.line_error(\n \"The URL of PyPI is fixed and cannot be set.\"\n )\n return 1\n elif not url:\n self.line_error(\n \"A custom source cannot be added without a URL.\"\n )\n return 1\n\n if is_default and is_secondary:\n self.line_error(\n \"Cannot configure a source as both default and\"\n \" secondary.\"\n )\n return 1\n\n if is_default or is_secondary:\n if priority_str is not None:\n self.line_error(\n \"Priority was passed through both --priority and a\"\n \" deprecated flag (--default or --secondary). Please only provide\"\n \" one of these.\"\n )\n return 1\n else:\n self.line_error(\n \"Warning: Priority was set through a deprecated flag\"\n \" (--default or --secondary). Consider using --priority next\"\n \" time.\"\n )\n\n if is_default:\n priority = Priority.DEFAULT\n elif is_secondary:\n priority = Priority.SECONDARY\n elif priority_str is None:\n priority = Priority.PRIMARY\n else:\n priority = Priority[priority_str.upper()]\n\n if priority is Priority.SECONDARY:\n allowed_prios = (\n p for p in Priority if p not in {Priority.DEFAULT, Priority.SECONDARY}\n )\n self.line_error(\n \"Warning: Priority 'secondary' is deprecated. Consider\"\n \" changing the priority to one of the non-deprecated values:\"\n f\" {', '.join(repr(p.name.lower()) for p in allowed_prios)}.\"\n )\n if priority is Priority.DEFAULT:\n self.line_error(\n \"Warning: Priority 'default' is deprecated. You can achieve\"\n \" the same effect by changing the priority to 'primary' and putting\"\n \" the source first.\"\n )\n\n sources = AoT([])\n new_source = Source(name=name, url=url, priority=priority)\n is_new_source = True\n\n for source in self.poetry.get_sources():\n if source.priority is Priority.DEFAULT and priority is Priority.DEFAULT:\n self.line_error(\n f\"Source with name {source.name} is already set to\"\n \" default. Only one default source can be configured at a\"\n \" time.\"\n )\n return 1\n\n if source.name.lower() == lower_name:\n source = new_source\n is_new_source = False\n\n sources.append(source_to_table(source))\n\n if is_new_source:\n self.line(f\"Adding source with name {name}.\")\n sources.append(source_to_table(new_source))\n else:\n self.line(f\"Source with name {name} already exists. Updating.\")\n\n # ensure new source is valid. eg: invalid name etc.\n try:\n pool = Factory.create_pool(self.poetry.config, sources, NullIO())\n pool.repository(name)\n except ValueError as e:\n self.line_error(\n f\"Failed to validate addition of {name}: {e}\"\n )\n return 1\n\n self.poetry.pyproject.poetry_config[\"source\"] = sources\n self.poetry.pyproject.save()\n\n return 0\n\n# Path: src/poetry/console/commands/source/remove.py\nfrom __future__ import annotations\n\nfrom typing import TYPE_CHECKING\nfrom typing import ClassVar\n\nfrom cleo.helpers import argument\nfrom tomlkit.items import AoT\n\nfrom poetry.console.commands.command import Command\n\n\nif TYPE_CHECKING:\n from cleo.io.inputs.argument import Argument\n\n\nclass SourceRemoveCommand(Command):\n name = \"source remove\"\n description = \"Remove source configured for the project.\"\n\n arguments: ClassVar[list[Argument]] = [\n argument(\n \"name\",\n \"Source repository name.\",\n ),\n ]\n\n def handle(self) -> int:\n from poetry.utils.source import source_to_table\n\n name = self.argument(\"name\")\n lower_name = name.lower()\n\n sources = AoT([])\n removed = False\n\n for source in self.poetry.get_sources():\n if source.name.lower() == lower_name:\n self.line(f\"Removing source with name {source.name}.\")\n removed = True\n continue\n sources.append(source_to_table(source))\n\n if not removed:\n self.line_error(\n f\"Source with name {name} was not found.\"\n )\n return 1\n\n self.poetry.pyproject.poetry_config[\"source\"] = sources\n self.poetry.pyproject.save()\n\n return 0\n\n# Path: src/poetry/console/commands/source/show.py\nfrom __future__ import annotations\n\nfrom typing import TYPE_CHECKING\nfrom typing import ClassVar\n\nfrom cleo.helpers import argument\n\nfrom poetry.console.commands.command import Command\n\n\nif TYPE_CHECKING:\n from cleo.io.inputs.argument import Argument\n from cleo.ui.table import Rows\n\n\nclass SourceShowCommand(Command):\n name = \"source show\"\n description = \"Show information about sources configured for the project.\"\n\n arguments: ClassVar[list[Argument]] = [\n argument(\n \"source\",\n \"Source(s) to show information for. Defaults to showing all sources.\",\n optional=True,\n multiple=True,\n ),\n ]\n\n def handle(self) -> int:\n sources = self.poetry.get_sources()\n names = self.argument(\"source\")\n lower_names = [name.lower() for name in names]\n\n if not sources:\n self.line(\"No sources configured for this project.\")\n return 0\n\n if names and not any(s.name.lower() in lower_names for s in sources):\n self.line_error(\n f\"No source found with name(s): {', '.join(names)}\",\n style=\"error\",\n )\n return 1\n\n for source in sources:\n if names and source.name.lower() not in lower_names:\n continue\n\n table = self.table(style=\"compact\")\n rows: Rows = [[\"name\", f\" : {source.name}\"]]\n if source.url:\n rows.append([\"url\", f\" : {source.url}\"])\n rows.append([\"priority\", f\" : {source.priority.name.lower()}\"])\n table.add_rows(rows)\n table.render()\n self.line(\"\")\n\n return 0\n\n# Path: src/poetry/console/io/inputs/__init__.py\n\n# Path: src/poetry/console/io/inputs/run_argv_input.py\nfrom __future__ import annotations\n\nfrom typing import TYPE_CHECKING\n\nfrom cleo.io.inputs.argv_input import ArgvInput\n\n\nif TYPE_CHECKING:\n from cleo.io.inputs.definition import Definition\n\n\nclass RunArgvInput(ArgvInput):\n def __init__(\n self,\n argv: list[str] | None = None,\n definition: Definition | None = None,\n ) -> None:\n super().__init__(argv, definition=definition)\n\n self._parameter_options: list[str] = []\n\n @property\n def first_argument(self) -> str | None:\n return \"run\"\n\n def add_parameter_option(self, name: str) -> None:\n self._parameter_options.append(name)\n\n def has_parameter_option(\n self, values: str | list[str], only_params: bool = False\n ) -> bool:\n if not isinstance(values, list):\n values = [values]\n\n for token in self._tokens:\n if only_params and token == \"--\":\n return False\n\n for value in values:\n if value not in self._parameter_options:\n continue\n\n # Options with values:\n # For long options, test for '--option=' at beginning\n # For short options, test for '-o' at beginning\n leading = value + \"=\" if value.startswith(\"--\") else value\n\n if token == value or leading != \"\" and token.startswith(leading):\n return True\n\n return False\n\n def _parse(self) -> None:\n parse_options = True\n self._parsed = self._tokens[:]\n\n try:\n token = self._parsed.pop(0)\n except IndexError:\n token = None\n\n while token is not None:\n if parse_options and token == \"\":\n self._parse_argument(token)\n elif parse_options and token == \"--\":\n parse_options = False\n elif parse_options and token.find(\"--\") == 0:\n if token in self._parameter_options:\n self._parse_long_option(token)\n else:\n self._parse_argument(token)\n elif parse_options and token[0] == \"-\" and token != \"-\":\n if token in self._parameter_options:\n self._parse_short_option(token)\n else:\n self._parse_argument(token)\n else:\n self._parse_argument(token)\n\n try:\n token = self._parsed.pop(0)\n except IndexError:\n token = None\n\n# Path: src/poetry/console/logging/formatters/__init__.py\nfrom __future__ import annotations\n\nfrom poetry.console.logging.formatters.builder_formatter import BuilderLogFormatter\n\n\nFORMATTERS = {\n \"poetry.core.masonry.builders.builder\": BuilderLogFormatter(),\n \"poetry.core.masonry.builders.sdist\": BuilderLogFormatter(),\n \"poetry.core.masonry.builders.wheel\": BuilderLogFormatter(),\n}\n\n# Path: src/poetry/console/logging/formatters/builder_formatter.py\nfrom __future__ import annotations\n\nimport re\n\nfrom poetry.console.logging.formatters.formatter import Formatter\n\n\nclass BuilderLogFormatter(Formatter):\n def format(self, msg: str) -> str:\n if msg.startswith(\"Building \"):\n msg = re.sub(\"Building (.+)\", \" - Building \\\\1\", msg)\n elif msg.startswith(\"Built \"):\n msg = re.sub(\"Built (.+)\", \" - Built \\\\1\", msg)\n elif msg.startswith(\"Adding: \"):\n msg = re.sub(\"Adding: (.+)\", \" - Adding: \\\\1\", msg)\n elif msg.startswith(\"Executing build script: \"):\n msg = re.sub(\n \"Executing build script: (.+)\",\n \" - Executing build script: \\\\1\",\n msg,\n )\n\n return msg\n\n# Path: src/poetry/console/logging/formatters/formatter.py\nfrom __future__ import annotations\n\n\nclass Formatter:\n def format(self, record: str) -> str:\n raise NotImplementedError()\n\n# Path: src/poetry/console/commands/self/show/__init__.py\nfrom __future__ import annotations\n\nfrom typing import TYPE_CHECKING\nfrom typing import ClassVar\n\nfrom cleo.helpers import option\n\nfrom poetry.console.commands.self.self_command import SelfCommand\nfrom poetry.console.commands.show import ShowCommand\n\n\nif TYPE_CHECKING:\n from cleo.io.inputs.option import Option\n\n\nclass SelfShowCommand(SelfCommand, ShowCommand):\n name = \"self show\"\n options: ClassVar[list[Option]] = [\n option(\"addons\", None, \"List only add-on packages installed.\"),\n *[o for o in ShowCommand.options if o.name in {\"tree\", \"latest\", \"outdated\"}],\n ]\n description = \"Show packages from Poetry's runtime environment.\"\n help = f\"\"\"\\\nThe self show command behaves similar to the show command, but\nworking within Poetry's runtime environment. This lists all packages installed within\nthe Poetry install environment.\n\nTo show only additional packages that have been added via self add and their\ndependencies use self show --addons.\n\nThis is managed in the {SelfCommand.get_default_system_pyproject_file()} \\\nfile.\n\"\"\"\n\n @property\n def activated_groups(self) -> set[str]:\n if self.option(\"addons\", False):\n return {SelfCommand.ADDITIONAL_PACKAGE_GROUP}\n\n groups: set[str] = super(ShowCommand, self).activated_groups\n return groups\n\n# Path: src/poetry/console/commands/self/show/plugins.py\nfrom __future__ import annotations\n\nimport dataclasses\n\nfrom typing import TYPE_CHECKING\n\nfrom poetry.console.commands.self.self_command import SelfCommand\n\n\nif TYPE_CHECKING:\n from poetry.core.packages.package import Package\n\n from poetry.utils._compat import metadata\n\n\n@dataclasses.dataclass\nclass PluginPackage:\n package: Package\n plugins: list[metadata.EntryPoint] = dataclasses.field(default_factory=list)\n application_plugins: list[metadata.EntryPoint] = dataclasses.field(\n default_factory=list\n )\n\n def append(self, entry_point: metadata.EntryPoint) -> None:\n from poetry.plugins.application_plugin import ApplicationPlugin\n from poetry.plugins.plugin import Plugin\n\n group = entry_point.group\n\n if group == ApplicationPlugin.group:\n self.application_plugins.append(entry_point)\n elif group == Plugin.group:\n self.plugins.append(entry_point)\n else:\n name = entry_point.name\n raise ValueError(f\"Unknown plugin group ({group}) for {name}\")\n\n\n...\n", "needle_token_start": 6960, "needle_token_end": 7426, "code_context_ntokens": 15965, "output": ["```python\ndef _display_complete_info(self, env: Env) -> None:\n env_python_version = \".\".join(str(s) for s in env.version_info[:3])\n self.line(\"\")\n self.line(\"Virtualenv\")\n listing = [\n f\"Python: {env_python_version}\",\n f\"Implementation: {env.python_implementation}\",\n (\n \"Path: \"\n f\" {env.path if env.is_venv() else 'NA'}\"\n ),\n (\n \"Executable: \"\n f\" {env.python if env.is_venv() else 'NA'}\"\n ),\n ]\n if env.is_venv():\n listing.append(\n \"Valid: \"\n f\" <{'comment' if env.is_sane() else 'error'}>{env.is_sane()}\"\n )\n self.line(\"\\n\".join(listing))\n\n self.line(\"\")\n\n base_env = env.parent_env\n python = \".\".join(str(v) for v in base_env.version_info[:3])\n self.line(\"Base\")\n self.line(\n \"\\n\".join(\n [\n f\"Platform: {env.platform}\",\n f\"OS: {env.os}\",\n f\"Python: {python}\",\n f\"Path: {base_env.path}\",\n f\"Executable: {base_env.python}\",\n ]\n )\n )\n```"]} +{"repo": "python-poetry/poetry", "name": "_write_lock_file", "language": "python", "path": "src/poetry/installation/installer.py", "position_ratio": 0.55, "description": "\nFunction Description:\n1. **Purpose**: The function is designed to update and save the project's dependency lock file based on the current state of a repository that tracks locked dependencies, ensuring that the project's dependencies are consistently managed and reproducible.\n2. **Input**: It takes a repository object that contains the current state of locked dependencies and an optional boolean flag that forces the update regardless of other conditions.\n3. **Output**: There is no return value; the function's effect is the potential modification of the lock file on disk.\n4. **Procedure**: The function first checks if the operation is not a dry run and if an update is necessary or forced. If conditions are met, it updates the lock data with the current state of dependencies from the provided repository. If the lock data is successfully updated, it outputs a message indicating that the lock file is being written.\n\n", "instruction": "Based on the function description and code context, please retrieve and repeat the exact described function from the code context in a code block wrapped by ```:", "template": "instruction\ncode_context\ndescription\ninstruction", "code_context": "# Path: src/poetry/installation/executor.py\nfrom __future__ import annotations\n\nimport contextlib\nimport csv\nimport functools\nimport itertools\nimport json\nimport threading\n\nfrom concurrent.futures import ThreadPoolExecutor\nfrom concurrent.futures import wait\nfrom pathlib import Path\nfrom subprocess import CalledProcessError\nfrom typing import TYPE_CHECKING\nfrom typing import Any\n\nfrom cleo.io.null_io import NullIO\nfrom poetry.core.packages.utils.link import Link\n\nfrom poetry.installation.chef import Chef\nfrom poetry.installation.chef import ChefBuildError\nfrom poetry.installation.chef import ChefInstallError\nfrom poetry.installation.chooser import Chooser\nfrom poetry.installation.operations import Install\nfrom poetry.installation.operations import Uninstall\nfrom poetry.installation.operations import Update\nfrom poetry.installation.wheel_installer import WheelInstaller\nfrom poetry.puzzle.exceptions import SolverProblemError\nfrom poetry.utils._compat import decode\nfrom poetry.utils.authenticator import Authenticator\nfrom poetry.utils.env import EnvCommandError\nfrom poetry.utils.helpers import Downloader\nfrom poetry.utils.helpers import get_file_hash\nfrom poetry.utils.helpers import get_highest_priority_hash_type\nfrom poetry.utils.helpers import pluralize\nfrom poetry.utils.helpers import remove_directory\nfrom poetry.utils.pip import pip_install\n\n\nif TYPE_CHECKING:\n from cleo.io.io import IO\n from cleo.io.outputs.section_output import SectionOutput\n from poetry.core.masonry.builders.builder import Builder\n from poetry.core.packages.package import Package\n\n from poetry.config.config import Config\n from poetry.installation.operations.operation import Operation\n from poetry.repositories import RepositoryPool\n from poetry.utils.env import Env\n\n\nclass Executor:\n def __init__(\n self,\n env: Env,\n pool: RepositoryPool,\n config: Config,\n io: IO,\n parallel: bool | None = None,\n disable_cache: bool = False,\n ) -> None:\n self._env = env\n self._io = io\n self._dry_run = False\n self._enabled = True\n self._verbose = False\n self._wheel_installer = WheelInstaller(self._env)\n self._use_modern_installation = config.get(\n \"installer.modern-installation\", True\n )\n if not self._use_modern_installation:\n self._io.write_line(\n \"Warning: Setting `installer.modern-installation` to `false` \"\n \"is deprecated.\"\n )\n self._io.write_line(\n \"The pip-based installer will be removed in a future release.\"\n )\n self._io.write_line(\n \"See https://github.com/python-poetry/poetry/issues/8987.\"\n )\n\n if parallel is None:\n parallel = config.get(\"installer.parallel\", True)\n\n if parallel:\n self._max_workers = config.installer_max_workers\n else:\n self._max_workers = 1\n\n self._artifact_cache = pool.artifact_cache\n self._authenticator = Authenticator(\n config, self._io, disable_cache=disable_cache, pool_size=self._max_workers\n )\n self._chef = Chef(self._artifact_cache, self._env, pool)\n self._chooser = Chooser(pool, self._env, config)\n\n self._executor = ThreadPoolExecutor(max_workers=self._max_workers)\n self._total_operations = 0\n self._executed_operations = 0\n self._executed = {\"install\": 0, \"update\": 0, \"uninstall\": 0}\n self._skipped = {\"install\": 0, \"update\": 0, \"uninstall\": 0}\n self._sections: dict[int, SectionOutput] = {}\n self._yanked_warnings: list[str] = []\n self._lock = threading.Lock()\n self._shutdown = False\n self._hashes: dict[str, str] = {}\n\n @property\n def installations_count(self) -> int:\n return self._executed[\"install\"]\n\n @property\n def updates_count(self) -> int:\n return self._executed[\"update\"]\n\n @property\n def removals_count(self) -> int:\n return self._executed[\"uninstall\"]\n\n @property\n def enabled(self) -> bool:\n return self._enabled\n\n def supports_fancy_output(self) -> bool:\n return self._io.output.is_decorated() and not self._dry_run\n\n def disable(self) -> Executor:\n self._enabled = False\n\n return self\n\n def dry_run(self, dry_run: bool = True) -> Executor:\n self._dry_run = dry_run\n\n return self\n\n def verbose(self, verbose: bool = True) -> Executor:\n self._verbose = verbose\n\n return self\n\n def enable_bytecode_compilation(self, enable: bool = True) -> None:\n self._wheel_installer.enable_bytecode_compilation(enable)\n\n def pip_install(\n self, req: Path, upgrade: bool = False, editable: bool = False\n ) -> int:\n try:\n pip_install(req, self._env, upgrade=upgrade, editable=editable)\n except EnvCommandError as e:\n output = decode(e.e.output)\n if (\n \"KeyboardInterrupt\" in output\n or \"ERROR: Operation cancelled by user\" in output\n ):\n return -2\n raise\n\n return 0\n\n def execute(self, operations: list[Operation]) -> int:\n self._total_operations = len(operations)\n for job_type in self._executed:\n self._executed[job_type] = 0\n self._skipped[job_type] = 0\n\n if operations and (self._enabled or self._dry_run):\n self._display_summary(operations)\n\n self._sections = {}\n self._yanked_warnings = []\n\n # pip has to be installed first without parallelism if we install via pip\n for i, op in enumerate(operations):\n if op.package.name == \"pip\":\n wait([self._executor.submit(self._execute_operation, op)])\n del operations[i]\n break\n\n # We group operations by priority\n groups = itertools.groupby(operations, key=lambda o: -o.priority)\n for _, group in groups:\n tasks = []\n serial_operations = []\n for operation in group:\n if self._shutdown:\n break\n\n # Some operations are unsafe, we must execute them serially in a group\n # https://github.com/python-poetry/poetry/issues/3086\n # https://github.com/python-poetry/poetry/issues/2658\n #\n # We need to explicitly check source type here, see:\n # https://github.com/python-poetry/poetry-core/pull/98\n is_parallel_unsafe = operation.job_type == \"uninstall\" or (\n operation.package.develop\n and operation.package.source_type in {\"directory\", \"git\"}\n )\n if not operation.skipped and is_parallel_unsafe:\n serial_operations.append(operation)\n continue\n\n tasks.append(self._executor.submit(self._execute_operation, operation))\n\n try:\n wait(tasks)\n\n for operation in serial_operations:\n wait([self._executor.submit(self._execute_operation, operation)])\n\n except KeyboardInterrupt:\n self._shutdown = True\n\n if self._shutdown:\n # Cancelling further tasks from being executed\n [task.cancel() for task in tasks]\n self._executor.shutdown(wait=True)\n\n break\n\n for warning in self._yanked_warnings:\n self._io.write_error_line(f\"Warning: {warning}\")\n for path, issues in self._wheel_installer.invalid_wheels.items():\n formatted_issues = \"\\n\".join(issues)\n warning = (\n f\"Validation of the RECORD file of {path.name} failed.\"\n \" Please report to the maintainers of that package so they can fix\"\n f\" their build process. Details:\\n{formatted_issues}\\n\"\n )\n self._io.write_error_line(f\"Warning: {warning}\")\n\n return 1 if self._shutdown else 0\n\n def _write(self, operation: Operation, line: str) -> None:\n if not self.supports_fancy_output() or not self._should_write_operation(\n operation\n ):\n return\n\n if self._io.is_debug():\n with self._lock:\n section = self._sections[id(operation)]\n section.write_line(line)\n\n return\n\n with self._lock:\n section = self._sections[id(operation)]\n section.clear()\n section.write(line)\n\n def _execute_operation(self, operation: Operation) -> None:\n try:\n op_message = self.get_operation_message(operation)\n if self.supports_fancy_output():\n if id(operation) not in self._sections and self._should_write_operation(\n operation\n ):\n with self._lock:\n self._sections[id(operation)] = self._io.section()\n self._sections[id(operation)].write_line(\n f\" - {op_message}:\"\n \" Pending...\"\n )\n else:\n if self._should_write_operation(operation):\n if not operation.skipped:\n self._io.write_line(\n f\" - {op_message}\"\n )\n else:\n self._io.write_line(\n f\" - {op_message}: \"\n \"Skipped \"\n \"for the following reason: \"\n f\"{operation.skip_reason}\"\n )\n\n try:\n result = self._do_execute_operation(operation)\n except EnvCommandError as e:\n if e.e.returncode == -2:\n result = -2\n else:\n raise\n\n # If we have a result of -2 it means a KeyboardInterrupt\n # in the any python subprocess, so we raise a KeyboardInterrupt\n # error to be picked up by the error handler.\n if result == -2:\n raise KeyboardInterrupt\n except Exception as e:\n try:\n from cleo.ui.exception_trace import ExceptionTrace\n\n io: IO | SectionOutput\n if not self.supports_fancy_output():\n io = self._io\n else:\n message = (\n \" -\"\n f\" {self.get_operation_message(operation, error=True)}:\"\n \" Failed\"\n )\n self._write(operation, message)\n io = self._sections.get(id(operation), self._io)\n\n with self._lock:\n trace = ExceptionTrace(e)\n trace.render(io)\n pkg = operation.package\n if isinstance(e, ChefBuildError):\n pip_command = \"pip wheel --no-cache-dir --use-pep517\"\n if pkg.develop:\n requirement = pkg.source_url\n pip_command += \" --editable\"\n else:\n requirement = (\n pkg.to_dependency().to_pep_508().split(\";\")[0].strip()\n )\n message = (\n \"\"\n \"Note: This error originates from the build backend,\"\n \" and is likely not a problem with poetry\"\n f\" but with {pkg.pretty_name} ({pkg.full_pretty_version})\"\n \" not supporting PEP 517 builds. You can verify this by\"\n f\" running '{pip_command} \\\"{requirement}\\\"'.\"\n \"\"\n )\n elif isinstance(e, ChefInstallError):\n message = (\n \"\"\n \"Cannot install build-system.requires\"\n f\" for {pkg.pretty_name}.\"\n \"\"\n )\n elif isinstance(e, SolverProblemError):\n message = (\n \"\"\n \"Cannot resolve build-system.requires\"\n f\" for {pkg.pretty_name}.\"\n \"\"\n )\n else:\n message = f\"Cannot install {pkg.pretty_name}.\"\n\n io.write_line(\"\")\n io.write_line(message)\n io.write_line(\"\")\n finally:\n with self._lock:\n self._shutdown = True\n\n except KeyboardInterrupt:\n try:\n message = (\n \" -\"\n f\" {self.get_operation_message(operation, warning=True)}:\"\n \" Cancelled\"\n )\n if not self.supports_fancy_output():\n self._io.write_line(message)\n else:\n self._write(operation, message)\n finally:\n with self._lock:\n self._shutdown = True\n\n def _do_execute_operation(self, operation: Operation) -> int:\n method = operation.job_type\n\n operation_message = self.get_operation_message(operation)\n if operation.skipped:\n if self.supports_fancy_output():\n self._write(\n operation,\n f\" - {operation_message}: \"\n \"Skipped \"\n \"for the following reason: \"\n f\"{operation.skip_reason}\",\n )\n\n self._skipped[operation.job_type] += 1\n\n return 0\n\n if not self._enabled or self._dry_run:\n return 0\n\n result: int = getattr(self, f\"_execute_{method}\")(operation)\n\n if result != 0:\n return result\n\n operation_message = self.get_operation_message(operation, done=True)\n message = f\" - {operation_message}\"\n self._write(operation, message)\n\n self._increment_operations_count(operation, True)\n\n return result\n\n def _increment_operations_count(self, operation: Operation, executed: bool) -> None:\n with self._lock:\n if executed:\n self._executed_operations += 1\n self._executed[operation.job_type] += 1\n else:\n self._skipped[operation.job_type] += 1\n\n def run_pip(self, *args: Any, **kwargs: Any) -> int:\n try:\n self._env.run_pip(*args, **kwargs)\n except EnvCommandError as e:\n output = decode(e.e.output)\n if (\n \"KeyboardInterrupt\" in output\n or \"ERROR: Operation cancelled by user\" in output\n ):\n return -2\n\n raise\n\n return 0\n\n def get_operation_message(\n self,\n operation: Operation,\n done: bool = False,\n error: bool = False,\n warning: bool = False,\n ) -> str:\n base_tag = \"fg=default\"\n operation_color = \"c2\"\n source_operation_color = \"c2\"\n package_color = \"c1\"\n\n if error:\n operation_color = \"error\"\n elif warning:\n operation_color = \"warning\"\n elif done:\n operation_color = \"success\"\n\n if operation.skipped:\n base_tag = \"fg=default;options=dark\"\n operation_color += \"_dark\"\n source_operation_color += \"_dark\"\n package_color += \"_dark\"\n\n if isinstance(operation, Install):\n return (\n f\"<{base_tag}>Installing\"\n f\" <{package_color}>{operation.package.name}\"\n f\" (<{operation_color}>{operation.package.full_pretty_version})\"\n )\n\n if isinstance(operation, Uninstall):\n return (\n f\"<{base_tag}>Removing\"\n f\" <{package_color}>{operation.package.name}\"\n f\" (<{operation_color}>{operation.package.full_pretty_version})\"\n )\n\n if isinstance(operation, Update):\n initial_version = (initial_pkg := operation.initial_package).version\n target_version = (target_pkg := operation.target_package).version\n update_kind = (\n \"Updating\" if target_version >= initial_version else \"Downgrading\"\n )\n return (\n f\"<{base_tag}>{update_kind}\"\n f\" <{package_color}>{initial_pkg.name} \"\n f\"(<{source_operation_color}>\"\n f\"{initial_pkg.full_pretty_version}\"\n f\" -> <{operation_color}>\"\n f\"{target_pkg.full_pretty_version})\"\n )\n return \"\"\n\n def _display_summary(self, operations: list[Operation]) -> None:\n installs = 0\n updates = 0\n uninstalls = 0\n skipped = 0\n for op in operations:\n if op.skipped:\n skipped += 1\n continue\n\n if op.job_type == \"install\":\n installs += 1\n elif op.job_type == \"update\":\n updates += 1\n elif op.job_type == \"uninstall\":\n uninstalls += 1\n\n if not installs and not updates and not uninstalls and not self._verbose:\n self._io.write_line(\"\")\n self._io.write_line(\"No dependencies to install or update\")\n\n return\n\n self._io.write_line(\"\")\n self._io.write(\"Package operations: \")\n self._io.write(f\"{installs} install{pluralize(installs)}, \")\n self._io.write(f\"{updates} update{pluralize(updates)}, \")\n self._io.write(f\"{uninstalls} removal{pluralize(uninstalls)}\")\n if skipped and self._verbose:\n self._io.write(f\", {skipped} skipped\")\n self._io.write_line(\"\")\n self._io.write_line(\"\")\n\n def _execute_install(self, operation: Install | Update) -> int:\n status_code = self._install(operation)\n\n self._save_url_reference(operation)\n\n return status_code\n\n def _execute_update(self, operation: Install | Update) -> int:\n status_code = self._update(operation)\n\n self._save_url_reference(operation)\n\n return status_code\n\n def _execute_uninstall(self, operation: Uninstall) -> int:\n op_msg = self.get_operation_message(operation)\n message = f\" - {op_msg}: Removing...\"\n self._write(operation, message)\n\n return self._remove(operation.package)\n\n def _install(self, operation: Install | Update) -> int:\n package = operation.package\n if package.source_type == \"directory\" and not self._use_modern_installation:\n return self._install_directory_without_wheel_installer(operation)\n\n cleanup_archive: bool = False\n if package.source_type == \"git\":\n archive = self._prepare_git_archive(operation)\n cleanup_archive = operation.package.develop\n elif package.source_type == \"file\":\n archive = self._prepare_archive(operation)\n elif package.source_type == \"directory\":\n archive = self._prepare_archive(operation)\n cleanup_archive = True\n elif package.source_type == \"url\":\n assert package.source_url is not None\n archive = self._download_link(operation, Link(package.source_url))\n else:\n archive = self._download(operation)\n\n operation_message = self.get_operation_message(operation)\n message = (\n f\" - {operation_message}:\"\n \" Installing...\"\n )\n self._write(operation, message)\n\n if not self._use_modern_installation:\n return self.pip_install(archive, upgrade=operation.job_type == \"update\")\n\n try:\n if operation.job_type == \"update\":\n # Uninstall first\n # TODO: Make an uninstaller and find a way to rollback in case\n # the new package can't be installed\n assert isinstance(operation, Update)\n self._remove(operation.initial_package)\n\n self._wheel_installer.install(archive)\n finally:\n if cleanup_archive:\n archive.unlink()\n\n return 0\n\n def _update(self, operation: Install | Update) -> int:\n return self._install(operation)\n\n def _remove(self, package: Package) -> int:\n # If we have a VCS package, remove its source directory\n if package.source_type == \"git\":\n src_dir = self._env.path / \"src\" / package.name\n...\n# Path: src/poetry/installation/installer.py\nfrom __future__ import annotations\n\nfrom typing import TYPE_CHECKING\nfrom typing import cast\n\nfrom cleo.io.null_io import NullIO\nfrom packaging.utils import canonicalize_name\n\nfrom poetry.installation.executor import Executor\nfrom poetry.installation.operations import Install\nfrom poetry.installation.operations import Uninstall\nfrom poetry.installation.operations import Update\nfrom poetry.repositories import Repository\nfrom poetry.repositories import RepositoryPool\nfrom poetry.repositories.installed_repository import InstalledRepository\nfrom poetry.repositories.lockfile_repository import LockfileRepository\nfrom poetry.utils.extras import get_extra_package_names\n\n\nif TYPE_CHECKING:\n from collections.abc import Iterable\n\n from cleo.io.io import IO\n from packaging.utils import NormalizedName\n from poetry.core.packages.path_dependency import PathDependency\n from poetry.core.packages.project_package import ProjectPackage\n\n from poetry.config.config import Config\n from poetry.installation.operations.operation import Operation\n from poetry.packages import Locker\n from poetry.utils.env import Env\n\n\nclass Installer:\n def __init__(\n self,\n io: IO,\n env: Env,\n package: ProjectPackage,\n locker: Locker,\n pool: RepositoryPool,\n config: Config,\n installed: Repository | None = None,\n executor: Executor | None = None,\n disable_cache: bool = False,\n ) -> None:\n self._io = io\n self._env = env\n self._package = package\n self._locker = locker\n self._pool = pool\n self._config = config\n\n self._dry_run = False\n self._requires_synchronization = False\n self._update = False\n self._verbose = False\n self._groups: Iterable[str] | None = None\n self._skip_directory = False\n self._lock = False\n\n self._whitelist: list[NormalizedName] = []\n\n self._extras: list[NormalizedName] = []\n\n if executor is None:\n executor = Executor(\n self._env, self._pool, config, self._io, disable_cache=disable_cache\n )\n\n self._executor = executor\n\n if installed is None:\n installed = self._get_installed()\n\n self._installed_repository = installed\n\n @property\n def executor(self) -> Executor:\n return self._executor\n\n def set_package(self, package: ProjectPackage) -> Installer:\n self._package = package\n\n return self\n\n def set_locker(self, locker: Locker) -> Installer:\n self._locker = locker\n\n return self\n\n def run(self) -> int:\n # Check if refresh\n if not self._update and self._lock and self._locker.is_locked():\n return self._do_refresh()\n\n # Force update if there is no lock file present\n if not self._update and not self._locker.is_locked():\n self._update = True\n\n if self.is_dry_run():\n self.verbose(True)\n\n return self._do_install()\n\n def dry_run(self, dry_run: bool = True) -> Installer:\n self._dry_run = dry_run\n self._executor.dry_run(dry_run)\n\n return self\n\n def is_dry_run(self) -> bool:\n return self._dry_run\n\n def requires_synchronization(\n self, requires_synchronization: bool = True\n ) -> Installer:\n self._requires_synchronization = requires_synchronization\n\n return self\n\n def verbose(self, verbose: bool = True) -> Installer:\n self._verbose = verbose\n self._executor.verbose(verbose)\n\n return self\n\n def is_verbose(self) -> bool:\n return self._verbose\n\n def only_groups(self, groups: Iterable[str]) -> Installer:\n self._groups = groups\n\n return self\n\n def update(self, update: bool = True) -> Installer:\n self._update = update\n\n return self\n\n def skip_directory(self, skip_directory: bool = False) -> Installer:\n self._skip_directory = skip_directory\n\n return self\n\n def lock(self, update: bool = True) -> Installer:\n \"\"\"\n Prepare the installer for locking only.\n \"\"\"\n self.update(update=update)\n self.execute_operations(False)\n self._lock = True\n\n return self\n\n def is_updating(self) -> bool:\n return self._update\n\n def execute_operations(self, execute: bool = True) -> Installer:\n if not execute:\n self._executor.disable()\n\n return self\n\n def whitelist(self, packages: Iterable[str]) -> Installer:\n self._whitelist = [canonicalize_name(p) for p in packages]\n\n return self\n\n def extras(self, extras: list[str]) -> Installer:\n self._extras = [canonicalize_name(extra) for extra in extras]\n\n return self\n\n def _do_refresh(self) -> int:\n from poetry.puzzle.solver import Solver\n\n # Checking extras\n for extra in self._extras:\n if extra not in self._package.extras:\n raise ValueError(f\"Extra [{extra}] is not specified.\")\n\n locked_repository = self._locker.locked_repository()\n solver = Solver(\n self._package,\n self._pool,\n locked_repository.packages,\n locked_repository.packages,\n self._io,\n )\n\n # Always re-solve directory dependencies, otherwise we can't determine\n # if anything has changed (and the lock file contains an invalid version).\n use_latest = [\n p.name for p in locked_repository.packages if p.source_type == \"directory\"\n ]\n\n with solver.provider.use_source_root(\n source_root=self._env.path.joinpath(\"src\")\n ):\n ops = solver.solve(use_latest=use_latest).calculate_operations()\n\n lockfile_repo = LockfileRepository()\n self._populate_lockfile_repo(lockfile_repo, ops)\n\n self._write_lock_file(lockfile_repo, force=True)\n\n return 0\n\n def _do_install(self) -> int:\n from poetry.puzzle.solver import Solver\n\n locked_repository = Repository(\"poetry-locked\")\n if self._update:\n if not self._lock and self._locker.is_locked():\n locked_repository = self._locker.locked_repository()\n\n # If no packages have been whitelisted (The ones we want to update),\n # we whitelist every package in the lock file.\n if not self._whitelist:\n for pkg in locked_repository.packages:\n self._whitelist.append(pkg.name)\n\n # Checking extras\n for extra in self._extras:\n if extra not in self._package.extras:\n raise ValueError(f\"Extra [{extra}] is not specified.\")\n\n self._io.write_line(\"Updating dependencies\")\n solver = Solver(\n self._package,\n self._pool,\n self._installed_repository.packages,\n locked_repository.packages,\n self._io,\n )\n\n with solver.provider.use_source_root(\n source_root=self._env.path.joinpath(\"src\")\n ):\n ops = solver.solve(use_latest=self._whitelist).calculate_operations()\n else:\n self._io.write_line(\"Installing dependencies from lock file\")\n\n locked_repository = self._locker.locked_repository()\n\n if not self._locker.is_fresh():\n raise ValueError(\n \"pyproject.toml changed significantly since poetry.lock was last generated. \"\n \"Run `poetry lock [--no-update]` to fix the lock file.\"\n )\n\n locker_extras = {\n canonicalize_name(extra)\n for extra in self._locker.lock_data.get(\"extras\", {})\n }\n for extra in self._extras:\n if extra not in locker_extras:\n raise ValueError(f\"Extra [{extra}] is not specified.\")\n\n # If we are installing from lock\n # Filter the operations by comparing it with what is\n # currently installed\n ops = self._get_operations_from_lock(locked_repository)\n\n lockfile_repo = LockfileRepository()\n uninstalls = self._populate_lockfile_repo(lockfile_repo, ops)\n\n if not self.executor.enabled:\n # If we are only in lock mode, no need to go any further\n self._write_lock_file(lockfile_repo)\n return 0\n\n if self._groups is not None:\n root = self._package.with_dependency_groups(list(self._groups), only=True)\n else:\n root = self._package.without_optional_dependency_groups()\n\n if self._io.is_verbose():\n self._io.write_line(\"\")\n self._io.write_line(\n \"Finding the necessary packages for the current system\"\n )\n\n # We resolve again by only using the lock file\n packages = lockfile_repo.packages + locked_repository.packages\n pool = RepositoryPool.from_packages(packages, self._config)\n\n solver = Solver(\n root,\n pool,\n self._installed_repository.packages,\n locked_repository.packages,\n NullIO(),\n )\n # Everything is resolved at this point, so we no longer need\n # to load deferred dependencies (i.e. VCS, URL and path dependencies)\n solver.provider.load_deferred(False)\n\n with solver.use_environment(self._env):\n ops = solver.solve(use_latest=self._whitelist).calculate_operations(\n with_uninstalls=self._requires_synchronization,\n synchronize=self._requires_synchronization,\n skip_directory=self._skip_directory,\n )\n\n if not self._requires_synchronization:\n # If no packages synchronisation has been requested we need\n # to calculate the uninstall operations\n from poetry.puzzle.transaction import Transaction\n\n transaction = Transaction(\n locked_repository.packages,\n [(package, 0) for package in lockfile_repo.packages],\n installed_packages=self._installed_repository.packages,\n root_package=root,\n )\n\n ops = [\n op\n for op in transaction.calculate_operations(with_uninstalls=True)\n if op.job_type == \"uninstall\"\n ] + ops\n else:\n ops = uninstalls + ops\n\n # We need to filter operations so that packages\n # not compatible with the current system,\n # or optional and not requested, are dropped\n self._filter_operations(ops, lockfile_repo)\n\n # Validate the dependencies\n for op in ops:\n dep = op.package.to_dependency()\n if dep.is_file() or dep.is_directory():\n dep = cast(\"PathDependency\", dep)\n dep.validate(raise_error=not op.skipped)\n\n # Execute operations\n status = self._execute(ops)\n\n if status == 0 and self._update:\n # Only write lock file when installation is success\n self._write_lock_file(lockfile_repo)\n\n return status\n\n \ndef _write_lock_file(self, repo: LockfileRepository, force: bool = False) -> None:\n if not self.is_dry_run() and (force or self._update):\n updated_lock = self._locker.set_lock_data(self._package, repo.packages)\n\n if updated_lock:\n self._io.write_line(\"\")\n self._io.write_line(\"Writing lock file\")\n\n def _execute(self, operations: list[Operation]) -> int:\n return self._executor.execute(operations)\n\n def _populate_lockfile_repo(\n self, repo: LockfileRepository, ops: Iterable[Operation]\n ) -> list[Uninstall]:\n uninstalls = []\n for op in ops:\n if isinstance(op, Uninstall):\n uninstalls.append(op)\n continue\n\n package = op.target_package if isinstance(op, Update) else op.package\n if not repo.has_package(package):\n repo.add_package(package)\n\n return uninstalls\n\n def _get_operations_from_lock(\n self, locked_repository: Repository\n ) -> list[Operation]:\n installed_repo = self._installed_repository\n ops: list[Operation] = []\n\n extra_packages = self._get_extra_packages(locked_repository)\n for locked in locked_repository.packages:\n is_installed = False\n for installed in installed_repo.packages:\n if locked.name == installed.name:\n is_installed = True\n if locked.optional and locked.name not in extra_packages:\n # Installed but optional and not requested in extras\n ops.append(Uninstall(locked))\n elif locked.version != installed.version:\n ops.append(Update(installed, locked))\n\n # If it's optional and not in required extras\n # we do not install\n if locked.optional and locked.name not in extra_packages:\n continue\n\n op = Install(locked)\n if is_installed:\n op.skip(\"Already installed\")\n\n ops.append(op)\n\n return ops\n\n def _filter_operations(self, ops: Iterable[Operation], repo: Repository) -> None:\n extra_packages = self._get_extra_packages(repo)\n for op in ops:\n package = op.target_package if isinstance(op, Update) else op.package\n\n if op.job_type == \"uninstall\":\n continue\n\n if not self._env.is_valid_for_marker(package.marker):\n op.skip(\"Not needed for the current environment\")\n continue\n\n # If a package is optional and not requested\n # in any extra we skip it\n if package.optional and package.name not in extra_packages:\n op.skip(\"Not required\")\n\n def _get_extra_packages(self, repo: Repository) -> set[NormalizedName]:\n \"\"\"\n Returns all package names required by extras.\n\n Maybe we just let the solver handle it?\n \"\"\"\n extras: dict[NormalizedName, list[NormalizedName]]\n if self._update:\n extras = {k: [d.name for d in v] for k, v in self._package.extras.items()}\n else:\n raw_extras = self._locker.lock_data.get(\"extras\", {})\n extras = {\n canonicalize_name(extra): [\n canonicalize_name(dependency) for dependency in dependencies\n ]\n for extra, dependencies in raw_extras.items()\n }\n\n return get_extra_package_names(repo.packages, extras, self._extras)\n\n def _get_installed(self) -> InstalledRepository:\n return InstalledRepository.load(self._env)\n\n# Path: src/poetry/installation/wheel_installer.py\nfrom __future__ import annotations\n\nimport logging\nimport platform\nimport sys\n\nfrom pathlib import Path\nfrom typing import TYPE_CHECKING\n\nfrom installer import install\nfrom installer.destinations import SchemeDictionaryDestination\nfrom installer.sources import WheelFile\nfrom installer.sources import _WheelFileValidationError\n\nfrom poetry.__version__ import __version__\nfrom poetry.utils._compat import WINDOWS\n\n\nlogger = logging.getLogger(__name__)\n\nif TYPE_CHECKING:\n from collections.abc import Collection\n from typing import BinaryIO\n\n from installer.records import RecordEntry\n from installer.scripts import LauncherKind\n from installer.utils import Scheme\n\n from poetry.utils.env import Env\n\n\nclass WheelDestination(SchemeDictionaryDestination):\n \"\"\" \"\"\"\n\n def write_to_fs(\n self,\n scheme: Scheme,\n path: str,\n stream: BinaryIO,\n is_executable: bool,\n ) -> RecordEntry:\n from installer.records import Hash\n from installer.records import RecordEntry\n from installer.utils import copyfileobj_with_hashing\n from installer.utils import make_file_executable\n\n target_path = Path(self.scheme_dict[scheme]) / path\n if target_path.exists():\n # Contrary to the base library we don't raise an error here since it can\n # break pkgutil-style and pkg_resource-style namespace packages.\n logger.warning(f\"Installing {target_path} over existing file\")\n\n parent_folder = target_path.parent\n if not parent_folder.exists():\n # Due to the parallel installation it can happen\n # that two threads try to create the directory.\n parent_folder.mkdir(parents=True, exist_ok=True)\n\n with target_path.open(\"wb\") as f:\n hash_, size = copyfileobj_with_hashing(stream, f, self.hash_algorithm)\n\n if is_executable:\n make_file_executable(target_path)\n\n return RecordEntry(path, Hash(self.hash_algorithm, hash_), size)\n\n\nclass WheelInstaller:\n def __init__(self, env: Env) -> None:\n self._env = env\n\n script_kind: LauncherKind\n if not WINDOWS:\n script_kind = \"posix\"\n else:\n if platform.uname()[4].startswith(\"arm\"):\n script_kind = \"win-arm64\" if sys.maxsize > 2**32 else \"win-arm\"\n else:\n script_kind = \"win-amd64\" if sys.maxsize > 2**32 else \"win-ia32\"\n self._script_kind = script_kind\n\n self._bytecode_optimization_levels: Collection[int] = ()\n self.invalid_wheels: dict[Path, list[str]] = {}\n\n def enable_bytecode_compilation(self, enable: bool = True) -> None:\n self._bytecode_optimization_levels = (-1,) if enable else ()\n\n def install(self, wheel: Path) -> None:\n with WheelFile.open(wheel) as source:\n try:\n # Content validation is temporarily disabled because of\n # pypa/installer's out of memory issues with big wheels. See\n # https://github.com/python-poetry/poetry/issues/7983\n source.validate_record(validate_contents=False)\n except _WheelFileValidationError as e:\n self.invalid_wheels[wheel] = e.issues\n\n scheme_dict = self._env.paths.copy()\n scheme_dict[\"headers\"] = str(\n Path(scheme_dict[\"include\"]) / source.distribution\n )\n destination = WheelDestination(\n scheme_dict,\n interpreter=str(self._env.python),\n script_kind=self._script_kind,\n bytecode_optimization_levels=self._bytecode_optimization_levels,\n )\n\n install(\n source=source,\n destination=destination,\n # Additional metadata that is generated by the installation tool.\n additional_metadata={\n \"INSTALLER\": f\"Poetry {__version__}\".encode(),\n },\n )\n\n# Path: src/poetry/json/__init__.py\nfrom __future__ import annotations\n\nimport json\n\nfrom pathlib import Path\nfrom typing import Any\n\nimport fastjsonschema\n\nfrom fastjsonschema.exceptions import JsonSchemaException\nfrom poetry.core.json import SCHEMA_DIR as CORE_SCHEMA_DIR\n\n\nSCHEMA_DIR = Path(__file__).parent / \"schemas\"\n\n\nclass ValidationError(ValueError):\n pass\n\n\ndef validate_object(obj: dict[str, Any]) -> list[str]:\n schema_file = Path(SCHEMA_DIR, \"poetry.json\")\n schema = json.loads(schema_file.read_text(encoding=\"utf-8\"))\n\n validate = fastjsonschema.compile(schema)\n\n errors = []\n try:\n validate(obj)\n except JsonSchemaException as e:\n errors = [e.message]\n\n core_schema = json.loads(\n (CORE_SCHEMA_DIR / \"poetry-schema.json\").read_text(encoding=\"utf-8\")\n )\n\n properties = {*schema[\"properties\"].keys(), *core_schema[\"properties\"].keys()}\n additional_properties = set(obj.keys()) - properties\n for key in additional_properties:\n errors.append(f\"Additional properties are not allowed ('{key}' was unexpected)\")\n\n return errors\n\n# Path: src/poetry/layouts/__init__.py\nfrom __future__ import annotations\n\nfrom poetry.layouts.layout import Layout\nfrom poetry.layouts.src import SrcLayout\n\n\n_LAYOUTS = {\"src\": SrcLayout, \"standard\": Layout}\n\n\ndef layout(name: str) -> type[Layout]:\n if name not in _LAYOUTS:\n raise ValueError(\"Invalid layout\")\n\n return _LAYOUTS[name]\n\n# Path: src/poetry/layouts/layout.py\nfrom __future__ import annotations\n\nfrom pathlib import Path\nfrom typing import TYPE_CHECKING\nfrom typing import Any\n\nfrom packaging.utils import canonicalize_name\nfrom poetry.core.utils.helpers import module_name\nfrom tomlkit import inline_table\nfrom tomlkit import loads\nfrom tomlkit import table\nfrom tomlkit.toml_document import TOMLDocument\n\nfrom poetry.pyproject.toml import PyProjectTOML\n\n\nif TYPE_CHECKING:\n from collections.abc import Mapping\n\n from tomlkit.items import InlineTable\n\n\nPOETRY_DEFAULT = \"\"\"\\\n[tool.poetry]\nname = \"\"\nversion = \"\"\ndescription = \"\"\nauthors = []\nlicense = \"\"\nreadme = \"\"\npackages = []\n\n[tool.poetry.dependencies]\n\n[tool.poetry.group.dev.dependencies]\n\"\"\"\n\nBUILD_SYSTEM_MIN_VERSION: str | None = None\nBUILD_SYSTEM_MAX_VERSION: str | None = None\n\n\nclass Layout:\n def __init__(\n self,\n project: str,\n version: str = \"0.1.0\",\n description: str = \"\",\n readme_format: str = \"md\",\n author: str | None = None,\n license: str | None = None,\n python: str = \"*\",\n dependencies: Mapping[str, str | Mapping[str, Any]] | None = None,\n dev_dependencies: Mapping[str, str | Mapping[str, Any]] | None = None,\n ) -> None:\n self._project = canonicalize_name(project)\n self._package_path_relative = Path(\n *(module_name(part) for part in project.split(\".\"))\n )\n self._package_name = \".\".join(self._package_path_relative.parts)\n self._version = version\n self._description = description\n\n self._readme_format = readme_format.lower()\n\n self._license = license\n self._python = python\n self._dependencies = dependencies or {}\n self._dev_dependencies = dev_dependencies or {}\n\n if not author:\n author = \"Your Name \"\n\n self._author = author\n\n @property\n def basedir(self) -> Path:\n return Path()\n\n @property\n def package_path(self) -> Path:\n return self.basedir / self._package_path_relative\n\n def get_package_include(self) -> InlineTable | None:\n package = inline_table()\n\n # If a project is created in the root directory (this is reasonable inside a\n # docker container, eg )\n # then parts will be empty.\n parts = self._package_path_relative.parts\n if not parts:\n return None\n\n include = parts[0]\n package.append(\"include\", include)\n\n if self.basedir != Path():\n package.append(\"from\", self.basedir.as_posix())\n else:\n if module_name(self._project) == include:\n # package include and package name are the same,\n # packages table is redundant here.\n return None\n\n return package\n\n def create(\n self, path: Path, with_tests: bool = True, with_pyproject: bool = True\n ) -> None:\n path.mkdir(parents=True, exist_ok=True)\n\n self._create_default(path)\n self._create_readme(path)\n\n if with_tests:\n self._create_tests(path)\n\n if with_pyproject:\n self._write_poetry(path)\n\n def generate_poetry_content(self) -> TOMLDocument:\n template = POETRY_DEFAULT\n\n content: dict[str, Any] = loads(template)\n\n poetry_content = content[\"tool\"][\"poetry\"]\n poetry_content[\"name\"] = self._project\n poetry_content[\"version\"] = self._version\n poetry_content[\"description\"] = self._description\n poetry_content[\"authors\"].append(self._author)\n\n if self._license:\n poetry_content[\"license\"] = self._license\n else:\n poetry_content.remove(\"license\")\n\n poetry_content[\"readme\"] = f\"README.{self._readme_format}\"\n packages = self.get_package_include()\n if packages:\n poetry_content[\"packages\"].append(packages)\n else:\n poetry_content.remove(\"packages\")\n\n poetry_content[\"dependencies\"][\"python\"] = self._python\n\n for dep_name, dep_constraint in self._dependencies.items():\n poetry_content[\"dependencies\"][dep_name] = dep_constraint\n\n if self._dev_dependencies:\n for dep_name, dep_constraint in self._dev_dependencies.items():\n poetry_content[\"group\"][\"dev\"][\"dependencies\"][dep_name] = (\n dep_constraint\n )\n else:\n del poetry_content[\"group\"]\n\n # Add build system\n build_system = table()\n build_system_version = \"\"\n\n if BUILD_SYSTEM_MIN_VERSION is not None:\n build_system_version = \">=\" + BUILD_SYSTEM_MIN_VERSION\n if BUILD_SYSTEM_MAX_VERSION is not None:\n if build_system_version:\n build_system_version += \",\"\n build_system_version += \"<\" + BUILD_SYSTEM_MAX_VERSION\n\n build_system.add(\"requires\", [\"poetry-core\" + build_system_version])\n build_system.add(\"build-backend\", \"poetry.core.masonry.api\")\n\n assert isinstance(content, TOMLDocument)\n content.add(\"build-system\", build_system)\n\n return content\n\n def _create_default(self, path: Path, src: bool = True) -> None:\n package_path = path / self.package_path\n package_path.mkdir(parents=True)\n\n package_init = package_path / \"__init__.py\"\n package_init.touch()\n\n def _create_readme(self, path: Path) -> Path:\n readme_file = path.joinpath(f\"README.{self._readme_format}\")\n readme_file.touch()\n return readme_file\n\n @staticmethod\n def _create_tests(path: Path) -> None:\n tests = path / \"tests\"\n tests.mkdir()\n\n tests_init = tests / \"__init__.py\"\n tests_init.touch(exist_ok=False)\n\n def _write_poetry(self, path: Path) -> None:\n pyproject = PyProjectTOML(path / \"pyproject.toml\")\n content = self.generate_poetry_content()\n for section, item in content.items():\n pyproject.data.append(section, item)\n pyproject.save()\n\n# Path: src/poetry/layouts/src.py\nfrom __future__ import annotations\n\nfrom pathlib import Path\n\nfrom poetry.layouts.layout import Layout\n\n\nclass SrcLayout(Layout):\n @property\n def basedir(self) -> Path:\n return Path(\"src\")\n\n# Path: src/poetry/layouts/standard.py\n\n# Path: src/poetry/masonry/__init__.py\n\n# Path: src/poetry/masonry/api.py\nfrom __future__ import annotations\n\nfrom poetry.core.masonry.api import build_sdist\nfrom poetry.core.masonry.api import build_wheel\nfrom poetry.core.masonry.api import get_requires_for_build_sdist\nfrom poetry.core.masonry.api import get_requires_for_build_wheel\nfrom poetry.core.masonry.api import prepare_metadata_for_build_wheel\n\n\n__all__ = [\n \"build_sdist\",\n \"build_wheel\",\n \"get_requires_for_build_sdist\",\n \"get_requires_for_build_wheel\",\n \"prepare_metadata_for_build_wheel\",\n]\n\n# Path: src/poetry/mixology/__init__.py\nfrom __future__ import annotations\n\nfrom typing import TYPE_CHECKING\n\nfrom poetry.mixology.version_solver import VersionSolver\n\n\nif TYPE_CHECKING:\n from poetry.core.packages.project_package import ProjectPackage\n\n from poetry.mixology.result import SolverResult\n from poetry.puzzle.provider import Provider\n\n\ndef resolve_version(root: ProjectPackage, provider: Provider) -> SolverResult:\n solver = VersionSolver(root, provider)\n\n return solver.solve()\n\n# Path: src/poetry/mixology/assignment.py\nfrom __future__ import annotations\n\nfrom typing import TYPE_CHECKING\n\nfrom poetry.mixology.term import Term\n\n\nif TYPE_CHECKING:\n from poetry.core.packages.dependency import Dependency\n from poetry.core.packages.package import Package\n\n from poetry.mixology.incompatibility import Incompatibility\n\n\nclass Assignment(Term):\n \"\"\"\n A term in a PartialSolution that tracks some additional metadata.\n \"\"\"\n\n def __init__(\n self,\n dependency: Dependency,\n is_positive: bool,\n decision_level: int,\n index: int,\n cause: Incompatibility | None = None,\n ) -> None:\n super().__init__(dependency, is_positive)\n\n self._decision_level = decision_level\n self._index = index\n self._cause = cause\n\n @property\n def decision_level(self) -> int:\n return self._decision_level\n\n @property\n def index(self) -> int:\n return self._index\n\n @property\n def cause(self) -> Incompatibility | None:\n return self._cause\n\n @classmethod\n def decision(cls, package: Package, decision_level: int, index: int) -> Assignment:\n return cls(package.to_dependency(), True, decision_level, index)\n\n @classmethod\n def derivation(\n cls,\n dependency: Dependency,\n is_positive: bool,\n cause: Incompatibility,\n decision_level: int,\n index: int,\n ) -> Assignment:\n return cls(dependency, is_positive, decision_level, index, cause)\n\n def is_decision(self) -> bool:\n return self._cause is None\n\n# Path: src/poetry/mixology/failure.py\nfrom __future__ import annotations\n\nfrom typing import TYPE_CHECKING\n\nfrom poetry.core.constraints.version import parse_constraint\n\nfrom poetry.mixology.incompatibility_cause import ConflictCause\nfrom poetry.mixology.incompatibility_cause import PythonCause\n\n\nif TYPE_CHECKING:\n from poetry.mixology.incompatibility import Incompatibility\n\n\nclass SolveFailure(Exception):\n def __init__(self, incompatibility: Incompatibility) -> None:\n self._incompatibility = incompatibility\n\n @property\n def message(self) -> str:\n return str(self)\n\n def __str__(self) -> str:\n return _Writer(self._incompatibility).write()\n\n\nclass _Writer:\n def __init__(self, root: Incompatibility) -> None:\n self._root = root\n self._derivations: dict[Incompatibility, int] = {}\n self._lines: list[tuple[str, int | None]] = []\n self._line_numbers: dict[Incompatibility, int] = {}\n\n self._count_derivations(self._root)\n\n def write(self) -> str:\n buffer = []\n version_solutions = []\n required_python_version_notification = False\n for incompatibility in self._root.external_incompatibilities:\n if isinstance(incompatibility.cause, PythonCause):\n root_constraint = parse_constraint(\n incompatibility.cause.root_python_version\n )\n constraint = parse_constraint(incompatibility.cause.python_version)\n\n version_solutions.append(\n \"For \"\n f\"{incompatibility.terms[0].dependency.name},\"\n \" a possible solution would be to set the\"\n \" `python` property to\"\n f' \"{root_constraint.intersect(constraint)}\"'\n )\n if not required_python_version_notification:\n buffer.append(\n \"The current project's supported Python range\"\n f\" ({incompatibility.cause.root_python_version}) is not\"\n \" compatible with some of the required packages Python\"\n \" requirement:\"\n )\n required_python_version_notification = True\n\n root_constraint = parse_constraint(\n incompatibility.cause.root_python_version\n )\n constraint = parse_constraint(incompatibility.cause.python_version)\n buffer.append(\n f\" - {incompatibility.terms[0].dependency.name} requires Python\"\n f\" {incompatibility.cause.python_version}, so it will not be\"\n f\" satisfied for Python {root_constraint.difference(constraint)}\"\n )\n\n if required_python_version_notification:\n buffer.append(\"\")\n\n if isinstance(self._root.cause, ConflictCause):\n self._visit(self._root)\n else:\n self._write(self._root, f\"Because {self._root}, version solving failed.\")\n\n padding = (\n 0\n if not self._line_numbers\n else len(f\"({list(self._line_numbers.values())[-1]}) \")\n )\n\n last_was_empty = False\n for line in self._lines:\n message = line[0]\n if not message:\n if not last_was_empty:\n buffer.append(\"\")\n\n last_was_empty = True\n continue\n\n last_was_empty = False\n\n number = line[-1]\n if number is not None:\n message = f\"({number})\".ljust(padding) + message\n else:\n message = \" \" * padding + message\n\n buffer.append(message)\n if required_python_version_notification:\n # Add suggested solution\n links = \",\".join(\n f\"\\n https://python-poetry.org/docs/dependency-specification/#{section}\"\n for section in [\n \"python-restricted-dependencies\",\n \"using-environment-markers\",\n ]\n )\n\n description = (\n \"The Python requirement can be specified via the\"\n \" `python` or\"\n \" `markers` properties\"\n )\n if version_solutions:\n description += \"\\n\\n \" + \"\\n\".join(version_solutions)\n\n description = description.strip(\" \")\n\n buffer.append(\n f\"\\n * \"\n f\"Check your dependencies Python requirement:\"\n f\" {description}\\n{links}\\n\",\n )\n return \"\\n\".join(buffer)\n\n def _write(\n self, incompatibility: Incompatibility, message: str, numbered: bool = False\n ) -> None:\n if numbered:\n number = len(self._line_numbers) + 1\n self._line_numbers[incompatibility] = number\n self._lines.append((message, number))\n else:\n self._lines.append((message, None))\n\n def _visit(\n self,\n incompatibility: Incompatibility,\n conclusion: bool = False,\n ) -> None:\n numbered = conclusion or self._derivations[incompatibility] > 1\n conjunction = \"So,\" if conclusion or incompatibility == self._root else \"And\"\n incompatibility_string = str(incompatibility)\n\n cause = incompatibility.cause\n assert isinstance(cause, ConflictCause)\n\n if isinstance(cause.conflict.cause, ConflictCause) and isinstance(\n cause.other.cause, ConflictCause\n ):\n conflict_line = self._line_numbers.get(cause.conflict)\n other_line = self._line_numbers.get(cause.other)\n\n if conflict_line is not None and other_line is not None:\n reason = cause.conflict.and_to_string(\n cause.other, conflict_line, other_line\n )\n self._write(\n incompatibility,\n f\"Because {reason}, {incompatibility_string}.\",\n numbered=numbered,\n )\n elif conflict_line is not None or other_line is not None:\n if conflict_line is not None:\n with_line = cause.conflict\n without_line = cause.other\n line = conflict_line\n elif other_line is not None:\n with_line = cause.other\n without_line = cause.conflict\n line = other_line\n\n self._visit(without_line)\n self._write(\n incompatibility,\n f\"{conjunction} because {with_line!s} ({line}),\"\n f\" {incompatibility_string}.\",\n numbered=numbered,\n )\n else:\n single_line_conflict = self._is_single_line(cause.conflict.cause)\n single_line_other = self._is_single_line(cause.other.cause)\n\n if single_line_other or single_line_conflict:\n first = cause.conflict if single_line_other else cause.other\n second = cause.other if single_line_other else cause.conflict\n self._visit(first)\n self._visit(second)\n self._write(\n incompatibility,\n f\"Thus, {incompatibility_string}.\",\n numbered=numbered,\n )\n else:\n self._visit(cause.conflict, conclusion=True)\n self._lines.append((\"\", None))\n\n self._visit(cause.other)\n\n self._write(\n incompatibility,\n...\n", "needle_token_start": 8778, "needle_token_end": 8888, "code_context_ntokens": 16020, "output": ["```python\ndef _write_lock_file(self, repo: LockfileRepository, force: bool = False) -> None:\n if not self.is_dry_run() and (force or self._update):\n updated_lock = self._locker.set_lock_data(self._package, repo.packages)\n\n if updated_lock:\n self._io.write_line(\"\")\n self._io.write_line(\"Writing lock file\")\n```"]} +{"repo": "python-poetry/poetry", "name": "_do_execute_operation", "language": "python", "path": "src/poetry/installation/executor.py", "position_ratio": 0.65, "description": "\nFunction Description:\n1. **Purpose**: The function is designed to manage the execution of a specified operation, handling its completion, skipping, or failure based on the operation's properties and the system's state.\n2. **Input**: The function takes an operation object that contains details about the job to be executed, including its type and whether it should be skipped.\n3. **Output**: Returns an integer indicating the result of the operation execution, where a value of 0 typically denotes successful completion or intentional skipping.\n4. **Procedure**: The function first checks if the operation is marked as skipped and, if so, logs this status and returns 0. If not skipped, it checks if the system is in a dry run or disabled state, returning 0 if true. Otherwise, it dynamically calls the appropriate method based on the operation type to execute the job. After execution, it logs the completion and updates the count of successful operations, returning the result of the execution.\n\n", "instruction": "Based on the function description and code context, please retrieve and repeat the exact described function from the code context in a code block wrapped by ```:", "template": "instruction\ncode_context\ndescription\ninstruction", "code_context": "# Path: src/poetry/inspection/lazy_wheel.py\n\"\"\"Lazy ZIP over HTTP\"\"\"\n\nfrom __future__ import annotations\n\nimport io\nimport logging\nimport re\n\nfrom bisect import bisect_left\nfrom bisect import bisect_right\nfrom contextlib import contextmanager\nfrom tempfile import NamedTemporaryFile\nfrom typing import TYPE_CHECKING\nfrom typing import Any\nfrom typing import BinaryIO\nfrom typing import ClassVar\nfrom typing import TypeVar\nfrom typing import cast\nfrom urllib.parse import urlparse\nfrom zipfile import BadZipFile\nfrom zipfile import ZipFile\n\nfrom packaging.metadata import parse_email\nfrom requests.models import CONTENT_CHUNK_SIZE\nfrom requests.models import HTTPError\nfrom requests.models import Response\nfrom requests.status_codes import codes\n\n\nif TYPE_CHECKING:\n from collections.abc import Iterable\n from collections.abc import Iterator\n from types import TracebackType\n\n from packaging.metadata import RawMetadata\n from requests import Session\n\n from poetry.utils.authenticator import Authenticator\n\n\nlogger = logging.getLogger(__name__)\n\n\nclass LazyWheelUnsupportedError(Exception):\n \"\"\"Raised when a lazy wheel is unsupported.\"\"\"\n\n\nclass HTTPRangeRequestUnsupported(LazyWheelUnsupportedError):\n \"\"\"Raised when the remote server appears unable to support byte ranges.\"\"\"\n\n\nclass HTTPRangeRequestNotRespected(LazyWheelUnsupportedError):\n \"\"\"Raised when the remote server tells us that it supports byte ranges\n but does not respect a respective request.\"\"\"\n\n\nclass UnsupportedWheel(LazyWheelUnsupportedError):\n \"\"\"Unsupported wheel.\"\"\"\n\n\nclass InvalidWheel(LazyWheelUnsupportedError):\n \"\"\"Invalid (e.g. corrupt) wheel.\"\"\"\n\n def __init__(self, location: str, name: str) -> None:\n self.location = location\n self.name = name\n\n def __str__(self) -> str:\n return f\"Wheel {self.name} located at {self.location} is invalid.\"\n\n\ndef metadata_from_wheel_url(\n name: str, url: str, session: Session | Authenticator\n) -> RawMetadata:\n \"\"\"Fetch metadata from the given wheel URL.\n\n This uses HTTP range requests to only fetch the portion of the wheel\n containing metadata, just enough for the object to be constructed.\n\n :raises HTTPRangeRequestUnsupported: if range requests are unsupported for ``url``.\n :raises InvalidWheel: if the zip file contents could not be parsed.\n \"\"\"\n try:\n # After context manager exit, wheel.name will point to a deleted file path.\n # Add `delete_backing_file=False` to disable this for debugging.\n with LazyWheelOverHTTP(url, session) as lazy_file:\n metadata_bytes = lazy_file.read_metadata(name)\n\n metadata, _ = parse_email(metadata_bytes)\n return metadata\n\n except (BadZipFile, UnsupportedWheel):\n # We assume that these errors have occurred because the wheel contents\n # themselves are invalid, not because we've messed up our bookkeeping\n # and produced an invalid file.\n raise InvalidWheel(url, name)\n except Exception as e:\n if isinstance(e, LazyWheelUnsupportedError):\n # this is expected when the code handles issues with lazy wheel metadata retrieval correctly\n raise e\n\n logger.debug(\n \"There was an unexpected %s when handling lazy wheel metadata retrieval for %s from %s: %s\",\n type(e).__name__,\n name,\n url,\n e,\n )\n\n # Catch all exception to handle any issues that may have occurred during\n # attempts to use Lazy Wheel.\n raise LazyWheelUnsupportedError(\n f\"Attempts to use lazy wheel metadata retrieval for {name} from {url} failed\"\n ) from e\n\n\nclass MergeIntervals:\n \"\"\"Stateful bookkeeping to merge interval graphs.\"\"\"\n\n def __init__(self, *, left: Iterable[int] = (), right: Iterable[int] = ()) -> None:\n self._left = list(left)\n self._right = list(right)\n\n def __repr__(self) -> str:\n return (\n f\"{type(self).__name__}\"\n f\"(left={tuple(self._left)}, right={tuple(self._right)})\"\n )\n\n def _merge(\n self, start: int, end: int, left: int, right: int\n ) -> Iterator[tuple[int, int]]:\n \"\"\"Return an iterator of intervals to be fetched.\n\n Args:\n start: Start of needed interval\n end: End of needed interval\n left: Index of first overlapping downloaded data\n right: Index after last overlapping downloaded data\n \"\"\"\n lslice, rslice = self._left[left:right], self._right[left:right]\n i = start = min([start] + lslice[:1])\n end = max([end] + rslice[-1:])\n for j, k in zip(lslice, rslice):\n if j > i:\n yield i, j - 1\n i = k + 1\n if i <= end:\n yield i, end\n self._left[left:right], self._right[left:right] = [start], [end]\n\n def minimal_intervals_covering(\n self, start: int, end: int\n ) -> Iterator[tuple[int, int]]:\n \"\"\"Provide the intervals needed to cover from ``start <= x <= end``.\n\n This method mutates internal state so that later calls only return intervals not\n covered by prior calls. The first call to this method will always return exactly\n one interval, which was exactly the one requested. Later requests for\n intervals overlapping that first requested interval will yield only the ranges\n not previously covered (which may be empty, e.g. if the same interval is\n requested twice).\n\n This may be used e.g. to download substrings of remote files on demand.\n \"\"\"\n left = bisect_left(self._right, start)\n right = bisect_right(self._left, end)\n yield from self._merge(start, end, left, right)\n\n\nT = TypeVar(\"T\", bound=\"ReadOnlyIOWrapper\")\n\n\nclass ReadOnlyIOWrapper(BinaryIO):\n \"\"\"Implement read-side ``BinaryIO`` methods wrapping an inner ``BinaryIO``.\n\n This wrapper is useful because Python currently does not distinguish read-only\n streams at the type level.\n \"\"\"\n\n def __init__(self, inner: BinaryIO) -> None:\n self._file = inner\n\n def __enter__(self: T) -> T:\n self._file.__enter__()\n return self\n\n def __exit__(\n self,\n exc_type: type[BaseException] | None,\n exc_value: BaseException | None,\n traceback: TracebackType | None,\n ) -> None:\n self._file.__exit__(exc_type, exc_value, traceback)\n\n def __iter__(self) -> Iterator[bytes]:\n raise NotImplementedError\n\n def __next__(self) -> bytes:\n raise NotImplementedError\n\n @property\n def mode(self) -> str:\n \"\"\"Opening mode, which is always rb.\"\"\"\n return \"rb\"\n\n @property\n def name(self) -> str:\n \"\"\"Path to the underlying file.\"\"\"\n return self._file.name\n\n def seekable(self) -> bool:\n \"\"\"Return whether random access is supported, which is True.\"\"\"\n return True\n\n def close(self) -> None:\n \"\"\"Close the file.\"\"\"\n self._file.close()\n\n @property\n def closed(self) -> bool:\n \"\"\"Whether the file is closed.\"\"\"\n return self._file.closed\n\n def fileno(self) -> int:\n return self._file.fileno()\n\n def flush(self) -> None:\n self._file.flush()\n\n def isatty(self) -> bool:\n return False\n\n def readable(self) -> bool:\n \"\"\"Return whether the file is readable, which is True.\"\"\"\n return True\n\n def read(self, size: int = -1) -> bytes:\n \"\"\"Read up to size bytes from the object and return them.\n\n As a convenience, if size is unspecified or -1,\n all bytes until EOF are returned. Fewer than\n size bytes may be returned if EOF is reached.\n \"\"\"\n return self._file.read(size)\n\n def readline(self, limit: int = -1) -> bytes:\n # Explicit impl needed to satisfy mypy.\n raise NotImplementedError\n\n def readlines(self, hint: int = -1) -> list[bytes]:\n raise NotImplementedError\n\n def seek(self, offset: int, whence: int = 0) -> int:\n \"\"\"Change stream position and return the new absolute position.\n\n Seek to offset relative position indicated by whence:\n * 0: Start of stream (the default). pos should be >= 0;\n * 1: Current position - pos may be negative;\n * 2: End of stream - pos usually negative.\n \"\"\"\n return self._file.seek(offset, whence)\n\n def tell(self) -> int:\n \"\"\"Return the current position.\"\"\"\n return self._file.tell()\n\n def truncate(self, size: int | None = None) -> int:\n \"\"\"Resize the stream to the given size in bytes.\n\n If size is unspecified resize to the current position.\n The current stream position isn't changed.\n\n Return the new file size.\n \"\"\"\n return self._file.truncate(size)\n\n def writable(self) -> bool:\n \"\"\"Return False.\"\"\"\n return False\n\n def write(self, s: Any) -> int:\n raise NotImplementedError\n\n def writelines(self, lines: Iterable[Any]) -> None:\n raise NotImplementedError\n\n\nU = TypeVar(\"U\", bound=\"LazyFileOverHTTP\")\n\n\nclass LazyFileOverHTTP(ReadOnlyIOWrapper):\n \"\"\"File-like object representing a fixed-length file over HTTP.\n\n This uses HTTP range requests to lazily fetch the file's content into a temporary\n file. If such requests are not supported by the server, raises\n ``HTTPRangeRequestUnsupported`` in the ``__enter__`` method.\"\"\"\n\n def __init__(\n self,\n url: str,\n session: Session | Authenticator,\n delete_backing_file: bool = True,\n ) -> None:\n super().__init__(cast(BinaryIO, NamedTemporaryFile(delete=delete_backing_file)))\n\n self._merge_intervals: MergeIntervals | None = None\n self._length: int | None = None\n\n self._request_count = 0\n...\n# Path: src/poetry/installation/__init__.py\nfrom __future__ import annotations\n\nfrom poetry.installation.installer import Installer\n\n\n__all__ = [\"Installer\"]\n\n# Path: src/poetry/installation/chef.py\nfrom __future__ import annotations\n\nimport os\nimport tempfile\n\nfrom contextlib import redirect_stdout\nfrom io import StringIO\nfrom pathlib import Path\nfrom typing import TYPE_CHECKING\n\nfrom build import BuildBackendException\nfrom build import ProjectBuilder\nfrom build.env import IsolatedEnv as BaseIsolatedEnv\nfrom poetry.core.utils.helpers import temporary_directory\nfrom pyproject_hooks import quiet_subprocess_runner # type: ignore[import-untyped]\n\nfrom poetry.utils._compat import decode\nfrom poetry.utils.env import ephemeral_environment\nfrom poetry.utils.helpers import extractall\n\n\nif TYPE_CHECKING:\n from collections.abc import Collection\n\n from poetry.repositories import RepositoryPool\n from poetry.utils.cache import ArtifactCache\n from poetry.utils.env import Env\n\n\nclass ChefError(Exception): ...\n\n\nclass ChefBuildError(ChefError): ...\n\n\nclass ChefInstallError(ChefError):\n def __init__(self, requirements: Collection[str], output: str, error: str) -> None:\n message = \"\\n\\n\".join(\n (\n f\"Failed to install {', '.join(requirements)}.\",\n f\"Output:\\n{output}\",\n f\"Error:\\n{error}\",\n )\n )\n super().__init__(message)\n self._requirements = requirements\n\n @property\n def requirements(self) -> Collection[str]:\n return self._requirements\n\n\nclass IsolatedEnv(BaseIsolatedEnv):\n def __init__(self, env: Env, pool: RepositoryPool) -> None:\n self._env = env\n self._pool = pool\n\n @property\n def python_executable(self) -> str:\n return str(self._env.python)\n\n def make_extra_environ(self) -> dict[str, str]:\n path = os.environ.get(\"PATH\")\n scripts_dir = str(self._env._bin_dir)\n return {\n \"PATH\": (\n os.pathsep.join([scripts_dir, path])\n if path is not None\n else scripts_dir\n )\n }\n\n def install(self, requirements: Collection[str]) -> None:\n from cleo.io.buffered_io import BufferedIO\n from poetry.core.packages.dependency import Dependency\n from poetry.core.packages.project_package import ProjectPackage\n\n from poetry.config.config import Config\n from poetry.installation.installer import Installer\n from poetry.packages.locker import Locker\n from poetry.repositories.installed_repository import InstalledRepository\n\n # We build Poetry dependencies from the requirements\n package = ProjectPackage(\"__root__\", \"0.0.0\")\n package.python_versions = \".\".join(str(v) for v in self._env.version_info[:3])\n for requirement in requirements:\n dependency = Dependency.create_from_pep_508(requirement)\n package.add_dependency(dependency)\n\n io = BufferedIO()\n installer = Installer(\n io,\n self._env,\n package,\n Locker(self._env.path.joinpath(\"poetry.lock\"), {}),\n self._pool,\n Config.create(),\n InstalledRepository.load(self._env),\n )\n installer.update(True)\n if installer.run() != 0:\n raise ChefInstallError(requirements, io.fetch_output(), io.fetch_error())\n\n\nclass Chef:\n def __init__(\n self, artifact_cache: ArtifactCache, env: Env, pool: RepositoryPool\n ) -> None:\n self._env = env\n self._pool = pool\n self._artifact_cache = artifact_cache\n\n def prepare(\n self, archive: Path, output_dir: Path | None = None, *, editable: bool = False\n ) -> Path:\n if not self._should_prepare(archive):\n return archive\n\n if archive.is_dir():\n destination = output_dir or Path(tempfile.mkdtemp(prefix=\"poetry-chef-\"))\n return self._prepare(archive, destination=destination, editable=editable)\n\n return self._prepare_sdist(archive, destination=output_dir)\n\n def _prepare(\n self, directory: Path, destination: Path, *, editable: bool = False\n ) -> Path:\n from subprocess import CalledProcessError\n\n with ephemeral_environment(\n self._env.python,\n flags={\"no-pip\": True, \"no-setuptools\": True, \"no-wheel\": True},\n ) as venv:\n env = IsolatedEnv(venv, self._pool)\n builder = ProjectBuilder.from_isolated_env(\n env, directory, runner=quiet_subprocess_runner\n )\n env.install(builder.build_system_requires)\n\n stdout = StringIO()\n error: Exception | None = None\n try:\n with redirect_stdout(stdout):\n dist_format = \"wheel\" if not editable else \"editable\"\n env.install(\n builder.build_system_requires\n | builder.get_requires_for_build(dist_format)\n )\n path = Path(\n builder.build(\n dist_format,\n destination.as_posix(),\n )\n )\n except BuildBackendException as e:\n message_parts = [str(e)]\n if isinstance(e.exception, CalledProcessError):\n text = e.exception.stderr or e.exception.stdout\n if text is not None:\n message_parts.append(decode(text))\n else:\n message_parts.append(str(e.exception))\n\n error = ChefBuildError(\"\\n\\n\".join(message_parts))\n\n if error is not None:\n raise error from None\n\n return path\n\n def _prepare_sdist(self, archive: Path, destination: Path | None = None) -> Path:\n from poetry.core.packages.utils.link import Link\n\n suffix = archive.suffix\n zip = suffix == \".zip\"\n\n with temporary_directory() as tmp_dir:\n archive_dir = Path(tmp_dir)\n extractall(source=archive, dest=archive_dir, zip=zip)\n\n elements = list(archive_dir.glob(\"*\"))\n\n if len(elements) == 1 and elements[0].is_dir():\n sdist_dir = elements[0]\n else:\n sdist_dir = archive_dir / archive.name.rstrip(suffix)\n if not sdist_dir.is_dir():\n sdist_dir = archive_dir\n\n if destination is None:\n destination = self._artifact_cache.get_cache_directory_for_link(\n Link(archive.as_uri())\n )\n\n destination.mkdir(parents=True, exist_ok=True)\n\n return self._prepare(\n sdist_dir,\n destination,\n )\n\n def _should_prepare(self, archive: Path) -> bool:\n return archive.is_dir() or not self._is_wheel(archive)\n\n @classmethod\n def _is_wheel(cls, archive: Path) -> bool:\n return archive.suffix == \".whl\"\n\n# Path: src/poetry/installation/chooser.py\nfrom __future__ import annotations\n\nimport logging\nimport re\n\nfrom typing import TYPE_CHECKING\nfrom typing import Any\n\nfrom poetry.config.config import Config\nfrom poetry.config.config import PackageFilterPolicy\nfrom poetry.repositories.http_repository import HTTPRepository\nfrom poetry.utils.helpers import get_highest_priority_hash_type\nfrom poetry.utils.wheel import Wheel\n\n\nif TYPE_CHECKING:\n from poetry.core.constraints.version import Version\n from poetry.core.packages.package import Package\n from poetry.core.packages.utils.link import Link\n\n from poetry.repositories.repository_pool import RepositoryPool\n from poetry.utils.env import Env\n\n\nlogger = logging.getLogger(__name__)\n\n\nclass Chooser:\n \"\"\"\n A Chooser chooses an appropriate release archive for packages.\n \"\"\"\n\n def __init__(\n self, pool: RepositoryPool, env: Env, config: Config | None = None\n ) -> None:\n self._pool = pool\n self._env = env\n self._config = config or Config.create()\n self._no_binary_policy: PackageFilterPolicy = PackageFilterPolicy(\n self._config.get(\"installer.no-binary\", [])\n )\n\n def choose_for(self, package: Package) -> Link:\n \"\"\"\n Return the url of the selected archive for a given package.\n \"\"\"\n links = []\n for link in self._get_links(package):\n if link.is_wheel:\n if not self._no_binary_policy.allows(package.name):\n logger.debug(\n \"Skipping wheel for %s as requested in no binary policy for\"\n \" package (%s)\",\n link.filename,\n package.name,\n )\n continue\n\n if not Wheel(link.filename).is_supported_by_environment(self._env):\n logger.debug(\n \"Skipping wheel %s as this is not supported by the current\"\n \" environment\",\n link.filename,\n )\n continue\n\n if link.ext in {\".egg\", \".exe\", \".msi\", \".rpm\", \".srpm\"}:\n logger.debug(\"Skipping unsupported distribution %s\", link.filename)\n continue\n\n links.append(link)\n\n if not links:\n raise RuntimeError(f\"Unable to find installation candidates for {package}\")\n\n # Get the best link\n chosen = max(links, key=lambda link: self._sort_key(package, link))\n\n return chosen\n\n def _get_links(self, package: Package) -> list[Link]:\n if package.source_type:\n assert package.source_reference is not None\n repository = self._pool.repository(package.source_reference)\n\n elif not self._pool.has_repository(\"pypi\"):\n repository = self._pool.repositories[0]\n else:\n repository = self._pool.repository(\"pypi\")\n links = repository.find_links_for_package(package)\n\n locked_hashes = {f[\"hash\"] for f in package.files}\n if not locked_hashes:\n return links\n\n selected_links = []\n skipped = []\n locked_hash_names = {h.split(\":\")[0] for h in locked_hashes}\n for link in links:\n if not link.hashes:\n selected_links.append(link)\n continue\n\n link_hash: str | None = None\n if (candidates := locked_hash_names.intersection(link.hashes.keys())) and (\n hash_name := get_highest_priority_hash_type(candidates, link.filename)\n ):\n link_hash = f\"{hash_name}:{link.hashes[hash_name]}\"\n\n elif isinstance(repository, HTTPRepository):\n link_hash = repository.calculate_sha256(link)\n\n if link_hash not in locked_hashes:\n skipped.append((link.filename, link_hash))\n logger.debug(\n \"Skipping %s as %s checksum does not match expected value\",\n link.filename,\n link_hash,\n )\n continue\n\n selected_links.append(link)\n\n if links and not selected_links:\n links_str = \", \".join(f\"{link}({h})\" for link, h in skipped)\n raise RuntimeError(\n f\"Retrieved digests for links {links_str} not in poetry.lock\"\n f\" metadata {locked_hashes}\"\n )\n\n return selected_links\n\n def _sort_key(\n self, package: Package, link: Link\n ) -> tuple[int, int, int, Version, tuple[Any, ...], int]:\n \"\"\"\n Function to pass as the `key` argument to a call to sorted() to sort\n InstallationCandidates by preference.\n Returns a tuple such that tuples sorting as greater using Python's\n default comparison operator are more preferred.\n The preference is as follows:\n First and foremost, candidates with allowed (matching) hashes are\n always preferred over candidates without matching hashes. This is\n because e.g. if the only candidate with an allowed hash is yanked,\n we still want to use that candidate.\n Second, excepting hash considerations, candidates that have been\n yanked (in the sense of PEP 592) are always less preferred than\n candidates that haven't been yanked. Then:\n If not finding wheels, they are sorted by version only.\n If finding wheels, then the sort order is by version, then:\n 1. existing installs\n 2. wheels ordered via Wheel.support_index_min(self._supported_tags)\n 3. source archives\n If prefer_binary was set, then all wheels are sorted above sources.\n Note: it was considered to embed this logic into the Link\n comparison operators, but then different sdist links\n with the same version, would have to be considered equal\n \"\"\"\n build_tag: tuple[Any, ...] = ()\n binary_preference = 0\n if link.is_wheel:\n wheel = Wheel(link.filename)\n if not wheel.is_supported_by_environment(self._env):\n raise RuntimeError(\n f\"{wheel.filename} is not a supported wheel for this platform. It \"\n \"can't be sorted.\"\n )\n\n # TODO: Binary preference\n pri = -(wheel.get_minimum_supported_index(self._env.supported_tags) or 0)\n if wheel.build_tag is not None:\n match = re.match(r\"^(\\d+)(.*)$\", wheel.build_tag)\n if not match:\n raise ValueError(f\"Unable to parse build tag: {wheel.build_tag}\")\n build_tag_groups = match.groups()\n build_tag = (int(build_tag_groups[0]), build_tag_groups[1])\n else: # sdist\n support_num = len(self._env.supported_tags)\n pri = -support_num\n\n has_allowed_hash = int(self._is_link_hash_allowed_for_package(link, package))\n\n yank_value = int(not link.yanked)\n\n return (\n has_allowed_hash,\n yank_value,\n binary_preference,\n package.version,\n build_tag,\n pri,\n )\n\n def _is_link_hash_allowed_for_package(self, link: Link, package: Package) -> bool:\n if not link.hashes:\n return True\n\n link_hashes = {f\"{name}:{h}\" for name, h in link.hashes.items()}\n locked_hashes = {f[\"hash\"] for f in package.files}\n\n return bool(link_hashes & locked_hashes)\n\n# Path: src/poetry/installation/executor.py\nfrom __future__ import annotations\n\nimport contextlib\nimport csv\nimport functools\nimport itertools\nimport json\nimport threading\n\nfrom concurrent.futures import ThreadPoolExecutor\nfrom concurrent.futures import wait\nfrom pathlib import Path\nfrom subprocess import CalledProcessError\nfrom typing import TYPE_CHECKING\nfrom typing import Any\n\nfrom cleo.io.null_io import NullIO\nfrom poetry.core.packages.utils.link import Link\n\nfrom poetry.installation.chef import Chef\nfrom poetry.installation.chef import ChefBuildError\nfrom poetry.installation.chef import ChefInstallError\nfrom poetry.installation.chooser import Chooser\nfrom poetry.installation.operations import Install\nfrom poetry.installation.operations import Uninstall\nfrom poetry.installation.operations import Update\nfrom poetry.installation.wheel_installer import WheelInstaller\nfrom poetry.puzzle.exceptions import SolverProblemError\nfrom poetry.utils._compat import decode\nfrom poetry.utils.authenticator import Authenticator\nfrom poetry.utils.env import EnvCommandError\nfrom poetry.utils.helpers import Downloader\nfrom poetry.utils.helpers import get_file_hash\nfrom poetry.utils.helpers import get_highest_priority_hash_type\nfrom poetry.utils.helpers import pluralize\nfrom poetry.utils.helpers import remove_directory\nfrom poetry.utils.pip import pip_install\n\n\nif TYPE_CHECKING:\n from cleo.io.io import IO\n from cleo.io.outputs.section_output import SectionOutput\n from poetry.core.masonry.builders.builder import Builder\n from poetry.core.packages.package import Package\n\n from poetry.config.config import Config\n from poetry.installation.operations.operation import Operation\n from poetry.repositories import RepositoryPool\n from poetry.utils.env import Env\n\n\nclass Executor:\n def __init__(\n self,\n env: Env,\n pool: RepositoryPool,\n config: Config,\n io: IO,\n parallel: bool | None = None,\n disable_cache: bool = False,\n ) -> None:\n self._env = env\n self._io = io\n self._dry_run = False\n self._enabled = True\n self._verbose = False\n self._wheel_installer = WheelInstaller(self._env)\n self._use_modern_installation = config.get(\n \"installer.modern-installation\", True\n )\n if not self._use_modern_installation:\n self._io.write_line(\n \"Warning: Setting `installer.modern-installation` to `false` \"\n \"is deprecated.\"\n )\n self._io.write_line(\n \"The pip-based installer will be removed in a future release.\"\n )\n self._io.write_line(\n \"See https://github.com/python-poetry/poetry/issues/8987.\"\n )\n\n if parallel is None:\n parallel = config.get(\"installer.parallel\", True)\n\n if parallel:\n self._max_workers = config.installer_max_workers\n else:\n self._max_workers = 1\n\n self._artifact_cache = pool.artifact_cache\n self._authenticator = Authenticator(\n config, self._io, disable_cache=disable_cache, pool_size=self._max_workers\n )\n self._chef = Chef(self._artifact_cache, self._env, pool)\n self._chooser = Chooser(pool, self._env, config)\n\n self._executor = ThreadPoolExecutor(max_workers=self._max_workers)\n self._total_operations = 0\n self._executed_operations = 0\n self._executed = {\"install\": 0, \"update\": 0, \"uninstall\": 0}\n self._skipped = {\"install\": 0, \"update\": 0, \"uninstall\": 0}\n self._sections: dict[int, SectionOutput] = {}\n self._yanked_warnings: list[str] = []\n self._lock = threading.Lock()\n self._shutdown = False\n self._hashes: dict[str, str] = {}\n\n @property\n def installations_count(self) -> int:\n return self._executed[\"install\"]\n\n @property\n def updates_count(self) -> int:\n return self._executed[\"update\"]\n\n @property\n def removals_count(self) -> int:\n return self._executed[\"uninstall\"]\n\n @property\n def enabled(self) -> bool:\n return self._enabled\n\n def supports_fancy_output(self) -> bool:\n return self._io.output.is_decorated() and not self._dry_run\n\n def disable(self) -> Executor:\n self._enabled = False\n\n return self\n\n def dry_run(self, dry_run: bool = True) -> Executor:\n self._dry_run = dry_run\n\n return self\n\n def verbose(self, verbose: bool = True) -> Executor:\n self._verbose = verbose\n\n return self\n\n def enable_bytecode_compilation(self, enable: bool = True) -> None:\n self._wheel_installer.enable_bytecode_compilation(enable)\n\n def pip_install(\n self, req: Path, upgrade: bool = False, editable: bool = False\n ) -> int:\n try:\n pip_install(req, self._env, upgrade=upgrade, editable=editable)\n except EnvCommandError as e:\n output = decode(e.e.output)\n if (\n \"KeyboardInterrupt\" in output\n or \"ERROR: Operation cancelled by user\" in output\n ):\n return -2\n raise\n\n return 0\n\n def execute(self, operations: list[Operation]) -> int:\n self._total_operations = len(operations)\n for job_type in self._executed:\n self._executed[job_type] = 0\n self._skipped[job_type] = 0\n\n if operations and (self._enabled or self._dry_run):\n self._display_summary(operations)\n\n self._sections = {}\n self._yanked_warnings = []\n\n # pip has to be installed first without parallelism if we install via pip\n for i, op in enumerate(operations):\n if op.package.name == \"pip\":\n wait([self._executor.submit(self._execute_operation, op)])\n del operations[i]\n break\n\n # We group operations by priority\n groups = itertools.groupby(operations, key=lambda o: -o.priority)\n for _, group in groups:\n tasks = []\n serial_operations = []\n for operation in group:\n if self._shutdown:\n break\n\n # Some operations are unsafe, we must execute them serially in a group\n # https://github.com/python-poetry/poetry/issues/3086\n # https://github.com/python-poetry/poetry/issues/2658\n #\n # We need to explicitly check source type here, see:\n # https://github.com/python-poetry/poetry-core/pull/98\n is_parallel_unsafe = operation.job_type == \"uninstall\" or (\n operation.package.develop\n and operation.package.source_type in {\"directory\", \"git\"}\n )\n if not operation.skipped and is_parallel_unsafe:\n serial_operations.append(operation)\n continue\n\n tasks.append(self._executor.submit(self._execute_operation, operation))\n\n try:\n wait(tasks)\n\n for operation in serial_operations:\n wait([self._executor.submit(self._execute_operation, operation)])\n\n except KeyboardInterrupt:\n self._shutdown = True\n\n if self._shutdown:\n # Cancelling further tasks from being executed\n [task.cancel() for task in tasks]\n self._executor.shutdown(wait=True)\n\n break\n\n for warning in self._yanked_warnings:\n self._io.write_error_line(f\"Warning: {warning}\")\n for path, issues in self._wheel_installer.invalid_wheels.items():\n formatted_issues = \"\\n\".join(issues)\n warning = (\n f\"Validation of the RECORD file of {path.name} failed.\"\n \" Please report to the maintainers of that package so they can fix\"\n f\" their build process. Details:\\n{formatted_issues}\\n\"\n )\n self._io.write_error_line(f\"Warning: {warning}\")\n\n return 1 if self._shutdown else 0\n\n def _write(self, operation: Operation, line: str) -> None:\n if not self.supports_fancy_output() or not self._should_write_operation(\n operation\n ):\n return\n\n if self._io.is_debug():\n with self._lock:\n section = self._sections[id(operation)]\n section.write_line(line)\n\n return\n\n with self._lock:\n section = self._sections[id(operation)]\n section.clear()\n section.write(line)\n\n def _execute_operation(self, operation: Operation) -> None:\n try:\n op_message = self.get_operation_message(operation)\n if self.supports_fancy_output():\n if id(operation) not in self._sections and self._should_write_operation(\n operation\n ):\n with self._lock:\n self._sections[id(operation)] = self._io.section()\n self._sections[id(operation)].write_line(\n f\" - {op_message}:\"\n \" Pending...\"\n )\n else:\n if self._should_write_operation(operation):\n if not operation.skipped:\n self._io.write_line(\n f\" - {op_message}\"\n )\n else:\n self._io.write_line(\n f\" - {op_message}: \"\n \"Skipped \"\n \"for the following reason: \"\n f\"{operation.skip_reason}\"\n )\n\n try:\n result = self._do_execute_operation(operation)\n except EnvCommandError as e:\n if e.e.returncode == -2:\n result = -2\n else:\n raise\n\n # If we have a result of -2 it means a KeyboardInterrupt\n # in the any python subprocess, so we raise a KeyboardInterrupt\n # error to be picked up by the error handler.\n if result == -2:\n raise KeyboardInterrupt\n except Exception as e:\n try:\n from cleo.ui.exception_trace import ExceptionTrace\n\n io: IO | SectionOutput\n if not self.supports_fancy_output():\n io = self._io\n else:\n message = (\n \" -\"\n f\" {self.get_operation_message(operation, error=True)}:\"\n \" Failed\"\n )\n self._write(operation, message)\n io = self._sections.get(id(operation), self._io)\n\n with self._lock:\n trace = ExceptionTrace(e)\n trace.render(io)\n pkg = operation.package\n if isinstance(e, ChefBuildError):\n pip_command = \"pip wheel --no-cache-dir --use-pep517\"\n if pkg.develop:\n requirement = pkg.source_url\n pip_command += \" --editable\"\n else:\n requirement = (\n pkg.to_dependency().to_pep_508().split(\";\")[0].strip()\n )\n message = (\n \"\"\n \"Note: This error originates from the build backend,\"\n \" and is likely not a problem with poetry\"\n f\" but with {pkg.pretty_name} ({pkg.full_pretty_version})\"\n \" not supporting PEP 517 builds. You can verify this by\"\n f\" running '{pip_command} \\\"{requirement}\\\"'.\"\n \"\"\n )\n elif isinstance(e, ChefInstallError):\n message = (\n \"\"\n \"Cannot install build-system.requires\"\n f\" for {pkg.pretty_name}.\"\n \"\"\n )\n elif isinstance(e, SolverProblemError):\n message = (\n \"\"\n \"Cannot resolve build-system.requires\"\n f\" for {pkg.pretty_name}.\"\n \"\"\n )\n else:\n message = f\"Cannot install {pkg.pretty_name}.\"\n\n io.write_line(\"\")\n io.write_line(message)\n io.write_line(\"\")\n finally:\n with self._lock:\n self._shutdown = True\n\n except KeyboardInterrupt:\n try:\n message = (\n \" -\"\n f\" {self.get_operation_message(operation, warning=True)}:\"\n \" Cancelled\"\n )\n if not self.supports_fancy_output():\n self._io.write_line(message)\n else:\n self._write(operation, message)\n finally:\n with self._lock:\n self._shutdown = True\n\n \ndef _do_execute_operation(self, operation: Operation) -> int:\n method = operation.job_type\n\n operation_message = self.get_operation_message(operation)\n if operation.skipped:\n if self.supports_fancy_output():\n self._write(\n operation,\n f\" - {operation_message}: \"\n \"Skipped \"\n \"for the following reason: \"\n f\"{operation.skip_reason}\",\n )\n\n self._skipped[operation.job_type] += 1\n\n return 0\n\n if not self._enabled or self._dry_run:\n return 0\n\n result: int = getattr(self, f\"_execute_{method}\")(operation)\n\n if result != 0:\n return result\n\n operation_message = self.get_operation_message(operation, done=True)\n message = f\" - {operation_message}\"\n self._write(operation, message)\n\n self._increment_operations_count(operation, True)\n\n return result\n\n def _increment_operations_count(self, operation: Operation, executed: bool) -> None:\n with self._lock:\n if executed:\n self._executed_operations += 1\n self._executed[operation.job_type] += 1\n else:\n self._skipped[operation.job_type] += 1\n\n def run_pip(self, *args: Any, **kwargs: Any) -> int:\n try:\n self._env.run_pip(*args, **kwargs)\n except EnvCommandError as e:\n output = decode(e.e.output)\n if (\n \"KeyboardInterrupt\" in output\n or \"ERROR: Operation cancelled by user\" in output\n ):\n return -2\n\n raise\n\n return 0\n\n def get_operation_message(\n self,\n operation: Operation,\n done: bool = False,\n error: bool = False,\n warning: bool = False,\n ) -> str:\n base_tag = \"fg=default\"\n operation_color = \"c2\"\n source_operation_color = \"c2\"\n package_color = \"c1\"\n\n if error:\n operation_color = \"error\"\n elif warning:\n operation_color = \"warning\"\n elif done:\n operation_color = \"success\"\n\n if operation.skipped:\n base_tag = \"fg=default;options=dark\"\n operation_color += \"_dark\"\n source_operation_color += \"_dark\"\n package_color += \"_dark\"\n\n if isinstance(operation, Install):\n return (\n f\"<{base_tag}>Installing\"\n f\" <{package_color}>{operation.package.name}\"\n f\" (<{operation_color}>{operation.package.full_pretty_version})\"\n )\n\n if isinstance(operation, Uninstall):\n return (\n f\"<{base_tag}>Removing\"\n f\" <{package_color}>{operation.package.name}\"\n f\" (<{operation_color}>{operation.package.full_pretty_version})\"\n )\n\n if isinstance(operation, Update):\n initial_version = (initial_pkg := operation.initial_package).version\n target_version = (target_pkg := operation.target_package).version\n update_kind = (\n \"Updating\" if target_version >= initial_version else \"Downgrading\"\n )\n return (\n f\"<{base_tag}>{update_kind}\"\n f\" <{package_color}>{initial_pkg.name} \"\n f\"(<{source_operation_color}>\"\n f\"{initial_pkg.full_pretty_version}\"\n f\" -> <{operation_color}>\"\n f\"{target_pkg.full_pretty_version})\"\n )\n return \"\"\n\n def _display_summary(self, operations: list[Operation]) -> None:\n installs = 0\n updates = 0\n uninstalls = 0\n skipped = 0\n for op in operations:\n if op.skipped:\n skipped += 1\n continue\n\n if op.job_type == \"install\":\n installs += 1\n elif op.job_type == \"update\":\n updates += 1\n elif op.job_type == \"uninstall\":\n uninstalls += 1\n\n if not installs and not updates and not uninstalls and not self._verbose:\n self._io.write_line(\"\")\n self._io.write_line(\"No dependencies to install or update\")\n\n return\n\n self._io.write_line(\"\")\n self._io.write(\"Package operations: \")\n self._io.write(f\"{installs} install{pluralize(installs)}, \")\n self._io.write(f\"{updates} update{pluralize(updates)}, \")\n self._io.write(f\"{uninstalls} removal{pluralize(uninstalls)}\")\n if skipped and self._verbose:\n self._io.write(f\", {skipped} skipped\")\n self._io.write_line(\"\")\n self._io.write_line(\"\")\n\n def _execute_install(self, operation: Install | Update) -> int:\n status_code = self._install(operation)\n\n self._save_url_reference(operation)\n\n return status_code\n\n def _execute_update(self, operation: Install | Update) -> int:\n status_code = self._update(operation)\n\n self._save_url_reference(operation)\n\n return status_code\n\n def _execute_uninstall(self, operation: Uninstall) -> int:\n op_msg = self.get_operation_message(operation)\n message = f\" - {op_msg}: Removing...\"\n self._write(operation, message)\n\n return self._remove(operation.package)\n\n def _install(self, operation: Install | Update) -> int:\n package = operation.package\n if package.source_type == \"directory\" and not self._use_modern_installation:\n return self._install_directory_without_wheel_installer(operation)\n\n cleanup_archive: bool = False\n if package.source_type == \"git\":\n archive = self._prepare_git_archive(operation)\n cleanup_archive = operation.package.develop\n elif package.source_type == \"file\":\n archive = self._prepare_archive(operation)\n elif package.source_type == \"directory\":\n archive = self._prepare_archive(operation)\n cleanup_archive = True\n elif package.source_type == \"url\":\n assert package.source_url is not None\n archive = self._download_link(operation, Link(package.source_url))\n else:\n archive = self._download(operation)\n\n operation_message = self.get_operation_message(operation)\n message = (\n f\" - {operation_message}:\"\n \" Installing...\"\n )\n self._write(operation, message)\n\n if not self._use_modern_installation:\n return self.pip_install(archive, upgrade=operation.job_type == \"update\")\n\n try:\n if operation.job_type == \"update\":\n # Uninstall first\n # TODO: Make an uninstaller and find a way to rollback in case\n # the new package can't be installed\n assert isinstance(operation, Update)\n self._remove(operation.initial_package)\n\n self._wheel_installer.install(archive)\n finally:\n if cleanup_archive:\n archive.unlink()\n\n return 0\n\n def _update(self, operation: Install | Update) -> int:\n return self._install(operation)\n\n def _remove(self, package: Package) -> int:\n # If we have a VCS package, remove its source directory\n if package.source_type == \"git\":\n src_dir = self._env.path / \"src\" / package.name\n if src_dir.exists():\n remove_directory(src_dir, force=True)\n\n try:\n return self.run_pip(\"uninstall\", package.name, \"-y\")\n except CalledProcessError as e:\n if \"not installed\" in str(e):\n return 0\n\n raise\n\n def _prepare_archive(\n self, operation: Install | Update, *, output_dir: Path | None = None\n ) -> Path:\n package = operation.package\n operation_message = self.get_operation_message(operation)\n\n message = (\n f\" - {operation_message}:\"\n \" Preparing...\"\n )\n self._write(operation, message)\n\n assert package.source_url is not None\n archive = Path(package.source_url)\n if package.source_subdirectory:\n archive = archive / package.source_subdirectory\n if not Path(package.source_url).is_absolute() and package.root_dir:\n archive = package.root_dir / archive\n\n self._populate_hashes_dict(archive, package)\n\n return self._chef.prepare(\n archive, editable=package.develop, output_dir=output_dir\n )\n\n def _prepare_git_archive(self, operation: Install | Update) -> Path:\n from poetry.vcs.git import Git\n\n package = operation.package\n assert package.source_url is not None\n\n if package.source_resolved_reference and not package.develop:\n # Only cache git archives when we know precise reference hash,\n # otherwise we might get stale archives\n cached_archive = self._artifact_cache.get_cached_archive_for_git(\n package.source_url,\n package.source_resolved_reference,\n package.source_subdirectory,\n env=self._env,\n )\n if cached_archive is not None:\n return cached_archive\n\n operation_message = self.get_operation_message(operation)\n\n message = (\n f\" - {operation_message}: Cloning...\"\n )\n self._write(operation, message)\n\n source = Git.clone(\n url=package.source_url,\n source_root=self._env.path / \"src\",\n revision=package.source_resolved_reference or package.source_reference,\n )\n\n # Now we just need to install from the source directory\n original_url = package.source_url\n package._source_url = str(source.path)\n\n output_dir = None\n if package.source_resolved_reference and not package.develop:\n output_dir = self._artifact_cache.get_cache_directory_for_git(\n original_url,\n package.source_resolved_reference,\n package.source_subdirectory,\n )\n\n archive = self._prepare_archive(operation, output_dir=output_dir)\n if not package.develop:\n package._source_url = original_url\n\n if output_dir is not None and output_dir.is_dir():\n # Mark directories with cached git packages, to distinguish from\n # \"normal\" cache\n (output_dir / \".created_from_git_dependency\").touch()\n\n return archive\n\n def _install_directory_without_wheel_installer(\n self, operation: Install | Update\n ) -> int:\n from poetry.factory import Factory\n from poetry.pyproject.toml import PyProjectTOML\n\n package = operation.package\n operation_message = self.get_operation_message(operation)\n\n message = (\n f\" - {operation_message}:\"\n \" Building...\"\n )\n self._write(operation, message)\n\n assert package.source_url is not None\n if package.root_dir:\n req = package.root_dir / package.source_url\n else:\n req = Path(package.source_url).resolve(strict=False)\n\n if package.source_subdirectory:\n req /= package.source_subdirectory\n\n pyproject = PyProjectTOML(req / \"pyproject.toml\")\n\n package_poetry = None\n if pyproject.is_poetry_project():\n with contextlib.suppress(RuntimeError):\n package_poetry = Factory().create_poetry(pyproject.file.path.parent)\n\n if package_poetry is not None:\n builder: Builder\n if package.develop and not package_poetry.package.build_script:\n from poetry.masonry.builders.editable import EditableBuilder\n\n # This is a Poetry package in editable mode\n # we can use the EditableBuilder without going through pip\n # to install it, unless it has a build script.\n builder = EditableBuilder(package_poetry, self._env, NullIO())\n builder.build()\n\n return 0\n\n if package_poetry.package.build_script:\n from poetry.core.masonry.builders.sdist import SdistBuilder\n\n builder = SdistBuilder(package_poetry)\n with builder.setup_py():\n return self.pip_install(req, upgrade=True, editable=package.develop)\n\n return self.pip_install(req, upgrade=True, editable=package.develop)\n\n def _download(self, operation: Install | Update) -> Path:\n link = self._chooser.choose_for(operation.package)\n\n if link.yanked:\n # Store yanked warnings in a list and print after installing, so they can't\n # be overlooked. Further, printing them in the concerning section would have\n # the risk of overwriting the warning, so it is only briefly visible.\n message = (\n f\"The file chosen for install of {operation.package.pretty_name} \"\n f\"{operation.package.pretty_version} ({link.show_url}) is yanked.\"\n )\n if link.yanked_reason:\n message += f\" Reason for being yanked: {link.yanked_reason}\"\n self._yanked_warnings.append(message)\n\n return self._download_link(operation, link)\n\n def _download_link(self, operation: Install | Update, link: Link) -> Path:\n package = operation.package\n\n # Get original package for the link provided\n download_func = functools.partial(self._download_archive, operation)\n original_archive = self._artifact_cache.get_cached_archive_for_link(\n link, strict=True, download_func=download_func\n )\n\n # Get potential higher prioritized cached archive, otherwise it will fall back\n # to the original archive.\n archive = self._artifact_cache.get_cached_archive_for_link(\n link,\n strict=False,\n env=self._env,\n )\n if archive is None:\n # Since we previously downloaded an archive, we now should have\n # something cached that we can use here. The only case in which\n # archive is None is if the original archive is not valid for the\n # current environment.\n raise RuntimeError(\n f\"Package {link.url} cannot be installed in the current environment\"\n f\" {self._env.marker_env}\"\n )\n\n if archive.suffix != \".whl\":\n message = (\n f\" - {self.get_operation_message(operation)}:\"\n \" Preparing...\"\n )\n self._write(operation, message)\n\n archive = self._chef.prepare(archive, output_dir=original_archive.parent)\n\n # Use the original archive to provide the correct hash.\n self._populate_hashes_dict(original_archive, package)\n\n return archive\n\n def _populate_hashes_dict(self, archive: Path, package: Package) -> None:\n if package.files and archive.name in {f[\"file\"] for f in package.files}:\n archive_hash = self._validate_archive_hash(archive, package)\n self._hashes[package.name] = archive_hash\n\n @staticmethod\n def _validate_archive_hash(archive: Path, package: Package) -> str:\n known_hashes = {f[\"hash\"] for f in package.files if f[\"file\"] == archive.name}\n hash_types = {t.split(\":\")[0] for t in known_hashes}\n hash_type = get_highest_priority_hash_type(hash_types, archive.name)\n\n if hash_type is None:\n raise RuntimeError(\n f\"No usable hash type(s) for {package} from archive\"\n f\" {archive.name} found (known hashes: {known_hashes!s})\"\n )\n\n archive_hash = f\"{hash_type}:{get_file_hash(archive, hash_type)}\"\n\n if archive_hash not in known_hashes:\n raise RuntimeError(\n f\"Hash for {package} from archive {archive.name} not found in\"\n f\" known hashes (was: {archive_hash})\"\n )\n\n return archive_hash\n\n def _download_archive(\n self,\n operation: Install | Update,\n url: str,\n dest: Path,\n ) -> None:\n downloader = Downloader(url, dest, self._authenticator)\n wheel_size = downloader.total_size\n\n operation_message = self.get_operation_message(operation)\n message = (\n f\" - {operation_message}: Downloading...\"\n )\n progress = None\n if self.supports_fancy_output():\n if wheel_size is None:\n self._write(operation, message)\n else:\n from cleo.ui.progress_bar import ProgressBar\n\n progress = ProgressBar(\n self._sections[id(operation)], max=int(wheel_size)\n )\n progress.set_format(message + \" %percent%%\")\n\n if progress:\n with self._lock:\n self._sections[id(operation)].clear()\n progress.start()\n\n for fetched_size in downloader.download_with_progress(chunk_size=4096):\n if progress:\n with self._lock:\n progress.set_progress(fetched_size)\n\n if progress:\n with self._lock:\n progress.finish()\n\n def _should_write_operation(self, operation: Operation) -> bool:\n return (\n not operation.skipped or self._dry_run or self._verbose or not self._enabled\n )\n\n def _save_url_reference(self, operation: Operation) -> None:\n \"\"\"\n Create and store a PEP-610 `direct_url.json` file, if needed.\n \"\"\"\n if operation.job_type not in {\"install\", \"update\"}:\n return\n\n package = operation.package\n\n if not package.source_url or package.source_type == \"legacy\":\n if not self._use_modern_installation:\n # Since we are installing from our own distribution cache pip will write\n # a `direct_url.json` file pointing to the cache distribution.\n #\n # That's not what we want, so we remove the direct_url.json file, if it\n # exists.\n for (\n direct_url_json\n ) in self._env.site_packages.find_distribution_direct_url_json_files(\n distribution_name=package.name, writable_only=True\n ):\n direct_url_json.unlink(missing_ok=True)\n return\n\n url_reference: dict[str, Any] | None = None\n\n if package.source_type == \"git\" and not package.develop:\n url_reference = self._create_git_url_reference(package)\n elif package.source_type in (\"directory\", \"git\"):\n url_reference = self._create_directory_url_reference(package)\n elif package.source_type == \"url\":\n url_reference = self._create_url_url_reference(package)\n elif package.source_type == \"file\":\n url_reference = self._create_file_url_reference(package)\n\n if url_reference:\n for dist in self._env.site_packages.distributions(\n name=package.name, writable_only=True\n ):\n dist_path = dist._path # type: ignore[attr-defined]\n assert isinstance(dist_path, Path)\n url = dist_path / \"direct_url.json\"\n url.write_text(json.dumps(url_reference), encoding=\"utf-8\")\n\n record = dist_path / \"RECORD\"\n if record.exists():\n with record.open(mode=\"a\", encoding=\"utf-8\", newline=\"\") as f:\n writer = csv.writer(f)\n path = url.relative_to(record.parent.parent)\n writer.writerow([str(path), \"\", \"\"])\n\n def _create_git_url_reference(self, package: Package) -> dict[str, Any]:\n reference = {\n \"url\": package.source_url,\n \"vcs_info\": {\n \"vcs\": \"git\",\n \"requested_revision\": package.source_reference,\n \"commit_id\": package.source_resolved_reference,\n },\n }\n if package.source_subdirectory:\n reference[\"subdirectory\"] = package.source_subdirectory\n\n return reference\n\n def _create_url_url_reference(self, package: Package) -> dict[str, Any]:\n archive_info = self._get_archive_info(package)\n\n return {\"url\": package.source_url, \"archive_info\": archive_info}\n\n def _create_file_url_reference(self, package: Package) -> dict[str, Any]:\n archive_info = self._get_archive_info(package)\n\n assert package.source_url is not None\n return {\n \"url\": Path(package.source_url).as_uri(),\n \"archive_info\": archive_info,\n }\n\n def _create_directory_url_reference(self, package: Package) -> dict[str, Any]:\n dir_info = {}\n\n if package.develop:\n...\n", "needle_token_start": 10271, "needle_token_end": 10593, "code_context_ntokens": 16057, "output": ["```python\ndef _do_execute_operation(self, operation: Operation) -> int:\n method = operation.job_type\n\n operation_message = self.get_operation_message(operation)\n if operation.skipped:\n if self.supports_fancy_output():\n self._write(\n operation,\n f\" - {operation_message}: \"\n \"Skipped \"\n \"for the following reason: \"\n f\"{operation.skip_reason}\",\n )\n\n self._skipped[operation.job_type] += 1\n\n return 0\n\n if not self._enabled or self._dry_run:\n return 0\n\n result: int = getattr(self, f\"_execute_{method}\")(operation)\n\n if result != 0:\n return result\n\n operation_message = self.get_operation_message(operation, done=True)\n message = f\" - {operation_message}\"\n self._write(operation, message)\n\n self._increment_operations_count(operation, True)\n\n return result\n```"]} +{"repo": "python-poetry/poetry", "name": "_get_directory_from_hash", "language": "python", "path": "src/poetry/utils/cache.py", "position_ratio": 0.75, "description": "\nFunction Description:\n1. **Purpose**: To generate a directory path based on a hashed representation of key components, ensuring organized and consistent storage of cached data.\n2. **Input**: A dictionary containing key-value pairs that uniquely identify a cache item.\n3. **Output**: A `Path` object representing the directory where the cache item should be stored.\n4. **Procedure**: The function serializes the input dictionary into a JSON string, ensuring keys are sorted and special characters are handled properly. This string is then hashed using SHA-256. The hash is split into segments to create a nested directory structure, which is combined with the base cache directory to form the full path.\n\n", "instruction": "Based on the function description and code context, please retrieve and repeat the exact described function from the code context in a code block wrapped by ```:", "template": "instruction\ncode_context\ndescription\ninstruction", "code_context": "# Path: src/poetry/repositories/pypi_repository.py\nfrom __future__ import annotations\n\nimport logging\n\nfrom typing import TYPE_CHECKING\nfrom typing import Any\n\nimport requests\nimport requests.adapters\n\nfrom cachecontrol.controller import logger as cache_control_logger\nfrom poetry.core.packages.package import Package\nfrom poetry.core.packages.utils.link import Link\nfrom poetry.core.version.exceptions import InvalidVersion\n\nfrom poetry.repositories.exceptions import PackageNotFound\nfrom poetry.repositories.http_repository import HTTPRepository\nfrom poetry.repositories.link_sources.json import SimpleJsonPage\nfrom poetry.repositories.parsers.pypi_search_parser import SearchResultParser\nfrom poetry.utils.constants import REQUESTS_TIMEOUT\n\n\ncache_control_logger.setLevel(logging.ERROR)\n\nlogger = logging.getLogger(__name__)\n\nif TYPE_CHECKING:\n from packaging.utils import NormalizedName\n from poetry.core.constraints.version import Version\n from poetry.core.constraints.version import VersionConstraint\n\nSUPPORTED_PACKAGE_TYPES = {\"sdist\", \"bdist_wheel\"}\n\n\nclass PyPiRepository(HTTPRepository):\n def __init__(\n self,\n url: str = \"https://pypi.org/\",\n disable_cache: bool = False,\n fallback: bool = True,\n pool_size: int = requests.adapters.DEFAULT_POOLSIZE,\n ) -> None:\n super().__init__(\n \"PyPI\",\n url.rstrip(\"/\") + \"/simple/\",\n disable_cache=disable_cache,\n pool_size=pool_size,\n )\n\n self._base_url = url\n self._fallback = fallback\n\n def search(self, query: str) -> list[Package]:\n results = []\n\n response = requests.get(\n self._base_url + \"search\", params={\"q\": query}, timeout=REQUESTS_TIMEOUT\n )\n parser = SearchResultParser()\n parser.feed(response.text)\n\n for result in parser.results:\n try:\n package = Package(result.name, result.version)\n package.description = result.description.strip()\n results.append(package)\n except InvalidVersion:\n self._log(\n f'Unable to parse version \"{result.version}\" for the'\n f\" {result.name} package, skipping\",\n level=\"debug\",\n )\n\n return results\n\n def get_package_info(self, name: NormalizedName) -> dict[str, Any]:\n \"\"\"\n Return the package information given its name.\n\n The information is returned from the cache if it exists\n or retrieved from the remote server.\n \"\"\"\n return self._get_package_info(name)\n\n def _find_packages(\n self, name: NormalizedName, constraint: VersionConstraint\n ) -> list[Package]:\n \"\"\"\n Find packages on the remote server.\n \"\"\"\n try:\n json_page = self.get_page(name)\n except PackageNotFound:\n self._log(f\"No packages found for {name}\", level=\"debug\")\n return []\n\n versions = [\n (version, json_page.yanked(name, version))\n for version in json_page.versions(name)\n if constraint.allows(version)\n ]\n\n return [Package(name, version, yanked=yanked) for version, yanked in versions]\n\n def _get_package_info(self, name: NormalizedName) -> dict[str, Any]:\n headers = {\"Accept\": \"application/vnd.pypi.simple.v1+json\"}\n info = self._get(f\"simple/{name}/\", headers=headers)\n if info is None:\n raise PackageNotFound(f\"Package [{name}] not found.\")\n\n return info\n\n def find_links_for_package(self, package: Package) -> list[Link]:\n json_data = self._get(f\"pypi/{package.name}/{package.version}/json\")\n if json_data is None:\n return []\n\n links = []\n for url in json_data[\"urls\"]:\n if url[\"packagetype\"] in SUPPORTED_PACKAGE_TYPES:\n h = f\"sha256={url['digests']['sha256']}\"\n links.append(Link(url[\"url\"] + \"#\" + h, yanked=self._get_yanked(url)))\n\n return links\n\n def _get_release_info(\n self, name: NormalizedName, version: Version\n ) -> dict[str, Any]:\n from poetry.inspection.info import PackageInfo\n\n self._log(f\"Getting info for {name} ({version}) from PyPI\", \"debug\")\n\n json_data = self._get(f\"pypi/{name}/{version}/json\")\n if json_data is None:\n raise PackageNotFound(f\"Package [{name}] not found.\")\n\n info = json_data[\"info\"]\n\n data = PackageInfo(\n name=info[\"name\"],\n version=info[\"version\"],\n summary=info[\"summary\"],\n requires_dist=info[\"requires_dist\"],\n requires_python=info[\"requires_python\"],\n yanked=self._get_yanked(info),\n cache_version=str(self.CACHE_VERSION),\n )\n\n try:\n...\n# Path: src/poetry/repositories/repository.py\nfrom __future__ import annotations\n\nimport logging\n\nfrom typing import TYPE_CHECKING\n\nfrom packaging.utils import canonicalize_name\nfrom poetry.core.constraints.version import Version\n\nfrom poetry.repositories.abstract_repository import AbstractRepository\nfrom poetry.repositories.exceptions import PackageNotFound\n\n\nif TYPE_CHECKING:\n from packaging.utils import NormalizedName\n from poetry.core.constraints.version import VersionConstraint\n from poetry.core.packages.dependency import Dependency\n from poetry.core.packages.package import Package\n from poetry.core.packages.utils.link import Link\n\n\nclass Repository(AbstractRepository):\n def __init__(self, name: str, packages: list[Package] | None = None) -> None:\n super().__init__(name)\n self._packages: list[Package] = []\n\n for package in packages or []:\n self.add_package(package)\n\n @property\n def packages(self) -> list[Package]:\n return self._packages\n\n def find_packages(self, dependency: Dependency) -> list[Package]:\n packages = []\n ignored_pre_release_packages = []\n\n constraint = dependency.constraint\n allow_prereleases = dependency.allows_prereleases()\n for package in self._find_packages(dependency.name, constraint):\n if package.yanked and not isinstance(constraint, Version):\n # PEP 592: yanked files are always ignored, unless they are the only\n # file that matches a version specifier that \"pins\" to an exact\n # version\n continue\n if (\n package.is_prerelease()\n and not allow_prereleases\n and not package.is_direct_origin()\n ):\n ignored_pre_release_packages.append(package)\n continue\n\n packages.append(package)\n\n self._log(\n f\"{len(packages)} packages found for {dependency.name} {constraint!s}\",\n level=\"debug\",\n )\n\n return packages or ignored_pre_release_packages\n\n def has_package(self, package: Package) -> bool:\n package_id = package.unique_name\n return any(\n package_id == repo_package.unique_name for repo_package in self.packages\n )\n\n def add_package(self, package: Package) -> None:\n self._packages.append(package)\n\n def remove_package(self, package: Package) -> None:\n package_id = package.unique_name\n\n index = None\n for i, repo_package in enumerate(self.packages):\n if package_id == repo_package.unique_name:\n index = i\n break\n\n if index is not None:\n del self._packages[index]\n\n def search(self, query: str) -> list[Package]:\n results: list[Package] = []\n\n for package in self.packages:\n if query in package.name:\n results.append(package)\n\n return results\n\n def _find_packages(\n self, name: NormalizedName, constraint: VersionConstraint\n ) -> list[Package]:\n return [\n package\n for package in self._packages\n if package.name == name and constraint.allows(package.version)\n ]\n\n def _log(self, msg: str, level: str = \"info\") -> None:\n logger = logging.getLogger(f\"{__name__}.{self.__class__.__name__}\")\n getattr(logger, level)(f\"Source ({self.name}): {msg}\")\n\n def __len__(self) -> int:\n return len(self._packages)\n\n def find_links_for_package(self, package: Package) -> list[Link]:\n return []\n\n def package(\n self, name: str, version: Version, extras: list[str] | None = None\n ) -> Package:\n canonicalized_name = canonicalize_name(name)\n for package in self.packages:\n if canonicalized_name == package.name and package.version == version:\n return package\n\n raise PackageNotFound(f\"Package {name} ({version}) not found.\")\n\n# Path: src/poetry/repositories/repository_pool.py\nfrom __future__ import annotations\n\nimport enum\nimport warnings\n\nfrom collections import OrderedDict\nfrom dataclasses import dataclass\nfrom enum import IntEnum\nfrom typing import TYPE_CHECKING\n\nfrom poetry.config.config import Config\nfrom poetry.repositories.abstract_repository import AbstractRepository\nfrom poetry.repositories.exceptions import PackageNotFound\nfrom poetry.repositories.repository import Repository\nfrom poetry.utils.cache import ArtifactCache\n\n\nif TYPE_CHECKING:\n from poetry.core.constraints.version import Version\n from poetry.core.packages.dependency import Dependency\n from poetry.core.packages.package import Package\n\n_SENTINEL = object()\n\n\nclass Priority(IntEnum):\n # The order of the members below dictates the actual priority. The first member has\n # top priority.\n DEFAULT = enum.auto()\n PRIMARY = enum.auto()\n SECONDARY = enum.auto()\n SUPPLEMENTAL = enum.auto()\n EXPLICIT = enum.auto()\n\n\n@dataclass(frozen=True)\nclass PrioritizedRepository:\n repository: Repository\n priority: Priority\n\n\nclass RepositoryPool(AbstractRepository):\n def __init__(\n self,\n repositories: list[Repository] | None = None,\n ignore_repository_names: object = _SENTINEL,\n *,\n config: Config | None = None,\n ) -> None:\n super().__init__(\"poetry-repository-pool\")\n self._repositories: OrderedDict[str, PrioritizedRepository] = OrderedDict()\n\n if repositories is None:\n repositories = []\n for repository in repositories:\n self.add_repository(repository)\n\n self._artifact_cache = ArtifactCache(\n cache_dir=(config or Config.create()).artifacts_cache_directory\n )\n\n if ignore_repository_names is not _SENTINEL:\n warnings.warn(\n \"The 'ignore_repository_names' argument to 'RepositoryPool.__init__' is\"\n \" deprecated. It has no effect anymore and will be removed in a future\"\n \" version.\",\n DeprecationWarning,\n stacklevel=2,\n )\n\n @staticmethod\n def from_packages(packages: list[Package], config: Config | None) -> RepositoryPool:\n pool = RepositoryPool(config=config)\n for package in packages:\n if package.is_direct_origin():\n continue\n\n repo_name = package.source_reference or \"PyPI\"\n try:\n repo = pool.repository(repo_name)\n except IndexError:\n repo = Repository(repo_name)\n pool.add_repository(repo)\n\n if not repo.has_package(package):\n repo.add_package(package)\n\n return pool\n\n @property\n def repositories(self) -> list[Repository]:\n \"\"\"\n Returns the repositories in the pool,\n in the order they will be searched for packages.\n\n ATTENTION: For backwards compatibility and practical reasons,\n repositories with priority EXPLICIT are NOT included,\n because they will not be searched.\n \"\"\"\n sorted_repositories = self._sorted_repositories\n return [\n prio_repo.repository\n for prio_repo in sorted_repositories\n if prio_repo.priority is not Priority.EXPLICIT\n ]\n\n @property\n def all_repositories(self) -> list[Repository]:\n return [prio_repo.repository for prio_repo in self._sorted_repositories]\n\n @property\n def _sorted_repositories(self) -> list[PrioritizedRepository]:\n return sorted(\n self._repositories.values(), key=lambda prio_repo: prio_repo.priority\n )\n\n @property\n def artifact_cache(self) -> ArtifactCache:\n return self._artifact_cache\n\n def has_default(self) -> bool:\n return self._contains_priority(Priority.DEFAULT)\n\n def has_primary_repositories(self) -> bool:\n return self._contains_priority(Priority.PRIMARY)\n\n def _contains_priority(self, priority: Priority) -> bool:\n return any(\n prio_repo.priority is priority for prio_repo in self._repositories.values()\n )\n\n def has_repository(self, name: str) -> bool:\n return name.lower() in self._repositories\n\n def repository(self, name: str) -> Repository:\n return self._get_prioritized_repository(name).repository\n\n def get_priority(self, name: str) -> Priority:\n return self._get_prioritized_repository(name).priority\n\n def _get_prioritized_repository(self, name: str) -> PrioritizedRepository:\n name = name.lower()\n if self.has_repository(name):\n return self._repositories[name]\n raise IndexError(f'Repository \"{name}\" does not exist.')\n\n def add_repository(\n self,\n repository: Repository,\n default: bool = False,\n secondary: bool = False,\n *,\n priority: Priority = Priority.PRIMARY,\n ) -> RepositoryPool:\n \"\"\"\n Adds a repository to the pool.\n \"\"\"\n repository_name = repository.name.lower()\n if self.has_repository(repository_name):\n raise ValueError(\n f\"A repository with name {repository_name} was already added.\"\n )\n\n if default or secondary:\n warnings.warn(\n \"Parameters 'default' and 'secondary' to\"\n \" 'RepositoryPool.add_repository' are deprecated. Please provide\"\n \" the keyword-argument 'priority' instead.\",\n DeprecationWarning,\n stacklevel=2,\n )\n priority = Priority.DEFAULT if default else Priority.SECONDARY\n\n if priority is Priority.DEFAULT and self.has_default():\n raise ValueError(\"Only one repository can be the default.\")\n\n self._repositories[repository_name] = PrioritizedRepository(\n repository, priority\n )\n return self\n\n def remove_repository(self, name: str) -> RepositoryPool:\n if not self.has_repository(name):\n raise IndexError(\n f\"RepositoryPool can not remove unknown repository '{name}'.\"\n )\n del self._repositories[name.lower()]\n return self\n\n def package(\n self,\n name: str,\n version: Version,\n extras: list[str] | None = None,\n repository_name: str | None = None,\n ) -> Package:\n if repository_name:\n return self.repository(repository_name).package(\n name, version, extras=extras\n )\n\n for repo in self.repositories:\n try:\n return repo.package(name, version, extras=extras)\n except PackageNotFound:\n continue\n raise PackageNotFound(f\"Package {name} ({version}) not found.\")\n\n def find_packages(self, dependency: Dependency) -> list[Package]:\n repository_name = dependency.source_name\n if repository_name:\n return self.repository(repository_name).find_packages(dependency)\n\n packages: list[Package] = []\n for repo in self.repositories:\n if packages and self.get_priority(repo.name) is Priority.SUPPLEMENTAL:\n break\n packages += repo.find_packages(dependency)\n return packages\n\n def search(self, query: str) -> list[Package]:\n results: list[Package] = []\n for repo in self.repositories:\n results += repo.search(query)\n return results\n\n# Path: src/poetry/repositories/single_page_repository.py\nfrom __future__ import annotations\n\nfrom typing import TYPE_CHECKING\n\nfrom poetry.repositories.exceptions import PackageNotFound\nfrom poetry.repositories.legacy_repository import LegacyRepository\nfrom poetry.repositories.link_sources.html import SimpleRepositoryPage\n\n\nif TYPE_CHECKING:\n from packaging.utils import NormalizedName\n\n\nclass SinglePageRepository(LegacyRepository):\n def _get_page(self, name: NormalizedName) -> SimpleRepositoryPage:\n \"\"\"\n Single page repositories only have one page irrespective of endpoint.\n \"\"\"\n response = self._get_response(\"\")\n if not response:\n raise PackageNotFound(f\"Package [{name}] not found.\")\n return SimpleRepositoryPage(response.url, response.text)\n\n# Path: src/poetry/toml/__init__.py\nfrom __future__ import annotations\n\nfrom poetry.toml.exceptions import TOMLError\nfrom poetry.toml.file import TOMLFile\n\n\n__all__ = [\"TOMLError\", \"TOMLFile\"]\n\n# Path: src/poetry/toml/exceptions.py\nfrom __future__ import annotations\n\nfrom poetry.core.exceptions import PoetryCoreException\nfrom tomlkit.exceptions import TOMLKitError\n\n\nclass TOMLError(TOMLKitError, PoetryCoreException):\n pass\n\n# Path: src/poetry/toml/file.py\nfrom __future__ import annotations\n\nimport warnings\n\nfrom typing import TYPE_CHECKING\nfrom typing import Any\n\nfrom tomlkit.toml_file import TOMLFile as BaseTOMLFile\n\n\nif TYPE_CHECKING:\n from pathlib import Path\n\n from tomlkit.toml_document import TOMLDocument\n\n\nclass TOMLFile(BaseTOMLFile):\n def __init__(self, path: Path) -> None:\n super().__init__(path)\n self.__path = path\n\n @property\n def path(self) -> Path:\n return self.__path\n\n def exists(self) -> bool:\n return self.__path.exists()\n\n def read(self) -> TOMLDocument:\n from tomlkit.exceptions import TOMLKitError\n\n from poetry.toml import TOMLError\n\n try:\n return super().read()\n except (ValueError, TOMLKitError) as e:\n raise TOMLError(f\"Invalid TOML file {self.path.as_posix()}: {e}\")\n\n def __getattr__(self, item: str) -> Any:\n warnings.warn(\n \"`__getattr__` will be removed from the `TOMLFile` in a future release.\"\n \"\\n\\nInstead of accessing properties of the underlying `Path` as \"\n \"`tomlfile.whatever`, prefer `tomlfile.path.whatever`.\",\n DeprecationWarning,\n stacklevel=2,\n )\n return getattr(self.__path, item)\n\n def __str__(self) -> str:\n return self.__path.as_posix()\n\n# Path: src/poetry/utils/__init__.py\n\n# Path: src/poetry/utils/_compat.py\nfrom __future__ import annotations\n\nimport sys\n\nfrom contextlib import suppress\n\n\n# TODO: use try/except ImportError when\n# https://github.com/python/mypy/issues/1393 is fixed\n\nif sys.version_info < (3, 11):\n # compatibility for python <3.11\n import tomli as tomllib\nelse:\n import tomllib # nopycln: import\n\n\nif sys.version_info < (3, 10):\n # compatibility for python <3.10\n import importlib_metadata as metadata\nelse:\n from importlib import metadata\n\nWINDOWS = sys.platform == \"win32\"\n\n\ndef decode(string: bytes | str, encodings: list[str] | None = None) -> str:\n if not isinstance(string, bytes):\n return string\n\n encodings = encodings or [\"utf-8\", \"latin1\", \"ascii\"]\n\n for encoding in encodings:\n with suppress(UnicodeEncodeError, UnicodeDecodeError):\n return string.decode(encoding)\n\n return string.decode(encodings[0], errors=\"ignore\")\n\n\ndef encode(string: str, encodings: list[str] | None = None) -> bytes:\n if isinstance(string, bytes):\n return string\n\n encodings = encodings or [\"utf-8\", \"latin1\", \"ascii\"]\n\n for encoding in encodings:\n with suppress(UnicodeEncodeError, UnicodeDecodeError):\n return string.encode(encoding)\n\n return string.encode(encodings[0], errors=\"ignore\")\n\n\n__all__ = [\n \"WINDOWS\",\n \"decode\",\n \"encode\",\n \"metadata\",\n \"tomllib\",\n]\n\n# Path: src/poetry/utils/authenticator.py\nfrom __future__ import annotations\n\nimport contextlib\nimport dataclasses\nimport functools\nimport logging\nimport time\nimport urllib.parse\n\nfrom os.path import commonprefix\nfrom pathlib import Path\nfrom typing import TYPE_CHECKING\nfrom typing import Any\n\nimport requests\nimport requests.adapters\nimport requests.auth\nimport requests.exceptions\n\nfrom cachecontrol import CacheControlAdapter\nfrom cachecontrol.caches import FileCache\nfrom requests_toolbelt import user_agent\n\nfrom poetry.__version__ import __version__\nfrom poetry.config.config import Config\nfrom poetry.exceptions import PoetryException\nfrom poetry.utils.constants import REQUESTS_TIMEOUT\nfrom poetry.utils.constants import RETRY_AFTER_HEADER\nfrom poetry.utils.constants import STATUS_FORCELIST\nfrom poetry.utils.password_manager import HTTPAuthCredential\nfrom poetry.utils.password_manager import PasswordManager\n\n\nif TYPE_CHECKING:\n from cleo.io.io import IO\n\n\nlogger = logging.getLogger(__name__)\n\n\n@dataclasses.dataclass(frozen=True)\nclass RepositoryCertificateConfig:\n cert: Path | None = dataclasses.field(default=None)\n client_cert: Path | None = dataclasses.field(default=None)\n verify: bool = dataclasses.field(default=True)\n\n @classmethod\n def create(\n cls, repository: str, config: Config | None\n ) -> RepositoryCertificateConfig:\n config = config if config else Config.create()\n\n verify: str | bool = config.get(\n f\"certificates.{repository}.verify\",\n config.get(f\"certificates.{repository}.cert\", True),\n )\n client_cert: str = config.get(f\"certificates.{repository}.client-cert\")\n\n return cls(\n cert=Path(verify) if isinstance(verify, str) else None,\n client_cert=Path(client_cert) if client_cert else None,\n verify=verify if isinstance(verify, bool) else True,\n )\n\n\n@dataclasses.dataclass\nclass AuthenticatorRepositoryConfig:\n name: str\n url: str\n netloc: str = dataclasses.field(init=False)\n path: str = dataclasses.field(init=False)\n\n def __post_init__(self) -> None:\n parsed_url = urllib.parse.urlsplit(self.url)\n self.netloc = parsed_url.netloc\n self.path = parsed_url.path\n\n def certs(self, config: Config) -> RepositoryCertificateConfig:\n return RepositoryCertificateConfig.create(self.name, config)\n\n @property\n def http_credential_keys(self) -> list[str]:\n return [self.url, self.netloc, self.name]\n\n def get_http_credentials(\n self, password_manager: PasswordManager, username: str | None = None\n ) -> HTTPAuthCredential:\n # try with the repository name via the password manager\n credential = HTTPAuthCredential(\n **(password_manager.get_http_auth(self.name) or {})\n )\n\n if credential.password is not None:\n return credential\n\n if password_manager.use_keyring:\n # fallback to url and netloc based keyring entries\n credential = password_manager.get_credential(\n self.url, self.netloc, username=credential.username\n )\n\n return credential\n\n\nclass Authenticator:\n def __init__(\n self,\n config: Config | None = None,\n io: IO | None = None,\n cache_id: str | None = None,\n disable_cache: bool = False,\n pool_size: int = requests.adapters.DEFAULT_POOLSIZE,\n ) -> None:\n self._config = config or Config.create()\n self._io = io\n self._sessions_for_netloc: dict[str, requests.Session] = {}\n self._credentials: dict[str, HTTPAuthCredential] = {}\n self._certs: dict[str, RepositoryCertificateConfig] = {}\n self._configured_repositories: (\n dict[str, AuthenticatorRepositoryConfig] | None\n ) = None\n self._password_manager = PasswordManager(self._config)\n self._cache_control = (\n FileCache(\n self._config.repository_cache_directory\n / (cache_id or \"_default_cache\")\n / \"_http\"\n )\n if not disable_cache\n else None\n )\n self.get_repository_config_for_url = functools.lru_cache(maxsize=None)(\n self._get_repository_config_for_url\n )\n self._pool_size = pool_size\n self._user_agent = user_agent(\"poetry\", __version__)\n\n def create_session(self) -> requests.Session:\n session = requests.Session()\n session.headers[\"User-Agent\"] = self._user_agent\n\n if self._cache_control is None:\n return session\n\n adapter = CacheControlAdapter(\n cache=self._cache_control,\n pool_maxsize=self._pool_size,\n )\n session.mount(\"http://\", adapter)\n session.mount(\"https://\", adapter)\n\n return session\n\n def get_session(self, url: str | None = None) -> requests.Session:\n if not url:\n return self.create_session()\n\n parsed_url = urllib.parse.urlsplit(url)\n netloc = parsed_url.netloc\n\n if netloc not in self._sessions_for_netloc:\n logger.debug(\"Creating new session for %s\", netloc)\n self._sessions_for_netloc[netloc] = self.create_session()\n\n return self._sessions_for_netloc[netloc]\n\n def close(self) -> None:\n for session in self._sessions_for_netloc.values():\n if session is not None:\n with contextlib.suppress(AttributeError):\n session.close()\n\n def __del__(self) -> None:\n self.close()\n\n def delete_cache(self, url: str) -> None:\n if self._cache_control is not None:\n self._cache_control.delete(key=url)\n\n def authenticated_url(self, url: str) -> str:\n parsed = urllib.parse.urlparse(url)\n credential = self.get_credentials_for_url(url)\n\n if credential.username is not None and credential.password is not None:\n username = urllib.parse.quote(credential.username, safe=\"\")\n password = urllib.parse.quote(credential.password, safe=\"\")\n\n return (\n f\"{parsed.scheme}://{username}:{password}@{parsed.netloc}{parsed.path}\"\n )\n\n return url\n\n def request(\n self, method: str, url: str, raise_for_status: bool = True, **kwargs: Any\n ) -> requests.Response:\n headers = kwargs.get(\"headers\")\n request = requests.Request(method, url, headers=headers)\n credential = self.get_credentials_for_url(url)\n\n if credential.username is not None or credential.password is not None:\n request = requests.auth.HTTPBasicAuth(\n credential.username or \"\", credential.password or \"\"\n )(request)\n\n session = self.get_session(url=url)\n prepared_request = session.prepare_request(request)\n\n proxies: dict[str, str] = kwargs.get(\"proxies\", {})\n stream: bool | None = kwargs.get(\"stream\")\n\n certs = self.get_certs_for_url(url)\n verify: bool | str | Path = kwargs.get(\"verify\") or certs.cert or certs.verify\n cert: str | Path | None = kwargs.get(\"cert\") or certs.client_cert\n\n if cert is not None:\n cert = str(cert)\n\n verify = str(verify) if isinstance(verify, Path) else verify\n\n settings = session.merge_environment_settings(\n prepared_request.url, proxies, stream, verify, cert\n )\n\n # Send the request.\n send_kwargs = {\n \"timeout\": kwargs.get(\"timeout\", REQUESTS_TIMEOUT),\n \"allow_redirects\": kwargs.get(\"allow_redirects\", True),\n }\n send_kwargs.update(settings)\n\n attempt = 0\n resp = None\n\n while True:\n is_last_attempt = attempt >= 5\n try:\n resp = session.send(prepared_request, **send_kwargs)\n except (requests.exceptions.ConnectionError, OSError) as e:\n if is_last_attempt:\n raise e\n else:\n if resp.status_code not in STATUS_FORCELIST or is_last_attempt:\n if raise_for_status:\n resp.raise_for_status()\n return resp\n\n if not is_last_attempt:\n attempt += 1\n delay = self._get_backoff(resp, attempt)\n logger.debug(\"Retrying HTTP request in %s seconds.\", delay)\n time.sleep(delay)\n continue\n\n # this should never really be hit under any sane circumstance\n raise PoetryException(\"Failed HTTP {} request\", method.upper())\n\n def _get_backoff(self, response: requests.Response | None, attempt: int) -> float:\n if response is not None:\n retry_after = response.headers.get(RETRY_AFTER_HEADER, \"\")\n if retry_after:\n return float(retry_after)\n\n return 0.5 * attempt\n\n def get(self, url: str, **kwargs: Any) -> requests.Response:\n return self.request(\"get\", url, **kwargs)\n\n def head(self, url: str, **kwargs: Any) -> requests.Response:\n kwargs.setdefault(\"allow_redirects\", False)\n return self.request(\"head\", url, **kwargs)\n\n def post(self, url: str, **kwargs: Any) -> requests.Response:\n return self.request(\"post\", url, **kwargs)\n\n def _get_credentials_for_repository(\n self, repository: AuthenticatorRepositoryConfig, username: str | None = None\n ) -> HTTPAuthCredential:\n # cache repository credentials by repository url to avoid multiple keyring\n # backend queries when packages are being downloaded from the same source\n key = f\"{repository.url}#username={username or ''}\"\n\n if key not in self._credentials:\n self._credentials[key] = repository.get_http_credentials(\n password_manager=self._password_manager, username=username\n )\n\n return self._credentials[key]\n\n def _get_credentials_for_url(\n self, url: str, exact_match: bool = False\n ) -> HTTPAuthCredential:\n repository = self.get_repository_config_for_url(url, exact_match)\n\n credential = (\n self._get_credentials_for_repository(repository=repository)\n if repository is not None\n else HTTPAuthCredential()\n )\n\n if credential.password is None:\n parsed_url = urllib.parse.urlsplit(url)\n netloc = parsed_url.netloc\n credential = self._password_manager.get_credential(\n url, netloc, username=credential.username\n )\n\n return HTTPAuthCredential(\n username=credential.username, password=credential.password\n )\n\n return credential\n\n def get_credentials_for_git_url(self, url: str) -> HTTPAuthCredential:\n parsed_url = urllib.parse.urlsplit(url)\n\n if parsed_url.scheme not in {\"http\", \"https\"}:\n return HTTPAuthCredential()\n\n key = f\"git+{url}\"\n\n if key not in self._credentials:\n self._credentials[key] = self._get_credentials_for_url(url, True)\n\n return self._credentials[key]\n\n def get_credentials_for_url(self, url: str) -> HTTPAuthCredential:\n parsed_url = urllib.parse.urlsplit(url)\n netloc = parsed_url.netloc\n\n if url not in self._credentials:\n if \"@\" not in netloc:\n # no credentials were provided in the url, try finding the\n # best repository configuration\n self._credentials[url] = self._get_credentials_for_url(url)\n else:\n # Split from the right because that's how urllib.parse.urlsplit()\n # behaves if more than one @ is present (which can be checked using\n # the password attribute of urlsplit()'s return value).\n auth, netloc = netloc.rsplit(\"@\", 1)\n # Split from the left because that's how urllib.parse.urlsplit()\n # behaves if more than one : is present (which again can be checked\n # using the password attribute of the return value)\n user, password = auth.split(\":\", 1) if \":\" in auth else (auth, \"\")\n self._credentials[url] = HTTPAuthCredential(\n urllib.parse.unquote(user),\n urllib.parse.unquote(password),\n )\n\n return self._credentials[url]\n\n def get_pypi_token(self, name: str) -> str | None:\n return self._password_manager.get_pypi_token(name)\n\n def get_http_auth(\n self, name: str, username: str | None = None\n ) -> HTTPAuthCredential | None:\n if name == \"pypi\":\n repository = AuthenticatorRepositoryConfig(\n name, \"https://upload.pypi.org/legacy/\"\n )\n else:\n if name not in self.configured_repositories:\n return None\n repository = self.configured_repositories[name]\n\n return self._get_credentials_for_repository(\n repository=repository, username=username\n )\n\n def get_certs_for_repository(self, name: str) -> RepositoryCertificateConfig:\n if name.lower() == \"pypi\" or name not in self.configured_repositories:\n return RepositoryCertificateConfig()\n return self.configured_repositories[name].certs(self._config)\n\n @property\n def configured_repositories(self) -> dict[str, AuthenticatorRepositoryConfig]:\n if self._configured_repositories is None:\n self._configured_repositories = {}\n for repository_name in self._config.get(\"repositories\", []):\n url = self._config.get(f\"repositories.{repository_name}.url\")\n self._configured_repositories[repository_name] = (\n AuthenticatorRepositoryConfig(repository_name, url)\n )\n\n return self._configured_repositories\n\n def reset_credentials_cache(self) -> None:\n self.get_repository_config_for_url.cache_clear()\n self._credentials = {}\n\n def add_repository(self, name: str, url: str) -> None:\n self.configured_repositories[name] = AuthenticatorRepositoryConfig(name, url)\n self.reset_credentials_cache()\n\n def get_certs_for_url(self, url: str) -> RepositoryCertificateConfig:\n if url not in self._certs:\n self._certs[url] = self._get_certs_for_url(url)\n return self._certs[url]\n\n def _get_repository_config_for_url(\n self, url: str, exact_match: bool = False\n ) -> AuthenticatorRepositoryConfig | None:\n parsed_url = urllib.parse.urlsplit(url)\n candidates_netloc_only = []\n candidates_path_match = []\n\n for repository in self.configured_repositories.values():\n if exact_match:\n if parsed_url.path == repository.path:\n return repository\n continue\n\n if repository.netloc == parsed_url.netloc:\n if parsed_url.path.startswith(repository.path) or commonprefix(\n (parsed_url.path, repository.path)\n ):\n candidates_path_match.append(repository)\n continue\n candidates_netloc_only.append(repository)\n\n if candidates_path_match:\n candidates = candidates_path_match\n elif candidates_netloc_only:\n candidates = candidates_netloc_only\n else:\n return None\n\n if len(candidates) > 1:\n logger.debug(\n \"Multiple source configurations found for %s - %s\",\n parsed_url.netloc,\n \", \".join(c.name for c in candidates),\n )\n # prefer the more specific path\n candidates.sort(\n key=lambda c: len(commonprefix([parsed_url.path, c.path])), reverse=True\n )\n\n return candidates[0]\n\n def _get_certs_for_url(self, url: str) -> RepositoryCertificateConfig:\n selected = self.get_repository_config_for_url(url)\n if selected:\n return selected.certs(config=self._config)\n return RepositoryCertificateConfig()\n\n\n_authenticator: Authenticator | None = None\n\n\ndef get_default_authenticator() -> Authenticator:\n global _authenticator\n\n if _authenticator is None:\n _authenticator = Authenticator()\n\n return _authenticator\n\n# Path: src/poetry/utils/cache.py\nfrom __future__ import annotations\n\nimport dataclasses\nimport hashlib\nimport json\nimport logging\nimport shutil\nimport threading\nimport time\n\nfrom collections import defaultdict\nfrom pathlib import Path\nfrom typing import TYPE_CHECKING\nfrom typing import Any\nfrom typing import Generic\nfrom typing import TypeVar\nfrom typing import overload\n\nfrom poetry.utils._compat import decode\nfrom poetry.utils._compat import encode\nfrom poetry.utils.helpers import get_highest_priority_hash_type\nfrom poetry.utils.wheel import InvalidWheelName\nfrom poetry.utils.wheel import Wheel\n\n\nif TYPE_CHECKING:\n from collections.abc import Callable\n\n from poetry.core.packages.utils.link import Link\n\n from poetry.utils.env import Env\n\n\n# Used by FileCache for items that do not expire.\nMAX_DATE = 9999999999\nT = TypeVar(\"T\")\n\nlogger = logging.getLogger(__name__)\n\n\ndef _expiration(minutes: int) -> int:\n \"\"\"\n Calculates the time in seconds since epoch that occurs 'minutes' from now.\n\n :param minutes: The number of minutes to count forward\n \"\"\"\n return round(time.time()) + minutes * 60\n\n\n_HASHES = {\n \"md5\": (hashlib.md5, 2),\n \"sha1\": (hashlib.sha1, 4),\n \"sha256\": (hashlib.sha256, 8),\n}\n\n\n@dataclasses.dataclass(frozen=True)\nclass CacheItem(Generic[T]):\n \"\"\"\n Stores data and metadata for cache items.\n \"\"\"\n\n data: T\n expires: int | None = None\n\n @property\n def expired(self) -> bool:\n \"\"\"\n Return true if the cache item has exceeded its expiration period.\n \"\"\"\n return self.expires is not None and time.time() >= self.expires\n\n\n@dataclasses.dataclass(frozen=True)\nclass FileCache(Generic[T]):\n \"\"\"\n Cachy-compatible minimal file cache. Stores subsequent data in a JSON format.\n\n :param path: The path that the cache starts at.\n :param hash_type: The hash to use for encoding keys/building directories.\n \"\"\"\n\n path: Path\n hash_type: str = \"sha256\"\n\n def __post_init__(self) -> None:\n if self.hash_type not in _HASHES:\n raise ValueError(\n f\"FileCache.hash_type is unknown value: '{self.hash_type}'.\"\n )\n\n def get(self, key: str) -> T | None:\n return self._get_payload(key)\n\n def has(self, key: str) -> bool:\n \"\"\"\n Determine if a file exists and has not expired in the cache.\n :param key: The cache key\n :returns: True if the key exists in the cache\n \"\"\"\n return self.get(key) is not None\n\n def put(self, key: str, value: Any, minutes: int | None = None) -> None:\n \"\"\"\n Store an item in the cache.\n\n :param key: The cache key\n :param value: The cache value\n :param minutes: The lifetime in minutes of the cached value\n \"\"\"\n payload: CacheItem[Any] = CacheItem(\n value, expires=_expiration(minutes) if minutes is not None else None\n )\n path = self._path(key)\n path.parent.mkdir(parents=True, exist_ok=True)\n with path.open(\"wb\") as f:\n f.write(self._serialize(payload))\n\n def forget(self, key: str) -> None:\n \"\"\"\n Remove an item from the cache.\n\n :param key: The cache key\n \"\"\"\n path = self._path(key)\n if path.exists():\n path.unlink()\n\n def flush(self) -> None:\n \"\"\"\n Clear the cache.\n \"\"\"\n shutil.rmtree(self.path)\n\n def remember(\n self, key: str, callback: T | Callable[[], T], minutes: int | None = None\n ) -> T:\n \"\"\"\n Get an item from the cache, or use a default from callback.\n\n :param key: The cache key\n :param callback: Callback function providing default value\n :param minutes: The lifetime in minutes of the cached value\n \"\"\"\n value = self.get(key)\n if value is None:\n value = callback() if callable(callback) else callback\n self.put(key, value, minutes)\n return value\n\n def _get_payload(self, key: str) -> T | None:\n path = self._path(key)\n\n if not path.exists():\n return None\n\n with path.open(\"rb\") as f:\n file_content = f.read()\n\n try:\n payload = self._deserialize(file_content)\n except (json.JSONDecodeError, ValueError):\n self.forget(key)\n logger.warning(\"Corrupt cache file was detected and cleaned up.\")\n return None\n\n if payload.expired:\n self.forget(key)\n return None\n else:\n return payload.data\n\n def _path(self, key: str) -> Path:\n hash_type, parts_count = _HASHES[self.hash_type]\n h = hash_type(encode(key)).hexdigest()\n parts = [h[i : i + 2] for i in range(0, len(h), 2)][:parts_count]\n return Path(self.path, *parts, h)\n\n def _serialize(self, payload: CacheItem[T]) -> bytes:\n expires = payload.expires or MAX_DATE\n data = json.dumps(payload.data)\n return encode(f\"{expires:010d}{data}\")\n\n def _deserialize(self, data_raw: bytes) -> CacheItem[T]:\n data_str = decode(data_raw)\n data = json.loads(data_str[10:])\n expires = int(data_str[:10])\n return CacheItem(data, expires)\n\n\nclass ArtifactCache:\n def __init__(self, *, cache_dir: Path) -> None:\n self._cache_dir = cache_dir\n self._archive_locks: defaultdict[Path, threading.Lock] = defaultdict(\n threading.Lock\n )\n\n def get_cache_directory_for_link(self, link: Link) -> Path:\n key_parts = {\"url\": link.url_without_fragment}\n\n if hash_name := get_highest_priority_hash_type(\n set(link.hashes.keys()), link.filename\n ):\n key_parts[hash_name] = link.hashes[hash_name]\n\n if link.subdirectory_fragment:\n key_parts[\"subdirectory\"] = link.subdirectory_fragment\n\n return self._get_directory_from_hash(key_parts)\n\n \ndef _get_directory_from_hash(self, key_parts: object) -> Path:\n key = hashlib.sha256(\n json.dumps(\n key_parts, sort_keys=True, separators=(\",\", \":\"), ensure_ascii=True\n ).encode(\"ascii\")\n ).hexdigest()\n\n split_key = [key[:2], key[2:4], key[4:6], key[6:]]\n return self._cache_dir.joinpath(*split_key)\n\n def get_cache_directory_for_git(\n self, url: str, ref: str, subdirectory: str | None\n ) -> Path:\n key_parts = {\"url\": url, \"ref\": ref}\n if subdirectory:\n key_parts[\"subdirectory\"] = subdirectory\n\n return self._get_directory_from_hash(key_parts)\n\n @overload\n def get_cached_archive_for_link(\n self,\n link: Link,\n *,\n strict: bool,\n env: Env | None = ...,\n download_func: Callable[[str, Path], None],\n ) -> Path: ...\n\n @overload\n def get_cached_archive_for_link(\n self,\n link: Link,\n *,\n strict: bool,\n env: Env | None = ...,\n download_func: None = ...,\n ) -> Path | None: ...\n\n def get_cached_archive_for_link(\n self,\n link: Link,\n *,\n strict: bool,\n env: Env | None = None,\n download_func: Callable[[str, Path], None] | None = None,\n ) -> Path | None:\n cache_dir = self.get_cache_directory_for_link(link)\n\n cached_archive = self._get_cached_archive(\n cache_dir, strict=strict, filename=link.filename, env=env\n )\n if cached_archive is None and strict and download_func is not None:\n cached_archive = cache_dir / link.filename\n with self._archive_locks[cached_archive]:\n # Check again if the archive exists (under the lock) to avoid\n # duplicate downloads because it may have already been downloaded\n # by another thread in the meantime\n if not cached_archive.exists():\n cache_dir.mkdir(parents=True, exist_ok=True)\n try:\n download_func(link.url, cached_archive)\n except BaseException:\n cached_archive.unlink(missing_ok=True)\n raise\n\n return cached_archive\n\n def get_cached_archive_for_git(\n self, url: str, reference: str, subdirectory: str | None, env: Env\n ) -> Path | None:\n cache_dir = self.get_cache_directory_for_git(url, reference, subdirectory)\n\n return self._get_cached_archive(cache_dir, strict=False, env=env)\n\n def _get_cached_archive(\n self,\n cache_dir: Path,\n *,\n strict: bool,\n filename: str | None = None,\n env: Env | None = None,\n ) -> Path | None:\n # implication \"not strict -> env must not be None\"\n assert strict or env is not None\n # implication \"strict -> filename must not be None\"\n assert not strict or filename is not None\n\n archives = self._get_cached_archives(cache_dir)\n if not archives:\n return None\n\n candidates: list[tuple[float | None, Path]] = []\n\n for archive in archives:\n if strict:\n # in strict mode return the original cached archive instead of the\n # prioritized archive type.\n if filename == archive.name:\n return archive\n continue\n\n assert env is not None\n\n if archive.suffix != \".whl\":\n candidates.append((float(\"inf\"), archive))\n continue\n\n try:\n wheel = Wheel(archive.name)\n except InvalidWheelName:\n continue\n\n if not wheel.is_supported_by_environment(env):\n continue\n\n candidates.append(\n (wheel.get_minimum_supported_index(env.supported_tags), archive),\n )\n\n if not candidates:\n return None\n\n return min(candidates)[1]\n\n def _get_cached_archives(self, cache_dir: Path) -> list[Path]:\n archive_types = [\"whl\", \"tar.gz\", \"tar.bz2\", \"bz2\", \"zip\"]\n paths: list[Path] = []\n for archive_type in archive_types:\n paths += cache_dir.glob(f\"*.{archive_type}\")\n\n return paths\n\n# Path: src/poetry/utils/constants.py\nfrom __future__ import annotations\n\nimport os\n\n\n# Timeout for HTTP requests using the requests library.\nREQUESTS_TIMEOUT = int(os.getenv(\"POETRY_REQUESTS_TIMEOUT\", 15))\n\nRETRY_AFTER_HEADER = \"retry-after\"\n\n# Server response codes to retry requests on.\nSTATUS_FORCELIST = [429, 500, 501, 502, 503, 504]\n\n# Path: src/poetry/utils/dependency_specification.py\nfrom __future__ import annotations\n\nimport contextlib\nimport os\nimport re\nimport urllib.parse\n\nfrom pathlib import Path\nfrom typing import TYPE_CHECKING\nfrom typing import Dict\nfrom typing import List\nfrom typing import TypeVar\nfrom typing import Union\nfrom typing import cast\n\nfrom poetry.core.packages.dependency import Dependency\nfrom tomlkit.items import InlineTable\n\nfrom poetry.packages.direct_origin import DirectOrigin\n\n\nif TYPE_CHECKING:\n from poetry.core.packages.vcs_dependency import VCSDependency\n\n from poetry.utils.cache import ArtifactCache\n from poetry.utils.env import Env\n\n\nDependencySpec = Dict[str, Union[str, bool, Dict[str, Union[str, bool]], List[str]]]\nBaseSpec = TypeVar(\"BaseSpec\", DependencySpec, InlineTable)\n\nGIT_URL_SCHEMES = {\"git+http\", \"git+https\", \"git+ssh\"}\n\n\ndef dependency_to_specification(\n dependency: Dependency, specification: BaseSpec\n) -> BaseSpec:\n if dependency.is_vcs():\n dependency = cast(\"VCSDependency\", dependency)\n assert dependency.source_url is not None\n specification[dependency.vcs] = dependency.source_url\n if dependency.reference:\n specification[\"rev\"] = dependency.reference\n elif dependency.is_file() or dependency.is_directory():\n assert dependency.source_url is not None\n specification[\"path\"] = dependency.source_url\n elif dependency.is_url():\n assert dependency.source_url is not None\n specification[\"url\"] = dependency.source_url\n elif dependency.pretty_constraint != \"*\" and not dependency.constraint.is_empty():\n specification[\"version\"] = dependency.pretty_constraint\n\n if not dependency.marker.is_any():\n specification[\"markers\"] = str(dependency.marker)\n\n if dependency.extras:\n specification[\"extras\"] = sorted(dependency.extras)\n\n return specification\n\n\nclass RequirementsParser:\n def __init__(\n self,\n *,\n artifact_cache: ArtifactCache,\n env: Env | None = None,\n cwd: Path | None = None,\n ) -> None:\n self._direct_origin = DirectOrigin(artifact_cache)\n self._env = env\n self._cwd = cwd or Path.cwd()\n\n def parse(self, requirement: str) -> DependencySpec:\n requirement = requirement.strip()\n\n specification = self._parse_pep508(requirement)\n\n if specification is not None:\n return specification\n\n extras = []\n extras_m = re.search(r\"\\[([\\w\\d,-_ ]+)\\]$\", requirement)\n if extras_m:\n extras = [e.strip() for e in extras_m.group(1).split(\",\")]\n requirement, _ = requirement.split(\"[\")\n\n specification = (\n self._parse_url(requirement)\n or self._parse_path(requirement)\n or self._parse_simple(requirement)\n )\n\n if specification:\n if extras:\n specification.setdefault(\"extras\", extras)\n return specification\n\n raise ValueError(f\"Invalid dependency specification: {requirement}\")\n\n def _parse_pep508(self, requirement: str) -> DependencySpec | None:\n if \" ; \" not in requirement and re.search(r\"@[\\^~!=<>\\d]\", requirement):\n # this is of the form package@, do not attempt to parse it\n return None\n\n with contextlib.suppress(ValueError):\n dependency = Dependency.create_from_pep_508(requirement)\n specification: DependencySpec = {}\n specification = dependency_to_specification(dependency, specification)\n\n if specification:\n specification[\"name\"] = dependency.name\n return specification\n\n return None\n\n def _parse_git_url(self, requirement: str) -> DependencySpec | None:\n from poetry.core.vcs.git import Git\n from poetry.core.vcs.git import ParsedUrl\n\n parsed = ParsedUrl.parse(requirement)\n url = Git.normalize_url(requirement)\n\n pair = {\"name\": parsed.name, \"git\": url.url}\n\n if parsed.rev:\n pair[\"rev\"] = url.revision\n\n if parsed.subdirectory:\n pair[\"subdirectory\"] = parsed.subdirectory\n\n source_root = self._env.path.joinpath(\"src\") if self._env else None\n package = self._direct_origin.get_package_from_vcs(\n \"git\",\n url=url.url,\n rev=pair.get(\"rev\"),\n subdirectory=parsed.subdirectory,\n source_root=source_root,\n )\n pair[\"name\"] = package.name\n return pair\n\n def _parse_url(self, requirement: str) -> DependencySpec | None:\n url_parsed = urllib.parse.urlparse(requirement)\n if not (url_parsed.scheme and url_parsed.netloc):\n return None\n\n if url_parsed.scheme in GIT_URL_SCHEMES:\n return self._parse_git_url(requirement)\n\n if url_parsed.scheme in [\"http\", \"https\"]:\n package = self._direct_origin.get_package_from_url(requirement)\n assert package.source_url is not None\n return {\"name\": package.name, \"url\": package.source_url}\n\n return None\n\n def _parse_path(self, requirement: str) -> DependencySpec | None:\n if (os.path.sep in requirement or \"/\" in requirement) and (\n self._cwd.joinpath(requirement).exists()\n or Path(requirement).expanduser().exists()\n and Path(requirement).expanduser().is_absolute()\n ):\n path = Path(requirement).expanduser()\n is_absolute = path.is_absolute()\n\n if not path.is_absolute():\n path = self._cwd.joinpath(requirement)\n\n if path.is_file():\n package = self._direct_origin.get_package_from_file(path.resolve())\n else:\n package = self._direct_origin.get_package_from_directory(path.resolve())\n\n return {\n \"name\": package.name,\n \"path\": (\n path.relative_to(self._cwd).as_posix()\n if not is_absolute\n else path.as_posix()\n ),\n }\n\n return None\n\n def _parse_simple(\n self,\n requirement: str,\n ) -> DependencySpec | None:\n extras: list[str] = []\n pair = re.sub(\n \"^([^@=: ]+)(?:@|==|(?~!])=|:| )(.*)$\", \"\\\\1 \\\\2\", requirement\n )\n pair = pair.strip()\n\n require: DependencySpec = {}\n\n if \" \" in pair:\n name, version = pair.split(\" \", 1)\n extras_m = re.search(r\"\\[([\\w\\d,-_]+)\\]$\", name)\n if extras_m:\n extras = [e.strip() for e in extras_m.group(1).split(\",\")]\n name, _ = name.split(\"[\")\n\n require[\"name\"] = name\n if version != \"latest\":\n require[\"version\"] = version\n else:\n m = re.match(\n r\"^([^><=!: ]+)((?:>=|<=|>|<|!=|~=|~|\\^).*)$\", requirement.strip()\n )\n if m:\n name, constraint = m.group(1), m.group(2)\n extras_m = re.search(r\"\\[([\\w\\d,-_]+)\\]$\", name)\n if extras_m:\n extras = [e.strip() for e in extras_m.group(1).split(\",\")]\n name, _ = name.split(\"[\")\n\n require[\"name\"] = name\n require[\"version\"] = constraint\n else:\n extras_m = re.search(r\"\\[([\\w\\d,-_]+)\\]$\", pair)\n if extras_m:\n extras = [e.strip() for e in extras_m.group(1).split(\",\")]\n pair, _ = pair.split(\"[\")\n\n require[\"name\"] = pair\n\n if extras:\n require[\"extras\"] = extras\n\n return require\n\n# Path: src/poetry/utils/extras.py\nfrom __future__ import annotations\n\nfrom typing import TYPE_CHECKING\n\n\nif TYPE_CHECKING:\n from collections.abc import Collection\n from collections.abc import Iterable\n from collections.abc import Mapping\n\n from packaging.utils import NormalizedName\n from poetry.core.packages.package import Package\n\n\ndef get_extra_package_names(\n packages: Iterable[Package],\n extras: Mapping[NormalizedName, Iterable[NormalizedName]],\n extra_names: Collection[NormalizedName],\n) -> set[NormalizedName]:\n \"\"\"\n Returns all package names required by the given extras.\n\n :param packages: A collection of packages, such as from Repository.packages\n :param extras: A mapping of `extras` names to lists of package names, as defined\n in the `extras` section of `poetry.lock`.\n :param extra_names: A list of strings specifying names of extra groups to resolve.\n \"\"\"\n from packaging.utils import canonicalize_name\n\n if not extra_names:\n return set()\n\n # lookup for packages by name, faster than looping over packages repeatedly\n packages_by_name = {package.name: package for package in packages}\n\n # Depth-first search, with our entry points being the packages directly required by\n # extras.\n seen_package_names = set()\n stack = [\n canonicalize_name(extra_package_name)\n for extra_name in extra_names\n for extra_package_name in extras.get(extra_name, ())\n ]\n\n while stack:\n package_name = stack.pop()\n\n # We expect to find all packages, but can just carry on if we don't.\n package = packages_by_name.get(package_name)\n if package is None or package.name in seen_package_names:\n continue\n\n seen_package_names.add(package.name)\n\n stack += [dependency.name for dependency in package.requires]\n\n return seen_package_names\n\n...\n", "needle_token_start": 11917, "needle_token_end": 12042, "code_context_ntokens": 15975, "output": ["```python\ndef _get_directory_from_hash(self, key_parts: object) -> Path:\n key = hashlib.sha256(\n json.dumps(\n key_parts, sort_keys=True, separators=(\",\", \":\"), ensure_ascii=True\n ).encode(\"ascii\")\n ).hexdigest()\n\n split_key = [key[:2], key[2:4], key[4:6], key[6:]]\n return self._cache_dir.joinpath(*split_key)\n```"]} +{"repo": "python-poetry/poetry", "name": "_validate_readme", "language": "python", "path": "src/poetry/console/commands/check.py", "position_ratio": 0.85, "description": "\nFunction Description:\n1. **Purpose**: To verify the existence of README files referenced in a project configuration file.\n2. **Input**: A string or list of strings representing README filenames, and a path object pointing to the project configuration file.\n3. **Output**: A list of error messages for each README file that is declared but does not exist in the specified location.\n4. **Procedure**: The function checks if each README file listed exists in the directory of the project configuration file. If a file does not exist, an error message is generated and added to the list of errors, which is then returned.\n\n", "instruction": "Based on the function description and code context, please retrieve and repeat the exact described function from the code context in a code block wrapped by ```:", "template": "instruction\ncode_context\ndescription\ninstruction", "code_context": "# Path: src/poetry/utils/password_manager.py\nfrom __future__ import annotations\n\nimport dataclasses\nimport functools\nimport logging\n\nfrom contextlib import suppress\nfrom typing import TYPE_CHECKING\n\n\nif TYPE_CHECKING:\n from keyring.backend import KeyringBackend\n\n from poetry.config.config import Config\n\nlogger = logging.getLogger(__name__)\n\n\nclass PasswordManagerError(Exception):\n pass\n\n\nclass PoetryKeyringError(Exception):\n pass\n\n\n@dataclasses.dataclass\nclass HTTPAuthCredential:\n username: str | None = dataclasses.field(default=None)\n password: str | None = dataclasses.field(default=None)\n\n\nclass PoetryKeyring:\n def __init__(self, namespace: str) -> None:\n self._namespace = namespace\n\n def get_credential(\n self, *names: str, username: str | None = None\n ) -> HTTPAuthCredential:\n import keyring\n\n from keyring.errors import KeyringError\n from keyring.errors import KeyringLocked\n\n for name in names:\n credential = None\n try:\n credential = keyring.get_credential(name, username)\n except KeyringLocked:\n logger.debug(\"Keyring %s is locked\", name)\n except (KeyringError, RuntimeError):\n logger.debug(\"Accessing keyring %s failed\", name, exc_info=True)\n\n if credential:\n return HTTPAuthCredential(\n username=credential.username, password=credential.password\n )\n\n return HTTPAuthCredential(username=username, password=None)\n\n def get_password(self, name: str, username: str) -> str | None:\n import keyring\n import keyring.errors\n\n name = self.get_entry_name(name)\n\n try:\n return keyring.get_password(name, username)\n except (RuntimeError, keyring.errors.KeyringError):\n raise PoetryKeyringError(\n f\"Unable to retrieve the password for {name} from the key ring\"\n )\n\n def set_password(self, name: str, username: str, password: str) -> None:\n import keyring\n import keyring.errors\n\n name = self.get_entry_name(name)\n\n try:\n keyring.set_password(name, username, password)\n except (RuntimeError, keyring.errors.KeyringError) as e:\n raise PoetryKeyringError(\n f\"Unable to store the password for {name} in the key ring: {e}\"\n )\n\n def delete_password(self, name: str, username: str) -> None:\n import keyring.errors\n\n name = self.get_entry_name(name)\n\n try:\n keyring.delete_password(name, username)\n except (RuntimeError, keyring.errors.KeyringError):\n raise PoetryKeyringError(\n f\"Unable to delete the password for {name} from the key ring\"\n )\n\n def get_entry_name(self, name: str) -> str:\n return f\"{self._namespace}-{name}\"\n\n @classmethod\n def is_available(cls) -> bool:\n logger.debug(\"Checking if keyring is available\")\n try:\n import keyring\n import keyring.backend\n except ImportError as e:\n logger.debug(\"An error occurred while importing keyring: %s\", e)\n return False\n\n def backend_name(backend: KeyringBackend) -> str:\n name: str = backend.name\n return name.split(\" \")[0]\n\n def backend_is_valid(backend: KeyringBackend) -> bool:\n name = backend_name(backend)\n if name in (\"chainer\", \"fail\", \"null\"):\n logger.debug(f\"Backend {backend.name!r} is not suitable\")\n return False\n elif \"plaintext\" in backend.name.lower():\n logger.debug(f\"Not using plaintext keyring backend {backend.name!r}\")\n return False\n\n return True\n\n backend = keyring.get_keyring()\n if backend_name(backend) == \"chainer\":\n backends = keyring.backend.get_all_keyring()\n valid_backend = next((b for b in backends if backend_is_valid(b)), None)\n else:\n valid_backend = backend if backend_is_valid(backend) else None\n\n if valid_backend is None:\n logger.debug(\"No valid keyring backend was found\")\n return False\n else:\n logger.debug(f\"Using keyring backend {backend.name!r}\")\n return True\n\n\nclass PasswordManager:\n def __init__(self, config: Config) -> None:\n self._config = config\n\n @functools.cached_property\n def use_keyring(self) -> bool:\n return self._config.get(\"keyring.enabled\") and PoetryKeyring.is_available()\n\n @functools.cached_property\n def keyring(self) -> PoetryKeyring:\n if not self.use_keyring:\n raise PoetryKeyringError(\n \"Access to keyring was requested, but it is not available\"\n )\n\n return PoetryKeyring(\"poetry-repository\")\n\n @staticmethod\n def warn_plaintext_credentials_stored() -> None:\n logger.warning(\"Using a plaintext file to store credentials\")\n\n def set_pypi_token(self, repo_name: str, token: str) -> None:\n if not self.use_keyring:\n self.warn_plaintext_credentials_stored()\n self._config.auth_config_source.add_property(\n f\"pypi-token.{repo_name}\", token\n )\n else:\n self.keyring.set_password(repo_name, \"__token__\", token)\n\n def get_pypi_token(self, repo_name: str) -> str | None:\n \"\"\"Get PyPi token.\n\n First checks the environment variables for a token,\n then the configured username/password and the\n available keyring.\n\n :param repo_name: Name of repository.\n :return: Returns a token as a string if found, otherwise None.\n \"\"\"\n token: str | None = self._config.get(f\"pypi-token.{repo_name}\")\n if token:\n return token\n\n if self.use_keyring:\n return self.keyring.get_password(repo_name, \"__token__\")\n else:\n return None\n\n def delete_pypi_token(self, repo_name: str) -> None:\n if not self.use_keyring:\n return self._config.auth_config_source.remove_property(\n f\"pypi-token.{repo_name}\"\n )\n\n self.keyring.delete_password(repo_name, \"__token__\")\n\n def get_http_auth(self, repo_name: str) -> dict[str, str | None] | None:\n username = self._config.get(f\"http-basic.{repo_name}.username\")\n password = self._config.get(f\"http-basic.{repo_name}.password\")\n if not username and not password:\n return None\n\n if not password:\n if self.use_keyring:\n password = self.keyring.get_password(repo_name, username)\n else:\n return None\n\n return {\n \"username\": username,\n \"password\": password,\n }\n\n def set_http_password(self, repo_name: str, username: str, password: str) -> None:\n auth = {\"username\": username}\n\n if not self.use_keyring:\n self.warn_plaintext_credentials_stored()\n...\n# Path: src/poetry/utils/patterns.py\nfrom __future__ import annotations\n\nimport re\n\n\nwheel_file_re = re.compile(\n r\"^(?P(?P.+?)-(?P\\d.*?))\"\n r\"(-(?P\\d.*?))?\"\n r\"-(?P.+?)\"\n r\"-(?P.+?)\"\n r\"-(?P.+?)\"\n r\"\\.whl|\\.dist-info$\",\n re.VERBOSE,\n)\n\nsdist_file_re = re.compile(\n r\"^(?P(?P.+?)-(?P\\d.*?))\"\n r\"(\\.sdist)?\\.(?P(zip|tar(\\.(gz|bz2|xz|Z))?))$\"\n)\n\n# Path: src/poetry/utils/pip.py\nfrom __future__ import annotations\n\nfrom typing import TYPE_CHECKING\n\nfrom poetry.exceptions import PoetryException\nfrom poetry.utils.env import EnvCommandError\n\n\nif TYPE_CHECKING:\n from pathlib import Path\n\n from poetry.utils.env import Env\n\n\ndef pip_install(\n path: Path,\n environment: Env,\n editable: bool = False,\n deps: bool = False,\n upgrade: bool = False,\n) -> str:\n is_wheel = path.suffix == \".whl\"\n\n # We disable version check here as we are already pinning to version available in\n # either the virtual environment or the virtualenv package embedded wheel. Version\n # checks are a wasteful network call that adds a lot of wait time when installing a\n # lot of packages.\n args = [\n \"install\",\n \"--disable-pip-version-check\",\n \"--isolated\",\n \"--no-input\",\n \"--prefix\",\n str(environment.path),\n ]\n\n if not is_wheel and not editable:\n args.insert(1, \"--use-pep517\")\n\n if upgrade:\n args.append(\"--upgrade\")\n\n if not deps:\n args.append(\"--no-deps\")\n\n if editable:\n if not path.is_dir():\n raise PoetryException(\n \"Cannot install non directory dependencies in editable mode\"\n )\n args.append(\"-e\")\n\n args.append(str(path))\n\n try:\n return environment.run_pip(*args)\n except EnvCommandError as e:\n raise PoetryException(f\"Failed to install {path}\") from e\n\n# Path: src/poetry/utils/setup_reader.py\nfrom __future__ import annotations\n\nimport ast\n\nfrom configparser import ConfigParser\nfrom typing import TYPE_CHECKING\nfrom typing import Any\nfrom typing import ClassVar\n\nfrom poetry.core.constraints.version import Version\n\n\nif TYPE_CHECKING:\n from pathlib import Path\n\n\nclass SetupReaderError(Exception):\n pass\n\n\nclass SetupReader:\n \"\"\"\n Class that reads a setup.py file without executing it.\n \"\"\"\n\n DEFAULT: ClassVar[dict[str, Any]] = {\n \"name\": None,\n \"version\": None,\n \"description\": None,\n \"install_requires\": [],\n \"extras_require\": {},\n \"python_requires\": None,\n }\n\n FILES: ClassVar[list[str]] = [\"setup.py\", \"setup.cfg\"]\n\n @classmethod\n def read_from_directory(cls, directory: Path) -> dict[str, Any]:\n result = cls.DEFAULT.copy()\n for filename in cls.FILES:\n filepath = directory / filename\n if not filepath.exists():\n continue\n\n read_file_func = getattr(cls(), \"read_\" + filename.replace(\".\", \"_\"))\n new_result = read_file_func(filepath)\n\n for key in result:\n if new_result[key]:\n result[key] = new_result[key]\n\n return result\n\n def read_setup_py(self, filepath: Path) -> dict[str, Any]:\n with filepath.open(encoding=\"utf-8\") as f:\n content = f.read()\n\n result: dict[str, Any] = {}\n\n body = ast.parse(content).body\n\n setup_call = self._find_setup_call(body)\n if setup_call is None:\n return self.DEFAULT\n\n # Inspecting keyword arguments\n call, body = setup_call\n result[\"name\"] = self._find_single_string(call, body, \"name\")\n result[\"version\"] = self._find_single_string(call, body, \"version\")\n result[\"description\"] = self._find_single_string(call, body, \"description\")\n result[\"install_requires\"] = self._find_install_requires(call, body)\n result[\"extras_require\"] = self._find_extras_require(call, body)\n result[\"python_requires\"] = self._find_single_string(\n call, body, \"python_requires\"\n )\n\n return result\n\n def read_setup_cfg(self, filepath: Path) -> dict[str, Any]:\n parser = ConfigParser()\n\n parser.read(str(filepath))\n\n name = None\n version = None\n description = None\n if parser.has_option(\"metadata\", \"name\"):\n name = parser.get(\"metadata\", \"name\")\n\n if parser.has_option(\"metadata\", \"version\"):\n version = Version.parse(parser.get(\"metadata\", \"version\")).text\n\n if parser.has_option(\"metadata\", \"description\"):\n description = parser.get(\"metadata\", \"description\")\n\n install_requires = []\n extras_require: dict[str, list[str]] = {}\n python_requires = None\n if parser.has_section(\"options\"):\n if parser.has_option(\"options\", \"install_requires\"):\n for dep in parser.get(\"options\", \"install_requires\").split(\"\\n\"):\n dep = dep.strip()\n if not dep:\n continue\n\n install_requires.append(dep)\n\n if parser.has_option(\"options\", \"python_requires\"):\n python_requires = parser.get(\"options\", \"python_requires\")\n\n if parser.has_section(\"options.extras_require\"):\n for group in parser.options(\"options.extras_require\"):\n extras_require[group] = []\n deps = parser.get(\"options.extras_require\", group)\n for dep in deps.split(\"\\n\"):\n dep = dep.strip()\n if not dep:\n continue\n\n extras_require[group].append(dep)\n\n return {\n \"name\": name,\n \"version\": version,\n \"description\": description,\n \"install_requires\": install_requires,\n \"extras_require\": extras_require,\n \"python_requires\": python_requires,\n }\n\n def _find_setup_call(\n self, elements: list[ast.stmt]\n ) -> tuple[ast.Call, list[ast.stmt]] | None:\n funcdefs: list[ast.stmt] = []\n for i, element in enumerate(elements):\n if isinstance(element, ast.If) and i == len(elements) - 1:\n # Checking if the last element is an if statement\n # and if it is 'if __name__ == \"__main__\"' which\n # could contain the call to setup()\n test = element.test\n if not isinstance(test, ast.Compare):\n continue\n\n left = test.left\n if not isinstance(left, ast.Name):\n continue\n\n if left.id != \"__name__\":\n continue\n\n setup_call = self._find_sub_setup_call([element])\n if setup_call is None:\n continue\n\n call, body = setup_call\n return call, body + elements\n\n if not isinstance(element, ast.Expr):\n if isinstance(element, ast.FunctionDef):\n funcdefs.append(element)\n\n continue\n\n value = element.value\n if not isinstance(value, ast.Call):\n continue\n\n func = value.func\n if not (isinstance(func, ast.Name) and func.id == \"setup\") and not (\n isinstance(func, ast.Attribute)\n and getattr(func.value, \"id\", None) == \"setuptools\"\n and func.attr == \"setup\"\n ):\n continue\n\n return value, elements\n\n # Nothing, we inspect the function definitions\n return self._find_sub_setup_call(funcdefs)\n\n def _find_sub_setup_call(\n self, elements: list[ast.stmt]\n ) -> tuple[ast.Call, list[ast.stmt]] | None:\n for element in elements:\n if not isinstance(element, (ast.FunctionDef, ast.If)):\n continue\n\n setup_call = self._find_setup_call(element.body)\n if setup_call is not None:\n sub_call, body = setup_call\n\n body = elements + body\n\n return sub_call, body\n\n return None\n\n def _find_install_requires(self, call: ast.Call, body: list[ast.stmt]) -> list[str]:\n value = self._find_in_call(call, \"install_requires\")\n if value is None:\n # Trying to find in kwargs\n kwargs = self._find_call_kwargs(call)\n\n if kwargs is None or not isinstance(kwargs, ast.Name):\n return []\n\n variable = self._find_variable_in_body(body, kwargs.id)\n\n if isinstance(variable, ast.Dict):\n value = self._find_in_dict(variable, \"install_requires\")\n\n elif (\n isinstance(variable, ast.Call)\n and isinstance(variable.func, ast.Name)\n and variable.func.id == \"dict\"\n ):\n value = self._find_in_call(variable, \"install_requires\")\n\n else:\n raise SetupReaderError(f\"Cannot handle variable {variable}\")\n\n if value is None:\n return []\n\n if isinstance(value, ast.Name):\n value = self._find_variable_in_body(body, value.id)\n\n if isinstance(value, ast.Constant) and value.value is None:\n return []\n\n if isinstance(value, ast.List):\n return string_list_values(value)\n\n raise SetupReaderError(f\"Cannot handle value of type {type(value)}\")\n\n def _find_extras_require(\n self, call: ast.Call, body: list[ast.stmt]\n ) -> dict[str, list[str]]:\n value = self._find_in_call(call, \"extras_require\")\n if value is None:\n # Trying to find in kwargs\n kwargs = self._find_call_kwargs(call)\n\n if kwargs is None or not isinstance(kwargs, ast.Name):\n return {}\n\n variable = self._find_variable_in_body(body, kwargs.id)\n if isinstance(variable, ast.Dict):\n value = self._find_in_dict(variable, \"extras_require\")\n\n elif (\n isinstance(variable, ast.Call)\n and isinstance(variable.func, ast.Name)\n and variable.func.id == \"dict\"\n ):\n value = self._find_in_call(variable, \"extras_require\")\n\n else:\n raise SetupReaderError(f\"Cannot handle variable {variable}\")\n\n if value is None:\n return {}\n\n if isinstance(value, ast.Name):\n value = self._find_variable_in_body(body, value.id)\n\n if isinstance(value, ast.Constant) and value.value is None:\n return {}\n\n if isinstance(value, ast.Dict):\n extras_require: dict[str, list[str]] = {}\n val: ast.expr | None\n for key, val in zip(value.keys, value.values):\n if not isinstance(key, ast.Constant) or not isinstance(key.value, str):\n raise SetupReaderError(f\"Cannot handle key {key}\")\n\n if isinstance(val, ast.Name):\n val = self._find_variable_in_body(body, val.id)\n\n if not isinstance(val, ast.List):\n raise SetupReaderError(f\"Cannot handle value of type {type(val)}\")\n\n extras_require[key.value] = string_list_values(val)\n\n return extras_require\n\n raise SetupReaderError(f\"Cannot handle value of type {type(value)}\")\n\n def _find_single_string(\n self, call: ast.Call, body: list[ast.stmt], name: str\n ) -> str | None:\n value = self._find_in_call(call, name)\n if value is None:\n # Trying to find in kwargs\n kwargs = self._find_call_kwargs(call)\n\n if kwargs is None or not isinstance(kwargs, ast.Name):\n return None\n\n variable = self._find_variable_in_body(body, kwargs.id)\n if not isinstance(variable, (ast.Dict, ast.Call)):\n return None\n\n if isinstance(variable, ast.Call):\n if not isinstance(variable.func, ast.Name):\n return None\n\n if variable.func.id != \"dict\":\n return None\n\n value = self._find_in_call(variable, name)\n else:\n value = self._find_in_dict(variable, name)\n\n if value is None:\n return None\n\n if isinstance(value, ast.Constant) and isinstance(value.value, str):\n return value.value\n elif isinstance(value, ast.Name):\n variable = self._find_variable_in_body(body, value.id)\n\n if (\n variable is not None\n and isinstance(variable, ast.Constant)\n and isinstance(variable.value, str)\n ):\n return variable.value\n\n return None\n\n def _find_in_call(self, call: ast.Call, name: str) -> Any | None:\n for keyword in call.keywords:\n if keyword.arg == name:\n return keyword.value\n return None\n\n def _find_call_kwargs(self, call: ast.Call) -> Any | None:\n kwargs = None\n for keyword in call.keywords:\n if keyword.arg is None:\n kwargs = keyword.value\n\n return kwargs\n\n def _find_variable_in_body(\n self, body: list[ast.stmt], name: str\n ) -> ast.expr | None:\n for elem in body:\n if not isinstance(elem, ast.Assign):\n continue\n\n for target in elem.targets:\n if not isinstance(target, ast.Name):\n continue\n\n if target.id == name:\n return elem.value\n\n return None\n\n def _find_in_dict(self, dict_: ast.Dict, name: str) -> ast.expr | None:\n for key, val in zip(dict_.keys, dict_.values):\n if (\n isinstance(key, ast.Constant)\n and isinstance(key.value, str)\n and key.value == name\n ):\n return val\n\n return None\n\n\ndef string_list_values(value: ast.List) -> list[str]:\n strings = []\n for element in value.elts:\n if isinstance(element, ast.Constant) and isinstance(element.value, str):\n strings.append(element.value)\n\n else:\n raise SetupReaderError(\"Found non-string element in list\")\n\n return strings\n\n# Path: src/poetry/utils/shell.py\nfrom __future__ import annotations\n\nimport os\nimport shutil\nimport signal\nimport subprocess\nimport sys\n\nfrom pathlib import Path\nfrom typing import TYPE_CHECKING\nfrom typing import Any\n\nimport pexpect\n\nfrom shellingham import ShellDetectionFailure\nfrom shellingham import detect_shell\n\nfrom poetry.utils._compat import WINDOWS\n\n\nif TYPE_CHECKING:\n from poetry.utils.env import VirtualEnv\n\n\nclass Shell:\n \"\"\"\n Represents the current shell.\n \"\"\"\n\n _shell = None\n\n def __init__(self, name: str, path: str) -> None:\n self._name = name\n self._path = path\n\n @property\n def name(self) -> str:\n return self._name\n\n @property\n def path(self) -> str:\n return self._path\n\n @classmethod\n def get(cls) -> Shell:\n \"\"\"\n Retrieve the current shell.\n \"\"\"\n if cls._shell is not None:\n return cls._shell\n\n try:\n name, path = detect_shell(os.getpid())\n except (RuntimeError, ShellDetectionFailure):\n shell = None\n\n if os.name == \"posix\":\n shell = os.environ.get(\"SHELL\")\n elif os.name == \"nt\":\n shell = os.environ.get(\"COMSPEC\")\n\n if not shell:\n raise RuntimeError(\"Unable to detect the current shell.\")\n\n name, path = Path(shell).stem, shell\n\n cls._shell = cls(name, path)\n\n return cls._shell\n\n def activate(self, env: VirtualEnv) -> int | None:\n activate_script = self._get_activate_script()\n if WINDOWS:\n bin_path = env.path / \"Scripts\"\n # Python innstalled via msys2 on Windows might produce a POSIX-like venv\n # See https://github.com/python-poetry/poetry/issues/8638\n bin_dir = \"Scripts\" if bin_path.exists() else \"bin\"\n else:\n bin_dir = \"bin\"\n activate_path = env.path / bin_dir / activate_script\n\n # mypy requires using sys.platform instead of WINDOWS constant\n # in if statements to properly type check on Windows\n if sys.platform == \"win32\":\n args = None\n if self._name in (\"powershell\", \"pwsh\"):\n args = [\"-NoExit\", \"-File\", str(activate_path)]\n elif self._name == \"cmd\":\n # /K will execute the bat file and\n # keep the cmd process from terminating\n args = [\"/K\", str(activate_path)]\n\n if args:\n completed_proc = subprocess.run([self.path, *args])\n return completed_proc.returncode\n else:\n # If no args are set, execute the shell within the venv\n # This activates it, but there could be some features missing:\n # deactivate command might not work\n # shell prompt will not be modified.\n return env.execute(self._path)\n\n import shlex\n\n terminal = shutil.get_terminal_size()\n cmd = f\"{self._get_source_command()} {shlex.quote(str(activate_path))}\"\n\n with env.temp_environ():\n if self._name == \"nu\":\n args = [\"-e\", cmd]\n elif self._name == \"fish\":\n args = [\"-i\", \"--init-command\", cmd]\n else:\n args = [\"-i\"]\n\n c = pexpect.spawn(\n self._path, args, dimensions=(terminal.lines, terminal.columns)\n )\n\n if self._name in [\"zsh\"]:\n c.setecho(False)\n\n if self._name == \"zsh\":\n # Under ZSH the source command should be invoked in zsh's bash emulator\n quoted_activate_path = shlex.quote(str(activate_path))\n c.sendline(f\"emulate bash -c {shlex.quote(f'. {quoted_activate_path}')}\")\n elif self._name == \"xonsh\":\n c.sendline(f\"vox activate {shlex.quote(str(env.path))}\")\n elif self._name in [\"nu\", \"fish\"]:\n # If this is nu or fish, we don't want to send the activation command to the\n # command line since we already ran it via the shell's invocation.\n pass\n else:\n c.sendline(cmd)\n\n def resize(sig: Any, data: Any) -> None:\n terminal = shutil.get_terminal_size()\n c.setwinsize(terminal.lines, terminal.columns)\n\n signal.signal(signal.SIGWINCH, resize)\n\n # Interact with the new shell.\n c.interact(escape_character=None)\n c.close()\n\n sys.exit(c.exitstatus)\n\n def _get_activate_script(self) -> str:\n if self._name == \"fish\":\n suffix = \".fish\"\n elif self._name in (\"csh\", \"tcsh\"):\n suffix = \".csh\"\n elif self._name in (\"powershell\", \"pwsh\"):\n suffix = \".ps1\"\n elif self._name == \"cmd\":\n suffix = \".bat\"\n elif self._name == \"nu\":\n suffix = \".nu\"\n else:\n suffix = \"\"\n\n return \"activate\" + suffix\n\n def _get_source_command(self) -> str:\n if self._name in (\"fish\", \"csh\", \"tcsh\"):\n return \"source\"\n elif self._name == \"nu\":\n return \"overlay use\"\n return \".\"\n\n def __repr__(self) -> str:\n return f'{self.__class__.__name__}(\"{self._name}\", \"{self._path}\")'\n\n# Path: src/poetry/utils/source.py\nfrom __future__ import annotations\n\nfrom typing import TYPE_CHECKING\n\n\nif TYPE_CHECKING:\n from tomlkit.items import Table\n\n from poetry.config.source import Source\n\n\ndef source_to_table(source: Source) -> Table:\n from tomlkit import nl\n from tomlkit import table\n\n source_table: Table = table()\n for key, value in source.to_dict().items():\n source_table.add(key, value)\n source_table.add(nl())\n return source_table\n\n# Path: src/poetry/utils/wheel.py\nfrom __future__ import annotations\n\nimport logging\n\nfrom typing import TYPE_CHECKING\n\nfrom packaging.tags import Tag\n\nfrom poetry.utils.patterns import wheel_file_re\n\n\nif TYPE_CHECKING:\n from poetry.utils.env import Env\n\n\nlogger = logging.getLogger(__name__)\n\n\nclass InvalidWheelName(Exception):\n pass\n\n\nclass Wheel:\n def __init__(self, filename: str) -> None:\n wheel_info = wheel_file_re.match(filename)\n if not wheel_info:\n raise InvalidWheelName(f\"{filename} is not a valid wheel filename.\")\n\n self.filename = filename\n self.name = wheel_info.group(\"name\").replace(\"_\", \"-\")\n self.version = wheel_info.group(\"ver\").replace(\"_\", \"-\")\n self.build_tag = wheel_info.group(\"build\")\n self.pyversions = wheel_info.group(\"pyver\").split(\".\")\n self.abis = wheel_info.group(\"abi\").split(\".\")\n self.plats = wheel_info.group(\"plat\").split(\".\")\n\n self.tags = {\n Tag(x, y, z) for x in self.pyversions for y in self.abis for z in self.plats\n }\n\n def get_minimum_supported_index(self, tags: list[Tag]) -> int | None:\n indexes = [tags.index(t) for t in self.tags if t in tags]\n\n return min(indexes) if indexes else None\n\n def is_supported_by_environment(self, env: Env) -> bool:\n return bool(set(env.supported_tags).intersection(self.tags))\n\n# Path: src/poetry/vcs/__init__.py\n\n# Path: src/poetry/version/__init__.py\n\n# Path: src/poetry/version/version_selector.py\nfrom __future__ import annotations\n\nfrom typing import TYPE_CHECKING\n\n\nif TYPE_CHECKING:\n from poetry.core.packages.package import Package\n\n from poetry.repositories import RepositoryPool\n\n\nclass VersionSelector:\n def __init__(self, pool: RepositoryPool) -> None:\n self._pool = pool\n\n def find_best_candidate(\n self,\n package_name: str,\n target_package_version: str | None = None,\n allow_prereleases: bool = False,\n source: str | None = None,\n ) -> Package | None:\n \"\"\"\n Given a package name and optional version,\n returns the latest Package that matches\n \"\"\"\n from poetry.factory import Factory\n\n dependency = Factory.create_dependency(\n package_name,\n {\n \"version\": target_package_version or \"*\",\n \"allow-prereleases\": allow_prereleases,\n \"source\": source,\n },\n )\n candidates = self._pool.find_packages(dependency)\n only_prereleases = all(c.version.is_unstable() for c in candidates)\n\n if not candidates:\n return None\n\n package = None\n for candidate in candidates:\n if (\n candidate.is_prerelease()\n and not dependency.allows_prereleases()\n and not only_prereleases\n ):\n continue\n\n # Select highest version of the two\n if package is None or package.version < candidate.version:\n package = candidate\n\n return package\n\n# Path: src/poetry/console/commands/__init__.py\n\n# Path: src/poetry/console/commands/about.py\nfrom __future__ import annotations\n\nfrom typing import TYPE_CHECKING\n\nfrom poetry.console.commands.command import Command\n\n\nif TYPE_CHECKING:\n from collections.abc import Callable\n\n\nclass AboutCommand(Command):\n name = \"about\"\n\n description = \"Shows information about Poetry.\"\n\n def handle(self) -> int:\n from poetry.utils._compat import metadata\n\n # The metadata.version that we import for Python 3.7 is untyped, work around\n # that.\n version: Callable[[str], str] = metadata.version\n\n self.line(\n f\"\"\"\\\nPoetry - Package Management for Python\n\nVersion: {version('poetry')}\nPoetry-Core Version: {version('poetry-core')}\n\nPoetry is a dependency manager tracking local dependencies of your projects\\\n and libraries.\nSee https://github.com/python-poetry/poetry for more information.\\\n\"\"\"\n )\n\n return 0\n\n# Path: src/poetry/console/commands/add.py\nfrom __future__ import annotations\n\nimport contextlib\n\nfrom typing import TYPE_CHECKING\nfrom typing import Any\nfrom typing import ClassVar\n\nfrom cleo.helpers import argument\nfrom cleo.helpers import option\nfrom packaging.utils import canonicalize_name\nfrom poetry.core.packages.dependency_group import MAIN_GROUP\nfrom tomlkit.toml_document import TOMLDocument\n\nfrom poetry.console.commands.init import InitCommand\nfrom poetry.console.commands.installer_command import InstallerCommand\n\n\nif TYPE_CHECKING:\n from cleo.io.inputs.argument import Argument\n from cleo.io.inputs.option import Option\n\n\nclass AddCommand(InstallerCommand, InitCommand):\n name = \"add\"\n description = \"Adds a new dependency to pyproject.toml.\"\n\n arguments: ClassVar[list[Argument]] = [\n argument(\"name\", \"The packages to add.\", multiple=True)\n ]\n options: ClassVar[list[Option]] = [\n option(\n \"group\",\n \"-G\",\n \"The group to add the dependency to.\",\n flag=False,\n default=MAIN_GROUP,\n ),\n option(\n \"dev\",\n \"D\",\n \"Add as a development dependency. (Deprecated) Use\"\n \" --group=dev instead.\",\n ),\n option(\"editable\", \"e\", \"Add vcs/path dependencies as editable.\"),\n option(\n \"extras\",\n \"E\",\n \"Extras to activate for the dependency.\",\n flag=False,\n multiple=True,\n ),\n option(\"optional\", None, \"Add as an optional dependency.\"),\n option(\n \"python\",\n None,\n \"Python version for which the dependency must be installed.\",\n flag=False,\n ),\n option(\n \"platform\",\n None,\n \"Platforms for which the dependency must be installed.\",\n flag=False,\n ),\n option(\n \"source\",\n None,\n \"Name of the source to use to install the package.\",\n flag=False,\n ),\n option(\"allow-prereleases\", None, \"Accept prereleases.\"),\n option(\n \"dry-run\",\n None,\n \"Output the operations but do not execute anything (implicitly enables\"\n \" --verbose).\",\n ),\n option(\"lock\", None, \"Do not perform operations (only update the lockfile).\"),\n ]\n examples = \"\"\"\\\nIf you do not specify a version constraint, poetry will choose a suitable one based on\\\n the available package versions.\n\nYou can specify a package in the following forms:\n - A single name (requests)\n - A name and a constraint (requests@^2.23.0)\n - A git url (git+https://github.com/python-poetry/poetry.git)\n - A git url with a revision\\\n (git+https://github.com/python-poetry/poetry.git#develop)\n - A subdirectory of a git repository\\\n (git+https://github.com/python-poetry/poetry.git#subdirectory=tests/fixtures/sample_project)\n - A git SSH url (git+ssh://github.com/python-poetry/poetry.git)\n - A git SSH url with a revision\\\n (git+ssh://github.com/python-poetry/poetry.git#develop)\n - A file path (../my-package/my-package.whl)\n - A directory (../my-package/)\n - A url (https://example.com/packages/my-package-0.1.0.tar.gz)\n\"\"\"\n help = f\"\"\"\\\nThe add command adds required packages to your pyproject.toml and installs\\\n them.\n\n{examples}\n\"\"\"\n\n loggers: ClassVar[list[str]] = [\n \"poetry.repositories.pypi_repository\",\n \"poetry.inspection.info\",\n ]\n\n def handle(self) -> int:\n from poetry.core.constraints.version import parse_constraint\n from tomlkit import inline_table\n from tomlkit import parse as parse_toml\n from tomlkit import table\n\n from poetry.factory import Factory\n\n packages = self.argument(\"name\")\n if self.option(\"dev\"):\n self.line_error(\n \"The --dev option is deprecated, \"\n \"use the `--group dev` notation instead.\"\n )\n group = \"dev\"\n else:\n group = self.option(\"group\", self.default_group or MAIN_GROUP)\n\n if self.option(\"extras\") and len(packages) > 1:\n raise ValueError(\n \"You can only specify one package when using the --extras option\"\n )\n\n # tomlkit types are awkward to work with, treat content as a mostly untyped\n # dictionary.\n content: dict[str, Any] = self.poetry.file.read()\n poetry_content = content[\"tool\"][\"poetry\"]\n project_name = (\n canonicalize_name(name) if (name := poetry_content.get(\"name\")) else None\n )\n\n if group == MAIN_GROUP:\n if \"dependencies\" not in poetry_content:\n poetry_content[\"dependencies\"] = table()\n\n section = poetry_content[\"dependencies\"]\n else:\n if \"group\" not in poetry_content:\n poetry_content[\"group\"] = table(is_super_table=True)\n\n groups = poetry_content[\"group\"]\n if group not in groups:\n dependencies_toml: dict[str, Any] = parse_toml(\n f\"[tool.poetry.group.{group}.dependencies]\\n\\n\"\n )\n group_table = dependencies_toml[\"tool\"][\"poetry\"][\"group\"][group]\n poetry_content[\"group\"][group] = group_table\n\n if \"dependencies\" not in poetry_content[\"group\"][group]:\n poetry_content[\"group\"][group][\"dependencies\"] = table()\n\n section = poetry_content[\"group\"][group][\"dependencies\"]\n\n existing_packages = self.get_existing_packages_from_input(packages, section)\n\n if existing_packages:\n self.notify_about_existing_packages(existing_packages)\n\n packages = [name for name in packages if name not in existing_packages]\n\n if not packages:\n self.line(\"Nothing to add.\")\n return 0\n\n requirements = self._determine_requirements(\n packages,\n allow_prereleases=self.option(\"allow-prereleases\"),\n source=self.option(\"source\"),\n )\n\n for _constraint in requirements:\n version = _constraint.get(\"version\")\n if version is not None:\n # Validate version constraint\n assert isinstance(version, str)\n parse_constraint(version)\n\n constraint: dict[str, Any] = inline_table()\n for name, value in _constraint.items():\n if name == \"name\":\n continue\n\n constraint[name] = value\n\n if self.option(\"optional\"):\n constraint[\"optional\"] = True\n\n if self.option(\"allow-prereleases\"):\n constraint[\"allow-prereleases\"] = True\n\n if self.option(\"extras\"):\n extras = []\n for extra in self.option(\"extras\"):\n extras += extra.split()\n\n constraint[\"extras\"] = extras\n\n if self.option(\"editable\"):\n if \"git\" in _constraint or \"path\" in _constraint:\n constraint[\"develop\"] = True\n else:\n self.line_error(\n \"\\n\"\n \"Failed to add packages. \"\n \"Only vcs/path dependencies support editable installs. \"\n f\"{_constraint['name']} is neither.\"\n )\n self.line_error(\"\\nNo changes were applied.\")\n return 1\n\n if self.option(\"python\"):\n constraint[\"python\"] = self.option(\"python\")\n\n if self.option(\"platform\"):\n constraint[\"platform\"] = self.option(\"platform\")\n\n if self.option(\"source\"):\n constraint[\"source\"] = self.option(\"source\")\n\n if len(constraint) == 1 and \"version\" in constraint:\n constraint = constraint[\"version\"]\n\n constraint_name = _constraint[\"name\"]\n assert isinstance(constraint_name, str)\n\n canonical_constraint_name = canonicalize_name(constraint_name)\n\n if canonical_constraint_name == project_name:\n self.line_error(\n f\"Cannot add dependency on {constraint_name} to\"\n \" project with the same name.\"\n )\n self.line_error(\"\\nNo changes were applied.\")\n return 1\n\n for key in section:\n if canonicalize_name(key) == canonical_constraint_name:\n section[key] = constraint\n break\n else:\n section[constraint_name] = constraint\n\n with contextlib.suppress(ValueError):\n self.poetry.package.dependency_group(group).remove_dependency(\n constraint_name\n )\n\n self.poetry.package.add_dependency(\n Factory.create_dependency(\n constraint_name,\n constraint,\n groups=[group],\n root_dir=self.poetry.file.path.parent,\n )\n )\n\n # Refresh the locker\n self.poetry.locker.set_local_config(poetry_content)\n self.installer.set_locker(self.poetry.locker)\n\n # Cosmetic new line\n self.line(\"\")\n\n self.installer.set_package(self.poetry.package)\n self.installer.dry_run(self.option(\"dry-run\"))\n self.installer.verbose(self.io.is_verbose())\n self.installer.update(True)\n self.installer.execute_operations(not self.option(\"lock\"))\n\n self.installer.whitelist([r[\"name\"] for r in requirements])\n\n status = self.installer.run()\n\n if status == 0 and not self.option(\"dry-run\"):\n assert isinstance(content, TOMLDocument)\n self.poetry.file.write(content)\n\n return status\n\n def get_existing_packages_from_input(\n self, packages: list[str], section: dict[str, Any]\n ) -> list[str]:\n existing_packages = []\n\n for name in packages:\n for key in section:\n if canonicalize_name(key) == canonicalize_name(name):\n existing_packages.append(name)\n\n return existing_packages\n\n @property\n def _hint_update_packages(self) -> str:\n return (\n \"\\nIf you want to update it to the latest compatible version, you can use\"\n \" `poetry update package`.\\nIf you prefer to upgrade it to the latest\"\n \" available version, you can use `poetry add package@latest`.\\n\"\n )\n\n def notify_about_existing_packages(self, existing_packages: list[str]) -> None:\n self.line(\n \"The following packages are already present in the pyproject.toml and will\"\n \" be skipped:\\n\"\n )\n for name in existing_packages:\n self.line(f\" - {name}\")\n self.line(self._hint_update_packages)\n\n# Path: src/poetry/console/commands/build.py\nfrom __future__ import annotations\n\nfrom pathlib import Path\nfrom typing import TYPE_CHECKING\nfrom typing import ClassVar\n\nfrom cleo.helpers import option\n\nfrom poetry.console.commands.env_command import EnvCommand\nfrom poetry.utils.env import build_environment\nfrom poetry.utils.helpers import remove_directory\n\n\nif TYPE_CHECKING:\n from cleo.io.inputs.option import Option\n\n\nclass BuildCommand(EnvCommand):\n name = \"build\"\n description = \"Builds a package, as a tarball and a wheel by default.\"\n\n options: ClassVar[list[Option]] = [\n option(\"format\", \"f\", \"Limit the format to either sdist or wheel.\", flag=False),\n option(\n \"clean\",\n \"Clean output directory before building.\",\n flag=True,\n ),\n option(\n \"local-version\",\n \"l\",\n \"Add or replace a local version label to the build.\",\n flag=False,\n ),\n option(\n \"output\",\n \"o\",\n \"Set output directory for build artifacts. Default is `dist`.\",\n default=\"dist\",\n flag=False,\n ),\n ]\n\n loggers: ClassVar[list[str]] = [\n \"poetry.core.masonry.builders.builder\",\n \"poetry.core.masonry.builders.sdist\",\n \"poetry.core.masonry.builders.wheel\",\n ]\n\n def _build(\n self,\n fmt: str,\n executable: str | Path | None = None,\n *,\n target_dir: Path | None = None,\n ) -> None:\n from poetry.masonry.builders import BUILD_FORMATS\n\n if fmt in BUILD_FORMATS:\n builders = [BUILD_FORMATS[fmt]]\n elif fmt == \"all\":\n builders = list(BUILD_FORMATS.values())\n else:\n raise ValueError(f\"Invalid format: {fmt}\")\n\n if local_version_label := self.option(\"local-version\"):\n self.poetry.package.version = self.poetry.package.version.replace(\n local=local_version_label\n )\n\n for builder in builders:\n builder(self.poetry, executable=executable).build(target_dir)\n\n def handle(self) -> int:\n if not self.poetry.is_package_mode:\n self.line_error(\"Building a package is not possible in non-package mode.\")\n return 1\n\n with build_environment(poetry=self.poetry, env=self.env, io=self.io) as env:\n fmt = self.option(\"format\") or \"all\"\n dist_dir = Path(self.option(\"output\"))\n package = self.poetry.package\n self.line(\n f\"Building {package.pretty_name} ({package.version})\"\n )\n\n if not dist_dir.is_absolute():\n dist_dir = self.poetry.pyproject_path.parent / dist_dir\n\n if self.option(\"clean\"):\n remove_directory(path=dist_dir, force=True)\n\n self._build(fmt, executable=env.python, target_dir=dist_dir)\n\n return 0\n\n# Path: src/poetry/console/commands/check.py\nfrom __future__ import annotations\n\nfrom typing import TYPE_CHECKING\nfrom typing import Any\nfrom typing import ClassVar\n\nfrom cleo.helpers import option\n\nfrom poetry.console.commands.command import Command\n\n\nif TYPE_CHECKING:\n from pathlib import Path\n\n from cleo.io.inputs.option import Option\n\n\nclass CheckCommand(Command):\n name = \"check\"\n description = (\n \"Validates the content of the pyproject.toml file and its\"\n \" consistency with the poetry.lock file.\"\n )\n\n options: ClassVar[list[Option]] = [\n option(\n \"lock\",\n None,\n \"Checks that poetry.lock exists for the current\"\n \" version of pyproject.toml.\",\n ),\n ]\n\n def _validate_classifiers(\n self, project_classifiers: set[str]\n ) -> tuple[list[str], list[str]]:\n \"\"\"Identify unrecognized and deprecated trove classifiers.\n\n A fully-qualified classifier is a string delimited by `` :: `` separators. To\n make the error message more readable we need to have visual clues to\n materialize the start and end of a classifier string. That way the user can\n easily copy and paste it from the messages while reducing mistakes because of\n extra spaces.\n\n We use ``!r`` (``repr()``) for classifiers and list of classifiers for\n consistency. That way all strings will be rendered with the same kind of quotes\n (i.e. simple tick: ``'``).\n \"\"\"\n from trove_classifiers import classifiers\n from trove_classifiers import deprecated_classifiers\n\n errors = []\n warnings = []\n\n unrecognized = sorted(\n project_classifiers - set(classifiers) - set(deprecated_classifiers)\n )\n # Allow \"Private ::\" classifiers as recommended on PyPI and the packaging guide\n # to allow users to avoid accidentally publishing private packages to PyPI.\n # https://pypi.org/classifiers/\n unrecognized = [u for u in unrecognized if not u.startswith(\"Private ::\")]\n if unrecognized:\n errors.append(f\"Unrecognized classifiers: {unrecognized!r}.\")\n\n deprecated = sorted(\n project_classifiers.intersection(set(deprecated_classifiers))\n )\n if deprecated:\n for old_classifier in deprecated:\n new_classifiers = deprecated_classifiers[old_classifier]\n if new_classifiers:\n message = (\n f\"Deprecated classifier {old_classifier!r}. \"\n f\"Must be replaced by {new_classifiers!r}.\"\n )\n else:\n message = (\n f\"Deprecated classifier {old_classifier!r}. Must be removed.\"\n )\n warnings.append(message)\n\n return errors, warnings\n\n \ndef _validate_readme(self, readme: str | list[str], poetry_file: Path) -> list[str]:\n \"\"\"Check existence of referenced readme files\"\"\"\n\n readmes = [readme] if isinstance(readme, str) else readme\n\n errors = []\n for name in readmes:\n if not (poetry_file.parent / name).exists():\n errors.append(f\"Declared README file does not exist: {name}\")\n return errors\n\n def _validate_dependencies_source(self, config: dict[str, Any]) -> list[str]:\n \"\"\"Check dependencies's source are valid\"\"\"\n sources = {k[\"name\"] for k in config.get(\"source\", [])}\n\n dependency_declarations: list[\n dict[str, str | dict[str, str] | list[dict[str, str]]]\n ] = []\n # scan dependencies and group dependencies settings in pyproject.toml\n if \"dependencies\" in config:\n dependency_declarations.append(config[\"dependencies\"])\n\n for group in config.get(\"group\", {}).values():\n if \"dependencies\" in group:\n dependency_declarations.append(group[\"dependencies\"])\n\n all_referenced_sources: set[str] = set()\n\n for dependency_declaration in dependency_declarations:\n for declaration in dependency_declaration.values():\n if isinstance(declaration, list):\n for item in declaration:\n if \"source\" in item:\n all_referenced_sources.add(item[\"source\"])\n elif isinstance(declaration, dict) and \"source\" in declaration:\n all_referenced_sources.add(declaration[\"source\"])\n\n return [\n f'Invalid source \"{source}\" referenced in dependencies.'\n for source in sorted(all_referenced_sources - sources)\n ]\n\n def handle(self) -> int:\n from poetry.core.pyproject.toml import PyProjectTOML\n\n from poetry.factory import Factory\n\n # Load poetry config and display errors, if any\n poetry_file = self.poetry.file.path\n config = PyProjectTOML(poetry_file).poetry_config\n check_result = Factory.validate(config, strict=True)\n\n # Validate trove classifiers\n project_classifiers = set(config.get(\"classifiers\", []))\n errors, warnings = self._validate_classifiers(project_classifiers)\n check_result[\"errors\"].extend(errors)\n check_result[\"warnings\"].extend(warnings)\n\n # Validate readme (files must exist)\n if \"readme\" in config:\n errors = self._validate_readme(config[\"readme\"], poetry_file)\n check_result[\"errors\"].extend(errors)\n\n check_result[\"errors\"] += self._validate_dependencies_source(config)\n\n # Verify that lock file is consistent\n if self.option(\"lock\") and not self.poetry.locker.is_locked():\n check_result[\"errors\"] += [\"poetry.lock was not found.\"]\n if self.poetry.locker.is_locked() and not self.poetry.locker.is_fresh():\n check_result[\"errors\"] += [\n \"pyproject.toml changed significantly since poetry.lock was last generated. \"\n \"Run `poetry lock [--no-update]` to fix the lock file.\"\n ]\n\n if not check_result[\"errors\"] and not check_result[\"warnings\"]:\n self.info(\"All set!\")\n\n return 0\n\n for error in check_result[\"errors\"]:\n self.line_error(f\"Error: {error}\")\n\n for error in check_result[\"warnings\"]:\n self.line_error(f\"Warning: {error}\")\n\n return 1\n\n# Path: src/poetry/console/commands/command.py\nfrom __future__ import annotations\n\nfrom typing import TYPE_CHECKING\nfrom typing import Any\nfrom typing import ClassVar\n\nfrom cleo.commands.command import Command as BaseCommand\nfrom cleo.exceptions import CleoValueError\n\n\nif TYPE_CHECKING:\n from poetry.console.application import Application\n from poetry.poetry import Poetry\n\n\nclass Command(BaseCommand):\n loggers: ClassVar[list[str]] = []\n\n _poetry: Poetry | None = None\n\n @property\n def poetry(self) -> Poetry:\n if self._poetry is None:\n return self.get_application().poetry\n\n return self._poetry\n\n def set_poetry(self, poetry: Poetry) -> None:\n self._poetry = poetry\n\n def get_application(self) -> Application:\n from poetry.console.application import Application\n\n application = self.application\n assert isinstance(application, Application)\n return application\n\n def reset_poetry(self) -> None:\n self.get_application().reset_poetry()\n\n def option(self, name: str, default: Any = None) -> Any:\n try:\n return super().option(name)\n except CleoValueError:\n return default\n\n# Path: src/poetry/console/commands/config.py\nfrom __future__ import annotations\n\nimport json\nimport re\n\nfrom pathlib import Path\nfrom typing import TYPE_CHECKING\nfrom typing import Any\nfrom typing import ClassVar\nfrom typing import cast\n\nfrom cleo.helpers import argument\nfrom cleo.helpers import option\n\nfrom poetry.config.config import PackageFilterPolicy\nfrom poetry.config.config import boolean_normalizer\nfrom poetry.config.config import boolean_validator\nfrom poetry.config.config import int_normalizer\nfrom poetry.console.commands.command import Command\n\n\nif TYPE_CHECKING:\n from cleo.io.inputs.argument import Argument\n from cleo.io.inputs.option import Option\n\n from poetry.config.config_source import ConfigSource\n\n\nclass ConfigCommand(Command):\n name = \"config\"\n description = \"Manages configuration settings.\"\n\n arguments: ClassVar[list[Argument]] = [\n argument(\"key\", \"Setting key.\", optional=True),\n argument(\"value\", \"Setting value.\", optional=True, multiple=True),\n ]\n\n options: ClassVar[list[Option]] = [\n option(\"list\", None, \"List configuration settings.\"),\n option(\"unset\", None, \"Unset configuration setting.\"),\n option(\"local\", None, \"Set/Get from the project's local configuration.\"),\n ]\n\n help = \"\"\"\\\nThis command allows you to edit the poetry config settings and repositories.\n\nTo add a repository:\n\n poetry config repositories.foo https://bar.com/simple/\n\nTo remove a repository (repo is a short alias for repositories):\n\n poetry config --unset repo.foo\"\"\"\n\n LIST_PROHIBITED_SETTINGS: ClassVar[set[str]] = {\"http-basic\", \"pypi-token\"}\n\n @property\n def unique_config_values(self) -> dict[str, tuple[Any, Any]]:\n unique_config_values = {\n \"cache-dir\": (str, lambda val: str(Path(val))),\n \"virtualenvs.create\": (boolean_validator, boolean_normalizer),\n \"virtualenvs.in-project\": (boolean_validator, boolean_normalizer),\n \"virtualenvs.options.always-copy\": (boolean_validator, boolean_normalizer),\n \"virtualenvs.options.system-site-packages\": (\n boolean_validator,\n boolean_normalizer,\n ),\n \"virtualenvs.options.no-pip\": (boolean_validator, boolean_normalizer),\n \"virtualenvs.options.no-setuptools\": (\n boolean_validator,\n boolean_normalizer,\n ),\n \"virtualenvs.path\": (str, lambda val: str(Path(val))),\n \"virtualenvs.prefer-active-python\": (boolean_validator, boolean_normalizer),\n \"virtualenvs.prompt\": (str, str),\n \"experimental.system-git-client\": (boolean_validator, boolean_normalizer),\n \"installer.modern-installation\": (boolean_validator, boolean_normalizer),\n \"installer.parallel\": (boolean_validator, boolean_normalizer),\n \"installer.max-workers\": (lambda val: int(val) > 0, int_normalizer),\n \"installer.no-binary\": (\n PackageFilterPolicy.validator,\n PackageFilterPolicy.normalize,\n ),\n \"solver.lazy-wheel\": (boolean_validator, boolean_normalizer),\n \"warnings.export\": (boolean_validator, boolean_normalizer),\n \"keyring.enabled\": (boolean_validator, boolean_normalizer),\n }\n\n return unique_config_values\n\n def handle(self) -> int:\n from pathlib import Path\n\n from poetry.core.pyproject.exceptions import PyProjectException\n\n from poetry.config.config import Config\n from poetry.config.file_config_source import FileConfigSource\n from poetry.locations import CONFIG_DIR\n from poetry.toml.file import TOMLFile\n\n config = Config.create()\n config_file = TOMLFile(CONFIG_DIR / \"config.toml\")\n\n try:\n local_config_file = TOMLFile(self.poetry.file.path.parent / \"poetry.toml\")\n if local_config_file.exists():\n config.merge(local_config_file.read())\n except (RuntimeError, PyProjectException):\n local_config_file = TOMLFile(Path.cwd() / \"poetry.toml\")\n\n if self.option(\"local\"):\n...\n", "needle_token_start": 13503, "needle_token_end": 13622, "code_context_ntokens": 15951, "output": ["```python\ndef _validate_readme(self, readme: str | list[str], poetry_file: Path) -> list[str]:\n \"\"\"Check existence of referenced readme files\"\"\"\n\n readmes = [readme] if isinstance(readme, str) else readme\n\n errors = []\n for name in readmes:\n if not (poetry_file.parent / name).exists():\n errors.append(f\"Declared README file does not exist: {name}\")\n return errors\n```"]} +{"repo": "python-poetry/poetry", "name": "base", "language": "python", "path": "src/poetry/utils/env/base_env.py", "position_ratio": 0.95, "description": "\nFunction Description:\n1. **Purpose**: The purpose of this attribute is to provide the root directory of the Python environment, which may differ from the active environment's directory if the environment is nested or linked to a different base environment.\n2. **Input**: There is no direct input required as it is an attribute that retrieves the stored path.\n3. **Output**: The output is the path to the base directory of the Python environment.\n4. **Procedure**: The procedure involves checking if a base path was explicitly provided during the initialization of the environment object. If not provided, it defaults to the path of the current environment. This path is then adjusted for Windows systems to ensure compatibility.\n\n", "instruction": "Based on the function description and code context, please retrieve and repeat the exact described function from the code context in a code block wrapped by ```:", "template": "instruction\ncode_context\ndescription\ninstruction", "code_context": "# Path: src/poetry/console/commands/show.py\nfrom __future__ import annotations\n\nfrom typing import TYPE_CHECKING\nfrom typing import ClassVar\n\nfrom cleo.helpers import argument\nfrom cleo.helpers import option\nfrom packaging.utils import canonicalize_name\n\nfrom poetry.console.commands.env_command import EnvCommand\nfrom poetry.console.commands.group_command import GroupCommand\n\n\nif TYPE_CHECKING:\n from cleo.io.inputs.argument import Argument\n from cleo.io.inputs.option import Option\n from cleo.io.io import IO\n from cleo.ui.table import Rows\n from packaging.utils import NormalizedName\n from poetry.core.packages.dependency import Dependency\n from poetry.core.packages.package import Package\n from poetry.core.packages.project_package import ProjectPackage\n\n from poetry.repositories.repository import Repository\n\n\ndef reverse_deps(pkg: Package, repo: Repository) -> dict[str, str]:\n required_by = {}\n for locked in repo.packages:\n dependencies = {d.name: d.pretty_constraint for d in locked.requires}\n\n if pkg.name in dependencies:\n required_by[locked.pretty_name] = dependencies[pkg.name]\n\n return required_by\n\n\nclass ShowCommand(GroupCommand, EnvCommand):\n name = \"show\"\n description = \"Shows information about packages.\"\n\n arguments: ClassVar[list[Argument]] = [\n argument(\"package\", \"The package to inspect\", optional=True)\n ]\n options: ClassVar[list[Option]] = [\n *GroupCommand._group_dependency_options(),\n option(\n \"no-dev\",\n None,\n \"Do not list the development dependencies. (Deprecated)\",\n ),\n option(\"tree\", \"t\", \"List the dependencies as a tree.\"),\n option(\n \"why\",\n None,\n \"When showing the full list, or a --tree for a single package,\"\n \" display whether they are a direct dependency or required by other\"\n \" packages\",\n ),\n option(\"latest\", \"l\", \"Show the latest version.\"),\n option(\n \"outdated\",\n \"o\",\n \"Show the latest version but only for packages that are outdated.\",\n ),\n option(\n \"all\",\n \"a\",\n \"Show all packages (even those not compatible with current system).\",\n ),\n option(\"top-level\", \"T\", \"Show only top-level dependencies.\"),\n ]\n\n help = \"\"\"The show command displays detailed information about a package, or\nlists all packages available.\"\"\"\n\n colors: ClassVar[list[str]] = [\"cyan\", \"yellow\", \"green\", \"magenta\", \"blue\"]\n\n def handle(self) -> int:\n package = self.argument(\"package\")\n\n if self.option(\"tree\"):\n self.init_styles(self.io)\n\n if self.option(\"top-level\"):\n if self.option(\"tree\"):\n self.line_error(\n \"Error: Cannot use --tree and --top-level at the same\"\n \" time.\"\n )\n return 1\n if package is not None:\n self.line_error(\n \"Error: Cannot use --top-level when displaying a single\"\n \" package.\"\n )\n return 1\n\n if self.option(\"why\"):\n if self.option(\"tree\") and package is None:\n self.line_error(\n \"Error: --why requires a package when combined with\"\n \" --tree.\"\n )\n\n return 1\n\n if not self.option(\"tree\") and package:\n self.line_error(\n \"Error: --why cannot be used without --tree when displaying\"\n \" a single package.\"\n )\n\n return 1\n\n if self.option(\"outdated\"):\n self.io.input.set_option(\"latest\", True)\n\n if not self.poetry.locker.is_locked():\n self.line_error(\n \"Error: poetry.lock not found. Run `poetry lock` to create\"\n \" it.\"\n )\n return 1\n\n locked_repo = self.poetry.locker.locked_repository()\n\n if package:\n return self._display_single_package_information(package, locked_repo)\n\n root = self.project_with_activated_groups_only()\n\n # Show tree view if requested\n if self.option(\"tree\"):\n return self._display_packages_tree_information(locked_repo, root)\n\n return self._display_packages_information(locked_repo, root)\n\n def _display_single_package_information(\n self, package: str, locked_repository: Repository\n ) -> int:\n locked_packages = locked_repository.packages\n canonicalized_package = canonicalize_name(package)\n pkg = None\n\n for locked in locked_packages:\n if locked.name == canonicalized_package:\n pkg = locked\n break\n\n if not pkg:\n raise ValueError(f\"Package {package} not found\")\n\n required_by = reverse_deps(pkg, locked_repository)\n\n if self.option(\"tree\"):\n if self.option(\"why\"):\n # The default case if there's no reverse dependencies is to query\n # the subtree for pkg but if any rev-deps exist we'll query for each\n # of them in turn\n packages = [pkg]\n if required_by:\n packages = [\n p for p in locked_packages for r in required_by if p.name == r\n ]\n else:\n # if no rev-deps exist we'll make this clear as it can otherwise\n # look very odd for packages that also have no or few direct\n # dependencies\n self.io.write_line(f\"Package {package} is a direct dependency.\")\n\n for p in packages:\n self.display_package_tree(\n self.io, p, locked_packages, why_package=pkg\n )\n\n else:\n self.display_package_tree(self.io, pkg, locked_packages)\n\n return 0\n\n rows: Rows = [\n [\"name\", f\" : {pkg.pretty_name}\"],\n [\"version\", f\" : {pkg.pretty_version}\"],\n [\"description\", f\" : {pkg.description}\"],\n ]\n\n self.table(rows=rows, style=\"compact\").render()\n\n if pkg.requires:\n self.line(\"\")\n self.line(\"dependencies\")\n for dependency in pkg.requires:\n self.line(\n f\" - {dependency.pretty_name}\"\n f\" {dependency.pretty_constraint}\"\n )\n\n if required_by:\n self.line(\"\")\n self.line(\"required by\")\n for parent, requires_version in required_by.items():\n self.line(f\" - {parent} {requires_version}\")\n\n return 0\n\n def _display_packages_information(\n self, locked_repository: Repository, root: ProjectPackage\n ) -> int:\n import shutil\n\n from cleo.io.null_io import NullIO\n\n from poetry.puzzle.solver import Solver\n from poetry.repositories.installed_repository import InstalledRepository\n from poetry.repositories.repository_pool import RepositoryPool\n from poetry.utils.helpers import get_package_version_display_string\n\n locked_packages = locked_repository.packages\n pool = RepositoryPool.from_packages(locked_packages, self.poetry.config)\n solver = Solver(\n root,\n pool=pool,\n installed=[],\n locked=locked_packages,\n io=NullIO(),\n )\n solver.provider.load_deferred(False)\n with solver.use_environment(self.env):\n ops = solver.solve().calculate_operations()\n\n required_locked_packages = {op.package for op in ops if not op.skipped}\n\n show_latest = self.option(\"latest\")\n show_all = self.option(\"all\")\n show_top_level = self.option(\"top-level\")\n width = shutil.get_terminal_size().columns\n name_length = version_length = latest_length = required_by_length = 0\n latest_packages = {}\n latest_statuses = {}\n installed_repo = InstalledRepository.load(self.env)\n\n # Computing widths\n for locked in locked_packages:\n if locked not in required_locked_packages and not show_all:\n continue\n\n current_length = len(locked.pretty_name)\n if not self.io.output.is_decorated():\n installed_status = self.get_installed_status(\n locked, installed_repo.packages\n )\n\n if installed_status == \"not-installed\":\n current_length += 4\n\n if show_latest:\n latest = self.find_latest_package(locked, root)\n if not latest:\n latest = locked\n\n latest_packages[locked.pretty_name] = latest\n update_status = latest_statuses[locked.pretty_name] = (\n self.get_update_status(latest, locked)\n )\n\n if not self.option(\"outdated\") or update_status != \"up-to-date\":\n name_length = max(name_length, current_length)\n version_length = max(\n version_length,\n len(\n get_package_version_display_string(\n locked, root=self.poetry.file.path.parent\n )\n ),\n...\n# Path: src/poetry/console/commands/update.py\nfrom __future__ import annotations\n\nfrom typing import TYPE_CHECKING\nfrom typing import ClassVar\n\nfrom cleo.helpers import argument\nfrom cleo.helpers import option\n\nfrom poetry.console.commands.installer_command import InstallerCommand\n\n\nif TYPE_CHECKING:\n from cleo.io.inputs.argument import Argument\n from cleo.io.inputs.option import Option\n\n\nclass UpdateCommand(InstallerCommand):\n name = \"update\"\n description = (\n \"Update the dependencies as according to the pyproject.toml file.\"\n )\n\n arguments: ClassVar[list[Argument]] = [\n argument(\"packages\", \"The packages to update\", optional=True, multiple=True)\n ]\n options: ClassVar[list[Option]] = [\n *InstallerCommand._group_dependency_options(),\n option(\n \"no-dev\",\n None,\n \"Do not update the development dependencies.\"\n \" (Deprecated)\",\n ),\n option(\n \"sync\",\n None,\n \"Synchronize the environment with the locked packages and the specified\"\n \" groups.\",\n ),\n option(\n \"dry-run\",\n None,\n \"Output the operations but do not execute anything \"\n \"(implicitly enables --verbose).\",\n ),\n option(\"lock\", None, \"Do not perform operations (only update the lockfile).\"),\n ]\n\n loggers: ClassVar[list[str]] = [\"poetry.repositories.pypi_repository\"]\n\n def handle(self) -> int:\n packages = self.argument(\"packages\")\n if packages:\n self.installer.whitelist({name: \"*\" for name in packages})\n\n self.installer.only_groups(self.activated_groups)\n self.installer.dry_run(self.option(\"dry-run\"))\n self.installer.requires_synchronization(self.option(\"sync\"))\n self.installer.execute_operations(not self.option(\"lock\"))\n\n # Force update\n self.installer.update(True)\n\n return self.installer.run()\n\n# Path: src/poetry/console/commands/version.py\nfrom __future__ import annotations\n\nfrom typing import TYPE_CHECKING\nfrom typing import Any\nfrom typing import ClassVar\n\nfrom cleo.helpers import argument\nfrom cleo.helpers import option\nfrom poetry.core.version.exceptions import InvalidVersion\nfrom tomlkit.toml_document import TOMLDocument\n\nfrom poetry.console.commands.command import Command\n\n\nif TYPE_CHECKING:\n from cleo.io.inputs.argument import Argument\n from cleo.io.inputs.option import Option\n from poetry.core.constraints.version import Version\n\n\nclass VersionCommand(Command):\n name = \"version\"\n description = (\n \"Shows the version of the project or bumps it when a valid \"\n \"bump rule is provided.\"\n )\n\n arguments: ClassVar[list[Argument]] = [\n argument(\n \"version\",\n \"The version number or the rule to update the version.\",\n optional=True,\n ),\n ]\n options: ClassVar[list[Option]] = [\n option(\"short\", \"s\", \"Output the version number only\"),\n option(\n \"dry-run\",\n None,\n \"Do not update pyproject.toml file\",\n ),\n option(\"next-phase\", None, \"Increment the phase of the current version\"),\n ]\n\n help = \"\"\"\\\nThe version command shows the current version of the project or bumps the version of\nthe project and writes the new version back to pyproject.toml if a valid\nbump rule is provided.\n\nThe new version should ideally be a valid semver string or a valid bump rule:\npatch, minor, major, prepatch, preminor, premajor, prerelease.\n\"\"\"\n\n RESERVED: ClassVar[set[str]] = {\n \"major\",\n \"minor\",\n \"patch\",\n \"premajor\",\n \"preminor\",\n \"prepatch\",\n \"prerelease\",\n }\n\n def handle(self) -> int:\n version = self.argument(\"version\")\n\n if version:\n version = self.increment_version(\n self.poetry.package.pretty_version, version, self.option(\"next-phase\")\n )\n\n if self.option(\"short\"):\n self.line(version.to_string())\n else:\n self.line(\n f\"Bumping version from {self.poetry.package.pretty_version}\"\n f\" to {version}\"\n )\n\n if not self.option(\"dry-run\"):\n content: dict[str, Any] = self.poetry.file.read()\n poetry_content = content[\"tool\"][\"poetry\"]\n poetry_content[\"version\"] = version.text\n\n assert isinstance(content, TOMLDocument)\n self.poetry.file.write(content)\n else:\n if self.option(\"short\"):\n self.line(self.poetry.package.pretty_version)\n else:\n self.line(\n f\"{self.poetry.package.pretty_name}\"\n f\" {self.poetry.package.pretty_version}\"\n )\n\n return 0\n\n def increment_version(\n self, version: str, rule: str, next_phase: bool = False\n ) -> Version:\n from poetry.core.constraints.version import Version\n\n try:\n parsed = Version.parse(version)\n except InvalidVersion:\n raise ValueError(\"The project's version doesn't seem to follow semver\")\n\n if rule in {\"major\", \"premajor\"}:\n new = parsed.next_major()\n if rule == \"premajor\":\n new = new.first_prerelease()\n elif rule in {\"minor\", \"preminor\"}:\n new = parsed.next_minor()\n if rule == \"preminor\":\n new = new.first_prerelease()\n elif rule in {\"patch\", \"prepatch\"}:\n new = parsed.next_patch()\n if rule == \"prepatch\":\n new = new.first_prerelease()\n elif rule == \"prerelease\":\n if parsed.is_unstable():\n pre = parsed.pre\n assert pre is not None\n pre = pre.next_phase() if next_phase else pre.next()\n new = Version(parsed.epoch, parsed.release, pre)\n else:\n new = parsed.next_patch().first_prerelease()\n else:\n new = Version.parse(rule)\n\n return new\n\n# Path: src/poetry/console/events/__init__.py\n\n# Path: src/poetry/console/events/console_events.py\n\n# Path: src/poetry/console/io/__init__.py\n\n# Path: src/poetry/console/logging/__init__.py\n\n# Path: src/poetry/console/logging/filters.py\nfrom __future__ import annotations\n\nimport logging\n\n\nPOETRY_FILTER = logging.Filter(name=\"poetry\")\n\n# Path: src/poetry/console/logging/io_formatter.py\nfrom __future__ import annotations\n\nimport logging\nimport sys\nimport textwrap\n\nfrom pathlib import Path\nfrom typing import TYPE_CHECKING\nfrom typing import ClassVar\n\nfrom poetry.console.logging.filters import POETRY_FILTER\nfrom poetry.console.logging.formatters import FORMATTERS\n\n\nif TYPE_CHECKING:\n from logging import LogRecord\n\n\nclass IOFormatter(logging.Formatter):\n _colors: ClassVar[dict[str, str]] = {\n \"error\": \"fg=red\",\n \"warning\": \"fg=yellow\",\n \"debug\": \"debug\",\n \"info\": \"fg=blue\",\n }\n\n def format(self, record: LogRecord) -> str:\n if not record.exc_info:\n level = record.levelname.lower()\n msg = record.msg\n\n if record.name in FORMATTERS:\n msg = FORMATTERS[record.name].format(msg)\n elif level in self._colors:\n msg = f\"<{self._colors[level]}>{msg}\"\n\n record.msg = msg\n\n formatted = super().format(record)\n\n if not POETRY_FILTER.filter(record):\n # prefix all lines from third-party packages for easier debugging\n formatted = textwrap.indent(\n formatted, f\"[{_log_prefix(record)}] \", lambda line: True\n )\n\n return formatted\n\n\ndef _log_prefix(record: LogRecord) -> str:\n prefix = _path_to_package(Path(record.pathname)) or record.module\n if record.name != \"root\":\n prefix = \":\".join([prefix, record.name])\n return prefix\n\n\ndef _path_to_package(path: Path) -> str | None:\n \"\"\"Return main package name from the LogRecord.pathname.\"\"\"\n prefix: Path | None = None\n # Find the most specific prefix in sys.path.\n # We have to search the entire sys.path because a subsequent path might be\n # a sub path of the first match and thereby a better match.\n for syspath in sys.path:\n if (\n prefix and prefix in (p := Path(syspath)).parents and p in path.parents\n ) or (not prefix and (p := Path(syspath)) in path.parents):\n prefix = p\n if not prefix:\n # this is unexpected, but let's play it safe\n return None\n path = path.relative_to(prefix)\n return path.parts[0] # main package name\n\n# Path: src/poetry/console/logging/io_handler.py\nfrom __future__ import annotations\n\nimport logging\n\nfrom typing import TYPE_CHECKING\n\n\nif TYPE_CHECKING:\n from logging import LogRecord\n\n from cleo.io.io import IO\n\n\nclass IOHandler(logging.Handler):\n def __init__(self, io: IO) -> None:\n self._io = io\n\n super().__init__()\n\n def emit(self, record: LogRecord) -> None:\n try:\n msg = self.format(record)\n level = record.levelname.lower()\n err = level in (\"warning\", \"error\", \"exception\", \"critical\")\n if err:\n self._io.write_error_line(msg)\n else:\n self._io.write_line(msg)\n except Exception:\n self.handleError(record)\n\n# Path: src/poetry/installation/operations/__init__.py\nfrom __future__ import annotations\n\nfrom poetry.installation.operations.install import Install\nfrom poetry.installation.operations.uninstall import Uninstall\nfrom poetry.installation.operations.update import Update\n\n\n__all__ = [\"Install\", \"Uninstall\", \"Update\"]\n\n# Path: src/poetry/installation/operations/install.py\nfrom __future__ import annotations\n\nfrom typing import TYPE_CHECKING\n\nfrom poetry.installation.operations.operation import Operation\n\n\nif TYPE_CHECKING:\n from poetry.core.packages.package import Package\n\n\nclass Install(Operation):\n def __init__(\n self, package: Package, reason: str | None = None, priority: int = 0\n ) -> None:\n super().__init__(reason, priority=priority)\n\n self._package = package\n\n @property\n def package(self) -> Package:\n return self._package\n\n @property\n def job_type(self) -> str:\n return \"install\"\n\n def __str__(self) -> str:\n return (\n \"Installing\"\n f\" {self.package.pretty_name} ({self.format_version(self.package)})\"\n )\n\n def __repr__(self) -> str:\n return (\n \"\"\n )\n\n# Path: src/poetry/installation/operations/operation.py\nfrom __future__ import annotations\n\nfrom typing import TYPE_CHECKING\nfrom typing import TypeVar\n\n\nif TYPE_CHECKING:\n from poetry.core.packages.package import Package\n\nT = TypeVar(\"T\", bound=\"Operation\")\n\n\nclass Operation:\n def __init__(self, reason: str | None = None, priority: float = 0) -> None:\n self._reason = reason\n\n self._skipped = False\n self._skip_reason: str | None = None\n self._priority = priority\n\n @property\n def job_type(self) -> str:\n raise NotImplementedError\n\n @property\n def reason(self) -> str | None:\n return self._reason\n\n @property\n def skipped(self) -> bool:\n return self._skipped\n\n @property\n def skip_reason(self) -> str | None:\n return self._skip_reason\n\n @property\n def priority(self) -> float:\n return self._priority\n\n @property\n def package(self) -> Package:\n raise NotImplementedError()\n\n def format_version(self, package: Package) -> str:\n version: str = package.full_pretty_version\n return version\n\n def skip(self: T, reason: str) -> T:\n self._skipped = True\n self._skip_reason = reason\n\n return self\n\n def unskip(self: T) -> T:\n self._skipped = False\n self._skip_reason = None\n\n return self\n\n# Path: src/poetry/installation/operations/uninstall.py\nfrom __future__ import annotations\n\nfrom typing import TYPE_CHECKING\n\nfrom poetry.installation.operations.operation import Operation\n\n\nif TYPE_CHECKING:\n from poetry.core.packages.package import Package\n\n\nclass Uninstall(Operation):\n def __init__(\n self,\n package: Package,\n reason: str | None = None,\n priority: float = float(\"inf\"),\n ) -> None:\n super().__init__(reason, priority=priority)\n\n self._package = package\n\n @property\n def package(self) -> Package:\n return self._package\n\n @property\n def job_type(self) -> str:\n return \"uninstall\"\n\n def __str__(self) -> str:\n return (\n \"Uninstalling\"\n f\" {self.package.pretty_name} ({self.format_version(self._package)})\"\n )\n\n def __repr__(self) -> str:\n return (\n \"\"\n )\n\n# Path: src/poetry/installation/operations/update.py\nfrom __future__ import annotations\n\nfrom typing import TYPE_CHECKING\n\nfrom poetry.installation.operations.operation import Operation\n\n\nif TYPE_CHECKING:\n from poetry.core.packages.package import Package\n\n\nclass Update(Operation):\n def __init__(\n self,\n initial: Package,\n target: Package,\n reason: str | None = None,\n priority: int = 0,\n ) -> None:\n self._initial_package = initial\n self._target_package = target\n\n super().__init__(reason, priority=priority)\n\n @property\n def initial_package(self) -> Package:\n return self._initial_package\n\n @property\n def target_package(self) -> Package:\n return self._target_package\n\n @property\n def package(self) -> Package:\n return self._target_package\n\n @property\n def job_type(self) -> str:\n return \"update\"\n\n def __str__(self) -> str:\n init_version = self.format_version(self.initial_package)\n target_version = self.format_version(self.target_package)\n return (\n f\"Updating {self.initial_package.pretty_name} ({init_version}) \"\n f\"to {self.target_package.pretty_name} ({target_version})\"\n )\n\n def __repr__(self) -> str:\n init_version = self.format_version(self.initial_package)\n target_version = self.format_version(self.target_package)\n return (\n f\"\"\n )\n\n# Path: src/poetry/masonry/builders/__init__.py\nfrom __future__ import annotations\n\nfrom poetry.core.masonry.builders.sdist import SdistBuilder\nfrom poetry.core.masonry.builders.wheel import WheelBuilder\n\nfrom poetry.masonry.builders.editable import EditableBuilder\n\n\n__all__ = [\"BUILD_FORMATS\", \"EditableBuilder\"]\n\n\n# might be extended by plugins\nBUILD_FORMATS = {\n \"sdist\": SdistBuilder,\n \"wheel\": WheelBuilder,\n}\n\n# Path: src/poetry/masonry/builders/editable.py\nfrom __future__ import annotations\n\nimport csv\nimport hashlib\nimport json\nimport locale\nimport os\n\nfrom base64 import urlsafe_b64encode\nfrom pathlib import Path\nfrom typing import TYPE_CHECKING\n\nfrom poetry.core.masonry.builders.builder import Builder\nfrom poetry.core.masonry.builders.sdist import SdistBuilder\nfrom poetry.core.masonry.utils.package_include import PackageInclude\n\nfrom poetry.utils._compat import WINDOWS\nfrom poetry.utils._compat import decode\nfrom poetry.utils.env import build_environment\nfrom poetry.utils.helpers import is_dir_writable\nfrom poetry.utils.pip import pip_install\n\n\nif TYPE_CHECKING:\n from cleo.io.io import IO\n\n from poetry.poetry import Poetry\n from poetry.utils.env import Env\n\nSCRIPT_TEMPLATE = \"\"\"\\\n#!{python}\nimport sys\nfrom {module} import {callable_holder}\n\nif __name__ == '__main__':\n sys.exit({callable_}())\n\"\"\"\n\nWINDOWS_CMD_TEMPLATE = \"\"\"\\\n@echo off\\r\\n\"{python}\" \"%~dp0\\\\{script}\" %*\\r\\n\n\"\"\"\n\n\nclass EditableBuilder(Builder):\n def __init__(self, poetry: Poetry, env: Env, io: IO) -> None:\n self._poetry: Poetry\n super().__init__(poetry)\n\n self._env = env\n self._io = io\n\n def build(self, target_dir: Path | None = None) -> Path:\n self._debug(\n f\" - Building package {self._package.name} in\"\n \" editable mode\"\n )\n\n if self._package.build_script:\n if self._package.build_should_generate_setup():\n self._debug(\n \" - Falling back on using a setup.py\"\n )\n self._setup_build()\n return self._path\n\n self._run_build_script(self._package.build_script)\n\n for removed in self._env.site_packages.remove_distribution_files(\n distribution_name=self._package.name\n ):\n self._debug(\n f\" - Removed {removed.name} directory from\"\n f\" {removed.parent}\"\n )\n\n added_files = []\n added_files += self._add_pth()\n added_files += self._add_scripts()\n self._add_dist_info(added_files)\n\n return self._path\n\n def _run_build_script(self, build_script: str) -> None:\n with build_environment(poetry=self._poetry, env=self._env, io=self._io) as env:\n self._debug(f\" - Executing build script: {build_script}\")\n env.run(\"python\", str(self._path.joinpath(build_script)), call=True)\n\n def _setup_build(self) -> None:\n builder = SdistBuilder(self._poetry)\n setup = self._path / \"setup.py\"\n has_setup = setup.exists()\n\n if has_setup:\n self._io.write_error_line(\n \"A setup.py file already exists. Using it.\"\n )\n else:\n with setup.open(\"w\", encoding=\"utf-8\") as f:\n f.write(decode(builder.build_setup()))\n\n try:\n pip_install(self._path, self._env, upgrade=True, editable=True)\n finally:\n if not has_setup:\n os.remove(setup)\n\n def _add_pth(self) -> list[Path]:\n paths = {\n include.base.resolve().as_posix()\n for include in self._module.includes\n if isinstance(include, PackageInclude)\n and (include.is_module() or include.is_package())\n }\n\n content = \"\".join(decode(path + os.linesep) for path in paths)\n pth_file = Path(self._module.name).with_suffix(\".pth\")\n\n # remove any pre-existing pth files for this package\n for file in self._env.site_packages.find(path=pth_file, writable_only=True):\n self._debug(\n f\" - Removing existing {file.name} from {file.parent}\"\n f\" for {self._poetry.file.path.parent}\"\n )\n file.unlink(missing_ok=True)\n\n try:\n pth_file = self._env.site_packages.write_text(\n pth_file, content, encoding=locale.getpreferredencoding()\n )\n self._debug(\n f\" - Adding {pth_file.name} to {pth_file.parent} for\"\n f\" {self._poetry.file.path.parent}\"\n )\n return [pth_file]\n except OSError:\n # TODO: Replace with PermissionError\n self._io.write_error_line(\n f\" - Failed to create {pth_file.name} for\"\n f\" {self._poetry.file.path.parent}\"\n )\n return []\n\n def _add_scripts(self) -> list[Path]:\n added = []\n entry_points = self.convert_entry_points()\n\n for scripts_path in self._env.script_dirs:\n if is_dir_writable(path=scripts_path, create=True):\n break\n else:\n self._io.write_error_line(\n \" - Failed to find a suitable script installation directory for\"\n f\" {self._poetry.file.path.parent}\"\n )\n return []\n\n scripts = entry_points.get(\"console_scripts\", [])\n for script in scripts:\n name, script_with_extras = script.split(\" = \")\n script_without_extras = script_with_extras.split(\"[\")[0]\n try:\n module, callable_ = script_without_extras.split(\":\")\n except ValueError as exc:\n msg = (\n f\"Bad script ({name}): script needs to specify a function within a\"\n \" module like: module(.submodule):function\\nInstead got:\"\n f\" {script_with_extras}\"\n )\n if \"not enough values\" in str(exc):\n msg += (\n \"\\nHint: If the script depends on module-level code, try\"\n \" wrapping it in a main() function and modifying your script\"\n f' like:\\n{name} = \"{script_with_extras}:main\"'\n )\n elif \"too many values\" in str(exc):\n msg += '\\nToo many \":\" found!'\n\n raise ValueError(msg)\n\n callable_holder = callable_.split(\".\", 1)[0]\n\n script_file = scripts_path.joinpath(name)\n self._debug(\n f\" - Adding the {name} script to {scripts_path}\"\n )\n with script_file.open(\"w\", encoding=\"utf-8\") as f:\n f.write(\n decode(\n SCRIPT_TEMPLATE.format(\n python=self._env.python,\n module=module,\n callable_holder=callable_holder,\n callable_=callable_,\n )\n )\n )\n\n script_file.chmod(0o755)\n\n added.append(script_file)\n\n if WINDOWS:\n cmd_script = script_file.with_suffix(\".cmd\")\n cmd = WINDOWS_CMD_TEMPLATE.format(python=self._env.python, script=name)\n self._debug(\n f\" - Adding the {cmd_script.name} script wrapper to\"\n f\" {scripts_path}\"\n )\n\n with cmd_script.open(\"w\", encoding=\"utf-8\") as f:\n f.write(decode(cmd))\n\n added.append(cmd_script)\n\n return added\n\n def _add_dist_info(self, added_files: list[Path]) -> None:\n from poetry.core.masonry.builders.wheel import WheelBuilder\n\n added_files = added_files[:]\n\n builder = WheelBuilder(self._poetry)\n dist_info = self._env.site_packages.mkdir(Path(builder.dist_info))\n\n self._debug(\n f\" - Adding the {dist_info.name} directory to\"\n f\" {dist_info.parent}\"\n )\n\n with dist_info.joinpath(\"METADATA\").open(\"w\", encoding=\"utf-8\") as f:\n builder._write_metadata_file(f)\n\n added_files.append(dist_info.joinpath(\"METADATA\"))\n\n with dist_info.joinpath(\"INSTALLER\").open(\"w\", encoding=\"utf-8\") as f:\n f.write(\"poetry\")\n\n added_files.append(dist_info.joinpath(\"INSTALLER\"))\n\n if self.convert_entry_points():\n with dist_info.joinpath(\"entry_points.txt\").open(\n \"w\", encoding=\"utf-8\"\n ) as f:\n builder._write_entry_points(f)\n\n added_files.append(dist_info.joinpath(\"entry_points.txt\"))\n\n # write PEP 610 metadata\n direct_url_json = dist_info.joinpath(\"direct_url.json\")\n direct_url_json.write_text(\n json.dumps(\n {\n \"dir_info\": {\"editable\": True},\n \"url\": self._poetry.file.path.parent.absolute().as_uri(),\n }\n )\n )\n added_files.append(direct_url_json)\n\n record = dist_info.joinpath(\"RECORD\")\n with record.open(\"w\", encoding=\"utf-8\", newline=\"\") as f:\n csv_writer = csv.writer(f)\n for path in added_files:\n hash = self._get_file_hash(path)\n size = path.stat().st_size\n csv_writer.writerow((path, f\"sha256={hash}\", size))\n\n # RECORD itself is recorded with no hash or size\n csv_writer.writerow((record, \"\", \"\"))\n\n def _get_file_hash(self, filepath: Path) -> str:\n hashsum = hashlib.sha256()\n with filepath.open(\"rb\") as src:\n while True:\n buf = src.read(1024 * 8)\n if not buf:\n break\n hashsum.update(buf)\n\n src.seek(0)\n\n return urlsafe_b64encode(hashsum.digest()).decode(\"ascii\").rstrip(\"=\")\n\n def _debug(self, msg: str) -> None:\n if self._io.is_debug():\n self._io.write_line(msg)\n\n# Path: src/poetry/repositories/link_sources/__init__.py\n\n# Path: src/poetry/repositories/link_sources/base.py\nfrom __future__ import annotations\n\nimport logging\nimport re\n\nfrom functools import cached_property\nfrom typing import TYPE_CHECKING\nfrom typing import ClassVar\nfrom typing import DefaultDict\nfrom typing import List\n\nfrom poetry.core.constraints.version import Version\nfrom poetry.core.packages.package import Package\nfrom poetry.core.version.exceptions import InvalidVersion\n\nfrom poetry.utils.patterns import sdist_file_re\nfrom poetry.utils.patterns import wheel_file_re\n\n\nif TYPE_CHECKING:\n from collections.abc import Iterator\n\n from packaging.utils import NormalizedName\n from poetry.core.packages.utils.link import Link\n\n LinkCache = DefaultDict[NormalizedName, DefaultDict[Version, List[Link]]]\n\n\nlogger = logging.getLogger(__name__)\n\n\nclass LinkSource:\n VERSION_REGEX = re.compile(r\"(?i)([a-z0-9_\\-.]+?)-(?=\\d)([a-z0-9_.!+-]+)\")\n CLEAN_REGEX = re.compile(r\"[^a-z0-9$&+,/:;=?@.#%_\\\\|-]\", re.I)\n SUPPORTED_FORMATS: ClassVar[list[str]] = [\n \".tar.gz\",\n \".whl\",\n \".zip\",\n \".tar.bz2\",\n \".tar.xz\",\n \".tar.Z\",\n \".tar\",\n ]\n\n def __init__(self, url: str) -> None:\n self._url = url\n\n @property\n def url(self) -> str:\n return self._url\n\n def versions(self, name: NormalizedName) -> Iterator[Version]:\n yield from self._link_cache[name]\n\n @property\n def packages(self) -> Iterator[Package]:\n for link in self.links:\n pkg = self.link_package_data(link)\n\n if pkg:\n yield pkg\n\n @property\n def links(self) -> Iterator[Link]:\n for links_per_version in self._link_cache.values():\n for links in links_per_version.values():\n yield from links\n\n @classmethod\n def link_package_data(cls, link: Link) -> Package | None:\n name: str | None = None\n version_string: str | None = None\n version: Version | None = None\n m = wheel_file_re.match(link.filename) or sdist_file_re.match(link.filename)\n\n if m:\n name = m.group(\"name\")\n version_string = m.group(\"ver\")\n else:\n info, ext = link.splitext()\n match = cls.VERSION_REGEX.match(info)\n if match:\n name = match.group(1)\n version_string = match.group(2)\n\n if version_string:\n try:\n version = Version.parse(version_string)\n except InvalidVersion:\n logger.debug(\n \"Skipping url (%s) due to invalid version (%s)\", link.url, version\n )\n return None\n\n pkg = None\n if name and version:\n pkg = Package(name, version, source_url=link.url)\n return pkg\n\n def links_for_version(\n self, name: NormalizedName, version: Version\n ) -> Iterator[Link]:\n yield from self._link_cache[name][version]\n\n def clean_link(self, url: str) -> str:\n \"\"\"Makes sure a link is fully encoded. That is, if a ' ' shows up in\n the link, it will be rewritten to %20 (while not over-quoting\n % or other characters).\"\"\"\n return self.CLEAN_REGEX.sub(lambda match: f\"%{ord(match.group(0)):02x}\", url)\n\n def yanked(self, name: NormalizedName, version: Version) -> str | bool:\n reasons = set()\n for link in self.links_for_version(name, version):\n if link.yanked:\n if link.yanked_reason:\n reasons.add(link.yanked_reason)\n else:\n # release is not yanked if at least one file is not yanked\n return False\n # if all files are yanked (or there are no files) the release is yanked\n if reasons:\n return \"\\n\".join(sorted(reasons))\n return True\n\n @cached_property\n def _link_cache(self) -> LinkCache:\n raise NotImplementedError()\n\n# Path: src/poetry/repositories/link_sources/html.py\nfrom __future__ import annotations\n\nimport urllib.parse\n\nfrom collections import defaultdict\nfrom functools import cached_property\nfrom html import unescape\nfrom typing import TYPE_CHECKING\n\nfrom poetry.core.packages.utils.link import Link\n\nfrom poetry.repositories.link_sources.base import LinkSource\nfrom poetry.repositories.parsers.html_page_parser import HTMLPageParser\n\n\nif TYPE_CHECKING:\n from poetry.repositories.link_sources.base import LinkCache\n\n\nclass HTMLPage(LinkSource):\n def __init__(self, url: str, content: str) -> None:\n super().__init__(url=url)\n\n parser = HTMLPageParser()\n parser.feed(content)\n self._parsed = parser.anchors\n self._base_url: str | None = parser.base_url\n\n @cached_property\n def _link_cache(self) -> LinkCache:\n links: LinkCache = defaultdict(lambda: defaultdict(list))\n for anchor in self._parsed:\n if href := anchor.get(\"href\"):\n url = self.clean_link(\n urllib.parse.urljoin(self._base_url or self._url, href)\n )\n pyrequire = anchor.get(\"data-requires-python\")\n pyrequire = unescape(pyrequire) if pyrequire else None\n yanked_value = anchor.get(\"data-yanked\")\n yanked: str | bool\n if yanked_value:\n yanked = unescape(yanked_value)\n else:\n yanked = \"data-yanked\" in anchor\n\n # see https://peps.python.org/pep-0714/#clients\n # and https://peps.python.org/pep-0658/#specification\n metadata: str | bool\n for metadata_key in (\"data-core-metadata\", \"data-dist-info-metadata\"):\n metadata_value = anchor.get(metadata_key)\n if metadata_value:\n metadata = unescape(metadata_value)\n else:\n metadata = metadata_key in anchor\n if metadata:\n break\n link = Link(\n url, requires_python=pyrequire, yanked=yanked, metadata=metadata\n )\n\n if link.ext not in self.SUPPORTED_FORMATS:\n continue\n\n pkg = self.link_package_data(link)\n if pkg:\n links[pkg.name][pkg.version].append(link)\n\n return links\n\n\nclass SimpleRepositoryPage(HTMLPage):\n def __init__(self, url: str, content: str) -> None:\n if not url.endswith(\"/\"):\n url += \"/\"\n super().__init__(url=url, content=content)\n\n# Path: src/poetry/repositories/link_sources/json.py\nfrom __future__ import annotations\n\nfrom collections import defaultdict\nfrom functools import cached_property\nfrom typing import TYPE_CHECKING\nfrom typing import Any\n\nfrom poetry.core.packages.utils.link import Link\n\nfrom poetry.repositories.link_sources.base import LinkSource\n\n\nif TYPE_CHECKING:\n from poetry.repositories.link_sources.base import LinkCache\n\n\nclass SimpleJsonPage(LinkSource):\n \"\"\"Links as returned by PEP 691 compatible JSON-based Simple API.\"\"\"\n\n def __init__(self, url: str, content: dict[str, Any]) -> None:\n super().__init__(url=url)\n self.content = content\n\n @cached_property\n def _link_cache(self) -> LinkCache:\n links: LinkCache = defaultdict(lambda: defaultdict(list))\n for file in self.content[\"files\"]:\n url = file[\"url\"]\n requires_python = file.get(\"requires-python\")\n yanked = file.get(\"yanked\", False)\n\n # see https://peps.python.org/pep-0714/#clients\n # and https://peps.python.org/pep-0691/#project-detail\n metadata: dict[str, str] | bool = False\n for metadata_key in (\"core-metadata\", \"dist-info-metadata\"):\n if metadata_key in file:\n metadata_value = file[metadata_key]\n if metadata_value and isinstance(metadata_value, dict):\n metadata = metadata_value\n else:\n metadata = bool(metadata_value)\n break\n\n link = Link(\n url, requires_python=requires_python, yanked=yanked, metadata=metadata\n )\n\n if link.ext not in self.SUPPORTED_FORMATS:\n continue\n\n pkg = self.link_package_data(link)\n if pkg:\n links[pkg.name][pkg.version].append(link)\n\n return links\n\n# Path: src/poetry/repositories/parsers/__init__.py\n\n# Path: src/poetry/repositories/parsers/html_page_parser.py\nfrom __future__ import annotations\n\nfrom html.parser import HTMLParser\n\n\nclass HTMLPageParser(HTMLParser):\n def __init__(self) -> None:\n super().__init__()\n self.base_url: str | None = None\n self.anchors: list[dict[str, str | None]] = []\n\n def handle_starttag(self, tag: str, attrs: list[tuple[str, str | None]]) -> None:\n if tag == \"base\" and self.base_url is None:\n base_url = dict(attrs).get(\"href\")\n if base_url is not None:\n self.base_url = base_url\n elif tag == \"a\":\n self.anchors.append(dict(attrs))\n\n# Path: src/poetry/repositories/parsers/pypi_search_parser.py\nfrom __future__ import annotations\n\nimport functools\n\nfrom dataclasses import dataclass\nfrom html.parser import HTMLParser\nfrom typing import Callable\n\n\n# The following code was originally written for PDM project\n# https://github.com/pdm-project/pdm/blob/1f4f48a35cdded064def85df117bebf713f7c17a/src/pdm/models/search.py\n# and later changed to fit Poetry needs\n\n\n@dataclass\nclass Result:\n name: str = \"\"\n version: str = \"\"\n description: str = \"\"\n\n\nclass SearchResultParser(HTMLParser):\n \"\"\"A simple HTML parser for pypi.org search results.\"\"\"\n\n def __init__(self) -> None:\n super().__init__()\n self.results: list[Result] = []\n self._current: Result | None = None\n self._nest_anchors = 0\n self._data_callback: Callable[[str], None] | None = None\n\n @staticmethod\n def _match_class(attrs: list[tuple[str, str | None]], name: str) -> bool:\n attrs_map = dict(attrs)\n return name in (attrs_map.get(\"class\") or \"\").split()\n\n def handle_starttag(self, tag: str, attrs: list[tuple[str, str | None]]) -> None:\n if not self._current:\n if tag == \"a\" and self._match_class(attrs, \"package-snippet\"):\n self._current = Result()\n self._nest_anchors = 1\n else:\n if tag == \"span\" and self._match_class(attrs, \"package-snippet__name\"):\n self._data_callback = functools.partial(setattr, self._current, \"name\")\n elif tag == \"span\" and self._match_class(attrs, \"package-snippet__version\"):\n self._data_callback = functools.partial(\n setattr, self._current, \"version\"\n )\n elif tag == \"p\" and self._match_class(\n attrs, \"package-snippet__description\"\n ):\n self._data_callback = functools.partial(\n setattr, self._current, \"description\"\n )\n elif tag == \"a\":\n self._nest_anchors += 1\n\n def handle_data(self, data: str) -> None:\n if self._data_callback is not None:\n self._data_callback(data)\n self._data_callback = None\n\n def handle_endtag(self, tag: str) -> None:\n if tag != \"a\" or self._current is None:\n return\n self._nest_anchors -= 1\n if self._nest_anchors == 0:\n if self._current.name and self._current.version:\n self.results.append(self._current)\n self._current = None\n\n# Path: src/poetry/utils/env/__init__.py\nfrom __future__ import annotations\n\nfrom contextlib import contextmanager\nfrom pathlib import Path\nfrom typing import TYPE_CHECKING\n\nfrom poetry.core.utils.helpers import temporary_directory\n\nfrom poetry.utils.env.base_env import Env\nfrom poetry.utils.env.env_manager import EnvManager\nfrom poetry.utils.env.exceptions import EnvCommandError\nfrom poetry.utils.env.exceptions import EnvError\nfrom poetry.utils.env.exceptions import IncorrectEnvError\nfrom poetry.utils.env.exceptions import InvalidCurrentPythonVersionError\nfrom poetry.utils.env.exceptions import NoCompatiblePythonVersionFound\nfrom poetry.utils.env.exceptions import PythonVersionNotFound\nfrom poetry.utils.env.generic_env import GenericEnv\nfrom poetry.utils.env.mock_env import MockEnv\nfrom poetry.utils.env.null_env import NullEnv\nfrom poetry.utils.env.script_strings import GET_BASE_PREFIX\nfrom poetry.utils.env.script_strings import GET_ENV_PATH_ONELINER\nfrom poetry.utils.env.script_strings import GET_ENVIRONMENT_INFO\nfrom poetry.utils.env.script_strings import GET_PATHS\nfrom poetry.utils.env.script_strings import GET_PATHS_FOR_GENERIC_ENVS\nfrom poetry.utils.env.script_strings import GET_PYTHON_VERSION\nfrom poetry.utils.env.script_strings import GET_PYTHON_VERSION_ONELINER\nfrom poetry.utils.env.script_strings import GET_SYS_PATH\nfrom poetry.utils.env.script_strings import GET_SYS_TAGS\nfrom poetry.utils.env.site_packages import SitePackages\nfrom poetry.utils.env.system_env import SystemEnv\nfrom poetry.utils.env.virtual_env import VirtualEnv\n\n\nif TYPE_CHECKING:\n from collections.abc import Iterator\n\n from cleo.io.io import IO\n from poetry.core.poetry import Poetry as CorePoetry\n\n\n@contextmanager\ndef ephemeral_environment(\n executable: Path | None = None,\n flags: dict[str, str | bool] | None = None,\n) -> Iterator[VirtualEnv]:\n with temporary_directory() as tmp_dir:\n # TODO: cache PEP 517 build environment corresponding to each project venv\n venv_dir = Path(tmp_dir) / \".venv\"\n EnvManager.build_venv(\n path=venv_dir,\n executable=executable,\n flags=flags,\n )\n yield VirtualEnv(venv_dir, venv_dir)\n\n\n@contextmanager\ndef build_environment(\n poetry: CorePoetry, env: Env | None = None, io: IO | None = None\n) -> Iterator[Env]:\n \"\"\"\n If a build script is specified for the project, there could be additional build\n time dependencies, eg: cython, setuptools etc. In these cases, we create an\n ephemeral build environment with all requirements specified under\n `build-system.requires` and return this. Otherwise, the given default project\n environment is returned.\n \"\"\"\n if not env or poetry.package.build_script:\n with ephemeral_environment(executable=env.python if env else None) as venv:\n if io:\n requires = [\n f\"{requirement}\"\n for requirement in poetry.pyproject.build_system.requires\n ]\n\n io.write_error_line(\n \"Preparing build environment with build-system requirements\"\n f\" {', '.join(requires)}\"\n )\n\n output = venv.run_pip(\n \"install\",\n \"--disable-pip-version-check\",\n \"--ignore-installed\",\n \"--no-input\",\n *poetry.pyproject.build_system.requires,\n )\n\n if io and io.is_debug() and output:\n io.write_error(output)\n\n yield venv\n else:\n yield env\n\n\n__all__ = [\n \"GET_BASE_PREFIX\",\n \"GET_ENVIRONMENT_INFO\",\n \"GET_PATHS\",\n \"GET_PYTHON_VERSION\",\n \"GET_SYS_PATH\",\n \"GET_SYS_TAGS\",\n \"GET_ENV_PATH_ONELINER\",\n \"GET_PYTHON_VERSION_ONELINER\",\n \"GET_PATHS_FOR_GENERIC_ENVS\",\n \"EnvError\",\n \"EnvCommandError\",\n \"IncorrectEnvError\",\n \"InvalidCurrentPythonVersionError\",\n \"NoCompatiblePythonVersionFound\",\n \"PythonVersionNotFound\",\n \"Env\",\n \"EnvManager\",\n \"GenericEnv\",\n \"MockEnv\",\n \"NullEnv\",\n \"SystemEnv\",\n \"VirtualEnv\",\n \"SitePackages\",\n \"build_environment\",\n \"ephemeral_environment\",\n]\n\n# Path: src/poetry/utils/env/base_env.py\nfrom __future__ import annotations\n\nimport contextlib\nimport os\nimport re\nimport subprocess\nimport sys\nimport sysconfig\n\nfrom pathlib import Path\nfrom subprocess import CalledProcessError\nfrom typing import TYPE_CHECKING\nfrom typing import Any\n\nfrom virtualenv.seed.wheels.embed import get_embed_wheel\n\nfrom poetry.utils.env.exceptions import EnvCommandError\nfrom poetry.utils.env.site_packages import SitePackages\nfrom poetry.utils.helpers import get_real_windows_path\n\n\nif TYPE_CHECKING:\n from packaging.tags import Tag\n from poetry.core.version.markers import BaseMarker\n from virtualenv.seed.wheels.util import Wheel\n\n from poetry.utils.env.generic_env import GenericEnv\n\n\nclass Env:\n \"\"\"\n An abstract Python environment.\n \"\"\"\n\n def __init__(self, path: Path, base: Path | None = None) -> None:\n self._is_windows = sys.platform == \"win32\"\n self._is_mingw = sysconfig.get_platform().startswith(\"mingw\")\n self._is_conda = bool(os.environ.get(\"CONDA_DEFAULT_ENV\"))\n\n if self._is_windows:\n path = get_real_windows_path(path)\n base = get_real_windows_path(base) if base else None\n\n bin_dir = \"bin\" if not self._is_windows or self._is_mingw else \"Scripts\"\n self._path = path\n self._bin_dir = self._path / bin_dir\n\n self._executable = \"python\"\n self._pip_executable = \"pip\"\n\n self.find_executables()\n\n self._base = base or path\n\n self._marker_env: dict[str, Any] | None = None\n self._site_packages: SitePackages | None = None\n self._paths: dict[str, str] | None = None\n self._supported_tags: list[Tag] | None = None\n self._purelib: Path | None = None\n self._platlib: Path | None = None\n self._script_dirs: list[Path] | None = None\n\n self._embedded_pip_path: Path | None = None\n\n @property\n def path(self) -> Path:\n return self._path\n\n @property\n \ndef base(self) -> Path:\n return self._base\n\n @property\n def version_info(self) -> tuple[int, int, int, str, int]:\n version_info: tuple[int, int, int, str, int] = self.marker_env[\"version_info\"]\n return version_info\n\n @property\n def python_implementation(self) -> str:\n implementation: str = self.marker_env[\"platform_python_implementation\"]\n return implementation\n\n @property\n def python(self) -> Path:\n \"\"\"\n Path to current python executable\n \"\"\"\n return Path(self._bin(self._executable))\n\n @property\n def marker_env(self) -> dict[str, Any]:\n if self._marker_env is None:\n self._marker_env = self.get_marker_env()\n\n return self._marker_env\n\n @property\n def parent_env(self) -> GenericEnv:\n from poetry.utils.env.generic_env import GenericEnv\n\n return GenericEnv(self.base, child_env=self)\n\n def _find_python_executable(self) -> None:\n bin_dir = self._bin_dir\n\n if self._is_windows and self._is_conda:\n bin_dir = self._path\n\n python_executables = sorted(\n p.name\n for p in bin_dir.glob(\"python*\")\n if re.match(r\"python(?:\\d+(?:\\.\\d+)?)?(?:\\.exe)?$\", p.name)\n )\n if python_executables:\n executable = python_executables[0]\n if executable.endswith(\".exe\"):\n executable = executable[:-4]\n\n self._executable = executable\n\n def _find_pip_executable(self) -> None:\n pip_executables = sorted(\n p.name\n for p in self._bin_dir.glob(\"pip*\")\n if re.match(r\"pip(?:\\d+(?:\\.\\d+)?)?(?:\\.exe)?$\", p.name)\n )\n if pip_executables:\n pip_executable = pip_executables[0]\n if pip_executable.endswith(\".exe\"):\n pip_executable = pip_executable[:-4]\n\n self._pip_executable = pip_executable\n\n def find_executables(self) -> None:\n self._find_python_executable()\n self._find_pip_executable()\n\n def get_embedded_wheel(self, distribution: str) -> Path:\n wheel: Wheel = get_embed_wheel(\n distribution, f\"{self.version_info[0]}.{self.version_info[1]}\"\n )\n path: Path = wheel.path\n return path\n\n @property\n def pip_embedded(self) -> Path:\n if self._embedded_pip_path is None:\n self._embedded_pip_path = self.get_embedded_wheel(\"pip\") / \"pip\"\n return self._embedded_pip_path\n\n @property\n def pip(self) -> Path:\n \"\"\"\n Path to current pip executable\n \"\"\"\n...\n", "needle_token_start": 15193, "needle_token_end": 15207, "code_context_ntokens": 15994, "output": ["```python\n@property\ndef base(self) -> Path:\n return self._base\n```"]} +{"repo": "locustio/locust", "name": "test_user_classes_with_same_name_is_error", "language": "python", "path": "locust/test/test_env.py", "position_ratio": 0.05, "description": "\nFunction Description:\n1. **Purpose**: To verify that an error is raised when two user classes with identical names are used in the same environment, ensuring unique identification of user classes.\n2. **Input**: Two user classes imported from different modules but with the same class name.\n3. **Output**: A `ValueError` exception indicating that user classes with duplicate names cannot be used together.\n4. **Procedure**: The test attempts to create an environment with the two conflicting user classes. It checks that the correct exception is raised and that the error message specifies the exact issue with the class names and their originating modules.\n\n", "instruction": "Based on the function description and code context, please retrieve and repeat the exact described function from the code context in a code block wrapped by ```:", "template": "instruction\ncode_context\ndescription\ninstruction", "code_context": "# Path: locust/test/test_dispatch.py\nfrom __future__ import annotations\n\nfrom locust import User\nfrom locust.dispatch import UsersDispatcher\nfrom locust.runners import WorkerNode\nfrom locust.test.util import clear_all_functools_lru_cache\n\nimport time\nimport unittest\nfrom operator import attrgetter\n\n_TOLERANCE = 0.025\n\n\nclass TestRampUpUsersFromZero(unittest.TestCase):\n def test_ramp_up_users_to_3_workers_with_spawn_rate_of_0_5(self):\n \"\"\"Final distribution should be {\"User1\": 3, \"User2\": 3, \"User3\": 3}\"\"\"\n\n class User1(User):\n weight = 1\n\n class User2(User):\n weight = 1\n\n class User3(User):\n weight = 1\n\n worker_node1 = WorkerNode(\"1\")\n worker_node2 = WorkerNode(\"2\")\n worker_node3 = WorkerNode(\"3\")\n\n sleep_time = 0.2 # Speed-up test\n\n users_dispatcher = UsersDispatcher(\n worker_nodes=[worker_node1, worker_node2, worker_node3], user_classes=[User1, User2, User3]\n )\n users_dispatcher.new_dispatch(target_user_count=9, spawn_rate=0.5)\n users_dispatcher._wait_between_dispatch = sleep_time\n\n ts = time.perf_counter()\n self.assertDictEqual(\n next(users_dispatcher),\n {\n...\n# Path: locust/test/test_env.py\nfrom locust import (\n constant,\n)\nfrom locust.dispatch import UsersDispatcher\nfrom locust.env import Environment, LoadTestShape\nfrom locust.user import (\n User,\n task,\n)\nfrom locust.user.task import TaskSet\n\nfrom .fake_module1_for_env_test import MyUserWithSameName as MyUserWithSameName1\nfrom .fake_module2_for_env_test import MyUserWithSameName as MyUserWithSameName2\nfrom .testcases import LocustTestCase\n\n\nclass TestEnvironment(LocustTestCase):\n def test_user_classes_count(self):\n class MyUser1(User):\n wait_time = constant(0)\n\n @task\n def my_task(self):\n pass\n\n class MyUser2(User):\n wait_time = constant(0)\n\n @task\n def my_task(self):\n pass\n\n environment = Environment(user_classes=[MyUser1, MyUser2])\n\n self.assertDictEqual({\"MyUser1\": MyUser1, \"MyUser2\": MyUser2}, environment.user_classes_by_name)\n\n \ndef test_user_classes_with_same_name_is_error(self):\n with self.assertRaises(ValueError) as e:\n environment = Environment(user_classes=[MyUserWithSameName1, MyUserWithSameName2])\n\n self.assertEqual(\n e.exception.args[0],\n \"The following user classes have the same class name: locust.test.fake_module1_for_env_test.MyUserWithSameName, locust.test.fake_module2_for_env_test.MyUserWithSameName\",\n )\n\n def test_assign_equal_weights(self):\n def verify_tasks(u, target_tasks):\n self.assertEqual(len(u.tasks), len(target_tasks))\n tasks = [t.__name__ for t in u.tasks]\n self.assertEqual(len(tasks), len(set(tasks)))\n self.assertEqual(set(tasks), set(target_tasks))\n\n # Base case\n class MyUser1(User):\n wait_time = constant(0)\n\n @task(4)\n def my_task(self):\n pass\n\n @task(1)\n def my_task_2(self):\n pass\n\n environment = Environment(user_classes=[MyUser1])\n environment.assign_equal_weights()\n u = environment.user_classes[0]\n verify_tasks(u, [\"my_task\", \"my_task_2\"])\n\n # Testing nested task sets\n class MyUser2(User):\n @task\n class TopLevelTaskSet(TaskSet):\n @task\n class IndexTaskSet(TaskSet):\n @task(10)\n def index(self):\n self.client.get(\"/\")\n\n @task\n def stop(self):\n self.client.get(\"/hi\")\n\n @task(2)\n def stats(self):\n self.client.get(\"/stats/requests\")\n\n environment = Environment(user_classes=[MyUser2])\n environment.assign_equal_weights()\n u = environment.user_classes[0]\n verify_tasks(u, [\"index\", \"stop\", \"stats\"])\n\n # Testing task assignment via instance variable\n def outside_task():\n pass\n\n def outside_task_2():\n pass\n\n class SingleTaskSet(TaskSet):\n tasks = [outside_task, outside_task, outside_task_2]\n\n class MyUser3(User):\n tasks = [SingleTaskSet, outside_task]\n\n environment = Environment(user_classes=[MyUser3])\n environment.assign_equal_weights()\n u = environment.user_classes[0]\n verify_tasks(u, [\"outside_task\", \"outside_task_2\"])\n\n # Testing task assignment via dict\n class DictTaskSet(TaskSet):\n def dict_task_1():\n pass\n\n def dict_task_2():\n pass\n\n def dict_task_3():\n pass\n\n tasks = {\n dict_task_1: 5,\n dict_task_2: 3,\n dict_task_3: 1,\n }\n\n class MyUser4(User):\n tasks = [DictTaskSet, SingleTaskSet, SingleTaskSet]\n\n # Assign user tasks in dict\n environment = Environment(user_classes=[MyUser4])\n environment.assign_equal_weights()\n u = environment.user_classes[0]\n verify_tasks(u, [\"outside_task\", \"outside_task_2\", \"dict_task_1\", \"dict_task_2\", \"dict_task_3\"])\n\n class MyUser5(User):\n tasks = {\n DictTaskSet: 5,\n SingleTaskSet: 3,\n outside_task: 6,\n }\n\n environment = Environment(user_classes=[MyUser5])\n environment.assign_equal_weights()\n u = environment.user_classes[0]\n verify_tasks(u, [\"outside_task\", \"outside_task_2\", \"dict_task_1\", \"dict_task_2\", \"dict_task_3\"])\n\n def test_user_classes_with_zero_weight_are_removed(self):\n class MyUser1(User):\n wait_time = constant(0)\n weight = 0\n\n @task\n def my_task(self):\n pass\n\n class MyUser2(User):\n wait_time = constant(0)\n weight = 1\n\n @task\n def my_task(self):\n pass\n\n environment = Environment(user_classes=[MyUser1, MyUser2])\n\n self.assertEqual(len(environment.user_classes), 1)\n self.assertIs(environment.user_classes[0], MyUser2)\n\n def test_all_user_classes_with_zero_weight_raises_exception(self):\n class MyUser1(User):\n wait_time = constant(0)\n weight = 0\n\n @task\n def my_task(self):\n pass\n\n class MyUser2(User):\n wait_time = constant(0)\n weight = 0\n\n @task\n def my_task(self):\n pass\n\n with self.assertRaises(ValueError) as e:\n environment = Environment(user_classes=[MyUser1, MyUser2])\n\n self.assertEqual(\n e.exception.args[0],\n \"There are no users with weight > 0.\",\n )\n\n def test_shape_class_attribute(self):\n class SubLoadTestShape(LoadTestShape):\n \"\"\"Inherited from locust.env.LoadTestShape\"\"\"\n\n with self.assertRaisesRegex(\n ValueError, r\"instance of LoadTestShape or subclass LoadTestShape\", msg=\"exception message is mismatching\"\n ):\n Environment(user_classes=[MyUserWithSameName1], shape_class=SubLoadTestShape)\n\n def test_dispatcher_class_attribute(self):\n environment = Environment(user_classes=[MyUserWithSameName1])\n\n self.assertEqual(environment.dispatcher_class, UsersDispatcher)\n\n class MyUsersDispatcher(UsersDispatcher):\n pass\n\n environment = Environment(user_classes=[MyUserWithSameName1], dispatcher_class=MyUsersDispatcher)\n\n self.assertEqual(environment.dispatcher_class, MyUsersDispatcher)\n\n def test_update_user_class(self):\n class MyUser1(User):\n @task\n def my_task(self):\n pass\n\n @task\n def my_task_2(self):\n pass\n\n class MyUser2(User):\n @task\n def my_task(self):\n pass\n\n environment = Environment(\n user_classes=[MyUser1, MyUser2],\n available_user_classes={\"User1\": MyUser1, \"User2\": MyUser2},\n available_user_tasks={\"User1\": MyUser1.tasks, \"User2\": MyUser2.tasks},\n )\n\n environment.update_user_class({\"user_class_name\": \"User1\", \"host\": \"http://localhost\", \"tasks\": [\"my_task_2\"]})\n\n self.assertEqual(\n environment.available_user_classes[\"User1\"].json(),\n {\"host\": \"http://localhost\", \"tasks\": [\"my_task_2\"], \"fixed_count\": 0, \"weight\": 1},\n )\n\n# Path: locust/test/test_fasthttp.py\nfrom locust import FastHttpUser\nfrom locust.argument_parser import parse_options\nfrom locust.contrib.fasthttp import FastHttpSession\nfrom locust.exception import CatchResponseError, InterruptTaskSet, LocustError, ResponseError\nfrom locust.user import TaskSet, task\nfrom locust.util.load_locustfile import is_user_class\n\nimport socket\nimport time\nfrom tempfile import NamedTemporaryFile\n\nimport gevent\nfrom geventhttpclient.client import HTTPClientPool\n\nfrom .testcases import LocustTestCase, WebserverTestCase\nfrom .util import create_tls_cert\n\n\nclass TestFastHttpSession(WebserverTestCase):\n def get_client(self):\n return FastHttpSession(self.environment, base_url=\"http://127.0.0.1:%i\" % self.port, user=None)\n\n def test_get(self):\n s = self.get_client()\n r = s.get(\"/ultra_fast\")\n self.assertEqual(200, r.status_code)\n\n def test_connection_error(self):\n s = FastHttpSession(self.environment, \"http://localhost:1\", user=None)\n r = s.get(\"/\", headers={\"X-Test-Headers\": \"hello\"})\n self.assertEqual(r.status_code, 0)\n self.assertEqual(None, r.content)\n self.assertEqual(1, len(self.runner.stats.errors))\n self.assertTrue(isinstance(r.error, ConnectionRefusedError))\n self.assertTrue(isinstance(next(iter(self.runner.stats.errors.values())).error, ConnectionRefusedError))\n self.assertEqual(r.url, \"http://localhost:1/\")\n self.assertEqual(r.request.url, r.url)\n self.assertEqual(r.request.headers.get(\"X-Test-Headers\", \"\"), \"hello\")\n\n def test_404(self):\n s = self.get_client()\n r = s.get(\"/does_not_exist\")\n self.assertEqual(404, r.status_code)\n self.assertEqual(1, self.runner.stats.get(\"/does_not_exist\", \"GET\").num_failures)\n\n def test_204(self):\n s = self.get_client()\n r = s.get(\"/status/204\")\n self.assertEqual(204, r.status_code)\n self.assertEqual(1, self.runner.stats.get(\"/status/204\", \"GET\").num_requests)\n self.assertEqual(0, self.runner.stats.get(\"/status/204\", \"GET\").num_failures)\n self.assertEqual(r.url, \"http://127.0.0.1:%i/status/204\" % self.port)\n self.assertEqual(r.request.url, r.url)\n\n def test_streaming_response(self):\n \"\"\"\n Test a request to an endpoint that returns a streaming response\n \"\"\"\n s = self.get_client()\n r = s.get(\"/streaming/30\")\n\n # verify that the time reported includes the download time of the whole streamed response\n self.assertGreater(self.runner.stats.get(\"/streaming/30\", method=\"GET\").avg_response_time, 250)\n self.runner.stats.clear_all()\n\n # verify that response time does NOT include whole download time, when using stream=True\n r = s.get(\"/streaming/30\", stream=True)\n self.assertGreaterEqual(self.runner.stats.get(\"/streaming/30\", method=\"GET\").avg_response_time, 0)\n self.assertLess(self.runner.stats.get(\"/streaming/30\", method=\"GET\").avg_response_time, 250)\n\n # download the content of the streaming response (so we don't get an ugly exception in the log)\n _ = r.content\n\n def test_slow_redirect(self):\n s = self.get_client()\n url = \"/redirect?url=/redirect&delay=0.5\"\n r = s.get(url)\n stats = self.runner.stats.get(url, method=\"GET\")\n self.assertEqual(1, stats.num_requests)\n self.assertGreater(stats.avg_response_time, 500)\n\n def test_post_redirect(self):\n s = self.get_client()\n url = \"/redirect\"\n r = s.post(url)\n self.assertEqual(200, r.status_code)\n post_stats = self.runner.stats.get(url, method=\"POST\")\n get_stats = self.runner.stats.get(url, method=\"GET\")\n self.assertEqual(1, post_stats.num_requests)\n self.assertEqual(0, get_stats.num_requests)\n\n def test_cookie(self):\n s = self.get_client()\n r = s.post(\"/set_cookie?name=testcookie&value=1337\")\n self.assertEqual(200, r.status_code)\n r = s.get(\"/get_cookie?name=testcookie\")\n self.assertEqual(\"1337\", r.content.decode())\n self.assertEqual(\"1337\", r.text)\n\n def test_head(self):\n s = self.get_client()\n r = s.head(\"/request_method\")\n self.assertEqual(200, r.status_code)\n self.assertEqual(\"\", r.content.decode())\n\n def test_delete(self):\n s = self.get_client()\n r = s.delete(\"/request_method\")\n self.assertEqual(200, r.status_code)\n self.assertEqual(\"DELETE\", r.content.decode())\n\n def test_patch(self):\n s = self.get_client()\n r = s.patch(\"/request_method\")\n self.assertEqual(200, r.status_code)\n self.assertEqual(\"PATCH\", r.content.decode())\n\n def test_options(self):\n s = self.get_client()\n r = s.options(\"/request_method\")\n self.assertEqual(200, r.status_code)\n self.assertEqual(\"\", r.content.decode())\n self.assertEqual(\n {\"OPTIONS\", \"DELETE\", \"PUT\", \"GET\", \"POST\", \"HEAD\", \"PATCH\"},\n set(r.headers[\"allow\"].split(\", \")),\n )\n\n def test_json_payload(self):\n s = self.get_client()\n r = s.post(\"/request_method\", json={\"foo\": \"bar\"})\n self.assertEqual(200, r.status_code)\n self.assertEqual(r.request.body, '{\"foo\": \"bar\"}')\n self.assertEqual(r.request.headers.get(\"Content-Type\"), \"application/json\")\n\n def test_catch_response_fail_successful_request(self):\n s = self.get_client()\n with s.get(\"/ultra_fast\", catch_response=True) as r:\n r.failure(\"nope\")\n self.assertEqual(1, self.environment.stats.get(\"/ultra_fast\", \"GET\").num_requests)\n self.assertEqual(1, self.environment.stats.get(\"/ultra_fast\", \"GET\").num_failures)\n\n def test_catch_response_pass_failed_request(self):\n s = self.get_client()\n with s.get(\"/fail\", catch_response=True) as r:\n r.success()\n self.assertEqual(1, self.environment.stats.total.num_requests)\n self.assertEqual(0, self.environment.stats.total.num_failures)\n\n def test_catch_response_multiple_failure_and_success(self):\n s = self.get_client()\n with s.get(\"/ultra_fast\", catch_response=True) as r:\n r.failure(\"nope\")\n r.success()\n r.failure(\"nooo\")\n r.success()\n self.assertEqual(1, self.environment.stats.total.num_requests)\n self.assertEqual(0, self.environment.stats.total.num_failures)\n\n def test_catch_response_pass_failed_request_with_other_exception_within_block(self):\n class OtherException(Exception):\n pass\n\n s = self.get_client()\n try:\n with s.get(\"/fail\", catch_response=True) as r:\n r.success()\n raise OtherException(\"wtf\")\n except OtherException as e:\n pass\n else:\n self.fail(\"OtherException should have been raised\")\n\n self.assertEqual(1, self.environment.stats.total.num_requests)\n self.assertEqual(0, self.environment.stats.total.num_failures)\n\n def test_catch_response_default_success(self):\n s = self.get_client()\n with s.get(\"/ultra_fast\", catch_response=True) as r:\n pass\n self.assertEqual(1, self.environment.stats.get(\"/ultra_fast\", \"GET\").num_requests)\n self.assertEqual(0, self.environment.stats.get(\"/ultra_fast\", \"GET\").num_failures)\n\n def test_catch_response_default_fail(self):\n s = self.get_client()\n with s.get(\"/fail\", catch_response=True) as r:\n pass\n self.assertEqual(1, self.environment.stats.total.num_requests)\n self.assertEqual(1, self.environment.stats.total.num_failures)\n\n def test_error_message_with_name_replacement(self):\n s = self.get_client()\n kwargs = {}\n\n def on_request(**kw):\n self.assertIsNotNone(kw[\"exception\"])\n kwargs.update(kw)\n\n self.environment.events.request.add_listener(on_request)\n before_request = time.time()\n s.request(\"get\", \"/wrong_url/01\", name=\"replaced_url_name\", context={\"foo\": \"bar\"})\n after_request = time.time()\n # self.assertIn(\"for url: replaced_url_name\", str(kwargs[\"exception\"])) # this is actually broken for FastHttpUser right now...\n self.assertAlmostEqual(before_request, kwargs[\"start_time\"], delta=0.01)\n self.assertAlmostEqual(after_request, kwargs[\"start_time\"] + kwargs[\"response_time\"] / 1000, delta=0.01)\n self.assertEqual(s.base_url + \"/wrong_url/01\", kwargs[\"url\"]) # url is unaffected by name\n self.assertDictEqual({\"foo\": \"bar\"}, kwargs[\"context\"])\n\n def test_custom_ssl_context_fail_with_bad_context(self):\n \"\"\"\n Test FastHttpSession with a custom SSLContext factory that will fail as\n we can not set verify_mode to CERT_NONE when check_hostname is enabled\n \"\"\"\n\n def create_custom_context():\n context = gevent.ssl.create_default_context()\n context.check_hostname = True\n context.verify_mode = gevent.ssl.CERT_NONE\n return context\n\n s = FastHttpSession(\n self.environment,\n \"https://127.0.0.1:%i\" % self.port,\n ssl_context_factory=create_custom_context,\n user=None,\n )\n with self.assertRaises(ValueError) as e:\n s.get(\"/\")\n self.assertEqual(e.exception.args, (\"Cannot set verify_mode to CERT_NONE when check_hostname is enabled.\",))\n\n def test_custom_ssl_context_passed_correct_to_client_pool(self):\n \"\"\"\n Test FastHttpSession with a custom SSLContext factory with a options.name\n that will be passed correctly to the ClientPool. It will also test a 2nd\n factory which is not the correct one.\n \"\"\"\n\n def custom_ssl_context():\n context = gevent.ssl.create_default_context()\n context.check_hostname = False\n context.verify_mode = gevent.ssl.CERT_NONE\n context.options.name = \"FAKEOPTION\"\n return context\n\n def custom_context_with_wrong_option():\n context = gevent.ssl.create_default_context()\n context.check_hostname = False\n context.verify_mode = gevent.ssl.CERT_NONE\n context.options.name = \"OPTIONFAKED\"\n return context\n\n s = FastHttpSession(\n self.environment,\n \"https://127.0.0.1:%i\" % self.port,\n ssl_context_factory=custom_ssl_context,\n user=None,\n )\n self.assertEqual(s.client.clientpool.client_args[\"ssl_context_factory\"], custom_ssl_context)\n self.assertNotEqual(s.client.clientpool.client_args[\"ssl_context_factory\"], custom_context_with_wrong_option)\n\n\nclass TestRequestStatsWithWebserver(WebserverTestCase):\n def test_request_stats_content_length(self):\n class MyUser(FastHttpUser):\n host = \"http://127.0.0.1:%i\" % self.port\n\n locust = MyUser(self.environment)\n locust.client.get(\"/ultra_fast\")\n self.assertEqual(\n self.runner.stats.get(\"/ultra_fast\", \"GET\").avg_content_length, len(\"This is an ultra fast response\")\n )\n locust.client.get(\"/ultra_fast\")\n self.assertEqual(\n self.runner.stats.get(\"/ultra_fast\", \"GET\").avg_content_length, len(\"This is an ultra fast response\")\n )\n\n def test_request_stats_no_content_length(self):\n class MyUser(FastHttpUser):\n host = \"http://127.0.0.1:%i\" % self.port\n\n l = MyUser(self.environment)\n path = \"/no_content_length\"\n r = l.client.get(path)\n self.assertEqual(\n self.runner.stats.get(path, \"GET\").avg_content_length,\n len(\"This response does not have content-length in the header\"),\n )\n\n def test_request_stats_no_content_length_streaming(self):\n class MyUser(FastHttpUser):\n host = \"http://127.0.0.1:%i\" % self.port\n\n l = MyUser(self.environment)\n path = \"/no_content_length\"\n r = l.client.get(path, stream=True)\n self.assertEqual(0, self.runner.stats.get(path, \"GET\").avg_content_length)\n\n def test_request_stats_named_endpoint(self):\n class MyUser(FastHttpUser):\n host = \"http://127.0.0.1:%i\" % self.port\n\n locust = MyUser(self.environment)\n locust.client.get(\"/ultra_fast\", name=\"my_custom_name\")\n self.assertEqual(1, self.runner.stats.get(\"my_custom_name\", \"GET\").num_requests)\n\n def test_request_stats_query_variables(self):\n class MyUser(FastHttpUser):\n host = \"http://127.0.0.1:%i\" % self.port\n\n locust = MyUser(self.environment)\n locust.client.get(\"/ultra_fast?query=1\")\n self.assertEqual(1, self.runner.stats.get(\"/ultra_fast?query=1\", \"GET\").num_requests)\n\n def test_request_stats_put(self):\n class MyUser(FastHttpUser):\n host = \"http://127.0.0.1:%i\" % self.port\n\n locust = MyUser(self.environment)\n locust.client.put(\"/put\")\n self.assertEqual(1, self.runner.stats.get(\"/put\", \"PUT\").num_requests)\n\n def test_request_connection_error(self):\n class MyUser(FastHttpUser):\n host = \"http://localhost:1\"\n\n locust = MyUser(self.environment)\n response = locust.client.get(\"/\")\n self.assertEqual(response.status_code, 0)\n self.assertEqual(1, self.runner.stats.get(\"/\", \"GET\").num_failures)\n self.assertEqual(1, self.runner.stats.get(\"/\", \"GET\").num_requests)\n\n\nclass TestFastHttpUserClass(WebserverTestCase):\n def test_is_abstract(self):\n self.assertTrue(FastHttpUser.abstract)\n self.assertFalse(is_user_class(FastHttpUser))\n\n def test_class_context(self):\n class MyUser(FastHttpUser):\n host = \"http://127.0.0.1:%i\" % self.port\n\n def context(self):\n return {\"user\": self.username}\n\n kwargs = {}\n\n def on_request(**kw):\n kwargs.update(kw)\n\n self.environment.events.request.add_listener(on_request)\n user = MyUser(self.environment)\n user.username = \"foo\"\n user.client.request(\"get\", \"/request_method\")\n self.assertDictEqual({\"user\": \"foo\"}, kwargs[\"context\"])\n self.assertEqual(\"GET\", kwargs[\"response\"].text)\n user.client.request(\"get\", \"/request_method\", context={\"user\": \"bar\"})\n self.assertDictEqual({\"user\": \"bar\"}, kwargs[\"context\"])\n\n def test_get_request(self):\n self.response = \"\"\n\n def t1(l):\n self.response = l.client.get(\"/ultra_fast\")\n\n class MyUser(FastHttpUser):\n tasks = [t1]\n host = \"http://127.0.0.1:%i\" % self.port\n\n my_locust = MyUser(self.environment)\n t1(my_locust)\n self.assertEqual(self.response.text, \"This is an ultra fast response\")\n\n def test_client_request_headers(self):\n class MyUser(FastHttpUser):\n host = \"http://127.0.0.1:%i\" % self.port\n\n locust = MyUser(self.environment)\n r = locust.client.get(\"/request_header_test\", headers={\"X-Header-Test\": \"hello\"})\n self.assertEqual(\"hello\", r.text)\n self.assertEqual(\"hello\", r.headers.get(\"X-Header-Test\"))\n self.assertEqual(\"hello\", r.request.headers.get(\"X-Header-Test\"))\n\n def test_client_get(self):\n class MyUser(FastHttpUser):\n host = \"http://127.0.0.1:%i\" % self.port\n\n locust = MyUser(self.environment)\n self.assertEqual(\"GET\", locust.client.get(\"/request_method\").text)\n\n def test_client_get_absolute_url(self):\n class MyUser(FastHttpUser):\n host = \"http://127.0.0.1:%i\" % self.port\n\n locust = MyUser(self.environment)\n self.assertEqual(\"GET\", locust.client.get(\"http://127.0.0.1:%i/request_method\" % self.port).text)\n\n def test_client_post(self):\n class MyUser(FastHttpUser):\n host = \"http://127.0.0.1:%i\" % self.port\n\n locust = MyUser(self.environment)\n self.assertEqual(\"POST\", locust.client.post(\"/request_method\", {\"arg\": \"hello world\"}).text)\n self.assertEqual(\"hello world\", locust.client.post(\"/post\", {\"arg\": \"hello world\"}).text)\n\n def test_client_put(self):\n class MyUser(FastHttpUser):\n host = \"http://127.0.0.1:%i\" % self.port\n\n locust = MyUser(self.environment)\n self.assertEqual(\"PUT\", locust.client.put(\"/request_method\", {\"arg\": \"hello world\"}).text)\n self.assertEqual(\"hello world\", locust.client.put(\"/put\", {\"arg\": \"hello world\"}).text)\n\n def test_client_delete(self):\n class MyUser(FastHttpUser):\n host = \"http://127.0.0.1:%i\" % self.port\n\n locust = MyUser(self.environment)\n self.assertEqual(\"DELETE\", locust.client.delete(\"/request_method\").text)\n self.assertEqual(200, locust.client.delete(\"/request_method\").status_code)\n\n def test_client_head(self):\n class MyUser(FastHttpUser):\n host = \"http://127.0.0.1:%i\" % self.port\n\n locust = MyUser(self.environment)\n self.assertEqual(200, locust.client.head(\"/request_method\").status_code)\n\n def test_complex_content_type(self):\n class MyUser(FastHttpUser):\n host = \"http://127.0.0.1:%i\" % self.port\n\n locust = MyUser(self.environment)\n\n self.assertEqual(\"stuff\", locust.client.get(\"/content_type_missing_charset\").text)\n self.assertEqual(\"stuff\", locust.client.get(\"/content_type_regular\").text)\n self.assertEqual(\"stuff\", locust.client.get(\"/content_type_with_extra_stuff\").text)\n\n def test_log_request_name_argument(self):\n self.response = \"\"\n\n class MyUser(FastHttpUser):\n tasks = []\n host = \"http://127.0.0.1:%i\" % self.port\n\n @task()\n def t1(l):\n self.response = l.client.get(\"/ultra_fast\", name=\"new name!\")\n\n my_locust = MyUser(self.environment)\n my_locust.t1()\n\n self.assertEqual(1, self.runner.stats.get(\"new name!\", \"GET\").num_requests)\n self.assertEqual(0, self.runner.stats.get(\"/ultra_fast\", \"GET\").num_requests)\n\n def test_redirect_url_original_path_as_name(self):\n class MyUser(FastHttpUser):\n host = \"http://127.0.0.1:%i\" % self.port\n\n l = MyUser(self.environment)\n l.client.get(\"/redirect\")\n\n self.assertEqual(1, len(self.runner.stats.entries))\n self.assertEqual(1, self.runner.stats.get(\"/redirect\", \"GET\").num_requests)\n self.assertEqual(0, self.runner.stats.get(\"/ultra_fast\", \"GET\").num_requests)\n\n def test_network_timeout_setting(self):\n class MyUser(FastHttpUser):\n network_timeout = 0.5\n host = \"http://127.0.0.1:%i\" % self.port\n\n l = MyUser(self.environment)\n\n timeout = gevent.Timeout(\n seconds=0.6,\n exception=AssertionError(\n \"Request took longer than 0.6 even though FastHttpUser.network_timeout was set to 0.5\"\n ),\n )\n timeout.start()\n r = l.client.get(\"/redirect?url=/redirect&delay=5.0\")\n timeout.cancel()\n\n self.assertTrue(isinstance(r.error.original, socket.timeout))\n self.assertEqual(1, self.runner.stats.get(\"/redirect?url=/redirect&delay=5.0\", \"GET\").num_failures)\n\n def test_max_redirect_setting(self):\n class MyUser(FastHttpUser):\n max_redirects = 1 # max_redirects and max_retries are funny names, because they are actually max attempts\n host = \"http://127.0.0.1:%i\" % self.port\n\n l = MyUser(self.environment)\n l.client.get(\"/redirect\")\n self.assertEqual(1, self.runner.stats.get(\"/redirect\", \"GET\").num_failures)\n\n def test_allow_redirects_override(self):\n class MyLocust(FastHttpUser):\n host = \"http://127.0.0.1:%i\" % self.port\n\n l = MyLocust(self.environment)\n resp = l.client.get(\"/redirect\", allow_redirects=False)\n self.assertTrue(resp.headers[\"location\"].endswith(\"/ultra_fast\"))\n resp = l.client.get(\"/redirect\") # ensure redirect still works\n self.assertFalse(\"location\" in resp.headers)\n\n def test_slow_redirect(self):\n s = FastHttpSession(self.environment, \"http://127.0.0.1:%i\" % self.port, user=None)\n url = \"/redirect?url=/redirect&delay=0.5\"\n r = s.get(url)\n stats = self.runner.stats.get(url, method=\"GET\")\n self.assertEqual(1, stats.num_requests)\n self.assertGreater(stats.avg_response_time, 500)\n\n def test_client_basic_auth(self):\n class MyUser(FastHttpUser):\n host = \"http://127.0.0.1:%i\" % self.port\n\n class MyAuthorizedUser(FastHttpUser):\n host = \"http://locust:menace@127.0.0.1:%i\" % self.port\n\n class MyUnauthorizedUser(FastHttpUser):\n host = \"http://locust:wrong@127.0.0.1:%i\" % self.port\n\n locust = MyUser(self.environment)\n unauthorized = MyUnauthorizedUser(self.environment)\n authorized = MyAuthorizedUser(self.environment)\n response = authorized.client.get(\"/basic_auth\")\n self.assertEqual(200, response.status_code)\n self.assertEqual(\"Authorized\", response.text)\n self.assertEqual(401, locust.client.get(\"/basic_auth\").status_code)\n self.assertEqual(401, unauthorized.client.get(\"/basic_auth\").status_code)\n\n def test_shared_client_pool(self):\n shared_client_pool = HTTPClientPool(concurrency=1)\n\n class MyUserA(FastHttpUser):\n host = \"http://127.0.0.1:%i\" % self.port\n client_pool = shared_client_pool\n\n class MyUserB(FastHttpUser):\n host = \"http://127.0.0.1:%i\" % self.port\n client_pool = shared_client_pool\n\n user_a = MyUserA(self.environment)\n user_b = MyUserB(self.environment)\n\n user_a.client.get(\"/ultra_fast\")\n user_b.client.get(\"/ultra_fast\")\n user_b.client.get(\"/ultra_fast\")\n user_a.client.get(\"/ultra_fast\")\n\n self.assertEqual(1, self.connections_count)\n self.assertEqual(4, self.requests_count)\n\n def test_client_pool_per_user_instance(self):\n class MyUser(FastHttpUser):\n host = \"http://127.0.0.1:%i\" % self.port\n\n user_a = MyUser(self.environment)\n user_b = MyUser(self.environment)\n\n user_a.client.get(\"/ultra_fast\")\n user_b.client.get(\"/ultra_fast\")\n user_b.client.get(\"/ultra_fast\")\n user_a.client.get(\"/ultra_fast\")\n\n self.assertEqual(2, self.connections_count)\n self.assertEqual(4, self.requests_count)\n\n def test_client_pool_concurrency(self):\n class MyUser(FastHttpUser):\n host = \"http://127.0.0.1:%i\" % self.port\n\n @task\n def t(self):\n def concurrent_request(url):\n response = self.client.get(url)\n assert response.status_code == 200\n\n pool = gevent.pool.Pool()\n urls = [\"/slow?delay=0.2\"] * 20 # these urls are all the same, but they could be different\n for url in urls:\n pool.spawn(concurrent_request, url)\n pool.join()\n\n user = MyUser(self.environment)\n before_requests = time.time()\n user.t()\n after_requests = time.time()\n expected_delta = 0.4 # 20 requests with concurrency 10 and response time 0.2\n self.assertAlmostEqual(before_requests + expected_delta, after_requests, delta=0.1)\n\n\nclass TestFastHttpCatchResponse(WebserverTestCase):\n def setUp(self):\n super().setUp()\n\n class MyUser(FastHttpUser):\n host = \"http://127.0.0.1:%i\" % self.port\n\n self.user = MyUser(self.environment)\n\n self.num_failures = 0\n self.num_success = 0\n\n def on_request(exception, **kwargs):\n if exception:\n self.num_failures += 1\n self.last_failure_exception = exception\n else:\n self.num_success += 1\n\n self.environment.events.request.add_listener(on_request)\n\n def test_catch_response(self):\n self.assertEqual(500, self.user.client.get(\"/fail\").status_code)\n self.assertEqual(1, self.num_failures)\n self.assertEqual(0, self.num_success)\n\n with self.user.client.get(\"/ultra_fast\", catch_response=True) as response:\n pass\n self.assertEqual(1, self.num_failures)\n self.assertEqual(1, self.num_success)\n self.assertIn(\"ultra fast\", str(response.content))\n\n with self.user.client.get(\"/ultra_fast\", catch_response=True) as response:\n raise ResponseError(\"Not working\")\n\n self.assertEqual(2, self.num_failures)\n self.assertEqual(1, self.num_success)\n\n def test_catch_response_http_fail(self):\n with self.user.client.get(\"/fail\", catch_response=True) as response:\n pass\n self.assertEqual(1, self.num_failures)\n self.assertEqual(0, self.num_success)\n\n def test_catch_response_http_manual_fail(self):\n with self.user.client.get(\"/ultra_fast\", catch_response=True) as response:\n response.failure(\"Haha!\")\n self.assertEqual(1, self.num_failures)\n self.assertEqual(0, self.num_success)\n self.assertTrue(\n isinstance(self.last_failure_exception, CatchResponseError),\n \"Failure event handler should have been passed a CatchResponseError instance\",\n )\n\n def test_catch_response_http_manual_success(self):\n with self.user.client.get(\"/fail\", catch_response=True) as response:\n response.success()\n self.assertEqual(0, self.num_failures)\n self.assertEqual(1, self.num_success)\n\n def test_catch_response_allow_404(self):\n with self.user.client.get(\"/does/not/exist\", catch_response=True) as response:\n self.assertEqual(404, response.status_code)\n if response.status_code == 404:\n response.success()\n self.assertEqual(0, self.num_failures)\n self.assertEqual(1, self.num_success)\n\n def test_interrupt_taskset_with_catch_response(self):\n class MyTaskSet(TaskSet):\n @task\n def interrupted_task(self):\n with self.client.get(\"/ultra_fast\", catch_response=True) as r:\n raise InterruptTaskSet()\n\n class MyUser(FastHttpUser):\n host = \"http://127.0.0.1:%i\" % self.port\n tasks = [MyTaskSet]\n\n l = MyUser(self.environment)\n ts = MyTaskSet(l)\n self.assertRaises(InterruptTaskSet, lambda: ts.interrupted_task())\n self.assertEqual(0, self.num_failures)\n self.assertEqual(0, self.num_success)\n\n def test_catch_response_connection_error_success(self):\n class MyUser(FastHttpUser):\n host = \"http://127.0.0.1:1\"\n\n l = MyUser(self.environment)\n with l.client.get(\"/\", catch_response=True) as r:\n self.assertEqual(r.status_code, 0)\n self.assertEqual(None, r.content)\n r.success()\n self.assertEqual(1, self.num_success)\n self.assertEqual(0, self.num_failures)\n\n def test_catch_response_connection_error_fail(self):\n class MyUser(FastHttpUser):\n host = \"http://127.0.0.1:1\"\n\n l = MyUser(self.environment)\n with l.client.get(\"/\", catch_response=True) as r:\n self.assertEqual(r.status_code, 0)\n self.assertEqual(None, r.content)\n r.failure(\"Manual fail\")\n self.assertEqual(0, self.num_success)\n self.assertEqual(1, self.num_failures)\n\n def test_catch_response_missing_with_block(self):\n # incorrect usage, missing with-block\n r = self.user.client.get(\"/fail\", catch_response=True)\n self.assertRaises(LocustError, r.success)\n self.assertRaises(LocustError, r.failure, \"\")\n\n def test_missing_catch_response_true(self):\n # incorrect usage, missing catch_response=True\n with self.user.client.get(\"/fail\") as resp:\n self.assertRaises(LocustError, resp.success)\n\n def test_rest_success(self):\n self.last_failure_exception = None\n with self.user.rest(\"POST\", \"/rest\", json={\"foo\": \"bar\"}) as response:\n assert response.js[\"foo\"] == \"bar\"\n\n self.assertEqual(0, self.num_failures)\n self.assertEqual(1, self.num_success)\n\n def test_rest_fail(self):\n with self.user.rest(\"POST\", \"/rest\", json={\"foo\": \"bar\"}) as response:\n assert response.js[\"foo\"] == \"NOPE\"\n\n self.assertTrue(\n isinstance(self.last_failure_exception, CatchResponseError),\n \"Failure event handler should have been passed a CatchResponseError instance\",\n )\n self.assertEqual(1, self.num_failures)\n self.assertEqual(0, self.num_success)\n\n\nclass TestFastHttpSsl(LocustTestCase):\n def setUp(self):\n super().setUp()\n tls_cert, tls_key = create_tls_cert(\"127.0.0.1\")\n self.tls_cert_file = NamedTemporaryFile()\n self.tls_key_file = NamedTemporaryFile()\n with open(self.tls_cert_file.name, \"w\") as f:\n f.write(tls_cert.decode())\n with open(self.tls_key_file.name, \"w\") as f:\n f.write(tls_key.decode())\n\n self.web_ui = self.environment.create_web_ui(\n \"127.0.0.1\",\n 0,\n tls_cert=self.tls_cert_file.name,\n tls_key=self.tls_key_file.name,\n )\n gevent.sleep(0.01)\n self.web_port = self.web_ui.server.server_port\n\n def tearDown(self):\n super().tearDown()\n self.web_ui.stop()\n\n def test_ssl_request_insecure(self):\n s = FastHttpSession(self.environment, \"https://127.0.0.1:%i\" % self.web_port, insecure=True, user=None)\n r = s.get(\"/\")\n self.assertEqual(200, r.status_code)\n self.assertIn(\"Locust for None\", r.content.decode(\"utf-8\"))\n self.assertIn(\"

Script: None

\", r.text)\n\n# Path: locust/test/test_http.py\nfrom locust.clients import HttpSession\nfrom locust.exception import LocustError, ResponseError\nfrom locust.user.users import HttpUser\n\nimport time\n\nfrom requests.exceptions import InvalidSchema, InvalidURL, MissingSchema, RequestException\n\nfrom .testcases import WebserverTestCase\n\n\nclass TestHttpSession(WebserverTestCase):\n def get_client(self, base_url=None):\n if base_url is None:\n base_url = \"http://127.0.0.1:%i\" % self.port\n return HttpSession(\n base_url=base_url,\n request_event=self.environment.events.request,\n user=None,\n )\n\n def test_get(self):\n s = self.get_client()\n r = s.get(\"/ultra_fast\")\n self.assertEqual(200, r.status_code)\n\n def test_connection_error(self):\n s = self.get_client(base_url=\"http://localhost:1\")\n r = s.get(\"/\", timeout=0.1)\n self.assertEqual(r.status_code, 0)\n self.assertEqual(None, r.content)\n self.assertRaises(RequestException, r.raise_for_status)\n\n def test_wrong_url(self):\n for url, exception in (\n (\"http://\\x94\", InvalidURL),\n (\"telnet://127.0.0.1\", InvalidSchema),\n (\"127.0.0.1\", MissingSchema),\n ):\n s = self.get_client(base_url=url)\n try:\n self.assertRaises(exception, s.get, \"/\")\n except KeyError:\n self.fail(f\"Invalid URL {url} was not propagated\")\n\n def test_streaming_response(self):\n \"\"\"\n Test a request to an endpoint that returns a streaming response\n \"\"\"\n s = self.get_client()\n r = s.get(\"/streaming/30\")\n\n # verify that the time reported includes the download time of the whole streamed response\n self.assertGreater(self.runner.stats.get(\"/streaming/30\", method=\"GET\").avg_response_time, 250)\n self.runner.stats.clear_all()\n\n # verify that response time does NOT include whole download time, when using stream=True\n r = s.get(\"/streaming/30\", stream=True)\n self.assertGreater(self.runner.stats.get(\"/streaming/30\", method=\"GET\").avg_response_time, 0)\n self.assertLess(self.runner.stats.get(\"/streaming/30\", method=\"GET\").avg_response_time, 250)\n\n # download the content of the streaming response (so we don't get an ugly exception in the log)\n _ = r.content\n\n def test_slow_redirect(self):\n s = self.get_client()\n url = \"/redirect?url=/redirect&delay=0.5\"\n r = s.get(url)\n stats = self.runner.stats.get(url, method=\"GET\")\n self.assertEqual(1, stats.num_requests)\n self.assertGreater(stats.avg_response_time, 500)\n\n def test_post_redirect(self):\n s = self.get_client()\n url = \"/redirect\"\n r = s.post(url)\n self.assertEqual(200, r.status_code)\n post_stats = self.runner.stats.get(url, method=\"POST\")\n get_stats = self.runner.stats.get(url, method=\"GET\")\n self.assertEqual(1, post_stats.num_requests)\n self.assertEqual(0, get_stats.num_requests)\n\n def test_cookie(self):\n s = self.get_client()\n r = s.post(\"/set_cookie?name=testcookie&value=1337\")\n self.assertEqual(200, r.status_code)\n r = s.get(\"/get_cookie?name=testcookie\")\n self.assertEqual(\"1337\", r.content.decode())\n\n def test_head(self):\n s = self.get_client()\n r = s.head(\"/request_method\")\n self.assertEqual(200, r.status_code)\n self.assertEqual(\"\", r.content.decode())\n\n def test_delete(self):\n s = self.get_client()\n r = s.delete(\"/request_method\")\n self.assertEqual(200, r.status_code)\n self.assertEqual(\"DELETE\", r.content.decode())\n\n def test_options(self):\n s = self.get_client()\n r = s.options(\"/request_method\")\n self.assertEqual(200, r.status_code)\n self.assertEqual(\"\", r.content.decode())\n self.assertEqual(\n {\"OPTIONS\", \"DELETE\", \"PUT\", \"GET\", \"POST\", \"HEAD\", \"PATCH\"},\n set(r.headers[\"allow\"].split(\", \")),\n )\n\n def test_error_message(self):\n s = self.get_client()\n kwargs = {}\n\n def on_request(**kw):\n kwargs.update(kw)\n\n self.environment.events.request.add_listener(on_request)\n s.request(\"get\", \"/wrong_url\", context={\"foo\": \"bar\"})\n self.assertIn(\"/wrong_url\", str(kwargs[\"exception\"]))\n self.assertDictEqual({\"foo\": \"bar\"}, kwargs[\"context\"])\n\n def test_context_in_success(self):\n s = self.get_client()\n kwargs = {}\n\n def on_request(exception, **kw):\n self.assertIsNone(exception)\n kwargs.update(kw)\n\n self.environment.events.request.add_listener(on_request)\n s.request(\"get\", \"/request_method\", context={\"foo\": \"bar\"})\n self.assertDictEqual({\"foo\": \"bar\"}, kwargs[\"context\"])\n\n def test_response_parameter(self):\n s = self.get_client()\n kwargs = {}\n\n def on_request(**kw):\n kwargs.update(kw)\n\n self.environment.events.request.add_listener(on_request)\n s.request(\"get\", \"/request_method\")\n self.assertEqual(\"GET\", kwargs[\"response\"].text)\n s.request(\"get\", \"/wrong_url\")\n self.assertEqual(\"Not Found\", kwargs[\"response\"].text)\n\n def test_error_message_with_name_replacement(self):\n s = self.get_client()\n kwargs = {}\n\n def on_request(**kw):\n self.assertIsNotNone(kw[\"exception\"])\n kwargs.update(kw)\n\n self.environment.events.request.add_listener(on_request)\n before_request = time.time()\n s.request(\"get\", \"/wrong_url/01\", name=\"replaced_url_name\", context={\"foo\": \"bar\"})\n after_request = time.time()\n self.assertIn(\"for url: replaced_url_name\", str(kwargs[\"exception\"]))\n self.assertAlmostEqual(before_request, kwargs[\"start_time\"], delta=0.01)\n self.assertAlmostEqual(after_request, kwargs[\"start_time\"] + kwargs[\"response_time\"] / 1000, delta=0.01)\n self.assertEqual(s.base_url + \"/wrong_url/01\", kwargs[\"url\"]) # url is unaffected by name\n self.assertDictEqual({\"foo\": \"bar\"}, kwargs[\"context\"])\n\n def test_get_with_params(self):\n s = self.get_client()\n r = s.get(\"/get_arg\", params={\"arg\": \"test_123\"})\n self.assertEqual(200, r.status_code)\n self.assertEqual(\"test_123\", r.text)\n\n def test_catch_response_fail_successful_request(self):\n s = self.get_client()\n with s.get(\"/ultra_fast\", catch_response=True) as r:\n r.failure(\"nope\")\n self.assertEqual(1, self.environment.stats.get(\"/ultra_fast\", \"GET\").num_requests)\n self.assertEqual(1, self.environment.stats.get(\"/ultra_fast\", \"GET\").num_failures)\n\n def test_catch_response_fail_successful_request_with_non_string_error_message(self):\n s = self.get_client()\n with s.get(\"/ultra_fast\", catch_response=True) as r:\n r.failure({\"other types are also wrapped as exceptions\": True})\n self.assertEqual(1, self.environment.stats.get(\"/ultra_fast\", \"GET\").num_requests)\n self.assertEqual(1, self.environment.stats.get(\"/ultra_fast\", \"GET\").num_failures)\n\n def test_catch_response_pass_failed_request(self):\n s = self.get_client()\n with s.get(\"/fail\", catch_response=True) as r:\n r.success()\n self.assertEqual(1, self.environment.stats.total.num_requests)\n self.assertEqual(0, self.environment.stats.total.num_failures)\n\n def test_catch_response_multiple_failure_and_success(self):\n s = self.get_client()\n with s.get(\"/ultra_fast\", catch_response=True) as r:\n r.failure(\"nope\")\n r.success()\n r.failure(\"nooo\")\n r.success()\n self.assertEqual(1, self.environment.stats.total.num_requests)\n self.assertEqual(0, self.environment.stats.total.num_failures)\n\n def test_catch_response_timeout(self):\n s = self.get_client()\n with s.get(\"/slow\", catch_response=True, timeout=0.1) as r:\n self.assertAlmostEqual(r.request_meta[\"response_time\"], 100, delta=50)\n self.assertEqual(1, self.environment.stats.total.num_requests)\n self.assertEqual(1, self.environment.stats.total.num_failures)\n\n def test_catch_response_pass_failed_request_with_other_exception_within_block(self):\n class OtherException(Exception):\n pass\n\n s = self.get_client()\n try:\n with s.get(\"/fail\", catch_response=True) as r:\n r.success()\n raise OtherException(\"wtf\")\n except OtherException as e:\n pass\n else:\n self.fail(\"OtherException should have been raised\")\n\n self.assertEqual(1, self.environment.stats.total.num_requests)\n self.assertEqual(0, self.environment.stats.total.num_failures)\n\n def test_catch_response_response_error(self):\n s = self.get_client()\n try:\n with s.get(\"/fail\", catch_response=True) as r:\n raise ResponseError(\"response error\")\n except ResponseError as e:\n self.fail(\"ResponseError should not have been raised\")\n\n self.assertEqual(1, self.environment.stats.total.num_requests)\n self.assertEqual(1, self.environment.stats.total.num_failures)\n\n def test_catch_response_default_success(self):\n s = self.get_client()\n with s.get(\"/ultra_fast\", catch_response=True) as r:\n pass\n self.assertEqual(1, self.environment.stats.get(\"/ultra_fast\", \"GET\").num_requests)\n self.assertEqual(0, self.environment.stats.get(\"/ultra_fast\", \"GET\").num_failures)\n\n def test_catch_response_default_fail(self):\n s = self.get_client()\n with s.get(\"/fail\", catch_response=True) as r:\n pass\n self.assertEqual(1, self.environment.stats.total.num_requests)\n self.assertEqual(1, self.environment.stats.total.num_failures)\n\n def test_catch_response_with_name_replacement(self):\n s = self.get_client()\n kwargs = {}\n\n def on_request(**kw):\n self.assertIsNotNone(kw[\"exception\"])\n kwargs.update(kw)\n\n self.environment.events.request.add_listener(on_request)\n\n with s.get(\"/wrong_url/01\", name=\"replaced_url_name\") as r:\n pass\n\n self.assertIn(\"for url: replaced_url_name\", str(kwargs[\"exception\"]))\n self.assertEqual(s.base_url + \"/wrong_url/01\", kwargs[\"url\"]) # url is unaffected by name\n\n def test_catch_response_missing_with_block(self):\n s = self.get_client()\n # incorrect usage, missing with-block\n r = s.get(\"/fail\", catch_response=True)\n self.assertRaises(LocustError, r.success)\n self.assertRaises(LocustError, r.failure, \"\")\n\n def test_missing_catch_response_true(self):\n s = self.get_client()\n # incorrect usage, missing catch_response=True\n with s.get(\"/fail\") as resp:\n self.assertRaises(LocustError, resp.success)\n\n def test_event_measure(self):\n kwargs = {}\n\n def on_request(**kw):\n kwargs.update(**kw)\n\n self.environment.events.request.add_listener(on_request)\n\n with self.environment.events.request.measure(\"GET\", \"/test\") as request_meta:\n time.sleep(0.001)\n\n self.assertTrue(1 <= kwargs[\"response_time\"] <= 1.5, kwargs[\"response_time\"])\n self.assertEqual(kwargs[\"name\"], \"/test\")\n self.assertIsNone(kwargs[\"exception\"])\n\n with self.environment.events.request.measure(\"GET\", \"/test\") as request_meta:\n request_meta[\"foo\"] = \"bar\"\n raise Exception(\"nooo\")\n\n self.assertEqual(kwargs[\"name\"], \"/test\")\n self.assertEqual(kwargs[\"foo\"], \"bar\")\n self.assertEqual(str(kwargs[\"exception\"]), \"nooo\")\n\n def test_user_context(self):\n class TestUser(HttpUser):\n host = f\"http://127.0.0.1:{self.port}\"\n\n def context(self):\n return {\"user\": self.username}\n\n kwargs = {}\n\n def on_request(**kw):\n kwargs.update(kw)\n\n self.environment.events.request.add_listener(on_request)\n\n user = TestUser(self.environment)\n user.username = \"foo\"\n user.client.request(\"get\", \"/request_method\")\n self.assertDictEqual({\"user\": \"foo\"}, kwargs[\"context\"])\n self.assertEqual(\"GET\", kwargs[\"response\"].text)\n user.client.request(\"get\", \"/request_method\", context={\"user\": \"bar\"}) # override User context\n...\n", "needle_token_start": 709, "needle_token_end": 852, "code_context_ntokens": 16084, "output": ["```python\ndef test_user_classes_with_same_name_is_error(self):\n with self.assertRaises(ValueError) as e:\n environment = Environment(user_classes=[MyUserWithSameName1, MyUserWithSameName2])\n\n self.assertEqual(\n e.exception.args[0],\n \"The following user classes have the same class name: locust.test.fake_module1_for_env_test.MyUserWithSameName, locust.test.fake_module2_for_env_test.MyUserWithSameName\",\n )\n```"]} +{"repo": "locustio/locust", "name": "rename_request", "language": "python", "path": "locust/clients.py", "position_ratio": 0.15, "description": "\nFunction Description:\n1. **Purpose**: To temporarily change the label used for grouping performance statistics of web requests in a session, facilitating the aggregation of different URLs under a common name in statistical reports.\n2. **Input**: A string representing the new label name for the grouping of requests.\n3. **Output**: This method does not return any value; it acts as a context manager that temporarily changes an internal state.\n4. **Procedure**: The method sets an internal label to the provided name, executes the block of code within its context, and then resets the label to its original state, ensuring that all requests made within the context block are grouped under the specified name in performance statistics.\n\n", "instruction": "Based on the function description and code context, please retrieve and repeat the exact described function from the code context in a code block wrapped by ```:", "template": "instruction\ncode_context\ndescription\ninstruction", "code_context": "# Path: locust/user/task.py\nfrom __future__ import annotations\n\nfrom locust.exception import InterruptTaskSet, MissingWaitTimeError, RescheduleTask, RescheduleTaskImmediately, StopUser\n\nimport logging\nimport random\nimport traceback\nfrom time import time\nfrom typing import (\n TYPE_CHECKING,\n Callable,\n Protocol,\n Type,\n TypeVar,\n final,\n overload,\n runtime_checkable,\n)\n\nimport gevent\nfrom gevent import GreenletExit\n\nif TYPE_CHECKING:\n from locust import User\n\n\nlogger = logging.getLogger(__name__)\nTaskT = TypeVar(\"TaskT\", Callable[..., None], Type[\"TaskSet\"])\n\nLOCUST_STATE_RUNNING, LOCUST_STATE_WAITING, LOCUST_STATE_STOPPING = [\"running\", \"waiting\", \"stopping\"]\n\n\n@runtime_checkable\nclass TaskHolder(Protocol[TaskT]):\n tasks: list[TaskT]\n\n\n@overload\ndef task(weight: TaskT) -> TaskT: ...\n\n\n@overload\ndef task(weight: int) -> Callable[[TaskT], TaskT]: ...\n\n\ndef task(weight: TaskT | int = 1) -> TaskT | Callable[[TaskT], TaskT]:\n \"\"\"\n Used as a convenience decorator to be able to declare tasks for a User or a TaskSet\n inline in the class. Example::\n\n class ForumPage(TaskSet):\n @task(100)\n def read_thread(self):\n pass\n\n @task(7)\n def create_thread(self):\n pass\n\n @task(25)\n class ForumThread(TaskSet):\n @task\n def get_author(self):\n pass\n\n @task\n def get_created(self):\n pass\n \"\"\"\n\n def decorator_func(func):\n if func.__name__ in [\"on_stop\", \"on_start\"]:\n logging.warning(\n \"You have tagged your on_stop/start function with @task. This will make the method get called both as a task AND on stop/start.\"\n ) # this is usually not what the user intended\n if func.__name__ == \"run\":\n raise Exception(\n \"User.run() is a method used internally by Locust, and you must not override it or register it as a task\"\n )\n func.locust_task_weight = weight\n return func\n\n \"\"\"\n Check if task was used without parentheses (not called), like this::\n\n @task\n def my_task(self)\n pass\n \"\"\"\n if callable(weight):\n func = weight\n weight = 1\n return decorator_func(func)\n else:\n return decorator_func\n\n\ndef tag(*tags: str) -> Callable[[TaskT], TaskT]:\n \"\"\"\n Decorator for tagging tasks and TaskSets with the given tag name. You can\n then limit the test to only execute tasks that are tagged with any of the\n tags provided by the :code:`--tags` command-line argument. Example::\n\n class ForumPage(TaskSet):\n @tag('thread')\n @task(100)\n def read_thread(self):\n pass\n\n @tag('thread')\n @tag('post')\n @task(7)\n def create_thread(self):\n pass\n\n @tag('post')\n @task(11)\n def comment(self):\n pass\n \"\"\"\n\n def decorator_func(decorated):\n if hasattr(decorated, \"tasks\"):\n decorated.tasks = list(map(tag(*tags), decorated.tasks))\n else:\n if \"locust_tag_set\" not in decorated.__dict__:\n decorated.locust_tag_set = set()\n decorated.locust_tag_set |= set(tags)\n return decorated\n\n if len(tags) == 0 or callable(tags[0]):\n raise ValueError(\"No tag name was supplied\")\n\n return decorator_func\n\n\ndef get_tasks_from_base_classes(bases, class_dict):\n \"\"\"\n Function used by both TaskSetMeta and UserMeta for collecting all declared tasks\n on the TaskSet/User class and all its base classes\n \"\"\"\n new_tasks = []\n for base in bases:\n if hasattr(base, \"tasks\") and base.tasks:\n new_tasks += base.tasks\n\n if \"tasks\" in class_dict and class_dict[\"tasks\"] is not None:\n tasks = class_dict[\"tasks\"]\n if isinstance(tasks, dict):\n tasks = tasks.items()\n\n for task in tasks:\n if isinstance(task, tuple):\n task, count = task\n for _ in range(count):\n new_tasks.append(task)\n else:\n new_tasks.append(task)\n\n for item in class_dict.values():\n if \"locust_task_weight\" in dir(item):\n for i in range(item.locust_task_weight):\n new_tasks.append(item)\n\n return new_tasks\n\n\ndef filter_tasks_by_tags(\n task_holder: type[TaskHolder],\n...\n# Path: locust/clients.py\nfrom __future__ import annotations\n\nimport re\nimport time\nfrom contextlib import contextmanager\nfrom typing import Generator\nfrom urllib.parse import urlparse, urlunparse\n\nimport requests\nfrom requests import Request, Response\nfrom requests.adapters import HTTPAdapter\nfrom requests.auth import HTTPBasicAuth\nfrom requests.exceptions import InvalidSchema, InvalidURL, MissingSchema, RequestException\nfrom urllib3 import PoolManager\n\nfrom .exception import CatchResponseError, LocustError, ResponseError\n\nabsolute_http_url_regexp = re.compile(r\"^https?://\", re.I)\n\n\nclass LocustResponse(Response):\n def raise_for_status(self):\n if hasattr(self, \"error\") and self.error:\n raise self.error\n Response.raise_for_status(self)\n\n\nclass HttpSession(requests.Session):\n \"\"\"\n Class for performing web requests and holding (session-) cookies between requests (in order\n to be able to log in and out of websites). Each request is logged so that locust can display\n statistics.\n\n This is a slightly extended version of `python-request `_'s\n :py:class:`requests.Session` class and mostly this class works exactly the same. However\n the methods for making requests (get, post, delete, put, head, options, patch, request)\n can now take a *url* argument that's only the path part of the URL, in which case the host\n part of the URL will be prepended with the HttpSession.base_url which is normally inherited\n from a User class' host attribute.\n\n Each of the methods for making requests also takes two additional optional arguments which\n are Locust specific and doesn't exist in python-requests. These are:\n\n :param name: (optional) An argument that can be specified to use as label in Locust's statistics instead of the URL path.\n This can be used to group different URL's that are requested into a single entry in Locust's statistics.\n :param catch_response: (optional) Boolean argument that, if set, can be used to make a request return a context manager\n to work as argument to a with statement. This will allow the request to be marked as a fail based on the content of the\n response, even if the response code is ok (2xx). The opposite also works, one can use catch_response to catch a request\n and then mark it as successful even if the response code was not (i.e 500 or 404).\n \"\"\"\n\n def __init__(self, base_url, request_event, user, *args, pool_manager: PoolManager | None = None, **kwargs):\n super().__init__(*args, **kwargs)\n\n self.base_url = base_url\n self.request_event = request_event\n self.user = user\n\n # User can group name, or use the group context manager to gather performance statistics under a specific name\n # This is an alternative to passing in the \"name\" parameter to the requests function\n self.request_name: str | None = None\n\n # Check for basic authentication\n parsed_url = urlparse(self.base_url)\n if parsed_url.username and parsed_url.password:\n netloc = parsed_url.hostname\n if parsed_url.port:\n netloc += \":%d\" % parsed_url.port\n\n # remove username and password from the base_url\n self.base_url = urlunparse(\n (parsed_url.scheme, netloc, parsed_url.path, parsed_url.params, parsed_url.query, parsed_url.fragment)\n )\n # configure requests to use basic auth\n self.auth = HTTPBasicAuth(parsed_url.username, parsed_url.password)\n\n self.mount(\"https://\", LocustHttpAdapter(pool_manager=pool_manager))\n self.mount(\"http://\", LocustHttpAdapter(pool_manager=pool_manager))\n\n def _build_url(self, path):\n \"\"\"prepend url with hostname unless it's already an absolute URL\"\"\"\n if absolute_http_url_regexp.match(path):\n return path\n else:\n return f\"{self.base_url}{path}\"\n\n @contextmanager\n \ndef rename_request(self, name: str) -> Generator[None, None, None]:\n \"\"\"Group requests using the \"with\" keyword\"\"\"\n\n self.request_name = name\n try:\n yield\n finally:\n self.request_name = None\n\n def request(self, method, url, name=None, catch_response=False, context={}, **kwargs):\n \"\"\"\n Constructs and sends a :py:class:`requests.Request`.\n Returns :py:class:`requests.Response` object.\n\n :param method: method for the new :class:`Request` object.\n :param url: URL for the new :class:`Request` object.\n :param name: (optional) An argument that can be specified to use as label in Locust's statistics instead of the URL path.\n This can be used to group different URL's that are requested into a single entry in Locust's statistics.\n :param catch_response: (optional) Boolean argument that, if set, can be used to make a request return a context manager\n to work as argument to a with statement. This will allow the request to be marked as a fail based on the content of the\n response, even if the response code is ok (2xx). The opposite also works, one can use catch_response to catch a request\n and then mark it as successful even if the response code was not (i.e 500 or 404).\n :param params: (optional) Dictionary or bytes to be sent in the query string for the :class:`Request`.\n :param data: (optional) Dictionary or bytes to send in the body of the :class:`Request`.\n :param headers: (optional) Dictionary of HTTP Headers to send with the :class:`Request`.\n :param cookies: (optional) Dict or CookieJar object to send with the :class:`Request`.\n :param files: (optional) Dictionary of ``'filename': file-like-objects`` for multipart encoding upload.\n :param auth: (optional) Auth tuple or callable to enable Basic/Digest/Custom HTTP Auth.\n :param timeout: (optional) How long in seconds to wait for the server to send data before giving up, as a float,\n or a (`connect timeout, read timeout `_) tuple.\n :type timeout: float or tuple\n :param allow_redirects: (optional) Set to True by default.\n :type allow_redirects: bool\n :param proxies: (optional) Dictionary mapping protocol to the URL of the proxy.\n :param stream: (optional) whether to immediately download the response content. Defaults to ``False``.\n :param verify: (optional) if ``True``, the SSL cert will be verified. A CA_BUNDLE path can also be provided.\n :param cert: (optional) if String, path to ssl client cert file (.pem). If Tuple, ('cert', 'key') pair.\n \"\"\"\n\n # if group name has been set and no name parameter has been passed in; set the name parameter to group_name\n if self.request_name and not name:\n name = self.request_name\n\n # prepend url with hostname unless it's already an absolute URL\n url = self._build_url(url)\n\n start_time = time.time()\n start_perf_counter = time.perf_counter()\n response = self._send_request_safe_mode(method, url, **kwargs)\n response_time = (time.perf_counter() - start_perf_counter) * 1000\n\n request_before_redirect = (response.history and response.history[0] or response).request\n url = request_before_redirect.url\n\n if not name:\n name = request_before_redirect.path_url\n\n if self.user:\n context = {**self.user.context(), **context}\n\n # store meta data that is used when reporting the request to locust's statistics\n request_meta = {\n \"request_type\": method,\n \"response_time\": response_time,\n \"name\": name,\n \"context\": context,\n \"response\": response,\n \"exception\": None,\n \"start_time\": start_time,\n \"url\": url,\n }\n\n # get the length of the content, but if the argument stream is set to True, we take\n # the size from the content-length header, in order to not trigger fetching of the body\n if kwargs.get(\"stream\", False):\n request_meta[\"response_length\"] = int(response.headers.get(\"content-length\") or 0)\n else:\n request_meta[\"response_length\"] = len(response.content or b\"\")\n\n if catch_response:\n return ResponseContextManager(response, request_event=self.request_event, request_meta=request_meta)\n else:\n with ResponseContextManager(response, request_event=self.request_event, request_meta=request_meta):\n pass\n return response\n\n def _send_request_safe_mode(self, method, url, **kwargs):\n \"\"\"\n Send an HTTP request, and catch any exception that might occur due to connection problems.\n\n Safe mode has been removed from requests 1.x.\n \"\"\"\n try:\n return super().request(method, url, **kwargs)\n except (MissingSchema, InvalidSchema, InvalidURL):\n raise\n except RequestException as e:\n r = LocustResponse()\n r.error = e\n r.status_code = 0 # with this status_code, content returns None\n r.request = Request(method, url).prepare()\n return r\n\n\nclass ResponseContextManager(LocustResponse):\n \"\"\"\n A Response class that also acts as a context manager that provides the ability to manually\n control if an HTTP request should be marked as successful or a failure in Locust's statistics\n\n This class is a subclass of :py:class:`Response ` with two additional\n methods: :py:meth:`success ` and\n :py:meth:`failure `.\n \"\"\"\n\n _manual_result: bool | Exception | None = None\n _entered = False\n\n def __init__(self, response, request_event, request_meta):\n # copy data from response to this object\n self.__dict__ = response.__dict__\n self._request_event = request_event\n self.request_meta = request_meta\n\n def __enter__(self):\n self._entered = True\n return self\n\n def __exit__(self, exc, value, traceback):\n # if the user has already manually marked this response as failure or success\n # we can ignore the default behaviour of letting the response code determine the outcome\n if self._manual_result is not None:\n if self._manual_result is True:\n self.request_meta[\"exception\"] = None\n elif isinstance(self._manual_result, Exception):\n self.request_meta[\"exception\"] = self._manual_result\n self._report_request()\n return exc is None\n\n if exc:\n if isinstance(value, ResponseError):\n self.request_meta[\"exception\"] = value\n self._report_request()\n else:\n # we want other unknown exceptions to be raised\n return False\n else:\n # Since we use the Exception message when grouping failures, in order to not get\n # multiple failure entries for different URLs for the same name argument, we need\n # to temporarily override the response.url attribute\n orig_url = self.url\n self.url = self.request_meta[\"name\"]\n\n try:\n self.raise_for_status()\n except requests.exceptions.RequestException as e:\n while (\n isinstance(\n e,\n (\n requests.exceptions.ConnectionError,\n requests.packages.urllib3.exceptions.ProtocolError,\n requests.packages.urllib3.exceptions.MaxRetryError,\n requests.packages.urllib3.exceptions.NewConnectionError,\n ),\n )\n and e.__context__ # Not sure if the above exceptions can ever be the lowest level, but it is good to be sure\n ):\n e = e.__context__\n self.request_meta[\"exception\"] = e\n\n self._report_request()\n self.url = orig_url\n\n return True\n\n def _report_request(self, exc=None):\n self._request_event.fire(**self.request_meta)\n\n def success(self):\n \"\"\"\n Report the response as successful\n\n Example::\n\n with self.client.get(\"/does/not/exist\", catch_response=True) as response:\n if response.status_code == 404:\n response.success()\n \"\"\"\n if not self._entered:\n raise LocustError(\n \"Tried to set status on a request that has not yet been made. Make sure you use a with-block, like this:\\n\\nwith self.client.request(..., catch_response=True) as response:\\n response.success()\"\n )\n self._manual_result = True\n\n def failure(self, exc):\n \"\"\"\n Report the response as a failure.\n\n if exc is anything other than a python exception (like a string) it will\n be wrapped inside a CatchResponseError.\n\n Example::\n\n with self.client.get(\"/\", catch_response=True) as response:\n if response.content == b\"\":\n response.failure(\"No data\")\n \"\"\"\n if not self._entered:\n raise LocustError(\n \"Tried to set status on a request that has not yet been made. Make sure you use a with-block, like this:\\n\\nwith self.client.request(..., catch_response=True) as response:\\n response.failure(...)\"\n )\n if not isinstance(exc, Exception):\n exc = CatchResponseError(exc)\n self._manual_result = exc\n\n\nclass LocustHttpAdapter(HTTPAdapter):\n def __init__(self, pool_manager: PoolManager | None, *args, **kwargs):\n self.poolmanager = pool_manager\n super().__init__(*args, **kwargs)\n\n def init_poolmanager(self, *args, **kwargs):\n if self.poolmanager is None:\n super().init_poolmanager(*args, **kwargs)\n\n\n# Monkey patch Response class to give some guidance\ndef _success(self):\n raise LocustError(\n \"If you want to change the state of the request, you must pass catch_response=True. See http://docs.locust.io/en/stable/writing-a-locustfile.html#validating-responses\"\n )\n\n\ndef _failure(self):\n raise LocustError(\n \"If you want to change the state of the request, you must pass catch_response=True. See http://docs.locust.io/en/stable/writing-a-locustfile.html#validating-responses\"\n )\n\n\nResponse.success = _success # type: ignore[attr-defined]\nResponse.failure = _failure # type: ignore[attr-defined]\n\n# Path: locust/user/wait_time.py\nimport random\nfrom time import time\n\n\ndef between(min_wait, max_wait):\n \"\"\"\n Returns a function that will return a random number between min_wait and max_wait.\n\n Example::\n\n class MyUser(User):\n # wait between 3.0 and 10.5 seconds after each task\n wait_time = between(3.0, 10.5)\n \"\"\"\n return lambda instance: min_wait + random.random() * (max_wait - min_wait)\n\n\ndef constant(wait_time):\n \"\"\"\n Returns a function that just returns the number specified by the wait_time argument\n\n Example::\n\n class MyUser(User):\n wait_time = constant(3)\n \"\"\"\n return lambda instance: wait_time\n\n\ndef constant_pacing(wait_time):\n \"\"\"\n Returns a function that will track the run time of the tasks, and for each time it's\n called it will return a wait time that will try to make the total time between task\n execution equal to the time specified by the wait_time argument.\n\n In the following example the task will always be executed once every 10 seconds, no matter\n the task execution time::\n\n class MyUser(User):\n wait_time = constant_pacing(10)\n @task\n def my_task(self):\n time.sleep(random.random())\n\n If a task execution exceeds the specified wait_time, the wait will be 0 before starting\n the next task.\n \"\"\"\n\n def wait_time_func(self):\n if not hasattr(self, \"_cp_last_wait_time\"):\n self._cp_last_wait_time = 0\n run_time = time() - self._cp_last_run - self._cp_last_wait_time\n self._cp_last_wait_time = max(0, wait_time - run_time)\n self._cp_last_run = time()\n return self._cp_last_wait_time\n\n return wait_time_func\n\n\ndef constant_throughput(task_runs_per_second):\n \"\"\"\n Returns a function that will track the run time of the tasks, and for each time it's\n called it will return a wait time that will try to make the number of task runs per second\n execution equal to the time specified by the task_runs_per_second argument.\n\n If you have multiple requests in a task your RPS will of course be higher than the\n specified throughput.\n\n This is the mathematical inverse of constant_pacing.\n\n In the following example the task will always be executed once every 10 seconds, no matter\n the task execution time::\n\n class MyUser(User):\n wait_time = constant_throughput(0.1)\n @task\n def my_task(self):\n time.sleep(random.random())\n\n If a task execution exceeds the specified wait_time, the wait will be 0 before starting\n the next task.\n \"\"\"\n return constant_pacing(1 / task_runs_per_second)\n\n# Path: locust/util/deprecation.py\nimport warnings\n\n# Show deprecation warnings\nwarnings.filterwarnings(\"always\", category=DeprecationWarning, module=\"locust\")\n\n\ndef check_for_deprecated_task_set_attribute(class_dict):\n from locust.user.task import TaskSet\n\n if \"task_set\" in class_dict:\n task_set = class_dict[\"task_set\"]\n if issubclass(task_set, TaskSet) and not hasattr(task_set, \"locust_task_weight\"):\n warnings.warn(\n \"Usage of User.task_set is deprecated since version 1.0. Set the tasks attribute instead \"\n \"(tasks = [%s])\" % task_set.__name__,\n DeprecationWarning,\n )\n\n\ndef deprecated_locust_meta_class(deprecation_message):\n class MetaClass(type):\n def __new__(mcs, classname, bases, class_dict):\n if classname in [\"DeprecatedLocustClass\", \"DeprecatedHttpLocustClass\", \"DeprecatedFastHttpLocustClass\"]:\n return super().__new__(mcs, classname, bases, class_dict)\n else:\n raise ImportError(deprecation_message)\n\n return MetaClass\n\n\n# PEP 484 specifies \"Generic metaclasses are not supported\", see https://github.com/python/mypy/issues/3602, ignore typing errors\nclass DeprecatedLocustClass(\n metaclass=deprecated_locust_meta_class( # type: ignore\n \"The Locust class has been renamed to User in version 1.0. \"\n \"For more info see: https://docs.locust.io/en/latest/changelog.html#changelog-1-0\"\n )\n):\n pass\n\n\nclass DeprecatedHttpLocustClass(\n metaclass=deprecated_locust_meta_class( # type: ignore\n \"The HttpLocust class has been renamed to HttpUser in version 1.0. \"\n \"For more info see: https://docs.locust.io/en/latest/changelog.html#changelog-1-0\"\n )\n):\n pass\n\n\nclass DeprecatedFastHttpLocustClass(\n metaclass=deprecated_locust_meta_class( # type: ignore\n \"The FastHttpLocust class has been renamed to FastHttpUser in version 1.0. \"\n \"For more info see: https://docs.locust.io/en/latest/changelog.html#changelog-1-0\"\n )\n):\n pass\n\n# Path: locust/user/users.py\nfrom __future__ import annotations\n\nimport logging\nimport time\nimport traceback\nfrom typing import Callable, final\n\nfrom gevent import GreenletExit, greenlet\nfrom gevent.pool import Group\nfrom urllib3 import PoolManager\n\nlogger = logging.getLogger(__name__)\nfrom locust.clients import HttpSession\nfrom locust.exception import LocustError, StopUser\nfrom locust.user.wait_time import constant\nfrom locust.util import deprecation\n\nfrom .task import (\n LOCUST_STATE_RUNNING,\n LOCUST_STATE_STOPPING,\n LOCUST_STATE_WAITING,\n DefaultTaskSet,\n TaskSet,\n get_tasks_from_base_classes,\n)\n\nlogger = logging.getLogger(__name__)\n\n\nclass UserMeta(type):\n \"\"\"\n Meta class for the main User class. It's used to allow User classes to specify task execution\n ratio using an {task:int} dict, or a [(task0,int), ..., (taskN,int)] list.\n \"\"\"\n\n def __new__(mcs, classname, bases, class_dict):\n # gather any tasks that is declared on the class (or it's bases)\n tasks = get_tasks_from_base_classes(bases, class_dict)\n class_dict[\"tasks\"] = tasks\n\n if not class_dict.get(\"abstract\"):\n # Not a base class\n class_dict[\"abstract\"] = False\n\n deprecation.check_for_deprecated_task_set_attribute(class_dict)\n\n return type.__new__(mcs, classname, bases, class_dict)\n\n\nclass User(metaclass=UserMeta):\n \"\"\"\n Represents a \"user\" which is to be spawned and attack the system that is to be load tested.\n\n The behaviour of this user is defined by its tasks. Tasks can be declared either directly on the\n class by using the :py:func:`@task decorator ` on methods, or by setting\n the :py:attr:`tasks attribute `.\n\n This class should usually be subclassed by a class that defines some kind of client. For\n example when load testing an HTTP system, you probably want to use the\n :py:class:`HttpUser ` class.\n \"\"\"\n\n host: str | None = None\n \"\"\"Base hostname to swarm. i.e: http://127.0.0.1:1234\"\"\"\n\n min_wait = None\n \"\"\"Deprecated: Use wait_time instead. Minimum waiting time between the execution of locust tasks\"\"\"\n\n max_wait = None\n \"\"\"Deprecated: Use wait_time instead. Maximum waiting time between the execution of locust tasks\"\"\"\n\n wait_time = constant(0)\n \"\"\"\n Method that returns the time (in seconds) between the execution of locust tasks.\n Can be overridden for individual TaskSets.\n\n Example::\n\n from locust import User, between\n class MyUser(User):\n wait_time = between(3, 25)\n \"\"\"\n\n wait_function = None\n \"\"\"\n .. warning::\n\n DEPRECATED: Use wait_time instead. Note that the new wait_time method should return seconds and not milliseconds.\n\n Method that returns the time between the execution of locust tasks in milliseconds\n \"\"\"\n\n tasks: list[TaskSet | Callable] = []\n \"\"\"\n Collection of python callables and/or TaskSet classes that the Locust user(s) will run.\n\n If tasks is a list, the task to be performed will be picked randomly.\n\n If tasks is a *(callable,int)* list of two-tuples, or a {callable:int} dict,\n the task to be performed will be picked randomly, but each task will be weighted\n according to its corresponding int value. So in the following case, *ThreadPage* will\n be fifteen times more likely to be picked than *write_post*::\n\n class ForumPage(TaskSet):\n tasks = {ThreadPage:15, write_post:1}\n \"\"\"\n\n weight = 1\n \"\"\"Probability of user class being chosen. The higher the weight, the greater the chance of it being chosen.\"\"\"\n\n fixed_count = 0\n \"\"\"\n If the value > 0, the weight property will be ignored and the 'fixed_count'-instances will be spawned.\n These Users are spawned first. If the total target count (specified by the --users arg) is not enough\n to spawn all instances of each User class with the defined property, the final count of each User is undefined.\n \"\"\"\n\n abstract = True\n \"\"\"If abstract is True, the class is meant to be subclassed, and locust will not spawn users of this class during a test.\"\"\"\n\n def __init__(self, environment):\n super().__init__()\n self.environment = environment\n \"\"\"A reference to the :py:class:`Environment ` in which this user is running\"\"\"\n self._state = None\n self._greenlet: greenlet.Greenlet = None\n self._group: Group\n self._taskset_instance: TaskSet = None\n self._cp_last_run = time.time() # used by constant_pacing wait_time\n\n def on_start(self) -> None:\n \"\"\"\n Called when a User starts running.\n \"\"\"\n pass\n\n def on_stop(self):\n \"\"\"\n Called when a User stops running (is killed)\n \"\"\"\n pass\n\n @final\n def run(self):\n self._state = LOCUST_STATE_RUNNING\n self._taskset_instance = DefaultTaskSet(self)\n try:\n # run the TaskSet on_start method, if it has one\n try:\n self.on_start()\n except Exception as e:\n # unhandled exceptions inside tasks are logged in TaskSet.run, but since we're not yet there...\n logger.error(\"%s\\n%s\", e, traceback.format_exc())\n raise\n\n self._taskset_instance.run()\n except (GreenletExit, StopUser):\n # run the on_stop method, if it has one\n self.on_stop()\n\n def wait(self):\n \"\"\"\n Make the running user sleep for a duration defined by the User.wait_time\n function.\n\n The user can also be killed gracefully while it's sleeping, so calling this\n method within a task makes it possible for a user to be killed mid-task even if you've\n set a stop_timeout. If this behaviour is not desired, you should make the user wait using\n gevent.sleep() instead.\n \"\"\"\n self._taskset_instance.wait()\n\n def start(self, group: Group):\n \"\"\"\n Start a greenlet that runs this User instance.\n\n :param group: Group instance where the greenlet will be spawned.\n :type group: gevent.pool.Group\n :returns: The spawned greenlet.\n \"\"\"\n\n def run_user(user):\n \"\"\"\n Main function for User greenlet. It's important that this function takes the user\n instance as an argument, since we use greenlet_instance.args[0] to retrieve a reference to the\n User instance.\n \"\"\"\n user.run()\n\n self._greenlet = group.spawn(run_user, self)\n self._group = group\n return self._greenlet\n\n def stop(self, force=False):\n \"\"\"\n Stop the user greenlet.\n\n :param force: If False (the default) the stopping is done gracefully by setting the state to LOCUST_STATE_STOPPING\n which will make the User instance stop once any currently running task is complete and on_stop\n methods are called. If force is True the greenlet will be killed immediately.\n :returns: True if the greenlet was killed immediately, otherwise False\n \"\"\"\n if force or self._state == LOCUST_STATE_WAITING:\n self._group.killone(self._greenlet)\n return True\n elif self._state == LOCUST_STATE_RUNNING:\n self._state = LOCUST_STATE_STOPPING\n return False\n else:\n raise Exception(f\"Tried to stop User in an unexpected state: {self._state}. This should never happen.\")\n\n @property\n def group(self):\n return self._group\n\n @property\n def greenlet(self):\n return self._greenlet\n\n def context(self) -> dict:\n \"\"\"\n Adds the returned value (a dict) to the context for :ref:`request event `.\n Override this in your User class to customize the context.\n \"\"\"\n return {}\n\n @classmethod\n def json(cls):\n return {\n \"host\": cls.host,\n \"weight\": cls.weight,\n \"fixed_count\": cls.fixed_count,\n \"tasks\": [task.__name__ for task in cls.tasks],\n }\n\n @classmethod\n def fullname(cls) -> str:\n \"\"\"Fully qualified name of the user class, e.g. my_package.my_module.MyUserClass\"\"\"\n return \".\".join(filter(lambda x: x != \"\", (cls.__module__ + \".\" + cls.__qualname__).split(\".\")))\n\n\nclass HttpUser(User):\n \"\"\"\n Represents an HTTP \"user\" which is to be spawned and attack the system that is to be load tested.\n\n The behaviour of this user is defined by its tasks. Tasks can be declared either directly on the\n class by using the :py:func:`@task decorator ` on methods, or by setting\n the :py:attr:`tasks attribute `.\n\n This class creates a *client* attribute on instantiation which is an HTTP client with support\n for keeping a user session between requests.\n \"\"\"\n\n abstract = True\n \"\"\"If abstract is True, the class is meant to be subclassed, and users will not choose this locust during a test\"\"\"\n\n pool_manager: PoolManager | None = None\n \"\"\"Connection pool manager to use. If not given, a new manager is created per single user.\"\"\"\n\n def __init__(self, *args, **kwargs):\n super().__init__(*args, **kwargs)\n if self.host is None:\n raise LocustError(\n \"You must specify the base host. Either in the host attribute in the User class, or on the command line using the --host option.\"\n )\n\n self.client = HttpSession(\n base_url=self.host,\n request_event=self.environment.events.request,\n user=self,\n pool_manager=self.pool_manager,\n )\n \"\"\"\n Instance of HttpSession that is created upon instantiation of Locust.\n The client supports cookies, and therefore keeps the session between HTTP requests.\n \"\"\"\n self.client.trust_env = False\n\n# Path: locust/user/__init__.py\nfrom .task import TaskSet, tag, task\nfrom .users import HttpUser, User\n\n# Path: locust/user/inspectuser.py\nfrom __future__ import annotations\n\nimport inspect\nfrom collections import defaultdict\nfrom json import dumps\n\nfrom .task import TaskSet\nfrom .users import User\n\n\ndef print_task_ratio(user_classes, num_users, total):\n \"\"\"\n This function calculates the task ratio of users based on the user total count.\n \"\"\"\n d = get_ratio(user_classes, _calc_distribution(user_classes, num_users), total)\n _print_task_ratio(d)\n\n\ndef print_task_ratio_json(user_classes, num_users):\n d = _calc_distribution(user_classes, num_users)\n task_data = {\n \"per_class\": get_ratio(user_classes, d, False),\n \"total\": get_ratio(user_classes, d, True),\n }\n\n print(dumps(task_data, indent=4))\n\n\ndef _calc_distribution(user_classes, num_users):\n fixed_count = sum(u.fixed_count for u in user_classes if u.fixed_count)\n total_weight = sum(u.weight for u in user_classes if not u.fixed_count)\n num_users = num_users or (total_weight if not fixed_count else 1)\n weighted_count = num_users - fixed_count\n weighted_count = weighted_count if weighted_count > 0 else 0\n user_classes_count = {}\n\n for u in user_classes:\n count = u.fixed_count if u.fixed_count else (u.weight / total_weight) * weighted_count\n user_classes_count[u.__name__] = round(count)\n\n return user_classes_count\n\n\ndef _print_task_ratio(x, level=0):\n padding = 2 * \" \" * level\n for k, v in x.items():\n ratio = v.get(\"ratio\", 1)\n print(\" %-10s %-50s\" % (padding + \"%-6.1f\" % (ratio * 100), padding + k))\n if \"tasks\" in v:\n _print_task_ratio(v[\"tasks\"], level + 1)\n\n\ndef get_ratio(user_classes: list[type[User]], user_spawned: dict[str, int], total: bool) -> dict[str, dict[str, float]]:\n user_count = sum(user_spawned.values()) or 1\n ratio_percent: dict[type[User], float] = {u: user_spawned.get(u.__name__, 0) / user_count for u in user_classes}\n\n task_dict: dict[str, dict[str, float]] = {}\n for u, r in ratio_percent.items():\n d = {\"ratio\": r}\n d[\"tasks\"] = _get_task_ratio(u.tasks, total, r)\n task_dict[u.__name__] = d\n\n return task_dict\n\n\ndef _get_task_ratio(tasks, total, parent_ratio):\n parent_ratio = parent_ratio if total else 1.0\n ratio = defaultdict(int)\n for task in tasks:\n ratio[task] += 1\n\n ratio_percent = {t: r * parent_ratio / len(tasks) for t, r in ratio.items()}\n\n task_dict = {}\n for t, r in ratio_percent.items():\n d = {\"ratio\": r}\n if inspect.isclass(t) and issubclass(t, TaskSet):\n d[\"tasks\"] = _get_task_ratio(t.tasks, total, r)\n task_dict[t.__name__] = d\n\n return task_dict\n\n# Path: locust/html.py\nimport datetime\nimport glob\nimport os\nimport pathlib\nfrom html import escape\nfrom itertools import chain\nfrom json import dumps\n\nfrom jinja2 import Environment, FileSystemLoader\n\nfrom . import stats as stats_module\nfrom .runners import STATE_STOPPED, STATE_STOPPING, MasterRunner\nfrom .stats import sort_stats\nfrom .user.inspectuser import get_ratio\n\nPERCENTILES_FOR_HTML_REPORT = [0.50, 0.6, 0.7, 0.8, 0.9, 0.95, 0.99, 1.0]\n\n\ndef render_template(file, template_path, **kwargs):\n env = Environment(loader=FileSystemLoader(template_path), extensions=[\"jinja2.ext.do\"])\n template = env.get_template(file)\n return template.render(**kwargs)\n\n\ndef get_html_report(\n environment,\n show_download_link=True,\n use_modern_ui=False,\n theme=\"\",\n):\n root_path = os.path.dirname(os.path.abspath(__file__))\n if use_modern_ui:\n static_path = os.path.join(root_path, \"webui\", \"dist\", \"assets\")\n template_path = os.path.join(root_path, \"webui\", \"dist\")\n else:\n static_path = os.path.join(root_path, \"static\")\n template_path = os.path.join(root_path, \"templates\")\n\n stats = environment.runner.stats\n\n start_ts = stats.start_time\n start_time = datetime.datetime.utcfromtimestamp(start_ts).strftime(\"%Y-%m-%d %H:%M:%S\")\n\n end_ts = stats.last_request_timestamp\n if end_ts:\n end_time = datetime.datetime.utcfromtimestamp(end_ts).strftime(\"%Y-%m-%d %H:%M:%S\")\n else:\n end_time = start_time\n\n host = None\n if environment.host:\n host = environment.host\n elif environment.runner.user_classes:\n all_hosts = {l.host for l in environment.runner.user_classes}\n if len(all_hosts) == 1:\n host = list(all_hosts)[0]\n\n requests_statistics = list(chain(sort_stats(stats.entries), [stats.total]))\n failures_statistics = sort_stats(stats.errors)\n exceptions_statistics = [\n {**exc, \"nodes\": \", \".join(exc[\"nodes\"])} for exc in environment.runner.exceptions.values()\n ]\n\n history = stats.history\n\n static_js = []\n if use_modern_ui:\n js_files = [os.path.basename(filepath) for filepath in glob.glob(os.path.join(static_path, \"*.js\"))]\n else:\n js_files = [\"jquery-1.11.3.min.js\", \"echarts.common.min.js\", \"vintage.js\", \"chart.js\", \"tasks.js\"]\n\n for js_file in js_files:\n path = os.path.join(static_path, js_file)\n static_js.append(\"// \" + js_file + \"\\n\")\n with open(path, encoding=\"utf8\") as f:\n static_js.append(f.read())\n static_js.extend([\"\", \"\"])\n\n if not use_modern_ui:\n static_css = []\n css_files = [\"tables.css\"]\n for css_file in css_files:\n path = os.path.join(static_path, \"css\", css_file)\n static_css.append(\"/* \" + css_file + \" */\")\n with open(path, encoding=\"utf8\") as f:\n static_css.append(f.read())\n static_css.extend([\"\", \"\"])\n\n is_distributed = isinstance(environment.runner, MasterRunner)\n user_spawned = (\n environment.runner.reported_user_classes_count if is_distributed else environment.runner.user_classes_count\n )\n\n if environment.runner.state in [STATE_STOPPED, STATE_STOPPING]:\n user_spawned = environment.runner.final_user_classes_count\n\n task_data = {\n \"per_class\": get_ratio(environment.user_classes, user_spawned, False),\n \"total\": get_ratio(environment.user_classes, user_spawned, True),\n }\n\n if use_modern_ui:\n res = render_template(\n \"report.html\",\n template_path,\n template_args={\n \"is_report\": True,\n \"requests_statistics\": [stat.to_dict(escape_string_values=True) for stat in requests_statistics],\n \"failures_statistics\": [stat.to_dict() for stat in failures_statistics],\n \"exceptions_statistics\": [stat for stat in exceptions_statistics],\n \"response_time_statistics\": [\n {\n \"name\": escape(stat.name),\n \"method\": escape(stat.method or \"\"),\n **{\n str(percentile): stat.get_response_time_percentile(percentile)\n for percentile in PERCENTILES_FOR_HTML_REPORT\n },\n }\n for stat in requests_statistics\n ],\n \"start_time\": start_time,\n \"end_time\": end_time,\n \"host\": escape(str(host)),\n \"history\": history,\n \"show_download_link\": show_download_link,\n \"locustfile\": escape(str(environment.locustfile)),\n \"tasks\": task_data,\n \"percentiles_to_chart\": stats_module.MODERN_UI_PERCENTILES_TO_CHART,\n },\n theme=theme,\n static_js=\"\\n\".join(static_js),\n )\n else:\n res = render_template(\n \"report.html\",\n template_path,\n int=int,\n round=round,\n escape=escape,\n str=str,\n requests_statistics=requests_statistics,\n failures_statistics=failures_statistics,\n exceptions_statistics=exceptions_statistics,\n start_time=start_time,\n end_time=end_time,\n host=host,\n history=history,\n static_js=\"\\n\".join(static_js),\n static_css=\"\\n\".join(static_css),\n show_download_link=show_download_link,\n locustfile=environment.locustfile,\n tasks=dumps(task_data),\n percentile1=stats_module.PERCENTILES_TO_CHART[0],\n percentile2=stats_module.PERCENTILES_TO_CHART[1],\n )\n\n return res\n\n# Path: locust/util/cache.py\nimport functools\nfrom time import time\n\n\ndef memoize(timeout, dynamic_timeout=False):\n \"\"\"\n Memoization decorator with support for timeout.\n\n If dynamic_timeout is set, the cache timeout is doubled if the cached function\n takes longer time to run than the timeout time\n \"\"\"\n cache = {\"timeout\": timeout}\n\n def decorator(func):\n @functools.wraps(func)\n def wrapper(*args, **kwargs):\n start = time()\n if (not \"time\" in cache) or (start - cache[\"time\"] > cache[\"timeout\"]):\n # cache miss\n cache[\"result\"] = func(*args, **kwargs)\n cache[\"time\"] = time()\n if dynamic_timeout and cache[\"time\"] - start > cache[\"timeout\"]:\n cache[\"timeout\"] *= 2\n return cache[\"result\"]\n\n def clear_cache():\n if \"time\" in cache:\n del cache[\"time\"]\n if \"result\" in cache:\n del cache[\"result\"]\n\n wrapper.clear_cache = clear_cache\n return wrapper\n\n return decorator\n\n# Path: locust/util/timespan.py\nimport re\nfrom datetime import timedelta\n\n\ndef parse_timespan(time_str):\n \"\"\"\n Parse a string representing a time span and return the number of seconds.\n Valid formats are: 20, 20s, 3m, 2h, 1h20m, 3h30m10s, etc.\n \"\"\"\n if not time_str:\n raise ValueError(\"Invalid time span format\")\n\n if re.match(r\"^\\d+$\", time_str):\n # if an int is specified we assume they want seconds\n return int(time_str)\n\n timespan_regex = re.compile(r\"((?P\\d+?)h)?((?P\\d+?)m)?((?P\\d+?)s)?\")\n parts = timespan_regex.match(time_str)\n if not parts:\n raise ValueError(\"Invalid time span format. Valid formats: 20, 20s, 3m, 2h, 1h20m, 3h30m10s, etc.\")\n parts = parts.groupdict()\n time_params = {name: int(value) for name, value in parts.items() if value}\n if not time_params:\n raise ValueError(\"Invalid time span format. Valid formats: 20, 20s, 3m, 2h, 1h20m, 3h30m10s, etc.\")\n return int(timedelta(**time_params).total_seconds())\n\n# Path: locust/web.py\nfrom __future__ import annotations\n\nimport csv\nimport json\nimport logging\nimport os.path\nfrom functools import wraps\nfrom html import escape\nfrom io import StringIO\nfrom itertools import chain\nfrom json import dumps\nfrom time import time\nfrom typing import TYPE_CHECKING, Any\n\nimport gevent\nfrom flask import (\n Flask,\n Response,\n jsonify,\n make_response,\n redirect,\n render_template,\n request,\n send_file,\n send_from_directory,\n url_for,\n)\nfrom flask_cors import CORS\nfrom flask_login import LoginManager, login_required\nfrom gevent import pywsgi\n\nfrom . import __version__ as version\nfrom . import argument_parser\nfrom . import stats as stats_module\nfrom .html import get_html_report\nfrom .log import greenlet_exception_logger\nfrom .runners import STATE_MISSING, STATE_RUNNING, MasterRunner\nfrom .stats import StatsCSV, StatsCSVFileWriter, StatsErrorDict, sort_stats\nfrom .user.inspectuser import get_ratio\nfrom .util.cache import memoize\nfrom .util.timespan import parse_timespan\n\nif TYPE_CHECKING:\n from .env import Environment\n\n\nlogger = logging.getLogger(__name__)\ngreenlet_exception_handler = greenlet_exception_logger(logger)\n\nDEFAULT_CACHE_TIME = 2.0\n\n\nclass WebUI:\n \"\"\"\n Sets up and runs a Flask web app that can start and stop load tests using the\n :attr:`environment.runner ` as well as show the load test statistics\n in :attr:`environment.stats `\n \"\"\"\n\n app: Flask | None = None\n \"\"\"\n Reference to the :class:`flask.Flask` app. Can be used to add additional web routes and customize\n the Flask app in other various ways. Example::\n\n from flask import request\n\n @web_ui.app.route(\"/my_custom_route\")\n def my_custom_route():\n return \"your IP is: %s\" % request.remote_addr\n \"\"\"\n\n greenlet: gevent.Greenlet | None = None\n \"\"\"\n Greenlet of the running web server\n \"\"\"\n\n server: pywsgi.WSGIServer | None = None\n \"\"\"Reference to the :class:`pyqsgi.WSGIServer` instance\"\"\"\n\n template_args: dict[str, Any]\n \"\"\"Arguments used to render index.html for the web UI. Must be used with custom templates\n extending index.html.\"\"\"\n\n auth_args: dict[str, Any]\n \"\"\"Arguments used to render auth.html for the web UI auth page. Must be used when configuring auth\"\"\"\n\n def __init__(\n self,\n environment: Environment,\n host: str,\n port: int,\n web_login: bool = False,\n tls_cert: str | None = None,\n tls_key: str | None = None,\n stats_csv_writer: StatsCSV | None = None,\n delayed_start=False,\n userclass_picker_is_active=False,\n modern_ui=False,\n ):\n \"\"\"\n Create WebUI instance and start running the web server in a separate greenlet (self.greenlet)\n\n Arguments:\n environment: Reference to the current Locust Environment\n host: Host/interface that the web server should accept connections to\n port: Port that the web server should listen to\n web_login: Enables a login page for the modern UI\n tls_cert: A path to a TLS certificate\n tls_key: A path to a TLS private key\n delayed_start: Whether or not to delay starting web UI until `start()` is called. Delaying web UI start\n allows for adding Flask routes or Blueprints before accepting requests, avoiding errors.\n \"\"\"\n environment.web_ui = self\n self.stats_csv_writer = stats_csv_writer or StatsCSV(environment, stats_module.PERCENTILES_TO_REPORT)\n self.environment = environment\n self.host = host\n self.port = port\n self.tls_cert = tls_cert\n self.tls_key = tls_key\n self.userclass_picker_is_active = userclass_picker_is_active\n self.modern_ui = modern_ui\n self.web_login = web_login\n app = Flask(__name__)\n CORS(app)\n self.app = app\n app.jinja_env.add_extension(\"jinja2.ext.do\")\n app.debug = True\n root_path = os.path.dirname(os.path.abspath(__file__))\n app.root_path = root_path\n self.webui_build_path = os.path.join(root_path, \"webui\", \"dist\")\n self.greenlet: gevent.Greenlet | None = None\n self._swarm_greenlet: gevent.Greenlet | None = None\n self.template_args = {}\n self.auth_args = {}\n\n if self.web_login:\n self.login_manager = LoginManager()\n self.login_manager.init_app(app)\n self.login_manager.login_view = \"login\"\n\n if environment.runner:\n self.update_template_args()\n if not delayed_start:\n self.start()\n\n @app.errorhandler(Exception)\n def handle_exception(error):\n error_message = str(error)\n logger.log(logging.CRITICAL, error_message)\n return make_response(error_message, 500)\n\n @app.route(\"/assets/\")\n def send_assets(path):\n return send_from_directory(os.path.join(self.webui_build_path, \"assets\"), path)\n\n @app.route(\"/\")\n @self.auth_required_if_enabled\n def index() -> str | Response:\n if not environment.runner:\n return make_response(\"Error: Locust Environment does not have any runner\", 500)\n self.update_template_args()\n\n if self.modern_ui:\n self.set_static_modern_ui()\n\n return render_template(\"index.html\", template_args=self.template_args)\n return render_template(\"index.html\", **self.template_args)\n\n @app.route(\"/swarm\", methods=[\"POST\"])\n @self.auth_required_if_enabled\n def swarm() -> Response:\n assert request.method == \"POST\"\n\n # Loading UserClasses & ShapeClasses if Locust is running with UserClass Picker\n if self.userclass_picker_is_active:\n if not self.environment.available_user_classes:\n err_msg = \"UserClass picker is active but there are no available UserClasses\"\n return jsonify({\"success\": False, \"message\": err_msg, \"host\": environment.host})\n\n # Getting Specified User Classes\n form_data_user_class_names = request.form.getlist(\"user_classes\")\n\n # Updating UserClasses\n if form_data_user_class_names:\n user_classes = {}\n for user_class_name, user_class_object in self.environment.available_user_classes.items():\n if user_class_name in form_data_user_class_names:\n user_classes[user_class_name] = user_class_object\n\n else:\n if self.environment.runner and self.environment.runner.state == STATE_RUNNING:\n # Test is already running\n # Using the user classes that have already been selected\n user_classes = {\n key: value\n for (key, value) in self.environment.available_user_classes.items()\n if value in self.environment.user_classes\n }\n else:\n # Starting test with no user class selection\n # Defaulting to using all available user classes\n user_classes = self.environment.available_user_classes\n\n self._update_user_classes(user_classes)\n\n # Updating ShapeClass if specified in WebUI Form\n form_data_shape_class_name = request.form.get(\"shape_class\", \"Default\")\n if form_data_shape_class_name == \"Default\":\n self._update_shape_class(None)\n else:\n self._update_shape_class(form_data_shape_class_name)\n\n parsed_options_dict = vars(environment.parsed_options) if environment.parsed_options else {}\n run_time = None\n for key, value in request.form.items():\n if key == \"user_count\": # if we just renamed this field to \"users\" we wouldn't need this\n user_count = int(value)\n elif key == \"spawn_rate\":\n spawn_rate = float(value)\n elif key == \"host\":\n # Replace < > to guard against XSS\n environment.host = str(request.form[\"host\"]).replace(\"<\", \"\").replace(\">\", \"\")\n elif key == \"user_classes\":\n # Set environment.parsed_options.user_classes to the selected user_classes\n parsed_options_dict[key] = request.form.getlist(\"user_classes\")\n elif key == \"run_time\":\n if not value:\n continue\n try:\n run_time = parse_timespan(value)\n except ValueError:\n err_msg = \"Valid run_time formats are : 20, 20s, 3m, 2h, 1h20m, 3h30m10s, etc.\"\n logger.error(err_msg)\n return jsonify({\"success\": False, \"message\": err_msg, \"host\": environment.host})\n elif key in parsed_options_dict:\n # update the value in environment.parsed_options, but dont change the type.\n parsed_options_value = parsed_options_dict[key]\n\n if isinstance(parsed_options_value, bool):\n parsed_options_dict[key] = value == \"true\"\n elif parsed_options_value is None:\n parsed_options_dict[key] = parsed_options_value\n else:\n parsed_options_dict[key] = type(parsed_options_dict[key])(value)\n\n if environment.shape_class and environment.runner is not None:\n environment.runner.start_shape()\n return jsonify(\n {\n \"success\": True,\n \"message\": f\"Swarming started using shape class '{type(environment.shape_class).__name__}'\",\n \"host\": environment.host,\n }\n )\n\n if self._swarm_greenlet is not None:\n self._swarm_greenlet.kill(block=True)\n self._swarm_greenlet = None\n\n if environment.runner is not None:\n self._swarm_greenlet = gevent.spawn(environment.runner.start, user_count, spawn_rate)\n self._swarm_greenlet.link_exception(greenlet_exception_handler)\n response_data = {\n \"success\": True,\n \"message\": \"Swarming started\",\n \"host\": environment.host,\n }\n if run_time:\n gevent.spawn_later(run_time, self._stop_runners).link_exception(greenlet_exception_handler)\n response_data[\"run_time\"] = run_time\n\n if self.userclass_picker_is_active:\n response_data[\"user_classes\"] = sorted(user_classes.keys())\n\n return jsonify(response_data)\n else:\n return jsonify({\"success\": False, \"message\": \"No runner\", \"host\": environment.host})\n\n @app.route(\"/stop\")\n @self.auth_required_if_enabled\n def stop() -> Response:\n if self._swarm_greenlet is not None:\n self._swarm_greenlet.kill(block=True)\n self._swarm_greenlet = None\n if environment.runner is not None:\n environment.runner.stop()\n return jsonify({\"success\": True, \"message\": \"Test stopped\"})\n\n @app.route(\"/stats/reset\")\n @self.auth_required_if_enabled\n def reset_stats() -> str:\n environment.events.reset_stats.fire()\n if environment.runner is not None:\n environment.runner.stats.reset_all()\n environment.runner.exceptions = {}\n return \"ok\"\n\n @app.route(\"/stats/report\")\n @self.auth_required_if_enabled\n def stats_report() -> Response:\n theme = request.args.get(\"theme\", \"\")\n res = get_html_report(\n self.environment,\n show_download_link=not request.args.get(\"download\"),\n use_modern_ui=self.modern_ui,\n theme=theme,\n )\n if request.args.get(\"download\"):\n res = app.make_response(res)\n res.headers[\"Content-Disposition\"] = f\"attachment;filename=report_{time()}.html\"\n return res\n\n...\n", "needle_token_start": 2358, "needle_token_end": 2422, "code_context_ntokens": 16064, "output": ["```python\ndef rename_request(self, name: str) -> Generator[None, None, None]:\n \"\"\"Group requests using the \"with\" keyword\"\"\"\n\n self.request_name = name\n try:\n yield\n finally:\n self.request_name = None\n```"]} +{"repo": "locustio/locust", "name": "get_poller", "language": "python", "path": "locust/input_events.py", "position_ratio": 0.25, "description": "\nFunction Description:\n1. **Purpose**: The function is designed to select and return an appropriate keyboard input polling mechanism based on the operating system the script is running on. It chooses between two different classes that handle keyboard input for Unix-like systems and Windows systems.\n2. **Input**: There are no parameters required for this function.\n3. **Output**: The function returns an instance of a class that is capable of polling keyboard input. The specific class returned depends on the operating system.\n4. **Procedure**: The function checks the name of the operating system. If the operating system is Windows, it returns an instance of the class designed for handling keyboard input on Windows. If the operating system is not Windows, it assumes a Unix-like environment and returns an instance of the class designed for Unix-like systems.\n\n", "instruction": "Based on the function description and code context, please retrieve and repeat the exact described function from the code context in a code block wrapped by ```:", "template": "instruction\ncode_context\ndescription\ninstruction", "code_context": "# Path: locust/contrib/fasthttp.py\nfrom __future__ import annotations\n\nfrom locust.env import Environment\nfrom locust.exception import CatchResponseError, LocustError, ResponseError\nfrom locust.user import User\nfrom locust.util.deprecation import DeprecatedFastHttpLocustClass as FastHttpLocust\n\nimport json\nimport json as unshadowed_json # some methods take a named parameter called json\nimport re\nimport socket\nimport time\nimport traceback\nfrom base64 import b64encode\nfrom contextlib import contextmanager\nfrom http.cookiejar import CookieJar\nfrom json.decoder import JSONDecodeError\nfrom ssl import SSLError\nfrom typing import Any, Callable, Generator, cast\nfrom urllib.parse import urlparse, urlunparse\n\nimport gevent\nfrom charset_normalizer import detect\nfrom gevent.timeout import Timeout\nfrom geventhttpclient._parser import HTTPParseError\nfrom geventhttpclient.client import HTTPClientPool\nfrom geventhttpclient.header import Headers\nfrom geventhttpclient.response import HTTPConnectionClosed, HTTPSocketPoolResponse\nfrom geventhttpclient.useragent import CompatRequest, CompatResponse, ConnectionError, UserAgent\n\n# borrow requests's content-type header parsing\nfrom requests.utils import get_encoding_from_headers\n\n# Monkey patch geventhttpclient.useragent.CompatRequest so that Cookiejar works with Python >= 3.3\n# More info: https://github.com/requests/requests/pull/871\nCompatRequest.unverifiable = False\n\n# Workaround for AttributeError: 'CompatRequest' object has no attribute 'type' in Cookiejar\n# https://github.com/locustio/locust/issues/1138\n# Might allow secure cookies over non-secure connections but that is a minor concern in a load testing tool\nCompatRequest.type = \"https\"\n\n# Regexp for checking if an absolute URL was specified\nabsolute_http_url_regexp = re.compile(r\"^https?://\", re.I)\n\n# List of exceptions that can be raised by geventhttpclient when sending an HTTP request,\n# and that should result in a Locust failure\nFAILURE_EXCEPTIONS = (\n ConnectionError,\n ConnectionRefusedError,\n ConnectionResetError,\n socket.error,\n SSLError,\n Timeout,\n HTTPConnectionClosed,\n)\n\n\n...\n# Path: locust/debug.py\nfrom __future__ import annotations\n\nimport locust\nimport locust.log\nfrom locust import argument_parser\nfrom locust.env import Environment\nfrom locust.exception import CatchResponseError, RescheduleTask\n\nimport inspect\nimport os\nfrom datetime import datetime, timezone\nfrom typing import TYPE_CHECKING\n\nif TYPE_CHECKING:\n from locust import User\n\n\ndef _print_t(s):\n \"\"\"\n Print something with a tab instead of newline at the end\n \"\"\"\n print(str(s), end=\"\\t\")\n\n\nclass PrintListener:\n \"\"\"\n Print every response (useful when debugging a single locust)\n \"\"\"\n\n def __init__(\n self,\n env: Environment,\n include_length=False,\n include_time=False,\n include_context=False,\n include_payload=False,\n ):\n env.events.request.add_listener(self.on_request)\n\n self.include_length = \"length\\t\" if include_length else \"\"\n self.include_time = \"time \\t\" if include_time else \"\"\n self.include_context = \"context\\t\" if include_context else \"\"\n self.include_payload = \"payload\\t\" if include_payload else \"\"\n\n print(\n f\"\\n{self.include_time}type\\t{'name'.ljust(50)}\\tresp_ms\\t{self.include_length}exception\\t{self.include_context}\\t{self.include_payload}\"\n )\n\n def on_request(\n self,\n request_type,\n name,\n response_time,\n response_length,\n exception,\n context: dict,\n start_time=None,\n response=None,\n **_kwargs,\n ):\n if exception:\n if isinstance(exception, RescheduleTask):\n pass\n if isinstance(exception, CatchResponseError):\n e = str(exception)\n else:\n try:\n e = repr(exception)\n except AttributeError:\n e = f\"{exception.__class__} (and it has no string representation)\"\n errortext = e[:500].replace(\"\\n\", \" \")\n else:\n errortext = \"\"\n\n if response_time is None:\n response_time = -1\n n = name.ljust(30) if name else \"\"\n\n if self.include_time:\n if start_time:\n _print_t(datetime.fromtimestamp(start_time, tz=timezone.utc))\n else:\n _print_t(datetime.now())\n\n _print_t(request_type)\n _print_t(n.ljust(50))\n _print_t(str(round(response_time)).ljust(7))\n\n if self.include_length:\n _print_t(response_length)\n\n _print_t(errortext.ljust(9))\n\n if self.include_context:\n _print_t(context or \"\")\n\n if self.include_payload:\n _print_t(response._request.payload)\n\n print()\n\n\n_env: Environment | None = None # minimal Environment for debugging\n\n\ndef run_single_user(\n user_class: type[User],\n include_length=False,\n include_time=False,\n include_context=False,\n include_payload=False,\n loglevel: str | None = \"WARNING\",\n):\n \"\"\"\n Runs a single User. Useful when you want to run a debugger.\n\n It creates in a new locust :py:attr:`Environment ` and triggers any ``init`` or ``test_start`` :ref:`events ` as normal.\n\n It does **not** trigger ``test_stop`` or ``quit`` when you quit the debugger.\n\n It prints some info about every request to stdout, and you can get additional info using the `include_*` flags\n\n It also initiates logging on WARNING level (not INFO, because it could interfere with the printing of requests),\n but you can change that by passing a log level (or disabling logging entirely by passing None)\n \"\"\"\n global _env\n\n if loglevel:\n locust.log.setup_logging(loglevel)\n\n if not _env:\n options = argument_parser.parse_options()\n\n # in case your test goes looking for the file name of your locustfile\n frame = inspect.stack()[1]\n locustfile = os.path.basename(frame[0].f_code.co_filename)\n options.locustfile = locustfile\n\n _env = Environment(events=locust.events, locustfile=locustfile, host=options.host, parsed_options=options)\n\n # log requests to stdout\n PrintListener(\n _env,\n include_length=include_length,\n include_time=include_time,\n include_context=include_context,\n include_payload=include_payload,\n )\n # fire various events (quit and test_stop will never get called, sorry about that)\n _env.events.init.fire(environment=_env, runner=None, web_ui=None)\n # uncaught events will be suppressed, so check if that happened\n if locust.log.unhandled_greenlet_exception:\n raise Exception(\"Unhandled exception in init\")\n\n # do the things that the Runner usually does\n _env.user_classes = [user_class]\n _env._filter_tasks_by_tags()\n _env.events.test_start.fire(environment=_env)\n if _env.host:\n user_class.host = _env.host\n\n # create a single user\n user = user_class(_env)\n setattr(_env, \"single_user_instance\", user) # if you happen to need access to this from the Environment instance\n user.run()\n\n# Path: locust/user/sequential_taskset.py\nfrom locust.exception import LocustError\n\nimport logging\n\nfrom .task import TaskSet, TaskSetMeta\n\n\nclass SequentialTaskSetMeta(TaskSetMeta):\n \"\"\"\n Meta class for SequentialTaskSet. It's used to allow SequentialTaskSet classes to specify\n task execution in both a list as the tasks attribute or using the @task decorator\n\n We use the fact that class_dict order is the order of declaration in Python 3.6\n (See https://www.python.org/dev/peps/pep-0520/)\n \"\"\"\n\n def __new__(mcs, classname, bases, class_dict):\n new_tasks = []\n for base in bases:\n # first get tasks from base classes\n if hasattr(base, \"tasks\") and base.tasks:\n new_tasks += base.tasks\n for key, value in class_dict.items():\n if key == \"tasks\":\n # we want to insert tasks from the tasks attribute at the point of it's declaration\n # compared to methods declared with @task\n if isinstance(value, list):\n new_tasks.extend(value)\n else:\n raise ValueError(\"On SequentialTaskSet the task attribute can only be set to a list\")\n\n if \"locust_task_weight\" in dir(value):\n # method decorated with @task\n for _ in range(value.locust_task_weight):\n new_tasks.append(value)\n\n class_dict[\"tasks\"] = new_tasks\n return type.__new__(mcs, classname, bases, class_dict)\n\n\nclass SequentialTaskSet(TaskSet, metaclass=SequentialTaskSetMeta):\n \"\"\"\n Class defining a sequence of tasks that a User will execute.\n\n Works like TaskSet, but task weight is ignored, and all tasks are executed in order. Tasks can\n either be specified by setting the *tasks* attribute to a list of tasks, or by declaring tasks\n as methods using the @task decorator. The order of declaration decides the order of execution.\n\n It's possible to combine a task list in the *tasks* attribute, with some tasks declared using\n the @task decorator. The order of declaration is respected also in that case.\n \"\"\"\n\n def __init__(self, *args, **kwargs):\n super().__init__(*args, **kwargs)\n self._task_index = 0\n\n def get_next_task(self):\n if not self.tasks:\n raise LocustError(\n \"No tasks defined. Use the @task decorator or set the 'tasks' attribute of the SequentialTaskSet\"\n )\n task = self.tasks[self._task_index % len(self.tasks)]\n self._task_index += 1\n return task\n\n# Path: locust/__init__.py\nimport os\n\nif os.getenv(\"LOCUST_PLAYWRIGHT\", None):\n print(\"LOCUST_PLAYWRIGHT setting is no longer needed (because locust-plugins no longer installs trio)\")\n print(\"Uninstall trio package and remove the setting.\")\n try:\n # preserve backwards compatibility for now\n import trio\n except ModuleNotFoundError:\n # dont show a massive callstack if trio is not installed\n os._exit(1)\n\nfrom gevent import monkey\n\nmonkey.patch_all()\n\nfrom ._version import version as __version__\nfrom .contrib.fasthttp import FastHttpUser\nfrom .debug import run_single_user\nfrom .event import Events\nfrom .shape import LoadTestShape\nfrom .user import wait_time\nfrom .user.sequential_taskset import SequentialTaskSet\nfrom .user.task import TaskSet, tag, task\nfrom .user.users import HttpUser, User\nfrom .user.wait_time import between, constant, constant_pacing, constant_throughput\n\nevents = Events()\n\n__all__ = (\n \"SequentialTaskSet\",\n \"wait_time\",\n \"task\",\n \"tag\",\n \"TaskSet\",\n \"HttpUser\",\n \"FastHttpUser\",\n \"User\",\n \"between\",\n \"constant\",\n \"constant_pacing\",\n \"constant_throughput\",\n \"events\",\n \"LoadTestShape\",\n \"run_single_user\",\n)\n\n# Used for raising a DeprecationWarning if old Locust/HttpLocust is used\nfrom .util.deprecation import DeprecatedHttpLocustClass as HttpLocust\nfrom .util.deprecation import DeprecatedLocustClass as Locust\n\n# Path: locust/input_events.py\nfrom __future__ import annotations\n\nimport logging\nimport os\nimport sys\nfrom typing import Callable\n\nimport gevent\n\nif os.name == \"nt\":\n import pywintypes\n from win32api import STD_INPUT_HANDLE\n from win32console import (\n ENABLE_ECHO_INPUT,\n ENABLE_LINE_INPUT,\n ENABLE_PROCESSED_INPUT,\n KEY_EVENT,\n GetStdHandle,\n )\nelse:\n import select\n import termios\n import tty\n\n\nclass InitError(Exception):\n pass\n\n\nclass UnixKeyPoller:\n def __init__(self):\n if sys.stdin.isatty():\n self.stdin = sys.stdin.fileno()\n self.tattr = termios.tcgetattr(self.stdin)\n tty.setcbreak(self.stdin, termios.TCSANOW)\n else:\n raise InitError(\"Terminal was not a tty. Keyboard input disabled\")\n\n def cleanup(self):\n termios.tcsetattr(self.stdin, termios.TCSANOW, self.tattr)\n\n def poll(_self):\n dr, dw, de = select.select([sys.stdin], [], [], 0)\n if not dr == []:\n return sys.stdin.read(1)\n return None\n\n\nclass WindowsKeyPoller:\n def __init__(self):\n if sys.stdin.isatty():\n try:\n self.read_handle = GetStdHandle(STD_INPUT_HANDLE)\n self.read_handle.SetConsoleMode(ENABLE_LINE_INPUT | ENABLE_ECHO_INPUT | ENABLE_PROCESSED_INPUT)\n self.cur_event_length = 0\n self.cur_keys_length = 0\n self.captured_chars = []\n except pywintypes.error:\n raise InitError(\"Terminal says its a tty but we couldnt enable line input. Keyboard input disabled.\")\n else:\n raise InitError(\"Terminal was not a tty. Keyboard input disabled\")\n\n def cleanup(self):\n pass\n\n def poll(self):\n if self.captured_chars:\n return self.captured_chars.pop(0)\n\n events_peek = self.read_handle.PeekConsoleInput(10000)\n\n if not events_peek:\n return None\n\n if not len(events_peek) == self.cur_event_length:\n for cur_event in events_peek[self.cur_event_length :]:\n if cur_event.EventType == KEY_EVENT:\n if ord(cur_event.Char) and cur_event.KeyDown:\n cur_char = str(cur_event.Char)\n self.captured_chars.append(cur_char)\n\n self.cur_event_length = len(events_peek)\n\n if self.captured_chars:\n return self.captured_chars.pop(0)\n else:\n return None\n\n\n\ndef get_poller():\n if os.name == \"nt\":\n return WindowsKeyPoller()\n else:\n return UnixKeyPoller()\n\n\ndef input_listener(key_to_func_map: dict[str, Callable]):\n def input_listener_func():\n try:\n poller = get_poller()\n except InitError as e:\n logging.debug(e)\n return\n\n try:\n while True:\n input = poller.poll()\n if input:\n for key in key_to_func_map:\n if input == key:\n key_to_func_map[key]()\n else:\n gevent.sleep(0.2)\n except Exception as e:\n logging.warning(f\"Exception in keyboard input poller: {e}\")\n finally:\n poller.cleanup()\n\n return input_listener_func\n\n# Path: locust/util/load_locustfile.py\nfrom __future__ import annotations\n\nimport importlib\nimport importlib.util\nimport inspect\nimport os\nimport sys\n\nfrom ..shape import LoadTestShape\nfrom ..user import User\n\n\ndef is_user_class(item):\n \"\"\"\n Check if a variable is a runnable (non-abstract) User class\n \"\"\"\n return bool(inspect.isclass(item) and issubclass(item, User) and item.abstract is False)\n\n\ndef is_shape_class(item):\n \"\"\"\n Check if a class is a LoadTestShape\n \"\"\"\n return bool(inspect.isclass(item) and issubclass(item, LoadTestShape) and not getattr(item, \"abstract\", True))\n\n\ndef load_locustfile(path) -> tuple[str | None, dict[str, User], list[LoadTestShape]]:\n \"\"\"\n Import given locustfile path and return (docstring, callables).\n\n Specifically, the locustfile's ``__doc__`` attribute (a string) and a\n dictionary of ``{'name': callable}`` containing all callables which pass\n the \"is a Locust\" test.\n \"\"\"\n\n # Start with making sure the current working dir is in the sys.path\n sys.path.insert(0, os.getcwd())\n # Get directory and locustfile name\n directory, locustfile = os.path.split(path)\n # If the directory isn't in the PYTHONPATH, add it so our import will work\n added_to_path = False\n index = None\n if directory not in sys.path:\n sys.path.insert(0, directory)\n added_to_path = True\n # If the directory IS in the PYTHONPATH, move it to the front temporarily,\n # otherwise other locustfiles -- like Locusts's own -- may scoop the intended\n # one.\n else:\n i = sys.path.index(directory)\n if i != 0:\n # Store index for later restoration\n index = i\n # Add to front, then remove from original position\n sys.path.insert(0, directory)\n del sys.path[i + 1]\n\n # Perform the import\n module_name = os.path.splitext(locustfile)[0]\n if module_name == \"locust\":\n module_name = \"locustfile\" # Avoid conflict with locust package\n loader = importlib.machinery.SourceFileLoader(module_name, path)\n spec = importlib.util.spec_from_file_location(module_name, path, loader=loader)\n if spec is None:\n sys.stderr.write(f\"Unable to get module spec for {module_name} in {path}\")\n sys.exit(1)\n\n imported = importlib.util.module_from_spec(spec)\n sys.modules[imported.__name__] = imported\n loader.exec_module(imported)\n\n # Remove directory from path if we added it ourselves (just to be neat)\n if added_to_path:\n del sys.path[0]\n # Put back in original index if we moved it\n if index is not None:\n sys.path.insert(index + 1, directory)\n del sys.path[0]\n # Return our two-tuple\n user_classes = {name: value for name, value in vars(imported).items() if is_user_class(value)}\n\n # Find shape class, if any, return it\n shape_classes = [value() for value in vars(imported).values() if is_shape_class(value)]\n\n return imported.__doc__, user_classes, shape_classes\n\n# Path: locust/main.py\nfrom __future__ import annotations\n\nimport locust\n\nimport atexit\nimport errno\nimport gc\nimport inspect\nimport json\nimport logging\nimport os\nimport signal\nimport sys\nimport time\nimport traceback\n\nimport gevent\n\nfrom . import log, stats\nfrom .argument_parser import parse_locustfile_option, parse_options\nfrom .env import Environment\nfrom .html import get_html_report\nfrom .input_events import input_listener\nfrom .log import greenlet_exception_logger, setup_logging\nfrom .stats import (\n StatsCSV,\n StatsCSVFileWriter,\n print_error_report,\n print_percentile_stats,\n print_stats,\n print_stats_json,\n stats_history,\n stats_printer,\n)\nfrom .user.inspectuser import print_task_ratio, print_task_ratio_json\nfrom .util.load_locustfile import load_locustfile\nfrom .util.timespan import parse_timespan\n\ntry:\n # import locust_plugins if it is installed, to allow it to register custom arguments etc\n import locust_plugins # pyright: ignore[reportMissingImports]\nexcept ModuleNotFoundError:\n pass\n\nversion = locust.__version__\n\n# Options to ignore when using a custom shape class without `use_common_options=True`\n# See: https://docs.locust.io/en/stable/custom-load-shape.html#use-common-options\nCOMMON_OPTIONS = {\n \"num_users\": \"users\",\n \"spawn_rate\": \"spawn-rate\",\n \"run_time\": \"run-time\",\n}\n\n\ndef create_environment(\n user_classes,\n options,\n events=None,\n shape_class=None,\n locustfile=None,\n available_user_classes=None,\n available_shape_classes=None,\n available_user_tasks=None,\n):\n \"\"\"\n Create an Environment instance from options\n \"\"\"\n return Environment(\n locustfile=locustfile,\n user_classes=user_classes,\n shape_class=shape_class,\n events=events,\n host=options.host,\n reset_stats=options.reset_stats,\n parsed_options=options,\n available_user_classes=available_user_classes,\n available_shape_classes=available_shape_classes,\n available_user_tasks=available_user_tasks,\n )\n\n\ndef main():\n # find specified locustfile(s) and make sure it exists, using a very simplified\n # command line parser that is only used to parse the -f option.\n locustfiles = parse_locustfile_option()\n locustfiles_length = len(locustfiles)\n\n # Grabbing the Locustfile if only one was provided. Otherwise, allowing users to select the locustfile in the UI\n # If --headless or --autostart and multiple locustfiles, all provided UserClasses will be ran\n locustfile = locustfiles[0] if locustfiles_length == 1 else None\n\n # Importing Locustfile(s) - setting available UserClasses and ShapeClasses to choose from in UI\n user_classes: dict[str, locust.User] = {}\n available_user_classes = {}\n available_shape_classes = {}\n available_user_tasks = {}\n shape_class = None\n for _locustfile in locustfiles:\n docstring, _user_classes, shape_classes = load_locustfile(_locustfile)\n\n # Setting Available Shape Classes\n if shape_classes:\n shape_class = shape_classes[0]\n for shape_class in shape_classes:\n shape_class_name = type(shape_class).__name__\n if shape_class_name in available_shape_classes.keys():\n sys.stderr.write(f\"Duplicate shape classes: {shape_class_name}\\n\")\n sys.exit(1)\n\n available_shape_classes[shape_class_name] = shape_class\n\n # Setting Available User Classes\n for key, value in _user_classes.items():\n if key in available_user_classes.keys():\n previous_path = inspect.getfile(user_classes[key])\n new_path = inspect.getfile(value)\n if previous_path == new_path:\n # The same User class was defined in two locustfiles but one probably imported the other, so we just ignore it\n continue\n else:\n sys.stderr.write(\n f\"Duplicate user class names: {key} is defined in both {previous_path} and {new_path}\\n\"\n )\n sys.exit(1)\n\n user_classes[key] = value\n available_user_classes[key] = value\n available_user_tasks[key] = value.tasks or None\n\n if len(stats.PERCENTILES_TO_CHART) != 2:\n logging.error(\"stats.PERCENTILES_TO_CHART parameter should be 2 parameters \\n\")\n sys.exit(1)\n\n if len(stats.MODERN_UI_PERCENTILES_TO_CHART) > 6:\n logging.error(\"stats.MODERN_UI_PERCENTILES_TO_CHART parameter should be a maximum of 6 parameters \\n\")\n sys.exit(1)\n\n def is_valid_percentile(parameter):\n try:\n if 0 < float(parameter) < 1:\n return True\n return False\n except ValueError:\n return False\n\n for percentile in stats.PERCENTILES_TO_CHART:\n if not is_valid_percentile(percentile):\n logging.error(\n \"stats.PERCENTILES_TO_CHART parameter need to be float and value between. 0 < percentile < 1 Eg 0.95\\n\"\n )\n sys.exit(1)\n\n for percentile in stats.PERCENTILES_TO_STATISTICS:\n if not is_valid_percentile(percentile):\n logging.error(\n \"stats.PERCENTILES_TO_STATISTICS parameter need to be float and value between. 0 < percentile < 1 Eg 0.95\\n\"\n )\n sys.exit(1)\n\n # parse all command line options\n options = parse_options()\n\n if options.headful:\n options.headless = False\n\n if options.slave or options.expect_slaves:\n sys.stderr.write(\"The --slave/--expect-slaves parameters have been renamed --worker/--expect-workers\\n\")\n sys.exit(1)\n\n if options.web_auth:\n sys.stderr.write(\n \"The --web-auth parameters has been replaced with --web-login. See https://docs.locust.io/en/stable/extending-locust.html#authentication for details\\n\"\n )\n sys.exit(1)\n\n if options.autoquit != -1 and not options.autostart:\n sys.stderr.write(\"--autoquit is only meaningful in combination with --autostart\\n\")\n sys.exit(1)\n\n if options.hatch_rate:\n sys.stderr.write(\"[DEPRECATED] The --hatch-rate parameter has been renamed --spawn-rate\\n\")\n options.spawn_rate = options.hatch_rate\n\n # setup logging\n if not options.skip_log_setup:\n if options.loglevel.upper() in [\"DEBUG\", \"INFO\", \"WARNING\", \"ERROR\", \"CRITICAL\"]:\n setup_logging(options.loglevel, options.logfile)\n else:\n sys.stderr.write(\"Invalid --loglevel. Valid values are: DEBUG/INFO/WARNING/ERROR/CRITICAL\\n\")\n sys.exit(1)\n\n children = []\n\n if options.processes:\n if os.name == \"nt\":\n sys.stderr.write(\"--processes is not supported in Windows (except in WSL)\\n\")\n sys.exit(1)\n if options.processes == -1:\n options.processes = os.cpu_count()\n if not options.processes:\n sys.stderr.write(\"--processes failed to detect number of cpus!?\\n\")\n sys.exit(1)\n elif options.processes < -1:\n sys.stderr.write(f\"Invalid --processes count {options.processes}\\n\")\n sys.exit(1)\n elif options.master:\n sys.stderr.write(\n \"--master cannot be combined with --processes. Remove --master, as it is implicit as long as --worker is not set.\\n\"\n )\n sys.exit(1)\n # Optimize copy-on-write-behavior to save some memory (aprx 26MB -> 15MB rss) in child processes\n gc.collect() # avoid freezing garbage\n gc.freeze() # move all objects to perm gen so ref counts dont get updated\n for _ in range(options.processes):\n child_pid = gevent.fork()\n if child_pid:\n children.append(child_pid)\n logging.debug(f\"Started child worker with pid #{child_pid}\")\n else:\n # child is always a worker, even when it wasnt set on command line\n options.worker = True\n # remove options that dont make sense on worker\n options.run_time = None\n options.autostart = None\n break\n else:\n # we're in the parent process\n if options.worker:\n # ignore the first sigint in parent, and wait for the children to handle sigint\n def sigint_handler(_signal, _frame):\n if getattr(sigint_handler, \"has_run\", False):\n # if parent gets repeated sigint, we kill the children hard\n for child_pid in children:\n try:\n logging.debug(f\"Sending SIGKILL to child with pid {child_pid}\")\n os.kill(child_pid, signal.SIGKILL)\n except ProcessLookupError:\n pass # process already dead\n except Exception:\n logging.error(traceback.format_exc())\n sys.exit(1)\n sigint_handler.has_run = True\n\n signal.signal(signal.SIGINT, sigint_handler)\n exit_code = 0\n # nothing more to do, just wait for the children to exit\n for child_pid in children:\n _, child_status = os.waitpid(child_pid, 0)\n try:\n if sys.version_info >= (3, 9):\n child_exit_code = os.waitstatus_to_exitcode(child_status)\n exit_code = max(exit_code, child_exit_code)\n except AttributeError:\n pass # dammit python 3.8...\n sys.exit(exit_code)\n else:\n options.master = True\n options.expect_workers = options.processes\n\n def kill_workers(children):\n exit_code = 0\n start_time = time.time()\n # give children some time to finish up (in case they had an error parsing arguments etc)\n for child_pid in children[:]:\n while time.time() < start_time + 3:\n try:\n _, child_status = os.waitpid(child_pid, os.WNOHANG)\n children.remove(child_pid)\n try:\n if sys.version_info >= (3, 9):\n child_exit_code = os.waitstatus_to_exitcode(child_status)\n exit_code = max(exit_code, child_exit_code)\n except AttributeError:\n pass # dammit python 3.8...\n except OSError as e:\n if e.errno == errno.EINTR:\n time.sleep(0.1)\n else:\n logging.error(traceback.format_exc())\n else:\n break\n for child_pid in children:\n try:\n logging.debug(f\"Sending SIGINT to child with pid {child_pid}\")\n os.kill(child_pid, signal.SIGINT)\n except ProcessLookupError:\n pass # never mind, process was already dead\n for child_pid in children:\n _, child_status = os.waitpid(child_pid, 0)\n try:\n if sys.version_info >= (3, 9):\n child_exit_code = os.waitstatus_to_exitcode(child_status)\n exit_code = max(exit_code, child_exit_code)\n except AttributeError:\n pass # dammit python 3.8...\n if exit_code > 1:\n logging.error(f\"Bad response code from worker children: {exit_code}\")\n # ensure master doesnt finish until output from workers has arrived\n # otherwise the terminal might look weird.\n time.sleep(0.1)\n\n atexit.register(kill_workers, children)\n\n logger = logging.getLogger(__name__)\n greenlet_exception_handler = greenlet_exception_logger(logger)\n\n if options.stop_timeout:\n try:\n options.stop_timeout = parse_timespan(options.stop_timeout)\n except ValueError:\n logger.error(\"Valid --stop-timeout formats are: 20, 20s, 3m, 2h, 1h20m, 3h30m10s, etc.\")\n sys.exit(1)\n\n if options.list_commands:\n print(\"Available Users:\")\n for name in user_classes:\n print(\" \" + name)\n sys.exit(0)\n\n if not user_classes:\n logger.error(\"No User class found!\")\n sys.exit(1)\n\n # make sure specified User exists\n if options.user_classes:\n missing = set(options.user_classes) - set(user_classes.keys())\n if missing:\n logger.error(f\"Unknown User(s): {', '.join(missing)}\\n\")\n sys.exit(1)\n else:\n names = set(options.user_classes) & set(user_classes.keys())\n user_classes = [user_classes[n] for n in names]\n else:\n # list() call is needed to consume the dict_view object in Python 3\n user_classes = list(user_classes.values())\n\n if os.name != \"nt\" and not options.master:\n try:\n import resource\n\n minimum_open_file_limit = 10000\n (soft_limit, hard_limit) = resource.getrlimit(resource.RLIMIT_NOFILE)\n\n if soft_limit < minimum_open_file_limit:\n # Increasing the limit to 10000 within a running process should work on at least MacOS.\n # It does not work on all OS:es, but we should be no worse off for trying.\n resource.setrlimit(resource.RLIMIT_NOFILE, [minimum_open_file_limit, hard_limit])\n except BaseException:\n logger.warning(\n f\"\"\"System open file limit '{soft_limit} is below minimum setting '{minimum_open_file_limit}'.\nIt's not high enough for load testing, and the OS didn't allow locust to increase it by itself.\nSee https://github.com/locustio/locust/wiki/Installation#increasing-maximum-number-of-open-files-limit for more info.\"\"\"\n )\n\n if sys.version_info < (3, 9):\n logger.warning(\"Python 3.8 support is deprecated and will be removed soon\")\n\n # create locust Environment\n locustfile_path = None if not locustfile else os.path.basename(locustfile)\n\n environment = create_environment(\n user_classes,\n options,\n events=locust.events,\n shape_class=shape_class,\n locustfile=locustfile_path,\n available_user_classes=available_user_classes,\n available_shape_classes=available_shape_classes,\n available_user_tasks=available_user_tasks,\n )\n\n if options.config_users:\n for json_user_config in options.config_users:\n try:\n if json_user_config.endswith(\".json\"):\n with open(json_user_config) as file:\n user_config = json.load(file)\n else:\n user_config = json.loads(json_user_config)\n\n def ensure_user_class_name(config):\n if \"user_class_name\" not in config:\n logger.error(\"The user config must specify a user_class_name\")\n sys.exit(-1)\n\n if isinstance(user_config, list):\n for config in user_config:\n ensure_user_class_name(config)\n\n environment.update_user_class(config)\n else:\n ensure_user_class_name(user_config)\n\n environment.update_user_class(user_config)\n except Exception as e:\n logger.error(f\"The --config-users arugment must be in valid JSON string or file: {e}\")\n sys.exit(-1)\n\n if (\n shape_class\n and not shape_class.use_common_options\n and any(getattr(options, opt, None) for opt in COMMON_OPTIONS)\n ):\n logger.warning(\n \"--run-time, --users or --spawn-rate have no impact on LoadShapes unless the shape class explicitly reads them. \"\n \"See: docs.locust.io/en/stable/custom-load-shape.html#use-common-options\"\n )\n ignored = [f\"--{arg}\" for opt, arg in COMMON_OPTIONS.items() if getattr(options, opt, None)]\n logger.warning(f\"The following option(s) will be ignored: {', '.join(ignored)}\")\n\n if options.show_task_ratio:\n print(\"\\n Task ratio per User class\")\n print(\"-\" * 80)\n print_task_ratio(user_classes, options.num_users, False)\n print(\"\\n Total task ratio\")\n print(\"-\" * 80)\n print_task_ratio(user_classes, options.num_users, True)\n sys.exit(0)\n if options.show_task_ratio_json:\n print_task_ratio_json(user_classes, options.num_users)\n sys.exit(0)\n\n if options.master:\n if options.worker:\n logger.error(\"The --master argument cannot be combined with --worker\")\n sys.exit(-1)\n if options.expect_workers_max_wait and not options.expect_workers:\n logger.error(\"The --expect-workers-max-wait argument only makes sense when combined with --expect-workers\")\n sys.exit(-1)\n runner = environment.create_master_runner(\n master_bind_host=options.master_bind_host,\n master_bind_port=options.master_bind_port,\n )\n elif options.worker:\n try:\n runner = environment.create_worker_runner(options.master_host, options.master_port)\n logger.debug(\"Connected to locust master: %s:%s\", options.master_host, options.master_port)\n except OSError as e:\n logger.error(\"Failed to connect to the Locust master: %s\", e)\n sys.exit(-1)\n else:\n runner = environment.create_local_runner()\n\n # main_greenlet is pointing to runners.greenlet by default, it will point the web greenlet later if in web mode\n main_greenlet = runner.greenlet\n\n if options.run_time:\n if options.worker:\n logger.error(\"--run-time should be specified on the master node, and not on worker nodes\")\n sys.exit(1)\n try:\n options.run_time = parse_timespan(options.run_time)\n except ValueError:\n logger.error(\"Valid --run-time formats are: 20, 20s, 3m, 2h, 1h20m, 3h30m10s, etc.\")\n sys.exit(1)\n\n if options.csv_prefix:\n base_csv_file = os.path.basename(options.csv_prefix)\n base_csv_dir = options.csv_prefix[: -len(base_csv_file)]\n if not os.path.exists(base_csv_dir) and len(base_csv_dir) != 0:\n os.makedirs(base_csv_dir)\n stats_csv_writer = StatsCSVFileWriter(\n environment, stats.PERCENTILES_TO_REPORT, options.csv_prefix, options.stats_history_enabled\n )\n else:\n stats_csv_writer = StatsCSV(environment, stats.PERCENTILES_TO_REPORT)\n\n # start Web UI\n if not options.headless and not options.worker:\n # spawn web greenlet\n protocol = \"https\" if options.tls_cert and options.tls_key else \"http\"\n\n if options.web_host == \"*\":\n # special check for \"*\" so that we're consistent with --master-bind-host\n web_host = \"\"\n else:\n web_host = options.web_host\n if web_host:\n logger.info(f\"Starting web interface at {protocol}://{web_host}:{options.web_port}\")\n else:\n if os.name == \"nt\":\n logger.info(\n f\"Starting web interface at {protocol}://localhost:{options.web_port} (accepting connections from all network interfaces)\"\n )\n else:\n logger.info(f\"Starting web interface at {protocol}://0.0.0.0:{options.web_port}\")\n\n web_ui = environment.create_web_ui(\n host=web_host,\n port=options.web_port,\n web_login=options.web_login,\n tls_cert=options.tls_cert,\n tls_key=options.tls_key,\n stats_csv_writer=stats_csv_writer,\n delayed_start=True,\n userclass_picker_is_active=options.class_picker,\n modern_ui=not options.legacy_ui,\n )\n else:\n web_ui = None\n\n if options.autostart and options.headless:\n logger.warning(\"The --autostart argument is implied by --headless, no need to set both.\")\n\n if options.autostart and options.worker:\n logger.warning(\"The --autostart argument has no meaning on a worker.\")\n\n def assign_equal_weights(environment, **kwargs):\n environment.assign_equal_weights()\n\n if options.equal_weights:\n environment.events.init.add_listener(assign_equal_weights)\n\n # Fire locust init event which can be used by end-users' code to run setup code that\n # need access to the Environment, Runner or WebUI.\n environment.events.init.fire(environment=environment, runner=runner, web_ui=web_ui)\n\n if web_ui:\n web_ui.start()\n main_greenlet = web_ui.greenlet\n\n def stop_and_optionally_quit():\n if options.autostart and not options.headless:\n logger.info(\"--run-time limit reached, stopping test\")\n runner.stop()\n if options.autoquit != -1:\n logger.debug(\"Autoquit time limit set to %s seconds\" % options.autoquit)\n time.sleep(options.autoquit)\n logger.info(\"--autoquit time reached, shutting down\")\n runner.quit()\n if web_ui:\n web_ui.stop()\n else:\n logger.info(\"--autoquit not specified, leaving web ui running indefinitely\")\n else: # --headless run\n logger.info(\"--run-time limit reached, shutting down\")\n runner.quit()\n\n def spawn_run_time_quit_greenlet():\n gevent.spawn_later(options.run_time, stop_and_optionally_quit).link_exception(greenlet_exception_handler)\n\n headless_master_greenlet = None\n stats_printer_greenlet = None\n if not options.only_summary and (options.print_stats or (options.headless and not options.worker)):\n # spawn stats printing greenlet\n stats_printer_greenlet = gevent.spawn(stats_printer(runner.stats))\n stats_printer_greenlet.link_exception(greenlet_exception_handler)\n\n gevent.spawn(stats_history, runner)\n\n def start_automatic_run():\n if options.master:\n # wait for worker nodes to connect\n start_time = time.monotonic()\n while len(runner.clients.ready) < options.expect_workers:\n if options.expect_workers_max_wait and options.expect_workers_max_wait < time.monotonic() - start_time:\n logger.error(\"Gave up waiting for workers to connect\")\n runner.quit()\n sys.exit(1)\n logging.info(\n \"Waiting for workers to be ready, %s of %s connected\",\n len(runner.clients.ready),\n options.expect_workers,\n )\n # TODO: Handle KeyboardInterrupt and send quit signal to workers that are started.\n # Right now, if the user sends a ctrl+c, the master will not gracefully\n # shutdown resulting in all the already started workers to stay active.\n time.sleep(1)\n if not options.worker:\n # apply headless mode defaults\n if options.num_users is None:\n options.num_users = 1\n if options.spawn_rate is None:\n options.spawn_rate = 1\n\n # start the test\n if environment.shape_class:\n try:\n environment.runner.start_shape()\n environment.runner.shape_greenlet.join()\n except KeyboardInterrupt:\n logging.info(\"Exiting due to CTRL+C interruption\")\n finally:\n stop_and_optionally_quit()\n else:\n headless_master_greenlet = gevent.spawn(runner.start, options.num_users, options.spawn_rate)\n headless_master_greenlet.link_exception(greenlet_exception_handler)\n\n if options.run_time:\n logger.info(f\"Run time limit set to {options.run_time} seconds\")\n spawn_run_time_quit_greenlet()\n elif not options.worker and not environment.shape_class:\n logger.info(\"No run time limit set, use CTRL+C to interrupt\")\n\n if options.csv_prefix:\n gevent.spawn(stats_csv_writer.stats_writer).link_exception(greenlet_exception_handler)\n\n if options.headless:\n start_automatic_run()\n\n input_listener_greenlet = None\n if not options.worker:\n # spawn input listener greenlet\n input_listener_greenlet = gevent.spawn(\n input_listener(\n {\n \"w\": lambda: runner.start(runner.user_count + 1, 100)\n if runner.state != \"spawning\"\n else logging.warning(\"Already spawning users, can't spawn more right now\"),\n \"W\": lambda: runner.start(runner.user_count + 10, 100)\n if runner.state != \"spawning\"\n else logging.warning(\"Already spawning users, can't spawn more right now\"),\n \"s\": lambda: runner.start(max(0, runner.user_count - 1), 100)\n if runner.state != \"spawning\"\n else logging.warning(\"Spawning users, can't stop right now\"),\n \"S\": lambda: runner.start(max(0, runner.user_count - 10), 100)\n if runner.state != \"spawning\"\n else logging.warning(\"Spawning users, can't stop right now\"),\n },\n )\n )\n input_listener_greenlet.link_exception(greenlet_exception_handler)\n # ensure terminal is reset, even if there is an unhandled exception in locust or someone\n # does something wild, like calling sys.exit() in the locustfile\n atexit.register(input_listener_greenlet.kill, block=True)\n\n def shutdown():\n \"\"\"\n Shut down locust by firing quitting event, printing/writing stats and exiting\n \"\"\"\n logger.debug(\"Running teardowns...\")\n\n if input_listener_greenlet is not None:\n input_listener_greenlet.kill(block=False)\n\n environment.events.quitting.fire(environment=environment, reverse=True)\n\n # determine the process exit code\n if environment.process_exit_code is not None:\n code = environment.process_exit_code\n elif len(runner.errors) or len(runner.exceptions):\n code = options.exit_code_on_error\n elif log.unhandled_greenlet_exception:\n code = 2\n else:\n code = 0\n\n logger.info(f\"Shutting down (exit code {code})\")\n\n if stats_printer_greenlet is not None:\n stats_printer_greenlet.kill(block=False)\n if headless_master_greenlet is not None:\n headless_master_greenlet.kill(block=False)\n logger.debug(\"Cleaning up runner...\")\n if runner is not None:\n runner.quit()\n if options.json:\n print_stats_json(runner.stats)\n elif not isinstance(runner, locust.runners.WorkerRunner):\n print_stats(runner.stats, current=False)\n print_percentile_stats(runner.stats)\n print_error_report(runner.stats)\n environment.events.quit.fire(exit_code=code)\n sys.exit(code)\n\n # install SIGTERM handler\n def sig_term_handler():\n logger.info(\"Got SIGTERM signal\")\n shutdown()\n\n def save_html_report(use_modern_ui=False):\n html_report = get_html_report(environment, show_download_link=False, use_modern_ui=use_modern_ui)\n logger.info(\"writing html report to file: %s\", options.html_file)\n with open(options.html_file, \"w\", encoding=\"utf-8\") as file:\n file.write(html_report)\n\n gevent.signal_handler(signal.SIGTERM, sig_term_handler)\n\n try:\n logger.info(f\"Starting Locust {version}\")\n if options.class_picker:\n logger.info(\"Locust is running with the UserClass Picker Enabled\")\n if options.autostart and not options.headless:\n start_automatic_run()\n\n main_greenlet.join()\n if options.html_file:\n save_html_report(not options.legacy_ui)\n except KeyboardInterrupt:\n if options.html_file:\n save_html_report(not options.legacy_ui)\n except Exception:\n raise\n shutdown()\n\n# Path: locust/__main__.py\nfrom .main import main\n\nmain()\n\n# Path: locust/test/fake_module1_for_env_test.py\n\"\"\"Module for locust.test.test_env.TestEnvironment.test_user_classes_with_same_name_is_error\"\"\"\n\nfrom locust import User\n\n\nclass MyUserWithSameName(User):\n pass\n\n# Path: locust/test/fake_module2_for_env_test.py\n\"\"\"Module for locust.test.test_env.TestEnvironment.test_user_classes_with_same_name_is_error\"\"\"\n\nfrom locust import User\n\n\nclass MyUserWithSameName(User):\n pass\n\n# Path: locust/test/__init__.py\ntry:\n import resource\n\n # work around occasional \"zmq.error.ZMQError: Too many open files\"\n # this is done in main.py when running locust proper so we need to do it here as well\n resource.setrlimit(\n resource.RLIMIT_NOFILE,\n (\n 10000,\n resource.RLIM_INFINITY,\n ),\n )\n changed_rlimit = True\nexcept Exception:\n changed_rlimit = False\n\n# Path: locust/test/mock_logging.py\nfrom __future__ import annotations\n\nimport logging\n\nfrom typing import List, Union, Dict\nfrom types import TracebackType\n\nLogMessage = List[Union[str, Dict[str, TracebackType]]]\n\n\nclass MockedLoggingHandler(logging.Handler):\n debug: list[LogMessage] = []\n warning: list[LogMessage] = []\n info: list[LogMessage] = []\n error: list[LogMessage] = []\n critical: list[LogMessage] = []\n\n def emit(self, record):\n if record.exc_info:\n value = {\"message\": record.getMessage(), \"exc_info\": record.exc_info}\n else:\n value = record.getMessage()\n getattr(self.__class__, record.levelname.lower()).append(value)\n\n @classmethod\n def reset(cls):\n for attr in dir(cls):\n if isinstance(getattr(cls, attr), list):\n setattr(cls, attr, [])\n\n# Path: locust/test/util.py\nimport datetime\nimport functools\nimport gc\nimport os\nimport socket\nimport warnings\nfrom contextlib import contextmanager\nfrom tempfile import NamedTemporaryFile\n\nfrom cryptography import x509\nfrom cryptography.hazmat.backends import default_backend\nfrom cryptography.hazmat.primitives import hashes, serialization\nfrom cryptography.hazmat.primitives.asymmetric import rsa\nfrom cryptography.x509.oid import NameOID\n\n\n@contextmanager\ndef temporary_file(content, suffix=\"_locustfile.py\", dir=None):\n f = NamedTemporaryFile(suffix=suffix, delete=False, dir=dir)\n f.write(content.encode(\"utf-8\"))\n f.close()\n try:\n yield f.name\n finally:\n if os.path.exists(f.name):\n os.remove(f.name)\n\n\n@contextmanager\ndef patch_env(name: str, value: str):\n prev_value = os.getenv(name)\n os.environ[name] = value\n try:\n yield\n finally:\n if prev_value is None:\n del os.environ[name]\n else:\n os.environ[name] = prev_value\n\n\ndef get_free_tcp_port():\n \"\"\"\n Find an unused TCP port\n \"\"\"\n s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)\n s.bind((\"127.0.0.1\", 0))\n s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)\n port = s.getsockname()[1]\n s.close()\n return port\n\n\ndef create_tls_cert(hostname):\n \"\"\"Generate a TLS cert and private key to serve over https\"\"\"\n key = rsa.generate_private_key(public_exponent=2**16 + 1, key_size=2048, backend=default_backend())\n name = x509.Name([x509.NameAttribute(NameOID.COMMON_NAME, hostname)])\n now = datetime.datetime.now(tz=datetime.timezone.utc)\n cert = (\n x509.CertificateBuilder()\n .subject_name(name)\n .issuer_name(name)\n .public_key(key.public_key())\n .serial_number(1000)\n .not_valid_before(now)\n .not_valid_after(now + datetime.timedelta(days=10 * 365))\n .sign(key, hashes.SHA256(), default_backend())\n )\n cert_pem = cert.public_bytes(encoding=serialization.Encoding.PEM)\n key_pem = key.private_bytes(\n encoding=serialization.Encoding.PEM,\n format=serialization.PrivateFormat.TraditionalOpenSSL,\n encryption_algorithm=serialization.NoEncryption(),\n )\n\n return cert_pem, key_pem\n\n\ndef clear_all_functools_lru_cache() -> None:\n # Clear all `functools.lru_cache` to ensure that no state are persisted from one test to another.\n # Taken from https://stackoverflow.com/a/50699209.\n with warnings.catch_warnings():\n warnings.simplefilter(action=\"ignore\", category=ResourceWarning)\n gc.collect()\n wrappers = [a for a in gc.get_objects() if isinstance(a, functools._lru_cache_wrapper)]\n assert len(wrappers) > 0\n for wrapper in wrappers:\n wrapper.cache_clear()\n\n# Path: locust/test/testcases.py\nimport locust\nfrom locust import log\nfrom locust.env import Environment\nfrom locust.event import Events\nfrom locust.test.mock_logging import MockedLoggingHandler\nfrom locust.test.util import clear_all_functools_lru_cache\n\nimport base64\nimport logging\nimport random\nimport sys\nimport unittest\nimport warnings\nfrom io import BytesIO\n\nimport gevent\nimport gevent.pywsgi\nfrom flask import Flask, Response, make_response, redirect, request, send_file, stream_with_context\n\napp = Flask(__name__)\napp.jinja_env.add_extension(\"jinja2.ext.do\")\n\n\n@app.route(\"/ultra_fast\")\ndef ultra_fast():\n return \"This is an ultra fast response\"\n\n\n@app.route(\"/fast\")\ndef fast():\n gevent.sleep(random.choice([0.1, 0.2, 0.3]))\n return \"This is a fast response\"\n\n\n@app.route(\"/slow\")\ndef slow():\n delay = request.args.get(\"delay\")\n if delay:\n gevent.sleep(float(delay))\n else:\n gevent.sleep(random.choice([0.5, 1, 1.5]))\n return \"This is a slow response\"\n\n\n@app.route(\"/consistent\")\ndef consistent():\n gevent.sleep(0.2)\n return \"This is a consistent response\"\n\n\n@app.route(\"/request_method\", methods=[\"POST\", \"GET\", \"HEAD\", \"PUT\", \"DELETE\", \"PATCH\"])\ndef request_method():\n return request.method\n\n\n@app.route(\"/request_header_test\")\ndef request_header_test():\n x_header_test = request.headers[\"X-Header-Test\"]\n response = Response(x_header_test)\n response.headers[\"X-Header-Test\"] = x_header_test\n\n return response\n\n\n@app.route(\"/post\", methods=[\"POST\"])\n@app.route(\"/put\", methods=[\"PUT\"])\ndef manipulate():\n return str(request.form.get(\"arg\", \"\"))\n\n\n@app.route(\"/get_arg\", methods=[\"GET\"])\ndef get_arg():\n return request.args.get(\"arg\")\n\n\n@app.route(\"/fail\")\ndef failed_request():\n return \"This response failed\", 500\n\n\n@app.route(\"/status/204\")\ndef status_204():\n return \"\", 204\n\n\n@app.route(\"/redirect\", methods=[\"GET\", \"POST\"])\ndef do_redirect():\n delay = request.args.get(\"delay\")\n if delay:\n gevent.sleep(float(delay))\n url = request.args.get(\"url\", \"/ultra_fast\")\n return redirect(url)\n\n\n@app.route(\"/basic_auth\")\ndef basic_auth():\n auth = base64.b64decode(request.headers.get(\"Authorization\", \"\").replace(\"Basic \", \"\")).decode(\"utf-8\")\n if auth == \"locust:menace\":\n return \"Authorized\"\n resp = make_response(\"401 Authorization Required\", 401)\n resp.headers[\"WWW-Authenticate\"] = 'Basic realm=\"Locust\"'\n return resp\n\n\n@app.route(\"/no_content_length\")\ndef no_content_length():\n r = send_file(\n BytesIO(b\"This response does not have content-length in the header\"),\n etag=False,\n mimetype=\"text/plain\",\n )\n r.headers.remove(\"Content-Length\")\n return r\n\n\n@app.errorhandler(404)\ndef not_found(error):\n return \"Not Found\", 404\n\n\n@app.route(\"/streaming/\")\ndef streaming_response(iterations):\n import time\n\n def generate():\n yield \"

streaming response

\"\n for i in range(iterations):\n yield f\"{i}\\n\"\n time.sleep(0.01)\n yield \"\"\n\n return Response(stream_with_context(generate()), mimetype=\"text/html\")\n\n\n@app.route(\"/set_cookie\", methods=[\"POST\"])\ndef set_cookie():\n response = make_response(\"ok\")\n response.set_cookie(request.args.get(\"name\"), request.args.get(\"value\"))\n return response\n\n\n@app.route(\"/get_cookie\")\ndef get_cookie():\n return make_response(request.cookies.get(request.args.get(\"name\"), \"\"))\n\n\n@app.route(\"/rest\", methods=[\"POST\"])\ndef rest():\n return request.json\n\n\n@app.route(\"/content_type_missing_charset\")\ndef content_type_missing_charset():\n resp = make_response(\"stuff\")\n resp.headers[\"Content-Type\"] = \"Content-Type: application/json;\"\n return resp\n\n\n@app.route(\"/content_type_regular\")\ndef content_type_regular():\n resp = make_response(\"stuff\")\n resp.headers[\"Content-Type\"] = \"Content-Type: application/json; charset=utf-8;\"\n return resp\n\n\n@app.route(\"/content_type_with_extra_stuff\")\ndef content_type_with_extra_stuff():\n resp = make_response(\"stuff\")\n resp.headers[\"Content-Type\"] = \"Content-Type: application/json; charset=utf-8; api-version=3.0\"\n return resp\n\n\n...\n", "needle_token_start": 3984, "needle_token_end": 4020, "code_context_ntokens": 16090, "output": ["```python\ndef get_poller():\n if os.name == \"nt\":\n return WindowsKeyPoller()\n else:\n return UnixKeyPoller()\n```"]} +{"repo": "locustio/locust", "name": "test_ramp_down_users_to_3_workers_with_spawn_rate_of_9", "language": "python", "path": "locust/test/test_dispatch.py", "position_ratio": 0.35, "description": "\nFunction Description:\n1. **Purpose**: To verify that a system correctly reduces the number of active users across three worker nodes to zero using a specific rate of user deactivation.\n2. **Input**: A setup with three worker nodes and three types of users, initially having a total of nine active users distributed across the nodes.\n3. **Output**: Assertions check that all users are deactivated in one step and that no further users remain to be deactivated, confirming the rapid deactivation process.\n4. **Procedure**: \n - Initialize the system with three worker nodes and an equal distribution of nine users.\n - Command the system to deactivate all users at a rate of nine users per step.\n - Verify through assertions that all users are deactivated in one step and that attempting to get the next state results in no further users, confirming the end of the deactivation process.\n\n", "instruction": "Based on the function description and code context, please retrieve and repeat the exact described function from the code context in a code block wrapped by ```:", "template": "instruction\ncode_context\ndescription\ninstruction", "code_context": "\n ts = time.perf_counter()\n self.assertDictEqual(\n next(users_dispatcher),\n {\n \"1\": {\"User1\": 1, \"User2\": 1, \"User3\": 0},\n \"2\": {\"User1\": 0, \"User2\": 1, \"User3\": 1},\n \"3\": {\"User1\": 1, \"User2\": 0, \"User3\": 1},\n \"4\": {\"User1\": 1, \"User2\": 0, \"User3\": 0},\n },\n )\n delta = time.perf_counter() - ts\n self.assertTrue(sleep_time - _TOLERANCE <= delta <= sleep_time + _TOLERANCE, delta)\n\n ts = time.perf_counter()\n self.assertDictEqual(\n next(users_dispatcher),\n {\n \"1\": {\"User1\": 1, \"User2\": 1, \"User3\": 0},\n \"2\": {\"User1\": 0, \"User2\": 1, \"User3\": 1},\n \"3\": {\"User1\": 0, \"User2\": 0, \"User3\": 1},\n \"4\": {\"User1\": 1, \"User2\": 0, \"User3\": 0},\n },\n )\n delta = time.perf_counter() - ts\n self.assertTrue(sleep_time - _TOLERANCE <= delta <= sleep_time + _TOLERANCE, delta)\n\n ts = time.perf_counter()\n self.assertDictEqual(\n next(users_dispatcher),\n {\n \"1\": {\"User1\": 1, \"User2\": 1, \"User3\": 0},\n \"2\": {\"User1\": 0, \"User2\": 1, \"User3\": 0},\n \"3\": {\"User1\": 0, \"User2\": 0, \"User3\": 1},\n \"4\": {\"User1\": 1, \"User2\": 0, \"User3\": 0},\n },\n )\n delta = time.perf_counter() - ts\n self.assertTrue(sleep_time - _TOLERANCE <= delta <= sleep_time + _TOLERANCE, delta)\n\n ts = time.perf_counter()\n self.assertDictEqual(\n next(users_dispatcher),\n {\n \"1\": {\"User1\": 1, \"User2\": 0, \"User3\": 0},\n \"2\": {\"User1\": 0, \"User2\": 1, \"User3\": 0},\n \"3\": {\"User1\": 0, \"User2\": 0, \"User3\": 1},\n \"4\": {\"User1\": 1, \"User2\": 0, \"User3\": 0},\n },\n )\n delta = time.perf_counter() - ts\n self.assertTrue(sleep_time - _TOLERANCE <= delta <= sleep_time + _TOLERANCE, delta)\n\n ts = time.perf_counter()\n self.assertDictEqual(\n next(users_dispatcher),\n {\n \"1\": {\"User1\": 1, \"User2\": 0, \"User3\": 0},\n \"2\": {\"User1\": 0, \"User2\": 1, \"User3\": 0},\n \"3\": {\"User1\": 0, \"User2\": 0, \"User3\": 1},\n \"4\": {\"User1\": 0, \"User2\": 0, \"User3\": 0},\n },\n )\n delta = time.perf_counter() - ts\n self.assertTrue(sleep_time - _TOLERANCE <= delta <= sleep_time + _TOLERANCE, delta)\n\n ts = time.perf_counter()\n self.assertDictEqual(\n next(users_dispatcher),\n {\n \"1\": {\"User1\": 1, \"User2\": 0, \"User3\": 0},\n \"2\": {\"User1\": 0, \"User2\": 1, \"User3\": 0},\n \"3\": {\"User1\": 0, \"User2\": 0, \"User3\": 0},\n \"4\": {\"User1\": 0, \"User2\": 0, \"User3\": 0},\n },\n )\n delta = time.perf_counter() - ts\n self.assertTrue(sleep_time - _TOLERANCE <= delta <= sleep_time + _TOLERANCE, delta)\n\n ts = time.perf_counter()\n self.assertDictEqual(\n next(users_dispatcher),\n {\n \"1\": {\"User1\": 1, \"User2\": 0, \"User3\": 0},\n \"2\": {\"User1\": 0, \"User2\": 0, \"User3\": 0},\n \"3\": {\"User1\": 0, \"User2\": 0, \"User3\": 0},\n \"4\": {\"User1\": 0, \"User2\": 0, \"User3\": 0},\n },\n )\n delta = time.perf_counter() - ts\n self.assertTrue(sleep_time - _TOLERANCE <= delta <= sleep_time + _TOLERANCE, delta)\n\n ts = time.perf_counter()\n self.assertDictEqual(\n next(users_dispatcher),\n {\n \"1\": {\"User1\": 0, \"User2\": 0, \"User3\": 0},\n \"2\": {\"User1\": 0, \"User2\": 0, \"User3\": 0},\n \"3\": {\"User1\": 0, \"User2\": 0, \"User3\": 0},\n \"4\": {\"User1\": 0, \"User2\": 0, \"User3\": 0},\n },\n )\n delta = time.perf_counter() - ts\n self.assertTrue(sleep_time - _TOLERANCE <= delta <= sleep_time + _TOLERANCE, delta)\n\n ts = time.perf_counter()\n self.assertRaises(StopIteration, lambda: next(users_dispatcher))\n delta = time.perf_counter() - ts\n self.assertTrue(0 <= delta <= _TOLERANCE, delta)\n\n def test_ramp_down_users_to_3_workers_with_spawn_rate_of_2(self):\n class User1(User):\n weight = 1\n\n class User2(User):\n weight = 1\n\n class User3(User):\n weight = 1\n\n user_classes = [User1, User2, User3]\n\n workers = [WorkerNode(str(i + 1)) for i in range(3)]\n\n initial_user_count = 9\n\n users_dispatcher = UsersDispatcher(worker_nodes=workers, user_classes=user_classes)\n users_dispatcher.new_dispatch(target_user_count=initial_user_count, spawn_rate=initial_user_count)\n users_dispatcher._wait_between_dispatch = 0\n list(users_dispatcher)\n\n sleep_time = 0.2 # Speed-up test\n\n users_dispatcher.new_dispatch(target_user_count=0, spawn_rate=2)\n users_dispatcher._wait_between_dispatch = sleep_time\n\n ts = time.perf_counter()\n self.assertDictEqual(\n next(users_dispatcher),\n {\n \"1\": {\"User1\": 3, \"User2\": 0, \"User3\": 0},\n \"2\": {\"User1\": 0, \"User2\": 2, \"User3\": 0},\n \"3\": {\"User1\": 0, \"User2\": 0, \"User3\": 2},\n },\n )\n delta = time.perf_counter() - ts\n self.assertTrue(0 <= delta <= _TOLERANCE, delta)\n\n ts = time.perf_counter()\n self.assertDictEqual(\n next(users_dispatcher),\n {\n \"1\": {\"User1\": 2, \"User2\": 0, \"User3\": 0},\n \"2\": {\"User1\": 0, \"User2\": 2, \"User3\": 0},\n \"3\": {\"User1\": 0, \"User2\": 0, \"User3\": 1},\n },\n )\n delta = time.perf_counter() - ts\n self.assertTrue(sleep_time - _TOLERANCE <= delta <= sleep_time + _TOLERANCE, delta)\n\n ts = time.perf_counter()\n self.assertDictEqual(\n next(users_dispatcher),\n {\n \"1\": {\"User1\": 1, \"User2\": 0, \"User3\": 0},\n \"2\": {\"User1\": 0, \"User2\": 1, \"User3\": 0},\n \"3\": {\"User1\": 0, \"User2\": 0, \"User3\": 1},\n },\n )\n delta = time.perf_counter() - ts\n self.assertTrue(sleep_time - _TOLERANCE <= delta <= sleep_time + _TOLERANCE, delta)\n\n ts = time.perf_counter()\n self.assertDictEqual(\n next(users_dispatcher),\n {\n \"1\": {\"User1\": 1, \"User2\": 0, \"User3\": 0},\n \"2\": {\"User1\": 0, \"User2\": 0, \"User3\": 0},\n \"3\": {\"User1\": 0, \"User2\": 0, \"User3\": 0},\n },\n )\n delta = time.perf_counter() - ts\n self.assertTrue(sleep_time - _TOLERANCE <= delta <= sleep_time + _TOLERANCE, delta)\n\n ts = time.perf_counter()\n self.assertDictEqual(\n next(users_dispatcher),\n {\n \"1\": {\"User1\": 0, \"User2\": 0, \"User3\": 0},\n \"2\": {\"User1\": 0, \"User2\": 0, \"User3\": 0},\n \"3\": {\"User1\": 0, \"User2\": 0, \"User3\": 0},\n },\n )\n delta = time.perf_counter() - ts\n self.assertTrue(sleep_time - _TOLERANCE <= delta <= sleep_time + _TOLERANCE, delta)\n\n ts = time.perf_counter()\n self.assertRaises(StopIteration, lambda: next(users_dispatcher))\n delta = time.perf_counter() - ts\n self.assertTrue(0 <= delta <= _TOLERANCE, delta)\n\n def test_ramp_down_users_to_3_workers_with_spawn_rate_of_2_4(self):\n class User1(User):\n weight = 1\n\n class User2(User):\n weight = 1\n\n class User3(User):\n weight = 1\n\n user_classes = [User1, User2, User3]\n\n workers = [WorkerNode(str(i + 1)) for i in range(3)]\n\n initial_user_count = 9\n\n users_dispatcher = UsersDispatcher(worker_nodes=workers, user_classes=user_classes)\n users_dispatcher.new_dispatch(target_user_count=initial_user_count, spawn_rate=initial_user_count)\n users_dispatcher._wait_between_dispatch = 0\n list(users_dispatcher)\n\n sleep_time = 0.2 # Speed-up test\n\n users_dispatcher.new_dispatch(target_user_count=0, spawn_rate=2.4)\n users_dispatcher._wait_between_dispatch = sleep_time\n\n ts = time.perf_counter()\n self.assertDictEqual(\n next(users_dispatcher),\n {\n \"1\": {\"User1\": 3, \"User2\": 0, \"User3\": 0},\n \"2\": {\"User1\": 0, \"User2\": 2, \"User3\": 0},\n \"3\": {\"User1\": 0, \"User2\": 0, \"User3\": 2},\n },\n )\n delta = time.perf_counter() - ts\n self.assertTrue(0 <= delta <= _TOLERANCE, delta)\n\n ts = time.perf_counter()\n self.assertDictEqual(\n next(users_dispatcher),\n {\n \"1\": {\"User1\": 2, \"User2\": 0, \"User3\": 0},\n \"2\": {\"User1\": 0, \"User2\": 2, \"User3\": 0},\n \"3\": {\"User1\": 0, \"User2\": 0, \"User3\": 1},\n },\n )\n delta = time.perf_counter() - ts\n self.assertTrue(sleep_time - _TOLERANCE <= delta <= sleep_time + _TOLERANCE, delta)\n\n ts = time.perf_counter()\n self.assertDictEqual(\n next(users_dispatcher),\n {\n \"1\": {\"User1\": 1, \"User2\": 0, \"User3\": 0},\n \"2\": {\"User1\": 0, \"User2\": 1, \"User3\": 0},\n \"3\": {\"User1\": 0, \"User2\": 0, \"User3\": 1},\n },\n )\n delta = time.perf_counter() - ts\n self.assertTrue(sleep_time - _TOLERANCE <= delta <= sleep_time + _TOLERANCE, delta)\n\n ts = time.perf_counter()\n self.assertDictEqual(\n next(users_dispatcher),\n {\n \"1\": {\"User1\": 1, \"User2\": 0, \"User3\": 0},\n \"2\": {\"User1\": 0, \"User2\": 0, \"User3\": 0},\n \"3\": {\"User1\": 0, \"User2\": 0, \"User3\": 0},\n },\n )\n delta = time.perf_counter() - ts\n self.assertTrue(sleep_time - _TOLERANCE <= delta <= sleep_time + _TOLERANCE, delta)\n\n ts = time.perf_counter()\n self.assertDictEqual(\n next(users_dispatcher),\n {\n \"1\": {\"User1\": 0, \"User2\": 0, \"User3\": 0},\n \"2\": {\"User1\": 0, \"User2\": 0, \"User3\": 0},\n \"3\": {\"User1\": 0, \"User2\": 0, \"User3\": 0},\n },\n )\n delta = time.perf_counter() - ts\n self.assertTrue(sleep_time - _TOLERANCE <= delta <= sleep_time + _TOLERANCE, delta)\n\n ts = time.perf_counter()\n self.assertRaises(StopIteration, lambda: next(users_dispatcher))\n delta = time.perf_counter() - ts\n self.assertTrue(0 <= delta <= _TOLERANCE, delta)\n\n def test_ramp_down_users_to_3_workers_with_spawn_rate_of_3(self):\n class User1(User):\n weight = 1\n\n class User2(User):\n weight = 1\n\n class User3(User):\n weight = 1\n\n user_classes = [User1, User2, User3]\n\n workers = [WorkerNode(str(i + 1)) for i in range(3)]\n\n initial_user_count = 9\n\n users_dispatcher = UsersDispatcher(worker_nodes=workers, user_classes=user_classes)\n users_dispatcher.new_dispatch(target_user_count=initial_user_count, spawn_rate=initial_user_count)\n users_dispatcher._wait_between_dispatch = 0\n list(users_dispatcher)\n\n sleep_time = 0.2 # Speed-up test\n\n users_dispatcher.new_dispatch(target_user_count=0, spawn_rate=3)\n users_dispatcher._wait_between_dispatch = sleep_time\n\n ts = time.perf_counter()\n self.assertDictEqual(\n next(users_dispatcher),\n {\n \"1\": {\"User1\": 2, \"User2\": 0, \"User3\": 0},\n \"2\": {\"User1\": 0, \"User2\": 2, \"User3\": 0},\n \"3\": {\"User1\": 0, \"User2\": 0, \"User3\": 2},\n },\n )\n delta = time.perf_counter() - ts\n self.assertTrue(0 <= delta <= _TOLERANCE, delta)\n\n ts = time.perf_counter()\n self.assertDictEqual(\n next(users_dispatcher),\n {\n \"1\": {\"User1\": 1, \"User2\": 0, \"User3\": 0},\n \"2\": {\"User1\": 0, \"User2\": 1, \"User3\": 0},\n \"3\": {\"User1\": 0, \"User2\": 0, \"User3\": 1},\n },\n )\n delta = time.perf_counter() - ts\n self.assertTrue(sleep_time - _TOLERANCE <= delta <= sleep_time + _TOLERANCE, delta)\n\n ts = time.perf_counter()\n self.assertDictEqual(\n next(users_dispatcher),\n {\n \"1\": {\"User1\": 0, \"User2\": 0, \"User3\": 0},\n \"2\": {\"User1\": 0, \"User2\": 0, \"User3\": 0},\n \"3\": {\"User1\": 0, \"User2\": 0, \"User3\": 0},\n },\n )\n delta = time.perf_counter() - ts\n self.assertTrue(sleep_time - _TOLERANCE <= delta <= sleep_time + _TOLERANCE, delta)\n\n ts = time.perf_counter()\n self.assertRaises(StopIteration, lambda: next(users_dispatcher))\n delta = time.perf_counter() - ts\n self.assertTrue(0 <= delta <= _TOLERANCE, delta)\n\n def test_ramp_down_users_to_3_workers_with_spawn_rate_of_4(self):\n class User1(User):\n weight = 1\n\n class User2(User):\n weight = 1\n\n class User3(User):\n weight = 1\n\n user_classes = [User1, User2, User3]\n\n workers = [WorkerNode(str(i + 1)) for i in range(3)]\n\n initial_user_count = 9\n\n users_dispatcher = UsersDispatcher(worker_nodes=workers, user_classes=user_classes)\n users_dispatcher.new_dispatch(target_user_count=initial_user_count, spawn_rate=initial_user_count)\n users_dispatcher._wait_between_dispatch = 0\n list(users_dispatcher)\n\n sleep_time = 0.2 # Speed-up test\n\n users_dispatcher.new_dispatch(target_user_count=0, spawn_rate=4)\n users_dispatcher._wait_between_dispatch = sleep_time\n\n ts = time.perf_counter()\n self.assertDictEqual(\n next(users_dispatcher),\n {\n \"1\": {\"User1\": 2, \"User2\": 0, \"User3\": 0},\n \"2\": {\"User1\": 0, \"User2\": 2, \"User3\": 0},\n \"3\": {\"User1\": 0, \"User2\": 0, \"User3\": 1},\n },\n )\n delta = time.perf_counter() - ts\n self.assertTrue(0 <= delta <= _TOLERANCE, delta)\n\n ts = time.perf_counter()\n self.assertDictEqual(\n next(users_dispatcher),\n {\n \"1\": {\"User1\": 1, \"User2\": 0, \"User3\": 0},\n \"2\": {\"User1\": 0, \"User2\": 0, \"User3\": 0},\n \"3\": {\"User1\": 0, \"User2\": 0, \"User3\": 0},\n },\n )\n delta = time.perf_counter() - ts\n self.assertTrue(sleep_time - _TOLERANCE <= delta <= sleep_time + _TOLERANCE, delta)\n\n ts = time.perf_counter()\n self.assertDictEqual(\n next(users_dispatcher),\n {\n \"1\": {\"User1\": 0, \"User2\": 0, \"User3\": 0},\n \"2\": {\"User1\": 0, \"User2\": 0, \"User3\": 0},\n \"3\": {\"User1\": 0, \"User2\": 0, \"User3\": 0},\n },\n )\n delta = time.perf_counter() - ts\n self.assertTrue(sleep_time - _TOLERANCE <= delta <= sleep_time + _TOLERANCE, delta)\n\n ts = time.perf_counter()\n self.assertRaises(StopIteration, lambda: next(users_dispatcher))\n delta = time.perf_counter() - ts\n self.assertTrue(0 <= delta <= _TOLERANCE, delta)\n\n \ndef test_ramp_down_users_to_3_workers_with_spawn_rate_of_9(self):\n class User1(User):\n weight = 1\n\n class User2(User):\n weight = 1\n\n class User3(User):\n weight = 1\n\n user_classes = [User1, User2, User3]\n\n workers = [WorkerNode(str(i + 1)) for i in range(3)]\n\n initial_user_count = 9\n\n users_dispatcher = UsersDispatcher(worker_nodes=workers, user_classes=user_classes)\n users_dispatcher.new_dispatch(target_user_count=initial_user_count, spawn_rate=initial_user_count)\n users_dispatcher._wait_between_dispatch = 0\n list(users_dispatcher)\n\n sleep_time = 0.2 # Speed-up test\n\n users_dispatcher.new_dispatch(target_user_count=0, spawn_rate=9)\n users_dispatcher._wait_between_dispatch = sleep_time\n\n ts = time.perf_counter()\n self.assertDictEqual(\n next(users_dispatcher),\n {\n \"1\": {\"User1\": 0, \"User2\": 0, \"User3\": 0},\n \"2\": {\"User1\": 0, \"User2\": 0, \"User3\": 0},\n \"3\": {\"User1\": 0, \"User2\": 0, \"User3\": 0},\n },\n )\n delta = time.perf_counter() - ts\n self.assertTrue(0 <= delta <= _TOLERANCE, delta)\n\n ts = time.perf_counter()\n self.assertRaises(StopIteration, lambda: next(users_dispatcher))\n delta = time.perf_counter() - ts\n self.assertTrue(0 <= delta <= _TOLERANCE, delta)\n\n\n@unittest.skip(reason=\"takes too long. run this manually if you change dispatch logic.\")\nclass TestRampUpThenDownThenUp(unittest.TestCase):\n def test_ramp_up_then_down_then_up(self):\n for user1_weight, user2_weight, user3_weight, user4_weight, user5_weight in [\n (1, 1, 1, 1, 1),\n (1, 2, 3, 4, 5),\n (1, 3, 5, 7, 9),\n ]:\n\n class User1(User):\n weight = user1_weight\n\n class User2(User):\n weight = user2_weight\n\n class User3(User):\n weight = user3_weight\n\n class User4(User):\n weight = user4_weight\n\n class User5(User):\n weight = user5_weight\n\n all_user_classes = [User1, User2, User3, User4, User5]\n\n for number_of_user_classes in range(1, len(all_user_classes) + 1):\n user_classes = all_user_classes[:number_of_user_classes]\n\n for max_user_count, min_user_count in [(30, 15), (54, 21), (14165, 1476)]:\n for worker_count in [1, 3, 5, 9]:\n workers = [WorkerNode(str(i + 1)) for i in range(worker_count)]\n\n users_dispatcher = UsersDispatcher(worker_nodes=workers, user_classes=user_classes)\n\n # Ramp-up to go to `min_user_count` #########\n\n users_dispatcher.new_dispatch(target_user_count=min_user_count, spawn_rate=1)\n users_dispatcher._wait_between_dispatch = 0\n\n all_dispatched_users_ramp_up_to_min_user_count = list(users_dispatcher)\n\n # Ramp-up to go to `max_user_count` #########\n\n users_dispatcher.new_dispatch(target_user_count=max_user_count, spawn_rate=1)\n users_dispatcher._wait_between_dispatch = 0\n\n list(users_dispatcher)\n\n # Ramp-down go back to `min_user_count` #########\n\n users_dispatcher.new_dispatch(target_user_count=min_user_count, spawn_rate=1)\n users_dispatcher._wait_between_dispatch = 0\n\n all_dispatched_users_ramp_down_to_min_user_count = list(users_dispatcher)\n\n # Assertions #########\n\n self.assertDictEqual(\n all_dispatched_users_ramp_up_to_min_user_count[-1],\n all_dispatched_users_ramp_down_to_min_user_count[-1],\n )\n\n\nclass TestDispatchUsersToWorkersHavingTheSameUsersAsTheTarget(unittest.TestCase):\n def test_dispatch_users_to_3_workers(self):\n \"\"\"Final distribution should be {\"User1\": 3, \"User2\": 3, \"User3\": 3}\"\"\"\n\n class User1(User):\n weight = 1\n\n class User2(User):\n weight = 1\n\n class User3(User):\n weight = 1\n\n user_classes = [User1, User2, User3]\n\n user_count = 9\n\n for spawn_rate in [0.15, 0.5, 1, 2, 2.4, 3, 4, 9]:\n workers = [WorkerNode(str(i + 1)) for i in range(3)]\n\n users_dispatcher = UsersDispatcher(worker_nodes=workers, user_classes=user_classes)\n users_dispatcher.new_dispatch(target_user_count=user_count, spawn_rate=user_count)\n users_dispatcher._wait_between_dispatch = 0\n list(users_dispatcher)\n\n sleep_time = 0.2 # Speed-up test\n\n users_dispatcher.new_dispatch(target_user_count=user_count, spawn_rate=spawn_rate)\n users_dispatcher._wait_between_dispatch = sleep_time\n\n ts = time.perf_counter()\n self.assertDictEqual(\n next(users_dispatcher),\n {\n \"1\": {\"User1\": 3, \"User2\": 0, \"User3\": 0},\n \"2\": {\"User1\": 0, \"User2\": 3, \"User3\": 0},\n \"3\": {\"User1\": 0, \"User2\": 0, \"User3\": 3},\n },\n )\n delta = time.perf_counter() - ts\n self.assertTrue(0 <= delta <= _TOLERANCE, delta)\n\n ts = time.perf_counter()\n self.assertRaises(StopIteration, lambda: next(users_dispatcher))\n delta = time.perf_counter() - ts\n self.assertTrue(0 <= delta <= _TOLERANCE, delta)\n\n clear_all_functools_lru_cache()\n\n\nclass TestDistributionIsRespectedDuringDispatch(unittest.TestCase):\n def test_dispatch_75_users_to_4_workers_with_spawn_rate_of_5(self):\n \"\"\"\n Test case covering reported issue in https://github.com/locustio/locust/pull/1621#issuecomment-853624275.\n\n The case is to ramp-up from 0 to 75 users with two user classes. `User1` has a weight of 1 and `User2`\n has a weight of 2. The original issue was with 500 users, but to keep the test shorter, we use 75 users.\n\n Final distribution should be {\"User1\": 25, \"User2\": 50}\n \"\"\"\n\n class User1(User):\n weight = 1\n\n class User2(User):\n weight = 2\n\n worker_node1 = WorkerNode(\"1\")\n worker_node2 = WorkerNode(\"2\")\n worker_node3 = WorkerNode(\"3\")\n worker_node4 = WorkerNode(\"4\")\n\n users_dispatcher = UsersDispatcher(\n worker_nodes=[worker_node1, worker_node2, worker_node3, worker_node4], user_classes=[User1, User2]\n )\n users_dispatcher.new_dispatch(target_user_count=75, spawn_rate=5)\n users_dispatcher._wait_between_dispatch = 0\n\n # total user count = 5\n dispatched_users = next(users_dispatcher)\n self.assertDictEqual(_aggregate_dispatched_users(dispatched_users), {\"User1\": 2, \"User2\": 3})\n self.assertDictEqual(\n dispatched_users,\n {\n \"1\": {\"User1\": 1, \"User2\": 1},\n \"2\": {\"User1\": 1, \"User2\": 0},\n \"3\": {\"User1\": 0, \"User2\": 1},\n \"4\": {\"User1\": 0, \"User2\": 1},\n },\n )\n\n # total user count = 10\n dispatched_users = next(users_dispatcher)\n self.assertDictEqual(_aggregate_dispatched_users(dispatched_users), {\"User1\": 3, \"User2\": 7})\n self.assertDictEqual(\n dispatched_users,\n {\n \"1\": {\"User1\": 1, \"User2\": 2},\n \"2\": {\"User1\": 1, \"User2\": 2},\n \"3\": {\"User1\": 0, \"User2\": 2},\n \"4\": {\"User1\": 1, \"User2\": 1},\n },\n )\n\n # total user count = 15\n dispatched_users = next(users_dispatcher)\n self.assertDictEqual(_aggregate_dispatched_users(dispatched_users), {\"User1\": 5, \"User2\": 10})\n self.assertDictEqual(\n dispatched_users,\n {\n \"1\": {\"User1\": 1, \"User2\": 3},\n \"2\": {\"User1\": 2, \"User2\": 2},\n \"3\": {\"User1\": 1, \"User2\": 3},\n \"4\": {\"User1\": 1, \"User2\": 2},\n },\n )\n\n # total user count = 20\n dispatched_users = next(users_dispatcher)\n self.assertDictEqual(_aggregate_dispatched_users(dispatched_users), {\"User1\": 7, \"User2\": 13})\n self.assertDictEqual(\n dispatched_users,\n {\n \"1\": {\"User1\": 2, \"User2\": 3},\n \"2\": {\"User1\": 2, \"User2\": 3},\n \"3\": {\"User1\": 1, \"User2\": 4},\n \"4\": {\"User1\": 2, \"User2\": 3},\n },\n )\n\n # total user count = 25\n dispatched_users = next(users_dispatcher)\n self.assertDictEqual(_aggregate_dispatched_users(dispatched_users), {\"User1\": 8, \"User2\": 17})\n self.assertDictEqual(\n dispatched_users,\n {\n \"1\": {\"User1\": 2, \"User2\": 5},\n \"2\": {\"User1\": 2, \"User2\": 4},\n \"3\": {\"User1\": 2, \"User2\": 4},\n \"4\": {\"User1\": 2, \"User2\": 4},\n },\n )\n\n # total user count = 30\n dispatched_users = next(users_dispatcher)\n self.assertDictEqual(_aggregate_dispatched_users(dispatched_users), {\"User1\": 10, \"User2\": 20})\n self.assertDictEqual(\n dispatched_users,\n {\n \"1\": {\"User1\": 3, \"User2\": 5},\n \"2\": {\"User1\": 3, \"User2\": 5},\n \"3\": {\"User1\": 2, \"User2\": 5},\n \"4\": {\"User1\": 2, \"User2\": 5},\n },\n )\n\n # total user count = 35\n dispatched_users = next(users_dispatcher)\n self.assertDictEqual(_aggregate_dispatched_users(dispatched_users), {\"User1\": 12, \"User2\": 23})\n self.assertDictEqual(\n dispatched_users,\n {\n \"1\": {\"User1\": 3, \"User2\": 6},\n \"2\": {\"User1\": 3, \"User2\": 6},\n \"3\": {\"User1\": 3, \"User2\": 6},\n \"4\": {\"User1\": 3, \"User2\": 5},\n },\n )\n\n # total user count = 40\n dispatched_users = next(users_dispatcher)\n self.assertDictEqual(_aggregate_dispatched_users(dispatched_users), {\"User1\": 13, \"User2\": 27})\n self.assertDictEqual(\n dispatched_users,\n {\n \"1\": {\"User1\": 3, \"User2\": 7},\n \"2\": {\"User1\": 4, \"User2\": 6},\n \"3\": {\"User1\": 3, \"User2\": 7},\n \"4\": {\"User1\": 3, \"User2\": 7},\n },\n )\n\n # total user count = 45\n dispatched_users = next(users_dispatcher)\n self.assertDictEqual(_aggregate_dispatched_users(dispatched_users), {\"User1\": 15, \"User2\": 30})\n self.assertDictEqual(\n dispatched_users,\n {\n \"1\": {\"User1\": 4, \"User2\": 8},\n \"2\": {\"User1\": 4, \"User2\": 7},\n \"3\": {\"User1\": 3, \"User2\": 8},\n \"4\": {\"User1\": 4, \"User2\": 7},\n },\n )\n\n # total user count = 50\n dispatched_users = next(users_dispatcher)\n self.assertDictEqual(_aggregate_dispatched_users(dispatched_users), {\"User1\": 17, \"User2\": 33})\n self.assertDictEqual(\n dispatched_users,\n {\n \"1\": {\"User1\": 4, \"User2\": 9},\n \"2\": {\"User1\": 5, \"User2\": 8},\n \"3\": {\"User1\": 4, \"User2\": 8},\n \"4\": {\"User1\": 4, \"User2\": 8},\n },\n )\n\n # total user count = 55\n dispatched_users = next(users_dispatcher)\n self.assertDictEqual(_aggregate_dispatched_users(dispatched_users), {\"User1\": 18, \"User2\": 37})\n self.assertDictEqual(\n dispatched_users,\n {\n \"1\": {\"User1\": 5, \"User2\": 9},\n \"2\": {\"User1\": 5, \"User2\": 9},\n \"3\": {\"User1\": 4, \"User2\": 10},\n \"4\": {\"User1\": 4, \"User2\": 9},\n },\n )\n\n # total user count = 60\n dispatched_users = next(users_dispatcher)\n self.assertDictEqual(_aggregate_dispatched_users(dispatched_users), {\"User1\": 20, \"User2\": 40})\n self.assertDictEqual(\n dispatched_users,\n {\n \"1\": {\"User1\": 5, \"User2\": 10},\n \"2\": {\"User1\": 5, \"User2\": 10},\n \"3\": {\"User1\": 5, \"User2\": 10},\n \"4\": {\"User1\": 5, \"User2\": 10},\n },\n )\n\n # total user count = 65\n dispatched_users = next(users_dispatcher)\n self.assertDictEqual(_aggregate_dispatched_users(dispatched_users), {\"User1\": 22, \"User2\": 43})\n self.assertDictEqual(\n dispatched_users,\n {\n \"1\": {\"User1\": 6, \"User2\": 11},\n \"2\": {\"User1\": 6, \"User2\": 10},\n \"3\": {\"User1\": 5, \"User2\": 11},\n \"4\": {\"User1\": 5, \"User2\": 11},\n },\n )\n\n # total user count = 70\n dispatched_users = next(users_dispatcher)\n self.assertDictEqual(_aggregate_dispatched_users(dispatched_users), {\"User1\": 23, \"User2\": 47})\n self.assertDictEqual(\n dispatched_users,\n {\n \"1\": {\"User1\": 6, \"User2\": 12},\n \"2\": {\"User1\": 6, \"User2\": 12},\n \"3\": {\"User1\": 5, \"User2\": 12},\n \"4\": {\"User1\": 6, \"User2\": 11},\n },\n )\n\n # total user count = 75, User1 = 25, User2 = 50\n dispatched_users = next(users_dispatcher)\n self.assertDictEqual(_aggregate_dispatched_users(dispatched_users), {\"User1\": 25, \"User2\": 50})\n self.assertDictEqual(\n dispatched_users,\n {\n \"1\": {\"User1\": 6, \"User2\": 13},\n \"2\": {\"User1\": 7, \"User2\": 12},\n \"3\": {\"User1\": 6, \"User2\": 13},\n \"4\": {\"User1\": 6, \"User2\": 12},\n },\n )\n\n self.assertRaises(StopIteration, lambda: next(users_dispatcher))\n\n\nclass TestLargeScale(unittest.TestCase):\n # fmt: off\n weights = [\n 5, 55, 37, 2, 97, 41, 33, 19, 19, 34, 78, 76, 28, 62, 69, 5, 55, 37, 2, 97, 41, 33, 19, 19, 34,\n 78, 76, 28, 62, 69, 41, 33, 19, 19, 34, 78, 76, 28, 62, 69, 41, 33, 19, 19, 34, 78, 76, 28, 62, 69\n ]\n # fmt: on\n numerated_weights = dict(zip(range(len(weights)), weights))\n\n weighted_user_classes = [type(f\"User{i}\", (User,), {\"weight\": w}) for i, w in numerated_weights.items()]\n fixed_user_classes_10k = [type(f\"FixedUser10k{i}\", (User,), {\"fixed_count\": 2000}) for i in range(50)]\n fixed_user_classes_1M = [type(f\"FixedUser1M{i}\", (User,), {\"fixed_count\": 20000}) for i in range(50)]\n mixed_users = weighted_user_classes[:25] + fixed_user_classes_10k[25:]\n\n def test_distribute_users(self):\n for user_classes in [self.weighted_user_classes, self.fixed_user_classes_1M, self.mixed_users]:\n workers = [WorkerNode(str(i)) for i in range(10_000)]\n\n target_user_count = 1_000_000\n\n users_dispatcher = UsersDispatcher(worker_nodes=workers, user_classes=user_classes)\n\n ts = time.perf_counter()\n users_on_workers, user_gen, worker_gen, active_users = users_dispatcher._distribute_users(\n target_user_count=target_user_count\n )\n delta = time.perf_counter() - ts\n\n # Because tests are run with coverage, the code will be slower.\n # We set the pass criterion to 7000ms, but in real life, the\n # `_distribute_users` method runs faster than this.\n self.assertLessEqual(1000 * delta, 7000)\n\n self.assertEqual(_user_count(users_on_workers), target_user_count)\n\n def test_ramp_up_from_0_to_100_000_users_with_50_user_classes_and_1000_workers_and_5000_spawn_rate(self):\n for user_classes in [\n self.weighted_user_classes,\n self.fixed_user_classes_1M,\n self.fixed_user_classes_10k,\n self.mixed_users,\n ]:\n workers = [WorkerNode(str(i)) for i in range(1000)]\n\n target_user_count = 100_000\n\n users_dispatcher = UsersDispatcher(worker_nodes=workers, user_classes=user_classes)\n users_dispatcher.new_dispatch(target_user_count=target_user_count, spawn_rate=5_000)\n users_dispatcher._wait_between_dispatch = 0\n\n all_dispatched_users = list(users_dispatcher)\n\n tol = 0.2\n self.assertTrue(\n all(\n dispatch_iteration_duration <= tol\n for dispatch_iteration_duration in users_dispatcher.dispatch_iteration_durations\n ),\n \"One or more dispatch took more than {:.0f}s to compute (max = {}ms)\".format(\n tol * 1000, 1000 * max(users_dispatcher.dispatch_iteration_durations)\n ),\n )\n\n self.assertEqual(_user_count(all_dispatched_users[-1]), target_user_count)\n\n for dispatch_users in all_dispatched_users:\n user_count_on_workers = [\n sum(user_classes_count.values()) for user_classes_count in dispatch_users.values()\n ]\n self.assertLessEqual(\n max(user_count_on_workers) - min(user_count_on_workers),\n 1,\n \"One or more workers have too much users compared to the other workers when user count is {}\".format(\n _user_count(dispatch_users)\n ),\n )\n\n for i, dispatch_users in enumerate(all_dispatched_users):\n aggregated_dispatched_users = _aggregate_dispatched_users(dispatch_users)\n for user_class in [u for u in user_classes if not u.fixed_count]:\n target_relative_weight = user_class.weight / sum(\n map(attrgetter(\"weight\"), [u for u in user_classes if not u.fixed_count])\n )\n relative_weight = aggregated_dispatched_users[user_class.__name__] / _user_count(dispatch_users)\n error_percent = 100 * (relative_weight - target_relative_weight) / target_relative_weight\n if i == len(all_dispatched_users) - 1:\n # We want the distribution to be as good as possible at the end of the ramp-up\n tol = 0.5\n else:\n tol = 15\n self.assertLessEqual(\n error_percent,\n tol,\n \"Distribution for user class {} is off by more than {}% when user count is {}\".format(\n user_class, tol, _user_count(dispatch_users)\n ),\n )\n\n def test_ramp_down_from_100_000_to_0_users_with_50_user_classes_and_1000_workers_and_5000_spawn_rate(self):\n for user_classes in [\n self.weighted_user_classes,\n self.fixed_user_classes_1M,\n self.fixed_user_classes_10k,\n self.mixed_users,\n ]:\n initial_user_count = 100_000\n\n workers = [WorkerNode(str(i)) for i in range(1000)]\n\n # Ramp-up\n users_dispatcher = UsersDispatcher(worker_nodes=workers, user_classes=user_classes)\n users_dispatcher.new_dispatch(target_user_count=initial_user_count, spawn_rate=initial_user_count)\n users_dispatcher._wait_between_dispatch = 0\n list(users_dispatcher)\n\n # Ramp-down\n users_dispatcher.new_dispatch(target_user_count=0, spawn_rate=5000)\n users_dispatcher._wait_between_dispatch = 0\n\n all_dispatched_users = list(users_dispatcher)\n\n tol = 0.2\n self.assertTrue(\n all(\n dispatch_iteration_duration <= tol\n for dispatch_iteration_duration in users_dispatcher.dispatch_iteration_durations\n ),\n \"One or more dispatch took more than {:.0f}ms to compute (max = {}ms)\".format(\n tol * 1000, 1000 * max(users_dispatcher.dispatch_iteration_durations)\n ),\n )\n\n self.assertEqual(_user_count(all_dispatched_users[-1]), 0)\n\n for dispatch_users in all_dispatched_users[:-1]:\n user_count_on_workers = [\n sum(user_classes_count.values()) for user_classes_count in dispatch_users.values()\n ]\n self.assertLessEqual(\n max(user_count_on_workers) - min(user_count_on_workers),\n 1,\n \"One or more workers have too much users compared to the other workers when user count is {}\".format(\n _user_count(dispatch_users)\n ),\n )\n\n for dispatch_users in all_dispatched_users[:-1]:\n aggregated_dispatched_users = _aggregate_dispatched_users(dispatch_users)\n for user_class in [u for u in user_classes if not u.fixed_count]:\n target_relative_weight = user_class.weight / sum(\n map(attrgetter(\"weight\"), [u for u in user_classes if not u.fixed_count])\n )\n relative_weight = aggregated_dispatched_users[user_class.__name__] / _user_count(dispatch_users)\n error_percent = 100 * (relative_weight - target_relative_weight) / target_relative_weight\n tol = 15\n self.assertLessEqual(\n error_percent,\n tol,\n \"Distribution for user class {} is off by more than {}% when user count is {}\".format(\n user_class, tol, _user_count(dispatch_users)\n ),\n )\n\n\nclass TestSmallConsecutiveRamping(unittest.TestCase):\n def test_consecutive_ramp_up_and_ramp_down(self):\n class User1(User):\n weight = 1\n\n class User2(User):\n weight = 1\n\n user_classes = [User1, User2]\n\n worker_node1 = WorkerNode(\"1\")\n worker_node2 = WorkerNode(\"2\")\n\n worker_nodes = [worker_node1, worker_node2]\n\n users_dispatcher = UsersDispatcher(worker_nodes=worker_nodes, user_classes=user_classes)\n\n # user count = 1\n users_dispatcher.new_dispatch(target_user_count=1, spawn_rate=1)\n users_dispatcher._wait_between_dispatch = 0\n\n dispatched_users = next(users_dispatcher)\n self.assertDictEqual(_aggregate_dispatched_users(dispatched_users), {\"User1\": 1, \"User2\": 0})\n self.assertEqual(_user_count_on_worker(dispatched_users, worker_node1.id), 1)\n self.assertEqual(_user_count_on_worker(dispatched_users, worker_node2.id), 0)\n\n # user count = 2\n users_dispatcher.new_dispatch(target_user_count=2, spawn_rate=1)\n users_dispatcher._wait_between_dispatch = 0\n\n dispatched_users = next(users_dispatcher)\n self.assertDictEqual(_aggregate_dispatched_users(dispatched_users), {\"User1\": 1, \"User2\": 1})\n self.assertEqual(_user_count_on_worker(dispatched_users, worker_node1.id), 1)\n self.assertEqual(_user_count_on_worker(dispatched_users, worker_node2.id), 1)\n\n # user count = 3\n users_dispatcher.new_dispatch(target_user_count=3, spawn_rate=1)\n users_dispatcher._wait_between_dispatch = 0\n\n dispatched_users = next(users_dispatcher)\n self.assertDictEqual(_aggregate_dispatched_users(dispatched_users), {\"User1\": 2, \"User2\": 1})\n self.assertEqual(_user_count_on_worker(dispatched_users, worker_node1.id), 2)\n self.assertEqual(_user_count_on_worker(dispatched_users, worker_node2.id), 1)\n\n # user count = 4\n users_dispatcher.new_dispatch(target_user_count=4, spawn_rate=1)\n users_dispatcher._wait_between_dispatch = 0\n\n dispatched_users = next(users_dispatcher)\n self.assertDictEqual(_aggregate_dispatched_users(dispatched_users), {\"User1\": 2, \"User2\": 2})\n self.assertEqual(_user_count_on_worker(dispatched_users, worker_node1.id), 2)\n self.assertEqual(_user_count_on_worker(dispatched_users, worker_node2.id), 2)\n\n # user count = 3\n users_dispatcher.new_dispatch(target_user_count=3, spawn_rate=1)\n users_dispatcher._wait_between_dispatch = 0\n\n dispatched_users = next(users_dispatcher)\n self.assertDictEqual(_aggregate_dispatched_users(dispatched_users), {\"User1\": 2, \"User2\": 1})\n self.assertEqual(_user_count_on_worker(dispatched_users, worker_node1.id), 2)\n self.assertEqual(_user_count_on_worker(dispatched_users, worker_node2.id), 1)\n\n # user count = 2\n users_dispatcher.new_dispatch(target_user_count=2, spawn_rate=1)\n users_dispatcher._wait_between_dispatch = 0\n\n dispatched_users = next(users_dispatcher)\n self.assertDictEqual(_aggregate_dispatched_users(dispatched_users), {\"User1\": 1, \"User2\": 1})\n self.assertEqual(_user_count_on_worker(dispatched_users, worker_node1.id), 1)\n self.assertEqual(_user_count_on_worker(dispatched_users, worker_node2.id), 1)\n\n # user count = 1\n users_dispatcher.new_dispatch(target_user_count=1, spawn_rate=1)\n users_dispatcher._wait_between_dispatch = 0\n\n dispatched_users = next(users_dispatcher)\n self.assertDictEqual(_aggregate_dispatched_users(dispatched_users), {\"User1\": 1, \"User2\": 0})\n self.assertEqual(_user_count_on_worker(dispatched_users, worker_node1.id), 1)\n self.assertEqual(_user_count_on_worker(dispatched_users, worker_node2.id), 0)\n\n # user count = 0\n users_dispatcher.new_dispatch(target_user_count=0, spawn_rate=1)\n users_dispatcher._wait_between_dispatch = 0\n\n dispatched_users = next(users_dispatcher)\n self.assertDictEqual(_aggregate_dispatched_users(dispatched_users), {\"User1\": 0, \"User2\": 0})\n self.assertEqual(_user_count_on_worker(dispatched_users, worker_node1.id), 0)\n self.assertEqual(_user_count_on_worker(dispatched_users, worker_node2.id), 0)\n\n\nclass TestRampingMiscellaneous(unittest.TestCase):\n def test_spawn_rate_greater_than_target_user_count(self):\n class User1(User):\n weight = 1\n\n user_classes = [User1]\n\n worker_nodes = [WorkerNode(str(i + 1)) for i in range(1)]\n\n users_dispatcher = UsersDispatcher(worker_nodes=worker_nodes, user_classes=user_classes)\n\n users_dispatcher.new_dispatch(target_user_count=1, spawn_rate=100)\n users_dispatcher._wait_between_dispatch = 0\n dispatched_users = next(users_dispatcher)\n self.assertDictEqual(dispatched_users, {\"1\": {\"User1\": 1}})\n\n users_dispatcher.new_dispatch(target_user_count=11, spawn_rate=100)\n users_dispatcher._wait_between_dispatch = 0\n dispatched_users = next(users_dispatcher)\n self.assertDictEqual(dispatched_users, {\"1\": {\"User1\": 11}})\n\n users_dispatcher.new_dispatch(target_user_count=10, spawn_rate=100)\n users_dispatcher._wait_between_dispatch = 0\n dispatched_users = next(users_dispatcher)\n self.assertDictEqual(dispatched_users, {\"1\": {\"User1\": 10}})\n\n users_dispatcher.new_dispatch(target_user_count=0, spawn_rate=100)\n users_dispatcher._wait_between_dispatch = 0\n dispatched_users = next(users_dispatcher)\n self.assertDictEqual(dispatched_users, {\"1\": {\"User1\": 0}})\n\n\nclass TestRemoveWorker(unittest.TestCase):\n def test_remove_worker_during_ramp_up(self):\n class User1(User):\n weight = 1\n\n class User2(User):\n weight = 1\n\n class User3(User):\n weight = 1\n\n user_classes = [User1, User2, User3]\n\n worker_nodes = [WorkerNode(str(i + 1)) for i in range(3)]\n\n users_dispatcher = UsersDispatcher(worker_nodes=worker_nodes, user_classes=user_classes)\n\n sleep_time = 0.2 # Speed-up test\n\n users_dispatcher.new_dispatch(target_user_count=9, spawn_rate=3)\n users_dispatcher._wait_between_dispatch = sleep_time\n\n # Dispatch iteration 1\n ts = time.perf_counter()\n dispatched_users = next(users_dispatcher)\n delta = time.perf_counter() - ts\n self.assertTrue(0 <= delta <= _TOLERANCE, delta)\n self.assertDictEqual(_aggregate_dispatched_users(dispatched_users), {\"User1\": 1, \"User2\": 1, \"User3\": 1})\n self.assertEqual(_user_count_on_worker(dispatched_users, worker_nodes[0].id), 1)\n self.assertEqual(_user_count_on_worker(dispatched_users, worker_nodes[1].id), 1)\n self.assertEqual(_user_count_on_worker(dispatched_users, worker_nodes[2].id), 1)\n\n # Dispatch iteration 2\n ts = time.perf_counter()\n dispatched_users = next(users_dispatcher)\n delta = time.perf_counter() - ts\n self.assertTrue(sleep_time - _TOLERANCE <= delta <= sleep_time + _TOLERANCE, delta)\n self.assertDictEqual(_aggregate_dispatched_users(dispatched_users), {\"User1\": 2, \"User2\": 2, \"User3\": 2})\n self.assertEqual(_user_count_on_worker(dispatched_users, worker_nodes[0].id), 2)\n self.assertEqual(_user_count_on_worker(dispatched_users, worker_nodes[1].id), 2)\n self.assertEqual(_user_count_on_worker(dispatched_users, worker_nodes[2].id), 2)\n\n self.assertFalse(users_dispatcher._rebalance)\n\n users_dispatcher.remove_worker(worker_nodes[1])\n\n self.assertTrue(users_dispatcher._rebalance)\n\n # Re-balance\n ts = time.perf_counter()\n dispatched_users = next(users_dispatcher)\n delta = time.perf_counter() - ts\n self.assertTrue(0 <= delta <= _TOLERANCE, f\"Expected re-balance dispatch to be instantaneous but got {delta}s\")\n self.assertDictEqual(_aggregate_dispatched_users(dispatched_users), {\"User1\": 2, \"User2\": 2, \"User3\": 2})\n self.assertEqual(_user_count_on_worker(dispatched_users, worker_nodes[0].id), 3)\n self.assertEqual(_user_count_on_worker(dispatched_users, worker_nodes[2].id), 3)\n\n self.assertFalse(users_dispatcher._rebalance)\n\n # Dispatch iteration 3\n ts = time.perf_counter()\n dispatched_users = next(users_dispatcher)\n delta = time.perf_counter() - ts\n self.assertTrue(sleep_time - _TOLERANCE <= delta <= sleep_time + _TOLERANCE, delta)\n self.assertDictEqual(_aggregate_dispatched_users(dispatched_users), {\"User1\": 3, \"User2\": 3, \"User3\": 3})\n self.assertEqual(_user_count_on_worker(dispatched_users, worker_nodes[0].id), 5)\n self.assertEqual(_user_count_on_worker(dispatched_users, worker_nodes[2].id), 4)\n\n def test_remove_two_workers_during_ramp_up(self):\n class User1(User):\n weight = 1\n\n class User2(User):\n weight = 1\n\n class User3(User):\n weight = 1\n\n user_classes = [User1, User2, User3]\n\n worker_nodes = [WorkerNode(str(i + 1)) for i in range(3)]\n\n users_dispatcher = UsersDispatcher(worker_nodes=worker_nodes, user_classes=user_classes)\n\n sleep_time = 0.2 # Speed-up test\n\n users_dispatcher.new_dispatch(target_user_count=9, spawn_rate=3)\n users_dispatcher._wait_between_dispatch = sleep_time\n\n # Dispatch iteration 1\n ts = time.perf_counter()\n dispatched_users = next(users_dispatcher)\n delta = time.perf_counter() - ts\n self.assertTrue(0 <= delta <= _TOLERANCE, delta)\n self.assertDictEqual(_aggregate_dispatched_users(dispatched_users), {\"User1\": 1, \"User2\": 1, \"User3\": 1})\n self.assertEqual(_user_count_on_worker(dispatched_users, worker_nodes[0].id), 1)\n self.assertEqual(_user_count_on_worker(dispatched_users, worker_nodes[1].id), 1)\n self.assertEqual(_user_count_on_worker(dispatched_users, worker_nodes[2].id), 1)\n\n # Dispatch iteration 2\n ts = time.perf_counter()\n dispatched_users = next(users_dispatcher)\n delta = time.perf_counter() - ts\n self.assertTrue(sleep_time - _TOLERANCE <= delta <= sleep_time + _TOLERANCE, delta)\n self.assertDictEqual(_aggregate_dispatched_users(dispatched_users), {\"User1\": 2, \"User2\": 2, \"User3\": 2})\n self.assertEqual(_user_count_on_worker(dispatched_users, worker_nodes[0].id), 2)\n self.assertEqual(_user_count_on_worker(dispatched_users, worker_nodes[1].id), 2)\n self.assertEqual(_user_count_on_worker(dispatched_users, worker_nodes[2].id), 2)\n\n self.assertFalse(users_dispatcher._rebalance)\n\n users_dispatcher.remove_worker(worker_nodes[1])\n users_dispatcher.remove_worker(worker_nodes[2])\n\n self.assertTrue(users_dispatcher._rebalance)\n\n # Re-balance\n...\n", "needle_token_start": 5418, "needle_token_end": 5899, "code_context_ntokens": 16157, "output": ["```python\ndef test_ramp_down_users_to_3_workers_with_spawn_rate_of_9(self):\n class User1(User):\n weight = 1\n\n class User2(User):\n weight = 1\n\n class User3(User):\n weight = 1\n\n user_classes = [User1, User2, User3]\n\n workers = [WorkerNode(str(i + 1)) for i in range(3)]\n\n initial_user_count = 9\n\n users_dispatcher = UsersDispatcher(worker_nodes=workers, user_classes=user_classes)\n users_dispatcher.new_dispatch(target_user_count=initial_user_count, spawn_rate=initial_user_count)\n users_dispatcher._wait_between_dispatch = 0\n list(users_dispatcher)\n\n sleep_time = 0.2 # Speed-up test\n\n users_dispatcher.new_dispatch(target_user_count=0, spawn_rate=9)\n users_dispatcher._wait_between_dispatch = sleep_time\n\n ts = time.perf_counter()\n self.assertDictEqual(\n next(users_dispatcher),\n {\n \"1\": {\"User1\": 0, \"User2\": 0, \"User3\": 0},\n \"2\": {\"User1\": 0, \"User2\": 0, \"User3\": 0},\n \"3\": {\"User1\": 0, \"User2\": 0, \"User3\": 0},\n },\n )\n delta = time.perf_counter() - ts\n self.assertTrue(0 <= delta <= _TOLERANCE, delta)\n\n ts = time.perf_counter()\n self.assertRaises(StopIteration, lambda: next(users_dispatcher))\n delta = time.perf_counter() - ts\n self.assertTrue(0 <= delta <= _TOLERANCE, delta)\n```"]} +{"repo": "locustio/locust", "name": "test_custom_ssl_context_fail_with_bad_context", "language": "python", "path": "locust/test/test_fasthttp.py", "position_ratio": 0.45, "description": "\nFunction Description:\n1. **Purpose**: To verify that the system correctly handles an error when attempting to use an SSL context with incompatible settings (hostname checks enabled but certificate verification disabled).\n2. **Input**: A custom SSL context configuration where hostname verification is enabled but the SSL certificate verification mode is set to none.\n3. **Output**: An exception is expected to be raised, specifically a `ValueError` indicating the conflict in SSL context settings.\n4. **Procedure**: \n - A custom SSL context is created with conflicting settings: hostname checks are enabled, but certificate verification is disabled.\n - An HTTP session is initiated using this custom SSL context to make a request to a predefined server address.\n - The test checks that making a request with this configuration raises the expected `ValueError`, confirming that the system properly identifies and rejects the invalid SSL context settings.\n\n", "instruction": "Based on the function description and code context, please retrieve and repeat the exact described function from the code context in a code block wrapped by ```:", "template": "instruction\ncode_context\ndescription\ninstruction", "code_context": "# Path: locust/test/test_dispatch.py\nfrom __future__ import annotations\n\nfrom locust import User\nfrom locust.dispatch import UsersDispatcher\nfrom locust.runners import WorkerNode\nfrom locust.test.util import clear_all_functools_lru_cache\n\nimport time\nimport unittest\nfrom operator import attrgetter\n\n_TOLERANCE = 0.025\n\n\nclass TestRampUpUsersFromZero(unittest.TestCase):\n def test_ramp_up_users_to_3_workers_with_spawn_rate_of_0_5(self):\n \"\"\"Final distribution should be {\"User1\": 3, \"User2\": 3, \"User3\": 3}\"\"\"\n\n class User1(User):\n weight = 1\n\n class User2(User):\n weight = 1\n\n class User3(User):\n weight = 1\n\n worker_node1 = WorkerNode(\"1\")\n worker_node2 = WorkerNode(\"2\")\n worker_node3 = WorkerNode(\"3\")\n\n sleep_time = 0.2 # Speed-up test\n\n users_dispatcher = UsersDispatcher(\n worker_nodes=[worker_node1, worker_node2, worker_node3], user_classes=[User1, User2, User3]\n )\n users_dispatcher.new_dispatch(target_user_count=9, spawn_rate=0.5)\n users_dispatcher._wait_between_dispatch = sleep_time\n\n ts = time.perf_counter()\n self.assertDictEqual(\n next(users_dispatcher),\n {\n \"1\": {\"User1\": 1, \"User2\": 0, \"User3\": 0},\n \"2\": {\"User1\": 0, \"User2\": 0, \"User3\": 0},\n \"3\": {\"User1\": 0, \"User2\": 0, \"User3\": 0},\n },\n )\n delta = time.perf_counter() - ts\n self.assertTrue(0 <= delta <= _TOLERANCE, delta)\n\n ts = time.perf_counter()\n self.assertDictEqual(\n next(users_dispatcher),\n {\n \"1\": {\"User1\": 1, \"User2\": 0, \"User3\": 0},\n \"2\": {\"User1\": 0, \"User2\": 1, \"User3\": 0},\n \"3\": {\"User1\": 0, \"User2\": 0, \"User3\": 0},\n },\n )\n delta = time.perf_counter() - ts\n self.assertTrue(sleep_time - _TOLERANCE <= delta <= sleep_time + _TOLERANCE, delta)\n\n ts = time.perf_counter()\n self.assertDictEqual(\n next(users_dispatcher),\n {\n \"1\": {\"User1\": 1, \"User2\": 0, \"User3\": 0},\n \"2\": {\"User1\": 0, \"User2\": 1, \"User3\": 0},\n \"3\": {\"User1\": 0, \"User2\": 0, \"User3\": 1},\n },\n )\n delta = time.perf_counter() - ts\n self.assertTrue(sleep_time - _TOLERANCE <= delta <= sleep_time + _TOLERANCE, delta)\n\n ts = time.perf_counter()\n self.assertDictEqual(\n next(users_dispatcher),\n {\n \"1\": {\"User1\": 2, \"User2\": 0, \"User3\": 0},\n \"2\": {\"User1\": 0, \"User2\": 1, \"User3\": 0},\n \"3\": {\"User1\": 0, \"User2\": 0, \"User3\": 1},\n },\n )\n delta = time.perf_counter() - ts\n self.assertTrue(sleep_time - _TOLERANCE <= delta <= sleep_time + _TOLERANCE, delta)\n\n ts = time.perf_counter()\n self.assertDictEqual(\n next(users_dispatcher),\n {\n \"1\": {\"User1\": 2, \"User2\": 0, \"User3\": 0},\n \"2\": {\"User1\": 0, \"User2\": 2, \"User3\": 0},\n \"3\": {\"User1\": 0, \"User2\": 0, \"User3\": 1},\n },\n )\n delta = time.perf_counter() - ts\n self.assertTrue(sleep_time - _TOLERANCE <= delta <= sleep_time + _TOLERANCE, delta)\n\n ts = time.perf_counter()\n self.assertDictEqual(\n next(users_dispatcher),\n {\n \"1\": {\"User1\": 2, \"User2\": 0, \"User3\": 0},\n \"2\": {\"User1\": 0, \"User2\": 2, \"User3\": 0},\n \"3\": {\"User1\": 0, \"User2\": 0, \"User3\": 2},\n },\n )\n delta = time.perf_counter() - ts\n self.assertTrue(sleep_time - _TOLERANCE <= delta <= sleep_time + _TOLERANCE, delta)\n\n ts = time.perf_counter()\n self.assertDictEqual(\n next(users_dispatcher),\n {\n \"1\": {\"User1\": 3, \"User2\": 0, \"User3\": 0},\n \"2\": {\"User1\": 0, \"User2\": 2, \"User3\": 0},\n \"3\": {\"User1\": 0, \"User2\": 0, \"User3\": 2},\n },\n )\n delta = time.perf_counter() - ts\n self.assertTrue(sleep_time - _TOLERANCE <= delta <= sleep_time + _TOLERANCE, delta)\n\n ts = time.perf_counter()\n self.assertDictEqual(\n next(users_dispatcher),\n {\n \"1\": {\"User1\": 3, \"User2\": 0, \"User3\": 0},\n \"2\": {\"User1\": 0, \"User2\": 3, \"User3\": 0},\n \"3\": {\"User1\": 0, \"User2\": 0, \"User3\": 2},\n },\n )\n delta = time.perf_counter() - ts\n self.assertTrue(sleep_time - _TOLERANCE <= delta <= sleep_time + _TOLERANCE, delta)\n\n ts = time.perf_counter()\n self.assertDictEqual(\n next(users_dispatcher),\n {\n \"1\": {\"User1\": 3, \"User2\": 0, \"User3\": 0},\n \"2\": {\"User1\": 0, \"User2\": 3, \"User3\": 0},\n \"3\": {\"User1\": 0, \"User2\": 0, \"User3\": 3},\n },\n )\n delta = time.perf_counter() - ts\n self.assertTrue(sleep_time - _TOLERANCE <= delta <= sleep_time + _TOLERANCE, delta)\n\n ts = time.perf_counter()\n self.assertRaises(StopIteration, lambda: next(users_dispatcher))\n delta = time.perf_counter() - ts\n self.assertTrue(0 <= delta <= _TOLERANCE, delta)\n\n def test_ramp_up_users_to_3_workers_with_spawn_rate_of_1(self):\n \"\"\"Final distribution should be {\"User1\": 3, \"User2\": 3, \"User3\": 3}\"\"\"\n\n class User1(User):\n weight = 1\n\n class User2(User):\n weight = 1\n\n class User3(User):\n weight = 1\n\n worker_node1 = WorkerNode(\"1\")\n worker_node2 = WorkerNode(\"2\")\n worker_node3 = WorkerNode(\"3\")\n\n sleep_time = 0.2 # Speed-up test\n\n users_dispatcher = UsersDispatcher(\n worker_nodes=[worker_node1, worker_node2, worker_node3], user_classes=[User1, User2, User3]\n )\n users_dispatcher.new_dispatch(target_user_count=9, spawn_rate=1)\n users_dispatcher._wait_between_dispatch = sleep_time\n\n ts = time.perf_counter()\n self.assertDictEqual(\n next(users_dispatcher),\n {\n \"1\": {\"User1\": 1, \"User2\": 0, \"User3\": 0},\n...\n# Path: locust/test/test_env.py\nfrom locust import (\n constant,\n)\nfrom locust.dispatch import UsersDispatcher\nfrom locust.env import Environment, LoadTestShape\nfrom locust.user import (\n User,\n task,\n)\nfrom locust.user.task import TaskSet\n\nfrom .fake_module1_for_env_test import MyUserWithSameName as MyUserWithSameName1\nfrom .fake_module2_for_env_test import MyUserWithSameName as MyUserWithSameName2\nfrom .testcases import LocustTestCase\n\n\nclass TestEnvironment(LocustTestCase):\n def test_user_classes_count(self):\n class MyUser1(User):\n wait_time = constant(0)\n\n @task\n def my_task(self):\n pass\n\n class MyUser2(User):\n wait_time = constant(0)\n\n @task\n def my_task(self):\n pass\n\n environment = Environment(user_classes=[MyUser1, MyUser2])\n\n self.assertDictEqual({\"MyUser1\": MyUser1, \"MyUser2\": MyUser2}, environment.user_classes_by_name)\n\n def test_user_classes_with_same_name_is_error(self):\n with self.assertRaises(ValueError) as e:\n environment = Environment(user_classes=[MyUserWithSameName1, MyUserWithSameName2])\n\n self.assertEqual(\n e.exception.args[0],\n \"The following user classes have the same class name: locust.test.fake_module1_for_env_test.MyUserWithSameName, locust.test.fake_module2_for_env_test.MyUserWithSameName\",\n )\n\n def test_assign_equal_weights(self):\n def verify_tasks(u, target_tasks):\n self.assertEqual(len(u.tasks), len(target_tasks))\n tasks = [t.__name__ for t in u.tasks]\n self.assertEqual(len(tasks), len(set(tasks)))\n self.assertEqual(set(tasks), set(target_tasks))\n\n # Base case\n class MyUser1(User):\n wait_time = constant(0)\n\n @task(4)\n def my_task(self):\n pass\n\n @task(1)\n def my_task_2(self):\n pass\n\n environment = Environment(user_classes=[MyUser1])\n environment.assign_equal_weights()\n u = environment.user_classes[0]\n verify_tasks(u, [\"my_task\", \"my_task_2\"])\n\n # Testing nested task sets\n class MyUser2(User):\n @task\n class TopLevelTaskSet(TaskSet):\n @task\n class IndexTaskSet(TaskSet):\n @task(10)\n def index(self):\n self.client.get(\"/\")\n\n @task\n def stop(self):\n self.client.get(\"/hi\")\n\n @task(2)\n def stats(self):\n self.client.get(\"/stats/requests\")\n\n environment = Environment(user_classes=[MyUser2])\n environment.assign_equal_weights()\n u = environment.user_classes[0]\n verify_tasks(u, [\"index\", \"stop\", \"stats\"])\n\n # Testing task assignment via instance variable\n def outside_task():\n pass\n\n def outside_task_2():\n pass\n\n class SingleTaskSet(TaskSet):\n tasks = [outside_task, outside_task, outside_task_2]\n\n class MyUser3(User):\n tasks = [SingleTaskSet, outside_task]\n\n environment = Environment(user_classes=[MyUser3])\n environment.assign_equal_weights()\n u = environment.user_classes[0]\n verify_tasks(u, [\"outside_task\", \"outside_task_2\"])\n\n # Testing task assignment via dict\n class DictTaskSet(TaskSet):\n def dict_task_1():\n pass\n\n def dict_task_2():\n pass\n\n def dict_task_3():\n pass\n\n tasks = {\n dict_task_1: 5,\n dict_task_2: 3,\n dict_task_3: 1,\n }\n\n class MyUser4(User):\n tasks = [DictTaskSet, SingleTaskSet, SingleTaskSet]\n\n # Assign user tasks in dict\n environment = Environment(user_classes=[MyUser4])\n environment.assign_equal_weights()\n u = environment.user_classes[0]\n verify_tasks(u, [\"outside_task\", \"outside_task_2\", \"dict_task_1\", \"dict_task_2\", \"dict_task_3\"])\n\n class MyUser5(User):\n tasks = {\n DictTaskSet: 5,\n SingleTaskSet: 3,\n outside_task: 6,\n }\n\n environment = Environment(user_classes=[MyUser5])\n environment.assign_equal_weights()\n u = environment.user_classes[0]\n verify_tasks(u, [\"outside_task\", \"outside_task_2\", \"dict_task_1\", \"dict_task_2\", \"dict_task_3\"])\n\n def test_user_classes_with_zero_weight_are_removed(self):\n class MyUser1(User):\n wait_time = constant(0)\n weight = 0\n\n @task\n def my_task(self):\n pass\n\n class MyUser2(User):\n wait_time = constant(0)\n weight = 1\n\n @task\n def my_task(self):\n pass\n\n environment = Environment(user_classes=[MyUser1, MyUser2])\n\n self.assertEqual(len(environment.user_classes), 1)\n self.assertIs(environment.user_classes[0], MyUser2)\n\n def test_all_user_classes_with_zero_weight_raises_exception(self):\n class MyUser1(User):\n wait_time = constant(0)\n weight = 0\n\n @task\n def my_task(self):\n pass\n\n class MyUser2(User):\n wait_time = constant(0)\n weight = 0\n\n @task\n def my_task(self):\n pass\n\n with self.assertRaises(ValueError) as e:\n environment = Environment(user_classes=[MyUser1, MyUser2])\n\n self.assertEqual(\n e.exception.args[0],\n \"There are no users with weight > 0.\",\n )\n\n def test_shape_class_attribute(self):\n class SubLoadTestShape(LoadTestShape):\n \"\"\"Inherited from locust.env.LoadTestShape\"\"\"\n\n with self.assertRaisesRegex(\n ValueError, r\"instance of LoadTestShape or subclass LoadTestShape\", msg=\"exception message is mismatching\"\n ):\n Environment(user_classes=[MyUserWithSameName1], shape_class=SubLoadTestShape)\n\n def test_dispatcher_class_attribute(self):\n environment = Environment(user_classes=[MyUserWithSameName1])\n\n self.assertEqual(environment.dispatcher_class, UsersDispatcher)\n\n class MyUsersDispatcher(UsersDispatcher):\n pass\n\n environment = Environment(user_classes=[MyUserWithSameName1], dispatcher_class=MyUsersDispatcher)\n\n self.assertEqual(environment.dispatcher_class, MyUsersDispatcher)\n\n def test_update_user_class(self):\n class MyUser1(User):\n @task\n def my_task(self):\n pass\n\n @task\n def my_task_2(self):\n pass\n\n class MyUser2(User):\n @task\n def my_task(self):\n pass\n\n environment = Environment(\n user_classes=[MyUser1, MyUser2],\n available_user_classes={\"User1\": MyUser1, \"User2\": MyUser2},\n available_user_tasks={\"User1\": MyUser1.tasks, \"User2\": MyUser2.tasks},\n )\n\n environment.update_user_class({\"user_class_name\": \"User1\", \"host\": \"http://localhost\", \"tasks\": [\"my_task_2\"]})\n\n self.assertEqual(\n environment.available_user_classes[\"User1\"].json(),\n {\"host\": \"http://localhost\", \"tasks\": [\"my_task_2\"], \"fixed_count\": 0, \"weight\": 1},\n )\n\n# Path: locust/test/test_fasthttp.py\nfrom locust import FastHttpUser\nfrom locust.argument_parser import parse_options\nfrom locust.contrib.fasthttp import FastHttpSession\nfrom locust.exception import CatchResponseError, InterruptTaskSet, LocustError, ResponseError\nfrom locust.user import TaskSet, task\nfrom locust.util.load_locustfile import is_user_class\n\nimport socket\nimport time\nfrom tempfile import NamedTemporaryFile\n\nimport gevent\nfrom geventhttpclient.client import HTTPClientPool\n\nfrom .testcases import LocustTestCase, WebserverTestCase\nfrom .util import create_tls_cert\n\n\nclass TestFastHttpSession(WebserverTestCase):\n def get_client(self):\n return FastHttpSession(self.environment, base_url=\"http://127.0.0.1:%i\" % self.port, user=None)\n\n def test_get(self):\n s = self.get_client()\n r = s.get(\"/ultra_fast\")\n self.assertEqual(200, r.status_code)\n\n def test_connection_error(self):\n s = FastHttpSession(self.environment, \"http://localhost:1\", user=None)\n r = s.get(\"/\", headers={\"X-Test-Headers\": \"hello\"})\n self.assertEqual(r.status_code, 0)\n self.assertEqual(None, r.content)\n self.assertEqual(1, len(self.runner.stats.errors))\n self.assertTrue(isinstance(r.error, ConnectionRefusedError))\n self.assertTrue(isinstance(next(iter(self.runner.stats.errors.values())).error, ConnectionRefusedError))\n self.assertEqual(r.url, \"http://localhost:1/\")\n self.assertEqual(r.request.url, r.url)\n self.assertEqual(r.request.headers.get(\"X-Test-Headers\", \"\"), \"hello\")\n\n def test_404(self):\n s = self.get_client()\n r = s.get(\"/does_not_exist\")\n self.assertEqual(404, r.status_code)\n self.assertEqual(1, self.runner.stats.get(\"/does_not_exist\", \"GET\").num_failures)\n\n def test_204(self):\n s = self.get_client()\n r = s.get(\"/status/204\")\n self.assertEqual(204, r.status_code)\n self.assertEqual(1, self.runner.stats.get(\"/status/204\", \"GET\").num_requests)\n self.assertEqual(0, self.runner.stats.get(\"/status/204\", \"GET\").num_failures)\n self.assertEqual(r.url, \"http://127.0.0.1:%i/status/204\" % self.port)\n self.assertEqual(r.request.url, r.url)\n\n def test_streaming_response(self):\n \"\"\"\n Test a request to an endpoint that returns a streaming response\n \"\"\"\n s = self.get_client()\n r = s.get(\"/streaming/30\")\n\n # verify that the time reported includes the download time of the whole streamed response\n self.assertGreater(self.runner.stats.get(\"/streaming/30\", method=\"GET\").avg_response_time, 250)\n self.runner.stats.clear_all()\n\n # verify that response time does NOT include whole download time, when using stream=True\n r = s.get(\"/streaming/30\", stream=True)\n self.assertGreaterEqual(self.runner.stats.get(\"/streaming/30\", method=\"GET\").avg_response_time, 0)\n self.assertLess(self.runner.stats.get(\"/streaming/30\", method=\"GET\").avg_response_time, 250)\n\n # download the content of the streaming response (so we don't get an ugly exception in the log)\n _ = r.content\n\n def test_slow_redirect(self):\n s = self.get_client()\n url = \"/redirect?url=/redirect&delay=0.5\"\n r = s.get(url)\n stats = self.runner.stats.get(url, method=\"GET\")\n self.assertEqual(1, stats.num_requests)\n self.assertGreater(stats.avg_response_time, 500)\n\n def test_post_redirect(self):\n s = self.get_client()\n url = \"/redirect\"\n r = s.post(url)\n self.assertEqual(200, r.status_code)\n post_stats = self.runner.stats.get(url, method=\"POST\")\n get_stats = self.runner.stats.get(url, method=\"GET\")\n self.assertEqual(1, post_stats.num_requests)\n self.assertEqual(0, get_stats.num_requests)\n\n def test_cookie(self):\n s = self.get_client()\n r = s.post(\"/set_cookie?name=testcookie&value=1337\")\n self.assertEqual(200, r.status_code)\n r = s.get(\"/get_cookie?name=testcookie\")\n self.assertEqual(\"1337\", r.content.decode())\n self.assertEqual(\"1337\", r.text)\n\n def test_head(self):\n s = self.get_client()\n r = s.head(\"/request_method\")\n self.assertEqual(200, r.status_code)\n self.assertEqual(\"\", r.content.decode())\n\n def test_delete(self):\n s = self.get_client()\n r = s.delete(\"/request_method\")\n self.assertEqual(200, r.status_code)\n self.assertEqual(\"DELETE\", r.content.decode())\n\n def test_patch(self):\n s = self.get_client()\n r = s.patch(\"/request_method\")\n self.assertEqual(200, r.status_code)\n self.assertEqual(\"PATCH\", r.content.decode())\n\n def test_options(self):\n s = self.get_client()\n r = s.options(\"/request_method\")\n self.assertEqual(200, r.status_code)\n self.assertEqual(\"\", r.content.decode())\n self.assertEqual(\n {\"OPTIONS\", \"DELETE\", \"PUT\", \"GET\", \"POST\", \"HEAD\", \"PATCH\"},\n set(r.headers[\"allow\"].split(\", \")),\n )\n\n def test_json_payload(self):\n s = self.get_client()\n r = s.post(\"/request_method\", json={\"foo\": \"bar\"})\n self.assertEqual(200, r.status_code)\n self.assertEqual(r.request.body, '{\"foo\": \"bar\"}')\n self.assertEqual(r.request.headers.get(\"Content-Type\"), \"application/json\")\n\n def test_catch_response_fail_successful_request(self):\n s = self.get_client()\n with s.get(\"/ultra_fast\", catch_response=True) as r:\n r.failure(\"nope\")\n self.assertEqual(1, self.environment.stats.get(\"/ultra_fast\", \"GET\").num_requests)\n self.assertEqual(1, self.environment.stats.get(\"/ultra_fast\", \"GET\").num_failures)\n\n def test_catch_response_pass_failed_request(self):\n s = self.get_client()\n with s.get(\"/fail\", catch_response=True) as r:\n r.success()\n self.assertEqual(1, self.environment.stats.total.num_requests)\n self.assertEqual(0, self.environment.stats.total.num_failures)\n\n def test_catch_response_multiple_failure_and_success(self):\n s = self.get_client()\n with s.get(\"/ultra_fast\", catch_response=True) as r:\n r.failure(\"nope\")\n r.success()\n r.failure(\"nooo\")\n r.success()\n self.assertEqual(1, self.environment.stats.total.num_requests)\n self.assertEqual(0, self.environment.stats.total.num_failures)\n\n def test_catch_response_pass_failed_request_with_other_exception_within_block(self):\n class OtherException(Exception):\n pass\n\n s = self.get_client()\n try:\n with s.get(\"/fail\", catch_response=True) as r:\n r.success()\n raise OtherException(\"wtf\")\n except OtherException as e:\n pass\n else:\n self.fail(\"OtherException should have been raised\")\n\n self.assertEqual(1, self.environment.stats.total.num_requests)\n self.assertEqual(0, self.environment.stats.total.num_failures)\n\n def test_catch_response_default_success(self):\n s = self.get_client()\n with s.get(\"/ultra_fast\", catch_response=True) as r:\n pass\n self.assertEqual(1, self.environment.stats.get(\"/ultra_fast\", \"GET\").num_requests)\n self.assertEqual(0, self.environment.stats.get(\"/ultra_fast\", \"GET\").num_failures)\n\n def test_catch_response_default_fail(self):\n s = self.get_client()\n with s.get(\"/fail\", catch_response=True) as r:\n pass\n self.assertEqual(1, self.environment.stats.total.num_requests)\n self.assertEqual(1, self.environment.stats.total.num_failures)\n\n def test_error_message_with_name_replacement(self):\n s = self.get_client()\n kwargs = {}\n\n def on_request(**kw):\n self.assertIsNotNone(kw[\"exception\"])\n kwargs.update(kw)\n\n self.environment.events.request.add_listener(on_request)\n before_request = time.time()\n s.request(\"get\", \"/wrong_url/01\", name=\"replaced_url_name\", context={\"foo\": \"bar\"})\n after_request = time.time()\n # self.assertIn(\"for url: replaced_url_name\", str(kwargs[\"exception\"])) # this is actually broken for FastHttpUser right now...\n self.assertAlmostEqual(before_request, kwargs[\"start_time\"], delta=0.01)\n self.assertAlmostEqual(after_request, kwargs[\"start_time\"] + kwargs[\"response_time\"] / 1000, delta=0.01)\n self.assertEqual(s.base_url + \"/wrong_url/01\", kwargs[\"url\"]) # url is unaffected by name\n self.assertDictEqual({\"foo\": \"bar\"}, kwargs[\"context\"])\n\n \ndef test_custom_ssl_context_fail_with_bad_context(self):\n \"\"\"\n Test FastHttpSession with a custom SSLContext factory that will fail as\n we can not set verify_mode to CERT_NONE when check_hostname is enabled\n \"\"\"\n\n def create_custom_context():\n context = gevent.ssl.create_default_context()\n context.check_hostname = True\n context.verify_mode = gevent.ssl.CERT_NONE\n return context\n\n s = FastHttpSession(\n self.environment,\n \"https://127.0.0.1:%i\" % self.port,\n ssl_context_factory=create_custom_context,\n user=None,\n )\n with self.assertRaises(ValueError) as e:\n s.get(\"/\")\n self.assertEqual(e.exception.args, (\"Cannot set verify_mode to CERT_NONE when check_hostname is enabled.\",))\n\n def test_custom_ssl_context_passed_correct_to_client_pool(self):\n \"\"\"\n Test FastHttpSession with a custom SSLContext factory with a options.name\n that will be passed correctly to the ClientPool. It will also test a 2nd\n factory which is not the correct one.\n \"\"\"\n\n def custom_ssl_context():\n context = gevent.ssl.create_default_context()\n context.check_hostname = False\n context.verify_mode = gevent.ssl.CERT_NONE\n context.options.name = \"FAKEOPTION\"\n return context\n\n def custom_context_with_wrong_option():\n context = gevent.ssl.create_default_context()\n context.check_hostname = False\n context.verify_mode = gevent.ssl.CERT_NONE\n context.options.name = \"OPTIONFAKED\"\n return context\n\n s = FastHttpSession(\n self.environment,\n \"https://127.0.0.1:%i\" % self.port,\n ssl_context_factory=custom_ssl_context,\n user=None,\n )\n self.assertEqual(s.client.clientpool.client_args[\"ssl_context_factory\"], custom_ssl_context)\n self.assertNotEqual(s.client.clientpool.client_args[\"ssl_context_factory\"], custom_context_with_wrong_option)\n\n\nclass TestRequestStatsWithWebserver(WebserverTestCase):\n def test_request_stats_content_length(self):\n class MyUser(FastHttpUser):\n host = \"http://127.0.0.1:%i\" % self.port\n\n locust = MyUser(self.environment)\n locust.client.get(\"/ultra_fast\")\n self.assertEqual(\n self.runner.stats.get(\"/ultra_fast\", \"GET\").avg_content_length, len(\"This is an ultra fast response\")\n )\n locust.client.get(\"/ultra_fast\")\n self.assertEqual(\n self.runner.stats.get(\"/ultra_fast\", \"GET\").avg_content_length, len(\"This is an ultra fast response\")\n )\n\n def test_request_stats_no_content_length(self):\n class MyUser(FastHttpUser):\n host = \"http://127.0.0.1:%i\" % self.port\n\n l = MyUser(self.environment)\n path = \"/no_content_length\"\n r = l.client.get(path)\n self.assertEqual(\n self.runner.stats.get(path, \"GET\").avg_content_length,\n len(\"This response does not have content-length in the header\"),\n )\n\n def test_request_stats_no_content_length_streaming(self):\n class MyUser(FastHttpUser):\n host = \"http://127.0.0.1:%i\" % self.port\n\n l = MyUser(self.environment)\n path = \"/no_content_length\"\n r = l.client.get(path, stream=True)\n self.assertEqual(0, self.runner.stats.get(path, \"GET\").avg_content_length)\n\n def test_request_stats_named_endpoint(self):\n class MyUser(FastHttpUser):\n host = \"http://127.0.0.1:%i\" % self.port\n\n locust = MyUser(self.environment)\n locust.client.get(\"/ultra_fast\", name=\"my_custom_name\")\n self.assertEqual(1, self.runner.stats.get(\"my_custom_name\", \"GET\").num_requests)\n\n def test_request_stats_query_variables(self):\n class MyUser(FastHttpUser):\n host = \"http://127.0.0.1:%i\" % self.port\n\n locust = MyUser(self.environment)\n locust.client.get(\"/ultra_fast?query=1\")\n self.assertEqual(1, self.runner.stats.get(\"/ultra_fast?query=1\", \"GET\").num_requests)\n\n def test_request_stats_put(self):\n class MyUser(FastHttpUser):\n host = \"http://127.0.0.1:%i\" % self.port\n\n locust = MyUser(self.environment)\n locust.client.put(\"/put\")\n self.assertEqual(1, self.runner.stats.get(\"/put\", \"PUT\").num_requests)\n\n def test_request_connection_error(self):\n class MyUser(FastHttpUser):\n host = \"http://localhost:1\"\n\n locust = MyUser(self.environment)\n response = locust.client.get(\"/\")\n self.assertEqual(response.status_code, 0)\n self.assertEqual(1, self.runner.stats.get(\"/\", \"GET\").num_failures)\n self.assertEqual(1, self.runner.stats.get(\"/\", \"GET\").num_requests)\n\n\nclass TestFastHttpUserClass(WebserverTestCase):\n def test_is_abstract(self):\n self.assertTrue(FastHttpUser.abstract)\n self.assertFalse(is_user_class(FastHttpUser))\n\n def test_class_context(self):\n class MyUser(FastHttpUser):\n host = \"http://127.0.0.1:%i\" % self.port\n\n def context(self):\n return {\"user\": self.username}\n\n kwargs = {}\n\n def on_request(**kw):\n kwargs.update(kw)\n\n self.environment.events.request.add_listener(on_request)\n user = MyUser(self.environment)\n user.username = \"foo\"\n user.client.request(\"get\", \"/request_method\")\n self.assertDictEqual({\"user\": \"foo\"}, kwargs[\"context\"])\n self.assertEqual(\"GET\", kwargs[\"response\"].text)\n user.client.request(\"get\", \"/request_method\", context={\"user\": \"bar\"})\n self.assertDictEqual({\"user\": \"bar\"}, kwargs[\"context\"])\n\n def test_get_request(self):\n self.response = \"\"\n\n def t1(l):\n self.response = l.client.get(\"/ultra_fast\")\n\n class MyUser(FastHttpUser):\n tasks = [t1]\n host = \"http://127.0.0.1:%i\" % self.port\n\n my_locust = MyUser(self.environment)\n t1(my_locust)\n self.assertEqual(self.response.text, \"This is an ultra fast response\")\n\n def test_client_request_headers(self):\n class MyUser(FastHttpUser):\n host = \"http://127.0.0.1:%i\" % self.port\n\n locust = MyUser(self.environment)\n r = locust.client.get(\"/request_header_test\", headers={\"X-Header-Test\": \"hello\"})\n self.assertEqual(\"hello\", r.text)\n self.assertEqual(\"hello\", r.headers.get(\"X-Header-Test\"))\n self.assertEqual(\"hello\", r.request.headers.get(\"X-Header-Test\"))\n\n def test_client_get(self):\n class MyUser(FastHttpUser):\n host = \"http://127.0.0.1:%i\" % self.port\n\n locust = MyUser(self.environment)\n self.assertEqual(\"GET\", locust.client.get(\"/request_method\").text)\n\n def test_client_get_absolute_url(self):\n class MyUser(FastHttpUser):\n host = \"http://127.0.0.1:%i\" % self.port\n\n locust = MyUser(self.environment)\n self.assertEqual(\"GET\", locust.client.get(\"http://127.0.0.1:%i/request_method\" % self.port).text)\n\n def test_client_post(self):\n class MyUser(FastHttpUser):\n host = \"http://127.0.0.1:%i\" % self.port\n\n locust = MyUser(self.environment)\n self.assertEqual(\"POST\", locust.client.post(\"/request_method\", {\"arg\": \"hello world\"}).text)\n self.assertEqual(\"hello world\", locust.client.post(\"/post\", {\"arg\": \"hello world\"}).text)\n\n def test_client_put(self):\n class MyUser(FastHttpUser):\n host = \"http://127.0.0.1:%i\" % self.port\n\n locust = MyUser(self.environment)\n self.assertEqual(\"PUT\", locust.client.put(\"/request_method\", {\"arg\": \"hello world\"}).text)\n self.assertEqual(\"hello world\", locust.client.put(\"/put\", {\"arg\": \"hello world\"}).text)\n\n def test_client_delete(self):\n class MyUser(FastHttpUser):\n host = \"http://127.0.0.1:%i\" % self.port\n\n locust = MyUser(self.environment)\n self.assertEqual(\"DELETE\", locust.client.delete(\"/request_method\").text)\n self.assertEqual(200, locust.client.delete(\"/request_method\").status_code)\n\n def test_client_head(self):\n class MyUser(FastHttpUser):\n host = \"http://127.0.0.1:%i\" % self.port\n\n locust = MyUser(self.environment)\n self.assertEqual(200, locust.client.head(\"/request_method\").status_code)\n\n def test_complex_content_type(self):\n class MyUser(FastHttpUser):\n host = \"http://127.0.0.1:%i\" % self.port\n\n locust = MyUser(self.environment)\n\n self.assertEqual(\"stuff\", locust.client.get(\"/content_type_missing_charset\").text)\n self.assertEqual(\"stuff\", locust.client.get(\"/content_type_regular\").text)\n self.assertEqual(\"stuff\", locust.client.get(\"/content_type_with_extra_stuff\").text)\n\n def test_log_request_name_argument(self):\n self.response = \"\"\n\n class MyUser(FastHttpUser):\n tasks = []\n host = \"http://127.0.0.1:%i\" % self.port\n\n @task()\n def t1(l):\n self.response = l.client.get(\"/ultra_fast\", name=\"new name!\")\n\n my_locust = MyUser(self.environment)\n my_locust.t1()\n\n self.assertEqual(1, self.runner.stats.get(\"new name!\", \"GET\").num_requests)\n self.assertEqual(0, self.runner.stats.get(\"/ultra_fast\", \"GET\").num_requests)\n\n def test_redirect_url_original_path_as_name(self):\n class MyUser(FastHttpUser):\n host = \"http://127.0.0.1:%i\" % self.port\n\n l = MyUser(self.environment)\n l.client.get(\"/redirect\")\n\n self.assertEqual(1, len(self.runner.stats.entries))\n self.assertEqual(1, self.runner.stats.get(\"/redirect\", \"GET\").num_requests)\n self.assertEqual(0, self.runner.stats.get(\"/ultra_fast\", \"GET\").num_requests)\n\n def test_network_timeout_setting(self):\n class MyUser(FastHttpUser):\n network_timeout = 0.5\n host = \"http://127.0.0.1:%i\" % self.port\n\n l = MyUser(self.environment)\n\n timeout = gevent.Timeout(\n seconds=0.6,\n exception=AssertionError(\n \"Request took longer than 0.6 even though FastHttpUser.network_timeout was set to 0.5\"\n ),\n )\n timeout.start()\n r = l.client.get(\"/redirect?url=/redirect&delay=5.0\")\n timeout.cancel()\n\n self.assertTrue(isinstance(r.error.original, socket.timeout))\n self.assertEqual(1, self.runner.stats.get(\"/redirect?url=/redirect&delay=5.0\", \"GET\").num_failures)\n\n def test_max_redirect_setting(self):\n class MyUser(FastHttpUser):\n max_redirects = 1 # max_redirects and max_retries are funny names, because they are actually max attempts\n host = \"http://127.0.0.1:%i\" % self.port\n\n l = MyUser(self.environment)\n l.client.get(\"/redirect\")\n self.assertEqual(1, self.runner.stats.get(\"/redirect\", \"GET\").num_failures)\n\n def test_allow_redirects_override(self):\n class MyLocust(FastHttpUser):\n host = \"http://127.0.0.1:%i\" % self.port\n\n l = MyLocust(self.environment)\n resp = l.client.get(\"/redirect\", allow_redirects=False)\n self.assertTrue(resp.headers[\"location\"].endswith(\"/ultra_fast\"))\n resp = l.client.get(\"/redirect\") # ensure redirect still works\n self.assertFalse(\"location\" in resp.headers)\n\n def test_slow_redirect(self):\n s = FastHttpSession(self.environment, \"http://127.0.0.1:%i\" % self.port, user=None)\n url = \"/redirect?url=/redirect&delay=0.5\"\n r = s.get(url)\n stats = self.runner.stats.get(url, method=\"GET\")\n self.assertEqual(1, stats.num_requests)\n self.assertGreater(stats.avg_response_time, 500)\n\n def test_client_basic_auth(self):\n class MyUser(FastHttpUser):\n host = \"http://127.0.0.1:%i\" % self.port\n\n class MyAuthorizedUser(FastHttpUser):\n host = \"http://locust:menace@127.0.0.1:%i\" % self.port\n\n class MyUnauthorizedUser(FastHttpUser):\n host = \"http://locust:wrong@127.0.0.1:%i\" % self.port\n\n locust = MyUser(self.environment)\n unauthorized = MyUnauthorizedUser(self.environment)\n authorized = MyAuthorizedUser(self.environment)\n response = authorized.client.get(\"/basic_auth\")\n self.assertEqual(200, response.status_code)\n self.assertEqual(\"Authorized\", response.text)\n self.assertEqual(401, locust.client.get(\"/basic_auth\").status_code)\n self.assertEqual(401, unauthorized.client.get(\"/basic_auth\").status_code)\n\n def test_shared_client_pool(self):\n shared_client_pool = HTTPClientPool(concurrency=1)\n\n class MyUserA(FastHttpUser):\n host = \"http://127.0.0.1:%i\" % self.port\n client_pool = shared_client_pool\n\n class MyUserB(FastHttpUser):\n host = \"http://127.0.0.1:%i\" % self.port\n client_pool = shared_client_pool\n\n user_a = MyUserA(self.environment)\n user_b = MyUserB(self.environment)\n\n user_a.client.get(\"/ultra_fast\")\n user_b.client.get(\"/ultra_fast\")\n user_b.client.get(\"/ultra_fast\")\n user_a.client.get(\"/ultra_fast\")\n\n self.assertEqual(1, self.connections_count)\n self.assertEqual(4, self.requests_count)\n\n def test_client_pool_per_user_instance(self):\n class MyUser(FastHttpUser):\n host = \"http://127.0.0.1:%i\" % self.port\n\n user_a = MyUser(self.environment)\n user_b = MyUser(self.environment)\n\n user_a.client.get(\"/ultra_fast\")\n user_b.client.get(\"/ultra_fast\")\n user_b.client.get(\"/ultra_fast\")\n user_a.client.get(\"/ultra_fast\")\n\n self.assertEqual(2, self.connections_count)\n self.assertEqual(4, self.requests_count)\n\n def test_client_pool_concurrency(self):\n class MyUser(FastHttpUser):\n host = \"http://127.0.0.1:%i\" % self.port\n\n @task\n def t(self):\n def concurrent_request(url):\n response = self.client.get(url)\n assert response.status_code == 200\n\n pool = gevent.pool.Pool()\n urls = [\"/slow?delay=0.2\"] * 20 # these urls are all the same, but they could be different\n for url in urls:\n pool.spawn(concurrent_request, url)\n pool.join()\n\n user = MyUser(self.environment)\n before_requests = time.time()\n user.t()\n after_requests = time.time()\n expected_delta = 0.4 # 20 requests with concurrency 10 and response time 0.2\n self.assertAlmostEqual(before_requests + expected_delta, after_requests, delta=0.1)\n\n\nclass TestFastHttpCatchResponse(WebserverTestCase):\n def setUp(self):\n super().setUp()\n\n class MyUser(FastHttpUser):\n host = \"http://127.0.0.1:%i\" % self.port\n\n self.user = MyUser(self.environment)\n\n self.num_failures = 0\n self.num_success = 0\n\n def on_request(exception, **kwargs):\n if exception:\n self.num_failures += 1\n self.last_failure_exception = exception\n else:\n self.num_success += 1\n\n self.environment.events.request.add_listener(on_request)\n\n def test_catch_response(self):\n self.assertEqual(500, self.user.client.get(\"/fail\").status_code)\n self.assertEqual(1, self.num_failures)\n self.assertEqual(0, self.num_success)\n\n with self.user.client.get(\"/ultra_fast\", catch_response=True) as response:\n pass\n self.assertEqual(1, self.num_failures)\n self.assertEqual(1, self.num_success)\n self.assertIn(\"ultra fast\", str(response.content))\n\n with self.user.client.get(\"/ultra_fast\", catch_response=True) as response:\n raise ResponseError(\"Not working\")\n\n self.assertEqual(2, self.num_failures)\n self.assertEqual(1, self.num_success)\n\n def test_catch_response_http_fail(self):\n with self.user.client.get(\"/fail\", catch_response=True) as response:\n pass\n self.assertEqual(1, self.num_failures)\n self.assertEqual(0, self.num_success)\n\n def test_catch_response_http_manual_fail(self):\n with self.user.client.get(\"/ultra_fast\", catch_response=True) as response:\n response.failure(\"Haha!\")\n self.assertEqual(1, self.num_failures)\n self.assertEqual(0, self.num_success)\n self.assertTrue(\n isinstance(self.last_failure_exception, CatchResponseError),\n \"Failure event handler should have been passed a CatchResponseError instance\",\n )\n\n def test_catch_response_http_manual_success(self):\n with self.user.client.get(\"/fail\", catch_response=True) as response:\n response.success()\n self.assertEqual(0, self.num_failures)\n self.assertEqual(1, self.num_success)\n\n def test_catch_response_allow_404(self):\n with self.user.client.get(\"/does/not/exist\", catch_response=True) as response:\n self.assertEqual(404, response.status_code)\n if response.status_code == 404:\n response.success()\n self.assertEqual(0, self.num_failures)\n self.assertEqual(1, self.num_success)\n\n def test_interrupt_taskset_with_catch_response(self):\n class MyTaskSet(TaskSet):\n @task\n def interrupted_task(self):\n with self.client.get(\"/ultra_fast\", catch_response=True) as r:\n raise InterruptTaskSet()\n\n class MyUser(FastHttpUser):\n host = \"http://127.0.0.1:%i\" % self.port\n tasks = [MyTaskSet]\n\n l = MyUser(self.environment)\n ts = MyTaskSet(l)\n self.assertRaises(InterruptTaskSet, lambda: ts.interrupted_task())\n self.assertEqual(0, self.num_failures)\n self.assertEqual(0, self.num_success)\n\n def test_catch_response_connection_error_success(self):\n class MyUser(FastHttpUser):\n host = \"http://127.0.0.1:1\"\n\n l = MyUser(self.environment)\n with l.client.get(\"/\", catch_response=True) as r:\n self.assertEqual(r.status_code, 0)\n self.assertEqual(None, r.content)\n r.success()\n self.assertEqual(1, self.num_success)\n self.assertEqual(0, self.num_failures)\n\n def test_catch_response_connection_error_fail(self):\n class MyUser(FastHttpUser):\n host = \"http://127.0.0.1:1\"\n\n l = MyUser(self.environment)\n with l.client.get(\"/\", catch_response=True) as r:\n self.assertEqual(r.status_code, 0)\n self.assertEqual(None, r.content)\n r.failure(\"Manual fail\")\n self.assertEqual(0, self.num_success)\n self.assertEqual(1, self.num_failures)\n\n def test_catch_response_missing_with_block(self):\n # incorrect usage, missing with-block\n r = self.user.client.get(\"/fail\", catch_response=True)\n self.assertRaises(LocustError, r.success)\n self.assertRaises(LocustError, r.failure, \"\")\n\n def test_missing_catch_response_true(self):\n # incorrect usage, missing catch_response=True\n with self.user.client.get(\"/fail\") as resp:\n self.assertRaises(LocustError, resp.success)\n\n def test_rest_success(self):\n self.last_failure_exception = None\n with self.user.rest(\"POST\", \"/rest\", json={\"foo\": \"bar\"}) as response:\n assert response.js[\"foo\"] == \"bar\"\n\n self.assertEqual(0, self.num_failures)\n self.assertEqual(1, self.num_success)\n\n def test_rest_fail(self):\n with self.user.rest(\"POST\", \"/rest\", json={\"foo\": \"bar\"}) as response:\n assert response.js[\"foo\"] == \"NOPE\"\n\n self.assertTrue(\n isinstance(self.last_failure_exception, CatchResponseError),\n \"Failure event handler should have been passed a CatchResponseError instance\",\n )\n self.assertEqual(1, self.num_failures)\n self.assertEqual(0, self.num_success)\n\n\nclass TestFastHttpSsl(LocustTestCase):\n def setUp(self):\n super().setUp()\n tls_cert, tls_key = create_tls_cert(\"127.0.0.1\")\n self.tls_cert_file = NamedTemporaryFile()\n self.tls_key_file = NamedTemporaryFile()\n with open(self.tls_cert_file.name, \"w\") as f:\n f.write(tls_cert.decode())\n with open(self.tls_key_file.name, \"w\") as f:\n f.write(tls_key.decode())\n\n self.web_ui = self.environment.create_web_ui(\n \"127.0.0.1\",\n 0,\n tls_cert=self.tls_cert_file.name,\n tls_key=self.tls_key_file.name,\n )\n gevent.sleep(0.01)\n self.web_port = self.web_ui.server.server_port\n\n def tearDown(self):\n super().tearDown()\n self.web_ui.stop()\n\n def test_ssl_request_insecure(self):\n s = FastHttpSession(self.environment, \"https://127.0.0.1:%i\" % self.web_port, insecure=True, user=None)\n r = s.get(\"/\")\n self.assertEqual(200, r.status_code)\n self.assertIn(\"Locust for None\", r.content.decode(\"utf-8\"))\n self.assertIn(\"

Script: None

\", r.text)\n\n# Path: locust/test/test_http.py\nfrom locust.clients import HttpSession\nfrom locust.exception import LocustError, ResponseError\nfrom locust.user.users import HttpUser\n\nimport time\n\nfrom requests.exceptions import InvalidSchema, InvalidURL, MissingSchema, RequestException\n\nfrom .testcases import WebserverTestCase\n\n\nclass TestHttpSession(WebserverTestCase):\n def get_client(self, base_url=None):\n if base_url is None:\n base_url = \"http://127.0.0.1:%i\" % self.port\n return HttpSession(\n base_url=base_url,\n request_event=self.environment.events.request,\n user=None,\n )\n\n def test_get(self):\n s = self.get_client()\n r = s.get(\"/ultra_fast\")\n self.assertEqual(200, r.status_code)\n\n def test_connection_error(self):\n s = self.get_client(base_url=\"http://localhost:1\")\n r = s.get(\"/\", timeout=0.1)\n self.assertEqual(r.status_code, 0)\n self.assertEqual(None, r.content)\n self.assertRaises(RequestException, r.raise_for_status)\n\n def test_wrong_url(self):\n for url, exception in (\n (\"http://\\x94\", InvalidURL),\n (\"telnet://127.0.0.1\", InvalidSchema),\n (\"127.0.0.1\", MissingSchema),\n ):\n s = self.get_client(base_url=url)\n try:\n self.assertRaises(exception, s.get, \"/\")\n except KeyError:\n self.fail(f\"Invalid URL {url} was not propagated\")\n\n def test_streaming_response(self):\n \"\"\"\n Test a request to an endpoint that returns a streaming response\n \"\"\"\n s = self.get_client()\n r = s.get(\"/streaming/30\")\n\n # verify that the time reported includes the download time of the whole streamed response\n self.assertGreater(self.runner.stats.get(\"/streaming/30\", method=\"GET\").avg_response_time, 250)\n self.runner.stats.clear_all()\n\n # verify that response time does NOT include whole download time, when using stream=True\n r = s.get(\"/streaming/30\", stream=True)\n self.assertGreater(self.runner.stats.get(\"/streaming/30\", method=\"GET\").avg_response_time, 0)\n self.assertLess(self.runner.stats.get(\"/streaming/30\", method=\"GET\").avg_response_time, 250)\n\n # download the content of the streaming response (so we don't get an ugly exception in the log)\n _ = r.content\n\n def test_slow_redirect(self):\n s = self.get_client()\n url = \"/redirect?url=/redirect&delay=0.5\"\n r = s.get(url)\n stats = self.runner.stats.get(url, method=\"GET\")\n self.assertEqual(1, stats.num_requests)\n self.assertGreater(stats.avg_response_time, 500)\n\n def test_post_redirect(self):\n s = self.get_client()\n url = \"/redirect\"\n r = s.post(url)\n self.assertEqual(200, r.status_code)\n post_stats = self.runner.stats.get(url, method=\"POST\")\n get_stats = self.runner.stats.get(url, method=\"GET\")\n self.assertEqual(1, post_stats.num_requests)\n self.assertEqual(0, get_stats.num_requests)\n\n def test_cookie(self):\n s = self.get_client()\n r = s.post(\"/set_cookie?name=testcookie&value=1337\")\n self.assertEqual(200, r.status_code)\n r = s.get(\"/get_cookie?name=testcookie\")\n self.assertEqual(\"1337\", r.content.decode())\n\n def test_head(self):\n s = self.get_client()\n r = s.head(\"/request_method\")\n self.assertEqual(200, r.status_code)\n self.assertEqual(\"\", r.content.decode())\n\n def test_delete(self):\n s = self.get_client()\n r = s.delete(\"/request_method\")\n self.assertEqual(200, r.status_code)\n self.assertEqual(\"DELETE\", r.content.decode())\n\n def test_options(self):\n s = self.get_client()\n r = s.options(\"/request_method\")\n self.assertEqual(200, r.status_code)\n self.assertEqual(\"\", r.content.decode())\n self.assertEqual(\n {\"OPTIONS\", \"DELETE\", \"PUT\", \"GET\", \"POST\", \"HEAD\", \"PATCH\"},\n set(r.headers[\"allow\"].split(\", \")),\n )\n\n def test_error_message(self):\n s = self.get_client()\n kwargs = {}\n\n def on_request(**kw):\n kwargs.update(kw)\n\n self.environment.events.request.add_listener(on_request)\n s.request(\"get\", \"/wrong_url\", context={\"foo\": \"bar\"})\n self.assertIn(\"/wrong_url\", str(kwargs[\"exception\"]))\n self.assertDictEqual({\"foo\": \"bar\"}, kwargs[\"context\"])\n\n def test_context_in_success(self):\n s = self.get_client()\n kwargs = {}\n\n def on_request(exception, **kw):\n self.assertIsNone(exception)\n kwargs.update(kw)\n\n self.environment.events.request.add_listener(on_request)\n s.request(\"get\", \"/request_method\", context={\"foo\": \"bar\"})\n self.assertDictEqual({\"foo\": \"bar\"}, kwargs[\"context\"])\n\n def test_response_parameter(self):\n s = self.get_client()\n kwargs = {}\n\n def on_request(**kw):\n kwargs.update(kw)\n\n self.environment.events.request.add_listener(on_request)\n s.request(\"get\", \"/request_method\")\n self.assertEqual(\"GET\", kwargs[\"response\"].text)\n s.request(\"get\", \"/wrong_url\")\n self.assertEqual(\"Not Found\", kwargs[\"response\"].text)\n\n def test_error_message_with_name_replacement(self):\n s = self.get_client()\n kwargs = {}\n\n def on_request(**kw):\n self.assertIsNotNone(kw[\"exception\"])\n kwargs.update(kw)\n\n self.environment.events.request.add_listener(on_request)\n before_request = time.time()\n s.request(\"get\", \"/wrong_url/01\", name=\"replaced_url_name\", context={\"foo\": \"bar\"})\n after_request = time.time()\n self.assertIn(\"for url: replaced_url_name\", str(kwargs[\"exception\"]))\n self.assertAlmostEqual(before_request, kwargs[\"start_time\"], delta=0.01)\n self.assertAlmostEqual(after_request, kwargs[\"start_time\"] + kwargs[\"response_time\"] / 1000, delta=0.01)\n self.assertEqual(s.base_url + \"/wrong_url/01\", kwargs[\"url\"]) # url is unaffected by name\n self.assertDictEqual({\"foo\": \"bar\"}, kwargs[\"context\"])\n\n def test_get_with_params(self):\n s = self.get_client()\n r = s.get(\"/get_arg\", params={\"arg\": \"test_123\"})\n self.assertEqual(200, r.status_code)\n self.assertEqual(\"test_123\", r.text)\n\n def test_catch_response_fail_successful_request(self):\n s = self.get_client()\n with s.get(\"/ultra_fast\", catch_response=True) as r:\n r.failure(\"nope\")\n self.assertEqual(1, self.environment.stats.get(\"/ultra_fast\", \"GET\").num_requests)\n self.assertEqual(1, self.environment.stats.get(\"/ultra_fast\", \"GET\").num_failures)\n...\n", "needle_token_start": 7129, "needle_token_end": 7363, "code_context_ntokens": 16112, "output": ["```python\ndef test_custom_ssl_context_fail_with_bad_context(self):\n \"\"\"\n Test FastHttpSession with a custom SSLContext factory that will fail as\n we can not set verify_mode to CERT_NONE when check_hostname is enabled\n \"\"\"\n\n def create_custom_context():\n context = gevent.ssl.create_default_context()\n context.check_hostname = True\n context.verify_mode = gevent.ssl.CERT_NONE\n return context\n\n s = FastHttpSession(\n self.environment,\n \"https://127.0.0.1:%i\" % self.port,\n ssl_context_factory=create_custom_context,\n user=None,\n )\n with self.assertRaises(ValueError) as e:\n s.get(\"/\")\n self.assertEqual(e.exception.args, (\"Cannot set verify_mode to CERT_NONE when check_hostname is enabled.\",))\n```"]} +{"repo": "locustio/locust", "name": "test_worker_heartbeat_messages_sent_to_master", "language": "python", "path": "locust/test/test_runners.py", "position_ratio": 0.55, "description": "\nFunction Description:\n1. **Purpose**: To verify that the correct heartbeat data is sent from a worker to the master, including the worker's state and resource usage metrics.\n2. **Input**: The test does not require any external inputs as it uses a mocked RPC client and a predefined user class.\n3. **Output**: There is no direct output from the test other than the validation of conditions within the test itself, ensuring that the heartbeat message contains the expected keys and values.\n4. **Procedure**: \n - A worker instance is created with a mocked RPC client and a user class designed for the test.\n - The test waits for a heartbeat message to be sent to the master, ensuring it happens within a specified timeout.\n - It then checks the last heartbeat message to confirm it includes the correct data fields: state, CPU usage, and memory usage.\n - The test concludes by shutting down the worker.\n\n", "instruction": "Based on the function description and code context, please retrieve and repeat the exact described function from the code context in a code block wrapped by ```:", "template": "instruction\ncode_context\ndescription\ninstruction", "code_context": " even number of the connected workers\n \"\"\"\n\n class TestUser(User):\n @task\n def my_task(self):\n pass\n\n with mock.patch(\"locust.rpc.rpc.Server\", mocked_rpc()) as server:\n master = self.get_runner(user_classes=[TestUser])\n\n for i in range(5):\n server.mocked_send(Message(\"client_ready\", __version__, \"fake_client%i\" % i))\n\n master.start(7, 7)\n self.assertEqual(10, len(server.outbox))\n\n num_users = sum(\n sum(msg.data[\"user_classes_count\"].values()) for _, msg in server.outbox if msg.type != \"ack\"\n )\n\n self.assertEqual(7, num_users, \"Total number of locusts that would have been spawned is not 7\")\n\n def test_spawn_fewer_locusts_than_workers(self):\n class TestUser(User):\n @task\n def my_task(self):\n pass\n\n with mock.patch(\"locust.rpc.rpc.Server\", mocked_rpc()) as server:\n master = self.get_runner(user_classes=[TestUser])\n\n for i in range(5):\n server.mocked_send(Message(\"client_ready\", __version__, \"fake_client%i\" % i))\n\n master.start(2, 2)\n self.assertEqual(10, len(server.outbox))\n\n num_users = sum(\n sum(msg.data[\"user_classes_count\"].values()) for _, msg in server.outbox if msg.type != \"ack\"\n )\n\n self.assertEqual(2, num_users, \"Total number of locusts that would have been spawned is not 2\")\n\n def test_spawn_correct_worker_indexes(self):\n \"\"\"\n Tests that workers would receive a monotonic sequence of ordinal IDs.\n \"\"\"\n\n class TestUser(User):\n @task\n def my_task(self):\n pass\n\n with mock.patch(\"locust.rpc.rpc.Server\", mocked_rpc()) as server:\n master = self.get_runner(user_classes=[TestUser])\n\n USERS_COUNT = 5\n\n for i in range(USERS_COUNT):\n server.mocked_send(Message(\"client_ready\", __version__, \"fake_client%i\" % i))\n\n master.start(USERS_COUNT, USERS_COUNT)\n self.assertEqual(USERS_COUNT * 2, len(server.outbox))\n\n indexes = []\n for _, msg in server.outbox:\n if msg.type == \"ack\":\n indexes.append(msg.data[\"index\"])\n self.assertEqual(USERS_COUNT, len(indexes), \"Total number of locusts/workers is not 5\")\n\n indexes.sort()\n for i in range(USERS_COUNT):\n self.assertEqual(indexes[i], i, \"Worker index mismatch\")\n\n def test_custom_shape_scale_interval(self):\n class MyUser(User):\n @task\n def my_task(self):\n pass\n\n class TestShape(LoadTestShape):\n def __init__(self):\n super().__init__()\n self._users_num = [1, 1, 1, 2, 2, 3, 3, 3, 4]\n self._index = 0\n\n def tick(self):\n if self._index >= len(self._users_num):\n return None\n users_num = self._users_num[self._index]\n self._index += 1\n return users_num, users_num\n\n self.environment.shape_class = TestShape()\n\n with mock.patch(\"locust.rpc.rpc.Server\", mocked_rpc()) as server:\n master = self.get_runner(user_classes=[MyUser])\n for i in range(5):\n server.mocked_send(Message(\"client_ready\", __version__, \"fake_client%i\" % i))\n\n # Start the shape_worker\n self.environment.shape_class.reset_time()\n master.start_shape()\n\n # Wait for shape_worker to update user_count\n sleep(0.5)\n num_users = sum(\n sum(msg.data[\"user_classes_count\"].values()) for _, msg in server.outbox if msg.type != \"ack\"\n )\n self.assertEqual(\n 1, num_users, \"Total number of users in first stage of shape test is not 1: %i\" % num_users\n )\n\n # Wait for shape_worker to update user_count again\n sleep(1.5)\n num_users = sum(\n sum(msg.data[\"user_classes_count\"].values()) for _, msg in server.outbox if msg.type != \"ack\"\n )\n self.assertEqual(\n 1, num_users, \"Total number of users in second stage of shape test is not 1: %i\" % num_users\n )\n\n # Wait for shape_worker to update user_count few times but not reach the end yet\n sleep(2.5)\n num_users = sum(\n sum(msg.data[\"user_classes_count\"].values()) for _, msg in server.outbox if msg.type != \"ack\"\n )\n self.assertEqual(\n 3, num_users, \"Total number of users in second stage of shape test is not 3: %i\" % num_users\n )\n\n # Wait to ensure shape_worker has stopped the test\n sleep(3)\n self.assertEqual(\"stopped\", master.state, \"The test has not been stopped by the shape class\")\n\n def test_custom_shape_scale_up(self):\n class MyUser(User):\n @task\n def my_task(self):\n pass\n\n class TestShape(LoadTestShape):\n def tick(self):\n run_time = self.get_run_time()\n if run_time < 2:\n return 1, 1\n elif run_time < 4:\n return 2, 2\n else:\n return None\n\n self.environment.shape_class = TestShape()\n\n with mock.patch(\"locust.rpc.rpc.Server\", mocked_rpc()) as server:\n master = self.get_runner(user_classes=[MyUser])\n for i in range(5):\n server.mocked_send(Message(\"client_ready\", __version__, \"fake_client%i\" % i))\n\n # Start the shape_worker\n self.environment.shape_class.reset_time()\n master.start_shape()\n sleep(0.5)\n\n # Wait for shape_worker to update user_count\n num_users = sum(\n sum(msg.data[\"user_classes_count\"].values()) for _, msg in server.outbox if msg.type != \"ack\"\n )\n self.assertEqual(\n 1, num_users, \"Total number of users in first stage of shape test is not 1: %i\" % num_users\n )\n\n # Wait for shape_worker to update user_count again\n sleep(2)\n num_users = sum(\n sum(msg.data[\"user_classes_count\"].values()) for _, msg in server.outbox if msg.type != \"ack\"\n )\n self.assertEqual(\n 3, num_users, \"Total number of users in second stage of shape test is not 3: %i\" % num_users\n )\n\n # Wait to ensure shape_worker has stopped the test\n sleep(3)\n self.assertEqual(\"stopped\", master.state, \"The test has not been stopped by the shape class\")\n\n def test_custom_shape_scale_down(self):\n class MyUser(User):\n @task\n def my_task(self):\n pass\n\n class TestShape(LoadTestShape):\n def tick(self):\n run_time = self.get_run_time()\n if run_time < 2:\n return 5, 5\n elif run_time < 4:\n return 1, 5\n else:\n return None\n\n self.environment.shape_class = TestShape()\n\n with mock.patch(\"locust.rpc.rpc.Server\", mocked_rpc()) as server:\n master = self.get_runner(user_classes=[MyUser])\n for i in range(5):\n server.mocked_send(Message(\"client_ready\", __version__, \"fake_client%i\" % i))\n\n # Start the shape_worker\n self.environment.shape_class.reset_time()\n master.start_shape()\n sleep(0.5)\n\n # Wait for shape_worker to update user_count\n num_users = sum(\n sum(msg.data[\"user_classes_count\"].values()) for _, msg in server.outbox if msg.type != \"ack\"\n )\n self.assertEqual(\n 5, num_users, \"Total number of users in first stage of shape test is not 5: %i\" % num_users\n )\n\n # Wait for shape_worker to update user_count again\n sleep(2)\n msgs = defaultdict(dict)\n for _, msg in server.outbox:\n if msg.type == \"ack\":\n continue\n msgs[msg.node_id][msg.data[\"timestamp\"]] = sum(msg.data[\"user_classes_count\"].values())\n # Count users for the last received messages\n num_users = sum(v[max(v.keys())] for v in msgs.values())\n self.assertEqual(\n 1, num_users, \"Total number of users in second stage of shape test is not 1: %i\" % num_users\n )\n\n # Wait to ensure shape_worker has stopped the test\n sleep(3)\n self.assertEqual(\"stopped\", master.state, \"The test has not been stopped by the shape class\")\n\n def test_exception_in_task(self):\n class MyUser(User):\n @task\n def will_error(self):\n raise HeyAnException(\":(\")\n\n self.environment.user_classes = [MyUser]\n runner = self.environment.create_local_runner()\n\n l = MyUser(self.environment)\n\n self.assertRaises(HeyAnException, l.run)\n self.assertRaises(HeyAnException, l.run)\n self.assertEqual(1, len(runner.exceptions))\n\n hash_key, exception = runner.exceptions.popitem()\n self.assertIn(\"traceback\", exception)\n self.assertIn(\"HeyAnException\", exception[\"traceback\"])\n self.assertEqual(2, exception[\"count\"])\n\n def test_exception_is_caught(self):\n \"\"\"Test that exceptions are stored, and execution continues\"\"\"\n\n class MyTaskSet(TaskSet):\n def __init__(self, *a, **kw):\n super().__init__(*a, **kw)\n self._task_queue = [self.will_error, self.will_stop]\n\n @task(1)\n def will_error(self):\n raise HeyAnException(\":(\")\n\n @task(1)\n def will_stop(self):\n raise StopUser()\n\n class MyUser(User):\n wait_time = constant(0.01)\n tasks = [MyTaskSet]\n\n # set config to catch exceptions in locust users\n self.environment.catch_exceptions = True\n self.environment.user_classes = [MyUser]\n runner = LocalRunner(self.environment)\n l = MyUser(self.environment)\n\n # make sure HeyAnException isn't raised\n l.run()\n l.run()\n # make sure we got two entries in the error log\n self.assertEqual(2, len(self.mocked_log.error))\n\n # make sure exception was stored\n self.assertEqual(1, len(runner.exceptions))\n hash_key, exception = runner.exceptions.popitem()\n self.assertTrue(\"traceback\" in exception)\n self.assertTrue(\"HeyAnException\" in exception[\"traceback\"])\n self.assertEqual(2, exception[\"count\"])\n\n def test_master_reset_connection(self):\n \"\"\"Test that connection will be reset when network issues found\"\"\"\n with mock.patch(\"locust.runners.FALLBACK_INTERVAL\", new=0.1):\n with mock.patch(\"locust.rpc.rpc.Server\", mocked_rpc(raise_on_close=False)) as server:\n master = self.get_runner()\n self.assertEqual(0, len(master.clients))\n server.mocked_send(Message(\"client_ready\", NETWORK_BROKEN, \"fake_client\"))\n self.assertTrue(master.connection_broken)\n server.mocked_send(Message(\"client_ready\", __version__, \"fake_client\"))\n sleep(1)\n self.assertFalse(master.connection_broken)\n self.assertEqual(1, len(master.clients))\n master.quit()\n\n def test_reset_connection_after_RPCError(self):\n with mock.patch(\"locust.rpc.rpc.Server\", mocked_rpc(raise_on_close=False)) as server:\n master = self.get_runner()\n server.mocked_send(Message(\"client_ready\", __version__, \"fake_client\"))\n sleep(0.2)\n self.assertFalse(master.connection_broken)\n self.assertEqual(1, len(master.clients))\n\n # Trigger RPCError\n server.mocked_send(Message(\"lets_trigger_RPCError\", NETWORK_BROKEN, \"fake_client\"))\n self.assertTrue(master.connection_broken)\n sleep(1)\n self.assertFalse(master.connection_broken)\n master.quit()\n\n def test_attributes_populated_when_calling_start(self):\n class MyUser1(User):\n @task\n def my_task(self):\n pass\n\n class MyUser2(User):\n @task\n def my_task(self):\n pass\n\n with mock.patch(\"locust.rpc.rpc.Server\", mocked_rpc()) as server:\n master = self.get_runner(user_classes=[MyUser1, MyUser2])\n\n server.mocked_send(Message(\"client_ready\", __version__, \"fake_client1\"))\n\n master.start(7, 7)\n self.assertEqual({\"MyUser1\": 4, \"MyUser2\": 3}, master.target_user_classes_count)\n self.assertEqual(7, master.target_user_count)\n self.assertEqual(7, master.spawn_rate)\n\n master.start(10, 10)\n self.assertEqual({\"MyUser1\": 5, \"MyUser2\": 5}, master.target_user_classes_count)\n self.assertEqual(10, master.target_user_count)\n self.assertEqual(10, master.spawn_rate)\n\n master.start(1, 3)\n self.assertEqual({\"MyUser1\": 1, \"MyUser2\": 0}, master.target_user_classes_count)\n self.assertEqual(1, master.target_user_count)\n self.assertEqual(3, master.spawn_rate)\n\n def test_custom_message_send(self):\n class MyUser(User):\n wait_time = constant(1)\n\n @task\n def my_task(self):\n pass\n\n with mock.patch(\"locust.rpc.rpc.Server\", mocked_rpc()) as server:\n master = self.get_runner()\n for i in range(5):\n master.clients[i] = WorkerNode(str(i))\n master.send_message(\"test_custom_msg\", {\"test_data\": 123})\n\n self.assertEqual(5, len(server.outbox))\n for _, msg in server.outbox:\n self.assertEqual(\"test_custom_msg\", msg.type)\n self.assertEqual(123, msg.data[\"test_data\"])\n\n def test_custom_message_receive(self):\n class MyUser(User):\n wait_time = constant(1)\n\n @task\n def my_task(self):\n pass\n\n with mock.patch(\"locust.rpc.rpc.Server\", mocked_rpc()) as server:\n test_custom_msg = [False]\n test_custom_msg_data = [{}]\n\n def on_custom_msg(msg, **kw):\n test_custom_msg[0] = True\n test_custom_msg_data[0] = msg.data\n\n master = self.get_runner()\n master.register_message(\"test_custom_msg\", on_custom_msg)\n\n server.mocked_send(Message(\"test_custom_msg\", {\"test_data\": 123}, \"dummy_id\"))\n\n self.assertTrue(test_custom_msg[0])\n self.assertEqual(123, test_custom_msg_data[0][\"test_data\"])\n\n def test_undefined_custom_message_receive(self):\n class MyUser(User):\n wait_time = constant(1)\n\n @task\n def my_task(self):\n pass\n\n with mock.patch(\"locust.rpc.rpc.Server\", mocked_rpc()) as server:\n test_custom_msg = [False]\n\n def on_custom_msg(msg, **kw):\n test_custom_msg[0] = True\n\n master = self.get_runner()\n master.register_message(\"test_custom_msg\", on_custom_msg)\n\n server.mocked_send(Message(\"unregistered_custom_msg\", {}, \"dummy_id\"))\n\n self.assertFalse(test_custom_msg[0])\n self.assertEqual(1, len(self.mocked_log.warning))\n msg = self.mocked_log.warning[0]\n self.assertIn(\"Unknown message type received from worker\", msg)\n\n def test_wait_for_workers_report_after_ramp_up(self):\n def assert_cache_hits():\n self.assertEqual(master._wait_for_workers_report_after_ramp_up.cache_info().hits, 0)\n master._wait_for_workers_report_after_ramp_up()\n self.assertEqual(master._wait_for_workers_report_after_ramp_up.cache_info().hits, 1)\n\n master = self.get_runner()\n\n master._wait_for_workers_report_after_ramp_up.cache_clear()\n self.assertEqual(master._wait_for_workers_report_after_ramp_up(), 1.0)\n assert_cache_hits()\n\n master._wait_for_workers_report_after_ramp_up.cache_clear()\n with patch_env(\"LOCUST_WAIT_FOR_WORKERS_REPORT_AFTER_RAMP_UP\", \"5.7\"):\n self.assertEqual(master._wait_for_workers_report_after_ramp_up(), 5.7)\n assert_cache_hits()\n\n master._wait_for_workers_report_after_ramp_up.cache_clear()\n with mock.patch(\"locust.runners.WORKER_REPORT_INTERVAL\", new=1.5), patch_env(\n \"LOCUST_WAIT_FOR_WORKERS_REPORT_AFTER_RAMP_UP\", \"5.7 * WORKER_REPORT_INTERVAL\"\n ):\n self.assertEqual(master._wait_for_workers_report_after_ramp_up(), 5.7 * 1.5)\n assert_cache_hits()\n\n master._wait_for_workers_report_after_ramp_up.cache_clear()\n\n def test_master_discard_first_client_ready(self):\n with mock.patch(\"locust.rpc.rpc.Server\", mocked_rpc()) as server:\n server.mocked_send(Message(\"client_ready\", __version__, \"dummy_client\"))\n # discard first client_ready msg\n server.queue.get()\n master = self.get_runner()\n server.mocked_send(Message(\"client_ready\", __version__, \"dummy_client\"))\n\n self.assertEqual(1, len(master.clients))\n self.assertEqual(\"ack\", server.outbox[0][1].type)\n self.assertEqual(1, len(server.outbox))\n self.assertEqual(0, server.outbox[0][1].data[\"index\"])\n\n def test_worker_sends_bad_message_to_master(self):\n \"\"\"\n Validate master sends reconnect message to worker when it receives a bad message.\n \"\"\"\n\n class TestUser(User):\n @task\n def my_task(self):\n pass\n\n with mock.patch(\"locust.rpc.rpc.Server\", mocked_rpc()) as server:\n master = self.get_runner(user_classes=[TestUser])\n server.mocked_send(Message(\"client_ready\", __version__, \"zeh_fake_client1\"))\n self.assertEqual(1, len(master.clients))\n self.assertTrue(\n \"zeh_fake_client1\" in master.clients, \"Could not find fake client in master instance's clients dict\"\n )\n\n master.start(10, 10)\n sleep(0.1)\n server.mocked_send(Message(\"stats\", BAD_MESSAGE, \"zeh_fake_client1\"))\n self.assertEqual(4, len(server.outbox))\n\n # Expected message order in outbox: ack, spawn, reconnect, ack\n self.assertEqual(\n \"reconnect\", server.outbox[2][1].type, \"Master didn't send worker reconnect message when expected.\"\n )\n\n def test_worker_sends_unrecognized_message_to_master(self):\n \"\"\"\n Validate master ignores message from worker when it cannot parse adddress info.\n \"\"\"\n\n class TestUser(User):\n @task\n def my_task(self):\n pass\n\n with mock.patch(\"locust.rpc.rpc.Server\", mocked_rpc()) as server:\n master = self.get_runner(user_classes=[TestUser])\n server.mocked_send(Message(\"client_ready\", __version__, \"zeh_fake_client1\"))\n self.assertEqual(1, len(master.clients))\n self.assertTrue(\n \"zeh_fake_client1\" in master.clients, \"Could not find fake client in master instance's clients dict\"\n )\n\n master.start(10, 10)\n sleep(0.1)\n server.mocked_send(Message(\"stats\", UNRECOGNIZED_MESSAGE, \"zeh_fake_client1\"))\n self.assertEqual(2, len(server.outbox))\n\n def test_unknown_host_sends_message_to_master(self):\n \"\"\"\n Validate master ignores message that is sent from unknown host\n \"\"\"\n\n class TestUser(User):\n @task\n def my_task(self):\n pass\n\n with mock.patch(\"locust.rpc.rpc.Server\", mocked_rpc()) as server:\n master = self.get_runner(user_classes=[TestUser])\n server.mocked_send(Message(\"client_ready\", __version__, \"zeh_fake_client1\"))\n self.assertEqual(1, len(master.clients))\n self.assertTrue(\n \"zeh_fake_client1\" in master.clients, \"Could not find fake client in master instance's clients dict\"\n )\n\n master.start(10, 10)\n sleep(0.1)\n server.mocked_send(Message(\"stats\", UNRECOGNIZED_HOST_MESSAGE, \"unknown_host\"))\n self.assertEqual(2, len(server.outbox))\n\n\nclass TestWorkerRunner(LocustTestCase):\n def setUp(self):\n super().setUp()\n # self._report_to_master_event_handlers = [h for h in events.report_to_master._handlers]\n\n def tearDown(self):\n # events.report_to_master._handlers = self._report_to_master_event_handlers\n super().tearDown()\n\n def get_runner(self, client, environment=None, user_classes=None, auto_connect=True):\n if auto_connect:\n client.mocked_send(Message(\"ack\", {\"index\": 0}, \"dummy_client_id\"))\n if environment is None:\n environment = self.environment\n user_classes = user_classes or []\n environment.user_classes = user_classes\n return WorkerRunner(environment, master_host=\"localhost\", master_port=5557)\n\n def test_worker_stop_timeout(self):\n class MyTestUser(User):\n _test_state = 0\n\n @task\n def the_task(self):\n MyTestUser._test_state = 1\n gevent.sleep(0.2)\n MyTestUser._test_state = 2\n\n with mock.patch(\"locust.rpc.rpc.Client\", mocked_rpc()) as client:\n worker = self.get_runner(environment=Environment(), user_classes=[MyTestUser], client=client)\n self.assertEqual(1, len(client.outbox))\n self.assertEqual(\"client_ready\", client.outbox[0].type)\n client.mocked_send(\n Message(\n \"spawn\",\n {\n \"timestamp\": 1605538584,\n \"user_classes_count\": {\"MyTestUser\": 1},\n \"host\": \"\",\n \"stop_timeout\": 1,\n \"parsed_options\": {},\n },\n \"dummy_client_id\",\n )\n )\n # wait for worker to spawn locusts\n self.assertIn(\"spawning\", [m.type for m in client.outbox])\n worker.spawning_greenlet.join()\n self.assertEqual(1, len(worker.user_greenlets))\n # check that locust has started running\n gevent.sleep(0.01)\n self.assertEqual(1, MyTestUser._test_state)\n # send stop message\n client.mocked_send(Message(\"stop\", None, \"dummy_client_id\"))\n worker.user_greenlets.join()\n # check that locust user got to finish\n self.assertEqual(2, MyTestUser._test_state)\n\n def test_worker_without_stop_timeout(self):\n class MyTestUser(User):\n _test_state = 0\n\n @task\n def the_task(self):\n MyTestUser._test_state = 1\n gevent.sleep(0.2)\n MyTestUser._test_state = 2\n\n with mock.patch(\"locust.rpc.rpc.Client\", mocked_rpc()) as client:\n worker = self.get_runner(environment=Environment(), user_classes=[MyTestUser], client=client)\n self.assertEqual(1, len(client.outbox))\n self.assertEqual(\"client_ready\", client.outbox[0].type)\n client.mocked_send(\n Message(\n \"spawn\",\n {\n \"timestamp\": 1605538584,\n \"user_classes_count\": {\"MyTestUser\": 1},\n \"host\": \"\",\n \"stop_timeout\": None,\n \"parsed_options\": {},\n },\n \"dummy_client_id\",\n )\n )\n # print(\"outbox:\", client.outbox)\n # wait for worker to spawn locusts\n self.assertIn(\"spawning\", [m.type for m in client.outbox])\n worker.spawning_greenlet.join()\n self.assertEqual(1, len(worker.user_greenlets))\n # check that locust has started running\n gevent.sleep(0.01)\n self.assertEqual(1, MyTestUser._test_state)\n # send stop message\n client.mocked_send(Message(\"stop\", None, \"dummy_client_id\"))\n worker.user_greenlets.join()\n # check that locust user did not get to finish\n self.assertEqual(1, MyTestUser._test_state)\n\n def test_spawn_message_with_older_timestamp_is_rejected(self):\n class MyUser(User):\n wait_time = constant(1)\n\n def start(self, group: Group):\n # We do this so that the spawning does not finish\n # too quickly\n gevent.sleep(0.1)\n return super().start(group)\n\n @task\n def my_task(self):\n pass\n\n with mock.patch(\"locust.rpc.rpc.Client\", mocked_rpc()) as client:\n worker = self.get_runner(environment=Environment(), user_classes=[MyUser], client=client)\n\n client.mocked_send(\n Message(\n \"spawn\",\n {\n \"timestamp\": 1605538584,\n \"user_classes_count\": {\"MyUser\": 10},\n \"host\": \"\",\n \"stop_timeout\": None,\n \"parsed_options\": {},\n },\n \"dummy_client_id\",\n )\n )\n sleep(0.6)\n self.assertEqual(STATE_SPAWNING, worker.state)\n worker.spawning_greenlet.join()\n self.assertEqual(10, worker.user_count)\n\n # Send same timestamp as the first message\n client.mocked_send(\n Message(\n \"spawn\",\n {\n \"timestamp\": 1605538584,\n \"user_classes_count\": {\"MyUser\": 9},\n \"host\": \"\",\n \"stop_timeout\": None,\n \"parsed_options\": {},\n },\n \"dummy_client_id\",\n )\n )\n worker.spawning_greenlet.join()\n # Still 10 users\n self.assertEqual(10, worker.user_count)\n\n # Send older timestamp than the first message\n client.mocked_send(\n Message(\n \"spawn\",\n {\n \"timestamp\": 1605538583,\n \"user_classes_count\": {\"MyUser\": 2},\n \"host\": \"\",\n \"stop_timeout\": None,\n \"parsed_options\": {},\n },\n \"dummy_client_id\",\n )\n )\n worker.spawning_greenlet.join()\n # Still 10 users\n self.assertEqual(10, worker.user_count)\n\n # Send newer timestamp than the first message\n client.mocked_send(\n Message(\n \"spawn\",\n {\n \"timestamp\": 1605538585,\n \"user_classes_count\": {\"MyUser\": 2},\n \"host\": \"\",\n \"stop_timeout\": None,\n \"parsed_options\": {},\n },\n \"dummy_client_id\",\n )\n )\n worker.spawning_greenlet.join()\n self.assertEqual(2, worker.user_count)\n\n worker.quit()\n\n def test_worker_messages_sent_to_master(self):\n \"\"\"\n Ensure that worker includes both \"user_count\" and \"user_classes_count\"\n when reporting to the master.\n \"\"\"\n\n class MyUser(User):\n wait_time = constant(1)\n\n def start(self, group: Group):\n # We do this so that the spawning does not finish\n # too quickly\n gevent.sleep(0.1)\n return super().start(group)\n\n @task\n def my_task(self):\n pass\n\n with mock.patch(\"locust.rpc.rpc.Client\", mocked_rpc()) as client:\n worker = self.get_runner(environment=Environment(), user_classes=[MyUser], client=client)\n\n client.mocked_send(\n Message(\n \"spawn\",\n {\n \"timestamp\": 1605538584,\n \"user_classes_count\": {\"MyUser\": 10},\n \"host\": \"\",\n \"stop_timeout\": None,\n \"parsed_options\": {},\n },\n \"dummy_client_id\",\n )\n )\n sleep(0.6)\n self.assertEqual(STATE_SPAWNING, worker.state)\n worker.spawning_greenlet.join()\n self.assertEqual(10, worker.user_count)\n\n sleep(2)\n\n message = next((m for m in reversed(client.outbox) if m.type == \"stats\"), None)\n self.assertIsNotNone(message)\n self.assertIn(\"user_count\", message.data)\n self.assertIn(\"user_classes_count\", message.data)\n self.assertEqual(message.data[\"user_count\"], 10)\n self.assertEqual(message.data[\"user_classes_count\"][\"MyUser\"], 10)\n\n message = next((m for m in client.outbox if m.type == \"spawning_complete\"), None)\n self.assertIsNotNone(message)\n self.assertIn(\"user_count\", message.data)\n self.assertIn(\"user_classes_count\", message.data)\n self.assertEqual(message.data[\"user_count\"], 10)\n self.assertEqual(message.data[\"user_classes_count\"][\"MyUser\"], 10)\n\n worker.quit()\n\n \ndef test_worker_heartbeat_messages_sent_to_master(self):\n \"\"\"\n Validate content of the heartbeat payload sent to the master.\n \"\"\"\n\n class MyUser(User):\n wait_time = constant(1)\n\n @task\n def my_task(self):\n pass\n\n with mock.patch(\"locust.rpc.rpc.Client\", mocked_rpc()) as client:\n worker = self.get_runner(environment=Environment(), user_classes=[MyUser], client=client)\n\n t0 = time.perf_counter()\n while len([m for m in client.outbox if m.type == \"heartbeat\"]) == 0:\n self.assertLessEqual(time.perf_counter() - t0, 3)\n sleep(0.1)\n\n message = next(m for m in reversed(client.outbox) if m.type == \"heartbeat\")\n self.assertEqual(len(message.data), 3)\n self.assertIn(\"state\", message.data)\n self.assertIn(\"current_cpu_usage\", message.data)\n self.assertIn(\"current_memory_usage\", message.data)\n\n worker.quit()\n\n def test_reset_rpc_connection_to_master(self):\n \"\"\"\n Validate worker resets RPC connection to master on \"reconnect\" message.\n \"\"\"\n\n class MyUser(User):\n wait_time = constant(1)\n\n @task\n def my_task(self):\n pass\n\n with mock.patch(\"locust.rpc.rpc.Client\", mocked_rpc(raise_on_close=False)) as client:\n client_id = id(client)\n worker = self.get_runner(environment=Environment(), user_classes=[MyUser], client=client)\n client.mocked_send(\n Message(\n \"spawn\",\n {\n \"timestamp\": 1605538584,\n \"user_classes_count\": {\"MyUser\": 10},\n \"host\": \"\",\n \"stop_timeout\": None,\n \"parsed_options\": {},\n },\n \"dummy_client_id\",\n )\n )\n sleep(0.6)\n self.assertEqual(STATE_RUNNING, worker.state)\n with self.assertLogs(\"locust.runners\") as capture:\n with mock.patch(\"locust.rpc.rpc.Client.close\") as close:\n client.mocked_send(\n Message(\n \"reconnect\",\n None,\n \"dummy_client_id\",\n )\n )\n sleep(0)\n worker.spawning_greenlet.join()\n worker.quit()\n close.assert_called_once()\n self.assertIn(\n \"WARNING:locust.runners:Received reconnect message from master. Resetting RPC connection.\",\n capture.output,\n )\n\n def test_change_user_count_during_spawning(self):\n class MyUser(User):\n wait_time = constant(1)\n\n def start(self, group: Group):\n # We do this so that the spawning does not finish\n # too quickly\n gevent.sleep(0.1)\n return super().start(group)\n\n @task\n def my_task(self):\n pass\n\n with mock.patch(\"locust.rpc.rpc.Client\", mocked_rpc()) as client:\n worker = self.get_runner(environment=Environment(), user_classes=[MyUser], client=client)\n\n client.mocked_send(\n Message(\n \"spawn\",\n {\n \"timestamp\": 1605538584,\n \"user_classes_count\": {\"MyUser\": 10},\n \"host\": \"\",\n \"stop_timeout\": None,\n \"parsed_options\": {},\n },\n \"dummy_client_id\",\n )\n )\n sleep(0.6)\n self.assertEqual(STATE_SPAWNING, worker.state)\n client.mocked_send(\n Message(\n \"spawn\",\n {\n \"timestamp\": 1605538585,\n \"user_classes_count\": {\"MyUser\": 9},\n \"host\": \"\",\n \"stop_timeout\": None,\n \"parsed_options\": {},\n },\n \"dummy_client_id\",\n )\n )\n sleep(0)\n worker.spawning_greenlet.join()\n self.assertEqual(9, len(worker.user_greenlets))\n worker.quit()\n\n def test_computed_properties(self):\n class MyUser1(User):\n wait_time = constant(1)\n\n @task\n def my_task(self):\n pass\n\n class MyUser2(User):\n wait_time = constant(1)\n\n @task\n def my_task(self):\n pass\n\n with mock.patch(\"locust.rpc.rpc.Client\", mocked_rpc()) as client:\n worker = self.get_runner(environment=Environment(), user_classes=[MyUser1, MyUser2], client=client)\n\n client.mocked_send(\n Message(\n \"spawn\",\n {\n \"timestamp\": 1605538584,\n \"user_classes_count\": {\"MyUser1\": 10, \"MyUser2\": 10},\n \"host\": \"\",\n \"stop_timeout\": None,\n \"parsed_options\": {},\n },\n \"dummy_client_id\",\n )\n )\n worker.spawning_greenlet.join()\n self.assertDictEqual(worker.user_classes_count, {\"MyUser1\": 10, \"MyUser2\": 10})\n self.assertDictEqual(worker.target_user_classes_count, {\"MyUser1\": 10, \"MyUser2\": 10})\n self.assertEqual(worker.target_user_count, 20)\n\n client.mocked_send(\n Message(\n \"spawn\",\n {\n \"timestamp\": 1605538585,\n \"user_classes_count\": {\"MyUser1\": 1, \"MyUser2\": 2},\n \"host\": \"\",\n \"stop_timeout\": None,\n \"parsed_options\": {},\n },\n \"dummy_client_id\",\n )\n )\n worker.spawning_greenlet.join()\n self.assertDictEqual(worker.user_classes_count, {\"MyUser1\": 1, \"MyUser2\": 2})\n self.assertDictEqual(worker.target_user_classes_count, {\"MyUser1\": 1, \"MyUser2\": 2})\n self.assertEqual(worker.target_user_count, 3)\n\n worker.quit()\n\n def test_custom_message_send(self):\n class MyUser(User):\n wait_time = constant(1)\n\n @task\n def my_task(self):\n pass\n\n with mock.patch(\"locust.rpc.rpc.Client\", mocked_rpc()) as client:\n worker = self.get_runner(environment=Environment(), user_classes=[MyUser], client=client)\n client.outbox.clear()\n worker.send_message(\"test_custom_msg\", {\"test_data\": 123})\n self.assertEqual(\"test_custom_msg\", client.outbox[0].type)\n self.assertEqual(123, client.outbox[0].data[\"test_data\"])\n worker.quit()\n\n def test_custom_message_receive(self):\n class MyUser(User):\n wait_time = constant(1)\n\n @task\n def my_task(self):\n pass\n\n with mock.patch(\"locust.rpc.rpc.Client\", mocked_rpc()) as client:\n test_custom_msg = [False]\n test_custom_msg_data = [{}]\n\n def on_custom_msg(msg, **kw):\n test_custom_msg[0] = True\n test_custom_msg_data[0] = msg.data\n\n worker = self.get_runner(environment=Environment(), user_classes=[MyUser], client=client)\n worker.register_message(\"test_custom_msg\", on_custom_msg)\n\n client.mocked_send(Message(\"test_custom_msg\", {\"test_data\": 123}, \"dummy_client_id\"))\n\n self.assertTrue(test_custom_msg[0])\n self.assertEqual(123, test_custom_msg_data[0][\"test_data\"])\n worker.quit()\n\n def test_undefined_custom_message_receive(self):\n class MyUser(User):\n wait_time = constant(1)\n\n @task\n def my_task(self):\n pass\n\n with mock.patch(\"locust.rpc.rpc.Client\", mocked_rpc()) as client:\n test_custom_msg = [False]\n\n def on_custom_msg(msg, **kw):\n test_custom_msg[0] = True\n\n worker = self.get_runner(environment=Environment(), user_classes=[MyUser], client=client)\n worker.register_message(\"test_custom_msg\", on_custom_msg)\n\n client.mocked_send(Message(\"unregistered_custom_msg\", {}, \"dummy_id\"))\n\n self.assertFalse(test_custom_msg[0])\n self.assertEqual(1, len(self.mocked_log.warning))\n msg = self.mocked_log.warning[0]\n self.assertIn(\"Unknown message type received\", msg)\n\n def test_start_event(self):\n class MyTestUser(User):\n _test_state = 0\n\n @task\n def the_task(self):\n MyTestUser._test_state = 1\n gevent.sleep(0.2)\n MyTestUser._test_state = 2\n\n with mock.patch(\"locust.rpc.rpc.Client\", mocked_rpc()) as client:\n environment = Environment()\n run_count = [0]\n\n @environment.events.test_start.add_listener\n def on_test_start(*args, **kw):\n run_count[0] += 1\n\n worker = self.get_runner(environment=environment, user_classes=[MyTestUser], client=client)\n self.assertEqual(1, len(client.outbox))\n self.assertEqual(\"client_ready\", client.outbox[0].type)\n client.mocked_send(\n Message(\n \"spawn\",\n {\n \"timestamp\": 1605538585,\n \"user_classes_count\": {\"MyTestUser\": 1},\n \"spawn_rate\": 1,\n \"num_users\": 1,\n \"host\": \"\",\n \"stop_timeout\": None,\n \"parsed_options\": {},\n },\n \"dummy_client_id\",\n )\n )\n # wait for worker to spawn locusts\n self.assertIn(\"spawning\", [m.type for m in client.outbox])\n worker.spawning_greenlet.join()\n self.assertEqual(1, len(worker.user_greenlets))\n self.assertEqual(1, run_count[0])\n\n # check that locust has started running\n gevent.sleep(0.01)\n self.assertEqual(1, MyTestUser._test_state)\n\n # change number of users and check that test_start isn't fired again\n client.mocked_send(\n Message(\n \"spawn\",\n {\n \"timestamp\": 1605538586,\n \"user_classes_count\": {\"MyTestUser\": 1},\n \"spawn_rate\": 1,\n \"num_users\": 1,\n \"host\": \"\",\n \"stop_timeout\": None,\n \"parsed_options\": {},\n },\n \"dummy_client_id\",\n )\n )\n self.assertEqual(1, run_count[0])\n\n # stop and start to make sure test_start is fired again\n client.mocked_send(Message(\"stop\", None, \"dummy_client_id\"))\n client.mocked_send(\n Message(\n \"spawn\",\n {\n \"timestamp\": 1605538587,\n \"user_classes_count\": {\"MyTestUser\": 1},\n \"spawn_rate\": 1,\n \"num_users\": 1,\n \"host\": \"\",\n \"stop_timeout\": None,\n \"parsed_options\": {},\n },\n \"dummy_client_id\",\n )\n )\n gevent.sleep(0.01)\n self.assertEqual(2, run_count[0])\n\n client.mocked_send(Message(\"stop\", None, \"dummy_client_id\"))\n\n def test_stop_event(self):\n class MyTestUser(User):\n _test_state = 0\n\n @task\n def the_task(self):\n MyTestUser._test_state = 1\n gevent.sleep(0.2)\n MyTestUser._test_state = 2\n\n with mock.patch(\"locust.rpc.rpc.Client\", mocked_rpc()) as client:\n environment = Environment()\n run_count = [0]\n\n @environment.events.test_stop.add_listener\n def on_test_stop(*args, **kw):\n run_count[0] += 1\n\n worker = self.get_runner(environment=environment, user_classes=[MyTestUser], client=client)\n self.assertEqual(1, len(client.outbox))\n self.assertEqual(\"client_ready\", client.outbox[0].type)\n client.mocked_send(\n Message(\n \"spawn\",\n {\n \"timestamp\": 1605538585,\n \"user_classes_count\": {\"MyTestUser\": 1},\n \"spawn_rate\": 1,\n \"num_users\": 1,\n \"host\": \"\",\n \"stop_timeout\": None,\n \"parsed_options\": {},\n },\n \"dummy_client_id\",\n )\n )\n\n # wait for worker to spawn locusts\n self.assertIn(\"spawning\", [m.type for m in client.outbox])\n worker.spawning_greenlet.join()\n self.assertEqual(1, len(worker.user_greenlets))\n\n # check that locust has started running\n gevent.sleep(0.01)\n self.assertEqual(1, MyTestUser._test_state)\n\n # stop and make sure test_stop is fired\n client.mocked_send(Message(\"stop\", None, \"dummy_client_id\"))\n gevent.sleep(0.01)\n self.assertEqual(1, run_count[0])\n\n # stop while stopped and make sure the event isn't fired again\n client.mocked_send(Message(\"stop\", None, \"dummy_client_id\"))\n gevent.sleep(0.01)\n self.assertEqual(1, run_count[0])\n\n # start and stop to check that the event is fired again\n client.mocked_send(\n Message(\n \"spawn\",\n {\n \"timestamp\": 1605538586,\n \"user_classes_count\": {\"MyTestUser\": 1},\n \"spawn_rate\": 1,\n \"num_users\": 1,\n \"host\": \"\",\n \"stop_timeout\": None,\n \"parsed_options\": {},\n },\n \"dummy_client_id\",\n )\n )\n client.mocked_send(Message(\"stop\", None, \"dummy_client_id\"))\n gevent.sleep(0.01)\n self.assertEqual(2, run_count[0])\n\n def test_worker_connect_success(self):\n class MyTestUser(User):\n @task\n def the_task(self):\n pass\n\n with mock.patch(\"locust.runners.CONNECT_TIMEOUT\", new=1):\n with mock.patch(\"locust.rpc.rpc.Client\", mocked_rpc()) as client:\n worker = self.get_runner(environment=Environment(), user_classes=[MyTestUser], client=client)\n\n self.assertEqual(\"client_ready\", client.outbox[0].type)\n self.assertEqual(1, len(client.outbox))\n self.assertTrue(worker.connected)\n\n def test_worker_connect_failure(self):\n class MyTestUser(User):\n @task\n def the_task(self):\n pass\n\n with mock.patch(\"locust.runners.CONNECT_TIMEOUT\", new=0.01):\n with mock.patch(\"locust.runners.CONNECT_RETRY_COUNT\", new=1):\n with mock.patch(\"locust.rpc.rpc.Client\", mocked_rpc()) as client:\n with self.assertRaises(ConnectionError):\n self.get_runner(\n environment=Environment(), user_classes=[MyTestUser], client=client, auto_connect=False\n )\n self.assertEqual(2, len(client.outbox))\n\n\nclass TestMessageSerializing(unittest.TestCase):\n def test_message_serialize(self):\n msg = Message(\"client_ready\", __version__, \"my_id\")\n rebuilt = Message.unserialize(msg.serialize())\n self.assertEqual(msg.type, rebuilt.type)\n self.assertEqual(msg.data, rebuilt.data)\n self.assertEqual(msg.node_id, rebuilt.node_id)\n\n\nclass TestStopTimeout(LocustTestCase):\n def test_stop_timeout(self):\n short_time = 0.05\n\n class MyTaskSet(TaskSet):\n @task\n def my_task(self):\n MyTaskSet.state = \"first\"\n gevent.sleep(short_time)\n MyTaskSet.state = \"second\" # should only run when run time + stop_timeout is > short_time\n gevent.sleep(short_time)\n MyTaskSet.state = \"third\" # should only run when run time + stop_timeout is > short_time * 2\n\n class MyTestUser(User):\n tasks = [MyTaskSet]\n\n environment = Environment(user_classes=[MyTestUser])\n runner = environment.create_local_runner()\n runner.start(1, 1, wait=False)\n gevent.sleep(short_time / 2)\n runner.quit()\n self.assertEqual(\"first\", MyTaskSet.state)\n\n # exit with timeout\n environment = Environment(user_classes=[MyTestUser], stop_timeout=short_time / 2)\n runner = environment.create_local_runner()\n runner.start(1, 1, wait=False)\n gevent.sleep(short_time)\n runner.quit()\n self.assertEqual(\"second\", MyTaskSet.state)\n\n # allow task iteration to complete, with some margin\n environment = Environment(user_classes=[MyTestUser], stop_timeout=short_time * 3)\n runner = environment.create_local_runner()\n runner.start(1, 1, wait=False)\n gevent.sleep(short_time)\n timeout = gevent.Timeout(short_time * 2)\n timeout.start()\n try:\n runner.quit()\n runner.greenlet.join()\n except gevent.Timeout:\n self.fail(\"Got Timeout exception. Some locusts must have kept running after iteration finish\")\n finally:\n timeout.cancel()\n self.assertEqual(\"third\", MyTaskSet.state)\n\n def test_stop_timeout_during_on_start(self):\n short_time = 0.05\n\n class MyTaskSet(TaskSet):\n finished_on_start = False\n my_task_run = False\n\n def on_start(self):\n gevent.sleep(short_time)\n MyTaskSet.finished_on_start = True\n\n @task\n def my_task(self):\n MyTaskSet.my_task_run = True\n\n class MyTestUser(User):\n tasks = [MyTaskSet]\n\n environment = create_environment([MyTestUser], mocked_options())\n environment.stop_timeout = short_time\n runner = environment.create_local_runner()\n runner.start(1, 1)\n gevent.sleep(short_time / 2)\n runner.quit()\n\n self.assertTrue(MyTaskSet.finished_on_start)\n self.assertFalse(MyTaskSet.my_task_run)\n\n def test_stop_timeout_exit_during_wait(self):\n short_time = 0.05\n\n class MyTaskSet(TaskSet):\n @task\n def my_task(self):\n pass\n\n class MyTestUser(User):\n tasks = [MyTaskSet]\n wait_time = constant(1)\n\n environment = Environment(user_classes=[MyTestUser], stop_timeout=short_time)\n runner = environment.create_local_runner()\n runner.start(1, 1)\n gevent.sleep(short_time) # sleep to make sure locust has had time to start waiting\n timeout = gevent.Timeout(short_time)\n timeout.start()\n try:\n runner.quit()\n runner.greenlet.join()\n except gevent.Timeout:\n self.fail(\"Got Timeout exception. Waiting locusts should stop immediately, even when using stop_timeout.\")\n finally:\n timeout.cancel()\n\n def test_stop_timeout_with_interrupt(self):\n short_time = 0.05\n\n class MySubTaskSet(TaskSet):\n @task\n def a_task(self):\n gevent.sleep(0)\n self.interrupt(reschedule=True)\n\n class MyTaskSet(TaskSet):\n tasks = [MySubTaskSet]\n\n class MyTestUser(User):\n tasks = [MyTaskSet]\n\n environment = create_environment([MyTestUser], mocked_options())\n environment.stop_timeout = short_time\n runner = environment.create_local_runner()\n runner.start(1, 1, wait=True)\n gevent.sleep(0)\n timeout = gevent.Timeout(short_time)\n timeout.start()\n try:\n runner.quit()\n runner.greenlet.join()\n except gevent.Timeout:\n self.fail(\"Got Timeout exception. Interrupted locusts should exit immediately during stop_timeout.\")\n finally:\n timeout.cancel()\n\n def test_stop_timeout_with_interrupt_no_reschedule(self):\n state = [0]\n\n class MySubTaskSet(TaskSet):\n @task\n def a_task(self):\n gevent.sleep(0.1)\n state[0] = 1\n self.interrupt(reschedule=False)\n\n class MyTestUser(User):\n tasks = [MySubTaskSet]\n wait_time = constant(3)\n\n options = mocked_options()\n options.stop_timeout = 0.3\n environment = create_environment([MyTestUser], options)\n runner = environment.create_local_runner()\n runner.start(1, 1, wait=True)\n gevent.sleep(0)\n timeout = gevent.Timeout(0.11)\n timeout.start()\n try:\n runner.quit()\n runner.greenlet.join()\n except gevent.Timeout:\n self.fail(\"Got Timeout exception. Interrupted locusts should exit immediately during stop_timeout.\")\n finally:\n timeout.cancel()\n self.assertEqual(1, state[0])\n\n def test_kill_locusts_with_stop_timeout(self):\n short_time = 0.05\n\n class MyTaskSet(TaskSet):\n @task\n def my_task(self):\n MyTaskSet.state = \"first\"\n gevent.sleep(short_time)\n MyTaskSet.state = \"second\" # should only run when run time + stop_timeout is > short_time\n gevent.sleep(short_time)\n MyTaskSet.state = \"third\" # should only run when run time + stop_timeout is > short_time * 2\n\n class MyTestUser(User):\n tasks = [MyTaskSet]\n\n environment = create_environment([MyTestUser], mocked_options())\n runner = environment.create_local_runner()\n runner.start(1, 1)\n gevent.sleep(short_time / 2)\n runner.stop_users({MyTestUser.__name__: 1})\n self.assertEqual(\"first\", MyTaskSet.state)\n runner.quit()\n environment.runner = None\n\n environment.stop_timeout = short_time / 2 # exit with timeout\n runner = environment.create_local_runner()\n runner.start(1, 1)\n gevent.sleep(short_time)\n runner.stop_users({MyTestUser.__name__: 1})\n self.assertEqual(\"second\", MyTaskSet.state)\n runner.quit()\n environment.runner = None\n\n environment.stop_timeout = short_time * 3 # allow task iteration to complete, with some margin\n runner = environment.create_local_runner()\n runner.start(1, 1)\n gevent.sleep(short_time)\n timeout = gevent.Timeout(short_time * 2)\n timeout.start()\n try:\n runner.stop_users({MyTestUser.__name__: 1})\n runner.user_greenlets.join()\n except gevent.Timeout:\n self.fail(\"Got Timeout exception. Some locusts must have kept running after iteration finish\")\n finally:\n timeout.cancel()\n self.assertEqual(\"third\", MyTaskSet.state)\n\n def test_users_can_call_runner_quit_with_stop_timeout(self):\n class BaseUser(User):\n wait_time = constant(1)\n\n @task\n def trigger(self):\n self.environment.runner.quit()\n\n runner = Environment(user_classes=[BaseUser]).create_local_runner()\n runner.environment.stop_timeout = 1\n runner.spawn_users({BaseUser.__name__: 1}, wait=False)\n timeout = gevent.Timeout(0.5)\n timeout.start()\n try:\n runner.greenlet.join()\n except gevent.Timeout:\n self.fail(\"Got Timeout exception, runner must have hung somehow.\")\n finally:\n timeout.cancel()\n\n def test_gracefully_handle_exceptions_in_listener(self):\n class MyUser(User):\n wait_time = constant(1)\n\n @task\n def my_task(self):\n pass\n\n test_stop_run = [0]\n environment = Environment(user_classes=[MyUser])\n\n def on_test_stop_ok(*args, **kwargs):\n test_stop_run[0] += 1\n\n def on_test_stop_fail(*args, **kwargs):\n assert False\n\n environment.events.test_stop.add_listener(on_test_stop_ok)\n environment.events.test_stop.add_listener(on_test_stop_fail)\n environment.events.test_stop.add_listener(on_test_stop_ok)\n\n runner = LocalRunner(environment)\n runner.start(user_count=3, spawn_rate=3, wait=False)\n self.assertEqual(0, test_stop_run[0])\n runner.stop()\n self.assertEqual(2, test_stop_run[0])\n\n def test_stop_timeout_with_ramp_down(self):\n class MyTaskSet(TaskSet):\n @task\n def my_task(self):\n gevent.sleep(1)\n\n class MyTestUser(User):\n tasks = [MyTaskSet]\n\n environment = Environment(user_classes=[MyTestUser], stop_timeout=2)\n runner = environment.create_local_runner()\n\n # Start load test, wait for users to start, then trigger ramp down\n ts = time.perf_counter()\n runner.start(10, 10, wait=False)\n runner.spawning_greenlet.join()\n delta = time.perf_counter() - ts\n self.assertTrue(\n 0 <= delta <= 0.05, f\"Expected user count to increase to 10 instantaneously, instead it took {delta:f}\"\n )\n self.assertTrue(\n...\n", "needle_token_start": 8705, "needle_token_end": 9005, "code_context_ntokens": 16102, "output": ["```python\ndef test_worker_heartbeat_messages_sent_to_master(self):\n \"\"\"\n Validate content of the heartbeat payload sent to the master.\n \"\"\"\n\n class MyUser(User):\n wait_time = constant(1)\n\n @task\n def my_task(self):\n pass\n\n with mock.patch(\"locust.rpc.rpc.Client\", mocked_rpc()) as client:\n worker = self.get_runner(environment=Environment(), user_classes=[MyUser], client=client)\n\n t0 = time.perf_counter()\n while len([m for m in client.outbox if m.type == \"heartbeat\"]) == 0:\n self.assertLessEqual(time.perf_counter() - t0, 3)\n sleep(0.1)\n\n message = next(m for m in reversed(client.outbox) if m.type == \"heartbeat\")\n self.assertEqual(len(message.data), 3)\n self.assertIn(\"state\", message.data)\n self.assertIn(\"current_cpu_usage\", message.data)\n self.assertIn(\"current_memory_usage\", message.data)\n\n worker.quit()\n```"]} +{"repo": "locustio/locust", "name": "test_request_stats_named_endpoint_rename_request", "language": "python", "path": "locust/test/test_stats.py", "position_ratio": 0.65, "description": "\nFunction Description:\n1. **Purpose**: To verify that renaming a request endpoint correctly updates the statistics tracking system with the new name.\n2. **Input**: A GET request to a specified URL under a temporarily assigned custom name.\n3. **Output**: The test checks if the statistics entry for the custom name records exactly one request.\n4. **Procedure**: The test initiates a context where the request name is temporarily changed. Within this context, a GET request is made to a specific URL. After exiting the context, the test asserts that the statistics for the custom name correctly reflect the number of requests made under that name.\n\n", "instruction": "Based on the function description and code context, please retrieve and repeat the exact described function from the code context in a code block wrapped by ```:", "template": "instruction\ncode_context\ndescription\ninstruction", "code_context": " log(35, 1)\n log(79, 1)\n log(None, 1)\n log_error(Exception(\"dummy fail\"))\n self.s = self.stats.entries[(\"test_entry\", \"GET\")]\n\n def test_percentile(self):\n s = StatsEntry(self.stats, \"percentile_test\", \"GET\")\n for x in range(100):\n s.log(x, 0)\n\n self.assertEqual(s.get_response_time_percentile(0.5), 50)\n self.assertEqual(s.get_response_time_percentile(0.6), 60)\n self.assertEqual(s.get_response_time_percentile(0.95), 95)\n\n def test_median(self):\n self.assertEqual(self.s.median_response_time, 79)\n\n def test_median_out_of_min_max_bounds(self):\n s = StatsEntry(self.stats, \"median_test\", \"GET\")\n s.log(6034, 0)\n self.assertEqual(s.median_response_time, 6034)\n s.reset()\n s.log(6099, 0)\n self.assertEqual(s.median_response_time, 6099)\n\n def test_total_rps(self):\n self.stats.log_request(\"GET\", \"other_endpoint\", 1337, 1337)\n s2 = self.stats.entries[(\"other_endpoint\", \"GET\")]\n s2.start_time = 2.0\n s2.last_request_timestamp = 6.0\n self.s.start_time = 1.0\n self.s.last_request_timestamp = 4.0\n self.stats.total.start_time = 1.0\n self.stats.total.last_request_timestamp = 6.0\n self.assertEqual(self.s.total_rps, 9 / 5.0)\n self.assertAlmostEqual(s2.total_rps, 1 / 5.0)\n self.assertEqual(self.stats.total.total_rps, 10 / 5.0)\n\n def test_rps_less_than_one_second(self):\n s = StatsEntry(self.stats, \"percentile_test\", \"GET\")\n for i in range(10):\n s.log(i, 0)\n self.assertGreater(s.total_rps, 10)\n\n def test_current_rps(self):\n self.stats.total.last_request_timestamp = int(time.time()) + 4\n self.assertEqual(self.s.current_rps, 4.5)\n\n self.stats.total.last_request_timestamp = int(time.time()) + 25\n self.assertEqual(self.s.current_rps, 0)\n\n def test_current_fail_per_sec(self):\n self.stats.total.last_request_timestamp = int(time.time()) + 4\n self.assertEqual(self.s.current_fail_per_sec, 1.5)\n\n self.stats.total.last_request_timestamp = int(time.time()) + 12\n self.assertEqual(self.s.current_fail_per_sec, 0.3)\n\n self.stats.total.last_request_timestamp = int(time.time()) + 25\n self.assertEqual(self.s.current_fail_per_sec, 0)\n\n def test_num_reqs_fails(self):\n self.assertEqual(self.s.num_requests, 9)\n self.assertEqual(self.s.num_failures, 3)\n\n def test_avg(self):\n self.assertEqual(self.s.avg_response_time, 187.71428571428572)\n\n def test_total_content_length(self):\n self.assertEqual(self.s.total_content_length, 9)\n\n def test_reset(self):\n self.s.reset()\n self.s.log(756, 0)\n self.s.log_error(Exception(\"dummy fail after reset\"))\n self.s.log(85, 0)\n\n self.assertGreater(self.s.total_rps, 2)\n self.assertEqual(self.s.num_requests, 2)\n self.assertEqual(self.s.num_failures, 1)\n self.assertEqual(self.s.avg_response_time, 420.5)\n self.assertEqual(self.s.median_response_time, 85)\n self.assertNotEqual(None, self.s.last_request_timestamp)\n self.s.reset()\n self.assertEqual(None, self.s.last_request_timestamp)\n\n def test_avg_only_none(self):\n self.s.reset()\n self.s.log(None, 123)\n self.assertEqual(self.s.avg_response_time, 0)\n self.assertEqual(self.s.median_response_time, 0)\n self.assertEqual(self.s.get_response_time_percentile(0.5), 0)\n\n def test_reset_min_response_time(self):\n self.s.reset()\n self.s.log(756, 0)\n self.assertEqual(756, self.s.min_response_time)\n\n def test_aggregation(self):\n s1 = StatsEntry(self.stats, \"aggregate me!\", \"GET\")\n s1.log(12, 0)\n s1.log(12, 0)\n s1.log(38, 0)\n s1.log_error(\"Dummy exception\")\n\n s2 = StatsEntry(self.stats, \"aggregate me!\", \"GET\")\n s2.log_error(\"Dummy exception\")\n s2.log_error(\"Dummy exception\")\n s2.log(12, 0)\n s2.log(99, 0)\n s2.log(14, 0)\n s2.log(55, 0)\n s2.log(38, 0)\n s2.log(55, 0)\n s2.log(97, 0)\n\n s = StatsEntry(self.stats, \"GET\", \"\")\n s.extend(s1)\n s.extend(s2)\n\n self.assertEqual(s.num_requests, 10)\n self.assertEqual(s.num_failures, 3)\n self.assertEqual(s.median_response_time, 38)\n self.assertEqual(s.avg_response_time, 43.2)\n\n def test_aggregation_with_rounding(self):\n s1 = StatsEntry(self.stats, \"round me!\", \"GET\")\n s1.log(122, 0) # (rounded 120) min\n s1.log(992, 0) # (rounded 990) max\n s1.log(142, 0) # (rounded 140)\n s1.log(552, 0) # (rounded 550)\n s1.log(557, 0) # (rounded 560)\n s1.log(387, 0) # (rounded 390)\n s1.log(557, 0) # (rounded 560)\n s1.log(977, 0) # (rounded 980)\n\n self.assertEqual(s1.num_requests, 8)\n self.assertEqual(s1.median_response_time, 550)\n self.assertEqual(s1.avg_response_time, 535.75)\n self.assertEqual(s1.min_response_time, 122)\n self.assertEqual(s1.max_response_time, 992)\n\n def test_aggregation_with_decimal_rounding(self):\n s1 = StatsEntry(self.stats, \"round me!\", \"GET\")\n s1.log(1.1, 0)\n s1.log(1.99, 0)\n s1.log(3.1, 0)\n self.assertEqual(s1.num_requests, 3)\n self.assertEqual(s1.median_response_time, 2)\n self.assertEqual(s1.avg_response_time, (1.1 + 1.99 + 3.1) / 3)\n self.assertEqual(s1.min_response_time, 1.1)\n self.assertEqual(s1.max_response_time, 3.1)\n\n def test_aggregation_min_response_time(self):\n s1 = StatsEntry(self.stats, \"min\", \"GET\")\n s1.log(10, 0)\n self.assertEqual(10, s1.min_response_time)\n s2 = StatsEntry(self.stats, \"min\", \"GET\")\n s1.extend(s2)\n self.assertEqual(10, s1.min_response_time)\n\n def test_aggregation_last_request_timestamp(self):\n s1 = StatsEntry(self.stats, \"r\", \"GET\")\n s2 = StatsEntry(self.stats, \"r\", \"GET\")\n s1.extend(s2)\n self.assertEqual(None, s1.last_request_timestamp)\n s1 = StatsEntry(self.stats, \"r\", \"GET\")\n s2 = StatsEntry(self.stats, \"r\", \"GET\")\n s1.last_request_timestamp = 666\n s1.extend(s2)\n self.assertEqual(666, s1.last_request_timestamp)\n s1 = StatsEntry(self.stats, \"r\", \"GET\")\n s2 = StatsEntry(self.stats, \"r\", \"GET\")\n s2.last_request_timestamp = 666\n s1.extend(s2)\n self.assertEqual(666, s1.last_request_timestamp)\n s1 = StatsEntry(self.stats, \"r\", \"GET\")\n s2 = StatsEntry(self.stats, \"r\", \"GET\")\n s1.last_request_timestamp = 666\n s1.last_request_timestamp = 700\n s1.extend(s2)\n self.assertEqual(700, s1.last_request_timestamp)\n\n def test_percentile_rounded_down(self):\n s1 = StatsEntry(self.stats, \"rounding down!\", \"GET\")\n s1.log(122, 0) # (rounded 120) min\n actual_percentile = s1.percentile().split()\n\n self.assertEqual(actual_percentile, [\"GET\", \"rounding\", \"down!\"] + [\"120\"] * len(PERCENTILES_TO_REPORT) + [\"1\"])\n\n def test_percentile_rounded_up(self):\n s2 = StatsEntry(self.stats, \"rounding up!\", \"GET\")\n s2.log(127, 0) # (rounded 130) min\n actual_percentile = s2.percentile().split()\n self.assertEqual(actual_percentile, [\"GET\", \"rounding\", \"up!\"] + [\"130\"] * len(PERCENTILES_TO_REPORT) + [\"1\"])\n\n def test_custom_percentile_list(self):\n s = StatsEntry(self.stats, \"custom_percentiles\", \"GET\")\n custom_percentile_list = [0.50, 0.90, 0.95, 0.99]\n locust.stats.PERCENTILES_TO_REPORT = custom_percentile_list\n s.log(150, 0)\n actual_percentile = s.percentile().split()\n self.assertEqual(\n actual_percentile, [\"GET\", \"custom_percentiles\"] + [\"150\"] * len(custom_percentile_list) + [\"1\"]\n )\n\n def test_error_grouping(self):\n # reset stats\n self.stats = RequestStats()\n\n self.stats.log_error(\"GET\", \"/some-path\", Exception(\"Exception!\"))\n self.stats.log_error(\"GET\", \"/some-path\", Exception(\"Exception!\"))\n\n self.assertEqual(1, len(self.stats.errors))\n self.assertEqual(2, list(self.stats.errors.values())[0].occurrences)\n\n self.stats.log_error(\"GET\", \"/some-path\", Exception(\"Another exception!\"))\n self.stats.log_error(\"GET\", \"/some-path\", Exception(\"Another exception!\"))\n self.stats.log_error(\"GET\", \"/some-path\", Exception(\"Third exception!\"))\n self.assertEqual(3, len(self.stats.errors))\n\n def test_error_grouping_errors_with_memory_addresses(self):\n # reset stats\n self.stats = RequestStats()\n\n class Dummy:\n pass\n\n self.stats.log_error(\"GET\", \"/\", Exception(f\"Error caused by {Dummy()!r}\"))\n self.assertEqual(1, len(self.stats.errors))\n\n def test_serialize_through_message(self):\n \"\"\"\n Serialize a RequestStats instance, then serialize it through a Message,\n and unserialize the whole thing again. This is done \"IRL\" when stats are sent\n from workers to master.\n \"\"\"\n s1 = StatsEntry(self.stats, \"test\", \"GET\")\n s1.log(10, 0)\n s1.log(20, 0)\n s1.log(40, 0)\n u1 = StatsEntry.unserialize(s1.serialize())\n\n data = Message.unserialize(Message(\"dummy\", s1.serialize(), \"none\").serialize()).data\n u1 = StatsEntry.unserialize(data)\n\n self.assertEqual(20, u1.median_response_time)\n\n\nclass TestStatsPrinting(LocustTestCase):\n def setUp(self):\n super().setUp()\n\n self.stats = RequestStats()\n for i in range(100):\n for method, name, freq in [\n (\n \"GET\",\n \"test_entry\",\n 5,\n ),\n (\n \"DELETE\",\n \"test\" * int((STATS_NAME_WIDTH - STATS_TYPE_WIDTH + 4) / len(\"test\")),\n 3,\n ),\n ]:\n self.stats.log_request(method, name, i, 2000 + i)\n if i % freq == 0:\n self.stats.log_error(method, name, RuntimeError(f\"{method} error\"))\n\n def test_print_percentile_stats(self):\n locust.stats.print_percentile_stats(self.stats)\n info = self.mocked_log.info\n self.assertEqual(8, len(info))\n self.assertEqual(\"Response time percentiles (approximated)\", info[0])\n # check that headline contains same number of column as the value rows\n headlines = info[1].replace(\"# reqs\", \"#reqs\").split()\n self.assertEqual(len(headlines), len(info[3].split()))\n self.assertEqual(len(headlines) - 1, len(info[-2].split())) # Aggregated, no \"Type\"\n self.assertEqual(info[2], info[-3]) # table ascii separators\n\n def test_print_stats(self):\n locust.stats.print_stats(self.stats)\n info = self.mocked_log.info\n self.assertEqual(7, len(info))\n\n headlines = info[0].replace(\"# \", \"#\").split()\n\n # check number of columns in separator\n self.assertEqual(len(headlines), len(info[1].split(\"|\")) + 2)\n # check entry row\n self.assertEqual(len(headlines), len(info[2].split()))\n # check aggregated row, which is missing value in \"type\"-column\n self.assertEqual(len(headlines) - 1, len(info[-2].split()))\n # table ascii separators\n self.assertEqual(info[1], info[-3])\n\n def test_print_error_report(self):\n locust.stats.print_error_report(self.stats)\n info = self.mocked_log.info\n self.assertEqual(7, len(info))\n self.assertEqual(\"Error report\", info[0])\n\n headlines = info[1].replace(\"# \", \"#\").split()\n # check number of columns in headlines vs table ascii separator\n self.assertEqual(len(headlines), len(info[2].split(\"|\")))\n # table ascii seprators\n self.assertEqual(info[2], info[-2])\n\n\nclass TestCsvStats(LocustTestCase):\n STATS_BASE_NAME = \"test\"\n STATS_FILENAME = f\"{STATS_BASE_NAME}_stats.csv\"\n STATS_HISTORY_FILENAME = f\"{STATS_BASE_NAME}_stats_history.csv\"\n STATS_FAILURES_FILENAME = f\"{STATS_BASE_NAME}_failures.csv\"\n STATS_EXCEPTIONS_FILENAME = f\"{STATS_BASE_NAME}_exceptions.csv\"\n\n def setUp(self):\n super().setUp()\n self.remove_file_if_exists(self.STATS_FILENAME)\n self.remove_file_if_exists(self.STATS_HISTORY_FILENAME)\n self.remove_file_if_exists(self.STATS_FAILURES_FILENAME)\n self.remove_file_if_exists(self.STATS_EXCEPTIONS_FILENAME)\n\n def tearDown(self):\n self.remove_file_if_exists(self.STATS_FILENAME)\n self.remove_file_if_exists(self.STATS_HISTORY_FILENAME)\n self.remove_file_if_exists(self.STATS_FAILURES_FILENAME)\n self.remove_file_if_exists(self.STATS_EXCEPTIONS_FILENAME)\n\n def remove_file_if_exists(self, filename):\n if os.path.exists(filename):\n os.remove(filename)\n\n def test_write_csv_files(self):\n _write_csv_files(self.environment, self.STATS_BASE_NAME)\n self.assertTrue(os.path.exists(self.STATS_FILENAME))\n self.assertTrue(os.path.exists(self.STATS_HISTORY_FILENAME))\n self.assertTrue(os.path.exists(self.STATS_FAILURES_FILENAME))\n self.assertTrue(os.path.exists(self.STATS_EXCEPTIONS_FILENAME))\n\n def test_write_csv_files_full_history(self):\n _write_csv_files(self.environment, self.STATS_BASE_NAME, full_history=True)\n self.assertTrue(os.path.exists(self.STATS_FILENAME))\n self.assertTrue(os.path.exists(self.STATS_HISTORY_FILENAME))\n self.assertTrue(os.path.exists(self.STATS_FAILURES_FILENAME))\n self.assertTrue(os.path.exists(self.STATS_EXCEPTIONS_FILENAME))\n\n @mock.patch(\"locust.stats.CSV_STATS_INTERVAL_SEC\", new=_TEST_CSV_STATS_INTERVAL_SEC)\n def test_csv_stats_writer(self):\n _write_csv_files(self.environment, self.STATS_BASE_NAME)\n\n self.assertTrue(os.path.exists(self.STATS_FILENAME))\n self.assertTrue(os.path.exists(self.STATS_HISTORY_FILENAME))\n self.assertTrue(os.path.exists(self.STATS_FAILURES_FILENAME))\n self.assertTrue(os.path.exists(self.STATS_EXCEPTIONS_FILENAME))\n\n with open(self.STATS_HISTORY_FILENAME) as f:\n reader = csv.DictReader(f)\n rows = [r for r in reader]\n\n self.assertEqual(2, len(rows))\n self.assertEqual(\"Aggregated\", rows[0][\"Name\"])\n self.assertEqual(\"Aggregated\", rows[1][\"Name\"])\n\n @mock.patch(\"locust.stats.CSV_STATS_INTERVAL_SEC\", new=_TEST_CSV_STATS_INTERVAL_SEC)\n def test_csv_stats_writer_full_history(self):\n stats_writer = StatsCSVFileWriter(\n self.environment, PERCENTILES_TO_REPORT, self.STATS_BASE_NAME, full_history=True\n )\n\n for i in range(10):\n self.runner.stats.log_request(\"GET\", \"/\", 100, content_length=666)\n\n greenlet = gevent.spawn(stats_writer)\n gevent.sleep(10)\n\n for i in range(10):\n self.runner.stats.log_request(\"GET\", \"/\", 10, content_length=666)\n\n gevent.sleep(5)\n\n gevent.sleep(_TEST_CSV_STATS_INTERVAL_WAIT_SEC)\n gevent.kill(greenlet)\n stats_writer.close_files()\n\n self.assertTrue(os.path.exists(self.STATS_FILENAME))\n self.assertTrue(os.path.exists(self.STATS_HISTORY_FILENAME))\n self.assertTrue(os.path.exists(self.STATS_FAILURES_FILENAME))\n self.assertTrue(os.path.exists(self.STATS_EXCEPTIONS_FILENAME))\n\n with open(self.STATS_HISTORY_FILENAME) as f:\n reader = csv.DictReader(f)\n rows = [r for r in reader]\n\n self.assertGreaterEqual(len(rows), 130)\n\n self.assertEqual(\"/\", rows[0][\"Name\"])\n self.assertEqual(\"Aggregated\", rows[1][\"Name\"])\n self.assertEqual(\"/\", rows[2][\"Name\"])\n self.assertEqual(\"Aggregated\", rows[3][\"Name\"])\n self.assertEqual(\"20\", rows[-1][\"Total Request Count\"])\n\n saw100 = False\n saw10 = False\n\n for row in rows:\n if not saw100 and row[\"95%\"] == \"100\":\n saw100 = True\n elif saw100 and row[\"95%\"] == \"10\":\n saw10 = True\n break\n\n self.assertTrue(saw100, \"Never saw 95th percentile increase to 100\")\n self.assertTrue(saw10, \"Never saw 95th percentile decrease to 10\")\n\n def test_csv_stats_on_master_from_aggregated_stats(self):\n # Failing test for: https://github.com/locustio/locust/issues/1315\n with mock.patch(\"locust.rpc.rpc.Server\", mocked_rpc()) as server:\n environment = Environment()\n stats_writer = StatsCSVFileWriter(\n environment, PERCENTILES_TO_REPORT, self.STATS_BASE_NAME, full_history=True\n )\n master = environment.create_master_runner(master_bind_host=\"*\", master_bind_port=0)\n greenlet = gevent.spawn(stats_writer)\n gevent.sleep(_TEST_CSV_STATS_INTERVAL_WAIT_SEC)\n\n server.mocked_send(Message(\"client_ready\", __version__, \"fake_client\"))\n\n master.stats.entries[(\"/\", \"GET\")].log(100, 23455)\n master.stats.entries[(\"/\", \"GET\")].log(800, 23455)\n master.stats.entries[(\"/\", \"GET\")].log(700, 23455)\n\n data = {\"user_count\": 1}\n environment.events.report_to_master.fire(client_id=\"fake_client\", data=data)\n master.stats.clear_all()\n\n server.mocked_send(Message(\"stats\", data, \"fake_client\"))\n s = master.stats.entries[(\"/\", \"GET\")]\n self.assertEqual(700, s.median_response_time)\n\n gevent.kill(greenlet)\n stats_writer.close_files()\n\n self.assertTrue(os.path.exists(self.STATS_FILENAME))\n self.assertTrue(os.path.exists(self.STATS_HISTORY_FILENAME))\n self.assertTrue(os.path.exists(self.STATS_FAILURES_FILENAME))\n self.assertTrue(os.path.exists(self.STATS_EXCEPTIONS_FILENAME))\n\n @mock.patch(\"locust.stats.CSV_STATS_INTERVAL_SEC\", new=_TEST_CSV_STATS_INTERVAL_SEC)\n def test_user_count_in_csv_history_stats(self):\n start_time = int(time.time())\n\n class TestUser(User):\n wait_time = constant(10)\n\n @task\n def t(self):\n self.environment.runner.stats.log_request(\"GET\", \"/\", 10, 10)\n\n environment = Environment(user_classes=[TestUser])\n stats_writer = StatsCSVFileWriter(environment, PERCENTILES_TO_REPORT, self.STATS_BASE_NAME, full_history=True)\n runner = environment.create_local_runner()\n # spawn a user every _TEST_CSV_STATS_INTERVAL_SEC second\n user_count = 15\n spawn_rate = 5\n assert 1 / 5 == _TEST_CSV_STATS_INTERVAL_SEC\n runner_greenlet = gevent.spawn(runner.start, user_count, spawn_rate)\n gevent.sleep(0.1)\n\n greenlet = gevent.spawn(stats_writer)\n gevent.sleep(user_count / spawn_rate)\n gevent.kill(greenlet)\n stats_writer.close_files()\n runner.stop()\n gevent.kill(runner_greenlet)\n\n with open(self.STATS_HISTORY_FILENAME) as f:\n reader = csv.DictReader(f)\n rows = [r for r in reader]\n\n self.assertEqual(2 * user_count, len(rows))\n for i in range(int(user_count / spawn_rate)):\n for _ in range(spawn_rate):\n row = rows.pop(0)\n self.assertEqual(\"%i\" % ((i + 1) * spawn_rate), row[\"User Count\"])\n self.assertEqual(\"/\", row[\"Name\"])\n self.assertEqual(\"%i\" % ((i + 1) * spawn_rate), row[\"Total Request Count\"])\n self.assertGreaterEqual(int(row[\"Timestamp\"]), start_time)\n row = rows.pop(0)\n self.assertEqual(\"%i\" % ((i + 1) * spawn_rate), row[\"User Count\"])\n self.assertEqual(\"Aggregated\", row[\"Name\"])\n self.assertEqual(\"%i\" % ((i + 1) * spawn_rate), row[\"Total Request Count\"])\n self.assertGreaterEqual(int(row[\"Timestamp\"]), start_time)\n\n def test_requests_csv_quote_escaping(self):\n with mock.patch(\"locust.rpc.rpc.Server\", mocked_rpc()) as server:\n environment = Environment()\n master = environment.create_master_runner(master_bind_host=\"*\", master_bind_port=0)\n server.mocked_send(Message(\"client_ready\", __version__, \"fake_client\"))\n\n request_name_dict = {\n \"scenario\": \"get cashes\",\n \"path\": \"/cash/[amount]\",\n \"arguments\": [{\"size\": 1}],\n }\n request_name_str = json.dumps(request_name_dict)\n\n master.stats.entries[(request_name_str, \"GET\")].log(100, 23455)\n data = {\"user_count\": 1}\n environment.events.report_to_master.fire(client_id=\"fake_client\", data=data)\n master.stats.clear_all()\n server.mocked_send(Message(\"stats\", data, \"fake_client\"))\n\n _write_csv_files(environment, self.STATS_BASE_NAME, full_history=True)\n with open(self.STATS_FILENAME) as f:\n reader = csv.DictReader(f)\n rows = [r for r in reader]\n csv_request_name = rows[0].get(\"Name\")\n self.assertEqual(request_name_str, csv_request_name)\n\n def test_stats_history(self):\n env1 = Environment(events=locust.events, catch_exceptions=False)\n runner1 = env1.create_master_runner(\"127.0.0.1\", 5558)\n env2 = Environment(events=locust.events, catch_exceptions=False)\n runner2 = env2.create_worker_runner(\"127.0.0.1\", 5558)\n greenlet1 = gevent.spawn(stats_history, runner1)\n greenlet2 = gevent.spawn(stats_history, runner2)\n gevent.sleep(1)\n hs1 = runner1.stats.history\n hs2 = runner2.stats.history\n gevent.kill(greenlet1)\n gevent.kill(greenlet2)\n self.assertEqual(1, len(hs1))\n self.assertEqual(0, len(hs2))\n\n\nclass TestStatsEntryResponseTimesCache(unittest.TestCase):\n def setUp(self, *args, **kwargs):\n super().setUp(*args, **kwargs)\n self.stats = RequestStats()\n\n def test_response_times_cached(self):\n s = StatsEntry(self.stats, \"/\", \"GET\", use_response_times_cache=True)\n self.assertEqual(1, len(s.response_times_cache))\n s.log(11, 1337)\n self.assertEqual(1, len(s.response_times_cache))\n s.last_request_timestamp -= 1\n s.log(666, 1337)\n self.assertEqual(2, len(s.response_times_cache))\n self.assertEqual(\n CachedResponseTimes(\n response_times={11: 1},\n num_requests=1,\n ),\n s.response_times_cache[int(s.last_request_timestamp) - 1],\n )\n\n def test_response_times_not_cached_if_not_enabled(self):\n s = StatsEntry(self.stats, \"/\", \"GET\")\n s.log(11, 1337)\n self.assertEqual(None, s.response_times_cache)\n s.last_request_timestamp -= 1\n s.log(666, 1337)\n self.assertEqual(None, s.response_times_cache)\n\n def test_latest_total_response_times_pruned(self):\n \"\"\"\n Check that RequestStats.latest_total_response_times are pruned when exceeding 20 entries\n \"\"\"\n s = StatsEntry(self.stats, \"/\", \"GET\", use_response_times_cache=True)\n t = int(time.time())\n for i in reversed(range(2, 30)):\n s.response_times_cache[t - i] = CachedResponseTimes(response_times={}, num_requests=0)\n self.assertEqual(29, len(s.response_times_cache))\n s.log(17, 1337)\n s.last_request_timestamp -= 1\n s.log(1, 1)\n self.assertEqual(20, len(s.response_times_cache))\n self.assertEqual(\n CachedResponseTimes(response_times={17: 1}, num_requests=1),\n s.response_times_cache.popitem(last=True)[1],\n )\n\n def test_get_current_response_time_percentile(self):\n s = StatsEntry(self.stats, \"/\", \"GET\", use_response_times_cache=True)\n t = int(time.time())\n s.response_times_cache[t - 10] = CachedResponseTimes(\n response_times={i: 1 for i in range(100)}, num_requests=200\n )\n s.response_times_cache[t - 10].response_times[1] = 201\n\n s.response_times = {i: 2 for i in range(100)}\n s.response_times[1] = 202\n s.num_requests = 300\n\n self.assertEqual(95, s.get_current_response_time_percentile(0.95))\n\n def test_get_current_response_time_percentile_outside_cache_window(self):\n s = StatsEntry(self.stats, \"/\", \"GET\", use_response_times_cache=True)\n # an empty response times cache, current time will not be in this cache\n s.response_times_cache = {}\n self.assertEqual(None, s.get_current_response_time_percentile(0.95))\n\n def test_diff_response_times_dicts(self):\n self.assertEqual(\n {1: 5, 6: 8},\n diff_response_time_dicts(\n {1: 6, 6: 16, 2: 2},\n {1: 1, 6: 8, 2: 2},\n ),\n )\n self.assertEqual(\n {},\n diff_response_time_dicts(\n {},\n {},\n ),\n )\n self.assertEqual(\n {10: 15},\n diff_response_time_dicts(\n {10: 15},\n {},\n ),\n )\n self.assertEqual(\n {10: 10},\n diff_response_time_dicts(\n {10: 10},\n {},\n ),\n )\n self.assertEqual(\n {},\n diff_response_time_dicts(\n {1: 1},\n {1: 1},\n ),\n )\n\n\nclass TestStatsEntry(unittest.TestCase):\n def parse_string_output(self, text):\n tokenlist = re.split(r\"[\\s\\(\\)%|]+\", text.strip())\n tokens = {\n \"method\": tokenlist[0],\n \"name\": tokenlist[1],\n \"request_count\": int(tokenlist[2]),\n \"failure_count\": int(tokenlist[3]),\n \"failure_percentage\": float(tokenlist[4]),\n }\n return tokens\n\n def setUp(self, *args, **kwargs):\n super().setUp(*args, **kwargs)\n self.stats = RequestStats()\n\n def test_fail_ratio_with_no_failures(self):\n REQUEST_COUNT = 10\n FAILURE_COUNT = 0\n EXPECTED_FAIL_RATIO = 0.0\n\n s = StatsEntry(self.stats, \"/\", \"GET\")\n s.num_requests = REQUEST_COUNT\n s.num_failures = FAILURE_COUNT\n\n self.assertAlmostEqual(s.fail_ratio, EXPECTED_FAIL_RATIO)\n output_fields = self.parse_string_output(str(s))\n self.assertEqual(output_fields[\"request_count\"], REQUEST_COUNT)\n self.assertEqual(output_fields[\"failure_count\"], FAILURE_COUNT)\n self.assertAlmostEqual(output_fields[\"failure_percentage\"], EXPECTED_FAIL_RATIO * 100)\n\n def test_fail_ratio_with_all_failures(self):\n REQUEST_COUNT = 10\n FAILURE_COUNT = 10\n EXPECTED_FAIL_RATIO = 1.0\n\n s = StatsEntry(self.stats, \"/\", \"GET\")\n s.num_requests = REQUEST_COUNT\n s.num_failures = FAILURE_COUNT\n\n self.assertAlmostEqual(s.fail_ratio, EXPECTED_FAIL_RATIO)\n output_fields = self.parse_string_output(str(s))\n self.assertEqual(output_fields[\"request_count\"], REQUEST_COUNT)\n self.assertEqual(output_fields[\"failure_count\"], FAILURE_COUNT)\n self.assertAlmostEqual(output_fields[\"failure_percentage\"], EXPECTED_FAIL_RATIO * 100)\n\n def test_fail_ratio_with_half_failures(self):\n REQUEST_COUNT = 10\n FAILURE_COUNT = 5\n EXPECTED_FAIL_RATIO = 0.5\n\n s = StatsEntry(self.stats, \"/\", \"GET\")\n s.num_requests = REQUEST_COUNT\n s.num_failures = FAILURE_COUNT\n\n self.assertAlmostEqual(s.fail_ratio, EXPECTED_FAIL_RATIO)\n output_fields = self.parse_string_output(str(s))\n self.assertEqual(output_fields[\"request_count\"], REQUEST_COUNT)\n self.assertEqual(output_fields[\"failure_count\"], FAILURE_COUNT)\n self.assertAlmostEqual(output_fields[\"failure_percentage\"], EXPECTED_FAIL_RATIO * 100)\n\n\nclass TestRequestStatsWithWebserver(WebserverTestCase):\n def setUp(self):\n super().setUp()\n\n class MyUser(HttpUser):\n host = \"http://127.0.0.1:%i\" % self.port\n\n self.locust = MyUser(self.environment)\n\n def test_request_stats_content_length(self):\n self.locust.client.get(\"/ultra_fast\")\n self.assertEqual(\n self.runner.stats.entries[(\"/ultra_fast\", \"GET\")].avg_content_length, len(\"This is an ultra fast response\")\n )\n self.locust.client.get(\"/ultra_fast\")\n # test legacy stats.get() function sometimes too\n self.assertEqual(\n self.runner.stats.get(\"/ultra_fast\", \"GET\").avg_content_length, len(\"This is an ultra fast response\")\n )\n\n def test_request_stats_no_content_length(self):\n path = \"/no_content_length\"\n self.locust.client.get(path)\n self.assertEqual(\n self.runner.stats.entries[(path, \"GET\")].avg_content_length,\n len(\"This response does not have content-length in the header\"),\n )\n\n def test_request_stats_no_content_length_streaming(self):\n path = \"/no_content_length\"\n self.locust.client.get(path, stream=True)\n self.assertEqual(0, self.runner.stats.entries[(path, \"GET\")].avg_content_length)\n\n def test_request_stats_named_endpoint(self):\n self.locust.client.get(\"/ultra_fast\", name=\"my_custom_name\")\n self.assertEqual(1, self.runner.stats.entries[(\"my_custom_name\", \"GET\")].num_requests)\n\n def test_request_stats_named_endpoint_request_name(self):\n self.locust.client.request_name = \"my_custom_name_1\"\n self.locust.client.get(\"/ultra_fast\")\n self.assertEqual(1, self.runner.stats.entries[(\"my_custom_name_1\", \"GET\")].num_requests)\n self.locust.client.request_name = None\n\n \ndef test_request_stats_named_endpoint_rename_request(self):\n with self.locust.client.rename_request(\"my_custom_name_3\"):\n self.locust.client.get(\"/ultra_fast\")\n self.assertEqual(1, self.runner.stats.entries[(\"my_custom_name_3\", \"GET\")].num_requests)\n\n def test_request_stats_query_variables(self):\n self.locust.client.get(\"/ultra_fast?query=1\")\n self.assertEqual(1, self.runner.stats.entries[(\"/ultra_fast?query=1\", \"GET\")].num_requests)\n\n def test_request_stats_put(self):\n self.locust.client.put(\"/put\")\n self.assertEqual(1, self.runner.stats.entries[(\"/put\", \"PUT\")].num_requests)\n\n def test_request_connection_error(self):\n class MyUser(HttpUser):\n host = \"http://localhost:1\"\n\n locust = MyUser(self.environment)\n response = locust.client.get(\"/\", timeout=0.1)\n self.assertEqual(response.status_code, 0)\n self.assertEqual(1, self.runner.stats.entries[(\"/\", \"GET\")].num_failures)\n self.assertEqual(1, self.runner.stats.entries[(\"/\", \"GET\")].num_requests)\n\n\nclass MyTaskSet(TaskSet):\n @task(75)\n def root_task(self):\n pass\n\n @task(25)\n class MySubTaskSet(TaskSet):\n @task\n def task1(self):\n pass\n\n @task\n def task2(self):\n pass\n\n\nclass TestInspectUser(unittest.TestCase):\n def test_get_task_ratio_relative(self):\n ratio = _get_task_ratio([MyTaskSet], False, 1.0)\n self.assertEqual(1.0, ratio[\"MyTaskSet\"][\"ratio\"])\n self.assertEqual(0.75, ratio[\"MyTaskSet\"][\"tasks\"][\"root_task\"][\"ratio\"])\n self.assertEqual(0.25, ratio[\"MyTaskSet\"][\"tasks\"][\"MySubTaskSet\"][\"ratio\"])\n self.assertEqual(0.5, ratio[\"MyTaskSet\"][\"tasks\"][\"MySubTaskSet\"][\"tasks\"][\"task1\"][\"ratio\"])\n self.assertEqual(0.5, ratio[\"MyTaskSet\"][\"tasks\"][\"MySubTaskSet\"][\"tasks\"][\"task2\"][\"ratio\"])\n\n def test_get_task_ratio_total(self):\n ratio = _get_task_ratio([MyTaskSet], True, 1.0)\n self.assertEqual(1.0, ratio[\"MyTaskSet\"][\"ratio\"])\n self.assertEqual(0.75, ratio[\"MyTaskSet\"][\"tasks\"][\"root_task\"][\"ratio\"])\n self.assertEqual(0.25, ratio[\"MyTaskSet\"][\"tasks\"][\"MySubTaskSet\"][\"ratio\"])\n self.assertEqual(0.125, ratio[\"MyTaskSet\"][\"tasks\"][\"MySubTaskSet\"][\"tasks\"][\"task1\"][\"ratio\"])\n self.assertEqual(0.125, ratio[\"MyTaskSet\"][\"tasks\"][\"MySubTaskSet\"][\"tasks\"][\"task2\"][\"ratio\"])\n\n# Path: locust/test/test_tags.py\nfrom locust import TaskSet, User, tag, task\nfrom locust.env import Environment\nfrom locust.user.task import filter_tasks_by_tags\n\nfrom .testcases import LocustTestCase\n\n\nclass TestTags(LocustTestCase):\n def test_tagging(self):\n @tag(\"tag1\")\n @task\n def tagged():\n pass\n\n self.assertIn(\"locust_tag_set\", dir(tagged))\n self.assertEqual({\"tag1\"}, tagged.locust_tag_set)\n\n @tag(\"tag2\", \"tag3\")\n @task\n def tagged_multiple_args():\n pass\n\n self.assertIn(\"locust_tag_set\", dir(tagged_multiple_args))\n self.assertEqual({\"tag2\", \"tag3\"}, tagged_multiple_args.locust_tag_set)\n\n @tag(\"tag4\")\n @tag(\"tag5\")\n @task\n def tagged_multiple_times():\n pass\n\n self.assertIn(\"locust_tag_set\", dir(tagged_multiple_times))\n self.assertEqual({\"tag4\", \"tag5\"}, tagged_multiple_times.locust_tag_set)\n\n def test_tagging_taskset(self):\n @tag(\"taskset\")\n @task\n class MyTaskSet(TaskSet):\n @task\n def tagged(self):\n pass\n\n @tag(\"task\")\n @task\n def tagged_again(self):\n pass\n\n @tag(\"taskset2\")\n @task\n class NestedTaskSet(TaskSet):\n @task\n def nested_task(self):\n pass\n\n # when tagging taskset, its tasks receive the tag\n self.assertIn(\"locust_tag_set\", dir(MyTaskSet.tagged))\n self.assertEqual({\"taskset\"}, MyTaskSet.tagged.locust_tag_set)\n\n # tagging inner task receives both\n self.assertIn(\"locust_tag_set\", dir(MyTaskSet.tagged_again))\n self.assertEqual({\"taskset\", \"task\"}, MyTaskSet.tagged_again.locust_tag_set)\n\n # when tagging nested taskset, its tasks receives both\n self.assertIn(\"locust_tag_set\", dir(MyTaskSet.NestedTaskSet.nested_task))\n self.assertEqual({\"taskset\", \"taskset2\"}, MyTaskSet.NestedTaskSet.nested_task.locust_tag_set)\n\n def test_tagging_without_args_fails(self):\n @task\n def dummy_task(self):\n pass\n\n # task is tagged without parens\n self.assertRaises(ValueError, lambda: tag(dummy_task))\n\n # task is tagged with empty parens\n self.assertRaises(ValueError, lambda: tag()(dummy_task))\n\n def test_including_tags(self):\n class MyTaskSet(TaskSet):\n @tag(\"include this\", \"other tag\")\n @task\n def included(self):\n pass\n\n @tag(\"dont include this\", \"other tag\")\n @task\n def not_included(self):\n pass\n\n @task\n def dont_include_this_either(self):\n pass\n\n self.assertListEqual(\n MyTaskSet.tasks, [MyTaskSet.included, MyTaskSet.not_included, MyTaskSet.dont_include_this_either]\n )\n\n filter_tasks_by_tags(MyTaskSet, tags={\"include this\"})\n self.assertListEqual(MyTaskSet.tasks, [MyTaskSet.included])\n\n def test_excluding_tags(self):\n class MyTaskSet(TaskSet):\n @tag(\"exclude this\", \"other tag\")\n @task\n def excluded(self):\n pass\n\n @tag(\"dont exclude this\", \"other tag\")\n @task\n def not_excluded(self):\n pass\n\n @task\n def dont_exclude_this_either(self):\n pass\n\n self.assertListEqual(\n MyTaskSet.tasks, [MyTaskSet.excluded, MyTaskSet.not_excluded, MyTaskSet.dont_exclude_this_either]\n )\n\n filter_tasks_by_tags(MyTaskSet, exclude_tags={\"exclude this\"})\n self.assertListEqual(MyTaskSet.tasks, [MyTaskSet.not_excluded, MyTaskSet.dont_exclude_this_either])\n\n def test_including_and_excluding(self):\n class MyTaskSet(TaskSet):\n @task\n def not_included_or_excluded(self):\n pass\n\n @tag(\"included\")\n @task\n def included(self):\n pass\n\n @tag(\"excluded\")\n @task\n def excluded(self):\n pass\n\n @tag(\"included\", \"excluded\")\n @task\n def included_and_excluded(self):\n pass\n\n filter_tasks_by_tags(MyTaskSet, tags={\"included\"}, exclude_tags={\"excluded\"})\n self.assertListEqual(MyTaskSet.tasks, [MyTaskSet.included])\n\n def test_including_tasksets(self):\n class MyTaskSet(TaskSet):\n @task\n class MixedNestedTaskSet(TaskSet):\n @tag(\"included\")\n @task\n def included(self):\n pass\n\n @task\n def not_included(self):\n pass\n\n @tag(\"included\")\n @task\n class TaggedNestedTaskSet(TaskSet):\n @task\n def included(self):\n pass\n\n @task\n class NormalNestedTaskSet(TaskSet):\n @task\n def not_included(self):\n pass\n\n filter_tasks_by_tags(MyTaskSet, tags={\"included\"})\n self.assertListEqual(MyTaskSet.tasks, [MyTaskSet.MixedNestedTaskSet, MyTaskSet.TaggedNestedTaskSet])\n self.assertListEqual(MyTaskSet.MixedNestedTaskSet.tasks, [MyTaskSet.MixedNestedTaskSet.included])\n\n def test_excluding_tasksets(self):\n class MyTaskSet(TaskSet):\n @task\n class MixedNestedTaskSet(TaskSet):\n @tag(\"excluded\")\n @task\n def excluded(self):\n pass\n\n @task\n def not_excluded(self):\n pass\n\n @task\n class ExcludedNestedTaskSet(TaskSet):\n @tag(\"excluded\")\n @task\n def excluded(self):\n pass\n\n @tag(\"excluded\")\n @task\n class TaggedNestedTaskSet(TaskSet):\n @task\n def excluded(self):\n pass\n\n @task\n class NormalNestedTaskSet(TaskSet):\n @task\n def not_excluded(self):\n pass\n\n filter_tasks_by_tags(MyTaskSet, exclude_tags={\"excluded\"})\n self.assertListEqual(MyTaskSet.tasks, [MyTaskSet.MixedNestedTaskSet, MyTaskSet.NormalNestedTaskSet])\n self.assertListEqual(MyTaskSet.MixedNestedTaskSet.tasks, [MyTaskSet.MixedNestedTaskSet.not_excluded])\n\n def test_including_tags_with_weights(self):\n class MyTaskSet(TaskSet):\n @tag(\"included\")\n @task(2)\n def include_twice(self):\n pass\n\n @tag(\"included\")\n @task(3)\n def include_3_times(self):\n pass\n\n @tag(\"dont include this\")\n @task(4)\n def dont_include_4_times(self):\n pass\n\n @task(5)\n def dont_include_5_times(self):\n pass\n\n self.assertListEqual(\n MyTaskSet.tasks,\n [\n MyTaskSet.include_twice,\n MyTaskSet.include_twice,\n MyTaskSet.include_3_times,\n MyTaskSet.include_3_times,\n MyTaskSet.include_3_times,\n MyTaskSet.dont_include_4_times,\n MyTaskSet.dont_include_4_times,\n MyTaskSet.dont_include_4_times,\n MyTaskSet.dont_include_4_times,\n MyTaskSet.dont_include_5_times,\n MyTaskSet.dont_include_5_times,\n MyTaskSet.dont_include_5_times,\n MyTaskSet.dont_include_5_times,\n MyTaskSet.dont_include_5_times,\n ],\n )\n\n filter_tasks_by_tags(MyTaskSet, tags={\"included\"})\n\n self.assertListEqual(\n MyTaskSet.tasks,\n [\n MyTaskSet.include_twice,\n MyTaskSet.include_twice,\n MyTaskSet.include_3_times,\n MyTaskSet.include_3_times,\n MyTaskSet.include_3_times,\n ],\n )\n\n def test_excluding_tags_with_weights(self):\n class MyTaskSet(TaskSet):\n @tag(\"dont exclude this\")\n @task(2)\n def dont_exclude_twice(self):\n pass\n\n @task(3)\n def dont_exclude_3_times(self):\n pass\n\n @tag(\"excluded\")\n @task(4)\n def exclude_4_times(self):\n pass\n\n @tag(\"excluded\")\n @task(5)\n def exclude_5_times(self):\n pass\n\n self.assertListEqual(\n MyTaskSet.tasks,\n [\n MyTaskSet.dont_exclude_twice,\n MyTaskSet.dont_exclude_twice,\n MyTaskSet.dont_exclude_3_times,\n MyTaskSet.dont_exclude_3_times,\n MyTaskSet.dont_exclude_3_times,\n MyTaskSet.exclude_4_times,\n MyTaskSet.exclude_4_times,\n MyTaskSet.exclude_4_times,\n MyTaskSet.exclude_4_times,\n MyTaskSet.exclude_5_times,\n MyTaskSet.exclude_5_times,\n MyTaskSet.exclude_5_times,\n MyTaskSet.exclude_5_times,\n MyTaskSet.exclude_5_times,\n ],\n )\n\n filter_tasks_by_tags(MyTaskSet, exclude_tags={\"excluded\"})\n\n self.assertListEqual(\n MyTaskSet.tasks,\n [\n MyTaskSet.dont_exclude_twice,\n MyTaskSet.dont_exclude_twice,\n MyTaskSet.dont_exclude_3_times,\n MyTaskSet.dont_exclude_3_times,\n MyTaskSet.dont_exclude_3_times,\n ],\n )\n\n def test_tagged_tasks_shared_across_tasksets(self):\n @tag(\"tagged\")\n def shared_task():\n pass\n\n def untagged_shared_task():\n pass\n\n @tag(\"tagged\")\n class SharedTaskSet(TaskSet):\n @task\n def inner_task(self):\n pass\n\n class IncludeTaskSet(TaskSet):\n tasks = [shared_task, untagged_shared_task, SharedTaskSet]\n\n class ExcludeTaskSet(TaskSet):\n tasks = [shared_task, untagged_shared_task, SharedTaskSet]\n\n filter_tasks_by_tags(IncludeTaskSet, tags={\"tagged\"})\n\n self.assertListEqual(IncludeTaskSet.tasks, [shared_task, SharedTaskSet])\n self.assertListEqual(IncludeTaskSet.tasks[1].tasks, [SharedTaskSet.inner_task])\n\n filter_tasks_by_tags(ExcludeTaskSet, exclude_tags={\"tagged\"})\n\n self.assertListEqual(ExcludeTaskSet.tasks, [untagged_shared_task])\n\n def test_include_tags_under_user(self):\n class MyUser(User):\n @tag(\"include this\")\n @task\n def included(self):\n pass\n\n @tag(\"dont include this\")\n @task\n def not_included(self):\n pass\n\n @task\n def dont_include_this_either(self):\n pass\n\n filter_tasks_by_tags(MyUser, tags={\"include this\"})\n\n self.assertListEqual(MyUser.tasks, [MyUser.included])\n\n def test_exclude_tags_under_user(self):\n class MyUser(User):\n @tag(\"exclude this\")\n @task\n def excluded(self):\n pass\n\n @tag(\"dont exclude this\")\n @task\n def not_excluded(self):\n pass\n\n @task\n def dont_exclude_this_either(self):\n pass\n\n filter_tasks_by_tags(MyUser, exclude_tags={\"exclude this\"})\n\n self.assertListEqual(MyUser.tasks, [MyUser.not_excluded, MyUser.dont_exclude_this_either])\n\n def test_env_include_tags(self):\n class MyTaskSet(TaskSet):\n @tag(\"include this\")\n @task\n def included(self):\n pass\n\n @tag(\"dont include this\")\n @task\n def not_included(self):\n pass\n\n @task\n def dont_include_this_either(self):\n pass\n\n class MyUser(User):\n tasks = [MyTaskSet]\n\n env = Environment(user_classes=[MyUser], tags=[\"include this\"])\n env._filter_tasks_by_tags()\n\n self.assertListEqual(MyUser.tasks, [MyTaskSet])\n self.assertListEqual(MyUser.tasks[0].tasks, [MyTaskSet.included])\n\n def test_env_exclude_tags(self):\n class MyTaskSet(User):\n @tag(\"exclude this\")\n @task\n def excluded(self):\n pass\n\n @tag(\"dont exclude this\")\n @task\n def not_excluded(self):\n pass\n\n @task\n def dont_exclude_this_either(self):\n pass\n\n class MyUser(User):\n tasks = [MyTaskSet]\n\n env = Environment(user_classes=[MyUser], exclude_tags=[\"exclude this\"])\n env._filter_tasks_by_tags()\n\n self.assertListEqual(MyUser.tasks, [MyTaskSet])\n self.assertListEqual(MyUser.tasks[0].tasks, [MyTaskSet.not_excluded, MyTaskSet.dont_exclude_this_either])\n\n# Path: locust/test/test_taskratio.py\nfrom locust.user import TaskSet, User, task\nfrom locust.user.inspectuser import _get_task_ratio, get_ratio\n\nimport unittest\n\n\nclass TestTaskRatio(unittest.TestCase):\n def test_task_ratio_command(self):\n class Tasks(TaskSet):\n @task\n def root_task1(self):\n pass\n\n @task\n class SubTasks(TaskSet):\n @task\n def task1(self):\n pass\n\n @task\n def task2(self):\n pass\n\n class MyUser(User):\n tasks = [Tasks]\n\n ratio_dict = _get_task_ratio(Tasks.tasks, True, 1.0)\n\n self.assertEqual(\n {\n \"SubTasks\": {\"tasks\": {\"task1\": {\"ratio\": 0.25}, \"task2\": {\"ratio\": 0.25}}, \"ratio\": 0.5},\n \"root_task1\": {\"ratio\": 0.5},\n },\n ratio_dict,\n )\n\n def test_task_ratio_command_with_locust_weight(self):\n class Tasks(TaskSet):\n @task(1)\n def task1(self):\n pass\n\n @task(3)\n def task3(self):\n pass\n\n class UnlikelyUser(User):\n weight = 1\n tasks = [Tasks]\n\n class MoreLikelyUser(User):\n weight = 3\n tasks = [Tasks]\n\n ratio_dict = get_ratio([UnlikelyUser, MoreLikelyUser], {\"UnlikelyUser\": 1, \"MoreLikelyUser\": 3}, True)\n\n self.assertDictEqual(\n {\n \"UnlikelyUser\": {\n \"ratio\": 0.25,\n \"tasks\": {\n \"Tasks\": {\n \"tasks\": {\n \"task1\": {\"ratio\": 0.25 * 0.25},\n \"task3\": {\"ratio\": 0.25 * 0.75},\n },\n \"ratio\": 0.25,\n }\n },\n },\n \"MoreLikelyUser\": {\n \"ratio\": 0.75,\n \"tasks\": {\n \"Tasks\": {\n \"tasks\": {\n \"task1\": {\"ratio\": 0.75 * 0.25},\n \"task3\": {\"ratio\": 0.75 * 0.75},\n },\n \"ratio\": 0.75,\n },\n },\n },\n },\n ratio_dict,\n )\n unlikely = ratio_dict[\"UnlikelyUser\"][\"tasks\"][\"Tasks\"][\"tasks\"]\n likely = ratio_dict[\"MoreLikelyUser\"][\"tasks\"][\"Tasks\"][\"tasks\"]\n assert (\n unlikely[\"task1\"][\"ratio\"]\n + unlikely[\"task3\"][\"ratio\"]\n + likely[\"task1\"][\"ratio\"]\n + likely[\"task3\"][\"ratio\"]\n == 1\n )\n\n# Path: locust/test/test_users.py\nfrom locust import HttpUser, User\nfrom locust.test.testcases import WebserverTestCase\n\nimport unittest\n\nfrom urllib3 import PoolManager\n\n\nclass TestUserClass(unittest.TestCase):\n class MyClassScopedUser(User):\n pass\n\n def test_fullname_module_scoped(self):\n self.assertEqual(MyModuleScopedUser.fullname(), \"locust.test.test_users.MyModuleScopedUser\")\n\n def test_fullname_class_scoped(self):\n...\n", "needle_token_start": 10468, "needle_token_end": 10560, "code_context_ntokens": 16108, "output": ["```python\ndef test_request_stats_named_endpoint_rename_request(self):\n with self.locust.client.rename_request(\"my_custom_name_3\"):\n self.locust.client.get(\"/ultra_fast\")\n self.assertEqual(1, self.runner.stats.entries[(\"my_custom_name_3\", \"GET\")].num_requests)\n```"]} +{"repo": "locustio/locust", "name": "is_valid_percentile", "language": "python", "path": "locust/main.py", "position_ratio": 0.75, "description": "\nFunction Description:\n1. **Purpose**: To verify if a given parameter is a valid percentile value, specifically ensuring it is a decimal between 0 and 1 (exclusive).\n2. **Input**: A single parameter, expected to be a numeric value or a string that can be converted to a float.\n3. **Output**: A boolean value; `True` if the input is a valid percentile, otherwise `False`.\n4. **Procedure**: The function first attempts to convert the input to a float. If successful, it checks whether the float value is greater than 0 and less than 1. If both conditions are met, it returns `True`. If the conversion fails or the conditions are not met, it returns `False`.\n\n", "instruction": "Based on the function description and code context, please retrieve and repeat the exact described function from the code context in a code block wrapped by ```:", "template": "instruction\ncode_context\ndescription\ninstruction", "code_context": "# Path: locust/contrib/fasthttp.py\nfrom __future__ import annotations\n\nfrom locust.env import Environment\nfrom locust.exception import CatchResponseError, LocustError, ResponseError\nfrom locust.user import User\nfrom locust.util.deprecation import DeprecatedFastHttpLocustClass as FastHttpLocust\n\nimport json\nimport json as unshadowed_json # some methods take a named parameter called json\nimport re\nimport socket\nimport time\nimport traceback\nfrom base64 import b64encode\nfrom contextlib import contextmanager\nfrom http.cookiejar import CookieJar\nfrom json.decoder import JSONDecodeError\nfrom ssl import SSLError\nfrom typing import Any, Callable, Generator, cast\nfrom urllib.parse import urlparse, urlunparse\n\nimport gevent\nfrom charset_normalizer import detect\nfrom gevent.timeout import Timeout\nfrom geventhttpclient._parser import HTTPParseError\nfrom geventhttpclient.client import HTTPClientPool\nfrom geventhttpclient.header import Headers\nfrom geventhttpclient.response import HTTPConnectionClosed, HTTPSocketPoolResponse\nfrom geventhttpclient.useragent import CompatRequest, CompatResponse, ConnectionError, UserAgent\n\n# borrow requests's content-type header parsing\nfrom requests.utils import get_encoding_from_headers\n\n# Monkey patch geventhttpclient.useragent.CompatRequest so that Cookiejar works with Python >= 3.3\n# More info: https://github.com/requests/requests/pull/871\nCompatRequest.unverifiable = False\n\n# Workaround for AttributeError: 'CompatRequest' object has no attribute 'type' in Cookiejar\n# https://github.com/locustio/locust/issues/1138\n# Might allow secure cookies over non-secure connections but that is a minor concern in a load testing tool\nCompatRequest.type = \"https\"\n\n# Regexp for checking if an absolute URL was specified\nabsolute_http_url_regexp = re.compile(r\"^https?://\", re.I)\n\n# List of exceptions that can be raised by geventhttpclient when sending an HTTP request,\n# and that should result in a Locust failure\nFAILURE_EXCEPTIONS = (\n ConnectionError,\n ConnectionRefusedError,\n ConnectionResetError,\n socket.error,\n SSLError,\n Timeout,\n HTTPConnectionClosed,\n)\n\n\ndef _construct_basic_auth_str(username, password):\n \"\"\"Construct Authorization header value to be used in HTTP Basic Auth\"\"\"\n if isinstance(username, str):\n username = username.encode(\"latin1\")\n if isinstance(password, str):\n password = password.encode(\"latin1\")\n return \"Basic \" + b64encode(b\":\".join((username, password))).strip().decode(\"ascii\")\n\n\ndef insecure_ssl_context_factory():\n context = gevent.ssl.create_default_context()\n context.check_hostname = False\n context.verify_mode = gevent.ssl.CERT_NONE\n return context\n\n\nclass FastHttpSession:\n auth_header = None\n\n def __init__(\n self,\n environment: Environment,\n base_url: str,\n user: User | None,\n insecure=True,\n client_pool: HTTPClientPool | None = None,\n ssl_context_factory: Callable | None = None,\n **kwargs,\n ):\n self.environment = environment\n self.base_url = base_url\n self.cookiejar = CookieJar()\n self.user = user\n if not ssl_context_factory:\n if insecure:\n ssl_context_factory = insecure_ssl_context_factory\n else:\n ssl_context_factory = gevent.ssl.create_default_context\n self.client = LocustUserAgent(\n cookiejar=self.cookiejar,\n ssl_context_factory=ssl_context_factory,\n insecure=insecure,\n client_pool=client_pool,\n **kwargs,\n )\n\n # Check for basic authentication\n parsed_url = urlparse(self.base_url)\n if parsed_url.username and parsed_url.password:\n netloc = parsed_url.hostname or \"\"\n if parsed_url.port:\n netloc += \":%d\" % parsed_url.port\n\n # remove username and password from the base_url\n self.base_url = urlunparse(\n (parsed_url.scheme, netloc, parsed_url.path, parsed_url.params, parsed_url.query, parsed_url.fragment)\n )\n # store authentication header (we construct this by using _basic_auth_str() function from requests.auth)\n self.auth_header = _construct_basic_auth_str(parsed_url.username, parsed_url.password)\n\n def _build_url(self, path):\n \"\"\"prepend url with hostname unless it's already an absolute URL\"\"\"\n if absolute_http_url_regexp.match(path):\n return path\n else:\n return f\"{self.base_url}{path}\"\n\n def _send_request_safe_mode(self, method, url, **kwargs):\n \"\"\"\n Send an HTTP request, and catch any exception that might occur due to either\n connection problems, or invalid HTTP status codes\n \"\"\"\n try:\n return self.client.urlopen(url, method=method, **kwargs)\n except FAILURE_EXCEPTIONS as e:\n if hasattr(e, \"response\"):\n r = e.response\n else:\n req = self.client._make_request(\n url,\n method=method,\n headers=kwargs.get(\"headers\"),\n payload=kwargs.get(\"payload\"),\n params=kwargs.get(\"params\"),\n )\n r = ErrorResponse(url=url, request=req)\n r.error = e\n return r\n\n def request(\n self,\n method: str,\n url: str,\n name: str | None = None,\n data: str | dict | None = None,\n catch_response: bool = False,\n stream: bool = False,\n headers: dict | None = None,\n auth=None,\n json: dict | None = None,\n allow_redirects=True,\n context: dict = {},\n **kwargs,\n ) -> ResponseContextManager | FastResponse:\n \"\"\"\n Send and HTTP request\n Returns :py:class:`locust.contrib.fasthttp.FastResponse` object.\n\n :param method: method for the new :class:`Request` object.\n :param url: path that will be concatenated with the base host URL that has been specified.\n Can also be a full URL, in which case the full URL will be requested, and the base host\n is ignored.\n :param name: (optional) An argument that can be specified to use as label in Locust's\n statistics instead of the URL path. This can be used to group different URL's\n that are requested into a single entry in Locust's statistics.\n :param catch_response: (optional) Boolean argument that, if set, can be used to make a request\n return a context manager to work as argument to a with statement. This will allow the\n request to be marked as a fail based on the content of the response, even if the response\n code is ok (2xx). The opposite also works, one can use catch_response to catch a request\n and then mark it as successful even if the response code was not (i.e 500 or 404).\n :param data: (optional) String/bytes to send in the body of the request.\n :param json: (optional) Dictionary to send in the body of the request.\n Automatically sets Content-Type and Accept headers to \"application/json\".\n Only used if data is not set.\n :param headers: (optional) Dictionary of HTTP Headers to send with the request.\n :param auth: (optional) Auth (username, password) tuple to enable Basic HTTP Auth.\n :param stream: (optional) If set to true the response body will not be consumed immediately\n and can instead be consumed by accessing the stream attribute on the Response object.\n Another side effect of setting stream to True is that the time for downloading the response\n content will not be accounted for in the request time that is reported by Locust.\n \"\"\"\n # prepend url with hostname unless it's already an absolute URL\n built_url = self._build_url(url)\n\n start_time = time.time() # seconds since epoch\n\n if self.user:\n context = {**self.user.context(), **context}\n\n headers = headers or {}\n if auth:\n headers[\"Authorization\"] = _construct_basic_auth_str(auth[0], auth[1])\n elif self.auth_header:\n headers[\"Authorization\"] = self.auth_header\n if \"Accept-Encoding\" not in headers and \"accept-encoding\" not in headers:\n headers[\"Accept-Encoding\"] = \"gzip, deflate\"\n\n if not data and json is not None:\n data = unshadowed_json.dumps(json)\n if \"Content-Type\" not in headers and \"content-type\" not in headers:\n headers[\"Content-Type\"] = \"application/json\"\n if \"Accept\" not in headers and \"accept\" not in headers:\n headers[\"Accept\"] = \"application/json\"\n\n if not allow_redirects:\n old_redirect_response_codes = self.client.redirect_resonse_codes\n self.client.redirect_resonse_codes = []\n\n start_perf_counter = time.perf_counter()\n # send request, and catch any exceptions\n response = self._send_request_safe_mode(method, built_url, payload=data, headers=headers, **kwargs)\n request_meta = {\n \"request_type\": method,\n \"name\": name or url,\n \"context\": context,\n \"response\": response,\n \"exception\": None,\n \"start_time\": start_time,\n \"url\": built_url, # this is a small deviation from HttpSession, which gets the final (possibly redirected) URL\n }\n\n if not allow_redirects:\n self.client.redirect_resonse_codes = old_redirect_response_codes\n\n # get the length of the content, but if the argument stream is set to True, we take\n # the size from the content-length header, in order to not trigger fetching of the body\n if stream:\n request_meta[\"response_length\"] = int(response.headers.get(\"response_length\") or 0)\n else:\n try:\n request_meta[\"response_length\"] = len(response.content or \"\")\n except HTTPParseError as e:\n request_meta[\"response_time\"] = (time.perf_counter() - start_perf_counter) * 1000\n request_meta[\"response_length\"] = 0\n request_meta[\"exception\"] = e\n self.environment.events.request.fire(**request_meta)\n return response\n\n # Record the consumed time\n # Note: This is intentionally placed after we record the content_size above, since\n # we'll then trigger fetching of the body (unless stream=True)\n request_meta[\"response_time\"] = int((time.perf_counter() - start_perf_counter) * 1000)\n\n if catch_response:\n return ResponseContextManager(response, environment=self.environment, request_meta=request_meta)\n else:\n try:\n response.raise_for_status()\n except FAILURE_EXCEPTIONS as e:\n request_meta[\"exception\"] = e\n\n self.environment.events.request.fire(**request_meta)\n return response\n\n def delete(self, url, **kwargs):\n return self.request(\"DELETE\", url, **kwargs)\n\n def get(self, url, **kwargs):\n \"\"\"Sends a GET request\"\"\"\n return self.request(\"GET\", url, **kwargs)\n\n def head(self, url, **kwargs):\n \"\"\"Sends a HEAD request\"\"\"\n return self.request(\"HEAD\", url, **kwargs)\n\n def options(self, url, **kwargs):\n \"\"\"Sends a OPTIONS request\"\"\"\n return self.request(\"OPTIONS\", url, **kwargs)\n\n def patch(self, url, data=None, **kwargs):\n \"\"\"Sends a POST request\"\"\"\n return self.request(\"PATCH\", url, data=data, **kwargs)\n\n def post(self, url, data=None, **kwargs):\n \"\"\"Sends a POST request\"\"\"\n return self.request(\"POST\", url, data=data, **kwargs)\n\n def put(self, url, data=None, **kwargs):\n \"\"\"Sends a PUT request\"\"\"\n return self.request(\"PUT\", url, data=data, **kwargs)\n\n\nclass FastHttpUser(User):\n \"\"\"\n FastHttpUser provides the same API as HttpUser, but uses geventhttpclient instead of python-requests\n as its underlying client. It uses considerably less CPU on the load generator, and should work\n as a simple drop-in-replacement in most cases.\n \"\"\"\n\n # Below are various UserAgent settings. Change these in your subclass to alter FastHttpUser's behaviour.\n # It needs to be done before FastHttpUser is instantiated, changing them later will have no effect\n\n network_timeout: float = 60.0\n \"\"\"Parameter passed to FastHttpSession\"\"\"\n\n connection_timeout: float = 60.0\n \"\"\"Parameter passed to FastHttpSession\"\"\"\n\n max_redirects: int = 5\n \"\"\"Parameter passed to FastHttpSession. Default 5, meaning 4 redirects.\"\"\"\n\n max_retries: int = 1\n \"\"\"Parameter passed to FastHttpSession. Default 1, meaning zero retries.\"\"\"\n\n insecure: bool = True\n \"\"\"Parameter passed to FastHttpSession. Default True, meaning no SSL verification.\"\"\"\n\n default_headers: dict | None = None\n \"\"\"Parameter passed to FastHttpSession. Adds the listed headers to every request.\"\"\"\n\n concurrency: int = 10\n \"\"\"Parameter passed to FastHttpSession. Describes number of concurrent requests allowed by the FastHttpSession. Default 10.\n Note that setting this value has no effect when custom client_pool was given, and you need to spawn a your own gevent pool\n to use it (as Users only have one greenlet). See test_fasthttp.py / test_client_pool_concurrency for an example.\"\"\"\n\n client_pool: HTTPClientPool | None = None\n \"\"\"HTTP client pool to use. If not given, a new pool is created per single user.\"\"\"\n\n ssl_context_factory: Callable | None = None\n \"\"\"A callable that return a SSLContext for overriding the default context created by the FastHttpSession.\"\"\"\n\n abstract = True\n \"\"\"Dont register this as a User class that can be run by itself\"\"\"\n\n _callstack_regex = re.compile(r' File \"(\\/.[^\"]*)\", line (\\d*),(.*)')\n\n def __init__(self, environment):\n super().__init__(environment)\n if self.host is None:\n raise LocustError(\n \"You must specify the base host. Either in the host attribute in the User class, or on the command line using the --host option.\"\n )\n\n self.client: FastHttpSession = FastHttpSession(\n self.environment,\n base_url=self.host,\n network_timeout=self.network_timeout,\n connection_timeout=self.connection_timeout,\n max_redirects=self.max_redirects,\n max_retries=self.max_retries,\n insecure=self.insecure,\n concurrency=self.concurrency,\n user=self,\n client_pool=self.client_pool,\n ssl_context_factory=self.ssl_context_factory,\n headers=self.default_headers,\n )\n \"\"\"\n Instance of HttpSession that is created upon instantiation of User.\n The client support cookies, and therefore keeps the session between HTTP requests.\n \"\"\"\n\n @contextmanager\n def rest(\n self, method, url, headers: dict | None = None, **kwargs\n ) -> Generator[RestResponseContextManager, None, None]:\n \"\"\"\n A wrapper for self.client.request that:\n\n * Parses the JSON response to a dict called ``js`` in the response object. Marks the request as failed if the response was not valid JSON.\n * Defaults ``Content-Type`` and ``Accept`` headers to ``application/json``\n * Sets ``catch_response=True`` (so always use a :ref:`with-block `)\n * Catches any unhandled exceptions thrown inside your with-block, marking the sample as failed (instead of exiting the task immediately without even firing the request event)\n \"\"\"\n headers = headers or {}\n if not (\"Content-Type\" in headers or \"content-type\" in headers):\n headers[\"Content-Type\"] = \"application/json\"\n if not (\"Accept\" in headers or \"accept\" in headers):\n headers[\"Accept\"] = \"application/json\"\n with self.client.request(method, url, catch_response=True, headers=headers, **kwargs) as r:\n resp = cast(RestResponseContextManager, r)\n resp.js = None # type: ignore\n if resp.content is None:\n resp.failure(str(resp.error))\n elif resp.text:\n try:\n resp.js = resp.json()\n except JSONDecodeError as e:\n resp.failure(\n f\"Could not parse response as JSON. {resp.text[:250]}, response code {resp.status_code}, error {e}\"\n )\n try:\n yield resp\n except AssertionError as e:\n if e.args:\n if e.args[0].endswith(\",\"):\n short_resp = resp.text[:200] if resp.text else resp.text\n resp.failure(f\"{e.args[0][:-1]}, response was {short_resp}\")\n else:\n resp.failure(e.args[0])\n else:\n resp.failure(\"Assertion failed\")\n\n except Exception as e:\n error_lines = []\n for l in traceback.format_exc().split(\"\\n\"):\n m = self._callstack_regex.match(l)\n if m:\n filename = re.sub(r\"/(home|Users/\\w*)/\", \"~/\", m.group(1))\n error_lines.append(filename + \":\" + m.group(2) + m.group(3))\n short_resp = resp.text[:200] if resp.text else resp.text\n resp.failure(f\"{e.__class__.__name__}: {e} at {', '.join(error_lines)}. Response was {short_resp}\")\n\n @contextmanager\n def rest_(self, method, url, name=None, **kwargs) -> Generator[RestResponseContextManager, None, None]:\n \"\"\"\n Some REST api:s use a timestamp as part of their query string (mainly to break through caches).\n This is a convenience method for that, appending a _= parameter automatically\n \"\"\"\n separator = \"&\" if \"?\" in url else \"?\"\n if name is None:\n name = url + separator + \"_=...\"\n with self.rest(method, f\"{url}{separator}_={int(time.time()*1000)}\", name=name, **kwargs) as resp:\n yield resp\n\n\nclass FastRequest(CompatRequest):\n payload: str | None = None\n\n @property\n def body(self) -> str | None:\n return self.payload\n\n\nclass FastResponse(CompatResponse):\n headers: Headers | None = None\n \"\"\"Dict like object containing the response headers\"\"\"\n\n _response: HTTPSocketPoolResponse | None = None\n\n encoding: str | None = None\n \"\"\"In some cases setting the encoding explicitly is needed. If so, do it before calling .text\"\"\"\n\n request: FastRequest | None = None\n\n def __init__(\n self,\n ghc_response: HTTPSocketPoolResponse,\n request: FastRequest | None = None,\n sent_request: str | None = None,\n ):\n super().__init__(ghc_response, request, sent_request)\n\n self.request = request\n\n @property\n def text(self) -> str | None:\n \"\"\"\n Returns the text content of the response as a decoded string\n \"\"\"\n if self.content is None:\n return None\n if self.encoding is None:\n if self.headers is None:\n # No information, try to detect\n self.encoding = detect(self.content)[\"encoding\"]\n else:\n self.encoding = get_encoding_from_headers(self.headers)\n # No information, try to detect\n if not self.encoding:\n self.encoding = detect(self.content)[\"encoding\"]\n if self.encoding is None:\n return None\n return str(self.content, str(self.encoding), errors=\"replace\")\n\n @property\n def url(self) -> str | None:\n \"\"\"\n Get \"response\" URL, which is the same as the request URL. This is a small deviation from HttpSession, which gets the final (possibly redirected) URL.\n \"\"\"\n if self.request is not None:\n return self.request.url\n\n return None\n\n def json(self) -> dict:\n \"\"\"\n Parses the response as json and returns a dict\n \"\"\"\n return json.loads(self.text) # type: ignore\n\n def raise_for_status(self):\n \"\"\"Raise any connection errors that occurred during the request\"\"\"\n if hasattr(self, \"error\") and self.error:\n raise self.error\n\n @property\n def status_code(self) -> int:\n \"\"\"\n We override status_code in order to return None if no valid response was\n returned. E.g. in the case of connection errors\n \"\"\"\n return self._response.get_code() if self._response is not None else 0\n\n @property\n def ok(self):\n \"\"\"Returns True if :attr:`status_code` is less than 400, False if not.\"\"\"\n return self.status_code < 400\n\n def _content(self):\n if self.headers is None:\n return None\n return super()._content()\n\n def success(self):\n raise LocustError(\n \"If you want to change the state of the request, you must pass catch_response=True. See http://docs.locust.io/en/stable/writing-a-locustfile.html#validating-responses\"\n )\n\n def failure(self):\n raise LocustError(\n \"If you want to change the state of the request, you must pass catch_response=True. See http://docs.locust.io/en/stable/writing-a-locustfile.html#validating-responses\"\n )\n\n\nclass ErrorResponse:\n \"\"\"\n This is used as a dummy response object when geventhttpclient raises an error\n that doesn't have a real Response object attached. E.g. a socket error or similar\n \"\"\"\n\n headers: Headers | None = None\n content = None\n status_code = 0\n error: Exception | None = None\n text: str | None = None\n request: CompatRequest\n\n def __init__(self, url: str, request: CompatRequest):\n self.url = url\n self.request = request\n\n def raise_for_status(self):\n raise self.error\n\n\nclass LocustUserAgent(UserAgent):\n response_type = FastResponse\n request_type = FastRequest\n valid_response_codes = frozenset([200, 201, 202, 203, 204, 205, 206, 207, 208, 226, 301, 302, 303, 304, 307])\n\n def __init__(self, client_pool: HTTPClientPool | None = None, **kwargs):\n super().__init__(**kwargs)\n\n if client_pool is not None:\n self.clientpool = client_pool\n\n def _urlopen(self, request):\n \"\"\"Override _urlopen() in order to make it use the response_type attribute\"\"\"\n client = self.clientpool.get_client(request.url_split)\n resp = client.request(\n request.method, request.url_split.request_uri, body=request.payload, headers=request.headers\n )\n return self.response_type(resp, request=request, sent_request=resp._sent_request)\n\n\n...\n# Path: locust/debug.py\nfrom __future__ import annotations\n\nimport locust\nimport locust.log\nfrom locust import argument_parser\nfrom locust.env import Environment\nfrom locust.exception import CatchResponseError, RescheduleTask\n\nimport inspect\nimport os\nfrom datetime import datetime, timezone\nfrom typing import TYPE_CHECKING\n\nif TYPE_CHECKING:\n from locust import User\n\n\ndef _print_t(s):\n \"\"\"\n Print something with a tab instead of newline at the end\n \"\"\"\n print(str(s), end=\"\\t\")\n\n\nclass PrintListener:\n \"\"\"\n Print every response (useful when debugging a single locust)\n \"\"\"\n\n def __init__(\n self,\n env: Environment,\n include_length=False,\n include_time=False,\n include_context=False,\n include_payload=False,\n ):\n env.events.request.add_listener(self.on_request)\n\n self.include_length = \"length\\t\" if include_length else \"\"\n self.include_time = \"time \\t\" if include_time else \"\"\n self.include_context = \"context\\t\" if include_context else \"\"\n self.include_payload = \"payload\\t\" if include_payload else \"\"\n\n print(\n f\"\\n{self.include_time}type\\t{'name'.ljust(50)}\\tresp_ms\\t{self.include_length}exception\\t{self.include_context}\\t{self.include_payload}\"\n )\n\n def on_request(\n self,\n request_type,\n name,\n response_time,\n response_length,\n exception,\n context: dict,\n start_time=None,\n response=None,\n **_kwargs,\n ):\n if exception:\n if isinstance(exception, RescheduleTask):\n pass\n if isinstance(exception, CatchResponseError):\n e = str(exception)\n else:\n try:\n e = repr(exception)\n except AttributeError:\n e = f\"{exception.__class__} (and it has no string representation)\"\n errortext = e[:500].replace(\"\\n\", \" \")\n else:\n errortext = \"\"\n\n if response_time is None:\n response_time = -1\n n = name.ljust(30) if name else \"\"\n\n if self.include_time:\n if start_time:\n _print_t(datetime.fromtimestamp(start_time, tz=timezone.utc))\n else:\n _print_t(datetime.now())\n\n _print_t(request_type)\n _print_t(n.ljust(50))\n _print_t(str(round(response_time)).ljust(7))\n\n if self.include_length:\n _print_t(response_length)\n\n _print_t(errortext.ljust(9))\n\n if self.include_context:\n _print_t(context or \"\")\n\n if self.include_payload:\n _print_t(response._request.payload)\n\n print()\n\n\n_env: Environment | None = None # minimal Environment for debugging\n\n\ndef run_single_user(\n user_class: type[User],\n include_length=False,\n include_time=False,\n include_context=False,\n include_payload=False,\n loglevel: str | None = \"WARNING\",\n):\n \"\"\"\n Runs a single User. Useful when you want to run a debugger.\n\n It creates in a new locust :py:attr:`Environment ` and triggers any ``init`` or ``test_start`` :ref:`events ` as normal.\n\n It does **not** trigger ``test_stop`` or ``quit`` when you quit the debugger.\n\n It prints some info about every request to stdout, and you can get additional info using the `include_*` flags\n\n It also initiates logging on WARNING level (not INFO, because it could interfere with the printing of requests),\n but you can change that by passing a log level (or disabling logging entirely by passing None)\n \"\"\"\n global _env\n\n if loglevel:\n locust.log.setup_logging(loglevel)\n\n if not _env:\n options = argument_parser.parse_options()\n\n # in case your test goes looking for the file name of your locustfile\n frame = inspect.stack()[1]\n locustfile = os.path.basename(frame[0].f_code.co_filename)\n options.locustfile = locustfile\n\n _env = Environment(events=locust.events, locustfile=locustfile, host=options.host, parsed_options=options)\n\n # log requests to stdout\n PrintListener(\n _env,\n include_length=include_length,\n include_time=include_time,\n include_context=include_context,\n include_payload=include_payload,\n )\n # fire various events (quit and test_stop will never get called, sorry about that)\n _env.events.init.fire(environment=_env, runner=None, web_ui=None)\n # uncaught events will be suppressed, so check if that happened\n if locust.log.unhandled_greenlet_exception:\n raise Exception(\"Unhandled exception in init\")\n\n # do the things that the Runner usually does\n _env.user_classes = [user_class]\n _env._filter_tasks_by_tags()\n _env.events.test_start.fire(environment=_env)\n if _env.host:\n user_class.host = _env.host\n\n # create a single user\n user = user_class(_env)\n setattr(_env, \"single_user_instance\", user) # if you happen to need access to this from the Environment instance\n user.run()\n\n# Path: locust/user/sequential_taskset.py\nfrom locust.exception import LocustError\n\nimport logging\n\nfrom .task import TaskSet, TaskSetMeta\n\n\nclass SequentialTaskSetMeta(TaskSetMeta):\n \"\"\"\n Meta class for SequentialTaskSet. It's used to allow SequentialTaskSet classes to specify\n task execution in both a list as the tasks attribute or using the @task decorator\n\n We use the fact that class_dict order is the order of declaration in Python 3.6\n (See https://www.python.org/dev/peps/pep-0520/)\n \"\"\"\n\n def __new__(mcs, classname, bases, class_dict):\n new_tasks = []\n for base in bases:\n # first get tasks from base classes\n if hasattr(base, \"tasks\") and base.tasks:\n new_tasks += base.tasks\n for key, value in class_dict.items():\n if key == \"tasks\":\n # we want to insert tasks from the tasks attribute at the point of it's declaration\n # compared to methods declared with @task\n if isinstance(value, list):\n new_tasks.extend(value)\n else:\n raise ValueError(\"On SequentialTaskSet the task attribute can only be set to a list\")\n\n if \"locust_task_weight\" in dir(value):\n # method decorated with @task\n for _ in range(value.locust_task_weight):\n new_tasks.append(value)\n\n class_dict[\"tasks\"] = new_tasks\n return type.__new__(mcs, classname, bases, class_dict)\n\n\nclass SequentialTaskSet(TaskSet, metaclass=SequentialTaskSetMeta):\n \"\"\"\n Class defining a sequence of tasks that a User will execute.\n\n Works like TaskSet, but task weight is ignored, and all tasks are executed in order. Tasks can\n either be specified by setting the *tasks* attribute to a list of tasks, or by declaring tasks\n as methods using the @task decorator. The order of declaration decides the order of execution.\n\n It's possible to combine a task list in the *tasks* attribute, with some tasks declared using\n the @task decorator. The order of declaration is respected also in that case.\n \"\"\"\n\n def __init__(self, *args, **kwargs):\n super().__init__(*args, **kwargs)\n self._task_index = 0\n\n def get_next_task(self):\n if not self.tasks:\n raise LocustError(\n \"No tasks defined. Use the @task decorator or set the 'tasks' attribute of the SequentialTaskSet\"\n )\n task = self.tasks[self._task_index % len(self.tasks)]\n self._task_index += 1\n return task\n\n# Path: locust/__init__.py\nimport os\n\nif os.getenv(\"LOCUST_PLAYWRIGHT\", None):\n print(\"LOCUST_PLAYWRIGHT setting is no longer needed (because locust-plugins no longer installs trio)\")\n print(\"Uninstall trio package and remove the setting.\")\n try:\n # preserve backwards compatibility for now\n import trio\n except ModuleNotFoundError:\n # dont show a massive callstack if trio is not installed\n os._exit(1)\n\nfrom gevent import monkey\n\nmonkey.patch_all()\n\nfrom ._version import version as __version__\nfrom .contrib.fasthttp import FastHttpUser\nfrom .debug import run_single_user\nfrom .event import Events\nfrom .shape import LoadTestShape\nfrom .user import wait_time\nfrom .user.sequential_taskset import SequentialTaskSet\nfrom .user.task import TaskSet, tag, task\nfrom .user.users import HttpUser, User\nfrom .user.wait_time import between, constant, constant_pacing, constant_throughput\n\nevents = Events()\n\n__all__ = (\n \"SequentialTaskSet\",\n \"wait_time\",\n \"task\",\n \"tag\",\n \"TaskSet\",\n \"HttpUser\",\n \"FastHttpUser\",\n \"User\",\n \"between\",\n \"constant\",\n \"constant_pacing\",\n \"constant_throughput\",\n \"events\",\n \"LoadTestShape\",\n \"run_single_user\",\n)\n\n# Used for raising a DeprecationWarning if old Locust/HttpLocust is used\nfrom .util.deprecation import DeprecatedHttpLocustClass as HttpLocust\nfrom .util.deprecation import DeprecatedLocustClass as Locust\n\n# Path: locust/input_events.py\nfrom __future__ import annotations\n\nimport logging\nimport os\nimport sys\nfrom typing import Callable\n\nimport gevent\n\nif os.name == \"nt\":\n import pywintypes\n from win32api import STD_INPUT_HANDLE\n from win32console import (\n ENABLE_ECHO_INPUT,\n ENABLE_LINE_INPUT,\n ENABLE_PROCESSED_INPUT,\n KEY_EVENT,\n GetStdHandle,\n )\nelse:\n import select\n import termios\n import tty\n\n\nclass InitError(Exception):\n pass\n\n\nclass UnixKeyPoller:\n def __init__(self):\n if sys.stdin.isatty():\n self.stdin = sys.stdin.fileno()\n self.tattr = termios.tcgetattr(self.stdin)\n tty.setcbreak(self.stdin, termios.TCSANOW)\n else:\n raise InitError(\"Terminal was not a tty. Keyboard input disabled\")\n\n def cleanup(self):\n termios.tcsetattr(self.stdin, termios.TCSANOW, self.tattr)\n\n def poll(_self):\n dr, dw, de = select.select([sys.stdin], [], [], 0)\n if not dr == []:\n return sys.stdin.read(1)\n return None\n\n\nclass WindowsKeyPoller:\n def __init__(self):\n if sys.stdin.isatty():\n try:\n self.read_handle = GetStdHandle(STD_INPUT_HANDLE)\n self.read_handle.SetConsoleMode(ENABLE_LINE_INPUT | ENABLE_ECHO_INPUT | ENABLE_PROCESSED_INPUT)\n self.cur_event_length = 0\n self.cur_keys_length = 0\n self.captured_chars = []\n except pywintypes.error:\n raise InitError(\"Terminal says its a tty but we couldnt enable line input. Keyboard input disabled.\")\n else:\n raise InitError(\"Terminal was not a tty. Keyboard input disabled\")\n\n def cleanup(self):\n pass\n\n def poll(self):\n if self.captured_chars:\n return self.captured_chars.pop(0)\n\n events_peek = self.read_handle.PeekConsoleInput(10000)\n\n if not events_peek:\n return None\n\n if not len(events_peek) == self.cur_event_length:\n for cur_event in events_peek[self.cur_event_length :]:\n if cur_event.EventType == KEY_EVENT:\n if ord(cur_event.Char) and cur_event.KeyDown:\n cur_char = str(cur_event.Char)\n self.captured_chars.append(cur_char)\n\n self.cur_event_length = len(events_peek)\n\n if self.captured_chars:\n return self.captured_chars.pop(0)\n else:\n return None\n\n\ndef get_poller():\n if os.name == \"nt\":\n return WindowsKeyPoller()\n else:\n return UnixKeyPoller()\n\n\ndef input_listener(key_to_func_map: dict[str, Callable]):\n def input_listener_func():\n try:\n poller = get_poller()\n except InitError as e:\n logging.debug(e)\n return\n\n try:\n while True:\n input = poller.poll()\n if input:\n for key in key_to_func_map:\n if input == key:\n key_to_func_map[key]()\n else:\n gevent.sleep(0.2)\n except Exception as e:\n logging.warning(f\"Exception in keyboard input poller: {e}\")\n finally:\n poller.cleanup()\n\n return input_listener_func\n\n# Path: locust/util/load_locustfile.py\nfrom __future__ import annotations\n\nimport importlib\nimport importlib.util\nimport inspect\nimport os\nimport sys\n\nfrom ..shape import LoadTestShape\nfrom ..user import User\n\n\ndef is_user_class(item):\n \"\"\"\n Check if a variable is a runnable (non-abstract) User class\n \"\"\"\n return bool(inspect.isclass(item) and issubclass(item, User) and item.abstract is False)\n\n\ndef is_shape_class(item):\n \"\"\"\n Check if a class is a LoadTestShape\n \"\"\"\n return bool(inspect.isclass(item) and issubclass(item, LoadTestShape) and not getattr(item, \"abstract\", True))\n\n\ndef load_locustfile(path) -> tuple[str | None, dict[str, User], list[LoadTestShape]]:\n \"\"\"\n Import given locustfile path and return (docstring, callables).\n\n Specifically, the locustfile's ``__doc__`` attribute (a string) and a\n dictionary of ``{'name': callable}`` containing all callables which pass\n the \"is a Locust\" test.\n \"\"\"\n\n # Start with making sure the current working dir is in the sys.path\n sys.path.insert(0, os.getcwd())\n # Get directory and locustfile name\n directory, locustfile = os.path.split(path)\n # If the directory isn't in the PYTHONPATH, add it so our import will work\n added_to_path = False\n index = None\n if directory not in sys.path:\n sys.path.insert(0, directory)\n added_to_path = True\n # If the directory IS in the PYTHONPATH, move it to the front temporarily,\n # otherwise other locustfiles -- like Locusts's own -- may scoop the intended\n # one.\n else:\n i = sys.path.index(directory)\n if i != 0:\n # Store index for later restoration\n index = i\n # Add to front, then remove from original position\n sys.path.insert(0, directory)\n del sys.path[i + 1]\n\n # Perform the import\n module_name = os.path.splitext(locustfile)[0]\n if module_name == \"locust\":\n module_name = \"locustfile\" # Avoid conflict with locust package\n loader = importlib.machinery.SourceFileLoader(module_name, path)\n spec = importlib.util.spec_from_file_location(module_name, path, loader=loader)\n if spec is None:\n sys.stderr.write(f\"Unable to get module spec for {module_name} in {path}\")\n sys.exit(1)\n\n imported = importlib.util.module_from_spec(spec)\n sys.modules[imported.__name__] = imported\n loader.exec_module(imported)\n\n # Remove directory from path if we added it ourselves (just to be neat)\n if added_to_path:\n del sys.path[0]\n # Put back in original index if we moved it\n if index is not None:\n sys.path.insert(index + 1, directory)\n del sys.path[0]\n # Return our two-tuple\n user_classes = {name: value for name, value in vars(imported).items() if is_user_class(value)}\n\n # Find shape class, if any, return it\n shape_classes = [value() for value in vars(imported).values() if is_shape_class(value)]\n\n return imported.__doc__, user_classes, shape_classes\n\n# Path: locust/main.py\nfrom __future__ import annotations\n\nimport locust\n\nimport atexit\nimport errno\nimport gc\nimport inspect\nimport json\nimport logging\nimport os\nimport signal\nimport sys\nimport time\nimport traceback\n\nimport gevent\n\nfrom . import log, stats\nfrom .argument_parser import parse_locustfile_option, parse_options\nfrom .env import Environment\nfrom .html import get_html_report\nfrom .input_events import input_listener\nfrom .log import greenlet_exception_logger, setup_logging\nfrom .stats import (\n StatsCSV,\n StatsCSVFileWriter,\n print_error_report,\n print_percentile_stats,\n print_stats,\n print_stats_json,\n stats_history,\n stats_printer,\n)\nfrom .user.inspectuser import print_task_ratio, print_task_ratio_json\nfrom .util.load_locustfile import load_locustfile\nfrom .util.timespan import parse_timespan\n\ntry:\n # import locust_plugins if it is installed, to allow it to register custom arguments etc\n import locust_plugins # pyright: ignore[reportMissingImports]\nexcept ModuleNotFoundError:\n pass\n\nversion = locust.__version__\n\n# Options to ignore when using a custom shape class without `use_common_options=True`\n# See: https://docs.locust.io/en/stable/custom-load-shape.html#use-common-options\nCOMMON_OPTIONS = {\n \"num_users\": \"users\",\n \"spawn_rate\": \"spawn-rate\",\n \"run_time\": \"run-time\",\n}\n\n\ndef create_environment(\n user_classes,\n options,\n events=None,\n shape_class=None,\n locustfile=None,\n available_user_classes=None,\n available_shape_classes=None,\n available_user_tasks=None,\n):\n \"\"\"\n Create an Environment instance from options\n \"\"\"\n return Environment(\n locustfile=locustfile,\n user_classes=user_classes,\n shape_class=shape_class,\n events=events,\n host=options.host,\n reset_stats=options.reset_stats,\n parsed_options=options,\n available_user_classes=available_user_classes,\n available_shape_classes=available_shape_classes,\n available_user_tasks=available_user_tasks,\n )\n\n\ndef main():\n # find specified locustfile(s) and make sure it exists, using a very simplified\n # command line parser that is only used to parse the -f option.\n locustfiles = parse_locustfile_option()\n locustfiles_length = len(locustfiles)\n\n # Grabbing the Locustfile if only one was provided. Otherwise, allowing users to select the locustfile in the UI\n # If --headless or --autostart and multiple locustfiles, all provided UserClasses will be ran\n locustfile = locustfiles[0] if locustfiles_length == 1 else None\n\n # Importing Locustfile(s) - setting available UserClasses and ShapeClasses to choose from in UI\n user_classes: dict[str, locust.User] = {}\n available_user_classes = {}\n available_shape_classes = {}\n available_user_tasks = {}\n shape_class = None\n for _locustfile in locustfiles:\n docstring, _user_classes, shape_classes = load_locustfile(_locustfile)\n\n # Setting Available Shape Classes\n if shape_classes:\n shape_class = shape_classes[0]\n for shape_class in shape_classes:\n shape_class_name = type(shape_class).__name__\n if shape_class_name in available_shape_classes.keys():\n sys.stderr.write(f\"Duplicate shape classes: {shape_class_name}\\n\")\n sys.exit(1)\n\n available_shape_classes[shape_class_name] = shape_class\n\n # Setting Available User Classes\n for key, value in _user_classes.items():\n if key in available_user_classes.keys():\n previous_path = inspect.getfile(user_classes[key])\n new_path = inspect.getfile(value)\n if previous_path == new_path:\n # The same User class was defined in two locustfiles but one probably imported the other, so we just ignore it\n continue\n else:\n sys.stderr.write(\n f\"Duplicate user class names: {key} is defined in both {previous_path} and {new_path}\\n\"\n )\n sys.exit(1)\n\n user_classes[key] = value\n available_user_classes[key] = value\n available_user_tasks[key] = value.tasks or None\n\n if len(stats.PERCENTILES_TO_CHART) != 2:\n logging.error(\"stats.PERCENTILES_TO_CHART parameter should be 2 parameters \\n\")\n sys.exit(1)\n\n if len(stats.MODERN_UI_PERCENTILES_TO_CHART) > 6:\n logging.error(\"stats.MODERN_UI_PERCENTILES_TO_CHART parameter should be a maximum of 6 parameters \\n\")\n sys.exit(1)\n\n \ndef is_valid_percentile(parameter):\n try:\n if 0 < float(parameter) < 1:\n return True\n return False\n except ValueError:\n return False\n\n for percentile in stats.PERCENTILES_TO_CHART:\n if not is_valid_percentile(percentile):\n logging.error(\n \"stats.PERCENTILES_TO_CHART parameter need to be float and value between. 0 < percentile < 1 Eg 0.95\\n\"\n )\n sys.exit(1)\n\n for percentile in stats.PERCENTILES_TO_STATISTICS:\n if not is_valid_percentile(percentile):\n logging.error(\n \"stats.PERCENTILES_TO_STATISTICS parameter need to be float and value between. 0 < percentile < 1 Eg 0.95\\n\"\n )\n sys.exit(1)\n\n # parse all command line options\n options = parse_options()\n\n if options.headful:\n options.headless = False\n\n if options.slave or options.expect_slaves:\n sys.stderr.write(\"The --slave/--expect-slaves parameters have been renamed --worker/--expect-workers\\n\")\n sys.exit(1)\n\n if options.web_auth:\n sys.stderr.write(\n \"The --web-auth parameters has been replaced with --web-login. See https://docs.locust.io/en/stable/extending-locust.html#authentication for details\\n\"\n )\n sys.exit(1)\n\n if options.autoquit != -1 and not options.autostart:\n sys.stderr.write(\"--autoquit is only meaningful in combination with --autostart\\n\")\n sys.exit(1)\n\n if options.hatch_rate:\n sys.stderr.write(\"[DEPRECATED] The --hatch-rate parameter has been renamed --spawn-rate\\n\")\n options.spawn_rate = options.hatch_rate\n\n # setup logging\n if not options.skip_log_setup:\n if options.loglevel.upper() in [\"DEBUG\", \"INFO\", \"WARNING\", \"ERROR\", \"CRITICAL\"]:\n setup_logging(options.loglevel, options.logfile)\n else:\n sys.stderr.write(\"Invalid --loglevel. Valid values are: DEBUG/INFO/WARNING/ERROR/CRITICAL\\n\")\n sys.exit(1)\n\n children = []\n\n if options.processes:\n if os.name == \"nt\":\n sys.stderr.write(\"--processes is not supported in Windows (except in WSL)\\n\")\n sys.exit(1)\n if options.processes == -1:\n options.processes = os.cpu_count()\n if not options.processes:\n sys.stderr.write(\"--processes failed to detect number of cpus!?\\n\")\n sys.exit(1)\n elif options.processes < -1:\n sys.stderr.write(f\"Invalid --processes count {options.processes}\\n\")\n sys.exit(1)\n elif options.master:\n sys.stderr.write(\n \"--master cannot be combined with --processes. Remove --master, as it is implicit as long as --worker is not set.\\n\"\n )\n sys.exit(1)\n # Optimize copy-on-write-behavior to save some memory (aprx 26MB -> 15MB rss) in child processes\n gc.collect() # avoid freezing garbage\n gc.freeze() # move all objects to perm gen so ref counts dont get updated\n for _ in range(options.processes):\n child_pid = gevent.fork()\n if child_pid:\n children.append(child_pid)\n logging.debug(f\"Started child worker with pid #{child_pid}\")\n else:\n # child is always a worker, even when it wasnt set on command line\n options.worker = True\n # remove options that dont make sense on worker\n options.run_time = None\n options.autostart = None\n break\n else:\n # we're in the parent process\n if options.worker:\n # ignore the first sigint in parent, and wait for the children to handle sigint\n def sigint_handler(_signal, _frame):\n if getattr(sigint_handler, \"has_run\", False):\n # if parent gets repeated sigint, we kill the children hard\n for child_pid in children:\n try:\n logging.debug(f\"Sending SIGKILL to child with pid {child_pid}\")\n os.kill(child_pid, signal.SIGKILL)\n except ProcessLookupError:\n pass # process already dead\n except Exception:\n logging.error(traceback.format_exc())\n sys.exit(1)\n sigint_handler.has_run = True\n\n signal.signal(signal.SIGINT, sigint_handler)\n exit_code = 0\n # nothing more to do, just wait for the children to exit\n for child_pid in children:\n _, child_status = os.waitpid(child_pid, 0)\n try:\n if sys.version_info >= (3, 9):\n child_exit_code = os.waitstatus_to_exitcode(child_status)\n exit_code = max(exit_code, child_exit_code)\n except AttributeError:\n pass # dammit python 3.8...\n sys.exit(exit_code)\n else:\n options.master = True\n options.expect_workers = options.processes\n\n def kill_workers(children):\n exit_code = 0\n start_time = time.time()\n # give children some time to finish up (in case they had an error parsing arguments etc)\n for child_pid in children[:]:\n while time.time() < start_time + 3:\n try:\n _, child_status = os.waitpid(child_pid, os.WNOHANG)\n children.remove(child_pid)\n try:\n if sys.version_info >= (3, 9):\n child_exit_code = os.waitstatus_to_exitcode(child_status)\n exit_code = max(exit_code, child_exit_code)\n except AttributeError:\n pass # dammit python 3.8...\n except OSError as e:\n if e.errno == errno.EINTR:\n time.sleep(0.1)\n else:\n logging.error(traceback.format_exc())\n else:\n break\n for child_pid in children:\n try:\n logging.debug(f\"Sending SIGINT to child with pid {child_pid}\")\n os.kill(child_pid, signal.SIGINT)\n except ProcessLookupError:\n pass # never mind, process was already dead\n for child_pid in children:\n _, child_status = os.waitpid(child_pid, 0)\n try:\n if sys.version_info >= (3, 9):\n child_exit_code = os.waitstatus_to_exitcode(child_status)\n exit_code = max(exit_code, child_exit_code)\n except AttributeError:\n pass # dammit python 3.8...\n if exit_code > 1:\n logging.error(f\"Bad response code from worker children: {exit_code}\")\n # ensure master doesnt finish until output from workers has arrived\n # otherwise the terminal might look weird.\n time.sleep(0.1)\n\n atexit.register(kill_workers, children)\n\n logger = logging.getLogger(__name__)\n greenlet_exception_handler = greenlet_exception_logger(logger)\n\n if options.stop_timeout:\n try:\n options.stop_timeout = parse_timespan(options.stop_timeout)\n except ValueError:\n logger.error(\"Valid --stop-timeout formats are: 20, 20s, 3m, 2h, 1h20m, 3h30m10s, etc.\")\n sys.exit(1)\n\n if options.list_commands:\n print(\"Available Users:\")\n for name in user_classes:\n print(\" \" + name)\n sys.exit(0)\n\n if not user_classes:\n logger.error(\"No User class found!\")\n sys.exit(1)\n\n # make sure specified User exists\n if options.user_classes:\n missing = set(options.user_classes) - set(user_classes.keys())\n if missing:\n logger.error(f\"Unknown User(s): {', '.join(missing)}\\n\")\n sys.exit(1)\n else:\n names = set(options.user_classes) & set(user_classes.keys())\n user_classes = [user_classes[n] for n in names]\n else:\n # list() call is needed to consume the dict_view object in Python 3\n user_classes = list(user_classes.values())\n\n if os.name != \"nt\" and not options.master:\n try:\n import resource\n\n minimum_open_file_limit = 10000\n (soft_limit, hard_limit) = resource.getrlimit(resource.RLIMIT_NOFILE)\n\n if soft_limit < minimum_open_file_limit:\n # Increasing the limit to 10000 within a running process should work on at least MacOS.\n # It does not work on all OS:es, but we should be no worse off for trying.\n resource.setrlimit(resource.RLIMIT_NOFILE, [minimum_open_file_limit, hard_limit])\n except BaseException:\n logger.warning(\n f\"\"\"System open file limit '{soft_limit} is below minimum setting '{minimum_open_file_limit}'.\nIt's not high enough for load testing, and the OS didn't allow locust to increase it by itself.\nSee https://github.com/locustio/locust/wiki/Installation#increasing-maximum-number-of-open-files-limit for more info.\"\"\"\n )\n\n if sys.version_info < (3, 9):\n logger.warning(\"Python 3.8 support is deprecated and will be removed soon\")\n\n # create locust Environment\n locustfile_path = None if not locustfile else os.path.basename(locustfile)\n\n environment = create_environment(\n user_classes,\n options,\n events=locust.events,\n shape_class=shape_class,\n locustfile=locustfile_path,\n available_user_classes=available_user_classes,\n available_shape_classes=available_shape_classes,\n available_user_tasks=available_user_tasks,\n )\n\n if options.config_users:\n for json_user_config in options.config_users:\n try:\n if json_user_config.endswith(\".json\"):\n with open(json_user_config) as file:\n user_config = json.load(file)\n else:\n user_config = json.loads(json_user_config)\n\n def ensure_user_class_name(config):\n if \"user_class_name\" not in config:\n logger.error(\"The user config must specify a user_class_name\")\n sys.exit(-1)\n\n if isinstance(user_config, list):\n for config in user_config:\n ensure_user_class_name(config)\n\n environment.update_user_class(config)\n else:\n ensure_user_class_name(user_config)\n\n environment.update_user_class(user_config)\n except Exception as e:\n logger.error(f\"The --config-users arugment must be in valid JSON string or file: {e}\")\n sys.exit(-1)\n\n if (\n shape_class\n and not shape_class.use_common_options\n and any(getattr(options, opt, None) for opt in COMMON_OPTIONS)\n ):\n logger.warning(\n \"--run-time, --users or --spawn-rate have no impact on LoadShapes unless the shape class explicitly reads them. \"\n \"See: docs.locust.io/en/stable/custom-load-shape.html#use-common-options\"\n )\n ignored = [f\"--{arg}\" for opt, arg in COMMON_OPTIONS.items() if getattr(options, opt, None)]\n logger.warning(f\"The following option(s) will be ignored: {', '.join(ignored)}\")\n\n if options.show_task_ratio:\n print(\"\\n Task ratio per User class\")\n print(\"-\" * 80)\n print_task_ratio(user_classes, options.num_users, False)\n print(\"\\n Total task ratio\")\n print(\"-\" * 80)\n print_task_ratio(user_classes, options.num_users, True)\n sys.exit(0)\n if options.show_task_ratio_json:\n print_task_ratio_json(user_classes, options.num_users)\n sys.exit(0)\n\n if options.master:\n if options.worker:\n logger.error(\"The --master argument cannot be combined with --worker\")\n sys.exit(-1)\n if options.expect_workers_max_wait and not options.expect_workers:\n logger.error(\"The --expect-workers-max-wait argument only makes sense when combined with --expect-workers\")\n sys.exit(-1)\n runner = environment.create_master_runner(\n master_bind_host=options.master_bind_host,\n master_bind_port=options.master_bind_port,\n )\n elif options.worker:\n try:\n runner = environment.create_worker_runner(options.master_host, options.master_port)\n logger.debug(\"Connected to locust master: %s:%s\", options.master_host, options.master_port)\n except OSError as e:\n logger.error(\"Failed to connect to the Locust master: %s\", e)\n sys.exit(-1)\n else:\n runner = environment.create_local_runner()\n\n # main_greenlet is pointing to runners.greenlet by default, it will point the web greenlet later if in web mode\n main_greenlet = runner.greenlet\n\n if options.run_time:\n if options.worker:\n logger.error(\"--run-time should be specified on the master node, and not on worker nodes\")\n sys.exit(1)\n try:\n options.run_time = parse_timespan(options.run_time)\n except ValueError:\n logger.error(\"Valid --run-time formats are: 20, 20s, 3m, 2h, 1h20m, 3h30m10s, etc.\")\n sys.exit(1)\n\n if options.csv_prefix:\n base_csv_file = os.path.basename(options.csv_prefix)\n base_csv_dir = options.csv_prefix[: -len(base_csv_file)]\n if not os.path.exists(base_csv_dir) and len(base_csv_dir) != 0:\n os.makedirs(base_csv_dir)\n stats_csv_writer = StatsCSVFileWriter(\n environment, stats.PERCENTILES_TO_REPORT, options.csv_prefix, options.stats_history_enabled\n )\n else:\n stats_csv_writer = StatsCSV(environment, stats.PERCENTILES_TO_REPORT)\n\n # start Web UI\n if not options.headless and not options.worker:\n # spawn web greenlet\n protocol = \"https\" if options.tls_cert and options.tls_key else \"http\"\n\n if options.web_host == \"*\":\n...\n", "needle_token_start": 12053, "needle_token_end": 12099, "code_context_ntokens": 16136, "output": ["```python\ndef is_valid_percentile(parameter):\n try:\n if 0 < float(parameter) < 1:\n return True\n return False\n except ValueError:\n return False\n```"]} +{"repo": "locustio/locust", "name": "_construct_basic_auth_str", "language": "python", "path": "locust/contrib/fasthttp.py", "position_ratio": 0.85, "description": "\nFunction Description:\n1. **Purpose**: To create an encoded string for HTTP Basic Authentication that can be used in an HTTP request header.\n2. **Input**: Two strings representing a username and a password.\n3. **Output**: A string that combines the username and password in a format suitable for HTTP Basic Authentication, encoded in Base64.\n4. **Procedure**: The function first ensures both the username and password are bytes, encoding them if they are strings. It then concatenates them with a colon in between, encodes the resulting string in Base64, and finally formats it with the prefix required for HTTP Basic Authentication headers.\n\n", "instruction": "Based on the function description and code context, please retrieve and repeat the exact described function from the code context in a code block wrapped by ```:", "template": "instruction\ncode_context\ndescription\ninstruction", "code_context": "# Path: locust/html.py\nimport datetime\nimport glob\nimport os\nimport pathlib\nfrom html import escape\nfrom itertools import chain\nfrom json import dumps\n\nfrom jinja2 import Environment, FileSystemLoader\n\nfrom . import stats as stats_module\nfrom .runners import STATE_STOPPED, STATE_STOPPING, MasterRunner\nfrom .stats import sort_stats\nfrom .user.inspectuser import get_ratio\n\nPERCENTILES_FOR_HTML_REPORT = [0.50, 0.6, 0.7, 0.8, 0.9, 0.95, 0.99, 1.0]\n\n\ndef render_template(file, template_path, **kwargs):\n env = Environment(loader=FileSystemLoader(template_path), extensions=[\"jinja2.ext.do\"])\n template = env.get_template(file)\n return template.render(**kwargs)\n\n\ndef get_html_report(\n environment,\n show_download_link=True,\n use_modern_ui=False,\n theme=\"\",\n):\n root_path = os.path.dirname(os.path.abspath(__file__))\n if use_modern_ui:\n static_path = os.path.join(root_path, \"webui\", \"dist\", \"assets\")\n template_path = os.path.join(root_path, \"webui\", \"dist\")\n else:\n static_path = os.path.join(root_path, \"static\")\n template_path = os.path.join(root_path, \"templates\")\n\n stats = environment.runner.stats\n\n start_ts = stats.start_time\n start_time = datetime.datetime.utcfromtimestamp(start_ts).strftime(\"%Y-%m-%d %H:%M:%S\")\n\n end_ts = stats.last_request_timestamp\n if end_ts:\n end_time = datetime.datetime.utcfromtimestamp(end_ts).strftime(\"%Y-%m-%d %H:%M:%S\")\n else:\n end_time = start_time\n\n host = None\n if environment.host:\n host = environment.host\n elif environment.runner.user_classes:\n all_hosts = {l.host for l in environment.runner.user_classes}\n if len(all_hosts) == 1:\n host = list(all_hosts)[0]\n\n requests_statistics = list(chain(sort_stats(stats.entries), [stats.total]))\n failures_statistics = sort_stats(stats.errors)\n exceptions_statistics = [\n {**exc, \"nodes\": \", \".join(exc[\"nodes\"])} for exc in environment.runner.exceptions.values()\n ]\n\n history = stats.history\n\n static_js = []\n if use_modern_ui:\n js_files = [os.path.basename(filepath) for filepath in glob.glob(os.path.join(static_path, \"*.js\"))]\n else:\n js_files = [\"jquery-1.11.3.min.js\", \"echarts.common.min.js\", \"vintage.js\", \"chart.js\", \"tasks.js\"]\n\n for js_file in js_files:\n path = os.path.join(static_path, js_file)\n static_js.append(\"// \" + js_file + \"\\n\")\n with open(path, encoding=\"utf8\") as f:\n static_js.append(f.read())\n static_js.extend([\"\", \"\"])\n\n if not use_modern_ui:\n static_css = []\n css_files = [\"tables.css\"]\n for css_file in css_files:\n path = os.path.join(static_path, \"css\", css_file)\n static_css.append(\"/* \" + css_file + \" */\")\n with open(path, encoding=\"utf8\") as f:\n static_css.append(f.read())\n static_css.extend([\"\", \"\"])\n\n is_distributed = isinstance(environment.runner, MasterRunner)\n user_spawned = (\n environment.runner.reported_user_classes_count if is_distributed else environment.runner.user_classes_count\n )\n\n if environment.runner.state in [STATE_STOPPED, STATE_STOPPING]:\n user_spawned = environment.runner.final_user_classes_count\n\n task_data = {\n \"per_class\": get_ratio(environment.user_classes, user_spawned, False),\n \"total\": get_ratio(environment.user_classes, user_spawned, True),\n }\n\n if use_modern_ui:\n res = render_template(\n \"report.html\",\n template_path,\n template_args={\n \"is_report\": True,\n \"requests_statistics\": [stat.to_dict(escape_string_values=True) for stat in requests_statistics],\n \"failures_statistics\": [stat.to_dict() for stat in failures_statistics],\n \"exceptions_statistics\": [stat for stat in exceptions_statistics],\n \"response_time_statistics\": [\n {\n \"name\": escape(stat.name),\n \"method\": escape(stat.method or \"\"),\n **{\n str(percentile): stat.get_response_time_percentile(percentile)\n for percentile in PERCENTILES_FOR_HTML_REPORT\n },\n }\n for stat in requests_statistics\n ],\n \"start_time\": start_time,\n \"end_time\": end_time,\n \"host\": escape(str(host)),\n \"history\": history,\n \"show_download_link\": show_download_link,\n \"locustfile\": escape(str(environment.locustfile)),\n \"tasks\": task_data,\n \"percentiles_to_chart\": stats_module.MODERN_UI_PERCENTILES_TO_CHART,\n },\n theme=theme,\n static_js=\"\\n\".join(static_js),\n )\n else:\n res = render_template(\n \"report.html\",\n template_path,\n int=int,\n round=round,\n escape=escape,\n str=str,\n requests_statistics=requests_statistics,\n failures_statistics=failures_statistics,\n...\n# Path: locust/util/cache.py\nimport functools\nfrom time import time\n\n\ndef memoize(timeout, dynamic_timeout=False):\n \"\"\"\n Memoization decorator with support for timeout.\n\n If dynamic_timeout is set, the cache timeout is doubled if the cached function\n takes longer time to run than the timeout time\n \"\"\"\n cache = {\"timeout\": timeout}\n\n def decorator(func):\n @functools.wraps(func)\n def wrapper(*args, **kwargs):\n start = time()\n if (not \"time\" in cache) or (start - cache[\"time\"] > cache[\"timeout\"]):\n # cache miss\n cache[\"result\"] = func(*args, **kwargs)\n cache[\"time\"] = time()\n if dynamic_timeout and cache[\"time\"] - start > cache[\"timeout\"]:\n cache[\"timeout\"] *= 2\n return cache[\"result\"]\n\n def clear_cache():\n if \"time\" in cache:\n del cache[\"time\"]\n if \"result\" in cache:\n del cache[\"result\"]\n\n wrapper.clear_cache = clear_cache\n return wrapper\n\n return decorator\n\n# Path: locust/util/timespan.py\nimport re\nfrom datetime import timedelta\n\n\ndef parse_timespan(time_str):\n \"\"\"\n Parse a string representing a time span and return the number of seconds.\n Valid formats are: 20, 20s, 3m, 2h, 1h20m, 3h30m10s, etc.\n \"\"\"\n if not time_str:\n raise ValueError(\"Invalid time span format\")\n\n if re.match(r\"^\\d+$\", time_str):\n # if an int is specified we assume they want seconds\n return int(time_str)\n\n timespan_regex = re.compile(r\"((?P\\d+?)h)?((?P\\d+?)m)?((?P\\d+?)s)?\")\n parts = timespan_regex.match(time_str)\n if not parts:\n raise ValueError(\"Invalid time span format. Valid formats: 20, 20s, 3m, 2h, 1h20m, 3h30m10s, etc.\")\n parts = parts.groupdict()\n time_params = {name: int(value) for name, value in parts.items() if value}\n if not time_params:\n raise ValueError(\"Invalid time span format. Valid formats: 20, 20s, 3m, 2h, 1h20m, 3h30m10s, etc.\")\n return int(timedelta(**time_params).total_seconds())\n\n# Path: locust/web.py\nfrom __future__ import annotations\n\nimport csv\nimport json\nimport logging\nimport os.path\nfrom functools import wraps\nfrom html import escape\nfrom io import StringIO\nfrom itertools import chain\nfrom json import dumps\nfrom time import time\nfrom typing import TYPE_CHECKING, Any\n\nimport gevent\nfrom flask import (\n Flask,\n Response,\n jsonify,\n make_response,\n redirect,\n render_template,\n request,\n send_file,\n send_from_directory,\n url_for,\n)\nfrom flask_cors import CORS\nfrom flask_login import LoginManager, login_required\nfrom gevent import pywsgi\n\nfrom . import __version__ as version\nfrom . import argument_parser\nfrom . import stats as stats_module\nfrom .html import get_html_report\nfrom .log import greenlet_exception_logger\nfrom .runners import STATE_MISSING, STATE_RUNNING, MasterRunner\nfrom .stats import StatsCSV, StatsCSVFileWriter, StatsErrorDict, sort_stats\nfrom .user.inspectuser import get_ratio\nfrom .util.cache import memoize\nfrom .util.timespan import parse_timespan\n\nif TYPE_CHECKING:\n from .env import Environment\n\n\nlogger = logging.getLogger(__name__)\ngreenlet_exception_handler = greenlet_exception_logger(logger)\n\nDEFAULT_CACHE_TIME = 2.0\n\n\nclass WebUI:\n \"\"\"\n Sets up and runs a Flask web app that can start and stop load tests using the\n :attr:`environment.runner ` as well as show the load test statistics\n in :attr:`environment.stats `\n \"\"\"\n\n app: Flask | None = None\n \"\"\"\n Reference to the :class:`flask.Flask` app. Can be used to add additional web routes and customize\n the Flask app in other various ways. Example::\n\n from flask import request\n\n @web_ui.app.route(\"/my_custom_route\")\n def my_custom_route():\n return \"your IP is: %s\" % request.remote_addr\n \"\"\"\n\n greenlet: gevent.Greenlet | None = None\n \"\"\"\n Greenlet of the running web server\n \"\"\"\n\n server: pywsgi.WSGIServer | None = None\n \"\"\"Reference to the :class:`pyqsgi.WSGIServer` instance\"\"\"\n\n template_args: dict[str, Any]\n \"\"\"Arguments used to render index.html for the web UI. Must be used with custom templates\n extending index.html.\"\"\"\n\n auth_args: dict[str, Any]\n \"\"\"Arguments used to render auth.html for the web UI auth page. Must be used when configuring auth\"\"\"\n\n def __init__(\n self,\n environment: Environment,\n host: str,\n port: int,\n web_login: bool = False,\n tls_cert: str | None = None,\n tls_key: str | None = None,\n stats_csv_writer: StatsCSV | None = None,\n delayed_start=False,\n userclass_picker_is_active=False,\n modern_ui=False,\n ):\n \"\"\"\n Create WebUI instance and start running the web server in a separate greenlet (self.greenlet)\n\n Arguments:\n environment: Reference to the current Locust Environment\n host: Host/interface that the web server should accept connections to\n port: Port that the web server should listen to\n web_login: Enables a login page for the modern UI\n tls_cert: A path to a TLS certificate\n tls_key: A path to a TLS private key\n delayed_start: Whether or not to delay starting web UI until `start()` is called. Delaying web UI start\n allows for adding Flask routes or Blueprints before accepting requests, avoiding errors.\n \"\"\"\n environment.web_ui = self\n self.stats_csv_writer = stats_csv_writer or StatsCSV(environment, stats_module.PERCENTILES_TO_REPORT)\n self.environment = environment\n self.host = host\n self.port = port\n self.tls_cert = tls_cert\n self.tls_key = tls_key\n self.userclass_picker_is_active = userclass_picker_is_active\n self.modern_ui = modern_ui\n self.web_login = web_login\n app = Flask(__name__)\n CORS(app)\n self.app = app\n app.jinja_env.add_extension(\"jinja2.ext.do\")\n app.debug = True\n root_path = os.path.dirname(os.path.abspath(__file__))\n app.root_path = root_path\n self.webui_build_path = os.path.join(root_path, \"webui\", \"dist\")\n self.greenlet: gevent.Greenlet | None = None\n self._swarm_greenlet: gevent.Greenlet | None = None\n self.template_args = {}\n self.auth_args = {}\n\n if self.web_login:\n self.login_manager = LoginManager()\n self.login_manager.init_app(app)\n self.login_manager.login_view = \"login\"\n\n if environment.runner:\n self.update_template_args()\n if not delayed_start:\n self.start()\n\n @app.errorhandler(Exception)\n def handle_exception(error):\n error_message = str(error)\n logger.log(logging.CRITICAL, error_message)\n return make_response(error_message, 500)\n\n @app.route(\"/assets/\")\n def send_assets(path):\n return send_from_directory(os.path.join(self.webui_build_path, \"assets\"), path)\n\n @app.route(\"/\")\n @self.auth_required_if_enabled\n def index() -> str | Response:\n if not environment.runner:\n return make_response(\"Error: Locust Environment does not have any runner\", 500)\n self.update_template_args()\n\n if self.modern_ui:\n self.set_static_modern_ui()\n\n return render_template(\"index.html\", template_args=self.template_args)\n return render_template(\"index.html\", **self.template_args)\n\n @app.route(\"/swarm\", methods=[\"POST\"])\n @self.auth_required_if_enabled\n def swarm() -> Response:\n assert request.method == \"POST\"\n\n # Loading UserClasses & ShapeClasses if Locust is running with UserClass Picker\n if self.userclass_picker_is_active:\n if not self.environment.available_user_classes:\n err_msg = \"UserClass picker is active but there are no available UserClasses\"\n return jsonify({\"success\": False, \"message\": err_msg, \"host\": environment.host})\n\n # Getting Specified User Classes\n form_data_user_class_names = request.form.getlist(\"user_classes\")\n\n # Updating UserClasses\n if form_data_user_class_names:\n user_classes = {}\n for user_class_name, user_class_object in self.environment.available_user_classes.items():\n if user_class_name in form_data_user_class_names:\n user_classes[user_class_name] = user_class_object\n\n else:\n if self.environment.runner and self.environment.runner.state == STATE_RUNNING:\n # Test is already running\n # Using the user classes that have already been selected\n user_classes = {\n key: value\n for (key, value) in self.environment.available_user_classes.items()\n if value in self.environment.user_classes\n }\n else:\n # Starting test with no user class selection\n # Defaulting to using all available user classes\n user_classes = self.environment.available_user_classes\n\n self._update_user_classes(user_classes)\n\n # Updating ShapeClass if specified in WebUI Form\n form_data_shape_class_name = request.form.get(\"shape_class\", \"Default\")\n if form_data_shape_class_name == \"Default\":\n self._update_shape_class(None)\n else:\n self._update_shape_class(form_data_shape_class_name)\n\n parsed_options_dict = vars(environment.parsed_options) if environment.parsed_options else {}\n run_time = None\n for key, value in request.form.items():\n if key == \"user_count\": # if we just renamed this field to \"users\" we wouldn't need this\n user_count = int(value)\n elif key == \"spawn_rate\":\n spawn_rate = float(value)\n elif key == \"host\":\n # Replace < > to guard against XSS\n environment.host = str(request.form[\"host\"]).replace(\"<\", \"\").replace(\">\", \"\")\n elif key == \"user_classes\":\n # Set environment.parsed_options.user_classes to the selected user_classes\n parsed_options_dict[key] = request.form.getlist(\"user_classes\")\n elif key == \"run_time\":\n if not value:\n continue\n try:\n run_time = parse_timespan(value)\n except ValueError:\n err_msg = \"Valid run_time formats are : 20, 20s, 3m, 2h, 1h20m, 3h30m10s, etc.\"\n logger.error(err_msg)\n return jsonify({\"success\": False, \"message\": err_msg, \"host\": environment.host})\n elif key in parsed_options_dict:\n # update the value in environment.parsed_options, but dont change the type.\n parsed_options_value = parsed_options_dict[key]\n\n if isinstance(parsed_options_value, bool):\n parsed_options_dict[key] = value == \"true\"\n elif parsed_options_value is None:\n parsed_options_dict[key] = parsed_options_value\n else:\n parsed_options_dict[key] = type(parsed_options_dict[key])(value)\n\n if environment.shape_class and environment.runner is not None:\n environment.runner.start_shape()\n return jsonify(\n {\n \"success\": True,\n \"message\": f\"Swarming started using shape class '{type(environment.shape_class).__name__}'\",\n \"host\": environment.host,\n }\n )\n\n if self._swarm_greenlet is not None:\n self._swarm_greenlet.kill(block=True)\n self._swarm_greenlet = None\n\n if environment.runner is not None:\n self._swarm_greenlet = gevent.spawn(environment.runner.start, user_count, spawn_rate)\n self._swarm_greenlet.link_exception(greenlet_exception_handler)\n response_data = {\n \"success\": True,\n \"message\": \"Swarming started\",\n \"host\": environment.host,\n }\n if run_time:\n gevent.spawn_later(run_time, self._stop_runners).link_exception(greenlet_exception_handler)\n response_data[\"run_time\"] = run_time\n\n if self.userclass_picker_is_active:\n response_data[\"user_classes\"] = sorted(user_classes.keys())\n\n return jsonify(response_data)\n else:\n return jsonify({\"success\": False, \"message\": \"No runner\", \"host\": environment.host})\n\n @app.route(\"/stop\")\n @self.auth_required_if_enabled\n def stop() -> Response:\n if self._swarm_greenlet is not None:\n self._swarm_greenlet.kill(block=True)\n self._swarm_greenlet = None\n if environment.runner is not None:\n environment.runner.stop()\n return jsonify({\"success\": True, \"message\": \"Test stopped\"})\n\n @app.route(\"/stats/reset\")\n @self.auth_required_if_enabled\n def reset_stats() -> str:\n environment.events.reset_stats.fire()\n if environment.runner is not None:\n environment.runner.stats.reset_all()\n environment.runner.exceptions = {}\n return \"ok\"\n\n @app.route(\"/stats/report\")\n @self.auth_required_if_enabled\n def stats_report() -> Response:\n theme = request.args.get(\"theme\", \"\")\n res = get_html_report(\n self.environment,\n show_download_link=not request.args.get(\"download\"),\n use_modern_ui=self.modern_ui,\n theme=theme,\n )\n if request.args.get(\"download\"):\n res = app.make_response(res)\n res.headers[\"Content-Disposition\"] = f\"attachment;filename=report_{time()}.html\"\n return res\n\n def _download_csv_suggest_file_name(suggest_filename_prefix: str) -> str:\n \"\"\"Generate csv file download attachment filename suggestion.\n\n Arguments:\n suggest_filename_prefix: Prefix of the filename to suggest for saving the download. Will be appended with timestamp.\n \"\"\"\n\n return f\"{suggest_filename_prefix}_{time()}.csv\"\n\n def _download_csv_response(csv_data: str, filename_prefix: str) -> Response:\n \"\"\"Generate csv file download response with 'csv_data'.\n\n Arguments:\n csv_data: CSV header and data rows.\n filename_prefix: Prefix of the filename to suggest for saving the download. Will be appended with timestamp.\n \"\"\"\n\n response = make_response(csv_data)\n response.headers[\"Content-type\"] = \"text/csv\"\n response.headers[\"Content-disposition\"] = (\n f\"attachment;filename={_download_csv_suggest_file_name(filename_prefix)}\"\n )\n return response\n\n @app.route(\"/stats/requests/csv\")\n @self.auth_required_if_enabled\n def request_stats_csv() -> Response:\n data = StringIO()\n writer = csv.writer(data)\n self.stats_csv_writer.requests_csv(writer)\n return _download_csv_response(data.getvalue(), \"requests\")\n\n @app.route(\"/stats/requests_full_history/csv\")\n @self.auth_required_if_enabled\n def request_stats_full_history_csv() -> Response:\n options = self.environment.parsed_options\n if options and options.stats_history_enabled and isinstance(self.stats_csv_writer, StatsCSVFileWriter):\n return send_file(\n os.path.abspath(self.stats_csv_writer.stats_history_file_name()),\n mimetype=\"text/csv\",\n as_attachment=True,\n download_name=_download_csv_suggest_file_name(\"requests_full_history\"),\n etag=True,\n max_age=0,\n conditional=True,\n last_modified=None,\n )\n\n return make_response(\"Error: Server was not started with option to generate full history.\", 404)\n\n @app.route(\"/stats/failures/csv\")\n @self.auth_required_if_enabled\n def failures_stats_csv() -> Response:\n data = StringIO()\n writer = csv.writer(data)\n self.stats_csv_writer.failures_csv(writer)\n return _download_csv_response(data.getvalue(), \"failures\")\n\n @app.route(\"/stats/requests\")\n @self.auth_required_if_enabled\n @memoize(timeout=DEFAULT_CACHE_TIME, dynamic_timeout=True)\n def request_stats() -> Response:\n stats: list[dict[str, Any]] = []\n errors: list[StatsErrorDict] = []\n\n if environment.runner is None:\n report = {\n \"stats\": stats,\n \"errors\": errors,\n \"total_rps\": 0.0,\n \"total_fail_per_sec\": 0.0,\n \"fail_ratio\": 0.0,\n \"current_response_time_percentile_1\": None,\n \"current_response_time_percentile_2\": None,\n \"state\": STATE_MISSING,\n \"user_count\": 0,\n }\n\n if isinstance(environment.runner, MasterRunner):\n report.update({\"workers\": []})\n\n return jsonify(report)\n\n for s in chain(sort_stats(environment.runner.stats.entries), [environment.runner.stats.total]):\n stats.append(s.to_dict())\n\n for e in environment.runner.errors.values():\n err_dict = e.serialize()\n err_dict[\"name\"] = escape(err_dict[\"name\"])\n err_dict[\"error\"] = escape(err_dict[\"error\"])\n errors.append(err_dict)\n\n # Truncate the total number of stats and errors displayed since a large number of rows will cause the app\n # to render extremely slowly. Aggregate stats should be preserved.\n truncated_stats = stats[:500]\n if len(stats) > 500:\n truncated_stats += [stats[-1]]\n\n report = {\"stats\": truncated_stats, \"errors\": errors[:500]}\n total_stats = stats[-1]\n\n if stats:\n report[\"total_rps\"] = total_stats[\"current_rps\"]\n report[\"total_fail_per_sec\"] = total_stats[\"current_fail_per_sec\"]\n report[\"total_avg_response_time\"] = total_stats[\"avg_response_time\"]\n report[\"fail_ratio\"] = environment.runner.stats.total.fail_ratio\n\n if self.modern_ui:\n report[\"current_response_time_percentiles\"] = {\n f\"response_time_percentile_{percentile}\": environment.runner.stats.total.get_current_response_time_percentile(\n percentile\n )\n for percentile in stats_module.MODERN_UI_PERCENTILES_TO_CHART\n }\n else:\n report[\"current_response_time_percentile_1\"] = (\n environment.runner.stats.total.get_current_response_time_percentile(\n stats_module.PERCENTILES_TO_CHART[0]\n )\n )\n report[\"current_response_time_percentile_2\"] = (\n environment.runner.stats.total.get_current_response_time_percentile(\n stats_module.PERCENTILES_TO_CHART[1]\n )\n )\n\n if isinstance(environment.runner, MasterRunner):\n workers = []\n for worker in environment.runner.clients.values():\n workers.append(\n {\n \"id\": worker.id,\n \"state\": worker.state,\n \"user_count\": worker.user_count,\n \"cpu_usage\": worker.cpu_usage,\n \"memory_usage\": worker.memory_usage,\n }\n )\n\n report[\"workers\"] = workers\n\n report[\"state\"] = environment.runner.state\n report[\"user_count\"] = environment.runner.user_count\n\n return jsonify(report)\n\n @app.route(\"/exceptions\")\n @self.auth_required_if_enabled\n def exceptions() -> Response:\n return jsonify(\n {\n \"exceptions\": [\n {\n \"count\": row[\"count\"],\n \"msg\": escape(row[\"msg\"]),\n \"traceback\": escape(row[\"traceback\"]),\n \"nodes\": \", \".join(row[\"nodes\"]),\n }\n for row in (environment.runner.exceptions.values() if environment.runner is not None else [])\n ]\n }\n )\n\n @app.route(\"/exceptions/csv\")\n @self.auth_required_if_enabled\n def exceptions_csv() -> Response:\n data = StringIO()\n writer = csv.writer(data)\n self.stats_csv_writer.exceptions_csv(writer)\n return _download_csv_response(data.getvalue(), \"exceptions\")\n\n @app.route(\"/tasks\")\n @self.auth_required_if_enabled\n def tasks() -> dict[str, dict[str, dict[str, float]]]:\n runner = self.environment.runner\n user_spawned: dict[str, int]\n if runner is None:\n user_spawned = {}\n else:\n user_spawned = (\n runner.reported_user_classes_count\n if isinstance(runner, MasterRunner)\n else runner.user_classes_count\n )\n\n task_data = {\n \"per_class\": get_ratio(self.environment.user_classes, user_spawned, False),\n \"total\": get_ratio(self.environment.user_classes, user_spawned, True),\n }\n return task_data\n\n @app.route(\"/logs\")\n @self.auth_required_if_enabled\n def logs():\n log_reader_handler = [\n handler for handler in logging.getLogger(\"root\").handlers if handler.name == \"log_reader\"\n ]\n\n if log_reader_handler:\n logs = log_reader_handler[0].logs\n else:\n logs = []\n\n return jsonify({\"logs\": logs})\n\n @app.route(\"/login\")\n def login():\n if not self.web_login:\n return redirect(url_for(\"index\"))\n\n if self.modern_ui:\n self.set_static_modern_ui()\n\n return render_template(\n \"auth.html\",\n auth_args=self.auth_args,\n )\n else:\n return \"Web Auth is only available on the modern web ui.\"\n\n @app.route(\"/user\", methods=[\"POST\"])\n def update_user():\n assert request.method == \"POST\"\n\n user_settings = json.loads(request.data)\n self.environment.update_user_class(user_settings)\n\n return {}, 201\n\n def start(self):\n self.greenlet = gevent.spawn(self.start_server)\n self.greenlet.link_exception(greenlet_exception_handler)\n\n def start_server(self):\n if self.tls_cert and self.tls_key:\n self.server = pywsgi.WSGIServer(\n (self.host, self.port), self.app, log=None, keyfile=self.tls_key, certfile=self.tls_cert\n )\n else:\n self.server = pywsgi.WSGIServer((self.host, self.port), self.app, log=None)\n self.server.serve_forever()\n\n def stop(self):\n \"\"\"\n Stop the running web server\n \"\"\"\n self.server.stop()\n\n def auth_required_if_enabled(self, view_func):\n \"\"\"\n Decorator that can be used on custom route methods that will turn on Flask Login\n authentication if the ``--web-login`` flag is used. Example::\n\n @web_ui.app.route(\"/my_custom_route\")\n @web_ui.auth_required_if_enabled\n def my_custom_route():\n return \"custom response\"\n \"\"\"\n\n @wraps(view_func)\n def wrapper(*args, **kwargs):\n if self.web_login:\n try:\n return login_required(view_func)(*args, **kwargs)\n except Exception as e:\n return f\"Locust auth exception: {e} See https://docs.locust.io/en/stable/extending-locust.html#authentication for configuring authentication.\"\n else:\n return view_func(*args, **kwargs)\n\n return wrapper\n\n def set_static_modern_ui(self):\n self.app.template_folder = self.webui_build_path\n self.app.static_folder = os.path.join(self.webui_build_path, \"assets\")\n self.app.static_url_path = \"/assets/\"\n\n def update_template_args(self):\n override_host_warning = False\n if self.environment.host:\n host = self.environment.host\n elif self.environment.runner.user_classes:\n all_hosts = {l.host for l in self.environment.runner.user_classes}\n if len(all_hosts) == 1:\n host = list(all_hosts)[0]\n else:\n # since we have multiple User classes with different host attributes, we'll\n # inform that specifying host will override the host for all User classes\n override_host_warning = True\n host = None\n else:\n host = None\n\n options = self.environment.parsed_options\n\n is_distributed = isinstance(self.environment.runner, MasterRunner)\n if is_distributed:\n worker_count = self.environment.runner.worker_count\n else:\n worker_count = 0\n\n stats = self.environment.runner.stats\n extra_options = argument_parser.ui_extra_args_dict()\n\n available_user_classes = None\n users = None\n if self.environment.available_user_classes:\n available_user_classes = sorted(self.environment.available_user_classes)\n users = {\n user_class_name: user_class.json()\n for (user_class_name, user_class) in self.environment.available_user_classes.items()\n }\n\n available_shape_classes = [\"Default\"]\n if self.environment.available_shape_classes:\n available_shape_classes += sorted(self.environment.available_shape_classes.keys())\n\n available_user_tasks = (\n {\n user_class_name: [task.__name__ for task in user_class]\n for (user_class_name, user_class) in self.environment.available_user_tasks.items()\n }\n if self.environment.available_user_tasks\n else None\n )\n\n if self.modern_ui:\n percentiles = {\n \"percentiles_to_chart\": stats_module.MODERN_UI_PERCENTILES_TO_CHART,\n \"percentiles_to_statistics\": stats_module.PERCENTILES_TO_STATISTICS,\n }\n else:\n percentiles = {\n \"percentile1\": stats_module.PERCENTILES_TO_CHART[0],\n \"percentile2\": stats_module.PERCENTILES_TO_CHART[1],\n }\n\n self.template_args = {\n \"locustfile\": self.environment.locustfile,\n \"state\": self.environment.runner.state,\n \"is_distributed\": is_distributed,\n \"user_count\": self.environment.runner.user_count,\n \"version\": version,\n \"host\": host if host else \"\",\n \"history\": stats.history if stats.num_requests > 0 else {},\n \"override_host_warning\": override_host_warning,\n \"num_users\": options and options.num_users,\n \"spawn_rate\": options and options.spawn_rate,\n \"worker_count\": worker_count,\n \"hide_common_options\": (\n self.environment.shape_class\n and not (self.userclass_picker_is_active or self.environment.shape_class.use_common_options)\n ),\n \"stats_history_enabled\": options and options.stats_history_enabled,\n \"tasks\": dumps({}),\n \"extra_options\": extra_options,\n \"run_time\": options and options.run_time,\n \"show_userclass_picker\": self.userclass_picker_is_active,\n \"available_user_classes\": available_user_classes,\n \"available_shape_classes\": available_shape_classes,\n \"available_user_tasks\": available_user_tasks,\n \"users\": users,\n **percentiles,\n }\n\n def _update_shape_class(self, shape_class_name):\n if shape_class_name:\n shape_class = self.environment.available_shape_classes[shape_class_name]\n shape_class.runner = self.environment.runner\n else:\n shape_class = None\n\n # Validating ShapeClass\n self.environment.shape_class = shape_class\n self.environment._validate_shape_class_instance()\n\n def _update_user_classes(self, user_classes):\n self.environment.user_classes = list(user_classes.values())\n # populate the locustfile which used in web ui title only\n if self.environment.locustfile is None:\n self.environment.locustfile = \",\".join(self.environment.user_classes_by_name.keys())\n\n # Validating UserClasses\n self.environment._remove_user_classes_with_weight_zero()\n self.environment._validate_user_class_name_uniqueness()\n\n def _stop_runners(self):\n self.environment.runner.stop()\n\n# Path: locust/env.py\nfrom __future__ import annotations\n\nfrom operator import methodcaller\nfrom typing import Callable, TypeVar\n\nfrom configargparse import Namespace\n\nfrom .dispatch import UsersDispatcher\nfrom .event import Events\nfrom .exception import RunnerAlreadyExistsError\nfrom .runners import LocalRunner, MasterRunner, Runner, WorkerRunner\nfrom .shape import LoadTestShape\nfrom .stats import RequestStats, StatsCSV\nfrom .user import User\nfrom .user.task import TaskHolder, TaskSet, filter_tasks_by_tags\nfrom .web import WebUI\n\nRunnerType = TypeVar(\"RunnerType\", bound=Runner)\n\n\nclass Environment:\n def __init__(\n self,\n *,\n user_classes: list[type[User]] | None = None,\n shape_class: LoadTestShape | None = None,\n tags: list[str] | None = None,\n locustfile: str | None = None,\n exclude_tags: list[str] | None = None,\n events: Events | None = None,\n host: str | None = None,\n reset_stats=False,\n stop_timeout: float | None = None,\n catch_exceptions=True,\n parsed_options: Namespace | None = None,\n available_user_classes: dict[str, User] | None = None,\n available_shape_classes: dict[str, LoadTestShape] | None = None,\n available_user_tasks: dict[str, list[TaskSet | Callable]] | None = None,\n dispatcher_class: type[UsersDispatcher] = UsersDispatcher,\n ):\n self.runner: Runner | None = None\n \"\"\"Reference to the :class:`Runner ` instance\"\"\"\n\n self.web_ui: WebUI | None = None\n \"\"\"Reference to the WebUI instance\"\"\"\n\n self.process_exit_code: int | None = None\n \"\"\"\n If set it'll be the exit code of the Locust process\n \"\"\"\n\n if events:\n self.events = events\n \"\"\"\n Event hooks used by Locust internally, as well as to extend Locust's functionality\n See :ref:`events` for available events.\n \"\"\"\n else:\n self.events = Events()\n\n self.locustfile = locustfile\n \"\"\"Filename (not path) of locustfile\"\"\"\n self.user_classes: list[type[User]] = user_classes or []\n \"\"\"User classes that the runner will run\"\"\"\n self.shape_class = shape_class\n \"\"\"A shape class to control the shape of the load test\"\"\"\n self.tags = tags\n \"\"\"If set, only tasks that are tagged by tags in this list will be executed. Leave this as None to use the one from parsed_options\"\"\"\n self.exclude_tags = exclude_tags\n \"\"\"If set, only tasks that aren't tagged by tags in this list will be executed. Leave this as None to use the one from parsed_options\"\"\"\n self.stats = RequestStats()\n \"\"\"Reference to RequestStats instance\"\"\"\n self.host = host\n \"\"\"Base URL of the target system\"\"\"\n self.reset_stats = reset_stats\n \"\"\"Determines if stats should be reset once all simulated users have been spawned\"\"\"\n if stop_timeout is not None:\n self.stop_timeout = stop_timeout\n elif parsed_options:\n self.stop_timeout = float(getattr(parsed_options, \"stop_timeout\", 0.0))\n else:\n self.stop_timeout = 0.0\n \"\"\"\n If set, the runner will try to stop the running users gracefully and wait this many seconds\n before killing them hard.\n \"\"\"\n self.catch_exceptions = catch_exceptions\n \"\"\"\n If True exceptions that happen within running users will be caught (and reported in UI/console).\n If False, exceptions will be raised.\n \"\"\"\n self.parsed_options = parsed_options\n \"\"\"Reference to the parsed command line options (used to pre-populate fields in Web UI). When using Locust as a library, this should either be `None` or an object created by `argument_parser.parse_args()`\"\"\"\n self.available_user_classes = available_user_classes\n \"\"\"List of the available User Classes to pick from in the UserClass Picker\"\"\"\n self.available_shape_classes = available_shape_classes\n \"\"\"List of the available Shape Classes to pick from in the ShapeClass Picker\"\"\"\n self.available_user_tasks = available_user_tasks\n \"\"\"List of the available Tasks per User Classes to pick from in the Task Picker\"\"\"\n self.dispatcher_class = dispatcher_class\n \"\"\"A user dispatcher class that decides how users are spawned, default :class:`UsersDispatcher `\"\"\"\n\n self._remove_user_classes_with_weight_zero()\n self._validate_user_class_name_uniqueness()\n self._validate_shape_class_instance()\n\n def _create_runner(\n self,\n runner_class: type[RunnerType],\n *args,\n **kwargs,\n ) -> RunnerType:\n if self.runner is not None:\n raise RunnerAlreadyExistsError(f\"Environment.runner already exists ({self.runner})\")\n self.runner = runner_class(self, *args, **kwargs)\n\n # Attach the runner to the shape class so that the shape class can access user count state\n if self.shape_class:\n self.shape_class.runner = self.runner\n\n return self.runner\n\n def create_local_runner(self) -> LocalRunner:\n \"\"\"\n Create a :class:`LocalRunner ` instance for this Environment\n \"\"\"\n return self._create_runner(LocalRunner)\n\n def create_master_runner(self, master_bind_host=\"*\", master_bind_port=5557) -> MasterRunner:\n \"\"\"\n Create a :class:`MasterRunner ` instance for this Environment\n\n :param master_bind_host: Interface/host that the master should use for incoming worker connections.\n Defaults to \"*\" which means all interfaces.\n :param master_bind_port: Port that the master should listen for incoming worker connections on\n \"\"\"\n return self._create_runner(\n MasterRunner,\n master_bind_host=master_bind_host,\n master_bind_port=master_bind_port,\n )\n\n def create_worker_runner(self, master_host: str, master_port: int) -> WorkerRunner:\n \"\"\"\n Create a :class:`WorkerRunner ` instance for this Environment\n\n :param master_host: Host/IP of a running master node\n :param master_port: Port on master node to connect to\n \"\"\"\n # Create a new RequestStats with use_response_times_cache set to False to save some memory\n # and CPU cycles, since the response_times_cache is not needed for Worker nodes\n self.stats = RequestStats(use_response_times_cache=False)\n return self._create_runner(\n WorkerRunner,\n master_host=master_host,\n master_port=master_port,\n )\n\n def create_web_ui(\n self,\n host=\"\",\n port=8089,\n web_login: bool = False,\n tls_cert: str | None = None,\n tls_key: str | None = None,\n stats_csv_writer: StatsCSV | None = None,\n delayed_start=False,\n userclass_picker_is_active=False,\n modern_ui=False,\n ) -> WebUI:\n \"\"\"\n Creates a :class:`WebUI ` instance for this Environment and start running the web server\n\n :param host: Host/interface that the web server should accept connections to. Defaults to \"\"\n which means all interfaces\n :param port: Port that the web server should listen to\n :param web_login: If provided, an authentication page will protect the app\n :param tls_cert: An optional path (str) to a TLS cert. If this is provided the web UI will be\n served over HTTPS\n :param tls_key: An optional path (str) to a TLS private key. If this is provided the web UI will be\n served over HTTPS\n :param stats_csv_writer: `StatsCSV ` instance.\n :param delayed_start: Whether or not to delay starting web UI until `start()` is called. Delaying web UI start\n allows for adding Flask routes or Blueprints before accepting requests, avoiding errors.\n \"\"\"\n self.web_ui = WebUI(\n self,\n host,\n port,\n web_login=web_login,\n tls_cert=tls_cert,\n tls_key=tls_key,\n stats_csv_writer=stats_csv_writer,\n delayed_start=delayed_start,\n userclass_picker_is_active=userclass_picker_is_active,\n modern_ui=modern_ui,\n )\n return self.web_ui\n\n def update_user_class(self, user_settings):\n user_class_name = user_settings.get(\"user_class_name\")\n user_class = self.available_user_classes[user_class_name]\n user_tasks = self.available_user_tasks[user_class_name]\n\n for key, value in user_settings.items():\n if key not in [\"user_class_name\", \"tasks\"]:\n setattr(user_class, key, value)\n if key == \"tasks\":\n user_class.tasks = [task for task in user_tasks if task.__name__ in value]\n\n def _filter_tasks_by_tags(self) -> None:\n \"\"\"\n Filter the tasks on all the user_classes recursively, according to the tags and\n exclude_tags attributes\n \"\"\"\n if getattr(self, \"_tasks_filtered\", False):\n return # only filter once\n self._tasks_filtered = True\n\n if self.tags is not None:\n tags = set(self.tags)\n elif self.parsed_options and self.parsed_options.tags:\n tags = set(self.parsed_options.tags)\n else:\n tags = None\n\n if self.exclude_tags is not None:\n exclude_tags = set(self.exclude_tags)\n elif self.parsed_options and self.parsed_options.exclude_tags:\n exclude_tags = set(self.parsed_options.exclude_tags)\n else:\n exclude_tags = None\n\n for user_class in self.user_classes:\n filter_tasks_by_tags(user_class, tags, exclude_tags)\n\n def _remove_user_classes_with_weight_zero(self) -> None:\n \"\"\"\n Remove user classes having a weight of zero.\n \"\"\"\n if len(self.user_classes) == 0:\n # Preserve previous behaviour that allowed no user classes to be specified.\n return\n filtered_user_classes = [\n user_class for user_class in self.user_classes if user_class.weight > 0 or user_class.fixed_count > 0\n ]\n if len(filtered_user_classes) == 0:\n # TODO: Better exception than `ValueError`?\n raise ValueError(\"There are no users with weight > 0.\")\n self.user_classes[:] = filtered_user_classes\n\n def assign_equal_weights(self) -> None:\n \"\"\"\n Update the user classes such that each user runs their specified tasks with equal\n probability.\n \"\"\"\n for u in self.user_classes:\n u.weight = 1\n user_tasks: list[TaskSet | Callable] = []\n tasks_frontier = u.tasks\n while len(tasks_frontier) != 0:\n t = tasks_frontier.pop()\n if isinstance(t, TaskHolder):\n tasks_frontier.extend(t.tasks)\n elif callable(t):\n if t not in user_tasks:\n user_tasks.append(t)\n else:\n raise ValueError(\"Unrecognized task type in user\")\n u.tasks = user_tasks\n\n def _validate_user_class_name_uniqueness(self):\n # Validate there's no class with the same name but in different modules\n if len({user_class.__name__ for user_class in self.user_classes}) != len(self.user_classes):\n raise ValueError(\n \"The following user classes have the same class name: {}\".format(\n \", \".join(map(methodcaller(\"fullname\"), self.user_classes))\n )\n )\n\n def _validate_shape_class_instance(self):\n if self.shape_class is not None and not isinstance(self.shape_class, LoadTestShape):\n raise ValueError(\n \"shape_class should be instance of LoadTestShape or subclass LoadTestShape, but got: %s\"\n % self.shape_class\n )\n\n @property\n def user_classes_by_name(self) -> dict[str, type[User]]:\n return {u.__name__: u for u in self.user_classes}\n\n# Path: locust/contrib/fasthttp.py\nfrom __future__ import annotations\n\nfrom locust.env import Environment\nfrom locust.exception import CatchResponseError, LocustError, ResponseError\nfrom locust.user import User\nfrom locust.util.deprecation import DeprecatedFastHttpLocustClass as FastHttpLocust\n\nimport json\nimport json as unshadowed_json # some methods take a named parameter called json\nimport re\nimport socket\nimport time\nimport traceback\nfrom base64 import b64encode\nfrom contextlib import contextmanager\nfrom http.cookiejar import CookieJar\nfrom json.decoder import JSONDecodeError\nfrom ssl import SSLError\nfrom typing import Any, Callable, Generator, cast\nfrom urllib.parse import urlparse, urlunparse\n\nimport gevent\nfrom charset_normalizer import detect\nfrom gevent.timeout import Timeout\nfrom geventhttpclient._parser import HTTPParseError\nfrom geventhttpclient.client import HTTPClientPool\nfrom geventhttpclient.header import Headers\nfrom geventhttpclient.response import HTTPConnectionClosed, HTTPSocketPoolResponse\nfrom geventhttpclient.useragent import CompatRequest, CompatResponse, ConnectionError, UserAgent\n\n# borrow requests's content-type header parsing\nfrom requests.utils import get_encoding_from_headers\n\n# Monkey patch geventhttpclient.useragent.CompatRequest so that Cookiejar works with Python >= 3.3\n# More info: https://github.com/requests/requests/pull/871\nCompatRequest.unverifiable = False\n\n# Workaround for AttributeError: 'CompatRequest' object has no attribute 'type' in Cookiejar\n# https://github.com/locustio/locust/issues/1138\n# Might allow secure cookies over non-secure connections but that is a minor concern in a load testing tool\nCompatRequest.type = \"https\"\n\n# Regexp for checking if an absolute URL was specified\nabsolute_http_url_regexp = re.compile(r\"^https?://\", re.I)\n\n# List of exceptions that can be raised by geventhttpclient when sending an HTTP request,\n# and that should result in a Locust failure\nFAILURE_EXCEPTIONS = (\n ConnectionError,\n ConnectionRefusedError,\n ConnectionResetError,\n socket.error,\n SSLError,\n Timeout,\n HTTPConnectionClosed,\n)\n\n\n\ndef _construct_basic_auth_str(username, password):\n \"\"\"Construct Authorization header value to be used in HTTP Basic Auth\"\"\"\n if isinstance(username, str):\n username = username.encode(\"latin1\")\n if isinstance(password, str):\n password = password.encode(\"latin1\")\n return \"Basic \" + b64encode(b\":\".join((username, password))).strip().decode(\"ascii\")\n\n\ndef insecure_ssl_context_factory():\n context = gevent.ssl.create_default_context()\n context.check_hostname = False\n context.verify_mode = gevent.ssl.CERT_NONE\n return context\n\n\nclass FastHttpSession:\n auth_header = None\n\n def __init__(\n self,\n environment: Environment,\n base_url: str,\n user: User | None,\n insecure=True,\n client_pool: HTTPClientPool | None = None,\n ssl_context_factory: Callable | None = None,\n **kwargs,\n ):\n self.environment = environment\n self.base_url = base_url\n self.cookiejar = CookieJar()\n self.user = user\n if not ssl_context_factory:\n if insecure:\n ssl_context_factory = insecure_ssl_context_factory\n else:\n ssl_context_factory = gevent.ssl.create_default_context\n self.client = LocustUserAgent(\n cookiejar=self.cookiejar,\n ssl_context_factory=ssl_context_factory,\n insecure=insecure,\n client_pool=client_pool,\n **kwargs,\n )\n\n # Check for basic authentication\n parsed_url = urlparse(self.base_url)\n if parsed_url.username and parsed_url.password:\n netloc = parsed_url.hostname or \"\"\n if parsed_url.port:\n netloc += \":%d\" % parsed_url.port\n\n # remove username and password from the base_url\n self.base_url = urlunparse(\n (parsed_url.scheme, netloc, parsed_url.path, parsed_url.params, parsed_url.query, parsed_url.fragment)\n )\n # store authentication header (we construct this by using _basic_auth_str() function from requests.auth)\n self.auth_header = _construct_basic_auth_str(parsed_url.username, parsed_url.password)\n\n def _build_url(self, path):\n \"\"\"prepend url with hostname unless it's already an absolute URL\"\"\"\n if absolute_http_url_regexp.match(path):\n return path\n else:\n return f\"{self.base_url}{path}\"\n\n def _send_request_safe_mode(self, method, url, **kwargs):\n \"\"\"\n Send an HTTP request, and catch any exception that might occur due to either\n connection problems, or invalid HTTP status codes\n \"\"\"\n try:\n return self.client.urlopen(url, method=method, **kwargs)\n except FAILURE_EXCEPTIONS as e:\n if hasattr(e, \"response\"):\n r = e.response\n else:\n req = self.client._make_request(\n url,\n method=method,\n headers=kwargs.get(\"headers\"),\n payload=kwargs.get(\"payload\"),\n params=kwargs.get(\"params\"),\n )\n r = ErrorResponse(url=url, request=req)\n r.error = e\n return r\n\n def request(\n self,\n method: str,\n url: str,\n name: str | None = None,\n data: str | dict | None = None,\n catch_response: bool = False,\n stream: bool = False,\n headers: dict | None = None,\n auth=None,\n json: dict | None = None,\n allow_redirects=True,\n context: dict = {},\n **kwargs,\n ) -> ResponseContextManager | FastResponse:\n \"\"\"\n Send and HTTP request\n Returns :py:class:`locust.contrib.fasthttp.FastResponse` object.\n\n :param method: method for the new :class:`Request` object.\n :param url: path that will be concatenated with the base host URL that has been specified.\n Can also be a full URL, in which case the full URL will be requested, and the base host\n is ignored.\n :param name: (optional) An argument that can be specified to use as label in Locust's\n statistics instead of the URL path. This can be used to group different URL's\n that are requested into a single entry in Locust's statistics.\n :param catch_response: (optional) Boolean argument that, if set, can be used to make a request\n return a context manager to work as argument to a with statement. This will allow the\n request to be marked as a fail based on the content of the response, even if the response\n code is ok (2xx). The opposite also works, one can use catch_response to catch a request\n and then mark it as successful even if the response code was not (i.e 500 or 404).\n :param data: (optional) String/bytes to send in the body of the request.\n :param json: (optional) Dictionary to send in the body of the request.\n Automatically sets Content-Type and Accept headers to \"application/json\".\n Only used if data is not set.\n :param headers: (optional) Dictionary of HTTP Headers to send with the request.\n :param auth: (optional) Auth (username, password) tuple to enable Basic HTTP Auth.\n :param stream: (optional) If set to true the response body will not be consumed immediately\n and can instead be consumed by accessing the stream attribute on the Response object.\n Another side effect of setting stream to True is that the time for downloading the response\n content will not be accounted for in the request time that is reported by Locust.\n \"\"\"\n # prepend url with hostname unless it's already an absolute URL\n built_url = self._build_url(url)\n\n start_time = time.time() # seconds since epoch\n\n if self.user:\n context = {**self.user.context(), **context}\n\n headers = headers or {}\n if auth:\n headers[\"Authorization\"] = _construct_basic_auth_str(auth[0], auth[1])\n elif self.auth_header:\n headers[\"Authorization\"] = self.auth_header\n if \"Accept-Encoding\" not in headers and \"accept-encoding\" not in headers:\n headers[\"Accept-Encoding\"] = \"gzip, deflate\"\n\n if not data and json is not None:\n data = unshadowed_json.dumps(json)\n if \"Content-Type\" not in headers and \"content-type\" not in headers:\n headers[\"Content-Type\"] = \"application/json\"\n if \"Accept\" not in headers and \"accept\" not in headers:\n headers[\"Accept\"] = \"application/json\"\n\n if not allow_redirects:\n old_redirect_response_codes = self.client.redirect_resonse_codes\n self.client.redirect_resonse_codes = []\n\n start_perf_counter = time.perf_counter()\n # send request, and catch any exceptions\n response = self._send_request_safe_mode(method, built_url, payload=data, headers=headers, **kwargs)\n request_meta = {\n \"request_type\": method,\n \"name\": name or url,\n \"context\": context,\n \"response\": response,\n \"exception\": None,\n \"start_time\": start_time,\n \"url\": built_url, # this is a small deviation from HttpSession, which gets the final (possibly redirected) URL\n }\n\n if not allow_redirects:\n self.client.redirect_resonse_codes = old_redirect_response_codes\n\n # get the length of the content, but if the argument stream is set to True, we take\n # the size from the content-length header, in order to not trigger fetching of the body\n if stream:\n request_meta[\"response_length\"] = int(response.headers.get(\"response_length\") or 0)\n else:\n try:\n request_meta[\"response_length\"] = len(response.content or \"\")\n except HTTPParseError as e:\n request_meta[\"response_time\"] = (time.perf_counter() - start_perf_counter) * 1000\n request_meta[\"response_length\"] = 0\n request_meta[\"exception\"] = e\n self.environment.events.request.fire(**request_meta)\n return response\n\n # Record the consumed time\n # Note: This is intentionally placed after we record the content_size above, since\n # we'll then trigger fetching of the body (unless stream=True)\n request_meta[\"response_time\"] = int((time.perf_counter() - start_perf_counter) * 1000)\n\n if catch_response:\n return ResponseContextManager(response, environment=self.environment, request_meta=request_meta)\n else:\n try:\n response.raise_for_status()\n except FAILURE_EXCEPTIONS as e:\n request_meta[\"exception\"] = e\n\n self.environment.events.request.fire(**request_meta)\n return response\n\n def delete(self, url, **kwargs):\n return self.request(\"DELETE\", url, **kwargs)\n\n def get(self, url, **kwargs):\n \"\"\"Sends a GET request\"\"\"\n return self.request(\"GET\", url, **kwargs)\n\n def head(self, url, **kwargs):\n \"\"\"Sends a HEAD request\"\"\"\n return self.request(\"HEAD\", url, **kwargs)\n\n def options(self, url, **kwargs):\n \"\"\"Sends a OPTIONS request\"\"\"\n return self.request(\"OPTIONS\", url, **kwargs)\n\n def patch(self, url, data=None, **kwargs):\n \"\"\"Sends a POST request\"\"\"\n...\n", "needle_token_start": 13678, "needle_token_end": 13783, "code_context_ntokens": 16164, "output": ["```python\ndef _construct_basic_auth_str(username, password):\n \"\"\"Construct Authorization header value to be used in HTTP Basic Auth\"\"\"\n if isinstance(username, str):\n username = username.encode(\"latin1\")\n if isinstance(password, str):\n password = password.encode(\"latin1\")\n return \"Basic \" + b64encode(b\":\".join((username, password))).strip().decode(\"ascii\")\n```"]} +{"repo": "locustio/locust", "name": "test_distribute_users", "language": "python", "path": "locust/test/test_dispatch.py", "position_ratio": 0.95, "description": "\nFunction Description:\n1. **Purpose**: The function is designed to validate the distribution mechanism of users across multiple worker nodes, ensuring that the total number of users distributed matches a predefined target while measuring the efficiency and performance of the distribution process.\n2. **Input**: The function takes different sets of user classes (with varying distribution weights and fixed counts) and a large number of worker nodes as input parameters.\n3. **Output**: The primary output is a validation of whether the total number of users distributed across all worker nodes matches the expected target. Additionally, it outputs performance metrics indicating the time taken for the distribution process.\n4. **Procedure**: The function initiates by creating a list of worker nodes. It then iterates over different configurations of user classes, setting up a user distribution system for each configuration. For each setup, it measures the time taken to distribute a target number of users across the worker nodes, checks the total count of distributed users, and asserts that the distribution time is within acceptable limits.\n\n", "instruction": "Based on the function description and code context, please retrieve and repeat the exact described function from the code context in a code block wrapped by ```:", "template": "instruction\ncode_context\ndescription\ninstruction", "code_context": " users_dispatcher.new_dispatch(target_user_count=0, spawn_rate=spawn_rate)\n self.assertEqual(users_dispatcher._wait_between_dispatch, expected_wait_between_dispatch)\n\n\nclass TestRampDownUsersToZero(unittest.TestCase):\n def test_ramp_down_users_to_3_workers_with_spawn_rate_of_0_5(self):\n class User1(User):\n weight = 1\n\n class User2(User):\n weight = 1\n\n class User3(User):\n weight = 1\n\n user_classes = [User1, User2, User3]\n\n workers = [WorkerNode(str(i + 1)) for i in range(3)]\n\n initial_user_count = 9\n\n users_dispatcher = UsersDispatcher(worker_nodes=workers, user_classes=user_classes)\n users_dispatcher.new_dispatch(target_user_count=initial_user_count, spawn_rate=initial_user_count)\n users_dispatcher._wait_between_dispatch = 0\n list(users_dispatcher)\n\n sleep_time = 0.2 # Speed-up test\n\n users_dispatcher.new_dispatch(target_user_count=0, spawn_rate=0.5)\n users_dispatcher._wait_between_dispatch = sleep_time\n\n ts = time.perf_counter()\n self.assertDictEqual(\n next(users_dispatcher),\n {\n \"1\": {\"User1\": 3, \"User2\": 0, \"User3\": 0},\n \"2\": {\"User1\": 0, \"User2\": 3, \"User3\": 0},\n \"3\": {\"User1\": 0, \"User2\": 0, \"User3\": 2},\n },\n )\n delta = time.perf_counter() - ts\n self.assertTrue(0 <= delta <= _TOLERANCE, delta)\n\n ts = time.perf_counter()\n self.assertDictEqual(\n next(users_dispatcher),\n {\n \"1\": {\"User1\": 3, \"User2\": 0, \"User3\": 0},\n \"2\": {\"User1\": 0, \"User2\": 2, \"User3\": 0},\n \"3\": {\"User1\": 0, \"User2\": 0, \"User3\": 2},\n },\n )\n delta = time.perf_counter() - ts\n self.assertTrue(sleep_time - _TOLERANCE <= delta <= sleep_time + _TOLERANCE, delta)\n\n ts = time.perf_counter()\n self.assertDictEqual(\n next(users_dispatcher),\n {\n \"1\": {\"User1\": 2, \"User2\": 0, \"User3\": 0},\n \"2\": {\"User1\": 0, \"User2\": 2, \"User3\": 0},\n \"3\": {\"User1\": 0, \"User2\": 0, \"User3\": 2},\n },\n )\n delta = time.perf_counter() - ts\n self.assertTrue(sleep_time - _TOLERANCE <= delta <= sleep_time + _TOLERANCE, delta)\n\n ts = time.perf_counter()\n self.assertDictEqual(\n next(users_dispatcher),\n {\n \"1\": {\"User1\": 2, \"User2\": 0, \"User3\": 0},\n \"2\": {\"User1\": 0, \"User2\": 2, \"User3\": 0},\n \"3\": {\"User1\": 0, \"User2\": 0, \"User3\": 1},\n },\n )\n delta = time.perf_counter() - ts\n self.assertTrue(sleep_time - _TOLERANCE <= delta <= sleep_time + _TOLERANCE, delta)\n\n ts = time.perf_counter()\n self.assertDictEqual(\n next(users_dispatcher),\n {\n \"1\": {\"User1\": 2, \"User2\": 0, \"User3\": 0},\n \"2\": {\"User1\": 0, \"User2\": 1, \"User3\": 0},\n \"3\": {\"User1\": 0, \"User2\": 0, \"User3\": 1},\n },\n )\n delta = time.perf_counter() - ts\n self.assertTrue(sleep_time - _TOLERANCE <= delta <= sleep_time + _TOLERANCE, delta)\n\n ts = time.perf_counter()\n self.assertDictEqual(\n next(users_dispatcher),\n {\n \"1\": {\"User1\": 1, \"User2\": 0, \"User3\": 0},\n \"2\": {\"User1\": 0, \"User2\": 1, \"User3\": 0},\n \"3\": {\"User1\": 0, \"User2\": 0, \"User3\": 1},\n },\n )\n delta = time.perf_counter() - ts\n self.assertTrue(sleep_time - _TOLERANCE <= delta <= sleep_time + _TOLERANCE, delta)\n\n ts = time.perf_counter()\n self.assertDictEqual(\n next(users_dispatcher),\n {\n \"1\": {\"User1\": 1, \"User2\": 0, \"User3\": 0},\n \"2\": {\"User1\": 0, \"User2\": 1, \"User3\": 0},\n \"3\": {\"User1\": 0, \"User2\": 0, \"User3\": 0},\n },\n )\n delta = time.perf_counter() - ts\n self.assertTrue(sleep_time - _TOLERANCE <= delta <= sleep_time + _TOLERANCE, delta)\n\n ts = time.perf_counter()\n self.assertDictEqual(\n next(users_dispatcher),\n {\n \"1\": {\"User1\": 1, \"User2\": 0, \"User3\": 0},\n \"2\": {\"User1\": 0, \"User2\": 0, \"User3\": 0},\n \"3\": {\"User1\": 0, \"User2\": 0, \"User3\": 0},\n },\n )\n delta = time.perf_counter() - ts\n self.assertTrue(sleep_time - _TOLERANCE <= delta <= sleep_time + _TOLERANCE, delta)\n\n ts = time.perf_counter()\n self.assertDictEqual(\n next(users_dispatcher),\n {\n \"1\": {\"User1\": 0, \"User2\": 0, \"User3\": 0},\n \"2\": {\"User1\": 0, \"User2\": 0, \"User3\": 0},\n \"3\": {\"User1\": 0, \"User2\": 0, \"User3\": 0},\n },\n )\n delta = time.perf_counter() - ts\n self.assertTrue(sleep_time - _TOLERANCE <= delta <= sleep_time + _TOLERANCE, delta)\n\n ts = time.perf_counter()\n self.assertRaises(StopIteration, lambda: next(users_dispatcher))\n delta = time.perf_counter() - ts\n self.assertTrue(0 <= delta <= _TOLERANCE, delta)\n\n # def test_ramp_down_users_on_workers_respecting_weight(self):\n # class User1(User):\n # weight = 1\n #\n # class User2(User):\n # weight = 1\n #\n # class User3(User):\n # weight = 1\n #\n # user_classes = [User1, User2, User3]\n # workers = [WorkerNode(str(i + 1)) for i in range(3)]\n #\n # user_dispatcher = UsersDispatcher(worker_nodes= workers, user_classes = user_classes)\n # user_dispatcher.new_dispatch(target_user_count=7, spawn_rate=7)\n #\n # dispatched_users = next(user_dispatcher)\n # self.assertDictEqual(dispatched_users,\n # {\n # \"1\": {\"User1\": 3, \"User2\": 0, \"User3\": 0},\n # \"2\": {\"User1\": 0, \"User2\": 2, \"User3\": 0},\n # \"3\": {\"User1\": 0, \"User2\": 0, \"User3\": 2}\n # })\n #\n # user_dispatcher.new_dispatch(target_user_count=16, spawn_rate=9)\n # dispatched_users = next(user_dispatcher)\n # self.assertDictEqual(dispatched_users,\n # {\n # \"1\": {\"User1\": 6, \"User2\": 0, \"User3\": 0},\n # \"2\": {\"User1\": 0, \"User2\": 5, \"User3\": 0},\n # \"3\": {\"User1\": 0, \"User2\": 0, \"User3\": 5}\n # })\n #\n # user_dispatcher.new_dispatch(target_user_count=3, spawn_rate=15)\n # dispatched_users = next(user_dispatcher)\n # self.assertDictEqual(dispatched_users,\n # {\n # \"1\": {\"User1\": 1, \"User2\": 0, \"User3\": 0},\n # \"2\": {\"User1\": 0, \"User2\": 1, \"User3\": 0},\n # \"3\": {\"User1\": 0, \"User2\": 0, \"User3\": 1}\n # })\n #\n\n def test_ramp_down_users_to_3_workers_with_spawn_rate_of_1(self):\n class User1(User):\n weight = 1\n\n class User2(User):\n weight = 1\n\n class User3(User):\n weight = 1\n\n user_classes = [User1, User2, User3]\n\n workers = [WorkerNode(str(i + 1)) for i in range(3)]\n\n initial_user_count = 9\n\n users_dispatcher = UsersDispatcher(worker_nodes=workers, user_classes=user_classes)\n users_dispatcher.new_dispatch(target_user_count=initial_user_count, spawn_rate=initial_user_count)\n users_dispatcher._wait_between_dispatch = 0\n list(users_dispatcher)\n\n sleep_time = 0.2 # Speed-up test\n\n users_dispatcher.new_dispatch(target_user_count=0, spawn_rate=1)\n users_dispatcher._wait_between_dispatch = sleep_time\n\n ts = time.perf_counter()\n self.assertDictEqual(\n next(users_dispatcher),\n {\n \"1\": {\"User1\": 3, \"User2\": 0, \"User3\": 0},\n \"2\": {\"User1\": 0, \"User2\": 3, \"User3\": 0},\n \"3\": {\"User1\": 0, \"User2\": 0, \"User3\": 2},\n },\n )\n delta = time.perf_counter() - ts\n self.assertTrue(0 <= delta <= _TOLERANCE, delta)\n\n ts = time.perf_counter()\n self.assertDictEqual(\n next(users_dispatcher),\n {\n \"1\": {\"User1\": 3, \"User2\": 0, \"User3\": 0},\n \"2\": {\"User1\": 0, \"User2\": 2, \"User3\": 0},\n \"3\": {\"User1\": 0, \"User2\": 0, \"User3\": 2},\n },\n )\n delta = time.perf_counter() - ts\n self.assertTrue(sleep_time - _TOLERANCE <= delta <= sleep_time + _TOLERANCE, delta)\n\n ts = time.perf_counter()\n self.assertDictEqual(\n next(users_dispatcher),\n {\n \"1\": {\"User1\": 2, \"User2\": 0, \"User3\": 0},\n \"2\": {\"User1\": 0, \"User2\": 2, \"User3\": 0},\n \"3\": {\"User1\": 0, \"User2\": 0, \"User3\": 2},\n },\n )\n delta = time.perf_counter() - ts\n self.assertTrue(sleep_time - _TOLERANCE <= delta <= sleep_time + _TOLERANCE, delta)\n\n ts = time.perf_counter()\n self.assertDictEqual(\n next(users_dispatcher),\n {\n \"1\": {\"User1\": 2, \"User2\": 0, \"User3\": 0},\n \"2\": {\"User1\": 0, \"User2\": 2, \"User3\": 0},\n \"3\": {\"User1\": 0, \"User2\": 0, \"User3\": 1},\n },\n )\n delta = time.perf_counter() - ts\n self.assertTrue(sleep_time - _TOLERANCE <= delta <= sleep_time + _TOLERANCE, delta)\n\n ts = time.perf_counter()\n self.assertDictEqual(\n next(users_dispatcher),\n {\n \"1\": {\"User1\": 2, \"User2\": 0, \"User3\": 0},\n \"2\": {\"User1\": 0, \"User2\": 1, \"User3\": 0},\n \"3\": {\"User1\": 0, \"User2\": 0, \"User3\": 1},\n },\n )\n delta = time.perf_counter() - ts\n self.assertTrue(sleep_time - _TOLERANCE <= delta <= sleep_time + _TOLERANCE, delta)\n\n ts = time.perf_counter()\n self.assertDictEqual(\n next(users_dispatcher),\n {\n \"1\": {\"User1\": 1, \"User2\": 0, \"User3\": 0},\n \"2\": {\"User1\": 0, \"User2\": 1, \"User3\": 0},\n \"3\": {\"User1\": 0, \"User2\": 0, \"User3\": 1},\n },\n )\n delta = time.perf_counter() - ts\n self.assertTrue(sleep_time - _TOLERANCE <= delta <= sleep_time + _TOLERANCE, delta)\n\n ts = time.perf_counter()\n self.assertDictEqual(\n next(users_dispatcher),\n {\n \"1\": {\"User1\": 1, \"User2\": 0, \"User3\": 0},\n \"2\": {\"User1\": 0, \"User2\": 1, \"User3\": 0},\n \"3\": {\"User1\": 0, \"User2\": 0, \"User3\": 0},\n },\n )\n delta = time.perf_counter() - ts\n self.assertTrue(sleep_time - _TOLERANCE <= delta <= sleep_time + _TOLERANCE, delta)\n\n ts = time.perf_counter()\n self.assertDictEqual(\n next(users_dispatcher),\n {\n \"1\": {\"User1\": 1, \"User2\": 0, \"User3\": 0},\n \"2\": {\"User1\": 0, \"User2\": 0, \"User3\": 0},\n \"3\": {\"User1\": 0, \"User2\": 0, \"User3\": 0},\n },\n )\n delta = time.perf_counter() - ts\n self.assertTrue(sleep_time - _TOLERANCE <= delta <= sleep_time + _TOLERANCE, delta)\n\n ts = time.perf_counter()\n self.assertDictEqual(\n next(users_dispatcher),\n {\n \"1\": {\"User1\": 0, \"User2\": 0, \"User3\": 0},\n \"2\": {\"User1\": 0, \"User2\": 0, \"User3\": 0},\n \"3\": {\"User1\": 0, \"User2\": 0, \"User3\": 0},\n },\n )\n delta = time.perf_counter() - ts\n self.assertTrue(sleep_time - _TOLERANCE <= delta <= sleep_time + _TOLERANCE, delta)\n\n ts = time.perf_counter()\n self.assertRaises(StopIteration, lambda: next(users_dispatcher))\n delta = time.perf_counter() - ts\n self.assertTrue(0 <= delta <= _TOLERANCE, delta)\n\n def test_ramp_down_users_to_4_workers_with_spawn_rate_of_1(self):\n class User1(User):\n weight = 1\n\n class User2(User):\n weight = 1\n\n class User3(User):\n weight = 1\n\n user_classes = [User1, User2, User3]\n\n workers = [WorkerNode(str(i + 1)) for i in range(4)]\n\n initial_user_count = 9\n\n users_dispatcher = UsersDispatcher(worker_nodes=workers, user_classes=user_classes)\n users_dispatcher.new_dispatch(target_user_count=initial_user_count, spawn_rate=initial_user_count)\n users_dispatcher._wait_between_dispatch = 0\n list(users_dispatcher)\n\n sleep_time = 0.2 # Speed-up test\n\n users_dispatcher.new_dispatch(target_user_count=0, spawn_rate=1)\n users_dispatcher._wait_between_dispatch = sleep_time\n\n ts = time.perf_counter()\n self.assertDictEqual(\n next(users_dispatcher),\n {\n \"1\": {\"User1\": 1, \"User2\": 1, \"User3\": 0},\n \"2\": {\"User1\": 0, \"User2\": 1, \"User3\": 1},\n \"3\": {\"User1\": 1, \"User2\": 0, \"User3\": 1},\n \"4\": {\"User1\": 1, \"User2\": 1, \"User3\": 0},\n },\n )\n delta = time.perf_counter() - ts\n self.assertTrue(0 <= delta <= _TOLERANCE, delta)\n\n ts = time.perf_counter()\n self.assertDictEqual(\n next(users_dispatcher),\n {\n \"1\": {\"User1\": 1, \"User2\": 1, \"User3\": 0},\n \"2\": {\"User1\": 0, \"User2\": 1, \"User3\": 1},\n \"3\": {\"User1\": 1, \"User2\": 0, \"User3\": 1},\n \"4\": {\"User1\": 1, \"User2\": 0, \"User3\": 0},\n },\n )\n delta = time.perf_counter() - ts\n self.assertTrue(sleep_time - _TOLERANCE <= delta <= sleep_time + _TOLERANCE, delta)\n\n ts = time.perf_counter()\n self.assertDictEqual(\n next(users_dispatcher),\n {\n \"1\": {\"User1\": 1, \"User2\": 1, \"User3\": 0},\n \"2\": {\"User1\": 0, \"User2\": 1, \"User3\": 1},\n \"3\": {\"User1\": 0, \"User2\": 0, \"User3\": 1},\n \"4\": {\"User1\": 1, \"User2\": 0, \"User3\": 0},\n },\n )\n delta = time.perf_counter() - ts\n self.assertTrue(sleep_time - _TOLERANCE <= delta <= sleep_time + _TOLERANCE, delta)\n\n ts = time.perf_counter()\n self.assertDictEqual(\n next(users_dispatcher),\n {\n \"1\": {\"User1\": 1, \"User2\": 1, \"User3\": 0},\n \"2\": {\"User1\": 0, \"User2\": 1, \"User3\": 0},\n \"3\": {\"User1\": 0, \"User2\": 0, \"User3\": 1},\n \"4\": {\"User1\": 1, \"User2\": 0, \"User3\": 0},\n },\n )\n delta = time.perf_counter() - ts\n self.assertTrue(sleep_time - _TOLERANCE <= delta <= sleep_time + _TOLERANCE, delta)\n\n ts = time.perf_counter()\n self.assertDictEqual(\n next(users_dispatcher),\n {\n \"1\": {\"User1\": 1, \"User2\": 0, \"User3\": 0},\n \"2\": {\"User1\": 0, \"User2\": 1, \"User3\": 0},\n \"3\": {\"User1\": 0, \"User2\": 0, \"User3\": 1},\n \"4\": {\"User1\": 1, \"User2\": 0, \"User3\": 0},\n },\n )\n delta = time.perf_counter() - ts\n self.assertTrue(sleep_time - _TOLERANCE <= delta <= sleep_time + _TOLERANCE, delta)\n\n ts = time.perf_counter()\n self.assertDictEqual(\n next(users_dispatcher),\n {\n \"1\": {\"User1\": 1, \"User2\": 0, \"User3\": 0},\n \"2\": {\"User1\": 0, \"User2\": 1, \"User3\": 0},\n \"3\": {\"User1\": 0, \"User2\": 0, \"User3\": 1},\n \"4\": {\"User1\": 0, \"User2\": 0, \"User3\": 0},\n },\n )\n delta = time.perf_counter() - ts\n self.assertTrue(sleep_time - _TOLERANCE <= delta <= sleep_time + _TOLERANCE, delta)\n\n ts = time.perf_counter()\n self.assertDictEqual(\n next(users_dispatcher),\n {\n \"1\": {\"User1\": 1, \"User2\": 0, \"User3\": 0},\n \"2\": {\"User1\": 0, \"User2\": 1, \"User3\": 0},\n \"3\": {\"User1\": 0, \"User2\": 0, \"User3\": 0},\n \"4\": {\"User1\": 0, \"User2\": 0, \"User3\": 0},\n },\n )\n delta = time.perf_counter() - ts\n self.assertTrue(sleep_time - _TOLERANCE <= delta <= sleep_time + _TOLERANCE, delta)\n\n ts = time.perf_counter()\n self.assertDictEqual(\n next(users_dispatcher),\n {\n \"1\": {\"User1\": 1, \"User2\": 0, \"User3\": 0},\n \"2\": {\"User1\": 0, \"User2\": 0, \"User3\": 0},\n \"3\": {\"User1\": 0, \"User2\": 0, \"User3\": 0},\n \"4\": {\"User1\": 0, \"User2\": 0, \"User3\": 0},\n },\n )\n delta = time.perf_counter() - ts\n self.assertTrue(sleep_time - _TOLERANCE <= delta <= sleep_time + _TOLERANCE, delta)\n\n ts = time.perf_counter()\n self.assertDictEqual(\n next(users_dispatcher),\n {\n \"1\": {\"User1\": 0, \"User2\": 0, \"User3\": 0},\n \"2\": {\"User1\": 0, \"User2\": 0, \"User3\": 0},\n \"3\": {\"User1\": 0, \"User2\": 0, \"User3\": 0},\n \"4\": {\"User1\": 0, \"User2\": 0, \"User3\": 0},\n },\n )\n delta = time.perf_counter() - ts\n self.assertTrue(sleep_time - _TOLERANCE <= delta <= sleep_time + _TOLERANCE, delta)\n\n ts = time.perf_counter()\n self.assertRaises(StopIteration, lambda: next(users_dispatcher))\n delta = time.perf_counter() - ts\n self.assertTrue(0 <= delta <= _TOLERANCE, delta)\n\n def test_ramp_down_users_to_3_workers_with_spawn_rate_of_2(self):\n class User1(User):\n weight = 1\n\n class User2(User):\n weight = 1\n\n class User3(User):\n weight = 1\n\n user_classes = [User1, User2, User3]\n\n workers = [WorkerNode(str(i + 1)) for i in range(3)]\n\n initial_user_count = 9\n\n users_dispatcher = UsersDispatcher(worker_nodes=workers, user_classes=user_classes)\n users_dispatcher.new_dispatch(target_user_count=initial_user_count, spawn_rate=initial_user_count)\n users_dispatcher._wait_between_dispatch = 0\n list(users_dispatcher)\n\n sleep_time = 0.2 # Speed-up test\n\n users_dispatcher.new_dispatch(target_user_count=0, spawn_rate=2)\n users_dispatcher._wait_between_dispatch = sleep_time\n\n ts = time.perf_counter()\n self.assertDictEqual(\n next(users_dispatcher),\n {\n \"1\": {\"User1\": 3, \"User2\": 0, \"User3\": 0},\n \"2\": {\"User1\": 0, \"User2\": 2, \"User3\": 0},\n \"3\": {\"User1\": 0, \"User2\": 0, \"User3\": 2},\n },\n )\n delta = time.perf_counter() - ts\n self.assertTrue(0 <= delta <= _TOLERANCE, delta)\n\n ts = time.perf_counter()\n self.assertDictEqual(\n next(users_dispatcher),\n {\n \"1\": {\"User1\": 2, \"User2\": 0, \"User3\": 0},\n \"2\": {\"User1\": 0, \"User2\": 2, \"User3\": 0},\n \"3\": {\"User1\": 0, \"User2\": 0, \"User3\": 1},\n },\n )\n delta = time.perf_counter() - ts\n self.assertTrue(sleep_time - _TOLERANCE <= delta <= sleep_time + _TOLERANCE, delta)\n\n ts = time.perf_counter()\n self.assertDictEqual(\n next(users_dispatcher),\n {\n \"1\": {\"User1\": 1, \"User2\": 0, \"User3\": 0},\n \"2\": {\"User1\": 0, \"User2\": 1, \"User3\": 0},\n \"3\": {\"User1\": 0, \"User2\": 0, \"User3\": 1},\n },\n )\n delta = time.perf_counter() - ts\n self.assertTrue(sleep_time - _TOLERANCE <= delta <= sleep_time + _TOLERANCE, delta)\n\n ts = time.perf_counter()\n self.assertDictEqual(\n next(users_dispatcher),\n {\n \"1\": {\"User1\": 1, \"User2\": 0, \"User3\": 0},\n \"2\": {\"User1\": 0, \"User2\": 0, \"User3\": 0},\n \"3\": {\"User1\": 0, \"User2\": 0, \"User3\": 0},\n },\n )\n delta = time.perf_counter() - ts\n self.assertTrue(sleep_time - _TOLERANCE <= delta <= sleep_time + _TOLERANCE, delta)\n\n ts = time.perf_counter()\n self.assertDictEqual(\n next(users_dispatcher),\n {\n \"1\": {\"User1\": 0, \"User2\": 0, \"User3\": 0},\n \"2\": {\"User1\": 0, \"User2\": 0, \"User3\": 0},\n \"3\": {\"User1\": 0, \"User2\": 0, \"User3\": 0},\n },\n )\n delta = time.perf_counter() - ts\n self.assertTrue(sleep_time - _TOLERANCE <= delta <= sleep_time + _TOLERANCE, delta)\n\n ts = time.perf_counter()\n self.assertRaises(StopIteration, lambda: next(users_dispatcher))\n delta = time.perf_counter() - ts\n self.assertTrue(0 <= delta <= _TOLERANCE, delta)\n\n def test_ramp_down_users_to_3_workers_with_spawn_rate_of_2_4(self):\n class User1(User):\n weight = 1\n\n class User2(User):\n weight = 1\n\n class User3(User):\n weight = 1\n\n user_classes = [User1, User2, User3]\n\n workers = [WorkerNode(str(i + 1)) for i in range(3)]\n\n initial_user_count = 9\n\n users_dispatcher = UsersDispatcher(worker_nodes=workers, user_classes=user_classes)\n users_dispatcher.new_dispatch(target_user_count=initial_user_count, spawn_rate=initial_user_count)\n users_dispatcher._wait_between_dispatch = 0\n list(users_dispatcher)\n\n sleep_time = 0.2 # Speed-up test\n\n users_dispatcher.new_dispatch(target_user_count=0, spawn_rate=2.4)\n users_dispatcher._wait_between_dispatch = sleep_time\n\n ts = time.perf_counter()\n self.assertDictEqual(\n next(users_dispatcher),\n {\n \"1\": {\"User1\": 3, \"User2\": 0, \"User3\": 0},\n \"2\": {\"User1\": 0, \"User2\": 2, \"User3\": 0},\n \"3\": {\"User1\": 0, \"User2\": 0, \"User3\": 2},\n },\n )\n delta = time.perf_counter() - ts\n self.assertTrue(0 <= delta <= _TOLERANCE, delta)\n\n ts = time.perf_counter()\n self.assertDictEqual(\n next(users_dispatcher),\n {\n \"1\": {\"User1\": 2, \"User2\": 0, \"User3\": 0},\n \"2\": {\"User1\": 0, \"User2\": 2, \"User3\": 0},\n \"3\": {\"User1\": 0, \"User2\": 0, \"User3\": 1},\n },\n )\n delta = time.perf_counter() - ts\n self.assertTrue(sleep_time - _TOLERANCE <= delta <= sleep_time + _TOLERANCE, delta)\n\n ts = time.perf_counter()\n self.assertDictEqual(\n next(users_dispatcher),\n {\n \"1\": {\"User1\": 1, \"User2\": 0, \"User3\": 0},\n \"2\": {\"User1\": 0, \"User2\": 1, \"User3\": 0},\n \"3\": {\"User1\": 0, \"User2\": 0, \"User3\": 1},\n },\n )\n delta = time.perf_counter() - ts\n self.assertTrue(sleep_time - _TOLERANCE <= delta <= sleep_time + _TOLERANCE, delta)\n\n ts = time.perf_counter()\n self.assertDictEqual(\n next(users_dispatcher),\n {\n \"1\": {\"User1\": 1, \"User2\": 0, \"User3\": 0},\n \"2\": {\"User1\": 0, \"User2\": 0, \"User3\": 0},\n \"3\": {\"User1\": 0, \"User2\": 0, \"User3\": 0},\n },\n )\n delta = time.perf_counter() - ts\n self.assertTrue(sleep_time - _TOLERANCE <= delta <= sleep_time + _TOLERANCE, delta)\n\n ts = time.perf_counter()\n self.assertDictEqual(\n next(users_dispatcher),\n {\n \"1\": {\"User1\": 0, \"User2\": 0, \"User3\": 0},\n \"2\": {\"User1\": 0, \"User2\": 0, \"User3\": 0},\n \"3\": {\"User1\": 0, \"User2\": 0, \"User3\": 0},\n },\n )\n delta = time.perf_counter() - ts\n self.assertTrue(sleep_time - _TOLERANCE <= delta <= sleep_time + _TOLERANCE, delta)\n\n ts = time.perf_counter()\n self.assertRaises(StopIteration, lambda: next(users_dispatcher))\n delta = time.perf_counter() - ts\n self.assertTrue(0 <= delta <= _TOLERANCE, delta)\n\n def test_ramp_down_users_to_3_workers_with_spawn_rate_of_3(self):\n class User1(User):\n weight = 1\n\n class User2(User):\n weight = 1\n\n class User3(User):\n weight = 1\n\n user_classes = [User1, User2, User3]\n\n workers = [WorkerNode(str(i + 1)) for i in range(3)]\n\n initial_user_count = 9\n\n users_dispatcher = UsersDispatcher(worker_nodes=workers, user_classes=user_classes)\n users_dispatcher.new_dispatch(target_user_count=initial_user_count, spawn_rate=initial_user_count)\n users_dispatcher._wait_between_dispatch = 0\n list(users_dispatcher)\n\n sleep_time = 0.2 # Speed-up test\n\n users_dispatcher.new_dispatch(target_user_count=0, spawn_rate=3)\n users_dispatcher._wait_between_dispatch = sleep_time\n\n ts = time.perf_counter()\n self.assertDictEqual(\n next(users_dispatcher),\n {\n \"1\": {\"User1\": 2, \"User2\": 0, \"User3\": 0},\n \"2\": {\"User1\": 0, \"User2\": 2, \"User3\": 0},\n \"3\": {\"User1\": 0, \"User2\": 0, \"User3\": 2},\n },\n )\n delta = time.perf_counter() - ts\n self.assertTrue(0 <= delta <= _TOLERANCE, delta)\n\n ts = time.perf_counter()\n self.assertDictEqual(\n next(users_dispatcher),\n {\n \"1\": {\"User1\": 1, \"User2\": 0, \"User3\": 0},\n \"2\": {\"User1\": 0, \"User2\": 1, \"User3\": 0},\n \"3\": {\"User1\": 0, \"User2\": 0, \"User3\": 1},\n },\n )\n delta = time.perf_counter() - ts\n self.assertTrue(sleep_time - _TOLERANCE <= delta <= sleep_time + _TOLERANCE, delta)\n\n ts = time.perf_counter()\n self.assertDictEqual(\n next(users_dispatcher),\n {\n \"1\": {\"User1\": 0, \"User2\": 0, \"User3\": 0},\n \"2\": {\"User1\": 0, \"User2\": 0, \"User3\": 0},\n \"3\": {\"User1\": 0, \"User2\": 0, \"User3\": 0},\n },\n )\n delta = time.perf_counter() - ts\n self.assertTrue(sleep_time - _TOLERANCE <= delta <= sleep_time + _TOLERANCE, delta)\n\n ts = time.perf_counter()\n self.assertRaises(StopIteration, lambda: next(users_dispatcher))\n delta = time.perf_counter() - ts\n self.assertTrue(0 <= delta <= _TOLERANCE, delta)\n\n def test_ramp_down_users_to_3_workers_with_spawn_rate_of_4(self):\n class User1(User):\n weight = 1\n\n class User2(User):\n weight = 1\n\n class User3(User):\n weight = 1\n\n user_classes = [User1, User2, User3]\n\n workers = [WorkerNode(str(i + 1)) for i in range(3)]\n\n initial_user_count = 9\n\n users_dispatcher = UsersDispatcher(worker_nodes=workers, user_classes=user_classes)\n users_dispatcher.new_dispatch(target_user_count=initial_user_count, spawn_rate=initial_user_count)\n users_dispatcher._wait_between_dispatch = 0\n list(users_dispatcher)\n\n sleep_time = 0.2 # Speed-up test\n\n users_dispatcher.new_dispatch(target_user_count=0, spawn_rate=4)\n users_dispatcher._wait_between_dispatch = sleep_time\n\n ts = time.perf_counter()\n self.assertDictEqual(\n next(users_dispatcher),\n {\n \"1\": {\"User1\": 2, \"User2\": 0, \"User3\": 0},\n \"2\": {\"User1\": 0, \"User2\": 2, \"User3\": 0},\n \"3\": {\"User1\": 0, \"User2\": 0, \"User3\": 1},\n },\n )\n delta = time.perf_counter() - ts\n self.assertTrue(0 <= delta <= _TOLERANCE, delta)\n\n ts = time.perf_counter()\n self.assertDictEqual(\n next(users_dispatcher),\n {\n \"1\": {\"User1\": 1, \"User2\": 0, \"User3\": 0},\n \"2\": {\"User1\": 0, \"User2\": 0, \"User3\": 0},\n \"3\": {\"User1\": 0, \"User2\": 0, \"User3\": 0},\n },\n )\n delta = time.perf_counter() - ts\n self.assertTrue(sleep_time - _TOLERANCE <= delta <= sleep_time + _TOLERANCE, delta)\n\n ts = time.perf_counter()\n self.assertDictEqual(\n next(users_dispatcher),\n {\n \"1\": {\"User1\": 0, \"User2\": 0, \"User3\": 0},\n \"2\": {\"User1\": 0, \"User2\": 0, \"User3\": 0},\n \"3\": {\"User1\": 0, \"User2\": 0, \"User3\": 0},\n },\n )\n delta = time.perf_counter() - ts\n self.assertTrue(sleep_time - _TOLERANCE <= delta <= sleep_time + _TOLERANCE, delta)\n\n ts = time.perf_counter()\n self.assertRaises(StopIteration, lambda: next(users_dispatcher))\n delta = time.perf_counter() - ts\n self.assertTrue(0 <= delta <= _TOLERANCE, delta)\n\n def test_ramp_down_users_to_3_workers_with_spawn_rate_of_9(self):\n class User1(User):\n weight = 1\n\n class User2(User):\n weight = 1\n\n class User3(User):\n weight = 1\n\n user_classes = [User1, User2, User3]\n\n workers = [WorkerNode(str(i + 1)) for i in range(3)]\n\n initial_user_count = 9\n\n users_dispatcher = UsersDispatcher(worker_nodes=workers, user_classes=user_classes)\n users_dispatcher.new_dispatch(target_user_count=initial_user_count, spawn_rate=initial_user_count)\n users_dispatcher._wait_between_dispatch = 0\n list(users_dispatcher)\n\n sleep_time = 0.2 # Speed-up test\n\n users_dispatcher.new_dispatch(target_user_count=0, spawn_rate=9)\n users_dispatcher._wait_between_dispatch = sleep_time\n\n ts = time.perf_counter()\n self.assertDictEqual(\n next(users_dispatcher),\n {\n \"1\": {\"User1\": 0, \"User2\": 0, \"User3\": 0},\n \"2\": {\"User1\": 0, \"User2\": 0, \"User3\": 0},\n \"3\": {\"User1\": 0, \"User2\": 0, \"User3\": 0},\n },\n )\n delta = time.perf_counter() - ts\n self.assertTrue(0 <= delta <= _TOLERANCE, delta)\n\n ts = time.perf_counter()\n self.assertRaises(StopIteration, lambda: next(users_dispatcher))\n delta = time.perf_counter() - ts\n self.assertTrue(0 <= delta <= _TOLERANCE, delta)\n\n\n@unittest.skip(reason=\"takes too long. run this manually if you change dispatch logic.\")\nclass TestRampUpThenDownThenUp(unittest.TestCase):\n def test_ramp_up_then_down_then_up(self):\n for user1_weight, user2_weight, user3_weight, user4_weight, user5_weight in [\n (1, 1, 1, 1, 1),\n (1, 2, 3, 4, 5),\n (1, 3, 5, 7, 9),\n ]:\n\n class User1(User):\n weight = user1_weight\n\n class User2(User):\n weight = user2_weight\n\n class User3(User):\n weight = user3_weight\n\n class User4(User):\n weight = user4_weight\n\n class User5(User):\n weight = user5_weight\n\n all_user_classes = [User1, User2, User3, User4, User5]\n\n for number_of_user_classes in range(1, len(all_user_classes) + 1):\n user_classes = all_user_classes[:number_of_user_classes]\n\n for max_user_count, min_user_count in [(30, 15), (54, 21), (14165, 1476)]:\n for worker_count in [1, 3, 5, 9]:\n workers = [WorkerNode(str(i + 1)) for i in range(worker_count)]\n\n users_dispatcher = UsersDispatcher(worker_nodes=workers, user_classes=user_classes)\n\n # Ramp-up to go to `min_user_count` #########\n\n users_dispatcher.new_dispatch(target_user_count=min_user_count, spawn_rate=1)\n users_dispatcher._wait_between_dispatch = 0\n\n all_dispatched_users_ramp_up_to_min_user_count = list(users_dispatcher)\n\n # Ramp-up to go to `max_user_count` #########\n\n users_dispatcher.new_dispatch(target_user_count=max_user_count, spawn_rate=1)\n users_dispatcher._wait_between_dispatch = 0\n\n list(users_dispatcher)\n\n # Ramp-down go back to `min_user_count` #########\n\n users_dispatcher.new_dispatch(target_user_count=min_user_count, spawn_rate=1)\n users_dispatcher._wait_between_dispatch = 0\n\n all_dispatched_users_ramp_down_to_min_user_count = list(users_dispatcher)\n\n # Assertions #########\n\n self.assertDictEqual(\n all_dispatched_users_ramp_up_to_min_user_count[-1],\n all_dispatched_users_ramp_down_to_min_user_count[-1],\n )\n\n\nclass TestDispatchUsersToWorkersHavingTheSameUsersAsTheTarget(unittest.TestCase):\n def test_dispatch_users_to_3_workers(self):\n \"\"\"Final distribution should be {\"User1\": 3, \"User2\": 3, \"User3\": 3}\"\"\"\n\n class User1(User):\n weight = 1\n\n class User2(User):\n weight = 1\n\n class User3(User):\n weight = 1\n\n user_classes = [User1, User2, User3]\n\n user_count = 9\n\n for spawn_rate in [0.15, 0.5, 1, 2, 2.4, 3, 4, 9]:\n workers = [WorkerNode(str(i + 1)) for i in range(3)]\n\n users_dispatcher = UsersDispatcher(worker_nodes=workers, user_classes=user_classes)\n users_dispatcher.new_dispatch(target_user_count=user_count, spawn_rate=user_count)\n users_dispatcher._wait_between_dispatch = 0\n list(users_dispatcher)\n\n sleep_time = 0.2 # Speed-up test\n\n users_dispatcher.new_dispatch(target_user_count=user_count, spawn_rate=spawn_rate)\n users_dispatcher._wait_between_dispatch = sleep_time\n\n ts = time.perf_counter()\n self.assertDictEqual(\n next(users_dispatcher),\n {\n \"1\": {\"User1\": 3, \"User2\": 0, \"User3\": 0},\n \"2\": {\"User1\": 0, \"User2\": 3, \"User3\": 0},\n \"3\": {\"User1\": 0, \"User2\": 0, \"User3\": 3},\n },\n )\n delta = time.perf_counter() - ts\n self.assertTrue(0 <= delta <= _TOLERANCE, delta)\n\n ts = time.perf_counter()\n self.assertRaises(StopIteration, lambda: next(users_dispatcher))\n delta = time.perf_counter() - ts\n self.assertTrue(0 <= delta <= _TOLERANCE, delta)\n\n clear_all_functools_lru_cache()\n\n\nclass TestDistributionIsRespectedDuringDispatch(unittest.TestCase):\n def test_dispatch_75_users_to_4_workers_with_spawn_rate_of_5(self):\n \"\"\"\n Test case covering reported issue in https://github.com/locustio/locust/pull/1621#issuecomment-853624275.\n\n The case is to ramp-up from 0 to 75 users with two user classes. `User1` has a weight of 1 and `User2`\n has a weight of 2. The original issue was with 500 users, but to keep the test shorter, we use 75 users.\n\n Final distribution should be {\"User1\": 25, \"User2\": 50}\n \"\"\"\n\n class User1(User):\n weight = 1\n\n class User2(User):\n weight = 2\n\n worker_node1 = WorkerNode(\"1\")\n worker_node2 = WorkerNode(\"2\")\n worker_node3 = WorkerNode(\"3\")\n worker_node4 = WorkerNode(\"4\")\n\n users_dispatcher = UsersDispatcher(\n worker_nodes=[worker_node1, worker_node2, worker_node3, worker_node4], user_classes=[User1, User2]\n )\n users_dispatcher.new_dispatch(target_user_count=75, spawn_rate=5)\n users_dispatcher._wait_between_dispatch = 0\n\n # total user count = 5\n dispatched_users = next(users_dispatcher)\n self.assertDictEqual(_aggregate_dispatched_users(dispatched_users), {\"User1\": 2, \"User2\": 3})\n self.assertDictEqual(\n dispatched_users,\n {\n \"1\": {\"User1\": 1, \"User2\": 1},\n \"2\": {\"User1\": 1, \"User2\": 0},\n \"3\": {\"User1\": 0, \"User2\": 1},\n \"4\": {\"User1\": 0, \"User2\": 1},\n },\n )\n\n # total user count = 10\n dispatched_users = next(users_dispatcher)\n self.assertDictEqual(_aggregate_dispatched_users(dispatched_users), {\"User1\": 3, \"User2\": 7})\n self.assertDictEqual(\n dispatched_users,\n {\n \"1\": {\"User1\": 1, \"User2\": 2},\n \"2\": {\"User1\": 1, \"User2\": 2},\n \"3\": {\"User1\": 0, \"User2\": 2},\n \"4\": {\"User1\": 1, \"User2\": 1},\n },\n )\n\n # total user count = 15\n dispatched_users = next(users_dispatcher)\n self.assertDictEqual(_aggregate_dispatched_users(dispatched_users), {\"User1\": 5, \"User2\": 10})\n self.assertDictEqual(\n dispatched_users,\n {\n \"1\": {\"User1\": 1, \"User2\": 3},\n \"2\": {\"User1\": 2, \"User2\": 2},\n \"3\": {\"User1\": 1, \"User2\": 3},\n \"4\": {\"User1\": 1, \"User2\": 2},\n },\n )\n\n # total user count = 20\n dispatched_users = next(users_dispatcher)\n self.assertDictEqual(_aggregate_dispatched_users(dispatched_users), {\"User1\": 7, \"User2\": 13})\n self.assertDictEqual(\n dispatched_users,\n {\n \"1\": {\"User1\": 2, \"User2\": 3},\n \"2\": {\"User1\": 2, \"User2\": 3},\n \"3\": {\"User1\": 1, \"User2\": 4},\n \"4\": {\"User1\": 2, \"User2\": 3},\n },\n )\n\n # total user count = 25\n dispatched_users = next(users_dispatcher)\n self.assertDictEqual(_aggregate_dispatched_users(dispatched_users), {\"User1\": 8, \"User2\": 17})\n self.assertDictEqual(\n dispatched_users,\n {\n \"1\": {\"User1\": 2, \"User2\": 5},\n \"2\": {\"User1\": 2, \"User2\": 4},\n \"3\": {\"User1\": 2, \"User2\": 4},\n \"4\": {\"User1\": 2, \"User2\": 4},\n },\n )\n\n # total user count = 30\n dispatched_users = next(users_dispatcher)\n self.assertDictEqual(_aggregate_dispatched_users(dispatched_users), {\"User1\": 10, \"User2\": 20})\n self.assertDictEqual(\n dispatched_users,\n {\n \"1\": {\"User1\": 3, \"User2\": 5},\n \"2\": {\"User1\": 3, \"User2\": 5},\n \"3\": {\"User1\": 2, \"User2\": 5},\n \"4\": {\"User1\": 2, \"User2\": 5},\n },\n )\n\n # total user count = 35\n dispatched_users = next(users_dispatcher)\n self.assertDictEqual(_aggregate_dispatched_users(dispatched_users), {\"User1\": 12, \"User2\": 23})\n self.assertDictEqual(\n dispatched_users,\n {\n \"1\": {\"User1\": 3, \"User2\": 6},\n \"2\": {\"User1\": 3, \"User2\": 6},\n \"3\": {\"User1\": 3, \"User2\": 6},\n \"4\": {\"User1\": 3, \"User2\": 5},\n },\n )\n\n # total user count = 40\n dispatched_users = next(users_dispatcher)\n self.assertDictEqual(_aggregate_dispatched_users(dispatched_users), {\"User1\": 13, \"User2\": 27})\n self.assertDictEqual(\n dispatched_users,\n {\n \"1\": {\"User1\": 3, \"User2\": 7},\n \"2\": {\"User1\": 4, \"User2\": 6},\n \"3\": {\"User1\": 3, \"User2\": 7},\n \"4\": {\"User1\": 3, \"User2\": 7},\n },\n )\n\n # total user count = 45\n dispatched_users = next(users_dispatcher)\n self.assertDictEqual(_aggregate_dispatched_users(dispatched_users), {\"User1\": 15, \"User2\": 30})\n self.assertDictEqual(\n dispatched_users,\n {\n \"1\": {\"User1\": 4, \"User2\": 8},\n \"2\": {\"User1\": 4, \"User2\": 7},\n \"3\": {\"User1\": 3, \"User2\": 8},\n \"4\": {\"User1\": 4, \"User2\": 7},\n },\n )\n\n # total user count = 50\n dispatched_users = next(users_dispatcher)\n self.assertDictEqual(_aggregate_dispatched_users(dispatched_users), {\"User1\": 17, \"User2\": 33})\n self.assertDictEqual(\n dispatched_users,\n {\n \"1\": {\"User1\": 4, \"User2\": 9},\n \"2\": {\"User1\": 5, \"User2\": 8},\n \"3\": {\"User1\": 4, \"User2\": 8},\n \"4\": {\"User1\": 4, \"User2\": 8},\n },\n )\n\n # total user count = 55\n dispatched_users = next(users_dispatcher)\n self.assertDictEqual(_aggregate_dispatched_users(dispatched_users), {\"User1\": 18, \"User2\": 37})\n self.assertDictEqual(\n dispatched_users,\n {\n \"1\": {\"User1\": 5, \"User2\": 9},\n \"2\": {\"User1\": 5, \"User2\": 9},\n \"3\": {\"User1\": 4, \"User2\": 10},\n \"4\": {\"User1\": 4, \"User2\": 9},\n },\n )\n\n # total user count = 60\n dispatched_users = next(users_dispatcher)\n self.assertDictEqual(_aggregate_dispatched_users(dispatched_users), {\"User1\": 20, \"User2\": 40})\n self.assertDictEqual(\n dispatched_users,\n {\n \"1\": {\"User1\": 5, \"User2\": 10},\n \"2\": {\"User1\": 5, \"User2\": 10},\n \"3\": {\"User1\": 5, \"User2\": 10},\n \"4\": {\"User1\": 5, \"User2\": 10},\n },\n )\n\n # total user count = 65\n dispatched_users = next(users_dispatcher)\n self.assertDictEqual(_aggregate_dispatched_users(dispatched_users), {\"User1\": 22, \"User2\": 43})\n self.assertDictEqual(\n dispatched_users,\n {\n \"1\": {\"User1\": 6, \"User2\": 11},\n \"2\": {\"User1\": 6, \"User2\": 10},\n \"3\": {\"User1\": 5, \"User2\": 11},\n \"4\": {\"User1\": 5, \"User2\": 11},\n },\n )\n\n # total user count = 70\n dispatched_users = next(users_dispatcher)\n self.assertDictEqual(_aggregate_dispatched_users(dispatched_users), {\"User1\": 23, \"User2\": 47})\n self.assertDictEqual(\n dispatched_users,\n {\n \"1\": {\"User1\": 6, \"User2\": 12},\n \"2\": {\"User1\": 6, \"User2\": 12},\n \"3\": {\"User1\": 5, \"User2\": 12},\n \"4\": {\"User1\": 6, \"User2\": 11},\n },\n )\n\n # total user count = 75, User1 = 25, User2 = 50\n dispatched_users = next(users_dispatcher)\n self.assertDictEqual(_aggregate_dispatched_users(dispatched_users), {\"User1\": 25, \"User2\": 50})\n self.assertDictEqual(\n dispatched_users,\n {\n \"1\": {\"User1\": 6, \"User2\": 13},\n \"2\": {\"User1\": 7, \"User2\": 12},\n \"3\": {\"User1\": 6, \"User2\": 13},\n \"4\": {\"User1\": 6, \"User2\": 12},\n },\n )\n\n self.assertRaises(StopIteration, lambda: next(users_dispatcher))\n\n\nclass TestLargeScale(unittest.TestCase):\n # fmt: off\n weights = [\n 5, 55, 37, 2, 97, 41, 33, 19, 19, 34, 78, 76, 28, 62, 69, 5, 55, 37, 2, 97, 41, 33, 19, 19, 34,\n 78, 76, 28, 62, 69, 41, 33, 19, 19, 34, 78, 76, 28, 62, 69, 41, 33, 19, 19, 34, 78, 76, 28, 62, 69\n ]\n # fmt: on\n numerated_weights = dict(zip(range(len(weights)), weights))\n\n weighted_user_classes = [type(f\"User{i}\", (User,), {\"weight\": w}) for i, w in numerated_weights.items()]\n fixed_user_classes_10k = [type(f\"FixedUser10k{i}\", (User,), {\"fixed_count\": 2000}) for i in range(50)]\n fixed_user_classes_1M = [type(f\"FixedUser1M{i}\", (User,), {\"fixed_count\": 20000}) for i in range(50)]\n mixed_users = weighted_user_classes[:25] + fixed_user_classes_10k[25:]\n\n \ndef test_distribute_users(self):\n for user_classes in [self.weighted_user_classes, self.fixed_user_classes_1M, self.mixed_users]:\n workers = [WorkerNode(str(i)) for i in range(10_000)]\n\n target_user_count = 1_000_000\n\n users_dispatcher = UsersDispatcher(worker_nodes=workers, user_classes=user_classes)\n\n ts = time.perf_counter()\n users_on_workers, user_gen, worker_gen, active_users = users_dispatcher._distribute_users(\n target_user_count=target_user_count\n )\n delta = time.perf_counter() - ts\n\n # Because tests are run with coverage, the code will be slower.\n # We set the pass criterion to 7000ms, but in real life, the\n # `_distribute_users` method runs faster than this.\n self.assertLessEqual(1000 * delta, 7000)\n\n self.assertEqual(_user_count(users_on_workers), target_user_count)\n\n def test_ramp_up_from_0_to_100_000_users_with_50_user_classes_and_1000_workers_and_5000_spawn_rate(self):\n for user_classes in [\n self.weighted_user_classes,\n self.fixed_user_classes_1M,\n self.fixed_user_classes_10k,\n self.mixed_users,\n ]:\n workers = [WorkerNode(str(i)) for i in range(1000)]\n\n target_user_count = 100_000\n\n users_dispatcher = UsersDispatcher(worker_nodes=workers, user_classes=user_classes)\n users_dispatcher.new_dispatch(target_user_count=target_user_count, spawn_rate=5_000)\n users_dispatcher._wait_between_dispatch = 0\n\n all_dispatched_users = list(users_dispatcher)\n\n tol = 0.2\n self.assertTrue(\n all(\n dispatch_iteration_duration <= tol\n for dispatch_iteration_duration in users_dispatcher.dispatch_iteration_durations\n ),\n \"One or more dispatch took more than {:.0f}s to compute (max = {}ms)\".format(\n tol * 1000, 1000 * max(users_dispatcher.dispatch_iteration_durations)\n ),\n )\n\n self.assertEqual(_user_count(all_dispatched_users[-1]), target_user_count)\n\n for dispatch_users in all_dispatched_users:\n user_count_on_workers = [\n sum(user_classes_count.values()) for user_classes_count in dispatch_users.values()\n ]\n self.assertLessEqual(\n max(user_count_on_workers) - min(user_count_on_workers),\n 1,\n \"One or more workers have too much users compared to the other workers when user count is {}\".format(\n _user_count(dispatch_users)\n ),\n )\n\n for i, dispatch_users in enumerate(all_dispatched_users):\n aggregated_dispatched_users = _aggregate_dispatched_users(dispatch_users)\n for user_class in [u for u in user_classes if not u.fixed_count]:\n target_relative_weight = user_class.weight / sum(\n map(attrgetter(\"weight\"), [u for u in user_classes if not u.fixed_count])\n )\n...\n", "needle_token_start": 15216, "needle_token_end": 15504, "code_context_ntokens": 16164, "output": ["```python\ndef test_distribute_users(self):\n for user_classes in [self.weighted_user_classes, self.fixed_user_classes_1M, self.mixed_users]:\n workers = [WorkerNode(str(i)) for i in range(10_000)]\n\n target_user_count = 1_000_000\n\n users_dispatcher = UsersDispatcher(worker_nodes=workers, user_classes=user_classes)\n\n ts = time.perf_counter()\n users_on_workers, user_gen, worker_gen, active_users = users_dispatcher._distribute_users(\n target_user_count=target_user_count\n )\n delta = time.perf_counter() - ts\n\n # Because tests are run with coverage, the code will be slower.\n # We set the pass criterion to 7000ms, but in real life, the\n # `_distribute_users` method runs faster than this.\n self.assertLessEqual(1000 * delta, 7000)\n\n self.assertEqual(_user_count(users_on_workers), target_user_count)\n\n def test_ramp_up_from_0_to_100_000_users_with_50_user_classes_and_1000_workers_and_5000_spawn_rate(self):\n for user_classes in [\n self.weighted_user_classes,\n self.fixed_user_classes_1M,\n self.fixed_user_classes_10k,\n self.mixed_users,\n ]:\n workers = [WorkerNode(str(i)) for i in range(1000)]\n\n target_user_count = 100_000\n\n users_dispatcher = UsersDispatcher(worker_nodes=workers, user_classes=user_classes)\n users_dispatcher.new_dispatch(target_user_count=target_user_count, spawn_rate=5_000)\n users_dispatcher._wait_between_dispatch = 0\n\n all_dispatched_users = list(users_dispatcher)\n\n tol = 0.2\n self.assertTrue(\n all(\n dispatch_iteration_duration <= tol\n for dispatch_iteration_duration in users_dispatcher.dispatch_iteration_durations\n ),\n \"One or more dispatch took more than {:.0f}s to compute (max = {}ms)\".format(\n tol * 1000, 1000 * max(users_dispatcher.dispatch_iteration_durations)\n ),\n )\n\n self.assertEqual(_user_count(all_dispatched_users[-1]), target_user_count)\n\n for dispatch_users in all_dispatched_users:\n user_count_on_workers = [\n sum(user_classes_count.values()) for user_classes_count in dispatch_users.values()\n ]\n self.assertLessEqual(\n max(user_count_on_workers) - min(user_count_on_workers),\n 1,\n \"One or more workers have too much users compared to the other workers when user count is {}\".format(\n _user_count(dispatch_users)\n ),\n )\n\n for i, dispatch_users in enumerate(all_dispatched_users):\n aggregated_dispatched_users = _aggregate_dispatched_users(dispatch_users)\n for user_class in [u for u in user_classes if not u.fixed_count]:\n target_relative_weight = user_class.weight / sum(\n map(attrgetter(\"weight\"), [u for u in user_classes if not u.fixed_count])\n )\n...\n```"]} +{"repo": "pyg-team/pytorch_geometric", "name": "maybe_init_stream", "language": "python", "path": "torch_geometric/loader/prefetch.py", "position_ratio": 0.05, "description": "\nFunction Description:\n1. **Purpose**: The function initializes a GPU stream if the device being used supports GPU operations. This allows for asynchronous operations on the GPU, which can improve performance by overlapping computation and data transfer.\n2. **Input**: There are no direct inputs as the function uses the class's internal state to determine if a stream should be initialized.\n3. **Output**: There is no direct output returned by the function; however, it modifies the internal state of the class by potentially creating a new GPU stream and setting up a context manager for that stream.\n4. **Procedure**: The function first checks if the current device is a GPU. If it is, the function creates a new GPU stream and sets up a context manager that will allow other methods to execute operations within this stream, thereby enabling asynchronous GPU operations.\n\n", "instruction": "Based on the function description and code context, please retrieve and repeat the exact described function from the code context in a code block wrapped by ```:", "template": "instruction\ncode_context\ndescription\ninstruction", "code_context": "# Path: torch_geometric/loader/neighbor_sampler.py\nfrom typing import Callable, List, NamedTuple, Optional, Tuple, Union\n\nimport torch\nfrom torch import Tensor\n\nfrom torch_geometric.typing import SparseTensor\n\n\nclass EdgeIndex(NamedTuple):\n edge_index: Tensor\n e_id: Optional[Tensor]\n size: Tuple[int, int]\n\n def to(self, *args, **kwargs):\n edge_index = self.edge_index.to(*args, **kwargs)\n e_id = self.e_id.to(*args, **kwargs) if self.e_id is not None else None\n return EdgeIndex(edge_index, e_id, self.size)\n\n\nclass Adj(NamedTuple):\n adj_t: SparseTensor\n e_id: Optional[Tensor]\n size: Tuple[int, int]\n\n def to(self, *args, **kwargs):\n adj_t = self.adj_t.to(*args, **kwargs)\n e_id = self.e_id.to(*args, **kwargs) if self.e_id is not None else None\n return Adj(adj_t, e_id, self.size)\n\n\nclass NeighborSampler(torch.utils.data.DataLoader):\n r\"\"\"The neighbor sampler from the `\"Inductive Representation Learning on\n Large Graphs\" `_ paper, which allows\n...\n# Path: torch_geometric/loader/prefetch.py\nimport warnings\nfrom contextlib import nullcontext\nfrom functools import partial\nfrom typing import Any, Optional\n\nimport torch\nfrom torch.utils.data import DataLoader\n\nfrom torch_geometric.typing import WITH_IPEX\n\n\nclass DeviceHelper:\n def __init__(self, device: Optional[torch.device] = None):\n with_cuda = torch.cuda.is_available()\n with_xpu = torch.xpu.is_available() if WITH_IPEX else False\n\n if device is None:\n if with_cuda:\n device = 'cuda'\n elif with_xpu:\n device = 'xpu'\n else:\n device = 'cpu'\n\n self.device = torch.device(device)\n self.is_gpu = self.device.type in ['cuda', 'xpu']\n\n if ((self.device.type == 'cuda' and not with_cuda)\n or (self.device.type == 'xpu' and not with_xpu)):\n warnings.warn(f\"Requested device '{self.device.type}' is not \"\n f\"available, falling back to CPU\")\n self.device = torch.device('cpu')\n\n self.stream = None\n self.stream_context = nullcontext\n self.module = getattr(torch, self.device.type) if self.is_gpu else None\n\n \ndef maybe_init_stream(self) -> None:\n if self.is_gpu:\n self.stream = self.module.Stream()\n self.stream_context = partial(\n self.module.stream,\n stream=self.stream,\n )\n\n def maybe_wait_stream(self) -> None:\n if self.stream is not None:\n self.module.current_stream().wait_stream(self.stream)\n\n\nclass PrefetchLoader:\n r\"\"\"A GPU prefetcher class for asynchronously transferring data of a\n :class:`torch.utils.data.DataLoader` from host memory to device memory.\n\n Args:\n loader (torch.utils.data.DataLoader): The data loader.\n device (torch.device, optional): The device to load the data to.\n (default: :obj:`None`)\n \"\"\"\n def __init__(\n self,\n loader: DataLoader,\n device: Optional[torch.device] = None,\n ):\n self.loader = loader\n self.device_helper = DeviceHelper(device)\n\n def non_blocking_transfer(self, batch: Any) -> Any:\n if not self.device_helper.is_gpu:\n return batch\n if isinstance(batch, (list, tuple)):\n return [self.non_blocking_transfer(v) for v in batch]\n if isinstance(batch, dict):\n return {k: self.non_blocking_transfer(v) for k, v in batch.items()}\n\n batch = batch.pin_memory(self.device_helper.device)\n return batch.to(self.device_helper.device, non_blocking=True)\n\n def __iter__(self) -> Any:\n first = True\n self.device_helper.maybe_init_stream()\n\n batch = None\n for next_batch in self.loader:\n\n with self.device_helper.stream_context():\n next_batch = self.non_blocking_transfer(next_batch)\n\n if not first:\n yield batch\n else:\n first = False\n\n self.device_helper.maybe_wait_stream()\n\n batch = next_batch\n\n yield batch\n\n def __len__(self) -> int:\n return len(self.loader)\n\n def __repr__(self) -> str:\n return f'{self.__class__.__name__}({self.loader})'\n\n# Path: torch_geometric/loader/random_node_loader.py\nimport math\nfrom typing import Union\n\nimport torch\nfrom torch import Tensor\n\nfrom torch_geometric.data import Data, HeteroData\nfrom torch_geometric.data.hetero_data import to_homogeneous_edge_index\n\n\nclass RandomNodeLoader(torch.utils.data.DataLoader):\n r\"\"\"A data loader that randomly samples nodes within a graph and returns\n their induced subgraph.\n\n .. note::\n\n For an example of using\n :class:`~torch_geometric.loader.RandomNodeLoader`, see\n `examples/ogbn_proteins_deepgcn.py\n `_.\n\n Args:\n data (torch_geometric.data.Data or torch_geometric.data.HeteroData):\n The :class:`~torch_geometric.data.Data` or\n :class:`~torch_geometric.data.HeteroData` graph object.\n num_parts (int): The number of partitions.\n **kwargs (optional): Additional arguments of\n :class:`torch.utils.data.DataLoader`, such as :obj:`num_workers`.\n \"\"\"\n def __init__(\n self,\n data: Union[Data, HeteroData],\n num_parts: int,\n **kwargs,\n ):\n self.data = data\n self.num_parts = num_parts\n\n if isinstance(data, HeteroData):\n edge_index, node_dict, edge_dict = to_homogeneous_edge_index(data)\n self.node_dict, self.edge_dict = node_dict, edge_dict\n else:\n edge_index = data.edge_index\n\n self.edge_index = edge_index\n self.num_nodes = data.num_nodes\n\n super().__init__(\n range(self.num_nodes),\n batch_size=math.ceil(self.num_nodes / num_parts),\n collate_fn=self.collate_fn,\n **kwargs,\n )\n\n def collate_fn(self, index):\n if not isinstance(index, Tensor):\n index = torch.tensor(index)\n\n if isinstance(self.data, Data):\n return self.data.subgraph(index)\n\n elif isinstance(self.data, HeteroData):\n node_dict = {\n key: index[(index >= start) & (index < end)] - start\n for key, (start, end) in self.node_dict.items()\n }\n return self.data.subgraph(node_dict)\n\n# Path: torch_geometric/loader/shadow.py\nimport copy\nfrom typing import Optional\n\nimport torch\nfrom torch import Tensor\n\nfrom torch_geometric.data import Batch, Data\nfrom torch_geometric.typing import WITH_TORCH_SPARSE, SparseTensor\n\n\nclass ShaDowKHopSampler(torch.utils.data.DataLoader):\n r\"\"\"The ShaDow :math:`k`-hop sampler from the `\"Decoupling the Depth and\n Scope of Graph Neural Networks\" `_ paper.\n Given a graph in a :obj:`data` object, the sampler will create shallow,\n localized subgraphs.\n A deep GNN on this local graph then smooths the informative local signals.\n\n .. note::\n\n For an example of using :class:`ShaDowKHopSampler`, see\n `examples/shadow.py `_.\n\n Args:\n data (torch_geometric.data.Data): The graph data object.\n depth (int): The depth/number of hops of the localized subgraph.\n num_neighbors (int): The number of neighbors to sample for each node in\n each hop.\n node_idx (LongTensor or BoolTensor, optional): The nodes that should be\n considered for creating mini-batches.\n If set to :obj:`None`, all nodes will be\n considered.\n replace (bool, optional): If set to :obj:`True`, will sample neighbors\n with replacement. (default: :obj:`False`)\n **kwargs (optional): Additional arguments of\n :class:`torch.utils.data.DataLoader`, such as :obj:`batch_size` or\n :obj:`num_workers`.\n \"\"\"\n def __init__(self, data: Data, depth: int, num_neighbors: int,\n node_idx: Optional[Tensor] = None, replace: bool = False,\n **kwargs):\n\n if not WITH_TORCH_SPARSE:\n raise ImportError(\n f\"'{self.__class__.__name__}' requires 'torch-sparse'\")\n\n self.data = copy.copy(data)\n self.depth = depth\n self.num_neighbors = num_neighbors\n self.replace = replace\n\n if data.edge_index is not None:\n self.is_sparse_tensor = False\n row, col = data.edge_index.cpu()\n self.adj_t = SparseTensor(\n row=row, col=col, value=torch.arange(col.size(0)),\n sparse_sizes=(data.num_nodes, data.num_nodes)).t()\n else:\n self.is_sparse_tensor = True\n self.adj_t = data.adj_t.cpu()\n\n if node_idx is None:\n node_idx = torch.arange(self.adj_t.sparse_size(0))\n elif node_idx.dtype == torch.bool:\n node_idx = node_idx.nonzero(as_tuple=False).view(-1)\n self.node_idx = node_idx\n\n super().__init__(node_idx.tolist(), collate_fn=self.__collate__,\n **kwargs)\n\n def __collate__(self, n_id):\n n_id = torch.tensor(n_id)\n\n rowptr, col, value = self.adj_t.csr()\n out = torch.ops.torch_sparse.ego_k_hop_sample_adj(\n rowptr, col, n_id, self.depth, self.num_neighbors, self.replace)\n rowptr, col, n_id, e_id, ptr, root_n_id = out\n\n adj_t = SparseTensor(rowptr=rowptr, col=col,\n value=value[e_id] if value is not None else None,\n sparse_sizes=(n_id.numel(), n_id.numel()),\n is_sorted=True, trust_data=True)\n\n batch = Batch(batch=torch.ops.torch_sparse.ptr2ind(ptr, n_id.numel()),\n ptr=ptr)\n batch.root_n_id = root_n_id\n\n if self.is_sparse_tensor:\n batch.adj_t = adj_t\n else:\n row, col, e_id = adj_t.t().coo()\n batch.edge_index = torch.stack([row, col], dim=0)\n\n for k, v in self.data:\n if k in ['edge_index', 'adj_t', 'num_nodes', 'batch', 'ptr']:\n continue\n if k == 'y' and v.size(0) == self.data.num_nodes:\n batch[k] = v[n_id][root_n_id]\n elif isinstance(v, Tensor) and v.size(0) == self.data.num_nodes:\n batch[k] = v[n_id]\n elif isinstance(v, Tensor) and v.size(0) == self.data.num_edges:\n batch[k] = v[e_id]\n else:\n batch[k] = v\n\n return batch\n\n# Path: torch_geometric/loader/temporal_dataloader.py\nfrom typing import List\n\nimport torch\n\nfrom torch_geometric.data import TemporalData\n\n\nclass TemporalDataLoader(torch.utils.data.DataLoader):\n r\"\"\"A data loader which merges succesive events of a\n :class:`torch_geometric.data.TemporalData` to a mini-batch.\n\n Args:\n data (TemporalData): The :obj:`~torch_geometric.data.TemporalData`\n from which to load the data.\n batch_size (int, optional): How many samples per batch to load.\n (default: :obj:`1`)\n neg_sampling_ratio (float, optional): The ratio of sampled negative\n destination nodes to the number of postive destination nodes.\n (default: :obj:`0.0`)\n **kwargs (optional): Additional arguments of\n :class:`torch.utils.data.DataLoader`.\n \"\"\"\n def __init__(\n self,\n data: TemporalData,\n batch_size: int = 1,\n neg_sampling_ratio: float = 0.0,\n **kwargs,\n ):\n # Remove for PyTorch Lightning:\n kwargs.pop('dataset', None)\n kwargs.pop('collate_fn', None)\n kwargs.pop('shuffle', None)\n\n self.data = data\n self.events_per_batch = batch_size\n self.neg_sampling_ratio = neg_sampling_ratio\n\n if neg_sampling_ratio > 0:\n self.min_dst = int(data.dst.min())\n self.max_dst = int(data.dst.max())\n\n if kwargs.get('drop_last', False) and len(data) % batch_size != 0:\n arange = range(0, len(data) - batch_size, batch_size)\n else:\n arange = range(0, len(data), batch_size)\n\n super().__init__(arange, 1, shuffle=False, collate_fn=self, **kwargs)\n\n def __call__(self, arange: List[int]) -> TemporalData:\n batch = self.data[arange[0]:arange[0] + self.events_per_batch]\n\n n_ids = [batch.src, batch.dst]\n\n if self.neg_sampling_ratio > 0:\n batch.neg_dst = torch.randint(\n low=self.min_dst,\n high=self.max_dst + 1,\n size=(round(self.neg_sampling_ratio * batch.dst.size(0)), ),\n dtype=batch.dst.dtype,\n device=batch.dst.device,\n )\n n_ids += [batch.neg_dst]\n\n batch.n_id = torch.cat(n_ids, dim=0).unique()\n\n return batch\n\n# Path: torch_geometric/loader/zip_loader.py\nfrom typing import Any, Iterator, List, Optional, Tuple, Union\n\nimport torch\nfrom torch import Tensor\n\nfrom torch_geometric.data import Data, HeteroData\nfrom torch_geometric.loader import LinkLoader, NodeLoader\nfrom torch_geometric.loader.base import DataLoaderIterator\nfrom torch_geometric.loader.utils import infer_filter_per_worker\n\n\nclass ZipLoader(torch.utils.data.DataLoader):\n r\"\"\"A loader that returns a tuple of data objects by sampling from multiple\n :class:`NodeLoader` or :class:`LinkLoader` instances.\n\n Args:\n loaders (List[NodeLoader] or List[LinkLoader]): The loader instances.\n filter_per_worker (bool, optional): If set to :obj:`True`, will filter\n the returned data in each worker's subprocess.\n If set to :obj:`False`, will filter the returned data in the main\n process.\n If set to :obj:`None`, will automatically infer the decision based\n on whether data partially lives on the GPU\n (:obj:`filter_per_worker=True`) or entirely on the CPU\n (:obj:`filter_per_worker=False`).\n There exists different trade-offs for setting this option.\n Specifically, setting this option to :obj:`True` for in-memory\n datasets will move all features to shared memory, which may result\n in too many open file handles. (default: :obj:`None`)\n **kwargs (optional): Additional arguments of\n :class:`torch.utils.data.DataLoader`, such as :obj:`batch_size`,\n :obj:`shuffle`, :obj:`drop_last` or :obj:`num_workers`.\n \"\"\"\n def __init__(\n self,\n loaders: Union[List[NodeLoader], List[LinkLoader]],\n filter_per_worker: Optional[bool] = None,\n **kwargs,\n ):\n if filter_per_worker is None:\n filter_per_worker = infer_filter_per_worker(loaders[0].data)\n\n # Remove for PyTorch Lightning:\n kwargs.pop('dataset', None)\n kwargs.pop('collate_fn', None)\n\n for loader in loaders:\n if not callable(getattr(loader, 'collate_fn', None)):\n raise ValueError(\"'{loader.__class__.__name__}' does not have \"\n \"a 'collate_fn' method\")\n if not callable(getattr(loader, 'filter_fn', None)):\n raise ValueError(\"'{loader.__class__.__name__}' does not have \"\n \"a 'filter_fn' method\")\n loader.filter_per_worker = filter_per_worker\n\n iterator = range(min([len(loader.dataset) for loader in loaders]))\n super().__init__(iterator, collate_fn=self.collate_fn, **kwargs)\n\n self.loaders = loaders\n self.filter_per_worker = filter_per_worker\n\n def collate_fn(self, index: List[int]) -> Tuple[Any, ...]:\n if not isinstance(index, Tensor):\n index = torch.tensor(index, dtype=torch.long)\n\n return tuple(loader.collate_fn(index) for loader in self.loaders)\n\n def filter_fn(\n self,\n outs: Tuple[Any, ...],\n ) -> Tuple[Union[Data, HeteroData], ...]:\n loaders = self.loaders\n return tuple(loader.filter_fn(v) for loader, v in zip(loaders, outs))\n\n def _get_iterator(self) -> Iterator:\n if self.filter_per_worker:\n return super()._get_iterator()\n\n # Execute `filter_fn` in the main process:\n return DataLoaderIterator(super()._get_iterator(), self.filter_fn)\n\n def __repr__(self) -> str:\n return f'{self.__class__.__name__}(loaders={self.loaders})'\n\n# Path: torch_geometric/loader/__init__.py\nfrom torch_geometric.deprecation import deprecated\n\nfrom .dataloader import DataLoader\nfrom .node_loader import NodeLoader\nfrom .link_loader import LinkLoader\nfrom .neighbor_loader import NeighborLoader\nfrom .link_neighbor_loader import LinkNeighborLoader\nfrom .hgt_loader import HGTLoader\nfrom .cluster import ClusterData, ClusterLoader\nfrom .graph_saint import (GraphSAINTSampler, GraphSAINTNodeSampler,\n GraphSAINTEdgeSampler, GraphSAINTRandomWalkSampler)\nfrom .shadow import ShaDowKHopSampler\nfrom .random_node_loader import RandomNodeLoader\n# from .ibmb_loader import IBMBBatchLoader, IBMBNodeLoader\nfrom .zip_loader import ZipLoader\nfrom .data_list_loader import DataListLoader\nfrom .dense_data_loader import DenseDataLoader\nfrom .temporal_dataloader import TemporalDataLoader\nfrom .neighbor_sampler import NeighborSampler\nfrom .imbalanced_sampler import ImbalancedSampler\nfrom .dynamic_batch_sampler import DynamicBatchSampler\nfrom .prefetch import PrefetchLoader\nfrom .cache import CachedLoader\nfrom .mixin import AffinityMixin\n\n__all__ = classes = [\n 'DataLoader',\n 'NodeLoader',\n 'LinkLoader',\n 'NeighborLoader',\n 'LinkNeighborLoader',\n 'HGTLoader',\n 'ClusterData',\n 'ClusterLoader',\n 'GraphSAINTSampler',\n 'GraphSAINTNodeSampler',\n 'GraphSAINTEdgeSampler',\n 'GraphSAINTRandomWalkSampler',\n 'ShaDowKHopSampler',\n 'RandomNodeLoader',\n # 'IBMBBatchLoader',\n # 'IBMBNodeLoader',\n 'ZipLoader',\n 'DataListLoader',\n 'DenseDataLoader',\n 'TemporalDataLoader',\n 'NeighborSampler',\n 'ImbalancedSampler',\n 'DynamicBatchSampler',\n 'PrefetchLoader',\n 'CachedLoader',\n 'AffinityMixin',\n]\n\nRandomNodeSampler = deprecated(\n details=\"use 'loader.RandomNodeLoader' instead\",\n func_name='loader.RandomNodeSampler',\n)(RandomNodeLoader)\n\n# Path: torch_geometric/data/__init__.py\n# flake8: noqa\n\nfrom .feature_store import FeatureStore, TensorAttr\nfrom .graph_store import GraphStore, EdgeAttr\nfrom .data import Data\nfrom .hetero_data import HeteroData\nfrom .batch import Batch\nfrom .temporal import TemporalData\nfrom .database import Database, SQLiteDatabase, RocksDatabase\nfrom .dataset import Dataset\nfrom .in_memory_dataset import InMemoryDataset\nfrom .on_disk_dataset import OnDiskDataset\nfrom .makedirs import makedirs\nfrom .download import download_url, download_google_url\nfrom .extract import extract_tar, extract_zip, extract_bz2, extract_gz\n\nfrom torch_geometric.lazy_loader import LazyLoader\n\ndata_classes = [\n 'Data',\n 'HeteroData',\n 'Batch',\n 'TemporalData',\n 'Dataset',\n 'InMemoryDataset',\n 'OnDiskDataset',\n]\n\nremote_backend_classes = [\n 'FeatureStore',\n 'GraphStore',\n 'TensorAttr',\n 'EdgeAttr',\n]\n\ndatabase_classes = [\n 'Database',\n 'SQLiteDatabase',\n 'RocksDatabase',\n]\n\nhelper_functions = [\n 'makedirs',\n 'download_url',\n 'download_google_url',\n 'extract_tar',\n 'extract_zip',\n 'extract_bz2',\n 'extract_gz',\n]\n\n__all__ = data_classes + remote_backend_classes + helper_functions\n\nlightning = LazyLoader('lightning', globals(),\n 'torch_geometric.data.lightning')\n\nfrom torch_geometric.deprecation import deprecated\nfrom torch_geometric.loader import NeighborSampler\nfrom torch_geometric.loader import ClusterData\nfrom torch_geometric.loader import ClusterLoader\nfrom torch_geometric.loader import GraphSAINTSampler\nfrom torch_geometric.loader import GraphSAINTNodeSampler\nfrom torch_geometric.loader import GraphSAINTEdgeSampler\nfrom torch_geometric.loader import GraphSAINTRandomWalkSampler\nfrom torch_geometric.loader import ShaDowKHopSampler\nfrom torch_geometric.loader import RandomNodeLoader\nfrom torch_geometric.loader import DataLoader\nfrom torch_geometric.loader import DataListLoader\nfrom torch_geometric.loader import DenseDataLoader\n\nNeighborSampler = deprecated( # type: ignore\n details=\"use 'loader.NeighborSampler' instead\",\n func_name='data.NeighborSampler',\n)(NeighborSampler)\nClusterData = deprecated( # type: ignore\n details=\"use 'loader.ClusterData' instead\",\n func_name='data.ClusterData',\n)(ClusterData)\nClusterLoader = deprecated( # type: ignore\n details=\"use 'loader.ClusterLoader' instead\",\n func_name='data.ClusterLoader',\n)(ClusterLoader)\nGraphSAINTSampler = deprecated( # type: ignore\n details=\"use 'loader.GraphSAINTSampler' instead\",\n func_name='data.GraphSAINTSampler',\n)(GraphSAINTSampler)\nGraphSAINTNodeSampler = deprecated( # type: ignore\n details=\"use 'loader.GraphSAINTNodeSampler' instead\",\n func_name='data.GraphSAINTNodeSampler',\n)(GraphSAINTNodeSampler)\nGraphSAINTEdgeSampler = deprecated( # type: ignore\n details=\"use 'loader.GraphSAINTEdgeSampler' instead\",\n func_name='data.GraphSAINTEdgeSampler',\n)(GraphSAINTEdgeSampler)\nGraphSAINTRandomWalkSampler = deprecated( # type: ignore\n details=\"use 'loader.GraphSAINTRandomWalkSampler' instead\",\n func_name='data.GraphSAINTRandomWalkSampler',\n)(GraphSAINTRandomWalkSampler)\nShaDowKHopSampler = deprecated( # type: ignore\n details=\"use 'loader.ShaDowKHopSampler' instead\",\n func_name='data.ShaDowKHopSampler',\n)(ShaDowKHopSampler)\nRandomNodeSampler = deprecated(\n details=\"use 'loader.RandomNodeLoader' instead\",\n func_name='data.RandomNodeSampler',\n)(RandomNodeLoader)\nDataLoader = deprecated( # type: ignore\n details=\"use 'loader.DataLoader' instead\",\n func_name='data.DataLoader',\n)(DataLoader)\nDataListLoader = deprecated( # type: ignore\n details=\"use 'loader.DataListLoader' instead\",\n func_name='data.DataListLoader',\n)(DataListLoader)\nDenseDataLoader = deprecated( # type: ignore\n details=\"use 'loader.DenseDataLoader' instead\",\n func_name='data.DenseDataLoader',\n)(DenseDataLoader)\n\n# Path: torch_geometric/visualization/graph.py\nfrom math import sqrt\nfrom typing import Any, List, Optional\n\nimport torch\nfrom torch import Tensor\n\nBACKENDS = {'graphviz', 'networkx'}\n\n\ndef has_graphviz() -> bool:\n try:\n import graphviz\n except ImportError:\n return False\n\n try:\n graphviz.Digraph().pipe()\n except graphviz.backend.ExecutableNotFound:\n return False\n\n return True\n\n\ndef visualize_graph(\n edge_index: Tensor,\n edge_weight: Optional[Tensor] = None,\n path: Optional[str] = None,\n backend: Optional[str] = None,\n node_labels: Optional[List[str]] = None,\n) -> Any:\n r\"\"\"Visualizes the graph given via :obj:`edge_index` and (optional)\n :obj:`edge_weight`.\n\n Args:\n edge_index (torch.Tensor): The edge indices.\n edge_weight (torch.Tensor, optional): The edge weights.\n path (str, optional): The path to where the plot is saved.\n If set to :obj:`None`, will visualize the plot on-the-fly.\n (default: :obj:`None`)\n backend (str, optional): The graph drawing backend to use for\n visualization (:obj:`\"graphviz\"`, :obj:`\"networkx\"`).\n If set to :obj:`None`, will use the most appropriate\n visualization backend based on available system packages.\n (default: :obj:`None`)\n node_labels (List[str], optional): The labels/IDs of nodes.\n (default: :obj:`None`)\n \"\"\"\n if edge_weight is not None: # Normalize edge weights.\n edge_weight = edge_weight - edge_weight.min()\n edge_weight = edge_weight / edge_weight.max()\n\n if edge_weight is not None: # Discard any edges with zero edge weight:\n mask = edge_weight > 1e-7\n edge_index = edge_index[:, mask]\n edge_weight = edge_weight[mask]\n\n if edge_weight is None:\n edge_weight = torch.ones(edge_index.size(1))\n\n if backend is None:\n backend = 'graphviz' if has_graphviz() else 'networkx'\n\n if backend.lower() == 'networkx':\n return _visualize_graph_via_networkx(edge_index, edge_weight, path,\n node_labels)\n elif backend.lower() == 'graphviz':\n return _visualize_graph_via_graphviz(edge_index, edge_weight, path,\n node_labels)\n\n raise ValueError(f\"Expected graph drawing backend to be in \"\n f\"{BACKENDS} (got '{backend}')\")\n\n\ndef _visualize_graph_via_graphviz(\n edge_index: Tensor,\n edge_weight: Tensor,\n path: Optional[str] = None,\n node_labels: Optional[List[str]] = None,\n) -> Any:\n import graphviz\n\n suffix = path.split('.')[-1] if path is not None else None\n g = graphviz.Digraph('graph', format=suffix)\n g.attr('node', shape='circle', fontsize='11pt')\n\n for node in edge_index.view(-1).unique().tolist():\n g.node(str(node) if node_labels is None else node_labels[node])\n\n for (src, dst), w in zip(edge_index.t().tolist(), edge_weight.tolist()):\n hex_color = hex(255 - round(255 * w))[2:]\n hex_color = f'{hex_color}0' if len(hex_color) == 1 else hex_color\n if node_labels is not None:\n src = node_labels[src]\n dst = node_labels[dst]\n g.edge(str(src), str(dst), color=f'#{hex_color}{hex_color}{hex_color}')\n\n if path is not None:\n path = '.'.join(path.split('.')[:-1])\n g.render(path, cleanup=True)\n else:\n g.view()\n\n return g\n\n\ndef _visualize_graph_via_networkx(\n edge_index: Tensor,\n edge_weight: Tensor,\n path: Optional[str] = None,\n node_labels: Optional[List[str]] = None,\n) -> Any:\n import matplotlib.pyplot as plt\n import networkx as nx\n\n g = nx.DiGraph()\n node_size = 800\n\n for node in edge_index.view(-1).unique().tolist():\n g.add_node(node if node_labels is None else node_labels[node])\n\n for (src, dst), w in zip(edge_index.t().tolist(), edge_weight.tolist()):\n if node_labels is not None:\n src = node_labels[src]\n dst = node_labels[dst]\n g.add_edge(src, dst, alpha=w)\n\n ax = plt.gca()\n pos = nx.spring_layout(g)\n for src, dst, data in g.edges(data=True):\n ax.annotate(\n '',\n xy=pos[src],\n xytext=pos[dst],\n arrowprops=dict(\n arrowstyle=\"->\",\n alpha=data['alpha'],\n shrinkA=sqrt(node_size) / 2.0,\n shrinkB=sqrt(node_size) / 2.0,\n connectionstyle=\"arc3,rad=0.1\",\n ),\n )\n\n nodes = nx.draw_networkx_nodes(g, pos, node_size=node_size,\n node_color='white', margins=0.1)\n nodes.set_edgecolor('black')\n nx.draw_networkx_labels(g, pos, font_size=10)\n\n if path is not None:\n plt.savefig(path)\n else:\n plt.show()\n\n plt.close()\n\n# Path: torch_geometric/visualization/influence.py\nfrom typing import Any\n\nimport torch\nfrom torch import Tensor\nfrom torch.autograd import grad\n\n\ndef influence(model: torch.nn.Module, src: Tensor, *args: Any) -> Tensor:\n x = src.clone().requires_grad_()\n out = model(x, *args).sum(dim=-1)\n\n influences = []\n for j in range(src.size(0)):\n influence = grad([out[j]], [x], retain_graph=True)[0].abs().sum(dim=-1)\n influences.append(influence / influence.sum())\n\n return torch.stack(influences, dim=0)\n\n# Path: torch_geometric/visualization/__init__.py\nr\"\"\"Visualization package.\"\"\"\n\nfrom .graph import visualize_graph\nfrom .influence import influence\n\n__all__ = [\n 'visualize_graph',\n 'influence',\n]\n\n# Path: torch_geometric/explain/explanation.py\nimport copy\nfrom typing import Dict, List, Optional, Union\n\nimport torch\nfrom torch import Tensor\n\nfrom torch_geometric.data.data import Data, warn_or_raise\nfrom torch_geometric.data.hetero_data import HeteroData\nfrom torch_geometric.explain.config import ThresholdConfig, ThresholdType\nfrom torch_geometric.typing import EdgeType, NodeType\nfrom torch_geometric.visualization import visualize_graph\n\n\nclass ExplanationMixin:\n @property\n def available_explanations(self) -> List[str]:\n \"\"\"Returns the available explanation masks.\"\"\"\n return [key for key in self.keys() if key.endswith('_mask')]\n\n def validate_masks(self, raise_on_error: bool = True) -> bool:\n r\"\"\"Validates the correctness of the :class:`Explanation` masks.\"\"\"\n status = True\n\n for store in self.node_stores:\n if 'node_mask' not in store:\n continue\n\n if store.node_mask.dim() != 2:\n status = False\n warn_or_raise(\n f\"Expected a 'node_mask' with two dimensions (got \"\n f\"{store.node_mask.dim()} dimensions)\", raise_on_error)\n\n if store.node_mask.size(0) not in {1, store.num_nodes}:\n status = False\n warn_or_raise(\n f\"Expected a 'node_mask' with {store.num_nodes} nodes \"\n f\"(got {store.node_mask.size(0)} nodes)\", raise_on_error)\n\n if 'x' in store:\n num_features = store.x.size(-1)\n else:\n num_features = store.node_mask.size(-1)\n\n if store.node_mask.size(1) not in {1, num_features}:\n status = False\n warn_or_raise(\n f\"Expected a 'node_mask' with {num_features} features (\"\n f\"got {store.node_mask.size(1)} features)\", raise_on_error)\n\n for store in self.edge_stores:\n if 'edge_mask' not in store:\n continue\n\n if store.edge_mask.dim() != 1:\n status = False\n warn_or_raise(\n f\"Expected an 'edge_mask' with one dimension (got \"\n f\"{store.edge_mask.dim()} dimensions)\", raise_on_error)\n\n if store.edge_mask.size(0) != store.num_edges:\n status = False\n warn_or_raise(\n f\"Expected an 'edge_mask' with {store.num_edges} edges \"\n f\"(got {store.edge_mask.size(0)} edges)\", raise_on_error)\n\n return status\n\n def _threshold_mask(\n self,\n mask: Optional[Tensor],\n threshold_config: ThresholdConfig,\n ) -> Optional[Tensor]:\n\n if mask is None:\n return None\n\n if threshold_config.type == ThresholdType.hard:\n return (mask > threshold_config.value).float()\n\n if threshold_config.type in [\n ThresholdType.topk,\n ThresholdType.topk_hard,\n ]:\n if threshold_config.value >= mask.numel():\n if threshold_config.type == ThresholdType.topk:\n return mask\n else:\n return torch.ones_like(mask)\n\n value, index = torch.topk(\n mask.flatten(),\n k=threshold_config.value,\n )\n\n out = torch.zeros_like(mask.flatten())\n if threshold_config.type == ThresholdType.topk:\n out[index] = value\n else:\n out[index] = 1.0\n return out.view(mask.size())\n\n assert False\n\n def threshold(\n self,\n *args,\n **kwargs,\n ) -> Union['Explanation', 'HeteroExplanation']:\n \"\"\"Thresholds the explanation masks according to the thresholding\n method.\n\n Args:\n *args: Arguments passed to :class:`ThresholdConfig`.\n **kwargs: Keyword arguments passed to :class:`ThresholdConfig`.\n \"\"\"\n threshold_config = ThresholdConfig.cast(*args, **kwargs)\n\n if threshold_config is None:\n return self\n\n # Avoid modification of the original explanation:\n out = copy.copy(self)\n\n for store in out.node_stores:\n store.node_mask = self._threshold_mask(store.get('node_mask'),\n threshold_config)\n\n for store in out.edge_stores:\n store.edge_mask = self._threshold_mask(store.get('edge_mask'),\n threshold_config)\n\n return out\n\n\nclass Explanation(Data, ExplanationMixin):\n r\"\"\"Holds all the obtained explanations of a homogeneous graph.\n\n The explanation object is a :obj:`~torch_geometric.data.Data` object and\n can hold node attributions and edge attributions.\n It can also hold the original graph if needed.\n\n Args:\n node_mask (Tensor, optional): Node-level mask with shape\n :obj:`[num_nodes, 1]`, :obj:`[1, num_features]` or\n :obj:`[num_nodes, num_features]`. (default: :obj:`None`)\n edge_mask (Tensor, optional): Edge-level mask with shape\n :obj:`[num_edges]`. (default: :obj:`None`)\n **kwargs (optional): Additional attributes.\n \"\"\"\n def validate(self, raise_on_error: bool = True) -> bool:\n r\"\"\"Validates the correctness of the :class:`Explanation` object.\"\"\"\n status = super().validate(raise_on_error)\n status &= self.validate_masks(raise_on_error)\n return status\n\n def get_explanation_subgraph(self) -> 'Explanation':\n r\"\"\"Returns the induced subgraph, in which all nodes and edges with\n zero attribution are masked out.\n \"\"\"\n node_mask = self.get('node_mask')\n if node_mask is not None:\n node_mask = node_mask.sum(dim=-1) > 0\n edge_mask = self.get('edge_mask')\n if edge_mask is not None:\n edge_mask = edge_mask > 0\n return self._apply_masks(node_mask, edge_mask)\n\n def get_complement_subgraph(self) -> 'Explanation':\n r\"\"\"Returns the induced subgraph, in which all nodes and edges with any\n attribution are masked out.\n \"\"\"\n node_mask = self.get('node_mask')\n if node_mask is not None:\n node_mask = node_mask.sum(dim=-1) == 0\n edge_mask = self.get('edge_mask')\n if edge_mask is not None:\n edge_mask = edge_mask == 0\n return self._apply_masks(node_mask, edge_mask)\n\n def _apply_masks(\n self,\n node_mask: Optional[Tensor] = None,\n edge_mask: Optional[Tensor] = None,\n ) -> 'Explanation':\n out = copy.copy(self)\n\n if edge_mask is not None:\n for key, value in self.items():\n if key == 'edge_index':\n out.edge_index = value[:, edge_mask]\n elif self.is_edge_attr(key):\n out[key] = value[edge_mask]\n\n if node_mask is not None:\n out = out.subgraph(node_mask)\n\n return out\n\n def visualize_feature_importance(\n self,\n path: Optional[str] = None,\n feat_labels: Optional[List[str]] = None,\n top_k: Optional[int] = None,\n ):\n r\"\"\"Creates a bar plot of the node feature importances by summing up\n the node mask across all nodes.\n\n Args:\n path (str, optional): The path to where the plot is saved.\n If set to :obj:`None`, will visualize the plot on-the-fly.\n (default: :obj:`None`)\n feat_labels (List[str], optional): The labels of features.\n (default :obj:`None`)\n top_k (int, optional): Top k features to plot. If :obj:`None`\n plots all features. (default: :obj:`None`)\n \"\"\"\n node_mask = self.get('node_mask')\n if node_mask is None:\n raise ValueError(f\"The attribute 'node_mask' is not available \"\n f\"in '{self.__class__.__name__}' \"\n f\"(got {self.available_explanations})\")\n if node_mask.dim() != 2 or node_mask.size(1) <= 1:\n raise ValueError(f\"Cannot compute feature importance for \"\n f\"object-level 'node_mask' \"\n f\"(got shape {node_mask.size()})\")\n\n if feat_labels is None:\n feat_labels = range(node_mask.size(1))\n\n score = node_mask.sum(dim=0)\n\n return _visualize_score(score, feat_labels, path, top_k)\n\n def visualize_graph(\n self,\n path: Optional[str] = None,\n backend: Optional[str] = None,\n node_labels: Optional[List[str]] = None,\n ) -> None:\n r\"\"\"Visualizes the explanation graph with edge opacity corresponding to\n edge importance.\n\n Args:\n path (str, optional): The path to where the plot is saved.\n If set to :obj:`None`, will visualize the plot on-the-fly.\n (default: :obj:`None`)\n backend (str, optional): The graph drawing backend to use for\n visualization (:obj:`\"graphviz\"`, :obj:`\"networkx\"`).\n If set to :obj:`None`, will use the most appropriate\n visualization backend based on available system packages.\n (default: :obj:`None`)\n node_labels (list[str], optional): The labels/IDs of nodes.\n (default: :obj:`None`)\n \"\"\"\n edge_mask = self.get('edge_mask')\n if edge_mask is None:\n raise ValueError(f\"The attribute 'edge_mask' is not available \"\n f\"in '{self.__class__.__name__}' \"\n f\"(got {self.available_explanations})\")\n visualize_graph(self.edge_index, edge_mask, path, backend, node_labels)\n\n\nclass HeteroExplanation(HeteroData, ExplanationMixin):\n r\"\"\"Holds all the obtained explanations of a heterogeneous graph.\n\n The explanation object is a :obj:`~torch_geometric.data.HeteroData` object\n and can hold node attributions and edge attributions.\n It can also hold the original graph if needed.\n \"\"\"\n def validate(self, raise_on_error: bool = True) -> bool:\n r\"\"\"Validates the correctness of the :class:`Explanation` object.\"\"\"\n status = super().validate(raise_on_error)\n status &= self.validate_masks(raise_on_error)\n return status\n\n def get_explanation_subgraph(self) -> 'HeteroExplanation':\n r\"\"\"Returns the induced subgraph, in which all nodes and edges with\n zero attribution are masked out.\n \"\"\"\n return self._apply_masks(\n node_mask_dict={\n key: mask.sum(dim=-1) > 0\n for key, mask in self.collect('node_mask', True).items()\n },\n edge_mask_dict={\n key: mask > 0\n for key, mask in self.collect('edge_mask', True).items()\n },\n )\n\n def get_complement_subgraph(self) -> 'HeteroExplanation':\n r\"\"\"Returns the induced subgraph, in which all nodes and edges with any\n attribution are masked out.\n \"\"\"\n return self._apply_masks(\n node_mask_dict={\n key: mask.sum(dim=-1) == 0\n for key, mask in self.collect('node_mask', True).items()\n },\n edge_mask_dict={\n key: mask == 0\n for key, mask in self.collect('edge_mask', True).items()\n },\n )\n\n def _apply_masks(\n self,\n node_mask_dict: Dict[NodeType, Tensor],\n edge_mask_dict: Dict[EdgeType, Tensor],\n ) -> 'HeteroExplanation':\n out = copy.copy(self)\n\n for edge_type, edge_mask in edge_mask_dict.items():\n for key, value in self[edge_type].items():\n if key == 'edge_index':\n out[edge_type].edge_index = value[:, edge_mask]\n elif self[edge_type].is_edge_attr(key):\n out[edge_type][key] = value[edge_mask]\n\n return out.subgraph(node_mask_dict)\n\n def visualize_feature_importance(\n self,\n path: Optional[str] = None,\n feat_labels: Optional[Dict[NodeType, List[str]]] = None,\n top_k: Optional[int] = None,\n ):\n r\"\"\"Creates a bar plot of the node feature importances by summing up\n node masks across all nodes for each node type.\n\n Args:\n path (str, optional): The path to where the plot is saved.\n If set to :obj:`None`, will visualize the plot on-the-fly.\n (default: :obj:`None`)\n feat_labels (Dict[NodeType, List[str]], optional): The labels of\n features for each node type. (default :obj:`None`)\n top_k (int, optional): Top k features to plot. If :obj:`None`\n plots all features. (default: :obj:`None`)\n \"\"\"\n node_mask_dict = self.node_mask_dict\n for node_mask in node_mask_dict.values():\n if node_mask.dim() != 2 or node_mask.size(1) <= 1:\n raise ValueError(f\"Cannot compute feature importance for \"\n f\"object-level 'node_mask' \"\n f\"(got shape {node_mask_dict.size()})\")\n\n if feat_labels is None:\n feat_labels = {}\n for node_type, node_mask in node_mask_dict.items():\n feat_labels[node_type] = range(node_mask.size(1))\n\n score = torch.cat(\n [node_mask.sum(dim=0) for node_mask in node_mask_dict.values()],\n dim=0)\n\n all_feat_labels = []\n for node_type in node_mask_dict.keys():\n all_feat_labels += [\n f'{node_type}#{label}' for label in feat_labels[node_type]\n ]\n\n return _visualize_score(score, all_feat_labels, path, top_k)\n\n\ndef _visualize_score(\n score: torch.Tensor,\n labels: List[str],\n path: Optional[str] = None,\n top_k: Optional[int] = None,\n):\n import matplotlib.pyplot as plt\n import pandas as pd\n\n if len(labels) != score.numel():\n raise ValueError(f\"The number of labels (got {len(labels)}) must \"\n f\"match the number of scores (got {score.numel()})\")\n\n score = score.cpu().numpy()\n\n df = pd.DataFrame({'score': score}, index=labels)\n df = df.sort_values('score', ascending=False)\n df = df.round(decimals=3)\n\n if top_k is not None:\n df = df.head(top_k)\n title = f\"Feature importance for top {len(df)} features\"\n else:\n title = f\"Feature importance for {len(df)} features\"\n\n ax = df.plot(\n kind='barh',\n figsize=(10, 7),\n title=title,\n ylabel='Feature label',\n xlim=[0, float(df['score'].max()) + 0.3],\n legend=False,\n )\n plt.gca().invert_yaxis()\n ax.bar_label(container=ax.containers[0], label_type='edge')\n\n if path is not None:\n plt.savefig(path)\n else:\n plt.show()\n\n plt.close()\n\n# Path: torch_geometric/explain/metric/basic.py\nfrom typing import List, Optional, Tuple, Union\n\nfrom torch import Tensor\n\nMETRICS = ['accuracy', 'recall', 'precision', 'f1_score', 'auroc']\n\n\ndef groundtruth_metrics(\n pred_mask: Tensor,\n target_mask: Tensor,\n metrics: Optional[Union[str, List[str]]] = None,\n threshold: float = 0.5,\n) -> Union[float, Tuple[float, ...]]:\n r\"\"\"Compares and evaluates an explanation mask with the ground-truth\n explanation mask.\n\n Args:\n pred_mask (torch.Tensor): The prediction mask to evaluate.\n target_mask (torch.Tensor): The ground-truth target mask.\n metrics (str or List[str], optional): The metrics to return\n (:obj:`\"accuracy\"`, :obj:`\"recall\"`, :obj:`\"precision\"`,\n :obj:`\"f1_score\"`, :obj:`\"auroc\"`). (default: :obj:`[\"accuracy\",\n \"recall\", \"precision\", \"f1_score\", \"auroc\"]`)\n threshold (float, optional): The threshold value to perform hard\n thresholding of :obj:`mask` and :obj:`groundtruth`.\n (default: :obj:`0.5`)\n \"\"\"\n import torchmetrics\n\n if metrics is None:\n metrics = METRICS\n\n if isinstance(metrics, str):\n metrics = [metrics]\n\n if not isinstance(metrics, (tuple, list)):\n raise ValueError(f\"Expected metrics to be a string or a list of \"\n f\"strings (got {type(metrics)})\")\n\n pred_mask = pred_mask.view(-1)\n target_mask = (target_mask >= threshold).view(-1)\n\n outs = []\n for metric in metrics:\n if metric not in METRICS:\n raise ValueError(f\"Encountered invalid metric {metric}\")\n\n fn = getattr(torchmetrics.functional, metric)\n if metric in {'auroc'}:\n out = fn(pred_mask, target_mask, 'binary')\n else:\n out = fn(pred_mask, target_mask, 'binary', threshold)\n\n outs.append(float(out))\n\n return tuple(outs) if len(outs) > 1 else outs[0]\n\n# Path: torch_geometric/explain/metric/faithfulness.py\nfrom typing import Optional\n\nimport torch\nimport torch.nn.functional as F\n\nfrom torch_geometric.explain import Explainer, Explanation\nfrom torch_geometric.explain.config import MaskType, ModelMode, ModelReturnType\n\n\ndef unfaithfulness(\n explainer: Explainer,\n explanation: Explanation,\n top_k: Optional[int] = None,\n) -> float:\n r\"\"\"Evaluates how faithful an :class:`~torch_geometric.explain.Explanation`\n is to an underyling GNN predictor, as described in the\n `\"Evaluating Explainability for Graph Neural Networks\"\n `_ paper.\n\n In particular, the graph explanation unfaithfulness metric is defined as\n\n .. math::\n \\textrm{GEF}(y, \\hat{y}) = 1 - \\exp(- \\textrm{KL}(y || \\hat{y}))\n\n where :math:`y` refers to the prediction probability vector obtained from\n the original graph, and :math:`\\hat{y}` refers to the prediction\n probability vector obtained from the masked subgraph.\n Finally, the Kullback-Leibler (KL) divergence score quantifies the distance\n between the two probability distributions.\n\n Args:\n explainer (Explainer): The explainer to evaluate.\n explanation (Explanation): The explanation to evaluate.\n top_k (int, optional): If set, will only keep the original values of\n the top-:math:`k` node features identified by an explanation.\n If set to :obj:`None`, will use :obj:`explanation.node_mask` as it\n is for masking node features. (default: :obj:`None`)\n \"\"\"\n if explainer.model_config.mode == ModelMode.regression:\n raise ValueError(\"Fidelity not defined for 'regression' models\")\n\n if top_k is not None and explainer.node_mask_type == MaskType.object:\n raise ValueError(\"Cannot apply top-k feature selection based on a \"\n \"node mask of type 'object'\")\n\n node_mask = explanation.get('node_mask')\n edge_mask = explanation.get('edge_mask')\n x, edge_index = explanation.x, explanation.edge_index\n kwargs = {key: explanation[key] for key in explanation._model_args}\n\n y = explanation.get('prediction')\n if y is None: # == ExplanationType.phenomenon\n y = explainer.get_prediction(x, edge_index, **kwargs)\n\n if node_mask is not None and top_k is not None:\n feat_importance = node_mask.sum(dim=0)\n _, top_k_index = feat_importance.topk(top_k)\n node_mask = torch.zeros_like(node_mask)\n node_mask[:, top_k_index] = 1.0\n\n y_hat = explainer.get_masked_prediction(x, edge_index, node_mask,\n edge_mask, **kwargs)\n\n if explanation.get('index') is not None:\n y, y_hat = y[explanation.index], y_hat[explanation.index]\n\n if explainer.model_config.return_type == ModelReturnType.raw:\n y, y_hat = y.softmax(dim=-1), y_hat.softmax(dim=-1)\n elif explainer.model_config.return_type == ModelReturnType.log_probs:\n y, y_hat = y.exp(), y_hat.exp()\n\n kl_div = F.kl_div(y.log(), y_hat, reduction='batchmean')\n return 1 - float(torch.exp(-kl_div))\n\n# Path: torch_geometric/explain/metric/fidelity.py\nfrom typing import Tuple\n\nimport torch\nfrom torch import Tensor\n\nfrom torch_geometric.explain import Explainer, Explanation\nfrom torch_geometric.explain.config import ExplanationType, ModelMode\n\n\ndef fidelity(\n explainer: Explainer,\n explanation: Explanation,\n) -> Tuple[float, float]:\n r\"\"\"Evaluates the fidelity of an\n :class:`~torch_geometric.explain.Explainer` given an\n :class:`~torch_geometric.explain.Explanation`, as described in the\n `\"GraphFramEx: Towards Systematic Evaluation of Explainability Methods for\n Graph Neural Networks\" `_ paper.\n\n Fidelity evaluates the contribution of the produced explanatory subgraph\n to the initial prediction, either by giving only the subgraph to the model\n (fidelity-) or by removing it from the entire graph (fidelity+).\n The fidelity scores capture how good an explainable model reproduces the\n natural phenomenon or the GNN model logic.\n\n For **phenomenon** explanations, the fidelity scores are given by:\n\n .. math::\n \\textrm{fid}_{+} &= \\frac{1}{N} \\sum_{i = 1}^N\n \\| \\mathbb{1}(\\hat{y}_i = y_i) -\n \\mathbb{1}( \\hat{y}_i^{G_{C \\setminus S}} = y_i) \\|\n\n \\textrm{fid}_{-} &= \\frac{1}{N} \\sum_{i = 1}^N\n \\| \\mathbb{1}(\\hat{y}_i = y_i) -\n \\mathbb{1}( \\hat{y}_i^{G_S} = y_i) \\|\n\n For **model** explanations, the fidelity scores are given by:\n\n .. math::\n \\textrm{fid}_{+} &= 1 - \\frac{1}{N} \\sum_{i = 1}^N\n \\mathbb{1}( \\hat{y}_i^{G_{C \\setminus S}} = \\hat{y}_i)\n\n \\textrm{fid}_{-} &= 1 - \\frac{1}{N} \\sum_{i = 1}^N\n \\mathbb{1}( \\hat{y}_i^{G_S} = \\hat{y}_i)\n\n Args:\n explainer (Explainer): The explainer to evaluate.\n explanation (Explanation): The explanation to evaluate.\n \"\"\"\n if explainer.model_config.mode == ModelMode.regression:\n raise ValueError(\"Fidelity not defined for 'regression' models\")\n\n node_mask = explanation.get('node_mask')\n edge_mask = explanation.get('edge_mask')\n kwargs = {key: explanation[key] for key in explanation._model_args}\n\n y = explanation.target\n if explainer.explanation_type == ExplanationType.phenomenon:\n y_hat = explainer.get_prediction(\n explanation.x,\n explanation.edge_index,\n **kwargs,\n )\n y_hat = explainer.get_target(y_hat)\n\n explain_y_hat = explainer.get_masked_prediction(\n explanation.x,\n explanation.edge_index,\n node_mask,\n edge_mask,\n **kwargs,\n )\n explain_y_hat = explainer.get_target(explain_y_hat)\n\n complement_y_hat = explainer.get_masked_prediction(\n explanation.x,\n explanation.edge_index,\n 1. - node_mask if node_mask is not None else None,\n 1. - edge_mask if edge_mask is not None else None,\n **kwargs,\n )\n complement_y_hat = explainer.get_target(complement_y_hat)\n\n if explanation.get('index') is not None:\n y = y[explanation.index]\n if explainer.explanation_type == ExplanationType.phenomenon:\n y_hat = y_hat[explanation.index]\n explain_y_hat = explain_y_hat[explanation.index]\n complement_y_hat = complement_y_hat[explanation.index]\n\n if explainer.explanation_type == ExplanationType.model:\n pos_fidelity = 1. - (complement_y_hat == y).float().mean()\n neg_fidelity = 1. - (explain_y_hat == y).float().mean()\n else:\n pos_fidelity = ((y_hat == y).float() -\n (complement_y_hat == y).float()).abs().mean()\n neg_fidelity = ((y_hat == y).float() -\n (explain_y_hat == y).float()).abs().mean()\n\n...\n", "needle_token_start": 756, "needle_token_end": 819, "code_context_ntokens": 16078, "output": ["```python\ndef maybe_init_stream(self) -> None:\n if self.is_gpu:\n self.stream = self.module.Stream()\n self.stream_context = partial(\n self.module.stream,\n stream=self.stream,\n )\n```"]} +{"repo": "pyg-team/pytorch_geometric", "name": "positive_sampling", "language": "python", "path": "torch_geometric/nn/conv/supergat_conv.py", "position_ratio": 0.15, "description": "\nFunction Description:\n1. **Purpose**: The function is designed to randomly select a subset of existing edges from a graph, which can be used to train models that require a sample of positive interactions.\n2. **Input**: It takes a tensor representing the indices of edges in a graph.\n3. **Output**: It returns a tensor containing a subset of the input edge indices, representing the sampled positive edges.\n4. **Procedure**: The function first applies a dropout to the edges based on a predefined sampling ratio, effectively removing a certain percentage of edges randomly. The remaining edges are then output as the sampled positive edges, which are used during the training phase to compare against negative samples.\n\n", "instruction": "Based on the function description and code context, please retrieve and repeat the exact described function from the code context in a code block wrapped by ```:", "template": "instruction\ncode_context\ndescription\ninstruction", "code_context": " negative_slope (float, optional): LeakyReLU angle of the negative\n slope. (default: :obj:`0.2`)\n dropout (float, optional): Dropout probability of the normalized\n attention coefficients which exposes each node to a stochastically\n sampled neighborhood during training. (default: :obj:`0`)\n add_self_loops (bool, optional): If set to :obj:`False`, will not add\n self-loops to the input graph. (default: :obj:`True`)\n bias (bool, optional): If set to :obj:`False`, the layer will not learn\n an additive bias. (default: :obj:`True`)\n attention_type (str, optional): Type of attention to use\n (:obj:`'MX'`, :obj:`'SD'`). (default: :obj:`'MX'`)\n neg_sample_ratio (float, optional): The ratio of the number of sampled\n negative edges to the number of positive edges.\n (default: :obj:`0.5`)\n edge_sample_ratio (float, optional): The ratio of samples to use for\n training among the number of training edges. (default: :obj:`1.0`)\n is_undirected (bool, optional): Whether the input graph is undirected.\n If not given, will be automatically computed with the input graph\n when negative sampling is performed. (default: :obj:`False`)\n **kwargs (optional): Additional arguments of\n :class:`torch_geometric.nn.conv.MessagePassing`.\n\n Shapes:\n - **input:**\n node features :math:`(|\\mathcal{V}|, F_{in})`,\n edge indices :math:`(2, |\\mathcal{E}|)`,\n negative edge indices :math:`(2, |\\mathcal{E}^{(-)}|)` *(optional)*\n - **output:** node features :math:`(|\\mathcal{V}|, H * F_{out})`\n \"\"\"\n att_x: OptTensor\n att_y: OptTensor\n\n def __init__(self, in_channels: int, out_channels: int, heads: int = 1,\n concat: bool = True, negative_slope: float = 0.2,\n dropout: float = 0.0, add_self_loops: bool = True,\n bias: bool = True, attention_type: str = 'MX',\n neg_sample_ratio: float = 0.5, edge_sample_ratio: float = 1.0,\n is_undirected: bool = False, **kwargs):\n kwargs.setdefault('aggr', 'add')\n super().__init__(node_dim=0, **kwargs)\n\n self.in_channels = in_channels\n self.out_channels = out_channels\n self.heads = heads\n self.concat = concat\n self.negative_slope = negative_slope\n self.dropout = dropout\n self.add_self_loops = add_self_loops\n self.attention_type = attention_type\n self.neg_sample_ratio = neg_sample_ratio\n self.edge_sample_ratio = edge_sample_ratio\n self.is_undirected = is_undirected\n\n assert attention_type in ['MX', 'SD']\n assert 0.0 < neg_sample_ratio and 0.0 < edge_sample_ratio <= 1.0\n\n self.lin = Linear(in_channels, heads * out_channels, bias=False,\n weight_initializer='glorot')\n\n if self.attention_type == 'MX':\n self.att_l = Parameter(torch.empty(1, heads, out_channels))\n self.att_r = Parameter(torch.empty(1, heads, out_channels))\n else: # self.attention_type == 'SD'\n self.register_parameter('att_l', None)\n self.register_parameter('att_r', None)\n\n self.att_x = self.att_y = None # x/y for self-supervision\n\n if bias and concat:\n self.bias = Parameter(torch.empty(heads * out_channels))\n elif bias and not concat:\n self.bias = Parameter(torch.empty(out_channels))\n else:\n self.register_parameter('bias', None)\n\n self.reset_parameters()\n\n def reset_parameters(self):\n super().reset_parameters()\n self.lin.reset_parameters()\n glorot(self.att_l)\n glorot(self.att_r)\n zeros(self.bias)\n\n def forward(\n self,\n x: Tensor,\n edge_index: Adj,\n neg_edge_index: OptTensor = None,\n batch: OptTensor = None,\n ) -> Tensor:\n r\"\"\"Runs the forward pass of the module.\n\n Args:\n x (torch.Tensor): The input node features.\n edge_index (torch.Tensor or SparseTensor): The edge indices.\n neg_edge_index (torch.Tensor, optional): The negative edges to\n train against. If not given, uses negative sampling to\n calculate negative edges. (default: :obj:`None`)\n batch (torch.Tensor, optional): The batch vector\n :math:`\\mathbf{b} \\in {\\{ 0, \\ldots, B-1\\}}^N`, which assigns\n each element to a specific example.\n Used when sampling negatives on-the-fly in mini-batch\n scenarios. (default: :obj:`None`)\n \"\"\"\n N, H, C = x.size(0), self.heads, self.out_channels\n\n if self.add_self_loops:\n if isinstance(edge_index, SparseTensor):\n edge_index = torch_sparse.fill_diag(edge_index, 1.)\n else:\n edge_index, _ = remove_self_loops(edge_index)\n edge_index, _ = add_self_loops(edge_index, num_nodes=N)\n\n x = self.lin(x).view(-1, H, C)\n\n # propagate_type: (x: Tensor)\n out = self.propagate(edge_index, x=x)\n\n if self.training:\n if isinstance(edge_index, SparseTensor):\n col, row, _ = edge_index.coo()\n edge_index = torch.stack([row, col], dim=0)\n pos_edge_index = self.positive_sampling(edge_index)\n\n pos_att = self.get_attention(\n edge_index_i=pos_edge_index[1],\n x_i=x[pos_edge_index[1]],\n x_j=x[pos_edge_index[0]],\n num_nodes=x.size(0),\n return_logits=True,\n )\n\n if neg_edge_index is None:\n neg_edge_index = self.negative_sampling(edge_index, N, batch)\n\n neg_att = self.get_attention(\n edge_index_i=neg_edge_index[1],\n x_i=x[neg_edge_index[1]],\n x_j=x[neg_edge_index[0]],\n num_nodes=x.size(0),\n return_logits=True,\n )\n\n self.att_x = torch.cat([pos_att, neg_att], dim=0)\n self.att_y = self.att_x.new_zeros(self.att_x.size(0))\n self.att_y[:pos_edge_index.size(1)] = 1.\n\n if self.concat is True:\n out = out.view(-1, self.heads * self.out_channels)\n else:\n out = out.mean(dim=1)\n\n if self.bias is not None:\n out = out + self.bias\n\n return out\n\n def message(self, edge_index_i: Tensor, x_i: Tensor, x_j: Tensor,\n size_i: Optional[int]) -> Tensor:\n alpha = self.get_attention(edge_index_i, x_i, x_j, num_nodes=size_i)\n alpha = F.dropout(alpha, p=self.dropout, training=self.training)\n return x_j * alpha.view(-1, self.heads, 1)\n\n def negative_sampling(self, edge_index: Tensor, num_nodes: int,\n batch: OptTensor = None) -> Tensor:\n\n num_neg_samples = int(self.neg_sample_ratio * self.edge_sample_ratio *\n edge_index.size(1))\n\n if not self.is_undirected and not is_undirected(\n edge_index, num_nodes=num_nodes):\n edge_index = to_undirected(edge_index, num_nodes=num_nodes)\n\n if batch is None:\n neg_edge_index = negative_sampling(edge_index, num_nodes,\n num_neg_samples=num_neg_samples)\n else:\n neg_edge_index = batched_negative_sampling(\n edge_index, batch, num_neg_samples=num_neg_samples)\n\n return neg_edge_index\n\n \ndef positive_sampling(self, edge_index: Tensor) -> Tensor:\n pos_edge_index, _ = dropout_edge(edge_index,\n p=1. - self.edge_sample_ratio,\n training=self.training)\n return pos_edge_index\n\n def get_attention(self, edge_index_i: Tensor, x_i: Tensor, x_j: Tensor,\n num_nodes: Optional[int],\n return_logits: bool = False) -> Tensor:\n\n if self.attention_type == 'MX':\n logits = (x_i * x_j).sum(dim=-1)\n if return_logits:\n return logits\n\n alpha = (x_j * self.att_l).sum(-1) + (x_i * self.att_r).sum(-1)\n alpha = alpha * logits.sigmoid()\n\n else: # self.attention_type == 'SD'\n alpha = (x_i * x_j).sum(dim=-1) / math.sqrt(self.out_channels)\n if return_logits:\n return alpha\n\n alpha = F.leaky_relu(alpha, self.negative_slope)\n alpha = softmax(alpha, edge_index_i, num_nodes=num_nodes)\n return alpha\n\n def get_attention_loss(self) -> Tensor:\n r\"\"\"Computes the self-supervised graph attention loss.\"\"\"\n if not self.training:\n return torch.tensor([0], device=self.lin.weight.device)\n\n return F.binary_cross_entropy_with_logits(\n self.att_x.mean(dim=-1),\n self.att_y,\n )\n\n def __repr__(self) -> str:\n return (f'{self.__class__.__name__}({self.in_channels}, '\n f'{self.out_channels}, heads={self.heads}, '\n f'type={self.attention_type})')\n\n# Path: torch_geometric/nn/conv/tag_conv.py\nimport torch\nfrom torch import Tensor\n\nfrom torch_geometric.nn.conv import MessagePassing\nfrom torch_geometric.nn.conv.gcn_conv import gcn_norm\nfrom torch_geometric.nn.dense.linear import Linear\nfrom torch_geometric.nn.inits import zeros\nfrom torch_geometric.typing import Adj, OptTensor, SparseTensor\nfrom torch_geometric.utils import spmm\n\n\nclass TAGConv(MessagePassing):\n r\"\"\"The topology adaptive graph convolutional networks operator from the\n `\"Topology Adaptive Graph Convolutional Networks\"\n `_ paper.\n\n .. math::\n \\mathbf{X}^{\\prime} = \\sum_{k=0}^K \\left( \\mathbf{D}^{-1/2} \\mathbf{A}\n \\mathbf{D}^{-1/2} \\right)^k \\mathbf{X} \\mathbf{W}_{k},\n\n where :math:`\\mathbf{A}` denotes the adjacency matrix and\n :math:`D_{ii} = \\sum_{j=0} A_{ij}` its diagonal degree matrix.\n The adjacency matrix can include other values than :obj:`1` representing\n edge weights via the optional :obj:`edge_weight` tensor.\n\n Args:\n in_channels (int): Size of each input sample, or :obj:`-1` to derive\n the size from the first input(s) to the forward method.\n out_channels (int): Size of each output sample.\n K (int, optional): Number of hops :math:`K`. (default: :obj:`3`)\n bias (bool, optional): If set to :obj:`False`, the layer will not learn\n an additive bias. (default: :obj:`True`)\n normalize (bool, optional): Whether to apply symmetric normalization.\n (default: :obj:`True`)\n **kwargs (optional): Additional arguments of\n :class:`torch_geometric.nn.conv.MessagePassing`.\n\n Shapes:\n - **input:**\n node_features :math:`(|\\mathcal{V}|, F_{in})`,\n edge_index :math:`(2, |\\mathcal{E}|)`,\n edge_weights :math:`(|\\mathcal{E}|)` *(optional)*\n - **output:** node features :math:`(|\\mathcal{V}|, F_{out})`\n \"\"\"\n def __init__(self, in_channels: int, out_channels: int, K: int = 3,\n bias: bool = True, normalize: bool = True, **kwargs):\n kwargs.setdefault('aggr', 'add')\n super().__init__(**kwargs)\n\n self.in_channels = in_channels\n self.out_channels = out_channels\n self.K = K\n self.normalize = normalize\n\n self.lins = torch.nn.ModuleList([\n Linear(in_channels, out_channels, bias=False) for _ in range(K + 1)\n ])\n\n if bias:\n self.bias = torch.nn.Parameter(torch.empty(out_channels))\n else:\n self.register_parameter('bias', None)\n\n self.reset_parameters()\n\n def reset_parameters(self):\n super().reset_parameters()\n for lin in self.lins:\n lin.reset_parameters()\n zeros(self.bias)\n\n def forward(self, x: Tensor, edge_index: Adj,\n edge_weight: OptTensor = None) -> Tensor:\n\n if self.normalize:\n if isinstance(edge_index, Tensor):\n edge_index, edge_weight = gcn_norm( # yapf: disable\n edge_index, edge_weight, x.size(self.node_dim),\n improved=False, add_self_loops=False, flow=self.flow,\n dtype=x.dtype)\n\n elif isinstance(edge_index, SparseTensor):\n edge_index = gcn_norm( # yapf: disable\n edge_index, edge_weight, x.size(self.node_dim),\n add_self_loops=False, flow=self.flow, dtype=x.dtype)\n\n out = self.lins[0](x)\n for lin in self.lins[1:]:\n # propagate_type: (x: Tensor, edge_weight: OptTensor)\n x = self.propagate(edge_index, x=x, edge_weight=edge_weight)\n out = out + lin.forward(x)\n\n if self.bias is not None:\n out = out + self.bias\n\n return out\n\n def message(self, x_j: Tensor, edge_weight: OptTensor) -> Tensor:\n return x_j if edge_weight is None else edge_weight.view(-1, 1) * x_j\n\n def message_and_aggregate(self, adj_t: Adj, x: Tensor) -> Tensor:\n return spmm(adj_t, x, reduce=self.aggr)\n\n def __repr__(self) -> str:\n return (f'{self.__class__.__name__}({self.in_channels}, '\n f'{self.out_channels}, K={self.K})')\n\n# Path: torch_geometric/nn/conv/transformer_conv.py\nimport math\nimport typing\nfrom typing import Optional, Tuple, Union\n\nimport torch\nimport torch.nn.functional as F\nfrom torch import Tensor\n\nfrom torch_geometric.nn.conv import MessagePassing\nfrom torch_geometric.nn.dense.linear import Linear\nfrom torch_geometric.typing import (\n Adj,\n NoneType,\n OptTensor,\n PairTensor,\n SparseTensor,\n)\nfrom torch_geometric.utils import softmax\n\nif typing.TYPE_CHECKING:\n from typing import overload\nelse:\n from torch.jit import _overload_method as overload\n\n\nclass TransformerConv(MessagePassing):\n r\"\"\"The graph transformer operator from the `\"Masked Label Prediction:\n Unified Message Passing Model for Semi-Supervised Classification\"\n `_ paper.\n\n .. math::\n \\mathbf{x}^{\\prime}_i = \\mathbf{W}_1 \\mathbf{x}_i +\n \\sum_{j \\in \\mathcal{N}(i)} \\alpha_{i,j} \\mathbf{W}_2 \\mathbf{x}_{j},\n\n where the attention coefficients :math:`\\alpha_{i,j}` are computed via\n multi-head dot product attention:\n\n .. math::\n \\alpha_{i,j} = \\textrm{softmax} \\left(\n \\frac{(\\mathbf{W}_3\\mathbf{x}_i)^{\\top} (\\mathbf{W}_4\\mathbf{x}_j)}\n {\\sqrt{d}} \\right)\n\n Args:\n in_channels (int or tuple): Size of each input sample, or :obj:`-1` to\n derive the size from the first input(s) to the forward method.\n A tuple corresponds to the sizes of source and target\n dimensionalities.\n out_channels (int): Size of each output sample.\n heads (int, optional): Number of multi-head-attentions.\n (default: :obj:`1`)\n concat (bool, optional): If set to :obj:`False`, the multi-head\n attentions are averaged instead of concatenated.\n (default: :obj:`True`)\n beta (bool, optional): If set, will combine aggregation and\n skip information via\n\n .. math::\n \\mathbf{x}^{\\prime}_i = \\beta_i \\mathbf{W}_1 \\mathbf{x}_i +\n (1 - \\beta_i) \\underbrace{\\left(\\sum_{j \\in \\mathcal{N}(i)}\n \\alpha_{i,j} \\mathbf{W}_2 \\vec{x}_j \\right)}_{=\\mathbf{m}_i}\n\n with :math:`\\beta_i = \\textrm{sigmoid}(\\mathbf{w}_5^{\\top}\n [ \\mathbf{W}_1 \\mathbf{x}_i, \\mathbf{m}_i, \\mathbf{W}_1\n \\mathbf{x}_i - \\mathbf{m}_i ])` (default: :obj:`False`)\n dropout (float, optional): Dropout probability of the normalized\n attention coefficients which exposes each node to a stochastically\n sampled neighborhood during training. (default: :obj:`0`)\n edge_dim (int, optional): Edge feature dimensionality (in case\n there are any). Edge features are added to the keys after\n linear transformation, that is, prior to computing the\n attention dot product. They are also added to final values\n after the same linear transformation. The model is:\n\n .. math::\n \\mathbf{x}^{\\prime}_i = \\mathbf{W}_1 \\mathbf{x}_i +\n \\sum_{j \\in \\mathcal{N}(i)} \\alpha_{i,j} \\left(\n \\mathbf{W}_2 \\mathbf{x}_{j} + \\mathbf{W}_6 \\mathbf{e}_{ij}\n \\right),\n\n where the attention coefficients :math:`\\alpha_{i,j}` are now\n computed via:\n\n .. math::\n \\alpha_{i,j} = \\textrm{softmax} \\left(\n \\frac{(\\mathbf{W}_3\\mathbf{x}_i)^{\\top}\n (\\mathbf{W}_4\\mathbf{x}_j + \\mathbf{W}_6 \\mathbf{e}_{ij})}\n {\\sqrt{d}} \\right)\n\n (default :obj:`None`)\n bias (bool, optional): If set to :obj:`False`, the layer will not learn\n an additive bias. (default: :obj:`True`)\n root_weight (bool, optional): If set to :obj:`False`, the layer will\n not add the transformed root node features to the output and the\n option :attr:`beta` is set to :obj:`False`. (default: :obj:`True`)\n **kwargs (optional): Additional arguments of\n :class:`torch_geometric.nn.conv.MessagePassing`.\n \"\"\"\n _alpha: OptTensor\n\n def __init__(\n self,\n in_channels: Union[int, Tuple[int, int]],\n out_channels: int,\n heads: int = 1,\n concat: bool = True,\n beta: bool = False,\n dropout: float = 0.,\n edge_dim: Optional[int] = None,\n bias: bool = True,\n root_weight: bool = True,\n **kwargs,\n ):\n kwargs.setdefault('aggr', 'add')\n super().__init__(node_dim=0, **kwargs)\n\n self.in_channels = in_channels\n self.out_channels = out_channels\n self.heads = heads\n self.beta = beta and root_weight\n self.root_weight = root_weight\n self.concat = concat\n self.dropout = dropout\n self.edge_dim = edge_dim\n self._alpha = None\n\n if isinstance(in_channels, int):\n in_channels = (in_channels, in_channels)\n\n self.lin_key = Linear(in_channels[0], heads * out_channels)\n self.lin_query = Linear(in_channels[1], heads * out_channels)\n self.lin_value = Linear(in_channels[0], heads * out_channels)\n if edge_dim is not None:\n self.lin_edge = Linear(edge_dim, heads * out_channels, bias=False)\n else:\n self.lin_edge = self.register_parameter('lin_edge', None)\n\n if concat:\n self.lin_skip = Linear(in_channels[1], heads * out_channels,\n bias=bias)\n if self.beta:\n self.lin_beta = Linear(3 * heads * out_channels, 1, bias=False)\n else:\n self.lin_beta = self.register_parameter('lin_beta', None)\n else:\n self.lin_skip = Linear(in_channels[1], out_channels, bias=bias)\n if self.beta:\n self.lin_beta = Linear(3 * out_channels, 1, bias=False)\n else:\n self.lin_beta = self.register_parameter('lin_beta', None)\n\n self.reset_parameters()\n\n def reset_parameters(self):\n super().reset_parameters()\n self.lin_key.reset_parameters()\n self.lin_query.reset_parameters()\n self.lin_value.reset_parameters()\n if self.edge_dim:\n self.lin_edge.reset_parameters()\n self.lin_skip.reset_parameters()\n if self.beta:\n self.lin_beta.reset_parameters()\n\n @overload\n def forward(\n self,\n x: Union[Tensor, PairTensor],\n edge_index: Adj,\n edge_attr: OptTensor = None,\n return_attention_weights: NoneType = None,\n ) -> Tensor:\n pass\n\n @overload\n def forward( # noqa: F811\n self,\n x: Union[Tensor, PairTensor],\n edge_index: Tensor,\n edge_attr: OptTensor = None,\n return_attention_weights: bool = None,\n ) -> Tuple[Tensor, Tuple[Tensor, Tensor]]:\n pass\n\n @overload\n def forward( # noqa: F811\n self,\n x: Union[Tensor, PairTensor],\n edge_index: SparseTensor,\n edge_attr: OptTensor = None,\n return_attention_weights: bool = None,\n ) -> Tuple[Tensor, SparseTensor]:\n pass\n\n def forward( # noqa: F811\n self,\n x: Union[Tensor, PairTensor],\n edge_index: Adj,\n edge_attr: OptTensor = None,\n return_attention_weights: Optional[bool] = None,\n ) -> Union[\n Tensor,\n Tuple[Tensor, Tuple[Tensor, Tensor]],\n Tuple[Tensor, SparseTensor],\n ]:\n r\"\"\"Runs the forward pass of the module.\n\n Args:\n x (torch.Tensor or (torch.Tensor, torch.Tensor)): The input node\n features.\n edge_index (torch.Tensor or SparseTensor): The edge indices.\n edge_attr (torch.Tensor, optional): The edge features.\n (default: :obj:`None`)\n return_attention_weights (bool, optional): If set to :obj:`True`,\n will additionally return the tuple\n :obj:`(edge_index, attention_weights)`, holding the computed\n attention weights for each edge. (default: :obj:`None`)\n \"\"\"\n H, C = self.heads, self.out_channels\n\n if isinstance(x, Tensor):\n x = (x, x)\n\n query = self.lin_query(x[1]).view(-1, H, C)\n key = self.lin_key(x[0]).view(-1, H, C)\n value = self.lin_value(x[0]).view(-1, H, C)\n\n # propagate_type: (query: Tensor, key:Tensor, value: Tensor,\n # edge_attr: OptTensor)\n out = self.propagate(edge_index, query=query, key=key, value=value,\n edge_attr=edge_attr)\n\n alpha = self._alpha\n self._alpha = None\n\n if self.concat:\n out = out.view(-1, self.heads * self.out_channels)\n else:\n out = out.mean(dim=1)\n\n if self.root_weight:\n x_r = self.lin_skip(x[1])\n if self.lin_beta is not None:\n beta = self.lin_beta(torch.cat([out, x_r, out - x_r], dim=-1))\n beta = beta.sigmoid()\n out = beta * x_r + (1 - beta) * out\n else:\n out = out + x_r\n\n if isinstance(return_attention_weights, bool):\n assert alpha is not None\n if isinstance(edge_index, Tensor):\n return out, (edge_index, alpha)\n elif isinstance(edge_index, SparseTensor):\n return out, edge_index.set_value(alpha, layout='coo')\n else:\n return out\n\n def message(self, query_i: Tensor, key_j: Tensor, value_j: Tensor,\n edge_attr: OptTensor, index: Tensor, ptr: OptTensor,\n size_i: Optional[int]) -> Tensor:\n\n if self.lin_edge is not None:\n assert edge_attr is not None\n edge_attr = self.lin_edge(edge_attr).view(-1, self.heads,\n self.out_channels)\n key_j = key_j + edge_attr\n\n alpha = (query_i * key_j).sum(dim=-1) / math.sqrt(self.out_channels)\n alpha = softmax(alpha, index, ptr, size_i)\n self._alpha = alpha\n alpha = F.dropout(alpha, p=self.dropout, training=self.training)\n\n out = value_j\n if edge_attr is not None:\n out = out + edge_attr\n\n out = out * alpha.view(-1, self.heads, 1)\n return out\n\n def __repr__(self) -> str:\n return (f'{self.__class__.__name__}({self.in_channels}, '\n f'{self.out_channels}, heads={self.heads})')\n\n# Path: torch_geometric/nn/conv/utils/cheatsheet.py\nimport importlib\nimport inspect\nimport re\nfrom typing import Optional\n\n\ndef paper_title(cls: str) -> Optional[str]:\n cls = importlib.import_module('torch_geometric.nn.conv').__dict__[cls]\n match = re.search('`\\\".+?\\\"', inspect.getdoc(cls), flags=re.DOTALL)\n return None if match is None else match.group().replace('\\n', ' ')[2:-1]\n\n\ndef paper_link(cls: str) -> Optional[str]:\n cls = importlib.import_module('torch_geometric.nn.conv').__dict__[cls]\n match = re.search('<.+?>', inspect.getdoc(cls), flags=re.DOTALL)\n return None if match is None else match.group().replace('\\n', ' ')[1:-1]\n\n\ndef supports_sparse_tensor(cls: str) -> bool:\n cls = importlib.import_module('torch_geometric.nn.conv').__dict__[cls]\n signature = inspect.signature(cls.forward)\n return 'SparseTensor' in str(signature)\n\n\ndef supports_edge_weights(cls: str) -> bool:\n cls = importlib.import_module('torch_geometric.nn.conv').__dict__[cls]\n signature = inspect.signature(cls.forward)\n return 'edge_weight' in str(signature)\n\n\ndef supports_edge_features(cls: str) -> bool:\n cls = importlib.import_module('torch_geometric.nn.conv').__dict__[cls]\n signature = inspect.signature(cls.forward)\n return 'edge_attr' in str(signature)\n\n\ndef supports_bipartite_graphs(cls: str) -> bool:\n cls = importlib.import_module('torch_geometric.nn.conv').__dict__[cls]\n signature = inspect.signature(cls.forward)\n return 'Union[torch.Tensor, Tuple[torch.Tensor' in str(signature)\n\n\ndef supports_static_graphs(cls: str) -> bool:\n cls = importlib.import_module('torch_geometric.nn.conv').__dict__[cls]\n return 'node_dim=' not in inspect.getsource(cls.__init__)\n\n\ndef supports_lazy_initialization(cls: str) -> bool:\n cls = importlib.import_module('torch_geometric.nn.conv').__dict__[cls]\n doc = re.sub(' +', ' ', inspect.getdoc(cls).replace('\\n', ' '))\n match = re.search('or :obj:`-1` to derive the size from the first', doc)\n return match is not None\n\n\ndef processes_heterogeneous_graphs(cls: str) -> bool:\n if 'hetero' in cls.lower():\n return True\n cls = importlib.import_module('torch_geometric.nn.conv').__dict__[cls]\n signature = inspect.signature(cls.forward)\n return 'edge_index_dict' in str(signature) or 'edge_type' in str(signature)\n\n\ndef processes_hypergraphs(cls: str) -> bool:\n cls = importlib.import_module('torch_geometric.nn.conv').__dict__[cls]\n signature = inspect.signature(cls.forward)\n return 'hyperedge_index' in str(signature)\n\n\ndef processes_point_clouds(cls: str) -> bool:\n cls = importlib.import_module('torch_geometric.nn.conv').__dict__[cls]\n signature = inspect.signature(cls.forward)\n return (('edge_index' not in str(signature)\n and 'csc' not in str(signature)) or 'pos' in str(signature))\n\n# Path: torch_geometric/nn/conv/utils/__init__.py\nr\"\"\"GNN utility package.\"\"\"\n\nfrom .cheatsheet import paper_title, paper_link\nfrom .cheatsheet import supports_sparse_tensor\nfrom .cheatsheet import supports_edge_weights\nfrom .cheatsheet import supports_edge_features\nfrom .cheatsheet import supports_bipartite_graphs\nfrom .cheatsheet import supports_static_graphs\nfrom .cheatsheet import supports_lazy_initialization\nfrom .cheatsheet import processes_heterogeneous_graphs\nfrom .cheatsheet import processes_hypergraphs\nfrom .cheatsheet import processes_point_clouds\n\n__all__ = [\n 'paper_title',\n 'paper_link',\n 'supports_sparse_tensor',\n 'supports_edge_weights',\n 'supports_edge_features',\n 'supports_bipartite_graphs',\n 'supports_static_graphs',\n 'supports_lazy_initialization',\n 'processes_heterogeneous_graphs',\n 'processes_hypergraphs',\n 'processes_point_clouds',\n]\n\n# Path: torch_geometric/nn/conv/wl_conv.py\nfrom typing import Optional\n\nimport torch\nfrom torch import Tensor\n\nfrom torch_geometric.typing import Adj\nfrom torch_geometric.utils import (\n degree,\n is_sparse,\n scatter,\n sort_edge_index,\n to_edge_index,\n)\n\n\nclass WLConv(torch.nn.Module):\n r\"\"\"The Weisfeiler Lehman (WL) operator from the `\"A Reduction of a Graph\n to a Canonical Form and an Algebra Arising During this Reduction\"\n `_ paper.\n\n :class:`WLConv` iteratively refines node colorings according to:\n\n .. math::\n \\mathbf{x}^{\\prime}_i = \\textrm{hash} \\left( \\mathbf{x}_i, \\{\n \\mathbf{x}_j \\colon j \\in \\mathcal{N}(i) \\} \\right)\n\n Shapes:\n - **input:**\n node coloring :math:`(|\\mathcal{V}|, F_{in})` *(one-hot encodings)*\n or :math:`(|\\mathcal{V}|)` *(integer-based)*,\n edge indices :math:`(2, |\\mathcal{E}|)`\n - **output:** node coloring :math:`(|\\mathcal{V}|)` *(integer-based)*\n \"\"\"\n def __init__(self):\n super().__init__()\n self.hashmap = {}\n\n def reset_parameters(self):\n r\"\"\"Resets all learnable parameters of the module.\"\"\"\n self.hashmap = {}\n\n @torch.no_grad()\n def forward(self, x: Tensor, edge_index: Adj) -> Tensor:\n r\"\"\"Runs the forward pass of the module.\"\"\"\n if x.dim() > 1:\n assert (x.sum(dim=-1) == 1).sum() == x.size(0)\n x = x.argmax(dim=-1) # one-hot -> integer.\n assert x.dtype == torch.long\n\n if is_sparse(edge_index):\n col_and_row, _ = to_edge_index(edge_index)\n col = col_and_row[0]\n row = col_and_row[1]\n else:\n edge_index = sort_edge_index(edge_index, num_nodes=x.size(0),\n sort_by_row=False)\n row, col = edge_index[0], edge_index[1]\n\n # `col` is sorted, so we can use it to `split` neighbors to groups:\n deg = degree(col, x.size(0), dtype=torch.long).tolist()\n\n out = []\n for node, neighbors in zip(x.tolist(), x[row].split(deg)):\n idx = hash(tuple([node] + neighbors.sort()[0].tolist()))\n if idx not in self.hashmap:\n self.hashmap[idx] = len(self.hashmap)\n out.append(self.hashmap[idx])\n\n return torch.tensor(out, device=x.device)\n\n def histogram(self, x: Tensor, batch: Optional[Tensor] = None,\n norm: bool = False) -> Tensor:\n r\"\"\"Given a node coloring :obj:`x`, computes the color histograms of\n the respective graphs (separated by :obj:`batch`).\n \"\"\"\n if batch is None:\n batch = torch.zeros(x.size(0), dtype=torch.long, device=x.device)\n\n num_colors = len(self.hashmap)\n batch_size = int(batch.max()) + 1\n\n index = batch * num_colors + x\n out = scatter(torch.ones_like(index), index, dim=0,\n dim_size=num_colors * batch_size, reduce='sum')\n out = out.view(batch_size, num_colors)\n\n if norm:\n out = out.to(torch.float)\n out /= out.norm(dim=-1, keepdim=True)\n\n return out\n\n# Path: torch_geometric/nn/conv/wl_conv_continuous.py\nfrom typing import Union\n\nfrom torch import Tensor\n\nfrom torch_geometric.nn.conv import MessagePassing\nfrom torch_geometric.typing import (\n Adj,\n OptPairTensor,\n OptTensor,\n Size,\n SparseTensor,\n)\nfrom torch_geometric.utils import scatter, spmm\n\n\nclass WLConvContinuous(MessagePassing):\n r\"\"\"The Weisfeiler Lehman operator from the `\"Wasserstein\n Weisfeiler-Lehman Graph Kernels\" `_\n paper.\n\n Refinement is done though a degree-scaled mean aggregation and works on\n nodes with continuous attributes:\n\n .. math::\n \\mathbf{x}^{\\prime}_i = \\frac{1}{2}\\big(\\mathbf{x}_i +\n \\frac{1}{\\textrm{deg}(i)}\n \\sum_{j \\in \\mathcal{N}(i)} e_{j,i} \\cdot \\mathbf{x}_j \\big)\n\n where :math:`e_{j,i}` denotes the edge weight from source node :obj:`j` to\n target node :obj:`i` (default: :obj:`1`)\n\n Args:\n **kwargs (optional): Additional arguments of\n :class:`torch_geometric.nn.conv.MessagePassing`.\n\n Shapes:\n - **input:**\n node features :math:`(|\\mathcal{V}|, F)` or\n :math:`((|\\mathcal{V_s}|, F), (|\\mathcal{V_t}|, F))` if bipartite,\n edge indices :math:`(2, |\\mathcal{E}|)`,\n edge weights :math:`(|\\mathcal{E}|)` *(optional)*\n - **output:** node features :math:`(|\\mathcal{V}|, F)` or\n :math:`(|\\mathcal{V}_t|, F)` if bipartite\n \"\"\"\n def __init__(self, **kwargs):\n super().__init__(aggr='add', **kwargs)\n\n def forward(\n self,\n x: Union[Tensor, OptPairTensor],\n edge_index: Adj,\n edge_weight: OptTensor = None,\n size: Size = None,\n ) -> Tensor:\n\n if isinstance(x, Tensor):\n x = (x, x)\n\n # propagate_type: (x: OptPairTensor, edge_weight: OptTensor)\n out = self.propagate(edge_index, x=x, edge_weight=edge_weight,\n size=size)\n\n if isinstance(edge_index, SparseTensor):\n assert edge_weight is None\n dst_index, _, edge_weight = edge_index.coo()\n else:\n dst_index = edge_index[1]\n\n if edge_weight is None:\n edge_weight = x[0].new_ones(dst_index.numel())\n\n deg = scatter(edge_weight, dst_index, 0, out.size(0), reduce='sum')\n deg_inv = 1. / deg\n deg_inv.masked_fill_(deg_inv == float('inf'), 0)\n out = deg_inv.view(-1, 1) * out\n\n x_dst = x[1]\n if x_dst is not None:\n out = 0.5 * (x_dst + out)\n\n return out\n\n def message(self, x_j: Tensor, edge_weight: OptTensor) -> Tensor:\n return x_j if edge_weight is None else edge_weight.view(-1, 1) * x_j\n\n def message_and_aggregate(self, adj_t: Adj, x: OptPairTensor) -> Tensor:\n return spmm(adj_t, x[0], reduce=self.aggr)\n\n# Path: torch_geometric/nn/conv/x_conv.py\nfrom math import ceil\nfrom typing import Optional\n\nimport torch\nfrom torch import Tensor\nfrom torch.nn import ELU\nfrom torch.nn import BatchNorm1d as BN\nfrom torch.nn import Conv1d\nfrom torch.nn import Linear as L\nfrom torch.nn import Sequential as S\n\nfrom torch_geometric.nn import Reshape\nfrom torch_geometric.nn.inits import reset\n\ntry:\n from torch_cluster import knn_graph\nexcept ImportError:\n knn_graph = None\n\n\nclass XConv(torch.nn.Module):\n r\"\"\"The convolutional operator on :math:`\\mathcal{X}`-transformed points\n from the `\"PointCNN: Convolution On X-Transformed Points\"\n `_ paper.\n\n .. math::\n \\mathbf{x}^{\\prime}_i = \\mathrm{Conv}\\left(\\mathbf{K},\n \\gamma_{\\mathbf{\\Theta}}(\\mathbf{P}_i - \\mathbf{p}_i) \\times\n \\left( h_\\mathbf{\\Theta}(\\mathbf{P}_i - \\mathbf{p}_i) \\, \\Vert \\,\n \\mathbf{x}_i \\right) \\right),\n\n where :math:`\\mathbf{K}` and :math:`\\mathbf{P}_i` denote the trainable\n filter and neighboring point positions of :math:`\\mathbf{x}_i`,\n respectively.\n :math:`\\gamma_{\\mathbf{\\Theta}}` and :math:`h_{\\mathbf{\\Theta}}` describe\n neural networks, *i.e.* MLPs, where :math:`h_{\\mathbf{\\Theta}}`\n individually lifts each point into a higher-dimensional space, and\n :math:`\\gamma_{\\mathbf{\\Theta}}` computes the :math:`\\mathcal{X}`-\n transformation matrix based on *all* points in a neighborhood.\n\n Args:\n in_channels (int): Size of each input sample.\n out_channels (int): Size of each output sample.\n dim (int): Point cloud dimensionality.\n kernel_size (int): Size of the convolving kernel, *i.e.* number of\n neighbors including self-loops.\n hidden_channels (int, optional): Output size of\n :math:`h_{\\mathbf{\\Theta}}`, *i.e.* dimensionality of lifted\n points. If set to :obj:`None`, will be automatically set to\n :obj:`in_channels / 4`. (default: :obj:`None`)\n dilation (int, optional): The factor by which the neighborhood is\n extended, from which :obj:`kernel_size` neighbors are then\n uniformly sampled. Can be interpreted as the dilation rate of\n classical convolutional operators. (default: :obj:`1`)\n bias (bool, optional): If set to :obj:`False`, the layer will not learn\n an additive bias. (default: :obj:`True`)\n num_workers (int): Number of workers to use for k-NN computation.\n Has no effect in case :obj:`batch` is not :obj:`None`, or the input\n lies on the GPU. (default: :obj:`1`)\n\n Shapes:\n - **input:**\n node features :math:`(|\\mathcal{V}|, F_{in})`,\n positions :math:`(|\\mathcal{V}|, D)`,\n batch vector :math:`(|\\mathcal{V}|)` *(optional)*\n - **output:**\n node features :math:`(|\\mathcal{V}|, F_{out})`\n \"\"\"\n def __init__(self, in_channels: int, out_channels: int, dim: int,\n kernel_size: int, hidden_channels: Optional[int] = None,\n dilation: int = 1, bias: bool = True, num_workers: int = 1):\n super().__init__()\n\n if knn_graph is None:\n raise ImportError('`XConv` requires `torch-cluster`.')\n\n self.in_channels = in_channels\n if hidden_channels is None:\n hidden_channels = in_channels // 4\n assert hidden_channels > 0\n self.hidden_channels = hidden_channels\n self.out_channels = out_channels\n self.dim = dim\n self.kernel_size = kernel_size\n self.dilation = dilation\n self.num_workers = num_workers\n\n C_in, C_delta, C_out = in_channels, hidden_channels, out_channels\n D, K = dim, kernel_size\n\n self.mlp1 = S(\n L(dim, C_delta),\n ELU(),\n BN(C_delta),\n L(C_delta, C_delta),\n ELU(),\n BN(C_delta),\n Reshape(-1, K, C_delta),\n )\n\n self.mlp2 = S(\n L(D * K, K**2),\n ELU(),\n BN(K**2),\n Reshape(-1, K, K),\n Conv1d(K, K**2, K, groups=K),\n ELU(),\n BN(K**2),\n Reshape(-1, K, K),\n Conv1d(K, K**2, K, groups=K),\n BN(K**2),\n Reshape(-1, K, K),\n )\n\n C_in = C_in + C_delta\n depth_multiplier = int(ceil(C_out / C_in))\n self.conv = S(\n Conv1d(C_in, C_in * depth_multiplier, K, groups=C_in),\n Reshape(-1, C_in * depth_multiplier),\n L(C_in * depth_multiplier, C_out, bias=bias),\n )\n\n self.reset_parameters()\n\n def reset_parameters(self):\n r\"\"\"Resets all learnable parameters of the module.\"\"\"\n reset(self.mlp1)\n reset(self.mlp2)\n reset(self.conv)\n\n def forward(self, x: Tensor, pos: Tensor, batch: Optional[Tensor] = None):\n r\"\"\"Runs the forward pass of the module.\"\"\"\n pos = pos.unsqueeze(-1) if pos.dim() == 1 else pos\n (N, D), K = pos.size(), self.kernel_size\n\n edge_index = knn_graph(pos, K * self.dilation, batch, loop=True,\n flow='target_to_source',\n num_workers=self.num_workers)\n\n if self.dilation > 1:\n edge_index = edge_index[:, ::self.dilation]\n\n row, col = edge_index[0], edge_index[1]\n\n pos = pos[col] - pos[row]\n\n x_star = self.mlp1(pos)\n if x is not None:\n x = x.unsqueeze(-1) if x.dim() == 1 else x\n x = x[col].view(N, K, self.in_channels)\n x_star = torch.cat([x_star, x], dim=-1)\n x_star = x_star.transpose(1, 2).contiguous()\n\n transform_matrix = self.mlp2(pos.view(N, K * D))\n\n x_transformed = torch.matmul(x_star, transform_matrix)\n\n out = self.conv(x_transformed)\n\n return out\n\n def __repr__(self) -> str:\n return (f'{self.__class__.__name__}({self.in_channels}, '\n f'{self.out_channels})')\n\n# Path: torch_geometric/nn/conv/__init__.py\nfrom .message_passing import MessagePassing\nfrom .simple_conv import SimpleConv\nfrom .gcn_conv import GCNConv\nfrom .cheb_conv import ChebConv\nfrom .sage_conv import SAGEConv\nfrom .cugraph.sage_conv import CuGraphSAGEConv\nfrom .graph_conv import GraphConv\nfrom .gravnet_conv import GravNetConv\nfrom .gated_graph_conv import GatedGraphConv\nfrom .res_gated_graph_conv import ResGatedGraphConv\nfrom .gat_conv import GATConv\nfrom .cugraph.gat_conv import CuGraphGATConv\nfrom .fused_gat_conv import FusedGATConv\nfrom .gatv2_conv import GATv2Conv\nfrom .transformer_conv import TransformerConv\nfrom .agnn_conv import AGNNConv\nfrom .tag_conv import TAGConv\nfrom .gin_conv import GINConv, GINEConv\nfrom .arma_conv import ARMAConv\nfrom .sg_conv import SGConv\nfrom .appnp import APPNP\nfrom .mf_conv import MFConv\nfrom .rgcn_conv import RGCNConv, FastRGCNConv\nfrom .cugraph.rgcn_conv import CuGraphRGCNConv\nfrom .rgat_conv import RGATConv\nfrom .signed_conv import SignedConv\nfrom .dna_conv import DNAConv\nfrom .point_conv import PointNetConv\nfrom .gmm_conv import GMMConv\nfrom .spline_conv import SplineConv\nfrom .nn_conv import NNConv\nfrom .cg_conv import CGConv\nfrom .edge_conv import EdgeConv, DynamicEdgeConv\nfrom .x_conv import XConv\nfrom .ppf_conv import PPFConv\nfrom .feast_conv import FeaStConv\nfrom .point_transformer_conv import PointTransformerConv\nfrom .hypergraph_conv import HypergraphConv\nfrom .le_conv import LEConv\nfrom .pna_conv import PNAConv\nfrom .cluster_gcn_conv import ClusterGCNConv\nfrom .gen_conv import GENConv\nfrom .gcn2_conv import GCN2Conv\nfrom .pan_conv import PANConv\nfrom .wl_conv import WLConv\nfrom .wl_conv_continuous import WLConvContinuous\nfrom .film_conv import FiLMConv\nfrom .supergat_conv import SuperGATConv\nfrom .fa_conv import FAConv\nfrom .eg_conv import EGConv\nfrom .pdn_conv import PDNConv\nfrom .general_conv import GeneralConv\nfrom .hgt_conv import HGTConv\nfrom .heat_conv import HEATConv\nfrom .hetero_conv import HeteroConv\nfrom .han_conv import HANConv\nfrom .lg_conv import LGConv\nfrom .ssg_conv import SSGConv\nfrom .point_gnn_conv import PointGNNConv\nfrom .gps_conv import GPSConv\nfrom .antisymmetric_conv import AntiSymmetricConv\nfrom .dir_gnn_conv import DirGNNConv\nfrom .mixhop_conv import MixHopConv\n\nimport torch_geometric.nn.conv.utils # noqa\n\n__all__ = [\n 'MessagePassing',\n 'SimpleConv',\n 'GCNConv',\n 'ChebConv',\n 'SAGEConv',\n 'CuGraphSAGEConv',\n 'GraphConv',\n 'GravNetConv',\n 'GatedGraphConv',\n 'ResGatedGraphConv',\n 'GATConv',\n 'CuGraphGATConv',\n 'FusedGATConv',\n 'GATv2Conv',\n 'TransformerConv',\n 'AGNNConv',\n 'TAGConv',\n 'GINConv',\n 'GINEConv',\n 'ARMAConv',\n 'SGConv',\n 'SSGConv',\n 'APPNP',\n 'MFConv',\n 'RGCNConv',\n 'FastRGCNConv',\n 'CuGraphRGCNConv',\n 'RGATConv',\n 'SignedConv',\n 'DNAConv',\n 'PointNetConv',\n 'GMMConv',\n 'SplineConv',\n 'NNConv',\n 'CGConv',\n 'EdgeConv',\n 'DynamicEdgeConv',\n 'XConv',\n 'PPFConv',\n 'FeaStConv',\n 'PointTransformerConv',\n 'HypergraphConv',\n 'LEConv',\n 'PNAConv',\n 'ClusterGCNConv',\n 'GENConv',\n 'GCN2Conv',\n 'PANConv',\n 'WLConv',\n 'WLConvContinuous',\n 'FiLMConv',\n 'SuperGATConv',\n 'FAConv',\n 'EGConv',\n 'PDNConv',\n 'GeneralConv',\n 'HGTConv',\n 'HEATConv',\n 'HeteroConv',\n 'HANConv',\n 'LGConv',\n 'PointGNNConv',\n 'GPSConv',\n 'AntiSymmetricConv',\n 'DirGNNConv',\n 'MixHopConv',\n]\n\nclasses = __all__\n\nECConv = NNConv\nPointConv = PointNetConv\n\n# Path: torch_geometric/explain/algorithm/attention_explainer.py\nimport logging\nfrom typing import List, Optional, Union\n\nimport torch\nfrom torch import Tensor\n\nfrom torch_geometric.explain import Explanation\nfrom torch_geometric.explain.algorithm import ExplainerAlgorithm\nfrom torch_geometric.explain.config import ExplanationType, ModelTaskLevel\nfrom torch_geometric.nn.conv.message_passing import MessagePassing\n\n\nclass AttentionExplainer(ExplainerAlgorithm):\n r\"\"\"An explainer that uses the attention coefficients produced by an\n attention-based GNN (*e.g.*,\n :class:`~torch_geometric.nn.conv.GATConv`,\n :class:`~torch_geometric.nn.conv.GATv2Conv`, or\n :class:`~torch_geometric.nn.conv.TransformerConv`) as edge explanation.\n Attention scores across layers and heads will be aggregated according to\n the :obj:`reduce` argument.\n\n Args:\n reduce (str, optional): The method to reduce the attention scores\n across layers and heads. (default: :obj:`\"max\"`)\n \"\"\"\n def __init__(self, reduce: str = 'max'):\n super().__init__()\n self.reduce = reduce\n\n def forward(\n self,\n model: torch.nn.Module,\n x: Tensor,\n edge_index: Tensor,\n *,\n target: Tensor,\n index: Optional[Union[int, Tensor]] = None,\n **kwargs,\n ) -> Explanation:\n if isinstance(x, dict) or isinstance(edge_index, dict):\n raise ValueError(f\"Heterogeneous graphs not yet supported in \"\n f\"'{self.__class__.__name__}'\")\n\n hard_edge_mask = None\n if self.model_config.task_level == ModelTaskLevel.node:\n # We need to compute the hard edge mask to properly clean up edge\n # attributions not involved during message passing:\n _, hard_edge_mask = self._get_hard_masks(model, index, edge_index,\n num_nodes=x.size(0))\n\n alphas: List[Tensor] = []\n\n def hook(module, msg_kwargs, out):\n if 'alpha' in msg_kwargs[0]:\n alphas.append(msg_kwargs[0]['alpha'].detach())\n elif getattr(module, '_alpha', None) is not None:\n alphas.append(module._alpha.detach())\n\n hook_handles = []\n for module in model.modules(): # Register message forward hooks:\n if (isinstance(module, MessagePassing)\n and module.explain is not False):\n hook_handles.append(module.register_message_forward_hook(hook))\n\n model(x, edge_index, **kwargs)\n\n for handle in hook_handles: # Remove hooks:\n handle.remove()\n\n if len(alphas) == 0:\n raise ValueError(\"Could not collect any attention coefficients. \"\n \"Please ensure that your model is using \"\n \"attention-based GNN layers.\")\n\n for i, alpha in enumerate(alphas):\n alpha = alpha[:edge_index.size(1)] # Respect potential self-loops.\n if alpha.dim() == 2:\n alpha = getattr(torch, self.reduce)(alpha, dim=-1)\n if isinstance(alpha, tuple): # Respect `torch.max`:\n alpha = alpha[0]\n elif alpha.dim() > 2:\n raise ValueError(f\"Can not reduce attention coefficients of \"\n f\"shape {list(alpha.size())}\")\n alphas[i] = alpha\n\n if len(alphas) > 1:\n alpha = torch.stack(alphas, dim=-1)\n alpha = getattr(torch, self.reduce)(alpha, dim=-1)\n if isinstance(alpha, tuple): # Respect `torch.max`:\n alpha = alpha[0]\n else:\n alpha = alphas[0]\n\n alpha = self._post_process_mask(alpha, hard_edge_mask,\n apply_sigmoid=False)\n\n return Explanation(edge_mask=alpha)\n\n def supports(self) -> bool:\n explanation_type = self.explainer_config.explanation_type\n if explanation_type != ExplanationType.model:\n logging.error(f\"'{self.__class__.__name__}' only supports \"\n f\"model explanations \"\n f\"got (`explanation_type={explanation_type.value}`)\")\n return False\n\n node_mask_type = self.explainer_config.node_mask_type\n if node_mask_type is not None:\n logging.error(f\"'{self.__class__.__name__}' does not support \"\n f\"explaining input node features \"\n f\"got (`node_mask_type={node_mask_type.value}`)\")\n return False\n\n return True\n\n# Path: torch_geometric/explain/algorithm/base.py\nfrom abc import abstractmethod\nfrom typing import Dict, Optional, Tuple, Union\n\nimport torch\nimport torch.nn.functional as F\nfrom torch import Tensor\n\nfrom torch_geometric.explain import Explanation, HeteroExplanation\nfrom torch_geometric.explain.config import (\n ExplainerConfig,\n ModelConfig,\n ModelReturnType,\n)\nfrom torch_geometric.nn import MessagePassing\nfrom torch_geometric.typing import EdgeType, NodeType\nfrom torch_geometric.utils import k_hop_subgraph\n\n\nclass ExplainerAlgorithm(torch.nn.Module):\n r\"\"\"An abstract base class for implementing explainer algorithms.\"\"\"\n @abstractmethod\n def forward(\n self,\n model: torch.nn.Module,\n x: Union[Tensor, Dict[NodeType, Tensor]],\n edge_index: Union[Tensor, Dict[EdgeType, Tensor]],\n *,\n target: Tensor,\n index: Optional[Union[int, Tensor]] = None,\n **kwargs,\n ) -> Union[Explanation, HeteroExplanation]:\n r\"\"\"Computes the explanation.\n\n Args:\n model (torch.nn.Module): The model to explain.\n x (Union[torch.Tensor, Dict[NodeType, torch.Tensor]]): The input\n node features of a homogeneous or heterogeneous graph.\n edge_index (Union[torch.Tensor, Dict[NodeType, torch.Tensor]]): The\n input edge indices of a homogeneous or heterogeneous graph.\n target (torch.Tensor): The target of the model.\n index (Union[int, Tensor], optional): The index of the model\n output to explain. Can be a single index or a tensor of\n indices. (default: :obj:`None`)\n **kwargs (optional): Additional keyword arguments passed to\n :obj:`model`.\n \"\"\"\n\n @abstractmethod\n def supports(self) -> bool:\n r\"\"\"Checks if the explainer supports the user-defined settings provided\n in :obj:`self.explainer_config`, :obj:`self.model_config`.\n \"\"\"\n pass\n\n ###########################################################################\n...\n", "needle_token_start": 2380, "needle_token_end": 2454, "code_context_ntokens": 16143, "output": ["```python\ndef positive_sampling(self, edge_index: Tensor) -> Tensor:\n pos_edge_index, _ = dropout_edge(edge_index,\n p=1. - self.edge_sample_ratio,\n training=self.training)\n return pos_edge_index\n```"]} +{"repo": "pyg-team/pytorch_geometric", "name": "read_from_memlab", "language": "python", "path": "torch_geometric/profile/profile.py", "position_ratio": 0.25, "description": "\nFunction Description:\n1. **Purpose**: To extract and convert specific memory usage statistics from profiling data into a more readable format, specifically in megabytes.\n2. **Input**: A profiler object that contains raw memory profiling records.\n3. **Output**: A list of floats, each representing a memory statistic (peak allocated, reserved, and active bytes) converted from bytes to megabytes.\n4. **Procedure**: The function retrieves raw line records from the profiler, specifies the memory statistics to track, and then processes these records to extract the maximum values of the specified statistics. These byte values are then converted to megabytes and returned as a list.\n\n", "instruction": "Based on the function description and code context, please retrieve and repeat the exact described function from the code context in a code block wrapped by ```:", "template": "instruction\ncode_context\ndescription\ninstruction", "code_context": "# Path: torch_geometric/profile/utils.py\nimport gc\nimport os\nimport os.path as osp\nimport random\nimport subprocess as sp\nimport sys\nimport warnings\nfrom collections.abc import Mapping, Sequence\nfrom typing import Any, Tuple\n\nimport torch\nfrom torch import Tensor\n\nfrom torch_geometric.data.data import BaseData\nfrom torch_geometric.typing import SparseTensor\n\n\ndef count_parameters(model: torch.nn.Module) -> int:\n r\"\"\"Given a :class:`torch.nn.Module`, count its trainable parameters.\n\n Args:\n model (torch.nn.Model): The model.\n \"\"\"\n return sum([p.numel() for p in model.parameters() if p.requires_grad])\n\n\ndef get_model_size(model: torch.nn.Module) -> int:\n r\"\"\"Given a :class:`torch.nn.Module`, get its actual disk size in bytes.\n\n Args:\n model (torch model): The model.\n \"\"\"\n path = f'{random.randrange(sys.maxsize)}.pt'\n torch.save(model.state_dict(), path)\n model_size = osp.getsize(path)\n os.remove(path)\n return model_size\n\n\ndef get_data_size(data: BaseData) -> int:\n r\"\"\"Given a :class:`torch_geometric.data.Data` object, get its theoretical\n memory usage in bytes.\n\n Args:\n data (torch_geometric.data.Data or torch_geometric.data.HeteroData):\n The :class:`~torch_geometric.data.Data` or\n :class:`~torch_geometric.data.HeteroData` graph object.\n \"\"\"\n data_ptrs = set()\n\n def _get_size(obj: Any) -> int:\n if isinstance(obj, Tensor):\n if obj.data_ptr() in data_ptrs:\n return 0\n data_ptrs.add(obj.data_ptr())\n return obj.numel() * obj.element_size()\n elif isinstance(obj, SparseTensor):\n return _get_size(obj.csr())\n elif isinstance(obj, Sequence) and not isinstance(obj, str):\n return sum([_get_size(x) for x in obj])\n elif isinstance(obj, Mapping):\n return sum([_get_size(x) for x in obj.values()])\n else:\n return 0\n\n return sum([_get_size(store) for store in data.stores])\n\n\ndef get_cpu_memory_from_gc() -> int:\n r\"\"\"Returns the used CPU memory in bytes, as reported by the\n :python:`Python` garbage collector.\n \"\"\"\n warnings.filterwarnings('ignore', '.*torch.distributed.reduce_op.*')\n\n mem = 0\n for obj in gc.get_objects():\n try:\n if isinstance(obj, Tensor) and not obj.is_cuda:\n mem += obj.numel() * obj.element_size()\n except Exception:\n pass\n return mem\n\n\ndef get_gpu_memory_from_gc(device: int = 0) -> int: # pragma: no cover\n r\"\"\"Returns the used GPU memory in bytes, as reported by the\n :python:`Python` garbage collector.\n\n Args:\n device (int, optional): The GPU device identifier. (default: :obj:`1`)\n \"\"\"\n warnings.filterwarnings('ignore', '.*torch.distributed.reduce_op.*')\n\n mem = 0\n for obj in gc.get_objects():\n try:\n if isinstance(obj, Tensor) and obj.get_device() == device:\n mem += obj.numel() * obj.element_size()\n except Exception:\n pass\n return mem\n\n\ndef get_gpu_memory_from_nvidia_smi( # pragma: no cover\n device: int = 0,\n digits: int = 2,\n) -> Tuple[float, float]:\n r\"\"\"Returns the free and used GPU memory in megabytes, as reported by\n :obj:`nivdia-smi`.\n\n .. note::\n\n :obj:`nvidia-smi` will generally overestimate the amount of memory used\n by the actual program, see `here `__.\n\n Args:\n device (int, optional): The GPU device identifier. (default: :obj:`1`)\n digits (int): The number of decimals to use for megabytes.\n (default: :obj:`2`)\n \"\"\"\n CMD = 'nvidia-smi --query-gpu=memory.free --format=csv'\n free_out = sp.check_output(CMD.split()).decode('utf-8').split('\\n')[1:-1]\n\n CMD = 'nvidia-smi --query-gpu=memory.used --format=csv'\n used_out = sp.check_output(CMD.split()).decode('utf-8').split('\\n')[1:-1]\n\n if device < 0 or device >= len(free_out):\n raise AttributeError(\n f'GPU {device} not available (found {len(free_out)} GPUs)')\n\n free_mem = medibyte_to_megabyte(int(free_out[device].split()[0]), digits)\n used_mem = medibyte_to_megabyte(int(used_out[device].split()[0]), digits)\n\n return free_mem, used_mem\n\n\ndef get_gpu_memory_from_ipex(\n device: int = 0,\n digits=2) -> Tuple[float, float, float]: # pragma: no cover\n r\"\"\"Returns the XPU memory statistics.\n\n...\n# Path: torch_geometric/profile/profile.py\nimport os\nimport pathlib\nimport time\nfrom contextlib import ContextDecorator, contextmanager\nfrom dataclasses import dataclass\nfrom typing import Any, List, Tuple, Union\n\nimport torch\nfrom torch.autograd.profiler import EventList\nfrom torch.profiler import ProfilerActivity, profile\n\nfrom torch_geometric.profile.utils import (\n byte_to_megabyte,\n get_gpu_memory_from_ipex,\n get_gpu_memory_from_nvidia_smi,\n)\n\n\n@dataclass\nclass GPUStats:\n time: float\n max_allocated_gpu: float\n max_reserved_gpu: float\n max_active_gpu: float\n\n\n@dataclass\nclass CUDAStats(GPUStats):\n nvidia_smi_free_cuda: float\n nvidia_smi_used_cuda: float\n\n\n@dataclass\nclass GPUStatsSummary:\n time_mean: float\n time_std: float\n max_allocated_gpu: float\n max_reserved_gpu: float\n max_active_gpu: float\n\n\n@dataclass\nclass CUDAStatsSummary(GPUStatsSummary):\n min_nvidia_smi_free_cuda: float\n max_nvidia_smi_used_cuda: float\n\n\ndef profileit(device: str): # pragma: no cover\n r\"\"\"A decorator to facilitate profiling a function, *e.g.*, obtaining\n training runtime and memory statistics of a specific model on a specific\n dataset.\n Returns a :obj:`GPUStats` if :obj:`device` is :obj:`xpu` or extended\n object :obj:`CUDAStats`, if :obj:`device` is :obj:`cuda`.\n\n Args:\n device (str): Target device for profiling. Options are:\n :obj:`cuda` and obj:`xpu`.\n\n .. code-block:: python\n\n @profileit(\"cuda\")\n def train(model, optimizer, x, edge_index, y):\n optimizer.zero_grad()\n out = model(x, edge_index)\n loss = criterion(out, y)\n loss.backward()\n optimizer.step()\n return float(loss)\n\n loss, stats = train(model, x, edge_index, y)\n \"\"\"\n def decorator(func):\n def wrapper(\n *args, **kwargs\n ) -> Union[Tuple[Any, GPUStats], Tuple[Any, CUDAStats]]:\n model = args[0]\n if not isinstance(model, torch.nn.Module):\n raise AttributeError(\n 'First argument for profiling needs to be torch.nn.Module')\n if device not in ['cuda', 'xpu']:\n raise AttributeError(\n \"The profiling decorator supports only CUDA and \"\n \"XPU devices\")\n\n device_id = None\n for arg in list(args) + list(kwargs.values()):\n if isinstance(arg, torch.Tensor):\n device_id = arg.get_device()\n break\n if device_id is None:\n raise AttributeError(\n \"Could not infer GPU device from the args in the \"\n \"function being profiled\")\n if device_id == -1:\n raise RuntimeError(\n \"The profiling decorator does not support profiling \"\n \"on non GPU devices\")\n\n is_cuda = device == 'cuda'\n torch_gpu = torch.cuda if is_cuda else torch.xpu\n\n # `pytorch_memlab` supports only CUDA devices\n if is_cuda:\n from pytorch_memlab import LineProfiler\n\n # Init `pytorch_memlab` for analyzing the model forward pass:\n line_profiler = LineProfiler(target_gpu=device_id)\n line_profiler.enable()\n line_profiler.add_function(args[0].forward)\n\n start = torch_gpu.Event(enable_timing=True)\n end = torch_gpu.Event(enable_timing=True)\n start.record()\n\n out = func(*args, **kwargs)\n\n end.record()\n torch_gpu.synchronize()\n time = start.elapsed_time(end) / 1000\n\n if is_cuda:\n # Get the global memory statistics collected\n # by `pytorch_memlab`:\n memlab = read_from_memlab(line_profiler)\n max_allocated, max_reserved, max_active = memlab\n line_profiler.disable()\n\n # Get additional information from `nvidia-smi`:\n free_cuda, used_cuda = get_gpu_memory_from_nvidia_smi(\n device=device_id)\n\n stats = CUDAStats(time, max_allocated, max_reserved,\n max_active, free_cuda, used_cuda)\n return out, stats\n else:\n stats = GPUStats(time, *get_gpu_memory_from_ipex(device_id))\n return out, stats\n\n return wrapper\n\n return decorator\n\n\nclass timeit(ContextDecorator):\n r\"\"\"A context decorator to facilitate timing a function, *e.g.*, obtaining\n the runtime of a specific model on a specific dataset.\n\n .. code-block:: python\n\n @torch.no_grad()\n def test(model, x, edge_index):\n return model(x, edge_index)\n\n with timeit() as t:\n z = test(model, x, edge_index)\n time = t.duration\n\n Args:\n log (bool, optional): If set to :obj:`False`, will not log any runtime\n to the console. (default: :obj:`True`)\n avg_time_divisor (int, optional): If set to a value greater than\n :obj:`1`, will divide the total time by this value. Useful for\n calculating the average of runtimes within a for-loop.\n (default: :obj:`0`)\n \"\"\"\n def __init__(self, log: bool = True, avg_time_divisor: int = 0):\n self.log = log\n self.avg_time_divisor = avg_time_divisor\n\n def __enter__(self):\n if torch.cuda.is_available():\n torch.cuda.synchronize()\n self.t_start = time.time()\n return self\n\n def __exit__(self, *args):\n if torch.cuda.is_available():\n torch.cuda.synchronize()\n self.t_end = time.time()\n self.duration = self.t_end - self.t_start\n if self.avg_time_divisor > 1:\n self.duration = self.duration / self.avg_time_divisor\n if self.log: # pragma: no cover\n print(f'Time: {self.duration:.8f}s', flush=True)\n\n def reset(self):\n r\"\"\"Prints the duration and resets current timer.\"\"\"\n if self.t_start is None:\n raise RuntimeError(\"Timer wasn't started.\")\n else:\n self.__exit__()\n self.__enter__()\n\n\ndef get_stats_summary(\n stats_list: Union[List[GPUStats], List[CUDAStats]]\n) -> Union[GPUStatsSummary, CUDAStatsSummary]: # pragma: no cover\n r\"\"\"Creates a summary of collected runtime and memory statistics.\n Returns a :obj:`GPUStatsSummary` if list of :obj:`GPUStats` was passed,\n otherwise (list of :obj:`CUDAStats` was passed),\n returns a :obj:`CUDAStatsSummary`.\n\n Args:\n stats_list (Union[List[GPUStats], List[CUDAStats]]): A list of\n :obj:`GPUStats` or :obj:`CUDAStats` objects, as returned by\n :meth:`~torch_geometric.profile.profileit`.\n \"\"\"\n # calculate common statistics\n kwargs = dict(\n time_mean=float(torch.tensor([s.time for s in stats_list]).mean()),\n time_std=float(torch.tensor([s.time for s in stats_list]).std()),\n max_allocated_gpu=max([s.max_allocated_gpu for s in stats_list]),\n max_reserved_gpu=max([s.max_reserved_gpu for s in stats_list]),\n max_active_gpu=max([s.max_active_gpu for s in stats_list]))\n\n if all(isinstance(s, CUDAStats) for s in stats_list):\n return CUDAStatsSummary(\n **kwargs,\n min_nvidia_smi_free_cuda=min(\n [s.nvidia_smi_free_cuda for s in stats_list]),\n max_nvidia_smi_used_cuda=max(\n [s.nvidia_smi_used_cuda for s in stats_list]),\n )\n else:\n return GPUStatsSummary(**kwargs)\n\n\n###############################################################################\n\n\n\ndef read_from_memlab(line_profiler: Any) -> List[float]: # pragma: no cover\n from pytorch_memlab.line_profiler.line_records import LineRecords\n\n # See: https://pytorch.org/docs/stable/cuda.html#torch.cuda.memory_stats\n\n track_stats = [ # Different statistic can be collected as needed.\n 'allocated_bytes.all.peak',\n 'reserved_bytes.all.peak',\n 'active_bytes.all.peak',\n ]\n\n records = LineRecords(line_profiler._raw_line_records,\n line_profiler._code_infos)\n stats = records.display(None, track_stats)._line_records\n return [byte_to_megabyte(x) for x in stats.values.max(axis=0).tolist()]\n\n\ndef trace_handler(p):\n print_time_total(p)\n profile_dir = str(pathlib.Path.cwd()) + '/'\n timeline_file = profile_dir + 'timeline' + '.json'\n p.export_chrome_trace(timeline_file)\n\n\ndef print_time_total(p):\n if torch.cuda.is_available():\n profile_sort = 'self_cuda_time_total'\n else:\n profile_sort = 'self_cpu_time_total'\n output = p.key_averages().table(sort_by=profile_sort)\n print(output)\n\n\ndef rename_profile_file(*args):\n profile_dir = str(pathlib.Path.cwd()) + '/'\n timeline_file = profile_dir + 'profile'\n for arg in args:\n timeline_file += '-' + arg\n timeline_file += '.json'\n os.rename('timeline.json', timeline_file)\n\n\n@contextmanager\ndef torch_profile(export_chrome_trace=True, csv_data=None, write_csv=None):\n use_cuda = torch.cuda.is_available()\n\n activities = [ProfilerActivity.CPU]\n if use_cuda:\n activities.append(ProfilerActivity.CUDA)\n\n if export_chrome_trace:\n p_trace_handler = trace_handler\n else:\n p_trace_handler = print_time_total\n\n p = profile(activities=activities, on_trace_ready=p_trace_handler)\n\n with p:\n yield\n p.step()\n\n if csv_data is not None and write_csv == 'prof':\n if use_cuda:\n profile_sort = 'self_cuda_time_total'\n else:\n profile_sort = 'self_cpu_time_total'\n events = EventList(\n sorted(\n p.key_averages(),\n key=lambda evt: getattr(evt, profile_sort),\n reverse=True,\n ), use_cuda=use_cuda)\n\n save_profile_data(csv_data, events, use_cuda)\n\n\n@contextmanager\ndef xpu_profile(export_chrome_trace=True):\n with torch.autograd.profiler_legacy.profile(use_xpu=True) as profile:\n yield\n print(profile.key_averages().table(sort_by='self_xpu_time_total'))\n if export_chrome_trace:\n profile.export_chrome_trace('timeline.json')\n\n\ndef format_prof_time(time):\n # Profile time is in micro seconds, so format it appropriately:\n return round(time / 1e6, 3)\n\n\ndef save_profile_data(csv_data, events, use_cuda):\n sum_self_cpu_time_total = sum(\n [event.self_cpu_time_total for event in events])\n sum_cpu_time_total = sum([event.self_cpu_time_total for event in events])\n sum_self_cuda_time_total = sum(\n [event.self_cuda_time_total for event in events]) if use_cuda else 0\n\n for e in events[:5]: # Save top 5 most time consuming operations:\n csv_data['NAME'].append(e.key)\n csv_data['SELF CPU %'].append(\n round(e.self_cpu_time_total * 100.0 / sum_self_cpu_time_total, 3))\n csv_data['SELF CPU'].append(format_prof_time(e.self_cpu_time_total))\n csv_data['CPU TOTAL %'].append(\n round(e.cpu_time_total * 100.0 / sum_cpu_time_total, 3))\n csv_data['CPU TOTAL'].append(format_prof_time(e.cpu_time_total))\n csv_data['CPU TIME AVG'].append(format_prof_time(e.cpu_time_total))\n if use_cuda:\n csv_data['SELF CUDA %'].append(e.self_cuda_time_total * 100.0 /\n sum_self_cuda_time_total)\n csv_data['SELF CUDA'].append(\n format_prof_time(e.self_cuda_time_total))\n csv_data['CUDA TOTAL'].append(format_prof_time(e.cpu_time_total))\n csv_data['CUDA TIME AVG'].append(format_prof_time(\n e.cpu_time_total))\n csv_data['# OF CALLS'].append(e.count)\n\n# Path: torch_geometric/profile/__init__.py\nr\"\"\"GNN profiling package.\"\"\"\n\nfrom .benchmark import benchmark\nfrom .profile import (\n get_stats_summary,\n print_time_total,\n profileit,\n rename_profile_file,\n timeit,\n torch_profile,\n trace_handler,\n xpu_profile,\n)\nfrom .utils import (\n count_parameters,\n get_cpu_memory_from_gc,\n get_data_size,\n get_gpu_memory_from_gc,\n get_gpu_memory_from_ipex,\n get_gpu_memory_from_nvidia_smi,\n get_model_size,\n)\n\n__all__ = [\n 'profileit',\n 'timeit',\n 'get_stats_summary',\n 'trace_handler',\n 'print_time_total',\n 'rename_profile_file',\n 'torch_profile',\n 'xpu_profile',\n 'count_parameters',\n 'get_model_size',\n 'get_data_size',\n 'get_cpu_memory_from_gc',\n 'get_gpu_memory_from_gc',\n 'get_gpu_memory_from_nvidia_smi',\n 'get_gpu_memory_from_ipex',\n 'benchmark',\n]\n\nclasses = __all__\n\n# Path: torch_geometric/seed.py\nimport random\n\nimport numpy as np\nimport torch\n\n\ndef seed_everything(seed: int) -> None:\n r\"\"\"Sets the seed for generating random numbers in :pytorch:`PyTorch`,\n :obj:`numpy` and :python:`Python`.\n\n Args:\n seed (int): The desired seed.\n \"\"\"\n random.seed(seed)\n np.random.seed(seed)\n torch.manual_seed(seed)\n torch.cuda.manual_seed_all(seed)\n\n# Path: torch_geometric/transforms/add_metapaths.py\nimport warnings\nfrom typing import List, Optional, Tuple, Union, cast\n\nimport torch\nfrom torch import Tensor\n\nfrom torch_geometric import EdgeIndex\nfrom torch_geometric.data import HeteroData\nfrom torch_geometric.data.datapipes import functional_transform\nfrom torch_geometric.transforms import BaseTransform\nfrom torch_geometric.typing import EdgeType\nfrom torch_geometric.utils import coalesce, degree\n\n\n@functional_transform('add_metapaths')\nclass AddMetaPaths(BaseTransform):\n r\"\"\"Adds additional edge types to a\n :class:`~torch_geometric.data.HeteroData` object between the source node\n type and the destination node type of a given :obj:`metapath`, as described\n in the `\"Heterogenous Graph Attention Networks\"\n `_ paper\n (functional name: :obj:`add_metapaths`).\n\n Meta-path based neighbors can exploit different aspects of structure\n information in heterogeneous graphs.\n Formally, a metapath is a path of the form\n\n .. math::\n\n \\mathcal{V}_1 \\xrightarrow{R_1} \\mathcal{V}_2 \\xrightarrow{R_2} \\ldots\n \\xrightarrow{R_{\\ell-1}} \\mathcal{V}_{\\ell}\n\n in which :math:`\\mathcal{V}_i` represents node types, and :math:`R_j`\n represents the edge type connecting two node types.\n The added edge type is given by the sequential multiplication of\n adjacency matrices along the metapath, and is added to the\n :class:`~torch_geometric.data.HeteroData` object as edge type\n :obj:`(src_node_type, \"metapath_*\", dst_node_type)`, where\n :obj:`src_node_type` and :obj:`dst_node_type` denote :math:`\\mathcal{V}_1`\n and :math:`\\mathcal{V}_{\\ell}`, repectively.\n\n In addition, a :obj:`metapath_dict` object is added to the\n :class:`~torch_geometric.data.HeteroData` object which maps the\n metapath-based edge type to its original metapath.\n\n .. code-block:: python\n\n from torch_geometric.datasets import DBLP\n from torch_geometric.data import HeteroData\n from torch_geometric.transforms import AddMetaPaths\n\n data = DBLP(root)[0]\n # 4 node types: \"paper\", \"author\", \"conference\", and \"term\"\n # 6 edge types: (\"paper\",\"author\"), (\"author\", \"paper\"),\n # (\"paper, \"term\"), (\"paper\", \"conference\"),\n # (\"term, \"paper\"), (\"conference\", \"paper\")\n\n # Add two metapaths:\n # 1. From \"paper\" to \"paper\" through \"conference\"\n # 2. From \"author\" to \"conference\" through \"paper\"\n metapaths = [[(\"paper\", \"conference\"), (\"conference\", \"paper\")],\n [(\"author\", \"paper\"), (\"paper\", \"conference\")]]\n data = AddMetaPaths(metapaths)(data)\n\n print(data.edge_types)\n >>> [(\"author\", \"to\", \"paper\"), (\"paper\", \"to\", \"author\"),\n (\"paper\", \"to\", \"term\"), (\"paper\", \"to\", \"conference\"),\n (\"term\", \"to\", \"paper\"), (\"conference\", \"to\", \"paper\"),\n (\"paper\", \"metapath_0\", \"paper\"),\n (\"author\", \"metapath_1\", \"conference\")]\n\n print(data.metapath_dict)\n >>> {(\"paper\", \"metapath_0\", \"paper\"): [(\"paper\", \"conference\"),\n (\"conference\", \"paper\")],\n (\"author\", \"metapath_1\", \"conference\"): [(\"author\", \"paper\"),\n (\"paper\", \"conference\")]}\n\n Args:\n metapaths (List[List[Tuple[str, str, str]]]): The metapaths described\n by a list of lists of\n :obj:`(src_node_type, rel_type, dst_node_type)` tuples.\n drop_orig_edge_types (bool, optional): If set to :obj:`True`, existing\n edge types will be dropped. (default: :obj:`False`)\n keep_same_node_type (bool, optional): If set to :obj:`True`, existing\n edge types between the same node type are not dropped even in case\n :obj:`drop_orig_edge_types` is set to :obj:`True`.\n (default: :obj:`False`)\n drop_unconnected_node_types (bool, optional): If set to :obj:`True`,\n will drop node types not connected by any edge type.\n (default: :obj:`False`)\n max_sample (int, optional): If set, will sample at maximum\n :obj:`max_sample` neighbors within metapaths. Useful in order to\n tackle very dense metapath edges. (default: :obj:`None`)\n weighted (bool, optional): If set to :obj:`True`, computes weights for\n each metapath edge and stores them in :obj:`edge_weight`. The\n weight of each metapath edge is computed as the number of metapaths\n from the start to the end of the metapath edge.\n (default :obj:`False`)\n \"\"\"\n def __init__(\n self,\n metapaths: List[List[EdgeType]],\n drop_orig_edge_types: bool = False,\n keep_same_node_type: bool = False,\n drop_unconnected_node_types: bool = False,\n max_sample: Optional[int] = None,\n weighted: bool = False,\n **kwargs: bool,\n ) -> None:\n if 'drop_orig_edges' in kwargs:\n warnings.warn(\"'drop_orig_edges' is dprecated. Use \"\n \"'drop_orig_edge_types' instead\")\n drop_orig_edge_types = kwargs['drop_orig_edges']\n\n if 'drop_unconnected_nodes' in kwargs:\n warnings.warn(\"'drop_unconnected_nodes' is dprecated. Use \"\n \"'drop_unconnected_node_types' instead\")\n drop_unconnected_node_types = kwargs['drop_unconnected_nodes']\n\n for path in metapaths:\n assert len(path) >= 2, f\"Invalid metapath '{path}'\"\n assert all([\n j[-1] == path[i + 1][0] for i, j in enumerate(path[:-1])\n ]), f\"Invalid sequence of node types in '{path}'\"\n\n self.metapaths = metapaths\n self.drop_orig_edge_types = drop_orig_edge_types\n self.keep_same_node_type = keep_same_node_type\n self.drop_unconnected_node_types = drop_unconnected_node_types\n self.max_sample = max_sample\n self.weighted = weighted\n\n def forward(self, data: HeteroData) -> HeteroData:\n edge_types = data.edge_types # Save original edge types.\n data.metapath_dict = {}\n\n for j, metapath in enumerate(self.metapaths):\n for edge_type in metapath:\n assert data._to_canonical(edge_type) in edge_types\n\n edge_type = metapath[0]\n edge_index, edge_weight = self._edge_index(data, edge_type)\n\n if self.max_sample is not None:\n edge_index, edge_weight = self._sample(edge_index, edge_weight)\n\n for i, edge_type in enumerate(metapath[1:]):\n edge_index2, edge_weight2 = self._edge_index(data, edge_type)\n\n edge_index, edge_weight = edge_index.matmul(\n edge_index2, edge_weight, edge_weight2)\n\n if not self.weighted:\n edge_weight = None\n\n if self.max_sample is not None:\n edge_index, edge_weight = self._sample(\n edge_index, edge_weight)\n\n new_edge_type = (metapath[0][0], f'metapath_{j}', metapath[-1][-1])\n data[new_edge_type].edge_index = edge_index\n if self.weighted:\n data[new_edge_type].edge_weight = edge_weight\n data.metapath_dict[new_edge_type] = metapath\n\n postprocess(data, edge_types, self.drop_orig_edge_types,\n self.keep_same_node_type, self.drop_unconnected_node_types)\n\n return data\n\n def _edge_index(\n self,\n data: HeteroData,\n edge_type: EdgeType,\n ) -> Tuple[EdgeIndex, Optional[Tensor]]:\n\n edge_index = EdgeIndex(\n data[edge_type].edge_index,\n sparse_size=data[edge_type].size(),\n )\n edge_index, perm = edge_index.sort_by('row')\n\n if not self.weighted:\n return edge_index, None\n\n edge_weight = data[edge_type].get('edge_weight')\n if edge_weight is not None:\n assert edge_weight.dim() == 1\n edge_weight = edge_weight[perm]\n\n return edge_index, edge_weight\n\n def _sample(\n self,\n edge_index: EdgeIndex,\n edge_weight: Optional[Tensor],\n ) -> Tuple[EdgeIndex, Optional[Tensor]]:\n\n deg = degree(edge_index[0], num_nodes=edge_index.get_sparse_size(0))\n prob = (self.max_sample * (1. / deg))[edge_index[0]]\n mask = torch.rand_like(prob) < prob\n\n edge_index = cast(EdgeIndex, edge_index[:, mask])\n assert isinstance(edge_index, EdgeIndex)\n if edge_weight is not None:\n edge_weight = edge_weight[mask]\n\n return edge_index, edge_weight\n\n\n@functional_transform('add_random_metapaths')\nclass AddRandomMetaPaths(BaseTransform):\n r\"\"\"Adds additional edge types similar to :class:`AddMetaPaths`.\n The key difference is that the added edge type is given by\n multiple random walks along the metapath.\n One might want to increase the number of random walks\n via :obj:`walks_per_node` to achieve competitive performance with\n :class:`AddMetaPaths`.\n\n Args:\n metapaths (List[List[Tuple[str, str, str]]]): The metapaths described\n by a list of lists of\n :obj:`(src_node_type, rel_type, dst_node_type)` tuples.\n drop_orig_edge_types (bool, optional): If set to :obj:`True`, existing\n edge types will be dropped. (default: :obj:`False`)\n keep_same_node_type (bool, optional): If set to :obj:`True`, existing\n edge types between the same node type are not dropped even in case\n :obj:`drop_orig_edge_types` is set to :obj:`True`.\n (default: :obj:`False`)\n drop_unconnected_node_types (bool, optional): If set to :obj:`True`,\n will drop node types not connected by any edge type.\n (default: :obj:`False`)\n walks_per_node (int, List[int], optional): The number of random walks\n for each starting node in a metapth. (default: :obj:`1`)\n sample_ratio (float, optional): The ratio of source nodes to start\n random walks from. (default: :obj:`1.0`)\n \"\"\"\n def __init__(\n self,\n metapaths: List[List[EdgeType]],\n drop_orig_edge_types: bool = False,\n keep_same_node_type: bool = False,\n drop_unconnected_node_types: bool = False,\n walks_per_node: Union[int, List[int]] = 1,\n sample_ratio: float = 1.0,\n ):\n\n for path in metapaths:\n assert len(path) >= 2, f\"Invalid metapath '{path}'\"\n assert all([\n j[-1] == path[i + 1][0] for i, j in enumerate(path[:-1])\n ]), f\"Invalid sequence of node types in '{path}'\"\n\n self.metapaths = metapaths\n self.drop_orig_edge_types = drop_orig_edge_types\n self.keep_same_node_type = keep_same_node_type\n self.drop_unconnected_node_types = drop_unconnected_node_types\n self.sample_ratio = sample_ratio\n if isinstance(walks_per_node, int):\n walks_per_node = [walks_per_node] * len(metapaths)\n assert len(walks_per_node) == len(metapaths)\n self.walks_per_node = walks_per_node\n\n def forward(self, data: HeteroData) -> HeteroData:\n edge_types = data.edge_types # save original edge types\n data.metapath_dict = {}\n\n for j, metapath in enumerate(self.metapaths):\n for edge_type in metapath:\n assert data._to_canonical(\n edge_type) in edge_types, f\"'{edge_type}' not present\"\n\n src_node = metapath[0][0]\n num_nodes = data[src_node].num_nodes\n num_starts = round(num_nodes * self.sample_ratio)\n row = start = torch.randperm(num_nodes)[:num_starts].repeat(\n self.walks_per_node[j])\n\n for i, edge_type in enumerate(metapath):\n edge_index = EdgeIndex(\n data[edge_type].edge_index,\n sparse_size=data[edge_type].size(),\n )\n col, mask = self.sample(edge_index, start)\n row, col = row[mask], col[mask]\n start = col\n\n new_edge_type = (metapath[0][0], f'metapath_{j}', metapath[-1][-1])\n data[new_edge_type].edge_index = coalesce(torch.vstack([row, col]))\n data.metapath_dict[new_edge_type] = metapath\n\n postprocess(data, edge_types, self.drop_orig_edge_types,\n self.keep_same_node_type, self.drop_unconnected_node_types)\n\n return data\n\n @staticmethod\n def sample(edge_index: EdgeIndex, subset: Tensor) -> Tuple[Tensor, Tensor]:\n \"\"\"Sample neighbors from :obj:`edge_index` for each node in\n :obj:`subset`.\n \"\"\"\n edge_index, _ = edge_index.sort_by('row')\n rowptr = edge_index.get_indptr()\n rowcount = rowptr.diff()[subset]\n\n mask = rowcount > 0\n offset = torch.zeros_like(subset)\n offset[mask] = rowptr[subset[mask]]\n\n rand = torch.rand((rowcount.size(0), 1), device=subset.device)\n rand.mul_(rowcount.to(rand.dtype).view(-1, 1))\n rand = rand.to(torch.long)\n rand.add_(offset.view(-1, 1))\n col = edge_index[1][rand].squeeze()\n return col, mask\n\n def __repr__(self) -> str:\n return (f'{self.__class__.__name__}('\n f'sample_ratio={self.sample_ratio}, '\n f'walks_per_node={self.walks_per_node})')\n\n\ndef postprocess(\n data: HeteroData,\n edge_types: List[EdgeType],\n drop_orig_edge_types: bool,\n keep_same_node_type: bool,\n drop_unconnected_node_types: bool,\n) -> None:\n\n if drop_orig_edge_types:\n for i in edge_types:\n if keep_same_node_type and i[0] == i[-1]:\n continue\n else:\n del data[i]\n\n # Remove nodes not connected by any edge type:\n if drop_unconnected_node_types:\n new_edge_types = data.edge_types\n node_types = data.node_types\n connected_nodes = set()\n for i in new_edge_types:\n connected_nodes.add(i[0])\n connected_nodes.add(i[-1])\n for node in node_types:\n if node not in connected_nodes:\n del data[node]\n\n# Path: torch_geometric/transforms/add_positional_encoding.py\nfrom typing import Any, Optional\n\nimport numpy as np\nimport torch\nfrom torch import Tensor\n\nimport torch_geometric.typing\nfrom torch_geometric.data import Data\nfrom torch_geometric.data.datapipes import functional_transform\nfrom torch_geometric.transforms import BaseTransform\nfrom torch_geometric.utils import (\n get_laplacian,\n get_self_loop_attr,\n is_torch_sparse_tensor,\n scatter,\n to_edge_index,\n to_scipy_sparse_matrix,\n to_torch_coo_tensor,\n to_torch_csr_tensor,\n)\n\n\ndef add_node_attr(\n data: Data,\n value: Any,\n attr_name: Optional[str] = None,\n) -> Data:\n # TODO Move to `BaseTransform`.\n if attr_name is None:\n if data.x is not None:\n x = data.x.view(-1, 1) if data.x.dim() == 1 else data.x\n data.x = torch.cat([x, value.to(x.device, x.dtype)], dim=-1)\n else:\n data.x = value\n else:\n data[attr_name] = value\n\n return data\n\n\n@functional_transform('add_laplacian_eigenvector_pe')\nclass AddLaplacianEigenvectorPE(BaseTransform):\n r\"\"\"Adds the Laplacian eigenvector positional encoding from the\n `\"Benchmarking Graph Neural Networks\" `_\n paper to the given graph\n (functional name: :obj:`add_laplacian_eigenvector_pe`).\n\n Args:\n k (int): The number of non-trivial eigenvectors to consider.\n attr_name (str, optional): The attribute name of the data object to add\n positional encodings to. If set to :obj:`None`, will be\n concatenated to :obj:`data.x`.\n (default: :obj:`\"laplacian_eigenvector_pe\"`)\n is_undirected (bool, optional): If set to :obj:`True`, this transform\n expects undirected graphs as input, and can hence speed up the\n computation of eigenvectors. (default: :obj:`False`)\n **kwargs (optional): Additional arguments of\n :meth:`scipy.sparse.linalg.eigs` (when :attr:`is_undirected` is\n :obj:`False`) or :meth:`scipy.sparse.linalg.eigsh` (when\n :attr:`is_undirected` is :obj:`True`).\n \"\"\"\n # Number of nodes from which to use sparse eigenvector computation:\n SPARSE_THRESHOLD: int = 100\n\n def __init__(\n self,\n k: int,\n attr_name: Optional[str] = 'laplacian_eigenvector_pe',\n is_undirected: bool = False,\n **kwargs: Any,\n ) -> None:\n self.k = k\n self.attr_name = attr_name\n self.is_undirected = is_undirected\n self.kwargs = kwargs\n\n def forward(self, data: Data) -> Data:\n assert data.edge_index is not None\n num_nodes = data.num_nodes\n assert num_nodes is not None\n\n edge_index, edge_weight = get_laplacian(\n data.edge_index,\n data.edge_weight,\n normalization='sym',\n num_nodes=num_nodes,\n )\n\n L = to_scipy_sparse_matrix(edge_index, edge_weight, num_nodes)\n\n if num_nodes < self.SPARSE_THRESHOLD:\n from numpy.linalg import eig, eigh\n eig_fn = eig if not self.is_undirected else eigh\n\n eig_vals, eig_vecs = eig_fn(L.todense()) # type: ignore\n else:\n from scipy.sparse.linalg import eigs, eigsh\n eig_fn = eigs if not self.is_undirected else eigsh\n\n eig_vals, eig_vecs = eig_fn( # type: ignore\n L,\n k=self.k + 1,\n which='SR' if not self.is_undirected else 'SA',\n return_eigenvectors=True,\n **self.kwargs,\n )\n\n eig_vecs = np.real(eig_vecs[:, eig_vals.argsort()])\n pe = torch.from_numpy(eig_vecs[:, 1:self.k + 1])\n sign = -1 + 2 * torch.randint(0, 2, (self.k, ))\n pe *= sign\n\n data = add_node_attr(data, pe, attr_name=self.attr_name)\n return data\n\n\n@functional_transform('add_random_walk_pe')\nclass AddRandomWalkPE(BaseTransform):\n r\"\"\"Adds the random walk positional encoding from the `\"Graph Neural\n Networks with Learnable Structural and Positional Representations\"\n `_ paper to the given graph\n (functional name: :obj:`add_random_walk_pe`).\n\n Args:\n walk_length (int): The number of random walk steps.\n attr_name (str, optional): The attribute name of the data object to add\n positional encodings to. If set to :obj:`None`, will be\n concatenated to :obj:`data.x`.\n (default: :obj:`\"random_walk_pe\"`)\n \"\"\"\n def __init__(\n self,\n walk_length: int,\n attr_name: Optional[str] = 'random_walk_pe',\n ) -> None:\n self.walk_length = walk_length\n self.attr_name = attr_name\n\n def forward(self, data: Data) -> Data:\n assert data.edge_index is not None\n row, col = data.edge_index\n N = data.num_nodes\n assert N is not None\n\n if data.edge_weight is None:\n value = torch.ones(data.num_edges, device=row.device)\n else:\n value = data.edge_weight\n value = scatter(value, row, dim_size=N, reduce='sum').clamp(min=1)[row]\n value = 1.0 / value\n\n if N <= 2_000: # Dense code path for faster computation:\n adj = torch.zeros((N, N), device=row.device)\n adj[row, col] = value\n loop_index = torch.arange(N, device=row.device)\n elif torch_geometric.typing.WITH_MKL:\n adj = to_torch_csr_tensor(data.edge_index, value, size=data.size())\n else:\n adj = to_torch_coo_tensor(data.edge_index, value, size=data.size())\n\n def get_pe(out: Tensor) -> Tensor:\n if is_torch_sparse_tensor(out):\n return get_self_loop_attr(*to_edge_index(out), num_nodes=N)\n return out[loop_index, loop_index]\n\n out = adj\n pe_list = [get_pe(out)]\n for _ in range(self.walk_length - 1):\n out = out @ adj\n pe_list.append(get_pe(out))\n\n pe = torch.stack(pe_list, dim=-1)\n data = add_node_attr(data, pe, attr_name=self.attr_name)\n\n return data\n\n# Path: torch_geometric/transforms/add_remaining_self_loops.py\nfrom typing import Union\n\nfrom torch import Tensor\n\nfrom torch_geometric.data import Data, HeteroData\nfrom torch_geometric.data.datapipes import functional_transform\nfrom torch_geometric.transforms import BaseTransform\nfrom torch_geometric.utils import add_remaining_self_loops\n\n\n@functional_transform('add_remaining_self_loops')\nclass AddRemainingSelfLoops(BaseTransform):\n r\"\"\"Adds remaining self-loops to the given homogeneous or heterogeneous\n graph (functional name: :obj:`add_remaining_self_loops`).\n\n Args:\n attr (str, optional): The name of the attribute of edge weights\n or multi-dimensional edge features to pass to\n :meth:`torch_geometric.utils.add_remaining_self_loops`.\n (default: :obj:`\"edge_weight\"`)\n fill_value (float or Tensor or str, optional): The way to generate\n edge features of self-loops (in case :obj:`attr != None`).\n If given as :obj:`float` or :class:`torch.Tensor`, edge features of\n self-loops will be directly given by :obj:`fill_value`.\n If given as :obj:`str`, edge features of self-loops are computed by\n aggregating all features of edges that point to the specific node,\n according to a reduce operation. (:obj:`\"add\"`, :obj:`\"mean\"`,\n :obj:`\"min\"`, :obj:`\"max\"`, :obj:`\"mul\"`). (default: :obj:`1.`)\n \"\"\"\n def __init__(\n self,\n attr: str = 'edge_weight',\n fill_value: Union[float, Tensor, str] = 1.0,\n ):\n self.attr = attr\n self.fill_value = fill_value\n\n def forward(\n self,\n data: Union[Data, HeteroData],\n ) -> Union[Data, HeteroData]:\n for store in data.edge_stores:\n if store.is_bipartite() or 'edge_index' not in store:\n continue\n\n store.edge_index, store[self.attr] = add_remaining_self_loops(\n store.edge_index,\n edge_attr=store.get(self.attr, None),\n fill_value=self.fill_value,\n num_nodes=store.size(0),\n )\n\n return data\n\n# Path: torch_geometric/transforms/add_self_loops.py\nfrom typing import Union\n\nfrom torch import Tensor\n\nfrom torch_geometric.data import Data, HeteroData\nfrom torch_geometric.data.datapipes import functional_transform\nfrom torch_geometric.transforms import BaseTransform\nfrom torch_geometric.utils import add_self_loops\n\n\n@functional_transform('add_self_loops')\nclass AddSelfLoops(BaseTransform):\n r\"\"\"Adds self-loops to the given homogeneous or heterogeneous graph\n (functional name: :obj:`add_self_loops`).\n\n Args:\n attr (str, optional): The name of the attribute of edge weights\n or multi-dimensional edge features to pass to\n :meth:`torch_geometric.utils.add_self_loops`.\n (default: :obj:`\"edge_weight\"`)\n fill_value (float or Tensor or str, optional): The way to generate\n edge features of self-loops (in case :obj:`attr != None`).\n If given as :obj:`float` or :class:`torch.Tensor`, edge features of\n self-loops will be directly given by :obj:`fill_value`.\n If given as :obj:`str`, edge features of self-loops are computed by\n aggregating all features of edges that point to the specific node,\n according to a reduce operation. (:obj:`\"add\"`, :obj:`\"mean\"`,\n :obj:`\"min\"`, :obj:`\"max\"`, :obj:`\"mul\"`). (default: :obj:`1.`)\n \"\"\"\n def __init__(\n self,\n attr: str = 'edge_weight',\n fill_value: Union[float, Tensor, str] = 1.0,\n ) -> None:\n self.attr = attr\n self.fill_value = fill_value\n\n def forward(\n self,\n data: Union[Data, HeteroData],\n ) -> Union[Data, HeteroData]:\n for store in data.edge_stores:\n if store.is_bipartite() or 'edge_index' not in store:\n continue\n\n store.edge_index, store[self.attr] = add_self_loops(\n store.edge_index,\n edge_attr=store.get(self.attr, None),\n fill_value=self.fill_value,\n num_nodes=store.size(0),\n )\n\n return data\n\n# Path: torch_geometric/transforms/base_transform.py\nimport copy\nfrom abc import ABC\nfrom typing import Any\n\n\nclass BaseTransform(ABC):\n r\"\"\"An abstract base class for writing transforms.\n\n Transforms are a general way to modify and customize\n :class:`~torch_geometric.data.Data` or\n :class:`~torch_geometric.data.HeteroData` objects, either by implicitly\n passing them as an argument to a :class:`~torch_geometric.data.Dataset`, or\n by applying them explicitly to individual\n :class:`~torch_geometric.data.Data` or\n :class:`~torch_geometric.data.HeteroData` objects:\n\n .. code-block:: python\n\n import torch_geometric.transforms as T\n from torch_geometric.datasets import TUDataset\n\n transform = T.Compose([T.ToUndirected(), T.AddSelfLoops()])\n\n dataset = TUDataset(path, name='MUTAG', transform=transform)\n data = dataset[0] # Implicitly transform data on every access.\n\n data = TUDataset(path, name='MUTAG')[0]\n data = transform(data) # Explicitly transform data.\n \"\"\"\n def __call__(self, data: Any) -> Any:\n # Shallow-copy the data so that we prevent in-place data modification.\n return self.forward(copy.copy(data))\n\n def forward(self, data: Any) -> Any:\n pass\n\n def __repr__(self) -> str:\n return f'{self.__class__.__name__}()'\n\n# Path: torch_geometric/transforms/cartesian.py\nfrom typing import Optional, Tuple\n\nimport torch\n\nfrom torch_geometric.data import Data\nfrom torch_geometric.data.datapipes import functional_transform\nfrom torch_geometric.transforms import BaseTransform\n\n\n@functional_transform('cartesian')\nclass Cartesian(BaseTransform):\n r\"\"\"Saves the relative Cartesian coordinates of linked nodes in its edge\n attributes (functional name: :obj:`cartesian`). Each coordinate gets\n globally normalized to a specified interval (:math:`[0, 1]` by default).\n\n Args:\n norm (bool, optional): If set to :obj:`False`, the output will not be\n normalized. (default: :obj:`True`)\n max_value (float, optional): If set and :obj:`norm=True`, normalization\n will be performed based on this value instead of the maximum value\n found in the data. (default: :obj:`None`)\n cat (bool, optional): If set to :obj:`False`, all existing edge\n attributes will be replaced. (default: :obj:`True`)\n interval ((float, float), optional): A tuple specifying the lower and\n upper bound for normalization. (default: :obj:`(0.0, 1.0)`)\n \"\"\"\n def __init__(\n self,\n norm: bool = True,\n max_value: Optional[float] = None,\n cat: bool = True,\n interval: Tuple[float, float] = (0.0, 1.0),\n ):\n self.norm = norm\n self.max = max_value\n self.cat = cat\n self.interval = interval\n\n def forward(self, data: Data) -> Data:\n assert data.pos is not None\n assert data.edge_index is not None\n (row, col), pos, pseudo = data.edge_index, data.pos, data.edge_attr\n\n cart = pos[row] - pos[col]\n cart = cart.view(-1, 1) if cart.dim() == 1 else cart\n\n if self.norm and cart.numel() > 0:\n max_val = float(cart.abs().max()) if self.max is None else self.max\n\n length = self.interval[1] - self.interval[0]\n center = (self.interval[0] + self.interval[1]) / 2\n cart = length * cart / (2 * max_val) + center\n\n if pseudo is not None and self.cat:\n pseudo = pseudo.view(-1, 1) if pseudo.dim() == 1 else pseudo\n data.edge_attr = torch.cat([pseudo, cart.type_as(pseudo)], dim=-1)\n else:\n data.edge_attr = cart\n\n return data\n\n def __repr__(self) -> str:\n return (f'{self.__class__.__name__}(norm={self.norm}, '\n f'max_value={self.max})')\n\n# Path: torch_geometric/transforms/center.py\nfrom typing import Union\n\nfrom torch_geometric.data import Data, HeteroData\nfrom torch_geometric.data.datapipes import functional_transform\nfrom torch_geometric.transforms import BaseTransform\n\n\n@functional_transform('center')\nclass Center(BaseTransform):\n r\"\"\"Centers node positions :obj:`data.pos` around the origin\n (functional name: :obj:`center`).\n \"\"\"\n def forward(\n self,\n data: Union[Data, HeteroData],\n ) -> Union[Data, HeteroData]:\n for store in data.node_stores:\n if hasattr(store, 'pos'):\n store.pos = store.pos - store.pos.mean(dim=-2, keepdim=True)\n return data\n\n# Path: torch_geometric/transforms/compose.py\nfrom typing import Callable, List, Union\n\nfrom torch_geometric.data import Data, HeteroData\nfrom torch_geometric.transforms import BaseTransform\n\n\nclass Compose(BaseTransform):\n r\"\"\"Composes several transforms together.\n\n Args:\n transforms (List[Callable]): List of transforms to compose.\n \"\"\"\n def __init__(self, transforms: List[Callable]):\n self.transforms = transforms\n\n def forward(\n self,\n data: Union[Data, HeteroData],\n ) -> Union[Data, HeteroData]:\n for transform in self.transforms:\n if isinstance(data, (list, tuple)):\n data = [transform(d) for d in data]\n else:\n data = transform(data)\n return data\n\n def __repr__(self) -> str:\n args = [f' {transform}' for transform in self.transforms]\n return '{}([\\n{}\\n])'.format(self.__class__.__name__, ',\\n'.join(args))\n\n\nclass ComposeFilters:\n r\"\"\"Composes several filters together.\n\n Args:\n filters (List[Callable]): List of filters to compose.\n \"\"\"\n def __init__(self, filters: List[Callable]):\n self.filters = filters\n\n def __call__(\n self,\n data: Union[Data, HeteroData],\n ) -> bool:\n for filter_fn in self.filters:\n if isinstance(data, (list, tuple)):\n if not all([filter_fn(d) for d in data]):\n return False\n elif not filter_fn(data):\n return False\n return True\n\n def __repr__(self) -> str:\n args = [f' {filter_fn}' for filter_fn in self.filters]\n return '{}([\\n{}\\n])'.format(self.__class__.__name__, ',\\n'.join(args))\n\n# Path: torch_geometric/transforms/constant.py\nfrom typing import List, Optional, Union\n\nimport torch\n\nfrom torch_geometric.data import Data, HeteroData\nfrom torch_geometric.data.datapipes import functional_transform\nfrom torch_geometric.transforms import BaseTransform\n\n\n@functional_transform('constant')\nclass Constant(BaseTransform):\n r\"\"\"Appends a constant value to each node feature :obj:`x`\n (functional name: :obj:`constant`).\n\n Args:\n value (float, optional): The value to add. (default: :obj:`1.0`)\n cat (bool, optional): If set to :obj:`False`, existing node features\n will be replaced. (default: :obj:`True`)\n node_types (str or List[str], optional): The specified node type(s) to\n append constant values for if used on heterogeneous graphs.\n If set to :obj:`None`, constants will be added to each node feature\n :obj:`x` for all existing node types. (default: :obj:`None`)\n \"\"\"\n def __init__(\n self,\n value: float = 1.0,\n cat: bool = True,\n node_types: Optional[Union[str, List[str]]] = None,\n ):\n if isinstance(node_types, str):\n node_types = [node_types]\n\n self.value = value\n self.cat = cat\n self.node_types = node_types\n\n def forward(\n self,\n data: Union[Data, HeteroData],\n ) -> Union[Data, HeteroData]:\n\n for store in data.node_stores:\n if self.node_types is None or store._key in self.node_types:\n num_nodes = store.num_nodes\n assert num_nodes is not None\n c = torch.full((num_nodes, 1), self.value, dtype=torch.float)\n\n if hasattr(store, 'x') and self.cat:\n x = store.x.view(-1, 1) if store.x.dim() == 1 else store.x\n store.x = torch.cat([x, c.to(x.device, x.dtype)], dim=-1)\n else:\n store.x = c\n\n return data\n\n def __repr__(self) -> str:\n return f'{self.__class__.__name__}(value={self.value})'\n\n# Path: torch_geometric/transforms/delaunay.py\nimport scipy.spatial\nimport torch\n\nfrom torch_geometric.data import Data\nfrom torch_geometric.data.datapipes import functional_transform\nfrom torch_geometric.transforms import BaseTransform\n\n\n@functional_transform('delaunay')\nclass Delaunay(BaseTransform):\n r\"\"\"Computes the delaunay triangulation of a set of points\n (functional name: :obj:`delaunay`).\n...\n", "needle_token_start": 3903, "needle_token_end": 4135, "code_context_ntokens": 16115, "output": ["```python\ndef read_from_memlab(line_profiler: Any) -> List[float]: # pragma: no cover\n from pytorch_memlab.line_profiler.line_records import LineRecords\n\n # See: https://pytorch.org/docs/stable/cuda.html#torch.cuda.memory_stats\n\n track_stats = [ # Different statistic can be collected as needed.\n 'allocated_bytes.all.peak',\n 'reserved_bytes.all.peak',\n 'active_bytes.all.peak',\n ]\n\n records = LineRecords(line_profiler._raw_line_records,\n line_profiler._code_infos)\n stats = records.display(None, track_stats)._line_records\n return [byte_to_megabyte(x) for x in stats.values.max(axis=0).tolist()]\n```"]} +{"repo": "pyg-team/pytorch_geometric", "name": "num_node_labels", "language": "python", "path": "torch_geometric/datasets/tu_dataset.py", "position_ratio": 0.35, "description": "\nFunction Description:\n1. **Purpose**: To retrieve the total number of distinct labels that can be assigned to nodes within the graphs of a dataset.\n2. **Input**: This method does not require any external inputs as it operates on the internal state of the dataset object.\n3. **Output**: Returns an integer representing the count of unique node labels available in the dataset.\n4. **Procedure**: The method accesses a pre-stored dictionary that contains various dataset properties, including the count of unique node labels, and returns the value associated with this specific property.\n\n", "instruction": "Based on the function description and code context, please retrieve and repeat the exact described function from the code context in a code block wrapped by ```:", "template": "instruction\ncode_context\ndescription\ninstruction", "code_context": "# Path: torch_geometric/datasets/snap_dataset.py\nimport os\nimport os.path as osp\nfrom typing import Any, Callable, Dict, List, Optional, Union\n\nimport fsspec\nimport numpy as np\nimport torch\n\n...\n# Path: torch_geometric/datasets/suite_sparse.py\nimport os.path as osp\nfrom typing import Callable, Optional\n\nimport fsspec\nimport torch\n\nfrom torch_geometric.data import Data, InMemoryDataset\nfrom torch_geometric.io import fs\n\n\nclass SuiteSparseMatrixCollection(InMemoryDataset):\n r\"\"\"A suite of sparse matrix benchmarks known as the `Suite Sparse Matrix\n Collection `_ collected from a wide range of\n applications.\n\n Args:\n root (str): Root directory where the dataset should be saved.\n group (str): The group of the sparse matrix.\n name (str): The name of the sparse matrix.\n transform (callable, optional): A function/transform that takes in an\n :obj:`torch_geometric.data.Data` object and returns a transformed\n version. The data object will be transformed before every access.\n (default: :obj:`None`)\n pre_transform (callable, optional): A function/transform that takes in\n an :obj:`torch_geometric.data.Data` object and returns a\n transformed version. The data object will be transformed before\n being saved to disk. (default: :obj:`None`)\n force_reload (bool, optional): Whether to re-process the dataset.\n (default: :obj:`False`)\n \"\"\"\n\n url = 'https://sparse.tamu.edu/mat/{}/{}.mat'\n\n def __init__(\n self,\n root: str,\n group: str,\n name: str,\n transform: Optional[Callable] = None,\n pre_transform: Optional[Callable] = None,\n force_reload: bool = False,\n ) -> None:\n self.group = group\n self.name = name\n super().__init__(root, transform, pre_transform,\n force_reload=force_reload)\n self.load(self.processed_paths[0])\n\n @property\n def raw_dir(self) -> str:\n return osp.join(self.root, self.group, self.name, 'raw')\n\n @property\n def processed_dir(self) -> str:\n return osp.join(self.root, self.group, self.name, 'processed')\n\n @property\n def raw_file_names(self) -> str:\n return f'{self.name}.mat'\n\n @property\n def processed_file_names(self) -> str:\n return 'data.pt'\n\n def download(self) -> None:\n fs.cp(self.url.format(self.group, self.name), self.raw_dir)\n\n def process(self) -> None:\n from scipy.io import loadmat\n\n with fsspec.open(self.raw_paths[0], 'rb') as f:\n mat = loadmat(f)['Problem'][0][0][2].tocsr().tocoo()\n\n row = torch.from_numpy(mat.row).to(torch.long)\n col = torch.from_numpy(mat.col).to(torch.long)\n edge_index = torch.stack([row, col], dim=0)\n\n value = torch.from_numpy(mat.data).to(torch.float)\n edge_attr = None if torch.all(value == 1.0) else value\n\n size: Optional[torch.Size] = torch.Size(mat.shape)\n if mat.shape[0] == mat.shape[1]:\n size = None\n\n num_nodes = mat.shape[0]\n\n data = Data(edge_index=edge_index, edge_attr=edge_attr, size=size,\n num_nodes=num_nodes)\n\n if self.pre_transform is not None:\n data = self.pre_transform(data)\n\n self.save([data], self.processed_paths[0])\n\n def __repr__(self) -> str:\n return (f'{self.__class__.__name__}(group={self.group}, '\n f'name={self.name})')\n\n# Path: torch_geometric/datasets/taobao.py\nimport os\nfrom typing import Callable, Optional\n\nimport numpy as np\nimport torch\n\nfrom torch_geometric.data import (\n HeteroData,\n InMemoryDataset,\n download_url,\n extract_zip,\n)\n\n\nclass Taobao(InMemoryDataset):\n r\"\"\"Taobao is a dataset of user behaviors from Taobao offered by Alibaba,\n provided by the `Tianchi Alicloud platform\n `_.\n\n Taobao is a heterogeneous graph for recommendation.\n Nodes represent users with user IDs, items with item IDs, and categories\n with category ID.\n Edges between users and items represent different types of user behaviors\n towards items (alongside with timestamps).\n Edges between items and categories assign each item to its set of\n categories.\n\n Args:\n root (str): Root directory where the dataset should be saved.\n transform (callable, optional): A function/transform that takes in an\n :obj:`torch_geometric.data.HeteroData` object and returns a\n transformed version. The data object will be transformed before\n every access. (default: :obj:`None`)\n pre_transform (callable, optional): A function/transform that takes in\n an :obj:`torch_geometric.data.HeteroData` object and returns a\n transformed version. The data object will be transformed before\n being saved to disk. (default: :obj:`None`)\n force_reload (bool, optional): Whether to re-process the dataset.\n (default: :obj:`False`)\n\n \"\"\"\n url = ('https://alicloud-dev.oss-cn-hangzhou.aliyuncs.com/'\n 'UserBehavior.csv.zip')\n\n def __init__(\n self,\n root: str,\n transform: Optional[Callable] = None,\n pre_transform: Optional[Callable] = None,\n force_reload: bool = False,\n ) -> None:\n super().__init__(root, transform, pre_transform,\n force_reload=force_reload)\n self.load(self.processed_paths[0], data_cls=HeteroData)\n\n @property\n def raw_file_names(self) -> str:\n return 'UserBehavior.csv'\n\n @property\n def processed_file_names(self) -> str:\n return 'data.pt'\n\n def download(self) -> None:\n path = download_url(self.url, self.raw_dir)\n extract_zip(path, self.raw_dir)\n os.remove(path)\n\n def process(self) -> None:\n import pandas as pd\n\n cols = ['userId', 'itemId', 'categoryId', 'behaviorType', 'timestamp']\n df = pd.read_csv(self.raw_paths[0], names=cols)\n\n # Time representation (YYYY.MM.DD-HH:MM:SS -> Integer)\n # start: 1511539200 = 2017.11.25-00:00:00\n # end: 1512316799 = 2017.12.03-23:59:59\n start = 1511539200\n end = 1512316799\n df = df[(df[\"timestamp\"] >= start) & (df[\"timestamp\"] <= end)]\n\n df = df.drop_duplicates()\n\n behavior_dict = {'pv': 0, 'cart': 1, 'buy': 2, 'fav': 3}\n df['behaviorType'] = df['behaviorType'].map(behavior_dict)\n\n num_entries = {}\n for name in ['userId', 'itemId', 'categoryId']:\n # Map IDs to consecutive integers:\n value, df[name] = np.unique(df[[name]].values, return_inverse=True)\n num_entries[name] = value.shape[0]\n\n data = HeteroData()\n\n data['user'].num_nodes = num_entries['userId']\n data['item'].num_nodes = num_entries['itemId']\n data['category'].num_nodes = num_entries['categoryId']\n\n row = torch.from_numpy(df['userId'].values)\n col = torch.from_numpy(df['itemId'].values)\n data['user', 'item'].edge_index = torch.stack([row, col], dim=0)\n data['user', 'item'].time = torch.from_numpy(df['timestamp'].values)\n behavior = torch.from_numpy(df['behaviorType'].values)\n data['user', 'item'].behavior = behavior\n\n df = df[['itemId', 'categoryId']].drop_duplicates()\n row = torch.from_numpy(df['itemId'].values)\n col = torch.from_numpy(df['categoryId'].values)\n data['item', 'category'].edge_index = torch.stack([row, col], dim=0)\n\n data = data if self.pre_transform is None else self.pre_transform(data)\n\n self.save([data], self.processed_paths[0])\n\n# Path: torch_geometric/datasets/tosca.py\nimport glob\nimport os\nimport os.path as osp\nfrom typing import Callable, List, Optional\n\nimport torch\n\nfrom torch_geometric.data import (\n Data,\n InMemoryDataset,\n download_url,\n extract_zip,\n)\nfrom torch_geometric.io import read_txt_array\n\n\nclass TOSCA(InMemoryDataset):\n r\"\"\"The TOSCA dataset from the `\"Numerical Geometry of Non-Ridig Shapes\"\n `_ book, containing 80 meshes.\n Meshes within the same category have the same triangulation and an equal\n number of vertices numbered in a compatible way.\n\n .. note::\n\n Data objects hold mesh faces instead of edge indices.\n To convert the mesh to a graph, use the\n :obj:`torch_geometric.transforms.FaceToEdge` as :obj:`pre_transform`.\n To convert the mesh to a point cloud, use the\n :obj:`torch_geometric.transforms.SamplePoints` as :obj:`transform` to\n sample a fixed number of points on the mesh faces according to their\n face area.\n\n Args:\n root (str): Root directory where the dataset should be saved.\n categories (list, optional): List of categories to include in the\n dataset. Can include the categories :obj:`\"Cat\"`, :obj:`\"Centaur\"`,\n :obj:`\"David\"`, :obj:`\"Dog\"`, :obj:`\"Gorilla\"`, :obj:`\"Horse\"`,\n :obj:`\"Michael\"`, :obj:`\"Victoria\"`, :obj:`\"Wolf\"`. If set to\n :obj:`None`, the dataset will contain all categories. (default:\n :obj:`None`)\n transform (callable, optional): A function/transform that takes in an\n :obj:`torch_geometric.data.Data` object and returns a transformed\n version. The data object will be transformed before every access.\n (default: :obj:`None`)\n pre_transform (callable, optional): A function/transform that takes in\n an :obj:`torch_geometric.data.Data` object and returns a\n transformed version. The data object will be transformed before\n being saved to disk. (default: :obj:`None`)\n pre_filter (callable, optional): A function that takes in an\n :obj:`torch_geometric.data.Data` object and returns a boolean\n value, indicating whether the data object should be included in the\n final dataset. (default: :obj:`None`)\n force_reload (bool, optional): Whether to re-process the dataset.\n (default: :obj:`False`)\n \"\"\"\n\n url = 'http://tosca.cs.technion.ac.il/data/toscahires-asci.zip'\n\n categories = [\n 'cat', 'centaur', 'david', 'dog', 'gorilla', 'horse', 'michael',\n 'victoria', 'wolf'\n ]\n\n def __init__(\n self,\n root: str,\n categories: Optional[List[str]] = None,\n transform: Optional[Callable] = None,\n pre_transform: Optional[Callable] = None,\n pre_filter: Optional[Callable] = None,\n force_reload: bool = False,\n ) -> None:\n categories = self.categories if categories is None else categories\n categories = [cat.lower() for cat in categories]\n for cat in categories:\n assert cat in self.categories\n self.categories = categories\n super().__init__(root, transform, pre_transform, pre_filter,\n force_reload=force_reload)\n self.load(self.processed_paths[0])\n\n @property\n def raw_file_names(self) -> List[str]:\n return ['cat0.vert', 'cat0.tri']\n\n @property\n def processed_file_names(self) -> str:\n name = '_'.join([cat[:2] for cat in self.categories])\n return f'{name}.pt'\n\n def download(self) -> None:\n path = download_url(self.url, self.raw_dir)\n extract_zip(path, self.raw_dir)\n os.unlink(path)\n\n def process(self) -> None:\n data_list = []\n for cat in self.categories:\n paths = glob.glob(osp.join(self.raw_dir, f'{cat}*.tri'))\n paths = [path[:-4] for path in paths]\n paths = sorted(paths, key=lambda e: (len(e), e))\n\n for path in paths:\n pos = read_txt_array(f'{path}.vert')\n face = read_txt_array(f'{path}.tri', dtype=torch.long)\n face = face - face.min() # Ensure zero-based index.\n data = Data(pos=pos, face=face.t().contiguous())\n if self.pre_filter is not None and not self.pre_filter(data):\n continue\n if self.pre_transform is not None:\n data = self.pre_transform(data)\n data_list.append(data)\n\n self.save(data_list, self.processed_paths[0])\n\n# Path: torch_geometric/datasets/tu_dataset.py\nimport os.path as osp\nfrom typing import Callable, List, Optional\n\nfrom torch_geometric.data import Data, InMemoryDataset\nfrom torch_geometric.io import fs, read_tu_data\n\n\nclass TUDataset(InMemoryDataset):\n r\"\"\"A variety of graph kernel benchmark datasets, *.e.g.*,\n :obj:`\"IMDB-BINARY\"`, :obj:`\"REDDIT-BINARY\"` or :obj:`\"PROTEINS\"`,\n collected from the `TU Dortmund University\n `_.\n In addition, this dataset wrapper provides `cleaned dataset versions\n `_ as motivated by the\n `\"Understanding Isomorphism Bias in Graph Data Sets\"\n `_ paper, containing only non-isomorphic\n graphs.\n\n .. note::\n Some datasets may not come with any node labels.\n You can then either make use of the argument :obj:`use_node_attr`\n to load additional continuous node attributes (if present) or provide\n synthetic node features using transforms such as\n :class:`torch_geometric.transforms.Constant` or\n :class:`torch_geometric.transforms.OneHotDegree`.\n\n Args:\n root (str): Root directory where the dataset should be saved.\n name (str): The `name\n `_ of the\n dataset.\n transform (callable, optional): A function/transform that takes in an\n :obj:`torch_geometric.data.Data` object and returns a transformed\n version. The data object will be transformed before every access.\n (default: :obj:`None`)\n pre_transform (callable, optional): A function/transform that takes in\n an :obj:`torch_geometric.data.Data` object and returns a\n transformed version. The data object will be transformed before\n being saved to disk. (default: :obj:`None`)\n pre_filter (callable, optional): A function that takes in an\n :obj:`torch_geometric.data.Data` object and returns a boolean\n value, indicating whether the data object should be included in the\n final dataset. (default: :obj:`None`)\n force_reload (bool, optional): Whether to re-process the dataset.\n (default: :obj:`False`)\n use_node_attr (bool, optional): If :obj:`True`, the dataset will\n contain additional continuous node attributes (if present).\n (default: :obj:`False`)\n use_edge_attr (bool, optional): If :obj:`True`, the dataset will\n contain additional continuous edge attributes (if present).\n (default: :obj:`False`)\n cleaned (bool, optional): If :obj:`True`, the dataset will\n contain only non-isomorphic graphs. (default: :obj:`False`)\n\n **STATS:**\n\n .. list-table::\n :widths: 20 10 10 10 10 10\n :header-rows: 1\n\n * - Name\n - #graphs\n - #nodes\n - #edges\n - #features\n - #classes\n * - MUTAG\n - 188\n - ~17.9\n - ~39.6\n - 7\n - 2\n * - ENZYMES\n - 600\n - ~32.6\n - ~124.3\n - 3\n - 6\n * - PROTEINS\n - 1,113\n - ~39.1\n - ~145.6\n - 3\n - 2\n * - COLLAB\n - 5,000\n - ~74.5\n - ~4914.4\n - 0\n - 3\n * - IMDB-BINARY\n - 1,000\n - ~19.8\n - ~193.1\n - 0\n - 2\n * - REDDIT-BINARY\n - 2,000\n - ~429.6\n - ~995.5\n - 0\n - 2\n * - ...\n -\n -\n -\n -\n -\n \"\"\"\n\n url = 'https://www.chrsmrrs.com/graphkerneldatasets'\n cleaned_url = ('https://raw.githubusercontent.com/nd7141/'\n 'graph_datasets/master/datasets')\n\n def __init__(\n self,\n root: str,\n name: str,\n transform: Optional[Callable] = None,\n pre_transform: Optional[Callable] = None,\n pre_filter: Optional[Callable] = None,\n force_reload: bool = False,\n use_node_attr: bool = False,\n use_edge_attr: bool = False,\n cleaned: bool = False,\n ) -> None:\n self.name = name\n self.cleaned = cleaned\n super().__init__(root, transform, pre_transform, pre_filter,\n force_reload=force_reload)\n\n out = fs.torch_load(self.processed_paths[0])\n if not isinstance(out, tuple) or len(out) < 3:\n raise RuntimeError(\n \"The 'data' object was created by an older version of PyG. \"\n \"If this error occurred while loading an already existing \"\n \"dataset, remove the 'processed/' directory in the dataset's \"\n \"root folder and try again.\")\n assert len(out) == 3 or len(out) == 4\n\n if len(out) == 3: # Backward compatibility.\n data, self.slices, self.sizes = out\n data_cls = Data\n else:\n data, self.slices, self.sizes, data_cls = out\n\n if not isinstance(data, dict): # Backward compatibility.\n self.data = data\n else:\n self.data = data_cls.from_dict(data)\n\n assert isinstance(self._data, Data)\n if self._data.x is not None and not use_node_attr:\n num_node_attributes = self.num_node_attributes\n self._data.x = self._data.x[:, num_node_attributes:]\n if self._data.edge_attr is not None and not use_edge_attr:\n num_edge_attrs = self.num_edge_attributes\n self._data.edge_attr = self._data.edge_attr[:, num_edge_attrs:]\n\n @property\n def raw_dir(self) -> str:\n name = f'raw{\"_cleaned\" if self.cleaned else \"\"}'\n return osp.join(self.root, self.name, name)\n\n @property\n def processed_dir(self) -> str:\n name = f'processed{\"_cleaned\" if self.cleaned else \"\"}'\n return osp.join(self.root, self.name, name)\n\n @property\n \ndef num_node_labels(self) -> int:\n return self.sizes['num_node_labels']\n\n @property\n def num_node_attributes(self) -> int:\n return self.sizes['num_node_attributes']\n\n @property\n def num_edge_labels(self) -> int:\n return self.sizes['num_edge_labels']\n\n @property\n def num_edge_attributes(self) -> int:\n return self.sizes['num_edge_attributes']\n\n @property\n def raw_file_names(self) -> List[str]:\n names = ['A', 'graph_indicator']\n return [f'{self.name}_{name}.txt' for name in names]\n\n @property\n def processed_file_names(self) -> str:\n return 'data.pt'\n\n def download(self) -> None:\n url = self.cleaned_url if self.cleaned else self.url\n fs.cp(f'{url}/{self.name}.zip', self.raw_dir, extract=True)\n for filename in fs.ls(osp.join(self.raw_dir, self.name)):\n fs.mv(filename, osp.join(self.raw_dir, osp.basename(filename)))\n fs.rm(osp.join(self.raw_dir, self.name))\n\n def process(self) -> None:\n self.data, self.slices, sizes = read_tu_data(self.raw_dir, self.name)\n\n if self.pre_filter is not None or self.pre_transform is not None:\n data_list = [self.get(idx) for idx in range(len(self))]\n\n if self.pre_filter is not None:\n data_list = [d for d in data_list if self.pre_filter(d)]\n\n if self.pre_transform is not None:\n data_list = [self.pre_transform(d) for d in data_list]\n\n self.data, self.slices = self.collate(data_list)\n self._data_list = None # Reset cache.\n\n assert isinstance(self._data, Data)\n fs.torch_save(\n (self._data.to_dict(), self.slices, sizes, self._data.__class__),\n self.processed_paths[0],\n )\n\n def __repr__(self) -> str:\n return f'{self.name}({len(self)})'\n\n# Path: torch_geometric/datasets/twitch.py\nimport os.path as osp\nfrom typing import Callable, Optional\n\nimport numpy as np\nimport torch\n\nfrom torch_geometric.data import Data, InMemoryDataset, download_url\n\n\nclass Twitch(InMemoryDataset):\n r\"\"\"The Twitch Gamer networks introduced in the\n `\"Multi-scale Attributed Node Embedding\"\n `_ paper.\n Nodes represent gamers on Twitch and edges are followerships between them.\n Node features represent embeddings of games played by the Twitch users.\n The task is to predict whether a user streams mature content.\n\n Args:\n root (str): Root directory where the dataset should be saved.\n name (str): The name of the dataset (:obj:`\"DE\"`, :obj:`\"EN\"`,\n :obj:`\"ES\"`, :obj:`\"FR\"`, :obj:`\"PT\"`, :obj:`\"RU\"`).\n transform (callable, optional): A function/transform that takes in an\n :obj:`torch_geometric.data.Data` object and returns a transformed\n version. The data object will be transformed before every access.\n (default: :obj:`None`)\n pre_transform (callable, optional): A function/transform that takes in\n an :obj:`torch_geometric.data.Data` object and returns a\n transformed version. The data object will be transformed before\n being saved to disk. (default: :obj:`None`)\n force_reload (bool, optional): Whether to re-process the dataset.\n (default: :obj:`False`)\n\n **STATS:**\n\n .. list-table::\n :widths: 10 10 10 10 10\n :header-rows: 1\n\n * - Name\n - #nodes\n - #edges\n - #features\n - #classes\n * - DE\n - 9,498\n - 315,774\n - 128\n - 2\n * - EN\n - 7,126\n - 77,774\n - 128\n - 2\n * - ES\n - 4,648\n - 123,412\n - 128\n - 2\n * - FR\n - 6,551\n - 231,883\n - 128\n - 2\n * - PT\n - 1,912\n - 64,510\n - 128\n - 2\n * - RU\n - 4,385\n - 78,993\n - 128\n - 2\n \"\"\"\n\n url = 'https://graphmining.ai/datasets/ptg/twitch'\n\n def __init__(\n self,\n root: str,\n name: str,\n transform: Optional[Callable] = None,\n pre_transform: Optional[Callable] = None,\n force_reload: bool = False,\n ) -> None:\n self.name = name\n assert self.name in ['DE', 'EN', 'ES', 'FR', 'PT', 'RU']\n super().__init__(root, transform, pre_transform,\n force_reload=force_reload)\n self.load(self.processed_paths[0])\n\n @property\n def raw_dir(self) -> str:\n return osp.join(self.root, self.name, 'raw')\n\n @property\n def processed_dir(self) -> str:\n return osp.join(self.root, self.name, 'processed')\n\n @property\n def raw_file_names(self) -> str:\n return f'{self.name}.npz'\n\n @property\n def processed_file_names(self) -> str:\n return 'data.pt'\n\n def download(self) -> None:\n download_url(f'{self.url}/{self.name}.npz', self.raw_dir)\n\n def process(self) -> None:\n data = np.load(self.raw_paths[0], 'r', allow_pickle=True)\n x = torch.from_numpy(data['features']).to(torch.float)\n y = torch.from_numpy(data['target']).to(torch.long)\n\n edge_index = torch.from_numpy(data['edges']).to(torch.long)\n edge_index = edge_index.t().contiguous()\n\n data = Data(x=x, y=y, edge_index=edge_index)\n\n if self.pre_transform is not None:\n data = self.pre_transform(data)\n\n self.save([data], self.processed_paths[0])\n\n# Path: torch_geometric/datasets/upfd.py\nimport os\nimport os.path as osp\nfrom typing import Callable, List, Optional\n\nimport numpy as np\nimport scipy.sparse as sp\nimport torch\n\nfrom torch_geometric.data import (\n Data,\n InMemoryDataset,\n download_google_url,\n extract_zip,\n)\nfrom torch_geometric.io import read_txt_array\nfrom torch_geometric.utils import coalesce, cumsum\n\n\nclass UPFD(InMemoryDataset):\n r\"\"\"The tree-structured fake news propagation graph classification dataset\n from the `\"User Preference-aware Fake News Detection\"\n `_ paper.\n It includes two sets of tree-structured fake & real news propagation graphs\n extracted from Twitter.\n For a single graph, the root node represents the source news, and leaf\n nodes represent Twitter users who retweeted the same root news.\n A user node has an edge to the news node if and only if the user retweeted\n the root news directly.\n Two user nodes have an edge if and only if one user retweeted the root news\n from the other user.\n Four different node features are encoded using different encoders.\n Please refer to `GNN-FakeNews\n `_ repo for more details.\n\n .. note::\n\n For an example of using UPFD, see `examples/upfd.py\n `_.\n\n Args:\n root (str): Root directory where the dataset should be saved.\n name (str): The name of the graph set (:obj:`\"politifact\"`,\n :obj:`\"gossipcop\"`).\n feature (str): The node feature type (:obj:`\"profile\"`, :obj:`\"spacy\"`,\n :obj:`\"bert\"`, :obj:`\"content\"`).\n If set to :obj:`\"profile\"`, the 10-dimensional node feature\n is composed of ten Twitter user profile attributes.\n If set to :obj:`\"spacy\"`, the 300-dimensional node feature is\n composed of Twitter user historical tweets encoded by\n the `spaCy word2vec encoder\n `_.\n If set to :obj:`\"bert\"`, the 768-dimensional node feature is\n composed of Twitter user historical tweets encoded by the\n `bert-as-service `_.\n If set to :obj:`\"content\"`, the 310-dimensional node feature is\n composed of a 300-dimensional \"spacy\" vector plus a\n 10-dimensional \"profile\" vector.\n split (str, optional): If :obj:`\"train\"`, loads the training dataset.\n If :obj:`\"val\"`, loads the validation dataset.\n If :obj:`\"test\"`, loads the test dataset.\n (default: :obj:`\"train\"`)\n transform (callable, optional): A function/transform that takes in an\n :obj:`torch_geometric.data.Data` object and returns a transformed\n version. The data object will be transformed before every access.\n (default: :obj:`None`)\n pre_transform (callable, optional): A function/transform that takes in\n an :obj:`torch_geometric.data.Data` object and returns a\n transformed version. The data object will be transformed before\n being saved to disk. (default: :obj:`None`)\n pre_filter (callable, optional): A function that takes in an\n :obj:`torch_geometric.data.Data` object and returns a boolean\n value, indicating whether the data object should be included in the\n final dataset. (default: :obj:`None`)\n force_reload (bool, optional): Whether to re-process the dataset.\n (default: :obj:`False`)\n \"\"\"\n file_ids = {\n 'politifact': '1KOmSrlGcC50PjkvRVbyb_WoWHVql06J-',\n 'gossipcop': '1VskhAQ92PrT4sWEKQ2v2-AJhEcpp4A81',\n }\n\n def __init__(\n self,\n root: str,\n name: str,\n feature: str,\n split: str = \"train\",\n transform: Optional[Callable] = None,\n pre_transform: Optional[Callable] = None,\n pre_filter: Optional[Callable] = None,\n force_reload: bool = False,\n ) -> None:\n assert name in ['politifact', 'gossipcop']\n assert split in ['train', 'val', 'test']\n\n self.root = root\n self.name = name\n self.feature = feature\n\n super().__init__(root, transform, pre_transform, pre_filter,\n force_reload=force_reload)\n\n path = self.processed_paths[['train', 'val', 'test'].index(split)]\n self.load(path)\n\n @property\n def raw_dir(self) -> str:\n return osp.join(self.root, self.name, 'raw')\n\n @property\n def processed_dir(self) -> str:\n return osp.join(self.root, self.name, 'processed', self.feature)\n\n @property\n def raw_file_names(self) -> List[str]:\n return [\n 'node_graph_id.npy', 'graph_labels.npy', 'A.txt', 'train_idx.npy',\n 'val_idx.npy', 'test_idx.npy', f'new_{self.feature}_feature.npz'\n ]\n\n @property\n def processed_file_names(self) -> List[str]:\n return ['train.pt', 'val.pt', 'test.pt']\n\n def download(self) -> None:\n id = self.file_ids[self.name]\n path = download_google_url(id, self.raw_dir, 'data.zip')\n extract_zip(path, self.raw_dir)\n os.remove(path)\n\n def process(self) -> None:\n x = sp.load_npz(\n osp.join(self.raw_dir, f'new_{self.feature}_feature.npz'))\n x = torch.from_numpy(x.todense()).to(torch.float)\n\n edge_index = read_txt_array(osp.join(self.raw_dir, 'A.txt'), sep=',',\n dtype=torch.long).t()\n edge_index = coalesce(edge_index, num_nodes=x.size(0))\n\n y = np.load(osp.join(self.raw_dir, 'graph_labels.npy'))\n y = torch.from_numpy(y).to(torch.long)\n _, y = y.unique(sorted=True, return_inverse=True)\n\n batch = np.load(osp.join(self.raw_dir, 'node_graph_id.npy'))\n batch = torch.from_numpy(batch).to(torch.long)\n\n node_slice = cumsum(batch.bincount())\n edge_slice = cumsum(batch[edge_index[0]].bincount())\n graph_slice = torch.arange(y.size(0) + 1)\n self.slices = {\n 'x': node_slice,\n 'edge_index': edge_slice,\n 'y': graph_slice\n }\n\n edge_index -= node_slice[batch[edge_index[0]]].view(1, -1)\n self.data = Data(x=x, edge_index=edge_index, y=y)\n\n for path, split in zip(self.processed_paths, ['train', 'val', 'test']):\n idx = np.load(osp.join(self.raw_dir, f'{split}_idx.npy')).tolist()\n data_list = [self.get(i) for i in idx]\n if self.pre_filter is not None:\n data_list = [d for d in data_list if self.pre_filter(d)]\n if self.pre_transform is not None:\n data_list = [self.pre_transform(d) for d in data_list]\n self.save(data_list, path)\n\n def __repr__(self) -> str:\n return (f'{self.__class__.__name__}({len(self)}, name={self.name}, '\n f'feature={self.feature})')\n\n# Path: torch_geometric/datasets/utils/cheatsheet.py\nimport importlib\nimport inspect\nimport re\nfrom typing import Any, List, Optional\n\n\ndef paper_link(cls: str) -> Optional[str]:\n cls = importlib.import_module('torch_geometric.datasets').__dict__[cls]\n doc = inspect.getdoc(cls)\n assert doc is not None\n match = re.search('<.+?>', doc, flags=re.DOTALL)\n return None if match is None else match.group().replace('\\n', ' ')[1:-1]\n\n\ndef get_stats_table(cls: str) -> str:\n cls = importlib.import_module('torch_geometric.datasets').__dict__[cls]\n doc = inspect.getdoc(cls)\n assert doc is not None\n match = re.search(r'\\*\\*STATS:\\*\\*\\n.*$', doc, flags=re.DOTALL)\n return '' if match is None else match.group()\n\n\ndef has_stats(cls: str) -> bool:\n return len(get_stats_table(cls)) > 0\n\n\ndef get_type(cls: str) -> str:\n return 'Edge' if '-' in cls else 'Node'\n\n\ndef get_stat(cls: str, name: str, child: Optional[str] = None,\n default: Any = None) -> str:\n if child is None and len(get_children(cls)) > 0:\n return ''\n\n stats_table = get_stats_table(cls)\n\n if len(stats_table) > 0:\n stats_table = '\\n'.join(stats_table.split('\\n')[2:])\n\n match = re.search(f'^.*- {name}', stats_table, flags=re.DOTALL)\n if match is None:\n return default\n\n column = match.group().count(' -')\n\n if child is not None:\n child = child.replace('(', r'\\(').replace(')', r'\\)')\n match = re.search(f'[*] - {child}\\n.*$', stats_table, flags=re.DOTALL)\n assert match is not None\n stats_row = match.group()\n else:\n stats_row = '*' + stats_table.split('*')[2]\n\n return stats_row.split(' -')[column].split('\\n')[0].strip()\n\n\ndef get_children(cls: str) -> List[str]:\n matches = re.findall('[*] -.*', get_stats_table(cls))\n return [match[4:] for match in matches[1:]] if len(matches) > 2 else []\n\n# Path: torch_geometric/datasets/utils/__init__.py\nfrom .cheatsheet import paper_link, has_stats, get_stat, get_children, get_type\n\n__all__ = [\n 'paper_link',\n 'has_stats',\n 'get_stat',\n 'get_children',\n 'get_type',\n]\n\n# Path: torch_geometric/datasets/webkb.py\nimport os.path as osp\nfrom typing import Callable, List, Optional\n\nimport numpy as np\nimport torch\n\nfrom torch_geometric.data import Data, InMemoryDataset, download_url\nfrom torch_geometric.utils import coalesce\n\n\nclass WebKB(InMemoryDataset):\n r\"\"\"The WebKB datasets used in the\n `\"Geom-GCN: Geometric Graph Convolutional Networks\"\n `_ paper.\n Nodes represent web pages and edges represent hyperlinks between them.\n Node features are the bag-of-words representation of web pages.\n The task is to classify the nodes into one of the five categories, student,\n project, course, staff, and faculty.\n\n Args:\n root (str): Root directory where the dataset should be saved.\n name (str): The name of the dataset (:obj:`\"Cornell\"`, :obj:`\"Texas\"`,\n :obj:`\"Wisconsin\"`).\n transform (callable, optional): A function/transform that takes in an\n :obj:`torch_geometric.data.Data` object and returns a transformed\n version. The data object will be transformed before every access.\n (default: :obj:`None`)\n pre_transform (callable, optional): A function/transform that takes in\n an :obj:`torch_geometric.data.Data` object and returns a\n transformed version. The data object will be transformed before\n being saved to disk. (default: :obj:`None`)\n force_reload (bool, optional): Whether to re-process the dataset.\n (default: :obj:`False`)\n\n **STATS:**\n\n .. list-table::\n :widths: 10 10 10 10 10\n :header-rows: 1\n\n * - Name\n - #nodes\n - #edges\n - #features\n - #classes\n * - Cornell\n - 183\n - 298\n - 1,703\n - 5\n * - Texas\n - 183\n - 325\n - 1,703\n - 5\n * - Wisconsin\n - 251\n - 515\n - 1,703\n - 5\n \"\"\"\n\n url = 'https://raw.githubusercontent.com/graphdml-uiuc-jlu/geom-gcn/master'\n\n def __init__(\n self,\n root: str,\n name: str,\n transform: Optional[Callable] = None,\n pre_transform: Optional[Callable] = None,\n force_reload: bool = False,\n ) -> None:\n self.name = name.lower()\n assert self.name in ['cornell', 'texas', 'wisconsin']\n\n super().__init__(root, transform, pre_transform,\n force_reload=force_reload)\n self.load(self.processed_paths[0])\n\n @property\n def raw_dir(self) -> str:\n return osp.join(self.root, self.name, 'raw')\n\n @property\n def processed_dir(self) -> str:\n return osp.join(self.root, self.name, 'processed')\n\n @property\n def raw_file_names(self) -> List[str]:\n out = ['out1_node_feature_label.txt', 'out1_graph_edges.txt']\n out += [f'{self.name}_split_0.6_0.2_{i}.npz' for i in range(10)]\n return out\n\n @property\n def processed_file_names(self) -> str:\n return 'data.pt'\n\n def download(self) -> None:\n for f in self.raw_file_names[:2]:\n download_url(f'{self.url}/new_data/{self.name}/{f}', self.raw_dir)\n for f in self.raw_file_names[2:]:\n download_url(f'{self.url}/splits/{f}', self.raw_dir)\n\n def process(self) -> None:\n with open(self.raw_paths[0], 'r') as f:\n lines = f.read().split('\\n')[1:-1]\n xs = [[float(value) for value in line.split('\\t')[1].split(',')]\n for line in lines]\n x = torch.tensor(xs, dtype=torch.float)\n\n ys = [int(line.split('\\t')[2]) for line in lines]\n y = torch.tensor(ys, dtype=torch.long)\n\n with open(self.raw_paths[1], 'r') as f:\n lines = f.read().split('\\n')[1:-1]\n edge_indices = [[int(value) for value in line.split('\\t')]\n for line in lines]\n edge_index = torch.tensor(edge_indices).t().contiguous()\n edge_index = coalesce(edge_index, num_nodes=x.size(0))\n\n train_masks, val_masks, test_masks = [], [], []\n for path in self.raw_paths[2:]:\n tmp = np.load(path)\n train_masks += [torch.from_numpy(tmp['train_mask']).to(torch.bool)]\n val_masks += [torch.from_numpy(tmp['val_mask']).to(torch.bool)]\n test_masks += [torch.from_numpy(tmp['test_mask']).to(torch.bool)]\n train_mask = torch.stack(train_masks, dim=1)\n val_mask = torch.stack(val_masks, dim=1)\n test_mask = torch.stack(test_masks, dim=1)\n\n data = Data(x=x, edge_index=edge_index, y=y, train_mask=train_mask,\n val_mask=val_mask, test_mask=test_mask)\n data = data if self.pre_transform is None else self.pre_transform(data)\n self.save([data], self.processed_paths[0])\n\n def __repr__(self) -> str:\n return f'{self.name}()'\n\n# Path: torch_geometric/datasets/wikics.py\nimport json\nimport warnings\nfrom itertools import chain\nfrom typing import Callable, List, Optional\n\nimport torch\n\nfrom torch_geometric.data import Data, InMemoryDataset, download_url\nfrom torch_geometric.utils import to_undirected\n\n\nclass WikiCS(InMemoryDataset):\n r\"\"\"The semi-supervised Wikipedia-based dataset from the\n `\"Wiki-CS: A Wikipedia-Based Benchmark for Graph Neural Networks\"\n `_ paper, containing 11,701 nodes,\n 216,123 edges, 10 classes and 20 different training splits.\n\n Args:\n root (str): Root directory where the dataset should be saved.\n transform (callable, optional): A function/transform that takes in an\n :obj:`torch_geometric.data.Data` object and returns a transformed\n version. The data object will be transformed before every access.\n (default: :obj:`None`)\n pre_transform (callable, optional): A function/transform that takes in\n an :obj:`torch_geometric.data.Data` object and returns a\n transformed version. The data object will be transformed before\n being saved to disk. (default: :obj:`None`)\n is_undirected (bool, optional): Whether the graph is undirected.\n (default: :obj:`True`)\n force_reload (bool, optional): Whether to re-process the dataset.\n (default: :obj:`False`)\n \"\"\"\n\n url = 'https://github.com/pmernyei/wiki-cs-dataset/raw/master/dataset'\n\n def __init__(\n self,\n root: str,\n transform: Optional[Callable] = None,\n pre_transform: Optional[Callable] = None,\n is_undirected: Optional[bool] = None,\n force_reload: bool = False,\n ) -> None:\n if is_undirected is None:\n warnings.warn(\n f\"The {self.__class__.__name__} dataset now returns an \"\n f\"undirected graph by default. Please explicitly specify \"\n f\"'is_undirected=False' to restore the old behavior.\")\n is_undirected = True\n self.is_undirected = is_undirected\n super().__init__(root, transform, pre_transform,\n force_reload=force_reload)\n self.load(self.processed_paths[0])\n\n @property\n def raw_file_names(self) -> List[str]:\n return ['data.json']\n\n @property\n def processed_file_names(self) -> str:\n return 'data_undirected.pt' if self.is_undirected else 'data.pt'\n\n def download(self) -> None:\n for name in self.raw_file_names:\n download_url(f'{self.url}/{name}', self.raw_dir)\n\n def process(self) -> None:\n with open(self.raw_paths[0], 'r') as f:\n data = json.load(f)\n\n x = torch.tensor(data['features'], dtype=torch.float)\n y = torch.tensor(data['labels'], dtype=torch.long)\n\n edges = [[(i, j) for j in js] for i, js in enumerate(data['links'])]\n edges = list(chain(*edges)) # type: ignore\n edge_index = torch.tensor(edges, dtype=torch.long).t().contiguous()\n if self.is_undirected:\n edge_index = to_undirected(edge_index, num_nodes=x.size(0))\n\n train_mask = torch.tensor(data['train_masks'], dtype=torch.bool)\n train_mask = train_mask.t().contiguous()\n\n val_mask = torch.tensor(data['val_masks'], dtype=torch.bool)\n val_mask = val_mask.t().contiguous()\n\n test_mask = torch.tensor(data['test_mask'], dtype=torch.bool)\n\n stopping_mask = torch.tensor(data['stopping_masks'], dtype=torch.bool)\n stopping_mask = stopping_mask.t().contiguous()\n\n data = Data(x=x, y=y, edge_index=edge_index, train_mask=train_mask,\n val_mask=val_mask, test_mask=test_mask,\n stopping_mask=stopping_mask)\n\n if self.pre_transform is not None:\n data = self.pre_transform(data)\n\n self.save([data], self.processed_paths[0])\n\n# Path: torch_geometric/datasets/wikidata.py\nimport os\nimport os.path as osp\nfrom typing import Callable, Dict, List, Optional\n\nimport torch\n\nfrom torch_geometric.data import (\n Data,\n InMemoryDataset,\n download_url,\n extract_tar,\n)\n\n\nclass Wikidata5M(InMemoryDataset):\n r\"\"\"The Wikidata-5M dataset from the `\"KEPLER: A Unified Model for\n Knowledge Embedding and Pre-trained Language Representation\"\n `_ paper,\n containing 4,594,485 entities, 822 relations,\n 20,614,279 train triples, 5,163 validation triples, and 5,133 test triples.\n\n `Wikidata-5M `_\n is a large-scale knowledge graph dataset with aligned corpus\n extracted form Wikidata.\n\n Args:\n root (str): Root directory where the dataset should be saved.\n setting (str, optional):\n If :obj:`\"transductive\"`, loads the transductive dataset.\n If :obj:`\"inductive\"`, loads the inductive dataset.\n (default: :obj:`\"transductive\"`)\n transform (callable, optional): A function/transform that takes in an\n :obj:`torch_geometric.data.Data` object and returns a transformed\n version. The data object will be transformed before every access.\n (default: :obj:`None`)\n pre_transform (callable, optional): A function/transform that takes in\n an :obj:`torch_geometric.data.Data` object and returns a\n transformed version. The data object will be transformed before\n being saved to disk. (default: :obj:`None`)\n force_reload (bool, optional): Whether to re-process the dataset.\n (default: :obj:`False`)\n \"\"\"\n def __init__(\n self,\n root: str,\n setting: str = 'transductive',\n transform: Optional[Callable] = None,\n pre_transform: Optional[Callable] = None,\n force_reload: bool = False,\n ) -> None:\n if setting not in {'transductive', 'inductive'}:\n raise ValueError(f\"Invalid 'setting' argument (got '{setting}')\")\n\n self.setting = setting\n\n self.urls = [\n ('https://www.dropbox.com/s/7jp4ib8zo3i6m10/'\n 'wikidata5m_text.txt.gz?dl=1'),\n 'https://uni-bielefeld.sciebo.de/s/yuBKzBxsEc9j3hy/download',\n ]\n if self.setting == 'inductive':\n self.urls.append('https://www.dropbox.com/s/csed3cgal3m7rzo/'\n 'wikidata5m_inductive.tar.gz?dl=1')\n else:\n self.urls.append('https://www.dropbox.com/s/6sbhm0rwo4l73jq/'\n 'wikidata5m_transductive.tar.gz?dl=1')\n\n super().__init__(root, transform, pre_transform,\n force_reload=force_reload)\n self.load(self.processed_paths[0])\n\n @property\n def raw_file_names(self) -> List[str]:\n return [\n 'wikidata5m_text.txt.gz',\n 'download',\n f'wikidata5m_{self.setting}_train.txt',\n f'wikidata5m_{self.setting}_valid.txt',\n f'wikidata5m_{self.setting}_test.txt',\n ]\n\n @property\n def processed_file_names(self) -> str:\n return f'{self.setting}_data.pt'\n\n def download(self) -> None:\n for url in self.urls:\n download_url(url, self.raw_dir)\n path = osp.join(self.raw_dir, f'wikidata5m_{self.setting}.tar.gz')\n extract_tar(path, self.raw_dir)\n os.remove(path)\n\n def process(self) -> None:\n import gzip\n\n entity_to_id: Dict[str, int] = {}\n with gzip.open(self.raw_paths[0], 'rt') as f:\n for i, line in enumerate(f):\n values = line.strip().split('\\t')\n entity_to_id[values[0]] = i\n\n x = torch.load(self.raw_paths[1])\n\n edge_indices = []\n edge_types = []\n split_indices = []\n\n rel_to_id: Dict[str, int] = {}\n for split, path in enumerate(self.raw_paths[2:]):\n with open(path, 'r') as f:\n for line in f:\n head, rel, tail = line[:-1].split('\\t')\n src = entity_to_id[head]\n dst = entity_to_id[tail]\n edge_indices.append([src, dst])\n if rel not in rel_to_id:\n rel_to_id[rel] = len(rel_to_id)\n edge_types.append(rel_to_id[rel])\n split_indices.append(split)\n\n edge_index = torch.tensor(edge_indices).t().contiguous()\n edge_type = torch.tensor(edge_types)\n split_index = torch.tensor(split_indices)\n\n data = Data(\n x=x,\n edge_index=edge_index,\n edge_type=edge_type,\n train_mask=split_index == 0,\n val_mask=split_index == 1,\n test_mask=split_index == 2,\n )\n\n if self.pre_transform is not None:\n data = self.pre_transform(data)\n\n self.save([data], self.processed_paths[0])\n\n# Path: torch_geometric/datasets/wikipedia_network.py\nimport os.path as osp\nfrom typing import Callable, List, Optional, Union\n\nimport numpy as np\nimport torch\n\nfrom torch_geometric.data import Data, InMemoryDataset, download_url\nfrom torch_geometric.utils import coalesce\n\n\nclass WikipediaNetwork(InMemoryDataset):\n r\"\"\"The Wikipedia networks introduced in the\n `\"Multi-scale Attributed Node Embedding\"\n `_ paper.\n Nodes represent web pages and edges represent hyperlinks between them.\n Node features represent several informative nouns in the Wikipedia pages.\n The task is to predict the average daily traffic of the web page.\n\n Args:\n root (str): Root directory where the dataset should be saved.\n name (str): The name of the dataset (:obj:`\"chameleon\"`,\n :obj:`\"crocodile\"`, :obj:`\"squirrel\"`).\n geom_gcn_preprocess (bool): If set to :obj:`True`, will load the\n pre-processed data as introduced in the `\"Geom-GCN: Geometric\n Graph Convolutional Networks\" _`,\n in which the average monthly traffic of the web page is converted\n into five categories to predict.\n If set to :obj:`True`, the dataset :obj:`\"crocodile\"` is not\n available.\n If set to :obj:`True`, train/validation/test splits will be\n available as masks for multiple splits with shape\n :obj:`[num_nodes, num_splits]`. (default: :obj:`True`)\n transform (callable, optional): A function/transform that takes in an\n :obj:`torch_geometric.data.Data` object and returns a transformed\n version. The data object will be transformed before every access.\n (default: :obj:`None`)\n pre_transform (callable, optional): A function/transform that takes in\n an :obj:`torch_geometric.data.Data` object and returns a\n transformed version. The data object will be transformed before\n being saved to disk. (default: :obj:`None`)\n force_reload (bool, optional): Whether to re-process the dataset.\n (default: :obj:`False`)\n\n \"\"\"\n\n raw_url = 'https://graphmining.ai/datasets/ptg/wiki'\n processed_url = ('https://raw.githubusercontent.com/graphdml-uiuc-jlu/'\n 'geom-gcn/f1fc0d14b3b019c562737240d06ec83b07d16a8f')\n\n def __init__(\n self,\n root: str,\n name: str,\n geom_gcn_preprocess: bool = True,\n transform: Optional[Callable] = None,\n pre_transform: Optional[Callable] = None,\n force_reload: bool = False,\n ) -> None:\n self.name = name.lower()\n self.geom_gcn_preprocess = geom_gcn_preprocess\n assert self.name in ['chameleon', 'crocodile', 'squirrel']\n if geom_gcn_preprocess and self.name == 'crocodile':\n raise AttributeError(\"The dataset 'crocodile' is not available in \"\n \"case 'geom_gcn_preprocess=True'\")\n super().__init__(root, transform, pre_transform,\n force_reload=force_reload)\n self.load(self.processed_paths[0])\n\n @property\n def raw_dir(self) -> str:\n if self.geom_gcn_preprocess:\n return osp.join(self.root, self.name, 'geom_gcn', 'raw')\n else:\n return osp.join(self.root, self.name, 'raw')\n\n @property\n def processed_dir(self) -> str:\n if self.geom_gcn_preprocess:\n return osp.join(self.root, self.name, 'geom_gcn', 'processed')\n else:\n return osp.join(self.root, self.name, 'processed')\n\n @property\n def raw_file_names(self) -> Union[List[str], str]:\n if self.geom_gcn_preprocess:\n return (['out1_node_feature_label.txt', 'out1_graph_edges.txt'] +\n [f'{self.name}_split_0.6_0.2_{i}.npz' for i in range(10)])\n else:\n return f'{self.name}.npz'\n\n @property\n def processed_file_names(self) -> str:\n return 'data.pt'\n\n def download(self) -> None:\n if self.geom_gcn_preprocess:\n for filename in self.raw_file_names[:2]:\n...\n", "needle_token_start": 5629, "needle_token_end": 5655, "code_context_ntokens": 16113, "output": ["```python\ndef num_node_labels(self) -> int:\n return self.sizes['num_node_labels']\n```"]} +{"repo": "pyg-team/pytorch_geometric", "name": "__collate__", "language": "python", "path": "torch_geometric/loader/shadow.py", "position_ratio": 0.45, "description": "\nFunction Description:\n1. **Purpose**: The function is designed to generate mini-batches of subgraphs from a larger graph, focusing on a localized neighborhood around specified nodes. This is particularly useful for training graph neural networks on sections of a graph to reduce computational complexity and focus learning on local structures.\n2. **Input**: A list of node indices for which localized subgraphs are to be generated.\n3. **Output**: A batch object containing the subgraphs corresponding to the input node indices. This batch includes the adjacency information, node features, and any other relevant graph data restricted to the subgraph.\n4. **Procedure**: \n - Convert the input node indices into a tensor if they are not already.\n - Extract the adjacency matrix for the specified nodes using a k-hop sampling method, which considers a specified number of neighbor nodes across a defined number of hops.\n - Construct a new sparse tensor representing the adjacency matrix of the subgraph.\n - Populate a batch object with the subgraph's data, including node features and edge attributes, ensuring that only data relevant to the subgraph's nodes and edges is included.\n - Return the populated batch object, ready for use in graph neural network training or evaluation.\n\n", "instruction": "Based on the function description and code context, please retrieve and repeat the exact described function from the code context in a code block wrapped by ```:", "template": "instruction\ncode_context\ndescription\ninstruction", "code_context": "# Path: torch_geometric/loader/neighbor_loader.py\nfrom typing import Callable, Dict, List, Optional, Tuple, Union\n\nfrom torch_geometric.data import Data, FeatureStore, GraphStore, HeteroData\nfrom torch_geometric.loader.node_loader import NodeLoader\nfrom torch_geometric.sampler import NeighborSampler\nfrom torch_geometric.sampler.base import SubgraphType\nfrom torch_geometric.typing import EdgeType, InputNodes, OptTensor\n\n\nclass NeighborLoader(NodeLoader):\n r\"\"\"A data loader that performs neighbor sampling as introduced in the\n `\"Inductive Representation Learning on Large Graphs\"\n `_ paper.\n This loader allows for mini-batch training of GNNs on large-scale graphs\n where full-batch training is not feasible.\n\n More specifically, :obj:`num_neighbors` denotes how much neighbors are\n sampled for each node in each iteration.\n :class:`~torch_geometric.loader.NeighborLoader` takes in this list of\n :obj:`num_neighbors` and iteratively samples :obj:`num_neighbors[i]` for\n each node involved in iteration :obj:`i - 1`.\n\n Sampled nodes are sorted based on the order in which they were sampled.\n In particular, the first :obj:`batch_size` nodes represent the set of\n original mini-batch nodes.\n\n .. code-block:: python\n\n from torch_geometric.datasets import Planetoid\n from torch_geometric.loader import NeighborLoader\n\n data = Planetoid(path, name='Cora')[0]\n\n loader = NeighborLoader(\n data,\n # Sample 30 neighbors for each node for 2 iterations\n num_neighbors=[30] * 2,\n # Use a batch size of 128 for sampling training nodes\n batch_size=128,\n input_nodes=data.train_mask,\n )\n\n sampled_data = next(iter(loader))\n print(sampled_data.batch_size)\n >>> 128\n\n By default, the data loader will only include the edges that were\n originally sampled (:obj:`directed = True`).\n This option should only be used in case the number of hops is equivalent to\n the number of GNN layers.\n In case the number of GNN layers is greater than the number of hops,\n consider setting :obj:`directed = False`, which will include all edges\n between all sampled nodes (but is slightly slower as a result).\n\n Furthermore, :class:`~torch_geometric.loader.NeighborLoader` works for both\n **homogeneous** graphs stored via :class:`~torch_geometric.data.Data` as\n well as **heterogeneous** graphs stored via\n :class:`~torch_geometric.data.HeteroData`.\n When operating in heterogeneous graphs, up to :obj:`num_neighbors`\n neighbors will be sampled for each :obj:`edge_type`.\n However, more fine-grained control over\n the amount of sampled neighbors of individual edge types is possible:\n\n .. code-block:: python\n\n from torch_geometric.datasets import OGB_MAG\n from torch_geometric.loader import NeighborLoader\n\n hetero_data = OGB_MAG(path)[0]\n\n loader = NeighborLoader(\n hetero_data,\n # Sample 30 neighbors for each node and edge type for 2 iterations\n num_neighbors={key: [30] * 2 for key in hetero_data.edge_types},\n # Use a batch size of 128 for sampling training nodes of type paper\n batch_size=128,\n input_nodes=('paper', hetero_data['paper'].train_mask),\n )\n\n sampled_hetero_data = next(iter(loader))\n print(sampled_hetero_data['paper'].batch_size)\n >>> 128\n\n .. note::\n\n For an example of using\n :class:`~torch_geometric.loader.NeighborLoader`, see\n `examples/hetero/to_hetero_mag.py `_.\n\n The :class:`~torch_geometric.loader.NeighborLoader` will return subgraphs\n where global node indices are mapped to local indices corresponding to this\n specific subgraph. However, often times it is desired to map the nodes of\n the current subgraph back to the global node indices. The\n :class:`~torch_geometric.loader.NeighborLoader` will include this mapping\n as part of the :obj:`data` object:\n\n .. code-block:: python\n\n loader = NeighborLoader(data, ...)\n sampled_data = next(iter(loader))\n print(sampled_data.n_id) # Global node index of each node in batch.\n\n In particular, the data loader will add the following attributes to the\n returned mini-batch:\n\n * :obj:`batch_size` The number of seed nodes (first nodes in the batch)\n * :obj:`n_id` The global node index for every sampled node\n * :obj:`e_id` The global edge index for every sampled edge\n * :obj:`input_id`: The global index of the :obj:`input_nodes`\n * :obj:`num_sampled_nodes`: The number of sampled nodes in each hop\n * :obj:`num_sampled_edges`: The number of sampled edges in each hop\n\n Args:\n data (Any): A :class:`~torch_geometric.data.Data`,\n :class:`~torch_geometric.data.HeteroData`, or\n (:class:`~torch_geometric.data.FeatureStore`,\n :class:`~torch_geometric.data.GraphStore`) data object.\n num_neighbors (List[int] or Dict[Tuple[str, str, str], List[int]]): The\n number of neighbors to sample for each node in each iteration.\n If an entry is set to :obj:`-1`, all neighbors will be included.\n In heterogeneous graphs, may also take in a dictionary denoting\n the amount of neighbors to sample for each individual edge type.\n input_nodes (torch.Tensor or str or Tuple[str, torch.Tensor]): The\n indices of nodes for which neighbors are sampled to create\n mini-batches.\n Needs to be either given as a :obj:`torch.LongTensor` or\n :obj:`torch.BoolTensor`.\n If set to :obj:`None`, all nodes will be considered.\n In heterogeneous graphs, needs to be passed as a tuple that holds\n the node type and node indices. (default: :obj:`None`)\n input_time (torch.Tensor, optional): Optional values to override the\n timestamp for the input nodes given in :obj:`input_nodes`. If not\n set, will use the timestamps in :obj:`time_attr` as default (if\n present). The :obj:`time_attr` needs to be set for this to work.\n (default: :obj:`None`)\n replace (bool, optional): If set to :obj:`True`, will sample with\n...\n# Path: torch_geometric/loader/neighbor_sampler.py\nfrom typing import Callable, List, NamedTuple, Optional, Tuple, Union\n\nimport torch\nfrom torch import Tensor\n\nfrom torch_geometric.typing import SparseTensor\n\n\nclass EdgeIndex(NamedTuple):\n edge_index: Tensor\n e_id: Optional[Tensor]\n size: Tuple[int, int]\n\n def to(self, *args, **kwargs):\n edge_index = self.edge_index.to(*args, **kwargs)\n e_id = self.e_id.to(*args, **kwargs) if self.e_id is not None else None\n return EdgeIndex(edge_index, e_id, self.size)\n\n\nclass Adj(NamedTuple):\n adj_t: SparseTensor\n e_id: Optional[Tensor]\n size: Tuple[int, int]\n\n def to(self, *args, **kwargs):\n adj_t = self.adj_t.to(*args, **kwargs)\n e_id = self.e_id.to(*args, **kwargs) if self.e_id is not None else None\n return Adj(adj_t, e_id, self.size)\n\n\nclass NeighborSampler(torch.utils.data.DataLoader):\n r\"\"\"The neighbor sampler from the `\"Inductive Representation Learning on\n Large Graphs\" `_ paper, which allows\n for mini-batch training of GNNs on large-scale graphs where full-batch\n training is not feasible.\n\n Given a GNN with :math:`L` layers and a specific mini-batch of nodes\n :obj:`node_idx` for which we want to compute embeddings, this module\n iteratively samples neighbors and constructs bipartite graphs that simulate\n the actual computation flow of GNNs.\n\n More specifically, :obj:`sizes` denotes how much neighbors we want to\n sample for each node in each layer.\n This module then takes in these :obj:`sizes` and iteratively samples\n :obj:`sizes[l]` for each node involved in layer :obj:`l`.\n In the next layer, sampling is repeated for the union of nodes that were\n already encountered.\n The actual computation graphs are then returned in reverse-mode, meaning\n that we pass messages from a larger set of nodes to a smaller one, until we\n reach the nodes for which we originally wanted to compute embeddings.\n\n Hence, an item returned by :class:`NeighborSampler` holds the current\n :obj:`batch_size`, the IDs :obj:`n_id` of all nodes involved in the\n computation, and a list of bipartite graph objects via the tuple\n :obj:`(edge_index, e_id, size)`, where :obj:`edge_index` represents the\n bipartite edges between source and target nodes, :obj:`e_id` denotes the\n IDs of original edges in the full graph, and :obj:`size` holds the shape\n of the bipartite graph.\n For each bipartite graph, target nodes are also included at the beginning\n of the list of source nodes so that one can easily apply skip-connections\n or add self-loops.\n\n .. warning::\n\n :class:`~torch_geometric.loader.NeighborSampler` is deprecated and will\n be removed in a future release.\n Use :class:`torch_geometric.loader.NeighborLoader` instead.\n\n .. note::\n\n For an example of using :obj:`NeighborSampler`, see\n `examples/reddit.py\n `_ or\n `examples/ogbn_products_sage.py\n `_.\n\n Args:\n edge_index (Tensor or SparseTensor): A :obj:`torch.LongTensor` or a\n :class:`torch_sparse.SparseTensor` that defines the underlying\n graph connectivity/message passing flow.\n :obj:`edge_index` holds the indices of a (sparse) symmetric\n adjacency matrix.\n If :obj:`edge_index` is of type :obj:`torch.LongTensor`, its shape\n must be defined as :obj:`[2, num_edges]`, where messages from nodes\n :obj:`edge_index[0]` are sent to nodes in :obj:`edge_index[1]`\n (in case :obj:`flow=\"source_to_target\"`).\n If :obj:`edge_index` is of type :class:`torch_sparse.SparseTensor`,\n its sparse indices :obj:`(row, col)` should relate to\n :obj:`row = edge_index[1]` and :obj:`col = edge_index[0]`.\n The major difference between both formats is that we need to input\n the *transposed* sparse adjacency matrix.\n sizes ([int]): The number of neighbors to sample for each node in each\n layer. If set to :obj:`sizes[l] = -1`, all neighbors are included\n in layer :obj:`l`.\n node_idx (LongTensor, optional): The nodes that should be considered\n for creating mini-batches. If set to :obj:`None`, all nodes will be\n considered.\n num_nodes (int, optional): The number of nodes in the graph.\n (default: :obj:`None`)\n return_e_id (bool, optional): If set to :obj:`False`, will not return\n original edge indices of sampled edges. This is only useful in case\n when operating on graphs without edge features to save memory.\n (default: :obj:`True`)\n transform (callable, optional): A function/transform that takes in\n a sampled mini-batch and returns a transformed version.\n (default: :obj:`None`)\n **kwargs (optional): Additional arguments of\n :class:`torch.utils.data.DataLoader`, such as :obj:`batch_size`,\n :obj:`shuffle`, :obj:`drop_last` or :obj:`num_workers`.\n \"\"\"\n def __init__(self, edge_index: Union[Tensor, SparseTensor],\n sizes: List[int], node_idx: Optional[Tensor] = None,\n num_nodes: Optional[int] = None, return_e_id: bool = True,\n transform: Callable = None, **kwargs):\n\n edge_index = edge_index.to('cpu')\n\n # Remove for PyTorch Lightning:\n kwargs.pop('dataset', None)\n kwargs.pop('collate_fn', None)\n\n # Save for Pytorch Lightning < 1.6:\n self.edge_index = edge_index\n self.node_idx = node_idx\n self.num_nodes = num_nodes\n\n self.sizes = sizes\n self.return_e_id = return_e_id\n self.transform = transform\n self.is_sparse_tensor = isinstance(edge_index, SparseTensor)\n self.__val__ = None\n\n # Obtain a *transposed* `SparseTensor` instance.\n if not self.is_sparse_tensor:\n if (num_nodes is None and node_idx is not None\n and node_idx.dtype == torch.bool):\n num_nodes = node_idx.size(0)\n if (num_nodes is None and node_idx is not None\n and node_idx.dtype == torch.long):\n num_nodes = max(int(edge_index.max()), int(node_idx.max())) + 1\n if num_nodes is None:\n num_nodes = int(edge_index.max()) + 1\n\n value = torch.arange(edge_index.size(1)) if return_e_id else None\n self.adj_t = SparseTensor(row=edge_index[0], col=edge_index[1],\n value=value,\n sparse_sizes=(num_nodes, num_nodes)).t()\n else:\n adj_t = edge_index\n if return_e_id:\n self.__val__ = adj_t.storage.value()\n value = torch.arange(adj_t.nnz())\n adj_t = adj_t.set_value(value, layout='coo')\n self.adj_t = adj_t\n\n self.adj_t.storage.rowptr()\n\n if node_idx is None:\n node_idx = torch.arange(self.adj_t.sparse_size(0))\n elif node_idx.dtype == torch.bool:\n node_idx = node_idx.nonzero(as_tuple=False).view(-1)\n\n super().__init__(\n node_idx.view(-1).tolist(), collate_fn=self.sample, **kwargs)\n\n def sample(self, batch):\n if not isinstance(batch, Tensor):\n batch = torch.tensor(batch)\n\n batch_size: int = len(batch)\n\n adjs = []\n n_id = batch\n for size in self.sizes:\n adj_t, n_id = self.adj_t.sample_adj(n_id, size, replace=False)\n e_id = adj_t.storage.value()\n size = adj_t.sparse_sizes()[::-1]\n if self.__val__ is not None:\n adj_t.set_value_(self.__val__[e_id], layout='coo')\n\n if self.is_sparse_tensor:\n adjs.append(Adj(adj_t, e_id, size))\n else:\n row, col, _ = adj_t.coo()\n edge_index = torch.stack([col, row], dim=0)\n adjs.append(EdgeIndex(edge_index, e_id, size))\n\n adjs = adjs[0] if len(adjs) == 1 else adjs[::-1]\n out = (batch_size, n_id, adjs)\n out = self.transform(*out) if self.transform is not None else out\n return out\n\n def __repr__(self) -> str:\n return f'{self.__class__.__name__}(sizes={self.sizes})'\n\n# Path: torch_geometric/loader/prefetch.py\nimport warnings\nfrom contextlib import nullcontext\nfrom functools import partial\nfrom typing import Any, Optional\n\nimport torch\nfrom torch.utils.data import DataLoader\n\nfrom torch_geometric.typing import WITH_IPEX\n\n\nclass DeviceHelper:\n def __init__(self, device: Optional[torch.device] = None):\n with_cuda = torch.cuda.is_available()\n with_xpu = torch.xpu.is_available() if WITH_IPEX else False\n\n if device is None:\n if with_cuda:\n device = 'cuda'\n elif with_xpu:\n device = 'xpu'\n else:\n device = 'cpu'\n\n self.device = torch.device(device)\n self.is_gpu = self.device.type in ['cuda', 'xpu']\n\n if ((self.device.type == 'cuda' and not with_cuda)\n or (self.device.type == 'xpu' and not with_xpu)):\n warnings.warn(f\"Requested device '{self.device.type}' is not \"\n f\"available, falling back to CPU\")\n self.device = torch.device('cpu')\n\n self.stream = None\n self.stream_context = nullcontext\n self.module = getattr(torch, self.device.type) if self.is_gpu else None\n\n def maybe_init_stream(self) -> None:\n if self.is_gpu:\n self.stream = self.module.Stream()\n self.stream_context = partial(\n self.module.stream,\n stream=self.stream,\n )\n\n def maybe_wait_stream(self) -> None:\n if self.stream is not None:\n self.module.current_stream().wait_stream(self.stream)\n\n\nclass PrefetchLoader:\n r\"\"\"A GPU prefetcher class for asynchronously transferring data of a\n :class:`torch.utils.data.DataLoader` from host memory to device memory.\n\n Args:\n loader (torch.utils.data.DataLoader): The data loader.\n device (torch.device, optional): The device to load the data to.\n (default: :obj:`None`)\n \"\"\"\n def __init__(\n self,\n loader: DataLoader,\n device: Optional[torch.device] = None,\n ):\n self.loader = loader\n self.device_helper = DeviceHelper(device)\n\n def non_blocking_transfer(self, batch: Any) -> Any:\n if not self.device_helper.is_gpu:\n return batch\n if isinstance(batch, (list, tuple)):\n return [self.non_blocking_transfer(v) for v in batch]\n if isinstance(batch, dict):\n return {k: self.non_blocking_transfer(v) for k, v in batch.items()}\n\n batch = batch.pin_memory(self.device_helper.device)\n return batch.to(self.device_helper.device, non_blocking=True)\n\n def __iter__(self) -> Any:\n first = True\n self.device_helper.maybe_init_stream()\n\n batch = None\n for next_batch in self.loader:\n\n with self.device_helper.stream_context():\n next_batch = self.non_blocking_transfer(next_batch)\n\n if not first:\n yield batch\n else:\n first = False\n\n self.device_helper.maybe_wait_stream()\n\n batch = next_batch\n\n yield batch\n\n def __len__(self) -> int:\n return len(self.loader)\n\n def __repr__(self) -> str:\n return f'{self.__class__.__name__}({self.loader})'\n\n# Path: torch_geometric/loader/random_node_loader.py\nimport math\nfrom typing import Union\n\nimport torch\nfrom torch import Tensor\n\nfrom torch_geometric.data import Data, HeteroData\nfrom torch_geometric.data.hetero_data import to_homogeneous_edge_index\n\n\nclass RandomNodeLoader(torch.utils.data.DataLoader):\n r\"\"\"A data loader that randomly samples nodes within a graph and returns\n their induced subgraph.\n\n .. note::\n\n For an example of using\n :class:`~torch_geometric.loader.RandomNodeLoader`, see\n `examples/ogbn_proteins_deepgcn.py\n `_.\n\n Args:\n data (torch_geometric.data.Data or torch_geometric.data.HeteroData):\n The :class:`~torch_geometric.data.Data` or\n :class:`~torch_geometric.data.HeteroData` graph object.\n num_parts (int): The number of partitions.\n **kwargs (optional): Additional arguments of\n :class:`torch.utils.data.DataLoader`, such as :obj:`num_workers`.\n \"\"\"\n def __init__(\n self,\n data: Union[Data, HeteroData],\n num_parts: int,\n **kwargs,\n ):\n self.data = data\n self.num_parts = num_parts\n\n if isinstance(data, HeteroData):\n edge_index, node_dict, edge_dict = to_homogeneous_edge_index(data)\n self.node_dict, self.edge_dict = node_dict, edge_dict\n else:\n edge_index = data.edge_index\n\n self.edge_index = edge_index\n self.num_nodes = data.num_nodes\n\n super().__init__(\n range(self.num_nodes),\n batch_size=math.ceil(self.num_nodes / num_parts),\n collate_fn=self.collate_fn,\n **kwargs,\n )\n\n def collate_fn(self, index):\n if not isinstance(index, Tensor):\n index = torch.tensor(index)\n\n if isinstance(self.data, Data):\n return self.data.subgraph(index)\n\n elif isinstance(self.data, HeteroData):\n node_dict = {\n key: index[(index >= start) & (index < end)] - start\n for key, (start, end) in self.node_dict.items()\n }\n return self.data.subgraph(node_dict)\n\n# Path: torch_geometric/loader/shadow.py\nimport copy\nfrom typing import Optional\n\nimport torch\nfrom torch import Tensor\n\nfrom torch_geometric.data import Batch, Data\nfrom torch_geometric.typing import WITH_TORCH_SPARSE, SparseTensor\n\n\nclass ShaDowKHopSampler(torch.utils.data.DataLoader):\n r\"\"\"The ShaDow :math:`k`-hop sampler from the `\"Decoupling the Depth and\n Scope of Graph Neural Networks\" `_ paper.\n Given a graph in a :obj:`data` object, the sampler will create shallow,\n localized subgraphs.\n A deep GNN on this local graph then smooths the informative local signals.\n\n .. note::\n\n For an example of using :class:`ShaDowKHopSampler`, see\n `examples/shadow.py `_.\n\n Args:\n data (torch_geometric.data.Data): The graph data object.\n depth (int): The depth/number of hops of the localized subgraph.\n num_neighbors (int): The number of neighbors to sample for each node in\n each hop.\n node_idx (LongTensor or BoolTensor, optional): The nodes that should be\n considered for creating mini-batches.\n If set to :obj:`None`, all nodes will be\n considered.\n replace (bool, optional): If set to :obj:`True`, will sample neighbors\n with replacement. (default: :obj:`False`)\n **kwargs (optional): Additional arguments of\n :class:`torch.utils.data.DataLoader`, such as :obj:`batch_size` or\n :obj:`num_workers`.\n \"\"\"\n def __init__(self, data: Data, depth: int, num_neighbors: int,\n node_idx: Optional[Tensor] = None, replace: bool = False,\n **kwargs):\n\n if not WITH_TORCH_SPARSE:\n raise ImportError(\n f\"'{self.__class__.__name__}' requires 'torch-sparse'\")\n\n self.data = copy.copy(data)\n self.depth = depth\n self.num_neighbors = num_neighbors\n self.replace = replace\n\n if data.edge_index is not None:\n self.is_sparse_tensor = False\n row, col = data.edge_index.cpu()\n self.adj_t = SparseTensor(\n row=row, col=col, value=torch.arange(col.size(0)),\n sparse_sizes=(data.num_nodes, data.num_nodes)).t()\n else:\n self.is_sparse_tensor = True\n self.adj_t = data.adj_t.cpu()\n\n if node_idx is None:\n node_idx = torch.arange(self.adj_t.sparse_size(0))\n elif node_idx.dtype == torch.bool:\n node_idx = node_idx.nonzero(as_tuple=False).view(-1)\n self.node_idx = node_idx\n\n super().__init__(node_idx.tolist(), collate_fn=self.__collate__,\n **kwargs)\n\n \ndef __collate__(self, n_id):\n n_id = torch.tensor(n_id)\n\n rowptr, col, value = self.adj_t.csr()\n out = torch.ops.torch_sparse.ego_k_hop_sample_adj(\n rowptr, col, n_id, self.depth, self.num_neighbors, self.replace)\n rowptr, col, n_id, e_id, ptr, root_n_id = out\n\n adj_t = SparseTensor(rowptr=rowptr, col=col,\n value=value[e_id] if value is not None else None,\n sparse_sizes=(n_id.numel(), n_id.numel()),\n is_sorted=True, trust_data=True)\n\n batch = Batch(batch=torch.ops.torch_sparse.ptr2ind(ptr, n_id.numel()),\n ptr=ptr)\n batch.root_n_id = root_n_id\n\n if self.is_sparse_tensor:\n batch.adj_t = adj_t\n else:\n row, col, e_id = adj_t.t().coo()\n batch.edge_index = torch.stack([row, col], dim=0)\n\n for k, v in self.data:\n if k in ['edge_index', 'adj_t', 'num_nodes', 'batch', 'ptr']:\n continue\n if k == 'y' and v.size(0) == self.data.num_nodes:\n batch[k] = v[n_id][root_n_id]\n elif isinstance(v, Tensor) and v.size(0) == self.data.num_nodes:\n batch[k] = v[n_id]\n elif isinstance(v, Tensor) and v.size(0) == self.data.num_edges:\n batch[k] = v[e_id]\n else:\n batch[k] = v\n\n return batch\n\n# Path: torch_geometric/loader/temporal_dataloader.py\nfrom typing import List\n\nimport torch\n\nfrom torch_geometric.data import TemporalData\n\n\nclass TemporalDataLoader(torch.utils.data.DataLoader):\n r\"\"\"A data loader which merges succesive events of a\n :class:`torch_geometric.data.TemporalData` to a mini-batch.\n\n Args:\n data (TemporalData): The :obj:`~torch_geometric.data.TemporalData`\n from which to load the data.\n batch_size (int, optional): How many samples per batch to load.\n (default: :obj:`1`)\n neg_sampling_ratio (float, optional): The ratio of sampled negative\n destination nodes to the number of postive destination nodes.\n (default: :obj:`0.0`)\n **kwargs (optional): Additional arguments of\n :class:`torch.utils.data.DataLoader`.\n \"\"\"\n def __init__(\n self,\n data: TemporalData,\n batch_size: int = 1,\n neg_sampling_ratio: float = 0.0,\n **kwargs,\n ):\n # Remove for PyTorch Lightning:\n kwargs.pop('dataset', None)\n kwargs.pop('collate_fn', None)\n kwargs.pop('shuffle', None)\n\n self.data = data\n self.events_per_batch = batch_size\n self.neg_sampling_ratio = neg_sampling_ratio\n\n if neg_sampling_ratio > 0:\n self.min_dst = int(data.dst.min())\n self.max_dst = int(data.dst.max())\n\n if kwargs.get('drop_last', False) and len(data) % batch_size != 0:\n arange = range(0, len(data) - batch_size, batch_size)\n else:\n arange = range(0, len(data), batch_size)\n\n super().__init__(arange, 1, shuffle=False, collate_fn=self, **kwargs)\n\n def __call__(self, arange: List[int]) -> TemporalData:\n batch = self.data[arange[0]:arange[0] + self.events_per_batch]\n\n n_ids = [batch.src, batch.dst]\n\n if self.neg_sampling_ratio > 0:\n batch.neg_dst = torch.randint(\n low=self.min_dst,\n high=self.max_dst + 1,\n size=(round(self.neg_sampling_ratio * batch.dst.size(0)), ),\n dtype=batch.dst.dtype,\n device=batch.dst.device,\n )\n n_ids += [batch.neg_dst]\n\n batch.n_id = torch.cat(n_ids, dim=0).unique()\n\n return batch\n\n# Path: torch_geometric/loader/zip_loader.py\nfrom typing import Any, Iterator, List, Optional, Tuple, Union\n\nimport torch\nfrom torch import Tensor\n\nfrom torch_geometric.data import Data, HeteroData\nfrom torch_geometric.loader import LinkLoader, NodeLoader\nfrom torch_geometric.loader.base import DataLoaderIterator\nfrom torch_geometric.loader.utils import infer_filter_per_worker\n\n\nclass ZipLoader(torch.utils.data.DataLoader):\n r\"\"\"A loader that returns a tuple of data objects by sampling from multiple\n :class:`NodeLoader` or :class:`LinkLoader` instances.\n\n Args:\n loaders (List[NodeLoader] or List[LinkLoader]): The loader instances.\n filter_per_worker (bool, optional): If set to :obj:`True`, will filter\n the returned data in each worker's subprocess.\n If set to :obj:`False`, will filter the returned data in the main\n process.\n If set to :obj:`None`, will automatically infer the decision based\n on whether data partially lives on the GPU\n (:obj:`filter_per_worker=True`) or entirely on the CPU\n (:obj:`filter_per_worker=False`).\n There exists different trade-offs for setting this option.\n Specifically, setting this option to :obj:`True` for in-memory\n datasets will move all features to shared memory, which may result\n in too many open file handles. (default: :obj:`None`)\n **kwargs (optional): Additional arguments of\n :class:`torch.utils.data.DataLoader`, such as :obj:`batch_size`,\n :obj:`shuffle`, :obj:`drop_last` or :obj:`num_workers`.\n \"\"\"\n def __init__(\n self,\n loaders: Union[List[NodeLoader], List[LinkLoader]],\n filter_per_worker: Optional[bool] = None,\n **kwargs,\n ):\n if filter_per_worker is None:\n filter_per_worker = infer_filter_per_worker(loaders[0].data)\n\n # Remove for PyTorch Lightning:\n kwargs.pop('dataset', None)\n kwargs.pop('collate_fn', None)\n\n for loader in loaders:\n if not callable(getattr(loader, 'collate_fn', None)):\n raise ValueError(\"'{loader.__class__.__name__}' does not have \"\n \"a 'collate_fn' method\")\n if not callable(getattr(loader, 'filter_fn', None)):\n raise ValueError(\"'{loader.__class__.__name__}' does not have \"\n \"a 'filter_fn' method\")\n loader.filter_per_worker = filter_per_worker\n\n iterator = range(min([len(loader.dataset) for loader in loaders]))\n super().__init__(iterator, collate_fn=self.collate_fn, **kwargs)\n\n self.loaders = loaders\n self.filter_per_worker = filter_per_worker\n\n def collate_fn(self, index: List[int]) -> Tuple[Any, ...]:\n if not isinstance(index, Tensor):\n index = torch.tensor(index, dtype=torch.long)\n\n return tuple(loader.collate_fn(index) for loader in self.loaders)\n\n def filter_fn(\n self,\n outs: Tuple[Any, ...],\n ) -> Tuple[Union[Data, HeteroData], ...]:\n loaders = self.loaders\n return tuple(loader.filter_fn(v) for loader, v in zip(loaders, outs))\n\n def _get_iterator(self) -> Iterator:\n if self.filter_per_worker:\n return super()._get_iterator()\n\n # Execute `filter_fn` in the main process:\n return DataLoaderIterator(super()._get_iterator(), self.filter_fn)\n\n def __repr__(self) -> str:\n return f'{self.__class__.__name__}(loaders={self.loaders})'\n\n# Path: torch_geometric/loader/__init__.py\nfrom torch_geometric.deprecation import deprecated\n\nfrom .dataloader import DataLoader\nfrom .node_loader import NodeLoader\nfrom .link_loader import LinkLoader\nfrom .neighbor_loader import NeighborLoader\nfrom .link_neighbor_loader import LinkNeighborLoader\nfrom .hgt_loader import HGTLoader\nfrom .cluster import ClusterData, ClusterLoader\nfrom .graph_saint import (GraphSAINTSampler, GraphSAINTNodeSampler,\n GraphSAINTEdgeSampler, GraphSAINTRandomWalkSampler)\nfrom .shadow import ShaDowKHopSampler\nfrom .random_node_loader import RandomNodeLoader\n# from .ibmb_loader import IBMBBatchLoader, IBMBNodeLoader\nfrom .zip_loader import ZipLoader\nfrom .data_list_loader import DataListLoader\nfrom .dense_data_loader import DenseDataLoader\nfrom .temporal_dataloader import TemporalDataLoader\nfrom .neighbor_sampler import NeighborSampler\nfrom .imbalanced_sampler import ImbalancedSampler\nfrom .dynamic_batch_sampler import DynamicBatchSampler\nfrom .prefetch import PrefetchLoader\nfrom .cache import CachedLoader\nfrom .mixin import AffinityMixin\n\n__all__ = classes = [\n 'DataLoader',\n 'NodeLoader',\n 'LinkLoader',\n 'NeighborLoader',\n 'LinkNeighborLoader',\n 'HGTLoader',\n 'ClusterData',\n 'ClusterLoader',\n 'GraphSAINTSampler',\n 'GraphSAINTNodeSampler',\n 'GraphSAINTEdgeSampler',\n 'GraphSAINTRandomWalkSampler',\n 'ShaDowKHopSampler',\n 'RandomNodeLoader',\n # 'IBMBBatchLoader',\n # 'IBMBNodeLoader',\n 'ZipLoader',\n 'DataListLoader',\n 'DenseDataLoader',\n 'TemporalDataLoader',\n 'NeighborSampler',\n 'ImbalancedSampler',\n 'DynamicBatchSampler',\n 'PrefetchLoader',\n 'CachedLoader',\n 'AffinityMixin',\n]\n\nRandomNodeSampler = deprecated(\n details=\"use 'loader.RandomNodeLoader' instead\",\n func_name='loader.RandomNodeSampler',\n)(RandomNodeLoader)\n\n# Path: torch_geometric/data/__init__.py\n# flake8: noqa\n\nfrom .feature_store import FeatureStore, TensorAttr\nfrom .graph_store import GraphStore, EdgeAttr\nfrom .data import Data\nfrom .hetero_data import HeteroData\nfrom .batch import Batch\nfrom .temporal import TemporalData\nfrom .database import Database, SQLiteDatabase, RocksDatabase\nfrom .dataset import Dataset\nfrom .in_memory_dataset import InMemoryDataset\nfrom .on_disk_dataset import OnDiskDataset\nfrom .makedirs import makedirs\nfrom .download import download_url, download_google_url\nfrom .extract import extract_tar, extract_zip, extract_bz2, extract_gz\n\nfrom torch_geometric.lazy_loader import LazyLoader\n\ndata_classes = [\n 'Data',\n 'HeteroData',\n 'Batch',\n 'TemporalData',\n 'Dataset',\n 'InMemoryDataset',\n 'OnDiskDataset',\n]\n\nremote_backend_classes = [\n 'FeatureStore',\n 'GraphStore',\n 'TensorAttr',\n 'EdgeAttr',\n]\n\ndatabase_classes = [\n 'Database',\n 'SQLiteDatabase',\n 'RocksDatabase',\n]\n\nhelper_functions = [\n 'makedirs',\n 'download_url',\n 'download_google_url',\n 'extract_tar',\n 'extract_zip',\n 'extract_bz2',\n 'extract_gz',\n]\n\n__all__ = data_classes + remote_backend_classes + helper_functions\n\nlightning = LazyLoader('lightning', globals(),\n 'torch_geometric.data.lightning')\n\nfrom torch_geometric.deprecation import deprecated\nfrom torch_geometric.loader import NeighborSampler\nfrom torch_geometric.loader import ClusterData\nfrom torch_geometric.loader import ClusterLoader\nfrom torch_geometric.loader import GraphSAINTSampler\nfrom torch_geometric.loader import GraphSAINTNodeSampler\nfrom torch_geometric.loader import GraphSAINTEdgeSampler\nfrom torch_geometric.loader import GraphSAINTRandomWalkSampler\nfrom torch_geometric.loader import ShaDowKHopSampler\nfrom torch_geometric.loader import RandomNodeLoader\nfrom torch_geometric.loader import DataLoader\nfrom torch_geometric.loader import DataListLoader\nfrom torch_geometric.loader import DenseDataLoader\n\nNeighborSampler = deprecated( # type: ignore\n details=\"use 'loader.NeighborSampler' instead\",\n func_name='data.NeighborSampler',\n)(NeighborSampler)\nClusterData = deprecated( # type: ignore\n details=\"use 'loader.ClusterData' instead\",\n func_name='data.ClusterData',\n)(ClusterData)\nClusterLoader = deprecated( # type: ignore\n details=\"use 'loader.ClusterLoader' instead\",\n func_name='data.ClusterLoader',\n)(ClusterLoader)\nGraphSAINTSampler = deprecated( # type: ignore\n details=\"use 'loader.GraphSAINTSampler' instead\",\n func_name='data.GraphSAINTSampler',\n)(GraphSAINTSampler)\nGraphSAINTNodeSampler = deprecated( # type: ignore\n details=\"use 'loader.GraphSAINTNodeSampler' instead\",\n func_name='data.GraphSAINTNodeSampler',\n)(GraphSAINTNodeSampler)\nGraphSAINTEdgeSampler = deprecated( # type: ignore\n details=\"use 'loader.GraphSAINTEdgeSampler' instead\",\n func_name='data.GraphSAINTEdgeSampler',\n)(GraphSAINTEdgeSampler)\nGraphSAINTRandomWalkSampler = deprecated( # type: ignore\n details=\"use 'loader.GraphSAINTRandomWalkSampler' instead\",\n func_name='data.GraphSAINTRandomWalkSampler',\n)(GraphSAINTRandomWalkSampler)\nShaDowKHopSampler = deprecated( # type: ignore\n details=\"use 'loader.ShaDowKHopSampler' instead\",\n func_name='data.ShaDowKHopSampler',\n)(ShaDowKHopSampler)\nRandomNodeSampler = deprecated(\n details=\"use 'loader.RandomNodeLoader' instead\",\n func_name='data.RandomNodeSampler',\n)(RandomNodeLoader)\nDataLoader = deprecated( # type: ignore\n details=\"use 'loader.DataLoader' instead\",\n func_name='data.DataLoader',\n)(DataLoader)\nDataListLoader = deprecated( # type: ignore\n details=\"use 'loader.DataListLoader' instead\",\n func_name='data.DataListLoader',\n)(DataListLoader)\nDenseDataLoader = deprecated( # type: ignore\n details=\"use 'loader.DenseDataLoader' instead\",\n func_name='data.DenseDataLoader',\n)(DenseDataLoader)\n\n# Path: torch_geometric/visualization/graph.py\nfrom math import sqrt\nfrom typing import Any, List, Optional\n\nimport torch\nfrom torch import Tensor\n\nBACKENDS = {'graphviz', 'networkx'}\n\n\ndef has_graphviz() -> bool:\n try:\n import graphviz\n except ImportError:\n return False\n\n try:\n graphviz.Digraph().pipe()\n except graphviz.backend.ExecutableNotFound:\n return False\n\n return True\n\n\ndef visualize_graph(\n edge_index: Tensor,\n edge_weight: Optional[Tensor] = None,\n path: Optional[str] = None,\n backend: Optional[str] = None,\n node_labels: Optional[List[str]] = None,\n) -> Any:\n r\"\"\"Visualizes the graph given via :obj:`edge_index` and (optional)\n :obj:`edge_weight`.\n\n Args:\n edge_index (torch.Tensor): The edge indices.\n edge_weight (torch.Tensor, optional): The edge weights.\n path (str, optional): The path to where the plot is saved.\n If set to :obj:`None`, will visualize the plot on-the-fly.\n (default: :obj:`None`)\n backend (str, optional): The graph drawing backend to use for\n visualization (:obj:`\"graphviz\"`, :obj:`\"networkx\"`).\n If set to :obj:`None`, will use the most appropriate\n visualization backend based on available system packages.\n (default: :obj:`None`)\n node_labels (List[str], optional): The labels/IDs of nodes.\n (default: :obj:`None`)\n \"\"\"\n if edge_weight is not None: # Normalize edge weights.\n edge_weight = edge_weight - edge_weight.min()\n edge_weight = edge_weight / edge_weight.max()\n\n if edge_weight is not None: # Discard any edges with zero edge weight:\n mask = edge_weight > 1e-7\n edge_index = edge_index[:, mask]\n edge_weight = edge_weight[mask]\n\n if edge_weight is None:\n edge_weight = torch.ones(edge_index.size(1))\n\n if backend is None:\n backend = 'graphviz' if has_graphviz() else 'networkx'\n\n if backend.lower() == 'networkx':\n return _visualize_graph_via_networkx(edge_index, edge_weight, path,\n node_labels)\n elif backend.lower() == 'graphviz':\n return _visualize_graph_via_graphviz(edge_index, edge_weight, path,\n node_labels)\n\n raise ValueError(f\"Expected graph drawing backend to be in \"\n f\"{BACKENDS} (got '{backend}')\")\n\n\ndef _visualize_graph_via_graphviz(\n edge_index: Tensor,\n edge_weight: Tensor,\n path: Optional[str] = None,\n node_labels: Optional[List[str]] = None,\n) -> Any:\n import graphviz\n\n suffix = path.split('.')[-1] if path is not None else None\n g = graphviz.Digraph('graph', format=suffix)\n g.attr('node', shape='circle', fontsize='11pt')\n\n for node in edge_index.view(-1).unique().tolist():\n g.node(str(node) if node_labels is None else node_labels[node])\n\n for (src, dst), w in zip(edge_index.t().tolist(), edge_weight.tolist()):\n hex_color = hex(255 - round(255 * w))[2:]\n hex_color = f'{hex_color}0' if len(hex_color) == 1 else hex_color\n if node_labels is not None:\n src = node_labels[src]\n dst = node_labels[dst]\n g.edge(str(src), str(dst), color=f'#{hex_color}{hex_color}{hex_color}')\n\n if path is not None:\n path = '.'.join(path.split('.')[:-1])\n g.render(path, cleanup=True)\n else:\n g.view()\n\n return g\n\n\ndef _visualize_graph_via_networkx(\n edge_index: Tensor,\n edge_weight: Tensor,\n path: Optional[str] = None,\n node_labels: Optional[List[str]] = None,\n) -> Any:\n import matplotlib.pyplot as plt\n import networkx as nx\n\n g = nx.DiGraph()\n node_size = 800\n\n for node in edge_index.view(-1).unique().tolist():\n g.add_node(node if node_labels is None else node_labels[node])\n\n for (src, dst), w in zip(edge_index.t().tolist(), edge_weight.tolist()):\n if node_labels is not None:\n src = node_labels[src]\n dst = node_labels[dst]\n g.add_edge(src, dst, alpha=w)\n\n ax = plt.gca()\n pos = nx.spring_layout(g)\n for src, dst, data in g.edges(data=True):\n ax.annotate(\n '',\n xy=pos[src],\n xytext=pos[dst],\n arrowprops=dict(\n arrowstyle=\"->\",\n alpha=data['alpha'],\n shrinkA=sqrt(node_size) / 2.0,\n shrinkB=sqrt(node_size) / 2.0,\n connectionstyle=\"arc3,rad=0.1\",\n ),\n )\n\n nodes = nx.draw_networkx_nodes(g, pos, node_size=node_size,\n node_color='white', margins=0.1)\n nodes.set_edgecolor('black')\n nx.draw_networkx_labels(g, pos, font_size=10)\n\n if path is not None:\n plt.savefig(path)\n else:\n plt.show()\n\n plt.close()\n\n# Path: torch_geometric/visualization/influence.py\nfrom typing import Any\n\nimport torch\nfrom torch import Tensor\nfrom torch.autograd import grad\n\n\ndef influence(model: torch.nn.Module, src: Tensor, *args: Any) -> Tensor:\n x = src.clone().requires_grad_()\n out = model(x, *args).sum(dim=-1)\n\n influences = []\n for j in range(src.size(0)):\n influence = grad([out[j]], [x], retain_graph=True)[0].abs().sum(dim=-1)\n influences.append(influence / influence.sum())\n\n return torch.stack(influences, dim=0)\n\n# Path: torch_geometric/visualization/__init__.py\nr\"\"\"Visualization package.\"\"\"\n\nfrom .graph import visualize_graph\nfrom .influence import influence\n\n__all__ = [\n 'visualize_graph',\n 'influence',\n]\n\n# Path: torch_geometric/explain/explanation.py\nimport copy\nfrom typing import Dict, List, Optional, Union\n\nimport torch\nfrom torch import Tensor\n\nfrom torch_geometric.data.data import Data, warn_or_raise\nfrom torch_geometric.data.hetero_data import HeteroData\nfrom torch_geometric.explain.config import ThresholdConfig, ThresholdType\nfrom torch_geometric.typing import EdgeType, NodeType\nfrom torch_geometric.visualization import visualize_graph\n\n\nclass ExplanationMixin:\n @property\n def available_explanations(self) -> List[str]:\n \"\"\"Returns the available explanation masks.\"\"\"\n return [key for key in self.keys() if key.endswith('_mask')]\n\n def validate_masks(self, raise_on_error: bool = True) -> bool:\n r\"\"\"Validates the correctness of the :class:`Explanation` masks.\"\"\"\n status = True\n\n for store in self.node_stores:\n if 'node_mask' not in store:\n continue\n\n if store.node_mask.dim() != 2:\n status = False\n warn_or_raise(\n f\"Expected a 'node_mask' with two dimensions (got \"\n f\"{store.node_mask.dim()} dimensions)\", raise_on_error)\n\n if store.node_mask.size(0) not in {1, store.num_nodes}:\n status = False\n warn_or_raise(\n f\"Expected a 'node_mask' with {store.num_nodes} nodes \"\n f\"(got {store.node_mask.size(0)} nodes)\", raise_on_error)\n\n if 'x' in store:\n num_features = store.x.size(-1)\n else:\n num_features = store.node_mask.size(-1)\n\n if store.node_mask.size(1) not in {1, num_features}:\n status = False\n warn_or_raise(\n f\"Expected a 'node_mask' with {num_features} features (\"\n f\"got {store.node_mask.size(1)} features)\", raise_on_error)\n\n for store in self.edge_stores:\n if 'edge_mask' not in store:\n continue\n\n if store.edge_mask.dim() != 1:\n status = False\n warn_or_raise(\n f\"Expected an 'edge_mask' with one dimension (got \"\n f\"{store.edge_mask.dim()} dimensions)\", raise_on_error)\n\n if store.edge_mask.size(0) != store.num_edges:\n status = False\n warn_or_raise(\n f\"Expected an 'edge_mask' with {store.num_edges} edges \"\n f\"(got {store.edge_mask.size(0)} edges)\", raise_on_error)\n\n return status\n\n def _threshold_mask(\n self,\n mask: Optional[Tensor],\n threshold_config: ThresholdConfig,\n ) -> Optional[Tensor]:\n\n if mask is None:\n return None\n\n if threshold_config.type == ThresholdType.hard:\n return (mask > threshold_config.value).float()\n\n if threshold_config.type in [\n ThresholdType.topk,\n ThresholdType.topk_hard,\n ]:\n if threshold_config.value >= mask.numel():\n if threshold_config.type == ThresholdType.topk:\n return mask\n else:\n return torch.ones_like(mask)\n\n value, index = torch.topk(\n mask.flatten(),\n k=threshold_config.value,\n )\n\n out = torch.zeros_like(mask.flatten())\n if threshold_config.type == ThresholdType.topk:\n out[index] = value\n else:\n out[index] = 1.0\n return out.view(mask.size())\n\n assert False\n\n def threshold(\n self,\n *args,\n **kwargs,\n ) -> Union['Explanation', 'HeteroExplanation']:\n \"\"\"Thresholds the explanation masks according to the thresholding\n method.\n\n Args:\n *args: Arguments passed to :class:`ThresholdConfig`.\n **kwargs: Keyword arguments passed to :class:`ThresholdConfig`.\n \"\"\"\n threshold_config = ThresholdConfig.cast(*args, **kwargs)\n\n if threshold_config is None:\n return self\n\n # Avoid modification of the original explanation:\n out = copy.copy(self)\n\n for store in out.node_stores:\n store.node_mask = self._threshold_mask(store.get('node_mask'),\n threshold_config)\n\n for store in out.edge_stores:\n store.edge_mask = self._threshold_mask(store.get('edge_mask'),\n threshold_config)\n\n return out\n\n\nclass Explanation(Data, ExplanationMixin):\n r\"\"\"Holds all the obtained explanations of a homogeneous graph.\n\n The explanation object is a :obj:`~torch_geometric.data.Data` object and\n can hold node attributions and edge attributions.\n It can also hold the original graph if needed.\n\n Args:\n node_mask (Tensor, optional): Node-level mask with shape\n :obj:`[num_nodes, 1]`, :obj:`[1, num_features]` or\n :obj:`[num_nodes, num_features]`. (default: :obj:`None`)\n edge_mask (Tensor, optional): Edge-level mask with shape\n :obj:`[num_edges]`. (default: :obj:`None`)\n **kwargs (optional): Additional attributes.\n \"\"\"\n def validate(self, raise_on_error: bool = True) -> bool:\n r\"\"\"Validates the correctness of the :class:`Explanation` object.\"\"\"\n status = super().validate(raise_on_error)\n status &= self.validate_masks(raise_on_error)\n return status\n\n def get_explanation_subgraph(self) -> 'Explanation':\n r\"\"\"Returns the induced subgraph, in which all nodes and edges with\n zero attribution are masked out.\n \"\"\"\n node_mask = self.get('node_mask')\n if node_mask is not None:\n node_mask = node_mask.sum(dim=-1) > 0\n edge_mask = self.get('edge_mask')\n if edge_mask is not None:\n edge_mask = edge_mask > 0\n return self._apply_masks(node_mask, edge_mask)\n\n def get_complement_subgraph(self) -> 'Explanation':\n r\"\"\"Returns the induced subgraph, in which all nodes and edges with any\n attribution are masked out.\n \"\"\"\n node_mask = self.get('node_mask')\n if node_mask is not None:\n node_mask = node_mask.sum(dim=-1) == 0\n edge_mask = self.get('edge_mask')\n if edge_mask is not None:\n edge_mask = edge_mask == 0\n return self._apply_masks(node_mask, edge_mask)\n\n def _apply_masks(\n self,\n node_mask: Optional[Tensor] = None,\n edge_mask: Optional[Tensor] = None,\n ) -> 'Explanation':\n out = copy.copy(self)\n\n if edge_mask is not None:\n for key, value in self.items():\n if key == 'edge_index':\n out.edge_index = value[:, edge_mask]\n elif self.is_edge_attr(key):\n out[key] = value[edge_mask]\n\n if node_mask is not None:\n out = out.subgraph(node_mask)\n\n return out\n\n def visualize_feature_importance(\n self,\n path: Optional[str] = None,\n feat_labels: Optional[List[str]] = None,\n top_k: Optional[int] = None,\n ):\n r\"\"\"Creates a bar plot of the node feature importances by summing up\n the node mask across all nodes.\n\n Args:\n path (str, optional): The path to where the plot is saved.\n If set to :obj:`None`, will visualize the plot on-the-fly.\n (default: :obj:`None`)\n feat_labels (List[str], optional): The labels of features.\n (default :obj:`None`)\n top_k (int, optional): Top k features to plot. If :obj:`None`\n plots all features. (default: :obj:`None`)\n \"\"\"\n node_mask = self.get('node_mask')\n if node_mask is None:\n raise ValueError(f\"The attribute 'node_mask' is not available \"\n f\"in '{self.__class__.__name__}' \"\n f\"(got {self.available_explanations})\")\n if node_mask.dim() != 2 or node_mask.size(1) <= 1:\n raise ValueError(f\"Cannot compute feature importance for \"\n f\"object-level 'node_mask' \"\n f\"(got shape {node_mask.size()})\")\n\n if feat_labels is None:\n feat_labels = range(node_mask.size(1))\n\n score = node_mask.sum(dim=0)\n\n return _visualize_score(score, feat_labels, path, top_k)\n\n def visualize_graph(\n self,\n path: Optional[str] = None,\n backend: Optional[str] = None,\n node_labels: Optional[List[str]] = None,\n ) -> None:\n r\"\"\"Visualizes the explanation graph with edge opacity corresponding to\n edge importance.\n\n Args:\n path (str, optional): The path to where the plot is saved.\n If set to :obj:`None`, will visualize the plot on-the-fly.\n (default: :obj:`None`)\n backend (str, optional): The graph drawing backend to use for\n visualization (:obj:`\"graphviz\"`, :obj:`\"networkx\"`).\n If set to :obj:`None`, will use the most appropriate\n visualization backend based on available system packages.\n (default: :obj:`None`)\n node_labels (list[str], optional): The labels/IDs of nodes.\n (default: :obj:`None`)\n \"\"\"\n edge_mask = self.get('edge_mask')\n if edge_mask is None:\n raise ValueError(f\"The attribute 'edge_mask' is not available \"\n f\"in '{self.__class__.__name__}' \"\n f\"(got {self.available_explanations})\")\n visualize_graph(self.edge_index, edge_mask, path, backend, node_labels)\n\n\nclass HeteroExplanation(HeteroData, ExplanationMixin):\n r\"\"\"Holds all the obtained explanations of a heterogeneous graph.\n\n The explanation object is a :obj:`~torch_geometric.data.HeteroData` object\n and can hold node attributions and edge attributions.\n It can also hold the original graph if needed.\n \"\"\"\n def validate(self, raise_on_error: bool = True) -> bool:\n r\"\"\"Validates the correctness of the :class:`Explanation` object.\"\"\"\n status = super().validate(raise_on_error)\n status &= self.validate_masks(raise_on_error)\n return status\n\n def get_explanation_subgraph(self) -> 'HeteroExplanation':\n r\"\"\"Returns the induced subgraph, in which all nodes and edges with\n zero attribution are masked out.\n \"\"\"\n return self._apply_masks(\n node_mask_dict={\n key: mask.sum(dim=-1) > 0\n for key, mask in self.collect('node_mask', True).items()\n },\n edge_mask_dict={\n key: mask > 0\n for key, mask in self.collect('edge_mask', True).items()\n },\n )\n\n def get_complement_subgraph(self) -> 'HeteroExplanation':\n r\"\"\"Returns the induced subgraph, in which all nodes and edges with any\n attribution are masked out.\n \"\"\"\n return self._apply_masks(\n...\n", "needle_token_start": 7004, "needle_token_end": 7512, "code_context_ntokens": 16118, "output": ["```python\nclass ShaDowKHopSampler(torch.utils.data.DataLoader):\n r\"\"\"The ShaDow :math:`k`-hop sampler from the `\"Decoupling the Depth and\n Scope of Graph Neural Networks\" `_ paper.\n Given a graph in a :obj:`data` object, the sampler will create shallow,\n localized subgraphs.\n A deep GNN on this local graph then smooths the informative local signals.\n\n .. note::\n\n For an example of using :class:`ShaDowKHopSampler`, see\n `examples/shadow.py `_.\n\n Args:\n data (torch_geometric.data.Data): The graph data object.\n depth (int): The depth/number of hops of the localized subgraph.\n num_neighbors (int): The number of neighbors to sample for each node in\n each hop.\n node_idx (LongTensor or BoolTensor, optional): The nodes that should be\n considered for creating mini-batches.\n If set to :obj:`None`, all nodes will be\n considered.\n replace (bool, optional): If set to :obj:`True`, will sample neighbors\n with replacement. (default: :obj:`False`)\n **kwargs (optional): Additional arguments of\n :class:`torch.utils.data.DataLoader`, such as :obj:`batch_size` or\n :obj:`num_workers`.\n \"\"\"\n def __init__(self, data: Data, depth: int, num_neighbors: int,\n node_idx: Optional[Tensor] = None, replace: bool = False,\n **kwargs):\n\n if not WITH_TORCH_SPARSE:\n raise ImportError(\n f\"'{self.__class__.__name__}' requires 'torch-sparse'\")\n\n self.data = copy.copy(data)\n self.depth = depth\n self.num_neighbors = num_neighbors\n self.replace = replace\n\n if data.edge_index is not None:\n self.is_sparse_tensor = False\n row, col = data.edge_index.cpu()\n self.adj_t = SparseTensor(\n row=row, col=col, value=torch.arange(col.size(0)),\n sparse_sizes=(data.num_nodes, data.num_nodes)).t()\n else:\n self.is_sparse_tensor = True\n self.adj_t = data.adj_t.cpu()\n\n if node_idx is None:\n node_idx = torch.arange(self.adj_t.sparse_size(0))\n elif node_idx.dtype == torch.bool:\n node_idx = node_idx.nonzero(as_tuple=False).view(-1)\n self.node_idx = node_idx\n\n super().__init__(node_idx.tolist(), collate_fn=self.__collate__,\n **kwargs)\n\n \ndef __collate__(self, n_id):\n n_id = torch.tensor(n_id)\n\n rowptr, col, value = self.adj_t.csr()\n out = torch.ops.torch_sparse.ego_k_hop_sample_adj(\n rowptr, col, n_id, self.depth, self.num_neighbors, self.replace)\n rowptr, col, n_id, e_id, ptr, root_n_id = out\n\n adj_t = SparseTensor(rowptr=rowptr, col=col,\n value=value[e_id] if value is not None else None,\n sparse_sizes=(n_id.numel(), n_id.numel()),\n is_sorted=True, trust_data=True)\n\n batch = Batch(batch=torch.ops.torch_sparse.ptr2ind(ptr, n_id.numel()),\n ptr=ptr)\n batch.root_n_id = root_n_id\n\n if self.is_sparse_tensor:\n batch.adj_t = adj_t\n else:\n row, col, e_id = adj_t.t().coo()\n batch.edge_index = torch.stack([row, col], dim=0)\n\n for k, v in self.data:\n if k in ['edge_index', 'adj_t', 'num_nodes', 'batch', 'ptr']:\n continue\n if k == 'y' and v.size(0) == self.data.num_nodes:\n batch[k] = v[n_id][root_n_id]\n elif isinstance(v, Tensor) and v.size(0) == self.data.num_nodes:\n batch[k] = v[n_id]\n elif isinstance(v, Tensor) and v.size(0) == self.data.num_edges:\n batch[k] = v[e_id]\n else:\n batch[k] = v\n\n return batch\n```"]} +{"repo": "pyg-team/pytorch_geometric", "name": "tree", "language": "python", "path": "torch_geometric/datasets/graph_generator/tree_graph.py", "position_ratio": 0.55, "description": "\nFunction Description:\n1. **Purpose**: The function generates a tree graph with specified depth and branching factor, optionally making it undirected, and provides node-level depth information.\n2. **Input**: \n - An integer representing the depth of the tree.\n - An optional integer for the branching factor of each node (default is 2).\n - An optional boolean to determine if the graph should be undirected (default is False).\n - An optional device specification for where the tensors should be computed (default is CPU).\n3. **Output**: \n - A tensor representing the edge connections between nodes in the tree.\n - A tensor indicating the depth level of each node in the tree.\n4. **Procedure**: \n - Initialize lists to store edges and node depths.\n - Define a recursive function to add edges and node depths based on the current node and depth, iterating until the specified tree depth is reached.\n - Convert the list of edges into a tensor and, if specified, make the graph undirected by adding reciprocal edges.\n - Return the edge tensor and the node depth tensor.\n\n", "instruction": "Based on the function description and code context, please retrieve and repeat the exact described function from the code context in a code block wrapped by ```:", "template": "instruction\ncode_context\ndescription\ninstruction", "code_context": "# Path: torch_geometric/datasets/dynamic_faust.py\nfrom itertools import product\nfrom typing import Callable, List, Optional\n\nimport torch\n\nfrom torch_geometric.data import Data, InMemoryDataset\n\n\nclass DynamicFAUST(InMemoryDataset):\n r\"\"\"The dynamic FAUST humans dataset from the `\"Dynamic FAUST: Registering\n Human Bodies in Motion\"\n `_ paper.\n\n .. note::\n\n Data objects hold mesh faces instead of edge indices.\n To convert the mesh to a graph, use the\n :obj:`torch_geometric.transforms.FaceToEdge` as :obj:`pre_transform`.\n To convert the mesh to a point cloud, use the\n :obj:`torch_geometric.transforms.SamplePoints` as :obj:`transform` to\n sample a fixed number of points on the mesh faces according to their\n face area.\n\n Args:\n root (str): Root directory where the dataset should be saved.\n subjects (list, optional): List of subjects to include in the\n dataset. Can include the subjects :obj:`\"50002\"`, :obj:`\"50004\"`,\n :obj:`\"50007\"`, :obj:`\"50009\"`, :obj:`\"50020\"`, :obj:`\"50021\"`,\n :obj:`\"50022\"`, :obj:`\"50025\"`, :obj:`\"50026\"`, :obj:`\"50027\"`.\n If set to :obj:`None`, the dataset will contain all subjects.\n (default: :obj:`None`)\n categories (list, optional): List of categories to include in the\n dataset. Can include the categories :obj:`\"chicken_wings\"`,\n :obj:`\"hips\"`, :obj:`\"jiggle_on_toes\"`, :obj:`\"jumping_jacks\"`,\n :obj:`\"knees\"`, :obj:`\"light_hopping_loose\"`,\n :obj:`\"light_hopping_stiff\"`, :obj:`\"one_leg_jump\"`,\n :obj:`\"one_leg_loose\"`, :obj:`\"personal_move\"`, :obj:`\"punching\"`,\n :obj:`\"running_on_spot\"`, :obj:`\"running_on_spot_bugfix\"`,\n :obj:`\"shake_arms\"`, :obj:`\"shake_hips\"`, :obj:`\"shoulders\"`.\n If set to :obj:`None`, the dataset will contain all categories.\n (default: :obj:`None`)\n transform (callable, optional): A function/transform that takes in an\n :obj:`torch_geometric.data.Data` object and returns a transformed\n version. The data object will be transformed before every access.\n (default: :obj:`None`)\n pre_transform (callable, optional): A function/transform that takes in\n an :obj:`torch_geometric.data.Data` object and returns a\n transformed version. The data object will be transformed before\n being saved to disk. (default: :obj:`None`)\n pre_filter (callable, optional): A function that takes in an\n :obj:`torch_geometric.data.Data` object and returns a boolean\n value, indicating whether the data object should be included in the\n final dataset. (default: :obj:`None`)\n force_reload (bool, optional): Whether to re-process the dataset.\n (default: :obj:`False`)\n \"\"\"\n\n url = 'http://dfaust.is.tue.mpg.de/'\n\n subjects = [\n '50002', '50004', '50007', '50009', '50020', '50021', '50022', '50025',\n '50026', '50027'\n ]\n categories = [\n 'chicken_wings', 'hips', 'jiggle_on_toes', 'jumping_jacks', 'knees',\n 'light_hopping_loose', 'light_hopping_stiff', 'one_leg_jump',\n 'one_leg_loose', 'personal_move', 'punching', 'running_on_spot',\n 'running_on_spot_bugfix', 'shake_arms', 'shake_hips', 'shake_shoulders'\n ]\n\n def __init__(\n self,\n root: str,\n subjects: Optional[List[str]] = None,\n categories: Optional[List[str]] = None,\n transform: Optional[Callable] = None,\n pre_transform: Optional[Callable] = None,\n pre_filter: Optional[Callable] = None,\n force_reload: bool = False,\n ) -> None:\n\n subjects = self.subjects if subjects is None else subjects\n subjects = [sid.lower() for sid in subjects]\n for sid in subjects:\n assert sid in self.subjects\n self.subjects = subjects\n\n categories = self.categories if categories is None else categories\n categories = [cat.lower() for cat in categories]\n for cat in categories:\n assert cat in self.categories\n self.categories = categories\n\n super().__init__(root, transform, pre_transform, pre_filter,\n force_reload=force_reload)\n self.load(self.processed_paths[0])\n...\n# Path: torch_geometric/datasets/elliptic.py\nfrom typing import Any, Callable, List, Optional, Tuple\n\nimport torch\n\nfrom torch_geometric.data import Data, InMemoryDataset\nfrom torch_geometric.io import fs\n\n\nclass EllipticBitcoinDataset(InMemoryDataset):\n r\"\"\"The Elliptic Bitcoin dataset of Bitcoin transactions from the\n `\"Anti-Money Laundering in Bitcoin: Experimenting with Graph Convolutional\n Networks for Financial Forensics\" `_\n paper.\n\n :class:`EllipticBitcoinDataset` maps Bitcoin transactions to real entities\n belonging to licit categories (exchanges, wallet providers, miners,\n licit services, etc.) versus illicit ones (scams, malware, terrorist\n organizations, ransomware, Ponzi schemes, etc.)\n\n There exists 203,769 node transactions and 234,355 directed edge payments\n flows, with two percent of nodes (4,545) labelled as illicit, and\n twenty-one percent of nodes (42,019) labelled as licit.\n The remaining transactions are unknown.\n\n Args:\n root (str): Root directory where the dataset should be saved.\n transform (callable, optional): A function/transform that takes in an\n :obj:`torch_geometric.data.Data` object and returns a transformed\n version. The data object will be transformed before every access.\n (default: :obj:`None`)\n pre_transform (callable, optional): A function/transform that takes in\n an :obj:`torch_geometric.data.Data` object and returns a\n transformed version. The data object will be transformed before\n being saved to disk. (default: :obj:`None`)\n force_reload (bool, optional): Whether to re-process the dataset.\n (default: :obj:`False`)\n\n **STATS:**\n\n .. list-table::\n :widths: 10 10 10 10\n :header-rows: 1\n\n * - #nodes\n - #edges\n - #features\n - #classes\n * - 203,769\n - 234,355\n - 165\n - 2\n \"\"\"\n url = 'https://data.pyg.org/datasets/elliptic'\n\n def __init__(\n self,\n root: str,\n transform: Optional[Callable] = None,\n pre_transform: Optional[Callable] = None,\n force_reload: bool = False,\n ) -> None:\n super().__init__(root, transform, pre_transform,\n force_reload=force_reload)\n self.load(self.processed_paths[0])\n\n @property\n def raw_file_names(self) -> List[str]:\n return [\n 'elliptic_txs_features.csv',\n 'elliptic_txs_edgelist.csv',\n 'elliptic_txs_classes.csv',\n ]\n\n @property\n def processed_file_names(self) -> str:\n return 'data.pt'\n\n def download(self) -> None:\n for file_name in self.raw_file_names:\n fs.cp(f'{self.url}/{file_name}.zip', self.raw_dir, extract=True)\n\n def _process_df(self, feat_df: Any, edge_df: Any,\n class_df: Any) -> Tuple[Any, Any, Any]:\n return feat_df, edge_df, class_df\n\n def process(self) -> None:\n import pandas as pd\n\n feat_df = pd.read_csv(self.raw_paths[0], header=None)\n edge_df = pd.read_csv(self.raw_paths[1])\n class_df = pd.read_csv(self.raw_paths[2])\n\n columns = {0: 'txId', 1: 'time_step'}\n feat_df = feat_df.rename(columns=columns)\n\n feat_df, edge_df, class_df = self._process_df(\n feat_df,\n edge_df,\n class_df,\n )\n\n x = torch.from_numpy(feat_df.loc[:, 2:].values).to(torch.float)\n\n # There exists 3 different classes in the dataset:\n # 0=licit, 1=illicit, 2=unknown\n mapping = {'unknown': 2, '1': 1, '2': 0}\n class_df['class'] = class_df['class'].map(mapping)\n y = torch.from_numpy(class_df['class'].values)\n\n mapping = {idx: i for i, idx in enumerate(feat_df['txId'].values)}\n edge_df['txId1'] = edge_df['txId1'].map(mapping)\n edge_df['txId2'] = edge_df['txId2'].map(mapping)\n edge_index = torch.from_numpy(edge_df.values).t().contiguous()\n\n # Timestamp based split:\n # train_mask: 1 - 34 time_step, test_mask: 35-49 time_step\n time_step = torch.from_numpy(feat_df['time_step'].values)\n train_mask = (time_step < 35) & (y != 2)\n test_mask = (time_step >= 35) & (y != 2)\n\n data = Data(x=x, edge_index=edge_index, y=y, train_mask=train_mask,\n test_mask=test_mask)\n\n if self.pre_transform is not None:\n data = self.pre_transform(data)\n\n self.save([data], self.processed_paths[0])\n\n @property\n def num_classes(self) -> int:\n return 2\n\n# Path: torch_geometric/datasets/elliptic_temporal.py\nfrom typing import Any, Callable, Optional, Tuple\n\nfrom torch_geometric.datasets import EllipticBitcoinDataset\n\n\nclass EllipticBitcoinTemporalDataset(EllipticBitcoinDataset):\n r\"\"\"The time-step aware Elliptic Bitcoin dataset of Bitcoin transactions\n from the `\"Anti-Money Laundering in Bitcoin: Experimenting with Graph\n Convolutional Networks for Financial Forensics\"\n `_ paper.\n\n :class:`EllipticBitcoinTemporalDataset` maps Bitcoin transactions to real\n entities belonging to licit categories (exchanges, wallet providers,\n miners, licit services, etc.) versus illicit ones (scams, malware,\n terrorist organizations, ransomware, Ponzi schemes, etc.)\n\n There exists 203,769 node transactions and 234,355 directed edge payments\n flows, with two percent of nodes (4,545) labelled as illicit, and\n twenty-one percent of nodes (42,019) labelled as licit.\n The remaining transactions are unknown.\n\n .. note::\n\n In contrast to :class:`EllipticBitcoinDataset`, this dataset returns\n Bitcoin transactions only for a given timestamp :obj:`t`.\n\n Args:\n root (str): Root directory where the dataset should be saved.\n t (int): The Timestep for which nodes should be selected (from :obj:`1`\n to :obj:`49`).\n transform (callable, optional): A function/transform that takes in an\n :obj:`torch_geometric.data.Data` object and returns a transformed\n version. The data object will be transformed before every access.\n (default: :obj:`None`)\n pre_transform (callable, optional): A function/transform that takes in\n an :obj:`torch_geometric.data.Data` object and returns a\n transformed version. The data object will be transformed before\n being saved to disk. (default: :obj:`None`)\n force_reload (bool, optional): Whether to re-process the dataset.\n (default: :obj:`False`)\n\n **STATS:**\n\n .. list-table::\n :widths: 10 10 10 10\n :header-rows: 1\n\n * - #nodes\n - #edges\n - #features\n - #classes\n * - 203,769\n - 234,355\n - 165\n - 2\n \"\"\"\n def __init__(\n self,\n root: str,\n t: int,\n transform: Optional[Callable] = None,\n pre_transform: Optional[Callable] = None,\n force_reload: bool = False,\n ):\n if t < 1 or t > 49:\n raise ValueError(\"'t' needs to be between 1 and 49\")\n\n self.t = t\n super().__init__(root, transform, pre_transform,\n force_reload=force_reload)\n\n @property\n def processed_file_names(self) -> str:\n return f'data_t_{self.t}.pt'\n\n def _process_df(self, feat_df: Any, edge_df: Any,\n class_df: Any) -> Tuple[Any, Any, Any]:\n\n feat_df = feat_df[feat_df['time_step'] == self.t]\n\n mask = edge_df['txId1'].isin(feat_df['txId'].values)\n edge_df = edge_df[mask]\n\n class_df = class_df.merge(feat_df[['txId']], how='right',\n left_on='txId', right_on='txId')\n\n return feat_df, edge_df, class_df\n\n# Path: torch_geometric/datasets/email_eu_core.py\nimport os\nfrom typing import Callable, List, Optional\n\nimport torch\n\nfrom torch_geometric.data import (\n Data,\n InMemoryDataset,\n download_url,\n extract_gz,\n)\n\n\nclass EmailEUCore(InMemoryDataset):\n r\"\"\"An e-mail communication network of a large European research\n institution, taken from the `\"Local Higher-order Graph Clustering\"\n `_ paper.\n Nodes indicate members of the institution.\n An edge between a pair of members indicates that they exchanged at least\n one email.\n Node labels indicate membership to one of the 42 departments.\n\n Args:\n root (str): Root directory where the dataset should be saved.\n transform (callable, optional): A function/transform that takes in an\n :obj:`torch_geometric.data.Data` object and returns a transformed\n version. The data object will be transformed before every access.\n (default: :obj:`None`)\n pre_transform (callable, optional): A function/transform that takes in\n an :obj:`torch_geometric.data.Data` object and returns a\n transformed version. The data object will be transformed before\n being saved to disk. (default: :obj:`None`)\n force_reload (bool, optional): Whether to re-process the dataset.\n (default: :obj:`False`)\n \"\"\"\n\n urls = [\n 'https://snap.stanford.edu/data/email-Eu-core.txt.gz',\n 'https://snap.stanford.edu/data/email-Eu-core-department-labels.txt.gz'\n ]\n\n def __init__(\n self,\n root: str,\n transform: Optional[Callable] = None,\n pre_transform: Optional[Callable] = None,\n force_reload: bool = False,\n ) -> None:\n super().__init__(root, transform, pre_transform,\n force_reload=force_reload)\n self.load(self.processed_paths[0])\n\n @property\n def raw_file_names(self) -> List[str]:\n return ['email-Eu-core.txt', 'email-Eu-core-department-labels.txt']\n\n @property\n def processed_file_names(self) -> str:\n return 'data.pt'\n\n def download(self) -> None:\n for url in self.urls:\n path = download_url(url, self.raw_dir)\n extract_gz(path, self.raw_dir)\n os.unlink(path)\n\n def process(self) -> None:\n import pandas as pd\n\n edge_index = pd.read_csv(self.raw_paths[0], sep=' ', header=None)\n edge_index = torch.from_numpy(edge_index.values).t().contiguous()\n\n y = pd.read_csv(self.raw_paths[1], sep=' ', header=None, usecols=[1])\n y = torch.from_numpy(y.values).view(-1)\n\n data = Data(edge_index=edge_index, y=y, num_nodes=y.size(0))\n\n if self.pre_transform is not None:\n data = self.pre_transform(data)\n\n self.save([data], self.processed_paths[0])\n\n# Path: torch_geometric/datasets/entities.py\nimport logging\nimport os\nimport os.path as osp\nfrom collections import Counter\nfrom typing import Any, Callable, List, Optional\n\nimport torch\n\nfrom torch_geometric.data import (\n Data,\n HeteroData,\n InMemoryDataset,\n download_url,\n extract_tar,\n)\nfrom torch_geometric.utils import index_sort\n\n\nclass Entities(InMemoryDataset):\n r\"\"\"The relational entities networks :obj:`\"AIFB\"`, :obj:`\"MUTAG\"`,\n :obj:`\"BGS\"` and :obj:`\"AM\"` from the `\"Modeling Relational Data with Graph\n Convolutional Networks\" `_ paper.\n Training and test splits are given by node indices.\n\n Args:\n root (str): Root directory where the dataset should be saved.\n name (str): The name of the dataset (:obj:`\"AIFB\"`, :obj:`\"MUTAG\"`,\n :obj:`\"BGS\"`, :obj:`\"AM\"`).\n hetero (bool, optional): If set to :obj:`True`, will save the dataset\n as a :class:`~torch_geometric.data.HeteroData` object.\n (default: :obj:`False`)\n transform (callable, optional): A function/transform that takes in an\n :obj:`torch_geometric.data.Data` object and returns a transformed\n version. The data object will be transformed before every access.\n (default: :obj:`None`)\n pre_transform (callable, optional): A function/transform that takes in\n an :obj:`torch_geometric.data.Data` object and returns a\n transformed version. The data object will be transformed before\n being saved to disk. (default: :obj:`None`)\n force_reload (bool, optional): Whether to re-process the dataset.\n (default: :obj:`False`)\n\n **STATS:**\n\n .. list-table::\n :widths: 10 10 10 10 10\n :header-rows: 1\n\n * - Name\n - #nodes\n - #edges\n - #features\n - #classes\n * - AIFB\n - 8,285\n - 58,086\n - 0\n - 4\n * - AM\n - 1,666,764\n - 11,976,642\n - 0\n - 11\n * - MUTAG\n - 23,644\n - 148,454\n - 0\n - 2\n * - BGS\n - 333,845\n - 1,832,398\n - 0\n - 2\n \"\"\"\n\n url = 'https://data.dgl.ai/dataset/{}.tgz'\n\n def __init__(\n self,\n root: str,\n name: str,\n hetero: bool = False,\n transform: Optional[Callable] = None,\n pre_transform: Optional[Callable] = None,\n force_reload: bool = False,\n ) -> None:\n self.name = name.lower()\n self.hetero = hetero\n assert self.name in ['aifb', 'am', 'mutag', 'bgs']\n super().__init__(root, transform, pre_transform,\n force_reload=force_reload)\n if hetero:\n self.load(self.processed_paths[0], data_cls=HeteroData)\n else:\n self.load(self.processed_paths[0], data_cls=Data)\n\n @property\n def raw_dir(self) -> str:\n return osp.join(self.root, self.name, 'raw')\n\n @property\n def processed_dir(self) -> str:\n return osp.join(self.root, self.name, 'processed')\n\n @property\n def num_relations(self) -> int:\n return int(self._data.edge_type.max()) + 1 # type: ignore\n\n @property\n def num_classes(self) -> int:\n return int(self._data.train_y.max()) + 1 # type: ignore\n\n @property\n def raw_file_names(self) -> List[str]:\n return [\n f'{self.name}_stripped.nt.gz',\n 'completeDataset.tsv',\n 'trainingSet.tsv',\n 'testSet.tsv',\n ]\n\n @property\n def processed_file_names(self) -> str:\n return 'hetero_data.pt' if self.hetero else 'data.pt'\n\n def download(self) -> None:\n path = download_url(self.url.format(self.name), self.root)\n extract_tar(path, self.raw_dir)\n os.unlink(path)\n\n def process(self) -> None:\n import gzip\n\n import pandas as pd\n import rdflib as rdf\n\n graph_file, task_file, train_file, test_file = self.raw_paths\n\n with hide_stdout():\n g = rdf.Graph()\n with gzip.open(graph_file, 'rb') as f:\n g.parse(file=f, format='nt') # type: ignore\n\n freq = Counter(g.predicates())\n\n relations = sorted(set(g.predicates()), key=lambda p: -freq.get(p, 0))\n subjects = set(g.subjects())\n objects = set(g.objects())\n nodes = list(subjects.union(objects))\n\n N = len(nodes)\n R = 2 * len(relations)\n\n relations_dict = {rel: i for i, rel in enumerate(relations)}\n nodes_dict = {str(node): i for i, node in enumerate(nodes)}\n\n edges = []\n for s, p, o in g.triples((None, None, None)):\n src, dst = nodes_dict[str(s)], nodes_dict[str(o)]\n rel = relations_dict[p]\n edges.append([src, dst, 2 * rel])\n edges.append([dst, src, 2 * rel + 1])\n\n edge = torch.tensor(edges, dtype=torch.long).t().contiguous()\n _, perm = index_sort(N * R * edge[0] + R * edge[1] + edge[2])\n edge = edge[:, perm]\n\n edge_index, edge_type = edge[:2], edge[2]\n\n if self.name == 'am':\n label_header = 'label_cateogory'\n nodes_header = 'proxy'\n elif self.name == 'aifb':\n label_header = 'label_affiliation'\n nodes_header = 'person'\n elif self.name == 'mutag':\n label_header = 'label_mutagenic'\n nodes_header = 'bond'\n elif self.name == 'bgs':\n label_header = 'label_lithogenesis'\n nodes_header = 'rock'\n\n labels_df = pd.read_csv(task_file, sep='\\t')\n labels_set = set(labels_df[label_header].values.tolist())\n labels_dict = {lab: i for i, lab in enumerate(list(labels_set))}\n\n train_labels_df = pd.read_csv(train_file, sep='\\t')\n train_indices, train_labels = [], []\n for nod, lab in zip(train_labels_df[nodes_header].values,\n train_labels_df[label_header].values):\n train_indices.append(nodes_dict[nod])\n train_labels.append(labels_dict[lab])\n\n train_idx = torch.tensor(train_indices, dtype=torch.long)\n train_y = torch.tensor(train_labels, dtype=torch.long)\n\n test_labels_df = pd.read_csv(test_file, sep='\\t')\n test_indices, test_labels = [], []\n for nod, lab in zip(test_labels_df[nodes_header].values,\n test_labels_df[label_header].values):\n test_indices.append(nodes_dict[nod])\n test_labels.append(labels_dict[lab])\n\n test_idx = torch.tensor(test_indices, dtype=torch.long)\n test_y = torch.tensor(test_labels, dtype=torch.long)\n\n data = Data(edge_index=edge_index, edge_type=edge_type,\n train_idx=train_idx, train_y=train_y, test_idx=test_idx,\n test_y=test_y, num_nodes=N)\n\n if self.hetero:\n data = data.to_heterogeneous(node_type_names=['v'])\n\n self.save([data], self.processed_paths[0])\n\n def __repr__(self) -> str:\n return f'{self.name.upper()}{self.__class__.__name__}()'\n\n\nclass hide_stdout:\n def __enter__(self) -> None:\n self.level = logging.getLogger().level\n logging.getLogger().setLevel(logging.ERROR)\n\n def __exit__(self, *args: Any) -> None:\n logging.getLogger().setLevel(self.level)\n\n# Path: torch_geometric/datasets/graph_generator/ba_graph.py\nfrom torch_geometric.data import Data\nfrom torch_geometric.datasets.graph_generator import GraphGenerator\nfrom torch_geometric.utils import barabasi_albert_graph\n\n\nclass BAGraph(GraphGenerator):\n r\"\"\"Generates random Barabasi-Albert (BA) graphs.\n See :meth:`~torch_geometric.utils.barabasi_albert_graph` for more\n information.\n\n Args:\n num_nodes (int): The number of nodes.\n num_edges (int): The number of edges from a new node to existing nodes.\n \"\"\"\n def __init__(self, num_nodes: int, num_edges: int):\n super().__init__()\n self.num_nodes = num_nodes\n self.num_edges = num_edges\n\n def __call__(self) -> Data:\n edge_index = barabasi_albert_graph(self.num_nodes, self.num_edges)\n return Data(num_nodes=self.num_nodes, edge_index=edge_index)\n\n def __repr__(self) -> str:\n return (f'{self.__class__.__name__}(num_nodes={self.num_nodes}, '\n f'num_edges={self.num_edges})')\n\n# Path: torch_geometric/datasets/graph_generator/base.py\nfrom abc import ABC, abstractmethod\nfrom typing import Any\n\nfrom torch_geometric.data import Data\nfrom torch_geometric.resolver import resolver\n\n\nclass GraphGenerator(ABC):\n r\"\"\"An abstract base class for generating synthetic graphs.\"\"\"\n @abstractmethod\n def __call__(self) -> Data:\n r\"\"\"To be implemented by :class:`GraphGenerator` subclasses.\"\"\"\n raise NotImplementedError\n\n @staticmethod\n def resolve(query: Any, *args: Any, **kwargs: Any) -> 'GraphGenerator':\n import torch_geometric.datasets.graph_generator as _graph_generators\n graph_generators = [\n gen for gen in vars(_graph_generators).values()\n if isinstance(gen, type) and issubclass(gen, GraphGenerator)\n ]\n return resolver(graph_generators, {}, query, GraphGenerator, 'Graph',\n *args, **kwargs)\n\n def __repr__(self) -> str:\n return f'{self.__class__.__name__}()'\n\n# Path: torch_geometric/datasets/graph_generator/er_graph.py\nfrom torch_geometric.data import Data\nfrom torch_geometric.datasets.graph_generator import GraphGenerator\nfrom torch_geometric.utils import erdos_renyi_graph\n\n\nclass ERGraph(GraphGenerator):\n r\"\"\"Generates random Erdos-Renyi (ER) graphs.\n See :meth:`~torch_geometric.utils.erdos_renyi_graph` for more information.\n\n Args:\n num_nodes (int): The number of nodes.\n edge_prob (float): Probability of an edge.\n \"\"\"\n def __init__(self, num_nodes: int, edge_prob: float):\n super().__init__()\n self.num_nodes = num_nodes\n self.edge_prob = edge_prob\n\n def __call__(self) -> Data:\n edge_index = erdos_renyi_graph(self.num_nodes, self.edge_prob)\n return Data(num_nodes=self.num_nodes, edge_index=edge_index)\n\n def __repr__(self) -> str:\n return (f'{self.__class__.__name__}(num_nodes={self.num_nodes}, '\n f'edge_prob={self.edge_prob})')\n\n# Path: torch_geometric/datasets/graph_generator/grid_graph.py\nfrom typing import Optional\n\nimport torch\n\nfrom torch_geometric.data import Data\nfrom torch_geometric.datasets.graph_generator import GraphGenerator\nfrom torch_geometric.utils import grid\n\n\nclass GridGraph(GraphGenerator):\n r\"\"\"Generates two-dimensional grid graphs.\n See :meth:`~torch_geometric.utils.grid` for more information.\n\n Args:\n height (int): The height of the grid.\n width (int): The width of the grid.\n dtype (:obj:`torch.dtype`, optional): The desired data type of the\n returned position tensor. (default: :obj:`None`)\n \"\"\"\n def __init__(\n self,\n height: int,\n width: int,\n dtype: Optional[torch.dtype] = None,\n ):\n super().__init__()\n self.height = height\n self.width = width\n self.dtype = dtype\n\n def __call__(self) -> Data:\n edge_index, pos = grid(height=self.height, width=self.width,\n dtype=self.dtype)\n return Data(edge_index=edge_index, pos=pos)\n\n def __repr__(self) -> str:\n return (f'{self.__class__.__name__}(height={self.height}, '\n f'width={self.width})')\n\n# Path: torch_geometric/datasets/graph_generator/tree_graph.py\nfrom typing import List, Optional, Tuple\n\nimport torch\nfrom torch import Tensor\n\nfrom torch_geometric.data import Data\nfrom torch_geometric.datasets.graph_generator import GraphGenerator\nfrom torch_geometric.utils import to_undirected\n\n\n\ndef tree(\n depth: int,\n branch: int = 2,\n undirected: bool = False,\n device: Optional[torch.device] = None,\n) -> Tuple[Tensor, Tensor]:\n \"\"\"Generates a tree graph with the given depth and branch size, along with\n node-level depth indicators.\n\n Args:\n depth (int): The depth of the tree.\n branch (int, optional): The branch size of the tree.\n (default: :obj:`2`)\n undirected (bool, optional): If set to :obj:`True`, the tree graph will\n be undirected. (default: :obj:`False`)\n device (torch.device, optional): The desired device of the returned\n tensors. (default: :obj:`None`)\n \"\"\"\n edges: List[Tuple[int, int]] = []\n depths: List[int] = [0]\n\n def add_edges(node: int, current_depth: int) -> None:\n node_count = len(depths)\n\n if current_depth < depth:\n for i in range(branch):\n edges.append((node, node_count + i))\n depths.append(current_depth + 1)\n\n for i in range(branch):\n add_edges(node=node_count + i, current_depth=current_depth + 1)\n\n add_edges(node=0, current_depth=0)\n\n edge_index = torch.tensor(edges, device=device).t().contiguous()\n if undirected:\n edge_index = to_undirected(edge_index, num_nodes=len(depths))\n\n return edge_index, torch.tensor(depths, device=device)\n\n\nclass TreeGraph(GraphGenerator):\n r\"\"\"Generates tree graphs.\n\n Args:\n depth (int): The depth of the tree.\n branch (int, optional): The branch size of the tree.\n (default: :obj:`2`)\n undirected (bool, optional): If set to :obj:`True`, the tree graph will\n be undirected. (default: :obj:`False`)\n \"\"\"\n def __init__(\n self,\n depth: int,\n branch: int = 2,\n undirected: bool = False,\n ) -> None:\n super().__init__()\n self.depth = depth\n self.branch = branch\n self.undirected = undirected\n\n def __call__(self) -> Data:\n edge_index, depth = tree(self.depth, self.branch, self.undirected)\n num_nodes = depth.numel()\n return Data(edge_index=edge_index, depth=depth, num_nodes=num_nodes)\n\n def __repr__(self) -> str:\n return (f'{self.__class__.__name__}(depth={self.depth}, '\n f'branch={self.branch}, undirected={self.undirected})')\n\n# Path: torch_geometric/datasets/graph_generator/__init__.py\nfrom .base import GraphGenerator\nfrom .ba_graph import BAGraph\nfrom .er_graph import ERGraph\nfrom .grid_graph import GridGraph\nfrom .tree_graph import TreeGraph\n\n__all__ = classes = [\n 'GraphGenerator',\n 'BAGraph',\n 'ERGraph',\n 'GridGraph',\n 'TreeGraph',\n]\n\n# Path: torch_geometric/datasets/motif_generator/base.py\nfrom abc import ABC, abstractmethod\nfrom typing import Any\n\nfrom torch_geometric.data import Data\nfrom torch_geometric.resolver import resolver\n\n\nclass MotifGenerator(ABC):\n r\"\"\"An abstract base class for generating a motif.\"\"\"\n @abstractmethod\n def __call__(self) -> Data:\n r\"\"\"To be implemented by :class:`Motif` subclasses.\"\"\"\n pass\n\n @staticmethod\n def resolve(query: Any, *args: Any, **kwargs: Any) -> 'MotifGenerator':\n import torch_geometric.datasets.motif_generator as _motif_generators\n motif_generators = [\n gen for gen in vars(_motif_generators).values()\n if isinstance(gen, type) and issubclass(gen, MotifGenerator)\n ]\n return resolver(motif_generators, {}, query, MotifGenerator, 'Motif',\n *args, **kwargs)\n\n def __repr__(self) -> str:\n return f'{self.__class__.__name__}()'\n\n# Path: torch_geometric/datasets/motif_generator/custom.py\nfrom typing import Any, Optional\n\nfrom torch_geometric.data import Data\nfrom torch_geometric.datasets.motif_generator import MotifGenerator\nfrom torch_geometric.utils import from_networkx\n\n\nclass CustomMotif(MotifGenerator):\n r\"\"\"Generates a motif based on a custom structure coming from a\n :class:`torch_geometric.data.Data` or :class:`networkx.Graph` object.\n\n Args:\n structure (torch_geometric.data.Data or networkx.Graph): The structure\n to use as a motif.\n \"\"\"\n def __init__(self, structure: Any):\n super().__init__()\n\n self.structure: Optional[Data] = None\n\n if isinstance(structure, Data):\n self.structure = structure\n else:\n try:\n import networkx as nx\n if isinstance(structure, nx.Graph):\n self.structure = from_networkx(structure)\n except ImportError:\n pass\n\n if self.structure is None:\n raise ValueError(f\"Expected a motif structure of type \"\n f\"'torch_geometric.data.Data' or 'networkx.Graph'\"\n f\"(got {type(structure)})\")\n\n def __call__(self) -> Data:\n assert isinstance(self.structure, Data)\n return self.structure\n\n# Path: torch_geometric/datasets/motif_generator/cycle.py\nimport torch\n\nfrom torch_geometric.data import Data\nfrom torch_geometric.datasets.motif_generator import CustomMotif\n\n\nclass CycleMotif(CustomMotif):\n r\"\"\"Generates the cycle motif from the `\"GNNExplainer:\n Generating Explanations for Graph Neural Networks\"\n `__ paper.\n\n Args:\n num_nodes (int): The number of nodes in the cycle.\n \"\"\"\n def __init__(self, num_nodes: int):\n self.num_nodes = num_nodes\n\n row = torch.arange(num_nodes).view(-1, 1).repeat(1, 2).view(-1)\n col1 = torch.arange(-1, num_nodes - 1) % num_nodes\n col2 = torch.arange(1, num_nodes + 1) % num_nodes\n col = torch.stack([col1, col2], dim=1).sort(dim=-1)[0].view(-1)\n\n structure = Data(\n num_nodes=num_nodes,\n edge_index=torch.stack([row, col], dim=0),\n )\n super().__init__(structure)\n\n def __repr__(self) -> str:\n return f'{self.__class__.__name__}({self.num_nodes})'\n\n# Path: torch_geometric/datasets/motif_generator/grid.py\nimport torch\n\nfrom torch_geometric.data import Data\nfrom torch_geometric.datasets.motif_generator import CustomMotif\n\n\nclass GridMotif(CustomMotif):\n r\"\"\"Generates the grid-structured motif from the\n `\"GNNExplainer: Generating Explanations for Graph Neural Networks\"\n `__ paper.\n \"\"\"\n def __init__(self) -> None:\n edge_indices = [\n [0, 1],\n [0, 3],\n [1, 4],\n [3, 4],\n [1, 2],\n [2, 5],\n [4, 5],\n [3, 6],\n [6, 7],\n [4, 7],\n [5, 8],\n [7, 8],\n [1, 0],\n [3, 0],\n [4, 1],\n [4, 3],\n [2, 1],\n [5, 2],\n [5, 4],\n [6, 3],\n [7, 6],\n [7, 4],\n [8, 5],\n [8, 7],\n ]\n structure = Data(\n num_nodes=9,\n edge_index=torch.tensor(edge_indices).t().contiguous(),\n y=torch.tensor([0, 1, 0, 1, 2, 1, 0, 1, 0]),\n )\n super().__init__(structure)\n\n# Path: torch_geometric/datasets/motif_generator/house.py\nimport torch\n\nfrom torch_geometric.data import Data\nfrom torch_geometric.datasets.motif_generator import CustomMotif\n\n\nclass HouseMotif(CustomMotif):\n r\"\"\"Generates the house-structured motif from the `\"GNNExplainer:\n Generating Explanations for Graph Neural Networks\"\n `__ paper, containing 5 nodes and 6\n undirected edges. Nodes are labeled according to their structural role:\n the top, middle and bottom of the house.\n \"\"\"\n def __init__(self) -> None:\n structure = Data(\n num_nodes=5,\n edge_index=torch.tensor([\n [0, 0, 0, 1, 1, 1, 2, 2, 3, 3, 4, 4],\n [1, 3, 4, 4, 2, 0, 1, 3, 2, 0, 0, 1],\n ]),\n y=torch.tensor([0, 0, 1, 1, 2]),\n )\n super().__init__(structure)\n\n# Path: torch_geometric/datasets/motif_generator/__init__.py\nfrom .base import MotifGenerator\nfrom .custom import CustomMotif\nfrom .house import HouseMotif\nfrom .cycle import CycleMotif\nfrom .grid import GridMotif\n\n__all__ = classes = [\n 'MotifGenerator',\n 'CustomMotif',\n 'HouseMotif',\n 'CycleMotif',\n 'GridMotif',\n]\n\n# Path: torch_geometric/datasets/explainer_dataset.py\nfrom typing import Any, Callable, Dict, Optional, Union\n\nimport torch\nfrom torch import Tensor\n\nfrom torch_geometric.data import InMemoryDataset\nfrom torch_geometric.datasets.graph_generator import GraphGenerator\nfrom torch_geometric.datasets.motif_generator import MotifGenerator\nfrom torch_geometric.explain import Explanation\n\n\nclass ExplainerDataset(InMemoryDataset):\n r\"\"\"Generates a synthetic dataset for evaluating explainabilty algorithms,\n as described in the `\"GNNExplainer: Generating Explanations for Graph\n Neural Networks\" `__ paper.\n The :class:`~torch_geometric.datasets.ExplainerDataset` creates synthetic\n graphs coming from a\n :class:`~torch_geometric.datasets.graph_generator.GraphGenerator`, and\n randomly attaches :obj:`num_motifs` many motifs to it coming from a\n :class:`~torch_geometric.datasets.graph_generator.MotifGenerator`.\n Ground-truth node-level and edge-level explainabilty masks are given based\n on whether nodes and edges are part of a certain motif or not.\n\n For example, to generate a random Barabasi-Albert (BA) graph with 300\n nodes, in which we want to randomly attach 80 :obj:`\"house\"` motifs, write:\n\n .. code-block:: python\n\n from torch_geometric.datasets import ExplainerDataset\n from torch_geometric.datasets.graph_generator import BAGraph\n\n dataset = ExplainerDataset(\n graph_generator=BAGraph(num_nodes=300, num_edges=5),\n motif_generator='house',\n num_motifs=80,\n )\n\n .. note::\n\n For an example of using :class:`ExplainerDataset`, see\n `examples/explain/gnn_explainer_ba_shapes.py\n `_.\n\n Args:\n graph_generator (GraphGenerator or str): The graph generator to be\n used, *e.g.*,\n :class:`torch.geometric.datasets.graph_generator.BAGraph`\n (or any string that automatically resolves to it).\n motif_generator (MotifGenerator): The motif generator to be used,\n *e.g.*,\n :class:`torch_geometric.datasets.motif_generator.HouseMotif`\n (or any string that automatically resolves to it).\n num_motifs (int): The number of motifs to attach to the graph.\n num_graphs (int, optional): The number of graphs to generate.\n (default: :obj:`1`)\n graph_generator_kwargs (Dict[str, Any], optional): Arguments passed to\n the respective graph generator module in case it gets automatically\n resolved. (default: :obj:`None`)\n motif_generator_kwargs (Dict[str, Any], optional): Arguments passed to\n the respective motif generator module in case it gets automatically\n resolved. (default: :obj:`None`)\n transform (callable, optional): A function/transform that takes in an\n :obj:`torch_geometric.data.Data` object and returns a transformed\n version. The data object will be transformed before every access.\n (default: :obj:`None`)\n \"\"\"\n def __init__(\n self,\n graph_generator: Union[GraphGenerator, str],\n motif_generator: Union[MotifGenerator, str],\n num_motifs: int,\n num_graphs: int = 1,\n graph_generator_kwargs: Optional[Dict[str, Any]] = None,\n motif_generator_kwargs: Optional[Dict[str, Any]] = None,\n transform: Optional[Callable] = None,\n ):\n super().__init__(root=None, transform=transform)\n\n if num_motifs <= 0:\n raise ValueError(f\"At least one motif needs to be attached to the \"\n f\"graph (got {num_motifs})\")\n\n self.graph_generator = GraphGenerator.resolve(\n graph_generator,\n **(graph_generator_kwargs or {}),\n )\n self.motif_generator = MotifGenerator.resolve(\n motif_generator,\n **(motif_generator_kwargs or {}),\n )\n self.num_motifs = num_motifs\n\n # TODO (matthias) support on-the-fly graph generation.\n data_list = [self.get_graph() for _ in range(num_graphs)]\n self.data, self.slices = self.collate(data_list)\n\n def get_graph(self) -> Explanation:\n data = self.graph_generator()\n assert data.num_nodes is not None\n assert data.edge_index is not None\n\n edge_indices = [data.edge_index]\n num_nodes = data.num_nodes\n node_masks = [torch.zeros(data.num_nodes)]\n edge_masks = [torch.zeros(data.num_edges)]\n ys = [torch.zeros(num_nodes, dtype=torch.long)]\n\n connecting_nodes = torch.randperm(num_nodes)[:self.num_motifs]\n for i in connecting_nodes.tolist():\n motif = self.motif_generator()\n assert motif.num_nodes is not None\n assert motif.edge_index is not None\n\n # Add motif to the graph.\n edge_indices.append(motif.edge_index + num_nodes)\n node_masks.append(torch.ones(motif.num_nodes))\n edge_masks.append(torch.ones(motif.num_edges))\n\n # Add random motif connection to the graph.\n j = int(torch.randint(0, motif.num_nodes, (1, ))) + num_nodes\n edge_indices.append(torch.tensor([[i, j], [j, i]]))\n edge_masks.append(torch.zeros(2))\n\n if isinstance(motif.y, Tensor):\n ys.append(motif.y + 1 if motif.y.min() == 0 else motif.y)\n else:\n ys.append(torch.ones(motif.num_nodes, dtype=torch.long))\n\n num_nodes += motif.num_nodes\n\n return Explanation(\n edge_index=torch.cat(edge_indices, dim=1),\n y=torch.cat(ys, dim=0),\n edge_mask=torch.cat(edge_masks, dim=0),\n node_mask=torch.cat(node_masks, dim=0),\n )\n\n def __repr__(self) -> str:\n return (f'{self.__class__.__name__}({len(self)}, '\n f'graph_generator={self.graph_generator}, '\n f'motif_generator={self.motif_generator}, '\n f'num_motifs={self.num_motifs})')\n\n# Path: torch_geometric/datasets/facebook.py\nfrom typing import Callable, Optional\n\nimport numpy as np\nimport torch\n\nfrom torch_geometric.data import Data, InMemoryDataset, download_url\n\n\nclass FacebookPagePage(InMemoryDataset):\n r\"\"\"The Facebook Page-Page network dataset introduced in the\n `\"Multi-scale Attributed Node Embedding\"\n `_ paper.\n Nodes represent verified pages on Facebook and edges are mutual likes.\n It contains 22,470 nodes, 342,004 edges, 128 node features and 4 classes.\n\n Args:\n root (str): Root directory where the dataset should be saved.\n transform (callable, optional): A function/transform that takes in an\n :obj:`torch_geometric.data.Data` object and returns a transformed\n version. The data object will be transformed before every access.\n (default: :obj:`None`)\n pre_transform (callable, optional): A function/transform that takes in\n an :obj:`torch_geometric.data.Data` object and returns a\n transformed version. The data object will be transformed before\n being saved to disk. (default: :obj:`None`)\n force_reload (bool, optional): Whether to re-process the dataset.\n (default: :obj:`False`)\n \"\"\"\n\n url = 'https://graphmining.ai/datasets/ptg/facebook.npz'\n\n def __init__(\n self,\n root: str,\n transform: Optional[Callable] = None,\n pre_transform: Optional[Callable] = None,\n force_reload: bool = False,\n ) -> None:\n super().__init__(root, transform, pre_transform,\n force_reload=force_reload)\n self.load(self.processed_paths[0])\n\n @property\n def raw_file_names(self) -> str:\n return 'facebook.npz'\n\n @property\n def processed_file_names(self) -> str:\n return 'data.pt'\n\n def download(self) -> None:\n download_url(self.url, self.raw_dir)\n\n def process(self) -> None:\n data = np.load(self.raw_paths[0], 'r', allow_pickle=True)\n x = torch.from_numpy(data['features']).to(torch.float)\n y = torch.from_numpy(data['target']).to(torch.long)\n edge_index = torch.from_numpy(data['edges']).to(torch.long)\n edge_index = edge_index.t().contiguous()\n\n data = Data(x=x, y=y, edge_index=edge_index)\n\n if self.pre_transform is not None:\n data = self.pre_transform(data)\n\n self.save([data], self.processed_paths[0])\n\n# Path: torch_geometric/datasets/fake.py\nimport random\nfrom collections import defaultdict\nfrom itertools import product\nfrom typing import Callable, Dict, List, Optional, Tuple, Union\n\nimport torch\nfrom torch import Tensor\n\nfrom torch_geometric.data import Data, HeteroData, InMemoryDataset\nfrom torch_geometric.utils import coalesce, remove_self_loops, to_undirected\n\n\nclass FakeDataset(InMemoryDataset):\n r\"\"\"A fake dataset that returns randomly generated\n :class:`~torch_geometric.data.Data` objects.\n\n Args:\n num_graphs (int, optional): The number of graphs. (default: :obj:`1`)\n avg_num_nodes (int, optional): The average number of nodes in a graph.\n (default: :obj:`1000`)\n avg_degree (float, optional): The average degree per node.\n (default: :obj:`10.0`)\n num_channels (int, optional): The number of node features.\n (default: :obj:`64`)\n edge_dim (int, optional): The number of edge features.\n (default: :obj:`0`)\n num_classes (int, optional): The number of classes in the dataset.\n (default: :obj:`10`)\n task (str, optional): Whether to return node-level or graph-level\n labels (:obj:`\"node\"`, :obj:`\"graph\"`, :obj:`\"auto\"`).\n If set to :obj:`\"auto\"`, will return graph-level labels if\n :obj:`num_graphs > 1`, and node-level labels other-wise.\n (default: :obj:`\"auto\"`)\n is_undirected (bool, optional): Whether the graphs to generate are\n undirected. (default: :obj:`True`)\n transform (callable, optional): A function/transform that takes in\n an :obj:`torch_geometric.data.Data` object and returns a\n transformed version. The data object will be transformed before\n every access. (default: :obj:`None`)\n **kwargs (optional): Additional attributes and their shapes\n *e.g.* :obj:`global_features=5`.\n \"\"\"\n def __init__(\n self,\n num_graphs: int = 1,\n avg_num_nodes: int = 1000,\n avg_degree: float = 10.0,\n num_channels: int = 64,\n edge_dim: int = 0,\n num_classes: int = 10,\n task: str = 'auto',\n is_undirected: bool = True,\n transform: Optional[Callable] = None,\n pre_transform: Optional[Callable] = None,\n **kwargs: Union[int, Tuple[int, ...]],\n ) -> None:\n super().__init__(None, transform)\n\n if task == 'auto':\n task = 'graph' if num_graphs > 1 else 'node'\n assert task in ['node', 'graph']\n\n self.avg_num_nodes = max(avg_num_nodes, int(avg_degree))\n self.avg_degree = max(avg_degree, 1)\n self.num_channels = num_channels\n self.edge_dim = edge_dim\n self._num_classes = num_classes\n self.task = task\n self.is_undirected = is_undirected\n self.kwargs = kwargs\n\n data_list = [self.generate_data() for _ in range(max(num_graphs, 1))]\n self.data, self.slices = self.collate(data_list)\n\n def generate_data(self) -> Data:\n num_nodes = get_num_nodes(self.avg_num_nodes, self.avg_degree)\n\n data = Data()\n\n if self._num_classes > 0 and self.task == 'node':\n data.y = torch.randint(self._num_classes, (num_nodes, ))\n elif self._num_classes > 0 and self.task == 'graph':\n data.y = torch.tensor([random.randint(0, self._num_classes - 1)])\n\n data.edge_index = get_edge_index(num_nodes, num_nodes, self.avg_degree,\n self.is_undirected, remove_loops=True)\n\n if self.num_channels > 0:\n x = torch.randn(num_nodes, self.num_channels)\n if self._num_classes > 0 and self.task == 'node':\n assert isinstance(data.y, Tensor)\n x = x + data.y.unsqueeze(1)\n elif self._num_classes > 0 and self.task == 'graph':\n assert isinstance(data.y, Tensor)\n x = x + data.y\n data.x = x\n else:\n data.num_nodes = num_nodes\n\n if self.edge_dim > 1:\n data.edge_attr = torch.rand(data.num_edges, self.edge_dim)\n elif self.edge_dim == 1:\n data.edge_weight = torch.rand(data.num_edges)\n\n for feature_name, feature_shape in self.kwargs.items():\n setattr(data, feature_name, torch.randn(feature_shape))\n\n return data\n\n\nclass FakeHeteroDataset(InMemoryDataset):\n r\"\"\"A fake dataset that returns randomly generated\n :class:`~torch_geometric.data.HeteroData` objects.\n\n Args:\n num_graphs (int, optional): The number of graphs. (default: :obj:`1`)\n num_node_types (int, optional): The number of node types.\n (default: :obj:`3`)\n num_edge_types (int, optional): The number of edge types.\n (default: :obj:`6`)\n avg_num_nodes (int, optional): The average number of nodes in a graph.\n (default: :obj:`1000`)\n avg_degree (float, optional): The average degree per node.\n (default: :obj:`10.0`)\n avg_num_channels (int, optional): The average number of node features.\n (default: :obj:`64`)\n edge_dim (int, optional): The number of edge features.\n (default: :obj:`0`)\n num_classes (int, optional): The number of classes in the dataset.\n (default: :obj:`10`)\n task (str, optional): Whether to return node-level or graph-level\n labels (:obj:`\"node\"`, :obj:`\"graph\"`, :obj:`\"auto\"`).\n If set to :obj:`\"auto\"`, will return graph-level labels if\n :obj:`num_graphs > 1`, and node-level labels other-wise.\n (default: :obj:`\"auto\"`)\n transform (callable, optional): A function/transform that takes in\n an :obj:`torch_geometric.data.HeteroData` object and returns a\n transformed version. The data object will be transformed before\n every access. (default: :obj:`None`)\n **kwargs (optional): Additional attributes and their shapes\n *e.g.* :obj:`global_features=5`.\n \"\"\"\n def __init__(\n self,\n num_graphs: int = 1,\n num_node_types: int = 3,\n num_edge_types: int = 6,\n avg_num_nodes: int = 1000,\n avg_degree: float = 10.0,\n avg_num_channels: int = 64,\n edge_dim: int = 0,\n num_classes: int = 10,\n task: str = \"auto\",\n transform: Optional[Callable] = None,\n pre_transform: Optional[Callable] = None,\n **kwargs: Union[int, Tuple[int, ...]],\n ) -> None:\n super().__init__(None, transform)\n\n if task == 'auto':\n task = 'graph' if num_graphs > 1 else 'node'\n assert task in ['node', 'graph']\n\n...\n", "needle_token_start": 8646, "needle_token_end": 9079, "code_context_ntokens": 16120, "output": ["```python\nfrom typing import List, Optional, Tuple\n\nimport torch\nfrom torch import Tensor\n\nfrom torch_geometric.data import Data\nfrom torch_geometric.datasets.graph_generator import GraphGenerator\nfrom torch_geometric.utils import to_undirected\n\n\ndef tree(\n depth: int,\n branch: int = 2,\n undirected: bool = False,\n device: Optional[torch.device] = None,\n) -> Tuple[Tensor, Tensor]:\n \"\"\"Generates a tree graph with the given depth and branch size, along with\n node-level depth indicators.\n\n Args:\n depth (int): The depth of the tree.\n branch (int, optional): The branch size of the tree.\n (default: :obj:`2`)\n undirected (bool, optional): If set to :obj:`True`, the tree graph will\n be undirected. (default: :obj:`False`)\n device (torch.device, optional): The desired device of the returned\n tensors. (default: :obj:`None`)\n \"\"\"\n edges: List[Tuple[int, int]] = []\n depths: List[int] = [0]\n\n def add_edges(node: int, current_depth: int) -> None:\n node_count = len(depths)\n\n if current_depth < depth:\n for i in range(branch):\n edges.append((node, node_count + i))\n depths.append(current_depth + 1)\n\n for i in range(branch):\n add_edges(node=node_count + i, current_depth=current_depth + 1)\n\n add_edges(node=0, current_depth=0)\n\n edge_index = torch.tensor(edges, device=device).t().contiguous()\n if undirected:\n edge_index = to_undirected(edge_index, num_nodes=len(depths))\n\n return edge_index, torch.tensor(depths, device=device)\n\n\nclass TreeGraph(GraphGenerator):\n r\"\"\"Generates tree graphs.\n\n Args:\n depth (int): The depth of the tree.\n branch (int, optional): The branch size of the tree.\n (default: :obj:`2`)\n undirected (bool, optional): If set to :obj:`True`, the tree graph will\n be undirected. (default: :obj:`False`)\n \"\"\"\n def __init__(\n self,\n depth: int,\n branch: int = 2,\n undirected: bool = False,\n ) -> None:\n super().__init__()\n self.depth = depth\n self.branch = branch\n self.undirected = undirected\n\n def __call__(self) -> Data:\n edge_index, depth = tree(self.depth, self.branch, self.undirected)\n num_nodes = depth.numel()\n return Data(edge_index=edge_index, depth=depth, num_nodes=num_nodes)\n\n def __repr__(self) -> str:\n return (f'{self.__class__.__name__}(depth={self.depth}, '\n f'branch={self.branch}, undirected={self.undirected})')\n```"]} +{"repo": "pyg-team/pytorch_geometric", "name": "to_scipy_sparse_matrix", "language": "python", "path": "torch_geometric/utils/convert.py", "position_ratio": 0.65, "description": "\nFunction Description:\n1. **Purpose**: The function is designed to convert a graph representation from edge indices and optional edge attributes into a sparse matrix format, which is commonly used for efficient mathematical and scientific computations.\n2. **Input**: The function takes three parameters: a tensor of edge indices indicating the connections between nodes, an optional tensor of edge attributes that can represent weights or multi-dimensional features of the edges, and an optional integer specifying the total number of nodes in the graph.\n3. **Output**: It produces a sparse matrix in COOrdinate format from the scipy library, representing the graph where rows and columns correspond to nodes and non-zero entries represent edges, optionally weighted by the edge attributes.\n4. **Procedure**: The function first extracts the rows and columns from the edge indices tensor. If no edge attributes are provided, it assigns a weight of one to all edges. It then determines the number of nodes, either from the input or by calculating the maximum index in the edge indices. Finally, it constructs and returns the sparse matrix using the node indices and edge attributes (or default weights).\n\n", "instruction": "Based on the function description and code context, please retrieve and repeat the exact described function from the code context in a code block wrapped by ```:", "template": "instruction\ncode_context\ndescription\ninstruction", "code_context": "# Path: torch_geometric/utils/_subgraph.py\nfrom typing import List, Literal, Optional, Tuple, Union, overload\n\nimport torch\nfrom torch import Tensor\n\nfrom torch_geometric.typing import OptTensor, PairTensor\nfrom torch_geometric.utils import scatter\nfrom torch_geometric.utils.map import map_index\nfrom torch_geometric.utils.mask import index_to_mask\nfrom torch_geometric.utils.num_nodes import maybe_num_nodes\n\n\ndef get_num_hops(model: torch.nn.Module) -> int:\n r\"\"\"Returns the number of hops the model is aggregating information\n from.\n\n .. note::\n\n This function counts the number of message passing layers as an\n approximation of the total number of hops covered by the model.\n Its output may not necessarily be correct in case message passing\n layers perform multi-hop aggregation, *e.g.*, as in\n :class:`~torch_geometric.nn.conv.ChebConv`.\n\n Example:\n >>> class GNN(torch.nn.Module):\n ... def __init__(self):\n ... super().__init__()\n ... self.conv1 = GCNConv(3, 16)\n ... self.conv2 = GCNConv(16, 16)\n ... self.lin = Linear(16, 2)\n ...\n ... def forward(self, x, edge_index):\n ... x = self.conv1(x, edge_index).relu()\n ... x = self.conv2(x, edge_index).relu()\n ... return self.lin(x)\n >>> get_num_hops(GNN())\n 2\n \"\"\"\n from torch_geometric.nn.conv import MessagePassing\n num_hops = 0\n for module in model.modules():\n if isinstance(module, MessagePassing):\n num_hops += 1\n return num_hops\n\n\n@overload\ndef subgraph(\n subset: Union[Tensor, List[int]],\n edge_index: Tensor,\n edge_attr: OptTensor = ...,\n relabel_nodes: bool = ...,\n num_nodes: Optional[int] = ...,\n) -> Tuple[Tensor, OptTensor]:\n pass\n\n\n@overload\ndef subgraph(\n subset: Union[Tensor, List[int]],\n edge_index: Tensor,\n edge_attr: OptTensor = ...,\n relabel_nodes: bool = ...,\n num_nodes: Optional[int] = ...,\n *,\n return_edge_mask: Literal[False],\n) -> Tuple[Tensor, OptTensor]:\n pass\n\n\n@overload\ndef subgraph(\n subset: Union[Tensor, List[int]],\n edge_index: Tensor,\n edge_attr: OptTensor = ...,\n relabel_nodes: bool = ...,\n num_nodes: Optional[int] = ...,\n *,\n return_edge_mask: Literal[True],\n) -> Tuple[Tensor, OptTensor, Tensor]:\n pass\n\n\ndef subgraph(\n subset: Union[Tensor, List[int]],\n edge_index: Tensor,\n edge_attr: OptTensor = None,\n relabel_nodes: bool = False,\n num_nodes: Optional[int] = None,\n *,\n return_edge_mask: bool = False,\n) -> Union[Tuple[Tensor, OptTensor], Tuple[Tensor, OptTensor, Tensor]]:\n r\"\"\"Returns the induced subgraph of :obj:`(edge_index, edge_attr)`\n containing the nodes in :obj:`subset`.\n\n Args:\n subset (LongTensor, BoolTensor or [int]): The nodes to keep.\n edge_index (LongTensor): The edge indices.\n edge_attr (Tensor, optional): Edge weights or multi-dimensional\n edge features. (default: :obj:`None`)\n relabel_nodes (bool, optional): If set to :obj:`True`, the resulting\n :obj:`edge_index` will be relabeled to hold consecutive indices\n starting from zero. (default: :obj:`False`)\n num_nodes (int, optional): The number of nodes, *i.e.*\n :obj:`max(edge_index) + 1`. (default: :obj:`None`)\n return_edge_mask (bool, optional): If set to :obj:`True`, will return\n the edge mask to filter out additional edge features.\n (default: :obj:`False`)\n\n :rtype: (:class:`LongTensor`, :class:`Tensor`)\n\n Examples:\n >>> edge_index = torch.tensor([[0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6],\n ... [1, 0, 2, 1, 3, 2, 4, 3, 5, 4, 6, 5]])\n >>> edge_attr = torch.tensor([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12])\n >>> subset = torch.tensor([3, 4, 5])\n >>> subgraph(subset, edge_index, edge_attr)\n...\n# Path: torch_geometric/utils/_to_dense_batch.py\nfrom typing import Optional, Tuple\n\nimport torch\nfrom torch import Tensor\n\nfrom torch_geometric.experimental import (\n disable_dynamic_shapes,\n is_experimental_mode_enabled,\n)\nfrom torch_geometric.utils import cumsum, scatter\n\n\n@disable_dynamic_shapes(required_args=['batch_size', 'max_num_nodes'])\ndef to_dense_batch(\n x: Tensor,\n batch: Optional[Tensor] = None,\n fill_value: float = 0.0,\n max_num_nodes: Optional[int] = None,\n batch_size: Optional[int] = None,\n) -> Tuple[Tensor, Tensor]:\n r\"\"\"Given a sparse batch of node features\n :math:`\\mathbf{X} \\in \\mathbb{R}^{(N_1 + \\ldots + N_B) \\times F}` (with\n :math:`N_i` indicating the number of nodes in graph :math:`i`), creates a\n dense node feature tensor\n :math:`\\mathbf{X} \\in \\mathbb{R}^{B \\times N_{\\max} \\times F}` (with\n :math:`N_{\\max} = \\max_i^B N_i`).\n In addition, a mask of shape :math:`\\mathbf{M} \\in \\{ 0, 1 \\}^{B \\times\n N_{\\max}}` is returned, holding information about the existence of\n fake-nodes in the dense representation.\n\n Args:\n x (Tensor): Node feature matrix\n :math:`\\mathbf{X} \\in \\mathbb{R}^{(N_1 + \\ldots + N_B) \\times F}`.\n batch (LongTensor, optional): Batch vector\n :math:`\\mathbf{b} \\in {\\{ 0, \\ldots, B-1\\}}^N`, which assigns each\n node to a specific example. Must be ordered. (default: :obj:`None`)\n fill_value (float, optional): The value for invalid entries in the\n resulting dense output tensor. (default: :obj:`0`)\n max_num_nodes (int, optional): The size of the output node dimension.\n (default: :obj:`None`)\n batch_size (int, optional): The batch size. (default: :obj:`None`)\n\n :rtype: (:class:`Tensor`, :class:`BoolTensor`)\n\n Examples:\n >>> x = torch.arange(12).view(6, 2)\n >>> x\n tensor([[ 0, 1],\n [ 2, 3],\n [ 4, 5],\n [ 6, 7],\n [ 8, 9],\n [10, 11]])\n\n >>> out, mask = to_dense_batch(x)\n >>> mask\n tensor([[True, True, True, True, True, True]])\n\n >>> batch = torch.tensor([0, 0, 1, 2, 2, 2])\n >>> out, mask = to_dense_batch(x, batch)\n >>> out\n tensor([[[ 0, 1],\n [ 2, 3],\n [ 0, 0]],\n [[ 4, 5],\n [ 0, 0],\n [ 0, 0]],\n [[ 6, 7],\n [ 8, 9],\n [10, 11]]])\n >>> mask\n tensor([[ True, True, False],\n [ True, False, False],\n [ True, True, True]])\n\n >>> out, mask = to_dense_batch(x, batch, max_num_nodes=4)\n >>> out\n tensor([[[ 0, 1],\n [ 2, 3],\n [ 0, 0],\n [ 0, 0]],\n [[ 4, 5],\n [ 0, 0],\n [ 0, 0],\n [ 0, 0]],\n [[ 6, 7],\n [ 8, 9],\n [10, 11],\n [ 0, 0]]])\n\n >>> mask\n tensor([[ True, True, False, False],\n [ True, False, False, False],\n [ True, True, True, False]])\n \"\"\"\n if batch is None and max_num_nodes is None:\n mask = torch.ones(1, x.size(0), dtype=torch.bool, device=x.device)\n return x.unsqueeze(0), mask\n\n if batch is None:\n batch = x.new_zeros(x.size(0), dtype=torch.long)\n\n if batch_size is None:\n batch_size = int(batch.max()) + 1\n\n num_nodes = scatter(batch.new_ones(x.size(0)), batch, dim=0,\n dim_size=batch_size, reduce='sum')\n cum_nodes = cumsum(num_nodes)\n\n filter_nodes = False\n dynamic_shapes_disabled = is_experimental_mode_enabled(\n 'disable_dynamic_shapes')\n\n if max_num_nodes is None:\n max_num_nodes = int(num_nodes.max())\n elif not dynamic_shapes_disabled and num_nodes.max() > max_num_nodes:\n filter_nodes = True\n\n tmp = torch.arange(batch.size(0), device=x.device) - cum_nodes[batch]\n idx = tmp + (batch * max_num_nodes)\n if filter_nodes:\n mask = tmp < max_num_nodes\n x, idx = x[mask], idx[mask]\n\n size = [batch_size * max_num_nodes] + list(x.size())[1:]\n out = torch.as_tensor(fill_value, device=x.device)\n out = out.to(x.dtype).repeat(size)\n out[idx] = x\n out = out.view([batch_size, max_num_nodes] + list(x.size())[1:])\n\n mask = torch.zeros(batch_size * max_num_nodes, dtype=torch.bool,\n device=x.device)\n mask[idx] = 1\n mask = mask.view(batch_size, max_num_nodes)\n\n return out, mask\n\n# Path: torch_geometric/utils/_train_test_split_edges.py\nimport math\n\nimport torch\n\nimport torch_geometric\nfrom torch_geometric.deprecation import deprecated\nfrom torch_geometric.utils import to_undirected\n\n\n@deprecated(\"use 'transforms.RandomLinkSplit' instead\")\ndef train_test_split_edges(\n data: 'torch_geometric.data.Data',\n val_ratio: float = 0.05,\n test_ratio: float = 0.1,\n) -> 'torch_geometric.data.Data':\n r\"\"\"Splits the edges of a :class:`torch_geometric.data.Data` object\n into positive and negative train/val/test edges.\n As such, it will replace the :obj:`edge_index` attribute with\n :obj:`train_pos_edge_index`, :obj:`train_pos_neg_adj_mask`,\n :obj:`val_pos_edge_index`, :obj:`val_neg_edge_index` and\n :obj:`test_pos_edge_index` attributes.\n If :obj:`data` has edge features named :obj:`edge_attr`, then\n :obj:`train_pos_edge_attr`, :obj:`val_pos_edge_attr` and\n :obj:`test_pos_edge_attr` will be added as well.\n\n .. warning::\n\n :meth:`~torch_geometric.utils.train_test_split_edges` is deprecated and\n will be removed in a future release.\n Use :class:`torch_geometric.transforms.RandomLinkSplit` instead.\n\n Args:\n data (Data): The data object.\n val_ratio (float, optional): The ratio of positive validation edges.\n (default: :obj:`0.05`)\n test_ratio (float, optional): The ratio of positive test edges.\n (default: :obj:`0.1`)\n\n :rtype: :class:`torch_geometric.data.Data`\n \"\"\"\n assert 'batch' not in data # No batch-mode.\n\n assert data.num_nodes is not None\n assert data.edge_index is not None\n\n num_nodes = data.num_nodes\n row, col = data.edge_index\n edge_attr = data.edge_attr\n del data.edge_index\n del data.edge_attr\n\n # Return upper triangular portion.\n mask = row < col\n row, col = row[mask], col[mask]\n\n if edge_attr is not None:\n edge_attr = edge_attr[mask]\n\n n_v = int(math.floor(val_ratio * row.size(0)))\n n_t = int(math.floor(test_ratio * row.size(0)))\n\n # Positive edges.\n perm = torch.randperm(row.size(0))\n row, col = row[perm], col[perm]\n if edge_attr is not None:\n edge_attr = edge_attr[perm]\n\n r, c = row[:n_v], col[:n_v]\n data.val_pos_edge_index = torch.stack([r, c], dim=0)\n if edge_attr is not None:\n data.val_pos_edge_attr = edge_attr[:n_v]\n\n r, c = row[n_v:n_v + n_t], col[n_v:n_v + n_t]\n data.test_pos_edge_index = torch.stack([r, c], dim=0)\n if edge_attr is not None:\n data.test_pos_edge_attr = edge_attr[n_v:n_v + n_t]\n\n r, c = row[n_v + n_t:], col[n_v + n_t:]\n data.train_pos_edge_index = torch.stack([r, c], dim=0)\n if edge_attr is not None:\n out = to_undirected(data.train_pos_edge_index, edge_attr[n_v + n_t:])\n data.train_pos_edge_index, data.train_pos_edge_attr = out\n else:\n data.train_pos_edge_index = to_undirected(data.train_pos_edge_index)\n\n # Negative edges.\n neg_adj_mask = torch.ones(num_nodes, num_nodes, dtype=torch.uint8)\n neg_adj_mask = neg_adj_mask.triu(diagonal=1).to(torch.bool)\n neg_adj_mask[row, col] = 0\n\n neg_row, neg_col = neg_adj_mask.nonzero(as_tuple=False).t()\n perm = torch.randperm(neg_row.size(0))[:n_v + n_t]\n neg_row, neg_col = neg_row[perm], neg_col[perm]\n\n neg_adj_mask[neg_row, neg_col] = 0\n data.train_neg_adj_mask = neg_adj_mask\n\n row, col = neg_row[:n_v], neg_col[:n_v]\n data.val_neg_edge_index = torch.stack([row, col], dim=0)\n\n row, col = neg_row[n_v:n_v + n_t], neg_col[n_v:n_v + n_t]\n data.test_neg_edge_index = torch.stack([row, col], dim=0)\n\n return data\n\n# Path: torch_geometric/utils/_tree_decomposition.py\nfrom itertools import chain\nfrom typing import Any, List, Literal, Tuple, Union, overload\n\nimport torch\nfrom scipy.sparse.csgraph import minimum_spanning_tree\nfrom torch import Tensor\n\nfrom torch_geometric.utils import (\n from_scipy_sparse_matrix,\n to_scipy_sparse_matrix,\n to_undirected,\n)\n\n\n@overload\ndef tree_decomposition(mol: Any) -> Tuple[Tensor, Tensor, int]:\n pass\n\n\n@overload\ndef tree_decomposition(\n mol: Any,\n return_vocab: Literal[False],\n) -> Tuple[Tensor, Tensor, int]:\n pass\n\n\n@overload\ndef tree_decomposition(\n mol: Any,\n return_vocab: Literal[True],\n) -> Tuple[Tensor, Tensor, int, Tensor]:\n pass\n\n\ndef tree_decomposition(\n mol: Any,\n return_vocab: bool = False,\n) -> Union[Tuple[Tensor, Tensor, int], Tuple[Tensor, Tensor, int, Tensor]]:\n r\"\"\"The tree decomposition algorithm of molecules from the\n `\"Junction Tree Variational Autoencoder for Molecular Graph Generation\"\n `_ paper.\n Returns the graph connectivity of the junction tree, the assignment\n mapping of each atom to the clique in the junction tree, and the number\n of cliques.\n\n Args:\n mol (rdkit.Chem.Mol): An :obj:`rdkit` molecule.\n return_vocab (bool, optional): If set to :obj:`True`, will return an\n identifier for each clique (ring, bond, bridged compounds, single).\n (default: :obj:`False`)\n\n :rtype: :obj:`(LongTensor, LongTensor, int)` if :obj:`return_vocab` is\n :obj:`False`, else :obj:`(LongTensor, LongTensor, int, LongTensor)`\n \"\"\"\n import rdkit.Chem as Chem\n\n # Cliques = rings and bonds.\n cliques: List[List[int]] = [list(x) for x in Chem.GetSymmSSSR(mol)]\n xs: List[int] = [0] * len(cliques)\n for bond in mol.GetBonds():\n if not bond.IsInRing():\n cliques.append([bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()])\n xs.append(1)\n\n # Generate `atom2cliques` mappings.\n atom2cliques: List[List[int]] = [[] for i in range(mol.GetNumAtoms())]\n for c in range(len(cliques)):\n for atom in cliques[c]:\n atom2cliques[atom].append(c)\n\n # Merge rings that share more than 2 atoms as they form bridged compounds.\n for c1 in range(len(cliques)):\n for atom in cliques[c1]:\n for c2 in atom2cliques[atom]:\n if c1 >= c2 or len(cliques[c1]) <= 2 or len(cliques[c2]) <= 2:\n continue\n if len(set(cliques[c1]) & set(cliques[c2])) > 2:\n cliques[c1] = list(set(cliques[c1]) | set(cliques[c2]))\n xs[c1] = 2\n cliques[c2] = []\n xs[c2] = -1\n cliques = [c for c in cliques if len(c) > 0]\n xs = [x for x in xs if x >= 0]\n\n # Update `atom2cliques` mappings.\n atom2cliques = [[] for i in range(mol.GetNumAtoms())]\n for c in range(len(cliques)):\n for atom in cliques[c]:\n atom2cliques[atom].append(c)\n\n # Add singleton cliques in case there are more than 2 intersecting\n # cliques. We further compute the \"initial\" clique graph.\n edges = {}\n for atom in range(mol.GetNumAtoms()):\n cs = atom2cliques[atom]\n if len(cs) <= 1:\n continue\n\n # Number of bond clusters that the atom lies in.\n bonds = [c for c in cs if len(cliques[c]) == 2]\n # Number of ring clusters that the atom lies in.\n rings = [c for c in cs if len(cliques[c]) > 4]\n\n if len(bonds) > 2 or (len(bonds) == 2 and len(cs) > 2):\n cliques.append([atom])\n xs.append(3)\n c2 = len(cliques) - 1\n for c1 in cs:\n edges[(c1, c2)] = 1\n\n elif len(rings) > 2:\n cliques.append([atom])\n xs.append(3)\n c2 = len(cliques) - 1\n for c1 in cs:\n edges[(c1, c2)] = 99\n\n else:\n for i in range(len(cs)):\n for j in range(i + 1, len(cs)):\n c1, c2 = cs[i], cs[j]\n count = len(set(cliques[c1]) & set(cliques[c2]))\n edges[(c1, c2)] = min(count, edges.get((c1, c2), 99))\n\n # Update `atom2cliques` mappings.\n atom2cliques = [[] for i in range(mol.GetNumAtoms())]\n for c in range(len(cliques)):\n for atom in cliques[c]:\n atom2cliques[atom].append(c)\n\n if len(edges) > 0:\n edge_index_T, weight = zip(*edges.items())\n edge_index = torch.tensor(edge_index_T).t()\n inv_weight = 100 - torch.tensor(weight)\n graph = to_scipy_sparse_matrix(edge_index, inv_weight, len(cliques))\n junc_tree = minimum_spanning_tree(graph)\n edge_index, _ = from_scipy_sparse_matrix(junc_tree)\n edge_index = to_undirected(edge_index, num_nodes=len(cliques))\n else:\n edge_index = torch.empty((2, 0), dtype=torch.long)\n\n rows = [[i] * len(atom2cliques[i]) for i in range(mol.GetNumAtoms())]\n row = torch.tensor(list(chain.from_iterable(rows)))\n col = torch.tensor(list(chain.from_iterable(atom2cliques)))\n atom2clique = torch.stack([row, col], dim=0).to(torch.long)\n\n if return_vocab:\n vocab = torch.tensor(xs, dtype=torch.long)\n return edge_index, atom2clique, len(cliques), vocab\n else:\n return edge_index, atom2clique, len(cliques)\n\n# Path: torch_geometric/utils/_unbatch.py\nfrom typing import List, Optional\n\nimport torch\nfrom torch import Tensor\n\nfrom torch_geometric.utils import cumsum, degree\n\n\ndef unbatch(\n src: Tensor,\n batch: Tensor,\n dim: int = 0,\n batch_size: Optional[int] = None,\n) -> List[Tensor]:\n r\"\"\"Splits :obj:`src` according to a :obj:`batch` vector along dimension\n :obj:`dim`.\n\n Args:\n src (Tensor): The source tensor.\n batch (LongTensor): The batch vector\n :math:`\\mathbf{b} \\in {\\{ 0, \\ldots, B-1\\}}^N`, which assigns each\n entry in :obj:`src` to a specific example. Must be ordered.\n dim (int, optional): The dimension along which to split the :obj:`src`\n tensor. (default: :obj:`0`)\n batch_size (int, optional): The batch size. (default: :obj:`None`)\n\n :rtype: :class:`List[Tensor]`\n\n Example:\n >>> src = torch.arange(7)\n >>> batch = torch.tensor([0, 0, 0, 1, 1, 2, 2])\n >>> unbatch(src, batch)\n (tensor([0, 1, 2]), tensor([3, 4]), tensor([5, 6]))\n \"\"\"\n sizes = degree(batch, batch_size, dtype=torch.long).tolist()\n return src.split(sizes, dim)\n\n\ndef unbatch_edge_index(\n edge_index: Tensor,\n batch: Tensor,\n batch_size: Optional[int] = None,\n) -> List[Tensor]:\n r\"\"\"Splits the :obj:`edge_index` according to a :obj:`batch` vector.\n\n Args:\n edge_index (Tensor): The edge_index tensor. Must be ordered.\n batch (LongTensor): The batch vector\n :math:`\\mathbf{b} \\in {\\{ 0, \\ldots, B-1\\}}^N`, which assigns each\n node to a specific example. Must be ordered.\n batch_size (int, optional): The batch size. (default: :obj:`None`)\n\n :rtype: :class:`List[Tensor]`\n\n Example:\n >>> edge_index = torch.tensor([[0, 1, 1, 2, 2, 3, 4, 5, 5, 6],\n ... [1, 0, 2, 1, 3, 2, 5, 4, 6, 5]])\n >>> batch = torch.tensor([0, 0, 0, 0, 1, 1, 1])\n >>> unbatch_edge_index(edge_index, batch)\n (tensor([[0, 1, 1, 2, 2, 3],\n [1, 0, 2, 1, 3, 2]]),\n tensor([[0, 1, 1, 2],\n [1, 0, 2, 1]]))\n \"\"\"\n deg = degree(batch, batch_size, dtype=torch.long)\n ptr = cumsum(deg)\n\n edge_batch = batch[edge_index[0]]\n edge_index = edge_index - ptr[edge_batch]\n sizes = degree(edge_batch, batch_size, dtype=torch.long).cpu().tolist()\n return edge_index.split(sizes, dim=1)\n\n# Path: torch_geometric/utils/augmentation.py\nfrom typing import Optional, Tuple, Union\n\nimport torch\nfrom torch import Tensor\n\nfrom torch_geometric.utils import cumsum, negative_sampling, scatter\n\n\ndef shuffle_node(\n x: Tensor,\n batch: Optional[Tensor] = None,\n training: bool = True,\n) -> Tuple[Tensor, Tensor]:\n r\"\"\"Randomly shuffle the feature matrix :obj:`x` along the\n first dimmension.\n\n The method returns (1) the shuffled :obj:`x`, (2) the permutation\n indicating the orders of original nodes after shuffling.\n\n Args:\n x (FloatTensor): The feature matrix.\n batch (LongTensor, optional): Batch vector\n :math:`\\mathbf{b} \\in {\\{ 0, \\ldots, B-1\\}}^N`, which assigns each\n node to a specific example. Must be ordered. (default: :obj:`None`)\n training (bool, optional): If set to :obj:`False`, this operation is a\n no-op. (default: :obj:`True`)\n\n :rtype: (:class:`FloatTensor`, :class:`LongTensor`)\n\n Example:\n >>> # Standard case\n >>> x = torch.tensor([[0, 1, 2],\n ... [3, 4, 5],\n ... [6, 7, 8],\n ... [9, 10, 11]], dtype=torch.float)\n >>> x, node_perm = shuffle_node(x)\n >>> x\n tensor([[ 3., 4., 5.],\n [ 9., 10., 11.],\n [ 0., 1., 2.],\n [ 6., 7., 8.]])\n >>> node_perm\n tensor([1, 3, 0, 2])\n\n >>> # For batched graphs as inputs\n >>> batch = torch.tensor([0, 0, 1, 1])\n >>> x, node_perm = shuffle_node(x, batch)\n >>> x\n tensor([[ 3., 4., 5.],\n [ 0., 1., 2.],\n [ 9., 10., 11.],\n [ 6., 7., 8.]])\n >>> node_perm\n tensor([1, 0, 3, 2])\n \"\"\"\n if not training:\n perm = torch.arange(x.size(0), device=x.device)\n return x, perm\n if batch is None:\n perm = torch.randperm(x.size(0), device=x.device)\n return x[perm], perm\n num_nodes = scatter(batch.new_ones(x.size(0)), batch, dim=0, reduce='sum')\n ptr = cumsum(num_nodes)\n perm = torch.cat([\n torch.randperm(n, device=x.device) + offset\n for offset, n in zip(ptr[:-1], num_nodes)\n ])\n return x[perm], perm\n\n\ndef mask_feature(\n x: Tensor,\n p: float = 0.5,\n mode: str = 'col',\n fill_value: float = 0.,\n training: bool = True,\n) -> Tuple[Tensor, Tensor]:\n r\"\"\"Randomly masks feature from the feature matrix\n :obj:`x` with probability :obj:`p` using samples from\n a Bernoulli distribution.\n\n The method returns (1) the retained :obj:`x`, (2) the feature\n mask broadcastable with :obj:`x` (:obj:`mode='row'` and :obj:`mode='col'`)\n or with the same shape as :obj:`x` (:obj:`mode='all'`),\n indicating where features are retained.\n\n Args:\n x (FloatTensor): The feature matrix.\n p (float, optional): The masking ratio. (default: :obj:`0.5`)\n mode (str, optional): The masked scheme to use for feature masking.\n (:obj:`\"row\"`, :obj:`\"col\"` or :obj:`\"all\"`).\n If :obj:`mode='col'`, will mask entire features of all nodes\n from the feature matrix. If :obj:`mode='row'`, will mask entire\n nodes from the feature matrix. If :obj:`mode='all'`, will mask\n individual features across all nodes. (default: :obj:`'col'`)\n fill_value (float, optional): The value for masked features in the\n output tensor. (default: :obj:`0`)\n training (bool, optional): If set to :obj:`False`, this operation is a\n no-op. (default: :obj:`True`)\n\n :rtype: (:class:`FloatTensor`, :class:`BoolTensor`)\n\n Examples:\n >>> # Masked features are column-wise sampled\n >>> x = torch.tensor([[1, 2, 3],\n ... [4, 5, 6],\n ... [7, 8, 9]], dtype=torch.float)\n >>> x, feat_mask = mask_feature(x)\n >>> x\n tensor([[1., 0., 3.],\n [4., 0., 6.],\n [7., 0., 9.]]),\n >>> feat_mask\n tensor([[True, False, True]])\n\n >>> # Masked features are row-wise sampled\n >>> x, feat_mask = mask_feature(x, mode='row')\n >>> x\n tensor([[1., 2., 3.],\n [0., 0., 0.],\n [7., 8., 9.]]),\n >>> feat_mask\n tensor([[True], [False], [True]])\n\n >>> # Masked features are uniformly sampled\n >>> x, feat_mask = mask_feature(x, mode='all')\n >>> x\n tensor([[0., 0., 0.],\n [4., 0., 6.],\n [0., 0., 9.]])\n >>> feat_mask\n tensor([[False, False, False],\n [True, False, True],\n [False, False, True]])\n \"\"\"\n if p < 0. or p > 1.:\n raise ValueError(f'Masking ratio has to be between 0 and 1 '\n f'(got {p}')\n if not training or p == 0.0:\n return x, torch.ones_like(x, dtype=torch.bool)\n assert mode in ['row', 'col', 'all']\n\n if mode == 'row':\n mask = torch.rand(x.size(0), device=x.device) >= p\n mask = mask.view(-1, 1)\n elif mode == 'col':\n mask = torch.rand(x.size(1), device=x.device) >= p\n mask = mask.view(1, -1)\n else:\n mask = torch.rand_like(x) >= p\n\n x = x.masked_fill(~mask, fill_value)\n return x, mask\n\n\ndef add_random_edge(\n edge_index: Tensor,\n p: float = 0.5,\n force_undirected: bool = False,\n num_nodes: Optional[Union[int, Tuple[int, int]]] = None,\n training: bool = True,\n) -> Tuple[Tensor, Tensor]:\n r\"\"\"Randomly adds edges to :obj:`edge_index`.\n\n The method returns (1) the retained :obj:`edge_index`, (2) the added\n edge indices.\n\n Args:\n edge_index (LongTensor): The edge indices.\n p (float): Ratio of added edges to the existing edges.\n (default: :obj:`0.5`)\n force_undirected (bool, optional): If set to :obj:`True`,\n added edges will be undirected.\n (default: :obj:`False`)\n num_nodes (int, Tuple[int], optional): The overall number of nodes,\n *i.e.* :obj:`max_val + 1`, or the number of source and\n destination nodes, *i.e.* :obj:`(max_src_val + 1, max_dst_val + 1)`\n of :attr:`edge_index`. (default: :obj:`None`)\n training (bool, optional): If set to :obj:`False`, this operation is a\n no-op. (default: :obj:`True`)\n\n :rtype: (:class:`LongTensor`, :class:`LongTensor`)\n\n Examples:\n >>> # Standard case\n >>> edge_index = torch.tensor([[0, 1, 1, 2, 2, 3],\n ... [1, 0, 2, 1, 3, 2]])\n >>> edge_index, added_edges = add_random_edge(edge_index, p=0.5)\n >>> edge_index\n tensor([[0, 1, 1, 2, 2, 3, 2, 1, 3],\n [1, 0, 2, 1, 3, 2, 0, 2, 1]])\n >>> added_edges\n tensor([[2, 1, 3],\n [0, 2, 1]])\n\n >>> # The returned graph is kept undirected\n >>> edge_index, added_edges = add_random_edge(edge_index, p=0.5,\n ... force_undirected=True)\n >>> edge_index\n tensor([[0, 1, 1, 2, 2, 3, 2, 1, 3, 0, 2, 1],\n [1, 0, 2, 1, 3, 2, 0, 2, 1, 2, 1, 3]])\n >>> added_edges\n tensor([[2, 1, 3, 0, 2, 1],\n [0, 2, 1, 2, 1, 3]])\n\n >>> # For bipartite graphs\n >>> edge_index = torch.tensor([[0, 1, 2, 3, 4, 5],\n ... [2, 3, 1, 4, 2, 1]])\n >>> edge_index, added_edges = add_random_edge(edge_index, p=0.5,\n ... num_nodes=(6, 5))\n >>> edge_index\n tensor([[0, 1, 2, 3, 4, 5, 3, 4, 1],\n [2, 3, 1, 4, 2, 1, 1, 3, 2]])\n >>> added_edges\n tensor([[3, 4, 1],\n [1, 3, 2]])\n \"\"\"\n if p < 0. or p > 1.:\n raise ValueError(f\"Ratio of added edges has to be between 0 and 1 \"\n f\"(got '{p}')\")\n if force_undirected and isinstance(num_nodes, (tuple, list)):\n raise RuntimeError(\"'force_undirected' is not supported for \"\n \"bipartite graphs\")\n\n device = edge_index.device\n if not training or p == 0.0:\n edge_index_to_add = torch.tensor([[], []], device=device)\n return edge_index, edge_index_to_add\n\n edge_index_to_add = negative_sampling(\n edge_index=edge_index,\n num_nodes=num_nodes,\n num_neg_samples=round(edge_index.size(1) * p),\n force_undirected=force_undirected,\n )\n\n edge_index = torch.cat([edge_index, edge_index_to_add], dim=1)\n\n return edge_index, edge_index_to_add\n\n# Path: torch_geometric/utils/convert.py\nfrom collections import defaultdict\nfrom typing import Any, Dict, Iterable, List, Literal, Optional, Tuple, Union\n\nimport scipy.sparse\nimport torch\nfrom torch import Tensor\nfrom torch.utils.dlpack import from_dlpack, to_dlpack\n\nimport torch_geometric\nfrom torch_geometric.utils.num_nodes import maybe_num_nodes\n\n\n\ndef to_scipy_sparse_matrix(\n edge_index: Tensor,\n edge_attr: Optional[Tensor] = None,\n num_nodes: Optional[int] = None,\n) -> scipy.sparse.coo_matrix:\n r\"\"\"Converts a graph given by edge indices and edge attributes to a scipy\n sparse matrix.\n\n Args:\n edge_index (LongTensor): The edge indices.\n edge_attr (Tensor, optional): Edge weights or multi-dimensional\n edge features. (default: :obj:`None`)\n num_nodes (int, optional): The number of nodes, *i.e.*\n :obj:`max_val + 1` of :attr:`index`. (default: :obj:`None`)\n\n Examples:\n >>> edge_index = torch.tensor([\n ... [0, 1, 1, 2, 2, 3],\n ... [1, 0, 2, 1, 3, 2],\n ... ])\n >>> to_scipy_sparse_matrix(edge_index)\n <4x4 sparse matrix of type ''\n with 6 stored elements in COOrdinate format>\n \"\"\"\n row, col = edge_index.cpu()\n\n if edge_attr is None:\n edge_attr = torch.ones(row.size(0))\n else:\n edge_attr = edge_attr.view(-1).cpu()\n assert edge_attr.size(0) == row.size(0)\n\n N = maybe_num_nodes(edge_index, num_nodes)\n out = scipy.sparse.coo_matrix(\n (edge_attr.numpy(), (row.numpy(), col.numpy())), (N, N))\n return out\n\n\ndef from_scipy_sparse_matrix(\n A: scipy.sparse.spmatrix) -> Tuple[Tensor, Tensor]:\n r\"\"\"Converts a scipy sparse matrix to edge indices and edge attributes.\n\n Args:\n A (scipy.sparse): A sparse matrix.\n\n Examples:\n >>> edge_index = torch.tensor([\n ... [0, 1, 1, 2, 2, 3],\n ... [1, 0, 2, 1, 3, 2],\n ... ])\n >>> adj = to_scipy_sparse_matrix(edge_index)\n >>> # `edge_index` and `edge_weight` are both returned\n >>> from_scipy_sparse_matrix(adj)\n (tensor([[0, 1, 1, 2, 2, 3],\n [1, 0, 2, 1, 3, 2]]),\n tensor([1., 1., 1., 1., 1., 1.]))\n \"\"\"\n A = A.tocoo()\n row = torch.from_numpy(A.row).to(torch.long)\n col = torch.from_numpy(A.col).to(torch.long)\n edge_index = torch.stack([row, col], dim=0)\n edge_weight = torch.from_numpy(A.data)\n return edge_index, edge_weight\n\n\ndef to_networkx(\n data: Union[\n 'torch_geometric.data.Data',\n 'torch_geometric.data.HeteroData',\n ],\n node_attrs: Optional[Iterable[str]] = None,\n edge_attrs: Optional[Iterable[str]] = None,\n graph_attrs: Optional[Iterable[str]] = None,\n to_undirected: Optional[Union[bool, str]] = False,\n to_multi: bool = False,\n remove_self_loops: bool = False,\n) -> Any:\n r\"\"\"Converts a :class:`torch_geometric.data.Data` instance to a\n :obj:`networkx.Graph` if :attr:`to_undirected` is set to :obj:`True`, or\n a directed :obj:`networkx.DiGraph` otherwise.\n\n Args:\n data (torch_geometric.data.Data or torch_geometric.data.HeteroData): A\n homogeneous or heterogeneous data object.\n node_attrs (iterable of str, optional): The node attributes to be\n copied. (default: :obj:`None`)\n edge_attrs (iterable of str, optional): The edge attributes to be\n copied. (default: :obj:`None`)\n graph_attrs (iterable of str, optional): The graph attributes to be\n copied. (default: :obj:`None`)\n to_undirected (bool or str, optional): If set to :obj:`True`, will\n return a :class:`networkx.Graph` instead of a\n :class:`networkx.DiGraph`.\n By default, will include all edges and make them undirected.\n If set to :obj:`\"upper\"`, the undirected graph will only correspond\n to the upper triangle of the input adjacency matrix.\n If set to :obj:`\"lower\"`, the undirected graph will only correspond\n to the lower triangle of the input adjacency matrix.\n Only applicable in case the :obj:`data` object holds a homogeneous\n graph. (default: :obj:`False`)\n to_multi (bool, optional): if set to :obj:`True`, will return a\n :class:`networkx.MultiGraph` or a :class:`networkx:MultiDiGraph`\n (depending on the :obj:`to_undirected` option), which will not drop\n duplicated edges that may exist in :obj:`data`.\n (default: :obj:`False`)\n remove_self_loops (bool, optional): If set to :obj:`True`, will not\n include self-loops in the resulting graph. (default: :obj:`False`)\n\n Examples:\n >>> edge_index = torch.tensor([\n ... [0, 1, 1, 2, 2, 3],\n ... [1, 0, 2, 1, 3, 2],\n ... ])\n >>> data = Data(edge_index=edge_index, num_nodes=4)\n >>> to_networkx(data)\n \n\n \"\"\"\n import networkx as nx\n\n from torch_geometric.data import HeteroData\n\n to_undirected_upper: bool = to_undirected == 'upper'\n to_undirected_lower: bool = to_undirected == 'lower'\n\n to_undirected = to_undirected is True\n to_undirected |= to_undirected_upper or to_undirected_lower\n assert isinstance(to_undirected, bool)\n\n if isinstance(data, HeteroData) and to_undirected:\n raise ValueError(\"'to_undirected' is not supported in \"\n \"'to_networkx' for heterogeneous graphs\")\n\n if to_undirected:\n G = nx.MultiGraph() if to_multi else nx.Graph()\n else:\n G = nx.MultiDiGraph() if to_multi else nx.DiGraph()\n\n def to_networkx_value(value: Any) -> Any:\n return value.tolist() if isinstance(value, Tensor) else value\n\n for key in graph_attrs or []:\n G.graph[key] = to_networkx_value(data[key])\n\n node_offsets = data.node_offsets\n for node_store in data.node_stores:\n start = node_offsets[node_store._key]\n assert node_store.num_nodes is not None\n for i in range(node_store.num_nodes):\n node_kwargs: Dict[str, Any] = {}\n if isinstance(data, HeteroData):\n node_kwargs['type'] = node_store._key\n for key in node_attrs or []:\n node_kwargs[key] = to_networkx_value(node_store[key][i])\n\n G.add_node(start + i, **node_kwargs)\n\n for edge_store in data.edge_stores:\n for i, (v, w) in enumerate(edge_store.edge_index.t().tolist()):\n if to_undirected_upper and v > w:\n continue\n elif to_undirected_lower and v < w:\n continue\n elif remove_self_loops and v == w and not edge_store.is_bipartite(\n ):\n continue\n\n edge_kwargs: Dict[str, Any] = {}\n if isinstance(data, HeteroData):\n v = v + node_offsets[edge_store._key[0]]\n w = w + node_offsets[edge_store._key[-1]]\n edge_kwargs['type'] = edge_store._key\n for key in edge_attrs or []:\n edge_kwargs[key] = to_networkx_value(edge_store[key][i])\n\n G.add_edge(v, w, **edge_kwargs)\n\n return G\n\n\ndef from_networkx(\n G: Any,\n group_node_attrs: Optional[Union[List[str], Literal['all']]] = None,\n group_edge_attrs: Optional[Union[List[str], Literal['all']]] = None,\n) -> 'torch_geometric.data.Data':\n r\"\"\"Converts a :obj:`networkx.Graph` or :obj:`networkx.DiGraph` to a\n :class:`torch_geometric.data.Data` instance.\n\n Args:\n G (networkx.Graph or networkx.DiGraph): A networkx graph.\n group_node_attrs (List[str] or \"all\", optional): The node attributes to\n be concatenated and added to :obj:`data.x`. (default: :obj:`None`)\n group_edge_attrs (List[str] or \"all\", optional): The edge attributes to\n be concatenated and added to :obj:`data.edge_attr`.\n (default: :obj:`None`)\n\n .. note::\n\n All :attr:`group_node_attrs` and :attr:`group_edge_attrs` values must\n be numeric.\n\n Examples:\n >>> edge_index = torch.tensor([\n ... [0, 1, 1, 2, 2, 3],\n ... [1, 0, 2, 1, 3, 2],\n ... ])\n >>> data = Data(edge_index=edge_index, num_nodes=4)\n >>> g = to_networkx(data)\n >>> # A `Data` object is returned\n >>> from_networkx(g)\n Data(edge_index=[2, 6], num_nodes=4)\n \"\"\"\n import networkx as nx\n\n from torch_geometric.data import Data\n\n G = G.to_directed() if not nx.is_directed(G) else G\n\n mapping = dict(zip(G.nodes(), range(G.number_of_nodes())))\n edge_index = torch.empty((2, G.number_of_edges()), dtype=torch.long)\n for i, (src, dst) in enumerate(G.edges()):\n edge_index[0, i] = mapping[src]\n edge_index[1, i] = mapping[dst]\n\n data_dict: Dict[str, Any] = defaultdict(list)\n data_dict['edge_index'] = edge_index\n\n node_attrs: List[str] = []\n if G.number_of_nodes() > 0:\n node_attrs = list(next(iter(G.nodes(data=True)))[-1].keys())\n\n edge_attrs: List[str] = []\n if G.number_of_edges() > 0:\n edge_attrs = list(next(iter(G.edges(data=True)))[-1].keys())\n\n if group_node_attrs is not None and not isinstance(group_node_attrs, list):\n group_node_attrs = node_attrs\n\n if group_edge_attrs is not None and not isinstance(group_edge_attrs, list):\n group_edge_attrs = edge_attrs\n\n for i, (_, feat_dict) in enumerate(G.nodes(data=True)):\n if set(feat_dict.keys()) != set(node_attrs):\n raise ValueError('Not all nodes contain the same attributes')\n for key, value in feat_dict.items():\n data_dict[str(key)].append(value)\n\n for i, (_, _, feat_dict) in enumerate(G.edges(data=True)):\n if set(feat_dict.keys()) != set(edge_attrs):\n raise ValueError('Not all edges contain the same attributes')\n for key, value in feat_dict.items():\n key = f'edge_{key}' if key in node_attrs else key\n data_dict[str(key)].append(value)\n\n for key, value in G.graph.items():\n if key == 'node_default' or key == 'edge_default':\n continue # Do not load default attributes.\n key = f'graph_{key}' if key in node_attrs else key\n data_dict[str(key)] = value\n\n for key, value in data_dict.items():\n if isinstance(value, (tuple, list)) and isinstance(value[0], Tensor):\n data_dict[key] = torch.stack(value, dim=0)\n else:\n try:\n data_dict[key] = torch.as_tensor(value)\n except Exception:\n pass\n\n data = Data.from_dict(data_dict)\n\n if group_node_attrs is not None:\n xs = []\n for key in group_node_attrs:\n x = data[key]\n x = x.view(-1, 1) if x.dim() <= 1 else x\n xs.append(x)\n del data[key]\n data.x = torch.cat(xs, dim=-1)\n\n if group_edge_attrs is not None:\n xs = []\n for key in group_edge_attrs:\n key = f'edge_{key}' if key in node_attrs else key\n x = data[key]\n x = x.view(-1, 1) if x.dim() <= 1 else x\n xs.append(x)\n del data[key]\n data.edge_attr = torch.cat(xs, dim=-1)\n\n if data.x is None and data.pos is None:\n data.num_nodes = G.number_of_nodes()\n\n return data\n\n\ndef to_networkit(\n edge_index: Tensor,\n edge_weight: Optional[Tensor] = None,\n num_nodes: Optional[int] = None,\n directed: bool = True,\n) -> Any:\n r\"\"\"Converts a :obj:`(edge_index, edge_weight)` tuple to a\n :class:`networkit.Graph`.\n\n Args:\n edge_index (torch.Tensor): The edge indices of the graph.\n edge_weight (torch.Tensor, optional): The edge weights of the graph.\n (default: :obj:`None`)\n num_nodes (int, optional): The number of nodes in the graph.\n (default: :obj:`None`)\n directed (bool, optional): If set to :obj:`False`, the graph will be\n undirected. (default: :obj:`True`)\n \"\"\"\n import networkit as nk\n\n num_nodes = maybe_num_nodes(edge_index, num_nodes)\n\n g = nk.graph.Graph(\n num_nodes,\n weighted=edge_weight is not None,\n directed=directed,\n )\n\n if edge_weight is None:\n edge_weight = torch.ones(edge_index.size(1))\n\n if not directed:\n mask = edge_index[0] <= edge_index[1]\n edge_index = edge_index[:, mask]\n edge_weight = edge_weight[mask]\n\n for (u, v), w in zip(edge_index.t().tolist(), edge_weight.tolist()):\n g.addEdge(u, v, w)\n\n return g\n\n\ndef from_networkit(g: Any) -> Tuple[Tensor, Optional[Tensor]]:\n r\"\"\"Converts a :class:`networkit.Graph` to a\n :obj:`(edge_index, edge_weight)` tuple.\n If the :class:`networkit.Graph` is not weighted, the returned\n :obj:`edge_weight` will be :obj:`None`.\n\n Args:\n g (networkkit.graph.Graph): A :obj:`networkit` graph object.\n \"\"\"\n is_directed = g.isDirected()\n is_weighted = g.isWeighted()\n\n edge_indices, edge_weights = [], []\n for u, v, w in g.iterEdgesWeights():\n edge_indices.append([u, v])\n edge_weights.append(w)\n if not is_directed:\n edge_indices.append([v, u])\n edge_weights.append(w)\n\n edge_index = torch.tensor(edge_indices).t().contiguous()\n edge_weight = torch.tensor(edge_weights) if is_weighted else None\n\n return edge_index, edge_weight\n\n\ndef to_trimesh(data: 'torch_geometric.data.Data') -> Any:\n r\"\"\"Converts a :class:`torch_geometric.data.Data` instance to a\n :obj:`trimesh.Trimesh`.\n\n Args:\n data (torch_geometric.data.Data): The data object.\n\n Example:\n >>> pos = torch.tensor([[0, 0, 0], [1, 0, 0], [0, 1, 0], [1, 1, 0]],\n ... dtype=torch.float)\n >>> face = torch.tensor([[0, 1, 2], [1, 2, 3]]).t()\n\n >>> data = Data(pos=pos, face=face)\n >>> to_trimesh(data)\n \n \"\"\"\n import trimesh\n\n assert data.pos is not None\n assert data.face is not None\n\n return trimesh.Trimesh(\n vertices=data.pos.detach().cpu().numpy(),\n faces=data.face.detach().t().cpu().numpy(),\n process=False,\n )\n\n\ndef from_trimesh(mesh: Any) -> 'torch_geometric.data.Data':\n r\"\"\"Converts a :obj:`trimesh.Trimesh` to a\n :class:`torch_geometric.data.Data` instance.\n\n Args:\n mesh (trimesh.Trimesh): A :obj:`trimesh` mesh.\n\n Example:\n >>> pos = torch.tensor([[0, 0, 0], [1, 0, 0], [0, 1, 0], [1, 1, 0]],\n ... dtype=torch.float)\n >>> face = torch.tensor([[0, 1, 2], [1, 2, 3]]).t()\n\n >>> data = Data(pos=pos, face=face)\n >>> mesh = to_trimesh(data)\n >>> from_trimesh(mesh)\n Data(pos=[4, 3], face=[3, 2])\n \"\"\"\n from torch_geometric.data import Data\n\n pos = torch.from_numpy(mesh.vertices).to(torch.float)\n face = torch.from_numpy(mesh.faces).t().contiguous()\n\n return Data(pos=pos, face=face)\n\n\ndef to_cugraph(\n edge_index: Tensor,\n edge_weight: Optional[Tensor] = None,\n relabel_nodes: bool = True,\n directed: bool = True,\n) -> Any:\n r\"\"\"Converts a graph given by :obj:`edge_index` and optional\n :obj:`edge_weight` into a :obj:`cugraph` graph object.\n\n Args:\n edge_index (torch.Tensor): The edge indices of the graph.\n edge_weight (torch.Tensor, optional): The edge weights of the graph.\n (default: :obj:`None`)\n relabel_nodes (bool, optional): If set to :obj:`True`,\n :obj:`cugraph` will remove any isolated nodes, leading to a\n relabeling of nodes. (default: :obj:`True`)\n directed (bool, optional): If set to :obj:`False`, the graph will be\n undirected. (default: :obj:`True`)\n \"\"\"\n import cudf\n import cugraph\n\n g = cugraph.Graph(directed=directed)\n df = cudf.from_dlpack(to_dlpack(edge_index.t()))\n\n if edge_weight is not None:\n assert edge_weight.dim() == 1\n df['2'] = cudf.from_dlpack(to_dlpack(edge_weight))\n\n g.from_cudf_edgelist(\n df,\n source=0,\n destination=1,\n edge_attr='2' if edge_weight is not None else None,\n renumber=relabel_nodes,\n )\n\n return g\n\n\ndef from_cugraph(g: Any) -> Tuple[Tensor, Optional[Tensor]]:\n r\"\"\"Converts a :obj:`cugraph` graph object into :obj:`edge_index` and\n optional :obj:`edge_weight` tensors.\n\n Args:\n g (cugraph.Graph): A :obj:`cugraph` graph object.\n \"\"\"\n df = g.view_edge_list()\n\n src = from_dlpack(df[0].to_dlpack()).long()\n dst = from_dlpack(df[1].to_dlpack()).long()\n edge_index = torch.stack([src, dst], dim=0)\n\n edge_weight = None\n if '2' in df:\n edge_weight = from_dlpack(df['2'].to_dlpack())\n\n return edge_index, edge_weight\n\n\ndef to_dgl(\n data: Union['torch_geometric.data.Data', 'torch_geometric.data.HeteroData']\n) -> Any:\n r\"\"\"Converts a :class:`torch_geometric.data.Data` or\n :class:`torch_geometric.data.HeteroData` instance to a :obj:`dgl` graph\n object.\n\n Args:\n data (torch_geometric.data.Data or torch_geometric.data.HeteroData):\n The data object.\n\n Example:\n...\n", "needle_token_start": 10274, "needle_token_end": 10716, "code_context_ntokens": 16127, "output": ["```python\nfrom scipy.sparse.csgraph import minimum_spanning_tree\nfrom torch import Tensor\n\nfrom torch_geometric.utils import (\n from_scipy_sparse_matrix,\n to_scipy_sparse_matrix,\n to_undirected,\n)\n\n\n@overload\ndef tree_decomposition(mol: Any) -> Tuple[Tensor, Tensor, int]:\n pass\n\n\n@overload\ndef tree_decomposition(\n mol: Any,\n return_vocab: Literal[False],\n) -> Tuple[Tensor, Tensor, int]:\n pass\n\n\n@overload\ndef tree_decomposition(\n mol: Any,\n return_vocab: Literal[True],\n) -> Tuple[Tensor, Tensor, int, Tensor]:\n pass\n\n\ndef tree_decomposition(\n mol: Any,\n return_vocab: bool = False,\n) -> Union[Tuple[Tensor, Tensor, int], Tuple[Tensor, Tensor, int, Tensor]]:\n r\"\"\"The tree decomposition algorithm of molecules from the\n `\"Junction Tree Variational Autoencoder for Molecular Graph Generation\"\n `_ paper.\n Returns the graph connectivity of the junction tree, the assignment\n mapping of each atom to the clique in the junction tree, and the number\n of cliques.\n\n Args:\n mol (rdkit.Chem.Mol): An :obj:`rdkit` molecule.\n return_vocab (bool, optional): If set to :obj:`True`, will return an\n identifier for each clique (ring, bond, bridged compounds, single).\n (default: :obj:`False`)\n\n :rtype: :obj:`(LongTensor, LongTensor, int)` if :obj:`return_vocab` is\n :obj:`False`, else :obj:`(LongTensor, LongTensor, int, LongTensor)`\n \"\"\"\n import rdkit.Chem as Chem\n\n # Cliques = rings and bonds.\n cliques: List[List[int]] = [list(x) for x in Chem.GetSymmSSSR(mol)]\n xs: List[int] = [0] * len(cliques)\n for bond in mol.GetBonds():\n if not bond.IsInRing():\n cliques.append([bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()])\n xs.append(1)\n\n # Generate `atom2cliques` mappings.\n atom2cliques: List[List[int]] = [[] for i in range(mol.GetNumAtoms())]\n for c in range(len(cliques)):\n for atom in cliques[c]:\n atom2cliques[atom].append(c)\n\n # Merge rings that share more than 2 atoms as they form bridged compounds.\n for c1 in range(len(cliques)):\n for atom in cliques[c1]:\n for c2 in atom2cliques[atom]:\n if c1 >= c2 or len(cliques[c1]) <= 2 or len(cliques[c2]) <= 2:\n continue\n if len(set(cliques[c1]) & set(cliques[c2])) > 2:\n cliques[c1] = list(set(cliques[c1]) | set(cliques[c2]))\n xs[c1] = 2\n cliques[c2] = []\n xs[c2] = -1\n cliques = [c for c in cliques if len(c) > 0]\n xs = [x for x in xs if x >= 0]\n\n # Update `atom2cliques` mappings.\n atom2cliques = [[] for i in range(mol.GetNumAtoms())]\n for c in range(len(cliques)):\n for atom in cliques[c]:\n atom2cliques[atom].append(c)\n\n # Add singleton cliques in case there are more than 2 intersecting\n # cliques. We further compute the \"initial\" clique graph.\n edges = {}\n for atom in range(mol.GetNumAtoms()):\n cs = atom2cliques[atom]\n if len(cs) <= 1:\n continue\n\n # Number of bond clusters that the atom lies in.\n bonds = [c for c in cs if len(cliques[c]) == 2]\n # Number of ring clusters that the atom lies in.\n rings = [c for c in cs if len(cliques[c]) > 4]\n\n if len(bonds) > 2 or (len(bonds) == 2 and len(cs) > 2):\n cliques.append([atom])\n xs.append(3)\n c2 = len(cliques) - 1\n for c1 in cs:\n edges[(c1, c2)] = 1\n\n elif len(rings) > 2:\n cliques.append([atom])\n xs.append(3)\n c2 = len(cliques) - 1\n for"]} +{"repo": "pyg-team/pytorch_geometric", "name": "set_tuple_item", "language": "python", "path": "torch_geometric/edge_index.py", "position_ratio": 0.75, "description": "\nFunction Description:\n1. **Purpose**: The function is designed to replace an item in a tuple with a new value at a specified index, since tuples are immutable in Python.\n2. **Input**: The function takes three parameters: a tuple containing any types of values, an integer representing the index at which the replacement should occur, and the new value to be inserted at that index.\n3. **Output**: It returns a new tuple that is a copy of the original tuple but with the specified item replaced by the new value.\n4. **Procedure**: The function first checks if the provided index is within the valid range for the tuple. If the index is out of range, it raises an `IndexError`. If the index is negative, it adjusts it to access the tuple in reverse order. The function then constructs a new tuple by slicing the original tuple up to the specified index, adding the new value, and then appending the remainder of the original tuple after the specified index.\n\n", "instruction": "Based on the function description and code context, please retrieve and repeat the exact described function from the code context in a code block wrapped by ```:", "template": "instruction\ncode_context\ndescription\ninstruction", "code_context": "# Path: torch_geometric/utils/_to_dense_adj.py\nfrom typing import Optional\n\nimport torch\nfrom torch import Tensor\n\nfrom torch_geometric.typing import OptTensor\nfrom torch_geometric.utils import cumsum, scatter\n\n\ndef to_dense_adj(\n edge_index: Tensor,\n batch: OptTensor = None,\n edge_attr: OptTensor = None,\n max_num_nodes: Optional[int] = None,\n batch_size: Optional[int] = None,\n) -> Tensor:\n r\"\"\"Converts batched sparse adjacency matrices given by edge indices and\n edge attributes to a single dense batched adjacency matrix.\n\n Args:\n edge_index (LongTensor): The edge indices.\n batch (LongTensor, optional): Batch vector\n :math:`\\mathbf{b} \\in {\\{ 0, \\ldots, B-1\\}}^N`, which assigns each\n node to a specific example. (default: :obj:`None`)\n edge_attr (Tensor, optional): Edge weights or multi-dimensional edge\n features.\n If :obj:`edge_index` contains duplicated edges, the dense adjacency\n matrix output holds the summed up entries of :obj:`edge_attr` for\n duplicated edges. (default: :obj:`None`)\n max_num_nodes (int, optional): The size of the output node dimension.\n (default: :obj:`None`)\n batch_size (int, optional): The batch size. (default: :obj:`None`)\n\n :rtype: :class:`Tensor`\n\n Examples:\n >>> edge_index = torch.tensor([[0, 0, 1, 2, 3],\n ... [0, 1, 0, 3, 0]])\n >>> batch = torch.tensor([0, 0, 1, 1])\n >>> to_dense_adj(edge_index, batch)\n tensor([[[1., 1.],\n [1., 0.]],\n [[0., 1.],\n [1., 0.]]])\n\n >>> to_dense_adj(edge_index, batch, max_num_nodes=4)\n tensor([[[1., 1., 0., 0.],\n [1., 0., 0., 0.],\n [0., 0., 0., 0.],\n [0., 0., 0., 0.]],\n [[0., 1., 0., 0.],\n [1., 0., 0., 0.],\n [0., 0., 0., 0.],\n [0., 0., 0., 0.]]])\n\n >>> edge_attr = torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0])\n >>> to_dense_adj(edge_index, batch, edge_attr)\n tensor([[[1., 2.],\n [3., 0.]],\n [[0., 4.],\n [5., 0.]]])\n \"\"\"\n if batch is None:\n max_index = int(edge_index.max()) + 1 if edge_index.numel() > 0 else 0\n batch = edge_index.new_zeros(max_index)\n\n if batch_size is None:\n batch_size = int(batch.max()) + 1 if batch.numel() > 0 else 1\n\n one = batch.new_ones(batch.size(0))\n num_nodes = scatter(one, batch, dim=0, dim_size=batch_size, reduce='sum')\n cum_nodes = cumsum(num_nodes)\n\n idx0 = batch[edge_index[0]]\n idx1 = edge_index[0] - cum_nodes[batch][edge_index[0]]\n idx2 = edge_index[1] - cum_nodes[batch][edge_index[1]]\n\n if max_num_nodes is None:\n max_num_nodes = int(num_nodes.max())\n\n elif ((idx1.numel() > 0 and idx1.max() >= max_num_nodes)\n or (idx2.numel() > 0 and idx2.max() >= max_num_nodes)):\n mask = (idx1 < max_num_nodes) & (idx2 < max_num_nodes)\n idx0 = idx0[mask]\n idx1 = idx1[mask]\n idx2 = idx2[mask]\n edge_attr = None if edge_attr is None else edge_attr[mask]\n\n if edge_attr is None:\n edge_attr = torch.ones(idx0.numel(), device=edge_index.device)\n\n size = [batch_size, max_num_nodes, max_num_nodes]\n size += list(edge_attr.size())[1:]\n flattened_size = batch_size * max_num_nodes * max_num_nodes\n\n idx = idx0 * max_num_nodes * max_num_nodes + idx1 * max_num_nodes + idx2\n adj = scatter(edge_attr, idx, dim=0, dim_size=flattened_size, reduce='sum')\n adj = adj.view(size)\n\n return adj\n\n# Path: torch_geometric/utils/_assortativity.py\nimport torch\nfrom torch import Tensor\n\nfrom torch_geometric.typing import Adj, SparseTensor\nfrom torch_geometric.utils import coalesce, degree\nfrom torch_geometric.utils._to_dense_adj import to_dense_adj\n\n\ndef assortativity(edge_index: Adj) -> float:\n r\"\"\"The degree assortativity coefficient from the\n `\"Mixing patterns in networks\"\n `_ paper.\n Assortativity in a network refers to the tendency of nodes to\n connect with other similar nodes over dissimilar nodes.\n It is computed from Pearson correlation coefficient of the node degrees.\n\n Args:\n edge_index (Tensor or SparseTensor): The graph connectivity.\n\n Returns:\n The value of the degree assortativity coefficient for the input\n graph :math:`\\in [-1, 1]`\n\n Example:\n >>> edge_index = torch.tensor([[0, 1, 2, 3, 2],\n ... [1, 2, 0, 1, 3]])\n >>> assortativity(edge_index)\n -0.666667640209198\n \"\"\"\n if isinstance(edge_index, SparseTensor):\n adj: SparseTensor = edge_index\n row, col, _ = adj.coo()\n else:\n assert isinstance(edge_index, Tensor)\n row, col = edge_index\n\n device = row.device\n out_deg = degree(row, dtype=torch.long)\n in_deg = degree(col, dtype=torch.long)\n degrees = torch.unique(torch.cat([out_deg, in_deg]))\n mapping = row.new_zeros(degrees.max().item() + 1)\n mapping[degrees] = torch.arange(degrees.size(0), device=device)\n\n # Compute degree mixing matrix (joint probability distribution) `M`\n num_degrees = degrees.size(0)\n src_deg = mapping[out_deg[row]]\n dst_deg = mapping[in_deg[col]]\n\n pairs = torch.stack([src_deg, dst_deg], dim=0)\n occurrence = torch.ones(pairs.size(1), device=device)\n pairs, occurrence = coalesce(pairs, occurrence)\n M = to_dense_adj(pairs, edge_attr=occurrence, max_num_nodes=num_degrees)[0]\n # normalization\n M /= M.sum()\n\n # numeric assortativity coefficient, computed by\n # Pearson correlation coefficient of the node degrees\n x = y = degrees.float()\n a, b = M.sum(0), M.sum(1)\n\n vara = (a * x**2).sum() - ((a * x).sum())**2\n varb = (b * x**2).sum() - ((b * x).sum())**2\n xy = torch.outer(x, y)\n ab = torch.outer(a, b)\n out = (xy * (M - ab)).sum() / (vara * varb).sqrt()\n return out.item()\n\n# Path: torch_geometric/edge_index.py\nimport functools\nimport typing\nfrom enum import Enum\nfrom typing import (\n Any,\n Callable,\n Dict,\n List,\n Literal,\n NamedTuple,\n Optional,\n Sequence,\n Set,\n Tuple,\n Type,\n Union,\n get_args,\n overload,\n)\n\nimport torch\nfrom torch import Tensor\n\nimport torch_geometric.typing\nfrom torch_geometric import is_compiling\nfrom torch_geometric.typing import SparseTensor\n\nHANDLED_FUNCTIONS: Dict[Callable, Callable] = {}\n\nif torch_geometric.typing.WITH_PT20:\n SUPPORTED_DTYPES: Set[torch.dtype] = {\n torch.int32,\n torch.int64,\n }\nelif not typing.TYPE_CHECKING: # pragma: no cover\n SUPPORTED_DTYPES: Set[torch.dtype] = {\n torch.int64,\n }\n\nReduceType = Literal['sum', 'mean', 'amin', 'amax', 'add', 'min', 'max']\nPYG_REDUCE: Dict[ReduceType, ReduceType] = {\n 'add': 'sum',\n 'amin': 'min',\n 'amax': 'max'\n}\nTORCH_REDUCE: Dict[ReduceType, ReduceType] = {\n 'add': 'sum',\n 'min': 'amin',\n 'max': 'amax'\n}\n\n\nclass SortOrder(Enum):\n ROW = 'row'\n COL = 'col'\n\n\nclass CatMetadata(NamedTuple):\n nnz: List[int]\n sparse_size: List[Tuple[Optional[int], Optional[int]]]\n sort_order: List[Optional[SortOrder]]\n is_undirected: List[bool]\n\n\ndef implements(torch_function: Callable) -> Callable:\n r\"\"\"Registers a :pytorch:`PyTorch` function override.\"\"\"\n @functools.wraps(torch_function)\n def decorator(my_function: Callable) -> Callable:\n HANDLED_FUNCTIONS[torch_function] = my_function\n return my_function\n\n return decorator\n\n\n\ndef set_tuple_item(\n values: Tuple[Any, ...],\n dim: int,\n value: Any,\n) -> Tuple[Any, ...]:\n if dim < -len(values) or dim >= len(values):\n raise IndexError(\"tuple index out of range\")\n\n dim = dim + len(values) if dim < 0 else dim\n return values[:dim] + (value, ) + values[dim + 1:]\n\n\ndef maybe_add(\n value: Sequence[Optional[int]],\n other: Union[int, Sequence[Optional[int]]],\n alpha: int = 1,\n) -> Tuple[Optional[int], ...]:\n\n if isinstance(other, int):\n return tuple(v + alpha * other if v is not None else None\n for v in value)\n\n assert len(value) == len(other)\n return tuple(v + alpha * o if v is not None and o is not None else None\n for v, o in zip(value, other))\n\n\ndef maybe_sub(\n value: Sequence[Optional[int]],\n other: Union[int, Sequence[Optional[int]]],\n alpha: int = 1,\n) -> Tuple[Optional[int], ...]:\n\n if isinstance(other, int):\n return tuple(v - alpha * other if v is not None else None\n for v in value)\n\n assert len(value) == len(other)\n return tuple(v - alpha * o if v is not None and o is not None else None\n for v, o in zip(value, other))\n\n\ndef ptr2index(ptr: Tensor, output_size: Optional[int] = None) -> Tensor:\n index = torch.arange(ptr.numel() - 1, dtype=ptr.dtype, device=ptr.device)\n return index.repeat_interleave(ptr.diff(), output_size=output_size)\n\n\ndef assert_valid_dtype(tensor: Tensor) -> None:\n if tensor.dtype not in SUPPORTED_DTYPES:\n raise ValueError(f\"'EdgeIndex' holds an unsupported data type \"\n f\"(got '{tensor.dtype}', but expected one of \"\n f\"{SUPPORTED_DTYPES})\")\n\n\ndef assert_two_dimensional(tensor: Tensor) -> None:\n if tensor.dim() != 2:\n raise ValueError(f\"'EdgeIndex' needs to be two-dimensional \"\n f\"(got {tensor.dim()} dimensions)\")\n if not torch.jit.is_tracing() and tensor.size(0) != 2:\n raise ValueError(f\"'EdgeIndex' needs to have a shape of \"\n f\"[2, *] (got {list(tensor.size())})\")\n\n\ndef assert_contiguous(tensor: Tensor) -> None:\n if not tensor.is_contiguous():\n raise ValueError(\"'EdgeIndex' needs to be contiguous. Please call \"\n \"`edge_index.contiguous()` before proceeding.\")\n\n\ndef assert_symmetric(size: Tuple[Optional[int], Optional[int]]) -> None:\n if (not torch.jit.is_tracing() and size[0] is not None\n and size[1] is not None and size[0] != size[1]):\n raise ValueError(f\"'EdgeIndex' is undirected but received a \"\n f\"non-symmetric size (got {list(size)})\")\n\n\ndef assert_sorted(func: Callable) -> Callable:\n @functools.wraps(func)\n def wrapper(*args: Any, **kwargs: Any) -> Any:\n if not args[0].is_sorted:\n cls_name = args[0].__class__.__name__\n raise ValueError(\n f\"Cannot call '{func.__name__}' since '{cls_name}' is not \"\n f\"sorted. Please call `{cls_name}.sort_by(...)` first.\")\n return func(*args, **kwargs)\n\n return wrapper\n\n\nclass EdgeIndex(Tensor):\n r\"\"\"A COO :obj:`edge_index` tensor with additional (meta)data attached.\n\n :class:`EdgeIndex` is a :pytorch:`null` :class:`torch.Tensor`, that holds\n an :obj:`edge_index` representation of shape :obj:`[2, num_edges]`.\n Edges are given as pairwise source and destination node indices in sparse\n COO format.\n\n While :class:`EdgeIndex` sub-classes a general :pytorch:`null`\n :class:`torch.Tensor`, it can hold additional (meta)data, *i.e.*:\n\n * :obj:`sparse_size`: The underlying sparse matrix size\n * :obj:`sort_order`: The sort order (if present), either by row or column.\n * :obj:`is_undirected`: Whether edges are bidirectional.\n\n Additionally, :class:`EdgeIndex` caches data for fast CSR or CSC conversion\n in case its representation is sorted, such as its :obj:`rowptr` or\n :obj:`colptr`, or the permutation vector for going from CSR to CSC or vice\n versa.\n Caches are filled based on demand (*e.g.*, when calling\n :meth:`EdgeIndex.sort_by`), or when explicitly requested via\n :meth:`EdgeIndex.fill_cache_`, and are maintained and adjusted over its\n lifespan (*e.g.*, when calling :meth:`EdgeIndex.flip`).\n\n This representation ensures for optimal computation in GNN message passing\n schemes, while preserving the ease-of-use of regular COO-based :pyg:`PyG`\n workflows.\n\n .. code-block:: python\n\n from torch_geometric import EdgeIndex\n\n edge_index = EdgeIndex(\n [[0, 1, 1, 2],\n [1, 0, 2, 1]]\n sparse_size=(3, 3),\n sort_order='row',\n is_undirected=True,\n device='cpu',\n )\n >>> EdgeIndex([[0, 1, 1, 2],\n ... [1, 0, 2, 1]])\n assert edge_index.is_sorted_by_row\n assert edge_index.is_undirected\n\n # Flipping order:\n edge_index = edge_index.flip(0)\n >>> EdgeIndex([[1, 0, 2, 1],\n ... [0, 1, 1, 2]])\n assert edge_index.is_sorted_by_col\n assert edge_index.is_undirected\n\n # Filtering:\n mask = torch.tensor([True, True, True, False])\n edge_index = edge_index[:, mask]\n >>> EdgeIndex([[1, 0, 2],\n ... [0, 1, 1]])\n assert edge_index.is_sorted_by_col\n assert not edge_index.is_undirected\n\n # Sparse-Dense Matrix Multiplication:\n out = edge_index.flip(0) @\u00a0torch.randn(3, 16)\n assert out.size() == (3, 16)\n \"\"\"\n # See \"https://pytorch.org/docs/stable/notes/extending.html\"\n # for a basic tutorial on how to subclass `torch.Tensor`.\n\n # The underlying tensor representation:\n _data: Optional[Tensor] = None\n\n # The size of the underlying sparse matrix:\n _sparse_size: Tuple[Optional[int], Optional[int]] = (None, None)\n\n # Whether the `edge_index` represented is non-sorted (`None`), or sorted\n # based on row or column values.\n _sort_order: Optional[SortOrder] = None\n\n # Whether the `edge_index` is undirected:\n # NOTE `is_undirected` allows us to assume symmetric adjacency matrix size\n # and to share compressed pointer representations, however, it does not\n # allow us get rid of CSR/CSC permutation vectors since ordering within\n # neighborhoods is not necessarily deterministic.\n _is_undirected: bool = False\n\n # A cache for its compressed representation:\n _indptr: Optional[Tensor] = None\n\n # A cache for its transposed representation:\n _T_perm: Optional[Tensor] = None\n _T_index: Tuple[Optional[Tensor], Optional[Tensor]] = (None, None)\n _T_indptr: Optional[Tensor] = None\n\n # A cached \"1\"-value vector for `torch.sparse` matrix multiplication:\n _value: Optional[Tensor] = None\n\n # Whenever we perform a concatenation of edge indices, we cache the\n # original metadata to be able to reconstruct individual edge indices:\n _cat_metadata: Optional[CatMetadata] = None\n\n def __new__(\n cls: Type,\n data: Any,\n *args: Any,\n sparse_size: Optional[Tuple[Optional[int], Optional[int]]] = None,\n sort_order: Optional[Union[str, SortOrder]] = None,\n is_undirected: bool = False,\n **kwargs: Any,\n ) -> 'EdgeIndex':\n if not isinstance(data, Tensor):\n data = torch.tensor(data, *args, **kwargs)\n elif len(args) > 0:\n raise TypeError(\n f\"new() received an invalid combination of arguments - got \"\n f\"(Tensor, {', '.join(str(type(arg)) for arg in args)})\")\n elif len(kwargs) > 0:\n raise TypeError(f\"new() received invalid keyword arguments - got \"\n f\"{set(kwargs.keys())})\")\n\n assert isinstance(data, Tensor)\n\n indptr: Optional[Tensor] = None\n\n if isinstance(data, cls): # If passed `EdgeIndex`, inherit metadata:\n indptr = data._indptr\n sparse_size = sparse_size or data.sparse_size()\n sort_order = sort_order or data.sort_order\n is_undirected = is_undirected or data.is_undirected\n\n # Convert `torch.sparse` tensors to `EdgeIndex` representation:\n if data.layout == torch.sparse_coo:\n sort_order = SortOrder.ROW\n sparse_size = sparse_size or (data.size(0), data.size(1))\n data = data.indices()\n\n if data.layout == torch.sparse_csr:\n indptr = data.crow_indices()\n col = data.col_indices()\n\n assert isinstance(indptr, Tensor)\n row = ptr2index(indptr, output_size=col.numel())\n\n sort_order = SortOrder.ROW\n sparse_size = sparse_size or (data.size(0), data.size(1))\n if sparse_size[0] is not None and sparse_size[0] != data.size(0):\n indptr = None\n data = torch.stack([row, col], dim=0)\n\n if (torch_geometric.typing.WITH_PT112\n and data.layout == torch.sparse_csc):\n row = data.row_indices()\n indptr = data.ccol_indices()\n\n assert isinstance(indptr, Tensor)\n col = ptr2index(indptr, output_size=row.numel())\n\n sort_order = SortOrder.COL\n sparse_size = sparse_size or (data.size(0), data.size(1))\n if sparse_size[1] is not None and sparse_size[1] != data.size(1):\n indptr = None\n data = torch.stack([row, col], dim=0)\n\n assert_valid_dtype(data)\n assert_two_dimensional(data)\n assert_contiguous(data)\n\n if sparse_size is None:\n sparse_size = (None, None)\n\n if is_undirected:\n assert_symmetric(sparse_size)\n if sparse_size[0] is not None and sparse_size[1] is None:\n sparse_size = (sparse_size[0], sparse_size[0])\n elif sparse_size[0] is None and sparse_size[1] is not None:\n sparse_size = (sparse_size[1], sparse_size[1])\n\n if torch_geometric.typing.WITH_PT112:\n out = super().__new__(cls, data)\n else:\n out = Tensor._make_subclass(cls, data)\n\n # Attach metadata:\n assert isinstance(out, EdgeIndex)\n if torch_geometric.typing.WITH_PT22:\n out._data = data\n out._sparse_size = sparse_size\n out._sort_order = None if sort_order is None else SortOrder(sort_order)\n out._is_undirected = is_undirected\n out._indptr = indptr\n\n if isinstance(data, cls): # If passed `EdgeIndex`, inherit metadata:\n out._T_perm = data._T_perm\n out._T_index = data._T_index\n out._T_indptr = data._T_indptr\n out._value = out._value\n\n # Reset metadata if cache is invalidated:\n num_rows = sparse_size[0]\n if num_rows is not None and num_rows != data.sparse_size(0):\n out._indptr = None\n\n num_cols = sparse_size[1]\n if num_cols is not None and num_cols != data.sparse_size(1):\n out._T_indptr = None\n\n return out\n\n # Validation ##############################################################\n\n def validate(self) -> 'EdgeIndex':\n r\"\"\"Validates the :class:`EdgeIndex` representation.\n\n In particular, it ensures that\n\n * it only holds valid indices.\n * the sort order is correctly set.\n * indices are bidirectional in case it is specified as undirected.\n \"\"\"\n assert_valid_dtype(self)\n assert_two_dimensional(self)\n assert_contiguous(self)\n if self.is_undirected:\n assert_symmetric(self.sparse_size())\n\n if self.numel() > 0 and self.min() < 0:\n raise ValueError(f\"'{self.__class__.__name__}' contains negative \"\n f\"indices (got {int(self.min())})\")\n\n if (self.numel() > 0 and self.num_rows is not None\n and self[0].max() >= self.num_rows):\n raise ValueError(f\"'{self.__class__.__name__}' contains larger \"\n f\"indices than its number of rows \"\n f\"(got {int(self[0].max())}, but expected values \"\n f\"smaller than {self.num_rows})\")\n\n if (self.numel() > 0 and self.num_cols is not None\n and self[1].max() >= self.num_cols):\n raise ValueError(f\"'{self.__class__.__name__}' contains larger \"\n f\"indices than its number of columns \"\n f\"(got {int(self[1].max())}, but expected values \"\n f\"smaller than {self.num_cols})\")\n\n if self.is_sorted_by_row and (self[0].diff() < 0).any():\n raise ValueError(f\"'{self.__class__.__name__}' is not sorted by \"\n f\"row indices\")\n\n if self.is_sorted_by_col and (self[1].diff() < 0).any():\n raise ValueError(f\"'{self.__class__.__name__}' is not sorted by \"\n f\"column indices\")\n\n if self.is_undirected:\n flat_index1 = (self[0] * self.get_num_rows() + self[1]).sort()[0]\n...\n", "needle_token_start": 2771, "needle_token_end": 2878, "code_context_ntokens": 6825, "output": ["```python\ndef set_tuple_item(\n values: Tuple[Any, ...],\n dim: int,\n value: Any,\n) -> Tuple[Any, ...]:\n if dim < -len(values) or dim >= len(values):\n raise IndexError(\"tuple index out of range\")\n\n dim = dim + len(values) if dim < 0 else dim\n return values[:dim] + (value, ) + values[dim + 1:]\n```"]} +{"repo": "pyg-team/pytorch_geometric", "name": "_to_edge_mask", "language": "python", "path": "torch_geometric/explain/algorithm/captum.py", "position_ratio": 0.85, "description": "\nFunction Description:\n1. **Purpose**: The function generates a tensor representing a mask for edges in a graph, where each edge is initially considered equally important.\n2. **Input**: A tensor representing the indices of edges in a graph.\n3. **Output**: A tensor of ones with the same length as the number of edges, with gradient tracking enabled.\n4. **Procedure**: The function calculates the number of edges based on the input tensor's second dimension, then creates and returns a tensor filled with ones, sized according to the number of edges, and configured to track gradients and match the device of the input tensor.\n\n", "instruction": "Based on the function description and code context, please retrieve and repeat the exact described function from the code context in a code block wrapped by ```:", "template": "instruction\ncode_context\ndescription\ninstruction", "code_context": "# Path: torch_geometric/nn/conv/transformer_conv.py\nimport math\nimport typing\nfrom typing import Optional, Tuple, Union\n\nimport torch\nimport torch.nn.functional as F\nfrom torch import Tensor\n\nfrom torch_geometric.nn.conv import MessagePassing\nfrom torch_geometric.nn.dense.linear import Linear\nfrom torch_geometric.typing import (\n Adj,\n NoneType,\n OptTensor,\n PairTensor,\n SparseTensor,\n)\nfrom torch_geometric.utils import softmax\n\nif typing.TYPE_CHECKING:\n from typing import overload\nelse:\n from torch.jit import _overload_method as overload\n\n\nclass TransformerConv(MessagePassing):\n r\"\"\"The graph transformer operator from the `\"Masked Label Prediction:\n Unified Message Passing Model for Semi-Supervised Classification\"\n `_ paper.\n\n .. math::\n \\mathbf{x}^{\\prime}_i = \\mathbf{W}_1 \\mathbf{x}_i +\n \\sum_{j \\in \\mathcal{N}(i)} \\alpha_{i,j} \\mathbf{W}_2 \\mathbf{x}_{j},\n\n where the attention coefficients :math:`\\alpha_{i,j}` are computed via\n multi-head dot product attention:\n\n .. math::\n \\alpha_{i,j} = \\textrm{softmax} \\left(\n \\frac{(\\mathbf{W}_3\\mathbf{x}_i)^{\\top} (\\mathbf{W}_4\\mathbf{x}_j)}\n {\\sqrt{d}} \\right)\n\n Args:\n in_channels (int or tuple): Size of each input sample, or :obj:`-1` to\n derive the size from the first input(s) to the forward method.\n A tuple corresponds to the sizes of source and target\n dimensionalities.\n out_channels (int): Size of each output sample.\n heads (int, optional): Number of multi-head-attentions.\n (default: :obj:`1`)\n concat (bool, optional): If set to :obj:`False`, the multi-head\n attentions are averaged instead of concatenated.\n (default: :obj:`True`)\n beta (bool, optional): If set, will combine aggregation and\n skip information via\n\n .. math::\n \\mathbf{x}^{\\prime}_i = \\beta_i \\mathbf{W}_1 \\mathbf{x}_i +\n (1 - \\beta_i) \\underbrace{\\left(\\sum_{j \\in \\mathcal{N}(i)}\n \\alpha_{i,j} \\mathbf{W}_2 \\vec{x}_j \\right)}_{=\\mathbf{m}_i}\n\n with :math:`\\beta_i = \\textrm{sigmoid}(\\mathbf{w}_5^{\\top}\n [ \\mathbf{W}_1 \\mathbf{x}_i, \\mathbf{m}_i, \\mathbf{W}_1\n \\mathbf{x}_i - \\mathbf{m}_i ])` (default: :obj:`False`)\n...\n# Path: torch_geometric/nn/conv/utils/cheatsheet.py\nimport importlib\nimport inspect\nimport re\nfrom typing import Optional\n\n\ndef paper_title(cls: str) -> Optional[str]:\n cls = importlib.import_module('torch_geometric.nn.conv').__dict__[cls]\n match = re.search('`\\\".+?\\\"', inspect.getdoc(cls), flags=re.DOTALL)\n return None if match is None else match.group().replace('\\n', ' ')[2:-1]\n\n\ndef paper_link(cls: str) -> Optional[str]:\n cls = importlib.import_module('torch_geometric.nn.conv').__dict__[cls]\n match = re.search('<.+?>', inspect.getdoc(cls), flags=re.DOTALL)\n return None if match is None else match.group().replace('\\n', ' ')[1:-1]\n\n\ndef supports_sparse_tensor(cls: str) -> bool:\n cls = importlib.import_module('torch_geometric.nn.conv').__dict__[cls]\n signature = inspect.signature(cls.forward)\n return 'SparseTensor' in str(signature)\n\n\ndef supports_edge_weights(cls: str) -> bool:\n cls = importlib.import_module('torch_geometric.nn.conv').__dict__[cls]\n signature = inspect.signature(cls.forward)\n return 'edge_weight' in str(signature)\n\n\ndef supports_edge_features(cls: str) -> bool:\n cls = importlib.import_module('torch_geometric.nn.conv').__dict__[cls]\n signature = inspect.signature(cls.forward)\n return 'edge_attr' in str(signature)\n\n\ndef supports_bipartite_graphs(cls: str) -> bool:\n cls = importlib.import_module('torch_geometric.nn.conv').__dict__[cls]\n signature = inspect.signature(cls.forward)\n return 'Union[torch.Tensor, Tuple[torch.Tensor' in str(signature)\n\n\ndef supports_static_graphs(cls: str) -> bool:\n cls = importlib.import_module('torch_geometric.nn.conv').__dict__[cls]\n return 'node_dim=' not in inspect.getsource(cls.__init__)\n\n\ndef supports_lazy_initialization(cls: str) -> bool:\n cls = importlib.import_module('torch_geometric.nn.conv').__dict__[cls]\n doc = re.sub(' +', ' ', inspect.getdoc(cls).replace('\\n', ' '))\n match = re.search('or :obj:`-1` to derive the size from the first', doc)\n return match is not None\n\n\ndef processes_heterogeneous_graphs(cls: str) -> bool:\n if 'hetero' in cls.lower():\n return True\n cls = importlib.import_module('torch_geometric.nn.conv').__dict__[cls]\n signature = inspect.signature(cls.forward)\n return 'edge_index_dict' in str(signature) or 'edge_type' in str(signature)\n\n\ndef processes_hypergraphs(cls: str) -> bool:\n cls = importlib.import_module('torch_geometric.nn.conv').__dict__[cls]\n signature = inspect.signature(cls.forward)\n return 'hyperedge_index' in str(signature)\n\n\ndef processes_point_clouds(cls: str) -> bool:\n cls = importlib.import_module('torch_geometric.nn.conv').__dict__[cls]\n signature = inspect.signature(cls.forward)\n return (('edge_index' not in str(signature)\n and 'csc' not in str(signature)) or 'pos' in str(signature))\n\n# Path: torch_geometric/nn/conv/utils/__init__.py\nr\"\"\"GNN utility package.\"\"\"\n\nfrom .cheatsheet import paper_title, paper_link\nfrom .cheatsheet import supports_sparse_tensor\nfrom .cheatsheet import supports_edge_weights\nfrom .cheatsheet import supports_edge_features\nfrom .cheatsheet import supports_bipartite_graphs\nfrom .cheatsheet import supports_static_graphs\nfrom .cheatsheet import supports_lazy_initialization\nfrom .cheatsheet import processes_heterogeneous_graphs\nfrom .cheatsheet import processes_hypergraphs\nfrom .cheatsheet import processes_point_clouds\n\n__all__ = [\n 'paper_title',\n 'paper_link',\n 'supports_sparse_tensor',\n 'supports_edge_weights',\n 'supports_edge_features',\n 'supports_bipartite_graphs',\n 'supports_static_graphs',\n 'supports_lazy_initialization',\n 'processes_heterogeneous_graphs',\n 'processes_hypergraphs',\n 'processes_point_clouds',\n]\n\n# Path: torch_geometric/nn/conv/wl_conv.py\nfrom typing import Optional\n\nimport torch\nfrom torch import Tensor\n\nfrom torch_geometric.typing import Adj\nfrom torch_geometric.utils import (\n degree,\n is_sparse,\n scatter,\n sort_edge_index,\n to_edge_index,\n)\n\n\nclass WLConv(torch.nn.Module):\n r\"\"\"The Weisfeiler Lehman (WL) operator from the `\"A Reduction of a Graph\n to a Canonical Form and an Algebra Arising During this Reduction\"\n `_ paper.\n\n :class:`WLConv` iteratively refines node colorings according to:\n\n .. math::\n \\mathbf{x}^{\\prime}_i = \\textrm{hash} \\left( \\mathbf{x}_i, \\{\n \\mathbf{x}_j \\colon j \\in \\mathcal{N}(i) \\} \\right)\n\n Shapes:\n - **input:**\n node coloring :math:`(|\\mathcal{V}|, F_{in})` *(one-hot encodings)*\n or :math:`(|\\mathcal{V}|)` *(integer-based)*,\n edge indices :math:`(2, |\\mathcal{E}|)`\n - **output:** node coloring :math:`(|\\mathcal{V}|)` *(integer-based)*\n \"\"\"\n def __init__(self):\n super().__init__()\n self.hashmap = {}\n\n def reset_parameters(self):\n r\"\"\"Resets all learnable parameters of the module.\"\"\"\n self.hashmap = {}\n\n @torch.no_grad()\n def forward(self, x: Tensor, edge_index: Adj) -> Tensor:\n r\"\"\"Runs the forward pass of the module.\"\"\"\n if x.dim() > 1:\n assert (x.sum(dim=-1) == 1).sum() == x.size(0)\n x = x.argmax(dim=-1) # one-hot -> integer.\n assert x.dtype == torch.long\n\n if is_sparse(edge_index):\n col_and_row, _ = to_edge_index(edge_index)\n col = col_and_row[0]\n row = col_and_row[1]\n else:\n edge_index = sort_edge_index(edge_index, num_nodes=x.size(0),\n sort_by_row=False)\n row, col = edge_index[0], edge_index[1]\n\n # `col` is sorted, so we can use it to `split` neighbors to groups:\n deg = degree(col, x.size(0), dtype=torch.long).tolist()\n\n out = []\n for node, neighbors in zip(x.tolist(), x[row].split(deg)):\n idx = hash(tuple([node] + neighbors.sort()[0].tolist()))\n if idx not in self.hashmap:\n self.hashmap[idx] = len(self.hashmap)\n out.append(self.hashmap[idx])\n\n return torch.tensor(out, device=x.device)\n\n def histogram(self, x: Tensor, batch: Optional[Tensor] = None,\n norm: bool = False) -> Tensor:\n r\"\"\"Given a node coloring :obj:`x`, computes the color histograms of\n the respective graphs (separated by :obj:`batch`).\n \"\"\"\n if batch is None:\n batch = torch.zeros(x.size(0), dtype=torch.long, device=x.device)\n\n num_colors = len(self.hashmap)\n batch_size = int(batch.max()) + 1\n\n index = batch * num_colors + x\n out = scatter(torch.ones_like(index), index, dim=0,\n dim_size=num_colors * batch_size, reduce='sum')\n out = out.view(batch_size, num_colors)\n\n if norm:\n out = out.to(torch.float)\n out /= out.norm(dim=-1, keepdim=True)\n\n return out\n\n# Path: torch_geometric/nn/conv/wl_conv_continuous.py\nfrom typing import Union\n\nfrom torch import Tensor\n\nfrom torch_geometric.nn.conv import MessagePassing\nfrom torch_geometric.typing import (\n Adj,\n OptPairTensor,\n OptTensor,\n Size,\n SparseTensor,\n)\nfrom torch_geometric.utils import scatter, spmm\n\n\nclass WLConvContinuous(MessagePassing):\n r\"\"\"The Weisfeiler Lehman operator from the `\"Wasserstein\n Weisfeiler-Lehman Graph Kernels\" `_\n paper.\n\n Refinement is done though a degree-scaled mean aggregation and works on\n nodes with continuous attributes:\n\n .. math::\n \\mathbf{x}^{\\prime}_i = \\frac{1}{2}\\big(\\mathbf{x}_i +\n \\frac{1}{\\textrm{deg}(i)}\n \\sum_{j \\in \\mathcal{N}(i)} e_{j,i} \\cdot \\mathbf{x}_j \\big)\n\n where :math:`e_{j,i}` denotes the edge weight from source node :obj:`j` to\n target node :obj:`i` (default: :obj:`1`)\n\n Args:\n **kwargs (optional): Additional arguments of\n :class:`torch_geometric.nn.conv.MessagePassing`.\n\n Shapes:\n - **input:**\n node features :math:`(|\\mathcal{V}|, F)` or\n :math:`((|\\mathcal{V_s}|, F), (|\\mathcal{V_t}|, F))` if bipartite,\n edge indices :math:`(2, |\\mathcal{E}|)`,\n edge weights :math:`(|\\mathcal{E}|)` *(optional)*\n - **output:** node features :math:`(|\\mathcal{V}|, F)` or\n :math:`(|\\mathcal{V}_t|, F)` if bipartite\n \"\"\"\n def __init__(self, **kwargs):\n super().__init__(aggr='add', **kwargs)\n\n def forward(\n self,\n x: Union[Tensor, OptPairTensor],\n edge_index: Adj,\n edge_weight: OptTensor = None,\n size: Size = None,\n ) -> Tensor:\n\n if isinstance(x, Tensor):\n x = (x, x)\n\n # propagate_type: (x: OptPairTensor, edge_weight: OptTensor)\n out = self.propagate(edge_index, x=x, edge_weight=edge_weight,\n size=size)\n\n if isinstance(edge_index, SparseTensor):\n assert edge_weight is None\n dst_index, _, edge_weight = edge_index.coo()\n else:\n dst_index = edge_index[1]\n\n if edge_weight is None:\n edge_weight = x[0].new_ones(dst_index.numel())\n\n deg = scatter(edge_weight, dst_index, 0, out.size(0), reduce='sum')\n deg_inv = 1. / deg\n deg_inv.masked_fill_(deg_inv == float('inf'), 0)\n out = deg_inv.view(-1, 1) * out\n\n x_dst = x[1]\n if x_dst is not None:\n out = 0.5 * (x_dst + out)\n\n return out\n\n def message(self, x_j: Tensor, edge_weight: OptTensor) -> Tensor:\n return x_j if edge_weight is None else edge_weight.view(-1, 1) * x_j\n\n def message_and_aggregate(self, adj_t: Adj, x: OptPairTensor) -> Tensor:\n return spmm(adj_t, x[0], reduce=self.aggr)\n\n# Path: torch_geometric/nn/conv/x_conv.py\nfrom math import ceil\nfrom typing import Optional\n\nimport torch\nfrom torch import Tensor\nfrom torch.nn import ELU\nfrom torch.nn import BatchNorm1d as BN\nfrom torch.nn import Conv1d\nfrom torch.nn import Linear as L\nfrom torch.nn import Sequential as S\n\nfrom torch_geometric.nn import Reshape\nfrom torch_geometric.nn.inits import reset\n\ntry:\n from torch_cluster import knn_graph\nexcept ImportError:\n knn_graph = None\n\n\nclass XConv(torch.nn.Module):\n r\"\"\"The convolutional operator on :math:`\\mathcal{X}`-transformed points\n from the `\"PointCNN: Convolution On X-Transformed Points\"\n `_ paper.\n\n .. math::\n \\mathbf{x}^{\\prime}_i = \\mathrm{Conv}\\left(\\mathbf{K},\n \\gamma_{\\mathbf{\\Theta}}(\\mathbf{P}_i - \\mathbf{p}_i) \\times\n \\left( h_\\mathbf{\\Theta}(\\mathbf{P}_i - \\mathbf{p}_i) \\, \\Vert \\,\n \\mathbf{x}_i \\right) \\right),\n\n where :math:`\\mathbf{K}` and :math:`\\mathbf{P}_i` denote the trainable\n filter and neighboring point positions of :math:`\\mathbf{x}_i`,\n respectively.\n :math:`\\gamma_{\\mathbf{\\Theta}}` and :math:`h_{\\mathbf{\\Theta}}` describe\n neural networks, *i.e.* MLPs, where :math:`h_{\\mathbf{\\Theta}}`\n individually lifts each point into a higher-dimensional space, and\n :math:`\\gamma_{\\mathbf{\\Theta}}` computes the :math:`\\mathcal{X}`-\n transformation matrix based on *all* points in a neighborhood.\n\n Args:\n in_channels (int): Size of each input sample.\n out_channels (int): Size of each output sample.\n dim (int): Point cloud dimensionality.\n kernel_size (int): Size of the convolving kernel, *i.e.* number of\n neighbors including self-loops.\n hidden_channels (int, optional): Output size of\n :math:`h_{\\mathbf{\\Theta}}`, *i.e.* dimensionality of lifted\n points. If set to :obj:`None`, will be automatically set to\n :obj:`in_channels / 4`. (default: :obj:`None`)\n dilation (int, optional): The factor by which the neighborhood is\n extended, from which :obj:`kernel_size` neighbors are then\n uniformly sampled. Can be interpreted as the dilation rate of\n classical convolutional operators. (default: :obj:`1`)\n bias (bool, optional): If set to :obj:`False`, the layer will not learn\n an additive bias. (default: :obj:`True`)\n num_workers (int): Number of workers to use for k-NN computation.\n Has no effect in case :obj:`batch` is not :obj:`None`, or the input\n lies on the GPU. (default: :obj:`1`)\n\n Shapes:\n - **input:**\n node features :math:`(|\\mathcal{V}|, F_{in})`,\n positions :math:`(|\\mathcal{V}|, D)`,\n batch vector :math:`(|\\mathcal{V}|)` *(optional)*\n - **output:**\n node features :math:`(|\\mathcal{V}|, F_{out})`\n \"\"\"\n def __init__(self, in_channels: int, out_channels: int, dim: int,\n kernel_size: int, hidden_channels: Optional[int] = None,\n dilation: int = 1, bias: bool = True, num_workers: int = 1):\n super().__init__()\n\n if knn_graph is None:\n raise ImportError('`XConv` requires `torch-cluster`.')\n\n self.in_channels = in_channels\n if hidden_channels is None:\n hidden_channels = in_channels // 4\n assert hidden_channels > 0\n self.hidden_channels = hidden_channels\n self.out_channels = out_channels\n self.dim = dim\n self.kernel_size = kernel_size\n self.dilation = dilation\n self.num_workers = num_workers\n\n C_in, C_delta, C_out = in_channels, hidden_channels, out_channels\n D, K = dim, kernel_size\n\n self.mlp1 = S(\n L(dim, C_delta),\n ELU(),\n BN(C_delta),\n L(C_delta, C_delta),\n ELU(),\n BN(C_delta),\n Reshape(-1, K, C_delta),\n )\n\n self.mlp2 = S(\n L(D * K, K**2),\n ELU(),\n BN(K**2),\n Reshape(-1, K, K),\n Conv1d(K, K**2, K, groups=K),\n ELU(),\n BN(K**2),\n Reshape(-1, K, K),\n Conv1d(K, K**2, K, groups=K),\n BN(K**2),\n Reshape(-1, K, K),\n )\n\n C_in = C_in + C_delta\n depth_multiplier = int(ceil(C_out / C_in))\n self.conv = S(\n Conv1d(C_in, C_in * depth_multiplier, K, groups=C_in),\n Reshape(-1, C_in * depth_multiplier),\n L(C_in * depth_multiplier, C_out, bias=bias),\n )\n\n self.reset_parameters()\n\n def reset_parameters(self):\n r\"\"\"Resets all learnable parameters of the module.\"\"\"\n reset(self.mlp1)\n reset(self.mlp2)\n reset(self.conv)\n\n def forward(self, x: Tensor, pos: Tensor, batch: Optional[Tensor] = None):\n r\"\"\"Runs the forward pass of the module.\"\"\"\n pos = pos.unsqueeze(-1) if pos.dim() == 1 else pos\n (N, D), K = pos.size(), self.kernel_size\n\n edge_index = knn_graph(pos, K * self.dilation, batch, loop=True,\n flow='target_to_source',\n num_workers=self.num_workers)\n\n if self.dilation > 1:\n edge_index = edge_index[:, ::self.dilation]\n\n row, col = edge_index[0], edge_index[1]\n\n pos = pos[col] - pos[row]\n\n x_star = self.mlp1(pos)\n if x is not None:\n x = x.unsqueeze(-1) if x.dim() == 1 else x\n x = x[col].view(N, K, self.in_channels)\n x_star = torch.cat([x_star, x], dim=-1)\n x_star = x_star.transpose(1, 2).contiguous()\n\n transform_matrix = self.mlp2(pos.view(N, K * D))\n\n x_transformed = torch.matmul(x_star, transform_matrix)\n\n out = self.conv(x_transformed)\n\n return out\n\n def __repr__(self) -> str:\n return (f'{self.__class__.__name__}({self.in_channels}, '\n f'{self.out_channels})')\n\n# Path: torch_geometric/nn/conv/__init__.py\nfrom .message_passing import MessagePassing\nfrom .simple_conv import SimpleConv\nfrom .gcn_conv import GCNConv\nfrom .cheb_conv import ChebConv\nfrom .sage_conv import SAGEConv\nfrom .cugraph.sage_conv import CuGraphSAGEConv\nfrom .graph_conv import GraphConv\nfrom .gravnet_conv import GravNetConv\nfrom .gated_graph_conv import GatedGraphConv\nfrom .res_gated_graph_conv import ResGatedGraphConv\nfrom .gat_conv import GATConv\nfrom .cugraph.gat_conv import CuGraphGATConv\nfrom .fused_gat_conv import FusedGATConv\nfrom .gatv2_conv import GATv2Conv\nfrom .transformer_conv import TransformerConv\nfrom .agnn_conv import AGNNConv\nfrom .tag_conv import TAGConv\nfrom .gin_conv import GINConv, GINEConv\nfrom .arma_conv import ARMAConv\nfrom .sg_conv import SGConv\nfrom .appnp import APPNP\nfrom .mf_conv import MFConv\nfrom .rgcn_conv import RGCNConv, FastRGCNConv\nfrom .cugraph.rgcn_conv import CuGraphRGCNConv\nfrom .rgat_conv import RGATConv\nfrom .signed_conv import SignedConv\nfrom .dna_conv import DNAConv\nfrom .point_conv import PointNetConv\nfrom .gmm_conv import GMMConv\nfrom .spline_conv import SplineConv\nfrom .nn_conv import NNConv\nfrom .cg_conv import CGConv\nfrom .edge_conv import EdgeConv, DynamicEdgeConv\nfrom .x_conv import XConv\nfrom .ppf_conv import PPFConv\nfrom .feast_conv import FeaStConv\nfrom .point_transformer_conv import PointTransformerConv\nfrom .hypergraph_conv import HypergraphConv\nfrom .le_conv import LEConv\nfrom .pna_conv import PNAConv\nfrom .cluster_gcn_conv import ClusterGCNConv\nfrom .gen_conv import GENConv\nfrom .gcn2_conv import GCN2Conv\nfrom .pan_conv import PANConv\nfrom .wl_conv import WLConv\nfrom .wl_conv_continuous import WLConvContinuous\nfrom .film_conv import FiLMConv\nfrom .supergat_conv import SuperGATConv\nfrom .fa_conv import FAConv\nfrom .eg_conv import EGConv\nfrom .pdn_conv import PDNConv\nfrom .general_conv import GeneralConv\nfrom .hgt_conv import HGTConv\nfrom .heat_conv import HEATConv\nfrom .hetero_conv import HeteroConv\nfrom .han_conv import HANConv\nfrom .lg_conv import LGConv\nfrom .ssg_conv import SSGConv\nfrom .point_gnn_conv import PointGNNConv\nfrom .gps_conv import GPSConv\nfrom .antisymmetric_conv import AntiSymmetricConv\nfrom .dir_gnn_conv import DirGNNConv\nfrom .mixhop_conv import MixHopConv\n\nimport torch_geometric.nn.conv.utils # noqa\n\n__all__ = [\n 'MessagePassing',\n 'SimpleConv',\n 'GCNConv',\n 'ChebConv',\n 'SAGEConv',\n 'CuGraphSAGEConv',\n 'GraphConv',\n 'GravNetConv',\n 'GatedGraphConv',\n 'ResGatedGraphConv',\n 'GATConv',\n 'CuGraphGATConv',\n 'FusedGATConv',\n 'GATv2Conv',\n 'TransformerConv',\n 'AGNNConv',\n 'TAGConv',\n 'GINConv',\n 'GINEConv',\n 'ARMAConv',\n 'SGConv',\n 'SSGConv',\n 'APPNP',\n 'MFConv',\n 'RGCNConv',\n 'FastRGCNConv',\n 'CuGraphRGCNConv',\n 'RGATConv',\n 'SignedConv',\n 'DNAConv',\n 'PointNetConv',\n 'GMMConv',\n 'SplineConv',\n 'NNConv',\n 'CGConv',\n 'EdgeConv',\n 'DynamicEdgeConv',\n 'XConv',\n 'PPFConv',\n 'FeaStConv',\n 'PointTransformerConv',\n 'HypergraphConv',\n 'LEConv',\n 'PNAConv',\n 'ClusterGCNConv',\n 'GENConv',\n 'GCN2Conv',\n 'PANConv',\n 'WLConv',\n 'WLConvContinuous',\n 'FiLMConv',\n 'SuperGATConv',\n 'FAConv',\n 'EGConv',\n 'PDNConv',\n 'GeneralConv',\n 'HGTConv',\n 'HEATConv',\n 'HeteroConv',\n 'HANConv',\n 'LGConv',\n 'PointGNNConv',\n 'GPSConv',\n 'AntiSymmetricConv',\n 'DirGNNConv',\n 'MixHopConv',\n]\n\nclasses = __all__\n\nECConv = NNConv\nPointConv = PointNetConv\n\n# Path: torch_geometric/explain/algorithm/attention_explainer.py\nimport logging\nfrom typing import List, Optional, Union\n\nimport torch\nfrom torch import Tensor\n\nfrom torch_geometric.explain import Explanation\nfrom torch_geometric.explain.algorithm import ExplainerAlgorithm\nfrom torch_geometric.explain.config import ExplanationType, ModelTaskLevel\nfrom torch_geometric.nn.conv.message_passing import MessagePassing\n\n\nclass AttentionExplainer(ExplainerAlgorithm):\n r\"\"\"An explainer that uses the attention coefficients produced by an\n attention-based GNN (*e.g.*,\n :class:`~torch_geometric.nn.conv.GATConv`,\n :class:`~torch_geometric.nn.conv.GATv2Conv`, or\n :class:`~torch_geometric.nn.conv.TransformerConv`) as edge explanation.\n Attention scores across layers and heads will be aggregated according to\n the :obj:`reduce` argument.\n\n Args:\n reduce (str, optional): The method to reduce the attention scores\n across layers and heads. (default: :obj:`\"max\"`)\n \"\"\"\n def __init__(self, reduce: str = 'max'):\n super().__init__()\n self.reduce = reduce\n\n def forward(\n self,\n model: torch.nn.Module,\n x: Tensor,\n edge_index: Tensor,\n *,\n target: Tensor,\n index: Optional[Union[int, Tensor]] = None,\n **kwargs,\n ) -> Explanation:\n if isinstance(x, dict) or isinstance(edge_index, dict):\n raise ValueError(f\"Heterogeneous graphs not yet supported in \"\n f\"'{self.__class__.__name__}'\")\n\n hard_edge_mask = None\n if self.model_config.task_level == ModelTaskLevel.node:\n # We need to compute the hard edge mask to properly clean up edge\n # attributions not involved during message passing:\n _, hard_edge_mask = self._get_hard_masks(model, index, edge_index,\n num_nodes=x.size(0))\n\n alphas: List[Tensor] = []\n\n def hook(module, msg_kwargs, out):\n if 'alpha' in msg_kwargs[0]:\n alphas.append(msg_kwargs[0]['alpha'].detach())\n elif getattr(module, '_alpha', None) is not None:\n alphas.append(module._alpha.detach())\n\n hook_handles = []\n for module in model.modules(): # Register message forward hooks:\n if (isinstance(module, MessagePassing)\n and module.explain is not False):\n hook_handles.append(module.register_message_forward_hook(hook))\n\n model(x, edge_index, **kwargs)\n\n for handle in hook_handles: # Remove hooks:\n handle.remove()\n\n if len(alphas) == 0:\n raise ValueError(\"Could not collect any attention coefficients. \"\n \"Please ensure that your model is using \"\n \"attention-based GNN layers.\")\n\n for i, alpha in enumerate(alphas):\n alpha = alpha[:edge_index.size(1)] # Respect potential self-loops.\n if alpha.dim() == 2:\n alpha = getattr(torch, self.reduce)(alpha, dim=-1)\n if isinstance(alpha, tuple): # Respect `torch.max`:\n alpha = alpha[0]\n elif alpha.dim() > 2:\n raise ValueError(f\"Can not reduce attention coefficients of \"\n f\"shape {list(alpha.size())}\")\n alphas[i] = alpha\n\n if len(alphas) > 1:\n alpha = torch.stack(alphas, dim=-1)\n alpha = getattr(torch, self.reduce)(alpha, dim=-1)\n if isinstance(alpha, tuple): # Respect `torch.max`:\n alpha = alpha[0]\n else:\n alpha = alphas[0]\n\n alpha = self._post_process_mask(alpha, hard_edge_mask,\n apply_sigmoid=False)\n\n return Explanation(edge_mask=alpha)\n\n def supports(self) -> bool:\n explanation_type = self.explainer_config.explanation_type\n if explanation_type != ExplanationType.model:\n logging.error(f\"'{self.__class__.__name__}' only supports \"\n f\"model explanations \"\n f\"got (`explanation_type={explanation_type.value}`)\")\n return False\n\n node_mask_type = self.explainer_config.node_mask_type\n if node_mask_type is not None:\n logging.error(f\"'{self.__class__.__name__}' does not support \"\n f\"explaining input node features \"\n f\"got (`node_mask_type={node_mask_type.value}`)\")\n return False\n\n return True\n\n# Path: torch_geometric/explain/algorithm/base.py\nfrom abc import abstractmethod\nfrom typing import Dict, Optional, Tuple, Union\n\nimport torch\nimport torch.nn.functional as F\nfrom torch import Tensor\n\nfrom torch_geometric.explain import Explanation, HeteroExplanation\nfrom torch_geometric.explain.config import (\n ExplainerConfig,\n ModelConfig,\n ModelReturnType,\n)\nfrom torch_geometric.nn import MessagePassing\nfrom torch_geometric.typing import EdgeType, NodeType\nfrom torch_geometric.utils import k_hop_subgraph\n\n\nclass ExplainerAlgorithm(torch.nn.Module):\n r\"\"\"An abstract base class for implementing explainer algorithms.\"\"\"\n @abstractmethod\n def forward(\n self,\n model: torch.nn.Module,\n x: Union[Tensor, Dict[NodeType, Tensor]],\n edge_index: Union[Tensor, Dict[EdgeType, Tensor]],\n *,\n target: Tensor,\n index: Optional[Union[int, Tensor]] = None,\n **kwargs,\n ) -> Union[Explanation, HeteroExplanation]:\n r\"\"\"Computes the explanation.\n\n Args:\n model (torch.nn.Module): The model to explain.\n x (Union[torch.Tensor, Dict[NodeType, torch.Tensor]]): The input\n node features of a homogeneous or heterogeneous graph.\n edge_index (Union[torch.Tensor, Dict[NodeType, torch.Tensor]]): The\n input edge indices of a homogeneous or heterogeneous graph.\n target (torch.Tensor): The target of the model.\n index (Union[int, Tensor], optional): The index of the model\n output to explain. Can be a single index or a tensor of\n indices. (default: :obj:`None`)\n **kwargs (optional): Additional keyword arguments passed to\n :obj:`model`.\n \"\"\"\n\n @abstractmethod\n def supports(self) -> bool:\n r\"\"\"Checks if the explainer supports the user-defined settings provided\n in :obj:`self.explainer_config`, :obj:`self.model_config`.\n \"\"\"\n pass\n\n ###########################################################################\n\n @property\n def explainer_config(self) -> ExplainerConfig:\n r\"\"\"Returns the connected explainer configuration.\"\"\"\n if not hasattr(self, '_explainer_config'):\n raise ValueError(\n f\"The explanation algorithm '{self.__class__.__name__}' is \"\n f\"not yet connected to any explainer configuration. Please \"\n f\"call `{self.__class__.__name__}.connect(...)` before \"\n f\"proceeding.\")\n return self._explainer_config\n\n @property\n def model_config(self) -> ModelConfig:\n r\"\"\"Returns the connected model configuration.\"\"\"\n if not hasattr(self, '_model_config'):\n raise ValueError(\n f\"The explanation algorithm '{self.__class__.__name__}' is \"\n f\"not yet connected to any model configuration. Please call \"\n f\"`{self.__class__.__name__}.connect(...)` before \"\n f\"proceeding.\")\n return self._model_config\n\n def connect(\n self,\n explainer_config: ExplainerConfig,\n model_config: ModelConfig,\n ):\n r\"\"\"Connects an explainer and model configuration to the explainer\n algorithm.\n \"\"\"\n self._explainer_config = ExplainerConfig.cast(explainer_config)\n self._model_config = ModelConfig.cast(model_config)\n\n if not self.supports():\n raise ValueError(\n f\"The explanation algorithm '{self.__class__.__name__}' does \"\n f\"not support the given explanation settings.\")\n\n # Helper functions ########################################################\n\n @staticmethod\n def _post_process_mask(\n mask: Optional[Tensor],\n hard_mask: Optional[Tensor] = None,\n apply_sigmoid: bool = True,\n ) -> Optional[Tensor]:\n r\"\"\"\"Post processes any mask to not include any attributions of\n elements not involved during message passing.\n \"\"\"\n if mask is None:\n return mask\n\n mask = mask.detach()\n\n if apply_sigmoid:\n mask = mask.sigmoid()\n\n if hard_mask is not None and mask.size(0) == hard_mask.size(0):\n mask[~hard_mask] = 0.\n\n return mask\n\n @staticmethod\n def _get_hard_masks(\n model: torch.nn.Module,\n node_index: Optional[Union[int, Tensor]],\n edge_index: Tensor,\n num_nodes: int,\n ) -> Tuple[Optional[Tensor], Optional[Tensor]]:\n r\"\"\"Returns hard node and edge masks that only include the nodes and\n edges visited during message passing.\n \"\"\"\n if node_index is None:\n return None, None # Consider all nodes and edges.\n\n index, _, _, edge_mask = k_hop_subgraph(\n node_index,\n num_hops=ExplainerAlgorithm._num_hops(model),\n edge_index=edge_index,\n num_nodes=num_nodes,\n flow=ExplainerAlgorithm._flow(model),\n )\n\n node_mask = edge_index.new_zeros(num_nodes, dtype=torch.bool)\n node_mask[index] = True\n\n return node_mask, edge_mask\n\n @staticmethod\n def _num_hops(model: torch.nn.Module) -> int:\n r\"\"\"Returns the number of hops the :obj:`model` is aggregating\n information from.\n \"\"\"\n num_hops = 0\n for module in model.modules():\n if isinstance(module, MessagePassing):\n num_hops += 1\n return num_hops\n\n @staticmethod\n def _flow(model: torch.nn.Module) -> str:\n r\"\"\"Determines the message passing flow of the :obj:`model`.\"\"\"\n for module in model.modules():\n if isinstance(module, MessagePassing):\n return module.flow\n return 'source_to_target'\n\n def _loss_binary_classification(self, y_hat: Tensor, y: Tensor) -> Tensor:\n if self.model_config.return_type == ModelReturnType.raw:\n loss_fn = F.binary_cross_entropy_with_logits\n elif self.model_config.return_type == ModelReturnType.probs:\n loss_fn = F.binary_cross_entropy\n else:\n assert False\n\n return loss_fn(y_hat.view_as(y), y.float())\n\n def _loss_multiclass_classification(\n self,\n y_hat: Tensor,\n y: Tensor,\n ) -> Tensor:\n if self.model_config.return_type == ModelReturnType.raw:\n loss_fn = F.cross_entropy\n elif self.model_config.return_type == ModelReturnType.probs:\n loss_fn = F.nll_loss\n y_hat = y_hat.log()\n elif self.model_config.return_type == ModelReturnType.log_probs:\n loss_fn = F.nll_loss\n else:\n assert False\n\n return loss_fn(y_hat, y)\n\n def _loss_regression(self, y_hat: Tensor, y: Tensor) -> Tensor:\n assert self.model_config.return_type == ModelReturnType.raw\n return F.mse_loss(y_hat, y)\n\n def __repr__(self) -> str:\n return f'{self.__class__.__name__}()'\n\n# Path: torch_geometric/explain/algorithm/utils.py\nfrom typing import Dict, Union\n\nimport torch\nfrom torch import Tensor\nfrom torch.nn import Parameter\n\nfrom torch_geometric.nn import MessagePassing\nfrom torch_geometric.typing import EdgeType\n\n\ndef set_masks(\n model: torch.nn.Module,\n mask: Union[Tensor, Parameter],\n edge_index: Tensor,\n apply_sigmoid: bool = True,\n):\n r\"\"\"Apply mask to every graph layer in the :obj:`model`.\"\"\"\n loop_mask = edge_index[0] != edge_index[1]\n\n # Loop over layers and set masks on MessagePassing layers:\n for module in model.modules():\n if isinstance(module, MessagePassing):\n # Skip layers that have been explicitly set to `False`:\n if module.explain is False:\n continue\n\n # Convert mask to a param if it was previously registered as one.\n # This is a workaround for the fact that PyTorch does not allow\n # assignments of pure tensors to parameter attributes:\n if (not isinstance(mask, Parameter)\n and '_edge_mask' in module._parameters):\n mask = Parameter(mask)\n\n module.explain = True\n module._edge_mask = mask\n module._loop_mask = loop_mask\n module._apply_sigmoid = apply_sigmoid\n\n\ndef set_hetero_masks(\n model: torch.nn.Module,\n mask_dict: Dict[EdgeType, Union[Tensor, Parameter]],\n edge_index_dict: Dict[EdgeType, Tensor],\n apply_sigmoid: bool = True,\n):\n r\"\"\"Apply masks to every heterogeneous graph layer in the :obj:`model`\n according to edge types.\n \"\"\"\n for module in model.modules():\n if isinstance(module, torch.nn.ModuleDict):\n for edge_type in mask_dict.keys():\n if edge_type in module:\n edge_level_module = module[edge_type]\n elif '__'.join(edge_type) in module:\n edge_level_module = module['__'.join(edge_type)]\n else:\n continue\n\n set_masks(\n edge_level_module,\n mask_dict[edge_type],\n edge_index_dict[edge_type],\n apply_sigmoid=apply_sigmoid,\n )\n\n\ndef clear_masks(model: torch.nn.Module):\n r\"\"\"Clear all masks from the model.\"\"\"\n for module in model.modules():\n if isinstance(module, MessagePassing):\n if module.explain is True:\n module.explain = None\n module._edge_mask = None\n module._loop_mask = None\n module._apply_sigmoid = True\n return module\n\n# Path: torch_geometric/explain/algorithm/captum.py\nfrom enum import Enum\nfrom typing import Dict, Optional, Tuple, Union\n\nimport torch\nfrom torch import Tensor\n\nfrom torch_geometric.explain.algorithm.utils import (\n clear_masks,\n set_hetero_masks,\n set_masks,\n)\nfrom torch_geometric.explain.config import (\n ModelConfig,\n ModelMode,\n ModelReturnType,\n)\nfrom torch_geometric.typing import EdgeType, Metadata, NodeType\n\n\nclass MaskLevelType(Enum):\n \"\"\"Enum class for the mask level type.\"\"\"\n node = 'node'\n edge = 'edge'\n node_and_edge = 'node_and_edge'\n\n @property\n def with_edge(self) -> bool:\n return self in [MaskLevelType.edge, MaskLevelType.node_and_edge]\n\n\nclass CaptumModel(torch.nn.Module):\n def __init__(\n self,\n model: torch.nn.Module,\n mask_type: Union[str, MaskLevelType],\n output_idx: Optional[Union[int, Tensor]] = None,\n model_config: Optional[ModelConfig] = None,\n ):\n super().__init__()\n\n self.mask_type = MaskLevelType(mask_type)\n self.model = model\n self.output_idx = output_idx\n self.model_config = model_config\n\n def forward(self, mask, *args):\n \"\"\"\"\"\" # noqa: D419\n # The mask tensor, which comes from Captum's attribution methods,\n # contains the number of samples in dimension 0. Since we are\n # working with only one sample, we squeeze the tensors below.\n assert mask.shape[0] == 1, \"Dimension 0 of input should be 1\"\n if self.mask_type == MaskLevelType.edge:\n assert len(args) >= 2, \"Expects at least x and edge_index as args.\"\n if self.mask_type == MaskLevelType.node:\n assert len(args) >= 1, \"Expects at least edge_index as args.\"\n if self.mask_type == MaskLevelType.node_and_edge:\n assert args[0].shape[0] == 1, \"Dimension 0 of input should be 1\"\n assert len(args[1:]) >= 1, \"Expects at least edge_index as args.\"\n\n # Set edge mask:\n if self.mask_type == MaskLevelType.edge:\n set_masks(self.model, mask.squeeze(0), args[1],\n apply_sigmoid=False)\n elif self.mask_type == MaskLevelType.node_and_edge:\n set_masks(self.model, args[0].squeeze(0), args[1],\n apply_sigmoid=False)\n args = args[1:]\n\n if self.mask_type == MaskLevelType.edge:\n x = self.model(*args)\n\n else:\n x = self.model(mask.squeeze(0), *args)\n\n return self.postprocess(x)\n\n def postprocess(self, x: Tensor) -> Tensor:\n if self.mask_type.with_edge:\n clear_masks(self.model)\n\n if self.output_idx is not None: # Filter by output index:\n x = x[self.output_idx]\n if (isinstance(self.output_idx, int)\n or self.output_idx.dim() == 0):\n x = x.unsqueeze(0)\n\n # Convert binary classification to multi-class classification:\n if (self.model_config is not None\n and self.model_config.mode == ModelMode.binary_classification):\n assert self.model_config.return_type == ModelReturnType.probs\n x = x.view(-1, 1)\n x = torch.cat([1 - x, x], dim=-1)\n\n return x\n\n\n# TODO(jinu) Is there any point of inheriting from `CaptumModel`\nclass CaptumHeteroModel(CaptumModel):\n def __init__(\n self,\n model: torch.nn.Module,\n mask_type: Union[str, MaskLevelType],\n output_idx: Optional[Union[int, Tensor]],\n metadata: Metadata,\n model_config: Optional[ModelConfig] = None,\n ):\n super().__init__(model, mask_type, output_idx, model_config)\n self.node_types = metadata[0]\n self.edge_types = metadata[1]\n self.num_node_types = len(self.node_types)\n self.num_edge_types = len(self.edge_types)\n\n def _captum_data_to_hetero_data(\n self, *args\n ) -> Tuple[Dict[NodeType, Tensor], Dict[EdgeType, Tensor], Optional[Dict[\n EdgeType, Tensor]]]:\n \"\"\"Converts tuple of tensors to `x_dict`, `edge_index_dict` and\n `edge_mask_dict`.\n \"\"\"\n if self.mask_type == MaskLevelType.node:\n node_tensors = args[:self.num_node_types]\n node_tensors = [mask.squeeze(0) for mask in node_tensors]\n x_dict = dict(zip(self.node_types, node_tensors))\n edge_index_dict = args[self.num_node_types]\n elif self.mask_type == MaskLevelType.edge:\n edge_mask_tensors = args[:self.num_edge_types]\n x_dict = args[self.num_edge_types]\n edge_index_dict = args[self.num_edge_types + 1]\n else:\n node_tensors = args[:self.num_node_types]\n node_tensors = [mask.squeeze(0) for mask in node_tensors]\n x_dict = dict(zip(self.node_types, node_tensors))\n edge_mask_tensors = args[self.num_node_types:self.num_node_types +\n self.num_edge_types]\n edge_index_dict = args[self.num_node_types + self.num_edge_types]\n\n if self.mask_type.with_edge:\n edge_mask_tensors = [mask.squeeze(0) for mask in edge_mask_tensors]\n edge_mask_dict = dict(zip(self.edge_types, edge_mask_tensors))\n else:\n edge_mask_dict = None\n return x_dict, edge_index_dict, edge_mask_dict\n\n def forward(self, *args):\n # Validate args:\n if self.mask_type == MaskLevelType.node:\n assert len(args) >= self.num_node_types + 1\n len_remaining_args = len(args) - (self.num_node_types + 1)\n elif self.mask_type == MaskLevelType.edge:\n assert len(args) >= self.num_edge_types + 2\n len_remaining_args = len(args) - (self.num_edge_types + 2)\n else:\n assert len(args) >= self.num_node_types + self.num_edge_types + 1\n len_remaining_args = len(args) - (self.num_node_types +\n self.num_edge_types + 1)\n\n # Get main args:\n (x_dict, edge_index_dict,\n edge_mask_dict) = self._captum_data_to_hetero_data(*args)\n\n if self.mask_type.with_edge:\n set_hetero_masks(self.model, edge_mask_dict, edge_index_dict)\n\n if len_remaining_args > 0:\n # If there are args other than `x_dict` and `edge_index_dict`\n x = self.model(x_dict, edge_index_dict,\n *args[-len_remaining_args:])\n else:\n x = self.model(x_dict, edge_index_dict)\n\n return self.postprocess(x)\n\n\n\ndef _to_edge_mask(edge_index: Tensor) -> Tensor:\n num_edges = edge_index.shape[1]\n return torch.ones(num_edges, requires_grad=True, device=edge_index.device)\n\n\ndef to_captum_input(\n x: Union[Tensor, Dict[NodeType, Tensor]],\n edge_index: Union[Tensor, Dict[EdgeType, Tensor]],\n mask_type: Union[str, MaskLevelType],\n *args,\n) -> Tuple[Tuple[Tensor, ...], Tuple[Tensor, ...]]:\n r\"\"\"Given :obj:`x`, :obj:`edge_index` and :obj:`mask_type`, converts it\n to a format to use in `Captum `_ attribution\n methods. Returns :obj:`inputs` and :obj:`additional_forward_args`\n required for :captum:`Captum's` :obj:`attribute` functions.\n See :meth:`~torch_geometric.nn.models.to_captum_model` for example usage.\n\n Args:\n x (torch.Tensor or Dict[NodeType, torch.Tensor]): The node features.\n For heterogeneous graphs this is a dictionary holding node featues\n for each node type.\n edge_index(torch.Tensor or Dict[EdgeType, torch.Tensor]): The edge\n indices. For heterogeneous graphs this is a dictionary holding the\n :obj:`edge index` for each edge type.\n mask_type (str): Denotes the type of mask to be created with\n a Captum explainer. Valid inputs are :obj:`\"edge\"`, :obj:`\"node\"`,\n and :obj:`\"node_and_edge\"`.\n *args: Additional forward arguments of the model being explained\n which will be added to :obj:`additional_forward_args`.\n \"\"\"\n mask_type = MaskLevelType(mask_type)\n\n additional_forward_args = []\n if isinstance(x, Tensor) and isinstance(edge_index, Tensor):\n if mask_type == MaskLevelType.node:\n inputs = [x.unsqueeze(0)]\n elif mask_type == MaskLevelType.edge:\n inputs = [_to_edge_mask(edge_index).unsqueeze(0)]\n additional_forward_args.append(x)\n else:\n inputs = [x.unsqueeze(0), _to_edge_mask(edge_index).unsqueeze(0)]\n additional_forward_args.append(edge_index)\n\n elif isinstance(x, Dict) and isinstance(edge_index, Dict):\n node_types = x.keys()\n edge_types = edge_index.keys()\n inputs = []\n if mask_type == MaskLevelType.node:\n for key in node_types:\n inputs.append(x[key].unsqueeze(0))\n elif mask_type == MaskLevelType.edge:\n for key in edge_types:\n inputs.append(_to_edge_mask(edge_index[key]).unsqueeze(0))\n additional_forward_args.append(x)\n else:\n for key in node_types:\n inputs.append(x[key].unsqueeze(0))\n for key in edge_types:\n inputs.append(_to_edge_mask(edge_index[key]).unsqueeze(0))\n additional_forward_args.append(edge_index)\n\n else:\n raise ValueError(\n \"'x' and 'edge_index' need to be either\"\n f\"'Dict' or 'Tensor' got({type(x)}, {type(edge_index)})\")\n\n additional_forward_args.extend(args)\n\n return tuple(inputs), tuple(additional_forward_args)\n\n\ndef captum_output_to_dicts(\n captum_attrs: Tuple[Tensor, ...],\n mask_type: Union[str, MaskLevelType],\n metadata: Metadata,\n) -> Tuple[Optional[Dict[NodeType, Tensor]], Optional[Dict[EdgeType, Tensor]]]:\n r\"\"\"Convert the output of `Captum `_ attribution\n methods which is a tuple of attributions to two dictionaries with node and\n edge attribution tensors. This function is used while explaining\n :class:`~torch_geometric.data.HeteroData` objects.\n See :meth:`~torch_geometric.nn.models.to_captum_model` for example usage.\n\n Args:\n captum_attrs (tuple[torch.Tensor]): The output of attribution methods.\n mask_type (str): Denotes the type of mask to be created with\n a Captum explainer. Valid inputs are :obj:`\"edge\"`, :obj:`\"node\"`,\n and :obj:`\"node_and_edge\"`:\n\n 1. :obj:`\"edge\"`: :obj:`captum_attrs` contains only edge\n attributions. The returned tuple has no node attributions, and\n an edge attribution dictionary edge types as keys and edge mask\n tensors of shape :obj:`[num_edges]` as values.\n\n 2. :obj:`\"node\"`: :obj:`captum_attrs` contains only node\n attributions. The returned tuple has a node attribution\n dictionary with node types as keys and node mask tensors of\n shape :obj:`[num_nodes, num_features]` as values, and no edge\n attributions.\n\n 3. :obj:`\"node_and_edge\"`: :obj:`captum_attrs` contains node and\n edge attributions.\n\n metadata (Metadata): The metadata of the heterogeneous graph.\n \"\"\"\n mask_type = MaskLevelType(mask_type)\n node_types = metadata[0]\n edge_types = metadata[1]\n x_attr_dict, edge_attr_dict = None, None\n captum_attrs = [captum_attr.squeeze(0) for captum_attr in captum_attrs]\n if mask_type == MaskLevelType.node:\n assert len(node_types) == len(captum_attrs)\n x_attr_dict = dict(zip(node_types, captum_attrs))\n elif mask_type == MaskLevelType.edge:\n assert len(edge_types) == len(captum_attrs)\n edge_attr_dict = dict(zip(edge_types, captum_attrs))\n elif mask_type == MaskLevelType.node_and_edge:\n assert len(edge_types) + len(node_types) == len(captum_attrs)\n x_attr_dict = dict(zip(node_types, captum_attrs[:len(node_types)]))\n edge_attr_dict = dict(zip(edge_types, captum_attrs[len(node_types):]))\n return x_attr_dict, edge_attr_dict\n\n\ndef convert_captum_output(\n captum_attrs: Tuple[Tensor, ...],\n mask_type: Union[str, MaskLevelType],\n metadata: Optional[Metadata] = None,\n):\n r\"\"\"Convert the output of `Captum.ai `_ attribution\n methods which is a tuple of attributions to either\n :obj:`(node_mask, edge_mask)` or :obj:`(node_mask_dict, edge_mask_dict)`.\n \"\"\"\n mask_type = MaskLevelType(mask_type)\n if metadata is not None:\n return captum_output_to_dicts(captum_attrs, mask_type, metadata)\n\n node_mask = edge_mask = None\n if mask_type == MaskLevelType.edge:\n edge_mask = captum_attrs[0].squeeze(0)\n elif mask_type == MaskLevelType.node:\n node_mask = captum_attrs[0].squeeze(0)\n else:\n node_mask = captum_attrs[0].squeeze(0)\n edge_mask = captum_attrs[1].squeeze(0)\n\n return node_mask, edge_mask\n\n# Path: torch_geometric/explain/algorithm/captum_explainer.py\nimport inspect\nimport logging\nimport warnings\nfrom typing import Any, Dict, Optional, Union\n\nimport torch\nfrom torch import Tensor\n\nfrom torch_geometric.explain import Explanation, HeteroExplanation\nfrom torch_geometric.explain.algorithm import ExplainerAlgorithm\nfrom torch_geometric.explain.algorithm.captum import (\n CaptumHeteroModel,\n CaptumModel,\n MaskLevelType,\n convert_captum_output,\n to_captum_input,\n)\nfrom torch_geometric.explain.config import MaskType, ModelMode, ModelReturnType\nfrom torch_geometric.typing import EdgeType, NodeType\n\n\nclass CaptumExplainer(ExplainerAlgorithm):\n \"\"\"A `Captum `__-based explainer for identifying compact\n subgraph structures and node features that play a crucial role in the\n predictions made by a GNN.\n\n This explainer algorithm uses :captum:`null` `Captum `_\n to compute attributions.\n\n Currently, the following attribution methods are supported:\n\n * :class:`captum.attr.IntegratedGradients`\n * :class:`captum.attr.Saliency`\n * :class:`captum.attr.InputXGradient`\n * :class:`captum.attr.Deconvolution`\n * :class:`captum.attr.ShapleyValueSampling`\n...\n", "needle_token_start": 13687, "needle_token_end": 13748, "code_context_ntokens": 16140, "output": ["```python\ndef _to_edge_mask(edge_index: Tensor) -> Tensor:\n num_edges = edge_index.shape[1]\n return torch.ones(num_edges, requires_grad=True, device=edge_index.device)\n```"]} +{"repo": "pyg-team/pytorch_geometric", "name": "get_self_loop_attr", "language": "python", "path": "torch_geometric/utils/loop.py", "position_ratio": 0.95, "description": "\nFunction Description:\n1. **Purpose**: The function retrieves the attributes or weights of self-loops (edges that connect a node to itself) for each node in a graph. If a self-loop is absent for a node, the attribute is set to zero. If no attributes are provided, a default attribute of one is used for each self-loop.\n2. **Input**: \n - A tensor representing the indices of edges in the graph.\n - An optional tensor for the attributes or weights of these edges.\n - An optional integer specifying the total number of nodes in the graph.\n3. **Output**: A tensor containing the attributes or weights of self-loops for each node. If a node does not have a self-loop, the corresponding value is zero.\n4. **Procedure**: \n - First, identify which edges are self-loops by checking where the start and end nodes of each edge are the same.\n - Extract the indices and attributes of these self-loops.\n - Create a tensor initialized with zeros (or ones if no attributes are provided) to represent the attributes of self-loops for all nodes.\n - Place the extracted attributes into this tensor at the positions corresponding to nodes with self-loops.\n\n", "instruction": "Based on the function description and code context, please retrieve and repeat the exact described function from the code context in a code block wrapped by ```:", "template": "instruction\ncode_context\ndescription\ninstruction", "code_context": "# Path: torch_geometric/utils/dropout.py\nfrom typing import Optional, Tuple\n\nimport torch\nfrom torch import Tensor\n\nimport torch_geometric.typing\nfrom torch_geometric import is_compiling\nfrom torch_geometric.deprecation import deprecated\nfrom torch_geometric.typing import OptTensor\nfrom torch_geometric.utils import cumsum, degree, sort_edge_index, subgraph\nfrom torch_geometric.utils.num_nodes import maybe_num_nodes\n\n\ndef filter_adj(row: Tensor, col: Tensor, edge_attr: OptTensor,\n mask: Tensor) -> Tuple[Tensor, Tensor, OptTensor]:\n return row[mask], col[mask], None if edge_attr is None else edge_attr[mask]\n\n\n@deprecated(\"use 'dropout_edge' instead\")\ndef dropout_adj(\n edge_index: Tensor,\n edge_attr: OptTensor = None,\n p: float = 0.5,\n force_undirected: bool = False,\n num_nodes: Optional[int] = None,\n training: bool = True,\n) -> Tuple[Tensor, OptTensor]:\n r\"\"\"Randomly drops edges from the adjacency matrix\n :obj:`(edge_index, edge_attr)` with probability :obj:`p` using samples from\n a Bernoulli distribution.\n\n .. warning::\n\n :class:`~torch_geometric.utils.dropout_adj` is deprecated and will\n be removed in a future release.\n Use :class:`torch_geometric.utils.dropout_edge` instead.\n\n Args:\n edge_index (LongTensor): The edge indices.\n edge_attr (Tensor, optional): Edge weights or multi-dimensional\n edge features. (default: :obj:`None`)\n p (float, optional): Dropout probability. (default: :obj:`0.5`)\n force_undirected (bool, optional): If set to :obj:`True`, will either\n drop or keep both edges of an undirected edge.\n (default: :obj:`False`)\n num_nodes (int, optional): The number of nodes, *i.e.*\n :obj:`max_val + 1` of :attr:`edge_index`. (default: :obj:`None`)\n training (bool, optional): If set to :obj:`False`, this operation is a\n no-op. (default: :obj:`True`)\n\n Examples:\n >>> edge_index = torch.tensor([[0, 1, 1, 2, 2, 3],\n ... [1, 0, 2, 1, 3, 2]])\n >>> edge_attr = torch.tensor([1, 2, 3, 4, 5, 6])\n >>> dropout_adj(edge_index, edge_attr)\n (tensor([[0, 1, 2, 3],\n [1, 2, 3, 2]]),\n tensor([1, 3, 5, 6]))\n\n >>> # The returned graph is kept undirected\n >>> dropout_adj(edge_index, edge_attr, force_undirected=True)\n (tensor([[0, 1, 2, 1, 2, 3],\n [1, 2, 3, 0, 1, 2]]),\n tensor([1, 3, 5, 1, 3, 5]))\n \"\"\"\n if p < 0. or p > 1.:\n raise ValueError(f'Dropout probability has to be between 0 and 1 '\n f'(got {p}')\n\n if not training or p == 0.0:\n return edge_index, edge_attr\n\n row, col = edge_index\n\n mask = torch.rand(row.size(0), device=edge_index.device) >= p\n\n if force_undirected:\n mask[row > col] = False\n\n row, col, edge_attr = filter_adj(row, col, edge_attr, mask)\n\n if force_undirected:\n edge_index = torch.stack(\n [torch.cat([row, col], dim=0),\n torch.cat([col, row], dim=0)], dim=0)\n if edge_attr is not None:\n edge_attr = torch.cat([edge_attr, edge_attr], dim=0)\n else:\n edge_index = torch.stack([row, col], dim=0)\n\n return edge_index, edge_attr\n\n\ndef dropout_node(\n edge_index: Tensor,\n p: float = 0.5,\n num_nodes: Optional[int] = None,\n training: bool = True,\n relabel_nodes: bool = False,\n) -> Tuple[Tensor, Tensor, Tensor]:\n r\"\"\"Randomly drops nodes from the adjacency matrix\n :obj:`edge_index` with probability :obj:`p` using samples from\n a Bernoulli distribution.\n\n The method returns (1) the retained :obj:`edge_index`, (2) the edge mask\n indicating which edges were retained. (3) the node mask indicating\n which nodes were retained.\n\n Args:\n edge_index (LongTensor): The edge indices.\n p (float, optional): Dropout probability. (default: :obj:`0.5`)\n num_nodes (int, optional): The number of nodes, *i.e.*\n :obj:`max_val + 1` of :attr:`edge_index`. (default: :obj:`None`)\n training (bool, optional): If set to :obj:`False`, this operation is a\n no-op. (default: :obj:`True`)\n relabel_nodes (bool, optional): If set to `True`, the resulting\n `edge_index` will be relabeled to hold consecutive indices\n starting from zero.\n\n :rtype: (:class:`LongTensor`, :class:`BoolTensor`, :class:`BoolTensor`)\n\n Examples:\n >>> edge_index = torch.tensor([[0, 1, 1, 2, 2, 3],\n ... [1, 0, 2, 1, 3, 2]])\n >>> edge_index, edge_mask, node_mask = dropout_node(edge_index)\n >>> edge_index\n tensor([[0, 1],\n [1, 0]])\n >>> edge_mask\n tensor([ True, True, False, False, False, False])\n >>> node_mask\n tensor([ True, True, False, False])\n \"\"\"\n if p < 0. or p > 1.:\n raise ValueError(f'Dropout probability has to be between 0 and 1 '\n f'(got {p}')\n\n num_nodes = maybe_num_nodes(edge_index, num_nodes)\n\n if not training or p == 0.0:\n node_mask = edge_index.new_ones(num_nodes, dtype=torch.bool)\n edge_mask = edge_index.new_ones(edge_index.size(1), dtype=torch.bool)\n return edge_index, edge_mask, node_mask\n\n prob = torch.rand(num_nodes, device=edge_index.device)\n node_mask = prob > p\n edge_index, _, edge_mask = subgraph(\n node_mask,\n edge_index,\n relabel_nodes=relabel_nodes,\n num_nodes=num_nodes,\n return_edge_mask=True,\n )\n return edge_index, edge_mask, node_mask\n\n\ndef dropout_edge(edge_index: Tensor, p: float = 0.5,\n force_undirected: bool = False,\n training: bool = True) -> Tuple[Tensor, Tensor]:\n r\"\"\"Randomly drops edges from the adjacency matrix\n :obj:`edge_index` with probability :obj:`p` using samples from\n a Bernoulli distribution.\n\n The method returns (1) the retained :obj:`edge_index`, (2) the edge mask\n or index indicating which edges were retained, depending on the argument\n :obj:`force_undirected`.\n\n Args:\n edge_index (LongTensor): The edge indices.\n p (float, optional): Dropout probability. (default: :obj:`0.5`)\n force_undirected (bool, optional): If set to :obj:`True`, will either\n drop or keep both edges of an undirected edge.\n (default: :obj:`False`)\n training (bool, optional): If set to :obj:`False`, this operation is a\n no-op. (default: :obj:`True`)\n\n :rtype: (:class:`LongTensor`, :class:`BoolTensor` or :class:`LongTensor`)\n\n Examples:\n >>> edge_index = torch.tensor([[0, 1, 1, 2, 2, 3],\n ... [1, 0, 2, 1, 3, 2]])\n >>> edge_index, edge_mask = dropout_edge(edge_index)\n >>> edge_index\n tensor([[0, 1, 2, 2],\n [1, 2, 1, 3]])\n >>> edge_mask # masks indicating which edges are retained\n tensor([ True, False, True, True, True, False])\n\n >>> edge_index, edge_id = dropout_edge(edge_index,\n ... force_undirected=True)\n >>> edge_index\n tensor([[0, 1, 2, 1, 2, 3],\n [1, 2, 3, 0, 1, 2]])\n >>> edge_id # indices indicating which edges are retained\n tensor([0, 2, 4, 0, 2, 4])\n \"\"\"\n if p < 0. or p > 1.:\n raise ValueError(f'Dropout probability has to be between 0 and 1 '\n f'(got {p}')\n\n if not training or p == 0.0:\n edge_mask = edge_index.new_ones(edge_index.size(1), dtype=torch.bool)\n return edge_index, edge_mask\n\n row, col = edge_index\n\n edge_mask = torch.rand(row.size(0), device=edge_index.device) >= p\n\n if force_undirected:\n edge_mask[row > col] = False\n\n edge_index = edge_index[:, edge_mask]\n\n if force_undirected:\n edge_index = torch.cat([edge_index, edge_index.flip(0)], dim=1)\n edge_mask = edge_mask.nonzero().repeat((2, 1)).squeeze()\n\n return edge_index, edge_mask\n\n\ndef dropout_path(edge_index: Tensor, p: float = 0.2, walks_per_node: int = 1,\n walk_length: int = 3, num_nodes: Optional[int] = None,\n is_sorted: bool = False,\n training: bool = True) -> Tuple[Tensor, Tensor]:\n r\"\"\"Drops edges from the adjacency matrix :obj:`edge_index`\n based on random walks. The source nodes to start random walks from are\n sampled from :obj:`edge_index` with probability :obj:`p`, following\n a Bernoulli distribution.\n\n The method returns (1) the retained :obj:`edge_index`, (2) the edge mask\n indicating which edges were retained.\n\n Args:\n edge_index (LongTensor): The edge indices.\n p (float, optional): Sample probability. (default: :obj:`0.2`)\n walks_per_node (int, optional): The number of walks per node, same as\n :class:`~torch_geometric.nn.models.Node2Vec`. (default: :obj:`1`)\n walk_length (int, optional): The walk length, same as\n :class:`~torch_geometric.nn.models.Node2Vec`. (default: :obj:`3`)\n num_nodes (int, optional): The number of nodes, *i.e.*\n :obj:`max_val + 1` of :attr:`edge_index`. (default: :obj:`None`)\n is_sorted (bool, optional): If set to :obj:`True`, will expect\n :obj:`edge_index` to be already sorted row-wise.\n (default: :obj:`False`)\n training (bool, optional): If set to :obj:`False`, this operation is a\n no-op. (default: :obj:`True`)\n\n :rtype: (:class:`LongTensor`, :class:`BoolTensor`)\n\n Example:\n >>> edge_index = torch.tensor([[0, 1, 1, 2, 2, 3],\n...\n# Path: torch_geometric/utils/embedding.py\nimport warnings\nfrom typing import Any, List\n\nimport torch\nfrom torch import Tensor\n\n\ndef get_embeddings(\n model: torch.nn.Module,\n *args: Any,\n **kwargs: Any,\n) -> List[Tensor]:\n \"\"\"Returns the output embeddings of all\n :class:`~torch_geometric.nn.conv.MessagePassing` layers in\n :obj:`model`.\n\n Internally, this method registers forward hooks on all\n :class:`~torch_geometric.nn.conv.MessagePassing` layers of a :obj:`model`,\n and runs the forward pass of the :obj:`model` by calling\n :obj:`model(*args, **kwargs)`.\n\n Args:\n model (torch.nn.Module): The message passing model.\n *args: Arguments passed to the model.\n **kwargs (optional): Additional keyword arguments passed to the model.\n \"\"\"\n from torch_geometric.nn import MessagePassing\n\n embeddings: List[Tensor] = []\n\n def hook(model: torch.nn.Module, inputs: Any, outputs: Any) -> None:\n # Clone output in case it will be later modified in-place:\n outputs = outputs[0] if isinstance(outputs, tuple) else outputs\n assert isinstance(outputs, Tensor)\n embeddings.append(outputs.clone())\n\n hook_handles = []\n for module in model.modules(): # Register forward hooks:\n if isinstance(module, MessagePassing):\n hook_handles.append(module.register_forward_hook(hook))\n\n if len(hook_handles) == 0:\n warnings.warn(\"The 'model' does not have any 'MessagePassing' layers\")\n\n training = model.training\n model.eval()\n with torch.no_grad():\n model(*args, **kwargs)\n model.train(training)\n\n for handle in hook_handles: # Remove hooks:\n handle.remove()\n\n return embeddings\n\n# Path: torch_geometric/utils/geodesic.py\nimport multiprocessing as mp\nimport warnings\nfrom typing import Optional\n\nimport torch\nfrom torch import Tensor\n\n\ndef geodesic_distance( # noqa: D417\n pos: Tensor,\n face: Tensor,\n src: Optional[Tensor] = None,\n dst: Optional[Tensor] = None,\n norm: bool = True,\n max_distance: Optional[float] = None,\n num_workers: int = 0,\n # Backward compatibility for `dest`:\n **kwargs: Optional[Tensor],\n) -> Tensor:\n r\"\"\"Computes (normalized) geodesic distances of a mesh given by :obj:`pos`\n and :obj:`face`. If :obj:`src` and :obj:`dst` are given, this method only\n computes the geodesic distances for the respective source and target\n node-pairs.\n\n .. note::\n\n This function requires the :obj:`gdist` package.\n To install, run :obj:`pip install cython && pip install gdist`.\n\n Args:\n pos (torch.Tensor): The node positions.\n face (torch.Tensor): The face indices.\n src (torch.Tensor, optional): If given, only compute geodesic distances\n for the specified source indices. (default: :obj:`None`)\n dst (torch.Tensor, optional): If given, only compute geodesic distances\n for the specified target indices. (default: :obj:`None`)\n norm (bool, optional): Normalizes geodesic distances by\n :math:`\\sqrt{\\textrm{area}(\\mathcal{M})}`. (default: :obj:`True`)\n max_distance (float, optional): If given, only yields results for\n geodesic distances less than :obj:`max_distance`. This will speed\n up runtime dramatically. (default: :obj:`None`)\n num_workers (int, optional): How many subprocesses to use for\n calculating geodesic distances.\n :obj:`0` means that computation takes place in the main process.\n :obj:`-1` means that the available amount of CPU cores is used.\n (default: :obj:`0`)\n\n :rtype: :class:`Tensor`\n\n Example:\n >>> pos = torch.tensor([[0.0, 0.0, 0.0],\n ... [2.0, 0.0, 0.0],\n ... [0.0, 2.0, 0.0],\n ... [2.0, 2.0, 0.0]])\n >>> face = torch.tensor([[0, 0],\n ... [1, 2],\n ... [3, 3]])\n >>> geodesic_distance(pos, face)\n [[0, 1, 1, 1.4142135623730951],\n [1, 0, 1.4142135623730951, 1],\n [1, 1.4142135623730951, 0, 1],\n [1.4142135623730951, 1, 1, 0]]\n \"\"\"\n import gdist\n\n if 'dest' in kwargs:\n dst = kwargs['dest']\n warnings.warn(\"'dest' attribute in 'geodesic_distance' is deprecated \"\n \"and will be removed in a future release. Use the 'dst' \"\n \"argument instead.\")\n\n max_distance = float('inf') if max_distance is None else max_distance\n\n if norm:\n area = (pos[face[1]] - pos[face[0]]).cross(\n pos[face[2]] - pos[face[0]],\n dim=1,\n )\n scale = float((area.norm(p=2, dim=1) / 2).sum().sqrt())\n else:\n scale = 1.0\n\n dtype = pos.dtype\n\n pos = pos.detach().cpu().to(torch.double).numpy()\n face = face.detach().t().cpu().to(torch.int).numpy()\n\n if src is None and dst is None:\n out = gdist.local_gdist_matrix(pos, face,\n max_distance * scale).toarray() / scale\n return torch.from_numpy(out).to(dtype)\n\n if src is None:\n src = torch.arange(pos.shape[0], dtype=torch.int).numpy()\n else:\n src = src.detach().cpu().to(torch.int).numpy()\n assert src is not None\n\n dst = None if dst is None else dst.detach().cpu().to(torch.int).numpy()\n\n def _parallel_loop(\n pos: Tensor,\n face: Tensor,\n src: Tensor,\n dst: Optional[Tensor],\n max_distance: float,\n scale: float,\n i: int,\n dtype: torch.dtype,\n ) -> Tensor:\n s = src[i:i + 1]\n d = None if dst is None else dst[i:i + 1]\n out = gdist.compute_gdist(pos, face, s, d, max_distance * scale)\n out = out / scale\n return torch.from_numpy(out).to(dtype)\n\n num_workers = mp.cpu_count() if num_workers <= -1 else num_workers\n if num_workers > 0:\n with mp.Pool(num_workers) as pool:\n outs = pool.starmap(\n _parallel_loop,\n [(pos, face, src, dst, max_distance, scale, i, dtype)\n for i in range(len(src))])\n else:\n outs = [\n _parallel_loop(pos, face, src, dst, max_distance, scale, i, dtype)\n for i in range(len(src))\n ]\n\n out = torch.cat(outs, dim=0)\n\n if dst is None:\n out = out.view(-1, pos.shape[0])\n\n return out\n\n# Path: torch_geometric/utils/isolated.py\nfrom typing import Optional, Tuple\n\nimport torch\nfrom torch import Tensor\n\nfrom torch_geometric.utils import remove_self_loops, segregate_self_loops\nfrom torch_geometric.utils.num_nodes import maybe_num_nodes\n\n\ndef contains_isolated_nodes(\n edge_index: Tensor,\n num_nodes: Optional[int] = None,\n) -> bool:\n r\"\"\"Returns :obj:`True` if the graph given by :attr:`edge_index` contains\n isolated nodes.\n\n Args:\n edge_index (LongTensor): The edge indices.\n num_nodes (int, optional): The number of nodes, *i.e.*\n :obj:`max_val + 1` of :attr:`edge_index`. (default: :obj:`None`)\n\n :rtype: bool\n\n Examples:\n >>> edge_index = torch.tensor([[0, 1, 0],\n ... [1, 0, 0]])\n >>> contains_isolated_nodes(edge_index)\n False\n\n >>> contains_isolated_nodes(edge_index, num_nodes=3)\n True\n \"\"\"\n num_nodes = maybe_num_nodes(edge_index, num_nodes)\n edge_index, _ = remove_self_loops(edge_index)\n return torch.unique(edge_index.view(-1)).numel() < num_nodes\n\n\ndef remove_isolated_nodes(\n edge_index: Tensor,\n edge_attr: Optional[Tensor] = None,\n num_nodes: Optional[int] = None,\n) -> Tuple[Tensor, Optional[Tensor], Tensor]:\n r\"\"\"Removes the isolated nodes from the graph given by :attr:`edge_index`\n with optional edge attributes :attr:`edge_attr`.\n In addition, returns a mask of shape :obj:`[num_nodes]` to manually filter\n out isolated node features later on.\n Self-loops are preserved for non-isolated nodes.\n\n Args:\n edge_index (LongTensor): The edge indices.\n edge_attr (Tensor, optional): Edge weights or multi-dimensional\n edge features. (default: :obj:`None`)\n num_nodes (int, optional): The number of nodes, *i.e.*\n :obj:`max_val + 1` of :attr:`edge_index`. (default: :obj:`None`)\n\n :rtype: (LongTensor, Tensor, BoolTensor)\n\n Examples:\n >>> edge_index = torch.tensor([[0, 1, 0],\n ... [1, 0, 0]])\n >>> edge_index, edge_attr, mask = remove_isolated_nodes(edge_index)\n >>> mask # node mask (2 nodes)\n tensor([True, True])\n\n >>> edge_index, edge_attr, mask = remove_isolated_nodes(edge_index,\n ... num_nodes=3)\n >>> mask # node mask (3 nodes)\n tensor([True, True, False])\n \"\"\"\n num_nodes = maybe_num_nodes(edge_index, num_nodes)\n\n out = segregate_self_loops(edge_index, edge_attr)\n edge_index, edge_attr, loop_edge_index, loop_edge_attr = out\n\n mask = torch.zeros(num_nodes, dtype=torch.bool, device=edge_index.device)\n mask[edge_index.view(-1)] = 1\n\n assoc = torch.full((num_nodes, ), -1, dtype=torch.long, device=mask.device)\n assoc[mask] = torch.arange(mask.sum(), device=assoc.device) # type: ignore\n edge_index = assoc[edge_index]\n\n loop_mask = torch.zeros_like(mask)\n loop_mask[loop_edge_index[0]] = 1\n loop_mask = loop_mask & mask\n loop_assoc = torch.full_like(assoc, -1)\n loop_assoc[loop_edge_index[0]] = torch.arange(loop_edge_index.size(1),\n device=loop_assoc.device)\n loop_idx = loop_assoc[loop_mask]\n loop_edge_index = assoc[loop_edge_index[:, loop_idx]]\n\n edge_index = torch.cat([edge_index, loop_edge_index], dim=1)\n\n if edge_attr is not None:\n assert loop_edge_attr is not None\n loop_edge_attr = loop_edge_attr[loop_idx]\n edge_attr = torch.cat([edge_attr, loop_edge_attr], dim=0)\n\n return edge_index, edge_attr, mask\n\n# Path: torch_geometric/utils/laplacian.py\nfrom typing import Optional, Tuple\n\nimport torch\nfrom torch import Tensor\n\nfrom torch_geometric.typing import OptTensor\nfrom torch_geometric.utils import add_self_loops, remove_self_loops, scatter\nfrom torch_geometric.utils.num_nodes import maybe_num_nodes\n\n\ndef get_laplacian(\n edge_index: Tensor,\n edge_weight: OptTensor = None,\n normalization: Optional[str] = None,\n dtype: Optional[torch.dtype] = None,\n num_nodes: Optional[int] = None,\n) -> Tuple[Tensor, Tensor]:\n r\"\"\"Computes the graph Laplacian of the graph given by :obj:`edge_index`\n and optional :obj:`edge_weight`.\n\n Args:\n edge_index (LongTensor): The edge indices.\n edge_weight (Tensor, optional): One-dimensional edge weights.\n (default: :obj:`None`)\n normalization (str, optional): The normalization scheme for the graph\n Laplacian (default: :obj:`None`):\n\n 1. :obj:`None`: No normalization\n :math:`\\mathbf{L} = \\mathbf{D} - \\mathbf{A}`\n\n 2. :obj:`\"sym\"`: Symmetric normalization\n :math:`\\mathbf{L} = \\mathbf{I} - \\mathbf{D}^{-1/2} \\mathbf{A}\n \\mathbf{D}^{-1/2}`\n\n 3. :obj:`\"rw\"`: Random-walk normalization\n :math:`\\mathbf{L} = \\mathbf{I} - \\mathbf{D}^{-1} \\mathbf{A}`\n dtype (torch.dtype, optional): The desired data type of returned tensor\n in case :obj:`edge_weight=None`. (default: :obj:`None`)\n num_nodes (int, optional): The number of nodes, *i.e.*\n :obj:`max_val + 1` of :attr:`edge_index`. (default: :obj:`None`)\n\n Examples:\n >>> edge_index = torch.tensor([[0, 1, 1, 2],\n ... [1, 0, 2, 1]])\n >>> edge_weight = torch.tensor([1., 2., 2., 4.])\n\n >>> # No normalization\n >>> lap = get_laplacian(edge_index, edge_weight)\n\n >>> # Symmetric normalization\n >>> lap_sym = get_laplacian(edge_index, edge_weight,\n normalization='sym')\n\n >>> # Random-walk normalization\n >>> lap_rw = get_laplacian(edge_index, edge_weight, normalization='rw')\n \"\"\"\n if normalization is not None:\n assert normalization in ['sym', 'rw'] # 'Invalid normalization'\n\n edge_index, edge_weight = remove_self_loops(edge_index, edge_weight)\n\n if edge_weight is None:\n edge_weight = torch.ones(edge_index.size(1), dtype=dtype,\n device=edge_index.device)\n\n num_nodes = maybe_num_nodes(edge_index, num_nodes)\n\n row, col = edge_index[0], edge_index[1]\n deg = scatter(edge_weight, row, 0, dim_size=num_nodes, reduce='sum')\n\n if normalization is None:\n # L = D - A.\n edge_index, _ = add_self_loops(edge_index, num_nodes=num_nodes)\n edge_weight = torch.cat([-edge_weight, deg], dim=0)\n elif normalization == 'sym':\n # Compute A_norm = -D^{-1/2} A D^{-1/2}.\n deg_inv_sqrt = deg.pow_(-0.5)\n deg_inv_sqrt.masked_fill_(deg_inv_sqrt == float('inf'), 0)\n edge_weight = deg_inv_sqrt[row] * edge_weight * deg_inv_sqrt[col]\n\n # L = I - A_norm.\n assert isinstance(edge_weight, Tensor)\n edge_index, edge_weight = add_self_loops( #\n edge_index, -edge_weight, fill_value=1., num_nodes=num_nodes)\n else:\n # Compute A_norm = -D^{-1} A.\n deg_inv = 1.0 / deg\n deg_inv.masked_fill_(deg_inv == float('inf'), 0)\n edge_weight = deg_inv[row] * edge_weight\n\n # L = I - A_norm.\n assert isinstance(edge_weight, Tensor)\n edge_index, edge_weight = add_self_loops( #\n edge_index, -edge_weight, fill_value=1., num_nodes=num_nodes)\n\n return edge_index, edge_weight\n\n# Path: torch_geometric/utils/loop.py\nimport typing\nfrom typing import Optional, Tuple, Union\n\nimport torch\nfrom torch import Tensor\n\nfrom torch_geometric import EdgeIndex\nfrom torch_geometric.utils import scatter\nfrom torch_geometric.utils.num_nodes import maybe_num_nodes\nfrom torch_geometric.utils.sparse import (\n is_torch_sparse_tensor,\n to_edge_index,\n to_torch_coo_tensor,\n to_torch_csr_tensor,\n)\n\nif typing.TYPE_CHECKING:\n from typing import overload\nelse:\n from torch.jit import _overload as overload\n\n\ndef contains_self_loops(edge_index: Tensor) -> bool:\n r\"\"\"Returns :obj:`True` if the graph given by :attr:`edge_index` contains\n self-loops.\n\n Args:\n edge_index (LongTensor): The edge indices.\n\n :rtype: bool\n\n Examples:\n >>> edge_index = torch.tensor([[0, 1, 0],\n ... [1, 0, 0]])\n >>> contains_self_loops(edge_index)\n True\n\n >>> edge_index = torch.tensor([[0, 1, 1],\n ... [1, 0, 2]])\n >>> contains_self_loops(edge_index)\n False\n \"\"\"\n mask = edge_index[0] == edge_index[1]\n return mask.sum().item() > 0\n\n\n@overload\ndef remove_self_loops(\n edge_index: Tensor,\n edge_attr: None = None,\n) -> Tuple[Tensor, None]:\n pass\n\n\n@overload\ndef remove_self_loops( # noqa: F811\n edge_index: Tensor,\n edge_attr: Tensor,\n) -> Tuple[Tensor, Tensor]:\n pass\n\n\n@overload\ndef remove_self_loops( # noqa: F811\n edge_index: Tensor,\n edge_attr: Optional[Tensor],\n) -> Tuple[Tensor, Optional[Tensor]]:\n pass\n\n\ndef remove_self_loops( # noqa: F811\n edge_index: Tensor,\n edge_attr: Optional[Tensor] = None,\n) -> Tuple[Tensor, Optional[Tensor]]:\n r\"\"\"Removes every self-loop in the graph given by :attr:`edge_index`, so\n that :math:`(i,i) \\not\\in \\mathcal{E}` for every :math:`i \\in \\mathcal{V}`.\n\n Args:\n edge_index (LongTensor): The edge indices.\n edge_attr (Tensor, optional): Edge weights or multi-dimensional\n edge features. (default: :obj:`None`)\n\n :rtype: (:class:`LongTensor`, :class:`Tensor`)\n\n Example:\n >>> edge_index = torch.tensor([[0, 1, 0],\n ... [1, 0, 0]])\n >>> edge_attr = [[1, 2], [3, 4], [5, 6]]\n >>> edge_attr = torch.tensor(edge_attr)\n >>> remove_self_loops(edge_index, edge_attr)\n (tensor([[0, 1],\n [1, 0]]),\n tensor([[1, 2],\n [3, 4]]))\n \"\"\"\n size: Optional[Tuple[int, int]] = None\n if not typing.TYPE_CHECKING and torch.jit.is_scripting():\n layout: Optional[int] = None\n else:\n layout: Optional[torch.layout] = None\n\n value: Optional[Tensor] = None\n if is_torch_sparse_tensor(edge_index):\n layout = edge_index.layout\n size = (edge_index.size(0), edge_index.size(1))\n edge_index, value = to_edge_index(edge_index)\n\n is_undirected = False\n if not torch.jit.is_scripting() and isinstance(edge_index, EdgeIndex):\n is_undirected = edge_index.is_undirected\n\n mask = edge_index[0] != edge_index[1]\n edge_index = edge_index[:, mask]\n\n if not torch.jit.is_scripting() and isinstance(edge_index, EdgeIndex):\n edge_index._is_undirected = is_undirected\n\n if layout is not None:\n assert edge_attr is None\n assert value is not None\n value = value[mask]\n if str(layout) == 'torch.sparse_coo': # str(...) for TorchScript :(\n return to_torch_coo_tensor(edge_index, value, size, True), None\n elif str(layout) == 'torch.sparse_csr':\n return to_torch_csr_tensor(edge_index, value, size, True), None\n raise ValueError(f\"Unexpected sparse tensor layout (got '{layout}')\")\n\n if edge_attr is None:\n return edge_index, None\n else:\n return edge_index, edge_attr[mask]\n\n\n@overload\ndef segregate_self_loops(\n edge_index: Tensor,\n edge_attr: None = None,\n) -> Tuple[Tensor, None, Tensor, None]:\n pass\n\n\n@overload\ndef segregate_self_loops( # noqa: F811\n edge_index: Tensor,\n edge_attr: Tensor,\n) -> Tuple[Tensor, Tensor, Tensor, Tensor]:\n pass\n\n\n@overload\ndef segregate_self_loops( # noqa: F811\n edge_index: Tensor,\n edge_attr: Optional[Tensor],\n) -> Tuple[Tensor, Optional[Tensor], Tensor, Optional[Tensor]]:\n pass\n\n\ndef segregate_self_loops( # noqa: F811\n edge_index: Tensor,\n edge_attr: Optional[Tensor] = None,\n) -> Tuple[Tensor, Optional[Tensor], Tensor, Optional[Tensor]]:\n r\"\"\"Segregates self-loops from the graph.\n\n Args:\n edge_index (LongTensor): The edge indices.\n edge_attr (Tensor, optional): Edge weights or multi-dimensional\n edge features. (default: :obj:`None`)\n\n :rtype: (:class:`LongTensor`, :class:`Tensor`, :class:`LongTensor`,\n :class:`Tensor`)\n\n Example:\n >>> edge_index = torch.tensor([[0, 0, 1],\n ... [0, 1, 0]])\n >>> (edge_index, edge_attr,\n ... loop_edge_index,\n ... loop_edge_attr) = segregate_self_loops(edge_index)\n >>> loop_edge_index\n tensor([[0],\n [0]])\n \"\"\"\n mask = edge_index[0] != edge_index[1]\n inv_mask = ~mask\n\n is_undirected = False\n if not torch.jit.is_scripting() and isinstance(edge_index, EdgeIndex):\n is_undirected = edge_index.is_undirected\n\n loop_edge_index = edge_index[:, inv_mask]\n loop_edge_attr = None if edge_attr is None else edge_attr[inv_mask]\n edge_index = edge_index[:, mask]\n edge_attr = None if edge_attr is None else edge_attr[mask]\n\n if not torch.jit.is_scripting() and isinstance(edge_index, EdgeIndex):\n assert isinstance(loop_edge_index, EdgeIndex)\n edge_index._is_undirected = is_undirected\n loop_edge_index._is_undirected = is_undirected\n\n return edge_index, edge_attr, loop_edge_index, loop_edge_attr\n\n\n@overload\ndef add_self_loops(\n edge_index: Tensor,\n edge_attr: None = None,\n fill_value: Optional[float] = None,\n num_nodes: Optional[int] = None,\n) -> Tuple[Tensor, None]:\n pass\n\n\n@overload\ndef add_self_loops( # noqa: F811\n edge_index: Tensor,\n edge_attr: None = None,\n fill_value: Optional[float] = None,\n num_nodes: Optional[Tuple[int, int]] = None,\n) -> Tuple[Tensor, None]:\n pass\n\n\n@overload\ndef add_self_loops( # noqa: F811\n edge_index: Tensor,\n edge_attr: None = None,\n fill_value: Optional[Tensor] = None,\n num_nodes: Optional[int] = None,\n) -> Tuple[Tensor, None]:\n pass\n\n\n@overload\ndef add_self_loops( # noqa: F811\n edge_index: Tensor,\n edge_attr: None = None,\n fill_value: Optional[Tensor] = None,\n num_nodes: Optional[Tuple[int, int]] = None,\n) -> Tuple[Tensor, None]:\n pass\n\n\n@overload\ndef add_self_loops( # noqa: F811\n edge_index: Tensor,\n edge_attr: None = None,\n fill_value: Optional[str] = None,\n num_nodes: Optional[int] = None,\n) -> Tuple[Tensor, None]:\n pass\n\n\n@overload\ndef add_self_loops( # noqa: F811\n edge_index: Tensor,\n edge_attr: None = None,\n fill_value: Optional[str] = None,\n num_nodes: Optional[Tuple[int, int]] = None,\n) -> Tuple[Tensor, None]:\n pass\n\n\n@overload\ndef add_self_loops( # noqa: F811\n edge_index: Tensor,\n edge_attr: Tensor,\n fill_value: Optional[float] = None,\n num_nodes: Optional[int] = None,\n) -> Tuple[Tensor, Tensor]:\n pass\n\n\n@overload\ndef add_self_loops( # noqa: F811\n edge_index: Tensor,\n edge_attr: Tensor,\n fill_value: Optional[float] = None,\n num_nodes: Optional[Tuple[int, int]] = None,\n) -> Tuple[Tensor, Tensor]:\n pass\n\n\n@overload\ndef add_self_loops( # noqa: F811\n edge_index: Tensor,\n edge_attr: Tensor,\n fill_value: Optional[Tensor] = None,\n num_nodes: Optional[int] = None,\n) -> Tuple[Tensor, Tensor]:\n pass\n\n\n@overload\ndef add_self_loops( # noqa: F811\n edge_index: Tensor,\n edge_attr: Tensor,\n fill_value: Optional[Tensor] = None,\n num_nodes: Optional[Tuple[int, int]] = None,\n) -> Tuple[Tensor, Tensor]:\n pass\n\n\n@overload\ndef add_self_loops( # noqa: F811\n edge_index: Tensor,\n edge_attr: Tensor,\n fill_value: Optional[str] = None,\n num_nodes: Optional[int] = None,\n) -> Tuple[Tensor, Tensor]:\n pass\n\n\n@overload\ndef add_self_loops( # noqa: F811\n edge_index: Tensor,\n edge_attr: Tensor,\n fill_value: Optional[str] = None,\n num_nodes: Optional[Tuple[int, int]] = None,\n) -> Tuple[Tensor, Tensor]:\n pass\n\n\n@overload\ndef add_self_loops( # noqa: F811\n edge_index: Tensor,\n edge_attr: Optional[Tensor],\n fill_value: Optional[float] = None,\n num_nodes: Optional[int] = None,\n) -> Tuple[Tensor, Optional[Tensor]]:\n pass\n\n\n@overload\ndef add_self_loops( # noqa: F811\n edge_index: Tensor,\n edge_attr: Optional[Tensor],\n fill_value: Optional[float] = None,\n num_nodes: Optional[Tuple[int, int]] = None,\n) -> Tuple[Tensor, Optional[Tensor]]:\n pass\n\n\n@overload\ndef add_self_loops( # noqa: F811\n edge_index: Tensor,\n edge_attr: Optional[Tensor],\n fill_value: Optional[Tensor] = None,\n num_nodes: Optional[int] = None,\n) -> Tuple[Tensor, Optional[Tensor]]:\n pass\n\n\n@overload\ndef add_self_loops( # noqa: F811\n edge_index: Tensor,\n edge_attr: Optional[Tensor],\n fill_value: Optional[Tensor] = None,\n num_nodes: Optional[Tuple[int, int]] = None,\n) -> Tuple[Tensor, Optional[Tensor]]:\n pass\n\n\n@overload\ndef add_self_loops( # noqa: F811\n edge_index: Tensor,\n edge_attr: Optional[Tensor],\n fill_value: Optional[str] = None,\n num_nodes: Optional[int] = None,\n) -> Tuple[Tensor, Optional[Tensor]]:\n pass\n\n\n@overload\ndef add_self_loops( # noqa: F811\n edge_index: Tensor,\n edge_attr: Optional[Tensor],\n fill_value: Optional[str] = None,\n num_nodes: Optional[Tuple[int, int]] = None,\n) -> Tuple[Tensor, Optional[Tensor]]:\n pass\n\n\ndef add_self_loops( # noqa: F811\n edge_index: Tensor,\n edge_attr: Optional[Tensor] = None,\n fill_value: Optional[Union[float, Tensor, str]] = None,\n num_nodes: Optional[Union[int, Tuple[int, int]]] = None,\n) -> Tuple[Tensor, Optional[Tensor]]:\n r\"\"\"Adds a self-loop :math:`(i,i) \\in \\mathcal{E}` to every node\n :math:`i \\in \\mathcal{V}` in the graph given by :attr:`edge_index`.\n In case the graph is weighted or has multi-dimensional edge features\n (:obj:`edge_attr != None`), edge features of self-loops will be added\n according to :obj:`fill_value`.\n\n Args:\n edge_index (LongTensor): The edge indices.\n edge_attr (Tensor, optional): Edge weights or multi-dimensional edge\n features. (default: :obj:`None`)\n fill_value (float or Tensor or str, optional): The way to generate\n edge features of self-loops (in case :obj:`edge_attr != None`).\n If given as :obj:`float` or :class:`torch.Tensor`, edge features of\n self-loops will be directly given by :obj:`fill_value`.\n If given as :obj:`str`, edge features of self-loops are computed by\n aggregating all features of edges that point to the specific node,\n according to a reduce operation. (:obj:`\"add\"`, :obj:`\"mean\"`,\n :obj:`\"min\"`, :obj:`\"max\"`, :obj:`\"mul\"`). (default: :obj:`1.`)\n num_nodes (int or Tuple[int, int], optional): The number of nodes,\n *i.e.* :obj:`max_val + 1` of :attr:`edge_index`.\n If given as a tuple, then :obj:`edge_index` is interpreted as a\n bipartite graph with shape :obj:`(num_src_nodes, num_dst_nodes)`.\n (default: :obj:`None`)\n\n :rtype: (:class:`LongTensor`, :class:`Tensor`)\n\n Examples:\n >>> edge_index = torch.tensor([[0, 1, 0],\n ... [1, 0, 0]])\n >>> edge_weight = torch.tensor([0.5, 0.5, 0.5])\n >>> add_self_loops(edge_index)\n (tensor([[0, 1, 0, 0, 1],\n [1, 0, 0, 0, 1]]),\n None)\n\n >>> add_self_loops(edge_index, edge_weight)\n (tensor([[0, 1, 0, 0, 1],\n [1, 0, 0, 0, 1]]),\n tensor([0.5000, 0.5000, 0.5000, 1.0000, 1.0000]))\n\n >>> # edge features of self-loops are filled by constant `2.0`\n >>> add_self_loops(edge_index, edge_weight,\n ... fill_value=2.)\n (tensor([[0, 1, 0, 0, 1],\n [1, 0, 0, 0, 1]]),\n tensor([0.5000, 0.5000, 0.5000, 2.0000, 2.0000]))\n\n >>> # Use 'add' operation to merge edge features for self-loops\n >>> add_self_loops(edge_index, edge_weight,\n ... fill_value='add')\n (tensor([[0, 1, 0, 0, 1],\n [1, 0, 0, 0, 1]]),\n tensor([0.5000, 0.5000, 0.5000, 1.0000, 0.5000]))\n \"\"\"\n if not typing.TYPE_CHECKING and torch.jit.is_scripting():\n layout: Optional[int] = None\n else:\n layout: Optional[torch.layout] = None\n is_sparse = is_torch_sparse_tensor(edge_index)\n\n value: Optional[Tensor] = None\n if is_sparse:\n assert edge_attr is None\n layout = edge_index.layout\n size = (edge_index.size(0), edge_index.size(1))\n N = min(size)\n edge_index, value = to_edge_index(edge_index)\n elif isinstance(num_nodes, (tuple, list)):\n size = (num_nodes[0], num_nodes[1])\n N = min(size)\n else:\n N = maybe_num_nodes(edge_index, num_nodes)\n size = (N, N)\n\n device = edge_index.device\n if not torch.jit.is_scripting() and isinstance(edge_index, EdgeIndex):\n loop_index: Tensor = EdgeIndex(\n torch.arange(0, N, device=device).view(1, -1).repeat(2, 1),\n sparse_size=(N, N),\n is_undirected=True,\n )\n else:\n loop_index = torch.arange(0, N, device=device).view(1, -1).repeat(2, 1)\n\n full_edge_index = torch.cat([edge_index, loop_index], dim=1)\n\n if is_sparse:\n assert edge_attr is None\n assert value is not None\n loop_attr = compute_loop_attr( #\n edge_index, value, N, is_sparse, fill_value)\n value = torch.cat([value, loop_attr], dim=0)\n\n if str(layout) == 'torch.sparse_coo': # str(...) for TorchScript :(\n return to_torch_coo_tensor(full_edge_index, value, size), None\n elif str(layout) == 'torch.sparse_csr':\n return to_torch_csr_tensor(full_edge_index, value, size), None\n raise ValueError(f\"Unexpected sparse tensor layout (got '{layout}')\")\n\n if edge_attr is not None:\n loop_attr = compute_loop_attr( #\n edge_index, edge_attr, N, is_sparse, fill_value)\n edge_attr = torch.cat([edge_attr, loop_attr], dim=0)\n\n return full_edge_index, edge_attr\n\n\n@overload\ndef add_remaining_self_loops(\n edge_index: Tensor,\n edge_attr: None = None,\n fill_value: Optional[float] = None,\n num_nodes: Optional[int] = None,\n) -> Tuple[Tensor, None]:\n pass\n\n\n@overload\ndef add_remaining_self_loops( # noqa: F811\n edge_index: Tensor,\n edge_attr: None = None,\n fill_value: Optional[Tensor] = None,\n num_nodes: Optional[int] = None,\n) -> Tuple[Tensor, None]:\n pass\n\n\n@overload\ndef add_remaining_self_loops( # noqa: F811\n edge_index: Tensor,\n edge_attr: None = None,\n fill_value: Optional[str] = None,\n num_nodes: Optional[int] = None,\n) -> Tuple[Tensor, None]:\n pass\n\n\n@overload\ndef add_remaining_self_loops( # noqa: F811\n edge_index: Tensor,\n edge_attr: Tensor,\n fill_value: Optional[float] = None,\n num_nodes: Optional[int] = None,\n) -> Tuple[Tensor, Tensor]:\n pass\n\n\n@overload\ndef add_remaining_self_loops( # noqa: F811\n edge_index: Tensor,\n edge_attr: Tensor,\n fill_value: Optional[Tensor] = None,\n num_nodes: Optional[int] = None,\n) -> Tuple[Tensor, Tensor]:\n pass\n\n\n@overload\ndef add_remaining_self_loops( # noqa: F811\n edge_index: Tensor,\n edge_attr: Tensor,\n fill_value: Optional[str] = None,\n num_nodes: Optional[int] = None,\n) -> Tuple[Tensor, Tensor]:\n pass\n\n\n@overload\ndef add_remaining_self_loops( # noqa: F811\n edge_index: Tensor,\n edge_attr: Optional[Tensor],\n fill_value: Optional[float] = None,\n num_nodes: Optional[int] = None,\n) -> Tuple[Tensor, Optional[Tensor]]:\n pass\n\n\n@overload\ndef add_remaining_self_loops( # noqa: F811\n edge_index: Tensor,\n edge_attr: Optional[Tensor],\n fill_value: Optional[Tensor] = None,\n num_nodes: Optional[int] = None,\n) -> Tuple[Tensor, Optional[Tensor]]:\n pass\n\n\n@overload\ndef add_remaining_self_loops( # noqa: F811\n edge_index: Tensor,\n edge_attr: Optional[Tensor],\n fill_value: Optional[str] = None,\n num_nodes: Optional[int] = None,\n) -> Tuple[Tensor, Optional[Tensor]]:\n pass\n\n\ndef add_remaining_self_loops( # noqa: F811\n edge_index: Tensor,\n edge_attr: Optional[Tensor] = None,\n fill_value: Optional[Union[float, Tensor, str]] = None,\n num_nodes: Optional[int] = None,\n) -> Tuple[Tensor, Optional[Tensor]]:\n r\"\"\"Adds remaining self-loop :math:`(i,i) \\in \\mathcal{E}` to every node\n :math:`i \\in \\mathcal{V}` in the graph given by :attr:`edge_index`.\n In case the graph is weighted or has multi-dimensional edge features\n (:obj:`edge_attr != None`), edge features of non-existing self-loops will\n be added according to :obj:`fill_value`.\n\n Args:\n edge_index (LongTensor): The edge indices.\n edge_attr (Tensor, optional): Edge weights or multi-dimensional edge\n features. (default: :obj:`None`)\n fill_value (float or Tensor or str, optional): The way to generate\n edge features of self-loops (in case :obj:`edge_attr != None`).\n If given as :obj:`float` or :class:`torch.Tensor`, edge features of\n self-loops will be directly given by :obj:`fill_value`.\n If given as :obj:`str`, edge features of self-loops are computed by\n aggregating all features of edges that point to the specific node,\n according to a reduce operation. (:obj:`\"add\"`, :obj:`\"mean\"`,\n :obj:`\"min\"`, :obj:`\"max\"`, :obj:`\"mul\"`). (default: :obj:`1.`)\n num_nodes (int, optional): The number of nodes, *i.e.*\n :obj:`max_val + 1` of :attr:`edge_index`. (default: :obj:`None`)\n\n :rtype: (:class:`LongTensor`, :class:`Tensor`)\n\n Example:\n >>> edge_index = torch.tensor([[0, 1],\n ... [1, 0]])\n >>> edge_weight = torch.tensor([0.5, 0.5])\n >>> add_remaining_self_loops(edge_index, edge_weight)\n (tensor([[0, 1, 0, 1],\n [1, 0, 0, 1]]),\n tensor([0.5000, 0.5000, 1.0000, 1.0000]))\n \"\"\"\n N = maybe_num_nodes(edge_index, num_nodes)\n mask = edge_index[0] != edge_index[1]\n\n device = edge_index.device\n if not torch.jit.is_scripting() and isinstance(edge_index, EdgeIndex):\n loop_index: Tensor = EdgeIndex(\n torch.arange(0, N, device=device).view(1, -1).repeat(2, 1),\n sparse_size=(N, N),\n is_undirected=True,\n )\n else:\n loop_index = torch.arange(0, N, device=device).view(1, -1).repeat(2, 1)\n\n if edge_attr is not None:\n\n loop_attr = compute_loop_attr( #\n edge_index, edge_attr, N, False, fill_value)\n\n inv_mask = ~mask\n loop_attr[edge_index[0][inv_mask]] = edge_attr[inv_mask]\n\n edge_attr = torch.cat([edge_attr[mask], loop_attr], dim=0)\n\n is_undirected = False\n if not torch.jit.is_scripting() and isinstance(edge_index, EdgeIndex):\n is_undirected = edge_index.is_undirected\n\n edge_index = edge_index[:, mask]\n\n if not torch.jit.is_scripting() and isinstance(edge_index, EdgeIndex):\n edge_index._is_undirected = is_undirected\n\n edge_index = torch.cat([edge_index, loop_index], dim=1)\n\n return edge_index, edge_attr\n\n\n\ndef get_self_loop_attr(\n edge_index: Tensor,\n edge_attr: Optional[Tensor] = None,\n num_nodes: Optional[int] = None,\n) -> Tensor:\n r\"\"\"Returns the edge features or weights of self-loops\n :math:`(i, i)` of every node :math:`i \\in \\mathcal{V}` in the\n graph given by :attr:`edge_index`. Edge features of missing self-loops not\n present in :attr:`edge_index` will be filled with zeros. If\n :attr:`edge_attr` is not given, it will be the vector of ones.\n\n .. note::\n This operation is analogous to getting the diagonal elements of the\n dense adjacency matrix.\n\n Args:\n edge_index (LongTensor): The edge indices.\n edge_attr (Tensor, optional): Edge weights or multi-dimensional edge\n features. (default: :obj:`None`)\n num_nodes (int, optional): The number of nodes, *i.e.*\n :obj:`max_val + 1` of :attr:`edge_index`. (default: :obj:`None`)\n\n :rtype: :class:`Tensor`\n\n Examples:\n >>> edge_index = torch.tensor([[0, 1, 0],\n ... [1, 0, 0]])\n >>> edge_weight = torch.tensor([0.2, 0.3, 0.5])\n >>> get_self_loop_attr(edge_index, edge_weight)\n tensor([0.5000, 0.0000])\n\n >>> get_self_loop_attr(edge_index, edge_weight, num_nodes=4)\n tensor([0.5000, 0.0000, 0.0000, 0.0000])\n \"\"\"\n loop_mask = edge_index[0] == edge_index[1]\n loop_index = edge_index[0][loop_mask]\n\n if edge_attr is not None:\n loop_attr = edge_attr[loop_mask]\n else: # A vector of ones:\n loop_attr = torch.ones(loop_index.numel(), device=edge_index.device)\n\n num_nodes = maybe_num_nodes(edge_index, num_nodes)\n full_loop_attr = loop_attr.new_zeros((num_nodes, ) + loop_attr.size()[1:])\n full_loop_attr[loop_index] = loop_attr\n\n return full_loop_attr\n\n\n@overload\ndef compute_loop_attr(\n edge_index: Tensor,\n edge_attr: Tensor,\n num_nodes: int,\n is_sparse: bool,\n fill_value: Optional[float] = None,\n) -> Tensor:\n pass\n\n\n@overload\ndef compute_loop_attr( # noqa: F811\n edge_index: Tensor,\n edge_attr: Tensor,\n num_nodes: int,\n is_sparse: bool,\n fill_value: Optional[Tensor] = None,\n) -> Tensor:\n pass\n\n\n@overload\ndef compute_loop_attr( # noqa: F811\n edge_index: Tensor,\n edge_attr: Tensor,\n num_nodes: int,\n is_sparse: bool,\n fill_value: Optional[str] = None,\n) -> Tensor:\n pass\n\n\ndef compute_loop_attr( # noqa: F811\n edge_index: Tensor,\n edge_attr: Tensor,\n num_nodes: int,\n is_sparse: bool,\n fill_value: Optional[Union[float, Tensor, str]] = None,\n) -> Tensor:\n\n if fill_value is None:\n size = (num_nodes, ) + edge_attr.size()[1:]\n return edge_attr.new_ones(size)\n\n elif isinstance(fill_value, (int, float)):\n size = (num_nodes, ) + edge_attr.size()[1:]\n return edge_attr.new_full(size, fill_value)\n\n elif isinstance(fill_value, Tensor):\n size = (num_nodes, ) + edge_attr.size()[1:]\n loop_attr = fill_value.to(edge_attr.device, edge_attr.dtype)\n if edge_attr.dim() != loop_attr.dim():\n...\n", "needle_token_start": 14965, "needle_token_end": 15603, "code_context_ntokens": 16090, "output": ["```python\ndef get_self_loop_attr(\n edge_index: Tensor,\n edge_attr: Optional[Tensor] = None,\n num_nodes: Optional[int] = None,\n) -> Tensor:\n r\"\"\"Returns the edge features or weights of self-loops\n :math:`(i, i)` of every node :math:`i \\in \\mathcal{V}` in the\n graph given by :attr:`edge_index`. Edge features of missing self-loops not\n present in :attr:`edge_index` will be filled with zeros. If\n :attr:`edge_attr` is not given, it will be the vector of ones.\n\n .. note::\n This operation is analogous to getting the diagonal elements of the\n dense adjacency matrix.\n\n Args:\n edge_index (LongTensor): The edge indices.\n edge_attr (Tensor, optional): Edge weights or multi-dimensional edge\n features. (default: :obj:`None`)\n num_nodes (int, optional): The number of nodes, *i.e.*\n :obj:`max_val + 1` of :attr:`edge_index`. (default: :obj:`None`)\n\n :rtype: :class:`Tensor`\n\n Examples:\n >>> edge_index = torch.tensor([[0, 1, 0],\n ... [1, 0, 0]])\n >>> edge_weight = torch.tensor([0.2, 0.3, 0.5])\n >>> get_self_loop_attr(edge_index, edge_weight)\n tensor([0.5000, 0.0000])\n\n >>> get_self_loop_attr(edge_index, edge_weight, num_nodes=4)\n tensor([0.5000, 0.0000, 0.0000, 0.0000])\n \"\"\"\n loop_mask = edge_index[0] == edge_index[1]\n loop_index = edge_index[0][loop_mask]\n\n if edge_attr is not None:\n loop_attr = edge_attr[loop_mask]\n else: # A vector of ones:\n loop_attr = torch.ones(loop_index.numel(), device=edge_index.device)\n\n num_nodes = maybe_num_nodes(edge_index, num_nodes)\n full_loop_attr = loop_attr.new_zeros((num_nodes, ) + loop_attr.size()[1:])\n full_loop_attr[loop_index] = loop_attr\n\n return full_loop_attr\n```"]} +{"repo": "openai/openai-python", "name": "_get_arch", "language": "python", "path": "src/openai/cli/_tools/migrate.py", "position_ratio": 0.05, "description": "\nFunction Description:\n1. **Purpose**: The function determines the architecture of the system on which the script is running and maps it to a standardized format used for identifying compatible binaries.\n2. **Input**: There are no explicit inputs taken by the function as it retrieves the system architecture internally.\n3. **Output**: Returns a string representing the standardized architecture name.\n4. **Procedure**: The function fetches the current system's architecture, converts it to a lower case format, and then maps it to a corresponding standardized name using a predefined dictionary. If the architecture does not exist in the dictionary, the original architecture name is returned.\n\n", "instruction": "Based on the function description and code context, please retrieve and repeat the exact described function from the code context in a code block wrapped by ```:", "template": "instruction\ncode_context\ndescription\ninstruction", "code_context": " \"\"\"Installs the Grit CLI and returns the location of the binary\"\"\"\n if sys.platform == \"win32\":\n raise CLIError(\"Windows is not supported yet in the migration CLI\")\n\n platform = \"macos\" if sys.platform == \"darwin\" else \"linux\"\n\n dir_name = _cache_dir() / \"openai-python\"\n install_dir = dir_name / \".install\"\n target_dir = install_dir / \"bin\"\n\n target_path = target_dir / \"marzano\"\n temp_file = target_dir / \"marzano.tmp\"\n\n if target_path.exists():\n _debug(f\"{target_path} already exists\")\n sys.stdout.flush()\n return target_path\n\n _debug(f\"Using Grit CLI path: {target_path}\")\n\n target_dir.mkdir(parents=True, exist_ok=True)\n\n if temp_file.exists():\n temp_file.unlink()\n\n arch = _get_arch()\n _debug(f\"Using architecture {arch}\")\n\n file_name = f\"marzano-{platform}-{arch}\"\n meta_url = f\"https://api.keygen.sh/v1/accounts/{KEYGEN_ACCOUNT}/artifacts/{file_name}\"\n\n sys.stdout.write(f\"Retrieving Grit CLI metadata from {meta_url}\\n\")\n with httpx.Client() as client:\n response = client.get(meta_url) # pyright: ignore[reportUnknownMemberType]\n\n data = response.json()\n errors = data.get(\"errors\")\n if errors:\n for error in errors:\n sys.stdout.write(f\"{error}\\n\")\n\n raise CLIError(\"Could not locate Grit CLI binary - see above errors\")\n\n write_manifest(install_dir, data[\"data\"][\"relationships\"][\"release\"][\"data\"][\"id\"])\n\n link = data[\"data\"][\"links\"][\"redirect\"]\n _debug(f\"Redirect URL {link}\")\n\n download_response = client.get(link) # pyright: ignore[reportUnknownMemberType]\n with open(temp_file, \"wb\") as file:\n for chunk in download_response.iter_bytes():\n file.write(chunk)\n\n unpacked_dir = target_dir / \"cli-bin\"\n unpacked_dir.mkdir(parents=True, exist_ok=True)\n\n with tarfile.open(temp_file, \"r:gz\") as archive:\n archive.extractall(unpacked_dir, filter=\"data\")\n\n for item in unpacked_dir.iterdir():\n item.rename(target_dir / item.name)\n\n shutil.rmtree(unpacked_dir)\n os.remove(temp_file)\n os.chmod(target_path, 0o755)\n\n sys.stdout.flush()\n\n return target_path\n\n\n\ndef _get_arch() -> str:\n architecture = platform.machine().lower()\n\n # Map the architecture names to Node.js equivalents\n arch_map = {\n \"x86_64\": \"x64\",\n \"amd64\": \"x64\",\n \"armv7l\": \"arm\",\n \"aarch64\": \"arm64\",\n }\n\n return arch_map.get(architecture, architecture)\n\n\ndef write_manifest(install_path: Path, release: str) -> None:\n manifest = {\n \"installPath\": str(install_path),\n \"binaries\": {\n \"marzano\": {\n \"name\": \"marzano\",\n \"release\": release,\n },\n },\n }\n manifest_path = Path(install_path) / \"manifests.json\"\n with open(manifest_path, \"w\") as f:\n json.dump(manifest, f, indent=2)\n\n# Path: src/openai/cli/_tools/_main.py\nfrom __future__ import annotations\n\nfrom typing import TYPE_CHECKING\nfrom argparse import ArgumentParser\n\nfrom . import migrate, fine_tunes\n\nif TYPE_CHECKING:\n from argparse import _SubParsersAction\n\n\ndef register_commands(parser: ArgumentParser, subparser: _SubParsersAction[ArgumentParser]) -> None:\n migrate.register(subparser)\n\n namespaced = parser.add_subparsers(title=\"Tools\", help=\"Convenience client side tools\")\n\n fine_tunes.register(namespaced)\n\n# Path: src/openai/cli/_tools/__init__.py\nfrom ._main import register_commands as register_commands\n\n# Path: src/openai/cli/_cli.py\nfrom __future__ import annotations\n\nimport sys\nimport logging\nimport argparse\nfrom typing import Any, List, Type, Optional\nfrom typing_extensions import ClassVar\n\nimport httpx\nimport pydantic\n\nimport openai\n\nfrom . import _tools\nfrom .. import _ApiType, __version__\nfrom ._api import register_commands\nfrom ._utils import can_use_http2\nfrom .._types import ProxiesDict\nfrom ._errors import CLIError, display_error\nfrom .._compat import PYDANTIC_V2, ConfigDict, model_parse\nfrom .._models import BaseModel\nfrom .._exceptions import APIError\n\nlogger = logging.getLogger()\nformatter = logging.Formatter(\"[%(asctime)s] %(message)s\")\nhandler = logging.StreamHandler(sys.stderr)\nhandler.setFormatter(formatter)\nlogger.addHandler(handler)\n\n\nclass Arguments(BaseModel):\n if PYDANTIC_V2:\n model_config: ClassVar[ConfigDict] = ConfigDict(\n extra=\"ignore\",\n )\n else:\n\n class Config(pydantic.BaseConfig): # type: ignore\n extra: Any = pydantic.Extra.ignore # type: ignore\n\n verbosity: int\n version: Optional[str] = None\n\n api_key: Optional[str]\n api_base: Optional[str]\n organization: Optional[str]\n proxy: Optional[List[str]]\n api_type: Optional[_ApiType] = None\n api_version: Optional[str] = None\n\n # azure\n azure_endpoint: Optional[str] = None\n azure_ad_token: Optional[str] = None\n\n # internal, set by subparsers to parse their specific args\n args_model: Optional[Type[BaseModel]] = None\n\n # internal, used so that subparsers can forward unknown arguments\n unknown_args: List[str] = []\n allow_unknown_args: bool = False\n\n\ndef _build_parser() -> argparse.ArgumentParser:\n parser = argparse.ArgumentParser(description=None, prog=\"openai\")\n parser.add_argument(\n \"-v\",\n \"--verbose\",\n action=\"count\",\n dest=\"verbosity\",\n default=0,\n help=\"Set verbosity.\",\n )\n parser.add_argument(\"-b\", \"--api-base\", help=\"What API base url to use.\")\n parser.add_argument(\"-k\", \"--api-key\", help=\"What API key to use.\")\n parser.add_argument(\"-p\", \"--proxy\", nargs=\"+\", help=\"What proxy to use.\")\n parser.add_argument(\n \"-o\",\n \"--organization\",\n help=\"Which organization to run as (will use your default organization if not specified)\",\n )\n parser.add_argument(\n \"-t\",\n \"--api-type\",\n type=str,\n choices=(\"openai\", \"azure\"),\n help=\"The backend API to call, must be `openai` or `azure`\",\n )\n parser.add_argument(\n \"--api-version\",\n help=\"The Azure API version, e.g. 'https://learn.microsoft.com/en-us/azure/ai-services/openai/reference#rest-api-versioning'\",\n )\n\n # azure\n parser.add_argument(\n \"--azure-endpoint\",\n help=\"The Azure endpoint, e.g. 'https://endpoint.openai.azure.com'\",\n )\n parser.add_argument(\n \"--azure-ad-token\",\n help=\"A token from Azure Active Directory, https://www.microsoft.com/en-us/security/business/identity-access/microsoft-entra-id\",\n )\n\n # prints the package version\n parser.add_argument(\n \"-V\",\n \"--version\",\n action=\"version\",\n version=\"%(prog)s \" + __version__,\n )\n\n def help() -> None:\n parser.print_help()\n\n parser.set_defaults(func=help)\n\n subparsers = parser.add_subparsers()\n sub_api = subparsers.add_parser(\"api\", help=\"Direct API calls\")\n\n register_commands(sub_api)\n\n sub_tools = subparsers.add_parser(\"tools\", help=\"Client side tools for convenience\")\n _tools.register_commands(sub_tools, subparsers)\n\n return parser\n\n\ndef main() -> int:\n try:\n _main()\n except (APIError, CLIError, pydantic.ValidationError) as err:\n display_error(err)\n return 1\n except KeyboardInterrupt:\n sys.stderr.write(\"\\n\")\n return 1\n return 0\n\n\ndef _parse_args(parser: argparse.ArgumentParser) -> tuple[argparse.Namespace, Arguments, list[str]]:\n # argparse by default will strip out the `--` but we want to keep it for unknown arguments\n if \"--\" in sys.argv:\n idx = sys.argv.index(\"--\")\n known_args = sys.argv[1:idx]\n unknown_args = sys.argv[idx:]\n else:\n known_args = sys.argv[1:]\n unknown_args = []\n\n parsed, remaining_unknown = parser.parse_known_args(known_args)\n\n # append any remaining unknown arguments from the initial parsing\n remaining_unknown.extend(unknown_args)\n\n args = model_parse(Arguments, vars(parsed))\n if not args.allow_unknown_args:\n # we have to parse twice to ensure any unknown arguments\n # result in an error if that behaviour is desired\n parser.parse_args()\n\n return parsed, args, remaining_unknown\n\n\ndef _main() -> None:\n parser = _build_parser()\n parsed, args, unknown = _parse_args(parser)\n\n if args.verbosity != 0:\n sys.stderr.write(\"Warning: --verbosity isn't supported yet\\n\")\n\n proxies: ProxiesDict = {}\n if args.proxy is not None:\n for proxy in args.proxy:\n key = \"https://\" if proxy.startswith(\"https\") else \"http://\"\n if key in proxies:\n raise CLIError(f\"Multiple {key} proxies given - only the last one would be used\")\n\n proxies[key] = proxy\n\n http_client = httpx.Client(\n proxies=proxies or None,\n http2=can_use_http2(),\n )\n openai.http_client = http_client\n\n if args.organization:\n openai.organization = args.organization\n\n if args.api_key:\n openai.api_key = args.api_key\n\n if args.api_base:\n openai.base_url = args.api_base\n\n # azure\n if args.api_type is not None:\n openai.api_type = args.api_type\n\n if args.azure_endpoint is not None:\n openai.azure_endpoint = args.azure_endpoint\n\n if args.api_version is not None:\n openai.api_version = args.api_version\n\n if args.azure_ad_token is not None:\n openai.azure_ad_token = args.azure_ad_token\n\n try:\n if args.args_model:\n parsed.func(\n model_parse(\n args.args_model,\n {\n **{\n # we omit None values so that they can be defaulted to `NotGiven`\n # and we'll strip it from the API request\n key: value\n for key, value in vars(parsed).items()\n if value is not None\n },\n \"unknown_args\": unknown,\n },\n )\n )\n else:\n parsed.func()\n finally:\n try:\n http_client.close()\n except Exception:\n pass\n\n\nif __name__ == \"__main__\":\n sys.exit(main())\n\n# Path: src/openai/cli/__init__.py\nfrom ._cli import main as main\n\n# Path: src/openai/__main__.py\nfrom .cli import main\n\nmain()\n\n# Path: src/openai/types/beta/thread.py\n# File generated from our OpenAPI spec by Stainless.\n\nfrom typing import Optional\nfrom typing_extensions import Literal\n\nfrom ..._models import BaseModel\n\n__all__ = [\"Thread\"]\n\n\nclass Thread(BaseModel):\n id: str\n \"\"\"The identifier, which can be referenced in API endpoints.\"\"\"\n\n created_at: int\n \"\"\"The Unix timestamp (in seconds) for when the thread was created.\"\"\"\n\n metadata: Optional[object] = None\n \"\"\"Set of 16 key-value pairs that can be attached to an object.\n\n This can be useful for storing additional information about the object in a\n structured format. Keys can be a maximum of 64 characters long and values can be\n a maxium of 512 characters long.\n \"\"\"\n\n object: Literal[\"thread\"]\n \"\"\"The object type, which is always `thread`.\"\"\"\n\n# Path: src/openai/version.py\nfrom ._version import __version__\n\nVERSION: str = __version__\n\n# Path: src/openai/lib/_old_api.py\nfrom __future__ import annotations\n\nfrom typing import TYPE_CHECKING, Any\nfrom typing_extensions import override\n\nfrom .._utils import LazyProxy\nfrom .._exceptions import OpenAIError\n\nINSTRUCTIONS = \"\"\"\n\nYou tried to access openai.{symbol}, but this is no longer supported in openai>=1.0.0 - see the README at https://github.com/openai/openai-python for the API.\n\nYou can run `openai migrate` to automatically upgrade your codebase to use the 1.0.0 interface. \n\nAlternatively, you can pin your installation to the old version, e.g. `pip install openai==0.28`\n\nA detailed migration guide is available here: https://github.com/openai/openai-python/discussions/742\n\"\"\"\n\n\nclass APIRemovedInV1(OpenAIError):\n def __init__(self, *, symbol: str) -> None:\n super().__init__(INSTRUCTIONS.format(symbol=symbol))\n\n\nclass APIRemovedInV1Proxy(LazyProxy[Any]):\n def __init__(self, *, symbol: str) -> None:\n super().__init__()\n self._symbol = symbol\n\n @override\n def __load__(self) -> Any:\n # return the proxy until it is eventually called so that\n # we don't break people that are just checking the attributes\n # of a module\n return self\n\n def __call__(self, *_args: Any, **_kwargs: Any) -> Any:\n raise APIRemovedInV1(symbol=self._symbol)\n\n\nSYMBOLS = [\n \"Edit\",\n \"File\",\n \"Audio\",\n \"Image\",\n \"Model\",\n \"Engine\",\n \"Customer\",\n \"FineTune\",\n \"Embedding\",\n \"Completion\",\n \"Deployment\",\n \"Moderation\",\n \"ErrorObject\",\n \"FineTuningJob\",\n \"ChatCompletion\",\n]\n\n# we explicitly tell type checkers that nothing is exported\n# from this file so that when we re-export the old symbols\n# in `openai/__init__.py` they aren't added to the auto-complete\n# suggestions given by editors\nif TYPE_CHECKING:\n __all__: list[str] = []\nelse:\n __all__ = SYMBOLS\n\n\n__locals = locals()\nfor symbol in SYMBOLS:\n __locals[symbol] = APIRemovedInV1Proxy(symbol=symbol)\n\n# Path: src/openai/lib/_validators.py\n# pyright: basic\nfrom __future__ import annotations\n\nimport os\nimport sys\nfrom typing import Any, TypeVar, Callable, Optional, NamedTuple\nfrom typing_extensions import TypeAlias\n\nfrom .._extras import pandas as pd\n\n\nclass Remediation(NamedTuple):\n name: str\n immediate_msg: Optional[str] = None\n necessary_msg: Optional[str] = None\n necessary_fn: Optional[Callable[[Any], Any]] = None\n optional_msg: Optional[str] = None\n optional_fn: Optional[Callable[[Any], Any]] = None\n error_msg: Optional[str] = None\n\n\nOptionalDataFrameT = TypeVar(\"OptionalDataFrameT\", bound=\"Optional[pd.DataFrame]\")\n\n\ndef num_examples_validator(df: pd.DataFrame) -> Remediation:\n \"\"\"\n This validator will only print out the number of examples and recommend to the user to increase the number of examples if less than 100.\n \"\"\"\n MIN_EXAMPLES = 100\n optional_suggestion = (\n \"\"\n if len(df) >= MIN_EXAMPLES\n else \". In general, we recommend having at least a few hundred examples. We've found that performance tends to linearly increase for every doubling of the number of examples\"\n )\n immediate_msg = f\"\\n- Your file contains {len(df)} prompt-completion pairs{optional_suggestion}\"\n return Remediation(name=\"num_examples\", immediate_msg=immediate_msg)\n\n\ndef necessary_column_validator(df: pd.DataFrame, necessary_column: str) -> Remediation:\n \"\"\"\n This validator will ensure that the necessary column is present in the dataframe.\n \"\"\"\n\n def lower_case_column(df: pd.DataFrame, column: Any) -> pd.DataFrame:\n cols = [c for c in df.columns if str(c).lower() == column]\n df.rename(columns={cols[0]: column.lower()}, inplace=True)\n return df\n\n immediate_msg = None\n necessary_fn = None\n necessary_msg = None\n error_msg = None\n\n if necessary_column not in df.columns:\n if necessary_column in [str(c).lower() for c in df.columns]:\n\n def lower_case_column_creator(df: pd.DataFrame) -> pd.DataFrame:\n return lower_case_column(df, necessary_column)\n\n necessary_fn = lower_case_column_creator\n immediate_msg = f\"\\n- The `{necessary_column}` column/key should be lowercase\"\n necessary_msg = f\"Lower case column name to `{necessary_column}`\"\n else:\n error_msg = f\"`{necessary_column}` column/key is missing. Please make sure you name your columns/keys appropriately, then retry\"\n\n return Remediation(\n name=\"necessary_column\",\n immediate_msg=immediate_msg,\n necessary_msg=necessary_msg,\n necessary_fn=necessary_fn,\n error_msg=error_msg,\n )\n\n\ndef additional_column_validator(df: pd.DataFrame, fields: list[str] = [\"prompt\", \"completion\"]) -> Remediation:\n \"\"\"\n This validator will remove additional columns from the dataframe.\n \"\"\"\n additional_columns = []\n necessary_msg = None\n immediate_msg = None\n necessary_fn = None # type: ignore\n\n if len(df.columns) > 2:\n additional_columns = [c for c in df.columns if c not in fields]\n warn_message = \"\"\n for ac in additional_columns:\n dups = [c for c in additional_columns if ac in c]\n if len(dups) > 0:\n warn_message += f\"\\n WARNING: Some of the additional columns/keys contain `{ac}` in their name. These will be ignored, and the column/key `{ac}` will be used instead. This could also result from a duplicate column/key in the provided file.\"\n immediate_msg = f\"\\n- The input file should contain exactly two columns/keys per row. Additional columns/keys present are: {additional_columns}{warn_message}\"\n necessary_msg = f\"Remove additional columns/keys: {additional_columns}\"\n\n def necessary_fn(x: Any) -> Any:\n return x[fields]\n\n return Remediation(\n name=\"additional_column\",\n immediate_msg=immediate_msg,\n necessary_msg=necessary_msg,\n necessary_fn=necessary_fn,\n )\n\n\ndef non_empty_field_validator(df: pd.DataFrame, field: str = \"completion\") -> Remediation:\n \"\"\"\n This validator will ensure that no completion is empty.\n \"\"\"\n necessary_msg = None\n necessary_fn = None # type: ignore\n immediate_msg = None\n\n if df[field].apply(lambda x: x == \"\").any() or df[field].isnull().any():\n empty_rows = (df[field] == \"\") | (df[field].isnull())\n empty_indexes = df.reset_index().index[empty_rows].tolist()\n immediate_msg = f\"\\n- `{field}` column/key should not contain empty strings. These are rows: {empty_indexes}\"\n\n def necessary_fn(x: Any) -> Any:\n return x[x[field] != \"\"].dropna(subset=[field])\n\n necessary_msg = f\"Remove {len(empty_indexes)} rows with empty {field}s\"\n\n return Remediation(\n name=f\"empty_{field}\",\n immediate_msg=immediate_msg,\n necessary_msg=necessary_msg,\n necessary_fn=necessary_fn,\n )\n\n\ndef duplicated_rows_validator(df: pd.DataFrame, fields: list[str] = [\"prompt\", \"completion\"]) -> Remediation:\n \"\"\"\n This validator will suggest to the user to remove duplicate rows if they exist.\n \"\"\"\n duplicated_rows = df.duplicated(subset=fields)\n duplicated_indexes = df.reset_index().index[duplicated_rows].tolist()\n immediate_msg = None\n optional_msg = None\n optional_fn = None # type: ignore\n\n if len(duplicated_indexes) > 0:\n immediate_msg = f\"\\n- There are {len(duplicated_indexes)} duplicated {'-'.join(fields)} sets. These are rows: {duplicated_indexes}\"\n optional_msg = f\"Remove {len(duplicated_indexes)} duplicate rows\"\n\n def optional_fn(x: Any) -> Any:\n return x.drop_duplicates(subset=fields)\n\n return Remediation(\n name=\"duplicated_rows\",\n immediate_msg=immediate_msg,\n optional_msg=optional_msg,\n optional_fn=optional_fn,\n )\n\n\ndef long_examples_validator(df: pd.DataFrame) -> Remediation:\n \"\"\"\n This validator will suggest to the user to remove examples that are too long.\n \"\"\"\n immediate_msg = None\n optional_msg = None\n optional_fn = None # type: ignore\n\n ft_type = infer_task_type(df)\n if ft_type != \"open-ended generation\":\n\n def get_long_indexes(d: pd.DataFrame) -> Any:\n long_examples = d.apply(lambda x: len(x.prompt) + len(x.completion) > 10000, axis=1)\n return d.reset_index().index[long_examples].tolist()\n\n long_indexes = get_long_indexes(df)\n\n if len(long_indexes) > 0:\n immediate_msg = f\"\\n- There are {len(long_indexes)} examples that are very long. These are rows: {long_indexes}\\nFor conditional generation, and for classification the examples shouldn't be longer than 2048 tokens.\"\n optional_msg = f\"Remove {len(long_indexes)} long examples\"\n\n def optional_fn(x: Any) -> Any:\n long_indexes_to_drop = get_long_indexes(x)\n if long_indexes != long_indexes_to_drop:\n sys.stdout.write(\n f\"The indices of the long examples has changed as a result of a previously applied recommendation.\\nThe {len(long_indexes_to_drop)} long examples to be dropped are now at the following indices: {long_indexes_to_drop}\\n\"\n )\n return x.drop(long_indexes_to_drop)\n\n return Remediation(\n name=\"long_examples\",\n immediate_msg=immediate_msg,\n optional_msg=optional_msg,\n optional_fn=optional_fn,\n )\n\n\ndef common_prompt_suffix_validator(df: pd.DataFrame) -> Remediation:\n \"\"\"\n This validator will suggest to add a common suffix to the prompt if one doesn't already exist in case of classification or conditional generation.\n \"\"\"\n error_msg = None\n immediate_msg = None\n optional_msg = None\n optional_fn = None # type: ignore\n\n # Find a suffix which is not contained within the prompt otherwise\n suggested_suffix = \"\\n\\n### =>\\n\\n\"\n suffix_options = [\n \" ->\",\n \"\\n\\n###\\n\\n\",\n \"\\n\\n===\\n\\n\",\n \"\\n\\n---\\n\\n\",\n \"\\n\\n===>\\n\\n\",\n \"\\n\\n--->\\n\\n\",\n ]\n for suffix_option in suffix_options:\n if suffix_option == \" ->\":\n if df.prompt.str.contains(\"\\n\").any():\n continue\n if df.prompt.str.contains(suffix_option, regex=False).any():\n continue\n suggested_suffix = suffix_option\n break\n display_suggested_suffix = suggested_suffix.replace(\"\\n\", \"\\\\n\")\n\n ft_type = infer_task_type(df)\n if ft_type == \"open-ended generation\":\n return Remediation(name=\"common_suffix\")\n\n def add_suffix(x: Any, suffix: Any) -> Any:\n x[\"prompt\"] += suffix\n return x\n\n common_suffix = get_common_xfix(df.prompt, xfix=\"suffix\")\n if (df.prompt == common_suffix).all():\n error_msg = f\"All prompts are identical: `{common_suffix}`\\nConsider leaving the prompts blank if you want to do open-ended generation, otherwise ensure prompts are different\"\n return Remediation(name=\"common_suffix\", error_msg=error_msg)\n\n if common_suffix != \"\":\n common_suffix_new_line_handled = common_suffix.replace(\"\\n\", \"\\\\n\")\n immediate_msg = f\"\\n- All prompts end with suffix `{common_suffix_new_line_handled}`\"\n if len(common_suffix) > 10:\n immediate_msg += f\". This suffix seems very long. Consider replacing with a shorter suffix, such as `{display_suggested_suffix}`\"\n if df.prompt.str[: -len(common_suffix)].str.contains(common_suffix, regex=False).any():\n immediate_msg += f\"\\n WARNING: Some of your prompts contain the suffix `{common_suffix}` more than once. We strongly suggest that you review your prompts and add a unique suffix\"\n\n else:\n immediate_msg = \"\\n- Your data does not contain a common separator at the end of your prompts. Having a separator string appended to the end of the prompt makes it clearer to the fine-tuned model where the completion should begin. See https://platform.openai.com/docs/guides/fine-tuning/preparing-your-dataset for more detail and examples. If you intend to do open-ended generation, then you should leave the prompts empty\"\n\n if common_suffix == \"\":\n optional_msg = f\"Add a suffix separator `{display_suggested_suffix}` to all prompts\"\n\n def optional_fn(x: Any) -> Any:\n return add_suffix(x, suggested_suffix)\n\n return Remediation(\n name=\"common_completion_suffix\",\n immediate_msg=immediate_msg,\n optional_msg=optional_msg,\n optional_fn=optional_fn,\n error_msg=error_msg,\n )\n\n\ndef common_prompt_prefix_validator(df: pd.DataFrame) -> Remediation:\n \"\"\"\n This validator will suggest to remove a common prefix from the prompt if a long one exist.\n \"\"\"\n MAX_PREFIX_LEN = 12\n\n immediate_msg = None\n optional_msg = None\n optional_fn = None # type: ignore\n\n common_prefix = get_common_xfix(df.prompt, xfix=\"prefix\")\n if common_prefix == \"\":\n return Remediation(name=\"common_prefix\")\n\n def remove_common_prefix(x: Any, prefix: Any) -> Any:\n x[\"prompt\"] = x[\"prompt\"].str[len(prefix) :]\n return x\n\n if (df.prompt == common_prefix).all():\n # already handled by common_suffix_validator\n return Remediation(name=\"common_prefix\")\n\n if common_prefix != \"\":\n immediate_msg = f\"\\n- All prompts start with prefix `{common_prefix}`\"\n if MAX_PREFIX_LEN < len(common_prefix):\n immediate_msg += \". Fine-tuning doesn't require the instruction specifying the task, or a few-shot example scenario. Most of the time you should only add the input data into the prompt, and the desired output into the completion\"\n optional_msg = f\"Remove prefix `{common_prefix}` from all prompts\"\n\n def optional_fn(x: Any) -> Any:\n return remove_common_prefix(x, common_prefix)\n\n return Remediation(\n name=\"common_prompt_prefix\",\n immediate_msg=immediate_msg,\n optional_msg=optional_msg,\n optional_fn=optional_fn,\n )\n\n\ndef common_completion_prefix_validator(df: pd.DataFrame) -> Remediation:\n \"\"\"\n This validator will suggest to remove a common prefix from the completion if a long one exist.\n \"\"\"\n MAX_PREFIX_LEN = 5\n\n common_prefix = get_common_xfix(df.completion, xfix=\"prefix\")\n ws_prefix = len(common_prefix) > 0 and common_prefix[0] == \" \"\n if len(common_prefix) < MAX_PREFIX_LEN:\n return Remediation(name=\"common_prefix\")\n\n def remove_common_prefix(x: Any, prefix: Any, ws_prefix: Any) -> Any:\n x[\"completion\"] = x[\"completion\"].str[len(prefix) :]\n if ws_prefix:\n # keep the single whitespace as prefix\n x[\"completion\"] = f\" {x['completion']}\"\n return x\n\n if (df.completion == common_prefix).all():\n # already handled by common_suffix_validator\n return Remediation(name=\"common_prefix\")\n\n immediate_msg = f\"\\n- All completions start with prefix `{common_prefix}`. Most of the time you should only add the output data into the completion, without any prefix\"\n optional_msg = f\"Remove prefix `{common_prefix}` from all completions\"\n\n def optional_fn(x: Any) -> Any:\n return remove_common_prefix(x, common_prefix, ws_prefix)\n\n return Remediation(\n name=\"common_completion_prefix\",\n immediate_msg=immediate_msg,\n optional_msg=optional_msg,\n optional_fn=optional_fn,\n )\n\n\ndef common_completion_suffix_validator(df: pd.DataFrame) -> Remediation:\n \"\"\"\n This validator will suggest to add a common suffix to the completion if one doesn't already exist in case of classification or conditional generation.\n \"\"\"\n error_msg = None\n immediate_msg = None\n optional_msg = None\n optional_fn = None # type: ignore\n\n ft_type = infer_task_type(df)\n if ft_type == \"open-ended generation\" or ft_type == \"classification\":\n return Remediation(name=\"common_suffix\")\n\n common_suffix = get_common_xfix(df.completion, xfix=\"suffix\")\n if (df.completion == common_suffix).all():\n error_msg = f\"All completions are identical: `{common_suffix}`\\nEnsure completions are different, otherwise the model will just repeat `{common_suffix}`\"\n return Remediation(name=\"common_suffix\", error_msg=error_msg)\n\n # Find a suffix which is not contained within the completion otherwise\n suggested_suffix = \" [END]\"\n suffix_options = [\n \"\\n\",\n \".\",\n \" END\",\n \"***\",\n \"+++\",\n \"&&&\",\n \"$$$\",\n \"@@@\",\n \"%%%\",\n ]\n for suffix_option in suffix_options:\n if df.completion.str.contains(suffix_option, regex=False).any():\n continue\n suggested_suffix = suffix_option\n break\n display_suggested_suffix = suggested_suffix.replace(\"\\n\", \"\\\\n\")\n\n def add_suffix(x: Any, suffix: Any) -> Any:\n x[\"completion\"] += suffix\n return x\n\n if common_suffix != \"\":\n common_suffix_new_line_handled = common_suffix.replace(\"\\n\", \"\\\\n\")\n immediate_msg = f\"\\n- All completions end with suffix `{common_suffix_new_line_handled}`\"\n if len(common_suffix) > 10:\n immediate_msg += f\". This suffix seems very long. Consider replacing with a shorter suffix, such as `{display_suggested_suffix}`\"\n if df.completion.str[: -len(common_suffix)].str.contains(common_suffix, regex=False).any():\n immediate_msg += f\"\\n WARNING: Some of your completions contain the suffix `{common_suffix}` more than once. We suggest that you review your completions and add a unique ending\"\n\n else:\n immediate_msg = \"\\n- Your data does not contain a common ending at the end of your completions. Having a common ending string appended to the end of the completion makes it clearer to the fine-tuned model where the completion should end. See https://platform.openai.com/docs/guides/fine-tuning/preparing-your-dataset for more detail and examples.\"\n\n if common_suffix == \"\":\n optional_msg = f\"Add a suffix ending `{display_suggested_suffix}` to all completions\"\n\n def optional_fn(x: Any) -> Any:\n return add_suffix(x, suggested_suffix)\n\n return Remediation(\n name=\"common_completion_suffix\",\n immediate_msg=immediate_msg,\n optional_msg=optional_msg,\n optional_fn=optional_fn,\n error_msg=error_msg,\n )\n\n\ndef completions_space_start_validator(df: pd.DataFrame) -> Remediation:\n \"\"\"\n This validator will suggest to add a space at the start of the completion if it doesn't already exist. This helps with tokenization.\n \"\"\"\n\n def add_space_start(x: Any) -> Any:\n x[\"completion\"] = x[\"completion\"].apply(lambda s: (\"\" if s.startswith(\" \") else \" \") + s)\n return x\n\n optional_msg = None\n optional_fn = None\n immediate_msg = None\n\n if df.completion.str[:1].nunique() != 1 or df.completion.values[0][0] != \" \":\n immediate_msg = \"\\n- The completion should start with a whitespace character (` `). This tends to produce better results due to the tokenization we use. See https://platform.openai.com/docs/guides/fine-tuning/preparing-your-dataset for more details\"\n optional_msg = \"Add a whitespace character to the beginning of the completion\"\n optional_fn = add_space_start\n return Remediation(\n name=\"completion_space_start\",\n immediate_msg=immediate_msg,\n optional_msg=optional_msg,\n optional_fn=optional_fn,\n )\n\n\ndef lower_case_validator(df: pd.DataFrame, column: Any) -> Remediation | None:\n \"\"\"\n This validator will suggest to lowercase the column values, if more than a third of letters are uppercase.\n \"\"\"\n\n def lower_case(x: Any) -> Any:\n x[column] = x[column].str.lower()\n return x\n\n count_upper = df[column].apply(lambda x: sum(1 for c in x if c.isalpha() and c.isupper())).sum()\n count_lower = df[column].apply(lambda x: sum(1 for c in x if c.isalpha() and c.islower())).sum()\n\n if count_upper * 2 > count_lower:\n return Remediation(\n name=\"lower_case\",\n immediate_msg=f\"\\n- More than a third of your `{column}` column/key is uppercase. Uppercase {column}s tends to perform worse than a mixture of case encountered in normal language. We recommend to lower case the data if that makes sense in your domain. See https://platform.openai.com/docs/guides/fine-tuning/preparing-your-dataset for more details\",\n optional_msg=f\"Lowercase all your data in column/key `{column}`\",\n optional_fn=lower_case,\n )\n return None\n\n\ndef read_any_format(\n fname: str, fields: list[str] = [\"prompt\", \"completion\"]\n) -> tuple[pd.DataFrame | None, Remediation]:\n \"\"\"\n This function will read a file saved in .csv, .json, .txt, .xlsx or .tsv format using pandas.\n - for .xlsx it will read the first sheet\n - for .txt it will assume completions and split on newline\n \"\"\"\n remediation = None\n necessary_msg = None\n immediate_msg = None\n error_msg = None\n df = None\n\n if os.path.isfile(fname):\n try:\n if fname.lower().endswith(\".csv\") or fname.lower().endswith(\".tsv\"):\n file_extension_str, separator = (\"CSV\", \",\") if fname.lower().endswith(\".csv\") else (\"TSV\", \"\\t\")\n immediate_msg = (\n f\"\\n- Based on your file extension, your file is formatted as a {file_extension_str} file\"\n )\n necessary_msg = f\"Your format `{file_extension_str}` will be converted to `JSONL`\"\n df = pd.read_csv(fname, sep=separator, dtype=str).fillna(\"\")\n elif fname.lower().endswith(\".xlsx\"):\n immediate_msg = \"\\n- Based on your file extension, your file is formatted as an Excel file\"\n necessary_msg = \"Your format `XLSX` will be converted to `JSONL`\"\n xls = pd.ExcelFile(fname)\n sheets = xls.sheet_names\n if len(sheets) > 1:\n immediate_msg += \"\\n- Your Excel file contains more than one sheet. Please either save as csv or ensure all data is present in the first sheet. WARNING: Reading only the first sheet...\"\n df = pd.read_excel(fname, dtype=str).fillna(\"\")\n elif fname.lower().endswith(\".txt\"):\n immediate_msg = \"\\n- Based on your file extension, you provided a text file\"\n necessary_msg = \"Your format `TXT` will be converted to `JSONL`\"\n with open(fname, \"r\") as f:\n content = f.read()\n df = pd.DataFrame(\n [[\"\", line] for line in content.split(\"\\n\")],\n columns=fields,\n dtype=str,\n ).fillna(\"\")\n elif fname.lower().endswith(\".jsonl\"):\n df = pd.read_json(fname, lines=True, dtype=str).fillna(\"\") # type: ignore\n if len(df) == 1: # type: ignore\n # this is NOT what we expect for a .jsonl file\n immediate_msg = \"\\n- Your JSONL file appears to be in a JSON format. Your file will be converted to JSONL format\"\n necessary_msg = \"Your format `JSON` will be converted to `JSONL`\"\n df = pd.read_json(fname, dtype=str).fillna(\"\") # type: ignore\n else:\n pass # this is what we expect for a .jsonl file\n elif fname.lower().endswith(\".json\"):\n try:\n # to handle case where .json file is actually a .jsonl file\n df = pd.read_json(fname, lines=True, dtype=str).fillna(\"\") # type: ignore\n if len(df) == 1: # type: ignore\n # this code path corresponds to a .json file that has one line\n df = pd.read_json(fname, dtype=str).fillna(\"\") # type: ignore\n else:\n # this is NOT what we expect for a .json file\n immediate_msg = \"\\n- Your JSON file appears to be in a JSONL format. Your file will be converted to JSONL format\"\n necessary_msg = \"Your format `JSON` will be converted to `JSONL`\"\n except ValueError:\n # this code path corresponds to a .json file that has multiple lines (i.e. it is indented)\n df = pd.read_json(fname, dtype=str).fillna(\"\") # type: ignore\n else:\n error_msg = (\n \"Your file must have one of the following extensions: .CSV, .TSV, .XLSX, .TXT, .JSON or .JSONL\"\n )\n if \".\" in fname:\n error_msg += f\" Your file `{fname}` ends with the extension `.{fname.split('.')[-1]}` which is not supported.\"\n else:\n error_msg += f\" Your file `{fname}` is missing a file extension.\"\n\n except (ValueError, TypeError):\n file_extension_str = fname.split(\".\")[-1].upper()\n error_msg = f\"Your file `{fname}` does not appear to be in valid {file_extension_str} format. Please ensure your file is formatted as a valid {file_extension_str} file.\"\n\n else:\n error_msg = f\"File {fname} does not exist.\"\n\n remediation = Remediation(\n name=\"read_any_format\",\n necessary_msg=necessary_msg,\n immediate_msg=immediate_msg,\n error_msg=error_msg,\n )\n return df, remediation\n\n\ndef format_inferrer_validator(df: pd.DataFrame) -> Remediation:\n \"\"\"\n This validator will infer the likely fine-tuning format of the data, and display it to the user if it is classification.\n It will also suggest to use ada and explain train/validation split benefits.\n \"\"\"\n ft_type = infer_task_type(df)\n immediate_msg = None\n if ft_type == \"classification\":\n immediate_msg = f\"\\n- Based on your data it seems like you're trying to fine-tune a model for {ft_type}\\n- For classification, we recommend you try one of the faster and cheaper models, such as `ada`\\n- For classification, you can estimate the expected model performance by keeping a held out dataset, which is not used for training\"\n return Remediation(name=\"num_examples\", immediate_msg=immediate_msg)\n\n\ndef apply_necessary_remediation(df: OptionalDataFrameT, remediation: Remediation) -> OptionalDataFrameT:\n \"\"\"\n This function will apply a necessary remediation to a dataframe, or print an error message if one exists.\n \"\"\"\n if remediation.error_msg is not None:\n sys.stderr.write(f\"\\n\\nERROR in {remediation.name} validator: {remediation.error_msg}\\n\\nAborting...\")\n sys.exit(1)\n if remediation.immediate_msg is not None:\n sys.stdout.write(remediation.immediate_msg)\n if remediation.necessary_fn is not None:\n df = remediation.necessary_fn(df)\n return df\n\n\ndef accept_suggestion(input_text: str, auto_accept: bool) -> bool:\n sys.stdout.write(input_text)\n if auto_accept:\n sys.stdout.write(\"Y\\n\")\n return True\n return input().lower() != \"n\"\n\n\ndef apply_optional_remediation(\n df: pd.DataFrame, remediation: Remediation, auto_accept: bool\n) -> tuple[pd.DataFrame, bool]:\n \"\"\"\n This function will apply an optional remediation to a dataframe, based on the user input.\n \"\"\"\n optional_applied = False\n input_text = f\"- [Recommended] {remediation.optional_msg} [Y/n]: \"\n if remediation.optional_msg is not None:\n if accept_suggestion(input_text, auto_accept):\n assert remediation.optional_fn is not None\n df = remediation.optional_fn(df)\n optional_applied = True\n if remediation.necessary_msg is not None:\n sys.stdout.write(f\"- [Necessary] {remediation.necessary_msg}\\n\")\n return df, optional_applied\n\n\ndef estimate_fine_tuning_time(df: pd.DataFrame) -> None:\n \"\"\"\n Estimate the time it'll take to fine-tune the dataset\n \"\"\"\n ft_format = infer_task_type(df)\n expected_time = 1.0\n if ft_format == \"classification\":\n num_examples = len(df)\n expected_time = num_examples * 1.44\n else:\n size = df.memory_usage(index=True).sum()\n expected_time = size * 0.0515\n\n def format_time(time: float) -> str:\n if time < 60:\n return f\"{round(time, 2)} seconds\"\n elif time < 3600:\n return f\"{round(time / 60, 2)} minutes\"\n elif time < 86400:\n return f\"{round(time / 3600, 2)} hours\"\n else:\n return f\"{round(time / 86400, 2)} days\"\n\n time_string = format_time(expected_time + 140)\n sys.stdout.write(\n f\"Once your model starts training, it'll approximately take {time_string} to train a `curie` model, and less for `ada` and `babbage`. Queue will approximately take half an hour per job ahead of you.\\n\"\n )\n\n\ndef get_outfnames(fname: str, split: bool) -> list[str]:\n suffixes = [\"_train\", \"_valid\"] if split else [\"\"]\n i = 0\n while True:\n index_suffix = f\" ({i})\" if i > 0 else \"\"\n candidate_fnames = [f\"{os.path.splitext(fname)[0]}_prepared{suffix}{index_suffix}.jsonl\" for suffix in suffixes]\n if not any(os.path.isfile(f) for f in candidate_fnames):\n return candidate_fnames\n i += 1\n\n\ndef get_classification_hyperparams(df: pd.DataFrame) -> tuple[int, object]:\n n_classes = df.completion.nunique()\n pos_class = None\n if n_classes == 2:\n pos_class = df.completion.value_counts().index[0]\n return n_classes, pos_class\n\n\ndef write_out_file(df: pd.DataFrame, fname: str, any_remediations: bool, auto_accept: bool) -> None:\n \"\"\"\n This function will write out a dataframe to a file, if the user would like to proceed, and also offer a fine-tuning command with the newly created file.\n For classification it will optionally ask the user if they would like to split the data into train/valid files, and modify the suggested command to include the valid set.\n \"\"\"\n ft_format = infer_task_type(df)\n common_prompt_suffix = get_common_xfix(df.prompt, xfix=\"suffix\")\n common_completion_suffix = get_common_xfix(df.completion, xfix=\"suffix\")\n\n split = False\n input_text = \"- [Recommended] Would you like to split into training and validation set? [Y/n]: \"\n if ft_format == \"classification\":\n if accept_suggestion(input_text, auto_accept):\n split = True\n\n additional_params = \"\"\n common_prompt_suffix_new_line_handled = common_prompt_suffix.replace(\"\\n\", \"\\\\n\")\n common_completion_suffix_new_line_handled = common_completion_suffix.replace(\"\\n\", \"\\\\n\")\n optional_ending_string = (\n f' Make sure to include `stop=[\"{common_completion_suffix_new_line_handled}\"]` so that the generated texts ends at the expected place.'\n if len(common_completion_suffix_new_line_handled) > 0\n else \"\"\n )\n\n input_text = \"\\n\\nYour data will be written to a new JSONL file. Proceed [Y/n]: \"\n\n if not any_remediations and not split:\n sys.stdout.write(\n f'\\nYou can use your file for fine-tuning:\\n> openai api fine_tunes.create -t \"{fname}\"{additional_params}\\n\\nAfter you\u2019ve fine-tuned a model, remember that your prompt has to end with the indicator string `{common_prompt_suffix_new_line_handled}` for the model to start generating completions, rather than continuing with the prompt.{optional_ending_string}\\n'\n )\n estimate_fine_tuning_time(df)\n\n elif accept_suggestion(input_text, auto_accept):\n fnames = get_outfnames(fname, split)\n if split:\n assert len(fnames) == 2 and \"train\" in fnames[0] and \"valid\" in fnames[1]\n MAX_VALID_EXAMPLES = 1000\n n_train = max(len(df) - MAX_VALID_EXAMPLES, int(len(df) * 0.8))\n df_train = df.sample(n=n_train, random_state=42)\n df_valid = df.drop(df_train.index)\n df_train[[\"prompt\", \"completion\"]].to_json( # type: ignore\n fnames[0], lines=True, orient=\"records\", force_ascii=False\n )\n df_valid[[\"prompt\", \"completion\"]].to_json(fnames[1], lines=True, orient=\"records\", force_ascii=False)\n\n n_classes, pos_class = get_classification_hyperparams(df)\n additional_params += \" --compute_classification_metrics\"\n if n_classes == 2:\n additional_params += f' --classification_positive_class \"{pos_class}\"'\n else:\n additional_params += f\" --classification_n_classes {n_classes}\"\n else:\n assert len(fnames) == 1\n df[[\"prompt\", \"completion\"]].to_json(fnames[0], lines=True, orient=\"records\", force_ascii=False)\n\n # Add -v VALID_FILE if we split the file into train / valid\n files_string = (\"s\" if split else \"\") + \" to `\" + (\"` and `\".join(fnames))\n valid_string = f' -v \"{fnames[1]}\"' if split else \"\"\n separator_reminder = (\n \"\"\n if len(common_prompt_suffix_new_line_handled) == 0\n else f\"After you\u2019ve fine-tuned a model, remember that your prompt has to end with the indicator string `{common_prompt_suffix_new_line_handled}` for the model to start generating completions, rather than continuing with the prompt.\"\n )\n sys.stdout.write(\n f'\\nWrote modified file{files_string}`\\nFeel free to take a look!\\n\\nNow use that file when fine-tuning:\\n> openai api fine_tunes.create -t \"{fnames[0]}\"{valid_string}{additional_params}\\n\\n{separator_reminder}{optional_ending_string}\\n'\n )\n estimate_fine_tuning_time(df)\n else:\n sys.stdout.write(\"Aborting... did not write the file\\n\")\n\n\ndef infer_task_type(df: pd.DataFrame) -> str:\n \"\"\"\n Infer the likely fine-tuning task type from the data\n \"\"\"\n CLASSIFICATION_THRESHOLD = 3 # min_average instances of each class\n if sum(df.prompt.str.len()) == 0:\n return \"open-ended generation\"\n\n if len(df.completion.unique()) < len(df) / CLASSIFICATION_THRESHOLD:\n return \"classification\"\n\n return \"conditional generation\"\n\n\ndef get_common_xfix(series: Any, xfix: str = \"suffix\") -> str:\n \"\"\"\n Finds the longest common suffix or prefix of all the values in a series\n \"\"\"\n common_xfix = \"\"\n while True:\n common_xfixes = (\n series.str[-(len(common_xfix) + 1) :] if xfix == \"suffix\" else series.str[: len(common_xfix) + 1]\n ) # first few or last few characters\n if common_xfixes.nunique() != 1: # we found the character at which we don't have a unique xfix anymore\n break\n elif common_xfix == common_xfixes.values[0]: # the entire first row is a prefix of every other row\n break\n else: # the first or last few characters are still common across all rows - let's try to add one more\n common_xfix = common_xfixes.values[0]\n return common_xfix\n\n\nValidator: TypeAlias = \"Callable[[pd.DataFrame], Remediation | None]\"\n\n\ndef get_validators() -> list[Validator]:\n return [\n num_examples_validator,\n lambda x: necessary_column_validator(x, \"prompt\"),\n lambda x: necessary_column_validator(x, \"completion\"),\n additional_column_validator,\n non_empty_field_validator,\n format_inferrer_validator,\n duplicated_rows_validator,\n long_examples_validator,\n lambda x: lower_case_validator(x, \"prompt\"),\n lambda x: lower_case_validator(x, \"completion\"),\n common_prompt_suffix_validator,\n common_prompt_prefix_validator,\n common_completion_prefix_validator,\n common_completion_suffix_validator,\n completions_space_start_validator,\n ]\n\n\ndef apply_validators(\n df: pd.DataFrame,\n fname: str,\n remediation: Remediation | None,\n validators: list[Validator],\n auto_accept: bool,\n write_out_file_func: Callable[..., Any],\n) -> None:\n optional_remediations: list[Remediation] = []\n if remediation is not None:\n optional_remediations.append(remediation)\n for validator in validators:\n remediation = validator(df)\n if remediation is not None:\n optional_remediations.append(remediation)\n df = apply_necessary_remediation(df, remediation)\n\n any_optional_or_necessary_remediations = any(\n [\n remediation\n for remediation in optional_remediations\n if remediation.optional_msg is not None or remediation.necessary_msg is not None\n ]\n )\n any_necessary_applied = any(\n [remediation for remediation in optional_remediations if remediation.necessary_msg is not None]\n )\n any_optional_applied = False\n\n if any_optional_or_necessary_remediations:\n sys.stdout.write(\"\\n\\nBased on the analysis we will perform the following actions:\\n\")\n for remediation in optional_remediations:\n df, optional_applied = apply_optional_remediation(df, remediation, auto_accept)\n any_optional_applied = any_optional_applied or optional_applied\n else:\n sys.stdout.write(\"\\n\\nNo remediations found.\\n\")\n\n any_optional_or_necessary_applied = any_optional_applied or any_necessary_applied\n\n write_out_file_func(df, fname, any_optional_or_necessary_applied, auto_accept)\n\n# Path: src/openai/lib/azure.py\nfrom __future__ import annotations\n\nimport os\nimport inspect\nfrom typing import Any, Union, Mapping, TypeVar, Callable, Awaitable, overload\nfrom typing_extensions import Self, override\n\nimport httpx\n\nfrom .._types import NOT_GIVEN, Omit, Timeout, NotGiven\nfrom .._utils import is_given, is_mapping\nfrom .._client import OpenAI, AsyncOpenAI\nfrom .._models import FinalRequestOptions\nfrom .._streaming import Stream, AsyncStream\nfrom .._exceptions import OpenAIError\nfrom .._base_client import DEFAULT_MAX_RETRIES, BaseClient\n\n_deployments_endpoints = set(\n [\n \"/completions\",\n \"/chat/completions\",\n \"/embeddings\",\n \"/audio/transcriptions\",\n \"/audio/translations\",\n \"/audio/speech\",\n \"/images/generations\",\n ]\n)\n\n\nAzureADTokenProvider = Callable[[], str]\nAsyncAzureADTokenProvider = Callable[[], \"str | Awaitable[str]\"]\n_HttpxClientT = TypeVar(\"_HttpxClientT\", bound=Union[httpx.Client, httpx.AsyncClient])\n_DefaultStreamT = TypeVar(\"_DefaultStreamT\", bound=Union[Stream[Any], AsyncStream[Any]])\n\n\n# we need to use a sentinel API key value for Azure AD\n# as we don't want to make the `api_key` in the main client Optional\n# and Azure AD tokens may be retrieved on a per-request basis\nAPI_KEY_SENTINEL = \"\".join([\"<\", \"missing API key\", \">\"])\n\n\nclass MutuallyExclusiveAuthError(OpenAIError):\n def __init__(self) -> None:\n super().__init__(\n \"The `api_key`, `azure_ad_token` and `azure_ad_token_provider` arguments are mutually exclusive; Only one can be passed at a time\"\n )\n\n\nclass BaseAzureClient(BaseClient[_HttpxClientT, _DefaultStreamT]):\n @override\n def _build_request(\n self,\n options: FinalRequestOptions,\n ) -> httpx.Request:\n if options.url in _deployments_endpoints and is_mapping(options.json_data):\n model = options.json_data.get(\"model\")\n if model is not None and not \"/deployments\" in str(self.base_url):\n options.url = f\"/deployments/{model}{options.url}\"\n\n return super()._build_request(options)\n\n\nclass AzureOpenAI(BaseAzureClient[httpx.Client, Stream[Any]], OpenAI):\n @overload\n def __init__(\n self,\n *,\n azure_endpoint: str,\n azure_deployment: str | None = None,\n api_version: str | None = None,\n api_key: str | None = None,\n azure_ad_token: str | None = None,\n azure_ad_token_provider: AzureADTokenProvider | None = None,\n organization: str | None = None,\n timeout: float | Timeout | None | NotGiven = NOT_GIVEN,\n max_retries: int = DEFAULT_MAX_RETRIES,\n default_headers: Mapping[str, str] | None = None,\n default_query: Mapping[str, object] | None = None,\n http_client: httpx.Client | None = None,\n _strict_response_validation: bool = False,\n ) -> None:\n ...\n\n @overload\n def __init__(\n self,\n *,\n azure_deployment: str | None = None,\n api_version: str | None = None,\n api_key: str | None = None,\n azure_ad_token: str | None = None,\n azure_ad_token_provider: AzureADTokenProvider | None = None,\n organization: str | None = None,\n timeout: float | Timeout | None | NotGiven = NOT_GIVEN,\n max_retries: int = DEFAULT_MAX_RETRIES,\n default_headers: Mapping[str, str] | None = None,\n default_query: Mapping[str, object] | None = None,\n http_client: httpx.Client | None = None,\n _strict_response_validation: bool = False,\n ) -> None:\n ...\n\n @overload\n def __init__(\n self,\n *,\n base_url: str,\n api_version: str | None = None,\n api_key: str | None = None,\n azure_ad_token: str | None = None,\n azure_ad_token_provider: AzureADTokenProvider | None = None,\n organization: str | None = None,\n timeout: float | Timeout | None | NotGiven = NOT_GIVEN,\n max_retries: int = DEFAULT_MAX_RETRIES,\n default_headers: Mapping[str, str] | None = None,\n default_query: Mapping[str, object] | None = None,\n http_client: httpx.Client | None = None,\n _strict_response_validation: bool = False,\n ) -> None:\n ...\n\n def __init__(\n self,\n *,\n api_version: str | None = None,\n...\n", "needle_token_start": 735, "needle_token_end": 843, "code_context_ntokens": 16096, "output": ["```python\ndef _get_arch() -> str:\n architecture = platform.machine().lower()\n\n # Map the architecture names to Node.js equivalents\n arch_map = {\n \"x86_64\": \"x64\",\n \"amd64\": \"x64\",\n \"armv7l\": \"arm\",\n \"aarch64\": \"arm64\",\n }\n\n return arch_map.get(architecture, architecture)\n```"]} +{"repo": "openai/openai-python", "name": "to_custom_raw_response_wrapper", "language": "python", "path": "src/openai/_response.py", "position_ratio": 0.15, "description": "\nFunction Description:\n1. **Purpose**: The function is designed to modify an API method so that it returns a specific type of API response directly, rather than the default response format.\n2. **Input**: It takes an API method and a specific API response class that should be used to format the response.\n3. **Output**: The output is a modified version of the original API method that, when called, returns an instance of the specified API response class.\n4. **Procedure**: The function first creates a wrapper around the original API method. It modifies the method's behavior by adding specific headers to the request to indicate that the response should be returned as a raw data type and specifies the class that should be used to cast the response. The modified API method is then returned, ready to be used with the new settings applied.\n\n", "instruction": "Based on the function description and code context, please retrieve and repeat the exact described function from the code context in a code block wrapped by ```:", "template": "instruction\ncode_context\ndescription\ninstruction", "code_context": "\nclass MissingStreamClassError(TypeError):\n def __init__(self) -> None:\n super().__init__(\n \"The `stream` argument was set to `True` but the `stream_cls` argument was not given. See `openai._streaming` for reference\",\n )\n\n\nclass StreamAlreadyConsumed(OpenAIError):\n \"\"\"\n Attempted to read or stream content, but the content has already\n been streamed.\n\n This can happen if you use a method like `.iter_lines()` and then attempt\n to read th entire response body afterwards, e.g.\n\n ```py\n response = await client.post(...)\n async for line in response.iter_lines():\n ... # do something with `line`\n\n content = await response.read()\n # ^ error\n ```\n\n If you want this behaviour you'll need to either manually accumulate the response\n content or call `await response.read()` before iterating over the stream.\n \"\"\"\n\n def __init__(self) -> None:\n message = (\n \"Attempted to read or stream some content, but the content has \"\n \"already been streamed. \"\n \"This could be due to attempting to stream the response \"\n \"content more than once.\"\n \"\\n\\n\"\n \"You can fix this by manually accumulating the response content while streaming \"\n \"or by calling `.read()` before starting to stream.\"\n )\n super().__init__(message)\n\n\nclass ResponseContextManager(Generic[_APIResponseT]):\n \"\"\"Context manager for ensuring that a request is not made\n until it is entered and that the response will always be closed\n when the context manager exits\n \"\"\"\n\n def __init__(self, request_func: Callable[[], _APIResponseT]) -> None:\n self._request_func = request_func\n self.__response: _APIResponseT | None = None\n\n def __enter__(self) -> _APIResponseT:\n self.__response = self._request_func()\n return self.__response\n\n def __exit__(\n self,\n exc_type: type[BaseException] | None,\n exc: BaseException | None,\n exc_tb: TracebackType | None,\n ) -> None:\n if self.__response is not None:\n self.__response.close()\n\n\nclass AsyncResponseContextManager(Generic[_AsyncAPIResponseT]):\n \"\"\"Context manager for ensuring that a request is not made\n until it is entered and that the response will always be closed\n when the context manager exits\n \"\"\"\n\n def __init__(self, api_request: Awaitable[_AsyncAPIResponseT]) -> None:\n self._api_request = api_request\n self.__response: _AsyncAPIResponseT | None = None\n\n async def __aenter__(self) -> _AsyncAPIResponseT:\n self.__response = await self._api_request\n return self.__response\n\n async def __aexit__(\n self,\n exc_type: type[BaseException] | None,\n exc: BaseException | None,\n exc_tb: TracebackType | None,\n ) -> None:\n if self.__response is not None:\n await self.__response.close()\n\n\ndef to_streamed_response_wrapper(func: Callable[P, R]) -> Callable[P, ResponseContextManager[APIResponse[R]]]:\n \"\"\"Higher order function that takes one of our bound API methods and wraps it\n to support streaming and returning the raw `APIResponse` object directly.\n \"\"\"\n\n @functools.wraps(func)\n def wrapped(*args: P.args, **kwargs: P.kwargs) -> ResponseContextManager[APIResponse[R]]:\n extra_headers = {**(cast(Any, kwargs.get(\"extra_headers\")) or {})}\n extra_headers[RAW_RESPONSE_HEADER] = \"stream\"\n\n kwargs[\"extra_headers\"] = extra_headers\n\n make_request = functools.partial(func, *args, **kwargs)\n\n return ResponseContextManager(cast(Callable[[], APIResponse[R]], make_request))\n\n return wrapped\n\n\ndef async_to_streamed_response_wrapper(\n func: Callable[P, Awaitable[R]],\n) -> Callable[P, AsyncResponseContextManager[AsyncAPIResponse[R]]]:\n \"\"\"Higher order function that takes one of our bound API methods and wraps it\n to support streaming and returning the raw `APIResponse` object directly.\n \"\"\"\n\n @functools.wraps(func)\n def wrapped(*args: P.args, **kwargs: P.kwargs) -> AsyncResponseContextManager[AsyncAPIResponse[R]]:\n extra_headers = {**(cast(Any, kwargs.get(\"extra_headers\")) or {})}\n extra_headers[RAW_RESPONSE_HEADER] = \"stream\"\n\n kwargs[\"extra_headers\"] = extra_headers\n\n make_request = func(*args, **kwargs)\n\n return AsyncResponseContextManager(cast(Awaitable[AsyncAPIResponse[R]], make_request))\n\n return wrapped\n\n\ndef to_custom_streamed_response_wrapper(\n func: Callable[P, object],\n response_cls: type[_APIResponseT],\n) -> Callable[P, ResponseContextManager[_APIResponseT]]:\n \"\"\"Higher order function that takes one of our bound API methods and an `APIResponse` class\n and wraps the method to support streaming and returning the given response class directly.\n\n Note: the given `response_cls` *must* be concrete, e.g. `class BinaryAPIResponse(APIResponse[bytes])`\n \"\"\"\n\n @functools.wraps(func)\n def wrapped(*args: P.args, **kwargs: P.kwargs) -> ResponseContextManager[_APIResponseT]:\n extra_headers = {**(cast(Any, kwargs.get(\"extra_headers\")) or {})}\n extra_headers[RAW_RESPONSE_HEADER] = \"stream\"\n extra_headers[OVERRIDE_CAST_TO_HEADER] = response_cls\n\n kwargs[\"extra_headers\"] = extra_headers\n\n make_request = functools.partial(func, *args, **kwargs)\n\n return ResponseContextManager(cast(Callable[[], _APIResponseT], make_request))\n\n return wrapped\n\n\ndef async_to_custom_streamed_response_wrapper(\n func: Callable[P, Awaitable[object]],\n response_cls: type[_AsyncAPIResponseT],\n) -> Callable[P, AsyncResponseContextManager[_AsyncAPIResponseT]]:\n \"\"\"Higher order function that takes one of our bound API methods and an `APIResponse` class\n and wraps the method to support streaming and returning the given response class directly.\n\n Note: the given `response_cls` *must* be concrete, e.g. `class BinaryAPIResponse(APIResponse[bytes])`\n \"\"\"\n\n @functools.wraps(func)\n def wrapped(*args: P.args, **kwargs: P.kwargs) -> AsyncResponseContextManager[_AsyncAPIResponseT]:\n extra_headers = {**(cast(Any, kwargs.get(\"extra_headers\")) or {})}\n extra_headers[RAW_RESPONSE_HEADER] = \"stream\"\n extra_headers[OVERRIDE_CAST_TO_HEADER] = response_cls\n\n kwargs[\"extra_headers\"] = extra_headers\n\n make_request = func(*args, **kwargs)\n\n return AsyncResponseContextManager(cast(Awaitable[_AsyncAPIResponseT], make_request))\n\n return wrapped\n\n\ndef to_raw_response_wrapper(func: Callable[P, R]) -> Callable[P, APIResponse[R]]:\n \"\"\"Higher order function that takes one of our bound API methods and wraps it\n to support returning the raw `APIResponse` object directly.\n \"\"\"\n\n @functools.wraps(func)\n def wrapped(*args: P.args, **kwargs: P.kwargs) -> APIResponse[R]:\n extra_headers = {**(cast(Any, kwargs.get(\"extra_headers\")) or {})}\n extra_headers[RAW_RESPONSE_HEADER] = \"raw\"\n\n kwargs[\"extra_headers\"] = extra_headers\n\n return cast(APIResponse[R], func(*args, **kwargs))\n\n return wrapped\n\n\ndef async_to_raw_response_wrapper(func: Callable[P, Awaitable[R]]) -> Callable[P, Awaitable[AsyncAPIResponse[R]]]:\n \"\"\"Higher order function that takes one of our bound API methods and wraps it\n to support returning the raw `APIResponse` object directly.\n \"\"\"\n\n @functools.wraps(func)\n async def wrapped(*args: P.args, **kwargs: P.kwargs) -> AsyncAPIResponse[R]:\n extra_headers = {**(cast(Any, kwargs.get(\"extra_headers\")) or {})}\n extra_headers[RAW_RESPONSE_HEADER] = \"raw\"\n\n kwargs[\"extra_headers\"] = extra_headers\n\n return cast(AsyncAPIResponse[R], await func(*args, **kwargs))\n\n return wrapped\n\n\n\ndef to_custom_raw_response_wrapper(\n func: Callable[P, object],\n response_cls: type[_APIResponseT],\n) -> Callable[P, _APIResponseT]:\n \"\"\"Higher order function that takes one of our bound API methods and an `APIResponse` class\n and wraps the method to support returning the given response class directly.\n\n Note: the given `response_cls` *must* be concrete, e.g. `class BinaryAPIResponse(APIResponse[bytes])`\n \"\"\"\n\n @functools.wraps(func)\n def wrapped(*args: P.args, **kwargs: P.kwargs) -> _APIResponseT:\n extra_headers = {**(cast(Any, kwargs.get(\"extra_headers\")) or {})}\n extra_headers[RAW_RESPONSE_HEADER] = \"raw\"\n extra_headers[OVERRIDE_CAST_TO_HEADER] = response_cls\n\n kwargs[\"extra_headers\"] = extra_headers\n\n return cast(_APIResponseT, func(*args, **kwargs))\n\n return wrapped\n\n\ndef async_to_custom_raw_response_wrapper(\n func: Callable[P, Awaitable[object]],\n response_cls: type[_AsyncAPIResponseT],\n) -> Callable[P, Awaitable[_AsyncAPIResponseT]]:\n \"\"\"Higher order function that takes one of our bound API methods and an `APIResponse` class\n and wraps the method to support returning the given response class directly.\n\n Note: the given `response_cls` *must* be concrete, e.g. `class BinaryAPIResponse(APIResponse[bytes])`\n \"\"\"\n\n @functools.wraps(func)\n def wrapped(*args: P.args, **kwargs: P.kwargs) -> Awaitable[_AsyncAPIResponseT]:\n extra_headers = {**(cast(Any, kwargs.get(\"extra_headers\")) or {})}\n extra_headers[RAW_RESPONSE_HEADER] = \"raw\"\n extra_headers[OVERRIDE_CAST_TO_HEADER] = response_cls\n\n kwargs[\"extra_headers\"] = extra_headers\n\n return cast(Awaitable[_AsyncAPIResponseT], func(*args, **kwargs))\n\n return wrapped\n\n\ndef extract_response_type(typ: type[BaseAPIResponse[Any]]) -> type:\n \"\"\"Given a type like `APIResponse[T]`, returns the generic type variable `T`.\n\n This also handles the case where a concrete subclass is given, e.g.\n ```py\n class MyResponse(APIResponse[bytes]):\n ...\n\n extract_response_type(MyResponse) -> bytes\n ```\n \"\"\"\n return extract_type_var_from_base(\n typ,\n generic_bases=cast(\"tuple[type, ...]\", (BaseAPIResponse, APIResponse, AsyncAPIResponse)),\n index=0,\n )\n\n# Path: src/openai/types/completion_choice.py\n# File generated from our OpenAPI spec by Stainless.\n\nfrom typing import Dict, List, Optional\nfrom typing_extensions import Literal\n\nfrom .._models import BaseModel\n\n__all__ = [\"CompletionChoice\", \"Logprobs\"]\n\n\nclass Logprobs(BaseModel):\n text_offset: Optional[List[int]] = None\n\n token_logprobs: Optional[List[float]] = None\n\n tokens: Optional[List[str]] = None\n\n top_logprobs: Optional[List[Dict[str, float]]] = None\n\n\nclass CompletionChoice(BaseModel):\n finish_reason: Literal[\"stop\", \"length\", \"content_filter\"]\n \"\"\"The reason the model stopped generating tokens.\n\n This will be `stop` if the model hit a natural stop point or a provided stop\n sequence, `length` if the maximum number of tokens specified in the request was\n reached, or `content_filter` if content was omitted due to a flag from our\n content filters.\n \"\"\"\n\n index: int\n\n logprobs: Optional[Logprobs] = None\n\n text: str\n\n# Path: src/openai/types/completion_usage.py\n# File generated from our OpenAPI spec by Stainless.\n\nfrom .._models import BaseModel\n\n__all__ = [\"CompletionUsage\"]\n\n\nclass CompletionUsage(BaseModel):\n completion_tokens: int\n \"\"\"Number of tokens in the generated completion.\"\"\"\n\n prompt_tokens: int\n \"\"\"Number of tokens in the prompt.\"\"\"\n\n total_tokens: int\n \"\"\"Total number of tokens used in the request (prompt + completion).\"\"\"\n\n# Path: src/openai/types/completion.py\n# File generated from our OpenAPI spec by Stainless.\n\nfrom typing import List, Optional\nfrom typing_extensions import Literal\n\nfrom .._models import BaseModel\nfrom .completion_usage import CompletionUsage\nfrom .completion_choice import CompletionChoice\n\n__all__ = [\"Completion\"]\n\n\nclass Completion(BaseModel):\n id: str\n \"\"\"A unique identifier for the completion.\"\"\"\n\n choices: List[CompletionChoice]\n \"\"\"The list of completion choices the model generated for the input prompt.\"\"\"\n\n created: int\n \"\"\"The Unix timestamp (in seconds) of when the completion was created.\"\"\"\n\n model: str\n \"\"\"The model used for completion.\"\"\"\n\n object: Literal[\"text_completion\"]\n \"\"\"The object type, which is always \"text_completion\" \"\"\"\n\n system_fingerprint: Optional[str] = None\n \"\"\"This fingerprint represents the backend configuration that the model runs with.\n\n Can be used in conjunction with the `seed` request parameter to understand when\n backend changes have been made that might impact determinism.\n \"\"\"\n\n usage: Optional[CompletionUsage] = None\n \"\"\"Usage statistics for the completion request.\"\"\"\n\n# Path: src/openai/types/completion_create_params.py\n# File generated from our OpenAPI spec by Stainless.\n\nfrom __future__ import annotations\n\nfrom typing import Dict, List, Union, Iterable, Optional\nfrom typing_extensions import Literal, Required, TypedDict\n\n__all__ = [\"CompletionCreateParamsBase\", \"CompletionCreateParamsNonStreaming\", \"CompletionCreateParamsStreaming\"]\n\n\nclass CompletionCreateParamsBase(TypedDict, total=False):\n model: Required[Union[str, Literal[\"gpt-3.5-turbo-instruct\", \"davinci-002\", \"babbage-002\"]]]\n \"\"\"ID of the model to use.\n\n You can use the\n [List models](https://platform.openai.com/docs/api-reference/models/list) API to\n see all of your available models, or see our\n [Model overview](https://platform.openai.com/docs/models/overview) for\n descriptions of them.\n \"\"\"\n\n prompt: Required[Union[str, List[str], Iterable[int], Iterable[Iterable[int]], None]]\n \"\"\"\n The prompt(s) to generate completions for, encoded as a string, array of\n strings, array of tokens, or array of token arrays.\n\n Note that <|endoftext|> is the document separator that the model sees during\n training, so if a prompt is not specified the model will generate as if from the\n beginning of a new document.\n \"\"\"\n\n best_of: Optional[int]\n \"\"\"\n Generates `best_of` completions server-side and returns the \"best\" (the one with\n the highest log probability per token). Results cannot be streamed.\n\n When used with `n`, `best_of` controls the number of candidate completions and\n `n` specifies how many to return \u2013 `best_of` must be greater than `n`.\n\n **Note:** Because this parameter generates many completions, it can quickly\n consume your token quota. Use carefully and ensure that you have reasonable\n settings for `max_tokens` and `stop`.\n \"\"\"\n\n echo: Optional[bool]\n \"\"\"Echo back the prompt in addition to the completion\"\"\"\n\n frequency_penalty: Optional[float]\n \"\"\"Number between -2.0 and 2.0.\n\n Positive values penalize new tokens based on their existing frequency in the\n text so far, decreasing the model's likelihood to repeat the same line verbatim.\n\n [See more information about frequency and presence penalties.](https://platform.openai.com/docs/guides/text-generation/parameter-details)\n \"\"\"\n\n logit_bias: Optional[Dict[str, int]]\n \"\"\"Modify the likelihood of specified tokens appearing in the completion.\n\n Accepts a JSON object that maps tokens (specified by their token ID in the GPT\n tokenizer) to an associated bias value from -100 to 100. You can use this\n [tokenizer tool](/tokenizer?view=bpe) to convert text to token IDs.\n Mathematically, the bias is added to the logits generated by the model prior to\n sampling. The exact effect will vary per model, but values between -1 and 1\n should decrease or increase likelihood of selection; values like -100 or 100\n should result in a ban or exclusive selection of the relevant token.\n\n As an example, you can pass `{\"50256\": -100}` to prevent the <|endoftext|> token\n from being generated.\n \"\"\"\n\n logprobs: Optional[int]\n \"\"\"\n Include the log probabilities on the `logprobs` most likely output tokens, as\n well the chosen tokens. For example, if `logprobs` is 5, the API will return a\n list of the 5 most likely tokens. The API will always return the `logprob` of\n the sampled token, so there may be up to `logprobs+1` elements in the response.\n\n The maximum value for `logprobs` is 5.\n \"\"\"\n\n max_tokens: Optional[int]\n \"\"\"\n The maximum number of [tokens](/tokenizer) that can be generated in the\n completion.\n\n The token count of your prompt plus `max_tokens` cannot exceed the model's\n context length.\n [Example Python code](https://cookbook.openai.com/examples/how_to_count_tokens_with_tiktoken)\n for counting tokens.\n \"\"\"\n\n n: Optional[int]\n \"\"\"How many completions to generate for each prompt.\n\n **Note:** Because this parameter generates many completions, it can quickly\n consume your token quota. Use carefully and ensure that you have reasonable\n settings for `max_tokens` and `stop`.\n \"\"\"\n\n presence_penalty: Optional[float]\n \"\"\"Number between -2.0 and 2.0.\n\n Positive values penalize new tokens based on whether they appear in the text so\n far, increasing the model's likelihood to talk about new topics.\n\n [See more information about frequency and presence penalties.](https://platform.openai.com/docs/guides/text-generation/parameter-details)\n \"\"\"\n\n seed: Optional[int]\n \"\"\"\n If specified, our system will make a best effort to sample deterministically,\n such that repeated requests with the same `seed` and parameters should return\n the same result.\n\n Determinism is not guaranteed, and you should refer to the `system_fingerprint`\n response parameter to monitor changes in the backend.\n \"\"\"\n\n stop: Union[Optional[str], List[str], None]\n \"\"\"Up to 4 sequences where the API will stop generating further tokens.\n\n The returned text will not contain the stop sequence.\n \"\"\"\n\n suffix: Optional[str]\n \"\"\"The suffix that comes after a completion of inserted text.\"\"\"\n\n temperature: Optional[float]\n \"\"\"What sampling temperature to use, between 0 and 2.\n\n Higher values like 0.8 will make the output more random, while lower values like\n 0.2 will make it more focused and deterministic.\n\n We generally recommend altering this or `top_p` but not both.\n \"\"\"\n\n top_p: Optional[float]\n \"\"\"\n An alternative to sampling with temperature, called nucleus sampling, where the\n model considers the results of the tokens with top_p probability mass. So 0.1\n means only the tokens comprising the top 10% probability mass are considered.\n\n We generally recommend altering this or `temperature` but not both.\n \"\"\"\n\n user: str\n \"\"\"\n A unique identifier representing your end-user, which can help OpenAI to monitor\n and detect abuse.\n [Learn more](https://platform.openai.com/docs/guides/safety-best-practices/end-user-ids).\n \"\"\"\n\n\nclass CompletionCreateParamsNonStreaming(CompletionCreateParamsBase):\n stream: Optional[Literal[False]]\n \"\"\"Whether to stream back partial progress.\n\n If set, tokens will be sent as data-only\n [server-sent events](https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events#Event_stream_format)\n as they become available, with the stream terminated by a `data: [DONE]`\n message.\n [Example Python code](https://cookbook.openai.com/examples/how_to_stream_completions).\n \"\"\"\n\n\nclass CompletionCreateParamsStreaming(CompletionCreateParamsBase):\n stream: Required[Literal[True]]\n \"\"\"Whether to stream back partial progress.\n\n If set, tokens will be sent as data-only\n [server-sent events](https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events#Event_stream_format)\n as they become available, with the stream terminated by a `data: [DONE]`\n message.\n [Example Python code](https://cookbook.openai.com/examples/how_to_stream_completions).\n \"\"\"\n\n\nCompletionCreateParams = Union[CompletionCreateParamsNonStreaming, CompletionCreateParamsStreaming]\n\n# Path: src/openai/types/embedding.py\n# File generated from our OpenAPI spec by Stainless.\n\nfrom typing import List\nfrom typing_extensions import Literal\n\nfrom .._models import BaseModel\n\n__all__ = [\"Embedding\"]\n\n\nclass Embedding(BaseModel):\n embedding: List[float]\n \"\"\"The embedding vector, which is a list of floats.\n\n The length of vector depends on the model as listed in the\n [embedding guide](https://platform.openai.com/docs/guides/embeddings).\n \"\"\"\n\n index: int\n \"\"\"The index of the embedding in the list of embeddings.\"\"\"\n\n object: Literal[\"embedding\"]\n \"\"\"The object type, which is always \"embedding\".\"\"\"\n\n# Path: src/openai/types/create_embedding_response.py\n# File generated from our OpenAPI spec by Stainless.\n\nfrom typing import List\nfrom typing_extensions import Literal\n\nfrom .._models import BaseModel\nfrom .embedding import Embedding\n\n__all__ = [\"CreateEmbeddingResponse\", \"Usage\"]\n\n\nclass Usage(BaseModel):\n prompt_tokens: int\n \"\"\"The number of tokens used by the prompt.\"\"\"\n\n total_tokens: int\n \"\"\"The total number of tokens used by the request.\"\"\"\n\n\nclass CreateEmbeddingResponse(BaseModel):\n data: List[Embedding]\n \"\"\"The list of embeddings generated by the model.\"\"\"\n\n model: str\n \"\"\"The name of the model used to generate the embedding.\"\"\"\n\n object: Literal[\"list\"]\n \"\"\"The object type, which is always \"list\".\"\"\"\n\n usage: Usage\n \"\"\"The usage information for the request.\"\"\"\n\n# Path: src/openai/types/embedding_create_params.py\n# File generated from our OpenAPI spec by Stainless.\n\nfrom __future__ import annotations\n\nfrom typing import List, Union, Iterable\nfrom typing_extensions import Literal, Required, TypedDict\n\n__all__ = [\"EmbeddingCreateParams\"]\n\n\nclass EmbeddingCreateParams(TypedDict, total=False):\n input: Required[Union[str, List[str], Iterable[int], Iterable[Iterable[int]]]]\n \"\"\"Input text to embed, encoded as a string or array of tokens.\n\n To embed multiple inputs in a single request, pass an array of strings or array\n of token arrays. The input must not exceed the max input tokens for the model\n (8192 tokens for `text-embedding-ada-002`), cannot be an empty string, and any\n array must be 2048 dimensions or less.\n [Example Python code](https://cookbook.openai.com/examples/how_to_count_tokens_with_tiktoken)\n for counting tokens.\n \"\"\"\n\n model: Required[Union[str, Literal[\"text-embedding-ada-002\", \"text-embedding-3-small\", \"text-embedding-3-large\"]]]\n \"\"\"ID of the model to use.\n\n You can use the\n [List models](https://platform.openai.com/docs/api-reference/models/list) API to\n see all of your available models, or see our\n [Model overview](https://platform.openai.com/docs/models/overview) for\n descriptions of them.\n \"\"\"\n\n dimensions: int\n \"\"\"The number of dimensions the resulting output embeddings should have.\n\n Only supported in `text-embedding-3` and later models.\n \"\"\"\n\n encoding_format: Literal[\"float\", \"base64\"]\n \"\"\"The format to return the embeddings in.\n\n Can be either `float` or [`base64`](https://pypi.org/project/pybase64/).\n \"\"\"\n\n user: str\n \"\"\"\n A unique identifier representing your end-user, which can help OpenAI to monitor\n and detect abuse.\n [Learn more](https://platform.openai.com/docs/guides/safety-best-practices/end-user-ids).\n \"\"\"\n\n# Path: src/openai/types/file_content.py\n# File generated from our OpenAPI spec by Stainless.\n\n\n__all__ = [\"FileContent\"]\n\nFileContent = str\n\n# Path: src/openai/types/file_create_params.py\n# File generated from our OpenAPI spec by Stainless.\n\nfrom __future__ import annotations\n\nfrom typing_extensions import Literal, Required, TypedDict\n\nfrom .._types import FileTypes\n\n__all__ = [\"FileCreateParams\"]\n\n\nclass FileCreateParams(TypedDict, total=False):\n file: Required[FileTypes]\n \"\"\"The File object (not file name) to be uploaded.\"\"\"\n\n purpose: Required[Literal[\"fine-tune\", \"assistants\"]]\n \"\"\"The intended purpose of the uploaded file.\n\n Use \"fine-tune\" for\n [Fine-tuning](https://platform.openai.com/docs/api-reference/fine-tuning) and\n \"assistants\" for\n [Assistants](https://platform.openai.com/docs/api-reference/assistants) and\n [Messages](https://platform.openai.com/docs/api-reference/messages). This allows\n us to validate the format of the uploaded file is correct for fine-tuning.\n \"\"\"\n\n# Path: src/openai/types/file_deleted.py\n# File generated from our OpenAPI spec by Stainless.\n\nfrom typing_extensions import Literal\n\nfrom .._models import BaseModel\n\n__all__ = [\"FileDeleted\"]\n\n\nclass FileDeleted(BaseModel):\n id: str\n\n deleted: bool\n\n object: Literal[\"file\"]\n\n# Path: src/openai/types/file_list_params.py\n# File generated from our OpenAPI spec by Stainless.\n\nfrom __future__ import annotations\n\nfrom typing_extensions import TypedDict\n\n__all__ = [\"FileListParams\"]\n\n\nclass FileListParams(TypedDict, total=False):\n purpose: str\n \"\"\"Only return files with the given purpose.\"\"\"\n\n# Path: src/openai/types/file_object.py\n# File generated from our OpenAPI spec by Stainless.\n\nfrom typing import Optional\nfrom typing_extensions import Literal\n\nfrom .._models import BaseModel\n\n__all__ = [\"FileObject\"]\n\n\nclass FileObject(BaseModel):\n id: str\n \"\"\"The file identifier, which can be referenced in the API endpoints.\"\"\"\n\n bytes: int\n \"\"\"The size of the file, in bytes.\"\"\"\n\n created_at: int\n \"\"\"The Unix timestamp (in seconds) for when the file was created.\"\"\"\n\n filename: str\n \"\"\"The name of the file.\"\"\"\n\n object: Literal[\"file\"]\n \"\"\"The object type, which is always `file`.\"\"\"\n\n purpose: Literal[\"fine-tune\", \"fine-tune-results\", \"assistants\", \"assistants_output\"]\n \"\"\"The intended purpose of the file.\n\n Supported values are `fine-tune`, `fine-tune-results`, `assistants`, and\n `assistants_output`.\n \"\"\"\n\n status: Literal[\"uploaded\", \"processed\", \"error\"]\n \"\"\"Deprecated.\n\n The current status of the file, which can be either `uploaded`, `processed`, or\n `error`.\n \"\"\"\n\n status_details: Optional[str] = None\n \"\"\"Deprecated.\n\n For details on why a fine-tuning training file failed validation, see the\n `error` field on `fine_tuning.job`.\n \"\"\"\n\n# Path: src/openai/types/image.py\n# File generated from our OpenAPI spec by Stainless.\n\nfrom typing import Optional\n\nfrom .._models import BaseModel\n\n__all__ = [\"Image\"]\n\n\nclass Image(BaseModel):\n b64_json: Optional[str] = None\n \"\"\"\n The base64-encoded JSON of the generated image, if `response_format` is\n `b64_json`.\n \"\"\"\n\n revised_prompt: Optional[str] = None\n \"\"\"\n The prompt that was used to generate the image, if there was any revision to the\n prompt.\n \"\"\"\n\n url: Optional[str] = None\n \"\"\"The URL of the generated image, if `response_format` is `url` (default).\"\"\"\n\n# Path: src/openai/types/image_create_variation_params.py\n# File generated from our OpenAPI spec by Stainless.\n\nfrom __future__ import annotations\n\nfrom typing import Union, Optional\nfrom typing_extensions import Literal, Required, TypedDict\n\nfrom .._types import FileTypes\n\n__all__ = [\"ImageCreateVariationParams\"]\n\n\nclass ImageCreateVariationParams(TypedDict, total=False):\n image: Required[FileTypes]\n \"\"\"The image to use as the basis for the variation(s).\n\n Must be a valid PNG file, less than 4MB, and square.\n \"\"\"\n\n model: Union[str, Literal[\"dall-e-2\"], None]\n \"\"\"The model to use for image generation.\n\n Only `dall-e-2` is supported at this time.\n \"\"\"\n\n n: Optional[int]\n \"\"\"The number of images to generate.\n\n Must be between 1 and 10. For `dall-e-3`, only `n=1` is supported.\n \"\"\"\n\n response_format: Optional[Literal[\"url\", \"b64_json\"]]\n \"\"\"The format in which the generated images are returned.\n\n Must be one of `url` or `b64_json`.\n \"\"\"\n\n size: Optional[Literal[\"256x256\", \"512x512\", \"1024x1024\"]]\n \"\"\"The size of the generated images.\n\n Must be one of `256x256`, `512x512`, or `1024x1024`.\n \"\"\"\n\n user: str\n \"\"\"\n A unique identifier representing your end-user, which can help OpenAI to monitor\n and detect abuse.\n [Learn more](https://platform.openai.com/docs/guides/safety-best-practices/end-user-ids).\n \"\"\"\n\n# Path: src/openai/types/image_edit_params.py\n# File generated from our OpenAPI spec by Stainless.\n\nfrom __future__ import annotations\n\nfrom typing import Union, Optional\nfrom typing_extensions import Literal, Required, TypedDict\n\nfrom .._types import FileTypes\n\n__all__ = [\"ImageEditParams\"]\n\n\nclass ImageEditParams(TypedDict, total=False):\n image: Required[FileTypes]\n \"\"\"The image to edit.\n\n Must be a valid PNG file, less than 4MB, and square. If mask is not provided,\n image must have transparency, which will be used as the mask.\n \"\"\"\n\n prompt: Required[str]\n \"\"\"A text description of the desired image(s).\n\n The maximum length is 1000 characters.\n \"\"\"\n\n mask: FileTypes\n \"\"\"An additional image whose fully transparent areas (e.g.\n\n where alpha is zero) indicate where `image` should be edited. Must be a valid\n PNG file, less than 4MB, and have the same dimensions as `image`.\n \"\"\"\n\n model: Union[str, Literal[\"dall-e-2\"], None]\n \"\"\"The model to use for image generation.\n\n Only `dall-e-2` is supported at this time.\n \"\"\"\n\n n: Optional[int]\n \"\"\"The number of images to generate. Must be between 1 and 10.\"\"\"\n\n response_format: Optional[Literal[\"url\", \"b64_json\"]]\n \"\"\"The format in which the generated images are returned.\n\n Must be one of `url` or `b64_json`.\n \"\"\"\n\n size: Optional[Literal[\"256x256\", \"512x512\", \"1024x1024\"]]\n \"\"\"The size of the generated images.\n\n Must be one of `256x256`, `512x512`, or `1024x1024`.\n \"\"\"\n\n user: str\n \"\"\"\n A unique identifier representing your end-user, which can help OpenAI to monitor\n and detect abuse.\n [Learn more](https://platform.openai.com/docs/guides/safety-best-practices/end-user-ids).\n \"\"\"\n\n# Path: src/openai/types/image_generate_params.py\n# File generated from our OpenAPI spec by Stainless.\n\nfrom __future__ import annotations\n\nfrom typing import Union, Optional\nfrom typing_extensions import Literal, Required, TypedDict\n\n__all__ = [\"ImageGenerateParams\"]\n\n\nclass ImageGenerateParams(TypedDict, total=False):\n prompt: Required[str]\n \"\"\"A text description of the desired image(s).\n\n The maximum length is 1000 characters for `dall-e-2` and 4000 characters for\n `dall-e-3`.\n \"\"\"\n\n model: Union[str, Literal[\"dall-e-2\", \"dall-e-3\"], None]\n \"\"\"The model to use for image generation.\"\"\"\n\n n: Optional[int]\n \"\"\"The number of images to generate.\n\n Must be between 1 and 10. For `dall-e-3`, only `n=1` is supported.\n \"\"\"\n\n quality: Literal[\"standard\", \"hd\"]\n \"\"\"The quality of the image that will be generated.\n\n `hd` creates images with finer details and greater consistency across the image.\n This param is only supported for `dall-e-3`.\n \"\"\"\n\n response_format: Optional[Literal[\"url\", \"b64_json\"]]\n \"\"\"The format in which the generated images are returned.\n\n Must be one of `url` or `b64_json`.\n \"\"\"\n\n size: Optional[Literal[\"256x256\", \"512x512\", \"1024x1024\", \"1792x1024\", \"1024x1792\"]]\n \"\"\"The size of the generated images.\n\n Must be one of `256x256`, `512x512`, or `1024x1024` for `dall-e-2`. Must be one\n of `1024x1024`, `1792x1024`, or `1024x1792` for `dall-e-3` models.\n \"\"\"\n\n style: Optional[Literal[\"vivid\", \"natural\"]]\n \"\"\"The style of the generated images.\n\n Must be one of `vivid` or `natural`. Vivid causes the model to lean towards\n generating hyper-real and dramatic images. Natural causes the model to produce\n more natural, less hyper-real looking images. This param is only supported for\n `dall-e-3`.\n \"\"\"\n\n user: str\n \"\"\"\n A unique identifier representing your end-user, which can help OpenAI to monitor\n and detect abuse.\n [Learn more](https://platform.openai.com/docs/guides/safety-best-practices/end-user-ids).\n \"\"\"\n\n# Path: src/openai/types/images_response.py\n# File generated from our OpenAPI spec by Stainless.\n\nfrom typing import List\n\nfrom .image import Image\nfrom .._models import BaseModel\n\n__all__ = [\"ImagesResponse\"]\n\n\nclass ImagesResponse(BaseModel):\n created: int\n\n data: List[Image]\n\n# Path: src/openai/types/model.py\n# File generated from our OpenAPI spec by Stainless.\n\nfrom typing_extensions import Literal\n\nfrom .._models import BaseModel\n\n__all__ = [\"Model\"]\n\n\nclass Model(BaseModel):\n id: str\n \"\"\"The model identifier, which can be referenced in the API endpoints.\"\"\"\n\n created: int\n \"\"\"The Unix timestamp (in seconds) when the model was created.\"\"\"\n\n object: Literal[\"model\"]\n \"\"\"The object type, which is always \"model\".\"\"\"\n\n owned_by: str\n \"\"\"The organization that owns the model.\"\"\"\n\n# Path: src/openai/types/model_deleted.py\n# File generated from our OpenAPI spec by Stainless.\n\nfrom .._models import BaseModel\n\n__all__ = [\"ModelDeleted\"]\n\n\nclass ModelDeleted(BaseModel):\n id: str\n\n deleted: bool\n\n object: str\n\n# Path: src/openai/types/moderation.py\n# File generated from our OpenAPI spec by Stainless.\n\nfrom pydantic import Field as FieldInfo\n\nfrom .._models import BaseModel\n\n__all__ = [\"Moderation\", \"Categories\", \"CategoryScores\"]\n\n\nclass Categories(BaseModel):\n harassment: bool\n \"\"\"\n Content that expresses, incites, or promotes harassing language towards any\n target.\n \"\"\"\n\n harassment_threatening: bool = FieldInfo(alias=\"harassment/threatening\")\n \"\"\"\n Harassment content that also includes violence or serious harm towards any\n target.\n \"\"\"\n\n hate: bool\n \"\"\"\n Content that expresses, incites, or promotes hate based on race, gender,\n ethnicity, religion, nationality, sexual orientation, disability status, or\n caste. Hateful content aimed at non-protected groups (e.g., chess players) is\n harassment.\n \"\"\"\n\n hate_threatening: bool = FieldInfo(alias=\"hate/threatening\")\n \"\"\"\n Hateful content that also includes violence or serious harm towards the targeted\n group based on race, gender, ethnicity, religion, nationality, sexual\n orientation, disability status, or caste.\n \"\"\"\n\n self_harm: bool = FieldInfo(alias=\"self-harm\")\n \"\"\"\n Content that promotes, encourages, or depicts acts of self-harm, such as\n suicide, cutting, and eating disorders.\n \"\"\"\n\n self_harm_instructions: bool = FieldInfo(alias=\"self-harm/instructions\")\n \"\"\"\n Content that encourages performing acts of self-harm, such as suicide, cutting,\n and eating disorders, or that gives instructions or advice on how to commit such\n acts.\n \"\"\"\n\n self_harm_intent: bool = FieldInfo(alias=\"self-harm/intent\")\n \"\"\"\n Content where the speaker expresses that they are engaging or intend to engage\n in acts of self-harm, such as suicide, cutting, and eating disorders.\n \"\"\"\n\n sexual: bool\n \"\"\"\n Content meant to arouse sexual excitement, such as the description of sexual\n activity, or that promotes sexual services (excluding sex education and\n wellness).\n \"\"\"\n\n sexual_minors: bool = FieldInfo(alias=\"sexual/minors\")\n \"\"\"Sexual content that includes an individual who is under 18 years old.\"\"\"\n\n violence: bool\n \"\"\"Content that depicts death, violence, or physical injury.\"\"\"\n\n violence_graphic: bool = FieldInfo(alias=\"violence/graphic\")\n \"\"\"Content that depicts death, violence, or physical injury in graphic detail.\"\"\"\n\n\nclass CategoryScores(BaseModel):\n harassment: float\n \"\"\"The score for the category 'harassment'.\"\"\"\n\n harassment_threatening: float = FieldInfo(alias=\"harassment/threatening\")\n \"\"\"The score for the category 'harassment/threatening'.\"\"\"\n\n hate: float\n \"\"\"The score for the category 'hate'.\"\"\"\n\n hate_threatening: float = FieldInfo(alias=\"hate/threatening\")\n \"\"\"The score for the category 'hate/threatening'.\"\"\"\n\n self_harm: float = FieldInfo(alias=\"self-harm\")\n \"\"\"The score for the category 'self-harm'.\"\"\"\n\n self_harm_instructions: float = FieldInfo(alias=\"self-harm/instructions\")\n \"\"\"The score for the category 'self-harm/instructions'.\"\"\"\n\n self_harm_intent: float = FieldInfo(alias=\"self-harm/intent\")\n \"\"\"The score for the category 'self-harm/intent'.\"\"\"\n\n sexual: float\n \"\"\"The score for the category 'sexual'.\"\"\"\n\n sexual_minors: float = FieldInfo(alias=\"sexual/minors\")\n \"\"\"The score for the category 'sexual/minors'.\"\"\"\n\n violence: float\n \"\"\"The score for the category 'violence'.\"\"\"\n\n violence_graphic: float = FieldInfo(alias=\"violence/graphic\")\n \"\"\"The score for the category 'violence/graphic'.\"\"\"\n\n\nclass Moderation(BaseModel):\n categories: Categories\n \"\"\"A list of the categories, and whether they are flagged or not.\"\"\"\n\n category_scores: CategoryScores\n \"\"\"A list of the categories along with their scores as predicted by model.\"\"\"\n\n flagged: bool\n \"\"\"\n Whether the content violates\n [OpenAI's usage policies](/policies/usage-policies).\n \"\"\"\n\n# Path: src/openai/types/moderation_create_params.py\n# File generated from our OpenAPI spec by Stainless.\n\nfrom __future__ import annotations\n\nfrom typing import List, Union\nfrom typing_extensions import Literal, Required, TypedDict\n\n__all__ = [\"ModerationCreateParams\"]\n\n\nclass ModerationCreateParams(TypedDict, total=False):\n input: Required[Union[str, List[str]]]\n \"\"\"The input text to classify\"\"\"\n\n model: Union[str, Literal[\"text-moderation-latest\", \"text-moderation-stable\"]]\n \"\"\"\n Two content moderations models are available: `text-moderation-stable` and\n `text-moderation-latest`.\n\n The default is `text-moderation-latest` which will be automatically upgraded\n over time. This ensures you are always using our most accurate model. If you use\n `text-moderation-stable`, we will provide advanced notice before updating the\n model. Accuracy of `text-moderation-stable` may be slightly lower than for\n `text-moderation-latest`.\n \"\"\"\n\n# Path: src/openai/types/moderation_create_response.py\n# File generated from our OpenAPI spec by Stainless.\n\nfrom typing import List\n\nfrom .._models import BaseModel\nfrom .moderation import Moderation\n\n__all__ = [\"ModerationCreateResponse\"]\n\n\nclass ModerationCreateResponse(BaseModel):\n id: str\n \"\"\"The unique identifier for the moderation request.\"\"\"\n\n model: str\n \"\"\"The model used to generate the moderation results.\"\"\"\n\n results: List[Moderation]\n \"\"\"A list of moderation objects.\"\"\"\n\n# Path: src/openai/types/shared/function_parameters.py\n# File generated from our OpenAPI spec by Stainless.\n\nfrom typing import Dict\n\n__all__ = [\"FunctionParameters\"]\n\nFunctionParameters = Dict[str, object]\n\n# Path: src/openai/types/shared/function_definition.py\n# File generated from our OpenAPI spec by Stainless.\n\nfrom typing import Optional\n\nfrom ..._models import BaseModel\nfrom .function_parameters import FunctionParameters\n\n__all__ = [\"FunctionDefinition\"]\n\n\nclass FunctionDefinition(BaseModel):\n name: str\n \"\"\"The name of the function to be called.\n\n Must be a-z, A-Z, 0-9, or contain underscores and dashes, with a maximum length\n of 64.\n \"\"\"\n\n description: Optional[str] = None\n \"\"\"\n A description of what the function does, used by the model to choose when and\n how to call the function.\n \"\"\"\n\n parameters: Optional[FunctionParameters] = None\n \"\"\"The parameters the functions accepts, described as a JSON Schema object.\n\n See the\n [guide](https://platform.openai.com/docs/guides/text-generation/function-calling)\n for examples, and the\n [JSON Schema reference](https://json-schema.org/understanding-json-schema/) for\n documentation about the format.\n\n Omitting `parameters` defines a function with an empty parameter list.\n \"\"\"\n\n# Path: src/openai/types/shared/__init__.py\n# File generated from our OpenAPI spec by Stainless.\n\nfrom .function_definition import FunctionDefinition as FunctionDefinition\nfrom .function_parameters import FunctionParameters as FunctionParameters\n\n# Path: src/openai/types/__init__.py\n# File generated from our OpenAPI spec by Stainless.\n\nfrom __future__ import annotations\n\nfrom .image import Image as Image\nfrom .model import Model as Model\nfrom .shared import FunctionDefinition as FunctionDefinition, FunctionParameters as FunctionParameters\nfrom .embedding import Embedding as Embedding\nfrom .completion import Completion as Completion\nfrom .moderation import Moderation as Moderation\nfrom .file_object import FileObject as FileObject\nfrom .file_content import FileContent as FileContent\nfrom .file_deleted import FileDeleted as FileDeleted\nfrom .model_deleted import ModelDeleted as ModelDeleted\nfrom .images_response import ImagesResponse as ImagesResponse\nfrom .completion_usage import CompletionUsage as CompletionUsage\nfrom .file_list_params import FileListParams as FileListParams\nfrom .completion_choice import CompletionChoice as CompletionChoice\nfrom .image_edit_params import ImageEditParams as ImageEditParams\nfrom .file_create_params import FileCreateParams as FileCreateParams\nfrom .image_generate_params import ImageGenerateParams as ImageGenerateParams\nfrom .embedding_create_params import EmbeddingCreateParams as EmbeddingCreateParams\nfrom .completion_create_params import CompletionCreateParams as CompletionCreateParams\nfrom .moderation_create_params import ModerationCreateParams as ModerationCreateParams\nfrom .create_embedding_response import CreateEmbeddingResponse as CreateEmbeddingResponse\nfrom .moderation_create_response import ModerationCreateResponse as ModerationCreateResponse\nfrom .image_create_variation_params import ImageCreateVariationParams as ImageCreateVariationParams\n\n# Path: src/openai/types/audio/speech_create_params.py\n# File generated from our OpenAPI spec by Stainless.\n\nfrom __future__ import annotations\n\nfrom typing import Union\nfrom typing_extensions import Literal, Required, TypedDict\n\n__all__ = [\"SpeechCreateParams\"]\n\n\nclass SpeechCreateParams(TypedDict, total=False):\n input: Required[str]\n \"\"\"The text to generate audio for. The maximum length is 4096 characters.\"\"\"\n\n model: Required[Union[str, Literal[\"tts-1\", \"tts-1-hd\"]]]\n \"\"\"\n One of the available [TTS models](https://platform.openai.com/docs/models/tts):\n `tts-1` or `tts-1-hd`\n \"\"\"\n\n voice: Required[Literal[\"alloy\", \"echo\", \"fable\", \"onyx\", \"nova\", \"shimmer\"]]\n \"\"\"The voice to use when generating the audio.\n\n Supported voices are `alloy`, `echo`, `fable`, `onyx`, `nova`, and `shimmer`.\n Previews of the voices are available in the\n [Text to speech guide](https://platform.openai.com/docs/guides/text-to-speech/voice-options).\n \"\"\"\n\n response_format: Literal[\"mp3\", \"opus\", \"aac\", \"flac\", \"pcm\", \"wav\"]\n \"\"\"The format to return audio in.\n\n Supported formats are `mp3`, `opus`, `aac`, `flac`, `pcm`, and `wav`.\n\n The `pcm` audio format, similar to `wav` but without a header, utilizes a 24kHz\n sample rate, mono channel, and 16-bit depth in signed little-endian format.\n \"\"\"\n\n speed: float\n \"\"\"The speed of the generated audio.\n\n Select a value from `0.25` to `4.0`. `1.0` is the default.\n \"\"\"\n\n# Path: src/openai/types/audio/transcription.py\n# File generated from our OpenAPI spec by Stainless.\n\nfrom ..._models import BaseModel\n\n__all__ = [\"Transcription\"]\n\n\nclass Transcription(BaseModel):\n text: str\n\n# Path: src/openai/types/audio/transcription_create_params.py\n# File generated from our OpenAPI spec by Stainless.\n\nfrom __future__ import annotations\n\nfrom typing import List, Union\nfrom typing_extensions import Literal, Required, TypedDict\n\nfrom ..._types import FileTypes\n\n__all__ = [\"TranscriptionCreateParams\"]\n\n\nclass TranscriptionCreateParams(TypedDict, total=False):\n file: Required[FileTypes]\n \"\"\"\n The audio file object (not file name) to transcribe, in one of these formats:\n flac, mp3, mp4, mpeg, mpga, m4a, ogg, wav, or webm.\n \"\"\"\n\n model: Required[Union[str, Literal[\"whisper-1\"]]]\n \"\"\"ID of the model to use. Only `whisper-1` is currently available.\"\"\"\n\n language: str\n \"\"\"The language of the input audio.\n\n Supplying the input language in\n [ISO-639-1](https://en.wikipedia.org/wiki/List_of_ISO_639-1_codes) format will\n improve accuracy and latency.\n \"\"\"\n\n prompt: str\n \"\"\"An optional text to guide the model's style or continue a previous audio\n segment.\n\n The [prompt](https://platform.openai.com/docs/guides/speech-to-text/prompting)\n should match the audio language.\n \"\"\"\n\n response_format: Literal[\"json\", \"text\", \"srt\", \"verbose_json\", \"vtt\"]\n \"\"\"\n The format of the transcript output, in one of these options: `json`, `text`,\n `srt`, `verbose_json`, or `vtt`.\n \"\"\"\n\n temperature: float\n \"\"\"The sampling temperature, between 0 and 1.\n\n Higher values like 0.8 will make the output more random, while lower values like\n 0.2 will make it more focused and deterministic. If set to 0, the model will use\n [log probability](https://en.wikipedia.org/wiki/Log_probability) to\n automatically increase the temperature until certain thresholds are hit.\n \"\"\"\n\n timestamp_granularities: List[Literal[\"word\", \"segment\"]]\n \"\"\"The timestamp granularities to populate for this transcription.\n\n Any of these options: `word`, or `segment`. Note: There is no additional latency\n for segment timestamps, but generating word timestamps incurs additional\n latency.\n \"\"\"\n\n# Path: src/openai/types/audio/translation.py\n# File generated from our OpenAPI spec by Stainless.\n\nfrom ..._models import BaseModel\n\n__all__ = [\"Translation\"]\n\n\nclass Translation(BaseModel):\n text: str\n\n# Path: src/openai/types/audio/translation_create_params.py\n# File generated from our OpenAPI spec by Stainless.\n\nfrom __future__ import annotations\n\nfrom typing import Union\nfrom typing_extensions import Literal, Required, TypedDict\n\nfrom ..._types import FileTypes\n\n__all__ = [\"TranslationCreateParams\"]\n\n\nclass TranslationCreateParams(TypedDict, total=False):\n file: Required[FileTypes]\n \"\"\"\n The audio file object (not file name) translate, in one of these formats: flac,\n mp3, mp4, mpeg, mpga, m4a, ogg, wav, or webm.\n \"\"\"\n\n model: Required[Union[str, Literal[\"whisper-1\"]]]\n \"\"\"ID of the model to use. Only `whisper-1` is currently available.\"\"\"\n\n prompt: str\n \"\"\"An optional text to guide the model's style or continue a previous audio\n segment.\n\n The [prompt](https://platform.openai.com/docs/guides/speech-to-text/prompting)\n should be in English.\n \"\"\"\n\n response_format: str\n \"\"\"\n The format of the transcript output, in one of these options: `json`, `text`,\n `srt`, `verbose_json`, or `vtt`.\n \"\"\"\n\n temperature: float\n \"\"\"The sampling temperature, between 0 and 1.\n\n Higher values like 0.8 will make the output more random, while lower values like\n 0.2 will make it more focused and deterministic. If set to 0, the model will use\n [log probability](https://en.wikipedia.org/wiki/Log_probability) to\n automatically increase the temperature until certain thresholds are hit.\n \"\"\"\n\n# Path: src/openai/types/audio/__init__.py\n# File generated from our OpenAPI spec by Stainless.\n\nfrom __future__ import annotations\n\nfrom .translation import Translation as Translation\nfrom .transcription import Transcription as Transcription\nfrom .speech_create_params import SpeechCreateParams as SpeechCreateParams\nfrom .translation_create_params import TranslationCreateParams as TranslationCreateParams\nfrom .transcription_create_params import TranscriptionCreateParams as TranscriptionCreateParams\n\n# Path: src/openai/resources/audio/speech.py\n# File generated from our OpenAPI spec by Stainless.\n\nfrom __future__ import annotations\n\nfrom typing import Union\nfrom typing_extensions import Literal\n\nimport httpx\n\nfrom ... import _legacy_response\nfrom ..._types import NOT_GIVEN, Body, Query, Headers, NotGiven\nfrom ..._utils import maybe_transform\nfrom ..._compat import cached_property\nfrom ..._resource import SyncAPIResource, AsyncAPIResource\nfrom ..._response import (\n StreamedBinaryAPIResponse,\n AsyncStreamedBinaryAPIResponse,\n to_custom_streamed_response_wrapper,\n async_to_custom_streamed_response_wrapper,\n)\nfrom ...types.audio import speech_create_params\nfrom ..._base_client import (\n make_request_options,\n)\n\n__all__ = [\"Speech\", \"AsyncSpeech\"]\n\n\nclass Speech(SyncAPIResource):\n @cached_property\n def with_raw_response(self) -> SpeechWithRawResponse:\n return SpeechWithRawResponse(self)\n\n @cached_property\n def with_streaming_response(self) -> SpeechWithStreamingResponse:\n return SpeechWithStreamingResponse(self)\n\n def create(\n self,\n *,\n input: str,\n model: Union[str, Literal[\"tts-1\", \"tts-1-hd\"]],\n voice: Literal[\"alloy\", \"echo\", \"fable\", \"onyx\", \"nova\", \"shimmer\"],\n response_format: Literal[\"mp3\", \"opus\", \"aac\", \"flac\", \"pcm\", \"wav\"] | NotGiven = NOT_GIVEN,\n speed: float | NotGiven = NOT_GIVEN,\n # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.\n # The extra values given here take precedence over values defined on the client or passed to this method.\n extra_headers: Headers | None = None,\n extra_query: Query | None = None,\n extra_body: Body | None = None,\n timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,\n ) -> _legacy_response.HttpxBinaryResponseContent:\n \"\"\"\n Generates audio from the input text.\n\n Args:\n input: The text to generate audio for. The maximum length is 4096 characters.\n\n model:\n One of the available [TTS models](https://platform.openai.com/docs/models/tts):\n `tts-1` or `tts-1-hd`\n\n voice: The voice to use when generating the audio. Supported voices are `alloy`,\n `echo`, `fable`, `onyx`, `nova`, and `shimmer`. Previews of the voices are\n available in the\n [Text to speech guide](https://platform.openai.com/docs/guides/text-to-speech/voice-options).\n\n response_format: The format to return audio in. Supported formats are `mp3`, `opus`, `aac`,\n `flac`, `pcm`, and `wav`.\n\n The `pcm` audio format, similar to `wav` but without a header, utilizes a 24kHz\n sample rate, mono channel, and 16-bit depth in signed little-endian format.\n\n speed: The speed of the generated audio. Select a value from `0.25` to `4.0`. `1.0` is\n the default.\n\n extra_headers: Send extra headers\n\n extra_query: Add additional query parameters to the request\n\n extra_body: Add additional JSON properties to the request\n\n timeout: Override the client-level default timeout for this request, in seconds\n \"\"\"\n extra_headers = {\"Accept\": \"application/octet-stream\", **(extra_headers or {})}\n return self._post(\n \"/audio/speech\",\n body=maybe_transform(\n {\n \"input\": input,\n \"model\": model,\n \"voice\": voice,\n \"response_format\": response_format,\n \"speed\": speed,\n },\n speech_create_params.SpeechCreateParams,\n ),\n options=make_request_options(\n extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout\n ),\n cast_to=_legacy_response.HttpxBinaryResponseContent,\n )\n\n\nclass AsyncSpeech(AsyncAPIResource):\n @cached_property\n def with_raw_response(self) -> AsyncSpeechWithRawResponse:\n return AsyncSpeechWithRawResponse(self)\n\n @cached_property\n def with_streaming_response(self) -> AsyncSpeechWithStreamingResponse:\n return AsyncSpeechWithStreamingResponse(self)\n\n async def create(\n self,\n *,\n input: str,\n model: Union[str, Literal[\"tts-1\", \"tts-1-hd\"]],\n voice: Literal[\"alloy\", \"echo\", \"fable\", \"onyx\", \"nova\", \"shimmer\"],\n response_format: Literal[\"mp3\", \"opus\", \"aac\", \"flac\", \"pcm\", \"wav\"] | NotGiven = NOT_GIVEN,\n speed: float | NotGiven = NOT_GIVEN,\n # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.\n # The extra values given here take precedence over values defined on the client or passed to this method.\n extra_headers: Headers | None = None,\n extra_query: Query | None = None,\n extra_body: Body | None = None,\n timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,\n ) -> _legacy_response.HttpxBinaryResponseContent:\n \"\"\"\n Generates audio from the input text.\n\n Args:\n input: The text to generate audio for. The maximum length is 4096 characters.\n\n model:\n One of the available [TTS models](https://platform.openai.com/docs/models/tts):\n `tts-1` or `tts-1-hd`\n\n voice: The voice to use when generating the audio. Supported voices are `alloy`,\n `echo`, `fable`, `onyx`, `nova`, and `shimmer`. Previews of the voices are\n available in the\n [Text to speech guide](https://platform.openai.com/docs/guides/text-to-speech/voice-options).\n\n response_format: The format to return audio in. Supported formats are `mp3`, `opus`, `aac`,\n `flac`, `pcm`, and `wav`.\n\n...\n", "needle_token_start": 2261, "needle_token_end": 2527, "code_context_ntokens": 15891, "output": ["```python\ndef to_custom_raw_response_wrapper(\n func: Callable[P, object],\n response_cls: type[_APIResponseT],\n) -> Callable[P, _APIResponseT]:\n \"\"\"Higher order function that takes one of our bound API methods and an `APIResponse` class\n and wraps the method to support returning the given response class directly.\n\n Note: the given `response_cls` *must* be concrete, e.g. `class BinaryAPIResponse(APIResponse[bytes])`\n \"\"\"\n\n @functools.wraps(func)\n def wrapped(*args: P.args, **kwargs: P.kwargs) -> _APIResponseT:\n extra_headers = {**(cast(Any, kwargs.get(\"extra_headers\")) or {})}\n extra_headers[RAW_RESPONSE_HEADER] = \"raw\"\n extra_headers[OVERRIDE_CAST_TO_HEADER] = response_cls\n\n kwargs[\"extra_headers\"] = extra_headers\n\n return cast(_APIResponseT, func(*args, **kwargs))\n\n return wrapped\n```"]} +{"repo": "openai/openai-python", "name": "transcribe", "language": "python", "path": "src/openai/cli/_api/audio.py", "position_ratio": 0.25, "description": "\nFunction Description:\n1. **Purpose**: Converts spoken language in an audio file into written text using a specified machine learning model.\n2. **Input**: An audio file, a model identifier, and optional parameters such as language, temperature, prompt, and response format.\n3. **Output**: The transcribed text from the audio file displayed in a structured format.\n4. **Procedure**: The function reads the audio file, sends it to a machine learning API using the specified model and optional parameters, and then displays the transcription result.\n\n", "instruction": "Based on the function description and code context, please retrieve and repeat the exact described function from the code context in a code block wrapped by ```:", "template": "instruction\ncode_context\ndescription\ninstruction", "code_context": "# Path: src/openai/__init__.py\n# File generated from our OpenAPI spec by Stainless.\n\nfrom __future__ import annotations\n\nimport os as _os\nfrom typing_extensions import override\n\nfrom . import types\nfrom ._types import NoneType, Transport, ProxiesTypes\nfrom ._utils import file_from_path\nfrom ._client import Client, OpenAI, Stream, Timeout, Transport, AsyncClient, AsyncOpenAI, AsyncStream, RequestOptions\nfrom ._models import BaseModel\nfrom ._version import __title__, __version__\nfrom ._response import APIResponse as APIResponse, AsyncAPIResponse as AsyncAPIResponse\nfrom ._exceptions import (\n APIError,\n OpenAIError,\n ConflictError,\n NotFoundError,\n APIStatusError,\n RateLimitError,\n APITimeoutError,\n BadRequestError,\n APIConnectionError,\n AuthenticationError,\n InternalServerError,\n PermissionDeniedError,\n UnprocessableEntityError,\n APIResponseValidationError,\n)\nfrom ._utils._logs import setup_logging as _setup_logging\n\n__all__ = [\n \"types\",\n \"__version__\",\n \"__title__\",\n \"NoneType\",\n \"Transport\",\n \"ProxiesTypes\",\n \"OpenAIError\",\n \"APIError\",\n \"APIStatusError\",\n \"APITimeoutError\",\n \"APIConnectionError\",\n \"APIResponseValidationError\",\n \"BadRequestError\",\n \"AuthenticationError\",\n \"PermissionDeniedError\",\n \"NotFoundError\",\n \"ConflictError\",\n \"UnprocessableEntityError\",\n \"RateLimitError\",\n \"InternalServerError\",\n \"Timeout\",\n \"RequestOptions\",\n \"Client\",\n \"AsyncClient\",\n \"Stream\",\n \"AsyncStream\",\n \"OpenAI\",\n \"AsyncOpenAI\",\n \"file_from_path\",\n \"BaseModel\",\n]\n\nfrom .lib import azure as _azure\nfrom .version import VERSION as VERSION\nfrom .lib.azure import AzureOpenAI as AzureOpenAI, AsyncAzureOpenAI as AsyncAzureOpenAI\nfrom .lib._old_api import *\n\n_setup_logging()\n\n# Update the __module__ attribute for exported symbols so that\n# error messages point to this module instead of the module\n# it was originally defined in, e.g.\n# openai._exceptions.NotFoundError -> openai.NotFoundError\n__locals = locals()\nfor __name in __all__:\n if not __name.startswith(\"__\"):\n try:\n __locals[__name].__module__ = \"openai\"\n except (TypeError, AttributeError):\n # Some of our exported symbols are builtins which we can't set attributes for.\n pass\n\n# ------ Module level client ------\nimport typing as _t\nimport typing_extensions as _te\n\nimport httpx as _httpx\n\nfrom ._base_client import DEFAULT_TIMEOUT, DEFAULT_MAX_RETRIES\n\napi_key: str | None = None\n\norganization: str | None = None\n\nbase_url: str | _httpx.URL | None = None\n\ntimeout: float | Timeout | None = DEFAULT_TIMEOUT\n\nmax_retries: int = DEFAULT_MAX_RETRIES\n\ndefault_headers: _t.Mapping[str, str] | None = None\n\ndefault_query: _t.Mapping[str, object] | None = None\n\nhttp_client: _httpx.Client | None = None\n\n_ApiType = _te.Literal[\"openai\", \"azure\"]\n\napi_type: _ApiType | None = _t.cast(_ApiType, _os.environ.get(\"OPENAI_API_TYPE\"))\n\napi_version: str | None = _os.environ.get(\"OPENAI_API_VERSION\")\n\nazure_endpoint: str | None = _os.environ.get(\"AZURE_OPENAI_ENDPOINT\")\n\nazure_ad_token: str | None = _os.environ.get(\"AZURE_OPENAI_AD_TOKEN\")\n\nazure_ad_token_provider: _azure.AzureADTokenProvider | None = None\n\n\nclass _ModuleClient(OpenAI):\n # Note: we have to use type: ignores here as overriding class members\n # with properties is technically unsafe but it is fine for our use case\n\n @property # type: ignore\n @override\n def api_key(self) -> str | None:\n return api_key\n\n @api_key.setter # type: ignore\n def api_key(self, value: str | None) -> None: # type: ignore\n global api_key\n\n api_key = value\n\n @property # type: ignore\n @override\n def organization(self) -> str | None:\n return organization\n\n @organization.setter # type: ignore\n def organization(self, value: str | None) -> None: # type: ignore\n global organization\n\n organization = value\n\n @property\n @override\n def base_url(self) -> _httpx.URL:\n if base_url is not None:\n return _httpx.URL(base_url)\n\n return super().base_url\n\n @base_url.setter\n def base_url(self, url: _httpx.URL | str) -> None:\n super().base_url = url # type: ignore[misc]\n\n @property # type: ignore\n @override\n def timeout(self) -> float | Timeout | None:\n return timeout\n\n @timeout.setter # type: ignore\n def timeout(self, value: float | Timeout | None) -> None: # type: ignore\n global timeout\n\n timeout = value\n\n @property # type: ignore\n @override\n def max_retries(self) -> int:\n return max_retries\n\n @max_retries.setter # type: ignore\n def max_retries(self, value: int) -> None: # type: ignore\n global max_retries\n\n max_retries = value\n\n @property # type: ignore\n @override\n def _custom_headers(self) -> _t.Mapping[str, str] | None:\n return default_headers\n\n @_custom_headers.setter # type: ignore\n def _custom_headers(self, value: _t.Mapping[str, str] | None) -> None: # type: ignore\n global default_headers\n\n default_headers = value\n\n @property # type: ignore\n @override\n def _custom_query(self) -> _t.Mapping[str, object] | None:\n return default_query\n\n @_custom_query.setter # type: ignore\n def _custom_query(self, value: _t.Mapping[str, object] | None) -> None: # type: ignore\n global default_query\n\n default_query = value\n\n @property # type: ignore\n @override\n def _client(self) -> _httpx.Client:\n return http_client or super()._client\n\n @_client.setter # type: ignore\n def _client(self, value: _httpx.Client) -> None: # type: ignore\n global http_client\n\n http_client = value\n\n\nclass _AzureModuleClient(_ModuleClient, AzureOpenAI): # type: ignore\n ...\n\n\nclass _AmbiguousModuleClientUsageError(OpenAIError):\n def __init__(self) -> None:\n super().__init__(\n \"Ambiguous use of module client; please set `openai.api_type` or the `OPENAI_API_TYPE` environment variable to `openai` or `azure`\"\n )\n\n\ndef _has_openai_credentials() -> bool:\n return _os.environ.get(\"OPENAI_API_KEY\") is not None\n\n\ndef _has_azure_credentials() -> bool:\n return azure_endpoint is not None or _os.environ.get(\"AZURE_OPENAI_API_KEY\") is not None\n\n\ndef _has_azure_ad_credentials() -> bool:\n return (\n _os.environ.get(\"AZURE_OPENAI_AD_TOKEN\") is not None\n or azure_ad_token is not None\n or azure_ad_token_provider is not None\n )\n\n\n_client: OpenAI | None = None\n\n\ndef _load_client() -> OpenAI: # type: ignore[reportUnusedFunction]\n global _client\n\n if _client is None:\n global api_type, azure_endpoint, azure_ad_token, api_version\n\n if azure_endpoint is None:\n azure_endpoint = _os.environ.get(\"AZURE_OPENAI_ENDPOINT\")\n\n if azure_ad_token is None:\n azure_ad_token = _os.environ.get(\"AZURE_OPENAI_AD_TOKEN\")\n\n if api_version is None:\n api_version = _os.environ.get(\"OPENAI_API_VERSION\")\n\n if api_type is None:\n...\n# Path: src/openai/cli/_models.py\nfrom typing import Any\nfrom typing_extensions import ClassVar\n\nimport pydantic\n\nfrom .. import _models\nfrom .._compat import PYDANTIC_V2, ConfigDict\n\n\nclass BaseModel(_models.BaseModel):\n if PYDANTIC_V2:\n model_config: ClassVar[ConfigDict] = ConfigDict(extra=\"ignore\", arbitrary_types_allowed=True)\n else:\n\n class Config(pydantic.BaseConfig): # type: ignore\n extra: Any = pydantic.Extra.ignore # type: ignore\n arbitrary_types_allowed: bool = True\n\n# Path: src/openai/cli/_progress.py\nfrom __future__ import annotations\n\nimport io\nfrom typing import Callable\nfrom typing_extensions import override\n\n\nclass CancelledError(Exception):\n def __init__(self, msg: str) -> None:\n self.msg = msg\n super().__init__(msg)\n\n @override\n def __str__(self) -> str:\n return self.msg\n\n __repr__ = __str__\n\n\nclass BufferReader(io.BytesIO):\n def __init__(self, buf: bytes = b\"\", desc: str | None = None) -> None:\n super().__init__(buf)\n self._len = len(buf)\n self._progress = 0\n self._callback = progress(len(buf), desc=desc)\n\n def __len__(self) -> int:\n return self._len\n\n @override\n def read(self, n: int | None = -1) -> bytes:\n chunk = io.BytesIO.read(self, n)\n self._progress += len(chunk)\n\n try:\n self._callback(self._progress)\n except Exception as e: # catches exception from the callback\n raise CancelledError(\"The upload was cancelled: {}\".format(e)) from e\n\n return chunk\n\n\ndef progress(total: float, desc: str | None) -> Callable[[float], None]:\n import tqdm\n\n meter = tqdm.tqdm(total=total, unit_scale=True, desc=desc)\n\n def incr(progress: float) -> None:\n meter.n = progress\n if progress == total:\n meter.close()\n else:\n meter.refresh()\n\n return incr\n\n\ndef MB(i: int) -> int:\n return int(i // 1024**2)\n\n# Path: src/openai/cli/_utils.py\nfrom __future__ import annotations\n\nimport sys\n\nimport openai\n\nfrom .. import OpenAI, _load_client\nfrom .._compat import model_json\nfrom .._models import BaseModel\n\n\nclass Colors:\n HEADER = \"\\033[95m\"\n OKBLUE = \"\\033[94m\"\n OKGREEN = \"\\033[92m\"\n WARNING = \"\\033[93m\"\n FAIL = \"\\033[91m\"\n ENDC = \"\\033[0m\"\n BOLD = \"\\033[1m\"\n UNDERLINE = \"\\033[4m\"\n\n\ndef get_client() -> OpenAI:\n return _load_client()\n\n\ndef organization_info() -> str:\n organization = openai.organization\n if organization is not None:\n return \"[organization={}] \".format(organization)\n\n return \"\"\n\n\ndef print_model(model: BaseModel) -> None:\n sys.stdout.write(model_json(model, indent=2) + \"\\n\")\n\n\ndef can_use_http2() -> bool:\n try:\n import h2 # type: ignore # noqa\n except ImportError:\n return False\n\n return True\n\n# Path: src/openai/cli/_api/audio.py\nfrom __future__ import annotations\n\nfrom typing import TYPE_CHECKING, Any, Optional, cast\nfrom argparse import ArgumentParser\n\nfrom .._utils import get_client, print_model\nfrom ..._types import NOT_GIVEN\nfrom .._models import BaseModel\nfrom .._progress import BufferReader\n\nif TYPE_CHECKING:\n from argparse import _SubParsersAction\n\n\ndef register(subparser: _SubParsersAction[ArgumentParser]) -> None:\n # transcriptions\n sub = subparser.add_parser(\"audio.transcriptions.create\")\n\n # Required\n sub.add_argument(\"-m\", \"--model\", type=str, default=\"whisper-1\")\n sub.add_argument(\"-f\", \"--file\", type=str, required=True)\n # Optional\n sub.add_argument(\"--response-format\", type=str)\n sub.add_argument(\"--language\", type=str)\n sub.add_argument(\"-t\", \"--temperature\", type=float)\n sub.add_argument(\"--prompt\", type=str)\n sub.set_defaults(func=CLIAudio.transcribe, args_model=CLITranscribeArgs)\n\n # translations\n sub = subparser.add_parser(\"audio.translations.create\")\n\n # Required\n sub.add_argument(\"-f\", \"--file\", type=str, required=True)\n # Optional\n sub.add_argument(\"-m\", \"--model\", type=str, default=\"whisper-1\")\n sub.add_argument(\"--response-format\", type=str)\n # TODO: doesn't seem to be supported by the API\n # sub.add_argument(\"--language\", type=str)\n sub.add_argument(\"-t\", \"--temperature\", type=float)\n sub.add_argument(\"--prompt\", type=str)\n sub.set_defaults(func=CLIAudio.translate, args_model=CLITranslationArgs)\n\n\nclass CLITranscribeArgs(BaseModel):\n model: str\n file: str\n response_format: Optional[str] = None\n language: Optional[str] = None\n temperature: Optional[float] = None\n prompt: Optional[str] = None\n\n\nclass CLITranslationArgs(BaseModel):\n model: str\n file: str\n response_format: Optional[str] = None\n language: Optional[str] = None\n temperature: Optional[float] = None\n prompt: Optional[str] = None\n\n\nclass CLIAudio:\n @staticmethod\n \ndef transcribe(args: CLITranscribeArgs) -> None:\n with open(args.file, \"rb\") as file_reader:\n buffer_reader = BufferReader(file_reader.read(), desc=\"Upload progress\")\n\n model = get_client().audio.transcriptions.create(\n file=(args.file, buffer_reader),\n model=args.model,\n language=args.language or NOT_GIVEN,\n temperature=args.temperature or NOT_GIVEN,\n prompt=args.prompt or NOT_GIVEN,\n # casts required because the API is typed for enums\n # but we don't want to validate that here for forwards-compat\n response_format=cast(Any, args.response_format),\n )\n print_model(model)\n\n @staticmethod\n def translate(args: CLITranslationArgs) -> None:\n with open(args.file, \"rb\") as file_reader:\n buffer_reader = BufferReader(file_reader.read(), desc=\"Upload progress\")\n\n model = get_client().audio.translations.create(\n file=(args.file, buffer_reader),\n model=args.model,\n temperature=args.temperature or NOT_GIVEN,\n prompt=args.prompt or NOT_GIVEN,\n # casts required because the API is typed for enums\n # but we don't want to validate that here for forwards-compat\n response_format=cast(Any, args.response_format),\n )\n print_model(model)\n\n# Path: src/openai/cli/_api/chat/completions.py\nfrom __future__ import annotations\n\nimport sys\nfrom typing import TYPE_CHECKING, List, Optional, cast\nfrom argparse import ArgumentParser\nfrom typing_extensions import Literal, NamedTuple\n\nfrom ..._utils import get_client\nfrom ..._models import BaseModel\nfrom ...._streaming import Stream\nfrom ....types.chat import (\n ChatCompletionRole,\n ChatCompletionChunk,\n CompletionCreateParams,\n)\nfrom ....types.chat.completion_create_params import (\n CompletionCreateParamsStreaming,\n CompletionCreateParamsNonStreaming,\n)\n\nif TYPE_CHECKING:\n from argparse import _SubParsersAction\n\n\ndef register(subparser: _SubParsersAction[ArgumentParser]) -> None:\n sub = subparser.add_parser(\"chat.completions.create\")\n\n sub._action_groups.pop()\n req = sub.add_argument_group(\"required arguments\")\n opt = sub.add_argument_group(\"optional arguments\")\n\n req.add_argument(\n \"-g\",\n \"--message\",\n action=\"append\",\n nargs=2,\n metavar=(\"ROLE\", \"CONTENT\"),\n help=\"A message in `{role} {content}` format. Use this argument multiple times to add multiple messages.\",\n required=True,\n )\n req.add_argument(\n \"-m\",\n \"--model\",\n help=\"The model to use.\",\n required=True,\n )\n\n opt.add_argument(\n \"-n\",\n \"--n\",\n help=\"How many completions to generate for the conversation.\",\n type=int,\n )\n opt.add_argument(\"-M\", \"--max-tokens\", help=\"The maximum number of tokens to generate.\", type=int)\n opt.add_argument(\n \"-t\",\n \"--temperature\",\n help=\"\"\"What sampling temperature to use. Higher values means the model will take more risks. Try 0.9 for more creative applications, and 0 (argmax sampling) for ones with a well-defined answer.\n\nMutually exclusive with `top_p`.\"\"\",\n type=float,\n )\n opt.add_argument(\n \"-P\",\n \"--top_p\",\n help=\"\"\"An alternative to sampling with temperature, called nucleus sampling, where the considers the results of the tokens with top_p probability mass. So 0.1 means only the tokens comprising the top 10%% probability mass are considered.\n\n Mutually exclusive with `temperature`.\"\"\",\n type=float,\n )\n opt.add_argument(\n \"--stop\",\n help=\"A stop sequence at which to stop generating tokens for the message.\",\n )\n opt.add_argument(\"--stream\", help=\"Stream messages as they're ready.\", action=\"store_true\")\n sub.set_defaults(func=CLIChatCompletion.create, args_model=CLIChatCompletionCreateArgs)\n\n\nclass CLIMessage(NamedTuple):\n role: ChatCompletionRole\n content: str\n\n\nclass CLIChatCompletionCreateArgs(BaseModel):\n message: List[CLIMessage]\n model: str\n n: Optional[int] = None\n max_tokens: Optional[int] = None\n temperature: Optional[float] = None\n top_p: Optional[float] = None\n stop: Optional[str] = None\n stream: bool = False\n\n\nclass CLIChatCompletion:\n @staticmethod\n def create(args: CLIChatCompletionCreateArgs) -> None:\n params: CompletionCreateParams = {\n \"model\": args.model,\n \"messages\": [\n {\"role\": cast(Literal[\"user\"], message.role), \"content\": message.content} for message in args.message\n ],\n \"n\": args.n,\n \"temperature\": args.temperature,\n \"top_p\": args.top_p,\n \"stop\": args.stop,\n # type checkers are not good at inferring union types so we have to set stream afterwards\n \"stream\": False,\n }\n if args.stream:\n params[\"stream\"] = args.stream # type: ignore\n if args.max_tokens is not None:\n params[\"max_tokens\"] = args.max_tokens\n\n if args.stream:\n return CLIChatCompletion._stream_create(cast(CompletionCreateParamsStreaming, params))\n\n return CLIChatCompletion._create(cast(CompletionCreateParamsNonStreaming, params))\n\n @staticmethod\n def _create(params: CompletionCreateParamsNonStreaming) -> None:\n completion = get_client().chat.completions.create(**params)\n should_print_header = len(completion.choices) > 1\n for choice in completion.choices:\n if should_print_header:\n sys.stdout.write(\"===== Chat Completion {} =====\\n\".format(choice.index))\n\n content = choice.message.content if choice.message.content is not None else \"None\"\n sys.stdout.write(content)\n\n if should_print_header or not content.endswith(\"\\n\"):\n sys.stdout.write(\"\\n\")\n\n sys.stdout.flush()\n\n @staticmethod\n def _stream_create(params: CompletionCreateParamsStreaming) -> None:\n # cast is required for mypy\n stream = cast( # pyright: ignore[reportUnnecessaryCast]\n Stream[ChatCompletionChunk], get_client().chat.completions.create(**params)\n )\n for chunk in stream:\n should_print_header = len(chunk.choices) > 1\n for choice in chunk.choices:\n if should_print_header:\n sys.stdout.write(\"===== Chat Completion {} =====\\n\".format(choice.index))\n\n content = choice.delta.content or \"\"\n sys.stdout.write(content)\n\n if should_print_header:\n sys.stdout.write(\"\\n\")\n\n sys.stdout.flush()\n\n sys.stdout.write(\"\\n\")\n\n# Path: src/openai/cli/_api/chat/__init__.py\nfrom __future__ import annotations\n\nfrom typing import TYPE_CHECKING\nfrom argparse import ArgumentParser\n\nfrom . import completions\n\nif TYPE_CHECKING:\n from argparse import _SubParsersAction\n\n\ndef register(subparser: _SubParsersAction[ArgumentParser]) -> None:\n completions.register(subparser)\n\n# Path: src/openai/cli/_errors.py\nfrom __future__ import annotations\n\nimport sys\n\nimport pydantic\n\nfrom ._utils import Colors, organization_info\nfrom .._exceptions import APIError, OpenAIError\n\n\nclass CLIError(OpenAIError):\n ...\n\n\nclass SilentCLIError(CLIError):\n ...\n\n\ndef display_error(err: CLIError | APIError | pydantic.ValidationError) -> None:\n if isinstance(err, SilentCLIError):\n return\n\n sys.stderr.write(\"{}{}Error:{} {}\\n\".format(organization_info(), Colors.FAIL, Colors.ENDC, err))\n\n# Path: src/openai/cli/_api/completions.py\nfrom __future__ import annotations\n\nimport sys\nfrom typing import TYPE_CHECKING, Optional, cast\nfrom argparse import ArgumentParser\nfrom functools import partial\n\nfrom openai.types.completion import Completion\n\nfrom .._utils import get_client\nfrom ..._types import NOT_GIVEN, NotGivenOr\nfrom ..._utils import is_given\nfrom .._errors import CLIError\nfrom .._models import BaseModel\nfrom ..._streaming import Stream\n\nif TYPE_CHECKING:\n from argparse import _SubParsersAction\n\n\ndef register(subparser: _SubParsersAction[ArgumentParser]) -> None:\n sub = subparser.add_parser(\"completions.create\")\n\n # Required\n sub.add_argument(\n \"-m\",\n \"--model\",\n help=\"The model to use\",\n required=True,\n )\n\n # Optional\n sub.add_argument(\"-p\", \"--prompt\", help=\"An optional prompt to complete from\")\n sub.add_argument(\"--stream\", help=\"Stream tokens as they're ready.\", action=\"store_true\")\n sub.add_argument(\"-M\", \"--max-tokens\", help=\"The maximum number of tokens to generate\", type=int)\n sub.add_argument(\n \"-t\",\n \"--temperature\",\n help=\"\"\"What sampling temperature to use. Higher values means the model will take more risks. Try 0.9 for more creative applications, and 0 (argmax sampling) for ones with a well-defined answer.\n\nMutually exclusive with `top_p`.\"\"\",\n type=float,\n )\n sub.add_argument(\n \"-P\",\n \"--top_p\",\n help=\"\"\"An alternative to sampling with temperature, called nucleus sampling, where the considers the results of the tokens with top_p probability mass. So 0.1 means only the tokens comprising the top 10%% probability mass are considered.\n\n Mutually exclusive with `temperature`.\"\"\",\n type=float,\n )\n sub.add_argument(\n \"-n\",\n \"--n\",\n help=\"How many sub-completions to generate for each prompt.\",\n type=int,\n )\n sub.add_argument(\n \"--logprobs\",\n help=\"Include the log probabilities on the `logprobs` most likely tokens, as well the chosen tokens. So for example, if `logprobs` is 10, the API will return a list of the 10 most likely tokens. If `logprobs` is 0, only the chosen tokens will have logprobs returned.\",\n type=int,\n )\n sub.add_argument(\n \"--best_of\",\n help=\"Generates `best_of` completions server-side and returns the 'best' (the one with the highest log probability per token). Results cannot be streamed.\",\n type=int,\n )\n sub.add_argument(\n \"--echo\",\n help=\"Echo back the prompt in addition to the completion\",\n action=\"store_true\",\n )\n sub.add_argument(\n \"--frequency_penalty\",\n help=\"Positive values penalize new tokens based on their existing frequency in the text so far, decreasing the model's likelihood to repeat the same line verbatim.\",\n type=float,\n )\n sub.add_argument(\n \"--presence_penalty\",\n help=\"Positive values penalize new tokens based on whether they appear in the text so far, increasing the model's likelihood to talk about new topics.\",\n type=float,\n )\n sub.add_argument(\"--suffix\", help=\"The suffix that comes after a completion of inserted text.\")\n sub.add_argument(\"--stop\", help=\"A stop sequence at which to stop generating tokens.\")\n sub.add_argument(\n \"--user\",\n help=\"A unique identifier representing your end-user, which can help OpenAI to monitor and detect abuse.\",\n )\n # TODO: add support for logit_bias\n sub.set_defaults(func=CLICompletions.create, args_model=CLICompletionCreateArgs)\n\n\nclass CLICompletionCreateArgs(BaseModel):\n model: str\n stream: bool = False\n\n prompt: Optional[str] = None\n n: NotGivenOr[int] = NOT_GIVEN\n stop: NotGivenOr[str] = NOT_GIVEN\n user: NotGivenOr[str] = NOT_GIVEN\n echo: NotGivenOr[bool] = NOT_GIVEN\n suffix: NotGivenOr[str] = NOT_GIVEN\n best_of: NotGivenOr[int] = NOT_GIVEN\n top_p: NotGivenOr[float] = NOT_GIVEN\n logprobs: NotGivenOr[int] = NOT_GIVEN\n max_tokens: NotGivenOr[int] = NOT_GIVEN\n temperature: NotGivenOr[float] = NOT_GIVEN\n presence_penalty: NotGivenOr[float] = NOT_GIVEN\n frequency_penalty: NotGivenOr[float] = NOT_GIVEN\n\n\nclass CLICompletions:\n @staticmethod\n def create(args: CLICompletionCreateArgs) -> None:\n if is_given(args.n) and args.n > 1 and args.stream:\n raise CLIError(\"Can't stream completions with n>1 with the current CLI\")\n\n make_request = partial(\n get_client().completions.create,\n n=args.n,\n echo=args.echo,\n stop=args.stop,\n user=args.user,\n model=args.model,\n top_p=args.top_p,\n prompt=args.prompt,\n suffix=args.suffix,\n best_of=args.best_of,\n logprobs=args.logprobs,\n max_tokens=args.max_tokens,\n temperature=args.temperature,\n presence_penalty=args.presence_penalty,\n frequency_penalty=args.frequency_penalty,\n )\n\n if args.stream:\n return CLICompletions._stream_create(\n # mypy doesn't understand the `partial` function but pyright does\n cast(Stream[Completion], make_request(stream=True)) # pyright: ignore[reportUnnecessaryCast]\n )\n\n return CLICompletions._create(make_request())\n\n @staticmethod\n def _create(completion: Completion) -> None:\n should_print_header = len(completion.choices) > 1\n for choice in completion.choices:\n if should_print_header:\n sys.stdout.write(\"===== Completion {} =====\\n\".format(choice.index))\n\n sys.stdout.write(choice.text)\n\n if should_print_header or not choice.text.endswith(\"\\n\"):\n sys.stdout.write(\"\\n\")\n\n sys.stdout.flush()\n\n @staticmethod\n def _stream_create(stream: Stream[Completion]) -> None:\n for completion in stream:\n should_print_header = len(completion.choices) > 1\n for choice in sorted(completion.choices, key=lambda c: c.index):\n if should_print_header:\n sys.stdout.write(\"===== Chat Completion {} =====\\n\".format(choice.index))\n\n sys.stdout.write(choice.text)\n\n if should_print_header:\n sys.stdout.write(\"\\n\")\n\n sys.stdout.flush()\n\n sys.stdout.write(\"\\n\")\n\n# Path: src/openai/cli/_api/files.py\nfrom __future__ import annotations\n\nfrom typing import TYPE_CHECKING, Any, cast\nfrom argparse import ArgumentParser\n\nfrom .._utils import get_client, print_model\nfrom .._models import BaseModel\nfrom .._progress import BufferReader\n\nif TYPE_CHECKING:\n from argparse import _SubParsersAction\n\n\ndef register(subparser: _SubParsersAction[ArgumentParser]) -> None:\n sub = subparser.add_parser(\"files.create\")\n\n sub.add_argument(\n \"-f\",\n \"--file\",\n required=True,\n help=\"File to upload\",\n )\n sub.add_argument(\n \"-p\",\n \"--purpose\",\n help=\"Why are you uploading this file? (see https://platform.openai.com/docs/api-reference/ for purposes)\",\n required=True,\n )\n sub.set_defaults(func=CLIFile.create, args_model=CLIFileCreateArgs)\n\n sub = subparser.add_parser(\"files.retrieve\")\n sub.add_argument(\"-i\", \"--id\", required=True, help=\"The files ID\")\n sub.set_defaults(func=CLIFile.get, args_model=CLIFileCreateArgs)\n\n sub = subparser.add_parser(\"files.delete\")\n sub.add_argument(\"-i\", \"--id\", required=True, help=\"The files ID\")\n sub.set_defaults(func=CLIFile.delete, args_model=CLIFileCreateArgs)\n\n sub = subparser.add_parser(\"files.list\")\n sub.set_defaults(func=CLIFile.list)\n\n\nclass CLIFileIDArgs(BaseModel):\n id: str\n\n\nclass CLIFileCreateArgs(BaseModel):\n file: str\n purpose: str\n\n\nclass CLIFile:\n @staticmethod\n def create(args: CLIFileCreateArgs) -> None:\n with open(args.file, \"rb\") as file_reader:\n buffer_reader = BufferReader(file_reader.read(), desc=\"Upload progress\")\n\n file = get_client().files.create(\n file=(args.file, buffer_reader),\n # casts required because the API is typed for enums\n # but we don't want to validate that here for forwards-compat\n purpose=cast(Any, args.purpose),\n )\n print_model(file)\n\n @staticmethod\n def get(args: CLIFileIDArgs) -> None:\n file = get_client().files.retrieve(file_id=args.id)\n print_model(file)\n\n @staticmethod\n def delete(args: CLIFileIDArgs) -> None:\n file = get_client().files.delete(file_id=args.id)\n print_model(file)\n\n @staticmethod\n def list() -> None:\n files = get_client().files.list()\n for file in files:\n print_model(file)\n\n# Path: src/openai/cli/_api/image.py\nfrom __future__ import annotations\n\nfrom typing import TYPE_CHECKING, Any, cast\nfrom argparse import ArgumentParser\n\nfrom .._utils import get_client, print_model\nfrom ..._types import NOT_GIVEN, NotGiven, NotGivenOr\nfrom .._models import BaseModel\nfrom .._progress import BufferReader\n\nif TYPE_CHECKING:\n from argparse import _SubParsersAction\n\n\ndef register(subparser: _SubParsersAction[ArgumentParser]) -> None:\n sub = subparser.add_parser(\"images.generate\")\n sub.add_argument(\"-m\", \"--model\", type=str)\n sub.add_argument(\"-p\", \"--prompt\", type=str, required=True)\n sub.add_argument(\"-n\", \"--num-images\", type=int, default=1)\n sub.add_argument(\"-s\", \"--size\", type=str, default=\"1024x1024\", help=\"Size of the output image\")\n sub.add_argument(\"--response-format\", type=str, default=\"url\")\n sub.set_defaults(func=CLIImage.create, args_model=CLIImageCreateArgs)\n\n sub = subparser.add_parser(\"images.edit\")\n sub.add_argument(\"-m\", \"--model\", type=str)\n sub.add_argument(\"-p\", \"--prompt\", type=str, required=True)\n sub.add_argument(\"-n\", \"--num-images\", type=int, default=1)\n sub.add_argument(\n \"-I\",\n \"--image\",\n type=str,\n required=True,\n help=\"Image to modify. Should be a local path and a PNG encoded image.\",\n )\n sub.add_argument(\"-s\", \"--size\", type=str, default=\"1024x1024\", help=\"Size of the output image\")\n sub.add_argument(\"--response-format\", type=str, default=\"url\")\n sub.add_argument(\n \"-M\",\n \"--mask\",\n type=str,\n required=False,\n help=\"Path to a mask image. It should be the same size as the image you're editing and a RGBA PNG image. The Alpha channel acts as the mask.\",\n )\n sub.set_defaults(func=CLIImage.edit, args_model=CLIImageEditArgs)\n\n sub = subparser.add_parser(\"images.create_variation\")\n sub.add_argument(\"-m\", \"--model\", type=str)\n sub.add_argument(\"-n\", \"--num-images\", type=int, default=1)\n sub.add_argument(\n \"-I\",\n \"--image\",\n type=str,\n required=True,\n help=\"Image to modify. Should be a local path and a PNG encoded image.\",\n )\n sub.add_argument(\"-s\", \"--size\", type=str, default=\"1024x1024\", help=\"Size of the output image\")\n sub.add_argument(\"--response-format\", type=str, default=\"url\")\n sub.set_defaults(func=CLIImage.create_variation, args_model=CLIImageCreateVariationArgs)\n\n\nclass CLIImageCreateArgs(BaseModel):\n prompt: str\n num_images: int\n size: str\n response_format: str\n model: NotGivenOr[str] = NOT_GIVEN\n\n\nclass CLIImageCreateVariationArgs(BaseModel):\n image: str\n num_images: int\n size: str\n response_format: str\n model: NotGivenOr[str] = NOT_GIVEN\n\n\nclass CLIImageEditArgs(BaseModel):\n image: str\n num_images: int\n size: str\n response_format: str\n prompt: str\n mask: NotGivenOr[str] = NOT_GIVEN\n model: NotGivenOr[str] = NOT_GIVEN\n\n\nclass CLIImage:\n @staticmethod\n def create(args: CLIImageCreateArgs) -> None:\n image = get_client().images.generate(\n model=args.model,\n prompt=args.prompt,\n n=args.num_images,\n # casts required because the API is typed for enums\n # but we don't want to validate that here for forwards-compat\n size=cast(Any, args.size),\n response_format=cast(Any, args.response_format),\n )\n print_model(image)\n\n @staticmethod\n def create_variation(args: CLIImageCreateVariationArgs) -> None:\n with open(args.image, \"rb\") as file_reader:\n buffer_reader = BufferReader(file_reader.read(), desc=\"Upload progress\")\n\n image = get_client().images.create_variation(\n model=args.model,\n image=(\"image\", buffer_reader),\n n=args.num_images,\n # casts required because the API is typed for enums\n # but we don't want to validate that here for forwards-compat\n size=cast(Any, args.size),\n response_format=cast(Any, args.response_format),\n )\n print_model(image)\n\n @staticmethod\n def edit(args: CLIImageEditArgs) -> None:\n with open(args.image, \"rb\") as file_reader:\n buffer_reader = BufferReader(file_reader.read(), desc=\"Image upload progress\")\n\n if isinstance(args.mask, NotGiven):\n mask: NotGivenOr[BufferReader] = NOT_GIVEN\n else:\n with open(args.mask, \"rb\") as file_reader:\n mask = BufferReader(file_reader.read(), desc=\"Mask progress\")\n\n image = get_client().images.edit(\n model=args.model,\n prompt=args.prompt,\n image=(\"image\", buffer_reader),\n n=args.num_images,\n mask=(\"mask\", mask) if not isinstance(mask, NotGiven) else mask,\n # casts required because the API is typed for enums\n # but we don't want to validate that here for forwards-compat\n size=cast(Any, args.size),\n response_format=cast(Any, args.response_format),\n )\n print_model(image)\n\n# Path: src/openai/cli/_api/models.py\nfrom __future__ import annotations\n\nfrom typing import TYPE_CHECKING\nfrom argparse import ArgumentParser\n\nfrom .._utils import get_client, print_model\nfrom .._models import BaseModel\n\nif TYPE_CHECKING:\n from argparse import _SubParsersAction\n\n\ndef register(subparser: _SubParsersAction[ArgumentParser]) -> None:\n sub = subparser.add_parser(\"models.list\")\n sub.set_defaults(func=CLIModels.list)\n\n sub = subparser.add_parser(\"models.retrieve\")\n sub.add_argument(\"-i\", \"--id\", required=True, help=\"The model ID\")\n sub.set_defaults(func=CLIModels.get, args_model=CLIModelIDArgs)\n\n sub = subparser.add_parser(\"models.delete\")\n sub.add_argument(\"-i\", \"--id\", required=True, help=\"The model ID\")\n sub.set_defaults(func=CLIModels.delete, args_model=CLIModelIDArgs)\n\n\nclass CLIModelIDArgs(BaseModel):\n id: str\n\n\nclass CLIModels:\n @staticmethod\n def get(args: CLIModelIDArgs) -> None:\n model = get_client().models.retrieve(model=args.id)\n print_model(model)\n\n @staticmethod\n def delete(args: CLIModelIDArgs) -> None:\n model = get_client().models.delete(model=args.id)\n print_model(model)\n\n @staticmethod\n def list() -> None:\n models = get_client().models.list()\n for model in models:\n print_model(model)\n\n# Path: src/openai/cli/_api/_main.py\nfrom __future__ import annotations\n\nfrom argparse import ArgumentParser\n\nfrom . import chat, audio, files, image, models, completions\n\n\ndef register_commands(parser: ArgumentParser) -> None:\n subparsers = parser.add_subparsers(help=\"All API subcommands\")\n\n chat.register(subparsers)\n image.register(subparsers)\n audio.register(subparsers)\n files.register(subparsers)\n models.register(subparsers)\n completions.register(subparsers)\n\n# Path: src/openai/cli/_api/__init__.py\nfrom ._main import register_commands as register_commands\n\n# Path: src/openai/cli/_tools/fine_tunes.py\nfrom __future__ import annotations\n\nimport sys\nfrom typing import TYPE_CHECKING\nfrom argparse import ArgumentParser\n\nfrom .._models import BaseModel\nfrom ...lib._validators import (\n get_validators,\n write_out_file,\n read_any_format,\n apply_validators,\n apply_necessary_remediation,\n)\n\nif TYPE_CHECKING:\n from argparse import _SubParsersAction\n\n\ndef register(subparser: _SubParsersAction[ArgumentParser]) -> None:\n sub = subparser.add_parser(\"fine_tunes.prepare_data\")\n sub.add_argument(\n \"-f\",\n \"--file\",\n required=True,\n help=\"JSONL, JSON, CSV, TSV, TXT or XLSX file containing prompt-completion examples to be analyzed.\"\n \"This should be the local file path.\",\n )\n sub.add_argument(\n \"-q\",\n \"--quiet\",\n required=False,\n action=\"store_true\",\n help=\"Auto accepts all suggestions, without asking for user input. To be used within scripts.\",\n )\n sub.set_defaults(func=prepare_data, args_model=PrepareDataArgs)\n\n\nclass PrepareDataArgs(BaseModel):\n file: str\n\n quiet: bool\n\n\ndef prepare_data(args: PrepareDataArgs) -> None:\n sys.stdout.write(\"Analyzing...\\n\")\n fname = args.file\n auto_accept = args.quiet\n df, remediation = read_any_format(fname)\n apply_necessary_remediation(None, remediation)\n\n validators = get_validators()\n\n assert df is not None\n\n apply_validators(\n df,\n fname,\n remediation,\n validators,\n auto_accept,\n write_out_file_func=write_out_file,\n )\n\n# Path: src/openai/cli/_tools/migrate.py\nfrom __future__ import annotations\n\nimport os\nimport sys\nimport json\nimport shutil\nimport tarfile\nimport platform\nimport subprocess\nfrom typing import TYPE_CHECKING, List\nfrom pathlib import Path\nfrom argparse import ArgumentParser\n\nimport httpx\n\nfrom .._errors import CLIError, SilentCLIError\nfrom .._models import BaseModel\n\nif TYPE_CHECKING:\n from argparse import _SubParsersAction\n\n\ndef register(subparser: _SubParsersAction[ArgumentParser]) -> None:\n sub = subparser.add_parser(\"migrate\")\n sub.set_defaults(func=migrate, args_model=MigrateArgs, allow_unknown_args=True)\n\n sub = subparser.add_parser(\"grit\")\n sub.set_defaults(func=grit, args_model=GritArgs, allow_unknown_args=True)\n\n\nclass GritArgs(BaseModel):\n # internal\n unknown_args: List[str] = []\n\n\ndef grit(args: GritArgs) -> None:\n grit_path = install()\n\n try:\n subprocess.check_call([grit_path, *args.unknown_args])\n except subprocess.CalledProcessError:\n # stdout and stderr are forwarded by subprocess so an error will already\n # have been displayed\n raise SilentCLIError() from None\n\n\nclass MigrateArgs(BaseModel):\n # internal\n unknown_args: List[str] = []\n\n\ndef migrate(args: MigrateArgs) -> None:\n grit_path = install()\n\n try:\n subprocess.check_call([grit_path, \"apply\", \"openai\", *args.unknown_args])\n except subprocess.CalledProcessError:\n # stdout and stderr are forwarded by subprocess so an error will already\n # have been displayed\n raise SilentCLIError() from None\n\n\n# handles downloading the Grit CLI until they provide their own PyPi package\n\nKEYGEN_ACCOUNT = \"custodian-dev\"\n\n\ndef _cache_dir() -> Path:\n xdg = os.environ.get(\"XDG_CACHE_HOME\")\n if xdg is not None:\n return Path(xdg)\n\n return Path.home() / \".cache\"\n\n\ndef _debug(message: str) -> None:\n if not os.environ.get(\"DEBUG\"):\n return\n\n sys.stdout.write(f\"[DEBUG]: {message}\\n\")\n\n\ndef install() -> Path:\n \"\"\"Installs the Grit CLI and returns the location of the binary\"\"\"\n if sys.platform == \"win32\":\n raise CLIError(\"Windows is not supported yet in the migration CLI\")\n\n platform = \"macos\" if sys.platform == \"darwin\" else \"linux\"\n\n dir_name = _cache_dir() / \"openai-python\"\n install_dir = dir_name / \".install\"\n target_dir = install_dir / \"bin\"\n\n target_path = target_dir / \"marzano\"\n temp_file = target_dir / \"marzano.tmp\"\n\n if target_path.exists():\n _debug(f\"{target_path} already exists\")\n sys.stdout.flush()\n return target_path\n\n _debug(f\"Using Grit CLI path: {target_path}\")\n\n target_dir.mkdir(parents=True, exist_ok=True)\n\n if temp_file.exists():\n temp_file.unlink()\n\n arch = _get_arch()\n _debug(f\"Using architecture {arch}\")\n\n file_name = f\"marzano-{platform}-{arch}\"\n meta_url = f\"https://api.keygen.sh/v1/accounts/{KEYGEN_ACCOUNT}/artifacts/{file_name}\"\n\n sys.stdout.write(f\"Retrieving Grit CLI metadata from {meta_url}\\n\")\n with httpx.Client() as client:\n response = client.get(meta_url) # pyright: ignore[reportUnknownMemberType]\n\n data = response.json()\n errors = data.get(\"errors\")\n if errors:\n for error in errors:\n sys.stdout.write(f\"{error}\\n\")\n\n raise CLIError(\"Could not locate Grit CLI binary - see above errors\")\n\n write_manifest(install_dir, data[\"data\"][\"relationships\"][\"release\"][\"data\"][\"id\"])\n\n link = data[\"data\"][\"links\"][\"redirect\"]\n _debug(f\"Redirect URL {link}\")\n\n download_response = client.get(link) # pyright: ignore[reportUnknownMemberType]\n with open(temp_file, \"wb\") as file:\n for chunk in download_response.iter_bytes():\n file.write(chunk)\n\n unpacked_dir = target_dir / \"cli-bin\"\n unpacked_dir.mkdir(parents=True, exist_ok=True)\n\n with tarfile.open(temp_file, \"r:gz\") as archive:\n archive.extractall(unpacked_dir, filter=\"data\")\n\n for item in unpacked_dir.iterdir():\n item.rename(target_dir / item.name)\n\n shutil.rmtree(unpacked_dir)\n os.remove(temp_file)\n os.chmod(target_path, 0o755)\n\n sys.stdout.flush()\n\n return target_path\n\n\ndef _get_arch() -> str:\n architecture = platform.machine().lower()\n\n # Map the architecture names to Node.js equivalents\n arch_map = {\n \"x86_64\": \"x64\",\n \"amd64\": \"x64\",\n \"armv7l\": \"arm\",\n \"aarch64\": \"arm64\",\n }\n\n return arch_map.get(architecture, architecture)\n\n\ndef write_manifest(install_path: Path, release: str) -> None:\n manifest = {\n \"installPath\": str(install_path),\n \"binaries\": {\n \"marzano\": {\n \"name\": \"marzano\",\n \"release\": release,\n },\n },\n }\n manifest_path = Path(install_path) / \"manifests.json\"\n with open(manifest_path, \"w\") as f:\n json.dump(manifest, f, indent=2)\n\n# Path: src/openai/cli/_tools/_main.py\nfrom __future__ import annotations\n\nfrom typing import TYPE_CHECKING\nfrom argparse import ArgumentParser\n\nfrom . import migrate, fine_tunes\n\nif TYPE_CHECKING:\n from argparse import _SubParsersAction\n\n\ndef register_commands(parser: ArgumentParser, subparser: _SubParsersAction[ArgumentParser]) -> None:\n migrate.register(subparser)\n\n namespaced = parser.add_subparsers(title=\"Tools\", help=\"Convenience client side tools\")\n\n fine_tunes.register(namespaced)\n\n# Path: src/openai/cli/_tools/__init__.py\nfrom ._main import register_commands as register_commands\n\n# Path: src/openai/cli/_cli.py\nfrom __future__ import annotations\n\nimport sys\nimport logging\nimport argparse\nfrom typing import Any, List, Type, Optional\nfrom typing_extensions import ClassVar\n\nimport httpx\nimport pydantic\n\nimport openai\n\nfrom . import _tools\nfrom .. import _ApiType, __version__\nfrom ._api import register_commands\nfrom ._utils import can_use_http2\nfrom .._types import ProxiesDict\nfrom ._errors import CLIError, display_error\nfrom .._compat import PYDANTIC_V2, ConfigDict, model_parse\nfrom .._models import BaseModel\nfrom .._exceptions import APIError\n\nlogger = logging.getLogger()\nformatter = logging.Formatter(\"[%(asctime)s] %(message)s\")\nhandler = logging.StreamHandler(sys.stderr)\nhandler.setFormatter(formatter)\nlogger.addHandler(handler)\n\n\nclass Arguments(BaseModel):\n if PYDANTIC_V2:\n model_config: ClassVar[ConfigDict] = ConfigDict(\n extra=\"ignore\",\n )\n else:\n\n class Config(pydantic.BaseConfig): # type: ignore\n extra: Any = pydantic.Extra.ignore # type: ignore\n\n verbosity: int\n version: Optional[str] = None\n\n api_key: Optional[str]\n api_base: Optional[str]\n organization: Optional[str]\n proxy: Optional[List[str]]\n api_type: Optional[_ApiType] = None\n api_version: Optional[str] = None\n\n # azure\n azure_endpoint: Optional[str] = None\n azure_ad_token: Optional[str] = None\n\n # internal, set by subparsers to parse their specific args\n args_model: Optional[Type[BaseModel]] = None\n\n # internal, used so that subparsers can forward unknown arguments\n unknown_args: List[str] = []\n allow_unknown_args: bool = False\n\n\ndef _build_parser() -> argparse.ArgumentParser:\n parser = argparse.ArgumentParser(description=None, prog=\"openai\")\n parser.add_argument(\n \"-v\",\n \"--verbose\",\n action=\"count\",\n dest=\"verbosity\",\n default=0,\n help=\"Set verbosity.\",\n )\n parser.add_argument(\"-b\", \"--api-base\", help=\"What API base url to use.\")\n parser.add_argument(\"-k\", \"--api-key\", help=\"What API key to use.\")\n parser.add_argument(\"-p\", \"--proxy\", nargs=\"+\", help=\"What proxy to use.\")\n parser.add_argument(\n \"-o\",\n \"--organization\",\n help=\"Which organization to run as (will use your default organization if not specified)\",\n )\n parser.add_argument(\n \"-t\",\n \"--api-type\",\n type=str,\n choices=(\"openai\", \"azure\"),\n help=\"The backend API to call, must be `openai` or `azure`\",\n )\n parser.add_argument(\n \"--api-version\",\n help=\"The Azure API version, e.g. 'https://learn.microsoft.com/en-us/azure/ai-services/openai/reference#rest-api-versioning'\",\n )\n\n # azure\n parser.add_argument(\n \"--azure-endpoint\",\n help=\"The Azure endpoint, e.g. 'https://endpoint.openai.azure.com'\",\n )\n parser.add_argument(\n \"--azure-ad-token\",\n help=\"A token from Azure Active Directory, https://www.microsoft.com/en-us/security/business/identity-access/microsoft-entra-id\",\n )\n\n # prints the package version\n parser.add_argument(\n \"-V\",\n \"--version\",\n action=\"version\",\n version=\"%(prog)s \" + __version__,\n )\n\n def help() -> None:\n parser.print_help()\n\n parser.set_defaults(func=help)\n\n subparsers = parser.add_subparsers()\n sub_api = subparsers.add_parser(\"api\", help=\"Direct API calls\")\n\n register_commands(sub_api)\n\n sub_tools = subparsers.add_parser(\"tools\", help=\"Client side tools for convenience\")\n _tools.register_commands(sub_tools, subparsers)\n\n return parser\n\n\ndef main() -> int:\n try:\n _main()\n except (APIError, CLIError, pydantic.ValidationError) as err:\n display_error(err)\n return 1\n except KeyboardInterrupt:\n sys.stderr.write(\"\\n\")\n return 1\n return 0\n\n\ndef _parse_args(parser: argparse.ArgumentParser) -> tuple[argparse.Namespace, Arguments, list[str]]:\n # argparse by default will strip out the `--` but we want to keep it for unknown arguments\n if \"--\" in sys.argv:\n idx = sys.argv.index(\"--\")\n known_args = sys.argv[1:idx]\n unknown_args = sys.argv[idx:]\n else:\n known_args = sys.argv[1:]\n unknown_args = []\n\n parsed, remaining_unknown = parser.parse_known_args(known_args)\n\n # append any remaining unknown arguments from the initial parsing\n remaining_unknown.extend(unknown_args)\n\n args = model_parse(Arguments, vars(parsed))\n if not args.allow_unknown_args:\n # we have to parse twice to ensure any unknown arguments\n # result in an error if that behaviour is desired\n parser.parse_args()\n\n return parsed, args, remaining_unknown\n\n\ndef _main() -> None:\n parser = _build_parser()\n parsed, args, unknown = _parse_args(parser)\n\n if args.verbosity != 0:\n sys.stderr.write(\"Warning: --verbosity isn't supported yet\\n\")\n\n proxies: ProxiesDict = {}\n if args.proxy is not None:\n for proxy in args.proxy:\n key = \"https://\" if proxy.startswith(\"https\") else \"http://\"\n if key in proxies:\n raise CLIError(f\"Multiple {key} proxies given - only the last one would be used\")\n\n proxies[key] = proxy\n\n http_client = httpx.Client(\n proxies=proxies or None,\n http2=can_use_http2(),\n )\n openai.http_client = http_client\n\n if args.organization:\n openai.organization = args.organization\n\n if args.api_key:\n openai.api_key = args.api_key\n\n if args.api_base:\n openai.base_url = args.api_base\n\n # azure\n if args.api_type is not None:\n openai.api_type = args.api_type\n\n if args.azure_endpoint is not None:\n openai.azure_endpoint = args.azure_endpoint\n\n if args.api_version is not None:\n openai.api_version = args.api_version\n\n if args.azure_ad_token is not None:\n openai.azure_ad_token = args.azure_ad_token\n\n try:\n if args.args_model:\n parsed.func(\n model_parse(\n args.args_model,\n {\n **{\n # we omit None values so that they can be defaulted to `NotGiven`\n # and we'll strip it from the API request\n key: value\n for key, value in vars(parsed).items()\n if value is not None\n },\n \"unknown_args\": unknown,\n },\n )\n )\n else:\n parsed.func()\n finally:\n try:\n http_client.close()\n except Exception:\n pass\n\n\nif __name__ == \"__main__\":\n sys.exit(main())\n\n# Path: src/openai/cli/__init__.py\nfrom ._cli import main as main\n\n# Path: src/openai/__main__.py\nfrom .cli import main\n\nmain()\n\n# Path: src/openai/types/beta/thread.py\n# File generated from our OpenAPI spec by Stainless.\n\nfrom typing import Optional\nfrom typing_extensions import Literal\n\nfrom ..._models import BaseModel\n\n__all__ = [\"Thread\"]\n\n\nclass Thread(BaseModel):\n id: str\n \"\"\"The identifier, which can be referenced in API endpoints.\"\"\"\n\n created_at: int\n \"\"\"The Unix timestamp (in seconds) for when the thread was created.\"\"\"\n\n metadata: Optional[object] = None\n \"\"\"Set of 16 key-value pairs that can be attached to an object.\n\n This can be useful for storing additional information about the object in a\n structured format. Keys can be a maximum of 64 characters long and values can be\n a maxium of 512 characters long.\n \"\"\"\n\n object: Literal[\"thread\"]\n \"\"\"The object type, which is always `thread`.\"\"\"\n\n# Path: src/openai/version.py\nfrom ._version import __version__\n\nVERSION: str = __version__\n\n# Path: src/openai/lib/_old_api.py\nfrom __future__ import annotations\n\nfrom typing import TYPE_CHECKING, Any\nfrom typing_extensions import override\n\nfrom .._utils import LazyProxy\nfrom .._exceptions import OpenAIError\n\nINSTRUCTIONS = \"\"\"\n\nYou tried to access openai.{symbol}, but this is no longer supported in openai>=1.0.0 - see the README at https://github.com/openai/openai-python for the API.\n\nYou can run `openai migrate` to automatically upgrade your codebase to use the 1.0.0 interface. \n\nAlternatively, you can pin your installation to the old version, e.g. `pip install openai==0.28`\n\nA detailed migration guide is available here: https://github.com/openai/openai-python/discussions/742\n\"\"\"\n\n\nclass APIRemovedInV1(OpenAIError):\n def __init__(self, *, symbol: str) -> None:\n super().__init__(INSTRUCTIONS.format(symbol=symbol))\n\n\nclass APIRemovedInV1Proxy(LazyProxy[Any]):\n def __init__(self, *, symbol: str) -> None:\n super().__init__()\n self._symbol = symbol\n\n @override\n def __load__(self) -> Any:\n # return the proxy until it is eventually called so that\n # we don't break people that are just checking the attributes\n # of a module\n return self\n\n...\n", "needle_token_start": 3867, "needle_token_end": 4060, "code_context_ntokens": 15951, "output": ["```python\nclass CLIAudio:\n @staticmethod\n def transcribe(args: CLITranscribeArgs) -> None:\n with open(args.file, \"rb\") as file_reader:\n buffer_reader = BufferReader(file_reader.read(), desc=\"Upload progress\")\n\n model = get_client().audio.transcriptions.create(\n file=(args.file, buffer_reader),\n model=args.model,\n language=args.language or NOT_GIVEN,\n temperature=args.temperature or NOT_GIVEN,\n prompt=args.prompt or NOT_GIVEN,\n # casts required because the API is typed for enums\n # but we don't want to validate that here for forwards-compat\n response_format=cast(Any, args.response_format),\n )\n print_model(model)\n\n @staticmethod\n def translate(args: CLITranslationArgs) -> None:\n with open(args.file, \"rb\") as file_reader:\n buffer_reader = BufferReader(file_reader.read(), desc=\"Upload progress\")\n\n model = get_client().audio.translations.create(\n file=(args.file, buffer_reader),\n model=args.model,\n temperature=args.temperature or NOT_GIVEN,\n prompt=args.prompt or NOT_GIVEN,\n # casts required because the API is typed for enums\n # but we don't want to validate that here for forwards-compat\n response_format=cast(Any, args.response_format),\n )\n print_model(model)\n```"]} +{"repo": "openai/openai-python", "name": "make_request_options", "language": "python", "path": "src/openai/_base_client.py", "position_ratio": 0.35, "description": "\nFunction Description:\n1. **Purpose**: To construct a dictionary of request options for HTTP operations, filtering out any parameters that are not provided or are marked as not given.\n2. **Input**: Optional parameters including query parameters, additional headers, extra query parameters, extra body content, an idempotency key, a timeout setting, and a post-parser function. Each of these can be omitted or marked as not given.\n3. **Output**: A dictionary tailored for configuring HTTP requests, containing only the parameters that were explicitly provided and valid.\n4. **Procedure**: The function initializes an empty dictionary and conditionally adds entries based on the presence and validity of input parameters. It merges additional query parameters with base query parameters if both are provided, and ensures that no entries for parameters marked as not given are included in the output dictionary.\n\n", "instruction": "Based on the function description and code context, please retrieve and repeat the exact described function from the code context in a code block wrapped by ```:", "template": "instruction\ncode_context\ndescription\ninstruction", "code_context": " def post(\n self,\n path: str,\n *,\n cast_to: Type[ResponseT],\n body: Body | None = None,\n options: RequestOptions = {},\n files: RequestFiles | None = None,\n stream: Literal[False] = False,\n ) -> ResponseT:\n ...\n\n @overload\n def post(\n self,\n path: str,\n *,\n cast_to: Type[ResponseT],\n body: Body | None = None,\n options: RequestOptions = {},\n files: RequestFiles | None = None,\n stream: Literal[True],\n stream_cls: type[_StreamT],\n ) -> _StreamT:\n ...\n\n @overload\n def post(\n self,\n path: str,\n *,\n cast_to: Type[ResponseT],\n body: Body | None = None,\n options: RequestOptions = {},\n files: RequestFiles | None = None,\n stream: bool,\n stream_cls: type[_StreamT] | None = None,\n ) -> ResponseT | _StreamT:\n ...\n\n def post(\n self,\n path: str,\n *,\n cast_to: Type[ResponseT],\n body: Body | None = None,\n options: RequestOptions = {},\n files: RequestFiles | None = None,\n stream: bool = False,\n stream_cls: type[_StreamT] | None = None,\n ) -> ResponseT | _StreamT:\n opts = FinalRequestOptions.construct(\n method=\"post\", url=path, json_data=body, files=to_httpx_files(files), **options\n )\n return cast(ResponseT, self.request(cast_to, opts, stream=stream, stream_cls=stream_cls))\n\n def patch(\n self,\n path: str,\n *,\n cast_to: Type[ResponseT],\n body: Body | None = None,\n options: RequestOptions = {},\n ) -> ResponseT:\n opts = FinalRequestOptions.construct(method=\"patch\", url=path, json_data=body, **options)\n return self.request(cast_to, opts)\n\n def put(\n self,\n path: str,\n *,\n cast_to: Type[ResponseT],\n body: Body | None = None,\n files: RequestFiles | None = None,\n options: RequestOptions = {},\n ) -> ResponseT:\n opts = FinalRequestOptions.construct(\n method=\"put\", url=path, json_data=body, files=to_httpx_files(files), **options\n )\n return self.request(cast_to, opts)\n\n def delete(\n self,\n path: str,\n *,\n cast_to: Type[ResponseT],\n body: Body | None = None,\n options: RequestOptions = {},\n ) -> ResponseT:\n opts = FinalRequestOptions.construct(method=\"delete\", url=path, json_data=body, **options)\n return self.request(cast_to, opts)\n\n def get_api_list(\n self,\n path: str,\n *,\n model: Type[object],\n page: Type[SyncPageT],\n body: Body | None = None,\n options: RequestOptions = {},\n method: str = \"get\",\n ) -> SyncPageT:\n opts = FinalRequestOptions.construct(method=method, url=path, json_data=body, **options)\n return self._request_api_list(model, page, opts)\n\n\nclass AsyncHttpxClientWrapper(httpx.AsyncClient):\n def __del__(self) -> None:\n try:\n # TODO(someday): support non asyncio runtimes here\n asyncio.get_running_loop().create_task(self.aclose())\n except Exception:\n pass\n\n\nclass AsyncAPIClient(BaseClient[httpx.AsyncClient, AsyncStream[Any]]):\n _client: httpx.AsyncClient\n _default_stream_cls: type[AsyncStream[Any]] | None = None\n\n def __init__(\n self,\n *,\n version: str,\n base_url: str | URL,\n _strict_response_validation: bool,\n max_retries: int = DEFAULT_MAX_RETRIES,\n timeout: float | Timeout | None | NotGiven = NOT_GIVEN,\n transport: AsyncTransport | None = None,\n proxies: ProxiesTypes | None = None,\n limits: Limits | None = None,\n http_client: httpx.AsyncClient | None = None,\n custom_headers: Mapping[str, str] | None = None,\n custom_query: Mapping[str, object] | None = None,\n ) -> None:\n if limits is not None:\n warnings.warn(\n \"The `connection_pool_limits` argument is deprecated. The `http_client` argument should be passed instead\",\n category=DeprecationWarning,\n stacklevel=3,\n )\n if http_client is not None:\n raise ValueError(\"The `http_client` argument is mutually exclusive with `connection_pool_limits`\")\n else:\n limits = DEFAULT_LIMITS\n\n if transport is not None:\n warnings.warn(\n \"The `transport` argument is deprecated. The `http_client` argument should be passed instead\",\n category=DeprecationWarning,\n stacklevel=3,\n )\n if http_client is not None:\n raise ValueError(\"The `http_client` argument is mutually exclusive with `transport`\")\n\n if proxies is not None:\n warnings.warn(\n \"The `proxies` argument is deprecated. The `http_client` argument should be passed instead\",\n category=DeprecationWarning,\n stacklevel=3,\n )\n if http_client is not None:\n raise ValueError(\"The `http_client` argument is mutually exclusive with `proxies`\")\n\n if not is_given(timeout):\n # if the user passed in a custom http client with a non-default\n # timeout set then we use that timeout.\n #\n # note: there is an edge case here where the user passes in a client\n # where they've explicitly set the timeout to match the default timeout\n # as this check is structural, meaning that we'll think they didn't\n # pass in a timeout and will ignore it\n if http_client and http_client.timeout != HTTPX_DEFAULT_TIMEOUT:\n timeout = http_client.timeout\n else:\n timeout = DEFAULT_TIMEOUT\n\n super().__init__(\n version=version,\n base_url=base_url,\n limits=limits,\n # cast to a valid type because mypy doesn't understand our type narrowing\n timeout=cast(Timeout, timeout),\n proxies=proxies,\n transport=transport,\n max_retries=max_retries,\n custom_query=custom_query,\n custom_headers=custom_headers,\n _strict_response_validation=_strict_response_validation,\n )\n self._client = http_client or AsyncHttpxClientWrapper(\n base_url=base_url,\n # cast to a valid type because mypy doesn't understand our type narrowing\n timeout=cast(Timeout, timeout),\n proxies=proxies,\n transport=transport,\n limits=limits,\n follow_redirects=True,\n )\n\n def is_closed(self) -> bool:\n return self._client.is_closed\n\n async def close(self) -> None:\n \"\"\"Close the underlying HTTPX client.\n\n The client will *not* be usable after this.\n \"\"\"\n await self._client.aclose()\n\n async def __aenter__(self: _T) -> _T:\n return self\n\n async def __aexit__(\n self,\n exc_type: type[BaseException] | None,\n exc: BaseException | None,\n exc_tb: TracebackType | None,\n ) -> None:\n await self.close()\n\n async def _prepare_options(\n self,\n options: FinalRequestOptions, # noqa: ARG002\n ) -> None:\n \"\"\"Hook for mutating the given options\"\"\"\n return None\n\n async def _prepare_request(\n self,\n request: httpx.Request, # noqa: ARG002\n ) -> None:\n \"\"\"This method is used as a callback for mutating the `Request` object\n after it has been constructed.\n This is useful for cases where you want to add certain headers based off of\n the request properties, e.g. `url`, `method` etc.\n \"\"\"\n return None\n\n @overload\n async def request(\n self,\n cast_to: Type[ResponseT],\n options: FinalRequestOptions,\n *,\n stream: Literal[False] = False,\n remaining_retries: Optional[int] = None,\n ) -> ResponseT:\n ...\n\n @overload\n async def request(\n self,\n cast_to: Type[ResponseT],\n options: FinalRequestOptions,\n *,\n stream: Literal[True],\n stream_cls: type[_AsyncStreamT],\n remaining_retries: Optional[int] = None,\n ) -> _AsyncStreamT:\n ...\n\n @overload\n async def request(\n self,\n cast_to: Type[ResponseT],\n options: FinalRequestOptions,\n *,\n stream: bool,\n stream_cls: type[_AsyncStreamT] | None = None,\n remaining_retries: Optional[int] = None,\n ) -> ResponseT | _AsyncStreamT:\n ...\n\n async def request(\n self,\n cast_to: Type[ResponseT],\n options: FinalRequestOptions,\n *,\n stream: bool = False,\n stream_cls: type[_AsyncStreamT] | None = None,\n remaining_retries: Optional[int] = None,\n ) -> ResponseT | _AsyncStreamT:\n return await self._request(\n cast_to=cast_to,\n options=options,\n stream=stream,\n stream_cls=stream_cls,\n remaining_retries=remaining_retries,\n )\n\n async def _request(\n self,\n cast_to: Type[ResponseT],\n options: FinalRequestOptions,\n *,\n stream: bool,\n stream_cls: type[_AsyncStreamT] | None,\n remaining_retries: int | None,\n ) -> ResponseT | _AsyncStreamT:\n cast_to = self._maybe_override_cast_to(cast_to, options)\n await self._prepare_options(options)\n\n retries = self._remaining_retries(remaining_retries, options)\n request = self._build_request(options)\n await self._prepare_request(request)\n\n kwargs: HttpxSendArgs = {}\n if self.custom_auth is not None:\n kwargs[\"auth\"] = self.custom_auth\n\n try:\n response = await self._client.send(\n request,\n stream=stream or self._should_stream_response_body(request=request),\n **kwargs,\n )\n except httpx.TimeoutException as err:\n log.debug(\"Encountered httpx.TimeoutException\", exc_info=True)\n\n if retries > 0:\n return await self._retry_request(\n options,\n cast_to,\n retries,\n stream=stream,\n stream_cls=stream_cls,\n response_headers=None,\n )\n\n log.debug(\"Raising timeout error\")\n raise APITimeoutError(request=request) from err\n except Exception as err:\n log.debug(\"Encountered Exception\", exc_info=True)\n\n if retries > 0:\n return await self._retry_request(\n options,\n cast_to,\n retries,\n stream=stream,\n stream_cls=stream_cls,\n response_headers=None,\n )\n\n log.debug(\"Raising connection error\")\n raise APIConnectionError(request=request) from err\n\n log.debug(\n 'HTTP Request: %s %s \"%i %s\"', request.method, request.url, response.status_code, response.reason_phrase\n )\n\n try:\n response.raise_for_status()\n except httpx.HTTPStatusError as err: # thrown on 4xx and 5xx status code\n log.debug(\"Encountered httpx.HTTPStatusError\", exc_info=True)\n\n if retries > 0 and self._should_retry(err.response):\n await err.response.aclose()\n return await self._retry_request(\n options,\n cast_to,\n retries,\n err.response.headers,\n stream=stream,\n stream_cls=stream_cls,\n )\n\n # If the response is streamed then we need to explicitly read the response\n # to completion before attempting to access the response text.\n if not err.response.is_closed:\n await err.response.aread()\n\n log.debug(\"Re-raising status error\")\n raise self._make_status_error_from_response(err.response) from None\n\n return await self._process_response(\n cast_to=cast_to,\n options=options,\n response=response,\n stream=stream,\n stream_cls=stream_cls,\n )\n\n async def _retry_request(\n self,\n options: FinalRequestOptions,\n cast_to: Type[ResponseT],\n remaining_retries: int,\n response_headers: httpx.Headers | None,\n *,\n stream: bool,\n stream_cls: type[_AsyncStreamT] | None,\n ) -> ResponseT | _AsyncStreamT:\n remaining = remaining_retries - 1\n if remaining == 1:\n log.debug(\"1 retry left\")\n else:\n log.debug(\"%i retries left\", remaining)\n\n timeout = self._calculate_retry_timeout(remaining, options, response_headers)\n log.info(\"Retrying request to %s in %f seconds\", options.url, timeout)\n\n await anyio.sleep(timeout)\n\n return await self._request(\n options=options,\n cast_to=cast_to,\n remaining_retries=remaining,\n stream=stream,\n stream_cls=stream_cls,\n )\n\n async def _process_response(\n self,\n *,\n cast_to: Type[ResponseT],\n options: FinalRequestOptions,\n response: httpx.Response,\n stream: bool,\n stream_cls: type[Stream[Any]] | type[AsyncStream[Any]] | None,\n ) -> ResponseT:\n if response.request.headers.get(RAW_RESPONSE_HEADER) == \"true\":\n return cast(\n ResponseT,\n LegacyAPIResponse(\n raw=response,\n client=self,\n cast_to=cast_to,\n stream=stream,\n stream_cls=stream_cls,\n options=options,\n ),\n )\n\n origin = get_origin(cast_to) or cast_to\n\n if inspect.isclass(origin) and issubclass(origin, BaseAPIResponse):\n if not issubclass(origin, AsyncAPIResponse):\n raise TypeError(f\"API Response types must subclass {AsyncAPIResponse}; Received {origin}\")\n\n response_cls = cast(\"type[BaseAPIResponse[Any]]\", cast_to)\n return cast(\n \"ResponseT\",\n response_cls(\n raw=response,\n client=self,\n cast_to=extract_response_type(response_cls),\n stream=stream,\n stream_cls=stream_cls,\n options=options,\n ),\n )\n\n if cast_to == httpx.Response:\n return cast(ResponseT, response)\n\n api_response = AsyncAPIResponse(\n raw=response,\n client=self,\n cast_to=cast(\"type[ResponseT]\", cast_to), # pyright: ignore[reportUnnecessaryCast]\n stream=stream,\n stream_cls=stream_cls,\n options=options,\n )\n if bool(response.request.headers.get(RAW_RESPONSE_HEADER)):\n return cast(ResponseT, api_response)\n\n return await api_response.parse()\n\n def _request_api_list(\n self,\n model: Type[_T],\n page: Type[AsyncPageT],\n options: FinalRequestOptions,\n ) -> AsyncPaginator[_T, AsyncPageT]:\n return AsyncPaginator(client=self, options=options, page_cls=page, model=model)\n\n @overload\n async def get(\n self,\n path: str,\n *,\n cast_to: Type[ResponseT],\n options: RequestOptions = {},\n stream: Literal[False] = False,\n ) -> ResponseT:\n ...\n\n @overload\n async def get(\n self,\n path: str,\n *,\n cast_to: Type[ResponseT],\n options: RequestOptions = {},\n stream: Literal[True],\n stream_cls: type[_AsyncStreamT],\n ) -> _AsyncStreamT:\n ...\n\n @overload\n async def get(\n self,\n path: str,\n *,\n cast_to: Type[ResponseT],\n options: RequestOptions = {},\n stream: bool,\n stream_cls: type[_AsyncStreamT] | None = None,\n ) -> ResponseT | _AsyncStreamT:\n ...\n\n async def get(\n self,\n path: str,\n *,\n cast_to: Type[ResponseT],\n options: RequestOptions = {},\n stream: bool = False,\n stream_cls: type[_AsyncStreamT] | None = None,\n ) -> ResponseT | _AsyncStreamT:\n opts = FinalRequestOptions.construct(method=\"get\", url=path, **options)\n return await self.request(cast_to, opts, stream=stream, stream_cls=stream_cls)\n\n @overload\n async def post(\n self,\n path: str,\n *,\n cast_to: Type[ResponseT],\n body: Body | None = None,\n files: RequestFiles | None = None,\n options: RequestOptions = {},\n stream: Literal[False] = False,\n ) -> ResponseT:\n ...\n\n @overload\n async def post(\n self,\n path: str,\n *,\n cast_to: Type[ResponseT],\n body: Body | None = None,\n files: RequestFiles | None = None,\n options: RequestOptions = {},\n stream: Literal[True],\n stream_cls: type[_AsyncStreamT],\n ) -> _AsyncStreamT:\n ...\n\n @overload\n async def post(\n self,\n path: str,\n *,\n cast_to: Type[ResponseT],\n body: Body | None = None,\n files: RequestFiles | None = None,\n options: RequestOptions = {},\n stream: bool,\n stream_cls: type[_AsyncStreamT] | None = None,\n ) -> ResponseT | _AsyncStreamT:\n ...\n\n async def post(\n self,\n path: str,\n *,\n cast_to: Type[ResponseT],\n body: Body | None = None,\n files: RequestFiles | None = None,\n options: RequestOptions = {},\n stream: bool = False,\n stream_cls: type[_AsyncStreamT] | None = None,\n ) -> ResponseT | _AsyncStreamT:\n opts = FinalRequestOptions.construct(\n method=\"post\", url=path, json_data=body, files=await async_to_httpx_files(files), **options\n )\n return await self.request(cast_to, opts, stream=stream, stream_cls=stream_cls)\n\n async def patch(\n self,\n path: str,\n *,\n cast_to: Type[ResponseT],\n body: Body | None = None,\n options: RequestOptions = {},\n ) -> ResponseT:\n opts = FinalRequestOptions.construct(method=\"patch\", url=path, json_data=body, **options)\n return await self.request(cast_to, opts)\n\n async def put(\n self,\n path: str,\n *,\n cast_to: Type[ResponseT],\n body: Body | None = None,\n files: RequestFiles | None = None,\n options: RequestOptions = {},\n ) -> ResponseT:\n opts = FinalRequestOptions.construct(\n method=\"put\", url=path, json_data=body, files=await async_to_httpx_files(files), **options\n )\n return await self.request(cast_to, opts)\n\n async def delete(\n self,\n path: str,\n *,\n cast_to: Type[ResponseT],\n body: Body | None = None,\n options: RequestOptions = {},\n ) -> ResponseT:\n opts = FinalRequestOptions.construct(method=\"delete\", url=path, json_data=body, **options)\n return await self.request(cast_to, opts)\n\n def get_api_list(\n self,\n path: str,\n *,\n model: Type[_T],\n page: Type[AsyncPageT],\n body: Body | None = None,\n options: RequestOptions = {},\n method: str = \"get\",\n ) -> AsyncPaginator[_T, AsyncPageT]:\n opts = FinalRequestOptions.construct(method=method, url=path, json_data=body, **options)\n return self._request_api_list(model, page, opts)\n\n\n\ndef make_request_options(\n *,\n query: Query | None = None,\n extra_headers: Headers | None = None,\n extra_query: Query | None = None,\n extra_body: Body | None = None,\n idempotency_key: str | None = None,\n timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,\n post_parser: PostParser | NotGiven = NOT_GIVEN,\n) -> RequestOptions:\n \"\"\"Create a dict of type RequestOptions without keys of NotGiven values.\"\"\"\n options: RequestOptions = {}\n if extra_headers is not None:\n options[\"headers\"] = extra_headers\n\n if extra_body is not None:\n options[\"extra_json\"] = cast(AnyMapping, extra_body)\n\n if query is not None:\n options[\"params\"] = query\n\n if extra_query is not None:\n options[\"params\"] = {**options.get(\"params\", {}), **extra_query}\n\n if not isinstance(timeout, NotGiven):\n options[\"timeout\"] = timeout\n\n if idempotency_key is not None:\n options[\"idempotency_key\"] = idempotency_key\n\n if is_given(post_parser):\n # internal\n options[\"post_parser\"] = post_parser # type: ignore\n\n return options\n\n\nclass OtherPlatform:\n def __init__(self, name: str) -> None:\n self.name = name\n\n @override\n def __str__(self) -> str:\n return f\"Other:{self.name}\"\n\n\nPlatform = Union[\n OtherPlatform,\n Literal[\n \"MacOS\",\n \"Linux\",\n \"Windows\",\n \"FreeBSD\",\n \"OpenBSD\",\n \"iOS\",\n \"Android\",\n \"Unknown\",\n ],\n]\n\n\ndef get_platform() -> Platform:\n try:\n system = platform.system().lower()\n platform_name = platform.platform().lower()\n except Exception:\n return \"Unknown\"\n\n if \"iphone\" in platform_name or \"ipad\" in platform_name:\n # Tested using Python3IDE on an iPhone 11 and Pythonista on an iPad 7\n # system is Darwin and platform_name is a string like:\n # - Darwin-21.6.0-iPhone12,1-64bit\n # - Darwin-21.6.0-iPad7,11-64bit\n return \"iOS\"\n\n if system == \"darwin\":\n return \"MacOS\"\n\n if system == \"windows\":\n return \"Windows\"\n\n if \"android\" in platform_name:\n # Tested using Pydroid 3\n # system is Linux and platform_name is a string like 'Linux-5.10.81-android12-9-00001-geba40aecb3b7-ab8534902-aarch64-with-libc'\n return \"Android\"\n\n if system == \"linux\":\n # https://distro.readthedocs.io/en/latest/#distro.id\n distro_id = distro.id()\n if distro_id == \"freebsd\":\n return \"FreeBSD\"\n\n if distro_id == \"openbsd\":\n return \"OpenBSD\"\n\n return \"Linux\"\n\n if platform_name:\n return OtherPlatform(platform_name)\n\n return \"Unknown\"\n\n\n@lru_cache(maxsize=None)\ndef platform_headers(version: str) -> Dict[str, str]:\n return {\n \"X-Stainless-Lang\": \"python\",\n \"X-Stainless-Package-Version\": version,\n \"X-Stainless-OS\": str(get_platform()),\n \"X-Stainless-Arch\": str(get_architecture()),\n \"X-Stainless-Runtime\": get_python_runtime(),\n \"X-Stainless-Runtime-Version\": get_python_version(),\n }\n\n\nclass OtherArch:\n def __init__(self, name: str) -> None:\n self.name = name\n\n @override\n def __str__(self) -> str:\n return f\"other:{self.name}\"\n\n\nArch = Union[OtherArch, Literal[\"x32\", \"x64\", \"arm\", \"arm64\", \"unknown\"]]\n\n\ndef get_python_runtime() -> str:\n try:\n return platform.python_implementation()\n except Exception:\n return \"unknown\"\n\n\ndef get_python_version() -> str:\n try:\n return platform.python_version()\n except Exception:\n return \"unknown\"\n\n\ndef get_architecture() -> Arch:\n try:\n python_bitness, _ = platform.architecture()\n machine = platform.machine().lower()\n except Exception:\n return \"unknown\"\n\n if machine in (\"arm64\", \"aarch64\"):\n return \"arm64\"\n\n # TODO: untested\n if machine == \"arm\":\n return \"arm\"\n\n if machine == \"x86_64\":\n return \"x64\"\n\n # TODO: untested\n if python_bitness == \"32bit\":\n return \"x32\"\n\n if machine:\n return OtherArch(machine)\n\n return \"unknown\"\n\n\ndef _merge_mappings(\n obj1: Mapping[_T_co, Union[_T, Omit]],\n obj2: Mapping[_T_co, Union[_T, Omit]],\n) -> Dict[_T_co, _T]:\n \"\"\"Merge two mappings of the same type, removing any values that are instances of `Omit`.\n\n In cases with duplicate keys the second mapping takes precedence.\n \"\"\"\n merged = {**obj1, **obj2}\n return {key: value for key, value in merged.items() if not isinstance(value, Omit)}\n\n# Path: src/openai/_module_client.py\n# File generated from our OpenAPI spec by Stainless.\n\nfrom typing_extensions import override\n\nfrom . import resources, _load_client\nfrom ._utils import LazyProxy\n\n\nclass ChatProxy(LazyProxy[resources.Chat]):\n @override\n def __load__(self) -> resources.Chat:\n return _load_client().chat\n\n\nclass BetaProxy(LazyProxy[resources.Beta]):\n @override\n def __load__(self) -> resources.Beta:\n return _load_client().beta\n\n\nclass FilesProxy(LazyProxy[resources.Files]):\n @override\n def __load__(self) -> resources.Files:\n return _load_client().files\n\n\nclass AudioProxy(LazyProxy[resources.Audio]):\n @override\n def __load__(self) -> resources.Audio:\n return _load_client().audio\n\n\nclass ImagesProxy(LazyProxy[resources.Images]):\n @override\n def __load__(self) -> resources.Images:\n return _load_client().images\n\n\nclass ModelsProxy(LazyProxy[resources.Models]):\n @override\n def __load__(self) -> resources.Models:\n return _load_client().models\n\n\nclass EmbeddingsProxy(LazyProxy[resources.Embeddings]):\n @override\n def __load__(self) -> resources.Embeddings:\n return _load_client().embeddings\n\n\nclass CompletionsProxy(LazyProxy[resources.Completions]):\n @override\n def __load__(self) -> resources.Completions:\n return _load_client().completions\n\n\nclass ModerationsProxy(LazyProxy[resources.Moderations]):\n @override\n def __load__(self) -> resources.Moderations:\n return _load_client().moderations\n\n\nclass FineTuningProxy(LazyProxy[resources.FineTuning]):\n @override\n def __load__(self) -> resources.FineTuning:\n return _load_client().fine_tuning\n\n\nchat: resources.Chat = ChatProxy().__as_proxied__()\nbeta: resources.Beta = BetaProxy().__as_proxied__()\nfiles: resources.Files = FilesProxy().__as_proxied__()\naudio: resources.Audio = AudioProxy().__as_proxied__()\nimages: resources.Images = ImagesProxy().__as_proxied__()\nmodels: resources.Models = ModelsProxy().__as_proxied__()\nembeddings: resources.Embeddings = EmbeddingsProxy().__as_proxied__()\ncompletions: resources.Completions = CompletionsProxy().__as_proxied__()\nmoderations: resources.Moderations = ModerationsProxy().__as_proxied__()\nfine_tuning: resources.FineTuning = FineTuningProxy().__as_proxied__()\n\n# Path: src/openai/_utils/_logs.py\nimport os\nimport logging\n\nlogger: logging.Logger = logging.getLogger(\"openai\")\nhttpx_logger: logging.Logger = logging.getLogger(\"httpx\")\n\n\ndef _basic_config() -> None:\n # e.g. [2023-10-05 14:12:26 - openai._base_client:818 - DEBUG] HTTP Request: POST http://127.0.0.1:4010/foo/bar \"200 OK\"\n logging.basicConfig(\n format=\"[%(asctime)s - %(name)s:%(lineno)d - %(levelname)s] %(message)s\",\n datefmt=\"%Y-%m-%d %H:%M:%S\",\n )\n\n\ndef setup_logging() -> None:\n env = os.environ.get(\"OPENAI_LOG\")\n if env == \"debug\":\n _basic_config()\n logger.setLevel(logging.DEBUG)\n httpx_logger.setLevel(logging.DEBUG)\n elif env == \"info\":\n _basic_config()\n logger.setLevel(logging.INFO)\n httpx_logger.setLevel(logging.INFO)\n\n# Path: src/openai/__init__.py\n# File generated from our OpenAPI spec by Stainless.\n\nfrom __future__ import annotations\n\nimport os as _os\nfrom typing_extensions import override\n\nfrom . import types\nfrom ._types import NoneType, Transport, ProxiesTypes\nfrom ._utils import file_from_path\nfrom ._client import Client, OpenAI, Stream, Timeout, Transport, AsyncClient, AsyncOpenAI, AsyncStream, RequestOptions\nfrom ._models import BaseModel\nfrom ._version import __title__, __version__\nfrom ._response import APIResponse as APIResponse, AsyncAPIResponse as AsyncAPIResponse\nfrom ._exceptions import (\n APIError,\n OpenAIError,\n ConflictError,\n NotFoundError,\n APIStatusError,\n RateLimitError,\n APITimeoutError,\n BadRequestError,\n APIConnectionError,\n AuthenticationError,\n InternalServerError,\n PermissionDeniedError,\n UnprocessableEntityError,\n APIResponseValidationError,\n)\nfrom ._utils._logs import setup_logging as _setup_logging\n\n__all__ = [\n \"types\",\n \"__version__\",\n \"__title__\",\n \"NoneType\",\n \"Transport\",\n \"ProxiesTypes\",\n \"OpenAIError\",\n \"APIError\",\n \"APIStatusError\",\n \"APITimeoutError\",\n \"APIConnectionError\",\n \"APIResponseValidationError\",\n \"BadRequestError\",\n \"AuthenticationError\",\n \"PermissionDeniedError\",\n \"NotFoundError\",\n \"ConflictError\",\n \"UnprocessableEntityError\",\n \"RateLimitError\",\n \"InternalServerError\",\n \"Timeout\",\n \"RequestOptions\",\n \"Client\",\n \"AsyncClient\",\n \"Stream\",\n \"AsyncStream\",\n \"OpenAI\",\n \"AsyncOpenAI\",\n \"file_from_path\",\n \"BaseModel\",\n]\n\nfrom .lib import azure as _azure\nfrom .version import VERSION as VERSION\nfrom .lib.azure import AzureOpenAI as AzureOpenAI, AsyncAzureOpenAI as AsyncAzureOpenAI\nfrom .lib._old_api import *\n\n_setup_logging()\n\n# Update the __module__ attribute for exported symbols so that\n# error messages point to this module instead of the module\n# it was originally defined in, e.g.\n# openai._exceptions.NotFoundError -> openai.NotFoundError\n__locals = locals()\nfor __name in __all__:\n if not __name.startswith(\"__\"):\n try:\n __locals[__name].__module__ = \"openai\"\n except (TypeError, AttributeError):\n # Some of our exported symbols are builtins which we can't set attributes for.\n pass\n\n# ------ Module level client ------\nimport typing as _t\nimport typing_extensions as _te\n\nimport httpx as _httpx\n\nfrom ._base_client import DEFAULT_TIMEOUT, DEFAULT_MAX_RETRIES\n\napi_key: str | None = None\n\norganization: str | None = None\n\nbase_url: str | _httpx.URL | None = None\n\ntimeout: float | Timeout | None = DEFAULT_TIMEOUT\n\nmax_retries: int = DEFAULT_MAX_RETRIES\n\ndefault_headers: _t.Mapping[str, str] | None = None\n\ndefault_query: _t.Mapping[str, object] | None = None\n\nhttp_client: _httpx.Client | None = None\n\n_ApiType = _te.Literal[\"openai\", \"azure\"]\n\napi_type: _ApiType | None = _t.cast(_ApiType, _os.environ.get(\"OPENAI_API_TYPE\"))\n\napi_version: str | None = _os.environ.get(\"OPENAI_API_VERSION\")\n\nazure_endpoint: str | None = _os.environ.get(\"AZURE_OPENAI_ENDPOINT\")\n\nazure_ad_token: str | None = _os.environ.get(\"AZURE_OPENAI_AD_TOKEN\")\n\nazure_ad_token_provider: _azure.AzureADTokenProvider | None = None\n\n\nclass _ModuleClient(OpenAI):\n # Note: we have to use type: ignores here as overriding class members\n # with properties is technically unsafe but it is fine for our use case\n\n @property # type: ignore\n @override\n def api_key(self) -> str | None:\n return api_key\n\n @api_key.setter # type: ignore\n def api_key(self, value: str | None) -> None: # type: ignore\n global api_key\n\n api_key = value\n\n @property # type: ignore\n @override\n def organization(self) -> str | None:\n return organization\n\n @organization.setter # type: ignore\n def organization(self, value: str | None) -> None: # type: ignore\n global organization\n\n organization = value\n\n @property\n @override\n def base_url(self) -> _httpx.URL:\n if base_url is not None:\n return _httpx.URL(base_url)\n\n return super().base_url\n\n @base_url.setter\n def base_url(self, url: _httpx.URL | str) -> None:\n super().base_url = url # type: ignore[misc]\n\n @property # type: ignore\n @override\n def timeout(self) -> float | Timeout | None:\n return timeout\n\n @timeout.setter # type: ignore\n def timeout(self, value: float | Timeout | None) -> None: # type: ignore\n global timeout\n\n timeout = value\n\n @property # type: ignore\n @override\n def max_retries(self) -> int:\n return max_retries\n\n @max_retries.setter # type: ignore\n def max_retries(self, value: int) -> None: # type: ignore\n global max_retries\n\n max_retries = value\n\n @property # type: ignore\n @override\n def _custom_headers(self) -> _t.Mapping[str, str] | None:\n return default_headers\n\n @_custom_headers.setter # type: ignore\n def _custom_headers(self, value: _t.Mapping[str, str] | None) -> None: # type: ignore\n global default_headers\n\n default_headers = value\n\n @property # type: ignore\n @override\n def _custom_query(self) -> _t.Mapping[str, object] | None:\n return default_query\n\n @_custom_query.setter # type: ignore\n def _custom_query(self, value: _t.Mapping[str, object] | None) -> None: # type: ignore\n global default_query\n\n default_query = value\n\n @property # type: ignore\n @override\n def _client(self) -> _httpx.Client:\n return http_client or super()._client\n\n @_client.setter # type: ignore\n def _client(self, value: _httpx.Client) -> None: # type: ignore\n global http_client\n\n http_client = value\n\n\nclass _AzureModuleClient(_ModuleClient, AzureOpenAI): # type: ignore\n ...\n\n\nclass _AmbiguousModuleClientUsageError(OpenAIError):\n def __init__(self) -> None:\n super().__init__(\n \"Ambiguous use of module client; please set `openai.api_type` or the `OPENAI_API_TYPE` environment variable to `openai` or `azure`\"\n )\n\n\ndef _has_openai_credentials() -> bool:\n return _os.environ.get(\"OPENAI_API_KEY\") is not None\n\n\ndef _has_azure_credentials() -> bool:\n return azure_endpoint is not None or _os.environ.get(\"AZURE_OPENAI_API_KEY\") is not None\n\n\ndef _has_azure_ad_credentials() -> bool:\n return (\n _os.environ.get(\"AZURE_OPENAI_AD_TOKEN\") is not None\n or azure_ad_token is not None\n or azure_ad_token_provider is not None\n )\n\n\n_client: OpenAI | None = None\n\n\ndef _load_client() -> OpenAI: # type: ignore[reportUnusedFunction]\n global _client\n\n if _client is None:\n global api_type, azure_endpoint, azure_ad_token, api_version\n\n if azure_endpoint is None:\n azure_endpoint = _os.environ.get(\"AZURE_OPENAI_ENDPOINT\")\n\n if azure_ad_token is None:\n azure_ad_token = _os.environ.get(\"AZURE_OPENAI_AD_TOKEN\")\n\n if api_version is None:\n api_version = _os.environ.get(\"OPENAI_API_VERSION\")\n\n if api_type is None:\n has_openai = _has_openai_credentials()\n has_azure = _has_azure_credentials()\n has_azure_ad = _has_azure_ad_credentials()\n\n if has_openai and (has_azure or has_azure_ad):\n raise _AmbiguousModuleClientUsageError()\n\n if (azure_ad_token is not None or azure_ad_token_provider is not None) and _os.environ.get(\n \"AZURE_OPENAI_API_KEY\"\n ) is not None:\n raise _AmbiguousModuleClientUsageError()\n\n if has_azure or has_azure_ad:\n api_type = \"azure\"\n else:\n api_type = \"openai\"\n\n if api_type == \"azure\":\n _client = _AzureModuleClient( # type: ignore\n api_version=api_version,\n azure_endpoint=azure_endpoint,\n api_key=api_key,\n azure_ad_token=azure_ad_token,\n azure_ad_token_provider=azure_ad_token_provider,\n organization=organization,\n base_url=base_url,\n timeout=timeout,\n max_retries=max_retries,\n default_headers=default_headers,\n default_query=default_query,\n http_client=http_client,\n )\n return _client\n\n _client = _ModuleClient(\n api_key=api_key,\n organization=organization,\n base_url=base_url,\n timeout=timeout,\n max_retries=max_retries,\n default_headers=default_headers,\n default_query=default_query,\n http_client=http_client,\n )\n return _client\n\n return _client\n\n\ndef _reset_client() -> None: # type: ignore[reportUnusedFunction]\n global _client\n\n _client = None\n\n\nfrom ._module_client import (\n beta as beta,\n chat as chat,\n audio as audio,\n files as files,\n images as images,\n models as models,\n embeddings as embeddings,\n completions as completions,\n fine_tuning as fine_tuning,\n moderations as moderations,\n)\n\n# Path: src/openai/cli/_models.py\nfrom typing import Any\nfrom typing_extensions import ClassVar\n\nimport pydantic\n\nfrom .. import _models\nfrom .._compat import PYDANTIC_V2, ConfigDict\n\n\nclass BaseModel(_models.BaseModel):\n if PYDANTIC_V2:\n model_config: ClassVar[ConfigDict] = ConfigDict(extra=\"ignore\", arbitrary_types_allowed=True)\n else:\n\n class Config(pydantic.BaseConfig): # type: ignore\n extra: Any = pydantic.Extra.ignore # type: ignore\n arbitrary_types_allowed: bool = True\n\n# Path: src/openai/cli/_progress.py\nfrom __future__ import annotations\n\nimport io\nfrom typing import Callable\nfrom typing_extensions import override\n\n\nclass CancelledError(Exception):\n def __init__(self, msg: str) -> None:\n self.msg = msg\n super().__init__(msg)\n\n @override\n def __str__(self) -> str:\n return self.msg\n\n __repr__ = __str__\n\n\nclass BufferReader(io.BytesIO):\n def __init__(self, buf: bytes = b\"\", desc: str | None = None) -> None:\n super().__init__(buf)\n self._len = len(buf)\n self._progress = 0\n self._callback = progress(len(buf), desc=desc)\n\n def __len__(self) -> int:\n return self._len\n\n @override\n def read(self, n: int | None = -1) -> bytes:\n chunk = io.BytesIO.read(self, n)\n self._progress += len(chunk)\n\n try:\n self._callback(self._progress)\n except Exception as e: # catches exception from the callback\n raise CancelledError(\"The upload was cancelled: {}\".format(e)) from e\n\n return chunk\n\n\ndef progress(total: float, desc: str | None) -> Callable[[float], None]:\n import tqdm\n\n meter = tqdm.tqdm(total=total, unit_scale=True, desc=desc)\n\n def incr(progress: float) -> None:\n meter.n = progress\n if progress == total:\n meter.close()\n else:\n meter.refresh()\n\n return incr\n\n\ndef MB(i: int) -> int:\n return int(i // 1024**2)\n\n# Path: src/openai/cli/_utils.py\nfrom __future__ import annotations\n\nimport sys\n\nimport openai\n\nfrom .. import OpenAI, _load_client\nfrom .._compat import model_json\nfrom .._models import BaseModel\n\n\nclass Colors:\n HEADER = \"\\033[95m\"\n OKBLUE = \"\\033[94m\"\n OKGREEN = \"\\033[92m\"\n WARNING = \"\\033[93m\"\n FAIL = \"\\033[91m\"\n ENDC = \"\\033[0m\"\n BOLD = \"\\033[1m\"\n UNDERLINE = \"\\033[4m\"\n\n\ndef get_client() -> OpenAI:\n return _load_client()\n\n\ndef organization_info() -> str:\n organization = openai.organization\n if organization is not None:\n return \"[organization={}] \".format(organization)\n\n return \"\"\n\n\ndef print_model(model: BaseModel) -> None:\n sys.stdout.write(model_json(model, indent=2) + \"\\n\")\n\n\ndef can_use_http2() -> bool:\n try:\n import h2 # type: ignore # noqa\n except ImportError:\n return False\n\n return True\n\n# Path: src/openai/cli/_api/audio.py\nfrom __future__ import annotations\n\nfrom typing import TYPE_CHECKING, Any, Optional, cast\nfrom argparse import ArgumentParser\n\nfrom .._utils import get_client, print_model\nfrom ..._types import NOT_GIVEN\nfrom .._models import BaseModel\nfrom .._progress import BufferReader\n\nif TYPE_CHECKING:\n from argparse import _SubParsersAction\n\n\ndef register(subparser: _SubParsersAction[ArgumentParser]) -> None:\n # transcriptions\n sub = subparser.add_parser(\"audio.transcriptions.create\")\n\n # Required\n sub.add_argument(\"-m\", \"--model\", type=str, default=\"whisper-1\")\n sub.add_argument(\"-f\", \"--file\", type=str, required=True)\n # Optional\n sub.add_argument(\"--response-format\", type=str)\n sub.add_argument(\"--language\", type=str)\n sub.add_argument(\"-t\", \"--temperature\", type=float)\n sub.add_argument(\"--prompt\", type=str)\n sub.set_defaults(func=CLIAudio.transcribe, args_model=CLITranscribeArgs)\n\n # translations\n sub = subparser.add_parser(\"audio.translations.create\")\n\n # Required\n sub.add_argument(\"-f\", \"--file\", type=str, required=True)\n # Optional\n sub.add_argument(\"-m\", \"--model\", type=str, default=\"whisper-1\")\n sub.add_argument(\"--response-format\", type=str)\n # TODO: doesn't seem to be supported by the API\n # sub.add_argument(\"--language\", type=str)\n sub.add_argument(\"-t\", \"--temperature\", type=float)\n sub.add_argument(\"--prompt\", type=str)\n sub.set_defaults(func=CLIAudio.translate, args_model=CLITranslationArgs)\n\n\nclass CLITranscribeArgs(BaseModel):\n model: str\n file: str\n response_format: Optional[str] = None\n language: Optional[str] = None\n temperature: Optional[float] = None\n prompt: Optional[str] = None\n\n\nclass CLITranslationArgs(BaseModel):\n model: str\n file: str\n response_format: Optional[str] = None\n language: Optional[str] = None\n temperature: Optional[float] = None\n prompt: Optional[str] = None\n\n\nclass CLIAudio:\n @staticmethod\n def transcribe(args: CLITranscribeArgs) -> None:\n with open(args.file, \"rb\") as file_reader:\n buffer_reader = BufferReader(file_reader.read(), desc=\"Upload progress\")\n\n model = get_client().audio.transcriptions.create(\n file=(args.file, buffer_reader),\n model=args.model,\n language=args.language or NOT_GIVEN,\n temperature=args.temperature or NOT_GIVEN,\n prompt=args.prompt or NOT_GIVEN,\n # casts required because the API is typed for enums\n # but we don't want to validate that here for forwards-compat\n response_format=cast(Any, args.response_format),\n )\n print_model(model)\n\n @staticmethod\n def translate(args: CLITranslationArgs) -> None:\n with open(args.file, \"rb\") as file_reader:\n buffer_reader = BufferReader(file_reader.read(), desc=\"Upload progress\")\n\n model = get_client().audio.translations.create(\n file=(args.file, buffer_reader),\n model=args.model,\n temperature=args.temperature or NOT_GIVEN,\n prompt=args.prompt or NOT_GIVEN,\n # casts required because the API is typed for enums\n # but we don't want to validate that here for forwards-compat\n response_format=cast(Any, args.response_format),\n )\n print_model(model)\n\n# Path: src/openai/cli/_api/chat/completions.py\nfrom __future__ import annotations\n\nimport sys\nfrom typing import TYPE_CHECKING, List, Optional, cast\nfrom argparse import ArgumentParser\nfrom typing_extensions import Literal, NamedTuple\n\nfrom ..._utils import get_client\nfrom ..._models import BaseModel\nfrom ...._streaming import Stream\nfrom ....types.chat import (\n ChatCompletionRole,\n ChatCompletionChunk,\n CompletionCreateParams,\n)\nfrom ....types.chat.completion_create_params import (\n CompletionCreateParamsStreaming,\n CompletionCreateParamsNonStreaming,\n)\n\nif TYPE_CHECKING:\n from argparse import _SubParsersAction\n\n\ndef register(subparser: _SubParsersAction[ArgumentParser]) -> None:\n sub = subparser.add_parser(\"chat.completions.create\")\n\n sub._action_groups.pop()\n req = sub.add_argument_group(\"required arguments\")\n opt = sub.add_argument_group(\"optional arguments\")\n\n req.add_argument(\n \"-g\",\n \"--message\",\n action=\"append\",\n nargs=2,\n metavar=(\"ROLE\", \"CONTENT\"),\n help=\"A message in `{role} {content}` format. Use this argument multiple times to add multiple messages.\",\n required=True,\n )\n req.add_argument(\n \"-m\",\n \"--model\",\n help=\"The model to use.\",\n required=True,\n )\n\n opt.add_argument(\n \"-n\",\n \"--n\",\n help=\"How many completions to generate for the conversation.\",\n type=int,\n )\n opt.add_argument(\"-M\", \"--max-tokens\", help=\"The maximum number of tokens to generate.\", type=int)\n opt.add_argument(\n \"-t\",\n \"--temperature\",\n help=\"\"\"What sampling temperature to use. Higher values means the model will take more risks. Try 0.9 for more creative applications, and 0 (argmax sampling) for ones with a well-defined answer.\n\nMutually exclusive with `top_p`.\"\"\",\n type=float,\n )\n opt.add_argument(\n \"-P\",\n \"--top_p\",\n help=\"\"\"An alternative to sampling with temperature, called nucleus sampling, where the considers the results of the tokens with top_p probability mass. So 0.1 means only the tokens comprising the top 10%% probability mass are considered.\n\n Mutually exclusive with `temperature`.\"\"\",\n type=float,\n )\n opt.add_argument(\n \"--stop\",\n help=\"A stop sequence at which to stop generating tokens for the message.\",\n )\n opt.add_argument(\"--stream\", help=\"Stream messages as they're ready.\", action=\"store_true\")\n sub.set_defaults(func=CLIChatCompletion.create, args_model=CLIChatCompletionCreateArgs)\n\n\nclass CLIMessage(NamedTuple):\n role: ChatCompletionRole\n content: str\n\n\nclass CLIChatCompletionCreateArgs(BaseModel):\n message: List[CLIMessage]\n model: str\n n: Optional[int] = None\n max_tokens: Optional[int] = None\n temperature: Optional[float] = None\n top_p: Optional[float] = None\n stop: Optional[str] = None\n stream: bool = False\n\n\nclass CLIChatCompletion:\n @staticmethod\n def create(args: CLIChatCompletionCreateArgs) -> None:\n params: CompletionCreateParams = {\n \"model\": args.model,\n \"messages\": [\n {\"role\": cast(Literal[\"user\"], message.role), \"content\": message.content} for message in args.message\n ],\n \"n\": args.n,\n \"temperature\": args.temperature,\n \"top_p\": args.top_p,\n \"stop\": args.stop,\n # type checkers are not good at inferring union types so we have to set stream afterwards\n \"stream\": False,\n }\n if args.stream:\n params[\"stream\"] = args.stream # type: ignore\n if args.max_tokens is not None:\n params[\"max_tokens\"] = args.max_tokens\n\n if args.stream:\n return CLIChatCompletion._stream_create(cast(CompletionCreateParamsStreaming, params))\n\n return CLIChatCompletion._create(cast(CompletionCreateParamsNonStreaming, params))\n\n @staticmethod\n def _create(params: CompletionCreateParamsNonStreaming) -> None:\n completion = get_client().chat.completions.create(**params)\n should_print_header = len(completion.choices) > 1\n for choice in completion.choices:\n if should_print_header:\n sys.stdout.write(\"===== Chat Completion {} =====\\n\".format(choice.index))\n\n content = choice.message.content if choice.message.content is not None else \"None\"\n sys.stdout.write(content)\n\n if should_print_header or not content.endswith(\"\\n\"):\n sys.stdout.write(\"\\n\")\n\n sys.stdout.flush()\n\n @staticmethod\n def _stream_create(params: CompletionCreateParamsStreaming) -> None:\n # cast is required for mypy\n stream = cast( # pyright: ignore[reportUnnecessaryCast]\n Stream[ChatCompletionChunk], get_client().chat.completions.create(**params)\n )\n for chunk in stream:\n should_print_header = len(chunk.choices) > 1\n for choice in chunk.choices:\n if should_print_header:\n sys.stdout.write(\"===== Chat Completion {} =====\\n\".format(choice.index))\n\n content = choice.delta.content or \"\"\n sys.stdout.write(content)\n\n if should_print_header:\n sys.stdout.write(\"\\n\")\n\n sys.stdout.flush()\n\n sys.stdout.write(\"\\n\")\n\n# Path: src/openai/cli/_api/chat/__init__.py\nfrom __future__ import annotations\n\nfrom typing import TYPE_CHECKING\nfrom argparse import ArgumentParser\n\nfrom . import completions\n\nif TYPE_CHECKING:\n from argparse import _SubParsersAction\n\n\ndef register(subparser: _SubParsersAction[ArgumentParser]) -> None:\n completions.register(subparser)\n\n# Path: src/openai/cli/_errors.py\nfrom __future__ import annotations\n\nimport sys\n\nimport pydantic\n\nfrom ._utils import Colors, organization_info\nfrom .._exceptions import APIError, OpenAIError\n\n\nclass CLIError(OpenAIError):\n ...\n\n\nclass SilentCLIError(CLIError):\n ...\n\n\ndef display_error(err: CLIError | APIError | pydantic.ValidationError) -> None:\n if isinstance(err, SilentCLIError):\n return\n\n sys.stderr.write(\"{}{}Error:{} {}\\n\".format(organization_info(), Colors.FAIL, Colors.ENDC, err))\n\n# Path: src/openai/cli/_api/completions.py\nfrom __future__ import annotations\n\nimport sys\nfrom typing import TYPE_CHECKING, Optional, cast\nfrom argparse import ArgumentParser\nfrom functools import partial\n\nfrom openai.types.completion import Completion\n\nfrom .._utils import get_client\nfrom ..._types import NOT_GIVEN, NotGivenOr\nfrom ..._utils import is_given\nfrom .._errors import CLIError\nfrom .._models import BaseModel\nfrom ..._streaming import Stream\n\nif TYPE_CHECKING:\n from argparse import _SubParsersAction\n\n\ndef register(subparser: _SubParsersAction[ArgumentParser]) -> None:\n sub = subparser.add_parser(\"completions.create\")\n\n # Required\n sub.add_argument(\n \"-m\",\n \"--model\",\n help=\"The model to use\",\n required=True,\n )\n\n # Optional\n sub.add_argument(\"-p\", \"--prompt\", help=\"An optional prompt to complete from\")\n sub.add_argument(\"--stream\", help=\"Stream tokens as they're ready.\", action=\"store_true\")\n sub.add_argument(\"-M\", \"--max-tokens\", help=\"The maximum number of tokens to generate\", type=int)\n sub.add_argument(\n \"-t\",\n \"--temperature\",\n help=\"\"\"What sampling temperature to use. Higher values means the model will take more risks. Try 0.9 for more creative applications, and 0 (argmax sampling) for ones with a well-defined answer.\n\nMutually exclusive with `top_p`.\"\"\",\n type=float,\n )\n sub.add_argument(\n \"-P\",\n \"--top_p\",\n help=\"\"\"An alternative to sampling with temperature, called nucleus sampling, where the considers the results of the tokens with top_p probability mass. So 0.1 means only the tokens comprising the top 10%% probability mass are considered.\n\n Mutually exclusive with `temperature`.\"\"\",\n type=float,\n )\n sub.add_argument(\n \"-n\",\n \"--n\",\n help=\"How many sub-completions to generate for each prompt.\",\n type=int,\n )\n sub.add_argument(\n \"--logprobs\",\n help=\"Include the log probabilities on the `logprobs` most likely tokens, as well the chosen tokens. So for example, if `logprobs` is 10, the API will return a list of the 10 most likely tokens. If `logprobs` is 0, only the chosen tokens will have logprobs returned.\",\n type=int,\n )\n sub.add_argument(\n \"--best_of\",\n help=\"Generates `best_of` completions server-side and returns the 'best' (the one with the highest log probability per token). Results cannot be streamed.\",\n type=int,\n )\n sub.add_argument(\n \"--echo\",\n help=\"Echo back the prompt in addition to the completion\",\n action=\"store_true\",\n )\n sub.add_argument(\n \"--frequency_penalty\",\n help=\"Positive values penalize new tokens based on their existing frequency in the text so far, decreasing the model's likelihood to repeat the same line verbatim.\",\n type=float,\n )\n sub.add_argument(\n \"--presence_penalty\",\n help=\"Positive values penalize new tokens based on whether they appear in the text so far, increasing the model's likelihood to talk about new topics.\",\n type=float,\n )\n sub.add_argument(\"--suffix\", help=\"The suffix that comes after a completion of inserted text.\")\n sub.add_argument(\"--stop\", help=\"A stop sequence at which to stop generating tokens.\")\n sub.add_argument(\n \"--user\",\n help=\"A unique identifier representing your end-user, which can help OpenAI to monitor and detect abuse.\",\n )\n # TODO: add support for logit_bias\n sub.set_defaults(func=CLICompletions.create, args_model=CLICompletionCreateArgs)\n\n\nclass CLICompletionCreateArgs(BaseModel):\n model: str\n stream: bool = False\n\n prompt: Optional[str] = None\n n: NotGivenOr[int] = NOT_GIVEN\n stop: NotGivenOr[str] = NOT_GIVEN\n user: NotGivenOr[str] = NOT_GIVEN\n echo: NotGivenOr[bool] = NOT_GIVEN\n suffix: NotGivenOr[str] = NOT_GIVEN\n best_of: NotGivenOr[int] = NOT_GIVEN\n top_p: NotGivenOr[float] = NOT_GIVEN\n logprobs: NotGivenOr[int] = NOT_GIVEN\n max_tokens: NotGivenOr[int] = NOT_GIVEN\n temperature: NotGivenOr[float] = NOT_GIVEN\n presence_penalty: NotGivenOr[float] = NOT_GIVEN\n frequency_penalty: NotGivenOr[float] = NOT_GIVEN\n\n\nclass CLICompletions:\n @staticmethod\n def create(args: CLICompletionCreateArgs) -> None:\n if is_given(args.n) and args.n > 1 and args.stream:\n raise CLIError(\"Can't stream completions with n>1 with the current CLI\")\n\n make_request = partial(\n get_client().completions.create,\n n=args.n,\n echo=args.echo,\n stop=args.stop,\n...\n", "needle_token_start": 5496, "needle_token_end": 5836, "code_context_ntokens": 16052, "output": ["```python\ndef make_request_options(\n *,\n query: Query | None = None,\n extra_headers: Headers | None = None,\n extra_query: Query | None = None,\n extra_body: Body | None = None,\n idempotency_key: str | None = None,\n timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,\n post_parser: PostParser | NotGiven = NOT_GIVEN,\n) -> RequestOptions:\n \"\"\"Create a dict of type RequestOptions without keys of NotGiven values.\"\"\"\n options: RequestOptions = {}\n if extra_headers is not None:\n options[\"headers\"] = extra_headers\n\n if extra_body is not None:\n options[\"extra_json\"] = cast(AnyMapping, extra_body)\n\n if query is not None:\n options[\"params\"] = query\n\n if extra_query is not None:\n options[\"params\"] = {**options.get(\"params\", {}), **extra_query}\n\n if not isinstance(timeout, NotGiven):\n options[\"timeout\"] = timeout\n\n if idempotency_key is not None:\n options[\"idempotency_key\"] = idempotency_key\n\n if is_given(post_parser):\n # internal\n options[\"post_parser\"] = post_parser # type: ignore\n\n return options\n```"]} +{"repo": "openai/openai-python", "name": "_primitive_value_to_str", "language": "python", "path": "src/openai/_qs.py", "position_ratio": 0.45, "description": "\nFunction Description:\n1. **Purpose**: Converts basic data types to their string representations for use in URL query strings.\n2. **Input**: A single value which can be a boolean, None, or other primitive data types.\n3. **Output**: A string representation of the input value, tailored for URL encoding.\n4. **Procedure**: The function checks the type of the input value:\n - If the value is `True`, it returns the string \"true\".\n - If the value is `False`, it returns the string \"false\".\n - If the value is `None`, it returns an empty string.\n - For all other data types, it converts the value to a string using Python's built-in `str()` function.\n\n", "instruction": "Based on the function description and code context, please retrieve and repeat the exact described function from the code context in a code block wrapped by ```:", "template": "instruction\ncode_context\ndescription\ninstruction", "code_context": "# Path: src/openai/_utils/__init__.py\nfrom ._sync import asyncify as asyncify\nfrom ._proxy import LazyProxy as LazyProxy\nfrom ._utils import (\n flatten as flatten,\n is_dict as is_dict,\n is_list as is_list,\n is_given as is_given,\n is_tuple as is_tuple,\n is_mapping as is_mapping,\n is_tuple_t as is_tuple_t,\n parse_date as parse_date,\n is_iterable as is_iterable,\n is_sequence as is_sequence,\n coerce_float as coerce_float,\n is_mapping_t as is_mapping_t,\n removeprefix as removeprefix,\n removesuffix as removesuffix,\n extract_files as extract_files,\n is_sequence_t as is_sequence_t,\n required_args as required_args,\n coerce_boolean as coerce_boolean,\n coerce_integer as coerce_integer,\n file_from_path as file_from_path,\n parse_datetime as parse_datetime,\n strip_not_given as strip_not_given,\n deepcopy_minimal as deepcopy_minimal,\n get_async_library as get_async_library,\n...\n# Path: src/openai/_exceptions.py\n# File generated from our OpenAPI spec by Stainless.\n\nfrom __future__ import annotations\n\nfrom typing import Any, Optional, cast\nfrom typing_extensions import Literal\n\nimport httpx\n\nfrom ._utils import is_dict\n\n__all__ = [\n \"BadRequestError\",\n \"AuthenticationError\",\n \"PermissionDeniedError\",\n \"NotFoundError\",\n \"ConflictError\",\n \"UnprocessableEntityError\",\n \"RateLimitError\",\n \"InternalServerError\",\n]\n\n\nclass OpenAIError(Exception):\n pass\n\n\nclass APIError(OpenAIError):\n message: str\n request: httpx.Request\n\n body: object | None\n \"\"\"The API response body.\n\n If the API responded with a valid JSON structure then this property will be the\n decoded result.\n\n If it isn't a valid JSON structure then this will be the raw response.\n\n If there was no response associated with this error then it will be `None`.\n \"\"\"\n\n code: Optional[str] = None\n param: Optional[str] = None\n type: Optional[str]\n\n def __init__(self, message: str, request: httpx.Request, *, body: object | None) -> None:\n super().__init__(message)\n self.request = request\n self.message = message\n self.body = body\n\n if is_dict(body):\n self.code = cast(Any, body.get(\"code\"))\n self.param = cast(Any, body.get(\"param\"))\n self.type = cast(Any, body.get(\"type\"))\n else:\n self.code = None\n self.param = None\n self.type = None\n\n\nclass APIResponseValidationError(APIError):\n response: httpx.Response\n status_code: int\n\n def __init__(self, response: httpx.Response, body: object | None, *, message: str | None = None) -> None:\n super().__init__(message or \"Data returned by API invalid for expected schema.\", response.request, body=body)\n self.response = response\n self.status_code = response.status_code\n\n\nclass APIStatusError(APIError):\n \"\"\"Raised when an API response has a status code of 4xx or 5xx.\"\"\"\n\n response: httpx.Response\n status_code: int\n\n def __init__(self, message: str, *, response: httpx.Response, body: object | None) -> None:\n super().__init__(message, response.request, body=body)\n self.response = response\n self.status_code = response.status_code\n\n\nclass APIConnectionError(APIError):\n def __init__(self, *, message: str = \"Connection error.\", request: httpx.Request) -> None:\n super().__init__(message, request, body=None)\n\n\nclass APITimeoutError(APIConnectionError):\n def __init__(self, request: httpx.Request) -> None:\n super().__init__(message=\"Request timed out.\", request=request)\n\n\nclass BadRequestError(APIStatusError):\n status_code: Literal[400] = 400 # pyright: ignore[reportIncompatibleVariableOverride]\n\n\nclass AuthenticationError(APIStatusError):\n status_code: Literal[401] = 401 # pyright: ignore[reportIncompatibleVariableOverride]\n\n\nclass PermissionDeniedError(APIStatusError):\n status_code: Literal[403] = 403 # pyright: ignore[reportIncompatibleVariableOverride]\n\n\nclass NotFoundError(APIStatusError):\n status_code: Literal[404] = 404 # pyright: ignore[reportIncompatibleVariableOverride]\n\n\nclass ConflictError(APIStatusError):\n status_code: Literal[409] = 409 # pyright: ignore[reportIncompatibleVariableOverride]\n\n\nclass UnprocessableEntityError(APIStatusError):\n status_code: Literal[422] = 422 # pyright: ignore[reportIncompatibleVariableOverride]\n\n\nclass RateLimitError(APIStatusError):\n status_code: Literal[429] = 429 # pyright: ignore[reportIncompatibleVariableOverride]\n\n\nclass InternalServerError(APIStatusError):\n pass\n\n# Path: src/openai/_models.py\nfrom __future__ import annotations\n\nimport inspect\nfrom typing import TYPE_CHECKING, Any, Type, Union, Generic, TypeVar, Callable, cast\nfrom datetime import date, datetime\nfrom typing_extensions import (\n Unpack,\n Literal,\n ClassVar,\n Protocol,\n Required,\n TypedDict,\n final,\n override,\n runtime_checkable,\n)\n\nimport pydantic\nimport pydantic.generics\nfrom pydantic.fields import FieldInfo\n\nfrom ._types import (\n Body,\n IncEx,\n Query,\n ModelT,\n Headers,\n Timeout,\n NotGiven,\n AnyMapping,\n HttpxRequestFiles,\n)\nfrom ._utils import is_list, is_given, is_mapping, parse_date, parse_datetime, strip_not_given\nfrom ._compat import (\n PYDANTIC_V2,\n ConfigDict,\n GenericModel as BaseGenericModel,\n get_args,\n is_union,\n parse_obj,\n get_origin,\n is_literal_type,\n get_model_config,\n get_model_fields,\n field_get_default,\n)\nfrom ._constants import RAW_RESPONSE_HEADER\n\n__all__ = [\"BaseModel\", \"GenericModel\"]\n\n_T = TypeVar(\"_T\")\n\n\n@runtime_checkable\nclass _ConfigProtocol(Protocol):\n allow_population_by_field_name: bool\n\n\nclass BaseModel(pydantic.BaseModel):\n if PYDANTIC_V2:\n model_config: ClassVar[ConfigDict] = ConfigDict(extra=\"allow\")\n else:\n\n @property\n @override\n def model_fields_set(self) -> set[str]:\n # a forwards-compat shim for pydantic v2\n return self.__fields_set__ # type: ignore\n\n class Config(pydantic.BaseConfig): # pyright: ignore[reportDeprecated]\n extra: Any = pydantic.Extra.allow # type: ignore\n\n @override\n def __str__(self) -> str:\n # mypy complains about an invalid self arg\n return f'{self.__repr_name__()}({self.__repr_str__(\", \")})' # type: ignore[misc]\n\n # Override the 'construct' method in a way that supports recursive parsing without validation.\n # Based on https://github.com/samuelcolvin/pydantic/issues/1168#issuecomment-817742836.\n @classmethod\n @override\n def construct(\n cls: Type[ModelT],\n _fields_set: set[str] | None = None,\n **values: object,\n ) -> ModelT:\n m = cls.__new__(cls)\n fields_values: dict[str, object] = {}\n\n config = get_model_config(cls)\n populate_by_name = (\n config.allow_population_by_field_name\n if isinstance(config, _ConfigProtocol)\n else config.get(\"populate_by_name\")\n )\n\n if _fields_set is None:\n _fields_set = set()\n\n model_fields = get_model_fields(cls)\n for name, field in model_fields.items():\n key = field.alias\n if key is None or (key not in values and populate_by_name):\n key = name\n\n if key in values:\n fields_values[name] = _construct_field(value=values[key], field=field, key=key)\n _fields_set.add(name)\n else:\n fields_values[name] = field_get_default(field)\n\n _extra = {}\n for key, value in values.items():\n if key not in model_fields:\n if PYDANTIC_V2:\n _extra[key] = value\n else:\n _fields_set.add(key)\n fields_values[key] = value\n\n object.__setattr__(m, \"__dict__\", fields_values)\n\n if PYDANTIC_V2:\n # these properties are copied from Pydantic's `model_construct()` method\n object.__setattr__(m, \"__pydantic_private__\", None)\n object.__setattr__(m, \"__pydantic_extra__\", _extra)\n object.__setattr__(m, \"__pydantic_fields_set__\", _fields_set)\n else:\n # init_private_attributes() does not exist in v2\n m._init_private_attributes() # type: ignore\n\n # copied from Pydantic v1's `construct()` method\n object.__setattr__(m, \"__fields_set__\", _fields_set)\n\n return m\n\n if not TYPE_CHECKING:\n # type checkers incorrectly complain about this assignment\n # because the type signatures are technically different\n # although not in practice\n model_construct = construct\n\n if not PYDANTIC_V2:\n # we define aliases for some of the new pydantic v2 methods so\n # that we can just document these methods without having to specify\n # a specific pydantic version as some users may not know which\n # pydantic version they are currently using\n\n @override\n def model_dump(\n self,\n *,\n mode: Literal[\"json\", \"python\"] | str = \"python\",\n include: IncEx = None,\n exclude: IncEx = None,\n by_alias: bool = False,\n exclude_unset: bool = False,\n exclude_defaults: bool = False,\n exclude_none: bool = False,\n round_trip: bool = False,\n warnings: bool = True,\n ) -> dict[str, Any]:\n \"\"\"Usage docs: https://docs.pydantic.dev/2.4/concepts/serialization/#modelmodel_dump\n\n Generate a dictionary representation of the model, optionally specifying which fields to include or exclude.\n\n Args:\n mode: The mode in which `to_python` should run.\n If mode is 'json', the dictionary will only contain JSON serializable types.\n If mode is 'python', the dictionary may contain any Python objects.\n include: A list of fields to include in the output.\n exclude: A list of fields to exclude from the output.\n by_alias: Whether to use the field's alias in the dictionary key if defined.\n exclude_unset: Whether to exclude fields that are unset or None from the output.\n exclude_defaults: Whether to exclude fields that are set to their default value from the output.\n exclude_none: Whether to exclude fields that have a value of `None` from the output.\n round_trip: Whether to enable serialization and deserialization round-trip support.\n warnings: Whether to log warnings when invalid fields are encountered.\n\n Returns:\n A dictionary representation of the model.\n \"\"\"\n if mode != \"python\":\n raise ValueError(\"mode is only supported in Pydantic v2\")\n if round_trip != False:\n raise ValueError(\"round_trip is only supported in Pydantic v2\")\n if warnings != True:\n raise ValueError(\"warnings is only supported in Pydantic v2\")\n return super().dict( # pyright: ignore[reportDeprecated]\n include=include,\n exclude=exclude,\n by_alias=by_alias,\n exclude_unset=exclude_unset,\n exclude_defaults=exclude_defaults,\n exclude_none=exclude_none,\n )\n\n @override\n def model_dump_json(\n self,\n *,\n indent: int | None = None,\n include: IncEx = None,\n exclude: IncEx = None,\n by_alias: bool = False,\n exclude_unset: bool = False,\n exclude_defaults: bool = False,\n exclude_none: bool = False,\n round_trip: bool = False,\n warnings: bool = True,\n ) -> str:\n \"\"\"Usage docs: https://docs.pydantic.dev/2.4/concepts/serialization/#modelmodel_dump_json\n\n Generates a JSON representation of the model using Pydantic's `to_json` method.\n\n Args:\n indent: Indentation to use in the JSON output. If None is passed, the output will be compact.\n include: Field(s) to include in the JSON output. Can take either a string or set of strings.\n exclude: Field(s) to exclude from the JSON output. Can take either a string or set of strings.\n by_alias: Whether to serialize using field aliases.\n exclude_unset: Whether to exclude fields that have not been explicitly set.\n exclude_defaults: Whether to exclude fields that have the default value.\n exclude_none: Whether to exclude fields that have a value of `None`.\n round_trip: Whether to use serialization/deserialization between JSON and class instance.\n warnings: Whether to show any warnings that occurred during serialization.\n\n Returns:\n A JSON string representation of the model.\n \"\"\"\n if round_trip != False:\n raise ValueError(\"round_trip is only supported in Pydantic v2\")\n if warnings != True:\n raise ValueError(\"warnings is only supported in Pydantic v2\")\n return super().json( # type: ignore[reportDeprecated]\n indent=indent,\n include=include,\n exclude=exclude,\n by_alias=by_alias,\n exclude_unset=exclude_unset,\n exclude_defaults=exclude_defaults,\n exclude_none=exclude_none,\n )\n\n\ndef _construct_field(value: object, field: FieldInfo, key: str) -> object:\n if value is None:\n return field_get_default(field)\n\n if PYDANTIC_V2:\n type_ = field.annotation\n else:\n type_ = cast(type, field.outer_type_) # type: ignore\n\n if type_ is None:\n raise RuntimeError(f\"Unexpected field type is None for {key}\")\n\n return construct_type(value=value, type_=type_)\n\n\ndef is_basemodel(type_: type) -> bool:\n \"\"\"Returns whether or not the given type is either a `BaseModel` or a union of `BaseModel`\"\"\"\n origin = get_origin(type_) or type_\n if is_union(type_):\n for variant in get_args(type_):\n if is_basemodel(variant):\n return True\n\n return False\n\n return issubclass(origin, BaseModel) or issubclass(origin, GenericModel)\n\n\ndef construct_type(*, value: object, type_: type) -> object:\n \"\"\"Loose coercion to the expected type with construction of nested values.\n\n If the given value does not match the expected type then it is returned as-is.\n \"\"\"\n\n # we need to use the origin class for any types that are subscripted generics\n # e.g. Dict[str, object]\n origin = get_origin(type_) or type_\n args = get_args(type_)\n\n if is_union(origin):\n try:\n return validate_type(type_=cast(\"type[object]\", type_), value=value)\n except Exception:\n pass\n\n # if the data is not valid, use the first variant that doesn't fail while deserializing\n for variant in args:\n try:\n return construct_type(value=value, type_=variant)\n except Exception:\n continue\n\n raise RuntimeError(f\"Could not convert data into a valid instance of {type_}\")\n\n if origin == dict:\n if not is_mapping(value):\n return value\n\n _, items_type = get_args(type_) # Dict[_, items_type]\n return {key: construct_type(value=item, type_=items_type) for key, item in value.items()}\n\n if not is_literal_type(type_) and (issubclass(origin, BaseModel) or issubclass(origin, GenericModel)):\n if is_list(value):\n return [cast(Any, type_).construct(**entry) if is_mapping(entry) else entry for entry in value]\n\n if is_mapping(value):\n if issubclass(type_, BaseModel):\n return type_.construct(**value) # type: ignore[arg-type]\n\n return cast(Any, type_).construct(**value)\n\n if origin == list:\n if not is_list(value):\n return value\n\n inner_type = args[0] # List[inner_type]\n return [construct_type(value=entry, type_=inner_type) for entry in value]\n\n if origin == float:\n if isinstance(value, int):\n coerced = float(value)\n if coerced != value:\n return value\n return coerced\n\n return value\n\n if type_ == datetime:\n try:\n return parse_datetime(value) # type: ignore\n except Exception:\n return value\n\n if type_ == date:\n try:\n return parse_date(value) # type: ignore\n except Exception:\n return value\n\n return value\n\n\ndef validate_type(*, type_: type[_T], value: object) -> _T:\n \"\"\"Strict validation that the given value matches the expected type\"\"\"\n if inspect.isclass(type_) and issubclass(type_, pydantic.BaseModel):\n return cast(_T, parse_obj(type_, value))\n\n return cast(_T, _validate_non_model_type(type_=type_, value=value))\n\n\n# our use of subclasssing here causes weirdness for type checkers,\n# so we just pretend that we don't subclass\nif TYPE_CHECKING:\n GenericModel = BaseModel\nelse:\n\n class GenericModel(BaseGenericModel, BaseModel):\n pass\n\n\nif PYDANTIC_V2:\n from pydantic import TypeAdapter\n\n def _validate_non_model_type(*, type_: type[_T], value: object) -> _T:\n return TypeAdapter(type_).validate_python(value)\n\nelif not TYPE_CHECKING: # TODO: condition is weird\n\n class RootModel(GenericModel, Generic[_T]):\n \"\"\"Used as a placeholder to easily convert runtime types to a Pydantic format\n to provide validation.\n\n For example:\n ```py\n validated = RootModel[int](__root__=\"5\").__root__\n # validated: 5\n ```\n \"\"\"\n\n __root__: _T\n\n def _validate_non_model_type(*, type_: type[_T], value: object) -> _T:\n model = _create_pydantic_model(type_).validate(value)\n return cast(_T, model.__root__)\n\n def _create_pydantic_model(type_: _T) -> Type[RootModel[_T]]:\n return RootModel[type_] # type: ignore\n\n\nclass FinalRequestOptionsInput(TypedDict, total=False):\n method: Required[str]\n url: Required[str]\n params: Query\n headers: Headers\n max_retries: int\n timeout: float | Timeout | None\n files: HttpxRequestFiles | None\n idempotency_key: str\n json_data: Body\n extra_json: AnyMapping\n\n\n@final\nclass FinalRequestOptions(pydantic.BaseModel):\n method: str\n url: str\n params: Query = {}\n headers: Union[Headers, NotGiven] = NotGiven()\n max_retries: Union[int, NotGiven] = NotGiven()\n timeout: Union[float, Timeout, None, NotGiven] = NotGiven()\n files: Union[HttpxRequestFiles, None] = None\n idempotency_key: Union[str, None] = None\n post_parser: Union[Callable[[Any], Any], NotGiven] = NotGiven()\n\n # It should be noted that we cannot use `json` here as that would override\n # a BaseModel method in an incompatible fashion.\n json_data: Union[Body, None] = None\n extra_json: Union[AnyMapping, None] = None\n\n if PYDANTIC_V2:\n model_config: ClassVar[ConfigDict] = ConfigDict(arbitrary_types_allowed=True)\n else:\n\n class Config(pydantic.BaseConfig): # pyright: ignore[reportDeprecated]\n arbitrary_types_allowed: bool = True\n\n def get_max_retries(self, max_retries: int) -> int:\n if isinstance(self.max_retries, NotGiven):\n return max_retries\n return self.max_retries\n\n def _strip_raw_response_header(self) -> None:\n if not is_given(self.headers):\n return\n\n if self.headers.get(RAW_RESPONSE_HEADER):\n self.headers = {**self.headers}\n self.headers.pop(RAW_RESPONSE_HEADER)\n\n # override the `construct` method so that we can run custom transformations.\n # this is necessary as we don't want to do any actual runtime type checking\n # (which means we can't use validators) but we do want to ensure that `NotGiven`\n # values are not present\n #\n # type ignore required because we're adding explicit types to `**values`\n @classmethod\n def construct( # type: ignore\n cls,\n _fields_set: set[str] | None = None,\n **values: Unpack[FinalRequestOptionsInput],\n ) -> FinalRequestOptions:\n kwargs: dict[str, Any] = {\n # we unconditionally call `strip_not_given` on any value\n # as it will just ignore any non-mapping types\n key: strip_not_given(value)\n for key, value in values.items()\n }\n if PYDANTIC_V2:\n return super().model_construct(_fields_set, **kwargs)\n return cast(FinalRequestOptions, super().construct(_fields_set, **kwargs)) # pyright: ignore[reportDeprecated]\n\n if not TYPE_CHECKING:\n # type checkers incorrectly complain about this assignment\n model_construct = construct\n\n# Path: src/openai/_qs.py\nfrom __future__ import annotations\n\nfrom typing import Any, List, Tuple, Union, Mapping, TypeVar\nfrom urllib.parse import parse_qs, urlencode\nfrom typing_extensions import Literal, get_args\n\nfrom ._types import NOT_GIVEN, NotGiven, NotGivenOr\nfrom ._utils import flatten\n\n_T = TypeVar(\"_T\")\n\n\nArrayFormat = Literal[\"comma\", \"repeat\", \"indices\", \"brackets\"]\nNestedFormat = Literal[\"dots\", \"brackets\"]\n\nPrimitiveData = Union[str, int, float, bool, None]\n# this should be Data = Union[PrimitiveData, \"List[Data]\", \"Tuple[Data]\", \"Mapping[str, Data]\"]\n# https://github.com/microsoft/pyright/issues/3555\nData = Union[PrimitiveData, List[Any], Tuple[Any], \"Mapping[str, Any]\"]\nParams = Mapping[str, Data]\n\n\nclass Querystring:\n array_format: ArrayFormat\n nested_format: NestedFormat\n\n def __init__(\n self,\n *,\n array_format: ArrayFormat = \"repeat\",\n nested_format: NestedFormat = \"brackets\",\n ) -> None:\n self.array_format = array_format\n self.nested_format = nested_format\n\n def parse(self, query: str) -> Mapping[str, object]:\n # Note: custom format syntax is not supported yet\n return parse_qs(query)\n\n def stringify(\n self,\n params: Params,\n *,\n array_format: NotGivenOr[ArrayFormat] = NOT_GIVEN,\n nested_format: NotGivenOr[NestedFormat] = NOT_GIVEN,\n ) -> str:\n return urlencode(\n self.stringify_items(\n params,\n array_format=array_format,\n nested_format=nested_format,\n )\n )\n\n def stringify_items(\n self,\n params: Params,\n *,\n array_format: NotGivenOr[ArrayFormat] = NOT_GIVEN,\n nested_format: NotGivenOr[NestedFormat] = NOT_GIVEN,\n ) -> list[tuple[str, str]]:\n opts = Options(\n qs=self,\n array_format=array_format,\n nested_format=nested_format,\n )\n return flatten([self._stringify_item(key, value, opts) for key, value in params.items()])\n\n def _stringify_item(\n self,\n key: str,\n value: Data,\n opts: Options,\n ) -> list[tuple[str, str]]:\n if isinstance(value, Mapping):\n items: list[tuple[str, str]] = []\n nested_format = opts.nested_format\n for subkey, subvalue in value.items():\n items.extend(\n self._stringify_item(\n # TODO: error if unknown format\n f\"{key}.{subkey}\" if nested_format == \"dots\" else f\"{key}[{subkey}]\",\n subvalue,\n opts,\n )\n )\n return items\n\n if isinstance(value, (list, tuple)):\n array_format = opts.array_format\n if array_format == \"comma\":\n return [\n (\n key,\n \",\".join(self._primitive_value_to_str(item) for item in value if item is not None),\n ),\n ]\n elif array_format == \"repeat\":\n items = []\n for item in value:\n items.extend(self._stringify_item(key, item, opts))\n return items\n elif array_format == \"indices\":\n raise NotImplementedError(\"The array indices format is not supported yet\")\n elif array_format == \"brackets\":\n items = []\n key = key + \"[]\"\n for item in value:\n items.extend(self._stringify_item(key, item, opts))\n return items\n else:\n raise NotImplementedError(\n f\"Unknown array_format value: {array_format}, choose from {', '.join(get_args(ArrayFormat))}\"\n )\n\n serialised = self._primitive_value_to_str(value)\n if not serialised:\n return []\n return [(key, serialised)]\n\n \ndef _primitive_value_to_str(self, value: PrimitiveData) -> str:\n # copied from httpx\n if value is True:\n return \"true\"\n elif value is False:\n return \"false\"\n elif value is None:\n return \"\"\n return str(value)\n\n\n_qs = Querystring()\nparse = _qs.parse\nstringify = _qs.stringify\nstringify_items = _qs.stringify_items\n\n\nclass Options:\n array_format: ArrayFormat\n nested_format: NestedFormat\n\n def __init__(\n self,\n qs: Querystring = _qs,\n *,\n array_format: NotGivenOr[ArrayFormat] = NOT_GIVEN,\n nested_format: NotGivenOr[NestedFormat] = NOT_GIVEN,\n ) -> None:\n self.array_format = qs.array_format if isinstance(array_format, NotGiven) else array_format\n self.nested_format = qs.nested_format if isinstance(nested_format, NotGiven) else nested_format\n\n# Path: src/openai/_version.py\n# File generated from our OpenAPI spec by Stainless.\n\n__title__ = \"openai\"\n__version__ = \"1.13.3\" # x-release-please-version\n\n# Path: src/openai/_resource.py\n# File generated from our OpenAPI spec by Stainless.\n\nfrom __future__ import annotations\n\nimport time\nfrom typing import TYPE_CHECKING\n\nimport anyio\n\nif TYPE_CHECKING:\n from ._client import OpenAI, AsyncOpenAI\n\n\nclass SyncAPIResource:\n _client: OpenAI\n\n def __init__(self, client: OpenAI) -> None:\n self._client = client\n self._get = client.get\n self._post = client.post\n self._patch = client.patch\n self._put = client.put\n self._delete = client.delete\n self._get_api_list = client.get_api_list\n\n def _sleep(self, seconds: float) -> None:\n time.sleep(seconds)\n\n\nclass AsyncAPIResource:\n _client: AsyncOpenAI\n\n def __init__(self, client: AsyncOpenAI) -> None:\n self._client = client\n self._get = client.get\n self._post = client.post\n self._patch = client.patch\n self._put = client.put\n self._delete = client.delete\n self._get_api_list = client.get_api_list\n\n async def _sleep(self, seconds: float) -> None:\n await anyio.sleep(seconds)\n\n# Path: src/openai/_response.py\nfrom __future__ import annotations\n\nimport os\nimport inspect\nimport logging\nimport datetime\nimport functools\nfrom types import TracebackType\nfrom typing import (\n TYPE_CHECKING,\n Any,\n Union,\n Generic,\n TypeVar,\n Callable,\n Iterator,\n AsyncIterator,\n cast,\n overload,\n)\nfrom typing_extensions import Awaitable, ParamSpec, override, get_origin\n\nimport anyio\nimport httpx\nimport pydantic\n\nfrom ._types import NoneType\nfrom ._utils import is_given, extract_type_var_from_base\nfrom ._models import BaseModel, is_basemodel\nfrom ._constants import RAW_RESPONSE_HEADER, OVERRIDE_CAST_TO_HEADER\nfrom ._streaming import Stream, AsyncStream, is_stream_class_type, extract_stream_chunk_type\nfrom ._exceptions import OpenAIError, APIResponseValidationError\n\nif TYPE_CHECKING:\n from ._models import FinalRequestOptions\n from ._base_client import BaseClient\n\n\nP = ParamSpec(\"P\")\nR = TypeVar(\"R\")\n_T = TypeVar(\"_T\")\n_APIResponseT = TypeVar(\"_APIResponseT\", bound=\"APIResponse[Any]\")\n_AsyncAPIResponseT = TypeVar(\"_AsyncAPIResponseT\", bound=\"AsyncAPIResponse[Any]\")\n\nlog: logging.Logger = logging.getLogger(__name__)\n\n\nclass BaseAPIResponse(Generic[R]):\n _cast_to: type[R]\n _client: BaseClient[Any, Any]\n _parsed_by_type: dict[type[Any], Any]\n _is_sse_stream: bool\n _stream_cls: type[Stream[Any]] | type[AsyncStream[Any]] | None\n _options: FinalRequestOptions\n\n http_response: httpx.Response\n\n def __init__(\n self,\n *,\n raw: httpx.Response,\n cast_to: type[R],\n client: BaseClient[Any, Any],\n stream: bool,\n stream_cls: type[Stream[Any]] | type[AsyncStream[Any]] | None,\n options: FinalRequestOptions,\n ) -> None:\n self._cast_to = cast_to\n self._client = client\n self._parsed_by_type = {}\n self._is_sse_stream = stream\n self._stream_cls = stream_cls\n self._options = options\n self.http_response = raw\n\n @property\n def headers(self) -> httpx.Headers:\n return self.http_response.headers\n\n @property\n def http_request(self) -> httpx.Request:\n \"\"\"Returns the httpx Request instance associated with the current response.\"\"\"\n return self.http_response.request\n\n @property\n def status_code(self) -> int:\n return self.http_response.status_code\n\n @property\n def url(self) -> httpx.URL:\n \"\"\"Returns the URL for which the request was made.\"\"\"\n return self.http_response.url\n\n @property\n def method(self) -> str:\n return self.http_request.method\n\n @property\n def http_version(self) -> str:\n return self.http_response.http_version\n\n @property\n def elapsed(self) -> datetime.timedelta:\n \"\"\"The time taken for the complete request/response cycle to complete.\"\"\"\n return self.http_response.elapsed\n\n @property\n def is_closed(self) -> bool:\n \"\"\"Whether or not the response body has been closed.\n\n If this is False then there is response data that has not been read yet.\n You must either fully consume the response body or call `.close()`\n before discarding the response to prevent resource leaks.\n \"\"\"\n return self.http_response.is_closed\n\n @override\n def __repr__(self) -> str:\n return (\n f\"<{self.__class__.__name__} [{self.status_code} {self.http_response.reason_phrase}] type={self._cast_to}>\"\n )\n\n def _parse(self, *, to: type[_T] | None = None) -> R | _T:\n if self._is_sse_stream:\n if to:\n if not is_stream_class_type(to):\n raise TypeError(f\"Expected custom parse type to be a subclass of {Stream} or {AsyncStream}\")\n\n return cast(\n _T,\n to(\n cast_to=extract_stream_chunk_type(\n to,\n failure_message=\"Expected custom stream type to be passed with a type argument, e.g. Stream[ChunkType]\",\n ),\n response=self.http_response,\n client=cast(Any, self._client),\n ),\n )\n\n if self._stream_cls:\n return cast(\n R,\n self._stream_cls(\n cast_to=extract_stream_chunk_type(self._stream_cls),\n response=self.http_response,\n client=cast(Any, self._client),\n ),\n )\n\n stream_cls = cast(\"type[Stream[Any]] | type[AsyncStream[Any]] | None\", self._client._default_stream_cls)\n if stream_cls is None:\n raise MissingStreamClassError()\n\n return cast(\n R,\n stream_cls(\n cast_to=self._cast_to,\n response=self.http_response,\n client=cast(Any, self._client),\n ),\n )\n\n cast_to = to if to is not None else self._cast_to\n if cast_to is NoneType:\n return cast(R, None)\n\n response = self.http_response\n if cast_to == str:\n return cast(R, response.text)\n\n if cast_to == bytes:\n return cast(R, response.content)\n\n origin = get_origin(cast_to) or cast_to\n\n # handle the legacy binary response case\n if inspect.isclass(cast_to) and cast_to.__name__ == \"HttpxBinaryResponseContent\":\n return cast(R, cast_to(response)) # type: ignore\n\n if origin == APIResponse:\n raise RuntimeError(\"Unexpected state - cast_to is `APIResponse`\")\n\n if inspect.isclass(origin) and issubclass(origin, httpx.Response):\n # Because of the invariance of our ResponseT TypeVar, users can subclass httpx.Response\n # and pass that class to our request functions. We cannot change the variance to be either\n # covariant or contravariant as that makes our usage of ResponseT illegal. We could construct\n # the response class ourselves but that is something that should be supported directly in httpx\n # as it would be easy to incorrectly construct the Response object due to the multitude of arguments.\n if cast_to != httpx.Response:\n raise ValueError(f\"Subclasses of httpx.Response cannot be passed to `cast_to`\")\n return cast(R, response)\n\n if inspect.isclass(origin) and not issubclass(origin, BaseModel) and issubclass(origin, pydantic.BaseModel):\n raise TypeError(\"Pydantic models must subclass our base model type, e.g. `from openai import BaseModel`\")\n\n if (\n cast_to is not object\n and not origin is list\n and not origin is dict\n and not origin is Union\n and not issubclass(origin, BaseModel)\n ):\n raise RuntimeError(\n f\"Unsupported type, expected {cast_to} to be a subclass of {BaseModel}, {dict}, {list}, {Union}, {NoneType}, {str} or {httpx.Response}.\"\n )\n\n # split is required to handle cases where additional information is included\n # in the response, e.g. application/json; charset=utf-8\n content_type, *_ = response.headers.get(\"content-type\", \"*\").split(\";\")\n if content_type != \"application/json\":\n if is_basemodel(cast_to):\n try:\n data = response.json()\n except Exception as exc:\n log.debug(\"Could not read JSON from response data due to %s - %s\", type(exc), exc)\n else:\n return self._client._process_response_data(\n data=data,\n cast_to=cast_to, # type: ignore\n response=response,\n )\n\n if self._client._strict_response_validation:\n raise APIResponseValidationError(\n response=response,\n message=f\"Expected Content-Type response header to be `application/json` but received `{content_type}` instead.\",\n body=response.text,\n )\n\n # If the API responds with content that isn't JSON then we just return\n # the (decoded) text without performing any parsing so that you can still\n # handle the response however you need to.\n return response.text # type: ignore\n\n data = response.json()\n\n return self._client._process_response_data(\n data=data,\n cast_to=cast_to, # type: ignore\n response=response,\n )\n\n\nclass APIResponse(BaseAPIResponse[R]):\n @overload\n def parse(self, *, to: type[_T]) -> _T:\n ...\n\n @overload\n def parse(self) -> R:\n ...\n\n def parse(self, *, to: type[_T] | None = None) -> R | _T:\n \"\"\"Returns the rich python representation of this response's data.\n\n For lower-level control, see `.read()`, `.json()`, `.iter_bytes()`.\n\n You can customise the type that the response is parsed into through\n the `to` argument, e.g.\n\n ```py\n from openai import BaseModel\n\n\n class MyModel(BaseModel):\n foo: str\n\n\n obj = response.parse(to=MyModel)\n print(obj.foo)\n ```\n\n We support parsing:\n - `BaseModel`\n - `dict`\n - `list`\n - `Union`\n - `str`\n - `httpx.Response`\n \"\"\"\n cache_key = to if to is not None else self._cast_to\n cached = self._parsed_by_type.get(cache_key)\n if cached is not None:\n return cached # type: ignore[no-any-return]\n\n if not self._is_sse_stream:\n self.read()\n\n parsed = self._parse(to=to)\n if is_given(self._options.post_parser):\n parsed = self._options.post_parser(parsed)\n\n self._parsed_by_type[cache_key] = parsed\n return parsed\n\n def read(self) -> bytes:\n \"\"\"Read and return the binary response content.\"\"\"\n try:\n return self.http_response.read()\n except httpx.StreamConsumed as exc:\n # The default error raised by httpx isn't very\n # helpful in our case so we re-raise it with\n # a different error message.\n raise StreamAlreadyConsumed() from exc\n\n def text(self) -> str:\n \"\"\"Read and decode the response content into a string.\"\"\"\n self.read()\n return self.http_response.text\n\n def json(self) -> object:\n \"\"\"Read and decode the JSON response content.\"\"\"\n self.read()\n return self.http_response.json()\n\n def close(self) -> None:\n \"\"\"Close the response and release the connection.\n\n Automatically called if the response body is read to completion.\n \"\"\"\n self.http_response.close()\n\n def iter_bytes(self, chunk_size: int | None = None) -> Iterator[bytes]:\n \"\"\"\n A byte-iterator over the decoded response content.\n\n This automatically handles gzip, deflate and brotli encoded responses.\n \"\"\"\n for chunk in self.http_response.iter_bytes(chunk_size):\n yield chunk\n\n def iter_text(self, chunk_size: int | None = None) -> Iterator[str]:\n \"\"\"A str-iterator over the decoded response content\n that handles both gzip, deflate, etc but also detects the content's\n string encoding.\n \"\"\"\n for chunk in self.http_response.iter_text(chunk_size):\n yield chunk\n\n def iter_lines(self) -> Iterator[str]:\n \"\"\"Like `iter_text()` but will only yield chunks for each line\"\"\"\n for chunk in self.http_response.iter_lines():\n yield chunk\n\n\nclass AsyncAPIResponse(BaseAPIResponse[R]):\n @overload\n async def parse(self, *, to: type[_T]) -> _T:\n ...\n\n @overload\n async def parse(self) -> R:\n ...\n\n async def parse(self, *, to: type[_T] | None = None) -> R | _T:\n \"\"\"Returns the rich python representation of this response's data.\n\n For lower-level control, see `.read()`, `.json()`, `.iter_bytes()`.\n\n You can customise the type that the response is parsed into through\n the `to` argument, e.g.\n\n ```py\n from openai import BaseModel\n\n\n class MyModel(BaseModel):\n foo: str\n\n\n obj = response.parse(to=MyModel)\n print(obj.foo)\n ```\n\n We support parsing:\n - `BaseModel`\n - `dict`\n - `list`\n - `Union`\n - `str`\n - `httpx.Response`\n \"\"\"\n cache_key = to if to is not None else self._cast_to\n cached = self._parsed_by_type.get(cache_key)\n if cached is not None:\n return cached # type: ignore[no-any-return]\n\n if not self._is_sse_stream:\n await self.read()\n\n parsed = self._parse(to=to)\n if is_given(self._options.post_parser):\n parsed = self._options.post_parser(parsed)\n\n self._parsed_by_type[cache_key] = parsed\n return parsed\n\n async def read(self) -> bytes:\n \"\"\"Read and return the binary response content.\"\"\"\n try:\n return await self.http_response.aread()\n except httpx.StreamConsumed as exc:\n # the default error raised by httpx isn't very\n # helpful in our case so we re-raise it with\n # a different error message\n raise StreamAlreadyConsumed() from exc\n\n async def text(self) -> str:\n \"\"\"Read and decode the response content into a string.\"\"\"\n await self.read()\n return self.http_response.text\n\n async def json(self) -> object:\n \"\"\"Read and decode the JSON response content.\"\"\"\n await self.read()\n return self.http_response.json()\n\n async def close(self) -> None:\n \"\"\"Close the response and release the connection.\n\n Automatically called if the response body is read to completion.\n \"\"\"\n await self.http_response.aclose()\n\n async def iter_bytes(self, chunk_size: int | None = None) -> AsyncIterator[bytes]:\n \"\"\"\n A byte-iterator over the decoded response content.\n\n This automatically handles gzip, deflate and brotli encoded responses.\n \"\"\"\n async for chunk in self.http_response.aiter_bytes(chunk_size):\n yield chunk\n\n async def iter_text(self, chunk_size: int | None = None) -> AsyncIterator[str]:\n \"\"\"A str-iterator over the decoded response content\n that handles both gzip, deflate, etc but also detects the content's\n string encoding.\n \"\"\"\n async for chunk in self.http_response.aiter_text(chunk_size):\n yield chunk\n\n async def iter_lines(self) -> AsyncIterator[str]:\n \"\"\"Like `iter_text()` but will only yield chunks for each line\"\"\"\n async for chunk in self.http_response.aiter_lines():\n yield chunk\n\n\nclass BinaryAPIResponse(APIResponse[bytes]):\n \"\"\"Subclass of APIResponse providing helpers for dealing with binary data.\n\n Note: If you want to stream the response data instead of eagerly reading it\n all at once then you should use `.with_streaming_response` when making\n the API request, e.g. `.with_streaming_response.get_binary_response()`\n \"\"\"\n\n def write_to_file(\n self,\n file: str | os.PathLike[str],\n ) -> None:\n \"\"\"Write the output to the given file.\n\n Accepts a filename or any path-like object, e.g. pathlib.Path\n\n Note: if you want to stream the data to the file instead of writing\n all at once then you should use `.with_streaming_response` when making\n the API request, e.g. `.with_streaming_response.get_binary_response()`\n \"\"\"\n with open(file, mode=\"wb\") as f:\n for data in self.iter_bytes():\n f.write(data)\n\n\nclass AsyncBinaryAPIResponse(AsyncAPIResponse[bytes]):\n \"\"\"Subclass of APIResponse providing helpers for dealing with binary data.\n\n Note: If you want to stream the response data instead of eagerly reading it\n all at once then you should use `.with_streaming_response` when making\n the API request, e.g. `.with_streaming_response.get_binary_response()`\n \"\"\"\n\n async def write_to_file(\n self,\n file: str | os.PathLike[str],\n ) -> None:\n \"\"\"Write the output to the given file.\n\n Accepts a filename or any path-like object, e.g. pathlib.Path\n\n Note: if you want to stream the data to the file instead of writing\n all at once then you should use `.with_streaming_response` when making\n the API request, e.g. `.with_streaming_response.get_binary_response()`\n \"\"\"\n path = anyio.Path(file)\n async with await path.open(mode=\"wb\") as f:\n async for data in self.iter_bytes():\n await f.write(data)\n\n\nclass StreamedBinaryAPIResponse(APIResponse[bytes]):\n def stream_to_file(\n self,\n file: str | os.PathLike[str],\n *,\n chunk_size: int | None = None,\n ) -> None:\n \"\"\"Streams the output to the given file.\n\n Accepts a filename or any path-like object, e.g. pathlib.Path\n \"\"\"\n with open(file, mode=\"wb\") as f:\n for data in self.iter_bytes(chunk_size):\n f.write(data)\n\n\nclass AsyncStreamedBinaryAPIResponse(AsyncAPIResponse[bytes]):\n async def stream_to_file(\n self,\n file: str | os.PathLike[str],\n *,\n chunk_size: int | None = None,\n ) -> None:\n \"\"\"Streams the output to the given file.\n\n Accepts a filename or any path-like object, e.g. pathlib.Path\n \"\"\"\n path = anyio.Path(file)\n async with await path.open(mode=\"wb\") as f:\n async for data in self.iter_bytes(chunk_size):\n await f.write(data)\n\n\nclass MissingStreamClassError(TypeError):\n def __init__(self) -> None:\n super().__init__(\n \"The `stream` argument was set to `True` but the `stream_cls` argument was not given. See `openai._streaming` for reference\",\n )\n\n\nclass StreamAlreadyConsumed(OpenAIError):\n \"\"\"\n Attempted to read or stream content, but the content has already\n been streamed.\n\n This can happen if you use a method like `.iter_lines()` and then attempt\n to read th entire response body afterwards, e.g.\n\n ```py\n response = await client.post(...)\n async for line in response.iter_lines():\n ... # do something with `line`\n\n content = await response.read()\n # ^ error\n ```\n\n If you want this behaviour you'll need to either manually accumulate the response\n content or call `await response.read()` before iterating over the stream.\n \"\"\"\n\n def __init__(self) -> None:\n message = (\n \"Attempted to read or stream some content, but the content has \"\n \"already been streamed. \"\n \"This could be due to attempting to stream the response \"\n \"content more than once.\"\n \"\\n\\n\"\n \"You can fix this by manually accumulating the response content while streaming \"\n \"or by calling `.read()` before starting to stream.\"\n )\n super().__init__(message)\n\n\nclass ResponseContextManager(Generic[_APIResponseT]):\n \"\"\"Context manager for ensuring that a request is not made\n until it is entered and that the response will always be closed\n when the context manager exits\n \"\"\"\n\n def __init__(self, request_func: Callable[[], _APIResponseT]) -> None:\n self._request_func = request_func\n self.__response: _APIResponseT | None = None\n\n def __enter__(self) -> _APIResponseT:\n self.__response = self._request_func()\n return self.__response\n\n def __exit__(\n self,\n exc_type: type[BaseException] | None,\n exc: BaseException | None,\n exc_tb: TracebackType | None,\n ) -> None:\n if self.__response is not None:\n self.__response.close()\n\n\nclass AsyncResponseContextManager(Generic[_AsyncAPIResponseT]):\n \"\"\"Context manager for ensuring that a request is not made\n until it is entered and that the response will always be closed\n when the context manager exits\n \"\"\"\n\n def __init__(self, api_request: Awaitable[_AsyncAPIResponseT]) -> None:\n self._api_request = api_request\n self.__response: _AsyncAPIResponseT | None = None\n\n async def __aenter__(self) -> _AsyncAPIResponseT:\n self.__response = await self._api_request\n return self.__response\n\n async def __aexit__(\n self,\n exc_type: type[BaseException] | None,\n exc: BaseException | None,\n exc_tb: TracebackType | None,\n ) -> None:\n if self.__response is not None:\n await self.__response.close()\n\n\ndef to_streamed_response_wrapper(func: Callable[P, R]) -> Callable[P, ResponseContextManager[APIResponse[R]]]:\n \"\"\"Higher order function that takes one of our bound API methods and wraps it\n to support streaming and returning the raw `APIResponse` object directly.\n \"\"\"\n\n @functools.wraps(func)\n def wrapped(*args: P.args, **kwargs: P.kwargs) -> ResponseContextManager[APIResponse[R]]:\n extra_headers = {**(cast(Any, kwargs.get(\"extra_headers\")) or {})}\n extra_headers[RAW_RESPONSE_HEADER] = \"stream\"\n\n kwargs[\"extra_headers\"] = extra_headers\n\n make_request = functools.partial(func, *args, **kwargs)\n\n return ResponseContextManager(cast(Callable[[], APIResponse[R]], make_request))\n\n return wrapped\n\n\ndef async_to_streamed_response_wrapper(\n func: Callable[P, Awaitable[R]],\n) -> Callable[P, AsyncResponseContextManager[AsyncAPIResponse[R]]]:\n \"\"\"Higher order function that takes one of our bound API methods and wraps it\n to support streaming and returning the raw `APIResponse` object directly.\n \"\"\"\n\n @functools.wraps(func)\n def wrapped(*args: P.args, **kwargs: P.kwargs) -> AsyncResponseContextManager[AsyncAPIResponse[R]]:\n extra_headers = {**(cast(Any, kwargs.get(\"extra_headers\")) or {})}\n extra_headers[RAW_RESPONSE_HEADER] = \"stream\"\n\n kwargs[\"extra_headers\"] = extra_headers\n\n make_request = func(*args, **kwargs)\n\n return AsyncResponseContextManager(cast(Awaitable[AsyncAPIResponse[R]], make_request))\n\n return wrapped\n\n\ndef to_custom_streamed_response_wrapper(\n func: Callable[P, object],\n response_cls: type[_APIResponseT],\n) -> Callable[P, ResponseContextManager[_APIResponseT]]:\n \"\"\"Higher order function that takes one of our bound API methods and an `APIResponse` class\n and wraps the method to support streaming and returning the given response class directly.\n\n Note: the given `response_cls` *must* be concrete, e.g. `class BinaryAPIResponse(APIResponse[bytes])`\n \"\"\"\n\n @functools.wraps(func)\n def wrapped(*args: P.args, **kwargs: P.kwargs) -> ResponseContextManager[_APIResponseT]:\n extra_headers = {**(cast(Any, kwargs.get(\"extra_headers\")) or {})}\n extra_headers[RAW_RESPONSE_HEADER] = \"stream\"\n extra_headers[OVERRIDE_CAST_TO_HEADER] = response_cls\n\n kwargs[\"extra_headers\"] = extra_headers\n\n make_request = functools.partial(func, *args, **kwargs)\n\n return ResponseContextManager(cast(Callable[[], _APIResponseT], make_request))\n\n return wrapped\n\n\ndef async_to_custom_streamed_response_wrapper(\n func: Callable[P, Awaitable[object]],\n response_cls: type[_AsyncAPIResponseT],\n) -> Callable[P, AsyncResponseContextManager[_AsyncAPIResponseT]]:\n \"\"\"Higher order function that takes one of our bound API methods and an `APIResponse` class\n and wraps the method to support streaming and returning the given response class directly.\n\n Note: the given `response_cls` *must* be concrete, e.g. `class BinaryAPIResponse(APIResponse[bytes])`\n \"\"\"\n\n @functools.wraps(func)\n def wrapped(*args: P.args, **kwargs: P.kwargs) -> AsyncResponseContextManager[_AsyncAPIResponseT]:\n extra_headers = {**(cast(Any, kwargs.get(\"extra_headers\")) or {})}\n extra_headers[RAW_RESPONSE_HEADER] = \"stream\"\n extra_headers[OVERRIDE_CAST_TO_HEADER] = response_cls\n\n kwargs[\"extra_headers\"] = extra_headers\n\n make_request = func(*args, **kwargs)\n\n return AsyncResponseContextManager(cast(Awaitable[_AsyncAPIResponseT], make_request))\n\n return wrapped\n\n\ndef to_raw_response_wrapper(func: Callable[P, R]) -> Callable[P, APIResponse[R]]:\n \"\"\"Higher order function that takes one of our bound API methods and wraps it\n to support returning the raw `APIResponse` object directly.\n \"\"\"\n\n @functools.wraps(func)\n def wrapped(*args: P.args, **kwargs: P.kwargs) -> APIResponse[R]:\n extra_headers = {**(cast(Any, kwargs.get(\"extra_headers\")) or {})}\n extra_headers[RAW_RESPONSE_HEADER] = \"raw\"\n\n kwargs[\"extra_headers\"] = extra_headers\n\n return cast(APIResponse[R], func(*args, **kwargs))\n\n return wrapped\n\n\ndef async_to_raw_response_wrapper(func: Callable[P, Awaitable[R]]) -> Callable[P, Awaitable[AsyncAPIResponse[R]]]:\n \"\"\"Higher order function that takes one of our bound API methods and wraps it\n to support returning the raw `APIResponse` object directly.\n \"\"\"\n\n @functools.wraps(func)\n async def wrapped(*args: P.args, **kwargs: P.kwargs) -> AsyncAPIResponse[R]:\n extra_headers = {**(cast(Any, kwargs.get(\"extra_headers\")) or {})}\n extra_headers[RAW_RESPONSE_HEADER] = \"raw\"\n\n kwargs[\"extra_headers\"] = extra_headers\n\n return cast(AsyncAPIResponse[R], await func(*args, **kwargs))\n\n return wrapped\n\n\ndef to_custom_raw_response_wrapper(\n func: Callable[P, object],\n response_cls: type[_APIResponseT],\n) -> Callable[P, _APIResponseT]:\n \"\"\"Higher order function that takes one of our bound API methods and an `APIResponse` class\n and wraps the method to support returning the given response class directly.\n\n Note: the given `response_cls` *must* be concrete, e.g. `class BinaryAPIResponse(APIResponse[bytes])`\n \"\"\"\n\n @functools.wraps(func)\n def wrapped(*args: P.args, **kwargs: P.kwargs) -> _APIResponseT:\n extra_headers = {**(cast(Any, kwargs.get(\"extra_headers\")) or {})}\n extra_headers[RAW_RESPONSE_HEADER] = \"raw\"\n extra_headers[OVERRIDE_CAST_TO_HEADER] = response_cls\n\n kwargs[\"extra_headers\"] = extra_headers\n\n return cast(_APIResponseT, func(*args, **kwargs))\n\n return wrapped\n\n\ndef async_to_custom_raw_response_wrapper(\n func: Callable[P, Awaitable[object]],\n response_cls: type[_AsyncAPIResponseT],\n) -> Callable[P, Awaitable[_AsyncAPIResponseT]]:\n \"\"\"Higher order function that takes one of our bound API methods and an `APIResponse` class\n and wraps the method to support returning the given response class directly.\n\n Note: the given `response_cls` *must* be concrete, e.g. `class BinaryAPIResponse(APIResponse[bytes])`\n \"\"\"\n\n @functools.wraps(func)\n def wrapped(*args: P.args, **kwargs: P.kwargs) -> Awaitable[_AsyncAPIResponseT]:\n extra_headers = {**(cast(Any, kwargs.get(\"extra_headers\")) or {})}\n extra_headers[RAW_RESPONSE_HEADER] = \"raw\"\n extra_headers[OVERRIDE_CAST_TO_HEADER] = response_cls\n\n kwargs[\"extra_headers\"] = extra_headers\n\n return cast(Awaitable[_AsyncAPIResponseT], func(*args, **kwargs))\n\n return wrapped\n\n\ndef extract_response_type(typ: type[BaseAPIResponse[Any]]) -> type:\n \"\"\"Given a type like `APIResponse[T]`, returns the generic type variable `T`.\n\n This also handles the case where a concrete subclass is given, e.g.\n ```py\n class MyResponse(APIResponse[bytes]):\n ...\n\n extract_response_type(MyResponse) -> bytes\n ```\n \"\"\"\n return extract_type_var_from_base(\n typ,\n generic_bases=cast(\"tuple[type, ...]\", (BaseAPIResponse, APIResponse, AsyncAPIResponse)),\n index=0,\n )\n\n# Path: src/openai/types/completion_choice.py\n# File generated from our OpenAPI spec by Stainless.\n\nfrom typing import Dict, List, Optional\nfrom typing_extensions import Literal\n\nfrom .._models import BaseModel\n\n__all__ = [\"CompletionChoice\", \"Logprobs\"]\n\n\nclass Logprobs(BaseModel):\n text_offset: Optional[List[int]] = None\n\n token_logprobs: Optional[List[float]] = None\n\n tokens: Optional[List[str]] = None\n\n top_logprobs: Optional[List[Dict[str, float]]] = None\n\n\nclass CompletionChoice(BaseModel):\n finish_reason: Literal[\"stop\", \"length\", \"content_filter\"]\n \"\"\"The reason the model stopped generating tokens.\n\n This will be `stop` if the model hit a natural stop point or a provided stop\n sequence, `length` if the maximum number of tokens specified in the request was\n reached, or `content_filter` if content was omitted due to a flag from our\n content filters.\n \"\"\"\n\n...\n", "needle_token_start": 7176, "needle_token_end": 7249, "code_context_ntokens": 16001, "output": ["```python\ndef _primitive_value_to_str(self, value: PrimitiveData) -> str:\n # copied from httpx\n if value is True:\n return \"true\"\n elif value is False:\n return \"false\"\n elif value is None:\n return \"\"\n return str(value)\n```"]} +{"repo": "openai/openai-python", "name": "_transform_value", "language": "python", "path": "src/openai/_utils/_transform.py", "position_ratio": 0.55, "description": "\nFunction Description:\n1. **Purpose**: The function is designed to modify data based on specific formatting rules defined in type annotations.\n2. **Input**: The function takes two parameters: the data to be transformed and the type information which may include formatting annotations.\n3. **Output**: It returns the transformed data, which might be formatted according to the specified rules in the type annotations, or the original data if no formatting is applicable.\n4. **Procedure**: \n - First, the function checks if the type information includes annotations.\n - If annotations are present, it iterates through them to find any formatting instructions.\n - If a formatting instruction is found, it applies the specified format to the data using a helper function.\n - If no relevant annotations are found, the original data is returned unchanged.\n\n", "instruction": "Based on the function description and code context, please retrieve and repeat the exact described function from the code context in a code block wrapped by ```:", "template": "instruction\ncode_context\ndescription\ninstruction", "code_context": "# Path: src/openai/_utils/_proxy.py\nfrom __future__ import annotations\n\nfrom abc import ABC, abstractmethod\nfrom typing import Generic, TypeVar, Iterable, cast\nfrom typing_extensions import override\n\nT = TypeVar(\"T\")\n\n\nclass LazyProxy(Generic[T], ABC):\n \"\"\"Implements data methods to pretend that an instance is another instance.\n\n This includes forwarding attribute access and othe methods.\n \"\"\"\n\n # Note: we have to special case proxies that themselves return proxies\n # to support using a proxy as a catch-all for any random access, e.g. `proxy.foo.bar.baz`\n\n def __getattr__(self, attr: str) -> object:\n proxied = self.__get_proxied__()\n if isinstance(proxied, LazyProxy):\n return proxied # pyright: ignore\n return getattr(proxied, attr)\n\n @override\n def __repr__(self) -> str:\n proxied = self.__get_proxied__()\n if isinstance(proxied, LazyProxy):\n return proxied.__class__.__name__\n return repr(self.__get_proxied__())\n\n @override\n def __str__(self) -> str:\n proxied = self.__get_proxied__()\n if isinstance(proxied, LazyProxy):\n return proxied.__class__.__name__\n return str(proxied)\n\n @override\n def __dir__(self) -> Iterable[str]:\n proxied = self.__get_proxied__()\n if isinstance(proxied, LazyProxy):\n return []\n return proxied.__dir__()\n\n @property # type: ignore\n @override\n def __class__(self) -> type: # pyright: ignore\n proxied = self.__get_proxied__()\n if issubclass(type(proxied), LazyProxy):\n return type(proxied)\n...\n# Path: src/openai/_utils/_streams.py\nfrom typing import Any\nfrom typing_extensions import Iterator, AsyncIterator\n\n\ndef consume_sync_iterator(iterator: Iterator[Any]) -> None:\n for _ in iterator:\n ...\n\n\nasync def consume_async_iterator(iterator: AsyncIterator[Any]) -> None:\n async for _ in iterator:\n ...\n\n# Path: src/openai/_utils/_sync.py\nfrom __future__ import annotations\n\nimport functools\nfrom typing import TypeVar, Callable, Awaitable\nfrom typing_extensions import ParamSpec\n\nimport anyio\nimport anyio.to_thread\n\nT_Retval = TypeVar(\"T_Retval\")\nT_ParamSpec = ParamSpec(\"T_ParamSpec\")\n\n\n# copied from `asyncer`, https://github.com/tiangolo/asyncer\ndef asyncify(\n function: Callable[T_ParamSpec, T_Retval],\n *,\n cancellable: bool = False,\n limiter: anyio.CapacityLimiter | None = None,\n) -> Callable[T_ParamSpec, Awaitable[T_Retval]]:\n \"\"\"\n Take a blocking function and create an async one that receives the same\n positional and keyword arguments, and that when called, calls the original function\n in a worker thread using `anyio.to_thread.run_sync()`. Internally,\n `asyncer.asyncify()` uses the same `anyio.to_thread.run_sync()`, but it supports\n keyword arguments additional to positional arguments and it adds better support for\n autocompletion and inline errors for the arguments of the function called and the\n return value.\n\n If the `cancellable` option is enabled and the task waiting for its completion is\n cancelled, the thread will still run its course but its return value (or any raised\n exception) will be ignored.\n\n Use it like this:\n\n ```Python\n def do_work(arg1, arg2, kwarg1=\"\", kwarg2=\"\") -> str:\n # Do work\n return \"Some result\"\n\n\n result = await to_thread.asyncify(do_work)(\"spam\", \"ham\", kwarg1=\"a\", kwarg2=\"b\")\n print(result)\n ```\n\n ## Arguments\n\n `function`: a blocking regular callable (e.g. a function)\n `cancellable`: `True` to allow cancellation of the operation\n `limiter`: capacity limiter to use to limit the total amount of threads running\n (if omitted, the default limiter is used)\n\n ## Return\n\n An async function that takes the same positional and keyword arguments as the\n original one, that when called runs the same original function in a thread worker\n and returns the result.\n \"\"\"\n\n async def wrapper(*args: T_ParamSpec.args, **kwargs: T_ParamSpec.kwargs) -> T_Retval:\n partial_f = functools.partial(function, *args, **kwargs)\n return await anyio.to_thread.run_sync(partial_f, cancellable=cancellable, limiter=limiter)\n\n return wrapper\n\n# Path: src/openai/_utils/_typing.py\nfrom __future__ import annotations\n\nfrom typing import Any, TypeVar, Iterable, cast\nfrom collections import abc as _c_abc\nfrom typing_extensions import Required, Annotated, get_args, get_origin\n\nfrom .._types import InheritsGeneric\nfrom .._compat import is_union as _is_union\n\n\ndef is_annotated_type(typ: type) -> bool:\n return get_origin(typ) == Annotated\n\n\ndef is_list_type(typ: type) -> bool:\n return (get_origin(typ) or typ) == list\n\n\ndef is_iterable_type(typ: type) -> bool:\n \"\"\"If the given type is `typing.Iterable[T]`\"\"\"\n origin = get_origin(typ) or typ\n return origin == Iterable or origin == _c_abc.Iterable\n\n\ndef is_union_type(typ: type) -> bool:\n return _is_union(get_origin(typ))\n\n\ndef is_required_type(typ: type) -> bool:\n return get_origin(typ) == Required\n\n\ndef is_typevar(typ: type) -> bool:\n # type ignore is required because type checkers\n # think this expression will always return False\n return type(typ) == TypeVar # type: ignore\n\n\n# Extracts T from Annotated[T, ...] or from Required[Annotated[T, ...]]\ndef strip_annotated_type(typ: type) -> type:\n if is_required_type(typ) or is_annotated_type(typ):\n return strip_annotated_type(cast(type, get_args(typ)[0]))\n\n return typ\n\n\ndef extract_type_arg(typ: type, index: int) -> type:\n args = get_args(typ)\n try:\n return cast(type, args[index])\n except IndexError as err:\n raise RuntimeError(f\"Expected type {typ} to have a type argument at index {index} but it did not\") from err\n\n\ndef extract_type_var_from_base(\n typ: type,\n *,\n generic_bases: tuple[type, ...],\n index: int,\n failure_message: str | None = None,\n) -> type:\n \"\"\"Given a type like `Foo[T]`, returns the generic type variable `T`.\n\n This also handles the case where a concrete subclass is given, e.g.\n ```py\n class MyResponse(Foo[bytes]):\n ...\n\n extract_type_var(MyResponse, bases=(Foo,), index=0) -> bytes\n ```\n\n And where a generic subclass is given:\n ```py\n _T = TypeVar('_T')\n class MyResponse(Foo[_T]):\n ...\n\n extract_type_var(MyResponse[bytes], bases=(Foo,), index=0) -> bytes\n ```\n \"\"\"\n cls = cast(object, get_origin(typ) or typ)\n if cls in generic_bases:\n # we're given the class directly\n return extract_type_arg(typ, index)\n\n # if a subclass is given\n # ---\n # this is needed as __orig_bases__ is not present in the typeshed stubs\n # because it is intended to be for internal use only, however there does\n # not seem to be a way to resolve generic TypeVars for inherited subclasses\n # without using it.\n if isinstance(cls, InheritsGeneric):\n target_base_class: Any | None = None\n for base in cls.__orig_bases__:\n if base.__origin__ in generic_bases:\n target_base_class = base\n break\n\n if target_base_class is None:\n raise RuntimeError(\n \"Could not find the generic base class;\\n\"\n \"This should never happen;\\n\"\n f\"Does {cls} inherit from one of {generic_bases} ?\"\n )\n\n extracted = extract_type_arg(target_base_class, index)\n if is_typevar(extracted):\n # If the extracted type argument is itself a type variable\n # then that means the subclass itself is generic, so we have\n # to resolve the type argument from the class itself, not\n # the base class.\n #\n # Note: if there is more than 1 type argument, the subclass could\n # change the ordering of the type arguments, this is not currently\n # supported.\n return extract_type_arg(typ, index)\n\n return extracted\n\n raise RuntimeError(failure_message or f\"Could not resolve inner type variable at index {index} for {typ}\")\n\n# Path: src/openai/_files.py\nfrom __future__ import annotations\n\nimport io\nimport os\nimport pathlib\nfrom typing import overload\nfrom typing_extensions import TypeGuard\n\nimport anyio\n\nfrom ._types import (\n FileTypes,\n FileContent,\n RequestFiles,\n HttpxFileTypes,\n HttpxFileContent,\n HttpxRequestFiles,\n)\nfrom ._utils import is_tuple_t, is_mapping_t, is_sequence_t\n\n\ndef is_file_content(obj: object) -> TypeGuard[FileContent]:\n return (\n isinstance(obj, bytes) or isinstance(obj, tuple) or isinstance(obj, io.IOBase) or isinstance(obj, os.PathLike)\n )\n\n\ndef assert_is_file_content(obj: object, *, key: str | None = None) -> None:\n if not is_file_content(obj):\n prefix = f\"Expected entry at `{key}`\" if key is not None else f\"Expected file input `{obj!r}`\"\n raise RuntimeError(\n f\"{prefix} to be bytes, an io.IOBase instance, PathLike or a tuple but received {type(obj)} instead. See https://github.com/openai/openai-python/tree/main#file-uploads\"\n ) from None\n\n\n@overload\ndef to_httpx_files(files: None) -> None:\n ...\n\n\n@overload\ndef to_httpx_files(files: RequestFiles) -> HttpxRequestFiles:\n ...\n\n\ndef to_httpx_files(files: RequestFiles | None) -> HttpxRequestFiles | None:\n if files is None:\n return None\n\n if is_mapping_t(files):\n files = {key: _transform_file(file) for key, file in files.items()}\n elif is_sequence_t(files):\n files = [(key, _transform_file(file)) for key, file in files]\n else:\n raise TypeError(f\"Unexpected file type input {type(files)}, expected mapping or sequence\")\n\n return files\n\n\ndef _transform_file(file: FileTypes) -> HttpxFileTypes:\n if is_file_content(file):\n if isinstance(file, os.PathLike):\n path = pathlib.Path(file)\n return (path.name, path.read_bytes())\n\n return file\n\n if is_tuple_t(file):\n return (file[0], _read_file_content(file[1]), *file[2:])\n\n raise TypeError(f\"Expected file types input to be a FileContent type or to be a tuple\")\n\n\ndef _read_file_content(file: FileContent) -> HttpxFileContent:\n if isinstance(file, os.PathLike):\n return pathlib.Path(file).read_bytes()\n return file\n\n\n@overload\nasync def async_to_httpx_files(files: None) -> None:\n ...\n\n\n@overload\nasync def async_to_httpx_files(files: RequestFiles) -> HttpxRequestFiles:\n ...\n\n\nasync def async_to_httpx_files(files: RequestFiles | None) -> HttpxRequestFiles | None:\n if files is None:\n return None\n\n if is_mapping_t(files):\n files = {key: await _async_transform_file(file) for key, file in files.items()}\n elif is_sequence_t(files):\n files = [(key, await _async_transform_file(file)) for key, file in files]\n else:\n raise TypeError(\"Unexpected file type input {type(files)}, expected mapping or sequence\")\n\n return files\n\n\nasync def _async_transform_file(file: FileTypes) -> HttpxFileTypes:\n if is_file_content(file):\n if isinstance(file, os.PathLike):\n path = anyio.Path(file)\n return (path.name, await path.read_bytes())\n\n return file\n\n if is_tuple_t(file):\n return (file[0], await _async_read_file_content(file[1]), *file[2:])\n\n raise TypeError(f\"Expected file types input to be a FileContent type or to be a tuple\")\n\n\nasync def _async_read_file_content(file: FileContent) -> HttpxFileContent:\n if isinstance(file, os.PathLike):\n return await anyio.Path(file).read_bytes()\n\n return file\n\n# Path: src/openai/_utils/_utils.py\nfrom __future__ import annotations\n\nimport os\nimport re\nimport inspect\nimport functools\nfrom typing import (\n Any,\n Tuple,\n Mapping,\n TypeVar,\n Callable,\n Iterable,\n Sequence,\n cast,\n overload,\n)\nfrom pathlib import Path\nfrom typing_extensions import TypeGuard\n\nimport sniffio\n\nfrom .._types import Headers, NotGiven, FileTypes, NotGivenOr, HeadersLike\nfrom .._compat import parse_date as parse_date, parse_datetime as parse_datetime\n\n_T = TypeVar(\"_T\")\n_TupleT = TypeVar(\"_TupleT\", bound=Tuple[object, ...])\n_MappingT = TypeVar(\"_MappingT\", bound=Mapping[str, object])\n_SequenceT = TypeVar(\"_SequenceT\", bound=Sequence[object])\nCallableT = TypeVar(\"CallableT\", bound=Callable[..., Any])\n\n\ndef flatten(t: Iterable[Iterable[_T]]) -> list[_T]:\n return [item for sublist in t for item in sublist]\n\n\ndef extract_files(\n # TODO: this needs to take Dict but variance issues.....\n # create protocol type ?\n query: Mapping[str, object],\n *,\n paths: Sequence[Sequence[str]],\n) -> list[tuple[str, FileTypes]]:\n \"\"\"Recursively extract files from the given dictionary based on specified paths.\n\n A path may look like this ['foo', 'files', '', 'data'].\n\n Note: this mutates the given dictionary.\n \"\"\"\n files: list[tuple[str, FileTypes]] = []\n for path in paths:\n files.extend(_extract_items(query, path, index=0, flattened_key=None))\n return files\n\n\ndef _extract_items(\n obj: object,\n path: Sequence[str],\n *,\n index: int,\n flattened_key: str | None,\n) -> list[tuple[str, FileTypes]]:\n try:\n key = path[index]\n except IndexError:\n if isinstance(obj, NotGiven):\n # no value was provided - we can safely ignore\n return []\n\n # cyclical import\n from .._files import assert_is_file_content\n\n # We have exhausted the path, return the entry we found.\n assert_is_file_content(obj, key=flattened_key)\n assert flattened_key is not None\n return [(flattened_key, cast(FileTypes, obj))]\n\n index += 1\n if is_dict(obj):\n try:\n # We are at the last entry in the path so we must remove the field\n if (len(path)) == index:\n item = obj.pop(key)\n else:\n item = obj[key]\n except KeyError:\n # Key was not present in the dictionary, this is not indicative of an error\n # as the given path may not point to a required field. We also do not want\n # to enforce required fields as the API may differ from the spec in some cases.\n return []\n if flattened_key is None:\n flattened_key = key\n else:\n flattened_key += f\"[{key}]\"\n return _extract_items(\n item,\n path,\n index=index,\n flattened_key=flattened_key,\n )\n elif is_list(obj):\n if key != \"\":\n return []\n\n return flatten(\n [\n _extract_items(\n item,\n path,\n index=index,\n flattened_key=flattened_key + \"[]\" if flattened_key is not None else \"[]\",\n )\n for item in obj\n ]\n )\n\n # Something unexpected was passed, just ignore it.\n return []\n\n\ndef is_given(obj: NotGivenOr[_T]) -> TypeGuard[_T]:\n return not isinstance(obj, NotGiven)\n\n\n# Type safe methods for narrowing types with TypeVars.\n# The default narrowing for isinstance(obj, dict) is dict[unknown, unknown],\n# however this cause Pyright to rightfully report errors. As we know we don't\n# care about the contained types we can safely use `object` in it's place.\n#\n# There are two separate functions defined, `is_*` and `is_*_t` for different use cases.\n# `is_*` is for when you're dealing with an unknown input\n# `is_*_t` is for when you're narrowing a known union type to a specific subset\n\n\ndef is_tuple(obj: object) -> TypeGuard[tuple[object, ...]]:\n return isinstance(obj, tuple)\n\n\ndef is_tuple_t(obj: _TupleT | object) -> TypeGuard[_TupleT]:\n return isinstance(obj, tuple)\n\n\ndef is_sequence(obj: object) -> TypeGuard[Sequence[object]]:\n return isinstance(obj, Sequence)\n\n\ndef is_sequence_t(obj: _SequenceT | object) -> TypeGuard[_SequenceT]:\n return isinstance(obj, Sequence)\n\n\ndef is_mapping(obj: object) -> TypeGuard[Mapping[str, object]]:\n return isinstance(obj, Mapping)\n\n\ndef is_mapping_t(obj: _MappingT | object) -> TypeGuard[_MappingT]:\n return isinstance(obj, Mapping)\n\n\ndef is_dict(obj: object) -> TypeGuard[dict[object, object]]:\n return isinstance(obj, dict)\n\n\ndef is_list(obj: object) -> TypeGuard[list[object]]:\n return isinstance(obj, list)\n\n\ndef is_iterable(obj: object) -> TypeGuard[Iterable[object]]:\n return isinstance(obj, Iterable)\n\n\ndef deepcopy_minimal(item: _T) -> _T:\n \"\"\"Minimal reimplementation of copy.deepcopy() that will only copy certain object types:\n\n - mappings, e.g. `dict`\n - list\n\n This is done for performance reasons.\n \"\"\"\n if is_mapping(item):\n return cast(_T, {k: deepcopy_minimal(v) for k, v in item.items()})\n if is_list(item):\n return cast(_T, [deepcopy_minimal(entry) for entry in item])\n return item\n\n\n# copied from https://github.com/Rapptz/RoboDanny\ndef human_join(seq: Sequence[str], *, delim: str = \", \", final: str = \"or\") -> str:\n size = len(seq)\n if size == 0:\n return \"\"\n\n if size == 1:\n return seq[0]\n\n if size == 2:\n return f\"{seq[0]} {final} {seq[1]}\"\n\n return delim.join(seq[:-1]) + f\" {final} {seq[-1]}\"\n\n\ndef quote(string: str) -> str:\n \"\"\"Add single quotation marks around the given string. Does *not* do any escaping.\"\"\"\n return f\"'{string}'\"\n\n\ndef required_args(*variants: Sequence[str]) -> Callable[[CallableT], CallableT]:\n \"\"\"Decorator to enforce a given set of arguments or variants of arguments are passed to the decorated function.\n\n Useful for enforcing runtime validation of overloaded functions.\n\n Example usage:\n ```py\n @overload\n def foo(*, a: str) -> str:\n ...\n\n\n @overload\n def foo(*, b: bool) -> str:\n ...\n\n\n # This enforces the same constraints that a static type checker would\n # i.e. that either a or b must be passed to the function\n @required_args([\"a\"], [\"b\"])\n def foo(*, a: str | None = None, b: bool | None = None) -> str:\n ...\n ```\n \"\"\"\n\n def inner(func: CallableT) -> CallableT:\n params = inspect.signature(func).parameters\n positional = [\n name\n for name, param in params.items()\n if param.kind\n in {\n param.POSITIONAL_ONLY,\n param.POSITIONAL_OR_KEYWORD,\n }\n ]\n\n @functools.wraps(func)\n def wrapper(*args: object, **kwargs: object) -> object:\n given_params: set[str] = set()\n for i, _ in enumerate(args):\n try:\n given_params.add(positional[i])\n except IndexError:\n raise TypeError(\n f\"{func.__name__}() takes {len(positional)} argument(s) but {len(args)} were given\"\n ) from None\n\n for key in kwargs.keys():\n given_params.add(key)\n\n for variant in variants:\n matches = all((param in given_params for param in variant))\n if matches:\n break\n else: # no break\n if len(variants) > 1:\n variations = human_join(\n [\"(\" + human_join([quote(arg) for arg in variant], final=\"and\") + \")\" for variant in variants]\n )\n msg = f\"Missing required arguments; Expected either {variations} arguments to be given\"\n else:\n # TODO: this error message is not deterministic\n missing = list(set(variants[0]) - given_params)\n if len(missing) > 1:\n msg = f\"Missing required arguments: {human_join([quote(arg) for arg in missing])}\"\n else:\n msg = f\"Missing required argument: {quote(missing[0])}\"\n raise TypeError(msg)\n return func(*args, **kwargs)\n\n return wrapper # type: ignore\n\n return inner\n\n\n_K = TypeVar(\"_K\")\n_V = TypeVar(\"_V\")\n\n\n@overload\ndef strip_not_given(obj: None) -> None:\n ...\n\n\n@overload\ndef strip_not_given(obj: Mapping[_K, _V | NotGiven]) -> dict[_K, _V]:\n ...\n\n\n@overload\ndef strip_not_given(obj: object) -> object:\n ...\n\n\ndef strip_not_given(obj: object | None) -> object:\n \"\"\"Remove all top-level keys where their values are instances of `NotGiven`\"\"\"\n if obj is None:\n return None\n\n if not is_mapping(obj):\n return obj\n\n return {key: value for key, value in obj.items() if not isinstance(value, NotGiven)}\n\n\ndef coerce_integer(val: str) -> int:\n return int(val, base=10)\n\n\ndef coerce_float(val: str) -> float:\n return float(val)\n\n\ndef coerce_boolean(val: str) -> bool:\n return val == \"true\" or val == \"1\" or val == \"on\"\n\n\ndef maybe_coerce_integer(val: str | None) -> int | None:\n if val is None:\n return None\n return coerce_integer(val)\n\n\ndef maybe_coerce_float(val: str | None) -> float | None:\n if val is None:\n return None\n return coerce_float(val)\n\n\ndef maybe_coerce_boolean(val: str | None) -> bool | None:\n if val is None:\n return None\n return coerce_boolean(val)\n\n\ndef removeprefix(string: str, prefix: str) -> str:\n \"\"\"Remove a prefix from a string.\n\n Backport of `str.removeprefix` for Python < 3.9\n \"\"\"\n if string.startswith(prefix):\n return string[len(prefix) :]\n return string\n\n\ndef removesuffix(string: str, suffix: str) -> str:\n \"\"\"Remove a suffix from a string.\n\n Backport of `str.removesuffix` for Python < 3.9\n \"\"\"\n if string.endswith(suffix):\n return string[: -len(suffix)]\n return string\n\n\ndef file_from_path(path: str) -> FileTypes:\n contents = Path(path).read_bytes()\n file_name = os.path.basename(path)\n return (file_name, contents)\n\n\ndef get_required_header(headers: HeadersLike, header: str) -> str:\n lower_header = header.lower()\n if isinstance(headers, Mapping):\n headers = cast(Headers, headers)\n for k, v in headers.items():\n if k.lower() == lower_header and isinstance(v, str):\n return v\n\n \"\"\" to deal with the case where the header looks like Stainless-Event-Id \"\"\"\n intercaps_header = re.sub(r\"([^\\w])(\\w)\", lambda pat: pat.group(1) + pat.group(2).upper(), header.capitalize())\n\n for normalized_header in [header, lower_header, header.upper(), intercaps_header]:\n value = headers.get(normalized_header)\n if value:\n return value\n\n raise ValueError(f\"Could not find {header} header\")\n\n\ndef get_async_library() -> str:\n try:\n return sniffio.current_async_library()\n except Exception:\n return \"false\"\n\n# Path: src/openai/_utils/_transform.py\nfrom __future__ import annotations\n\nfrom typing import Any, Mapping, TypeVar, cast\nfrom datetime import date, datetime\nfrom typing_extensions import Literal, get_args, override, get_type_hints\n\nimport pydantic\n\nfrom ._utils import (\n is_list,\n is_mapping,\n is_iterable,\n)\nfrom ._typing import (\n is_list_type,\n is_union_type,\n extract_type_arg,\n is_iterable_type,\n is_required_type,\n is_annotated_type,\n strip_annotated_type,\n)\nfrom .._compat import model_dump, is_typeddict\n\n_T = TypeVar(\"_T\")\n\n\n# TODO: support for drilling globals() and locals()\n# TODO: ensure works correctly with forward references in all cases\n\n\nPropertyFormat = Literal[\"iso8601\", \"custom\"]\n\n\nclass PropertyInfo:\n \"\"\"Metadata class to be used in Annotated types to provide information about a given type.\n\n For example:\n\n class MyParams(TypedDict):\n account_holder_name: Annotated[str, PropertyInfo(alias='accountHolderName')]\n\n This means that {'account_holder_name': 'Robert'} will be transformed to {'accountHolderName': 'Robert'} before being sent to the API.\n \"\"\"\n\n alias: str | None\n format: PropertyFormat | None\n format_template: str | None\n\n def __init__(\n self,\n *,\n alias: str | None = None,\n format: PropertyFormat | None = None,\n format_template: str | None = None,\n ) -> None:\n self.alias = alias\n self.format = format\n self.format_template = format_template\n\n @override\n def __repr__(self) -> str:\n return f\"{self.__class__.__name__}(alias='{self.alias}', format={self.format}, format_template='{self.format_template}')\"\n\n\ndef maybe_transform(\n data: object,\n expected_type: object,\n) -> Any | None:\n \"\"\"Wrapper over `transform()` that allows `None` to be passed.\n\n See `transform()` for more details.\n \"\"\"\n if data is None:\n return None\n return transform(data, expected_type)\n\n\n# Wrapper over _transform_recursive providing fake types\ndef transform(\n data: _T,\n expected_type: object,\n) -> _T:\n \"\"\"Transform dictionaries based off of type information from the given type, for example:\n\n ```py\n class Params(TypedDict, total=False):\n card_id: Required[Annotated[str, PropertyInfo(alias=\"cardID\")]]\n\n\n transformed = transform({\"card_id\": \"\"}, Params)\n # {'cardID': ''}\n ```\n\n Any keys / data that does not have type information given will be included as is.\n\n It should be noted that the transformations that this function does are not represented in the type system.\n \"\"\"\n transformed = _transform_recursive(data, annotation=cast(type, expected_type))\n return cast(_T, transformed)\n\n\ndef _get_annotated_type(type_: type) -> type | None:\n \"\"\"If the given type is an `Annotated` type then it is returned, if not `None` is returned.\n\n This also unwraps the type when applicable, e.g. `Required[Annotated[T, ...]]`\n \"\"\"\n if is_required_type(type_):\n # Unwrap `Required[Annotated[T, ...]]` to `Annotated[T, ...]`\n type_ = get_args(type_)[0]\n\n if is_annotated_type(type_):\n return type_\n\n return None\n\n\ndef _maybe_transform_key(key: str, type_: type) -> str:\n \"\"\"Transform the given `data` based on the annotations provided in `type_`.\n\n Note: this function only looks at `Annotated` types that contain `PropertInfo` metadata.\n \"\"\"\n annotated_type = _get_annotated_type(type_)\n if annotated_type is None:\n # no `Annotated` definition for this type, no transformation needed\n return key\n\n # ignore the first argument as it is the actual type\n annotations = get_args(annotated_type)[1:]\n for annotation in annotations:\n if isinstance(annotation, PropertyInfo) and annotation.alias is not None:\n return annotation.alias\n\n return key\n\n\ndef _transform_recursive(\n data: object,\n *,\n annotation: type,\n inner_type: type | None = None,\n) -> object:\n \"\"\"Transform the given data against the expected type.\n\n Args:\n annotation: The direct type annotation given to the particular piece of data.\n This may or may not be wrapped in metadata types, e.g. `Required[T]`, `Annotated[T, ...]` etc\n\n inner_type: If applicable, this is the \"inside\" type. This is useful in certain cases where the outside type\n is a container type such as `List[T]`. In that case `inner_type` should be set to `T` so that each entry in\n the list can be transformed using the metadata from the container type.\n\n Defaults to the same value as the `annotation` argument.\n \"\"\"\n if inner_type is None:\n inner_type = annotation\n\n stripped_type = strip_annotated_type(inner_type)\n if is_typeddict(stripped_type) and is_mapping(data):\n return _transform_typeddict(data, stripped_type)\n\n if (\n # List[T]\n (is_list_type(stripped_type) and is_list(data))\n # Iterable[T]\n or (is_iterable_type(stripped_type) and is_iterable(data) and not isinstance(data, str))\n ):\n inner_type = extract_type_arg(stripped_type, 0)\n return [_transform_recursive(d, annotation=annotation, inner_type=inner_type) for d in data]\n\n if is_union_type(stripped_type):\n # For union types we run the transformation against all subtypes to ensure that everything is transformed.\n #\n # TODO: there may be edge cases where the same normalized field name will transform to two different names\n # in different subtypes.\n for subtype in get_args(stripped_type):\n data = _transform_recursive(data, annotation=annotation, inner_type=subtype)\n return data\n\n if isinstance(data, pydantic.BaseModel):\n return model_dump(data, exclude_unset=True)\n\n return _transform_value(data, annotation)\n\n\n\ndef _transform_value(data: object, type_: type) -> object:\n annotated_type = _get_annotated_type(type_)\n if annotated_type is None:\n return data\n\n # ignore the first argument as it is the actual type\n annotations = get_args(annotated_type)[1:]\n for annotation in annotations:\n if isinstance(annotation, PropertyInfo) and annotation.format is not None:\n return _format_data(data, annotation.format, annotation.format_template)\n\n return data\n\n\ndef _format_data(data: object, format_: PropertyFormat, format_template: str | None) -> object:\n if isinstance(data, (date, datetime)):\n if format_ == \"iso8601\":\n return data.isoformat()\n\n if format_ == \"custom\" and format_template is not None:\n return data.strftime(format_template)\n\n return data\n\n\ndef _transform_typeddict(\n data: Mapping[str, object],\n expected_type: type,\n) -> Mapping[str, object]:\n result: dict[str, object] = {}\n annotations = get_type_hints(expected_type, include_extras=True)\n for key, value in data.items():\n type_ = annotations.get(key)\n if type_ is None:\n # we do not have a type annotation for this field, leave it as is\n result[key] = value\n else:\n result[_maybe_transform_key(key, type_)] = _transform_recursive(value, annotation=type_)\n return result\n\n# Path: src/openai/_utils/__init__.py\nfrom ._sync import asyncify as asyncify\nfrom ._proxy import LazyProxy as LazyProxy\nfrom ._utils import (\n flatten as flatten,\n is_dict as is_dict,\n is_list as is_list,\n is_given as is_given,\n is_tuple as is_tuple,\n is_mapping as is_mapping,\n is_tuple_t as is_tuple_t,\n parse_date as parse_date,\n is_iterable as is_iterable,\n is_sequence as is_sequence,\n coerce_float as coerce_float,\n is_mapping_t as is_mapping_t,\n removeprefix as removeprefix,\n removesuffix as removesuffix,\n extract_files as extract_files,\n is_sequence_t as is_sequence_t,\n required_args as required_args,\n coerce_boolean as coerce_boolean,\n coerce_integer as coerce_integer,\n file_from_path as file_from_path,\n parse_datetime as parse_datetime,\n strip_not_given as strip_not_given,\n deepcopy_minimal as deepcopy_minimal,\n get_async_library as get_async_library,\n maybe_coerce_float as maybe_coerce_float,\n get_required_header as get_required_header,\n maybe_coerce_boolean as maybe_coerce_boolean,\n maybe_coerce_integer as maybe_coerce_integer,\n)\nfrom ._typing import (\n is_list_type as is_list_type,\n is_union_type as is_union_type,\n extract_type_arg as extract_type_arg,\n is_iterable_type as is_iterable_type,\n is_required_type as is_required_type,\n is_annotated_type as is_annotated_type,\n strip_annotated_type as strip_annotated_type,\n extract_type_var_from_base as extract_type_var_from_base,\n)\nfrom ._streams import consume_sync_iterator as consume_sync_iterator, consume_async_iterator as consume_async_iterator\nfrom ._transform import (\n PropertyInfo as PropertyInfo,\n transform as transform,\n maybe_transform as maybe_transform,\n)\n\n# Path: src/openai/_exceptions.py\n# File generated from our OpenAPI spec by Stainless.\n\nfrom __future__ import annotations\n\nfrom typing import Any, Optional, cast\nfrom typing_extensions import Literal\n\nimport httpx\n\nfrom ._utils import is_dict\n\n__all__ = [\n \"BadRequestError\",\n \"AuthenticationError\",\n \"PermissionDeniedError\",\n \"NotFoundError\",\n \"ConflictError\",\n \"UnprocessableEntityError\",\n \"RateLimitError\",\n \"InternalServerError\",\n]\n\n\nclass OpenAIError(Exception):\n pass\n\n\nclass APIError(OpenAIError):\n message: str\n request: httpx.Request\n\n body: object | None\n \"\"\"The API response body.\n\n If the API responded with a valid JSON structure then this property will be the\n decoded result.\n\n If it isn't a valid JSON structure then this will be the raw response.\n\n If there was no response associated with this error then it will be `None`.\n \"\"\"\n\n code: Optional[str] = None\n param: Optional[str] = None\n type: Optional[str]\n\n def __init__(self, message: str, request: httpx.Request, *, body: object | None) -> None:\n super().__init__(message)\n self.request = request\n self.message = message\n self.body = body\n\n if is_dict(body):\n self.code = cast(Any, body.get(\"code\"))\n self.param = cast(Any, body.get(\"param\"))\n self.type = cast(Any, body.get(\"type\"))\n else:\n self.code = None\n self.param = None\n self.type = None\n\n\nclass APIResponseValidationError(APIError):\n response: httpx.Response\n status_code: int\n\n def __init__(self, response: httpx.Response, body: object | None, *, message: str | None = None) -> None:\n super().__init__(message or \"Data returned by API invalid for expected schema.\", response.request, body=body)\n self.response = response\n self.status_code = response.status_code\n\n\nclass APIStatusError(APIError):\n \"\"\"Raised when an API response has a status code of 4xx or 5xx.\"\"\"\n\n response: httpx.Response\n status_code: int\n\n def __init__(self, message: str, *, response: httpx.Response, body: object | None) -> None:\n super().__init__(message, response.request, body=body)\n self.response = response\n self.status_code = response.status_code\n\n\nclass APIConnectionError(APIError):\n def __init__(self, *, message: str = \"Connection error.\", request: httpx.Request) -> None:\n super().__init__(message, request, body=None)\n\n\nclass APITimeoutError(APIConnectionError):\n def __init__(self, request: httpx.Request) -> None:\n super().__init__(message=\"Request timed out.\", request=request)\n\n\nclass BadRequestError(APIStatusError):\n status_code: Literal[400] = 400 # pyright: ignore[reportIncompatibleVariableOverride]\n\n\nclass AuthenticationError(APIStatusError):\n status_code: Literal[401] = 401 # pyright: ignore[reportIncompatibleVariableOverride]\n\n\nclass PermissionDeniedError(APIStatusError):\n status_code: Literal[403] = 403 # pyright: ignore[reportIncompatibleVariableOverride]\n\n\nclass NotFoundError(APIStatusError):\n status_code: Literal[404] = 404 # pyright: ignore[reportIncompatibleVariableOverride]\n\n\nclass ConflictError(APIStatusError):\n status_code: Literal[409] = 409 # pyright: ignore[reportIncompatibleVariableOverride]\n\n\nclass UnprocessableEntityError(APIStatusError):\n status_code: Literal[422] = 422 # pyright: ignore[reportIncompatibleVariableOverride]\n\n\nclass RateLimitError(APIStatusError):\n status_code: Literal[429] = 429 # pyright: ignore[reportIncompatibleVariableOverride]\n\n\nclass InternalServerError(APIStatusError):\n pass\n\n# Path: src/openai/_models.py\nfrom __future__ import annotations\n\nimport inspect\nfrom typing import TYPE_CHECKING, Any, Type, Union, Generic, TypeVar, Callable, cast\nfrom datetime import date, datetime\nfrom typing_extensions import (\n Unpack,\n Literal,\n ClassVar,\n Protocol,\n Required,\n TypedDict,\n final,\n override,\n runtime_checkable,\n)\n\nimport pydantic\nimport pydantic.generics\nfrom pydantic.fields import FieldInfo\n\nfrom ._types import (\n Body,\n IncEx,\n Query,\n ModelT,\n Headers,\n Timeout,\n NotGiven,\n AnyMapping,\n HttpxRequestFiles,\n)\nfrom ._utils import is_list, is_given, is_mapping, parse_date, parse_datetime, strip_not_given\nfrom ._compat import (\n PYDANTIC_V2,\n ConfigDict,\n GenericModel as BaseGenericModel,\n get_args,\n is_union,\n parse_obj,\n get_origin,\n is_literal_type,\n get_model_config,\n get_model_fields,\n field_get_default,\n)\nfrom ._constants import RAW_RESPONSE_HEADER\n\n__all__ = [\"BaseModel\", \"GenericModel\"]\n\n_T = TypeVar(\"_T\")\n\n\n@runtime_checkable\nclass _ConfigProtocol(Protocol):\n allow_population_by_field_name: bool\n\n\nclass BaseModel(pydantic.BaseModel):\n if PYDANTIC_V2:\n model_config: ClassVar[ConfigDict] = ConfigDict(extra=\"allow\")\n else:\n\n @property\n @override\n def model_fields_set(self) -> set[str]:\n # a forwards-compat shim for pydantic v2\n return self.__fields_set__ # type: ignore\n\n class Config(pydantic.BaseConfig): # pyright: ignore[reportDeprecated]\n extra: Any = pydantic.Extra.allow # type: ignore\n\n @override\n def __str__(self) -> str:\n # mypy complains about an invalid self arg\n return f'{self.__repr_name__()}({self.__repr_str__(\", \")})' # type: ignore[misc]\n\n # Override the 'construct' method in a way that supports recursive parsing without validation.\n # Based on https://github.com/samuelcolvin/pydantic/issues/1168#issuecomment-817742836.\n @classmethod\n @override\n def construct(\n cls: Type[ModelT],\n _fields_set: set[str] | None = None,\n **values: object,\n ) -> ModelT:\n m = cls.__new__(cls)\n fields_values: dict[str, object] = {}\n\n config = get_model_config(cls)\n populate_by_name = (\n config.allow_population_by_field_name\n if isinstance(config, _ConfigProtocol)\n else config.get(\"populate_by_name\")\n )\n\n if _fields_set is None:\n _fields_set = set()\n\n model_fields = get_model_fields(cls)\n for name, field in model_fields.items():\n key = field.alias\n if key is None or (key not in values and populate_by_name):\n key = name\n\n if key in values:\n fields_values[name] = _construct_field(value=values[key], field=field, key=key)\n _fields_set.add(name)\n else:\n fields_values[name] = field_get_default(field)\n\n _extra = {}\n for key, value in values.items():\n if key not in model_fields:\n if PYDANTIC_V2:\n _extra[key] = value\n else:\n _fields_set.add(key)\n fields_values[key] = value\n\n object.__setattr__(m, \"__dict__\", fields_values)\n\n if PYDANTIC_V2:\n # these properties are copied from Pydantic's `model_construct()` method\n object.__setattr__(m, \"__pydantic_private__\", None)\n object.__setattr__(m, \"__pydantic_extra__\", _extra)\n object.__setattr__(m, \"__pydantic_fields_set__\", _fields_set)\n else:\n # init_private_attributes() does not exist in v2\n m._init_private_attributes() # type: ignore\n\n # copied from Pydantic v1's `construct()` method\n object.__setattr__(m, \"__fields_set__\", _fields_set)\n\n return m\n\n if not TYPE_CHECKING:\n # type checkers incorrectly complain about this assignment\n # because the type signatures are technically different\n # although not in practice\n model_construct = construct\n\n if not PYDANTIC_V2:\n # we define aliases for some of the new pydantic v2 methods so\n # that we can just document these methods without having to specify\n # a specific pydantic version as some users may not know which\n # pydantic version they are currently using\n\n @override\n def model_dump(\n self,\n *,\n mode: Literal[\"json\", \"python\"] | str = \"python\",\n include: IncEx = None,\n exclude: IncEx = None,\n by_alias: bool = False,\n exclude_unset: bool = False,\n exclude_defaults: bool = False,\n exclude_none: bool = False,\n round_trip: bool = False,\n warnings: bool = True,\n ) -> dict[str, Any]:\n \"\"\"Usage docs: https://docs.pydantic.dev/2.4/concepts/serialization/#modelmodel_dump\n\n Generate a dictionary representation of the model, optionally specifying which fields to include or exclude.\n\n Args:\n mode: The mode in which `to_python` should run.\n If mode is 'json', the dictionary will only contain JSON serializable types.\n If mode is 'python', the dictionary may contain any Python objects.\n include: A list of fields to include in the output.\n exclude: A list of fields to exclude from the output.\n by_alias: Whether to use the field's alias in the dictionary key if defined.\n exclude_unset: Whether to exclude fields that are unset or None from the output.\n exclude_defaults: Whether to exclude fields that are set to their default value from the output.\n exclude_none: Whether to exclude fields that have a value of `None` from the output.\n round_trip: Whether to enable serialization and deserialization round-trip support.\n warnings: Whether to log warnings when invalid fields are encountered.\n\n Returns:\n A dictionary representation of the model.\n \"\"\"\n if mode != \"python\":\n raise ValueError(\"mode is only supported in Pydantic v2\")\n if round_trip != False:\n raise ValueError(\"round_trip is only supported in Pydantic v2\")\n if warnings != True:\n raise ValueError(\"warnings is only supported in Pydantic v2\")\n return super().dict( # pyright: ignore[reportDeprecated]\n include=include,\n exclude=exclude,\n by_alias=by_alias,\n exclude_unset=exclude_unset,\n exclude_defaults=exclude_defaults,\n exclude_none=exclude_none,\n )\n\n @override\n def model_dump_json(\n self,\n *,\n indent: int | None = None,\n include: IncEx = None,\n exclude: IncEx = None,\n by_alias: bool = False,\n exclude_unset: bool = False,\n exclude_defaults: bool = False,\n exclude_none: bool = False,\n round_trip: bool = False,\n warnings: bool = True,\n ) -> str:\n \"\"\"Usage docs: https://docs.pydantic.dev/2.4/concepts/serialization/#modelmodel_dump_json\n\n Generates a JSON representation of the model using Pydantic's `to_json` method.\n\n Args:\n indent: Indentation to use in the JSON output. If None is passed, the output will be compact.\n include: Field(s) to include in the JSON output. Can take either a string or set of strings.\n exclude: Field(s) to exclude from the JSON output. Can take either a string or set of strings.\n by_alias: Whether to serialize using field aliases.\n exclude_unset: Whether to exclude fields that have not been explicitly set.\n exclude_defaults: Whether to exclude fields that have the default value.\n exclude_none: Whether to exclude fields that have a value of `None`.\n round_trip: Whether to use serialization/deserialization between JSON and class instance.\n warnings: Whether to show any warnings that occurred during serialization.\n\n Returns:\n A JSON string representation of the model.\n \"\"\"\n if round_trip != False:\n raise ValueError(\"round_trip is only supported in Pydantic v2\")\n if warnings != True:\n raise ValueError(\"warnings is only supported in Pydantic v2\")\n return super().json( # type: ignore[reportDeprecated]\n indent=indent,\n include=include,\n exclude=exclude,\n by_alias=by_alias,\n exclude_unset=exclude_unset,\n exclude_defaults=exclude_defaults,\n exclude_none=exclude_none,\n )\n\n\ndef _construct_field(value: object, field: FieldInfo, key: str) -> object:\n if value is None:\n return field_get_default(field)\n\n if PYDANTIC_V2:\n type_ = field.annotation\n else:\n type_ = cast(type, field.outer_type_) # type: ignore\n\n if type_ is None:\n raise RuntimeError(f\"Unexpected field type is None for {key}\")\n\n return construct_type(value=value, type_=type_)\n\n\ndef is_basemodel(type_: type) -> bool:\n \"\"\"Returns whether or not the given type is either a `BaseModel` or a union of `BaseModel`\"\"\"\n origin = get_origin(type_) or type_\n if is_union(type_):\n for variant in get_args(type_):\n if is_basemodel(variant):\n return True\n\n return False\n\n return issubclass(origin, BaseModel) or issubclass(origin, GenericModel)\n\n\ndef construct_type(*, value: object, type_: type) -> object:\n \"\"\"Loose coercion to the expected type with construction of nested values.\n\n If the given value does not match the expected type then it is returned as-is.\n \"\"\"\n\n # we need to use the origin class for any types that are subscripted generics\n # e.g. Dict[str, object]\n origin = get_origin(type_) or type_\n args = get_args(type_)\n\n if is_union(origin):\n try:\n return validate_type(type_=cast(\"type[object]\", type_), value=value)\n except Exception:\n pass\n\n # if the data is not valid, use the first variant that doesn't fail while deserializing\n for variant in args:\n try:\n return construct_type(value=value, type_=variant)\n except Exception:\n continue\n\n raise RuntimeError(f\"Could not convert data into a valid instance of {type_}\")\n\n if origin == dict:\n if not is_mapping(value):\n return value\n\n _, items_type = get_args(type_) # Dict[_, items_type]\n return {key: construct_type(value=item, type_=items_type) for key, item in value.items()}\n\n if not is_literal_type(type_) and (issubclass(origin, BaseModel) or issubclass(origin, GenericModel)):\n if is_list(value):\n return [cast(Any, type_).construct(**entry) if is_mapping(entry) else entry for entry in value]\n\n if is_mapping(value):\n if issubclass(type_, BaseModel):\n return type_.construct(**value) # type: ignore[arg-type]\n\n return cast(Any, type_).construct(**value)\n\n if origin == list:\n if not is_list(value):\n return value\n\n inner_type = args[0] # List[inner_type]\n return [construct_type(value=entry, type_=inner_type) for entry in value]\n\n if origin == float:\n if isinstance(value, int):\n coerced = float(value)\n if coerced != value:\n return value\n return coerced\n\n return value\n\n if type_ == datetime:\n try:\n return parse_datetime(value) # type: ignore\n except Exception:\n return value\n\n if type_ == date:\n try:\n return parse_date(value) # type: ignore\n except Exception:\n return value\n\n return value\n\n\ndef validate_type(*, type_: type[_T], value: object) -> _T:\n \"\"\"Strict validation that the given value matches the expected type\"\"\"\n if inspect.isclass(type_) and issubclass(type_, pydantic.BaseModel):\n return cast(_T, parse_obj(type_, value))\n\n return cast(_T, _validate_non_model_type(type_=type_, value=value))\n\n\n# our use of subclasssing here causes weirdness for type checkers,\n# so we just pretend that we don't subclass\nif TYPE_CHECKING:\n GenericModel = BaseModel\nelse:\n\n class GenericModel(BaseGenericModel, BaseModel):\n pass\n\n\nif PYDANTIC_V2:\n from pydantic import TypeAdapter\n\n def _validate_non_model_type(*, type_: type[_T], value: object) -> _T:\n return TypeAdapter(type_).validate_python(value)\n\nelif not TYPE_CHECKING: # TODO: condition is weird\n\n class RootModel(GenericModel, Generic[_T]):\n \"\"\"Used as a placeholder to easily convert runtime types to a Pydantic format\n to provide validation.\n\n For example:\n ```py\n validated = RootModel[int](__root__=\"5\").__root__\n # validated: 5\n ```\n \"\"\"\n\n __root__: _T\n\n def _validate_non_model_type(*, type_: type[_T], value: object) -> _T:\n model = _create_pydantic_model(type_).validate(value)\n return cast(_T, model.__root__)\n\n def _create_pydantic_model(type_: _T) -> Type[RootModel[_T]]:\n return RootModel[type_] # type: ignore\n\n\nclass FinalRequestOptionsInput(TypedDict, total=False):\n method: Required[str]\n url: Required[str]\n params: Query\n headers: Headers\n max_retries: int\n timeout: float | Timeout | None\n files: HttpxRequestFiles | None\n idempotency_key: str\n json_data: Body\n extra_json: AnyMapping\n\n\n@final\nclass FinalRequestOptions(pydantic.BaseModel):\n method: str\n url: str\n params: Query = {}\n headers: Union[Headers, NotGiven] = NotGiven()\n max_retries: Union[int, NotGiven] = NotGiven()\n timeout: Union[float, Timeout, None, NotGiven] = NotGiven()\n files: Union[HttpxRequestFiles, None] = None\n idempotency_key: Union[str, None] = None\n post_parser: Union[Callable[[Any], Any], NotGiven] = NotGiven()\n\n # It should be noted that we cannot use `json` here as that would override\n # a BaseModel method in an incompatible fashion.\n json_data: Union[Body, None] = None\n extra_json: Union[AnyMapping, None] = None\n\n if PYDANTIC_V2:\n model_config: ClassVar[ConfigDict] = ConfigDict(arbitrary_types_allowed=True)\n else:\n\n class Config(pydantic.BaseConfig): # pyright: ignore[reportDeprecated]\n arbitrary_types_allowed: bool = True\n\n def get_max_retries(self, max_retries: int) -> int:\n if isinstance(self.max_retries, NotGiven):\n return max_retries\n return self.max_retries\n\n def _strip_raw_response_header(self) -> None:\n if not is_given(self.headers):\n return\n\n if self.headers.get(RAW_RESPONSE_HEADER):\n self.headers = {**self.headers}\n self.headers.pop(RAW_RESPONSE_HEADER)\n\n # override the `construct` method so that we can run custom transformations.\n # this is necessary as we don't want to do any actual runtime type checking\n # (which means we can't use validators) but we do want to ensure that `NotGiven`\n # values are not present\n #\n # type ignore required because we're adding explicit types to `**values`\n @classmethod\n def construct( # type: ignore\n cls,\n _fields_set: set[str] | None = None,\n **values: Unpack[FinalRequestOptionsInput],\n ) -> FinalRequestOptions:\n kwargs: dict[str, Any] = {\n # we unconditionally call `strip_not_given` on any value\n # as it will just ignore any non-mapping types\n key: strip_not_given(value)\n for key, value in values.items()\n }\n if PYDANTIC_V2:\n return super().model_construct(_fields_set, **kwargs)\n return cast(FinalRequestOptions, super().construct(_fields_set, **kwargs)) # pyright: ignore[reportDeprecated]\n\n if not TYPE_CHECKING:\n # type checkers incorrectly complain about this assignment\n model_construct = construct\n\n# Path: src/openai/_qs.py\nfrom __future__ import annotations\n\nfrom typing import Any, List, Tuple, Union, Mapping, TypeVar\nfrom urllib.parse import parse_qs, urlencode\nfrom typing_extensions import Literal, get_args\n\nfrom ._types import NOT_GIVEN, NotGiven, NotGivenOr\nfrom ._utils import flatten\n\n_T = TypeVar(\"_T\")\n\n\nArrayFormat = Literal[\"comma\", \"repeat\", \"indices\", \"brackets\"]\nNestedFormat = Literal[\"dots\", \"brackets\"]\n\nPrimitiveData = Union[str, int, float, bool, None]\n# this should be Data = Union[PrimitiveData, \"List[Data]\", \"Tuple[Data]\", \"Mapping[str, Data]\"]\n# https://github.com/microsoft/pyright/issues/3555\nData = Union[PrimitiveData, List[Any], Tuple[Any], \"Mapping[str, Any]\"]\nParams = Mapping[str, Data]\n\n\nclass Querystring:\n array_format: ArrayFormat\n nested_format: NestedFormat\n\n def __init__(\n self,\n *,\n array_format: ArrayFormat = \"repeat\",\n nested_format: NestedFormat = \"brackets\",\n ) -> None:\n self.array_format = array_format\n self.nested_format = nested_format\n\n def parse(self, query: str) -> Mapping[str, object]:\n # Note: custom format syntax is not supported yet\n return parse_qs(query)\n\n def stringify(\n self,\n params: Params,\n *,\n array_format: NotGivenOr[ArrayFormat] = NOT_GIVEN,\n nested_format: NotGivenOr[NestedFormat] = NOT_GIVEN,\n ) -> str:\n return urlencode(\n self.stringify_items(\n params,\n array_format=array_format,\n nested_format=nested_format,\n )\n )\n\n def stringify_items(\n self,\n params: Params,\n *,\n array_format: NotGivenOr[ArrayFormat] = NOT_GIVEN,\n...\n", "needle_token_start": 8684, "needle_token_end": 8813, "code_context_ntokens": 15954, "output": ["```python\ndef maybe_transform(\n data: object,\n expected_type: object,\n) -> Any | None:\n \"\"\"Wrapper over `transform()` that allows `None` to be passed.\n\n See `transform()` for more details.\n \"\"\"\n if data is None:\n return None\n return transform(data, expected_type)\n```"]} +{"repo": "openai/openai-python", "name": "_construct_field", "language": "python", "path": "src/openai/_models.py", "position_ratio": 0.65, "description": "\nFunction Description:\n1. **Purpose**: The function is designed to construct and return an object of a specified type based on the provided value and field metadata.\n2. **Input**: The function takes three parameters: a value (any object), field metadata (which includes type annotations and default values), and a key (string representing the field name).\n3. **Output**: Returns an object that is constructed according to the specified type in the field metadata. If the input value is `None`, it returns the default value defined in the field metadata.\n4. **Procedure**: \n - First, it checks if the input value is `None`. If true, it retrieves and returns the default value from the field metadata.\n - It then determines the type of the field. If a certain version condition is met, it uses the type annotation directly from the field metadata; otherwise, it uses a type conversion method to get the appropriate type.\n - If the determined type is `None`, it raises an error indicating an unexpected `None` type for the field.\n - Finally, it constructs and returns an object of the determined type using the provided value.\n\n", "instruction": "Based on the function description and code context, please retrieve and repeat the exact described function from the code context in a code block wrapped by ```:", "template": "instruction\ncode_context\ndescription\ninstruction", "code_context": "# Path: src/openai/_files.py\nfrom __future__ import annotations\n\nimport io\nimport os\nimport pathlib\nfrom typing import overload\nfrom typing_extensions import TypeGuard\n\nimport anyio\n\nfrom ._types import (\n FileTypes,\n FileContent,\n RequestFiles,\n HttpxFileTypes,\n HttpxFileContent,\n HttpxRequestFiles,\n)\nfrom ._utils import is_tuple_t, is_mapping_t, is_sequence_t\n\n\ndef is_file_content(obj: object) -> TypeGuard[FileContent]:\n return (\n isinstance(obj, bytes) or isinstance(obj, tuple) or isinstance(obj, io.IOBase) or isinstance(obj, os.PathLike)\n )\n\n\ndef assert_is_file_content(obj: object, *, key: str | None = None) -> None:\n if not is_file_content(obj):\n prefix = f\"Expected entry at `{key}`\" if key is not None else f\"Expected file input `{obj!r}`\"\n raise RuntimeError(\n f\"{prefix} to be bytes, an io.IOBase instance, PathLike or a tuple but received {type(obj)} instead. See https://github.com/openai/openai-python/tree/main#file-uploads\"\n ) from None\n\n\n@overload\ndef to_httpx_files(files: None) -> None:\n ...\n\n\n@overload\ndef to_httpx_files(files: RequestFiles) -> HttpxRequestFiles:\n ...\n\n\ndef to_httpx_files(files: RequestFiles | None) -> HttpxRequestFiles | None:\n if files is None:\n return None\n\n if is_mapping_t(files):\n files = {key: _transform_file(file) for key, file in files.items()}\n elif is_sequence_t(files):\n files = [(key, _transform_file(file)) for key, file in files]\n else:\n raise TypeError(f\"Unexpected file type input {type(files)}, expected mapping or sequence\")\n\n return files\n\n\ndef _transform_file(file: FileTypes) -> HttpxFileTypes:\n if is_file_content(file):\n if isinstance(file, os.PathLike):\n path = pathlib.Path(file)\n return (path.name, path.read_bytes())\n\n return file\n\n if is_tuple_t(file):\n return (file[0], _read_file_content(file[1]), *file[2:])\n\n raise TypeError(f\"Expected file types input to be a FileContent type or to be a tuple\")\n\n\n...\n# Path: src/openai/_utils/_utils.py\nfrom __future__ import annotations\n\nimport os\nimport re\nimport inspect\nimport functools\nfrom typing import (\n Any,\n Tuple,\n Mapping,\n TypeVar,\n Callable,\n Iterable,\n Sequence,\n cast,\n overload,\n)\nfrom pathlib import Path\nfrom typing_extensions import TypeGuard\n\nimport sniffio\n\nfrom .._types import Headers, NotGiven, FileTypes, NotGivenOr, HeadersLike\nfrom .._compat import parse_date as parse_date, parse_datetime as parse_datetime\n\n_T = TypeVar(\"_T\")\n_TupleT = TypeVar(\"_TupleT\", bound=Tuple[object, ...])\n_MappingT = TypeVar(\"_MappingT\", bound=Mapping[str, object])\n_SequenceT = TypeVar(\"_SequenceT\", bound=Sequence[object])\nCallableT = TypeVar(\"CallableT\", bound=Callable[..., Any])\n\n\ndef flatten(t: Iterable[Iterable[_T]]) -> list[_T]:\n return [item for sublist in t for item in sublist]\n\n\ndef extract_files(\n # TODO: this needs to take Dict but variance issues.....\n # create protocol type ?\n query: Mapping[str, object],\n *,\n paths: Sequence[Sequence[str]],\n) -> list[tuple[str, FileTypes]]:\n \"\"\"Recursively extract files from the given dictionary based on specified paths.\n\n A path may look like this ['foo', 'files', '', 'data'].\n\n Note: this mutates the given dictionary.\n \"\"\"\n files: list[tuple[str, FileTypes]] = []\n for path in paths:\n files.extend(_extract_items(query, path, index=0, flattened_key=None))\n return files\n\n\ndef _extract_items(\n obj: object,\n path: Sequence[str],\n *,\n index: int,\n flattened_key: str | None,\n) -> list[tuple[str, FileTypes]]:\n try:\n key = path[index]\n except IndexError:\n if isinstance(obj, NotGiven):\n # no value was provided - we can safely ignore\n return []\n\n # cyclical import\n from .._files import assert_is_file_content\n\n # We have exhausted the path, return the entry we found.\n assert_is_file_content(obj, key=flattened_key)\n assert flattened_key is not None\n return [(flattened_key, cast(FileTypes, obj))]\n\n index += 1\n if is_dict(obj):\n try:\n # We are at the last entry in the path so we must remove the field\n if (len(path)) == index:\n item = obj.pop(key)\n else:\n item = obj[key]\n except KeyError:\n # Key was not present in the dictionary, this is not indicative of an error\n # as the given path may not point to a required field. We also do not want\n # to enforce required fields as the API may differ from the spec in some cases.\n return []\n if flattened_key is None:\n flattened_key = key\n else:\n flattened_key += f\"[{key}]\"\n return _extract_items(\n item,\n path,\n index=index,\n flattened_key=flattened_key,\n )\n elif is_list(obj):\n if key != \"\":\n return []\n\n return flatten(\n [\n _extract_items(\n item,\n path,\n index=index,\n flattened_key=flattened_key + \"[]\" if flattened_key is not None else \"[]\",\n )\n for item in obj\n ]\n )\n\n # Something unexpected was passed, just ignore it.\n return []\n\n\ndef is_given(obj: NotGivenOr[_T]) -> TypeGuard[_T]:\n return not isinstance(obj, NotGiven)\n\n\n# Type safe methods for narrowing types with TypeVars.\n# The default narrowing for isinstance(obj, dict) is dict[unknown, unknown],\n# however this cause Pyright to rightfully report errors. As we know we don't\n# care about the contained types we can safely use `object` in it's place.\n#\n# There are two separate functions defined, `is_*` and `is_*_t` for different use cases.\n# `is_*` is for when you're dealing with an unknown input\n# `is_*_t` is for when you're narrowing a known union type to a specific subset\n\n\ndef is_tuple(obj: object) -> TypeGuard[tuple[object, ...]]:\n return isinstance(obj, tuple)\n\n\ndef is_tuple_t(obj: _TupleT | object) -> TypeGuard[_TupleT]:\n return isinstance(obj, tuple)\n\n\ndef is_sequence(obj: object) -> TypeGuard[Sequence[object]]:\n return isinstance(obj, Sequence)\n\n\ndef is_sequence_t(obj: _SequenceT | object) -> TypeGuard[_SequenceT]:\n return isinstance(obj, Sequence)\n\n\ndef is_mapping(obj: object) -> TypeGuard[Mapping[str, object]]:\n return isinstance(obj, Mapping)\n\n\ndef is_mapping_t(obj: _MappingT | object) -> TypeGuard[_MappingT]:\n return isinstance(obj, Mapping)\n\n\ndef is_dict(obj: object) -> TypeGuard[dict[object, object]]:\n return isinstance(obj, dict)\n\n\ndef is_list(obj: object) -> TypeGuard[list[object]]:\n return isinstance(obj, list)\n\n\ndef is_iterable(obj: object) -> TypeGuard[Iterable[object]]:\n return isinstance(obj, Iterable)\n\n\ndef deepcopy_minimal(item: _T) -> _T:\n \"\"\"Minimal reimplementation of copy.deepcopy() that will only copy certain object types:\n\n - mappings, e.g. `dict`\n - list\n\n This is done for performance reasons.\n \"\"\"\n if is_mapping(item):\n return cast(_T, {k: deepcopy_minimal(v) for k, v in item.items()})\n if is_list(item):\n return cast(_T, [deepcopy_minimal(entry) for entry in item])\n return item\n\n\n# copied from https://github.com/Rapptz/RoboDanny\ndef human_join(seq: Sequence[str], *, delim: str = \", \", final: str = \"or\") -> str:\n size = len(seq)\n if size == 0:\n return \"\"\n\n if size == 1:\n return seq[0]\n\n if size == 2:\n return f\"{seq[0]} {final} {seq[1]}\"\n\n return delim.join(seq[:-1]) + f\" {final} {seq[-1]}\"\n\n\ndef quote(string: str) -> str:\n \"\"\"Add single quotation marks around the given string. Does *not* do any escaping.\"\"\"\n return f\"'{string}'\"\n\n\ndef required_args(*variants: Sequence[str]) -> Callable[[CallableT], CallableT]:\n \"\"\"Decorator to enforce a given set of arguments or variants of arguments are passed to the decorated function.\n\n Useful for enforcing runtime validation of overloaded functions.\n\n Example usage:\n ```py\n @overload\n def foo(*, a: str) -> str:\n ...\n\n\n @overload\n def foo(*, b: bool) -> str:\n ...\n\n\n # This enforces the same constraints that a static type checker would\n # i.e. that either a or b must be passed to the function\n @required_args([\"a\"], [\"b\"])\n def foo(*, a: str | None = None, b: bool | None = None) -> str:\n ...\n ```\n \"\"\"\n\n def inner(func: CallableT) -> CallableT:\n params = inspect.signature(func).parameters\n positional = [\n name\n for name, param in params.items()\n if param.kind\n in {\n param.POSITIONAL_ONLY,\n param.POSITIONAL_OR_KEYWORD,\n }\n ]\n\n @functools.wraps(func)\n def wrapper(*args: object, **kwargs: object) -> object:\n given_params: set[str] = set()\n for i, _ in enumerate(args):\n try:\n given_params.add(positional[i])\n except IndexError:\n raise TypeError(\n f\"{func.__name__}() takes {len(positional)} argument(s) but {len(args)} were given\"\n ) from None\n\n for key in kwargs.keys():\n given_params.add(key)\n\n for variant in variants:\n matches = all((param in given_params for param in variant))\n if matches:\n break\n else: # no break\n if len(variants) > 1:\n variations = human_join(\n [\"(\" + human_join([quote(arg) for arg in variant], final=\"and\") + \")\" for variant in variants]\n )\n msg = f\"Missing required arguments; Expected either {variations} arguments to be given\"\n else:\n # TODO: this error message is not deterministic\n missing = list(set(variants[0]) - given_params)\n if len(missing) > 1:\n msg = f\"Missing required arguments: {human_join([quote(arg) for arg in missing])}\"\n else:\n msg = f\"Missing required argument: {quote(missing[0])}\"\n raise TypeError(msg)\n return func(*args, **kwargs)\n\n return wrapper # type: ignore\n\n return inner\n\n\n_K = TypeVar(\"_K\")\n_V = TypeVar(\"_V\")\n\n\n@overload\ndef strip_not_given(obj: None) -> None:\n ...\n\n\n@overload\ndef strip_not_given(obj: Mapping[_K, _V | NotGiven]) -> dict[_K, _V]:\n ...\n\n\n@overload\ndef strip_not_given(obj: object) -> object:\n ...\n\n\ndef strip_not_given(obj: object | None) -> object:\n \"\"\"Remove all top-level keys where their values are instances of `NotGiven`\"\"\"\n if obj is None:\n return None\n\n if not is_mapping(obj):\n return obj\n\n return {key: value for key, value in obj.items() if not isinstance(value, NotGiven)}\n\n\ndef coerce_integer(val: str) -> int:\n return int(val, base=10)\n\n\ndef coerce_float(val: str) -> float:\n return float(val)\n\n\ndef coerce_boolean(val: str) -> bool:\n return val == \"true\" or val == \"1\" or val == \"on\"\n\n\ndef maybe_coerce_integer(val: str | None) -> int | None:\n if val is None:\n return None\n return coerce_integer(val)\n\n\ndef maybe_coerce_float(val: str | None) -> float | None:\n if val is None:\n return None\n return coerce_float(val)\n\n\ndef maybe_coerce_boolean(val: str | None) -> bool | None:\n if val is None:\n return None\n return coerce_boolean(val)\n\n\ndef removeprefix(string: str, prefix: str) -> str:\n \"\"\"Remove a prefix from a string.\n\n Backport of `str.removeprefix` for Python < 3.9\n \"\"\"\n if string.startswith(prefix):\n return string[len(prefix) :]\n return string\n\n\ndef removesuffix(string: str, suffix: str) -> str:\n \"\"\"Remove a suffix from a string.\n\n Backport of `str.removesuffix` for Python < 3.9\n \"\"\"\n if string.endswith(suffix):\n return string[: -len(suffix)]\n return string\n\n\ndef file_from_path(path: str) -> FileTypes:\n contents = Path(path).read_bytes()\n file_name = os.path.basename(path)\n return (file_name, contents)\n\n\ndef get_required_header(headers: HeadersLike, header: str) -> str:\n lower_header = header.lower()\n if isinstance(headers, Mapping):\n headers = cast(Headers, headers)\n for k, v in headers.items():\n if k.lower() == lower_header and isinstance(v, str):\n return v\n\n \"\"\" to deal with the case where the header looks like Stainless-Event-Id \"\"\"\n intercaps_header = re.sub(r\"([^\\w])(\\w)\", lambda pat: pat.group(1) + pat.group(2).upper(), header.capitalize())\n\n for normalized_header in [header, lower_header, header.upper(), intercaps_header]:\n value = headers.get(normalized_header)\n if value:\n return value\n\n raise ValueError(f\"Could not find {header} header\")\n\n\ndef get_async_library() -> str:\n try:\n return sniffio.current_async_library()\n except Exception:\n return \"false\"\n\n# Path: src/openai/_utils/_transform.py\nfrom __future__ import annotations\n\nfrom typing import Any, Mapping, TypeVar, cast\nfrom datetime import date, datetime\nfrom typing_extensions import Literal, get_args, override, get_type_hints\n\nimport pydantic\n\nfrom ._utils import (\n is_list,\n is_mapping,\n is_iterable,\n)\nfrom ._typing import (\n is_list_type,\n is_union_type,\n extract_type_arg,\n is_iterable_type,\n is_required_type,\n is_annotated_type,\n strip_annotated_type,\n)\nfrom .._compat import model_dump, is_typeddict\n\n_T = TypeVar(\"_T\")\n\n\n# TODO: support for drilling globals() and locals()\n# TODO: ensure works correctly with forward references in all cases\n\n\nPropertyFormat = Literal[\"iso8601\", \"custom\"]\n\n\nclass PropertyInfo:\n \"\"\"Metadata class to be used in Annotated types to provide information about a given type.\n\n For example:\n\n class MyParams(TypedDict):\n account_holder_name: Annotated[str, PropertyInfo(alias='accountHolderName')]\n\n This means that {'account_holder_name': 'Robert'} will be transformed to {'accountHolderName': 'Robert'} before being sent to the API.\n \"\"\"\n\n alias: str | None\n format: PropertyFormat | None\n format_template: str | None\n\n def __init__(\n self,\n *,\n alias: str | None = None,\n format: PropertyFormat | None = None,\n format_template: str | None = None,\n ) -> None:\n self.alias = alias\n self.format = format\n self.format_template = format_template\n\n @override\n def __repr__(self) -> str:\n return f\"{self.__class__.__name__}(alias='{self.alias}', format={self.format}, format_template='{self.format_template}')\"\n\n\ndef maybe_transform(\n data: object,\n expected_type: object,\n) -> Any | None:\n \"\"\"Wrapper over `transform()` that allows `None` to be passed.\n\n See `transform()` for more details.\n \"\"\"\n if data is None:\n return None\n return transform(data, expected_type)\n\n\n# Wrapper over _transform_recursive providing fake types\ndef transform(\n data: _T,\n expected_type: object,\n) -> _T:\n \"\"\"Transform dictionaries based off of type information from the given type, for example:\n\n ```py\n class Params(TypedDict, total=False):\n card_id: Required[Annotated[str, PropertyInfo(alias=\"cardID\")]]\n\n\n transformed = transform({\"card_id\": \"\"}, Params)\n # {'cardID': ''}\n ```\n\n Any keys / data that does not have type information given will be included as is.\n\n It should be noted that the transformations that this function does are not represented in the type system.\n \"\"\"\n transformed = _transform_recursive(data, annotation=cast(type, expected_type))\n return cast(_T, transformed)\n\n\ndef _get_annotated_type(type_: type) -> type | None:\n \"\"\"If the given type is an `Annotated` type then it is returned, if not `None` is returned.\n\n This also unwraps the type when applicable, e.g. `Required[Annotated[T, ...]]`\n \"\"\"\n if is_required_type(type_):\n # Unwrap `Required[Annotated[T, ...]]` to `Annotated[T, ...]`\n type_ = get_args(type_)[0]\n\n if is_annotated_type(type_):\n return type_\n\n return None\n\n\ndef _maybe_transform_key(key: str, type_: type) -> str:\n \"\"\"Transform the given `data` based on the annotations provided in `type_`.\n\n Note: this function only looks at `Annotated` types that contain `PropertInfo` metadata.\n \"\"\"\n annotated_type = _get_annotated_type(type_)\n if annotated_type is None:\n # no `Annotated` definition for this type, no transformation needed\n return key\n\n # ignore the first argument as it is the actual type\n annotations = get_args(annotated_type)[1:]\n for annotation in annotations:\n if isinstance(annotation, PropertyInfo) and annotation.alias is not None:\n return annotation.alias\n\n return key\n\n\ndef _transform_recursive(\n data: object,\n *,\n annotation: type,\n inner_type: type | None = None,\n) -> object:\n \"\"\"Transform the given data against the expected type.\n\n Args:\n annotation: The direct type annotation given to the particular piece of data.\n This may or may not be wrapped in metadata types, e.g. `Required[T]`, `Annotated[T, ...]` etc\n\n inner_type: If applicable, this is the \"inside\" type. This is useful in certain cases where the outside type\n is a container type such as `List[T]`. In that case `inner_type` should be set to `T` so that each entry in\n the list can be transformed using the metadata from the container type.\n\n Defaults to the same value as the `annotation` argument.\n \"\"\"\n if inner_type is None:\n inner_type = annotation\n\n stripped_type = strip_annotated_type(inner_type)\n if is_typeddict(stripped_type) and is_mapping(data):\n return _transform_typeddict(data, stripped_type)\n\n if (\n # List[T]\n (is_list_type(stripped_type) and is_list(data))\n # Iterable[T]\n or (is_iterable_type(stripped_type) and is_iterable(data) and not isinstance(data, str))\n ):\n inner_type = extract_type_arg(stripped_type, 0)\n return [_transform_recursive(d, annotation=annotation, inner_type=inner_type) for d in data]\n\n if is_union_type(stripped_type):\n # For union types we run the transformation against all subtypes to ensure that everything is transformed.\n #\n # TODO: there may be edge cases where the same normalized field name will transform to two different names\n # in different subtypes.\n for subtype in get_args(stripped_type):\n data = _transform_recursive(data, annotation=annotation, inner_type=subtype)\n return data\n\n if isinstance(data, pydantic.BaseModel):\n return model_dump(data, exclude_unset=True)\n\n return _transform_value(data, annotation)\n\n\ndef _transform_value(data: object, type_: type) -> object:\n annotated_type = _get_annotated_type(type_)\n if annotated_type is None:\n return data\n\n # ignore the first argument as it is the actual type\n annotations = get_args(annotated_type)[1:]\n for annotation in annotations:\n if isinstance(annotation, PropertyInfo) and annotation.format is not None:\n return _format_data(data, annotation.format, annotation.format_template)\n\n return data\n\n\ndef _format_data(data: object, format_: PropertyFormat, format_template: str | None) -> object:\n if isinstance(data, (date, datetime)):\n if format_ == \"iso8601\":\n return data.isoformat()\n\n if format_ == \"custom\" and format_template is not None:\n return data.strftime(format_template)\n\n return data\n\n\ndef _transform_typeddict(\n data: Mapping[str, object],\n expected_type: type,\n) -> Mapping[str, object]:\n result: dict[str, object] = {}\n annotations = get_type_hints(expected_type, include_extras=True)\n for key, value in data.items():\n type_ = annotations.get(key)\n if type_ is None:\n # we do not have a type annotation for this field, leave it as is\n result[key] = value\n else:\n result[_maybe_transform_key(key, type_)] = _transform_recursive(value, annotation=type_)\n return result\n\n# Path: src/openai/_utils/__init__.py\nfrom ._sync import asyncify as asyncify\nfrom ._proxy import LazyProxy as LazyProxy\nfrom ._utils import (\n flatten as flatten,\n is_dict as is_dict,\n is_list as is_list,\n is_given as is_given,\n is_tuple as is_tuple,\n is_mapping as is_mapping,\n is_tuple_t as is_tuple_t,\n parse_date as parse_date,\n is_iterable as is_iterable,\n is_sequence as is_sequence,\n coerce_float as coerce_float,\n is_mapping_t as is_mapping_t,\n removeprefix as removeprefix,\n removesuffix as removesuffix,\n extract_files as extract_files,\n is_sequence_t as is_sequence_t,\n required_args as required_args,\n coerce_boolean as coerce_boolean,\n coerce_integer as coerce_integer,\n file_from_path as file_from_path,\n parse_datetime as parse_datetime,\n strip_not_given as strip_not_given,\n deepcopy_minimal as deepcopy_minimal,\n get_async_library as get_async_library,\n maybe_coerce_float as maybe_coerce_float,\n get_required_header as get_required_header,\n maybe_coerce_boolean as maybe_coerce_boolean,\n maybe_coerce_integer as maybe_coerce_integer,\n)\nfrom ._typing import (\n is_list_type as is_list_type,\n is_union_type as is_union_type,\n extract_type_arg as extract_type_arg,\n is_iterable_type as is_iterable_type,\n is_required_type as is_required_type,\n is_annotated_type as is_annotated_type,\n strip_annotated_type as strip_annotated_type,\n extract_type_var_from_base as extract_type_var_from_base,\n)\nfrom ._streams import consume_sync_iterator as consume_sync_iterator, consume_async_iterator as consume_async_iterator\nfrom ._transform import (\n PropertyInfo as PropertyInfo,\n transform as transform,\n maybe_transform as maybe_transform,\n)\n\n# Path: src/openai/_exceptions.py\n# File generated from our OpenAPI spec by Stainless.\n\nfrom __future__ import annotations\n\nfrom typing import Any, Optional, cast\nfrom typing_extensions import Literal\n\nimport httpx\n\nfrom ._utils import is_dict\n\n__all__ = [\n \"BadRequestError\",\n \"AuthenticationError\",\n \"PermissionDeniedError\",\n \"NotFoundError\",\n \"ConflictError\",\n \"UnprocessableEntityError\",\n \"RateLimitError\",\n \"InternalServerError\",\n]\n\n\nclass OpenAIError(Exception):\n pass\n\n\nclass APIError(OpenAIError):\n message: str\n request: httpx.Request\n\n body: object | None\n \"\"\"The API response body.\n\n If the API responded with a valid JSON structure then this property will be the\n decoded result.\n\n If it isn't a valid JSON structure then this will be the raw response.\n\n If there was no response associated with this error then it will be `None`.\n \"\"\"\n\n code: Optional[str] = None\n param: Optional[str] = None\n type: Optional[str]\n\n def __init__(self, message: str, request: httpx.Request, *, body: object | None) -> None:\n super().__init__(message)\n self.request = request\n self.message = message\n self.body = body\n\n if is_dict(body):\n self.code = cast(Any, body.get(\"code\"))\n self.param = cast(Any, body.get(\"param\"))\n self.type = cast(Any, body.get(\"type\"))\n else:\n self.code = None\n self.param = None\n self.type = None\n\n\nclass APIResponseValidationError(APIError):\n response: httpx.Response\n status_code: int\n\n def __init__(self, response: httpx.Response, body: object | None, *, message: str | None = None) -> None:\n super().__init__(message or \"Data returned by API invalid for expected schema.\", response.request, body=body)\n self.response = response\n self.status_code = response.status_code\n\n\nclass APIStatusError(APIError):\n \"\"\"Raised when an API response has a status code of 4xx or 5xx.\"\"\"\n\n response: httpx.Response\n status_code: int\n\n def __init__(self, message: str, *, response: httpx.Response, body: object | None) -> None:\n super().__init__(message, response.request, body=body)\n self.response = response\n self.status_code = response.status_code\n\n\nclass APIConnectionError(APIError):\n def __init__(self, *, message: str = \"Connection error.\", request: httpx.Request) -> None:\n super().__init__(message, request, body=None)\n\n\nclass APITimeoutError(APIConnectionError):\n def __init__(self, request: httpx.Request) -> None:\n super().__init__(message=\"Request timed out.\", request=request)\n\n\nclass BadRequestError(APIStatusError):\n status_code: Literal[400] = 400 # pyright: ignore[reportIncompatibleVariableOverride]\n\n\nclass AuthenticationError(APIStatusError):\n status_code: Literal[401] = 401 # pyright: ignore[reportIncompatibleVariableOverride]\n\n\nclass PermissionDeniedError(APIStatusError):\n status_code: Literal[403] = 403 # pyright: ignore[reportIncompatibleVariableOverride]\n\n\nclass NotFoundError(APIStatusError):\n status_code: Literal[404] = 404 # pyright: ignore[reportIncompatibleVariableOverride]\n\n\nclass ConflictError(APIStatusError):\n status_code: Literal[409] = 409 # pyright: ignore[reportIncompatibleVariableOverride]\n\n\nclass UnprocessableEntityError(APIStatusError):\n status_code: Literal[422] = 422 # pyright: ignore[reportIncompatibleVariableOverride]\n\n\nclass RateLimitError(APIStatusError):\n status_code: Literal[429] = 429 # pyright: ignore[reportIncompatibleVariableOverride]\n\n\nclass InternalServerError(APIStatusError):\n pass\n\n# Path: src/openai/_models.py\nfrom __future__ import annotations\n\nimport inspect\nfrom typing import TYPE_CHECKING, Any, Type, Union, Generic, TypeVar, Callable, cast\nfrom datetime import date, datetime\nfrom typing_extensions import (\n Unpack,\n Literal,\n ClassVar,\n Protocol,\n Required,\n TypedDict,\n final,\n override,\n runtime_checkable,\n)\n\nimport pydantic\nimport pydantic.generics\nfrom pydantic.fields import FieldInfo\n\nfrom ._types import (\n Body,\n IncEx,\n Query,\n ModelT,\n Headers,\n Timeout,\n NotGiven,\n AnyMapping,\n HttpxRequestFiles,\n)\nfrom ._utils import is_list, is_given, is_mapping, parse_date, parse_datetime, strip_not_given\nfrom ._compat import (\n PYDANTIC_V2,\n ConfigDict,\n GenericModel as BaseGenericModel,\n get_args,\n is_union,\n parse_obj,\n get_origin,\n is_literal_type,\n get_model_config,\n get_model_fields,\n field_get_default,\n)\nfrom ._constants import RAW_RESPONSE_HEADER\n\n__all__ = [\"BaseModel\", \"GenericModel\"]\n\n_T = TypeVar(\"_T\")\n\n\n@runtime_checkable\nclass _ConfigProtocol(Protocol):\n allow_population_by_field_name: bool\n\n\nclass BaseModel(pydantic.BaseModel):\n if PYDANTIC_V2:\n model_config: ClassVar[ConfigDict] = ConfigDict(extra=\"allow\")\n else:\n\n @property\n @override\n def model_fields_set(self) -> set[str]:\n # a forwards-compat shim for pydantic v2\n return self.__fields_set__ # type: ignore\n\n class Config(pydantic.BaseConfig): # pyright: ignore[reportDeprecated]\n extra: Any = pydantic.Extra.allow # type: ignore\n\n @override\n def __str__(self) -> str:\n # mypy complains about an invalid self arg\n return f'{self.__repr_name__()}({self.__repr_str__(\", \")})' # type: ignore[misc]\n\n # Override the 'construct' method in a way that supports recursive parsing without validation.\n # Based on https://github.com/samuelcolvin/pydantic/issues/1168#issuecomment-817742836.\n @classmethod\n @override\n def construct(\n cls: Type[ModelT],\n _fields_set: set[str] | None = None,\n **values: object,\n ) -> ModelT:\n m = cls.__new__(cls)\n fields_values: dict[str, object] = {}\n\n config = get_model_config(cls)\n populate_by_name = (\n config.allow_population_by_field_name\n if isinstance(config, _ConfigProtocol)\n else config.get(\"populate_by_name\")\n )\n\n if _fields_set is None:\n _fields_set = set()\n\n model_fields = get_model_fields(cls)\n for name, field in model_fields.items():\n key = field.alias\n if key is None or (key not in values and populate_by_name):\n key = name\n\n if key in values:\n fields_values[name] = _construct_field(value=values[key], field=field, key=key)\n _fields_set.add(name)\n else:\n fields_values[name] = field_get_default(field)\n\n _extra = {}\n for key, value in values.items():\n if key not in model_fields:\n if PYDANTIC_V2:\n _extra[key] = value\n else:\n _fields_set.add(key)\n fields_values[key] = value\n\n object.__setattr__(m, \"__dict__\", fields_values)\n\n if PYDANTIC_V2:\n # these properties are copied from Pydantic's `model_construct()` method\n object.__setattr__(m, \"__pydantic_private__\", None)\n object.__setattr__(m, \"__pydantic_extra__\", _extra)\n object.__setattr__(m, \"__pydantic_fields_set__\", _fields_set)\n else:\n # init_private_attributes() does not exist in v2\n m._init_private_attributes() # type: ignore\n\n # copied from Pydantic v1's `construct()` method\n object.__setattr__(m, \"__fields_set__\", _fields_set)\n\n return m\n\n if not TYPE_CHECKING:\n # type checkers incorrectly complain about this assignment\n # because the type signatures are technically different\n # although not in practice\n model_construct = construct\n\n if not PYDANTIC_V2:\n # we define aliases for some of the new pydantic v2 methods so\n # that we can just document these methods without having to specify\n # a specific pydantic version as some users may not know which\n # pydantic version they are currently using\n\n @override\n def model_dump(\n self,\n *,\n mode: Literal[\"json\", \"python\"] | str = \"python\",\n include: IncEx = None,\n exclude: IncEx = None,\n by_alias: bool = False,\n exclude_unset: bool = False,\n exclude_defaults: bool = False,\n exclude_none: bool = False,\n round_trip: bool = False,\n warnings: bool = True,\n ) -> dict[str, Any]:\n \"\"\"Usage docs: https://docs.pydantic.dev/2.4/concepts/serialization/#modelmodel_dump\n\n Generate a dictionary representation of the model, optionally specifying which fields to include or exclude.\n\n Args:\n mode: The mode in which `to_python` should run.\n If mode is 'json', the dictionary will only contain JSON serializable types.\n If mode is 'python', the dictionary may contain any Python objects.\n include: A list of fields to include in the output.\n exclude: A list of fields to exclude from the output.\n by_alias: Whether to use the field's alias in the dictionary key if defined.\n exclude_unset: Whether to exclude fields that are unset or None from the output.\n exclude_defaults: Whether to exclude fields that are set to their default value from the output.\n exclude_none: Whether to exclude fields that have a value of `None` from the output.\n round_trip: Whether to enable serialization and deserialization round-trip support.\n warnings: Whether to log warnings when invalid fields are encountered.\n\n Returns:\n A dictionary representation of the model.\n \"\"\"\n if mode != \"python\":\n raise ValueError(\"mode is only supported in Pydantic v2\")\n if round_trip != False:\n raise ValueError(\"round_trip is only supported in Pydantic v2\")\n if warnings != True:\n raise ValueError(\"warnings is only supported in Pydantic v2\")\n return super().dict( # pyright: ignore[reportDeprecated]\n include=include,\n exclude=exclude,\n by_alias=by_alias,\n exclude_unset=exclude_unset,\n exclude_defaults=exclude_defaults,\n exclude_none=exclude_none,\n )\n\n @override\n def model_dump_json(\n self,\n *,\n indent: int | None = None,\n include: IncEx = None,\n exclude: IncEx = None,\n by_alias: bool = False,\n exclude_unset: bool = False,\n exclude_defaults: bool = False,\n exclude_none: bool = False,\n round_trip: bool = False,\n warnings: bool = True,\n ) -> str:\n \"\"\"Usage docs: https://docs.pydantic.dev/2.4/concepts/serialization/#modelmodel_dump_json\n\n Generates a JSON representation of the model using Pydantic's `to_json` method.\n\n Args:\n indent: Indentation to use in the JSON output. If None is passed, the output will be compact.\n include: Field(s) to include in the JSON output. Can take either a string or set of strings.\n exclude: Field(s) to exclude from the JSON output. Can take either a string or set of strings.\n by_alias: Whether to serialize using field aliases.\n exclude_unset: Whether to exclude fields that have not been explicitly set.\n exclude_defaults: Whether to exclude fields that have the default value.\n exclude_none: Whether to exclude fields that have a value of `None`.\n round_trip: Whether to use serialization/deserialization between JSON and class instance.\n warnings: Whether to show any warnings that occurred during serialization.\n\n Returns:\n A JSON string representation of the model.\n \"\"\"\n if round_trip != False:\n raise ValueError(\"round_trip is only supported in Pydantic v2\")\n if warnings != True:\n raise ValueError(\"warnings is only supported in Pydantic v2\")\n return super().json( # type: ignore[reportDeprecated]\n indent=indent,\n include=include,\n exclude=exclude,\n by_alias=by_alias,\n exclude_unset=exclude_unset,\n exclude_defaults=exclude_defaults,\n exclude_none=exclude_none,\n )\n\n\n\ndef _construct_field(value: object, field: FieldInfo, key: str) -> object:\n if value is None:\n return field_get_default(field)\n\n if PYDANTIC_V2:\n type_ = field.annotation\n else:\n type_ = cast(type, field.outer_type_) # type: ignore\n\n if type_ is None:\n raise RuntimeError(f\"Unexpected field type is None for {key}\")\n\n return construct_type(value=value, type_=type_)\n\n\ndef is_basemodel(type_: type) -> bool:\n \"\"\"Returns whether or not the given type is either a `BaseModel` or a union of `BaseModel`\"\"\"\n origin = get_origin(type_) or type_\n if is_union(type_):\n for variant in get_args(type_):\n if is_basemodel(variant):\n return True\n\n return False\n\n return issubclass(origin, BaseModel) or issubclass(origin, GenericModel)\n\n\ndef construct_type(*, value: object, type_: type) -> object:\n \"\"\"Loose coercion to the expected type with construction of nested values.\n\n If the given value does not match the expected type then it is returned as-is.\n \"\"\"\n\n # we need to use the origin class for any types that are subscripted generics\n # e.g. Dict[str, object]\n origin = get_origin(type_) or type_\n args = get_args(type_)\n\n if is_union(origin):\n try:\n return validate_type(type_=cast(\"type[object]\", type_), value=value)\n except Exception:\n pass\n\n # if the data is not valid, use the first variant that doesn't fail while deserializing\n for variant in args:\n try:\n return construct_type(value=value, type_=variant)\n except Exception:\n continue\n\n raise RuntimeError(f\"Could not convert data into a valid instance of {type_}\")\n\n if origin == dict:\n if not is_mapping(value):\n return value\n\n _, items_type = get_args(type_) # Dict[_, items_type]\n return {key: construct_type(value=item, type_=items_type) for key, item in value.items()}\n\n if not is_literal_type(type_) and (issubclass(origin, BaseModel) or issubclass(origin, GenericModel)):\n if is_list(value):\n return [cast(Any, type_).construct(**entry) if is_mapping(entry) else entry for entry in value]\n\n if is_mapping(value):\n if issubclass(type_, BaseModel):\n return type_.construct(**value) # type: ignore[arg-type]\n\n return cast(Any, type_).construct(**value)\n\n if origin == list:\n if not is_list(value):\n return value\n\n inner_type = args[0] # List[inner_type]\n return [construct_type(value=entry, type_=inner_type) for entry in value]\n\n if origin == float:\n if isinstance(value, int):\n coerced = float(value)\n if coerced != value:\n return value\n return coerced\n\n return value\n\n if type_ == datetime:\n try:\n return parse_datetime(value) # type: ignore\n except Exception:\n return value\n\n if type_ == date:\n try:\n return parse_date(value) # type: ignore\n except Exception:\n return value\n\n return value\n\n\ndef validate_type(*, type_: type[_T], value: object) -> _T:\n \"\"\"Strict validation that the given value matches the expected type\"\"\"\n if inspect.isclass(type_) and issubclass(type_, pydantic.BaseModel):\n return cast(_T, parse_obj(type_, value))\n\n return cast(_T, _validate_non_model_type(type_=type_, value=value))\n\n\n# our use of subclasssing here causes weirdness for type checkers,\n# so we just pretend that we don't subclass\nif TYPE_CHECKING:\n GenericModel = BaseModel\nelse:\n\n class GenericModel(BaseGenericModel, BaseModel):\n pass\n\n\nif PYDANTIC_V2:\n from pydantic import TypeAdapter\n\n def _validate_non_model_type(*, type_: type[_T], value: object) -> _T:\n return TypeAdapter(type_).validate_python(value)\n\nelif not TYPE_CHECKING: # TODO: condition is weird\n\n class RootModel(GenericModel, Generic[_T]):\n \"\"\"Used as a placeholder to easily convert runtime types to a Pydantic format\n to provide validation.\n\n For example:\n ```py\n validated = RootModel[int](__root__=\"5\").__root__\n # validated: 5\n ```\n \"\"\"\n\n __root__: _T\n\n def _validate_non_model_type(*, type_: type[_T], value: object) -> _T:\n model = _create_pydantic_model(type_).validate(value)\n return cast(_T, model.__root__)\n\n def _create_pydantic_model(type_: _T) -> Type[RootModel[_T]]:\n return RootModel[type_] # type: ignore\n\n\nclass FinalRequestOptionsInput(TypedDict, total=False):\n method: Required[str]\n url: Required[str]\n params: Query\n headers: Headers\n max_retries: int\n timeout: float | Timeout | None\n files: HttpxRequestFiles | None\n idempotency_key: str\n json_data: Body\n extra_json: AnyMapping\n\n\n@final\nclass FinalRequestOptions(pydantic.BaseModel):\n method: str\n url: str\n params: Query = {}\n headers: Union[Headers, NotGiven] = NotGiven()\n max_retries: Union[int, NotGiven] = NotGiven()\n timeout: Union[float, Timeout, None, NotGiven] = NotGiven()\n files: Union[HttpxRequestFiles, None] = None\n idempotency_key: Union[str, None] = None\n post_parser: Union[Callable[[Any], Any], NotGiven] = NotGiven()\n\n # It should be noted that we cannot use `json` here as that would override\n # a BaseModel method in an incompatible fashion.\n json_data: Union[Body, None] = None\n extra_json: Union[AnyMapping, None] = None\n\n if PYDANTIC_V2:\n model_config: ClassVar[ConfigDict] = ConfigDict(arbitrary_types_allowed=True)\n else:\n\n class Config(pydantic.BaseConfig): # pyright: ignore[reportDeprecated]\n arbitrary_types_allowed: bool = True\n\n def get_max_retries(self, max_retries: int) -> int:\n if isinstance(self.max_retries, NotGiven):\n return max_retries\n return self.max_retries\n\n def _strip_raw_response_header(self) -> None:\n if not is_given(self.headers):\n return\n\n if self.headers.get(RAW_RESPONSE_HEADER):\n self.headers = {**self.headers}\n self.headers.pop(RAW_RESPONSE_HEADER)\n\n # override the `construct` method so that we can run custom transformations.\n # this is necessary as we don't want to do any actual runtime type checking\n # (which means we can't use validators) but we do want to ensure that `NotGiven`\n # values are not present\n #\n # type ignore required because we're adding explicit types to `**values`\n @classmethod\n def construct( # type: ignore\n cls,\n _fields_set: set[str] | None = None,\n **values: Unpack[FinalRequestOptionsInput],\n ) -> FinalRequestOptions:\n kwargs: dict[str, Any] = {\n # we unconditionally call `strip_not_given` on any value\n # as it will just ignore any non-mapping types\n key: strip_not_given(value)\n for key, value in values.items()\n }\n if PYDANTIC_V2:\n return super().model_construct(_fields_set, **kwargs)\n return cast(FinalRequestOptions, super().construct(_fields_set, **kwargs)) # pyright: ignore[reportDeprecated]\n\n if not TYPE_CHECKING:\n # type checkers incorrectly complain about this assignment\n model_construct = construct\n\n# Path: src/openai/_qs.py\nfrom __future__ import annotations\n\nfrom typing import Any, List, Tuple, Union, Mapping, TypeVar\nfrom urllib.parse import parse_qs, urlencode\nfrom typing_extensions import Literal, get_args\n\nfrom ._types import NOT_GIVEN, NotGiven, NotGivenOr\nfrom ._utils import flatten\n\n_T = TypeVar(\"_T\")\n\n\nArrayFormat = Literal[\"comma\", \"repeat\", \"indices\", \"brackets\"]\nNestedFormat = Literal[\"dots\", \"brackets\"]\n\nPrimitiveData = Union[str, int, float, bool, None]\n# this should be Data = Union[PrimitiveData, \"List[Data]\", \"Tuple[Data]\", \"Mapping[str, Data]\"]\n# https://github.com/microsoft/pyright/issues/3555\nData = Union[PrimitiveData, List[Any], Tuple[Any], \"Mapping[str, Any]\"]\nParams = Mapping[str, Data]\n\n\nclass Querystring:\n array_format: ArrayFormat\n nested_format: NestedFormat\n\n def __init__(\n self,\n *,\n array_format: ArrayFormat = \"repeat\",\n nested_format: NestedFormat = \"brackets\",\n ) -> None:\n self.array_format = array_format\n self.nested_format = nested_format\n\n def parse(self, query: str) -> Mapping[str, object]:\n # Note: custom format syntax is not supported yet\n return parse_qs(query)\n\n def stringify(\n self,\n params: Params,\n *,\n array_format: NotGivenOr[ArrayFormat] = NOT_GIVEN,\n nested_format: NotGivenOr[NestedFormat] = NOT_GIVEN,\n ) -> str:\n return urlencode(\n self.stringify_items(\n params,\n array_format=array_format,\n nested_format=nested_format,\n )\n )\n\n def stringify_items(\n self,\n params: Params,\n *,\n array_format: NotGivenOr[ArrayFormat] = NOT_GIVEN,\n nested_format: NotGivenOr[NestedFormat] = NOT_GIVEN,\n ) -> list[tuple[str, str]]:\n opts = Options(\n qs=self,\n array_format=array_format,\n nested_format=nested_format,\n )\n return flatten([self._stringify_item(key, value, opts) for key, value in params.items()])\n\n def _stringify_item(\n self,\n key: str,\n value: Data,\n opts: Options,\n ) -> list[tuple[str, str]]:\n if isinstance(value, Mapping):\n items: list[tuple[str, str]] = []\n nested_format = opts.nested_format\n for subkey, subvalue in value.items():\n items.extend(\n self._stringify_item(\n # TODO: error if unknown format\n f\"{key}.{subkey}\" if nested_format == \"dots\" else f\"{key}[{subkey}]\",\n subvalue,\n opts,\n )\n )\n return items\n\n if isinstance(value, (list, tuple)):\n array_format = opts.array_format\n if array_format == \"comma\":\n return [\n (\n key,\n \",\".join(self._primitive_value_to_str(item) for item in value if item is not None),\n ),\n ]\n elif array_format == \"repeat\":\n items = []\n for item in value:\n items.extend(self._stringify_item(key, item, opts))\n return items\n elif array_format == \"indices\":\n raise NotImplementedError(\"The array indices format is not supported yet\")\n elif array_format == \"brackets\":\n items = []\n key = key + \"[]\"\n for item in value:\n items.extend(self._stringify_item(key, item, opts))\n return items\n else:\n raise NotImplementedError(\n f\"Unknown array_format value: {array_format}, choose from {', '.join(get_args(ArrayFormat))}\"\n )\n\n serialised = self._primitive_value_to_str(value)\n if not serialised:\n return []\n return [(key, serialised)]\n\n def _primitive_value_to_str(self, value: PrimitiveData) -> str:\n # copied from httpx\n if value is True:\n return \"true\"\n elif value is False:\n return \"false\"\n elif value is None:\n return \"\"\n return str(value)\n\n\n_qs = Querystring()\nparse = _qs.parse\nstringify = _qs.stringify\nstringify_items = _qs.stringify_items\n\n\nclass Options:\n array_format: ArrayFormat\n nested_format: NestedFormat\n\n def __init__(\n self,\n qs: Querystring = _qs,\n *,\n array_format: NotGivenOr[ArrayFormat] = NOT_GIVEN,\n nested_format: NotGivenOr[NestedFormat] = NOT_GIVEN,\n ) -> None:\n self.array_format = qs.array_format if isinstance(array_format, NotGiven) else array_format\n self.nested_format = qs.nested_format if isinstance(nested_format, NotGiven) else nested_format\n\n# Path: src/openai/_version.py\n# File generated from our OpenAPI spec by Stainless.\n\n__title__ = \"openai\"\n__version__ = \"1.13.3\" # x-release-please-version\n\n# Path: src/openai/_resource.py\n# File generated from our OpenAPI spec by Stainless.\n\nfrom __future__ import annotations\n\nimport time\nfrom typing import TYPE_CHECKING\n\nimport anyio\n\nif TYPE_CHECKING:\n from ._client import OpenAI, AsyncOpenAI\n\n\nclass SyncAPIResource:\n _client: OpenAI\n\n def __init__(self, client: OpenAI) -> None:\n self._client = client\n self._get = client.get\n self._post = client.post\n self._patch = client.patch\n self._put = client.put\n self._delete = client.delete\n self._get_api_list = client.get_api_list\n\n def _sleep(self, seconds: float) -> None:\n time.sleep(seconds)\n\n\nclass AsyncAPIResource:\n _client: AsyncOpenAI\n\n def __init__(self, client: AsyncOpenAI) -> None:\n self._client = client\n self._get = client.get\n self._post = client.post\n self._patch = client.patch\n self._put = client.put\n self._delete = client.delete\n self._get_api_list = client.get_api_list\n\n async def _sleep(self, seconds: float) -> None:\n await anyio.sleep(seconds)\n\n# Path: src/openai/_response.py\nfrom __future__ import annotations\n\nimport os\nimport inspect\nimport logging\nimport datetime\nimport functools\nfrom types import TracebackType\nfrom typing import (\n TYPE_CHECKING,\n Any,\n Union,\n Generic,\n TypeVar,\n Callable,\n Iterator,\n AsyncIterator,\n cast,\n overload,\n)\nfrom typing_extensions import Awaitable, ParamSpec, override, get_origin\n\nimport anyio\nimport httpx\nimport pydantic\n\nfrom ._types import NoneType\nfrom ._utils import is_given, extract_type_var_from_base\nfrom ._models import BaseModel, is_basemodel\nfrom ._constants import RAW_RESPONSE_HEADER, OVERRIDE_CAST_TO_HEADER\nfrom ._streaming import Stream, AsyncStream, is_stream_class_type, extract_stream_chunk_type\nfrom ._exceptions import OpenAIError, APIResponseValidationError\n\nif TYPE_CHECKING:\n from ._models import FinalRequestOptions\n from ._base_client import BaseClient\n\n\nP = ParamSpec(\"P\")\nR = TypeVar(\"R\")\n_T = TypeVar(\"_T\")\n_APIResponseT = TypeVar(\"_APIResponseT\", bound=\"APIResponse[Any]\")\n_AsyncAPIResponseT = TypeVar(\"_AsyncAPIResponseT\", bound=\"AsyncAPIResponse[Any]\")\n\nlog: logging.Logger = logging.getLogger(__name__)\n\n\nclass BaseAPIResponse(Generic[R]):\n _cast_to: type[R]\n _client: BaseClient[Any, Any]\n _parsed_by_type: dict[type[Any], Any]\n _is_sse_stream: bool\n _stream_cls: type[Stream[Any]] | type[AsyncStream[Any]] | None\n _options: FinalRequestOptions\n\n http_response: httpx.Response\n\n def __init__(\n self,\n *,\n raw: httpx.Response,\n cast_to: type[R],\n client: BaseClient[Any, Any],\n stream: bool,\n stream_cls: type[Stream[Any]] | type[AsyncStream[Any]] | None,\n options: FinalRequestOptions,\n ) -> None:\n self._cast_to = cast_to\n self._client = client\n self._parsed_by_type = {}\n self._is_sse_stream = stream\n self._stream_cls = stream_cls\n self._options = options\n self.http_response = raw\n\n @property\n def headers(self) -> httpx.Headers:\n return self.http_response.headers\n\n @property\n def http_request(self) -> httpx.Request:\n \"\"\"Returns the httpx Request instance associated with the current response.\"\"\"\n return self.http_response.request\n\n @property\n def status_code(self) -> int:\n return self.http_response.status_code\n\n @property\n def url(self) -> httpx.URL:\n \"\"\"Returns the URL for which the request was made.\"\"\"\n return self.http_response.url\n\n @property\n def method(self) -> str:\n return self.http_request.method\n\n @property\n def http_version(self) -> str:\n return self.http_response.http_version\n\n @property\n def elapsed(self) -> datetime.timedelta:\n \"\"\"The time taken for the complete request/response cycle to complete.\"\"\"\n return self.http_response.elapsed\n\n @property\n def is_closed(self) -> bool:\n \"\"\"Whether or not the response body has been closed.\n\n If this is False then there is response data that has not been read yet.\n You must either fully consume the response body or call `.close()`\n before discarding the response to prevent resource leaks.\n \"\"\"\n return self.http_response.is_closed\n\n @override\n def __repr__(self) -> str:\n return (\n f\"<{self.__class__.__name__} [{self.status_code} {self.http_response.reason_phrase}] type={self._cast_to}>\"\n )\n\n def _parse(self, *, to: type[_T] | None = None) -> R | _T:\n if self._is_sse_stream:\n if to:\n if not is_stream_class_type(to):\n raise TypeError(f\"Expected custom parse type to be a subclass of {Stream} or {AsyncStream}\")\n\n return cast(\n _T,\n to(\n cast_to=extract_stream_chunk_type(\n to,\n failure_message=\"Expected custom stream type to be passed with a type argument, e.g. Stream[ChunkType]\",\n ),\n response=self.http_response,\n client=cast(Any, self._client),\n ),\n )\n\n if self._stream_cls:\n return cast(\n R,\n self._stream_cls(\n cast_to=extract_stream_chunk_type(self._stream_cls),\n response=self.http_response,\n client=cast(Any, self._client),\n ),\n )\n\n stream_cls = cast(\"type[Stream[Any]] | type[AsyncStream[Any]] | None\", self._client._default_stream_cls)\n if stream_cls is None:\n raise MissingStreamClassError()\n\n return cast(\n R,\n stream_cls(\n cast_to=self._cast_to,\n response=self.http_response,\n client=cast(Any, self._client),\n ),\n )\n\n cast_to = to if to is not None else self._cast_to\n if cast_to is NoneType:\n return cast(R, None)\n\n response = self.http_response\n if cast_to == str:\n return cast(R, response.text)\n\n if cast_to == bytes:\n return cast(R, response.content)\n\n origin = get_origin(cast_to) or cast_to\n\n # handle the legacy binary response case\n if inspect.isclass(cast_to) and cast_to.__name__ == \"HttpxBinaryResponseContent\":\n return cast(R, cast_to(response)) # type: ignore\n\n if origin == APIResponse:\n raise RuntimeError(\"Unexpected state - cast_to is `APIResponse`\")\n\n if inspect.isclass(origin) and issubclass(origin, httpx.Response):\n # Because of the invariance of our ResponseT TypeVar, users can subclass httpx.Response\n # and pass that class to our request functions. We cannot change the variance to be either\n...\n", "needle_token_start": 10316, "needle_token_end": 10446, "code_context_ntokens": 15992, "output": ["```python\ndef _construct_field(value: object, field: FieldInfo, key: str) -> object:\n if value is None:\n return field_get_default(field)\n\n if PYDANTIC_V2:\n type_ = field.annotation\n else:\n type_ = cast(type, field.outer_type_) # type: ignore\n\n if type_ is None:\n raise RuntimeError(f\"Unexpected field type is None for {key}\")\n\n return construct_type(value=value, type_=type_)\n```"]} +{"repo": "openai/openai-python", "name": "_basic_config", "language": "python", "path": "src/openai/_utils/_logs.py", "position_ratio": 0.75, "description": "\nFunction Description:\n1. **Purpose**: To configure the default logging settings for the application, ensuring that log messages are formatted and timestamped consistently.\n2. **Input**: None, as it does not require any parameters to be passed.\n3. **Output**: None, as it does not return any values but configures the logging system.\n4. **Procedure**: The function sets up the logging format to include the timestamp, logger name, line number, log level, and the log message itself. It also specifies the date and time format to be used in the log messages.\n\n", "instruction": "Based on the function description and code context, please retrieve and repeat the exact described function from the code context in a code block wrapped by ```:", "template": "instruction\ncode_context\ndescription\ninstruction", "code_context": "# Path: src/openai/_base_client.py\nfrom __future__ import annotations\n\nimport json\nimport time\nimport uuid\nimport email\nimport asyncio\nimport inspect\nimport logging\nimport platform\nimport warnings\nimport email.utils\nfrom types import TracebackType\nfrom random import random\nfrom typing import (\n TYPE_CHECKING,\n Any,\n Dict,\n Type,\n Union,\n Generic,\n Mapping,\n TypeVar,\n Iterable,\n Iterator,\n Optional,\n Generator,\n AsyncIterator,\n cast,\n overload,\n)\nfrom functools import lru_cache\nfrom typing_extensions import Literal, override, get_origin\n\nimport anyio\nimport httpx\nimport distro\nimport pydantic\nfrom httpx import URL, Limits\nfrom pydantic import PrivateAttr\n\nfrom . import _exceptions\nfrom ._qs import Querystring\nfrom ._files import to_httpx_files, async_to_httpx_files\nfrom ._types import (\n NOT_GIVEN,\n Body,\n Omit,\n Query,\n Headers,\n Timeout,\n NotGiven,\n ResponseT,\n Transport,\n AnyMapping,\n PostParser,\n ProxiesTypes,\n RequestFiles,\n HttpxSendArgs,\n AsyncTransport,\n RequestOptions,\n ModelBuilderProtocol,\n)\nfrom ._utils import is_dict, is_list, is_given, is_mapping\nfrom ._compat import model_copy, model_dump\nfrom ._models import GenericModel, FinalRequestOptions, validate_type, construct_type\nfrom ._response import (\n APIResponse,\n BaseAPIResponse,\n AsyncAPIResponse,\n extract_response_type,\n)\nfrom ._constants import (\n DEFAULT_LIMITS,\n DEFAULT_TIMEOUT,\n MAX_RETRY_DELAY,\n DEFAULT_MAX_RETRIES,\n INITIAL_RETRY_DELAY,\n RAW_RESPONSE_HEADER,\n OVERRIDE_CAST_TO_HEADER,\n)\nfrom ._streaming import Stream, AsyncStream\nfrom ._exceptions import (\n APIStatusError,\n APITimeoutError,\n APIConnectionError,\n APIResponseValidationError,\n)\nfrom ._legacy_response import LegacyAPIResponse\n\nlog: logging.Logger = logging.getLogger(__name__)\n\n# TODO: make base page type vars covariant\nSyncPageT = TypeVar(\"SyncPageT\", bound=\"BaseSyncPage[Any]\")\nAsyncPageT = TypeVar(\"AsyncPageT\", bound=\"BaseAsyncPage[Any]\")\n\n\n_T = TypeVar(\"_T\")\n_T_co = TypeVar(\"_T_co\", covariant=True)\n\n_StreamT = TypeVar(\"_StreamT\", bound=Stream[Any])\n_AsyncStreamT = TypeVar(\"_AsyncStreamT\", bound=AsyncStream[Any])\n\nif TYPE_CHECKING:\n from httpx._config import DEFAULT_TIMEOUT_CONFIG as HTTPX_DEFAULT_TIMEOUT\nelse:\n try:\n from httpx._config import DEFAULT_TIMEOUT_CONFIG as HTTPX_DEFAULT_TIMEOUT\n except ImportError:\n # taken from https://github.com/encode/httpx/blob/3ba5fe0d7ac70222590e759c31442b1cab263791/httpx/_config.py#L366\n HTTPX_DEFAULT_TIMEOUT = Timeout(5.0)\n\n\nclass PageInfo:\n \"\"\"Stores the necessary information to build the request to retrieve the next page.\n\n Either `url` or `params` must be set.\n \"\"\"\n\n url: URL | NotGiven\n params: Query | NotGiven\n\n @overload\n def __init__(\n self,\n *,\n url: URL,\n ) -> None:\n ...\n\n @overload\n def __init__(\n self,\n *,\n params: Query,\n ) -> None:\n ...\n\n def __init__(\n self,\n *,\n url: URL | NotGiven = NOT_GIVEN,\n params: Query | NotGiven = NOT_GIVEN,\n ) -> None:\n self.url = url\n self.params = params\n\n\nclass BasePage(GenericModel, Generic[_T]):\n \"\"\"\n Defines the core interface for pagination.\n\n Type Args:\n ModelT: The pydantic model that represents an item in the response.\n\n Methods:\n has_next_page(): Check if there is another page available\n next_page_info(): Get the necessary information to make a request for the next page\n \"\"\"\n\n _options: FinalRequestOptions = PrivateAttr()\n _model: Type[_T] = PrivateAttr()\n\n def has_next_page(self) -> bool:\n items = self._get_page_items()\n if not items:\n return False\n return self.next_page_info() is not None\n\n def next_page_info(self) -> Optional[PageInfo]:\n ...\n\n def _get_page_items(self) -> Iterable[_T]: # type: ignore[empty-body]\n ...\n\n def _params_from_url(self, url: URL) -> httpx.QueryParams:\n # TODO: do we have to preprocess params here?\n return httpx.QueryParams(cast(Any, self._options.params)).merge(url.params)\n\n def _info_to_options(self, info: PageInfo) -> FinalRequestOptions:\n options = model_copy(self._options)\n options._strip_raw_response_header()\n\n if not isinstance(info.params, NotGiven):\n options.params = {**options.params, **info.params}\n return options\n\n if not isinstance(info.url, NotGiven):\n params = self._params_from_url(info.url)\n url = info.url.copy_with(params=params)\n options.params = dict(url.params)\n options.url = str(url)\n return options\n\n raise ValueError(\"Unexpected PageInfo state\")\n\n\nclass BaseSyncPage(BasePage[_T], Generic[_T]):\n _client: SyncAPIClient = pydantic.PrivateAttr()\n\n def _set_private_attributes(\n self,\n client: SyncAPIClient,\n model: Type[_T],\n options: FinalRequestOptions,\n ) -> None:\n self._model = model\n self._client = client\n self._options = options\n\n # Pydantic uses a custom `__iter__` method to support casting BaseModels\n # to dictionaries. e.g. dict(model).\n # As we want to support `for item in page`, this is inherently incompatible\n # with the default pydantic behaviour. It is not possible to support both\n # use cases at once. Fortunately, this is not a big deal as all other pydantic\n # methods should continue to work as expected as there is an alternative method\n # to cast a model to a dictionary, model.dict(), which is used internally\n # by pydantic.\n def __iter__(self) -> Iterator[_T]: # type: ignore\n for page in self.iter_pages():\n for item in page._get_page_items():\n yield item\n\n def iter_pages(self: SyncPageT) -> Iterator[SyncPageT]:\n page = self\n while True:\n yield page\n if page.has_next_page():\n page = page.get_next_page()\n else:\n return\n\n def get_next_page(self: SyncPageT) -> SyncPageT:\n info = self.next_page_info()\n if not info:\n raise RuntimeError(\n \"No next page expected; please check `.has_next_page()` before calling `.get_next_page()`.\"\n )\n\n options = self._info_to_options(info)\n return self._client._request_api_list(self._model, page=self.__class__, options=options)\n\n\nclass AsyncPaginator(Generic[_T, AsyncPageT]):\n def __init__(\n self,\n client: AsyncAPIClient,\n options: FinalRequestOptions,\n page_cls: Type[AsyncPageT],\n model: Type[_T],\n ) -> None:\n self._model = model\n self._client = client\n self._options = options\n self._page_cls = page_cls\n\n def __await__(self) -> Generator[Any, None, AsyncPageT]:\n return self._get_page().__await__()\n\n async def _get_page(self) -> AsyncPageT:\n def _parser(resp: AsyncPageT) -> AsyncPageT:\n resp._set_private_attributes(\n model=self._model,\n options=self._options,\n client=self._client,\n )\n return resp\n\n self._options.post_parser = _parser\n\n return await self._client.request(self._page_cls, self._options)\n\n async def __aiter__(self) -> AsyncIterator[_T]:\n # https://github.com/microsoft/pyright/issues/3464\n page = cast(\n AsyncPageT,\n await self, # type: ignore\n )\n async for item in page:\n yield item\n\n\nclass BaseAsyncPage(BasePage[_T], Generic[_T]):\n _client: AsyncAPIClient = pydantic.PrivateAttr()\n\n def _set_private_attributes(\n self,\n model: Type[_T],\n client: AsyncAPIClient,\n options: FinalRequestOptions,\n ) -> None:\n self._model = model\n self._client = client\n self._options = options\n\n async def __aiter__(self) -> AsyncIterator[_T]:\n async for page in self.iter_pages():\n for item in page._get_page_items():\n yield item\n\n async def iter_pages(self: AsyncPageT) -> AsyncIterator[AsyncPageT]:\n page = self\n while True:\n yield page\n if page.has_next_page():\n page = await page.get_next_page()\n else:\n return\n\n async def get_next_page(self: AsyncPageT) -> AsyncPageT:\n info = self.next_page_info()\n if not info:\n raise RuntimeError(\n \"No next page expected; please check `.has_next_page()` before calling `.get_next_page()`.\"\n )\n\n options = self._info_to_options(info)\n return await self._client._request_api_list(self._model, page=self.__class__, options=options)\n\n\n_HttpxClientT = TypeVar(\"_HttpxClientT\", bound=Union[httpx.Client, httpx.AsyncClient])\n_DefaultStreamT = TypeVar(\"_DefaultStreamT\", bound=Union[Stream[Any], AsyncStream[Any]])\n\n\nclass BaseClient(Generic[_HttpxClientT, _DefaultStreamT]):\n _client: _HttpxClientT\n _version: str\n _base_url: URL\n max_retries: int\n timeout: Union[float, Timeout, None]\n _limits: httpx.Limits\n _proxies: ProxiesTypes | None\n _transport: Transport | AsyncTransport | None\n _strict_response_validation: bool\n _idempotency_header: str | None\n _default_stream_cls: type[_DefaultStreamT] | None = None\n\n def __init__(\n self,\n *,\n version: str,\n base_url: str | URL,\n _strict_response_validation: bool,\n max_retries: int = DEFAULT_MAX_RETRIES,\n timeout: float | Timeout | None = DEFAULT_TIMEOUT,\n limits: httpx.Limits,\n transport: Transport | AsyncTransport | None,\n proxies: ProxiesTypes | None,\n custom_headers: Mapping[str, str] | None = None,\n custom_query: Mapping[str, object] | None = None,\n ) -> None:\n self._version = version\n self._base_url = self._enforce_trailing_slash(URL(base_url))\n self.max_retries = max_retries\n self.timeout = timeout\n self._limits = limits\n self._proxies = proxies\n self._transport = transport\n self._custom_headers = custom_headers or {}\n self._custom_query = custom_query or {}\n self._strict_response_validation = _strict_response_validation\n self._idempotency_header = None\n\n def _enforce_trailing_slash(self, url: URL) -> URL:\n if url.raw_path.endswith(b\"/\"):\n return url\n return url.copy_with(raw_path=url.raw_path + b\"/\")\n\n def _make_status_error_from_response(\n self,\n response: httpx.Response,\n ) -> APIStatusError:\n if response.is_closed and not response.is_stream_consumed:\n # We can't read the response body as it has been closed\n # before it was read. This can happen if an event hook\n # raises a status error.\n body = None\n err_msg = f\"Error code: {response.status_code}\"\n else:\n err_text = response.text.strip()\n body = err_text\n\n try:\n body = json.loads(err_text)\n err_msg = f\"Error code: {response.status_code} - {body}\"\n except Exception:\n err_msg = err_text or f\"Error code: {response.status_code}\"\n\n return self._make_status_error(err_msg, body=body, response=response)\n\n def _make_status_error(\n self,\n err_msg: str,\n *,\n body: object,\n response: httpx.Response,\n ) -> _exceptions.APIStatusError:\n raise NotImplementedError()\n\n def _remaining_retries(\n self,\n remaining_retries: Optional[int],\n options: FinalRequestOptions,\n ) -> int:\n return remaining_retries if remaining_retries is not None else options.get_max_retries(self.max_retries)\n\n def _build_headers(self, options: FinalRequestOptions) -> httpx.Headers:\n custom_headers = options.headers or {}\n headers_dict = _merge_mappings(self.default_headers, custom_headers)\n self._validate_headers(headers_dict, custom_headers)\n\n # headers are case-insensitive while dictionaries are not.\n headers = httpx.Headers(headers_dict)\n\n idempotency_header = self._idempotency_header\n if idempotency_header and options.method.lower() != \"get\" and idempotency_header not in headers:\n headers[idempotency_header] = options.idempotency_key or self._idempotency_key()\n\n return headers\n\n def _prepare_url(self, url: str) -> URL:\n \"\"\"\n Merge a URL argument together with any 'base_url' on the client,\n to create the URL used for the outgoing request.\n \"\"\"\n # Copied from httpx's `_merge_url` method.\n merge_url = URL(url)\n if merge_url.is_relative_url:\n merge_raw_path = self.base_url.raw_path + merge_url.raw_path.lstrip(b\"/\")\n return self.base_url.copy_with(raw_path=merge_raw_path)\n\n return merge_url\n\n def _build_request(\n self,\n options: FinalRequestOptions,\n ) -> httpx.Request:\n if log.isEnabledFor(logging.DEBUG):\n log.debug(\"Request options: %s\", model_dump(options, exclude_unset=True))\n\n kwargs: dict[str, Any] = {}\n\n json_data = options.json_data\n if options.extra_json is not None:\n if json_data is None:\n json_data = cast(Body, options.extra_json)\n elif is_mapping(json_data):\n json_data = _merge_mappings(json_data, options.extra_json)\n else:\n raise RuntimeError(f\"Unexpected JSON data type, {type(json_data)}, cannot merge with `extra_body`\")\n\n headers = self._build_headers(options)\n params = _merge_mappings(self._custom_query, options.params)\n content_type = headers.get(\"Content-Type\")\n\n # If the given Content-Type header is multipart/form-data then it\n # has to be removed so that httpx can generate the header with\n # additional information for us as it has to be in this form\n # for the server to be able to correctly parse the request:\n # multipart/form-data; boundary=---abc--\n if content_type is not None and content_type.startswith(\"multipart/form-data\"):\n if \"boundary\" not in content_type:\n # only remove the header if the boundary hasn't been explicitly set\n # as the caller doesn't want httpx to come up with their own boundary\n headers.pop(\"Content-Type\")\n\n # As we are now sending multipart/form-data instead of application/json\n # we need to tell httpx to use it, https://www.python-httpx.org/advanced/#multipart-file-encoding\n if json_data:\n if not is_dict(json_data):\n raise TypeError(\n f\"Expected query input to be a dictionary for multipart requests but got {type(json_data)} instead.\"\n )\n kwargs[\"data\"] = self._serialize_multipartform(json_data)\n\n # TODO: report this error to httpx\n return self._client.build_request( # pyright: ignore[reportUnknownMemberType]\n headers=headers,\n timeout=self.timeout if isinstance(options.timeout, NotGiven) else options.timeout,\n method=options.method,\n url=self._prepare_url(options.url),\n # the `Query` type that we use is incompatible with qs'\n # `Params` type as it needs to be typed as `Mapping[str, object]`\n # so that passing a `TypedDict` doesn't cause an error.\n # https://github.com/microsoft/pyright/issues/3526#event-6715453066\n params=self.qs.stringify(cast(Mapping[str, Any], params)) if params else None,\n json=json_data,\n files=options.files,\n **kwargs,\n )\n\n def _serialize_multipartform(self, data: Mapping[object, object]) -> dict[str, object]:\n items = self.qs.stringify_items(\n # TODO: type ignore is required as stringify_items is well typed but we can't be\n # well typed without heavy validation.\n data, # type: ignore\n array_format=\"brackets\",\n )\n serialized: dict[str, object] = {}\n for key, value in items:\n existing = serialized.get(key)\n\n if not existing:\n serialized[key] = value\n continue\n\n # If a value has already been set for this key then that\n # means we're sending data like `array[]=[1, 2, 3]` and we\n # need to tell httpx that we want to send multiple values with\n # the same key which is done by using a list or a tuple.\n #\n # Note: 2d arrays should never result in the same key at both\n # levels so it's safe to assume that if the value is a list,\n # it was because we changed it to be a list.\n if is_list(existing):\n existing.append(value)\n else:\n serialized[key] = [existing, value]\n\n return serialized\n\n def _maybe_override_cast_to(self, cast_to: type[ResponseT], options: FinalRequestOptions) -> type[ResponseT]:\n if not is_given(options.headers):\n return cast_to\n\n # make a copy of the headers so we don't mutate user-input\n headers = dict(options.headers)\n\n # we internally support defining a temporary header to override the\n # default `cast_to` type for use with `.with_raw_response` and `.with_streaming_response`\n # see _response.py for implementation details\n override_cast_to = headers.pop(OVERRIDE_CAST_TO_HEADER, NOT_GIVEN)\n if is_given(override_cast_to):\n options.headers = headers\n return cast(Type[ResponseT], override_cast_to)\n\n return cast_to\n\n def _should_stream_response_body(self, request: httpx.Request) -> bool:\n return request.headers.get(RAW_RESPONSE_HEADER) == \"stream\" # type: ignore[no-any-return]\n\n def _process_response_data(\n self,\n *,\n data: object,\n cast_to: type[ResponseT],\n response: httpx.Response,\n ) -> ResponseT:\n if data is None:\n return cast(ResponseT, None)\n\n if cast_to is object:\n return cast(ResponseT, data)\n\n try:\n if inspect.isclass(cast_to) and issubclass(cast_to, ModelBuilderProtocol):\n return cast(ResponseT, cast_to.build(response=response, data=data))\n\n if self._strict_response_validation:\n return cast(ResponseT, validate_type(type_=cast_to, value=data))\n\n return cast(ResponseT, construct_type(type_=cast_to, value=data))\n except pydantic.ValidationError as err:\n raise APIResponseValidationError(response=response, body=data) from err\n\n @property\n def qs(self) -> Querystring:\n return Querystring()\n\n @property\n def custom_auth(self) -> httpx.Auth | None:\n return None\n\n @property\n def auth_headers(self) -> dict[str, str]:\n return {}\n\n @property\n def default_headers(self) -> dict[str, str | Omit]:\n return {\n \"Accept\": \"application/json\",\n \"Content-Type\": \"application/json\",\n \"User-Agent\": self.user_agent,\n **self.platform_headers(),\n **self.auth_headers,\n **self._custom_headers,\n }\n\n def _validate_headers(\n self,\n headers: Headers, # noqa: ARG002\n custom_headers: Headers, # noqa: ARG002\n ) -> None:\n \"\"\"Validate the given default headers and custom headers.\n\n Does nothing by default.\n \"\"\"\n return\n\n @property\n def user_agent(self) -> str:\n return f\"{self.__class__.__name__}/Python {self._version}\"\n\n @property\n def base_url(self) -> URL:\n return self._base_url\n\n @base_url.setter\n def base_url(self, url: URL | str) -> None:\n self._base_url = self._enforce_trailing_slash(url if isinstance(url, URL) else URL(url))\n\n def platform_headers(self) -> Dict[str, str]:\n return platform_headers(self._version)\n\n def _parse_retry_after_header(self, response_headers: Optional[httpx.Headers] = None) -> float | None:\n \"\"\"Returns a float of the number of seconds (not milliseconds) to wait after retrying, or None if unspecified.\n\n About the Retry-After header: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Retry-After\n See also https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Retry-After#syntax\n \"\"\"\n if response_headers is None:\n return None\n\n # First, try the non-standard `retry-after-ms` header for milliseconds,\n # which is more precise than integer-seconds `retry-after`\n try:\n retry_ms_header = response_headers.get(\"retry-after-ms\", None)\n return float(retry_ms_header) / 1000\n except (TypeError, ValueError):\n pass\n\n # Next, try parsing `retry-after` header as seconds (allowing nonstandard floats).\n retry_header = response_headers.get(\"retry-after\")\n try:\n # note: the spec indicates that this should only ever be an integer\n # but if someone sends a float there's no reason for us to not respect it\n return float(retry_header)\n except (TypeError, ValueError):\n pass\n\n # Last, try parsing `retry-after` as a date.\n retry_date_tuple = email.utils.parsedate_tz(retry_header)\n if retry_date_tuple is None:\n return None\n\n retry_date = email.utils.mktime_tz(retry_date_tuple)\n return float(retry_date - time.time())\n\n def _calculate_retry_timeout(\n self,\n remaining_retries: int,\n options: FinalRequestOptions,\n response_headers: Optional[httpx.Headers] = None,\n ) -> float:\n max_retries = options.get_max_retries(self.max_retries)\n\n # If the API asks us to wait a certain amount of time (and it's a reasonable amount), just do what it says.\n retry_after = self._parse_retry_after_header(response_headers)\n if retry_after is not None and 0 < retry_after <= 60:\n return retry_after\n\n nb_retries = max_retries - remaining_retries\n\n # Apply exponential backoff, but not more than the max.\n sleep_seconds = min(INITIAL_RETRY_DELAY * pow(2.0, nb_retries), MAX_RETRY_DELAY)\n\n # Apply some jitter, plus-or-minus half a second.\n jitter = 1 - 0.25 * random()\n timeout = sleep_seconds * jitter\n return timeout if timeout >= 0 else 0\n\n def _should_retry(self, response: httpx.Response) -> bool:\n # Note: this is not a standard header\n should_retry_header = response.headers.get(\"x-should-retry\")\n\n # If the server explicitly says whether or not to retry, obey.\n if should_retry_header == \"true\":\n log.debug(\"Retrying as header `x-should-retry` is set to `true`\")\n return True\n if should_retry_header == \"false\":\n log.debug(\"Not retrying as header `x-should-retry` is set to `false`\")\n return False\n\n # Retry on request timeouts.\n if response.status_code == 408:\n log.debug(\"Retrying due to status code %i\", response.status_code)\n return True\n\n # Retry on lock timeouts.\n if response.status_code == 409:\n log.debug(\"Retrying due to status code %i\", response.status_code)\n return True\n\n # Retry on rate limits.\n if response.status_code == 429:\n log.debug(\"Retrying due to status code %i\", response.status_code)\n return True\n\n # Retry internal errors.\n if response.status_code >= 500:\n log.debug(\"Retrying due to status code %i\", response.status_code)\n return True\n\n log.debug(\"Not retrying\")\n return False\n\n def _idempotency_key(self) -> str:\n return f\"stainless-python-retry-{uuid.uuid4()}\"\n\n\nclass SyncHttpxClientWrapper(httpx.Client):\n def __del__(self) -> None:\n try:\n self.close()\n except Exception:\n pass\n\n\nclass SyncAPIClient(BaseClient[httpx.Client, Stream[Any]]):\n _client: httpx.Client\n _default_stream_cls: type[Stream[Any]] | None = None\n\n def __init__(\n self,\n *,\n version: str,\n base_url: str | URL,\n max_retries: int = DEFAULT_MAX_RETRIES,\n timeout: float | Timeout | None | NotGiven = NOT_GIVEN,\n transport: Transport | None = None,\n proxies: ProxiesTypes | None = None,\n limits: Limits | None = None,\n http_client: httpx.Client | None = None,\n custom_headers: Mapping[str, str] | None = None,\n custom_query: Mapping[str, object] | None = None,\n _strict_response_validation: bool,\n ) -> None:\n if limits is not None:\n warnings.warn(\n \"The `connection_pool_limits` argument is deprecated. The `http_client` argument should be passed instead\",\n category=DeprecationWarning,\n stacklevel=3,\n )\n if http_client is not None:\n raise ValueError(\"The `http_client` argument is mutually exclusive with `connection_pool_limits`\")\n else:\n limits = DEFAULT_LIMITS\n\n if transport is not None:\n warnings.warn(\n \"The `transport` argument is deprecated. The `http_client` argument should be passed instead\",\n category=DeprecationWarning,\n stacklevel=3,\n )\n if http_client is not None:\n raise ValueError(\"The `http_client` argument is mutually exclusive with `transport`\")\n\n if proxies is not None:\n warnings.warn(\n \"The `proxies` argument is deprecated. The `http_client` argument should be passed instead\",\n category=DeprecationWarning,\n stacklevel=3,\n )\n if http_client is not None:\n raise ValueError(\"The `http_client` argument is mutually exclusive with `proxies`\")\n\n if not is_given(timeout):\n # if the user passed in a custom http client with a non-default\n # timeout set then we use that timeout.\n #\n # note: there is an edge case here where the user passes in a client\n # where they've explicitly set the timeout to match the default timeout\n # as this check is structural, meaning that we'll think they didn't\n # pass in a timeout and will ignore it\n if http_client and http_client.timeout != HTTPX_DEFAULT_TIMEOUT:\n timeout = http_client.timeout\n else:\n timeout = DEFAULT_TIMEOUT\n\n super().__init__(\n version=version,\n limits=limits,\n # cast to a valid type because mypy doesn't understand our type narrowing\n timeout=cast(Timeout, timeout),\n proxies=proxies,\n base_url=base_url,\n transport=transport,\n max_retries=max_retries,\n custom_query=custom_query,\n custom_headers=custom_headers,\n _strict_response_validation=_strict_response_validation,\n )\n self._client = http_client or SyncHttpxClientWrapper(\n base_url=base_url,\n # cast to a valid type because mypy doesn't understand our type narrowing\n timeout=cast(Timeout, timeout),\n proxies=proxies,\n transport=transport,\n limits=limits,\n follow_redirects=True,\n )\n\n def is_closed(self) -> bool:\n return self._client.is_closed\n\n def close(self) -> None:\n \"\"\"Close the underlying HTTPX client.\n\n The client will *not* be usable after this.\n \"\"\"\n # If an error is thrown while constructing a client, self._client\n # may not be present\n if hasattr(self, \"_client\"):\n self._client.close()\n\n def __enter__(self: _T) -> _T:\n return self\n\n def __exit__(\n self,\n exc_type: type[BaseException] | None,\n exc: BaseException | None,\n exc_tb: TracebackType | None,\n ) -> None:\n self.close()\n\n def _prepare_options(\n self,\n options: FinalRequestOptions, # noqa: ARG002\n ) -> None:\n \"\"\"Hook for mutating the given options\"\"\"\n return None\n\n def _prepare_request(\n self,\n request: httpx.Request, # noqa: ARG002\n ) -> None:\n \"\"\"This method is used as a callback for mutating the `Request` object\n after it has been constructed.\n This is useful for cases where you want to add certain headers based off of\n the request properties, e.g. `url`, `method` etc.\n \"\"\"\n return None\n\n @overload\n def request(\n self,\n cast_to: Type[ResponseT],\n options: FinalRequestOptions,\n remaining_retries: Optional[int] = None,\n *,\n stream: Literal[True],\n stream_cls: Type[_StreamT],\n ) -> _StreamT:\n ...\n\n @overload\n def request(\n self,\n cast_to: Type[ResponseT],\n options: FinalRequestOptions,\n remaining_retries: Optional[int] = None,\n *,\n stream: Literal[False] = False,\n ) -> ResponseT:\n ...\n\n @overload\n def request(\n self,\n cast_to: Type[ResponseT],\n options: FinalRequestOptions,\n remaining_retries: Optional[int] = None,\n *,\n stream: bool = False,\n stream_cls: Type[_StreamT] | None = None,\n ) -> ResponseT | _StreamT:\n ...\n\n def request(\n self,\n cast_to: Type[ResponseT],\n options: FinalRequestOptions,\n remaining_retries: Optional[int] = None,\n *,\n stream: bool = False,\n stream_cls: type[_StreamT] | None = None,\n ) -> ResponseT | _StreamT:\n return self._request(\n cast_to=cast_to,\n options=options,\n stream=stream,\n stream_cls=stream_cls,\n remaining_retries=remaining_retries,\n )\n\n def _request(\n self,\n *,\n cast_to: Type[ResponseT],\n options: FinalRequestOptions,\n remaining_retries: int | None,\n stream: bool,\n stream_cls: type[_StreamT] | None,\n ) -> ResponseT | _StreamT:\n cast_to = self._maybe_override_cast_to(cast_to, options)\n self._prepare_options(options)\n\n retries = self._remaining_retries(remaining_retries, options)\n request = self._build_request(options)\n self._prepare_request(request)\n\n kwargs: HttpxSendArgs = {}\n if self.custom_auth is not None:\n kwargs[\"auth\"] = self.custom_auth\n\n try:\n response = self._client.send(\n request,\n stream=stream or self._should_stream_response_body(request=request),\n **kwargs,\n )\n except httpx.TimeoutException as err:\n log.debug(\"Encountered httpx.TimeoutException\", exc_info=True)\n\n if retries > 0:\n return self._retry_request(\n options,\n cast_to,\n retries,\n stream=stream,\n stream_cls=stream_cls,\n response_headers=None,\n )\n\n log.debug(\"Raising timeout error\")\n raise APITimeoutError(request=request) from err\n except Exception as err:\n log.debug(\"Encountered Exception\", exc_info=True)\n\n if retries > 0:\n return self._retry_request(\n options,\n cast_to,\n retries,\n stream=stream,\n stream_cls=stream_cls,\n response_headers=None,\n )\n\n log.debug(\"Raising connection error\")\n raise APIConnectionError(request=request) from err\n\n log.debug(\n 'HTTP Request: %s %s \"%i %s\"', request.method, request.url, response.status_code, response.reason_phrase\n )\n\n try:\n response.raise_for_status()\n except httpx.HTTPStatusError as err: # thrown on 4xx and 5xx status code\n log.debug(\"Encountered httpx.HTTPStatusError\", exc_info=True)\n\n if retries > 0 and self._should_retry(err.response):\n err.response.close()\n return self._retry_request(\n options,\n cast_to,\n retries,\n err.response.headers,\n stream=stream,\n stream_cls=stream_cls,\n )\n\n # If the response is streamed then we need to explicitly read the response\n # to completion before attempting to access the response text.\n if not err.response.is_closed:\n err.response.read()\n\n log.debug(\"Re-raising status error\")\n raise self._make_status_error_from_response(err.response) from None\n\n return self._process_response(\n cast_to=cast_to,\n options=options,\n response=response,\n stream=stream,\n stream_cls=stream_cls,\n )\n\n def _retry_request(\n self,\n options: FinalRequestOptions,\n cast_to: Type[ResponseT],\n remaining_retries: int,\n response_headers: httpx.Headers | None,\n *,\n stream: bool,\n stream_cls: type[_StreamT] | None,\n ) -> ResponseT | _StreamT:\n remaining = remaining_retries - 1\n if remaining == 1:\n log.debug(\"1 retry left\")\n else:\n log.debug(\"%i retries left\", remaining)\n\n timeout = self._calculate_retry_timeout(remaining, options, response_headers)\n log.info(\"Retrying request to %s in %f seconds\", options.url, timeout)\n\n # In a synchronous context we are blocking the entire thread. Up to the library user to run the client in a\n # different thread if necessary.\n time.sleep(timeout)\n\n return self._request(\n options=options,\n cast_to=cast_to,\n remaining_retries=remaining,\n stream=stream,\n stream_cls=stream_cls,\n )\n\n def _process_response(\n self,\n *,\n cast_to: Type[ResponseT],\n options: FinalRequestOptions,\n response: httpx.Response,\n stream: bool,\n stream_cls: type[Stream[Any]] | type[AsyncStream[Any]] | None,\n ) -> ResponseT:\n if response.request.headers.get(RAW_RESPONSE_HEADER) == \"true\":\n return cast(\n ResponseT,\n LegacyAPIResponse(\n raw=response,\n client=self,\n cast_to=cast_to,\n stream=stream,\n stream_cls=stream_cls,\n options=options,\n ),\n )\n\n origin = get_origin(cast_to) or cast_to\n\n if inspect.isclass(origin) and issubclass(origin, BaseAPIResponse):\n if not issubclass(origin, APIResponse):\n raise TypeError(f\"API Response types must subclass {APIResponse}; Received {origin}\")\n\n response_cls = cast(\"type[BaseAPIResponse[Any]]\", cast_to)\n return cast(\n ResponseT,\n response_cls(\n raw=response,\n client=self,\n cast_to=extract_response_type(response_cls),\n stream=stream,\n stream_cls=stream_cls,\n options=options,\n ),\n )\n\n if cast_to == httpx.Response:\n return cast(ResponseT, response)\n\n api_response = APIResponse(\n raw=response,\n client=self,\n cast_to=cast(\"type[ResponseT]\", cast_to), # pyright: ignore[reportUnnecessaryCast]\n stream=stream,\n stream_cls=stream_cls,\n options=options,\n )\n if bool(response.request.headers.get(RAW_RESPONSE_HEADER)):\n return cast(ResponseT, api_response)\n\n return api_response.parse()\n\n def _request_api_list(\n self,\n model: Type[object],\n page: Type[SyncPageT],\n options: FinalRequestOptions,\n ) -> SyncPageT:\n def _parser(resp: SyncPageT) -> SyncPageT:\n resp._set_private_attributes(\n client=self,\n model=model,\n options=options,\n )\n return resp\n\n options.post_parser = _parser\n\n return self.request(page, options, stream=False)\n\n @overload\n def get(\n self,\n path: str,\n *,\n cast_to: Type[ResponseT],\n options: RequestOptions = {},\n stream: Literal[False] = False,\n ) -> ResponseT:\n ...\n\n @overload\n def get(\n self,\n path: str,\n *,\n cast_to: Type[ResponseT],\n options: RequestOptions = {},\n stream: Literal[True],\n stream_cls: type[_StreamT],\n ) -> _StreamT:\n ...\n\n @overload\n def get(\n self,\n path: str,\n *,\n cast_to: Type[ResponseT],\n options: RequestOptions = {},\n stream: bool,\n stream_cls: type[_StreamT] | None = None,\n ) -> ResponseT | _StreamT:\n ...\n\n def get(\n self,\n path: str,\n *,\n cast_to: Type[ResponseT],\n options: RequestOptions = {},\n stream: bool = False,\n stream_cls: type[_StreamT] | None = None,\n ) -> ResponseT | _StreamT:\n opts = FinalRequestOptions.construct(method=\"get\", url=path, **options)\n # cast is required because mypy complains about returning Any even though\n # it understands the type variables\n return cast(ResponseT, self.request(cast_to, opts, stream=stream, stream_cls=stream_cls))\n\n @overload\n def post(\n self,\n path: str,\n *,\n cast_to: Type[ResponseT],\n body: Body | None = None,\n options: RequestOptions = {},\n files: RequestFiles | None = None,\n stream: Literal[False] = False,\n ) -> ResponseT:\n ...\n\n @overload\n def post(\n self,\n path: str,\n *,\n cast_to: Type[ResponseT],\n body: Body | None = None,\n options: RequestOptions = {},\n files: RequestFiles | None = None,\n stream: Literal[True],\n stream_cls: type[_StreamT],\n ) -> _StreamT:\n ...\n\n @overload\n def post(\n self,\n path: str,\n *,\n cast_to: Type[ResponseT],\n body: Body | None = None,\n options: RequestOptions = {},\n files: RequestFiles | None = None,\n stream: bool,\n stream_cls: type[_StreamT] | None = None,\n ) -> ResponseT | _StreamT:\n ...\n\n def post(\n self,\n path: str,\n *,\n cast_to: Type[ResponseT],\n body: Body | None = None,\n options: RequestOptions = {},\n files: RequestFiles | None = None,\n stream: bool = False,\n...\n# Path: src/openai/_module_client.py\n# File generated from our OpenAPI spec by Stainless.\n\nfrom typing_extensions import override\n\nfrom . import resources, _load_client\nfrom ._utils import LazyProxy\n\n\nclass ChatProxy(LazyProxy[resources.Chat]):\n @override\n def __load__(self) -> resources.Chat:\n return _load_client().chat\n\n\nclass BetaProxy(LazyProxy[resources.Beta]):\n @override\n def __load__(self) -> resources.Beta:\n return _load_client().beta\n\n\nclass FilesProxy(LazyProxy[resources.Files]):\n @override\n def __load__(self) -> resources.Files:\n return _load_client().files\n\n\nclass AudioProxy(LazyProxy[resources.Audio]):\n @override\n def __load__(self) -> resources.Audio:\n return _load_client().audio\n\n\nclass ImagesProxy(LazyProxy[resources.Images]):\n @override\n def __load__(self) -> resources.Images:\n return _load_client().images\n\n\nclass ModelsProxy(LazyProxy[resources.Models]):\n @override\n def __load__(self) -> resources.Models:\n return _load_client().models\n\n\nclass EmbeddingsProxy(LazyProxy[resources.Embeddings]):\n @override\n def __load__(self) -> resources.Embeddings:\n return _load_client().embeddings\n\n\nclass CompletionsProxy(LazyProxy[resources.Completions]):\n @override\n def __load__(self) -> resources.Completions:\n return _load_client().completions\n\n\nclass ModerationsProxy(LazyProxy[resources.Moderations]):\n @override\n def __load__(self) -> resources.Moderations:\n return _load_client().moderations\n\n\nclass FineTuningProxy(LazyProxy[resources.FineTuning]):\n @override\n def __load__(self) -> resources.FineTuning:\n return _load_client().fine_tuning\n\n\nchat: resources.Chat = ChatProxy().__as_proxied__()\nbeta: resources.Beta = BetaProxy().__as_proxied__()\nfiles: resources.Files = FilesProxy().__as_proxied__()\naudio: resources.Audio = AudioProxy().__as_proxied__()\nimages: resources.Images = ImagesProxy().__as_proxied__()\nmodels: resources.Models = ModelsProxy().__as_proxied__()\nembeddings: resources.Embeddings = EmbeddingsProxy().__as_proxied__()\ncompletions: resources.Completions = CompletionsProxy().__as_proxied__()\nmoderations: resources.Moderations = ModerationsProxy().__as_proxied__()\nfine_tuning: resources.FineTuning = FineTuningProxy().__as_proxied__()\n\n# Path: src/openai/_utils/_logs.py\nimport os\nimport logging\n\nlogger: logging.Logger = logging.getLogger(\"openai\")\nhttpx_logger: logging.Logger = logging.getLogger(\"httpx\")\n\n\n\ndef _basic_config() -> None:\n # e.g. [2023-10-05 14:12:26 - openai._base_client:818 - DEBUG] HTTP Request: POST http://127.0.0.1:4010/foo/bar \"200 OK\"\n logging.basicConfig(\n format=\"[%(asctime)s - %(name)s:%(lineno)d - %(levelname)s] %(message)s\",\n datefmt=\"%Y-%m-%d %H:%M:%S\",\n )\n\n\ndef setup_logging() -> None:\n env = os.environ.get(\"OPENAI_LOG\")\n if env == \"debug\":\n _basic_config()\n logger.setLevel(logging.DEBUG)\n httpx_logger.setLevel(logging.DEBUG)\n elif env == \"info\":\n _basic_config()\n logger.setLevel(logging.INFO)\n httpx_logger.setLevel(logging.INFO)\n\n# Path: src/openai/__init__.py\n# File generated from our OpenAPI spec by Stainless.\n\nfrom __future__ import annotations\n\nimport os as _os\nfrom typing_extensions import override\n\nfrom . import types\nfrom ._types import NoneType, Transport, ProxiesTypes\nfrom ._utils import file_from_path\nfrom ._client import Client, OpenAI, Stream, Timeout, Transport, AsyncClient, AsyncOpenAI, AsyncStream, RequestOptions\nfrom ._models import BaseModel\nfrom ._version import __title__, __version__\nfrom ._response import APIResponse as APIResponse, AsyncAPIResponse as AsyncAPIResponse\nfrom ._exceptions import (\n APIError,\n OpenAIError,\n ConflictError,\n NotFoundError,\n APIStatusError,\n RateLimitError,\n APITimeoutError,\n BadRequestError,\n APIConnectionError,\n AuthenticationError,\n InternalServerError,\n PermissionDeniedError,\n UnprocessableEntityError,\n APIResponseValidationError,\n)\nfrom ._utils._logs import setup_logging as _setup_logging\n\n__all__ = [\n \"types\",\n \"__version__\",\n \"__title__\",\n \"NoneType\",\n \"Transport\",\n \"ProxiesTypes\",\n \"OpenAIError\",\n \"APIError\",\n \"APIStatusError\",\n \"APITimeoutError\",\n \"APIConnectionError\",\n \"APIResponseValidationError\",\n \"BadRequestError\",\n \"AuthenticationError\",\n \"PermissionDeniedError\",\n \"NotFoundError\",\n \"ConflictError\",\n \"UnprocessableEntityError\",\n \"RateLimitError\",\n \"InternalServerError\",\n \"Timeout\",\n \"RequestOptions\",\n \"Client\",\n \"AsyncClient\",\n \"Stream\",\n \"AsyncStream\",\n \"OpenAI\",\n \"AsyncOpenAI\",\n \"file_from_path\",\n \"BaseModel\",\n]\n\nfrom .lib import azure as _azure\nfrom .version import VERSION as VERSION\nfrom .lib.azure import AzureOpenAI as AzureOpenAI, AsyncAzureOpenAI as AsyncAzureOpenAI\nfrom .lib._old_api import *\n\n_setup_logging()\n\n# Update the __module__ attribute for exported symbols so that\n# error messages point to this module instead of the module\n# it was originally defined in, e.g.\n# openai._exceptions.NotFoundError -> openai.NotFoundError\n__locals = locals()\nfor __name in __all__:\n if not __name.startswith(\"__\"):\n try:\n __locals[__name].__module__ = \"openai\"\n except (TypeError, AttributeError):\n # Some of our exported symbols are builtins which we can't set attributes for.\n pass\n\n# ------ Module level client ------\nimport typing as _t\nimport typing_extensions as _te\n\nimport httpx as _httpx\n\nfrom ._base_client import DEFAULT_TIMEOUT, DEFAULT_MAX_RETRIES\n\napi_key: str | None = None\n\norganization: str | None = None\n\nbase_url: str | _httpx.URL | None = None\n\ntimeout: float | Timeout | None = DEFAULT_TIMEOUT\n\nmax_retries: int = DEFAULT_MAX_RETRIES\n\ndefault_headers: _t.Mapping[str, str] | None = None\n\ndefault_query: _t.Mapping[str, object] | None = None\n\nhttp_client: _httpx.Client | None = None\n\n_ApiType = _te.Literal[\"openai\", \"azure\"]\n\napi_type: _ApiType | None = _t.cast(_ApiType, _os.environ.get(\"OPENAI_API_TYPE\"))\n\napi_version: str | None = _os.environ.get(\"OPENAI_API_VERSION\")\n\nazure_endpoint: str | None = _os.environ.get(\"AZURE_OPENAI_ENDPOINT\")\n\nazure_ad_token: str | None = _os.environ.get(\"AZURE_OPENAI_AD_TOKEN\")\n\nazure_ad_token_provider: _azure.AzureADTokenProvider | None = None\n\n\nclass _ModuleClient(OpenAI):\n # Note: we have to use type: ignores here as overriding class members\n # with properties is technically unsafe but it is fine for our use case\n\n @property # type: ignore\n @override\n def api_key(self) -> str | None:\n return api_key\n\n @api_key.setter # type: ignore\n def api_key(self, value: str | None) -> None: # type: ignore\n global api_key\n\n api_key = value\n\n @property # type: ignore\n @override\n def organization(self) -> str | None:\n return organization\n\n @organization.setter # type: ignore\n def organization(self, value: str | None) -> None: # type: ignore\n global organization\n\n organization = value\n\n @property\n @override\n def base_url(self) -> _httpx.URL:\n if base_url is not None:\n return _httpx.URL(base_url)\n\n return super().base_url\n\n @base_url.setter\n def base_url(self, url: _httpx.URL | str) -> None:\n super().base_url = url # type: ignore[misc]\n\n @property # type: ignore\n @override\n def timeout(self) -> float | Timeout | None:\n return timeout\n\n @timeout.setter # type: ignore\n def timeout(self, value: float | Timeout | None) -> None: # type: ignore\n global timeout\n\n timeout = value\n\n @property # type: ignore\n @override\n def max_retries(self) -> int:\n return max_retries\n\n @max_retries.setter # type: ignore\n def max_retries(self, value: int) -> None: # type: ignore\n global max_retries\n\n max_retries = value\n\n @property # type: ignore\n @override\n def _custom_headers(self) -> _t.Mapping[str, str] | None:\n return default_headers\n\n @_custom_headers.setter # type: ignore\n def _custom_headers(self, value: _t.Mapping[str, str] | None) -> None: # type: ignore\n global default_headers\n\n default_headers = value\n\n @property # type: ignore\n @override\n def _custom_query(self) -> _t.Mapping[str, object] | None:\n return default_query\n\n @_custom_query.setter # type: ignore\n def _custom_query(self, value: _t.Mapping[str, object] | None) -> None: # type: ignore\n global default_query\n\n default_query = value\n\n @property # type: ignore\n @override\n def _client(self) -> _httpx.Client:\n return http_client or super()._client\n\n @_client.setter # type: ignore\n def _client(self, value: _httpx.Client) -> None: # type: ignore\n global http_client\n\n http_client = value\n\n\nclass _AzureModuleClient(_ModuleClient, AzureOpenAI): # type: ignore\n ...\n\n\nclass _AmbiguousModuleClientUsageError(OpenAIError):\n def __init__(self) -> None:\n super().__init__(\n \"Ambiguous use of module client; please set `openai.api_type` or the `OPENAI_API_TYPE` environment variable to `openai` or `azure`\"\n )\n\n\ndef _has_openai_credentials() -> bool:\n return _os.environ.get(\"OPENAI_API_KEY\") is not None\n\n\ndef _has_azure_credentials() -> bool:\n return azure_endpoint is not None or _os.environ.get(\"AZURE_OPENAI_API_KEY\") is not None\n\n\ndef _has_azure_ad_credentials() -> bool:\n return (\n _os.environ.get(\"AZURE_OPENAI_AD_TOKEN\") is not None\n or azure_ad_token is not None\n or azure_ad_token_provider is not None\n )\n\n\n_client: OpenAI | None = None\n\n\ndef _load_client() -> OpenAI: # type: ignore[reportUnusedFunction]\n global _client\n\n if _client is None:\n global api_type, azure_endpoint, azure_ad_token, api_version\n\n if azure_endpoint is None:\n azure_endpoint = _os.environ.get(\"AZURE_OPENAI_ENDPOINT\")\n\n if azure_ad_token is None:\n azure_ad_token = _os.environ.get(\"AZURE_OPENAI_AD_TOKEN\")\n\n if api_version is None:\n api_version = _os.environ.get(\"OPENAI_API_VERSION\")\n\n if api_type is None:\n has_openai = _has_openai_credentials()\n has_azure = _has_azure_credentials()\n has_azure_ad = _has_azure_ad_credentials()\n\n if has_openai and (has_azure or has_azure_ad):\n raise _AmbiguousModuleClientUsageError()\n\n if (azure_ad_token is not None or azure_ad_token_provider is not None) and _os.environ.get(\n \"AZURE_OPENAI_API_KEY\"\n ) is not None:\n raise _AmbiguousModuleClientUsageError()\n\n if has_azure or has_azure_ad:\n api_type = \"azure\"\n else:\n api_type = \"openai\"\n\n if api_type == \"azure\":\n _client = _AzureModuleClient( # type: ignore\n api_version=api_version,\n azure_endpoint=azure_endpoint,\n api_key=api_key,\n azure_ad_token=azure_ad_token,\n azure_ad_token_provider=azure_ad_token_provider,\n organization=organization,\n base_url=base_url,\n timeout=timeout,\n max_retries=max_retries,\n default_headers=default_headers,\n default_query=default_query,\n http_client=http_client,\n )\n return _client\n\n _client = _ModuleClient(\n api_key=api_key,\n organization=organization,\n base_url=base_url,\n timeout=timeout,\n max_retries=max_retries,\n default_headers=default_headers,\n default_query=default_query,\n http_client=http_client,\n )\n return _client\n\n return _client\n\n\ndef _reset_client() -> None: # type: ignore[reportUnusedFunction]\n global _client\n\n _client = None\n\n\nfrom ._module_client import (\n beta as beta,\n chat as chat,\n audio as audio,\n files as files,\n images as images,\n models as models,\n embeddings as embeddings,\n completions as completions,\n fine_tuning as fine_tuning,\n moderations as moderations,\n)\n\n# Path: src/openai/cli/_models.py\nfrom typing import Any\nfrom typing_extensions import ClassVar\n\nimport pydantic\n\nfrom .. import _models\nfrom .._compat import PYDANTIC_V2, ConfigDict\n\n\nclass BaseModel(_models.BaseModel):\n if PYDANTIC_V2:\n model_config: ClassVar[ConfigDict] = ConfigDict(extra=\"ignore\", arbitrary_types_allowed=True)\n else:\n\n class Config(pydantic.BaseConfig): # type: ignore\n extra: Any = pydantic.Extra.ignore # type: ignore\n arbitrary_types_allowed: bool = True\n\n# Path: src/openai/cli/_progress.py\nfrom __future__ import annotations\n\nimport io\nfrom typing import Callable\nfrom typing_extensions import override\n\n\nclass CancelledError(Exception):\n def __init__(self, msg: str) -> None:\n self.msg = msg\n super().__init__(msg)\n\n @override\n def __str__(self) -> str:\n return self.msg\n\n __repr__ = __str__\n\n\nclass BufferReader(io.BytesIO):\n def __init__(self, buf: bytes = b\"\", desc: str | None = None) -> None:\n super().__init__(buf)\n self._len = len(buf)\n self._progress = 0\n self._callback = progress(len(buf), desc=desc)\n\n def __len__(self) -> int:\n return self._len\n\n @override\n def read(self, n: int | None = -1) -> bytes:\n chunk = io.BytesIO.read(self, n)\n self._progress += len(chunk)\n\n try:\n self._callback(self._progress)\n except Exception as e: # catches exception from the callback\n raise CancelledError(\"The upload was cancelled: {}\".format(e)) from e\n\n return chunk\n\n\ndef progress(total: float, desc: str | None) -> Callable[[float], None]:\n import tqdm\n\n meter = tqdm.tqdm(total=total, unit_scale=True, desc=desc)\n\n def incr(progress: float) -> None:\n meter.n = progress\n if progress == total:\n meter.close()\n else:\n meter.refresh()\n\n return incr\n\n\ndef MB(i: int) -> int:\n return int(i // 1024**2)\n\n# Path: src/openai/cli/_utils.py\nfrom __future__ import annotations\n\nimport sys\n\nimport openai\n\nfrom .. import OpenAI, _load_client\nfrom .._compat import model_json\nfrom .._models import BaseModel\n\n\nclass Colors:\n HEADER = \"\\033[95m\"\n OKBLUE = \"\\033[94m\"\n OKGREEN = \"\\033[92m\"\n WARNING = \"\\033[93m\"\n FAIL = \"\\033[91m\"\n ENDC = \"\\033[0m\"\n BOLD = \"\\033[1m\"\n UNDERLINE = \"\\033[4m\"\n\n\ndef get_client() -> OpenAI:\n return _load_client()\n\n\ndef organization_info() -> str:\n organization = openai.organization\n if organization is not None:\n return \"[organization={}] \".format(organization)\n\n return \"\"\n\n\ndef print_model(model: BaseModel) -> None:\n sys.stdout.write(model_json(model, indent=2) + \"\\n\")\n\n\ndef can_use_http2() -> bool:\n try:\n import h2 # type: ignore # noqa\n except ImportError:\n return False\n\n return True\n\n...\n", "needle_token_start": 12002, "needle_token_end": 12148, "code_context_ntokens": 16036, "output": ["```python\nimport os\nimport logging\n\nlogger: logging.Logger = logging.getLogger(\"openai\")\nhttpx_logger: logging.Logger = logging.getLogger(\"httpx\")\n\ndef _basic_config() -> None:\n # e.g. [2023-10-05 14:12:26 - openai._base_client:818 - DEBUG] HTTP Request: POST http://127.0.0.1:4010/foo/bar \"200 OK\"\n logging.basicConfig(\n format=\"[%(asctime)s - %(name)s:%(lineno)d - %(levelname)s] %(message)s\",\n datefmt=\"%Y-%m-%d %H:%M:%S\",\n )\n\ndef setup_logging() -> None:\n env = os.environ.get(\"OPENAI_LOG\")\n if env == \"debug\":\n _basic_config()\n logger.setLevel(logging.DEBUG)\n httpx_logger.setLevel(logging.DEBUG)\n elif env == \"info\":\n _basic_config()\n logger.setLevel(logging.INFO)\n httpx_logger.setLevel(logging.INFO)\n```"]} +{"repo": "openai/openai-python", "name": "__getattr__", "language": "python", "path": "src/openai/_utils/_proxy.py", "position_ratio": 0.85, "description": "\nFunction Description:\n1. **Purpose**: To dynamically handle attribute access on instances that do not explicitly define those attributes, effectively forwarding the attribute access to another object.\n2. **Input**: A string representing the name of the attribute being accessed.\n3. **Output**: The value of the attribute from the proxied object, or the proxied object itself if it is an instance of the same class.\n4. **Procedure**: When an attribute is accessed that isn't directly found on the instance, this method retrieves the proxied object. If this proxied object is an instance of the same class, it returns the proxied object itself. Otherwise, it retrieves and returns the attribute from the proxied object.\n\n", "instruction": "Based on the function description and code context, please retrieve and repeat the exact described function from the code context in a code block wrapped by ```:", "template": "instruction\ncode_context\ndescription\ninstruction", "code_context": "# Path: src/openai/_constants.py\n# File generated from our OpenAPI spec by Stainless.\n\nimport httpx\n\nRAW_RESPONSE_HEADER = \"X-Stainless-Raw-Response\"\nOVERRIDE_CAST_TO_HEADER = \"____stainless_override_cast_to\"\n\n# default timeout is 10 minutes\nDEFAULT_TIMEOUT = httpx.Timeout(timeout=600.0, connect=5.0)\nDEFAULT_MAX_RETRIES = 2\nDEFAULT_LIMITS = httpx.Limits(max_connections=100, max_keepalive_connections=20)\n\nINITIAL_RETRY_DELAY = 0.5\nMAX_RETRY_DELAY = 8.0\n\n# Path: src/openai/_utils/_proxy.py\nfrom __future__ import annotations\n\nfrom abc import ABC, abstractmethod\nfrom typing import Generic, TypeVar, Iterable, cast\nfrom typing_extensions import override\n\nT = TypeVar(\"T\")\n\n\nclass LazyProxy(Generic[T], ABC):\n \"\"\"Implements data methods to pretend that an instance is another instance.\n\n This includes forwarding attribute access and othe methods.\n \"\"\"\n\n # Note: we have to special case proxies that themselves return proxies\n # to support using a proxy as a catch-all for any random access, e.g. `proxy.foo.bar.baz`\n\n \ndef __getattr__(self, attr: str) -> object:\n proxied = self.__get_proxied__()\n if isinstance(proxied, LazyProxy):\n return proxied # pyright: ignore\n return getattr(proxied, attr)\n\n @override\n def __repr__(self) -> str:\n proxied = self.__get_proxied__()\n if isinstance(proxied, LazyProxy):\n return proxied.__class__.__name__\n return repr(self.__get_proxied__())\n\n @override\n def __str__(self) -> str:\n proxied = self.__get_proxied__()\n if isinstance(proxied, LazyProxy):\n return proxied.__class__.__name__\n return str(proxied)\n\n @override\n def __dir__(self) -> Iterable[str]:\n proxied = self.__get_proxied__()\n if isinstance(proxied, LazyProxy):\n return []\n return proxied.__dir__()\n\n @property # type: ignore\n @override\n def __class__(self) -> type: # pyright: ignore\n proxied = self.__get_proxied__()\n if issubclass(type(proxied), LazyProxy):\n return type(proxied)\n return proxied.__class__\n\n def __get_proxied__(self) -> T:\n return self.__load__()\n\n def __as_proxied__(self) -> T:\n \"\"\"Helper method that returns the current proxy, typed as the loaded object\"\"\"\n return cast(T, self)\n\n @abstractmethod\n def __load__(self) -> T:\n ...\n\n# Path: src/openai/_utils/_streams.py\nfrom typing import Any\nfrom typing_extensions import Iterator, AsyncIterator\n\n\ndef consume_sync_iterator(iterator: Iterator[Any]) -> None:\n for _ in iterator:\n ...\n\n\nasync def consume_async_iterator(iterator: AsyncIterator[Any]) -> None:\n async for _ in iterator:\n ...\n\n# Path: src/openai/_utils/_sync.py\nfrom __future__ import annotations\n\nimport functools\nfrom typing import TypeVar, Callable, Awaitable\nfrom typing_extensions import ParamSpec\n\nimport anyio\nimport anyio.to_thread\n\nT_Retval = TypeVar(\"T_Retval\")\nT_ParamSpec = ParamSpec(\"T_ParamSpec\")\n\n\n# copied from `asyncer`, https://github.com/tiangolo/asyncer\ndef asyncify(\n function: Callable[T_ParamSpec, T_Retval],\n *,\n cancellable: bool = False,\n limiter: anyio.CapacityLimiter | None = None,\n) -> Callable[T_ParamSpec, Awaitable[T_Retval]]:\n \"\"\"\n Take a blocking function and create an async one that receives the same\n positional and keyword arguments, and that when called, calls the original function\n in a worker thread using `anyio.to_thread.run_sync()`. Internally,\n `asyncer.asyncify()` uses the same `anyio.to_thread.run_sync()`, but it supports\n keyword arguments additional to positional arguments and it adds better support for\n autocompletion and inline errors for the arguments of the function called and the\n return value.\n\n If the `cancellable` option is enabled and the task waiting for its completion is\n cancelled, the thread will still run its course but its return value (or any raised\n exception) will be ignored.\n\n Use it like this:\n\n ```Python\n def do_work(arg1, arg2, kwarg1=\"\", kwarg2=\"\") -> str:\n # Do work\n return \"Some result\"\n\n\n result = await to_thread.asyncify(do_work)(\"spam\", \"ham\", kwarg1=\"a\", kwarg2=\"b\")\n print(result)\n ```\n\n ## Arguments\n\n `function`: a blocking regular callable (e.g. a function)\n `cancellable`: `True` to allow cancellation of the operation\n `limiter`: capacity limiter to use to limit the total amount of threads running\n (if omitted, the default limiter is used)\n\n ## Return\n\n An async function that takes the same positional and keyword arguments as the\n original one, that when called runs the same original function in a thread worker\n and returns the result.\n \"\"\"\n\n async def wrapper(*args: T_ParamSpec.args, **kwargs: T_ParamSpec.kwargs) -> T_Retval:\n partial_f = functools.partial(function, *args, **kwargs)\n return await anyio.to_thread.run_sync(partial_f, cancellable=cancellable, limiter=limiter)\n\n return wrapper\n\n# Path: src/openai/_utils/_typing.py\nfrom __future__ import annotations\n\nfrom typing import Any, TypeVar, Iterable, cast\nfrom collections import abc as _c_abc\nfrom typing_extensions import Required, Annotated, get_args, get_origin\n\nfrom .._types import InheritsGeneric\nfrom .._compat import is_union as _is_union\n\n\ndef is_annotated_type(typ: type) -> bool:\n return get_origin(typ) == Annotated\n\n\ndef is_list_type(typ: type) -> bool:\n return (get_origin(typ) or typ) == list\n\n\ndef is_iterable_type(typ: type) -> bool:\n \"\"\"If the given type is `typing.Iterable[T]`\"\"\"\n origin = get_origin(typ) or typ\n return origin == Iterable or origin == _c_abc.Iterable\n\n\ndef is_union_type(typ: type) -> bool:\n return _is_union(get_origin(typ))\n\n\ndef is_required_type(typ: type) -> bool:\n return get_origin(typ) == Required\n\n\ndef is_typevar(typ: type) -> bool:\n # type ignore is required because type checkers\n # think this expression will always return False\n return type(typ) == TypeVar # type: ignore\n\n\n# Extracts T from Annotated[T, ...] or from Required[Annotated[T, ...]]\ndef strip_annotated_type(typ: type) -> type:\n if is_required_type(typ) or is_annotated_type(typ):\n return strip_annotated_type(cast(type, get_args(typ)[0]))\n\n return typ\n\n\ndef extract_type_arg(typ: type, index: int) -> type:\n args = get_args(typ)\n try:\n return cast(type, args[index])\n except IndexError as err:\n raise RuntimeError(f\"Expected type {typ} to have a type argument at index {index} but it did not\") from err\n\n\ndef extract_type_var_from_base(\n typ: type,\n *,\n generic_bases: tuple[type, ...],\n index: int,\n failure_message: str | None = None,\n) -> type:\n \"\"\"Given a type like `Foo[T]`, returns the generic type variable `T`.\n\n This also handles the case where a concrete subclass is given, e.g.\n ```py\n class MyResponse(Foo[bytes]):\n ...\n\n extract_type_var(MyResponse, bases=(Foo,), index=0) -> bytes\n ```\n\n And where a generic subclass is given:\n ```py\n _T = TypeVar('_T')\n class MyResponse(Foo[_T]):\n ...\n\n extract_type_var(MyResponse[bytes], bases=(Foo,), index=0) -> bytes\n ```\n \"\"\"\n cls = cast(object, get_origin(typ) or typ)\n if cls in generic_bases:\n # we're given the class directly\n return extract_type_arg(typ, index)\n\n # if a subclass is given\n # ---\n # this is needed as __orig_bases__ is not present in the typeshed stubs\n # because it is intended to be for internal use only, however there does\n # not seem to be a way to resolve generic TypeVars for inherited subclasses\n # without using it.\n if isinstance(cls, InheritsGeneric):\n target_base_class: Any | None = None\n for base in cls.__orig_bases__:\n if base.__origin__ in generic_bases:\n target_base_class = base\n break\n\n if target_base_class is None:\n raise RuntimeError(\n \"Could not find the generic base class;\\n\"\n \"This should never happen;\\n\"\n f\"Does {cls} inherit from one of {generic_bases} ?\"\n )\n\n extracted = extract_type_arg(target_base_class, index)\n if is_typevar(extracted):\n # If the extracted type argument is itself a type variable\n # then that means the subclass itself is generic, so we have\n # to resolve the type argument from the class itself, not\n # the base class.\n #\n # Note: if there is more than 1 type argument, the subclass could\n # change the ordering of the type arguments, this is not currently\n # supported.\n return extract_type_arg(typ, index)\n\n return extracted\n\n raise RuntimeError(failure_message or f\"Could not resolve inner type variable at index {index} for {typ}\")\n\n# Path: src/openai/_files.py\nfrom __future__ import annotations\n\nimport io\nimport os\nimport pathlib\nfrom typing import overload\nfrom typing_extensions import TypeGuard\n\nimport anyio\n\nfrom ._types import (\n FileTypes,\n FileContent,\n...\n", "needle_token_start": 358, "needle_token_end": 425, "code_context_ntokens": 2787, "output": ["```python\ndef __getattr__(self, attr: str) -> object:\n proxied = self.__get_proxied__()\n if isinstance(proxied, LazyProxy):\n return proxied # pyright: ignore\n return getattr(proxied, attr)\n```"]} +{"repo": "openai/openai-python", "name": "_enforce_trailing_slash", "language": "python", "path": "src/openai/_base_client.py", "position_ratio": 0.95, "description": "\nFunction Description:\n1. **Purpose**: Ensures that a given URL ends with a slash, which is often required for proper URL formation in web requests.\n2. **Input**: A URL object representing a web address.\n3. **Output**: A URL object that is guaranteed to end with a slash.\n4. **Procedure**: The function checks if the last character of the URL's path is a slash. If not, it appends a slash to the end of the path and returns the modified URL. If the path already ends with a slash, the original URL is returned unchanged.\n\n", "instruction": "Based on the function description and code context, please retrieve and repeat the exact described function from the code context in a code block wrapped by ```:", "template": "instruction\ncode_context\ndescription\ninstruction", "code_context": "# Path: src/openai/_client.py\n# File generated from our OpenAPI spec by Stainless.\n\nfrom __future__ import annotations\n\nimport os\nfrom typing import Any, Union, Mapping\nfrom typing_extensions import Self, override\n\nimport httpx\n\nfrom . import resources, _exceptions\nfrom ._qs import Querystring\nfrom ._types import (\n NOT_GIVEN,\n Omit,\n Timeout,\n NotGiven,\n Transport,\n ProxiesTypes,\n RequestOptions,\n)\nfrom ._utils import (\n is_given,\n is_mapping,\n get_async_library,\n)\nfrom ._version import __version__\nfrom ._streaming import Stream as Stream, AsyncStream as AsyncStream\nfrom ._exceptions import OpenAIError, APIStatusError\nfrom ._base_client import (\n DEFAULT_MAX_RETRIES,\n SyncAPIClient,\n AsyncAPIClient,\n)\n\n__all__ = [\n \"Timeout\",\n \"Transport\",\n \"ProxiesTypes\",\n \"RequestOptions\",\n \"resources\",\n \"OpenAI\",\n \"AsyncOpenAI\",\n \"Client\",\n \"AsyncClient\",\n]\n\n\nclass OpenAI(SyncAPIClient):\n completions: resources.Completions\n chat: resources.Chat\n embeddings: resources.Embeddings\n files: resources.Files\n images: resources.Images\n audio: resources.Audio\n moderations: resources.Moderations\n models: resources.Models\n fine_tuning: resources.FineTuning\n beta: resources.Beta\n with_raw_response: OpenAIWithRawResponse\n with_streaming_response: OpenAIWithStreamedResponse\n\n # client options\n api_key: str\n organization: str | None\n\n def __init__(\n self,\n *,\n api_key: str | None = None,\n organization: str | None = None,\n base_url: str | httpx.URL | None = None,\n timeout: Union[float, Timeout, None, NotGiven] = NOT_GIVEN,\n max_retries: int = DEFAULT_MAX_RETRIES,\n default_headers: Mapping[str, str] | None = None,\n default_query: Mapping[str, object] | None = None,\n # Configure a custom httpx client. See the [httpx documentation](https://www.python-httpx.org/api/#client) for more details.\n http_client: httpx.Client | None = None,\n # Enable or disable schema validation for data returned by the API.\n # When enabled an error APIResponseValidationError is raised\n # if the API responds with invalid data for the expected schema.\n #\n # This parameter may be removed or changed in the future.\n # If you rely on this feature, please open a GitHub issue\n # outlining your use-case to help us decide if it should be\n # part of our public interface in the future.\n _strict_response_validation: bool = False,\n ) -> None:\n \"\"\"Construct a new synchronous openai client instance.\n\n This automatically infers the following arguments from their corresponding environment variables if they are not provided:\n - `api_key` from `OPENAI_API_KEY`\n - `organization` from `OPENAI_ORG_ID`\n \"\"\"\n if api_key is None:\n api_key = os.environ.get(\"OPENAI_API_KEY\")\n if api_key is None:\n raise OpenAIError(\n \"The api_key client option must be set either by passing api_key to the client or by setting the OPENAI_API_KEY environment variable\"\n )\n self.api_key = api_key\n\n if organization is None:\n organization = os.environ.get(\"OPENAI_ORG_ID\")\n self.organization = organization\n\n if base_url is None:\n base_url = os.environ.get(\"OPENAI_BASE_URL\")\n if base_url is None:\n base_url = f\"https://api.openai.com/v1\"\n\n super().__init__(\n version=__version__,\n base_url=base_url,\n max_retries=max_retries,\n timeout=timeout,\n http_client=http_client,\n custom_headers=default_headers,\n custom_query=default_query,\n _strict_response_validation=_strict_response_validation,\n )\n\n self._default_stream_cls = Stream\n\n self.completions = resources.Completions(self)\n self.chat = resources.Chat(self)\n self.embeddings = resources.Embeddings(self)\n self.files = resources.Files(self)\n...\n# Path: src/openai/_streaming.py\n# Note: initially copied from https://github.com/florimondmanca/httpx-sse/blob/master/src/httpx_sse/_decoders.py\nfrom __future__ import annotations\n\nimport json\nimport inspect\nfrom types import TracebackType\nfrom typing import TYPE_CHECKING, Any, Generic, TypeVar, Iterator, AsyncIterator, cast\nfrom typing_extensions import Self, TypeGuard, override, get_origin\n\nimport httpx\n\nfrom ._utils import is_mapping, extract_type_var_from_base\nfrom ._exceptions import APIError\n\nif TYPE_CHECKING:\n from ._client import OpenAI, AsyncOpenAI\n\n\n_T = TypeVar(\"_T\")\n\n\nclass Stream(Generic[_T]):\n \"\"\"Provides the core interface to iterate over a synchronous stream response.\"\"\"\n\n response: httpx.Response\n\n def __init__(\n self,\n *,\n cast_to: type[_T],\n response: httpx.Response,\n client: OpenAI,\n ) -> None:\n self.response = response\n self._cast_to = cast_to\n self._client = client\n self._decoder = SSEDecoder()\n self._iterator = self.__stream__()\n\n def __next__(self) -> _T:\n return self._iterator.__next__()\n\n def __iter__(self) -> Iterator[_T]:\n for item in self._iterator:\n yield item\n\n def _iter_events(self) -> Iterator[ServerSentEvent]:\n yield from self._decoder.iter(self.response.iter_lines())\n\n def __stream__(self) -> Iterator[_T]:\n cast_to = cast(Any, self._cast_to)\n response = self.response\n process_data = self._client._process_response_data\n iterator = self._iter_events()\n\n for sse in iterator:\n if sse.data.startswith(\"[DONE]\"):\n break\n\n if sse.event is None:\n data = sse.json()\n if is_mapping(data) and data.get(\"error\"):\n raise APIError(\n message=\"An error occurred during streaming\",\n request=self.response.request,\n body=data[\"error\"],\n )\n\n yield process_data(data=data, cast_to=cast_to, response=response)\n\n # Ensure the entire stream is consumed\n for _sse in iterator:\n ...\n\n def __enter__(self) -> Self:\n return self\n\n def __exit__(\n self,\n exc_type: type[BaseException] | None,\n exc: BaseException | None,\n exc_tb: TracebackType | None,\n ) -> None:\n self.close()\n\n def close(self) -> None:\n \"\"\"\n Close the response and release the connection.\n\n Automatically called if the response body is read to completion.\n \"\"\"\n self.response.close()\n\n\nclass AsyncStream(Generic[_T]):\n \"\"\"Provides the core interface to iterate over an asynchronous stream response.\"\"\"\n\n response: httpx.Response\n\n def __init__(\n self,\n *,\n cast_to: type[_T],\n response: httpx.Response,\n client: AsyncOpenAI,\n ) -> None:\n self.response = response\n self._cast_to = cast_to\n self._client = client\n self._decoder = SSEDecoder()\n self._iterator = self.__stream__()\n\n async def __anext__(self) -> _T:\n return await self._iterator.__anext__()\n\n async def __aiter__(self) -> AsyncIterator[_T]:\n async for item in self._iterator:\n yield item\n\n async def _iter_events(self) -> AsyncIterator[ServerSentEvent]:\n async for sse in self._decoder.aiter(self.response.aiter_lines()):\n yield sse\n\n async def __stream__(self) -> AsyncIterator[_T]:\n cast_to = cast(Any, self._cast_to)\n response = self.response\n process_data = self._client._process_response_data\n iterator = self._iter_events()\n\n async for sse in iterator:\n if sse.data.startswith(\"[DONE]\"):\n break\n\n if sse.event is None:\n data = sse.json()\n if is_mapping(data) and data.get(\"error\"):\n raise APIError(\n message=\"An error occurred during streaming\",\n request=self.response.request,\n body=data[\"error\"],\n )\n\n yield process_data(data=data, cast_to=cast_to, response=response)\n\n # Ensure the entire stream is consumed\n async for _sse in iterator:\n ...\n\n async def __aenter__(self) -> Self:\n return self\n\n async def __aexit__(\n self,\n exc_type: type[BaseException] | None,\n exc: BaseException | None,\n exc_tb: TracebackType | None,\n ) -> None:\n await self.close()\n\n async def close(self) -> None:\n \"\"\"\n Close the response and release the connection.\n\n Automatically called if the response body is read to completion.\n \"\"\"\n await self.response.aclose()\n\n\nclass ServerSentEvent:\n def __init__(\n self,\n *,\n event: str | None = None,\n data: str | None = None,\n id: str | None = None,\n retry: int | None = None,\n ) -> None:\n if data is None:\n data = \"\"\n\n self._id = id\n self._data = data\n self._event = event or None\n self._retry = retry\n\n @property\n def event(self) -> str | None:\n return self._event\n\n @property\n def id(self) -> str | None:\n return self._id\n\n @property\n def retry(self) -> int | None:\n return self._retry\n\n @property\n def data(self) -> str:\n return self._data\n\n def json(self) -> Any:\n return json.loads(self.data)\n\n @override\n def __repr__(self) -> str:\n return f\"ServerSentEvent(event={self.event}, data={self.data}, id={self.id}, retry={self.retry})\"\n\n\nclass SSEDecoder:\n _data: list[str]\n _event: str | None\n _retry: int | None\n _last_event_id: str | None\n\n def __init__(self) -> None:\n self._event = None\n self._data = []\n self._last_event_id = None\n self._retry = None\n\n def iter(self, iterator: Iterator[str]) -> Iterator[ServerSentEvent]:\n \"\"\"Given an iterator that yields lines, iterate over it & yield every event encountered\"\"\"\n for line in iterator:\n line = line.rstrip(\"\\n\")\n sse = self.decode(line)\n if sse is not None:\n yield sse\n\n async def aiter(self, iterator: AsyncIterator[str]) -> AsyncIterator[ServerSentEvent]:\n \"\"\"Given an async iterator that yields lines, iterate over it & yield every event encountered\"\"\"\n async for line in iterator:\n line = line.rstrip(\"\\n\")\n sse = self.decode(line)\n if sse is not None:\n yield sse\n\n def decode(self, line: str) -> ServerSentEvent | None:\n # See: https://html.spec.whatwg.org/multipage/server-sent-events.html#event-stream-interpretation # noqa: E501\n\n if not line:\n if not self._event and not self._data and not self._last_event_id and self._retry is None:\n return None\n\n sse = ServerSentEvent(\n event=self._event,\n data=\"\\n\".join(self._data),\n id=self._last_event_id,\n retry=self._retry,\n )\n\n # NOTE: as per the SSE spec, do not reset last_event_id.\n self._event = None\n self._data = []\n self._retry = None\n\n return sse\n\n if line.startswith(\":\"):\n return None\n\n fieldname, _, value = line.partition(\":\")\n\n if value.startswith(\" \"):\n value = value[1:]\n\n if fieldname == \"event\":\n self._event = value\n elif fieldname == \"data\":\n self._data.append(value)\n elif fieldname == \"id\":\n if \"\\0\" in value:\n pass\n else:\n self._last_event_id = value\n elif fieldname == \"retry\":\n try:\n self._retry = int(value)\n except (TypeError, ValueError):\n pass\n else:\n pass # Field is ignored.\n\n return None\n\n\ndef is_stream_class_type(typ: type) -> TypeGuard[type[Stream[object]] | type[AsyncStream[object]]]:\n \"\"\"TypeGuard for determining whether or not the given type is a subclass of `Stream` / `AsyncStream`\"\"\"\n origin = get_origin(typ) or typ\n return inspect.isclass(origin) and issubclass(origin, (Stream, AsyncStream))\n\n\ndef extract_stream_chunk_type(\n stream_cls: type,\n *,\n failure_message: str | None = None,\n) -> type:\n \"\"\"Given a type like `Stream[T]`, returns the generic type variable `T`.\n\n This also handles the case where a concrete subclass is given, e.g.\n ```py\n class MyStream(Stream[bytes]):\n ...\n\n extract_stream_chunk_type(MyStream) -> bytes\n ```\n \"\"\"\n from ._base_client import Stream, AsyncStream\n\n return extract_type_var_from_base(\n stream_cls,\n index=0,\n generic_bases=cast(\"tuple[type, ...]\", (Stream, AsyncStream)),\n failure_message=failure_message,\n )\n\n# Path: src/openai/_legacy_response.py\nfrom __future__ import annotations\n\nimport os\nimport inspect\nimport logging\nimport datetime\nimport functools\nfrom typing import TYPE_CHECKING, Any, Union, Generic, TypeVar, Callable, Iterator, AsyncIterator, cast, overload\nfrom typing_extensions import Awaitable, ParamSpec, override, deprecated, get_origin\n\nimport anyio\nimport httpx\nimport pydantic\n\nfrom ._types import NoneType\nfrom ._utils import is_given\nfrom ._models import BaseModel, is_basemodel\nfrom ._constants import RAW_RESPONSE_HEADER\nfrom ._streaming import Stream, AsyncStream, is_stream_class_type, extract_stream_chunk_type\nfrom ._exceptions import APIResponseValidationError\n\nif TYPE_CHECKING:\n from ._models import FinalRequestOptions\n from ._base_client import BaseClient\n\n\nP = ParamSpec(\"P\")\nR = TypeVar(\"R\")\n_T = TypeVar(\"_T\")\n\nlog: logging.Logger = logging.getLogger(__name__)\n\n\nclass LegacyAPIResponse(Generic[R]):\n \"\"\"This is a legacy class as it will be replaced by `APIResponse`\n and `AsyncAPIResponse` in the `_response.py` file in the next major\n release.\n\n For the sync client this will mostly be the same with the exception\n of `content` & `text` will be methods instead of properties. In the\n async client, all methods will be async.\n\n A migration script will be provided & the migration in general should\n be smooth.\n \"\"\"\n\n _cast_to: type[R]\n _client: BaseClient[Any, Any]\n _parsed_by_type: dict[type[Any], Any]\n _stream: bool\n _stream_cls: type[Stream[Any]] | type[AsyncStream[Any]] | None\n _options: FinalRequestOptions\n\n http_response: httpx.Response\n\n def __init__(\n self,\n *,\n raw: httpx.Response,\n cast_to: type[R],\n client: BaseClient[Any, Any],\n stream: bool,\n stream_cls: type[Stream[Any]] | type[AsyncStream[Any]] | None,\n options: FinalRequestOptions,\n ) -> None:\n self._cast_to = cast_to\n self._client = client\n self._parsed_by_type = {}\n self._stream = stream\n self._stream_cls = stream_cls\n self._options = options\n self.http_response = raw\n\n @overload\n def parse(self, *, to: type[_T]) -> _T:\n ...\n\n @overload\n def parse(self) -> R:\n ...\n\n def parse(self, *, to: type[_T] | None = None) -> R | _T:\n \"\"\"Returns the rich python representation of this response's data.\n\n NOTE: For the async client: this will become a coroutine in the next major version.\n\n For lower-level control, see `.read()`, `.json()`, `.iter_bytes()`.\n\n You can customise the type that the response is parsed into through\n the `to` argument, e.g.\n\n ```py\n from openai import BaseModel\n\n\n class MyModel(BaseModel):\n foo: str\n\n\n obj = response.parse(to=MyModel)\n print(obj.foo)\n ```\n\n We support parsing:\n - `BaseModel`\n - `dict`\n - `list`\n - `Union`\n - `str`\n - `httpx.Response`\n \"\"\"\n cache_key = to if to is not None else self._cast_to\n cached = self._parsed_by_type.get(cache_key)\n if cached is not None:\n return cached # type: ignore[no-any-return]\n\n parsed = self._parse(to=to)\n if is_given(self._options.post_parser):\n parsed = self._options.post_parser(parsed)\n\n self._parsed_by_type[cache_key] = parsed\n return parsed\n\n @property\n def headers(self) -> httpx.Headers:\n return self.http_response.headers\n\n @property\n def http_request(self) -> httpx.Request:\n return self.http_response.request\n\n @property\n def status_code(self) -> int:\n return self.http_response.status_code\n\n @property\n def url(self) -> httpx.URL:\n return self.http_response.url\n\n @property\n def method(self) -> str:\n return self.http_request.method\n\n @property\n def content(self) -> bytes:\n \"\"\"Return the binary response content.\n\n NOTE: this will be removed in favour of `.read()` in the\n next major version.\n \"\"\"\n return self.http_response.content\n\n @property\n def text(self) -> str:\n \"\"\"Return the decoded response content.\n\n NOTE: this will be turned into a method in the next major version.\n \"\"\"\n return self.http_response.text\n\n @property\n def http_version(self) -> str:\n return self.http_response.http_version\n\n @property\n def is_closed(self) -> bool:\n return self.http_response.is_closed\n\n @property\n def elapsed(self) -> datetime.timedelta:\n \"\"\"The time taken for the complete request/response cycle to complete.\"\"\"\n return self.http_response.elapsed\n\n def _parse(self, *, to: type[_T] | None = None) -> R | _T:\n if self._stream:\n if to:\n if not is_stream_class_type(to):\n raise TypeError(f\"Expected custom parse type to be a subclass of {Stream} or {AsyncStream}\")\n\n return cast(\n _T,\n to(\n cast_to=extract_stream_chunk_type(\n to,\n failure_message=\"Expected custom stream type to be passed with a type argument, e.g. Stream[ChunkType]\",\n ),\n response=self.http_response,\n client=cast(Any, self._client),\n ),\n )\n\n if self._stream_cls:\n return cast(\n R,\n self._stream_cls(\n cast_to=extract_stream_chunk_type(self._stream_cls),\n response=self.http_response,\n client=cast(Any, self._client),\n ),\n )\n\n stream_cls = cast(\"type[Stream[Any]] | type[AsyncStream[Any]] | None\", self._client._default_stream_cls)\n if stream_cls is None:\n raise MissingStreamClassError()\n\n return cast(\n R,\n stream_cls(\n cast_to=self._cast_to,\n response=self.http_response,\n client=cast(Any, self._client),\n ),\n )\n\n cast_to = to if to is not None else self._cast_to\n if cast_to is NoneType:\n return cast(R, None)\n\n response = self.http_response\n if cast_to == str:\n return cast(R, response.text)\n\n origin = get_origin(cast_to) or cast_to\n\n if inspect.isclass(origin) and issubclass(origin, HttpxBinaryResponseContent):\n return cast(R, cast_to(response)) # type: ignore\n\n if origin == LegacyAPIResponse:\n raise RuntimeError(\"Unexpected state - cast_to is `APIResponse`\")\n\n if inspect.isclass(origin) and issubclass(origin, httpx.Response):\n # Because of the invariance of our ResponseT TypeVar, users can subclass httpx.Response\n # and pass that class to our request functions. We cannot change the variance to be either\n # covariant or contravariant as that makes our usage of ResponseT illegal. We could construct\n # the response class ourselves but that is something that should be supported directly in httpx\n # as it would be easy to incorrectly construct the Response object due to the multitude of arguments.\n if cast_to != httpx.Response:\n raise ValueError(f\"Subclasses of httpx.Response cannot be passed to `cast_to`\")\n return cast(R, response)\n\n if inspect.isclass(origin) and not issubclass(origin, BaseModel) and issubclass(origin, pydantic.BaseModel):\n raise TypeError(\"Pydantic models must subclass our base model type, e.g. `from openai import BaseModel`\")\n\n if (\n cast_to is not object\n and not origin is list\n and not origin is dict\n and not origin is Union\n and not issubclass(origin, BaseModel)\n ):\n raise RuntimeError(\n f\"Unsupported type, expected {cast_to} to be a subclass of {BaseModel}, {dict}, {list}, {Union}, {NoneType}, {str} or {httpx.Response}.\"\n )\n\n # split is required to handle cases where additional information is included\n # in the response, e.g. application/json; charset=utf-8\n content_type, *_ = response.headers.get(\"content-type\", \"*\").split(\";\")\n if content_type != \"application/json\":\n if is_basemodel(cast_to):\n try:\n data = response.json()\n except Exception as exc:\n log.debug(\"Could not read JSON from response data due to %s - %s\", type(exc), exc)\n else:\n return self._client._process_response_data(\n data=data,\n cast_to=cast_to, # type: ignore\n response=response,\n )\n\n if self._client._strict_response_validation:\n raise APIResponseValidationError(\n response=response,\n message=f\"Expected Content-Type response header to be `application/json` but received `{content_type}` instead.\",\n body=response.text,\n )\n\n # If the API responds with content that isn't JSON then we just return\n # the (decoded) text without performing any parsing so that you can still\n # handle the response however you need to.\n return response.text # type: ignore\n\n data = response.json()\n\n return self._client._process_response_data(\n data=data,\n cast_to=cast_to, # type: ignore\n response=response,\n )\n\n @override\n def __repr__(self) -> str:\n return f\"\"\n\n\nclass MissingStreamClassError(TypeError):\n def __init__(self) -> None:\n super().__init__(\n \"The `stream` argument was set to `True` but the `stream_cls` argument was not given. See `openai._streaming` for reference\",\n )\n\n\ndef to_raw_response_wrapper(func: Callable[P, R]) -> Callable[P, LegacyAPIResponse[R]]:\n \"\"\"Higher order function that takes one of our bound API methods and wraps it\n to support returning the raw `APIResponse` object directly.\n \"\"\"\n\n @functools.wraps(func)\n def wrapped(*args: P.args, **kwargs: P.kwargs) -> LegacyAPIResponse[R]:\n extra_headers = {**(cast(Any, kwargs.get(\"extra_headers\")) or {})}\n extra_headers[RAW_RESPONSE_HEADER] = \"true\"\n\n kwargs[\"extra_headers\"] = extra_headers\n\n return cast(LegacyAPIResponse[R], func(*args, **kwargs))\n\n return wrapped\n\n\ndef async_to_raw_response_wrapper(func: Callable[P, Awaitable[R]]) -> Callable[P, Awaitable[LegacyAPIResponse[R]]]:\n \"\"\"Higher order function that takes one of our bound API methods and wraps it\n to support returning the raw `APIResponse` object directly.\n \"\"\"\n\n @functools.wraps(func)\n async def wrapped(*args: P.args, **kwargs: P.kwargs) -> LegacyAPIResponse[R]:\n extra_headers = {**(cast(Any, kwargs.get(\"extra_headers\")) or {})}\n extra_headers[RAW_RESPONSE_HEADER] = \"true\"\n\n kwargs[\"extra_headers\"] = extra_headers\n\n return cast(LegacyAPIResponse[R], await func(*args, **kwargs))\n\n return wrapped\n\n\nclass HttpxBinaryResponseContent:\n response: httpx.Response\n\n def __init__(self, response: httpx.Response) -> None:\n self.response = response\n\n @property\n def content(self) -> bytes:\n return self.response.content\n\n @property\n def text(self) -> str:\n return self.response.text\n\n @property\n def encoding(self) -> str | None:\n return self.response.encoding\n\n @property\n def charset_encoding(self) -> str | None:\n return self.response.charset_encoding\n\n def json(self, **kwargs: Any) -> Any:\n return self.response.json(**kwargs)\n\n def read(self) -> bytes:\n return self.response.read()\n\n def iter_bytes(self, chunk_size: int | None = None) -> Iterator[bytes]:\n return self.response.iter_bytes(chunk_size)\n\n def iter_text(self, chunk_size: int | None = None) -> Iterator[str]:\n return self.response.iter_text(chunk_size)\n\n def iter_lines(self) -> Iterator[str]:\n return self.response.iter_lines()\n\n def iter_raw(self, chunk_size: int | None = None) -> Iterator[bytes]:\n return self.response.iter_raw(chunk_size)\n\n def write_to_file(\n self,\n file: str | os.PathLike[str],\n ) -> None:\n \"\"\"Write the output to the given file.\n\n Accepts a filename or any path-like object, e.g. pathlib.Path\n\n Note: if you want to stream the data to the file instead of writing\n all at once then you should use `.with_streaming_response` when making\n the API request, e.g. `client.with_streaming_response.foo().stream_to_file('my_filename.txt')`\n \"\"\"\n with open(file, mode=\"wb\") as f:\n for data in self.response.iter_bytes():\n f.write(data)\n\n @deprecated(\n \"Due to a bug, this method doesn't actually stream the response content, `.with_streaming_response.method()` should be used instead\"\n )\n def stream_to_file(\n self,\n file: str | os.PathLike[str],\n *,\n chunk_size: int | None = None,\n ) -> None:\n with open(file, mode=\"wb\") as f:\n for data in self.response.iter_bytes(chunk_size):\n f.write(data)\n\n def close(self) -> None:\n return self.response.close()\n\n async def aread(self) -> bytes:\n return await self.response.aread()\n\n async def aiter_bytes(self, chunk_size: int | None = None) -> AsyncIterator[bytes]:\n return self.response.aiter_bytes(chunk_size)\n\n async def aiter_text(self, chunk_size: int | None = None) -> AsyncIterator[str]:\n return self.response.aiter_text(chunk_size)\n\n async def aiter_lines(self) -> AsyncIterator[str]:\n return self.response.aiter_lines()\n\n async def aiter_raw(self, chunk_size: int | None = None) -> AsyncIterator[bytes]:\n return self.response.aiter_raw(chunk_size)\n\n @deprecated(\n \"Due to a bug, this method doesn't actually stream the response content, `.with_streaming_response.method()` should be used instead\"\n )\n async def astream_to_file(\n self,\n file: str | os.PathLike[str],\n *,\n chunk_size: int | None = None,\n ) -> None:\n path = anyio.Path(file)\n async with await path.open(mode=\"wb\") as f:\n async for data in self.response.aiter_bytes(chunk_size):\n await f.write(data)\n\n async def aclose(self) -> None:\n return await self.response.aclose()\n\n# Path: src/openai/_types.py\nfrom __future__ import annotations\n\nfrom os import PathLike\nfrom typing import (\n IO,\n TYPE_CHECKING,\n Any,\n Dict,\n List,\n Type,\n Tuple,\n Union,\n Mapping,\n TypeVar,\n Callable,\n Optional,\n Sequence,\n)\nfrom typing_extensions import Literal, Protocol, TypeAlias, TypedDict, override, runtime_checkable\n\nimport httpx\nimport pydantic\nfrom httpx import URL, Proxy, Timeout, Response, BaseTransport, AsyncBaseTransport\n\nif TYPE_CHECKING:\n from ._models import BaseModel\n from ._response import APIResponse, AsyncAPIResponse\n from ._legacy_response import HttpxBinaryResponseContent\n\nTransport = BaseTransport\nAsyncTransport = AsyncBaseTransport\nQuery = Mapping[str, object]\nBody = object\nAnyMapping = Mapping[str, object]\nModelT = TypeVar(\"ModelT\", bound=pydantic.BaseModel)\n_T = TypeVar(\"_T\")\n\n\n# Approximates httpx internal ProxiesTypes and RequestFiles types\n# while adding support for `PathLike` instances\nProxiesDict = Dict[\"str | URL\", Union[None, str, URL, Proxy]]\nProxiesTypes = Union[str, Proxy, ProxiesDict]\nif TYPE_CHECKING:\n FileContent = Union[IO[bytes], bytes, PathLike[str]]\nelse:\n FileContent = Union[IO[bytes], bytes, PathLike] # PathLike is not subscriptable in Python 3.8.\nFileTypes = Union[\n # file (or bytes)\n FileContent,\n # (filename, file (or bytes))\n Tuple[Optional[str], FileContent],\n # (filename, file (or bytes), content_type)\n Tuple[Optional[str], FileContent, Optional[str]],\n # (filename, file (or bytes), content_type, headers)\n Tuple[Optional[str], FileContent, Optional[str], Mapping[str, str]],\n]\nRequestFiles = Union[Mapping[str, FileTypes], Sequence[Tuple[str, FileTypes]]]\n\n# duplicate of the above but without our custom file support\nHttpxFileContent = Union[IO[bytes], bytes]\nHttpxFileTypes = Union[\n # file (or bytes)\n HttpxFileContent,\n # (filename, file (or bytes))\n Tuple[Optional[str], HttpxFileContent],\n # (filename, file (or bytes), content_type)\n Tuple[Optional[str], HttpxFileContent, Optional[str]],\n # (filename, file (or bytes), content_type, headers)\n Tuple[Optional[str], HttpxFileContent, Optional[str], Mapping[str, str]],\n]\nHttpxRequestFiles = Union[Mapping[str, HttpxFileTypes], Sequence[Tuple[str, HttpxFileTypes]]]\n\n# Workaround to support (cast_to: Type[ResponseT]) -> ResponseT\n# where ResponseT includes `None`. In order to support directly\n# passing `None`, overloads would have to be defined for every\n# method that uses `ResponseT` which would lead to an unacceptable\n# amount of code duplication and make it unreadable. See _base_client.py\n# for example usage.\n#\n# This unfortunately means that you will either have\n# to import this type and pass it explicitly:\n#\n# from openai import NoneType\n# client.get('/foo', cast_to=NoneType)\n#\n# or build it yourself:\n#\n# client.get('/foo', cast_to=type(None))\nif TYPE_CHECKING:\n NoneType: Type[None]\nelse:\n NoneType = type(None)\n\n\nclass RequestOptions(TypedDict, total=False):\n headers: Headers\n max_retries: int\n timeout: float | Timeout | None\n params: Query\n extra_json: AnyMapping\n idempotency_key: str\n\n\n# Sentinel class used until PEP 0661 is accepted\nclass NotGiven:\n \"\"\"\n A sentinel singleton class used to distinguish omitted keyword arguments\n from those passed in with the value None (which may have different behavior).\n\n For example:\n\n ```py\n def get(timeout: Union[int, NotGiven, None] = NotGiven()) -> Response:\n ...\n\n\n get(timeout=1) # 1s timeout\n get(timeout=None) # No timeout\n get() # Default timeout behavior, which may not be statically known at the method definition.\n ```\n \"\"\"\n\n def __bool__(self) -> Literal[False]:\n return False\n\n @override\n def __repr__(self) -> str:\n return \"NOT_GIVEN\"\n\n\nNotGivenOr = Union[_T, NotGiven]\nNOT_GIVEN = NotGiven()\n\n\nclass Omit:\n \"\"\"In certain situations you need to be able to represent a case where a default value has\n to be explicitly removed and `None` is not an appropriate substitute, for example:\n\n ```py\n # as the default `Content-Type` header is `application/json` that will be sent\n client.post(\"/upload/files\", files={\"file\": b\"my raw file content\"})\n\n # you can't explicitly override the header as it has to be dynamically generated\n # to look something like: 'multipart/form-data; boundary=0d8382fcf5f8c3be01ca2e11002d2983'\n client.post(..., headers={\"Content-Type\": \"multipart/form-data\"})\n\n # instead you can remove the default `application/json` header by passing Omit\n client.post(..., headers={\"Content-Type\": Omit()})\n ```\n \"\"\"\n\n def __bool__(self) -> Literal[False]:\n return False\n\n\n@runtime_checkable\nclass ModelBuilderProtocol(Protocol):\n @classmethod\n def build(\n cls: type[_T],\n *,\n response: Response,\n data: object,\n ) -> _T:\n ...\n\n\nHeaders = Mapping[str, Union[str, Omit]]\n\n\nclass HeadersLikeProtocol(Protocol):\n def get(self, __key: str) -> str | None:\n ...\n\n\nHeadersLike = Union[Headers, HeadersLikeProtocol]\n\nResponseT = TypeVar(\n \"ResponseT\",\n bound=Union[\n object,\n str,\n None,\n \"BaseModel\",\n List[Any],\n Dict[str, Any],\n Response,\n ModelBuilderProtocol,\n \"APIResponse[Any]\",\n \"AsyncAPIResponse[Any]\",\n \"HttpxBinaryResponseContent\",\n ],\n)\n\nStrBytesIntFloat = Union[str, bytes, int, float]\n\n# Note: copied from Pydantic\n# https://github.com/pydantic/pydantic/blob/32ea570bf96e84234d2992e1ddf40ab8a565925a/pydantic/main.py#L49\nIncEx: TypeAlias = \"set[int] | set[str] | dict[int, Any] | dict[str, Any] | None\"\n\nPostParser = Callable[[Any], Any]\n\n\n@runtime_checkable\nclass InheritsGeneric(Protocol):\n \"\"\"Represents a type that has inherited from `Generic`\n\n The `__orig_bases__` property can be used to determine the resolved\n type variable for a given base class.\n \"\"\"\n\n __orig_bases__: tuple[_GenericAlias]\n\n\nclass _GenericAlias(Protocol):\n __origin__: type[object]\n\n\nclass HttpxSendArgs(TypedDict, total=False):\n auth: httpx.Auth\n\n# Path: src/openai/_compat.py\nfrom __future__ import annotations\n\nfrom typing import TYPE_CHECKING, Any, Union, Generic, TypeVar, Callable, cast, overload\nfrom datetime import date, datetime\nfrom typing_extensions import Self\n\nimport pydantic\nfrom pydantic.fields import FieldInfo\n\nfrom ._types import StrBytesIntFloat\n\n_T = TypeVar(\"_T\")\n_ModelT = TypeVar(\"_ModelT\", bound=pydantic.BaseModel)\n\n# --------------- Pydantic v2 compatibility ---------------\n\n# Pyright incorrectly reports some of our functions as overriding a method when they don't\n# pyright: reportIncompatibleMethodOverride=false\n\nPYDANTIC_V2 = pydantic.VERSION.startswith(\"2.\")\n\n# v1 re-exports\nif TYPE_CHECKING:\n\n def parse_date(value: date | StrBytesIntFloat) -> date: # noqa: ARG001\n ...\n\n def parse_datetime(value: Union[datetime, StrBytesIntFloat]) -> datetime: # noqa: ARG001\n ...\n\n def get_args(t: type[Any]) -> tuple[Any, ...]: # noqa: ARG001\n ...\n\n def is_union(tp: type[Any] | None) -> bool: # noqa: ARG001\n ...\n\n def get_origin(t: type[Any]) -> type[Any] | None: # noqa: ARG001\n ...\n\n def is_literal_type(type_: type[Any]) -> bool: # noqa: ARG001\n ...\n\n def is_typeddict(type_: type[Any]) -> bool: # noqa: ARG001\n ...\n\nelse:\n if PYDANTIC_V2:\n from pydantic.v1.typing import (\n get_args as get_args,\n is_union as is_union,\n get_origin as get_origin,\n is_typeddict as is_typeddict,\n is_literal_type as is_literal_type,\n )\n from pydantic.v1.datetime_parse import parse_date as parse_date, parse_datetime as parse_datetime\n else:\n from pydantic.typing import (\n get_args as get_args,\n is_union as is_union,\n get_origin as get_origin,\n is_typeddict as is_typeddict,\n is_literal_type as is_literal_type,\n )\n from pydantic.datetime_parse import parse_date as parse_date, parse_datetime as parse_datetime\n\n\n# refactored config\nif TYPE_CHECKING:\n from pydantic import ConfigDict as ConfigDict\nelse:\n if PYDANTIC_V2:\n from pydantic import ConfigDict\n else:\n # TODO: provide an error message here?\n ConfigDict = None\n\n\n# renamed methods / properties\ndef parse_obj(model: type[_ModelT], value: object) -> _ModelT:\n if PYDANTIC_V2:\n return model.model_validate(value)\n else:\n return cast(_ModelT, model.parse_obj(value)) # pyright: ignore[reportDeprecated, reportUnnecessaryCast]\n\n\ndef field_is_required(field: FieldInfo) -> bool:\n if PYDANTIC_V2:\n return field.is_required()\n return field.required # type: ignore\n\n\ndef field_get_default(field: FieldInfo) -> Any:\n value = field.get_default()\n if PYDANTIC_V2:\n from pydantic_core import PydanticUndefined\n\n if value == PydanticUndefined:\n return None\n return value\n return value\n\n\ndef field_outer_type(field: FieldInfo) -> Any:\n if PYDANTIC_V2:\n return field.annotation\n return field.outer_type_ # type: ignore\n\n\ndef get_model_config(model: type[pydantic.BaseModel]) -> Any:\n if PYDANTIC_V2:\n return model.model_config\n return model.__config__ # type: ignore\n\n\ndef get_model_fields(model: type[pydantic.BaseModel]) -> dict[str, FieldInfo]:\n if PYDANTIC_V2:\n return model.model_fields\n return model.__fields__ # type: ignore\n\n\ndef model_copy(model: _ModelT) -> _ModelT:\n if PYDANTIC_V2:\n return model.model_copy()\n return model.copy() # type: ignore\n\n\ndef model_json(model: pydantic.BaseModel, *, indent: int | None = None) -> str:\n if PYDANTIC_V2:\n return model.model_dump_json(indent=indent)\n return model.json(indent=indent) # type: ignore\n\n\ndef model_dump(\n model: pydantic.BaseModel,\n *,\n exclude_unset: bool = False,\n exclude_defaults: bool = False,\n) -> dict[str, Any]:\n if PYDANTIC_V2:\n return model.model_dump(\n exclude_unset=exclude_unset,\n exclude_defaults=exclude_defaults,\n )\n return cast(\n \"dict[str, Any]\",\n model.dict( # pyright: ignore[reportDeprecated, reportUnnecessaryCast]\n exclude_unset=exclude_unset,\n exclude_defaults=exclude_defaults,\n ),\n )\n\n\ndef model_parse(model: type[_ModelT], data: Any) -> _ModelT:\n if PYDANTIC_V2:\n return model.model_validate(data)\n return model.parse_obj(data) # pyright: ignore[reportDeprecated]\n\n\n# generic models\nif TYPE_CHECKING:\n\n class GenericModel(pydantic.BaseModel):\n ...\n\nelse:\n if PYDANTIC_V2:\n # there no longer needs to be a distinction in v2 but\n # we still have to create our own subclass to avoid\n # inconsistent MRO ordering errors\n class GenericModel(pydantic.BaseModel):\n ...\n\n else:\n import pydantic.generics\n\n class GenericModel(pydantic.generics.GenericModel, pydantic.BaseModel):\n ...\n\n\n# cached properties\nif TYPE_CHECKING:\n cached_property = property\n\n # we define a separate type (copied from typeshed)\n # that represents that `cached_property` is `set`able\n # at runtime, which differs from `@property`.\n #\n # this is a separate type as editors likely special case\n # `@property` and we don't want to cause issues just to have\n # more helpful internal types.\n\n class typed_cached_property(Generic[_T]):\n func: Callable[[Any], _T]\n attrname: str | None\n\n def __init__(self, func: Callable[[Any], _T]) -> None:\n ...\n\n @overload\n def __get__(self, instance: None, owner: type[Any] | None = None) -> Self:\n ...\n\n @overload\n def __get__(self, instance: object, owner: type[Any] | None = None) -> _T:\n ...\n\n def __get__(self, instance: object, owner: type[Any] | None = None) -> _T | Self:\n raise NotImplementedError()\n\n def __set_name__(self, owner: type[Any], name: str) -> None:\n ...\n\n # __set__ is not defined at runtime, but @cached_property is designed to be settable\n def __set__(self, instance: object, value: _T) -> None:\n ...\nelse:\n try:\n from functools import cached_property as cached_property\n except ImportError:\n from cached_property import cached_property as cached_property\n\n typed_cached_property = cached_property\n\n# Path: src/openai/_base_client.py\nfrom __future__ import annotations\n\nimport json\nimport time\nimport uuid\nimport email\nimport asyncio\nimport inspect\nimport logging\nimport platform\nimport warnings\nimport email.utils\nfrom types import TracebackType\nfrom random import random\nfrom typing import (\n TYPE_CHECKING,\n Any,\n Dict,\n Type,\n Union,\n Generic,\n Mapping,\n TypeVar,\n Iterable,\n Iterator,\n Optional,\n Generator,\n AsyncIterator,\n cast,\n overload,\n)\nfrom functools import lru_cache\nfrom typing_extensions import Literal, override, get_origin\n\nimport anyio\nimport httpx\nimport distro\nimport pydantic\nfrom httpx import URL, Limits\nfrom pydantic import PrivateAttr\n\nfrom . import _exceptions\nfrom ._qs import Querystring\nfrom ._files import to_httpx_files, async_to_httpx_files\nfrom ._types import (\n NOT_GIVEN,\n Body,\n Omit,\n Query,\n Headers,\n Timeout,\n NotGiven,\n ResponseT,\n Transport,\n AnyMapping,\n PostParser,\n ProxiesTypes,\n RequestFiles,\n HttpxSendArgs,\n AsyncTransport,\n RequestOptions,\n ModelBuilderProtocol,\n)\nfrom ._utils import is_dict, is_list, is_given, is_mapping\nfrom ._compat import model_copy, model_dump\nfrom ._models import GenericModel, FinalRequestOptions, validate_type, construct_type\nfrom ._response import (\n APIResponse,\n BaseAPIResponse,\n AsyncAPIResponse,\n extract_response_type,\n)\nfrom ._constants import (\n DEFAULT_LIMITS,\n DEFAULT_TIMEOUT,\n MAX_RETRY_DELAY,\n DEFAULT_MAX_RETRIES,\n INITIAL_RETRY_DELAY,\n RAW_RESPONSE_HEADER,\n OVERRIDE_CAST_TO_HEADER,\n)\nfrom ._streaming import Stream, AsyncStream\nfrom ._exceptions import (\n APIStatusError,\n APITimeoutError,\n APIConnectionError,\n APIResponseValidationError,\n)\nfrom ._legacy_response import LegacyAPIResponse\n\nlog: logging.Logger = logging.getLogger(__name__)\n\n# TODO: make base page type vars covariant\nSyncPageT = TypeVar(\"SyncPageT\", bound=\"BaseSyncPage[Any]\")\nAsyncPageT = TypeVar(\"AsyncPageT\", bound=\"BaseAsyncPage[Any]\")\n\n\n_T = TypeVar(\"_T\")\n_T_co = TypeVar(\"_T_co\", covariant=True)\n\n_StreamT = TypeVar(\"_StreamT\", bound=Stream[Any])\n_AsyncStreamT = TypeVar(\"_AsyncStreamT\", bound=AsyncStream[Any])\n\nif TYPE_CHECKING:\n from httpx._config import DEFAULT_TIMEOUT_CONFIG as HTTPX_DEFAULT_TIMEOUT\nelse:\n try:\n from httpx._config import DEFAULT_TIMEOUT_CONFIG as HTTPX_DEFAULT_TIMEOUT\n except ImportError:\n # taken from https://github.com/encode/httpx/blob/3ba5fe0d7ac70222590e759c31442b1cab263791/httpx/_config.py#L366\n HTTPX_DEFAULT_TIMEOUT = Timeout(5.0)\n\n\nclass PageInfo:\n \"\"\"Stores the necessary information to build the request to retrieve the next page.\n\n Either `url` or `params` must be set.\n \"\"\"\n\n url: URL | NotGiven\n params: Query | NotGiven\n\n @overload\n def __init__(\n self,\n *,\n url: URL,\n ) -> None:\n ...\n\n @overload\n def __init__(\n self,\n *,\n params: Query,\n ) -> None:\n ...\n\n def __init__(\n self,\n *,\n url: URL | NotGiven = NOT_GIVEN,\n params: Query | NotGiven = NOT_GIVEN,\n ) -> None:\n self.url = url\n self.params = params\n\n\nclass BasePage(GenericModel, Generic[_T]):\n \"\"\"\n Defines the core interface for pagination.\n\n Type Args:\n ModelT: The pydantic model that represents an item in the response.\n\n Methods:\n has_next_page(): Check if there is another page available\n next_page_info(): Get the necessary information to make a request for the next page\n \"\"\"\n\n _options: FinalRequestOptions = PrivateAttr()\n _model: Type[_T] = PrivateAttr()\n\n def has_next_page(self) -> bool:\n items = self._get_page_items()\n if not items:\n return False\n return self.next_page_info() is not None\n\n def next_page_info(self) -> Optional[PageInfo]:\n ...\n\n def _get_page_items(self) -> Iterable[_T]: # type: ignore[empty-body]\n ...\n\n def _params_from_url(self, url: URL) -> httpx.QueryParams:\n # TODO: do we have to preprocess params here?\n return httpx.QueryParams(cast(Any, self._options.params)).merge(url.params)\n\n def _info_to_options(self, info: PageInfo) -> FinalRequestOptions:\n options = model_copy(self._options)\n options._strip_raw_response_header()\n\n if not isinstance(info.params, NotGiven):\n options.params = {**options.params, **info.params}\n return options\n\n if not isinstance(info.url, NotGiven):\n params = self._params_from_url(info.url)\n url = info.url.copy_with(params=params)\n options.params = dict(url.params)\n options.url = str(url)\n return options\n\n raise ValueError(\"Unexpected PageInfo state\")\n\n\nclass BaseSyncPage(BasePage[_T], Generic[_T]):\n _client: SyncAPIClient = pydantic.PrivateAttr()\n\n def _set_private_attributes(\n self,\n client: SyncAPIClient,\n model: Type[_T],\n options: FinalRequestOptions,\n ) -> None:\n self._model = model\n self._client = client\n self._options = options\n\n # Pydantic uses a custom `__iter__` method to support casting BaseModels\n # to dictionaries. e.g. dict(model).\n # As we want to support `for item in page`, this is inherently incompatible\n # with the default pydantic behaviour. It is not possible to support both\n # use cases at once. Fortunately, this is not a big deal as all other pydantic\n # methods should continue to work as expected as there is an alternative method\n # to cast a model to a dictionary, model.dict(), which is used internally\n # by pydantic.\n def __iter__(self) -> Iterator[_T]: # type: ignore\n for page in self.iter_pages():\n for item in page._get_page_items():\n yield item\n\n def iter_pages(self: SyncPageT) -> Iterator[SyncPageT]:\n page = self\n while True:\n yield page\n if page.has_next_page():\n page = page.get_next_page()\n else:\n return\n\n def get_next_page(self: SyncPageT) -> SyncPageT:\n info = self.next_page_info()\n if not info:\n raise RuntimeError(\n \"No next page expected; please check `.has_next_page()` before calling `.get_next_page()`.\"\n )\n\n options = self._info_to_options(info)\n return self._client._request_api_list(self._model, page=self.__class__, options=options)\n\n\nclass AsyncPaginator(Generic[_T, AsyncPageT]):\n def __init__(\n self,\n client: AsyncAPIClient,\n options: FinalRequestOptions,\n page_cls: Type[AsyncPageT],\n model: Type[_T],\n ) -> None:\n self._model = model\n self._client = client\n self._options = options\n self._page_cls = page_cls\n\n def __await__(self) -> Generator[Any, None, AsyncPageT]:\n return self._get_page().__await__()\n\n async def _get_page(self) -> AsyncPageT:\n def _parser(resp: AsyncPageT) -> AsyncPageT:\n resp._set_private_attributes(\n model=self._model,\n options=self._options,\n client=self._client,\n )\n return resp\n\n self._options.post_parser = _parser\n\n return await self._client.request(self._page_cls, self._options)\n\n async def __aiter__(self) -> AsyncIterator[_T]:\n # https://github.com/microsoft/pyright/issues/3464\n page = cast(\n AsyncPageT,\n await self, # type: ignore\n )\n async for item in page:\n yield item\n\n\nclass BaseAsyncPage(BasePage[_T], Generic[_T]):\n _client: AsyncAPIClient = pydantic.PrivateAttr()\n\n def _set_private_attributes(\n self,\n model: Type[_T],\n client: AsyncAPIClient,\n options: FinalRequestOptions,\n ) -> None:\n self._model = model\n self._client = client\n self._options = options\n\n async def __aiter__(self) -> AsyncIterator[_T]:\n async for page in self.iter_pages():\n for item in page._get_page_items():\n yield item\n\n async def iter_pages(self: AsyncPageT) -> AsyncIterator[AsyncPageT]:\n page = self\n while True:\n yield page\n if page.has_next_page():\n page = await page.get_next_page()\n else:\n return\n\n async def get_next_page(self: AsyncPageT) -> AsyncPageT:\n info = self.next_page_info()\n if not info:\n raise RuntimeError(\n \"No next page expected; please check `.has_next_page()` before calling `.get_next_page()`.\"\n )\n\n options = self._info_to_options(info)\n return await self._client._request_api_list(self._model, page=self.__class__, options=options)\n\n\n_HttpxClientT = TypeVar(\"_HttpxClientT\", bound=Union[httpx.Client, httpx.AsyncClient])\n_DefaultStreamT = TypeVar(\"_DefaultStreamT\", bound=Union[Stream[Any], AsyncStream[Any]])\n\n\nclass BaseClient(Generic[_HttpxClientT, _DefaultStreamT]):\n _client: _HttpxClientT\n _version: str\n _base_url: URL\n max_retries: int\n timeout: Union[float, Timeout, None]\n _limits: httpx.Limits\n _proxies: ProxiesTypes | None\n _transport: Transport | AsyncTransport | None\n _strict_response_validation: bool\n _idempotency_header: str | None\n _default_stream_cls: type[_DefaultStreamT] | None = None\n\n def __init__(\n self,\n *,\n version: str,\n base_url: str | URL,\n _strict_response_validation: bool,\n max_retries: int = DEFAULT_MAX_RETRIES,\n timeout: float | Timeout | None = DEFAULT_TIMEOUT,\n limits: httpx.Limits,\n transport: Transport | AsyncTransport | None,\n proxies: ProxiesTypes | None,\n custom_headers: Mapping[str, str] | None = None,\n custom_query: Mapping[str, object] | None = None,\n ) -> None:\n self._version = version\n self._base_url = self._enforce_trailing_slash(URL(base_url))\n self.max_retries = max_retries\n self.timeout = timeout\n self._limits = limits\n self._proxies = proxies\n self._transport = transport\n self._custom_headers = custom_headers or {}\n self._custom_query = custom_query or {}\n self._strict_response_validation = _strict_response_validation\n self._idempotency_header = None\n\n \ndef _enforce_trailing_slash(self, url: URL) -> URL:\n if url.raw_path.endswith(b\"/\"):\n return url\n return url.copy_with(raw_path=url.raw_path + b\"/\")\n\n def _make_status_error_from_response(\n self,\n response: httpx.Response,\n ) -> APIStatusError:\n if response.is_closed and not response.is_stream_consumed:\n # We can't read the response body as it has been closed\n # before it was read. This can happen if an event hook\n # raises a status error.\n body = None\n err_msg = f\"Error code: {response.status_code}\"\n else:\n err_text = response.text.strip()\n body = err_text\n\n try:\n body = json.loads(err_text)\n err_msg = f\"Error code: {response.status_code} - {body}\"\n except Exception:\n err_msg = err_text or f\"Error code: {response.status_code}\"\n\n return self._make_status_error(err_msg, body=body, response=response)\n\n def _make_status_error(\n self,\n err_msg: str,\n *,\n body: object,\n response: httpx.Response,\n ) -> _exceptions.APIStatusError:\n raise NotImplementedError()\n\n def _remaining_retries(\n self,\n remaining_retries: Optional[int],\n options: FinalRequestOptions,\n ) -> int:\n return remaining_retries if remaining_retries is not None else options.get_max_retries(self.max_retries)\n\n def _build_headers(self, options: FinalRequestOptions) -> httpx.Headers:\n custom_headers = options.headers or {}\n headers_dict = _merge_mappings(self.default_headers, custom_headers)\n self._validate_headers(headers_dict, custom_headers)\n\n # headers are case-insensitive while dictionaries are not.\n headers = httpx.Headers(headers_dict)\n\n idempotency_header = self._idempotency_header\n if idempotency_header and options.method.lower() != \"get\" and idempotency_header not in headers:\n headers[idempotency_header] = options.idempotency_key or self._idempotency_key()\n\n return headers\n\n def _prepare_url(self, url: str) -> URL:\n \"\"\"\n Merge a URL argument together with any 'base_url' on the client,\n to create the URL used for the outgoing request.\n \"\"\"\n # Copied from httpx's `_merge_url` method.\n merge_url = URL(url)\n if merge_url.is_relative_url:\n merge_raw_path = self.base_url.raw_path + merge_url.raw_path.lstrip(b\"/\")\n return self.base_url.copy_with(raw_path=merge_raw_path)\n\n return merge_url\n\n def _build_request(\n self,\n options: FinalRequestOptions,\n ) -> httpx.Request:\n if log.isEnabledFor(logging.DEBUG):\n...\n", "needle_token_start": 15188, "needle_token_end": 15249, "code_context_ntokens": 16006, "output": ["```python\ndef _enforce_trailing_slash(self, url: URL) -> URL:\n if url.raw_path.endswith(b\"/\"):\n return url\n return url.copy_with(raw_path=url.raw_path + b\"/\")\n```"]} +{"repo": "mlc-ai/mlc-llm", "name": "fused_decode3_matmul1_fp16_before", "language": "python", "path": "mlc_llm/dispatch/llama/main.py", "position_ratio": 0.05, "description": "\nFunction Description:\n1. **Purpose**: The function performs a decoding operation followed by a matrix multiplication, primarily working with 16-bit floating-point precision. This is typically used in neural network layers where decoding of inputs is followed by a weighted sum computation.\n2. **Input**: The function takes buffers containing 32-bit unsigned integers and 16-bit floating-point numbers, which represent encoded data and matrix values for the computation.\n3. **Output**: It outputs a buffer of 16-bit floating-point numbers, which are the results of the matrix multiplication after the decoding process.\n4. **Procedure**: The function first decodes the input data using bitwise operations and arithmetic, storing the results in an intermediate buffer. It then performs a matrix multiplication using the decoded data and another input matrix, initializing the output buffer to zero before accumulating the products.\n\n", "instruction": "Based on the function description and code context, please retrieve and repeat the exact described function from the code context in a code block wrapped by ```:", "template": "instruction\ncode_context\ndescription\ninstruction", "code_context": " for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(1), T.int64(1)):\n with T.block(\"var_matmul_intermediate_local\"):\n v0, v1 = T.axis.remap(\"SS\", [ax0, ax1])\n v2 = T.axis.spatial(T.int64(4096), i0_i1_i2_0_fused * T.int64(256) + i2_2 + ax2)\n T.reads(lv4[v0, v1, v2], var_matmul_intermediate_local[v0, v1, v2])\n T.writes(p_output0_intermediate[v0, v1, v2])\n p_output0_intermediate[v0, v1, v2] = lv4[v0, v1, v2] + var_matmul_intermediate_local[v0, v1, v2]\n\n\n@T.prim_func\n\ndef fused_decode3_matmul1_fp16_before(lv5865: T.Buffer((T.int64(512), T.int64(32000)), \"uint32\"), lv5866: T.Buffer((T.int64(128), T.int64(32000)), \"float16\"), lv5867: T.Buffer((T.int64(128), T.int64(32000)), \"float16\"), lv2705: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), \"float16\"), var_matmul_intermediate: T.Buffer((T.int64(1), T.int64(1), T.int64(32000)), \"float16\")):\n T.func_attr({\"tir.noalias\": T.bool(True)})\n # with T.block(\"root\"):\n var_decode_intermediate = T.alloc_buffer((T.int64(4096), T.int64(32000)), \"float16\")\n for i, j in T.grid(T.int64(4096), T.int64(32000)):\n with T.block(\"decode\"):\n v_i, v_j = T.axis.remap(\"SS\", [i, j])\n T.reads(lv5865[v_i // T.int64(8), v_j], lv5866[v_i // T.int64(32), v_j], lv5867[v_i // T.int64(32), v_j])\n T.writes(var_decode_intermediate[v_i, v_j])\n var_decode_intermediate[v_i, v_j] = T.Cast(\"float16\", T.bitwise_and(T.shift_right(lv5865[v_i // T.int64(8), v_j], T.Cast(\"uint32\", v_i % T.int64(8) * T.int64(4))), T.uint32(15))) * lv5866[v_i // T.int64(32), v_j] + lv5867[v_i // T.int64(32), v_j]\n for i0, i1, i2, k in T.grid(T.int64(1), T.int64(1), T.int64(32000), T.int64(4096)):\n with T.block(\"matmul\"):\n v_i0, v_i1, v_i2, v_k = T.axis.remap(\"SSSR\", [i0, i1, i2, k])\n T.reads(lv2705[v_i0, v_i1, v_k], var_decode_intermediate[v_k, v_i2])\n T.writes(var_matmul_intermediate[v_i0, v_i1, v_i2])\n with T.init():\n var_matmul_intermediate[v_i0, v_i1, v_i2] = T.float16(0)\n var_matmul_intermediate[v_i0, v_i1, v_i2] = var_matmul_intermediate[v_i0, v_i1, v_i2] + lv2705[v_i0, v_i1, v_k] * var_decode_intermediate[v_k, v_i2]\n\n\n@T.prim_func\ndef fused_decode3_matmul1_fp16_after(lv1123: T.Buffer((T.int64(512), T.int64(32000)), \"uint32\"), lv5866: T.Buffer((T.int64(128), T.int64(32000)), \"float16\"), lv5867: T.Buffer((T.int64(128), T.int64(32000)), \"float16\"), lv1511: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), \"float16\"), var_matmul_intermediate: T.Buffer((T.int64(1), T.int64(1), T.int64(32000)), \"float16\")):\n T.func_attr({\"tir.noalias\": T.bool(True), \"tir.is_scheduled\": 1})\n # with T.block(\"root\"):\n var_decode_intermediate_pad_local = T.alloc_buffer((T.int64(4096), T.int64(32000)), scope=\"local\", dtype=\"float16\")\n var_matmul_intermediate_pad_local = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(32000)), scope=\"local\", dtype=\"float16\")\n lv1511_shared = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(4096)), scope=\"shared\", dtype=\"float16\")\n for i0_i1_i2_0_fused in T.thread_binding(T.int64(125), thread=\"blockIdx.x\", annotations={\"pragma_auto_unroll_max_step\": 16, \"pragma_unroll_explicit\": 1}):\n for i2_1 in T.thread_binding(T.int64(1), thread=\"vthread.x\"):\n for i2_2 in T.thread_binding(T.int64(256), thread=\"threadIdx.x\"):\n for ax0, ax1_ax2_fused_0 in T.grid(T.int64(1), T.int64(4)):\n for ax1_ax2_fused_1 in T.thread_binding(T.int64(256), thread=\"threadIdx.x\"):\n for ax1_ax2_fused_2 in T.vectorized(T.int64(4)):\n with T.block(\"lv1511_shared\"):\n v0 = T.axis.spatial(T.int64(1), ax0)\n v1 = T.axis.spatial(T.int64(1), T.int64(0))\n v2 = T.axis.spatial(T.int64(4096), ax1_ax2_fused_0 * T.int64(1024) + ax1_ax2_fused_1 * T.int64(4) + ax1_ax2_fused_2)\n T.reads(lv1511[v0, v1, v2])\n T.writes(lv1511_shared[v0, v1, v2])\n T.block_attr({\"buffer_dim_align\": [[0, 1, 32, 8]]})\n lv1511_shared[v0, v1, v2] = lv1511[v0, v1, v2]\n with T.block(\"matmul_init\"):\n v_i0 = T.axis.spatial(T.int64(1), T.int64(0))\n v_i1 = T.axis.spatial(T.int64(1), T.int64(0))\n v_i2 = T.axis.spatial(T.int64(32000), i0_i1_i2_0_fused * T.int64(256) + i2_1 * T.int64(256) + i2_2)\n T.reads()\n T.writes(var_matmul_intermediate_pad_local[v_i0, v_i1, v_i2])\n var_matmul_intermediate_pad_local[v_i0, v_i1, v_i2] = T.float32(0)\n for k_0_0 in range(T.int64(64)):\n for ax0_0 in range(T.int64(8)):\n for ax0_1 in T.unroll(T.int64(8)):\n for ax1 in range(T.int64(1)):\n with T.block(\"var_decode_intermediate_pad\"):\n v0 = T.axis.spatial(T.int64(4096), k_0_0 * T.int64(64) + ax0_0 * T.int64(8) + ax0_1)\n v1 = T.axis.spatial(T.int64(32000), i0_i1_i2_0_fused * T.int64(256) + i2_2 + ax1)\n T.reads(lv1123[v0 // T.int64(8), v1], lv5866[v0 // T.int64(32), v1], lv5867[v0 // T.int64(32), v1])\n T.writes(var_decode_intermediate_pad_local[v0, v1])\n var_decode_intermediate_pad_local[v0, v1] = T.Cast(\"float16\", T.bitwise_and(T.shift_right(lv1123[v0 // T.int64(8), v1], T.Cast(\"uint32\", v0 % T.int64(8) * T.int64(4))), T.uint32(15))) * lv5866[v0 // T.int64(32), v1] + lv5867[v0 // T.int64(32), v1]\n for k_0_1_k_1_fused in range(T.int64(64)):\n with T.block(\"matmul_update\"):\n v_i0 = T.axis.spatial(T.int64(1), T.int64(0))\n v_i1 = T.axis.spatial(T.int64(1), T.int64(0))\n v_i2 = T.axis.spatial(T.int64(32000), i0_i1_i2_0_fused * T.int64(256) + i2_1 * T.int64(256) + i2_2)\n v_k = T.axis.reduce(T.int64(4096), k_0_0 * T.int64(64) + k_0_1_k_1_fused)\n T.reads(var_matmul_intermediate_pad_local[v_i0, v_i1, v_i2], lv1511_shared[v_i0, v_i1, v_k], var_decode_intermediate_pad_local[v_k, v_i2])\n T.writes(var_matmul_intermediate_pad_local[v_i0, v_i1, v_i2])\n var_matmul_intermediate_pad_local[v_i0, v_i1, v_i2] = var_matmul_intermediate_pad_local[v_i0, v_i1, v_i2] + lv1511_shared[v_i0, v_i1, v_k] * var_decode_intermediate_pad_local[v_k, v_i2]\n for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(1), T.int64(1)):\n with T.block(\"var_matmul_intermediate_pad_local\"):\n v0, v1 = T.axis.remap(\"SS\", [ax0, ax1])\n v2 = T.axis.spatial(T.int64(32000), i0_i1_i2_0_fused * T.int64(256) + i2_2 + ax2)\n T.reads(var_matmul_intermediate_pad_local[v0, v1, v2])\n T.writes(var_matmul_intermediate[v0, v1, v2])\n var_matmul_intermediate[v0, v1, v2] = var_matmul_intermediate_pad_local[v0, v1, v2]\n\n\n@T.prim_func\ndef fused_decode3_matmul1_cast_fp16_before(lv1803: T.Buffer((T.int64(512), T.int64(32000)), \"uint32\"), lv1804: T.Buffer((T.int64(128), T.int64(32000)), \"float16\"), lv1805: T.Buffer((T.int64(128), T.int64(32000)), \"float16\"), lv3025: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), \"float16\"), p_output0_intermediate: T.Buffer((T.int64(1), T.int64(1), T.int64(32000)), \"float32\")):\n T.func_attr({\"tir.noalias\": T.bool(True)})\n # with T.block(\"root\"):\n var_decode_intermediate = T.alloc_buffer((T.int64(4096), T.int64(32000)), \"float16\")\n var_matmul_intermediate = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(32000)), \"float16\")\n for i, j in T.grid(T.int64(4096), T.int64(32000)):\n with T.block(\"decode\"):\n v_i, v_j = T.axis.remap(\"SS\", [i, j])\n T.reads(lv1803[v_i // T.int64(8), v_j], lv1804[v_i // T.int64(32), v_j], lv1805[v_i // T.int64(32), v_j])\n T.writes(var_decode_intermediate[v_i, v_j])\n var_decode_intermediate[v_i, v_j] = T.Cast(\"float16\", T.bitwise_and(T.shift_right(lv1803[v_i // T.int64(8), v_j], T.Cast(\"uint32\", v_i % T.int64(8) * T.int64(4))), T.uint32(15))) * lv1804[v_i // T.int64(32), v_j] + lv1805[v_i // T.int64(32), v_j]\n for i0, i1, i2, k in T.grid(T.int64(1), T.int64(1), T.int64(32000), T.int64(4096)):\n with T.block(\"matmul\"):\n v_i0, v_i1, v_i2, v_k = T.axis.remap(\"SSSR\", [i0, i1, i2, k])\n T.reads(lv3025[v_i0, v_i1, v_k], var_decode_intermediate[v_k, v_i2])\n T.writes(var_matmul_intermediate[v_i0, v_i1, v_i2])\n with T.init():\n var_matmul_intermediate[v_i0, v_i1, v_i2] = T.float16(0)\n var_matmul_intermediate[v_i0, v_i1, v_i2] = var_matmul_intermediate[v_i0, v_i1, v_i2] + lv3025[v_i0, v_i1, v_k] * var_decode_intermediate[v_k, v_i2]\n for i0, i1, i2 in T.grid(T.int64(1), T.int64(1), T.int64(32000)):\n with T.block(\"compute\"):\n v_i0, v_i1, v_i2 = T.axis.remap(\"SSS\", [i0, i1, i2])\n T.reads(var_matmul_intermediate[v_i0, v_i1, v_i2])\n T.writes(p_output0_intermediate[v_i0, v_i1, v_i2])\n p_output0_intermediate[v_i0, v_i1, v_i2] = T.Cast(\"float32\", var_matmul_intermediate[v_i0, v_i1, v_i2])\n\n\n@T.prim_func\ndef fused_decode3_matmul1_cast_fp16_after(lv1123: T.Buffer((T.int64(512), T.int64(32000)), \"uint32\"), lv5866: T.Buffer((T.int64(128), T.int64(32000)), \"float16\"), lv5867: T.Buffer((T.int64(128), T.int64(32000)), \"float16\"), lv1511: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), \"float16\"), var_matmul_intermediate: T.Buffer((T.int64(1), T.int64(1), T.int64(32000)), \"float32\")):\n T.func_attr({\"tir.noalias\": T.bool(True), \"tir.is_scheduled\": 1})\n # with T.block(\"root\"):\n var_decode_intermediate_pad_local = T.alloc_buffer((T.int64(4096), T.int64(32000)), scope=\"local\", dtype=\"float16\")\n var_matmul_intermediate_pad_local = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(32000)), scope=\"local\", dtype=\"float16\")\n lv1511_shared = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(4096)), scope=\"shared\", dtype=\"float16\")\n for i0_i1_i2_0_fused in T.thread_binding(T.int64(125), thread=\"blockIdx.x\", annotations={\"pragma_auto_unroll_max_step\": 16, \"pragma_unroll_explicit\": 1}):\n for i2_1 in T.thread_binding(T.int64(1), thread=\"vthread.x\"):\n for i2_2 in T.thread_binding(T.int64(256), thread=\"threadIdx.x\"):\n for ax0, ax1_ax2_fused_0 in T.grid(T.int64(1), T.int64(4)):\n for ax1_ax2_fused_1 in T.thread_binding(T.int64(256), thread=\"threadIdx.x\"):\n for ax1_ax2_fused_2 in T.vectorized(T.int64(4)):\n with T.block(\"lv1511_shared\"):\n v0 = T.axis.spatial(T.int64(1), ax0)\n v1 = T.axis.spatial(T.int64(1), T.int64(0))\n v2 = T.axis.spatial(T.int64(4096), ax1_ax2_fused_0 * T.int64(1024) + ax1_ax2_fused_1 * T.int64(4) + ax1_ax2_fused_2)\n T.reads(lv1511[v0, v1, v2])\n T.writes(lv1511_shared[v0, v1, v2])\n T.block_attr({\"buffer_dim_align\": [[0, 1, 32, 8]]})\n lv1511_shared[v0, v1, v2] = lv1511[v0, v1, v2]\n with T.block(\"matmul_init\"):\n v_i0 = T.axis.spatial(T.int64(1), T.int64(0))\n v_i1 = T.axis.spatial(T.int64(1), T.int64(0))\n v_i2 = T.axis.spatial(T.int64(32000), i0_i1_i2_0_fused * T.int64(256) + i2_1 * T.int64(256) + i2_2)\n T.reads()\n T.writes(var_matmul_intermediate_pad_local[v_i0, v_i1, v_i2])\n var_matmul_intermediate_pad_local[v_i0, v_i1, v_i2] = T.float32(0)\n for k_0_0 in range(T.int64(64)):\n for ax0_0 in range(T.int64(8)):\n for ax0_1 in T.unroll(T.int64(8)):\n for ax1 in range(T.int64(1)):\n with T.block(\"var_decode_intermediate_pad\"):\n v0 = T.axis.spatial(T.int64(4096), k_0_0 * T.int64(64) + ax0_0 * T.int64(8) + ax0_1)\n v1 = T.axis.spatial(T.int64(32000), i0_i1_i2_0_fused * T.int64(256) + i2_2 + ax1)\n T.reads(lv1123[v0 // T.int64(8), v1], lv5866[v0 // T.int64(32), v1], lv5867[v0 // T.int64(32), v1])\n T.writes(var_decode_intermediate_pad_local[v0, v1])\n var_decode_intermediate_pad_local[v0, v1] = T.Cast(\"float16\", T.bitwise_and(T.shift_right(lv1123[v0 // T.int64(8), v1], T.Cast(\"uint32\", v0 % T.int64(8) * T.int64(4))), T.uint32(15))) * lv5866[v0 // T.int64(32), v1] + lv5867[v0 // T.int64(32), v1]\n for k_0_1_k_1_fused in range(T.int64(64)):\n with T.block(\"matmul_update\"):\n v_i0 = T.axis.spatial(T.int64(1), T.int64(0))\n v_i1 = T.axis.spatial(T.int64(1), T.int64(0))\n v_i2 = T.axis.spatial(T.int64(32000), i0_i1_i2_0_fused * T.int64(256) + i2_1 * T.int64(256) + i2_2)\n v_k = T.axis.reduce(T.int64(4096), k_0_0 * T.int64(64) + k_0_1_k_1_fused)\n T.reads(var_matmul_intermediate_pad_local[v_i0, v_i1, v_i2], lv1511_shared[v_i0, v_i1, v_k], var_decode_intermediate_pad_local[v_k, v_i2])\n T.writes(var_matmul_intermediate_pad_local[v_i0, v_i1, v_i2])\n var_matmul_intermediate_pad_local[v_i0, v_i1, v_i2] = var_matmul_intermediate_pad_local[v_i0, v_i1, v_i2] + lv1511_shared[v_i0, v_i1, v_k] * var_decode_intermediate_pad_local[v_k, v_i2]\n for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(1), T.int64(1)):\n with T.block(\"var_matmul_intermediate_pad_local\"):\n v0, v1 = T.axis.remap(\"SS\", [ax0, ax1])\n v2 = T.axis.spatial(T.int64(32000), i0_i1_i2_0_fused * T.int64(256) + i2_2 + ax2)\n T.reads(var_matmul_intermediate_pad_local[v0, v1, v2])\n T.writes(var_matmul_intermediate[v0, v1, v2])\n var_matmul_intermediate[v0, v1, v2] = T.Cast(\"float32\", var_matmul_intermediate_pad_local[v0, v1, v2])\n\n\n@T.prim_func\ndef fused_decode4_fused_matmul5_add3_fp16_before(lv35: T.Buffer((T.int64(512), T.int64(4096)), \"uint32\"), lv36: T.Buffer((T.int64(128), T.int64(4096)), \"float16\"), lv37: T.Buffer((T.int64(128), T.int64(4096)), \"float16\"), lv2: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), \"float16\"), lv2710: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), \"float16\"), p_output0_intermediate: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), \"float16\")):\n T.func_attr({\"tir.noalias\": T.bool(True)})\n # with T.block(\"root\"):\n var_decode_intermediate = T.alloc_buffer((T.int64(4096), T.int64(4096)), \"float16\")\n var_matmul_intermediate = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(4096)), \"float16\")\n for i, j in T.grid(T.int64(4096), T.int64(4096)):\n with T.block(\"decode\"):\n v_i, v_j = T.axis.remap(\"SS\", [i, j])\n T.reads(lv35[v_i // T.int64(8), v_j], lv36[v_i // T.int64(32), v_j], lv37[v_i // T.int64(32), v_j])\n T.writes(var_decode_intermediate[v_i, v_j])\n var_decode_intermediate[v_i, v_j] = T.Cast(\"float16\", T.bitwise_and(T.shift_right(lv35[v_i // T.int64(8), v_j], T.Cast(\"uint32\", v_i % T.int64(8) * T.int64(4))), T.uint32(15))) * lv36[v_i // T.int64(32), v_j] + lv37[v_i // T.int64(32), v_j]\n for i0, i1, i2, k in T.grid(T.int64(1), T.int64(1), T.int64(4096), T.int64(4096)):\n with T.block(\"matmul\"):\n v_i0, v_i1, v_i2, v_k = T.axis.remap(\"SSSR\", [i0, i1, i2, k])\n T.reads(lv2[v_i0, v_i1, v_k], var_decode_intermediate[v_k, v_i2])\n T.writes(var_matmul_intermediate[v_i0, v_i1, v_i2])\n with T.init():\n var_matmul_intermediate[v_i0, v_i1, v_i2] = T.float16(0)\n var_matmul_intermediate[v_i0, v_i1, v_i2] = var_matmul_intermediate[v_i0, v_i1, v_i2] + lv2[v_i0, v_i1, v_k] * var_decode_intermediate[v_k, v_i2]\n for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(1), T.int64(4096)):\n with T.block(\"T_add\"):\n v_ax0, v_ax1, v_ax2 = T.axis.remap(\"SSS\", [ax0, ax1, ax2])\n T.reads(lv2710[v_ax0, v_ax1, v_ax2], var_matmul_intermediate[v_ax0, v_ax1, v_ax2])\n T.writes(p_output0_intermediate[v_ax0, v_ax1, v_ax2])\n p_output0_intermediate[v_ax0, v_ax1, v_ax2] = lv2710[v_ax0, v_ax1, v_ax2] + var_matmul_intermediate[v_ax0, v_ax1, v_ax2]\n\n\n@T.prim_func\ndef fused_decode4_fused_matmul5_add3_fp16_after(lv1143: T.Buffer((T.int64(512), T.int64(4096)), \"uint32\"), lv36: T.Buffer((T.int64(128), T.int64(4096)), \"float16\"), lv37: T.Buffer((T.int64(128), T.int64(4096)), \"float16\"), lv3: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), \"float16\"), lv2710: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), \"float16\"), p_output0_intermediate: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), \"float16\")):\n T.func_attr({\"tir.is_scheduled\": 1, \"tir.noalias\": T.bool(True)})\n # with T.block(\"root\"):\n var_decode_intermediate_local = T.alloc_buffer((T.int64(4096), T.int64(4096)), scope=\"local\", dtype=\"float16\")\n var_matmul_intermediate_local = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(4096)), scope=\"local\", dtype=\"float16\")\n lv3_shared = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(4096)), scope=\"shared\", dtype=\"float16\")\n for i0_i1_i2_0_fused in T.thread_binding(T.int64(16), thread=\"blockIdx.x\", annotations={\"pragma_auto_unroll_max_step\": 16, \"pragma_unroll_explicit\": 1}):\n for i2_1 in T.thread_binding(T.int64(1), thread=\"vthread.x\"):\n for i2_2 in T.thread_binding(T.int64(256), thread=\"threadIdx.x\"):\n for ax0, ax1_ax2_fused_0 in T.grid(T.int64(1), T.int64(4)):\n for ax1_ax2_fused_1 in T.thread_binding(T.int64(256), thread=\"threadIdx.x\"):\n for ax1_ax2_fused_2 in T.vectorized(T.int64(4)):\n with T.block(\"lv3_shared\"):\n v0 = T.axis.spatial(T.int64(1), ax0)\n v1 = T.axis.spatial(T.int64(1), T.int64(0))\n v2 = T.axis.spatial(T.int64(4096), ax1_ax2_fused_0 * T.int64(1024) + ax1_ax2_fused_1 * T.int64(4) + ax1_ax2_fused_2)\n T.reads(lv3[v0, v1, v2])\n T.writes(lv3_shared[v0, v1, v2])\n T.block_attr({\"buffer_dim_align\": [[0, 1, 32, 8]]})\n lv3_shared[v0, v1, v2] = lv3[v0, v1, v2]\n with T.block(\"matmul_init\"):\n v_i0 = T.axis.spatial(T.int64(1), T.int64(0))\n v_i1 = T.axis.spatial(T.int64(1), T.int64(0))\n v_i2 = T.axis.spatial(T.int64(4096), i0_i1_i2_0_fused * T.int64(256) + i2_1 * T.int64(256) + i2_2)\n T.reads()\n T.writes(var_matmul_intermediate_local[v_i0, v_i1, v_i2])\n var_matmul_intermediate_local[v_i0, v_i1, v_i2] = T.float32(0)\n for k_0_0 in range(T.int64(64)):\n for ax0_0 in range(T.int64(8)):\n for ax0_1 in T.unroll(T.int64(8)):\n for ax1 in range(T.int64(1)):\n with T.block(\"decode\"):\n v_j = T.axis.spatial(T.int64(4096), k_0_0 * T.int64(64) + ax0_0 * T.int64(8) + ax0_1)\n v_i = T.axis.spatial(T.int64(4096), i0_i1_i2_0_fused * T.int64(256) + i2_2 + ax1)\n T.reads(lv1143[v_j // T.int64(8), v_i], lv36[v_j // T.int64(32), v_i], lv37[v_j // T.int64(32), v_i])\n T.writes(var_decode_intermediate_local[v_j, v_i])\n var_decode_intermediate_local[v_j, v_i] = T.Cast(\"float16\", T.bitwise_and(T.shift_right(lv1143[v_j // T.int64(8), v_i], T.Cast(\"uint32\", v_j % T.int64(8) * T.int64(4))), T.uint32(15))) * lv36[v_j // T.int64(32), v_i] + lv37[v_j // T.int64(32), v_i]\n for k_0_1_k_1_fused in range(T.int64(64)):\n with T.block(\"matmul_update\"):\n v_i0 = T.axis.spatial(T.int64(1), T.int64(0))\n v_i1 = T.axis.spatial(T.int64(1), T.int64(0))\n v_i2 = T.axis.spatial(T.int64(4096), i0_i1_i2_0_fused * T.int64(256) + i2_1 * T.int64(256) + i2_2)\n v_k = T.axis.reduce(T.int64(4096), k_0_0 * T.int64(64) + k_0_1_k_1_fused)\n T.reads(var_matmul_intermediate_local[v_i0, v_i1, v_i2], lv3_shared[v_i0, v_i1, v_k], var_decode_intermediate_local[v_k, v_i2])\n T.writes(var_matmul_intermediate_local[v_i0, v_i1, v_i2])\n var_matmul_intermediate_local[v_i0, v_i1, v_i2] = var_matmul_intermediate_local[v_i0, v_i1, v_i2] + lv3_shared[v_i0, v_i1, v_k] * var_decode_intermediate_local[v_k, v_i2]\n for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(1), T.int64(1)):\n with T.block(\"var_matmul_intermediate_local\"):\n v0, v1 = T.axis.remap(\"SS\", [ax0, ax1])\n v2 = T.axis.spatial(T.int64(4096), i0_i1_i2_0_fused * T.int64(256) + i2_2 + ax2)\n T.reads(lv2710[v0, v1, v2], var_matmul_intermediate_local[v0, v1, v2])\n T.writes(p_output0_intermediate[v0, v1, v2])\n p_output0_intermediate[v0, v1, v2] = lv2710[v0, v1, v2] + var_matmul_intermediate_local[v0, v1, v2]\n\n\n@T.prim_func\ndef fused_decode4_matmul5_fp16_before(lv11: T.Buffer((T.int64(512), T.int64(4096)), \"uint32\"), lv12: T.Buffer((T.int64(128), T.int64(4096)), \"float16\"), lv13: T.Buffer((T.int64(128), T.int64(4096)), \"float16\"), lv2712: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), \"float16\"), var_matmul_intermediate: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), \"float16\")):\n T.func_attr({\"tir.noalias\": T.bool(True)})\n # with T.block(\"root\"):\n var_decode_intermediate = T.alloc_buffer((T.int64(4096), T.int64(4096)), \"float16\")\n for i, j in T.grid(T.int64(4096), T.int64(4096)):\n with T.block(\"decode\"):\n v_i, v_j = T.axis.remap(\"SS\", [i, j])\n T.reads(lv11[v_i // T.int64(8), v_j], lv12[v_i // T.int64(32), v_j], lv13[v_i // T.int64(32), v_j])\n T.writes(var_decode_intermediate[v_i, v_j])\n var_decode_intermediate[v_i, v_j] = T.Cast(\"float16\", T.bitwise_and(T.shift_right(lv11[v_i // T.int64(8), v_j], T.Cast(\"uint32\", v_i % T.int64(8) * T.int64(4))), T.uint32(15))) * lv12[v_i // T.int64(32), v_j] + lv13[v_i // T.int64(32), v_j]\n for i0, i1, i2, k in T.grid(T.int64(1), T.int64(1), T.int64(4096), T.int64(4096)):\n with T.block(\"matmul\"):\n v_i0, v_i1, v_i2, v_k = T.axis.remap(\"SSSR\", [i0, i1, i2, k])\n T.reads(lv2712[v_i0, v_i1, v_k], var_decode_intermediate[v_k, v_i2])\n T.writes(var_matmul_intermediate[v_i0, v_i1, v_i2])\n with T.init():\n var_matmul_intermediate[v_i0, v_i1, v_i2] = T.float16(0)\n var_matmul_intermediate[v_i0, v_i1, v_i2] = var_matmul_intermediate[v_i0, v_i1, v_i2] + lv2712[v_i0, v_i1, v_k] * var_decode_intermediate[v_k, v_i2]\n\n\n@T.prim_func\ndef fused_decode4_matmul5_fp16_after(lv1128: T.Buffer((T.int64(512), T.int64(4096)), \"uint32\"), lv12: T.Buffer((T.int64(128), T.int64(4096)), \"float16\"), lv13: T.Buffer((T.int64(128), T.int64(4096)), \"float16\"), lv2712: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), \"float16\"), var_matmul_intermediate: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), \"float16\")):\n T.func_attr({\"tir.is_scheduled\": 1, \"tir.noalias\": T.bool(True)})\n # with T.block(\"root\"):\n var_decode_intermediate_local = T.alloc_buffer((T.int64(4096), T.int64(4096)), scope=\"local\", dtype=\"float16\")\n var_matmul_intermediate_local = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(4096)), scope=\"local\", dtype=\"float16\")\n lv2712_shared = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(4096)), scope=\"shared\", dtype=\"float16\")\n for i0_i1_i2_0_fused in T.thread_binding(T.int64(16), thread=\"blockIdx.x\", annotations={\"pragma_auto_unroll_max_step\": 16, \"pragma_unroll_explicit\": 1}):\n for i2_1 in T.thread_binding(T.int64(1), thread=\"vthread.x\"):\n for i2_2 in T.thread_binding(T.int64(256), thread=\"threadIdx.x\"):\n for ax0, ax1_ax2_fused_0 in T.grid(T.int64(1), T.int64(4)):\n for ax1_ax2_fused_1 in T.thread_binding(T.int64(256), thread=\"threadIdx.x\"):\n for ax1_ax2_fused_2 in T.vectorized(T.int64(4)):\n with T.block(\"lv2712_shared\"):\n v0 = T.axis.spatial(T.int64(1), ax0)\n v1 = T.axis.spatial(T.int64(1), T.int64(0))\n v2 = T.axis.spatial(T.int64(4096), ax1_ax2_fused_0 * T.int64(1024) + ax1_ax2_fused_1 * T.int64(4) + ax1_ax2_fused_2)\n T.reads(lv2712[v0, v1, v2])\n T.writes(lv2712_shared[v0, v1, v2])\n T.block_attr({\"buffer_dim_align\": [[0, 1, 32, 8]]})\n lv2712_shared[v0, v1, v2] = lv2712[v0, v1, v2]\n with T.block(\"matmul_init\"):\n v_i0 = T.axis.spatial(T.int64(1), T.int64(0))\n v_i1 = T.axis.spatial(T.int64(1), T.int64(0))\n v_i2 = T.axis.spatial(T.int64(4096), i0_i1_i2_0_fused * T.int64(256) + i2_1 * T.int64(256) + i2_2)\n T.reads()\n T.writes(var_matmul_intermediate_local[v_i0, v_i1, v_i2])\n var_matmul_intermediate_local[v_i0, v_i1, v_i2] = T.float32(0)\n for k_0_0 in range(T.int64(64)):\n for ax0_0 in range(T.int64(8)):\n for ax0_1 in T.unroll(T.int64(8)):\n for ax1 in range(T.int64(1)):\n with T.block(\"decode\"):\n v_j = T.axis.spatial(T.int64(4096), k_0_0 * T.int64(64) + ax0_0 * T.int64(8) + ax0_1)\n v_i = T.axis.spatial(T.int64(4096), i0_i1_i2_0_fused * T.int64(256) + i2_2 + ax1)\n T.reads(lv1128[v_j // T.int64(8), v_i], lv12[v_j // T.int64(32), v_i], lv13[v_j // T.int64(32), v_i])\n T.writes(var_decode_intermediate_local[v_j, v_i])\n var_decode_intermediate_local[v_j, v_i] = T.Cast(\"float16\", T.bitwise_and(T.shift_right(lv1128[v_j // T.int64(8), v_i], T.Cast(\"uint32\", v_j % T.int64(8) * T.int64(4))), T.uint32(15))) * lv12[v_j // T.int64(32), v_i] + lv13[v_j // T.int64(32), v_i]\n for k_0_1_k_1_fused in range(T.int64(64)):\n with T.block(\"matmul_update\"):\n v_i0 = T.axis.spatial(T.int64(1), T.int64(0))\n v_i1 = T.axis.spatial(T.int64(1), T.int64(0))\n v_i2 = T.axis.spatial(T.int64(4096), i0_i1_i2_0_fused * T.int64(256) + i2_1 * T.int64(256) + i2_2)\n v_k = T.axis.reduce(T.int64(4096), k_0_0 * T.int64(64) + k_0_1_k_1_fused)\n T.reads(var_matmul_intermediate_local[v_i0, v_i1, v_i2], lv2712_shared[v_i0, v_i1, v_k], var_decode_intermediate_local[v_k, v_i2])\n T.writes(var_matmul_intermediate_local[v_i0, v_i1, v_i2])\n var_matmul_intermediate_local[v_i0, v_i1, v_i2] = var_matmul_intermediate_local[v_i0, v_i1, v_i2] + lv2712_shared[v_i0, v_i1, v_k] * var_decode_intermediate_local[v_k, v_i2]\n for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(1), T.int64(1)):\n with T.block(\"var_matmul_intermediate_local\"):\n v0, v1 = T.axis.remap(\"SS\", [ax0, ax1])\n v2 = T.axis.spatial(T.int64(4096), i0_i1_i2_0_fused * T.int64(256) + i2_2 + ax2)\n T.reads(var_matmul_intermediate_local[v0, v1, v2])\n T.writes(var_matmul_intermediate[v0, v1, v2])\n var_matmul_intermediate[v0, v1, v2] = var_matmul_intermediate_local[v0, v1, v2]\n\n\n@T.prim_func\ndef fused_decode5_fused_matmul8_multiply1_fp16_before(lv51: T.Buffer((T.int64(512), T.int64(11008)), \"uint32\"), lv52: T.Buffer((T.int64(128), T.int64(11008)), \"float16\"), lv53: T.Buffer((T.int64(128), T.int64(11008)), \"float16\"), lv2749: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), \"float16\"), lv5: T.Buffer((T.int64(1), T.int64(1), T.int64(11008)), \"float16\"), p_output0_intermediate: T.Buffer((T.int64(1), T.int64(1), T.int64(11008)), \"float16\")):\n T.func_attr({\"tir.noalias\": T.bool(True)})\n # with T.block(\"root\"):\n var_decode_intermediate = T.alloc_buffer((T.int64(4096), T.int64(11008)), \"float16\")\n var_matmul_intermediate = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(11008)), \"float16\")\n for i, j in T.grid(T.int64(4096), T.int64(11008)):\n with T.block(\"decode\"):\n v_i, v_j = T.axis.remap(\"SS\", [i, j])\n T.reads(lv51[v_i // T.int64(8), v_j], lv52[v_i // T.int64(32), v_j], lv53[v_i // T.int64(32), v_j])\n T.writes(var_decode_intermediate[v_i, v_j])\n var_decode_intermediate[v_i, v_j] = T.Cast(\"float16\", T.bitwise_and(T.shift_right(lv51[v_i // T.int64(8), v_j], T.Cast(\"uint32\", v_i % T.int64(8) * T.int64(4))), T.uint32(15))) * lv52[v_i // T.int64(32), v_j] + lv53[v_i // T.int64(32), v_j]\n for i0, i1, i2, k in T.grid(T.int64(1), T.int64(1), T.int64(11008), T.int64(4096)):\n with T.block(\"matmul\"):\n v_i0, v_i1, v_i2, v_k = T.axis.remap(\"SSSR\", [i0, i1, i2, k])\n T.reads(lv2749[v_i0, v_i1, v_k], var_decode_intermediate[v_k, v_i2])\n T.writes(var_matmul_intermediate[v_i0, v_i1, v_i2])\n with T.init():\n var_matmul_intermediate[v_i0, v_i1, v_i2] = T.float16(0)\n var_matmul_intermediate[v_i0, v_i1, v_i2] = var_matmul_intermediate[v_i0, v_i1, v_i2] + lv2749[v_i0, v_i1, v_k] * var_decode_intermediate[v_k, v_i2]\n for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(1), T.int64(11008)):\n with T.block(\"T_multiply\"):\n v_ax0, v_ax1, v_ax2 = T.axis.remap(\"SSS\", [ax0, ax1, ax2])\n T.reads(lv5[v_ax0, v_ax1, v_ax2], var_matmul_intermediate[v_ax0, v_ax1, v_ax2])\n T.writes(p_output0_intermediate[v_ax0, v_ax1, v_ax2])\n p_output0_intermediate[v_ax0, v_ax1, v_ax2] = lv5[v_ax0, v_ax1, v_ax2] * var_matmul_intermediate[v_ax0, v_ax1, v_ax2]\n\n\n@T.prim_func\ndef fused_decode5_fused_matmul8_multiply1_fp16_after(lv1153: T.Buffer((T.int64(512), T.int64(11008)), \"uint32\"), lv52: T.Buffer((T.int64(128), T.int64(11008)), \"float16\"), lv53: T.Buffer((T.int64(128), T.int64(11008)), \"float16\"), lv2749: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), \"float16\"), lv5: T.Buffer((T.int64(1), T.int64(1), T.int64(11008)), \"float16\"), p_output0_intermediate: T.Buffer((T.int64(1), T.int64(1), T.int64(11008)), \"float16\")):\n T.func_attr({\"tir.is_scheduled\": 1, \"tir.noalias\": T.bool(True)})\n # with T.block(\"root\"):\n var_decode_intermediate_local = T.alloc_buffer((T.int64(4096), T.int64(11008)), scope=\"local\", dtype=\"float16\")\n var_matmul_intermediate_local = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(11008)), scope=\"local\", dtype=\"float16\")\n lv2749_shared = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(4096)), scope=\"shared\", dtype=\"float16\")\n for i0_i1_i2_0_fused in T.thread_binding(T.int64(43), thread=\"blockIdx.x\", annotations={\"pragma_auto_unroll_max_step\": 16, \"pragma_unroll_explicit\": 1}):\n for i2_1 in T.thread_binding(T.int64(1), thread=\"vthread.x\"):\n for i2_2 in T.thread_binding(T.int64(256), thread=\"threadIdx.x\"):\n for ax0, ax1_ax2_fused_0 in T.grid(T.int64(1), T.int64(4)):\n for ax1_ax2_fused_1 in T.thread_binding(T.int64(256), thread=\"threadIdx.x\"):\n for ax1_ax2_fused_2 in T.vectorized(T.int64(4)):\n with T.block(\"lv2749_shared\"):\n v0 = T.axis.spatial(T.int64(1), ax0)\n v1 = T.axis.spatial(T.int64(1), T.int64(0))\n v2 = T.axis.spatial(T.int64(4096), ax1_ax2_fused_0 * T.int64(1024) + ax1_ax2_fused_1 * T.int64(4) + ax1_ax2_fused_2)\n T.reads(lv2749[v0, v1, v2])\n T.writes(lv2749_shared[v0, v1, v2])\n T.block_attr({\"buffer_dim_align\": [[0, 1, 32, 8]]})\n lv2749_shared[v0, v1, v2] = lv2749[v0, v1, v2]\n with T.block(\"matmul_init\"):\n v_i0 = T.axis.spatial(T.int64(1), T.int64(0))\n v_i1 = T.axis.spatial(T.int64(1), T.int64(0))\n v_i2 = T.axis.spatial(T.int64(11008), i0_i1_i2_0_fused * T.int64(256) + i2_1 * T.int64(256) + i2_2)\n T.reads()\n T.writes(var_matmul_intermediate_local[v_i0, v_i1, v_i2])\n var_matmul_intermediate_local[v_i0, v_i1, v_i2] = T.float32(0)\n for k_0_0 in range(T.int64(64)):\n...\n", "needle_token_start": 270, "needle_token_end": 1194, "code_context_ntokens": 16276, "output": ["```python\n@T.prim_func\ndef fused_decode3_matmul1_fp16_before(lv5865: T.Buffer((T.int64(512), T.int64(32000)), \"uint32\"), lv5866: T.Buffer((T.int64(128), T.int64(32000)), \"float16\"), lv5867: T.Buffer((T.int64(128), T.int64(32000)), \"float16\"), lv2705: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), \"float16\"), var_matmul_intermediate: T.Buffer((T.int64(1), T.int64(1), T.int64(32000)), \"float16\")):\n T.func_attr({\"tir.noalias\": T.bool(True)})\n # with T.block(\"root\"):\n var_decode_intermediate = T.alloc_buffer((T.int64(4096), T.int64(32000)), \"float16\")\n for i, j in T.grid(T.int64(4096), T.int64(32000)):\n with T.block(\"decode\"):\n v_i, v_j = T.axis.remap(\"SS\", [i, j])\n T.reads(lv5865[v_i // T.int64(8), v_j], lv5866[v_i // T.int64(32), v_j], lv5867[v_i // T.int64(32), v_j])\n T.writes(var_decode_intermediate[v_i, v_j])\n var_decode_intermediate[v_i, v_j] = T.Cast(\"float16\", T.bitwise_and(T.shift_right(lv5865[v_i // T.int64(8), v_j], T.Cast(\"uint32\", v_i % T.int64(8) * T.int64(4))), T.uint32(15))) * lv5866[v_i // T.int64(32), v_j] + lv5867[v_i // T.int64(32), v_j]\n for i0, i1, i2, k in T.grid(T.int64(1), T.int64(1), T.int64(32000), T.int64(4096)):\n with T.block(\"matmul\"):\n v_i0, v_i1, v_i2, v_k = T.axis.remap(\"SSSR\", [i0, i1, i2, k])\n T.reads(lv2705[v_i0, v_i1, v_k], var_decode_intermediate[v_k, v_i2])\n T.writes(var_matmul_intermediate[v_i0, v_i1, v_i2])\n with T.init():\n var_matmul_intermediate[v_i0, v_i1, v_i2] = T.float16(0)\n var_matmul_intermediate[v_i0, v_i1, v_i2] = var_matmul_intermediate[v_i0, v_i1, v_i2] + lv2705[v_i0, v_i1, v_k] * var_decode_intermediate[v_k, v_i2]\n```"]} +{"repo": "mlc-ai/mlc-llm", "name": "fused_fused_decode7_matmul4", "language": "python", "path": "mlc_llm/dispatch/dispatch_tir_operator_adreno.py", "position_ratio": 0.15, "description": "\nFunction Description:\n1. **Purpose**: The function is designed to perform a matrix multiplication operation after decoding and transforming the input data. It is likely used in a neural network or similar computational model where such transformations and multiplications are common for processing the input data.\n2. **Input**: The function takes three primary inputs:\n - A 2D buffer of unsigned 32-bit integers,\n - A 2D buffer of 16-bit floating-point numbers,\n - A 3D buffer of 16-bit floating-point numbers.\n3. **Output**: It outputs a 3D buffer of 16-bit floating-point numbers, which represents the result of the matrix multiplication after the initial decoding and transformation of the input data.\n4. **Procedure**: The function proceeds through several steps:\n - **Decoding and Transformation**: Each element of the first input buffer is decoded and transformed using bitwise operations and then multiplied by corresponding elements from the second input buffer.\n - **Matrix Multiplication**: The transformed data is then used in a matrix multiplication with the third input buffer.\n - **Aggregation**: The results of the multiplication are aggregated into the output buffer.\n\n", "instruction": "Based on the function description and code context, please retrieve and repeat the exact described function from the code context in a code block wrapped by ```:", "template": "instruction\ncode_context\ndescription\ninstruction", "code_context": " + i0_i1_i2_fused_1 * T.int64(4)\n + ax1,\n )\n T.reads(\n lv20_local[v0, v1],\n var_matmul_intermediate_local_batch[\n v_i0, v_i1, v_i2k\n ],\n )\n T.writes(\n var_matmul_intermediate_local[v_i0, v_i1, v_i2k]\n )\n var_matmul_intermediate_local[v_i0, v_i1, v_i2k] = (\n var_matmul_intermediate_local[v_i0, v_i1, v_i2k]\n + var_matmul_intermediate_local_batch[\n v_i0, v_i1, v_i2k\n ]\n * lv20_local[v0, v1]\n )\n for ax2_y in T.thread_binding(T.int64(16), thread=\"threadIdx.y\"):\n for ax0, ax1 in T.grid(T.int64(1), T.int64(1)):\n for ax2 in T.vectorized(T.int64(4)):\n with T.block(\"var_matmul_intermediate_update\"):\n v0, v1 = T.axis.remap(\"SS\", [ax0, ax1])\n v2 = T.axis.spatial(\n T.int64(2048),\n i0_i1_i2_fused_1 * T.int64(64)\n + ax2_y * T.int64(4)\n + ax2,\n )\n v_i2k = T.axis.spatial(\n T.int64(352256),\n i0_i1_i2_fused_0 * T.int64(2048)\n + i0_i1_i2_fused_1 * T.int64(64)\n + ax2_y * T.int64(4)\n + ax2,\n )\n T.reads(var_matmul_intermediate_local[v0, v1, v_i2k])\n T.writes(lv1654_shared[v0, v1, v2])\n lv1654_shared[v0, v1, v2] = var_matmul_intermediate_local[\n v0, v1, v_i2k\n ]\n for ax2_y in T.thread_binding(T.int64(16), thread=\"threadIdx.y\"):\n for ax0, ax1 in T.grid(T.int64(1), T.int64(1)):\n for ax2 in T.vectorized(T.int64(4)):\n with T.block(\"reduction_1\"):\n v0, v1 = T.axis.remap(\"SS\", [ax0, ax1])\n v_i2k = T.axis.spatial(\n T.int64(2048),\n i0_i1_i2_fused_1 * T.int64(64)\n + ax2_y * T.int64(4)\n + ax2,\n )\n T.where(ax2_y < T.int64(8))\n T.reads(lv1654_shared[v0, v1, v_i2k])\n T.writes(lv1654_shared[v0, v1, v_i2k])\n lv1654_shared[v0, v1, v_i2k] = (\n lv1654_shared[v0, v1, v_i2k]\n + lv1654_shared[v0, v1, v_i2k + T.int64(32)]\n )\n for ax2_y in T.thread_binding(T.int64(16), thread=\"threadIdx.y\"):\n for ax0, ax1 in T.grid(T.int64(1), T.int64(1)):\n for ax2 in T.vectorized(T.int64(4)):\n with T.block(\"reduction_2\"):\n v0, v1 = T.axis.remap(\"SS\", [ax0, ax1])\n v_i2k = T.axis.spatial(\n T.int64(2048),\n i0_i1_i2_fused_1 * T.int64(64)\n + ax2_y * T.int64(4)\n + ax2,\n )\n T.where(ax2_y < T.int64(4))\n T.reads(lv1654_shared[v0, v1, v_i2k])\n T.writes(lv1654_shared[v0, v1, v_i2k])\n lv1654_shared[v0, v1, v_i2k] = (\n lv1654_shared[v0, v1, v_i2k]\n + lv1654_shared[v0, v1, v_i2k + T.int64(16)]\n )\n for ax2_y in T.thread_binding(T.int64(16), thread=\"threadIdx.y\"):\n for ax0, ax1 in T.grid(T.int64(1), T.int64(1)):\n for ax2 in T.vectorized(T.int64(4)):\n with T.block(\"var_matmul_intermediate_local\"):\n v0, v1 = T.axis.remap(\"SS\", [ax0, ax1])\n v2 = T.axis.spatial(\n T.int64(22016),\n i0_i1_i2_fused_0 * T.int64(128)\n + i0_i1_i2_fused_1 * T.int64(4)\n + ax2,\n )\n v_i2k = T.axis.spatial(\n T.int64(2048),\n i0_i1_i2_fused_1 * T.int64(64)\n + ax2_y * T.int64(4)\n + ax2,\n )\n T.where(ax2_y < T.int64(1))\n T.reads(lv1654_shared[v0, v1, v_i2k])\n T.writes(var_matmul_intermediate[v0, v1, v2])\n var_matmul_intermediate[v0, v1, v2] = (\n lv1654_shared[v0, v1, v_i2k]\n + lv1654_shared[v0, v1, v_i2k + T.int64(4)]\n + lv1654_shared[v0, v1, v_i2k + T.int64(8)]\n + lv1654_shared[v0, v1, v_i2k + T.int64(12)]\n )\n\n\n@T.prim_func(private=True)\n\ndef fused_fused_decode7_matmul4(\n lv3: T.Buffer((T.int64(512), T.int64(12288)), \"uint32\"),\n lv4: T.Buffer((T.int64(128), T.int64(12288)), \"float16\"),\n lv1615: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), \"float16\"),\n var_matmul_intermediate: T.Buffer(\n (T.int64(1), T.int64(1), T.int64(12288)), \"float16\"\n ),\n):\n T.func_attr({\"tir.noalias\": T.bool(True)})\n # with T.block(\"root\"):\n p_output0_intermediate = T.alloc_buffer((T.int64(4096), T.int64(12288)), \"float16\")\n for i, j in T.grid(T.int64(4096), T.int64(12288)):\n with T.block(\"decode\"):\n v_i, v_j = T.axis.remap(\"SS\", [i, j])\n T.reads(lv3[v_i // T.int64(8), v_j], lv4[v_i // T.int64(32), v_j])\n T.writes(p_output0_intermediate[v_i, v_j])\n p_output0_intermediate[v_i, v_j] = (\n T.Cast(\n \"float16\",\n T.bitwise_and(\n T.shift_right(\n lv3[v_i // T.int64(8), v_j],\n T.Cast(\"uint32\", v_i % T.int64(8)) * T.uint32(4),\n ),\n T.uint32(15),\n ),\n )\n - T.float16(7)\n ) * lv4[v_i // T.int64(32), v_j]\n for i0, i1, i2, k in T.grid(T.int64(1), T.int64(1), T.int64(12288), T.int64(4096)):\n with T.block(\"matmul\"):\n v_i0, v_i1, v_i2, v_k = T.axis.remap(\"SSSR\", [i0, i1, i2, k])\n T.reads(lv1615[v_i0, v_i1, v_k], p_output0_intermediate[v_k, v_i2])\n T.writes(var_matmul_intermediate[v_i0, v_i1, v_i2])\n with T.init():\n var_matmul_intermediate[v_i0, v_i1, v_i2] = T.float16(0)\n var_matmul_intermediate[v_i0, v_i1, v_i2] = (\n var_matmul_intermediate[v_i0, v_i1, v_i2]\n + lv1615[v_i0, v_i1, v_k] * p_output0_intermediate[v_k, v_i2]\n )\n\n\n@T.prim_func(private=True)\ndef fused_fused_decode7_matmul4_after(\n lv3: T.Buffer((T.int64(512), T.int64(12288)), \"uint32\"),\n lv4: T.Buffer((T.int64(128), T.int64(12288)), \"float16\"),\n lv1615: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), \"float16\"),\n var_matmul_intermediate: T.Buffer(\n (T.int64(1), T.int64(1), T.int64(12288)), \"float16\"\n ),\n):\n T.func_attr({\"tir.is_scheduled\": 1, \"tir.noalias\": T.bool(True)})\n # with T.block(\"root\"):\n var_matmul_intermediate_local = T.alloc_buffer(\n (T.int64(1), T.int64(1), T.int64(24576)), \"float16\", scope=\"local\"\n )\n var_matmul_intermediate_local_batch = T.alloc_buffer(\n (T.int64(1), T.int64(1), T.int64(24576)), \"float16\", scope=\"local\"\n )\n lv3_local = T.alloc_buffer((T.int64(512), T.int64(12288)), \"uint32\", scope=\"local\")\n lv4_local = T.alloc_buffer((T.int64(128), T.int64(12288)), \"float16\", scope=\"local\")\n lv1615_shared = T.alloc_buffer(\n (T.int64(1), T.int64(1), T.int64(1024)), \"float16\", scope=\"shared\"\n )\n for i0_i1_i2_fused_0 in T.thread_binding(T.int64(48), thread=\"blockIdx.x\"):\n for i0_i1_i2_fused_1 in T.thread_binding(T.int64(64), thread=\"threadIdx.x\"):\n for ax2_y in T.thread_binding(T.int64(2), thread=\"threadIdx.y\"):\n for i0_i1_i2_fused_2_init in T.vectorized(T.int64(4)):\n with T.block(\"matmul_init\"):\n v_i0 = T.axis.spatial(T.int64(1), T.int64(0))\n v_i1 = T.axis.spatial(T.int64(1), T.int64(0))\n v_i2 = T.axis.spatial(\n T.int64(24576),\n i0_i1_i2_fused_0 * T.int64(512)\n + i0_i1_i2_fused_1 * T.int64(8)\n + ax2_y * T.int64(4)\n + i0_i1_i2_fused_2_init,\n )\n T.reads()\n T.writes(var_matmul_intermediate_local[v_i0, v_i1, v_i2])\n var_matmul_intermediate_local[v_i0, v_i1, v_i2] = T.float16(0)\n for k_0 in range(T.int64(4)):\n for ax2_1 in T.thread_binding(T.int64(64), thread=\"threadIdx.x\"):\n for ax2_y in T.thread_binding(T.int64(2), thread=\"threadIdx.y\"):\n for ax0, ax1 in T.grid(T.int64(1), T.int64(1)):\n for ax2_2 in T.vectorized(T.int64(8)):\n with T.block(\"lv1615_shared\"):\n v0, v1 = T.axis.remap(\"SS\", [ax0, ax1])\n v2 = T.axis.spatial(\n T.int64(4096),\n k_0 * T.int64(1024)\n + ax2_y * T.int64(512)\n + ax2_1 * T.int64(8)\n + ax2_2,\n )\n v2k = T.axis.spatial(\n T.int64(1024),\n (\n ax2_y * T.int64(512)\n + ax2_1 * T.int64(8)\n + ax2_2\n ),\n )\n T.reads(lv1615[v0, v1, v2])\n T.writes(lv1615_shared[v0, v1, v2k])\n lv1615_shared[v0, v1, v2k] = lv1615[v0, v1, v2]\n for k_1 in range(T.int64(16)):\n for ax2_y in T.thread_binding(T.int64(2), thread=\"threadIdx.y\"):\n for ax1 in T.vectorized(T.int64(4)):\n with T.block(\"matmul_init_local\"):\n v_i0 = T.axis.spatial(T.int64(1), T.int64(0))\n v_i1 = T.axis.spatial(T.int64(1), T.int64(0))\n v_i2k = T.axis.spatial(\n T.int64(24576),\n i0_i1_i2_fused_0 * T.int64(512)\n + i0_i1_i2_fused_1 * T.int64(8)\n + ax2_y * T.int64(4)\n + ax1,\n )\n T.reads()\n T.writes(\n var_matmul_intermediate_local_batch[\n v_i0, v_i1, v_i2k\n ]\n )\n var_matmul_intermediate_local_batch[\n v_i0, v_i1, v_i2k\n ] = T.float16(0)\n for ax0 in range(T.int64(1)):\n for ax1 in T.vectorized(T.int64(4)):\n with T.block(\"lv4_local\"):\n v0 = T.axis.spatial(\n T.int64(128),\n k_0 * T.int64(32)\n + (k_1 * T.int64(2) + ax2_y)\n + ax0,\n )\n v1 = T.axis.spatial(\n T.int64(12288),\n i0_i1_i2_fused_0 * T.int64(256)\n + i0_i1_i2_fused_1 * T.int64(4)\n + ax1,\n )\n T.reads(lv4[v0, v1])\n T.writes(lv4_local[v0, v1])\n lv4_local[v0, v1] = lv4[v0, v1]\n for k_2 in range(T.int64(4)):\n for ax0 in range(T.int64(1)):\n for ax1 in T.vectorized(T.int64(4)):\n with T.block(\"lv3_local\"):\n v0 = T.axis.spatial(\n T.int64(512),\n k_0 * T.int64(128)\n + (k_1 * T.int64(2) + ax2_y) * T.int64(4)\n + k_2\n + ax0,\n )\n v1 = T.axis.spatial(\n T.int64(12288),\n i0_i1_i2_fused_0 * T.int64(256)\n + i0_i1_i2_fused_1 * T.int64(4)\n + ax1,\n )\n T.reads(lv3[v0, v1])\n T.writes(lv3_local[v0, v1])\n lv3_local[v0, v1] = lv3[v0, v1]\n for k_3 in range(T.int64(8)):\n for i0_i1_i2_fused_2 in T.vectorized(T.int64(4)):\n with T.block(\"matmul_update\"):\n v_i0 = T.axis.spatial(T.int64(1), T.int64(0))\n v_i1 = T.axis.spatial(T.int64(1), T.int64(0))\n v_i2 = T.axis.spatial(\n T.int64(12288),\n i0_i1_i2_fused_0 * T.int64(256)\n + i0_i1_i2_fused_1 * T.int64(4)\n + i0_i1_i2_fused_2,\n )\n v_i2k = T.axis.spatial(\n T.int64(24576),\n i0_i1_i2_fused_0 * T.int64(512)\n + i0_i1_i2_fused_1 * T.int64(8)\n + ax2_y * T.int64(4)\n + i0_i1_i2_fused_2,\n )\n v_k = T.axis.reduce(\n T.int64(4096),\n k_0 * T.int64(1024)\n + (k_1 * T.int64(2) + ax2_y) * T.int64(32)\n + k_2 * T.int64(8)\n + k_3,\n )\n v_ki = T.axis.reduce(\n T.int64(1024),\n (k_1 * T.int64(2) + ax2_y) * T.int64(32)\n + k_2 * T.int64(8)\n + k_3,\n )\n T.reads(\n var_matmul_intermediate_local_batch[\n v_i0, v_i1, v_i2k\n ],\n lv1615_shared[v_i0, v_i1, v_ki],\n lv3_local[v_k // T.int64(8), v_i2],\n )\n T.writes(\n var_matmul_intermediate_local_batch[\n v_i0, v_i1, v_i2k\n ]\n )\n var_matmul_intermediate_local_batch[\n v_i0, v_i1, v_i2k\n ] = var_matmul_intermediate_local_batch[\n v_i0, v_i1, v_i2k\n ] + lv1615_shared[\n v_i0, v_i1, v_ki\n ] * (\n (\n T.Cast(\n \"float16\",\n T.bitwise_and(\n T.shift_right(\n lv3_local[\n v_k // T.int64(8), v_i2\n ],\n T.Cast(\n \"uint32\",\n v_k % T.int64(8),\n )\n * T.uint32(4),\n ),\n T.uint32(15),\n ),\n )\n - T.float16(7)\n )\n )\n for ax0 in range(T.int64(1)):\n for ax1 in T.vectorized(T.int64(4)):\n with T.block(\"multiple_scale\"):\n v_i0 = T.axis.spatial(T.int64(1), T.int64(0))\n v_i1 = T.axis.spatial(T.int64(1), T.int64(0))\n v_i2k = T.axis.spatial(\n T.int64(24576),\n i0_i1_i2_fused_0 * T.int64(512)\n + i0_i1_i2_fused_1 * T.int64(8)\n + ax2_y * T.int64(4)\n + ax1,\n )\n v0 = T.axis.spatial(\n T.int64(128),\n k_0 * T.int64(32)\n + (k_1 * T.int64(2) + ax2_y)\n + ax0,\n )\n v1 = T.axis.spatial(\n T.int64(12288),\n i0_i1_i2_fused_0 * T.int64(256)\n + i0_i1_i2_fused_1 * T.int64(4)\n + ax1,\n )\n T.reads(\n lv4_local[v0, v1],\n var_matmul_intermediate_local_batch[\n v_i0, v_i1, v_i2k\n ],\n )\n T.writes(\n var_matmul_intermediate_local[v_i0, v_i1, v_i2k]\n )\n var_matmul_intermediate_local[v_i0, v_i1, v_i2k] = (\n var_matmul_intermediate_local[v_i0, v_i1, v_i2k]\n + var_matmul_intermediate_local_batch[\n v_i0, v_i1, v_i2k\n ]\n * lv4_local[v0, v1]\n )\n for ax2_y in T.thread_binding(T.int64(2), thread=\"threadIdx.y\"):\n for ax0, ax1 in T.grid(T.int64(1), T.int64(1)):\n for ax2 in T.vectorized(T.int64(4)):\n with T.block(\"var_matmul_intermediate_update\"):\n v0, v1 = T.axis.remap(\"SS\", [ax0, ax1])\n v2 = T.axis.spatial(\n T.int64(512),\n i0_i1_i2_fused_1 * T.int64(8)\n + ax2_y * T.int64(4)\n + ax2,\n )\n v_i2k = T.axis.spatial(\n T.int64(24576),\n i0_i1_i2_fused_0 * T.int64(512)\n + i0_i1_i2_fused_1 * T.int64(8)\n + ax2_y * T.int64(4)\n + ax2,\n )\n T.reads(var_matmul_intermediate_local[v0, v1, v_i2k])\n T.writes(lv1615_shared[v0, v1, v2])\n lv1615_shared[v0, v1, v2] = var_matmul_intermediate_local[\n v0, v1, v_i2k\n ]\n for ax2_y in T.thread_binding(T.int64(2), thread=\"threadIdx.y\"):\n for ax0, ax1 in T.grid(T.int64(1), T.int64(1)):\n for ax2 in T.vectorized(T.int64(4)):\n with T.block(\"var_matmul_intermediate_local\"):\n v0, v1 = T.axis.remap(\"SS\", [ax0, ax1])\n v2 = T.axis.spatial(\n T.int64(12288),\n i0_i1_i2_fused_0 * T.int64(256)\n + i0_i1_i2_fused_1 * T.int64(4)\n + ax2,\n )\n v_i2k = T.axis.spatial(\n T.int64(512),\n i0_i1_i2_fused_1 * T.int64(8)\n + ax2_y * T.int64(4)\n + ax2,\n )\n T.where(ax2_y < T.int64(1))\n T.reads(lv1615_shared[v0, v1, v_i2k])\n T.writes(var_matmul_intermediate[v0, v1, v2])\n var_matmul_intermediate[v0, v1, v2] = (\n lv1615_shared[v0, v1, v_i2k]\n + lv1615_shared[v0, v1, v_i2k + T.int64(4)]\n )\n\n\n@T.prim_func(private=True)\ndef fused_decode5_fused_matmul6_silu1(\n lv1611: T.Buffer((T.int64(512), T.int64(11008)), \"uint32\"),\n lv1612: T.Buffer((T.int64(128), T.int64(11008)), \"float16\"),\n lv1622: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), \"float16\"),\n p_output0_intermediate: T.Buffer(\n (T.int64(1), T.int64(1), T.int64(11008)), \"float16\"\n ),\n):\n T.func_attr({\"tir.noalias\": T.bool(True)})\n # with T.block(\"root\"):\n var_decode_intermediate = T.alloc_buffer((T.int64(4096), T.int64(11008)), \"float16\")\n var_matmul_intermediate = T.alloc_buffer(\n (T.int64(1), T.int64(1), T.int64(11008)), \"float16\"\n )\n compute = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(11008)), \"float16\")\n for i, j in T.grid(T.int64(4096), T.int64(11008)):\n with T.block(\"decode\"):\n v_i, v_j = T.axis.remap(\"SS\", [i, j])\n T.reads(lv1611[v_i // T.int64(8), v_j], lv1612[v_i // T.int64(32), v_j])\n T.writes(var_decode_intermediate[v_i, v_j])\n var_decode_intermediate[v_i, v_j] = (\n T.Cast(\n \"float16\",\n T.bitwise_and(\n T.shift_right(\n lv1611[v_i // T.int64(8), v_j],\n T.Cast(\"uint32\", v_i % T.int64(8)) * T.uint32(4),\n ),\n T.uint32(15),\n ),\n )\n - T.float16(7)\n ) * lv1612[v_i // T.int64(32), v_j]\n for i0, i1, i2, k in T.grid(T.int64(1), T.int64(1), T.int64(11008), T.int64(4096)):\n with T.block(\"matmul\"):\n v_i0, v_i1, v_i2, v_k = T.axis.remap(\"SSSR\", [i0, i1, i2, k])\n T.reads(lv1622[v_i0, v_i1, v_k], var_decode_intermediate[v_k, v_i2])\n T.writes(var_matmul_intermediate[v_i0, v_i1, v_i2])\n with T.init():\n var_matmul_intermediate[v_i0, v_i1, v_i2] = T.float16(0)\n var_matmul_intermediate[v_i0, v_i1, v_i2] = (\n var_matmul_intermediate[v_i0, v_i1, v_i2]\n + lv1622[v_i0, v_i1, v_k] * var_decode_intermediate[v_k, v_i2]\n )\n for i0, i1, i2 in T.grid(T.int64(1), T.int64(1), T.int64(11008)):\n with T.block(\"compute\"):\n v_i0, v_i1, v_i2 = T.axis.remap(\"SSS\", [i0, i1, i2])\n T.reads(var_matmul_intermediate[v_i0, v_i1, v_i2])\n T.writes(compute[v_i0, v_i1, v_i2])\n compute[v_i0, v_i1, v_i2] = T.sigmoid(\n var_matmul_intermediate[v_i0, v_i1, v_i2]\n )\n for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(1), T.int64(11008)):\n with T.block(\"T_multiply\"):\n v_ax0, v_ax1, v_ax2 = T.axis.remap(\"SSS\", [ax0, ax1, ax2])\n T.reads(\n var_matmul_intermediate[v_ax0, v_ax1, v_ax2],\n compute[v_ax0, v_ax1, v_ax2],\n )\n T.writes(p_output0_intermediate[v_ax0, v_ax1, v_ax2])\n p_output0_intermediate[v_ax0, v_ax1, v_ax2] = (\n var_matmul_intermediate[v_ax0, v_ax1, v_ax2]\n * compute[v_ax0, v_ax1, v_ax2]\n )\n\n\n@T.prim_func(private=True)\ndef fused_decode5_fused_matmul6_silu1_after(\n lv1611: T.Buffer((T.int64(512), T.int64(11008)), \"uint32\"),\n lv1612: T.Buffer((T.int64(128), T.int64(11008)), \"float16\"),\n lv1622: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), \"float16\"),\n p_output0_intermediate: T.Buffer(\n (T.int64(1), T.int64(1), T.int64(11008)), \"float16\"\n ),\n):\n T.func_attr({\"tir.is_scheduled\": 1, \"tir.noalias\": T.bool(True)})\n # with T.block(\"root\"):\n var_matmul_intermediate_local = T.alloc_buffer(\n (T.int64(1), T.int64(1), T.int64(22016)), \"float16\", scope=\"local\"\n )\n var_matmul_intermediate_local_batch = T.alloc_buffer(\n (T.int64(1), T.int64(1), T.int64(22016)), \"float16\", scope=\"local\"\n )\n lv1611_local = T.alloc_buffer(\n (T.int64(512), T.int64(11008)), \"uint32\", scope=\"local\"\n )\n lv1612_local = T.alloc_buffer(\n (T.int64(128), T.int64(11008)), \"float16\", scope=\"local\"\n )\n lv1622_shared = T.alloc_buffer(\n (T.int64(1), T.int64(1), T.int64(1024)), \"float16\", scope=\"shared\"\n )\n for i0_i1_i2_fused_0 in T.thread_binding(T.int64(43), thread=\"blockIdx.x\"):\n for i0_i1_i2_fused_1 in T.thread_binding(T.int64(64), thread=\"threadIdx.x\"):\n for ax2_y in T.thread_binding(T.int64(2), thread=\"threadIdx.y\"):\n for i0_i1_i2_fused_2_init in T.vectorized(T.int64(4)):\n with T.block(\"matmul_init\"):\n v_i0 = T.axis.spatial(T.int64(1), T.int64(0))\n v_i1 = T.axis.spatial(T.int64(1), T.int64(0))\n v_i2 = T.axis.spatial(\n T.int64(22016),\n i0_i1_i2_fused_0 * T.int64(512)\n + i0_i1_i2_fused_1 * T.int64(8)\n + ax2_y * T.int64(4)\n + i0_i1_i2_fused_2_init,\n )\n T.reads()\n T.writes(var_matmul_intermediate_local[v_i0, v_i1, v_i2])\n var_matmul_intermediate_local[v_i0, v_i1, v_i2] = T.float16(0)\n for k_0 in range(T.int64(4)):\n for ax2_1 in T.thread_binding(T.int64(64), thread=\"threadIdx.x\"):\n for ax2_y in T.thread_binding(T.int64(2), thread=\"threadIdx.y\"):\n for ax0, ax1 in T.grid(T.int64(1), T.int64(1)):\n for ax2_2 in T.vectorized(T.int64(8)):\n with T.block(\"lv1622_shared\"):\n v0, v1 = T.axis.remap(\"SS\", [ax0, ax1])\n v2 = T.axis.spatial(\n T.int64(4096),\n k_0 * T.int64(1024)\n + ax2_y * T.int64(512)\n + ax2_1 * T.int64(8)\n + ax2_2,\n )\n v2k = T.axis.spatial(\n T.int64(1024),\n (\n ax2_y * T.int64(512)\n + ax2_1 * T.int64(8)\n + ax2_2\n ),\n )\n T.reads(lv1622[v0, v1, v2])\n T.writes(lv1622_shared[v0, v1, v2k])\n lv1622_shared[v0, v1, v2k] = lv1622[v0, v1, v2]\n for k_1 in range(T.int64(16)):\n for ax2_y in T.thread_binding(T.int64(2), thread=\"threadIdx.y\"):\n for ax1 in T.vectorized(T.int64(4)):\n with T.block(\"matmul_init_local\"):\n v_i0 = T.axis.spatial(T.int64(1), T.int64(0))\n v_i1 = T.axis.spatial(T.int64(1), T.int64(0))\n v_i2k = T.axis.spatial(\n T.int64(22016),\n i0_i1_i2_fused_0 * T.int64(512)\n + i0_i1_i2_fused_1 * T.int64(8)\n + ax2_y * T.int64(4)\n + ax1,\n )\n T.reads()\n T.writes(\n var_matmul_intermediate_local_batch[\n v_i0, v_i1, v_i2k\n ]\n )\n var_matmul_intermediate_local_batch[\n v_i0, v_i1, v_i2k\n ] = T.float16(0)\n for ax0 in range(T.int64(1)):\n for ax1 in T.vectorized(T.int64(4)):\n with T.block(\"lv1612_local\"):\n v0 = T.axis.spatial(\n T.int64(128),\n k_0 * T.int64(32)\n + (k_1 * T.int64(2) + ax2_y)\n + ax0,\n )\n v1 = T.axis.spatial(\n T.int64(11008),\n i0_i1_i2_fused_0 * T.int64(256)\n + i0_i1_i2_fused_1 * T.int64(4)\n + ax1,\n )\n T.reads(lv1612[v0, v1])\n T.writes(lv1612_local[v0, v1])\n lv1612_local[v0, v1] = lv1612[v0, v1]\n for k_2 in range(T.int64(4)):\n for ax0 in range(T.int64(1)):\n for ax1 in T.vectorized(T.int64(4)):\n with T.block(\"lv1611_local\"):\n v0 = T.axis.spatial(\n T.int64(512),\n k_0 * T.int64(128)\n + (k_1 * T.int64(2) + ax2_y) * T.int64(4)\n + k_2\n + ax0,\n )\n v1 = T.axis.spatial(\n T.int64(11008),\n i0_i1_i2_fused_0 * T.int64(256)\n + i0_i1_i2_fused_1 * T.int64(4)\n + ax1,\n )\n T.reads(lv1611[v0, v1])\n T.writes(lv1611_local[v0, v1])\n lv1611_local[v0, v1] = lv1611[v0, v1]\n for k_3 in range(T.int64(8)):\n for i0_i1_i2_fused_2 in T.vectorized(T.int64(4)):\n with T.block(\"matmul_update\"):\n v_i0 = T.axis.spatial(T.int64(1), T.int64(0))\n v_i1 = T.axis.spatial(T.int64(1), T.int64(0))\n v_i2 = T.axis.spatial(\n T.int64(11008),\n i0_i1_i2_fused_0 * T.int64(256)\n + i0_i1_i2_fused_1 * T.int64(4)\n + i0_i1_i2_fused_2,\n )\n v_i2k = T.axis.spatial(\n T.int64(22016),\n i0_i1_i2_fused_0 * T.int64(512)\n + i0_i1_i2_fused_1 * T.int64(8)\n + ax2_y * T.int64(4)\n + i0_i1_i2_fused_2,\n )\n v_k = T.axis.reduce(\n T.int64(4096),\n k_0 * T.int64(1024)\n + (k_1 * T.int64(2) + ax2_y) * T.int64(32)\n + k_2 * T.int64(8)\n + k_3,\n )\n v_ki = T.axis.reduce(\n T.int64(1024),\n (k_1 * T.int64(2) + ax2_y) * T.int64(32)\n + k_2 * T.int64(8)\n + k_3,\n )\n T.reads(\n var_matmul_intermediate_local_batch[\n v_i0, v_i1, v_i2k\n ],\n lv1622_shared[v_i0, v_i1, v_ki],\n lv1611_local[v_k // T.int64(8), v_i2],\n )\n T.writes(\n var_matmul_intermediate_local_batch[\n v_i0, v_i1, v_i2k\n ]\n )\n var_matmul_intermediate_local_batch[\n v_i0, v_i1, v_i2k\n ] = var_matmul_intermediate_local_batch[\n v_i0, v_i1, v_i2k\n ] + lv1622_shared[\n v_i0, v_i1, v_ki\n ] * (\n (\n T.Cast(\n \"float16\",\n T.bitwise_and(\n T.shift_right(\n lv1611_local[\n v_k // T.int64(8), v_i2\n ],\n T.Cast(\n \"uint32\",\n v_k % T.int64(8),\n )\n * T.uint32(4),\n ),\n T.uint32(15),\n ),\n )\n - T.float16(7)\n )\n )\n for ax0 in range(T.int64(1)):\n for ax1 in T.vectorized(T.int64(4)):\n with T.block(\"multiple_scale\"):\n v_i0 = T.axis.spatial(T.int64(1), T.int64(0))\n v_i1 = T.axis.spatial(T.int64(1), T.int64(0))\n v_i2k = T.axis.spatial(\n T.int64(22016),\n i0_i1_i2_fused_0 * T.int64(512)\n + i0_i1_i2_fused_1 * T.int64(8)\n + ax2_y * T.int64(4)\n + ax1,\n )\n v0 = T.axis.spatial(\n T.int64(128),\n k_0 * T.int64(32)\n + (k_1 * T.int64(2) + ax2_y)\n + ax0,\n )\n v1 = T.axis.spatial(\n T.int64(11008),\n i0_i1_i2_fused_0 * T.int64(256)\n + i0_i1_i2_fused_1 * T.int64(4)\n + ax1,\n )\n T.reads(\n lv1612_local[v0, v1],\n var_matmul_intermediate_local_batch[\n v_i0, v_i1, v_i2k\n ],\n )\n T.writes(\n var_matmul_intermediate_local[v_i0, v_i1, v_i2k]\n )\n var_matmul_intermediate_local[v_i0, v_i1, v_i2k] = (\n var_matmul_intermediate_local[v_i0, v_i1, v_i2k]\n + var_matmul_intermediate_local_batch[\n v_i0, v_i1, v_i2k\n ]\n * lv1612_local[v0, v1]\n )\n for ax2_y in T.thread_binding(T.int64(2), thread=\"threadIdx.y\"):\n for ax0, ax1 in T.grid(T.int64(1), T.int64(1)):\n for ax2 in T.vectorized(T.int64(4)):\n with T.block(\"var_matmul_intermediate_update\"):\n v0, v1 = T.axis.remap(\"SS\", [ax0, ax1])\n v2 = T.axis.spatial(\n T.int64(512),\n i0_i1_i2_fused_1 * T.int64(8)\n + ax2_y * T.int64(4)\n + ax2,\n )\n v_i2k = T.axis.spatial(\n T.int64(22016),\n i0_i1_i2_fused_0 * T.int64(512)\n + i0_i1_i2_fused_1 * T.int64(8)\n + ax2_y * T.int64(4)\n + ax2,\n )\n T.reads(var_matmul_intermediate_local[v0, v1, v_i2k])\n T.writes(lv1622_shared[v0, v1, v2])\n lv1622_shared[v0, v1, v2] = var_matmul_intermediate_local[\n v0, v1, v_i2k\n ]\n for ax2_y in T.thread_binding(T.int64(2), thread=\"threadIdx.y\"):\n for ax0, ax1 in T.grid(T.int64(1), T.int64(1)):\n for ax2 in T.vectorized(T.int64(4)):\n with T.block(\"reduction\"):\n v0, v1 = T.axis.remap(\"SS\", [ax0, ax1])\n v2 = T.axis.spatial(\n T.int64(512),\n i0_i1_i2_fused_1 * T.int64(8)\n + ax2_y * T.int64(4)\n + ax2,\n )\n T.where(ax2_y < T.int64(1))\n T.reads(lv1622_shared[v0, v1, v2])\n T.writes(lv1622_shared[v0, v1, v2])\n lv1622_shared[v0, v1, v2] = (\n lv1622_shared[v0, v1, v2]\n + lv1622_shared[v0, v1, v2 + T.int64(4)]\n )\n for ax0, ax1 in T.grid(T.int64(1), T.int64(1)):\n for ax2 in T.vectorized(T.int64(4)):\n with T.block(\"var_matmul_intermediate_local\"):\n v0, v1 = T.axis.remap(\"SS\", [ax0, ax1])\n v2 = T.axis.spatial(\n T.int64(11008),\n i0_i1_i2_fused_0 * T.int64(256)\n + i0_i1_i2_fused_1 * T.int64(4)\n + ax2,\n )\n v_i2k = T.axis.spatial(\n T.int64(512),\n i0_i1_i2_fused_1 * T.int64(8)\n + ax2_y * T.int64(4)\n + ax2,\n )\n T.where(ax2_y < T.int64(1))\n T.reads(lv1622_shared[v0, v1, v_i2k])\n T.writes(p_output0_intermediate[v0, v1, v2])\n p_output0_intermediate[v0, v1, v2] = lv1622_shared[\n v0, v1, v_i2k\n ] * T.sigmoid(lv1622_shared[v0, v1, v_i2k])\n\n\ndef sch_fused_decode5_fused_matmul6_silu1(func):\n sch = tvm.tir.Schedule(func)\n b0 = sch.get_block(name=\"decode\", func_name=\"main\")\n b1 = sch.get_block(name=\"matmul\", func_name=\"main\")\n l2, l3, l4, l5 = sch.get_loops(block=b1)\n l6 = sch.fuse(l2, l3, l4, preserve_unit_iters=True)\n v7, v8, v9 = sch.sample_perfect_tile(\n loop=l6, n=3, max_innermost_factor=4, decision=[43, 64, 4]\n )\n l10, l11, l12 = sch.split(loop=l6, factors=[v7, v8, v9], preserve_unit_iters=True)\n v13, v14, v15 = sch.sample_perfect_tile(\n loop=l5, n=3, max_innermost_factor=8, decision=[128, 4, 8]\n )\n l16, l17, l18 = sch.split(\n loop=l5, factors=[v13, v14, v15], preserve_unit_iters=True\n )\n sch.reorder(l10, l11, l16, l17, l18, l12)\n sch.bind(loop=l10, thread_axis=\"blockIdx.x\")\n sch.bind(loop=l11, thread_axis=\"threadIdx.x\")\n sch.compute_inline(block=b0)\n b19 = sch.cache_write(block=b1, write_buffer_index=0, storage_scope=\"local\")\n sch.reverse_compute_at(block=b19, loop=l11, preserve_unit_loops=True, index=-1)\n b20 = sch.cache_read(block=b1, read_buffer_index=1, storage_scope=\"local\")\n b21 = sch.cache_read(block=b1, read_buffer_index=2, storage_scope=\"local\")\n b22 = sch.cache_read(block=b1, read_buffer_index=0, storage_scope=\"shared\")\n sch.compute_at(block=b22, loop=l11, preserve_unit_loops=True, index=-1)\n v23 = sch.sample_categorical(\n candidates=[1, 2, 4, 8], probs=[0.25, 0.25, 0.25, 0.25], decision=3\n )\n sch.annotate(\n block_or_loop=b22, ann_key=\"meta_schedule.cooperative_fetch\", ann_val=v23\n )\n sch.compute_at(block=b20, loop=l17, preserve_unit_loops=True, index=-1)\n sch.compute_at(block=b21, loop=l16, preserve_unit_loops=True, index=-1)\n l24, l25, l26, l27, l28, l29 = sch.get_loops(block=b20)\n sch.vectorize(loop=l29)\n l30, l31, l32, l33, l34 = sch.get_loops(block=b21)\n sch.vectorize(loop=l34)\n l35, l36, l37, l38, l39 = sch.get_loops(block=b19)\n sch.vectorize(loop=l39)\n sch.vectorize(loop=l12)\n b40 = sch.decompose_reduction(block=b1, loop=l16)\n b41 = sch.get_block(name=\"compute\", func_name=\"main\")\n sch.compute_inline(block=b41)\n b42 = sch.get_block(name=\"T_multiply\", func_name=\"main\")\n sch.reverse_compute_inline(block=b42)\n sch.enter_postproc()\n sch.unannotate(block_or_loop=b22, ann_key=\"meta_schedule.cooperative_fetch\")\n l43, l44, l45, l46, l47 = sch.get_loops(block=b22)\n l48, l49, l50 = sch.split(loop=l47, factors=[None, 64, 8], preserve_unit_iters=True)\n sch.vectorize(loop=l50)\n sch.bind(loop=l49, thread_axis=\"threadIdx.x\")\n return sch.mod[\"main\"].with_attr(\"tir.is_scheduled\", 1)\n\n@T.prim_func(private=True)\ndef fused_decode81_fused_matmul1_cast2(\n lv1576: T.Buffer((T.int64(512), T.int64(64000)), \"uint32\"),\n lv1577: T.Buffer((T.int64(128), T.int64(64000)), \"float16\"),\n lv1575: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), \"float16\"),\n p_output0_intermediate: T.Buffer(\n (T.int64(1), T.int64(1), T.int64(64000)), \"float32\"\n ),\n):\n T.func_attr({\"tir.noalias\": T.bool(True)})\n # with T.block(\"root\"):\n var_decode_intermediate = T.alloc_buffer((T.int64(4096), T.int64(64000)), \"float16\")\n var_matmul_intermediate = T.alloc_buffer(\n (T.int64(1), T.int64(1), T.int64(64000)), \"float16\"\n )\n for i, j in T.grid(T.int64(4096), T.int64(64000)):\n with T.block(\"decode\"):\n v_i, v_j = T.axis.remap(\"SS\", [i, j])\n...\n", "needle_token_start": 1982, "needle_token_end": 2874, "code_context_ntokens": 16324, "output": ["```python\n@T.prim_func(private=True)\ndef fused_decode81_fused_matmul1_cast2(\n lv1576: T.Buffer((T.int64(512), T.int64(64000)), \"uint32\"),\n lv1577: T.Buffer((T.int64(128), T.int64(64000)), \"float16\"),\n lv1575: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), \"float16\"),\n p_output0_intermediate: T.Buffer(\n (T.int64(1), T.int64(1), T.int64(64000)), \"float32\"\n ),\n):\n T.func_attr({\"tir.noalias\": T.bool(True)})\n # with T.block(\"root\"):\n var_decode_intermediate = T.alloc_buffer((T.int64(4096), T.int64(64000)), \"float16\")\n var_matmul_intermediate = T.alloc_buffer(\n (T.int64(1), T.int64(1), T.int64(64000)), \"float16\"\n )\n for i, j in T.grid(T.int64(4096), T.int64(64000)):\n with T.block(\"decode\"):\n v_i, v_j = T.axis.remap(\"SS\", [i, j])\n var_decode_intermediate[v_i, v_j] = (\n T.Cast(\n \"float16\",\n T.bitwise_and(\n T.shift_right(\n lv1576[v_i // T.int64(8), v_j],\n T.Cast(\"uint32\", v_i % T.int64(8)) * T.uint32(4),\n ),\n T.uint32(15),\n ),\n )\n - T.float16(7)\n ) * lv1577[v_i // T.int64(32), v_j]\n for i0, i1, i2, k in T.grid(T.int64(1), T.int64(1), T.int64(64000), T.int64(4096)):\n with T.block(\"matmul\"):\n v_i0, v_i1, v_i2, v_k = T.axis.remap(\"SSSR\", [i0, i1, i2, k])\n T.reads(lv1575[v_i0, v_i1, v_k], var_decode_intermediate[v_k, v_i2])\n T.writes(var_matmul_intermediate[v_i0, v_i1, v_i2])\n with T.init():\n var_matmul_intermediate[v_i0, v_i1, v_i2] = T.float16(0)\n var_matmul_intermediate[v_i0, v_i1, v_i2] = (\n var_matmul_intermediate[v_i0, v_i1, v_i2]\n + lv1575[v_i0, v_i1, v_k] * var_decode_intermediate[v_k, v_i2]\n )\n for i0, i1, i2 in T.grid(T.int64(1), T.int64(1), T.int64(64000)):\n with T.block(\"T_cast\"):\n v_i0, v_i1, v_i2 = T.axis.remap(\"SSS\", [i0, i1, i2])\n T.reads(var_matmul_intermediate[v_i0, v_i1, v_i2])\n T.writes(p_output0_intermediate[v_i0, v_i1, v_i2])\n p_output0_intermediate[v_i0, v_i1, v_i2] = T.cast(\n \"float32\", var_matmul_intermediate[v_i0, v_i1, v_i2]\n )\n```"]} +{"repo": "mlc-ai/mlc-llm", "name": "contain_symbolic_var", "language": "python", "path": "mlc_llm/transform/lift_tir_global_buffer_alloc.py", "position_ratio": 0.25, "description": "\nFunction Description:\n1. **Purpose**: The function checks if a tensor's shape contains any symbolic (non-constant) variables.\n2. **Input**: A structure containing information about a tensor, specifically its shape.\n3. **Output**: A boolean value indicating whether the tensor's shape includes any symbolic variables.\n4. **Procedure**: The function iterates through the values defining the tensor's shape. If any of these values are not integer constants, it returns true, indicating the presence of symbolic variables; otherwise, it returns false.\n\n", "instruction": "Based on the function description and code context, please retrieve and repeat the exact described function from the code context in a code block wrapped by ```:", "template": "instruction\ncode_context\ndescription\ninstruction", "code_context": "# Path: mlc_llm/transform/fuse_split_rotary_embedding.py\nimport tvm\nfrom tvm import relax\nfrom tvm.relax.dpl import (\n PatternContext,\n is_op,\n rewrite_bindings,\n wildcard,\n is_tuple_get_item,\n GlobalVarPattern,\n TuplePattern,\n is_shape,\n)\nfrom tvm.script import relax as R, tir as T\n\n\ndef get_dynamic_split_rotary():\n \"\"\"Implementation of R.split(rotary_embedding(fused_qkv))\n\n Implementation is generic over the number of query heads,\n key/value heads, sequence length, head dimension, and position\n embedding base. These parameters can be replaced with static\n values using `PrimFunc.specialize`.\n \"\"\"\n\n @T.prim_func(private=True)\n def split_rotary(\n fused_qkv_handle: T.handle,\n embedded_query_handle: T.handle,\n embedded_key_handle: T.handle,\n value_handle: T.handle,\n rotary_offset: T.int64,\n batch_size: T.int64,\n seq_len: T.int64,\n num_query_heads: T.int64,\n num_kv_heads: T.int64,\n head_dim: T.int64,\n position_embedding_base: T.float32,\n ):\n Fused_QKV = T.match_buffer(\n fused_qkv_handle,\n [batch_size, seq_len, num_query_heads + num_kv_heads * 2, head_dim],\n dtype=\"float16\",\n )\n EmbeddedQuery = T.match_buffer(\n embedded_query_handle,\n [batch_size, seq_len, num_query_heads, head_dim],\n dtype=\"float16\",\n )\n EmbeddedKey = T.match_buffer(\n embedded_key_handle,\n [batch_size, seq_len, num_kv_heads, head_dim],\n dtype=\"float16\",\n )\n Value = T.match_buffer(\n value_handle,\n [batch_size, seq_len, num_kv_heads, head_dim],\n dtype=\"float16\",\n )\n\n T.func_attr({\"op_pattern\": 2, \"tir.noalias\": T.bool(True)})\n\n for iters in T.grid(batch_size, seq_len, num_query_heads + num_kv_heads * 2, head_dim):\n with T.block(\"FusedRotaryEmbeddingAndSplitQKV\"):\n batch_i, seq_i, head_num, head_i = T.axis.remap(\"SSSS\", iters)\n pos: T.float32 = T.Cast(\"float32\", rotary_offset + seq_i - seq_len)\n\n inv_freq: T.float32 = T.float32(1) / T.pow(\n position_embedding_base,\n T.Cast(\"float32\", (head_i * 2) % head_dim) / T.float32(head_dim),\n )\n freq: T.float32 = pos * inv_freq\n cos_value: T.float16 = T.Cast(\"float16\", T.cos(freq))\n sin_value: T.float16 = T.Cast(\"float16\", T.sin(freq))\n\n input_value = Fused_QKV[batch_i, seq_i, head_num, head_i]\n embedded_value = cos_value * input_value + sin_value * T.Select(\n head_i < T.int64(head_dim // 2),\n Fused_QKV[batch_i, seq_i, head_num, head_i + T.int64(head_dim // 2)]\n * T.float16(-1),\n Fused_QKV[batch_i, seq_i, head_num, head_i - T.int64(head_dim // 2)],\n )\n if head_num < num_query_heads:\n EmbeddedQuery[batch_i, seq_i, head_num, head_i] = embedded_value\n elif head_num < num_query_heads + num_kv_heads:\n EmbeddedKey[batch_i, seq_i, head_num - num_query_heads, head_i] = embedded_value\n else:\n Value[\n batch_i, seq_i, head_num - num_query_heads - num_kv_heads, head_i\n ] = input_value\n\n param_sinfo = []\n for param in split_rotary.params:\n if param in split_rotary.buffer_map:\n buf = split_rotary.buffer_map[param]\n sinfo = relax.TensorStructInfo(shape=buf.shape, dtype=buf.dtype)\n else:\n sinfo = relax.PrimStructInfo(param.dtype)\n param_sinfo.append(sinfo)\n\n relax.expr._update_struct_info(\n split_rotary,\n tvm.relax.FuncStructInfo(\n params=param_sinfo,\n ret=relax.TupleStructInfo([]),\n purity=False,\n ),\n )\n\n return split_rotary\n\n\ndef fuse_split_rotary_embedding(\n num_query_heads, num_kv_heads, hidden_size, position_embedding_base\n):\n @tvm.ir.transform.module_pass(opt_level=0, name=\"fuse_split_rotary_embedding\")\n def ir_module_pass(mod: tvm.IRModule, _pass_context) -> tvm.IRModule:\n head_dim = hidden_size // num_query_heads\n split_rotary = get_dynamic_split_rotary()\n\n (\n dyn_batch_size,\n dyn_seq_len,\n dyn_num_query_heads,\n dyn_num_kv_heads,\n dyn_head_dim,\n dyn_position_embedding_base,\n ) = split_rotary.params[-6:]\n\n split_rotary = split_rotary.specialize(\n {\n # Static model parameters\n dyn_batch_size: T.int64(1),\n dyn_num_query_heads: T.int64(num_query_heads),\n dyn_num_kv_heads: T.int64(num_kv_heads),\n dyn_head_dim: T.int64(head_dim),\n dyn_position_embedding_base: T.float32(position_embedding_base),\n # Dynamic parameters, to be inferred from TIR Buffer shapes\n dyn_seq_len: tvm.tir.Var(\"query_sequence_length\", \"int64\"),\n }\n )\n\n mod[\"split_rotary\"] = split_rotary\n\n split_rotary_gvar = mod.get_global_var(\"split_rotary\")\n relax.expr._update_struct_info(split_rotary_gvar, mod[\"split_rotary\"].struct_info)\n\n with PatternContext() as ctx:\n # flat_qkv_tuple: R.Tuple(\n # R.Tensor((batch_size, seq_len, 4096), dtype=\"float16\"),\n # R.Tensor((batch_size, seq_len, 4096), dtype=\"float16\"),\n # R.Tensor((batch_size, seq_len, 4096), dtype=\"float16\"),\n # ) = R.split(flat_fused_qkv, indices_or_sections=[4096, 8192], axis=2)\n #\n # flat_query: R.Tensor((batch_size, seq_len, 4096), dtype=\"float16\") = flat_qkv_tuple[0]\n # query: R.Tensor((batch_size, seq_len, 32, 128), dtype=\"float16\") = R.reshape(\n # flat_query, R.shape([batch_size, seq_len, 32, 128])\n # )\n # flat_key: R.Tensor((batch_size, seq_len, 4096), dtype=\"float16\") = flat_qkv_tuple[1]\n # key: R.Tensor((batch_size, seq_len, 32, 128), dtype=\"float16\") = R.reshape(\n # flat_key, R.shape([batch_size, seq_len, 32, 128])\n # )\n # flat_value: R.Tensor((batch_size, seq_len, 4096), dtype=\"float16\") = flat_qkv_tuple[2]\n # value: R.Tensor((batch_size, seq_len, 32, 128), dtype=\"float16\") = R.reshape(\n # flat_value, R.shape([batch_size, seq_len, 32, 128])\n # )\n # embedded_query = R.call_tir(\n # cls.rotary_embedding1,\n # [query],\n # out_sinfo=R.Tensor((batch_size, seq_len, 32, 128), dtype=\"float16\"),\n # tir_vars=R.shape([n]),\n # )\n # embedded_key = R.call_tir(\n # cls.rotary_embedding1,\n # [key],\n # out_sinfo=R.Tensor((batch_size, seq_len, 32, 128), dtype=\"float16\"),\n # tir_vars=R.shape([n]),\n # )\n\n pat_rotary_embedding_gvar = GlobalVarPattern()\n\n pat_flat_fused_qkv = wildcard()\n pat_offset = wildcard()\n\n # query_shape = is_shape([1, seq_len, num_query_heads, head_dim])\n pat_query_shape = wildcard()\n # value_shape = is_shape([1, seq_len, num_kv_heads, head_dim])\n pat_key_shape = wildcard()\n # value_shape = is_shape([1, seq_len, num_kv_heads, head_dim])\n pat_value_shape = wildcard()\n\n pat_flat_qkv_tuple = is_op(\"relax.split\")(pat_flat_fused_qkv)\n pat_flat_query = is_tuple_get_item(pat_flat_qkv_tuple, 0)\n pat_query = is_op(\"relax.reshape\")(\n pat_flat_query, pat_query_shape, add_constraint=False\n )\n pat_flat_query.used_by(pat_query)\n pat_flat_key = is_tuple_get_item(pat_flat_qkv_tuple, 1)\n pat_key = is_op(\"relax.reshape\")(pat_flat_key, pat_key_shape, add_constraint=False)\n pat_flat_key.used_by(pat_key)\n pat_flat_value = is_tuple_get_item(pat_flat_qkv_tuple, 2)\n pat_value = is_op(\"relax.reshape\")(\n pat_flat_value, pat_value_shape, add_constraint=False\n )\n pat_flat_value.used_by(pat_value)\n\n pat_embedded_query = is_op(\"relax.call_tir\")(\n pat_rotary_embedding_gvar,\n TuplePattern([pat_query]),\n pat_offset,\n add_constraint=False,\n )\n pat_embedded_key = is_op(\"relax.call_tir\")(\n pat_rotary_embedding_gvar,\n TuplePattern([pat_key]),\n pat_offset,\n add_constraint=False,\n )\n\n pat_flat_qkv_tuple.used_by(pat_flat_query)\n pat_flat_qkv_tuple.used_by(pat_flat_key)\n pat_flat_qkv_tuple.used_by(pat_flat_value)\n pat_query.used_by(pat_embedded_query)\n pat_key.used_by(pat_embedded_key)\n\n def rewriter(matchings, bindings):\n # Extracting all the relax and TIR variables that we'll need\n flat_fused_qkv = matchings[pat_flat_fused_qkv]\n flat_qkv_tuple = matchings[pat_flat_qkv_tuple]\n\n flat_query = matchings[pat_flat_query]\n...\n# Path: mlc_llm/transform/lift_tir_global_buffer_alloc.py\n\"\"\"Lift global buffer allocation in TIR to graph level\"\"\"\n\nfrom typing import Dict, List, Tuple, Optional\n\nimport tvm\nfrom tvm import relax, tir\nfrom tvm.ir.module import IRModule\nfrom tvm.relax.analysis import remove_all_unused\nfrom tvm.relax.expr_functor import PyExprMutator, mutator\n\n\ndef remove_global_buf_alloc(\n func: tir.PrimFunc,\n) -> Optional[Tuple[tir.PrimFunc, List[relax.TensorStructInfo]]]:\n \"\"\"Remove the global buffer allocation for a given TIR PrimFunc.\"\"\"\n if not isinstance(func.body, tir.BlockRealize):\n return None\n\n params = list(func.params)\n buffer_map = dict(func.buffer_map)\n tensor_sinfo = []\n alloc_buffers = []\n\n insertion_point = len(params)\n while params[insertion_point - 1].dtype != \"handle\":\n insertion_point -= 1\n assert insertion_point >= 1\n\n prev_root_block = func.body.block\n for buf_alloc in func.body.block.alloc_buffers:\n if buf_alloc.scope() == \"global\":\n param = tir.Var(\"var_\" + buf_alloc.name, \"handle\")\n params.insert(insertion_point, param)\n insertion_point += 1\n buffer_map[param] = buf_alloc\n tensor_sinfo.append(relax.TensorStructInfo(buf_alloc.shape, buf_alloc.dtype))\n else:\n alloc_buffers.append(buf_alloc)\n\n if len(tensor_sinfo) == 0:\n return None\n\n assert len(prev_root_block.iter_vars) == 0\n assert len(prev_root_block.reads) == 0\n assert len(prev_root_block.writes) == 0\n assert len(prev_root_block.match_buffers) == 0\n assert prev_root_block.name_hint == \"root\"\n assert prev_root_block.init is None\n root_block = tir.Block(\n iter_vars=[],\n reads=[],\n writes=[],\n name_hint=\"root\",\n body=prev_root_block.body,\n alloc_buffers=alloc_buffers,\n annotations=prev_root_block.annotations,\n )\n\n updated_func = tir.PrimFunc(\n params=params,\n body=tir.BlockRealize(iter_values=[], predicate=True, block=root_block),\n ret_type=func.ret_type,\n buffer_map=buffer_map,\n attrs=func.attrs,\n )\n return updated_func, tensor_sinfo\n\n\n\ndef contain_symbolic_var(tensor_sinfo: relax.TensorStructInfo) -> bool:\n assert isinstance(tensor_sinfo.shape, relax.ShapeExpr)\n for v in tensor_sinfo.shape.values:\n if not isinstance(v, tir.IntImm):\n return True\n return False\n\n\ndef resolve_tir_var_mapping(\n func: tir.PrimFunc, call: relax.Call, tensor_sinfo: List[relax.TensorStructInfo]\n) -> Tuple[List[relax.TensorStructInfo], bool]:\n \"\"\"Resolve the TIR symbolic var relationship across sides of PrimFunc and Relax Function\"\"\"\n var_map: Dict[tir.Var, tir.PrimExpr] = dict()\n\n n_arg = len(call.args[1].fields)\n for i in range(n_arg):\n buffer_shape = func.buffer_map[func.params[i]].shape\n arg_shape = call.args[1][i].struct_info.shape.values\n assert len(buffer_shape) == len(arg_shape)\n for vl, vr in zip(buffer_shape, arg_shape):\n if isinstance(vl, tir.Var):\n var_map[vl] = vr\n elif not isinstance(vl, tir.IntImm):\n return [], False\n\n ret_tensors = call.sinfo_args[0]\n ret_tensors = (\n [ret_tensors]\n if isinstance(ret_tensors, relax.TensorStructInfo)\n else list(ret_tensors.fields)\n )\n for i in range(len(ret_tensors)):\n buffer_shape = func.buffer_map[func.params[n_arg + i]].shape\n ret_tensor_shape = ret_tensors[i].shape.values\n assert len(buffer_shape) == len(ret_tensor_shape)\n for vl, vr in zip(buffer_shape, ret_tensor_shape):\n if isinstance(vl, tir.Var):\n var_map[vl] = vr\n elif not isinstance(vl, tir.IntImm):\n return [], False\n\n updated_tensor_sinfo = []\n for sinfo in tensor_sinfo:\n if not contain_symbolic_var(sinfo):\n updated_tensor_sinfo.append(sinfo)\n continue\n\n new_shape = []\n for v in sinfo.shape.values:\n new_shape.append(tir.stmt_functor.substitute(v, var_map))\n updated_tensor_sinfo.append(relax.TensorStructInfo(new_shape, sinfo.dtype))\n return updated_tensor_sinfo, True\n\n\ndef LiftTIRGlobalBufferAlloc():\n @mutator\n class TIRGlobalAllocRewriter(PyExprMutator):\n def __init__(self, mod: IRModule):\n super().__init__(mod)\n self.mod = mod\n\n def transform(self) -> IRModule:\n self.mod = self.builder_.get()\n for gv, func in self.mod.functions.items():\n if isinstance(func, relax.Function):\n updated_func = self.visit_expr(func)\n self.builder_.update_func(gv, updated_func)\n return self.builder_.get()\n\n def visit_call_(self, call: relax.Call):\n call = self.visit_expr_post_order(call)\n if call.op != tvm.ir.Op.get(\"relax.call_tir\"):\n return call\n\n old_gvar = call.args[0]\n\n func_before_update = self.mod.functions[old_gvar]\n updates = remove_global_buf_alloc(func_before_update)\n if updates is None:\n return call\n updated_func, tensor_sinfo = updates\n\n assert len(call.sinfo_args) == 1\n if any(contain_symbolic_var(sinfo) for sinfo in tensor_sinfo):\n tensor_sinfo, success = resolve_tir_var_mapping(\n func_before_update, call, tensor_sinfo\n )\n if not success:\n # Cannot resolve TIR var mapping. Fall back to no lifting.\n return call\n\n new_gvar = self.builder_.add_func(updated_func, old_gvar.name_hint)\n new_args = [new_gvar, *call.args[1:]]\n\n if isinstance(call.sinfo_args[0], relax.TensorStructInfo):\n new_call = relax.Call(\n call.op,\n args=new_args,\n sinfo_args=[relax.TupleStructInfo(list(call.sinfo_args) + tensor_sinfo)],\n attrs=call.attrs,\n )\n emitted_tuple = self.builder_.emit(new_call)\n return relax.TupleGetItem(emitted_tuple, 0)\n elif isinstance(call.sinfo_args[0], relax.TupleStructInfo):\n return relax.Call(\n call.op,\n args=new_args,\n sinfo_args=[\n relax.TupleStructInfo(list(call.sinfo_args[0].fields) + tensor_sinfo)\n ],\n attrs=call.attrs,\n )\n else:\n raise TypeError(\n f\"Expected {call.op} to return either R.Tensor or R.Tuple, \"\n f\"but instead returned {call.sinfo_args[0]}\"\n )\n\n @tvm.transform.module_pass(opt_level=0, name=\"LiftTIRGlobalBufferAlloc.Inner\")\n def transform_module(mod: IRModule, _: tvm.transform.PassContext) -> IRModule:\n return TIRGlobalAllocRewriter(mod).transform()\n\n return tvm.ir.transform.Sequential(\n [\n transform_module,\n tvm.relax.transform.DeadCodeElimination(),\n ],\n name=\"LiftTIRGlobalBufferAlloc\",\n )\n\n# Path: mlc_llm/transform/reorder_transform_func.py\nfrom typing import Callable, Dict, List, Set, Tuple, Optional\n\nimport tvm\nfrom tvm import relax\nfrom tvm.ir.module import IRModule\n\n\"\"\"\nThis pass in this file reorders the bindings of the weight transform function\naccording to the weight location in binary files. The goal of the reorder is to\nreduce the memory pressure when loading the raw model weights and processing\nthem. In the ideal case, with this pass, the highest CPU memory usage will\naround the size of the largest raw weight binary file.\n\nRegarding the implementation, the bindings of fetching a raw weight in the\nweight transform function are all in the form of `lv = params[idx]`. Here, each\nindex specifies a raw weight tensor, and the raw weight tensor resides in a\nbinary file on the disk.\n\nWe group such `lv = params[idx]` into multiple groups, such that all raw weight\ntensors in a group come from a same binary file. We reorder the bindings\naccording to the grouping result based on topological sort.\n\nIn ideal case, after reordering the weight transform function has the following\nprocess during execution:\n* load a weight binary file,\n* process all weights in this file,\n* load another weight binary file,\n* process all weights in this file,\n* ...\n\nSo the maximum CPU memory usage will be the size of the largest raw weight\nbinary file, since we process and release all the raw weight tensors immediately\nafter loading them from the file.\n\"\"\"\n\n\ndef analyze_func(\n func: relax.Function,\n pidx2binname: Dict[int, str],\n) -> Tuple[List[relax.Binding], Dict[relax.Var, List[relax.Binding]], Dict[relax.Binding, int],]:\n \"\"\"Binding grouping analysis function.\n It takes the function to be analyzed, and mapping from each raw tensor index\n to the name of the binary file where it resides.\n\n This analysis function\n * computes a new order of weight fetching bindings (the bindings in form\n `lv = params[idx]`) based on weight location on disk.\n * collects the dataflow def-use information of the given function for\n topological sort (particularly, it collects the consumers of each binding\n variables and the number of variables each binding depends on).\n\n Parameters\n ----------\n func : relax.Function\n The weight transform function to be analyzed.\n\n pidx2binname : Dict[int, str]\n The mapping from each raw tensor index to the name of the binary\n file where it resides.\n\n Returns\n -------\n get_param_bindings : List[relax.Binding]\n The weight fetching bindings (`lv = params[idx]`) in the new order.\n\n var_users : Dict[relax.Var, List[relax.Binding]]\n The consumer bindings of each binding variable.\n Used for topological sort.\n\n num_depending_vars : Dict[relax.Binding, int]\n The number of variables each binding depends on.\n Used for topological sort.\n \"\"\"\n\n # The mapping of the weight fetching bindings in each binary file.\n # Here empty string means the weight is not in any binary file (e.g., cached\n # sin and cos values for rotary embeddings).\n binname2get_param_bindings: Dict[str, List[relax.Binding]] = {\"\": []}\n # The set of binding variables.\n binding_var_set: Set[relax.Var] = set()\n var_users: Dict[relax.Var, List[relax.Binding]] = {}\n num_depending_vars: Dict[relax.Binding, int] = {}\n\n if func.attrs is not None and \"num_input\" in func.attrs:\n num_input = func.attrs[\"num_input\"].value\n else:\n num_input = 0\n\n # Sanity check on the function pattern.\n assert isinstance(func.body, relax.SeqExpr)\n assert len(func.body.blocks) == 1\n assert isinstance(func.body.blocks[0], relax.DataflowBlock)\n assert func.body.blocks[0].bindings[-1].var.same_as(func.body.body)\n\n if isinstance(func.params[num_input].struct_info, relax.TupleStructInfo):\n model_param_tuple = func.params[num_input]\n else:\n model_param_tuple = None\n for i, var in enumerate(func.params[num_input:]):\n binname = pidx2binname.get(i, var.name_hint)\n if binname not in binname2get_param_bindings:\n binname2get_param_bindings[binname] = []\n binname2get_param_bindings[binname].append(var)\n\n bindings = list(func.body.blocks[0].bindings)\n\n # Go through each binding except the last one. (The last one is the output\n # binding `gv = (lv, lv1, ...)`) which we ignore for analysis.\n for binding in bindings[:-1]:\n value = binding.value\n binding_var_set.add(binding.var)\n var_users[binding.var] = []\n\n if (\n model_param_tuple is not None\n and isinstance(value, relax.TupleGetItem)\n and value.tuple_value.same_as(model_param_tuple)\n ):\n # For weight fetching bindings (`lv = params[idx]`), we group them\n # according to the binary file name.\n pidx = value.index\n if pidx not in pidx2binname:\n binname2get_param_bindings[\"\"].append(binding)\n continue\n\n binname = pidx2binname[pidx]\n if binname in binname2get_param_bindings:\n binname2get_param_bindings[binname].append(binding)\n else:\n binname2get_param_bindings[binname] = [binding]\n else:\n # For other bindings, we collect the use-def information for\n # topological sort.\n num_depending_vars[binding] = 0\n\n def fvisit(obj):\n if isinstance(obj, relax.Var) and obj in binding_var_set:\n assert obj in var_users\n var_users[obj].append(binding)\n num_depending_vars[binding] += 1\n\n relax.analysis.post_order_visit(value, fvisit)\n\n # Get the weight fetching bindings in new order according to the group results.\n get_param_bindings: List[relax.Binding] = []\n for bindings in binname2get_param_bindings.values():\n get_param_bindings += bindings\n\n return get_param_bindings, var_users, num_depending_vars\n\n\ndef reorder_func(\n func: relax.Function,\n pidx2binname: Optional[Dict[int, str]] = None,\n) -> relax.Function:\n \"\"\"Reorder the bindings of the input weight transform Relax function\n according the weight location in binary files.\n\n This function first analyzes the input function and gets the reordered\n weight fetching bindings and the use-def information for topological sort.\n It then reorders all bindings in the function with topological sort.\n\n Parameters\n ----------\n func : relax.Function\n The weight transform function to be analyzed.\n\n pidx2binname : Optional[Dict[int, str]]\n\n The mapping from each raw tensor index to the name of the\n binary file where it resides. If a relax dataflow graph has\n multiple valid topological sorts, the order that minimizes the\n number of simultaneously open files will be produced\n\n If `None` (default), the existing order of relax bindings is\n preserved in these cases.\n\n Returns\n -------\n func_updated : relax.Function\n The returned function where the bindings are updated with the new order.\n\n \"\"\"\n\n if pidx2binname is None:\n pidx2binname = {}\n\n bindings_to_visit = list(func.body.blocks[0].bindings)\n param_lookup = {param: i for i, param in enumerate(func.params)}\n binding_lookup = {}\n previously_defined = set(func.params)\n new_binding_order = []\n\n param_tuple = None\n if len(func.params) == 1 and isinstance(func.params[0].struct_info, relax.TupleStructInfo):\n param_tuple = func.params[0]\n\n def sort_key(i):\n binding = bindings_to_visit[i]\n upstream_vars = relax.analysis.free_vars(binding.value)\n\n valid_ordering = all(var in previously_defined for var in upstream_vars)\n last_param_used = max(\n (param_lookup[var] for var in upstream_vars if var in param_lookup), default=-1\n )\n earliest_binding_used = min(\n (binding_lookup[var] for var in upstream_vars if var in binding_lookup), default=-1\n )\n if (\n param_tuple\n and isinstance(binding.value, relax.TupleGetItem)\n and binding.value.tuple_value.same_as(param_tuple)\n and binding.value.index in pidx2binname\n ):\n tuple_param_group = pidx2binname[binding.value.index]\n else:\n tuple_param_group = \"\"\n\n return [\n # First, sort by valid orderings, so the min element will\n # always be a binding that would be legal to use.\n -valid_ordering,\n # Next, sort by the function parameter used by this\n # binding, in increasing order. That way, we start by\n # computing everything that required just the first\n # parameter, then move on to variables that can be\n # computed with the first two parameters, and so on.\n last_param_used,\n # Next, sort by the other bindings used. This way, for\n # variables that are only used as input in a single\n # downstream binding, the variable's required live range\n # is minimized.\n -earliest_binding_used,\n # Finally, if this is a `TupleGetItem(param_tuple, i)`,\n # select the option that uses an already-open file. This\n # is mainly used relevant when loading from pytorch, which\n # require loading the entire file at once.\n tuple_param_group,\n ]\n\n while bindings_to_visit:\n i_binding = min(range(len(bindings_to_visit)), key=sort_key)\n binding = bindings_to_visit.pop(i_binding)\n\n assert all(var in previously_defined for var in relax.analysis.free_vars(binding.value))\n new_binding_order.append(binding)\n previously_defined.add(binding.var)\n\n assert len(new_binding_order) == len(func.body.blocks[0].bindings)\n\n return relax.Function(\n func.params,\n relax.SeqExpr(\n blocks=[relax.DataflowBlock(new_binding_order)],\n body=func.body.body,\n ),\n func.ret_struct_info,\n func.is_pure,\n func.attrs,\n )\n\n\n@tvm.transform.module_pass(opt_level=0, name=\"ReorderTransformFunc\")\nclass ReorderTransformFunc:\n def __init__(self, pidx2binname: Optional[Dict[int, str]] = None):\n if pidx2binname is None:\n pidx2binname = {}\n self.pidx2binname = pidx2binname\n\n def transform_module(\n self,\n mod: IRModule,\n ctx: tvm.transform.PassContext,\n ) -> IRModule:\n mod = mod.clone()\n for gv, func in list(mod.functions.items()):\n if isinstance(func, relax.Function) and func.attrs and \"global_symbol\" in func.attrs:\n assert gv.name_hint.endswith(\"transform_params\")\n func_updated = reorder_func(func, self.pidx2binname)\n mod[gv] = func_updated\n return mod\n\n# Path: mlc_llm/transform/rewrite_attention.py\nimport tvm\nfrom tvm.relax.dpl import PatternContext, is_const, is_op, rewrite_call, wildcard\nfrom tvm.script import relax as R\n\n\ndef rewrite_attention(use_flash_mqa=False):\n @tvm.ir.transform.module_pass(opt_level=0, name=\"mlc_llm.transform.rewrite_attention\")\n def ir_module_transform(mod: tvm.IRModule, context) -> tvm.IRModule:\n Q = wildcard()\n K = wildcard()\n V = wildcard()\n\n Q_BNSH = is_op(\"relax.permute_dims\")(Q)\n\n if use_flash_mqa:\n K_BNSH = is_op(\"relax.permute_dims\")(is_op(\"relax.repeat\")(K))\n V_BNSH = is_op(\"relax.permute_dims\")(is_op(\"relax.repeat\")(V))\n else:\n K_BNSH = is_op(\"relax.permute_dims\")(K)\n V_BNSH = is_op(\"relax.permute_dims\")(V)\n\n K_BNSH_T = is_op(\"relax.permute_dims\")(K_BNSH)\n\n matmul1 = is_op(\"relax.matmul\")(Q_BNSH, K_BNSH_T)\n divide = is_op(\"relax.divide\")(matmul1, is_const())\n max = is_op(\"relax.maximum\")(divide, is_const())\n min = is_op(\"relax.minimum\")(max, wildcard())\n softmax = is_op(\"relax.nn.softmax\")(is_op(\"relax.astype\")(min))\n matmul2 = is_op(\"relax.matmul\")(is_op(\"relax.astype\")(softmax), V_BNSH)\n\n pattern = is_op(\"relax.permute_dims\")(matmul2)\n\n def callback(_, matchings):\n return R.nn.attention(\n matchings[Q], matchings[K], matchings[V], causal_mask=\"BottomRight\"\n )\n\n new_module = {}\n for gvar, func in mod.functions.items():\n if isinstance(func, tvm.relax.Function):\n func = rewrite_call(pattern, callback, func)\n new_module[gvar] = func\n\n return tvm.IRModule(new_module, mod.type_definitions, mod.attrs, mod.global_infos)\n\n return ir_module_transform\n\n# Path: mlc_llm/transform/set_entry_funcs.py\nimport re\n\nfrom typing import List, Union\n\nimport tvm\nfrom tvm.ir import GlobalVar\n\n\ndef SetEntryFuncs(*entry_funcs: List[Union[GlobalVar, str]]) -> tvm.ir.transform.Pass:\n \"\"\"Update which functions are externally-exposed\n\n All functions whose GlobalVar is contained `entry_funcs` list, or\n whose name matches a regular expression in `entry_funcs`, are set\n as externally exposed. All other functions are set as internal.\n\n This pass does not add or remove any functions from the\n `IRModule`. This pass may result in functions no longer being\n used by any externally-exposed function. In these cases, users\n may use the `relax.transform.DeadCodeElimination` pass to remove\n any unnecessary functions.\n\n Parameters\n ----------\n entry_funcs: List[Union[GlobalVar, str]]\n\n Specifies which functions that should be externally exposed,\n either by GlobalVar or by regular expression.\n\n Returns\n -------\n transform: tvm.ir.transform.Pass\n\n The IRModule-to-IRModule transformation\n \"\"\"\n\n def is_entry_func(gvar: GlobalVar) -> bool:\n for entry_func in entry_funcs:\n if isinstance(entry_func, GlobalVar):\n if entry_func.same_as(gvar):\n return True\n elif isinstance(entry_func, str):\n if re.fullmatch(entry_func, gvar.name_hint):\n return True\n else:\n raise TypeError(\n f\"SetEntryFuncs requires all arguments to be a GlobalVar or a str. \"\n f\"However, argument {entry_func} has type {type(entry_func)}.\"\n )\n\n def is_exposed(func: tvm.ir.BaseFunc) -> bool:\n return func.attrs is not None and \"global_symbol\" in func.attrs\n\n @tvm.ir.transform.module_pass(opt_level=0, name=\"SetEntryFuncs\")\n def transform(mod: tvm.IRModule, _pass_context) -> tvm.IRModule:\n updates = {}\n for gvar, func in mod.functions.items():\n if is_entry_func(gvar):\n if not is_exposed(func):\n updates[gvar] = func.with_attr(\"global_symbol\", gvar.name_hint)\n else:\n if is_exposed(func):\n updates[gvar] = func.without_attr(\"global_symbol\")\n\n if updates:\n mod = mod.clone()\n mod.update(updates)\n\n return mod\n\n return transform\n\n# Path: mlc_llm/transform/transpose_matmul.py\nimport tvm\nfrom tvm import IRModule, relax, te, tir\nfrom tvm.relax.dpl.pattern import is_op, wildcard\n\n\n@relax.expr_functor.mutator\nclass TransposeMatmulCodeGenerator(relax.PyExprMutator):\n def __init__(self, mod):\n super().__init__(mod)\n\n @staticmethod\n def pattern():\n w = wildcard()\n x = wildcard()\n wT = is_op(\"relax.permute_dims\")(w)\n o = is_op(\"relax.matmul\")(x, wT)\n annotations = {\"o\": o, \"w\": w, \"x\": x, \"wT\": wT}\n\n def _check(context: relax.transform.PatternCheckContext) -> bool:\n transpose_call = context.annotated_expr[\"wT\"]\n ndim = transpose_call.args[0].struct_info.ndim\n if ndim == -1:\n return False\n if ndim == 2 and transpose_call.attrs.axes is None:\n return True\n axes = list(range(ndim))\n axes[-1], axes[-2] = axes[-2], axes[-1]\n return list(transpose_call.attrs.axes) == axes\n\n return o, annotations, _check\n\n def visit_call_(self, call: relax.Call) -> relax.Expr:\n out_dtype = None\n\n def te_transposed_matmul(a: te.Tensor, b: te.Tensor) -> te.Tensor:\n nonlocal out_dtype\n a_shape = list(a.shape)\n b_shape = list(b.shape)\n a_prepended = False\n b_appended = False\n if len(a_shape) == 1:\n a_prepended = True\n a_shape.insert(0, 1)\n if len(b_shape) == 1:\n b_appended = True\n b_shape.append(1)\n\n is_a_larger = len(a_shape) > len(b_shape)\n offset = len(a_shape) - len(b_shape) if is_a_larger else len(b_shape) - len(a_shape)\n\n a_relax = relax.Var(\"a\", relax.TensorStructInfo(a.shape))\n bT_shape = list(b.shape)\n bT_shape[-1], bT_shape[-2] = bT_shape[-2], bT_shape[-1]\n bT_relax = relax.Var(\"b\", relax.TensorStructInfo(bT_shape))\n output_shape = self.builder_.normalize(\n relax.op.matmul(a_relax, bT_relax)\n ).struct_info.shape\n\n def matmul_compute(*idx_spatial):\n k = te.reduce_axis((0, a_shape[-1]), name=\"k\")\n\n def multiply_compute(idx_reduce):\n a_indices = []\n b_indices = []\n\n for i in range(offset):\n if is_a_larger:\n a_indices.append(idx_spatial[i])\n else:\n b_indices.append(idx_spatial[i])\n for i in range(offset, len(output_shape) - (2 - a_prepended - b_appended)):\n a_dim = a_shape[i if is_a_larger else i - offset]\n b_dim = b_shape[i if not is_a_larger else i - offset]\n dim_equal = a_dim == b_dim\n if not isinstance(dim_equal, tir.IntImm) or dim_equal == 0:\n a_dim_is_one = isinstance(a_dim, tir.IntImm) and a_dim == 1\n b_dim_is_one = isinstance(b_dim, tir.IntImm) and b_dim == 1\n a_indices.append(0 if a_dim_is_one else idx_spatial[i])\n b_indices.append(0 if b_dim_is_one else idx_spatial[i])\n else:\n a_indices.append(idx_spatial[i])\n b_indices.append(idx_spatial[i])\n\n if not a_prepended:\n a_indices.append(idx_spatial[-2 + b_appended])\n a_indices.append(idx_reduce)\n if not b_appended:\n b_indices.append(idx_spatial[-1])\n b_indices.append(idx_reduce)\n\n dtype = out_dtype\n if dtype != \"\":\n return a(*a_indices).astype(dtype) * b(*b_indices).astype(dtype)\n return a(*a_indices) * b(*b_indices)\n\n return te.sum(multiply_compute(k), axis=k)\n\n return te.compute(\n output_shape,\n lambda *idx: matmul_compute(*idx), # pylint: disable=unnecessary-lambda\n name=\"NT_matmul\",\n )\n\n if isinstance(call.op, relax.GlobalVar):\n function = self.builder_.get()[call.op]\n if (\n function.attrs\n and \"Composite\" in function.attrs\n and function.attrs[\"Composite\"] == \"transpose_matmul_fuse\"\n ):\n out_dtype = function.ret_struct_info.dtype\n return self.builder_.call_te(\n te_transposed_matmul,\n call.args[1],\n call.args[0],\n primfunc_name_hint=\"NT_matmul\",\n )\n\n return super().visit_call_(call)\n\n\n@tvm.transform.module_pass(opt_level=0, name=\"FuseTransposeMatmul\")\nclass FuseTransposeMatmul:\n def transform_module(self, mod: IRModule, ctx: tvm.transform.PassContext) -> IRModule:\n mod = relax.transform.FuseOpsByPattern(\n [(\"transpose_matmul_fuse\", *TransposeMatmulCodeGenerator.pattern())]\n )(mod)\n\n transpose_matmul_codegen = TransposeMatmulCodeGenerator(mod)\n for gv in mod.functions:\n func = mod[gv]\n if not isinstance(func, relax.Function):\n continue\n func = transpose_matmul_codegen.visit_expr(func)\n transpose_matmul_codegen.builder_.update_func(gv, func)\n\n return transpose_matmul_codegen.builder_.get()\n\n@relax.expr_functor.mutator\nclass Transpose1MatmulCodeGenerator(relax.PyExprMutator):\n def __init__(self, mod):\n super().__init__(mod)\n\n @staticmethod\n def pattern():\n w = wildcard()\n x = wildcard()\n xT = is_op(\"relax.permute_dims\")(x)\n wT = is_op(\"relax.permute_dims\")(w)\n o = is_op(\"relax.matmul\")(xT, wT)\n annotations = {\"o\": o, \"w\": w, \"x\": x, \"xT\": xT, \"wT\": wT}\n\n def _check(context: relax.transform.PatternCheckContext) -> bool:\n x_transpose_call = context.annotated_expr[\"o\"]\n w_transpose_call = context.annotated_expr[\"o\"]\n x_shape = context.annotated_expr[\"x\"].struct_info.shape\n w_shape = context.annotated_expr[\"w\"].struct_info.shape\n xT_shape = x_transpose_call.args[0].struct_info.shape\n wT_shape = w_transpose_call.args[1].struct_info.shape\n\n if not (\n xT_shape[0] == x_shape[0] and xT_shape[1] == x_shape[2]\n and xT_shape[2] == x_shape[1] and xT_shape[3] == x_shape[3]\n ):\n return False\n\n if not (\n wT_shape[0] == w_shape[0] and wT_shape[1] == w_shape[2]\n and wT_shape[2] == w_shape[3] and wT_shape[3] == w_shape[1]\n ):\n return False\n\n return True\n\n return o, annotations, _check\n\n def visit_call_(self, call: relax.Call) -> relax.Expr:\n out_dtype = None\n\n def te_transposed_matmul(a: te.Tensor, b: te.Tensor) -> te.Tensor:\n nonlocal out_dtype\n a_shape = list(a.shape)\n b_shape = list(b.shape)\n\n aT_shape = list(a.shape)\n aT_shape[-2], aT_shape[-3] = aT_shape[-3], aT_shape[-2]\n aT_relax = relax.Var(\"a\", relax.TensorStructInfo(aT_shape))\n bT_shape = list(b.shape)\n bT_shape[-1], bT_shape[-2], bT_shape[-3] = bT_shape[-3], bT_shape[-1], bT_shape[-2]\n bT_relax = relax.Var(\"b\", relax.TensorStructInfo(bT_shape))\n output_shape = self.builder_.normalize(\n relax.op.matmul(aT_relax, bT_relax)\n ).struct_info.shape\n def matmul_compute(*idx_spatial):\n k = te.reduce_axis((0, a_shape[-1]), name=\"k\")\n def multiply_compute(idx_reduce):\n a_indices = [idx_spatial[0], idx_spatial[2], idx_spatial[1], idx_reduce]\n b_indices = [idx_spatial[0], idx_spatial[3], idx_spatial[1], idx_reduce]\n dtype = out_dtype\n if dtype != \"\":\n return a(*a_indices).astype(dtype) * b(*b_indices).astype(dtype)\n return a(*a_indices) * b(*b_indices)\n\n return te.sum(multiply_compute(k), axis=k)\n\n return te.compute(\n output_shape,\n lambda *idx: matmul_compute(*idx), # pylint: disable=unnecessary-lambda\n name=\"NT_matmul\",\n )\n\n if isinstance(call.op, relax.GlobalVar):\n function = self.builder_.get()[call.op]\n if (\n \"Composite\" in function.attrs\n and function.attrs[\"Composite\"] == \"transpose1_matmul_fuse\"\n ):\n out_dtype = function.ret_struct_info.dtype\n return self.builder_.call_te(\n te_transposed_matmul,\n call.args[0],\n call.args[1],\n primfunc_name_hint=\"NT_matmul\",\n )\n\n return super().visit_call_(call)\n\n\n@tvm.transform.module_pass(opt_level=0, name=\"FuseTranspose1Matmul\")\nclass FuseTranspose1Matmul:\n def transform_module(\n self, mod: IRModule, ctx: tvm.transform.PassContext\n ) -> IRModule:\n mod = relax.transform.FuseOpsByPattern(\n [(\"transpose1_matmul_fuse\", *Transpose1MatmulCodeGenerator.pattern())]\n )(mod)\n\n transpose_matmul_codegen = Transpose1MatmulCodeGenerator(mod)\n for gv in mod.functions:\n func = mod[gv]\n if not isinstance(func, relax.Function):\n continue\n func = transpose_matmul_codegen.visit_expr(func)\n transpose_matmul_codegen.builder_.update_func(gv, func)\n\n return transpose_matmul_codegen.builder_.get()\n\n\n@relax.expr_functor.mutator\nclass Transpose2MatmulCodeGenerator(relax.PyExprMutator):\n def __init__(self, mod):\n super().__init__(mod)\n\n @staticmethod\n def pattern():\n w = wildcard()\n x = wildcard()\n wT = is_op(\"relax.permute_dims\")(w)\n o = is_op(\"relax.permute_dims\")(is_op(\"relax.matmul\")(x, wT))\n #oT = is_op(\"relax.permute_dims\")(o)\n annotations = {\"o\": o, \"w\": w, \"x\": x, \"wT\": wT}\n\n def _check(context: relax.transform.PatternCheckContext) -> bool:\n w_transpose_call = context.annotated_expr[\"wT\"]\n w_shape = w_transpose_call.args[0].struct_info.shape\n wT_shape = w_transpose_call.struct_info.shape\n oT_call = context.annotated_expr[\"o\"]\n o_shape = oT_call.args[0].struct_info.shape\n oT_shape = oT_call.struct_info.shape\n\n if not (\n wT_shape[0] == w_shape[0] and wT_shape[1] == w_shape[2]\n and wT_shape[2] == w_shape[1] and wT_shape[3] == w_shape[3]\n ):\n return False\n\n if not (\n oT_shape[0] == o_shape[0] and oT_shape[1] == o_shape[2]\n and oT_shape[2] == o_shape[1] and oT_shape[3] == o_shape[3]\n ):\n return False\n\n return True\n\n return o, annotations, _check\n\n def visit_call_(self, call: relax.Call) -> relax.Expr:\n out_dtype = None\n\n def te_transposed_matmul(a: te.Tensor, b: te.Tensor) -> te.Tensor:\n nonlocal out_dtype\n a_shape = list(a.shape)\n b_shape = list(b.shape)\n output_shape = [a_shape[0], b_shape[-2], a_shape[2], a_shape[3]]\n def matmul_compute(*idx_spatial):\n k = te.reduce_axis((0, b_shape[-1]), name=\"k\")\n def multiply_compute(idx_reduce):\n a_indices = [idx_spatial[0], idx_reduce, idx_spatial[2], idx_spatial[3]]\n b_indices = [idx_spatial[0], idx_spatial[2], idx_spatial[1], idx_reduce]\n\n dtype = out_dtype\n if dtype != \"\":\n return a(*a_indices).astype(dtype) * b(*b_indices).astype(dtype)\n return a(*a_indices) * b(*b_indices)\n\n return te.sum(multiply_compute(k), axis=k)\n\n return te.compute(\n output_shape,\n lambda *idx: matmul_compute(*idx), # pylint: disable=unnecessary-lambda\n name=\"NT_matmul\",\n )\n\n if isinstance(call.op, relax.GlobalVar):\n function = self.builder_.get()[call.op]\n if (\n \"Composite\" in function.attrs\n and function.attrs[\"Composite\"] == \"transpose2_matmul_fuse\"\n ):\n out_dtype = function.ret_struct_info.dtype\n #NT_output_shape = function.ret_struct_info.shape\n return self.builder_.call_te(\n te_transposed_matmul,\n call.args[0],\n call.args[1],\n primfunc_name_hint=\"NT_matmul\",\n )\n\n return super().visit_call_(call)\n\n\n@tvm.transform.module_pass(opt_level=0, name=\"FuseTranspose2Matmul\")\nclass FuseTranspose2Matmul:\n def transform_module(\n self, mod: IRModule, ctx: tvm.transform.PassContext\n ) -> IRModule:\n mod = relax.transform.FuseOpsByPattern(\n [(\"transpose2_matmul_fuse\", *Transpose2MatmulCodeGenerator.pattern())]\n )(mod)\n\n transpose_matmul_codegen = Transpose2MatmulCodeGenerator(mod)\n for gv in mod.functions:\n func = mod[gv]\n if not isinstance(func, relax.Function):\n continue\n func = transpose_matmul_codegen.visit_expr(func)\n transpose_matmul_codegen.builder_.update_func(gv, func)\n\n return transpose_matmul_codegen.builder_.get()\n\n# Path: mlc_llm/transform/__init__.py\nfrom .clean_up_tir_attrs import CleanUpTIRAttrs\nfrom .decode_matmul_ewise import FuseDecodeMatmulEwise\nfrom .decode_take import FuseDecodeTake\nfrom .decode_transpose import FuseDecodeTranspose\nfrom .fuse_split_rotary_embedding import fuse_split_rotary_embedding\nfrom .lift_tir_global_buffer_alloc import LiftTIRGlobalBufferAlloc\nfrom .reorder_transform_func import ReorderTransformFunc\nfrom .rewrite_attention import rewrite_attention\nfrom .transpose_matmul import FuseTransposeMatmul, FuseTranspose1Matmul, FuseTranspose2Matmul\nfrom .set_entry_funcs import SetEntryFuncs\n\n# Path: mlc_llm/relax_model/param_manager.py\nimport json\nimport os\nfrom typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union\n\nimport tvm\nfrom torch import Tensor as torchTensor\nfrom tvm import relax, tir\nfrom tvm._ffi.runtime_ctypes import Device\nfrom tvm.relax.analysis import remove_all_unused\nfrom tvm.relax.expr import Expr, Function, Var\nfrom tvm.relax.expr_functor import PyExprMutator, mutator\nfrom tvm.relax.testing import nn\n\nfrom .. import quantization\nfrom .modules import named_parameters\nfrom ..transform import ReorderTransformFunc\n\n\ndef f_default_compute_relax_param(relax_pname: str, torch_params: List[Any]) -> Any:\n \"\"\"The defualt `f_compute_relax_param` for ParamManager.\n See ParamManager for more details.\n \"\"\"\n raise NotImplementedError()\n\n\nclass Parameter:\n \"\"\"The abstraction of weight tensors (e.g., linear layer weight, embedding\n table, etc.) in a model.\n\n Attributes\n ----------\n name : str\n The name of the parameter.\n The name of a weight is got by `named_parameters()` method, similar to\n PyTorch's `named_parameters()` function.\n An example name is `model.layers.11.self_attn.k_proj.weight`.\n In a model, the name is the **unique** identifier of a parameter.\n\n param_info_dict : Dict[str, relax.TensorStructInfo]\n The shape and dtype of the parameter in each function.\n The shape can be accessed by `param_info_dict[func_name].shape`, which is\n a relax.ShapeExpr instance.\n And the dtype can be accessed by `param_info_dict[func_name].dtype`,\n which is a Python string.\n\n quant_spec : quantization.QuantizationSpec\n The quantization specification of this parameter.\n It specifies the algorithm to quantize and dequantize this parameter (or\n this parameter does not need quantization).\n\n shard_dim : Optional[int]\n The dimension to be sharded.\n\n shard_strategy : Optional[str]\n The strategy to shard the parameter.\n \"\"\"\n\n name: str\n param_info_dict: Dict[str, relax.TensorStructInfo]\n quant_spec: quantization.QuantizationSpec\n shard_dim: Optional[int]\n shard_strategy: Optional[str]\n\n def __init__(\n self,\n name: str,\n quant_spec: quantization.QuantizationSpec,\n shard_dim: Optional[int],\n shard_strategy: Optional[str],\n ) -> None:\n self.name = name\n self.param_info_dict = dict()\n self.quant_spec = quant_spec\n self.shard_dim = shard_dim\n self.shard_strategy = shard_strategy\n\n def register_func(self, func_name: str, param_info: relax.TensorStructInfo):\n self.param_info_dict[func_name] = param_info\n\n @property\n def param_info(self):\n \"\"\"Return the shape and dtype of the parameter (in some arbitrary function).\"\"\"\n return next(iter(self.param_info_dict.values()))\n\n\nclass ParamManager:\n \"\"\"The model-wise data structure which contains the information of every\n weight in the model and is in charge of applying quantization and dequantization\n to the parameters at the entire model level.\n\n Attributes\n ----------\n params : Dict[str, Parameter]\n The mapping from parameter names to parameters.\n\n param_names : List[str]\n The name list of all the parameters.\n To enforce a unique order or all the parameters for determinism, the\n parameter names are kept in the list, and the parameter order is\n uniquely determined by the parameter name list.\n\n func_raw_param_map : Dict[relax.Var, Tuple[str, Parameter]]\n The mapping from each relax.Var that denotes a weight parameter to the\n name of the function the var is in (e.g., \"prefill\" or \"decode\"), and\n the Parameter it corresponds to.\n This mapping is used for applying quantization transformation to the\n Relax functions (e.g., the \"prefill\", \"decode\", etc.) in the model.\n\n param2qrange : Dict[Parameter, range]\n The mapping from each parameter to the range of its quantized tensors\n in the list of quantized tensors of all parameters.\n Each parameter is quantized into multiple tensors.\n For example, assume we have parameters `p0`, `p1`, `p2`.\n - `p0` is quantized into `t0_0`, `t0_1`,\n - `p1` is quantized into `t1_0`, and\n - `p2` is quantized into `t2_0`, `t2_1` and `t2_2`.\n Then the list of all quantized tensors is `[t0_0, t0_1, t1_0, t2_0, t2_1, t2_2]`,\n and the dict `param2qrange` is\n `{p0: range(0, 2), p1: range(2, 3), p2: range(3, 6)}`.\n\n f_convert_pname_fwd : Callable[[str], List[str]]\n The function which converts Relax parameter name (ours) to torch's\n parameter names, suggesting \"to load this Relax parameter, which torch\n parameter(s) are needed\".\n - Usually, the function maps a name to itself. For example, in LLaMA we\n map `lm_head.weight` itself, as the parameter has the same name on both\n Relax side and torch side.\n - In some cases we map a name to multiple names. For example, if we\n support combined QKV computing when the torch side separates them, on\n Relax side we only have one QKV weight, while on torch side we have\n one weight for each of Q, K, V. In this case, we map one name to three\n...\n", "needle_token_start": 4008, "needle_token_end": 4086, "code_context_ntokens": 16163, "output": ["```python\ndef contain_symbolic_var(tensor_sinfo: relax.TensorStructInfo) -> bool:\n assert isinstance(tensor_sinfo.shape, relax.ShapeExpr)\n for v in tensor_sinfo.shape.values:\n if not isinstance(v, tir.IntImm):\n return True\n return False\n```"]} +{"repo": "mlc-ai/mlc-llm", "name": "get_param_set_item", "language": "python", "path": "mlc_llm/relax_model/param_manager.py", "position_ratio": 0.35, "description": "\nFunction Description:\n1. **Purpose**: The function is designed to provide a mechanism for lazy loading of parameters in a machine learning model, specifically for use with a parameter transformation and loading framework. It facilitates the registration of a function that can be used to set parameters at specific indices during model execution.\n2. **Input**: This function does not take any direct inputs when called.\n3. **Output**: It returns a tuple containing a callable function and a list. The callable function is used to set a parameter at a specified index, and the list is intended to store these parameters once they are loaded and set.\n4. **Procedure**: \n - Initializes an empty list to store parameters.\n - Defines a nested function that takes an index and a parameter. This function checks if the list is long enough to accommodate the parameter at the given index, extends the list if necessary, and then sets the parameter at the specified index.\n - Returns the nested function and the initially empty list. The list is populated with parameters as the nested function is called during the execution of a compiled function in a machine learning framework.\n\n", "instruction": "Based on the function description and code context, please retrieve and repeat the exact described function from the code context in a code block wrapped by ```:", "template": "instruction\ncode_context\ndescription\ninstruction", "code_context": " f_convert_param_bkwd: Callable[[str, Any], Optional[List[Tuple[str, Any]]]]\n f_compute_relax_param: Callable[[str, List[Any]], Any]\n f_run_prequantize: Optional[Callable[[str], str]]\n\n model_path: str\n use_safetensors: bool\n safetensors_load_func: Callable[[Union[str, os.PathLike], str], Dict[str, torchTensor]]\n pidx2pname: Dict[int, str]\n torch_pname2binname: Dict[str, str]\n\n def __init__(self) -> None:\n self.params = {}\n self.param_names = []\n self.params_in_func = {}\n\n self.func_raw_param_map = {}\n self.param2qrange = None\n\n self.nparam_to_load = None\n self.f_convert_pname_fwd = None\n self.f_convert_param_bkwd = None\n self.f_compute_relax_param = None\n self.f_run_prequantize = None\n\n self.qspec_updater_classes = []\n\n def register_params(\n self,\n model: nn.Module,\n func_name: str,\n quantization_scheme: quantization.QuantizationScheme,\n f_get_param_quant_kind: Callable[\n [str, relax.TensorStructInfo], quantization.ParamQuantKind\n ],\n ) -> None:\n \"\"\"Register the parameters of the input model (within the context of the\n input function) in the parameter manager.\n\n Parameters\n ----------\n model : nn.Module\n The input model whose parameters are registered.\n\n func_name : str\n The name of the function the input model is in.\n For example, the \"prefill\" function or the \"decode\" function.\n\n quantization_scheme : quantization.QuantizationScheme\n The quantization scheme of the input model, which describes how\n to quantize the model.\n\n f_get_param_quant_kind: Callable[[str, relax.TensorStructInfo], quantization.ParamQuantKind]\n A function which takes the name and StructInfo (effectively shape\n and dtype) of a parameter, and returns which quantization kind this\n parameter uses.\n This is used for applying quantization to the parameters.\n \"\"\"\n if quantization_scheme.qspec_updater_class is not None:\n self.qspec_updater_classes.append(quantization_scheme.qspec_updater_class)\n if quantization_scheme.f_convert_param_bkwd is not None:\n self.f_convert_param_bkwd = quantization_scheme.f_convert_param_bkwd\n if quantization_scheme.f_compute_relax_param is not None:\n self.f_compute_relax_param = quantization_scheme.f_compute_relax_param\n if quantization_scheme.f_run_prequantize is not None:\n self.f_run_prequantize = quantization_scheme.f_run_prequantize\n\n self.params_in_func[func_name] = []\n # For each parameter in the input model, get its quantization kind and\n # register the parameter with its name and quantization kind.\n for name, relax_param in named_parameters(model).items():\n quant_kind = f_get_param_quant_kind(name, relax_param.struct_info)\n param = self._register_param(\n name,\n relax_param,\n getattr(quantization_scheme, quant_kind.name),\n func_name,\n relax_param.__dict__.get(\"shard_dim\", None),\n relax_param.__dict__.get(\"shard_strategy\", None),\n )\n\n self.params_in_func[func_name].append(param)\n\n def run_pre_quantize(self, model_path: str):\n if self.f_run_prequantize is not None:\n model_path = self.f_run_prequantize(model_path)\n\n self.model_path = model_path\n return model_path\n\n def init_torch_pname_to_bin_name(self, use_safetensors: bool):\n assert hasattr(self, \"model_path\"), (\n \"Must call either set_param_loading_func or run_pre_quantize \"\n \"before init_torch_pname_to_bin_name\"\n )\n\n if self.pidx2pname:\n mapping = load_torch_pname2binname_map(\n self.model_path,\n use_safetensors,\n set(self.pidx2pname.values()),\n self.f_convert_pname_fwd,\n )\n else:\n mapping = {}\n\n self.torch_pname2binname = mapping\n\n def set_param_loading_func(\n self,\n model_path: str,\n use_safetensors: bool,\n f_convert_pname_fwd: Callable[[str], List[str]] = lambda pname: [pname],\n f_convert_param_bkwd: Callable[\n [str, Any], Optional[List[Tuple[str, Any]]]\n ] = lambda pname, torch_param: [(pname, torch_param)],\n f_compute_relax_param: Callable[[str, List[Any]], Any] = f_default_compute_relax_param,\n *,\n no_lazy_param_loading: bool = False,\n ) -> None:\n \"\"\"Set the parameter loading functions.\n\n Parameters\n ----------\n model_path : str\n The path of the Hugging Face model on disk.\n\n use_safetensors : bool\n Whether to use ``.safetensors`` instead of ``.bin`` to load model.\n\n f_convert_pname_fwd : Callable[[str], List[str]]\n The function which converts Relax parameter name (ours) to torch's\n parameter names. See the document of ParamManager for more details.\n\n f_convert_param_bkwd : Callable[[str, Any], Optional[List[Tuple[str, Any]]]]\n The function which converts torch parameter and param name back to\n Relax parameters with names. `Any` here stands for numpy.ndarray.\n See the document of ParamManager for more details.\n\n f_compute_relax_param : Callable[[str, List[Any]], Any]\n The function which computes a Relax parameter from a list of torch\n parameters. `Any` here stands for numpy.ndarray.\n See the document of ParamManager for more details.\n\n no_lazy_param_loading : bool\n A boolean indicating that no lazy parameter loading from torch is needed.\n This needs to be set as True when all the model weights are loaded\n at the time of constructing the model.\n \"\"\"\n self.f_convert_pname_fwd = f_convert_pname_fwd\n if self.f_convert_param_bkwd is None:\n self.f_convert_param_bkwd = f_convert_param_bkwd\n if self.f_compute_relax_param is None:\n self.f_compute_relax_param = f_compute_relax_param\n\n self.model_path = model_path\n self.use_safetensors = use_safetensors\n if self.use_safetensors:\n # Use a pointer here to prevent repeated import in tvm registered function\n from safetensors.torch import (\n load_file, # pylint: disable=import-outside-toplevel\n )\n\n def load_safetensors_func(*args):\n params = load_file(*args)\n for name, param in params.items():\n dtype = str(param.dtype)\n if dtype == \"torch.bfloat16\":\n param = param.float()\n params[name] = param\n return params\n\n self.safetensors_load_func = load_safetensors_func\n\n pnames_to_load = []\n for param_name in self.param_names:\n param = self.params[param_name]\n loaded_names, _ = param.quant_spec.get_loaded_tensor_info(param_name, param.param_info)\n pnames_to_load += loaded_names\n\n self.nparam_to_load = len(pnames_to_load)\n if not no_lazy_param_loading:\n self.pidx2pname = {pidx: pname for pidx, pname in enumerate(pnames_to_load)}\n else:\n self.pidx2pname = dict()\n\n def transform_dequantize(self) -> tvm.ir.transform.Pass:\n \"\"\"Apply dequantization to the input IRModule.\n\n Parameters\n ----------\n mod : tvm.IRModule\n The input IRModule to be applied dequantization.\n The IRModule contains all the constructed Relax functions\n (e.g., the \"prefill\"/\"decode\" functions) and is expected to\n have all of its parameters registered in the ParamManager.\n\n Returns\n -------\n updated_mod : tvm.IRModule\n The IRModule updated with the dequantization computation.\n \"\"\"\n\n @tvm.ir.transform.module_pass(opt_level=0, name=\"ParamManager.transform_dequantize\")\n def transform_func(mod: tvm.IRModule, _context) -> tvm.IRModule:\n # For each Relax function in the input IRModule (e.g., \"prefill\"),\n # we create its input relax.Var of all the quantized data, and\n # store the mapping from function name to the var.\n func_name_to_quantized_params: Dict[str, List[relax.Var]] = {}\n\n for gv, func in mod.functions.items():\n if isinstance(func, relax.Function) and func.attrs and \"num_input\" in func.attrs:\n func_name_to_quantized_params[gv.name_hint] = self.get_quantized_params(\n gv.name_hint\n )\n\n # Cache mapping to avoid duplicate dequantization.\n dequantized_cache: Dict[relax.Var, relax.Var] = {}\n\n # Define a var replacement function for applying dequantization.\n def f_replace(var: relax.Var, bb: relax.BlockBuilder) -> relax.Var:\n if var in dequantized_cache:\n return dequantized_cache[var]\n assert var in self.func_raw_param_map\n\n func_name, param = self.func_raw_param_map[var]\n quantized_params = func_name_to_quantized_params[func_name]\n relevant_quantized_params = [quantized_params[i] for i in self.param2qrange[param]]\n\n dequantized = self._dequantize(param, relevant_quantized_params, bb, func_name)\n\n dequantized_cache[var] = dequantized\n return dequantized\n\n # Create the function mutator for applying dequantization.\n replacer = ParamReplacer(mod, func_name_to_quantized_params, f_replace)\n # Update the input IRModule with dequantization.\n mod = replacer.transform()\n\n return mod\n\n return transform_func\n\n def get_quantized_params(self, func_name: str) -> List[relax.Var]:\n quantized_params: List[relax.Var] = []\n\n bb = relax.BlockBuilder()\n with bb.function(\"main\", []):\n self.param2qrange = dict()\n\n for name in self.param_names:\n param = self.params[name]\n param_info = None\n if func_name in param.param_info_dict:\n param_info = param.param_info_dict[func_name]\n else:\n param_info = relax.TensorStructInfo(\n tvm.ir.load_json(tvm.ir.save_json(param.param_info.shape)),\n param.param_info.dtype,\n )\n\n loaded_tensor_names, loaded_tensor_info = param.quant_spec.get_loaded_tensor_info(\n name, param_info\n )\n\n provided_tensor_vars: List[relax.Var] = [\n relax.Var(name, sinfo)\n for name, sinfo in zip(loaded_tensor_names, loaded_tensor_info)\n ]\n\n # Get the quantization function of this parameter.\n f_quantize = param.quant_spec.get_quantize_func(param_info)\n if f_quantize is None:\n # If the parameter does not have a quantization function, either it\n # does not need quantization or it is pre-quantized.\n self.param2qrange[param] = range(\n len(quantized_params),\n len(quantized_params) + len(provided_tensor_vars),\n )\n quantized_params.extend(provided_tensor_vars)\n else:\n # If the parameter has a quantization function, it is not expected\n # to be pre-quantized.\n assert len(provided_tensor_vars) == 1, (\n \"A parameter with quantization function is not expected \"\n \"to be pre-quantized.\"\n )\n\n # Apply the quantization function.\n quantized_data = bb.normalize(f_quantize(bb, provided_tensor_vars))\n if isinstance(quantized_data.struct_info, relax.TupleStructInfo):\n fields = quantized_data.struct_info.fields\n n_tensor = len(fields)\n assert n_tensor > 1\n # Record the range of quantized tensors of this parameter.\n self.param2qrange[param] = range(\n len(quantized_params),\n len(quantized_params) + n_tensor,\n )\n # Collect the quantized tensors to return.\n quantized_params.extend(\n relax.Var(f\"{name}.{field.dtype}.{i}\", field)\n for i, field in enumerate(fields)\n )\n\n else:\n field = quantized_data.struct_info\n assert isinstance(field, relax.TensorStructInfo)\n self.param2qrange[param] = range(\n len(quantized_params), len(quantized_params) + 1\n )\n quantized_params.append(relax.Var(f\"{name}.{field.dtype}\", field))\n bb.emit_func_output(relax.const(0, \"int64\"))\n\n return quantized_params\n\n def get_param_get_item(\n self, device: Device, model_params: List[Optional[tvm.nd.NDArray]] = []\n ) -> Callable:\n \"\"\"A wrapper function which returns the `get_item`\n functions for parameter lazy loading.\n\n The return value of this function is intended to be registered\n as `\"get_item\"`, for use in a module built with\n `LazyTransformParams`.\n\n .. code-block:: python\n\n get_item = manager.get_param_get_item(tvm.cuda())\n tvm.register_func(func_name=\"get_item\", f=get_item, override=True)\n compiled_function()\n\n Parameters\n ----------\n device : Device\n\n The device onto which tensor parameters should be loaded.\n\n model_params : List[Optional[tvm.nd.NDArray]]\n\n Any pre-loaded model parameters. For parameter at index\n `i`, if `model_params[i]` already contains an array, that\n array will be returned from `get_item`. Otherwise, the\n parameter will be loaded either from disk, or from an\n internal cache.\n\n Returns\n -------\n get_item: Callable[[int], tvm.nd.NDArray]\n\n A function that accepts an index, and returns the tensor\n parameter located at that index, loaded onto `device`.\n\n \"\"\"\n import torch # pylint: disable=import-outside-toplevel\n\n assert self.f_convert_pname_fwd is not None\n assert self.f_convert_param_bkwd is not None\n assert self.f_compute_relax_param is not None\n pname2pidx: Dict[str, int] = {pname: pidx for pidx, pname in self.pidx2pname.items()}\n\n # The set of indices of loaded parameters, serving for\n # robustness guarantee to avoid one parameter being loaded for\n # multiple times.\n loaded_idx_set: Set[int] = set()\n\n # The set of torch binary filenames, serving for robustness guarantee\n # to avoid one torch binary file being loaded for multiple times.\n loaded_torch_bins: Set[str] = set()\n\n # The set of cached Relax parameters.\n cached_relax_params: Dict[int, tvm.nd.NDArray] = {}\n\n # The set of cached torch parameters. `Any` here stands for\n # numpy.ndarray.\n cached_torch_params: Dict[str, Any] = {}\n\n device_cpu = tvm.cpu()\n\n def fetch_torch_param(torch_param):\n if str(torch_param.dtype) == \"torch.bfloat16\":\n # Convert to float32 first.\n return torch_param.detach().cpu().float().numpy()\n else:\n return torch_param.detach().cpu().numpy()\n\n def load_torch_params_from_bin(torch_binname: str):\n torch_binpath = os.path.join(self.model_path, torch_binname)\n torch_params = None\n if self.use_safetensors:\n torch_params = self.safetensors_load_func(torch_binpath)\n else:\n torch_params = torch.load(\n torch_binpath,\n map_location=torch.device(\"cpu\"),\n )\n torch_param_names = list(torch_params.keys())\n for torch_param_name in torch_param_names:\n torch_param = fetch_torch_param(torch_params[torch_param_name])\n del torch_params[torch_param_name]\n\n relax_params = self.f_convert_param_bkwd(torch_param_name, torch_param)\n if relax_params is not None:\n for param_name, param in relax_params:\n if param_name not in pname2pidx.keys():\n continue\n pidx = pname2pidx[param_name]\n assert pidx not in cached_relax_params\n cached_relax_params[pidx] = tvm.nd.array(param, device_cpu)\n else:\n assert torch_param_name not in cached_torch_params\n cached_torch_params[torch_param_name] = torch_param\n del torch_param\n\n def get_item(i):\n # If the weight is already provided by `model_params`, directly use it\n # and no need to load from binary file.\n if model_params and len(model_params) > i and model_params[i] is not None:\n assert i not in cached_relax_params\n return tvm.nd.array(model_params[i], device=device)\n\n # Otherwise, we load the weight from its corresponding binary file.\n assert i in self.pidx2pname\n relax_pname = self.pidx2pname[i]\n torch_pnames = self.f_convert_pname_fwd(relax_pname)\n\n if i not in cached_relax_params:\n for torch_binname in [\n self.torch_pname2binname[torch_pname] for torch_pname in torch_pnames\n ]:\n if torch_binname in loaded_torch_bins:\n continue\n load_torch_params_from_bin(torch_binname)\n loaded_torch_bins.add(torch_binname)\n\n if i not in cached_relax_params:\n assert len(torch_pnames) > 1\n assert all([torch_pname in cached_torch_params] for torch_pname in torch_pnames)\n cached_relax_params[i] = self.f_compute_relax_param(\n relax_pname,\n [cached_torch_params[torch_pname] for torch_pname in torch_pnames],\n )\n for torch_pname in torch_pnames:\n del cached_torch_params[torch_pname]\n\n assert i in cached_relax_params\n assert i not in loaded_idx_set\n param_on_device = tvm.nd.array(cached_relax_params[i], device=device)\n loaded_idx_set.add(i)\n del cached_relax_params[i]\n return param_on_device\n\n return get_item\n\n \ndef get_param_set_item(self) -> Tuple[Callable, List[tvm.nd.NDArray]]:\n \"\"\"A wrapper function which returns the `set_item`\n functions for parameter lazy loading.\n\n The return value of this function is intended to be registered\n as `\"set_item\"`, for use in a module built with\n `LazyTransformParams`.\n\n .. code-block:: python\n\n set_item,loaded_params = manager.get_param_set_item()\n tvm.register_func(func_name=\"set_item\", f=set_item, override=True)\n compiled_function()\n # `loaded_params` is now fully populated\n\n Returns\n -------\n set_item: Callable[[int,tvm.nd.NDArray]]\n\n A function that accepts an index and the return value at\n that index.\n\n loaded_params: List[tvm.nd.NDArray]\n\n A list of loaded parameters, populated by `set_item`.\n When initially returned, this list is empty. After\n executing the compiled function with\n `LazyTransformParams`, `loaded_params` will be\n populated.\n \"\"\"\n device_cpu = tvm.cpu()\n loaded_params: List[tvm.nd.NDArray] = []\n\n def set_item(i: int, computed_param: tvm.nd.NDArray):\n if len(loaded_params) <= i:\n loaded_params.extend([None for _ in range(i - len(loaded_params) + 1)])\n loaded_params[i] = tvm.nd.array(computed_param, device=device_cpu)\n\n return set_item, loaded_params\n\n #################### Below are internally called methods ####################\n\n def _register_param(\n self,\n name: str,\n var: relax.Var,\n quant_spec: quantization.QuantizationSpec,\n func_name: str,\n shard_dim: Optional[int],\n shard_strategy: Optional[str],\n ) -> Parameter:\n \"\"\"Register a single parameter in the parameter manager.\n In most cases, this method is not directly used outside this class:\n it is called by `register_params` above.\n\n Parameters\n ----------\n name : str\n The name of the parameter to register.\n Name serves as the unique identifier of the parameter.\n\n var : relax.Var\n The parameter relax.Var on the nn.Module side.\n\n quant_spec : quantization.QuantizationSpec\n The quantization specification of the parameter\n\n func_name : str\n The name of the function the input var is in.\n For example, the \"prefill\" function or the \"decode\" function.\n\n shard_dim : Optional[int]\n The dimension along which the parameter is sharded.\n\n shard_strategy : Optional[str]\n The strategy of sharding the parameter.\n\n Returns\n -------\n param : Parameter\n The registered Parameter.\n \"\"\"\n assert (\n var not in self.func_raw_param_map\n ), \"The input var is not supposed to be already registered.\"\n assert isinstance(\n var.struct_info.shape, relax.ShapeExpr\n ), \"The parameter to register is expected to have shape as a tuple\"\n\n if name in self.params:\n # When the input name appears in `self.params`, it means the input\n # parameter has been previously registered in some other function.\n # Thus, we check if the dtype, shape and the quantization specification\n # of both sides are consistent.\n param = self.params[name]\n assert (\n param.quant_spec == quant_spec\n ), \"One parameter is expected to be quantized by single specification in all functions.\"\n assert (\n param.param_info.dtype == var.struct_info.dtype\n ), \"Dtype mismatch of one parameter in two functions.\"\n assert (\n param.param_info.ndim == var.struct_info.ndim\n ), \"Shape mismatch of one parameter in two functions.\"\n for len0, len1 in zip(param.param_info.shape.values, var.struct_info.shape.values):\n if isinstance(len0, tir.IntImm) and isinstance(len1, tir.IntImm):\n assert (\n len0.value == len1.value\n ), \"Shape mismatch of one parameter in two functions.\"\n else:\n # Otherwise, the parameter is registered for the first time.\n param = Parameter(name, quant_spec, shard_dim, shard_strategy)\n self.params[name] = param\n self.param_names.append(name)\n\n param.register_func(func_name, var.struct_info)\n # Record the mapping from the input relax.Var to the function name and\n # the parameter in the manager.\n self.func_raw_param_map[var] = (func_name, param)\n return param\n\n def _dequantize(\n self,\n param: Parameter,\n qparams: List[relax.Var],\n bb: relax.BlockBuilder,\n func_name: str,\n ) -> relax.Var:\n \"\"\"Applying dequantization to the input parameter.\n This method is called by `transform_module` below, and is not\n directly invoked outside the class.\n\n Parameters\n ----------\n param : Parameter\n The parameter whose quantized tensors are to be dequantized.\n\n qparams : List[relax.Var]\n The relax.Var of the quantized tensors of all parameters in the model.\n\n Returns\n -------\n The dequantized parameter, in the form of a relax.Var.\n \"\"\"\n # Get the dequantization function of this parameter.\n f_dequantize = param.quant_spec.get_dequantize_func(\n param_info=param.param_info_dict[func_name],\n qparam_info=[qparam.struct_info for qparam in qparams],\n )\n if f_dequantize is None:\n # If the parameter does not have a dequantization function, its \"quantized\n # data\" is expected to have only one element.\n assert len(qparams) == 1, (\n \"A parameter without dequantization function is expected not to have \"\n 'more than one \"quantized data\".'\n )\n return qparams[0]\n else:\n # Apply the dequantization function.\n return bb.emit(f_dequantize(bb, qparams))\n\n def create_parameter_transformation(self, optimize_parameter_order: bool = True):\n \"\"\"Produce an IRModule that can transform the parameters\n\n Parameters\n ----------\n optimize_parameter_order: bool\n\n If true, reorder the parameter transformations to\n prioritize operations that use a currently-open file. If\n false, transform the parameters in their default order.\n\n Returns\n -------\n tvm.IRModule\n The transformation module\n\n \"\"\"\n mod = _create_quantize_func(self)\n if optimize_parameter_order:\n mod = self.optimize_transform_param_order()(mod)\n return mod\n\n def optimize_transform_param_order(self) -> tvm.transform.Pass:\n \"\"\"Produce an transformation that optimizes for minimal memory footprint\n\n Returns\n -------\n tvm.transform.Pass\n The transformation\n \"\"\"\n\n pidx2binname: Dict[int, str] = {\n pidx: self.torch_pname2binname[self.f_convert_pname_fwd(pname)[0]]\n for pidx, pname in self.pidx2pname.items()\n if self.f_convert_pname_fwd(pname)[0] in self.torch_pname2binname\n }\n return ReorderTransformFunc(pidx2binname)\n\n\n@mutator\nclass ParamReplacer(PyExprMutator):\n \"\"\"The function mutator that updates the model with dequantization.\n\n Attributes\n ----------\n mod : tvm.IRModule\n The IRModule of the model to be updated.\n\n func_name_to_quantized_params : Dict[str, List[relax.Var]]\n The mapping from each function name to its input var of quantized data tuple.\n\n f_replace : Callable[[relax.Var, relax.BlockBuilder], relax.Var]\n The function for updating a previous parameter in functions with dequantization.\n\n param_set : Set[relax.Var]\n The set of previous parameters (before applying quantization and dequantization)\n in the relax functions.\n \"\"\"\n\n mod: tvm.IRModule\n func_name_to_quantized_params: Dict[str, List[relax.Var]]\n f_replace: Callable[[relax.Var, relax.BlockBuilder], relax.Var]\n param_set: Set[relax.Var]\n\n cur_func_name: str\n\n def __init__(\n self,\n mod: tvm.IRModule,\n func_name_to_quantized_params: Dict[str, relax.Var],\n f_replace: Callable[[relax.Var, relax.BlockBuilder], relax.Var],\n ):\n super().__init__(mod)\n self.mod = mod\n self.func_name_to_quantized_params = func_name_to_quantized_params\n self.f_replace = f_replace\n self.cur_func_name = \"\"\n\n def transform(self) -> tvm.IRModule:\n for gv, func in self.mod.functions.items():\n if not isinstance(func, relax.Function):\n continue\n if func.attrs is None or not \"num_input\" in func.attrs:\n continue\n\n assert (\n gv.name_hint in self.func_name_to_quantized_params\n ), f\"{gv.name_hint} not in {self.func_name_to_quantized_params}\"\n updated_func = self.rewrite_func(func, self.func_name_to_quantized_params[gv.name_hint])\n updated_func = remove_all_unused(updated_func)\n self.builder_.update_func(gv, updated_func)\n return self.builder_.get()\n\n def rewrite_func(self, func: Function, quantized_params: List[relax.Var]) -> relax.Function:\n num_input = int(func.attrs[\"num_input\"])\n self.param_set = set(func.params[num_input:])\n\n body = self.visit_expr(func.body)\n return relax.Function(\n params=func.params[:num_input] + quantized_params,\n body=body,\n ret_struct_info=func.ret_struct_info,\n is_pure=func.is_pure,\n attrs=func.attrs,\n )\n\n def visit_var_(self, var: Var) -> Expr:\n if var in self.param_set:\n return self.f_replace(var, self.builder_)\n else:\n return super().visit_var_(var)\n\n\n##################################################################\n\n\ndef load_torch_pname2binname_map(\n model_path: str,\n use_safetensors: bool,\n relax_pnames: Set[str],\n f_convert_pname_fwd: Callable[[str], List[str]] = lambda pname: [pname],\n) -> Dict[str, str]:\n \"\"\"Constructing the dictionary from each torch parameter's name to\n the name of the binary shard where the torch parameter is saved.\n\n Parameters\n ----------\n model_path : str\n The path of the Hugging Face model on disk.\n\n use_safetensors: bool\n Whether to use ``.safetensors`` instead of ``.bin`` to load model.\n\n relax_pnames: Set[str]\n The name of the Relax parameters.\n\n f_convert_pname_fwd: Callable[[str], List[str]]\n The function which converts Relax parameter name to torch's\n parameter names. See ParamManager for more details.\n \"\"\"\n bin_idx_path = None\n single_shard_file_name = None\n if use_safetensors:\n bin_idx_path = os.path.join(model_path, \"model.safetensors.index.json\")\n single_shard_file_name = \"model.safetensors\"\n else:\n bin_idx_path = os.path.join(model_path, \"pytorch_model.bin.index.json\")\n single_shard_file_name = \"pytorch_model.bin\"\n single_shard_path = os.path.join(model_path, single_shard_file_name)\n\n if os.path.isfile(bin_idx_path):\n # Multiple weight shards.\n with open(bin_idx_path, \"r\") as f_torch_json:\n torch_bin_json = json.load(f_torch_json)\n torch_pname2binname = torch_bin_json[\"weight_map\"]\n elif os.path.isfile(single_shard_path):\n # Single weight shard.\n torch_pname2binname = {\n torch_pname: single_shard_file_name\n for relax_pname in relax_pnames\n for torch_pname in f_convert_pname_fwd(relax_pname)\n }\n else:\n suffix = \".safetensors\" if use_safetensors else \".bin\"\n shard_names = []\n # Collect Scan every single file with the suffix\n for filename in os.listdir(model_path):\n if filename.endswith(suffix):\n shard_names.append(filename)\n if len(shard_names) == 1:\n torch_pname2binname = {\n torch_pname: shard_names[0]\n for relax_pname in relax_pnames\n for torch_pname in f_convert_pname_fwd(relax_pname)\n }\n else:\n raise ValueError(\"Multiple weight shard files without json map is not supported\")\n return torch_pname2binname\n\n\ndef _create_quantize_func(param_manager: ParamManager) -> tvm.IRModule:\n \"\"\"Construct the Relax function which computes quantization.\n This method is called by `transform_module` below, and is not\n directly invoked outside the class.\n\n Parameters\n ----------\n param_manager : ParamManager\n The parameter manager which has all the parameter information.\n\n Returns\n -------\n The created function which computes quantization.\n Precisely, an IRModule which contains the main quantization Relax function\n and a series of TIR functions is returned.\n \"\"\"\n bb = relax.BlockBuilder()\n param2qrange = dict()\n\n # Construct the input of the function.\n # We need a list of ranges for each\n # parameter to get its corresponding tensors loaded from disk.\n input_tensor_info: List[relax.TensorStructInfo] = []\n loaded_tensor_ranges: List[range] = []\n for name in param_manager.param_names:\n param = param_manager.params[name]\n _, loaded_tensor_info = param.quant_spec.get_loaded_tensor_info(name, param.param_info)\n loaded_tensor_ranges.append(\n range(\n len(input_tensor_info),\n len(input_tensor_info) + len(loaded_tensor_info),\n )\n )\n input_tensor_info += loaded_tensor_info\n raw_param_tuple = relax.Var(\"params\", relax.TupleStructInfo(input_tensor_info))\n\n with bb.function(\"transform_params\", params=[raw_param_tuple]):\n with bb.dataflow():\n quantized_params: List[relax.Var] = []\n for pidx, name in enumerate(param_manager.param_names):\n param = param_manager.params[name]\n param_vars: List[relax.Var] = []\n # Emit relax.TupleGetItem to get the raw parameters or pre-quantized params.\n for loaded_tensor_idx in loaded_tensor_ranges[pidx]:\n param_vars.append(\n bb.emit(relax.TupleGetItem(raw_param_tuple, loaded_tensor_idx))\n )\n\n # Get the quantization function of this parameter.\n f_quantize = param.quant_spec.get_quantize_func(param.param_info)\n if f_quantize is None:\n # If the parameter does not have a quantization function, either it\n # does not need quantization or it is pre-quantized.\n param2qrange[param] = range(\n len(quantized_params),\n len(quantized_params) + len(param_vars),\n )\n quantized_params += param_vars\n else:\n # If the parameter has a quantization function, it is not expected\n # to be pre-quantized.\n assert len(param_vars) == 1, (\n \"A parameter with quantization function is not expected \"\n \"to be pre-quantized.\"\n )\n\n # Apply the quantization function.\n quantized_data = bb.emit(f_quantize(bb, param_vars))\n\n if isinstance(quantized_data.struct_info, relax.TupleStructInfo):\n n_tensor = len(quantized_data.struct_info.fields)\n assert n_tensor > 1\n # Record the range of quantized tensors of this parameter.\n param2qrange[param] = range(\n len(quantized_params), len(quantized_params) + n_tensor\n )\n # Collect the quantized tensors to return.\n for i in range(n_tensor):\n quantized_params.append(bb.emit(relax.TupleGetItem(quantized_data, i)))\n else:\n assert isinstance(quantized_data.struct_info, relax.TensorStructInfo)\n param2qrange[param] = range(\n len(quantized_params), len(quantized_params) + 1\n )\n quantized_params.append(quantized_data)\n\n output = bb.emit_output(relax.Tuple(quantized_params))\n bb.emit_func_output(output)\n\n mod = bb.get()\n param_manager.param2qrange = param2qrange\n # Return the created IRModule.\n return bb.get()\n\n\ndef transform_params_for_each_rank(\n num_shards: int, rank_argument_name: str = \"rank_arg\"\n) -> tvm.ir.transform.Pass:\n \"\"\"Update a parameter transform to apply across all ranks\n\n For use in generating a pre-sharded set of weights. Given a\n parameter transformation that generates sharded model weights for\n a single shard, produce a parameter transformation that generates\n sharded model weights for each shard.\n\n Parameters\n ----------\n mod: tvm.IRModule\n\n A module containing the parameter transformation function,\n named \"transform_params\", along with any subroutines called by\n the parameter transformation.\n\n num_shards: int\n\n The number of shards to generate.\n\n rank_argument_name: str\n\n The name of the argument that specifies the rank. Should be a\n R.ShapeTuple with a single R.PrimStructInfo('int64').\n\n Returns\n -------\n tvm.IRModule\n\n The modified parameter transformation\n \"\"\"\n\n @tvm.ir.transform.module_pass(opt_level=0, name=\"ParamManager.transform_params_for_each_rank\")\n def transform_func(mod: tvm.IRModule, _context) -> tvm.IRModule:\n generic_transform = mod[\"transform_params\"]\n\n if generic_transform.attrs is not None and \"num_input\" in generic_transform.attrs:\n num_input = generic_transform.attrs[\"num_input\"].value\n else:\n num_input = 0\n\n if num_input == 0:\n return mod\n\n tensor_params = generic_transform.params[num_input:]\n attrs = {\"num_input\": num_input - 1}\n\n bb = relax.BlockBuilder()\n\n with bb.function(\"transform_params\", params=tensor_params, attrs=attrs):\n output = []\n for rank in range(num_shards):\n # TODO(Lunderberg): Implement this in terms of a\n # generic utility that inlines local functions.\n func = generic_transform\n func = func.bind_params({rank_argument_name: relax.ShapeExpr([rank])})\n func = relax.utils.copy_with_new_vars(func)\n func = func.bind_params(\n {var: tensor_param for (var, tensor_param) in zip(func.params, tensor_params)}\n )\n shard_tuple = func.body\n output.extend([shard_tuple[i] for i in range(len(tensor_params))])\n\n with bb.dataflow():\n gv = bb.emit_output(relax.Tuple(output))\n bb.emit_func_output(gv)\n\n mod = mod.clone()\n mod[\"transform_params\"] = bb.get()[\"transform_params\"]\n return mod\n\n return transform_func\n\n\ndef chain_parameter_transforms(mod_a: tvm.IRModule, mod_b: tvm.IRModule) -> tvm.IRModule:\n \"\"\"Chain two sequential parameter transformations\n\n For use in manipulating sets of model weights. Given two\n parameter transformations that could be applied sequentially,\n produce a single parameter transformation whose output is the same\n as applying the parameter transformations sequentially.\n\n\n .. code-block:: python\n\n # Before\n params_after_a = mod_a['transform_params'](orig_params)\n params_after_b = mod_b['transform_params'](params_after_a)\n\n # After\n mod_ab = chain_parameter_transforms(mod_a, mod_b)\n params_after_b = mod_ab['transform_params'](orig_params)\n\n Parameters\n ----------\n mod_a: tvm.IRModule\n\n The module containing the first parameter transformation.\n\n mod_b: tvm.IRModule\n\n The module containing the second parameter transformation.\n\n Returns\n -------\n tvm.IRModule\n\n The module containing the output\n\n \"\"\"\n func_a = mod_a[\"transform_params\"]\n func_b = mod_b[\"transform_params\"]\n\n bb = relax.BlockBuilder()\n\n def get_num_input_attr(func):\n if func.attrs is None:\n return 0\n\n attrs = func.attrs\n if \"num_input\" not in attrs:\n return 0\n num_input = attrs[\"num_input\"]\n\n assert isinstance(num_input, tvm.tir.IntImm)\n return num_input.value\n\n # Either func_a or func_b may have parameters that are provided at\n # a later point. The chaining of parameter transforms assumes\n # that all model weights accepted by func_b are produced by\n # func_a. If func_b accepts non-weight parameters (e.g. the GPU\n # rank), these must still be provided.\n func_a_num_input = get_num_input_attr(func_a)\n func_b_num_input = get_num_input_attr(func_b)\n\n output_num_input = func_a_num_input + func_b_num_input\n output_params = [\n *func_a.params[:func_a_num_input],\n *func_b.params[:func_b_num_input],\n *func_a.params[func_a_num_input:],\n ]\n\n with bb.function(\n \"transform_params\", params=output_params, attrs={\"num_input\": output_num_input}\n ):\n with bb.dataflow():\n # TODO(Lunderberg): Implement this in terms of a\n # generic utility that inlines local functions.\n func_a_output = bb.emit(func_a.body)\n func_b_param_map = {\n param: expr\n for (param, expr) in zip(func_b.params[func_b_num_input:], func_a_output)\n }\n func_b_output = func_b.bind_params(func_b_param_map).body\n gv = bb.emit_output(func_b_output)\n bb.emit_func_output(gv)\n\n merged_transform_func = bb.get()[\"transform_params\"]\n\n new_mod = {\n **{\n gvar: func\n for gvar, func in mod_a.functions.items()\n if gvar.name_hint != \"transform_params\"\n },\n **{\n gvar: func\n for gvar, func in mod_b.functions.items()\n if gvar.name_hint != \"transform_params\"\n },\n \"transform_params\": merged_transform_func,\n }\n return tvm.IRModule(new_mod)\n\n# Path: mlc_llm/relax_model/llama.py\nimport math\nfrom dataclasses import dataclass\nfrom typing import Any, List, Optional, Tuple, Union\n\nimport numpy as np\nimport tvm\nfrom tvm import relax, te, tir\nfrom tvm.relax.op import ccl\nfrom tvm.relax.testing import nn\nfrom tvm.script import relax as R\n\nfrom ..quantization import ParamQuantKind, QuantizationScheme\nfrom .commons import create_metadata_func\nfrom .modules import ModuleList\nfrom .param_manager import ParamManager\n\n\n@dataclass\nclass LlamaConfig:\n def __init__(\n self,\n dtype=\"float32\",\n max_sequence_length=2048,\n vocab_size=32000, # some models like WizardMath can have 32001\n hidden_size=4096,\n intermediate_size=11008,\n num_hidden_layers=32,\n num_attention_heads=32,\n num_key_value_heads=None,\n hidden_act=\"silu\",\n initializer_range=0.02,\n rms_norm_eps=1e-6,\n pad_token_id=-1,\n bos_token_id=0,\n eos_token_id=1,\n tie_word_embeddings=False,\n position_embedding_base=10000,\n combine_matmul=True,\n build_model_only=False,\n num_shards=1,\n sliding_window=None,\n target_kind=None,\n **kwargs,\n ):\n self.dtype = dtype\n self.max_sequence_length = max_sequence_length\n self.vocab_size = vocab_size\n self.hidden_size = hidden_size\n self.intermediate_size = intermediate_size\n self.num_hidden_layers = num_hidden_layers\n self.num_attention_heads = num_attention_heads\n self.num_key_value_heads = num_key_value_heads\n self.hidden_act = hidden_act\n self.initializer_range = initializer_range\n self.rms_norm_eps = rms_norm_eps\n self.pad_token_id = pad_token_id\n self.bos_token_id = bos_token_id\n self.eos_token_id = eos_token_id\n self.tie_word_embeddings = tie_word_embeddings\n self.position_embedding_base = position_embedding_base\n self.combine_matmul = combine_matmul\n self.sliding_window = sliding_window\n self.target_kind = target_kind\n\n if build_model_only and num_shards > 1:\n self.num_shards = num_shards\n else:\n self.num_shards = 1\n self.kwargs = kwargs\n\n def get_num_key_value_heads(self):\n if self.num_key_value_heads is None:\n return self.num_attention_heads\n\n return self.num_key_value_heads\n\n\nclass Linear(nn.Module):\n def __init__(self, in_features, out_features, dtype: str, bias=True):\n self.in_features = in_features\n self.out_features = out_features\n self.weight = nn.Parameter((out_features, in_features), dtype=dtype, name=\"linear_weight\")\n if bias:\n self.bias = nn.Parameter((out_features,), dtype=dtype, name=\"linear_bias\")\n else:\n self.bias = None\n\n def forward(self, input: relax.Expr) -> relax.Var:\n return nn.emit(relax.op.linear(input, self.weight, self.bias))\n\n\nclass Embedding(nn.Module):\n def __init__(self, num_embeddings, embedding_dim, dtype: str):\n self.num_embeddings = num_embeddings\n self.embedding_dim = embedding_dim\n self.weight = nn.Parameter(\n (num_embeddings, embedding_dim), dtype=dtype, name=\"embedding_weight\"\n )\n\n def forward(self, x: relax.Expr) -> relax.Var:\n from tvm.relax.op import reshape, take\n\n ndim = x.struct_info.ndim\n if ndim == 1:\n return nn.emit(take(self.weight, x, axis=0))\n else:\n x_shape = x.struct_info.shape.values\n emb_size = self.weight.struct_info.shape.values[-1]\n x = nn.emit(reshape(x, shape=[-1]))\n embedding = nn.emit(take(self.weight, x, axis=0))\n return nn.emit(reshape(embedding, [*x_shape, emb_size]))\n\n\nclass LlamaRMSNorm(nn.Module):\n def __init__(self, hidden_size, dtype, eps=1e-6):\n self.weight = nn.Parameter((hidden_size,), dtype=dtype, name=\"rms_norm_weight\")\n self.variance_epsilon = tvm.tir.const(eps, dtype)\n\n def forward(self, hidden_states):\n from tvm import te, tir\n\n def f_rms_norm(x, weight):\n is_float32 = x.dtype == \"float32\"\n\n def f_square(x):\n return tir.Cast(\"float32\", x) * tir.Cast(\"float32\", x) if not is_float32 else x * x\n\n def f_mul_cast(x, y):\n value = x * y\n if not is_float32:\n value = tir.Cast(x.dtype, value)\n return value\n\n def f_div_cast_2d(i, k):\n x_val = x[i, k]\n if not is_float32:\n x_val = tir.Cast(\"float32\", x_val)\n return x_val / tir.sqrt(square_sum[i] / x.shape[1] + self.variance_epsilon)\n\n def f_div_cast_3d(bsz, i, k):\n x_val = x[bsz, i, k]\n if not is_float32:\n x_val = tir.Cast(\"float32\", x_val)\n return x_val / tir.sqrt(square_sum[bsz, i] / x.shape[2] + self.variance_epsilon)\n\n k = te.reduce_axis((0, x.shape[-1]), name=\"k\")\n\n if len(x.shape) == 2:\n square_sum = te.compute(\n (x.shape[0],),\n lambda i: te.sum(f_square(x[i, k]), axis=k),\n name=x.op.name + \"red_temp\",\n )\n\n return te.compute(\n x.shape,\n lambda i, k: f_mul_cast(weight(k), f_div_cast_2d(i, k)),\n name=\"rms_norm\",\n )\n else:\n square_sum = te.compute(\n (x.shape[0], x.shape[1]),\n lambda bsz, i: te.sum(f_square(x[bsz, i, k]), axis=k),\n name=x.op.name + \"red_temp\",\n )\n\n return te.compute(\n x.shape,\n lambda bsz, i, k: f_mul_cast(weight(k), f_div_cast_3d(bsz, i, k)),\n name=\"rms_norm\",\n )\n\n return nn.emit_te(f_rms_norm, hidden_states, self.weight, primfunc_name_hint=\"rms_norm\")\n\n\nclass LlamaMLP(nn.Module):\n def __init__(self, config: LlamaConfig):\n self.combine_matmul = config.combine_matmul\n self.num_shards = config.num_shards\n hidden_size = config.hidden_size\n intermediate_size = config.intermediate_size // self.num_shards\n dtype = config.dtype\n if self.combine_matmul:\n self.gate_up_proj = Linear(hidden_size, 2 * intermediate_size, dtype=dtype, bias=False)\n self.down_proj = Linear(intermediate_size, hidden_size, dtype=dtype, bias=False)\n self.gate_up_proj.weight.shard_dim = 0\n self.gate_up_proj.weight.shard_strategy = \"shard_gate_up\"\n self.down_proj.weight.shard_dim = 1\n self.down_proj.weight.shard_strategy = \"shard_mlp_k\"\n else:\n self.gate_proj = Linear(hidden_size, intermediate_size, dtype=dtype, bias=False)\n self.down_proj = Linear(intermediate_size, hidden_size, dtype=dtype, bias=False)\n self.up_proj = Linear(hidden_size, intermediate_size, dtype=dtype, bias=False)\n self.gate_proj.weight.shard_dim = 0\n self.gate_proj.weight.shard_strategy = \"shard_axis_0\"\n self.down_proj.weight.shard_dim = 1\n self.down_proj.weight.shard_strategy = \"shard_axis_1\"\n self.up_proj.weight.shard_dim = 0\n self.up_proj.weight.shard_strategy = \"shard_axis_0\"\n\n def forward(self, x):\n if self.combine_matmul:\n gate_up_results = nn.emit(\n relax.op.split(\n self.gate_up_proj(x),\n indices_or_sections=2,\n axis=-1,\n )\n )\n gate_result = relax.TupleGetItem(gate_up_results, 0)\n up_result = relax.TupleGetItem(gate_up_results, 1)\n else:\n gate_result = self.gate_proj(x)\n up_result = self.up_proj(x)\n\n result = self.down_proj(relax.op.nn.silu(gate_result) * up_result)\n return result\n\n\ndef rotary_modulate_by_freq(tensor, idx, pos, position_embedding_base):\n head_dim = tensor.shape[-1]\n dtype = tensor.dtype\n n_feat_half = head_dim // 2\n feat_idx = idx[-1]\n inv_freq = te.const(1, \"float32\") / (\n te.power(\n te.const(position_embedding_base, \"float32\"),\n ((2 * feat_idx) % head_dim).astype(\"float32\") / head_dim.astype(\"float32\"),\n )\n )\n freq = pos * inv_freq\n left_indices = idx[:-1] + (feat_idx - n_feat_half,)\n right_indices = idx[:-1] + (feat_idx + n_feat_half,)\n return te.cos(freq).astype(dtype) * tensor(*idx) + te.sin(freq).astype(dtype) * tvm.tir.Select(\n feat_idx >= n_feat_half,\n tensor[(*left_indices,)],\n -tensor[(*right_indices,)],\n )\n\n\ndef apply_rotary_pos_emb(q, k, position_embedding_base, offset: int = 0):\n def f_rotary_embedding(tensor, offset):\n def rotary_compute(*idx):\n pos = (offset + idx[-3]).astype(\"float32\")\n return rotary_modulate_by_freq(\n tensor,\n idx,\n pos,\n position_embedding_base,\n )\n\n return tvm.te.compute(tensor.shape, rotary_compute, name=\"rotary\")\n\n q_embed = nn.emit_te(f_rotary_embedding, q, offset, primfunc_name_hint=\"rotary_embedding\")\n k_embed = nn.emit_te(f_rotary_embedding, k, offset, primfunc_name_hint=\"rotary_embedding\")\n return q_embed, k_embed\n\n\nclass LlamaAttentionBase(nn.Module):\n \"\"\"Multi-headed attention from 'Attention Is All You Need' paper\"\"\"\n\n def __init__(self, config: LlamaConfig):\n dtype = config.dtype\n self.num_shards = config.num_shards\n self.hidden_size = config.hidden_size\n self.num_key_value_heads = config.get_num_key_value_heads() // config.num_shards\n self.num_query_heads = config.num_attention_heads // self.num_shards\n self.head_dim = self.hidden_size // config.num_attention_heads\n self.position_embedding_base = config.position_embedding_base\n\n self.combine_matmul = config.combine_matmul\n if self.combine_matmul:\n self.query_key_value_proj = Linear(\n self.hidden_size,\n (self.num_query_heads + 2 * self.num_key_value_heads) * self.head_dim,\n dtype=dtype,\n bias=False,\n )\n self.query_key_value_proj.weight.shard_dim = 0\n self.query_key_value_proj.weight.shard_strategy = \"shard_qkv\"\n else:\n self.q_proj = Linear(\n self.hidden_size,\n self.num_query_heads * self.head_dim,\n dtype=dtype,\n bias=False,\n )\n self.k_proj = Linear(\n self.hidden_size,\n self.num_key_value_heads * self.head_dim,\n dtype=dtype,\n bias=False,\n )\n self.v_proj = Linear(\n self.hidden_size,\n self.num_key_value_heads * self.head_dim,\n dtype=dtype,\n bias=False,\n )\n self.q_proj.weight.shard_dim = 0\n self.k_proj.weight.shard_dim = 0\n self.v_proj.weight.shard_dim = 0\n self.q_proj.weight.shard_strategy = \"shard_axis_0\"\n self.k_proj.weight.shard_strategy = \"shard_axis_0\"\n self.v_proj.weight.shard_strategy = \"shard_axis_0\"\n\n self.o_proj = Linear(\n self.head_dim * self.num_query_heads, self.hidden_size, dtype=dtype, bias=False\n )\n self.o_proj.weight.shard_dim = 1\n self.o_proj.weight.shard_strategy = \"shard_o_proj_k\"\n\n def project_qkv(self, hidden_states, query_output_shape, kv_output_shape):\n from tvm.relax.op import reshape, split\n\n if self.combine_matmul:\n qkv_states = nn.emit(\n split(\n self.query_key_value_proj(hidden_states),\n indices_or_sections=[\n self.num_query_heads * self.head_dim,\n (self.num_query_heads + self.num_key_value_heads) * self.head_dim,\n ],\n axis=-1,\n )\n )\n query_states = relax.TupleGetItem(qkv_states, 0)\n key_states = relax.TupleGetItem(qkv_states, 1)\n value_states = relax.TupleGetItem(qkv_states, 2)\n else:\n query_states = self.q_proj(hidden_states)\n key_states = self.k_proj(hidden_states)\n value_states = self.v_proj(hidden_states)\n\n query_states = nn.emit(\n reshape(query_states, query_output_shape),\n )\n key_states = nn.emit(\n reshape(key_states, kv_output_shape),\n )\n value_states = nn.emit(\n...\n", "needle_token_start": 5447, "needle_token_end": 5862, "code_context_ntokens": 16145, "output": ["```python\ndef get_param_set_item(self) -> Tuple[Callable, List[tvm.nd.NDArray]]:\n \"\"\"A wrapper function which returns the `set_item`\n functions for parameter lazy loading.\n\n The return value of this function is intended to be registered\n as `\"set_item\"`, for use in a module built with\n `LazyTransformParams`.\n\n .. code-block:: python\n\n set_item,loaded_params = manager.get_param_set_item()\n tvm.register_func(func_name=\"set_item\", f=set_item, override=True)\n compiled_function()\n # `loaded_params` is now fully populated\n\n Returns\n -------\n set_item: Callable[[int,tvm.nd.NDArray]]\n\n A function that accepts an index and the return value at\n that index.\n\n loaded_params: List[tvm.nd.NDArray]\n\n A list of loaded parameters, populated by `set_item`.\n When initially returned, this list is empty. After\n executing the compiled function with\n `LazyTransformParams`, `loaded_params` will be\n populated.\n \"\"\"\n device_cpu = tvm.cpu()\n loaded_params: List[tvm.nd.NDArray] = []\n\n def set_item(i: int, computed_param: tvm.nd.NDArray):\n if len(loaded_params) <= i:\n loaded_params.extend([None for _ in range(i - len(loaded_params) + 1)])\n loaded_params[i] = tvm.nd.array(computed_param, device=device_cpu)\n\n return set_item, loaded_params\n```"]} +{"repo": "mlc-ai/mlc-llm", "name": "fused_fused_decode9_matmul7", "language": "python", "path": "mlc_llm/dispatch/dispatch_tir_operator_adreno.py", "position_ratio": 0.45, "description": "\nFunction Description:\n1. **Purpose**: The function performs a decoding operation followed by a matrix multiplication. The decoding involves extracting and transforming elements from an input buffer, and the matrix multiplication aggregates these transformed elements with another input buffer.\n2. **Input**: \n - A buffer containing 32-bit unsigned integers, structured in a 2-dimensional array.\n - A buffer containing 16-bit floating-point numbers, also structured in a 2-dimensional array.\n - A buffer containing 16-bit floating-point numbers, structured in a 3-dimensional array.\n3. **Output**: A buffer containing 16-bit floating-point numbers, structured in a 3-dimensional array, which holds the result of the matrix multiplication.\n4. **Procedure**: \n - The function iterates over a grid to decode values from the first input buffer using bitwise operations and scaling, and multiplies the decoded values with corresponding values from the second input buffer.\n - The results are stored temporarily.\n - Another loop iterates over a grid to perform matrix multiplication using the temporary results and the third input buffer.\n - The final results of the matrix multiplication are stored in the output buffer.\n\n", "instruction": "Based on the function description and code context, please retrieve and repeat the exact described function from the code context in a code block wrapped by ```:", "template": "instruction\ncode_context\ndescription\ninstruction", "code_context": " for i0, i1, i2, k in T.grid(T.int64(1), T.int64(1), T.int64(11008), T.int64(4096)):\n with T.block(\"matmul\"):\n v_i0, v_i1, v_i2, v_k = T.axis.remap(\"SSSR\", [i0, i1, i2, k])\n T.reads(lv1622[v_i0, v_i1, v_k], var_decode_intermediate[v_k, v_i2])\n T.writes(var_matmul_intermediate[v_i0, v_i1, v_i2])\n with T.init():\n var_matmul_intermediate[v_i0, v_i1, v_i2] = T.float16(0)\n var_matmul_intermediate[v_i0, v_i1, v_i2] = (\n var_matmul_intermediate[v_i0, v_i1, v_i2]\n + lv1622[v_i0, v_i1, v_k] * var_decode_intermediate[v_k, v_i2]\n )\n for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(1), T.int64(11008)):\n with T.block(\"T_multiply\"):\n v_ax0, v_ax1, v_ax2 = T.axis.remap(\"SSS\", [ax0, ax1, ax2])\n T.reads(\n lv4[v_ax0, v_ax1, v_ax2], var_matmul_intermediate[v_ax0, v_ax1, v_ax2]\n )\n T.writes(p_output0_intermediate[v_ax0, v_ax1, v_ax2])\n p_output0_intermediate[v_ax0, v_ax1, v_ax2] = (\n lv4[v_ax0, v_ax1, v_ax2] * var_matmul_intermediate[v_ax0, v_ax1, v_ax2]\n )\n\n\n@T.prim_func(private=True)\ndef fused_decode5_fused_matmul6_multiply1_after(\n lv1617: T.Buffer((T.int64(512), T.int64(11008)), \"uint32\"),\n lv1618: T.Buffer((T.int64(128), T.int64(11008)), \"float16\"),\n lv1622: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), \"float16\"),\n lv4: T.Buffer((T.int64(1), T.int64(1), T.int64(11008)), \"float16\"),\n p_output0_intermediate: T.Buffer(\n (T.int64(1), T.int64(1), T.int64(11008)), \"float16\"\n ),\n):\n T.func_attr({\"tir.is_scheduled\": 1, \"tir.noalias\": T.bool(True)})\n # with T.block(\"root\"):\n var_matmul_intermediate_local = T.alloc_buffer(\n (T.int64(1), T.int64(1), T.int64(22016)), \"float16\", scope=\"local\"\n )\n var_matmul_intermediate_local_batch = T.alloc_buffer(\n (T.int64(1), T.int64(1), T.int64(22016)), \"float16\", scope=\"local\"\n )\n lv1617_local = T.alloc_buffer(\n (T.int64(512), T.int64(11008)), \"uint32\", scope=\"local\"\n )\n lv1618_local = T.alloc_buffer(\n (T.int64(128), T.int64(11008)), \"float16\", scope=\"local\"\n )\n lv1622_shared = T.alloc_buffer(\n (T.int64(1), T.int64(1), T.int64(1024)), \"float16\", scope=\"shared\"\n )\n for i0_i1_i2_fused_0 in T.thread_binding(T.int64(43), thread=\"blockIdx.x\"):\n for i0_i1_i2_fused_1 in T.thread_binding(T.int64(64), thread=\"threadIdx.x\"):\n for ax2_y in T.thread_binding(T.int64(2), thread=\"threadIdx.y\"):\n for i0_i1_i2_fused_2_init in T.vectorized(T.int64(4)):\n with T.block(\"matmul_init\"):\n v_i0 = T.axis.spatial(T.int64(1), T.int64(0))\n v_i1 = T.axis.spatial(T.int64(1), T.int64(0))\n v_i2 = T.axis.spatial(\n T.int64(22016),\n i0_i1_i2_fused_0 * T.int64(512)\n + i0_i1_i2_fused_1 * T.int64(8)\n + ax2_y * T.int64(4)\n + i0_i1_i2_fused_2_init,\n )\n T.reads()\n T.writes(var_matmul_intermediate_local[v_i0, v_i1, v_i2])\n var_matmul_intermediate_local[v_i0, v_i1, v_i2] = T.float16(0)\n for k_0 in range(T.int64(4)):\n for ax2_1 in T.thread_binding(T.int64(64), thread=\"threadIdx.x\"):\n for ax2_y in T.thread_binding(T.int64(2), thread=\"threadIdx.y\"):\n for ax0, ax1 in T.grid(T.int64(1), T.int64(1)):\n for ax2_2 in T.vectorized(T.int64(8)):\n with T.block(\"lv1622_shared\"):\n v0, v1 = T.axis.remap(\"SS\", [ax0, ax1])\n v2 = T.axis.spatial(\n T.int64(4096),\n k_0 * T.int64(1024)\n + ax2_y * T.int64(512)\n + ax2_1 * T.int64(8)\n + ax2_2,\n )\n v2k = T.axis.spatial(\n T.int64(1024),\n (\n ax2_y * T.int64(512)\n + ax2_1 * T.int64(8)\n + ax2_2\n ),\n )\n T.reads(lv1622[v0, v1, v2])\n T.writes(lv1622_shared[v0, v1, v2k])\n lv1622_shared[v0, v1, v2k] = lv1622[v0, v1, v2]\n for k_1 in range(T.int64(16)):\n for ax2_y in T.thread_binding(T.int64(2), thread=\"threadIdx.y\"):\n for ax1 in T.vectorized(T.int64(4)):\n with T.block(\"matmul_init_local\"):\n v_i0 = T.axis.spatial(T.int64(1), T.int64(0))\n v_i1 = T.axis.spatial(T.int64(1), T.int64(0))\n v_i2k = T.axis.spatial(\n T.int64(22016),\n i0_i1_i2_fused_0 * T.int64(512)\n + i0_i1_i2_fused_1 * T.int64(8)\n + ax2_y * T.int64(4)\n + ax1,\n )\n T.reads()\n T.writes(\n var_matmul_intermediate_local_batch[\n v_i0, v_i1, v_i2k\n ]\n )\n var_matmul_intermediate_local_batch[\n v_i0, v_i1, v_i2k\n ] = T.float16(0)\n for ax0 in range(T.int64(1)):\n for ax1 in T.vectorized(T.int64(4)):\n with T.block(\"lv1618_local\"):\n v0 = T.axis.spatial(\n T.int64(128),\n k_0 * T.int64(32)\n + (k_1 * T.int64(2) + ax2_y)\n + ax0,\n )\n v1 = T.axis.spatial(\n T.int64(11008),\n i0_i1_i2_fused_0 * T.int64(256)\n + i0_i1_i2_fused_1 * T.int64(4)\n + ax1,\n )\n T.reads(lv1618[v0, v1])\n T.writes(lv1618_local[v0, v1])\n lv1618_local[v0, v1] = lv1618[v0, v1]\n for k_2 in range(T.int64(4)):\n for ax0 in range(T.int64(1)):\n for ax1 in T.vectorized(T.int64(4)):\n with T.block(\"lv1617_local\"):\n v0 = T.axis.spatial(\n T.int64(512),\n k_0 * T.int64(128)\n + (k_1 * T.int64(2) + ax2_y) * T.int64(4)\n + k_2\n + ax0,\n )\n v1 = T.axis.spatial(\n T.int64(11008),\n i0_i1_i2_fused_0 * T.int64(256)\n + i0_i1_i2_fused_1 * T.int64(4)\n + ax1,\n )\n T.reads(lv1617[v0, v1])\n T.writes(lv1617_local[v0, v1])\n lv1617_local[v0, v1] = lv1617[v0, v1]\n for k_3 in range(T.int64(8)):\n for i0_i1_i2_fused_2 in T.vectorized(T.int64(4)):\n with T.block(\"matmul_update\"):\n v_i0 = T.axis.spatial(T.int64(1), T.int64(0))\n v_i1 = T.axis.spatial(T.int64(1), T.int64(0))\n v_i2 = T.axis.spatial(\n T.int64(11008),\n i0_i1_i2_fused_0 * T.int64(256)\n + i0_i1_i2_fused_1 * T.int64(4)\n + i0_i1_i2_fused_2,\n )\n v_i2k = T.axis.spatial(\n T.int64(22016),\n i0_i1_i2_fused_0 * T.int64(512)\n + i0_i1_i2_fused_1 * T.int64(8)\n + ax2_y * T.int64(4)\n + i0_i1_i2_fused_2,\n )\n v_k = T.axis.reduce(\n T.int64(4096),\n k_0 * T.int64(1024)\n + (k_1 * T.int64(2) + ax2_y) * T.int64(32)\n + k_2 * T.int64(8)\n + k_3,\n )\n v_ki = T.axis.reduce(\n T.int64(1024),\n (k_1 * T.int64(2) + ax2_y) * T.int64(32)\n + k_2 * T.int64(8)\n + k_3,\n )\n T.reads(\n var_matmul_intermediate_local_batch[\n v_i0, v_i1, v_i2k\n ],\n lv1622_shared[v_i0, v_i1, v_ki],\n lv1617_local[v_k // T.int64(8), v_i2],\n )\n T.writes(\n var_matmul_intermediate_local_batch[\n v_i0, v_i1, v_i2k\n ]\n )\n var_matmul_intermediate_local_batch[\n v_i0, v_i1, v_i2k\n ] = var_matmul_intermediate_local_batch[\n v_i0, v_i1, v_i2k\n ] + lv1622_shared[\n v_i0, v_i1, v_ki\n ] * (\n (\n T.Cast(\n \"float16\",\n T.bitwise_and(\n T.shift_right(\n lv1617_local[\n v_k // T.int64(8), v_i2\n ],\n T.Cast(\n \"uint32\",\n v_k % T.int64(8),\n )\n * T.uint32(4),\n ),\n T.uint32(15),\n ),\n )\n - T.float16(7)\n )\n )\n for ax0 in range(T.int64(1)):\n for ax1 in T.vectorized(T.int64(4)):\n with T.block(\"multiple_scale\"):\n v_i0 = T.axis.spatial(T.int64(1), T.int64(0))\n v_i1 = T.axis.spatial(T.int64(1), T.int64(0))\n v_i2k = T.axis.spatial(\n T.int64(22016),\n i0_i1_i2_fused_0 * T.int64(512)\n + i0_i1_i2_fused_1 * T.int64(8)\n + ax2_y * T.int64(4)\n + ax1,\n )\n v0 = T.axis.spatial(\n T.int64(128),\n k_0 * T.int64(32)\n + (k_1 * T.int64(2) + ax2_y)\n + ax0,\n )\n v1 = T.axis.spatial(\n T.int64(11008),\n i0_i1_i2_fused_0 * T.int64(256)\n + i0_i1_i2_fused_1 * T.int64(4)\n + ax1,\n )\n T.reads(\n lv1618_local[v0, v1],\n var_matmul_intermediate_local_batch[\n v_i0, v_i1, v_i2k\n ],\n )\n T.writes(\n var_matmul_intermediate_local[v_i0, v_i1, v_i2k]\n )\n var_matmul_intermediate_local[v_i0, v_i1, v_i2k] = (\n var_matmul_intermediate_local[v_i0, v_i1, v_i2k]\n + var_matmul_intermediate_local_batch[\n v_i0, v_i1, v_i2k\n ]\n * lv1618_local[v0, v1]\n )\n for ax2_y in T.thread_binding(T.int64(2), thread=\"threadIdx.y\"):\n for ax0, ax1 in T.grid(T.int64(1), T.int64(1)):\n for ax2 in T.vectorized(T.int64(4)):\n with T.block(\"var_matmul_intermediate_update\"):\n v0, v1 = T.axis.remap(\"SS\", [ax0, ax1])\n v2 = T.axis.spatial(\n T.int64(512),\n i0_i1_i2_fused_1 * T.int64(8)\n + ax2_y * T.int64(4)\n + ax2,\n )\n v_i2k = T.axis.spatial(\n T.int64(22016),\n i0_i1_i2_fused_0 * T.int64(512)\n + i0_i1_i2_fused_1 * T.int64(8)\n + ax2_y * T.int64(4)\n + ax2,\n )\n T.reads(var_matmul_intermediate_local[v0, v1, v_i2k])\n T.writes(lv1622_shared[v0, v1, v2])\n lv1622_shared[v0, v1, v2] = var_matmul_intermediate_local[\n v0, v1, v_i2k\n ]\n for ax2_y in T.thread_binding(T.int64(2), thread=\"threadIdx.y\"):\n for ax0, ax1 in T.grid(T.int64(1), T.int64(1)):\n for ax2 in T.vectorized(T.int64(4)):\n with T.block(\"var_matmul_intermediate_local\"):\n v0, v1 = T.axis.remap(\"SS\", [ax0, ax1])\n v2 = T.axis.spatial(\n T.int64(11008),\n i0_i1_i2_fused_0 * T.int64(256)\n + i0_i1_i2_fused_1 * T.int64(4)\n + ax2,\n )\n v_i2k = T.axis.spatial(\n T.int64(512),\n i0_i1_i2_fused_1 * T.int64(8)\n + ax2_y * T.int64(4)\n + ax2,\n )\n T.where(ax2_y < T.int64(1))\n T.reads(lv1622_shared[v0, v1, v_i2k], lv4[v0, v1, v2])\n T.writes(p_output0_intermediate[v0, v1, v2])\n p_output0_intermediate[v0, v1, v2] = lv4[v0, v1, v2] * (\n lv1622_shared[v0, v1, v_i2k]\n + lv1622_shared[v0, v1, v_i2k + T.int64(4)]\n )\n\n\ndef sch_fused_decode5_fused_matmul6_multiply1(func):\n sch = tvm.tir.Schedule(func)\n b0 = sch.get_block(name=\"decode\", func_name=\"main\")\n b1 = sch.get_block(name=\"matmul\", func_name=\"main\")\n l2, l3, l4, l5 = sch.get_loops(block=b1)\n l6 = sch.fuse(l2, l3, l4, preserve_unit_iters=True)\n v7, v8, v9 = sch.sample_perfect_tile(\n loop=l6, n=3, max_innermost_factor=4, decision=[43, 64, 4]\n )\n l10, l11, l12 = sch.split(loop=l6, factors=[v7, v8, v9], preserve_unit_iters=True)\n v13, v14, v15 = sch.sample_perfect_tile(\n loop=l5, n=3, max_innermost_factor=8, decision=[128, 4, 8]\n )\n l16, l17, l18 = sch.split(\n loop=l5, factors=[v13, v14, v15], preserve_unit_iters=True\n )\n sch.reorder(l10, l11, l16, l17, l18, l12)\n sch.bind(loop=l10, thread_axis=\"blockIdx.x\")\n sch.bind(loop=l11, thread_axis=\"threadIdx.x\")\n sch.compute_inline(block=b0)\n b19 = sch.cache_write(block=b1, write_buffer_index=0, storage_scope=\"local\")\n sch.reverse_compute_at(block=b19, loop=l11, preserve_unit_loops=True, index=-1)\n b20 = sch.cache_read(block=b1, read_buffer_index=1, storage_scope=\"local\")\n b21 = sch.cache_read(block=b1, read_buffer_index=2, storage_scope=\"local\")\n b22 = sch.cache_read(block=b1, read_buffer_index=0, storage_scope=\"shared\")\n sch.compute_at(block=b22, loop=l11, preserve_unit_loops=True, index=-1)\n v23 = sch.sample_categorical(\n candidates=[1, 2, 4, 8], probs=[0.25, 0.25, 0.25, 0.25], decision=3\n )\n sch.annotate(\n block_or_loop=b22, ann_key=\"meta_schedule.cooperative_fetch\", ann_val=v23\n )\n sch.compute_at(block=b20, loop=l17, preserve_unit_loops=True, index=-1)\n sch.compute_at(block=b21, loop=l16, preserve_unit_loops=True, index=-1)\n l24, l25, l26, l27, l28, l29 = sch.get_loops(block=b20)\n sch.vectorize(loop=l29)\n l30, l31, l32, l33, l34 = sch.get_loops(block=b21)\n sch.vectorize(loop=l34)\n l35, l36, l37, l38, l39 = sch.get_loops(block=b19)\n sch.vectorize(loop=l39)\n sch.vectorize(loop=l12)\n b40 = sch.decompose_reduction(block=b1, loop=l16)\n b41 = sch.get_block(name=\"T_multiply\", func_name=\"main\")\n sch.reverse_compute_inline(block=b41)\n sch.enter_postproc()\n sch.unannotate(block_or_loop=b22, ann_key=\"meta_schedule.cooperative_fetch\")\n l42, l43, l44, l45, l46 = sch.get_loops(block=b22)\n l47, l48, l49 = sch.split(loop=l46, factors=[None, 64, 8], preserve_unit_iters=True)\n sch.vectorize(loop=l49)\n sch.bind(loop=l48, thread_axis=\"threadIdx.x\")\n return sch.mod[\"main\"].with_attr(\"tir.is_scheduled\", 1)\n\n\n@T.prim_func(private=True)\n\ndef fused_fused_decode9_matmul7(\n lv19: T.Buffer((T.int64(512), T.int64(22016)), \"uint32\"),\n lv20: T.Buffer((T.int64(128), T.int64(22016)), \"float16\"),\n lv1654: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), \"float16\"),\n var_matmul_intermediate: T.Buffer(\n (T.int64(1), T.int64(1), T.int64(22016)), \"float16\"\n ),\n):\n T.func_attr({\"tir.noalias\": T.bool(True)})\n # with T.block(\"root\"):\n p_output0_intermediate = T.alloc_buffer((T.int64(4096), T.int64(22016)), \"float16\")\n for i, j in T.grid(T.int64(4096), T.int64(22016)):\n with T.block(\"decode\"):\n v_i, v_j = T.axis.remap(\"SS\", [i, j])\n T.reads(lv19[v_i // T.int64(8), v_j], lv20[v_i // T.int64(32), v_j])\n T.writes(p_output0_intermediate[v_i, v_j])\n p_output0_intermediate[v_i, v_j] = (\n T.Cast(\n \"float16\",\n T.bitwise_and(\n T.shift_right(\n lv19[v_i // T.int64(8), v_j],\n T.Cast(\"uint32\", v_i % T.int64(8)) * T.uint32(4),\n ),\n T.uint32(15),\n ),\n )\n - T.float16(7)\n ) * lv20[v_i // T.int64(32), v_j]\n for i0, i1, i2, k in T.grid(T.int64(1), T.int64(1), T.int64(22016), T.int64(4096)):\n with T.block(\"matmul\"):\n v_i0, v_i1, v_i2, v_k = T.axis.remap(\"SSSR\", [i0, i1, i2, k])\n T.reads(lv1654[v_i0, v_i1, v_k], p_output0_intermediate[v_k, v_i2])\n T.writes(var_matmul_intermediate[v_i0, v_i1, v_i2])\n with T.init():\n var_matmul_intermediate[v_i0, v_i1, v_i2] = T.float16(0)\n var_matmul_intermediate[v_i0, v_i1, v_i2] = (\n var_matmul_intermediate[v_i0, v_i1, v_i2]\n + lv1654[v_i0, v_i1, v_k] * p_output0_intermediate[v_k, v_i2]\n )\n\n\n@T.prim_func(private=True)\ndef fused_fused_decode9_matmul7_after(\n lv19: T.Buffer((T.int64(512), T.int64(22016)), \"uint32\"),\n lv20: T.Buffer((T.int64(128), T.int64(22016)), \"float16\"),\n lv1654: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), \"float16\"),\n var_matmul_intermediate: T.Buffer(\n (T.int64(1), T.int64(1), T.int64(22016)), \"float16\"\n ),\n):\n T.func_attr({\"tir.is_scheduled\": 1, \"tir.noalias\": T.bool(True)})\n # with T.block(\"root\"):\n var_matmul_intermediate_local = T.alloc_buffer(\n (T.int64(1), T.int64(1), T.int64(352256)), \"float16\", scope=\"local\"\n )\n var_matmul_intermediate_local_batch = T.alloc_buffer(\n (T.int64(1), T.int64(1), T.int64(352256)), \"float16\", scope=\"local\"\n )\n lv19_local = T.alloc_buffer((T.int64(512), T.int64(22016)), \"uint32\", scope=\"local\")\n lv20_local = T.alloc_buffer(\n (T.int64(128), T.int64(22016)), \"float16\", scope=\"local\"\n )\n lv1654_shared = T.alloc_buffer(\n (T.int64(1), T.int64(1), T.int64(4096)), \"float16\", scope=\"shared\"\n )\n for i0_i1_i2_fused_0 in T.thread_binding(T.int64(172), thread=\"blockIdx.x\"):\n for i0_i1_i2_fused_1 in T.thread_binding(T.int64(32), thread=\"threadIdx.x\"):\n for ax2_y in T.thread_binding(T.int64(16), thread=\"threadIdx.y\"):\n for i0_i1_i2_fused_2_init in T.vectorized(T.int64(4)):\n with T.block(\"matmul_init\"):\n v_i0 = T.axis.spatial(T.int64(1), T.int64(0))\n v_i1 = T.axis.spatial(T.int64(1), T.int64(0))\n v_i2 = T.axis.spatial(\n T.int64(352256),\n i0_i1_i2_fused_0 * T.int64(2048)\n + i0_i1_i2_fused_1 * T.int64(64)\n + ax2_y * T.int64(4)\n + i0_i1_i2_fused_2_init\n )\n T.reads()\n T.writes(var_matmul_intermediate_local[v_i0, v_i1, v_i2])\n var_matmul_intermediate_local[v_i0, v_i1, v_i2] = T.float16(0)\n for k_0 in range(T.int64(1)):\n for ax2_1 in T.thread_binding(T.int64(32), thread=\"threadIdx.x\"):\n for ax2_y in T.thread_binding(T.int64(16), thread=\"threadIdx.y\"):\n for ax0, ax1 in T.grid(T.int64(1), T.int64(1)):\n for ax2_2 in T.vectorized(T.int64(8)):\n with T.block(\"lv1654_shared\"):\n v0, v1 = T.axis.remap(\"SS\", [ax0, ax1])\n v2 = T.axis.spatial(\n T.int64(4096),\n k_0 * T.int64(4096)\n + ax2_y * T.int64(256)\n + ax2_1 * T.int64(8)\n + ax2_2,\n )\n v2k = T.axis.spatial(\n T.int64(4096),\n (\n ax2_y * T.int64(256)\n + ax2_1 * T.int64(8)\n + ax2_2\n ),\n )\n T.reads(lv1654[v0, v1, v2])\n T.writes(lv1654_shared[v0, v1, v2k])\n lv1654_shared[v0, v1, v2k] = lv1654[v0, v1, v2]\n for k_1 in range(T.int64(8)):\n for ax2_y in T.thread_binding(T.int64(16), thread=\"threadIdx.y\"):\n for ax1 in T.vectorized(T.int64(4)):\n with T.block(\"matmul_init_local\"):\n v_i0 = T.axis.spatial(T.int64(1), T.int64(0))\n v_i1 = T.axis.spatial(T.int64(1), T.int64(0))\n v_i2k = T.axis.spatial(\n T.int64(352256),\n i0_i1_i2_fused_0 * T.int64(2048)\n + i0_i1_i2_fused_1 * T.int64(64)\n + ax2_y * T.int64(4)\n + ax1,\n )\n T.reads()\n T.writes(\n var_matmul_intermediate_local_batch[\n v_i0, v_i1, v_i2k\n ]\n )\n var_matmul_intermediate_local_batch[\n v_i0, v_i1, v_i2k\n ] = T.float16(0)\n for ax0 in range(T.int64(1)):\n for ax1 in T.vectorized(T.int64(4)):\n with T.block(\"lv20_local\"):\n v0 = T.axis.spatial(\n T.int64(128),\n k_0 * T.int64(128)\n + (k_1 * T.int64(16) + ax2_y)\n + ax0,\n )\n v1 = T.axis.spatial(\n T.int64(22016),\n i0_i1_i2_fused_0 * T.int64(128)\n + i0_i1_i2_fused_1 * T.int64(4)\n + ax1,\n )\n T.reads(lv20[v0, v1])\n T.writes(lv20_local[v0, v1])\n lv20_local[v0, v1] = lv20[v0, v1]\n for k_2 in range(T.int64(4)):\n for ax0 in range(T.int64(1)):\n for ax1 in T.vectorized(T.int64(4)):\n with T.block(\"lv19_local\"):\n v0 = T.axis.spatial(\n T.int64(512),\n k_0 * T.int64(512)\n + (k_1 * T.int64(16) + ax2_y) * T.int64(4)\n + k_2\n + ax0,\n )\n v1 = T.axis.spatial(\n T.int64(22016),\n i0_i1_i2_fused_0 * T.int64(128)\n + i0_i1_i2_fused_1 * T.int64(4)\n + ax1,\n )\n T.reads(lv19[v0, v1])\n T.writes(lv19_local[v0, v1])\n lv19_local[v0, v1] = lv19[v0, v1]\n for k_3 in range(T.int64(8)):\n for i0_i1_i2_fused_2 in T.vectorized(T.int64(4)):\n with T.block(\"matmul_update\"):\n v_i0 = T.axis.spatial(T.int64(1), T.int64(0))\n v_i1 = T.axis.spatial(T.int64(1), T.int64(0))\n v_i2 = T.axis.spatial(\n T.int64(22016),\n i0_i1_i2_fused_0 * T.int64(128)\n + i0_i1_i2_fused_1 * T.int64(4)\n + i0_i1_i2_fused_2,\n )\n v_i2k = T.axis.spatial(\n T.int64(352256),\n i0_i1_i2_fused_0 * T.int64(2048)\n + i0_i1_i2_fused_1 * T.int64(64)\n + ax2_y * T.int64(4)\n + i0_i1_i2_fused_2,\n )\n v_k = T.axis.reduce(\n T.int64(4096),\n k_0 * T.int64(4096)\n + (k_1 * T.int64(16) + ax2_y) * T.int64(32)\n + k_2 * T.int64(8)\n + k_3,\n )\n v_ki = T.axis.reduce(\n T.int64(4096),\n (k_1 * T.int64(16) + ax2_y) * T.int64(32)\n + k_2 * T.int64(8)\n + k_3,\n )\n T.reads(\n var_matmul_intermediate_local_batch[\n v_i0, v_i1, v_i2k\n ],\n lv1654_shared[v_i0, v_i1, v_ki],\n lv19_local[v_k // T.int64(8), v_i2],\n )\n T.writes(\n var_matmul_intermediate_local_batch[\n v_i0, v_i1, v_i2k\n ]\n )\n var_matmul_intermediate_local_batch[\n v_i0, v_i1, v_i2k\n ] = var_matmul_intermediate_local_batch[\n v_i0, v_i1, v_i2k\n ] + lv1654_shared[\n v_i0, v_i1, v_ki\n ] * (\n (\n T.Cast(\n \"float16\",\n T.bitwise_and(\n T.shift_right(\n lv19_local[\n v_k // T.int64(8), v_i2\n ],\n T.Cast(\n \"uint32\",\n v_k % T.int64(8),\n )\n * T.uint32(4),\n ),\n T.uint32(15),\n ),\n )\n - T.float16(7)\n )\n )\n for ax0 in range(T.int64(1)):\n for ax1 in T.vectorized(T.int64(4)):\n with T.block(\"multiple_scale\"):\n v_i0 = T.axis.spatial(T.int64(1), T.int64(0))\n v_i1 = T.axis.spatial(T.int64(1), T.int64(0))\n v_i2k = T.axis.spatial(\n T.int64(352256),\n i0_i1_i2_fused_0 * T.int64(2048)\n + i0_i1_i2_fused_1 * T.int64(64)\n + ax2_y * T.int64(4)\n + ax1,\n )\n v0 = T.axis.spatial(\n T.int64(128),\n k_0 * T.int64(128)\n + (k_1 * T.int64(16) + ax2_y)\n + ax0,\n )\n v1 = T.axis.spatial(\n T.int64(22016),\n i0_i1_i2_fused_0 * T.int64(128)\n + i0_i1_i2_fused_1 * T.int64(4)\n + ax1,\n )\n T.reads(\n lv20_local[v0, v1],\n var_matmul_intermediate_local_batch[\n v_i0, v_i1, v_i2k\n ],\n )\n T.writes(\n var_matmul_intermediate_local[v_i0, v_i1, v_i2k]\n )\n var_matmul_intermediate_local[v_i0, v_i1, v_i2k] = (\n var_matmul_intermediate_local[v_i0, v_i1, v_i2k]\n + var_matmul_intermediate_local_batch[\n v_i0, v_i1, v_i2k\n ]\n * lv20_local[v0, v1]\n )\n for ax2_y in T.thread_binding(T.int64(16), thread=\"threadIdx.y\"):\n for ax0, ax1 in T.grid(T.int64(1), T.int64(1)):\n for ax2 in T.vectorized(T.int64(4)):\n with T.block(\"var_matmul_intermediate_update\"):\n v0, v1 = T.axis.remap(\"SS\", [ax0, ax1])\n v2 = T.axis.spatial(\n T.int64(2048),\n i0_i1_i2_fused_1 * T.int64(64)\n + ax2_y * T.int64(4)\n + ax2,\n )\n v_i2k = T.axis.spatial(\n T.int64(352256),\n i0_i1_i2_fused_0 * T.int64(2048)\n + i0_i1_i2_fused_1 * T.int64(64)\n + ax2_y * T.int64(4)\n + ax2,\n )\n T.reads(var_matmul_intermediate_local[v0, v1, v_i2k])\n T.writes(lv1654_shared[v0, v1, v2])\n lv1654_shared[v0, v1, v2] = var_matmul_intermediate_local[\n v0, v1, v_i2k\n ]\n for ax2_y in T.thread_binding(T.int64(16), thread=\"threadIdx.y\"):\n for ax0, ax1 in T.grid(T.int64(1), T.int64(1)):\n for ax2 in T.vectorized(T.int64(4)):\n with T.block(\"reduction_1\"):\n v0, v1 = T.axis.remap(\"SS\", [ax0, ax1])\n v_i2k = T.axis.spatial(\n T.int64(2048),\n i0_i1_i2_fused_1 * T.int64(64)\n + ax2_y * T.int64(4)\n + ax2,\n )\n T.where(ax2_y < T.int64(8))\n T.reads(lv1654_shared[v0, v1, v_i2k])\n T.writes(lv1654_shared[v0, v1, v_i2k])\n lv1654_shared[v0, v1, v_i2k] = (\n lv1654_shared[v0, v1, v_i2k]\n + lv1654_shared[v0, v1, v_i2k + T.int64(32)]\n )\n for ax2_y in T.thread_binding(T.int64(16), thread=\"threadIdx.y\"):\n for ax0, ax1 in T.grid(T.int64(1), T.int64(1)):\n for ax2 in T.vectorized(T.int64(4)):\n with T.block(\"reduction_2\"):\n v0, v1 = T.axis.remap(\"SS\", [ax0, ax1])\n v_i2k = T.axis.spatial(\n T.int64(2048),\n i0_i1_i2_fused_1 * T.int64(64)\n + ax2_y * T.int64(4)\n + ax2,\n )\n T.where(ax2_y < T.int64(4))\n T.reads(lv1654_shared[v0, v1, v_i2k])\n T.writes(lv1654_shared[v0, v1, v_i2k])\n lv1654_shared[v0, v1, v_i2k] = (\n lv1654_shared[v0, v1, v_i2k]\n + lv1654_shared[v0, v1, v_i2k + T.int64(16)]\n )\n for ax2_y in T.thread_binding(T.int64(16), thread=\"threadIdx.y\"):\n for ax0, ax1 in T.grid(T.int64(1), T.int64(1)):\n for ax2 in T.vectorized(T.int64(4)):\n with T.block(\"var_matmul_intermediate_local\"):\n v0, v1 = T.axis.remap(\"SS\", [ax0, ax1])\n v2 = T.axis.spatial(\n T.int64(22016),\n i0_i1_i2_fused_0 * T.int64(128)\n + i0_i1_i2_fused_1 * T.int64(4)\n + ax2,\n )\n v_i2k = T.axis.spatial(\n T.int64(2048),\n i0_i1_i2_fused_1 * T.int64(64)\n + ax2_y * T.int64(4)\n + ax2,\n )\n T.where(ax2_y < T.int64(1))\n T.reads(lv1654_shared[v0, v1, v_i2k])\n T.writes(var_matmul_intermediate[v0, v1, v2])\n var_matmul_intermediate[v0, v1, v2] = (\n lv1654_shared[v0, v1, v_i2k]\n + lv1654_shared[v0, v1, v_i2k + T.int64(4)]\n + lv1654_shared[v0, v1, v_i2k + T.int64(8)]\n + lv1654_shared[v0, v1, v_i2k + T.int64(12)]\n )\n\n\n@T.prim_func(private=True)\ndef fused_fused_decode7_matmul4(\n lv3: T.Buffer((T.int64(512), T.int64(12288)), \"uint32\"),\n lv4: T.Buffer((T.int64(128), T.int64(12288)), \"float16\"),\n lv1615: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), \"float16\"),\n var_matmul_intermediate: T.Buffer(\n (T.int64(1), T.int64(1), T.int64(12288)), \"float16\"\n ),\n):\n T.func_attr({\"tir.noalias\": T.bool(True)})\n # with T.block(\"root\"):\n p_output0_intermediate = T.alloc_buffer((T.int64(4096), T.int64(12288)), \"float16\")\n for i, j in T.grid(T.int64(4096), T.int64(12288)):\n with T.block(\"decode\"):\n v_i, v_j = T.axis.remap(\"SS\", [i, j])\n T.reads(lv3[v_i // T.int64(8), v_j], lv4[v_i // T.int64(32), v_j])\n T.writes(p_output0_intermediate[v_i, v_j])\n p_output0_intermediate[v_i, v_j] = (\n T.Cast(\n \"float16\",\n T.bitwise_and(\n T.shift_right(\n lv3[v_i // T.int64(8), v_j],\n T.Cast(\"uint32\", v_i % T.int64(8)) * T.uint32(4),\n ),\n T.uint32(15),\n ),\n )\n - T.float16(7)\n ) * lv4[v_i // T.int64(32), v_j]\n for i0, i1, i2, k in T.grid(T.int64(1), T.int64(1), T.int64(12288), T.int64(4096)):\n with T.block(\"matmul\"):\n v_i0, v_i1, v_i2, v_k = T.axis.remap(\"SSSR\", [i0, i1, i2, k])\n T.reads(lv1615[v_i0, v_i1, v_k], p_output0_intermediate[v_k, v_i2])\n T.writes(var_matmul_intermediate[v_i0, v_i1, v_i2])\n with T.init():\n var_matmul_intermediate[v_i0, v_i1, v_i2] = T.float16(0)\n var_matmul_intermediate[v_i0, v_i1, v_i2] = (\n var_matmul_intermediate[v_i0, v_i1, v_i2]\n + lv1615[v_i0, v_i1, v_k] * p_output0_intermediate[v_k, v_i2]\n )\n\n\n@T.prim_func(private=True)\ndef fused_fused_decode7_matmul4_after(\n lv3: T.Buffer((T.int64(512), T.int64(12288)), \"uint32\"),\n lv4: T.Buffer((T.int64(128), T.int64(12288)), \"float16\"),\n lv1615: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), \"float16\"),\n var_matmul_intermediate: T.Buffer(\n (T.int64(1), T.int64(1), T.int64(12288)), \"float16\"\n ),\n):\n T.func_attr({\"tir.is_scheduled\": 1, \"tir.noalias\": T.bool(True)})\n # with T.block(\"root\"):\n var_matmul_intermediate_local = T.alloc_buffer(\n (T.int64(1), T.int64(1), T.int64(24576)), \"float16\", scope=\"local\"\n )\n var_matmul_intermediate_local_batch = T.alloc_buffer(\n (T.int64(1), T.int64(1), T.int64(24576)), \"float16\", scope=\"local\"\n )\n lv3_local = T.alloc_buffer((T.int64(512), T.int64(12288)), \"uint32\", scope=\"local\")\n lv4_local = T.alloc_buffer((T.int64(128), T.int64(12288)), \"float16\", scope=\"local\")\n lv1615_shared = T.alloc_buffer(\n (T.int64(1), T.int64(1), T.int64(1024)), \"float16\", scope=\"shared\"\n )\n for i0_i1_i2_fused_0 in T.thread_binding(T.int64(48), thread=\"blockIdx.x\"):\n for i0_i1_i2_fused_1 in T.thread_binding(T.int64(64), thread=\"threadIdx.x\"):\n for ax2_y in T.thread_binding(T.int64(2), thread=\"threadIdx.y\"):\n for i0_i1_i2_fused_2_init in T.vectorized(T.int64(4)):\n with T.block(\"matmul_init\"):\n v_i0 = T.axis.spatial(T.int64(1), T.int64(0))\n v_i1 = T.axis.spatial(T.int64(1), T.int64(0))\n v_i2 = T.axis.spatial(\n T.int64(24576),\n i0_i1_i2_fused_0 * T.int64(512)\n + i0_i1_i2_fused_1 * T.int64(8)\n + ax2_y * T.int64(4)\n + i0_i1_i2_fused_2_init,\n )\n T.reads()\n T.writes(var_matmul_intermediate_local[v_i0, v_i1, v_i2])\n var_matmul_intermediate_local[v_i0, v_i1, v_i2] = T.float16(0)\n for k_0 in range(T.int64(4)):\n for ax2_1 in T.thread_binding(T.int64(64), thread=\"threadIdx.x\"):\n for ax2_y in T.thread_binding(T.int64(2), thread=\"threadIdx.y\"):\n for ax0, ax1 in T.grid(T.int64(1), T.int64(1)):\n for ax2_2 in T.vectorized(T.int64(8)):\n with T.block(\"lv1615_shared\"):\n v0, v1 = T.axis.remap(\"SS\", [ax0, ax1])\n v2 = T.axis.spatial(\n T.int64(4096),\n k_0 * T.int64(1024)\n + ax2_y * T.int64(512)\n + ax2_1 * T.int64(8)\n + ax2_2,\n )\n v2k = T.axis.spatial(\n T.int64(1024),\n (\n ax2_y * T.int64(512)\n + ax2_1 * T.int64(8)\n + ax2_2\n ),\n )\n T.reads(lv1615[v0, v1, v2])\n T.writes(lv1615_shared[v0, v1, v2k])\n lv1615_shared[v0, v1, v2k] = lv1615[v0, v1, v2]\n for k_1 in range(T.int64(16)):\n for ax2_y in T.thread_binding(T.int64(2), thread=\"threadIdx.y\"):\n for ax1 in T.vectorized(T.int64(4)):\n with T.block(\"matmul_init_local\"):\n v_i0 = T.axis.spatial(T.int64(1), T.int64(0))\n v_i1 = T.axis.spatial(T.int64(1), T.int64(0))\n v_i2k = T.axis.spatial(\n T.int64(24576),\n i0_i1_i2_fused_0 * T.int64(512)\n + i0_i1_i2_fused_1 * T.int64(8)\n + ax2_y * T.int64(4)\n + ax1,\n )\n T.reads()\n T.writes(\n var_matmul_intermediate_local_batch[\n v_i0, v_i1, v_i2k\n ]\n )\n var_matmul_intermediate_local_batch[\n v_i0, v_i1, v_i2k\n ] = T.float16(0)\n for ax0 in range(T.int64(1)):\n for ax1 in T.vectorized(T.int64(4)):\n...\n", "needle_token_start": 6907, "needle_token_end": 7805, "code_context_ntokens": 16361, "output": ["```python\n@T.prim_func(private=True)\ndef fused_decode5_fused_matmul6_multiply1_after(\n lv1617: T.Buffer((T.int64(512), T.int64(11008)), \"uint32\"),\n lv1618: T.Buffer((T.int64(128), T.int64(11008)), \"float16\"),\n lv1622: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), \"float16\"),\n lv4: T.Buffer((T.int64(1), T.int64(1), T.int64(11008)), \"float16\"),\n p_output0_intermediate: T.Buffer(\n (T.int64(1), T.int64(1), T.int64(11008)), \"float16\"\n ),\n):\n T.func_attr({\"tir.is_scheduled\": 1, \"tir.noalias\": T.bool(True)})\n # with T.block(\"root\"):\n var_matmul_intermediate_local = T.alloc_buffer(\n (T.int64(1), T.int64(1), T.int64(22016)), \"float16\", scope=\"local\"\n )\n var_matmul_intermediate_local_batch = T.alloc_buffer(\n (T.int64(1), T.int64(1), T.int64(22016)), \"float16\", scope=\"local\"\n )\n lv1617_local = T.alloc_buffer(\n (T.int64(512), T.int64(11008)), \"uint32\", scope=\"local\"\n )\n lv1618_local = T.alloc_buffer(\n (T.int64(128), T.int64(11008)), \"float16\", scope=\"local\"\n )\n lv1622_shared = T.alloc_buffer(\n (T.int64(1), T.int64(1), T.int64(1024)), \"float16\", scope=\"shared\"\n )\n for i0_i1_i2_fused_0 in T.thread_binding(T.int64(43), thread=\"blockIdx.x\"):\n for i0_i1_i2_fused_1 in T.thread_binding(T.int64(64), thread=\"threadIdx.x\"):\n for ax2_y in T.thread_binding(T.int64(2), thread=\"threadIdx.y\"):\n for i0_i1_i2_fused_2_init in T.vectorized(T.int64(4)):\n with T.block(\"matmul_init\"):\n v_i0 = T.axis.spatial(T.int64(1), T.int64(0))\n v_i1 = T.axis.spatial(T.int64(1), T.int64(0))\n v_i2 = T.axis.spatial(\n T.int64(22016),\n i0_i1_i2_fused_0 * T.int64(512)\n + i0_i1_i2_fused_1 * T.int64(8)\n + ax2_y * T.int64(4)\n + i0_i1_i2_fused_2_init,\n )\n T.reads()\n T.writes(var_matmul_intermediate_local[v_i0, v_i1, v_i2])\n var_matmul_intermediate_local[v_i0, v_i1, v_i2] = T.float16(0)\n for k_0 in range(T.int64(4)):\n for ax2_1 in T.thread_binding(T.int64(64), thread=\"threadIdx.x\"):\n for ax2_y in T.thread_binding(T.int64(2), thread=\"threadIdx.y\"):\n for ax0, ax1 in T.grid(T.int64(1), T.int64(1)):\n for ax2_2 in T.vectorized(T.int64(8)):\n with T.block(\"lv1622_shared\"):\n v0, v1 = T.axis.remap(\"SS\", [ax0, ax1])\n v2 = T.axis.spatial(\n T.int64(4096),\n k_0 * T.int64(1024)\n + ax2_y * T.int64(512)\n + ax2_1 * T.int64(8)\n + ax2_2,\n )\n v2k = T.axis.spatial(\n T.int64(1024),\n (\n ax2_y *"]} +{"repo": "mlc-ai/mlc-llm", "name": "fused_decode4_matmul5_int3_fp16_before", "language": "python", "path": "mlc_llm/dispatch/llama/main.py", "position_ratio": 0.55, "description": "\nFunction Description:\n1. **Purpose**: The function performs a decoding operation followed by a matrix multiplication, primarily working with 16-bit floating-point and 32-bit unsigned integer data types. This is typically used in neural network operations where decoding of compressed data is followed by a transformation via matrix multiplication.\n2. **Input**: The function takes four buffers: a 2D buffer of 32-bit unsigned integers, a 2D buffer of 16-bit floats, and two 3D buffers of 16-bit floats.\n3. **Output**: A 3D buffer of 16-bit floats that stores the result of the matrix multiplication.\n4. **Procedure**: \n - First, a decoding operation is performed where each element of the integer buffer is processed with bitwise operations and combined with elements from the first float buffer to produce a new 2D float buffer.\n - Next, a matrix multiplication is carried out between a slice of the second 3D float buffer and the newly created 2D float buffer, initializing the result buffer to zero before accumulation.\n - The result of this multiplication is stored in the output 3D float buffer.\n\n", "instruction": "Based on the function description and code context, please retrieve and repeat the exact described function from the code context in a code block wrapped by ```:", "template": "instruction\ncode_context\ndescription\ninstruction", "code_context": " T.reads()\n T.writes(var_matmul_intermediate_local[v_i0, v_i1, v_i2])\n var_matmul_intermediate_local[v_i0, v_i1, v_i2] = T.float32(0)\n for k_0_0 in range(T.int64(2)):\n for ax0, ax1_ax2_fused_0 in T.grid(T.int64(1), T.int64(22)):\n for ax1_ax2_fused_1 in T.thread_binding(T.int64(256), thread=\"threadIdx.x\"):\n with T.block(\"lv6_shared\"):\n v0 = T.axis.spatial(T.int64(1), ax0)\n v1 = T.axis.spatial(T.int64(1), T.int64(0))\n v2 = T.axis.spatial(T.int64(11008), k_0_0 * T.int64(5504) + (ax1_ax2_fused_0 * T.int64(256) + ax1_ax2_fused_1))\n T.where(ax1_ax2_fused_0 * T.int64(256) + ax1_ax2_fused_1 < T.int64(5504))\n T.reads(lv6[v0, v1, v2])\n T.writes(lv6_shared[v0, v1, v2])\n T.block_attr({\"buffer_dim_align\": [[0, 1, 32, 8]]})\n lv6_shared[v0, v1, v2] = lv6[v0, v1, v2]\n for k_0_1 in range(T.int64(86)):\n for ax0_0 in range(T.int64(8)):\n for ax0_1 in T.unroll(T.int64(8)):\n for ax1 in range(T.int64(1)):\n with T.block(\"decode\"):\n v_j = T.axis.spatial(T.int64(11008), k_0_0 * T.int64(5504) + k_0_1 * T.int64(64) + ax0_0 * T.int64(8) + ax0_1)\n v_i = T.axis.spatial(T.int64(4096), i0_i1_i2_0_fused * T.int64(256) + i2_2 + ax1)\n T.reads(lv1158[v_j // T.int64(8), v_i], lv60[v_j // T.int64(32), v_i], lv61[v_j // T.int64(32), v_i])\n T.writes(var_decode_intermediate_local[v_j, v_i])\n var_decode_intermediate_local[v_j, v_i] = T.Cast(\"float16\", T.bitwise_and(T.shift_right(lv1158[v_j // T.int64(8), v_i], T.Cast(\"uint32\", v_j % T.int64(8) * T.int64(4))), T.uint32(15))) * lv60[v_j // T.int64(32), v_i] + lv61[v_j // T.int64(32), v_i]\n for k_0_2_k_1_fused in range(T.int64(64)):\n with T.block(\"matmul_update\"):\n v_i0 = T.axis.spatial(T.int64(1), T.int64(0))\n v_i1 = T.axis.spatial(T.int64(1), T.int64(0))\n v_i2 = T.axis.spatial(T.int64(4096), i0_i1_i2_0_fused * T.int64(256) + i2_1 * T.int64(256) + i2_2)\n v_k = T.axis.reduce(T.int64(11008), k_0_0 * T.int64(5504) + k_0_1 * T.int64(64) + k_0_2_k_1_fused)\n T.reads(var_matmul_intermediate_local[v_i0, v_i1, v_i2], lv6_shared[v_i0, v_i1, v_k], var_decode_intermediate_local[v_k, v_i2])\n T.writes(var_matmul_intermediate_local[v_i0, v_i1, v_i2])\n var_matmul_intermediate_local[v_i0, v_i1, v_i2] = var_matmul_intermediate_local[v_i0, v_i1, v_i2] + lv6_shared[v_i0, v_i1, v_k] * var_decode_intermediate_local[v_k, v_i2]\n for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(1), T.int64(1)):\n with T.block(\"var_matmul_intermediate_local\"):\n v0, v1 = T.axis.remap(\"SS\", [ax0, ax1])\n v2 = T.axis.spatial(T.int64(4096), i0_i1_i2_0_fused * T.int64(256) + i2_2 + ax2)\n T.reads(lv4[v0, v1, v2], var_matmul_intermediate_local[v0, v1, v2])\n T.writes(p_output0_intermediate[v0, v1, v2])\n p_output0_intermediate[v0, v1, v2] = lv4[v0, v1, v2] + var_matmul_intermediate_local[v0, v1, v2]\n\n\n@T.prim_func\ndef fused_decode3_matmul1_cast_int3_fp16_before(lv2931: T.Buffer((T.int64(412), T.int64(32000)), \"uint32\"), lv2932: T.Buffer((T.int64(103), T.int64(32000)), \"float16\"), lv3025: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), \"float16\"), p_output0_intermediate: T.Buffer((T.int64(1), T.int64(1), T.int64(32000)), \"float32\")):\n T.func_attr({\"tir.noalias\": T.bool(True)})\n # with T.block(\"root\"):\n var_decode_intermediate = T.alloc_buffer((T.int64(4096), T.int64(32000)), \"float16\")\n var_matmul_intermediate = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(32000)), \"float16\")\n for i, j in T.grid(T.int64(4096), T.int64(32000)):\n with T.block(\"decode\"):\n v_i, v_j = T.axis.remap(\"SS\", [i, j])\n T.reads(lv2931[v_i // T.int64(10), v_j], lv2932[v_i // T.int64(40), v_j])\n T.writes(var_decode_intermediate[v_i, v_j])\n var_decode_intermediate[v_i, v_j] = (T.Cast(\"float16\", T.bitwise_and(T.shift_right(lv2931[v_i // T.int64(10), v_j], T.Cast(\"uint32\", v_i % T.int64(10)) * T.uint32(3)), T.uint32(7))) - T.float16(3)) * lv2932[v_i // T.int64(40), v_j]\n for i0, i1, i2, k in T.grid(T.int64(1), T.int64(1), T.int64(32000), T.int64(4096)):\n with T.block(\"matmul\"):\n v_i0, v_i1, v_i2, v_k = T.axis.remap(\"SSSR\", [i0, i1, i2, k])\n T.reads(lv3025[v_i0, v_i1, v_k], var_decode_intermediate[v_k, v_i2])\n T.writes(var_matmul_intermediate[v_i0, v_i1, v_i2])\n with T.init():\n var_matmul_intermediate[v_i0, v_i1, v_i2] = T.float16(0)\n var_matmul_intermediate[v_i0, v_i1, v_i2] = var_matmul_intermediate[v_i0, v_i1, v_i2] + lv3025[v_i0, v_i1, v_k] * var_decode_intermediate[v_k, v_i2]\n for i0, i1, i2 in T.grid(T.int64(1), T.int64(1), T.int64(32000)):\n with T.block(\"compute\"):\n v_i0, v_i1, v_i2 = T.axis.remap(\"SSS\", [i0, i1, i2])\n T.reads(var_matmul_intermediate[v_i0, v_i1, v_i2])\n T.writes(p_output0_intermediate[v_i0, v_i1, v_i2])\n p_output0_intermediate[v_i0, v_i1, v_i2] = T.Cast(\"float32\", var_matmul_intermediate[v_i0, v_i1, v_i2])\n\n\n@T.prim_func\ndef fused_decode3_matmul1_cast_int3_fp16_after(lv1123: T.Buffer((T.int64(412), T.int64(32000)), \"uint32\"), lv5866: T.Buffer((T.int64(103), T.int64(32000)), \"float16\"), lv1511: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), \"float16\"), var_matmul_intermediate: T.Buffer((T.int64(1), T.int64(1), T.int64(32000)), \"float32\")):\n T.func_attr({\"tir.noalias\": T.bool(True), \"tir.is_scheduled\": 1})\n # with T.block(\"root\"):\n var_decode_intermediate_pad_local = T.alloc_buffer((T.int64(4120), T.int64(32000)), scope=\"local\", dtype=\"float16\")\n var_matmul_intermediate_pad_local = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(32000)), scope=\"local\", dtype=\"float16\")\n lv1511_shared = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(4120)), scope=\"shared\", dtype=\"float16\")\n for i0_i1_i2_0_fused in T.thread_binding(T.int64(125), thread=\"blockIdx.x\", annotations={\"pragma_auto_unroll_max_step\": 16, \"pragma_unroll_explicit\": 1}):\n for i2_1 in T.thread_binding(T.int64(1), thread=\"vthread.x\"):\n for i2_2 in T.thread_binding(T.int64(256), thread=\"threadIdx.x\"):\n for ax0, ax1_ax2_fused_0 in T.grid(T.int64(1), T.int64(17)):\n for ax1_ax2_fused_1 in T.thread_binding(T.int64(256), thread=\"threadIdx.x\"):\n with T.block(\"lv1511_shared\"):\n v0 = T.axis.spatial(T.int64(1), ax0)\n v1 = T.axis.spatial(T.int64(1), T.int64(0))\n v2 = T.axis.spatial(T.int64(4120), ax1_ax2_fused_0 * T.int64(256) + ax1_ax2_fused_1)\n T.reads(lv1511[v0, v1, v2])\n T.writes(lv1511_shared[v0, v1, v2])\n T.where(ax1_ax2_fused_0 * T.int64(256) + ax1_ax2_fused_1 < T.int64(4120))\n T.block_attr({\"buffer_dim_align\": [[0, 1, 32, 8]]})\n lv1511_shared[v0, v1, v2] = T.if_then_else(v2 < T.int64(4096), lv1511[v0, v1, v2], T.float16(0))\n with T.block(\"matmul_init\"):\n v_i0 = T.axis.spatial(T.int64(1), T.int64(0))\n v_i1 = T.axis.spatial(T.int64(1), T.int64(0))\n v_i2 = T.axis.spatial(T.int64(32000), i0_i1_i2_0_fused * T.int64(256) + i2_1 * T.int64(256) + i2_2)\n T.reads()\n T.writes(var_matmul_intermediate_pad_local[v_i0, v_i1, v_i2])\n var_matmul_intermediate_pad_local[v_i0, v_i1, v_i2] = T.float32(0)\n for k_0_0 in range(T.int64(103)):\n for ax0_0 in T.unroll(T.int64(40)):\n for ax1 in range(T.int64(1)):\n with T.block(\"var_decode_intermediate_pad\"):\n v0 = T.axis.spatial(T.int64(4120), k_0_0 * T.int64(40) + ax0_0)\n v1 = T.axis.spatial(T.int64(32000), i0_i1_i2_0_fused * T.int64(256) + i2_2 + ax1)\n T.reads(lv1123[v0 // T.int64(10), v1], lv5866[v0 // T.int64(40), v1])\n T.writes(var_decode_intermediate_pad_local[v0, v1])\n var_decode_intermediate_pad_local[v0, v1] = T.Cast(\"float16\", T.bitwise_and(T.shift_right(lv1123[v0 // T.int64(10), v1], T.Cast(\"uint32\", v0 % T.int64(10)) * T.uint32(3)), T.uint32(7))) - T.float16(3)\n for k_0_1_k_1_fused in range(T.int64(40)):\n with T.block(\"matmul_update\"):\n v_i0 = T.axis.spatial(T.int64(1), T.int64(0))\n v_i1 = T.axis.spatial(T.int64(1), T.int64(0))\n v_i2 = T.axis.spatial(T.int64(32000), i0_i1_i2_0_fused * T.int64(256) + i2_1 * T.int64(256) + i2_2)\n v_k = T.axis.reduce(T.int64(4120), k_0_0 * T.int64(40) + k_0_1_k_1_fused)\n T.reads(var_matmul_intermediate_pad_local[v_i0, v_i1, v_i2], lv1511_shared[v_i0, v_i1, v_k], var_decode_intermediate_pad_local[v_k, v_i2])\n T.writes(var_matmul_intermediate_pad_local[v_i0, v_i1, v_i2])\n var_matmul_intermediate_pad_local[v_i0, v_i1, v_i2] = var_matmul_intermediate_pad_local[v_i0, v_i1, v_i2] + lv1511_shared[v_i0, v_i1, v_k] * var_decode_intermediate_pad_local[v_k, v_i2] * lv5866[v_k // T.int64(40), v_i2]\n for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(1), T.int64(1)):\n with T.block(\"var_matmul_intermediate_pad_local\"):\n v0, v1 = T.axis.remap(\"SS\", [ax0, ax1])\n v2 = T.axis.spatial(T.int64(32000), i0_i1_i2_0_fused * T.int64(256) + i2_2 + ax2)\n T.reads(var_matmul_intermediate_pad_local[v0, v1, v2])\n T.writes(var_matmul_intermediate[v0, v1, v2])\n var_matmul_intermediate[v0, v1, v2] = T.Cast(\"float32\", var_matmul_intermediate_pad_local[v0, v1, v2])\n\n\n@T.prim_func\ndef fused_decode4_fused_matmul5_add3_int3_fp16_before(lv1605: T.Buffer((T.int64(412), T.int64(4096)), \"uint32\"), lv1606: T.Buffer((T.int64(103), T.int64(4096)), \"float16\"), lv164: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), \"float16\"), lv1518: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), \"float16\"), p_output0_intermediate: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), \"float16\")):\n T.func_attr({\"tir.noalias\": T.bool(True)})\n # with T.block(\"root\"):\n var_decode_intermediate = T.alloc_buffer((T.int64(4096), T.int64(4096)), \"float16\")\n var_matmul_intermediate = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(4096)), \"float16\")\n for i, j in T.grid(T.int64(4096), T.int64(4096)):\n with T.block(\"decode\"):\n v_i, v_j = T.axis.remap(\"SS\", [i, j])\n T.reads(lv1605[v_i // T.int64(10), v_j], lv1606[v_i // T.int64(40), v_j])\n T.writes(var_decode_intermediate[v_i, v_j])\n var_decode_intermediate[v_i, v_j] = (T.Cast(\"float16\", T.bitwise_and(T.shift_right(lv1605[v_i // T.int64(10), v_j], T.Cast(\"uint32\", v_i % T.int64(10)) * T.uint32(3)), T.uint32(7))) - T.float16(3)) * lv1606[v_i // T.int64(40), v_j]\n for i0, i1, i2, k in T.grid(T.int64(1), T.int64(1), T.int64(4096), T.int64(4096)):\n with T.block(\"matmul\"):\n v_i0, v_i1, v_i2, v_k = T.axis.remap(\"SSSR\", [i0, i1, i2, k])\n T.reads(lv164[v_i0, v_i1, v_k], var_decode_intermediate[v_k, v_i2])\n T.writes(var_matmul_intermediate[v_i0, v_i1, v_i2])\n with T.init():\n var_matmul_intermediate[v_i0, v_i1, v_i2] = T.float16(0)\n var_matmul_intermediate[v_i0, v_i1, v_i2] = var_matmul_intermediate[v_i0, v_i1, v_i2] + lv164[v_i0, v_i1, v_k] * var_decode_intermediate[v_k, v_i2]\n for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(1), T.int64(4096)):\n with T.block(\"T_add\"):\n v_ax0, v_ax1, v_ax2 = T.axis.remap(\"SSS\", [ax0, ax1, ax2])\n T.reads(lv1518[v_ax0, v_ax1, v_ax2], var_matmul_intermediate[v_ax0, v_ax1, v_ax2])\n T.writes(p_output0_intermediate[v_ax0, v_ax1, v_ax2])\n p_output0_intermediate[v_ax0, v_ax1, v_ax2] = lv1518[v_ax0, v_ax1, v_ax2] + var_matmul_intermediate[v_ax0, v_ax1, v_ax2]\n\n\n@T.prim_func\ndef fused_decode4_fused_matmul5_add3_int3_fp16_after(lv1143: T.Buffer((T.int64(412), T.int64(4096)), \"uint32\"), lv36: T.Buffer((T.int64(103), T.int64(4096)), \"float16\"), lv3: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), \"float16\"), lv2710: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), \"float16\"), p_output0_intermediate: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), \"float16\")):\n T.func_attr({\"tir.is_scheduled\": 1, \"tir.noalias\": T.bool(True)})\n # with T.block(\"root\"):\n var_decode_intermediate_local = T.alloc_buffer((T.int64(4120), T.int64(4096)), scope=\"local\", dtype=\"float16\")\n var_matmul_intermediate_local = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(4096)), scope=\"local\", dtype=\"float16\")\n lv3_shared = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(4120)), scope=\"shared\", dtype=\"float16\")\n for i0_i1_i2_0_fused in T.thread_binding(T.int64(16), thread=\"blockIdx.x\", annotations={\"pragma_auto_unroll_max_step\": 16, \"pragma_unroll_explicit\": 1}):\n for i2_1 in T.thread_binding(T.int64(1), thread=\"vthread.x\"):\n for i2_2 in T.thread_binding(T.int64(256), thread=\"threadIdx.x\"):\n for ax0, ax1_ax2_fused_0 in T.grid(T.int64(1), T.int64(17)):\n for ax1_ax2_fused_1 in T.thread_binding(T.int64(256), thread=\"threadIdx.x\"):\n with T.block(\"lv3_shared\"):\n v0 = T.axis.spatial(T.int64(1), ax0)\n v1 = T.axis.spatial(T.int64(1), T.int64(0))\n v2 = T.axis.spatial(T.int64(4120), ax1_ax2_fused_0 * T.int64(256) + ax1_ax2_fused_1)\n T.reads(lv3[v0, v1, v2])\n T.writes(lv3_shared[v0, v1, v2])\n T.where(ax1_ax2_fused_0 * T.int64(256) + ax1_ax2_fused_1 < T.int64(4120))\n T.block_attr({\"buffer_dim_align\": [[0, 1, 32, 8]]})\n lv3_shared[v0, v1, v2] = T.if_then_else(v2 < T.int64(4096), lv3[v0, v1, v2], T.float16(0))\n with T.block(\"matmul_init\"):\n v_i0 = T.axis.spatial(T.int64(1), T.int64(0))\n v_i1 = T.axis.spatial(T.int64(1), T.int64(0))\n v_i2 = T.axis.spatial(T.int64(4096), i0_i1_i2_0_fused * T.int64(256) + i2_1 * T.int64(256) + i2_2)\n T.reads()\n T.writes(var_matmul_intermediate_local[v_i0, v_i1, v_i2])\n var_matmul_intermediate_local[v_i0, v_i1, v_i2] = T.float32(0)\n for k_0_0 in range(T.int64(103)):\n for ax0_0 in T.unroll(T.int64(40)):\n for ax1 in range(T.int64(1)):\n with T.block(\"decode\"):\n v_j = T.axis.spatial(T.int64(4120), k_0_0 * T.int64(40) + ax0_0)\n v_i = T.axis.spatial(T.int64(4096), i0_i1_i2_0_fused * T.int64(256) + i2_2 + ax1)\n T.reads(lv1143[v_j // T.int64(10), v_i], lv36[v_j // T.int64(40), v_i])\n T.writes(var_decode_intermediate_local[v_j, v_i])\n var_decode_intermediate_local[v_j, v_i] = T.Cast(\"float16\", T.bitwise_and(T.shift_right(lv1143[v_j // T.int64(10), v_i], T.Cast(\"uint32\", v_j % T.int64(10)) * T.uint32(3)), T.uint32(7))) - T.float16(3)\n for k_0_1_k_1_fused in range(T.int64(40)):\n with T.block(\"matmul_update\"):\n v_i0 = T.axis.spatial(T.int64(1), T.int64(0))\n v_i1 = T.axis.spatial(T.int64(1), T.int64(0))\n v_i2 = T.axis.spatial(T.int64(4096), i0_i1_i2_0_fused * T.int64(256) + i2_1 * T.int64(256) + i2_2)\n v_k = T.axis.reduce(T.int64(4120), k_0_0 * T.int64(40) + k_0_1_k_1_fused)\n T.reads(var_matmul_intermediate_local[v_i0, v_i1, v_i2], lv3_shared[v_i0, v_i1, v_k], var_decode_intermediate_local[v_k, v_i2])\n T.writes(var_matmul_intermediate_local[v_i0, v_i1, v_i2])\n var_matmul_intermediate_local[v_i0, v_i1, v_i2] = var_matmul_intermediate_local[v_i0, v_i1, v_i2] + lv3_shared[v_i0, v_i1, v_k] * var_decode_intermediate_local[v_k, v_i2] * lv36[v_k // T.int64(40), v_i2]\n for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(1), T.int64(1)):\n with T.block(\"var_matmul_intermediate_local\"):\n v0, v1 = T.axis.remap(\"SS\", [ax0, ax1])\n v2 = T.axis.spatial(T.int64(4096), i0_i1_i2_0_fused * T.int64(256) + i2_2 + ax2)\n T.reads(lv2710[v0, v1, v2], var_matmul_intermediate_local[v0, v1, v2])\n T.writes(p_output0_intermediate[v0, v1, v2])\n p_output0_intermediate[v0, v1, v2] = lv2710[v0, v1, v2] + var_matmul_intermediate_local[v0, v1, v2]\n\n\n@T.prim_func\n\ndef fused_decode4_matmul5_int3_fp16_before(lv1587: T.Buffer((T.int64(412), T.int64(4096)), \"uint32\"), lv1588: T.Buffer((T.int64(103), T.int64(4096)), \"float16\"), lv1520: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), \"float16\"), var_matmul_intermediate: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), \"float16\")):\n T.func_attr({\"tir.noalias\": T.bool(True)})\n # with T.block(\"root\"):\n var_decode_intermediate = T.alloc_buffer((T.int64(4096), T.int64(4096)), \"float16\")\n for i, j in T.grid(T.int64(4096), T.int64(4096)):\n with T.block(\"decode\"):\n v_i, v_j = T.axis.remap(\"SS\", [i, j])\n T.reads(lv1587[v_i // T.int64(10), v_j], lv1588[v_i // T.int64(40), v_j])\n T.writes(var_decode_intermediate[v_i, v_j])\n var_decode_intermediate[v_i, v_j] = (T.Cast(\"float16\", T.bitwise_and(T.shift_right(lv1587[v_i // T.int64(10), v_j], T.Cast(\"uint32\", v_i % T.int64(10)) * T.uint32(3)), T.uint32(7))) - T.float16(3)) * lv1588[v_i // T.int64(40), v_j]\n for i0, i1, i2, k in T.grid(T.int64(1), T.int64(1), T.int64(4096), T.int64(4096)):\n with T.block(\"matmul\"):\n v_i0, v_i1, v_i2, v_k = T.axis.remap(\"SSSR\", [i0, i1, i2, k])\n T.reads(lv1520[v_i0, v_i1, v_k], var_decode_intermediate[v_k, v_i2])\n T.writes(var_matmul_intermediate[v_i0, v_i1, v_i2])\n with T.init():\n var_matmul_intermediate[v_i0, v_i1, v_i2] = T.float16(0)\n var_matmul_intermediate[v_i0, v_i1, v_i2] = var_matmul_intermediate[v_i0, v_i1, v_i2] + lv1520[v_i0, v_i1, v_k] * var_decode_intermediate[v_k, v_i2]\n\n\n@T.prim_func\ndef fused_decode4_matmul5_int3_fp16_after(lv1128: T.Buffer((T.int64(412), T.int64(4096)), \"uint32\"), lv12: T.Buffer((T.int64(103), T.int64(4096)), \"float16\"), lv2712: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), \"float16\"), var_matmul_intermediate: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), \"float16\")):\n T.func_attr({\"tir.is_scheduled\": 1, \"tir.noalias\": T.bool(True)})\n # with T.block(\"root\"):\n var_decode_intermediate_local = T.alloc_buffer((T.int64(4120), T.int64(4096)), scope=\"local\", dtype=\"float16\")\n var_matmul_intermediate_local = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(4096)), scope=\"local\", dtype=\"float16\")\n lv2712_shared = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(4120)), scope=\"shared\", dtype=\"float16\")\n for i0_i1_i2_0_fused in T.thread_binding(T.int64(16), thread=\"blockIdx.x\", annotations={\"pragma_auto_unroll_max_step\": 16, \"pragma_unroll_explicit\": 1}):\n for i2_1 in T.thread_binding(T.int64(1), thread=\"vthread.x\"):\n for i2_2 in T.thread_binding(T.int64(256), thread=\"threadIdx.x\"):\n for ax0, ax1_ax2_fused_0 in T.grid(T.int64(1), T.int64(17)):\n for ax1_ax2_fused_1 in T.thread_binding(T.int64(256), thread=\"threadIdx.x\"):\n with T.block(\"lv2712_shared\"):\n v0 = T.axis.spatial(T.int64(1), ax0)\n v1 = T.axis.spatial(T.int64(1), T.int64(0))\n v2 = T.axis.spatial(T.int64(4120), ax1_ax2_fused_0 * T.int64(256) + ax1_ax2_fused_1)\n T.reads(lv2712[v0, v1, v2])\n T.writes(lv2712_shared[v0, v1, v2])\n T.where(ax1_ax2_fused_0 * T.int64(256) + ax1_ax2_fused_1 < T.int64(4120))\n T.block_attr({\"buffer_dim_align\": [[0, 1, 32, 8]]})\n lv2712_shared[v0, v1, v2] = T.if_then_else(v2 < T.int64(4096), lv2712[v0, v1, v2], T.float16(0))\n with T.block(\"matmul_init\"):\n v_i0 = T.axis.spatial(T.int64(1), T.int64(0))\n v_i1 = T.axis.spatial(T.int64(1), T.int64(0))\n v_i2 = T.axis.spatial(T.int64(4096), i0_i1_i2_0_fused * T.int64(256) + i2_1 * T.int64(256) + i2_2)\n T.reads()\n T.writes(var_matmul_intermediate_local[v_i0, v_i1, v_i2])\n var_matmul_intermediate_local[v_i0, v_i1, v_i2] = T.float32(0)\n for k_0_0 in range(T.int64(103)):\n for ax0_0 in T.unroll(T.int64(40)):\n for ax1 in range(T.int64(1)):\n with T.block(\"decode\"):\n v_j = T.axis.spatial(T.int64(4120), k_0_0 * T.int64(40) + ax0_0)\n v_i = T.axis.spatial(T.int64(4096), i0_i1_i2_0_fused * T.int64(256) + i2_2 + ax1)\n T.reads(lv1128[v_j // T.int64(10), v_i], lv12[v_j // T.int64(40), v_i])\n T.writes(var_decode_intermediate_local[v_j, v_i])\n var_decode_intermediate_local[v_j, v_i] = T.Cast(\"float16\", T.bitwise_and(T.shift_right(lv1128[v_j // T.int64(10), v_i], T.Cast(\"uint32\", v_j % T.int64(10)) * T.uint32(3)), T.uint32(7))) - T.float16(3)\n for k_0_1_k_1_fused in range(T.int64(40)):\n with T.block(\"matmul_update\"):\n v_i0 = T.axis.spatial(T.int64(1), T.int64(0))\n v_i1 = T.axis.spatial(T.int64(1), T.int64(0))\n v_i2 = T.axis.spatial(T.int64(4096), i0_i1_i2_0_fused * T.int64(256) + i2_1 * T.int64(256) + i2_2)\n v_k = T.axis.reduce(T.int64(4120), k_0_0 * T.int64(40) + k_0_1_k_1_fused)\n T.reads(var_matmul_intermediate_local[v_i0, v_i1, v_i2], lv2712_shared[v_i0, v_i1, v_k], var_decode_intermediate_local[v_k, v_i2])\n T.writes(var_matmul_intermediate_local[v_i0, v_i1, v_i2])\n var_matmul_intermediate_local[v_i0, v_i1, v_i2] = var_matmul_intermediate_local[v_i0, v_i1, v_i2] + lv2712_shared[v_i0, v_i1, v_k] * var_decode_intermediate_local[v_k, v_i2] * lv12[v_k // T.int64(40), v_i2]\n for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(1), T.int64(1)):\n with T.block(\"var_matmul_intermediate_local\"):\n v0, v1 = T.axis.remap(\"SS\", [ax0, ax1])\n v2 = T.axis.spatial(T.int64(4096), i0_i1_i2_0_fused * T.int64(256) + i2_2 + ax2)\n T.reads(var_matmul_intermediate_local[v0, v1, v2])\n T.writes(var_matmul_intermediate[v0, v1, v2])\n var_matmul_intermediate[v0, v1, v2] = var_matmul_intermediate_local[v0, v1, v2]\n\n\n@T.prim_func\ndef fused_decode5_fused_matmul8_multiply1_int3_fp16_before(lv1617: T.Buffer((T.int64(412), T.int64(11008)), \"uint32\"), lv1618: T.Buffer((T.int64(103), T.int64(11008)), \"float16\"), lv1557: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), \"float16\"), lv3: T.Buffer((T.int64(1), T.int64(1), T.int64(11008)), \"float16\"), p_output0_intermediate: T.Buffer((T.int64(1), T.int64(1), T.int64(11008)), \"float16\")):\n T.func_attr({\"tir.noalias\": T.bool(True)})\n # with T.block(\"root\"):\n var_decode_intermediate = T.alloc_buffer((T.int64(4096), T.int64(11008)), \"float16\")\n var_matmul_intermediate = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(11008)), \"float16\")\n for i, j in T.grid(T.int64(4096), T.int64(11008)):\n with T.block(\"decode\"):\n v_i, v_j = T.axis.remap(\"SS\", [i, j])\n T.reads(lv1617[v_i // T.int64(10), v_j], lv1618[v_i // T.int64(40), v_j])\n T.writes(var_decode_intermediate[v_i, v_j])\n var_decode_intermediate[v_i, v_j] = (T.Cast(\"float16\", T.bitwise_and(T.shift_right(lv1617[v_i // T.int64(10), v_j], T.Cast(\"uint32\", v_i % T.int64(10)) * T.uint32(3)), T.uint32(7))) - T.float16(3)) * lv1618[v_i // T.int64(40), v_j]\n for i0, i1, i2, k in T.grid(T.int64(1), T.int64(1), T.int64(11008), T.int64(4096)):\n with T.block(\"matmul\"):\n v_i0, v_i1, v_i2, v_k = T.axis.remap(\"SSSR\", [i0, i1, i2, k])\n T.reads(lv1557[v_i0, v_i1, v_k], var_decode_intermediate[v_k, v_i2])\n T.writes(var_matmul_intermediate[v_i0, v_i1, v_i2])\n with T.init():\n var_matmul_intermediate[v_i0, v_i1, v_i2] = T.float16(0)\n var_matmul_intermediate[v_i0, v_i1, v_i2] = var_matmul_intermediate[v_i0, v_i1, v_i2] + lv1557[v_i0, v_i1, v_k] * var_decode_intermediate[v_k, v_i2]\n for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(1), T.int64(11008)):\n with T.block(\"T_multiply\"):\n v_ax0, v_ax1, v_ax2 = T.axis.remap(\"SSS\", [ax0, ax1, ax2])\n T.reads(lv3[v_ax0, v_ax1, v_ax2], var_matmul_intermediate[v_ax0, v_ax1, v_ax2])\n T.writes(p_output0_intermediate[v_ax0, v_ax1, v_ax2])\n p_output0_intermediate[v_ax0, v_ax1, v_ax2] = lv3[v_ax0, v_ax1, v_ax2] * var_matmul_intermediate[v_ax0, v_ax1, v_ax2]\n\n\n@T.prim_func\ndef fused_decode5_fused_matmul8_multiply1_int3_fp16_after(lv1153: T.Buffer((T.int64(412), T.int64(11008)), \"uint32\"), lv52: T.Buffer((T.int64(103), T.int64(11008)), \"float16\"), lv2749: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), \"float16\"), lv5: T.Buffer((T.int64(1), T.int64(1), T.int64(11008)), \"float16\"), p_output0_intermediate: T.Buffer((T.int64(1), T.int64(1), T.int64(11008)), \"float16\")):\n T.func_attr({\"tir.is_scheduled\": 1, \"tir.noalias\": T.bool(True)})\n # with T.block(\"root\"):\n var_decode_intermediate_local = T.alloc_buffer((T.int64(4120), T.int64(11008)), scope=\"local\", dtype=\"float16\")\n var_matmul_intermediate_local = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(11008)), scope=\"local\", dtype=\"float16\")\n lv2749_shared = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(4120)), scope=\"shared\", dtype=\"float16\")\n for i0_i1_i2_0_fused in T.thread_binding(T.int64(43), thread=\"blockIdx.x\", annotations={\"pragma_auto_unroll_max_step\": 16, \"pragma_unroll_explicit\": 1}):\n for i2_1 in T.thread_binding(T.int64(1), thread=\"vthread.x\"):\n for i2_2 in T.thread_binding(T.int64(256), thread=\"threadIdx.x\"):\n for ax0, ax1_ax2_fused_0 in T.grid(T.int64(1), T.int64(17)):\n for ax1_ax2_fused_1 in T.thread_binding(T.int64(256), thread=\"threadIdx.x\"):\n with T.block(\"lv2749_shared\"):\n v0 = T.axis.spatial(T.int64(1), ax0)\n v1 = T.axis.spatial(T.int64(1), T.int64(0))\n v2 = T.axis.spatial(T.int64(4120), ax1_ax2_fused_0 * T.int64(256) + ax1_ax2_fused_1)\n T.reads(lv2749[v0, v1, v2])\n T.writes(lv2749_shared[v0, v1, v2])\n T.where(ax1_ax2_fused_0 * T.int64(256) + ax1_ax2_fused_1 < T.int64(4120))\n T.block_attr({\"buffer_dim_align\": [[0, 1, 32, 8]]})\n lv2749_shared[v0, v1, v2] = T.if_then_else(v2 < T.int64(4096), lv2749[v0, v1, v2], T.float16(0))\n with T.block(\"matmul_init\"):\n v_i0 = T.axis.spatial(T.int64(1), T.int64(0))\n v_i1 = T.axis.spatial(T.int64(1), T.int64(0))\n v_i2 = T.axis.spatial(T.int64(11008), i0_i1_i2_0_fused * T.int64(256) + i2_1 * T.int64(256) + i2_2)\n T.reads()\n T.writes(var_matmul_intermediate_local[v_i0, v_i1, v_i2])\n var_matmul_intermediate_local[v_i0, v_i1, v_i2] = T.float32(0)\n for k_0_0 in range(T.int64(103)):\n for ax0_0 in T.unroll(T.int64(40)):\n for ax1 in range(T.int64(1)):\n with T.block(\"decode\"):\n v_j = T.axis.spatial(T.int64(4120), k_0_0 * T.int64(40) + ax0_0)\n v_i = T.axis.spatial(T.int64(11008), i0_i1_i2_0_fused * T.int64(256) + i2_2 + ax1)\n T.reads(lv1153[v_j // T.int64(10), v_i], lv52[v_j // T.int64(40), v_i])\n T.writes(var_decode_intermediate_local[v_j, v_i])\n var_decode_intermediate_local[v_j, v_i] = T.Cast(\"float16\", T.bitwise_and(T.shift_right(lv1153[v_j // T.int64(10), v_i], T.Cast(\"uint32\", v_j % T.int64(10)) * T.uint32(3)), T.uint32(7))) - T.float16(3)\n for k_0_1_k_1_fused in range(T.int64(40)):\n with T.block(\"matmul_update\"):\n v_i0 = T.axis.spatial(T.int64(1), T.int64(0))\n v_i1 = T.axis.spatial(T.int64(1), T.int64(0))\n v_i2 = T.axis.spatial(T.int64(11008), i0_i1_i2_0_fused * T.int64(256) + i2_1 * T.int64(256) + i2_2)\n v_k = T.axis.reduce(T.int64(4120), k_0_0 * T.int64(40) + k_0_1_k_1_fused)\n T.reads(var_matmul_intermediate_local[v_i0, v_i1, v_i2], lv2749_shared[v_i0, v_i1, v_k], var_decode_intermediate_local[v_k, v_i2])\n T.writes(var_matmul_intermediate_local[v_i0, v_i1, v_i2])\n var_matmul_intermediate_local[v_i0, v_i1, v_i2] = var_matmul_intermediate_local[v_i0, v_i1, v_i2] + lv2749_shared[v_i0, v_i1, v_k] * var_decode_intermediate_local[v_k, v_i2] * lv52[v_k // T.int64(40), v_i2]\n for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(1), T.int64(1)):\n with T.block(\"var_matmul_intermediate_local\"):\n v0, v1 = T.axis.remap(\"SS\", [ax0, ax1])\n v2 = T.axis.spatial(T.int64(11008), i0_i1_i2_0_fused * T.int64(256) + i2_2 + ax2)\n T.reads(lv5[v0, v1, v2], var_matmul_intermediate_local[v0, v1, v2])\n T.writes(p_output0_intermediate[v0, v1, v2])\n p_output0_intermediate[v0, v1, v2] = lv5[v0, v1, v2] * var_matmul_intermediate_local[v0, v1, v2]\n\n\n@T.prim_func\ndef fused_decode5_fused_matmul8_silu1_int3_fp16_before(lv1611: T.Buffer((T.int64(412), T.int64(11008)), \"uint32\"), lv1612: T.Buffer((T.int64(103), T.int64(11008)), \"float16\"), lv1557: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), \"float16\"), p_output0_intermediate: T.Buffer((T.int64(1), T.int64(1), T.int64(11008)), \"float16\")):\n T.func_attr({\"tir.noalias\": T.bool(True)})\n # with T.block(\"root\"):\n var_decode_intermediate = T.alloc_buffer((T.int64(4096), T.int64(11008)), \"float16\")\n var_matmul_intermediate = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(11008)), \"float16\")\n compute = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(11008)), \"float16\")\n for i, j in T.grid(T.int64(4096), T.int64(11008)):\n with T.block(\"decode\"):\n v_i, v_j = T.axis.remap(\"SS\", [i, j])\n T.reads(lv1611[v_i // T.int64(10), v_j], lv1612[v_i // T.int64(40), v_j])\n T.writes(var_decode_intermediate[v_i, v_j])\n var_decode_intermediate[v_i, v_j] = (T.Cast(\"float16\", T.bitwise_and(T.shift_right(lv1611[v_i // T.int64(10), v_j], T.Cast(\"uint32\", v_i % T.int64(10)) * T.uint32(3)), T.uint32(7))) - T.float16(3)) * lv1612[v_i // T.int64(40), v_j]\n for i0, i1, i2, k in T.grid(T.int64(1), T.int64(1), T.int64(11008), T.int64(4096)):\n with T.block(\"matmul\"):\n v_i0, v_i1, v_i2, v_k = T.axis.remap(\"SSSR\", [i0, i1, i2, k])\n T.reads(lv1557[v_i0, v_i1, v_k], var_decode_intermediate[v_k, v_i2])\n T.writes(var_matmul_intermediate[v_i0, v_i1, v_i2])\n with T.init():\n var_matmul_intermediate[v_i0, v_i1, v_i2] = T.float16(0)\n var_matmul_intermediate[v_i0, v_i1, v_i2] = var_matmul_intermediate[v_i0, v_i1, v_i2] + lv1557[v_i0, v_i1, v_k] * var_decode_intermediate[v_k, v_i2]\n for i0, i1, i2 in T.grid(T.int64(1), T.int64(1), T.int64(11008)):\n with T.block(\"compute\"):\n v_i0, v_i1, v_i2 = T.axis.remap(\"SSS\", [i0, i1, i2])\n T.reads(var_matmul_intermediate[v_i0, v_i1, v_i2])\n T.writes(compute[v_i0, v_i1, v_i2])\n compute[v_i0, v_i1, v_i2] = T.sigmoid(var_matmul_intermediate[v_i0, v_i1, v_i2])\n for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(1), T.int64(11008)):\n with T.block(\"T_multiply\"):\n...\n", "needle_token_start": 8523, "needle_token_end": 9368, "code_context_ntokens": 16281, "output": ["```python\n@T.prim_func\ndef fused_decode5_fused_matmul8_silu1_int3_fp16_before(lv1611: T.Buffer((T.int64(412), T.int64(11008)), \"uint32\"), lv1612: T.Buffer((T.int64(103), T.int64(11008)), \"float16\"), lv1557: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), \"float16\"), p_output0_intermediate: T.Buffer((T.int64(1), T.int64(1), T.int64(11008)), \"float16\")):\n T.func_attr({\"tir.noalias\": T.bool(True)})\n # with T.block(\"root\"):\n var_decode_intermediate = T.alloc_buffer((T.int64(4096), T.int64(11008)), \"float16\")\n var_matmul_intermediate = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(11008)), \"float16\")\n compute = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(11008)), \"float16\")\n for i, j in T.grid(T.int64(4096), T.int64(11008)):\n with T.block(\"decode\"):\n v_i, v_j = T.axis.remap(\"SS\", [i, j])\n T.reads(lv1611[v_i // T.int64(10), v_j], lv1612[v_i // T.int64(40), v_j])\n T.writes(var_decode_intermediate[v_i, v_j])\n var_decode_intermediate[v_i, v_j] = (T.Cast(\"float16\", T.bitwise_and(T.shift_right(lv1611[v_i // T.int64(10), v_j], T.Cast(\"uint32\", v_i % T.int64(10)) * T.uint32(3)), T.uint32(7))) - T.float16(3)) * lv1612[v_i // T.int64(40), v_j]\n for i0, i1, i2, k in T.grid(T.int64(1), T.int64(1), T.int64(11008), T.int64(4096)):\n with T.block(\"matmul\"):\n v_i0, v_i1, v_i2, v_k = T.axis.remap(\"SSSR\", [i0, i1, i2, k])\n T.reads(lv1557[v_i0, v_i1, v_k], var_decode_intermediate[v_k, v_i2])\n T.writes(var_matmul_intermediate[v_i0, v_i1, v_i2])\n with T.init():\n var_matmul_intermediate[v_i0, v_i1, v_i2] = T.float16(0)\n var_matmul_intermediate[v_i0, v_i1, v_i2] = var_matmul_intermediate[v_i0, v_i1, v_i2] + lv1557[v_i0, v_i1, v_k] * var_decode_intermediate[v_k, v_i2]\n for i0, i1, i2 in T.grid(T.int64(1), T.int64(1), T.int64(11008)):\n with T.block(\"compute\"):\n v_i0, v_i1, v_i2 = T.axis.remap(\"SSS\", [i0, i1, i2])\n T.reads(var_matmul_intermediate[v_i0, v_i1, v_i2])\n T.writes(compute[v_i0, v_i1, v_i2])\n compute[v_i0, v_i1, v_i2] = T.sigmoid(var_matmul_intermediate[v_i0, v_i1, v_i2])\n for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(1), T.int64(11008)):\n with T.block(\"T_multiply\"):\n v_ax0, v_ax1, v_ax2 = T.axis.remap(\"SSS\", [ax0, ax1, ax2])\n T.reads(compute[v_ax0, v_ax1, v_ax2], var_matmul_intermediate[v"]} +{"repo": "mlc-ai/mlc-llm", "name": "fused_NT_matmul3_add1_before", "language": "python", "path": "mlc_llm/dispatch/llama/main.py", "position_ratio": 0.65, "description": "\nFunction Description:\n1. **Purpose**: The function performs a matrix multiplication followed by an addition operation. It is designed to compute the result of multiplying a matrix with another transposed matrix and then adding a third matrix to the result.\n2. **Input**: Three matrices are provided as input. The first matrix has dimensions (1, n, 11008), the second matrix is transposed with dimensions (4096, 11008), and the third matrix has dimensions (1, n, 4096).\n3. **Output**: The output is a matrix of dimensions (1, n, 4096), which is the result of the matrix multiplication followed by the addition with the third input matrix.\n4. **Procedure**: \n - The function first initializes an intermediate matrix to store the result of the matrix multiplication.\n - It then iterates over the elements of the input matrices, performing the multiplication of the first matrix and the transposed second matrix, and accumulates the results in the intermediate matrix.\n - After completing the matrix multiplication, the function iterates over the elements of the intermediate matrix and the third input matrix, adding corresponding elements together to produce the final output matrix.\n\n", "instruction": "Based on the function description and code context, please retrieve and repeat the exact described function from the code context in a code block wrapped by ```:", "template": "instruction\ncode_context\ndescription\ninstruction", "code_context": " v_ax2 = T.axis.spatial(T.int64(1), T.int64(0))\n v_ax3 = T.axis.spatial(n, (ax0_ax1_ax2_ax3_fused_0 * T.int64(32) + ax0_ax1_ax2_ax3_fused_1) % n)\n T.reads(var_NT_matmul_intermediate[v_ax0, v_ax1, v_ax2, v_ax3], lv2709[v_ax0, T.int64(0), v_ax2, v_ax3])\n T.writes(var_T_maximum_intermediate[v_ax0, v_ax1, v_ax2, v_ax3])\n var_T_maximum_intermediate[v_ax0, v_ax1, v_ax2, v_ax3] = T.max(var_NT_matmul_intermediate[v_ax0, v_ax1, v_ax2, v_ax3] * T.float32(0.088388349161020605) + lv2709[v_ax0, T.int64(0), v_ax2, v_ax3], T.float32(-3.4028234663852886e+38))\n\n\n@T.prim_func\ndef fused_NT_matmul2_multiply_before(p_lv43: T.handle, linear_weight6: T.Buffer((T.int64(11008), T.int64(4096)), \"float32\"), p_lv48: T.handle, p_output0: T.handle):\n T.func_attr({\"tir.noalias\": T.bool(True)})\n n = T.int64()\n lv43 = T.match_buffer(p_lv43, (T.int64(1), n, T.int64(4096)))\n lv48 = T.match_buffer(p_lv48, (T.int64(1), n, T.int64(11008)))\n var_T_multiply_intermediate = T.match_buffer(p_output0, (T.int64(1), n, T.int64(11008)))\n # with T.block(\"root\"):\n var_NT_matmul_intermediate = T.alloc_buffer((T.int64(1), n, T.int64(11008)))\n for i0, i1, i2, k in T.grid(T.int64(1), n, T.int64(11008), T.int64(4096)):\n with T.block(\"NT_matmul\"):\n v_i0, v_i1, v_i2, v_k = T.axis.remap(\"SSSR\", [i0, i1, i2, k])\n T.reads(lv43[v_i0, v_i1, v_k], linear_weight6[v_i2, v_k])\n T.writes(var_NT_matmul_intermediate[v_i0, v_i1, v_i2])\n with T.init():\n var_NT_matmul_intermediate[v_i0, v_i1, v_i2] = T.float32(0)\n var_NT_matmul_intermediate[v_i0, v_i1, v_i2] = var_NT_matmul_intermediate[v_i0, v_i1, v_i2] + lv43[v_i0, v_i1, v_k] * linear_weight6[v_i2, v_k]\n for ax0, ax1, ax2 in T.grid(T.int64(1), n, T.int64(11008)):\n with T.block(\"T_multiply\"):\n v_ax0, v_ax1, v_ax2 = T.axis.remap(\"SSS\", [ax0, ax1, ax2])\n T.reads(lv48[v_ax0, v_ax1, v_ax2], var_NT_matmul_intermediate[v_ax0, v_ax1, v_ax2])\n T.writes(var_T_multiply_intermediate[v_ax0, v_ax1, v_ax2])\n var_T_multiply_intermediate[v_ax0, v_ax1, v_ax2] = lv48[v_ax0, v_ax1, v_ax2] * var_NT_matmul_intermediate[v_ax0, v_ax1, v_ax2]\n\n\n@T.prim_func\ndef fused_NT_matmul2_multiply_after(p_lv37: T.handle, linear_weight6: T.Buffer((T.int64(11008), T.int64(4096)), \"float32\"), p_lv42: T.handle, p_output0: T.handle):\n T.func_attr({\"tir.noalias\": T.bool(True), \"tir.is_scheduled\": 1})\n n = T.int64()\n lv37 = T.match_buffer(p_lv37, (T.int64(1), n, T.int64(4096)))\n lv42 = T.match_buffer(p_lv42, (T.int64(1), n, T.int64(11008)))\n var_T_multiply_intermediate = T.match_buffer(p_output0, (T.int64(1), n, T.int64(11008)))\n # with T.block(\"root\"):\n for i1_0 in T.thread_binding((n + T.int64(31)) // T.int64(32), thread=\"blockIdx.y\"):\n with T.block(\"NT_matmul_o\"):\n v_i0 = T.axis.spatial(T.int64(1), T.int64(0))\n v_i1_o = T.axis.spatial((n + T.int64(31)) // T.int64(32), i1_0)\n T.reads(lv37[T.Add(v_i0, T.int64(0)), v_i1_o * T.int64(32):v_i1_o * T.int64(32) + T.int64(32), T.int64(0):T.int64(4096)], linear_weight6[T.int64(0):T.int64(11008), T.int64(0):T.int64(4096)], lv42[v_i0, v_i1_o * T.int64(32):v_i1_o * T.int64(32) + T.int64(32), T.int64(0):T.int64(11008)])\n T.writes(var_T_multiply_intermediate[v_i0, v_i1_o * T.int64(32):v_i1_o * T.int64(32) + T.int64(32), T.int64(0):T.int64(11008)])\n var_NT_matmul_intermediate_pad_local = T.alloc_buffer((T.int64(1), T.int64(32), T.int64(11008)), scope=\"local\")\n lv37_pad_shared = T.alloc_buffer((T.int64(1), T.int64(32), T.int64(4096)), scope=\"shared\")\n linear_weight6_shared = T.alloc_buffer((T.int64(11008), T.int64(4096)), scope=\"shared\")\n for i0_0_i1_1_0_i2_0_fused in T.thread_binding(T.int64(344), thread=\"blockIdx.x\", annotations={\"pragma_auto_unroll_max_step\": 16, \"pragma_unroll_explicit\": 1}):\n for i0_1_i1_1_1_i2_1_fused in T.thread_binding(T.int64(1), thread=\"vthread.x\"):\n for i0_2_i1_1_2_i2_2_fused in T.thread_binding(T.int64(64), thread=\"threadIdx.x\"):\n for i1_1_3_init, i2_3_init, i1_1_4_init, i2_4_init in T.grid(T.int64(2), T.int64(2), T.int64(2), T.int64(2)):\n with T.block(\"NT_matmul_init\"):\n v_i1_i = T.axis.spatial(T.int64(32), i0_2_i1_1_2_i2_2_fused // T.int64(8) * T.int64(4) + i1_1_3_init * T.int64(2) + i1_1_4_init)\n v_i2_i = T.axis.spatial(T.int64(11008), i0_0_i1_1_0_i2_0_fused * T.int64(32) + i0_2_i1_1_2_i2_2_fused % T.int64(8) * T.int64(4) + i2_3_init * T.int64(2) + i2_4_init)\n T.reads()\n T.writes(var_NT_matmul_intermediate_pad_local[T.int64(0), v_i1_i, v_i2_i])\n T.block_attr({\"meta_schedule.thread_extent_high_inclusive\": 256, \"meta_schedule.thread_extent_low_inclusive\": 32, \"meta_schedule.tiling_structure\": \"SSSRRSRS\"})\n var_NT_matmul_intermediate_pad_local[T.int64(0), v_i1_i, v_i2_i] = T.float32(0)\n for k_0 in range(T.int64(128)):\n for ax0_ax1_ax2_fused_0 in range(T.int64(4)):\n for ax0_ax1_ax2_fused_1 in T.thread_binding(T.int64(64), thread=\"threadIdx.x\"):\n for ax0_ax1_ax2_fused_2 in T.vectorized(T.int64(4)):\n with T.block(\"lv37_pad_shared\"):\n v0 = T.axis.spatial(T.int64(1), T.int64(0))\n v1 = T.axis.spatial(T.int64(32), (ax0_ax1_ax2_fused_0 * T.int64(256) + ax0_ax1_ax2_fused_1 * T.int64(4) + ax0_ax1_ax2_fused_2) // T.int64(32))\n v2 = T.axis.spatial(T.int64(4096), k_0 * T.int64(32) + (ax0_ax1_ax2_fused_0 * T.int64(256) + ax0_ax1_ax2_fused_1 * T.int64(4) + ax0_ax1_ax2_fused_2) % T.int64(32))\n T.reads(lv37[v_i0 + v0, v_i1_o * T.int64(32) + v1, v2])\n T.writes(lv37_pad_shared[v0, v1, v2])\n lv37_pad_shared[v0, v1, v2] = T.if_then_else(v_i1_o * T.int64(32) + v1 < n, lv37[v_i0 + v0, v_i1_o * T.int64(32) + v1, v2], T.float32(0))\n for ax0_ax1_fused_0 in range(T.int64(8)):\n for ax0_ax1_fused_1 in T.thread_binding(T.int64(64), thread=\"threadIdx.x\"):\n for ax0_ax1_fused_2 in T.vectorized(T.int64(2)):\n with T.block(\"linear_weight6_shared\"):\n v0 = T.axis.spatial(T.int64(11008), i0_0_i1_1_0_i2_0_fused * T.int64(32) + (ax0_ax1_fused_0 * T.int64(128) + ax0_ax1_fused_1 * T.int64(2) + ax0_ax1_fused_2) // T.int64(32))\n v1 = T.axis.spatial(T.int64(4096), k_0 * T.int64(32) + (ax0_ax1_fused_0 * T.int64(128) + ax0_ax1_fused_1 * T.int64(2) + ax0_ax1_fused_2) % T.int64(32))\n T.reads(linear_weight6[v0, v1])\n T.writes(linear_weight6_shared[v0, v1])\n linear_weight6_shared[v0, v1] = linear_weight6[v0, v1]\n for k_1, i0_3, i1_1_3, i2_3, k_2, i0_4, i1_1_4, i2_4 in T.grid(T.int64(8), T.int64(1), T.int64(2), T.int64(2), T.int64(4), T.int64(1), T.int64(2), T.int64(2)):\n with T.block(\"NT_matmul_update\"):\n v_i1_i = T.axis.spatial(T.int64(32), i0_2_i1_1_2_i2_2_fused // T.int64(8) * T.int64(4) + i1_1_3 * T.int64(2) + i1_1_4)\n v_i2_i = T.axis.spatial(T.int64(11008), i0_0_i1_1_0_i2_0_fused * T.int64(32) + i0_2_i1_1_2_i2_2_fused % T.int64(8) * T.int64(4) + i2_3 * T.int64(2) + i2_4)\n v_k_i = T.axis.reduce(T.int64(4096), k_0 * T.int64(32) + k_1 * T.int64(4) + k_2)\n T.reads(var_NT_matmul_intermediate_pad_local[T.int64(0), v_i1_i, v_i2_i], lv37_pad_shared[T.int64(0), v_i1_i, v_k_i], linear_weight6_shared[v_i2_i, v_k_i])\n T.writes(var_NT_matmul_intermediate_pad_local[T.int64(0), v_i1_i, v_i2_i])\n T.block_attr({\"meta_schedule.thread_extent_high_inclusive\": 256, \"meta_schedule.thread_extent_low_inclusive\": 32, \"meta_schedule.tiling_structure\": \"SSSRRSRS\"})\n var_NT_matmul_intermediate_pad_local[T.int64(0), v_i1_i, v_i2_i] = var_NT_matmul_intermediate_pad_local[T.int64(0), v_i1_i, v_i2_i] + lv37_pad_shared[T.int64(0), v_i1_i, v_k_i] * linear_weight6_shared[v_i2_i, v_k_i]\n for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(4), T.int64(4)):\n with T.block(\"var_NT_matmul_intermediate_pad_local\"):\n v0 = T.axis.spatial(T.int64(1), ax0)\n v1 = T.axis.spatial(T.int64(32), i0_2_i1_1_2_i2_2_fused // T.int64(8) * T.int64(4) + ax1)\n v2 = T.axis.spatial(T.int64(11008), i0_0_i1_1_0_i2_0_fused * T.int64(32) + i0_2_i1_1_2_i2_2_fused % T.int64(8) * T.int64(4) + ax2)\n T.reads(lv42[v_i0 + v0, v_i1_o * T.int64(32) + v1, v2], var_NT_matmul_intermediate_pad_local[v0, v1, v2])\n T.writes(var_T_multiply_intermediate[v_i0 + v0, v_i1_o * T.int64(32) + v1, v2])\n # if T.int64(0) <= v_i0 and v_i0 < T.int64(1) and T.int64(0) <= v_i1_o * T.int64(32) + v1 and v_i1_o * T.int64(32) + v1 < n:\n if v_i1_o * T.int64(32) + v1 < n:\n var_T_multiply_intermediate[v_i0 + v0, v_i1_o * T.int64(32) + v1, v2] = lv42[v_i0 + v0, v_i1_o * T.int64(32) + v1, v2] * var_NT_matmul_intermediate_pad_local[v0, v1, v2]\n\n\n@T.prim_func\ndef fused_NT_matmul2_silu_before(p_lv43: T.handle, linear_weight4: T.Buffer((T.int64(11008), T.int64(4096)), \"float32\"), p_output0: T.handle):\n T.func_attr({\"tir.noalias\": T.bool(True)})\n n = T.int64()\n lv43 = T.match_buffer(p_lv43, (T.int64(1), n, T.int64(4096)))\n var_T_multiply_intermediate = T.match_buffer(p_output0, (T.int64(1), n, T.int64(11008)))\n # with T.block(\"root\"):\n var_NT_matmul_intermediate = T.alloc_buffer((T.int64(1), n, T.int64(11008)))\n compute = T.alloc_buffer((T.int64(1), n, T.int64(11008)))\n for i0, i1, i2, k in T.grid(T.int64(1), n, T.int64(11008), T.int64(4096)):\n with T.block(\"NT_matmul\"):\n v_i0, v_i1, v_i2, v_k = T.axis.remap(\"SSSR\", [i0, i1, i2, k])\n T.reads(lv43[v_i0, v_i1, v_k], linear_weight4[v_i2, v_k])\n T.writes(var_NT_matmul_intermediate[v_i0, v_i1, v_i2])\n with T.init():\n var_NT_matmul_intermediate[v_i0, v_i1, v_i2] = T.float32(0)\n var_NT_matmul_intermediate[v_i0, v_i1, v_i2] = var_NT_matmul_intermediate[v_i0, v_i1, v_i2] + lv43[v_i0, v_i1, v_k] * linear_weight4[v_i2, v_k]\n for i0, i1, i2 in T.grid(T.int64(1), n, T.int64(11008)):\n with T.block(\"compute\"):\n v_i0, v_i1, v_i2 = T.axis.remap(\"SSS\", [i0, i1, i2])\n T.reads(var_NT_matmul_intermediate[v_i0, v_i1, v_i2])\n T.writes(compute[v_i0, v_i1, v_i2])\n compute[v_i0, v_i1, v_i2] = T.sigmoid(var_NT_matmul_intermediate[v_i0, v_i1, v_i2])\n for ax0, ax1, ax2 in T.grid(T.int64(1), n, T.int64(11008)):\n with T.block(\"T_multiply\"):\n v_ax0, v_ax1, v_ax2 = T.axis.remap(\"SSS\", [ax0, ax1, ax2])\n T.reads(var_NT_matmul_intermediate[v_ax0, v_ax1, v_ax2], compute[v_ax0, v_ax1, v_ax2])\n T.writes(var_T_multiply_intermediate[v_ax0, v_ax1, v_ax2])\n var_T_multiply_intermediate[v_ax0, v_ax1, v_ax2] = var_NT_matmul_intermediate[v_ax0, v_ax1, v_ax2] * compute[v_ax0, v_ax1, v_ax2]\n\n\n@T.prim_func\ndef fused_NT_matmul2_silu_after(p_lv37: T.handle, linear_weight4: T.Buffer((T.int64(11008), T.int64(4096)), \"float32\"), p_output0: T.handle):\n T.func_attr({\"tir.noalias\": T.bool(True), \"tir.is_scheduled\": 1})\n n = T.int64()\n lv37 = T.match_buffer(p_lv37, (T.int64(1), n, T.int64(4096)))\n var_T_multiply_intermediate = T.match_buffer(p_output0, (T.int64(1), n, T.int64(11008)))\n # with T.block(\"root\"):\n var_NT_matmul_intermediate = T.alloc_buffer((T.int64(1), n, T.int64(11008)))\n for i1_0 in T.thread_binding((n + T.int64(31)) // T.int64(32), thread=\"blockIdx.y\"):\n with T.block(\"NT_matmul_o\"):\n v_i0 = T.axis.spatial(T.int64(1), T.int64(0))\n v_i1_o = T.axis.spatial((n + T.int64(31)) // T.int64(32), i1_0)\n T.reads(lv37[T.Add(v_i0, T.int64(0)), v_i1_o * T.int64(32):v_i1_o * T.int64(32) + T.int64(32), T.int64(0):T.int64(4096)], linear_weight4[T.int64(0):T.int64(11008), T.int64(0):T.int64(4096)])\n T.writes(var_NT_matmul_intermediate[v_i0, v_i1_o * T.int64(32):v_i1_o * T.int64(32) + T.int64(32), T.int64(0):T.int64(11008)])\n var_NT_matmul_intermediate_pad_local = T.alloc_buffer((T.int64(1), T.int64(32), T.int64(11008)), scope=\"local\")\n lv37_pad_shared = T.alloc_buffer((T.int64(1), T.int64(32), T.int64(4096)), scope=\"shared\")\n linear_weight4_shared = T.alloc_buffer((T.int64(11008), T.int64(4096)), scope=\"shared\")\n for i0_0_i1_1_0_i2_0_fused in T.thread_binding(T.int64(344), thread=\"blockIdx.x\", annotations={\"pragma_auto_unroll_max_step\": 16, \"pragma_unroll_explicit\": 1}):\n for i0_1_i1_1_1_i2_1_fused in T.thread_binding(T.int64(1), thread=\"vthread.x\"):\n for i0_2_i1_1_2_i2_2_fused in T.thread_binding(T.int64(64), thread=\"threadIdx.x\"):\n for i1_1_3_init, i2_3_init, i1_1_4_init, i2_4_init in T.grid(T.int64(2), T.int64(4), T.int64(2), T.int64(1)):\n with T.block(\"NT_matmul_init\"):\n v_i1_i = T.axis.spatial(T.int64(32), i0_2_i1_1_2_i2_2_fused // T.int64(8) * T.int64(4) + i1_1_3_init * T.int64(2) + i1_1_4_init)\n v_i2_i = T.axis.spatial(T.int64(11008), i2_4_init + i0_0_i1_1_0_i2_0_fused * T.int64(32) + i0_2_i1_1_2_i2_2_fused % T.int64(8) * T.int64(4) + i2_3_init)\n T.reads()\n T.writes(var_NT_matmul_intermediate_pad_local[T.int64(0), v_i1_i, v_i2_i])\n T.block_attr({\"meta_schedule.thread_extent_high_inclusive\": 256, \"meta_schedule.thread_extent_low_inclusive\": 32, \"meta_schedule.tiling_structure\": \"SSSRRSRS\"})\n var_NT_matmul_intermediate_pad_local[T.int64(0), v_i1_i, v_i2_i] = T.float32(0)\n for k_0 in range(T.int64(128)):\n for ax0_ax1_ax2_fused_0 in range(T.int64(4)):\n for ax0_ax1_ax2_fused_1 in T.thread_binding(T.int64(64), thread=\"threadIdx.x\"):\n for ax0_ax1_ax2_fused_2 in T.vectorized(T.int64(4)):\n with T.block(\"lv37_pad_shared\"):\n v0 = T.axis.spatial(T.int64(1), T.int64(0))\n v1 = T.axis.spatial(T.int64(32), (ax0_ax1_ax2_fused_0 * T.int64(256) + ax0_ax1_ax2_fused_1 * T.int64(4) + ax0_ax1_ax2_fused_2) // T.int64(32))\n v2 = T.axis.spatial(T.int64(4096), k_0 * T.int64(32) + (ax0_ax1_ax2_fused_0 * T.int64(256) + ax0_ax1_ax2_fused_1 * T.int64(4) + ax0_ax1_ax2_fused_2) % T.int64(32))\n T.reads(lv37[v_i0 + v0, v_i1_o * T.int64(32) + v1, v2])\n T.writes(lv37_pad_shared[v0, v1, v2])\n lv37_pad_shared[v0, v1, v2] = T.if_then_else(v_i1_o * T.int64(32) + v1 < n, lv37[v_i0 + v0, v_i1_o * T.int64(32) + v1, v2], T.float32(0))\n for ax0_ax1_fused_0 in range(T.int64(8)):\n for ax0_ax1_fused_1 in T.thread_binding(T.int64(64), thread=\"threadIdx.x\"):\n for ax0_ax1_fused_2 in T.vectorized(T.int64(2)):\n with T.block(\"linear_weight4_shared\"):\n v0 = T.axis.spatial(T.int64(11008), i0_0_i1_1_0_i2_0_fused * T.int64(32) + (ax0_ax1_fused_0 * T.int64(128) + ax0_ax1_fused_1 * T.int64(2) + ax0_ax1_fused_2) // T.int64(32))\n v1 = T.axis.spatial(T.int64(4096), k_0 * T.int64(32) + (ax0_ax1_fused_0 * T.int64(128) + ax0_ax1_fused_1 * T.int64(2) + ax0_ax1_fused_2) % T.int64(32))\n T.reads(linear_weight4[v0, v1])\n T.writes(linear_weight4_shared[v0, v1])\n linear_weight4_shared[v0, v1] = linear_weight4[v0, v1]\n for k_1, i0_3, i1_1_3, i2_3, k_2, i0_4, i1_1_4, i2_4 in T.grid(T.int64(8), T.int64(1), T.int64(2), T.int64(4), T.int64(4), T.int64(1), T.int64(2), T.int64(1)):\n with T.block(\"NT_matmul_update\"):\n v_i1_i = T.axis.spatial(T.int64(32), i0_2_i1_1_2_i2_2_fused // T.int64(8) * T.int64(4) + i1_1_3 * T.int64(2) + i1_1_4)\n v_i2_i = T.axis.spatial(T.int64(11008), i2_4 + i0_0_i1_1_0_i2_0_fused * T.int64(32) + i0_2_i1_1_2_i2_2_fused % T.int64(8) * T.int64(4) + i2_3)\n v_k_i = T.axis.reduce(T.int64(4096), k_0 * T.int64(32) + k_1 * T.int64(4) + k_2)\n T.reads(var_NT_matmul_intermediate_pad_local[T.int64(0), v_i1_i, v_i2_i], lv37_pad_shared[T.int64(0), v_i1_i, v_k_i], linear_weight4_shared[v_i2_i, v_k_i])\n T.writes(var_NT_matmul_intermediate_pad_local[T.int64(0), v_i1_i, v_i2_i])\n T.block_attr({\"meta_schedule.thread_extent_high_inclusive\": 256, \"meta_schedule.thread_extent_low_inclusive\": 32, \"meta_schedule.tiling_structure\": \"SSSRRSRS\"})\n var_NT_matmul_intermediate_pad_local[T.int64(0), v_i1_i, v_i2_i] = var_NT_matmul_intermediate_pad_local[T.int64(0), v_i1_i, v_i2_i] + lv37_pad_shared[T.int64(0), v_i1_i, v_k_i] * linear_weight4_shared[v_i2_i, v_k_i]\n for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(4), T.int64(4)):\n with T.block(\"var_NT_matmul_intermediate_pad_local\"):\n v0 = T.axis.spatial(T.int64(1), ax0)\n v1 = T.axis.spatial(T.int64(32), i0_2_i1_1_2_i2_2_fused // T.int64(8) * T.int64(4) + ax1)\n v2 = T.axis.spatial(T.int64(11008), i0_0_i1_1_0_i2_0_fused * T.int64(32) + i0_2_i1_1_2_i2_2_fused % T.int64(8) * T.int64(4) + ax2)\n T.reads(var_NT_matmul_intermediate_pad_local[v0, v1, v2])\n T.writes(var_NT_matmul_intermediate[v_i0 + v0, v_i1_o * T.int64(32) + v1, v2])\n # if T.int64(0) <= v_i0 and v_i0 < T.int64(1) and T.int64(0) <= v_i1_o * T.int64(32) + v1 and v_i1_o * T.int64(32) + v1 < n:\n if v_i1_o * T.int64(32) + v1 < n:\n var_NT_matmul_intermediate[v_i0 + v0, v_i1_o * T.int64(32) + v1, v2] = var_NT_matmul_intermediate_pad_local[v0, v1, v2]\n for ax0_ax1_ax2_fused_1 in T.thread_binding(T.int64(256), thread=\"blockIdx.x\", annotations={\"pragma_auto_unroll_max_step\": 16, \"pragma_unroll_explicit\": 1}):\n for ax0_ax1_ax2_fused_2 in T.thread_binding(T.int64(256), thread=\"threadIdx.x\"):\n for ax0_ax1_ax2_fused_0 in range((n * T.int64(11008) + T.int64(65535)) // T.int64(65536)):\n with T.block(\"T_multiply\"):\n v_ax0 = T.axis.spatial(T.int64(1), T.int64(0))\n v_ax1 = T.axis.spatial(n, (ax0_ax1_ax2_fused_0 * T.int64(65536) + ax0_ax1_ax2_fused_1 * T.int64(256) + ax0_ax1_ax2_fused_2) // T.int64(11008))\n v_ax2 = T.axis.spatial(T.int64(11008), (ax0_ax1_ax2_fused_0 * T.int64(65536) + ax0_ax1_ax2_fused_1 * T.int64(256) + ax0_ax1_ax2_fused_2) % T.int64(11008))\n T.where((ax0_ax1_ax2_fused_0 * T.int64(256) + ax0_ax1_ax2_fused_1) * T.int64(256) + ax0_ax1_ax2_fused_2 < n * T.int64(11008))\n T.reads(var_NT_matmul_intermediate[v_ax0, v_ax1, v_ax2])\n T.writes(var_T_multiply_intermediate[v_ax0, v_ax1, v_ax2])\n var_T_multiply_intermediate[v_ax0, v_ax1, v_ax2] = var_NT_matmul_intermediate[v_ax0, v_ax1, v_ax2] * T.sigmoid(var_NT_matmul_intermediate[v_ax0, v_ax1, v_ax2])\n\n\n@T.prim_func\n\ndef fused_NT_matmul3_add1_before(p_lv49: T.handle, linear_weight5: T.Buffer((T.int64(4096), T.int64(11008)), \"float32\"), p_lv42: T.handle, p_output0: T.handle):\n T.func_attr({\"tir.noalias\": T.bool(True)})\n n = T.int64()\n lv49 = T.match_buffer(p_lv49, (T.int64(1), n, T.int64(11008)))\n lv42 = T.match_buffer(p_lv42, (T.int64(1), n, T.int64(4096)))\n var_T_add_intermediate = T.match_buffer(p_output0, (T.int64(1), n, T.int64(4096)))\n # with T.block(\"root\"):\n var_NT_matmul_intermediate = T.alloc_buffer((T.int64(1), n, T.int64(4096)))\n for i0, i1, i2, k in T.grid(T.int64(1), n, T.int64(4096), T.int64(11008)):\n with T.block(\"NT_matmul\"):\n v_i0, v_i1, v_i2, v_k = T.axis.remap(\"SSSR\", [i0, i1, i2, k])\n T.reads(lv49[v_i0, v_i1, v_k], linear_weight5[v_i2, v_k])\n T.writes(var_NT_matmul_intermediate[v_i0, v_i1, v_i2])\n with T.init():\n var_NT_matmul_intermediate[v_i0, v_i1, v_i2] = T.float32(0)\n var_NT_matmul_intermediate[v_i0, v_i1, v_i2] = var_NT_matmul_intermediate[v_i0, v_i1, v_i2] + lv49[v_i0, v_i1, v_k] * linear_weight5[v_i2, v_k]\n for ax0, ax1, ax2 in T.grid(T.int64(1), n, T.int64(4096)):\n with T.block(\"T_add\"):\n v_ax0, v_ax1, v_ax2 = T.axis.remap(\"SSS\", [ax0, ax1, ax2])\n T.reads(lv42[v_ax0, v_ax1, v_ax2], var_NT_matmul_intermediate[v_ax0, v_ax1, v_ax2])\n T.writes(var_T_add_intermediate[v_ax0, v_ax1, v_ax2])\n var_T_add_intermediate[v_ax0, v_ax1, v_ax2] = lv42[v_ax0, v_ax1, v_ax2] + var_NT_matmul_intermediate[v_ax0, v_ax1, v_ax2]\n\n\n@T.prim_func\ndef fused_NT_matmul3_add1_after(p_lv43: T.handle, linear_weight5: T.Buffer((T.int64(4096), T.int64(11008)), \"float32\"), p_lv36: T.handle, p_output0: T.handle):\n T.func_attr({\"tir.noalias\": T.bool(True), \"tir.is_scheduled\": 1})\n n = T.int64()\n lv43 = T.match_buffer(p_lv43, (T.int64(1), n, T.int64(11008)))\n lv36 = T.match_buffer(p_lv36, (T.int64(1), n, T.int64(4096)))\n var_T_add_intermediate = T.match_buffer(p_output0, (T.int64(1), n, T.int64(4096)))\n # with T.block(\"root\"):\n for i1_0 in T.thread_binding((n + T.int64(31)) // T.int64(32), thread=\"blockIdx.y\"):\n with T.block(\"NT_matmul_o\"):\n v_i0 = T.axis.spatial(T.int64(1), T.int64(0))\n v_i1_o = T.axis.spatial((n + T.int64(31)) // T.int64(32), i1_0)\n T.reads(lv43[T.Add(v_i0, T.int64(0)), v_i1_o * T.int64(32):v_i1_o * T.int64(32) + T.int64(32), T.int64(0):T.int64(11008)], linear_weight5[T.int64(0):T.int64(4096), T.int64(0):T.int64(11008)], lv36[v_i0, v_i1_o * T.int64(32):v_i1_o * T.int64(32) + T.int64(32), T.int64(0):T.int64(4096)])\n T.writes(var_T_add_intermediate[v_i0, v_i1_o * T.int64(32):v_i1_o * T.int64(32) + T.int64(32), T.int64(0):T.int64(4096)])\n var_NT_matmul_intermediate_pad_local = T.alloc_buffer((T.int64(1), T.int64(32), T.int64(4096)), scope=\"local\")\n lv43_pad_shared = T.alloc_buffer((T.int64(1), T.int64(32), T.int64(11008)), scope=\"shared\")\n linear_weight5_shared = T.alloc_buffer((T.int64(4096), T.int64(11008)), scope=\"shared\")\n for i0_0_i1_1_0_i2_0_fused in T.thread_binding(T.int64(128), thread=\"blockIdx.x\", annotations={\"pragma_auto_unroll_max_step\": 16, \"pragma_unroll_explicit\": 1}):\n for i0_1_i1_1_1_i2_1_fused in T.thread_binding(T.int64(4), thread=\"vthread.x\"):\n for i0_2_i1_1_2_i2_2_fused in T.thread_binding(T.int64(64), thread=\"threadIdx.x\"):\n for i1_1_3_init, i2_3_init, i1_1_4_init, i2_4_init in T.grid(T.int64(2), T.int64(2), T.int64(1), T.int64(1)):\n with T.block(\"NT_matmul_init\"):\n v_i1_i = T.axis.spatial(T.int64(32), i0_1_i1_1_1_i2_1_fused // T.int64(2) * T.int64(16) + i0_2_i1_1_2_i2_2_fused // T.int64(8) * T.int64(2) + i1_1_3_init + i1_1_4_init)\n v_i2_i = T.axis.spatial(T.int64(4096), i2_4_init + i0_0_i1_1_0_i2_0_fused * T.int64(32) + i0_1_i1_1_1_i2_1_fused % T.int64(2) * T.int64(16) + i0_2_i1_1_2_i2_2_fused % T.int64(8) * T.int64(2) + i2_3_init)\n T.reads()\n T.writes(var_NT_matmul_intermediate_pad_local[T.int64(0), v_i1_i, v_i2_i])\n T.block_attr({\"meta_schedule.thread_extent_high_inclusive\": 256, \"meta_schedule.thread_extent_low_inclusive\": 32, \"meta_schedule.tiling_structure\": \"SSSRRSRS\"})\n var_NT_matmul_intermediate_pad_local[T.int64(0), v_i1_i, v_i2_i] = T.float32(0)\n for k_0 in range(T.int64(344)):\n for ax0_ax1_ax2_fused_0 in range(T.int64(4)):\n for ax0_ax1_ax2_fused_1 in T.thread_binding(T.int64(64), thread=\"threadIdx.x\"):\n for ax0_ax1_ax2_fused_2 in T.vectorized(T.int64(4)):\n with T.block(\"lv43_pad_shared\"):\n v0 = T.axis.spatial(T.int64(1), T.int64(0))\n v1 = T.axis.spatial(T.int64(32), (ax0_ax1_ax2_fused_0 * T.int64(256) + ax0_ax1_ax2_fused_1 * T.int64(4) + ax0_ax1_ax2_fused_2) // T.int64(32))\n v2 = T.axis.spatial(T.int64(11008), k_0 * T.int64(32) + (ax0_ax1_ax2_fused_0 * T.int64(256) + ax0_ax1_ax2_fused_1 * T.int64(4) + ax0_ax1_ax2_fused_2) % T.int64(32))\n T.reads(lv43[v_i0 + v0, v_i1_o * T.int64(32) + v1, v2])\n T.writes(lv43_pad_shared[v0, v1, v2])\n lv43_pad_shared[v0, v1, v2] = T.if_then_else(v_i1_o * T.int64(32) + v1 < n, lv43[v_i0 + v0, v_i1_o * T.int64(32) + v1, v2], T.float32(0))\n for ax0_ax1_fused_0 in range(T.int64(8)):\n for ax0_ax1_fused_1 in T.thread_binding(T.int64(64), thread=\"threadIdx.x\"):\n for ax0_ax1_fused_2 in T.vectorized(T.int64(2)):\n with T.block(\"linear_weight5_shared\"):\n v0 = T.axis.spatial(T.int64(4096), i0_0_i1_1_0_i2_0_fused * T.int64(32) + (ax0_ax1_fused_0 * T.int64(128) + ax0_ax1_fused_1 * T.int64(2) + ax0_ax1_fused_2) // T.int64(32))\n v1 = T.axis.spatial(T.int64(11008), k_0 * T.int64(32) + (ax0_ax1_fused_0 * T.int64(128) + ax0_ax1_fused_1 * T.int64(2) + ax0_ax1_fused_2) % T.int64(32))\n T.reads(linear_weight5[v0, v1])\n T.writes(linear_weight5_shared[v0, v1])\n linear_weight5_shared[v0, v1] = linear_weight5[v0, v1]\n for k_1, i0_3, i1_1_3, i2_3, k_2, i0_4, i1_1_4, i2_4 in T.grid(T.int64(8), T.int64(1), T.int64(2), T.int64(2), T.int64(4), T.int64(1), T.int64(1), T.int64(1)):\n with T.block(\"NT_matmul_update\"):\n v_i1_i = T.axis.spatial(T.int64(32), i0_1_i1_1_1_i2_1_fused // T.int64(2) * T.int64(16) + i0_2_i1_1_2_i2_2_fused // T.int64(8) * T.int64(2) + i1_1_3 + i1_1_4)\n v_i2_i = T.axis.spatial(T.int64(4096), i2_4 + i0_0_i1_1_0_i2_0_fused * T.int64(32) + i0_1_i1_1_1_i2_1_fused % T.int64(2) * T.int64(16) + i0_2_i1_1_2_i2_2_fused % T.int64(8) * T.int64(2) + i2_3)\n v_k_i = T.axis.reduce(T.int64(11008), k_0 * T.int64(32) + k_1 * T.int64(4) + k_2)\n T.reads(var_NT_matmul_intermediate_pad_local[T.int64(0), v_i1_i, v_i2_i], lv43_pad_shared[T.int64(0), v_i1_i, v_k_i], linear_weight5_shared[v_i2_i, v_k_i])\n T.writes(var_NT_matmul_intermediate_pad_local[T.int64(0), v_i1_i, v_i2_i])\n T.block_attr({\"meta_schedule.thread_extent_high_inclusive\": 256, \"meta_schedule.thread_extent_low_inclusive\": 32, \"meta_schedule.tiling_structure\": \"SSSRRSRS\"})\n var_NT_matmul_intermediate_pad_local[T.int64(0), v_i1_i, v_i2_i] = var_NT_matmul_intermediate_pad_local[T.int64(0), v_i1_i, v_i2_i] + lv43_pad_shared[T.int64(0), v_i1_i, v_k_i] * linear_weight5_shared[v_i2_i, v_k_i]\n for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(2), T.int64(2)):\n with T.block(\"var_NT_matmul_intermediate_pad_local\"):\n v0 = T.axis.spatial(T.int64(1), ax0)\n v1 = T.axis.spatial(T.int64(32), i0_1_i1_1_1_i2_1_fused // T.int64(2) * T.int64(16) + i0_2_i1_1_2_i2_2_fused // T.int64(8) * T.int64(2) + ax1)\n v2 = T.axis.spatial(T.int64(4096), i0_0_i1_1_0_i2_0_fused * T.int64(32) + i0_1_i1_1_1_i2_1_fused % T.int64(2) * T.int64(16) + i0_2_i1_1_2_i2_2_fused % T.int64(8) * T.int64(2) + ax2)\n T.reads(lv36[v_i0 + v0, v_i1_o * T.int64(32) + v1, v2], var_NT_matmul_intermediate_pad_local[v0, v1, v2])\n T.writes(var_T_add_intermediate[v_i0 + v0, v_i1_o * T.int64(32) + v1, v2])\n # if T.int64(0) <= v_i0 and v_i0 < T.int64(1) and T.int64(0) <= v_i1_o * T.int64(32) + v1 and v_i1_o * T.int64(32) + v1 < n:\n if v_i1_o * T.int64(32) + v1 < n:\n var_T_add_intermediate[v_i0 + v0, v_i1_o * T.int64(32) + v1, v2] = lv36[v_i0 + v0, v_i1_o * T.int64(32) + v1, v2] + var_NT_matmul_intermediate_pad_local[v0, v1, v2]\n\n\n\n@T.prim_func\ndef fused_NT_matmul_divide_maximum_minimum_cast_before(lv1605: T.Buffer((T.int64(1), T.int64(32), T.int64(1), T.int64(128)), \"float16\"), p_lv1606: T.handle, p_lv1582: T.handle, p_output0: T.handle):\n T.func_attr({\"tir.noalias\": T.bool(True)})\n n = T.int64()\n lv1606 = T.match_buffer(p_lv1606, (T.int64(1), T.int64(32), n, T.int64(128)), \"float16\")\n lv1582 = T.match_buffer(p_lv1582, (T.int64(1), T.int64(1), T.int64(1), n), \"float16\")\n var_compute_intermediate = T.match_buffer(p_output0, (T.int64(1), T.int64(32), T.int64(1), n))\n # with T.block(\"root\"):\n var_NT_matmul_intermediate = T.alloc_buffer((T.int64(1), T.int64(32), T.int64(1), n), \"float16\")\n var_T_divide_intermediate = T.alloc_buffer((T.int64(1), T.int64(32), T.int64(1), n), \"float16\")\n var_T_maximum_intermediate = T.alloc_buffer((T.int64(1), T.int64(32), T.int64(1), n), \"float16\")\n var_T_minimum_intermediate = T.alloc_buffer((T.int64(1), T.int64(32), T.int64(1), n), \"float16\")\n for i0, i1, i2, i3, k in T.grid(T.int64(1), T.int64(32), T.int64(1), n, T.int64(128)):\n with T.block(\"NT_matmul\"):\n v_i0, v_i1, v_i2, v_i3, v_k = T.axis.remap(\"SSSSR\", [i0, i1, i2, i3, k])\n T.reads(lv1605[v_i0, v_i1, v_i2, v_k], lv1606[v_i0, v_i1, v_i3, v_k])\n T.writes(var_NT_matmul_intermediate[v_i0, v_i1, v_i2, v_i3])\n with T.init():\n var_NT_matmul_intermediate[v_i0, v_i1, v_i2, v_i3] = T.float16(0)\n var_NT_matmul_intermediate[v_i0, v_i1, v_i2, v_i3] = var_NT_matmul_intermediate[v_i0, v_i1, v_i2, v_i3] + lv1605[v_i0, v_i1, v_i2, v_k] * lv1606[v_i0, v_i1, v_i3, v_k]\n for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(32), T.int64(1), n):\n with T.block(\"T_divide\"):\n v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap(\"SSSS\", [ax0, ax1, ax2, ax3])\n T.reads(var_NT_matmul_intermediate[v_ax0, v_ax1, v_ax2, v_ax3])\n T.writes(var_T_divide_intermediate[v_ax0, v_ax1, v_ax2, v_ax3])\n var_T_divide_intermediate[v_ax0, v_ax1, v_ax2, v_ax3] = var_NT_matmul_intermediate[v_ax0, v_ax1, v_ax2, v_ax3] * T.float16(0.088397790055248615)\n for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(32), T.int64(1), n):\n with T.block(\"T_maximum\"):\n v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap(\"SSSS\", [ax0, ax1, ax2, ax3])\n T.reads(var_T_divide_intermediate[v_ax0, v_ax1, v_ax2, v_ax3])\n T.writes(var_T_maximum_intermediate[v_ax0, v_ax1, v_ax2, v_ax3])\n...\n", "needle_token_start": 10214, "needle_token_end": 11053, "code_context_ntokens": 16341, "output": ["```python\n@T.prim_func\ndef fused_NT_matmul3_add1_before(p_lv49: T.handle, linear_weight5: T.Buffer((T.int64(4096), T.int64(11008)), \"float32\"), p_lv42: T.handle, p_output0: T.handle):\n T.func_attr({\"tir.noalias\": T.bool(True)})\n n = T.int64()\n lv49 = T.match_buffer(p_lv49, (T.int64(1), n, T.int64(11008)))\n lv42 = T.match_buffer(p_lv42, (T.int64(1), n, T.int64(4096)))\n var_T_add_intermediate = T.match_buffer(p_output0, (T.int64(1), n, T.int64(4096)))\n # with T.block(\"root\"):\n var_NT_matmul_intermediate = T.alloc_buffer((T.int64(1), n, T.int64(4096)))\n for i0, i1, i2, k in T.grid(T.int64(1), n, T.int64(4096), T.int64(11008)):\n with T.block(\"NT_matmul\"):\n v_i0, v_i1, v_i2, v_k = T.axis.remap(\"SSSR\", [i0, i1, i2, k])\n T.reads(lv49[v_i0, v_i1, v_k], linear_weight5[v_i2, v_k])\n T.writes(var_NT_matmul_intermediate[v_i0, v_i1, v_i2])\n with T.init():\n var_NT_matmul_intermediate[v_i0, v_i1, v_i2] = T.float32(0)\n var_NT_matmul_intermediate[v_i0, v_i1, v_i2] = var_NT_matmul_intermediate[v_i0, v_i1, v_i2] + lv49[v_i0, v_i1, v_k] * linear_weight5[v_i2, v_k]\n for ax0, ax1, ax2 in T.grid(T.int64(1), n, T.int64(4096)):\n with T.block(\"T_add\"):\n v_ax0, v_ax1, v_ax2 = T.axis.remap(\"SSS\", [ax0, ax1, ax2])\n T.reads(lv42[v_ax0, v_ax1, v_ax2], var_NT_matmul_intermediate[v_ax0, v_ax1, v_ax2])\n T.writes(var_T_add_intermediate[v_ax0, v_ax1, v_ax2])\n var_T_add_intermediate[v_ax0, v_ax1, v_ax2] = lv42[v_ax0, v_ax1, v_ax2] + var_NT_matmul_intermediate[v_ax0, v_ax1, v_ax2]\n```"]} +{"repo": "mlc-ai/mlc-llm", "name": "fused_decode4_matmul5_int3_int16_fp16_before", "language": "python", "path": "mlc_llm/dispatch/llama/main.py", "position_ratio": 0.75, "description": "\nFunction Description:\n1. **Purpose**: The function performs a decoding operation followed by a matrix multiplication. The decoding involves extracting and transforming elements from an input buffer, which are then used in a matrix multiplication with another input buffer.\n2. **Input**: \n - A buffer of 16-bit unsigned integers with dimensions (824, 4096), representing encoded data.\n - A buffer of 16-bit floating-point numbers with dimensions (103, 4096), representing scaling factors for the decoding.\n - A buffer of 16-bit floating-point numbers with dimensions (1, 1, 4096), used as one of the multiplicands in the matrix multiplication.\n3. **Output**: \n - A buffer of 16-bit floating-point numbers with dimensions (1, 1, 4096), storing the result of the matrix multiplication.\n4. **Procedure**: \n - The function first allocates a buffer for intermediate decoded values.\n - It iterates over the dimensions of the encoded data buffer, decoding each element using bitwise operations and scaling it by corresponding values from the scaling factors buffer.\n - The decoded data is then used in a matrix multiplication operation with the provided multiplicand buffer.\n - The result of the matrix multiplication is stored in the output buffer.\n\n", "instruction": "Based on the function description and code context, please retrieve and repeat the exact described function from the code context in a code block wrapped by ```:", "template": "instruction\ncode_context\ndescription\ninstruction", "code_context": " var_decode_intermediate_local[v_j, v_i] = T.Cast(\"float16\", T.bitwise_and(T.shift_right(lv1148[v_j // T.int64(10), v_i], T.Cast(\"uint32\", v_j % T.int64(10)) * T.uint32(3)), T.uint32(7))) - T.float16(3)\n for k_0_1_k_1_fused in range(T.int64(40)):\n with T.block(\"matmul_update\"):\n v_i0 = T.axis.spatial(T.int64(1), T.int64(0))\n v_i1 = T.axis.spatial(T.int64(1), T.int64(0))\n v_i2 = T.axis.spatial(T.int64(11008), i0_i1_i2_0_fused * T.int64(256) + i2_1 * T.int64(256) + i2_2)\n v_k = T.axis.reduce(T.int64(4120), k_0_0 * T.int64(40) + k_0_1_k_1_fused)\n T.reads(var_matmul_intermediate_local[v_i0, v_i1, v_i2], lv2749_shared[v_i0, v_i1, v_k], var_decode_intermediate_local[v_k, v_i2])\n T.writes(var_matmul_intermediate_local[v_i0, v_i1, v_i2])\n var_matmul_intermediate_local[v_i0, v_i1, v_i2] = var_matmul_intermediate_local[v_i0, v_i1, v_i2] + lv2749_shared[v_i0, v_i1, v_k] * var_decode_intermediate_local[v_k, v_i2] * lv44[v_k // T.int64(40), v_i2]\n for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(1), T.int64(1)):\n with T.block(\"var_matmul_intermediate_local\"):\n v0, v1 = T.axis.remap(\"SS\", [ax0, ax1])\n v2 = T.axis.spatial(T.int64(11008), i0_i1_i2_0_fused * T.int64(256) + i2_2 + ax2)\n T.reads(var_matmul_intermediate_local[v0, v1, v2])\n T.writes(p_output0_intermediate[v0, v1, v2])\n p_output0_intermediate[v0, v1, v2] = var_matmul_intermediate_local[v0, v1, v2] * T.sigmoid(var_matmul_intermediate_local[v0, v1, v2])\n\n\n@T.prim_func\ndef fused_decode6_fused_matmul9_add3_int3_fp16_before(lv1623: T.Buffer((T.int64(1104), T.int64(4096)), \"uint32\"), lv1624: T.Buffer((T.int64(276), T.int64(4096)), \"float16\"), lv167: T.Buffer((T.int64(1), T.int64(1), T.int64(11008)), \"float16\"), lv165: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), \"float16\"), p_output0_intermediate: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), \"float16\")):\n T.func_attr({\"tir.noalias\": T.bool(True)})\n # with T.block(\"root\"):\n var_decode_intermediate = T.alloc_buffer((T.int64(11008), T.int64(4096)), \"float16\")\n var_matmul_intermediate = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(4096)), \"float16\")\n for i, j in T.grid(T.int64(11008), T.int64(4096)):\n with T.block(\"decode\"):\n v_i, v_j = T.axis.remap(\"SS\", [i, j])\n T.reads(lv1623[v_i // T.int64(10), v_j], lv1624[v_i // T.int64(40), v_j])\n T.writes(var_decode_intermediate[v_i, v_j])\n var_decode_intermediate[v_i, v_j] = (T.Cast(\"float16\", T.bitwise_and(T.shift_right(lv1623[v_i // T.int64(10), v_j], T.Cast(\"uint32\", v_i % T.int64(10)) * T.uint32(3)), T.uint32(7))) - T.float16(3)) * lv1624[v_i // T.int64(40), v_j]\n for i0, i1, i2, k in T.grid(T.int64(1), T.int64(1), T.int64(4096), T.int64(11008)):\n with T.block(\"matmul\"):\n v_i0, v_i1, v_i2, v_k = T.axis.remap(\"SSSR\", [i0, i1, i2, k])\n T.reads(lv167[v_i0, v_i1, v_k], var_decode_intermediate[v_k, v_i2])\n T.writes(var_matmul_intermediate[v_i0, v_i1, v_i2])\n with T.init():\n var_matmul_intermediate[v_i0, v_i1, v_i2] = T.float16(0)\n var_matmul_intermediate[v_i0, v_i1, v_i2] = var_matmul_intermediate[v_i0, v_i1, v_i2] + lv167[v_i0, v_i1, v_k] * var_decode_intermediate[v_k, v_i2]\n for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(1), T.int64(4096)):\n with T.block(\"T_add\"):\n v_ax0, v_ax1, v_ax2 = T.axis.remap(\"SSS\", [ax0, ax1, ax2])\n T.reads(lv165[v_ax0, v_ax1, v_ax2], var_matmul_intermediate[v_ax0, v_ax1, v_ax2])\n T.writes(p_output0_intermediate[v_ax0, v_ax1, v_ax2])\n p_output0_intermediate[v_ax0, v_ax1, v_ax2] = lv165[v_ax0, v_ax1, v_ax2] + var_matmul_intermediate[v_ax0, v_ax1, v_ax2]\n\n\n@T.prim_func\ndef fused_decode6_fused_matmul9_add3_int3_fp16_after(lv1158: T.Buffer((T.int64(1104), T.int64(4096)), \"uint32\"), lv60: T.Buffer((T.int64(276), T.int64(4096)), \"float16\"), lv6: T.Buffer((T.int64(1), T.int64(1), T.int64(11008)), \"float16\"), lv4: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), \"float16\"), p_output0_intermediate: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), \"float16\")):\n T.func_attr({\"tir.noalias\": T.bool(True), \"tir.is_scheduled\": 1})\n # with T.block(\"root\"):\n var_decode_intermediate_local = T.alloc_buffer((T.int64(11040), T.int64(4096)), scope=\"local\", dtype=\"float16\")\n var_matmul_intermediate_local = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(4096)), scope=\"local\", dtype=\"float16\")\n lv6_shared = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(11040)), scope=\"shared\", dtype=\"float16\")\n for i0_i1_i2_0_fused in T.thread_binding(T.int64(16), thread=\"blockIdx.x\", annotations={\"pragma_auto_unroll_max_step\": 16, \"pragma_unroll_explicit\": 1}):\n for i2_1 in T.thread_binding(T.int64(1), thread=\"vthread.x\"):\n for i2_2 in T.thread_binding(T.int64(256), thread=\"threadIdx.x\"):\n with T.block(\"matmul_init\"):\n v_i0 = T.axis.spatial(T.int64(1), T.int64(0))\n v_i1 = T.axis.spatial(T.int64(1), T.int64(0))\n v_i2 = T.axis.spatial(T.int64(4096), i0_i1_i2_0_fused * T.int64(256) + i2_1 * T.int64(256) + i2_2)\n T.reads()\n T.writes(var_matmul_intermediate_local[v_i0, v_i1, v_i2])\n var_matmul_intermediate_local[v_i0, v_i1, v_i2] = T.float32(0)\n for k_0_0 in range(T.int64(2)):\n for ax0, ax1_ax2_fused_0 in T.grid(T.int64(1), T.int64(22)):\n for ax1_ax2_fused_1 in T.thread_binding(T.int64(256), thread=\"threadIdx.x\"):\n with T.block(\"lv6_shared\"):\n v0 = T.axis.spatial(T.int64(1), ax0)\n v1 = T.axis.spatial(T.int64(1), T.int64(0))\n v2 = T.axis.spatial(T.int64(11040), k_0_0 * T.int64(5520) + (ax1_ax2_fused_0 * T.int64(256) + ax1_ax2_fused_1))\n T.where(ax1_ax2_fused_0 * T.int64(256) + ax1_ax2_fused_1 < T.int64(5520))\n T.reads(lv6[v0, v1, v2])\n T.writes(lv6_shared[v0, v1, v2])\n T.block_attr({\"buffer_dim_align\": [[0, 1, 32, 8]]})\n lv6_shared[v0, v1, v2] = T.if_then_else(v2 < T.int64(11008), lv6[v0, v1, v2], T.float16(0))\n for k_0_1 in range(T.int64(69)):\n for ax0_0 in T.unroll(T.int64(80)):\n for ax1 in range(T.int64(1)):\n with T.block(\"decode\"):\n v_j = T.axis.spatial(T.int64(11040), k_0_0 * T.int64(5520) + k_0_1 * T.int64(80) + ax0_0)\n v_i = T.axis.spatial(T.int64(4096), i0_i1_i2_0_fused * T.int64(256) + i2_2 + ax1)\n T.reads(lv1158[v_j // T.int64(10), v_i], lv60[v_j // T.int64(40), v_i])\n T.writes(var_decode_intermediate_local[v_j, v_i])\n var_decode_intermediate_local[v_j, v_i] = T.Cast(\"float16\", T.bitwise_and(T.shift_right(lv1158[v_j // T.int64(10), v_i], T.Cast(\"uint32\", v_j % T.int64(10)) * T.uint32(3)), T.uint32(7))) - T.float16(3)\n for k_0_2_k_1_fused in range(T.int64(80)):\n with T.block(\"matmul_update\"):\n v_i0 = T.axis.spatial(T.int64(1), T.int64(0))\n v_i1 = T.axis.spatial(T.int64(1), T.int64(0))\n v_i2 = T.axis.spatial(T.int64(4096), i0_i1_i2_0_fused * T.int64(256) + i2_1 * T.int64(256) + i2_2)\n v_k = T.axis.reduce(T.int64(11040), k_0_0 * T.int64(5520) + k_0_1 * T.int64(80) + k_0_2_k_1_fused)\n T.reads(var_matmul_intermediate_local[v_i0, v_i1, v_i2], lv6_shared[v_i0, v_i1, v_k], var_decode_intermediate_local[v_k, v_i2])\n T.writes(var_matmul_intermediate_local[v_i0, v_i1, v_i2])\n var_matmul_intermediate_local[v_i0, v_i1, v_i2] = var_matmul_intermediate_local[v_i0, v_i1, v_i2] + lv6_shared[v_i0, v_i1, v_k] * var_decode_intermediate_local[v_k, v_i2] * lv60[v_k // T.int64(40), v_i2]\n for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(1), T.int64(1)):\n with T.block(\"var_matmul_intermediate_local\"):\n v0, v1 = T.axis.remap(\"SS\", [ax0, ax1])\n v2 = T.axis.spatial(T.int64(4096), i0_i1_i2_0_fused * T.int64(256) + i2_2 + ax2)\n T.reads(lv4[v0, v1, v2], var_matmul_intermediate_local[v0, v1, v2])\n T.writes(p_output0_intermediate[v0, v1, v2])\n p_output0_intermediate[v0, v1, v2] = lv4[v0, v1, v2] + var_matmul_intermediate_local[v0, v1, v2]\n\n\n@T.prim_func\ndef fused_decode3_matmul1_cast_int3_int16_fp16_before(lv2931: T.Buffer((T.int64(824), T.int64(32000)), \"uint16\"), lv2932: T.Buffer((T.int64(103), T.int64(32000)), \"float16\"), lv3025: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), \"float16\"), p_output0_intermediate: T.Buffer((T.int64(1), T.int64(1), T.int64(32000)), \"float32\")):\n T.func_attr({\"tir.noalias\": T.bool(True)})\n # with T.block(\"root\"):\n var_decode_intermediate = T.alloc_buffer((T.int64(4096), T.int64(32000)), \"float16\")\n var_matmul_intermediate = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(32000)), \"float16\")\n for i, j in T.grid(T.int64(4096), T.int64(32000)):\n with T.block(\"decode\"):\n v_i, v_j = T.axis.remap(\"SS\", [i, j])\n T.reads(lv2931[v_i // T.int64(5), v_j], lv2932[v_i // T.int64(40), v_j])\n T.writes(var_decode_intermediate[v_i, v_j])\n var_decode_intermediate[v_i, v_j] = (T.Cast(\"float16\", T.bitwise_and(T.shift_right(T.Cast(\"uint32\", lv2931[v_i // T.int64(5), v_j]), T.Cast(\"uint32\", v_i % T.int64(5)) * T.uint32(3)), T.uint32(7))) - T.float16(3)) * lv2932[v_i // T.int64(40), v_j]\n for i0, i1, i2, k in T.grid(T.int64(1), T.int64(1), T.int64(32000), T.int64(4096)):\n with T.block(\"matmul\"):\n v_i0, v_i1, v_i2, v_k = T.axis.remap(\"SSSR\", [i0, i1, i2, k])\n T.reads(lv3025[v_i0, v_i1, v_k], var_decode_intermediate[v_k, v_i2])\n T.writes(var_matmul_intermediate[v_i0, v_i1, v_i2])\n with T.init():\n var_matmul_intermediate[v_i0, v_i1, v_i2] = T.float16(0)\n var_matmul_intermediate[v_i0, v_i1, v_i2] = var_matmul_intermediate[v_i0, v_i1, v_i2] + lv3025[v_i0, v_i1, v_k] * var_decode_intermediate[v_k, v_i2]\n for i0, i1, i2 in T.grid(T.int64(1), T.int64(1), T.int64(32000)):\n with T.block(\"compute\"):\n v_i0, v_i1, v_i2 = T.axis.remap(\"SSS\", [i0, i1, i2])\n T.reads(var_matmul_intermediate[v_i0, v_i1, v_i2])\n T.writes(p_output0_intermediate[v_i0, v_i1, v_i2])\n p_output0_intermediate[v_i0, v_i1, v_i2] = T.Cast(\"float32\", var_matmul_intermediate[v_i0, v_i1, v_i2])\n\n\n@T.prim_func\ndef fused_decode3_matmul1_cast_int3_int16_fp16_after(lv1123: T.Buffer((T.int64(824), T.int64(32000)), \"uint16\"), lv5866: T.Buffer((T.int64(103), T.int64(32000)), \"float16\"), lv1511: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), \"float16\"), var_matmul_intermediate: T.Buffer((T.int64(1), T.int64(1), T.int64(32000)), \"float32\")):\n T.func_attr({\"tir.noalias\": T.bool(True), \"tir.is_scheduled\": 1})\n # with T.block(\"root\"):\n var_decode_intermediate_pad_local = T.alloc_buffer((T.int64(4120), T.int64(32000)), scope=\"local\", dtype=\"float16\")\n var_scale_intermediate_local = T.alloc_buffer((T.int64(103), T.int64(4096)), scope=\"local\", dtype=\"float16\")\n var_matmul_intermediate_pad_local = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(32000)), scope=\"local\", dtype=\"float16\")\n lv1511_shared = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(4120)), scope=\"shared\", dtype=\"float16\")\n for i0_i1_i2_0_fused in T.thread_binding(T.int64(125), thread=\"blockIdx.x\", annotations={\"pragma_auto_unroll_max_step\": 16, \"pragma_unroll_explicit\": 1}):\n for i2_1 in T.thread_binding(T.int64(1), thread=\"vthread.x\"):\n for i2_2 in T.thread_binding(T.int64(256), thread=\"threadIdx.x\"):\n for ax0, ax1_ax2_fused_0 in T.grid(T.int64(1), T.int64(17)):\n for ax1_ax2_fused_1 in T.thread_binding(T.int64(256), thread=\"threadIdx.x\"):\n with T.block(\"lv1511_shared\"):\n v0 = T.axis.spatial(T.int64(1), ax0)\n v1 = T.axis.spatial(T.int64(1), T.int64(0))\n v2 = T.axis.spatial(T.int64(4120), ax1_ax2_fused_0 * T.int64(256) + ax1_ax2_fused_1)\n T.reads(lv1511[v0, v1, v2])\n T.writes(lv1511_shared[v0, v1, v2])\n T.where(ax1_ax2_fused_0 * T.int64(256) + ax1_ax2_fused_1 < T.int64(4120))\n T.block_attr({\"buffer_dim_align\": [[0, 1, 32, 8]]})\n lv1511_shared[v0, v1, v2] = T.if_then_else(v2 < T.int64(4096), lv1511[v0, v1, v2], T.float16(0))\n with T.block(\"matmul_init\"):\n v_i0 = T.axis.spatial(T.int64(1), T.int64(0))\n v_i1 = T.axis.spatial(T.int64(1), T.int64(0))\n v_i2 = T.axis.spatial(T.int64(32000), i0_i1_i2_0_fused * T.int64(256) + i2_1 * T.int64(256) + i2_2)\n T.reads()\n T.writes(var_matmul_intermediate_pad_local[v_i0, v_i1, v_i2])\n var_matmul_intermediate_pad_local[v_i0, v_i1, v_i2] = T.float32(0)\n for k_0_0 in range(T.int64(103)):\n for ax0_0 in T.unroll(T.int64(40)):\n for ax1 in range(T.int64(1)):\n with T.block(\"var_decode_intermediate_pad\"):\n v0 = T.axis.spatial(T.int64(4120), k_0_0 * T.int64(40) + ax0_0)\n v1 = T.axis.spatial(T.int64(32000), i0_i1_i2_0_fused * T.int64(256) + i2_2 + ax1)\n T.reads(lv1123[v0 // T.int64(5), v1])\n T.writes(var_decode_intermediate_pad_local[v0, v1])\n var_decode_intermediate_pad_local[v0, v1] = T.Cast(\"float16\", T.Cast(\"int16\", T.bitwise_and(T.shift_right(T.Cast(\"uint16\", lv1123[v0 // T.int64(5), v1]), T.Cast(\"uint16\", v0 % T.int64(5)) * T.uint16(3)), T.uint16(7))) - T.int16(3))\n for ax0_0 in range(T.int64(1)):\n for ax1 in range(T.int64(1)):\n with T.block(\"scale\"):\n v_j = T.axis.spatial(T.int64(103), k_0_0 + ax0_0)\n v_i = T.axis.spatial(T.int64(4096), i0_i1_i2_0_fused * T.int64(256) + i2_2 + ax1)\n T.reads(lv5866[v_j, v_i])\n T.writes(var_scale_intermediate_local[v_j, v_i])\n var_scale_intermediate_local[v_j, v_i] = lv5866[v_j, v_i]\n for k_0_1_k_1_fused in range(T.int64(40)):\n with T.block(\"matmul_update\"):\n v_i0 = T.axis.spatial(T.int64(1), T.int64(0))\n v_i1 = T.axis.spatial(T.int64(1), T.int64(0))\n v_i2 = T.axis.spatial(T.int64(32000), i0_i1_i2_0_fused * T.int64(256) + i2_1 * T.int64(256) + i2_2)\n v_k = T.axis.reduce(T.int64(4120), k_0_0 * T.int64(40) + k_0_1_k_1_fused)\n T.reads(var_matmul_intermediate_pad_local[v_i0, v_i1, v_i2], lv1511_shared[v_i0, v_i1, v_k], var_decode_intermediate_pad_local[v_k, v_i2], var_scale_intermediate_local[v_k // T.int64(40), v_i2])\n T.writes(var_matmul_intermediate_pad_local[v_i0, v_i1, v_i2])\n var_matmul_intermediate_pad_local[v_i0, v_i1, v_i2] = var_matmul_intermediate_pad_local[v_i0, v_i1, v_i2] + lv1511_shared[v_i0, v_i1, v_k] * var_decode_intermediate_pad_local[v_k, v_i2] * var_scale_intermediate_local[v_k // T.int64(40), v_i2]\n for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(1), T.int64(1)):\n with T.block(\"var_matmul_intermediate_pad_local\"):\n v0, v1 = T.axis.remap(\"SS\", [ax0, ax1])\n v2 = T.axis.spatial(T.int64(32000), i0_i1_i2_0_fused * T.int64(256) + i2_2 + ax2)\n T.reads(var_matmul_intermediate_pad_local[v0, v1, v2])\n T.writes(var_matmul_intermediate[v0, v1, v2])\n var_matmul_intermediate[v0, v1, v2] = T.Cast(\"float32\", var_matmul_intermediate_pad_local[v0, v1, v2])\n\n\n@T.prim_func\ndef fused_decode4_fused_matmul5_add3_int3_int16_fp16_before(lv1605: T.Buffer((T.int64(824), T.int64(4096)), \"uint16\"), lv1606: T.Buffer((T.int64(103), T.int64(4096)), \"float16\"), lv164: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), \"float16\"), lv1518: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), \"float16\"), p_output0_intermediate: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), \"float16\")):\n T.func_attr({\"tir.noalias\": T.bool(True)})\n # with T.block(\"root\"):\n var_decode_intermediate = T.alloc_buffer((T.int64(4096), T.int64(4096)), \"float16\")\n var_matmul_intermediate = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(4096)), \"float16\")\n for i, j in T.grid(T.int64(4096), T.int64(4096)):\n with T.block(\"decode\"):\n v_i, v_j = T.axis.remap(\"SS\", [i, j])\n T.reads(lv1605[v_i // T.int64(5), v_j], lv1606[v_i // T.int64(40), v_j])\n T.writes(var_decode_intermediate[v_i, v_j])\n var_decode_intermediate[v_i, v_j] = (T.Cast(\"float16\", T.bitwise_and(T.shift_right(T.Cast(\"uint32\", lv1605[v_i // T.int64(5), v_j]), T.Cast(\"uint32\", v_i % T.int64(5)) * T.uint32(3)), T.uint32(7))) - T.float16(3)) * lv1606[v_i // T.int64(40), v_j]\n for i0, i1, i2, k in T.grid(T.int64(1), T.int64(1), T.int64(4096), T.int64(4096)):\n with T.block(\"matmul\"):\n v_i0, v_i1, v_i2, v_k = T.axis.remap(\"SSSR\", [i0, i1, i2, k])\n T.reads(lv164[v_i0, v_i1, v_k], var_decode_intermediate[v_k, v_i2])\n T.writes(var_matmul_intermediate[v_i0, v_i1, v_i2])\n with T.init():\n var_matmul_intermediate[v_i0, v_i1, v_i2] = T.float16(0)\n var_matmul_intermediate[v_i0, v_i1, v_i2] = var_matmul_intermediate[v_i0, v_i1, v_i2] + lv164[v_i0, v_i1, v_k] * var_decode_intermediate[v_k, v_i2]\n for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(1), T.int64(4096)):\n with T.block(\"T_add\"):\n v_ax0, v_ax1, v_ax2 = T.axis.remap(\"SSS\", [ax0, ax1, ax2])\n T.reads(lv1518[v_ax0, v_ax1, v_ax2], var_matmul_intermediate[v_ax0, v_ax1, v_ax2])\n T.writes(p_output0_intermediate[v_ax0, v_ax1, v_ax2])\n p_output0_intermediate[v_ax0, v_ax1, v_ax2] = lv1518[v_ax0, v_ax1, v_ax2] + var_matmul_intermediate[v_ax0, v_ax1, v_ax2]\n\n\n@T.prim_func\ndef fused_decode4_fused_matmul5_add3_int3_int16_fp16_after(lv1143: T.Buffer((T.int64(824), T.int64(4096)), \"uint16\"), lv36: T.Buffer((T.int64(103), T.int64(4096)), \"float16\"), lv3: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), \"float16\"), lv2710: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), \"float16\"), p_output0_intermediate: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), \"float16\")):\n T.func_attr({\"tir.is_scheduled\": 1, \"tir.noalias\": T.bool(True)})\n # with T.block(\"root\"):\n var_decode_intermediate_local = T.alloc_buffer((T.int64(4120), T.int64(4096)), scope=\"local\", dtype=\"float16\")\n var_scale_intermediate_local = T.alloc_buffer((T.int64(103), T.int64(4096)), scope=\"local\", dtype=\"float16\")\n var_matmul_intermediate_local = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(4096)), scope=\"local\", dtype=\"float16\")\n lv3_shared = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(4120)), scope=\"shared\", dtype=\"float16\")\n for i0_i1_i2_0_fused in T.thread_binding(T.int64(16), thread=\"blockIdx.x\", annotations={\"pragma_auto_unroll_max_step\": 16, \"pragma_unroll_explicit\": 1}):\n for i2_1 in T.thread_binding(T.int64(1), thread=\"vthread.x\"):\n for i2_2 in T.thread_binding(T.int64(256), thread=\"threadIdx.x\"):\n for ax0, ax1_ax2_fused_0 in T.grid(T.int64(1), T.int64(17)):\n for ax1_ax2_fused_1 in T.thread_binding(T.int64(256), thread=\"threadIdx.x\"):\n with T.block(\"lv3_shared\"):\n v0 = T.axis.spatial(T.int64(1), ax0)\n v1 = T.axis.spatial(T.int64(1), T.int64(0))\n v2 = T.axis.spatial(T.int64(4120), ax1_ax2_fused_0 * T.int64(256) + ax1_ax2_fused_1)\n T.reads(lv3[v0, v1, v2])\n T.writes(lv3_shared[v0, v1, v2])\n T.where(ax1_ax2_fused_0 * T.int64(256) + ax1_ax2_fused_1 < T.int64(4120))\n T.block_attr({\"buffer_dim_align\": [[0, 1, 32, 8]]})\n lv3_shared[v0, v1, v2] = T.if_then_else(v2 < T.int64(4096), lv3[v0, v1, v2], T.float16(0))\n with T.block(\"matmul_init\"):\n v_i0 = T.axis.spatial(T.int64(1), T.int64(0))\n v_i1 = T.axis.spatial(T.int64(1), T.int64(0))\n v_i2 = T.axis.spatial(T.int64(4096), i0_i1_i2_0_fused * T.int64(256) + i2_1 * T.int64(256) + i2_2)\n T.reads()\n T.writes(var_matmul_intermediate_local[v_i0, v_i1, v_i2])\n var_matmul_intermediate_local[v_i0, v_i1, v_i2] = T.float32(0)\n for k_0_0 in range(T.int64(103)):\n for ax0_0 in T.unroll(T.int64(40)):\n for ax1 in range(T.int64(1)):\n with T.block(\"decode\"):\n v_j = T.axis.spatial(T.int64(4120), k_0_0 * T.int64(40) + ax0_0)\n v_i = T.axis.spatial(T.int64(4096), i0_i1_i2_0_fused * T.int64(256) + i2_2 + ax1)\n T.reads(lv1143[v_j // T.int64(5), v_i])\n T.writes(var_decode_intermediate_local[v_j, v_i])\n var_decode_intermediate_local[v_j, v_i] = T.Cast(\"float16\", T.Cast(\"int16\", T.bitwise_and(T.shift_right(T.Cast(\"uint16\", lv1143[v_j // T.int64(5), v_i]), T.Cast(\"uint16\", v_j % T.int64(5)) * T.uint16(3)), T.uint16(7))) - T.int16(3))\n for ax0_0 in range(T.int64(1)):\n for ax1 in range(T.int64(1)):\n with T.block(\"scale\"):\n v_j = T.axis.spatial(T.int64(103), k_0_0 + ax0_0)\n v_i = T.axis.spatial(T.int64(4096), i0_i1_i2_0_fused * T.int64(256) + i2_2 + ax1)\n T.reads(lv36[v_j, v_i])\n T.writes(var_scale_intermediate_local[v_j, v_i])\n var_scale_intermediate_local[v_j, v_i] = lv36[v_j, v_i]\n for k_0_1_k_1_fused in range(T.int64(40)):\n with T.block(\"matmul_update\"):\n v_i0 = T.axis.spatial(T.int64(1), T.int64(0))\n v_i1 = T.axis.spatial(T.int64(1), T.int64(0))\n v_i2 = T.axis.spatial(T.int64(4096), i0_i1_i2_0_fused * T.int64(256) + i2_1 * T.int64(256) + i2_2)\n v_k = T.axis.reduce(T.int64(4120), k_0_0 * T.int64(40) + k_0_1_k_1_fused)\n T.reads(var_matmul_intermediate_local[v_i0, v_i1, v_i2], lv3_shared[v_i0, v_i1, v_k], var_decode_intermediate_local[v_k, v_i2], var_scale_intermediate_local[v_k // T.int64(40), v_i2])\n T.writes(var_matmul_intermediate_local[v_i0, v_i1, v_i2])\n var_matmul_intermediate_local[v_i0, v_i1, v_i2] = var_matmul_intermediate_local[v_i0, v_i1, v_i2] + lv3_shared[v_i0, v_i1, v_k] * var_decode_intermediate_local[v_k, v_i2] * var_scale_intermediate_local[v_k // T.int64(40), v_i2]\n for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(1), T.int64(1)):\n with T.block(\"var_matmul_intermediate_local\"):\n v0, v1 = T.axis.remap(\"SS\", [ax0, ax1])\n v2 = T.axis.spatial(T.int64(4096), i0_i1_i2_0_fused * T.int64(256) + i2_2 + ax2)\n T.reads(lv2710[v0, v1, v2], var_matmul_intermediate_local[v0, v1, v2])\n T.writes(p_output0_intermediate[v0, v1, v2])\n p_output0_intermediate[v0, v1, v2] = lv2710[v0, v1, v2] + var_matmul_intermediate_local[v0, v1, v2]\n\n\n@T.prim_func\n\ndef fused_decode4_matmul5_int3_int16_fp16_before(lv1587: T.Buffer((T.int64(824), T.int64(4096)), \"uint16\"), lv1588: T.Buffer((T.int64(103), T.int64(4096)), \"float16\"), lv1520: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), \"float16\"), var_matmul_intermediate: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), \"float16\")):\n T.func_attr({\"tir.noalias\": T.bool(True)})\n # with T.block(\"root\"):\n var_decode_intermediate = T.alloc_buffer((T.int64(4096), T.int64(4096)), \"float16\")\n for i, j in T.grid(T.int64(4096), T.int64(4096)):\n with T.block(\"decode\"):\n v_i, v_j = T.axis.remap(\"SS\", [i, j])\n T.reads(lv1587[v_i // T.int64(5), v_j], lv1588[v_i // T.int64(40), v_j])\n T.writes(var_decode_intermediate[v_i, v_j])\n var_decode_intermediate[v_i, v_j] = (T.Cast(\"float16\", T.bitwise_and(T.shift_right(T.Cast(\"uint32\", lv1587[v_i // T.int64(5), v_j]), T.Cast(\"uint32\", v_i % T.int64(5)) * T.uint32(3)), T.uint32(7))) - T.float16(3)) * lv1588[v_i // T.int64(40), v_j]\n for i0, i1, i2, k in T.grid(T.int64(1), T.int64(1), T.int64(4096), T.int64(4096)):\n with T.block(\"matmul\"):\n v_i0, v_i1, v_i2, v_k = T.axis.remap(\"SSSR\", [i0, i1, i2, k])\n T.reads(lv1520[v_i0, v_i1, v_k], var_decode_intermediate[v_k, v_i2])\n T.writes(var_matmul_intermediate[v_i0, v_i1, v_i2])\n with T.init():\n var_matmul_intermediate[v_i0, v_i1, v_i2] = T.float16(0)\n var_matmul_intermediate[v_i0, v_i1, v_i2] = var_matmul_intermediate[v_i0, v_i1, v_i2] + lv1520[v_i0, v_i1, v_k] * var_decode_intermediate[v_k, v_i2]\n\n\n@T.prim_func\ndef fused_decode4_matmul5_int3_int16_fp16_after(lv1128: T.Buffer((T.int64(824), T.int64(4096)), \"uint16\"), lv12: T.Buffer((T.int64(103), T.int64(4096)), \"float16\"), lv2712: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), \"float16\"), var_matmul_intermediate: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), \"float16\")):\n T.func_attr({\"tir.is_scheduled\": 1, \"tir.noalias\": T.bool(True)})\n # with T.block(\"root\"):\n var_decode_intermediate_local = T.alloc_buffer((T.int64(4120), T.int64(4096)), scope=\"local\", dtype=\"float16\")\n var_scale_intermediate_local = T.alloc_buffer((T.int64(103), T.int64(4096)), scope=\"local\", dtype=\"float16\")\n var_matmul_intermediate_local = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(4096)), scope=\"local\", dtype=\"float16\")\n lv2712_shared = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(4120)), scope=\"shared\", dtype=\"float16\")\n for i0_i1_i2_0_fused in T.thread_binding(T.int64(16), thread=\"blockIdx.x\", annotations={\"pragma_auto_unroll_max_step\": 16, \"pragma_unroll_explicit\": 1}):\n for i2_1 in T.thread_binding(T.int64(1), thread=\"vthread.x\"):\n for i2_2 in T.thread_binding(T.int64(256), thread=\"threadIdx.x\"):\n for ax0, ax1_ax2_fused_0 in T.grid(T.int64(1), T.int64(17)):\n for ax1_ax2_fused_1 in T.thread_binding(T.int64(256), thread=\"threadIdx.x\"):\n with T.block(\"lv2712_shared\"):\n v0 = T.axis.spatial(T.int64(1), ax0)\n v1 = T.axis.spatial(T.int64(1), T.int64(0))\n v2 = T.axis.spatial(T.int64(4120), ax1_ax2_fused_0 * T.int64(256) + ax1_ax2_fused_1)\n T.reads(lv2712[v0, v1, v2])\n T.writes(lv2712_shared[v0, v1, v2])\n T.where(ax1_ax2_fused_0 * T.int64(256) + ax1_ax2_fused_1 < T.int64(4120))\n T.block_attr({\"buffer_dim_align\": [[0, 1, 32, 8]]})\n lv2712_shared[v0, v1, v2] = T.if_then_else(v2 < T.int64(4096), lv2712[v0, v1, v2], T.float16(0))\n with T.block(\"matmul_init\"):\n v_i0 = T.axis.spatial(T.int64(1), T.int64(0))\n v_i1 = T.axis.spatial(T.int64(1), T.int64(0))\n v_i2 = T.axis.spatial(T.int64(4096), i0_i1_i2_0_fused * T.int64(256) + i2_1 * T.int64(256) + i2_2)\n T.reads()\n T.writes(var_matmul_intermediate_local[v_i0, v_i1, v_i2])\n var_matmul_intermediate_local[v_i0, v_i1, v_i2] = T.float32(0)\n for k_0_0 in range(T.int64(103)):\n for ax0_0 in T.unroll(T.int64(40)):\n for ax1 in range(T.int64(1)):\n with T.block(\"decode\"):\n v_j = T.axis.spatial(T.int64(4120), k_0_0 * T.int64(40) + ax0_0)\n v_i = T.axis.spatial(T.int64(4096), i0_i1_i2_0_fused * T.int64(256) + i2_2 + ax1)\n T.reads(lv1128[v_j // T.int64(5), v_i])\n T.writes(var_decode_intermediate_local[v_j, v_i])\n var_decode_intermediate_local[v_j, v_i] = T.Cast(\"float16\", T.Cast(\"int16\", T.bitwise_and(T.shift_right(T.Cast(\"uint16\", lv1128[v_j // T.int64(5), v_i]), T.Cast(\"uint16\", v_j % T.int64(5)) * T.uint16(3)), T.uint16(7))) - T.int16(3))\n for ax0_0 in range(T.int64(1)):\n for ax1 in range(T.int64(1)):\n with T.block(\"scale\"):\n v_j = T.axis.spatial(T.int64(103), k_0_0 + ax0_0)\n v_i = T.axis.spatial(T.int64(4096), i0_i1_i2_0_fused * T.int64(256) + i2_2 + ax1)\n T.reads(lv12[v_j, v_i])\n T.writes(var_scale_intermediate_local[v_j, v_i])\n var_scale_intermediate_local[v_j, v_i] = lv12[v_j, v_i]\n for k_0_1_k_1_fused in range(T.int64(40)):\n with T.block(\"matmul_update\"):\n v_i0 = T.axis.spatial(T.int64(1), T.int64(0))\n v_i1 = T.axis.spatial(T.int64(1), T.int64(0))\n v_i2 = T.axis.spatial(T.int64(4096), i0_i1_i2_0_fused * T.int64(256) + i2_1 * T.int64(256) + i2_2)\n v_k = T.axis.reduce(T.int64(4120), k_0_0 * T.int64(40) + k_0_1_k_1_fused)\n T.reads(var_matmul_intermediate_local[v_i0, v_i1, v_i2], lv2712_shared[v_i0, v_i1, v_k], var_decode_intermediate_local[v_k, v_i2], var_scale_intermediate_local[v_k // T.int64(40), v_i2])\n T.writes(var_matmul_intermediate_local[v_i0, v_i1, v_i2])\n var_matmul_intermediate_local[v_i0, v_i1, v_i2] = var_matmul_intermediate_local[v_i0, v_i1, v_i2] + lv2712_shared[v_i0, v_i1, v_k] * var_decode_intermediate_local[v_k, v_i2] * var_scale_intermediate_local[v_k // T.int64(40), v_i2]\n for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(1), T.int64(1)):\n with T.block(\"var_matmul_intermediate_local\"):\n v0, v1 = T.axis.remap(\"SS\", [ax0, ax1])\n v2 = T.axis.spatial(T.int64(4096), i0_i1_i2_0_fused * T.int64(256) + i2_2 + ax2)\n T.reads(var_matmul_intermediate_local[v0, v1, v2])\n T.writes(var_matmul_intermediate[v0, v1, v2])\n var_matmul_intermediate[v0, v1, v2] = var_matmul_intermediate_local[v0, v1, v2]\n\n\n@T.prim_func\ndef fused_decode5_fused_matmul8_multiply1_int3_int16_fp16_before(lv1617: T.Buffer((T.int64(824), T.int64(11008)), \"uint16\"), lv1618: T.Buffer((T.int64(103), T.int64(11008)), \"float16\"), lv1557: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), \"float16\"), lv3: T.Buffer((T.int64(1), T.int64(1), T.int64(11008)), \"float16\"), p_output0_intermediate: T.Buffer((T.int64(1), T.int64(1), T.int64(11008)), \"float16\")):\n T.func_attr({\"tir.noalias\": T.bool(True)})\n # with T.block(\"root\"):\n var_decode_intermediate = T.alloc_buffer((T.int64(4096), T.int64(11008)), \"float16\")\n var_matmul_intermediate = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(11008)), \"float16\")\n for i, j in T.grid(T.int64(4096), T.int64(11008)):\n with T.block(\"decode\"):\n v_i, v_j = T.axis.remap(\"SS\", [i, j])\n T.reads(lv1617[v_i // T.int64(5), v_j], lv1618[v_i // T.int64(40), v_j])\n T.writes(var_decode_intermediate[v_i, v_j])\n var_decode_intermediate[v_i, v_j] = (T.Cast(\"float16\", T.bitwise_and(T.shift_right(T.Cast(\"uint32\", lv1617[v_i // T.int64(5), v_j]), T.Cast(\"uint32\", v_i % T.int64(5)) * T.uint32(3)), T.uint32(7))) - T.float16(3)) * lv1618[v_i // T.int64(40), v_j]\n for i0, i1, i2, k in T.grid(T.int64(1), T.int64(1), T.int64(11008), T.int64(4096)):\n with T.block(\"matmul\"):\n v_i0, v_i1, v_i2, v_k = T.axis.remap(\"SSSR\", [i0, i1, i2, k])\n T.reads(lv1557[v_i0, v_i1, v_k], var_decode_intermediate[v_k, v_i2])\n T.writes(var_matmul_intermediate[v_i0, v_i1, v_i2])\n with T.init():\n var_matmul_intermediate[v_i0, v_i1, v_i2] = T.float16(0)\n var_matmul_intermediate[v_i0, v_i1, v_i2] = var_matmul_intermediate[v_i0, v_i1, v_i2] + lv1557[v_i0, v_i1, v_k] * var_decode_intermediate[v_k, v_i2]\n for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(1), T.int64(11008)):\n with T.block(\"T_multiply\"):\n v_ax0, v_ax1, v_ax2 = T.axis.remap(\"SSS\", [ax0, ax1, ax2])\n T.reads(lv3[v_ax0, v_ax1, v_ax2], var_matmul_intermediate[v_ax0, v_ax1, v_ax2])\n T.writes(p_output0_intermediate[v_ax0, v_ax1, v_ax2])\n...\n", "needle_token_start": 11821, "needle_token_end": 12676, "code_context_ntokens": 16322, "output": ["```python\n@T.prim_func\ndef fused_decode4_matmul5_int3_int16_fp16_before(lv1587: T.Buffer((T.int64(824), T.int64(4096)), \"uint16\"), lv1588: T.Buffer((T.int64(103), T.int64(4096)), \"float16\"), lv1520: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), \"float16\"), var_matmul_intermediate: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), \"float16\")):\n T.func_attr({\"tir.noalias\": T.bool(True)})\n # with T.block(\"root\"):\n var_decode_intermediate = T.alloc_buffer((T.int64(4096), T.int64(4096)), \"float16\")\n for i, j in T.grid(T.int64(4096), T.int64(4096)):\n with T.block(\"decode\"):\n v_i, v_j = T.axis.remap(\"SS\", [i, j])\n T.reads(lv1587[v_i // T.int64(5), v_j], lv1588[v_i // T.int64(40), v_j])\n T.writes(var_decode_intermediate[v_i, v_j])\n var_decode_intermediate[v_i, v_j] = (T.Cast(\"float16\", T.bitwise_and(T.shift_right(T.Cast(\"uint32\", lv1587[v_i // T.int64(5), v_j]), T.Cast(\"uint32\", v_i % T.int64(5)) * T.uint32(3)), T.uint32(7))) - T.float16(3)) * lv1588[v_i // T.int64(40), v_j]\n for i0, i1, i2, k in T.grid(T.int64(1), T.int64(1), T.int64(4096), T.int64(4096)):\n with T.block(\"matmul\"):\n v_i0, v_i1, v_i2, v_k = T.axis.remap(\"SSSR\", [i0, i1, i2, k])\n T.reads(lv1520[v_i0, v_i1, v_k], var_decode_intermediate[v_k, v_i2])\n T.writes(var_matmul_intermediate[v_i0, v_i1, v_i2])\n with T.init():\n var_matmul_intermediate[v_i0, v_i1, v_i2] = T.float16(0)\n var_matmul_intermediate[v_i0, v_i1, v_i2] = var_matmul_intermediate[v_i0, v_i1, v_i2] + lv1520[v_i0, v_i1, v_k] * var_decode_intermediate[v_k, v_i2]\n```"]} +{"repo": "mlc-ai/mlc-llm", "name": "softmax_1xn_sch_func", "language": "python", "path": "mlc_llm/dispatch/llama/main.py", "position_ratio": 0.85, "description": "\nFunction Description:\n1. **Purpose**: The function is designed to optimize the execution of a softmax operation on a 1xN input tensor, potentially casting the results to a lower precision format based on a flag.\n2. **Input**: It accepts a primary function representing the softmax computation and a boolean flag indicating whether the output should be cast to half-precision floating-point format.\n3. **Output**: The output is a modified computational graph that includes optimizations such as inlining computations, splitting and binding loops to hardware threads, and setting memory scopes for efficient execution on parallel hardware.\n4. **Procedure**: The procedure involves:\n - Optionally inlining a casting computation based on the input flag.\n - Inlining the exponential computation directly into the normalization step to reduce memory overhead.\n - Splitting the normalization loop to facilitate parallel execution on GPU threads.\n - Computing the exponential sum and maximum element computations at optimal points in the loop hierarchy to enhance data locality and reduce latency.\n - Binding loops to GPU block and thread identifiers to ensure efficient distribution of work across the GPU's parallel execution units.\n\n", "instruction": "Based on the function description and code context, please retrieve and repeat the exact described function from the code context in a code block wrapped by ```:", "template": "instruction\ncode_context\ndescription\ninstruction", "code_context": " T_softmax_maxelem[v_i0, v_i1, v_i2] = T.max(T_softmax_maxelem[v_i0, v_i1, v_i2], rxplaceholder[v_i0, v_i1, v_i2, v_k])\n for i0, i1, i2, i3 in T.grid(T.int64(1), T.int64(32), n, m):\n with T.block(\"T_softmax_exp\"):\n v_i0, v_i1, v_i2, v_i3 = T.axis.remap(\"SSSS\", [i0, i1, i2, i3])\n T.reads(rxplaceholder[v_i0, v_i1, v_i2, v_i3], T_softmax_maxelem[v_i0, v_i1, v_i2])\n T.writes(T_softmax_exp[v_i0, v_i1, v_i2, v_i3])\n T_softmax_exp[v_i0, v_i1, v_i2, v_i3] = T.exp(rxplaceholder[v_i0, v_i1, v_i2, v_i3] - T_softmax_maxelem[v_i0, v_i1, v_i2])\n for i0, i1, i2, k in T.grid(T.int64(1), T.int64(32), n, m):\n with T.block(\"T_softmax_expsum\"):\n v_i0, v_i1, v_i2, v_k = T.axis.remap(\"SSSR\", [i0, i1, i2, k])\n T.reads(T_softmax_exp[v_i0, v_i1, v_i2, v_k])\n T.writes(T_softmax_expsum[v_i0, v_i1, v_i2])\n with T.init():\n T_softmax_expsum[v_i0, v_i1, v_i2] = T.float16(0)\n T_softmax_expsum[v_i0, v_i1, v_i2] = T_softmax_expsum[v_i0, v_i1, v_i2] + T_softmax_exp[v_i0, v_i1, v_i2, v_k]\n for i0, i1, i2, i3 in T.grid(T.int64(1), T.int64(32), n, m):\n with T.block(\"T_softmax_norm\"):\n v_i0, v_i1, v_i2, v_i3 = T.axis.remap(\"SSSS\", [i0, i1, i2, i3])\n T.reads(T_softmax_exp[v_i0, v_i1, v_i2, v_i3], T_softmax_expsum[v_i0, v_i1, v_i2])\n T.writes(T_softmax_norm[v_i0, v_i1, v_i2, v_i3])\n T.block_attr({\"axis\": 3})\n T_softmax_norm[v_i0, v_i1, v_i2, v_i3] = T_softmax_exp[v_i0, v_i1, v_i2, v_i3] / T_softmax_expsum[v_i0, v_i1, v_i2]\n\n@T.prim_func\ndef softmax_mxn_fp16_after(var_A: T.handle, var_T_softmax_norm: T.handle):\n T.func_attr({\"tir.noalias\": T.bool(True), \"tir.is_scheduled\": 1})\n n = T.int64()\n m = T.int64()\n A = T.match_buffer(var_A, (T.int64(1), T.int64(32), n, m), dtype=\"float16\")\n T_softmax_norm = T.match_buffer(var_T_softmax_norm, (T.int64(1), T.int64(32), n, m), dtype=\"float16\")\n # with T.block(\"root\"):\n for i2_0 in T.thread_binding((n + T.int64(31)) // T.int64(32), thread=\"blockIdx.x\"):\n with T.block(\"T_softmax_maxelem_o\"):\n v_i0 = T.axis.spatial(T.int64(1), T.int64(0))\n v_i2_o = T.axis.spatial((n + T.int64(31)) // T.int64(32), i2_0)\n T.reads(A[v_i0, T.int64(0):T.int64(32), v_i2_o * T.int64(32):v_i2_o * T.int64(32) + T.int64(32), T.int64(0):(m + T.int64(127)) // T.int64(128) * T.int64(128)])\n T.writes(T_softmax_norm[v_i0, T.int64(0):T.int64(32), v_i2_o * T.int64(32):v_i2_o * T.int64(32) + T.int64(32), T.int64(0):m])\n T_softmax_maxelem_pad_0_local = T.alloc_buffer((T.int64(1), T.int64(32), T.int64(32)), scope=\"shared\", dtype=\"float16\")\n T_softmax_expsum_pad_0_local = T.alloc_buffer((T.int64(1), T.int64(32), T.int64(32)), scope=\"shared\", dtype=\"float16\")\n for i0, i1, i2_1, k_0 in T.grid(T.int64(1), T.int64(32), T.int64(32), (m + T.int64(127)) // T.int64(128)):\n for k_1 in T.thread_binding(T.int64(128), thread=\"threadIdx.x\"):\n with T.block(\"T_softmax_maxelem\"):\n v_i1_i, v_i2_i = T.axis.remap(\"SS\", [i1, i2_1])\n v_k_i = T.axis.reduce(T.int64(32) * ((m + T.int64(127)) // T.int64(128)), k_0 * T.int64(128) + k_1)\n T.reads(A[v_i0, v_i1_i, v_i2_o * T.int64(32) + v_i2_i, v_k_i])\n T.writes(T_softmax_maxelem_pad_0_local[v_i0, v_i1_i, v_i2_i])\n with T.init():\n T_softmax_maxelem_pad_0_local[v_i0, v_i1_i, v_i2_i] = T.float16(-65504)\n T_softmax_maxelem_pad_0_local[v_i0, v_i1_i, v_i2_i] = T.max(T_softmax_maxelem_pad_0_local[v_i0, v_i1_i, v_i2_i], T.if_then_else(v_i2_o * T.int64(32) + v_i2_i < n and v_k_i < m, A[v_i0, v_i1_i, v_i2_o * T.int64(32) + v_i2_i, v_k_i], T.float16(-65504)))\n for i0, i1, i2_1, k_0 in T.grid(T.int64(1), T.int64(32), T.int64(32), (m + T.int64(127)) // T.int64(128)):\n for k_1 in T.thread_binding(T.int64(128), thread=\"threadIdx.x\"):\n with T.block(\"T_softmax_expsum\"):\n v_i1_i, v_i2_i = T.axis.remap(\"SS\", [i1, i2_1])\n v_k_i = T.axis.reduce(T.int64(32) * ((m + T.int64(127)) // T.int64(128)), k_0 * T.int64(128) + k_1)\n T.reads(A[v_i0, v_i1_i, v_i2_o * T.int64(32) + v_i2_i, v_k_i], T_softmax_maxelem_pad_0_local[v_i0, v_i1_i, v_i2_i])\n T.writes(T_softmax_expsum_pad_0_local[v_i0, v_i1_i, v_i2_i])\n with T.init():\n T_softmax_expsum_pad_0_local[v_i0, v_i1_i, v_i2_i] = T.float16(0)\n T_softmax_expsum_pad_0_local[v_i0, v_i1_i, v_i2_i] = T_softmax_expsum_pad_0_local[v_i0, v_i1_i, v_i2_i] + T.if_then_else(v_i2_o * T.int64(32) + v_i2_i < n and v_k_i < m, T.exp(A[v_i0, v_i1_i, v_i2_o * T.int64(32) + v_i2_i, v_k_i] - T_softmax_maxelem_pad_0_local[v_i0, v_i1_i, v_i2_i]), T.float16(0))\n for i0_i1_i2_1_i3_fused_0 in range((T.int64(32) * T.int64(32) * m) // T.int64(128)):\n for i0_i1_i2_1_i3_fused_1 in T.thread_binding(T.int64(128), thread=\"threadIdx.x\"):\n with T.block(\"T_softmax_norm\"):\n v_i0 = T.axis.spatial(T.int64(1), T.int64(0))\n v_i1 = T.axis.spatial(T.int64(32), (i0_i1_i2_1_i3_fused_0 * T.int64(128) + i0_i1_i2_1_i3_fused_1) // T.int64(32) // m)\n v_i2_i = T.axis.spatial(T.int64(32), (i0_i1_i2_1_i3_fused_0 * T.int64(128) + i0_i1_i2_1_i3_fused_1) // m % T.int64(32))\n v_i3 = T.axis.spatial(m, (i0_i1_i2_1_i3_fused_0 * T.int64(128) + i0_i1_i2_1_i3_fused_1) % m)\n T.where(i0_i1_i2_1_i3_fused_0 * T.int64(128) + i0_i1_i2_1_i3_fused_1 < T.int64(32) * T.int64(32) * m)\n T.reads(T_softmax_expsum_pad_0_local[v_i0, v_i1, v_i2_i], A[v_i0, v_i1, v_i2_o * T.int64(32) + v_i2_i, v_i3], T_softmax_maxelem_pad_0_local[v_i0, v_i1, v_i2_i])\n T.writes(T_softmax_norm[v_i0, v_i1, v_i2_o * T.int64(32) + v_i2_i, v_i3])\n if v_i2_o * T.int64(32) + v_i2_i < n:\n T_softmax_norm[v_i0, v_i1, v_i2_o * T.int64(32) + v_i2_i, v_i3] = T.exp(A[v_i0, v_i1, v_i2_o * T.int64(32) + v_i2_i, v_i3] - T_softmax_maxelem_pad_0_local[v_i0, v_i1, v_i2_i]) / T_softmax_expsum_pad_0_local[v_i0, v_i1, v_i2_i]\n\n\n@T.prim_func\ndef softmax_fp16_before(var_rxplaceholder: T.handle, var_T_softmax_norm: T.handle):\n T.func_attr({\"tir.noalias\": T.bool(True)})\n n = T.int64()\n rxplaceholder = T.match_buffer(var_rxplaceholder, (T.int64(1), T.int64(32), n, n), \"float16\")\n T_softmax_norm = T.match_buffer(var_T_softmax_norm, (T.int64(1), T.int64(32), n, n), \"float16\")\n # with T.block(\"root\"):\n T_softmax_maxelem = T.alloc_buffer((T.int64(1), T.int64(32), n), \"float16\")\n T_softmax_exp = T.alloc_buffer((T.int64(1), T.int64(32), n, n), \"float16\")\n T_softmax_expsum = T.alloc_buffer((T.int64(1), T.int64(32), n), \"float16\")\n for i0, i1, i2, k in T.grid(T.int64(1), T.int64(32), n, n):\n with T.block(\"T_softmax_maxelem\"):\n v_i0, v_i1, v_i2, v_k = T.axis.remap(\"SSSR\", [i0, i1, i2, k])\n T.reads(rxplaceholder[v_i0, v_i1, v_i2, v_k])\n T.writes(T_softmax_maxelem[v_i0, v_i1, v_i2])\n with T.init():\n T_softmax_maxelem[v_i0, v_i1, v_i2] = T.float16(-65504)\n T_softmax_maxelem[v_i0, v_i1, v_i2] = T.max(T_softmax_maxelem[v_i0, v_i1, v_i2], rxplaceholder[v_i0, v_i1, v_i2, v_k])\n for i0, i1, i2, i3 in T.grid(T.int64(1), T.int64(32), n, n):\n with T.block(\"T_softmax_exp\"):\n v_i0, v_i1, v_i2, v_i3 = T.axis.remap(\"SSSS\", [i0, i1, i2, i3])\n T.reads(rxplaceholder[v_i0, v_i1, v_i2, v_i3], T_softmax_maxelem[v_i0, v_i1, v_i2])\n T.writes(T_softmax_exp[v_i0, v_i1, v_i2, v_i3])\n T_softmax_exp[v_i0, v_i1, v_i2, v_i3] = T.exp(rxplaceholder[v_i0, v_i1, v_i2, v_i3] - T_softmax_maxelem[v_i0, v_i1, v_i2])\n for i0, i1, i2, k in T.grid(T.int64(1), T.int64(32), n, n):\n with T.block(\"T_softmax_expsum\"):\n v_i0, v_i1, v_i2, v_k = T.axis.remap(\"SSSR\", [i0, i1, i2, k])\n T.reads(T_softmax_exp[v_i0, v_i1, v_i2, v_k])\n T.writes(T_softmax_expsum[v_i0, v_i1, v_i2])\n with T.init():\n T_softmax_expsum[v_i0, v_i1, v_i2] = T.float16(0)\n T_softmax_expsum[v_i0, v_i1, v_i2] = T_softmax_expsum[v_i0, v_i1, v_i2] + T_softmax_exp[v_i0, v_i1, v_i2, v_k]\n for i0, i1, i2, i3 in T.grid(T.int64(1), T.int64(32), n, n):\n with T.block(\"T_softmax_norm\"):\n v_i0, v_i1, v_i2, v_i3 = T.axis.remap(\"SSSS\", [i0, i1, i2, i3])\n T.reads(T_softmax_exp[v_i0, v_i1, v_i2, v_i3], T_softmax_expsum[v_i0, v_i1, v_i2])\n T.writes(T_softmax_norm[v_i0, v_i1, v_i2, v_i3])\n T.block_attr({\"axis\": 3})\n T_softmax_norm[v_i0, v_i1, v_i2, v_i3] = T_softmax_exp[v_i0, v_i1, v_i2, v_i3] / T_softmax_expsum[v_i0, v_i1, v_i2]\n\n\n@T.prim_func\ndef softmax_fp16_after(var_A: T.handle, var_T_softmax_norm: T.handle):\n T.func_attr({\"tir.noalias\": T.bool(True), \"tir.is_scheduled\": 1})\n n = T.int64()\n A = T.match_buffer(var_A, (T.int64(1), T.int64(32), n, n), dtype=\"float16\")\n T_softmax_norm = T.match_buffer(var_T_softmax_norm, (T.int64(1), T.int64(32), n, n), dtype=\"float16\")\n # with T.block(\"root\"):\n T_softmax_maxelem = T.alloc_buffer((T.int64(1), T.int64(32), n), dtype=\"float16\")\n T_softmax_expsum = T.alloc_buffer((T.int64(1), T.int64(32), n), dtype=\"float16\")\n for i2_0 in T.thread_binding((n + T.int64(31)) // T.int64(32), thread=\"blockIdx.x\"):\n with T.block(\"T_softmax_maxelem_o\"):\n v_i0 = T.axis.spatial(T.int64(1), T.int64(0))\n v_i2_o = T.axis.spatial((n + T.int64(31)) // T.int64(32), i2_0)\n T.reads(A[v_i0, T.int64(0):T.int64(32), v_i2_o * T.int64(32):v_i2_o * T.int64(32) + T.int64(32), T.int64(0):(n + T.int64(127)) // T.int64(128) * T.int64(128)])\n T.writes(T_softmax_maxelem[v_i0, T.int64(0):T.int64(32), v_i2_o * T.int64(32):v_i2_o * T.int64(32) + T.int64(32)])\n T_softmax_maxelem_pad_0_local = T.alloc_buffer((T.int64(1), T.int64(32), T.int64(32)), scope=\"shared\", dtype=\"float16\")\n for i0, i1, i2_1, k_0 in T.grid(T.int64(1), T.int64(32), T.int64(32), (n + T.int64(127)) // T.int64(128)):\n for k_1 in T.thread_binding(T.int64(128), thread=\"threadIdx.x\"):\n with T.block(\"T_softmax_maxelem\"):\n v_i1_i, v_i2_i = T.axis.remap(\"SS\", [i1, i2_1])\n v_k_i = T.axis.reduce(T.int64(32) * ((n + T.int64(127)) // T.int64(128)), k_0 * T.int64(128) + k_1)\n T.reads(A[v_i0, v_i1_i, v_i2_o * T.int64(32) + v_i2_i, v_k_i])\n T.writes(T_softmax_maxelem_pad_0_local[v_i0, v_i1_i, v_i2_i])\n with T.init():\n T_softmax_maxelem_pad_0_local[v_i0, v_i1_i, v_i2_i] = T.float16(-65504)\n T_softmax_maxelem_pad_0_local[v_i0, v_i1_i, v_i2_i] = T.max(T_softmax_maxelem_pad_0_local[v_i0, v_i1_i, v_i2_i], T.if_then_else(v_i2_o * T.int64(32) + v_i2_i < n and v_k_i < n, A[v_i0, v_i1_i, v_i2_o * T.int64(32) + v_i2_i, v_k_i], T.float16(-65504)))\n for i0_i1_i2_1_fused_0 in range(T.int64(8)):\n for i0_i1_i2_1_fused_1 in T.thread_binding(T.int64(128), thread=\"threadIdx.x\"):\n with T.block(\"T_softmax_maxelem_cache_write\"):\n v_i1_i = T.axis.spatial(T.int64(32), (i0_i1_i2_1_fused_0 * T.int64(128) + i0_i1_i2_1_fused_1) // T.int64(32))\n v_i2_i = T.axis.spatial(T.int64(32), (i0_i1_i2_1_fused_0 * T.int64(128) + i0_i1_i2_1_fused_1) % T.int64(32))\n T.where(v_i2_o * T.int64(32) + (i0_i1_i2_1_fused_0 * T.int64(128) + i0_i1_i2_1_fused_1) % T.int64(32) < n)\n T.reads(T_softmax_maxelem_pad_0_local[v_i0, v_i1_i, v_i2_i])\n T.writes(T_softmax_maxelem[v_i0, v_i1_i, v_i2_o * T.int64(32) + v_i2_i])\n T_softmax_maxelem[v_i0, v_i1_i, v_i2_o * T.int64(32) + v_i2_i] = T_softmax_maxelem_pad_0_local[v_i0, v_i1_i, v_i2_i]\n for i2_0 in T.thread_binding((n + T.int64(31)) // T.int64(32), thread=\"blockIdx.x\"):\n with T.block(\"T_softmax_expsum_o\"):\n v_i0 = T.axis.spatial(T.int64(1), T.int64(0))\n v_i2_o = T.axis.spatial((n + T.int64(31)) // T.int64(32), i2_0)\n T.reads(A[v_i0, T.int64(0):T.int64(32), v_i2_o * T.int64(32):v_i2_o * T.int64(32) + T.int64(32), T.int64(0):(n + T.int64(127)) // T.int64(128) * T.int64(128)], T_softmax_maxelem[v_i0, T.int64(0):T.int64(32), v_i2_o * T.int64(32):v_i2_o * T.int64(32) + T.int64(32)])\n T.writes(T_softmax_expsum[v_i0, T.int64(0):T.int64(32), v_i2_o * T.int64(32):v_i2_o * T.int64(32) + T.int64(32)])\n T_softmax_expsum_pad_0_local = T.alloc_buffer((T.int64(1), T.int64(32), T.int64(32)), scope=\"shared\", dtype=\"float16\")\n for i0, i1, i2_1, k_0 in T.grid(T.int64(1), T.int64(32), T.int64(32), (n + T.int64(127)) // T.int64(128)):\n for k_1 in T.thread_binding(T.int64(128), thread=\"threadIdx.x\"):\n with T.block(\"T_softmax_expsum\"):\n v_i1_i, v_i2_i = T.axis.remap(\"SS\", [i1, i2_1])\n v_k_i = T.axis.reduce(T.int64(32) * ((n + T.int64(127)) // T.int64(128)), k_0 * T.int64(128) + k_1)\n T.reads(A[v_i0, v_i1_i, v_i2_o * T.int64(32) + v_i2_i, v_k_i], T_softmax_maxelem[v_i0, v_i1_i, v_i2_o * T.int64(32) + v_i2_i])\n T.writes(T_softmax_expsum_pad_0_local[v_i0, v_i1_i, v_i2_i])\n with T.init():\n T_softmax_expsum_pad_0_local[v_i0, v_i1_i, v_i2_i] = T.float16(0)\n T_softmax_expsum_pad_0_local[v_i0, v_i1_i, v_i2_i] = T_softmax_expsum_pad_0_local[v_i0, v_i1_i, v_i2_i] + T.if_then_else(v_i2_o * T.int64(32) + v_i2_i < n and v_k_i < n, T.exp(A[v_i0, v_i1_i, v_i2_o * T.int64(32) + v_i2_i, v_k_i] - T_softmax_maxelem[v_i0, v_i1_i, v_i2_o * T.int64(32) + v_i2_i]), T.float16(0))\n for i0_i1_i2_1_fused_0 in range(T.int64(8)):\n for i0_i1_i2_1_fused_1 in T.thread_binding(T.int64(128), thread=\"threadIdx.x\"):\n with T.block(\"T_softmax_expsum_cache_write\"):\n v_i1_i = T.axis.spatial(T.int64(32), (i0_i1_i2_1_fused_0 * T.int64(128) + i0_i1_i2_1_fused_1) // T.int64(32))\n v_i2_i = T.axis.spatial(T.int64(32), (i0_i1_i2_1_fused_0 * T.int64(128) + i0_i1_i2_1_fused_1) % T.int64(32))\n T.where(v_i2_o * T.int64(32) + (i0_i1_i2_1_fused_0 * T.int64(128) + i0_i1_i2_1_fused_1) % T.int64(32) < n)\n T.reads(T_softmax_expsum_pad_0_local[v_i0, v_i1_i, v_i2_i])\n T.writes(T_softmax_expsum[v_i0, v_i1_i, v_i2_o * T.int64(32) + v_i2_i])\n T_softmax_expsum[v_i0, v_i1_i, v_i2_o * T.int64(32) + v_i2_i] = T_softmax_expsum_pad_0_local[v_i0, v_i1_i, v_i2_i]\n for i0_i1_i2_fused_i3_fused_0 in T.thread_binding((n * T.int64(32) * n + T.int64(255)) // T.int64(256), thread=\"blockIdx.x\"):\n for i0_i1_i2_fused_i3_fused_1 in T.thread_binding(T.int64(256), thread=\"threadIdx.x\"):\n with T.block(\"T_softmax_norm\"):\n v_i0 = T.axis.spatial(T.int64(1), T.int64(0))\n v_i1 = T.axis.spatial(T.int64(32), (i0_i1_i2_fused_i3_fused_0 * T.int64(256) + i0_i1_i2_fused_i3_fused_1) // n // n)\n v_i2 = T.axis.spatial(n, (i0_i1_i2_fused_i3_fused_0 * T.int64(256) + i0_i1_i2_fused_i3_fused_1) // n % n)\n v_i3 = T.axis.spatial(n, (i0_i1_i2_fused_i3_fused_0 * T.int64(256) + i0_i1_i2_fused_i3_fused_1) % n)\n T.where(i0_i1_i2_fused_i3_fused_0 * T.int64(256) + i0_i1_i2_fused_i3_fused_1 < n * T.int64(32) * n)\n T.reads(T_softmax_expsum[v_i0, v_i1, v_i2], A[v_i0, v_i1, v_i2, v_i3], T_softmax_maxelem[v_i0, v_i1, v_i2])\n T.writes(T_softmax_norm[v_i0, v_i1, v_i2, v_i3])\n T_softmax_norm[v_i0, v_i1, v_i2, v_i3] = T.exp(A[v_i0, v_i1, v_i2, v_i3] - T_softmax_maxelem[v_i0, v_i1, v_i2]) / T_softmax_expsum[v_i0, v_i1, v_i2]\n\n\n@T.prim_func\ndef softmax_1xn_before(var_inp0: T.handle, var_T_softmax_norm: T.handle):\n T.func_attr({\"tir.noalias\": T.bool(True)})\n n = T.int64()\n inp0 = T.match_buffer(var_inp0, (T.int64(1), T.int64(32), T.int64(1), n))\n T_softmax_norm = T.match_buffer(var_T_softmax_norm, (T.int64(1), T.int64(32), T.int64(1), n))\n # with T.block(\"root\"):\n T_softmax_maxelem = T.alloc_buffer((T.int64(1), T.int64(32), T.int64(1)))\n T_softmax_exp = T.alloc_buffer((T.int64(1), T.int64(32), T.int64(1), n))\n T_softmax_expsum = T.alloc_buffer((T.int64(1), T.int64(32), T.int64(1)))\n for i0, i1, i2, k in T.grid(T.int64(1), T.int64(32), T.int64(1), n):\n with T.block(\"T_softmax_maxelem\"):\n v_i0, v_i1, v_i2, v_k = T.axis.remap(\"SSSR\", [i0, i1, i2, k])\n T.reads(inp0[v_i0, v_i1, v_i2, v_k])\n T.writes(T_softmax_maxelem[v_i0, v_i1, v_i2])\n with T.init():\n T_softmax_maxelem[v_i0, v_i1, v_i2] = T.float32(-3.4028234663852886e+38)\n T_softmax_maxelem[v_i0, v_i1, v_i2] = T.max(T_softmax_maxelem[v_i0, v_i1, v_i2], inp0[v_i0, v_i1, v_i2, v_k])\n for i0, i1, i2, i3 in T.grid(T.int64(1), T.int64(32), T.int64(1), n):\n with T.block(\"T_softmax_exp\"):\n v_i0, v_i1, v_i2, v_i3 = T.axis.remap(\"SSSS\", [i0, i1, i2, i3])\n T.reads(inp0[v_i0, v_i1, v_i2, v_i3], T_softmax_maxelem[v_i0, v_i1, v_i2])\n T.writes(T_softmax_exp[v_i0, v_i1, v_i2, v_i3])\n T_softmax_exp[v_i0, v_i1, v_i2, v_i3] = T.exp(inp0[v_i0, v_i1, v_i2, v_i3] - T_softmax_maxelem[v_i0, v_i1, v_i2])\n for i0, i1, i2, k in T.grid(T.int64(1), T.int64(32), T.int64(1), n):\n with T.block(\"T_softmax_expsum\"):\n v_i0, v_i1, v_i2, v_k = T.axis.remap(\"SSSR\", [i0, i1, i2, k])\n T.reads(T_softmax_exp[v_i0, v_i1, v_i2, v_k])\n T.writes(T_softmax_expsum[v_i0, v_i1, v_i2])\n with T.init():\n T_softmax_expsum[v_i0, v_i1, v_i2] = T.float32(0)\n T_softmax_expsum[v_i0, v_i1, v_i2] = T_softmax_expsum[v_i0, v_i1, v_i2] + T_softmax_exp[v_i0, v_i1, v_i2, v_k]\n for i0, i1, i2, i3 in T.grid(T.int64(1), T.int64(32), T.int64(1), n):\n with T.block(\"T_softmax_norm\"):\n v_i0, v_i1, v_i2, v_i3 = T.axis.remap(\"SSSS\", [i0, i1, i2, i3])\n T.reads(T_softmax_exp[v_i0, v_i1, v_i2, v_i3], T_softmax_expsum[v_i0, v_i1, v_i2])\n T.writes(T_softmax_norm[v_i0, v_i1, v_i2, v_i3])\n T.block_attr({\"axis\": 3})\n T_softmax_norm[v_i0, v_i1, v_i2, v_i3] = T_softmax_exp[v_i0, v_i1, v_i2, v_i3] / T_softmax_expsum[v_i0, v_i1, v_i2]\n\n\n@T.prim_func\ndef softmax_cast_1xn_before(p_lv1614: T.handle, p_output0: T.handle):\n T.func_attr({\"tir.noalias\": T.bool(True)})\n n = T.int64()\n lv1614 = T.match_buffer(p_lv1614, (T.int64(1), T.int64(32), T.int64(1), n))\n var_compute_intermediate = T.match_buffer(p_output0, (T.int64(1), T.int64(32), T.int64(1), n), \"float16\")\n # with T.block(\"root\"):\n T_softmax_maxelem = T.alloc_buffer((T.int64(1), T.int64(32), T.int64(1)))\n T_softmax_exp = T.alloc_buffer((T.int64(1), T.int64(32), T.int64(1), n))\n T_softmax_expsum = T.alloc_buffer((T.int64(1), T.int64(32), T.int64(1)))\n var_T_softmax_norm_intermediate = T.alloc_buffer((T.int64(1), T.int64(32), T.int64(1), n))\n for i0, i1, i2, k in T.grid(T.int64(1), T.int64(32), T.int64(1), n):\n with T.block(\"T_softmax_maxelem\"):\n v_i0, v_i1, v_i2, v_k = T.axis.remap(\"SSSR\", [i0, i1, i2, k])\n T.reads(lv1614[v_i0, v_i1, v_i2, v_k])\n T.writes(T_softmax_maxelem[v_i0, v_i1, v_i2])\n with T.init():\n T_softmax_maxelem[v_i0, v_i1, v_i2] = T.float32(-3.4028234663852886e+38)\n T_softmax_maxelem[v_i0, v_i1, v_i2] = T.max(T_softmax_maxelem[v_i0, v_i1, v_i2], lv1614[v_i0, v_i1, v_i2, v_k])\n for i0, i1, i2, i3 in T.grid(T.int64(1), T.int64(32), T.int64(1), n):\n with T.block(\"T_softmax_exp\"):\n v_i0, v_i1, v_i2, v_i3 = T.axis.remap(\"SSSS\", [i0, i1, i2, i3])\n T.reads(lv1614[v_i0, v_i1, v_i2, v_i3], T_softmax_maxelem[v_i0, v_i1, v_i2])\n T.writes(T_softmax_exp[v_i0, v_i1, v_i2, v_i3])\n T_softmax_exp[v_i0, v_i1, v_i2, v_i3] = T.exp(lv1614[v_i0, v_i1, v_i2, v_i3] - T_softmax_maxelem[v_i0, v_i1, v_i2])\n for i0, i1, i2, k in T.grid(T.int64(1), T.int64(32), T.int64(1), n):\n with T.block(\"T_softmax_expsum\"):\n v_i0, v_i1, v_i2, v_k = T.axis.remap(\"SSSR\", [i0, i1, i2, k])\n T.reads(T_softmax_exp[v_i0, v_i1, v_i2, v_k])\n T.writes(T_softmax_expsum[v_i0, v_i1, v_i2])\n with T.init():\n T_softmax_expsum[v_i0, v_i1, v_i2] = T.float32(0)\n T_softmax_expsum[v_i0, v_i1, v_i2] = T_softmax_expsum[v_i0, v_i1, v_i2] + T_softmax_exp[v_i0, v_i1, v_i2, v_k]\n for i0, i1, i2, i3 in T.grid(T.int64(1), T.int64(32), T.int64(1), n):\n with T.block(\"T_softmax_norm\"):\n v_i0, v_i1, v_i2, v_i3 = T.axis.remap(\"SSSS\", [i0, i1, i2, i3])\n T.reads(T_softmax_exp[v_i0, v_i1, v_i2, v_i3], T_softmax_expsum[v_i0, v_i1, v_i2])\n T.writes(var_T_softmax_norm_intermediate[v_i0, v_i1, v_i2, v_i3])\n T.block_attr({\"axis\": 3})\n var_T_softmax_norm_intermediate[v_i0, v_i1, v_i2, v_i3] = T_softmax_exp[v_i0, v_i1, v_i2, v_i3] / T_softmax_expsum[v_i0, v_i1, v_i2]\n for i0, i1, i2, i3 in T.grid(T.int64(1), T.int64(32), T.int64(1), n):\n with T.block(\"compute\"):\n v_i0, v_i1, v_i2, v_i3 = T.axis.remap(\"SSSS\", [i0, i1, i2, i3])\n T.reads(var_T_softmax_norm_intermediate[v_i0, v_i1, v_i2, v_i3])\n T.writes(var_compute_intermediate[v_i0, v_i1, v_i2, v_i3])\n var_compute_intermediate[v_i0, v_i1, v_i2, v_i3] = T.Cast(\"float16\", var_T_softmax_norm_intermediate[v_i0, v_i1, v_i2, v_i3])\n\n\n@T.prim_func\ndef softmax_1xn_fp16_before(var_rxplaceholder: T.handle, var_T_softmax_norm: T.handle):\n T.func_attr({\"tir.noalias\": T.bool(True)})\n n = T.int64()\n rxplaceholder = T.match_buffer(var_rxplaceholder, (T.int64(1), T.int64(32), T.int64(1), n), \"float16\")\n T_softmax_norm = T.match_buffer(var_T_softmax_norm, (T.int64(1), T.int64(32), T.int64(1), n), \"float16\")\n # with T.block(\"root\"):\n T_softmax_maxelem = T.alloc_buffer((T.int64(1), T.int64(32), T.int64(1)), \"float16\")\n T_softmax_exp = T.alloc_buffer((T.int64(1), T.int64(32), T.int64(1), n), \"float16\")\n T_softmax_expsum = T.alloc_buffer((T.int64(1), T.int64(32), T.int64(1)), \"float16\")\n for i0, i1, i2, k in T.grid(T.int64(1), T.int64(32), T.int64(1), n):\n with T.block(\"T_softmax_maxelem\"):\n v_i0, v_i1, v_i2, v_k = T.axis.remap(\"SSSR\", [i0, i1, i2, k])\n T.reads(rxplaceholder[v_i0, v_i1, v_i2, v_k])\n T.writes(T_softmax_maxelem[v_i0, v_i1, v_i2])\n with T.init():\n T_softmax_maxelem[v_i0, v_i1, v_i2] = T.float16(-65504)\n T_softmax_maxelem[v_i0, v_i1, v_i2] = T.max(T_softmax_maxelem[v_i0, v_i1, v_i2], rxplaceholder[v_i0, v_i1, v_i2, v_k])\n for i0, i1, i2, i3 in T.grid(T.int64(1), T.int64(32), T.int64(1), n):\n with T.block(\"T_softmax_exp\"):\n v_i0, v_i1, v_i2, v_i3 = T.axis.remap(\"SSSS\", [i0, i1, i2, i3])\n T.reads(rxplaceholder[v_i0, v_i1, v_i2, v_i3], T_softmax_maxelem[v_i0, v_i1, v_i2])\n T.writes(T_softmax_exp[v_i0, v_i1, v_i2, v_i3])\n T_softmax_exp[v_i0, v_i1, v_i2, v_i3] = T.exp(rxplaceholder[v_i0, v_i1, v_i2, v_i3] - T_softmax_maxelem[v_i0, v_i1, v_i2])\n for i0, i1, i2, k in T.grid(T.int64(1), T.int64(32), T.int64(1), n):\n with T.block(\"T_softmax_expsum\"):\n v_i0, v_i1, v_i2, v_k = T.axis.remap(\"SSSR\", [i0, i1, i2, k])\n T.reads(T_softmax_exp[v_i0, v_i1, v_i2, v_k])\n T.writes(T_softmax_expsum[v_i0, v_i1, v_i2])\n with T.init():\n T_softmax_expsum[v_i0, v_i1, v_i2] = T.float16(0)\n T_softmax_expsum[v_i0, v_i1, v_i2] = T_softmax_expsum[v_i0, v_i1, v_i2] + T_softmax_exp[v_i0, v_i1, v_i2, v_k]\n for i0, i1, i2, i3 in T.grid(T.int64(1), T.int64(32), T.int64(1), n):\n with T.block(\"T_softmax_norm\"):\n v_i0, v_i1, v_i2, v_i3 = T.axis.remap(\"SSSS\", [i0, i1, i2, i3])\n T.reads(T_softmax_exp[v_i0, v_i1, v_i2, v_i3], T_softmax_expsum[v_i0, v_i1, v_i2])\n T.writes(T_softmax_norm[v_i0, v_i1, v_i2, v_i3])\n T.block_attr({\"axis\": 3})\n T_softmax_norm[v_i0, v_i1, v_i2, v_i3] = T_softmax_exp[v_i0, v_i1, v_i2, v_i3] / T_softmax_expsum[v_i0, v_i1, v_i2]\n\n\n\ndef softmax_1xn_sch_func(f_softmax, cast_to_fp16: bool = False):\n sch = tvm.tir.Schedule(f_softmax)\n if cast_to_fp16:\n b_cast = sch.get_block(\"compute\")\n sch.reverse_compute_inline(b_cast)\n\n b0 = sch.get_block(\"T_softmax_exp\")\n sch.compute_inline(b0)\n b1 = sch.get_block(\"T_softmax_norm\")\n l2, l3, l4, l5 = sch.get_loops(b1)\n l6, l7 = sch.split(l5, [None, 128])\n sch.bind(l7, \"threadIdx.x\")\n b8 = sch.get_block(\"T_softmax_expsum\")\n sch.compute_at(b8, l4)\n sch.set_scope(b8, 0, \"shared\")\n l9, l10, l11, l12 = sch.get_loops(b8)\n l13, l14 = sch.split(l12, [None, 128])\n sch.bind(l14, \"threadIdx.x\")\n b15 = sch.get_block(\"T_softmax_maxelem\")\n sch.compute_at(b15, l4)\n sch.set_scope(b15, 0, \"shared\")\n l16, l17, l18, l19 = sch.get_loops(b15)\n l20, l21 = sch.split(l19, [None, 128])\n sch.bind(l21, \"threadIdx.x\")\n l22 = sch.fuse(l2, l3, l4)\n sch.bind(l22, \"blockIdx.x\")\n return sch.mod[\"main\"].with_attr(\"tir.is_scheduled\", 1)\n\n\n@T.prim_func\ndef matmul1_before(var_rxplaceholder: T.handle, var_rxplaceholder_1: T.handle, matmul: T.Buffer((T.int64(1), T.int64(32), T.int64(1), T.int64(128)), \"float32\")):\n T.func_attr({\"tir.noalias\": T.bool(True)})\n n = T.int64()\n rxplaceholder = T.match_buffer(var_rxplaceholder, (T.int64(1), T.int64(32), T.int64(1), n))\n rxplaceholder_1 = T.match_buffer(var_rxplaceholder_1, (T.int64(1), T.int64(32), n, T.int64(128)))\n # with T.block(\"root\"):\n for i0, i1, i2, i3, k in T.grid(T.int64(1), T.int64(32), T.int64(1), T.int64(128), n):\n with T.block(\"matmul\"):\n v_i0, v_i1, v_i2, v_i3, v_k = T.axis.remap(\"SSSSR\", [i0, i1, i2, i3, k])\n T.reads(rxplaceholder[v_i0, v_i1, v_i2, v_k], rxplaceholder_1[v_i0, v_i1, v_k, v_i3])\n T.writes(matmul[v_i0, v_i1, v_i2, v_i3])\n with T.init():\n matmul[v_i0, v_i1, v_i2, v_i3] = T.float32(0)\n matmul[v_i0, v_i1, v_i2, v_i3] = matmul[v_i0, v_i1, v_i2, v_i3] + rxplaceholder[v_i0, v_i1, v_i2, v_k] * rxplaceholder_1[v_i0, v_i1, v_k, v_i3]\n\n\n@T.prim_func\ndef matmul1_after(var_rxplaceholder: T.handle, var_rxplaceholder_1: T.handle, matmul: T.Buffer((T.int64(1), T.int64(32), T.int64(1), T.int64(128)), \"float32\")):\n T.func_attr({\"tir.noalias\": T.bool(True), \"tir.is_scheduled\": 1})\n n = T.int64()\n rxplaceholder = T.match_buffer(var_rxplaceholder, (T.int64(1), T.int64(32), T.int64(1), n))\n rxplaceholder_1 = T.match_buffer(var_rxplaceholder_1, (T.int64(1), T.int64(32), n, T.int64(128)))\n # with T.block(\"root\"):\n matmul_local = T.alloc_buffer((T.int64(1), T.int64(32), T.int64(1), T.int64(128)), scope=\"local\")\n rxplaceholder_pad_shared = T.alloc_buffer((T.int64(1), T.int64(32), T.int64(1), (n + T.int64(127)) // T.int64(128) * T.int64(128)), scope=\"shared\")\n rxplaceholder_1_pad_shared = T.alloc_buffer((T.int64(1), T.int64(32), (n + T.int64(127)) // T.int64(128) * T.int64(128), T.int64(128)), scope=\"shared\")\n for i0_0_i1_0_i2_0_i3_0_fused in T.thread_binding(T.int64(16), thread=\"blockIdx.x\", annotations={\"pragma_auto_unroll_max_step\": 512, \"pragma_unroll_explicit\": 1}):\n for i0_1_i1_1_i2_1_i3_1_fused in T.thread_binding(T.int64(1), thread=\"vthread.x\"):\n for i0_2_i1_2_i2_2_i3_2_fused in T.thread_binding(T.int64(128), thread=\"threadIdx.x\"):\n for i0_3_init, i1_3_init, i2_3_init, i3_3_init, i0_4_init, i1_4_init, i2_4_init, i3_4_init in T.grid(T.int64(1), T.int64(1), T.int64(1), T.int64(1), T.int64(1), T.int64(2), T.int64(1), T.int64(1)):\n with T.block(\"matmul_init\"):\n v_i0 = T.axis.spatial(T.int64(1), i0_3_init + i0_4_init)\n v_i1 = T.axis.spatial(T.int64(32), i0_2_i1_2_i2_2_i3_2_fused // T.int64(8) * T.int64(2) + i1_3_init * T.int64(2) + i1_4_init)\n v_i2 = T.axis.spatial(T.int64(1), i2_3_init + i2_4_init)\n v_i3 = T.axis.spatial(T.int64(128), i0_0_i1_0_i2_0_i3_0_fused * T.int64(8) + i0_2_i1_2_i2_2_i3_2_fused % T.int64(8) + i3_3_init + i3_4_init)\n T.reads()\n T.writes(matmul_local[v_i0, v_i1, v_i2, v_i3])\n T.block_attr({\"meta_schedule.thread_extent_high_inclusive\": 256, \"meta_schedule.thread_extent_low_inclusive\": 32, \"meta_schedule.tiling_structure\": \"SSSRRSRS\"})\n matmul_local[v_i0, v_i1, v_i2, v_i3] = T.float32(0)\n for k_0, k_1_0 in T.grid((n + T.int64(127)) // T.int64(128), T.int64(8)):\n for ax0_ax1_ax2_ax3_fused_0 in range(T.int64(1)):\n for ax0_ax1_ax2_ax3_fused_1 in T.thread_binding(T.int64(128), thread=\"threadIdx.x\"):\n for ax0_ax1_ax2_ax3_fused_2 in T.vectorized(T.int64(4)):\n with T.block(\"rxplaceholder_pad_shared\"):\n v0 = T.axis.spatial(T.int64(1), T.int64(0))\n v1 = T.axis.spatial(T.int64(32), (ax0_ax1_ax2_ax3_fused_0 * T.int64(512) + ax0_ax1_ax2_ax3_fused_1 * T.int64(4) + ax0_ax1_ax2_ax3_fused_2) // T.int64(16))\n v2 = T.axis.spatial(T.int64(1), T.int64(0))\n...\n", "needle_token_start": 13651, "needle_token_end": 14154, "code_context_ntokens": 16251, "output": ["```python\ndef softmax_1xn_sch_func(f_softmax, cast_to_fp16: bool = False):\n sch = tvm.tir.Schedule(f_softmax)\n if cast_to_fp16:\n b_cast = sch.get_block(\"compute\")\n sch.reverse_compute_inline(b_cast)\n\n b0 = sch.get_block(\"T_softmax_exp\")\n sch.compute_inline(b0)\n b1 = sch.get_block(\"T_softmax_norm\")\n l2, l3, l4, l5 = sch.get_loops(b1)\n l6, l7 = sch.split(l5, [None, 128])\n sch.bind(l7, \"threadIdx.x\")\n b8 = sch.get_block(\"T_softmax_expsum\")\n sch.compute_at(b8, l4)\n sch.set_scope(b8, 0, \"shared\")\n l9, l10, l11, l12 = sch.get_loops(b8)\n l13, l14 = sch.split(l12, [None, 128])\n sch.bind(l14, \"threadIdx.x\")\n b15 = sch.get_block(\"T_softmax_maxelem\")\n sch.compute_at(b15, l4)\n sch.set_scope(b15, 0, \"shared\")\n l16, l17, l18, l19 = sch.get_loops(b15)\n l20, l21 = sch.split(l19, [None, 128])\n sch.bind(l21, \"threadIdx.x\")\n l22 = sch.fuse(l2, l3, l4)\n sch.bind(l22, \"blockIdx.x\")\n return sch.mod[\"main\"].with_attr(\"tir.is_scheduled\", 1)\n```"]} +{"repo": "mlc-ai/mlc-llm", "name": "f_quantize", "language": "python", "path": "mlc_llm/quantization/ft_quantization.py", "position_ratio": 0.95, "description": "\nFunction Description:\n1. **Purpose**: The function is designed to quantize and preprocess neural network weights for optimized processing on specific GPU architectures, particularly for use in deep learning inference.\n2. **Input**: The function takes a list of expressions representing neural network weights.\n3. **Output**: It returns a tuple containing the preprocessed and quantized weights along with additional encoded data.\n4. **Procedure**: \n - The function first quantizes the input weights using a specified bit precision and group size.\n - It then optionally preprocesses the quantized weights using a CUDA kernel tailored for specific NVIDIA GPU architectures, enhancing compatibility and performance for deep learning computations.\n\n", "instruction": "Based on the function description and code context, please retrieve and repeat the exact described function from the code context in a code block wrapped by ```:", "template": "instruction\ncode_context\ndescription\ninstruction", "code_context": "# Path: mlc_llm/quantization/quantization.py\nimport enum\nfrom dataclasses import dataclass\nfrom typing import Any, Callable, List, Literal, Optional, Tuple, Type, Union\n\nimport tvm\nfrom tvm import relax, te\nfrom tvm.relax.expr_functor import PyExprVisitor, visitor\n\nFQuantize = Callable[[relax.BlockBuilder, List[relax.Expr]], relax.Var]\nFTEQuantize = Callable[[te.Tensor], List[te.Tensor]]\nFTEDequantize = Callable[[List[te.Tensor]], te.Tensor]\n\n\n@dataclass\nclass QuantizationSpec:\n \"\"\"The base dataclass of quantization specification.\n A specification describes how a parameter is quantized and dequantized.\n\n A subclass of QuantizationSpec\n - contains more data fields (e.g., the \"group size\" in group quantization)\n which instruct the quantization/dequantization,\n - defines the `get_quantize_func` method, which returns a function\n (`Callable[[relax.BlockBuilder, List[relax.Expr]], relax.Var]`) that takes a\n Relax BlockBuilder and the weight relax Var to be quantized, computes\n the quantization and returns the relax Var of quantized results.\n algorithm of the quantization.\n - defines the `get_dequantize_func` method, which returns function\n (`Callable[[relax.BlockBuilder, List[relax.Expr]], relax.Var]`) that takes\n the quantized results, computes and returns the dequantization result.\n - optionally overloads the `get_loaded_tensor_info` when the parameter is\n pre-quantized, in which case `get_loaded_tensor_info` needs to be overloaded\n so that we know how many quantized data tensors there are, and the dtype\n and shape of each quantized data tensor.\n \"\"\"\n\n dtype: str\n\n def get_loaded_tensor_info(\n self, pname: str, param_info: relax.TensorStructInfo\n ) -> Tuple[List[str], List[relax.TensorStructInfo]]:\n \"\"\"Returns the names and shapes and dtypes of the tensors that need to\n be loaded from the disk.\n\n It is useful when the parameter is pre-quantized. In such cases, we need\n to know how many tensors the parameter is quantized into, and together\n with the dtype and shape of each tensor, so that we can load the\n pre-quantized tensors in.\n \"\"\"\n return [pname], [param_info]\n\n def get_quantize_func(self, param_info: relax.TensorStructInfo) -> Optional[FQuantize]:\n \"\"\"Returns the function which computes quantization.\n Returning `None` means the parameter does not need quantization or is\n pre-quantized.\n\n The returned function takes a Relax BlockBuilder and a (list of) weight\n relax Var to be quantized, computes the quantization and returns the\n quantization result Relax Var(s).\n\n You can use `convert_TE_func` to convert a TE function to the function\n of the desired return format. See `group_quantization.py` for examples.\n \"\"\"\n return NotImplementedError()\n\n def get_dequantize_func(\n self,\n param_info: relax.TensorStructInfo,\n qparam_info: List[relax.TensorStructInfo],\n ) -> Optional[FQuantize]:\n \"\"\"Returns the function which computes dequantization.\n Returning `None` means the parameter does not need dequantization.\n\n The returned function takes a Relax BlockBuilder and a (list of)\n quantized weight relax Var, computes the dequantization and returns the\n result Relax Var(s).\n\n You can use `convert_TE_func` to convert a TE function to the function\n of the desired return format. See `group_quantization.py` for examples.\n \"\"\"\n return NotImplementedError()\n\n\n@dataclass\nclass NoQuantizationSpec(QuantizationSpec):\n \"\"\"The quantization specification that describes doing no quantization.\"\"\"\n\n def get_quantize_func(self, param_info: relax.TensorStructInfo) -> Optional[FQuantize]:\n return None\n\n def get_dequantize_func(\n self,\n param_info: relax.TensorStructInfo,\n qparam_info: List[relax.TensorStructInfo],\n ) -> Optional[FQuantize]:\n return None\n\n\nclass ParamQuantKind(enum.IntEnum):\n \"\"\"The parameter quantization kind class.\n\n We categorized all the parameters in a model into four kinds:\n - the weights of the internal linear layers, which are the main targets of quantization,\n - the embedding table of every token,\n - the weight of the fully-connected layer at the end of the model, which is\n used for computes the logits of each input token,\n - other parameters (e.g., the weight of layer normalization, etc.).\n \"\"\"\n\n linear_weight = 0\n embedding_table = 1\n final_fc_weight = 2\n others = 3\n\n\nclass QuantizationScheme:\n \"\"\"The quantization scheme class describes how an entire model is quantized.\n It contains the quantization specification for each parameter quantization kind.\n\n Besides, it has an optional field for a visitor class which has the ability to\n take the constructed model (in format of IRModule) as input, go through the\n model and update the QuantizationSpec for certain parameters.\n \"\"\"\n\n name: str\n linear_weight: QuantizationSpec\n embedding_table: QuantizationSpec\n final_fc_weight: QuantizationSpec\n others: QuantizationSpec\n\n qspec_updater_class: Optional[Type[\"QuantSpecUpdater\"]]\n f_convert_param_bkwd: Optional[Callable[[str, Any], Optional[List[Tuple[str, Any]]]]]\n f_compute_relax_param: Optional[Callable[[str, List[Any]], Any]]\n f_run_prequantize: Optional[Callable[[str], str]]\n\n def __init__(\n self,\n name: str,\n linear_weight: QuantizationSpec,\n *,\n embedding_table: Optional[Union[QuantizationSpec, Literal[\"same_as_linear_weight\"]]] = None,\n final_fc_weight: Optional[Union[QuantizationSpec, Literal[\"same_as_linear_weight\"]]] = None,\n others: Optional[QuantizationSpec] = None,\n qspec_updater_class: Optional[Type[\"QuantSpecUpdater\"]] = None,\n ) -> None:\n self.name = name\n self.linear_weight = linear_weight\n self.others = others if others is not None else NoQuantizationSpec(self.model_dtype)\n\n if embedding_table is None:\n self.embedding_table = self.others\n elif embedding_table == \"same_as_linear_weight\":\n self.embedding_table = self.linear_weight\n else:\n self.embedding_table = embedding_table\n\n if final_fc_weight is None:\n self.final_fc_weight = self.others\n elif final_fc_weight == \"same_as_linear_weight\":\n self.final_fc_weight = self.linear_weight\n else:\n self.final_fc_weight = final_fc_weight\n\n self.qspec_updater_class = qspec_updater_class\n self.f_convert_param_bkwd = None\n self.f_compute_relax_param = None\n self.f_run_prequantize = None\n\n for spec in [self.linear_weight, self.embedding_table, self.final_fc_weight, self.others]:\n if hasattr(spec, \"convert_param_bkwd\"):\n self.f_convert_param_bkwd = spec.convert_param_bkwd\n if hasattr(spec, \"compute_relax_param\"):\n self.f_compute_relax_param = spec.compute_relax_param\n if hasattr(spec, \"run_prequantize\"):\n self.f_run_prequantize = spec.run_prequantize\n\n @property\n def model_dtype(self) -> str:\n \"\"\"Returns the overall model dtype, which is defined as the dtype of\n the linear layers.\n \"\"\"\n return self.linear_weight.dtype\n\n\ndef convert_TE_func(te_func: Union[FTEQuantize, FTEDequantize], func_name: str) -> FQuantize:\n def func(bb: relax.BlockBuilder, inputs: List[relax.Expr]) -> relax.Var:\n return bb.call_te(te_func, *inputs, primfunc_name_hint=func_name)\n\n return func\n\n\n@visitor\nclass QuantSpecUpdater(PyExprVisitor):\n def __init__(self, param_manager) -> None:\n super().__init__()\n self.param_manager = param_manager\n self.param_map = None\n self.builder = relax.BlockBuilder()\n\n def lookup_binding(self, var: relax.Var):\n return self.builder.lookup_binding(var)\n\n def visit_module(self, mod: tvm.IRModule):\n for gv, func in mod.functions.items():\n if not isinstance(func, relax.Function):\n continue\n if func.attrs is None or not \"num_input\" in func.attrs:\n continue\n\n self.param_map = dict()\n num_input = int(func.attrs[\"num_input\"])\n params_in_func = self.param_manager.params_in_func[gv.name_hint]\n assert len(func.params) - num_input == len(params_in_func)\n for i, relax_param in enumerate(func.params[num_input:]):\n self.param_map[relax_param] = params_in_func[i]\n\n self.builder.normalize(func)\n self.visit_expr(func)\n\n# Path: mlc_llm/quantization/tir_utils.py\n\"\"\"TIR computation utilities for quantization.\"\"\"\n\nimport tvm\nfrom tvm import tir\n\n# fmt: off\ndef _tir_f32x2_to_bf16x2_to_u32(v0: tir.PrimExpr, v1: tir.PrimExpr, round_to_even: bool=True):\n mask = tir.const((1 << 16) - 1, \"uint32\")\n res = []\n for data in [v0, v1]:\n u32_val = tir.reinterpret(\"uint32\", data)\n if round_to_even:\n rounding_bias = ((u32_val >> tir.const(16, \"uint32\")) & tir.const(1, \"uint32\")) + tir.const(0x7FFF, \"uint32\")\n u32_val += rounding_bias\n res.append((u32_val >> tir.const(16, \"uint32\")) & mask)\n return res[0] | (res[1] << tir.const(16, \"uint32\"))\n\n\ndef _tir_u32_to_bf16x2_to_f32x2(x: tir.PrimExpr):\n mask = tir.const((1 << 16) - 1, \"uint32\")\n x0 = x & mask\n x1 = (x >> 16) & mask\n return (tir.reinterpret(\"float32\", x << tir.const(16, \"uint32\")) for x in [x0, x1])\n\n\ndef _tir_u32_to_int_to_float(nbit: int, val: tir.PrimExpr, pos: tir.PrimExpr, dtype: str):\n assert val.dtype == \"uint32\"\n mask = tvm.tir.const((1 << nbit) - 1, \"uint32\")\n return tir.Cast(dtype, (val >> (pos * nbit).astype(\"uint32\")) & mask)\n\n\ndef _tir_packed_uint_to_uint_to_float(storage_nbit: int):\n storage_dtype = \"uint\" + str(storage_nbit)\n\n def f_convert(nbit: int, val: tir.PrimExpr, pos: tir.PrimExpr, dtype: str):\n assert val.dtype == storage_dtype\n max_int_value = (1 << (nbit - 1)) - 1\n return ((val >> (pos.astype(\"uint32\") * tir.const(nbit, \"uint32\"))) & tir.const((1 << nbit) - 1, \"uint32\")).astype(dtype) - tir.const(max_int_value, dtype)\n\n return f_convert\n\n\ndef _tir_packed_int_to_int_to_float(storage_nbit: int):\n storage_dtype = \"int\" + str(storage_nbit)\n\n def f_convert(nbit: int, val: tir.PrimExpr, pos: tir.PrimExpr, dtype: str):\n assert val.dtype == storage_dtype\n mask = tir.const((1 << nbit) - 1, \"int32\")\n unextended = (val >> (pos.astype(\"int32\") * tir.const(nbit, \"int32\"))) & mask\n return tir.Cast(dtype, (unextended << tir.const(32 - nbit, \"int32\")) >> tir.const(32 - nbit, \"int32\"))\n\n return f_convert\n\n\ndef _tir_f32_to_uint_to_f4(val: tir.PrimExpr):\n assert val.dtype == \"float32\"\n val_u32 = tir.reinterpret(\"uint32\", val)\n # e_f32 > 120 -> e_f4 = min(e_f32 - 120 + M_h, 7)\n # e_f32 == 120 -> e_f4 = 1\n # e_f32 < 120 -> e_f4 = 0\n m_h = (val_u32 >> tir.const(22, \"uint32\")) & tir.const(1, \"uint32\")\n e_f32 = (val_u32 >> tir.const(23, \"uint32\")) & tir.const(255, \"uint32\")\n s = (val_u32 >> tir.const(31, \"uint32\"))\n e_f4 = tir.Select(e_f32 > tir.const(120, \"uint32\"), tir.Min(e_f32 - tir.const(120, \"uint32\") + m_h, tir.const(7, \"uint32\")), tir.Select(e_f32 == tir.const(120, \"uint32\"), tir.const(1, \"uint32\"), tir.const(0, \"uint32\")))\n return (s << tir.const(3, \"uint32\")) | e_f4\n\n\ndef _tir_f16_to_uint_to_f4(val: tir.PrimExpr):\n assert val.dtype == \"float16\"\n val_u32 = tir.Cast(\"uint32\", tir.reinterpret(\"uint16\", val))\n m_h = (val_u32 >> tir.const(9, \"uint32\")) & tir.const(1, \"uint32\")\n e_f16 = (val_u32 >> tir.const(10, \"uint32\")) & tir.const(31, \"uint32\")\n s = (val_u32 >> tir.const(15, \"uint32\"))\n e_f4 = tir.Select(e_f16 > tir.const(8, \"uint32\"), tir.Min(e_f16 - tir.const(8, \"uint32\") + m_h, tir.const(7, \"uint32\")), tir.Select(e_f16 == tir.const(8, \"uint32\"), tir.const(1, \"uint32\"), tir.const(0, \"uint32\")))\n return (s << tir.const(3, \"uint32\")) | e_f4\n\n\ndef _tir_u32_to_f4_to_f32(nbit: int, val: tir.PrimExpr, pos: tir.PrimExpr, dtype: str):\n assert nbit == 4\n assert dtype == \"float32\"\n assert val.dtype == \"uint32\"\n # e_f4 == 0 -> e_f32 = 0\n # e_f4 != 0 -> e_f32 = e_f4 + 120 = e_f4 | (1111000)_2\n mask = tvm.tir.const((1 << nbit) - 1, \"uint32\")\n f4 = (val >> (pos.astype(\"uint32\") * tir.const(nbit, \"uint32\"))) & mask\n s = f4 >> tir.const(3, \"uint32\")\n e_f4 = f4 & tir.const(7, \"uint32\")\n e_f32 = e_f4 | tir.const(120, \"uint32\")\n val_f32 = tir.reinterpret(\"float32\", (e_f32 | (s << tir.const(8, \"uint32\"))) << tir.const(23, \"uint32\"))\n return tir.Select(e_f4 == tir.const(0, \"uint32\"), tir.const(0, \"float32\"), val_f32)\n\n\ndef _tir_u32_to_f4_to_f16(nbit: int, val: tir.PrimExpr, pos: tir.PrimExpr, dtype: str):\n assert nbit == 4\n assert dtype == \"float16\"\n assert val.dtype == \"uint32\"\n # e_f4 == 0 -> e_f16 = 0\n # e_f4 != 0 -> e_f16 = e_f4 + 8 = e_f4 | (1000)_2\n mask = tvm.tir.const((1 << nbit) - 1, \"uint32\")\n f4 = (val >> (pos.astype(\"uint32\") * tir.const(nbit, \"uint32\"))) & mask\n s = f4 >> tir.const(3, \"uint32\")\n e_f4 = f4 & tir.const(7, \"uint32\")\n e_f16 = e_f4 | tir.const(8, \"uint32\")\n val_f16 = tir.reinterpret(\"float16\", (e_f16 | (s << tir.const(5, \"uint32\"))) << tir.const(10, \"uint32\"))\n return tir.Select(e_f4 == tir.const(0, \"uint32\"), tir.const(0, \"float16\"), val_f16)\n# fmt: on\n\n# Path: mlc_llm/quantization/autogptq_quantization.py\nfrom dataclasses import dataclass\nfrom typing import Any, List, Literal, Optional, Tuple\nfrom tvm import relax, te, tir, topi\nfrom . import tir_utils\nfrom .quantization import QuantizationSpec\nfrom .quantization import FQuantize, FTEDequantize, convert_TE_func\n\n\n@dataclass\nclass AutogptqQuantizationSpec(QuantizationSpec):\n \"\"\"The quantization specification for group quantization algorithm.\"\"\"\n\n mode: Literal[\"int2\", \"int3\", \"int4\", \"int8\"]\n sym: bool\n group_size: int\n storage_nbit: int = 32\n\n quantized_suffix = [\"qweight\", \"qzeros\", \"scales\", \"g_idx\"]\n\n def get_loaded_tensor_info(\n self, pname: str, param_info: relax.TensorStructInfo\n ) -> Tuple[List[str], List[relax.TensorStructInfo]]:\n assert self.storage_nbit == 32, \"Only support 32bit storage currently\"\n\n quantized_pnames = self.quant_convert_pname_fwd(pname)\n if len(quantized_pnames) == 1:\n return quantized_pnames, [param_info]\n else:\n assert len(quantized_pnames) == 4\n assert param_info.ndim == 2\n nbit = int(self.mode[-1])\n tensor_info = []\n outfeatures, infeatures = param_info.shape.values\n group_size = self.group_size if self.group_size != -1 else infeatures\n\n def get_quantized_shape_dtype(quantized_pname: str):\n if quantized_pname.endswith(\"qweight\"):\n return (infeatures // self.storage_nbit * nbit, outfeatures), \"uint32\"\n elif quantized_pname.endswith(\"qzeros\"):\n return (\n infeatures // group_size,\n outfeatures // self.storage_nbit * nbit,\n ), \"uint32\"\n elif quantized_pname.endswith(\"scales\"):\n return (infeatures // group_size, outfeatures), \"float16\"\n elif quantized_pname.endswith(\"g_idx\"):\n return (infeatures,), \"uint32\"\n else:\n raise ValueError(f\"Unrecognized quantized parameter name {quantized_pname}\")\n\n for quantized_pname in quantized_pnames:\n shape, dtype = get_quantized_shape_dtype(quantized_pname)\n tensor_info.append(relax.TensorStructInfo(shape, dtype))\n\n return quantized_pnames, tensor_info\n\n def quant_convert_pname_fwd(self, torch_pname: str) -> List[str]:\n # For Llama:\n if \"_proj.weight\" in torch_pname:\n return [torch_pname.replace(\"weight\", suffix) for suffix in self.quantized_suffix]\n return [torch_pname]\n\n def run_prequantize(self, model_path: str) -> str:\n # with auto-gptq >= 0.2.0\n try:\n import auto_gptq # pylint: disable=import-outside-toplevel\n import transformers # pylint: disable=import-outside-toplevel\n except ImportError:\n raise ImportError(\n \"Please install auto_gptq package (version >= 0.2.0) and \"\n \"transformers package to use AutoGPTQ quantization.\"\n )\n import os\n from transformers import AutoTokenizer\n from auto_gptq import AutoGPTQForCausalLM, BaseQuantizeConfig\n\n quantized_model_path = (\n model_path\n + f\"-gptq-i{self.mode[-1]}\"\n + (\"-sym\" if self.sym else \"\")\n + f\"-g{self.group_size}\"\n )\n if os.path.isdir(quantized_model_path):\n return quantized_model_path\n\n tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=True)\n examples = [\n tokenizer(\n \"MLC LLM is a universal solution that allows any language models \"\n \"to be deployed natively on a diverse set of hardware backends and \"\n \"native applications, plus a productive framework for everyone to \"\n \"further optimize model performance for their own use cases.\"\n )\n ]\n quantize_config = BaseQuantizeConfig(\n bits=int(self.mode[-1]), # quantize bits\n desc_act=False, # disable activation description\n group_size=self.group_size, # disable group quantization\n )\n\n model = AutoGPTQForCausalLM.from_pretrained(model_path, quantize_config)\n model.quantize(examples)\n\n # save quantized model\n model.save_quantized(quantized_model_path)\n tokenizer.save_pretrained(quantized_model_path)\n return quantized_model_path\n\n def get_quantize_func(self, param_info: relax.TensorStructInfo) -> Optional[FQuantize]:\n return None\n\n def get_dequantize_func(\n self,\n param_info: relax.TensorStructInfo,\n qparam_info: List[relax.TensorStructInfo],\n ) -> Optional[FQuantize]:\n return convert_TE_func(\n decoding_func(\n sym=self.sym,\n nbit=int(self.mode[-1]),\n storage_nbit=self.storage_nbit,\n dim_length=param_info.shape.values[-1],\n dtype=self.dtype,\n ),\n func_name=\"decode\",\n )\n\n def convert_param_bkwd(self, torch_pname: str, torch_param):\n target_dtype = (\n self.dtype if \"_proj.\" not in torch_pname or \"scales\" in torch_pname else \"uint32\"\n )\n\n # For Llama\n combined_layers = [\"q_proj\", \"k_proj\", \"v_proj\", \"gate_proj\", \"up_proj\"]\n if any([name in torch_pname for name in combined_layers]):\n return None\n return [(torch_pname, torch_param.astype(target_dtype))]\n\n def compute_relax_param(self, relax_pname: str, torch_params: List[Any]):\n import numpy as np\n\n # For Llama\n if \"query_key_value_proj\" in relax_pname:\n assert len(torch_params) == 3\n elif \"gate_up_proj\" in relax_pname:\n assert len(torch_params) == 2\n else:\n raise ValueError(\"Unexpected param loading\")\n\n if \"g_idx\" in relax_pname:\n return torch_params[0].astype(\"uint32\")\n else:\n target_dtype = self.dtype if \"scales\" in relax_pname else \"uint32\"\n return np.concatenate(torch_params, axis=-1).astype(target_dtype)\n\n\ndef decoding_func(\n sym: bool,\n nbit: int,\n storage_nbit: int,\n dim_length: tir.PrimExpr,\n dtype: str = \"float16\",\n) -> FTEDequantize:\n assert dtype in [\"float16\"], \"Only support float16 currently\"\n assert sym == False, \"Only support sym=False currently\"\n assert storage_nbit == 32, \"Only support storage_nbit=32 currently\"\n\n def te_decode_asym(qweight, qzeros, scales, g_idx):\n n_float_per_u32 = 32 // nbit\n\n def f_decode_asym(i, j):\n zeros = tir_utils._tir_u32_to_int_to_float(\n nbit,\n qzeros[g_idx[i], j // n_float_per_u32],\n j % n_float_per_u32,\n dtype=dtype,\n )\n data_float = tir_utils._tir_u32_to_int_to_float(\n nbit,\n qweight[i // n_float_per_u32, j],\n i % n_float_per_u32,\n dtype=dtype,\n )\n scale_float, bias_float = scales[g_idx[i], j], zeros + 1\n w = (data_float - bias_float) * scale_float\n return w\n\n shape = (dim_length, qweight.shape[1])\n w = te.compute(shape=shape, fcompute=f_decode_asym, name=\"decode\")\n w = topi.transpose(w)\n return w\n\n return te_decode_asym\n\n# Path: mlc_llm/quantization/group_quantization.py\nfrom dataclasses import dataclass\nfrom typing import List, Literal, Optional\n\nimport tvm\nfrom tvm import relax, te, tir, topi\nfrom tvm.script import tir as T\nfrom tvm.relax.expr_functor import visitor\n\nfrom . import tir_utils\nfrom .quantization import QuantizationSpec, QuantSpecUpdater\nfrom .quantization import NoQuantizationSpec\nfrom .quantization import FQuantize, FTEQuantize, FTEDequantize, convert_TE_func\n\n\n@dataclass\nclass GroupQuantizationSpec(QuantizationSpec):\n \"\"\"The quantization specification for group quantization algorithm.\"\"\"\n\n mode: Literal[\"int3\", \"int4\"]\n sym: bool\n storage_nbit: int\n group_size: int\n transpose: bool\n\n def get_quantize_func(self, param_info: relax.TensorStructInfo) -> Optional[FQuantize]:\n return convert_TE_func(\n encoding_func(\n sym=self.sym,\n group_size=self.group_size,\n nbit=int(self.mode[-1]),\n mode=self.mode,\n storage_nbit=self.storage_nbit,\n transpose=self.transpose,\n dtype=self.dtype,\n ),\n func_name=\"encode\",\n )\n\n def get_dequantize_func(\n self,\n param_info: relax.TensorStructInfo,\n qparam_info: List[relax.TensorStructInfo],\n ) -> Optional[FQuantize]:\n return convert_TE_func(\n decoding_func(\n sym=self.sym,\n group_size=self.group_size,\n nbit=int(self.mode[-1]),\n mode=self.mode,\n storage_nbit=self.storage_nbit,\n dim_length=param_info.shape.values[-1],\n data_transposed=self.transpose,\n transpose_output=self.transpose,\n dtype=self.dtype,\n ),\n func_name=\"decode\",\n )\n\n\n# fmt: off\ndef encoding_func(sym: bool, group_size: int, nbit: int, mode: str, storage_nbit: int, transpose: bool=True, dtype: str = \"float32\") -> FTEQuantize:\n def te_encode_asym(weight: te.Tensor):\n assert weight.shape[1] % group_size == 0\n n_group = weight.shape[1] // group_size\n n_float_per_u32 = 32 // nbit\n\n scale_min_shape = (weight.shape[0], n_group)\n k = te.reduce_axis((0, group_size), name=\"k\")\n min_value = te.compute(shape=scale_min_shape, fcompute=lambda i, j: te.min(weight[i, j * group_size + k], axis=k), name=\"min_value\")\n max_value = te.compute(shape=scale_min_shape, fcompute=lambda i, j: te.max(weight[i, j * group_size + k], axis=k), name=\"max_value\")\n scale = te.compute(shape=scale_min_shape, fcompute=lambda i, j: (max_value[i, j] - min_value[i, j]) / tir.const((1 << nbit) - 1, dtype), name=\"scale\")\n\n def f_scale_weight(i, j):\n group_idx = j // group_size\n w_scaled = tir.round((weight[i, j] - min_value[i, group_idx]) / scale[i, group_idx]).astype(\"int32\")\n w_scaled = T.min(T.max(w_scaled, tir.const(0, \"int32\")), tir.const((1 << nbit) - 1, \"int32\"))\n w_scaled = w_scaled.astype(\"uint32\")\n return w_scaled\n\n k = te.reduce_axis((0, n_float_per_u32), name=\"k\")\n reducer = te.comm_reducer(fcombine=lambda x, y: tir.bitwise_or(x, y), fidentity=lambda dtype: tir.const(0, dtype), name=\"bitwise_or\")\n if dtype == \"float32\":\n if transpose:\n w_gathered = te.compute(shape=(weight.shape[1] // n_float_per_u32, weight.shape[0]), fcompute=lambda j, i: reducer(f_scale_weight(i, j * n_float_per_u32 + k) << (k * nbit).astype(\"uint32\"), axis=k), name=\"w_gathered\")\n scale_bias = te.compute(shape=(n_group, weight.shape[0]), fcompute=lambda j, i: tir_utils._tir_f32x2_to_bf16x2_to_u32(scale[i, j], min_value[i, j], round_to_even=True), name=\"scale_min\")\n else:\n w_gathered = te.compute(shape=(weight.shape[0], weight.shape[1] // n_float_per_u32), fcompute=lambda i, j: reducer(f_scale_weight(i, j * n_float_per_u32 + k) << (k * nbit).astype(\"uint32\"), axis=k), name=\"w_gathered\")\n scale_bias = te.compute(shape=(weight.shape[0], n_group), fcompute=lambda i, j: tir_utils._tir_f32x2_to_bf16x2_to_u32(scale[i, j], min_value[i, j], round_to_even=True), name=\"scale_min\")\n return w_gathered, scale_bias\n else:\n if transpose:\n w_gathered = te.compute(shape=(weight.shape[1] // n_float_per_u32, weight.shape[0]), fcompute=lambda j, i: reducer(f_scale_weight(i, j * n_float_per_u32 + k) << (k * nbit).astype(\"uint32\"), axis=k), name=\"w_gathered\")\n scale = te.compute(shape=(n_group, weight.shape[0]), fcompute=lambda j, i: scale[i, j], name=\"scale_transpose\")\n min_value = te.compute(shape=(n_group, weight.shape[0]), fcompute=lambda j, i: min_value[i, j], name=\"min_transpose\")\n else:\n w_gathered = te.compute(shape=(weight.shape[0], weight.shape[1] // n_float_per_u32), fcompute=lambda i, j: reducer(f_scale_weight(i, j * n_float_per_u32 + k) << (k * nbit).astype(\"uint32\"), axis=k), name=\"w_gathered\")\n return w_gathered, scale, min_value\n\n def te_encode_sym(weight: te.Tensor):\n n_group = tir.ceildiv(weight.shape[1], group_size)\n n_float_per_int = storage_nbit // nbit\n max_int_value = (1 << (nbit - 1)) - 1\n assert group_size % n_float_per_int == 0\n\n scale_min_shape = (weight.shape[0], n_group)\n k = te.reduce_axis((0, group_size), name=\"k\")\n max_abs_value = te.compute(shape=scale_min_shape, fcompute=lambda i, j: te.max(tir.if_then_else(j * group_size + k < weight.shape[1], te.abs(weight[i, j * group_size + k]), tir.min_value(dtype)), axis=k), name=\"max_abs_value\")\n\n def f_compute_scale(i, j):\n max_value = tir.max(max_abs_value[i, j], tir.const(1e-4, dtype))\n return (max_value / tir.const(max_int_value, dtype)) if mode.startswith(\"int\") else max_value\n\n scale = te.compute(shape=scale_min_shape, fcompute=f_compute_scale, name=\"scale\")\n storage_dtype = (\"uint\" + str(storage_nbit)) if mode.startswith(\"int\") else \"uint32\"\n\n def f_scale_weight(i, j):\n group_idx = j // group_size\n if mode.startswith(\"int\"):\n w_scaled = tir.round(weight[i, j] / scale[i, group_idx] + tir.const(max_int_value, dtype))\n w_scaled = T.min(T.max(w_scaled, tir.const(0, dtype)), tir.const(max_int_value * 2, dtype)).astype(storage_dtype)\n return w_scaled\n else:\n f_convert = tir_utils._tir_f32_to_uint_to_f4 if dtype == \"float32\" else tir_utils._tir_f16_to_uint_to_f4\n return f_convert(weight[i, j] / scale[i, group_idx])\n\n k = te.reduce_axis((0, n_float_per_int), name=\"k\")\n reducer = te.comm_reducer(fcombine=lambda x, y: tir.bitwise_or(x, y), fidentity=lambda dtype: tir.const(0, dtype), name=\"bitwise_or\")\n n_i32 = tir.ceildiv(group_size, n_float_per_int) * n_group\n if transpose:\n w_gathered = te.compute(shape=(n_i32, weight.shape[0]), fcompute=lambda j, i: reducer(tir.if_then_else(j * n_float_per_int + k < weight.shape[1], f_scale_weight(i, j * n_float_per_int + k) << (k.astype(storage_dtype) * tir.const(nbit, storage_dtype)), tir.const(0, storage_dtype)), axis=k), name=\"w_gathered\")\n scale = te.compute(shape=(n_group, weight.shape[0]), fcompute=lambda j, i: scale[i, j])\n else:\n w_gathered = te.compute(shape=(weight.shape[0], n_i32), fcompute=lambda i, j: reducer(tir.if_then_else(j * n_float_per_int + k < weight.shape[1], f_scale_weight(i, j * n_float_per_int + k) << (k.astype(storage_dtype) * tir.const(nbit, storage_dtype)), tir.const(0, storage_dtype)), axis=k), name=\"w_gathered\")\n return w_gathered, scale\n\n return te_encode_sym if sym else te_encode_asym\n\n\ndef decoding_func(sym: bool, group_size: int, nbit: int, mode: str, storage_nbit: int, dim_length: tir.PrimExpr, data_transposed: bool=True, transpose_output: bool=False, dtype: str = \"float32\") -> FTEDequantize:\n def te_decode_asym(*args):\n n_float_per_u32 = 32 // nbit\n data = args[0]\n if dtype == \"float32\":\n scale_bias_bf16x2 = args[1]\n else:\n scale, min_value = args[1], args[2]\n\n def f_decode_asym(i, j):\n if data_transposed:\n data_float = tir_utils._tir_u32_to_int_to_float(nbit, data[i // n_float_per_u32, j], i % n_float_per_u32, dtype=dtype)\n if dtype == \"float32\":\n scale_float, bias_float = tir_utils._tir_u32_to_bf16x2_to_f32x2(scale_bias_bf16x2[i // group_size, j])\n else:\n scale_float, bias_float = scale[i // group_size, j], min_value[i // group_size, j]\n else:\n data_float = tir_utils._tir_u32_to_int_to_float(nbit, data[i, j // n_float_per_u32], j % n_float_per_u32, dtype=dtype)\n if dtype == \"float32\":\n scale_float, bias_float = tir_utils._tir_u32_to_bf16x2_to_f32x2(scale_bias_bf16x2[i, j // group_size])\n else:\n scale_float, bias_float = scale[i, j // group_size], min_value[i, j // group_size]\n w = data_float * scale_float + bias_float\n return w\n\n shape = (dim_length, data.shape[1]) if data_transposed else (data.shape[0], dim_length)\n w = te.compute(shape=shape, fcompute=f_decode_asym, name=\"decode\")\n if transpose_output:\n w = topi.transpose(w)\n return w\n\n def te_decode_sym(data, scale):\n n_float_per_int = storage_nbit // nbit\n\n def f_decode_sym(i, j):\n f_convert = tir_utils._tir_packed_uint_to_uint_to_float(storage_nbit) if mode.startswith(\"int\") else (tir_utils._tir_u32_to_f4_to_f32 if dtype == \"float32\" else tir_utils._tir_u32_to_f4_to_f16)\n if data_transposed:\n data_float = f_convert(nbit, data[i // n_float_per_int, j], i % n_float_per_int, dtype=dtype)\n scale_float = scale[i // group_size, j]\n else:\n data_float = f_convert(nbit, data[i, j // n_float_per_int], j % n_float_per_int, dtype=dtype)\n scale_float = scale[i, j // group_size]\n return data_float * scale_float\n\n shape = (dim_length, data.shape[1]) if data_transposed else (data.shape[0], dim_length)\n w = te.compute(shape=shape, fcompute=f_decode_sym, name=\"decode\")\n if transpose_output:\n w = topi.transpose(w)\n return w\n\n return te_decode_sym if sym else te_decode_asym\n# fmt: on\n\n\n# A simple example demo showing how QuantSpecUpdater is used.\n# NOTE: This visitor is only for demo purpose and should not be put into real use.\n@visitor\nclass GroupQuantDemoUpdater(QuantSpecUpdater._cls):\n def visit_call_(self, call: relax.Call):\n if call.op != tvm.ir.Op.get(\"relax.matmul\"):\n return\n rhs = self.lookup_binding(call.args[1])\n assert rhs is not None\n if (\n rhs.op != tvm.ir.Op.get(\"relax.permute_dims\")\n or rhs.attrs.axes is not None\n or rhs.args[0].struct_info.ndim != 2\n ):\n return\n\n if rhs.args[0] not in self.param_map:\n return\n param = self.param_map[rhs.args[0]]\n # Update to no quantization for matmul with float32 output dtype.\n if call.struct_info.dtype == \"float32\":\n param.quant_spec = NoQuantizationSpec(param.param_info.dtype)\n\n# Path: mlc_llm/quantization/ft_quantization.py\nfrom dataclasses import dataclass\nfrom typing import List, Optional\n\nimport tvm\nfrom tvm.contrib.nvcc import parse_compute_version\nfrom tvm import relax, te, tir, topi\nfrom tvm.script import tir as T\nfrom tvm.relax.expr_functor import visitor\n\nfrom . import tir_utils\nfrom .quantization import QuantizationSpec, QuantSpecUpdater\nfrom .quantization import FQuantize, convert_TE_func\nfrom .group_quantization import GroupQuantizationSpec\n\n\n@dataclass\nclass FTQuantizationSpec(QuantizationSpec):\n \"\"\"The quantization specification for the FasterTransformer kernel.\"\"\"\n\n def __init__(self, dtype, nbit, group_size=-1):\n super().__init__(dtype)\n self.nbit = nbit\n assert group_size in [-1, 64, 128], f\"Group size {group_size} is not supported.\"\n self.group_size = group_size\n\n if tvm.cuda(0).exist:\n major, minor = parse_compute_version(tvm.cuda(0).compute_version)\n if major == 8:\n self.sm = 80\n else:\n self.sm = 10 * major + minor\n else:\n self.sm = None\n\n self.do_preprocess = True\n\n def get_quantize_func(self, param_info: relax.TensorStructInfo) -> Optional[FQuantize]:\n assert self.sm is not None\n\n \ndef f_quantize(bb: relax.BlockBuilder, inputs: List[relax.Expr]):\n encoded_data = bb.emit_te(\n encoding_func(\n self.nbit,\n 8,\n group_size=self.group_size,\n dtype=self.dtype,\n ),\n inputs[0],\n primfunc_name_hint=\"encode\",\n )\n\n packed_weight = bb.normalize(encoded_data[0])\n\n if self.do_preprocess:\n encoded_weight = bb.emit(\n relax.call_pure_packed(\n \"cutlass.ft_preprocess_weight\",\n packed_weight,\n self.sm,\n self.nbit == 4,\n sinfo_args=packed_weight.struct_info,\n )\n )\n else:\n encoded_weight = packed_weight\n\n return bb.emit(relax.Tuple([encoded_weight, encoded_data[1]]))\n\n return f_quantize\n\n def get_dequantize_func(\n self,\n param_info: relax.TensorStructInfo,\n qparam_info: List[relax.TensorStructInfo],\n ) -> Optional[FQuantize]:\n return convert_TE_func(\n decoding_func(\n self.nbit,\n storage_nbit=8,\n group_size=self.group_size,\n ),\n func_name=\"decode\",\n )\n\n\ndef encoding_func(nbit: int, storage_nbit: int, group_size: int, dtype: str = \"float32\"):\n def te_encode_sym(weight: te.Tensor):\n \"\"\"Encode the weight tensor of shape [N, K] into a quantized weight tensor of shape\n [K, N // float_per_int] and a scale tensor of shape [K // group_size, N]\n \"\"\"\n n_float_per_int = storage_nbit // nbit\n max_int_value = (1 << (nbit - 1)) - 1\n\n cur_group_size = weight.shape[1] if group_size == -1 else group_size\n scale_min_shape = (tir.ceildiv(weight.shape[1], cur_group_size), weight.shape[0])\n k = te.reduce_axis((0, cur_group_size), name=\"k\")\n max_abs_value = te.compute(\n shape=scale_min_shape,\n fcompute=lambda group, i: te.max(\n te.abs(\n tir.if_then_else(\n group * cur_group_size + k < weight.shape[1],\n weight[i, group * cur_group_size + k],\n tir.const(0, dtype=weight.dtype),\n )\n ),\n axis=k,\n ),\n name=\"max_abs_value\",\n )\n\n def f_compute_scale(*idx):\n max_value = tir.max(tir.Cast(dtype, max_abs_value(*idx)), tir.const(1e-4, dtype))\n return max_value / tir.const(max_int_value, dtype)\n\n scale = te.compute(shape=scale_min_shape, fcompute=f_compute_scale, name=\"scale\")\n storage_dtype = \"int\" + str(storage_nbit)\n\n def f_scale_weight(i, j):\n w_scaled = tir.round(tir.Cast(dtype, weight[i, j]) / scale[j // cur_group_size, i])\n w_scaled = T.min(\n...\n", "needle_token_start": 11602, "needle_token_end": 11861, "code_context_ntokens": 12525, "output": ["```python\ndef f_quantize(bb: relax.BlockBuilder, inputs: List[relax.Expr]):\n encoded_data = bb.emit_te(\n encoding_func(\n self.nbit,\n 8,\n group_size=self.group_size,\n dtype=self.dtype,\n ),\n inputs[0],\n primfunc_name_hint=\"encode\",\n )\n\n packed_weight = bb.normalize(encoded_data[0])\n\n if self.do_preprocess:\n encoded_weight = bb.emit(\n relax.call_pure_packed(\n \"cutlass.ft_preprocess_weight\",\n packed_weight,\n self.sm,\n self.nbit == 4,\n sinfo_args=packed_weight.struct_info,\n )\n )\n else:\n encoded_weight = packed_weight\n\n return bb.emit(relax.Tuple([encoded_weight, encoded_data[1]]))\n```"]} +{"repo": "reactive-python/reactpy", "name": "sync_function", "language": "python", "path": "src/py/reactpy/reactpy/core/hooks.py", "position_ratio": 0.05, "description": "\nFunction Description:\n1. **Purpose**: To execute a synchronous version of an asynchronous effect function, potentially including a cleanup process.\n2. **Input**: An asynchronous function that returns a coroutine, which when executed, may return a cleanup function.\n3. **Output**: A synchronous function that, when called, handles the execution of the asynchronous function and its cleanup process.\n4. **Procedure**: \n - The asynchronous function is wrapped into a synchronous function.\n - A task is created to run the asynchronous function using asyncio.\n - A cleanup function is defined within the synchronous function to handle the cancellation of the task and to execute any cleanup returned by the asynchronous function if the task completes successfully.\n\n", "instruction": "Based on the function description and code context, please retrieve and repeat the exact described function from the code context in a code block wrapped by ```:", "template": "instruction\ncode_context\ndescription\ninstruction", "code_context": " return State(current_state.value, current_state.dispatch)\n\n\nclass _CurrentState(Generic[_Type]):\n __slots__ = \"value\", \"dispatch\"\n\n def __init__(\n self,\n initial_value: _Type | Callable[[], _Type],\n ) -> None:\n if callable(initial_value):\n self.value = initial_value()\n else:\n self.value = initial_value\n\n hook = current_hook()\n\n def dispatch(new: _Type | Callable[[_Type], _Type]) -> None:\n if callable(new):\n next_value = new(self.value)\n else:\n next_value = new\n if not strictly_equal(next_value, self.value):\n self.value = next_value\n hook.schedule_render()\n\n self.dispatch = dispatch\n\n\n_EffectCleanFunc: TypeAlias = \"Callable[[], None]\"\n_SyncEffectFunc: TypeAlias = \"Callable[[], _EffectCleanFunc | None]\"\n_AsyncEffectFunc: TypeAlias = (\n \"Callable[[], Coroutine[None, None, _EffectCleanFunc | None]]\"\n)\n_EffectApplyFunc: TypeAlias = \"_SyncEffectFunc | _AsyncEffectFunc\"\n\n\n@overload\ndef use_effect(\n function: None = None,\n dependencies: Sequence[Any] | ellipsis | None = ...,\n) -> Callable[[_EffectApplyFunc], None]: ...\n\n\n@overload\ndef use_effect(\n function: _EffectApplyFunc,\n dependencies: Sequence[Any] | ellipsis | None = ...,\n) -> None: ...\n\n\ndef use_effect(\n function: _EffectApplyFunc | None = None,\n dependencies: Sequence[Any] | ellipsis | None = ...,\n) -> Callable[[_EffectApplyFunc], None] | None:\n \"\"\"See the full :ref:`Use Effect` docs for details\n\n Parameters:\n function:\n Applies the effect and can return a clean-up function\n dependencies:\n Dependencies for the effect. The effect will only trigger if the identity\n of any value in the given sequence changes (i.e. their :func:`id` is\n different). By default these are inferred based on local variables that are\n referenced by the given function.\n\n Returns:\n If not function is provided, a decorator. Otherwise ``None``.\n \"\"\"\n hook = current_hook()\n\n dependencies = _try_to_infer_closure_values(function, dependencies)\n memoize = use_memo(dependencies=dependencies)\n last_clean_callback: Ref[_EffectCleanFunc | None] = use_ref(None)\n\n def add_effect(function: _EffectApplyFunc) -> None:\n if not asyncio.iscoroutinefunction(function):\n sync_function = cast(_SyncEffectFunc, function)\n else:\n async_function = cast(_AsyncEffectFunc, function)\n\n \ndef sync_function() -> _EffectCleanFunc | None:\n task = asyncio.create_task(async_function())\n\n def clean_future() -> None:\n if not task.cancel():\n try:\n clean = task.result()\n except asyncio.CancelledError:\n pass\n else:\n if clean is not None:\n clean()\n\n return clean_future\n\n async def effect(stop: asyncio.Event) -> None:\n if last_clean_callback.current is not None:\n last_clean_callback.current()\n last_clean_callback.current = None\n clean = last_clean_callback.current = sync_function()\n await stop.wait()\n if clean is not None:\n clean()\n\n return memoize(lambda: hook.add_effect(effect))\n\n if function is not None:\n add_effect(function)\n return None\n else:\n return add_effect\n\n\ndef use_debug_value(\n message: Any | Callable[[], Any],\n dependencies: Sequence[Any] | ellipsis | None = ...,\n) -> None:\n \"\"\"Log debug information when the given message changes.\n\n .. note::\n This hook only logs if :data:`~reactpy.config.REACTPY_DEBUG_MODE` is active.\n\n Unlike other hooks, a message is considered to have changed if the old and new\n values are ``!=``. Because this comparison is performed on every render of the\n component, it may be worth considering the performance cost in some situations.\n\n Parameters:\n message:\n The value to log or a memoized function for generating the value.\n dependencies:\n Dependencies for the memoized function. The message will only be recomputed\n if the identity of any value in the given sequence changes (i.e. their\n :func:`id` is different). By default these are inferred based on local\n variables that are referenced by the given function.\n \"\"\"\n old: Ref[Any] = _use_const(lambda: Ref(object()))\n memo_func = message if callable(message) else lambda: message\n new = use_memo(memo_func, dependencies)\n\n if REACTPY_DEBUG_MODE.current and old.current != new:\n old.current = new\n logger.debug(f\"{current_hook().component} {new}\")\n\n\ndef create_context(default_value: _Type) -> Context[_Type]:\n \"\"\"Return a new context type for use in :func:`use_context`\"\"\"\n\n def context(\n *children: Any,\n value: _Type = default_value,\n key: Key | None = None,\n ) -> _ContextProvider[_Type]:\n return _ContextProvider(\n *children,\n value=value,\n key=key,\n type=context,\n )\n\n context.__qualname__ = \"context\"\n\n return context\n\n\ndef use_context(context: Context[_Type]) -> _Type:\n \"\"\"Get the current value for the given context type.\n\n See the full :ref:`Use Context` docs for more information.\n \"\"\"\n hook = current_hook()\n provider = hook.get_context_provider(context)\n\n if provider is None:\n # same assertions but with normal exceptions\n if not isinstance(context, FunctionType):\n raise TypeError(f\"{context} is not a Context\") # nocov\n if context.__kwdefaults__ is None:\n raise TypeError(f\"{context} has no 'value' kwarg\") # nocov\n if \"value\" not in context.__kwdefaults__:\n raise TypeError(f\"{context} has no 'value' kwarg\") # nocov\n return cast(_Type, context.__kwdefaults__[\"value\"])\n\n return provider.value\n\n\nclass _ContextProvider(Generic[_Type]):\n def __init__(\n self,\n *children: Any,\n value: _Type,\n key: Key | None,\n type: Context[_Type],\n ) -> None:\n self.children = children\n self.key = key\n self.type = type\n self.value = value\n\n def render(self) -> VdomDict:\n current_hook().set_context_provider(self)\n return {\"tagName\": \"\", \"children\": self.children}\n\n def __repr__(self) -> str:\n return f\"ContextProvider({self.type})\"\n\n\n_ActionType = TypeVar(\"_ActionType\")\n\n\ndef use_reducer(\n reducer: Callable[[_Type, _ActionType], _Type],\n initial_value: _Type,\n) -> tuple[_Type, Callable[[_ActionType], None]]:\n \"\"\"See the full :ref:`Use Reducer` docs for details\n\n Parameters:\n reducer:\n A function which applies an action to the current state in order to\n produce the next state.\n initial_value:\n The initial state value (same as for :func:`use_state`)\n\n Returns:\n A tuple containing the current state and a function to change it with an action\n \"\"\"\n state, set_state = use_state(initial_value)\n return state, _use_const(lambda: _create_dispatcher(reducer, set_state))\n\n\ndef _create_dispatcher(\n reducer: Callable[[_Type, _ActionType], _Type],\n set_state: Callable[[Callable[[_Type], _Type]], None],\n) -> Callable[[_ActionType], None]:\n def dispatch(action: _ActionType) -> None:\n set_state(lambda last_state: reducer(last_state, action))\n\n return dispatch\n\n\n_CallbackFunc = TypeVar(\"_CallbackFunc\", bound=Callable[..., Any])\n\n\n@overload\ndef use_callback(\n function: None = None,\n dependencies: Sequence[Any] | ellipsis | None = ...,\n) -> Callable[[_CallbackFunc], _CallbackFunc]: ...\n\n\n@overload\ndef use_callback(\n function: _CallbackFunc,\n dependencies: Sequence[Any] | ellipsis | None = ...,\n) -> _CallbackFunc: ...\n\n\ndef use_callback(\n function: _CallbackFunc | None = None,\n dependencies: Sequence[Any] | ellipsis | None = ...,\n) -> _CallbackFunc | Callable[[_CallbackFunc], _CallbackFunc]:\n \"\"\"See the full :ref:`Use Callback` docs for details\n\n Parameters:\n function:\n The function whose identity will be preserved\n dependencies:\n Dependencies of the callback. The identity the ``function`` will be updated\n if the identity of any value in the given sequence changes (i.e. their\n :func:`id` is different). By default these are inferred based on local\n variables that are referenced by the given function.\n\n Returns:\n The current function\n \"\"\"\n dependencies = _try_to_infer_closure_values(function, dependencies)\n memoize = use_memo(dependencies=dependencies)\n\n def setup(function: _CallbackFunc) -> _CallbackFunc:\n return memoize(lambda: function)\n\n if function is not None:\n return setup(function)\n else:\n return setup\n\n\nclass _LambdaCaller(Protocol):\n \"\"\"MyPy doesn't know how to deal with TypeVars only used in function return\"\"\"\n\n def __call__(self, func: Callable[[], _Type]) -> _Type: ...\n\n\n@overload\ndef use_memo(\n function: None = None,\n dependencies: Sequence[Any] | ellipsis | None = ...,\n) -> _LambdaCaller: ...\n\n\n@overload\ndef use_memo(\n function: Callable[[], _Type],\n dependencies: Sequence[Any] | ellipsis | None = ...,\n) -> _Type: ...\n\n\ndef use_memo(\n function: Callable[[], _Type] | None = None,\n dependencies: Sequence[Any] | ellipsis | None = ...,\n) -> _Type | Callable[[Callable[[], _Type]], _Type]:\n \"\"\"See the full :ref:`Use Memo` docs for details\n\n Parameters:\n function:\n The function to be memoized.\n dependencies:\n Dependencies for the memoized function. The memo will only be recomputed if\n the identity of any value in the given sequence changes (i.e. their\n :func:`id` is different). By default these are inferred based on local\n variables that are referenced by the given function.\n\n Returns:\n The current state\n \"\"\"\n dependencies = _try_to_infer_closure_values(function, dependencies)\n\n memo: _Memo[_Type] = _use_const(_Memo)\n\n if memo.empty():\n # we need to initialize on the first run\n changed = True\n memo.deps = () if dependencies is None else dependencies\n elif dependencies is None:\n changed = True\n memo.deps = ()\n elif (\n len(memo.deps) != len(dependencies)\n # if deps are same length check identity for each item\n or not all(\n strictly_equal(current, new)\n for current, new in zip(memo.deps, dependencies)\n )\n ):\n memo.deps = dependencies\n changed = True\n else:\n changed = False\n\n setup: Callable[[Callable[[], _Type]], _Type]\n\n if changed:\n\n def setup(function: Callable[[], _Type]) -> _Type:\n current_value = memo.value = function()\n return current_value\n\n else:\n\n def setup(function: Callable[[], _Type]) -> _Type:\n return memo.value\n\n if function is not None:\n return setup(function)\n else:\n return setup\n\n\nclass _Memo(Generic[_Type]):\n \"\"\"Simple object for storing memoization data\"\"\"\n\n __slots__ = \"value\", \"deps\"\n\n value: _Type\n deps: Sequence[Any]\n\n def empty(self) -> bool:\n try:\n self.value # noqa: B018\n except AttributeError:\n return True\n else:\n return False\n\n\ndef use_ref(initial_value: _Type) -> Ref[_Type]:\n \"\"\"See the full :ref:`Use State` docs for details\n\n Parameters:\n initial_value: The value initially assigned to the reference.\n\n Returns:\n A :class:`Ref` object.\n \"\"\"\n return _use_const(lambda: Ref(initial_value))\n\n\ndef _use_const(function: Callable[[], _Type]) -> _Type:\n return current_hook().use_state(function)\n\n\ndef _try_to_infer_closure_values(\n func: Callable[..., Any] | None,\n values: Sequence[Any] | ellipsis | None,\n) -> Sequence[Any] | None:\n if values is ...:\n if isinstance(func, FunctionType):\n return (\n [cell.cell_contents for cell in func.__closure__]\n if func.__closure__\n else []\n )\n else:\n return None\n else:\n return values\n\n\ndef strictly_equal(x: Any, y: Any) -> bool:\n \"\"\"Check if two values are identical or, for a limited set or types, equal.\n\n Only the following types are checked for equality rather than identity:\n\n - ``int``\n - ``float``\n - ``complex``\n - ``str``\n - ``bytes``\n - ``bytearray``\n - ``memoryview``\n \"\"\"\n return x is y or (type(x) in _NUMERIC_TEXT_BINARY_TYPES and x == y)\n\n\n_NUMERIC_TEXT_BINARY_TYPES = {\n # numeric\n int,\n float,\n complex,\n # text\n str,\n # binary types\n bytes,\n bytearray,\n memoryview,\n}\n\n# Path: src/py/reactpy/reactpy/backend/hooks.py\nfrom __future__ import annotations\n\nfrom collections.abc import MutableMapping\nfrom typing import Any\n\nfrom reactpy.backend.types import Connection, Location\nfrom reactpy.core.hooks import create_context, use_context\nfrom reactpy.core.types import Context\n\n# backend implementations should establish this context at the root of an app\nConnectionContext: Context[Connection[Any] | None] = create_context(None)\n\n\ndef use_connection() -> Connection[Any]:\n \"\"\"Get the current :class:`~reactpy.backend.types.Connection`.\"\"\"\n conn = use_context(ConnectionContext)\n if conn is None: # nocov\n msg = \"No backend established a connection.\"\n raise RuntimeError(msg)\n return conn\n\n\ndef use_scope() -> MutableMapping[str, Any]:\n \"\"\"Get the current :class:`~reactpy.backend.types.Connection`'s scope.\"\"\"\n return use_connection().scope\n\n\ndef use_location() -> Location:\n \"\"\"Get the current :class:`~reactpy.backend.types.Connection`'s location.\"\"\"\n return use_connection().location\n\n# Path: src/py/reactpy/reactpy/core/component.py\nfrom __future__ import annotations\n\nimport inspect\nfrom functools import wraps\nfrom typing import Any, Callable\n\nfrom reactpy.core.types import ComponentType, VdomDict\n\n\ndef component(\n function: Callable[..., ComponentType | VdomDict | str | None]\n) -> Callable[..., Component]:\n \"\"\"A decorator for defining a new component.\n\n Parameters:\n function: The component's :meth:`reactpy.core.proto.ComponentType.render` function.\n \"\"\"\n sig = inspect.signature(function)\n\n if \"key\" in sig.parameters and sig.parameters[\"key\"].kind in (\n inspect.Parameter.KEYWORD_ONLY,\n inspect.Parameter.POSITIONAL_OR_KEYWORD,\n ):\n msg = f\"Component render function {function} uses reserved parameter 'key'\"\n raise TypeError(msg)\n\n @wraps(function)\n def constructor(*args: Any, key: Any | None = None, **kwargs: Any) -> Component:\n return Component(function, key, args, kwargs, sig)\n\n return constructor\n\n\nclass Component:\n \"\"\"An object for rending component models.\"\"\"\n\n __slots__ = \"__weakref__\", \"_func\", \"_args\", \"_kwargs\", \"_sig\", \"key\", \"type\"\n\n def __init__(\n self,\n function: Callable[..., ComponentType | VdomDict | str | None],\n key: Any | None,\n args: tuple[Any, ...],\n kwargs: dict[str, Any],\n sig: inspect.Signature,\n ) -> None:\n self.key = key\n self.type = function\n self._args = args\n self._kwargs = kwargs\n self._sig = sig\n\n def render(self) -> ComponentType | VdomDict | str | None:\n return self.type(*self._args, **self._kwargs)\n\n def __repr__(self) -> str:\n try:\n args = self._sig.bind(*self._args, **self._kwargs).arguments\n except TypeError:\n return f\"{self.type.__name__}(...)\"\n else:\n items = \", \".join(f\"{k}={v!r}\" for k, v in args.items())\n if items:\n return f\"{self.type.__name__}({id(self):02x}, {items})\"\n else:\n return f\"{self.type.__name__}({id(self):02x})\"\n\n# Path: src/py/reactpy/reactpy/types.py\n\"\"\"Exports common types from:\n\n- :mod:`reactpy.core.types`\n- :mod:`reactpy.backend.types`\n\"\"\"\n\nfrom reactpy.backend.types import BackendType, Connection, Location\nfrom reactpy.core.component import Component\nfrom reactpy.core.types import (\n ComponentConstructor,\n ComponentType,\n Context,\n EventHandlerDict,\n EventHandlerFunc,\n EventHandlerMapping,\n EventHandlerType,\n ImportSourceDict,\n Key,\n LayoutType,\n RootComponentConstructor,\n State,\n VdomAttributes,\n VdomChild,\n VdomChildren,\n VdomDict,\n VdomJson,\n)\n\n__all__ = [\n \"BackendType\",\n \"Component\",\n \"ComponentConstructor\",\n \"ComponentType\",\n \"Connection\",\n \"Context\",\n \"EventHandlerDict\",\n \"EventHandlerFunc\",\n \"EventHandlerMapping\",\n \"EventHandlerType\",\n \"ImportSourceDict\",\n \"Key\",\n \"LayoutType\",\n \"Location\",\n \"RootComponentConstructor\",\n \"State\",\n \"VdomAttributes\",\n \"VdomChild\",\n \"VdomChildren\",\n \"VdomDict\",\n \"VdomJson\",\n]\n\n# Path: src/py/reactpy/reactpy/backend/utils.py\nfrom __future__ import annotations\n\nimport asyncio\nimport logging\nimport socket\nimport sys\nfrom collections.abc import Iterator\nfrom contextlib import closing\nfrom importlib import import_module\nfrom typing import Any\n\nfrom reactpy.backend.types import BackendType\nfrom reactpy.types import RootComponentConstructor\n\nlogger = logging.getLogger(__name__)\n\nSUPPORTED_BACKENDS = (\n \"fastapi\",\n \"sanic\",\n \"tornado\",\n \"flask\",\n \"starlette\",\n)\n\n\ndef run(\n component: RootComponentConstructor,\n host: str = \"127.0.0.1\",\n port: int | None = None,\n implementation: BackendType[Any] | None = None,\n) -> None:\n \"\"\"Run a component with a development server\"\"\"\n logger.warning(_DEVELOPMENT_RUN_FUNC_WARNING)\n\n implementation = implementation or import_module(\"reactpy.backend.default\")\n app = implementation.create_development_app()\n implementation.configure(app, component)\n port = port or find_available_port(host)\n app_cls = type(app)\n\n logger.info(\n \"ReactPy is running with '%s.%s' at http://%s:%s\",\n app_cls.__module__,\n app_cls.__name__,\n host,\n port,\n )\n asyncio.run(implementation.serve_development_app(app, host, port))\n\n\ndef find_available_port(host: str, port_min: int = 8000, port_max: int = 9000) -> int:\n \"\"\"Get a port that's available for the given host and port range\"\"\"\n for port in range(port_min, port_max):\n with closing(socket.socket()) as sock:\n try:\n if sys.platform in (\"linux\", \"darwin\"):\n # Fixes bug on Unix-like systems where every time you restart the\n # server you'll get a different port on Linux. This cannot be set\n # on Windows otherwise address will always be reused.\n # Ref: https://stackoverflow.com/a/19247688/3159288\n sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)\n sock.bind((host, port))\n except OSError:\n pass\n else:\n return port\n msg = f\"Host {host!r} has no available port in range {port_max}-{port_max}\"\n raise RuntimeError(msg)\n\n\ndef all_implementations() -> Iterator[BackendType[Any]]:\n \"\"\"Yield all available server implementations\"\"\"\n for name in SUPPORTED_BACKENDS:\n try:\n import_module(name)\n except ImportError: # nocov\n logger.debug(\"Failed to import %s\", name, exc_info=True)\n continue\n\n reactpy_backend_name = f\"{__name__.rsplit('.', 1)[0]}.{name}\"\n yield import_module(reactpy_backend_name)\n\n\n_DEVELOPMENT_RUN_FUNC_WARNING = \"\"\"\\\nThe `run()` function is only intended for testing during development! To run \\\nin production, refer to the docs on how to use reactpy.backend.*.configure.\\\n\"\"\"\n\n# Path: src/py/reactpy/reactpy/core/layout.py\nfrom __future__ import annotations\n\nimport abc\nfrom asyncio import (\n FIRST_COMPLETED,\n CancelledError,\n Queue,\n Task,\n create_task,\n get_running_loop,\n wait,\n)\nfrom collections import Counter\nfrom collections.abc import Sequence\nfrom contextlib import AsyncExitStack\nfrom logging import getLogger\nfrom typing import (\n Any,\n Callable,\n Generic,\n NamedTuple,\n NewType,\n TypeVar,\n cast,\n)\nfrom uuid import uuid4\nfrom weakref import ref as weakref\n\nfrom anyio import Semaphore\nfrom typing_extensions import TypeAlias\n\nfrom reactpy.config import (\n REACTPY_ASYNC_RENDERING,\n REACTPY_CHECK_VDOM_SPEC,\n REACTPY_DEBUG_MODE,\n)\nfrom reactpy.core._life_cycle_hook import LifeCycleHook\nfrom reactpy.core.types import (\n ComponentType,\n EventHandlerDict,\n Key,\n LayoutEventMessage,\n LayoutUpdateMessage,\n VdomChild,\n VdomDict,\n VdomJson,\n)\nfrom reactpy.core.vdom import validate_vdom_json\nfrom reactpy.utils import Ref\n\nlogger = getLogger(__name__)\n\n\nclass Layout:\n \"\"\"Responsible for \"rendering\" components. That is, turning them into VDOM.\"\"\"\n\n __slots__: tuple[str, ...] = (\n \"root\",\n \"_event_handlers\",\n \"_rendering_queue\",\n \"_render_tasks\",\n \"_render_tasks_ready\",\n \"_root_life_cycle_state_id\",\n \"_model_states_by_life_cycle_state_id\",\n )\n\n if not hasattr(abc.ABC, \"__weakref__\"): # nocov\n __slots__ += (\"__weakref__\",)\n\n def __init__(self, root: ComponentType) -> None:\n super().__init__()\n if not isinstance(root, ComponentType):\n msg = f\"Expected a ComponentType, not {type(root)!r}.\"\n raise TypeError(msg)\n self.root = root\n\n async def __aenter__(self) -> Layout:\n # create attributes here to avoid access before entering context manager\n self._event_handlers: EventHandlerDict = {}\n self._render_tasks: set[Task[LayoutUpdateMessage]] = set()\n self._render_tasks_ready: Semaphore = Semaphore(0)\n\n self._rendering_queue: _ThreadSafeQueue[_LifeCycleStateId] = _ThreadSafeQueue()\n root_model_state = _new_root_model_state(self.root, self._schedule_render_task)\n\n self._root_life_cycle_state_id = root_id = root_model_state.life_cycle_state.id\n self._model_states_by_life_cycle_state_id = {root_id: root_model_state}\n self._schedule_render_task(root_id)\n\n return self\n\n async def __aexit__(self, *exc: Any) -> None:\n root_csid = self._root_life_cycle_state_id\n root_model_state = self._model_states_by_life_cycle_state_id[root_csid]\n\n for t in self._render_tasks:\n t.cancel()\n try:\n await t\n except CancelledError:\n pass\n\n await self._unmount_model_states([root_model_state])\n\n # delete attributes here to avoid access after exiting context manager\n del self._event_handlers\n del self._rendering_queue\n del self._root_life_cycle_state_id\n del self._model_states_by_life_cycle_state_id\n\n async def deliver(self, event: LayoutEventMessage) -> None:\n \"\"\"Dispatch an event to the targeted handler\"\"\"\n # It is possible for an element in the frontend to produce an event\n # associated with a backend model that has been deleted. We only handle\n # events if the element and the handler exist in the backend. Otherwise\n # we just ignore the event.\n handler = self._event_handlers.get(event[\"target\"])\n\n if handler is not None:\n try:\n await handler.function(event[\"data\"])\n except Exception:\n logger.exception(f\"Failed to execute event handler {handler}\")\n else:\n logger.info(\n f\"Ignored event - handler {event['target']!r} \"\n \"does not exist or its component unmounted\"\n )\n\n async def render(self) -> LayoutUpdateMessage:\n if REACTPY_ASYNC_RENDERING.current:\n return await self._concurrent_render()\n else: # nocov\n return await self._serial_render()\n\n async def _serial_render(self) -> LayoutUpdateMessage: # nocov\n \"\"\"Await the next available render. This will block until a component is updated\"\"\"\n while True:\n model_state_id = await self._rendering_queue.get()\n try:\n model_state = self._model_states_by_life_cycle_state_id[model_state_id]\n except KeyError:\n logger.debug(\n \"Did not render component with model state ID \"\n f\"{model_state_id!r} - component already unmounted\"\n )\n else:\n return await self._create_layout_update(model_state)\n\n async def _concurrent_render(self) -> LayoutUpdateMessage:\n \"\"\"Await the next available render. This will block until a component is updated\"\"\"\n await self._render_tasks_ready.acquire()\n done, _ = await wait(self._render_tasks, return_when=FIRST_COMPLETED)\n update_task: Task[LayoutUpdateMessage] = done.pop()\n self._render_tasks.remove(update_task)\n return update_task.result()\n\n async def _create_layout_update(\n self, old_state: _ModelState\n ) -> LayoutUpdateMessage:\n new_state = _copy_component_model_state(old_state)\n component = new_state.life_cycle_state.component\n\n async with AsyncExitStack() as exit_stack:\n await self._render_component(exit_stack, old_state, new_state, component)\n\n if REACTPY_CHECK_VDOM_SPEC.current:\n validate_vdom_json(new_state.model.current)\n\n return {\n \"type\": \"layout-update\",\n \"path\": new_state.patch_path,\n \"model\": new_state.model.current,\n }\n\n async def _render_component(\n self,\n exit_stack: AsyncExitStack,\n old_state: _ModelState | None,\n new_state: _ModelState,\n component: ComponentType,\n ) -> None:\n life_cycle_state = new_state.life_cycle_state\n life_cycle_hook = life_cycle_state.hook\n\n self._model_states_by_life_cycle_state_id[life_cycle_state.id] = new_state\n\n await life_cycle_hook.affect_component_will_render(component)\n exit_stack.push_async_callback(life_cycle_hook.affect_layout_did_render)\n try:\n raw_model = component.render()\n # wrap the model in a fragment (i.e. tagName=\"\") to ensure components have\n # a separate node in the model state tree. This could be removed if this\n # components are given a node in the tree some other way\n wrapper_model: VdomDict = {\"tagName\": \"\", \"children\": [raw_model]}\n await self._render_model(exit_stack, old_state, new_state, wrapper_model)\n except Exception as error:\n logger.exception(f\"Failed to render {component}\")\n new_state.model.current = {\n \"tagName\": \"\",\n \"error\": (\n f\"{type(error).__name__}: {error}\"\n if REACTPY_DEBUG_MODE.current\n else \"\"\n ),\n }\n finally:\n await life_cycle_hook.affect_component_did_render()\n\n try:\n parent = new_state.parent\n except AttributeError:\n pass # only happens for root component\n else:\n key, index = new_state.key, new_state.index\n parent.children_by_key[key] = new_state\n # need to add this model to parent's children without mutating parent model\n old_parent_model = parent.model.current\n old_parent_children = old_parent_model[\"children\"]\n parent.model.current = {\n **old_parent_model,\n \"children\": [\n *old_parent_children[:index],\n new_state.model.current,\n *old_parent_children[index + 1 :],\n ],\n }\n\n async def _render_model(\n self,\n exit_stack: AsyncExitStack,\n old_state: _ModelState | None,\n new_state: _ModelState,\n raw_model: Any,\n ) -> None:\n try:\n new_state.model.current = {\"tagName\": raw_model[\"tagName\"]}\n except Exception as e: # nocov\n msg = f\"Expected a VDOM element dict, not {raw_model}\"\n raise ValueError(msg) from e\n if \"key\" in raw_model:\n new_state.key = new_state.model.current[\"key\"] = raw_model[\"key\"]\n if \"importSource\" in raw_model:\n new_state.model.current[\"importSource\"] = raw_model[\"importSource\"]\n self._render_model_attributes(old_state, new_state, raw_model)\n await self._render_model_children(\n exit_stack, old_state, new_state, raw_model.get(\"children\", [])\n )\n\n def _render_model_attributes(\n self,\n old_state: _ModelState | None,\n new_state: _ModelState,\n raw_model: dict[str, Any],\n ) -> None:\n # extract event handlers from 'eventHandlers' and 'attributes'\n handlers_by_event: EventHandlerDict = raw_model.get(\"eventHandlers\", {})\n\n if \"attributes\" in raw_model:\n attrs = raw_model[\"attributes\"].copy()\n new_state.model.current[\"attributes\"] = attrs\n\n if old_state is None:\n self._render_model_event_handlers_without_old_state(\n new_state, handlers_by_event\n )\n return None\n\n for old_event in set(old_state.targets_by_event).difference(handlers_by_event):\n old_target = old_state.targets_by_event[old_event]\n del self._event_handlers[old_target]\n\n if not handlers_by_event:\n return None\n\n model_event_handlers = new_state.model.current[\"eventHandlers\"] = {}\n for event, handler in handlers_by_event.items():\n if event in old_state.targets_by_event:\n target = old_state.targets_by_event[event]\n else:\n target = uuid4().hex if handler.target is None else handler.target\n new_state.targets_by_event[event] = target\n self._event_handlers[target] = handler\n model_event_handlers[event] = {\n \"target\": target,\n \"preventDefault\": handler.prevent_default,\n \"stopPropagation\": handler.stop_propagation,\n }\n\n return None\n\n def _render_model_event_handlers_without_old_state(\n self,\n new_state: _ModelState,\n handlers_by_event: EventHandlerDict,\n ) -> None:\n if not handlers_by_event:\n return None\n\n model_event_handlers = new_state.model.current[\"eventHandlers\"] = {}\n for event, handler in handlers_by_event.items():\n target = uuid4().hex if handler.target is None else handler.target\n new_state.targets_by_event[event] = target\n self._event_handlers[target] = handler\n model_event_handlers[event] = {\n \"target\": target,\n \"preventDefault\": handler.prevent_default,\n \"stopPropagation\": handler.stop_propagation,\n }\n\n return None\n\n async def _render_model_children(\n self,\n exit_stack: AsyncExitStack,\n old_state: _ModelState | None,\n new_state: _ModelState,\n raw_children: Any,\n ) -> None:\n if not isinstance(raw_children, (list, tuple)):\n raw_children = [raw_children]\n\n if old_state is None:\n if raw_children:\n await self._render_model_children_without_old_state(\n exit_stack, new_state, raw_children\n )\n return None\n elif not raw_children:\n await self._unmount_model_states(list(old_state.children_by_key.values()))\n return None\n\n children_info = _get_children_info(raw_children)\n\n new_keys = {k for _, _, k in children_info}\n if len(new_keys) != len(children_info):\n key_counter = Counter(item[2] for item in children_info)\n duplicate_keys = [key for key, count in key_counter.items() if count > 1]\n msg = f\"Duplicate keys {duplicate_keys} at {new_state.patch_path or '/'!r}\"\n raise ValueError(msg)\n\n old_keys = set(old_state.children_by_key).difference(new_keys)\n if old_keys:\n await self._unmount_model_states(\n [old_state.children_by_key[key] for key in old_keys]\n )\n\n new_state.model.current[\"children\"] = []\n for index, (child, child_type, key) in enumerate(children_info):\n old_child_state = old_state.children_by_key.get(key)\n if child_type is _DICT_TYPE:\n old_child_state = old_state.children_by_key.get(key)\n if old_child_state is None:\n new_child_state = _make_element_model_state(\n new_state,\n index,\n key,\n )\n elif old_child_state.is_component_state:\n await self._unmount_model_states([old_child_state])\n new_child_state = _make_element_model_state(\n new_state,\n index,\n key,\n )\n old_child_state = None\n else:\n new_child_state = _update_element_model_state(\n old_child_state,\n new_state,\n index,\n )\n await self._render_model(\n exit_stack, old_child_state, new_child_state, child\n )\n new_state.append_child(new_child_state.model.current)\n new_state.children_by_key[key] = new_child_state\n elif child_type is _COMPONENT_TYPE:\n child = cast(ComponentType, child)\n old_child_state = old_state.children_by_key.get(key)\n if old_child_state is None:\n new_child_state = _make_component_model_state(\n new_state,\n index,\n key,\n child,\n self._schedule_render_task,\n )\n elif old_child_state.is_component_state and (\n old_child_state.life_cycle_state.component.type != child.type\n ):\n await self._unmount_model_states([old_child_state])\n old_child_state = None\n new_child_state = _make_component_model_state(\n new_state,\n index,\n key,\n child,\n self._schedule_render_task,\n )\n else:\n new_child_state = _update_component_model_state(\n old_child_state,\n new_state,\n index,\n child,\n self._schedule_render_task,\n )\n await self._render_component(\n exit_stack, old_child_state, new_child_state, child\n )\n else:\n old_child_state = old_state.children_by_key.get(key)\n if old_child_state is not None:\n await self._unmount_model_states([old_child_state])\n new_state.append_child(child)\n\n async def _render_model_children_without_old_state(\n self,\n exit_stack: AsyncExitStack,\n new_state: _ModelState,\n raw_children: list[Any],\n ) -> None:\n children_info = _get_children_info(raw_children)\n\n new_keys = {k for _, _, k in children_info}\n if len(new_keys) != len(children_info):\n key_counter = Counter(k for _, _, k in children_info)\n duplicate_keys = [key for key, count in key_counter.items() if count > 1]\n msg = f\"Duplicate keys {duplicate_keys} at {new_state.patch_path or '/'!r}\"\n raise ValueError(msg)\n\n new_state.model.current[\"children\"] = []\n for index, (child, child_type, key) in enumerate(children_info):\n if child_type is _DICT_TYPE:\n child_state = _make_element_model_state(new_state, index, key)\n await self._render_model(exit_stack, None, child_state, child)\n new_state.append_child(child_state.model.current)\n new_state.children_by_key[key] = child_state\n elif child_type is _COMPONENT_TYPE:\n child_state = _make_component_model_state(\n new_state, index, key, child, self._schedule_render_task\n )\n await self._render_component(exit_stack, None, child_state, child)\n else:\n new_state.append_child(child)\n\n async def _unmount_model_states(self, old_states: list[_ModelState]) -> None:\n to_unmount = old_states[::-1] # unmount in reversed order of rendering\n while to_unmount:\n model_state = to_unmount.pop()\n\n for target in model_state.targets_by_event.values():\n del self._event_handlers[target]\n\n if model_state.is_component_state:\n life_cycle_state = model_state.life_cycle_state\n del self._model_states_by_life_cycle_state_id[life_cycle_state.id]\n await life_cycle_state.hook.affect_component_will_unmount()\n\n to_unmount.extend(model_state.children_by_key.values())\n\n def _schedule_render_task(self, lcs_id: _LifeCycleStateId) -> None:\n if not REACTPY_ASYNC_RENDERING.current:\n self._rendering_queue.put(lcs_id)\n return None\n try:\n model_state = self._model_states_by_life_cycle_state_id[lcs_id]\n except KeyError:\n logger.debug(\n \"Did not render component with model state ID \"\n f\"{lcs_id!r} - component already unmounted\"\n )\n else:\n self._render_tasks.add(create_task(self._create_layout_update(model_state)))\n self._render_tasks_ready.release()\n\n def __repr__(self) -> str:\n return f\"{type(self).__name__}({self.root})\"\n\n\ndef _new_root_model_state(\n component: ComponentType, schedule_render: Callable[[_LifeCycleStateId], None]\n) -> _ModelState:\n return _ModelState(\n parent=None,\n index=-1,\n key=None,\n model=Ref(),\n patch_path=\"\",\n children_by_key={},\n targets_by_event={},\n life_cycle_state=_make_life_cycle_state(component, schedule_render),\n )\n\n\ndef _make_component_model_state(\n parent: _ModelState,\n index: int,\n key: Any,\n component: ComponentType,\n schedule_render: Callable[[_LifeCycleStateId], None],\n) -> _ModelState:\n return _ModelState(\n parent=parent,\n index=index,\n key=key,\n model=Ref(),\n patch_path=f\"{parent.patch_path}/children/{index}\",\n children_by_key={},\n targets_by_event={},\n life_cycle_state=_make_life_cycle_state(component, schedule_render),\n )\n\n\ndef _copy_component_model_state(old_model_state: _ModelState) -> _ModelState:\n # use try/except here because not having a parent is rare (only the root state)\n try:\n parent: _ModelState | None = old_model_state.parent\n except AttributeError:\n parent = None\n\n return _ModelState(\n parent=parent,\n index=old_model_state.index,\n key=old_model_state.key,\n model=Ref(), # does not copy the model\n patch_path=old_model_state.patch_path,\n children_by_key={},\n targets_by_event={},\n life_cycle_state=old_model_state.life_cycle_state,\n )\n\n\ndef _update_component_model_state(\n old_model_state: _ModelState,\n new_parent: _ModelState,\n new_index: int,\n new_component: ComponentType,\n schedule_render: Callable[[_LifeCycleStateId], None],\n) -> _ModelState:\n return _ModelState(\n parent=new_parent,\n index=new_index,\n key=old_model_state.key,\n model=Ref(), # does not copy the model\n patch_path=f\"{new_parent.patch_path}/children/{new_index}\",\n children_by_key={},\n targets_by_event={},\n life_cycle_state=(\n _update_life_cycle_state(old_model_state.life_cycle_state, new_component)\n if old_model_state.is_component_state\n else _make_life_cycle_state(new_component, schedule_render)\n ),\n )\n\n\ndef _make_element_model_state(\n parent: _ModelState,\n index: int,\n key: Any,\n) -> _ModelState:\n return _ModelState(\n parent=parent,\n index=index,\n key=key,\n model=Ref(),\n patch_path=f\"{parent.patch_path}/children/{index}\",\n children_by_key={},\n targets_by_event={},\n )\n\n\ndef _update_element_model_state(\n old_model_state: _ModelState,\n new_parent: _ModelState,\n new_index: int,\n) -> _ModelState:\n return _ModelState(\n parent=new_parent,\n index=new_index,\n key=old_model_state.key,\n model=Ref(), # does not copy the model\n patch_path=old_model_state.patch_path,\n children_by_key={},\n targets_by_event={},\n )\n\n\nclass _ModelState:\n \"\"\"State that is bound to a particular element within the layout\"\"\"\n\n __slots__ = (\n \"__weakref__\",\n \"_parent_ref\",\n \"_render_semaphore\",\n \"children_by_key\",\n \"index\",\n \"key\",\n \"life_cycle_state\",\n \"model\",\n \"patch_path\",\n \"targets_by_event\",\n )\n\n def __init__(\n self,\n parent: _ModelState | None,\n index: int,\n key: Any,\n model: Ref[VdomJson],\n patch_path: str,\n children_by_key: dict[Key, _ModelState],\n targets_by_event: dict[str, str],\n life_cycle_state: _LifeCycleState | None = None,\n ):\n self.index = index\n \"\"\"The index of the element amongst its siblings\"\"\"\n\n self.key = key\n \"\"\"A key that uniquely identifies the element amongst its siblings\"\"\"\n\n self.model = model\n \"\"\"The actual model of the element\"\"\"\n\n self.patch_path = patch_path\n \"\"\"A \"/\" delimited path to the element within the greater layout\"\"\"\n\n self.children_by_key = children_by_key\n \"\"\"Child model states indexed by their unique keys\"\"\"\n\n self.targets_by_event = targets_by_event\n \"\"\"The element's event handler target strings indexed by their event name\"\"\"\n\n # === Conditionally Available Attributes ===\n # It's easier to conditionally assign than to force a null check on every usage\n\n if parent is not None:\n self._parent_ref = weakref(parent)\n \"\"\"The parent model state\"\"\"\n\n if life_cycle_state is not None:\n self.life_cycle_state = life_cycle_state\n \"\"\"The state for the element's component (if it has one)\"\"\"\n\n @property\n def is_component_state(self) -> bool:\n return hasattr(self, \"life_cycle_state\")\n\n @property\n def parent(self) -> _ModelState:\n parent = self._parent_ref()\n if parent is None:\n raise RuntimeError(\"detached model state\") # nocov\n return parent\n\n def append_child(self, child: Any) -> None:\n self.model.current[\"children\"].append(child)\n\n def __repr__(self) -> str: # nocov\n return f\"ModelState({ {s: getattr(self, s, None) for s in self.__slots__} })\"\n\n\ndef _make_life_cycle_state(\n component: ComponentType,\n schedule_render: Callable[[_LifeCycleStateId], None],\n) -> _LifeCycleState:\n life_cycle_state_id = _LifeCycleStateId(uuid4().hex)\n return _LifeCycleState(\n life_cycle_state_id,\n LifeCycleHook(lambda: schedule_render(life_cycle_state_id)),\n component,\n )\n\n\ndef _update_life_cycle_state(\n old_life_cycle_state: _LifeCycleState,\n new_component: ComponentType,\n) -> _LifeCycleState:\n return _LifeCycleState(\n old_life_cycle_state.id,\n # the hook is preserved across renders because it holds the state\n old_life_cycle_state.hook,\n new_component,\n )\n\n\n_LifeCycleStateId = NewType(\"_LifeCycleStateId\", str)\n\n\nclass _LifeCycleState(NamedTuple):\n \"\"\"Component state for :class:`_ModelState`\"\"\"\n\n id: _LifeCycleStateId\n \"\"\"A unique identifier used in the :class:`~reactpy.core.hooks.LifeCycleHook` callback\"\"\"\n\n hook: LifeCycleHook\n \"\"\"The life cycle hook\"\"\"\n\n component: ComponentType\n \"\"\"The current component instance\"\"\"\n\n\n_Type = TypeVar(\"_Type\")\n\n\nclass _ThreadSafeQueue(Generic[_Type]):\n def __init__(self) -> None:\n self._loop = get_running_loop()\n self._queue: Queue[_Type] = Queue()\n self._pending: set[_Type] = set()\n\n def put(self, value: _Type) -> None:\n if value not in self._pending:\n self._pending.add(value)\n self._loop.call_soon_threadsafe(self._queue.put_nowait, value)\n\n async def get(self) -> _Type:\n value = await self._queue.get()\n self._pending.remove(value)\n return value\n\n\ndef _get_children_info(children: list[VdomChild]) -> Sequence[_ChildInfo]:\n infos: list[_ChildInfo] = []\n for index, child in enumerate(children):\n if child is None:\n continue\n elif isinstance(child, dict):\n child_type = _DICT_TYPE\n key = child.get(\"key\")\n elif isinstance(child, ComponentType):\n child_type = _COMPONENT_TYPE\n key = child.key\n else:\n child = f\"{child}\"\n child_type = _STRING_TYPE\n key = None\n\n if key is None:\n key = index\n\n infos.append((child, child_type, key))\n\n return infos\n\n\n_ChildInfo: TypeAlias = tuple[Any, \"_ElementType\", Key]\n\n# used in _process_child_type_and_key\n_ElementType = NewType(\"_ElementType\", int)\n_DICT_TYPE = _ElementType(1)\n_COMPONENT_TYPE = _ElementType(2)\n_STRING_TYPE = _ElementType(3)\n\n# Path: src/py/reactpy/reactpy/html.py\n\"\"\"\n\n**Fragment**\n\n- :func:`_`\n\n**Document metadata**\n\n- :func:`base`\n- :func:`head`\n- :func:`link`\n- :func:`meta`\n- :func:`style`\n- :func:`title`\n\n**Content sectioning**\n\n- :func:`address`\n- :func:`article`\n- :func:`aside`\n- :func:`footer`\n- :func:`header`\n- :func:`h1`\n- :func:`h2`\n- :func:`h3`\n- :func:`h4`\n- :func:`h5`\n- :func:`h6`\n- :func:`main`\n- :func:`nav`\n- :func:`section`\n\n**Text content**\n\n- :func:`blockquote`\n- :func:`dd`\n- :func:`div`\n- :func:`dl`\n- :func:`dt`\n- :func:`figcaption`\n- :func:`figure`\n- :func:`hr`\n- :func:`li`\n- :func:`ol`\n- :func:`p`\n- :func:`pre`\n- :func:`ul`\n\n**Inline text semantics**\n\n- :func:`a`\n- :func:`abbr`\n- :func:`b`\n- :func:`bdi`\n- :func:`bdo`\n- :func:`br`\n- :func:`cite`\n- :func:`code`\n- :func:`data`\n- :func:`em`\n- :func:`i`\n- :func:`kbd`\n- :func:`mark`\n- :func:`q`\n- :func:`rp`\n- :func:`rt`\n- :func:`ruby`\n- :func:`s`\n- :func:`samp`\n- :func:`small`\n- :func:`span`\n- :func:`strong`\n- :func:`sub`\n- :func:`sup`\n- :func:`time`\n- :func:`u`\n- :func:`var`\n- :func:`wbr`\n\n**Image and video**\n\n- :func:`area`\n- :func:`audio`\n- :func:`img`\n- :func:`map`\n- :func:`track`\n- :func:`video`\n\n**Embedded content**\n\n- :func:`embed`\n- :func:`iframe`\n- :func:`object`\n- :func:`param`\n- :func:`picture`\n- :func:`portal`\n- :func:`source`\n\n**SVG and MathML**\n\n- :func:`svg`\n- :func:`math`\n\n**Scripting**\n\n- :func:`canvas`\n- :func:`noscript`\n- :func:`script`\n\n**Demarcating edits**\n\n- :func:`del_`\n- :func:`ins`\n\n**Table content**\n\n- :func:`caption`\n- :func:`col`\n- :func:`colgroup`\n- :func:`table`\n- :func:`tbody`\n- :func:`td`\n- :func:`tfoot`\n- :func:`th`\n- :func:`thead`\n- :func:`tr`\n\n**Forms**\n\n- :func:`button`\n- :func:`fieldset`\n- :func:`form`\n- :func:`input`\n- :func:`label`\n- :func:`legend`\n- :func:`meter`\n- :func:`option`\n- :func:`output`\n- :func:`progress`\n- :func:`select`\n- :func:`textarea`\n\n**Interactive elements**\n\n- :func:`details`\n- :func:`dialog`\n- :func:`menu`\n- :func:`menuitem`\n- :func:`summary`\n\n**Web components**\n\n- :func:`slot`\n- :func:`template`\n\n.. autofunction:: _\n\"\"\"\n\nfrom __future__ import annotations\n\nfrom collections.abc import Sequence\n\nfrom reactpy.core.types import (\n EventHandlerDict,\n Key,\n VdomAttributes,\n VdomChild,\n VdomDict,\n)\nfrom reactpy.core.vdom import custom_vdom_constructor, make_vdom_constructor\n\n__all__ = (\n \"_\",\n \"a\",\n \"abbr\",\n \"address\",\n \"area\",\n \"article\",\n \"aside\",\n \"audio\",\n \"b\",\n \"base\",\n \"bdi\",\n \"bdo\",\n \"blockquote\",\n \"br\",\n \"button\",\n \"canvas\",\n \"caption\",\n \"cite\",\n \"code\",\n \"col\",\n \"colgroup\",\n \"data\",\n \"dd\",\n \"del_\",\n \"details\",\n \"dialog\",\n \"div\",\n \"dl\",\n \"dt\",\n \"em\",\n \"embed\",\n \"fieldset\",\n \"figcaption\",\n \"figure\",\n \"footer\",\n \"form\",\n \"h1\",\n \"h2\",\n \"h3\",\n \"h4\",\n \"h5\",\n \"h6\",\n \"head\",\n \"header\",\n \"hr\",\n \"i\",\n \"iframe\",\n \"img\",\n \"input\",\n \"ins\",\n \"kbd\",\n \"label\",\n \"legend\",\n \"li\",\n \"link\",\n \"main\",\n \"map\",\n \"mark\",\n \"math\",\n \"menu\",\n \"menuitem\",\n \"meta\",\n \"meter\",\n \"nav\",\n \"noscript\",\n \"object\",\n \"ol\",\n \"option\",\n \"output\",\n \"p\",\n \"param\",\n \"picture\",\n \"portal\",\n \"pre\",\n \"progress\",\n \"q\",\n \"rp\",\n \"rt\",\n \"ruby\",\n \"s\",\n \"samp\",\n \"script\",\n \"section\",\n \"select\",\n \"slot\",\n \"small\",\n \"source\",\n \"span\",\n \"strong\",\n \"style\",\n \"sub\",\n \"summary\",\n \"sup\",\n \"svg\",\n \"table\",\n \"tbody\",\n \"td\",\n \"template\",\n \"textarea\",\n \"tfoot\",\n \"th\",\n \"thead\",\n \"time\",\n \"title\",\n \"tr\",\n \"track\",\n \"u\",\n \"ul\",\n \"var\",\n \"video\",\n \"wbr\",\n)\n\n\ndef _fragment(\n attributes: VdomAttributes,\n children: Sequence[VdomChild],\n key: Key | None,\n event_handlers: EventHandlerDict,\n) -> VdomDict:\n \"\"\"An HTML fragment - this element will not appear in the DOM\"\"\"\n if attributes or event_handlers:\n msg = \"Fragments cannot have attributes besides 'key'\"\n raise TypeError(msg)\n model: VdomDict = {\"tagName\": \"\"}\n\n if children:\n model[\"children\"] = children\n\n if key is not None:\n model[\"key\"] = key\n\n return model\n\n\n# FIXME: https://github.com/PyCQA/pylint/issues/5784\n_ = custom_vdom_constructor(_fragment)\n\n\n# Document metadata\nbase = make_vdom_constructor(\"base\")\nhead = make_vdom_constructor(\"head\")\nlink = make_vdom_constructor(\"link\")\nmeta = make_vdom_constructor(\"meta\")\nstyle = make_vdom_constructor(\"style\")\ntitle = make_vdom_constructor(\"title\")\n\n# Content sectioning\naddress = make_vdom_constructor(\"address\")\narticle = make_vdom_constructor(\"article\")\naside = make_vdom_constructor(\"aside\")\nfooter = make_vdom_constructor(\"footer\")\nheader = make_vdom_constructor(\"header\")\nh1 = make_vdom_constructor(\"h1\")\nh2 = make_vdom_constructor(\"h2\")\nh3 = make_vdom_constructor(\"h3\")\nh4 = make_vdom_constructor(\"h4\")\nh5 = make_vdom_constructor(\"h5\")\nh6 = make_vdom_constructor(\"h6\")\nmain = make_vdom_constructor(\"main\")\nnav = make_vdom_constructor(\"nav\")\nsection = make_vdom_constructor(\"section\")\n\n# Text content\nblockquote = make_vdom_constructor(\"blockquote\")\ndd = make_vdom_constructor(\"dd\")\ndiv = make_vdom_constructor(\"div\")\ndl = make_vdom_constructor(\"dl\")\ndt = make_vdom_constructor(\"dt\")\nfigcaption = make_vdom_constructor(\"figcaption\")\nfigure = make_vdom_constructor(\"figure\")\nhr = make_vdom_constructor(\"hr\", allow_children=False)\nli = make_vdom_constructor(\"li\")\nol = make_vdom_constructor(\"ol\")\np = make_vdom_constructor(\"p\")\npre = make_vdom_constructor(\"pre\")\nul = make_vdom_constructor(\"ul\")\n\n# Inline text semantics\na = make_vdom_constructor(\"a\")\nabbr = make_vdom_constructor(\"abbr\")\nb = make_vdom_constructor(\"b\")\nbdi = make_vdom_constructor(\"bdi\")\nbdo = make_vdom_constructor(\"bdo\")\nbr = make_vdom_constructor(\"br\", allow_children=False)\ncite = make_vdom_constructor(\"cite\")\ncode = make_vdom_constructor(\"code\")\ndata = make_vdom_constructor(\"data\")\nem = make_vdom_constructor(\"em\")\ni = make_vdom_constructor(\"i\")\nkbd = make_vdom_constructor(\"kbd\")\nmark = make_vdom_constructor(\"mark\")\nq = make_vdom_constructor(\"q\")\nrp = make_vdom_constructor(\"rp\")\nrt = make_vdom_constructor(\"rt\")\nruby = make_vdom_constructor(\"ruby\")\ns = make_vdom_constructor(\"s\")\nsamp = make_vdom_constructor(\"samp\")\nsmall = make_vdom_constructor(\"small\")\nspan = make_vdom_constructor(\"span\")\nstrong = make_vdom_constructor(\"strong\")\nsub = make_vdom_constructor(\"sub\")\nsup = make_vdom_constructor(\"sup\")\ntime = make_vdom_constructor(\"time\")\nu = make_vdom_constructor(\"u\")\nvar = make_vdom_constructor(\"var\")\nwbr = make_vdom_constructor(\"wbr\")\n\n# Image and video\narea = make_vdom_constructor(\"area\", allow_children=False)\naudio = make_vdom_constructor(\"audio\")\n...\n", "needle_token_start": 739, "needle_token_end": 847, "code_context_ntokens": 16057, "output": ["```python\ndef sync_function() -> _EffectCleanFunc | None:\n task = asyncio.create_task(async_function())\n\n def clean_future() -> None:\n if not task.cancel():\n try:\n clean = task.result()\n except asyncio.CancelledError:\n pass\n else:\n if clean is not None:\n clean()\n\n return clean_future\n```"]} +{"repo": "reactive-python/reactpy", "name": "custom_vdom_constructor", "language": "python", "path": "src/py/reactpy/reactpy/core/vdom.py", "position_ratio": 0.15, "description": "\nFunction Description:\n1. **Purpose**: To transform a given function into a specialized constructor that creates virtual DOM (VDOM) dictionaries, ensuring that the function adheres to a specific interface for handling attributes, children, keys, and event handlers.\n2. **Input**: The input is a function that accepts four parameters: attributes (a dictionary), children (a list), a unique key (string or None), and event handlers (a dictionary).\n3. **Output**: The output is a new function that conforms to the VDOM dictionary constructor interface, capable of creating structured VDOM dictionaries based on the provided inputs.\n4. **Procedure**: The procedure involves:\n - Wrapping the given function with another function that preprocesses the input by separating attributes and children, extracting the unique key, and segregating event handlers.\n - The original function is then called with these processed inputs to generate a VDOM dictionary.\n - The wrapper ensures that the function conforms to the expected interface for creating VDOM elements.\n\n", "instruction": "Based on the function description and code context, please retrieve and repeat the exact described function from the code context in a code block wrapped by ```:", "template": "instruction\ncode_context\ndescription\ninstruction", "code_context": "# Path: src/py/reactpy/reactpy/core/events.py\nfrom __future__ import annotations\n\nimport asyncio\n...\n# Path: src/py/reactpy/reactpy/core/vdom.py\nfrom __future__ import annotations\n\nimport json\nfrom collections.abc import Mapping, Sequence\nfrom functools import wraps\nfrom typing import Any, Protocol, cast, overload\n\nfrom fastjsonschema import compile as compile_json_schema\n\nfrom reactpy._warnings import warn\nfrom reactpy.config import REACTPY_CHECK_JSON_ATTRS, REACTPY_DEBUG_MODE\nfrom reactpy.core._f_back import f_module_name\nfrom reactpy.core.events import EventHandler, to_event_handler_function\nfrom reactpy.core.types import (\n ComponentType,\n EventHandlerDict,\n EventHandlerType,\n ImportSourceDict,\n Key,\n VdomAttributes,\n VdomChild,\n VdomChildren,\n VdomDict,\n VdomDictConstructor,\n VdomJson,\n)\n\nVDOM_JSON_SCHEMA = {\n \"$schema\": \"http://json-schema.org/draft-07/schema\",\n \"$ref\": \"#/definitions/element\",\n \"definitions\": {\n \"element\": {\n \"type\": \"object\",\n \"properties\": {\n \"tagName\": {\"type\": \"string\"},\n \"key\": {\"type\": [\"string\", \"number\", \"null\"]},\n \"error\": {\"type\": \"string\"},\n \"children\": {\"$ref\": \"#/definitions/elementChildren\"},\n \"attributes\": {\"type\": \"object\"},\n \"eventHandlers\": {\"$ref\": \"#/definitions/elementEventHandlers\"},\n \"importSource\": {\"$ref\": \"#/definitions/importSource\"},\n },\n # The 'tagName' is required because its presence is a useful indicator of\n # whether a dictionary describes a VDOM model or not.\n \"required\": [\"tagName\"],\n \"dependentSchemas\": {\n # When 'error' is given, the 'tagName' should be empty.\n \"error\": {\"properties\": {\"tagName\": {\"maxLength\": 0}}}\n },\n },\n \"elementChildren\": {\n \"type\": \"array\",\n \"items\": {\"$ref\": \"#/definitions/elementOrString\"},\n },\n \"elementEventHandlers\": {\n \"type\": \"object\",\n \"patternProperties\": {\n \".*\": {\"$ref\": \"#/definitions/eventHandler\"},\n },\n },\n \"eventHandler\": {\n \"type\": \"object\",\n \"properties\": {\n \"target\": {\"type\": \"string\"},\n \"preventDefault\": {\"type\": \"boolean\"},\n \"stopPropagation\": {\"type\": \"boolean\"},\n },\n \"required\": [\"target\"],\n },\n \"importSource\": {\n \"type\": \"object\",\n \"properties\": {\n \"source\": {\"type\": \"string\"},\n \"sourceType\": {\"enum\": [\"URL\", \"NAME\"]},\n \"fallback\": {\n \"type\": [\"object\", \"string\", \"null\"],\n \"if\": {\"not\": {\"type\": \"null\"}},\n \"then\": {\"$ref\": \"#/definitions/elementOrString\"},\n },\n \"unmountBeforeUpdate\": {\"type\": \"boolean\"},\n },\n \"required\": [\"source\"],\n },\n \"elementOrString\": {\n \"type\": [\"object\", \"string\"],\n \"if\": {\"type\": \"object\"},\n \"then\": {\"$ref\": \"#/definitions/element\"},\n },\n },\n}\n\"\"\"JSON Schema describing serialized VDOM - see :ref:`VDOM` for more info\"\"\"\n\n\n# we can't add a docstring to this because Sphinx doesn't know how to find its source\n_COMPILED_VDOM_VALIDATOR = compile_json_schema(VDOM_JSON_SCHEMA)\n\n\ndef validate_vdom_json(value: Any) -> VdomJson:\n \"\"\"Validate serialized VDOM - see :attr:`VDOM_JSON_SCHEMA` for more info\"\"\"\n _COMPILED_VDOM_VALIDATOR(value)\n return cast(VdomJson, value)\n\n\ndef is_vdom(value: Any) -> bool:\n \"\"\"Return whether a value is a :class:`VdomDict`\n\n This employs a very simple heuristic - something is VDOM if:\n\n 1. It is a ``dict`` instance\n 2. It contains the key ``\"tagName\"``\n 3. The value of the key ``\"tagName\"`` is a string\n\n .. note::\n\n Performing an ``isinstance(value, VdomDict)`` check is too restrictive since the\n user would be forced to import ``VdomDict`` every time they needed to declare a\n VDOM element. Giving the user more flexibility, at the cost of this check's\n accuracy, is worth it.\n \"\"\"\n return (\n isinstance(value, dict)\n and \"tagName\" in value\n and isinstance(value[\"tagName\"], str)\n )\n\n\n@overload\ndef vdom(tag: str, *children: VdomChildren) -> VdomDict: ...\n\n\n@overload\ndef vdom(tag: str, attributes: VdomAttributes, *children: VdomChildren) -> VdomDict: ...\n\n\ndef vdom(\n tag: str,\n *attributes_and_children: Any,\n **kwargs: Any,\n) -> VdomDict:\n \"\"\"A helper function for creating VDOM elements.\n\n Parameters:\n tag:\n The type of element (e.g. 'div', 'h1', 'img')\n attributes_and_children:\n An optional attribute mapping followed by any number of children or\n iterables of children. The attribute mapping **must** precede the children,\n or children which will be merged into their respective parts of the model.\n key:\n A string indicating the identity of a particular element. This is significant\n to preserve event handlers across updates - without a key, a re-render would\n cause these handlers to be deleted, but with a key, they would be redirected\n to any newly defined handlers.\n event_handlers:\n Maps event types to coroutines that are responsible for handling those events.\n import_source:\n (subject to change) specifies javascript that, when evaluated returns a\n React component.\n \"\"\"\n if kwargs: # nocov\n if \"key\" in kwargs:\n if attributes_and_children:\n maybe_attributes, *children = attributes_and_children\n if _is_attributes(maybe_attributes):\n attributes_and_children = (\n {**maybe_attributes, \"key\": kwargs.pop(\"key\")},\n *children,\n )\n else:\n attributes_and_children = (\n {\"key\": kwargs.pop(\"key\")},\n maybe_attributes,\n *children,\n )\n else:\n attributes_and_children = ({\"key\": kwargs.pop(\"key\")},)\n warn(\n \"An element's 'key' must be declared in an attribute dict instead \"\n \"of as a keyword argument. This will error in a future version.\",\n DeprecationWarning,\n )\n\n if kwargs:\n msg = f\"Extra keyword arguments {kwargs}\"\n raise ValueError(msg)\n\n model: VdomDict = {\"tagName\": tag}\n\n if not attributes_and_children:\n return model\n\n attributes, children = separate_attributes_and_children(attributes_and_children)\n key = attributes.pop(\"key\", None)\n attributes, event_handlers = separate_attributes_and_event_handlers(attributes)\n\n if attributes:\n if REACTPY_CHECK_JSON_ATTRS.current:\n json.dumps(attributes)\n model[\"attributes\"] = attributes\n\n if children:\n model[\"children\"] = children\n\n if key is not None:\n model[\"key\"] = key\n\n if event_handlers:\n model[\"eventHandlers\"] = event_handlers\n\n return model\n\n\ndef make_vdom_constructor(\n tag: str, allow_children: bool = True, import_source: ImportSourceDict | None = None\n) -> VdomDictConstructor:\n \"\"\"Return a constructor for VDOM dictionaries with the given tag name.\n\n The resulting callable will have the same interface as :func:`vdom` but without its\n first ``tag`` argument.\n \"\"\"\n\n def constructor(*attributes_and_children: Any, **kwargs: Any) -> VdomDict:\n model = vdom(tag, *attributes_and_children, **kwargs)\n if not allow_children and \"children\" in model:\n msg = f\"{tag!r} nodes cannot have children.\"\n raise TypeError(msg)\n if import_source:\n model[\"importSource\"] = import_source\n return model\n\n # replicate common function attributes\n constructor.__name__ = tag\n constructor.__doc__ = (\n \"Return a new \"\n f\"`<{tag}> `__ \"\n \"element represented by a :class:`VdomDict`.\"\n )\n\n module_name = f_module_name(1)\n if module_name:\n constructor.__module__ = module_name\n constructor.__qualname__ = f\"{module_name}.{tag}\"\n\n return cast(VdomDictConstructor, constructor)\n\n\n\ndef custom_vdom_constructor(func: _CustomVdomDictConstructor) -> VdomDictConstructor:\n \"\"\"Cast function to VdomDictConstructor\"\"\"\n\n @wraps(func)\n def wrapper(*attributes_and_children: Any) -> VdomDict:\n attributes, children = separate_attributes_and_children(attributes_and_children)\n key = attributes.pop(\"key\", None)\n attributes, event_handlers = separate_attributes_and_event_handlers(attributes)\n return func(attributes, children, key, event_handlers)\n\n return cast(VdomDictConstructor, wrapper)\n\n\ndef separate_attributes_and_children(\n values: Sequence[Any],\n) -> tuple[dict[str, Any], list[Any]]:\n if not values:\n return {}, []\n\n attributes: dict[str, Any]\n children_or_iterables: Sequence[Any]\n if _is_attributes(values[0]):\n attributes, *children_or_iterables = values\n else:\n attributes = {}\n children_or_iterables = values\n\n children: list[Any] = []\n for child in children_or_iterables:\n if _is_single_child(child):\n children.append(child)\n else:\n children.extend(child)\n\n return attributes, children\n\n\ndef separate_attributes_and_event_handlers(\n attributes: Mapping[str, Any]\n) -> tuple[dict[str, Any], EventHandlerDict]:\n separated_attributes = {}\n separated_event_handlers: dict[str, EventHandlerType] = {}\n\n for k, v in attributes.items():\n handler: EventHandlerType\n\n if callable(v):\n handler = EventHandler(to_event_handler_function(v))\n elif (\n # isinstance check on protocols is slow - use function attr pre-check as a\n # quick filter before actually performing slow EventHandlerType type check\n hasattr(v, \"function\")\n and isinstance(v, EventHandlerType)\n ):\n handler = v\n else:\n separated_attributes[k] = v\n continue\n\n separated_event_handlers[k] = handler\n\n return separated_attributes, dict(separated_event_handlers.items())\n\n\ndef _is_attributes(value: Any) -> bool:\n return isinstance(value, Mapping) and \"tagName\" not in value\n\n\ndef _is_single_child(value: Any) -> bool:\n if isinstance(value, (str, Mapping)) or not hasattr(value, \"__iter__\"):\n return True\n if REACTPY_DEBUG_MODE.current:\n _validate_child_key_integrity(value)\n return False\n\n\ndef _validate_child_key_integrity(value: Any) -> None:\n if hasattr(value, \"__iter__\") and not hasattr(value, \"__len__\"):\n warn(\n f\"Did not verify key-path integrity of children in generator {value} \"\n \"- pass a sequence (i.e. list of finite length) in order to verify\"\n )\n else:\n for child in value:\n if isinstance(child, ComponentType) and child.key is None:\n warn(f\"Key not specified for child in list {child}\", UserWarning)\n elif isinstance(child, Mapping) and \"key\" not in child:\n # remove 'children' to reduce log spam\n child_copy = {**child, \"children\": _EllipsisRepr()}\n warn(f\"Key not specified for child in list {child_copy}\", UserWarning)\n\n\nclass _CustomVdomDictConstructor(Protocol):\n def __call__(\n self,\n attributes: VdomAttributes,\n children: Sequence[VdomChild],\n key: Key | None,\n event_handlers: EventHandlerDict,\n ) -> VdomDict: ...\n\n\nclass _EllipsisRepr:\n def __repr__(self) -> str:\n return \"...\"\n\n# Path: src/py/reactpy/reactpy/utils.py\nfrom __future__ import annotations\n\nimport re\nfrom collections.abc import Iterable\nfrom itertools import chain\nfrom typing import Any, Callable, Generic, TypeVar, cast\n\nfrom lxml import etree\nfrom lxml.html import fromstring, tostring\n\nfrom reactpy.core.types import VdomDict\nfrom reactpy.core.vdom import vdom\n\n_RefValue = TypeVar(\"_RefValue\")\n_ModelTransform = Callable[[VdomDict], Any]\n_UNDEFINED: Any = object()\n\n\nclass Ref(Generic[_RefValue]):\n \"\"\"Hold a reference to a value\n\n This is used in imperative code to mutate the state of this object in order to\n incur side effects. Generally refs should be avoided if possible, but sometimes\n they are required.\n\n Notes:\n You can compare the contents for two ``Ref`` objects using the ``==`` operator.\n \"\"\"\n\n __slots__ = (\"current\",)\n\n def __init__(self, initial_value: _RefValue = _UNDEFINED) -> None:\n if initial_value is not _UNDEFINED:\n self.current = initial_value\n \"\"\"The present value\"\"\"\n\n def set_current(self, new: _RefValue) -> _RefValue:\n \"\"\"Set the current value and return what is now the old value\n\n This is nice to use in ``lambda`` functions.\n \"\"\"\n old = self.current\n self.current = new\n return old\n\n def __eq__(self, other: Any) -> bool:\n try:\n return isinstance(other, Ref) and (other.current == self.current)\n except AttributeError:\n # attribute error occurs for uninitialized refs\n return False\n\n def __repr__(self) -> str:\n try:\n current = repr(self.current)\n except AttributeError:\n # attribute error occurs for uninitialized refs\n current = \"\"\n return f\"{type(self).__name__}({current})\"\n\n\ndef vdom_to_html(vdom: VdomDict) -> str:\n \"\"\"Convert a VDOM dictionary into an HTML string\n\n Only the following keys are translated to HTML:\n\n - ``tagName``\n - ``attributes``\n - ``children`` (must be strings or more VDOM dicts)\n\n Parameters:\n vdom: The VdomDict element to convert to HTML\n \"\"\"\n temp_root = etree.Element(\"__temp__\")\n _add_vdom_to_etree(temp_root, vdom)\n html = cast(bytes, tostring(temp_root)).decode()\n # strip out temp root <__temp__> element\n return html[10:-11]\n\n\ndef html_to_vdom(\n html: str, *transforms: _ModelTransform, strict: bool = True\n) -> VdomDict:\n \"\"\"Transform HTML into a DOM model. Unique keys can be provided to HTML elements\n using a ``key=...`` attribute within your HTML tag.\n\n Parameters:\n html:\n The raw HTML as a string\n transforms:\n Functions of the form ``transform(old) -> new`` where ``old`` is a VDOM\n dictionary which will be replaced by ``new``. For example, you could use a\n transform function to add highlighting to a ```` block.\n strict:\n If ``True``, raise an exception if the HTML does not perfectly follow HTML5\n syntax.\n \"\"\"\n if not isinstance(html, str): # nocov\n msg = f\"Expected html to be a string, not {type(html).__name__}\"\n raise TypeError(msg)\n\n # If the user provided a string, convert it to a list of lxml.etree nodes\n try:\n root_node: etree._Element = fromstring(\n html.strip(),\n parser=etree.HTMLParser(\n remove_comments=True,\n remove_pis=True,\n remove_blank_text=True,\n recover=not strict,\n ),\n )\n except etree.XMLSyntaxError as e:\n if not strict:\n raise e # nocov\n msg = \"An error has occurred while parsing the HTML.\\n\\nThis HTML may be malformatted, or may not perfectly adhere to HTML5.\\nIf you believe the exception above was due to something intentional, you can disable the strict parameter on html_to_vdom().\\nOtherwise, repair your broken HTML and try again.\"\n raise HTMLParseError(msg) from e\n\n return _etree_to_vdom(root_node, transforms)\n\n\nclass HTMLParseError(etree.LxmlSyntaxError): # type: ignore[misc]\n \"\"\"Raised when an HTML document cannot be parsed using strict parsing.\"\"\"\n\n\ndef _etree_to_vdom(\n node: etree._Element, transforms: Iterable[_ModelTransform]\n) -> VdomDict:\n \"\"\"Transform an lxml etree node into a DOM model\n\n Parameters:\n node:\n The ``lxml.etree._Element`` node\n transforms:\n Functions of the form ``transform(old) -> new`` where ``old`` is a VDOM\n dictionary which will be replaced by ``new``. For example, you could use a\n transform function to add highlighting to a ```` block.\n \"\"\"\n if not isinstance(node, etree._Element): # nocov\n msg = f\"Expected node to be a etree._Element, not {type(node).__name__}\"\n raise TypeError(msg)\n\n # Recursively call _etree_to_vdom() on all children\n children = _generate_vdom_children(node, transforms)\n\n # Convert the lxml node to a VDOM dict\n el = vdom(node.tag, dict(node.items()), *children)\n\n # Perform any necessary mutations on the VDOM attributes to meet VDOM spec\n _mutate_vdom(el)\n\n # Apply any provided transforms.\n for transform in transforms:\n el = transform(el)\n\n return el\n\n\ndef _add_vdom_to_etree(parent: etree._Element, vdom: VdomDict | dict[str, Any]) -> None:\n try:\n tag = vdom[\"tagName\"]\n except KeyError as e:\n msg = f\"Expected a VDOM dict, not {vdom}\"\n raise TypeError(msg) from e\n else:\n vdom = cast(VdomDict, vdom)\n\n if tag:\n element = etree.SubElement(parent, tag)\n element.attrib.update(\n _vdom_attr_to_html_str(k, v) for k, v in vdom.get(\"attributes\", {}).items()\n )\n else:\n element = parent\n\n for c in vdom.get(\"children\", []):\n if isinstance(c, dict):\n _add_vdom_to_etree(element, c)\n else:\n \"\"\"\n LXML handles string children by storing them under `text` and `tail`\n attributes of Element objects. The `text` attribute, if present, effectively\n becomes that element's first child. Then the `tail` attribute, if present,\n becomes a sibling that follows that element. For example, consider the\n following HTML:\n\n

helloworld

\n\n In this code sample, \"hello\" is the `text` attribute of the `` element\n and \"world\" is the `tail` attribute of that same `` element. It's for\n this reason that, depending on whether the element being constructed has\n non-string a child element, we need to assign a `text` vs `tail` attribute\n to that element or the last non-string child respectively.\n \"\"\"\n if len(element):\n last_child = element[-1]\n last_child.tail = f\"{last_child.tail or ''}{c}\"\n else:\n element.text = f\"{element.text or ''}{c}\"\n\n\ndef _mutate_vdom(vdom: VdomDict) -> None:\n \"\"\"Performs any necessary mutations on the VDOM attributes to meet VDOM spec.\n\n Currently, this function only transforms the ``style`` attribute into a dictionary whose keys are\n camelCase so as to be renderable by React.\n\n This function may be extended in the future.\n \"\"\"\n # Determine if the style attribute needs to be converted to a dict\n if (\n \"attributes\" in vdom\n and \"style\" in vdom[\"attributes\"]\n and isinstance(vdom[\"attributes\"][\"style\"], str)\n ):\n # Convince type checker that it's safe to mutate attributes\n assert isinstance(vdom[\"attributes\"], dict) # noqa: S101\n\n # Convert style attribute from str -> dict with camelCase keys\n vdom[\"attributes\"][\"style\"] = {\n key.strip().replace(\"-\", \"_\"): value.strip()\n for key, value in (\n part.split(\":\", 1)\n for part in vdom[\"attributes\"][\"style\"].split(\";\")\n if \":\" in part\n )\n }\n\n\ndef _generate_vdom_children(\n node: etree._Element, transforms: Iterable[_ModelTransform]\n) -> list[VdomDict | str]:\n \"\"\"Generates a list of VDOM children from an lxml node.\n\n Inserts inner text and/or tail text in between VDOM children, if necessary.\n \"\"\"\n return ( # Get the inner text of the current node\n [node.text] if node.text else []\n ) + list(\n chain(\n *(\n # Recursively convert each child node to VDOM\n [_etree_to_vdom(child, transforms)]\n # Insert the tail text between each child node\n + ([child.tail] if child.tail else [])\n for child in node.iterchildren(None)\n )\n )\n )\n\n\ndef del_html_head_body_transform(vdom: VdomDict) -> VdomDict:\n \"\"\"Transform intended for use with `html_to_vdom`.\n\n Removes ``, ``, and `` while preserving their children.\n\n Parameters:\n vdom:\n The VDOM dictionary to transform.\n \"\"\"\n if vdom[\"tagName\"] in {\"html\", \"body\", \"head\"}:\n return {\"tagName\": \"\", \"children\": vdom[\"children\"]}\n return vdom\n\n\ndef _vdom_attr_to_html_str(key: str, value: Any) -> tuple[str, str]:\n if key == \"style\":\n if isinstance(value, dict):\n value = \";\".join(\n # We lower only to normalize - CSS is case-insensitive:\n # https://www.w3.org/TR/css-fonts-3/#font-family-casing\n f\"{_CAMEL_CASE_SUB_PATTERN.sub('-', k).lower()}:{v}\"\n for k, v in value.items()\n )\n elif (\n # camel to data-* attributes\n key.startswith(\"data_\")\n # camel to aria-* attributes\n or key.startswith(\"aria_\")\n # handle special cases\n or key in DASHED_HTML_ATTRS\n ):\n key = key.replace(\"_\", \"-\")\n elif (\n # camel to data-* attributes\n key.startswith(\"data\")\n # camel to aria-* attributes\n or key.startswith(\"aria\")\n # handle special cases\n or key in DASHED_HTML_ATTRS\n ):\n key = _CAMEL_CASE_SUB_PATTERN.sub(\"-\", key)\n\n if callable(value): # nocov\n raise TypeError(f\"Cannot convert callable attribute {key}={value} to HTML\")\n\n # Again, we lower the attribute name only to normalize - HTML is case-insensitive:\n # http://w3c.github.io/html-reference/documents.html#case-insensitivity\n return key.lower(), str(value)\n\n\n# see list of HTML attributes with dashes in them:\n# https://developer.mozilla.org/en-US/docs/Web/HTML/Attributes#attribute_list\nDASHED_HTML_ATTRS = {\"accept_charset\", \"acceptCharset\", \"http_equiv\", \"httpEquiv\"}\n\n# Pattern for delimitting camelCase names (e.g. camelCase to camel-case)\n_CAMEL_CASE_SUB_PATTERN = re.compile(r\"(? State[_Type]: ...\n\n\n@overload\ndef use_state(initial_value: _Type) -> State[_Type]: ...\n\n\ndef use_state(initial_value: _Type | Callable[[], _Type]) -> State[_Type]:\n \"\"\"See the full :ref:`Use State` docs for details\n\n Parameters:\n initial_value:\n Defines the initial value of the state. A callable (accepting no arguments)\n can be used as a constructor function to avoid re-creating the initial value\n on each render.\n\n Returns:\n A tuple containing the current state and a function to update it.\n \"\"\"\n current_state = _use_const(lambda: _CurrentState(initial_value))\n return State(current_state.value, current_state.dispatch)\n\n\nclass _CurrentState(Generic[_Type]):\n __slots__ = \"value\", \"dispatch\"\n\n def __init__(\n self,\n initial_value: _Type | Callable[[], _Type],\n ) -> None:\n if callable(initial_value):\n self.value = initial_value()\n else:\n self.value = initial_value\n\n hook = current_hook()\n\n def dispatch(new: _Type | Callable[[_Type], _Type]) -> None:\n if callable(new):\n next_value = new(self.value)\n else:\n next_value = new\n if not strictly_equal(next_value, self.value):\n self.value = next_value\n hook.schedule_render()\n\n self.dispatch = dispatch\n\n\n_EffectCleanFunc: TypeAlias = \"Callable[[], None]\"\n_SyncEffectFunc: TypeAlias = \"Callable[[], _EffectCleanFunc | None]\"\n_AsyncEffectFunc: TypeAlias = (\n \"Callable[[], Coroutine[None, None, _EffectCleanFunc | None]]\"\n)\n_EffectApplyFunc: TypeAlias = \"_SyncEffectFunc | _AsyncEffectFunc\"\n\n\n@overload\ndef use_effect(\n function: None = None,\n dependencies: Sequence[Any] | ellipsis | None = ...,\n) -> Callable[[_EffectApplyFunc], None]: ...\n\n\n@overload\ndef use_effect(\n function: _EffectApplyFunc,\n dependencies: Sequence[Any] | ellipsis | None = ...,\n) -> None: ...\n\n\ndef use_effect(\n function: _EffectApplyFunc | None = None,\n dependencies: Sequence[Any] | ellipsis | None = ...,\n) -> Callable[[_EffectApplyFunc], None] | None:\n \"\"\"See the full :ref:`Use Effect` docs for details\n\n Parameters:\n function:\n Applies the effect and can return a clean-up function\n dependencies:\n Dependencies for the effect. The effect will only trigger if the identity\n of any value in the given sequence changes (i.e. their :func:`id` is\n different). By default these are inferred based on local variables that are\n referenced by the given function.\n\n Returns:\n If not function is provided, a decorator. Otherwise ``None``.\n \"\"\"\n hook = current_hook()\n\n dependencies = _try_to_infer_closure_values(function, dependencies)\n memoize = use_memo(dependencies=dependencies)\n last_clean_callback: Ref[_EffectCleanFunc | None] = use_ref(None)\n\n def add_effect(function: _EffectApplyFunc) -> None:\n if not asyncio.iscoroutinefunction(function):\n sync_function = cast(_SyncEffectFunc, function)\n else:\n async_function = cast(_AsyncEffectFunc, function)\n\n def sync_function() -> _EffectCleanFunc | None:\n task = asyncio.create_task(async_function())\n\n def clean_future() -> None:\n if not task.cancel():\n try:\n clean = task.result()\n except asyncio.CancelledError:\n pass\n else:\n if clean is not None:\n clean()\n\n return clean_future\n\n async def effect(stop: asyncio.Event) -> None:\n if last_clean_callback.current is not None:\n last_clean_callback.current()\n last_clean_callback.current = None\n clean = last_clean_callback.current = sync_function()\n await stop.wait()\n if clean is not None:\n clean()\n\n return memoize(lambda: hook.add_effect(effect))\n\n if function is not None:\n add_effect(function)\n return None\n else:\n return add_effect\n\n\ndef use_debug_value(\n message: Any | Callable[[], Any],\n dependencies: Sequence[Any] | ellipsis | None = ...,\n) -> None:\n \"\"\"Log debug information when the given message changes.\n\n .. note::\n This hook only logs if :data:`~reactpy.config.REACTPY_DEBUG_MODE` is active.\n\n Unlike other hooks, a message is considered to have changed if the old and new\n values are ``!=``. Because this comparison is performed on every render of the\n component, it may be worth considering the performance cost in some situations.\n\n Parameters:\n message:\n The value to log or a memoized function for generating the value.\n dependencies:\n Dependencies for the memoized function. The message will only be recomputed\n if the identity of any value in the given sequence changes (i.e. their\n :func:`id` is different). By default these are inferred based on local\n variables that are referenced by the given function.\n \"\"\"\n old: Ref[Any] = _use_const(lambda: Ref(object()))\n memo_func = message if callable(message) else lambda: message\n new = use_memo(memo_func, dependencies)\n\n if REACTPY_DEBUG_MODE.current and old.current != new:\n old.current = new\n logger.debug(f\"{current_hook().component} {new}\")\n\n\ndef create_context(default_value: _Type) -> Context[_Type]:\n \"\"\"Return a new context type for use in :func:`use_context`\"\"\"\n\n def context(\n *children: Any,\n value: _Type = default_value,\n key: Key | None = None,\n ) -> _ContextProvider[_Type]:\n return _ContextProvider(\n *children,\n value=value,\n key=key,\n type=context,\n )\n\n context.__qualname__ = \"context\"\n\n return context\n\n\ndef use_context(context: Context[_Type]) -> _Type:\n \"\"\"Get the current value for the given context type.\n\n See the full :ref:`Use Context` docs for more information.\n \"\"\"\n hook = current_hook()\n provider = hook.get_context_provider(context)\n\n if provider is None:\n # same assertions but with normal exceptions\n if not isinstance(context, FunctionType):\n raise TypeError(f\"{context} is not a Context\") # nocov\n if context.__kwdefaults__ is None:\n raise TypeError(f\"{context} has no 'value' kwarg\") # nocov\n if \"value\" not in context.__kwdefaults__:\n raise TypeError(f\"{context} has no 'value' kwarg\") # nocov\n return cast(_Type, context.__kwdefaults__[\"value\"])\n\n return provider.value\n\n\nclass _ContextProvider(Generic[_Type]):\n def __init__(\n self,\n *children: Any,\n value: _Type,\n key: Key | None,\n type: Context[_Type],\n ) -> None:\n self.children = children\n self.key = key\n self.type = type\n self.value = value\n\n def render(self) -> VdomDict:\n current_hook().set_context_provider(self)\n return {\"tagName\": \"\", \"children\": self.children}\n\n def __repr__(self) -> str:\n return f\"ContextProvider({self.type})\"\n\n\n_ActionType = TypeVar(\"_ActionType\")\n\n\ndef use_reducer(\n reducer: Callable[[_Type, _ActionType], _Type],\n initial_value: _Type,\n) -> tuple[_Type, Callable[[_ActionType], None]]:\n \"\"\"See the full :ref:`Use Reducer` docs for details\n\n Parameters:\n reducer:\n A function which applies an action to the current state in order to\n produce the next state.\n initial_value:\n The initial state value (same as for :func:`use_state`)\n\n Returns:\n A tuple containing the current state and a function to change it with an action\n \"\"\"\n state, set_state = use_state(initial_value)\n return state, _use_const(lambda: _create_dispatcher(reducer, set_state))\n\n\ndef _create_dispatcher(\n reducer: Callable[[_Type, _ActionType], _Type],\n set_state: Callable[[Callable[[_Type], _Type]], None],\n) -> Callable[[_ActionType], None]:\n def dispatch(action: _ActionType) -> None:\n set_state(lambda last_state: reducer(last_state, action))\n\n return dispatch\n\n\n_CallbackFunc = TypeVar(\"_CallbackFunc\", bound=Callable[..., Any])\n\n\n@overload\ndef use_callback(\n function: None = None,\n dependencies: Sequence[Any] | ellipsis | None = ...,\n) -> Callable[[_CallbackFunc], _CallbackFunc]: ...\n\n\n@overload\ndef use_callback(\n function: _CallbackFunc,\n dependencies: Sequence[Any] | ellipsis | None = ...,\n) -> _CallbackFunc: ...\n\n\ndef use_callback(\n function: _CallbackFunc | None = None,\n dependencies: Sequence[Any] | ellipsis | None = ...,\n) -> _CallbackFunc | Callable[[_CallbackFunc], _CallbackFunc]:\n \"\"\"See the full :ref:`Use Callback` docs for details\n\n Parameters:\n function:\n The function whose identity will be preserved\n dependencies:\n Dependencies of the callback. The identity the ``function`` will be updated\n if the identity of any value in the given sequence changes (i.e. their\n :func:`id` is different). By default these are inferred based on local\n variables that are referenced by the given function.\n\n Returns:\n The current function\n \"\"\"\n dependencies = _try_to_infer_closure_values(function, dependencies)\n memoize = use_memo(dependencies=dependencies)\n\n def setup(function: _CallbackFunc) -> _CallbackFunc:\n return memoize(lambda: function)\n\n if function is not None:\n return setup(function)\n else:\n return setup\n\n\nclass _LambdaCaller(Protocol):\n \"\"\"MyPy doesn't know how to deal with TypeVars only used in function return\"\"\"\n\n def __call__(self, func: Callable[[], _Type]) -> _Type: ...\n\n\n@overload\ndef use_memo(\n function: None = None,\n dependencies: Sequence[Any] | ellipsis | None = ...,\n) -> _LambdaCaller: ...\n\n\n@overload\ndef use_memo(\n function: Callable[[], _Type],\n dependencies: Sequence[Any] | ellipsis | None = ...,\n) -> _Type: ...\n\n\ndef use_memo(\n function: Callable[[], _Type] | None = None,\n dependencies: Sequence[Any] | ellipsis | None = ...,\n) -> _Type | Callable[[Callable[[], _Type]], _Type]:\n \"\"\"See the full :ref:`Use Memo` docs for details\n\n Parameters:\n function:\n The function to be memoized.\n dependencies:\n Dependencies for the memoized function. The memo will only be recomputed if\n the identity of any value in the given sequence changes (i.e. their\n :func:`id` is different). By default these are inferred based on local\n variables that are referenced by the given function.\n\n Returns:\n The current state\n \"\"\"\n dependencies = _try_to_infer_closure_values(function, dependencies)\n\n memo: _Memo[_Type] = _use_const(_Memo)\n\n if memo.empty():\n # we need to initialize on the first run\n changed = True\n memo.deps = () if dependencies is None else dependencies\n elif dependencies is None:\n changed = True\n memo.deps = ()\n elif (\n len(memo.deps) != len(dependencies)\n # if deps are same length check identity for each item\n or not all(\n strictly_equal(current, new)\n for current, new in zip(memo.deps, dependencies)\n )\n ):\n memo.deps = dependencies\n changed = True\n else:\n changed = False\n\n setup: Callable[[Callable[[], _Type]], _Type]\n\n if changed:\n\n def setup(function: Callable[[], _Type]) -> _Type:\n current_value = memo.value = function()\n return current_value\n\n else:\n\n def setup(function: Callable[[], _Type]) -> _Type:\n return memo.value\n\n if function is not None:\n return setup(function)\n else:\n return setup\n\n\nclass _Memo(Generic[_Type]):\n \"\"\"Simple object for storing memoization data\"\"\"\n\n __slots__ = \"value\", \"deps\"\n\n value: _Type\n deps: Sequence[Any]\n\n def empty(self) -> bool:\n try:\n self.value # noqa: B018\n except AttributeError:\n return True\n else:\n return False\n\n\ndef use_ref(initial_value: _Type) -> Ref[_Type]:\n \"\"\"See the full :ref:`Use State` docs for details\n\n Parameters:\n initial_value: The value initially assigned to the reference.\n\n Returns:\n A :class:`Ref` object.\n \"\"\"\n return _use_const(lambda: Ref(initial_value))\n\n\ndef _use_const(function: Callable[[], _Type]) -> _Type:\n return current_hook().use_state(function)\n\n\ndef _try_to_infer_closure_values(\n func: Callable[..., Any] | None,\n values: Sequence[Any] | ellipsis | None,\n) -> Sequence[Any] | None:\n if values is ...:\n if isinstance(func, FunctionType):\n return (\n [cell.cell_contents for cell in func.__closure__]\n if func.__closure__\n else []\n )\n else:\n return None\n else:\n return values\n\n\ndef strictly_equal(x: Any, y: Any) -> bool:\n \"\"\"Check if two values are identical or, for a limited set or types, equal.\n\n Only the following types are checked for equality rather than identity:\n\n - ``int``\n - ``float``\n - ``complex``\n - ``str``\n - ``bytes``\n - ``bytearray``\n - ``memoryview``\n \"\"\"\n return x is y or (type(x) in _NUMERIC_TEXT_BINARY_TYPES and x == y)\n\n\n_NUMERIC_TEXT_BINARY_TYPES = {\n # numeric\n int,\n float,\n complex,\n # text\n str,\n # binary types\n bytes,\n bytearray,\n memoryview,\n}\n\n# Path: src/py/reactpy/reactpy/backend/hooks.py\nfrom __future__ import annotations\n\nfrom collections.abc import MutableMapping\nfrom typing import Any\n\nfrom reactpy.backend.types import Connection, Location\nfrom reactpy.core.hooks import create_context, use_context\nfrom reactpy.core.types import Context\n\n# backend implementations should establish this context at the root of an app\nConnectionContext: Context[Connection[Any] | None] = create_context(None)\n\n\ndef use_connection() -> Connection[Any]:\n \"\"\"Get the current :class:`~reactpy.backend.types.Connection`.\"\"\"\n conn = use_context(ConnectionContext)\n if conn is None: # nocov\n msg = \"No backend established a connection.\"\n raise RuntimeError(msg)\n return conn\n\n\ndef use_scope() -> MutableMapping[str, Any]:\n \"\"\"Get the current :class:`~reactpy.backend.types.Connection`'s scope.\"\"\"\n return use_connection().scope\n\n\ndef use_location() -> Location:\n \"\"\"Get the current :class:`~reactpy.backend.types.Connection`'s location.\"\"\"\n return use_connection().location\n\n# Path: src/py/reactpy/reactpy/core/component.py\nfrom __future__ import annotations\n\nimport inspect\nfrom functools import wraps\nfrom typing import Any, Callable\n\nfrom reactpy.core.types import ComponentType, VdomDict\n\n\ndef component(\n function: Callable[..., ComponentType | VdomDict | str | None]\n) -> Callable[..., Component]:\n \"\"\"A decorator for defining a new component.\n\n Parameters:\n function: The component's :meth:`reactpy.core.proto.ComponentType.render` function.\n \"\"\"\n sig = inspect.signature(function)\n\n if \"key\" in sig.parameters and sig.parameters[\"key\"].kind in (\n inspect.Parameter.KEYWORD_ONLY,\n inspect.Parameter.POSITIONAL_OR_KEYWORD,\n ):\n msg = f\"Component render function {function} uses reserved parameter 'key'\"\n raise TypeError(msg)\n\n @wraps(function)\n def constructor(*args: Any, key: Any | None = None, **kwargs: Any) -> Component:\n return Component(function, key, args, kwargs, sig)\n\n return constructor\n\n\nclass Component:\n \"\"\"An object for rending component models.\"\"\"\n\n __slots__ = \"__weakref__\", \"_func\", \"_args\", \"_kwargs\", \"_sig\", \"key\", \"type\"\n\n def __init__(\n self,\n function: Callable[..., ComponentType | VdomDict | str | None],\n key: Any | None,\n args: tuple[Any, ...],\n kwargs: dict[str, Any],\n sig: inspect.Signature,\n ) -> None:\n self.key = key\n self.type = function\n self._args = args\n self._kwargs = kwargs\n self._sig = sig\n\n def render(self) -> ComponentType | VdomDict | str | None:\n return self.type(*self._args, **self._kwargs)\n\n def __repr__(self) -> str:\n try:\n args = self._sig.bind(*self._args, **self._kwargs).arguments\n except TypeError:\n return f\"{self.type.__name__}(...)\"\n else:\n items = \", \".join(f\"{k}={v!r}\" for k, v in args.items())\n if items:\n return f\"{self.type.__name__}({id(self):02x}, {items})\"\n else:\n return f\"{self.type.__name__}({id(self):02x})\"\n\n# Path: src/py/reactpy/reactpy/types.py\n\"\"\"Exports common types from:\n\n- :mod:`reactpy.core.types`\n- :mod:`reactpy.backend.types`\n\"\"\"\n\nfrom reactpy.backend.types import BackendType, Connection, Location\nfrom reactpy.core.component import Component\nfrom reactpy.core.types import (\n ComponentConstructor,\n ComponentType,\n Context,\n EventHandlerDict,\n EventHandlerFunc,\n EventHandlerMapping,\n EventHandlerType,\n ImportSourceDict,\n Key,\n LayoutType,\n RootComponentConstructor,\n State,\n VdomAttributes,\n VdomChild,\n VdomChildren,\n VdomDict,\n VdomJson,\n)\n\n__all__ = [\n \"BackendType\",\n \"Component\",\n \"ComponentConstructor\",\n \"ComponentType\",\n \"Connection\",\n \"Context\",\n \"EventHandlerDict\",\n \"EventHandlerFunc\",\n \"EventHandlerMapping\",\n \"EventHandlerType\",\n \"ImportSourceDict\",\n \"Key\",\n \"LayoutType\",\n \"Location\",\n \"RootComponentConstructor\",\n \"State\",\n \"VdomAttributes\",\n \"VdomChild\",\n \"VdomChildren\",\n \"VdomDict\",\n \"VdomJson\",\n]\n\n# Path: src/py/reactpy/reactpy/backend/utils.py\nfrom __future__ import annotations\n\nimport asyncio\nimport logging\nimport socket\nimport sys\nfrom collections.abc import Iterator\nfrom contextlib import closing\nfrom importlib import import_module\nfrom typing import Any\n\nfrom reactpy.backend.types import BackendType\nfrom reactpy.types import RootComponentConstructor\n\nlogger = logging.getLogger(__name__)\n\nSUPPORTED_BACKENDS = (\n \"fastapi\",\n \"sanic\",\n \"tornado\",\n \"flask\",\n \"starlette\",\n)\n\n\ndef run(\n component: RootComponentConstructor,\n host: str = \"127.0.0.1\",\n port: int | None = None,\n implementation: BackendType[Any] | None = None,\n) -> None:\n \"\"\"Run a component with a development server\"\"\"\n logger.warning(_DEVELOPMENT_RUN_FUNC_WARNING)\n\n implementation = implementation or import_module(\"reactpy.backend.default\")\n app = implementation.create_development_app()\n implementation.configure(app, component)\n port = port or find_available_port(host)\n app_cls = type(app)\n\n logger.info(\n \"ReactPy is running with '%s.%s' at http://%s:%s\",\n app_cls.__module__,\n app_cls.__name__,\n host,\n port,\n )\n asyncio.run(implementation.serve_development_app(app, host, port))\n\n\ndef find_available_port(host: str, port_min: int = 8000, port_max: int = 9000) -> int:\n \"\"\"Get a port that's available for the given host and port range\"\"\"\n for port in range(port_min, port_max):\n with closing(socket.socket()) as sock:\n try:\n if sys.platform in (\"linux\", \"darwin\"):\n # Fixes bug on Unix-like systems where every time you restart the\n # server you'll get a different port on Linux. This cannot be set\n # on Windows otherwise address will always be reused.\n # Ref: https://stackoverflow.com/a/19247688/3159288\n sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)\n sock.bind((host, port))\n except OSError:\n pass\n else:\n return port\n msg = f\"Host {host!r} has no available port in range {port_max}-{port_max}\"\n raise RuntimeError(msg)\n\n\ndef all_implementations() -> Iterator[BackendType[Any]]:\n \"\"\"Yield all available server implementations\"\"\"\n for name in SUPPORTED_BACKENDS:\n try:\n import_module(name)\n except ImportError: # nocov\n logger.debug(\"Failed to import %s\", name, exc_info=True)\n continue\n\n reactpy_backend_name = f\"{__name__.rsplit('.', 1)[0]}.{name}\"\n yield import_module(reactpy_backend_name)\n\n\n_DEVELOPMENT_RUN_FUNC_WARNING = \"\"\"\\\nThe `run()` function is only intended for testing during development! To run \\\nin production, refer to the docs on how to use reactpy.backend.*.configure.\\\n\"\"\"\n\n# Path: src/py/reactpy/reactpy/core/layout.py\nfrom __future__ import annotations\n\nimport abc\nfrom asyncio import (\n FIRST_COMPLETED,\n CancelledError,\n Queue,\n Task,\n create_task,\n get_running_loop,\n wait,\n)\nfrom collections import Counter\nfrom collections.abc import Sequence\nfrom contextlib import AsyncExitStack\nfrom logging import getLogger\nfrom typing import (\n Any,\n Callable,\n Generic,\n NamedTuple,\n NewType,\n TypeVar,\n cast,\n)\nfrom uuid import uuid4\nfrom weakref import ref as weakref\n\nfrom anyio import Semaphore\nfrom typing_extensions import TypeAlias\n\nfrom reactpy.config import (\n REACTPY_ASYNC_RENDERING,\n REACTPY_CHECK_VDOM_SPEC,\n REACTPY_DEBUG_MODE,\n)\nfrom reactpy.core._life_cycle_hook import LifeCycleHook\nfrom reactpy.core.types import (\n ComponentType,\n EventHandlerDict,\n Key,\n LayoutEventMessage,\n LayoutUpdateMessage,\n VdomChild,\n VdomDict,\n VdomJson,\n)\nfrom reactpy.core.vdom import validate_vdom_json\nfrom reactpy.utils import Ref\n\nlogger = getLogger(__name__)\n\n\nclass Layout:\n \"\"\"Responsible for \"rendering\" components. That is, turning them into VDOM.\"\"\"\n\n __slots__: tuple[str, ...] = (\n \"root\",\n \"_event_handlers\",\n \"_rendering_queue\",\n \"_render_tasks\",\n \"_render_tasks_ready\",\n \"_root_life_cycle_state_id\",\n \"_model_states_by_life_cycle_state_id\",\n )\n\n if not hasattr(abc.ABC, \"__weakref__\"): # nocov\n __slots__ += (\"__weakref__\",)\n\n def __init__(self, root: ComponentType) -> None:\n super().__init__()\n if not isinstance(root, ComponentType):\n msg = f\"Expected a ComponentType, not {type(root)!r}.\"\n raise TypeError(msg)\n self.root = root\n\n async def __aenter__(self) -> Layout:\n # create attributes here to avoid access before entering context manager\n self._event_handlers: EventHandlerDict = {}\n self._render_tasks: set[Task[LayoutUpdateMessage]] = set()\n self._render_tasks_ready: Semaphore = Semaphore(0)\n\n self._rendering_queue: _ThreadSafeQueue[_LifeCycleStateId] = _ThreadSafeQueue()\n root_model_state = _new_root_model_state(self.root, self._schedule_render_task)\n\n self._root_life_cycle_state_id = root_id = root_model_state.life_cycle_state.id\n self._model_states_by_life_cycle_state_id = {root_id: root_model_state}\n self._schedule_render_task(root_id)\n\n return self\n\n async def __aexit__(self, *exc: Any) -> None:\n root_csid = self._root_life_cycle_state_id\n root_model_state = self._model_states_by_life_cycle_state_id[root_csid]\n\n for t in self._render_tasks:\n t.cancel()\n try:\n await t\n except CancelledError:\n pass\n\n await self._unmount_model_states([root_model_state])\n\n # delete attributes here to avoid access after exiting context manager\n del self._event_handlers\n del self._rendering_queue\n del self._root_life_cycle_state_id\n del self._model_states_by_life_cycle_state_id\n\n async def deliver(self, event: LayoutEventMessage) -> None:\n \"\"\"Dispatch an event to the targeted handler\"\"\"\n # It is possible for an element in the frontend to produce an event\n # associated with a backend model that has been deleted. We only handle\n # events if the element and the handler exist in the backend. Otherwise\n # we just ignore the event.\n handler = self._event_handlers.get(event[\"target\"])\n\n if handler is not None:\n try:\n await handler.function(event[\"data\"])\n except Exception:\n logger.exception(f\"Failed to execute event handler {handler}\")\n else:\n logger.info(\n f\"Ignored event - handler {event['target']!r} \"\n \"does not exist or its component unmounted\"\n )\n\n async def render(self) -> LayoutUpdateMessage:\n if REACTPY_ASYNC_RENDERING.current:\n return await self._concurrent_render()\n else: # nocov\n return await self._serial_render()\n\n async def _serial_render(self) -> LayoutUpdateMessage: # nocov\n \"\"\"Await the next available render. This will block until a component is updated\"\"\"\n while True:\n model_state_id = await self._rendering_queue.get()\n try:\n model_state = self._model_states_by_life_cycle_state_id[model_state_id]\n except KeyError:\n logger.debug(\n \"Did not render component with model state ID \"\n f\"{model_state_id!r} - component already unmounted\"\n )\n else:\n return await self._create_layout_update(model_state)\n\n async def _concurrent_render(self) -> LayoutUpdateMessage:\n \"\"\"Await the next available render. This will block until a component is updated\"\"\"\n await self._render_tasks_ready.acquire()\n done, _ = await wait(self._render_tasks, return_when=FIRST_COMPLETED)\n update_task: Task[LayoutUpdateMessage] = done.pop()\n self._render_tasks.remove(update_task)\n return update_task.result()\n\n async def _create_layout_update(\n self, old_state: _ModelState\n ) -> LayoutUpdateMessage:\n new_state = _copy_component_model_state(old_state)\n component = new_state.life_cycle_state.component\n\n async with AsyncExitStack() as exit_stack:\n await self._render_component(exit_stack, old_state, new_state, component)\n\n if REACTPY_CHECK_VDOM_SPEC.current:\n validate_vdom_json(new_state.model.current)\n\n return {\n \"type\": \"layout-update\",\n \"path\": new_state.patch_path,\n \"model\": new_state.model.current,\n }\n\n async def _render_component(\n self,\n exit_stack: AsyncExitStack,\n old_state: _ModelState | None,\n new_state: _ModelState,\n component: ComponentType,\n ) -> None:\n life_cycle_state = new_state.life_cycle_state\n life_cycle_hook = life_cycle_state.hook\n\n self._model_states_by_life_cycle_state_id[life_cycle_state.id] = new_state\n\n await life_cycle_hook.affect_component_will_render(component)\n exit_stack.push_async_callback(life_cycle_hook.affect_layout_did_render)\n try:\n raw_model = component.render()\n # wrap the model in a fragment (i.e. tagName=\"\") to ensure components have\n # a separate node in the model state tree. This could be removed if this\n # components are given a node in the tree some other way\n wrapper_model: VdomDict = {\"tagName\": \"\", \"children\": [raw_model]}\n await self._render_model(exit_stack, old_state, new_state, wrapper_model)\n except Exception as error:\n logger.exception(f\"Failed to render {component}\")\n new_state.model.current = {\n \"tagName\": \"\",\n \"error\": (\n f\"{type(error).__name__}: {error}\"\n if REACTPY_DEBUG_MODE.current\n else \"\"\n ),\n }\n finally:\n await life_cycle_hook.affect_component_did_render()\n\n try:\n parent = new_state.parent\n except AttributeError:\n pass # only happens for root component\n else:\n key, index = new_state.key, new_state.index\n parent.children_by_key[key] = new_state\n # need to add this model to parent's children without mutating parent model\n old_parent_model = parent.model.current\n old_parent_children = old_parent_model[\"children\"]\n parent.model.current = {\n **old_parent_model,\n \"children\": [\n *old_parent_children[:index],\n new_state.model.current,\n *old_parent_children[index + 1 :],\n ],\n }\n\n async def _render_model(\n self,\n exit_stack: AsyncExitStack,\n old_state: _ModelState | None,\n new_state: _ModelState,\n raw_model: Any,\n ) -> None:\n try:\n new_state.model.current = {\"tagName\": raw_model[\"tagName\"]}\n except Exception as e: # nocov\n msg = f\"Expected a VDOM element dict, not {raw_model}\"\n raise ValueError(msg) from e\n if \"key\" in raw_model:\n new_state.key = new_state.model.current[\"key\"] = raw_model[\"key\"]\n if \"importSource\" in raw_model:\n new_state.model.current[\"importSource\"] = raw_model[\"importSource\"]\n self._render_model_attributes(old_state, new_state, raw_model)\n await self._render_model_children(\n exit_stack, old_state, new_state, raw_model.get(\"children\", [])\n )\n\n def _render_model_attributes(\n self,\n old_state: _ModelState | None,\n new_state: _ModelState,\n raw_model: dict[str, Any],\n ) -> None:\n # extract event handlers from 'eventHandlers' and 'attributes'\n handlers_by_event: EventHandlerDict = raw_model.get(\"eventHandlers\", {})\n\n if \"attributes\" in raw_model:\n attrs = raw_model[\"attributes\"].copy()\n new_state.model.current[\"attributes\"] = attrs\n\n if old_state is None:\n self._render_model_event_handlers_without_old_state(\n new_state, handlers_by_event\n )\n return None\n\n for old_event in set(old_state.targets_by_event).difference(handlers_by_event):\n old_target = old_state.targets_by_event[old_event]\n del self._event_handlers[old_target]\n\n if not handlers_by_event:\n return None\n\n model_event_handlers = new_state.model.current[\"eventHandlers\"] = {}\n for event, handler in handlers_by_event.items():\n if event in old_state.targets_by_event:\n target = old_state.targets_by_event[event]\n else:\n target = uuid4().hex if handler.target is None else handler.target\n new_state.targets_by_event[event] = target\n self._event_handlers[target] = handler\n model_event_handlers[event] = {\n \"target\": target,\n \"preventDefault\": handler.prevent_default,\n \"stopPropagation\": handler.stop_propagation,\n }\n\n return None\n\n def _render_model_event_handlers_without_old_state(\n self,\n new_state: _ModelState,\n handlers_by_event: EventHandlerDict,\n ) -> None:\n if not handlers_by_event:\n return None\n\n model_event_handlers = new_state.model.current[\"eventHandlers\"] = {}\n for event, handler in handlers_by_event.items():\n target = uuid4().hex if handler.target is None else handler.target\n new_state.targets_by_event[event] = target\n self._event_handlers[target] = handler\n model_event_handlers[event] = {\n \"target\": target,\n \"preventDefault\": handler.prevent_default,\n \"stopPropagation\": handler.stop_propagation,\n }\n\n return None\n\n async def _render_model_children(\n self,\n exit_stack: AsyncExitStack,\n old_state: _ModelState | None,\n new_state: _ModelState,\n...\n", "needle_token_start": 2344, "needle_token_end": 2490, "code_context_ntokens": 16058, "output": ["```python\ndef custom_vdom_constructor(func: _CustomVdomDictConstructor) -> VdomDictConstructor:\n \"\"\"Cast function to VdomDictConstructor\"\"\"\n\n @wraps(func)\n def wrapper(*attributes_and_children: Any) -> VdomDict:\n attributes, children = separate_attributes_and_children(attributes_and_children)\n key = attributes.pop(\"key\", None)\n attributes, event_handlers = separate_attributes_and_event_handlers(attributes)\n return func(attributes, children, key, event_handlers)\n\n return cast(VdomDictConstructor, wrapper)\n```"]} +{"repo": "reactive-python/reactpy", "name": "image", "language": "python", "path": "src/py/reactpy/reactpy/widgets.py", "position_ratio": 0.25, "description": "\nFunction Description:\n1. **Purpose**: The function is designed to create an HTML image element from a provided string or bytes, encoding the data into base64 format suitable for embedding directly into web pages.\n2. **Input**: The function accepts three parameters: a string specifying the image format (e.g., \"png\", \"jpeg\"), the image data either as a string or bytes, and an optional dictionary of additional HTML attributes for the image element.\n3. **Output**: Returns a dictionary representing a virtual DOM element for an image, which includes the tag name and attributes such as the source URL, which is constructed using the base64-encoded image data.\n4. **Procedure**: The function first checks and adjusts the image format if necessary (specifically for SVG). It then determines whether the image data is in string or byte format, encodes this data into base64, and constructs a data URL. This URL is used as the source attribute in the returned virtual DOM dictionary, which represents the image element.\n\n", "instruction": "Based on the function description and code context, please retrieve and repeat the exact described function from the code context in a code block wrapped by ```:", "template": "instruction\ncode_context\ndescription\ninstruction", "code_context": "# Path: src/py/reactpy/reactpy/testing/backend.py\nfrom __future__ import annotations\n\nimport asyncio\nimport logging\nfrom contextlib import AsyncExitStack, suppress\nfrom types import TracebackType\nfrom typing import Any, Callable\nfrom urllib.parse import urlencode, urlunparse\n\nfrom reactpy.backend import default as default_server\nfrom reactpy.backend.types import BackendType\nfrom reactpy.backend.utils import find_available_port\nfrom reactpy.config import REACTPY_TESTING_DEFAULT_TIMEOUT\nfrom reactpy.core.component import component\nfrom reactpy.core.hooks import use_callback, use_effect, use_state\nfrom reactpy.core.types import ComponentConstructor\nfrom reactpy.testing.logs import (\n LogAssertionError,\n capture_reactpy_logs,\n list_logged_exceptions,\n)\nfrom reactpy.utils import Ref\n\n\nclass BackendFixture:\n \"\"\"A test fixture for running a server and imperatively displaying views\n\n This fixture is typically used alongside async web drivers like ``playwight``.\n\n Example:\n .. code-block::\n\n async with BackendFixture() as server:\n server.mount(MyComponent)\n \"\"\"\n\n _records: list[logging.LogRecord]\n _server_future: asyncio.Task[Any]\n _exit_stack = AsyncExitStack()\n\n def __init__(\n self,\n host: str = \"127.0.0.1\",\n port: int | None = None,\n app: Any | None = None,\n implementation: BackendType[Any] | None = None,\n options: Any | None = None,\n timeout: float | None = None,\n ) -> None:\n self.host = host\n self.port = port or find_available_port(host)\n self.mount, self._root_component = _hotswap()\n self.timeout = (\n REACTPY_TESTING_DEFAULT_TIMEOUT.current if timeout is None else timeout\n )\n\n if app is not None and implementation is None:\n msg = \"If an application instance its corresponding server implementation must be provided too.\"\n raise ValueError(msg)\n\n self._app = app\n self.implementation = implementation or default_server\n self._options = options\n\n @property\n def log_records(self) -> list[logging.LogRecord]:\n \"\"\"A list of captured log records\"\"\"\n return self._records\n\n def url(self, path: str = \"\", query: Any | None = None) -> str:\n \"\"\"Return a URL string pointing to the host and point of the server\n\n Args:\n path: the path to a resource on the server\n query: a dictionary or list of query parameters\n \"\"\"\n return urlunparse(\n [\n \"http\",\n f\"{self.host}:{self.port}\",\n path,\n \"\",\n urlencode(query or ()),\n \"\",\n ]\n )\n\n def list_logged_exceptions(\n self,\n pattern: str = \"\",\n types: type[Any] | tuple[type[Any], ...] = Exception,\n log_level: int = logging.ERROR,\n del_log_records: bool = True,\n ) -> list[BaseException]:\n \"\"\"Return a list of logged exception matching the given criteria\n\n...\n# Path: src/py/reactpy/reactpy/testing/common.py\nfrom __future__ import annotations\n\nimport asyncio\nimport inspect\nimport shutil\nimport time\nfrom collections.abc import Awaitable\nfrom functools import wraps\nfrom typing import Any, Callable, Generic, TypeVar, cast\nfrom uuid import uuid4\nfrom weakref import ref\n\nfrom typing_extensions import ParamSpec\n\nfrom reactpy.config import REACTPY_TESTING_DEFAULT_TIMEOUT, REACTPY_WEB_MODULES_DIR\nfrom reactpy.core._life_cycle_hook import LifeCycleHook, current_hook\nfrom reactpy.core.events import EventHandler, to_event_handler_function\n\n\ndef clear_reactpy_web_modules_dir() -> None:\n \"\"\"Clear the directory where ReactPy stores registered web modules\"\"\"\n for path in REACTPY_WEB_MODULES_DIR.current.iterdir():\n shutil.rmtree(path) if path.is_dir() else path.unlink()\n\n\n_P = ParamSpec(\"_P\")\n_R = TypeVar(\"_R\")\n\n\n_DEFAULT_POLL_DELAY = 0.1\n\n\nclass poll(Generic[_R]): # noqa: N801\n \"\"\"Wait until the result of an sync or async function meets some condition\"\"\"\n\n def __init__(\n self,\n function: Callable[_P, Awaitable[_R] | _R],\n *args: _P.args,\n **kwargs: _P.kwargs,\n ) -> None:\n coro: Callable[_P, Awaitable[_R]]\n if not inspect.iscoroutinefunction(function):\n\n async def coro(*args: _P.args, **kwargs: _P.kwargs) -> _R:\n return cast(_R, function(*args, **kwargs))\n\n else:\n coro = cast(Callable[_P, Awaitable[_R]], function)\n self._func = coro\n self._args = args\n self._kwargs = kwargs\n\n async def until(\n self,\n condition: Callable[[_R], bool],\n timeout: float = REACTPY_TESTING_DEFAULT_TIMEOUT.current,\n delay: float = _DEFAULT_POLL_DELAY,\n description: str = \"condition to be true\",\n ) -> None:\n \"\"\"Check that the coroutines result meets a condition within the timeout\"\"\"\n started_at = time.time()\n while True:\n await asyncio.sleep(delay)\n result = await self._func(*self._args, **self._kwargs)\n if condition(result):\n break\n elif (time.time() - started_at) > timeout: # nocov\n msg = f\"Expected {description} after {timeout} seconds - last value was {result!r}\"\n raise asyncio.TimeoutError(msg)\n\n async def until_is(\n self,\n right: _R,\n timeout: float = REACTPY_TESTING_DEFAULT_TIMEOUT.current,\n delay: float = _DEFAULT_POLL_DELAY,\n ) -> None:\n \"\"\"Wait until the result is identical to the given value\"\"\"\n return await self.until(\n lambda left: left is right,\n timeout,\n delay,\n f\"value to be identical to {right!r}\",\n )\n\n async def until_equals(\n self,\n right: _R,\n timeout: float = REACTPY_TESTING_DEFAULT_TIMEOUT.current,\n delay: float = _DEFAULT_POLL_DELAY,\n ) -> None:\n \"\"\"Wait until the result is equal to the given value\"\"\"\n return await self.until(\n lambda left: left == right,\n timeout,\n delay,\n f\"value to equal {right!r}\",\n )\n\n\nclass HookCatcher:\n \"\"\"Utility for capturing a LifeCycleHook from a component\n\n Example:\n .. code-block::\n\n hooks = HookCatcher(index_by_kwarg=\"thing\")\n\n @reactpy.component\n @hooks.capture\n def MyComponent(thing):\n ...\n\n ... # render the component\n\n # grab the last render of where MyComponent(thing='something')\n hooks.index[\"something\"]\n # or grab the hook from the component's last render\n hooks.latest\n\n After the first render of ``MyComponent`` the ``HookCatcher`` will have\n captured the component's ``LifeCycleHook``.\n \"\"\"\n\n latest: LifeCycleHook\n\n def __init__(self, index_by_kwarg: str | None = None):\n self.index_by_kwarg = index_by_kwarg\n self.index: dict[Any, LifeCycleHook] = {}\n\n def capture(self, render_function: Callable[..., Any]) -> Callable[..., Any]:\n \"\"\"Decorator for capturing a ``LifeCycleHook`` on each render of a component\"\"\"\n\n # The render function holds a reference to `self` and, via the `LifeCycleHook`,\n # the component. Some tests check whether components are garbage collected, thus\n # we must use a `ref` here to ensure these checks pass once the catcher itself\n # has been collected.\n self_ref = ref(self)\n\n @wraps(render_function)\n def wrapper(*args: Any, **kwargs: Any) -> Any:\n self = self_ref()\n if self is None:\n raise RuntimeError(\"Hook catcher has been garbage collected\")\n\n hook = current_hook()\n if self.index_by_kwarg is not None:\n self.index[kwargs[self.index_by_kwarg]] = hook\n self.latest = hook\n return render_function(*args, **kwargs)\n\n return wrapper\n\n\nclass StaticEventHandler:\n \"\"\"Utility for capturing the target of one event handler\n\n Example:\n .. code-block::\n\n static_handler = StaticEventHandler()\n\n @reactpy.component\n def MyComponent():\n state, set_state = reactpy.hooks.use_state(0)\n handler = static_handler.use(lambda event: set_state(state + 1))\n return reactpy.html.button({\"onClick\": handler}, \"Click me!\")\n\n # gives the target ID for onClick where from the last render of MyComponent\n static_handlers.target\n\n If you need to capture event handlers from different instances of a component\n the you should create multiple ``StaticEventHandler`` instances.\n\n .. code-block::\n\n static_handlers_by_key = {\n \"first\": StaticEventHandler(),\n \"second\": StaticEventHandler(),\n }\n\n @reactpy.component\n def Parent():\n return reactpy.html.div(Child(key=\"first\"), Child(key=\"second\"))\n\n @reactpy.component\n def Child(key):\n state, set_state = reactpy.hooks.use_state(0)\n handler = static_handlers_by_key[key].use(lambda event: set_state(state + 1))\n return reactpy.html.button({\"onClick\": handler}, \"Click me!\")\n\n # grab the individual targets for each instance above\n first_target = static_handlers_by_key[\"first\"].target\n second_target = static_handlers_by_key[\"second\"].target\n \"\"\"\n\n def __init__(self) -> None:\n self.target = uuid4().hex\n\n def use(\n self,\n function: Callable[..., Any],\n stop_propagation: bool = False,\n prevent_default: bool = False,\n ) -> EventHandler:\n return EventHandler(\n to_event_handler_function(function),\n stop_propagation,\n prevent_default,\n self.target,\n )\n\n# Path: src/py/reactpy/reactpy/testing/display.py\nfrom __future__ import annotations\n\nfrom contextlib import AsyncExitStack\nfrom types import TracebackType\nfrom typing import Any\n\nfrom playwright.async_api import (\n Browser,\n BrowserContext,\n ElementHandle,\n Page,\n async_playwright,\n)\n\nfrom reactpy.config import REACTPY_TESTING_DEFAULT_TIMEOUT\nfrom reactpy.testing.backend import BackendFixture\nfrom reactpy.types import RootComponentConstructor\n\n\nclass DisplayFixture:\n \"\"\"A fixture for running web-based tests using ``playwright``\"\"\"\n\n _exit_stack: AsyncExitStack\n\n def __init__(\n self,\n backend: BackendFixture | None = None,\n driver: Browser | BrowserContext | Page | None = None,\n url_prefix: str = \"\",\n ) -> None:\n if backend is not None:\n self.backend = backend\n if driver is not None:\n if isinstance(driver, Page):\n self.page = driver\n else:\n self._browser = driver\n self.url_prefix = url_prefix\n\n async def show(\n self,\n component: RootComponentConstructor,\n ) -> None:\n self.backend.mount(component)\n await self.goto(\"/\")\n await self.root_element() # check that root element is attached\n\n async def goto(\n self, path: str, query: Any | None = None, add_url_prefix: bool = True\n ) -> None:\n await self.page.goto(\n self.backend.url(\n f\"{self.url_prefix}{path}\" if add_url_prefix else path, query\n )\n )\n\n async def root_element(self) -> ElementHandle:\n element = await self.page.wait_for_selector(\"#app\", state=\"attached\")\n if element is None: # nocov\n msg = \"Root element not attached\"\n raise RuntimeError(msg)\n return element\n\n async def __aenter__(self) -> DisplayFixture:\n es = self._exit_stack = AsyncExitStack()\n\n browser: Browser | BrowserContext\n if not hasattr(self, \"page\"):\n if not hasattr(self, \"_browser\"):\n pw = await es.enter_async_context(async_playwright())\n browser = await pw.chromium.launch()\n else:\n browser = self._browser\n self.page = await browser.new_page()\n\n self.page.set_default_timeout(REACTPY_TESTING_DEFAULT_TIMEOUT.current * 1000)\n\n if not hasattr(self, \"backend\"):\n self.backend = BackendFixture()\n await es.enter_async_context(self.backend)\n\n return self\n\n async def __aexit__(\n self,\n exc_type: type[BaseException] | None,\n exc_value: BaseException | None,\n traceback: TracebackType | None,\n ) -> None:\n self.backend.mount(None)\n await self._exit_stack.aclose()\n\n# Path: src/py/reactpy/reactpy/testing/__init__.py\nfrom reactpy.testing.backend import BackendFixture\nfrom reactpy.testing.common import (\n HookCatcher,\n StaticEventHandler,\n clear_reactpy_web_modules_dir,\n poll,\n)\nfrom reactpy.testing.display import DisplayFixture\nfrom reactpy.testing.logs import (\n LogAssertionError,\n assert_reactpy_did_log,\n assert_reactpy_did_not_log,\n capture_reactpy_logs,\n)\n\n__all__ = [\n \"assert_reactpy_did_not_log\",\n \"assert_reactpy_did_log\",\n \"capture_reactpy_logs\",\n \"clear_reactpy_web_modules_dir\",\n \"DisplayFixture\",\n \"HookCatcher\",\n \"LogAssertionError\",\n \"poll\",\n \"BackendFixture\",\n \"StaticEventHandler\",\n]\n\n# Path: src/py/reactpy/reactpy/widgets.py\nfrom __future__ import annotations\n\nfrom base64 import b64encode\nfrom collections.abc import Sequence\nfrom typing import TYPE_CHECKING, Any, Callable, Protocol, TypeVar\n\nimport reactpy\nfrom reactpy import html\nfrom reactpy._warnings import warn\nfrom reactpy.core.types import ComponentConstructor, VdomDict\n\n\n\ndef image(\n format: str,\n value: str | bytes = \"\",\n attributes: dict[str, Any] | None = None,\n) -> VdomDict:\n \"\"\"Utility for constructing an image from a string or bytes\n\n The source value will automatically be encoded to base64\n \"\"\"\n if format == \"svg\":\n format = \"svg+xml\" # noqa: A001\n\n if isinstance(value, str):\n bytes_value = value.encode()\n else:\n bytes_value = value\n\n base64_value = b64encode(bytes_value).decode()\n src = f\"data:image/{format};base64,{base64_value}\"\n\n return {\"tagName\": \"img\", \"attributes\": {\"src\": src, **(attributes or {})}}\n\n\n_Value = TypeVar(\"_Value\")\n\n\ndef use_linked_inputs(\n attributes: Sequence[dict[str, Any]],\n on_change: Callable[[_Value], None] = lambda value: None,\n cast: _CastFunc[_Value] = lambda value: value,\n initial_value: str = \"\",\n ignore_empty: bool = True,\n) -> list[VdomDict]:\n \"\"\"Return a list of linked inputs equal to the number of given attributes.\n\n Parameters:\n attributes:\n That attributes of each returned input element. If the number of generated\n inputs is variable, you may need to assign each one a\n :ref:`key ` by including a ``\"key\"`` in each\n attribute dictionary.\n on_change:\n A callback which is triggered when any input is changed. This callback need\n not update the 'value' field in the attributes of the inputs since that is\n handled automatically.\n cast:\n Cast the 'value' of changed inputs that is passed to ``on_change``.\n initial_value:\n Initialize the 'value' field of the inputs.\n ignore_empty:\n Do not trigger ``on_change`` if the 'value' is an empty string.\n \"\"\"\n value, set_value = reactpy.hooks.use_state(initial_value)\n\n def sync_inputs(event: dict[str, Any]) -> None:\n new_value = event[\"target\"][\"value\"]\n set_value(new_value)\n if not new_value and ignore_empty:\n return None\n on_change(cast(new_value))\n\n inputs: list[VdomDict] = []\n for attrs in attributes:\n inputs.append(html.input({**attrs, \"on_change\": sync_inputs, \"value\": value}))\n\n return inputs\n\n\n_CastTo_co = TypeVar(\"_CastTo_co\", covariant=True)\n\n\nclass _CastFunc(Protocol[_CastTo_co]):\n def __call__(self, value: str) -> _CastTo_co: ...\n\n\nif TYPE_CHECKING:\n from reactpy.testing.backend import _MountFunc\n\n\ndef hotswap(\n update_on_change: bool = False,\n) -> tuple[_MountFunc, ComponentConstructor]: # nocov\n warn(\n \"The 'hotswap' function is deprecated and will be removed in a future release\",\n DeprecationWarning,\n stacklevel=2,\n )\n from reactpy.testing.backend import _hotswap\n\n return _hotswap(update_on_change)\n\n# Path: src/py/reactpy/reactpy/__init__.py\nfrom reactpy import backend, config, html, logging, sample, svg, types, web, widgets\nfrom reactpy.backend.hooks import use_connection, use_location, use_scope\nfrom reactpy.backend.utils import run\nfrom reactpy.core import hooks\nfrom reactpy.core.component import component\nfrom reactpy.core.events import event\nfrom reactpy.core.hooks import (\n create_context,\n use_callback,\n use_context,\n use_debug_value,\n use_effect,\n use_memo,\n use_reducer,\n use_ref,\n use_state,\n)\nfrom reactpy.core.layout import Layout\nfrom reactpy.core.vdom import vdom\nfrom reactpy.utils import Ref, html_to_vdom, vdom_to_html\n\n__author__ = \"The Reactive Python Team\"\n__version__ = \"1.0.2\" # DO NOT MODIFY\n\n__all__ = [\n \"backend\",\n \"component\",\n \"config\",\n \"create_context\",\n \"event\",\n \"hooks\",\n \"html_to_vdom\",\n \"html\",\n \"Layout\",\n \"logging\",\n \"Ref\",\n \"run\",\n \"sample\",\n \"Stop\",\n \"svg\",\n \"types\",\n \"use_callback\",\n \"use_connection\",\n \"use_context\",\n \"use_debug_value\",\n \"use_effect\",\n \"use_location\",\n \"use_memo\",\n \"use_reducer\",\n \"use_ref\",\n \"use_scope\",\n \"use_state\",\n \"vdom_to_html\",\n \"vdom\",\n \"web\",\n \"widgets\",\n]\n\n# Path: src/py/reactpy/reactpy/_console/__init__.py\n\n# Path: src/py/reactpy/reactpy/_console/ast_utils.py\nfrom __future__ import annotations\n\nimport ast\nfrom collections.abc import Iterator, Sequence\nfrom dataclasses import dataclass\nfrom pathlib import Path\nfrom textwrap import indent\nfrom tokenize import COMMENT as COMMENT_TOKEN\nfrom tokenize import generate_tokens\nfrom typing import Any\n\nimport click\n\nfrom reactpy import html\n\n\ndef rewrite_changed_nodes(\n file: Path,\n source: str,\n tree: ast.AST,\n changed: list[ChangedNode],\n) -> str:\n ast.fix_missing_locations(tree)\n\n lines = source.split(\"\\n\")\n\n # find closest parent nodes that should be re-written\n nodes_to_unparse: list[ast.AST] = []\n for change in changed:\n node_lineage = [change.node, *change.parents]\n for i in range(len(node_lineage) - 1):\n current_node, next_node = node_lineage[i : i + 2]\n if (\n not hasattr(next_node, \"lineno\")\n or next_node.lineno < change.node.lineno\n or isinstance(next_node, (ast.ClassDef, ast.FunctionDef))\n ):\n nodes_to_unparse.append(current_node)\n break\n else: # nocov\n msg = \"Failed to change code\"\n raise RuntimeError(msg)\n\n # check if an nodes to rewrite contain each other, pick outermost nodes\n current_outermost_node, *sorted_nodes_to_unparse = sorted(\n nodes_to_unparse, key=lambda n: n.lineno\n )\n outermost_nodes_to_unparse = [current_outermost_node]\n for node in sorted_nodes_to_unparse:\n if (\n not current_outermost_node.end_lineno\n or node.lineno > current_outermost_node.end_lineno\n ):\n current_outermost_node = node\n outermost_nodes_to_unparse.append(node)\n\n moved_comment_lines_from_end: list[int] = []\n # now actually rewrite these nodes (in reverse to avoid changes earlier in file)\n for node in reversed(outermost_nodes_to_unparse):\n # make a best effort to preserve any comments that we're going to overwrite\n comments = _find_comments(lines[node.lineno - 1 : node.end_lineno])\n\n # there may be some content just before and after the content we're re-writing\n before_replacement = lines[node.lineno - 1][: node.col_offset].lstrip()\n\n after_replacement = (\n lines[node.end_lineno - 1][node.end_col_offset :].strip()\n if node.end_lineno is not None and node.end_col_offset is not None\n else \"\"\n )\n\n replacement = indent(\n before_replacement\n + \"\\n\".join([*comments, ast.unparse(node)])\n + after_replacement,\n \" \" * (node.col_offset - len(before_replacement)),\n )\n\n lines[node.lineno - 1 : node.end_lineno or node.lineno] = [replacement]\n\n if comments:\n moved_comment_lines_from_end.append(len(lines) - node.lineno)\n\n for lineno_from_end in sorted(set(moved_comment_lines_from_end)):\n click.echo(f\"Moved comments to {file}:{len(lines) - lineno_from_end}\")\n\n return \"\\n\".join(lines)\n\n\n@dataclass\nclass ChangedNode:\n node: ast.AST\n parents: Sequence[ast.AST]\n\n\ndef find_element_constructor_usages(\n tree: ast.AST, add_props: bool = False\n) -> Iterator[ElementConstructorInfo]:\n changed: list[Sequence[ast.AST]] = []\n for parents, node in _walk_with_parent(tree):\n if not (isinstance(node, ast.Call)):\n continue\n\n func = node.func\n if isinstance(func, ast.Attribute) and (\n (isinstance(func.value, ast.Name) and func.value.id == \"html\")\n or (isinstance(func.value, ast.Attribute) and func.value.attr == \"html\")\n ):\n name = func.attr\n elif isinstance(func, ast.Name):\n name = func.id\n else:\n continue\n\n maybe_attr_dict_node: Any | None = None\n\n if name == \"vdom\":\n if len(node.args) == 0:\n continue\n elif len(node.args) == 1:\n maybe_attr_dict_node = ast.Dict(keys=[], values=[])\n if add_props:\n node.args.append(maybe_attr_dict_node)\n else:\n continue\n elif isinstance(node.args[1], (ast.Constant, ast.JoinedStr)):\n maybe_attr_dict_node = ast.Dict(keys=[], values=[])\n if add_props:\n node.args.insert(1, maybe_attr_dict_node)\n else:\n continue\n elif len(node.args) >= 2: # noqa: PLR2004\n maybe_attr_dict_node = node.args[1]\n elif hasattr(html, name):\n if len(node.args) == 0:\n maybe_attr_dict_node = ast.Dict(keys=[], values=[])\n if add_props:\n node.args.append(maybe_attr_dict_node)\n else:\n continue\n elif isinstance(node.args[0], (ast.Constant, ast.JoinedStr)):\n maybe_attr_dict_node = ast.Dict(keys=[], values=[])\n if add_props:\n node.args.insert(0, maybe_attr_dict_node)\n else:\n continue\n else:\n maybe_attr_dict_node = node.args[0]\n\n if not maybe_attr_dict_node:\n continue\n\n if isinstance(maybe_attr_dict_node, ast.Dict) or (\n isinstance(maybe_attr_dict_node, ast.Call)\n and isinstance(maybe_attr_dict_node.func, ast.Name)\n and maybe_attr_dict_node.func.id == \"dict\"\n and isinstance(maybe_attr_dict_node.func.ctx, ast.Load)\n ):\n yield ElementConstructorInfo(node, maybe_attr_dict_node, parents)\n\n return changed\n\n\n@dataclass\nclass ElementConstructorInfo:\n call: ast.Call\n props: ast.Dict | ast.Call\n parents: Sequence[ast.AST]\n\n\ndef _find_comments(lines: list[str]) -> list[str]:\n iter_lines = iter(lines)\n return [\n token\n for token_type, token, _, _, _ in generate_tokens(lambda: next(iter_lines))\n if token_type == COMMENT_TOKEN\n ]\n\n\ndef _walk_with_parent(\n node: ast.AST, parents: tuple[ast.AST, ...] = ()\n) -> Iterator[tuple[tuple[ast.AST, ...], ast.AST]]:\n parents = (node, *parents)\n for child in ast.iter_child_nodes(node):\n yield parents, child\n yield from _walk_with_parent(child, parents)\n\n# Path: src/py/reactpy/reactpy/_console/rewrite_camel_case_props.py\nfrom __future__ import annotations\n\nimport ast\nimport re\nimport sys\nfrom copy import copy\nfrom keyword import kwlist\nfrom pathlib import Path\nfrom typing import Callable\n\nimport click\n\nfrom reactpy._console.ast_utils import (\n ChangedNode,\n find_element_constructor_usages,\n rewrite_changed_nodes,\n)\n\nCAMEL_CASE_SUB_PATTERN = re.compile(r\"(? None:\n \"\"\"Rewrite camelCase props to snake_case\"\"\"\n if sys.version_info < (3, 9): # nocov\n msg = \"This command requires Python>=3.9\"\n raise RuntimeError(msg)\n\n for p in map(Path, paths):\n for f in [p] if p.is_file() else p.rglob(\"*.py\"):\n result = generate_rewrite(file=f, source=f.read_text(encoding=\"utf-8\"))\n if result is not None:\n f.write_text(result)\n\n\ndef generate_rewrite(file: Path, source: str) -> str | None:\n tree = ast.parse(source)\n\n changed = find_nodes_to_change(tree)\n if not changed:\n return None\n\n new = rewrite_changed_nodes(file, source, tree, changed)\n return new\n\n\ndef find_nodes_to_change(tree: ast.AST) -> list[ChangedNode]:\n changed: list[ChangedNode] = []\n for el_info in find_element_constructor_usages(tree):\n if _rewrite_props(el_info.props, _construct_prop_item):\n changed.append(ChangedNode(el_info.call, el_info.parents))\n return changed\n\n\ndef conv_attr_name(name: str) -> str:\n new_name = CAMEL_CASE_SUB_PATTERN.sub(\"_\", name).lower()\n return f\"{new_name}_\" if new_name in kwlist else new_name\n\n\ndef _construct_prop_item(key: str, value: ast.expr) -> tuple[str, ast.expr]:\n if key == \"style\" and isinstance(value, (ast.Dict, ast.Call)):\n new_value = copy(value)\n if _rewrite_props(\n new_value,\n lambda k, v: (\n (k, v)\n # avoid infinite recursion\n if k == \"style\"\n else _construct_prop_item(k, v)\n ),\n ):\n value = new_value\n else:\n key = conv_attr_name(key)\n return key, value\n\n\ndef _rewrite_props(\n props_node: ast.Dict | ast.Call,\n constructor: Callable[[str, ast.expr], tuple[str, ast.expr]],\n) -> bool:\n if isinstance(props_node, ast.Dict):\n did_change = False\n keys: list[ast.expr | None] = []\n values: list[ast.expr] = []\n for k, v in zip(props_node.keys, props_node.values):\n if isinstance(k, ast.Constant) and isinstance(k.value, str):\n k_value, new_v = constructor(k.value, v)\n if k_value != k.value or new_v is not v:\n did_change = True\n k = ast.Constant(value=k_value)\n v = new_v\n keys.append(k)\n values.append(v)\n if not did_change:\n return False\n props_node.keys = keys\n props_node.values = values\n else:\n did_change = False\n keywords: list[ast.keyword] = []\n for kw in props_node.keywords:\n if kw.arg is not None:\n kw_arg, kw_value = constructor(kw.arg, kw.value)\n if kw_arg != kw.arg or kw_value is not kw.value:\n did_change = True\n kw = ast.keyword(arg=kw_arg, value=kw_value)\n keywords.append(kw)\n if not did_change:\n return False\n props_node.keywords = keywords\n return True\n\n# Path: src/py/reactpy/reactpy/_console/rewrite_keys.py\nfrom __future__ import annotations\n\nimport ast\nimport sys\nfrom pathlib import Path\n\nimport click\n\nfrom reactpy import html\nfrom reactpy._console.ast_utils import (\n ChangedNode,\n find_element_constructor_usages,\n rewrite_changed_nodes,\n)\n\n\n@click.command()\n@click.argument(\"paths\", nargs=-1, type=click.Path(exists=True))\ndef rewrite_keys(paths: list[str]) -> None:\n \"\"\"Rewrite files under the given paths using the new html element API.\n\n The old API required users to pass a dictionary of attributes to html element\n constructor functions. For example:\n\n >>> html.div({\"className\": \"x\"}, \"y\")\n {\"tagName\": \"div\", \"attributes\": {\"className\": \"x\"}, \"children\": [\"y\"]}\n\n The latest API though allows for attributes to be passed as snake_cased keyword\n arguments instead. The above example would be rewritten as:\n\n >>> html.div(\"y\", class_name=\"x\")\n {\"tagName\": \"div\", \"attributes\": {\"class_name\": \"x\"}, \"children\": [\"y\"]}\n\n All snake_case attributes are converted to camelCase by the client where necessary.\n\n ----- Notes -----\n\n While this command does it's best to preserve as much of the original code as\n possible, there are inevitably some limitations in doing this. As a result, we\n recommend running your code formatter like Black against your code after executing\n this command.\n\n Additionally, We are unable to preserve the location of comments that lie within any\n rewritten code. This command will place the comments in the code it plans to rewrite\n just above its changes. As such it requires manual intervention to put those\n comments back in their original location.\n \"\"\"\n if sys.version_info < (3, 9): # nocov\n msg = \"This command requires Python>=3.9\"\n raise RuntimeError(msg)\n\n for p in map(Path, paths):\n for f in [p] if p.is_file() else p.rglob(\"*.py\"):\n result = generate_rewrite(file=f, source=f.read_text(encoding=\"utf-8\"))\n if result is not None:\n f.write_text(result)\n\n\ndef generate_rewrite(file: Path, source: str) -> str | None:\n tree = ast.parse(source)\n\n changed = find_nodes_to_change(tree)\n if not changed:\n log_could_not_rewrite(file, tree)\n return None\n\n new = rewrite_changed_nodes(file, source, tree, changed)\n log_could_not_rewrite(file, ast.parse(new))\n\n return new\n\n\ndef find_nodes_to_change(tree: ast.AST) -> list[ChangedNode]:\n changed: list[ChangedNode] = []\n for el_info in find_element_constructor_usages(tree, add_props=True):\n for kw in list(el_info.call.keywords):\n if kw.arg == \"key\":\n break\n else:\n continue\n\n if isinstance(el_info.props, ast.Dict):\n el_info.props.keys.append(ast.Constant(\"key\"))\n el_info.props.values.append(kw.value)\n else:\n el_info.props.keywords.append(ast.keyword(arg=\"key\", value=kw.value))\n\n el_info.call.keywords.remove(kw)\n changed.append(ChangedNode(el_info.call, el_info.parents))\n\n return changed\n\n\ndef log_could_not_rewrite(file: Path, tree: ast.AST) -> None:\n for node in ast.walk(tree):\n if not (isinstance(node, ast.Call) and node.keywords):\n continue\n\n func = node.func\n if isinstance(func, ast.Attribute):\n name = func.attr\n elif isinstance(func, ast.Name):\n name = func.id\n else:\n continue\n\n if (\n name == \"vdom\"\n or hasattr(html, name)\n and any(kw.arg == \"key\" for kw in node.keywords)\n ):\n click.echo(f\"Unable to rewrite usage at {file}:{node.lineno}\")\n\n# Path: src/py/reactpy/reactpy/__main__.py\nimport click\n\nimport reactpy\nfrom reactpy._console.rewrite_camel_case_props import rewrite_camel_case_props\nfrom reactpy._console.rewrite_keys import rewrite_keys\n\n\n@click.group()\n@click.version_option(reactpy.__version__, prog_name=reactpy.__name__)\ndef app() -> None:\n pass\n\n\napp.add_command(rewrite_keys)\napp.add_command(rewrite_camel_case_props)\n\n\nif __name__ == \"__main__\":\n app()\n\n# Path: src/py/reactpy/reactpy/backend/_common.py\nfrom __future__ import annotations\n\nimport asyncio\nimport os\nfrom collections.abc import Awaitable, Sequence\nfrom dataclasses import dataclass\nfrom pathlib import Path, PurePosixPath\nfrom typing import TYPE_CHECKING, Any, cast\n\nfrom reactpy import __file__ as _reactpy_file_path\nfrom reactpy import html\nfrom reactpy.config import REACTPY_WEB_MODULES_DIR\nfrom reactpy.core.types import VdomDict\nfrom reactpy.utils import vdom_to_html\n\nif TYPE_CHECKING:\n import uvicorn\n from asgiref.typing import ASGIApplication\n\nPATH_PREFIX = PurePosixPath(\"/_reactpy\")\nMODULES_PATH = PATH_PREFIX / \"modules\"\nASSETS_PATH = PATH_PREFIX / \"assets\"\nSTREAM_PATH = PATH_PREFIX / \"stream\"\nCLIENT_BUILD_DIR = Path(_reactpy_file_path).parent / \"_static\" / \"app\" / \"dist\"\n\n\nasync def serve_with_uvicorn(\n app: ASGIApplication | Any,\n host: str,\n port: int,\n started: asyncio.Event | None,\n) -> None:\n \"\"\"Run a development server for an ASGI application\"\"\"\n import uvicorn\n\n server = uvicorn.Server(\n uvicorn.Config(\n app,\n host=host,\n port=port,\n loop=\"asyncio\",\n )\n )\n server.config.setup_event_loop()\n coros: list[Awaitable[Any]] = [server.serve()]\n\n # If a started event is provided, then use it signal based on `server.started`\n if started:\n coros.append(_check_if_started(server, started))\n\n try:\n await asyncio.gather(*coros)\n finally:\n # Since we aren't using the uvicorn's `run()` API, we can't guarantee uvicorn's\n # order of operations. So we need to make sure `shutdown()` always has an initialized\n # list of `self.servers` to use.\n if not hasattr(server, \"servers\"): # nocov\n server.servers = []\n await asyncio.wait_for(server.shutdown(), timeout=3)\n\n\nasync def _check_if_started(server: uvicorn.Server, started: asyncio.Event) -> None:\n while not server.started:\n await asyncio.sleep(0.2)\n started.set()\n\n\ndef safe_client_build_dir_path(path: str) -> Path:\n \"\"\"Prevent path traversal out of :data:`CLIENT_BUILD_DIR`\"\"\"\n return traversal_safe_path(\n CLIENT_BUILD_DIR, *(\"index.html\" if path in {\"\", \"/\"} else path).split(\"/\")\n )\n\n\ndef safe_web_modules_dir_path(path: str) -> Path:\n \"\"\"Prevent path traversal out of :data:`reactpy.config.REACTPY_WEB_MODULES_DIR`\"\"\"\n return traversal_safe_path(REACTPY_WEB_MODULES_DIR.current, *path.split(\"/\"))\n\n\ndef traversal_safe_path(root: str | Path, *unsafe: str | Path) -> Path:\n \"\"\"Raise a ``ValueError`` if the ``unsafe`` path resolves outside the root dir.\"\"\"\n root = os.path.abspath(root)\n\n # Resolve relative paths but not symlinks - symlinks should be ok since their\n # presence and where they point is under the control of the developer.\n path = os.path.abspath(os.path.join(root, *unsafe))\n\n if os.path.commonprefix([root, path]) != root:\n # If the common prefix is not root directory we resolved outside the root dir\n msg = \"Unsafe path\"\n raise ValueError(msg)\n\n return Path(path)\n\n\ndef read_client_index_html(options: CommonOptions) -> str:\n return (\n (CLIENT_BUILD_DIR / \"index.html\")\n .read_text()\n .format(__head__=vdom_head_elements_to_html(options.head))\n )\n\n\ndef vdom_head_elements_to_html(head: Sequence[VdomDict] | VdomDict | str) -> str:\n if isinstance(head, str):\n return head\n elif isinstance(head, dict):\n if head.get(\"tagName\") == \"head\":\n head = cast(VdomDict, {**head, \"tagName\": \"\"})\n return vdom_to_html(head)\n else:\n return vdom_to_html(html._(*head))\n\n\n@dataclass\nclass CommonOptions:\n \"\"\"Options for ReactPy's built-in backed server implementations\"\"\"\n\n head: Sequence[VdomDict] | VdomDict | str = (\n html.title(\"ReactPy\"),\n html.link(\n {\n \"rel\": \"icon\",\n \"href\": \"/_reactpy/assets/reactpy-logo.ico\",\n \"type\": \"image/x-icon\",\n }\n ),\n )\n \"\"\"Add elements to the ```` of the application.\n\n For example, this can be used to customize the title of the page, link extra\n scripts, or load stylesheets.\n \"\"\"\n\n url_prefix: str = \"\"\n \"\"\"The URL prefix where ReactPy resources will be served from\"\"\"\n\n serve_index_route: bool = True\n \"\"\"Automatically generate and serve the index route (``/``)\"\"\"\n\n def __post_init__(self) -> None:\n if self.url_prefix and not self.url_prefix.startswith(\"/\"):\n msg = \"Expected 'url_prefix' to start with '/'\"\n raise ValueError(msg)\n\n# Path: src/py/reactpy/reactpy/core/serve.py\nfrom __future__ import annotations\n\nfrom collections.abc import Awaitable\nfrom logging import getLogger\nfrom typing import Callable\nfrom warnings import warn\n\nfrom anyio import create_task_group\nfrom anyio.abc import TaskGroup\n\nfrom reactpy.config import REACTPY_DEBUG_MODE\nfrom reactpy.core.types import LayoutEventMessage, LayoutType, LayoutUpdateMessage\n\nlogger = getLogger(__name__)\n\n\nSendCoroutine = Callable[[LayoutUpdateMessage], Awaitable[None]]\n\"\"\"Send model patches given by a dispatcher\"\"\"\n\nRecvCoroutine = Callable[[], Awaitable[LayoutEventMessage]]\n\"\"\"Called by a dispatcher to return a :class:`reactpy.core.layout.LayoutEventMessage`\n\nThe event will then trigger an :class:`reactpy.core.proto.EventHandlerType` in a layout.\n\"\"\"\n\n\nclass Stop(BaseException):\n \"\"\"Deprecated\n\n Stop serving changes and events\n\n Raising this error will tell dispatchers to gracefully exit. Typically this is\n called by code running inside a layout to tell it to stop rendering.\n \"\"\"\n\n\nasync def serve_layout(\n layout: LayoutType[LayoutUpdateMessage, LayoutEventMessage],\n send: SendCoroutine,\n recv: RecvCoroutine,\n) -> None:\n \"\"\"Run a dispatch loop for a single view instance\"\"\"\n async with layout:\n try:\n async with create_task_group() as task_group:\n task_group.start_soon(_single_outgoing_loop, layout, send)\n task_group.start_soon(_single_incoming_loop, task_group, layout, recv)\n except Stop: # nocov\n warn(\n \"The Stop exception is deprecated and will be removed in a future version\",\n UserWarning,\n stacklevel=1,\n )\n logger.info(f\"Stopped serving {layout}\")\n\n\nasync def _single_outgoing_loop(\n layout: LayoutType[LayoutUpdateMessage, LayoutEventMessage], send: SendCoroutine\n) -> None:\n while True:\n update = await layout.render()\n try:\n await send(update)\n except Exception: # nocov\n if not REACTPY_DEBUG_MODE.current:\n msg = (\n \"Failed to send update. More info may be available \"\n \"if you enabling debug mode by setting \"\n \"`reactpy.config.REACTPY_DEBUG_MODE.current = True`.\"\n )\n logger.error(msg)\n raise\n\n\nasync def _single_incoming_loop(\n task_group: TaskGroup,\n layout: LayoutType[LayoutUpdateMessage, LayoutEventMessage],\n recv: RecvCoroutine,\n) -> None:\n while True:\n # We need to fire and forget here so that we avoid waiting on the completion\n # of this event handler before receiving and running the next one.\n task_group.start_soon(layout.deliver, await recv())\n\n# Path: src/py/reactpy/reactpy/backend/starlette.py\nfrom __future__ import annotations\n\nimport asyncio\nimport json\nimport logging\nfrom collections.abc import Awaitable\nfrom dataclasses import dataclass\nfrom typing import Any, Callable\n\nfrom exceptiongroup import BaseExceptionGroup\nfrom starlette.applications import Starlette\nfrom starlette.middleware.cors import CORSMiddleware\nfrom starlette.requests import Request\nfrom starlette.responses import HTMLResponse\nfrom starlette.staticfiles import StaticFiles\nfrom starlette.websockets import WebSocket, WebSocketDisconnect\n\nfrom reactpy.backend._common import (\n ASSETS_PATH,\n CLIENT_BUILD_DIR,\n MODULES_PATH,\n STREAM_PATH,\n CommonOptions,\n read_client_index_html,\n serve_with_uvicorn,\n)\nfrom reactpy.backend.hooks import ConnectionContext\nfrom reactpy.backend.hooks import use_connection as _use_connection\nfrom reactpy.backend.types import Connection, Location\nfrom reactpy.config import REACTPY_WEB_MODULES_DIR\nfrom reactpy.core.layout import Layout\nfrom reactpy.core.serve import RecvCoroutine, SendCoroutine, serve_layout\nfrom reactpy.core.types import RootComponentConstructor\n\nlogger = logging.getLogger(__name__)\n\n\n# BackendType.Options\n@dataclass\nclass Options(CommonOptions):\n \"\"\"Render server config for :func:`reactpy.backend.starlette.configure`\"\"\"\n\n cors: bool | dict[str, Any] = False\n \"\"\"Enable or configure Cross Origin Resource Sharing (CORS)\n\n For more information see docs for ``starlette.middleware.cors.CORSMiddleware``\n \"\"\"\n\n\n# BackendType.configure\ndef configure(\n app: Starlette,\n component: RootComponentConstructor,\n options: Options | None = None,\n) -> None:\n \"\"\"Configure the necessary ReactPy routes on the given app.\n\n Parameters:\n app: An application instance\n component: A component constructor\n options: Options for configuring server behavior\n \"\"\"\n options = options or Options()\n\n # this route should take priority so set up it up first\n _setup_single_view_dispatcher_route(options, app, component)\n\n _setup_common_routes(options, app)\n\n\n# BackendType.create_development_app\ndef create_development_app() -> Starlette:\n \"\"\"Return a :class:`Starlette` app instance in debug mode\"\"\"\n return Starlette(debug=True)\n\n\n# BackendType.serve_development_app\nasync def serve_development_app(\n app: Starlette,\n host: str,\n port: int,\n started: asyncio.Event | None = None,\n) -> None:\n \"\"\"Run a development server for starlette\"\"\"\n await serve_with_uvicorn(app, host, port, started)\n\n\ndef use_websocket() -> WebSocket:\n \"\"\"Get the current WebSocket object\"\"\"\n return use_connection().carrier\n\n\ndef use_connection() -> Connection[WebSocket]:\n conn = _use_connection()\n if not isinstance(conn.carrier, WebSocket): # nocov\n msg = f\"Connection has unexpected carrier {conn.carrier}. Are you running with a Flask server?\"\n raise TypeError(msg)\n return conn\n\n\ndef _setup_common_routes(options: Options, app: Starlette) -> None:\n cors_options = options.cors\n if cors_options: # nocov\n cors_params = (\n cors_options if isinstance(cors_options, dict) else {\"allow_origins\": [\"*\"]}\n )\n app.add_middleware(CORSMiddleware, **cors_params)\n\n # This really should be added to the APIRouter, but there's a bug in Starlette\n # BUG: https://github.com/tiangolo/fastapi/issues/1469\n url_prefix = options.url_prefix\n\n app.mount(\n str(MODULES_PATH),\n StaticFiles(directory=REACTPY_WEB_MODULES_DIR.current, check_dir=False),\n )\n app.mount(\n str(ASSETS_PATH),\n StaticFiles(directory=CLIENT_BUILD_DIR / \"assets\", check_dir=False),\n )\n # register this last so it takes least priority\n index_route = _make_index_route(options)\n\n if options.serve_index_route:\n app.add_route(f\"{url_prefix}/\", index_route)\n app.add_route(url_prefix + \"/{path:path}\", index_route)\n\n\ndef _make_index_route(options: Options) -> Callable[[Request], Awaitable[HTMLResponse]]:\n index_html = read_client_index_html(options)\n\n async def serve_index(request: Request) -> HTMLResponse:\n return HTMLResponse(index_html)\n\n return serve_index\n\n\ndef _setup_single_view_dispatcher_route(\n options: Options, app: Starlette, component: RootComponentConstructor\n) -> None:\n async def model_stream(socket: WebSocket) -> None:\n await socket.accept()\n send, recv = _make_send_recv_callbacks(socket)\n\n pathname = \"/\" + socket.scope[\"path_params\"].get(\"path\", \"\")\n pathname = pathname[len(options.url_prefix) :] or \"/\"\n search = socket.scope[\"query_string\"].decode()\n\n try:\n await serve_layout(\n Layout(\n ConnectionContext(\n component(),\n value=Connection(\n scope=socket.scope,\n location=Location(pathname, f\"?{search}\" if search else \"\"),\n carrier=socket,\n ),\n )\n ),\n send,\n recv,\n )\n except BaseExceptionGroup as egroup:\n for e in egroup.exceptions:\n if isinstance(e, WebSocketDisconnect):\n logger.info(f\"WebSocket disconnect: {e.code}\")\n break\n else: # nocov\n raise\n\n app.add_websocket_route(str(STREAM_PATH), model_stream)\n app.add_websocket_route(f\"{STREAM_PATH}/{{path:path}}\", model_stream)\n\n\ndef _make_send_recv_callbacks(\n socket: WebSocket,\n) -> tuple[SendCoroutine, RecvCoroutine]:\n async def sock_send(value: Any) -> None:\n await socket.send_text(json.dumps(value))\n\n async def sock_recv() -> Any:\n return json.loads(await socket.receive_text())\n\n return sock_send, sock_recv\n\n# Path: src/py/reactpy/reactpy/backend/fastapi.py\nfrom __future__ import annotations\n\nfrom fastapi import FastAPI\n\nfrom reactpy.backend import starlette\n\n# BackendType.Options\nOptions = starlette.Options\n\n# BackendType.configure\nconfigure = starlette.configure\n\n\n# BackendType.create_development_app\ndef create_development_app() -> FastAPI:\n \"\"\"Create a development ``FastAPI`` application instance.\"\"\"\n return FastAPI(debug=True)\n\n\n# BackendType.serve_development_app\nserve_development_app = starlette.serve_development_app\n\nuse_connection = starlette.use_connection\n\nuse_websocket = starlette.use_websocket\n\n# Path: src/py/reactpy/reactpy/backend/flask.py\nfrom __future__ import annotations\n\nimport asyncio\nimport json\nimport logging\nimport os\nfrom asyncio import Queue as AsyncQueue\nfrom dataclasses import dataclass\nfrom queue import Queue as ThreadQueue\nfrom threading import Event as ThreadEvent\nfrom threading import Thread\nfrom typing import Any, Callable, NamedTuple, NoReturn, cast\n\nfrom flask import (\n Blueprint,\n Flask,\n Request,\n copy_current_request_context,\n request,\n send_file,\n)\nfrom flask_cors import CORS\nfrom flask_sock import Sock\nfrom simple_websocket import Server as WebSocket\nfrom werkzeug.serving import BaseWSGIServer, make_server\n\nimport reactpy\nfrom reactpy.backend._common import (\n ASSETS_PATH,\n MODULES_PATH,\n PATH_PREFIX,\n STREAM_PATH,\n CommonOptions,\n read_client_index_html,\n safe_client_build_dir_path,\n safe_web_modules_dir_path,\n)\nfrom reactpy.backend.hooks import ConnectionContext\nfrom reactpy.backend.hooks import use_connection as _use_connection\nfrom reactpy.backend.types import Connection, Location\nfrom reactpy.core.serve import serve_layout\nfrom reactpy.core.types import ComponentType, RootComponentConstructor\nfrom reactpy.utils import Ref\n\nlogger = logging.getLogger(__name__)\n\n\n# BackendType.Options\n@dataclass\nclass Options(CommonOptions):\n \"\"\"Render server config for :func:`reactpy.backend.flask.configure`\"\"\"\n\n cors: bool | dict[str, Any] = False\n \"\"\"Enable or configure Cross Origin Resource Sharing (CORS)\n\n For more information see docs for ``flask_cors.CORS``\n \"\"\"\n\n\n# BackendType.configure\ndef configure(\n app: Flask, component: RootComponentConstructor, options: Options | None = None\n) -> None:\n \"\"\"Configure the necessary ReactPy routes on the given app.\n\n Parameters:\n app: An application instance\n component: A component constructor\n options: Options for configuring server behavior\n \"\"\"\n options = options or Options()\n\n api_bp = Blueprint(f\"reactpy_api_{id(app)}\", __name__, url_prefix=str(PATH_PREFIX))\n spa_bp = Blueprint(\n f\"reactpy_spa_{id(app)}\", __name__, url_prefix=options.url_prefix\n )\n\n _setup_single_view_dispatcher_route(api_bp, options, component)\n _setup_common_routes(api_bp, spa_bp, options)\n\n app.register_blueprint(api_bp)\n app.register_blueprint(spa_bp)\n\n\n# BackendType.create_development_app\ndef create_development_app() -> Flask:\n \"\"\"Create an application instance for development purposes\"\"\"\n os.environ[\"FLASK_DEBUG\"] = \"true\"\n return Flask(__name__)\n\n\n# BackendType.serve_development_app\nasync def serve_development_app(\n app: Flask,\n host: str,\n port: int,\n started: asyncio.Event | None = None,\n) -> None:\n \"\"\"Run a development server for FastAPI\"\"\"\n loop = asyncio.get_running_loop()\n stopped = asyncio.Event()\n\n server: Ref[BaseWSGIServer] = Ref()\n\n def run_server() -> None:\n server.current = make_server(host, port, app, threaded=True)\n if started:\n loop.call_soon_threadsafe(started.set)\n try:\n server.current.serve_forever() # type: ignore\n finally:\n loop.call_soon_threadsafe(stopped.set)\n\n thread = Thread(target=run_server, daemon=True)\n thread.start()\n\n if started:\n await started.wait()\n\n try:\n await stopped.wait()\n finally:\n # we may have exited because this task was cancelled\n server.current.shutdown()\n # the thread should eventually join\n thread.join(timeout=3)\n # just double check it happened\n if thread.is_alive(): # nocov\n msg = \"Failed to shutdown server.\"\n raise RuntimeError(msg)\n\n\ndef use_websocket() -> WebSocket:\n \"\"\"A handle to the current websocket\"\"\"\n return use_connection().carrier.websocket\n\n\ndef use_request() -> Request:\n \"\"\"Get the current ``Request``\"\"\"\n return use_connection().carrier.request\n\n\ndef use_connection() -> Connection[_FlaskCarrier]:\n \"\"\"Get the current :class:`Connection`\"\"\"\n conn = _use_connection()\n if not isinstance(conn.carrier, _FlaskCarrier): # nocov\n msg = f\"Connection has unexpected carrier {conn.carrier}. Are you running with a Flask server?\"\n raise TypeError(msg)\n return conn\n\n\ndef _setup_common_routes(\n api_blueprint: Blueprint,\n spa_blueprint: Blueprint,\n options: Options,\n) -> None:\n cors_options = options.cors\n if cors_options: # nocov\n cors_params = cors_options if isinstance(cors_options, dict) else {}\n CORS(api_blueprint, **cors_params)\n\n @api_blueprint.route(f\"/{ASSETS_PATH.name}/\")\n def send_assets_dir(path: str = \"\") -> Any:\n return send_file(safe_client_build_dir_path(f\"assets/{path}\"))\n\n @api_blueprint.route(f\"/{MODULES_PATH.name}/\")\n def send_modules_dir(path: str = \"\") -> Any:\n return send_file(safe_web_modules_dir_path(path), mimetype=\"text/javascript\")\n\n index_html = read_client_index_html(options)\n\n if options.serve_index_route:\n\n @spa_blueprint.route(\"/\")\n @spa_blueprint.route(\"/\")\n def send_client_dir(_: str = \"\") -> Any:\n return index_html\n\n\ndef _setup_single_view_dispatcher_route(\n api_blueprint: Blueprint, options: Options, constructor: RootComponentConstructor\n) -> None:\n sock = Sock(api_blueprint)\n\n def model_stream(ws: WebSocket, path: str = \"\") -> None:\n def send(value: Any) -> None:\n ws.send(json.dumps(value))\n\n def recv() -> Any:\n return json.loads(ws.receive())\n\n _dispatch_in_thread(\n ws,\n # remove any url prefix from path\n path[len(options.url_prefix) :],\n constructor(),\n send,\n recv,\n )\n\n sock.route(STREAM_PATH.name, endpoint=\"without_path\")(model_stream)\n sock.route(f\"{STREAM_PATH.name}/\", endpoint=\"with_path\")(model_stream)\n\n\ndef _dispatch_in_thread(\n websocket: WebSocket,\n path: str,\n component: ComponentType,\n send: Callable[[Any], None],\n recv: Callable[[], Any | None],\n) -> NoReturn:\n dispatch_thread_info_created = ThreadEvent()\n dispatch_thread_info_ref: reactpy.Ref[_DispatcherThreadInfo | None] = reactpy.Ref(\n None\n )\n\n @copy_current_request_context\n def run_dispatcher() -> None:\n loop = asyncio.new_event_loop()\n asyncio.set_event_loop(loop)\n\n thread_send_queue: ThreadQueue[Any] = ThreadQueue()\n async_recv_queue: AsyncQueue[Any] = AsyncQueue()\n\n async def send_coro(value: Any) -> None:\n thread_send_queue.put(value)\n\n async def main() -> None:\n search = request.query_string.decode()\n await serve_layout(\n reactpy.Layout(\n ConnectionContext(\n component,\n value=Connection(\n scope=request.environ,\n location=Location(\n pathname=f\"/{path}\",\n search=f\"?{search}\" if search else \"\",\n ),\n carrier=_FlaskCarrier(request, websocket),\n ),\n ),\n ),\n send_coro,\n async_recv_queue.get,\n )\n\n main_future = asyncio.ensure_future(main(), loop=loop)\n\n dispatch_thread_info_ref.current = _DispatcherThreadInfo(\n dispatch_loop=loop,\n dispatch_future=main_future,\n thread_send_queue=thread_send_queue,\n async_recv_queue=async_recv_queue,\n )\n dispatch_thread_info_created.set()\n\n loop.run_until_complete(main_future)\n\n...\n", "needle_token_start": 3907, "needle_token_end": 4106, "code_context_ntokens": 16025, "output": ["```python\ndef image(\n format: str,\n value: str | bytes = \"\",\n attributes: dict[str, Any] | None = None,\n) -> VdomDict:\n \"\"\"Utility for constructing an image from a string or bytes\n\n The source value will automatically be encoded to base64\n \"\"\"\n if format == \"svg\":\n format = \"svg+xml\" # noqa: A001\n\n if isinstance(value, str):\n bytes_value = value.encode()\n else:\n bytes_value = value\n\n base64_value = b64encode(bytes_value).decode()\n src = f\"data:image/{format};base64,{base64_value}\"\n\n return {\"tagName\": \"img\", \"attributes\": {\"src\": src, **(attributes or {})}}\n```"]} +{"repo": "reactive-python/reactpy", "name": "use", "language": "python", "path": "src/py/reactpy/reactpy/testing/common.py", "position_ratio": 0.35, "description": "\nFunction Description:\n1. **Purpose**: To create an event handler that is associated with a unique identifier, which can be configured to stop propagation or prevent default actions during event handling.\n2. **Input**: A function that defines the behavior of the event handler, and optional boolean flags to control event propagation and default actions.\n3. **Output**: An event handler object that encapsulates the user-defined function, event control behaviors, and a unique identifier.\n4. **Procedure**: The method generates a unique identifier for the event handler upon object instantiation. When called, it wraps the provided function with additional functionalities based on the specified flags for stopping propagation and preventing default actions, and associates it with the unique identifier.\n\n", "instruction": "Based on the function description and code context, please retrieve and repeat the exact described function from the code context in a code block wrapped by ```:", "template": "instruction\ncode_context\ndescription\ninstruction", "code_context": "# Path: src/py/reactpy/reactpy/backend/default.py\nfrom __future__ import annotations\n\nimport asyncio\nfrom logging import getLogger\nfrom sys import exc_info\nfrom typing import Any, NoReturn\n\nfrom reactpy.backend.types import BackendType\nfrom reactpy.backend.utils import SUPPORTED_BACKENDS, all_implementations\n...\n# Path: src/py/reactpy/reactpy/testing/logs.py\nfrom __future__ import annotations\n\nimport logging\nimport re\nfrom collections.abc import Iterator\nfrom contextlib import contextmanager\nfrom traceback import format_exception\nfrom typing import Any, NoReturn\n\nfrom reactpy.logging import ROOT_LOGGER\n\n\nclass LogAssertionError(AssertionError):\n \"\"\"An assertion error raised in relation to log messages.\"\"\"\n\n\n@contextmanager\ndef assert_reactpy_did_log(\n match_message: str = \"\",\n error_type: type[Exception] | None = None,\n match_error: str = \"\",\n) -> Iterator[None]:\n \"\"\"Assert that ReactPy produced a log matching the described message or error.\n\n Args:\n match_message: Must match a logged message.\n error_type: Checks the type of logged exceptions.\n match_error: Must match an error message.\n \"\"\"\n message_pattern = re.compile(match_message)\n error_pattern = re.compile(match_error)\n\n with capture_reactpy_logs() as log_records:\n try:\n yield None\n except Exception:\n raise\n else:\n for record in list(log_records):\n if (\n # record message matches\n message_pattern.findall(record.getMessage())\n # error type matches\n and (\n error_type is None\n or (\n record.exc_info is not None\n and record.exc_info[0] is not None\n and issubclass(record.exc_info[0], error_type)\n )\n )\n # error message pattern matches\n and (\n not match_error\n or (\n record.exc_info is not None\n and error_pattern.findall(\n \"\".join(format_exception(*record.exc_info))\n )\n )\n )\n ):\n break\n else: # nocov\n _raise_log_message_error(\n \"Could not find a log record matching the given\",\n match_message,\n error_type,\n match_error,\n )\n\n\n@contextmanager\ndef assert_reactpy_did_not_log(\n match_message: str = \"\",\n error_type: type[Exception] | None = None,\n match_error: str = \"\",\n) -> Iterator[None]:\n \"\"\"Assert the inverse of :func:`assert_reactpy_logged`\"\"\"\n try:\n with assert_reactpy_did_log(match_message, error_type, match_error):\n yield None\n except LogAssertionError:\n pass\n else:\n _raise_log_message_error(\n \"Did find a log record matching the given\",\n match_message,\n error_type,\n match_error,\n )\n\n\ndef list_logged_exceptions(\n log_records: list[logging.LogRecord],\n pattern: str = \"\",\n types: type[Any] | tuple[type[Any], ...] = Exception,\n log_level: int = logging.ERROR,\n del_log_records: bool = True,\n) -> list[BaseException]:\n \"\"\"Return a list of logged exception matching the given criteria\n\n Args:\n log_level: The level of log to check\n exclude_exc_types: Any exception types to ignore\n del_log_records: Whether to delete the log records for yielded exceptions\n \"\"\"\n found: list[BaseException] = []\n compiled_pattern = re.compile(pattern)\n for index, record in enumerate(log_records):\n if record.levelno >= log_level and record.exc_info:\n error = record.exc_info[1]\n if (\n error is not None\n and isinstance(error, types)\n and compiled_pattern.search(str(error))\n ):\n if del_log_records:\n del log_records[index - len(found)]\n found.append(error)\n return found\n\n\n@contextmanager\ndef capture_reactpy_logs() -> Iterator[list[logging.LogRecord]]:\n \"\"\"Capture logs from ReactPy\n\n Any logs produced in this context are cleared afterwards\n \"\"\"\n original_level = ROOT_LOGGER.level\n ROOT_LOGGER.setLevel(logging.DEBUG)\n try:\n if _LOG_RECORD_CAPTOR in ROOT_LOGGER.handlers:\n start_index = len(_LOG_RECORD_CAPTOR.records)\n try:\n yield _LOG_RECORD_CAPTOR.records\n finally:\n end_index = len(_LOG_RECORD_CAPTOR.records)\n _LOG_RECORD_CAPTOR.records[start_index:end_index] = []\n return None\n\n ROOT_LOGGER.addHandler(_LOG_RECORD_CAPTOR)\n try:\n yield _LOG_RECORD_CAPTOR.records\n finally:\n ROOT_LOGGER.removeHandler(_LOG_RECORD_CAPTOR)\n _LOG_RECORD_CAPTOR.records.clear()\n finally:\n ROOT_LOGGER.setLevel(original_level)\n\n\nclass _LogRecordCaptor(logging.NullHandler):\n def __init__(self) -> None:\n self.records: list[logging.LogRecord] = []\n super().__init__()\n\n def handle(self, record: logging.LogRecord) -> bool:\n self.records.append(record)\n return True\n\n\n_LOG_RECORD_CAPTOR = _LogRecordCaptor()\n\n\ndef _raise_log_message_error(\n prefix: str,\n match_message: str = \"\",\n error_type: type[Exception] | None = None,\n match_error: str = \"\",\n) -> NoReturn:\n conditions = []\n if match_message:\n conditions.append(f\"log message pattern {match_message!r}\")\n if error_type:\n conditions.append(f\"exception type {error_type}\")\n if match_error:\n conditions.append(f\"error message pattern {match_error!r}\")\n raise LogAssertionError(prefix + \" \" + \" and \".join(conditions))\n\n# Path: src/py/reactpy/reactpy/testing/backend.py\nfrom __future__ import annotations\n\nimport asyncio\nimport logging\nfrom contextlib import AsyncExitStack, suppress\nfrom types import TracebackType\nfrom typing import Any, Callable\nfrom urllib.parse import urlencode, urlunparse\n\nfrom reactpy.backend import default as default_server\nfrom reactpy.backend.types import BackendType\nfrom reactpy.backend.utils import find_available_port\nfrom reactpy.config import REACTPY_TESTING_DEFAULT_TIMEOUT\nfrom reactpy.core.component import component\nfrom reactpy.core.hooks import use_callback, use_effect, use_state\nfrom reactpy.core.types import ComponentConstructor\nfrom reactpy.testing.logs import (\n LogAssertionError,\n capture_reactpy_logs,\n list_logged_exceptions,\n)\nfrom reactpy.utils import Ref\n\n\nclass BackendFixture:\n \"\"\"A test fixture for running a server and imperatively displaying views\n\n This fixture is typically used alongside async web drivers like ``playwight``.\n\n Example:\n .. code-block::\n\n async with BackendFixture() as server:\n server.mount(MyComponent)\n \"\"\"\n\n _records: list[logging.LogRecord]\n _server_future: asyncio.Task[Any]\n _exit_stack = AsyncExitStack()\n\n def __init__(\n self,\n host: str = \"127.0.0.1\",\n port: int | None = None,\n app: Any | None = None,\n implementation: BackendType[Any] | None = None,\n options: Any | None = None,\n timeout: float | None = None,\n ) -> None:\n self.host = host\n self.port = port or find_available_port(host)\n self.mount, self._root_component = _hotswap()\n self.timeout = (\n REACTPY_TESTING_DEFAULT_TIMEOUT.current if timeout is None else timeout\n )\n\n if app is not None and implementation is None:\n msg = \"If an application instance its corresponding server implementation must be provided too.\"\n raise ValueError(msg)\n\n self._app = app\n self.implementation = implementation or default_server\n self._options = options\n\n @property\n def log_records(self) -> list[logging.LogRecord]:\n \"\"\"A list of captured log records\"\"\"\n return self._records\n\n def url(self, path: str = \"\", query: Any | None = None) -> str:\n \"\"\"Return a URL string pointing to the host and point of the server\n\n Args:\n path: the path to a resource on the server\n query: a dictionary or list of query parameters\n \"\"\"\n return urlunparse(\n [\n \"http\",\n f\"{self.host}:{self.port}\",\n path,\n \"\",\n urlencode(query or ()),\n \"\",\n ]\n )\n\n def list_logged_exceptions(\n self,\n pattern: str = \"\",\n types: type[Any] | tuple[type[Any], ...] = Exception,\n log_level: int = logging.ERROR,\n del_log_records: bool = True,\n ) -> list[BaseException]:\n \"\"\"Return a list of logged exception matching the given criteria\n\n Args:\n log_level: The level of log to check\n exclude_exc_types: Any exception types to ignore\n del_log_records: Whether to delete the log records for yielded exceptions\n \"\"\"\n return list_logged_exceptions(\n self.log_records,\n pattern,\n types,\n log_level,\n del_log_records,\n )\n\n async def __aenter__(self) -> BackendFixture:\n self._exit_stack = AsyncExitStack()\n self._records = self._exit_stack.enter_context(capture_reactpy_logs())\n\n app = self._app or self.implementation.create_development_app()\n self.implementation.configure(app, self._root_component, self._options)\n\n started = asyncio.Event()\n server_future = asyncio.create_task(\n self.implementation.serve_development_app(\n app, self.host, self.port, started\n )\n )\n\n async def stop_server() -> None:\n server_future.cancel()\n with suppress(asyncio.CancelledError):\n await asyncio.wait_for(server_future, timeout=self.timeout)\n\n self._exit_stack.push_async_callback(stop_server)\n\n try:\n await asyncio.wait_for(started.wait(), timeout=self.timeout)\n except Exception: # nocov\n # see if we can await the future for a more helpful error\n await asyncio.wait_for(server_future, timeout=self.timeout)\n raise\n\n return self\n\n async def __aexit__(\n self,\n exc_type: type[BaseException] | None,\n exc_value: BaseException | None,\n traceback: TracebackType | None,\n ) -> None:\n await self._exit_stack.aclose()\n\n self.mount(None) # reset the view\n\n logged_errors = self.list_logged_exceptions(del_log_records=False)\n if logged_errors: # nocov\n msg = \"Unexpected logged exception\"\n raise LogAssertionError(msg) from logged_errors[0]\n\n\n_MountFunc = Callable[[\"Callable[[], Any] | None\"], None]\n\n\ndef _hotswap(update_on_change: bool = False) -> tuple[_MountFunc, ComponentConstructor]:\n \"\"\"Swap out components from a layout on the fly.\n\n Since you can't change the component functions used to create a layout\n in an imperative manner, you can use ``hotswap`` to do this so\n long as you set things up ahead of time.\n\n Parameters:\n update_on_change: Whether or not all views of the layout should be updated on a swap.\n\n Example:\n .. code-block:: python\n\n import reactpy\n\n show, root = reactpy.hotswap()\n PerClientStateServer(root).run_in_thread(\"localhost\", 8765)\n\n @reactpy.component\n def DivOne(self):\n return {\"tagName\": \"div\", \"children\": [1]}\n\n show(DivOne)\n\n # displaying the output now will show DivOne\n\n @reactpy.component\n def DivTwo(self):\n return {\"tagName\": \"div\", \"children\": [2]}\n\n show(DivTwo)\n\n # displaying the output now will show DivTwo\n \"\"\"\n constructor_ref: Ref[Callable[[], Any]] = Ref(lambda: None)\n\n if update_on_change:\n set_constructor_callbacks: set[Callable[[Callable[[], Any]], None]] = set()\n\n @component\n def HotSwap() -> Any:\n # new displays will adopt the latest constructor and arguments\n constructor, _set_constructor = use_state(lambda: constructor_ref.current)\n set_constructor = use_callback(lambda new: _set_constructor(lambda _: new))\n\n def add_callback() -> Callable[[], None]:\n set_constructor_callbacks.add(set_constructor)\n return lambda: set_constructor_callbacks.remove(set_constructor)\n\n use_effect(add_callback)\n\n return constructor()\n\n def swap(constructor: Callable[[], Any] | None) -> None:\n constructor = constructor_ref.current = constructor or (lambda: None)\n\n for set_constructor in set_constructor_callbacks:\n set_constructor(constructor)\n\n else:\n\n @component\n def HotSwap() -> Any:\n return constructor_ref.current()\n\n def swap(constructor: Callable[[], Any] | None) -> None:\n constructor_ref.current = constructor or (lambda: None)\n\n return swap, HotSwap\n\n# Path: src/py/reactpy/reactpy/testing/common.py\nfrom __future__ import annotations\n\nimport asyncio\nimport inspect\nimport shutil\nimport time\nfrom collections.abc import Awaitable\nfrom functools import wraps\nfrom typing import Any, Callable, Generic, TypeVar, cast\nfrom uuid import uuid4\nfrom weakref import ref\n\nfrom typing_extensions import ParamSpec\n\nfrom reactpy.config import REACTPY_TESTING_DEFAULT_TIMEOUT, REACTPY_WEB_MODULES_DIR\nfrom reactpy.core._life_cycle_hook import LifeCycleHook, current_hook\nfrom reactpy.core.events import EventHandler, to_event_handler_function\n\n\ndef clear_reactpy_web_modules_dir() -> None:\n \"\"\"Clear the directory where ReactPy stores registered web modules\"\"\"\n for path in REACTPY_WEB_MODULES_DIR.current.iterdir():\n shutil.rmtree(path) if path.is_dir() else path.unlink()\n\n\n_P = ParamSpec(\"_P\")\n_R = TypeVar(\"_R\")\n\n\n_DEFAULT_POLL_DELAY = 0.1\n\n\nclass poll(Generic[_R]): # noqa: N801\n \"\"\"Wait until the result of an sync or async function meets some condition\"\"\"\n\n def __init__(\n self,\n function: Callable[_P, Awaitable[_R] | _R],\n *args: _P.args,\n **kwargs: _P.kwargs,\n ) -> None:\n coro: Callable[_P, Awaitable[_R]]\n if not inspect.iscoroutinefunction(function):\n\n async def coro(*args: _P.args, **kwargs: _P.kwargs) -> _R:\n return cast(_R, function(*args, **kwargs))\n\n else:\n coro = cast(Callable[_P, Awaitable[_R]], function)\n self._func = coro\n self._args = args\n self._kwargs = kwargs\n\n async def until(\n self,\n condition: Callable[[_R], bool],\n timeout: float = REACTPY_TESTING_DEFAULT_TIMEOUT.current,\n delay: float = _DEFAULT_POLL_DELAY,\n description: str = \"condition to be true\",\n ) -> None:\n \"\"\"Check that the coroutines result meets a condition within the timeout\"\"\"\n started_at = time.time()\n while True:\n await asyncio.sleep(delay)\n result = await self._func(*self._args, **self._kwargs)\n if condition(result):\n break\n elif (time.time() - started_at) > timeout: # nocov\n msg = f\"Expected {description} after {timeout} seconds - last value was {result!r}\"\n raise asyncio.TimeoutError(msg)\n\n async def until_is(\n self,\n right: _R,\n timeout: float = REACTPY_TESTING_DEFAULT_TIMEOUT.current,\n delay: float = _DEFAULT_POLL_DELAY,\n ) -> None:\n \"\"\"Wait until the result is identical to the given value\"\"\"\n return await self.until(\n lambda left: left is right,\n timeout,\n delay,\n f\"value to be identical to {right!r}\",\n )\n\n async def until_equals(\n self,\n right: _R,\n timeout: float = REACTPY_TESTING_DEFAULT_TIMEOUT.current,\n delay: float = _DEFAULT_POLL_DELAY,\n ) -> None:\n \"\"\"Wait until the result is equal to the given value\"\"\"\n return await self.until(\n lambda left: left == right,\n timeout,\n delay,\n f\"value to equal {right!r}\",\n )\n\n\nclass HookCatcher:\n \"\"\"Utility for capturing a LifeCycleHook from a component\n\n Example:\n .. code-block::\n\n hooks = HookCatcher(index_by_kwarg=\"thing\")\n\n @reactpy.component\n @hooks.capture\n def MyComponent(thing):\n ...\n\n ... # render the component\n\n # grab the last render of where MyComponent(thing='something')\n hooks.index[\"something\"]\n # or grab the hook from the component's last render\n hooks.latest\n\n After the first render of ``MyComponent`` the ``HookCatcher`` will have\n captured the component's ``LifeCycleHook``.\n \"\"\"\n\n latest: LifeCycleHook\n\n def __init__(self, index_by_kwarg: str | None = None):\n self.index_by_kwarg = index_by_kwarg\n self.index: dict[Any, LifeCycleHook] = {}\n\n def capture(self, render_function: Callable[..., Any]) -> Callable[..., Any]:\n \"\"\"Decorator for capturing a ``LifeCycleHook`` on each render of a component\"\"\"\n\n # The render function holds a reference to `self` and, via the `LifeCycleHook`,\n # the component. Some tests check whether components are garbage collected, thus\n # we must use a `ref` here to ensure these checks pass once the catcher itself\n # has been collected.\n self_ref = ref(self)\n\n @wraps(render_function)\n def wrapper(*args: Any, **kwargs: Any) -> Any:\n self = self_ref()\n if self is None:\n raise RuntimeError(\"Hook catcher has been garbage collected\")\n\n hook = current_hook()\n if self.index_by_kwarg is not None:\n self.index[kwargs[self.index_by_kwarg]] = hook\n self.latest = hook\n return render_function(*args, **kwargs)\n\n return wrapper\n\n\nclass StaticEventHandler:\n \"\"\"Utility for capturing the target of one event handler\n\n Example:\n .. code-block::\n\n static_handler = StaticEventHandler()\n\n @reactpy.component\n def MyComponent():\n state, set_state = reactpy.hooks.use_state(0)\n handler = static_handler.use(lambda event: set_state(state + 1))\n return reactpy.html.button({\"onClick\": handler}, \"Click me!\")\n\n # gives the target ID for onClick where from the last render of MyComponent\n static_handlers.target\n\n If you need to capture event handlers from different instances of a component\n the you should create multiple ``StaticEventHandler`` instances.\n\n .. code-block::\n\n static_handlers_by_key = {\n \"first\": StaticEventHandler(),\n \"second\": StaticEventHandler(),\n }\n\n @reactpy.component\n def Parent():\n return reactpy.html.div(Child(key=\"first\"), Child(key=\"second\"))\n\n @reactpy.component\n def Child(key):\n state, set_state = reactpy.hooks.use_state(0)\n handler = static_handlers_by_key[key].use(lambda event: set_state(state + 1))\n return reactpy.html.button({\"onClick\": handler}, \"Click me!\")\n\n # grab the individual targets for each instance above\n first_target = static_handlers_by_key[\"first\"].target\n second_target = static_handlers_by_key[\"second\"].target\n \"\"\"\n\n def __init__(self) -> None:\n self.target = uuid4().hex\n\n \ndef use(\n self,\n function: Callable[..., Any],\n stop_propagation: bool = False,\n prevent_default: bool = False,\n ) -> EventHandler:\n return EventHandler(\n to_event_handler_function(function),\n stop_propagation,\n prevent_default,\n self.target,\n )\n\n# Path: src/py/reactpy/reactpy/testing/display.py\nfrom __future__ import annotations\n\nfrom contextlib import AsyncExitStack\nfrom types import TracebackType\nfrom typing import Any\n\nfrom playwright.async_api import (\n Browser,\n BrowserContext,\n ElementHandle,\n Page,\n async_playwright,\n)\n\nfrom reactpy.config import REACTPY_TESTING_DEFAULT_TIMEOUT\nfrom reactpy.testing.backend import BackendFixture\nfrom reactpy.types import RootComponentConstructor\n\n\nclass DisplayFixture:\n \"\"\"A fixture for running web-based tests using ``playwright``\"\"\"\n\n _exit_stack: AsyncExitStack\n\n def __init__(\n self,\n backend: BackendFixture | None = None,\n driver: Browser | BrowserContext | Page | None = None,\n url_prefix: str = \"\",\n ) -> None:\n if backend is not None:\n self.backend = backend\n if driver is not None:\n if isinstance(driver, Page):\n self.page = driver\n else:\n self._browser = driver\n self.url_prefix = url_prefix\n\n async def show(\n self,\n component: RootComponentConstructor,\n ) -> None:\n self.backend.mount(component)\n await self.goto(\"/\")\n await self.root_element() # check that root element is attached\n\n async def goto(\n self, path: str, query: Any | None = None, add_url_prefix: bool = True\n ) -> None:\n await self.page.goto(\n self.backend.url(\n f\"{self.url_prefix}{path}\" if add_url_prefix else path, query\n )\n )\n\n async def root_element(self) -> ElementHandle:\n element = await self.page.wait_for_selector(\"#app\", state=\"attached\")\n if element is None: # nocov\n msg = \"Root element not attached\"\n raise RuntimeError(msg)\n return element\n\n async def __aenter__(self) -> DisplayFixture:\n es = self._exit_stack = AsyncExitStack()\n\n browser: Browser | BrowserContext\n if not hasattr(self, \"page\"):\n if not hasattr(self, \"_browser\"):\n pw = await es.enter_async_context(async_playwright())\n browser = await pw.chromium.launch()\n else:\n browser = self._browser\n self.page = await browser.new_page()\n\n self.page.set_default_timeout(REACTPY_TESTING_DEFAULT_TIMEOUT.current * 1000)\n\n if not hasattr(self, \"backend\"):\n self.backend = BackendFixture()\n await es.enter_async_context(self.backend)\n\n return self\n\n async def __aexit__(\n self,\n exc_type: type[BaseException] | None,\n exc_value: BaseException | None,\n traceback: TracebackType | None,\n ) -> None:\n self.backend.mount(None)\n await self._exit_stack.aclose()\n\n# Path: src/py/reactpy/reactpy/testing/__init__.py\nfrom reactpy.testing.backend import BackendFixture\nfrom reactpy.testing.common import (\n HookCatcher,\n StaticEventHandler,\n clear_reactpy_web_modules_dir,\n poll,\n)\nfrom reactpy.testing.display import DisplayFixture\nfrom reactpy.testing.logs import (\n LogAssertionError,\n assert_reactpy_did_log,\n assert_reactpy_did_not_log,\n capture_reactpy_logs,\n)\n\n__all__ = [\n \"assert_reactpy_did_not_log\",\n \"assert_reactpy_did_log\",\n \"capture_reactpy_logs\",\n \"clear_reactpy_web_modules_dir\",\n \"DisplayFixture\",\n \"HookCatcher\",\n \"LogAssertionError\",\n \"poll\",\n \"BackendFixture\",\n \"StaticEventHandler\",\n]\n\n# Path: src/py/reactpy/reactpy/widgets.py\nfrom __future__ import annotations\n\nfrom base64 import b64encode\nfrom collections.abc import Sequence\nfrom typing import TYPE_CHECKING, Any, Callable, Protocol, TypeVar\n\nimport reactpy\nfrom reactpy import html\nfrom reactpy._warnings import warn\nfrom reactpy.core.types import ComponentConstructor, VdomDict\n\n\ndef image(\n format: str,\n value: str | bytes = \"\",\n attributes: dict[str, Any] | None = None,\n) -> VdomDict:\n \"\"\"Utility for constructing an image from a string or bytes\n\n The source value will automatically be encoded to base64\n \"\"\"\n if format == \"svg\":\n format = \"svg+xml\" # noqa: A001\n\n if isinstance(value, str):\n bytes_value = value.encode()\n else:\n bytes_value = value\n\n base64_value = b64encode(bytes_value).decode()\n src = f\"data:image/{format};base64,{base64_value}\"\n\n return {\"tagName\": \"img\", \"attributes\": {\"src\": src, **(attributes or {})}}\n\n\n_Value = TypeVar(\"_Value\")\n\n\ndef use_linked_inputs(\n attributes: Sequence[dict[str, Any]],\n on_change: Callable[[_Value], None] = lambda value: None,\n cast: _CastFunc[_Value] = lambda value: value,\n initial_value: str = \"\",\n ignore_empty: bool = True,\n) -> list[VdomDict]:\n \"\"\"Return a list of linked inputs equal to the number of given attributes.\n\n Parameters:\n attributes:\n That attributes of each returned input element. If the number of generated\n inputs is variable, you may need to assign each one a\n :ref:`key ` by including a ``\"key\"`` in each\n attribute dictionary.\n on_change:\n A callback which is triggered when any input is changed. This callback need\n not update the 'value' field in the attributes of the inputs since that is\n handled automatically.\n cast:\n Cast the 'value' of changed inputs that is passed to ``on_change``.\n initial_value:\n Initialize the 'value' field of the inputs.\n ignore_empty:\n Do not trigger ``on_change`` if the 'value' is an empty string.\n \"\"\"\n value, set_value = reactpy.hooks.use_state(initial_value)\n\n def sync_inputs(event: dict[str, Any]) -> None:\n new_value = event[\"target\"][\"value\"]\n set_value(new_value)\n if not new_value and ignore_empty:\n return None\n on_change(cast(new_value))\n\n inputs: list[VdomDict] = []\n for attrs in attributes:\n inputs.append(html.input({**attrs, \"on_change\": sync_inputs, \"value\": value}))\n\n return inputs\n\n\n_CastTo_co = TypeVar(\"_CastTo_co\", covariant=True)\n\n\nclass _CastFunc(Protocol[_CastTo_co]):\n def __call__(self, value: str) -> _CastTo_co: ...\n\n\nif TYPE_CHECKING:\n from reactpy.testing.backend import _MountFunc\n\n\ndef hotswap(\n update_on_change: bool = False,\n) -> tuple[_MountFunc, ComponentConstructor]: # nocov\n warn(\n \"The 'hotswap' function is deprecated and will be removed in a future release\",\n DeprecationWarning,\n stacklevel=2,\n )\n from reactpy.testing.backend import _hotswap\n\n return _hotswap(update_on_change)\n\n# Path: src/py/reactpy/reactpy/__init__.py\nfrom reactpy import backend, config, html, logging, sample, svg, types, web, widgets\nfrom reactpy.backend.hooks import use_connection, use_location, use_scope\nfrom reactpy.backend.utils import run\nfrom reactpy.core import hooks\nfrom reactpy.core.component import component\nfrom reactpy.core.events import event\nfrom reactpy.core.hooks import (\n create_context,\n use_callback,\n use_context,\n use_debug_value,\n use_effect,\n use_memo,\n use_reducer,\n use_ref,\n use_state,\n)\nfrom reactpy.core.layout import Layout\nfrom reactpy.core.vdom import vdom\nfrom reactpy.utils import Ref, html_to_vdom, vdom_to_html\n\n__author__ = \"The Reactive Python Team\"\n__version__ = \"1.0.2\" # DO NOT MODIFY\n\n__all__ = [\n \"backend\",\n \"component\",\n \"config\",\n \"create_context\",\n \"event\",\n \"hooks\",\n \"html_to_vdom\",\n \"html\",\n \"Layout\",\n \"logging\",\n \"Ref\",\n \"run\",\n \"sample\",\n \"Stop\",\n \"svg\",\n \"types\",\n \"use_callback\",\n \"use_connection\",\n \"use_context\",\n \"use_debug_value\",\n \"use_effect\",\n \"use_location\",\n \"use_memo\",\n \"use_reducer\",\n \"use_ref\",\n \"use_scope\",\n \"use_state\",\n \"vdom_to_html\",\n \"vdom\",\n \"web\",\n \"widgets\",\n]\n\n# Path: src/py/reactpy/reactpy/_console/__init__.py\n\n# Path: src/py/reactpy/reactpy/_console/ast_utils.py\nfrom __future__ import annotations\n\nimport ast\nfrom collections.abc import Iterator, Sequence\nfrom dataclasses import dataclass\nfrom pathlib import Path\nfrom textwrap import indent\nfrom tokenize import COMMENT as COMMENT_TOKEN\nfrom tokenize import generate_tokens\nfrom typing import Any\n\nimport click\n\nfrom reactpy import html\n\n\ndef rewrite_changed_nodes(\n file: Path,\n source: str,\n tree: ast.AST,\n changed: list[ChangedNode],\n) -> str:\n ast.fix_missing_locations(tree)\n\n lines = source.split(\"\\n\")\n\n # find closest parent nodes that should be re-written\n nodes_to_unparse: list[ast.AST] = []\n for change in changed:\n node_lineage = [change.node, *change.parents]\n for i in range(len(node_lineage) - 1):\n current_node, next_node = node_lineage[i : i + 2]\n if (\n not hasattr(next_node, \"lineno\")\n or next_node.lineno < change.node.lineno\n or isinstance(next_node, (ast.ClassDef, ast.FunctionDef))\n ):\n nodes_to_unparse.append(current_node)\n break\n else: # nocov\n msg = \"Failed to change code\"\n raise RuntimeError(msg)\n\n # check if an nodes to rewrite contain each other, pick outermost nodes\n current_outermost_node, *sorted_nodes_to_unparse = sorted(\n nodes_to_unparse, key=lambda n: n.lineno\n )\n outermost_nodes_to_unparse = [current_outermost_node]\n for node in sorted_nodes_to_unparse:\n if (\n not current_outermost_node.end_lineno\n or node.lineno > current_outermost_node.end_lineno\n ):\n current_outermost_node = node\n outermost_nodes_to_unparse.append(node)\n\n moved_comment_lines_from_end: list[int] = []\n # now actually rewrite these nodes (in reverse to avoid changes earlier in file)\n for node in reversed(outermost_nodes_to_unparse):\n # make a best effort to preserve any comments that we're going to overwrite\n comments = _find_comments(lines[node.lineno - 1 : node.end_lineno])\n\n # there may be some content just before and after the content we're re-writing\n before_replacement = lines[node.lineno - 1][: node.col_offset].lstrip()\n\n after_replacement = (\n lines[node.end_lineno - 1][node.end_col_offset :].strip()\n if node.end_lineno is not None and node.end_col_offset is not None\n else \"\"\n )\n\n replacement = indent(\n before_replacement\n + \"\\n\".join([*comments, ast.unparse(node)])\n + after_replacement,\n \" \" * (node.col_offset - len(before_replacement)),\n )\n\n lines[node.lineno - 1 : node.end_lineno or node.lineno] = [replacement]\n\n if comments:\n moved_comment_lines_from_end.append(len(lines) - node.lineno)\n\n for lineno_from_end in sorted(set(moved_comment_lines_from_end)):\n click.echo(f\"Moved comments to {file}:{len(lines) - lineno_from_end}\")\n\n return \"\\n\".join(lines)\n\n\n@dataclass\nclass ChangedNode:\n node: ast.AST\n parents: Sequence[ast.AST]\n\n\ndef find_element_constructor_usages(\n tree: ast.AST, add_props: bool = False\n) -> Iterator[ElementConstructorInfo]:\n changed: list[Sequence[ast.AST]] = []\n for parents, node in _walk_with_parent(tree):\n if not (isinstance(node, ast.Call)):\n continue\n\n func = node.func\n if isinstance(func, ast.Attribute) and (\n (isinstance(func.value, ast.Name) and func.value.id == \"html\")\n or (isinstance(func.value, ast.Attribute) and func.value.attr == \"html\")\n ):\n name = func.attr\n elif isinstance(func, ast.Name):\n name = func.id\n else:\n continue\n\n maybe_attr_dict_node: Any | None = None\n\n if name == \"vdom\":\n if len(node.args) == 0:\n continue\n elif len(node.args) == 1:\n maybe_attr_dict_node = ast.Dict(keys=[], values=[])\n if add_props:\n node.args.append(maybe_attr_dict_node)\n else:\n continue\n elif isinstance(node.args[1], (ast.Constant, ast.JoinedStr)):\n maybe_attr_dict_node = ast.Dict(keys=[], values=[])\n if add_props:\n node.args.insert(1, maybe_attr_dict_node)\n else:\n continue\n elif len(node.args) >= 2: # noqa: PLR2004\n maybe_attr_dict_node = node.args[1]\n elif hasattr(html, name):\n if len(node.args) == 0:\n maybe_attr_dict_node = ast.Dict(keys=[], values=[])\n if add_props:\n node.args.append(maybe_attr_dict_node)\n else:\n continue\n elif isinstance(node.args[0], (ast.Constant, ast.JoinedStr)):\n maybe_attr_dict_node = ast.Dict(keys=[], values=[])\n if add_props:\n node.args.insert(0, maybe_attr_dict_node)\n else:\n continue\n else:\n maybe_attr_dict_node = node.args[0]\n\n if not maybe_attr_dict_node:\n continue\n\n if isinstance(maybe_attr_dict_node, ast.Dict) or (\n isinstance(maybe_attr_dict_node, ast.Call)\n and isinstance(maybe_attr_dict_node.func, ast.Name)\n and maybe_attr_dict_node.func.id == \"dict\"\n and isinstance(maybe_attr_dict_node.func.ctx, ast.Load)\n ):\n yield ElementConstructorInfo(node, maybe_attr_dict_node, parents)\n\n return changed\n\n\n@dataclass\nclass ElementConstructorInfo:\n call: ast.Call\n props: ast.Dict | ast.Call\n parents: Sequence[ast.AST]\n\n\ndef _find_comments(lines: list[str]) -> list[str]:\n iter_lines = iter(lines)\n return [\n token\n for token_type, token, _, _, _ in generate_tokens(lambda: next(iter_lines))\n if token_type == COMMENT_TOKEN\n ]\n\n\ndef _walk_with_parent(\n node: ast.AST, parents: tuple[ast.AST, ...] = ()\n) -> Iterator[tuple[tuple[ast.AST, ...], ast.AST]]:\n parents = (node, *parents)\n for child in ast.iter_child_nodes(node):\n yield parents, child\n yield from _walk_with_parent(child, parents)\n\n# Path: src/py/reactpy/reactpy/_console/rewrite_camel_case_props.py\nfrom __future__ import annotations\n\nimport ast\nimport re\nimport sys\nfrom copy import copy\nfrom keyword import kwlist\nfrom pathlib import Path\nfrom typing import Callable\n\nimport click\n\nfrom reactpy._console.ast_utils import (\n ChangedNode,\n find_element_constructor_usages,\n rewrite_changed_nodes,\n)\n\nCAMEL_CASE_SUB_PATTERN = re.compile(r\"(? None:\n \"\"\"Rewrite camelCase props to snake_case\"\"\"\n if sys.version_info < (3, 9): # nocov\n msg = \"This command requires Python>=3.9\"\n raise RuntimeError(msg)\n\n for p in map(Path, paths):\n for f in [p] if p.is_file() else p.rglob(\"*.py\"):\n result = generate_rewrite(file=f, source=f.read_text(encoding=\"utf-8\"))\n if result is not None:\n f.write_text(result)\n\n\ndef generate_rewrite(file: Path, source: str) -> str | None:\n tree = ast.parse(source)\n\n changed = find_nodes_to_change(tree)\n if not changed:\n return None\n\n new = rewrite_changed_nodes(file, source, tree, changed)\n return new\n\n\ndef find_nodes_to_change(tree: ast.AST) -> list[ChangedNode]:\n changed: list[ChangedNode] = []\n for el_info in find_element_constructor_usages(tree):\n if _rewrite_props(el_info.props, _construct_prop_item):\n changed.append(ChangedNode(el_info.call, el_info.parents))\n return changed\n\n\ndef conv_attr_name(name: str) -> str:\n new_name = CAMEL_CASE_SUB_PATTERN.sub(\"_\", name).lower()\n return f\"{new_name}_\" if new_name in kwlist else new_name\n\n\ndef _construct_prop_item(key: str, value: ast.expr) -> tuple[str, ast.expr]:\n if key == \"style\" and isinstance(value, (ast.Dict, ast.Call)):\n new_value = copy(value)\n if _rewrite_props(\n new_value,\n lambda k, v: (\n (k, v)\n # avoid infinite recursion\n if k == \"style\"\n else _construct_prop_item(k, v)\n ),\n ):\n value = new_value\n else:\n key = conv_attr_name(key)\n return key, value\n\n\ndef _rewrite_props(\n props_node: ast.Dict | ast.Call,\n constructor: Callable[[str, ast.expr], tuple[str, ast.expr]],\n) -> bool:\n if isinstance(props_node, ast.Dict):\n did_change = False\n keys: list[ast.expr | None] = []\n values: list[ast.expr] = []\n for k, v in zip(props_node.keys, props_node.values):\n if isinstance(k, ast.Constant) and isinstance(k.value, str):\n k_value, new_v = constructor(k.value, v)\n if k_value != k.value or new_v is not v:\n did_change = True\n k = ast.Constant(value=k_value)\n v = new_v\n keys.append(k)\n values.append(v)\n if not did_change:\n return False\n props_node.keys = keys\n props_node.values = values\n else:\n did_change = False\n keywords: list[ast.keyword] = []\n for kw in props_node.keywords:\n if kw.arg is not None:\n kw_arg, kw_value = constructor(kw.arg, kw.value)\n if kw_arg != kw.arg or kw_value is not kw.value:\n did_change = True\n kw = ast.keyword(arg=kw_arg, value=kw_value)\n keywords.append(kw)\n if not did_change:\n return False\n props_node.keywords = keywords\n return True\n\n# Path: src/py/reactpy/reactpy/_console/rewrite_keys.py\nfrom __future__ import annotations\n\nimport ast\nimport sys\nfrom pathlib import Path\n\nimport click\n\nfrom reactpy import html\nfrom reactpy._console.ast_utils import (\n ChangedNode,\n find_element_constructor_usages,\n rewrite_changed_nodes,\n)\n\n\n@click.command()\n@click.argument(\"paths\", nargs=-1, type=click.Path(exists=True))\ndef rewrite_keys(paths: list[str]) -> None:\n \"\"\"Rewrite files under the given paths using the new html element API.\n\n The old API required users to pass a dictionary of attributes to html element\n constructor functions. For example:\n\n >>> html.div({\"className\": \"x\"}, \"y\")\n {\"tagName\": \"div\", \"attributes\": {\"className\": \"x\"}, \"children\": [\"y\"]}\n\n The latest API though allows for attributes to be passed as snake_cased keyword\n arguments instead. The above example would be rewritten as:\n\n >>> html.div(\"y\", class_name=\"x\")\n {\"tagName\": \"div\", \"attributes\": {\"class_name\": \"x\"}, \"children\": [\"y\"]}\n\n All snake_case attributes are converted to camelCase by the client where necessary.\n\n ----- Notes -----\n\n While this command does it's best to preserve as much of the original code as\n possible, there are inevitably some limitations in doing this. As a result, we\n recommend running your code formatter like Black against your code after executing\n this command.\n\n Additionally, We are unable to preserve the location of comments that lie within any\n rewritten code. This command will place the comments in the code it plans to rewrite\n just above its changes. As such it requires manual intervention to put those\n comments back in their original location.\n \"\"\"\n if sys.version_info < (3, 9): # nocov\n msg = \"This command requires Python>=3.9\"\n raise RuntimeError(msg)\n\n for p in map(Path, paths):\n for f in [p] if p.is_file() else p.rglob(\"*.py\"):\n result = generate_rewrite(file=f, source=f.read_text(encoding=\"utf-8\"))\n if result is not None:\n f.write_text(result)\n\n\ndef generate_rewrite(file: Path, source: str) -> str | None:\n tree = ast.parse(source)\n\n changed = find_nodes_to_change(tree)\n if not changed:\n log_could_not_rewrite(file, tree)\n return None\n\n new = rewrite_changed_nodes(file, source, tree, changed)\n log_could_not_rewrite(file, ast.parse(new))\n\n return new\n\n\ndef find_nodes_to_change(tree: ast.AST) -> list[ChangedNode]:\n changed: list[ChangedNode] = []\n for el_info in find_element_constructor_usages(tree, add_props=True):\n for kw in list(el_info.call.keywords):\n if kw.arg == \"key\":\n break\n else:\n continue\n\n if isinstance(el_info.props, ast.Dict):\n el_info.props.keys.append(ast.Constant(\"key\"))\n el_info.props.values.append(kw.value)\n else:\n el_info.props.keywords.append(ast.keyword(arg=\"key\", value=kw.value))\n\n el_info.call.keywords.remove(kw)\n changed.append(ChangedNode(el_info.call, el_info.parents))\n\n return changed\n\n\ndef log_could_not_rewrite(file: Path, tree: ast.AST) -> None:\n for node in ast.walk(tree):\n if not (isinstance(node, ast.Call) and node.keywords):\n continue\n\n func = node.func\n if isinstance(func, ast.Attribute):\n name = func.attr\n elif isinstance(func, ast.Name):\n name = func.id\n else:\n continue\n\n if (\n name == \"vdom\"\n or hasattr(html, name)\n and any(kw.arg == \"key\" for kw in node.keywords)\n ):\n click.echo(f\"Unable to rewrite usage at {file}:{node.lineno}\")\n\n# Path: src/py/reactpy/reactpy/__main__.py\nimport click\n\nimport reactpy\nfrom reactpy._console.rewrite_camel_case_props import rewrite_camel_case_props\nfrom reactpy._console.rewrite_keys import rewrite_keys\n\n\n@click.group()\n@click.version_option(reactpy.__version__, prog_name=reactpy.__name__)\ndef app() -> None:\n pass\n\n\napp.add_command(rewrite_keys)\napp.add_command(rewrite_camel_case_props)\n\n\nif __name__ == \"__main__\":\n app()\n\n# Path: src/py/reactpy/reactpy/backend/_common.py\nfrom __future__ import annotations\n\nimport asyncio\nimport os\nfrom collections.abc import Awaitable, Sequence\nfrom dataclasses import dataclass\nfrom pathlib import Path, PurePosixPath\nfrom typing import TYPE_CHECKING, Any, cast\n\nfrom reactpy import __file__ as _reactpy_file_path\nfrom reactpy import html\nfrom reactpy.config import REACTPY_WEB_MODULES_DIR\nfrom reactpy.core.types import VdomDict\nfrom reactpy.utils import vdom_to_html\n\nif TYPE_CHECKING:\n import uvicorn\n from asgiref.typing import ASGIApplication\n\nPATH_PREFIX = PurePosixPath(\"/_reactpy\")\nMODULES_PATH = PATH_PREFIX / \"modules\"\nASSETS_PATH = PATH_PREFIX / \"assets\"\nSTREAM_PATH = PATH_PREFIX / \"stream\"\nCLIENT_BUILD_DIR = Path(_reactpy_file_path).parent / \"_static\" / \"app\" / \"dist\"\n\n\nasync def serve_with_uvicorn(\n app: ASGIApplication | Any,\n host: str,\n port: int,\n started: asyncio.Event | None,\n) -> None:\n \"\"\"Run a development server for an ASGI application\"\"\"\n import uvicorn\n\n server = uvicorn.Server(\n uvicorn.Config(\n app,\n host=host,\n port=port,\n loop=\"asyncio\",\n )\n )\n server.config.setup_event_loop()\n coros: list[Awaitable[Any]] = [server.serve()]\n\n # If a started event is provided, then use it signal based on `server.started`\n if started:\n coros.append(_check_if_started(server, started))\n\n try:\n await asyncio.gather(*coros)\n finally:\n # Since we aren't using the uvicorn's `run()` API, we can't guarantee uvicorn's\n # order of operations. So we need to make sure `shutdown()` always has an initialized\n # list of `self.servers` to use.\n if not hasattr(server, \"servers\"): # nocov\n server.servers = []\n await asyncio.wait_for(server.shutdown(), timeout=3)\n\n\nasync def _check_if_started(server: uvicorn.Server, started: asyncio.Event) -> None:\n while not server.started:\n await asyncio.sleep(0.2)\n started.set()\n\n\ndef safe_client_build_dir_path(path: str) -> Path:\n \"\"\"Prevent path traversal out of :data:`CLIENT_BUILD_DIR`\"\"\"\n return traversal_safe_path(\n CLIENT_BUILD_DIR, *(\"index.html\" if path in {\"\", \"/\"} else path).split(\"/\")\n )\n\n\ndef safe_web_modules_dir_path(path: str) -> Path:\n \"\"\"Prevent path traversal out of :data:`reactpy.config.REACTPY_WEB_MODULES_DIR`\"\"\"\n return traversal_safe_path(REACTPY_WEB_MODULES_DIR.current, *path.split(\"/\"))\n\n\ndef traversal_safe_path(root: str | Path, *unsafe: str | Path) -> Path:\n \"\"\"Raise a ``ValueError`` if the ``unsafe`` path resolves outside the root dir.\"\"\"\n root = os.path.abspath(root)\n\n # Resolve relative paths but not symlinks - symlinks should be ok since their\n # presence and where they point is under the control of the developer.\n path = os.path.abspath(os.path.join(root, *unsafe))\n\n if os.path.commonprefix([root, path]) != root:\n # If the common prefix is not root directory we resolved outside the root dir\n msg = \"Unsafe path\"\n raise ValueError(msg)\n\n return Path(path)\n\n\ndef read_client_index_html(options: CommonOptions) -> str:\n return (\n (CLIENT_BUILD_DIR / \"index.html\")\n .read_text()\n .format(__head__=vdom_head_elements_to_html(options.head))\n )\n\n\ndef vdom_head_elements_to_html(head: Sequence[VdomDict] | VdomDict | str) -> str:\n if isinstance(head, str):\n return head\n elif isinstance(head, dict):\n if head.get(\"tagName\") == \"head\":\n head = cast(VdomDict, {**head, \"tagName\": \"\"})\n return vdom_to_html(head)\n else:\n return vdom_to_html(html._(*head))\n\n\n@dataclass\nclass CommonOptions:\n \"\"\"Options for ReactPy's built-in backed server implementations\"\"\"\n\n head: Sequence[VdomDict] | VdomDict | str = (\n html.title(\"ReactPy\"),\n html.link(\n {\n \"rel\": \"icon\",\n \"href\": \"/_reactpy/assets/reactpy-logo.ico\",\n \"type\": \"image/x-icon\",\n }\n ),\n )\n \"\"\"Add elements to the ```` of the application.\n\n For example, this can be used to customize the title of the page, link extra\n scripts, or load stylesheets.\n \"\"\"\n\n url_prefix: str = \"\"\n \"\"\"The URL prefix where ReactPy resources will be served from\"\"\"\n\n serve_index_route: bool = True\n \"\"\"Automatically generate and serve the index route (``/``)\"\"\"\n\n def __post_init__(self) -> None:\n if self.url_prefix and not self.url_prefix.startswith(\"/\"):\n msg = \"Expected 'url_prefix' to start with '/'\"\n raise ValueError(msg)\n\n# Path: src/py/reactpy/reactpy/core/serve.py\nfrom __future__ import annotations\n\nfrom collections.abc import Awaitable\nfrom logging import getLogger\nfrom typing import Callable\nfrom warnings import warn\n\nfrom anyio import create_task_group\nfrom anyio.abc import TaskGroup\n\nfrom reactpy.config import REACTPY_DEBUG_MODE\nfrom reactpy.core.types import LayoutEventMessage, LayoutType, LayoutUpdateMessage\n\nlogger = getLogger(__name__)\n\n\nSendCoroutine = Callable[[LayoutUpdateMessage], Awaitable[None]]\n\"\"\"Send model patches given by a dispatcher\"\"\"\n\nRecvCoroutine = Callable[[], Awaitable[LayoutEventMessage]]\n\"\"\"Called by a dispatcher to return a :class:`reactpy.core.layout.LayoutEventMessage`\n\nThe event will then trigger an :class:`reactpy.core.proto.EventHandlerType` in a layout.\n\"\"\"\n\n\nclass Stop(BaseException):\n \"\"\"Deprecated\n\n Stop serving changes and events\n\n Raising this error will tell dispatchers to gracefully exit. Typically this is\n called by code running inside a layout to tell it to stop rendering.\n \"\"\"\n\n\nasync def serve_layout(\n layout: LayoutType[LayoutUpdateMessage, LayoutEventMessage],\n send: SendCoroutine,\n recv: RecvCoroutine,\n) -> None:\n \"\"\"Run a dispatch loop for a single view instance\"\"\"\n async with layout:\n try:\n async with create_task_group() as task_group:\n task_group.start_soon(_single_outgoing_loop, layout, send)\n task_group.start_soon(_single_incoming_loop, task_group, layout, recv)\n except Stop: # nocov\n warn(\n \"The Stop exception is deprecated and will be removed in a future version\",\n UserWarning,\n stacklevel=1,\n )\n logger.info(f\"Stopped serving {layout}\")\n\n\nasync def _single_outgoing_loop(\n layout: LayoutType[LayoutUpdateMessage, LayoutEventMessage], send: SendCoroutine\n) -> None:\n while True:\n update = await layout.render()\n try:\n await send(update)\n except Exception: # nocov\n if not REACTPY_DEBUG_MODE.current:\n msg = (\n \"Failed to send update. More info may be available \"\n \"if you enabling debug mode by setting \"\n \"`reactpy.config.REACTPY_DEBUG_MODE.current = True`.\"\n )\n logger.error(msg)\n raise\n\n\nasync def _single_incoming_loop(\n task_group: TaskGroup,\n layout: LayoutType[LayoutUpdateMessage, LayoutEventMessage],\n recv: RecvCoroutine,\n) -> None:\n while True:\n # We need to fire and forget here so that we avoid waiting on the completion\n # of this event handler before receiving and running the next one.\n task_group.start_soon(layout.deliver, await recv())\n\n# Path: src/py/reactpy/reactpy/backend/starlette.py\nfrom __future__ import annotations\n\nimport asyncio\nimport json\nimport logging\nfrom collections.abc import Awaitable\nfrom dataclasses import dataclass\nfrom typing import Any, Callable\n\nfrom exceptiongroup import BaseExceptionGroup\nfrom starlette.applications import Starlette\nfrom starlette.middleware.cors import CORSMiddleware\nfrom starlette.requests import Request\nfrom starlette.responses import HTMLResponse\nfrom starlette.staticfiles import StaticFiles\nfrom starlette.websockets import WebSocket, WebSocketDisconnect\n\nfrom reactpy.backend._common import (\n ASSETS_PATH,\n CLIENT_BUILD_DIR,\n MODULES_PATH,\n STREAM_PATH,\n CommonOptions,\n read_client_index_html,\n serve_with_uvicorn,\n)\nfrom reactpy.backend.hooks import ConnectionContext\nfrom reactpy.backend.hooks import use_connection as _use_connection\nfrom reactpy.backend.types import Connection, Location\nfrom reactpy.config import REACTPY_WEB_MODULES_DIR\nfrom reactpy.core.layout import Layout\nfrom reactpy.core.serve import RecvCoroutine, SendCoroutine, serve_layout\nfrom reactpy.core.types import RootComponentConstructor\n\nlogger = logging.getLogger(__name__)\n\n\n# BackendType.Options\n@dataclass\nclass Options(CommonOptions):\n \"\"\"Render server config for :func:`reactpy.backend.starlette.configure`\"\"\"\n\n cors: bool | dict[str, Any] = False\n \"\"\"Enable or configure Cross Origin Resource Sharing (CORS)\n\n For more information see docs for ``starlette.middleware.cors.CORSMiddleware``\n \"\"\"\n\n\n# BackendType.configure\ndef configure(\n app: Starlette,\n component: RootComponentConstructor,\n options: Options | None = None,\n) -> None:\n \"\"\"Configure the necessary ReactPy routes on the given app.\n\n Parameters:\n app: An application instance\n component: A component constructor\n options: Options for configuring server behavior\n \"\"\"\n options = options or Options()\n\n # this route should take priority so set up it up first\n _setup_single_view_dispatcher_route(options, app, component)\n\n _setup_common_routes(options, app)\n\n\n# BackendType.create_development_app\ndef create_development_app() -> Starlette:\n \"\"\"Return a :class:`Starlette` app instance in debug mode\"\"\"\n return Starlette(debug=True)\n\n\n# BackendType.serve_development_app\nasync def serve_development_app(\n app: Starlette,\n host: str,\n port: int,\n started: asyncio.Event | None = None,\n) -> None:\n \"\"\"Run a development server for starlette\"\"\"\n await serve_with_uvicorn(app, host, port, started)\n\n\ndef use_websocket() -> WebSocket:\n \"\"\"Get the current WebSocket object\"\"\"\n return use_connection().carrier\n\n\ndef use_connection() -> Connection[WebSocket]:\n conn = _use_connection()\n if not isinstance(conn.carrier, WebSocket): # nocov\n msg = f\"Connection has unexpected carrier {conn.carrier}. Are you running with a Flask server?\"\n raise TypeError(msg)\n return conn\n\n\ndef _setup_common_routes(options: Options, app: Starlette) -> None:\n cors_options = options.cors\n if cors_options: # nocov\n cors_params = (\n cors_options if isinstance(cors_options, dict) else {\"allow_origins\": [\"*\"]}\n )\n app.add_middleware(CORSMiddleware, **cors_params)\n\n # This really should be added to the APIRouter, but there's a bug in Starlette\n # BUG: https://github.com/tiangolo/fastapi/issues/1469\n url_prefix = options.url_prefix\n\n app.mount(\n str(MODULES_PATH),\n StaticFiles(directory=REACTPY_WEB_MODULES_DIR.current, check_dir=False),\n )\n app.mount(\n str(ASSETS_PATH),\n StaticFiles(directory=CLIENT_BUILD_DIR / \"assets\", check_dir=False),\n )\n # register this last so it takes least priority\n index_route = _make_index_route(options)\n\n if options.serve_index_route:\n app.add_route(f\"{url_prefix}/\", index_route)\n app.add_route(url_prefix + \"/{path:path}\", index_route)\n\n\ndef _make_index_route(options: Options) -> Callable[[Request], Awaitable[HTMLResponse]]:\n index_html = read_client_index_html(options)\n\n async def serve_index(request: Request) -> HTMLResponse:\n return HTMLResponse(index_html)\n\n return serve_index\n\n\ndef _setup_single_view_dispatcher_route(\n options: Options, app: Starlette, component: RootComponentConstructor\n) -> None:\n async def model_stream(socket: WebSocket) -> None:\n await socket.accept()\n send, recv = _make_send_recv_callbacks(socket)\n\n pathname = \"/\" + socket.scope[\"path_params\"].get(\"path\", \"\")\n...\n", "needle_token_start": 5565, "needle_token_end": 5652, "code_context_ntokens": 16019, "output": ["```python\ndef use(\n self,\n function: Callable[..., Any],\n stop_propagation: bool = False,\n prevent_default: bool = False,\n ) -> EventHandler:\n return EventHandler(\n to_event_handler_function(function),\n stop_propagation,\n prevent_default,\n self.target,\n )\n```"]} +{"repo": "reactive-python/reactpy", "name": "_equal_files", "language": "python", "path": "src/py/reactpy/reactpy/web/module.py", "position_ratio": 0.45, "description": "\nFunction Description:\n1. **Purpose**: To determine if two specified files are identical.\n2. **Input**: Two file paths.\n3. **Output**: A boolean value indicating whether the files are the same.\n4. **Procedure**: The function first resolves the paths of the input files. If either file is a symbolic link, it checks if the resolved paths are identical. If not, it performs a deep comparison of the files' contents to determine if they are the same.\n\n", "instruction": "Based on the function description and code context, please retrieve and repeat the exact described function from the code context in a code block wrapped by ```:", "template": "instruction\ncode_context\ndescription\ninstruction", "code_context": "# Path: src/py/reactpy/reactpy/html.py\n\"\"\"\n\n**Fragment**\n\n- :func:`_`\n\n**Document metadata**\n\n- :func:`base`\n- :func:`head`\n- :func:`link`\n- :func:`meta`\n- :func:`style`\n- :func:`title`\n\n**Content sectioning**\n\n- :func:`address`\n- :func:`article`\n- :func:`aside`\n- :func:`footer`\n- :func:`header`\n- :func:`h1`\n- :func:`h2`\n- :func:`h3`\n- :func:`h4`\n- :func:`h5`\n- :func:`h6`\n- :func:`main`\n- :func:`nav`\n- :func:`section`\n\n**Text content**\n\n- :func:`blockquote`\n- :func:`dd`\n- :func:`div`\n- :func:`dl`\n- :func:`dt`\n- :func:`figcaption`\n- :func:`figure`\n- :func:`hr`\n- :func:`li`\n- :func:`ol`\n- :func:`p`\n- :func:`pre`\n- :func:`ul`\n\n**Inline text semantics**\n\n- :func:`a`\n- :func:`abbr`\n- :func:`b`\n- :func:`bdi`\n- :func:`bdo`\n- :func:`br`\n- :func:`cite`\n- :func:`code`\n- :func:`data`\n- :func:`em`\n- :func:`i`\n- :func:`kbd`\n- :func:`mark`\n- :func:`q`\n- :func:`rp`\n- :func:`rt`\n- :func:`ruby`\n- :func:`s`\n- :func:`samp`\n- :func:`small`\n- :func:`span`\n- :func:`strong`\n- :func:`sub`\n- :func:`sup`\n- :func:`time`\n- :func:`u`\n- :func:`var`\n- :func:`wbr`\n\n**Image and video**\n\n- :func:`area`\n- :func:`audio`\n- :func:`img`\n- :func:`map`\n- :func:`track`\n- :func:`video`\n\n**Embedded content**\n\n- :func:`embed`\n- :func:`iframe`\n- :func:`object`\n- :func:`param`\n- :func:`picture`\n- :func:`portal`\n- :func:`source`\n\n**SVG and MathML**\n\n- :func:`svg`\n- :func:`math`\n\n**Scripting**\n\n- :func:`canvas`\n- :func:`noscript`\n- :func:`script`\n\n**Demarcating edits**\n\n- :func:`del_`\n- :func:`ins`\n\n**Table content**\n\n- :func:`caption`\n- :func:`col`\n- :func:`colgroup`\n- :func:`table`\n- :func:`tbody`\n- :func:`td`\n- :func:`tfoot`\n- :func:`th`\n- :func:`thead`\n- :func:`tr`\n\n**Forms**\n\n- :func:`button`\n- :func:`fieldset`\n- :func:`form`\n- :func:`input`\n- :func:`label`\n- :func:`legend`\n- :func:`meter`\n- :func:`option`\n- :func:`output`\n- :func:`progress`\n- :func:`select`\n- :func:`textarea`\n\n**Interactive elements**\n\n- :func:`details`\n- :func:`dialog`\n- :func:`menu`\n- :func:`menuitem`\n- :func:`summary`\n\n**Web components**\n\n- :func:`slot`\n- :func:`template`\n\n.. autofunction:: _\n\"\"\"\n\nfrom __future__ import annotations\n\nfrom collections.abc import Sequence\n\n...\n# Path: src/py/reactpy/reactpy/logging.py\nimport logging\nimport sys\nfrom logging.config import dictConfig\n\nfrom reactpy.config import REACTPY_DEBUG_MODE\n\ndictConfig(\n {\n \"version\": 1,\n \"disable_existing_loggers\": False,\n \"loggers\": {\n \"reactpy\": {\"handlers\": [\"console\"]},\n },\n \"handlers\": {\n \"console\": {\n \"class\": \"logging.StreamHandler\",\n \"formatter\": \"generic\",\n \"stream\": sys.stdout,\n }\n },\n \"formatters\": {\n \"generic\": {\n \"format\": \"%(asctime)s | %(log_color)s%(levelname)s%(reset)s | %(message)s\",\n \"datefmt\": r\"%Y-%m-%dT%H:%M:%S%z\",\n \"class\": \"colorlog.ColoredFormatter\",\n }\n },\n }\n)\n\n\nROOT_LOGGER = logging.getLogger(\"reactpy\")\n\"\"\"ReactPy's root logger instance\"\"\"\n\n\n@REACTPY_DEBUG_MODE.subscribe\ndef _set_debug_level(debug: bool) -> None:\n if debug:\n ROOT_LOGGER.setLevel(\"DEBUG\")\n ROOT_LOGGER.debug(\"ReactPy is in debug mode\")\n else:\n ROOT_LOGGER.setLevel(\"INFO\")\n\n# Path: src/py/reactpy/reactpy/sample.py\nfrom __future__ import annotations\n\nfrom reactpy import html\nfrom reactpy.core.component import component\nfrom reactpy.core.types import VdomDict\n\n\n@component\ndef SampleApp() -> VdomDict:\n return html.div(\n {\"id\": \"sample\", \"style\": {\"padding\": \"15px\"}},\n html.h1(\"Sample Application\"),\n html.p(\n \"This is a basic application made with ReactPy. Click \",\n html.a(\n {\"href\": \"https://pypi.org/project/reactpy/\", \"target\": \"_blank\"},\n \"here\",\n ),\n \" to learn more.\",\n ),\n )\n\n# Path: src/py/reactpy/reactpy/svg.py\nfrom reactpy.core.vdom import make_vdom_constructor\n\n__all__ = (\n \"a\",\n \"animate\",\n \"animate_motion\",\n \"animate_transform\",\n \"circle\",\n \"clip_path\",\n \"defs\",\n \"desc\",\n \"discard\",\n \"ellipse\",\n \"fe_blend\",\n \"fe_color_matrix\",\n \"fe_component_transfer\",\n \"fe_composite\",\n \"fe_convolve_matrix\",\n \"fe_diffuse_lighting\",\n \"fe_displacement_map\",\n \"fe_distant_light\",\n \"fe_drop_shadow\",\n \"fe_flood\",\n \"fe_func_a\",\n \"fe_func_b\",\n \"fe_func_g\",\n \"fe_func_r\",\n \"fe_gaussian_blur\",\n \"fe_image\",\n \"fe_merge\",\n \"fe_merge_node\",\n \"fe_morphology\",\n \"fe_offset\",\n \"fe_point_light\",\n \"fe_specular_lighting\",\n \"fe_spot_light\",\n \"fe_tile\",\n \"fe_turbulence\",\n \"filter\",\n \"foreign_object\",\n \"g\",\n \"hatch\",\n \"hatchpath\",\n \"image\",\n \"line\",\n \"linear_gradient\",\n \"marker\",\n \"mask\",\n \"metadata\",\n \"mpath\",\n \"path\",\n \"pattern\",\n \"polygon\",\n \"polyline\",\n \"radial_gradient\",\n \"rect\",\n \"script\",\n \"set\",\n \"stop\",\n \"style\",\n \"svg\",\n \"switch\",\n \"symbol\",\n \"text\",\n \"text_path\",\n \"title\",\n \"tspan\",\n \"use\",\n \"view\",\n)\n\na = make_vdom_constructor(\"a\")\nanimate = make_vdom_constructor(\"animate\", allow_children=False)\nanimate_motion = make_vdom_constructor(\"animateMotion\", allow_children=False)\nanimate_transform = make_vdom_constructor(\"animateTransform\", allow_children=False)\ncircle = make_vdom_constructor(\"circle\", allow_children=False)\nclip_path = make_vdom_constructor(\"clipPath\")\ndefs = make_vdom_constructor(\"defs\")\ndesc = make_vdom_constructor(\"desc\", allow_children=False)\ndiscard = make_vdom_constructor(\"discard\", allow_children=False)\nellipse = make_vdom_constructor(\"ellipse\", allow_children=False)\nfe_blend = make_vdom_constructor(\"feBlend\", allow_children=False)\nfe_color_matrix = make_vdom_constructor(\"feColorMatrix\", allow_children=False)\nfe_component_transfer = make_vdom_constructor(\n \"feComponentTransfer\", allow_children=False\n)\nfe_composite = make_vdom_constructor(\"feComposite\", allow_children=False)\nfe_convolve_matrix = make_vdom_constructor(\"feConvolveMatrix\", allow_children=False)\nfe_diffuse_lighting = make_vdom_constructor(\"feDiffuseLighting\", allow_children=False)\nfe_displacement_map = make_vdom_constructor(\"feDisplacementMap\", allow_children=False)\nfe_distant_light = make_vdom_constructor(\"feDistantLight\", allow_children=False)\nfe_drop_shadow = make_vdom_constructor(\"feDropShadow\", allow_children=False)\nfe_flood = make_vdom_constructor(\"feFlood\", allow_children=False)\nfe_func_a = make_vdom_constructor(\"feFuncA\", allow_children=False)\nfe_func_b = make_vdom_constructor(\"feFuncB\", allow_children=False)\nfe_func_g = make_vdom_constructor(\"feFuncG\", allow_children=False)\nfe_func_r = make_vdom_constructor(\"feFuncR\", allow_children=False)\nfe_gaussian_blur = make_vdom_constructor(\"feGaussianBlur\", allow_children=False)\nfe_image = make_vdom_constructor(\"feImage\", allow_children=False)\nfe_merge = make_vdom_constructor(\"feMerge\", allow_children=False)\nfe_merge_node = make_vdom_constructor(\"feMergeNode\", allow_children=False)\nfe_morphology = make_vdom_constructor(\"feMorphology\", allow_children=False)\nfe_offset = make_vdom_constructor(\"feOffset\", allow_children=False)\nfe_point_light = make_vdom_constructor(\"fePointLight\", allow_children=False)\nfe_specular_lighting = make_vdom_constructor(\"feSpecularLighting\", allow_children=False)\nfe_spot_light = make_vdom_constructor(\"feSpotLight\", allow_children=False)\nfe_tile = make_vdom_constructor(\"feTile\", allow_children=False)\nfe_turbulence = make_vdom_constructor(\"feTurbulence\", allow_children=False)\nfilter = make_vdom_constructor(\"filter\", allow_children=False) # noqa: A001\nforeign_object = make_vdom_constructor(\"foreignObject\", allow_children=False)\ng = make_vdom_constructor(\"g\")\nhatch = make_vdom_constructor(\"hatch\", allow_children=False)\nhatchpath = make_vdom_constructor(\"hatchpath\", allow_children=False)\nimage = make_vdom_constructor(\"image\", allow_children=False)\nline = make_vdom_constructor(\"line\", allow_children=False)\nlinear_gradient = make_vdom_constructor(\"linearGradient\", allow_children=False)\nmarker = make_vdom_constructor(\"marker\")\nmask = make_vdom_constructor(\"mask\")\nmetadata = make_vdom_constructor(\"metadata\", allow_children=False)\nmpath = make_vdom_constructor(\"mpath\", allow_children=False)\npath = make_vdom_constructor(\"path\", allow_children=False)\npattern = make_vdom_constructor(\"pattern\")\npolygon = make_vdom_constructor(\"polygon\", allow_children=False)\npolyline = make_vdom_constructor(\"polyline\", allow_children=False)\nradial_gradient = make_vdom_constructor(\"radialGradient\", allow_children=False)\nrect = make_vdom_constructor(\"rect\", allow_children=False)\nscript = make_vdom_constructor(\"script\", allow_children=False)\nset = make_vdom_constructor(\"set\", allow_children=False) # noqa: A001\nstop = make_vdom_constructor(\"stop\", allow_children=False)\nstyle = make_vdom_constructor(\"style\", allow_children=False)\nsvg = make_vdom_constructor(\"svg\")\nswitch = make_vdom_constructor(\"switch\")\nsymbol = make_vdom_constructor(\"symbol\")\ntext = make_vdom_constructor(\"text\", allow_children=False)\ntext_path = make_vdom_constructor(\"textPath\", allow_children=False)\ntitle = make_vdom_constructor(\"title\", allow_children=False)\ntspan = make_vdom_constructor(\"tspan\", allow_children=False)\nuse = make_vdom_constructor(\"use\", allow_children=False)\nview = make_vdom_constructor(\"view\", allow_children=False)\n\n# Path: src/py/reactpy/reactpy/web/utils.py\nimport logging\nimport re\nfrom pathlib import Path, PurePosixPath\nfrom urllib.parse import urlparse, urlunparse\n\nimport requests\n\nlogger = logging.getLogger(__name__)\n\n\ndef module_name_suffix(name: str) -> str:\n if name.startswith(\"@\"):\n name = name[1:]\n head, _, tail = name.partition(\"@\") # handle version identifier\n version, _, tail = tail.partition(\"/\") # get section after version\n return PurePosixPath(tail or head).suffix or \".js\"\n\n\ndef resolve_module_exports_from_file(\n file: Path,\n max_depth: int,\n is_re_export: bool = False,\n) -> set[str]:\n if max_depth == 0:\n logger.warning(f\"Did not resolve all exports for {file} - max depth reached\")\n return set()\n elif not file.exists():\n logger.warning(f\"Did not resolve exports for unknown file {file}\")\n return set()\n\n export_names, references = resolve_module_exports_from_source(\n file.read_text(encoding=\"utf-8\"), exclude_default=is_re_export\n )\n\n for ref in references:\n if urlparse(ref).scheme: # is an absolute URL\n export_names.update(\n resolve_module_exports_from_url(ref, max_depth - 1, is_re_export=True)\n )\n else:\n path = file.parent.joinpath(*ref.split(\"/\"))\n export_names.update(\n resolve_module_exports_from_file(path, max_depth - 1, is_re_export=True)\n )\n\n return export_names\n\n\ndef resolve_module_exports_from_url(\n url: str,\n max_depth: int,\n is_re_export: bool = False,\n) -> set[str]:\n if max_depth == 0:\n logger.warning(f\"Did not resolve all exports for {url} - max depth reached\")\n return set()\n\n try:\n text = requests.get(url, timeout=5).text\n except requests.exceptions.ConnectionError as error:\n reason = \"\" if error is None else \" - {error.errno}\"\n logger.warning(\"Did not resolve exports for url \" + url + reason)\n return set()\n\n export_names, references = resolve_module_exports_from_source(\n text, exclude_default=is_re_export\n )\n\n for ref in references:\n url = _resolve_relative_url(url, ref)\n export_names.update(\n resolve_module_exports_from_url(url, max_depth - 1, is_re_export=True)\n )\n\n return export_names\n\n\ndef resolve_module_exports_from_source(\n content: str, exclude_default: bool\n) -> tuple[set[str], set[str]]:\n names: set[str] = set()\n references: set[str] = set()\n\n if _JS_DEFAULT_EXPORT_PATTERN.search(content):\n names.add(\"default\")\n\n # Exporting functions and classes\n names.update(_JS_FUNC_OR_CLS_EXPORT_PATTERN.findall(content))\n\n for export in _JS_GENERAL_EXPORT_PATTERN.findall(content):\n export = export.rstrip(\";\").strip()\n # Exporting individual features\n if export.startswith(\"let \"):\n names.update(let.split(\"=\", 1)[0] for let in export[4:].split(\",\"))\n # Renaming exports and export list\n elif export.startswith(\"{\") and export.endswith(\"}\"):\n names.update(\n item.split(\" as \", 1)[-1] for item in export.strip(\"{}\").split(\",\")\n )\n # Exporting destructured assignments with renaming\n elif export.startswith(\"const \"):\n names.update(\n item.split(\":\", 1)[0]\n for item in export[6:].split(\"=\", 1)[0].strip(\"{}\").split(\",\")\n )\n # Default exports\n elif export.startswith(\"default \"):\n names.add(\"default\")\n # Aggregating modules\n elif export.startswith(\"* as \"):\n names.add(export[5:].split(\" from \", 1)[0])\n elif export.startswith(\"* \"):\n references.add(export[2:].split(\"from \", 1)[-1].strip(\"'\\\"\"))\n elif export.startswith(\"{\") and \" from \" in export:\n names.update(\n item.split(\" as \", 1)[-1]\n for item in export.split(\" from \")[0].strip(\"{}\").split(\",\")\n )\n elif not (export.startswith(\"function \") or export.startswith(\"class \")):\n logger.warning(f\"Unknown export type {export!r}\")\n\n names = {n.strip() for n in names}\n references = {r.strip() for r in references}\n\n if exclude_default and \"default\" in names:\n names.remove(\"default\")\n\n return names, references\n\n\ndef _resolve_relative_url(base_url: str, rel_url: str) -> str:\n if not rel_url.startswith(\".\"):\n if rel_url.startswith(\"/\"):\n # copy scheme and hostname from base_url\n return urlunparse(urlparse(base_url)[:2] + urlparse(rel_url)[2:])\n else:\n return rel_url\n\n base_url = base_url.rsplit(\"/\", 1)[0]\n\n if rel_url.startswith(\"./\"):\n return base_url + rel_url[1:]\n\n while rel_url.startswith(\"../\"):\n base_url = base_url.rsplit(\"/\", 1)[0]\n rel_url = rel_url[3:]\n\n return f\"{base_url}/{rel_url}\"\n\n\n_JS_DEFAULT_EXPORT_PATTERN = re.compile(\n r\";?\\s*export\\s+default\\s\",\n)\n_JS_FUNC_OR_CLS_EXPORT_PATTERN = re.compile(\n r\";?\\s*export\\s+(?:function|class)\\s+([a-zA-Z_$][0-9a-zA-Z_$]*)\"\n)\n_JS_GENERAL_EXPORT_PATTERN = re.compile(\n r\"(?:^|;|})\\s*export(?=\\s+|{)(.*?)(?=;|$)\", re.MULTILINE\n)\n\n# Path: src/py/reactpy/reactpy/web/module.py\nfrom __future__ import annotations\n\nimport filecmp\nimport logging\nimport shutil\nfrom dataclasses import dataclass\nfrom pathlib import Path\nfrom string import Template\nfrom typing import Any, NewType, overload\nfrom urllib.parse import urlparse\n\nfrom reactpy._warnings import warn\nfrom reactpy.config import REACTPY_DEBUG_MODE, REACTPY_WEB_MODULES_DIR\nfrom reactpy.core.types import ImportSourceDict, VdomDictConstructor\nfrom reactpy.core.vdom import make_vdom_constructor\nfrom reactpy.web.utils import (\n module_name_suffix,\n resolve_module_exports_from_file,\n resolve_module_exports_from_url,\n)\n\nlogger = logging.getLogger(__name__)\n\nSourceType = NewType(\"SourceType\", str)\n\nNAME_SOURCE = SourceType(\"NAME\")\n\"\"\"A named source - usually a Javascript package name\"\"\"\n\nURL_SOURCE = SourceType(\"URL\")\n\"\"\"A source loaded from a URL, usually a CDN\"\"\"\n\n\ndef module_from_url(\n url: str,\n fallback: Any | None = None,\n resolve_exports: bool | None = None,\n resolve_exports_depth: int = 5,\n unmount_before_update: bool = False,\n) -> WebModule:\n \"\"\"Load a :class:`WebModule` from a :data:`URL_SOURCE`\n\n Parameters:\n url:\n Where the javascript module will be loaded from which conforms to the\n interface for :ref:`Custom Javascript Components`\n fallback:\n What to temporarily display while the module is being loaded.\n resolve_imports:\n Whether to try and find all the named exports of this module.\n resolve_exports_depth:\n How deeply to search for those exports.\n unmount_before_update:\n Cause the component to be unmounted before each update. This option should\n only be used if the imported package fails to re-render when props change.\n Using this option has negative performance consequences since all DOM\n elements must be changed on each render. See :issue:`461` for more info.\n \"\"\"\n return WebModule(\n source=url,\n source_type=URL_SOURCE,\n default_fallback=fallback,\n file=None,\n export_names=(\n resolve_module_exports_from_url(url, resolve_exports_depth)\n if (\n resolve_exports\n if resolve_exports is not None\n else REACTPY_DEBUG_MODE.current\n )\n else None\n ),\n unmount_before_update=unmount_before_update,\n )\n\n\n_FROM_TEMPLATE_DIR = \"__from_template__\"\n\n\ndef module_from_template(\n template: str,\n package: str,\n cdn: str = \"https://esm.sh\",\n fallback: Any | None = None,\n resolve_exports: bool | None = None,\n resolve_exports_depth: int = 5,\n unmount_before_update: bool = False,\n) -> WebModule:\n \"\"\"Create a :class:`WebModule` from a framework template\n\n This is useful for experimenting with component libraries that do not already\n support ReactPy's :ref:`Custom Javascript Component` interface.\n\n .. warning::\n\n This approach is not recommended for use in a production setting because the\n framework templates may use unpinned dependencies that could change without\n warning. It's best to author a module adhering to the\n :ref:`Custom Javascript Component` interface instead.\n\n **Templates**\n\n - ``react``: for modules exporting React components\n\n Parameters:\n template:\n The name of the framework template to use with the given ``package``.\n package:\n The name of a package to load. May include a file extension (defaults to\n ``.js`` if not given)\n cdn:\n Where the package should be loaded from. The CDN must distribute ESM modules\n fallback:\n What to temporarily display while the module is being loaded.\n resolve_imports:\n Whether to try and find all the named exports of this module.\n resolve_exports_depth:\n How deeply to search for those exports.\n unmount_before_update:\n Cause the component to be unmounted before each update. This option should\n only be used if the imported package fails to re-render when props change.\n Using this option has negative performance consequences since all DOM\n elements must be changed on each render. See :issue:`461` for more info.\n \"\"\"\n warn(\n \"module_from_template() is deprecated due to instability - use the Javascript \"\n \"Components API instead. This function will be removed in a future release.\",\n DeprecationWarning,\n )\n template_name, _, template_version = template.partition(\"@\")\n template_version = \"@\" + template_version if template_version else \"\"\n\n # We do this since the package may be any valid URL path. Thus we may need to strip\n # object parameters or query information so we save the resulting template under the\n # correct file name.\n package_name = urlparse(package).path\n\n # downstream code assumes no trailing slash\n cdn = cdn.rstrip(\"/\")\n\n template_file_name = template_name + module_name_suffix(package_name)\n\n template_file = Path(__file__).parent / \"templates\" / template_file_name\n if not template_file.exists():\n msg = f\"No template for {template_file_name!r} exists\"\n raise ValueError(msg)\n\n variables = {\"PACKAGE\": package, \"CDN\": cdn, \"VERSION\": template_version}\n content = Template(template_file.read_text(encoding=\"utf-8\")).substitute(variables)\n\n return module_from_string(\n _FROM_TEMPLATE_DIR + \"/\" + package_name,\n content,\n fallback,\n resolve_exports,\n resolve_exports_depth,\n unmount_before_update=unmount_before_update,\n )\n\n\ndef module_from_file(\n name: str,\n file: str | Path,\n fallback: Any | None = None,\n resolve_exports: bool | None = None,\n resolve_exports_depth: int = 5,\n unmount_before_update: bool = False,\n symlink: bool = False,\n) -> WebModule:\n \"\"\"Load a :class:`WebModule` from a given ``file``\n\n Parameters:\n name:\n The name of the package\n file:\n The file from which the content of the web module will be created.\n fallback:\n What to temporarily display while the module is being loaded.\n resolve_imports:\n Whether to try and find all the named exports of this module.\n resolve_exports_depth:\n How deeply to search for those exports.\n unmount_before_update:\n Cause the component to be unmounted before each update. This option should\n only be used if the imported package fails to re-render when props change.\n Using this option has negative performance consequences since all DOM\n elements must be changed on each render. See :issue:`461` for more info.\n symlink:\n Whether the web module should be saved as a symlink to the given ``file``.\n \"\"\"\n name += module_name_suffix(name)\n\n source_file = Path(file).resolve()\n target_file = _web_module_path(name)\n if not source_file.exists():\n msg = f\"Source file does not exist: {source_file}\"\n raise FileNotFoundError(msg)\n\n if not target_file.exists():\n _copy_file(target_file, source_file, symlink)\n elif not _equal_files(source_file, target_file):\n logger.info(\n f\"Existing web module {name!r} will \"\n f\"be replaced with {target_file.resolve()}\"\n )\n target_file.unlink()\n _copy_file(target_file, source_file, symlink)\n\n return WebModule(\n source=name,\n source_type=NAME_SOURCE,\n default_fallback=fallback,\n file=target_file,\n export_names=(\n resolve_module_exports_from_file(source_file, resolve_exports_depth)\n if (\n resolve_exports\n if resolve_exports is not None\n else REACTPY_DEBUG_MODE.current\n )\n else None\n ),\n unmount_before_update=unmount_before_update,\n )\n\n\n\ndef _equal_files(f1: Path, f2: Path) -> bool:\n f1 = f1.resolve()\n f2 = f2.resolve()\n return (\n (f1.is_symlink() or f2.is_symlink()) and (f1.resolve() == f2.resolve())\n ) or filecmp.cmp(str(f1), str(f2), shallow=False)\n\n\ndef _copy_file(target: Path, source: Path, symlink: bool) -> None:\n target.parent.mkdir(parents=True, exist_ok=True)\n if symlink:\n target.symlink_to(source)\n else:\n shutil.copy(source, target)\n\n\ndef module_from_string(\n name: str,\n content: str,\n fallback: Any | None = None,\n resolve_exports: bool | None = None,\n resolve_exports_depth: int = 5,\n unmount_before_update: bool = False,\n) -> WebModule:\n \"\"\"Load a :class:`WebModule` whose ``content`` comes from a string.\n\n Parameters:\n name:\n The name of the package\n content:\n The contents of the web module\n fallback:\n What to temporarily display while the module is being loaded.\n resolve_imports:\n Whether to try and find all the named exports of this module.\n resolve_exports_depth:\n How deeply to search for those exports.\n unmount_before_update:\n Cause the component to be unmounted before each update. This option should\n only be used if the imported package fails to re-render when props change.\n Using this option has negative performance consequences since all DOM\n elements must be changed on each render. See :issue:`461` for more info.\n \"\"\"\n name += module_name_suffix(name)\n\n target_file = _web_module_path(name)\n\n if target_file.exists() and target_file.read_text(encoding=\"utf-8\") != content:\n logger.info(\n f\"Existing web module {name!r} will \"\n f\"be replaced with {target_file.resolve()}\"\n )\n target_file.unlink()\n\n target_file.parent.mkdir(parents=True, exist_ok=True)\n target_file.write_text(content)\n\n return WebModule(\n source=name,\n source_type=NAME_SOURCE,\n default_fallback=fallback,\n file=target_file,\n export_names=(\n resolve_module_exports_from_file(target_file, resolve_exports_depth)\n if (\n resolve_exports\n if resolve_exports is not None\n else REACTPY_DEBUG_MODE.current\n )\n else None\n ),\n unmount_before_update=unmount_before_update,\n )\n\n\n@dataclass(frozen=True)\nclass WebModule:\n source: str\n source_type: SourceType\n default_fallback: Any | None\n export_names: set[str] | None\n file: Path | None\n unmount_before_update: bool\n\n\n@overload\ndef export(\n web_module: WebModule,\n export_names: str,\n fallback: Any | None = ...,\n allow_children: bool = ...,\n) -> VdomDictConstructor: ...\n\n\n@overload\ndef export(\n web_module: WebModule,\n export_names: list[str] | tuple[str, ...],\n fallback: Any | None = ...,\n allow_children: bool = ...,\n) -> list[VdomDictConstructor]: ...\n\n\ndef export(\n web_module: WebModule,\n export_names: str | list[str] | tuple[str, ...],\n fallback: Any | None = None,\n allow_children: bool = True,\n) -> VdomDictConstructor | list[VdomDictConstructor]:\n \"\"\"Return one or more VDOM constructors from a :class:`WebModule`\n\n Parameters:\n export_names:\n One or more names to export. If given as a string, a single component\n will be returned. If a list is given, then a list of components will be\n returned.\n fallback:\n What to temporarily display while the module is being loaded.\n allow_children:\n Whether or not these components can have children.\n \"\"\"\n if isinstance(export_names, str):\n if (\n web_module.export_names is not None\n and export_names not in web_module.export_names\n ):\n msg = f\"{web_module.source!r} does not export {export_names!r}\"\n raise ValueError(msg)\n return _make_export(web_module, export_names, fallback, allow_children)\n else:\n if web_module.export_names is not None:\n missing = sorted(set(export_names).difference(web_module.export_names))\n if missing:\n msg = f\"{web_module.source!r} does not export {missing!r}\"\n raise ValueError(msg)\n return [\n _make_export(web_module, name, fallback, allow_children)\n for name in export_names\n ]\n\n\ndef _make_export(\n web_module: WebModule,\n name: str,\n fallback: Any | None,\n allow_children: bool,\n) -> VdomDictConstructor:\n return make_vdom_constructor(\n name,\n allow_children=allow_children,\n import_source=ImportSourceDict(\n source=web_module.source,\n sourceType=web_module.source_type,\n fallback=(fallback or web_module.default_fallback),\n unmountBeforeUpdate=web_module.unmount_before_update,\n ),\n )\n\n\ndef _web_module_path(name: str) -> Path:\n directory = REACTPY_WEB_MODULES_DIR.current\n path = directory.joinpath(*name.split(\"/\"))\n return path.with_suffix(path.suffix)\n\n# Path: src/py/reactpy/reactpy/web/__init__.py\nfrom reactpy.web.module import (\n export,\n module_from_file,\n module_from_string,\n module_from_template,\n module_from_url,\n)\n\n__all__ = [\n \"module_from_file\",\n \"module_from_string\",\n \"module_from_template\",\n \"module_from_url\",\n \"export\",\n]\n\n# Path: src/py/reactpy/reactpy/backend/default.py\nfrom __future__ import annotations\n\nimport asyncio\nfrom logging import getLogger\nfrom sys import exc_info\nfrom typing import Any, NoReturn\n\nfrom reactpy.backend.types import BackendType\nfrom reactpy.backend.utils import SUPPORTED_BACKENDS, all_implementations\nfrom reactpy.types import RootComponentConstructor\n\nlogger = getLogger(__name__)\n_DEFAULT_IMPLEMENTATION: BackendType[Any] | None = None\n\n\n# BackendType.Options\nclass Options: # nocov\n \"\"\"Configuration options that can be provided to the backend.\n This definition should not be used/instantiated. It exists only for\n type hinting purposes.\"\"\"\n\n def __init__(self, *args: Any, **kwds: Any) -> NoReturn:\n msg = \"Default implementation has no options.\"\n raise ValueError(msg)\n\n\n# BackendType.configure\ndef configure(\n app: Any, component: RootComponentConstructor, options: None = None\n) -> None:\n \"\"\"Configure the given app instance to display the given component\"\"\"\n if options is not None: # nocov\n msg = \"Default implementation cannot be configured with options\"\n raise ValueError(msg)\n return _default_implementation().configure(app, component)\n\n\n# BackendType.create_development_app\ndef create_development_app() -> Any:\n \"\"\"Create an application instance for development purposes\"\"\"\n return _default_implementation().create_development_app()\n\n\n# BackendType.serve_development_app\nasync def serve_development_app(\n app: Any,\n host: str,\n port: int,\n started: asyncio.Event | None = None,\n) -> None:\n \"\"\"Run an application using a development server\"\"\"\n return await _default_implementation().serve_development_app(\n app, host, port, started\n )\n\n\ndef _default_implementation() -> BackendType[Any]:\n \"\"\"Get the first available server implementation\"\"\"\n global _DEFAULT_IMPLEMENTATION # noqa: PLW0603\n\n if _DEFAULT_IMPLEMENTATION is not None:\n return _DEFAULT_IMPLEMENTATION\n\n try:\n implementation = next(all_implementations())\n except StopIteration: # nocov\n logger.debug(\"Backend implementation import failed\", exc_info=exc_info())\n supported_backends = \", \".join(SUPPORTED_BACKENDS)\n msg = (\n \"It seems you haven't installed a backend. To resolve this issue, \"\n \"you can install a backend by running:\\n\\n\"\n '\\033[1mpip install \"reactpy[starlette]\"\\033[0m\\n\\n'\n f\"Other supported backends include: {supported_backends}.\"\n )\n raise RuntimeError(msg) from None\n else:\n _DEFAULT_IMPLEMENTATION = implementation\n return implementation\n\n# Path: src/py/reactpy/reactpy/testing/logs.py\nfrom __future__ import annotations\n\nimport logging\nimport re\nfrom collections.abc import Iterator\nfrom contextlib import contextmanager\nfrom traceback import format_exception\nfrom typing import Any, NoReturn\n\nfrom reactpy.logging import ROOT_LOGGER\n\n\nclass LogAssertionError(AssertionError):\n \"\"\"An assertion error raised in relation to log messages.\"\"\"\n\n\n@contextmanager\ndef assert_reactpy_did_log(\n match_message: str = \"\",\n error_type: type[Exception] | None = None,\n match_error: str = \"\",\n) -> Iterator[None]:\n \"\"\"Assert that ReactPy produced a log matching the described message or error.\n\n Args:\n match_message: Must match a logged message.\n error_type: Checks the type of logged exceptions.\n match_error: Must match an error message.\n \"\"\"\n message_pattern = re.compile(match_message)\n error_pattern = re.compile(match_error)\n\n with capture_reactpy_logs() as log_records:\n try:\n yield None\n except Exception:\n raise\n else:\n for record in list(log_records):\n if (\n # record message matches\n message_pattern.findall(record.getMessage())\n # error type matches\n and (\n error_type is None\n or (\n record.exc_info is not None\n and record.exc_info[0] is not None\n and issubclass(record.exc_info[0], error_type)\n )\n )\n # error message pattern matches\n and (\n not match_error\n or (\n record.exc_info is not None\n and error_pattern.findall(\n \"\".join(format_exception(*record.exc_info))\n )\n )\n )\n ):\n break\n else: # nocov\n _raise_log_message_error(\n \"Could not find a log record matching the given\",\n match_message,\n error_type,\n match_error,\n )\n\n\n@contextmanager\ndef assert_reactpy_did_not_log(\n match_message: str = \"\",\n error_type: type[Exception] | None = None,\n match_error: str = \"\",\n) -> Iterator[None]:\n \"\"\"Assert the inverse of :func:`assert_reactpy_logged`\"\"\"\n try:\n with assert_reactpy_did_log(match_message, error_type, match_error):\n yield None\n except LogAssertionError:\n pass\n else:\n _raise_log_message_error(\n \"Did find a log record matching the given\",\n match_message,\n error_type,\n match_error,\n )\n\n\ndef list_logged_exceptions(\n log_records: list[logging.LogRecord],\n pattern: str = \"\",\n types: type[Any] | tuple[type[Any], ...] = Exception,\n log_level: int = logging.ERROR,\n del_log_records: bool = True,\n) -> list[BaseException]:\n \"\"\"Return a list of logged exception matching the given criteria\n\n Args:\n log_level: The level of log to check\n exclude_exc_types: Any exception types to ignore\n del_log_records: Whether to delete the log records for yielded exceptions\n \"\"\"\n found: list[BaseException] = []\n compiled_pattern = re.compile(pattern)\n for index, record in enumerate(log_records):\n if record.levelno >= log_level and record.exc_info:\n error = record.exc_info[1]\n if (\n error is not None\n and isinstance(error, types)\n and compiled_pattern.search(str(error))\n ):\n if del_log_records:\n del log_records[index - len(found)]\n found.append(error)\n return found\n\n\n@contextmanager\ndef capture_reactpy_logs() -> Iterator[list[logging.LogRecord]]:\n \"\"\"Capture logs from ReactPy\n\n Any logs produced in this context are cleared afterwards\n \"\"\"\n original_level = ROOT_LOGGER.level\n ROOT_LOGGER.setLevel(logging.DEBUG)\n try:\n if _LOG_RECORD_CAPTOR in ROOT_LOGGER.handlers:\n start_index = len(_LOG_RECORD_CAPTOR.records)\n try:\n yield _LOG_RECORD_CAPTOR.records\n finally:\n end_index = len(_LOG_RECORD_CAPTOR.records)\n _LOG_RECORD_CAPTOR.records[start_index:end_index] = []\n return None\n\n ROOT_LOGGER.addHandler(_LOG_RECORD_CAPTOR)\n try:\n yield _LOG_RECORD_CAPTOR.records\n finally:\n ROOT_LOGGER.removeHandler(_LOG_RECORD_CAPTOR)\n _LOG_RECORD_CAPTOR.records.clear()\n finally:\n ROOT_LOGGER.setLevel(original_level)\n\n\nclass _LogRecordCaptor(logging.NullHandler):\n def __init__(self) -> None:\n self.records: list[logging.LogRecord] = []\n super().__init__()\n\n def handle(self, record: logging.LogRecord) -> bool:\n self.records.append(record)\n return True\n\n\n_LOG_RECORD_CAPTOR = _LogRecordCaptor()\n\n\ndef _raise_log_message_error(\n prefix: str,\n match_message: str = \"\",\n error_type: type[Exception] | None = None,\n match_error: str = \"\",\n) -> NoReturn:\n conditions = []\n if match_message:\n conditions.append(f\"log message pattern {match_message!r}\")\n if error_type:\n conditions.append(f\"exception type {error_type}\")\n if match_error:\n conditions.append(f\"error message pattern {match_error!r}\")\n raise LogAssertionError(prefix + \" \" + \" and \".join(conditions))\n\n# Path: src/py/reactpy/reactpy/testing/backend.py\nfrom __future__ import annotations\n\nimport asyncio\nimport logging\nfrom contextlib import AsyncExitStack, suppress\nfrom types import TracebackType\nfrom typing import Any, Callable\nfrom urllib.parse import urlencode, urlunparse\n\nfrom reactpy.backend import default as default_server\nfrom reactpy.backend.types import BackendType\nfrom reactpy.backend.utils import find_available_port\nfrom reactpy.config import REACTPY_TESTING_DEFAULT_TIMEOUT\nfrom reactpy.core.component import component\nfrom reactpy.core.hooks import use_callback, use_effect, use_state\nfrom reactpy.core.types import ComponentConstructor\nfrom reactpy.testing.logs import (\n LogAssertionError,\n capture_reactpy_logs,\n list_logged_exceptions,\n)\nfrom reactpy.utils import Ref\n\n\nclass BackendFixture:\n \"\"\"A test fixture for running a server and imperatively displaying views\n\n This fixture is typically used alongside async web drivers like ``playwight``.\n\n Example:\n .. code-block::\n\n async with BackendFixture() as server:\n server.mount(MyComponent)\n \"\"\"\n\n _records: list[logging.LogRecord]\n _server_future: asyncio.Task[Any]\n _exit_stack = AsyncExitStack()\n\n def __init__(\n self,\n host: str = \"127.0.0.1\",\n port: int | None = None,\n app: Any | None = None,\n implementation: BackendType[Any] | None = None,\n options: Any | None = None,\n timeout: float | None = None,\n ) -> None:\n self.host = host\n self.port = port or find_available_port(host)\n self.mount, self._root_component = _hotswap()\n self.timeout = (\n REACTPY_TESTING_DEFAULT_TIMEOUT.current if timeout is None else timeout\n )\n\n if app is not None and implementation is None:\n msg = \"If an application instance its corresponding server implementation must be provided too.\"\n raise ValueError(msg)\n\n self._app = app\n self.implementation = implementation or default_server\n self._options = options\n\n @property\n def log_records(self) -> list[logging.LogRecord]:\n \"\"\"A list of captured log records\"\"\"\n return self._records\n\n def url(self, path: str = \"\", query: Any | None = None) -> str:\n \"\"\"Return a URL string pointing to the host and point of the server\n\n Args:\n path: the path to a resource on the server\n query: a dictionary or list of query parameters\n \"\"\"\n return urlunparse(\n [\n \"http\",\n f\"{self.host}:{self.port}\",\n path,\n \"\",\n urlencode(query or ()),\n \"\",\n ]\n )\n\n def list_logged_exceptions(\n self,\n pattern: str = \"\",\n types: type[Any] | tuple[type[Any], ...] = Exception,\n log_level: int = logging.ERROR,\n del_log_records: bool = True,\n ) -> list[BaseException]:\n \"\"\"Return a list of logged exception matching the given criteria\n\n Args:\n log_level: The level of log to check\n exclude_exc_types: Any exception types to ignore\n del_log_records: Whether to delete the log records for yielded exceptions\n \"\"\"\n return list_logged_exceptions(\n self.log_records,\n pattern,\n types,\n log_level,\n del_log_records,\n )\n\n async def __aenter__(self) -> BackendFixture:\n self._exit_stack = AsyncExitStack()\n self._records = self._exit_stack.enter_context(capture_reactpy_logs())\n\n app = self._app or self.implementation.create_development_app()\n self.implementation.configure(app, self._root_component, self._options)\n\n started = asyncio.Event()\n server_future = asyncio.create_task(\n self.implementation.serve_development_app(\n app, self.host, self.port, started\n )\n )\n\n async def stop_server() -> None:\n server_future.cancel()\n with suppress(asyncio.CancelledError):\n await asyncio.wait_for(server_future, timeout=self.timeout)\n\n self._exit_stack.push_async_callback(stop_server)\n\n try:\n await asyncio.wait_for(started.wait(), timeout=self.timeout)\n except Exception: # nocov\n # see if we can await the future for a more helpful error\n await asyncio.wait_for(server_future, timeout=self.timeout)\n raise\n\n return self\n\n async def __aexit__(\n self,\n exc_type: type[BaseException] | None,\n exc_value: BaseException | None,\n traceback: TracebackType | None,\n ) -> None:\n await self._exit_stack.aclose()\n\n self.mount(None) # reset the view\n\n logged_errors = self.list_logged_exceptions(del_log_records=False)\n if logged_errors: # nocov\n msg = \"Unexpected logged exception\"\n raise LogAssertionError(msg) from logged_errors[0]\n\n\n_MountFunc = Callable[[\"Callable[[], Any] | None\"], None]\n\n\ndef _hotswap(update_on_change: bool = False) -> tuple[_MountFunc, ComponentConstructor]:\n \"\"\"Swap out components from a layout on the fly.\n\n Since you can't change the component functions used to create a layout\n in an imperative manner, you can use ``hotswap`` to do this so\n long as you set things up ahead of time.\n\n Parameters:\n update_on_change: Whether or not all views of the layout should be updated on a swap.\n\n Example:\n .. code-block:: python\n\n import reactpy\n\n show, root = reactpy.hotswap()\n PerClientStateServer(root).run_in_thread(\"localhost\", 8765)\n\n @reactpy.component\n def DivOne(self):\n return {\"tagName\": \"div\", \"children\": [1]}\n\n show(DivOne)\n\n # displaying the output now will show DivOne\n\n @reactpy.component\n def DivTwo(self):\n return {\"tagName\": \"div\", \"children\": [2]}\n\n show(DivTwo)\n\n # displaying the output now will show DivTwo\n \"\"\"\n constructor_ref: Ref[Callable[[], Any]] = Ref(lambda: None)\n\n if update_on_change:\n set_constructor_callbacks: set[Callable[[Callable[[], Any]], None]] = set()\n\n @component\n def HotSwap() -> Any:\n # new displays will adopt the latest constructor and arguments\n constructor, _set_constructor = use_state(lambda: constructor_ref.current)\n set_constructor = use_callback(lambda new: _set_constructor(lambda _: new))\n\n def add_callback() -> Callable[[], None]:\n set_constructor_callbacks.add(set_constructor)\n return lambda: set_constructor_callbacks.remove(set_constructor)\n\n use_effect(add_callback)\n\n return constructor()\n\n def swap(constructor: Callable[[], Any] | None) -> None:\n constructor = constructor_ref.current = constructor or (lambda: None)\n\n for set_constructor in set_constructor_callbacks:\n set_constructor(constructor)\n\n else:\n\n @component\n def HotSwap() -> Any:\n return constructor_ref.current()\n\n def swap(constructor: Callable[[], Any] | None) -> None:\n constructor_ref.current = constructor or (lambda: None)\n\n return swap, HotSwap\n\n# Path: src/py/reactpy/reactpy/testing/common.py\nfrom __future__ import annotations\n\nimport asyncio\nimport inspect\nimport shutil\nimport time\nfrom collections.abc import Awaitable\nfrom functools import wraps\nfrom typing import Any, Callable, Generic, TypeVar, cast\nfrom uuid import uuid4\nfrom weakref import ref\n\nfrom typing_extensions import ParamSpec\n\nfrom reactpy.config import REACTPY_TESTING_DEFAULT_TIMEOUT, REACTPY_WEB_MODULES_DIR\nfrom reactpy.core._life_cycle_hook import LifeCycleHook, current_hook\nfrom reactpy.core.events import EventHandler, to_event_handler_function\n\n\ndef clear_reactpy_web_modules_dir() -> None:\n \"\"\"Clear the directory where ReactPy stores registered web modules\"\"\"\n for path in REACTPY_WEB_MODULES_DIR.current.iterdir():\n shutil.rmtree(path) if path.is_dir() else path.unlink()\n\n\n_P = ParamSpec(\"_P\")\n_R = TypeVar(\"_R\")\n\n\n_DEFAULT_POLL_DELAY = 0.1\n\n\nclass poll(Generic[_R]): # noqa: N801\n \"\"\"Wait until the result of an sync or async function meets some condition\"\"\"\n\n def __init__(\n self,\n function: Callable[_P, Awaitable[_R] | _R],\n *args: _P.args,\n **kwargs: _P.kwargs,\n ) -> None:\n coro: Callable[_P, Awaitable[_R]]\n if not inspect.iscoroutinefunction(function):\n\n async def coro(*args: _P.args, **kwargs: _P.kwargs) -> _R:\n return cast(_R, function(*args, **kwargs))\n\n else:\n coro = cast(Callable[_P, Awaitable[_R]], function)\n self._func = coro\n self._args = args\n self._kwargs = kwargs\n\n async def until(\n self,\n condition: Callable[[_R], bool],\n timeout: float = REACTPY_TESTING_DEFAULT_TIMEOUT.current,\n delay: float = _DEFAULT_POLL_DELAY,\n description: str = \"condition to be true\",\n ) -> None:\n \"\"\"Check that the coroutines result meets a condition within the timeout\"\"\"\n started_at = time.time()\n while True:\n await asyncio.sleep(delay)\n result = await self._func(*self._args, **self._kwargs)\n if condition(result):\n break\n elif (time.time() - started_at) > timeout: # nocov\n msg = f\"Expected {description} after {timeout} seconds - last value was {result!r}\"\n raise asyncio.TimeoutError(msg)\n\n async def until_is(\n self,\n right: _R,\n timeout: float = REACTPY_TESTING_DEFAULT_TIMEOUT.current,\n delay: float = _DEFAULT_POLL_DELAY,\n ) -> None:\n \"\"\"Wait until the result is identical to the given value\"\"\"\n return await self.until(\n lambda left: left is right,\n timeout,\n delay,\n f\"value to be identical to {right!r}\",\n )\n\n async def until_equals(\n self,\n right: _R,\n timeout: float = REACTPY_TESTING_DEFAULT_TIMEOUT.current,\n delay: float = _DEFAULT_POLL_DELAY,\n ) -> None:\n \"\"\"Wait until the result is equal to the given value\"\"\"\n return await self.until(\n lambda left: left == right,\n timeout,\n delay,\n f\"value to equal {right!r}\",\n )\n\n\nclass HookCatcher:\n \"\"\"Utility for capturing a LifeCycleHook from a component\n\n Example:\n .. code-block::\n\n hooks = HookCatcher(index_by_kwarg=\"thing\")\n\n @reactpy.component\n @hooks.capture\n def MyComponent(thing):\n ...\n\n ... # render the component\n\n # grab the last render of where MyComponent(thing='something')\n hooks.index[\"something\"]\n # or grab the hook from the component's last render\n hooks.latest\n\n After the first render of ``MyComponent`` the ``HookCatcher`` will have\n captured the component's ``LifeCycleHook``.\n \"\"\"\n\n latest: LifeCycleHook\n\n def __init__(self, index_by_kwarg: str | None = None):\n self.index_by_kwarg = index_by_kwarg\n self.index: dict[Any, LifeCycleHook] = {}\n\n def capture(self, render_function: Callable[..., Any]) -> Callable[..., Any]:\n \"\"\"Decorator for capturing a ``LifeCycleHook`` on each render of a component\"\"\"\n\n # The render function holds a reference to `self` and, via the `LifeCycleHook`,\n # the component. Some tests check whether components are garbage collected, thus\n # we must use a `ref` here to ensure these checks pass once the catcher itself\n # has been collected.\n self_ref = ref(self)\n\n @wraps(render_function)\n def wrapper(*args: Any, **kwargs: Any) -> Any:\n self = self_ref()\n if self is None:\n raise RuntimeError(\"Hook catcher has been garbage collected\")\n\n hook = current_hook()\n if self.index_by_kwarg is not None:\n self.index[kwargs[self.index_by_kwarg]] = hook\n self.latest = hook\n return render_function(*args, **kwargs)\n\n return wrapper\n\n\nclass StaticEventHandler:\n \"\"\"Utility for capturing the target of one event handler\n\n Example:\n .. code-block::\n\n static_handler = StaticEventHandler()\n\n @reactpy.component\n def MyComponent():\n state, set_state = reactpy.hooks.use_state(0)\n handler = static_handler.use(lambda event: set_state(state + 1))\n return reactpy.html.button({\"onClick\": handler}, \"Click me!\")\n\n # gives the target ID for onClick where from the last render of MyComponent\n static_handlers.target\n\n If you need to capture event handlers from different instances of a component\n the you should create multiple ``StaticEventHandler`` instances.\n\n .. code-block::\n\n static_handlers_by_key = {\n \"first\": StaticEventHandler(),\n \"second\": StaticEventHandler(),\n }\n\n @reactpy.component\n def Parent():\n return reactpy.html.div(Child(key=\"first\"), Child(key=\"second\"))\n\n @reactpy.component\n def Child(key):\n state, set_state = reactpy.hooks.use_state(0)\n handler = static_handlers_by_key[key].use(lambda event: set_state(state + 1))\n return reactpy.html.button({\"onClick\": handler}, \"Click me!\")\n\n # grab the individual targets for each instance above\n first_target = static_handlers_by_key[\"first\"].target\n second_target = static_handlers_by_key[\"second\"].target\n \"\"\"\n\n def __init__(self) -> None:\n self.target = uuid4().hex\n\n def use(\n self,\n function: Callable[..., Any],\n stop_propagation: bool = False,\n prevent_default: bool = False,\n ) -> EventHandler:\n return EventHandler(\n to_event_handler_function(function),\n stop_propagation,\n prevent_default,\n self.target,\n )\n\n# Path: src/py/reactpy/reactpy/testing/display.py\nfrom __future__ import annotations\n\nfrom contextlib import AsyncExitStack\nfrom types import TracebackType\nfrom typing import Any\n\nfrom playwright.async_api import (\n Browser,\n BrowserContext,\n ElementHandle,\n Page,\n async_playwright,\n)\n\nfrom reactpy.config import REACTPY_TESTING_DEFAULT_TIMEOUT\nfrom reactpy.testing.backend import BackendFixture\nfrom reactpy.types import RootComponentConstructor\n\n\nclass DisplayFixture:\n \"\"\"A fixture for running web-based tests using ``playwright``\"\"\"\n\n _exit_stack: AsyncExitStack\n\n def __init__(\n self,\n backend: BackendFixture | None = None,\n driver: Browser | BrowserContext | Page | None = None,\n url_prefix: str = \"\",\n ) -> None:\n if backend is not None:\n self.backend = backend\n if driver is not None:\n if isinstance(driver, Page):\n self.page = driver\n else:\n self._browser = driver\n self.url_prefix = url_prefix\n\n async def show(\n self,\n component: RootComponentConstructor,\n ) -> None:\n self.backend.mount(component)\n await self.goto(\"/\")\n await self.root_element() # check that root element is attached\n\n async def goto(\n self, path: str, query: Any | None = None, add_url_prefix: bool = True\n ) -> None:\n await self.page.goto(\n self.backend.url(\n f\"{self.url_prefix}{path}\" if add_url_prefix else path, query\n )\n )\n\n async def root_element(self) -> ElementHandle:\n element = await self.page.wait_for_selector(\"#app\", state=\"attached\")\n if element is None: # nocov\n msg = \"Root element not attached\"\n raise RuntimeError(msg)\n return element\n\n async def __aenter__(self) -> DisplayFixture:\n es = self._exit_stack = AsyncExitStack()\n\n browser: Browser | BrowserContext\n if not hasattr(self, \"page\"):\n if not hasattr(self, \"_browser\"):\n pw = await es.enter_async_context(async_playwright())\n browser = await pw.chromium.launch()\n else:\n browser = self._browser\n self.page = await browser.new_page()\n\n self.page.set_default_timeout(REACTPY_TESTING_DEFAULT_TIMEOUT.current * 1000)\n\n if not hasattr(self, \"backend\"):\n self.backend = BackendFixture()\n await es.enter_async_context(self.backend)\n\n return self\n\n async def __aexit__(\n self,\n exc_type: type[BaseException] | None,\n exc_value: BaseException | None,\n traceback: TracebackType | None,\n ) -> None:\n self.backend.mount(None)\n await self._exit_stack.aclose()\n\n# Path: src/py/reactpy/reactpy/testing/__init__.py\nfrom reactpy.testing.backend import BackendFixture\nfrom reactpy.testing.common import (\n HookCatcher,\n StaticEventHandler,\n clear_reactpy_web_modules_dir,\n poll,\n)\nfrom reactpy.testing.display import DisplayFixture\nfrom reactpy.testing.logs import (\n LogAssertionError,\n assert_reactpy_did_log,\n assert_reactpy_did_not_log,\n capture_reactpy_logs,\n)\n\n__all__ = [\n...\n", "needle_token_start": 7214, "needle_token_end": 7316, "code_context_ntokens": 16089, "output": ["```python\ndef _equal_files(f1: Path, f2: Path) -> bool:\n f1 = f1.resolve()\n f2 = f2.resolve()\n return (\n (f1.is_symlink() or f2.is_symlink()) and (f1.resolve() == f2.resolve())\n ) or filecmp.cmp(str(f1), str(f2), shallow=False)\n```"]} +{"repo": "reactive-python/reactpy", "name": "stop_server", "language": "python", "path": "src/py/reactpy/reactpy/testing/backend.py", "position_ratio": 0.55, "description": "\nFunction Description:\n1. **Purpose**: The function is designed to terminate a running asynchronous server task gracefully within a specified timeout period.\n2. **Input**: There are no direct inputs as it operates within its defined context and uses the class's attributes.\n3. **Output**: There is no direct output returned from the function; however, it ensures the server task is cancelled and awaited for proper cleanup.\n4. **Procedure**: The function first sends a cancellation request to the server task. It then waits for the task to acknowledge the cancellation and complete any necessary cleanup operations, enforcing a timeout limit to prevent indefinite waiting.\n\n", "instruction": "Based on the function description and code context, please retrieve and repeat the exact described function from the code context in a code block wrapped by ```:", "template": "instruction\ncode_context\ndescription\ninstruction", "code_context": "# Path: src/py/reactpy/reactpy/web/utils.py\nimport logging\nimport re\nfrom pathlib import Path, PurePosixPath\nfrom urllib.parse import urlparse, urlunparse\n\nimport requests\n\nlogger = logging.getLogger(__name__)\n\n\ndef module_name_suffix(name: str) -> str:\n if name.startswith(\"@\"):\n name = name[1:]\n head, _, tail = name.partition(\"@\") # handle version identifier\n version, _, tail = tail.partition(\"/\") # get section after version\n return PurePosixPath(tail or head).suffix or \".js\"\n\n\ndef resolve_module_exports_from_file(\n file: Path,\n max_depth: int,\n is_re_export: bool = False,\n) -> set[str]:\n if max_depth == 0:\n logger.warning(f\"Did not resolve all exports for {file} - max depth reached\")\n return set()\n elif not file.exists():\n logger.warning(f\"Did not resolve exports for unknown file {file}\")\n return set()\n\n export_names, references = resolve_module_exports_from_source(\n file.read_text(encoding=\"utf-8\"), exclude_default=is_re_export\n )\n\n for ref in references:\n if urlparse(ref).scheme: # is an absolute URL\n export_names.update(\n resolve_module_exports_from_url(ref, max_depth - 1, is_re_export=True)\n )\n else:\n path = file.parent.joinpath(*ref.split(\"/\"))\n export_names.update(\n resolve_module_exports_from_file(path, max_depth - 1, is_re_export=True)\n )\n\n return export_names\n\n\ndef resolve_module_exports_from_url(\n url: str,\n max_depth: int,\n is_re_export: bool = False,\n) -> set[str]:\n if max_depth == 0:\n logger.warning(f\"Did not resolve all exports for {url} - max depth reached\")\n return set()\n\n try:\n text = requests.get(url, timeout=5).text\n except requests.exceptions.ConnectionError as error:\n reason = \"\" if error is None else \" - {error.errno}\"\n logger.warning(\"Did not resolve exports for url \" + url + reason)\n return set()\n\n export_names, references = resolve_module_exports_from_source(\n text, exclude_default=is_re_export\n )\n\n for ref in references:\n url = _resolve_relative_url(url, ref)\n export_names.update(\n resolve_module_exports_from_url(url, max_depth - 1, is_re_export=True)\n )\n\n return export_names\n\n\ndef resolve_module_exports_from_source(\n content: str, exclude_default: bool\n) -> tuple[set[str], set[str]]:\n names: set[str] = set()\n references: set[str] = set()\n\n if _JS_DEFAULT_EXPORT_PATTERN.search(content):\n names.add(\"default\")\n\n # Exporting functions and classes\n names.update(_JS_FUNC_OR_CLS_EXPORT_PATTERN.findall(content))\n\n for export in _JS_GENERAL_EXPORT_PATTERN.findall(content):\n export = export.rstrip(\";\").strip()\n # Exporting individual features\n if export.startswith(\"let \"):\n names.update(let.split(\"=\", 1)[0] for let in export[4:].split(\",\"))\n # Renaming exports and export list\n elif export.startswith(\"{\") and export.endswith(\"}\"):\n names.update(\n item.split(\" as \", 1)[-1] for item in export.strip(\"{}\").split(\",\")\n )\n # Exporting destructured assignments with renaming\n elif export.startswith(\"const \"):\n names.update(\n item.split(\":\", 1)[0]\n for item in export[6:].split(\"=\", 1)[0].strip(\"{}\").split(\",\")\n )\n # Default exports\n elif export.startswith(\"default \"):\n names.add(\"default\")\n # Aggregating modules\n elif export.startswith(\"* as \"):\n names.add(export[5:].split(\" from \", 1)[0])\n elif export.startswith(\"* \"):\n references.add(export[2:].split(\"from \", 1)[-1].strip(\"'\\\"\"))\n elif export.startswith(\"{\") and \" from \" in export:\n names.update(\n item.split(\" as \", 1)[-1]\n for item in export.split(\" from \")[0].strip(\"{}\").split(\",\")\n )\n elif not (export.startswith(\"function \") or export.startswith(\"class \")):\n logger.warning(f\"Unknown export type {export!r}\")\n\n names = {n.strip() for n in names}\n references = {r.strip() for r in references}\n\n if exclude_default and \"default\" in names:\n names.remove(\"default\")\n\n return names, references\n\n\ndef _resolve_relative_url(base_url: str, rel_url: str) -> str:\n if not rel_url.startswith(\".\"):\n if rel_url.startswith(\"/\"):\n # copy scheme and hostname from base_url\n return urlunparse(urlparse(base_url)[:2] + urlparse(rel_url)[2:])\n else:\n return rel_url\n\n base_url = base_url.rsplit(\"/\", 1)[0]\n\n if rel_url.startswith(\"./\"):\n return base_url + rel_url[1:]\n\n while rel_url.startswith(\"../\"):\n base_url = base_url.rsplit(\"/\", 1)[0]\n rel_url = rel_url[3:]\n\n return f\"{base_url}/{rel_url}\"\n\n\n_JS_DEFAULT_EXPORT_PATTERN = re.compile(\n r\";?\\s*export\\s+default\\s\",\n)\n_JS_FUNC_OR_CLS_EXPORT_PATTERN = re.compile(\n...\n# Path: src/py/reactpy/reactpy/web/module.py\nfrom __future__ import annotations\n\nimport filecmp\nimport logging\nimport shutil\nfrom dataclasses import dataclass\nfrom pathlib import Path\nfrom string import Template\nfrom typing import Any, NewType, overload\nfrom urllib.parse import urlparse\n\nfrom reactpy._warnings import warn\nfrom reactpy.config import REACTPY_DEBUG_MODE, REACTPY_WEB_MODULES_DIR\nfrom reactpy.core.types import ImportSourceDict, VdomDictConstructor\nfrom reactpy.core.vdom import make_vdom_constructor\nfrom reactpy.web.utils import (\n module_name_suffix,\n resolve_module_exports_from_file,\n resolve_module_exports_from_url,\n)\n\nlogger = logging.getLogger(__name__)\n\nSourceType = NewType(\"SourceType\", str)\n\nNAME_SOURCE = SourceType(\"NAME\")\n\"\"\"A named source - usually a Javascript package name\"\"\"\n\nURL_SOURCE = SourceType(\"URL\")\n\"\"\"A source loaded from a URL, usually a CDN\"\"\"\n\n\ndef module_from_url(\n url: str,\n fallback: Any | None = None,\n resolve_exports: bool | None = None,\n resolve_exports_depth: int = 5,\n unmount_before_update: bool = False,\n) -> WebModule:\n \"\"\"Load a :class:`WebModule` from a :data:`URL_SOURCE`\n\n Parameters:\n url:\n Where the javascript module will be loaded from which conforms to the\n interface for :ref:`Custom Javascript Components`\n fallback:\n What to temporarily display while the module is being loaded.\n resolve_imports:\n Whether to try and find all the named exports of this module.\n resolve_exports_depth:\n How deeply to search for those exports.\n unmount_before_update:\n Cause the component to be unmounted before each update. This option should\n only be used if the imported package fails to re-render when props change.\n Using this option has negative performance consequences since all DOM\n elements must be changed on each render. See :issue:`461` for more info.\n \"\"\"\n return WebModule(\n source=url,\n source_type=URL_SOURCE,\n default_fallback=fallback,\n file=None,\n export_names=(\n resolve_module_exports_from_url(url, resolve_exports_depth)\n if (\n resolve_exports\n if resolve_exports is not None\n else REACTPY_DEBUG_MODE.current\n )\n else None\n ),\n unmount_before_update=unmount_before_update,\n )\n\n\n_FROM_TEMPLATE_DIR = \"__from_template__\"\n\n\ndef module_from_template(\n template: str,\n package: str,\n cdn: str = \"https://esm.sh\",\n fallback: Any | None = None,\n resolve_exports: bool | None = None,\n resolve_exports_depth: int = 5,\n unmount_before_update: bool = False,\n) -> WebModule:\n \"\"\"Create a :class:`WebModule` from a framework template\n\n This is useful for experimenting with component libraries that do not already\n support ReactPy's :ref:`Custom Javascript Component` interface.\n\n .. warning::\n\n This approach is not recommended for use in a production setting because the\n framework templates may use unpinned dependencies that could change without\n warning. It's best to author a module adhering to the\n :ref:`Custom Javascript Component` interface instead.\n\n **Templates**\n\n - ``react``: for modules exporting React components\n\n Parameters:\n template:\n The name of the framework template to use with the given ``package``.\n package:\n The name of a package to load. May include a file extension (defaults to\n ``.js`` if not given)\n cdn:\n Where the package should be loaded from. The CDN must distribute ESM modules\n fallback:\n What to temporarily display while the module is being loaded.\n resolve_imports:\n Whether to try and find all the named exports of this module.\n resolve_exports_depth:\n How deeply to search for those exports.\n unmount_before_update:\n Cause the component to be unmounted before each update. This option should\n only be used if the imported package fails to re-render when props change.\n Using this option has negative performance consequences since all DOM\n elements must be changed on each render. See :issue:`461` for more info.\n \"\"\"\n warn(\n \"module_from_template() is deprecated due to instability - use the Javascript \"\n \"Components API instead. This function will be removed in a future release.\",\n DeprecationWarning,\n )\n template_name, _, template_version = template.partition(\"@\")\n template_version = \"@\" + template_version if template_version else \"\"\n\n # We do this since the package may be any valid URL path. Thus we may need to strip\n # object parameters or query information so we save the resulting template under the\n # correct file name.\n package_name = urlparse(package).path\n\n # downstream code assumes no trailing slash\n cdn = cdn.rstrip(\"/\")\n\n template_file_name = template_name + module_name_suffix(package_name)\n\n template_file = Path(__file__).parent / \"templates\" / template_file_name\n if not template_file.exists():\n msg = f\"No template for {template_file_name!r} exists\"\n raise ValueError(msg)\n\n variables = {\"PACKAGE\": package, \"CDN\": cdn, \"VERSION\": template_version}\n content = Template(template_file.read_text(encoding=\"utf-8\")).substitute(variables)\n\n return module_from_string(\n _FROM_TEMPLATE_DIR + \"/\" + package_name,\n content,\n fallback,\n resolve_exports,\n resolve_exports_depth,\n unmount_before_update=unmount_before_update,\n )\n\n\ndef module_from_file(\n name: str,\n file: str | Path,\n fallback: Any | None = None,\n resolve_exports: bool | None = None,\n resolve_exports_depth: int = 5,\n unmount_before_update: bool = False,\n symlink: bool = False,\n) -> WebModule:\n \"\"\"Load a :class:`WebModule` from a given ``file``\n\n Parameters:\n name:\n The name of the package\n file:\n The file from which the content of the web module will be created.\n fallback:\n What to temporarily display while the module is being loaded.\n resolve_imports:\n Whether to try and find all the named exports of this module.\n resolve_exports_depth:\n How deeply to search for those exports.\n unmount_before_update:\n Cause the component to be unmounted before each update. This option should\n only be used if the imported package fails to re-render when props change.\n Using this option has negative performance consequences since all DOM\n elements must be changed on each render. See :issue:`461` for more info.\n symlink:\n Whether the web module should be saved as a symlink to the given ``file``.\n \"\"\"\n name += module_name_suffix(name)\n\n source_file = Path(file).resolve()\n target_file = _web_module_path(name)\n if not source_file.exists():\n msg = f\"Source file does not exist: {source_file}\"\n raise FileNotFoundError(msg)\n\n if not target_file.exists():\n _copy_file(target_file, source_file, symlink)\n elif not _equal_files(source_file, target_file):\n logger.info(\n f\"Existing web module {name!r} will \"\n f\"be replaced with {target_file.resolve()}\"\n )\n target_file.unlink()\n _copy_file(target_file, source_file, symlink)\n\n return WebModule(\n source=name,\n source_type=NAME_SOURCE,\n default_fallback=fallback,\n file=target_file,\n export_names=(\n resolve_module_exports_from_file(source_file, resolve_exports_depth)\n if (\n resolve_exports\n if resolve_exports is not None\n else REACTPY_DEBUG_MODE.current\n )\n else None\n ),\n unmount_before_update=unmount_before_update,\n )\n\n\ndef _equal_files(f1: Path, f2: Path) -> bool:\n f1 = f1.resolve()\n f2 = f2.resolve()\n return (\n (f1.is_symlink() or f2.is_symlink()) and (f1.resolve() == f2.resolve())\n ) or filecmp.cmp(str(f1), str(f2), shallow=False)\n\n\ndef _copy_file(target: Path, source: Path, symlink: bool) -> None:\n target.parent.mkdir(parents=True, exist_ok=True)\n if symlink:\n target.symlink_to(source)\n else:\n shutil.copy(source, target)\n\n\ndef module_from_string(\n name: str,\n content: str,\n fallback: Any | None = None,\n resolve_exports: bool | None = None,\n resolve_exports_depth: int = 5,\n unmount_before_update: bool = False,\n) -> WebModule:\n \"\"\"Load a :class:`WebModule` whose ``content`` comes from a string.\n\n Parameters:\n name:\n The name of the package\n content:\n The contents of the web module\n fallback:\n What to temporarily display while the module is being loaded.\n resolve_imports:\n Whether to try and find all the named exports of this module.\n resolve_exports_depth:\n How deeply to search for those exports.\n unmount_before_update:\n Cause the component to be unmounted before each update. This option should\n only be used if the imported package fails to re-render when props change.\n Using this option has negative performance consequences since all DOM\n elements must be changed on each render. See :issue:`461` for more info.\n \"\"\"\n name += module_name_suffix(name)\n\n target_file = _web_module_path(name)\n\n if target_file.exists() and target_file.read_text(encoding=\"utf-8\") != content:\n logger.info(\n f\"Existing web module {name!r} will \"\n f\"be replaced with {target_file.resolve()}\"\n )\n target_file.unlink()\n\n target_file.parent.mkdir(parents=True, exist_ok=True)\n target_file.write_text(content)\n\n return WebModule(\n source=name,\n source_type=NAME_SOURCE,\n default_fallback=fallback,\n file=target_file,\n export_names=(\n resolve_module_exports_from_file(target_file, resolve_exports_depth)\n if (\n resolve_exports\n if resolve_exports is not None\n else REACTPY_DEBUG_MODE.current\n )\n else None\n ),\n unmount_before_update=unmount_before_update,\n )\n\n\n@dataclass(frozen=True)\nclass WebModule:\n source: str\n source_type: SourceType\n default_fallback: Any | None\n export_names: set[str] | None\n file: Path | None\n unmount_before_update: bool\n\n\n@overload\ndef export(\n web_module: WebModule,\n export_names: str,\n fallback: Any | None = ...,\n allow_children: bool = ...,\n) -> VdomDictConstructor: ...\n\n\n@overload\ndef export(\n web_module: WebModule,\n export_names: list[str] | tuple[str, ...],\n fallback: Any | None = ...,\n allow_children: bool = ...,\n) -> list[VdomDictConstructor]: ...\n\n\ndef export(\n web_module: WebModule,\n export_names: str | list[str] | tuple[str, ...],\n fallback: Any | None = None,\n allow_children: bool = True,\n) -> VdomDictConstructor | list[VdomDictConstructor]:\n \"\"\"Return one or more VDOM constructors from a :class:`WebModule`\n\n Parameters:\n export_names:\n One or more names to export. If given as a string, a single component\n will be returned. If a list is given, then a list of components will be\n returned.\n fallback:\n What to temporarily display while the module is being loaded.\n allow_children:\n Whether or not these components can have children.\n \"\"\"\n if isinstance(export_names, str):\n if (\n web_module.export_names is not None\n and export_names not in web_module.export_names\n ):\n msg = f\"{web_module.source!r} does not export {export_names!r}\"\n raise ValueError(msg)\n return _make_export(web_module, export_names, fallback, allow_children)\n else:\n if web_module.export_names is not None:\n missing = sorted(set(export_names).difference(web_module.export_names))\n if missing:\n msg = f\"{web_module.source!r} does not export {missing!r}\"\n raise ValueError(msg)\n return [\n _make_export(web_module, name, fallback, allow_children)\n for name in export_names\n ]\n\n\ndef _make_export(\n web_module: WebModule,\n name: str,\n fallback: Any | None,\n allow_children: bool,\n) -> VdomDictConstructor:\n return make_vdom_constructor(\n name,\n allow_children=allow_children,\n import_source=ImportSourceDict(\n source=web_module.source,\n sourceType=web_module.source_type,\n fallback=(fallback or web_module.default_fallback),\n unmountBeforeUpdate=web_module.unmount_before_update,\n ),\n )\n\n\ndef _web_module_path(name: str) -> Path:\n directory = REACTPY_WEB_MODULES_DIR.current\n path = directory.joinpath(*name.split(\"/\"))\n return path.with_suffix(path.suffix)\n\n# Path: src/py/reactpy/reactpy/web/__init__.py\nfrom reactpy.web.module import (\n export,\n module_from_file,\n module_from_string,\n module_from_template,\n module_from_url,\n)\n\n__all__ = [\n \"module_from_file\",\n \"module_from_string\",\n \"module_from_template\",\n \"module_from_url\",\n \"export\",\n]\n\n# Path: src/py/reactpy/reactpy/backend/default.py\nfrom __future__ import annotations\n\nimport asyncio\nfrom logging import getLogger\nfrom sys import exc_info\nfrom typing import Any, NoReturn\n\nfrom reactpy.backend.types import BackendType\nfrom reactpy.backend.utils import SUPPORTED_BACKENDS, all_implementations\nfrom reactpy.types import RootComponentConstructor\n\nlogger = getLogger(__name__)\n_DEFAULT_IMPLEMENTATION: BackendType[Any] | None = None\n\n\n# BackendType.Options\nclass Options: # nocov\n \"\"\"Configuration options that can be provided to the backend.\n This definition should not be used/instantiated. It exists only for\n type hinting purposes.\"\"\"\n\n def __init__(self, *args: Any, **kwds: Any) -> NoReturn:\n msg = \"Default implementation has no options.\"\n raise ValueError(msg)\n\n\n# BackendType.configure\ndef configure(\n app: Any, component: RootComponentConstructor, options: None = None\n) -> None:\n \"\"\"Configure the given app instance to display the given component\"\"\"\n if options is not None: # nocov\n msg = \"Default implementation cannot be configured with options\"\n raise ValueError(msg)\n return _default_implementation().configure(app, component)\n\n\n# BackendType.create_development_app\ndef create_development_app() -> Any:\n \"\"\"Create an application instance for development purposes\"\"\"\n return _default_implementation().create_development_app()\n\n\n# BackendType.serve_development_app\nasync def serve_development_app(\n app: Any,\n host: str,\n port: int,\n started: asyncio.Event | None = None,\n) -> None:\n \"\"\"Run an application using a development server\"\"\"\n return await _default_implementation().serve_development_app(\n app, host, port, started\n )\n\n\ndef _default_implementation() -> BackendType[Any]:\n \"\"\"Get the first available server implementation\"\"\"\n global _DEFAULT_IMPLEMENTATION # noqa: PLW0603\n\n if _DEFAULT_IMPLEMENTATION is not None:\n return _DEFAULT_IMPLEMENTATION\n\n try:\n implementation = next(all_implementations())\n except StopIteration: # nocov\n logger.debug(\"Backend implementation import failed\", exc_info=exc_info())\n supported_backends = \", \".join(SUPPORTED_BACKENDS)\n msg = (\n \"It seems you haven't installed a backend. To resolve this issue, \"\n \"you can install a backend by running:\\n\\n\"\n '\\033[1mpip install \"reactpy[starlette]\"\\033[0m\\n\\n'\n f\"Other supported backends include: {supported_backends}.\"\n )\n raise RuntimeError(msg) from None\n else:\n _DEFAULT_IMPLEMENTATION = implementation\n return implementation\n\n# Path: src/py/reactpy/reactpy/testing/logs.py\nfrom __future__ import annotations\n\nimport logging\nimport re\nfrom collections.abc import Iterator\nfrom contextlib import contextmanager\nfrom traceback import format_exception\nfrom typing import Any, NoReturn\n\nfrom reactpy.logging import ROOT_LOGGER\n\n\nclass LogAssertionError(AssertionError):\n \"\"\"An assertion error raised in relation to log messages.\"\"\"\n\n\n@contextmanager\ndef assert_reactpy_did_log(\n match_message: str = \"\",\n error_type: type[Exception] | None = None,\n match_error: str = \"\",\n) -> Iterator[None]:\n \"\"\"Assert that ReactPy produced a log matching the described message or error.\n\n Args:\n match_message: Must match a logged message.\n error_type: Checks the type of logged exceptions.\n match_error: Must match an error message.\n \"\"\"\n message_pattern = re.compile(match_message)\n error_pattern = re.compile(match_error)\n\n with capture_reactpy_logs() as log_records:\n try:\n yield None\n except Exception:\n raise\n else:\n for record in list(log_records):\n if (\n # record message matches\n message_pattern.findall(record.getMessage())\n # error type matches\n and (\n error_type is None\n or (\n record.exc_info is not None\n and record.exc_info[0] is not None\n and issubclass(record.exc_info[0], error_type)\n )\n )\n # error message pattern matches\n and (\n not match_error\n or (\n record.exc_info is not None\n and error_pattern.findall(\n \"\".join(format_exception(*record.exc_info))\n )\n )\n )\n ):\n break\n else: # nocov\n _raise_log_message_error(\n \"Could not find a log record matching the given\",\n match_message,\n error_type,\n match_error,\n )\n\n\n@contextmanager\ndef assert_reactpy_did_not_log(\n match_message: str = \"\",\n error_type: type[Exception] | None = None,\n match_error: str = \"\",\n) -> Iterator[None]:\n \"\"\"Assert the inverse of :func:`assert_reactpy_logged`\"\"\"\n try:\n with assert_reactpy_did_log(match_message, error_type, match_error):\n yield None\n except LogAssertionError:\n pass\n else:\n _raise_log_message_error(\n \"Did find a log record matching the given\",\n match_message,\n error_type,\n match_error,\n )\n\n\ndef list_logged_exceptions(\n log_records: list[logging.LogRecord],\n pattern: str = \"\",\n types: type[Any] | tuple[type[Any], ...] = Exception,\n log_level: int = logging.ERROR,\n del_log_records: bool = True,\n) -> list[BaseException]:\n \"\"\"Return a list of logged exception matching the given criteria\n\n Args:\n log_level: The level of log to check\n exclude_exc_types: Any exception types to ignore\n del_log_records: Whether to delete the log records for yielded exceptions\n \"\"\"\n found: list[BaseException] = []\n compiled_pattern = re.compile(pattern)\n for index, record in enumerate(log_records):\n if record.levelno >= log_level and record.exc_info:\n error = record.exc_info[1]\n if (\n error is not None\n and isinstance(error, types)\n and compiled_pattern.search(str(error))\n ):\n if del_log_records:\n del log_records[index - len(found)]\n found.append(error)\n return found\n\n\n@contextmanager\ndef capture_reactpy_logs() -> Iterator[list[logging.LogRecord]]:\n \"\"\"Capture logs from ReactPy\n\n Any logs produced in this context are cleared afterwards\n \"\"\"\n original_level = ROOT_LOGGER.level\n ROOT_LOGGER.setLevel(logging.DEBUG)\n try:\n if _LOG_RECORD_CAPTOR in ROOT_LOGGER.handlers:\n start_index = len(_LOG_RECORD_CAPTOR.records)\n try:\n yield _LOG_RECORD_CAPTOR.records\n finally:\n end_index = len(_LOG_RECORD_CAPTOR.records)\n _LOG_RECORD_CAPTOR.records[start_index:end_index] = []\n return None\n\n ROOT_LOGGER.addHandler(_LOG_RECORD_CAPTOR)\n try:\n yield _LOG_RECORD_CAPTOR.records\n finally:\n ROOT_LOGGER.removeHandler(_LOG_RECORD_CAPTOR)\n _LOG_RECORD_CAPTOR.records.clear()\n finally:\n ROOT_LOGGER.setLevel(original_level)\n\n\nclass _LogRecordCaptor(logging.NullHandler):\n def __init__(self) -> None:\n self.records: list[logging.LogRecord] = []\n super().__init__()\n\n def handle(self, record: logging.LogRecord) -> bool:\n self.records.append(record)\n return True\n\n\n_LOG_RECORD_CAPTOR = _LogRecordCaptor()\n\n\ndef _raise_log_message_error(\n prefix: str,\n match_message: str = \"\",\n error_type: type[Exception] | None = None,\n match_error: str = \"\",\n) -> NoReturn:\n conditions = []\n if match_message:\n conditions.append(f\"log message pattern {match_message!r}\")\n if error_type:\n conditions.append(f\"exception type {error_type}\")\n if match_error:\n conditions.append(f\"error message pattern {match_error!r}\")\n raise LogAssertionError(prefix + \" \" + \" and \".join(conditions))\n\n# Path: src/py/reactpy/reactpy/testing/backend.py\nfrom __future__ import annotations\n\nimport asyncio\nimport logging\nfrom contextlib import AsyncExitStack, suppress\nfrom types import TracebackType\nfrom typing import Any, Callable\nfrom urllib.parse import urlencode, urlunparse\n\nfrom reactpy.backend import default as default_server\nfrom reactpy.backend.types import BackendType\nfrom reactpy.backend.utils import find_available_port\nfrom reactpy.config import REACTPY_TESTING_DEFAULT_TIMEOUT\nfrom reactpy.core.component import component\nfrom reactpy.core.hooks import use_callback, use_effect, use_state\nfrom reactpy.core.types import ComponentConstructor\nfrom reactpy.testing.logs import (\n LogAssertionError,\n capture_reactpy_logs,\n list_logged_exceptions,\n)\nfrom reactpy.utils import Ref\n\n\nclass BackendFixture:\n \"\"\"A test fixture for running a server and imperatively displaying views\n\n This fixture is typically used alongside async web drivers like ``playwight``.\n\n Example:\n .. code-block::\n\n async with BackendFixture() as server:\n server.mount(MyComponent)\n \"\"\"\n\n _records: list[logging.LogRecord]\n _server_future: asyncio.Task[Any]\n _exit_stack = AsyncExitStack()\n\n def __init__(\n self,\n host: str = \"127.0.0.1\",\n port: int | None = None,\n app: Any | None = None,\n implementation: BackendType[Any] | None = None,\n options: Any | None = None,\n timeout: float | None = None,\n ) -> None:\n self.host = host\n self.port = port or find_available_port(host)\n self.mount, self._root_component = _hotswap()\n self.timeout = (\n REACTPY_TESTING_DEFAULT_TIMEOUT.current if timeout is None else timeout\n )\n\n if app is not None and implementation is None:\n msg = \"If an application instance its corresponding server implementation must be provided too.\"\n raise ValueError(msg)\n\n self._app = app\n self.implementation = implementation or default_server\n self._options = options\n\n @property\n def log_records(self) -> list[logging.LogRecord]:\n \"\"\"A list of captured log records\"\"\"\n return self._records\n\n def url(self, path: str = \"\", query: Any | None = None) -> str:\n \"\"\"Return a URL string pointing to the host and point of the server\n\n Args:\n path: the path to a resource on the server\n query: a dictionary or list of query parameters\n \"\"\"\n return urlunparse(\n [\n \"http\",\n f\"{self.host}:{self.port}\",\n path,\n \"\",\n urlencode(query or ()),\n \"\",\n ]\n )\n\n def list_logged_exceptions(\n self,\n pattern: str = \"\",\n types: type[Any] | tuple[type[Any], ...] = Exception,\n log_level: int = logging.ERROR,\n del_log_records: bool = True,\n ) -> list[BaseException]:\n \"\"\"Return a list of logged exception matching the given criteria\n\n Args:\n log_level: The level of log to check\n exclude_exc_types: Any exception types to ignore\n del_log_records: Whether to delete the log records for yielded exceptions\n \"\"\"\n return list_logged_exceptions(\n self.log_records,\n pattern,\n types,\n log_level,\n del_log_records,\n )\n\n async def __aenter__(self) -> BackendFixture:\n self._exit_stack = AsyncExitStack()\n self._records = self._exit_stack.enter_context(capture_reactpy_logs())\n\n app = self._app or self.implementation.create_development_app()\n self.implementation.configure(app, self._root_component, self._options)\n\n started = asyncio.Event()\n server_future = asyncio.create_task(\n self.implementation.serve_development_app(\n app, self.host, self.port, started\n )\n )\n\n \nasync def stop_server() -> None:\n server_future.cancel()\n with suppress(asyncio.CancelledError):\n await asyncio.wait_for(server_future, timeout=self.timeout)\n\n self._exit_stack.push_async_callback(stop_server)\n\n try:\n await asyncio.wait_for(started.wait(), timeout=self.timeout)\n except Exception: # nocov\n # see if we can await the future for a more helpful error\n await asyncio.wait_for(server_future, timeout=self.timeout)\n raise\n\n return self\n\n async def __aexit__(\n self,\n exc_type: type[BaseException] | None,\n exc_value: BaseException | None,\n traceback: TracebackType | None,\n ) -> None:\n await self._exit_stack.aclose()\n\n self.mount(None) # reset the view\n\n logged_errors = self.list_logged_exceptions(del_log_records=False)\n if logged_errors: # nocov\n msg = \"Unexpected logged exception\"\n raise LogAssertionError(msg) from logged_errors[0]\n\n\n_MountFunc = Callable[[\"Callable[[], Any] | None\"], None]\n\n\ndef _hotswap(update_on_change: bool = False) -> tuple[_MountFunc, ComponentConstructor]:\n \"\"\"Swap out components from a layout on the fly.\n\n Since you can't change the component functions used to create a layout\n in an imperative manner, you can use ``hotswap`` to do this so\n long as you set things up ahead of time.\n\n Parameters:\n update_on_change: Whether or not all views of the layout should be updated on a swap.\n\n Example:\n .. code-block:: python\n\n import reactpy\n\n show, root = reactpy.hotswap()\n PerClientStateServer(root).run_in_thread(\"localhost\", 8765)\n\n @reactpy.component\n def DivOne(self):\n return {\"tagName\": \"div\", \"children\": [1]}\n\n show(DivOne)\n\n # displaying the output now will show DivOne\n\n @reactpy.component\n def DivTwo(self):\n return {\"tagName\": \"div\", \"children\": [2]}\n\n show(DivTwo)\n\n # displaying the output now will show DivTwo\n \"\"\"\n constructor_ref: Ref[Callable[[], Any]] = Ref(lambda: None)\n\n if update_on_change:\n set_constructor_callbacks: set[Callable[[Callable[[], Any]], None]] = set()\n\n @component\n def HotSwap() -> Any:\n # new displays will adopt the latest constructor and arguments\n constructor, _set_constructor = use_state(lambda: constructor_ref.current)\n set_constructor = use_callback(lambda new: _set_constructor(lambda _: new))\n\n def add_callback() -> Callable[[], None]:\n set_constructor_callbacks.add(set_constructor)\n return lambda: set_constructor_callbacks.remove(set_constructor)\n\n use_effect(add_callback)\n\n return constructor()\n\n def swap(constructor: Callable[[], Any] | None) -> None:\n constructor = constructor_ref.current = constructor or (lambda: None)\n\n for set_constructor in set_constructor_callbacks:\n set_constructor(constructor)\n\n else:\n\n @component\n def HotSwap() -> Any:\n return constructor_ref.current()\n\n def swap(constructor: Callable[[], Any] | None) -> None:\n constructor_ref.current = constructor or (lambda: None)\n\n return swap, HotSwap\n\n# Path: src/py/reactpy/reactpy/testing/common.py\nfrom __future__ import annotations\n\nimport asyncio\nimport inspect\nimport shutil\nimport time\nfrom collections.abc import Awaitable\nfrom functools import wraps\nfrom typing import Any, Callable, Generic, TypeVar, cast\nfrom uuid import uuid4\nfrom weakref import ref\n\nfrom typing_extensions import ParamSpec\n\nfrom reactpy.config import REACTPY_TESTING_DEFAULT_TIMEOUT, REACTPY_WEB_MODULES_DIR\nfrom reactpy.core._life_cycle_hook import LifeCycleHook, current_hook\nfrom reactpy.core.events import EventHandler, to_event_handler_function\n\n\ndef clear_reactpy_web_modules_dir() -> None:\n \"\"\"Clear the directory where ReactPy stores registered web modules\"\"\"\n for path in REACTPY_WEB_MODULES_DIR.current.iterdir():\n shutil.rmtree(path) if path.is_dir() else path.unlink()\n\n\n_P = ParamSpec(\"_P\")\n_R = TypeVar(\"_R\")\n\n\n_DEFAULT_POLL_DELAY = 0.1\n\n\nclass poll(Generic[_R]): # noqa: N801\n \"\"\"Wait until the result of an sync or async function meets some condition\"\"\"\n\n def __init__(\n self,\n function: Callable[_P, Awaitable[_R] | _R],\n *args: _P.args,\n **kwargs: _P.kwargs,\n ) -> None:\n coro: Callable[_P, Awaitable[_R]]\n if not inspect.iscoroutinefunction(function):\n\n async def coro(*args: _P.args, **kwargs: _P.kwargs) -> _R:\n return cast(_R, function(*args, **kwargs))\n\n else:\n coro = cast(Callable[_P, Awaitable[_R]], function)\n self._func = coro\n self._args = args\n self._kwargs = kwargs\n\n async def until(\n self,\n condition: Callable[[_R], bool],\n timeout: float = REACTPY_TESTING_DEFAULT_TIMEOUT.current,\n delay: float = _DEFAULT_POLL_DELAY,\n description: str = \"condition to be true\",\n ) -> None:\n \"\"\"Check that the coroutines result meets a condition within the timeout\"\"\"\n started_at = time.time()\n while True:\n await asyncio.sleep(delay)\n result = await self._func(*self._args, **self._kwargs)\n if condition(result):\n break\n elif (time.time() - started_at) > timeout: # nocov\n msg = f\"Expected {description} after {timeout} seconds - last value was {result!r}\"\n raise asyncio.TimeoutError(msg)\n\n async def until_is(\n self,\n right: _R,\n timeout: float = REACTPY_TESTING_DEFAULT_TIMEOUT.current,\n delay: float = _DEFAULT_POLL_DELAY,\n ) -> None:\n \"\"\"Wait until the result is identical to the given value\"\"\"\n return await self.until(\n lambda left: left is right,\n timeout,\n delay,\n f\"value to be identical to {right!r}\",\n )\n\n async def until_equals(\n self,\n right: _R,\n timeout: float = REACTPY_TESTING_DEFAULT_TIMEOUT.current,\n delay: float = _DEFAULT_POLL_DELAY,\n ) -> None:\n \"\"\"Wait until the result is equal to the given value\"\"\"\n return await self.until(\n lambda left: left == right,\n timeout,\n delay,\n f\"value to equal {right!r}\",\n )\n\n\nclass HookCatcher:\n \"\"\"Utility for capturing a LifeCycleHook from a component\n\n Example:\n .. code-block::\n\n hooks = HookCatcher(index_by_kwarg=\"thing\")\n\n @reactpy.component\n @hooks.capture\n def MyComponent(thing):\n ...\n\n ... # render the component\n\n # grab the last render of where MyComponent(thing='something')\n hooks.index[\"something\"]\n # or grab the hook from the component's last render\n hooks.latest\n\n After the first render of ``MyComponent`` the ``HookCatcher`` will have\n captured the component's ``LifeCycleHook``.\n \"\"\"\n\n latest: LifeCycleHook\n\n def __init__(self, index_by_kwarg: str | None = None):\n self.index_by_kwarg = index_by_kwarg\n self.index: dict[Any, LifeCycleHook] = {}\n\n def capture(self, render_function: Callable[..., Any]) -> Callable[..., Any]:\n \"\"\"Decorator for capturing a ``LifeCycleHook`` on each render of a component\"\"\"\n\n # The render function holds a reference to `self` and, via the `LifeCycleHook`,\n # the component. Some tests check whether components are garbage collected, thus\n # we must use a `ref` here to ensure these checks pass once the catcher itself\n # has been collected.\n self_ref = ref(self)\n\n @wraps(render_function)\n def wrapper(*args: Any, **kwargs: Any) -> Any:\n self = self_ref()\n if self is None:\n raise RuntimeError(\"Hook catcher has been garbage collected\")\n\n hook = current_hook()\n if self.index_by_kwarg is not None:\n self.index[kwargs[self.index_by_kwarg]] = hook\n self.latest = hook\n return render_function(*args, **kwargs)\n\n return wrapper\n\n\nclass StaticEventHandler:\n \"\"\"Utility for capturing the target of one event handler\n\n Example:\n .. code-block::\n\n static_handler = StaticEventHandler()\n\n @reactpy.component\n def MyComponent():\n state, set_state = reactpy.hooks.use_state(0)\n handler = static_handler.use(lambda event: set_state(state + 1))\n return reactpy.html.button({\"onClick\": handler}, \"Click me!\")\n\n # gives the target ID for onClick where from the last render of MyComponent\n static_handlers.target\n\n If you need to capture event handlers from different instances of a component\n the you should create multiple ``StaticEventHandler`` instances.\n\n .. code-block::\n\n static_handlers_by_key = {\n \"first\": StaticEventHandler(),\n \"second\": StaticEventHandler(),\n }\n\n @reactpy.component\n def Parent():\n return reactpy.html.div(Child(key=\"first\"), Child(key=\"second\"))\n\n @reactpy.component\n def Child(key):\n state, set_state = reactpy.hooks.use_state(0)\n handler = static_handlers_by_key[key].use(lambda event: set_state(state + 1))\n return reactpy.html.button({\"onClick\": handler}, \"Click me!\")\n\n # grab the individual targets for each instance above\n first_target = static_handlers_by_key[\"first\"].target\n second_target = static_handlers_by_key[\"second\"].target\n \"\"\"\n\n def __init__(self) -> None:\n self.target = uuid4().hex\n\n def use(\n self,\n function: Callable[..., Any],\n stop_propagation: bool = False,\n prevent_default: bool = False,\n ) -> EventHandler:\n return EventHandler(\n to_event_handler_function(function),\n stop_propagation,\n prevent_default,\n self.target,\n )\n\n# Path: src/py/reactpy/reactpy/testing/display.py\nfrom __future__ import annotations\n\nfrom contextlib import AsyncExitStack\nfrom types import TracebackType\nfrom typing import Any\n\nfrom playwright.async_api import (\n Browser,\n BrowserContext,\n ElementHandle,\n Page,\n async_playwright,\n)\n\nfrom reactpy.config import REACTPY_TESTING_DEFAULT_TIMEOUT\nfrom reactpy.testing.backend import BackendFixture\nfrom reactpy.types import RootComponentConstructor\n\n\nclass DisplayFixture:\n \"\"\"A fixture for running web-based tests using ``playwright``\"\"\"\n\n _exit_stack: AsyncExitStack\n\n def __init__(\n self,\n backend: BackendFixture | None = None,\n driver: Browser | BrowserContext | Page | None = None,\n url_prefix: str = \"\",\n ) -> None:\n if backend is not None:\n self.backend = backend\n if driver is not None:\n if isinstance(driver, Page):\n self.page = driver\n else:\n self._browser = driver\n self.url_prefix = url_prefix\n\n async def show(\n self,\n component: RootComponentConstructor,\n ) -> None:\n self.backend.mount(component)\n await self.goto(\"/\")\n await self.root_element() # check that root element is attached\n\n async def goto(\n self, path: str, query: Any | None = None, add_url_prefix: bool = True\n ) -> None:\n await self.page.goto(\n self.backend.url(\n f\"{self.url_prefix}{path}\" if add_url_prefix else path, query\n )\n )\n\n async def root_element(self) -> ElementHandle:\n element = await self.page.wait_for_selector(\"#app\", state=\"attached\")\n if element is None: # nocov\n msg = \"Root element not attached\"\n raise RuntimeError(msg)\n return element\n\n async def __aenter__(self) -> DisplayFixture:\n es = self._exit_stack = AsyncExitStack()\n\n browser: Browser | BrowserContext\n if not hasattr(self, \"page\"):\n if not hasattr(self, \"_browser\"):\n pw = await es.enter_async_context(async_playwright())\n browser = await pw.chromium.launch()\n else:\n browser = self._browser\n self.page = await browser.new_page()\n\n self.page.set_default_timeout(REACTPY_TESTING_DEFAULT_TIMEOUT.current * 1000)\n\n if not hasattr(self, \"backend\"):\n self.backend = BackendFixture()\n await es.enter_async_context(self.backend)\n\n return self\n\n async def __aexit__(\n self,\n exc_type: type[BaseException] | None,\n exc_value: BaseException | None,\n traceback: TracebackType | None,\n ) -> None:\n self.backend.mount(None)\n await self._exit_stack.aclose()\n\n# Path: src/py/reactpy/reactpy/testing/__init__.py\nfrom reactpy.testing.backend import BackendFixture\nfrom reactpy.testing.common import (\n HookCatcher,\n StaticEventHandler,\n clear_reactpy_web_modules_dir,\n poll,\n)\nfrom reactpy.testing.display import DisplayFixture\nfrom reactpy.testing.logs import (\n LogAssertionError,\n assert_reactpy_did_log,\n assert_reactpy_did_not_log,\n capture_reactpy_logs,\n)\n\n__all__ = [\n \"assert_reactpy_did_not_log\",\n \"assert_reactpy_did_log\",\n \"capture_reactpy_logs\",\n \"clear_reactpy_web_modules_dir\",\n \"DisplayFixture\",\n \"HookCatcher\",\n \"LogAssertionError\",\n \"poll\",\n \"BackendFixture\",\n \"StaticEventHandler\",\n]\n\n# Path: src/py/reactpy/reactpy/widgets.py\nfrom __future__ import annotations\n\nfrom base64 import b64encode\nfrom collections.abc import Sequence\nfrom typing import TYPE_CHECKING, Any, Callable, Protocol, TypeVar\n\nimport reactpy\nfrom reactpy import html\nfrom reactpy._warnings import warn\nfrom reactpy.core.types import ComponentConstructor, VdomDict\n\n\ndef image(\n format: str,\n value: str | bytes = \"\",\n attributes: dict[str, Any] | None = None,\n) -> VdomDict:\n \"\"\"Utility for constructing an image from a string or bytes\n\n The source value will automatically be encoded to base64\n \"\"\"\n if format == \"svg\":\n format = \"svg+xml\" # noqa: A001\n\n if isinstance(value, str):\n bytes_value = value.encode()\n else:\n bytes_value = value\n\n base64_value = b64encode(bytes_value).decode()\n src = f\"data:image/{format};base64,{base64_value}\"\n\n return {\"tagName\": \"img\", \"attributes\": {\"src\": src, **(attributes or {})}}\n\n\n_Value = TypeVar(\"_Value\")\n\n\ndef use_linked_inputs(\n attributes: Sequence[dict[str, Any]],\n on_change: Callable[[_Value], None] = lambda value: None,\n cast: _CastFunc[_Value] = lambda value: value,\n initial_value: str = \"\",\n ignore_empty: bool = True,\n) -> list[VdomDict]:\n \"\"\"Return a list of linked inputs equal to the number of given attributes.\n\n Parameters:\n attributes:\n That attributes of each returned input element. If the number of generated\n inputs is variable, you may need to assign each one a\n :ref:`key ` by including a ``\"key\"`` in each\n attribute dictionary.\n on_change:\n A callback which is triggered when any input is changed. This callback need\n not update the 'value' field in the attributes of the inputs since that is\n handled automatically.\n cast:\n Cast the 'value' of changed inputs that is passed to ``on_change``.\n initial_value:\n Initialize the 'value' field of the inputs.\n ignore_empty:\n Do not trigger ``on_change`` if the 'value' is an empty string.\n \"\"\"\n value, set_value = reactpy.hooks.use_state(initial_value)\n\n def sync_inputs(event: dict[str, Any]) -> None:\n new_value = event[\"target\"][\"value\"]\n set_value(new_value)\n if not new_value and ignore_empty:\n return None\n on_change(cast(new_value))\n\n inputs: list[VdomDict] = []\n for attrs in attributes:\n inputs.append(html.input({**attrs, \"on_change\": sync_inputs, \"value\": value}))\n\n return inputs\n\n\n_CastTo_co = TypeVar(\"_CastTo_co\", covariant=True)\n\n\nclass _CastFunc(Protocol[_CastTo_co]):\n def __call__(self, value: str) -> _CastTo_co: ...\n\n\nif TYPE_CHECKING:\n from reactpy.testing.backend import _MountFunc\n\n\ndef hotswap(\n update_on_change: bool = False,\n) -> tuple[_MountFunc, ComponentConstructor]: # nocov\n warn(\n \"The 'hotswap' function is deprecated and will be removed in a future release\",\n DeprecationWarning,\n stacklevel=2,\n )\n from reactpy.testing.backend import _hotswap\n\n return _hotswap(update_on_change)\n\n# Path: src/py/reactpy/reactpy/__init__.py\nfrom reactpy import backend, config, html, logging, sample, svg, types, web, widgets\nfrom reactpy.backend.hooks import use_connection, use_location, use_scope\nfrom reactpy.backend.utils import run\nfrom reactpy.core import hooks\nfrom reactpy.core.component import component\nfrom reactpy.core.events import event\nfrom reactpy.core.hooks import (\n create_context,\n use_callback,\n use_context,\n use_debug_value,\n use_effect,\n use_memo,\n use_reducer,\n use_ref,\n use_state,\n)\nfrom reactpy.core.layout import Layout\nfrom reactpy.core.vdom import vdom\nfrom reactpy.utils import Ref, html_to_vdom, vdom_to_html\n\n__author__ = \"The Reactive Python Team\"\n__version__ = \"1.0.2\" # DO NOT MODIFY\n\n__all__ = [\n \"backend\",\n \"component\",\n \"config\",\n \"create_context\",\n \"event\",\n \"hooks\",\n \"html_to_vdom\",\n \"html\",\n \"Layout\",\n \"logging\",\n \"Ref\",\n \"run\",\n \"sample\",\n \"Stop\",\n \"svg\",\n \"types\",\n \"use_callback\",\n \"use_connection\",\n \"use_context\",\n \"use_debug_value\",\n \"use_effect\",\n \"use_location\",\n \"use_memo\",\n \"use_reducer\",\n \"use_ref\",\n \"use_scope\",\n \"use_state\",\n \"vdom_to_html\",\n \"vdom\",\n \"web\",\n \"widgets\",\n]\n\n# Path: src/py/reactpy/reactpy/_console/__init__.py\n\n# Path: src/py/reactpy/reactpy/_console/ast_utils.py\nfrom __future__ import annotations\n\nimport ast\nfrom collections.abc import Iterator, Sequence\nfrom dataclasses import dataclass\nfrom pathlib import Path\nfrom textwrap import indent\nfrom tokenize import COMMENT as COMMENT_TOKEN\nfrom tokenize import generate_tokens\nfrom typing import Any\n\nimport click\n\nfrom reactpy import html\n\n\ndef rewrite_changed_nodes(\n file: Path,\n source: str,\n tree: ast.AST,\n changed: list[ChangedNode],\n) -> str:\n ast.fix_missing_locations(tree)\n\n lines = source.split(\"\\n\")\n\n # find closest parent nodes that should be re-written\n nodes_to_unparse: list[ast.AST] = []\n for change in changed:\n node_lineage = [change.node, *change.parents]\n for i in range(len(node_lineage) - 1):\n current_node, next_node = node_lineage[i : i + 2]\n if (\n not hasattr(next_node, \"lineno\")\n or next_node.lineno < change.node.lineno\n or isinstance(next_node, (ast.ClassDef, ast.FunctionDef))\n ):\n nodes_to_unparse.append(current_node)\n break\n else: # nocov\n msg = \"Failed to change code\"\n raise RuntimeError(msg)\n\n # check if an nodes to rewrite contain each other, pick outermost nodes\n current_outermost_node, *sorted_nodes_to_unparse = sorted(\n nodes_to_unparse, key=lambda n: n.lineno\n )\n outermost_nodes_to_unparse = [current_outermost_node]\n for node in sorted_nodes_to_unparse:\n if (\n not current_outermost_node.end_lineno\n or node.lineno > current_outermost_node.end_lineno\n ):\n current_outermost_node = node\n outermost_nodes_to_unparse.append(node)\n\n moved_comment_lines_from_end: list[int] = []\n # now actually rewrite these nodes (in reverse to avoid changes earlier in file)\n for node in reversed(outermost_nodes_to_unparse):\n # make a best effort to preserve any comments that we're going to overwrite\n comments = _find_comments(lines[node.lineno - 1 : node.end_lineno])\n\n # there may be some content just before and after the content we're re-writing\n before_replacement = lines[node.lineno - 1][: node.col_offset].lstrip()\n\n after_replacement = (\n lines[node.end_lineno - 1][node.end_col_offset :].strip()\n if node.end_lineno is not None and node.end_col_offset is not None\n else \"\"\n )\n\n replacement = indent(\n before_replacement\n + \"\\n\".join([*comments, ast.unparse(node)])\n + after_replacement,\n \" \" * (node.col_offset - len(before_replacement)),\n )\n\n lines[node.lineno - 1 : node.end_lineno or node.lineno] = [replacement]\n\n if comments:\n moved_comment_lines_from_end.append(len(lines) - node.lineno)\n\n for lineno_from_end in sorted(set(moved_comment_lines_from_end)):\n click.echo(f\"Moved comments to {file}:{len(lines) - lineno_from_end}\")\n\n return \"\\n\".join(lines)\n\n\n@dataclass\nclass ChangedNode:\n node: ast.AST\n parents: Sequence[ast.AST]\n\n\ndef find_element_constructor_usages(\n tree: ast.AST, add_props: bool = False\n) -> Iterator[ElementConstructorInfo]:\n changed: list[Sequence[ast.AST]] = []\n for parents, node in _walk_with_parent(tree):\n if not (isinstance(node, ast.Call)):\n continue\n\n func = node.func\n if isinstance(func, ast.Attribute) and (\n (isinstance(func.value, ast.Name) and func.value.id == \"html\")\n or (isinstance(func.value, ast.Attribute) and func.value.attr == \"html\")\n ):\n name = func.attr\n elif isinstance(func, ast.Name):\n name = func.id\n else:\n continue\n\n maybe_attr_dict_node: Any | None = None\n\n if name == \"vdom\":\n if len(node.args) == 0:\n continue\n elif len(node.args) == 1:\n maybe_attr_dict_node = ast.Dict(keys=[], values=[])\n if add_props:\n node.args.append(maybe_attr_dict_node)\n else:\n continue\n elif isinstance(node.args[1], (ast.Constant, ast.JoinedStr)):\n maybe_attr_dict_node = ast.Dict(keys=[], values=[])\n if add_props:\n node.args.insert(1, maybe_attr_dict_node)\n else:\n continue\n elif len(node.args) >= 2: # noqa: PLR2004\n maybe_attr_dict_node = node.args[1]\n elif hasattr(html, name):\n if len(node.args) == 0:\n maybe_attr_dict_node = ast.Dict(keys=[], values=[])\n if add_props:\n node.args.append(maybe_attr_dict_node)\n else:\n continue\n elif isinstance(node.args[0], (ast.Constant, ast.JoinedStr)):\n maybe_attr_dict_node = ast.Dict(keys=[], values=[])\n if add_props:\n node.args.insert(0, maybe_attr_dict_node)\n else:\n continue\n else:\n maybe_attr_dict_node = node.args[0]\n\n if not maybe_attr_dict_node:\n continue\n\n if isinstance(maybe_attr_dict_node, ast.Dict) or (\n isinstance(maybe_attr_dict_node, ast.Call)\n and isinstance(maybe_attr_dict_node.func, ast.Name)\n and maybe_attr_dict_node.func.id == \"dict\"\n and isinstance(maybe_attr_dict_node.func.ctx, ast.Load)\n ):\n yield ElementConstructorInfo(node, maybe_attr_dict_node, parents)\n\n return changed\n\n\n@dataclass\nclass ElementConstructorInfo:\n call: ast.Call\n props: ast.Dict | ast.Call\n parents: Sequence[ast.AST]\n\n\ndef _find_comments(lines: list[str]) -> list[str]:\n iter_lines = iter(lines)\n return [\n token\n for token_type, token, _, _, _ in generate_tokens(lambda: next(iter_lines))\n if token_type == COMMENT_TOKEN\n ]\n\n\ndef _walk_with_parent(\n node: ast.AST, parents: tuple[ast.AST, ...] = ()\n) -> Iterator[tuple[tuple[ast.AST, ...], ast.AST]]:\n parents = (node, *parents)\n for child in ast.iter_child_nodes(node):\n yield parents, child\n...\n", "needle_token_start": 8800, "needle_token_end": 8853, "code_context_ntokens": 16034, "output": ["```python\nasync def stop_server() -> None:\n server_future.cancel()\n with suppress(asyncio.CancelledError):\n await asyncio.wait_for(server_future, timeout=self.timeout)\n```"]} +{"repo": "reactive-python/reactpy", "name": "create_context", "language": "python", "path": "src/py/reactpy/reactpy/core/hooks.py", "position_ratio": 0.65, "description": "\nFunction Description:\n1. **Purpose**: The function is designed to generate a new context type that can be utilized within a specific context management function to share data across different parts of an application without having to pass props down manually through every level of the component tree.\n2. **Input**: A single parameter representing the default value for the context.\n3. **Output**: Returns a context object that can be used to provide and consume values in different parts of the application.\n4. **Procedure**: The function defines an inner function that acts as a context provider, capable of accepting children components and an optional value to override the default. This inner function is then returned as the context object, which can be used in conjunction with a context consumer function to access the provided value in descendant components.\n\n", "instruction": "Based on the function description and code context, please retrieve and repeat the exact described function from the code context in a code block wrapped by ```:", "template": "instruction\ncode_context\ndescription\ninstruction", "code_context": "# Path: src/py/reactpy/reactpy/core/_life_cycle_hook.py\nfrom __future__ import annotations\n\nimport logging\nfrom asyncio import Event, Task, create_task, gather\nfrom typing import Any, Callable, Protocol, TypeVar\n\nfrom anyio import Semaphore\n...\n# Path: src/py/reactpy/reactpy/core/_f_back.py\nfrom __future__ import annotations\n\nimport inspect\nfrom types import FrameType\n\n\ndef f_module_name(index: int = 0) -> str:\n frame = f_back(index + 1)\n if frame is None:\n return \"\" # nocov\n name = frame.f_globals.get(\"__name__\", \"\")\n if not isinstance(name, str):\n raise TypeError(\"Expected module name to be a string\") # nocov\n return name\n\n\ndef f_back(index: int = 0) -> FrameType | None:\n frame = inspect.currentframe()\n while frame is not None:\n if index < 0:\n return frame\n frame = frame.f_back\n index -= 1\n return None # nocov\n\n# Path: src/py/reactpy/reactpy/core/events.py\nfrom __future__ import annotations\n\nimport asyncio\nfrom collections.abc import Sequence\nfrom typing import Any, Callable, Literal, overload\n\nfrom anyio import create_task_group\n\nfrom reactpy.core.types import EventHandlerFunc, EventHandlerType\n\n\n@overload\ndef event(\n function: Callable[..., Any],\n *,\n stop_propagation: bool = ...,\n prevent_default: bool = ...,\n) -> EventHandler: ...\n\n\n@overload\ndef event(\n function: Literal[None] = ...,\n *,\n stop_propagation: bool = ...,\n prevent_default: bool = ...,\n) -> Callable[[Callable[..., Any]], EventHandler]: ...\n\n\ndef event(\n function: Callable[..., Any] | None = None,\n *,\n stop_propagation: bool = False,\n prevent_default: bool = False,\n) -> EventHandler | Callable[[Callable[..., Any]], EventHandler]:\n \"\"\"A decorator for constructing an :class:`EventHandler`.\n\n While you're always free to add callbacks by assigning them to an element's attributes\n\n .. code-block:: python\n\n element = reactpy.html.button({\"onClick\": my_callback})\n\n You may want the ability to prevent the default action associated with the event\n from taking place, or stopping the event from propagating up the DOM. This decorator\n allows you to add that functionality to your callbacks.\n\n .. code-block:: python\n\n @event(stop_propagation=True, prevent_default=True)\n def my_callback(*data):\n ...\n\n element = reactpy.html.button({\"onClick\": my_callback})\n\n Parameters:\n function:\n A function or coroutine responsible for handling the event.\n stop_propagation:\n Block the event from propagating further up the DOM.\n prevent_default:\n Stops the default actional associate with the event from taking place.\n \"\"\"\n\n def setup(function: Callable[..., Any]) -> EventHandler:\n return EventHandler(\n to_event_handler_function(function, positional_args=True),\n stop_propagation,\n prevent_default,\n )\n\n if function is not None:\n return setup(function)\n else:\n return setup\n\n\nclass EventHandler:\n \"\"\"Turn a function or coroutine into an event handler\n\n Parameters:\n function:\n The function or coroutine which handles the event.\n stop_propagation:\n Block the event from propagating further up the DOM.\n prevent_default:\n Stops the default action associate with the event from taking place.\n target:\n A unique identifier for this event handler (auto-generated by default)\n \"\"\"\n\n __slots__ = (\n \"__weakref__\",\n \"function\",\n \"prevent_default\",\n \"stop_propagation\",\n \"target\",\n )\n\n def __init__(\n self,\n function: EventHandlerFunc,\n stop_propagation: bool = False,\n prevent_default: bool = False,\n target: str | None = None,\n ) -> None:\n self.function = to_event_handler_function(function, positional_args=False)\n self.prevent_default = prevent_default\n self.stop_propagation = stop_propagation\n self.target = target\n\n def __eq__(self, other: Any) -> bool:\n undefined = object()\n for attr in (\n \"function\",\n \"prevent_default\",\n \"stop_propagation\",\n \"target\",\n ):\n if not attr.startswith(\"_\"):\n if not getattr(other, attr, undefined) == getattr(self, attr):\n return False\n return True\n\n def __repr__(self) -> str:\n public_names = [name for name in self.__slots__ if not name.startswith(\"_\")]\n items = \", \".join([f\"{n}={getattr(self, n)!r}\" for n in public_names])\n return f\"{type(self).__name__}({items})\"\n\n\ndef to_event_handler_function(\n function: Callable[..., Any],\n positional_args: bool = True,\n) -> EventHandlerFunc:\n \"\"\"Make a :data:`~reactpy.core.proto.EventHandlerFunc` from a function or coroutine\n\n Parameters:\n function:\n A function or coroutine accepting a number of positional arguments.\n positional_args:\n Whether to pass the event parameters a positional args or as a list.\n \"\"\"\n if positional_args:\n if asyncio.iscoroutinefunction(function):\n\n async def wrapper(data: Sequence[Any]) -> None:\n await function(*data)\n\n else:\n\n async def wrapper(data: Sequence[Any]) -> None:\n function(*data)\n\n return wrapper\n elif not asyncio.iscoroutinefunction(function):\n\n async def wrapper(data: Sequence[Any]) -> None:\n function(data)\n\n return wrapper\n else:\n return function\n\n\ndef merge_event_handlers(\n event_handlers: Sequence[EventHandlerType],\n) -> EventHandlerType:\n \"\"\"Merge multiple event handlers into one\n\n Raises a ValueError if any handlers have conflicting\n :attr:`~reactpy.core.proto.EventHandlerType.stop_propagation` or\n :attr:`~reactpy.core.proto.EventHandlerType.prevent_default` attributes.\n \"\"\"\n if not event_handlers:\n msg = \"No event handlers to merge\"\n raise ValueError(msg)\n elif len(event_handlers) == 1:\n return event_handlers[0]\n\n first_handler = event_handlers[0]\n\n stop_propagation = first_handler.stop_propagation\n prevent_default = first_handler.prevent_default\n target = first_handler.target\n\n for handler in event_handlers:\n if (\n handler.stop_propagation != stop_propagation\n or handler.prevent_default != prevent_default\n or handler.target != target\n ):\n msg = \"Cannot merge handlers - 'stop_propagation', 'prevent_default' or 'target' mismatch.\"\n raise ValueError(msg)\n\n return EventHandler(\n merge_event_handler_funcs([h.function for h in event_handlers]),\n stop_propagation,\n prevent_default,\n target,\n )\n\n\ndef merge_event_handler_funcs(\n functions: Sequence[EventHandlerFunc],\n) -> EventHandlerFunc:\n \"\"\"Make one event handler function from many\"\"\"\n if not functions:\n msg = \"No event handler functions to merge\"\n raise ValueError(msg)\n elif len(functions) == 1:\n return functions[0]\n\n async def await_all_event_handlers(data: Sequence[Any]) -> None:\n async with create_task_group() as group:\n for func in functions:\n group.start_soon(func, data)\n\n return await_all_event_handlers\n\n# Path: src/py/reactpy/reactpy/core/vdom.py\nfrom __future__ import annotations\n\nimport json\nfrom collections.abc import Mapping, Sequence\nfrom functools import wraps\nfrom typing import Any, Protocol, cast, overload\n\nfrom fastjsonschema import compile as compile_json_schema\n\nfrom reactpy._warnings import warn\nfrom reactpy.config import REACTPY_CHECK_JSON_ATTRS, REACTPY_DEBUG_MODE\nfrom reactpy.core._f_back import f_module_name\nfrom reactpy.core.events import EventHandler, to_event_handler_function\nfrom reactpy.core.types import (\n ComponentType,\n EventHandlerDict,\n EventHandlerType,\n ImportSourceDict,\n Key,\n VdomAttributes,\n VdomChild,\n VdomChildren,\n VdomDict,\n VdomDictConstructor,\n VdomJson,\n)\n\nVDOM_JSON_SCHEMA = {\n \"$schema\": \"http://json-schema.org/draft-07/schema\",\n \"$ref\": \"#/definitions/element\",\n \"definitions\": {\n \"element\": {\n \"type\": \"object\",\n \"properties\": {\n \"tagName\": {\"type\": \"string\"},\n \"key\": {\"type\": [\"string\", \"number\", \"null\"]},\n \"error\": {\"type\": \"string\"},\n \"children\": {\"$ref\": \"#/definitions/elementChildren\"},\n \"attributes\": {\"type\": \"object\"},\n \"eventHandlers\": {\"$ref\": \"#/definitions/elementEventHandlers\"},\n \"importSource\": {\"$ref\": \"#/definitions/importSource\"},\n },\n # The 'tagName' is required because its presence is a useful indicator of\n # whether a dictionary describes a VDOM model or not.\n \"required\": [\"tagName\"],\n \"dependentSchemas\": {\n # When 'error' is given, the 'tagName' should be empty.\n \"error\": {\"properties\": {\"tagName\": {\"maxLength\": 0}}}\n },\n },\n \"elementChildren\": {\n \"type\": \"array\",\n \"items\": {\"$ref\": \"#/definitions/elementOrString\"},\n },\n \"elementEventHandlers\": {\n \"type\": \"object\",\n \"patternProperties\": {\n \".*\": {\"$ref\": \"#/definitions/eventHandler\"},\n },\n },\n \"eventHandler\": {\n \"type\": \"object\",\n \"properties\": {\n \"target\": {\"type\": \"string\"},\n \"preventDefault\": {\"type\": \"boolean\"},\n \"stopPropagation\": {\"type\": \"boolean\"},\n },\n \"required\": [\"target\"],\n },\n \"importSource\": {\n \"type\": \"object\",\n \"properties\": {\n \"source\": {\"type\": \"string\"},\n \"sourceType\": {\"enum\": [\"URL\", \"NAME\"]},\n \"fallback\": {\n \"type\": [\"object\", \"string\", \"null\"],\n \"if\": {\"not\": {\"type\": \"null\"}},\n \"then\": {\"$ref\": \"#/definitions/elementOrString\"},\n },\n \"unmountBeforeUpdate\": {\"type\": \"boolean\"},\n },\n \"required\": [\"source\"],\n },\n \"elementOrString\": {\n \"type\": [\"object\", \"string\"],\n \"if\": {\"type\": \"object\"},\n \"then\": {\"$ref\": \"#/definitions/element\"},\n },\n },\n}\n\"\"\"JSON Schema describing serialized VDOM - see :ref:`VDOM` for more info\"\"\"\n\n\n# we can't add a docstring to this because Sphinx doesn't know how to find its source\n_COMPILED_VDOM_VALIDATOR = compile_json_schema(VDOM_JSON_SCHEMA)\n\n\ndef validate_vdom_json(value: Any) -> VdomJson:\n \"\"\"Validate serialized VDOM - see :attr:`VDOM_JSON_SCHEMA` for more info\"\"\"\n _COMPILED_VDOM_VALIDATOR(value)\n return cast(VdomJson, value)\n\n\ndef is_vdom(value: Any) -> bool:\n \"\"\"Return whether a value is a :class:`VdomDict`\n\n This employs a very simple heuristic - something is VDOM if:\n\n 1. It is a ``dict`` instance\n 2. It contains the key ``\"tagName\"``\n 3. The value of the key ``\"tagName\"`` is a string\n\n .. note::\n\n Performing an ``isinstance(value, VdomDict)`` check is too restrictive since the\n user would be forced to import ``VdomDict`` every time they needed to declare a\n VDOM element. Giving the user more flexibility, at the cost of this check's\n accuracy, is worth it.\n \"\"\"\n return (\n isinstance(value, dict)\n and \"tagName\" in value\n and isinstance(value[\"tagName\"], str)\n )\n\n\n@overload\ndef vdom(tag: str, *children: VdomChildren) -> VdomDict: ...\n\n\n@overload\ndef vdom(tag: str, attributes: VdomAttributes, *children: VdomChildren) -> VdomDict: ...\n\n\ndef vdom(\n tag: str,\n *attributes_and_children: Any,\n **kwargs: Any,\n) -> VdomDict:\n \"\"\"A helper function for creating VDOM elements.\n\n Parameters:\n tag:\n The type of element (e.g. 'div', 'h1', 'img')\n attributes_and_children:\n An optional attribute mapping followed by any number of children or\n iterables of children. The attribute mapping **must** precede the children,\n or children which will be merged into their respective parts of the model.\n key:\n A string indicating the identity of a particular element. This is significant\n to preserve event handlers across updates - without a key, a re-render would\n cause these handlers to be deleted, but with a key, they would be redirected\n to any newly defined handlers.\n event_handlers:\n Maps event types to coroutines that are responsible for handling those events.\n import_source:\n (subject to change) specifies javascript that, when evaluated returns a\n React component.\n \"\"\"\n if kwargs: # nocov\n if \"key\" in kwargs:\n if attributes_and_children:\n maybe_attributes, *children = attributes_and_children\n if _is_attributes(maybe_attributes):\n attributes_and_children = (\n {**maybe_attributes, \"key\": kwargs.pop(\"key\")},\n *children,\n )\n else:\n attributes_and_children = (\n {\"key\": kwargs.pop(\"key\")},\n maybe_attributes,\n *children,\n )\n else:\n attributes_and_children = ({\"key\": kwargs.pop(\"key\")},)\n warn(\n \"An element's 'key' must be declared in an attribute dict instead \"\n \"of as a keyword argument. This will error in a future version.\",\n DeprecationWarning,\n )\n\n if kwargs:\n msg = f\"Extra keyword arguments {kwargs}\"\n raise ValueError(msg)\n\n model: VdomDict = {\"tagName\": tag}\n\n if not attributes_and_children:\n return model\n\n attributes, children = separate_attributes_and_children(attributes_and_children)\n key = attributes.pop(\"key\", None)\n attributes, event_handlers = separate_attributes_and_event_handlers(attributes)\n\n if attributes:\n if REACTPY_CHECK_JSON_ATTRS.current:\n json.dumps(attributes)\n model[\"attributes\"] = attributes\n\n if children:\n model[\"children\"] = children\n\n if key is not None:\n model[\"key\"] = key\n\n if event_handlers:\n model[\"eventHandlers\"] = event_handlers\n\n return model\n\n\ndef make_vdom_constructor(\n tag: str, allow_children: bool = True, import_source: ImportSourceDict | None = None\n) -> VdomDictConstructor:\n \"\"\"Return a constructor for VDOM dictionaries with the given tag name.\n\n The resulting callable will have the same interface as :func:`vdom` but without its\n first ``tag`` argument.\n \"\"\"\n\n def constructor(*attributes_and_children: Any, **kwargs: Any) -> VdomDict:\n model = vdom(tag, *attributes_and_children, **kwargs)\n if not allow_children and \"children\" in model:\n msg = f\"{tag!r} nodes cannot have children.\"\n raise TypeError(msg)\n if import_source:\n model[\"importSource\"] = import_source\n return model\n\n # replicate common function attributes\n constructor.__name__ = tag\n constructor.__doc__ = (\n \"Return a new \"\n f\"`<{tag}> `__ \"\n \"element represented by a :class:`VdomDict`.\"\n )\n\n module_name = f_module_name(1)\n if module_name:\n constructor.__module__ = module_name\n constructor.__qualname__ = f\"{module_name}.{tag}\"\n\n return cast(VdomDictConstructor, constructor)\n\n\ndef custom_vdom_constructor(func: _CustomVdomDictConstructor) -> VdomDictConstructor:\n \"\"\"Cast function to VdomDictConstructor\"\"\"\n\n @wraps(func)\n def wrapper(*attributes_and_children: Any) -> VdomDict:\n attributes, children = separate_attributes_and_children(attributes_and_children)\n key = attributes.pop(\"key\", None)\n attributes, event_handlers = separate_attributes_and_event_handlers(attributes)\n return func(attributes, children, key, event_handlers)\n\n return cast(VdomDictConstructor, wrapper)\n\n\ndef separate_attributes_and_children(\n values: Sequence[Any],\n) -> tuple[dict[str, Any], list[Any]]:\n if not values:\n return {}, []\n\n attributes: dict[str, Any]\n children_or_iterables: Sequence[Any]\n if _is_attributes(values[0]):\n attributes, *children_or_iterables = values\n else:\n attributes = {}\n children_or_iterables = values\n\n children: list[Any] = []\n for child in children_or_iterables:\n if _is_single_child(child):\n children.append(child)\n else:\n children.extend(child)\n\n return attributes, children\n\n\ndef separate_attributes_and_event_handlers(\n attributes: Mapping[str, Any]\n) -> tuple[dict[str, Any], EventHandlerDict]:\n separated_attributes = {}\n separated_event_handlers: dict[str, EventHandlerType] = {}\n\n for k, v in attributes.items():\n handler: EventHandlerType\n\n if callable(v):\n handler = EventHandler(to_event_handler_function(v))\n elif (\n # isinstance check on protocols is slow - use function attr pre-check as a\n # quick filter before actually performing slow EventHandlerType type check\n hasattr(v, \"function\")\n and isinstance(v, EventHandlerType)\n ):\n handler = v\n else:\n separated_attributes[k] = v\n continue\n\n separated_event_handlers[k] = handler\n\n return separated_attributes, dict(separated_event_handlers.items())\n\n\ndef _is_attributes(value: Any) -> bool:\n return isinstance(value, Mapping) and \"tagName\" not in value\n\n\ndef _is_single_child(value: Any) -> bool:\n if isinstance(value, (str, Mapping)) or not hasattr(value, \"__iter__\"):\n return True\n if REACTPY_DEBUG_MODE.current:\n _validate_child_key_integrity(value)\n return False\n\n\ndef _validate_child_key_integrity(value: Any) -> None:\n if hasattr(value, \"__iter__\") and not hasattr(value, \"__len__\"):\n warn(\n f\"Did not verify key-path integrity of children in generator {value} \"\n \"- pass a sequence (i.e. list of finite length) in order to verify\"\n )\n else:\n for child in value:\n if isinstance(child, ComponentType) and child.key is None:\n warn(f\"Key not specified for child in list {child}\", UserWarning)\n elif isinstance(child, Mapping) and \"key\" not in child:\n # remove 'children' to reduce log spam\n child_copy = {**child, \"children\": _EllipsisRepr()}\n warn(f\"Key not specified for child in list {child_copy}\", UserWarning)\n\n\nclass _CustomVdomDictConstructor(Protocol):\n def __call__(\n self,\n attributes: VdomAttributes,\n children: Sequence[VdomChild],\n key: Key | None,\n event_handlers: EventHandlerDict,\n ) -> VdomDict: ...\n\n\nclass _EllipsisRepr:\n def __repr__(self) -> str:\n return \"...\"\n\n# Path: src/py/reactpy/reactpy/utils.py\nfrom __future__ import annotations\n\nimport re\nfrom collections.abc import Iterable\nfrom itertools import chain\nfrom typing import Any, Callable, Generic, TypeVar, cast\n\nfrom lxml import etree\nfrom lxml.html import fromstring, tostring\n\nfrom reactpy.core.types import VdomDict\nfrom reactpy.core.vdom import vdom\n\n_RefValue = TypeVar(\"_RefValue\")\n_ModelTransform = Callable[[VdomDict], Any]\n_UNDEFINED: Any = object()\n\n\nclass Ref(Generic[_RefValue]):\n \"\"\"Hold a reference to a value\n\n This is used in imperative code to mutate the state of this object in order to\n incur side effects. Generally refs should be avoided if possible, but sometimes\n they are required.\n\n Notes:\n You can compare the contents for two ``Ref`` objects using the ``==`` operator.\n \"\"\"\n\n __slots__ = (\"current\",)\n\n def __init__(self, initial_value: _RefValue = _UNDEFINED) -> None:\n if initial_value is not _UNDEFINED:\n self.current = initial_value\n \"\"\"The present value\"\"\"\n\n def set_current(self, new: _RefValue) -> _RefValue:\n \"\"\"Set the current value and return what is now the old value\n\n This is nice to use in ``lambda`` functions.\n \"\"\"\n old = self.current\n self.current = new\n return old\n\n def __eq__(self, other: Any) -> bool:\n try:\n return isinstance(other, Ref) and (other.current == self.current)\n except AttributeError:\n # attribute error occurs for uninitialized refs\n return False\n\n def __repr__(self) -> str:\n try:\n current = repr(self.current)\n except AttributeError:\n # attribute error occurs for uninitialized refs\n current = \"\"\n return f\"{type(self).__name__}({current})\"\n\n\ndef vdom_to_html(vdom: VdomDict) -> str:\n \"\"\"Convert a VDOM dictionary into an HTML string\n\n Only the following keys are translated to HTML:\n\n - ``tagName``\n - ``attributes``\n - ``children`` (must be strings or more VDOM dicts)\n\n Parameters:\n vdom: The VdomDict element to convert to HTML\n \"\"\"\n temp_root = etree.Element(\"__temp__\")\n _add_vdom_to_etree(temp_root, vdom)\n html = cast(bytes, tostring(temp_root)).decode()\n # strip out temp root <__temp__> element\n return html[10:-11]\n\n\ndef html_to_vdom(\n html: str, *transforms: _ModelTransform, strict: bool = True\n) -> VdomDict:\n \"\"\"Transform HTML into a DOM model. Unique keys can be provided to HTML elements\n using a ``key=...`` attribute within your HTML tag.\n\n Parameters:\n html:\n The raw HTML as a string\n transforms:\n Functions of the form ``transform(old) -> new`` where ``old`` is a VDOM\n dictionary which will be replaced by ``new``. For example, you could use a\n transform function to add highlighting to a ```` block.\n strict:\n If ``True``, raise an exception if the HTML does not perfectly follow HTML5\n syntax.\n \"\"\"\n if not isinstance(html, str): # nocov\n msg = f\"Expected html to be a string, not {type(html).__name__}\"\n raise TypeError(msg)\n\n # If the user provided a string, convert it to a list of lxml.etree nodes\n try:\n root_node: etree._Element = fromstring(\n html.strip(),\n parser=etree.HTMLParser(\n remove_comments=True,\n remove_pis=True,\n remove_blank_text=True,\n recover=not strict,\n ),\n )\n except etree.XMLSyntaxError as e:\n if not strict:\n raise e # nocov\n msg = \"An error has occurred while parsing the HTML.\\n\\nThis HTML may be malformatted, or may not perfectly adhere to HTML5.\\nIf you believe the exception above was due to something intentional, you can disable the strict parameter on html_to_vdom().\\nOtherwise, repair your broken HTML and try again.\"\n raise HTMLParseError(msg) from e\n\n return _etree_to_vdom(root_node, transforms)\n\n\nclass HTMLParseError(etree.LxmlSyntaxError): # type: ignore[misc]\n \"\"\"Raised when an HTML document cannot be parsed using strict parsing.\"\"\"\n\n\ndef _etree_to_vdom(\n node: etree._Element, transforms: Iterable[_ModelTransform]\n) -> VdomDict:\n \"\"\"Transform an lxml etree node into a DOM model\n\n Parameters:\n node:\n The ``lxml.etree._Element`` node\n transforms:\n Functions of the form ``transform(old) -> new`` where ``old`` is a VDOM\n dictionary which will be replaced by ``new``. For example, you could use a\n transform function to add highlighting to a ```` block.\n \"\"\"\n if not isinstance(node, etree._Element): # nocov\n msg = f\"Expected node to be a etree._Element, not {type(node).__name__}\"\n raise TypeError(msg)\n\n # Recursively call _etree_to_vdom() on all children\n children = _generate_vdom_children(node, transforms)\n\n # Convert the lxml node to a VDOM dict\n el = vdom(node.tag, dict(node.items()), *children)\n\n # Perform any necessary mutations on the VDOM attributes to meet VDOM spec\n _mutate_vdom(el)\n\n # Apply any provided transforms.\n for transform in transforms:\n el = transform(el)\n\n return el\n\n\ndef _add_vdom_to_etree(parent: etree._Element, vdom: VdomDict | dict[str, Any]) -> None:\n try:\n tag = vdom[\"tagName\"]\n except KeyError as e:\n msg = f\"Expected a VDOM dict, not {vdom}\"\n raise TypeError(msg) from e\n else:\n vdom = cast(VdomDict, vdom)\n\n if tag:\n element = etree.SubElement(parent, tag)\n element.attrib.update(\n _vdom_attr_to_html_str(k, v) for k, v in vdom.get(\"attributes\", {}).items()\n )\n else:\n element = parent\n\n for c in vdom.get(\"children\", []):\n if isinstance(c, dict):\n _add_vdom_to_etree(element, c)\n else:\n \"\"\"\n LXML handles string children by storing them under `text` and `tail`\n attributes of Element objects. The `text` attribute, if present, effectively\n becomes that element's first child. Then the `tail` attribute, if present,\n becomes a sibling that follows that element. For example, consider the\n following HTML:\n\n

helloworld

\n\n In this code sample, \"hello\" is the `text` attribute of the `` element\n and \"world\" is the `tail` attribute of that same `` element. It's for\n this reason that, depending on whether the element being constructed has\n non-string a child element, we need to assign a `text` vs `tail` attribute\n to that element or the last non-string child respectively.\n \"\"\"\n if len(element):\n last_child = element[-1]\n last_child.tail = f\"{last_child.tail or ''}{c}\"\n else:\n element.text = f\"{element.text or ''}{c}\"\n\n\ndef _mutate_vdom(vdom: VdomDict) -> None:\n \"\"\"Performs any necessary mutations on the VDOM attributes to meet VDOM spec.\n\n Currently, this function only transforms the ``style`` attribute into a dictionary whose keys are\n camelCase so as to be renderable by React.\n\n This function may be extended in the future.\n \"\"\"\n # Determine if the style attribute needs to be converted to a dict\n if (\n \"attributes\" in vdom\n and \"style\" in vdom[\"attributes\"]\n and isinstance(vdom[\"attributes\"][\"style\"], str)\n ):\n # Convince type checker that it's safe to mutate attributes\n assert isinstance(vdom[\"attributes\"], dict) # noqa: S101\n\n # Convert style attribute from str -> dict with camelCase keys\n vdom[\"attributes\"][\"style\"] = {\n key.strip().replace(\"-\", \"_\"): value.strip()\n for key, value in (\n part.split(\":\", 1)\n for part in vdom[\"attributes\"][\"style\"].split(\";\")\n if \":\" in part\n )\n }\n\n\ndef _generate_vdom_children(\n node: etree._Element, transforms: Iterable[_ModelTransform]\n) -> list[VdomDict | str]:\n \"\"\"Generates a list of VDOM children from an lxml node.\n\n Inserts inner text and/or tail text in between VDOM children, if necessary.\n \"\"\"\n return ( # Get the inner text of the current node\n [node.text] if node.text else []\n ) + list(\n chain(\n *(\n # Recursively convert each child node to VDOM\n [_etree_to_vdom(child, transforms)]\n # Insert the tail text between each child node\n + ([child.tail] if child.tail else [])\n for child in node.iterchildren(None)\n )\n )\n )\n\n\ndef del_html_head_body_transform(vdom: VdomDict) -> VdomDict:\n \"\"\"Transform intended for use with `html_to_vdom`.\n\n Removes ``, ``, and `` while preserving their children.\n\n Parameters:\n vdom:\n The VDOM dictionary to transform.\n \"\"\"\n if vdom[\"tagName\"] in {\"html\", \"body\", \"head\"}:\n return {\"tagName\": \"\", \"children\": vdom[\"children\"]}\n return vdom\n\n\ndef _vdom_attr_to_html_str(key: str, value: Any) -> tuple[str, str]:\n if key == \"style\":\n if isinstance(value, dict):\n value = \";\".join(\n # We lower only to normalize - CSS is case-insensitive:\n # https://www.w3.org/TR/css-fonts-3/#font-family-casing\n f\"{_CAMEL_CASE_SUB_PATTERN.sub('-', k).lower()}:{v}\"\n for k, v in value.items()\n )\n elif (\n # camel to data-* attributes\n key.startswith(\"data_\")\n # camel to aria-* attributes\n or key.startswith(\"aria_\")\n # handle special cases\n or key in DASHED_HTML_ATTRS\n ):\n key = key.replace(\"_\", \"-\")\n elif (\n # camel to data-* attributes\n key.startswith(\"data\")\n # camel to aria-* attributes\n or key.startswith(\"aria\")\n # handle special cases\n or key in DASHED_HTML_ATTRS\n ):\n key = _CAMEL_CASE_SUB_PATTERN.sub(\"-\", key)\n\n if callable(value): # nocov\n raise TypeError(f\"Cannot convert callable attribute {key}={value} to HTML\")\n\n # Again, we lower the attribute name only to normalize - HTML is case-insensitive:\n # http://w3c.github.io/html-reference/documents.html#case-insensitivity\n return key.lower(), str(value)\n\n\n# see list of HTML attributes with dashes in them:\n# https://developer.mozilla.org/en-US/docs/Web/HTML/Attributes#attribute_list\nDASHED_HTML_ATTRS = {\"accept_charset\", \"acceptCharset\", \"http_equiv\", \"httpEquiv\"}\n\n# Pattern for delimitting camelCase names (e.g. camelCase to camel-case)\n_CAMEL_CASE_SUB_PATTERN = re.compile(r\"(? State[_Type]: ...\n\n\n@overload\ndef use_state(initial_value: _Type) -> State[_Type]: ...\n\n\ndef use_state(initial_value: _Type | Callable[[], _Type]) -> State[_Type]:\n \"\"\"See the full :ref:`Use State` docs for details\n\n Parameters:\n initial_value:\n Defines the initial value of the state. A callable (accepting no arguments)\n can be used as a constructor function to avoid re-creating the initial value\n on each render.\n\n Returns:\n A tuple containing the current state and a function to update it.\n \"\"\"\n current_state = _use_const(lambda: _CurrentState(initial_value))\n return State(current_state.value, current_state.dispatch)\n\n\nclass _CurrentState(Generic[_Type]):\n __slots__ = \"value\", \"dispatch\"\n\n def __init__(\n self,\n initial_value: _Type | Callable[[], _Type],\n ) -> None:\n if callable(initial_value):\n self.value = initial_value()\n else:\n self.value = initial_value\n\n hook = current_hook()\n\n def dispatch(new: _Type | Callable[[_Type], _Type]) -> None:\n if callable(new):\n next_value = new(self.value)\n else:\n next_value = new\n if not strictly_equal(next_value, self.value):\n self.value = next_value\n hook.schedule_render()\n\n self.dispatch = dispatch\n\n\n_EffectCleanFunc: TypeAlias = \"Callable[[], None]\"\n_SyncEffectFunc: TypeAlias = \"Callable[[], _EffectCleanFunc | None]\"\n_AsyncEffectFunc: TypeAlias = (\n \"Callable[[], Coroutine[None, None, _EffectCleanFunc | None]]\"\n)\n_EffectApplyFunc: TypeAlias = \"_SyncEffectFunc | _AsyncEffectFunc\"\n\n\n@overload\ndef use_effect(\n function: None = None,\n dependencies: Sequence[Any] | ellipsis | None = ...,\n) -> Callable[[_EffectApplyFunc], None]: ...\n\n\n@overload\ndef use_effect(\n function: _EffectApplyFunc,\n dependencies: Sequence[Any] | ellipsis | None = ...,\n) -> None: ...\n\n\ndef use_effect(\n function: _EffectApplyFunc | None = None,\n dependencies: Sequence[Any] | ellipsis | None = ...,\n) -> Callable[[_EffectApplyFunc], None] | None:\n \"\"\"See the full :ref:`Use Effect` docs for details\n\n Parameters:\n function:\n Applies the effect and can return a clean-up function\n dependencies:\n Dependencies for the effect. The effect will only trigger if the identity\n of any value in the given sequence changes (i.e. their :func:`id` is\n different). By default these are inferred based on local variables that are\n referenced by the given function.\n\n Returns:\n If not function is provided, a decorator. Otherwise ``None``.\n \"\"\"\n hook = current_hook()\n\n dependencies = _try_to_infer_closure_values(function, dependencies)\n memoize = use_memo(dependencies=dependencies)\n last_clean_callback: Ref[_EffectCleanFunc | None] = use_ref(None)\n\n def add_effect(function: _EffectApplyFunc) -> None:\n if not asyncio.iscoroutinefunction(function):\n sync_function = cast(_SyncEffectFunc, function)\n else:\n async_function = cast(_AsyncEffectFunc, function)\n\n def sync_function() -> _EffectCleanFunc | None:\n task = asyncio.create_task(async_function())\n\n def clean_future() -> None:\n if not task.cancel():\n try:\n clean = task.result()\n except asyncio.CancelledError:\n pass\n else:\n if clean is not None:\n clean()\n\n return clean_future\n\n async def effect(stop: asyncio.Event) -> None:\n if last_clean_callback.current is not None:\n last_clean_callback.current()\n last_clean_callback.current = None\n clean = last_clean_callback.current = sync_function()\n await stop.wait()\n if clean is not None:\n clean()\n\n return memoize(lambda: hook.add_effect(effect))\n\n if function is not None:\n add_effect(function)\n return None\n else:\n return add_effect\n\n\ndef use_debug_value(\n message: Any | Callable[[], Any],\n dependencies: Sequence[Any] | ellipsis | None = ...,\n) -> None:\n \"\"\"Log debug information when the given message changes.\n\n .. note::\n This hook only logs if :data:`~reactpy.config.REACTPY_DEBUG_MODE` is active.\n\n Unlike other hooks, a message is considered to have changed if the old and new\n values are ``!=``. Because this comparison is performed on every render of the\n component, it may be worth considering the performance cost in some situations.\n\n Parameters:\n message:\n The value to log or a memoized function for generating the value.\n dependencies:\n Dependencies for the memoized function. The message will only be recomputed\n if the identity of any value in the given sequence changes (i.e. their\n :func:`id` is different). By default these are inferred based on local\n variables that are referenced by the given function.\n \"\"\"\n old: Ref[Any] = _use_const(lambda: Ref(object()))\n memo_func = message if callable(message) else lambda: message\n new = use_memo(memo_func, dependencies)\n\n if REACTPY_DEBUG_MODE.current and old.current != new:\n old.current = new\n logger.debug(f\"{current_hook().component} {new}\")\n\n\n\ndef create_context(default_value: _Type) -> Context[_Type]:\n \"\"\"Return a new context type for use in :func:`use_context`\"\"\"\n\n def context(\n *children: Any,\n value: _Type = default_value,\n key: Key | None = None,\n ) -> _ContextProvider[_Type]:\n return _ContextProvider(\n *children,\n value=value,\n key=key,\n type=context,\n )\n\n context.__qualname__ = \"context\"\n\n return context\n\n\ndef use_context(context: Context[_Type]) -> _Type:\n \"\"\"Get the current value for the given context type.\n\n See the full :ref:`Use Context` docs for more information.\n \"\"\"\n hook = current_hook()\n provider = hook.get_context_provider(context)\n\n if provider is None:\n # same assertions but with normal exceptions\n if not isinstance(context, FunctionType):\n raise TypeError(f\"{context} is not a Context\") # nocov\n if context.__kwdefaults__ is None:\n raise TypeError(f\"{context} has no 'value' kwarg\") # nocov\n if \"value\" not in context.__kwdefaults__:\n raise TypeError(f\"{context} has no 'value' kwarg\") # nocov\n return cast(_Type, context.__kwdefaults__[\"value\"])\n\n return provider.value\n\n\nclass _ContextProvider(Generic[_Type]):\n def __init__(\n self,\n *children: Any,\n value: _Type,\n key: Key | None,\n type: Context[_Type],\n ) -> None:\n self.children = children\n self.key = key\n self.type = type\n self.value = value\n\n def render(self) -> VdomDict:\n current_hook().set_context_provider(self)\n return {\"tagName\": \"\", \"children\": self.children}\n\n def __repr__(self) -> str:\n return f\"ContextProvider({self.type})\"\n\n\n_ActionType = TypeVar(\"_ActionType\")\n\n\ndef use_reducer(\n reducer: Callable[[_Type, _ActionType], _Type],\n initial_value: _Type,\n) -> tuple[_Type, Callable[[_ActionType], None]]:\n \"\"\"See the full :ref:`Use Reducer` docs for details\n\n Parameters:\n reducer:\n A function which applies an action to the current state in order to\n produce the next state.\n initial_value:\n The initial state value (same as for :func:`use_state`)\n\n Returns:\n A tuple containing the current state and a function to change it with an action\n \"\"\"\n state, set_state = use_state(initial_value)\n return state, _use_const(lambda: _create_dispatcher(reducer, set_state))\n\n\ndef _create_dispatcher(\n reducer: Callable[[_Type, _ActionType], _Type],\n set_state: Callable[[Callable[[_Type], _Type]], None],\n) -> Callable[[_ActionType], None]:\n def dispatch(action: _ActionType) -> None:\n set_state(lambda last_state: reducer(last_state, action))\n\n return dispatch\n\n\n_CallbackFunc = TypeVar(\"_CallbackFunc\", bound=Callable[..., Any])\n\n\n@overload\ndef use_callback(\n function: None = None,\n dependencies: Sequence[Any] | ellipsis | None = ...,\n) -> Callable[[_CallbackFunc], _CallbackFunc]: ...\n\n\n@overload\ndef use_callback(\n function: _CallbackFunc,\n dependencies: Sequence[Any] | ellipsis | None = ...,\n) -> _CallbackFunc: ...\n\n\ndef use_callback(\n function: _CallbackFunc | None = None,\n dependencies: Sequence[Any] | ellipsis | None = ...,\n) -> _CallbackFunc | Callable[[_CallbackFunc], _CallbackFunc]:\n \"\"\"See the full :ref:`Use Callback` docs for details\n\n Parameters:\n function:\n The function whose identity will be preserved\n dependencies:\n Dependencies of the callback. The identity the ``function`` will be updated\n if the identity of any value in the given sequence changes (i.e. their\n :func:`id` is different). By default these are inferred based on local\n variables that are referenced by the given function.\n\n Returns:\n The current function\n \"\"\"\n dependencies = _try_to_infer_closure_values(function, dependencies)\n memoize = use_memo(dependencies=dependencies)\n\n def setup(function: _CallbackFunc) -> _CallbackFunc:\n return memoize(lambda: function)\n\n if function is not None:\n return setup(function)\n else:\n return setup\n\n\nclass _LambdaCaller(Protocol):\n \"\"\"MyPy doesn't know how to deal with TypeVars only used in function return\"\"\"\n\n def __call__(self, func: Callable[[], _Type]) -> _Type: ...\n\n\n@overload\ndef use_memo(\n function: None = None,\n dependencies: Sequence[Any] | ellipsis | None = ...,\n) -> _LambdaCaller: ...\n\n\n@overload\ndef use_memo(\n function: Callable[[], _Type],\n dependencies: Sequence[Any] | ellipsis | None = ...,\n) -> _Type: ...\n\n\ndef use_memo(\n function: Callable[[], _Type] | None = None,\n dependencies: Sequence[Any] | ellipsis | None = ...,\n) -> _Type | Callable[[Callable[[], _Type]], _Type]:\n \"\"\"See the full :ref:`Use Memo` docs for details\n\n Parameters:\n function:\n The function to be memoized.\n dependencies:\n Dependencies for the memoized function. The memo will only be recomputed if\n the identity of any value in the given sequence changes (i.e. their\n :func:`id` is different). By default these are inferred based on local\n variables that are referenced by the given function.\n\n Returns:\n The current state\n \"\"\"\n dependencies = _try_to_infer_closure_values(function, dependencies)\n\n memo: _Memo[_Type] = _use_const(_Memo)\n\n if memo.empty():\n # we need to initialize on the first run\n changed = True\n memo.deps = () if dependencies is None else dependencies\n elif dependencies is None:\n changed = True\n memo.deps = ()\n elif (\n len(memo.deps) != len(dependencies)\n # if deps are same length check identity for each item\n or not all(\n strictly_equal(current, new)\n for current, new in zip(memo.deps, dependencies)\n )\n ):\n memo.deps = dependencies\n changed = True\n else:\n changed = False\n\n setup: Callable[[Callable[[], _Type]], _Type]\n\n if changed:\n\n def setup(function: Callable[[], _Type]) -> _Type:\n current_value = memo.value = function()\n return current_value\n\n else:\n\n def setup(function: Callable[[], _Type]) -> _Type:\n return memo.value\n\n if function is not None:\n return setup(function)\n else:\n return setup\n\n\nclass _Memo(Generic[_Type]):\n \"\"\"Simple object for storing memoization data\"\"\"\n\n __slots__ = \"value\", \"deps\"\n\n value: _Type\n deps: Sequence[Any]\n\n def empty(self) -> bool:\n try:\n self.value # noqa: B018\n except AttributeError:\n return True\n else:\n return False\n\n\ndef use_ref(initial_value: _Type) -> Ref[_Type]:\n \"\"\"See the full :ref:`Use State` docs for details\n\n Parameters:\n initial_value: The value initially assigned to the reference.\n\n Returns:\n A :class:`Ref` object.\n \"\"\"\n return _use_const(lambda: Ref(initial_value))\n\n\ndef _use_const(function: Callable[[], _Type]) -> _Type:\n return current_hook().use_state(function)\n\n\ndef _try_to_infer_closure_values(\n func: Callable[..., Any] | None,\n values: Sequence[Any] | ellipsis | None,\n) -> Sequence[Any] | None:\n if values is ...:\n if isinstance(func, FunctionType):\n return (\n [cell.cell_contents for cell in func.__closure__]\n if func.__closure__\n else []\n )\n else:\n return None\n else:\n return values\n\n\ndef strictly_equal(x: Any, y: Any) -> bool:\n \"\"\"Check if two values are identical or, for a limited set or types, equal.\n\n Only the following types are checked for equality rather than identity:\n\n - ``int``\n - ``float``\n - ``complex``\n - ``str``\n - ``bytes``\n - ``bytearray``\n - ``memoryview``\n \"\"\"\n return x is y or (type(x) in _NUMERIC_TEXT_BINARY_TYPES and x == y)\n\n\n_NUMERIC_TEXT_BINARY_TYPES = {\n # numeric\n int,\n float,\n complex,\n # text\n str,\n # binary types\n bytes,\n bytearray,\n memoryview,\n}\n\n# Path: src/py/reactpy/reactpy/backend/hooks.py\nfrom __future__ import annotations\n\nfrom collections.abc import MutableMapping\nfrom typing import Any\n\nfrom reactpy.backend.types import Connection, Location\nfrom reactpy.core.hooks import create_context, use_context\nfrom reactpy.core.types import Context\n\n# backend implementations should establish this context at the root of an app\nConnectionContext: Context[Connection[Any] | None] = create_context(None)\n\n\ndef use_connection() -> Connection[Any]:\n \"\"\"Get the current :class:`~reactpy.backend.types.Connection`.\"\"\"\n conn = use_context(ConnectionContext)\n if conn is None: # nocov\n msg = \"No backend established a connection.\"\n raise RuntimeError(msg)\n return conn\n\n\ndef use_scope() -> MutableMapping[str, Any]:\n \"\"\"Get the current :class:`~reactpy.backend.types.Connection`'s scope.\"\"\"\n return use_connection().scope\n\n\ndef use_location() -> Location:\n \"\"\"Get the current :class:`~reactpy.backend.types.Connection`'s location.\"\"\"\n return use_connection().location\n\n# Path: src/py/reactpy/reactpy/core/component.py\nfrom __future__ import annotations\n\nimport inspect\nfrom functools import wraps\nfrom typing import Any, Callable\n\nfrom reactpy.core.types import ComponentType, VdomDict\n\n\ndef component(\n function: Callable[..., ComponentType | VdomDict | str | None]\n) -> Callable[..., Component]:\n \"\"\"A decorator for defining a new component.\n\n Parameters:\n function: The component's :meth:`reactpy.core.proto.ComponentType.render` function.\n \"\"\"\n sig = inspect.signature(function)\n\n if \"key\" in sig.parameters and sig.parameters[\"key\"].kind in (\n inspect.Parameter.KEYWORD_ONLY,\n inspect.Parameter.POSITIONAL_OR_KEYWORD,\n ):\n msg = f\"Component render function {function} uses reserved parameter 'key'\"\n raise TypeError(msg)\n\n @wraps(function)\n def constructor(*args: Any, key: Any | None = None, **kwargs: Any) -> Component:\n return Component(function, key, args, kwargs, sig)\n\n return constructor\n\n\nclass Component:\n \"\"\"An object for rending component models.\"\"\"\n\n __slots__ = \"__weakref__\", \"_func\", \"_args\", \"_kwargs\", \"_sig\", \"key\", \"type\"\n\n def __init__(\n self,\n function: Callable[..., ComponentType | VdomDict | str | None],\n key: Any | None,\n args: tuple[Any, ...],\n kwargs: dict[str, Any],\n sig: inspect.Signature,\n ) -> None:\n self.key = key\n self.type = function\n self._args = args\n self._kwargs = kwargs\n self._sig = sig\n\n def render(self) -> ComponentType | VdomDict | str | None:\n return self.type(*self._args, **self._kwargs)\n\n def __repr__(self) -> str:\n try:\n args = self._sig.bind(*self._args, **self._kwargs).arguments\n except TypeError:\n return f\"{self.type.__name__}(...)\"\n else:\n items = \", \".join(f\"{k}={v!r}\" for k, v in args.items())\n if items:\n return f\"{self.type.__name__}({id(self):02x}, {items})\"\n else:\n return f\"{self.type.__name__}({id(self):02x})\"\n\n# Path: src/py/reactpy/reactpy/types.py\n\"\"\"Exports common types from:\n\n- :mod:`reactpy.core.types`\n- :mod:`reactpy.backend.types`\n\"\"\"\n\nfrom reactpy.backend.types import BackendType, Connection, Location\nfrom reactpy.core.component import Component\nfrom reactpy.core.types import (\n ComponentConstructor,\n ComponentType,\n Context,\n EventHandlerDict,\n EventHandlerFunc,\n EventHandlerMapping,\n EventHandlerType,\n ImportSourceDict,\n Key,\n LayoutType,\n RootComponentConstructor,\n State,\n VdomAttributes,\n VdomChild,\n VdomChildren,\n VdomDict,\n VdomJson,\n)\n\n__all__ = [\n \"BackendType\",\n \"Component\",\n \"ComponentConstructor\",\n \"ComponentType\",\n \"Connection\",\n \"Context\",\n \"EventHandlerDict\",\n \"EventHandlerFunc\",\n \"EventHandlerMapping\",\n \"EventHandlerType\",\n \"ImportSourceDict\",\n \"Key\",\n \"LayoutType\",\n \"Location\",\n \"RootComponentConstructor\",\n \"State\",\n \"VdomAttributes\",\n \"VdomChild\",\n \"VdomChildren\",\n \"VdomDict\",\n \"VdomJson\",\n]\n\n# Path: src/py/reactpy/reactpy/backend/utils.py\nfrom __future__ import annotations\n\nimport asyncio\nimport logging\nimport socket\nimport sys\nfrom collections.abc import Iterator\nfrom contextlib import closing\nfrom importlib import import_module\nfrom typing import Any\n\nfrom reactpy.backend.types import BackendType\nfrom reactpy.types import RootComponentConstructor\n\nlogger = logging.getLogger(__name__)\n\nSUPPORTED_BACKENDS = (\n \"fastapi\",\n \"sanic\",\n \"tornado\",\n \"flask\",\n \"starlette\",\n)\n\n\ndef run(\n component: RootComponentConstructor,\n host: str = \"127.0.0.1\",\n port: int | None = None,\n implementation: BackendType[Any] | None = None,\n) -> None:\n \"\"\"Run a component with a development server\"\"\"\n logger.warning(_DEVELOPMENT_RUN_FUNC_WARNING)\n\n implementation = implementation or import_module(\"reactpy.backend.default\")\n app = implementation.create_development_app()\n implementation.configure(app, component)\n port = port or find_available_port(host)\n app_cls = type(app)\n\n logger.info(\n \"ReactPy is running with '%s.%s' at http://%s:%s\",\n app_cls.__module__,\n app_cls.__name__,\n host,\n port,\n )\n asyncio.run(implementation.serve_development_app(app, host, port))\n\n\ndef find_available_port(host: str, port_min: int = 8000, port_max: int = 9000) -> int:\n \"\"\"Get a port that's available for the given host and port range\"\"\"\n for port in range(port_min, port_max):\n with closing(socket.socket()) as sock:\n try:\n if sys.platform in (\"linux\", \"darwin\"):\n # Fixes bug on Unix-like systems where every time you restart the\n # server you'll get a different port on Linux. This cannot be set\n # on Windows otherwise address will always be reused.\n # Ref: https://stackoverflow.com/a/19247688/3159288\n sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)\n sock.bind((host, port))\n except OSError:\n pass\n else:\n return port\n msg = f\"Host {host!r} has no available port in range {port_max}-{port_max}\"\n raise RuntimeError(msg)\n\n\ndef all_implementations() -> Iterator[BackendType[Any]]:\n \"\"\"Yield all available server implementations\"\"\"\n for name in SUPPORTED_BACKENDS:\n try:\n import_module(name)\n except ImportError: # nocov\n logger.debug(\"Failed to import %s\", name, exc_info=True)\n continue\n\n reactpy_backend_name = f\"{__name__.rsplit('.', 1)[0]}.{name}\"\n yield import_module(reactpy_backend_name)\n\n\n_DEVELOPMENT_RUN_FUNC_WARNING = \"\"\"\\\nThe `run()` function is only intended for testing during development! To run \\\nin production, refer to the docs on how to use reactpy.backend.*.configure.\\\n\"\"\"\n\n# Path: src/py/reactpy/reactpy/core/layout.py\nfrom __future__ import annotations\n\nimport abc\nfrom asyncio import (\n FIRST_COMPLETED,\n CancelledError,\n Queue,\n Task,\n create_task,\n get_running_loop,\n wait,\n)\nfrom collections import Counter\nfrom collections.abc import Sequence\nfrom contextlib import AsyncExitStack\nfrom logging import getLogger\nfrom typing import (\n Any,\n Callable,\n Generic,\n NamedTuple,\n NewType,\n TypeVar,\n cast,\n)\nfrom uuid import uuid4\nfrom weakref import ref as weakref\n\nfrom anyio import Semaphore\nfrom typing_extensions import TypeAlias\n\nfrom reactpy.config import (\n REACTPY_ASYNC_RENDERING,\n REACTPY_CHECK_VDOM_SPEC,\n REACTPY_DEBUG_MODE,\n)\nfrom reactpy.core._life_cycle_hook import LifeCycleHook\nfrom reactpy.core.types import (\n ComponentType,\n EventHandlerDict,\n Key,\n LayoutEventMessage,\n LayoutUpdateMessage,\n VdomChild,\n VdomDict,\n VdomJson,\n)\nfrom reactpy.core.vdom import validate_vdom_json\nfrom reactpy.utils import Ref\n\nlogger = getLogger(__name__)\n\n\nclass Layout:\n \"\"\"Responsible for \"rendering\" components. That is, turning them into VDOM.\"\"\"\n\n __slots__: tuple[str, ...] = (\n \"root\",\n \"_event_handlers\",\n \"_rendering_queue\",\n \"_render_tasks\",\n \"_render_tasks_ready\",\n \"_root_life_cycle_state_id\",\n \"_model_states_by_life_cycle_state_id\",\n )\n\n if not hasattr(abc.ABC, \"__weakref__\"): # nocov\n __slots__ += (\"__weakref__\",)\n\n def __init__(self, root: ComponentType) -> None:\n super().__init__()\n if not isinstance(root, ComponentType):\n msg = f\"Expected a ComponentType, not {type(root)!r}.\"\n raise TypeError(msg)\n self.root = root\n\n async def __aenter__(self) -> Layout:\n # create attributes here to avoid access before entering context manager\n self._event_handlers: EventHandlerDict = {}\n self._render_tasks: set[Task[LayoutUpdateMessage]] = set()\n self._render_tasks_ready: Semaphore = Semaphore(0)\n\n self._rendering_queue: _ThreadSafeQueue[_LifeCycleStateId] = _ThreadSafeQueue()\n root_model_state = _new_root_model_state(self.root, self._schedule_render_task)\n\n self._root_life_cycle_state_id = root_id = root_model_state.life_cycle_state.id\n self._model_states_by_life_cycle_state_id = {root_id: root_model_state}\n self._schedule_render_task(root_id)\n\n return self\n\n async def __aexit__(self, *exc: Any) -> None:\n root_csid = self._root_life_cycle_state_id\n root_model_state = self._model_states_by_life_cycle_state_id[root_csid]\n\n for t in self._render_tasks:\n t.cancel()\n try:\n await t\n except CancelledError:\n pass\n\n await self._unmount_model_states([root_model_state])\n\n # delete attributes here to avoid access after exiting context manager\n del self._event_handlers\n del self._rendering_queue\n del self._root_life_cycle_state_id\n del self._model_states_by_life_cycle_state_id\n\n async def deliver(self, event: LayoutEventMessage) -> None:\n \"\"\"Dispatch an event to the targeted handler\"\"\"\n # It is possible for an element in the frontend to produce an event\n # associated with a backend model that has been deleted. We only handle\n # events if the element and the handler exist in the backend. Otherwise\n # we just ignore the event.\n handler = self._event_handlers.get(event[\"target\"])\n\n if handler is not None:\n try:\n await handler.function(event[\"data\"])\n except Exception:\n logger.exception(f\"Failed to execute event handler {handler}\")\n else:\n logger.info(\n...\n", "needle_token_start": 10360, "needle_token_end": 10491, "code_context_ntokens": 16025, "output": ["```python\ndef create_context(default_value: _Type) -> Context[_Type]:\n \"\"\"Return a new context type for use in :func:`use_context`\"\"\"\n\n def context(\n *children: Any,\n value: _Type = default_value,\n key: Key | None = None,\n ) -> _ContextProvider[_Type]:\n return _ContextProvider(\n *children,\n value=value,\n key=key,\n type=context,\n )\n\n context.__qualname__ = \"context\"\n\n return context\n```"]} +{"repo": "reactive-python/reactpy", "name": "set_context_provider", "language": "python", "path": "src/py/reactpy/reactpy/core/_life_cycle_hook.py", "position_ratio": 0.75, "description": "\nFunction Description:\n1. **Purpose**: To assign a context provider to the hook, enabling it to supply state to any child components that request a context of the same type.\n2. **Input**: A context provider, which includes a specific context type and the associated state management logic.\n3. **Output**: There is no direct output; the function modifies the internal state of the hook by updating the context providers map.\n4. **Procedure**: The function takes a context provider as input and maps its context type to the provider itself within a dictionary that tracks context providers. This allows child components to access shared state through the context system.\n\n", "instruction": "Based on the function description and code context, please retrieve and repeat the exact described function from the code context in a code block wrapped by ```:", "template": "instruction\ncode_context\ndescription\ninstruction", "code_context": "# Path: src/py/reactpy/reactpy/backend/__init__.py\nimport mimetypes\nfrom logging import getLogger\n\n_logger = getLogger(__name__)\n\n# Fix for missing mime types due to OS corruption/misconfiguration\n# Example: https://github.com/encode/starlette/issues/829\nif not mimetypes.inited:\n mimetypes.init()\nfor extension, mime_type in {\n \".js\": \"application/javascript\",\n \".css\": \"text/css\",\n \".json\": \"application/json\",\n}.items():\n if not mimetypes.types_map.get(extension): # pragma: no cover\n _logger.warning(\n \"Mime type '%s = %s' is missing. Please research how to \"\n \"fix missing mime types on your operating system.\",\n extension,\n mime_type,\n )\n mimetypes.add_type(mime_type, extension)\n\n# Path: src/py/reactpy/reactpy/core/__init__.py\n\n# Path: src/py/reactpy/reactpy/core/types.py\nfrom __future__ import annotations\n\nimport sys\nfrom collections import namedtuple\nfrom collections.abc import Mapping, Sequence\nfrom types import TracebackType\nfrom typing import (\n TYPE_CHECKING,\n Any,\n Callable,\n Generic,\n Literal,\n NamedTuple,\n Protocol,\n TypeVar,\n overload,\n runtime_checkable,\n)\n\nfrom typing_extensions import TypeAlias, TypedDict\n\n_Type = TypeVar(\"_Type\")\n\n\nif TYPE_CHECKING or sys.version_info < (3, 9) or sys.version_info >= (3, 11):\n\n class State(NamedTuple, Generic[_Type]):\n value: _Type\n set_value: Callable[[_Type | Callable[[_Type], _Type]], None]\n\nelse: # nocov\n State = namedtuple(\"State\", (\"value\", \"set_value\"))\n\n\nComponentConstructor = Callable[..., \"ComponentType\"]\n\"\"\"Simple function returning a new component\"\"\"\n\nRootComponentConstructor = Callable[[], \"ComponentType\"]\n\"\"\"The root component should be constructed by a function accepting no arguments.\"\"\"\n\n\nKey: TypeAlias = \"str | int\"\n\n\n_OwnType = TypeVar(\"_OwnType\")\n\n\n@runtime_checkable\nclass ComponentType(Protocol):\n \"\"\"The expected interface for all component-like objects\"\"\"\n\n key: Key | None\n \"\"\"An identifier which is unique amongst a component's immediate siblings\"\"\"\n\n type: Any\n \"\"\"The function or class defining the behavior of this component\n\n This is used to see if two component instances share the same definition.\n \"\"\"\n\n def render(self) -> VdomDict | ComponentType | str | None:\n \"\"\"Render the component's view model.\"\"\"\n\n\n_Render_co = TypeVar(\"_Render_co\", covariant=True)\n_Event_contra = TypeVar(\"_Event_contra\", contravariant=True)\n\n\n@runtime_checkable\nclass LayoutType(Protocol[_Render_co, _Event_contra]):\n \"\"\"Renders and delivers, updates to views and events to handlers, respectively\"\"\"\n\n async def render(self) -> _Render_co:\n \"\"\"Render an update to a view\"\"\"\n\n async def deliver(self, event: _Event_contra) -> None:\n \"\"\"Relay an event to its respective handler\"\"\"\n\n async def __aenter__(self) -> LayoutType[_Render_co, _Event_contra]:\n \"\"\"Prepare the layout for its first render\"\"\"\n\n async def __aexit__(\n self,\n exc_type: type[Exception],\n exc_value: Exception,\n traceback: TracebackType,\n ) -> bool | None:\n \"\"\"Clean up the view after its final render\"\"\"\n\n\nVdomAttributes = Mapping[str, Any]\n\"\"\"Describes the attributes of a :class:`VdomDict`\"\"\"\n\nVdomChild: TypeAlias = \"ComponentType | VdomDict | str | None | Any\"\n\"\"\"A single child element of a :class:`VdomDict`\"\"\"\n\nVdomChildren: TypeAlias = \"Sequence[VdomChild] | VdomChild\"\n\"\"\"Describes a series of :class:`VdomChild` elements\"\"\"\n\n\nclass _VdomDictOptional(TypedDict, total=False):\n key: Key | None\n children: Sequence[ComponentType | VdomChild]\n attributes: VdomAttributes\n eventHandlers: EventHandlerDict\n importSource: ImportSourceDict\n\n\nclass _VdomDictRequired(TypedDict, total=True):\n tagName: str\n\n\nclass VdomDict(_VdomDictRequired, _VdomDictOptional):\n \"\"\"A :ref:`VDOM` dictionary\"\"\"\n\n\nclass ImportSourceDict(TypedDict):\n source: str\n fallback: Any\n sourceType: str\n unmountBeforeUpdate: bool\n\n\nclass _OptionalVdomJson(TypedDict, total=False):\n key: Key\n error: str\n children: list[Any]\n attributes: dict[str, Any]\n eventHandlers: dict[str, _JsonEventTarget]\n importSource: _JsonImportSource\n\n\nclass _RequiredVdomJson(TypedDict, total=True):\n tagName: str\n\n\nclass VdomJson(_RequiredVdomJson, _OptionalVdomJson):\n \"\"\"A JSON serializable form of :class:`VdomDict` matching the :data:`VDOM_JSON_SCHEMA`\"\"\"\n\n\nclass _JsonEventTarget(TypedDict):\n target: str\n preventDefault: bool\n stopPropagation: bool\n\n\nclass _JsonImportSource(TypedDict):\n source: str\n fallback: Any\n\n\nEventHandlerMapping = Mapping[str, \"EventHandlerType\"]\n\"\"\"A generic mapping between event names to their handlers\"\"\"\n\nEventHandlerDict: TypeAlias = \"dict[str, EventHandlerType]\"\n\"\"\"A dict mapping between event names to their handlers\"\"\"\n\n\nclass EventHandlerFunc(Protocol):\n \"\"\"A coroutine which can handle event data\"\"\"\n\n async def __call__(self, data: Sequence[Any]) -> None: ...\n\n\n@runtime_checkable\nclass EventHandlerType(Protocol):\n \"\"\"Defines a handler for some event\"\"\"\n\n prevent_default: bool\n \"\"\"Whether to block the event from propagating further up the DOM\"\"\"\n\n stop_propagation: bool\n \"\"\"Stops the default action associate with the event from taking place.\"\"\"\n\n function: EventHandlerFunc\n \"\"\"A coroutine which can respond to an event and its data\"\"\"\n\n target: str | None\n \"\"\"Typically left as ``None`` except when a static target is useful.\n\n When testing, it may be useful to specify a static target ID so events can be\n triggered programmatically.\n\n .. note::\n\n When ``None``, it is left to a :class:`LayoutType` to auto generate a unique ID.\n \"\"\"\n\n\nclass VdomDictConstructor(Protocol):\n \"\"\"Standard function for constructing a :class:`VdomDict`\"\"\"\n\n @overload\n def __call__(\n self, attributes: VdomAttributes, *children: VdomChildren\n ) -> VdomDict: ...\n\n @overload\n def __call__(self, *children: VdomChildren) -> VdomDict: ...\n\n @overload\n def __call__(\n self, *attributes_and_children: VdomAttributes | VdomChildren\n ) -> VdomDict: ...\n\n\nclass LayoutUpdateMessage(TypedDict):\n \"\"\"A message describing an update to a layout\"\"\"\n\n type: Literal[\"layout-update\"]\n \"\"\"The type of message\"\"\"\n path: str\n \"\"\"JSON Pointer path to the model element being updated\"\"\"\n model: VdomJson\n \"\"\"The model to assign at the given JSON Pointer path\"\"\"\n\n\nclass LayoutEventMessage(TypedDict):\n \"\"\"Message describing an event originating from an element in the layout\"\"\"\n\n type: Literal[\"layout-event\"]\n \"\"\"The type of message\"\"\"\n target: str\n \"\"\"The ID of the event handler.\"\"\"\n data: Sequence[Any]\n \"\"\"A list of event data passed to the event handler.\"\"\"\n\n\nclass Context(Protocol[_Type]):\n \"\"\"Returns a :class:`ContextProvider` component\"\"\"\n\n def __call__(\n self,\n *children: Any,\n value: _Type = ...,\n key: Key | None = ...,\n ) -> ContextProviderType[_Type]: ...\n\n\nclass ContextProviderType(ComponentType, Protocol[_Type]):\n \"\"\"A component which provides a context value to its children\"\"\"\n\n type: Context[_Type]\n \"\"\"The context type\"\"\"\n\n @property\n def value(self) -> _Type:\n \"Current context value\"\n\n# Path: src/py/reactpy/reactpy/backend/types.py\nfrom __future__ import annotations\n\nimport asyncio\nfrom collections.abc import MutableMapping\nfrom dataclasses import dataclass\nfrom typing import Any, Callable, Generic, Protocol, TypeVar, runtime_checkable\n\nfrom reactpy.core.types import RootComponentConstructor\n\n_App = TypeVar(\"_App\")\n\n\n@runtime_checkable\nclass BackendType(Protocol[_App]):\n \"\"\"Common interface for built-in web server/framework integrations\"\"\"\n\n Options: Callable[..., Any]\n \"\"\"A constructor for options passed to :meth:`BackendType.configure`\"\"\"\n\n def configure(\n self,\n app: _App,\n component: RootComponentConstructor,\n options: Any | None = None,\n ) -> None:\n \"\"\"Configure the given app instance to display the given component\"\"\"\n\n def create_development_app(self) -> _App:\n \"\"\"Create an application instance for development purposes\"\"\"\n\n async def serve_development_app(\n self,\n app: _App,\n host: str,\n port: int,\n started: asyncio.Event | None = None,\n ) -> None:\n \"\"\"Run an application using a development server\"\"\"\n\n\n_Carrier = TypeVar(\"_Carrier\")\n\n\n@dataclass\nclass Connection(Generic[_Carrier]):\n \"\"\"Represents a connection with a client\"\"\"\n\n scope: MutableMapping[str, Any]\n \"\"\"An ASGI scope or WSGI environment dictionary\"\"\"\n\n location: Location\n \"\"\"The current location (URL)\"\"\"\n\n carrier: _Carrier\n \"\"\"How the connection is mediated. For example, a request or websocket.\n\n This typically depends on the backend implementation.\n \"\"\"\n\n\n@dataclass\nclass Location:\n \"\"\"Represents the current location (URL)\n\n Analogous to, but not necessarily identical to, the client-side\n ``document.location`` object.\n \"\"\"\n\n pathname: str\n \"\"\"the path of the URL for the location\"\"\"\n\n search: str\n \"\"\"A search or query string - a '?' followed by the parameters of the URL.\n\n If there are no search parameters this should be an empty string\n \"\"\"\n\n# Path: src/py/reactpy/reactpy/_warnings.py\nfrom collections.abc import Iterator\nfrom functools import wraps\nfrom inspect import currentframe\nfrom types import FrameType\nfrom typing import TYPE_CHECKING, Any\nfrom warnings import warn as _warn\n\n\n@wraps(_warn)\ndef warn(*args: Any, **kwargs: Any) -> Any:\n # warn at call site outside of ReactPy\n _warn(*args, stacklevel=_frame_depth_in_module() + 1, **kwargs) # type: ignore\n\n\nif TYPE_CHECKING:\n warn = _warn # noqa: F811\n\n\ndef _frame_depth_in_module() -> int:\n depth = 0\n for frame in _iter_frames(2):\n module_name = frame.f_globals.get(\"__name__\")\n if not module_name or not module_name.startswith(\"reactpy.\"):\n break\n depth += 1\n return depth\n\n\ndef _iter_frames(index: int = 1) -> Iterator[FrameType]:\n frame = currentframe()\n while frame is not None:\n if index == 0:\n yield frame\n else:\n index -= 1\n frame = frame.f_back\n\n# Path: src/py/reactpy/reactpy/_option.py\nfrom __future__ import annotations\n\nimport os\nfrom logging import getLogger\nfrom typing import Any, Callable, Generic, TypeVar, cast\n\nfrom reactpy._warnings import warn\n\n_O = TypeVar(\"_O\")\nlogger = getLogger(__name__)\nUNDEFINED = cast(Any, object())\n\n\nclass Option(Generic[_O]):\n \"\"\"An option that can be set using an environment variable of the same name\"\"\"\n\n def __init__(\n self,\n name: str,\n default: _O = UNDEFINED,\n mutable: bool = True,\n parent: Option[_O] | None = None,\n validator: Callable[[Any], _O] = lambda x: cast(_O, x),\n ) -> None:\n self._name = name\n self._mutable = mutable\n self._validator = validator\n self._subscribers: list[Callable[[_O], None]] = []\n\n if name in os.environ:\n self._current = validator(os.environ[name])\n\n if parent is not None:\n if not (parent.mutable and self.mutable):\n raise TypeError(\"Parent and child options must be mutable\")\n self._default = parent.default\n parent.subscribe(self.set_current)\n elif default is not UNDEFINED:\n self._default = default\n else:\n raise TypeError(\"Must specify either a default or a parent option\")\n\n logger.debug(f\"{self._name}={self.current}\")\n\n @property\n def name(self) -> str:\n \"\"\"The name of this option (used to load environment variables)\"\"\"\n return self._name\n\n @property\n def mutable(self) -> bool:\n \"\"\"Whether this option can be modified after being loaded\"\"\"\n return self._mutable\n\n @property\n def default(self) -> _O:\n \"\"\"This option's default value\"\"\"\n return self._default\n\n @property\n def current(self) -> _O:\n try:\n return self._current\n except AttributeError:\n return self._default\n\n @current.setter\n def current(self, new: _O) -> None:\n self.set_current(new)\n\n @current.deleter\n def current(self) -> None:\n self.unset()\n\n def subscribe(self, handler: Callable[[_O], None]) -> Callable[[_O], None]:\n \"\"\"Register a callback that will be triggered when this option changes\"\"\"\n if not self.mutable:\n msg = \"Immutable options cannot be subscribed to.\"\n raise TypeError(msg)\n self._subscribers.append(handler)\n handler(self.current)\n return handler\n\n def is_set(self) -> bool:\n \"\"\"Whether this option has a value other than its default.\"\"\"\n return hasattr(self, \"_current\")\n\n def set_current(self, new: Any) -> None:\n \"\"\"Set the value of this option\n\n Raises a ``TypeError`` if this option is not :attr:`Option.mutable`.\n \"\"\"\n old = self.current\n if new is old:\n return None\n\n if not self._mutable:\n msg = f\"{self} cannot be modified after initial load\"\n raise TypeError(msg)\n\n try:\n new = self._current = self._validator(new)\n except ValueError as error:\n raise ValueError(f\"Invalid value for {self._name}: {new!r}\") from error\n\n logger.debug(f\"{self._name}={self._current}\")\n if new != old:\n for sub_func in self._subscribers:\n sub_func(new)\n\n def set_default(self, new: _O) -> _O:\n \"\"\"Set the value of this option if not :meth:`Option.is_set`\n\n Returns the current value (a la :meth:`dict.set_default`)\n \"\"\"\n if not self.is_set():\n self.set_current(new)\n return self._current\n\n def reload(self) -> None:\n \"\"\"Reload this option from its environment variable\"\"\"\n self.set_current(os.environ.get(self._name, self._default))\n\n def unset(self) -> None:\n \"\"\"Remove the current value, the default will be used until it is set again.\"\"\"\n if not self._mutable:\n msg = f\"{self} cannot be modified after initial load\"\n raise TypeError(msg)\n old = self.current\n if hasattr(self, \"_current\"):\n delattr(self, \"_current\")\n if self.current != old:\n for sub_func in self._subscribers:\n sub_func(self.current)\n\n def __repr__(self) -> str:\n return f\"Option({self._name}={self.current!r})\"\n\n\nclass DeprecatedOption(Option[_O]):\n \"\"\"An option that will warn when it is accessed\"\"\"\n\n def __init__(self, *args: Any, message: str, **kwargs: Any) -> None:\n super().__init__(*args, **kwargs)\n self._deprecation_message = message\n\n @Option.current.getter # type: ignore\n def current(self) -> _O:\n try:\n # we access the current value during init to debug log it\n # no need to warn unless it's actually used. since this attr\n # is only set after super().__init__ is called, we can check\n # for it to determine if it's being accessed by a user.\n msg = self._deprecation_message\n except AttributeError:\n pass\n else:\n warn(msg, DeprecationWarning)\n return super().current\n\n# Path: src/py/reactpy/reactpy/config.py\n\"\"\"\nReactPy provides a series of configuration options that can be set using environment\nvariables or, for those which allow it, a programmatic interface.\n\"\"\"\n\nfrom __future__ import annotations\n\nfrom pathlib import Path\nfrom tempfile import TemporaryDirectory\n\nfrom reactpy._option import Option\n\nTRUE_VALUES = {\"true\", \"1\"}\nFALSE_VALUES = {\"false\", \"0\"}\n\n\ndef boolean(value: str | bool | int) -> bool:\n if isinstance(value, bool):\n return value\n elif isinstance(value, int):\n return bool(value)\n elif not isinstance(value, str):\n raise TypeError(f\"Expected str or bool, got {type(value).__name__}\")\n\n if value.lower() in TRUE_VALUES:\n return True\n elif value.lower() in FALSE_VALUES:\n return False\n else:\n raise ValueError(\n f\"Invalid boolean value {value!r} - expected \"\n f\"one of {list(TRUE_VALUES | FALSE_VALUES)}\"\n )\n\n\nREACTPY_DEBUG_MODE = Option(\n \"REACTPY_DEBUG_MODE\", default=False, validator=boolean, mutable=True\n)\n\"\"\"Get extra logs and validation checks at the cost of performance.\n\nThis will enable the following:\n\n- :data:`REACTPY_CHECK_VDOM_SPEC`\n- :data:`REACTPY_CHECK_JSON_ATTRS`\n\"\"\"\n\nREACTPY_CHECK_VDOM_SPEC = Option(\"REACTPY_CHECK_VDOM_SPEC\", parent=REACTPY_DEBUG_MODE)\n\"\"\"Checks which ensure VDOM is rendered to spec\n\nFor more info on the VDOM spec, see here: :ref:`VDOM JSON Schema`\n\"\"\"\n\nREACTPY_CHECK_JSON_ATTRS = Option(\"REACTPY_CHECK_JSON_ATTRS\", parent=REACTPY_DEBUG_MODE)\n\"\"\"Checks that all VDOM attributes are JSON serializable\n\nThe VDOM spec is not able to enforce this on its own since attributes could anything.\n\"\"\"\n\n# Because these web modules will be linked dynamically at runtime this can be temporary.\n# Assigning to a variable here ensures that the directory is not deleted until the end\n# of the program.\n_DEFAULT_WEB_MODULES_DIR = TemporaryDirectory()\n\nREACTPY_WEB_MODULES_DIR = Option(\n \"REACTPY_WEB_MODULES_DIR\",\n default=Path(_DEFAULT_WEB_MODULES_DIR.name),\n validator=Path,\n)\n\"\"\"The location ReactPy will use to store its client application\n\nThis directory **MUST** be treated as a black box. Downstream applications **MUST NOT**\nassume anything about the structure of this directory see :mod:`reactpy.web.module` for a\nset of publicly available APIs for working with the client.\n\"\"\"\n\nREACTPY_TESTING_DEFAULT_TIMEOUT = Option(\n \"REACTPY_TESTING_DEFAULT_TIMEOUT\",\n 5.0,\n mutable=False,\n validator=float,\n)\n\"\"\"A default timeout for testing utilities in ReactPy\"\"\"\n\nREACTPY_ASYNC_RENDERING = Option(\n \"REACTPY_CONCURRENT_RENDERING\",\n default=False,\n mutable=True,\n validator=boolean,\n)\n\"\"\"Whether to render components concurrently. This is currently an experimental feature.\"\"\"\n\n# Path: src/py/reactpy/reactpy/core/_thread_local.py\nfrom threading import Thread, current_thread\nfrom typing import Callable, Generic, TypeVar\nfrom weakref import WeakKeyDictionary\n\n_StateType = TypeVar(\"_StateType\")\n\n\nclass ThreadLocal(Generic[_StateType]):\n \"\"\"Utility for managing per-thread state information\"\"\"\n\n def __init__(self, default: Callable[[], _StateType]):\n self._default = default\n self._state: WeakKeyDictionary[Thread, _StateType] = WeakKeyDictionary()\n\n def get(self) -> _StateType:\n thread = current_thread()\n if thread not in self._state:\n state = self._state[thread] = self._default()\n else:\n state = self._state[thread]\n return state\n\n# Path: src/py/reactpy/reactpy/core/_life_cycle_hook.py\nfrom __future__ import annotations\n\nimport logging\nfrom asyncio import Event, Task, create_task, gather\nfrom typing import Any, Callable, Protocol, TypeVar\n\nfrom anyio import Semaphore\n\nfrom reactpy.core._thread_local import ThreadLocal\nfrom reactpy.core.types import ComponentType, Context, ContextProviderType\n\nT = TypeVar(\"T\")\n\n\nclass EffectFunc(Protocol):\n async def __call__(self, stop: Event) -> None: ...\n\n\nlogger = logging.getLogger(__name__)\n\n_HOOK_STATE: ThreadLocal[list[LifeCycleHook]] = ThreadLocal(list)\n\n\ndef current_hook() -> LifeCycleHook:\n \"\"\"Get the current :class:`LifeCycleHook`\"\"\"\n hook_stack = _HOOK_STATE.get()\n if not hook_stack:\n msg = \"No life cycle hook is active. Are you rendering in a layout?\"\n raise RuntimeError(msg)\n return hook_stack[-1]\n\n\nclass LifeCycleHook:\n \"\"\"An object which manages the \"life cycle\" of a layout component.\n\n The \"life cycle\" of a component is the set of events which occur from the time\n a component is first rendered until it is removed from the layout. The life cycle\n is ultimately driven by the layout itself, but components can \"hook\" into those\n events to perform actions. Components gain access to their own life cycle hook\n by calling :func:`current_hook`. They can then perform actions such as:\n\n 1. Adding state via :meth:`use_state`\n 2. Adding effects via :meth:`add_effect`\n 3. Setting or getting context providers via\n :meth:`LifeCycleHook.set_context_provider` and\n :meth:`get_context_provider` respectively.\n\n Components can request access to their own life cycle events and state through hooks\n while :class:`~reactpy.core.proto.LayoutType` objects drive drive the life cycle\n forward by triggering events and rendering view changes.\n\n Example:\n\n If removed from the complexities of a layout, a very simplified full life cycle\n for a single component with no child components would look a bit like this:\n\n .. testcode::\n\n from reactpy.core._life_cycle_hook import LifeCycleHook\n from reactpy.core.hooks import current_hook\n\n # this function will come from a layout implementation\n schedule_render = lambda: ...\n\n # --- start life cycle ---\n\n hook = LifeCycleHook(schedule_render)\n\n # --- start render cycle ---\n\n component = ...\n await hook.affect_component_will_render(component)\n try:\n # render the component\n ...\n\n # the component may access the current hook\n assert current_hook() is hook\n\n # and save state or add effects\n current_hook().use_state(lambda: ...)\n\n async def my_effect(stop_event):\n ...\n\n current_hook().add_effect(my_effect)\n finally:\n await hook.affect_component_did_render()\n\n # This should only be called after the full set of changes associated with a\n # given render have been completed.\n await hook.affect_layout_did_render()\n\n # Typically an event occurs and a new render is scheduled, thus beginning\n # the render cycle anew.\n hook.schedule_render()\n\n\n # --- end render cycle ---\n\n hook.affect_component_will_unmount()\n del hook\n\n # --- end render cycle ---\n \"\"\"\n\n __slots__ = (\n \"__weakref__\",\n \"_context_providers\",\n \"_current_state_index\",\n \"_effect_funcs\",\n \"_effect_stops\",\n \"_effect_tasks\",\n \"_render_access\",\n \"_rendered_atleast_once\",\n \"_schedule_render_callback\",\n \"_scheduled_render\",\n \"_state\",\n \"component\",\n )\n\n component: ComponentType\n\n def __init__(\n self,\n schedule_render: Callable[[], None],\n ) -> None:\n self._context_providers: dict[Context[Any], ContextProviderType[Any]] = {}\n self._schedule_render_callback = schedule_render\n self._scheduled_render = False\n self._rendered_atleast_once = False\n self._current_state_index = 0\n self._state: tuple[Any, ...] = ()\n self._effect_funcs: list[EffectFunc] = []\n self._effect_tasks: list[Task[None]] = []\n self._effect_stops: list[Event] = []\n self._render_access = Semaphore(1) # ensure only one render at a time\n\n def schedule_render(self) -> None:\n if self._scheduled_render:\n return None\n try:\n self._schedule_render_callback()\n except Exception:\n msg = f\"Failed to schedule render via {self._schedule_render_callback}\"\n logger.exception(msg)\n else:\n self._scheduled_render = True\n\n def use_state(self, function: Callable[[], T]) -> T:\n \"\"\"Add state to this hook\n\n If this hook has not yet rendered, the state is appended to the state tuple.\n Otherwise, the state is retrieved from the tuple. This allows state to be\n preserved across renders.\n \"\"\"\n if not self._rendered_atleast_once:\n # since we're not initialized yet we're just appending state\n result = function()\n self._state += (result,)\n else:\n # once finalized we iterate over each succesively used piece of state\n result = self._state[self._current_state_index]\n self._current_state_index += 1\n return result\n\n def add_effect(self, effect_func: EffectFunc) -> None:\n \"\"\"Add an effect to this hook\n\n A task to run the effect is created when the component is done rendering.\n When the component will be unmounted, the event passed to the effect is\n triggered and the task is awaited. The effect should eventually halt after\n the event is triggered.\n \"\"\"\n self._effect_funcs.append(effect_func)\n\n \ndef set_context_provider(self, provider: ContextProviderType[Any]) -> None:\n \"\"\"Set a context provider for this hook\n\n The context provider will be used to provide state to any child components\n of this hook's component which request a context provider of the same type.\n \"\"\"\n self._context_providers[provider.type] = provider\n\n def get_context_provider(\n self, context: Context[T]\n ) -> ContextProviderType[T] | None:\n \"\"\"Get a context provider for this hook of the given type\n\n The context provider will have been set by a parent component. If no provider\n is found, ``None`` is returned.\n \"\"\"\n return self._context_providers.get(context)\n\n async def affect_component_will_render(self, component: ComponentType) -> None:\n \"\"\"The component is about to render\"\"\"\n await self._render_access.acquire()\n self._scheduled_render = False\n self.component = component\n self.set_current()\n\n async def affect_component_did_render(self) -> None:\n \"\"\"The component completed a render\"\"\"\n self.unset_current()\n self._rendered_atleast_once = True\n self._current_state_index = 0\n self._render_access.release()\n del self.component\n\n async def affect_layout_did_render(self) -> None:\n \"\"\"The layout completed a render\"\"\"\n stop = Event()\n self._effect_stops.append(stop)\n self._effect_tasks.extend(create_task(e(stop)) for e in self._effect_funcs)\n self._effect_funcs.clear()\n\n async def affect_component_will_unmount(self) -> None:\n \"\"\"The component is about to be removed from the layout\"\"\"\n for stop in self._effect_stops:\n stop.set()\n self._effect_stops.clear()\n try:\n await gather(*self._effect_tasks)\n except Exception:\n logger.exception(\"Error in effect\")\n finally:\n self._effect_tasks.clear()\n\n def set_current(self) -> None:\n \"\"\"Set this hook as the active hook in this thread\n\n This method is called by a layout before entering the render method\n of this hook's associated component.\n \"\"\"\n hook_stack = _HOOK_STATE.get()\n if hook_stack:\n parent = hook_stack[-1]\n self._context_providers.update(parent._context_providers)\n hook_stack.append(self)\n\n def unset_current(self) -> None:\n \"\"\"Unset this hook as the active hook in this thread\"\"\"\n if _HOOK_STATE.get().pop() is not self:\n raise RuntimeError(\"Hook stack is in an invalid state\") # nocov\n\n# Path: src/py/reactpy/reactpy/core/_f_back.py\nfrom __future__ import annotations\n\nimport inspect\nfrom types import FrameType\n\n\ndef f_module_name(index: int = 0) -> str:\n frame = f_back(index + 1)\n if frame is None:\n return \"\" # nocov\n name = frame.f_globals.get(\"__name__\", \"\")\n if not isinstance(name, str):\n raise TypeError(\"Expected module name to be a string\") # nocov\n return name\n\n\ndef f_back(index: int = 0) -> FrameType | None:\n frame = inspect.currentframe()\n while frame is not None:\n if index < 0:\n return frame\n frame = frame.f_back\n index -= 1\n return None # nocov\n\n# Path: src/py/reactpy/reactpy/core/events.py\nfrom __future__ import annotations\n\nimport asyncio\nfrom collections.abc import Sequence\nfrom typing import Any, Callable, Literal, overload\n\nfrom anyio import create_task_group\n\nfrom reactpy.core.types import EventHandlerFunc, EventHandlerType\n\n\n@overload\ndef event(\n function: Callable[..., Any],\n *,\n stop_propagation: bool = ...,\n prevent_default: bool = ...,\n) -> EventHandler: ...\n\n\n@overload\ndef event(\n function: Literal[None] = ...,\n *,\n stop_propagation: bool = ...,\n prevent_default: bool = ...,\n) -> Callable[[Callable[..., Any]], EventHandler]: ...\n\n\ndef event(\n function: Callable[..., Any] | None = None,\n *,\n stop_propagation: bool = False,\n prevent_default: bool = False,\n) -> EventHandler | Callable[[Callable[..., Any]], EventHandler]:\n \"\"\"A decorator for constructing an :class:`EventHandler`.\n\n While you're always free to add callbacks by assigning them to an element's attributes\n\n .. code-block:: python\n\n element = reactpy.html.button({\"onClick\": my_callback})\n\n You may want the ability to prevent the default action associated with the event\n from taking place, or stopping the event from propagating up the DOM. This decorator\n allows you to add that functionality to your callbacks.\n\n .. code-block:: python\n\n @event(stop_propagation=True, prevent_default=True)\n def my_callback(*data):\n ...\n\n element = reactpy.html.button({\"onClick\": my_callback})\n\n Parameters:\n function:\n A function or coroutine responsible for handling the event.\n stop_propagation:\n Block the event from propagating further up the DOM.\n prevent_default:\n Stops the default actional associate with the event from taking place.\n \"\"\"\n\n def setup(function: Callable[..., Any]) -> EventHandler:\n return EventHandler(\n to_event_handler_function(function, positional_args=True),\n stop_propagation,\n prevent_default,\n )\n\n if function is not None:\n return setup(function)\n else:\n return setup\n\n\nclass EventHandler:\n \"\"\"Turn a function or coroutine into an event handler\n\n Parameters:\n function:\n The function or coroutine which handles the event.\n stop_propagation:\n Block the event from propagating further up the DOM.\n prevent_default:\n Stops the default action associate with the event from taking place.\n target:\n A unique identifier for this event handler (auto-generated by default)\n \"\"\"\n\n __slots__ = (\n \"__weakref__\",\n \"function\",\n \"prevent_default\",\n \"stop_propagation\",\n \"target\",\n )\n\n def __init__(\n self,\n function: EventHandlerFunc,\n stop_propagation: bool = False,\n prevent_default: bool = False,\n target: str | None = None,\n ) -> None:\n self.function = to_event_handler_function(function, positional_args=False)\n self.prevent_default = prevent_default\n self.stop_propagation = stop_propagation\n self.target = target\n\n def __eq__(self, other: Any) -> bool:\n undefined = object()\n for attr in (\n \"function\",\n \"prevent_default\",\n \"stop_propagation\",\n \"target\",\n ):\n if not attr.startswith(\"_\"):\n if not getattr(other, attr, undefined) == getattr(self, attr):\n return False\n return True\n\n def __repr__(self) -> str:\n public_names = [name for name in self.__slots__ if not name.startswith(\"_\")]\n items = \", \".join([f\"{n}={getattr(self, n)!r}\" for n in public_names])\n return f\"{type(self).__name__}({items})\"\n\n\ndef to_event_handler_function(\n function: Callable[..., Any],\n positional_args: bool = True,\n) -> EventHandlerFunc:\n \"\"\"Make a :data:`~reactpy.core.proto.EventHandlerFunc` from a function or coroutine\n\n Parameters:\n function:\n A function or coroutine accepting a number of positional arguments.\n positional_args:\n Whether to pass the event parameters a positional args or as a list.\n \"\"\"\n if positional_args:\n if asyncio.iscoroutinefunction(function):\n\n async def wrapper(data: Sequence[Any]) -> None:\n await function(*data)\n\n else:\n\n async def wrapper(data: Sequence[Any]) -> None:\n function(*data)\n\n return wrapper\n elif not asyncio.iscoroutinefunction(function):\n\n async def wrapper(data: Sequence[Any]) -> None:\n function(data)\n\n return wrapper\n else:\n return function\n\n\ndef merge_event_handlers(\n event_handlers: Sequence[EventHandlerType],\n) -> EventHandlerType:\n \"\"\"Merge multiple event handlers into one\n\n Raises a ValueError if any handlers have conflicting\n :attr:`~reactpy.core.proto.EventHandlerType.stop_propagation` or\n :attr:`~reactpy.core.proto.EventHandlerType.prevent_default` attributes.\n \"\"\"\n if not event_handlers:\n msg = \"No event handlers to merge\"\n raise ValueError(msg)\n elif len(event_handlers) == 1:\n return event_handlers[0]\n\n first_handler = event_handlers[0]\n\n stop_propagation = first_handler.stop_propagation\n prevent_default = first_handler.prevent_default\n target = first_handler.target\n\n for handler in event_handlers:\n if (\n handler.stop_propagation != stop_propagation\n or handler.prevent_default != prevent_default\n or handler.target != target\n ):\n msg = \"Cannot merge handlers - 'stop_propagation', 'prevent_default' or 'target' mismatch.\"\n raise ValueError(msg)\n\n return EventHandler(\n merge_event_handler_funcs([h.function for h in event_handlers]),\n stop_propagation,\n prevent_default,\n target,\n )\n\n\ndef merge_event_handler_funcs(\n functions: Sequence[EventHandlerFunc],\n) -> EventHandlerFunc:\n \"\"\"Make one event handler function from many\"\"\"\n if not functions:\n msg = \"No event handler functions to merge\"\n raise ValueError(msg)\n elif len(functions) == 1:\n return functions[0]\n\n async def await_all_event_handlers(data: Sequence[Any]) -> None:\n async with create_task_group() as group:\n for func in functions:\n group.start_soon(func, data)\n\n return await_all_event_handlers\n\n# Path: src/py/reactpy/reactpy/core/vdom.py\nfrom __future__ import annotations\n\nimport json\nfrom collections.abc import Mapping, Sequence\nfrom functools import wraps\nfrom typing import Any, Protocol, cast, overload\n\nfrom fastjsonschema import compile as compile_json_schema\n\nfrom reactpy._warnings import warn\nfrom reactpy.config import REACTPY_CHECK_JSON_ATTRS, REACTPY_DEBUG_MODE\nfrom reactpy.core._f_back import f_module_name\nfrom reactpy.core.events import EventHandler, to_event_handler_function\nfrom reactpy.core.types import (\n ComponentType,\n EventHandlerDict,\n EventHandlerType,\n ImportSourceDict,\n Key,\n VdomAttributes,\n VdomChild,\n VdomChildren,\n VdomDict,\n VdomDictConstructor,\n VdomJson,\n)\n\nVDOM_JSON_SCHEMA = {\n \"$schema\": \"http://json-schema.org/draft-07/schema\",\n \"$ref\": \"#/definitions/element\",\n \"definitions\": {\n \"element\": {\n \"type\": \"object\",\n \"properties\": {\n \"tagName\": {\"type\": \"string\"},\n \"key\": {\"type\": [\"string\", \"number\", \"null\"]},\n \"error\": {\"type\": \"string\"},\n \"children\": {\"$ref\": \"#/definitions/elementChildren\"},\n \"attributes\": {\"type\": \"object\"},\n \"eventHandlers\": {\"$ref\": \"#/definitions/elementEventHandlers\"},\n \"importSource\": {\"$ref\": \"#/definitions/importSource\"},\n },\n # The 'tagName' is required because its presence is a useful indicator of\n # whether a dictionary describes a VDOM model or not.\n \"required\": [\"tagName\"],\n \"dependentSchemas\": {\n # When 'error' is given, the 'tagName' should be empty.\n \"error\": {\"properties\": {\"tagName\": {\"maxLength\": 0}}}\n },\n },\n \"elementChildren\": {\n \"type\": \"array\",\n \"items\": {\"$ref\": \"#/definitions/elementOrString\"},\n },\n \"elementEventHandlers\": {\n \"type\": \"object\",\n \"patternProperties\": {\n \".*\": {\"$ref\": \"#/definitions/eventHandler\"},\n },\n },\n \"eventHandler\": {\n \"type\": \"object\",\n \"properties\": {\n \"target\": {\"type\": \"string\"},\n \"preventDefault\": {\"type\": \"boolean\"},\n \"stopPropagation\": {\"type\": \"boolean\"},\n },\n \"required\": [\"target\"],\n },\n \"importSource\": {\n \"type\": \"object\",\n \"properties\": {\n \"source\": {\"type\": \"string\"},\n \"sourceType\": {\"enum\": [\"URL\", \"NAME\"]},\n \"fallback\": {\n \"type\": [\"object\", \"string\", \"null\"],\n \"if\": {\"not\": {\"type\": \"null\"}},\n \"then\": {\"$ref\": \"#/definitions/elementOrString\"},\n },\n \"unmountBeforeUpdate\": {\"type\": \"boolean\"},\n },\n \"required\": [\"source\"],\n },\n \"elementOrString\": {\n \"type\": [\"object\", \"string\"],\n \"if\": {\"type\": \"object\"},\n \"then\": {\"$ref\": \"#/definitions/element\"},\n },\n },\n}\n\"\"\"JSON Schema describing serialized VDOM - see :ref:`VDOM` for more info\"\"\"\n\n\n# we can't add a docstring to this because Sphinx doesn't know how to find its source\n_COMPILED_VDOM_VALIDATOR = compile_json_schema(VDOM_JSON_SCHEMA)\n\n\ndef validate_vdom_json(value: Any) -> VdomJson:\n \"\"\"Validate serialized VDOM - see :attr:`VDOM_JSON_SCHEMA` for more info\"\"\"\n _COMPILED_VDOM_VALIDATOR(value)\n return cast(VdomJson, value)\n\n\ndef is_vdom(value: Any) -> bool:\n \"\"\"Return whether a value is a :class:`VdomDict`\n\n This employs a very simple heuristic - something is VDOM if:\n\n 1. It is a ``dict`` instance\n 2. It contains the key ``\"tagName\"``\n 3. The value of the key ``\"tagName\"`` is a string\n\n .. note::\n\n Performing an ``isinstance(value, VdomDict)`` check is too restrictive since the\n user would be forced to import ``VdomDict`` every time they needed to declare a\n VDOM element. Giving the user more flexibility, at the cost of this check's\n accuracy, is worth it.\n \"\"\"\n return (\n isinstance(value, dict)\n and \"tagName\" in value\n and isinstance(value[\"tagName\"], str)\n )\n\n\n@overload\ndef vdom(tag: str, *children: VdomChildren) -> VdomDict: ...\n\n\n@overload\ndef vdom(tag: str, attributes: VdomAttributes, *children: VdomChildren) -> VdomDict: ...\n\n\ndef vdom(\n tag: str,\n *attributes_and_children: Any,\n **kwargs: Any,\n) -> VdomDict:\n...\n", "needle_token_start": 7172, "needle_token_end": 7254, "code_context_ntokens": 11219, "output": ["```python\ndef set_context_provider(self, provider: ContextProviderType[Any]) -> None:\n \"\"\"Set a context provider for this hook\n\n The context provider will be used to provide state to any child components\n of this hook's component which request a context provider of the same type.\n \"\"\"\n self._context_providers[provider.type] = provider\n```"]} +{"repo": "reactive-python/reactpy", "name": "boolean", "language": "python", "path": "src/py/reactpy/reactpy/config.py", "position_ratio": 0.85, "description": "\nFunction Description:\n1. **Purpose**: To convert various data types into a boolean value based on predefined true and false representations.\n2. **Input**: Accepts a string, boolean, or integer.\n3. **Output**: Returns a boolean value.\n4. **Procedure**: The function first checks the data type of the input. If it's already a boolean, it returns the value directly. If it's an integer, it converts it to a boolean where zero is `False` and any non-zero value is `True`. If the input is a string, the function converts it to lowercase and checks against predefined sets of values that represent true and false. If the string matches a true value, it returns `True`; if it matches a false value, it returns `False`. If the string does not match any predefined values, it raises a `ValueError`.\n\n", "instruction": "Based on the function description and code context, please retrieve and repeat the exact described function from the code context in a code block wrapped by ```:", "template": "instruction\ncode_context\ndescription\ninstruction", "code_context": "# Path: src/py/reactpy/reactpy/backend/__init__.py\nimport mimetypes\nfrom logging import getLogger\n\n_logger = getLogger(__name__)\n\n# Fix for missing mime types due to OS corruption/misconfiguration\n# Example: https://github.com/encode/starlette/issues/829\nif not mimetypes.inited:\n mimetypes.init()\nfor extension, mime_type in {\n \".js\": \"application/javascript\",\n \".css\": \"text/css\",\n \".json\": \"application/json\",\n}.items():\n if not mimetypes.types_map.get(extension): # pragma: no cover\n _logger.warning(\n \"Mime type '%s = %s' is missing. Please research how to \"\n \"fix missing mime types on your operating system.\",\n extension,\n mime_type,\n )\n mimetypes.add_type(mime_type, extension)\n\n# Path: src/py/reactpy/reactpy/core/__init__.py\n\n# Path: src/py/reactpy/reactpy/core/types.py\nfrom __future__ import annotations\n\nimport sys\nfrom collections import namedtuple\nfrom collections.abc import Mapping, Sequence\nfrom types import TracebackType\nfrom typing import (\n TYPE_CHECKING,\n Any,\n Callable,\n Generic,\n Literal,\n NamedTuple,\n Protocol,\n TypeVar,\n overload,\n runtime_checkable,\n)\n\nfrom typing_extensions import TypeAlias, TypedDict\n\n_Type = TypeVar(\"_Type\")\n\n\nif TYPE_CHECKING or sys.version_info < (3, 9) or sys.version_info >= (3, 11):\n\n class State(NamedTuple, Generic[_Type]):\n value: _Type\n set_value: Callable[[_Type | Callable[[_Type], _Type]], None]\n\nelse: # nocov\n State = namedtuple(\"State\", (\"value\", \"set_value\"))\n\n\nComponentConstructor = Callable[..., \"ComponentType\"]\n\"\"\"Simple function returning a new component\"\"\"\n\nRootComponentConstructor = Callable[[], \"ComponentType\"]\n\"\"\"The root component should be constructed by a function accepting no arguments.\"\"\"\n\n\nKey: TypeAlias = \"str | int\"\n\n\n_OwnType = TypeVar(\"_OwnType\")\n\n\n@runtime_checkable\nclass ComponentType(Protocol):\n \"\"\"The expected interface for all component-like objects\"\"\"\n\n key: Key | None\n \"\"\"An identifier which is unique amongst a component's immediate siblings\"\"\"\n\n type: Any\n \"\"\"The function or class defining the behavior of this component\n\n This is used to see if two component instances share the same definition.\n \"\"\"\n\n def render(self) -> VdomDict | ComponentType | str | None:\n \"\"\"Render the component's view model.\"\"\"\n\n\n_Render_co = TypeVar(\"_Render_co\", covariant=True)\n_Event_contra = TypeVar(\"_Event_contra\", contravariant=True)\n\n\n@runtime_checkable\nclass LayoutType(Protocol[_Render_co, _Event_contra]):\n \"\"\"Renders and delivers, updates to views and events to handlers, respectively\"\"\"\n\n async def render(self) -> _Render_co:\n \"\"\"Render an update to a view\"\"\"\n\n async def deliver(self, event: _Event_contra) -> None:\n \"\"\"Relay an event to its respective handler\"\"\"\n\n async def __aenter__(self) -> LayoutType[_Render_co, _Event_contra]:\n \"\"\"Prepare the layout for its first render\"\"\"\n\n async def __aexit__(\n self,\n exc_type: type[Exception],\n exc_value: Exception,\n traceback: TracebackType,\n ) -> bool | None:\n \"\"\"Clean up the view after its final render\"\"\"\n\n\nVdomAttributes = Mapping[str, Any]\n\"\"\"Describes the attributes of a :class:`VdomDict`\"\"\"\n\nVdomChild: TypeAlias = \"ComponentType | VdomDict | str | None | Any\"\n\"\"\"A single child element of a :class:`VdomDict`\"\"\"\n\nVdomChildren: TypeAlias = \"Sequence[VdomChild] | VdomChild\"\n\"\"\"Describes a series of :class:`VdomChild` elements\"\"\"\n\n\nclass _VdomDictOptional(TypedDict, total=False):\n key: Key | None\n children: Sequence[ComponentType | VdomChild]\n attributes: VdomAttributes\n eventHandlers: EventHandlerDict\n importSource: ImportSourceDict\n\n\nclass _VdomDictRequired(TypedDict, total=True):\n tagName: str\n\n\nclass VdomDict(_VdomDictRequired, _VdomDictOptional):\n \"\"\"A :ref:`VDOM` dictionary\"\"\"\n\n\nclass ImportSourceDict(TypedDict):\n source: str\n fallback: Any\n sourceType: str\n unmountBeforeUpdate: bool\n\n\nclass _OptionalVdomJson(TypedDict, total=False):\n key: Key\n error: str\n children: list[Any]\n attributes: dict[str, Any]\n eventHandlers: dict[str, _JsonEventTarget]\n importSource: _JsonImportSource\n\n\nclass _RequiredVdomJson(TypedDict, total=True):\n tagName: str\n\n\nclass VdomJson(_RequiredVdomJson, _OptionalVdomJson):\n \"\"\"A JSON serializable form of :class:`VdomDict` matching the :data:`VDOM_JSON_SCHEMA`\"\"\"\n\n\nclass _JsonEventTarget(TypedDict):\n target: str\n preventDefault: bool\n stopPropagation: bool\n\n\nclass _JsonImportSource(TypedDict):\n source: str\n fallback: Any\n\n\nEventHandlerMapping = Mapping[str, \"EventHandlerType\"]\n\"\"\"A generic mapping between event names to their handlers\"\"\"\n\nEventHandlerDict: TypeAlias = \"dict[str, EventHandlerType]\"\n\"\"\"A dict mapping between event names to their handlers\"\"\"\n\n\nclass EventHandlerFunc(Protocol):\n \"\"\"A coroutine which can handle event data\"\"\"\n\n async def __call__(self, data: Sequence[Any]) -> None: ...\n\n\n@runtime_checkable\nclass EventHandlerType(Protocol):\n \"\"\"Defines a handler for some event\"\"\"\n\n prevent_default: bool\n \"\"\"Whether to block the event from propagating further up the DOM\"\"\"\n\n stop_propagation: bool\n \"\"\"Stops the default action associate with the event from taking place.\"\"\"\n\n function: EventHandlerFunc\n \"\"\"A coroutine which can respond to an event and its data\"\"\"\n\n target: str | None\n \"\"\"Typically left as ``None`` except when a static target is useful.\n\n When testing, it may be useful to specify a static target ID so events can be\n triggered programmatically.\n\n .. note::\n\n When ``None``, it is left to a :class:`LayoutType` to auto generate a unique ID.\n \"\"\"\n\n\nclass VdomDictConstructor(Protocol):\n \"\"\"Standard function for constructing a :class:`VdomDict`\"\"\"\n\n @overload\n def __call__(\n self, attributes: VdomAttributes, *children: VdomChildren\n ) -> VdomDict: ...\n\n @overload\n def __call__(self, *children: VdomChildren) -> VdomDict: ...\n\n @overload\n def __call__(\n self, *attributes_and_children: VdomAttributes | VdomChildren\n ) -> VdomDict: ...\n\n\nclass LayoutUpdateMessage(TypedDict):\n \"\"\"A message describing an update to a layout\"\"\"\n\n type: Literal[\"layout-update\"]\n \"\"\"The type of message\"\"\"\n path: str\n \"\"\"JSON Pointer path to the model element being updated\"\"\"\n model: VdomJson\n \"\"\"The model to assign at the given JSON Pointer path\"\"\"\n\n\nclass LayoutEventMessage(TypedDict):\n \"\"\"Message describing an event originating from an element in the layout\"\"\"\n\n type: Literal[\"layout-event\"]\n \"\"\"The type of message\"\"\"\n target: str\n \"\"\"The ID of the event handler.\"\"\"\n data: Sequence[Any]\n \"\"\"A list of event data passed to the event handler.\"\"\"\n\n\nclass Context(Protocol[_Type]):\n \"\"\"Returns a :class:`ContextProvider` component\"\"\"\n\n def __call__(\n self,\n *children: Any,\n value: _Type = ...,\n key: Key | None = ...,\n ) -> ContextProviderType[_Type]: ...\n\n\nclass ContextProviderType(ComponentType, Protocol[_Type]):\n \"\"\"A component which provides a context value to its children\"\"\"\n\n type: Context[_Type]\n \"\"\"The context type\"\"\"\n\n @property\n def value(self) -> _Type:\n \"Current context value\"\n\n# Path: src/py/reactpy/reactpy/backend/types.py\nfrom __future__ import annotations\n\nimport asyncio\nfrom collections.abc import MutableMapping\nfrom dataclasses import dataclass\nfrom typing import Any, Callable, Generic, Protocol, TypeVar, runtime_checkable\n\nfrom reactpy.core.types import RootComponentConstructor\n\n_App = TypeVar(\"_App\")\n\n\n@runtime_checkable\nclass BackendType(Protocol[_App]):\n \"\"\"Common interface for built-in web server/framework integrations\"\"\"\n\n Options: Callable[..., Any]\n \"\"\"A constructor for options passed to :meth:`BackendType.configure`\"\"\"\n\n def configure(\n self,\n app: _App,\n component: RootComponentConstructor,\n options: Any | None = None,\n ) -> None:\n \"\"\"Configure the given app instance to display the given component\"\"\"\n\n def create_development_app(self) -> _App:\n \"\"\"Create an application instance for development purposes\"\"\"\n\n async def serve_development_app(\n self,\n app: _App,\n host: str,\n port: int,\n started: asyncio.Event | None = None,\n ) -> None:\n \"\"\"Run an application using a development server\"\"\"\n\n\n_Carrier = TypeVar(\"_Carrier\")\n\n\n@dataclass\nclass Connection(Generic[_Carrier]):\n \"\"\"Represents a connection with a client\"\"\"\n\n scope: MutableMapping[str, Any]\n \"\"\"An ASGI scope or WSGI environment dictionary\"\"\"\n\n location: Location\n \"\"\"The current location (URL)\"\"\"\n\n carrier: _Carrier\n \"\"\"How the connection is mediated. For example, a request or websocket.\n\n This typically depends on the backend implementation.\n \"\"\"\n\n\n@dataclass\nclass Location:\n \"\"\"Represents the current location (URL)\n\n Analogous to, but not necessarily identical to, the client-side\n ``document.location`` object.\n \"\"\"\n\n pathname: str\n \"\"\"the path of the URL for the location\"\"\"\n\n search: str\n \"\"\"A search or query string - a '?' followed by the parameters of the URL.\n\n If there are no search parameters this should be an empty string\n \"\"\"\n\n# Path: src/py/reactpy/reactpy/_warnings.py\nfrom collections.abc import Iterator\nfrom functools import wraps\nfrom inspect import currentframe\nfrom types import FrameType\nfrom typing import TYPE_CHECKING, Any\nfrom warnings import warn as _warn\n\n\n@wraps(_warn)\ndef warn(*args: Any, **kwargs: Any) -> Any:\n # warn at call site outside of ReactPy\n _warn(*args, stacklevel=_frame_depth_in_module() + 1, **kwargs) # type: ignore\n\n\nif TYPE_CHECKING:\n warn = _warn # noqa: F811\n\n\ndef _frame_depth_in_module() -> int:\n depth = 0\n for frame in _iter_frames(2):\n module_name = frame.f_globals.get(\"__name__\")\n if not module_name or not module_name.startswith(\"reactpy.\"):\n break\n depth += 1\n return depth\n\n\ndef _iter_frames(index: int = 1) -> Iterator[FrameType]:\n frame = currentframe()\n while frame is not None:\n if index == 0:\n yield frame\n else:\n index -= 1\n frame = frame.f_back\n\n# Path: src/py/reactpy/reactpy/_option.py\nfrom __future__ import annotations\n\nimport os\nfrom logging import getLogger\nfrom typing import Any, Callable, Generic, TypeVar, cast\n\nfrom reactpy._warnings import warn\n\n_O = TypeVar(\"_O\")\nlogger = getLogger(__name__)\nUNDEFINED = cast(Any, object())\n\n\nclass Option(Generic[_O]):\n \"\"\"An option that can be set using an environment variable of the same name\"\"\"\n\n def __init__(\n self,\n name: str,\n default: _O = UNDEFINED,\n mutable: bool = True,\n parent: Option[_O] | None = None,\n validator: Callable[[Any], _O] = lambda x: cast(_O, x),\n ) -> None:\n self._name = name\n self._mutable = mutable\n self._validator = validator\n self._subscribers: list[Callable[[_O], None]] = []\n\n if name in os.environ:\n self._current = validator(os.environ[name])\n\n if parent is not None:\n if not (parent.mutable and self.mutable):\n raise TypeError(\"Parent and child options must be mutable\")\n self._default = parent.default\n parent.subscribe(self.set_current)\n elif default is not UNDEFINED:\n self._default = default\n else:\n raise TypeError(\"Must specify either a default or a parent option\")\n\n logger.debug(f\"{self._name}={self.current}\")\n\n @property\n def name(self) -> str:\n \"\"\"The name of this option (used to load environment variables)\"\"\"\n return self._name\n\n @property\n def mutable(self) -> bool:\n \"\"\"Whether this option can be modified after being loaded\"\"\"\n return self._mutable\n\n @property\n def default(self) -> _O:\n \"\"\"This option's default value\"\"\"\n return self._default\n\n @property\n def current(self) -> _O:\n try:\n return self._current\n except AttributeError:\n return self._default\n\n @current.setter\n def current(self, new: _O) -> None:\n self.set_current(new)\n\n @current.deleter\n def current(self) -> None:\n self.unset()\n\n def subscribe(self, handler: Callable[[_O], None]) -> Callable[[_O], None]:\n \"\"\"Register a callback that will be triggered when this option changes\"\"\"\n if not self.mutable:\n msg = \"Immutable options cannot be subscribed to.\"\n raise TypeError(msg)\n self._subscribers.append(handler)\n handler(self.current)\n return handler\n\n def is_set(self) -> bool:\n \"\"\"Whether this option has a value other than its default.\"\"\"\n return hasattr(self, \"_current\")\n\n def set_current(self, new: Any) -> None:\n \"\"\"Set the value of this option\n\n Raises a ``TypeError`` if this option is not :attr:`Option.mutable`.\n \"\"\"\n old = self.current\n if new is old:\n return None\n\n if not self._mutable:\n msg = f\"{self} cannot be modified after initial load\"\n raise TypeError(msg)\n\n try:\n new = self._current = self._validator(new)\n except ValueError as error:\n raise ValueError(f\"Invalid value for {self._name}: {new!r}\") from error\n\n logger.debug(f\"{self._name}={self._current}\")\n if new != old:\n for sub_func in self._subscribers:\n sub_func(new)\n\n def set_default(self, new: _O) -> _O:\n \"\"\"Set the value of this option if not :meth:`Option.is_set`\n\n Returns the current value (a la :meth:`dict.set_default`)\n \"\"\"\n if not self.is_set():\n self.set_current(new)\n return self._current\n\n def reload(self) -> None:\n \"\"\"Reload this option from its environment variable\"\"\"\n self.set_current(os.environ.get(self._name, self._default))\n\n def unset(self) -> None:\n \"\"\"Remove the current value, the default will be used until it is set again.\"\"\"\n if not self._mutable:\n msg = f\"{self} cannot be modified after initial load\"\n raise TypeError(msg)\n old = self.current\n if hasattr(self, \"_current\"):\n delattr(self, \"_current\")\n if self.current != old:\n for sub_func in self._subscribers:\n sub_func(self.current)\n\n def __repr__(self) -> str:\n return f\"Option({self._name}={self.current!r})\"\n\n\nclass DeprecatedOption(Option[_O]):\n \"\"\"An option that will warn when it is accessed\"\"\"\n\n def __init__(self, *args: Any, message: str, **kwargs: Any) -> None:\n super().__init__(*args, **kwargs)\n self._deprecation_message = message\n\n @Option.current.getter # type: ignore\n def current(self) -> _O:\n try:\n # we access the current value during init to debug log it\n # no need to warn unless it's actually used. since this attr\n # is only set after super().__init__ is called, we can check\n # for it to determine if it's being accessed by a user.\n msg = self._deprecation_message\n except AttributeError:\n pass\n else:\n warn(msg, DeprecationWarning)\n return super().current\n\n# Path: src/py/reactpy/reactpy/config.py\n\"\"\"\nReactPy provides a series of configuration options that can be set using environment\nvariables or, for those which allow it, a programmatic interface.\n\"\"\"\n\nfrom __future__ import annotations\n\nfrom pathlib import Path\nfrom tempfile import TemporaryDirectory\n\nfrom reactpy._option import Option\n\nTRUE_VALUES = {\"true\", \"1\"}\nFALSE_VALUES = {\"false\", \"0\"}\n\n\n\ndef boolean(value: str | bool | int) -> bool:\n if isinstance(value, bool):\n return value\n elif isinstance(value, int):\n return bool(value)\n elif not isinstance(value, str):\n raise TypeError(f\"Expected str or bool, got {type(value).__name__}\")\n\n if value.lower() in TRUE_VALUES:\n return True\n elif value.lower() in FALSE_VALUES:\n return False\n else:\n raise ValueError(\n f\"Invalid boolean value {value!r} - expected \"\n f\"one of {list(TRUE_VALUES | FALSE_VALUES)}\"\n )\n\n\nREACTPY_DEBUG_MODE = Option(\n \"REACTPY_DEBUG_MODE\", default=False, validator=boolean, mutable=True\n)\n\"\"\"Get extra logs and validation checks at the cost of performance.\n\nThis will enable the following:\n\n- :data:`REACTPY_CHECK_VDOM_SPEC`\n- :data:`REACTPY_CHECK_JSON_ATTRS`\n\"\"\"\n\nREACTPY_CHECK_VDOM_SPEC = Option(\"REACTPY_CHECK_VDOM_SPEC\", parent=REACTPY_DEBUG_MODE)\n\"\"\"Checks which ensure VDOM is rendered to spec\n\nFor more info on the VDOM spec, see here: :ref:`VDOM JSON Schema`\n\"\"\"\n\nREACTPY_CHECK_JSON_ATTRS = Option(\"REACTPY_CHECK_JSON_ATTRS\", parent=REACTPY_DEBUG_MODE)\n\"\"\"Checks that all VDOM attributes are JSON serializable\n\nThe VDOM spec is not able to enforce this on its own since attributes could anything.\n\"\"\"\n\n# Because these web modules will be linked dynamically at runtime this can be temporary.\n# Assigning to a variable here ensures that the directory is not deleted until the end\n# of the program.\n_DEFAULT_WEB_MODULES_DIR = TemporaryDirectory()\n\nREACTPY_WEB_MODULES_DIR = Option(\n \"REACTPY_WEB_MODULES_DIR\",\n default=Path(_DEFAULT_WEB_MODULES_DIR.name),\n validator=Path,\n)\n\"\"\"The location ReactPy will use to store its client application\n\nThis directory **MUST** be treated as a black box. Downstream applications **MUST NOT**\nassume anything about the structure of this directory see :mod:`reactpy.web.module` for a\nset of publicly available APIs for working with the client.\n\"\"\"\n\nREACTPY_TESTING_DEFAULT_TIMEOUT = Option(\n \"REACTPY_TESTING_DEFAULT_TIMEOUT\",\n 5.0,\n mutable=False,\n validator=float,\n)\n\"\"\"A default timeout for testing utilities in ReactPy\"\"\"\n\nREACTPY_ASYNC_RENDERING = Option(\n \"REACTPY_CONCURRENT_RENDERING\",\n default=False,\n mutable=True,\n validator=boolean,\n)\n\"\"\"Whether to render components concurrently. This is currently an experimental feature.\"\"\"\n\n# Path: src/py/reactpy/reactpy/core/_thread_local.py\nfrom threading import Thread, current_thread\nfrom typing import Callable, Generic, TypeVar\nfrom weakref import WeakKeyDictionary\n\n_StateType = TypeVar(\"_StateType\")\n\n\nclass ThreadLocal(Generic[_StateType]):\n \"\"\"Utility for managing per-thread state information\"\"\"\n\n def __init__(self, default: Callable[[], _StateType]):\n self._default = default\n self._state: WeakKeyDictionary[Thread, _StateType] = WeakKeyDictionary()\n\n def get(self) -> _StateType:\n thread = current_thread()\n if thread not in self._state:\n state = self._state[thread] = self._default()\n else:\n state = self._state[thread]\n return state\n\n# Path: src/py/reactpy/reactpy/core/_life_cycle_hook.py\nfrom __future__ import annotations\n\nimport logging\nfrom asyncio import Event, Task, create_task, gather\nfrom typing import Any, Callable, Protocol, TypeVar\n\nfrom anyio import Semaphore\n\nfrom reactpy.core._thread_local import ThreadLocal\nfrom reactpy.core.types import ComponentType, Context, ContextProviderType\n\nT = TypeVar(\"T\")\n\n\nclass EffectFunc(Protocol):\n async def __call__(self, stop: Event) -> None: ...\n\n\nlogger = logging.getLogger(__name__)\n\n_HOOK_STATE: ThreadLocal[list[LifeCycleHook]] = ThreadLocal(list)\n\n\ndef current_hook() -> LifeCycleHook:\n \"\"\"Get the current :class:`LifeCycleHook`\"\"\"\n hook_stack = _HOOK_STATE.get()\n if not hook_stack:\n msg = \"No life cycle hook is active. Are you rendering in a layout?\"\n raise RuntimeError(msg)\n return hook_stack[-1]\n\n\nclass LifeCycleHook:\n \"\"\"An object which manages the \"life cycle\" of a layout component.\n\n The \"life cycle\" of a component is the set of events which occur from the time\n a component is first rendered until it is removed from the layout. The life cycle\n is ultimately driven by the layout itself, but components can \"hook\" into those\n events to perform actions. Components gain access to their own life cycle hook\n by calling :func:`current_hook`. They can then perform actions such as:\n\n 1. Adding state via :meth:`use_state`\n 2. Adding effects via :meth:`add_effect`\n 3. Setting or getting context providers via\n :meth:`LifeCycleHook.set_context_provider` and\n :meth:`get_context_provider` respectively.\n\n Components can request access to their own life cycle events and state through hooks\n while :class:`~reactpy.core.proto.LayoutType` objects drive drive the life cycle\n forward by triggering events and rendering view changes.\n\n Example:\n\n If removed from the complexities of a layout, a very simplified full life cycle\n for a single component with no child components would look a bit like this:\n\n .. testcode::\n\n from reactpy.core._life_cycle_hook import LifeCycleHook\n from reactpy.core.hooks import current_hook\n\n # this function will come from a layout implementation\n schedule_render = lambda: ...\n\n # --- start life cycle ---\n\n hook = LifeCycleHook(schedule_render)\n\n # --- start render cycle ---\n\n component = ...\n await hook.affect_component_will_render(component)\n try:\n # render the component\n ...\n\n # the component may access the current hook\n assert current_hook() is hook\n\n # and save state or add effects\n current_hook().use_state(lambda: ...)\n\n async def my_effect(stop_event):\n ...\n\n current_hook().add_effect(my_effect)\n finally:\n await hook.affect_component_did_render()\n\n # This should only be called after the full set of changes associated with a\n # given render have been completed.\n await hook.affect_layout_did_render()\n\n # Typically an event occurs and a new render is scheduled, thus beginning\n # the render cycle anew.\n hook.schedule_render()\n\n\n # --- end render cycle ---\n\n hook.affect_component_will_unmount()\n del hook\n\n # --- end render cycle ---\n \"\"\"\n\n __slots__ = (\n \"__weakref__\",\n \"_context_providers\",\n \"_current_state_index\",\n \"_effect_funcs\",\n \"_effect_stops\",\n \"_effect_tasks\",\n \"_render_access\",\n \"_rendered_atleast_once\",\n \"_schedule_render_callback\",\n \"_scheduled_render\",\n \"_state\",\n \"component\",\n )\n\n component: ComponentType\n\n def __init__(\n self,\n schedule_render: Callable[[], None],\n ) -> None:\n self._context_providers: dict[Context[Any], ContextProviderType[Any]] = {}\n self._schedule_render_callback = schedule_render\n self._scheduled_render = False\n self._rendered_atleast_once = False\n self._current_state_index = 0\n self._state: tuple[Any, ...] = ()\n self._effect_funcs: list[EffectFunc] = []\n self._effect_tasks: list[Task[None]] = []\n self._effect_stops: list[Event] = []\n self._render_access = Semaphore(1) # ensure only one render at a time\n\n def schedule_render(self) -> None:\n if self._scheduled_render:\n return None\n try:\n self._schedule_render_callback()\n except Exception:\n msg = f\"Failed to schedule render via {self._schedule_render_callback}\"\n logger.exception(msg)\n else:\n self._scheduled_render = True\n\n def use_state(self, function: Callable[[], T]) -> T:\n \"\"\"Add state to this hook\n\n If this hook has not yet rendered, the state is appended to the state tuple.\n Otherwise, the state is retrieved from the tuple. This allows state to be\n preserved across renders.\n \"\"\"\n if not self._rendered_atleast_once:\n # since we're not initialized yet we're just appending state\n result = function()\n self._state += (result,)\n else:\n # once finalized we iterate over each succesively used piece of state\n result = self._state[self._current_state_index]\n self._current_state_index += 1\n return result\n\n def add_effect(self, effect_func: EffectFunc) -> None:\n \"\"\"Add an effect to this hook\n\n...\n", "needle_token_start": 4624, "needle_token_end": 4786, "code_context_ntokens": 7097, "output": ["```python\ndef boolean(value: str | bool | int) -> bool:\n if isinstance(value, bool):\n return value\n elif isinstance(value, int):\n return bool(value)\n elif not isinstance(value, str):\n raise TypeError(f\"Expected str or bool, got {type(value).__name__}\")\n\n if value.lower() in TRUE_VALUES:\n return True\n elif value.lower() in FALSE_VALUES:\n return False\n else:\n raise ValueError(\n f\"Invalid boolean value {value!r} - expected \"\n f\"one of {list(TRUE_VALUES | FALSE_VALUES)}\"\n )\n```"]} +{"repo": "reactive-python/reactpy", "name": "all_implementations", "language": "python", "path": "src/py/reactpy/reactpy/backend/utils.py", "position_ratio": 0.95, "description": "\nFunction Description:\n1. **Purpose**: The function is designed to iterate through and provide all available server backend implementations supported by a Python library.\n2. **Input**: There are no inputs required for this function.\n3. **Output**: The function yields each successfully imported server backend implementation module.\n4. **Procedure**: The function iterates over a predefined list of backend names. For each name, it attempts to import the corresponding module. If the import is successful, the function then constructs the full module name for the backend specific to the library and yields the imported module. If the import fails, it logs the failure and continues to the next backend.\n\n", "instruction": "Based on the function description and code context, please retrieve and repeat the exact described function from the code context in a code block wrapped by ```:", "template": "instruction\ncode_context\ndescription\ninstruction", "code_context": "# Path: src/py/reactpy/reactpy/core/_life_cycle_hook.py\nfrom __future__ import annotations\n\nimport logging\nfrom asyncio import Event, Task, create_task, gather\nfrom typing import Any, Callable, Protocol, TypeVar\n\nfrom anyio import Semaphore\n\nfrom reactpy.core._thread_local import ThreadLocal\nfrom reactpy.core.types import ComponentType, Context, ContextProviderType\n\nT = TypeVar(\"T\")\n\n\nclass EffectFunc(Protocol):\n async def __call__(self, stop: Event) -> None: ...\n\n\nlogger = logging.getLogger(__name__)\n\n_HOOK_STATE: ThreadLocal[list[LifeCycleHook]] = ThreadLocal(list)\n\n\ndef current_hook() -> LifeCycleHook:\n \"\"\"Get the current :class:`LifeCycleHook`\"\"\"\n hook_stack = _HOOK_STATE.get()\n if not hook_stack:\n msg = \"No life cycle hook is active. Are you rendering in a layout?\"\n raise RuntimeError(msg)\n return hook_stack[-1]\n\n\nclass LifeCycleHook:\n \"\"\"An object which manages the \"life cycle\" of a layout component.\n\n The \"life cycle\" of a component is the set of events which occur from the time\n a component is first rendered until it is removed from the layout. The life cycle\n is ultimately driven by the layout itself, but components can \"hook\" into those\n events to perform actions. Components gain access to their own life cycle hook\n by calling :func:`current_hook`. They can then perform actions such as:\n\n 1. Adding state via :meth:`use_state`\n 2. Adding effects via :meth:`add_effect`\n 3. Setting or getting context providers via\n :meth:`LifeCycleHook.set_context_provider` and\n :meth:`get_context_provider` respectively.\n\n Components can request access to their own life cycle events and state through hooks\n while :class:`~reactpy.core.proto.LayoutType` objects drive drive the life cycle\n forward by triggering events and rendering view changes.\n\n Example:\n\n If removed from the complexities of a layout, a very simplified full life cycle\n...\n# Path: src/py/reactpy/reactpy/core/_f_back.py\nfrom __future__ import annotations\n\nimport inspect\nfrom types import FrameType\n\n\ndef f_module_name(index: int = 0) -> str:\n frame = f_back(index + 1)\n if frame is None:\n return \"\" # nocov\n name = frame.f_globals.get(\"__name__\", \"\")\n if not isinstance(name, str):\n raise TypeError(\"Expected module name to be a string\") # nocov\n return name\n\n\ndef f_back(index: int = 0) -> FrameType | None:\n frame = inspect.currentframe()\n while frame is not None:\n if index < 0:\n return frame\n frame = frame.f_back\n index -= 1\n return None # nocov\n\n# Path: src/py/reactpy/reactpy/core/events.py\nfrom __future__ import annotations\n\nimport asyncio\nfrom collections.abc import Sequence\nfrom typing import Any, Callable, Literal, overload\n\nfrom anyio import create_task_group\n\nfrom reactpy.core.types import EventHandlerFunc, EventHandlerType\n\n\n@overload\ndef event(\n function: Callable[..., Any],\n *,\n stop_propagation: bool = ...,\n prevent_default: bool = ...,\n) -> EventHandler: ...\n\n\n@overload\ndef event(\n function: Literal[None] = ...,\n *,\n stop_propagation: bool = ...,\n prevent_default: bool = ...,\n) -> Callable[[Callable[..., Any]], EventHandler]: ...\n\n\ndef event(\n function: Callable[..., Any] | None = None,\n *,\n stop_propagation: bool = False,\n prevent_default: bool = False,\n) -> EventHandler | Callable[[Callable[..., Any]], EventHandler]:\n \"\"\"A decorator for constructing an :class:`EventHandler`.\n\n While you're always free to add callbacks by assigning them to an element's attributes\n\n .. code-block:: python\n\n element = reactpy.html.button({\"onClick\": my_callback})\n\n You may want the ability to prevent the default action associated with the event\n from taking place, or stopping the event from propagating up the DOM. This decorator\n allows you to add that functionality to your callbacks.\n\n .. code-block:: python\n\n @event(stop_propagation=True, prevent_default=True)\n def my_callback(*data):\n ...\n\n element = reactpy.html.button({\"onClick\": my_callback})\n\n Parameters:\n function:\n A function or coroutine responsible for handling the event.\n stop_propagation:\n Block the event from propagating further up the DOM.\n prevent_default:\n Stops the default actional associate with the event from taking place.\n \"\"\"\n\n def setup(function: Callable[..., Any]) -> EventHandler:\n return EventHandler(\n to_event_handler_function(function, positional_args=True),\n stop_propagation,\n prevent_default,\n )\n\n if function is not None:\n return setup(function)\n else:\n return setup\n\n\nclass EventHandler:\n \"\"\"Turn a function or coroutine into an event handler\n\n Parameters:\n function:\n The function or coroutine which handles the event.\n stop_propagation:\n Block the event from propagating further up the DOM.\n prevent_default:\n Stops the default action associate with the event from taking place.\n target:\n A unique identifier for this event handler (auto-generated by default)\n \"\"\"\n\n __slots__ = (\n \"__weakref__\",\n \"function\",\n \"prevent_default\",\n \"stop_propagation\",\n \"target\",\n )\n\n def __init__(\n self,\n function: EventHandlerFunc,\n stop_propagation: bool = False,\n prevent_default: bool = False,\n target: str | None = None,\n ) -> None:\n self.function = to_event_handler_function(function, positional_args=False)\n self.prevent_default = prevent_default\n self.stop_propagation = stop_propagation\n self.target = target\n\n def __eq__(self, other: Any) -> bool:\n undefined = object()\n for attr in (\n \"function\",\n \"prevent_default\",\n \"stop_propagation\",\n \"target\",\n ):\n if not attr.startswith(\"_\"):\n if not getattr(other, attr, undefined) == getattr(self, attr):\n return False\n return True\n\n def __repr__(self) -> str:\n public_names = [name for name in self.__slots__ if not name.startswith(\"_\")]\n items = \", \".join([f\"{n}={getattr(self, n)!r}\" for n in public_names])\n return f\"{type(self).__name__}({items})\"\n\n\ndef to_event_handler_function(\n function: Callable[..., Any],\n positional_args: bool = True,\n) -> EventHandlerFunc:\n \"\"\"Make a :data:`~reactpy.core.proto.EventHandlerFunc` from a function or coroutine\n\n Parameters:\n function:\n A function or coroutine accepting a number of positional arguments.\n positional_args:\n Whether to pass the event parameters a positional args or as a list.\n \"\"\"\n if positional_args:\n if asyncio.iscoroutinefunction(function):\n\n async def wrapper(data: Sequence[Any]) -> None:\n await function(*data)\n\n else:\n\n async def wrapper(data: Sequence[Any]) -> None:\n function(*data)\n\n return wrapper\n elif not asyncio.iscoroutinefunction(function):\n\n async def wrapper(data: Sequence[Any]) -> None:\n function(data)\n\n return wrapper\n else:\n return function\n\n\ndef merge_event_handlers(\n event_handlers: Sequence[EventHandlerType],\n) -> EventHandlerType:\n \"\"\"Merge multiple event handlers into one\n\n Raises a ValueError if any handlers have conflicting\n :attr:`~reactpy.core.proto.EventHandlerType.stop_propagation` or\n :attr:`~reactpy.core.proto.EventHandlerType.prevent_default` attributes.\n \"\"\"\n if not event_handlers:\n msg = \"No event handlers to merge\"\n raise ValueError(msg)\n elif len(event_handlers) == 1:\n return event_handlers[0]\n\n first_handler = event_handlers[0]\n\n stop_propagation = first_handler.stop_propagation\n prevent_default = first_handler.prevent_default\n target = first_handler.target\n\n for handler in event_handlers:\n if (\n handler.stop_propagation != stop_propagation\n or handler.prevent_default != prevent_default\n or handler.target != target\n ):\n msg = \"Cannot merge handlers - 'stop_propagation', 'prevent_default' or 'target' mismatch.\"\n raise ValueError(msg)\n\n return EventHandler(\n merge_event_handler_funcs([h.function for h in event_handlers]),\n stop_propagation,\n prevent_default,\n target,\n )\n\n\ndef merge_event_handler_funcs(\n functions: Sequence[EventHandlerFunc],\n) -> EventHandlerFunc:\n \"\"\"Make one event handler function from many\"\"\"\n if not functions:\n msg = \"No event handler functions to merge\"\n raise ValueError(msg)\n elif len(functions) == 1:\n return functions[0]\n\n async def await_all_event_handlers(data: Sequence[Any]) -> None:\n async with create_task_group() as group:\n for func in functions:\n group.start_soon(func, data)\n\n return await_all_event_handlers\n\n# Path: src/py/reactpy/reactpy/core/vdom.py\nfrom __future__ import annotations\n\nimport json\nfrom collections.abc import Mapping, Sequence\nfrom functools import wraps\nfrom typing import Any, Protocol, cast, overload\n\nfrom fastjsonschema import compile as compile_json_schema\n\nfrom reactpy._warnings import warn\nfrom reactpy.config import REACTPY_CHECK_JSON_ATTRS, REACTPY_DEBUG_MODE\nfrom reactpy.core._f_back import f_module_name\nfrom reactpy.core.events import EventHandler, to_event_handler_function\nfrom reactpy.core.types import (\n ComponentType,\n EventHandlerDict,\n EventHandlerType,\n ImportSourceDict,\n Key,\n VdomAttributes,\n VdomChild,\n VdomChildren,\n VdomDict,\n VdomDictConstructor,\n VdomJson,\n)\n\nVDOM_JSON_SCHEMA = {\n \"$schema\": \"http://json-schema.org/draft-07/schema\",\n \"$ref\": \"#/definitions/element\",\n \"definitions\": {\n \"element\": {\n \"type\": \"object\",\n \"properties\": {\n \"tagName\": {\"type\": \"string\"},\n \"key\": {\"type\": [\"string\", \"number\", \"null\"]},\n \"error\": {\"type\": \"string\"},\n \"children\": {\"$ref\": \"#/definitions/elementChildren\"},\n \"attributes\": {\"type\": \"object\"},\n \"eventHandlers\": {\"$ref\": \"#/definitions/elementEventHandlers\"},\n \"importSource\": {\"$ref\": \"#/definitions/importSource\"},\n },\n # The 'tagName' is required because its presence is a useful indicator of\n # whether a dictionary describes a VDOM model or not.\n \"required\": [\"tagName\"],\n \"dependentSchemas\": {\n # When 'error' is given, the 'tagName' should be empty.\n \"error\": {\"properties\": {\"tagName\": {\"maxLength\": 0}}}\n },\n },\n \"elementChildren\": {\n \"type\": \"array\",\n \"items\": {\"$ref\": \"#/definitions/elementOrString\"},\n },\n \"elementEventHandlers\": {\n \"type\": \"object\",\n \"patternProperties\": {\n \".*\": {\"$ref\": \"#/definitions/eventHandler\"},\n },\n },\n \"eventHandler\": {\n \"type\": \"object\",\n \"properties\": {\n \"target\": {\"type\": \"string\"},\n \"preventDefault\": {\"type\": \"boolean\"},\n \"stopPropagation\": {\"type\": \"boolean\"},\n },\n \"required\": [\"target\"],\n },\n \"importSource\": {\n \"type\": \"object\",\n \"properties\": {\n \"source\": {\"type\": \"string\"},\n \"sourceType\": {\"enum\": [\"URL\", \"NAME\"]},\n \"fallback\": {\n \"type\": [\"object\", \"string\", \"null\"],\n \"if\": {\"not\": {\"type\": \"null\"}},\n \"then\": {\"$ref\": \"#/definitions/elementOrString\"},\n },\n \"unmountBeforeUpdate\": {\"type\": \"boolean\"},\n },\n \"required\": [\"source\"],\n },\n \"elementOrString\": {\n \"type\": [\"object\", \"string\"],\n \"if\": {\"type\": \"object\"},\n \"then\": {\"$ref\": \"#/definitions/element\"},\n },\n },\n}\n\"\"\"JSON Schema describing serialized VDOM - see :ref:`VDOM` for more info\"\"\"\n\n\n# we can't add a docstring to this because Sphinx doesn't know how to find its source\n_COMPILED_VDOM_VALIDATOR = compile_json_schema(VDOM_JSON_SCHEMA)\n\n\ndef validate_vdom_json(value: Any) -> VdomJson:\n \"\"\"Validate serialized VDOM - see :attr:`VDOM_JSON_SCHEMA` for more info\"\"\"\n _COMPILED_VDOM_VALIDATOR(value)\n return cast(VdomJson, value)\n\n\ndef is_vdom(value: Any) -> bool:\n \"\"\"Return whether a value is a :class:`VdomDict`\n\n This employs a very simple heuristic - something is VDOM if:\n\n 1. It is a ``dict`` instance\n 2. It contains the key ``\"tagName\"``\n 3. The value of the key ``\"tagName\"`` is a string\n\n .. note::\n\n Performing an ``isinstance(value, VdomDict)`` check is too restrictive since the\n user would be forced to import ``VdomDict`` every time they needed to declare a\n VDOM element. Giving the user more flexibility, at the cost of this check's\n accuracy, is worth it.\n \"\"\"\n return (\n isinstance(value, dict)\n and \"tagName\" in value\n and isinstance(value[\"tagName\"], str)\n )\n\n\n@overload\ndef vdom(tag: str, *children: VdomChildren) -> VdomDict: ...\n\n\n@overload\ndef vdom(tag: str, attributes: VdomAttributes, *children: VdomChildren) -> VdomDict: ...\n\n\ndef vdom(\n tag: str,\n *attributes_and_children: Any,\n **kwargs: Any,\n) -> VdomDict:\n \"\"\"A helper function for creating VDOM elements.\n\n Parameters:\n tag:\n The type of element (e.g. 'div', 'h1', 'img')\n attributes_and_children:\n An optional attribute mapping followed by any number of children or\n iterables of children. The attribute mapping **must** precede the children,\n or children which will be merged into their respective parts of the model.\n key:\n A string indicating the identity of a particular element. This is significant\n to preserve event handlers across updates - without a key, a re-render would\n cause these handlers to be deleted, but with a key, they would be redirected\n to any newly defined handlers.\n event_handlers:\n Maps event types to coroutines that are responsible for handling those events.\n import_source:\n (subject to change) specifies javascript that, when evaluated returns a\n React component.\n \"\"\"\n if kwargs: # nocov\n if \"key\" in kwargs:\n if attributes_and_children:\n maybe_attributes, *children = attributes_and_children\n if _is_attributes(maybe_attributes):\n attributes_and_children = (\n {**maybe_attributes, \"key\": kwargs.pop(\"key\")},\n *children,\n )\n else:\n attributes_and_children = (\n {\"key\": kwargs.pop(\"key\")},\n maybe_attributes,\n *children,\n )\n else:\n attributes_and_children = ({\"key\": kwargs.pop(\"key\")},)\n warn(\n \"An element's 'key' must be declared in an attribute dict instead \"\n \"of as a keyword argument. This will error in a future version.\",\n DeprecationWarning,\n )\n\n if kwargs:\n msg = f\"Extra keyword arguments {kwargs}\"\n raise ValueError(msg)\n\n model: VdomDict = {\"tagName\": tag}\n\n if not attributes_and_children:\n return model\n\n attributes, children = separate_attributes_and_children(attributes_and_children)\n key = attributes.pop(\"key\", None)\n attributes, event_handlers = separate_attributes_and_event_handlers(attributes)\n\n if attributes:\n if REACTPY_CHECK_JSON_ATTRS.current:\n json.dumps(attributes)\n model[\"attributes\"] = attributes\n\n if children:\n model[\"children\"] = children\n\n if key is not None:\n model[\"key\"] = key\n\n if event_handlers:\n model[\"eventHandlers\"] = event_handlers\n\n return model\n\n\ndef make_vdom_constructor(\n tag: str, allow_children: bool = True, import_source: ImportSourceDict | None = None\n) -> VdomDictConstructor:\n \"\"\"Return a constructor for VDOM dictionaries with the given tag name.\n\n The resulting callable will have the same interface as :func:`vdom` but without its\n first ``tag`` argument.\n \"\"\"\n\n def constructor(*attributes_and_children: Any, **kwargs: Any) -> VdomDict:\n model = vdom(tag, *attributes_and_children, **kwargs)\n if not allow_children and \"children\" in model:\n msg = f\"{tag!r} nodes cannot have children.\"\n raise TypeError(msg)\n if import_source:\n model[\"importSource\"] = import_source\n return model\n\n # replicate common function attributes\n constructor.__name__ = tag\n constructor.__doc__ = (\n \"Return a new \"\n f\"`<{tag}> `__ \"\n \"element represented by a :class:`VdomDict`.\"\n )\n\n module_name = f_module_name(1)\n if module_name:\n constructor.__module__ = module_name\n constructor.__qualname__ = f\"{module_name}.{tag}\"\n\n return cast(VdomDictConstructor, constructor)\n\n\ndef custom_vdom_constructor(func: _CustomVdomDictConstructor) -> VdomDictConstructor:\n \"\"\"Cast function to VdomDictConstructor\"\"\"\n\n @wraps(func)\n def wrapper(*attributes_and_children: Any) -> VdomDict:\n attributes, children = separate_attributes_and_children(attributes_and_children)\n key = attributes.pop(\"key\", None)\n attributes, event_handlers = separate_attributes_and_event_handlers(attributes)\n return func(attributes, children, key, event_handlers)\n\n return cast(VdomDictConstructor, wrapper)\n\n\ndef separate_attributes_and_children(\n values: Sequence[Any],\n) -> tuple[dict[str, Any], list[Any]]:\n if not values:\n return {}, []\n\n attributes: dict[str, Any]\n children_or_iterables: Sequence[Any]\n if _is_attributes(values[0]):\n attributes, *children_or_iterables = values\n else:\n attributes = {}\n children_or_iterables = values\n\n children: list[Any] = []\n for child in children_or_iterables:\n if _is_single_child(child):\n children.append(child)\n else:\n children.extend(child)\n\n return attributes, children\n\n\ndef separate_attributes_and_event_handlers(\n attributes: Mapping[str, Any]\n) -> tuple[dict[str, Any], EventHandlerDict]:\n separated_attributes = {}\n separated_event_handlers: dict[str, EventHandlerType] = {}\n\n for k, v in attributes.items():\n handler: EventHandlerType\n\n if callable(v):\n handler = EventHandler(to_event_handler_function(v))\n elif (\n # isinstance check on protocols is slow - use function attr pre-check as a\n # quick filter before actually performing slow EventHandlerType type check\n hasattr(v, \"function\")\n and isinstance(v, EventHandlerType)\n ):\n handler = v\n else:\n separated_attributes[k] = v\n continue\n\n separated_event_handlers[k] = handler\n\n return separated_attributes, dict(separated_event_handlers.items())\n\n\ndef _is_attributes(value: Any) -> bool:\n return isinstance(value, Mapping) and \"tagName\" not in value\n\n\ndef _is_single_child(value: Any) -> bool:\n if isinstance(value, (str, Mapping)) or not hasattr(value, \"__iter__\"):\n return True\n if REACTPY_DEBUG_MODE.current:\n _validate_child_key_integrity(value)\n return False\n\n\ndef _validate_child_key_integrity(value: Any) -> None:\n if hasattr(value, \"__iter__\") and not hasattr(value, \"__len__\"):\n warn(\n f\"Did not verify key-path integrity of children in generator {value} \"\n \"- pass a sequence (i.e. list of finite length) in order to verify\"\n )\n else:\n for child in value:\n if isinstance(child, ComponentType) and child.key is None:\n warn(f\"Key not specified for child in list {child}\", UserWarning)\n elif isinstance(child, Mapping) and \"key\" not in child:\n # remove 'children' to reduce log spam\n child_copy = {**child, \"children\": _EllipsisRepr()}\n warn(f\"Key not specified for child in list {child_copy}\", UserWarning)\n\n\nclass _CustomVdomDictConstructor(Protocol):\n def __call__(\n self,\n attributes: VdomAttributes,\n children: Sequence[VdomChild],\n key: Key | None,\n event_handlers: EventHandlerDict,\n ) -> VdomDict: ...\n\n\nclass _EllipsisRepr:\n def __repr__(self) -> str:\n return \"...\"\n\n# Path: src/py/reactpy/reactpy/utils.py\nfrom __future__ import annotations\n\nimport re\nfrom collections.abc import Iterable\nfrom itertools import chain\nfrom typing import Any, Callable, Generic, TypeVar, cast\n\nfrom lxml import etree\nfrom lxml.html import fromstring, tostring\n\nfrom reactpy.core.types import VdomDict\nfrom reactpy.core.vdom import vdom\n\n_RefValue = TypeVar(\"_RefValue\")\n_ModelTransform = Callable[[VdomDict], Any]\n_UNDEFINED: Any = object()\n\n\nclass Ref(Generic[_RefValue]):\n \"\"\"Hold a reference to a value\n\n This is used in imperative code to mutate the state of this object in order to\n incur side effects. Generally refs should be avoided if possible, but sometimes\n they are required.\n\n Notes:\n You can compare the contents for two ``Ref`` objects using the ``==`` operator.\n \"\"\"\n\n __slots__ = (\"current\",)\n\n def __init__(self, initial_value: _RefValue = _UNDEFINED) -> None:\n if initial_value is not _UNDEFINED:\n self.current = initial_value\n \"\"\"The present value\"\"\"\n\n def set_current(self, new: _RefValue) -> _RefValue:\n \"\"\"Set the current value and return what is now the old value\n\n This is nice to use in ``lambda`` functions.\n \"\"\"\n old = self.current\n self.current = new\n return old\n\n def __eq__(self, other: Any) -> bool:\n try:\n return isinstance(other, Ref) and (other.current == self.current)\n except AttributeError:\n # attribute error occurs for uninitialized refs\n return False\n\n def __repr__(self) -> str:\n try:\n current = repr(self.current)\n except AttributeError:\n # attribute error occurs for uninitialized refs\n current = \"\"\n return f\"{type(self).__name__}({current})\"\n\n\ndef vdom_to_html(vdom: VdomDict) -> str:\n \"\"\"Convert a VDOM dictionary into an HTML string\n\n Only the following keys are translated to HTML:\n\n - ``tagName``\n - ``attributes``\n - ``children`` (must be strings or more VDOM dicts)\n\n Parameters:\n vdom: The VdomDict element to convert to HTML\n \"\"\"\n temp_root = etree.Element(\"__temp__\")\n _add_vdom_to_etree(temp_root, vdom)\n html = cast(bytes, tostring(temp_root)).decode()\n # strip out temp root <__temp__> element\n return html[10:-11]\n\n\ndef html_to_vdom(\n html: str, *transforms: _ModelTransform, strict: bool = True\n) -> VdomDict:\n \"\"\"Transform HTML into a DOM model. Unique keys can be provided to HTML elements\n using a ``key=...`` attribute within your HTML tag.\n\n Parameters:\n html:\n The raw HTML as a string\n transforms:\n Functions of the form ``transform(old) -> new`` where ``old`` is a VDOM\n dictionary which will be replaced by ``new``. For example, you could use a\n transform function to add highlighting to a ```` block.\n strict:\n If ``True``, raise an exception if the HTML does not perfectly follow HTML5\n syntax.\n \"\"\"\n if not isinstance(html, str): # nocov\n msg = f\"Expected html to be a string, not {type(html).__name__}\"\n raise TypeError(msg)\n\n # If the user provided a string, convert it to a list of lxml.etree nodes\n try:\n root_node: etree._Element = fromstring(\n html.strip(),\n parser=etree.HTMLParser(\n remove_comments=True,\n remove_pis=True,\n remove_blank_text=True,\n recover=not strict,\n ),\n )\n except etree.XMLSyntaxError as e:\n if not strict:\n raise e # nocov\n msg = \"An error has occurred while parsing the HTML.\\n\\nThis HTML may be malformatted, or may not perfectly adhere to HTML5.\\nIf you believe the exception above was due to something intentional, you can disable the strict parameter on html_to_vdom().\\nOtherwise, repair your broken HTML and try again.\"\n raise HTMLParseError(msg) from e\n\n return _etree_to_vdom(root_node, transforms)\n\n\nclass HTMLParseError(etree.LxmlSyntaxError): # type: ignore[misc]\n \"\"\"Raised when an HTML document cannot be parsed using strict parsing.\"\"\"\n\n\ndef _etree_to_vdom(\n node: etree._Element, transforms: Iterable[_ModelTransform]\n) -> VdomDict:\n \"\"\"Transform an lxml etree node into a DOM model\n\n Parameters:\n node:\n The ``lxml.etree._Element`` node\n transforms:\n Functions of the form ``transform(old) -> new`` where ``old`` is a VDOM\n dictionary which will be replaced by ``new``. For example, you could use a\n transform function to add highlighting to a ```` block.\n \"\"\"\n if not isinstance(node, etree._Element): # nocov\n msg = f\"Expected node to be a etree._Element, not {type(node).__name__}\"\n raise TypeError(msg)\n\n # Recursively call _etree_to_vdom() on all children\n children = _generate_vdom_children(node, transforms)\n\n # Convert the lxml node to a VDOM dict\n el = vdom(node.tag, dict(node.items()), *children)\n\n # Perform any necessary mutations on the VDOM attributes to meet VDOM spec\n _mutate_vdom(el)\n\n # Apply any provided transforms.\n for transform in transforms:\n el = transform(el)\n\n return el\n\n\ndef _add_vdom_to_etree(parent: etree._Element, vdom: VdomDict | dict[str, Any]) -> None:\n try:\n tag = vdom[\"tagName\"]\n except KeyError as e:\n msg = f\"Expected a VDOM dict, not {vdom}\"\n raise TypeError(msg) from e\n else:\n vdom = cast(VdomDict, vdom)\n\n if tag:\n element = etree.SubElement(parent, tag)\n element.attrib.update(\n _vdom_attr_to_html_str(k, v) for k, v in vdom.get(\"attributes\", {}).items()\n )\n else:\n element = parent\n\n for c in vdom.get(\"children\", []):\n if isinstance(c, dict):\n _add_vdom_to_etree(element, c)\n else:\n \"\"\"\n LXML handles string children by storing them under `text` and `tail`\n attributes of Element objects. The `text` attribute, if present, effectively\n becomes that element's first child. Then the `tail` attribute, if present,\n becomes a sibling that follows that element. For example, consider the\n following HTML:\n\n

helloworld

\n\n In this code sample, \"hello\" is the `text` attribute of the `` element\n and \"world\" is the `tail` attribute of that same `` element. It's for\n this reason that, depending on whether the element being constructed has\n non-string a child element, we need to assign a `text` vs `tail` attribute\n to that element or the last non-string child respectively.\n \"\"\"\n if len(element):\n last_child = element[-1]\n last_child.tail = f\"{last_child.tail or ''}{c}\"\n else:\n element.text = f\"{element.text or ''}{c}\"\n\n\ndef _mutate_vdom(vdom: VdomDict) -> None:\n \"\"\"Performs any necessary mutations on the VDOM attributes to meet VDOM spec.\n\n Currently, this function only transforms the ``style`` attribute into a dictionary whose keys are\n camelCase so as to be renderable by React.\n\n This function may be extended in the future.\n \"\"\"\n # Determine if the style attribute needs to be converted to a dict\n if (\n \"attributes\" in vdom\n and \"style\" in vdom[\"attributes\"]\n and isinstance(vdom[\"attributes\"][\"style\"], str)\n ):\n # Convince type checker that it's safe to mutate attributes\n assert isinstance(vdom[\"attributes\"], dict) # noqa: S101\n\n # Convert style attribute from str -> dict with camelCase keys\n vdom[\"attributes\"][\"style\"] = {\n key.strip().replace(\"-\", \"_\"): value.strip()\n for key, value in (\n part.split(\":\", 1)\n for part in vdom[\"attributes\"][\"style\"].split(\";\")\n if \":\" in part\n )\n }\n\n\ndef _generate_vdom_children(\n node: etree._Element, transforms: Iterable[_ModelTransform]\n) -> list[VdomDict | str]:\n \"\"\"Generates a list of VDOM children from an lxml node.\n\n Inserts inner text and/or tail text in between VDOM children, if necessary.\n \"\"\"\n return ( # Get the inner text of the current node\n [node.text] if node.text else []\n ) + list(\n chain(\n *(\n # Recursively convert each child node to VDOM\n [_etree_to_vdom(child, transforms)]\n # Insert the tail text between each child node\n + ([child.tail] if child.tail else [])\n for child in node.iterchildren(None)\n )\n )\n )\n\n\ndef del_html_head_body_transform(vdom: VdomDict) -> VdomDict:\n \"\"\"Transform intended for use with `html_to_vdom`.\n\n Removes ``, ``, and `` while preserving their children.\n\n Parameters:\n vdom:\n The VDOM dictionary to transform.\n \"\"\"\n if vdom[\"tagName\"] in {\"html\", \"body\", \"head\"}:\n return {\"tagName\": \"\", \"children\": vdom[\"children\"]}\n return vdom\n\n\ndef _vdom_attr_to_html_str(key: str, value: Any) -> tuple[str, str]:\n if key == \"style\":\n if isinstance(value, dict):\n value = \";\".join(\n # We lower only to normalize - CSS is case-insensitive:\n # https://www.w3.org/TR/css-fonts-3/#font-family-casing\n f\"{_CAMEL_CASE_SUB_PATTERN.sub('-', k).lower()}:{v}\"\n for k, v in value.items()\n )\n elif (\n # camel to data-* attributes\n key.startswith(\"data_\")\n # camel to aria-* attributes\n or key.startswith(\"aria_\")\n # handle special cases\n or key in DASHED_HTML_ATTRS\n ):\n key = key.replace(\"_\", \"-\")\n elif (\n # camel to data-* attributes\n key.startswith(\"data\")\n # camel to aria-* attributes\n or key.startswith(\"aria\")\n # handle special cases\n or key in DASHED_HTML_ATTRS\n ):\n key = _CAMEL_CASE_SUB_PATTERN.sub(\"-\", key)\n\n if callable(value): # nocov\n raise TypeError(f\"Cannot convert callable attribute {key}={value} to HTML\")\n\n # Again, we lower the attribute name only to normalize - HTML is case-insensitive:\n # http://w3c.github.io/html-reference/documents.html#case-insensitivity\n return key.lower(), str(value)\n\n\n# see list of HTML attributes with dashes in them:\n# https://developer.mozilla.org/en-US/docs/Web/HTML/Attributes#attribute_list\nDASHED_HTML_ATTRS = {\"accept_charset\", \"acceptCharset\", \"http_equiv\", \"httpEquiv\"}\n\n# Pattern for delimitting camelCase names (e.g. camelCase to camel-case)\n_CAMEL_CASE_SUB_PATTERN = re.compile(r\"(? State[_Type]: ...\n\n\n@overload\ndef use_state(initial_value: _Type) -> State[_Type]: ...\n\n\ndef use_state(initial_value: _Type | Callable[[], _Type]) -> State[_Type]:\n \"\"\"See the full :ref:`Use State` docs for details\n\n Parameters:\n initial_value:\n Defines the initial value of the state. A callable (accepting no arguments)\n can be used as a constructor function to avoid re-creating the initial value\n on each render.\n\n Returns:\n A tuple containing the current state and a function to update it.\n \"\"\"\n current_state = _use_const(lambda: _CurrentState(initial_value))\n return State(current_state.value, current_state.dispatch)\n\n\nclass _CurrentState(Generic[_Type]):\n __slots__ = \"value\", \"dispatch\"\n\n def __init__(\n self,\n initial_value: _Type | Callable[[], _Type],\n ) -> None:\n if callable(initial_value):\n self.value = initial_value()\n else:\n self.value = initial_value\n\n hook = current_hook()\n\n def dispatch(new: _Type | Callable[[_Type], _Type]) -> None:\n if callable(new):\n next_value = new(self.value)\n else:\n next_value = new\n if not strictly_equal(next_value, self.value):\n self.value = next_value\n hook.schedule_render()\n\n self.dispatch = dispatch\n\n\n_EffectCleanFunc: TypeAlias = \"Callable[[], None]\"\n_SyncEffectFunc: TypeAlias = \"Callable[[], _EffectCleanFunc | None]\"\n_AsyncEffectFunc: TypeAlias = (\n \"Callable[[], Coroutine[None, None, _EffectCleanFunc | None]]\"\n)\n_EffectApplyFunc: TypeAlias = \"_SyncEffectFunc | _AsyncEffectFunc\"\n\n\n@overload\ndef use_effect(\n function: None = None,\n dependencies: Sequence[Any] | ellipsis | None = ...,\n) -> Callable[[_EffectApplyFunc], None]: ...\n\n\n@overload\ndef use_effect(\n function: _EffectApplyFunc,\n dependencies: Sequence[Any] | ellipsis | None = ...,\n) -> None: ...\n\n\ndef use_effect(\n function: _EffectApplyFunc | None = None,\n dependencies: Sequence[Any] | ellipsis | None = ...,\n) -> Callable[[_EffectApplyFunc], None] | None:\n \"\"\"See the full :ref:`Use Effect` docs for details\n\n Parameters:\n function:\n Applies the effect and can return a clean-up function\n dependencies:\n Dependencies for the effect. The effect will only trigger if the identity\n of any value in the given sequence changes (i.e. their :func:`id` is\n different). By default these are inferred based on local variables that are\n referenced by the given function.\n\n Returns:\n If not function is provided, a decorator. Otherwise ``None``.\n \"\"\"\n hook = current_hook()\n\n dependencies = _try_to_infer_closure_values(function, dependencies)\n memoize = use_memo(dependencies=dependencies)\n last_clean_callback: Ref[_EffectCleanFunc | None] = use_ref(None)\n\n def add_effect(function: _EffectApplyFunc) -> None:\n if not asyncio.iscoroutinefunction(function):\n sync_function = cast(_SyncEffectFunc, function)\n else:\n async_function = cast(_AsyncEffectFunc, function)\n\n def sync_function() -> _EffectCleanFunc | None:\n task = asyncio.create_task(async_function())\n\n def clean_future() -> None:\n if not task.cancel():\n try:\n clean = task.result()\n except asyncio.CancelledError:\n pass\n else:\n if clean is not None:\n clean()\n\n return clean_future\n\n async def effect(stop: asyncio.Event) -> None:\n if last_clean_callback.current is not None:\n last_clean_callback.current()\n last_clean_callback.current = None\n clean = last_clean_callback.current = sync_function()\n await stop.wait()\n if clean is not None:\n clean()\n\n return memoize(lambda: hook.add_effect(effect))\n\n if function is not None:\n add_effect(function)\n return None\n else:\n return add_effect\n\n\ndef use_debug_value(\n message: Any | Callable[[], Any],\n dependencies: Sequence[Any] | ellipsis | None = ...,\n) -> None:\n \"\"\"Log debug information when the given message changes.\n\n .. note::\n This hook only logs if :data:`~reactpy.config.REACTPY_DEBUG_MODE` is active.\n\n Unlike other hooks, a message is considered to have changed if the old and new\n values are ``!=``. Because this comparison is performed on every render of the\n component, it may be worth considering the performance cost in some situations.\n\n Parameters:\n message:\n The value to log or a memoized function for generating the value.\n dependencies:\n Dependencies for the memoized function. The message will only be recomputed\n if the identity of any value in the given sequence changes (i.e. their\n :func:`id` is different). By default these are inferred based on local\n variables that are referenced by the given function.\n \"\"\"\n old: Ref[Any] = _use_const(lambda: Ref(object()))\n memo_func = message if callable(message) else lambda: message\n new = use_memo(memo_func, dependencies)\n\n if REACTPY_DEBUG_MODE.current and old.current != new:\n old.current = new\n logger.debug(f\"{current_hook().component} {new}\")\n\n\ndef create_context(default_value: _Type) -> Context[_Type]:\n \"\"\"Return a new context type for use in :func:`use_context`\"\"\"\n\n def context(\n *children: Any,\n value: _Type = default_value,\n key: Key | None = None,\n ) -> _ContextProvider[_Type]:\n return _ContextProvider(\n *children,\n value=value,\n key=key,\n type=context,\n )\n\n context.__qualname__ = \"context\"\n\n return context\n\n\ndef use_context(context: Context[_Type]) -> _Type:\n \"\"\"Get the current value for the given context type.\n\n See the full :ref:`Use Context` docs for more information.\n \"\"\"\n hook = current_hook()\n provider = hook.get_context_provider(context)\n\n if provider is None:\n # same assertions but with normal exceptions\n if not isinstance(context, FunctionType):\n raise TypeError(f\"{context} is not a Context\") # nocov\n if context.__kwdefaults__ is None:\n raise TypeError(f\"{context} has no 'value' kwarg\") # nocov\n if \"value\" not in context.__kwdefaults__:\n raise TypeError(f\"{context} has no 'value' kwarg\") # nocov\n return cast(_Type, context.__kwdefaults__[\"value\"])\n\n return provider.value\n\n\nclass _ContextProvider(Generic[_Type]):\n def __init__(\n self,\n *children: Any,\n value: _Type,\n key: Key | None,\n type: Context[_Type],\n ) -> None:\n self.children = children\n self.key = key\n self.type = type\n self.value = value\n\n def render(self) -> VdomDict:\n current_hook().set_context_provider(self)\n return {\"tagName\": \"\", \"children\": self.children}\n\n def __repr__(self) -> str:\n return f\"ContextProvider({self.type})\"\n\n\n_ActionType = TypeVar(\"_ActionType\")\n\n\ndef use_reducer(\n reducer: Callable[[_Type, _ActionType], _Type],\n initial_value: _Type,\n) -> tuple[_Type, Callable[[_ActionType], None]]:\n \"\"\"See the full :ref:`Use Reducer` docs for details\n\n Parameters:\n reducer:\n A function which applies an action to the current state in order to\n produce the next state.\n initial_value:\n The initial state value (same as for :func:`use_state`)\n\n Returns:\n A tuple containing the current state and a function to change it with an action\n \"\"\"\n state, set_state = use_state(initial_value)\n return state, _use_const(lambda: _create_dispatcher(reducer, set_state))\n\n\ndef _create_dispatcher(\n reducer: Callable[[_Type, _ActionType], _Type],\n set_state: Callable[[Callable[[_Type], _Type]], None],\n) -> Callable[[_ActionType], None]:\n def dispatch(action: _ActionType) -> None:\n set_state(lambda last_state: reducer(last_state, action))\n\n return dispatch\n\n\n_CallbackFunc = TypeVar(\"_CallbackFunc\", bound=Callable[..., Any])\n\n\n@overload\ndef use_callback(\n function: None = None,\n dependencies: Sequence[Any] | ellipsis | None = ...,\n) -> Callable[[_CallbackFunc], _CallbackFunc]: ...\n\n\n@overload\ndef use_callback(\n function: _CallbackFunc,\n dependencies: Sequence[Any] | ellipsis | None = ...,\n) -> _CallbackFunc: ...\n\n\ndef use_callback(\n function: _CallbackFunc | None = None,\n dependencies: Sequence[Any] | ellipsis | None = ...,\n) -> _CallbackFunc | Callable[[_CallbackFunc], _CallbackFunc]:\n \"\"\"See the full :ref:`Use Callback` docs for details\n\n Parameters:\n function:\n The function whose identity will be preserved\n dependencies:\n Dependencies of the callback. The identity the ``function`` will be updated\n if the identity of any value in the given sequence changes (i.e. their\n :func:`id` is different). By default these are inferred based on local\n variables that are referenced by the given function.\n\n Returns:\n The current function\n \"\"\"\n dependencies = _try_to_infer_closure_values(function, dependencies)\n memoize = use_memo(dependencies=dependencies)\n\n def setup(function: _CallbackFunc) -> _CallbackFunc:\n return memoize(lambda: function)\n\n if function is not None:\n return setup(function)\n else:\n return setup\n\n\nclass _LambdaCaller(Protocol):\n \"\"\"MyPy doesn't know how to deal with TypeVars only used in function return\"\"\"\n\n def __call__(self, func: Callable[[], _Type]) -> _Type: ...\n\n\n@overload\ndef use_memo(\n function: None = None,\n dependencies: Sequence[Any] | ellipsis | None = ...,\n) -> _LambdaCaller: ...\n\n\n@overload\ndef use_memo(\n function: Callable[[], _Type],\n dependencies: Sequence[Any] | ellipsis | None = ...,\n) -> _Type: ...\n\n\ndef use_memo(\n function: Callable[[], _Type] | None = None,\n dependencies: Sequence[Any] | ellipsis | None = ...,\n) -> _Type | Callable[[Callable[[], _Type]], _Type]:\n \"\"\"See the full :ref:`Use Memo` docs for details\n\n Parameters:\n function:\n The function to be memoized.\n dependencies:\n Dependencies for the memoized function. The memo will only be recomputed if\n the identity of any value in the given sequence changes (i.e. their\n :func:`id` is different). By default these are inferred based on local\n variables that are referenced by the given function.\n\n Returns:\n The current state\n \"\"\"\n dependencies = _try_to_infer_closure_values(function, dependencies)\n\n memo: _Memo[_Type] = _use_const(_Memo)\n\n if memo.empty():\n # we need to initialize on the first run\n changed = True\n memo.deps = () if dependencies is None else dependencies\n elif dependencies is None:\n changed = True\n memo.deps = ()\n elif (\n len(memo.deps) != len(dependencies)\n # if deps are same length check identity for each item\n or not all(\n strictly_equal(current, new)\n for current, new in zip(memo.deps, dependencies)\n )\n ):\n memo.deps = dependencies\n changed = True\n else:\n changed = False\n\n setup: Callable[[Callable[[], _Type]], _Type]\n\n if changed:\n\n def setup(function: Callable[[], _Type]) -> _Type:\n current_value = memo.value = function()\n return current_value\n\n else:\n\n def setup(function: Callable[[], _Type]) -> _Type:\n return memo.value\n\n if function is not None:\n return setup(function)\n else:\n return setup\n\n\nclass _Memo(Generic[_Type]):\n \"\"\"Simple object for storing memoization data\"\"\"\n\n __slots__ = \"value\", \"deps\"\n\n value: _Type\n deps: Sequence[Any]\n\n def empty(self) -> bool:\n try:\n self.value # noqa: B018\n except AttributeError:\n return True\n else:\n return False\n\n\ndef use_ref(initial_value: _Type) -> Ref[_Type]:\n \"\"\"See the full :ref:`Use State` docs for details\n\n Parameters:\n initial_value: The value initially assigned to the reference.\n\n Returns:\n A :class:`Ref` object.\n \"\"\"\n return _use_const(lambda: Ref(initial_value))\n\n\ndef _use_const(function: Callable[[], _Type]) -> _Type:\n return current_hook().use_state(function)\n\n\ndef _try_to_infer_closure_values(\n func: Callable[..., Any] | None,\n values: Sequence[Any] | ellipsis | None,\n) -> Sequence[Any] | None:\n if values is ...:\n if isinstance(func, FunctionType):\n return (\n [cell.cell_contents for cell in func.__closure__]\n if func.__closure__\n else []\n )\n else:\n return None\n else:\n return values\n\n\ndef strictly_equal(x: Any, y: Any) -> bool:\n \"\"\"Check if two values are identical or, for a limited set or types, equal.\n\n Only the following types are checked for equality rather than identity:\n\n - ``int``\n - ``float``\n - ``complex``\n - ``str``\n - ``bytes``\n - ``bytearray``\n - ``memoryview``\n \"\"\"\n return x is y or (type(x) in _NUMERIC_TEXT_BINARY_TYPES and x == y)\n\n\n_NUMERIC_TEXT_BINARY_TYPES = {\n # numeric\n int,\n float,\n complex,\n # text\n str,\n # binary types\n bytes,\n bytearray,\n memoryview,\n}\n\n# Path: src/py/reactpy/reactpy/backend/hooks.py\nfrom __future__ import annotations\n\nfrom collections.abc import MutableMapping\nfrom typing import Any\n\nfrom reactpy.backend.types import Connection, Location\nfrom reactpy.core.hooks import create_context, use_context\nfrom reactpy.core.types import Context\n\n# backend implementations should establish this context at the root of an app\nConnectionContext: Context[Connection[Any] | None] = create_context(None)\n\n\ndef use_connection() -> Connection[Any]:\n \"\"\"Get the current :class:`~reactpy.backend.types.Connection`.\"\"\"\n conn = use_context(ConnectionContext)\n if conn is None: # nocov\n msg = \"No backend established a connection.\"\n raise RuntimeError(msg)\n return conn\n\n\ndef use_scope() -> MutableMapping[str, Any]:\n \"\"\"Get the current :class:`~reactpy.backend.types.Connection`'s scope.\"\"\"\n return use_connection().scope\n\n\ndef use_location() -> Location:\n \"\"\"Get the current :class:`~reactpy.backend.types.Connection`'s location.\"\"\"\n return use_connection().location\n\n# Path: src/py/reactpy/reactpy/core/component.py\nfrom __future__ import annotations\n\nimport inspect\nfrom functools import wraps\nfrom typing import Any, Callable\n\nfrom reactpy.core.types import ComponentType, VdomDict\n\n\ndef component(\n function: Callable[..., ComponentType | VdomDict | str | None]\n) -> Callable[..., Component]:\n \"\"\"A decorator for defining a new component.\n\n Parameters:\n function: The component's :meth:`reactpy.core.proto.ComponentType.render` function.\n \"\"\"\n sig = inspect.signature(function)\n\n if \"key\" in sig.parameters and sig.parameters[\"key\"].kind in (\n inspect.Parameter.KEYWORD_ONLY,\n inspect.Parameter.POSITIONAL_OR_KEYWORD,\n ):\n msg = f\"Component render function {function} uses reserved parameter 'key'\"\n raise TypeError(msg)\n\n @wraps(function)\n def constructor(*args: Any, key: Any | None = None, **kwargs: Any) -> Component:\n return Component(function, key, args, kwargs, sig)\n\n return constructor\n\n\nclass Component:\n \"\"\"An object for rending component models.\"\"\"\n\n __slots__ = \"__weakref__\", \"_func\", \"_args\", \"_kwargs\", \"_sig\", \"key\", \"type\"\n\n def __init__(\n self,\n function: Callable[..., ComponentType | VdomDict | str | None],\n key: Any | None,\n args: tuple[Any, ...],\n kwargs: dict[str, Any],\n sig: inspect.Signature,\n ) -> None:\n self.key = key\n self.type = function\n self._args = args\n self._kwargs = kwargs\n self._sig = sig\n\n def render(self) -> ComponentType | VdomDict | str | None:\n return self.type(*self._args, **self._kwargs)\n\n def __repr__(self) -> str:\n try:\n args = self._sig.bind(*self._args, **self._kwargs).arguments\n except TypeError:\n return f\"{self.type.__name__}(...)\"\n else:\n items = \", \".join(f\"{k}={v!r}\" for k, v in args.items())\n if items:\n return f\"{self.type.__name__}({id(self):02x}, {items})\"\n else:\n return f\"{self.type.__name__}({id(self):02x})\"\n\n# Path: src/py/reactpy/reactpy/types.py\n\"\"\"Exports common types from:\n\n- :mod:`reactpy.core.types`\n- :mod:`reactpy.backend.types`\n\"\"\"\n\nfrom reactpy.backend.types import BackendType, Connection, Location\nfrom reactpy.core.component import Component\nfrom reactpy.core.types import (\n ComponentConstructor,\n ComponentType,\n Context,\n EventHandlerDict,\n EventHandlerFunc,\n EventHandlerMapping,\n EventHandlerType,\n ImportSourceDict,\n Key,\n LayoutType,\n RootComponentConstructor,\n State,\n VdomAttributes,\n VdomChild,\n VdomChildren,\n VdomDict,\n VdomJson,\n)\n\n__all__ = [\n \"BackendType\",\n \"Component\",\n \"ComponentConstructor\",\n \"ComponentType\",\n \"Connection\",\n \"Context\",\n \"EventHandlerDict\",\n \"EventHandlerFunc\",\n \"EventHandlerMapping\",\n \"EventHandlerType\",\n \"ImportSourceDict\",\n \"Key\",\n \"LayoutType\",\n \"Location\",\n \"RootComponentConstructor\",\n \"State\",\n \"VdomAttributes\",\n \"VdomChild\",\n \"VdomChildren\",\n \"VdomDict\",\n \"VdomJson\",\n]\n\n# Path: src/py/reactpy/reactpy/backend/utils.py\nfrom __future__ import annotations\n\nimport asyncio\nimport logging\nimport socket\nimport sys\nfrom collections.abc import Iterator\nfrom contextlib import closing\nfrom importlib import import_module\nfrom typing import Any\n\nfrom reactpy.backend.types import BackendType\nfrom reactpy.types import RootComponentConstructor\n\nlogger = logging.getLogger(__name__)\n\nSUPPORTED_BACKENDS = (\n \"fastapi\",\n \"sanic\",\n \"tornado\",\n \"flask\",\n \"starlette\",\n)\n\n\ndef run(\n component: RootComponentConstructor,\n host: str = \"127.0.0.1\",\n port: int | None = None,\n implementation: BackendType[Any] | None = None,\n) -> None:\n \"\"\"Run a component with a development server\"\"\"\n logger.warning(_DEVELOPMENT_RUN_FUNC_WARNING)\n\n implementation = implementation or import_module(\"reactpy.backend.default\")\n app = implementation.create_development_app()\n implementation.configure(app, component)\n port = port or find_available_port(host)\n app_cls = type(app)\n\n logger.info(\n \"ReactPy is running with '%s.%s' at http://%s:%s\",\n app_cls.__module__,\n app_cls.__name__,\n host,\n port,\n )\n asyncio.run(implementation.serve_development_app(app, host, port))\n\n\ndef find_available_port(host: str, port_min: int = 8000, port_max: int = 9000) -> int:\n \"\"\"Get a port that's available for the given host and port range\"\"\"\n for port in range(port_min, port_max):\n with closing(socket.socket()) as sock:\n try:\n if sys.platform in (\"linux\", \"darwin\"):\n # Fixes bug on Unix-like systems where every time you restart the\n # server you'll get a different port on Linux. This cannot be set\n # on Windows otherwise address will always be reused.\n # Ref: https://stackoverflow.com/a/19247688/3159288\n sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)\n sock.bind((host, port))\n except OSError:\n pass\n else:\n return port\n msg = f\"Host {host!r} has no available port in range {port_max}-{port_max}\"\n raise RuntimeError(msg)\n\n\n\ndef all_implementations() -> Iterator[BackendType[Any]]:\n \"\"\"Yield all available server implementations\"\"\"\n for name in SUPPORTED_BACKENDS:\n try:\n import_module(name)\n except ImportError: # nocov\n logger.debug(\"Failed to import %s\", name, exc_info=True)\n continue\n\n reactpy_backend_name = f\"{__name__.rsplit('.', 1)[0]}.{name}\"\n yield import_module(reactpy_backend_name)\n\n\n_DEVELOPMENT_RUN_FUNC_WARNING = \"\"\"\\\nThe `run()` function is only intended for testing during development! To run \\\nin production, refer to the docs on how to use reactpy.backend.*.configure.\\\n\"\"\"\n\n# Path: src/py/reactpy/reactpy/core/layout.py\nfrom __future__ import annotations\n\nimport abc\nfrom asyncio import (\n FIRST_COMPLETED,\n CancelledError,\n Queue,\n Task,\n create_task,\n get_running_loop,\n wait,\n)\nfrom collections import Counter\nfrom collections.abc import Sequence\nfrom contextlib import AsyncExitStack\nfrom logging import getLogger\nfrom typing import (\n Any,\n Callable,\n Generic,\n NamedTuple,\n NewType,\n TypeVar,\n cast,\n)\nfrom uuid import uuid4\nfrom weakref import ref as weakref\n\nfrom anyio import Semaphore\nfrom typing_extensions import TypeAlias\n\nfrom reactpy.config import (\n REACTPY_ASYNC_RENDERING,\n REACTPY_CHECK_VDOM_SPEC,\n REACTPY_DEBUG_MODE,\n)\nfrom reactpy.core._life_cycle_hook import LifeCycleHook\nfrom reactpy.core.types import (\n ComponentType,\n EventHandlerDict,\n Key,\n LayoutEventMessage,\n LayoutUpdateMessage,\n VdomChild,\n VdomDict,\n VdomJson,\n)\nfrom reactpy.core.vdom import validate_vdom_json\nfrom reactpy.utils import Ref\n\nlogger = getLogger(__name__)\n\n\nclass Layout:\n \"\"\"Responsible for \"rendering\" components. That is, turning them into VDOM.\"\"\"\n\n __slots__: tuple[str, ...] = (\n \"root\",\n \"_event_handlers\",\n \"_rendering_queue\",\n \"_render_tasks\",\n \"_render_tasks_ready\",\n \"_root_life_cycle_state_id\",\n \"_model_states_by_life_cycle_state_id\",\n )\n\n if not hasattr(abc.ABC, \"__weakref__\"): # nocov\n __slots__ += (\"__weakref__\",)\n\n def __init__(self, root: ComponentType) -> None:\n super().__init__()\n if not isinstance(root, ComponentType):\n msg = f\"Expected a ComponentType, not {type(root)!r}.\"\n raise TypeError(msg)\n self.root = root\n\n async def __aenter__(self) -> Layout:\n # create attributes here to avoid access before entering context manager\n self._event_handlers: EventHandlerDict = {}\n self._render_tasks: set[Task[LayoutUpdateMessage]] = set()\n self._render_tasks_ready: Semaphore = Semaphore(0)\n\n...\n", "needle_token_start": 15147, "needle_token_end": 15280, "code_context_ntokens": 16005, "output": ["```python\nfrom __future__ import annotations\n\nimport asyncio\nimport logging\nimport socket\nimport sys\nfrom collections.abc import Iterator\nfrom contextlib import closing\nfrom importlib import import_module\nfrom typing import Any\n\nfrom reactpy.backend.types import BackendType\nfrom reactpy.types import RootComponentConstructor\n\nlogger = logging.getLogger(__name__)\n\nSUPPORTED_BACKENDS = (\n \"fastapi\",\n \"sanic\",\n \"tornado\",\n \"flask\",\n \"starlette\",\n)\n\n\ndef run(\n component: RootComponentConstructor,\n host: str = \"127.0.0.1\",\n port: int | None = None,\n implementation: BackendType[Any] | None = None,\n) -> None:\n \"\"\"Run a component with a development server\"\"\"\n logger.warning(_DEVELOPMENT_RUN_FUNC_WARNING)\n\n implementation = implementation or import_module(\"reactpy.backend.default\")\n app = implementation.create_development_app()\n implementation.configure(app, component)\n port = port or find_available_port(host)\n app_cls = type(app)\n\n logger.info(\n \"ReactPy is running with '%s.%s' at http://%s:%s\",\n app_cls.__module__,\n app_cls.__name__,\n host,\n port,\n )\n asyncio.run(implementation.serve_development_app(app, host, port))\n\n\ndef find_available_port(host: str, port_min: int = 8000, port_max: int = 9000) -> int:\n \"\"\"Get a port that's available for the given host and port range\"\"\"\n for port in range(port_min, port_max):\n with closing(socket.socket()) as sock:\n try:\n if sys.platform in (\"linux\", \"darwin\"):\n # Fixes bug on Unix-like systems where every time you restart the\n # server you'll get a different port on Linux. This cannot be set\n # on Windows otherwise address will always be reused.\n # Ref: https://stackoverflow.com/a/19247688/3159288\n sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)\n sock.bind((host, port))\n except OSError:\n pass\n else:\n return port\n msg = f\"Host {host!r} has no available port in range {port_max}-{port_max}\"\n raise RuntimeError(msg)\n\n\n\ndef all_implementations() -> Iterator[BackendType[Any]]:\n \"\"\"Yield all available server implementations\"\"\"\n for name in SUPPORTED_BACKENDS:\n try:\n import_module(name)\n except ImportError: # nocov\n logger.debug(\"Failed to import %s\", name, exc_info=True)\n continue\n\n reactpy_backend_name = f\"{__name__.rsplit('.', 1)[0]}.{name}\"\n yield import_module(reactpy_backend_name)\n\n\n_DEVELOPMENT_RUN_FUNC_WARNING = \"\"\"\\\nThe `run()` function is only intended for testing during development! To run \\\nin production, refer to the docs on how to use reactpy.backend.*.configure.\\\n\"\"\"\n```"]} +{"repo": "marshmallow-code/marshmallow", "name": "register", "language": "python", "path": "src/marshmallow/class_registry.py", "position_ratio": 0.05, "description": "\nFunction Description:\n1. **Purpose**: The function is designed to add a serializer class to a centralized registry, allowing for string-based lookup of these classes, which is particularly useful in scenarios involving nested serialization.\n2. **Input**: The function takes two parameters: a string representing the class name and the class itself.\n3. **Output**: There is no return value; the function operates by modifying a global registry.\n4. **Procedure**: The function first determines the module where the class is defined and constructs a full path identifier for the class. It then checks if the class is already registered under the simple class name without causing duplication in the same module. It adds the class to the registry under both the simple class name and the full path. If the full path is already registered, it updates the existing entry to ensure the class is correctly referenced.\n\n", "instruction": "Based on the function description and code context, please retrieve and repeat the exact described function from the code context in a code block wrapped by ```:", "template": "instruction\ncode_context\ndescription\ninstruction", "code_context": "# Path: src/marshmallow/base.py\n\"\"\"Abstract base classes.\n\nThese are necessary to avoid circular imports between schema.py and fields.py.\n\n.. warning::\n\n This module is treated as private API.\n Users should not need to use this module directly.\n\"\"\"\nfrom __future__ import annotations\n\nfrom abc import ABC, abstractmethod\n\n\nclass FieldABC(ABC):\n \"\"\"Abstract base class from which all Field classes inherit.\"\"\"\n\n parent = None\n name = None\n root = None\n\n @abstractmethod\n def serialize(self, attr, obj, accessor=None):\n pass\n\n @abstractmethod\n def deserialize(self, value):\n pass\n\n @abstractmethod\n def _serialize(self, value, attr, obj, **kwargs):\n pass\n\n @abstractmethod\n def _deserialize(self, value, attr, data, **kwargs):\n pass\n\n\nclass SchemaABC(ABC):\n \"\"\"Abstract base class from which all Schemas inherit.\"\"\"\n\n @abstractmethod\n def dump(self, obj, *, many: bool | None = None):\n pass\n\n @abstractmethod\n def dumps(self, obj, *, many: bool | None = None):\n pass\n\n @abstractmethod\n def load(self, data, *, many: bool | None = None, partial=None, unknown=None):\n pass\n\n @abstractmethod\n def loads(\n self,\n json_data,\n *,\n many: bool | None = None,\n...\n# Path: src/marshmallow/class_registry.py\n\"\"\"A registry of :class:`Schema ` classes. This allows for string\nlookup of schemas, which may be used with\nclass:`fields.Nested `.\n\n.. warning::\n\n This module is treated as private API.\n Users should not need to use this module directly.\n\"\"\"\nfrom __future__ import annotations\n\nimport typing\n\nfrom marshmallow.exceptions import RegistryError\n\nif typing.TYPE_CHECKING:\n from marshmallow import Schema\n\n SchemaType = typing.Type[Schema]\n\n# {\n# : \n# : \n# }\n_registry = {} # type: dict[str, list[SchemaType]]\n\n\n\ndef register(classname: str, cls: SchemaType) -> None:\n \"\"\"Add a class to the registry of serializer classes. When a class is\n registered, an entry for both its classname and its full, module-qualified\n path are added to the registry.\n\n Example: ::\n\n class MyClass:\n pass\n\n register('MyClass', MyClass)\n # Registry:\n # {\n # 'MyClass': [path.to.MyClass],\n # 'path.to.MyClass': [path.to.MyClass],\n # }\n\n \"\"\"\n # Module where the class is located\n module = cls.__module__\n # Full module path to the class\n # e.g. user.schemas.UserSchema\n fullpath = \".\".join([module, classname])\n # If the class is already registered; need to check if the entries are\n # in the same module as cls to avoid having multiple instances of the same\n # class in the registry\n if classname in _registry and not any(\n each.__module__ == module for each in _registry[classname]\n ):\n _registry[classname].append(cls)\n elif classname not in _registry:\n _registry[classname] = [cls]\n\n # Also register the full path\n if fullpath not in _registry:\n _registry.setdefault(fullpath, []).append(cls)\n else:\n # If fullpath does exist, replace existing entry\n _registry[fullpath] = [cls]\n return None\n\n\ndef get_class(classname: str, all: bool = False) -> list[SchemaType] | SchemaType:\n \"\"\"Retrieve a class from the registry.\n\n :raises: marshmallow.exceptions.RegistryError if the class cannot be found\n or if there are multiple entries for the given class name.\n \"\"\"\n try:\n classes = _registry[classname]\n except KeyError as error:\n raise RegistryError(\n f\"Class with name {classname!r} was not found. You may need \"\n \"to import the class.\"\n ) from error\n if len(classes) > 1:\n if all:\n return _registry[classname]\n raise RegistryError(\n f\"Multiple classes with name {classname!r} \"\n \"were found. Please use the full, \"\n \"module-qualified path.\"\n )\n else:\n return _registry[classname][0]\n\n# Path: src/marshmallow/error_store.py\n\"\"\"Utilities for storing collections of error messages.\n\n.. warning::\n\n This module is treated as private API.\n Users should not need to use this module directly.\n\"\"\"\n\nfrom marshmallow.exceptions import SCHEMA\n\n\nclass ErrorStore:\n def __init__(self):\n #: Dictionary of errors stored during serialization\n self.errors = {}\n\n def store_error(self, messages, field_name=SCHEMA, index=None):\n # field error -> store/merge error messages under field name key\n # schema error -> if string or list, store/merge under _schema key\n # -> if dict, store/merge with other top-level keys\n if field_name != SCHEMA or not isinstance(messages, dict):\n messages = {field_name: messages}\n if index is not None:\n messages = {index: messages}\n self.errors = merge_errors(self.errors, messages)\n\n\ndef merge_errors(errors1, errors2):\n \"\"\"Deeply merge two error messages.\n\n The format of ``errors1`` and ``errors2`` matches the ``message``\n parameter of :exc:`marshmallow.exceptions.ValidationError`.\n \"\"\"\n if not errors1:\n return errors2\n if not errors2:\n return errors1\n if isinstance(errors1, list):\n if isinstance(errors2, list):\n return errors1 + errors2\n if isinstance(errors2, dict):\n return dict(errors2, **{SCHEMA: merge_errors(errors1, errors2.get(SCHEMA))})\n return errors1 + [errors2]\n if isinstance(errors1, dict):\n if isinstance(errors2, list):\n return dict(errors1, **{SCHEMA: merge_errors(errors1.get(SCHEMA), errors2)})\n if isinstance(errors2, dict):\n errors = dict(errors1)\n for key, val in errors2.items():\n if key in errors:\n errors[key] = merge_errors(errors[key], val)\n else:\n errors[key] = val\n return errors\n return dict(errors1, **{SCHEMA: merge_errors(errors1.get(SCHEMA), errors2)})\n if isinstance(errors2, list):\n return [errors1] + errors2\n if isinstance(errors2, dict):\n return dict(errors2, **{SCHEMA: merge_errors(errors1, errors2.get(SCHEMA))})\n return [errors1, errors2]\n\n# Path: src/marshmallow/orderedset.py\n# OrderedSet\n# Copyright (c) 2009 Raymond Hettinger\n#\n# Permission is hereby granted, free of charge, to any person\n# obtaining a copy of this software and associated documentation files\n# (the \"Software\"), to deal in the Software without restriction,\n# including without limitation the rights to use, copy, modify, merge,\n# publish, distribute, sublicense, and/or sell copies of the Software,\n# and to permit persons to whom the Software is furnished to do so,\n# subject to the following conditions:\n#\n# The above copyright notice and this permission notice shall be\n# included in all copies or substantial portions of the Software.\n#\n# THE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND,\n# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES\n# OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND\n# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT\n# HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY,\n# WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING\n# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR\n# OTHER DEALINGS IN THE SOFTWARE.\nfrom collections.abc import MutableSet\n\n\nclass OrderedSet(MutableSet):\n def __init__(self, iterable=None):\n self.end = end = []\n end += [None, end, end] # sentinel node for doubly linked list\n self.map = {} # key --> [key, prev, next]\n if iterable is not None:\n self |= iterable\n\n def __len__(self):\n return len(self.map)\n\n def __contains__(self, key):\n return key in self.map\n\n def add(self, key):\n if key not in self.map:\n end = self.end\n curr = end[1]\n curr[2] = end[1] = self.map[key] = [key, curr, end]\n\n def discard(self, key):\n if key in self.map:\n key, prev, next = self.map.pop(key)\n prev[2] = next\n next[1] = prev\n\n def __iter__(self):\n end = self.end\n curr = end[2]\n while curr is not end:\n yield curr[0]\n curr = curr[2]\n\n def __reversed__(self):\n end = self.end\n curr = end[1]\n while curr is not end:\n yield curr[0]\n curr = curr[1]\n\n def pop(self, last=True):\n if not self:\n raise KeyError(\"set is empty\")\n key = self.end[1][0] if last else self.end[2][0]\n self.discard(key)\n return key\n\n def __repr__(self):\n if not self:\n return f\"{self.__class__.__name__}()\"\n return f\"{self.__class__.__name__}({list(self)!r})\"\n\n def __eq__(self, other):\n if isinstance(other, OrderedSet):\n return len(self) == len(other) and list(self) == list(other)\n return set(self) == set(other)\n\n\nif __name__ == \"__main__\":\n s = OrderedSet(\"abracadaba\")\n t = OrderedSet(\"simsalabim\")\n print(s | t)\n print(s & t)\n print(s - t)\n\n# Path: src/marshmallow/types.py\n\"\"\"Type aliases.\n\n.. warning::\n\n This module is provisional. Types may be modified, added, and removed between minor releases.\n\"\"\"\nimport typing\n\nStrSequenceOrSet = typing.Union[typing.Sequence[str], typing.AbstractSet[str]]\nTag = typing.Union[str, typing.Tuple[str, bool]]\nValidator = typing.Callable[[typing.Any], typing.Any]\n\n# Path: src/marshmallow/warnings.py\nclass RemovedInMarshmallow4Warning(DeprecationWarning):\n pass\n\n# Path: src/marshmallow/utils.py\n\"\"\"Utility methods for marshmallow.\"\"\"\nfrom __future__ import annotations\n\nimport collections\nimport datetime as dt\nimport functools\nimport inspect\nimport json\nimport re\nimport typing\nimport warnings\nfrom collections.abc import Mapping\nfrom email.utils import format_datetime, parsedate_to_datetime\nfrom pprint import pprint as py_pprint\n\nfrom marshmallow.base import FieldABC\nfrom marshmallow.exceptions import FieldInstanceResolutionError\nfrom marshmallow.warnings import RemovedInMarshmallow4Warning\n\nEXCLUDE = \"exclude\"\nINCLUDE = \"include\"\nRAISE = \"raise\"\n_UNKNOWN_VALUES = {EXCLUDE, INCLUDE, RAISE}\n\n\nclass _Missing:\n def __bool__(self):\n return False\n\n def __copy__(self):\n return self\n\n def __deepcopy__(self, _):\n return self\n\n def __repr__(self):\n return \"\"\n\n\n# Singleton value that indicates that a field's value is missing from input\n# dict passed to :meth:`Schema.load`. If the field's value is not required,\n# it's ``default`` value is used.\nmissing = _Missing()\n\n\ndef is_generator(obj) -> bool:\n \"\"\"Return True if ``obj`` is a generator\"\"\"\n return inspect.isgeneratorfunction(obj) or inspect.isgenerator(obj)\n\n\ndef is_iterable_but_not_string(obj) -> bool:\n \"\"\"Return True if ``obj`` is an iterable object that isn't a string.\"\"\"\n return (hasattr(obj, \"__iter__\") and not hasattr(obj, \"strip\")) or is_generator(obj)\n\n\ndef is_collection(obj) -> bool:\n \"\"\"Return True if ``obj`` is a collection type, e.g list, tuple, queryset.\"\"\"\n return is_iterable_but_not_string(obj) and not isinstance(obj, Mapping)\n\n\ndef is_instance_or_subclass(val, class_) -> bool:\n \"\"\"Return True if ``val`` is either a subclass or instance of ``class_``.\"\"\"\n try:\n return issubclass(val, class_)\n except TypeError:\n return isinstance(val, class_)\n\n\ndef is_keyed_tuple(obj) -> bool:\n \"\"\"Return True if ``obj`` has keyed tuple behavior, such as\n namedtuples or SQLAlchemy's KeyedTuples.\n \"\"\"\n return isinstance(obj, tuple) and hasattr(obj, \"_fields\")\n\n\ndef pprint(obj, *args, **kwargs) -> None:\n \"\"\"Pretty-printing function that can pretty-print OrderedDicts\n like regular dictionaries. Useful for printing the output of\n :meth:`marshmallow.Schema.dump`.\n\n .. deprecated:: 3.7.0\n marshmallow.pprint will be removed in marshmallow 4.\n \"\"\"\n warnings.warn(\n \"marshmallow's pprint function is deprecated and will be removed in marshmallow 4.\",\n RemovedInMarshmallow4Warning,\n stacklevel=2,\n )\n if isinstance(obj, collections.OrderedDict):\n print(json.dumps(obj, *args, **kwargs))\n else:\n py_pprint(obj, *args, **kwargs)\n\n\n# https://stackoverflow.com/a/27596917\ndef is_aware(datetime: dt.datetime) -> bool:\n return (\n datetime.tzinfo is not None and datetime.tzinfo.utcoffset(datetime) is not None\n )\n\n\ndef from_rfc(datestring: str) -> dt.datetime:\n \"\"\"Parse a RFC822-formatted datetime string and return a datetime object.\n\n https://stackoverflow.com/questions/885015/how-to-parse-a-rfc-2822-date-time-into-a-python-datetime # noqa: B950\n \"\"\"\n return parsedate_to_datetime(datestring)\n\n\ndef rfcformat(datetime: dt.datetime) -> str:\n \"\"\"Return the RFC822-formatted representation of a datetime object.\n\n :param datetime datetime: The datetime.\n \"\"\"\n return format_datetime(datetime)\n\n\n# Hat tip to Django for ISO8601 deserialization functions\n\n_iso8601_datetime_re = re.compile(\n r\"(?P\\d{4})-(?P\\d{1,2})-(?P\\d{1,2})\"\n r\"[T ](?P\\d{1,2}):(?P\\d{1,2})\"\n r\"(?::(?P\\d{1,2})(?:\\.(?P\\d{1,6})\\d{0,6})?)?\"\n r\"(?PZ|[+-]\\d{2}(?::?\\d{2})?)?$\"\n)\n\n_iso8601_date_re = re.compile(r\"(?P\\d{4})-(?P\\d{1,2})-(?P\\d{1,2})$\")\n\n_iso8601_time_re = re.compile(\n r\"(?P\\d{1,2}):(?P\\d{1,2})\"\n r\"(?::(?P\\d{1,2})(?:\\.(?P\\d{1,6})\\d{0,6})?)?\"\n)\n\n\ndef get_fixed_timezone(offset: int | float | dt.timedelta) -> dt.timezone:\n \"\"\"Return a tzinfo instance with a fixed offset from UTC.\"\"\"\n if isinstance(offset, dt.timedelta):\n offset = offset.total_seconds() // 60\n sign = \"-\" if offset < 0 else \"+\"\n hhmm = \"%02d%02d\" % divmod(abs(offset), 60)\n name = sign + hhmm\n return dt.timezone(dt.timedelta(minutes=offset), name)\n\n\ndef from_iso_datetime(value):\n \"\"\"Parse a string and return a datetime.datetime.\n\n This function supports time zone offsets. When the input contains one,\n the output uses a timezone with a fixed offset from UTC.\n \"\"\"\n match = _iso8601_datetime_re.match(value)\n if not match:\n raise ValueError(\"Not a valid ISO8601-formatted datetime string\")\n kw = match.groupdict()\n kw[\"microsecond\"] = kw[\"microsecond\"] and kw[\"microsecond\"].ljust(6, \"0\")\n tzinfo = kw.pop(\"tzinfo\")\n if tzinfo == \"Z\":\n tzinfo = dt.timezone.utc\n elif tzinfo is not None:\n offset_mins = int(tzinfo[-2:]) if len(tzinfo) > 3 else 0\n offset = 60 * int(tzinfo[1:3]) + offset_mins\n if tzinfo[0] == \"-\":\n offset = -offset\n tzinfo = get_fixed_timezone(offset)\n kw = {k: int(v) for k, v in kw.items() if v is not None}\n kw[\"tzinfo\"] = tzinfo\n return dt.datetime(**kw)\n\n\ndef from_iso_time(value):\n \"\"\"Parse a string and return a datetime.time.\n\n This function doesn't support time zone offsets.\n \"\"\"\n match = _iso8601_time_re.match(value)\n if not match:\n raise ValueError(\"Not a valid ISO8601-formatted time string\")\n kw = match.groupdict()\n kw[\"microsecond\"] = kw[\"microsecond\"] and kw[\"microsecond\"].ljust(6, \"0\")\n kw = {k: int(v) for k, v in kw.items() if v is not None}\n return dt.time(**kw)\n\n\ndef from_iso_date(value):\n \"\"\"Parse a string and return a datetime.date.\"\"\"\n match = _iso8601_date_re.match(value)\n if not match:\n raise ValueError(\"Not a valid ISO8601-formatted date string\")\n kw = {k: int(v) for k, v in match.groupdict().items()}\n return dt.date(**kw)\n\n\ndef from_timestamp(value: typing.Any) -> dt.datetime:\n value = float(value)\n if value < 0:\n raise ValueError(\"Not a valid POSIX timestamp\")\n\n # Load a timestamp with utc as timezone to prevent using system timezone.\n # Then set timezone to None, to let the Field handle adding timezone info.\n try:\n return dt.datetime.fromtimestamp(value, tz=dt.timezone.utc).replace(tzinfo=None)\n except OverflowError as exc:\n raise ValueError(\"Timestamp is too large\") from exc\n except OSError as exc:\n raise ValueError(\"Error converting value to datetime\") from exc\n\n\ndef from_timestamp_ms(value: typing.Any) -> dt.datetime:\n value = float(value)\n return from_timestamp(value / 1000)\n\n\ndef timestamp(\n value: dt.datetime,\n) -> float:\n if not is_aware(value):\n # When a date is naive, use UTC as zone info to prevent using system timezone.\n value = value.replace(tzinfo=dt.timezone.utc)\n return value.timestamp()\n\n\ndef timestamp_ms(value: dt.datetime) -> float:\n return timestamp(value) * 1000\n\n\ndef isoformat(datetime: dt.datetime) -> str:\n \"\"\"Return the ISO8601-formatted representation of a datetime object.\n\n :param datetime datetime: The datetime.\n \"\"\"\n return datetime.isoformat()\n\n\ndef to_iso_time(time: dt.time) -> str:\n return dt.time.isoformat(time)\n\n\ndef to_iso_date(date: dt.date) -> str:\n return dt.date.isoformat(date)\n\n\ndef ensure_text_type(val: str | bytes) -> str:\n if isinstance(val, bytes):\n val = val.decode(\"utf-8\")\n return str(val)\n\n\ndef pluck(dictlist: list[dict[str, typing.Any]], key: str):\n \"\"\"Extracts a list of dictionary values from a list of dictionaries.\n ::\n\n >>> dlist = [{'id': 1, 'name': 'foo'}, {'id': 2, 'name': 'bar'}]\n >>> pluck(dlist, 'id')\n [1, 2]\n \"\"\"\n return [d[key] for d in dictlist]\n\n\n# Various utilities for pulling keyed values from objects\n\n\ndef get_value(obj, key: int | str, default=missing):\n \"\"\"Helper for pulling a keyed value off various types of objects. Fields use\n this method by default to access attributes of the source object. For object `x`\n and attribute `i`, this method first tries to access `x[i]`, and then falls back to\n `x.i` if an exception is raised.\n\n .. warning::\n If an object `x` does not raise an exception when `x[i]` does not exist,\n `get_value` will never check the value `x.i`. Consider overriding\n `marshmallow.fields.Field.get_value` in this case.\n \"\"\"\n if not isinstance(key, int) and \".\" in key:\n return _get_value_for_keys(obj, key.split(\".\"), default)\n else:\n return _get_value_for_key(obj, key, default)\n\n\ndef _get_value_for_keys(obj, keys, default):\n if len(keys) == 1:\n return _get_value_for_key(obj, keys[0], default)\n else:\n return _get_value_for_keys(\n _get_value_for_key(obj, keys[0], default), keys[1:], default\n )\n\n\ndef _get_value_for_key(obj, key, default):\n if not hasattr(obj, \"__getitem__\"):\n return getattr(obj, key, default)\n\n try:\n return obj[key]\n except (KeyError, IndexError, TypeError, AttributeError):\n return getattr(obj, key, default)\n\n\ndef set_value(dct: dict[str, typing.Any], key: str, value: typing.Any):\n \"\"\"Set a value in a dict. If `key` contains a '.', it is assumed\n be a path (i.e. dot-delimited string) to the value's location.\n\n ::\n\n >>> d = {}\n >>> set_value(d, 'foo.bar', 42)\n >>> d\n {'foo': {'bar': 42}}\n \"\"\"\n if \".\" in key:\n head, rest = key.split(\".\", 1)\n target = dct.setdefault(head, {})\n if not isinstance(target, dict):\n raise ValueError(\n f\"Cannot set {key} in {head} \" f\"due to existing value: {target}\"\n )\n set_value(target, rest, value)\n else:\n dct[key] = value\n\n\ndef callable_or_raise(obj):\n \"\"\"Check that an object is callable, else raise a :exc:`TypeError`.\"\"\"\n if not callable(obj):\n raise TypeError(f\"Object {obj!r} is not callable.\")\n return obj\n\n\ndef _signature(func: typing.Callable) -> list[str]:\n return list(inspect.signature(func).parameters.keys())\n\n\ndef get_func_args(func: typing.Callable) -> list[str]:\n \"\"\"Given a callable, return a list of argument names. Handles\n `functools.partial` objects and class-based callables.\n\n .. versionchanged:: 3.0.0a1\n Do not return bound arguments, eg. ``self``.\n \"\"\"\n if inspect.isfunction(func) or inspect.ismethod(func):\n return _signature(func)\n if isinstance(func, functools.partial):\n return _signature(func.func)\n # Callable class\n return _signature(func)\n\n\ndef resolve_field_instance(cls_or_instance):\n \"\"\"Return a Schema instance from a Schema class or instance.\n\n :param type|Schema cls_or_instance: Marshmallow Schema class or instance.\n \"\"\"\n if isinstance(cls_or_instance, type):\n if not issubclass(cls_or_instance, FieldABC):\n raise FieldInstanceResolutionError\n return cls_or_instance()\n else:\n if not isinstance(cls_or_instance, FieldABC):\n raise FieldInstanceResolutionError\n return cls_or_instance\n\n\ndef timedelta_to_microseconds(value: dt.timedelta) -> int:\n \"\"\"Compute the total microseconds of a timedelta\n\n https://github.com/python/cpython/blob/bb3e0c240bc60fe08d332ff5955d54197f79751c/Lib/datetime.py#L665-L667 # noqa: B950\n \"\"\"\n return (value.days * (24 * 3600) + value.seconds) * 1000000 + value.microseconds\n\n\ndef validate_unknown_parameter_value(obj: typing.Any) -> str:\n if obj not in _UNKNOWN_VALUES:\n raise ValueError(\n f\"Object {obj!r} is not a valid value for the 'unknown' parameter\"\n )\n return obj\n\n# Path: src/marshmallow/schema.py\n\"\"\"The :class:`Schema` class, including its metaclass and options (class Meta).\"\"\"\nfrom __future__ import annotations\n\nimport copy\nimport datetime as dt\nimport decimal\nimport inspect\nimport json\nimport typing\nimport uuid\nimport warnings\nfrom abc import ABCMeta\nfrom collections import OrderedDict, defaultdict\nfrom collections.abc import Mapping\nfrom functools import lru_cache\n\nfrom marshmallow import base, class_registry, types\nfrom marshmallow import fields as ma_fields\nfrom marshmallow.decorators import (\n POST_DUMP,\n POST_LOAD,\n PRE_DUMP,\n PRE_LOAD,\n VALIDATES,\n VALIDATES_SCHEMA,\n)\nfrom marshmallow.error_store import ErrorStore\nfrom marshmallow.exceptions import StringNotCollectionError, ValidationError\nfrom marshmallow.orderedset import OrderedSet\nfrom marshmallow.utils import (\n EXCLUDE,\n INCLUDE,\n RAISE,\n get_value,\n is_collection,\n is_instance_or_subclass,\n missing,\n set_value,\n validate_unknown_parameter_value,\n)\nfrom marshmallow.warnings import RemovedInMarshmallow4Warning\n\n_T = typing.TypeVar(\"_T\")\n\n\ndef _get_fields(attrs):\n \"\"\"Get fields from a class\n\n :param attrs: Mapping of class attributes\n \"\"\"\n return [\n (field_name, field_value)\n for field_name, field_value in attrs.items()\n if is_instance_or_subclass(field_value, base.FieldABC)\n ]\n\n\n# This function allows Schemas to inherit from non-Schema classes and ensures\n# inheritance according to the MRO\ndef _get_fields_by_mro(klass):\n \"\"\"Collect fields from a class, following its method resolution order. The\n class itself is excluded from the search; only its parents are checked. Get\n fields from ``_declared_fields`` if available, else use ``__dict__``.\n\n :param type klass: Class whose fields to retrieve\n \"\"\"\n mro = inspect.getmro(klass)\n # Loop over mro in reverse to maintain correct order of fields\n return sum(\n (\n _get_fields(\n getattr(base, \"_declared_fields\", base.__dict__),\n )\n for base in mro[:0:-1]\n ),\n [],\n )\n\n\nclass SchemaMeta(ABCMeta):\n \"\"\"Metaclass for the Schema class. Binds the declared fields to\n a ``_declared_fields`` attribute, which is a dictionary mapping attribute\n names to field objects. Also sets the ``opts`` class attribute, which is\n the Schema class's ``class Meta`` options.\n \"\"\"\n\n def __new__(mcs, name, bases, attrs):\n meta = attrs.get(\"Meta\")\n ordered = getattr(meta, \"ordered\", False)\n if not ordered:\n # Inherit 'ordered' option\n # Warning: We loop through bases instead of MRO because we don't\n # yet have access to the class object\n # (i.e. can't call super before we have fields)\n for base_ in bases:\n if hasattr(base_, \"Meta\") and hasattr(base_.Meta, \"ordered\"):\n ordered = base_.Meta.ordered\n break\n else:\n ordered = False\n cls_fields = _get_fields(attrs)\n # Remove fields from list of class attributes to avoid shadowing\n # Schema attributes/methods in case of name conflict\n for field_name, _ in cls_fields:\n del attrs[field_name]\n klass = super().__new__(mcs, name, bases, attrs)\n inherited_fields = _get_fields_by_mro(klass)\n\n meta = klass.Meta\n # Set klass.opts in __new__ rather than __init__ so that it is accessible in\n # get_declared_fields\n klass.opts = klass.OPTIONS_CLASS(meta, ordered=ordered)\n # Add fields specified in the `include` class Meta option\n cls_fields += list(klass.opts.include.items())\n\n # Assign _declared_fields on class\n klass._declared_fields = mcs.get_declared_fields(\n klass=klass,\n cls_fields=cls_fields,\n inherited_fields=inherited_fields,\n dict_cls=dict,\n )\n return klass\n\n @classmethod\n def get_declared_fields(\n mcs,\n klass: type,\n cls_fields: list,\n inherited_fields: list,\n dict_cls: type = dict,\n ):\n \"\"\"Returns a dictionary of field_name => `Field` pairs declared on the class.\n This is exposed mainly so that plugins can add additional fields, e.g. fields\n computed from class Meta options.\n\n :param klass: The class object.\n :param cls_fields: The fields declared on the class, including those added\n by the ``include`` class Meta option.\n :param inherited_fields: Inherited fields.\n :param dict_cls: dict-like class to use for dict output Default to ``dict``.\n \"\"\"\n return dict_cls(inherited_fields + cls_fields)\n\n def __init__(cls, name, bases, attrs):\n super().__init__(name, bases, attrs)\n if name and cls.opts.register:\n class_registry.register(name, cls)\n cls._hooks = cls.resolve_hooks()\n\n def resolve_hooks(cls) -> dict[types.Tag, list[str]]:\n \"\"\"Add in the decorated processors\n\n By doing this after constructing the class, we let standard inheritance\n do all the hard work.\n \"\"\"\n mro = inspect.getmro(cls)\n\n hooks = defaultdict(list) # type: typing.Dict[types.Tag, typing.List[str]]\n\n for attr_name in dir(cls):\n # Need to look up the actual descriptor, not whatever might be\n # bound to the class. This needs to come from the __dict__ of the\n # declaring class.\n for parent in mro:\n try:\n attr = parent.__dict__[attr_name]\n except KeyError:\n continue\n else:\n break\n else:\n # In case we didn't find the attribute and didn't break above.\n # We should never hit this - it's just here for completeness\n # to exclude the possibility of attr being undefined.\n continue\n\n try:\n hook_config = attr.__marshmallow_hook__\n except AttributeError:\n pass\n else:\n for key in hook_config.keys():\n # Use name here so we can get the bound method later, in\n # case the processor was a descriptor or something.\n hooks[key].append(attr_name)\n\n return hooks\n\n\nclass SchemaOpts:\n \"\"\"class Meta options for the :class:`Schema`. Defines defaults.\"\"\"\n\n def __init__(self, meta, ordered: bool = False):\n self.fields = getattr(meta, \"fields\", ())\n if not isinstance(self.fields, (list, tuple)):\n raise ValueError(\"`fields` option must be a list or tuple.\")\n self.additional = getattr(meta, \"additional\", ())\n if not isinstance(self.additional, (list, tuple)):\n raise ValueError(\"`additional` option must be a list or tuple.\")\n if self.fields and self.additional:\n raise ValueError(\n \"Cannot set both `fields` and `additional` options\"\n \" for the same Schema.\"\n )\n self.exclude = getattr(meta, \"exclude\", ())\n if not isinstance(self.exclude, (list, tuple)):\n raise ValueError(\"`exclude` must be a list or tuple.\")\n self.dateformat = getattr(meta, \"dateformat\", None)\n self.datetimeformat = getattr(meta, \"datetimeformat\", None)\n self.timeformat = getattr(meta, \"timeformat\", None)\n if hasattr(meta, \"json_module\"):\n warnings.warn(\n \"The json_module class Meta option is deprecated. Use render_module instead.\",\n RemovedInMarshmallow4Warning,\n stacklevel=2,\n )\n render_module = getattr(meta, \"json_module\", json)\n else:\n render_module = json\n self.render_module = getattr(meta, \"render_module\", render_module)\n self.ordered = getattr(meta, \"ordered\", ordered)\n self.index_errors = getattr(meta, \"index_errors\", True)\n self.include = getattr(meta, \"include\", {})\n self.load_only = getattr(meta, \"load_only\", ())\n self.dump_only = getattr(meta, \"dump_only\", ())\n self.unknown = validate_unknown_parameter_value(getattr(meta, \"unknown\", RAISE))\n self.register = getattr(meta, \"register\", True)\n\n\nclass Schema(base.SchemaABC, metaclass=SchemaMeta):\n \"\"\"Base schema class with which to define custom schemas.\n\n Example usage:\n\n .. code-block:: python\n\n import datetime as dt\n from dataclasses import dataclass\n\n from marshmallow import Schema, fields\n\n\n @dataclass\n class Album:\n title: str\n release_date: dt.date\n\n\n class AlbumSchema(Schema):\n title = fields.Str()\n release_date = fields.Date()\n\n\n album = Album(\"Beggars Banquet\", dt.date(1968, 12, 6))\n schema = AlbumSchema()\n data = schema.dump(album)\n data # {'release_date': '1968-12-06', 'title': 'Beggars Banquet'}\n\n :param only: Whitelist of the declared fields to select when\n instantiating the Schema. If None, all fields are used. Nested fields\n can be represented with dot delimiters.\n :param exclude: Blacklist of the declared fields to exclude\n when instantiating the Schema. If a field appears in both `only` and\n `exclude`, it is not used. Nested fields can be represented with dot\n delimiters.\n :param many: Should be set to `True` if ``obj`` is a collection\n so that the object will be serialized to a list.\n :param context: Optional context passed to :class:`fields.Method` and\n :class:`fields.Function` fields.\n :param load_only: Fields to skip during serialization (write-only fields)\n :param dump_only: Fields to skip during deserialization (read-only fields)\n :param partial: Whether to ignore missing fields and not require\n any fields declared. Propagates down to ``Nested`` fields as well. If\n its value is an iterable, only missing fields listed in that iterable\n will be ignored. Use dot delimiters to specify nested fields.\n :param unknown: Whether to exclude, include, or raise an error for unknown\n fields in the data. Use `EXCLUDE`, `INCLUDE` or `RAISE`.\n\n .. versionchanged:: 3.0.0\n `prefix` parameter removed.\n\n .. versionchanged:: 2.0.0\n `__validators__`, `__preprocessors__`, and `__data_handlers__` are removed in favor of\n `marshmallow.decorators.validates_schema`,\n `marshmallow.decorators.pre_load` and `marshmallow.decorators.post_dump`.\n `__accessor__` and `__error_handler__` are deprecated. Implement the\n `handle_error` and `get_attribute` methods instead.\n \"\"\"\n\n TYPE_MAPPING = {\n str: ma_fields.String,\n bytes: ma_fields.String,\n dt.datetime: ma_fields.DateTime,\n float: ma_fields.Float,\n bool: ma_fields.Boolean,\n tuple: ma_fields.Raw,\n list: ma_fields.Raw,\n set: ma_fields.Raw,\n int: ma_fields.Integer,\n uuid.UUID: ma_fields.UUID,\n dt.time: ma_fields.Time,\n dt.date: ma_fields.Date,\n dt.timedelta: ma_fields.TimeDelta,\n decimal.Decimal: ma_fields.Decimal,\n } # type: typing.Dict[type, typing.Type[ma_fields.Field]]\n #: Overrides for default schema-level error messages\n error_messages = {} # type: typing.Dict[str, str]\n\n _default_error_messages = {\n \"type\": \"Invalid input type.\",\n \"unknown\": \"Unknown field.\",\n } # type: typing.Dict[str, str]\n\n OPTIONS_CLASS = SchemaOpts # type: type\n\n set_class = OrderedSet\n\n # These get set by SchemaMeta\n opts = None # type: SchemaOpts\n _declared_fields = {} # type: typing.Dict[str, ma_fields.Field]\n _hooks = {} # type: typing.Dict[types.Tag, typing.List[str]]\n\n class Meta:\n \"\"\"Options object for a Schema.\n\n Example usage: ::\n\n class Meta:\n fields = (\"id\", \"email\", \"date_created\")\n exclude = (\"password\", \"secret_attribute\")\n\n Available options:\n\n - ``fields``: Tuple or list of fields to include in the serialized result.\n - ``additional``: Tuple or list of fields to include *in addition* to the\n explicitly declared fields. ``additional`` and ``fields`` are\n mutually-exclusive options.\n - ``include``: Dictionary of additional fields to include in the schema. It is\n usually better to define fields as class variables, but you may need to\n use this option, e.g., if your fields are Python keywords. May be an\n `OrderedDict`.\n - ``exclude``: Tuple or list of fields to exclude in the serialized result.\n Nested fields can be represented with dot delimiters.\n - ``dateformat``: Default format for `Date ` fields.\n - ``datetimeformat``: Default format for `DateTime ` fields.\n - ``timeformat``: Default format for `Time ` fields.\n - ``render_module``: Module to use for `loads ` and `dumps `.\n Defaults to `json` from the standard library.\n - ``ordered``: If `True`, output of `Schema.dump` will be a `collections.OrderedDict`.\n - ``index_errors``: If `True`, errors dictionaries will include the index\n of invalid items in a collection.\n - ``load_only``: Tuple or list of fields to exclude from serialized results.\n - ``dump_only``: Tuple or list of fields to exclude from deserialization\n - ``unknown``: Whether to exclude, include, or raise an error for unknown\n fields in the data. Use `EXCLUDE`, `INCLUDE` or `RAISE`.\n - ``register``: Whether to register the `Schema` with marshmallow's internal\n class registry. Must be `True` if you intend to refer to this `Schema`\n by class name in `Nested` fields. Only set this to `False` when memory\n usage is critical. Defaults to `True`.\n \"\"\"\n\n def __init__(\n self,\n *,\n only: types.StrSequenceOrSet | None = None,\n exclude: types.StrSequenceOrSet = (),\n many: bool = False,\n context: dict | None = None,\n load_only: types.StrSequenceOrSet = (),\n dump_only: types.StrSequenceOrSet = (),\n partial: bool | types.StrSequenceOrSet | None = None,\n unknown: str | None = None,\n ):\n # Raise error if only or exclude is passed as string, not list of strings\n if only is not None and not is_collection(only):\n raise StringNotCollectionError('\"only\" should be a list of strings')\n if not is_collection(exclude):\n raise StringNotCollectionError('\"exclude\" should be a list of strings')\n # copy declared fields from metaclass\n self.declared_fields = copy.deepcopy(self._declared_fields)\n self.many = many\n self.only = only\n self.exclude: set[typing.Any] | typing.MutableSet[typing.Any] = set(\n self.opts.exclude\n ) | set(exclude)\n self.ordered = self.opts.ordered\n self.load_only = set(load_only) or set(self.opts.load_only)\n self.dump_only = set(dump_only) or set(self.opts.dump_only)\n self.partial = partial\n self.unknown = (\n self.opts.unknown\n if unknown is None\n else validate_unknown_parameter_value(unknown)\n )\n self.context = context or {}\n self._normalize_nested_options()\n #: Dictionary mapping field_names -> :class:`Field` objects\n self.fields = {} # type: typing.Dict[str, ma_fields.Field]\n self.load_fields = {} # type: typing.Dict[str, ma_fields.Field]\n self.dump_fields = {} # type: typing.Dict[str, ma_fields.Field]\n self._init_fields()\n messages = {}\n messages.update(self._default_error_messages)\n for cls in reversed(self.__class__.__mro__):\n messages.update(getattr(cls, \"error_messages\", {}))\n messages.update(self.error_messages or {})\n self.error_messages = messages\n\n def __repr__(self) -> str:\n return f\"<{self.__class__.__name__}(many={self.many})>\"\n\n @property\n def dict_class(self) -> type:\n return OrderedDict if self.ordered else dict\n\n @classmethod\n def from_dict(\n cls,\n fields: dict[str, ma_fields.Field | type],\n *,\n name: str = \"GeneratedSchema\",\n ) -> type:\n \"\"\"Generate a `Schema` class given a dictionary of fields.\n\n .. code-block:: python\n\n from marshmallow import Schema, fields\n\n PersonSchema = Schema.from_dict({\"name\": fields.Str()})\n print(PersonSchema().load({\"name\": \"David\"})) # => {'name': 'David'}\n\n Generated schemas are not added to the class registry and therefore cannot\n be referred to by name in `Nested` fields.\n\n :param dict fields: Dictionary mapping field names to field instances.\n :param str name: Optional name for the class, which will appear in\n the ``repr`` for the class.\n\n .. versionadded:: 3.0.0\n \"\"\"\n attrs = fields.copy()\n attrs[\"Meta\"] = type(\n \"GeneratedMeta\", (getattr(cls, \"Meta\", object),), {\"register\": False}\n )\n schema_cls = type(name, (cls,), attrs)\n return schema_cls\n\n ##### Override-able methods #####\n\n def handle_error(\n self, error: ValidationError, data: typing.Any, *, many: bool, **kwargs\n ):\n \"\"\"Custom error handler function for the schema.\n\n :param error: The `ValidationError` raised during (de)serialization.\n :param data: The original input data.\n :param many: Value of ``many`` on dump or load.\n :param partial: Value of ``partial`` on load.\n\n .. versionadded:: 2.0.0\n\n .. versionchanged:: 3.0.0rc9\n Receives `many` and `partial` (on deserialization) as keyword arguments.\n \"\"\"\n pass\n\n def get_attribute(self, obj: typing.Any, attr: str, default: typing.Any):\n \"\"\"Defines how to pull values from an object to serialize.\n\n .. versionadded:: 2.0.0\n\n .. versionchanged:: 3.0.0a1\n Changed position of ``obj`` and ``attr``.\n \"\"\"\n return get_value(obj, attr, default)\n\n ##### Serialization/Deserialization API #####\n\n @staticmethod\n def _call_and_store(getter_func, data, *, field_name, error_store, index=None):\n \"\"\"Call ``getter_func`` with ``data`` as its argument, and store any `ValidationErrors`.\n\n :param callable getter_func: Function for getting the serialized/deserialized\n value from ``data``.\n :param data: The data passed to ``getter_func``.\n :param str field_name: Field name.\n :param int index: Index of the item being validated, if validating a collection,\n otherwise `None`.\n \"\"\"\n try:\n value = getter_func(data)\n except ValidationError as error:\n error_store.store_error(error.messages, field_name, index=index)\n # When a Nested field fails validation, the marshalled data is stored\n # on the ValidationError's valid_data attribute\n return error.valid_data or missing\n return value\n\n def _serialize(self, obj: _T | typing.Iterable[_T], *, many: bool = False):\n \"\"\"Serialize ``obj``.\n\n :param obj: The object(s) to serialize.\n :param bool many: `True` if ``data`` should be serialized as a collection.\n :return: A dictionary of the serialized data\n\n .. versionchanged:: 1.0.0\n Renamed from ``marshal``.\n \"\"\"\n if many and obj is not None:\n return [\n self._serialize(d, many=False)\n for d in typing.cast(typing.Iterable[_T], obj)\n ]\n ret = self.dict_class()\n for attr_name, field_obj in self.dump_fields.items():\n value = field_obj.serialize(attr_name, obj, accessor=self.get_attribute)\n if value is missing:\n continue\n key = field_obj.data_key if field_obj.data_key is not None else attr_name\n ret[key] = value\n return ret\n\n def dump(self, obj: typing.Any, *, many: bool | None = None):\n \"\"\"Serialize an object to native Python data types according to this\n Schema's fields.\n\n :param obj: The object to serialize.\n :param many: Whether to serialize `obj` as a collection. If `None`, the value\n for `self.many` is used.\n :return: Serialized data\n\n .. versionadded:: 1.0.0\n .. versionchanged:: 3.0.0b7\n This method returns the serialized data rather than a ``(data, errors)`` duple.\n A :exc:`ValidationError ` is raised\n if ``obj`` is invalid.\n .. versionchanged:: 3.0.0rc9\n Validation no longer occurs upon serialization.\n \"\"\"\n many = self.many if many is None else bool(many)\n if self._has_processors(PRE_DUMP):\n processed_obj = self._invoke_dump_processors(\n PRE_DUMP, obj, many=many, original_data=obj\n )\n else:\n processed_obj = obj\n\n result = self._serialize(processed_obj, many=many)\n\n if self._has_processors(POST_DUMP):\n result = self._invoke_dump_processors(\n POST_DUMP, result, many=many, original_data=obj\n )\n\n return result\n\n def dumps(self, obj: typing.Any, *args, many: bool | None = None, **kwargs):\n \"\"\"Same as :meth:`dump`, except return a JSON-encoded string.\n\n :param obj: The object to serialize.\n :param many: Whether to serialize `obj` as a collection. If `None`, the value\n for `self.many` is used.\n :return: A ``json`` string\n\n .. versionadded:: 1.0.0\n .. versionchanged:: 3.0.0b7\n This method returns the serialized data rather than a ``(data, errors)`` duple.\n A :exc:`ValidationError ` is raised\n if ``obj`` is invalid.\n \"\"\"\n serialized = self.dump(obj, many=many)\n return self.opts.render_module.dumps(serialized, *args, **kwargs)\n\n def _deserialize(\n self,\n data: (\n typing.Mapping[str, typing.Any]\n | typing.Iterable[typing.Mapping[str, typing.Any]]\n ),\n *,\n error_store: ErrorStore,\n many: bool = False,\n partial=None,\n unknown=RAISE,\n index=None,\n ) -> _T | list[_T]:\n \"\"\"Deserialize ``data``.\n\n :param dict data: The data to deserialize.\n :param ErrorStore error_store: Structure to store errors.\n :param bool many: `True` if ``data`` should be deserialized as a collection.\n :param bool|tuple partial: Whether to ignore missing fields and not require\n any fields declared. Propagates down to ``Nested`` fields as well. If\n its value is an iterable, only missing fields listed in that iterable\n will be ignored. Use dot delimiters to specify nested fields.\n :param unknown: Whether to exclude, include, or raise an error for unknown\n fields in the data. Use `EXCLUDE`, `INCLUDE` or `RAISE`.\n :param int index: Index of the item being serialized (for storing errors) if\n serializing a collection, otherwise `None`.\n :return: A dictionary of the deserialized data.\n \"\"\"\n index_errors = self.opts.index_errors\n index = index if index_errors else None\n if many:\n if not is_collection(data):\n error_store.store_error([self.error_messages[\"type\"]], index=index)\n ret_l = [] # type: typing.List[_T]\n else:\n ret_l = [\n typing.cast(\n _T,\n self._deserialize(\n typing.cast(typing.Mapping[str, typing.Any], d),\n error_store=error_store,\n many=False,\n partial=partial,\n unknown=unknown,\n index=idx,\n ),\n )\n for idx, d in enumerate(data)\n ]\n return ret_l\n ret_d = self.dict_class()\n # Check data is a dict\n if not isinstance(data, Mapping):\n error_store.store_error([self.error_messages[\"type\"]], index=index)\n else:\n partial_is_collection = is_collection(partial)\n for attr_name, field_obj in self.load_fields.items():\n field_name = (\n field_obj.data_key if field_obj.data_key is not None else attr_name\n )\n raw_value = data.get(field_name, missing)\n if raw_value is missing:\n # Ignore missing field if we're allowed to.\n if partial is True or (\n partial_is_collection and attr_name in partial\n ):\n continue\n d_kwargs = {}\n # Allow partial loading of nested schemas.\n if partial_is_collection:\n prefix = field_name + \".\"\n len_prefix = len(prefix)\n sub_partial = [\n f[len_prefix:] for f in partial if f.startswith(prefix)\n ]\n d_kwargs[\"partial\"] = sub_partial\n elif partial is not None:\n d_kwargs[\"partial\"] = partial\n\n def getter(\n val, field_obj=field_obj, field_name=field_name, d_kwargs=d_kwargs\n ):\n return field_obj.deserialize(\n val,\n field_name,\n data,\n **d_kwargs,\n )\n\n value = self._call_and_store(\n getter_func=getter,\n data=raw_value,\n field_name=field_name,\n error_store=error_store,\n index=index,\n )\n if value is not missing:\n key = field_obj.attribute or attr_name\n set_value(ret_d, key, value)\n if unknown != EXCLUDE:\n fields = {\n field_obj.data_key if field_obj.data_key is not None else field_name\n for field_name, field_obj in self.load_fields.items()\n }\n for key in set(data) - fields:\n value = data[key]\n if unknown == INCLUDE:\n ret_d[key] = value\n elif unknown == RAISE:\n error_store.store_error(\n [self.error_messages[\"unknown\"]],\n key,\n (index if index_errors else None),\n )\n return ret_d\n\n def load(\n self,\n data: (\n typing.Mapping[str, typing.Any]\n | typing.Iterable[typing.Mapping[str, typing.Any]]\n ),\n *,\n many: bool | None = None,\n partial: bool | types.StrSequenceOrSet | None = None,\n unknown: str | None = None,\n ):\n \"\"\"Deserialize a data structure to an object defined by this Schema's fields.\n\n :param data: The data to deserialize.\n :param many: Whether to deserialize `data` as a collection. If `None`, the\n value for `self.many` is used.\n :param partial: Whether to ignore missing fields and not require\n any fields declared. Propagates down to ``Nested`` fields as well. If\n its value is an iterable, only missing fields listed in that iterable\n will be ignored. Use dot delimiters to specify nested fields.\n :param unknown: Whether to exclude, include, or raise an error for unknown\n fields in the data. Use `EXCLUDE`, `INCLUDE` or `RAISE`.\n If `None`, the value for `self.unknown` is used.\n :return: Deserialized data\n\n .. versionadded:: 1.0.0\n .. versionchanged:: 3.0.0b7\n This method returns the deserialized data rather than a ``(data, errors)`` duple.\n A :exc:`ValidationError ` is raised\n if invalid data are passed.\n \"\"\"\n return self._do_load(\n data, many=many, partial=partial, unknown=unknown, postprocess=True\n )\n\n def loads(\n self,\n json_data: str,\n *,\n many: bool | None = None,\n partial: bool | types.StrSequenceOrSet | None = None,\n unknown: str | None = None,\n **kwargs,\n ):\n \"\"\"Same as :meth:`load`, except it takes a JSON string as input.\n\n :param json_data: A JSON string of the data to deserialize.\n :param many: Whether to deserialize `obj` as a collection. If `None`, the\n value for `self.many` is used.\n :param partial: Whether to ignore missing fields and not require\n any fields declared. Propagates down to ``Nested`` fields as well. If\n its value is an iterable, only missing fields listed in that iterable\n will be ignored. Use dot delimiters to specify nested fields.\n :param unknown: Whether to exclude, include, or raise an error for unknown\n fields in the data. Use `EXCLUDE`, `INCLUDE` or `RAISE`.\n If `None`, the value for `self.unknown` is used.\n :return: Deserialized data\n\n .. versionadded:: 1.0.0\n .. versionchanged:: 3.0.0b7\n This method returns the deserialized data rather than a ``(data, errors)`` duple.\n A :exc:`ValidationError ` is raised\n if invalid data are passed.\n \"\"\"\n data = self.opts.render_module.loads(json_data, **kwargs)\n return self.load(data, many=many, partial=partial, unknown=unknown)\n\n def _run_validator(\n self,\n validator_func,\n output,\n *,\n original_data,\n error_store,\n many,\n partial,\n pass_original,\n index=None,\n ):\n try:\n if pass_original: # Pass original, raw data (before unmarshalling)\n validator_func(output, original_data, partial=partial, many=many)\n else:\n validator_func(output, partial=partial, many=many)\n except ValidationError as err:\n error_store.store_error(err.messages, err.field_name, index=index)\n\n def validate(\n self,\n data: (\n typing.Mapping[str, typing.Any]\n | typing.Iterable[typing.Mapping[str, typing.Any]]\n ),\n *,\n many: bool | None = None,\n partial: bool | types.StrSequenceOrSet | None = None,\n ) -> dict[str, list[str]]:\n \"\"\"Validate `data` against the schema, returning a dictionary of\n validation errors.\n\n :param data: The data to validate.\n :param many: Whether to validate `data` as a collection. If `None`, the\n value for `self.many` is used.\n :param partial: Whether to ignore missing fields and not require\n any fields declared. Propagates down to ``Nested`` fields as well. If\n its value is an iterable, only missing fields listed in that iterable\n will be ignored. Use dot delimiters to specify nested fields.\n :return: A dictionary of validation errors.\n\n .. versionadded:: 1.1.0\n \"\"\"\n try:\n self._do_load(data, many=many, partial=partial, postprocess=False)\n except ValidationError as exc:\n return typing.cast(typing.Dict[str, typing.List[str]], exc.messages)\n return {}\n\n ##### Private Helpers #####\n\n def _do_load(\n self,\n data: (\n typing.Mapping[str, typing.Any]\n | typing.Iterable[typing.Mapping[str, typing.Any]]\n ),\n *,\n many: bool | None = None,\n partial: bool | types.StrSequenceOrSet | None = None,\n unknown: str | None = None,\n postprocess: bool = True,\n ):\n \"\"\"Deserialize `data`, returning the deserialized result.\n This method is private API.\n\n :param data: The data to deserialize.\n :param many: Whether to deserialize `data` as a collection. If `None`, the\n value for `self.many` is used.\n :param partial: Whether to validate required fields. If its\n value is an iterable, only fields listed in that iterable will be\n ignored will be allowed missing. If `True`, all fields will be allowed missing.\n If `None`, the value for `self.partial` is used.\n :param unknown: Whether to exclude, include, or raise an error for unknown\n fields in the data. Use `EXCLUDE`, `INCLUDE` or `RAISE`.\n If `None`, the value for `self.unknown` is used.\n :param postprocess: Whether to run post_load methods..\n :return: Deserialized data\n \"\"\"\n...\n", "needle_token_start": 603, "needle_token_end": 985, "code_context_ntokens": 16130, "output": ["```python\ndef register(classname: str, cls: SchemaType) -> None:\n \"\"\"Add a class to the registry of serializer classes. When a class is\n registered, an entry for both its classname and its full, module-qualified\n path are added to the registry.\n\n Example: ::\n\n class MyClass:\n pass\n\n register('MyClass', MyClass)\n # Registry:\n # {\n # 'MyClass': [path.to.MyClass],\n # 'path.to.MyClass': [path.to.MyClass],\n # }\n\n \"\"\"\n # Module where the class is located\n module = cls.__module__\n # Full module path to the class\n # e.g. user.schemas.UserSchema\n fullpath = \".\".join([module, classname])\n # If the class is already registered; need to check if the entries are\n # in the same module as cls to avoid having multiple instances of the same\n # class in the registry\n if classname in _registry and not any(\n each.__module__ == module for each in _registry[classname]\n ):\n _registry[classname].append(cls)\n elif classname not in _registry:\n _registry[classname] = [cls]\n\n # Also register the full path\n if fullpath not in _registry:\n _registry.setdefault(fullpath, []).append(cls)\n else:\n # If fullpath does exist, replace existing entry\n _registry[fullpath] = [cls]\n return None\n```"]} +{"repo": "marshmallow-code/marshmallow", "name": "store_error", "language": "python", "path": "src/marshmallow/error_store.py", "position_ratio": 0.15, "description": "\nFunction Description:\n1. **Purpose**: The function is designed to accumulate error messages related to data serialization, categorizing them by field name or schema, and optionally by an index if provided.\n2. **Input**: The function accepts error messages (either as strings, lists, or dictionaries), a field name identifier (defaulting to a schema-wide error identifier if not specified), and an optional index to specify a particular element in a collection.\n3. **Output**: There is no direct output returned from the function; instead, it modifies an internal dictionary to include the new error messages, organizing them according to the provided field name and index.\n4. **Procedure**: The function first checks if the error message should be categorized under a specific field or as a general schema error. If an index is provided, the error message is further nested under this index. The function then merges these newly formatted error messages into an existing dictionary of errors, ensuring that all related errors are grouped together appropriately.\n\n", "instruction": "Based on the function description and code context, please retrieve and repeat the exact described function from the code context in a code block wrapped by ```:", "template": "instruction\ncode_context\ndescription\ninstruction", "code_context": "# Path: src/marshmallow/decorators.py\n\"\"\"Decorators for registering schema pre-processing and post-processing methods.\nThese should be imported from the top-level `marshmallow` module.\n\nMethods decorated with\n`pre_load `, `post_load `,\n`pre_dump `, `post_dump `,\nand `validates_schema ` receive\n``many`` as a keyword argument. In addition, `pre_load `,\n`post_load `,\nand `validates_schema ` receive\n``partial``. If you don't need these arguments, add ``**kwargs`` to your method\nsignature.\n\n\nExample: ::\n\n from marshmallow import (\n Schema, pre_load, pre_dump, post_load, validates_schema,\n validates, fields, ValidationError\n )\n\n class UserSchema(Schema):\n\n email = fields.Str(required=True)\n...\n# Path: src/marshmallow/exceptions.py\n\"\"\"Exception classes for marshmallow-related errors.\"\"\"\nfrom __future__ import annotations\n\nimport typing\n\n# Key used for schema-level validation errors\nSCHEMA = \"_schema\"\n\n\nclass MarshmallowError(Exception):\n \"\"\"Base class for all marshmallow-related errors.\"\"\"\n\n\nclass ValidationError(MarshmallowError):\n \"\"\"Raised when validation fails on a field or schema.\n\n Validators and custom fields should raise this exception.\n\n :param message: An error message, list of error messages, or dict of\n error messages. If a dict, the keys are subitems and the values are error messages.\n :param field_name: Field name to store the error on.\n If `None`, the error is stored as schema-level error.\n :param data: Raw input data.\n :param valid_data: Valid (de)serialized data.\n \"\"\"\n\n def __init__(\n self,\n message: str | list | dict,\n field_name: str = SCHEMA,\n data: typing.Mapping[str, typing.Any]\n | typing.Iterable[typing.Mapping[str, typing.Any]]\n | None = None,\n valid_data: list[dict[str, typing.Any]] | dict[str, typing.Any] | None = None,\n **kwargs,\n ):\n self.messages = [message] if isinstance(message, (str, bytes)) else message\n self.field_name = field_name\n self.data = data\n self.valid_data = valid_data\n self.kwargs = kwargs\n super().__init__(message)\n\n def normalized_messages(self):\n if self.field_name == SCHEMA and isinstance(self.messages, dict):\n return self.messages\n return {self.field_name: self.messages}\n\n @property\n def messages_dict(self) -> dict[str, typing.Any]:\n if not isinstance(self.messages, dict):\n raise TypeError(\n \"cannot access 'messages_dict' when 'messages' is of type \"\n + type(self.messages).__name__\n )\n return self.messages\n\n\nclass RegistryError(NameError):\n \"\"\"Raised when an invalid operation is performed on the serializer\n class registry.\n \"\"\"\n\n\nclass StringNotCollectionError(MarshmallowError, TypeError):\n \"\"\"Raised when a string is passed when a list of strings is expected.\"\"\"\n\n\nclass FieldInstanceResolutionError(MarshmallowError, TypeError):\n \"\"\"Raised when schema to instantiate is neither a Schema class nor an instance.\"\"\"\n\n# Path: src/marshmallow/base.py\n\"\"\"Abstract base classes.\n\nThese are necessary to avoid circular imports between schema.py and fields.py.\n\n.. warning::\n\n This module is treated as private API.\n Users should not need to use this module directly.\n\"\"\"\nfrom __future__ import annotations\n\nfrom abc import ABC, abstractmethod\n\n\nclass FieldABC(ABC):\n \"\"\"Abstract base class from which all Field classes inherit.\"\"\"\n\n parent = None\n name = None\n root = None\n\n @abstractmethod\n def serialize(self, attr, obj, accessor=None):\n pass\n\n @abstractmethod\n def deserialize(self, value):\n pass\n\n @abstractmethod\n def _serialize(self, value, attr, obj, **kwargs):\n pass\n\n @abstractmethod\n def _deserialize(self, value, attr, data, **kwargs):\n pass\n\n\nclass SchemaABC(ABC):\n \"\"\"Abstract base class from which all Schemas inherit.\"\"\"\n\n @abstractmethod\n def dump(self, obj, *, many: bool | None = None):\n pass\n\n @abstractmethod\n def dumps(self, obj, *, many: bool | None = None):\n pass\n\n @abstractmethod\n def load(self, data, *, many: bool | None = None, partial=None, unknown=None):\n pass\n\n @abstractmethod\n def loads(\n self,\n json_data,\n *,\n many: bool | None = None,\n partial=None,\n unknown=None,\n **kwargs,\n ):\n pass\n\n# Path: src/marshmallow/class_registry.py\n\"\"\"A registry of :class:`Schema ` classes. This allows for string\nlookup of schemas, which may be used with\nclass:`fields.Nested `.\n\n.. warning::\n\n This module is treated as private API.\n Users should not need to use this module directly.\n\"\"\"\nfrom __future__ import annotations\n\nimport typing\n\nfrom marshmallow.exceptions import RegistryError\n\nif typing.TYPE_CHECKING:\n from marshmallow import Schema\n\n SchemaType = typing.Type[Schema]\n\n# {\n# : \n# : \n# }\n_registry = {} # type: dict[str, list[SchemaType]]\n\n\ndef register(classname: str, cls: SchemaType) -> None:\n \"\"\"Add a class to the registry of serializer classes. When a class is\n registered, an entry for both its classname and its full, module-qualified\n path are added to the registry.\n\n Example: ::\n\n class MyClass:\n pass\n\n register('MyClass', MyClass)\n # Registry:\n # {\n # 'MyClass': [path.to.MyClass],\n # 'path.to.MyClass': [path.to.MyClass],\n # }\n\n \"\"\"\n # Module where the class is located\n module = cls.__module__\n # Full module path to the class\n # e.g. user.schemas.UserSchema\n fullpath = \".\".join([module, classname])\n # If the class is already registered; need to check if the entries are\n # in the same module as cls to avoid having multiple instances of the same\n # class in the registry\n if classname in _registry and not any(\n each.__module__ == module for each in _registry[classname]\n ):\n _registry[classname].append(cls)\n elif classname not in _registry:\n _registry[classname] = [cls]\n\n # Also register the full path\n if fullpath not in _registry:\n _registry.setdefault(fullpath, []).append(cls)\n else:\n # If fullpath does exist, replace existing entry\n _registry[fullpath] = [cls]\n return None\n\n\ndef get_class(classname: str, all: bool = False) -> list[SchemaType] | SchemaType:\n \"\"\"Retrieve a class from the registry.\n\n :raises: marshmallow.exceptions.RegistryError if the class cannot be found\n or if there are multiple entries for the given class name.\n \"\"\"\n try:\n classes = _registry[classname]\n except KeyError as error:\n raise RegistryError(\n f\"Class with name {classname!r} was not found. You may need \"\n \"to import the class.\"\n ) from error\n if len(classes) > 1:\n if all:\n return _registry[classname]\n raise RegistryError(\n f\"Multiple classes with name {classname!r} \"\n \"were found. Please use the full, \"\n \"module-qualified path.\"\n )\n else:\n return _registry[classname][0]\n\n# Path: src/marshmallow/error_store.py\n\"\"\"Utilities for storing collections of error messages.\n\n.. warning::\n\n This module is treated as private API.\n Users should not need to use this module directly.\n\"\"\"\n\nfrom marshmallow.exceptions import SCHEMA\n\n\nclass ErrorStore:\n def __init__(self):\n #: Dictionary of errors stored during serialization\n self.errors = {}\n\n \ndef store_error(self, messages, field_name=SCHEMA, index=None):\n # field error -> store/merge error messages under field name key\n # schema error -> if string or list, store/merge under _schema key\n # -> if dict, store/merge with other top-level keys\n if field_name != SCHEMA or not isinstance(messages, dict):\n messages = {field_name: messages}\n if index is not None:\n messages = {index: messages}\n self.errors = merge_errors(self.errors, messages)\n\n\ndef merge_errors(errors1, errors2):\n \"\"\"Deeply merge two error messages.\n\n The format of ``errors1`` and ``errors2`` matches the ``message``\n parameter of :exc:`marshmallow.exceptions.ValidationError`.\n \"\"\"\n if not errors1:\n return errors2\n if not errors2:\n return errors1\n if isinstance(errors1, list):\n if isinstance(errors2, list):\n return errors1 + errors2\n if isinstance(errors2, dict):\n return dict(errors2, **{SCHEMA: merge_errors(errors1, errors2.get(SCHEMA))})\n return errors1 + [errors2]\n if isinstance(errors1, dict):\n if isinstance(errors2, list):\n return dict(errors1, **{SCHEMA: merge_errors(errors1.get(SCHEMA), errors2)})\n if isinstance(errors2, dict):\n errors = dict(errors1)\n for key, val in errors2.items():\n if key in errors:\n errors[key] = merge_errors(errors[key], val)\n else:\n errors[key] = val\n return errors\n return dict(errors1, **{SCHEMA: merge_errors(errors1.get(SCHEMA), errors2)})\n if isinstance(errors2, list):\n return [errors1] + errors2\n if isinstance(errors2, dict):\n return dict(errors2, **{SCHEMA: merge_errors(errors1, errors2.get(SCHEMA))})\n return [errors1, errors2]\n\n# Path: src/marshmallow/orderedset.py\n# OrderedSet\n# Copyright (c) 2009 Raymond Hettinger\n#\n# Permission is hereby granted, free of charge, to any person\n# obtaining a copy of this software and associated documentation files\n# (the \"Software\"), to deal in the Software without restriction,\n# including without limitation the rights to use, copy, modify, merge,\n# publish, distribute, sublicense, and/or sell copies of the Software,\n# and to permit persons to whom the Software is furnished to do so,\n# subject to the following conditions:\n#\n# The above copyright notice and this permission notice shall be\n# included in all copies or substantial portions of the Software.\n#\n# THE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND,\n# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES\n# OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND\n# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT\n# HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY,\n# WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING\n# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR\n# OTHER DEALINGS IN THE SOFTWARE.\nfrom collections.abc import MutableSet\n\n\nclass OrderedSet(MutableSet):\n def __init__(self, iterable=None):\n self.end = end = []\n end += [None, end, end] # sentinel node for doubly linked list\n self.map = {} # key --> [key, prev, next]\n if iterable is not None:\n self |= iterable\n\n def __len__(self):\n return len(self.map)\n\n def __contains__(self, key):\n return key in self.map\n\n def add(self, key):\n if key not in self.map:\n end = self.end\n curr = end[1]\n curr[2] = end[1] = self.map[key] = [key, curr, end]\n\n def discard(self, key):\n if key in self.map:\n key, prev, next = self.map.pop(key)\n prev[2] = next\n next[1] = prev\n\n def __iter__(self):\n end = self.end\n curr = end[2]\n while curr is not end:\n yield curr[0]\n curr = curr[2]\n\n def __reversed__(self):\n end = self.end\n curr = end[1]\n while curr is not end:\n yield curr[0]\n curr = curr[1]\n\n def pop(self, last=True):\n if not self:\n raise KeyError(\"set is empty\")\n key = self.end[1][0] if last else self.end[2][0]\n self.discard(key)\n return key\n\n def __repr__(self):\n if not self:\n return f\"{self.__class__.__name__}()\"\n return f\"{self.__class__.__name__}({list(self)!r})\"\n\n def __eq__(self, other):\n if isinstance(other, OrderedSet):\n return len(self) == len(other) and list(self) == list(other)\n return set(self) == set(other)\n\n\nif __name__ == \"__main__\":\n s = OrderedSet(\"abracadaba\")\n t = OrderedSet(\"simsalabim\")\n print(s | t)\n print(s & t)\n print(s - t)\n\n# Path: src/marshmallow/types.py\n\"\"\"Type aliases.\n\n.. warning::\n\n This module is provisional. Types may be modified, added, and removed between minor releases.\n\"\"\"\nimport typing\n\nStrSequenceOrSet = typing.Union[typing.Sequence[str], typing.AbstractSet[str]]\nTag = typing.Union[str, typing.Tuple[str, bool]]\nValidator = typing.Callable[[typing.Any], typing.Any]\n\n# Path: src/marshmallow/warnings.py\nclass RemovedInMarshmallow4Warning(DeprecationWarning):\n pass\n\n# Path: src/marshmallow/utils.py\n\"\"\"Utility methods for marshmallow.\"\"\"\nfrom __future__ import annotations\n\nimport collections\nimport datetime as dt\nimport functools\nimport inspect\nimport json\nimport re\nimport typing\nimport warnings\nfrom collections.abc import Mapping\nfrom email.utils import format_datetime, parsedate_to_datetime\nfrom pprint import pprint as py_pprint\n\nfrom marshmallow.base import FieldABC\nfrom marshmallow.exceptions import FieldInstanceResolutionError\nfrom marshmallow.warnings import RemovedInMarshmallow4Warning\n\nEXCLUDE = \"exclude\"\nINCLUDE = \"include\"\nRAISE = \"raise\"\n_UNKNOWN_VALUES = {EXCLUDE, INCLUDE, RAISE}\n\n\nclass _Missing:\n def __bool__(self):\n return False\n\n def __copy__(self):\n return self\n\n def __deepcopy__(self, _):\n return self\n\n def __repr__(self):\n return \"\"\n\n\n# Singleton value that indicates that a field's value is missing from input\n# dict passed to :meth:`Schema.load`. If the field's value is not required,\n# it's ``default`` value is used.\nmissing = _Missing()\n\n\ndef is_generator(obj) -> bool:\n \"\"\"Return True if ``obj`` is a generator\"\"\"\n return inspect.isgeneratorfunction(obj) or inspect.isgenerator(obj)\n\n\ndef is_iterable_but_not_string(obj) -> bool:\n \"\"\"Return True if ``obj`` is an iterable object that isn't a string.\"\"\"\n return (hasattr(obj, \"__iter__\") and not hasattr(obj, \"strip\")) or is_generator(obj)\n\n\ndef is_collection(obj) -> bool:\n \"\"\"Return True if ``obj`` is a collection type, e.g list, tuple, queryset.\"\"\"\n return is_iterable_but_not_string(obj) and not isinstance(obj, Mapping)\n\n\ndef is_instance_or_subclass(val, class_) -> bool:\n \"\"\"Return True if ``val`` is either a subclass or instance of ``class_``.\"\"\"\n try:\n return issubclass(val, class_)\n except TypeError:\n return isinstance(val, class_)\n\n\ndef is_keyed_tuple(obj) -> bool:\n \"\"\"Return True if ``obj`` has keyed tuple behavior, such as\n namedtuples or SQLAlchemy's KeyedTuples.\n \"\"\"\n return isinstance(obj, tuple) and hasattr(obj, \"_fields\")\n\n\ndef pprint(obj, *args, **kwargs) -> None:\n \"\"\"Pretty-printing function that can pretty-print OrderedDicts\n like regular dictionaries. Useful for printing the output of\n :meth:`marshmallow.Schema.dump`.\n\n .. deprecated:: 3.7.0\n marshmallow.pprint will be removed in marshmallow 4.\n \"\"\"\n warnings.warn(\n \"marshmallow's pprint function is deprecated and will be removed in marshmallow 4.\",\n RemovedInMarshmallow4Warning,\n stacklevel=2,\n )\n if isinstance(obj, collections.OrderedDict):\n print(json.dumps(obj, *args, **kwargs))\n else:\n py_pprint(obj, *args, **kwargs)\n\n\n# https://stackoverflow.com/a/27596917\ndef is_aware(datetime: dt.datetime) -> bool:\n return (\n datetime.tzinfo is not None and datetime.tzinfo.utcoffset(datetime) is not None\n )\n\n\ndef from_rfc(datestring: str) -> dt.datetime:\n \"\"\"Parse a RFC822-formatted datetime string and return a datetime object.\n\n https://stackoverflow.com/questions/885015/how-to-parse-a-rfc-2822-date-time-into-a-python-datetime # noqa: B950\n \"\"\"\n return parsedate_to_datetime(datestring)\n\n\ndef rfcformat(datetime: dt.datetime) -> str:\n \"\"\"Return the RFC822-formatted representation of a datetime object.\n\n :param datetime datetime: The datetime.\n \"\"\"\n return format_datetime(datetime)\n\n\n# Hat tip to Django for ISO8601 deserialization functions\n\n_iso8601_datetime_re = re.compile(\n r\"(?P\\d{4})-(?P\\d{1,2})-(?P\\d{1,2})\"\n r\"[T ](?P\\d{1,2}):(?P\\d{1,2})\"\n r\"(?::(?P\\d{1,2})(?:\\.(?P\\d{1,6})\\d{0,6})?)?\"\n r\"(?PZ|[+-]\\d{2}(?::?\\d{2})?)?$\"\n)\n\n_iso8601_date_re = re.compile(r\"(?P\\d{4})-(?P\\d{1,2})-(?P\\d{1,2})$\")\n\n_iso8601_time_re = re.compile(\n r\"(?P\\d{1,2}):(?P\\d{1,2})\"\n r\"(?::(?P\\d{1,2})(?:\\.(?P\\d{1,6})\\d{0,6})?)?\"\n)\n\n\ndef get_fixed_timezone(offset: int | float | dt.timedelta) -> dt.timezone:\n \"\"\"Return a tzinfo instance with a fixed offset from UTC.\"\"\"\n if isinstance(offset, dt.timedelta):\n offset = offset.total_seconds() // 60\n sign = \"-\" if offset < 0 else \"+\"\n hhmm = \"%02d%02d\" % divmod(abs(offset), 60)\n name = sign + hhmm\n return dt.timezone(dt.timedelta(minutes=offset), name)\n\n\ndef from_iso_datetime(value):\n \"\"\"Parse a string and return a datetime.datetime.\n\n This function supports time zone offsets. When the input contains one,\n the output uses a timezone with a fixed offset from UTC.\n \"\"\"\n match = _iso8601_datetime_re.match(value)\n if not match:\n raise ValueError(\"Not a valid ISO8601-formatted datetime string\")\n kw = match.groupdict()\n kw[\"microsecond\"] = kw[\"microsecond\"] and kw[\"microsecond\"].ljust(6, \"0\")\n tzinfo = kw.pop(\"tzinfo\")\n if tzinfo == \"Z\":\n tzinfo = dt.timezone.utc\n elif tzinfo is not None:\n offset_mins = int(tzinfo[-2:]) if len(tzinfo) > 3 else 0\n offset = 60 * int(tzinfo[1:3]) + offset_mins\n if tzinfo[0] == \"-\":\n offset = -offset\n tzinfo = get_fixed_timezone(offset)\n kw = {k: int(v) for k, v in kw.items() if v is not None}\n kw[\"tzinfo\"] = tzinfo\n return dt.datetime(**kw)\n\n\ndef from_iso_time(value):\n \"\"\"Parse a string and return a datetime.time.\n\n This function doesn't support time zone offsets.\n \"\"\"\n match = _iso8601_time_re.match(value)\n if not match:\n raise ValueError(\"Not a valid ISO8601-formatted time string\")\n kw = match.groupdict()\n kw[\"microsecond\"] = kw[\"microsecond\"] and kw[\"microsecond\"].ljust(6, \"0\")\n kw = {k: int(v) for k, v in kw.items() if v is not None}\n return dt.time(**kw)\n\n\ndef from_iso_date(value):\n \"\"\"Parse a string and return a datetime.date.\"\"\"\n match = _iso8601_date_re.match(value)\n if not match:\n raise ValueError(\"Not a valid ISO8601-formatted date string\")\n kw = {k: int(v) for k, v in match.groupdict().items()}\n return dt.date(**kw)\n\n\ndef from_timestamp(value: typing.Any) -> dt.datetime:\n value = float(value)\n if value < 0:\n raise ValueError(\"Not a valid POSIX timestamp\")\n\n # Load a timestamp with utc as timezone to prevent using system timezone.\n # Then set timezone to None, to let the Field handle adding timezone info.\n try:\n return dt.datetime.fromtimestamp(value, tz=dt.timezone.utc).replace(tzinfo=None)\n except OverflowError as exc:\n raise ValueError(\"Timestamp is too large\") from exc\n except OSError as exc:\n raise ValueError(\"Error converting value to datetime\") from exc\n\n\ndef from_timestamp_ms(value: typing.Any) -> dt.datetime:\n value = float(value)\n return from_timestamp(value / 1000)\n\n\ndef timestamp(\n value: dt.datetime,\n) -> float:\n if not is_aware(value):\n # When a date is naive, use UTC as zone info to prevent using system timezone.\n value = value.replace(tzinfo=dt.timezone.utc)\n return value.timestamp()\n\n\ndef timestamp_ms(value: dt.datetime) -> float:\n return timestamp(value) * 1000\n\n\ndef isoformat(datetime: dt.datetime) -> str:\n \"\"\"Return the ISO8601-formatted representation of a datetime object.\n\n :param datetime datetime: The datetime.\n \"\"\"\n return datetime.isoformat()\n\n\ndef to_iso_time(time: dt.time) -> str:\n return dt.time.isoformat(time)\n\n\ndef to_iso_date(date: dt.date) -> str:\n return dt.date.isoformat(date)\n\n\ndef ensure_text_type(val: str | bytes) -> str:\n if isinstance(val, bytes):\n val = val.decode(\"utf-8\")\n return str(val)\n\n\ndef pluck(dictlist: list[dict[str, typing.Any]], key: str):\n \"\"\"Extracts a list of dictionary values from a list of dictionaries.\n ::\n\n >>> dlist = [{'id': 1, 'name': 'foo'}, {'id': 2, 'name': 'bar'}]\n >>> pluck(dlist, 'id')\n [1, 2]\n \"\"\"\n return [d[key] for d in dictlist]\n\n\n# Various utilities for pulling keyed values from objects\n\n\ndef get_value(obj, key: int | str, default=missing):\n \"\"\"Helper for pulling a keyed value off various types of objects. Fields use\n this method by default to access attributes of the source object. For object `x`\n and attribute `i`, this method first tries to access `x[i]`, and then falls back to\n `x.i` if an exception is raised.\n\n .. warning::\n If an object `x` does not raise an exception when `x[i]` does not exist,\n `get_value` will never check the value `x.i`. Consider overriding\n `marshmallow.fields.Field.get_value` in this case.\n \"\"\"\n if not isinstance(key, int) and \".\" in key:\n return _get_value_for_keys(obj, key.split(\".\"), default)\n else:\n return _get_value_for_key(obj, key, default)\n\n\ndef _get_value_for_keys(obj, keys, default):\n if len(keys) == 1:\n return _get_value_for_key(obj, keys[0], default)\n else:\n return _get_value_for_keys(\n _get_value_for_key(obj, keys[0], default), keys[1:], default\n )\n\n\ndef _get_value_for_key(obj, key, default):\n if not hasattr(obj, \"__getitem__\"):\n return getattr(obj, key, default)\n\n try:\n return obj[key]\n except (KeyError, IndexError, TypeError, AttributeError):\n return getattr(obj, key, default)\n\n\ndef set_value(dct: dict[str, typing.Any], key: str, value: typing.Any):\n \"\"\"Set a value in a dict. If `key` contains a '.', it is assumed\n be a path (i.e. dot-delimited string) to the value's location.\n\n ::\n\n >>> d = {}\n >>> set_value(d, 'foo.bar', 42)\n >>> d\n {'foo': {'bar': 42}}\n \"\"\"\n if \".\" in key:\n head, rest = key.split(\".\", 1)\n target = dct.setdefault(head, {})\n if not isinstance(target, dict):\n raise ValueError(\n f\"Cannot set {key} in {head} \" f\"due to existing value: {target}\"\n )\n set_value(target, rest, value)\n else:\n dct[key] = value\n\n\ndef callable_or_raise(obj):\n \"\"\"Check that an object is callable, else raise a :exc:`TypeError`.\"\"\"\n if not callable(obj):\n raise TypeError(f\"Object {obj!r} is not callable.\")\n return obj\n\n\ndef _signature(func: typing.Callable) -> list[str]:\n return list(inspect.signature(func).parameters.keys())\n\n\ndef get_func_args(func: typing.Callable) -> list[str]:\n \"\"\"Given a callable, return a list of argument names. Handles\n `functools.partial` objects and class-based callables.\n\n .. versionchanged:: 3.0.0a1\n Do not return bound arguments, eg. ``self``.\n \"\"\"\n if inspect.isfunction(func) or inspect.ismethod(func):\n return _signature(func)\n if isinstance(func, functools.partial):\n return _signature(func.func)\n # Callable class\n return _signature(func)\n\n\ndef resolve_field_instance(cls_or_instance):\n \"\"\"Return a Schema instance from a Schema class or instance.\n\n :param type|Schema cls_or_instance: Marshmallow Schema class or instance.\n \"\"\"\n if isinstance(cls_or_instance, type):\n if not issubclass(cls_or_instance, FieldABC):\n raise FieldInstanceResolutionError\n return cls_or_instance()\n else:\n if not isinstance(cls_or_instance, FieldABC):\n raise FieldInstanceResolutionError\n return cls_or_instance\n\n\ndef timedelta_to_microseconds(value: dt.timedelta) -> int:\n \"\"\"Compute the total microseconds of a timedelta\n\n https://github.com/python/cpython/blob/bb3e0c240bc60fe08d332ff5955d54197f79751c/Lib/datetime.py#L665-L667 # noqa: B950\n \"\"\"\n return (value.days * (24 * 3600) + value.seconds) * 1000000 + value.microseconds\n\n\ndef validate_unknown_parameter_value(obj: typing.Any) -> str:\n if obj not in _UNKNOWN_VALUES:\n raise ValueError(\n f\"Object {obj!r} is not a valid value for the 'unknown' parameter\"\n )\n return obj\n\n# Path: src/marshmallow/schema.py\n\"\"\"The :class:`Schema` class, including its metaclass and options (class Meta).\"\"\"\nfrom __future__ import annotations\n\nimport copy\nimport datetime as dt\nimport decimal\nimport inspect\nimport json\nimport typing\nimport uuid\nimport warnings\nfrom abc import ABCMeta\nfrom collections import OrderedDict, defaultdict\nfrom collections.abc import Mapping\nfrom functools import lru_cache\n\nfrom marshmallow import base, class_registry, types\nfrom marshmallow import fields as ma_fields\nfrom marshmallow.decorators import (\n POST_DUMP,\n POST_LOAD,\n PRE_DUMP,\n PRE_LOAD,\n VALIDATES,\n VALIDATES_SCHEMA,\n)\nfrom marshmallow.error_store import ErrorStore\nfrom marshmallow.exceptions import StringNotCollectionError, ValidationError\nfrom marshmallow.orderedset import OrderedSet\nfrom marshmallow.utils import (\n EXCLUDE,\n INCLUDE,\n RAISE,\n get_value,\n is_collection,\n is_instance_or_subclass,\n missing,\n set_value,\n validate_unknown_parameter_value,\n)\nfrom marshmallow.warnings import RemovedInMarshmallow4Warning\n\n_T = typing.TypeVar(\"_T\")\n\n\ndef _get_fields(attrs):\n \"\"\"Get fields from a class\n\n :param attrs: Mapping of class attributes\n \"\"\"\n return [\n (field_name, field_value)\n for field_name, field_value in attrs.items()\n if is_instance_or_subclass(field_value, base.FieldABC)\n ]\n\n\n# This function allows Schemas to inherit from non-Schema classes and ensures\n# inheritance according to the MRO\ndef _get_fields_by_mro(klass):\n \"\"\"Collect fields from a class, following its method resolution order. The\n class itself is excluded from the search; only its parents are checked. Get\n fields from ``_declared_fields`` if available, else use ``__dict__``.\n\n :param type klass: Class whose fields to retrieve\n \"\"\"\n mro = inspect.getmro(klass)\n # Loop over mro in reverse to maintain correct order of fields\n return sum(\n (\n _get_fields(\n getattr(base, \"_declared_fields\", base.__dict__),\n )\n for base in mro[:0:-1]\n ),\n [],\n )\n\n\nclass SchemaMeta(ABCMeta):\n \"\"\"Metaclass for the Schema class. Binds the declared fields to\n a ``_declared_fields`` attribute, which is a dictionary mapping attribute\n names to field objects. Also sets the ``opts`` class attribute, which is\n the Schema class's ``class Meta`` options.\n \"\"\"\n\n def __new__(mcs, name, bases, attrs):\n meta = attrs.get(\"Meta\")\n ordered = getattr(meta, \"ordered\", False)\n if not ordered:\n # Inherit 'ordered' option\n # Warning: We loop through bases instead of MRO because we don't\n # yet have access to the class object\n # (i.e. can't call super before we have fields)\n for base_ in bases:\n if hasattr(base_, \"Meta\") and hasattr(base_.Meta, \"ordered\"):\n ordered = base_.Meta.ordered\n break\n else:\n ordered = False\n cls_fields = _get_fields(attrs)\n # Remove fields from list of class attributes to avoid shadowing\n # Schema attributes/methods in case of name conflict\n for field_name, _ in cls_fields:\n del attrs[field_name]\n klass = super().__new__(mcs, name, bases, attrs)\n inherited_fields = _get_fields_by_mro(klass)\n\n meta = klass.Meta\n # Set klass.opts in __new__ rather than __init__ so that it is accessible in\n # get_declared_fields\n klass.opts = klass.OPTIONS_CLASS(meta, ordered=ordered)\n # Add fields specified in the `include` class Meta option\n cls_fields += list(klass.opts.include.items())\n\n # Assign _declared_fields on class\n klass._declared_fields = mcs.get_declared_fields(\n klass=klass,\n cls_fields=cls_fields,\n inherited_fields=inherited_fields,\n dict_cls=dict,\n )\n return klass\n\n @classmethod\n def get_declared_fields(\n mcs,\n klass: type,\n cls_fields: list,\n inherited_fields: list,\n dict_cls: type = dict,\n ):\n \"\"\"Returns a dictionary of field_name => `Field` pairs declared on the class.\n This is exposed mainly so that plugins can add additional fields, e.g. fields\n computed from class Meta options.\n\n :param klass: The class object.\n :param cls_fields: The fields declared on the class, including those added\n by the ``include`` class Meta option.\n :param inherited_fields: Inherited fields.\n :param dict_cls: dict-like class to use for dict output Default to ``dict``.\n \"\"\"\n return dict_cls(inherited_fields + cls_fields)\n\n def __init__(cls, name, bases, attrs):\n super().__init__(name, bases, attrs)\n if name and cls.opts.register:\n class_registry.register(name, cls)\n cls._hooks = cls.resolve_hooks()\n\n def resolve_hooks(cls) -> dict[types.Tag, list[str]]:\n \"\"\"Add in the decorated processors\n\n By doing this after constructing the class, we let standard inheritance\n do all the hard work.\n \"\"\"\n mro = inspect.getmro(cls)\n\n hooks = defaultdict(list) # type: typing.Dict[types.Tag, typing.List[str]]\n\n for attr_name in dir(cls):\n # Need to look up the actual descriptor, not whatever might be\n # bound to the class. This needs to come from the __dict__ of the\n # declaring class.\n for parent in mro:\n try:\n attr = parent.__dict__[attr_name]\n except KeyError:\n continue\n else:\n break\n else:\n # In case we didn't find the attribute and didn't break above.\n # We should never hit this - it's just here for completeness\n # to exclude the possibility of attr being undefined.\n continue\n\n try:\n hook_config = attr.__marshmallow_hook__\n except AttributeError:\n pass\n else:\n for key in hook_config.keys():\n # Use name here so we can get the bound method later, in\n # case the processor was a descriptor or something.\n hooks[key].append(attr_name)\n\n return hooks\n\n\nclass SchemaOpts:\n \"\"\"class Meta options for the :class:`Schema`. Defines defaults.\"\"\"\n\n def __init__(self, meta, ordered: bool = False):\n self.fields = getattr(meta, \"fields\", ())\n if not isinstance(self.fields, (list, tuple)):\n raise ValueError(\"`fields` option must be a list or tuple.\")\n self.additional = getattr(meta, \"additional\", ())\n if not isinstance(self.additional, (list, tuple)):\n raise ValueError(\"`additional` option must be a list or tuple.\")\n if self.fields and self.additional:\n raise ValueError(\n \"Cannot set both `fields` and `additional` options\"\n \" for the same Schema.\"\n )\n self.exclude = getattr(meta, \"exclude\", ())\n if not isinstance(self.exclude, (list, tuple)):\n raise ValueError(\"`exclude` must be a list or tuple.\")\n self.dateformat = getattr(meta, \"dateformat\", None)\n self.datetimeformat = getattr(meta, \"datetimeformat\", None)\n self.timeformat = getattr(meta, \"timeformat\", None)\n if hasattr(meta, \"json_module\"):\n warnings.warn(\n \"The json_module class Meta option is deprecated. Use render_module instead.\",\n RemovedInMarshmallow4Warning,\n stacklevel=2,\n )\n render_module = getattr(meta, \"json_module\", json)\n else:\n render_module = json\n self.render_module = getattr(meta, \"render_module\", render_module)\n self.ordered = getattr(meta, \"ordered\", ordered)\n self.index_errors = getattr(meta, \"index_errors\", True)\n self.include = getattr(meta, \"include\", {})\n self.load_only = getattr(meta, \"load_only\", ())\n self.dump_only = getattr(meta, \"dump_only\", ())\n self.unknown = validate_unknown_parameter_value(getattr(meta, \"unknown\", RAISE))\n self.register = getattr(meta, \"register\", True)\n\n\nclass Schema(base.SchemaABC, metaclass=SchemaMeta):\n \"\"\"Base schema class with which to define custom schemas.\n\n Example usage:\n\n .. code-block:: python\n\n import datetime as dt\n from dataclasses import dataclass\n\n from marshmallow import Schema, fields\n\n\n @dataclass\n class Album:\n title: str\n release_date: dt.date\n\n\n class AlbumSchema(Schema):\n title = fields.Str()\n release_date = fields.Date()\n\n\n album = Album(\"Beggars Banquet\", dt.date(1968, 12, 6))\n schema = AlbumSchema()\n data = schema.dump(album)\n data # {'release_date': '1968-12-06', 'title': 'Beggars Banquet'}\n\n :param only: Whitelist of the declared fields to select when\n instantiating the Schema. If None, all fields are used. Nested fields\n can be represented with dot delimiters.\n :param exclude: Blacklist of the declared fields to exclude\n when instantiating the Schema. If a field appears in both `only` and\n `exclude`, it is not used. Nested fields can be represented with dot\n delimiters.\n :param many: Should be set to `True` if ``obj`` is a collection\n so that the object will be serialized to a list.\n :param context: Optional context passed to :class:`fields.Method` and\n :class:`fields.Function` fields.\n :param load_only: Fields to skip during serialization (write-only fields)\n :param dump_only: Fields to skip during deserialization (read-only fields)\n :param partial: Whether to ignore missing fields and not require\n any fields declared. Propagates down to ``Nested`` fields as well. If\n its value is an iterable, only missing fields listed in that iterable\n will be ignored. Use dot delimiters to specify nested fields.\n :param unknown: Whether to exclude, include, or raise an error for unknown\n fields in the data. Use `EXCLUDE`, `INCLUDE` or `RAISE`.\n\n .. versionchanged:: 3.0.0\n `prefix` parameter removed.\n\n .. versionchanged:: 2.0.0\n `__validators__`, `__preprocessors__`, and `__data_handlers__` are removed in favor of\n `marshmallow.decorators.validates_schema`,\n `marshmallow.decorators.pre_load` and `marshmallow.decorators.post_dump`.\n `__accessor__` and `__error_handler__` are deprecated. Implement the\n `handle_error` and `get_attribute` methods instead.\n \"\"\"\n\n TYPE_MAPPING = {\n str: ma_fields.String,\n bytes: ma_fields.String,\n dt.datetime: ma_fields.DateTime,\n float: ma_fields.Float,\n bool: ma_fields.Boolean,\n tuple: ma_fields.Raw,\n list: ma_fields.Raw,\n set: ma_fields.Raw,\n int: ma_fields.Integer,\n uuid.UUID: ma_fields.UUID,\n dt.time: ma_fields.Time,\n dt.date: ma_fields.Date,\n dt.timedelta: ma_fields.TimeDelta,\n decimal.Decimal: ma_fields.Decimal,\n } # type: typing.Dict[type, typing.Type[ma_fields.Field]]\n #: Overrides for default schema-level error messages\n error_messages = {} # type: typing.Dict[str, str]\n\n _default_error_messages = {\n \"type\": \"Invalid input type.\",\n \"unknown\": \"Unknown field.\",\n } # type: typing.Dict[str, str]\n\n OPTIONS_CLASS = SchemaOpts # type: type\n\n set_class = OrderedSet\n\n # These get set by SchemaMeta\n opts = None # type: SchemaOpts\n _declared_fields = {} # type: typing.Dict[str, ma_fields.Field]\n _hooks = {} # type: typing.Dict[types.Tag, typing.List[str]]\n\n class Meta:\n \"\"\"Options object for a Schema.\n\n Example usage: ::\n\n class Meta:\n fields = (\"id\", \"email\", \"date_created\")\n exclude = (\"password\", \"secret_attribute\")\n\n Available options:\n\n - ``fields``: Tuple or list of fields to include in the serialized result.\n - ``additional``: Tuple or list of fields to include *in addition* to the\n explicitly declared fields. ``additional`` and ``fields`` are\n mutually-exclusive options.\n - ``include``: Dictionary of additional fields to include in the schema. It is\n usually better to define fields as class variables, but you may need to\n use this option, e.g., if your fields are Python keywords. May be an\n `OrderedDict`.\n - ``exclude``: Tuple or list of fields to exclude in the serialized result.\n Nested fields can be represented with dot delimiters.\n - ``dateformat``: Default format for `Date ` fields.\n - ``datetimeformat``: Default format for `DateTime ` fields.\n - ``timeformat``: Default format for `Time ` fields.\n - ``render_module``: Module to use for `loads ` and `dumps `.\n Defaults to `json` from the standard library.\n - ``ordered``: If `True`, output of `Schema.dump` will be a `collections.OrderedDict`.\n - ``index_errors``: If `True`, errors dictionaries will include the index\n of invalid items in a collection.\n - ``load_only``: Tuple or list of fields to exclude from serialized results.\n - ``dump_only``: Tuple or list of fields to exclude from deserialization\n - ``unknown``: Whether to exclude, include, or raise an error for unknown\n fields in the data. Use `EXCLUDE`, `INCLUDE` or `RAISE`.\n - ``register``: Whether to register the `Schema` with marshmallow's internal\n class registry. Must be `True` if you intend to refer to this `Schema`\n by class name in `Nested` fields. Only set this to `False` when memory\n usage is critical. Defaults to `True`.\n \"\"\"\n\n def __init__(\n self,\n *,\n only: types.StrSequenceOrSet | None = None,\n exclude: types.StrSequenceOrSet = (),\n many: bool = False,\n context: dict | None = None,\n load_only: types.StrSequenceOrSet = (),\n dump_only: types.StrSequenceOrSet = (),\n partial: bool | types.StrSequenceOrSet | None = None,\n unknown: str | None = None,\n ):\n # Raise error if only or exclude is passed as string, not list of strings\n if only is not None and not is_collection(only):\n raise StringNotCollectionError('\"only\" should be a list of strings')\n if not is_collection(exclude):\n raise StringNotCollectionError('\"exclude\" should be a list of strings')\n # copy declared fields from metaclass\n self.declared_fields = copy.deepcopy(self._declared_fields)\n self.many = many\n self.only = only\n self.exclude: set[typing.Any] | typing.MutableSet[typing.Any] = set(\n self.opts.exclude\n ) | set(exclude)\n self.ordered = self.opts.ordered\n self.load_only = set(load_only) or set(self.opts.load_only)\n self.dump_only = set(dump_only) or set(self.opts.dump_only)\n self.partial = partial\n self.unknown = (\n self.opts.unknown\n if unknown is None\n else validate_unknown_parameter_value(unknown)\n )\n self.context = context or {}\n self._normalize_nested_options()\n #: Dictionary mapping field_names -> :class:`Field` objects\n self.fields = {} # type: typing.Dict[str, ma_fields.Field]\n self.load_fields = {} # type: typing.Dict[str, ma_fields.Field]\n self.dump_fields = {} # type: typing.Dict[str, ma_fields.Field]\n self._init_fields()\n messages = {}\n messages.update(self._default_error_messages)\n for cls in reversed(self.__class__.__mro__):\n messages.update(getattr(cls, \"error_messages\", {}))\n messages.update(self.error_messages or {})\n self.error_messages = messages\n\n def __repr__(self) -> str:\n return f\"<{self.__class__.__name__}(many={self.many})>\"\n\n @property\n def dict_class(self) -> type:\n return OrderedDict if self.ordered else dict\n\n @classmethod\n def from_dict(\n cls,\n fields: dict[str, ma_fields.Field | type],\n *,\n name: str = \"GeneratedSchema\",\n ) -> type:\n \"\"\"Generate a `Schema` class given a dictionary of fields.\n\n .. code-block:: python\n\n from marshmallow import Schema, fields\n\n PersonSchema = Schema.from_dict({\"name\": fields.Str()})\n print(PersonSchema().load({\"name\": \"David\"})) # => {'name': 'David'}\n\n Generated schemas are not added to the class registry and therefore cannot\n be referred to by name in `Nested` fields.\n\n :param dict fields: Dictionary mapping field names to field instances.\n :param str name: Optional name for the class, which will appear in\n the ``repr`` for the class.\n\n .. versionadded:: 3.0.0\n \"\"\"\n attrs = fields.copy()\n attrs[\"Meta\"] = type(\n \"GeneratedMeta\", (getattr(cls, \"Meta\", object),), {\"register\": False}\n )\n schema_cls = type(name, (cls,), attrs)\n return schema_cls\n\n ##### Override-able methods #####\n\n def handle_error(\n self, error: ValidationError, data: typing.Any, *, many: bool, **kwargs\n ):\n \"\"\"Custom error handler function for the schema.\n\n :param error: The `ValidationError` raised during (de)serialization.\n :param data: The original input data.\n :param many: Value of ``many`` on dump or load.\n :param partial: Value of ``partial`` on load.\n\n .. versionadded:: 2.0.0\n\n .. versionchanged:: 3.0.0rc9\n Receives `many` and `partial` (on deserialization) as keyword arguments.\n \"\"\"\n pass\n\n def get_attribute(self, obj: typing.Any, attr: str, default: typing.Any):\n \"\"\"Defines how to pull values from an object to serialize.\n\n .. versionadded:: 2.0.0\n\n .. versionchanged:: 3.0.0a1\n Changed position of ``obj`` and ``attr``.\n \"\"\"\n return get_value(obj, attr, default)\n\n ##### Serialization/Deserialization API #####\n\n @staticmethod\n def _call_and_store(getter_func, data, *, field_name, error_store, index=None):\n \"\"\"Call ``getter_func`` with ``data`` as its argument, and store any `ValidationErrors`.\n\n :param callable getter_func: Function for getting the serialized/deserialized\n value from ``data``.\n :param data: The data passed to ``getter_func``.\n :param str field_name: Field name.\n :param int index: Index of the item being validated, if validating a collection,\n otherwise `None`.\n \"\"\"\n try:\n value = getter_func(data)\n except ValidationError as error:\n error_store.store_error(error.messages, field_name, index=index)\n # When a Nested field fails validation, the marshalled data is stored\n # on the ValidationError's valid_data attribute\n return error.valid_data or missing\n return value\n\n def _serialize(self, obj: _T | typing.Iterable[_T], *, many: bool = False):\n \"\"\"Serialize ``obj``.\n\n :param obj: The object(s) to serialize.\n :param bool many: `True` if ``data`` should be serialized as a collection.\n :return: A dictionary of the serialized data\n\n .. versionchanged:: 1.0.0\n Renamed from ``marshal``.\n \"\"\"\n if many and obj is not None:\n return [\n self._serialize(d, many=False)\n for d in typing.cast(typing.Iterable[_T], obj)\n ]\n ret = self.dict_class()\n for attr_name, field_obj in self.dump_fields.items():\n value = field_obj.serialize(attr_name, obj, accessor=self.get_attribute)\n if value is missing:\n continue\n key = field_obj.data_key if field_obj.data_key is not None else attr_name\n ret[key] = value\n return ret\n\n def dump(self, obj: typing.Any, *, many: bool | None = None):\n \"\"\"Serialize an object to native Python data types according to this\n Schema's fields.\n\n :param obj: The object to serialize.\n :param many: Whether to serialize `obj` as a collection. If `None`, the value\n for `self.many` is used.\n :return: Serialized data\n\n .. versionadded:: 1.0.0\n .. versionchanged:: 3.0.0b7\n This method returns the serialized data rather than a ``(data, errors)`` duple.\n A :exc:`ValidationError ` is raised\n if ``obj`` is invalid.\n .. versionchanged:: 3.0.0rc9\n Validation no longer occurs upon serialization.\n \"\"\"\n many = self.many if many is None else bool(many)\n if self._has_processors(PRE_DUMP):\n processed_obj = self._invoke_dump_processors(\n PRE_DUMP, obj, many=many, original_data=obj\n )\n else:\n processed_obj = obj\n\n result = self._serialize(processed_obj, many=many)\n\n if self._has_processors(POST_DUMP):\n result = self._invoke_dump_processors(\n POST_DUMP, result, many=many, original_data=obj\n )\n\n return result\n\n def dumps(self, obj: typing.Any, *args, many: bool | None = None, **kwargs):\n \"\"\"Same as :meth:`dump`, except return a JSON-encoded string.\n\n :param obj: The object to serialize.\n :param many: Whether to serialize `obj` as a collection. If `None`, the value\n for `self.many` is used.\n :return: A ``json`` string\n\n .. versionadded:: 1.0.0\n .. versionchanged:: 3.0.0b7\n This method returns the serialized data rather than a ``(data, errors)`` duple.\n A :exc:`ValidationError ` is raised\n if ``obj`` is invalid.\n \"\"\"\n serialized = self.dump(obj, many=many)\n return self.opts.render_module.dumps(serialized, *args, **kwargs)\n\n def _deserialize(\n self,\n data: (\n typing.Mapping[str, typing.Any]\n | typing.Iterable[typing.Mapping[str, typing.Any]]\n ),\n *,\n error_store: ErrorStore,\n many: bool = False,\n partial=None,\n unknown=RAISE,\n index=None,\n ) -> _T | list[_T]:\n \"\"\"Deserialize ``data``.\n\n :param dict data: The data to deserialize.\n :param ErrorStore error_store: Structure to store errors.\n :param bool many: `True` if ``data`` should be deserialized as a collection.\n :param bool|tuple partial: Whether to ignore missing fields and not require\n any fields declared. Propagates down to ``Nested`` fields as well. If\n its value is an iterable, only missing fields listed in that iterable\n will be ignored. Use dot delimiters to specify nested fields.\n :param unknown: Whether to exclude, include, or raise an error for unknown\n fields in the data. Use `EXCLUDE`, `INCLUDE` or `RAISE`.\n :param int index: Index of the item being serialized (for storing errors) if\n serializing a collection, otherwise `None`.\n :return: A dictionary of the deserialized data.\n \"\"\"\n index_errors = self.opts.index_errors\n index = index if index_errors else None\n if many:\n if not is_collection(data):\n error_store.store_error([self.error_messages[\"type\"]], index=index)\n ret_l = [] # type: typing.List[_T]\n else:\n ret_l = [\n typing.cast(\n _T,\n self._deserialize(\n typing.cast(typing.Mapping[str, typing.Any], d),\n error_store=error_store,\n many=False,\n partial=partial,\n unknown=unknown,\n index=idx,\n ),\n )\n for idx, d in enumerate(data)\n ]\n return ret_l\n ret_d = self.dict_class()\n # Check data is a dict\n if not isinstance(data, Mapping):\n error_store.store_error([self.error_messages[\"type\"]], index=index)\n else:\n partial_is_collection = is_collection(partial)\n for attr_name, field_obj in self.load_fields.items():\n field_name = (\n field_obj.data_key if field_obj.data_key is not None else attr_name\n )\n raw_value = data.get(field_name, missing)\n if raw_value is missing:\n # Ignore missing field if we're allowed to.\n if partial is True or (\n partial_is_collection and attr_name in partial\n ):\n continue\n d_kwargs = {}\n # Allow partial loading of nested schemas.\n if partial_is_collection:\n prefix = field_name + \".\"\n len_prefix = len(prefix)\n sub_partial = [\n f[len_prefix:] for f in partial if f.startswith(prefix)\n ]\n d_kwargs[\"partial\"] = sub_partial\n elif partial is not None:\n d_kwargs[\"partial\"] = partial\n\n def getter(\n val, field_obj=field_obj, field_name=field_name, d_kwargs=d_kwargs\n ):\n return field_obj.deserialize(\n val,\n field_name,\n data,\n **d_kwargs,\n )\n\n value = self._call_and_store(\n getter_func=getter,\n data=raw_value,\n field_name=field_name,\n error_store=error_store,\n index=index,\n )\n if value is not missing:\n key = field_obj.attribute or attr_name\n set_value(ret_d, key, value)\n if unknown != EXCLUDE:\n fields = {\n field_obj.data_key if field_obj.data_key is not None else field_name\n for field_name, field_obj in self.load_fields.items()\n }\n for key in set(data) - fields:\n value = data[key]\n if unknown == INCLUDE:\n ret_d[key] = value\n elif unknown == RAISE:\n error_store.store_error(\n [self.error_messages[\"unknown\"]],\n key,\n (index if index_errors else None),\n )\n return ret_d\n\n def load(\n self,\n data: (\n typing.Mapping[str, typing.Any]\n | typing.Iterable[typing.Mapping[str, typing.Any]]\n ),\n *,\n many: bool | None = None,\n partial: bool | types.StrSequenceOrSet | None = None,\n unknown: str | None = None,\n ):\n \"\"\"Deserialize a data structure to an object defined by this Schema's fields.\n\n :param data: The data to deserialize.\n :param many: Whether to deserialize `data` as a collection. If `None`, the\n value for `self.many` is used.\n :param partial: Whether to ignore missing fields and not require\n any fields declared. Propagates down to ``Nested`` fields as well. If\n its value is an iterable, only missing fields listed in that iterable\n will be ignored. Use dot delimiters to specify nested fields.\n :param unknown: Whether to exclude, include, or raise an error for unknown\n fields in the data. Use `EXCLUDE`, `INCLUDE` or `RAISE`.\n If `None`, the value for `self.unknown` is used.\n :return: Deserialized data\n\n .. versionadded:: 1.0.0\n .. versionchanged:: 3.0.0b7\n This method returns the deserialized data rather than a ``(data, errors)`` duple.\n A :exc:`ValidationError ` is raised\n if invalid data are passed.\n \"\"\"\n return self._do_load(\n data, many=many, partial=partial, unknown=unknown, postprocess=True\n )\n\n def loads(\n self,\n json_data: str,\n *,\n many: bool | None = None,\n partial: bool | types.StrSequenceOrSet | None = None,\n unknown: str | None = None,\n **kwargs,\n ):\n \"\"\"Same as :meth:`load`, except it takes a JSON string as input.\n\n :param json_data: A JSON string of the data to deserialize.\n :param many: Whether to deserialize `obj` as a collection. If `None`, the\n value for `self.many` is used.\n :param partial: Whether to ignore missing fields and not require\n any fields declared. Propagates down to ``Nested`` fields as well. If\n its value is an iterable, only missing fields listed in that iterable\n will be ignored. Use dot delimiters to specify nested fields.\n...\n", "needle_token_start": 2336, "needle_token_end": 2471, "code_context_ntokens": 16117, "output": ["```python\ndef store_error(self, messages, field_name=SCHEMA, index=None):\n # field error -> store/merge error messages under field name key\n # schema error -> if string or list, store/merge under _schema key\n # -> if dict, store/merge with other top-level keys\n if field_name != SCHEMA or not isinstance(messages, dict):\n messages = {field_name: messages}\n if index is not None:\n messages = {index: messages}\n self.errors = merge_errors(self.errors, messages)\n```"]} +{"repo": "marshmallow-code/marshmallow", "name": "dict_class", "language": "python", "path": "src/marshmallow/schema.py", "position_ratio": 0.25, "description": "\nFunction Description:\n1. **Purpose**: To determine the type of dictionary used in the serialization process based on the ordering preference.\n2. **Input**: No direct inputs; the decision is based on the internal state of the ordering flag.\n3. **Output**: Returns the class type of the dictionary to be used.\n4. **Procedure**: Checks if the ordering flag is set to true; if so, it returns an ordered dictionary type that maintains the order of insertion. Otherwise, it returns a standard dictionary type that does not guarantee order.\n\n", "instruction": "Based on the function description and code context, please retrieve and repeat the exact described function from the code context in a code block wrapped by ```:", "template": "instruction\ncode_context\ndescription\ninstruction", "code_context": "\n :param type klass: Class whose fields to retrieve\n \"\"\"\n mro = inspect.getmro(klass)\n # Loop over mro in reverse to maintain correct order of fields\n return sum(\n (\n _get_fields(\n getattr(base, \"_declared_fields\", base.__dict__),\n )\n for base in mro[:0:-1]\n ),\n [],\n )\n\n\nclass SchemaMeta(ABCMeta):\n \"\"\"Metaclass for the Schema class. Binds the declared fields to\n a ``_declared_fields`` attribute, which is a dictionary mapping attribute\n names to field objects. Also sets the ``opts`` class attribute, which is\n the Schema class's ``class Meta`` options.\n \"\"\"\n\n def __new__(mcs, name, bases, attrs):\n meta = attrs.get(\"Meta\")\n ordered = getattr(meta, \"ordered\", False)\n if not ordered:\n # Inherit 'ordered' option\n # Warning: We loop through bases instead of MRO because we don't\n # yet have access to the class object\n # (i.e. can't call super before we have fields)\n for base_ in bases:\n if hasattr(base_, \"Meta\") and hasattr(base_.Meta, \"ordered\"):\n ordered = base_.Meta.ordered\n break\n else:\n ordered = False\n cls_fields = _get_fields(attrs)\n # Remove fields from list of class attributes to avoid shadowing\n # Schema attributes/methods in case of name conflict\n for field_name, _ in cls_fields:\n del attrs[field_name]\n klass = super().__new__(mcs, name, bases, attrs)\n inherited_fields = _get_fields_by_mro(klass)\n\n meta = klass.Meta\n # Set klass.opts in __new__ rather than __init__ so that it is accessible in\n # get_declared_fields\n klass.opts = klass.OPTIONS_CLASS(meta, ordered=ordered)\n # Add fields specified in the `include` class Meta option\n cls_fields += list(klass.opts.include.items())\n\n # Assign _declared_fields on class\n klass._declared_fields = mcs.get_declared_fields(\n klass=klass,\n cls_fields=cls_fields,\n inherited_fields=inherited_fields,\n dict_cls=dict,\n )\n return klass\n\n @classmethod\n def get_declared_fields(\n mcs,\n klass: type,\n cls_fields: list,\n inherited_fields: list,\n dict_cls: type = dict,\n ):\n \"\"\"Returns a dictionary of field_name => `Field` pairs declared on the class.\n This is exposed mainly so that plugins can add additional fields, e.g. fields\n computed from class Meta options.\n\n :param klass: The class object.\n :param cls_fields: The fields declared on the class, including those added\n by the ``include`` class Meta option.\n :param inherited_fields: Inherited fields.\n :param dict_cls: dict-like class to use for dict output Default to ``dict``.\n \"\"\"\n return dict_cls(inherited_fields + cls_fields)\n\n def __init__(cls, name, bases, attrs):\n super().__init__(name, bases, attrs)\n if name and cls.opts.register:\n class_registry.register(name, cls)\n cls._hooks = cls.resolve_hooks()\n\n def resolve_hooks(cls) -> dict[types.Tag, list[str]]:\n \"\"\"Add in the decorated processors\n\n By doing this after constructing the class, we let standard inheritance\n do all the hard work.\n \"\"\"\n mro = inspect.getmro(cls)\n\n hooks = defaultdict(list) # type: typing.Dict[types.Tag, typing.List[str]]\n\n for attr_name in dir(cls):\n # Need to look up the actual descriptor, not whatever might be\n # bound to the class. This needs to come from the __dict__ of the\n # declaring class.\n for parent in mro:\n try:\n attr = parent.__dict__[attr_name]\n except KeyError:\n continue\n else:\n break\n else:\n # In case we didn't find the attribute and didn't break above.\n # We should never hit this - it's just here for completeness\n # to exclude the possibility of attr being undefined.\n continue\n\n try:\n hook_config = attr.__marshmallow_hook__\n except AttributeError:\n pass\n else:\n for key in hook_config.keys():\n # Use name here so we can get the bound method later, in\n # case the processor was a descriptor or something.\n hooks[key].append(attr_name)\n\n return hooks\n\n\nclass SchemaOpts:\n \"\"\"class Meta options for the :class:`Schema`. Defines defaults.\"\"\"\n\n def __init__(self, meta, ordered: bool = False):\n self.fields = getattr(meta, \"fields\", ())\n if not isinstance(self.fields, (list, tuple)):\n raise ValueError(\"`fields` option must be a list or tuple.\")\n self.additional = getattr(meta, \"additional\", ())\n if not isinstance(self.additional, (list, tuple)):\n raise ValueError(\"`additional` option must be a list or tuple.\")\n if self.fields and self.additional:\n raise ValueError(\n \"Cannot set both `fields` and `additional` options\"\n \" for the same Schema.\"\n )\n self.exclude = getattr(meta, \"exclude\", ())\n if not isinstance(self.exclude, (list, tuple)):\n raise ValueError(\"`exclude` must be a list or tuple.\")\n self.dateformat = getattr(meta, \"dateformat\", None)\n self.datetimeformat = getattr(meta, \"datetimeformat\", None)\n self.timeformat = getattr(meta, \"timeformat\", None)\n if hasattr(meta, \"json_module\"):\n warnings.warn(\n \"The json_module class Meta option is deprecated. Use render_module instead.\",\n RemovedInMarshmallow4Warning,\n stacklevel=2,\n )\n render_module = getattr(meta, \"json_module\", json)\n else:\n render_module = json\n self.render_module = getattr(meta, \"render_module\", render_module)\n self.ordered = getattr(meta, \"ordered\", ordered)\n self.index_errors = getattr(meta, \"index_errors\", True)\n self.include = getattr(meta, \"include\", {})\n self.load_only = getattr(meta, \"load_only\", ())\n self.dump_only = getattr(meta, \"dump_only\", ())\n self.unknown = validate_unknown_parameter_value(getattr(meta, \"unknown\", RAISE))\n self.register = getattr(meta, \"register\", True)\n\n\nclass Schema(base.SchemaABC, metaclass=SchemaMeta):\n \"\"\"Base schema class with which to define custom schemas.\n\n Example usage:\n\n .. code-block:: python\n\n import datetime as dt\n from dataclasses import dataclass\n\n from marshmallow import Schema, fields\n\n\n @dataclass\n class Album:\n title: str\n release_date: dt.date\n\n\n class AlbumSchema(Schema):\n title = fields.Str()\n release_date = fields.Date()\n\n\n album = Album(\"Beggars Banquet\", dt.date(1968, 12, 6))\n schema = AlbumSchema()\n data = schema.dump(album)\n data # {'release_date': '1968-12-06', 'title': 'Beggars Banquet'}\n\n :param only: Whitelist of the declared fields to select when\n instantiating the Schema. If None, all fields are used. Nested fields\n can be represented with dot delimiters.\n :param exclude: Blacklist of the declared fields to exclude\n when instantiating the Schema. If a field appears in both `only` and\n `exclude`, it is not used. Nested fields can be represented with dot\n delimiters.\n :param many: Should be set to `True` if ``obj`` is a collection\n so that the object will be serialized to a list.\n :param context: Optional context passed to :class:`fields.Method` and\n :class:`fields.Function` fields.\n :param load_only: Fields to skip during serialization (write-only fields)\n :param dump_only: Fields to skip during deserialization (read-only fields)\n :param partial: Whether to ignore missing fields and not require\n any fields declared. Propagates down to ``Nested`` fields as well. If\n its value is an iterable, only missing fields listed in that iterable\n will be ignored. Use dot delimiters to specify nested fields.\n :param unknown: Whether to exclude, include, or raise an error for unknown\n fields in the data. Use `EXCLUDE`, `INCLUDE` or `RAISE`.\n\n .. versionchanged:: 3.0.0\n `prefix` parameter removed.\n\n .. versionchanged:: 2.0.0\n `__validators__`, `__preprocessors__`, and `__data_handlers__` are removed in favor of\n `marshmallow.decorators.validates_schema`,\n `marshmallow.decorators.pre_load` and `marshmallow.decorators.post_dump`.\n `__accessor__` and `__error_handler__` are deprecated. Implement the\n `handle_error` and `get_attribute` methods instead.\n \"\"\"\n\n TYPE_MAPPING = {\n str: ma_fields.String,\n bytes: ma_fields.String,\n dt.datetime: ma_fields.DateTime,\n float: ma_fields.Float,\n bool: ma_fields.Boolean,\n tuple: ma_fields.Raw,\n list: ma_fields.Raw,\n set: ma_fields.Raw,\n int: ma_fields.Integer,\n uuid.UUID: ma_fields.UUID,\n dt.time: ma_fields.Time,\n dt.date: ma_fields.Date,\n dt.timedelta: ma_fields.TimeDelta,\n decimal.Decimal: ma_fields.Decimal,\n } # type: typing.Dict[type, typing.Type[ma_fields.Field]]\n #: Overrides for default schema-level error messages\n error_messages = {} # type: typing.Dict[str, str]\n\n _default_error_messages = {\n \"type\": \"Invalid input type.\",\n \"unknown\": \"Unknown field.\",\n } # type: typing.Dict[str, str]\n\n OPTIONS_CLASS = SchemaOpts # type: type\n\n set_class = OrderedSet\n\n # These get set by SchemaMeta\n opts = None # type: SchemaOpts\n _declared_fields = {} # type: typing.Dict[str, ma_fields.Field]\n _hooks = {} # type: typing.Dict[types.Tag, typing.List[str]]\n\n class Meta:\n \"\"\"Options object for a Schema.\n\n Example usage: ::\n\n class Meta:\n fields = (\"id\", \"email\", \"date_created\")\n exclude = (\"password\", \"secret_attribute\")\n\n Available options:\n\n - ``fields``: Tuple or list of fields to include in the serialized result.\n - ``additional``: Tuple or list of fields to include *in addition* to the\n explicitly declared fields. ``additional`` and ``fields`` are\n mutually-exclusive options.\n - ``include``: Dictionary of additional fields to include in the schema. It is\n usually better to define fields as class variables, but you may need to\n use this option, e.g., if your fields are Python keywords. May be an\n `OrderedDict`.\n - ``exclude``: Tuple or list of fields to exclude in the serialized result.\n Nested fields can be represented with dot delimiters.\n - ``dateformat``: Default format for `Date ` fields.\n - ``datetimeformat``: Default format for `DateTime ` fields.\n - ``timeformat``: Default format for `Time ` fields.\n - ``render_module``: Module to use for `loads ` and `dumps `.\n Defaults to `json` from the standard library.\n - ``ordered``: If `True`, output of `Schema.dump` will be a `collections.OrderedDict`.\n - ``index_errors``: If `True`, errors dictionaries will include the index\n of invalid items in a collection.\n - ``load_only``: Tuple or list of fields to exclude from serialized results.\n - ``dump_only``: Tuple or list of fields to exclude from deserialization\n - ``unknown``: Whether to exclude, include, or raise an error for unknown\n fields in the data. Use `EXCLUDE`, `INCLUDE` or `RAISE`.\n - ``register``: Whether to register the `Schema` with marshmallow's internal\n class registry. Must be `True` if you intend to refer to this `Schema`\n by class name in `Nested` fields. Only set this to `False` when memory\n usage is critical. Defaults to `True`.\n \"\"\"\n\n def __init__(\n self,\n *,\n only: types.StrSequenceOrSet | None = None,\n exclude: types.StrSequenceOrSet = (),\n many: bool = False,\n context: dict | None = None,\n load_only: types.StrSequenceOrSet = (),\n dump_only: types.StrSequenceOrSet = (),\n partial: bool | types.StrSequenceOrSet | None = None,\n unknown: str | None = None,\n ):\n # Raise error if only or exclude is passed as string, not list of strings\n if only is not None and not is_collection(only):\n raise StringNotCollectionError('\"only\" should be a list of strings')\n if not is_collection(exclude):\n raise StringNotCollectionError('\"exclude\" should be a list of strings')\n # copy declared fields from metaclass\n self.declared_fields = copy.deepcopy(self._declared_fields)\n self.many = many\n self.only = only\n self.exclude: set[typing.Any] | typing.MutableSet[typing.Any] = set(\n self.opts.exclude\n ) | set(exclude)\n self.ordered = self.opts.ordered\n self.load_only = set(load_only) or set(self.opts.load_only)\n self.dump_only = set(dump_only) or set(self.opts.dump_only)\n self.partial = partial\n self.unknown = (\n self.opts.unknown\n if unknown is None\n else validate_unknown_parameter_value(unknown)\n )\n self.context = context or {}\n self._normalize_nested_options()\n #: Dictionary mapping field_names -> :class:`Field` objects\n self.fields = {} # type: typing.Dict[str, ma_fields.Field]\n self.load_fields = {} # type: typing.Dict[str, ma_fields.Field]\n self.dump_fields = {} # type: typing.Dict[str, ma_fields.Field]\n self._init_fields()\n messages = {}\n messages.update(self._default_error_messages)\n for cls in reversed(self.__class__.__mro__):\n messages.update(getattr(cls, \"error_messages\", {}))\n messages.update(self.error_messages or {})\n self.error_messages = messages\n\n def __repr__(self) -> str:\n return f\"<{self.__class__.__name__}(many={self.many})>\"\n\n @property\n \ndef dict_class(self) -> type:\n return OrderedDict if self.ordered else dict\n\n @classmethod\n def from_dict(\n cls,\n fields: dict[str, ma_fields.Field | type],\n *,\n name: str = \"GeneratedSchema\",\n ) -> type:\n \"\"\"Generate a `Schema` class given a dictionary of fields.\n\n .. code-block:: python\n\n from marshmallow import Schema, fields\n\n PersonSchema = Schema.from_dict({\"name\": fields.Str()})\n print(PersonSchema().load({\"name\": \"David\"})) # => {'name': 'David'}\n\n Generated schemas are not added to the class registry and therefore cannot\n be referred to by name in `Nested` fields.\n\n :param dict fields: Dictionary mapping field names to field instances.\n :param str name: Optional name for the class, which will appear in\n the ``repr`` for the class.\n\n .. versionadded:: 3.0.0\n \"\"\"\n attrs = fields.copy()\n attrs[\"Meta\"] = type(\n \"GeneratedMeta\", (getattr(cls, \"Meta\", object),), {\"register\": False}\n )\n schema_cls = type(name, (cls,), attrs)\n return schema_cls\n\n ##### Override-able methods #####\n\n def handle_error(\n self, error: ValidationError, data: typing.Any, *, many: bool, **kwargs\n ):\n \"\"\"Custom error handler function for the schema.\n\n :param error: The `ValidationError` raised during (de)serialization.\n :param data: The original input data.\n :param many: Value of ``many`` on dump or load.\n :param partial: Value of ``partial`` on load.\n\n .. versionadded:: 2.0.0\n\n .. versionchanged:: 3.0.0rc9\n Receives `many` and `partial` (on deserialization) as keyword arguments.\n \"\"\"\n pass\n\n def get_attribute(self, obj: typing.Any, attr: str, default: typing.Any):\n \"\"\"Defines how to pull values from an object to serialize.\n\n .. versionadded:: 2.0.0\n\n .. versionchanged:: 3.0.0a1\n Changed position of ``obj`` and ``attr``.\n \"\"\"\n return get_value(obj, attr, default)\n\n ##### Serialization/Deserialization API #####\n\n @staticmethod\n def _call_and_store(getter_func, data, *, field_name, error_store, index=None):\n \"\"\"Call ``getter_func`` with ``data`` as its argument, and store any `ValidationErrors`.\n\n :param callable getter_func: Function for getting the serialized/deserialized\n value from ``data``.\n :param data: The data passed to ``getter_func``.\n :param str field_name: Field name.\n :param int index: Index of the item being validated, if validating a collection,\n otherwise `None`.\n \"\"\"\n try:\n value = getter_func(data)\n except ValidationError as error:\n error_store.store_error(error.messages, field_name, index=index)\n # When a Nested field fails validation, the marshalled data is stored\n # on the ValidationError's valid_data attribute\n return error.valid_data or missing\n return value\n\n def _serialize(self, obj: _T | typing.Iterable[_T], *, many: bool = False):\n \"\"\"Serialize ``obj``.\n\n :param obj: The object(s) to serialize.\n :param bool many: `True` if ``data`` should be serialized as a collection.\n :return: A dictionary of the serialized data\n\n .. versionchanged:: 1.0.0\n Renamed from ``marshal``.\n \"\"\"\n if many and obj is not None:\n return [\n self._serialize(d, many=False)\n for d in typing.cast(typing.Iterable[_T], obj)\n ]\n ret = self.dict_class()\n for attr_name, field_obj in self.dump_fields.items():\n value = field_obj.serialize(attr_name, obj, accessor=self.get_attribute)\n if value is missing:\n continue\n key = field_obj.data_key if field_obj.data_key is not None else attr_name\n ret[key] = value\n return ret\n\n def dump(self, obj: typing.Any, *, many: bool | None = None):\n \"\"\"Serialize an object to native Python data types according to this\n Schema's fields.\n\n :param obj: The object to serialize.\n :param many: Whether to serialize `obj` as a collection. If `None`, the value\n for `self.many` is used.\n :return: Serialized data\n\n .. versionadded:: 1.0.0\n .. versionchanged:: 3.0.0b7\n This method returns the serialized data rather than a ``(data, errors)`` duple.\n A :exc:`ValidationError ` is raised\n if ``obj`` is invalid.\n .. versionchanged:: 3.0.0rc9\n Validation no longer occurs upon serialization.\n \"\"\"\n many = self.many if many is None else bool(many)\n if self._has_processors(PRE_DUMP):\n processed_obj = self._invoke_dump_processors(\n PRE_DUMP, obj, many=many, original_data=obj\n )\n else:\n processed_obj = obj\n\n result = self._serialize(processed_obj, many=many)\n\n if self._has_processors(POST_DUMP):\n result = self._invoke_dump_processors(\n POST_DUMP, result, many=many, original_data=obj\n )\n\n return result\n\n def dumps(self, obj: typing.Any, *args, many: bool | None = None, **kwargs):\n \"\"\"Same as :meth:`dump`, except return a JSON-encoded string.\n\n :param obj: The object to serialize.\n :param many: Whether to serialize `obj` as a collection. If `None`, the value\n for `self.many` is used.\n :return: A ``json`` string\n\n .. versionadded:: 1.0.0\n .. versionchanged:: 3.0.0b7\n This method returns the serialized data rather than a ``(data, errors)`` duple.\n A :exc:`ValidationError ` is raised\n if ``obj`` is invalid.\n \"\"\"\n serialized = self.dump(obj, many=many)\n return self.opts.render_module.dumps(serialized, *args, **kwargs)\n\n def _deserialize(\n self,\n data: (\n typing.Mapping[str, typing.Any]\n | typing.Iterable[typing.Mapping[str, typing.Any]]\n ),\n *,\n error_store: ErrorStore,\n many: bool = False,\n partial=None,\n unknown=RAISE,\n index=None,\n ) -> _T | list[_T]:\n \"\"\"Deserialize ``data``.\n\n :param dict data: The data to deserialize.\n :param ErrorStore error_store: Structure to store errors.\n :param bool many: `True` if ``data`` should be deserialized as a collection.\n :param bool|tuple partial: Whether to ignore missing fields and not require\n any fields declared. Propagates down to ``Nested`` fields as well. If\n its value is an iterable, only missing fields listed in that iterable\n will be ignored. Use dot delimiters to specify nested fields.\n :param unknown: Whether to exclude, include, or raise an error for unknown\n fields in the data. Use `EXCLUDE`, `INCLUDE` or `RAISE`.\n :param int index: Index of the item being serialized (for storing errors) if\n serializing a collection, otherwise `None`.\n :return: A dictionary of the deserialized data.\n \"\"\"\n index_errors = self.opts.index_errors\n index = index if index_errors else None\n if many:\n if not is_collection(data):\n error_store.store_error([self.error_messages[\"type\"]], index=index)\n ret_l = [] # type: typing.List[_T]\n else:\n ret_l = [\n typing.cast(\n _T,\n self._deserialize(\n typing.cast(typing.Mapping[str, typing.Any], d),\n error_store=error_store,\n many=False,\n partial=partial,\n unknown=unknown,\n index=idx,\n ),\n )\n for idx, d in enumerate(data)\n ]\n return ret_l\n ret_d = self.dict_class()\n # Check data is a dict\n if not isinstance(data, Mapping):\n error_store.store_error([self.error_messages[\"type\"]], index=index)\n else:\n partial_is_collection = is_collection(partial)\n for attr_name, field_obj in self.load_fields.items():\n field_name = (\n field_obj.data_key if field_obj.data_key is not None else attr_name\n )\n raw_value = data.get(field_name, missing)\n if raw_value is missing:\n # Ignore missing field if we're allowed to.\n if partial is True or (\n partial_is_collection and attr_name in partial\n ):\n continue\n d_kwargs = {}\n # Allow partial loading of nested schemas.\n if partial_is_collection:\n prefix = field_name + \".\"\n len_prefix = len(prefix)\n sub_partial = [\n f[len_prefix:] for f in partial if f.startswith(prefix)\n ]\n d_kwargs[\"partial\"] = sub_partial\n elif partial is not None:\n d_kwargs[\"partial\"] = partial\n\n def getter(\n val, field_obj=field_obj, field_name=field_name, d_kwargs=d_kwargs\n ):\n return field_obj.deserialize(\n val,\n field_name,\n data,\n **d_kwargs,\n )\n\n value = self._call_and_store(\n getter_func=getter,\n data=raw_value,\n field_name=field_name,\n error_store=error_store,\n index=index,\n )\n if value is not missing:\n key = field_obj.attribute or attr_name\n set_value(ret_d, key, value)\n if unknown != EXCLUDE:\n fields = {\n field_obj.data_key if field_obj.data_key is not None else field_name\n for field_name, field_obj in self.load_fields.items()\n }\n for key in set(data) - fields:\n value = data[key]\n if unknown == INCLUDE:\n ret_d[key] = value\n elif unknown == RAISE:\n error_store.store_error(\n [self.error_messages[\"unknown\"]],\n key,\n (index if index_errors else None),\n )\n return ret_d\n\n def load(\n self,\n data: (\n typing.Mapping[str, typing.Any]\n | typing.Iterable[typing.Mapping[str, typing.Any]]\n ),\n *,\n many: bool | None = None,\n partial: bool | types.StrSequenceOrSet | None = None,\n unknown: str | None = None,\n ):\n \"\"\"Deserialize a data structure to an object defined by this Schema's fields.\n\n :param data: The data to deserialize.\n :param many: Whether to deserialize `data` as a collection. If `None`, the\n value for `self.many` is used.\n :param partial: Whether to ignore missing fields and not require\n any fields declared. Propagates down to ``Nested`` fields as well. If\n its value is an iterable, only missing fields listed in that iterable\n will be ignored. Use dot delimiters to specify nested fields.\n :param unknown: Whether to exclude, include, or raise an error for unknown\n fields in the data. Use `EXCLUDE`, `INCLUDE` or `RAISE`.\n If `None`, the value for `self.unknown` is used.\n :return: Deserialized data\n\n .. versionadded:: 1.0.0\n .. versionchanged:: 3.0.0b7\n This method returns the deserialized data rather than a ``(data, errors)`` duple.\n A :exc:`ValidationError ` is raised\n if invalid data are passed.\n \"\"\"\n return self._do_load(\n data, many=many, partial=partial, unknown=unknown, postprocess=True\n )\n\n def loads(\n self,\n json_data: str,\n *,\n many: bool | None = None,\n partial: bool | types.StrSequenceOrSet | None = None,\n unknown: str | None = None,\n **kwargs,\n ):\n \"\"\"Same as :meth:`load`, except it takes a JSON string as input.\n\n :param json_data: A JSON string of the data to deserialize.\n :param many: Whether to deserialize `obj` as a collection. If `None`, the\n value for `self.many` is used.\n :param partial: Whether to ignore missing fields and not require\n any fields declared. Propagates down to ``Nested`` fields as well. If\n its value is an iterable, only missing fields listed in that iterable\n will be ignored. Use dot delimiters to specify nested fields.\n :param unknown: Whether to exclude, include, or raise an error for unknown\n fields in the data. Use `EXCLUDE`, `INCLUDE` or `RAISE`.\n If `None`, the value for `self.unknown` is used.\n :return: Deserialized data\n\n .. versionadded:: 1.0.0\n .. versionchanged:: 3.0.0b7\n This method returns the deserialized data rather than a ``(data, errors)`` duple.\n A :exc:`ValidationError ` is raised\n if invalid data are passed.\n \"\"\"\n data = self.opts.render_module.loads(json_data, **kwargs)\n return self.load(data, many=many, partial=partial, unknown=unknown)\n\n def _run_validator(\n self,\n validator_func,\n output,\n *,\n original_data,\n error_store,\n many,\n partial,\n pass_original,\n index=None,\n ):\n try:\n if pass_original: # Pass original, raw data (before unmarshalling)\n validator_func(output, original_data, partial=partial, many=many)\n else:\n validator_func(output, partial=partial, many=many)\n except ValidationError as err:\n error_store.store_error(err.messages, err.field_name, index=index)\n\n def validate(\n self,\n data: (\n typing.Mapping[str, typing.Any]\n | typing.Iterable[typing.Mapping[str, typing.Any]]\n ),\n *,\n many: bool | None = None,\n partial: bool | types.StrSequenceOrSet | None = None,\n ) -> dict[str, list[str]]:\n \"\"\"Validate `data` against the schema, returning a dictionary of\n validation errors.\n\n :param data: The data to validate.\n :param many: Whether to validate `data` as a collection. If `None`, the\n value for `self.many` is used.\n :param partial: Whether to ignore missing fields and not require\n any fields declared. Propagates down to ``Nested`` fields as well. If\n its value is an iterable, only missing fields listed in that iterable\n will be ignored. Use dot delimiters to specify nested fields.\n :return: A dictionary of validation errors.\n\n .. versionadded:: 1.1.0\n \"\"\"\n try:\n self._do_load(data, many=many, partial=partial, postprocess=False)\n except ValidationError as exc:\n return typing.cast(typing.Dict[str, typing.List[str]], exc.messages)\n return {}\n\n ##### Private Helpers #####\n\n def _do_load(\n self,\n data: (\n typing.Mapping[str, typing.Any]\n | typing.Iterable[typing.Mapping[str, typing.Any]]\n ),\n *,\n many: bool | None = None,\n partial: bool | types.StrSequenceOrSet | None = None,\n unknown: str | None = None,\n postprocess: bool = True,\n ):\n \"\"\"Deserialize `data`, returning the deserialized result.\n This method is private API.\n\n :param data: The data to deserialize.\n :param many: Whether to deserialize `data` as a collection. If `None`, the\n value for `self.many` is used.\n :param partial: Whether to validate required fields. If its\n value is an iterable, only fields listed in that iterable will be\n ignored will be allowed missing. If `True`, all fields will be allowed missing.\n If `None`, the value for `self.partial` is used.\n :param unknown: Whether to exclude, include, or raise an error for unknown\n fields in the data. Use `EXCLUDE`, `INCLUDE` or `RAISE`.\n If `None`, the value for `self.unknown` is used.\n :param postprocess: Whether to run post_load methods..\n :return: Deserialized data\n \"\"\"\n error_store = ErrorStore()\n errors = {} # type: dict[str, list[str]]\n many = self.many if many is None else bool(many)\n unknown = (\n self.unknown\n if unknown is None\n else validate_unknown_parameter_value(unknown)\n )\n if partial is None:\n partial = self.partial\n # Run preprocessors\n if self._has_processors(PRE_LOAD):\n try:\n processed_data = self._invoke_load_processors(\n PRE_LOAD, data, many=many, original_data=data, partial=partial\n )\n except ValidationError as err:\n errors = err.normalized_messages()\n result = None # type: list | dict | None\n else:\n processed_data = data\n if not errors:\n # Deserialize data\n result = self._deserialize(\n processed_data,\n error_store=error_store,\n many=many,\n partial=partial,\n unknown=unknown,\n )\n # Run field-level validation\n self._invoke_field_validators(\n error_store=error_store, data=result, many=many\n )\n # Run schema-level validation\n if self._has_processors(VALIDATES_SCHEMA):\n field_errors = bool(error_store.errors)\n self._invoke_schema_validators(\n error_store=error_store,\n pass_many=True,\n data=result,\n original_data=data,\n many=many,\n partial=partial,\n field_errors=field_errors,\n )\n self._invoke_schema_validators(\n error_store=error_store,\n pass_many=False,\n data=result,\n original_data=data,\n many=many,\n partial=partial,\n field_errors=field_errors,\n )\n errors = error_store.errors\n # Run post processors\n if not errors and postprocess and self._has_processors(POST_LOAD):\n try:\n result = self._invoke_load_processors(\n POST_LOAD,\n result,\n many=many,\n original_data=data,\n partial=partial,\n )\n except ValidationError as err:\n errors = err.normalized_messages()\n if errors:\n exc = ValidationError(errors, data=data, valid_data=result)\n self.handle_error(exc, data, many=many, partial=partial)\n raise exc\n\n return result\n\n def _normalize_nested_options(self) -> None:\n \"\"\"Apply then flatten nested schema options.\n This method is private API.\n \"\"\"\n if self.only is not None:\n # Apply the only option to nested fields.\n self.__apply_nested_option(\"only\", self.only, \"intersection\")\n # Remove the child field names from the only option.\n self.only = self.set_class([field.split(\".\", 1)[0] for field in self.only])\n if self.exclude:\n # Apply the exclude option to nested fields.\n self.__apply_nested_option(\"exclude\", self.exclude, \"union\")\n # Remove the parent field names from the exclude option.\n self.exclude = self.set_class(\n [field for field in self.exclude if \".\" not in field]\n )\n\n def __apply_nested_option(self, option_name, field_names, set_operation) -> None:\n \"\"\"Apply nested options to nested fields\"\"\"\n # Split nested field names on the first dot.\n nested_fields = [name.split(\".\", 1) for name in field_names if \".\" in name]\n # Partition the nested field names by parent field.\n nested_options = defaultdict(list) # type: defaultdict\n for parent, nested_names in nested_fields:\n nested_options[parent].append(nested_names)\n # Apply the nested field options.\n for key, options in iter(nested_options.items()):\n new_options = self.set_class(options)\n original_options = getattr(self.declared_fields[key], option_name, ())\n if original_options:\n if set_operation == \"union\":\n new_options |= self.set_class(original_options)\n if set_operation == \"intersection\":\n new_options &= self.set_class(original_options)\n setattr(self.declared_fields[key], option_name, new_options)\n\n def _init_fields(self) -> None:\n \"\"\"Update self.fields, self.load_fields, and self.dump_fields based on schema options.\n This method is private API.\n \"\"\"\n if self.opts.fields:\n available_field_names = self.set_class(self.opts.fields)\n else:\n available_field_names = self.set_class(self.declared_fields.keys())\n if self.opts.additional:\n available_field_names |= self.set_class(self.opts.additional)\n\n invalid_fields = self.set_class()\n\n if self.only is not None:\n # Return only fields specified in only option\n field_names: typing.AbstractSet[typing.Any] = self.set_class(self.only)\n\n invalid_fields |= field_names - available_field_names\n else:\n field_names = available_field_names\n\n # If \"exclude\" option or param is specified, remove those fields.\n if self.exclude:\n # Note that this isn't available_field_names, since we want to\n # apply \"only\" for the actual calculation.\n field_names = field_names - self.exclude\n invalid_fields |= self.exclude - available_field_names\n\n if invalid_fields:\n message = f\"Invalid fields for {self}: {invalid_fields}.\"\n raise ValueError(message)\n\n fields_dict = self.dict_class()\n for field_name in field_names:\n field_obj = self.declared_fields.get(field_name, ma_fields.Inferred())\n self._bind_field(field_name, field_obj)\n fields_dict[field_name] = field_obj\n\n load_fields, dump_fields = self.dict_class(), self.dict_class()\n for field_name, field_obj in fields_dict.items():\n if not field_obj.dump_only:\n load_fields[field_name] = field_obj\n if not field_obj.load_only:\n dump_fields[field_name] = field_obj\n\n dump_data_keys = [\n field_obj.data_key if field_obj.data_key is not None else name\n for name, field_obj in dump_fields.items()\n ]\n if len(dump_data_keys) != len(set(dump_data_keys)):\n data_keys_duplicates = {\n x for x in dump_data_keys if dump_data_keys.count(x) > 1\n }\n raise ValueError(\n \"The data_key argument for one or more fields collides \"\n \"with another field's name or data_key argument. \"\n \"Check the following field names and \"\n f\"data_key arguments: {list(data_keys_duplicates)}\"\n )\n load_attributes = [obj.attribute or name for name, obj in load_fields.items()]\n if len(load_attributes) != len(set(load_attributes)):\n attributes_duplicates = {\n x for x in load_attributes if load_attributes.count(x) > 1\n }\n raise ValueError(\n \"The attribute argument for one or more fields collides \"\n \"with another field's name or attribute argument. \"\n \"Check the following field names and \"\n f\"attribute arguments: {list(attributes_duplicates)}\"\n )\n\n self.fields = fields_dict\n self.dump_fields = dump_fields\n self.load_fields = load_fields\n\n def on_bind_field(self, field_name: str, field_obj: ma_fields.Field) -> None:\n \"\"\"Hook to modify a field when it is bound to the `Schema`.\n\n No-op by default.\n \"\"\"\n return None\n\n def _bind_field(self, field_name: str, field_obj: ma_fields.Field) -> None:\n \"\"\"Bind field to the schema, setting any necessary attributes on the\n field (e.g. parent and name).\n\n Also set field load_only and dump_only values if field_name was\n specified in ``class Meta``.\n \"\"\"\n if field_name in self.load_only:\n field_obj.load_only = True\n if field_name in self.dump_only:\n field_obj.dump_only = True\n try:\n field_obj._bind_to_schema(field_name, self)\n except TypeError as error:\n # Field declared as a class, not an instance. Ignore type checking because\n # we handle unsupported arg types, i.e. this is dead code from\n # the type checker's perspective.\n if isinstance(field_obj, type) and issubclass(field_obj, base.FieldABC):\n msg = (\n f'Field for \"{field_name}\" must be declared as a '\n \"Field instance, not a class. \"\n f'Did you mean \"fields.{field_obj.__name__}()\"?' # type: ignore\n )\n raise TypeError(msg) from error\n raise error\n self.on_bind_field(field_name, field_obj)\n\n @lru_cache(maxsize=8) # noqa (https://github.com/PyCQA/flake8-bugbear/issues/310)\n def _has_processors(self, tag) -> bool:\n return bool(self._hooks[(tag, True)] or self._hooks[(tag, False)])\n\n def _invoke_dump_processors(\n self, tag: str, data, *, many: bool, original_data=None\n ):\n # The pass_many post-dump processors may do things like add an envelope, so\n # invoke those after invoking the non-pass_many processors which will expect\n # to get a list of items.\n data = self._invoke_processors(\n tag, pass_many=False, data=data, many=many, original_data=original_data\n )\n data = self._invoke_processors(\n tag, pass_many=True, data=data, many=many, original_data=original_data\n )\n return data\n\n def _invoke_load_processors(\n self,\n tag: str,\n data,\n *,\n many: bool,\n original_data,\n partial: bool | types.StrSequenceOrSet | None,\n ):\n # This has to invert the order of the dump processors, so run the pass_many\n # processors first.\n data = self._invoke_processors(\n tag,\n pass_many=True,\n data=data,\n many=many,\n original_data=original_data,\n partial=partial,\n )\n data = self._invoke_processors(\n tag,\n pass_many=False,\n data=data,\n many=many,\n original_data=original_data,\n partial=partial,\n )\n return data\n\n def _invoke_field_validators(self, *, error_store: ErrorStore, data, many: bool):\n for attr_name in self._hooks[VALIDATES]:\n validator = getattr(self, attr_name)\n validator_kwargs = validator.__marshmallow_hook__[VALIDATES]\n field_name = validator_kwargs[\"field_name\"]\n\n try:\n field_obj = self.fields[field_name]\n except KeyError as error:\n if field_name in self.declared_fields:\n continue\n raise ValueError(f'\"{field_name}\" field does not exist.') from error\n\n data_key = (\n field_obj.data_key if field_obj.data_key is not None else field_name\n )\n if many:\n for idx, item in enumerate(data):\n try:\n value = item[field_obj.attribute or field_name]\n except KeyError:\n pass\n else:\n validated_value = self._call_and_store(\n getter_func=validator,\n data=value,\n field_name=data_key,\n error_store=error_store,\n index=(idx if self.opts.index_errors else None),\n )\n if validated_value is missing:\n data[idx].pop(field_name, None)\n else:\n try:\n value = data[field_obj.attribute or field_name]\n except KeyError:\n pass\n else:\n validated_value = self._call_and_store(\n getter_func=validator,\n data=value,\n field_name=data_key,\n error_store=error_store,\n )\n if validated_value is missing:\n data.pop(field_name, None)\n\n def _invoke_schema_validators(\n self,\n *,\n error_store: ErrorStore,\n pass_many: bool,\n data,\n original_data,\n many: bool,\n partial: bool | types.StrSequenceOrSet | None,\n field_errors: bool = False,\n ):\n for attr_name in self._hooks[(VALIDATES_SCHEMA, pass_many)]:\n validator = getattr(self, attr_name)\n validator_kwargs = validator.__marshmallow_hook__[\n (VALIDATES_SCHEMA, pass_many)\n ]\n if field_errors and validator_kwargs[\"skip_on_field_errors\"]:\n continue\n pass_original = validator_kwargs.get(\"pass_original\", False)\n\n if many and not pass_many:\n for idx, (item, orig) in enumerate(zip(data, original_data)):\n self._run_validator(\n validator,\n item,\n original_data=orig,\n error_store=error_store,\n many=many,\n partial=partial,\n index=idx,\n pass_original=pass_original,\n )\n else:\n self._run_validator(\n validator,\n data,\n original_data=original_data,\n error_store=error_store,\n many=many,\n pass_original=pass_original,\n partial=partial,\n )\n\n def _invoke_processors(\n self,\n tag: str,\n *,\n pass_many: bool,\n data,\n many: bool,\n original_data=None,\n **kwargs,\n ):\n key = (tag, pass_many)\n for attr_name in self._hooks[key]:\n # This will be a bound method.\n processor = getattr(self, attr_name)\n\n processor_kwargs = processor.__marshmallow_hook__[key]\n pass_original = processor_kwargs.get(\"pass_original\", False)\n\n if many and not pass_many:\n if pass_original:\n data = [\n processor(item, original, many=many, **kwargs)\n for item, original in zip(data, original_data)\n ]\n else:\n data = [processor(item, many=many, **kwargs) for item in data]\n else:\n if pass_original:\n data = processor(data, original_data, many=many, **kwargs)\n else:\n data = processor(data, many=many, **kwargs)\n return data\n\n\nBaseSchema = Schema # for backwards compatibility\n\n# Path: src/marshmallow/validate.py\n\"\"\"Validation classes for various types of data.\"\"\"\nfrom __future__ import annotations\n\nimport re\nimport typing\nfrom abc import ABC, abstractmethod\nfrom itertools import zip_longest\nfrom operator import attrgetter\n\nfrom marshmallow import types\nfrom marshmallow.exceptions import ValidationError\n\n_T = typing.TypeVar(\"_T\")\n\n\nclass Validator(ABC):\n \"\"\"Abstract base class for validators.\n\n .. note::\n This class does not provide any validation behavior. It is only used to\n add a useful `__repr__` implementation for validators.\n \"\"\"\n\n error = None # type: str | None\n\n def __repr__(self) -> str:\n args = self._repr_args()\n args = f\"{args}, \" if args else \"\"\n\n return f\"<{self.__class__.__name__}({args}error={self.error!r})>\"\n\n def _repr_args(self) -> str:\n \"\"\"A string representation of the args passed to this validator. Used by\n `__repr__`.\n \"\"\"\n return \"\"\n\n @abstractmethod\n def __call__(self, value: typing.Any) -> typing.Any:\n ...\n\n\nclass And(Validator):\n \"\"\"Compose multiple validators and combine their error messages.\n\n Example: ::\n\n from marshmallow import validate, ValidationError\n\n def is_even(value):\n if value % 2 != 0:\n raise ValidationError(\"Not an even value.\")\n\n validator = validate.And(validate.Range(min=0), is_even)\n validator(-1)\n # ValidationError: ['Must be greater than or equal to 0.', 'Not an even value.']\n\n :param validators: Validators to combine.\n :param error: Error message to use when a validator returns ``False``.\n \"\"\"\n\n default_error_message = \"Invalid value.\"\n\n def __init__(self, *validators: types.Validator, error: str | None = None):\n self.validators = tuple(validators)\n self.error = error or self.default_error_message # type: str\n\n def _repr_args(self) -> str:\n return f\"validators={self.validators!r}\"\n\n def __call__(self, value: typing.Any) -> typing.Any:\n errors = []\n kwargs = {}\n for validator in self.validators:\n try:\n r = validator(value)\n if not isinstance(validator, Validator) and r is False:\n raise ValidationError(self.error)\n except ValidationError as err:\n kwargs.update(err.kwargs)\n if isinstance(err.messages, dict):\n errors.append(err.messages)\n else:\n # FIXME : Get rid of cast\n errors.extend(typing.cast(list, err.messages))\n if errors:\n raise ValidationError(errors, **kwargs)\n return value\n\n\nclass URL(Validator):\n \"\"\"Validate a URL.\n\n :param relative: Whether to allow relative URLs.\n :param absolute: Whether to allow absolute URLs.\n :param error: Error message to raise in case of a validation error.\n Can be interpolated with `{input}`.\n :param schemes: Valid schemes. By default, ``http``, ``https``,\n ``ftp``, and ``ftps`` are allowed.\n :param require_tld: Whether to reject non-FQDN hostnames.\n \"\"\"\n\n class RegexMemoizer:\n def __init__(self):\n self._memoized = {}\n\n def _regex_generator(\n self, relative: bool, absolute: bool, require_tld: bool\n ) -> typing.Pattern:\n hostname_variants = [\n # a normal domain name, expressed in [A-Z0-9] chars with hyphens allowed only in the middle\n # note that the regex will be compiled with IGNORECASE, so these are upper and lowercase chars\n (\n r\"(?:[A-Z0-9](?:[A-Z0-9-]{0,61}[A-Z0-9])?\\.)+\"\n r\"(?:[A-Z]{2,6}\\.?|[A-Z0-9-]{2,}\\.?)\"\n ),\n # or the special string 'localhost'\n r\"localhost\",\n # or IPv4\n r\"\\d{1,3}\\.\\d{1,3}\\.\\d{1,3}\\.\\d{1,3}\",\n # or IPv6\n r\"\\[[A-F0-9]*:[A-F0-9:]+\\]\",\n ]\n if not require_tld:\n # allow dotless hostnames\n hostname_variants.append(r\"(?:[A-Z0-9](?:[A-Z0-9-]{0,61}[A-Z0-9])?\\.?)\")\n\n absolute_part = \"\".join(\n (\n # scheme (e.g. 'https://', 'ftp://', etc)\n # this is validated separately against allowed schemes, so in the regex\n # we simply want to capture its existence\n r\"(?:[a-z0-9\\.\\-\\+]*)://\",\n # userinfo, for URLs encoding authentication\n # e.g. 'ftp://foo:bar@ftp.example.org/'\n r\"(?:(?:[a-z0-9\\-._~!$&'()*+,;=:]|%[0-9a-f]{2})*@)?\",\n # netloc, the hostname/domain part of the URL plus the optional port\n r\"(?:\",\n \"|\".join(hostname_variants),\n r\")\",\n r\"(?::\\d+)?\",\n )\n )\n relative_part = r\"(?:/?|[/?]\\S+)\\Z\"\n\n if relative:\n if absolute:\n parts: tuple[str, ...] = (\n r\"^(\",\n absolute_part,\n r\")?\",\n relative_part,\n )\n else:\n parts = (r\"^\", relative_part)\n else:\n parts = (r\"^\", absolute_part, relative_part)\n\n return re.compile(\"\".join(parts), re.IGNORECASE)\n\n def __call__(\n self, relative: bool, absolute: bool, require_tld: bool\n ) -> typing.Pattern:\n key = (relative, absolute, require_tld)\n if key not in self._memoized:\n self._memoized[key] = self._regex_generator(\n relative, absolute, require_tld\n )\n\n return self._memoized[key]\n\n _regex = RegexMemoizer()\n\n default_message = \"Not a valid URL.\"\n default_schemes = {\"http\", \"https\", \"ftp\", \"ftps\"}\n\n def __init__(\n self,\n *,\n relative: bool = False,\n absolute: bool = True,\n schemes: types.StrSequenceOrSet | None = None,\n require_tld: bool = True,\n error: str | None = None,\n ):\n if not relative and not absolute:\n raise ValueError(\n \"URL validation cannot set both relative and absolute to False.\"\n )\n self.relative = relative\n self.absolute = absolute\n self.error = error or self.default_message # type: str\n self.schemes = schemes or self.default_schemes\n self.require_tld = require_tld\n\n def _repr_args(self) -> str:\n return f\"relative={self.relative!r}, absolute={self.absolute!r}\"\n\n def _format_error(self, value) -> str:\n return self.error.format(input=value)\n\n def __call__(self, value: str) -> str:\n message = self._format_error(value)\n if not value:\n raise ValidationError(message)\n\n # Check first if the scheme is valid\n if \"://\" in value:\n scheme = value.split(\"://\")[0].lower()\n if scheme not in self.schemes:\n raise ValidationError(message)\n\n regex = self._regex(self.relative, self.absolute, self.require_tld)\n\n if not regex.search(value):\n raise ValidationError(message)\n\n return value\n\n\nclass Email(Validator):\n \"\"\"Validate an email address.\n\n :param error: Error message to raise in case of a validation error. Can be\n interpolated with `{input}`.\n \"\"\"\n\n USER_REGEX = re.compile(\n r\"(^[-!#$%&'*+/=?^`{}|~\\w]+(\\.[-!#$%&'*+/=?^`{}|~\\w]+)*\\Z\" # dot-atom\n # quoted-string\n r'|^\"([\\001-\\010\\013\\014\\016-\\037!#-\\[\\]-\\177]'\n r'|\\\\[\\001-\\011\\013\\014\\016-\\177])*\"\\Z)',\n re.IGNORECASE | re.UNICODE,\n )\n\n DOMAIN_REGEX = re.compile(\n # domain\n r\"(?:[A-Z0-9](?:[A-Z0-9-]{0,61}[A-Z0-9])?\\.)+\"\n r\"(?:[A-Z]{2,6}|[A-Z0-9-]{2,})\\Z\"\n # literal form, ipv4 address (SMTP 4.1.3)\n r\"|^\\[(25[0-5]|2[0-4]\\d|[0-1]?\\d?\\d)\"\n r\"(\\.(25[0-5]|2[0-4]\\d|[0-1]?\\d?\\d)){3}\\]\\Z\",\n re.IGNORECASE | re.UNICODE,\n )\n\n DOMAIN_WHITELIST = (\"localhost\",)\n\n default_message = \"Not a valid email address.\"\n\n def __init__(self, *, error: str | None = None):\n self.error = error or self.default_message # type: str\n\n def _format_error(self, value: str) -> str:\n return self.error.format(input=value)\n\n def __call__(self, value: str) -> str:\n message = self._format_error(value)\n\n if not value or \"@\" not in value:\n raise ValidationError(message)\n\n user_part, domain_part = value.rsplit(\"@\", 1)\n\n if not self.USER_REGEX.match(user_part):\n raise ValidationError(message)\n\n if domain_part not in self.DOMAIN_WHITELIST:\n if not self.DOMAIN_REGEX.match(domain_part):\n try:\n domain_part = domain_part.encode(\"idna\").decode(\"ascii\")\n except UnicodeError:\n pass\n else:\n if self.DOMAIN_REGEX.match(domain_part):\n return value\n raise ValidationError(message)\n\n return value\n\n\nclass Range(Validator):\n \"\"\"Validator which succeeds if the value passed to it is within the specified\n range. If ``min`` is not specified, or is specified as `None`,\n no lower bound exists. If ``max`` is not specified, or is specified as `None`,\n no upper bound exists. The inclusivity of the bounds (if they exist) is configurable.\n If ``min_inclusive`` is not specified, or is specified as `True`, then\n the ``min`` bound is included in the range. If ``max_inclusive`` is not specified,\n or is specified as `True`, then the ``max`` bound is included in the range.\n\n :param min: The minimum value (lower bound). If not provided, minimum\n value will not be checked.\n :param max: The maximum value (upper bound). If not provided, maximum\n value will not be checked.\n :param min_inclusive: Whether the `min` bound is included in the range.\n :param max_inclusive: Whether the `max` bound is included in the range.\n :param error: Error message to raise in case of a validation error.\n Can be interpolated with `{input}`, `{min}` and `{max}`.\n \"\"\"\n\n message_min = \"Must be {min_op} {{min}}.\"\n message_max = \"Must be {max_op} {{max}}.\"\n message_all = \"Must be {min_op} {{min}} and {max_op} {{max}}.\"\n\n message_gte = \"greater than or equal to\"\n message_gt = \"greater than\"\n message_lte = \"less than or equal to\"\n message_lt = \"less than\"\n\n def __init__(\n self,\n min=None,\n max=None,\n *,\n min_inclusive: bool = True,\n max_inclusive: bool = True,\n error: str | None = None,\n ):\n self.min = min\n self.max = max\n self.error = error\n self.min_inclusive = min_inclusive\n self.max_inclusive = max_inclusive\n\n # interpolate messages based on bound inclusivity\n self.message_min = self.message_min.format(\n min_op=self.message_gte if self.min_inclusive else self.message_gt\n )\n self.message_max = self.message_max.format(\n max_op=self.message_lte if self.max_inclusive else self.message_lt\n )\n self.message_all = self.message_all.format(\n min_op=self.message_gte if self.min_inclusive else self.message_gt,\n max_op=self.message_lte if self.max_inclusive else self.message_lt,\n )\n\n def _repr_args(self) -> str:\n...\n", "needle_token_start": 4033, "needle_token_end": 4055, "code_context_ntokens": 16179, "output": ["```python\ndef dict_class(self) -> type:\n return OrderedDict if self.ordered else dict\n```"]} +{"repo": "marshmallow-code/marshmallow", "name": "options", "language": "python", "path": "src/marshmallow/validate.py", "position_ratio": 0.35, "description": "\nFunction Description:\n1. **Purpose**: The function is designed to generate a sequence of pairs, each consisting of a string representation of a choice and its corresponding label. This is particularly useful for populating user interface elements like dropdown menus in forms.\n2. **Input**: Accepts either a callable function or a string. The callable should take one argument and return a value that represents a choice. If a string is provided, it is used to specify the name of an attribute of the choice objects.\n3. **Output**: Produces an iterable of tuples, where each tuple contains two elements: a string representation of a choice (obtained via the provided callable or attribute name) and a label for that choice.\n4. **Procedure**: The function first determines how to obtain the value representation of each choice, either by using a provided callable or by fetching the specified attribute of the choice. It then pairs each choice with its corresponding label (or a default empty string if no label is provided) and returns these pairs as an iterable.\n\n", "instruction": "Based on the function description and code context, please retrieve and repeat the exact described function from the code context in a code block wrapped by ```:", "template": "instruction\ncode_context\ndescription\ninstruction", "code_context": "\n :param relative: Whether to allow relative URLs.\n :param absolute: Whether to allow absolute URLs.\n :param error: Error message to raise in case of a validation error.\n Can be interpolated with `{input}`.\n :param schemes: Valid schemes. By default, ``http``, ``https``,\n ``ftp``, and ``ftps`` are allowed.\n :param require_tld: Whether to reject non-FQDN hostnames.\n \"\"\"\n\n class RegexMemoizer:\n def __init__(self):\n self._memoized = {}\n\n def _regex_generator(\n self, relative: bool, absolute: bool, require_tld: bool\n ) -> typing.Pattern:\n hostname_variants = [\n # a normal domain name, expressed in [A-Z0-9] chars with hyphens allowed only in the middle\n # note that the regex will be compiled with IGNORECASE, so these are upper and lowercase chars\n (\n r\"(?:[A-Z0-9](?:[A-Z0-9-]{0,61}[A-Z0-9])?\\.)+\"\n r\"(?:[A-Z]{2,6}\\.?|[A-Z0-9-]{2,}\\.?)\"\n ),\n # or the special string 'localhost'\n r\"localhost\",\n # or IPv4\n r\"\\d{1,3}\\.\\d{1,3}\\.\\d{1,3}\\.\\d{1,3}\",\n # or IPv6\n r\"\\[[A-F0-9]*:[A-F0-9:]+\\]\",\n ]\n if not require_tld:\n # allow dotless hostnames\n hostname_variants.append(r\"(?:[A-Z0-9](?:[A-Z0-9-]{0,61}[A-Z0-9])?\\.?)\")\n\n absolute_part = \"\".join(\n (\n # scheme (e.g. 'https://', 'ftp://', etc)\n # this is validated separately against allowed schemes, so in the regex\n # we simply want to capture its existence\n r\"(?:[a-z0-9\\.\\-\\+]*)://\",\n # userinfo, for URLs encoding authentication\n # e.g. 'ftp://foo:bar@ftp.example.org/'\n r\"(?:(?:[a-z0-9\\-._~!$&'()*+,;=:]|%[0-9a-f]{2})*@)?\",\n # netloc, the hostname/domain part of the URL plus the optional port\n r\"(?:\",\n \"|\".join(hostname_variants),\n r\")\",\n r\"(?::\\d+)?\",\n )\n )\n relative_part = r\"(?:/?|[/?]\\S+)\\Z\"\n\n if relative:\n if absolute:\n parts: tuple[str, ...] = (\n r\"^(\",\n absolute_part,\n r\")?\",\n relative_part,\n )\n else:\n parts = (r\"^\", relative_part)\n else:\n parts = (r\"^\", absolute_part, relative_part)\n\n return re.compile(\"\".join(parts), re.IGNORECASE)\n\n def __call__(\n self, relative: bool, absolute: bool, require_tld: bool\n ) -> typing.Pattern:\n key = (relative, absolute, require_tld)\n if key not in self._memoized:\n self._memoized[key] = self._regex_generator(\n relative, absolute, require_tld\n )\n\n return self._memoized[key]\n\n _regex = RegexMemoizer()\n\n default_message = \"Not a valid URL.\"\n default_schemes = {\"http\", \"https\", \"ftp\", \"ftps\"}\n\n def __init__(\n self,\n *,\n relative: bool = False,\n absolute: bool = True,\n schemes: types.StrSequenceOrSet | None = None,\n require_tld: bool = True,\n error: str | None = None,\n ):\n if not relative and not absolute:\n raise ValueError(\n \"URL validation cannot set both relative and absolute to False.\"\n )\n self.relative = relative\n self.absolute = absolute\n self.error = error or self.default_message # type: str\n self.schemes = schemes or self.default_schemes\n self.require_tld = require_tld\n\n def _repr_args(self) -> str:\n return f\"relative={self.relative!r}, absolute={self.absolute!r}\"\n\n def _format_error(self, value) -> str:\n return self.error.format(input=value)\n\n def __call__(self, value: str) -> str:\n message = self._format_error(value)\n if not value:\n raise ValidationError(message)\n\n # Check first if the scheme is valid\n if \"://\" in value:\n scheme = value.split(\"://\")[0].lower()\n if scheme not in self.schemes:\n raise ValidationError(message)\n\n regex = self._regex(self.relative, self.absolute, self.require_tld)\n\n if not regex.search(value):\n raise ValidationError(message)\n\n return value\n\n\nclass Email(Validator):\n \"\"\"Validate an email address.\n\n :param error: Error message to raise in case of a validation error. Can be\n interpolated with `{input}`.\n \"\"\"\n\n USER_REGEX = re.compile(\n r\"(^[-!#$%&'*+/=?^`{}|~\\w]+(\\.[-!#$%&'*+/=?^`{}|~\\w]+)*\\Z\" # dot-atom\n # quoted-string\n r'|^\"([\\001-\\010\\013\\014\\016-\\037!#-\\[\\]-\\177]'\n r'|\\\\[\\001-\\011\\013\\014\\016-\\177])*\"\\Z)',\n re.IGNORECASE | re.UNICODE,\n )\n\n DOMAIN_REGEX = re.compile(\n # domain\n r\"(?:[A-Z0-9](?:[A-Z0-9-]{0,61}[A-Z0-9])?\\.)+\"\n r\"(?:[A-Z]{2,6}|[A-Z0-9-]{2,})\\Z\"\n # literal form, ipv4 address (SMTP 4.1.3)\n r\"|^\\[(25[0-5]|2[0-4]\\d|[0-1]?\\d?\\d)\"\n r\"(\\.(25[0-5]|2[0-4]\\d|[0-1]?\\d?\\d)){3}\\]\\Z\",\n re.IGNORECASE | re.UNICODE,\n )\n\n DOMAIN_WHITELIST = (\"localhost\",)\n\n default_message = \"Not a valid email address.\"\n\n def __init__(self, *, error: str | None = None):\n self.error = error or self.default_message # type: str\n\n def _format_error(self, value: str) -> str:\n return self.error.format(input=value)\n\n def __call__(self, value: str) -> str:\n message = self._format_error(value)\n\n if not value or \"@\" not in value:\n raise ValidationError(message)\n\n user_part, domain_part = value.rsplit(\"@\", 1)\n\n if not self.USER_REGEX.match(user_part):\n raise ValidationError(message)\n\n if domain_part not in self.DOMAIN_WHITELIST:\n if not self.DOMAIN_REGEX.match(domain_part):\n try:\n domain_part = domain_part.encode(\"idna\").decode(\"ascii\")\n except UnicodeError:\n pass\n else:\n if self.DOMAIN_REGEX.match(domain_part):\n return value\n raise ValidationError(message)\n\n return value\n\n\nclass Range(Validator):\n \"\"\"Validator which succeeds if the value passed to it is within the specified\n range. If ``min`` is not specified, or is specified as `None`,\n no lower bound exists. If ``max`` is not specified, or is specified as `None`,\n no upper bound exists. The inclusivity of the bounds (if they exist) is configurable.\n If ``min_inclusive`` is not specified, or is specified as `True`, then\n the ``min`` bound is included in the range. If ``max_inclusive`` is not specified,\n or is specified as `True`, then the ``max`` bound is included in the range.\n\n :param min: The minimum value (lower bound). If not provided, minimum\n value will not be checked.\n :param max: The maximum value (upper bound). If not provided, maximum\n value will not be checked.\n :param min_inclusive: Whether the `min` bound is included in the range.\n :param max_inclusive: Whether the `max` bound is included in the range.\n :param error: Error message to raise in case of a validation error.\n Can be interpolated with `{input}`, `{min}` and `{max}`.\n \"\"\"\n\n message_min = \"Must be {min_op} {{min}}.\"\n message_max = \"Must be {max_op} {{max}}.\"\n message_all = \"Must be {min_op} {{min}} and {max_op} {{max}}.\"\n\n message_gte = \"greater than or equal to\"\n message_gt = \"greater than\"\n message_lte = \"less than or equal to\"\n message_lt = \"less than\"\n\n def __init__(\n self,\n min=None,\n max=None,\n *,\n min_inclusive: bool = True,\n max_inclusive: bool = True,\n error: str | None = None,\n ):\n self.min = min\n self.max = max\n self.error = error\n self.min_inclusive = min_inclusive\n self.max_inclusive = max_inclusive\n\n # interpolate messages based on bound inclusivity\n self.message_min = self.message_min.format(\n min_op=self.message_gte if self.min_inclusive else self.message_gt\n )\n self.message_max = self.message_max.format(\n max_op=self.message_lte if self.max_inclusive else self.message_lt\n )\n self.message_all = self.message_all.format(\n min_op=self.message_gte if self.min_inclusive else self.message_gt,\n max_op=self.message_lte if self.max_inclusive else self.message_lt,\n )\n\n def _repr_args(self) -> str:\n return \"min={!r}, max={!r}, min_inclusive={!r}, max_inclusive={!r}\".format(\n self.min, self.max, self.min_inclusive, self.max_inclusive\n )\n\n def _format_error(self, value: _T, message: str) -> str:\n return (self.error or message).format(input=value, min=self.min, max=self.max)\n\n def __call__(self, value: _T) -> _T:\n if self.min is not None and (\n value < self.min if self.min_inclusive else value <= self.min\n ):\n message = self.message_min if self.max is None else self.message_all\n raise ValidationError(self._format_error(value, message))\n\n if self.max is not None and (\n value > self.max if self.max_inclusive else value >= self.max\n ):\n message = self.message_max if self.min is None else self.message_all\n raise ValidationError(self._format_error(value, message))\n\n return value\n\n\nclass Length(Validator):\n \"\"\"Validator which succeeds if the value passed to it has a\n length between a minimum and maximum. Uses len(), so it\n can work for strings, lists, or anything with length.\n\n :param min: The minimum length. If not provided, minimum length\n will not be checked.\n :param max: The maximum length. If not provided, maximum length\n will not be checked.\n :param equal: The exact length. If provided, maximum and minimum\n length will not be checked.\n :param error: Error message to raise in case of a validation error.\n Can be interpolated with `{input}`, `{min}` and `{max}`.\n \"\"\"\n\n message_min = \"Shorter than minimum length {min}.\"\n message_max = \"Longer than maximum length {max}.\"\n message_all = \"Length must be between {min} and {max}.\"\n message_equal = \"Length must be {equal}.\"\n\n def __init__(\n self,\n min: int | None = None,\n max: int | None = None,\n *,\n equal: int | None = None,\n error: str | None = None,\n ):\n if equal is not None and any([min, max]):\n raise ValueError(\n \"The `equal` parameter was provided, maximum or \"\n \"minimum parameter must not be provided.\"\n )\n\n self.min = min\n self.max = max\n self.error = error\n self.equal = equal\n\n def _repr_args(self) -> str:\n return f\"min={self.min!r}, max={self.max!r}, equal={self.equal!r}\"\n\n def _format_error(self, value: typing.Sized, message: str) -> str:\n return (self.error or message).format(\n input=value, min=self.min, max=self.max, equal=self.equal\n )\n\n def __call__(self, value: typing.Sized) -> typing.Sized:\n length = len(value)\n\n if self.equal is not None:\n if length != self.equal:\n raise ValidationError(self._format_error(value, self.message_equal))\n return value\n\n if self.min is not None and length < self.min:\n message = self.message_min if self.max is None else self.message_all\n raise ValidationError(self._format_error(value, message))\n\n if self.max is not None and length > self.max:\n message = self.message_max if self.min is None else self.message_all\n raise ValidationError(self._format_error(value, message))\n\n return value\n\n\nclass Equal(Validator):\n \"\"\"Validator which succeeds if the ``value`` passed to it is\n equal to ``comparable``.\n\n :param comparable: The object to compare to.\n :param error: Error message to raise in case of a validation error.\n Can be interpolated with `{input}` and `{other}`.\n \"\"\"\n\n default_message = \"Must be equal to {other}.\"\n\n def __init__(self, comparable, *, error: str | None = None):\n self.comparable = comparable\n self.error = error or self.default_message # type: str\n\n def _repr_args(self) -> str:\n return f\"comparable={self.comparable!r}\"\n\n def _format_error(self, value: _T) -> str:\n return self.error.format(input=value, other=self.comparable)\n\n def __call__(self, value: _T) -> _T:\n if value != self.comparable:\n raise ValidationError(self._format_error(value))\n return value\n\n\nclass Regexp(Validator):\n \"\"\"Validator which succeeds if the ``value`` matches ``regex``.\n\n .. note::\n\n Uses `re.match`, which searches for a match at the beginning of a string.\n\n :param regex: The regular expression string to use. Can also be a compiled\n regular expression pattern.\n :param flags: The regexp flags to use, for example re.IGNORECASE. Ignored\n if ``regex`` is not a string.\n :param error: Error message to raise in case of a validation error.\n Can be interpolated with `{input}` and `{regex}`.\n \"\"\"\n\n default_message = \"String does not match expected pattern.\"\n\n def __init__(\n self,\n regex: str | bytes | typing.Pattern,\n flags: int = 0,\n *,\n error: str | None = None,\n ):\n self.regex = (\n re.compile(regex, flags) if isinstance(regex, (str, bytes)) else regex\n )\n self.error = error or self.default_message # type: str\n\n def _repr_args(self) -> str:\n return f\"regex={self.regex!r}\"\n\n def _format_error(self, value: str | bytes) -> str:\n return self.error.format(input=value, regex=self.regex.pattern)\n\n @typing.overload\n def __call__(self, value: str) -> str:\n ...\n\n @typing.overload\n def __call__(self, value: bytes) -> bytes:\n ...\n\n def __call__(self, value):\n if self.regex.match(value) is None:\n raise ValidationError(self._format_error(value))\n\n return value\n\n\nclass Predicate(Validator):\n \"\"\"Call the specified ``method`` of the ``value`` object. The\n validator succeeds if the invoked method returns an object that\n evaluates to True in a Boolean context. Any additional keyword\n argument will be passed to the method.\n\n :param method: The name of the method to invoke.\n :param error: Error message to raise in case of a validation error.\n Can be interpolated with `{input}` and `{method}`.\n :param kwargs: Additional keyword arguments to pass to the method.\n \"\"\"\n\n default_message = \"Invalid input.\"\n\n def __init__(self, method: str, *, error: str | None = None, **kwargs):\n self.method = method\n self.error = error or self.default_message # type: str\n self.kwargs = kwargs\n\n def _repr_args(self) -> str:\n return f\"method={self.method!r}, kwargs={self.kwargs!r}\"\n\n def _format_error(self, value: typing.Any) -> str:\n return self.error.format(input=value, method=self.method)\n\n def __call__(self, value: typing.Any) -> typing.Any:\n method = getattr(value, self.method)\n\n if not method(**self.kwargs):\n raise ValidationError(self._format_error(value))\n\n return value\n\n\nclass NoneOf(Validator):\n \"\"\"Validator which fails if ``value`` is a member of ``iterable``.\n\n :param iterable: A sequence of invalid values.\n :param error: Error message to raise in case of a validation error. Can be\n interpolated using `{input}` and `{values}`.\n \"\"\"\n\n default_message = \"Invalid input.\"\n\n def __init__(self, iterable: typing.Iterable, *, error: str | None = None):\n self.iterable = iterable\n self.values_text = \", \".join(str(each) for each in self.iterable)\n self.error = error or self.default_message # type: str\n\n def _repr_args(self) -> str:\n return f\"iterable={self.iterable!r}\"\n\n def _format_error(self, value) -> str:\n return self.error.format(input=value, values=self.values_text)\n\n def __call__(self, value: typing.Any) -> typing.Any:\n try:\n if value in self.iterable:\n raise ValidationError(self._format_error(value))\n except TypeError:\n pass\n\n return value\n\n\nclass OneOf(Validator):\n \"\"\"Validator which succeeds if ``value`` is a member of ``choices``.\n\n :param choices: A sequence of valid values.\n :param labels: Optional sequence of labels to pair with the choices.\n :param error: Error message to raise in case of a validation error. Can be\n interpolated with `{input}`, `{choices}` and `{labels}`.\n \"\"\"\n\n default_message = \"Must be one of: {choices}.\"\n\n def __init__(\n self,\n choices: typing.Iterable,\n labels: typing.Iterable[str] | None = None,\n *,\n error: str | None = None,\n ):\n self.choices = choices\n self.choices_text = \", \".join(str(choice) for choice in self.choices)\n self.labels = labels if labels is not None else []\n self.labels_text = \", \".join(str(label) for label in self.labels)\n self.error = error or self.default_message # type: str\n\n def _repr_args(self) -> str:\n return f\"choices={self.choices!r}, labels={self.labels!r}\"\n\n def _format_error(self, value) -> str:\n return self.error.format(\n input=value, choices=self.choices_text, labels=self.labels_text\n )\n\n def __call__(self, value: typing.Any) -> typing.Any:\n try:\n if value not in self.choices:\n raise ValidationError(self._format_error(value))\n except TypeError as error:\n raise ValidationError(self._format_error(value)) from error\n\n return value\n\n \ndef options(\n self,\n valuegetter: str | typing.Callable[[typing.Any], typing.Any] = str,\n ) -> typing.Iterable[tuple[typing.Any, str]]:\n \"\"\"Return a generator over the (value, label) pairs, where value\n is a string associated with each choice. This convenience method\n is useful to populate, for instance, a form select field.\n\n :param valuegetter: Can be a callable or a string. In the former case, it must\n be a one-argument callable which returns the value of a\n choice. In the latter case, the string specifies the name\n of an attribute of the choice objects. Defaults to `str()`\n or `str()`.\n \"\"\"\n valuegetter = valuegetter if callable(valuegetter) else attrgetter(valuegetter)\n pairs = zip_longest(self.choices, self.labels, fillvalue=\"\")\n\n return ((valuegetter(choice), label) for choice, label in pairs)\n\n\nclass ContainsOnly(OneOf):\n \"\"\"Validator which succeeds if ``value`` is a sequence and each element\n in the sequence is also in the sequence passed as ``choices``. Empty input\n is considered valid.\n\n :param iterable choices: Same as :class:`OneOf`.\n :param iterable labels: Same as :class:`OneOf`.\n :param str error: Same as :class:`OneOf`.\n\n .. versionchanged:: 3.0.0b2\n Duplicate values are considered valid.\n .. versionchanged:: 3.0.0b2\n Empty input is considered valid. Use `validate.Length(min=1) `\n to validate against empty inputs.\n \"\"\"\n\n default_message = \"One or more of the choices you made was not in: {choices}.\"\n\n def _format_error(self, value) -> str:\n value_text = \", \".join(str(val) for val in value)\n return super()._format_error(value_text)\n\n def __call__(self, value: typing.Sequence[_T]) -> typing.Sequence[_T]:\n # We can't use set.issubset because does not handle unhashable types\n for val in value:\n if val not in self.choices:\n raise ValidationError(self._format_error(value))\n return value\n\n\nclass ContainsNoneOf(NoneOf):\n \"\"\"Validator which fails if ``value`` is a sequence and any element\n in the sequence is a member of the sequence passed as ``iterable``. Empty input\n is considered valid.\n\n :param iterable iterable: Same as :class:`NoneOf`.\n :param str error: Same as :class:`NoneOf`.\n\n .. versionadded:: 3.6.0\n \"\"\"\n\n default_message = \"One or more of the choices you made was in: {values}.\"\n\n def _format_error(self, value) -> str:\n value_text = \", \".join(str(val) for val in value)\n return super()._format_error(value_text)\n\n def __call__(self, value: typing.Sequence[_T]) -> typing.Sequence[_T]:\n for val in value:\n if val in self.iterable:\n raise ValidationError(self._format_error(value))\n return value\n\n# Path: src/marshmallow/fields.py\n\"\"\"Field classes for various types of data.\"\"\"\nfrom __future__ import annotations\n\nimport collections\nimport copy\nimport datetime as dt\nimport decimal\nimport ipaddress\nimport math\nimport numbers\nimport typing\nimport uuid\nimport warnings\nfrom collections.abc import Mapping as _Mapping\nfrom enum import Enum as EnumType\n\nfrom marshmallow import class_registry, types, utils, validate\nfrom marshmallow.base import FieldABC, SchemaABC\nfrom marshmallow.exceptions import (\n FieldInstanceResolutionError,\n StringNotCollectionError,\n ValidationError,\n)\nfrom marshmallow.utils import (\n is_aware,\n is_collection,\n resolve_field_instance,\n)\nfrom marshmallow.utils import (\n missing as missing_,\n)\nfrom marshmallow.validate import And, Length\nfrom marshmallow.warnings import RemovedInMarshmallow4Warning\n\n__all__ = [\n \"Field\",\n \"Raw\",\n \"Nested\",\n \"Mapping\",\n \"Dict\",\n \"List\",\n \"Tuple\",\n \"String\",\n \"UUID\",\n \"Number\",\n \"Integer\",\n \"Decimal\",\n \"Boolean\",\n \"Float\",\n \"DateTime\",\n \"NaiveDateTime\",\n \"AwareDateTime\",\n \"Time\",\n \"Date\",\n \"TimeDelta\",\n \"Url\",\n \"URL\",\n \"Email\",\n \"IP\",\n \"IPv4\",\n \"IPv6\",\n \"IPInterface\",\n \"IPv4Interface\",\n \"IPv6Interface\",\n \"Enum\",\n \"Method\",\n \"Function\",\n \"Str\",\n \"Bool\",\n \"Int\",\n \"Constant\",\n \"Pluck\",\n]\n\n_T = typing.TypeVar(\"_T\")\n\n\nclass Field(FieldABC):\n \"\"\"Basic field from which other fields should extend. It applies no\n formatting by default, and should only be used in cases where\n data does not need to be formatted before being serialized or deserialized.\n On error, the name of the field will be returned.\n\n :param dump_default: If set, this value will be used during serialization if the\n input value is missing. If not set, the field will be excluded from the\n serialized output if the input value is missing. May be a value or a callable.\n :param load_default: Default deserialization value for the field if the field is not\n found in the input data. May be a value or a callable.\n :param data_key: The name of the dict key in the external representation, i.e.\n the input of `load` and the output of `dump`.\n If `None`, the key will match the name of the field.\n :param attribute: The name of the key/attribute in the internal representation, i.e.\n the output of `load` and the input of `dump`.\n If `None`, the key/attribute will match the name of the field.\n Note: This should only be used for very specific use cases such as\n outputting multiple fields for a single attribute, or using keys/attributes\n that are invalid variable names, unsuitable for field names. In most cases,\n you should use ``data_key`` instead.\n :param validate: Validator or collection of validators that are called\n during deserialization. Validator takes a field's input value as\n its only parameter and returns a boolean.\n If it returns `False`, an :exc:`ValidationError` is raised.\n :param required: Raise a :exc:`ValidationError` if the field value\n is not supplied during deserialization.\n :param allow_none: Set this to `True` if `None` should be considered a valid value during\n validation/deserialization. If ``load_default=None`` and ``allow_none`` is unset,\n will default to ``True``. Otherwise, the default is ``False``.\n :param load_only: If `True` skip this field during serialization, otherwise\n its value will be present in the serialized data.\n :param dump_only: If `True` skip this field during deserialization, otherwise\n its value will be present in the deserialized object. In the context of an\n HTTP API, this effectively marks the field as \"read-only\".\n :param dict error_messages: Overrides for `Field.default_error_messages`.\n :param metadata: Extra information to be stored as field metadata.\n\n .. versionchanged:: 2.0.0\n Removed `error` parameter. Use ``error_messages`` instead.\n\n .. versionchanged:: 2.0.0\n Added `allow_none` parameter, which makes validation/deserialization of `None`\n consistent across fields.\n\n .. versionchanged:: 2.0.0\n Added `load_only` and `dump_only` parameters, which allow field skipping\n during the (de)serialization process.\n\n .. versionchanged:: 2.0.0\n Added `missing` parameter, which indicates the value for a field if the field\n is not found during deserialization.\n\n .. versionchanged:: 2.0.0\n ``default`` value is only used if explicitly set. Otherwise, missing values\n inputs are excluded from serialized output.\n\n .. versionchanged:: 3.0.0b8\n Add ``data_key`` parameter for the specifying the key in the input and\n output data. This parameter replaced both ``load_from`` and ``dump_to``.\n \"\"\"\n\n # Some fields, such as Method fields and Function fields, are not expected\n # to exist as attributes on the objects to serialize. Set this to False\n # for those fields\n _CHECK_ATTRIBUTE = True\n\n #: Default error messages for various kinds of errors. The keys in this dictionary\n #: are passed to `Field.make_error`. The values are error messages passed to\n #: :exc:`marshmallow.exceptions.ValidationError`.\n default_error_messages = {\n \"required\": \"Missing data for required field.\",\n \"null\": \"Field may not be null.\",\n \"validator_failed\": \"Invalid value.\",\n }\n\n def __init__(\n self,\n *,\n load_default: typing.Any = missing_,\n missing: typing.Any = missing_,\n dump_default: typing.Any = missing_,\n default: typing.Any = missing_,\n data_key: str | None = None,\n attribute: str | None = None,\n validate: (\n None\n | typing.Callable[[typing.Any], typing.Any]\n | typing.Iterable[typing.Callable[[typing.Any], typing.Any]]\n ) = None,\n required: bool = False,\n allow_none: bool | None = None,\n load_only: bool = False,\n dump_only: bool = False,\n error_messages: dict[str, str] | None = None,\n metadata: typing.Mapping[str, typing.Any] | None = None,\n **additional_metadata,\n ) -> None:\n # handle deprecated `default` and `missing` parameters\n if default is not missing_:\n warnings.warn(\n \"The 'default' argument to fields is deprecated. \"\n \"Use 'dump_default' instead.\",\n RemovedInMarshmallow4Warning,\n stacklevel=2,\n )\n if dump_default is missing_:\n dump_default = default\n if missing is not missing_:\n warnings.warn(\n \"The 'missing' argument to fields is deprecated. \"\n \"Use 'load_default' instead.\",\n RemovedInMarshmallow4Warning,\n stacklevel=2,\n )\n if load_default is missing_:\n load_default = missing\n self.dump_default = dump_default\n self.load_default = load_default\n\n self.attribute = attribute\n self.data_key = data_key\n self.validate = validate\n if validate is None:\n self.validators = []\n elif callable(validate):\n self.validators = [validate]\n elif utils.is_iterable_but_not_string(validate):\n self.validators = list(validate)\n else:\n raise ValueError(\n \"The 'validate' parameter must be a callable \"\n \"or a collection of callables.\"\n )\n\n # If allow_none is None and load_default is None\n # None should be considered valid by default\n self.allow_none = load_default is None if allow_none is None else allow_none\n self.load_only = load_only\n self.dump_only = dump_only\n if required is True and load_default is not missing_:\n raise ValueError(\"'load_default' must not be set for required fields.\")\n self.required = required\n\n metadata = metadata or {}\n self.metadata = {**metadata, **additional_metadata}\n if additional_metadata:\n warnings.warn(\n \"Passing field metadata as keyword arguments is deprecated. Use the \"\n \"explicit `metadata=...` argument instead. \"\n f\"Additional metadata: {additional_metadata}\",\n RemovedInMarshmallow4Warning,\n stacklevel=2,\n )\n\n # Collect default error message from self and parent classes\n messages = {} # type: dict[str, str]\n for cls in reversed(self.__class__.__mro__):\n messages.update(getattr(cls, \"default_error_messages\", {}))\n messages.update(error_messages or {})\n self.error_messages = messages\n\n def __repr__(self) -> str:\n return (\n f\"\"\n )\n\n def __deepcopy__(self, memo):\n return copy.copy(self)\n\n def get_value(self, obj, attr, accessor=None, default=missing_):\n \"\"\"Return the value for a given key from an object.\n\n :param object obj: The object to get the value from.\n :param str attr: The attribute/key in `obj` to get the value from.\n :param callable accessor: A callable used to retrieve the value of `attr` from\n the object `obj`. Defaults to `marshmallow.utils.get_value`.\n \"\"\"\n accessor_func = accessor or utils.get_value\n check_key = attr if self.attribute is None else self.attribute\n return accessor_func(obj, check_key, default)\n\n def _validate(self, value):\n \"\"\"Perform validation on ``value``. Raise a :exc:`ValidationError` if validation\n does not succeed.\n \"\"\"\n self._validate_all(value)\n\n @property\n def _validate_all(self):\n return And(*self.validators, error=self.error_messages[\"validator_failed\"])\n\n def make_error(self, key: str, **kwargs) -> ValidationError:\n \"\"\"Helper method to make a `ValidationError` with an error message\n from ``self.error_messages``.\n \"\"\"\n try:\n msg = self.error_messages[key]\n except KeyError as error:\n class_name = self.__class__.__name__\n message = (\n f\"ValidationError raised by `{class_name}`, but error key `{key}` does \"\n \"not exist in the `error_messages` dictionary.\"\n )\n raise AssertionError(message) from error\n if isinstance(msg, (str, bytes)):\n msg = msg.format(**kwargs)\n return ValidationError(msg)\n\n def fail(self, key: str, **kwargs):\n \"\"\"Helper method that raises a `ValidationError` with an error message\n from ``self.error_messages``.\n\n .. deprecated:: 3.0.0\n Use `make_error ` instead.\n \"\"\"\n warnings.warn(\n f'`Field.fail` is deprecated. Use `raise self.make_error(\"{key}\", ...)` instead.',\n RemovedInMarshmallow4Warning,\n stacklevel=2,\n )\n raise self.make_error(key=key, **kwargs)\n\n def _validate_missing(self, value):\n \"\"\"Validate missing values. Raise a :exc:`ValidationError` if\n `value` should be considered missing.\n \"\"\"\n if value is missing_ and self.required:\n raise self.make_error(\"required\")\n if value is None and not self.allow_none:\n raise self.make_error(\"null\")\n\n def serialize(\n self,\n attr: str,\n obj: typing.Any,\n accessor: typing.Callable[[typing.Any, str, typing.Any], typing.Any]\n | None = None,\n **kwargs,\n ):\n \"\"\"Pulls the value for the given key from the object, applies the\n field's formatting and returns the result.\n\n :param attr: The attribute/key to get from the object.\n :param obj: The object to access the attribute/key from.\n :param accessor: Function used to access values from ``obj``.\n :param kwargs: Field-specific keyword arguments.\n \"\"\"\n if self._CHECK_ATTRIBUTE:\n value = self.get_value(obj, attr, accessor=accessor)\n if value is missing_:\n default = self.dump_default\n value = default() if callable(default) else default\n if value is missing_:\n return value\n else:\n value = None\n return self._serialize(value, attr, obj, **kwargs)\n\n def deserialize(\n self,\n value: typing.Any,\n attr: str | None = None,\n data: typing.Mapping[str, typing.Any] | None = None,\n **kwargs,\n ):\n \"\"\"Deserialize ``value``.\n\n :param value: The value to deserialize.\n :param attr: The attribute/key in `data` to deserialize.\n :param data: The raw input data passed to `Schema.load`.\n :param kwargs: Field-specific keyword arguments.\n :raise ValidationError: If an invalid value is passed or if a required value\n is missing.\n \"\"\"\n # Validate required fields, deserialize, then validate\n # deserialized value\n self._validate_missing(value)\n if value is missing_:\n _miss = self.load_default\n return _miss() if callable(_miss) else _miss\n if self.allow_none and value is None:\n return None\n output = self._deserialize(value, attr, data, **kwargs)\n self._validate(output)\n return output\n\n # Methods for concrete classes to override.\n\n def _bind_to_schema(self, field_name, schema):\n \"\"\"Update field with values from its parent schema. Called by\n :meth:`Schema._bind_field `.\n\n :param str field_name: Field name set in schema.\n :param Schema|Field schema: Parent object.\n \"\"\"\n self.parent = self.parent or schema\n self.name = self.name or field_name\n self.root = self.root or (\n self.parent.root if isinstance(self.parent, FieldABC) else self.parent\n )\n\n def _serialize(\n self, value: typing.Any, attr: str | None, obj: typing.Any, **kwargs\n ):\n \"\"\"Serializes ``value`` to a basic Python datatype. Noop by default.\n Concrete :class:`Field` classes should implement this method.\n\n Example: ::\n\n class TitleCase(Field):\n def _serialize(self, value, attr, obj, **kwargs):\n if not value:\n return ''\n return str(value).title()\n\n :param value: The value to be serialized.\n :param str attr: The attribute or key on the object to be serialized.\n :param object obj: The object the value was pulled from.\n :param dict kwargs: Field-specific keyword arguments.\n :return: The serialized value\n \"\"\"\n return value\n\n def _deserialize(\n self,\n value: typing.Any,\n attr: str | None,\n data: typing.Mapping[str, typing.Any] | None,\n **kwargs,\n ):\n \"\"\"Deserialize value. Concrete :class:`Field` classes should implement this method.\n\n :param value: The value to be deserialized.\n :param attr: The attribute/key in `data` to be deserialized.\n :param data: The raw input data passed to the `Schema.load`.\n :param kwargs: Field-specific keyword arguments.\n :raise ValidationError: In case of formatting or validation failure.\n :return: The deserialized value.\n\n .. versionchanged:: 2.0.0\n Added ``attr`` and ``data`` parameters.\n\n .. versionchanged:: 3.0.0\n Added ``**kwargs`` to signature.\n \"\"\"\n return value\n\n # Properties\n\n @property\n def context(self):\n \"\"\"The context dictionary for the parent :class:`Schema`.\"\"\"\n return self.parent.context\n\n # the default and missing properties are provided for compatibility and\n # emit warnings when they are accessed and set\n @property\n def default(self):\n warnings.warn(\n \"The 'default' attribute of fields is deprecated. \"\n \"Use 'dump_default' instead.\",\n RemovedInMarshmallow4Warning,\n stacklevel=2,\n )\n return self.dump_default\n\n @default.setter\n def default(self, value):\n warnings.warn(\n \"The 'default' attribute of fields is deprecated. \"\n \"Use 'dump_default' instead.\",\n RemovedInMarshmallow4Warning,\n stacklevel=2,\n )\n self.dump_default = value\n\n @property\n def missing(self):\n warnings.warn(\n \"The 'missing' attribute of fields is deprecated. \"\n \"Use 'load_default' instead.\",\n RemovedInMarshmallow4Warning,\n stacklevel=2,\n )\n return self.load_default\n\n @missing.setter\n def missing(self, value):\n warnings.warn(\n \"The 'missing' attribute of fields is deprecated. \"\n \"Use 'load_default' instead.\",\n RemovedInMarshmallow4Warning,\n stacklevel=2,\n )\n self.load_default = value\n\n\nclass Raw(Field):\n \"\"\"Field that applies no formatting.\"\"\"\n\n\nclass Nested(Field):\n \"\"\"Allows you to nest a :class:`Schema `\n inside a field.\n\n Examples: ::\n\n class ChildSchema(Schema):\n id = fields.Str()\n name = fields.Str()\n # Use lambda functions when you need two-way nesting or self-nesting\n parent = fields.Nested(lambda: ParentSchema(only=(\"id\",)), dump_only=True)\n siblings = fields.List(fields.Nested(lambda: ChildSchema(only=(\"id\", \"name\"))))\n\n class ParentSchema(Schema):\n id = fields.Str()\n children = fields.List(\n fields.Nested(ChildSchema(only=(\"id\", \"parent\", \"siblings\")))\n )\n spouse = fields.Nested(lambda: ParentSchema(only=(\"id\",)))\n\n When passing a `Schema ` instance as the first argument,\n the instance's ``exclude``, ``only``, and ``many`` attributes will be respected.\n\n Therefore, when passing the ``exclude``, ``only``, or ``many`` arguments to `fields.Nested`,\n you should pass a `Schema ` class (not an instance) as the first argument.\n\n ::\n\n # Yes\n author = fields.Nested(UserSchema, only=('id', 'name'))\n\n # No\n author = fields.Nested(UserSchema(), only=('id', 'name'))\n\n :param nested: `Schema` instance, class, class name (string), dictionary, or callable that\n returns a `Schema` or dictionary. Dictionaries are converted with `Schema.from_dict`.\n :param exclude: A list or tuple of fields to exclude.\n :param only: A list or tuple of fields to marshal. If `None`, all fields are marshalled.\n This parameter takes precedence over ``exclude``.\n :param many: Whether the field is a collection of objects.\n :param unknown: Whether to exclude, include, or raise an error for unknown\n fields in the data. Use `EXCLUDE`, `INCLUDE` or `RAISE`.\n :param kwargs: The same keyword arguments that :class:`Field` receives.\n \"\"\"\n\n #: Default error messages.\n default_error_messages = {\"type\": \"Invalid type.\"}\n\n def __init__(\n self,\n nested: SchemaABC\n | type\n | str\n | dict[str, Field | type]\n | typing.Callable[[], SchemaABC | type | dict[str, Field | type]],\n *,\n dump_default: typing.Any = missing_,\n default: typing.Any = missing_,\n only: types.StrSequenceOrSet | None = None,\n exclude: types.StrSequenceOrSet = (),\n many: bool = False,\n unknown: str | None = None,\n **kwargs,\n ):\n # Raise error if only or exclude is passed as string, not list of strings\n if only is not None and not is_collection(only):\n raise StringNotCollectionError('\"only\" should be a collection of strings.')\n if not is_collection(exclude):\n raise StringNotCollectionError(\n '\"exclude\" should be a collection of strings.'\n )\n if nested == \"self\":\n warnings.warn(\n \"Passing 'self' to `Nested` is deprecated. \"\n \"Use `Nested(lambda: MySchema(...))` instead.\",\n RemovedInMarshmallow4Warning,\n stacklevel=2,\n )\n self.nested = nested\n self.only = only\n self.exclude = exclude\n self.many = many\n self.unknown = unknown\n self._schema = None # Cached Schema instance\n super().__init__(default=default, dump_default=dump_default, **kwargs)\n\n @property\n def schema(self):\n \"\"\"The nested Schema object.\n\n .. versionchanged:: 1.0.0\n Renamed from `serializer` to `schema`.\n \"\"\"\n if not self._schema:\n # Inherit context from parent.\n context = getattr(self.parent, \"context\", {})\n if callable(self.nested) and not isinstance(self.nested, type):\n nested = self.nested()\n else:\n nested = self.nested\n if isinstance(nested, dict):\n # defer the import of `marshmallow.schema` to avoid circular imports\n from marshmallow.schema import Schema\n\n nested = Schema.from_dict(nested)\n\n if isinstance(nested, SchemaABC):\n self._schema = copy.copy(nested)\n self._schema.context.update(context)\n # Respect only and exclude passed from parent and re-initialize fields\n set_class = self._schema.set_class\n if self.only is not None:\n if self._schema.only is not None:\n original = self._schema.only\n else: # only=None -> all fields\n original = self._schema.fields.keys()\n self._schema.only = set_class(self.only) & set_class(original)\n if self.exclude:\n original = self._schema.exclude\n self._schema.exclude = set_class(self.exclude) | set_class(original)\n self._schema._init_fields()\n else:\n if isinstance(nested, type) and issubclass(nested, SchemaABC):\n schema_class = nested\n elif not isinstance(nested, (str, bytes)):\n raise ValueError(\n \"`Nested` fields must be passed a \"\n f\"`Schema`, not {nested.__class__}.\"\n )\n elif nested == \"self\":\n schema_class = self.root.__class__\n else:\n schema_class = class_registry.get_class(nested)\n self._schema = schema_class(\n many=self.many,\n only=self.only,\n exclude=self.exclude,\n context=context,\n load_only=self._nested_normalized_option(\"load_only\"),\n dump_only=self._nested_normalized_option(\"dump_only\"),\n )\n return self._schema\n\n def _nested_normalized_option(self, option_name: str) -> list[str]:\n nested_field = \"%s.\" % self.name\n return [\n field.split(nested_field, 1)[1]\n for field in getattr(self.root, option_name, set())\n if field.startswith(nested_field)\n ]\n\n def _serialize(self, nested_obj, attr, obj, **kwargs):\n # Load up the schema first. This allows a RegistryError to be raised\n # if an invalid schema name was passed\n schema = self.schema\n if nested_obj is None:\n return None\n many = schema.many or self.many\n return schema.dump(nested_obj, many=many)\n\n def _test_collection(self, value):\n many = self.schema.many or self.many\n if many and not utils.is_collection(value):\n raise self.make_error(\"type\", input=value, type=value.__class__.__name__)\n\n def _load(self, value, data, partial=None):\n try:\n valid_data = self.schema.load(value, unknown=self.unknown, partial=partial)\n except ValidationError as error:\n raise ValidationError(\n error.messages, valid_data=error.valid_data\n ) from error\n return valid_data\n\n def _deserialize(self, value, attr, data, partial=None, **kwargs):\n \"\"\"Same as :meth:`Field._deserialize` with additional ``partial`` argument.\n\n :param bool|tuple partial: For nested schemas, the ``partial``\n parameter passed to `Schema.load`.\n\n .. versionchanged:: 3.0.0\n Add ``partial`` parameter.\n \"\"\"\n self._test_collection(value)\n return self._load(value, data, partial=partial)\n\n\nclass Pluck(Nested):\n \"\"\"Allows you to replace nested data with one of the data's fields.\n\n Example: ::\n\n from marshmallow import Schema, fields\n\n class ArtistSchema(Schema):\n id = fields.Int()\n name = fields.Str()\n\n class AlbumSchema(Schema):\n artist = fields.Pluck(ArtistSchema, 'id')\n\n\n in_data = {'artist': 42}\n loaded = AlbumSchema().load(in_data) # => {'artist': {'id': 42}}\n dumped = AlbumSchema().dump(loaded) # => {'artist': 42}\n\n :param Schema nested: The Schema class or class name (string)\n to nest, or ``\"self\"`` to nest the :class:`Schema` within itself.\n :param str field_name: The key to pluck a value from.\n :param kwargs: The same keyword arguments that :class:`Nested` receives.\n \"\"\"\n\n def __init__(\n self,\n nested: SchemaABC | type | str | typing.Callable[[], SchemaABC],\n field_name: str,\n **kwargs,\n ):\n super().__init__(nested, only=(field_name,), **kwargs)\n self.field_name = field_name\n\n @property\n def _field_data_key(self):\n only_field = self.schema.fields[self.field_name]\n return only_field.data_key or self.field_name\n\n def _serialize(self, nested_obj, attr, obj, **kwargs):\n ret = super()._serialize(nested_obj, attr, obj, **kwargs)\n if ret is None:\n return None\n if self.many:\n return utils.pluck(ret, key=self._field_data_key)\n return ret[self._field_data_key]\n\n def _deserialize(self, value, attr, data, partial=None, **kwargs):\n self._test_collection(value)\n if self.many:\n value = [{self._field_data_key: v} for v in value]\n else:\n value = {self._field_data_key: value}\n return self._load(value, data, partial=partial)\n\n\nclass List(Field):\n \"\"\"A list field, composed with another `Field` class or\n instance.\n\n Example: ::\n\n numbers = fields.List(fields.Float())\n\n :param cls_or_instance: A field class or instance.\n :param kwargs: The same keyword arguments that :class:`Field` receives.\n\n .. versionchanged:: 2.0.0\n The ``allow_none`` parameter now applies to deserialization and\n has the same semantics as the other fields.\n\n .. versionchanged:: 3.0.0rc9\n Does not serialize scalar values to single-item lists.\n \"\"\"\n\n #: Default error messages.\n default_error_messages = {\"invalid\": \"Not a valid list.\"}\n\n def __init__(self, cls_or_instance: Field | type, **kwargs):\n super().__init__(**kwargs)\n try:\n self.inner = resolve_field_instance(cls_or_instance)\n except FieldInstanceResolutionError as error:\n raise ValueError(\n \"The list elements must be a subclass or instance of \"\n \"marshmallow.base.FieldABC.\"\n ) from error\n if isinstance(self.inner, Nested):\n self.only = self.inner.only\n self.exclude = self.inner.exclude\n\n def _bind_to_schema(self, field_name, schema):\n super()._bind_to_schema(field_name, schema)\n self.inner = copy.deepcopy(self.inner)\n self.inner._bind_to_schema(field_name, self)\n if isinstance(self.inner, Nested):\n self.inner.only = self.only\n self.inner.exclude = self.exclude\n\n def _serialize(self, value, attr, obj, **kwargs) -> list[typing.Any] | None:\n if value is None:\n return None\n return [self.inner._serialize(each, attr, obj, **kwargs) for each in value]\n\n def _deserialize(self, value, attr, data, **kwargs) -> list[typing.Any]:\n if not utils.is_collection(value):\n raise self.make_error(\"invalid\")\n\n result = []\n errors = {}\n for idx, each in enumerate(value):\n try:\n result.append(self.inner.deserialize(each, **kwargs))\n except ValidationError as error:\n if error.valid_data is not None:\n result.append(error.valid_data)\n errors.update({idx: error.messages})\n if errors:\n raise ValidationError(errors, valid_data=result)\n return result\n\n\nclass Tuple(Field):\n \"\"\"A tuple field, composed of a fixed number of other `Field` classes or\n instances\n\n Example: ::\n\n row = Tuple((fields.String(), fields.Integer(), fields.Float()))\n\n .. note::\n Because of the structured nature of `collections.namedtuple` and\n `typing.NamedTuple`, using a Schema within a Nested field for them is\n more appropriate than using a `Tuple` field.\n\n :param Iterable[Field] tuple_fields: An iterable of field classes or\n instances.\n :param kwargs: The same keyword arguments that :class:`Field` receives.\n\n .. versionadded:: 3.0.0rc4\n \"\"\"\n\n #: Default error messages.\n default_error_messages = {\"invalid\": \"Not a valid tuple.\"}\n\n def __init__(self, tuple_fields, *args, **kwargs):\n super().__init__(*args, **kwargs)\n if not utils.is_collection(tuple_fields):\n raise ValueError(\n \"tuple_fields must be an iterable of Field classes or \" \"instances.\"\n )\n\n try:\n self.tuple_fields = [\n resolve_field_instance(cls_or_instance)\n for cls_or_instance in tuple_fields\n ]\n except FieldInstanceResolutionError as error:\n raise ValueError(\n 'Elements of \"tuple_fields\" must be subclasses or '\n \"instances of marshmallow.base.FieldABC.\"\n ) from error\n\n self.validate_length = Length(equal=len(self.tuple_fields))\n\n def _bind_to_schema(self, field_name, schema):\n super()._bind_to_schema(field_name, schema)\n new_tuple_fields = []\n for field in self.tuple_fields:\n field = copy.deepcopy(field)\n field._bind_to_schema(field_name, self)\n new_tuple_fields.append(field)\n\n self.tuple_fields = new_tuple_fields\n\n def _serialize(self, value, attr, obj, **kwargs) -> tuple | None:\n if value is None:\n return None\n\n return tuple(\n field._serialize(each, attr, obj, **kwargs)\n for field, each in zip(self.tuple_fields, value)\n )\n\n def _deserialize(self, value, attr, data, **kwargs) -> tuple:\n if not utils.is_collection(value):\n raise self.make_error(\"invalid\")\n\n self.validate_length(value)\n\n result = []\n errors = {}\n\n for idx, (field, each) in enumerate(zip(self.tuple_fields, value)):\n try:\n result.append(field.deserialize(each, **kwargs))\n except ValidationError as error:\n if error.valid_data is not None:\n result.append(error.valid_data)\n errors.update({idx: error.messages})\n if errors:\n raise ValidationError(errors, valid_data=result)\n\n return tuple(result)\n\n\nclass String(Field):\n \"\"\"A string field.\n\n :param kwargs: The same keyword arguments that :class:`Field` receives.\n \"\"\"\n\n #: Default error messages.\n default_error_messages = {\n \"invalid\": \"Not a valid string.\",\n \"invalid_utf8\": \"Not a valid utf-8 string.\",\n }\n\n def _serialize(self, value, attr, obj, **kwargs) -> str | None:\n if value is None:\n return None\n return utils.ensure_text_type(value)\n\n def _deserialize(self, value, attr, data, **kwargs) -> typing.Any:\n if not isinstance(value, (str, bytes)):\n raise self.make_error(\"invalid\")\n try:\n return utils.ensure_text_type(value)\n except UnicodeDecodeError as error:\n raise self.make_error(\"invalid_utf8\") from error\n\n\nclass UUID(String):\n \"\"\"A UUID field.\"\"\"\n\n #: Default error messages.\n default_error_messages = {\"invalid_uuid\": \"Not a valid UUID.\"}\n\n def _validated(self, value) -> uuid.UUID | None:\n \"\"\"Format the value or raise a :exc:`ValidationError` if an error occurs.\"\"\"\n if value is None:\n return None\n if isinstance(value, uuid.UUID):\n return value\n try:\n if isinstance(value, bytes) and len(value) == 16:\n return uuid.UUID(bytes=value)\n return uuid.UUID(value)\n except (ValueError, AttributeError, TypeError) as error:\n raise self.make_error(\"invalid_uuid\") from error\n\n def _deserialize(self, value, attr, data, **kwargs) -> uuid.UUID | None:\n return self._validated(value)\n\n\nclass Number(Field):\n \"\"\"Base class for number fields.\n\n :param bool as_string: If `True`, format the serialized value as a string.\n :param kwargs: The same keyword arguments that :class:`Field` receives.\n \"\"\"\n\n num_type = float # type: typing.Type\n\n #: Default error messages.\n default_error_messages = {\n \"invalid\": \"Not a valid number.\",\n \"too_large\": \"Number too large.\",\n }\n\n def __init__(self, *, as_string: bool = False, **kwargs):\n self.as_string = as_string\n super().__init__(**kwargs)\n\n def _format_num(self, value) -> typing.Any:\n \"\"\"Return the number value for value, given this field's `num_type`.\"\"\"\n return self.num_type(value)\n\n def _validated(self, value) -> _T | None:\n \"\"\"Format the value or raise a :exc:`ValidationError` if an error occurs.\"\"\"\n if value is None:\n return None\n # (value is True or value is False) is ~5x faster than isinstance(value, bool)\n if value is True or value is False:\n raise self.make_error(\"invalid\", input=value)\n try:\n return self._format_num(value)\n except (TypeError, ValueError) as error:\n raise self.make_error(\"invalid\", input=value) from error\n except OverflowError as error:\n raise self.make_error(\"too_large\", input=value) from error\n\n def _to_string(self, value) -> str:\n return str(value)\n\n def _serialize(self, value, attr, obj, **kwargs) -> str | _T | None:\n \"\"\"Return a string if `self.as_string=True`, otherwise return this field's `num_type`.\"\"\"\n if value is None:\n return None\n ret = self._format_num(value) # type: _T\n return self._to_string(ret) if self.as_string else ret\n\n def _deserialize(self, value, attr, data, **kwargs) -> _T | None:\n return self._validated(value)\n\n\nclass Integer(Number):\n \"\"\"An integer field.\n\n :param strict: If `True`, only integer types are valid.\n...\n", "needle_token_start": 5502, "needle_token_end": 5746, "code_context_ntokens": 16109, "output": ["```python\ndef options(\n self,\n valuegetter: str | typing.Callable[[typing.Any], typing.Any] = str,\n ) -> typing.Iterable[tuple[typing.Any, str]]:\n \"\"\"Return a generator over the (value, label) pairs, where value\n is a string associated with each choice. This convenience method\n is useful to populate, for instance, a form select field.\n\n :param valuegetter: Can be a callable or a string. In the former case, it must\n be a one-argument callable which returns the value of a\n choice. In the latter case, the string specifies the name\n of an attribute of the choice objects. Defaults to `str()`\n or `str()`.\n \"\"\"\n valuegetter = valuegetter if callable(valuegetter) else attrgetter(valuegetter)\n pairs = zip_longest(self.choices, self.labels, fillvalue=\"\")\n\n return ((valuegetter(choice), label) for choice, label in pairs)\n```"]} +{"repo": "marshmallow-code/marshmallow", "name": "_invoke_field_validators", "language": "python", "path": "src/marshmallow/schema.py", "position_ratio": 0.45, "description": "\nFunction Description:\n1. **Purpose**: The function is designed to execute field-specific validation methods on data, either for individual items or collections of items, and handle any resulting errors or data modifications.\n2. **Input**: The function takes an error handling object, the data to be validated (which can be a single item or a list of items), and a boolean indicating if the data represents multiple items.\n3. **Output**: There is no direct output from this function; however, it modifies the input data based on validation results and updates the error handling object with any validation errors encountered.\n4. **Procedure**: \n - The function iterates over registered validation methods.\n - For each validation method, it identifies the corresponding field in the data.\n - If the field exists, the function retrieves the value from the data.\n - The validation method is then called with this value.\n - If the validation method indicates that the value is missing (using a specific missing marker), the value is removed from the data.\n - Any errors raised during validation are captured and stored in the error handling object.\n - This process is repeated for each item in the data if the input represents multiple items.\n\n", "instruction": "Based on the function description and code context, please retrieve and repeat the exact described function from the code context in a code block wrapped by ```:", "template": "instruction\ncode_context\ndescription\ninstruction", "code_context": " ):\n \"\"\"Custom error handler function for the schema.\n\n :param error: The `ValidationError` raised during (de)serialization.\n :param data: The original input data.\n :param many: Value of ``many`` on dump or load.\n :param partial: Value of ``partial`` on load.\n\n .. versionadded:: 2.0.0\n\n .. versionchanged:: 3.0.0rc9\n Receives `many` and `partial` (on deserialization) as keyword arguments.\n \"\"\"\n pass\n\n def get_attribute(self, obj: typing.Any, attr: str, default: typing.Any):\n \"\"\"Defines how to pull values from an object to serialize.\n\n .. versionadded:: 2.0.0\n\n .. versionchanged:: 3.0.0a1\n Changed position of ``obj`` and ``attr``.\n \"\"\"\n return get_value(obj, attr, default)\n\n ##### Serialization/Deserialization API #####\n\n @staticmethod\n def _call_and_store(getter_func, data, *, field_name, error_store, index=None):\n \"\"\"Call ``getter_func`` with ``data`` as its argument, and store any `ValidationErrors`.\n\n :param callable getter_func: Function for getting the serialized/deserialized\n value from ``data``.\n :param data: The data passed to ``getter_func``.\n :param str field_name: Field name.\n :param int index: Index of the item being validated, if validating a collection,\n otherwise `None`.\n \"\"\"\n try:\n value = getter_func(data)\n except ValidationError as error:\n error_store.store_error(error.messages, field_name, index=index)\n # When a Nested field fails validation, the marshalled data is stored\n # on the ValidationError's valid_data attribute\n return error.valid_data or missing\n return value\n\n def _serialize(self, obj: _T | typing.Iterable[_T], *, many: bool = False):\n \"\"\"Serialize ``obj``.\n\n :param obj: The object(s) to serialize.\n :param bool many: `True` if ``data`` should be serialized as a collection.\n :return: A dictionary of the serialized data\n\n .. versionchanged:: 1.0.0\n Renamed from ``marshal``.\n \"\"\"\n if many and obj is not None:\n return [\n self._serialize(d, many=False)\n for d in typing.cast(typing.Iterable[_T], obj)\n ]\n ret = self.dict_class()\n for attr_name, field_obj in self.dump_fields.items():\n value = field_obj.serialize(attr_name, obj, accessor=self.get_attribute)\n if value is missing:\n continue\n key = field_obj.data_key if field_obj.data_key is not None else attr_name\n ret[key] = value\n return ret\n\n def dump(self, obj: typing.Any, *, many: bool | None = None):\n \"\"\"Serialize an object to native Python data types according to this\n Schema's fields.\n\n :param obj: The object to serialize.\n :param many: Whether to serialize `obj` as a collection. If `None`, the value\n for `self.many` is used.\n :return: Serialized data\n\n .. versionadded:: 1.0.0\n .. versionchanged:: 3.0.0b7\n This method returns the serialized data rather than a ``(data, errors)`` duple.\n A :exc:`ValidationError ` is raised\n if ``obj`` is invalid.\n .. versionchanged:: 3.0.0rc9\n Validation no longer occurs upon serialization.\n \"\"\"\n many = self.many if many is None else bool(many)\n if self._has_processors(PRE_DUMP):\n processed_obj = self._invoke_dump_processors(\n PRE_DUMP, obj, many=many, original_data=obj\n )\n else:\n processed_obj = obj\n\n result = self._serialize(processed_obj, many=many)\n\n if self._has_processors(POST_DUMP):\n result = self._invoke_dump_processors(\n POST_DUMP, result, many=many, original_data=obj\n )\n\n return result\n\n def dumps(self, obj: typing.Any, *args, many: bool | None = None, **kwargs):\n \"\"\"Same as :meth:`dump`, except return a JSON-encoded string.\n\n :param obj: The object to serialize.\n :param many: Whether to serialize `obj` as a collection. If `None`, the value\n for `self.many` is used.\n :return: A ``json`` string\n\n .. versionadded:: 1.0.0\n .. versionchanged:: 3.0.0b7\n This method returns the serialized data rather than a ``(data, errors)`` duple.\n A :exc:`ValidationError ` is raised\n if ``obj`` is invalid.\n \"\"\"\n serialized = self.dump(obj, many=many)\n return self.opts.render_module.dumps(serialized, *args, **kwargs)\n\n def _deserialize(\n self,\n data: (\n typing.Mapping[str, typing.Any]\n | typing.Iterable[typing.Mapping[str, typing.Any]]\n ),\n *,\n error_store: ErrorStore,\n many: bool = False,\n partial=None,\n unknown=RAISE,\n index=None,\n ) -> _T | list[_T]:\n \"\"\"Deserialize ``data``.\n\n :param dict data: The data to deserialize.\n :param ErrorStore error_store: Structure to store errors.\n :param bool many: `True` if ``data`` should be deserialized as a collection.\n :param bool|tuple partial: Whether to ignore missing fields and not require\n any fields declared. Propagates down to ``Nested`` fields as well. If\n its value is an iterable, only missing fields listed in that iterable\n will be ignored. Use dot delimiters to specify nested fields.\n :param unknown: Whether to exclude, include, or raise an error for unknown\n fields in the data. Use `EXCLUDE`, `INCLUDE` or `RAISE`.\n :param int index: Index of the item being serialized (for storing errors) if\n serializing a collection, otherwise `None`.\n :return: A dictionary of the deserialized data.\n \"\"\"\n index_errors = self.opts.index_errors\n index = index if index_errors else None\n if many:\n if not is_collection(data):\n error_store.store_error([self.error_messages[\"type\"]], index=index)\n ret_l = [] # type: typing.List[_T]\n else:\n ret_l = [\n typing.cast(\n _T,\n self._deserialize(\n typing.cast(typing.Mapping[str, typing.Any], d),\n error_store=error_store,\n many=False,\n partial=partial,\n unknown=unknown,\n index=idx,\n ),\n )\n for idx, d in enumerate(data)\n ]\n return ret_l\n ret_d = self.dict_class()\n # Check data is a dict\n if not isinstance(data, Mapping):\n error_store.store_error([self.error_messages[\"type\"]], index=index)\n else:\n partial_is_collection = is_collection(partial)\n for attr_name, field_obj in self.load_fields.items():\n field_name = (\n field_obj.data_key if field_obj.data_key is not None else attr_name\n )\n raw_value = data.get(field_name, missing)\n if raw_value is missing:\n # Ignore missing field if we're allowed to.\n if partial is True or (\n partial_is_collection and attr_name in partial\n ):\n continue\n d_kwargs = {}\n # Allow partial loading of nested schemas.\n if partial_is_collection:\n prefix = field_name + \".\"\n len_prefix = len(prefix)\n sub_partial = [\n f[len_prefix:] for f in partial if f.startswith(prefix)\n ]\n d_kwargs[\"partial\"] = sub_partial\n elif partial is not None:\n d_kwargs[\"partial\"] = partial\n\n def getter(\n val, field_obj=field_obj, field_name=field_name, d_kwargs=d_kwargs\n ):\n return field_obj.deserialize(\n val,\n field_name,\n data,\n **d_kwargs,\n )\n\n value = self._call_and_store(\n getter_func=getter,\n data=raw_value,\n field_name=field_name,\n error_store=error_store,\n index=index,\n )\n if value is not missing:\n key = field_obj.attribute or attr_name\n set_value(ret_d, key, value)\n if unknown != EXCLUDE:\n fields = {\n field_obj.data_key if field_obj.data_key is not None else field_name\n for field_name, field_obj in self.load_fields.items()\n }\n for key in set(data) - fields:\n value = data[key]\n if unknown == INCLUDE:\n ret_d[key] = value\n elif unknown == RAISE:\n error_store.store_error(\n [self.error_messages[\"unknown\"]],\n key,\n (index if index_errors else None),\n )\n return ret_d\n\n def load(\n self,\n data: (\n typing.Mapping[str, typing.Any]\n | typing.Iterable[typing.Mapping[str, typing.Any]]\n ),\n *,\n many: bool | None = None,\n partial: bool | types.StrSequenceOrSet | None = None,\n unknown: str | None = None,\n ):\n \"\"\"Deserialize a data structure to an object defined by this Schema's fields.\n\n :param data: The data to deserialize.\n :param many: Whether to deserialize `data` as a collection. If `None`, the\n value for `self.many` is used.\n :param partial: Whether to ignore missing fields and not require\n any fields declared. Propagates down to ``Nested`` fields as well. If\n its value is an iterable, only missing fields listed in that iterable\n will be ignored. Use dot delimiters to specify nested fields.\n :param unknown: Whether to exclude, include, or raise an error for unknown\n fields in the data. Use `EXCLUDE`, `INCLUDE` or `RAISE`.\n If `None`, the value for `self.unknown` is used.\n :return: Deserialized data\n\n .. versionadded:: 1.0.0\n .. versionchanged:: 3.0.0b7\n This method returns the deserialized data rather than a ``(data, errors)`` duple.\n A :exc:`ValidationError ` is raised\n if invalid data are passed.\n \"\"\"\n return self._do_load(\n data, many=many, partial=partial, unknown=unknown, postprocess=True\n )\n\n def loads(\n self,\n json_data: str,\n *,\n many: bool | None = None,\n partial: bool | types.StrSequenceOrSet | None = None,\n unknown: str | None = None,\n **kwargs,\n ):\n \"\"\"Same as :meth:`load`, except it takes a JSON string as input.\n\n :param json_data: A JSON string of the data to deserialize.\n :param many: Whether to deserialize `obj` as a collection. If `None`, the\n value for `self.many` is used.\n :param partial: Whether to ignore missing fields and not require\n any fields declared. Propagates down to ``Nested`` fields as well. If\n its value is an iterable, only missing fields listed in that iterable\n will be ignored. Use dot delimiters to specify nested fields.\n :param unknown: Whether to exclude, include, or raise an error for unknown\n fields in the data. Use `EXCLUDE`, `INCLUDE` or `RAISE`.\n If `None`, the value for `self.unknown` is used.\n :return: Deserialized data\n\n .. versionadded:: 1.0.0\n .. versionchanged:: 3.0.0b7\n This method returns the deserialized data rather than a ``(data, errors)`` duple.\n A :exc:`ValidationError ` is raised\n if invalid data are passed.\n \"\"\"\n data = self.opts.render_module.loads(json_data, **kwargs)\n return self.load(data, many=many, partial=partial, unknown=unknown)\n\n def _run_validator(\n self,\n validator_func,\n output,\n *,\n original_data,\n error_store,\n many,\n partial,\n pass_original,\n index=None,\n ):\n try:\n if pass_original: # Pass original, raw data (before unmarshalling)\n validator_func(output, original_data, partial=partial, many=many)\n else:\n validator_func(output, partial=partial, many=many)\n except ValidationError as err:\n error_store.store_error(err.messages, err.field_name, index=index)\n\n def validate(\n self,\n data: (\n typing.Mapping[str, typing.Any]\n | typing.Iterable[typing.Mapping[str, typing.Any]]\n ),\n *,\n many: bool | None = None,\n partial: bool | types.StrSequenceOrSet | None = None,\n ) -> dict[str, list[str]]:\n \"\"\"Validate `data` against the schema, returning a dictionary of\n validation errors.\n\n :param data: The data to validate.\n :param many: Whether to validate `data` as a collection. If `None`, the\n value for `self.many` is used.\n :param partial: Whether to ignore missing fields and not require\n any fields declared. Propagates down to ``Nested`` fields as well. If\n its value is an iterable, only missing fields listed in that iterable\n will be ignored. Use dot delimiters to specify nested fields.\n :return: A dictionary of validation errors.\n\n .. versionadded:: 1.1.0\n \"\"\"\n try:\n self._do_load(data, many=many, partial=partial, postprocess=False)\n except ValidationError as exc:\n return typing.cast(typing.Dict[str, typing.List[str]], exc.messages)\n return {}\n\n ##### Private Helpers #####\n\n def _do_load(\n self,\n data: (\n typing.Mapping[str, typing.Any]\n | typing.Iterable[typing.Mapping[str, typing.Any]]\n ),\n *,\n many: bool | None = None,\n partial: bool | types.StrSequenceOrSet | None = None,\n unknown: str | None = None,\n postprocess: bool = True,\n ):\n \"\"\"Deserialize `data`, returning the deserialized result.\n This method is private API.\n\n :param data: The data to deserialize.\n :param many: Whether to deserialize `data` as a collection. If `None`, the\n value for `self.many` is used.\n :param partial: Whether to validate required fields. If its\n value is an iterable, only fields listed in that iterable will be\n ignored will be allowed missing. If `True`, all fields will be allowed missing.\n If `None`, the value for `self.partial` is used.\n :param unknown: Whether to exclude, include, or raise an error for unknown\n fields in the data. Use `EXCLUDE`, `INCLUDE` or `RAISE`.\n If `None`, the value for `self.unknown` is used.\n :param postprocess: Whether to run post_load methods..\n :return: Deserialized data\n \"\"\"\n error_store = ErrorStore()\n errors = {} # type: dict[str, list[str]]\n many = self.many if many is None else bool(many)\n unknown = (\n self.unknown\n if unknown is None\n else validate_unknown_parameter_value(unknown)\n )\n if partial is None:\n partial = self.partial\n # Run preprocessors\n if self._has_processors(PRE_LOAD):\n try:\n processed_data = self._invoke_load_processors(\n PRE_LOAD, data, many=many, original_data=data, partial=partial\n )\n except ValidationError as err:\n errors = err.normalized_messages()\n result = None # type: list | dict | None\n else:\n processed_data = data\n if not errors:\n # Deserialize data\n result = self._deserialize(\n processed_data,\n error_store=error_store,\n many=many,\n partial=partial,\n unknown=unknown,\n )\n # Run field-level validation\n self._invoke_field_validators(\n error_store=error_store, data=result, many=many\n )\n # Run schema-level validation\n if self._has_processors(VALIDATES_SCHEMA):\n field_errors = bool(error_store.errors)\n self._invoke_schema_validators(\n error_store=error_store,\n pass_many=True,\n data=result,\n original_data=data,\n many=many,\n partial=partial,\n field_errors=field_errors,\n )\n self._invoke_schema_validators(\n error_store=error_store,\n pass_many=False,\n data=result,\n original_data=data,\n many=many,\n partial=partial,\n field_errors=field_errors,\n )\n errors = error_store.errors\n # Run post processors\n if not errors and postprocess and self._has_processors(POST_LOAD):\n try:\n result = self._invoke_load_processors(\n POST_LOAD,\n result,\n many=many,\n original_data=data,\n partial=partial,\n )\n except ValidationError as err:\n errors = err.normalized_messages()\n if errors:\n exc = ValidationError(errors, data=data, valid_data=result)\n self.handle_error(exc, data, many=many, partial=partial)\n raise exc\n\n return result\n\n def _normalize_nested_options(self) -> None:\n \"\"\"Apply then flatten nested schema options.\n This method is private API.\n \"\"\"\n if self.only is not None:\n # Apply the only option to nested fields.\n self.__apply_nested_option(\"only\", self.only, \"intersection\")\n # Remove the child field names from the only option.\n self.only = self.set_class([field.split(\".\", 1)[0] for field in self.only])\n if self.exclude:\n # Apply the exclude option to nested fields.\n self.__apply_nested_option(\"exclude\", self.exclude, \"union\")\n # Remove the parent field names from the exclude option.\n self.exclude = self.set_class(\n [field for field in self.exclude if \".\" not in field]\n )\n\n def __apply_nested_option(self, option_name, field_names, set_operation) -> None:\n \"\"\"Apply nested options to nested fields\"\"\"\n # Split nested field names on the first dot.\n nested_fields = [name.split(\".\", 1) for name in field_names if \".\" in name]\n # Partition the nested field names by parent field.\n nested_options = defaultdict(list) # type: defaultdict\n for parent, nested_names in nested_fields:\n nested_options[parent].append(nested_names)\n # Apply the nested field options.\n for key, options in iter(nested_options.items()):\n new_options = self.set_class(options)\n original_options = getattr(self.declared_fields[key], option_name, ())\n if original_options:\n if set_operation == \"union\":\n new_options |= self.set_class(original_options)\n if set_operation == \"intersection\":\n new_options &= self.set_class(original_options)\n setattr(self.declared_fields[key], option_name, new_options)\n\n def _init_fields(self) -> None:\n \"\"\"Update self.fields, self.load_fields, and self.dump_fields based on schema options.\n This method is private API.\n \"\"\"\n if self.opts.fields:\n available_field_names = self.set_class(self.opts.fields)\n else:\n available_field_names = self.set_class(self.declared_fields.keys())\n if self.opts.additional:\n available_field_names |= self.set_class(self.opts.additional)\n\n invalid_fields = self.set_class()\n\n if self.only is not None:\n # Return only fields specified in only option\n field_names: typing.AbstractSet[typing.Any] = self.set_class(self.only)\n\n invalid_fields |= field_names - available_field_names\n else:\n field_names = available_field_names\n\n # If \"exclude\" option or param is specified, remove those fields.\n if self.exclude:\n # Note that this isn't available_field_names, since we want to\n # apply \"only\" for the actual calculation.\n field_names = field_names - self.exclude\n invalid_fields |= self.exclude - available_field_names\n\n if invalid_fields:\n message = f\"Invalid fields for {self}: {invalid_fields}.\"\n raise ValueError(message)\n\n fields_dict = self.dict_class()\n for field_name in field_names:\n field_obj = self.declared_fields.get(field_name, ma_fields.Inferred())\n self._bind_field(field_name, field_obj)\n fields_dict[field_name] = field_obj\n\n load_fields, dump_fields = self.dict_class(), self.dict_class()\n for field_name, field_obj in fields_dict.items():\n if not field_obj.dump_only:\n load_fields[field_name] = field_obj\n if not field_obj.load_only:\n dump_fields[field_name] = field_obj\n\n dump_data_keys = [\n field_obj.data_key if field_obj.data_key is not None else name\n for name, field_obj in dump_fields.items()\n ]\n if len(dump_data_keys) != len(set(dump_data_keys)):\n data_keys_duplicates = {\n x for x in dump_data_keys if dump_data_keys.count(x) > 1\n }\n raise ValueError(\n \"The data_key argument for one or more fields collides \"\n \"with another field's name or data_key argument. \"\n \"Check the following field names and \"\n f\"data_key arguments: {list(data_keys_duplicates)}\"\n )\n load_attributes = [obj.attribute or name for name, obj in load_fields.items()]\n if len(load_attributes) != len(set(load_attributes)):\n attributes_duplicates = {\n x for x in load_attributes if load_attributes.count(x) > 1\n }\n raise ValueError(\n \"The attribute argument for one or more fields collides \"\n \"with another field's name or attribute argument. \"\n \"Check the following field names and \"\n f\"attribute arguments: {list(attributes_duplicates)}\"\n )\n\n self.fields = fields_dict\n self.dump_fields = dump_fields\n self.load_fields = load_fields\n\n def on_bind_field(self, field_name: str, field_obj: ma_fields.Field) -> None:\n \"\"\"Hook to modify a field when it is bound to the `Schema`.\n\n No-op by default.\n \"\"\"\n return None\n\n def _bind_field(self, field_name: str, field_obj: ma_fields.Field) -> None:\n \"\"\"Bind field to the schema, setting any necessary attributes on the\n field (e.g. parent and name).\n\n Also set field load_only and dump_only values if field_name was\n specified in ``class Meta``.\n \"\"\"\n if field_name in self.load_only:\n field_obj.load_only = True\n if field_name in self.dump_only:\n field_obj.dump_only = True\n try:\n field_obj._bind_to_schema(field_name, self)\n except TypeError as error:\n # Field declared as a class, not an instance. Ignore type checking because\n # we handle unsupported arg types, i.e. this is dead code from\n # the type checker's perspective.\n if isinstance(field_obj, type) and issubclass(field_obj, base.FieldABC):\n msg = (\n f'Field for \"{field_name}\" must be declared as a '\n \"Field instance, not a class. \"\n f'Did you mean \"fields.{field_obj.__name__}()\"?' # type: ignore\n )\n raise TypeError(msg) from error\n raise error\n self.on_bind_field(field_name, field_obj)\n\n @lru_cache(maxsize=8) # noqa (https://github.com/PyCQA/flake8-bugbear/issues/310)\n def _has_processors(self, tag) -> bool:\n return bool(self._hooks[(tag, True)] or self._hooks[(tag, False)])\n\n def _invoke_dump_processors(\n self, tag: str, data, *, many: bool, original_data=None\n ):\n # The pass_many post-dump processors may do things like add an envelope, so\n # invoke those after invoking the non-pass_many processors which will expect\n # to get a list of items.\n data = self._invoke_processors(\n tag, pass_many=False, data=data, many=many, original_data=original_data\n )\n data = self._invoke_processors(\n tag, pass_many=True, data=data, many=many, original_data=original_data\n )\n return data\n\n def _invoke_load_processors(\n self,\n tag: str,\n data,\n *,\n many: bool,\n original_data,\n partial: bool | types.StrSequenceOrSet | None,\n ):\n # This has to invert the order of the dump processors, so run the pass_many\n # processors first.\n data = self._invoke_processors(\n tag,\n pass_many=True,\n data=data,\n many=many,\n original_data=original_data,\n partial=partial,\n )\n data = self._invoke_processors(\n tag,\n pass_many=False,\n data=data,\n many=many,\n original_data=original_data,\n partial=partial,\n )\n return data\n\n \ndef _invoke_field_validators(self, *, error_store: ErrorStore, data, many: bool):\n for attr_name in self._hooks[VALIDATES]:\n validator = getattr(self, attr_name)\n validator_kwargs = validator.__marshmallow_hook__[VALIDATES]\n field_name = validator_kwargs[\"field_name\"]\n\n try:\n field_obj = self.fields[field_name]\n except KeyError as error:\n if field_name in self.declared_fields:\n continue\n raise ValueError(f'\"{field_name}\" field does not exist.') from error\n\n data_key = (\n field_obj.data_key if field_obj.data_key is not None else field_name\n )\n if many:\n for idx, item in enumerate(data):\n try:\n value = item[field_obj.attribute or field_name]\n except KeyError:\n pass\n else:\n validated_value = self._call_and_store(\n getter_func=validator,\n data=value,\n field_name=data_key,\n error_store=error_store,\n index=(idx if self.opts.index_errors else None),\n )\n if validated_value is missing:\n data[idx].pop(field_name, None)\n else:\n try:\n value = data[field_obj.attribute or field_name]\n except KeyError:\n pass\n else:\n validated_value = self._call_and_store(\n getter_func=validator,\n data=value,\n field_name=data_key,\n error_store=error_store,\n )\n if validated_value is missing:\n data.pop(field_name, None)\n\n def _invoke_schema_validators(\n self,\n *,\n error_store: ErrorStore,\n pass_many: bool,\n data,\n original_data,\n many: bool,\n partial: bool | types.StrSequenceOrSet | None,\n field_errors: bool = False,\n ):\n for attr_name in self._hooks[(VALIDATES_SCHEMA, pass_many)]:\n validator = getattr(self, attr_name)\n validator_kwargs = validator.__marshmallow_hook__[\n (VALIDATES_SCHEMA, pass_many)\n ]\n if field_errors and validator_kwargs[\"skip_on_field_errors\"]:\n continue\n pass_original = validator_kwargs.get(\"pass_original\", False)\n\n if many and not pass_many:\n for idx, (item, orig) in enumerate(zip(data, original_data)):\n self._run_validator(\n validator,\n item,\n original_data=orig,\n error_store=error_store,\n many=many,\n partial=partial,\n index=idx,\n pass_original=pass_original,\n )\n else:\n self._run_validator(\n validator,\n data,\n original_data=original_data,\n error_store=error_store,\n many=many,\n pass_original=pass_original,\n partial=partial,\n )\n\n def _invoke_processors(\n self,\n tag: str,\n *,\n pass_many: bool,\n data,\n many: bool,\n original_data=None,\n **kwargs,\n ):\n key = (tag, pass_many)\n for attr_name in self._hooks[key]:\n # This will be a bound method.\n processor = getattr(self, attr_name)\n\n processor_kwargs = processor.__marshmallow_hook__[key]\n pass_original = processor_kwargs.get(\"pass_original\", False)\n\n if many and not pass_many:\n if pass_original:\n data = [\n processor(item, original, many=many, **kwargs)\n for item, original in zip(data, original_data)\n ]\n else:\n data = [processor(item, many=many, **kwargs) for item in data]\n else:\n if pass_original:\n data = processor(data, original_data, many=many, **kwargs)\n else:\n data = processor(data, many=many, **kwargs)\n return data\n\n\nBaseSchema = Schema # for backwards compatibility\n\n# Path: src/marshmallow/validate.py\n\"\"\"Validation classes for various types of data.\"\"\"\nfrom __future__ import annotations\n\nimport re\nimport typing\nfrom abc import ABC, abstractmethod\nfrom itertools import zip_longest\nfrom operator import attrgetter\n\nfrom marshmallow import types\nfrom marshmallow.exceptions import ValidationError\n\n_T = typing.TypeVar(\"_T\")\n\n\nclass Validator(ABC):\n \"\"\"Abstract base class for validators.\n\n .. note::\n This class does not provide any validation behavior. It is only used to\n add a useful `__repr__` implementation for validators.\n \"\"\"\n\n error = None # type: str | None\n\n def __repr__(self) -> str:\n args = self._repr_args()\n args = f\"{args}, \" if args else \"\"\n\n return f\"<{self.__class__.__name__}({args}error={self.error!r})>\"\n\n def _repr_args(self) -> str:\n \"\"\"A string representation of the args passed to this validator. Used by\n `__repr__`.\n \"\"\"\n return \"\"\n\n @abstractmethod\n def __call__(self, value: typing.Any) -> typing.Any:\n ...\n\n\nclass And(Validator):\n \"\"\"Compose multiple validators and combine their error messages.\n\n Example: ::\n\n from marshmallow import validate, ValidationError\n\n def is_even(value):\n if value % 2 != 0:\n raise ValidationError(\"Not an even value.\")\n\n validator = validate.And(validate.Range(min=0), is_even)\n validator(-1)\n # ValidationError: ['Must be greater than or equal to 0.', 'Not an even value.']\n\n :param validators: Validators to combine.\n :param error: Error message to use when a validator returns ``False``.\n \"\"\"\n\n default_error_message = \"Invalid value.\"\n\n def __init__(self, *validators: types.Validator, error: str | None = None):\n self.validators = tuple(validators)\n self.error = error or self.default_error_message # type: str\n\n def _repr_args(self) -> str:\n return f\"validators={self.validators!r}\"\n\n def __call__(self, value: typing.Any) -> typing.Any:\n errors = []\n kwargs = {}\n for validator in self.validators:\n try:\n r = validator(value)\n if not isinstance(validator, Validator) and r is False:\n raise ValidationError(self.error)\n except ValidationError as err:\n kwargs.update(err.kwargs)\n if isinstance(err.messages, dict):\n errors.append(err.messages)\n else:\n # FIXME : Get rid of cast\n errors.extend(typing.cast(list, err.messages))\n if errors:\n raise ValidationError(errors, **kwargs)\n return value\n\n\nclass URL(Validator):\n \"\"\"Validate a URL.\n\n :param relative: Whether to allow relative URLs.\n :param absolute: Whether to allow absolute URLs.\n :param error: Error message to raise in case of a validation error.\n Can be interpolated with `{input}`.\n :param schemes: Valid schemes. By default, ``http``, ``https``,\n ``ftp``, and ``ftps`` are allowed.\n :param require_tld: Whether to reject non-FQDN hostnames.\n \"\"\"\n\n class RegexMemoizer:\n def __init__(self):\n self._memoized = {}\n\n def _regex_generator(\n self, relative: bool, absolute: bool, require_tld: bool\n ) -> typing.Pattern:\n hostname_variants = [\n # a normal domain name, expressed in [A-Z0-9] chars with hyphens allowed only in the middle\n # note that the regex will be compiled with IGNORECASE, so these are upper and lowercase chars\n (\n r\"(?:[A-Z0-9](?:[A-Z0-9-]{0,61}[A-Z0-9])?\\.)+\"\n r\"(?:[A-Z]{2,6}\\.?|[A-Z0-9-]{2,}\\.?)\"\n ),\n # or the special string 'localhost'\n r\"localhost\",\n # or IPv4\n r\"\\d{1,3}\\.\\d{1,3}\\.\\d{1,3}\\.\\d{1,3}\",\n # or IPv6\n r\"\\[[A-F0-9]*:[A-F0-9:]+\\]\",\n ]\n if not require_tld:\n # allow dotless hostnames\n hostname_variants.append(r\"(?:[A-Z0-9](?:[A-Z0-9-]{0,61}[A-Z0-9])?\\.?)\")\n\n absolute_part = \"\".join(\n (\n # scheme (e.g. 'https://', 'ftp://', etc)\n # this is validated separately against allowed schemes, so in the regex\n # we simply want to capture its existence\n r\"(?:[a-z0-9\\.\\-\\+]*)://\",\n # userinfo, for URLs encoding authentication\n # e.g. 'ftp://foo:bar@ftp.example.org/'\n r\"(?:(?:[a-z0-9\\-._~!$&'()*+,;=:]|%[0-9a-f]{2})*@)?\",\n # netloc, the hostname/domain part of the URL plus the optional port\n r\"(?:\",\n \"|\".join(hostname_variants),\n r\")\",\n r\"(?::\\d+)?\",\n )\n )\n relative_part = r\"(?:/?|[/?]\\S+)\\Z\"\n\n if relative:\n if absolute:\n parts: tuple[str, ...] = (\n r\"^(\",\n absolute_part,\n r\")?\",\n relative_part,\n )\n else:\n parts = (r\"^\", relative_part)\n else:\n parts = (r\"^\", absolute_part, relative_part)\n\n return re.compile(\"\".join(parts), re.IGNORECASE)\n\n def __call__(\n self, relative: bool, absolute: bool, require_tld: bool\n ) -> typing.Pattern:\n key = (relative, absolute, require_tld)\n if key not in self._memoized:\n self._memoized[key] = self._regex_generator(\n relative, absolute, require_tld\n )\n\n return self._memoized[key]\n\n _regex = RegexMemoizer()\n\n default_message = \"Not a valid URL.\"\n default_schemes = {\"http\", \"https\", \"ftp\", \"ftps\"}\n\n def __init__(\n self,\n *,\n relative: bool = False,\n absolute: bool = True,\n schemes: types.StrSequenceOrSet | None = None,\n require_tld: bool = True,\n error: str | None = None,\n ):\n if not relative and not absolute:\n raise ValueError(\n \"URL validation cannot set both relative and absolute to False.\"\n )\n self.relative = relative\n self.absolute = absolute\n self.error = error or self.default_message # type: str\n self.schemes = schemes or self.default_schemes\n self.require_tld = require_tld\n\n def _repr_args(self) -> str:\n return f\"relative={self.relative!r}, absolute={self.absolute!r}\"\n\n def _format_error(self, value) -> str:\n return self.error.format(input=value)\n\n def __call__(self, value: str) -> str:\n message = self._format_error(value)\n if not value:\n raise ValidationError(message)\n\n # Check first if the scheme is valid\n if \"://\" in value:\n scheme = value.split(\"://\")[0].lower()\n if scheme not in self.schemes:\n raise ValidationError(message)\n\n regex = self._regex(self.relative, self.absolute, self.require_tld)\n\n if not regex.search(value):\n raise ValidationError(message)\n\n return value\n\n\nclass Email(Validator):\n \"\"\"Validate an email address.\n\n :param error: Error message to raise in case of a validation error. Can be\n interpolated with `{input}`.\n \"\"\"\n\n USER_REGEX = re.compile(\n r\"(^[-!#$%&'*+/=?^`{}|~\\w]+(\\.[-!#$%&'*+/=?^`{}|~\\w]+)*\\Z\" # dot-atom\n # quoted-string\n r'|^\"([\\001-\\010\\013\\014\\016-\\037!#-\\[\\]-\\177]'\n r'|\\\\[\\001-\\011\\013\\014\\016-\\177])*\"\\Z)',\n re.IGNORECASE | re.UNICODE,\n )\n\n DOMAIN_REGEX = re.compile(\n # domain\n r\"(?:[A-Z0-9](?:[A-Z0-9-]{0,61}[A-Z0-9])?\\.)+\"\n r\"(?:[A-Z]{2,6}|[A-Z0-9-]{2,})\\Z\"\n # literal form, ipv4 address (SMTP 4.1.3)\n r\"|^\\[(25[0-5]|2[0-4]\\d|[0-1]?\\d?\\d)\"\n r\"(\\.(25[0-5]|2[0-4]\\d|[0-1]?\\d?\\d)){3}\\]\\Z\",\n re.IGNORECASE | re.UNICODE,\n )\n\n DOMAIN_WHITELIST = (\"localhost\",)\n\n default_message = \"Not a valid email address.\"\n\n def __init__(self, *, error: str | None = None):\n self.error = error or self.default_message # type: str\n\n def _format_error(self, value: str) -> str:\n return self.error.format(input=value)\n\n def __call__(self, value: str) -> str:\n message = self._format_error(value)\n\n if not value or \"@\" not in value:\n raise ValidationError(message)\n\n user_part, domain_part = value.rsplit(\"@\", 1)\n\n if not self.USER_REGEX.match(user_part):\n raise ValidationError(message)\n\n if domain_part not in self.DOMAIN_WHITELIST:\n if not self.DOMAIN_REGEX.match(domain_part):\n try:\n domain_part = domain_part.encode(\"idna\").decode(\"ascii\")\n except UnicodeError:\n pass\n else:\n if self.DOMAIN_REGEX.match(domain_part):\n return value\n raise ValidationError(message)\n\n return value\n\n\nclass Range(Validator):\n \"\"\"Validator which succeeds if the value passed to it is within the specified\n range. If ``min`` is not specified, or is specified as `None`,\n no lower bound exists. If ``max`` is not specified, or is specified as `None`,\n no upper bound exists. The inclusivity of the bounds (if they exist) is configurable.\n If ``min_inclusive`` is not specified, or is specified as `True`, then\n the ``min`` bound is included in the range. If ``max_inclusive`` is not specified,\n or is specified as `True`, then the ``max`` bound is included in the range.\n\n :param min: The minimum value (lower bound). If not provided, minimum\n value will not be checked.\n :param max: The maximum value (upper bound). If not provided, maximum\n value will not be checked.\n :param min_inclusive: Whether the `min` bound is included in the range.\n :param max_inclusive: Whether the `max` bound is included in the range.\n :param error: Error message to raise in case of a validation error.\n Can be interpolated with `{input}`, `{min}` and `{max}`.\n \"\"\"\n\n message_min = \"Must be {min_op} {{min}}.\"\n message_max = \"Must be {max_op} {{max}}.\"\n message_all = \"Must be {min_op} {{min}} and {max_op} {{max}}.\"\n\n message_gte = \"greater than or equal to\"\n message_gt = \"greater than\"\n message_lte = \"less than or equal to\"\n message_lt = \"less than\"\n\n def __init__(\n self,\n min=None,\n max=None,\n *,\n min_inclusive: bool = True,\n max_inclusive: bool = True,\n error: str | None = None,\n ):\n self.min = min\n self.max = max\n self.error = error\n self.min_inclusive = min_inclusive\n self.max_inclusive = max_inclusive\n\n # interpolate messages based on bound inclusivity\n self.message_min = self.message_min.format(\n min_op=self.message_gte if self.min_inclusive else self.message_gt\n )\n self.message_max = self.message_max.format(\n max_op=self.message_lte if self.max_inclusive else self.message_lt\n )\n self.message_all = self.message_all.format(\n min_op=self.message_gte if self.min_inclusive else self.message_gt,\n max_op=self.message_lte if self.max_inclusive else self.message_lt,\n )\n\n def _repr_args(self) -> str:\n return \"min={!r}, max={!r}, min_inclusive={!r}, max_inclusive={!r}\".format(\n self.min, self.max, self.min_inclusive, self.max_inclusive\n )\n\n def _format_error(self, value: _T, message: str) -> str:\n return (self.error or message).format(input=value, min=self.min, max=self.max)\n\n def __call__(self, value: _T) -> _T:\n if self.min is not None and (\n value < self.min if self.min_inclusive else value <= self.min\n ):\n message = self.message_min if self.max is None else self.message_all\n raise ValidationError(self._format_error(value, message))\n\n if self.max is not None and (\n value > self.max if self.max_inclusive else value >= self.max\n ):\n message = self.message_max if self.min is None else self.message_all\n raise ValidationError(self._format_error(value, message))\n\n return value\n\n\nclass Length(Validator):\n \"\"\"Validator which succeeds if the value passed to it has a\n length between a minimum and maximum. Uses len(), so it\n can work for strings, lists, or anything with length.\n\n :param min: The minimum length. If not provided, minimum length\n will not be checked.\n :param max: The maximum length. If not provided, maximum length\n will not be checked.\n :param equal: The exact length. If provided, maximum and minimum\n length will not be checked.\n :param error: Error message to raise in case of a validation error.\n Can be interpolated with `{input}`, `{min}` and `{max}`.\n \"\"\"\n\n message_min = \"Shorter than minimum length {min}.\"\n message_max = \"Longer than maximum length {max}.\"\n message_all = \"Length must be between {min} and {max}.\"\n message_equal = \"Length must be {equal}.\"\n\n def __init__(\n self,\n min: int | None = None,\n max: int | None = None,\n *,\n equal: int | None = None,\n error: str | None = None,\n ):\n if equal is not None and any([min, max]):\n raise ValueError(\n \"The `equal` parameter was provided, maximum or \"\n \"minimum parameter must not be provided.\"\n )\n\n self.min = min\n self.max = max\n self.error = error\n self.equal = equal\n\n def _repr_args(self) -> str:\n return f\"min={self.min!r}, max={self.max!r}, equal={self.equal!r}\"\n\n def _format_error(self, value: typing.Sized, message: str) -> str:\n return (self.error or message).format(\n input=value, min=self.min, max=self.max, equal=self.equal\n )\n\n def __call__(self, value: typing.Sized) -> typing.Sized:\n length = len(value)\n\n if self.equal is not None:\n if length != self.equal:\n raise ValidationError(self._format_error(value, self.message_equal))\n return value\n\n if self.min is not None and length < self.min:\n message = self.message_min if self.max is None else self.message_all\n raise ValidationError(self._format_error(value, message))\n\n if self.max is not None and length > self.max:\n message = self.message_max if self.min is None else self.message_all\n raise ValidationError(self._format_error(value, message))\n\n return value\n\n\nclass Equal(Validator):\n \"\"\"Validator which succeeds if the ``value`` passed to it is\n equal to ``comparable``.\n\n :param comparable: The object to compare to.\n :param error: Error message to raise in case of a validation error.\n Can be interpolated with `{input}` and `{other}`.\n \"\"\"\n\n default_message = \"Must be equal to {other}.\"\n\n def __init__(self, comparable, *, error: str | None = None):\n self.comparable = comparable\n self.error = error or self.default_message # type: str\n\n def _repr_args(self) -> str:\n return f\"comparable={self.comparable!r}\"\n\n def _format_error(self, value: _T) -> str:\n return self.error.format(input=value, other=self.comparable)\n\n def __call__(self, value: _T) -> _T:\n if value != self.comparable:\n raise ValidationError(self._format_error(value))\n return value\n\n\nclass Regexp(Validator):\n \"\"\"Validator which succeeds if the ``value`` matches ``regex``.\n\n .. note::\n\n Uses `re.match`, which searches for a match at the beginning of a string.\n\n :param regex: The regular expression string to use. Can also be a compiled\n regular expression pattern.\n :param flags: The regexp flags to use, for example re.IGNORECASE. Ignored\n if ``regex`` is not a string.\n :param error: Error message to raise in case of a validation error.\n Can be interpolated with `{input}` and `{regex}`.\n \"\"\"\n\n default_message = \"String does not match expected pattern.\"\n\n def __init__(\n self,\n regex: str | bytes | typing.Pattern,\n flags: int = 0,\n *,\n error: str | None = None,\n ):\n self.regex = (\n re.compile(regex, flags) if isinstance(regex, (str, bytes)) else regex\n )\n self.error = error or self.default_message # type: str\n\n def _repr_args(self) -> str:\n return f\"regex={self.regex!r}\"\n\n def _format_error(self, value: str | bytes) -> str:\n return self.error.format(input=value, regex=self.regex.pattern)\n\n @typing.overload\n def __call__(self, value: str) -> str:\n ...\n\n @typing.overload\n def __call__(self, value: bytes) -> bytes:\n ...\n\n def __call__(self, value):\n if self.regex.match(value) is None:\n raise ValidationError(self._format_error(value))\n\n return value\n\n\nclass Predicate(Validator):\n \"\"\"Call the specified ``method`` of the ``value`` object. The\n validator succeeds if the invoked method returns an object that\n evaluates to True in a Boolean context. Any additional keyword\n argument will be passed to the method.\n\n :param method: The name of the method to invoke.\n :param error: Error message to raise in case of a validation error.\n Can be interpolated with `{input}` and `{method}`.\n :param kwargs: Additional keyword arguments to pass to the method.\n \"\"\"\n\n default_message = \"Invalid input.\"\n\n def __init__(self, method: str, *, error: str | None = None, **kwargs):\n self.method = method\n self.error = error or self.default_message # type: str\n self.kwargs = kwargs\n\n def _repr_args(self) -> str:\n return f\"method={self.method!r}, kwargs={self.kwargs!r}\"\n\n def _format_error(self, value: typing.Any) -> str:\n return self.error.format(input=value, method=self.method)\n\n def __call__(self, value: typing.Any) -> typing.Any:\n method = getattr(value, self.method)\n\n if not method(**self.kwargs):\n raise ValidationError(self._format_error(value))\n\n return value\n\n\nclass NoneOf(Validator):\n \"\"\"Validator which fails if ``value`` is a member of ``iterable``.\n\n :param iterable: A sequence of invalid values.\n :param error: Error message to raise in case of a validation error. Can be\n interpolated using `{input}` and `{values}`.\n \"\"\"\n\n default_message = \"Invalid input.\"\n\n def __init__(self, iterable: typing.Iterable, *, error: str | None = None):\n self.iterable = iterable\n self.values_text = \", \".join(str(each) for each in self.iterable)\n self.error = error or self.default_message # type: str\n\n def _repr_args(self) -> str:\n return f\"iterable={self.iterable!r}\"\n\n def _format_error(self, value) -> str:\n return self.error.format(input=value, values=self.values_text)\n\n def __call__(self, value: typing.Any) -> typing.Any:\n try:\n if value in self.iterable:\n raise ValidationError(self._format_error(value))\n except TypeError:\n pass\n\n return value\n\n\nclass OneOf(Validator):\n \"\"\"Validator which succeeds if ``value`` is a member of ``choices``.\n\n :param choices: A sequence of valid values.\n :param labels: Optional sequence of labels to pair with the choices.\n :param error: Error message to raise in case of a validation error. Can be\n interpolated with `{input}`, `{choices}` and `{labels}`.\n \"\"\"\n\n default_message = \"Must be one of: {choices}.\"\n\n def __init__(\n self,\n choices: typing.Iterable,\n labels: typing.Iterable[str] | None = None,\n *,\n error: str | None = None,\n ):\n self.choices = choices\n self.choices_text = \", \".join(str(choice) for choice in self.choices)\n self.labels = labels if labels is not None else []\n self.labels_text = \", \".join(str(label) for label in self.labels)\n self.error = error or self.default_message # type: str\n\n def _repr_args(self) -> str:\n return f\"choices={self.choices!r}, labels={self.labels!r}\"\n\n def _format_error(self, value) -> str:\n return self.error.format(\n input=value, choices=self.choices_text, labels=self.labels_text\n )\n\n def __call__(self, value: typing.Any) -> typing.Any:\n try:\n if value not in self.choices:\n raise ValidationError(self._format_error(value))\n except TypeError as error:\n raise ValidationError(self._format_error(value)) from error\n\n return value\n\n def options(\n self,\n valuegetter: str | typing.Callable[[typing.Any], typing.Any] = str,\n ) -> typing.Iterable[tuple[typing.Any, str]]:\n \"\"\"Return a generator over the (value, label) pairs, where value\n is a string associated with each choice. This convenience method\n is useful to populate, for instance, a form select field.\n\n :param valuegetter: Can be a callable or a string. In the former case, it must\n be a one-argument callable which returns the value of a\n choice. In the latter case, the string specifies the name\n of an attribute of the choice objects. Defaults to `str()`\n or `str()`.\n \"\"\"\n valuegetter = valuegetter if callable(valuegetter) else attrgetter(valuegetter)\n pairs = zip_longest(self.choices, self.labels, fillvalue=\"\")\n\n return ((valuegetter(choice), label) for choice, label in pairs)\n\n\nclass ContainsOnly(OneOf):\n \"\"\"Validator which succeeds if ``value`` is a sequence and each element\n in the sequence is also in the sequence passed as ``choices``. Empty input\n is considered valid.\n\n :param iterable choices: Same as :class:`OneOf`.\n :param iterable labels: Same as :class:`OneOf`.\n :param str error: Same as :class:`OneOf`.\n\n .. versionchanged:: 3.0.0b2\n Duplicate values are considered valid.\n .. versionchanged:: 3.0.0b2\n Empty input is considered valid. Use `validate.Length(min=1) `\n to validate against empty inputs.\n \"\"\"\n\n default_message = \"One or more of the choices you made was not in: {choices}.\"\n\n def _format_error(self, value) -> str:\n value_text = \", \".join(str(val) for val in value)\n return super()._format_error(value_text)\n\n def __call__(self, value: typing.Sequence[_T]) -> typing.Sequence[_T]:\n # We can't use set.issubset because does not handle unhashable types\n for val in value:\n if val not in self.choices:\n raise ValidationError(self._format_error(value))\n return value\n\n\nclass ContainsNoneOf(NoneOf):\n \"\"\"Validator which fails if ``value`` is a sequence and any element\n in the sequence is a member of the sequence passed as ``iterable``. Empty input\n is considered valid.\n\n :param iterable iterable: Same as :class:`NoneOf`.\n :param str error: Same as :class:`NoneOf`.\n\n .. versionadded:: 3.6.0\n \"\"\"\n\n default_message = \"One or more of the choices you made was in: {values}.\"\n\n def _format_error(self, value) -> str:\n value_text = \", \".join(str(val) for val in value)\n return super()._format_error(value_text)\n\n def __call__(self, value: typing.Sequence[_T]) -> typing.Sequence[_T]:\n for val in value:\n if val in self.iterable:\n raise ValidationError(self._format_error(value))\n return value\n\n# Path: src/marshmallow/fields.py\n\"\"\"Field classes for various types of data.\"\"\"\nfrom __future__ import annotations\n\nimport collections\nimport copy\nimport datetime as dt\nimport decimal\nimport ipaddress\nimport math\nimport numbers\nimport typing\nimport uuid\nimport warnings\nfrom collections.abc import Mapping as _Mapping\nfrom enum import Enum as EnumType\n\nfrom marshmallow import class_registry, types, utils, validate\nfrom marshmallow.base import FieldABC, SchemaABC\nfrom marshmallow.exceptions import (\n FieldInstanceResolutionError,\n StringNotCollectionError,\n ValidationError,\n)\nfrom marshmallow.utils import (\n is_aware,\n is_collection,\n resolve_field_instance,\n)\nfrom marshmallow.utils import (\n missing as missing_,\n)\nfrom marshmallow.validate import And, Length\nfrom marshmallow.warnings import RemovedInMarshmallow4Warning\n\n__all__ = [\n \"Field\",\n \"Raw\",\n \"Nested\",\n \"Mapping\",\n \"Dict\",\n \"List\",\n \"Tuple\",\n \"String\",\n \"UUID\",\n \"Number\",\n \"Integer\",\n \"Decimal\",\n \"Boolean\",\n \"Float\",\n \"DateTime\",\n \"NaiveDateTime\",\n \"AwareDateTime\",\n \"Time\",\n \"Date\",\n \"TimeDelta\",\n \"Url\",\n \"URL\",\n \"Email\",\n \"IP\",\n \"IPv4\",\n \"IPv6\",\n \"IPInterface\",\n \"IPv4Interface\",\n \"IPv6Interface\",\n \"Enum\",\n \"Method\",\n \"Function\",\n \"Str\",\n \"Bool\",\n \"Int\",\n \"Constant\",\n \"Pluck\",\n]\n\n_T = typing.TypeVar(\"_T\")\n\n\nclass Field(FieldABC):\n \"\"\"Basic field from which other fields should extend. It applies no\n formatting by default, and should only be used in cases where\n data does not need to be formatted before being serialized or deserialized.\n On error, the name of the field will be returned.\n\n :param dump_default: If set, this value will be used during serialization if the\n input value is missing. If not set, the field will be excluded from the\n serialized output if the input value is missing. May be a value or a callable.\n :param load_default: Default deserialization value for the field if the field is not\n found in the input data. May be a value or a callable.\n :param data_key: The name of the dict key in the external representation, i.e.\n the input of `load` and the output of `dump`.\n If `None`, the key will match the name of the field.\n :param attribute: The name of the key/attribute in the internal representation, i.e.\n the output of `load` and the input of `dump`.\n If `None`, the key/attribute will match the name of the field.\n Note: This should only be used for very specific use cases such as\n outputting multiple fields for a single attribute, or using keys/attributes\n that are invalid variable names, unsuitable for field names. In most cases,\n you should use ``data_key`` instead.\n :param validate: Validator or collection of validators that are called\n...\n", "needle_token_start": 7062, "needle_token_end": 7537, "code_context_ntokens": 16143, "output": ["```python\ndef _invoke_field_validators(self, *, error_store: ErrorStore, data, many: bool):\n for attr_name in self._hooks[VALIDATES]:\n validator = getattr(self, attr_name)\n validator_kwargs = validator.__marshmallow_hook__[VALIDATES]\n field_name = validator_kwargs[\"field_name\"]\n\n try:\n field_obj = self.fields[field_name]\n except KeyError as error:\n if field_name in self.declared_fields:\n continue\n raise ValueError(f'\"{field_name}\" field does not exist.') from error\n\n data_key = (\n field_obj.data_key if field_obj.data_key is not None else field_name\n )\n if many:\n for idx, item in enumerate(data):\n try:\n value = item[field_obj.attribute or field_name]\n except KeyError:\n pass\n else:\n validated_value = self._call_and_store(\n getter_func=validator,\n data=value,\n field_name=data_key,\n error_store=error_store,\n index=(idx if self.opts.index_errors else None),\n )\n if validated_value is missing:\n data[idx].pop(field_name, None)\n else:\n try:\n value = data[field_obj.attribute or field_name]\n except KeyError:\n pass\n else:\n validated_value = self._call_and_store(\n getter_func=validator,\n data=value,\n field_name=data_key,\n error_store=error_store,\n )\n if validated_value is missing:\n data.pop(field_name, None)\n```"]} +{"repo": "marshmallow-code/marshmallow", "name": "_run_validator", "language": "python", "path": "src/marshmallow/schema.py", "position_ratio": 0.55, "description": "\nFunction Description:\n1. **Purpose**: The function is designed to execute a validation function on deserialized data, handling any validation errors that occur.\n2. **Input**: It accepts a validation function, the deserialized data to validate, the original data structure, an error storage mechanism, and flags indicating whether the data represents multiple items, whether partial data is allowed, and whether the original data should be passed to the validation function.\n3. **Output**: There is no direct output from this function; however, it updates the error storage with any validation errors encountered.\n4. **Procedure**: The function first determines whether to pass the original data structure along with the deserialized data to the validation function based on a specific flag. It then executes the validation function. If the validation function raises a validation error, this function captures the error and updates the error storage with the error details, including the specific field and index if applicable.\n\n", "instruction": "Based on the function description and code context, please retrieve and repeat the exact described function from the code context in a code block wrapped by ```:", "template": "instruction\ncode_context\ndescription\ninstruction", "code_context": "# Path: src/marshmallow/utils.py\n\"\"\"Utility methods for marshmallow.\"\"\"\nfrom __future__ import annotations\n\nimport collections\nimport datetime as dt\nimport functools\nimport inspect\nimport json\nimport re\nimport typing\nimport warnings\nfrom collections.abc import Mapping\nfrom email.utils import format_datetime, parsedate_to_datetime\nfrom pprint import pprint as py_pprint\n\nfrom marshmallow.base import FieldABC\nfrom marshmallow.exceptions import FieldInstanceResolutionError\nfrom marshmallow.warnings import RemovedInMarshmallow4Warning\n\nEXCLUDE = \"exclude\"\nINCLUDE = \"include\"\nRAISE = \"raise\"\n_UNKNOWN_VALUES = {EXCLUDE, INCLUDE, RAISE}\n\n\nclass _Missing:\n def __bool__(self):\n return False\n\n def __copy__(self):\n return self\n\n def __deepcopy__(self, _):\n return self\n\n def __repr__(self):\n return \"\"\n\n\n# Singleton value that indicates that a field's value is missing from input\n# dict passed to :meth:`Schema.load`. If the field's value is not required,\n# it's ``default`` value is used.\nmissing = _Missing()\n\n\ndef is_generator(obj) -> bool:\n \"\"\"Return True if ``obj`` is a generator\"\"\"\n return inspect.isgeneratorfunction(obj) or inspect.isgenerator(obj)\n\n\ndef is_iterable_but_not_string(obj) -> bool:\n \"\"\"Return True if ``obj`` is an iterable object that isn't a string.\"\"\"\n return (hasattr(obj, \"__iter__\") and not hasattr(obj, \"strip\")) or is_generator(obj)\n\n\ndef is_collection(obj) -> bool:\n \"\"\"Return True if ``obj`` is a collection type, e.g list, tuple, queryset.\"\"\"\n return is_iterable_but_not_string(obj) and not isinstance(obj, Mapping)\n\n\ndef is_instance_or_subclass(val, class_) -> bool:\n...\n# Path: src/marshmallow/schema.py\n\"\"\"The :class:`Schema` class, including its metaclass and options (class Meta).\"\"\"\nfrom __future__ import annotations\n\nimport copy\nimport datetime as dt\nimport decimal\nimport inspect\nimport json\nimport typing\nimport uuid\nimport warnings\nfrom abc import ABCMeta\nfrom collections import OrderedDict, defaultdict\nfrom collections.abc import Mapping\nfrom functools import lru_cache\n\nfrom marshmallow import base, class_registry, types\nfrom marshmallow import fields as ma_fields\nfrom marshmallow.decorators import (\n POST_DUMP,\n POST_LOAD,\n PRE_DUMP,\n PRE_LOAD,\n VALIDATES,\n VALIDATES_SCHEMA,\n)\nfrom marshmallow.error_store import ErrorStore\nfrom marshmallow.exceptions import StringNotCollectionError, ValidationError\nfrom marshmallow.orderedset import OrderedSet\nfrom marshmallow.utils import (\n EXCLUDE,\n INCLUDE,\n RAISE,\n get_value,\n is_collection,\n is_instance_or_subclass,\n missing,\n set_value,\n validate_unknown_parameter_value,\n)\nfrom marshmallow.warnings import RemovedInMarshmallow4Warning\n\n_T = typing.TypeVar(\"_T\")\n\n\ndef _get_fields(attrs):\n \"\"\"Get fields from a class\n\n :param attrs: Mapping of class attributes\n \"\"\"\n return [\n (field_name, field_value)\n for field_name, field_value in attrs.items()\n if is_instance_or_subclass(field_value, base.FieldABC)\n ]\n\n\n# This function allows Schemas to inherit from non-Schema classes and ensures\n# inheritance according to the MRO\ndef _get_fields_by_mro(klass):\n \"\"\"Collect fields from a class, following its method resolution order. The\n class itself is excluded from the search; only its parents are checked. Get\n fields from ``_declared_fields`` if available, else use ``__dict__``.\n\n :param type klass: Class whose fields to retrieve\n \"\"\"\n mro = inspect.getmro(klass)\n # Loop over mro in reverse to maintain correct order of fields\n return sum(\n (\n _get_fields(\n getattr(base, \"_declared_fields\", base.__dict__),\n )\n for base in mro[:0:-1]\n ),\n [],\n )\n\n\nclass SchemaMeta(ABCMeta):\n \"\"\"Metaclass for the Schema class. Binds the declared fields to\n a ``_declared_fields`` attribute, which is a dictionary mapping attribute\n names to field objects. Also sets the ``opts`` class attribute, which is\n the Schema class's ``class Meta`` options.\n \"\"\"\n\n def __new__(mcs, name, bases, attrs):\n meta = attrs.get(\"Meta\")\n ordered = getattr(meta, \"ordered\", False)\n if not ordered:\n # Inherit 'ordered' option\n # Warning: We loop through bases instead of MRO because we don't\n # yet have access to the class object\n # (i.e. can't call super before we have fields)\n for base_ in bases:\n if hasattr(base_, \"Meta\") and hasattr(base_.Meta, \"ordered\"):\n ordered = base_.Meta.ordered\n break\n else:\n ordered = False\n cls_fields = _get_fields(attrs)\n # Remove fields from list of class attributes to avoid shadowing\n # Schema attributes/methods in case of name conflict\n for field_name, _ in cls_fields:\n del attrs[field_name]\n klass = super().__new__(mcs, name, bases, attrs)\n inherited_fields = _get_fields_by_mro(klass)\n\n meta = klass.Meta\n # Set klass.opts in __new__ rather than __init__ so that it is accessible in\n # get_declared_fields\n klass.opts = klass.OPTIONS_CLASS(meta, ordered=ordered)\n # Add fields specified in the `include` class Meta option\n cls_fields += list(klass.opts.include.items())\n\n # Assign _declared_fields on class\n klass._declared_fields = mcs.get_declared_fields(\n klass=klass,\n cls_fields=cls_fields,\n inherited_fields=inherited_fields,\n dict_cls=dict,\n )\n return klass\n\n @classmethod\n def get_declared_fields(\n mcs,\n klass: type,\n cls_fields: list,\n inherited_fields: list,\n dict_cls: type = dict,\n ):\n \"\"\"Returns a dictionary of field_name => `Field` pairs declared on the class.\n This is exposed mainly so that plugins can add additional fields, e.g. fields\n computed from class Meta options.\n\n :param klass: The class object.\n :param cls_fields: The fields declared on the class, including those added\n by the ``include`` class Meta option.\n :param inherited_fields: Inherited fields.\n :param dict_cls: dict-like class to use for dict output Default to ``dict``.\n \"\"\"\n return dict_cls(inherited_fields + cls_fields)\n\n def __init__(cls, name, bases, attrs):\n super().__init__(name, bases, attrs)\n if name and cls.opts.register:\n class_registry.register(name, cls)\n cls._hooks = cls.resolve_hooks()\n\n def resolve_hooks(cls) -> dict[types.Tag, list[str]]:\n \"\"\"Add in the decorated processors\n\n By doing this after constructing the class, we let standard inheritance\n do all the hard work.\n \"\"\"\n mro = inspect.getmro(cls)\n\n hooks = defaultdict(list) # type: typing.Dict[types.Tag, typing.List[str]]\n\n for attr_name in dir(cls):\n # Need to look up the actual descriptor, not whatever might be\n # bound to the class. This needs to come from the __dict__ of the\n # declaring class.\n for parent in mro:\n try:\n attr = parent.__dict__[attr_name]\n except KeyError:\n continue\n else:\n break\n else:\n # In case we didn't find the attribute and didn't break above.\n # We should never hit this - it's just here for completeness\n # to exclude the possibility of attr being undefined.\n continue\n\n try:\n hook_config = attr.__marshmallow_hook__\n except AttributeError:\n pass\n else:\n for key in hook_config.keys():\n # Use name here so we can get the bound method later, in\n # case the processor was a descriptor or something.\n hooks[key].append(attr_name)\n\n return hooks\n\n\nclass SchemaOpts:\n \"\"\"class Meta options for the :class:`Schema`. Defines defaults.\"\"\"\n\n def __init__(self, meta, ordered: bool = False):\n self.fields = getattr(meta, \"fields\", ())\n if not isinstance(self.fields, (list, tuple)):\n raise ValueError(\"`fields` option must be a list or tuple.\")\n self.additional = getattr(meta, \"additional\", ())\n if not isinstance(self.additional, (list, tuple)):\n raise ValueError(\"`additional` option must be a list or tuple.\")\n if self.fields and self.additional:\n raise ValueError(\n \"Cannot set both `fields` and `additional` options\"\n \" for the same Schema.\"\n )\n self.exclude = getattr(meta, \"exclude\", ())\n if not isinstance(self.exclude, (list, tuple)):\n raise ValueError(\"`exclude` must be a list or tuple.\")\n self.dateformat = getattr(meta, \"dateformat\", None)\n self.datetimeformat = getattr(meta, \"datetimeformat\", None)\n self.timeformat = getattr(meta, \"timeformat\", None)\n if hasattr(meta, \"json_module\"):\n warnings.warn(\n \"The json_module class Meta option is deprecated. Use render_module instead.\",\n RemovedInMarshmallow4Warning,\n stacklevel=2,\n )\n render_module = getattr(meta, \"json_module\", json)\n else:\n render_module = json\n self.render_module = getattr(meta, \"render_module\", render_module)\n self.ordered = getattr(meta, \"ordered\", ordered)\n self.index_errors = getattr(meta, \"index_errors\", True)\n self.include = getattr(meta, \"include\", {})\n self.load_only = getattr(meta, \"load_only\", ())\n self.dump_only = getattr(meta, \"dump_only\", ())\n self.unknown = validate_unknown_parameter_value(getattr(meta, \"unknown\", RAISE))\n self.register = getattr(meta, \"register\", True)\n\n\nclass Schema(base.SchemaABC, metaclass=SchemaMeta):\n \"\"\"Base schema class with which to define custom schemas.\n\n Example usage:\n\n .. code-block:: python\n\n import datetime as dt\n from dataclasses import dataclass\n\n from marshmallow import Schema, fields\n\n\n @dataclass\n class Album:\n title: str\n release_date: dt.date\n\n\n class AlbumSchema(Schema):\n title = fields.Str()\n release_date = fields.Date()\n\n\n album = Album(\"Beggars Banquet\", dt.date(1968, 12, 6))\n schema = AlbumSchema()\n data = schema.dump(album)\n data # {'release_date': '1968-12-06', 'title': 'Beggars Banquet'}\n\n :param only: Whitelist of the declared fields to select when\n instantiating the Schema. If None, all fields are used. Nested fields\n can be represented with dot delimiters.\n :param exclude: Blacklist of the declared fields to exclude\n when instantiating the Schema. If a field appears in both `only` and\n `exclude`, it is not used. Nested fields can be represented with dot\n delimiters.\n :param many: Should be set to `True` if ``obj`` is a collection\n so that the object will be serialized to a list.\n :param context: Optional context passed to :class:`fields.Method` and\n :class:`fields.Function` fields.\n :param load_only: Fields to skip during serialization (write-only fields)\n :param dump_only: Fields to skip during deserialization (read-only fields)\n :param partial: Whether to ignore missing fields and not require\n any fields declared. Propagates down to ``Nested`` fields as well. If\n its value is an iterable, only missing fields listed in that iterable\n will be ignored. Use dot delimiters to specify nested fields.\n :param unknown: Whether to exclude, include, or raise an error for unknown\n fields in the data. Use `EXCLUDE`, `INCLUDE` or `RAISE`.\n\n .. versionchanged:: 3.0.0\n `prefix` parameter removed.\n\n .. versionchanged:: 2.0.0\n `__validators__`, `__preprocessors__`, and `__data_handlers__` are removed in favor of\n `marshmallow.decorators.validates_schema`,\n `marshmallow.decorators.pre_load` and `marshmallow.decorators.post_dump`.\n `__accessor__` and `__error_handler__` are deprecated. Implement the\n `handle_error` and `get_attribute` methods instead.\n \"\"\"\n\n TYPE_MAPPING = {\n str: ma_fields.String,\n bytes: ma_fields.String,\n dt.datetime: ma_fields.DateTime,\n float: ma_fields.Float,\n bool: ma_fields.Boolean,\n tuple: ma_fields.Raw,\n list: ma_fields.Raw,\n set: ma_fields.Raw,\n int: ma_fields.Integer,\n uuid.UUID: ma_fields.UUID,\n dt.time: ma_fields.Time,\n dt.date: ma_fields.Date,\n dt.timedelta: ma_fields.TimeDelta,\n decimal.Decimal: ma_fields.Decimal,\n } # type: typing.Dict[type, typing.Type[ma_fields.Field]]\n #: Overrides for default schema-level error messages\n error_messages = {} # type: typing.Dict[str, str]\n\n _default_error_messages = {\n \"type\": \"Invalid input type.\",\n \"unknown\": \"Unknown field.\",\n } # type: typing.Dict[str, str]\n\n OPTIONS_CLASS = SchemaOpts # type: type\n\n set_class = OrderedSet\n\n # These get set by SchemaMeta\n opts = None # type: SchemaOpts\n _declared_fields = {} # type: typing.Dict[str, ma_fields.Field]\n _hooks = {} # type: typing.Dict[types.Tag, typing.List[str]]\n\n class Meta:\n \"\"\"Options object for a Schema.\n\n Example usage: ::\n\n class Meta:\n fields = (\"id\", \"email\", \"date_created\")\n exclude = (\"password\", \"secret_attribute\")\n\n Available options:\n\n - ``fields``: Tuple or list of fields to include in the serialized result.\n - ``additional``: Tuple or list of fields to include *in addition* to the\n explicitly declared fields. ``additional`` and ``fields`` are\n mutually-exclusive options.\n - ``include``: Dictionary of additional fields to include in the schema. It is\n usually better to define fields as class variables, but you may need to\n use this option, e.g., if your fields are Python keywords. May be an\n `OrderedDict`.\n - ``exclude``: Tuple or list of fields to exclude in the serialized result.\n Nested fields can be represented with dot delimiters.\n - ``dateformat``: Default format for `Date ` fields.\n - ``datetimeformat``: Default format for `DateTime ` fields.\n - ``timeformat``: Default format for `Time ` fields.\n - ``render_module``: Module to use for `loads ` and `dumps `.\n Defaults to `json` from the standard library.\n - ``ordered``: If `True`, output of `Schema.dump` will be a `collections.OrderedDict`.\n - ``index_errors``: If `True`, errors dictionaries will include the index\n of invalid items in a collection.\n - ``load_only``: Tuple or list of fields to exclude from serialized results.\n - ``dump_only``: Tuple or list of fields to exclude from deserialization\n - ``unknown``: Whether to exclude, include, or raise an error for unknown\n fields in the data. Use `EXCLUDE`, `INCLUDE` or `RAISE`.\n - ``register``: Whether to register the `Schema` with marshmallow's internal\n class registry. Must be `True` if you intend to refer to this `Schema`\n by class name in `Nested` fields. Only set this to `False` when memory\n usage is critical. Defaults to `True`.\n \"\"\"\n\n def __init__(\n self,\n *,\n only: types.StrSequenceOrSet | None = None,\n exclude: types.StrSequenceOrSet = (),\n many: bool = False,\n context: dict | None = None,\n load_only: types.StrSequenceOrSet = (),\n dump_only: types.StrSequenceOrSet = (),\n partial: bool | types.StrSequenceOrSet | None = None,\n unknown: str | None = None,\n ):\n # Raise error if only or exclude is passed as string, not list of strings\n if only is not None and not is_collection(only):\n raise StringNotCollectionError('\"only\" should be a list of strings')\n if not is_collection(exclude):\n raise StringNotCollectionError('\"exclude\" should be a list of strings')\n # copy declared fields from metaclass\n self.declared_fields = copy.deepcopy(self._declared_fields)\n self.many = many\n self.only = only\n self.exclude: set[typing.Any] | typing.MutableSet[typing.Any] = set(\n self.opts.exclude\n ) | set(exclude)\n self.ordered = self.opts.ordered\n self.load_only = set(load_only) or set(self.opts.load_only)\n self.dump_only = set(dump_only) or set(self.opts.dump_only)\n self.partial = partial\n self.unknown = (\n self.opts.unknown\n if unknown is None\n else validate_unknown_parameter_value(unknown)\n )\n self.context = context or {}\n self._normalize_nested_options()\n #: Dictionary mapping field_names -> :class:`Field` objects\n self.fields = {} # type: typing.Dict[str, ma_fields.Field]\n self.load_fields = {} # type: typing.Dict[str, ma_fields.Field]\n self.dump_fields = {} # type: typing.Dict[str, ma_fields.Field]\n self._init_fields()\n messages = {}\n messages.update(self._default_error_messages)\n for cls in reversed(self.__class__.__mro__):\n messages.update(getattr(cls, \"error_messages\", {}))\n messages.update(self.error_messages or {})\n self.error_messages = messages\n\n def __repr__(self) -> str:\n return f\"<{self.__class__.__name__}(many={self.many})>\"\n\n @property\n def dict_class(self) -> type:\n return OrderedDict if self.ordered else dict\n\n @classmethod\n def from_dict(\n cls,\n fields: dict[str, ma_fields.Field | type],\n *,\n name: str = \"GeneratedSchema\",\n ) -> type:\n \"\"\"Generate a `Schema` class given a dictionary of fields.\n\n .. code-block:: python\n\n from marshmallow import Schema, fields\n\n PersonSchema = Schema.from_dict({\"name\": fields.Str()})\n print(PersonSchema().load({\"name\": \"David\"})) # => {'name': 'David'}\n\n Generated schemas are not added to the class registry and therefore cannot\n be referred to by name in `Nested` fields.\n\n :param dict fields: Dictionary mapping field names to field instances.\n :param str name: Optional name for the class, which will appear in\n the ``repr`` for the class.\n\n .. versionadded:: 3.0.0\n \"\"\"\n attrs = fields.copy()\n attrs[\"Meta\"] = type(\n \"GeneratedMeta\", (getattr(cls, \"Meta\", object),), {\"register\": False}\n )\n schema_cls = type(name, (cls,), attrs)\n return schema_cls\n\n ##### Override-able methods #####\n\n def handle_error(\n self, error: ValidationError, data: typing.Any, *, many: bool, **kwargs\n ):\n \"\"\"Custom error handler function for the schema.\n\n :param error: The `ValidationError` raised during (de)serialization.\n :param data: The original input data.\n :param many: Value of ``many`` on dump or load.\n :param partial: Value of ``partial`` on load.\n\n .. versionadded:: 2.0.0\n\n .. versionchanged:: 3.0.0rc9\n Receives `many` and `partial` (on deserialization) as keyword arguments.\n \"\"\"\n pass\n\n def get_attribute(self, obj: typing.Any, attr: str, default: typing.Any):\n \"\"\"Defines how to pull values from an object to serialize.\n\n .. versionadded:: 2.0.0\n\n .. versionchanged:: 3.0.0a1\n Changed position of ``obj`` and ``attr``.\n \"\"\"\n return get_value(obj, attr, default)\n\n ##### Serialization/Deserialization API #####\n\n @staticmethod\n def _call_and_store(getter_func, data, *, field_name, error_store, index=None):\n \"\"\"Call ``getter_func`` with ``data`` as its argument, and store any `ValidationErrors`.\n\n :param callable getter_func: Function for getting the serialized/deserialized\n value from ``data``.\n :param data: The data passed to ``getter_func``.\n :param str field_name: Field name.\n :param int index: Index of the item being validated, if validating a collection,\n otherwise `None`.\n \"\"\"\n try:\n value = getter_func(data)\n except ValidationError as error:\n error_store.store_error(error.messages, field_name, index=index)\n # When a Nested field fails validation, the marshalled data is stored\n # on the ValidationError's valid_data attribute\n return error.valid_data or missing\n return value\n\n def _serialize(self, obj: _T | typing.Iterable[_T], *, many: bool = False):\n \"\"\"Serialize ``obj``.\n\n :param obj: The object(s) to serialize.\n :param bool many: `True` if ``data`` should be serialized as a collection.\n :return: A dictionary of the serialized data\n\n .. versionchanged:: 1.0.0\n Renamed from ``marshal``.\n \"\"\"\n if many and obj is not None:\n return [\n self._serialize(d, many=False)\n for d in typing.cast(typing.Iterable[_T], obj)\n ]\n ret = self.dict_class()\n for attr_name, field_obj in self.dump_fields.items():\n value = field_obj.serialize(attr_name, obj, accessor=self.get_attribute)\n if value is missing:\n continue\n key = field_obj.data_key if field_obj.data_key is not None else attr_name\n ret[key] = value\n return ret\n\n def dump(self, obj: typing.Any, *, many: bool | None = None):\n \"\"\"Serialize an object to native Python data types according to this\n Schema's fields.\n\n :param obj: The object to serialize.\n :param many: Whether to serialize `obj` as a collection. If `None`, the value\n for `self.many` is used.\n :return: Serialized data\n\n .. versionadded:: 1.0.0\n .. versionchanged:: 3.0.0b7\n This method returns the serialized data rather than a ``(data, errors)`` duple.\n A :exc:`ValidationError ` is raised\n if ``obj`` is invalid.\n .. versionchanged:: 3.0.0rc9\n Validation no longer occurs upon serialization.\n \"\"\"\n many = self.many if many is None else bool(many)\n if self._has_processors(PRE_DUMP):\n processed_obj = self._invoke_dump_processors(\n PRE_DUMP, obj, many=many, original_data=obj\n )\n else:\n processed_obj = obj\n\n result = self._serialize(processed_obj, many=many)\n\n if self._has_processors(POST_DUMP):\n result = self._invoke_dump_processors(\n POST_DUMP, result, many=many, original_data=obj\n )\n\n return result\n\n def dumps(self, obj: typing.Any, *args, many: bool | None = None, **kwargs):\n \"\"\"Same as :meth:`dump`, except return a JSON-encoded string.\n\n :param obj: The object to serialize.\n :param many: Whether to serialize `obj` as a collection. If `None`, the value\n for `self.many` is used.\n :return: A ``json`` string\n\n .. versionadded:: 1.0.0\n .. versionchanged:: 3.0.0b7\n This method returns the serialized data rather than a ``(data, errors)`` duple.\n A :exc:`ValidationError ` is raised\n if ``obj`` is invalid.\n \"\"\"\n serialized = self.dump(obj, many=many)\n return self.opts.render_module.dumps(serialized, *args, **kwargs)\n\n def _deserialize(\n self,\n data: (\n typing.Mapping[str, typing.Any]\n | typing.Iterable[typing.Mapping[str, typing.Any]]\n ),\n *,\n error_store: ErrorStore,\n many: bool = False,\n partial=None,\n unknown=RAISE,\n index=None,\n ) -> _T | list[_T]:\n \"\"\"Deserialize ``data``.\n\n :param dict data: The data to deserialize.\n :param ErrorStore error_store: Structure to store errors.\n :param bool many: `True` if ``data`` should be deserialized as a collection.\n :param bool|tuple partial: Whether to ignore missing fields and not require\n any fields declared. Propagates down to ``Nested`` fields as well. If\n its value is an iterable, only missing fields listed in that iterable\n will be ignored. Use dot delimiters to specify nested fields.\n :param unknown: Whether to exclude, include, or raise an error for unknown\n fields in the data. Use `EXCLUDE`, `INCLUDE` or `RAISE`.\n :param int index: Index of the item being serialized (for storing errors) if\n serializing a collection, otherwise `None`.\n :return: A dictionary of the deserialized data.\n \"\"\"\n index_errors = self.opts.index_errors\n index = index if index_errors else None\n if many:\n if not is_collection(data):\n error_store.store_error([self.error_messages[\"type\"]], index=index)\n ret_l = [] # type: typing.List[_T]\n else:\n ret_l = [\n typing.cast(\n _T,\n self._deserialize(\n typing.cast(typing.Mapping[str, typing.Any], d),\n error_store=error_store,\n many=False,\n partial=partial,\n unknown=unknown,\n index=idx,\n ),\n )\n for idx, d in enumerate(data)\n ]\n return ret_l\n ret_d = self.dict_class()\n # Check data is a dict\n if not isinstance(data, Mapping):\n error_store.store_error([self.error_messages[\"type\"]], index=index)\n else:\n partial_is_collection = is_collection(partial)\n for attr_name, field_obj in self.load_fields.items():\n field_name = (\n field_obj.data_key if field_obj.data_key is not None else attr_name\n )\n raw_value = data.get(field_name, missing)\n if raw_value is missing:\n # Ignore missing field if we're allowed to.\n if partial is True or (\n partial_is_collection and attr_name in partial\n ):\n continue\n d_kwargs = {}\n # Allow partial loading of nested schemas.\n if partial_is_collection:\n prefix = field_name + \".\"\n len_prefix = len(prefix)\n sub_partial = [\n f[len_prefix:] for f in partial if f.startswith(prefix)\n ]\n d_kwargs[\"partial\"] = sub_partial\n elif partial is not None:\n d_kwargs[\"partial\"] = partial\n\n def getter(\n val, field_obj=field_obj, field_name=field_name, d_kwargs=d_kwargs\n ):\n return field_obj.deserialize(\n val,\n field_name,\n data,\n **d_kwargs,\n )\n\n value = self._call_and_store(\n getter_func=getter,\n data=raw_value,\n field_name=field_name,\n error_store=error_store,\n index=index,\n )\n if value is not missing:\n key = field_obj.attribute or attr_name\n set_value(ret_d, key, value)\n if unknown != EXCLUDE:\n fields = {\n field_obj.data_key if field_obj.data_key is not None else field_name\n for field_name, field_obj in self.load_fields.items()\n }\n for key in set(data) - fields:\n value = data[key]\n if unknown == INCLUDE:\n ret_d[key] = value\n elif unknown == RAISE:\n error_store.store_error(\n [self.error_messages[\"unknown\"]],\n key,\n (index if index_errors else None),\n )\n return ret_d\n\n def load(\n self,\n data: (\n typing.Mapping[str, typing.Any]\n | typing.Iterable[typing.Mapping[str, typing.Any]]\n ),\n *,\n many: bool | None = None,\n partial: bool | types.StrSequenceOrSet | None = None,\n unknown: str | None = None,\n ):\n \"\"\"Deserialize a data structure to an object defined by this Schema's fields.\n\n :param data: The data to deserialize.\n :param many: Whether to deserialize `data` as a collection. If `None`, the\n value for `self.many` is used.\n :param partial: Whether to ignore missing fields and not require\n any fields declared. Propagates down to ``Nested`` fields as well. If\n its value is an iterable, only missing fields listed in that iterable\n will be ignored. Use dot delimiters to specify nested fields.\n :param unknown: Whether to exclude, include, or raise an error for unknown\n fields in the data. Use `EXCLUDE`, `INCLUDE` or `RAISE`.\n If `None`, the value for `self.unknown` is used.\n :return: Deserialized data\n\n .. versionadded:: 1.0.0\n .. versionchanged:: 3.0.0b7\n This method returns the deserialized data rather than a ``(data, errors)`` duple.\n A :exc:`ValidationError ` is raised\n if invalid data are passed.\n \"\"\"\n return self._do_load(\n data, many=many, partial=partial, unknown=unknown, postprocess=True\n )\n\n def loads(\n self,\n json_data: str,\n *,\n many: bool | None = None,\n partial: bool | types.StrSequenceOrSet | None = None,\n unknown: str | None = None,\n **kwargs,\n ):\n \"\"\"Same as :meth:`load`, except it takes a JSON string as input.\n\n :param json_data: A JSON string of the data to deserialize.\n :param many: Whether to deserialize `obj` as a collection. If `None`, the\n value for `self.many` is used.\n :param partial: Whether to ignore missing fields and not require\n any fields declared. Propagates down to ``Nested`` fields as well. If\n its value is an iterable, only missing fields listed in that iterable\n will be ignored. Use dot delimiters to specify nested fields.\n :param unknown: Whether to exclude, include, or raise an error for unknown\n fields in the data. Use `EXCLUDE`, `INCLUDE` or `RAISE`.\n If `None`, the value for `self.unknown` is used.\n :return: Deserialized data\n\n .. versionadded:: 1.0.0\n .. versionchanged:: 3.0.0b7\n This method returns the deserialized data rather than a ``(data, errors)`` duple.\n A :exc:`ValidationError ` is raised\n if invalid data are passed.\n \"\"\"\n data = self.opts.render_module.loads(json_data, **kwargs)\n return self.load(data, many=many, partial=partial, unknown=unknown)\n\n \ndef _run_validator(\n self,\n validator_func,\n output,\n *,\n original_data,\n error_store,\n many,\n partial,\n pass_original,\n index=None,\n ):\n try:\n if pass_original: # Pass original, raw data (before unmarshalling)\n validator_func(output, original_data, partial=partial, many=many)\n else:\n validator_func(output, partial=partial, many=many)\n except ValidationError as err:\n error_store.store_error(err.messages, err.field_name, index=index)\n\n def validate(\n self,\n data: (\n typing.Mapping[str, typing.Any]\n | typing.Iterable[typing.Mapping[str, typing.Any]]\n ),\n *,\n many: bool | None = None,\n partial: bool | types.StrSequenceOrSet | None = None,\n ) -> dict[str, list[str]]:\n \"\"\"Validate `data` against the schema, returning a dictionary of\n validation errors.\n\n :param data: The data to validate.\n :param many: Whether to validate `data` as a collection. If `None`, the\n value for `self.many` is used.\n :param partial: Whether to ignore missing fields and not require\n any fields declared. Propagates down to ``Nested`` fields as well. If\n its value is an iterable, only missing fields listed in that iterable\n will be ignored. Use dot delimiters to specify nested fields.\n :return: A dictionary of validation errors.\n\n .. versionadded:: 1.1.0\n \"\"\"\n try:\n self._do_load(data, many=many, partial=partial, postprocess=False)\n except ValidationError as exc:\n return typing.cast(typing.Dict[str, typing.List[str]], exc.messages)\n return {}\n\n ##### Private Helpers #####\n\n def _do_load(\n self,\n data: (\n typing.Mapping[str, typing.Any]\n | typing.Iterable[typing.Mapping[str, typing.Any]]\n ),\n *,\n many: bool | None = None,\n partial: bool | types.StrSequenceOrSet | None = None,\n unknown: str | None = None,\n postprocess: bool = True,\n ):\n \"\"\"Deserialize `data`, returning the deserialized result.\n This method is private API.\n\n :param data: The data to deserialize.\n :param many: Whether to deserialize `data` as a collection. If `None`, the\n value for `self.many` is used.\n :param partial: Whether to validate required fields. If its\n value is an iterable, only fields listed in that iterable will be\n ignored will be allowed missing. If `True`, all fields will be allowed missing.\n If `None`, the value for `self.partial` is used.\n :param unknown: Whether to exclude, include, or raise an error for unknown\n fields in the data. Use `EXCLUDE`, `INCLUDE` or `RAISE`.\n If `None`, the value for `self.unknown` is used.\n :param postprocess: Whether to run post_load methods..\n :return: Deserialized data\n \"\"\"\n error_store = ErrorStore()\n errors = {} # type: dict[str, list[str]]\n many = self.many if many is None else bool(many)\n unknown = (\n self.unknown\n if unknown is None\n else validate_unknown_parameter_value(unknown)\n )\n if partial is None:\n partial = self.partial\n # Run preprocessors\n if self._has_processors(PRE_LOAD):\n try:\n processed_data = self._invoke_load_processors(\n PRE_LOAD, data, many=many, original_data=data, partial=partial\n )\n except ValidationError as err:\n errors = err.normalized_messages()\n result = None # type: list | dict | None\n else:\n processed_data = data\n if not errors:\n # Deserialize data\n result = self._deserialize(\n processed_data,\n error_store=error_store,\n many=many,\n partial=partial,\n unknown=unknown,\n )\n # Run field-level validation\n self._invoke_field_validators(\n error_store=error_store, data=result, many=many\n )\n # Run schema-level validation\n if self._has_processors(VALIDATES_SCHEMA):\n field_errors = bool(error_store.errors)\n self._invoke_schema_validators(\n error_store=error_store,\n pass_many=True,\n data=result,\n original_data=data,\n many=many,\n partial=partial,\n field_errors=field_errors,\n )\n self._invoke_schema_validators(\n error_store=error_store,\n pass_many=False,\n data=result,\n original_data=data,\n many=many,\n partial=partial,\n field_errors=field_errors,\n )\n errors = error_store.errors\n # Run post processors\n if not errors and postprocess and self._has_processors(POST_LOAD):\n try:\n result = self._invoke_load_processors(\n POST_LOAD,\n result,\n many=many,\n original_data=data,\n partial=partial,\n )\n except ValidationError as err:\n errors = err.normalized_messages()\n if errors:\n exc = ValidationError(errors, data=data, valid_data=result)\n self.handle_error(exc, data, many=many, partial=partial)\n raise exc\n\n return result\n\n def _normalize_nested_options(self) -> None:\n \"\"\"Apply then flatten nested schema options.\n This method is private API.\n \"\"\"\n if self.only is not None:\n # Apply the only option to nested fields.\n self.__apply_nested_option(\"only\", self.only, \"intersection\")\n # Remove the child field names from the only option.\n self.only = self.set_class([field.split(\".\", 1)[0] for field in self.only])\n if self.exclude:\n # Apply the exclude option to nested fields.\n self.__apply_nested_option(\"exclude\", self.exclude, \"union\")\n # Remove the parent field names from the exclude option.\n self.exclude = self.set_class(\n [field for field in self.exclude if \".\" not in field]\n )\n\n def __apply_nested_option(self, option_name, field_names, set_operation) -> None:\n \"\"\"Apply nested options to nested fields\"\"\"\n # Split nested field names on the first dot.\n nested_fields = [name.split(\".\", 1) for name in field_names if \".\" in name]\n # Partition the nested field names by parent field.\n nested_options = defaultdict(list) # type: defaultdict\n for parent, nested_names in nested_fields:\n nested_options[parent].append(nested_names)\n # Apply the nested field options.\n for key, options in iter(nested_options.items()):\n new_options = self.set_class(options)\n original_options = getattr(self.declared_fields[key], option_name, ())\n if original_options:\n if set_operation == \"union\":\n new_options |= self.set_class(original_options)\n if set_operation == \"intersection\":\n new_options &= self.set_class(original_options)\n setattr(self.declared_fields[key], option_name, new_options)\n\n def _init_fields(self) -> None:\n \"\"\"Update self.fields, self.load_fields, and self.dump_fields based on schema options.\n This method is private API.\n \"\"\"\n if self.opts.fields:\n available_field_names = self.set_class(self.opts.fields)\n else:\n available_field_names = self.set_class(self.declared_fields.keys())\n if self.opts.additional:\n available_field_names |= self.set_class(self.opts.additional)\n\n invalid_fields = self.set_class()\n\n if self.only is not None:\n # Return only fields specified in only option\n field_names: typing.AbstractSet[typing.Any] = self.set_class(self.only)\n\n invalid_fields |= field_names - available_field_names\n else:\n field_names = available_field_names\n\n # If \"exclude\" option or param is specified, remove those fields.\n if self.exclude:\n # Note that this isn't available_field_names, since we want to\n # apply \"only\" for the actual calculation.\n field_names = field_names - self.exclude\n invalid_fields |= self.exclude - available_field_names\n\n if invalid_fields:\n message = f\"Invalid fields for {self}: {invalid_fields}.\"\n raise ValueError(message)\n\n fields_dict = self.dict_class()\n for field_name in field_names:\n field_obj = self.declared_fields.get(field_name, ma_fields.Inferred())\n self._bind_field(field_name, field_obj)\n fields_dict[field_name] = field_obj\n\n load_fields, dump_fields = self.dict_class(), self.dict_class()\n for field_name, field_obj in fields_dict.items():\n if not field_obj.dump_only:\n load_fields[field_name] = field_obj\n if not field_obj.load_only:\n dump_fields[field_name] = field_obj\n\n dump_data_keys = [\n field_obj.data_key if field_obj.data_key is not None else name\n for name, field_obj in dump_fields.items()\n ]\n if len(dump_data_keys) != len(set(dump_data_keys)):\n data_keys_duplicates = {\n x for x in dump_data_keys if dump_data_keys.count(x) > 1\n }\n raise ValueError(\n \"The data_key argument for one or more fields collides \"\n \"with another field's name or data_key argument. \"\n \"Check the following field names and \"\n f\"data_key arguments: {list(data_keys_duplicates)}\"\n )\n load_attributes = [obj.attribute or name for name, obj in load_fields.items()]\n if len(load_attributes) != len(set(load_attributes)):\n attributes_duplicates = {\n x for x in load_attributes if load_attributes.count(x) > 1\n }\n raise ValueError(\n \"The attribute argument for one or more fields collides \"\n \"with another field's name or attribute argument. \"\n \"Check the following field names and \"\n f\"attribute arguments: {list(attributes_duplicates)}\"\n )\n\n self.fields = fields_dict\n self.dump_fields = dump_fields\n self.load_fields = load_fields\n\n def on_bind_field(self, field_name: str, field_obj: ma_fields.Field) -> None:\n \"\"\"Hook to modify a field when it is bound to the `Schema`.\n\n No-op by default.\n \"\"\"\n return None\n\n def _bind_field(self, field_name: str, field_obj: ma_fields.Field) -> None:\n \"\"\"Bind field to the schema, setting any necessary attributes on the\n field (e.g. parent and name).\n\n Also set field load_only and dump_only values if field_name was\n specified in ``class Meta``.\n \"\"\"\n if field_name in self.load_only:\n field_obj.load_only = True\n if field_name in self.dump_only:\n field_obj.dump_only = True\n try:\n field_obj._bind_to_schema(field_name, self)\n except TypeError as error:\n # Field declared as a class, not an instance. Ignore type checking because\n # we handle unsupported arg types, i.e. this is dead code from\n # the type checker's perspective.\n if isinstance(field_obj, type) and issubclass(field_obj, base.FieldABC):\n msg = (\n f'Field for \"{field_name}\" must be declared as a '\n \"Field instance, not a class. \"\n f'Did you mean \"fields.{field_obj.__name__}()\"?' # type: ignore\n )\n raise TypeError(msg) from error\n raise error\n self.on_bind_field(field_name, field_obj)\n\n @lru_cache(maxsize=8) # noqa (https://github.com/PyCQA/flake8-bugbear/issues/310)\n def _has_processors(self, tag) -> bool:\n return bool(self._hooks[(tag, True)] or self._hooks[(tag, False)])\n\n def _invoke_dump_processors(\n self, tag: str, data, *, many: bool, original_data=None\n ):\n # The pass_many post-dump processors may do things like add an envelope, so\n # invoke those after invoking the non-pass_many processors which will expect\n # to get a list of items.\n data = self._invoke_processors(\n tag, pass_many=False, data=data, many=many, original_data=original_data\n )\n data = self._invoke_processors(\n tag, pass_many=True, data=data, many=many, original_data=original_data\n )\n return data\n\n def _invoke_load_processors(\n self,\n tag: str,\n data,\n *,\n many: bool,\n original_data,\n partial: bool | types.StrSequenceOrSet | None,\n ):\n # This has to invert the order of the dump processors, so run the pass_many\n # processors first.\n data = self._invoke_processors(\n tag,\n pass_many=True,\n data=data,\n many=many,\n original_data=original_data,\n partial=partial,\n )\n data = self._invoke_processors(\n tag,\n pass_many=False,\n data=data,\n many=many,\n original_data=original_data,\n partial=partial,\n )\n return data\n\n def _invoke_field_validators(self, *, error_store: ErrorStore, data, many: bool):\n for attr_name in self._hooks[VALIDATES]:\n validator = getattr(self, attr_name)\n validator_kwargs = validator.__marshmallow_hook__[VALIDATES]\n field_name = validator_kwargs[\"field_name\"]\n\n try:\n field_obj = self.fields[field_name]\n except KeyError as error:\n if field_name in self.declared_fields:\n continue\n raise ValueError(f'\"{field_name}\" field does not exist.') from error\n\n data_key = (\n field_obj.data_key if field_obj.data_key is not None else field_name\n )\n if many:\n for idx, item in enumerate(data):\n try:\n value = item[field_obj.attribute or field_name]\n except KeyError:\n pass\n else:\n validated_value = self._call_and_store(\n getter_func=validator,\n data=value,\n field_name=data_key,\n error_store=error_store,\n index=(idx if self.opts.index_errors else None),\n )\n if validated_value is missing:\n data[idx].pop(field_name, None)\n else:\n try:\n value = data[field_obj.attribute or field_name]\n except KeyError:\n pass\n else:\n validated_value = self._call_and_store(\n getter_func=validator,\n data=value,\n field_name=data_key,\n error_store=error_store,\n )\n if validated_value is missing:\n data.pop(field_name, None)\n\n def _invoke_schema_validators(\n self,\n *,\n error_store: ErrorStore,\n pass_many: bool,\n data,\n original_data,\n many: bool,\n partial: bool | types.StrSequenceOrSet | None,\n field_errors: bool = False,\n ):\n for attr_name in self._hooks[(VALIDATES_SCHEMA, pass_many)]:\n validator = getattr(self, attr_name)\n validator_kwargs = validator.__marshmallow_hook__[\n (VALIDATES_SCHEMA, pass_many)\n ]\n if field_errors and validator_kwargs[\"skip_on_field_errors\"]:\n continue\n pass_original = validator_kwargs.get(\"pass_original\", False)\n\n if many and not pass_many:\n for idx, (item, orig) in enumerate(zip(data, original_data)):\n self._run_validator(\n validator,\n item,\n original_data=orig,\n error_store=error_store,\n many=many,\n partial=partial,\n index=idx,\n pass_original=pass_original,\n )\n else:\n self._run_validator(\n validator,\n data,\n original_data=original_data,\n error_store=error_store,\n many=many,\n pass_original=pass_original,\n partial=partial,\n )\n\n def _invoke_processors(\n self,\n tag: str,\n *,\n pass_many: bool,\n data,\n many: bool,\n original_data=None,\n **kwargs,\n ):\n key = (tag, pass_many)\n for attr_name in self._hooks[key]:\n # This will be a bound method.\n processor = getattr(self, attr_name)\n\n processor_kwargs = processor.__marshmallow_hook__[key]\n pass_original = processor_kwargs.get(\"pass_original\", False)\n\n if many and not pass_many:\n if pass_original:\n data = [\n processor(item, original, many=many, **kwargs)\n for item, original in zip(data, original_data)\n ]\n else:\n data = [processor(item, many=many, **kwargs) for item in data]\n else:\n if pass_original:\n data = processor(data, original_data, many=many, **kwargs)\n else:\n data = processor(data, many=many, **kwargs)\n return data\n\n\nBaseSchema = Schema # for backwards compatibility\n\n# Path: src/marshmallow/validate.py\n\"\"\"Validation classes for various types of data.\"\"\"\nfrom __future__ import annotations\n\nimport re\nimport typing\nfrom abc import ABC, abstractmethod\nfrom itertools import zip_longest\nfrom operator import attrgetter\n\nfrom marshmallow import types\nfrom marshmallow.exceptions import ValidationError\n\n_T = typing.TypeVar(\"_T\")\n\n\nclass Validator(ABC):\n \"\"\"Abstract base class for validators.\n\n .. note::\n This class does not provide any validation behavior. It is only used to\n add a useful `__repr__` implementation for validators.\n \"\"\"\n\n error = None # type: str | None\n\n def __repr__(self) -> str:\n args = self._repr_args()\n args = f\"{args}, \" if args else \"\"\n\n return f\"<{self.__class__.__name__}({args}error={self.error!r})>\"\n\n def _repr_args(self) -> str:\n \"\"\"A string representation of the args passed to this validator. Used by\n `__repr__`.\n \"\"\"\n return \"\"\n\n @abstractmethod\n def __call__(self, value: typing.Any) -> typing.Any:\n ...\n\n\nclass And(Validator):\n \"\"\"Compose multiple validators and combine their error messages.\n\n Example: ::\n\n from marshmallow import validate, ValidationError\n\n def is_even(value):\n if value % 2 != 0:\n raise ValidationError(\"Not an even value.\")\n\n validator = validate.And(validate.Range(min=0), is_even)\n validator(-1)\n # ValidationError: ['Must be greater than or equal to 0.', 'Not an even value.']\n\n :param validators: Validators to combine.\n :param error: Error message to use when a validator returns ``False``.\n \"\"\"\n\n default_error_message = \"Invalid value.\"\n\n def __init__(self, *validators: types.Validator, error: str | None = None):\n self.validators = tuple(validators)\n self.error = error or self.default_error_message # type: str\n\n def _repr_args(self) -> str:\n return f\"validators={self.validators!r}\"\n\n def __call__(self, value: typing.Any) -> typing.Any:\n errors = []\n kwargs = {}\n for validator in self.validators:\n try:\n r = validator(value)\n if not isinstance(validator, Validator) and r is False:\n raise ValidationError(self.error)\n except ValidationError as err:\n kwargs.update(err.kwargs)\n if isinstance(err.messages, dict):\n errors.append(err.messages)\n else:\n # FIXME : Get rid of cast\n errors.extend(typing.cast(list, err.messages))\n if errors:\n raise ValidationError(errors, **kwargs)\n return value\n\n\nclass URL(Validator):\n \"\"\"Validate a URL.\n\n :param relative: Whether to allow relative URLs.\n :param absolute: Whether to allow absolute URLs.\n :param error: Error message to raise in case of a validation error.\n Can be interpolated with `{input}`.\n :param schemes: Valid schemes. By default, ``http``, ``https``,\n ``ftp``, and ``ftps`` are allowed.\n :param require_tld: Whether to reject non-FQDN hostnames.\n \"\"\"\n\n class RegexMemoizer:\n def __init__(self):\n self._memoized = {}\n\n def _regex_generator(\n self, relative: bool, absolute: bool, require_tld: bool\n ) -> typing.Pattern:\n hostname_variants = [\n # a normal domain name, expressed in [A-Z0-9] chars with hyphens allowed only in the middle\n # note that the regex will be compiled with IGNORECASE, so these are upper and lowercase chars\n (\n r\"(?:[A-Z0-9](?:[A-Z0-9-]{0,61}[A-Z0-9])?\\.)+\"\n r\"(?:[A-Z]{2,6}\\.?|[A-Z0-9-]{2,}\\.?)\"\n ),\n # or the special string 'localhost'\n r\"localhost\",\n # or IPv4\n r\"\\d{1,3}\\.\\d{1,3}\\.\\d{1,3}\\.\\d{1,3}\",\n # or IPv6\n r\"\\[[A-F0-9]*:[A-F0-9:]+\\]\",\n ]\n if not require_tld:\n # allow dotless hostnames\n hostname_variants.append(r\"(?:[A-Z0-9](?:[A-Z0-9-]{0,61}[A-Z0-9])?\\.?)\")\n\n absolute_part = \"\".join(\n (\n # scheme (e.g. 'https://', 'ftp://', etc)\n # this is validated separately against allowed schemes, so in the regex\n # we simply want to capture its existence\n r\"(?:[a-z0-9\\.\\-\\+]*)://\",\n # userinfo, for URLs encoding authentication\n # e.g. 'ftp://foo:bar@ftp.example.org/'\n r\"(?:(?:[a-z0-9\\-._~!$&'()*+,;=:]|%[0-9a-f]{2})*@)?\",\n # netloc, the hostname/domain part of the URL plus the optional port\n r\"(?:\",\n \"|\".join(hostname_variants),\n r\")\",\n r\"(?::\\d+)?\",\n )\n )\n relative_part = r\"(?:/?|[/?]\\S+)\\Z\"\n\n if relative:\n if absolute:\n parts: tuple[str, ...] = (\n r\"^(\",\n absolute_part,\n r\")?\",\n relative_part,\n )\n else:\n parts = (r\"^\", relative_part)\n else:\n parts = (r\"^\", absolute_part, relative_part)\n\n return re.compile(\"\".join(parts), re.IGNORECASE)\n\n def __call__(\n self, relative: bool, absolute: bool, require_tld: bool\n ) -> typing.Pattern:\n key = (relative, absolute, require_tld)\n if key not in self._memoized:\n self._memoized[key] = self._regex_generator(\n relative, absolute, require_tld\n )\n\n return self._memoized[key]\n\n _regex = RegexMemoizer()\n\n default_message = \"Not a valid URL.\"\n default_schemes = {\"http\", \"https\", \"ftp\", \"ftps\"}\n\n def __init__(\n self,\n *,\n relative: bool = False,\n absolute: bool = True,\n schemes: types.StrSequenceOrSet | None = None,\n require_tld: bool = True,\n error: str | None = None,\n ):\n if not relative and not absolute:\n raise ValueError(\n \"URL validation cannot set both relative and absolute to False.\"\n )\n self.relative = relative\n self.absolute = absolute\n self.error = error or self.default_message # type: str\n self.schemes = schemes or self.default_schemes\n self.require_tld = require_tld\n\n def _repr_args(self) -> str:\n return f\"relative={self.relative!r}, absolute={self.absolute!r}\"\n\n def _format_error(self, value) -> str:\n return self.error.format(input=value)\n\n def __call__(self, value: str) -> str:\n message = self._format_error(value)\n if not value:\n raise ValidationError(message)\n\n # Check first if the scheme is valid\n if \"://\" in value:\n scheme = value.split(\"://\")[0].lower()\n if scheme not in self.schemes:\n raise ValidationError(message)\n\n regex = self._regex(self.relative, self.absolute, self.require_tld)\n\n if not regex.search(value):\n raise ValidationError(message)\n\n return value\n\n\nclass Email(Validator):\n \"\"\"Validate an email address.\n\n :param error: Error message to raise in case of a validation error. Can be\n interpolated with `{input}`.\n \"\"\"\n\n USER_REGEX = re.compile(\n r\"(^[-!#$%&'*+/=?^`{}|~\\w]+(\\.[-!#$%&'*+/=?^`{}|~\\w]+)*\\Z\" # dot-atom\n # quoted-string\n r'|^\"([\\001-\\010\\013\\014\\016-\\037!#-\\[\\]-\\177]'\n r'|\\\\[\\001-\\011\\013\\014\\016-\\177])*\"\\Z)',\n re.IGNORECASE | re.UNICODE,\n )\n\n DOMAIN_REGEX = re.compile(\n # domain\n r\"(?:[A-Z0-9](?:[A-Z0-9-]{0,61}[A-Z0-9])?\\.)+\"\n r\"(?:[A-Z]{2,6}|[A-Z0-9-]{2,})\\Z\"\n # literal form, ipv4 address (SMTP 4.1.3)\n r\"|^\\[(25[0-5]|2[0-4]\\d|[0-1]?\\d?\\d)\"\n...\n", "needle_token_start": 8818, "needle_token_end": 8979, "code_context_ntokens": 16170, "output": ["```python\ndef _run_validator(\n self,\n validator_func,\n output,\n *,\n original_data,\n error_store,\n many,\n partial,\n pass_original,\n index=None,\n ):\n try:\n if pass_original: # Pass original, raw data (before unmarshalling)\n validator_func(output, original_data, partial=partial, many=many)\n else:\n validator_func(output, partial=partial, many=many)\n except ValidationError as err:\n error_store.store_error(err.messages, err.field_name, index=index)\n```"]} +{"repo": "marshmallow-code/marshmallow", "name": "from_iso_time", "language": "python", "path": "src/marshmallow/utils.py", "position_ratio": 0.65, "description": "\nFunction Description:\n1. **Purpose**: To convert a string representing a time in ISO8601 format into a time object.\n2. **Input**: A string formatted according to ISO8601 standards, representing a time without timezone information.\n3. **Output**: A time object representing the time specified in the input string.\n4. **Procedure**: The function first matches the input string against a regular expression designed to parse ISO8601 time strings. If the string does not match the pattern, it raises a ValueError indicating the string is not valid. If the string matches, it extracts components such as hour, minute, second, and microsecond from the string. These extracted values are then converted into integers (with microseconds adjusted for precision) and used to construct a time object, which is then returned.\n\n", "instruction": "Based on the function description and code context, please retrieve and repeat the exact described function from the code context in a code block wrapped by ```:", "template": "instruction\ncode_context\ndescription\ninstruction", "code_context": "# Path: src/marshmallow/decorators.py\n\"\"\"Decorators for registering schema pre-processing and post-processing methods.\nThese should be imported from the top-level `marshmallow` module.\n\nMethods decorated with\n`pre_load `, `post_load `,\n`pre_dump `, `post_dump `,\nand `validates_schema ` receive\n``many`` as a keyword argument. In addition, `pre_load `,\n`post_load `,\nand `validates_schema ` receive\n``partial``. If you don't need these arguments, add ``**kwargs`` to your method\nsignature.\n\n\nExample: ::\n\n from marshmallow import (\n Schema, pre_load, pre_dump, post_load, validates_schema,\n validates, fields, ValidationError\n )\n\n class UserSchema(Schema):\n\n email = fields.Str(required=True)\n age = fields.Integer(required=True)\n\n @post_load\n def lowerstrip_email(self, item, many, **kwargs):\n item['email'] = item['email'].lower().strip()\n return item\n\n @pre_load(pass_many=True)\n def remove_envelope(self, data, many, **kwargs):\n namespace = 'results' if many else 'result'\n return data[namespace]\n\n @post_dump(pass_many=True)\n def add_envelope(self, data, many, **kwargs):\n namespace = 'results' if many else 'result'\n return {namespace: data}\n\n @validates_schema\n def validate_email(self, data, **kwargs):\n if len(data['email']) < 3:\n raise ValidationError('Email must be more than 3 characters', 'email')\n\n @validates('age')\n def validate_age(self, data, **kwargs):\n if data < 14:\n raise ValidationError('Too young!')\n\n.. note::\n These decorators only work with instance methods. Class and static\n methods are not supported.\n\n.. warning::\n The invocation order of decorated methods of the same type is not guaranteed.\n If you need to guarantee order of different processing steps, you should put\n them in the same processing method.\n\"\"\"\nfrom __future__ import annotations\n\nimport functools\nfrom typing import Any, Callable, cast\n\nPRE_DUMP = \"pre_dump\"\nPOST_DUMP = \"post_dump\"\nPRE_LOAD = \"pre_load\"\nPOST_LOAD = \"post_load\"\nVALIDATES = \"validates\"\nVALIDATES_SCHEMA = \"validates_schema\"\n\n\nclass MarshmallowHook:\n __marshmallow_hook__: dict[tuple[str, bool] | str, Any] | None = None\n\n\ndef validates(field_name: str) -> Callable[..., Any]:\n \"\"\"Register a field validator.\n\n :param str field_name: Name of the field that the method validates.\n \"\"\"\n return set_hook(None, VALIDATES, field_name=field_name)\n\n\ndef validates_schema(\n fn: Callable[..., Any] | None = None,\n pass_many: bool = False,\n pass_original: bool = False,\n skip_on_field_errors: bool = True,\n) -> Callable[..., Any]:\n \"\"\"Register a schema-level validator.\n\n By default it receives a single object at a time, transparently handling the ``many``\n argument passed to the `Schema`'s :func:`~marshmallow.Schema.validate` call.\n If ``pass_many=True``, the raw data (which may be a collection) is passed.\n\n If ``pass_original=True``, the original data (before unmarshalling) will be passed as\n an additional argument to the method.\n\n If ``skip_on_field_errors=True``, this validation method will be skipped whenever\n validation errors have been detected when validating fields.\n\n .. versionchanged:: 3.0.0b1\n ``skip_on_field_errors`` defaults to `True`.\n\n .. versionchanged:: 3.0.0\n ``partial`` and ``many`` are always passed as keyword arguments to\n the decorated method.\n \"\"\"\n return set_hook(\n fn,\n (VALIDATES_SCHEMA, pass_many),\n pass_original=pass_original,\n skip_on_field_errors=skip_on_field_errors,\n )\n\n\ndef pre_dump(\n fn: Callable[..., Any] | None = None, pass_many: bool = False\n) -> Callable[..., Any]:\n \"\"\"Register a method to invoke before serializing an object. The method\n receives the object to be serialized and returns the processed object.\n\n By default it receives a single object at a time, transparently handling the ``many``\n argument passed to the `Schema`'s :func:`~marshmallow.Schema.dump` call.\n If ``pass_many=True``, the raw data (which may be a collection) is passed.\n\n .. versionchanged:: 3.0.0\n ``many`` is always passed as a keyword arguments to the decorated method.\n \"\"\"\n return set_hook(fn, (PRE_DUMP, pass_many))\n\n\ndef post_dump(\n fn: Callable[..., Any] | None = None,\n pass_many: bool = False,\n pass_original: bool = False,\n) -> Callable[..., Any]:\n \"\"\"Register a method to invoke after serializing an object. The method\n receives the serialized object and returns the processed object.\n\n By default it receives a single object at a time, transparently handling the ``many``\n argument passed to the `Schema`'s :func:`~marshmallow.Schema.dump` call.\n If ``pass_many=True``, the raw data (which may be a collection) is passed.\n\n If ``pass_original=True``, the original data (before serializing) will be passed as\n an additional argument to the method.\n\n .. versionchanged:: 3.0.0\n ``many`` is always passed as a keyword arguments to the decorated method.\n \"\"\"\n return set_hook(fn, (POST_DUMP, pass_many), pass_original=pass_original)\n\n\ndef pre_load(\n fn: Callable[..., Any] | None = None, pass_many: bool = False\n) -> Callable[..., Any]:\n \"\"\"Register a method to invoke before deserializing an object. The method\n receives the data to be deserialized and returns the processed data.\n\n By default it receives a single object at a time, transparently handling the ``many``\n argument passed to the `Schema`'s :func:`~marshmallow.Schema.load` call.\n If ``pass_many=True``, the raw data (which may be a collection) is passed.\n\n .. versionchanged:: 3.0.0\n ``partial`` and ``many`` are always passed as keyword arguments to\n the decorated method.\n \"\"\"\n return set_hook(fn, (PRE_LOAD, pass_many))\n\n\ndef post_load(\n fn: Callable[..., Any] | None = None,\n pass_many: bool = False,\n pass_original: bool = False,\n) -> Callable[..., Any]:\n \"\"\"Register a method to invoke after deserializing an object. The method\n receives the deserialized data and returns the processed data.\n\n By default it receives a single object at a time, transparently handling the ``many``\n argument passed to the `Schema`'s :func:`~marshmallow.Schema.load` call.\n If ``pass_many=True``, the raw data (which may be a collection) is passed.\n\n If ``pass_original=True``, the original data (before deserializing) will be passed as\n an additional argument to the method.\n\n .. versionchanged:: 3.0.0\n ``partial`` and ``many`` are always passed as keyword arguments to\n the decorated method.\n \"\"\"\n return set_hook(fn, (POST_LOAD, pass_many), pass_original=pass_original)\n\n\ndef set_hook(\n fn: Callable[..., Any] | None, key: tuple[str, bool] | str, **kwargs: Any\n) -> Callable[..., Any]:\n \"\"\"Mark decorated function as a hook to be picked up later.\n You should not need to use this method directly.\n\n .. note::\n Currently only works with functions and instance methods. Class and\n static methods are not supported.\n\n :return: Decorated function if supplied, else this decorator with its args\n bound.\n \"\"\"\n # Allow using this as either a decorator or a decorator factory.\n if fn is None:\n return functools.partial(set_hook, key=key, **kwargs)\n\n # Set a __marshmallow_hook__ attribute instead of wrapping in some class,\n # because I still want this to end up as a normal (unbound) method.\n function = cast(MarshmallowHook, fn)\n try:\n hook_config = function.__marshmallow_hook__\n except AttributeError:\n function.__marshmallow_hook__ = hook_config = {}\n # Also save the kwargs for the tagged function on\n # __marshmallow_hook__, keyed by (, )\n if hook_config is not None:\n hook_config[key] = kwargs\n\n return fn\n\n# Path: src/marshmallow/exceptions.py\n\"\"\"Exception classes for marshmallow-related errors.\"\"\"\nfrom __future__ import annotations\n\nimport typing\n\n# Key used for schema-level validation errors\nSCHEMA = \"_schema\"\n\n\nclass MarshmallowError(Exception):\n \"\"\"Base class for all marshmallow-related errors.\"\"\"\n\n\nclass ValidationError(MarshmallowError):\n \"\"\"Raised when validation fails on a field or schema.\n\n Validators and custom fields should raise this exception.\n\n :param message: An error message, list of error messages, or dict of\n error messages. If a dict, the keys are subitems and the values are error messages.\n :param field_name: Field name to store the error on.\n If `None`, the error is stored as schema-level error.\n :param data: Raw input data.\n :param valid_data: Valid (de)serialized data.\n \"\"\"\n\n def __init__(\n self,\n message: str | list | dict,\n field_name: str = SCHEMA,\n data: typing.Mapping[str, typing.Any]\n | typing.Iterable[typing.Mapping[str, typing.Any]]\n | None = None,\n valid_data: list[dict[str, typing.Any]] | dict[str, typing.Any] | None = None,\n **kwargs,\n ):\n self.messages = [message] if isinstance(message, (str, bytes)) else message\n self.field_name = field_name\n self.data = data\n self.valid_data = valid_data\n self.kwargs = kwargs\n super().__init__(message)\n\n def normalized_messages(self):\n if self.field_name == SCHEMA and isinstance(self.messages, dict):\n return self.messages\n return {self.field_name: self.messages}\n\n @property\n def messages_dict(self) -> dict[str, typing.Any]:\n if not isinstance(self.messages, dict):\n raise TypeError(\n \"cannot access 'messages_dict' when 'messages' is of type \"\n + type(self.messages).__name__\n )\n return self.messages\n\n\nclass RegistryError(NameError):\n \"\"\"Raised when an invalid operation is performed on the serializer\n class registry.\n \"\"\"\n\n\nclass StringNotCollectionError(MarshmallowError, TypeError):\n \"\"\"Raised when a string is passed when a list of strings is expected.\"\"\"\n\n\nclass FieldInstanceResolutionError(MarshmallowError, TypeError):\n \"\"\"Raised when schema to instantiate is neither a Schema class nor an instance.\"\"\"\n\n# Path: src/marshmallow/base.py\n\"\"\"Abstract base classes.\n\nThese are necessary to avoid circular imports between schema.py and fields.py.\n\n.. warning::\n\n This module is treated as private API.\n Users should not need to use this module directly.\n\"\"\"\nfrom __future__ import annotations\n\nfrom abc import ABC, abstractmethod\n\n\nclass FieldABC(ABC):\n \"\"\"Abstract base class from which all Field classes inherit.\"\"\"\n\n parent = None\n name = None\n root = None\n\n @abstractmethod\n def serialize(self, attr, obj, accessor=None):\n pass\n\n @abstractmethod\n def deserialize(self, value):\n pass\n\n @abstractmethod\n def _serialize(self, value, attr, obj, **kwargs):\n pass\n\n @abstractmethod\n def _deserialize(self, value, attr, data, **kwargs):\n pass\n\n\nclass SchemaABC(ABC):\n \"\"\"Abstract base class from which all Schemas inherit.\"\"\"\n\n @abstractmethod\n def dump(self, obj, *, many: bool | None = None):\n pass\n\n @abstractmethod\n def dumps(self, obj, *, many: bool | None = None):\n pass\n\n @abstractmethod\n def load(self, data, *, many: bool | None = None, partial=None, unknown=None):\n pass\n\n @abstractmethod\n def loads(\n self,\n json_data,\n *,\n many: bool | None = None,\n partial=None,\n unknown=None,\n **kwargs,\n ):\n pass\n\n# Path: src/marshmallow/class_registry.py\n\"\"\"A registry of :class:`Schema ` classes. This allows for string\nlookup of schemas, which may be used with\nclass:`fields.Nested `.\n\n.. warning::\n\n This module is treated as private API.\n Users should not need to use this module directly.\n\"\"\"\nfrom __future__ import annotations\n\nimport typing\n\nfrom marshmallow.exceptions import RegistryError\n\nif typing.TYPE_CHECKING:\n from marshmallow import Schema\n\n SchemaType = typing.Type[Schema]\n\n# {\n# : \n# : \n# }\n_registry = {} # type: dict[str, list[SchemaType]]\n\n\ndef register(classname: str, cls: SchemaType) -> None:\n \"\"\"Add a class to the registry of serializer classes. When a class is\n registered, an entry for both its classname and its full, module-qualified\n path are added to the registry.\n\n Example: ::\n\n class MyClass:\n pass\n\n register('MyClass', MyClass)\n # Registry:\n # {\n # 'MyClass': [path.to.MyClass],\n # 'path.to.MyClass': [path.to.MyClass],\n # }\n\n \"\"\"\n # Module where the class is located\n module = cls.__module__\n # Full module path to the class\n # e.g. user.schemas.UserSchema\n fullpath = \".\".join([module, classname])\n # If the class is already registered; need to check if the entries are\n # in the same module as cls to avoid having multiple instances of the same\n # class in the registry\n if classname in _registry and not any(\n each.__module__ == module for each in _registry[classname]\n ):\n _registry[classname].append(cls)\n elif classname not in _registry:\n _registry[classname] = [cls]\n\n # Also register the full path\n if fullpath not in _registry:\n _registry.setdefault(fullpath, []).append(cls)\n else:\n # If fullpath does exist, replace existing entry\n _registry[fullpath] = [cls]\n return None\n\n\ndef get_class(classname: str, all: bool = False) -> list[SchemaType] | SchemaType:\n \"\"\"Retrieve a class from the registry.\n\n :raises: marshmallow.exceptions.RegistryError if the class cannot be found\n or if there are multiple entries for the given class name.\n \"\"\"\n try:\n classes = _registry[classname]\n except KeyError as error:\n raise RegistryError(\n f\"Class with name {classname!r} was not found. You may need \"\n \"to import the class.\"\n ) from error\n if len(classes) > 1:\n if all:\n return _registry[classname]\n raise RegistryError(\n f\"Multiple classes with name {classname!r} \"\n \"were found. Please use the full, \"\n \"module-qualified path.\"\n )\n else:\n return _registry[classname][0]\n\n# Path: src/marshmallow/error_store.py\n\"\"\"Utilities for storing collections of error messages.\n\n.. warning::\n\n This module is treated as private API.\n Users should not need to use this module directly.\n\"\"\"\n\nfrom marshmallow.exceptions import SCHEMA\n\n\nclass ErrorStore:\n def __init__(self):\n #: Dictionary of errors stored during serialization\n self.errors = {}\n\n def store_error(self, messages, field_name=SCHEMA, index=None):\n # field error -> store/merge error messages under field name key\n # schema error -> if string or list, store/merge under _schema key\n # -> if dict, store/merge with other top-level keys\n if field_name != SCHEMA or not isinstance(messages, dict):\n messages = {field_name: messages}\n if index is not None:\n messages = {index: messages}\n self.errors = merge_errors(self.errors, messages)\n\n\ndef merge_errors(errors1, errors2):\n \"\"\"Deeply merge two error messages.\n\n The format of ``errors1`` and ``errors2`` matches the ``message``\n parameter of :exc:`marshmallow.exceptions.ValidationError`.\n \"\"\"\n if not errors1:\n return errors2\n if not errors2:\n return errors1\n if isinstance(errors1, list):\n if isinstance(errors2, list):\n return errors1 + errors2\n if isinstance(errors2, dict):\n return dict(errors2, **{SCHEMA: merge_errors(errors1, errors2.get(SCHEMA))})\n return errors1 + [errors2]\n if isinstance(errors1, dict):\n if isinstance(errors2, list):\n return dict(errors1, **{SCHEMA: merge_errors(errors1.get(SCHEMA), errors2)})\n if isinstance(errors2, dict):\n errors = dict(errors1)\n for key, val in errors2.items():\n if key in errors:\n errors[key] = merge_errors(errors[key], val)\n else:\n errors[key] = val\n return errors\n return dict(errors1, **{SCHEMA: merge_errors(errors1.get(SCHEMA), errors2)})\n if isinstance(errors2, list):\n return [errors1] + errors2\n if isinstance(errors2, dict):\n return dict(errors2, **{SCHEMA: merge_errors(errors1, errors2.get(SCHEMA))})\n return [errors1, errors2]\n\n# Path: src/marshmallow/orderedset.py\n# OrderedSet\n# Copyright (c) 2009 Raymond Hettinger\n#\n# Permission is hereby granted, free of charge, to any person\n# obtaining a copy of this software and associated documentation files\n# (the \"Software\"), to deal in the Software without restriction,\n# including without limitation the rights to use, copy, modify, merge,\n# publish, distribute, sublicense, and/or sell copies of the Software,\n# and to permit persons to whom the Software is furnished to do so,\n# subject to the following conditions:\n#\n# The above copyright notice and this permission notice shall be\n# included in all copies or substantial portions of the Software.\n#\n# THE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND,\n# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES\n# OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND\n# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT\n# HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY,\n# WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING\n# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR\n# OTHER DEALINGS IN THE SOFTWARE.\nfrom collections.abc import MutableSet\n\n\nclass OrderedSet(MutableSet):\n def __init__(self, iterable=None):\n self.end = end = []\n end += [None, end, end] # sentinel node for doubly linked list\n self.map = {} # key --> [key, prev, next]\n if iterable is not None:\n self |= iterable\n\n def __len__(self):\n return len(self.map)\n\n def __contains__(self, key):\n return key in self.map\n\n def add(self, key):\n if key not in self.map:\n end = self.end\n curr = end[1]\n curr[2] = end[1] = self.map[key] = [key, curr, end]\n\n def discard(self, key):\n if key in self.map:\n key, prev, next = self.map.pop(key)\n prev[2] = next\n next[1] = prev\n\n def __iter__(self):\n end = self.end\n curr = end[2]\n while curr is not end:\n yield curr[0]\n curr = curr[2]\n\n def __reversed__(self):\n end = self.end\n curr = end[1]\n while curr is not end:\n yield curr[0]\n curr = curr[1]\n\n def pop(self, last=True):\n if not self:\n raise KeyError(\"set is empty\")\n key = self.end[1][0] if last else self.end[2][0]\n self.discard(key)\n return key\n\n def __repr__(self):\n if not self:\n return f\"{self.__class__.__name__}()\"\n return f\"{self.__class__.__name__}({list(self)!r})\"\n\n def __eq__(self, other):\n if isinstance(other, OrderedSet):\n return len(self) == len(other) and list(self) == list(other)\n return set(self) == set(other)\n\n\nif __name__ == \"__main__\":\n s = OrderedSet(\"abracadaba\")\n t = OrderedSet(\"simsalabim\")\n print(s | t)\n print(s & t)\n print(s - t)\n\n# Path: src/marshmallow/types.py\n\"\"\"Type aliases.\n\n.. warning::\n\n This module is provisional. Types may be modified, added, and removed between minor releases.\n\"\"\"\nimport typing\n\nStrSequenceOrSet = typing.Union[typing.Sequence[str], typing.AbstractSet[str]]\nTag = typing.Union[str, typing.Tuple[str, bool]]\nValidator = typing.Callable[[typing.Any], typing.Any]\n\n# Path: src/marshmallow/warnings.py\nclass RemovedInMarshmallow4Warning(DeprecationWarning):\n pass\n\n# Path: src/marshmallow/utils.py\n\"\"\"Utility methods for marshmallow.\"\"\"\nfrom __future__ import annotations\n\nimport collections\nimport datetime as dt\nimport functools\nimport inspect\nimport json\nimport re\nimport typing\nimport warnings\nfrom collections.abc import Mapping\nfrom email.utils import format_datetime, parsedate_to_datetime\nfrom pprint import pprint as py_pprint\n\nfrom marshmallow.base import FieldABC\nfrom marshmallow.exceptions import FieldInstanceResolutionError\nfrom marshmallow.warnings import RemovedInMarshmallow4Warning\n\nEXCLUDE = \"exclude\"\nINCLUDE = \"include\"\nRAISE = \"raise\"\n_UNKNOWN_VALUES = {EXCLUDE, INCLUDE, RAISE}\n\n\nclass _Missing:\n def __bool__(self):\n return False\n\n def __copy__(self):\n return self\n\n def __deepcopy__(self, _):\n return self\n\n def __repr__(self):\n return \"\"\n\n\n# Singleton value that indicates that a field's value is missing from input\n# dict passed to :meth:`Schema.load`. If the field's value is not required,\n# it's ``default`` value is used.\nmissing = _Missing()\n\n\ndef is_generator(obj) -> bool:\n \"\"\"Return True if ``obj`` is a generator\"\"\"\n return inspect.isgeneratorfunction(obj) or inspect.isgenerator(obj)\n\n\ndef is_iterable_but_not_string(obj) -> bool:\n \"\"\"Return True if ``obj`` is an iterable object that isn't a string.\"\"\"\n return (hasattr(obj, \"__iter__\") and not hasattr(obj, \"strip\")) or is_generator(obj)\n\n\ndef is_collection(obj) -> bool:\n \"\"\"Return True if ``obj`` is a collection type, e.g list, tuple, queryset.\"\"\"\n return is_iterable_but_not_string(obj) and not isinstance(obj, Mapping)\n\n\ndef is_instance_or_subclass(val, class_) -> bool:\n \"\"\"Return True if ``val`` is either a subclass or instance of ``class_``.\"\"\"\n try:\n return issubclass(val, class_)\n except TypeError:\n return isinstance(val, class_)\n\n\ndef is_keyed_tuple(obj) -> bool:\n \"\"\"Return True if ``obj`` has keyed tuple behavior, such as\n namedtuples or SQLAlchemy's KeyedTuples.\n \"\"\"\n return isinstance(obj, tuple) and hasattr(obj, \"_fields\")\n\n\ndef pprint(obj, *args, **kwargs) -> None:\n \"\"\"Pretty-printing function that can pretty-print OrderedDicts\n like regular dictionaries. Useful for printing the output of\n :meth:`marshmallow.Schema.dump`.\n\n .. deprecated:: 3.7.0\n marshmallow.pprint will be removed in marshmallow 4.\n \"\"\"\n warnings.warn(\n \"marshmallow's pprint function is deprecated and will be removed in marshmallow 4.\",\n RemovedInMarshmallow4Warning,\n stacklevel=2,\n )\n if isinstance(obj, collections.OrderedDict):\n print(json.dumps(obj, *args, **kwargs))\n else:\n py_pprint(obj, *args, **kwargs)\n\n\n# https://stackoverflow.com/a/27596917\ndef is_aware(datetime: dt.datetime) -> bool:\n return (\n datetime.tzinfo is not None and datetime.tzinfo.utcoffset(datetime) is not None\n )\n\n\ndef from_rfc(datestring: str) -> dt.datetime:\n \"\"\"Parse a RFC822-formatted datetime string and return a datetime object.\n\n https://stackoverflow.com/questions/885015/how-to-parse-a-rfc-2822-date-time-into-a-python-datetime # noqa: B950\n \"\"\"\n return parsedate_to_datetime(datestring)\n\n\ndef rfcformat(datetime: dt.datetime) -> str:\n \"\"\"Return the RFC822-formatted representation of a datetime object.\n\n :param datetime datetime: The datetime.\n \"\"\"\n return format_datetime(datetime)\n\n\n# Hat tip to Django for ISO8601 deserialization functions\n\n_iso8601_datetime_re = re.compile(\n r\"(?P\\d{4})-(?P\\d{1,2})-(?P\\d{1,2})\"\n r\"[T ](?P\\d{1,2}):(?P\\d{1,2})\"\n r\"(?::(?P\\d{1,2})(?:\\.(?P\\d{1,6})\\d{0,6})?)?\"\n r\"(?PZ|[+-]\\d{2}(?::?\\d{2})?)?$\"\n)\n\n_iso8601_date_re = re.compile(r\"(?P\\d{4})-(?P\\d{1,2})-(?P\\d{1,2})$\")\n\n_iso8601_time_re = re.compile(\n r\"(?P\\d{1,2}):(?P\\d{1,2})\"\n r\"(?::(?P\\d{1,2})(?:\\.(?P\\d{1,6})\\d{0,6})?)?\"\n)\n\n\ndef get_fixed_timezone(offset: int | float | dt.timedelta) -> dt.timezone:\n \"\"\"Return a tzinfo instance with a fixed offset from UTC.\"\"\"\n if isinstance(offset, dt.timedelta):\n offset = offset.total_seconds() // 60\n sign = \"-\" if offset < 0 else \"+\"\n hhmm = \"%02d%02d\" % divmod(abs(offset), 60)\n name = sign + hhmm\n return dt.timezone(dt.timedelta(minutes=offset), name)\n\n\ndef from_iso_datetime(value):\n \"\"\"Parse a string and return a datetime.datetime.\n\n This function supports time zone offsets. When the input contains one,\n the output uses a timezone with a fixed offset from UTC.\n \"\"\"\n match = _iso8601_datetime_re.match(value)\n if not match:\n raise ValueError(\"Not a valid ISO8601-formatted datetime string\")\n kw = match.groupdict()\n kw[\"microsecond\"] = kw[\"microsecond\"] and kw[\"microsecond\"].ljust(6, \"0\")\n tzinfo = kw.pop(\"tzinfo\")\n if tzinfo == \"Z\":\n tzinfo = dt.timezone.utc\n elif tzinfo is not None:\n offset_mins = int(tzinfo[-2:]) if len(tzinfo) > 3 else 0\n offset = 60 * int(tzinfo[1:3]) + offset_mins\n if tzinfo[0] == \"-\":\n offset = -offset\n tzinfo = get_fixed_timezone(offset)\n kw = {k: int(v) for k, v in kw.items() if v is not None}\n kw[\"tzinfo\"] = tzinfo\n return dt.datetime(**kw)\n\n\n\ndef from_iso_time(value):\n \"\"\"Parse a string and return a datetime.time.\n\n This function doesn't support time zone offsets.\n \"\"\"\n match = _iso8601_time_re.match(value)\n if not match:\n raise ValueError(\"Not a valid ISO8601-formatted time string\")\n kw = match.groupdict()\n kw[\"microsecond\"] = kw[\"microsecond\"] and kw[\"microsecond\"].ljust(6, \"0\")\n kw = {k: int(v) for k, v in kw.items() if v is not None}\n return dt.time(**kw)\n\n\ndef from_iso_date(value):\n \"\"\"Parse a string and return a datetime.date.\"\"\"\n match = _iso8601_date_re.match(value)\n if not match:\n raise ValueError(\"Not a valid ISO8601-formatted date string\")\n kw = {k: int(v) for k, v in match.groupdict().items()}\n return dt.date(**kw)\n\n\ndef from_timestamp(value: typing.Any) -> dt.datetime:\n value = float(value)\n if value < 0:\n raise ValueError(\"Not a valid POSIX timestamp\")\n\n # Load a timestamp with utc as timezone to prevent using system timezone.\n # Then set timezone to None, to let the Field handle adding timezone info.\n try:\n return dt.datetime.fromtimestamp(value, tz=dt.timezone.utc).replace(tzinfo=None)\n except OverflowError as exc:\n raise ValueError(\"Timestamp is too large\") from exc\n except OSError as exc:\n raise ValueError(\"Error converting value to datetime\") from exc\n\n\ndef from_timestamp_ms(value: typing.Any) -> dt.datetime:\n value = float(value)\n return from_timestamp(value / 1000)\n\n\ndef timestamp(\n value: dt.datetime,\n) -> float:\n if not is_aware(value):\n # When a date is naive, use UTC as zone info to prevent using system timezone.\n value = value.replace(tzinfo=dt.timezone.utc)\n return value.timestamp()\n\n\ndef timestamp_ms(value: dt.datetime) -> float:\n return timestamp(value) * 1000\n\n\ndef isoformat(datetime: dt.datetime) -> str:\n \"\"\"Return the ISO8601-formatted representation of a datetime object.\n\n :param datetime datetime: The datetime.\n \"\"\"\n return datetime.isoformat()\n\n\ndef to_iso_time(time: dt.time) -> str:\n return dt.time.isoformat(time)\n\n\ndef to_iso_date(date: dt.date) -> str:\n return dt.date.isoformat(date)\n\n\ndef ensure_text_type(val: str | bytes) -> str:\n if isinstance(val, bytes):\n val = val.decode(\"utf-8\")\n return str(val)\n\n\ndef pluck(dictlist: list[dict[str, typing.Any]], key: str):\n \"\"\"Extracts a list of dictionary values from a list of dictionaries.\n ::\n\n >>> dlist = [{'id': 1, 'name': 'foo'}, {'id': 2, 'name': 'bar'}]\n >>> pluck(dlist, 'id')\n [1, 2]\n \"\"\"\n return [d[key] for d in dictlist]\n\n\n# Various utilities for pulling keyed values from objects\n\n\ndef get_value(obj, key: int | str, default=missing):\n \"\"\"Helper for pulling a keyed value off various types of objects. Fields use\n this method by default to access attributes of the source object. For object `x`\n and attribute `i`, this method first tries to access `x[i]`, and then falls back to\n `x.i` if an exception is raised.\n\n .. warning::\n If an object `x` does not raise an exception when `x[i]` does not exist,\n `get_value` will never check the value `x.i`. Consider overriding\n `marshmallow.fields.Field.get_value` in this case.\n \"\"\"\n if not isinstance(key, int) and \".\" in key:\n return _get_value_for_keys(obj, key.split(\".\"), default)\n else:\n return _get_value_for_key(obj, key, default)\n\n\ndef _get_value_for_keys(obj, keys, default):\n if len(keys) == 1:\n return _get_value_for_key(obj, keys[0], default)\n else:\n return _get_value_for_keys(\n _get_value_for_key(obj, keys[0], default), keys[1:], default\n )\n\n\ndef _get_value_for_key(obj, key, default):\n if not hasattr(obj, \"__getitem__\"):\n return getattr(obj, key, default)\n\n try:\n return obj[key]\n except (KeyError, IndexError, TypeError, AttributeError):\n return getattr(obj, key, default)\n\n\ndef set_value(dct: dict[str, typing.Any], key: str, value: typing.Any):\n \"\"\"Set a value in a dict. If `key` contains a '.', it is assumed\n be a path (i.e. dot-delimited string) to the value's location.\n\n ::\n\n >>> d = {}\n >>> set_value(d, 'foo.bar', 42)\n >>> d\n {'foo': {'bar': 42}}\n \"\"\"\n if \".\" in key:\n head, rest = key.split(\".\", 1)\n target = dct.setdefault(head, {})\n if not isinstance(target, dict):\n raise ValueError(\n f\"Cannot set {key} in {head} \" f\"due to existing value: {target}\"\n )\n set_value(target, rest, value)\n else:\n dct[key] = value\n\n\ndef callable_or_raise(obj):\n \"\"\"Check that an object is callable, else raise a :exc:`TypeError`.\"\"\"\n if not callable(obj):\n raise TypeError(f\"Object {obj!r} is not callable.\")\n return obj\n\n\ndef _signature(func: typing.Callable) -> list[str]:\n return list(inspect.signature(func).parameters.keys())\n\n\ndef get_func_args(func: typing.Callable) -> list[str]:\n \"\"\"Given a callable, return a list of argument names. Handles\n `functools.partial` objects and class-based callables.\n\n .. versionchanged:: 3.0.0a1\n Do not return bound arguments, eg. ``self``.\n \"\"\"\n if inspect.isfunction(func) or inspect.ismethod(func):\n return _signature(func)\n if isinstance(func, functools.partial):\n return _signature(func.func)\n # Callable class\n return _signature(func)\n\n\ndef resolve_field_instance(cls_or_instance):\n \"\"\"Return a Schema instance from a Schema class or instance.\n\n :param type|Schema cls_or_instance: Marshmallow Schema class or instance.\n \"\"\"\n if isinstance(cls_or_instance, type):\n if not issubclass(cls_or_instance, FieldABC):\n raise FieldInstanceResolutionError\n return cls_or_instance()\n else:\n if not isinstance(cls_or_instance, FieldABC):\n raise FieldInstanceResolutionError\n return cls_or_instance\n\n\ndef timedelta_to_microseconds(value: dt.timedelta) -> int:\n \"\"\"Compute the total microseconds of a timedelta\n\n https://github.com/python/cpython/blob/bb3e0c240bc60fe08d332ff5955d54197f79751c/Lib/datetime.py#L665-L667 # noqa: B950\n \"\"\"\n return (value.days * (24 * 3600) + value.seconds) * 1000000 + value.microseconds\n\n\ndef validate_unknown_parameter_value(obj: typing.Any) -> str:\n if obj not in _UNKNOWN_VALUES:\n raise ValueError(\n f\"Object {obj!r} is not a valid value for the 'unknown' parameter\"\n )\n return obj\n\n# Path: src/marshmallow/schema.py\n\"\"\"The :class:`Schema` class, including its metaclass and options (class Meta).\"\"\"\nfrom __future__ import annotations\n\nimport copy\nimport datetime as dt\nimport decimal\nimport inspect\nimport json\nimport typing\nimport uuid\nimport warnings\nfrom abc import ABCMeta\nfrom collections import OrderedDict, defaultdict\nfrom collections.abc import Mapping\nfrom functools import lru_cache\n\nfrom marshmallow import base, class_registry, types\nfrom marshmallow import fields as ma_fields\nfrom marshmallow.decorators import (\n POST_DUMP,\n POST_LOAD,\n PRE_DUMP,\n PRE_LOAD,\n VALIDATES,\n VALIDATES_SCHEMA,\n)\nfrom marshmallow.error_store import ErrorStore\nfrom marshmallow.exceptions import StringNotCollectionError, ValidationError\nfrom marshmallow.orderedset import OrderedSet\nfrom marshmallow.utils import (\n EXCLUDE,\n INCLUDE,\n RAISE,\n get_value,\n is_collection,\n is_instance_or_subclass,\n missing,\n set_value,\n validate_unknown_parameter_value,\n)\nfrom marshmallow.warnings import RemovedInMarshmallow4Warning\n\n_T = typing.TypeVar(\"_T\")\n\n\ndef _get_fields(attrs):\n \"\"\"Get fields from a class\n\n :param attrs: Mapping of class attributes\n \"\"\"\n return [\n (field_name, field_value)\n for field_name, field_value in attrs.items()\n if is_instance_or_subclass(field_value, base.FieldABC)\n ]\n\n\n# This function allows Schemas to inherit from non-Schema classes and ensures\n# inheritance according to the MRO\ndef _get_fields_by_mro(klass):\n \"\"\"Collect fields from a class, following its method resolution order. The\n class itself is excluded from the search; only its parents are checked. Get\n fields from ``_declared_fields`` if available, else use ``__dict__``.\n\n :param type klass: Class whose fields to retrieve\n \"\"\"\n mro = inspect.getmro(klass)\n # Loop over mro in reverse to maintain correct order of fields\n return sum(\n (\n _get_fields(\n getattr(base, \"_declared_fields\", base.__dict__),\n )\n for base in mro[:0:-1]\n ),\n [],\n )\n\n\nclass SchemaMeta(ABCMeta):\n \"\"\"Metaclass for the Schema class. Binds the declared fields to\n a ``_declared_fields`` attribute, which is a dictionary mapping attribute\n names to field objects. Also sets the ``opts`` class attribute, which is\n the Schema class's ``class Meta`` options.\n \"\"\"\n\n def __new__(mcs, name, bases, attrs):\n meta = attrs.get(\"Meta\")\n ordered = getattr(meta, \"ordered\", False)\n if not ordered:\n # Inherit 'ordered' option\n # Warning: We loop through bases instead of MRO because we don't\n # yet have access to the class object\n # (i.e. can't call super before we have fields)\n for base_ in bases:\n if hasattr(base_, \"Meta\") and hasattr(base_.Meta, \"ordered\"):\n ordered = base_.Meta.ordered\n break\n else:\n ordered = False\n cls_fields = _get_fields(attrs)\n # Remove fields from list of class attributes to avoid shadowing\n # Schema attributes/methods in case of name conflict\n for field_name, _ in cls_fields:\n del attrs[field_name]\n klass = super().__new__(mcs, name, bases, attrs)\n inherited_fields = _get_fields_by_mro(klass)\n\n meta = klass.Meta\n # Set klass.opts in __new__ rather than __init__ so that it is accessible in\n # get_declared_fields\n klass.opts = klass.OPTIONS_CLASS(meta, ordered=ordered)\n # Add fields specified in the `include` class Meta option\n cls_fields += list(klass.opts.include.items())\n\n # Assign _declared_fields on class\n klass._declared_fields = mcs.get_declared_fields(\n klass=klass,\n cls_fields=cls_fields,\n inherited_fields=inherited_fields,\n dict_cls=dict,\n )\n return klass\n\n @classmethod\n def get_declared_fields(\n mcs,\n klass: type,\n cls_fields: list,\n inherited_fields: list,\n dict_cls: type = dict,\n ):\n \"\"\"Returns a dictionary of field_name => `Field` pairs declared on the class.\n This is exposed mainly so that plugins can add additional fields, e.g. fields\n computed from class Meta options.\n\n :param klass: The class object.\n :param cls_fields: The fields declared on the class, including those added\n by the ``include`` class Meta option.\n :param inherited_fields: Inherited fields.\n :param dict_cls: dict-like class to use for dict output Default to ``dict``.\n \"\"\"\n return dict_cls(inherited_fields + cls_fields)\n\n def __init__(cls, name, bases, attrs):\n super().__init__(name, bases, attrs)\n if name and cls.opts.register:\n class_registry.register(name, cls)\n cls._hooks = cls.resolve_hooks()\n\n def resolve_hooks(cls) -> dict[types.Tag, list[str]]:\n \"\"\"Add in the decorated processors\n\n By doing this after constructing the class, we let standard inheritance\n do all the hard work.\n \"\"\"\n mro = inspect.getmro(cls)\n\n hooks = defaultdict(list) # type: typing.Dict[types.Tag, typing.List[str]]\n\n for attr_name in dir(cls):\n # Need to look up the actual descriptor, not whatever might be\n # bound to the class. This needs to come from the __dict__ of the\n # declaring class.\n for parent in mro:\n try:\n attr = parent.__dict__[attr_name]\n except KeyError:\n continue\n else:\n break\n else:\n # In case we didn't find the attribute and didn't break above.\n # We should never hit this - it's just here for completeness\n # to exclude the possibility of attr being undefined.\n continue\n\n try:\n hook_config = attr.__marshmallow_hook__\n except AttributeError:\n pass\n else:\n for key in hook_config.keys():\n # Use name here so we can get the bound method later, in\n # case the processor was a descriptor or something.\n hooks[key].append(attr_name)\n\n return hooks\n\n\nclass SchemaOpts:\n \"\"\"class Meta options for the :class:`Schema`. Defines defaults.\"\"\"\n\n def __init__(self, meta, ordered: bool = False):\n self.fields = getattr(meta, \"fields\", ())\n if not isinstance(self.fields, (list, tuple)):\n raise ValueError(\"`fields` option must be a list or tuple.\")\n self.additional = getattr(meta, \"additional\", ())\n if not isinstance(self.additional, (list, tuple)):\n raise ValueError(\"`additional` option must be a list or tuple.\")\n if self.fields and self.additional:\n raise ValueError(\n \"Cannot set both `fields` and `additional` options\"\n \" for the same Schema.\"\n )\n self.exclude = getattr(meta, \"exclude\", ())\n if not isinstance(self.exclude, (list, tuple)):\n raise ValueError(\"`exclude` must be a list or tuple.\")\n self.dateformat = getattr(meta, \"dateformat\", None)\n self.datetimeformat = getattr(meta, \"datetimeformat\", None)\n self.timeformat = getattr(meta, \"timeformat\", None)\n if hasattr(meta, \"json_module\"):\n warnings.warn(\n \"The json_module class Meta option is deprecated. Use render_module instead.\",\n RemovedInMarshmallow4Warning,\n stacklevel=2,\n )\n render_module = getattr(meta, \"json_module\", json)\n else:\n render_module = json\n self.render_module = getattr(meta, \"render_module\", render_module)\n self.ordered = getattr(meta, \"ordered\", ordered)\n self.index_errors = getattr(meta, \"index_errors\", True)\n self.include = getattr(meta, \"include\", {})\n self.load_only = getattr(meta, \"load_only\", ())\n self.dump_only = getattr(meta, \"dump_only\", ())\n self.unknown = validate_unknown_parameter_value(getattr(meta, \"unknown\", RAISE))\n self.register = getattr(meta, \"register\", True)\n\n\nclass Schema(base.SchemaABC, metaclass=SchemaMeta):\n \"\"\"Base schema class with which to define custom schemas.\n\n Example usage:\n\n .. code-block:: python\n\n import datetime as dt\n from dataclasses import dataclass\n\n from marshmallow import Schema, fields\n\n\n @dataclass\n class Album:\n title: str\n release_date: dt.date\n\n\n class AlbumSchema(Schema):\n title = fields.Str()\n release_date = fields.Date()\n\n\n album = Album(\"Beggars Banquet\", dt.date(1968, 12, 6))\n schema = AlbumSchema()\n data = schema.dump(album)\n data # {'release_date': '1968-12-06', 'title': 'Beggars Banquet'}\n\n :param only: Whitelist of the declared fields to select when\n instantiating the Schema. If None, all fields are used. Nested fields\n can be represented with dot delimiters.\n :param exclude: Blacklist of the declared fields to exclude\n when instantiating the Schema. If a field appears in both `only` and\n `exclude`, it is not used. Nested fields can be represented with dot\n delimiters.\n :param many: Should be set to `True` if ``obj`` is a collection\n so that the object will be serialized to a list.\n :param context: Optional context passed to :class:`fields.Method` and\n :class:`fields.Function` fields.\n :param load_only: Fields to skip during serialization (write-only fields)\n :param dump_only: Fields to skip during deserialization (read-only fields)\n :param partial: Whether to ignore missing fields and not require\n any fields declared. Propagates down to ``Nested`` fields as well. If\n its value is an iterable, only missing fields listed in that iterable\n will be ignored. Use dot delimiters to specify nested fields.\n :param unknown: Whether to exclude, include, or raise an error for unknown\n fields in the data. Use `EXCLUDE`, `INCLUDE` or `RAISE`.\n\n .. versionchanged:: 3.0.0\n `prefix` parameter removed.\n\n .. versionchanged:: 2.0.0\n `__validators__`, `__preprocessors__`, and `__data_handlers__` are removed in favor of\n `marshmallow.decorators.validates_schema`,\n `marshmallow.decorators.pre_load` and `marshmallow.decorators.post_dump`.\n `__accessor__` and `__error_handler__` are deprecated. Implement the\n `handle_error` and `get_attribute` methods instead.\n \"\"\"\n\n TYPE_MAPPING = {\n str: ma_fields.String,\n bytes: ma_fields.String,\n dt.datetime: ma_fields.DateTime,\n float: ma_fields.Float,\n bool: ma_fields.Boolean,\n tuple: ma_fields.Raw,\n list: ma_fields.Raw,\n set: ma_fields.Raw,\n int: ma_fields.Integer,\n uuid.UUID: ma_fields.UUID,\n dt.time: ma_fields.Time,\n dt.date: ma_fields.Date,\n dt.timedelta: ma_fields.TimeDelta,\n decimal.Decimal: ma_fields.Decimal,\n } # type: typing.Dict[type, typing.Type[ma_fields.Field]]\n #: Overrides for default schema-level error messages\n error_messages = {} # type: typing.Dict[str, str]\n\n _default_error_messages = {\n \"type\": \"Invalid input type.\",\n \"unknown\": \"Unknown field.\",\n } # type: typing.Dict[str, str]\n\n OPTIONS_CLASS = SchemaOpts # type: type\n\n set_class = OrderedSet\n\n # These get set by SchemaMeta\n opts = None # type: SchemaOpts\n _declared_fields = {} # type: typing.Dict[str, ma_fields.Field]\n _hooks = {} # type: typing.Dict[types.Tag, typing.List[str]]\n\n class Meta:\n \"\"\"Options object for a Schema.\n\n Example usage: ::\n\n class Meta:\n fields = (\"id\", \"email\", \"date_created\")\n exclude = (\"password\", \"secret_attribute\")\n\n Available options:\n\n - ``fields``: Tuple or list of fields to include in the serialized result.\n - ``additional``: Tuple or list of fields to include *in addition* to the\n explicitly declared fields. ``additional`` and ``fields`` are\n mutually-exclusive options.\n - ``include``: Dictionary of additional fields to include in the schema. It is\n usually better to define fields as class variables, but you may need to\n use this option, e.g., if your fields are Python keywords. May be an\n...\n", "needle_token_start": 7998, "needle_token_end": 8157, "code_context_ntokens": 13707, "output": ["```python\ndef from_iso_time(value):\n \"\"\"Parse a string and return a datetime.time.\n\n This function doesn't support time zone offsets.\n \"\"\"\n match = _iso8601_time_re.match(value)\n if not match:\n raise ValueError(\"Not a valid ISO8601-formatted time string\")\n kw = match.groupdict()\n kw[\"microsecond\"] = kw[\"microsecond\"] and kw[\"microsecond\"].ljust(6, \"0\")\n kw = {k: int(v) for k, v in kw.items() if v is not None}\n return dt.time(**kw)\n```"]} +{"repo": "marshmallow-code/marshmallow", "name": "__len__", "language": "python", "path": "src/marshmallow/orderedset.py", "position_ratio": 0.75, "description": "\nFunction Description:\n1. **Purpose**: To determine the number of unique elements stored in a custom set-like data structure.\n2. **Input**: No explicit input is required from the user when this method is called.\n3. **Output**: Returns an integer representing the count of unique elements.\n4. **Procedure**: This method calculates the total number of keys in an internal dictionary that maps each unique element to its position in a doubly linked list, effectively counting the elements in the data structure.\n\n", "instruction": "Based on the function description and code context, please retrieve and repeat the exact described function from the code context in a code block wrapped by ```:", "template": "instruction\ncode_context\ndescription\ninstruction", "code_context": "# Path: src/marshmallow/decorators.py\n\"\"\"Decorators for registering schema pre-processing and post-processing methods.\nThese should be imported from the top-level `marshmallow` module.\n\nMethods decorated with\n`pre_load `, `post_load `,\n`pre_dump `, `post_dump `,\nand `validates_schema ` receive\n``many`` as a keyword argument. In addition, `pre_load `,\n`post_load `,\nand `validates_schema ` receive\n``partial``. If you don't need these arguments, add ``**kwargs`` to your method\nsignature.\n\n\nExample: ::\n\n from marshmallow import (\n Schema, pre_load, pre_dump, post_load, validates_schema,\n validates, fields, ValidationError\n )\n\n class UserSchema(Schema):\n\n email = fields.Str(required=True)\n age = fields.Integer(required=True)\n\n @post_load\n def lowerstrip_email(self, item, many, **kwargs):\n item['email'] = item['email'].lower().strip()\n return item\n\n @pre_load(pass_many=True)\n def remove_envelope(self, data, many, **kwargs):\n namespace = 'results' if many else 'result'\n return data[namespace]\n\n @post_dump(pass_many=True)\n def add_envelope(self, data, many, **kwargs):\n namespace = 'results' if many else 'result'\n return {namespace: data}\n\n @validates_schema\n def validate_email(self, data, **kwargs):\n if len(data['email']) < 3:\n raise ValidationError('Email must be more than 3 characters', 'email')\n\n @validates('age')\n def validate_age(self, data, **kwargs):\n if data < 14:\n raise ValidationError('Too young!')\n\n.. note::\n These decorators only work with instance methods. Class and static\n methods are not supported.\n\n.. warning::\n The invocation order of decorated methods of the same type is not guaranteed.\n If you need to guarantee order of different processing steps, you should put\n them in the same processing method.\n\"\"\"\nfrom __future__ import annotations\n\nimport functools\nfrom typing import Any, Callable, cast\n\nPRE_DUMP = \"pre_dump\"\nPOST_DUMP = \"post_dump\"\nPRE_LOAD = \"pre_load\"\nPOST_LOAD = \"post_load\"\nVALIDATES = \"validates\"\nVALIDATES_SCHEMA = \"validates_schema\"\n\n\nclass MarshmallowHook:\n __marshmallow_hook__: dict[tuple[str, bool] | str, Any] | None = None\n\n\ndef validates(field_name: str) -> Callable[..., Any]:\n \"\"\"Register a field validator.\n\n :param str field_name: Name of the field that the method validates.\n \"\"\"\n return set_hook(None, VALIDATES, field_name=field_name)\n\n\ndef validates_schema(\n fn: Callable[..., Any] | None = None,\n pass_many: bool = False,\n pass_original: bool = False,\n skip_on_field_errors: bool = True,\n) -> Callable[..., Any]:\n \"\"\"Register a schema-level validator.\n\n By default it receives a single object at a time, transparently handling the ``many``\n argument passed to the `Schema`'s :func:`~marshmallow.Schema.validate` call.\n If ``pass_many=True``, the raw data (which may be a collection) is passed.\n\n If ``pass_original=True``, the original data (before unmarshalling) will be passed as\n an additional argument to the method.\n\n If ``skip_on_field_errors=True``, this validation method will be skipped whenever\n validation errors have been detected when validating fields.\n\n .. versionchanged:: 3.0.0b1\n ``skip_on_field_errors`` defaults to `True`.\n\n .. versionchanged:: 3.0.0\n ``partial`` and ``many`` are always passed as keyword arguments to\n the decorated method.\n \"\"\"\n return set_hook(\n fn,\n (VALIDATES_SCHEMA, pass_many),\n pass_original=pass_original,\n skip_on_field_errors=skip_on_field_errors,\n )\n\n\ndef pre_dump(\n fn: Callable[..., Any] | None = None, pass_many: bool = False\n) -> Callable[..., Any]:\n \"\"\"Register a method to invoke before serializing an object. The method\n receives the object to be serialized and returns the processed object.\n\n By default it receives a single object at a time, transparently handling the ``many``\n argument passed to the `Schema`'s :func:`~marshmallow.Schema.dump` call.\n If ``pass_many=True``, the raw data (which may be a collection) is passed.\n\n .. versionchanged:: 3.0.0\n ``many`` is always passed as a keyword arguments to the decorated method.\n \"\"\"\n return set_hook(fn, (PRE_DUMP, pass_many))\n\n\ndef post_dump(\n fn: Callable[..., Any] | None = None,\n pass_many: bool = False,\n pass_original: bool = False,\n) -> Callable[..., Any]:\n \"\"\"Register a method to invoke after serializing an object. The method\n receives the serialized object and returns the processed object.\n\n By default it receives a single object at a time, transparently handling the ``many``\n argument passed to the `Schema`'s :func:`~marshmallow.Schema.dump` call.\n If ``pass_many=True``, the raw data (which may be a collection) is passed.\n\n If ``pass_original=True``, the original data (before serializing) will be passed as\n an additional argument to the method.\n\n .. versionchanged:: 3.0.0\n ``many`` is always passed as a keyword arguments to the decorated method.\n \"\"\"\n return set_hook(fn, (POST_DUMP, pass_many), pass_original=pass_original)\n\n\ndef pre_load(\n fn: Callable[..., Any] | None = None, pass_many: bool = False\n) -> Callable[..., Any]:\n \"\"\"Register a method to invoke before deserializing an object. The method\n receives the data to be deserialized and returns the processed data.\n\n By default it receives a single object at a time, transparently handling the ``many``\n argument passed to the `Schema`'s :func:`~marshmallow.Schema.load` call.\n If ``pass_many=True``, the raw data (which may be a collection) is passed.\n\n .. versionchanged:: 3.0.0\n ``partial`` and ``many`` are always passed as keyword arguments to\n the decorated method.\n \"\"\"\n return set_hook(fn, (PRE_LOAD, pass_many))\n\n\ndef post_load(\n fn: Callable[..., Any] | None = None,\n pass_many: bool = False,\n pass_original: bool = False,\n) -> Callable[..., Any]:\n \"\"\"Register a method to invoke after deserializing an object. The method\n receives the deserialized data and returns the processed data.\n\n By default it receives a single object at a time, transparently handling the ``many``\n argument passed to the `Schema`'s :func:`~marshmallow.Schema.load` call.\n If ``pass_many=True``, the raw data (which may be a collection) is passed.\n\n If ``pass_original=True``, the original data (before deserializing) will be passed as\n an additional argument to the method.\n\n .. versionchanged:: 3.0.0\n ``partial`` and ``many`` are always passed as keyword arguments to\n the decorated method.\n \"\"\"\n return set_hook(fn, (POST_LOAD, pass_many), pass_original=pass_original)\n\n\ndef set_hook(\n fn: Callable[..., Any] | None, key: tuple[str, bool] | str, **kwargs: Any\n) -> Callable[..., Any]:\n \"\"\"Mark decorated function as a hook to be picked up later.\n You should not need to use this method directly.\n\n .. note::\n Currently only works with functions and instance methods. Class and\n static methods are not supported.\n\n :return: Decorated function if supplied, else this decorator with its args\n bound.\n \"\"\"\n # Allow using this as either a decorator or a decorator factory.\n if fn is None:\n return functools.partial(set_hook, key=key, **kwargs)\n\n # Set a __marshmallow_hook__ attribute instead of wrapping in some class,\n # because I still want this to end up as a normal (unbound) method.\n function = cast(MarshmallowHook, fn)\n try:\n hook_config = function.__marshmallow_hook__\n except AttributeError:\n function.__marshmallow_hook__ = hook_config = {}\n # Also save the kwargs for the tagged function on\n # __marshmallow_hook__, keyed by (, )\n if hook_config is not None:\n hook_config[key] = kwargs\n\n return fn\n\n# Path: src/marshmallow/exceptions.py\n\"\"\"Exception classes for marshmallow-related errors.\"\"\"\nfrom __future__ import annotations\n\nimport typing\n\n# Key used for schema-level validation errors\nSCHEMA = \"_schema\"\n\n\nclass MarshmallowError(Exception):\n \"\"\"Base class for all marshmallow-related errors.\"\"\"\n\n\nclass ValidationError(MarshmallowError):\n \"\"\"Raised when validation fails on a field or schema.\n\n Validators and custom fields should raise this exception.\n\n :param message: An error message, list of error messages, or dict of\n error messages. If a dict, the keys are subitems and the values are error messages.\n :param field_name: Field name to store the error on.\n If `None`, the error is stored as schema-level error.\n :param data: Raw input data.\n :param valid_data: Valid (de)serialized data.\n \"\"\"\n\n def __init__(\n self,\n message: str | list | dict,\n field_name: str = SCHEMA,\n data: typing.Mapping[str, typing.Any]\n | typing.Iterable[typing.Mapping[str, typing.Any]]\n | None = None,\n valid_data: list[dict[str, typing.Any]] | dict[str, typing.Any] | None = None,\n **kwargs,\n ):\n self.messages = [message] if isinstance(message, (str, bytes)) else message\n self.field_name = field_name\n self.data = data\n self.valid_data = valid_data\n self.kwargs = kwargs\n super().__init__(message)\n\n def normalized_messages(self):\n if self.field_name == SCHEMA and isinstance(self.messages, dict):\n return self.messages\n return {self.field_name: self.messages}\n\n @property\n def messages_dict(self) -> dict[str, typing.Any]:\n if not isinstance(self.messages, dict):\n raise TypeError(\n \"cannot access 'messages_dict' when 'messages' is of type \"\n + type(self.messages).__name__\n )\n return self.messages\n\n\nclass RegistryError(NameError):\n \"\"\"Raised when an invalid operation is performed on the serializer\n class registry.\n \"\"\"\n\n\nclass StringNotCollectionError(MarshmallowError, TypeError):\n \"\"\"Raised when a string is passed when a list of strings is expected.\"\"\"\n\n\nclass FieldInstanceResolutionError(MarshmallowError, TypeError):\n \"\"\"Raised when schema to instantiate is neither a Schema class nor an instance.\"\"\"\n\n# Path: src/marshmallow/base.py\n\"\"\"Abstract base classes.\n\nThese are necessary to avoid circular imports between schema.py and fields.py.\n\n.. warning::\n\n This module is treated as private API.\n Users should not need to use this module directly.\n\"\"\"\nfrom __future__ import annotations\n\nfrom abc import ABC, abstractmethod\n\n\nclass FieldABC(ABC):\n \"\"\"Abstract base class from which all Field classes inherit.\"\"\"\n\n parent = None\n name = None\n root = None\n\n @abstractmethod\n def serialize(self, attr, obj, accessor=None):\n pass\n\n @abstractmethod\n def deserialize(self, value):\n pass\n\n @abstractmethod\n def _serialize(self, value, attr, obj, **kwargs):\n pass\n\n @abstractmethod\n def _deserialize(self, value, attr, data, **kwargs):\n pass\n\n\nclass SchemaABC(ABC):\n \"\"\"Abstract base class from which all Schemas inherit.\"\"\"\n\n @abstractmethod\n def dump(self, obj, *, many: bool | None = None):\n pass\n\n @abstractmethod\n def dumps(self, obj, *, many: bool | None = None):\n pass\n\n @abstractmethod\n def load(self, data, *, many: bool | None = None, partial=None, unknown=None):\n pass\n\n @abstractmethod\n def loads(\n self,\n json_data,\n *,\n many: bool | None = None,\n partial=None,\n unknown=None,\n **kwargs,\n ):\n pass\n\n# Path: src/marshmallow/class_registry.py\n\"\"\"A registry of :class:`Schema ` classes. This allows for string\nlookup of schemas, which may be used with\nclass:`fields.Nested `.\n\n.. warning::\n\n This module is treated as private API.\n Users should not need to use this module directly.\n\"\"\"\nfrom __future__ import annotations\n\nimport typing\n\nfrom marshmallow.exceptions import RegistryError\n\nif typing.TYPE_CHECKING:\n from marshmallow import Schema\n\n SchemaType = typing.Type[Schema]\n\n# {\n# : \n# : \n# }\n_registry = {} # type: dict[str, list[SchemaType]]\n\n\ndef register(classname: str, cls: SchemaType) -> None:\n \"\"\"Add a class to the registry of serializer classes. When a class is\n registered, an entry for both its classname and its full, module-qualified\n path are added to the registry.\n\n Example: ::\n\n class MyClass:\n pass\n\n register('MyClass', MyClass)\n # Registry:\n # {\n # 'MyClass': [path.to.MyClass],\n # 'path.to.MyClass': [path.to.MyClass],\n # }\n\n \"\"\"\n # Module where the class is located\n module = cls.__module__\n # Full module path to the class\n # e.g. user.schemas.UserSchema\n fullpath = \".\".join([module, classname])\n # If the class is already registered; need to check if the entries are\n # in the same module as cls to avoid having multiple instances of the same\n # class in the registry\n if classname in _registry and not any(\n each.__module__ == module for each in _registry[classname]\n ):\n _registry[classname].append(cls)\n elif classname not in _registry:\n _registry[classname] = [cls]\n\n # Also register the full path\n if fullpath not in _registry:\n _registry.setdefault(fullpath, []).append(cls)\n else:\n # If fullpath does exist, replace existing entry\n _registry[fullpath] = [cls]\n return None\n\n\ndef get_class(classname: str, all: bool = False) -> list[SchemaType] | SchemaType:\n \"\"\"Retrieve a class from the registry.\n\n :raises: marshmallow.exceptions.RegistryError if the class cannot be found\n or if there are multiple entries for the given class name.\n \"\"\"\n try:\n classes = _registry[classname]\n except KeyError as error:\n raise RegistryError(\n f\"Class with name {classname!r} was not found. You may need \"\n \"to import the class.\"\n ) from error\n if len(classes) > 1:\n if all:\n return _registry[classname]\n raise RegistryError(\n f\"Multiple classes with name {classname!r} \"\n \"were found. Please use the full, \"\n \"module-qualified path.\"\n )\n else:\n return _registry[classname][0]\n\n# Path: src/marshmallow/error_store.py\n\"\"\"Utilities for storing collections of error messages.\n\n.. warning::\n\n This module is treated as private API.\n Users should not need to use this module directly.\n\"\"\"\n\nfrom marshmallow.exceptions import SCHEMA\n\n\nclass ErrorStore:\n def __init__(self):\n #: Dictionary of errors stored during serialization\n self.errors = {}\n\n def store_error(self, messages, field_name=SCHEMA, index=None):\n # field error -> store/merge error messages under field name key\n # schema error -> if string or list, store/merge under _schema key\n # -> if dict, store/merge with other top-level keys\n if field_name != SCHEMA or not isinstance(messages, dict):\n messages = {field_name: messages}\n if index is not None:\n messages = {index: messages}\n self.errors = merge_errors(self.errors, messages)\n\n\ndef merge_errors(errors1, errors2):\n \"\"\"Deeply merge two error messages.\n\n The format of ``errors1`` and ``errors2`` matches the ``message``\n parameter of :exc:`marshmallow.exceptions.ValidationError`.\n \"\"\"\n if not errors1:\n return errors2\n if not errors2:\n return errors1\n if isinstance(errors1, list):\n if isinstance(errors2, list):\n return errors1 + errors2\n if isinstance(errors2, dict):\n return dict(errors2, **{SCHEMA: merge_errors(errors1, errors2.get(SCHEMA))})\n return errors1 + [errors2]\n if isinstance(errors1, dict):\n if isinstance(errors2, list):\n return dict(errors1, **{SCHEMA: merge_errors(errors1.get(SCHEMA), errors2)})\n if isinstance(errors2, dict):\n errors = dict(errors1)\n for key, val in errors2.items():\n if key in errors:\n errors[key] = merge_errors(errors[key], val)\n else:\n errors[key] = val\n return errors\n return dict(errors1, **{SCHEMA: merge_errors(errors1.get(SCHEMA), errors2)})\n if isinstance(errors2, list):\n return [errors1] + errors2\n if isinstance(errors2, dict):\n return dict(errors2, **{SCHEMA: merge_errors(errors1, errors2.get(SCHEMA))})\n return [errors1, errors2]\n\n# Path: src/marshmallow/orderedset.py\n# OrderedSet\n# Copyright (c) 2009 Raymond Hettinger\n#\n# Permission is hereby granted, free of charge, to any person\n# obtaining a copy of this software and associated documentation files\n# (the \"Software\"), to deal in the Software without restriction,\n# including without limitation the rights to use, copy, modify, merge,\n# publish, distribute, sublicense, and/or sell copies of the Software,\n# and to permit persons to whom the Software is furnished to do so,\n# subject to the following conditions:\n#\n# The above copyright notice and this permission notice shall be\n# included in all copies or substantial portions of the Software.\n#\n# THE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND,\n# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES\n# OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND\n# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT\n# HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY,\n# WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING\n# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR\n# OTHER DEALINGS IN THE SOFTWARE.\nfrom collections.abc import MutableSet\n\n\nclass OrderedSet(MutableSet):\n def __init__(self, iterable=None):\n self.end = end = []\n end += [None, end, end] # sentinel node for doubly linked list\n self.map = {} # key --> [key, prev, next]\n if iterable is not None:\n self |= iterable\n\n \ndef __len__(self):\n return len(self.map)\n\n def __contains__(self, key):\n return key in self.map\n\n def add(self, key):\n if key not in self.map:\n end = self.end\n curr = end[1]\n curr[2] = end[1] = self.map[key] = [key, curr, end]\n\n def discard(self, key):\n if key in self.map:\n key, prev, next = self.map.pop(key)\n prev[2] = next\n next[1] = prev\n\n def __iter__(self):\n end = self.end\n curr = end[2]\n while curr is not end:\n yield curr[0]\n curr = curr[2]\n\n def __reversed__(self):\n end = self.end\n curr = end[1]\n while curr is not end:\n yield curr[0]\n curr = curr[1]\n\n def pop(self, last=True):\n if not self:\n raise KeyError(\"set is empty\")\n key = self.end[1][0] if last else self.end[2][0]\n self.discard(key)\n return key\n\n def __repr__(self):\n if not self:\n return f\"{self.__class__.__name__}()\"\n return f\"{self.__class__.__name__}({list(self)!r})\"\n\n def __eq__(self, other):\n if isinstance(other, OrderedSet):\n return len(self) == len(other) and list(self) == list(other)\n return set(self) == set(other)\n\n\nif __name__ == \"__main__\":\n s = OrderedSet(\"abracadaba\")\n t = OrderedSet(\"simsalabim\")\n print(s | t)\n print(s & t)\n print(s - t)\n\n# Path: src/marshmallow/types.py\n\"\"\"Type aliases.\n\n.. warning::\n\n This module is provisional. Types may be modified, added, and removed between minor releases.\n\"\"\"\nimport typing\n\nStrSequenceOrSet = typing.Union[typing.Sequence[str], typing.AbstractSet[str]]\nTag = typing.Union[str, typing.Tuple[str, bool]]\nValidator = typing.Callable[[typing.Any], typing.Any]\n\n# Path: src/marshmallow/warnings.py\nclass RemovedInMarshmallow4Warning(DeprecationWarning):\n pass\n\n# Path: src/marshmallow/utils.py\n\"\"\"Utility methods for marshmallow.\"\"\"\nfrom __future__ import annotations\n\nimport collections\nimport datetime as dt\nimport functools\nimport inspect\nimport json\nimport re\nimport typing\nimport warnings\nfrom collections.abc import Mapping\nfrom email.utils import format_datetime, parsedate_to_datetime\nfrom pprint import pprint as py_pprint\n\nfrom marshmallow.base import FieldABC\nfrom marshmallow.exceptions import FieldInstanceResolutionError\nfrom marshmallow.warnings import RemovedInMarshmallow4Warning\n\nEXCLUDE = \"exclude\"\nINCLUDE = \"include\"\nRAISE = \"raise\"\n_UNKNOWN_VALUES = {EXCLUDE, INCLUDE, RAISE}\n\n\nclass _Missing:\n def __bool__(self):\n return False\n\n def __copy__(self):\n return self\n\n def __deepcopy__(self, _):\n return self\n\n def __repr__(self):\n return \"\"\n\n\n# Singleton value that indicates that a field's value is missing from input\n# dict passed to :meth:`Schema.load`. If the field's value is not required,\n# it's ``default`` value is used.\nmissing = _Missing()\n\n\ndef is_generator(obj) -> bool:\n \"\"\"Return True if ``obj`` is a generator\"\"\"\n return inspect.isgeneratorfunction(obj) or inspect.isgenerator(obj)\n\n\ndef is_iterable_but_not_string(obj) -> bool:\n \"\"\"Return True if ``obj`` is an iterable object that isn't a string.\"\"\"\n return (hasattr(obj, \"__iter__\") and not hasattr(obj, \"strip\")) or is_generator(obj)\n\n\ndef is_collection(obj) -> bool:\n \"\"\"Return True if ``obj`` is a collection type, e.g list, tuple, queryset.\"\"\"\n return is_iterable_but_not_string(obj) and not isinstance(obj, Mapping)\n\n\ndef is_instance_or_subclass(val, class_) -> bool:\n \"\"\"Return True if ``val`` is either a subclass or instance of ``class_``.\"\"\"\n try:\n return issubclass(val, class_)\n except TypeError:\n return isinstance(val, class_)\n\n\ndef is_keyed_tuple(obj) -> bool:\n \"\"\"Return True if ``obj`` has keyed tuple behavior, such as\n namedtuples or SQLAlchemy's KeyedTuples.\n \"\"\"\n return isinstance(obj, tuple) and hasattr(obj, \"_fields\")\n\n\ndef pprint(obj, *args, **kwargs) -> None:\n \"\"\"Pretty-printing function that can pretty-print OrderedDicts\n like regular dictionaries. Useful for printing the output of\n :meth:`marshmallow.Schema.dump`.\n\n .. deprecated:: 3.7.0\n marshmallow.pprint will be removed in marshmallow 4.\n \"\"\"\n warnings.warn(\n \"marshmallow's pprint function is deprecated and will be removed in marshmallow 4.\",\n RemovedInMarshmallow4Warning,\n stacklevel=2,\n )\n if isinstance(obj, collections.OrderedDict):\n print(json.dumps(obj, *args, **kwargs))\n else:\n py_pprint(obj, *args, **kwargs)\n\n\n# https://stackoverflow.com/a/27596917\ndef is_aware(datetime: dt.datetime) -> bool:\n return (\n datetime.tzinfo is not None and datetime.tzinfo.utcoffset(datetime) is not None\n )\n\n\ndef from_rfc(datestring: str) -> dt.datetime:\n \"\"\"Parse a RFC822-formatted datetime string and return a datetime object.\n\n https://stackoverflow.com/questions/885015/how-to-parse-a-rfc-2822-date-time-into-a-python-datetime # noqa: B950\n \"\"\"\n return parsedate_to_datetime(datestring)\n\n\ndef rfcformat(datetime: dt.datetime) -> str:\n \"\"\"Return the RFC822-formatted representation of a datetime object.\n\n :param datetime datetime: The datetime.\n \"\"\"\n return format_datetime(datetime)\n\n\n# Hat tip to Django for ISO8601 deserialization functions\n\n_iso8601_datetime_re = re.compile(\n r\"(?P\\d{4})-(?P\\d{1,2})-(?P\\d{1,2})\"\n r\"[T ](?P\\d{1,2}):(?P\\d{1,2})\"\n r\"(?::(?P\\d{1,2})(?:\\.(?P\\d{1,6})\\d{0,6})?)?\"\n r\"(?PZ|[+-]\\d{2}(?::?\\d{2})?)?$\"\n)\n\n_iso8601_date_re = re.compile(r\"(?P\\d{4})-(?P\\d{1,2})-(?P\\d{1,2})$\")\n\n_iso8601_time_re = re.compile(\n r\"(?P\\d{1,2}):(?P\\d{1,2})\"\n r\"(?::(?P\\d{1,2})(?:\\.(?P\\d{1,6})\\d{0,6})?)?\"\n)\n\n\ndef get_fixed_timezone(offset: int | float | dt.timedelta) -> dt.timezone:\n \"\"\"Return a tzinfo instance with a fixed offset from UTC.\"\"\"\n if isinstance(offset, dt.timedelta):\n offset = offset.total_seconds() // 60\n sign = \"-\" if offset < 0 else \"+\"\n hhmm = \"%02d%02d\" % divmod(abs(offset), 60)\n name = sign + hhmm\n return dt.timezone(dt.timedelta(minutes=offset), name)\n\n\ndef from_iso_datetime(value):\n \"\"\"Parse a string and return a datetime.datetime.\n\n This function supports time zone offsets. When the input contains one,\n the output uses a timezone with a fixed offset from UTC.\n \"\"\"\n match = _iso8601_datetime_re.match(value)\n if not match:\n raise ValueError(\"Not a valid ISO8601-formatted datetime string\")\n kw = match.groupdict()\n kw[\"microsecond\"] = kw[\"microsecond\"] and kw[\"microsecond\"].ljust(6, \"0\")\n tzinfo = kw.pop(\"tzinfo\")\n if tzinfo == \"Z\":\n tzinfo = dt.timezone.utc\n elif tzinfo is not None:\n offset_mins = int(tzinfo[-2:]) if len(tzinfo) > 3 else 0\n offset = 60 * int(tzinfo[1:3]) + offset_mins\n if tzinfo[0] == \"-\":\n offset = -offset\n tzinfo = get_fixed_timezone(offset)\n kw = {k: int(v) for k, v in kw.items() if v is not None}\n kw[\"tzinfo\"] = tzinfo\n return dt.datetime(**kw)\n\n\ndef from_iso_time(value):\n \"\"\"Parse a string and return a datetime.time.\n\n This function doesn't support time zone offsets.\n \"\"\"\n match = _iso8601_time_re.match(value)\n if not match:\n raise ValueError(\"Not a valid ISO8601-formatted time string\")\n kw = match.groupdict()\n kw[\"microsecond\"] = kw[\"microsecond\"] and kw[\"microsecond\"].ljust(6, \"0\")\n kw = {k: int(v) for k, v in kw.items() if v is not None}\n return dt.time(**kw)\n\n\ndef from_iso_date(value):\n \"\"\"Parse a string and return a datetime.date.\"\"\"\n match = _iso8601_date_re.match(value)\n if not match:\n raise ValueError(\"Not a valid ISO8601-formatted date string\")\n kw = {k: int(v) for k, v in match.groupdict().items()}\n return dt.date(**kw)\n\n\ndef from_timestamp(value: typing.Any) -> dt.datetime:\n value = float(value)\n if value < 0:\n raise ValueError(\"Not a valid POSIX timestamp\")\n\n # Load a timestamp with utc as timezone to prevent using system timezone.\n # Then set timezone to None, to let the Field handle adding timezone info.\n try:\n return dt.datetime.fromtimestamp(value, tz=dt.timezone.utc).replace(tzinfo=None)\n except OverflowError as exc:\n raise ValueError(\"Timestamp is too large\") from exc\n except OSError as exc:\n raise ValueError(\"Error converting value to datetime\") from exc\n\n\ndef from_timestamp_ms(value: typing.Any) -> dt.datetime:\n value = float(value)\n return from_timestamp(value / 1000)\n\n\ndef timestamp(\n value: dt.datetime,\n) -> float:\n if not is_aware(value):\n # When a date is naive, use UTC as zone info to prevent using system timezone.\n value = value.replace(tzinfo=dt.timezone.utc)\n return value.timestamp()\n\n\ndef timestamp_ms(value: dt.datetime) -> float:\n return timestamp(value) * 1000\n\n\ndef isoformat(datetime: dt.datetime) -> str:\n \"\"\"Return the ISO8601-formatted representation of a datetime object.\n\n :param datetime datetime: The datetime.\n \"\"\"\n return datetime.isoformat()\n\n\ndef to_iso_time(time: dt.time) -> str:\n return dt.time.isoformat(time)\n\n\ndef to_iso_date(date: dt.date) -> str:\n return dt.date.isoformat(date)\n\n\ndef ensure_text_type(val: str | bytes) -> str:\n if isinstance(val, bytes):\n val = val.decode(\"utf-8\")\n return str(val)\n\n\ndef pluck(dictlist: list[dict[str, typing.Any]], key: str):\n \"\"\"Extracts a list of dictionary values from a list of dictionaries.\n ::\n\n >>> dlist = [{'id': 1, 'name': 'foo'}, {'id': 2, 'name': 'bar'}]\n >>> pluck(dlist, 'id')\n [1, 2]\n \"\"\"\n return [d[key] for d in dictlist]\n\n\n# Various utilities for pulling keyed values from objects\n\n\ndef get_value(obj, key: int | str, default=missing):\n \"\"\"Helper for pulling a keyed value off various types of objects. Fields use\n this method by default to access attributes of the source object. For object `x`\n and attribute `i`, this method first tries to access `x[i]`, and then falls back to\n `x.i` if an exception is raised.\n\n .. warning::\n If an object `x` does not raise an exception when `x[i]` does not exist,\n `get_value` will never check the value `x.i`. Consider overriding\n `marshmallow.fields.Field.get_value` in this case.\n \"\"\"\n if not isinstance(key, int) and \".\" in key:\n return _get_value_for_keys(obj, key.split(\".\"), default)\n else:\n return _get_value_for_key(obj, key, default)\n\n\ndef _get_value_for_keys(obj, keys, default):\n if len(keys) == 1:\n return _get_value_for_key(obj, keys[0], default)\n else:\n return _get_value_for_keys(\n _get_value_for_key(obj, keys[0], default), keys[1:], default\n )\n\n\ndef _get_value_for_key(obj, key, default):\n if not hasattr(obj, \"__getitem__\"):\n return getattr(obj, key, default)\n\n try:\n return obj[key]\n except (KeyError, IndexError, TypeError, AttributeError):\n return getattr(obj, key, default)\n\n\ndef set_value(dct: dict[str, typing.Any], key: str, value: typing.Any):\n \"\"\"Set a value in a dict. If `key` contains a '.', it is assumed\n be a path (i.e. dot-delimited string) to the value's location.\n\n ::\n\n >>> d = {}\n >>> set_value(d, 'foo.bar', 42)\n >>> d\n {'foo': {'bar': 42}}\n \"\"\"\n if \".\" in key:\n head, rest = key.split(\".\", 1)\n target = dct.setdefault(head, {})\n if not isinstance(target, dict):\n raise ValueError(\n f\"Cannot set {key} in {head} \" f\"due to existing value: {target}\"\n )\n...\n", "needle_token_start": 5474, "needle_token_end": 5489, "code_context_ntokens": 9485, "output": ["```python\nclass OrderedSet(MutableSet):\n def __init__(self, iterable=None):\n self.end = end = []\n end += [None, end, end] # sentinel node for doubly linked list\n self.map = {} # key --> [key, prev, next]\n if iterable is not None:\n self |= iterable\n\n def __len__(self):\n return len(self.map)\n\n def __contains__(self, key):\n return key in self.map\n\n def add(self, key):\n if key not in self.map:\n end = self.end\n curr = end[1]\n curr[2] = end[1] = self.map[key] = [key, curr, end]\n\n def discard(self, key):\n if key in self.map:\n key, prev, next = self.map.pop(key)\n prev[2] = next\n next[1] = prev\n\n def __iter__(self):\n end = self.end\n curr = end[2]\n while curr is not end:\n yield curr[0]\n curr = curr[2]\n\n def __reversed__(self):\n end = self.end\n curr = end[1]\n while curr is not end:\n yield curr[0]\n curr = curr[1]\n\n def pop(self, last=True):\n if not self:\n raise KeyError(\"set is empty\")\n key = self.end[1][0] if last else self.end[2][0]\n self.discard(key)\n return key\n\n def __repr__(self):\n if not self:\n return f\"{self.__class__.__name__}()\"\n return f\"{self.__class__.__name__}({list(self)!r})\"\n\n def __eq__(self, other):\n if isinstance(other, OrderedSet):\n return len(self) == len(other) and list(self) == list(other)\n return set(self) == set(other)\n```"]} +{"repo": "marshmallow-code/marshmallow", "name": "normalized_messages", "language": "python", "path": "src/marshmallow/exceptions.py", "position_ratio": 0.85, "description": "\nFunction Description:\n1. **Purpose**: To format error messages in a consistent dictionary structure, facilitating easier error handling and display.\n2. **Input**: This method does not take external inputs as it operates on the instance's attributes.\n3. **Output**: A dictionary where keys represent field names and values are the corresponding error messages. If the error is schema-level, the output will be the error message dictionary directly.\n4. **Procedure**: The method checks if the error is a schema-level error and if the error messages are already in dictionary form. If both conditions are met, it returns the error messages as is. Otherwise, it constructs a new dictionary with the field name as the key and the error messages as the value.\n\n", "instruction": "Based on the function description and code context, please retrieve and repeat the exact described function from the code context in a code block wrapped by ```:", "template": "instruction\ncode_context\ndescription\ninstruction", "code_context": "# Path: src/marshmallow/decorators.py\n\"\"\"Decorators for registering schema pre-processing and post-processing methods.\nThese should be imported from the top-level `marshmallow` module.\n\nMethods decorated with\n`pre_load `, `post_load `,\n`pre_dump `, `post_dump `,\nand `validates_schema ` receive\n``many`` as a keyword argument. In addition, `pre_load `,\n`post_load `,\nand `validates_schema ` receive\n``partial``. If you don't need these arguments, add ``**kwargs`` to your method\nsignature.\n\n\nExample: ::\n\n from marshmallow import (\n Schema, pre_load, pre_dump, post_load, validates_schema,\n validates, fields, ValidationError\n )\n\n class UserSchema(Schema):\n\n email = fields.Str(required=True)\n age = fields.Integer(required=True)\n\n @post_load\n def lowerstrip_email(self, item, many, **kwargs):\n item['email'] = item['email'].lower().strip()\n return item\n\n @pre_load(pass_many=True)\n def remove_envelope(self, data, many, **kwargs):\n namespace = 'results' if many else 'result'\n return data[namespace]\n\n @post_dump(pass_many=True)\n def add_envelope(self, data, many, **kwargs):\n namespace = 'results' if many else 'result'\n return {namespace: data}\n\n @validates_schema\n def validate_email(self, data, **kwargs):\n if len(data['email']) < 3:\n raise ValidationError('Email must be more than 3 characters', 'email')\n\n @validates('age')\n def validate_age(self, data, **kwargs):\n if data < 14:\n raise ValidationError('Too young!')\n\n.. note::\n These decorators only work with instance methods. Class and static\n methods are not supported.\n\n.. warning::\n The invocation order of decorated methods of the same type is not guaranteed.\n If you need to guarantee order of different processing steps, you should put\n them in the same processing method.\n\"\"\"\nfrom __future__ import annotations\n\nimport functools\nfrom typing import Any, Callable, cast\n\nPRE_DUMP = \"pre_dump\"\nPOST_DUMP = \"post_dump\"\nPRE_LOAD = \"pre_load\"\nPOST_LOAD = \"post_load\"\nVALIDATES = \"validates\"\nVALIDATES_SCHEMA = \"validates_schema\"\n\n\nclass MarshmallowHook:\n __marshmallow_hook__: dict[tuple[str, bool] | str, Any] | None = None\n\n\ndef validates(field_name: str) -> Callable[..., Any]:\n \"\"\"Register a field validator.\n\n :param str field_name: Name of the field that the method validates.\n \"\"\"\n return set_hook(None, VALIDATES, field_name=field_name)\n\n\ndef validates_schema(\n fn: Callable[..., Any] | None = None,\n pass_many: bool = False,\n pass_original: bool = False,\n skip_on_field_errors: bool = True,\n) -> Callable[..., Any]:\n \"\"\"Register a schema-level validator.\n\n By default it receives a single object at a time, transparently handling the ``many``\n argument passed to the `Schema`'s :func:`~marshmallow.Schema.validate` call.\n If ``pass_many=True``, the raw data (which may be a collection) is passed.\n\n If ``pass_original=True``, the original data (before unmarshalling) will be passed as\n an additional argument to the method.\n\n If ``skip_on_field_errors=True``, this validation method will be skipped whenever\n validation errors have been detected when validating fields.\n\n .. versionchanged:: 3.0.0b1\n ``skip_on_field_errors`` defaults to `True`.\n\n .. versionchanged:: 3.0.0\n ``partial`` and ``many`` are always passed as keyword arguments to\n the decorated method.\n \"\"\"\n return set_hook(\n fn,\n (VALIDATES_SCHEMA, pass_many),\n pass_original=pass_original,\n skip_on_field_errors=skip_on_field_errors,\n )\n\n\ndef pre_dump(\n fn: Callable[..., Any] | None = None, pass_many: bool = False\n) -> Callable[..., Any]:\n \"\"\"Register a method to invoke before serializing an object. The method\n receives the object to be serialized and returns the processed object.\n\n By default it receives a single object at a time, transparently handling the ``many``\n argument passed to the `Schema`'s :func:`~marshmallow.Schema.dump` call.\n If ``pass_many=True``, the raw data (which may be a collection) is passed.\n\n .. versionchanged:: 3.0.0\n ``many`` is always passed as a keyword arguments to the decorated method.\n \"\"\"\n return set_hook(fn, (PRE_DUMP, pass_many))\n\n\ndef post_dump(\n fn: Callable[..., Any] | None = None,\n pass_many: bool = False,\n pass_original: bool = False,\n) -> Callable[..., Any]:\n \"\"\"Register a method to invoke after serializing an object. The method\n receives the serialized object and returns the processed object.\n\n By default it receives a single object at a time, transparently handling the ``many``\n argument passed to the `Schema`'s :func:`~marshmallow.Schema.dump` call.\n If ``pass_many=True``, the raw data (which may be a collection) is passed.\n\n If ``pass_original=True``, the original data (before serializing) will be passed as\n an additional argument to the method.\n\n .. versionchanged:: 3.0.0\n ``many`` is always passed as a keyword arguments to the decorated method.\n \"\"\"\n return set_hook(fn, (POST_DUMP, pass_many), pass_original=pass_original)\n\n\ndef pre_load(\n fn: Callable[..., Any] | None = None, pass_many: bool = False\n) -> Callable[..., Any]:\n \"\"\"Register a method to invoke before deserializing an object. The method\n receives the data to be deserialized and returns the processed data.\n\n By default it receives a single object at a time, transparently handling the ``many``\n argument passed to the `Schema`'s :func:`~marshmallow.Schema.load` call.\n If ``pass_many=True``, the raw data (which may be a collection) is passed.\n\n .. versionchanged:: 3.0.0\n ``partial`` and ``many`` are always passed as keyword arguments to\n the decorated method.\n \"\"\"\n return set_hook(fn, (PRE_LOAD, pass_many))\n\n\ndef post_load(\n fn: Callable[..., Any] | None = None,\n pass_many: bool = False,\n pass_original: bool = False,\n) -> Callable[..., Any]:\n \"\"\"Register a method to invoke after deserializing an object. The method\n receives the deserialized data and returns the processed data.\n\n By default it receives a single object at a time, transparently handling the ``many``\n argument passed to the `Schema`'s :func:`~marshmallow.Schema.load` call.\n If ``pass_many=True``, the raw data (which may be a collection) is passed.\n\n If ``pass_original=True``, the original data (before deserializing) will be passed as\n an additional argument to the method.\n\n .. versionchanged:: 3.0.0\n ``partial`` and ``many`` are always passed as keyword arguments to\n the decorated method.\n \"\"\"\n return set_hook(fn, (POST_LOAD, pass_many), pass_original=pass_original)\n\n\ndef set_hook(\n fn: Callable[..., Any] | None, key: tuple[str, bool] | str, **kwargs: Any\n) -> Callable[..., Any]:\n \"\"\"Mark decorated function as a hook to be picked up later.\n You should not need to use this method directly.\n\n .. note::\n Currently only works with functions and instance methods. Class and\n static methods are not supported.\n\n :return: Decorated function if supplied, else this decorator with its args\n bound.\n \"\"\"\n # Allow using this as either a decorator or a decorator factory.\n if fn is None:\n return functools.partial(set_hook, key=key, **kwargs)\n\n # Set a __marshmallow_hook__ attribute instead of wrapping in some class,\n # because I still want this to end up as a normal (unbound) method.\n function = cast(MarshmallowHook, fn)\n try:\n hook_config = function.__marshmallow_hook__\n except AttributeError:\n function.__marshmallow_hook__ = hook_config = {}\n # Also save the kwargs for the tagged function on\n # __marshmallow_hook__, keyed by (, )\n if hook_config is not None:\n hook_config[key] = kwargs\n\n return fn\n\n# Path: src/marshmallow/exceptions.py\n\"\"\"Exception classes for marshmallow-related errors.\"\"\"\nfrom __future__ import annotations\n\nimport typing\n\n# Key used for schema-level validation errors\nSCHEMA = \"_schema\"\n\n\nclass MarshmallowError(Exception):\n \"\"\"Base class for all marshmallow-related errors.\"\"\"\n\n\nclass ValidationError(MarshmallowError):\n \"\"\"Raised when validation fails on a field or schema.\n\n Validators and custom fields should raise this exception.\n\n :param message: An error message, list of error messages, or dict of\n error messages. If a dict, the keys are subitems and the values are error messages.\n :param field_name: Field name to store the error on.\n If `None`, the error is stored as schema-level error.\n :param data: Raw input data.\n :param valid_data: Valid (de)serialized data.\n \"\"\"\n\n def __init__(\n self,\n message: str | list | dict,\n field_name: str = SCHEMA,\n data: typing.Mapping[str, typing.Any]\n | typing.Iterable[typing.Mapping[str, typing.Any]]\n | None = None,\n valid_data: list[dict[str, typing.Any]] | dict[str, typing.Any] | None = None,\n **kwargs,\n ):\n self.messages = [message] if isinstance(message, (str, bytes)) else message\n self.field_name = field_name\n self.data = data\n self.valid_data = valid_data\n self.kwargs = kwargs\n super().__init__(message)\n\n \ndef normalized_messages(self):\n if self.field_name == SCHEMA and isinstance(self.messages, dict):\n return self.messages\n return {self.field_name: self.messages}\n\n @property\n def messages_dict(self) -> dict[str, typing.Any]:\n if not isinstance(self.messages, dict):\n raise TypeError(\n \"cannot access 'messages_dict' when 'messages' is of type \"\n + type(self.messages).__name__\n )\n return self.messages\n\n\nclass RegistryError(NameError):\n \"\"\"Raised when an invalid operation is performed on the serializer\n class registry.\n \"\"\"\n\n\nclass StringNotCollectionError(MarshmallowError, TypeError):\n \"\"\"Raised when a string is passed when a list of strings is expected.\"\"\"\n\n\nclass FieldInstanceResolutionError(MarshmallowError, TypeError):\n \"\"\"Raised when schema to instantiate is neither a Schema class nor an instance.\"\"\"\n\n# Path: src/marshmallow/base.py\n\"\"\"Abstract base classes.\n\nThese are necessary to avoid circular imports between schema.py and fields.py.\n\n.. warning::\n\n This module is treated as private API.\n Users should not need to use this module directly.\n\"\"\"\nfrom __future__ import annotations\n\nfrom abc import ABC, abstractmethod\n\n\nclass FieldABC(ABC):\n \"\"\"Abstract base class from which all Field classes inherit.\"\"\"\n\n parent = None\n name = None\n root = None\n\n @abstractmethod\n def serialize(self, attr, obj, accessor=None):\n pass\n\n @abstractmethod\n def deserialize(self, value):\n pass\n\n @abstractmethod\n def _serialize(self, value, attr, obj, **kwargs):\n pass\n\n @abstractmethod\n def _deserialize(self, value, attr, data, **kwargs):\n pass\n\n\nclass SchemaABC(ABC):\n \"\"\"Abstract base class from which all Schemas inherit.\"\"\"\n\n @abstractmethod\n def dump(self, obj, *, many: bool | None = None):\n pass\n\n @abstractmethod\n def dumps(self, obj, *, many: bool | None = None):\n pass\n\n @abstractmethod\n def load(self, data, *, many: bool | None = None, partial=None, unknown=None):\n pass\n\n @abstractmethod\n def loads(\n self,\n json_data,\n *,\n many: bool | None = None,\n partial=None,\n unknown=None,\n **kwargs,\n ):\n pass\n\n# Path: src/marshmallow/class_registry.py\n\"\"\"A registry of :class:`Schema ` classes. This allows for string\nlookup of schemas, which may be used with\nclass:`fields.Nested `.\n\n.. warning::\n\n This module is treated as private API.\n Users should not need to use this module directly.\n\"\"\"\nfrom __future__ import annotations\n\nimport typing\n\nfrom marshmallow.exceptions import RegistryError\n\nif typing.TYPE_CHECKING:\n from marshmallow import Schema\n\n SchemaType = typing.Type[Schema]\n\n# {\n# : \n# : \n# }\n_registry = {} # type: dict[str, list[SchemaType]]\n\n\ndef register(classname: str, cls: SchemaType) -> None:\n \"\"\"Add a class to the registry of serializer classes. When a class is\n registered, an entry for both its classname and its full, module-qualified\n path are added to the registry.\n\n Example: ::\n\n class MyClass:\n pass\n\n register('MyClass', MyClass)\n # Registry:\n # {\n # 'MyClass': [path.to.MyClass],\n # 'path.to.MyClass': [path.to.MyClass],\n # }\n\n \"\"\"\n # Module where the class is located\n module = cls.__module__\n # Full module path to the class\n # e.g. user.schemas.UserSchema\n fullpath = \".\".join([module, classname])\n # If the class is already registered; need to check if the entries are\n # in the same module as cls to avoid having multiple instances of the same\n # class in the registry\n if classname in _registry and not any(\n each.__module__ == module for each in _registry[classname]\n ):\n _registry[classname].append(cls)\n elif classname not in _registry:\n _registry[classname] = [cls]\n\n # Also register the full path\n if fullpath not in _registry:\n _registry.setdefault(fullpath, []).append(cls)\n else:\n # If fullpath does exist, replace existing entry\n _registry[fullpath] = [cls]\n return None\n\n\ndef get_class(classname: str, all: bool = False) -> list[SchemaType] | SchemaType:\n \"\"\"Retrieve a class from the registry.\n\n :raises: marshmallow.exceptions.RegistryError if the class cannot be found\n or if there are multiple entries for the given class name.\n \"\"\"\n try:\n classes = _registry[classname]\n except KeyError as error:\n raise RegistryError(\n f\"Class with name {classname!r} was not found. You may need \"\n \"to import the class.\"\n ) from error\n if len(classes) > 1:\n if all:\n return _registry[classname]\n raise RegistryError(\n f\"Multiple classes with name {classname!r} \"\n \"were found. Please use the full, \"\n \"module-qualified path.\"\n )\n else:\n return _registry[classname][0]\n\n# Path: src/marshmallow/error_store.py\n\"\"\"Utilities for storing collections of error messages.\n\n.. warning::\n\n This module is treated as private API.\n Users should not need to use this module directly.\n\"\"\"\n\nfrom marshmallow.exceptions import SCHEMA\n\n\nclass ErrorStore:\n def __init__(self):\n #: Dictionary of errors stored during serialization\n self.errors = {}\n\n def store_error(self, messages, field_name=SCHEMA, index=None):\n # field error -> store/merge error messages under field name key\n # schema error -> if string or list, store/merge under _schema key\n # -> if dict, store/merge with other top-level keys\n if field_name != SCHEMA or not isinstance(messages, dict):\n messages = {field_name: messages}\n if index is not None:\n messages = {index: messages}\n self.errors = merge_errors(self.errors, messages)\n\n\ndef merge_errors(errors1, errors2):\n \"\"\"Deeply merge two error messages.\n\n The format of ``errors1`` and ``errors2`` matches the ``message``\n parameter of :exc:`marshmallow.exceptions.ValidationError`.\n \"\"\"\n if not errors1:\n return errors2\n if not errors2:\n return errors1\n if isinstance(errors1, list):\n if isinstance(errors2, list):\n return errors1 + errors2\n if isinstance(errors2, dict):\n return dict(errors2, **{SCHEMA: merge_errors(errors1, errors2.get(SCHEMA))})\n return errors1 + [errors2]\n if isinstance(errors1, dict):\n if isinstance(errors2, list):\n return dict(errors1, **{SCHEMA: merge_errors(errors1.get(SCHEMA), errors2)})\n if isinstance(errors2, dict):\n errors = dict(errors1)\n for key, val in errors2.items():\n if key in errors:\n errors[key] = merge_errors(errors[key], val)\n else:\n errors[key] = val\n return errors\n return dict(errors1, **{SCHEMA: merge_errors(errors1.get(SCHEMA), errors2)})\n if isinstance(errors2, list):\n return [errors1] + errors2\n if isinstance(errors2, dict):\n return dict(errors2, **{SCHEMA: merge_errors(errors1, errors2.get(SCHEMA))})\n return [errors1, errors2]\n\n# Path: src/marshmallow/orderedset.py\n# OrderedSet\n# Copyright (c) 2009 Raymond Hettinger\n#\n# Permission is hereby granted, free of charge, to any person\n# obtaining a copy of this software and associated documentation files\n# (the \"Software\"), to deal in the Software without restriction,\n# including without limitation the rights to use, copy, modify, merge,\n# publish, distribute, sublicense, and/or sell copies of the Software,\n# and to permit persons to whom the Software is furnished to do so,\n# subject to the following conditions:\n#\n# The above copyright notice and this permission notice shall be\n# included in all copies or substantial portions of the Software.\n#\n# THE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND,\n# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES\n# OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND\n# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT\n...\n", "needle_token_start": 2849, "needle_token_end": 2899, "code_context_ntokens": 5279, "output": ["```python\nclass ErrorStore:\n def __init__(self):\n #: Dictionary of errors stored during serialization\n self.errors = {}\n\n def store_error(self, messages, field_name=SCHEMA, index=None):\n # field error -> store/merge error messages under field name key\n # schema error -> if string or list, store/merge under _schema key\n # -> if dict, store/merge with other top-level keys\n if field_name != SCHEMA or not isinstance(messages, dict):\n messages = {field_name: messages}\n if index is not None:\n messages = {index: messages}\n self.errors = merge_errors(self.errors, messages)\n\n\ndef merge_errors(errors1, errors2):\n \"\"\"Deeply merge two error messages.\n\n The format of ``errors1`` and ``errors2`` matches the ``message``\n parameter of :exc:`marshmallow.exceptions.ValidationError`.\n \"\"\"\n if not errors1:\n return errors2\n if not errors2:\n return errors1\n if isinstance(errors1, list):\n if isinstance(errors2, list):\n return errors1 + errors2\n if isinstance(errors2, dict):\n return dict(errors2, **{SCHEMA: merge_errors(errors1, errors2.get(SCHEMA))})\n return errors1 + [errors2]\n if isinstance(errors1, dict):\n if isinstance(errors2, list):\n return dict(errors1, **{SCHEMA: merge_errors(errors1.get(SCHEMA), errors2)})\n if isinstance(errors2, dict):\n errors = dict(errors1)\n for key, val in errors2.items():\n if key in errors:\n errors[key] = merge_errors(errors[key], val)\n else:\n errors[key] = val\n return errors\n return dict(errors1, **{SCHEMA: merge_errors(errors1.get(SCHEMA), errors2)})\n if isinstance(errors2, list):\n return [errors1] + errors2\n if isinstance(errors2, dict):\n return dict(errors2, **{SCHEMA: merge_errors(errors1, errors2.get(SCHEMA))})\n return [errors1, errors2]\n```"]} +{"repo": "marshmallow-code/marshmallow", "name": "_normalize_nested_options", "language": "python", "path": "src/marshmallow/schema.py", "position_ratio": 0.95, "description": "\nFunction Description:\n1. **Purpose**: To adjust and simplify the settings related to field inclusion and exclusion for nested schemas within a data processing context.\n2. **Input**: The function does not take parameters directly; it operates on the internal state of an object that includes settings for which fields to include or exclude.\n3. **Output**: There is no direct output; the function modifies the internal state of the object.\n4. **Procedure**: \n - If there is a specification on which fields to include, this specification is applied recursively to nested fields, and then the specification is simplified to only include the top-level field names.\n - If there is a specification on which fields to exclude, this specification is also applied recursively to nested fields. Subsequently, the specification is adjusted to exclude any top-level field names that do not contain nested specifications.\n\n", "instruction": "Based on the function description and code context, please retrieve and repeat the exact described function from the code context in a code block wrapped by ```:", "template": "instruction\ncode_context\ndescription\ninstruction", "code_context": "# Path: src/marshmallow/error_store.py\n\"\"\"Utilities for storing collections of error messages.\n\n.. warning::\n\n This module is treated as private API.\n Users should not need to use this module directly.\n\"\"\"\n\nfrom marshmallow.exceptions import SCHEMA\n\n\nclass ErrorStore:\n def __init__(self):\n #: Dictionary of errors stored during serialization\n self.errors = {}\n\n def store_error(self, messages, field_name=SCHEMA, index=None):\n # field error -> store/merge error messages under field name key\n # schema error -> if string or list, store/merge under _schema key\n # -> if dict, store/merge with other top-level keys\n if field_name != SCHEMA or not isinstance(messages, dict):\n messages = {field_name: messages}\n if index is not None:\n messages = {index: messages}\n self.errors = merge_errors(self.errors, messages)\n\n\ndef merge_errors(errors1, errors2):\n \"\"\"Deeply merge two error messages.\n\n The format of ``errors1`` and ``errors2`` matches the ``message``\n parameter of :exc:`marshmallow.exceptions.ValidationError`.\n \"\"\"\n if not errors1:\n...\n# Path: src/marshmallow/orderedset.py\n# OrderedSet\n# Copyright (c) 2009 Raymond Hettinger\n#\n# Permission is hereby granted, free of charge, to any person\n# obtaining a copy of this software and associated documentation files\n# (the \"Software\"), to deal in the Software without restriction,\n# including without limitation the rights to use, copy, modify, merge,\n# publish, distribute, sublicense, and/or sell copies of the Software,\n# and to permit persons to whom the Software is furnished to do so,\n# subject to the following conditions:\n#\n# The above copyright notice and this permission notice shall be\n# included in all copies or substantial portions of the Software.\n#\n# THE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND,\n# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES\n# OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND\n# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT\n# HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY,\n# WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING\n# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR\n# OTHER DEALINGS IN THE SOFTWARE.\nfrom collections.abc import MutableSet\n\n\nclass OrderedSet(MutableSet):\n def __init__(self, iterable=None):\n self.end = end = []\n end += [None, end, end] # sentinel node for doubly linked list\n self.map = {} # key --> [key, prev, next]\n if iterable is not None:\n self |= iterable\n\n def __len__(self):\n return len(self.map)\n\n def __contains__(self, key):\n return key in self.map\n\n def add(self, key):\n if key not in self.map:\n end = self.end\n curr = end[1]\n curr[2] = end[1] = self.map[key] = [key, curr, end]\n\n def discard(self, key):\n if key in self.map:\n key, prev, next = self.map.pop(key)\n prev[2] = next\n next[1] = prev\n\n def __iter__(self):\n end = self.end\n curr = end[2]\n while curr is not end:\n yield curr[0]\n curr = curr[2]\n\n def __reversed__(self):\n end = self.end\n curr = end[1]\n while curr is not end:\n yield curr[0]\n curr = curr[1]\n\n def pop(self, last=True):\n if not self:\n raise KeyError(\"set is empty\")\n key = self.end[1][0] if last else self.end[2][0]\n self.discard(key)\n return key\n\n def __repr__(self):\n if not self:\n return f\"{self.__class__.__name__}()\"\n return f\"{self.__class__.__name__}({list(self)!r})\"\n\n def __eq__(self, other):\n if isinstance(other, OrderedSet):\n return len(self) == len(other) and list(self) == list(other)\n return set(self) == set(other)\n\n\nif __name__ == \"__main__\":\n s = OrderedSet(\"abracadaba\")\n t = OrderedSet(\"simsalabim\")\n print(s | t)\n print(s & t)\n print(s - t)\n\n# Path: src/marshmallow/types.py\n\"\"\"Type aliases.\n\n.. warning::\n\n This module is provisional. Types may be modified, added, and removed between minor releases.\n\"\"\"\nimport typing\n\nStrSequenceOrSet = typing.Union[typing.Sequence[str], typing.AbstractSet[str]]\nTag = typing.Union[str, typing.Tuple[str, bool]]\nValidator = typing.Callable[[typing.Any], typing.Any]\n\n# Path: src/marshmallow/warnings.py\nclass RemovedInMarshmallow4Warning(DeprecationWarning):\n pass\n\n# Path: src/marshmallow/utils.py\n\"\"\"Utility methods for marshmallow.\"\"\"\nfrom __future__ import annotations\n\nimport collections\nimport datetime as dt\nimport functools\nimport inspect\nimport json\nimport re\nimport typing\nimport warnings\nfrom collections.abc import Mapping\nfrom email.utils import format_datetime, parsedate_to_datetime\nfrom pprint import pprint as py_pprint\n\nfrom marshmallow.base import FieldABC\nfrom marshmallow.exceptions import FieldInstanceResolutionError\nfrom marshmallow.warnings import RemovedInMarshmallow4Warning\n\nEXCLUDE = \"exclude\"\nINCLUDE = \"include\"\nRAISE = \"raise\"\n_UNKNOWN_VALUES = {EXCLUDE, INCLUDE, RAISE}\n\n\nclass _Missing:\n def __bool__(self):\n return False\n\n def __copy__(self):\n return self\n\n def __deepcopy__(self, _):\n return self\n\n def __repr__(self):\n return \"\"\n\n\n# Singleton value that indicates that a field's value is missing from input\n# dict passed to :meth:`Schema.load`. If the field's value is not required,\n# it's ``default`` value is used.\nmissing = _Missing()\n\n\ndef is_generator(obj) -> bool:\n \"\"\"Return True if ``obj`` is a generator\"\"\"\n return inspect.isgeneratorfunction(obj) or inspect.isgenerator(obj)\n\n\ndef is_iterable_but_not_string(obj) -> bool:\n \"\"\"Return True if ``obj`` is an iterable object that isn't a string.\"\"\"\n return (hasattr(obj, \"__iter__\") and not hasattr(obj, \"strip\")) or is_generator(obj)\n\n\ndef is_collection(obj) -> bool:\n \"\"\"Return True if ``obj`` is a collection type, e.g list, tuple, queryset.\"\"\"\n return is_iterable_but_not_string(obj) and not isinstance(obj, Mapping)\n\n\ndef is_instance_or_subclass(val, class_) -> bool:\n \"\"\"Return True if ``val`` is either a subclass or instance of ``class_``.\"\"\"\n try:\n return issubclass(val, class_)\n except TypeError:\n return isinstance(val, class_)\n\n\ndef is_keyed_tuple(obj) -> bool:\n \"\"\"Return True if ``obj`` has keyed tuple behavior, such as\n namedtuples or SQLAlchemy's KeyedTuples.\n \"\"\"\n return isinstance(obj, tuple) and hasattr(obj, \"_fields\")\n\n\ndef pprint(obj, *args, **kwargs) -> None:\n \"\"\"Pretty-printing function that can pretty-print OrderedDicts\n like regular dictionaries. Useful for printing the output of\n :meth:`marshmallow.Schema.dump`.\n\n .. deprecated:: 3.7.0\n marshmallow.pprint will be removed in marshmallow 4.\n \"\"\"\n warnings.warn(\n \"marshmallow's pprint function is deprecated and will be removed in marshmallow 4.\",\n RemovedInMarshmallow4Warning,\n stacklevel=2,\n )\n if isinstance(obj, collections.OrderedDict):\n print(json.dumps(obj, *args, **kwargs))\n else:\n py_pprint(obj, *args, **kwargs)\n\n\n# https://stackoverflow.com/a/27596917\ndef is_aware(datetime: dt.datetime) -> bool:\n return (\n datetime.tzinfo is not None and datetime.tzinfo.utcoffset(datetime) is not None\n )\n\n\ndef from_rfc(datestring: str) -> dt.datetime:\n \"\"\"Parse a RFC822-formatted datetime string and return a datetime object.\n\n https://stackoverflow.com/questions/885015/how-to-parse-a-rfc-2822-date-time-into-a-python-datetime # noqa: B950\n \"\"\"\n return parsedate_to_datetime(datestring)\n\n\ndef rfcformat(datetime: dt.datetime) -> str:\n \"\"\"Return the RFC822-formatted representation of a datetime object.\n\n :param datetime datetime: The datetime.\n \"\"\"\n return format_datetime(datetime)\n\n\n# Hat tip to Django for ISO8601 deserialization functions\n\n_iso8601_datetime_re = re.compile(\n r\"(?P\\d{4})-(?P\\d{1,2})-(?P\\d{1,2})\"\n r\"[T ](?P\\d{1,2}):(?P\\d{1,2})\"\n r\"(?::(?P\\d{1,2})(?:\\.(?P\\d{1,6})\\d{0,6})?)?\"\n r\"(?PZ|[+-]\\d{2}(?::?\\d{2})?)?$\"\n)\n\n_iso8601_date_re = re.compile(r\"(?P\\d{4})-(?P\\d{1,2})-(?P\\d{1,2})$\")\n\n_iso8601_time_re = re.compile(\n r\"(?P\\d{1,2}):(?P\\d{1,2})\"\n r\"(?::(?P\\d{1,2})(?:\\.(?P\\d{1,6})\\d{0,6})?)?\"\n)\n\n\ndef get_fixed_timezone(offset: int | float | dt.timedelta) -> dt.timezone:\n \"\"\"Return a tzinfo instance with a fixed offset from UTC.\"\"\"\n if isinstance(offset, dt.timedelta):\n offset = offset.total_seconds() // 60\n sign = \"-\" if offset < 0 else \"+\"\n hhmm = \"%02d%02d\" % divmod(abs(offset), 60)\n name = sign + hhmm\n return dt.timezone(dt.timedelta(minutes=offset), name)\n\n\ndef from_iso_datetime(value):\n \"\"\"Parse a string and return a datetime.datetime.\n\n This function supports time zone offsets. When the input contains one,\n the output uses a timezone with a fixed offset from UTC.\n \"\"\"\n match = _iso8601_datetime_re.match(value)\n if not match:\n raise ValueError(\"Not a valid ISO8601-formatted datetime string\")\n kw = match.groupdict()\n kw[\"microsecond\"] = kw[\"microsecond\"] and kw[\"microsecond\"].ljust(6, \"0\")\n tzinfo = kw.pop(\"tzinfo\")\n if tzinfo == \"Z\":\n tzinfo = dt.timezone.utc\n elif tzinfo is not None:\n offset_mins = int(tzinfo[-2:]) if len(tzinfo) > 3 else 0\n offset = 60 * int(tzinfo[1:3]) + offset_mins\n if tzinfo[0] == \"-\":\n offset = -offset\n tzinfo = get_fixed_timezone(offset)\n kw = {k: int(v) for k, v in kw.items() if v is not None}\n kw[\"tzinfo\"] = tzinfo\n return dt.datetime(**kw)\n\n\ndef from_iso_time(value):\n \"\"\"Parse a string and return a datetime.time.\n\n This function doesn't support time zone offsets.\n \"\"\"\n match = _iso8601_time_re.match(value)\n if not match:\n raise ValueError(\"Not a valid ISO8601-formatted time string\")\n kw = match.groupdict()\n kw[\"microsecond\"] = kw[\"microsecond\"] and kw[\"microsecond\"].ljust(6, \"0\")\n kw = {k: int(v) for k, v in kw.items() if v is not None}\n return dt.time(**kw)\n\n\ndef from_iso_date(value):\n \"\"\"Parse a string and return a datetime.date.\"\"\"\n match = _iso8601_date_re.match(value)\n if not match:\n raise ValueError(\"Not a valid ISO8601-formatted date string\")\n kw = {k: int(v) for k, v in match.groupdict().items()}\n return dt.date(**kw)\n\n\ndef from_timestamp(value: typing.Any) -> dt.datetime:\n value = float(value)\n if value < 0:\n raise ValueError(\"Not a valid POSIX timestamp\")\n\n # Load a timestamp with utc as timezone to prevent using system timezone.\n # Then set timezone to None, to let the Field handle adding timezone info.\n try:\n return dt.datetime.fromtimestamp(value, tz=dt.timezone.utc).replace(tzinfo=None)\n except OverflowError as exc:\n raise ValueError(\"Timestamp is too large\") from exc\n except OSError as exc:\n raise ValueError(\"Error converting value to datetime\") from exc\n\n\ndef from_timestamp_ms(value: typing.Any) -> dt.datetime:\n value = float(value)\n return from_timestamp(value / 1000)\n\n\ndef timestamp(\n value: dt.datetime,\n) -> float:\n if not is_aware(value):\n # When a date is naive, use UTC as zone info to prevent using system timezone.\n value = value.replace(tzinfo=dt.timezone.utc)\n return value.timestamp()\n\n\ndef timestamp_ms(value: dt.datetime) -> float:\n return timestamp(value) * 1000\n\n\ndef isoformat(datetime: dt.datetime) -> str:\n \"\"\"Return the ISO8601-formatted representation of a datetime object.\n\n :param datetime datetime: The datetime.\n \"\"\"\n return datetime.isoformat()\n\n\ndef to_iso_time(time: dt.time) -> str:\n return dt.time.isoformat(time)\n\n\ndef to_iso_date(date: dt.date) -> str:\n return dt.date.isoformat(date)\n\n\ndef ensure_text_type(val: str | bytes) -> str:\n if isinstance(val, bytes):\n val = val.decode(\"utf-8\")\n return str(val)\n\n\ndef pluck(dictlist: list[dict[str, typing.Any]], key: str):\n \"\"\"Extracts a list of dictionary values from a list of dictionaries.\n ::\n\n >>> dlist = [{'id': 1, 'name': 'foo'}, {'id': 2, 'name': 'bar'}]\n >>> pluck(dlist, 'id')\n [1, 2]\n \"\"\"\n return [d[key] for d in dictlist]\n\n\n# Various utilities for pulling keyed values from objects\n\n\ndef get_value(obj, key: int | str, default=missing):\n \"\"\"Helper for pulling a keyed value off various types of objects. Fields use\n this method by default to access attributes of the source object. For object `x`\n and attribute `i`, this method first tries to access `x[i]`, and then falls back to\n `x.i` if an exception is raised.\n\n .. warning::\n If an object `x` does not raise an exception when `x[i]` does not exist,\n `get_value` will never check the value `x.i`. Consider overriding\n `marshmallow.fields.Field.get_value` in this case.\n \"\"\"\n if not isinstance(key, int) and \".\" in key:\n return _get_value_for_keys(obj, key.split(\".\"), default)\n else:\n return _get_value_for_key(obj, key, default)\n\n\ndef _get_value_for_keys(obj, keys, default):\n if len(keys) == 1:\n return _get_value_for_key(obj, keys[0], default)\n else:\n return _get_value_for_keys(\n _get_value_for_key(obj, keys[0], default), keys[1:], default\n )\n\n\ndef _get_value_for_key(obj, key, default):\n if not hasattr(obj, \"__getitem__\"):\n return getattr(obj, key, default)\n\n try:\n return obj[key]\n except (KeyError, IndexError, TypeError, AttributeError):\n return getattr(obj, key, default)\n\n\ndef set_value(dct: dict[str, typing.Any], key: str, value: typing.Any):\n \"\"\"Set a value in a dict. If `key` contains a '.', it is assumed\n be a path (i.e. dot-delimited string) to the value's location.\n\n ::\n\n >>> d = {}\n >>> set_value(d, 'foo.bar', 42)\n >>> d\n {'foo': {'bar': 42}}\n \"\"\"\n if \".\" in key:\n head, rest = key.split(\".\", 1)\n target = dct.setdefault(head, {})\n if not isinstance(target, dict):\n raise ValueError(\n f\"Cannot set {key} in {head} \" f\"due to existing value: {target}\"\n )\n set_value(target, rest, value)\n else:\n dct[key] = value\n\n\ndef callable_or_raise(obj):\n \"\"\"Check that an object is callable, else raise a :exc:`TypeError`.\"\"\"\n if not callable(obj):\n raise TypeError(f\"Object {obj!r} is not callable.\")\n return obj\n\n\ndef _signature(func: typing.Callable) -> list[str]:\n return list(inspect.signature(func).parameters.keys())\n\n\ndef get_func_args(func: typing.Callable) -> list[str]:\n \"\"\"Given a callable, return a list of argument names. Handles\n `functools.partial` objects and class-based callables.\n\n .. versionchanged:: 3.0.0a1\n Do not return bound arguments, eg. ``self``.\n \"\"\"\n if inspect.isfunction(func) or inspect.ismethod(func):\n return _signature(func)\n if isinstance(func, functools.partial):\n return _signature(func.func)\n # Callable class\n return _signature(func)\n\n\ndef resolve_field_instance(cls_or_instance):\n \"\"\"Return a Schema instance from a Schema class or instance.\n\n :param type|Schema cls_or_instance: Marshmallow Schema class or instance.\n \"\"\"\n if isinstance(cls_or_instance, type):\n if not issubclass(cls_or_instance, FieldABC):\n raise FieldInstanceResolutionError\n return cls_or_instance()\n else:\n if not isinstance(cls_or_instance, FieldABC):\n raise FieldInstanceResolutionError\n return cls_or_instance\n\n\ndef timedelta_to_microseconds(value: dt.timedelta) -> int:\n \"\"\"Compute the total microseconds of a timedelta\n\n https://github.com/python/cpython/blob/bb3e0c240bc60fe08d332ff5955d54197f79751c/Lib/datetime.py#L665-L667 # noqa: B950\n \"\"\"\n return (value.days * (24 * 3600) + value.seconds) * 1000000 + value.microseconds\n\n\ndef validate_unknown_parameter_value(obj: typing.Any) -> str:\n if obj not in _UNKNOWN_VALUES:\n raise ValueError(\n f\"Object {obj!r} is not a valid value for the 'unknown' parameter\"\n )\n return obj\n\n# Path: src/marshmallow/schema.py\n\"\"\"The :class:`Schema` class, including its metaclass and options (class Meta).\"\"\"\nfrom __future__ import annotations\n\nimport copy\nimport datetime as dt\nimport decimal\nimport inspect\nimport json\nimport typing\nimport uuid\nimport warnings\nfrom abc import ABCMeta\nfrom collections import OrderedDict, defaultdict\nfrom collections.abc import Mapping\nfrom functools import lru_cache\n\nfrom marshmallow import base, class_registry, types\nfrom marshmallow import fields as ma_fields\nfrom marshmallow.decorators import (\n POST_DUMP,\n POST_LOAD,\n PRE_DUMP,\n PRE_LOAD,\n VALIDATES,\n VALIDATES_SCHEMA,\n)\nfrom marshmallow.error_store import ErrorStore\nfrom marshmallow.exceptions import StringNotCollectionError, ValidationError\nfrom marshmallow.orderedset import OrderedSet\nfrom marshmallow.utils import (\n EXCLUDE,\n INCLUDE,\n RAISE,\n get_value,\n is_collection,\n is_instance_or_subclass,\n missing,\n set_value,\n validate_unknown_parameter_value,\n)\nfrom marshmallow.warnings import RemovedInMarshmallow4Warning\n\n_T = typing.TypeVar(\"_T\")\n\n\ndef _get_fields(attrs):\n \"\"\"Get fields from a class\n\n :param attrs: Mapping of class attributes\n \"\"\"\n return [\n (field_name, field_value)\n for field_name, field_value in attrs.items()\n if is_instance_or_subclass(field_value, base.FieldABC)\n ]\n\n\n# This function allows Schemas to inherit from non-Schema classes and ensures\n# inheritance according to the MRO\ndef _get_fields_by_mro(klass):\n \"\"\"Collect fields from a class, following its method resolution order. The\n class itself is excluded from the search; only its parents are checked. Get\n fields from ``_declared_fields`` if available, else use ``__dict__``.\n\n :param type klass: Class whose fields to retrieve\n \"\"\"\n mro = inspect.getmro(klass)\n # Loop over mro in reverse to maintain correct order of fields\n return sum(\n (\n _get_fields(\n getattr(base, \"_declared_fields\", base.__dict__),\n )\n for base in mro[:0:-1]\n ),\n [],\n )\n\n\nclass SchemaMeta(ABCMeta):\n \"\"\"Metaclass for the Schema class. Binds the declared fields to\n a ``_declared_fields`` attribute, which is a dictionary mapping attribute\n names to field objects. Also sets the ``opts`` class attribute, which is\n the Schema class's ``class Meta`` options.\n \"\"\"\n\n def __new__(mcs, name, bases, attrs):\n meta = attrs.get(\"Meta\")\n ordered = getattr(meta, \"ordered\", False)\n if not ordered:\n # Inherit 'ordered' option\n # Warning: We loop through bases instead of MRO because we don't\n # yet have access to the class object\n # (i.e. can't call super before we have fields)\n for base_ in bases:\n if hasattr(base_, \"Meta\") and hasattr(base_.Meta, \"ordered\"):\n ordered = base_.Meta.ordered\n break\n else:\n ordered = False\n cls_fields = _get_fields(attrs)\n # Remove fields from list of class attributes to avoid shadowing\n # Schema attributes/methods in case of name conflict\n for field_name, _ in cls_fields:\n del attrs[field_name]\n klass = super().__new__(mcs, name, bases, attrs)\n inherited_fields = _get_fields_by_mro(klass)\n\n meta = klass.Meta\n # Set klass.opts in __new__ rather than __init__ so that it is accessible in\n # get_declared_fields\n klass.opts = klass.OPTIONS_CLASS(meta, ordered=ordered)\n # Add fields specified in the `include` class Meta option\n cls_fields += list(klass.opts.include.items())\n\n # Assign _declared_fields on class\n klass._declared_fields = mcs.get_declared_fields(\n klass=klass,\n cls_fields=cls_fields,\n inherited_fields=inherited_fields,\n dict_cls=dict,\n )\n return klass\n\n @classmethod\n def get_declared_fields(\n mcs,\n klass: type,\n cls_fields: list,\n inherited_fields: list,\n dict_cls: type = dict,\n ):\n \"\"\"Returns a dictionary of field_name => `Field` pairs declared on the class.\n This is exposed mainly so that plugins can add additional fields, e.g. fields\n computed from class Meta options.\n\n :param klass: The class object.\n :param cls_fields: The fields declared on the class, including those added\n by the ``include`` class Meta option.\n :param inherited_fields: Inherited fields.\n :param dict_cls: dict-like class to use for dict output Default to ``dict``.\n \"\"\"\n return dict_cls(inherited_fields + cls_fields)\n\n def __init__(cls, name, bases, attrs):\n super().__init__(name, bases, attrs)\n if name and cls.opts.register:\n class_registry.register(name, cls)\n cls._hooks = cls.resolve_hooks()\n\n def resolve_hooks(cls) -> dict[types.Tag, list[str]]:\n \"\"\"Add in the decorated processors\n\n By doing this after constructing the class, we let standard inheritance\n do all the hard work.\n \"\"\"\n mro = inspect.getmro(cls)\n\n hooks = defaultdict(list) # type: typing.Dict[types.Tag, typing.List[str]]\n\n for attr_name in dir(cls):\n # Need to look up the actual descriptor, not whatever might be\n # bound to the class. This needs to come from the __dict__ of the\n # declaring class.\n for parent in mro:\n try:\n attr = parent.__dict__[attr_name]\n except KeyError:\n continue\n else:\n break\n else:\n # In case we didn't find the attribute and didn't break above.\n # We should never hit this - it's just here for completeness\n # to exclude the possibility of attr being undefined.\n continue\n\n try:\n hook_config = attr.__marshmallow_hook__\n except AttributeError:\n pass\n else:\n for key in hook_config.keys():\n # Use name here so we can get the bound method later, in\n # case the processor was a descriptor or something.\n hooks[key].append(attr_name)\n\n return hooks\n\n\nclass SchemaOpts:\n \"\"\"class Meta options for the :class:`Schema`. Defines defaults.\"\"\"\n\n def __init__(self, meta, ordered: bool = False):\n self.fields = getattr(meta, \"fields\", ())\n if not isinstance(self.fields, (list, tuple)):\n raise ValueError(\"`fields` option must be a list or tuple.\")\n self.additional = getattr(meta, \"additional\", ())\n if not isinstance(self.additional, (list, tuple)):\n raise ValueError(\"`additional` option must be a list or tuple.\")\n if self.fields and self.additional:\n raise ValueError(\n \"Cannot set both `fields` and `additional` options\"\n \" for the same Schema.\"\n )\n self.exclude = getattr(meta, \"exclude\", ())\n if not isinstance(self.exclude, (list, tuple)):\n raise ValueError(\"`exclude` must be a list or tuple.\")\n self.dateformat = getattr(meta, \"dateformat\", None)\n self.datetimeformat = getattr(meta, \"datetimeformat\", None)\n self.timeformat = getattr(meta, \"timeformat\", None)\n if hasattr(meta, \"json_module\"):\n warnings.warn(\n \"The json_module class Meta option is deprecated. Use render_module instead.\",\n RemovedInMarshmallow4Warning,\n stacklevel=2,\n )\n render_module = getattr(meta, \"json_module\", json)\n else:\n render_module = json\n self.render_module = getattr(meta, \"render_module\", render_module)\n self.ordered = getattr(meta, \"ordered\", ordered)\n self.index_errors = getattr(meta, \"index_errors\", True)\n self.include = getattr(meta, \"include\", {})\n self.load_only = getattr(meta, \"load_only\", ())\n self.dump_only = getattr(meta, \"dump_only\", ())\n self.unknown = validate_unknown_parameter_value(getattr(meta, \"unknown\", RAISE))\n self.register = getattr(meta, \"register\", True)\n\n\nclass Schema(base.SchemaABC, metaclass=SchemaMeta):\n \"\"\"Base schema class with which to define custom schemas.\n\n Example usage:\n\n .. code-block:: python\n\n import datetime as dt\n from dataclasses import dataclass\n\n from marshmallow import Schema, fields\n\n\n @dataclass\n class Album:\n title: str\n release_date: dt.date\n\n\n class AlbumSchema(Schema):\n title = fields.Str()\n release_date = fields.Date()\n\n\n album = Album(\"Beggars Banquet\", dt.date(1968, 12, 6))\n schema = AlbumSchema()\n data = schema.dump(album)\n data # {'release_date': '1968-12-06', 'title': 'Beggars Banquet'}\n\n :param only: Whitelist of the declared fields to select when\n instantiating the Schema. If None, all fields are used. Nested fields\n can be represented with dot delimiters.\n :param exclude: Blacklist of the declared fields to exclude\n when instantiating the Schema. If a field appears in both `only` and\n `exclude`, it is not used. Nested fields can be represented with dot\n delimiters.\n :param many: Should be set to `True` if ``obj`` is a collection\n so that the object will be serialized to a list.\n :param context: Optional context passed to :class:`fields.Method` and\n :class:`fields.Function` fields.\n :param load_only: Fields to skip during serialization (write-only fields)\n :param dump_only: Fields to skip during deserialization (read-only fields)\n :param partial: Whether to ignore missing fields and not require\n any fields declared. Propagates down to ``Nested`` fields as well. If\n its value is an iterable, only missing fields listed in that iterable\n will be ignored. Use dot delimiters to specify nested fields.\n :param unknown: Whether to exclude, include, or raise an error for unknown\n fields in the data. Use `EXCLUDE`, `INCLUDE` or `RAISE`.\n\n .. versionchanged:: 3.0.0\n `prefix` parameter removed.\n\n .. versionchanged:: 2.0.0\n `__validators__`, `__preprocessors__`, and `__data_handlers__` are removed in favor of\n `marshmallow.decorators.validates_schema`,\n `marshmallow.decorators.pre_load` and `marshmallow.decorators.post_dump`.\n `__accessor__` and `__error_handler__` are deprecated. Implement the\n `handle_error` and `get_attribute` methods instead.\n \"\"\"\n\n TYPE_MAPPING = {\n str: ma_fields.String,\n bytes: ma_fields.String,\n dt.datetime: ma_fields.DateTime,\n float: ma_fields.Float,\n bool: ma_fields.Boolean,\n tuple: ma_fields.Raw,\n list: ma_fields.Raw,\n set: ma_fields.Raw,\n int: ma_fields.Integer,\n uuid.UUID: ma_fields.UUID,\n dt.time: ma_fields.Time,\n dt.date: ma_fields.Date,\n dt.timedelta: ma_fields.TimeDelta,\n decimal.Decimal: ma_fields.Decimal,\n } # type: typing.Dict[type, typing.Type[ma_fields.Field]]\n #: Overrides for default schema-level error messages\n error_messages = {} # type: typing.Dict[str, str]\n\n _default_error_messages = {\n \"type\": \"Invalid input type.\",\n \"unknown\": \"Unknown field.\",\n } # type: typing.Dict[str, str]\n\n OPTIONS_CLASS = SchemaOpts # type: type\n\n set_class = OrderedSet\n\n # These get set by SchemaMeta\n opts = None # type: SchemaOpts\n _declared_fields = {} # type: typing.Dict[str, ma_fields.Field]\n _hooks = {} # type: typing.Dict[types.Tag, typing.List[str]]\n\n class Meta:\n \"\"\"Options object for a Schema.\n\n Example usage: ::\n\n class Meta:\n fields = (\"id\", \"email\", \"date_created\")\n exclude = (\"password\", \"secret_attribute\")\n\n Available options:\n\n - ``fields``: Tuple or list of fields to include in the serialized result.\n - ``additional``: Tuple or list of fields to include *in addition* to the\n explicitly declared fields. ``additional`` and ``fields`` are\n mutually-exclusive options.\n - ``include``: Dictionary of additional fields to include in the schema. It is\n usually better to define fields as class variables, but you may need to\n use this option, e.g., if your fields are Python keywords. May be an\n `OrderedDict`.\n - ``exclude``: Tuple or list of fields to exclude in the serialized result.\n Nested fields can be represented with dot delimiters.\n - ``dateformat``: Default format for `Date ` fields.\n - ``datetimeformat``: Default format for `DateTime ` fields.\n - ``timeformat``: Default format for `Time ` fields.\n - ``render_module``: Module to use for `loads ` and `dumps `.\n Defaults to `json` from the standard library.\n - ``ordered``: If `True`, output of `Schema.dump` will be a `collections.OrderedDict`.\n - ``index_errors``: If `True`, errors dictionaries will include the index\n of invalid items in a collection.\n - ``load_only``: Tuple or list of fields to exclude from serialized results.\n - ``dump_only``: Tuple or list of fields to exclude from deserialization\n - ``unknown``: Whether to exclude, include, or raise an error for unknown\n fields in the data. Use `EXCLUDE`, `INCLUDE` or `RAISE`.\n - ``register``: Whether to register the `Schema` with marshmallow's internal\n class registry. Must be `True` if you intend to refer to this `Schema`\n by class name in `Nested` fields. Only set this to `False` when memory\n usage is critical. Defaults to `True`.\n \"\"\"\n\n def __init__(\n self,\n *,\n only: types.StrSequenceOrSet | None = None,\n exclude: types.StrSequenceOrSet = (),\n many: bool = False,\n context: dict | None = None,\n load_only: types.StrSequenceOrSet = (),\n dump_only: types.StrSequenceOrSet = (),\n partial: bool | types.StrSequenceOrSet | None = None,\n unknown: str | None = None,\n ):\n # Raise error if only or exclude is passed as string, not list of strings\n if only is not None and not is_collection(only):\n raise StringNotCollectionError('\"only\" should be a list of strings')\n if not is_collection(exclude):\n raise StringNotCollectionError('\"exclude\" should be a list of strings')\n # copy declared fields from metaclass\n self.declared_fields = copy.deepcopy(self._declared_fields)\n self.many = many\n self.only = only\n self.exclude: set[typing.Any] | typing.MutableSet[typing.Any] = set(\n self.opts.exclude\n ) | set(exclude)\n self.ordered = self.opts.ordered\n self.load_only = set(load_only) or set(self.opts.load_only)\n self.dump_only = set(dump_only) or set(self.opts.dump_only)\n self.partial = partial\n self.unknown = (\n self.opts.unknown\n if unknown is None\n else validate_unknown_parameter_value(unknown)\n )\n self.context = context or {}\n self._normalize_nested_options()\n #: Dictionary mapping field_names -> :class:`Field` objects\n self.fields = {} # type: typing.Dict[str, ma_fields.Field]\n self.load_fields = {} # type: typing.Dict[str, ma_fields.Field]\n self.dump_fields = {} # type: typing.Dict[str, ma_fields.Field]\n self._init_fields()\n messages = {}\n messages.update(self._default_error_messages)\n for cls in reversed(self.__class__.__mro__):\n messages.update(getattr(cls, \"error_messages\", {}))\n messages.update(self.error_messages or {})\n self.error_messages = messages\n\n def __repr__(self) -> str:\n return f\"<{self.__class__.__name__}(many={self.many})>\"\n\n @property\n def dict_class(self) -> type:\n return OrderedDict if self.ordered else dict\n\n @classmethod\n def from_dict(\n cls,\n fields: dict[str, ma_fields.Field | type],\n *,\n name: str = \"GeneratedSchema\",\n ) -> type:\n \"\"\"Generate a `Schema` class given a dictionary of fields.\n\n .. code-block:: python\n\n from marshmallow import Schema, fields\n\n PersonSchema = Schema.from_dict({\"name\": fields.Str()})\n print(PersonSchema().load({\"name\": \"David\"})) # => {'name': 'David'}\n\n Generated schemas are not added to the class registry and therefore cannot\n be referred to by name in `Nested` fields.\n\n :param dict fields: Dictionary mapping field names to field instances.\n :param str name: Optional name for the class, which will appear in\n the ``repr`` for the class.\n\n .. versionadded:: 3.0.0\n \"\"\"\n attrs = fields.copy()\n attrs[\"Meta\"] = type(\n \"GeneratedMeta\", (getattr(cls, \"Meta\", object),), {\"register\": False}\n )\n schema_cls = type(name, (cls,), attrs)\n return schema_cls\n\n ##### Override-able methods #####\n\n def handle_error(\n self, error: ValidationError, data: typing.Any, *, many: bool, **kwargs\n ):\n \"\"\"Custom error handler function for the schema.\n\n :param error: The `ValidationError` raised during (de)serialization.\n :param data: The original input data.\n :param many: Value of ``many`` on dump or load.\n :param partial: Value of ``partial`` on load.\n\n .. versionadded:: 2.0.0\n\n .. versionchanged:: 3.0.0rc9\n Receives `many` and `partial` (on deserialization) as keyword arguments.\n \"\"\"\n pass\n\n def get_attribute(self, obj: typing.Any, attr: str, default: typing.Any):\n \"\"\"Defines how to pull values from an object to serialize.\n\n .. versionadded:: 2.0.0\n\n .. versionchanged:: 3.0.0a1\n Changed position of ``obj`` and ``attr``.\n \"\"\"\n return get_value(obj, attr, default)\n\n ##### Serialization/Deserialization API #####\n\n @staticmethod\n def _call_and_store(getter_func, data, *, field_name, error_store, index=None):\n \"\"\"Call ``getter_func`` with ``data`` as its argument, and store any `ValidationErrors`.\n\n :param callable getter_func: Function for getting the serialized/deserialized\n value from ``data``.\n :param data: The data passed to ``getter_func``.\n :param str field_name: Field name.\n :param int index: Index of the item being validated, if validating a collection,\n otherwise `None`.\n \"\"\"\n try:\n value = getter_func(data)\n except ValidationError as error:\n error_store.store_error(error.messages, field_name, index=index)\n # When a Nested field fails validation, the marshalled data is stored\n # on the ValidationError's valid_data attribute\n return error.valid_data or missing\n return value\n\n def _serialize(self, obj: _T | typing.Iterable[_T], *, many: bool = False):\n \"\"\"Serialize ``obj``.\n\n :param obj: The object(s) to serialize.\n :param bool many: `True` if ``data`` should be serialized as a collection.\n :return: A dictionary of the serialized data\n\n .. versionchanged:: 1.0.0\n Renamed from ``marshal``.\n \"\"\"\n if many and obj is not None:\n return [\n self._serialize(d, many=False)\n for d in typing.cast(typing.Iterable[_T], obj)\n ]\n ret = self.dict_class()\n for attr_name, field_obj in self.dump_fields.items():\n value = field_obj.serialize(attr_name, obj, accessor=self.get_attribute)\n if value is missing:\n continue\n key = field_obj.data_key if field_obj.data_key is not None else attr_name\n ret[key] = value\n return ret\n\n def dump(self, obj: typing.Any, *, many: bool | None = None):\n \"\"\"Serialize an object to native Python data types according to this\n Schema's fields.\n\n :param obj: The object to serialize.\n :param many: Whether to serialize `obj` as a collection. If `None`, the value\n for `self.many` is used.\n :return: Serialized data\n\n .. versionadded:: 1.0.0\n .. versionchanged:: 3.0.0b7\n This method returns the serialized data rather than a ``(data, errors)`` duple.\n A :exc:`ValidationError ` is raised\n if ``obj`` is invalid.\n .. versionchanged:: 3.0.0rc9\n Validation no longer occurs upon serialization.\n \"\"\"\n many = self.many if many is None else bool(many)\n if self._has_processors(PRE_DUMP):\n processed_obj = self._invoke_dump_processors(\n PRE_DUMP, obj, many=many, original_data=obj\n )\n else:\n processed_obj = obj\n\n result = self._serialize(processed_obj, many=many)\n\n if self._has_processors(POST_DUMP):\n result = self._invoke_dump_processors(\n POST_DUMP, result, many=many, original_data=obj\n )\n\n return result\n\n def dumps(self, obj: typing.Any, *args, many: bool | None = None, **kwargs):\n \"\"\"Same as :meth:`dump`, except return a JSON-encoded string.\n\n :param obj: The object to serialize.\n :param many: Whether to serialize `obj` as a collection. If `None`, the value\n for `self.many` is used.\n :return: A ``json`` string\n\n .. versionadded:: 1.0.0\n .. versionchanged:: 3.0.0b7\n This method returns the serialized data rather than a ``(data, errors)`` duple.\n A :exc:`ValidationError ` is raised\n if ``obj`` is invalid.\n \"\"\"\n serialized = self.dump(obj, many=many)\n return self.opts.render_module.dumps(serialized, *args, **kwargs)\n\n def _deserialize(\n self,\n data: (\n typing.Mapping[str, typing.Any]\n | typing.Iterable[typing.Mapping[str, typing.Any]]\n ),\n *,\n error_store: ErrorStore,\n many: bool = False,\n partial=None,\n unknown=RAISE,\n index=None,\n ) -> _T | list[_T]:\n \"\"\"Deserialize ``data``.\n\n :param dict data: The data to deserialize.\n :param ErrorStore error_store: Structure to store errors.\n :param bool many: `True` if ``data`` should be deserialized as a collection.\n :param bool|tuple partial: Whether to ignore missing fields and not require\n any fields declared. Propagates down to ``Nested`` fields as well. If\n its value is an iterable, only missing fields listed in that iterable\n will be ignored. Use dot delimiters to specify nested fields.\n :param unknown: Whether to exclude, include, or raise an error for unknown\n fields in the data. Use `EXCLUDE`, `INCLUDE` or `RAISE`.\n :param int index: Index of the item being serialized (for storing errors) if\n serializing a collection, otherwise `None`.\n :return: A dictionary of the deserialized data.\n \"\"\"\n index_errors = self.opts.index_errors\n index = index if index_errors else None\n if many:\n if not is_collection(data):\n error_store.store_error([self.error_messages[\"type\"]], index=index)\n ret_l = [] # type: typing.List[_T]\n else:\n ret_l = [\n typing.cast(\n _T,\n self._deserialize(\n typing.cast(typing.Mapping[str, typing.Any], d),\n error_store=error_store,\n many=False,\n partial=partial,\n unknown=unknown,\n index=idx,\n ),\n )\n for idx, d in enumerate(data)\n ]\n return ret_l\n ret_d = self.dict_class()\n # Check data is a dict\n if not isinstance(data, Mapping):\n error_store.store_error([self.error_messages[\"type\"]], index=index)\n else:\n partial_is_collection = is_collection(partial)\n for attr_name, field_obj in self.load_fields.items():\n field_name = (\n field_obj.data_key if field_obj.data_key is not None else attr_name\n )\n raw_value = data.get(field_name, missing)\n if raw_value is missing:\n # Ignore missing field if we're allowed to.\n if partial is True or (\n partial_is_collection and attr_name in partial\n ):\n continue\n d_kwargs = {}\n # Allow partial loading of nested schemas.\n if partial_is_collection:\n prefix = field_name + \".\"\n len_prefix = len(prefix)\n sub_partial = [\n f[len_prefix:] for f in partial if f.startswith(prefix)\n ]\n d_kwargs[\"partial\"] = sub_partial\n elif partial is not None:\n d_kwargs[\"partial\"] = partial\n\n def getter(\n val, field_obj=field_obj, field_name=field_name, d_kwargs=d_kwargs\n ):\n return field_obj.deserialize(\n val,\n field_name,\n data,\n **d_kwargs,\n )\n\n value = self._call_and_store(\n getter_func=getter,\n data=raw_value,\n field_name=field_name,\n error_store=error_store,\n index=index,\n )\n if value is not missing:\n key = field_obj.attribute or attr_name\n set_value(ret_d, key, value)\n if unknown != EXCLUDE:\n fields = {\n field_obj.data_key if field_obj.data_key is not None else field_name\n for field_name, field_obj in self.load_fields.items()\n }\n for key in set(data) - fields:\n value = data[key]\n if unknown == INCLUDE:\n ret_d[key] = value\n elif unknown == RAISE:\n error_store.store_error(\n [self.error_messages[\"unknown\"]],\n key,\n (index if index_errors else None),\n )\n return ret_d\n\n def load(\n self,\n data: (\n typing.Mapping[str, typing.Any]\n | typing.Iterable[typing.Mapping[str, typing.Any]]\n ),\n *,\n many: bool | None = None,\n partial: bool | types.StrSequenceOrSet | None = None,\n unknown: str | None = None,\n ):\n \"\"\"Deserialize a data structure to an object defined by this Schema's fields.\n\n :param data: The data to deserialize.\n :param many: Whether to deserialize `data` as a collection. If `None`, the\n value for `self.many` is used.\n :param partial: Whether to ignore missing fields and not require\n any fields declared. Propagates down to ``Nested`` fields as well. If\n its value is an iterable, only missing fields listed in that iterable\n will be ignored. Use dot delimiters to specify nested fields.\n :param unknown: Whether to exclude, include, or raise an error for unknown\n fields in the data. Use `EXCLUDE`, `INCLUDE` or `RAISE`.\n If `None`, the value for `self.unknown` is used.\n :return: Deserialized data\n\n .. versionadded:: 1.0.0\n .. versionchanged:: 3.0.0b7\n This method returns the deserialized data rather than a ``(data, errors)`` duple.\n A :exc:`ValidationError ` is raised\n if invalid data are passed.\n \"\"\"\n return self._do_load(\n data, many=many, partial=partial, unknown=unknown, postprocess=True\n )\n\n def loads(\n self,\n json_data: str,\n *,\n many: bool | None = None,\n partial: bool | types.StrSequenceOrSet | None = None,\n unknown: str | None = None,\n **kwargs,\n ):\n \"\"\"Same as :meth:`load`, except it takes a JSON string as input.\n\n :param json_data: A JSON string of the data to deserialize.\n :param many: Whether to deserialize `obj` as a collection. If `None`, the\n value for `self.many` is used.\n :param partial: Whether to ignore missing fields and not require\n any fields declared. Propagates down to ``Nested`` fields as well. If\n its value is an iterable, only missing fields listed in that iterable\n will be ignored. Use dot delimiters to specify nested fields.\n :param unknown: Whether to exclude, include, or raise an error for unknown\n fields in the data. Use `EXCLUDE`, `INCLUDE` or `RAISE`.\n If `None`, the value for `self.unknown` is used.\n :return: Deserialized data\n\n .. versionadded:: 1.0.0\n .. versionchanged:: 3.0.0b7\n This method returns the deserialized data rather than a ``(data, errors)`` duple.\n A :exc:`ValidationError ` is raised\n if invalid data are passed.\n \"\"\"\n data = self.opts.render_module.loads(json_data, **kwargs)\n return self.load(data, many=many, partial=partial, unknown=unknown)\n\n def _run_validator(\n self,\n validator_func,\n output,\n *,\n original_data,\n error_store,\n many,\n partial,\n pass_original,\n index=None,\n ):\n try:\n if pass_original: # Pass original, raw data (before unmarshalling)\n validator_func(output, original_data, partial=partial, many=many)\n else:\n validator_func(output, partial=partial, many=many)\n except ValidationError as err:\n error_store.store_error(err.messages, err.field_name, index=index)\n\n def validate(\n self,\n data: (\n typing.Mapping[str, typing.Any]\n | typing.Iterable[typing.Mapping[str, typing.Any]]\n ),\n *,\n many: bool | None = None,\n partial: bool | types.StrSequenceOrSet | None = None,\n ) -> dict[str, list[str]]:\n \"\"\"Validate `data` against the schema, returning a dictionary of\n validation errors.\n\n :param data: The data to validate.\n :param many: Whether to validate `data` as a collection. If `None`, the\n value for `self.many` is used.\n :param partial: Whether to ignore missing fields and not require\n any fields declared. Propagates down to ``Nested`` fields as well. If\n its value is an iterable, only missing fields listed in that iterable\n will be ignored. Use dot delimiters to specify nested fields.\n :return: A dictionary of validation errors.\n\n .. versionadded:: 1.1.0\n \"\"\"\n try:\n self._do_load(data, many=many, partial=partial, postprocess=False)\n except ValidationError as exc:\n return typing.cast(typing.Dict[str, typing.List[str]], exc.messages)\n return {}\n\n ##### Private Helpers #####\n\n def _do_load(\n self,\n data: (\n typing.Mapping[str, typing.Any]\n | typing.Iterable[typing.Mapping[str, typing.Any]]\n ),\n *,\n many: bool | None = None,\n partial: bool | types.StrSequenceOrSet | None = None,\n unknown: str | None = None,\n postprocess: bool = True,\n ):\n \"\"\"Deserialize `data`, returning the deserialized result.\n This method is private API.\n\n :param data: The data to deserialize.\n :param many: Whether to deserialize `data` as a collection. If `None`, the\n value for `self.many` is used.\n :param partial: Whether to validate required fields. If its\n value is an iterable, only fields listed in that iterable will be\n ignored will be allowed missing. If `True`, all fields will be allowed missing.\n If `None`, the value for `self.partial` is used.\n :param unknown: Whether to exclude, include, or raise an error for unknown\n fields in the data. Use `EXCLUDE`, `INCLUDE` or `RAISE`.\n If `None`, the value for `self.unknown` is used.\n :param postprocess: Whether to run post_load methods..\n :return: Deserialized data\n \"\"\"\n error_store = ErrorStore()\n errors = {} # type: dict[str, list[str]]\n many = self.many if many is None else bool(many)\n unknown = (\n self.unknown\n if unknown is None\n else validate_unknown_parameter_value(unknown)\n )\n if partial is None:\n partial = self.partial\n # Run preprocessors\n if self._has_processors(PRE_LOAD):\n try:\n processed_data = self._invoke_load_processors(\n PRE_LOAD, data, many=many, original_data=data, partial=partial\n )\n except ValidationError as err:\n errors = err.normalized_messages()\n result = None # type: list | dict | None\n else:\n processed_data = data\n if not errors:\n # Deserialize data\n result = self._deserialize(\n processed_data,\n error_store=error_store,\n many=many,\n partial=partial,\n unknown=unknown,\n )\n # Run field-level validation\n self._invoke_field_validators(\n error_store=error_store, data=result, many=many\n )\n # Run schema-level validation\n if self._has_processors(VALIDATES_SCHEMA):\n field_errors = bool(error_store.errors)\n self._invoke_schema_validators(\n error_store=error_store,\n pass_many=True,\n data=result,\n original_data=data,\n many=many,\n partial=partial,\n field_errors=field_errors,\n )\n self._invoke_schema_validators(\n error_store=error_store,\n pass_many=False,\n data=result,\n original_data=data,\n many=many,\n partial=partial,\n field_errors=field_errors,\n )\n errors = error_store.errors\n # Run post processors\n if not errors and postprocess and self._has_processors(POST_LOAD):\n try:\n result = self._invoke_load_processors(\n POST_LOAD,\n result,\n many=many,\n original_data=data,\n partial=partial,\n )\n except ValidationError as err:\n errors = err.normalized_messages()\n if errors:\n exc = ValidationError(errors, data=data, valid_data=result)\n self.handle_error(exc, data, many=many, partial=partial)\n raise exc\n\n return result\n\n \ndef _normalize_nested_options(self) -> None:\n \"\"\"Apply then flatten nested schema options.\n This method is private API.\n \"\"\"\n if self.only is not None:\n # Apply the only option to nested fields.\n self.__apply_nested_option(\"only\", self.only, \"intersection\")\n # Remove the child field names from the only option.\n self.only = self.set_class([field.split(\".\", 1)[0] for field in self.only])\n if self.exclude:\n # Apply the exclude option to nested fields.\n self.__apply_nested_option(\"exclude\", self.exclude, \"union\")\n # Remove the parent field names from the exclude option.\n self.exclude = self.set_class(\n [field for field in self.exclude if \".\" not in field]\n )\n\n def __apply_nested_option(self, option_name, field_names, set_operation) -> None:\n \"\"\"Apply nested options to nested fields\"\"\"\n # Split nested field names on the first dot.\n nested_fields = [name.split(\".\", 1) for name in field_names if \".\" in name]\n # Partition the nested field names by parent field.\n nested_options = defaultdict(list) # type: defaultdict\n for parent, nested_names in nested_fields:\n nested_options[parent].append(nested_names)\n # Apply the nested field options.\n for key, options in iter(nested_options.items()):\n new_options = self.set_class(options)\n original_options = getattr(self.declared_fields[key], option_name, ())\n if original_options:\n if set_operation == \"union\":\n new_options |= self.set_class(original_options)\n if set_operation == \"intersection\":\n new_options &= self.set_class(original_options)\n setattr(self.declared_fields[key], option_name, new_options)\n\n def _init_fields(self) -> None:\n \"\"\"Update self.fields, self.load_fields, and self.dump_fields based on schema options.\n This method is private API.\n \"\"\"\n if self.opts.fields:\n available_field_names = self.set_class(self.opts.fields)\n else:\n available_field_names = self.set_class(self.declared_fields.keys())\n if self.opts.additional:\n available_field_names |= self.set_class(self.opts.additional)\n\n invalid_fields = self.set_class()\n\n if self.only is not None:\n # Return only fields specified in only option\n field_names: typing.AbstractSet[typing.Any] = self.set_class(self.only)\n\n invalid_fields |= field_names - available_field_names\n else:\n field_names = available_field_names\n\n # If \"exclude\" option or param is specified, remove those fields.\n if self.exclude:\n # Note that this isn't available_field_names, since we want to\n # apply \"only\" for the actual calculation.\n field_names = field_names - self.exclude\n invalid_fields |= self.exclude - available_field_names\n\n if invalid_fields:\n message = f\"Invalid fields for {self}: {invalid_fields}.\"\n raise ValueError(message)\n\n fields_dict = self.dict_class()\n for field_name in field_names:\n field_obj = self.declared_fields.get(field_name, ma_fields.Inferred())\n...\n", "needle_token_start": 15237, "needle_token_end": 15446, "code_context_ntokens": 16143, "output": ["```python\ndef _normalize_nested_options(self) -> None:\n \"\"\"Apply then flatten nested schema options.\n This method is private API.\n \"\"\"\n if self.only is not None:\n # Apply the only option to nested fields.\n self.__apply_nested_option(\"only\", self.only, \"intersection\")\n # Remove the child field names from the only option.\n self.only = self.set_class([field.split(\".\", 1)[0] for field in self.only])\n if self.exclude:\n # Apply the exclude option to nested fields.\n self.__apply_nested_option(\"exclude\", self.exclude, \"union\")\n # Remove the parent field names from the exclude option.\n self.exclude = self.set_class(\n [field for field in self.exclude if \".\" not in field]\n )\n\ndef __apply_nested_option(self, option_name, field_names, set_operation) -> None:\n \"\"\"Apply nested options to nested fields\"\"\"\n # Split nested field names on the first dot.\n nested_fields = [name.split(\".\", 1) for name in field_names if \".\" in name]\n # Partition the nested field names by parent field.\n nested_options = defaultdict(list) # type: defaultdict\n for parent, nested_names in nested_fields:\n nested_options[parent].append(nested_names)\n # Apply the nested field options.\n for key, options in iter(nested_options.items()):\n new_options = self.set_class(options)\n original_options = getattr(self.declared_fields[key], option_name, ())\n if original_options:\n if set_operation == \"union\":\n new_options |= self.set_class(original_options)\n if set_operation == \"intersection\":\n new_options &= self.set_class(original_options)\n setattr(self.declared_fields[key], option_name, new_options)\n```"]} +{"repo": "ethereum/web3.py", "name": "test_eth_get_logs_with_logs_none_topic_args", "language": "python", "path": "web3/_utils/module_testing/eth_module.py", "position_ratio": 0.05, "description": "\nFunction Description:\n1. **Purpose**: To verify that querying blockchain logs with a filter containing only `None` values for topics results in no logs being returned.\n2. **Input**: A filter parameter set with the starting block number and an array of topics filled entirely with `None`.\n3. **Output**: An empty list of logs, confirming that no entries match the given filter criteria.\n4. **Procedure**: The function sets up a filter to search from the genesis block and specifies three `None` values in the topics array. It then queries the blockchain logs using this filter and checks that the result is an empty list, indicating no logs were found that match the filter.\n\n", "instruction": "Based on the function description and code context, please retrieve and repeat the exact described function from the code context in a code block wrapped by ```:", "template": "instruction\ncode_context\ndescription\ninstruction", "code_context": " ) -> None:\n # Test with block range\n\n # the range includes the block where the log resides in\n filter_params: FilterParams = {\n \"fromBlock\": block_with_txn_with_log[\"number\"],\n \"toBlock\": block_with_txn_with_log[\"number\"],\n }\n result = w3.eth.get_logs(filter_params)\n assert_contains_log(\n result, block_with_txn_with_log, emitter_contract_address, txn_hash_with_log\n )\n\n # specify only `from_block`. by default `to_block` should be 'latest'\n filter_params = {\n \"fromBlock\": BlockNumber(0),\n }\n result = w3.eth.get_logs(filter_params)\n assert_contains_log(\n result, block_with_txn_with_log, emitter_contract_address, txn_hash_with_log\n )\n\n # Test with `address`\n\n # filter with emitter_contract.address\n filter_params = {\n \"fromBlock\": BlockNumber(0),\n \"address\": emitter_contract_address,\n }\n\n def test_eth_get_logs_with_logs_topic_args(\n self,\n w3: \"Web3\",\n block_with_txn_with_log: BlockData,\n emitter_contract_address: ChecksumAddress,\n txn_hash_with_log: HexStr,\n ) -> None:\n # Test with None event sig\n\n filter_params: FilterParams = {\n \"fromBlock\": BlockNumber(0),\n \"topics\": [\n None,\n HexStr(\n \"0x000000000000000000000000000000000000000000000000000000000000d431\"\n ),\n ],\n }\n\n result = w3.eth.get_logs(filter_params)\n assert_contains_log(\n result, block_with_txn_with_log, emitter_contract_address, txn_hash_with_log\n )\n\n # Test with None indexed arg\n filter_params = {\n \"fromBlock\": BlockNumber(0),\n \"topics\": [\n HexStr(\n \"0x057bc32826fbe161da1c110afcdcae7c109a8b69149f727fc37a603c60ef94ca\"\n ),\n None,\n ],\n }\n result = w3.eth.get_logs(filter_params)\n assert_contains_log(\n result, block_with_txn_with_log, emitter_contract_address, txn_hash_with_log\n )\n\n def \ntest_eth_get_logs_with_logs_none_topic_args(self, w3: \"Web3\") -> None:\n # Test with None overflowing\n filter_params: FilterParams = {\n \"fromBlock\": BlockNumber(0),\n \"topics\": [None, None, None],\n }\n\n result = w3.eth.get_logs(filter_params)\n assert len(result) == 0\n\n def test_eth_call_old_contract_state(\n self, w3: \"Web3\", math_contract: \"Contract\", unlocked_account: ChecksumAddress\n ) -> None:\n start_block = w3.eth.get_block(\"latest\")\n block_num = start_block[\"number\"]\n block_hash = start_block[\"hash\"]\n\n math_contract.functions.incrementCounter().transact({\"from\": unlocked_account})\n\n # This isn't an incredibly convincing test since we can't mine, and\n # the default resolved block is latest, So if block_identifier was ignored\n # we would get the same result. For now, we mostly depend on core tests.\n # Ideas to improve this test:\n # - Enable on-demand mining in more clients\n # - Increment the math contract in all of the fixtures, and check the\n # value in an old block\n block_hash_call_result = math_contract.functions.counter().call(\n block_identifier=block_hash\n )\n block_num_call_result = math_contract.functions.counter().call(\n block_identifier=block_num\n )\n latest_call_result = math_contract.functions.counter().call(\n block_identifier=\"latest\"\n )\n default_call_result = math_contract.functions.counter().call()\n pending_call_result = math_contract.functions.counter().call(\n block_identifier=\"pending\"\n )\n\n assert block_hash_call_result == 0\n assert block_num_call_result == 0\n assert latest_call_result == 0\n assert default_call_result == 0\n\n if pending_call_result != 1:\n raise AssertionError(\n f\"pending call result was {pending_call_result} instead of 1\"\n )\n\n def test_eth_uninstall_filter(self, w3: \"Web3\") -> None:\n filter = w3.eth.filter({})\n assert is_string(filter.filter_id)\n\n success = w3.eth.uninstall_filter(filter.filter_id)\n assert success is True\n\n failure = w3.eth.uninstall_filter(filter.filter_id)\n assert failure is False\n\n def test_eth_get_raw_transaction(self, w3: \"Web3\", mined_txn_hash: HexStr) -> None:\n raw_transaction = w3.eth.get_raw_transaction(mined_txn_hash)\n assert is_bytes(raw_transaction)\n\n def test_eth_get_raw_transaction_raises_error(self, w3: \"Web3\") -> None:\n with pytest.raises(\n TransactionNotFound, match=f\"Transaction with hash: '{UNKNOWN_HASH}'\"\n ):\n w3.eth.get_raw_transaction(UNKNOWN_HASH)\n\n def test_eth_get_raw_transaction_by_block(\n self,\n w3: \"Web3\",\n unlocked_account_dual_type: ChecksumAddress,\n block_with_txn: BlockData,\n ) -> None:\n # eth_getRawTransactionByBlockNumberAndIndex: block identifier\n # send a txn to make sure pending block has at least one txn\n w3.eth.send_transaction(\n {\n \"from\": unlocked_account_dual_type,\n \"to\": unlocked_account_dual_type,\n \"value\": Wei(1),\n }\n )\n last_pending_txn_index = len(w3.eth.get_block(\"pending\")[\"transactions\"]) - 1\n raw_transaction = w3.eth.get_raw_transaction_by_block(\n \"pending\", last_pending_txn_index\n )\n assert is_bytes(raw_transaction)\n\n # eth_getRawTransactionByBlockNumberAndIndex: block number\n block_with_txn_number = block_with_txn[\"number\"]\n assert is_integer(block_with_txn_number)\n raw_transaction = w3.eth.get_raw_transaction_by_block(block_with_txn_number, 0)\n assert is_bytes(raw_transaction)\n\n # eth_getRawTransactionByBlockHashAndIndex: block hash\n block_with_txn_hash = block_with_txn[\"hash\"]\n assert is_bytes(block_with_txn_hash)\n raw_transaction = w3.eth.get_raw_transaction_by_block(block_with_txn_hash, 0)\n assert is_bytes(raw_transaction)\n\n @pytest.mark.parametrize(\"unknown_block_num_or_hash\", (1234567899999, UNKNOWN_HASH))\n def test_eth_get_raw_transaction_by_block_raises_error(\n self, w3: \"Web3\", unknown_block_num_or_hash: Union[int, HexBytes]\n ) -> None:\n with pytest.raises(\n TransactionNotFound,\n match=(\n f\"Transaction index: 0 on block id: \"\n f\"{to_hex_if_integer(unknown_block_num_or_hash)!r} not found.\"\n ),\n ):\n w3.eth.get_raw_transaction_by_block(unknown_block_num_or_hash, 0)\n\n def test_eth_get_raw_transaction_by_block_raises_error_block_identifier(\n self, w3: \"Web3\"\n ) -> None:\n unknown_identifier = \"unknown\"\n with pytest.raises(\n ValueError,\n match=(\n \"Value did not match any of the recognized block identifiers: \"\n f\"{unknown_identifier}\"\n ),\n ):\n w3.eth.get_raw_transaction_by_block(unknown_identifier, 0) # type: ignore\n\n def test_default_account(\n self, w3: \"Web3\", unlocked_account_dual_type: ChecksumAddress\n ) -> None:\n # check defaults to empty\n default_account = w3.eth.default_account\n assert default_account is empty\n\n # check setter\n w3.eth.default_account = unlocked_account_dual_type\n default_account = w3.eth.default_account\n assert default_account == unlocked_account_dual_type\n\n # reset to default\n w3.eth.default_account = empty\n\n def test_default_block(\n self,\n w3: \"Web3\",\n ) -> None:\n # check defaults to 'latest'\n default_block = w3.eth.default_block\n assert default_block == \"latest\"\n\n # check setter\n w3.eth.default_block = BlockNumber(12345)\n default_block = w3.eth.default_block\n assert default_block == BlockNumber(12345)\n\n # reset to default\n w3.eth.default_block = \"latest\"\n\n# Path: web3/_utils/module_testing/go_ethereum_admin_module.py\nimport pytest\nfrom typing import (\n TYPE_CHECKING,\n List,\n)\n\nfrom web3.datastructures import (\n AttributeDict,\n)\nfrom web3.types import (\n EnodeURI,\n)\n\nif TYPE_CHECKING:\n from web3 import ( # noqa: F401\n AsyncWeb3,\n Web3,\n )\n\n\nclass GoEthereumAdminModuleTest:\n def test_add_peer(self, w3: \"Web3\") -> None:\n result = w3.geth.admin.add_peer(\n EnodeURI(\n \"enode://f1a6b0bdbf014355587c3018454d070ac57801f05d3b39fe85da574f002a32e929f683d72aa5a8318382e4d3c7a05c9b91687b0d997a39619fb8a6e7ad88e512@1.1.1.1:30303\" # noqa: E501\n ),\n )\n assert result is True\n\n def test_admin_datadir(self, w3: \"Web3\", datadir: str) -> None:\n result = w3.geth.admin.datadir()\n assert result == datadir\n\n def test_admin_node_info(self, w3: \"Web3\") -> None:\n result = w3.geth.admin.node_info()\n expected = AttributeDict(\n {\n \"id\": \"\",\n \"name\": \"\",\n \"enode\": \"\",\n \"ip\": \"\",\n \"ports\": AttributeDict({}),\n \"listenAddr\": \"\",\n \"protocols\": AttributeDict({}),\n }\n )\n # Test that result gives at least the keys that are listed in `expected`\n assert not set(expected.keys()).difference(result.keys())\n\n def test_admin_peers(self, w3: \"Web3\") -> None:\n enode = w3.geth.admin.node_info()[\"enode\"]\n w3.geth.admin.add_peer(enode)\n result = w3.geth.admin.peers()\n assert len(result) == 1\n\n def test_admin_start_stop_http(self, w3: \"Web3\") -> None:\n stop = w3.geth.admin.stop_http()\n assert stop is True\n\n start = w3.geth.admin.start_http()\n assert start is True\n\n def test_admin_start_stop_ws(self, w3: \"Web3\") -> None:\n stop = w3.geth.admin.stop_ws()\n assert stop is True\n\n start = w3.geth.admin.start_ws()\n assert start is True\n\n\nclass GoEthereumAsyncAdminModuleTest:\n @pytest.mark.asyncio\n async def test_async_datadir(self, async_w3: \"AsyncWeb3\") -> None:\n datadir = await async_w3.geth.admin.datadir()\n assert isinstance(datadir, str)\n\n @pytest.mark.asyncio\n async def test_async_node_info(self, async_w3: \"AsyncWeb3\") -> None:\n node_info = await async_w3.geth.admin.node_info()\n assert \"Geth\" in node_info[\"name\"]\n\n @pytest.mark.asyncio\n async def test_async_nodes(self, async_w3: \"AsyncWeb3\") -> None:\n nodes = await async_w3.geth.admin.peers()\n assert isinstance(nodes, List)\n\n @pytest.mark.asyncio\n async def test_admin_peers(self, async_w3: \"AsyncWeb3\") -> None:\n node_info = await async_w3.geth.admin.node_info()\n await async_w3.geth.admin.add_peer(node_info[\"enode\"])\n result = await async_w3.geth.admin.peers()\n assert len(result) == 1\n\n @pytest.mark.asyncio\n async def test_admin_start_stop_http(self, async_w3: \"AsyncWeb3\") -> None:\n stop = await async_w3.geth.admin.stop_http()\n assert stop is True\n\n start = await async_w3.geth.admin.start_http()\n assert start is True\n\n @pytest.mark.asyncio\n async def test_admin_start_stop_ws(self, async_w3: \"AsyncWeb3\") -> None:\n stop = await async_w3.geth.admin.stop_ws()\n assert stop is True\n\n start = await async_w3.geth.admin.start_ws()\n assert start is True\n\n# Path: web3/_utils/module_testing/go_ethereum_personal_module.py\nimport json\nimport pytest\nfrom typing import (\n TYPE_CHECKING,\n cast,\n)\n\nfrom eth_typing import (\n ChecksumAddress,\n)\nfrom eth_utils import (\n is_checksum_address,\n is_list_like,\n is_same_address,\n is_string,\n)\nfrom hexbytes import (\n HexBytes,\n)\n\nfrom web3 import (\n constants,\n)\nfrom web3.datastructures import (\n AttributeDict,\n)\nfrom web3.types import (\n TxParams,\n Wei,\n)\n\nif TYPE_CHECKING:\n from web3 import ( # noqa: F401\n AsyncWeb3,\n Web3,\n )\n\nPRIVATE_KEY_HEX = \"0x56ebb41875ceedd42e395f730e03b5c44989393c9f0484ee6bc05f933673458f\"\nSECOND_PRIVATE_KEY_HEX = (\n \"0x56ebb41875ceedd42e395f730e03b5c44989393c9f0484ee6bc05f9336712345\"\n)\nTHIRD_PRIVATE_KEY_HEX = (\n \"0x56ebb41875ceedd42e395f730e03b5c44989393c9f0484ee6bc05f9336754321\"\n)\nPASSWORD = \"web3-testing\"\nADDRESS = \"0x844B417c0C58B02c2224306047B9fb0D3264fE8c\"\nSECOND_ADDRESS = \"0xB96b6B21053e67BA59907E252D990C71742c41B8\"\n\n\nPRIVATE_KEY_FOR_UNLOCK = (\n \"0x392f63a79b1ff8774845f3fa69de4a13800a59e7083f5187f1558f0797ad0f01\"\n)\nACCOUNT_FOR_UNLOCK = \"0x12efDc31B1a8FA1A1e756DFD8A1601055C971E13\"\n\n\nclass GoEthereumPersonalModuleTest:\n def test_personal_import_raw_key(self, w3: \"Web3\") -> None:\n actual = w3.geth.personal.import_raw_key(PRIVATE_KEY_HEX, PASSWORD)\n assert actual == ADDRESS\n\n def test_personal_list_accounts(self, w3: \"Web3\") -> None:\n accounts = w3.geth.personal.list_accounts()\n assert is_list_like(accounts)\n assert len(accounts) > 0\n assert all((is_checksum_address(item) for item in accounts))\n\n def test_personal_list_wallets(self, w3: \"Web3\") -> None:\n wallets = w3.geth.personal.list_wallets()\n assert is_list_like(wallets)\n assert len(wallets) > 0\n assert is_checksum_address(wallets[0][\"accounts\"][0][\"address\"])\n assert is_string(wallets[0][\"accounts\"][0][\"url\"])\n assert is_string(wallets[0][\"status\"])\n assert is_string(wallets[0][\"url\"])\n\n def test_personal_lock_account(\n self, w3: \"Web3\", unlockable_account_dual_type: ChecksumAddress\n ) -> None:\n # TODO: how do we test this better?\n w3.geth.personal.lock_account(unlockable_account_dual_type)\n\n def test_personal_unlock_account_success(\n self,\n w3: \"Web3\",\n unlockable_account_dual_type: ChecksumAddress,\n unlockable_account_pw: str,\n ) -> None:\n result = w3.geth.personal.unlock_account(\n unlockable_account_dual_type, unlockable_account_pw\n )\n assert result is True\n\n def test_personal_unlock_account_failure(\n self, w3: \"Web3\", unlockable_account_dual_type: ChecksumAddress\n ) -> None:\n with pytest.raises(ValueError):\n w3.geth.personal.unlock_account(\n unlockable_account_dual_type, \"bad-password\"\n )\n\n def test_personal_new_account(self, w3: \"Web3\") -> None:\n new_account = w3.geth.personal.new_account(PASSWORD)\n assert is_checksum_address(new_account)\n\n def test_personal_send_transaction(\n self,\n w3: \"Web3\",\n unlockable_account_dual_type: ChecksumAddress,\n unlockable_account_pw: str,\n ) -> None:\n assert (\n w3.eth.get_balance(unlockable_account_dual_type) > constants.WEI_PER_ETHER\n )\n txn_params: TxParams = {\n \"from\": unlockable_account_dual_type,\n \"to\": unlockable_account_dual_type,\n \"gas\": 21000,\n \"value\": Wei(1),\n \"gasPrice\": w3.to_wei(1, \"gwei\"),\n }\n txn_hash = w3.geth.personal.send_transaction(txn_params, unlockable_account_pw)\n assert txn_hash\n transaction = w3.eth.get_transaction(txn_hash)\n\n assert is_same_address(\n transaction[\"from\"], cast(ChecksumAddress, txn_params[\"from\"])\n )\n assert is_same_address(\n transaction[\"to\"], cast(ChecksumAddress, txn_params[\"to\"])\n )\n assert transaction[\"gas\"] == txn_params[\"gas\"]\n assert transaction[\"value\"] == txn_params[\"value\"]\n assert transaction[\"gasPrice\"] == txn_params[\"gasPrice\"]\n\n def test_personal_sign_and_ecrecover(\n self,\n w3: \"Web3\",\n unlockable_account_dual_type: ChecksumAddress,\n unlockable_account_pw: str,\n ) -> None:\n message = \"test-web3-geth-personal-sign\"\n signature = w3.geth.personal.sign(\n message, unlockable_account_dual_type, unlockable_account_pw\n )\n signer = w3.geth.personal.ec_recover(message, signature)\n assert is_same_address(signer, unlockable_account_dual_type)\n\n @pytest.mark.xfail(\n reason=\"personal_sign_typed_data JSON RPC call has not been released in geth\"\n )\n def test_personal_sign_typed_data(\n self,\n w3: \"Web3\",\n unlockable_account_dual_type: ChecksumAddress,\n unlockable_account_pw: str,\n ) -> None:\n typed_message = \"\"\"\n {\n \"types\": {\n \"EIP712Domain\": [\n {\"name\": \"name\", \"type\": \"string\"},\n {\"name\": \"version\", \"type\": \"string\"},\n {\"name\": \"chainId\", \"type\": \"uint256\"},\n {\"name\": \"verifyingContract\", \"type\": \"address\"}\n ],\n \"Person\": [\n {\"name\": \"name\", \"type\": \"string\"},\n {\"name\": \"wallet\", \"type\": \"address\"}\n ],\n \"Mail\": [\n {\"name\": \"from\", \"type\": \"Person\"},\n {\"name\": \"to\", \"type\": \"Person\"},\n {\"name\": \"contents\", \"type\": \"string\"}\n ]\n },\n \"primaryType\": \"Mail\",\n \"domain\": {\n \"name\": \"Ether Mail\",\n \"version\": \"1\",\n \"chainId\": \"0x01\",\n \"verifyingContract\": \"0xCcCCccccCCCCcCCCCCCcCcCccCcCCCcCcccccccC\"\n },\n \"message\": {\n \"from\": {\n \"name\": \"Cow\",\n \"wallet\": \"0xCD2a3d9F938E13CD947Ec05AbC7FE734Df8DD826\"\n },\n \"to\": {\n \"name\": \"Bob\",\n \"wallet\": \"0xbBbBBBBbbBBBbbbBbbBbbbbBBbBbbbbBbBbbBBbB\"\n },\n \"contents\": \"Hello, Bob!\"\n }\n }\n \"\"\"\n signature = HexBytes(\n w3.geth.personal.sign_typed_data(\n json.loads(typed_message),\n unlockable_account_dual_type,\n unlockable_account_pw,\n )\n )\n\n expected_signature = HexBytes(\n \"0xc8b56aaeefd10ab4005c2455daf28d9082af661ac347cd\"\n \"b612d5b5e11f339f2055be831bf57a6e6cb5f6d93448fa35\"\n \"c1bd56fe1d745ffa101e74697108668c401c\"\n )\n assert signature == expected_signature\n assert len(signature) == 32 + 32 + 1\n\n\nclass GoEthereumAsyncPersonalModuleTest:\n @pytest.mark.asyncio\n async def test_async_sign_and_ec_recover(\n self,\n async_w3: \"AsyncWeb3\",\n async_unlockable_account_dual_type: ChecksumAddress,\n unlockable_account_pw: str,\n ) -> None:\n message = \"This is a test\"\n signature = await async_w3.geth.personal.sign(\n message, async_unlockable_account_dual_type, unlockable_account_pw\n )\n address = await async_w3.geth.personal.ec_recover(message, signature)\n assert is_same_address(async_unlockable_account_dual_type, address)\n\n @pytest.mark.asyncio\n async def test_async_import_key(self, async_w3: \"AsyncWeb3\") -> None:\n address = await async_w3.geth.personal.import_raw_key(\n THIRD_PRIVATE_KEY_HEX, \"Testing\"\n )\n assert address is not None\n\n @pytest.mark.asyncio\n async def test_async_list_accounts(self, async_w3: \"AsyncWeb3\") -> None:\n accounts = await async_w3.geth.personal.list_accounts()\n assert len(accounts) > 0\n\n @pytest.mark.asyncio\n async def test_async_list_wallets(self, async_w3: \"AsyncWeb3\") -> None:\n wallets = await async_w3.geth.personal.list_wallets()\n assert isinstance(wallets[0], AttributeDict)\n\n @pytest.mark.asyncio\n async def test_async_new_account(self, async_w3: \"AsyncWeb3\") -> None:\n passphrase = \"Create New Account\"\n account = await async_w3.geth.personal.new_account(passphrase)\n assert is_checksum_address(account)\n\n @pytest.mark.asyncio\n async def test_async_unlock_lock_account(\n self,\n async_w3: \"AsyncWeb3\",\n async_unlockable_account_dual_type: ChecksumAddress,\n unlockable_account_pw: str,\n ) -> None:\n unlocked = await async_w3.geth.personal.unlock_account(\n async_unlockable_account_dual_type, unlockable_account_pw\n )\n assert unlocked is True\n locked = await async_w3.geth.personal.lock_account(\n async_unlockable_account_dual_type\n )\n assert locked is True\n\n @pytest.mark.asyncio\n async def test_async_send_transaction(\n self,\n async_w3: \"AsyncWeb3\",\n async_unlockable_account_dual_type: ChecksumAddress,\n unlockable_account_pw: str,\n ) -> None:\n tx_params = TxParams()\n tx_params[\"to\"] = async_unlockable_account_dual_type\n tx_params[\"from\"] = async_unlockable_account_dual_type\n tx_params[\"value\"] = Wei(123)\n response = await async_w3.geth.personal.send_transaction(\n tx_params, unlockable_account_pw\n )\n assert response is not None\n\n @pytest.mark.xfail(\n reason=\"personal_signTypedData JSON RPC call has not been released in geth\"\n )\n @pytest.mark.asyncio\n async def test_async_sign_typed_data(\n self,\n async_w3: \"AsyncWeb3\",\n async_unlockable_account_dual_type: ChecksumAddress,\n unlockable_account_pw: str,\n ) -> None:\n message = {\"message\": \"This is a test\"}\n signature = await async_w3.geth.personal.sign_typed_data(\n message, async_unlockable_account_dual_type, unlockable_account_pw\n )\n address = await async_w3.geth.personal.ec_recover(\n json.dumps(message), signature\n )\n assert is_same_address(async_unlockable_account_dual_type, address)\n\n# Path: web3/_utils/module_testing/go_ethereum_txpool_module.py\nimport pytest\n\nfrom web3 import (\n AsyncWeb3,\n Web3,\n)\n\n\nclass GoEthereumAsyncTxPoolModuleTest:\n @pytest.mark.asyncio\n async def test_async_geth_txpool_inspect(self, async_w3: \"AsyncWeb3\") -> None:\n test_data = await async_w3.geth.txpool.inspect()\n assert \"pending\" in test_data\n\n @pytest.mark.asyncio\n async def test_async_geth_txpool_content(self, async_w3: \"AsyncWeb3\") -> None:\n test_data = await async_w3.geth.txpool.content()\n assert \"pending\" in test_data\n\n @pytest.mark.asyncio\n async def test_async_geth_txpool_status(self, async_w3: \"AsyncWeb3\") -> None:\n test_data = await async_w3.geth.txpool.status()\n assert \"pending\" in test_data\n\n\nclass GoEthereumTxPoolModuleTest:\n def test_geth_txpool_inspect(self, w3: \"Web3\") -> None:\n test_data = w3.geth.txpool.inspect()\n assert \"pending\" in test_data\n\n def test_geth_txpool_content(self, w3: \"Web3\") -> None:\n test_data = w3.geth.txpool.content()\n assert \"pending\" in test_data\n\n def test_geth_txpool_status(self, w3: \"Web3\") -> None:\n test_data = w3.geth.txpool.status()\n assert \"pending\" in test_data\n\n# Path: web3/_utils/module_testing/net_module.py\nimport pytest\nfrom typing import (\n TYPE_CHECKING,\n)\n\nfrom eth_utils import (\n is_boolean,\n is_integer,\n is_string,\n)\n\nif TYPE_CHECKING:\n from web3 import ( # noqa: F401\n AsyncWeb3,\n Web3,\n )\n\n\nclass NetModuleTest:\n def test_net_version(self, w3: \"Web3\") -> None:\n version = w3.net.version\n\n assert is_string(version)\n assert version.isdigit()\n\n def test_net_listening(self, w3: \"Web3\") -> None:\n listening = w3.net.listening\n\n assert is_boolean(listening)\n\n def test_net_peer_count(self, w3: \"Web3\") -> None:\n peer_count = w3.net.peer_count\n\n assert is_integer(peer_count)\n\n\nclass AsyncNetModuleTest:\n @pytest.mark.asyncio\n async def test_net_version(self, async_w3: \"AsyncWeb3\") -> None:\n version = await async_w3.net.version\n\n assert is_string(version)\n assert version.isdigit()\n\n @pytest.mark.asyncio\n async def test_net_listening(self, async_w3: \"AsyncWeb3\") -> None:\n listening = await async_w3.net.listening\n\n assert is_boolean(listening)\n\n @pytest.mark.asyncio\n async def test_net_peer_count(self, async_w3: \"AsyncWeb3\") -> None:\n peer_count = await async_w3.net.peer_count\n\n assert is_integer(peer_count)\n\n# Path: web3/_utils/module_testing/web3_module.py\nimport pytest\nfrom typing import (\n Any,\n NoReturn,\n Sequence,\n Union,\n)\n\nfrom eth_typing import (\n ChecksumAddress,\n HexAddress,\n HexStr,\n TypeStr,\n)\nfrom hexbytes import (\n HexBytes,\n)\n\nfrom web3 import (\n AsyncWeb3,\n Web3,\n)\nfrom web3._utils.ens import (\n ens_addresses,\n)\nfrom web3.exceptions import (\n InvalidAddress,\n)\n\n\nclass Web3ModuleTest:\n def test_web3_client_version(self, w3: Web3) -> None:\n client_version = w3.client_version\n self._check_web3_client_version(client_version)\n\n def _check_web3_client_version(self, client_version: str) -> NoReturn:\n raise NotImplementedError(\"Must be implemented by subclasses\")\n\n # Contract that calculated test values can be found at\n # https://kovan.etherscan.io/address/0xb9be06f5b99372cf9afbccadbbb9954ccaf7f4bb#code\n @pytest.mark.parametrize(\n \"types,values,expected\",\n (\n (\n [\"bool\"],\n [True],\n HexBytes(\n \"0x5fe7f977e71dba2ea1a68e21057beebb9be2ac30c6410aa38d4f3fbe41dcffd2\"\n ),\n ),\n (\n [\"uint8\", \"uint8\", \"uint8\"],\n [97, 98, 99],\n HexBytes(\n \"0x4e03657aea45a94fc7d47ba826c8d667c0d1e6e33a64a036ec44f58fa12d6c45\"\n ),\n ),\n (\n [\"uint248\"],\n [30],\n HexBytes(\n \"0x30f95d210785601eb33ae4d53d405b26f920e765dff87cca8e9a4aec99f82671\"\n ),\n ),\n (\n [\"bool\", \"uint16\"],\n [True, 299],\n HexBytes(\n \"0xed18599ccd80ee9fae9a28b0e34a5573c3233d7468f808fd659bc171cf0b43bd\"\n ),\n ),\n (\n [\"int256\"],\n [-10],\n HexBytes(\n \"0xd6fb717f7e270a360f5093ce6a7a3752183e89c9a9afe5c0cb54b458a304d3d5\"\n ),\n ),\n (\n [\"int256\"],\n [10],\n HexBytes(\n \"0xc65a7bb8d6351c1cf70c95a316cc6a92839c986682d98bc35f958f4883f9d2a8\"\n ),\n ),\n (\n [\"int8\", \"uint8\"],\n [-10, 18],\n HexBytes(\n \"0x5c6ab1e634c08d9c0f4df4d789e8727943ef010dd7ca8e3c89de197a26d148be\"\n ),\n ),\n (\n [\"address\"],\n [\"0x49eddd3769c0712032808d86597b84ac5c2f5614\"],\n InvalidAddress,\n ),\n (\n [\"address\"],\n [\"0x49EdDD3769c0712032808D86597B84ac5c2F5614\"],\n HexBytes(\n \"0x2ff37b5607484cd4eecf6d13292e22bd6e5401eaffcc07e279583bc742c68882\"\n ),\n ),\n (\n [\"bytes2\"],\n [\"0x5402\"],\n HexBytes(\n \"0x4ed9171bda52fca71ab28e7f452bd6eacc3e5a568a47e0fa53b503159a9b8910\"\n ),\n ),\n (\n [\"bytes3\"],\n [\"0x5402\"],\n HexBytes(\n \"0x4ed9171bda52fca71ab28e7f452bd6eacc3e5a568a47e0fa53b503159a9b8910\"\n ),\n ),\n (\n [\"bytes\"],\n [\n \"0x636865636b6c6f6e6762797465737472696e676167\"\n \"61696e7374736f6c6964697479736861336861736866756e6374696f6e\"\n ],\n HexBytes(\n \"0xd78a84d65721b67e4011b10c99dafdedcdcd7cb30153064f773e210b4762e22f\"\n ),\n ),\n (\n [\"string\"],\n [\"testing a string!\"],\n HexBytes(\n \"0xe8c275c0b4070a5ec6cfcb83f0ba394b30ddd283de785d43f2eabfb04bd96747\"\n ),\n ),\n (\n [\"string\", \"bool\", \"uint16\", \"bytes2\", \"address\"],\n [\n \"testing a string!\",\n False,\n 299,\n \"0x5402\",\n \"0x49eddd3769c0712032808d86597b84ac5c2f5614\",\n ],\n InvalidAddress,\n ),\n (\n [\"string\", \"bool\", \"uint16\", \"bytes2\", \"address\"],\n [\n \"testing a string!\",\n False,\n 299,\n \"0x5402\",\n \"0x49EdDD3769c0712032808D86597B84ac5c2F5614\",\n ],\n HexBytes(\n \"0x8cc6eabb25b842715e8ca39e2524ed946759aa37bfb7d4b81829cf5a7e266103\"\n ),\n ),\n (\n [\"bool[2][]\"],\n [[[True, False], [False, True]]],\n HexBytes(\n \"0x1eef261f2eb51a8c736d52be3f91ff79e78a9ec5df2b7f50d0c6f98ed1e2bc06\"\n ),\n ),\n (\n [\"bool[]\"],\n [[True, False, True]],\n HexBytes(\n \"0x5c6090c0461491a2941743bda5c3658bf1ea53bbd3edcde54e16205e18b45792\"\n ),\n ),\n (\n [\"uint24[]\"],\n [[1, 0, 1]],\n HexBytes(\n \"0x5c6090c0461491a2941743bda5c3658bf1ea53bbd3edcde54e16205e18b45792\"\n ),\n ),\n (\n [\"uint8[2]\"],\n [[8, 9]],\n HexBytes(\n \"0xc7694af312c4f286114180fd0ba6a52461fcee8a381636770b19a343af92538a\"\n ),\n ),\n (\n [\"uint256[2]\"],\n [[8, 9]],\n HexBytes(\n \"0xc7694af312c4f286114180fd0ba6a52461fcee8a381636770b19a343af92538a\"\n ),\n ),\n (\n [\"uint8[]\"],\n [[8]],\n HexBytes(\n \"0xf3f7a9fe364faab93b216da50a3214154f22a0a2b415b23a84c8169e8b636ee3\"\n ),\n ),\n (\n [\"address[]\"],\n [\n [\n \"0x49EdDD3769c0712032808D86597B84ac5c2F5614\",\n \"0xA6b759bBbf4B59D24acf7E06e79f3a5D104fdCE5\",\n ]\n ],\n HexBytes(\n \"0xb98565c0c26a962fd54d93b0ed6fb9296e03e9da29d2281ed3e3473109ef7dde\"\n ),\n ),\n (\n [\"address[]\"],\n [\n [\n \"0x49EdDD3769c0712032808D86597B84ac5c2F5614\",\n \"0xa6b759bbbf4b59d24acf7e06e79f3a5d104fdce5\",\n ]\n ],\n InvalidAddress,\n ),\n ),\n )\n @pytest.mark.parametrize(\n \"w3\",\n (\n Web3,\n AsyncWeb3,\n ),\n )\n def test_solidity_keccak(\n self,\n w3: Union[\"Web3\", \"AsyncWeb3\"],\n types: Sequence[TypeStr],\n values: Sequence[Any],\n expected: HexBytes,\n ) -> None:\n if isinstance(expected, type) and issubclass(expected, Exception):\n with pytest.raises(expected):\n w3.solidity_keccak(types, values)\n return\n\n actual = w3.solidity_keccak(types, values)\n assert actual == expected\n\n @pytest.mark.parametrize(\n \"types, values, expected\",\n (\n (\n [\"address\"],\n [\"one.eth\"],\n HexBytes(\n \"0x2ff37b5607484cd4eecf6d13292e22bd6e5401eaffcc07e279583bc742c68882\"\n ),\n ),\n (\n [\"address[]\"],\n [[\"one.eth\", \"two.eth\"]],\n HexBytes(\n \"0xb98565c0c26a962fd54d93b0ed6fb9296e03e9da29d2281ed3e3473109ef7dde\"\n ),\n ),\n ),\n )\n @pytest.mark.parametrize(\n \"w3\",\n (\n Web3(),\n AsyncWeb3(),\n ),\n )\n def test_solidity_keccak_ens(\n self,\n w3: Union[\"Web3\", \"AsyncWeb3\"],\n types: Sequence[TypeStr],\n values: Sequence[str],\n expected: HexBytes,\n ) -> None:\n with ens_addresses(\n w3,\n {\n \"one.eth\": ChecksumAddress(\n HexAddress(HexStr(\"0x49EdDD3769c0712032808D86597B84ac5c2F5614\"))\n ),\n \"two.eth\": ChecksumAddress(\n HexAddress(HexStr(\"0xA6b759bBbf4B59D24acf7E06e79f3a5D104fdCE5\"))\n ),\n },\n ):\n # when called as class method, any name lookup attempt will fail\n with pytest.raises(InvalidAddress):\n Web3.solidity_keccak(types, values)\n\n # when called as instance method, ens lookups can succeed\n actual = w3.solidity_keccak(types, values)\n assert actual == expected\n\n @pytest.mark.parametrize(\n \"types,values\",\n (\n ([\"address\"], [\"0xA6b759bBbf4B59D24acf7E06e79f3a5D104fdCE5\", True]),\n ([\"address\", \"bool\"], [\"0xA6b759bBbf4B59D24acf7E06e79f3a5D104fdCE5\"]),\n ([], [\"0xA6b759bBbf4B59D24acf7E06e79f3a5D104fdCE5\"]),\n ),\n )\n def test_solidity_keccak_same_number_of_types_and_values(\n self, w3: \"Web3\", types: Sequence[TypeStr], values: Sequence[Any]\n ) -> None:\n with pytest.raises(ValueError):\n w3.solidity_keccak(types, values)\n\n def test_is_connected(self, w3: \"Web3\") -> None:\n assert w3.is_connected()\n\n# Path: web3/_utils/module_testing/__init__.py\nfrom .eth_module import (\n AsyncEthModuleTest,\n EthModuleTest,\n)\nfrom .go_ethereum_admin_module import (\n GoEthereumAdminModuleTest,\n)\nfrom .go_ethereum_personal_module import (\n GoEthereumPersonalModuleTest,\n)\nfrom .go_ethereum_txpool_module import (\n GoEthereumAsyncTxPoolModuleTest,\n GoEthereumTxPoolModuleTest,\n)\nfrom .net_module import (\n AsyncNetModuleTest,\n NetModuleTest,\n)\nfrom .web3_module import (\n Web3ModuleTest,\n)\n\n# Path: web3/_utils/module_testing/persistent_connection_provider.py\nimport asyncio\nimport pytest\nfrom typing import (\n TYPE_CHECKING,\n Any,\n Dict,\n Tuple,\n cast,\n)\n\nfrom eth_utils import (\n is_hexstr,\n)\nfrom hexbytes import (\n HexBytes,\n)\n\nfrom web3.datastructures import (\n AttributeDict,\n)\nfrom web3.middleware import (\n ExtraDataToPOAMiddleware,\n)\nfrom web3.types import (\n FormattedEthSubscriptionResponse,\n)\n\nif TYPE_CHECKING:\n from web3.main import (\n AsyncWeb3,\n )\n\n\nclass PersistentConnectionProviderTest:\n @pytest.mark.asyncio\n @pytest.mark.parametrize(\n \"subscription_params,ws_subscription_response,expected_formatted_result\",\n (\n (\n (\"newHeads\",),\n {\n \"jsonrpc\": \"2.0\",\n \"method\": \"eth_subscription\",\n \"params\": {\n \"subscription\": \"THIS_WILL_BE_REPLACED_IN_THE_TEST\",\n \"result\": {\n \"number\": \"0x539\",\n \"hash\": \"0xb46b85928f2c2264c2bf7ad5c6d6985664f1527e744193ef990cc0d3da5afc5e\", # noqa: E501\n \"parentHash\": \"0xb46b85928f2c2264c2bf7ad5c6d6985664f1527e744193ef990cc0d3da5afc5e\", # noqa: E501\n \"sha3Uncles\": \"0x1dcc4de8dec75d7aab85b567b6ccd41ad312451b948a7413f0a142fd40d49347\", # noqa: E501\n \"logsBloom\": \"0x00\",\n \"transactionsRoot\": \"0x56260fe8298aff6d360e3a68fa855693f25dcb2708d8a7e509e8519b265d3988\", # noqa: E501\n \"stateRoot\": \"0x56260fe8298aff6d360e3a68fa855693f25dcb2708d8a7e509e8519b265d3988\", # noqa: E501\n \"receiptsRoot\": \"0x56260fe8298aff6d360e3a68fa855693f25dcb2708d8a7e509e8519b265d3988\", # noqa: E501\n \"miner\": \"0x0000000000000000000000000000000000000000\",\n \"difficulty\": \"0x0\",\n \"extraData\": \"0x496c6c756d696e61746520446d6f63726174697a6520447374726962757465\", # noqa: E501\n \"gasLimit\": \"0x1c9c380\",\n \"gasUsed\": \"0xd1ce44\",\n \"timestamp\": \"0x539\",\n \"baseFeePerGas\": \"0x26f93fef9\",\n \"withdrawalsRoot\": \"0x56260fe8298aff6d360e3a68fa855693f25dcb2708d8a7e509e8519b265d3988\", # noqa: E501\n \"nonce\": \"0x0000000000000000\",\n \"mixHash\": \"0x73e9e036ec894047f29954571d4b6d9e8717de7304269c263cbf150caa4e0768\", # noqa: E501\n },\n },\n },\n AttributeDict(\n {\n \"number\": 1337,\n \"hash\": HexBytes(\n \"0xb46b85928f2c2264c2bf7ad5c6d6985664f1527e744193ef990cc0d3da5afc5e\" # noqa: E501\n ),\n \"parentHash\": HexBytes(\n \"0xb46b85928f2c2264c2bf7ad5c6d6985664f1527e744193ef990cc0d3da5afc5e\" # noqa: E501\n ),\n \"sha3Uncles\": HexBytes(\n \"0x1dcc4de8dec75d7aab85b567b6ccd41ad312451b948a7413f0a142fd40d49347\" # noqa: E501\n ),\n \"logsBloom\": HexBytes(\"0x00\"),\n \"transactionsRoot\": HexBytes(\n \"0x56260fe8298aff6d360e3a68fa855693f25dcb2708d8a7e509e8519b265d3988\" # noqa: E501\n ),\n \"stateRoot\": HexBytes(\n \"0x56260fe8298aff6d360e3a68fa855693f25dcb2708d8a7e509e8519b265d3988\" # noqa: E501\n ),\n \"receiptsRoot\": HexBytes(\n \"0x56260fe8298aff6d360e3a68fa855693f25dcb2708d8a7e509e8519b265d3988\" # noqa: E501\n ),\n \"miner\": \"0x0000000000000000000000000000000000000000\",\n \"difficulty\": 0,\n \"extraData\": HexBytes(\n \"0x496c6c756d696e61746520446d6f63726174697a6520447374726962757465\" # noqa: E501\n ),\n \"gasLimit\": 30000000,\n \"gasUsed\": 13749828,\n \"timestamp\": 1337,\n \"baseFeePerGas\": 10461904633,\n \"withdrawalsRoot\": HexBytes(\n \"0x56260fe8298aff6d360e3a68fa855693f25dcb2708d8a7e509e8519b265d3988\" # noqa: E501\n ),\n \"nonce\": HexBytes(\"0x0000000000000000\"),\n \"mixHash\": HexBytes(\n \"0x73e9e036ec894047f29954571d4b6d9e8717de7304269c263cbf150caa4e0768\" # noqa: E501\n ),\n }\n ),\n ),\n (\n (\"newPendingTransactions\", True),\n {\n \"jsonrpc\": \"2.0\",\n \"method\": \"eth_subscription\",\n \"params\": {\n \"subscription\": \"THIS_WILL_BE_REPLACED_IN_THE_TEST\",\n \"result\": {\n \"blockHash\": None,\n \"blockNumber\": None,\n \"from\": \"0x0000000000000000000000000000000000000000\",\n \"gas\": \"0xf2f4\",\n \"gasPrice\": \"0x29035f36f\",\n \"maxFeePerGas\": \"0x29035f36f\",\n \"maxPriorityFeePerGas\": \"0x3b9aca00\",\n \"hash\": \"0xb46b85928f2c2264c2bf7ad5c6d6985664f1527e744193ef990cc0d3da5afc5e\", # noqa: E501\n \"input\": \"0x00\",\n \"nonce\": \"0x2013\",\n \"to\": \"0x0000000000000000000000000000000000000000\",\n \"transactionIndex\": None,\n \"value\": \"0x0\",\n \"type\": \"0x2\",\n \"accessList\": [],\n \"chainId\": \"0x1\",\n \"v\": \"0x1\",\n \"r\": \"0x3c144a7c00ed3118d55445cd5be2ae4620ca377f7c685e9c5f3687671d4dece1\", # noqa: E501\n \"s\": \"0x284de67cbf75fec8a9edb368dee3a37cf6faba87f0af4413b2f869ebfa87d002\", # noqa: E501\n \"yParity\": \"0x1\",\n },\n },\n },\n AttributeDict(\n {\n \"blockHash\": None,\n \"blockNumber\": None,\n \"from\": \"0x0000000000000000000000000000000000000000\",\n \"gas\": 62196,\n \"gasPrice\": 11009389423,\n \"maxFeePerGas\": 11009389423,\n \"maxPriorityFeePerGas\": 1000000000,\n \"hash\": HexBytes(\n \"0xb46b85928f2c2264c2bf7ad5c6d6985664f1527e744193ef990cc0d3da5afc5e\" # noqa: E501\n ),\n \"input\": HexBytes(\"0x00\"),\n \"nonce\": 8211,\n \"to\": \"0x0000000000000000000000000000000000000000\",\n \"transactionIndex\": None,\n \"value\": 0,\n \"type\": 2,\n \"accessList\": [],\n \"chainId\": 1,\n \"v\": 1,\n \"r\": HexBytes(\n \"0x3c144a7c00ed3118d55445cd5be2ae4620ca377f7c685e9c5f3687671d4dece1\" # noqa: E501\n ),\n \"s\": HexBytes(\n \"0x284de67cbf75fec8a9edb368dee3a37cf6faba87f0af4413b2f869ebfa87d002\" # noqa: E501\n ),\n \"yParity\": 1,\n }\n ),\n ),\n (\n (\"newPendingTransactions\", False),\n {\n \"jsonrpc\": \"2.0\",\n \"method\": \"eth_subscription\",\n \"params\": {\n...\n", "needle_token_start": 747, "needle_token_end": 853, "code_context_ntokens": 16244, "output": ["```python\ndef test_eth_get_logs_with_logs_none_topic_args(self, w3: \"Web3\") -> None:\n # Test with None overflowing\n filter_params: FilterParams = {\n \"fromBlock\": BlockNumber(0),\n \"topics\": [None, None, None],\n }\n\n result = w3.eth.get_logs(filter_params)\n assert len(result) == 0\n```"]} +{"repo": "ethereum/web3.py", "name": "_get_response_for_request_id", "language": "python", "path": "web3/providers/persistent/persistent.py", "position_ratio": 0.15, "description": "\nFunction Description:\n1. **Purpose**: To retrieve the response associated with a specific request identifier from a cache, with a mechanism to handle timeouts if the response is not available within a specified duration.\n2. **Input**: A unique identifier for the request and an optional timeout duration.\n3. **Output**: The response corresponding to the request identifier if available; otherwise, an exception is raised if the response is not found within the timeout period.\n4. **Procedure**: The function checks a cache for the response corresponding to the given request identifier. It continuously yields control back to the event loop to allow other operations to proceed while waiting. If the response is found within the allotted time, it is returned. If not, an exception indicating a timeout is raised.\n\n", "instruction": "Based on the function description and code context, please retrieve and repeat the exact described function from the code context in a code block wrapped by ```:", "template": "instruction\ncode_context\ndescription\ninstruction", "code_context": "# Path: web3/providers/persistent/request_processor.py\nimport asyncio\nfrom copy import (\n copy,\n)\nfrom typing import (\n TYPE_CHECKING,\n Any,\n Callable,\n Dict,\n Optional,\n Tuple,\n)\n\nfrom web3._utils.caching import (\n RequestInformation,\n generate_cache_key,\n)\nfrom web3.types import (\n RPCEndpoint,\n RPCResponse,\n)\nfrom web3.utils import (\n SimpleCache,\n)\n\nif TYPE_CHECKING:\n from web3.providers.persistent import (\n PersistentConnectionProvider,\n )\n\n\nclass RequestProcessor:\n _subscription_queue_synced_with_ws_stream: bool = False\n\n def __init__(\n self,\n provider: \"PersistentConnectionProvider\",\n subscription_response_queue_size: int = 500,\n ) -> None:\n self._provider = provider\n\n self._request_information_cache: SimpleCache = SimpleCache(500)\n self._request_response_cache: SimpleCache = SimpleCache(500)\n self._subscription_response_queue: asyncio.Queue[RPCResponse] = asyncio.Queue(\n maxsize=subscription_response_queue_size\n )\n\n @property\n def active_subscriptions(self) -> Dict[str, Any]:\n return {\n value.subscription_id: {\"params\": value.params}\n for key, value in self._request_information_cache.items()\n if value.method == \"eth_subscribe\"\n }\n\n # request information cache\n\n def cache_request_information(\n self,\n method: RPCEndpoint,\n params: Any,\n response_formatters: Tuple[Callable[..., Any], ...],\n ) -> Optional[str]:\n cached_requests_key = generate_cache_key((method, params))\n if cached_requests_key in self._provider._request_cache._data:\n cached_response = self._provider._request_cache._data[cached_requests_key]\n cached_response_id = cached_response.get(\"id\")\n cache_key = generate_cache_key(cached_response_id)\n if cache_key in self._request_information_cache:\n self._provider.logger.debug(\n \"This is a cached request, not caching request info because it is \"\n f\"not unique:\\n method={method},\\n params={params}\"\n )\n return None\n\n # copy the request counter and find the next request id without incrementing\n # since this is done when / if the request is successfully sent\n request_id = next(copy(self._provider.request_counter))\n cache_key = generate_cache_key(request_id)\n\n self._bump_cache_if_key_present(cache_key, request_id)\n\n request_info = RequestInformation(\n method,\n params,\n response_formatters,\n )\n self._provider.logger.debug(\n f\"Caching request info:\\n request_id={request_id},\\n\"\n f\" cache_key={cache_key},\\n request_info={request_info.__dict__}\"\n )\n self._request_information_cache.cache(\n cache_key,\n request_info,\n )\n return cache_key\n\n def _bump_cache_if_key_present(self, cache_key: str, request_id: int) -> None:\n \"\"\"\n If the cache key is present in the cache, bump the cache key and request id\n by one to make room for the new request. This behavior is necessary when a\n request is made but inner requests, say to `eth_estimateGas` if the `gas` is\n missing, are made before the original request is sent.\n \"\"\"\n if cache_key in self._request_information_cache:\n original_request_info = self._request_information_cache.get_cache_entry(\n cache_key\n )\n bump = generate_cache_key(request_id + 1)\n\n # recursively bump the cache if the new key is also present\n self._bump_cache_if_key_present(bump, request_id + 1)\n\n self._provider.logger.debug(\n \"Caching internal request. Bumping original request in cache:\\n\"\n f\" request_id=[{request_id}] -> [{request_id + 1}],\\n\"\n f\" cache_key=[{cache_key}] -> [{bump}],\\n\"\n f\" request_info={original_request_info.__dict__}\"\n )\n self._request_information_cache.cache(bump, original_request_info)\n\n def pop_cached_request_information(\n self, cache_key: str\n ) -> Optional[RequestInformation]:\n request_info = self._request_information_cache.pop(cache_key)\n if request_info is not None:\n self._provider.logger.debug(\n \"Request info popped from cache:\\n\"\n f\" cache_key={cache_key},\\n request_info={request_info.__dict__}\"\n )\n return request_info\n\n def get_request_information_for_response(\n self,\n response: RPCResponse,\n ) -> RequestInformation:\n if \"method\" in response and response[\"method\"] == \"eth_subscription\":\n if \"params\" not in response:\n raise ValueError(\"Subscription response must have params field\")\n if \"subscription\" not in response[\"params\"]:\n raise ValueError(\n \"Subscription response params must have subscription field\"\n )\n\n # retrieve the request info from the cache using the subscription id\n cache_key = generate_cache_key(response[\"params\"][\"subscription\"])\n request_info = (\n # don't pop the request info from the cache, since we need to keep it\n # to process future subscription responses\n # i.e. subscription request information remains in the cache\n self._request_information_cache.get_cache_entry(cache_key)\n )\n\n else:\n # retrieve the request info from the cache using the request id\n cache_key = generate_cache_key(response[\"id\"])\n if response in self._provider._request_cache._data.values():\n request_info = (\n # don't pop the request info from the cache, since we need to keep\n # it to process future responses\n # i.e. request information remains in the cache\n self._request_information_cache.get_cache_entry(cache_key)\n )\n else:\n request_info = (\n...\n# Path: web3/providers/persistent/persistent.py\nfrom abc import (\n ABC,\n)\nimport asyncio\nimport logging\nfrom typing import (\n Optional,\n)\n\nfrom web3._utils.caching import (\n generate_cache_key,\n)\nfrom web3.exceptions import (\n TimeExhausted,\n)\nfrom web3.providers.async_base import (\n AsyncJSONBaseProvider,\n)\nfrom web3.providers.persistent.request_processor import (\n RequestProcessor,\n)\nfrom web3.types import (\n RPCId,\n RPCResponse,\n)\n\nDEFAULT_PERSISTENT_CONNECTION_TIMEOUT = 30.0\n\n\nclass PersistentConnectionProvider(AsyncJSONBaseProvider, ABC):\n logger = logging.getLogger(\"web3.providers.PersistentConnectionProvider\")\n has_persistent_connection = True\n endpoint_uri: Optional[str] = None\n\n _request_processor: RequestProcessor\n _message_listener_task: Optional[\"asyncio.Task[None]\"] = None\n _listen_event: asyncio.Event = asyncio.Event()\n\n def __init__(\n self,\n request_timeout: float = DEFAULT_PERSISTENT_CONNECTION_TIMEOUT,\n subscription_response_queue_size: int = 500,\n silence_listener_task_exceptions: bool = False,\n ) -> None:\n super().__init__()\n self._request_processor = RequestProcessor(\n self,\n subscription_response_queue_size=subscription_response_queue_size,\n )\n self.request_timeout = request_timeout\n self.silence_listener_task_exceptions = silence_listener_task_exceptions\n\n async def connect(self) -> None:\n raise NotImplementedError(\"Must be implemented by subclasses\")\n\n async def disconnect(self) -> None:\n raise NotImplementedError(\"Must be implemented by subclasses\")\n\n async def _message_listener(self) -> None:\n raise NotImplementedError(\"Must be implemented by subclasses\")\n\n \nasync def _get_response_for_request_id(\n self, request_id: RPCId, timeout: Optional[float] = None\n ) -> RPCResponse:\n if timeout is None:\n timeout = self.request_timeout\n\n async def _match_response_id_to_request_id() -> RPCResponse:\n request_cache_key = generate_cache_key(request_id)\n\n while True:\n # sleep(0) here seems to be the most efficient way to yield control\n # back to the event loop while waiting for the response to be in the\n # queue.\n await asyncio.sleep(0)\n\n if request_cache_key in self._request_processor._request_response_cache:\n self.logger.debug(\n f\"Popping response for id {request_id} from cache.\"\n )\n popped_response = self._request_processor.pop_raw_response(\n cache_key=request_cache_key,\n )\n return popped_response\n\n try:\n # Add the request timeout around the while loop that checks the request\n # cache and tried to recv(). If the request is neither in the cache, nor\n # received within the request_timeout, raise ``TimeExhausted``.\n return await asyncio.wait_for(_match_response_id_to_request_id(), timeout)\n except asyncio.TimeoutError:\n raise TimeExhausted(\n f\"Timed out waiting for response with request id `{request_id}` after \"\n f\"{self.request_timeout} second(s). This may be due to the provider \"\n \"not returning a response with the same id that was sent in the \"\n \"request or an exception raised during the request was caught and \"\n \"allowed to continue.\"\n )\n\n# Path: web3/method.py\nimport functools\nfrom typing import (\n TYPE_CHECKING,\n Any,\n Callable,\n Dict,\n Generic,\n List,\n Optional,\n Sequence,\n Tuple,\n Type,\n TypeVar,\n Union,\n)\nimport warnings\n\nfrom eth_utils.curried import (\n to_tuple,\n)\nfrom eth_utils.toolz import (\n pipe,\n)\n\nfrom web3._utils.method_formatters import (\n get_error_formatters,\n get_null_result_formatters,\n get_request_formatters,\n get_result_formatters,\n)\nfrom web3._utils.rpc_abi import (\n RPC,\n)\nfrom web3.exceptions import (\n Web3ValidationError,\n)\nfrom web3.types import (\n RPCEndpoint,\n TReturn,\n)\n\nif TYPE_CHECKING:\n from web3 import Web3 # noqa: F401\n from web3.module import Module # noqa: F401\n\n\nMunger = Callable[..., Any]\n\n\n@to_tuple\ndef _apply_request_formatters(\n params: Any, request_formatters: Dict[RPCEndpoint, Callable[..., TReturn]]\n) -> Tuple[Any, ...]:\n if request_formatters:\n formatted_params = pipe(params, request_formatters)\n return formatted_params\n return params\n\n\ndef _set_mungers(\n mungers: Optional[Sequence[Munger]], is_property: bool\n) -> Sequence[Any]:\n if is_property and mungers:\n raise Web3ValidationError(\"Mungers cannot be used with a property.\")\n\n return (\n mungers\n if mungers\n else [default_munger] if is_property else [default_root_munger]\n )\n\n\ndef default_munger(_module: \"Module\", *args: Any, **kwargs: Any) -> Tuple[()]:\n if args or kwargs:\n raise Web3ValidationError(\"Parameters cannot be passed to a property.\")\n return ()\n\n\ndef default_root_munger(_module: \"Module\", *args: Any) -> List[Any]:\n return [*args]\n\n\nTFunc = TypeVar(\"TFunc\", bound=Callable[..., Any])\n\n\nclass Method(Generic[TFunc]):\n \"\"\"Method object for web3 module methods\n\n Calls to the Method go through these steps:\n\n 1. input munging - includes normalization, parameter checking, early parameter\n formatting. Any processing on the input parameters that need to happen before\n json_rpc method string selection occurs.\n\n A note about mungers: The first (root) munger should reflect the desired\n api function arguments. In other words, if the api function wants to\n behave as: `get_balance(account, block_identifier=None)`, the root munger\n should accept these same arguments, with the addition of the module as\n the first argument e.g.:\n\n ```\n def get_balance_root_munger(module, account, block_identifier=None):\n if block_identifier is None:\n block_identifier = DEFAULT_BLOCK\n return module, [account, block_identifier]\n ```\n\n all mungers should return an argument list.\n\n if no munger is provided, a default munger expecting no method arguments\n will be used.\n\n 2. method selection - The json_rpc_method argument can be method string or a\n function that returns a method string. If a callable is provided the processed\n method inputs are passed to the method selection function, and the returned\n method string is used.\n\n 3. request and response formatters are set - formatters are retrieved\n using the json rpc method string.\n\n 4. After the parameter processing from steps 1-3 the request is made using\n the calling function returned by the module attribute ``retrieve_caller_fn``\n and the response formatters are applied to the output.\n \"\"\"\n\n def __init__(\n self,\n json_rpc_method: Optional[RPCEndpoint] = None,\n mungers: Optional[Sequence[Munger]] = None,\n request_formatters: Optional[Callable[..., TReturn]] = None,\n result_formatters: Optional[Callable[..., TReturn]] = None,\n null_result_formatters: Optional[Callable[..., TReturn]] = None,\n method_choice_depends_on_args: Optional[Callable[..., RPCEndpoint]] = None,\n is_property: bool = False,\n ):\n self.json_rpc_method = json_rpc_method\n self.mungers = _set_mungers(mungers, is_property)\n self.request_formatters = request_formatters or get_request_formatters\n self.result_formatters = result_formatters or get_result_formatters\n self.null_result_formatters = (\n null_result_formatters or get_null_result_formatters\n )\n self.method_choice_depends_on_args = method_choice_depends_on_args\n self.is_property = is_property\n\n def __get__(\n self, obj: Optional[\"Module\"] = None, obj_type: Optional[Type[\"Module\"]] = None\n ) -> TFunc:\n if obj is None:\n raise TypeError(\n \"Direct calls to methods are not supported. \"\n \"Methods must be called from an module instance, \"\n \"usually attached to a web3 instance.\"\n )\n return obj.retrieve_caller_fn(self)\n\n @property\n def method_selector_fn(\n self,\n ) -> Callable[..., Union[RPCEndpoint, Callable[..., RPCEndpoint]]]:\n \"\"\"Gets the method selector from the config.\"\"\"\n if callable(self.json_rpc_method):\n return self.json_rpc_method\n elif isinstance(self.json_rpc_method, (str,)):\n return lambda *_: self.json_rpc_method\n raise ValueError(\n \"``json_rpc_method`` config invalid. May be a string or function\"\n )\n\n def input_munger(self, module: \"Module\", args: Any, kwargs: Any) -> List[Any]:\n # This function takes the input parameters and munges them.\n # See the test_process_params test in ``tests/core/method-class/test_method.py``\n # for an example with multiple mungers.\n return functools.reduce(\n lambda args, munger: munger(module, *args, **kwargs), self.mungers, args\n )\n\n def process_params(self, module: \"Module\", *args: Any, **kwargs: Any) -> Tuple[\n Tuple[Union[RPCEndpoint, Callable[..., RPCEndpoint]], Tuple[Any, ...]],\n Tuple[\n Union[TReturn, Dict[str, Callable[..., Any]]],\n Callable[..., Any],\n Union[TReturn, Callable[..., Any]],\n ],\n ]:\n params = self.input_munger(module, args, kwargs)\n\n if self.method_choice_depends_on_args:\n # If the method choice depends on the args that get passed in,\n # the first parameter determines which method needs to be called\n self.json_rpc_method = self.method_choice_depends_on_args(value=params[0])\n\n pending_or_latest_filter_methods = [\n RPC.eth_newPendingTransactionFilter,\n RPC.eth_newBlockFilter,\n ]\n if self.json_rpc_method in pending_or_latest_filter_methods:\n # For pending or latest filter methods, use params to determine\n # which method to call, but don't pass them through with the request\n params = []\n\n method = self.method_selector_fn()\n response_formatters = (\n self.result_formatters(method, module),\n get_error_formatters(method),\n self.null_result_formatters(method),\n )\n request = (\n method,\n _apply_request_formatters(params, self.request_formatters(method)),\n )\n return request, response_formatters\n\n\nclass DeprecatedMethod:\n def __init__(\n self, method: Method[Callable[..., Any]], old_name: str, new_name: str\n ) -> None:\n self.method = method\n self.old_name = old_name\n self.new_name = new_name\n\n def __get__(\n self, obj: Optional[\"Module\"] = None, obj_type: Optional[Type[\"Module\"]] = None\n ) -> Any:\n warnings.warn(\n f\"{self.old_name} is deprecated in favor of {self.new_name}\",\n category=DeprecationWarning,\n )\n return self.method.__get__(obj, obj_type)\n\n# Path: web3/module.py\nfrom typing import (\n TYPE_CHECKING,\n Any,\n Callable,\n Coroutine,\n Dict,\n Optional,\n TypeVar,\n Union,\n cast,\n)\n\nfrom eth_abi.codec import (\n ABICodec,\n)\nfrom eth_utils.toolz import (\n curry,\n pipe,\n)\n\nfrom web3._utils.filters import (\n AsyncLogFilter,\n LogFilter,\n _UseExistingFilter,\n)\nfrom web3.method import (\n Method,\n)\nfrom web3.providers.persistent import (\n PersistentConnectionProvider,\n)\nfrom web3.types import (\n RPCEndpoint,\n RPCResponse,\n)\n\nif TYPE_CHECKING:\n from web3.main import ( # noqa: F401\n AsyncWeb3,\n Web3,\n )\n\n\n@curry\ndef apply_result_formatters(\n result_formatters: Callable[..., Any], result: RPCResponse\n) -> RPCResponse:\n if result_formatters:\n formatted_result = pipe(result, result_formatters)\n return formatted_result\n else:\n return result\n\n\nTReturn = TypeVar(\"TReturn\")\n\n\n@curry\ndef retrieve_blocking_method_call_fn(\n w3: \"Web3\", module: \"Module\", method: Method[Callable[..., TReturn]]\n) -> Callable[..., Union[TReturn, LogFilter]]:\n def caller(*args: Any, **kwargs: Any) -> Union[TReturn, LogFilter]:\n try:\n (method_str, params), response_formatters = method.process_params(\n module, *args, **kwargs\n )\n except _UseExistingFilter as err:\n return LogFilter(eth_module=module, filter_id=err.filter_id)\n\n (\n result_formatters,\n error_formatters,\n null_result_formatters,\n ) = response_formatters\n result = w3.manager.request_blocking(\n method_str, params, error_formatters, null_result_formatters\n )\n return apply_result_formatters(result_formatters, result)\n\n return caller\n\n\n@curry\ndef retrieve_async_method_call_fn(\n async_w3: \"AsyncWeb3\", module: \"Module\", method: Method[Callable[..., Any]]\n) -> Callable[..., Coroutine[Any, Any, Optional[Union[RPCResponse, AsyncLogFilter]]]]:\n async def caller(*args: Any, **kwargs: Any) -> Union[RPCResponse, AsyncLogFilter]:\n try:\n (method_str, params), response_formatters = method.process_params(\n module, *args, **kwargs\n )\n except _UseExistingFilter as err:\n return AsyncLogFilter(eth_module=module, filter_id=err.filter_id)\n\n if isinstance(async_w3.provider, PersistentConnectionProvider):\n # TODO: The typing does not seem to be correct for response_formatters.\n # For now, keep the expected typing but ignore it here.\n provider = async_w3.provider\n cache_key = provider._request_processor.cache_request_information(\n cast(RPCEndpoint, method_str), params, response_formatters # type: ignore # noqa: E501\n )\n try:\n method_str = cast(RPCEndpoint, method_str)\n return await async_w3.manager.send(method_str, params)\n except Exception as e:\n if (\n cache_key is not None\n and cache_key\n in provider._request_processor._request_information_cache\n ):\n provider._request_processor.pop_cached_request_information(\n cache_key\n )\n raise e\n else:\n (\n result_formatters,\n error_formatters,\n null_result_formatters,\n ) = response_formatters\n\n result = await async_w3.manager.coro_request(\n method_str, params, error_formatters, null_result_formatters\n )\n return apply_result_formatters(result_formatters, result)\n\n return caller\n\n\n# Module should no longer have access to the full web3 api.\n# Only the calling functions need access to the request methods.\n# Any \"re-entrant\" shenanigans can go in the middlewares, which do\n# have web3 access.\nclass Module:\n is_async = False\n\n def __init__(self, w3: Union[\"AsyncWeb3\", \"Web3\"]) -> None:\n if self.is_async:\n self.retrieve_caller_fn = retrieve_async_method_call_fn(w3, self)\n else:\n self.retrieve_caller_fn = retrieve_blocking_method_call_fn(w3, self)\n self.w3 = w3\n\n @property\n def codec(self) -> ABICodec:\n # use codec set on the Web3 instance\n return self.w3.codec\n\n def attach_methods(\n self,\n methods: Dict[str, Method[Callable[..., Any]]],\n ) -> None:\n for method_name, method_class in methods.items():\n klass = (\n method_class.__get__(obj=self)()\n if method_class.is_property\n else method_class.__get__(obj=self)\n )\n setattr(self, method_name, klass)\n\n# Path: web3/manager.py\nimport asyncio\nimport logging\nfrom typing import (\n TYPE_CHECKING,\n Any,\n AsyncGenerator,\n Callable,\n List,\n Optional,\n Sequence,\n Tuple,\n Union,\n cast,\n)\n\nfrom eth_utils.toolz import (\n pipe,\n)\nfrom hexbytes import (\n HexBytes,\n)\nfrom websockets.exceptions import (\n ConnectionClosedOK,\n)\n\nfrom web3._utils.caching import (\n generate_cache_key,\n)\nfrom web3._utils.compat import (\n Self,\n)\nfrom web3.datastructures import (\n NamedElementOnion,\n)\nfrom web3.exceptions import (\n BadResponseFormat,\n MethodUnavailable,\n ProviderConnectionError,\n)\nfrom web3.middleware import (\n AttributeDictMiddleware,\n BufferedGasEstimateMiddleware,\n ENSNameToAddressMiddleware,\n GasPriceStrategyMiddleware,\n ValidationMiddleware,\n)\nfrom web3.middleware.base import (\n Middleware,\n MiddlewareOnion,\n)\nfrom web3.module import (\n apply_result_formatters,\n)\nfrom web3.providers import (\n AutoProvider,\n PersistentConnectionProvider,\n)\nfrom web3.types import (\n RPCEndpoint,\n RPCResponse,\n)\n\nif TYPE_CHECKING:\n from web3.main import ( # noqa: F401\n AsyncWeb3,\n Web3,\n )\n from web3.middleware.base import ( # noqa: F401\n Web3Middleware,\n )\n from web3.providers import ( # noqa: F401\n AsyncBaseProvider,\n BaseProvider,\n )\n from web3.providers.persistent.request_processor import ( # noqa: F401\n RequestProcessor,\n )\n\n\nNULL_RESPONSES = [None, HexBytes(\"0x\"), \"0x\"]\nMETHOD_NOT_FOUND = -32601\n\n\ndef _raise_bad_response_format(response: RPCResponse, error: str = \"\") -> None:\n message = \"The response was in an unexpected format and unable to be parsed.\"\n raw_response = f\"The raw response is: {response}\"\n\n if error is not None and error != \"\":\n message = f\"{message} {error}. {raw_response}\"\n else:\n message = f\"{message} {raw_response}\"\n\n raise BadResponseFormat(message)\n\n\ndef apply_error_formatters(\n error_formatters: Callable[..., Any],\n response: RPCResponse,\n) -> RPCResponse:\n if error_formatters:\n formatted_resp = pipe(response, error_formatters)\n return formatted_resp\n else:\n return response\n\n\ndef apply_null_result_formatters(\n null_result_formatters: Callable[..., Any],\n response: RPCResponse,\n params: Optional[Any] = None,\n) -> RPCResponse:\n if null_result_formatters:\n formatted_resp = pipe(params, null_result_formatters)\n return formatted_resp\n else:\n return response\n\n\nclass RequestManager:\n logger = logging.getLogger(\"web3.manager.RequestManager\")\n\n middleware_onion: Union[\"MiddlewareOnion\", NamedElementOnion[None, None]]\n\n def __init__(\n self,\n w3: Union[\"AsyncWeb3\", \"Web3\"],\n provider: Optional[Union[\"BaseProvider\", \"AsyncBaseProvider\"]] = None,\n middlewares: Optional[Sequence[Tuple[Middleware, str]]] = None,\n ) -> None:\n self.w3 = w3\n\n if provider is None:\n self.provider = AutoProvider()\n else:\n self.provider = provider\n\n if middlewares is None:\n middlewares = self.get_default_middlewares()\n\n self.middleware_onion = NamedElementOnion(middlewares)\n\n if isinstance(provider, PersistentConnectionProvider):\n # set up the request processor to be able to properly process ordered\n # responses from the persistent connection as FIFO\n provider = cast(PersistentConnectionProvider, self.provider)\n self._request_processor: RequestProcessor = provider._request_processor\n\n w3: Union[\"AsyncWeb3\", \"Web3\"] = None\n _provider = None\n\n @property\n def provider(self) -> Union[\"BaseProvider\", \"AsyncBaseProvider\"]:\n return self._provider\n\n @provider.setter\n def provider(self, provider: Union[\"BaseProvider\", \"AsyncBaseProvider\"]) -> None:\n self._provider = provider\n\n @staticmethod\n def get_default_middlewares() -> List[Tuple[Middleware, str]]:\n \"\"\"\n List the default middlewares for the request manager.\n Documentation should remain in sync with these defaults.\n \"\"\"\n return [\n (GasPriceStrategyMiddleware, \"gas_price_strategy\"),\n (ENSNameToAddressMiddleware, \"ens_name_to_address\"),\n (AttributeDictMiddleware, \"attrdict\"),\n (ValidationMiddleware, \"validation\"),\n (BufferedGasEstimateMiddleware, \"gas_estimate\"),\n ]\n\n #\n # Provider requests and response\n #\n def _make_request(\n self, method: Union[RPCEndpoint, Callable[..., RPCEndpoint]], params: Any\n ) -> RPCResponse:\n provider = cast(\"BaseProvider\", self.provider)\n request_func = provider.request_func(\n cast(\"Web3\", self.w3), cast(\"MiddlewareOnion\", self.middleware_onion)\n )\n self.logger.debug(f\"Making request. Method: {method}\")\n return request_func(method, params)\n\n async def _coro_make_request(\n self, method: Union[RPCEndpoint, Callable[..., RPCEndpoint]], params: Any\n ) -> RPCResponse:\n provider = cast(\"AsyncBaseProvider\", self.provider)\n request_func = await provider.request_func(\n cast(\"AsyncWeb3\", self.w3), cast(\"MiddlewareOnion\", self.middleware_onion)\n )\n self.logger.debug(f\"Making request. Method: {method}\")\n return await request_func(method, params)\n\n #\n # formatted_response parses and validates JSON-RPC responses for expected\n # properties (result or an error) with the expected types.\n #\n # Required properties are not strictly enforced to further determine which\n # exception to raise for specific cases.\n #\n # See also: https://www.jsonrpc.org/specification\n #\n @staticmethod\n def formatted_response(\n response: RPCResponse,\n params: Any,\n error_formatters: Optional[Callable[..., Any]] = None,\n null_result_formatters: Optional[Callable[..., Any]] = None,\n ) -> Any:\n # jsonrpc is not enforced (as per the spec) but if present, it must be 2.0\n if \"jsonrpc\" in response and response[\"jsonrpc\"] != \"2.0\":\n _raise_bad_response_format(\n response, 'The \"jsonrpc\" field must be present with a value of \"2.0\"'\n )\n\n # id is not enforced (as per the spec) but if present, it must be a\n # string or integer\n # TODO: v7 - enforce id per the spec\n if \"id\" in response:\n response_id = response[\"id\"]\n # id is always None for errors\n if response_id is None and \"error\" not in response:\n _raise_bad_response_format(\n response, '\"id\" must be None when an error is present'\n )\n elif not isinstance(response_id, (str, int, type(None))):\n _raise_bad_response_format(response, '\"id\" must be a string or integer')\n\n # Response may not include both \"error\" and \"result\"\n if \"error\" in response and \"result\" in response:\n _raise_bad_response_format(\n response, 'Response cannot include both \"error\" and \"result\"'\n )\n\n # Format and validate errors\n elif \"error\" in response:\n error = response.get(\"error\")\n # Raise the error when the value is a string\n if error is None or isinstance(error, str):\n raise ValueError(error)\n\n # Errors must include an integer code\n code = error.get(\"code\")\n if not isinstance(code, int):\n _raise_bad_response_format(response, \"error['code'] must be an integer\")\n elif code == METHOD_NOT_FOUND:\n raise MethodUnavailable(error)\n\n # Errors must include a message\n if not isinstance(error.get(\"message\"), str):\n _raise_bad_response_format(\n response, \"error['message'] must be a string\"\n )\n\n apply_error_formatters(error_formatters, response)\n\n raise ValueError(error)\n\n # Format and validate results\n elif \"result\" in response:\n # Null values for result should apply null_result_formatters\n # Skip when result not present in the response (fallback to False)\n if response.get(\"result\", False) in NULL_RESPONSES:\n apply_null_result_formatters(null_result_formatters, response, params)\n return response.get(\"result\")\n\n # Response from eth_subscription includes response[\"params\"][\"result\"]\n elif (\n response.get(\"method\") == \"eth_subscription\"\n and response.get(\"params\") is not None\n and response[\"params\"].get(\"subscription\") is not None\n and response[\"params\"].get(\"result\") is not None\n ):\n return {\n \"subscription\": response[\"params\"][\"subscription\"],\n \"result\": response[\"params\"][\"result\"],\n }\n\n # Any other response type raises BadResponseFormat\n else:\n _raise_bad_response_format(response)\n\n def request_blocking(\n self,\n method: Union[RPCEndpoint, Callable[..., RPCEndpoint]],\n params: Any,\n error_formatters: Optional[Callable[..., Any]] = None,\n null_result_formatters: Optional[Callable[..., Any]] = None,\n ) -> Any:\n \"\"\"\n Make a synchronous request using the provider\n \"\"\"\n response = self._make_request(method, params)\n return self.formatted_response(\n response, params, error_formatters, null_result_formatters\n )\n\n async def coro_request(\n self,\n method: Union[RPCEndpoint, Callable[..., RPCEndpoint]],\n params: Any,\n error_formatters: Optional[Callable[..., Any]] = None,\n null_result_formatters: Optional[Callable[..., Any]] = None,\n ) -> Any:\n \"\"\"\n Coroutine for making a request using the provider\n \"\"\"\n response = await self._coro_make_request(method, params)\n return self.formatted_response(\n response, params, error_formatters, null_result_formatters\n )\n\n # -- persistent connection -- #\n\n async def send(self, method: RPCEndpoint, params: Any) -> RPCResponse:\n provider = cast(PersistentConnectionProvider, self._provider)\n request_func = await provider.request_func(\n cast(\"AsyncWeb3\", self.w3), cast(\"MiddlewareOnion\", self.middleware_onion)\n )\n self.logger.debug(\n \"Making request to open socket connection - \"\n f\"uri: {provider.endpoint_uri}, method: {method}\"\n )\n response = await request_func(method, params)\n return await self._process_response(response)\n\n def _persistent_message_stream(self) -> \"_AsyncPersistentMessageStream\":\n return _AsyncPersistentMessageStream(self)\n\n async def _get_next_message(self) -> Any:\n return await self._message_stream().__anext__()\n\n async def _message_stream(self) -> AsyncGenerator[RPCResponse, None]:\n if not isinstance(self._provider, PersistentConnectionProvider):\n raise TypeError(\n \"Only providers that maintain an open, persistent connection \"\n \"can listen to streams.\"\n )\n\n if self._provider._message_listener_task is None:\n raise ProviderConnectionError(\n \"No listener found for persistent connection.\"\n )\n\n while True:\n # sleep(0) here seems to be the most efficient way to yield control\n # back to the event loop while waiting for the response in the queue.\n await asyncio.sleep(0)\n\n response = self._request_processor.pop_raw_response(subscription=True)\n if (\n response is not None\n and response.get(\"params\", {}).get(\"subscription\")\n in self._request_processor.active_subscriptions\n ):\n # if response is an active subscription response, process it\n yield await self._process_response(response)\n\n async def _process_response(self, response: RPCResponse) -> RPCResponse:\n provider = cast(PersistentConnectionProvider, self._provider)\n request_info = self._request_processor.get_request_information_for_response(\n response\n )\n\n if request_info is None:\n self.logger.debug(\"No cache key found for response, returning raw response\")\n return response\n else:\n if request_info.method == \"eth_subscribe\" and \"result\" in response.keys():\n # if response for the initial eth_subscribe request, which returns the\n # subscription id\n subscription_id = response[\"result\"]\n cache_key = generate_cache_key(subscription_id)\n if cache_key not in self._request_processor._request_information_cache:\n # cache by subscription id in order to process each response for the\n # subscription as it comes in\n request_info.subscription_id = subscription_id\n provider.logger.debug(\n \"Caching eth_subscription info:\\n \"\n f\"cache_key={cache_key},\\n \"\n f\"request_info={request_info.__dict__}\"\n )\n self._request_processor._request_information_cache.cache(\n cache_key, request_info\n )\n\n # pipe response back through middleware response processors\n if len(request_info.middleware_response_processors) > 0:\n response = pipe(response, *request_info.middleware_response_processors)\n\n (\n result_formatters,\n error_formatters,\n null_formatters,\n ) = request_info.response_formatters\n partly_formatted_response = self.formatted_response(\n response,\n request_info.params,\n error_formatters,\n null_formatters,\n )\n return apply_result_formatters(result_formatters, partly_formatted_response)\n\n\nclass _AsyncPersistentMessageStream:\n \"\"\"\n Async generator for pulling subscription responses from the request processor\n subscription queue. This abstraction is necessary to define the `__aiter__()`\n method required for use with \"async for\" loops.\n \"\"\"\n\n def __init__(self, manager: RequestManager, *args: Any, **kwargs: Any) -> None:\n self.manager = manager\n self.provider: PersistentConnectionProvider = cast(\n PersistentConnectionProvider, manager._provider\n )\n super().__init__(*args, **kwargs)\n\n def __aiter__(self) -> Self:\n return self\n\n async def __anext__(self) -> RPCResponse:\n try:\n return await self.manager._get_next_message()\n except ConnectionClosedOK:\n raise StopAsyncIteration\n\n# Path: web3/providers/persistent/persistent_connection.py\nfrom typing import (\n TYPE_CHECKING,\n Any,\n Dict,\n)\n\nfrom web3.types import (\n RPCEndpoint,\n RPCResponse,\n)\n\nif TYPE_CHECKING:\n from web3.main import ( # noqa: F401\n AsyncWeb3,\n )\n from web3.manager import ( # noqa: F401\n _AsyncPersistentMessageStream,\n )\n\n\nclass PersistentConnection:\n \"\"\"\n A class that houses the public API for interacting with the persistent connection\n via a `AsyncWeb3` instance instantiated with a `PersistentConnectionProvider` class.\n \"\"\"\n\n def __init__(self, w3: \"AsyncWeb3\"):\n self._manager = w3.manager\n\n # -- public methods -- #\n @property\n def subscriptions(self) -> Dict[str, Any]:\n return self._manager._request_processor.active_subscriptions\n\n async def send(self, method: RPCEndpoint, params: Any) -> RPCResponse:\n return await self._manager.send(method, params)\n\n async def recv(self) -> Any:\n return await self._manager._get_next_message()\n\n def process_subscriptions(self) -> \"_AsyncPersistentMessageStream\":\n return self._manager._persistent_message_stream()\n\n# Path: web3/providers/persistent/websocket.py\nimport asyncio\nimport json\nimport logging\nimport os\nfrom typing import (\n Any,\n Dict,\n Optional,\n Union,\n)\n\nfrom eth_typing import (\n URI,\n)\nfrom toolz import (\n merge,\n)\nfrom websockets import (\n WebSocketClientProtocol,\n)\nfrom websockets.client import (\n connect,\n)\nfrom websockets.exceptions import (\n WebSocketException,\n)\n\nfrom web3._utils.caching import (\n async_handle_request_caching,\n)\nfrom web3.exceptions import (\n ProviderConnectionError,\n Web3ValidationError,\n)\nfrom web3.providers.persistent import (\n PersistentConnectionProvider,\n)\nfrom web3.types import (\n RPCEndpoint,\n RPCResponse,\n)\n\nDEFAULT_PING_INTERVAL = 30 # 30 seconds\nDEFAULT_PING_TIMEOUT = 300 # 5 minutes\n\nVALID_WEBSOCKET_URI_PREFIXES = {\"ws://\", \"wss://\"}\nRESTRICTED_WEBSOCKET_KWARGS = {\"uri\", \"loop\"}\nDEFAULT_WEBSOCKET_KWARGS = {\n # set how long to wait between pings from the server\n \"ping_interval\": DEFAULT_PING_INTERVAL,\n # set how long to wait without a pong response before closing the connection\n \"ping_timeout\": DEFAULT_PING_TIMEOUT,\n}\n\n\ndef get_default_endpoint() -> URI:\n return URI(os.environ.get(\"WEB3_WS_PROVIDER_URI\", \"ws://127.0.0.1:8546\"))\n\n\nclass WebSocketProvider(PersistentConnectionProvider):\n logger = logging.getLogger(\"web3.providers.WebSocketProvider\")\n is_async: bool = True\n\n _max_connection_retries: int = 5\n _ws: Optional[WebSocketClientProtocol] = None\n\n def __init__(\n self,\n endpoint_uri: Optional[Union[URI, str]] = None,\n websocket_kwargs: Optional[Dict[str, Any]] = None,\n # `PersistentConnectionProvider` kwargs can be passed through\n **kwargs: Any,\n ) -> None:\n self.endpoint_uri = URI(endpoint_uri)\n if self.endpoint_uri is None:\n self.endpoint_uri = get_default_endpoint()\n\n if not any(\n self.endpoint_uri.startswith(prefix)\n for prefix in VALID_WEBSOCKET_URI_PREFIXES\n ):\n raise Web3ValidationError(\n \"WebSocket endpoint uri must begin with 'ws://' or 'wss://': \"\n f\"{self.endpoint_uri}\"\n )\n\n if websocket_kwargs is not None:\n found_restricted_keys = set(websocket_kwargs).intersection(\n RESTRICTED_WEBSOCKET_KWARGS\n )\n if found_restricted_keys:\n raise Web3ValidationError(\n \"Found restricted keys for websocket_kwargs: \"\n f\"{found_restricted_keys}.\"\n )\n\n self.websocket_kwargs = merge(DEFAULT_WEBSOCKET_KWARGS, websocket_kwargs or {})\n\n super().__init__(**kwargs)\n\n def __str__(self) -> str:\n return f\"WebSocket connection: {self.endpoint_uri}\"\n\n async def is_connected(self, show_traceback: bool = False) -> bool:\n if not self._ws:\n return False\n\n try:\n await self._ws.pong()\n return True\n\n except WebSocketException as e:\n if show_traceback:\n raise ProviderConnectionError(\n f\"Error connecting to endpoint: '{self.endpoint_uri}'\"\n ) from e\n return False\n\n async def connect(self) -> None:\n _connection_attempts = 0\n _backoff_rate_change = 1.75\n _backoff_time = 1.75\n\n while _connection_attempts != self._max_connection_retries:\n try:\n _connection_attempts += 1\n self._ws = await connect(self.endpoint_uri, **self.websocket_kwargs)\n self._message_listener_task = asyncio.create_task(\n self._message_listener()\n )\n break\n except WebSocketException as e:\n if _connection_attempts == self._max_connection_retries:\n raise ProviderConnectionError(\n f\"Could not connect to endpoint: {self.endpoint_uri}. \"\n f\"Retries exceeded max of {self._max_connection_retries}.\"\n ) from e\n self.logger.info(\n f\"Could not connect to endpoint: {self.endpoint_uri}. Retrying in \"\n f\"{round(_backoff_time, 1)} seconds.\",\n exc_info=True,\n )\n await asyncio.sleep(_backoff_time)\n _backoff_time *= _backoff_rate_change\n\n async def disconnect(self) -> None:\n if self._ws is not None and not self._ws.closed:\n await self._ws.close()\n self._ws = None\n self.logger.debug(\n f'Successfully disconnected from endpoint: \"{self.endpoint_uri}'\n )\n\n try:\n self._message_listener_task.cancel()\n await self._message_listener_task\n except (asyncio.CancelledError, StopAsyncIteration):\n pass\n self._request_processor.clear_caches()\n\n @async_handle_request_caching\n async def make_request(self, method: RPCEndpoint, params: Any) -> RPCResponse:\n request_data = self.encode_rpc_request(method, params)\n\n if self._ws is None:\n raise ProviderConnectionError(\n \"Connection to websocket has not been initiated for the provider.\"\n )\n\n await asyncio.wait_for(\n self._ws.send(request_data), timeout=self.request_timeout\n )\n\n current_request_id = json.loads(request_data)[\"id\"]\n response = await self._get_response_for_request_id(current_request_id)\n\n return response\n\n async def _message_listener(self) -> None:\n self.logger.info(\n \"WebSocket listener background task started. Storing all messages in \"\n \"appropriate request processor queues / caches to be processed.\"\n )\n while True:\n # the use of sleep(0) seems to be the most efficient way to yield control\n # back to the event loop to share the loop with other tasks.\n await asyncio.sleep(0)\n\n try:\n async for raw_message in self._ws:\n await asyncio.sleep(0)\n\n response = json.loads(raw_message)\n subscription = response.get(\"method\") == \"eth_subscription\"\n await self._request_processor.cache_raw_response(\n response, subscription=subscription\n )\n except Exception as e:\n if not self.silence_listener_task_exceptions:\n loop = asyncio.get_event_loop()\n for task in asyncio.all_tasks(loop=loop):\n task.cancel()\n raise e\n\n self.logger.error(\n \"Exception caught in listener, error logging and keeping \"\n \"listener background task alive.\"\n f\"\\n error={e.__class__.__name__}: {e}\"\n )\n\n# Path: web3/providers/persistent/__init__.py\nfrom .persistent import (\n PersistentConnectionProvider,\n)\nfrom .persistent_connection import (\n PersistentConnection,\n)\nfrom .request_processor import (\n RequestProcessor,\n)\nfrom .async_ipc import (\n AsyncIPCProvider,\n)\nfrom .websocket import (\n WebSocketProvider,\n)\n\n# Path: web3/providers/async_base.py\nimport asyncio\nimport itertools\nfrom typing import (\n TYPE_CHECKING,\n Any,\n Callable,\n Coroutine,\n Optional,\n Set,\n Tuple,\n cast,\n)\n\nfrom eth_utils import (\n is_text,\n to_bytes,\n to_text,\n)\n\nfrom web3._utils.caching import (\n async_handle_request_caching,\n)\nfrom web3._utils.encoding import (\n FriendlyJsonSerde,\n Web3JsonEncoder,\n)\nfrom web3.exceptions import (\n ProviderConnectionError,\n)\nfrom web3.middleware import (\n async_combine_middlewares,\n)\nfrom web3.middleware.base import (\n Middleware,\n MiddlewareOnion,\n)\nfrom web3.types import (\n RPCEndpoint,\n RPCResponse,\n)\nfrom web3.utils import (\n SimpleCache,\n)\n\nif TYPE_CHECKING:\n from websockets import (\n WebSocketClientProtocol,\n )\n\n from web3 import ( # noqa: F401\n AsyncWeb3,\n WebSocketProvider,\n )\n from web3.providers.persistent import ( # noqa: F401\n RequestProcessor,\n )\n\n\nCACHEABLE_REQUESTS = cast(\n Set[RPCEndpoint],\n (\n \"eth_chainId\",\n \"eth_getBlockByHash\",\n \"eth_getBlockTransactionCountByHash\",\n \"eth_getRawTransactionByHash\",\n \"eth_getTransactionByBlockHashAndIndex\",\n \"eth_getTransactionByHash\",\n \"eth_getUncleByBlockHashAndIndex\",\n \"eth_getUncleCountByBlockHash\",\n \"net_version\",\n \"web3_clientVersion\",\n ),\n)\n\n\nclass AsyncBaseProvider:\n _request_func_cache: Tuple[\n Tuple[Middleware, ...], Callable[..., Coroutine[Any, Any, RPCResponse]]\n ] = (None, None)\n\n is_async = True\n has_persistent_connection = False\n global_ccip_read_enabled: bool = True\n ccip_read_max_redirects: int = 4\n\n # request caching\n cache_allowed_requests: bool = False\n cacheable_requests: Set[RPCEndpoint] = CACHEABLE_REQUESTS\n _request_cache: SimpleCache\n _request_cache_lock: asyncio.Lock = asyncio.Lock()\n\n def __init__(self) -> None:\n self._request_cache = SimpleCache(1000)\n\n async def request_func(\n self, async_w3: \"AsyncWeb3\", middleware_onion: MiddlewareOnion\n ) -> Callable[..., Coroutine[Any, Any, RPCResponse]]:\n middlewares: Tuple[Middleware, ...] = middleware_onion.as_tuple_of_middlewares()\n\n cache_key = self._request_func_cache[0]\n if cache_key != middlewares:\n self._request_func_cache = (\n middlewares,\n await async_combine_middlewares(\n middlewares=middlewares,\n async_w3=async_w3,\n provider_request_fn=self.make_request,\n ),\n )\n return self._request_func_cache[-1]\n\n @async_handle_request_caching\n async def make_request(self, method: RPCEndpoint, params: Any) -> RPCResponse:\n raise NotImplementedError(\"Providers must implement this method\")\n\n async def is_connected(self, show_traceback: bool = False) -> bool:\n raise NotImplementedError(\"Providers must implement this method\")\n\n # -- persistent connection providers -- #\n\n _request_processor: \"RequestProcessor\"\n _message_listener_task: \"asyncio.Task[None]\"\n _listen_event: \"asyncio.Event\"\n\n async def connect(self) -> None:\n raise NotImplementedError(\n \"Persistent connection providers must implement this method\"\n )\n\n async def disconnect(self) -> None:\n raise NotImplementedError(\n \"Persistent connection providers must implement this method\"\n )\n\n # WebSocket typing\n _ws: \"WebSocketClientProtocol\"\n\n # IPC typing\n _reader: Optional[asyncio.StreamReader]\n _writer: Optional[asyncio.StreamWriter]\n\n\nclass AsyncJSONBaseProvider(AsyncBaseProvider):\n def __init__(self) -> None:\n super().__init__()\n self.request_counter = itertools.count()\n\n def encode_rpc_request(self, method: RPCEndpoint, params: Any) -> bytes:\n request_id = next(self.request_counter)\n rpc_dict = {\n \"jsonrpc\": \"2.0\",\n \"method\": method,\n \"params\": params or [],\n \"id\": request_id,\n }\n encoded = FriendlyJsonSerde().json_encode(rpc_dict, cls=Web3JsonEncoder)\n return to_bytes(text=encoded)\n\n def decode_rpc_response(self, raw_response: bytes) -> RPCResponse:\n text_response = str(\n to_text(raw_response) if not is_text(raw_response) else raw_response\n )\n return cast(RPCResponse, FriendlyJsonSerde().json_decode(text_response))\n\n async def is_connected(self, show_traceback: bool = False) -> bool:\n try:\n response = await self.make_request(RPCEndpoint(\"web3_clientVersion\"), [])\n except (OSError, ProviderConnectionError) as e:\n if show_traceback:\n raise ProviderConnectionError(\n f\"Problem connecting to provider with error: {type(e)}: {e}\"\n )\n return False\n\n if \"error\" in response:\n if show_traceback:\n raise ProviderConnectionError(\n f\"Error received from provider: {response}\"\n )\n return False\n\n if response.get(\"jsonrpc\") == \"2.0\":\n return True\n else:\n if show_traceback:\n raise ProviderConnectionError(f\"Bad jsonrpc version: {response}\")\n return False\n\n# Path: web3/providers/auto.py\nimport os\nfrom typing import (\n Any,\n Callable,\n Dict,\n Optional,\n Sequence,\n Tuple,\n Type,\n Union,\n)\nfrom urllib.parse import (\n urlparse,\n)\n\nfrom eth_typing import (\n URI,\n)\n\nfrom web3.exceptions import (\n CannotHandleRequest,\n)\nfrom web3.providers import (\n BaseProvider,\n HTTPProvider,\n IPCProvider,\n LegacyWebSocketProvider,\n)\nfrom web3.types import (\n RPCEndpoint,\n RPCResponse,\n)\n\nHTTP_SCHEMES = {\"http\", \"https\"}\nWS_SCHEMES = {\"ws\", \"wss\"}\n\n\ndef load_provider_from_environment() -> BaseProvider:\n uri_string = URI(os.environ.get(\"WEB3_PROVIDER_URI\", \"\"))\n if not uri_string:\n return None\n\n return load_provider_from_uri(uri_string)\n\n\ndef load_provider_from_uri(\n uri_string: URI, headers: Optional[Dict[str, Tuple[str, str]]] = None\n) -> BaseProvider:\n uri = urlparse(uri_string)\n if uri.scheme == \"file\":\n return IPCProvider(uri.path)\n elif uri.scheme in HTTP_SCHEMES:\n return HTTPProvider(uri_string, headers)\n elif uri.scheme in WS_SCHEMES:\n return LegacyWebSocketProvider(uri_string)\n else:\n raise NotImplementedError(\n \"Web3 does not know how to connect to scheme \"\n f\"{uri.scheme!r} in {uri_string!r}\"\n )\n\n\nclass AutoProvider(BaseProvider):\n default_providers = (\n load_provider_from_environment,\n IPCProvider,\n HTTPProvider,\n LegacyWebSocketProvider,\n )\n _active_provider = None\n\n def __init__(\n self,\n potential_providers: Optional[\n Sequence[Union[Callable[..., BaseProvider], Type[BaseProvider]]]\n ] = None,\n ) -> None:\n \"\"\"\n :param iterable potential_providers: ordered series of provider classes\n to attempt with\n\n AutoProvider will initialize each potential provider (without arguments),\n in an attempt to find an active node. The list will default to\n :attribute:`default_providers`.\n \"\"\"\n if potential_providers:\n self._potential_providers = potential_providers\n else:\n self._potential_providers = self.default_providers\n\n def make_request(self, method: RPCEndpoint, params: Any) -> RPCResponse:\n try:\n return self._proxy_request(method, params)\n except OSError:\n return self._proxy_request(method, params, use_cache=False)\n\n def is_connected(self, show_traceback: bool = False) -> bool:\n provider = self._get_active_provider(use_cache=True)\n return provider is not None and provider.is_connected(show_traceback)\n\n def _proxy_request(\n self, method: RPCEndpoint, params: Any, use_cache: bool = True\n ) -> RPCResponse:\n provider = self._get_active_provider(use_cache)\n if provider is None:\n raise CannotHandleRequest(\n \"Could not discover provider while making request: \"\n f\"method:{method}\\nparams:{params}\\n\"\n )\n\n return provider.make_request(method, params)\n\n def _get_active_provider(self, use_cache: bool) -> Optional[BaseProvider]:\n if use_cache and self._active_provider is not None:\n return self._active_provider\n\n for Provider in self._potential_providers:\n provider = Provider()\n if provider is not None and provider.is_connected():\n self._active_provider = provider\n return provider\n\n return None\n\n# Path: web3/providers/legacy_websocket.py\nimport asyncio\nimport json\nimport logging\nimport os\nfrom threading import (\n Thread,\n)\nfrom types import (\n TracebackType,\n)\nfrom typing import (\n Any,\n Optional,\n Type,\n Union,\n)\n\nfrom eth_typing import (\n URI,\n)\nfrom websockets.client import (\n connect,\n)\nfrom websockets.legacy.client import (\n WebSocketClientProtocol,\n)\n\nfrom web3.exceptions import (\n Web3ValidationError,\n)\nfrom web3.providers.base import (\n JSONBaseProvider,\n)\nfrom web3.types import (\n RPCEndpoint,\n RPCResponse,\n)\n\nRESTRICTED_WEBSOCKET_KWARGS = {\"uri\", \"loop\"}\nDEFAULT_WEBSOCKET_TIMEOUT = 30\n\n\ndef _start_event_loop(loop: asyncio.AbstractEventLoop) -> None:\n asyncio.set_event_loop(loop)\n loop.run_forever()\n loop.close()\n\n\ndef _get_threaded_loop() -> asyncio.AbstractEventLoop:\n new_loop = asyncio.new_event_loop()\n thread_loop = Thread(target=_start_event_loop, args=(new_loop,), daemon=True)\n thread_loop.start()\n return new_loop\n\n\ndef get_default_endpoint() -> URI:\n return URI(os.environ.get(\"WEB3_WS_PROVIDER_URI\", \"ws://127.0.0.1:8546\"))\n\n\nclass PersistentWebSocket:\n def __init__(self, endpoint_uri: URI, websocket_kwargs: Any) -> None:\n self.ws: Optional[WebSocketClientProtocol] = None\n self.endpoint_uri = endpoint_uri\n self.websocket_kwargs = websocket_kwargs\n\n async def __aenter__(self) -> WebSocketClientProtocol:\n if self.ws is None:\n self.ws = await connect(uri=self.endpoint_uri, **self.websocket_kwargs)\n return self.ws\n\n async def __aexit__(\n self,\n exc_type: Type[BaseException],\n exc_val: BaseException,\n exc_tb: TracebackType,\n ) -> None:\n if exc_val is not None:\n try:\n await self.ws.close()\n except Exception:\n pass\n self.ws = None\n\n\nclass LegacyWebSocketProvider(JSONBaseProvider):\n logger = logging.getLogger(\"web3.providers.WebSocketProvider\")\n _loop = None\n\n def __init__(\n self,\n endpoint_uri: Optional[Union[URI, str]] = None,\n websocket_kwargs: Optional[Any] = None,\n...\n", "needle_token_start": 2195, "needle_token_end": 2635, "code_context_ntokens": 16115, "output": ["```python\nasync def _get_response_for_request_id(\n self, request_id: RPCId, timeout: Optional[float] = None\n ) -> RPCResponse:\n if timeout is None:\n timeout = self.request_timeout\n\n async def _match_response_id_to_request_id() -> RPCResponse:\n request_cache_key = generate_cache_key(request_id)\n\n while True:\n # sleep(0) here seems to be the most efficient way to yield control\n # back to the event loop while waiting for the response to be in the\n # queue.\n await asyncio.sleep(0)\n\n if request_cache_key in self._request_processor._request_response_cache:\n self.logger.debug(\n f\"Popping response for id {request_id} from cache.\"\n )\n popped_response = self._request_processor.pop_raw_response(\n cache_key=request_cache_key,\n )\n return popped_response\n\n try:\n # Add the request timeout around the while loop that checks the request\n # cache and tried to recv(). If the request is neither in the cache, nor\n # received within the request_timeout, raise ``TimeExhausted``.\n return await asyncio.wait_for(_match_response_id_to_request_id(), timeout)\n except asyncio.TimeoutError:\n raise TimeExhausted(\n f\"Timed out waiting for response with request id `{request_id}` after \"\n f\"{self.request_timeout} second(s). This may be due to the provider \"\n \"not returning a response with the same id that was sent in the \"\n \"request or an exception raised during the request was caught and \"\n \"allowed to continue.\"\n )\n```"]} +{"repo": "ethereum/web3.py", "name": "get_default_middlewares", "language": "python", "path": "web3/manager.py", "position_ratio": 0.25, "description": "\nFunction Description:\n1. **Purpose**: To provide a list of default middleware components that are essential for processing and managing blockchain requests.\n2. **Input**: None, as this method does not require any parameters to be passed.\n3. **Output**: A list of tuples, where each tuple contains a middleware component and its corresponding identifier as a string.\n4. **Procedure**: The method statically defines and returns a list of middleware components that are commonly used to handle various aspects of blockchain interactions, such as gas price strategy, name resolution, data formatting, input validation, and gas estimation.\n\n", "instruction": "Based on the function description and code context, please retrieve and repeat the exact described function from the code context in a code block wrapped by ```:", "template": "instruction\ncode_context\ndescription\ninstruction", "code_context": "# Path: web3/method.py\nimport functools\nfrom typing import (\n TYPE_CHECKING,\n Any,\n Callable,\n Dict,\n Generic,\n List,\n Optional,\n Sequence,\n Tuple,\n Type,\n TypeVar,\n Union,\n)\nimport warnings\n\nfrom eth_utils.curried import (\n to_tuple,\n)\nfrom eth_utils.toolz import (\n pipe,\n)\n\nfrom web3._utils.method_formatters import (\n get_error_formatters,\n get_null_result_formatters,\n get_request_formatters,\n get_result_formatters,\n)\nfrom web3._utils.rpc_abi import (\n RPC,\n)\nfrom web3.exceptions import (\n Web3ValidationError,\n)\nfrom web3.types import (\n RPCEndpoint,\n TReturn,\n)\n\nif TYPE_CHECKING:\n from web3 import Web3 # noqa: F401\n from web3.module import Module # noqa: F401\n\n\nMunger = Callable[..., Any]\n\n\n@to_tuple\ndef _apply_request_formatters(\n params: Any, request_formatters: Dict[RPCEndpoint, Callable[..., TReturn]]\n) -> Tuple[Any, ...]:\n if request_formatters:\n formatted_params = pipe(params, request_formatters)\n return formatted_params\n return params\n\n\ndef _set_mungers(\n mungers: Optional[Sequence[Munger]], is_property: bool\n) -> Sequence[Any]:\n if is_property and mungers:\n raise Web3ValidationError(\"Mungers cannot be used with a property.\")\n\n return (\n mungers\n if mungers\n else [default_munger] if is_property else [default_root_munger]\n )\n\n\ndef default_munger(_module: \"Module\", *args: Any, **kwargs: Any) -> Tuple[()]:\n if args or kwargs:\n raise Web3ValidationError(\"Parameters cannot be passed to a property.\")\n return ()\n\n\ndef default_root_munger(_module: \"Module\", *args: Any) -> List[Any]:\n return [*args]\n\n\nTFunc = TypeVar(\"TFunc\", bound=Callable[..., Any])\n\n\nclass Method(Generic[TFunc]):\n \"\"\"Method object for web3 module methods\n\n Calls to the Method go through these steps:\n\n 1. input munging - includes normalization, parameter checking, early parameter\n formatting. Any processing on the input parameters that need to happen before\n json_rpc method string selection occurs.\n\n A note about mungers: The first (root) munger should reflect the desired\n api function arguments. In other words, if the api function wants to\n behave as: `get_balance(account, block_identifier=None)`, the root munger\n should accept these same arguments, with the addition of the module as\n the first argument e.g.:\n\n ```\n def get_balance_root_munger(module, account, block_identifier=None):\n if block_identifier is None:\n block_identifier = DEFAULT_BLOCK\n return module, [account, block_identifier]\n ```\n\n all mungers should return an argument list.\n\n if no munger is provided, a default munger expecting no method arguments\n will be used.\n\n 2. method selection - The json_rpc_method argument can be method string or a\n function that returns a method string. If a callable is provided the processed\n method inputs are passed to the method selection function, and the returned\n method string is used.\n\n 3. request and response formatters are set - formatters are retrieved\n using the json rpc method string.\n\n 4. After the parameter processing from steps 1-3 the request is made using\n the calling function returned by the module attribute ``retrieve_caller_fn``\n and the response formatters are applied to the output.\n \"\"\"\n\n def __init__(\n self,\n json_rpc_method: Optional[RPCEndpoint] = None,\n mungers: Optional[Sequence[Munger]] = None,\n request_formatters: Optional[Callable[..., TReturn]] = None,\n result_formatters: Optional[Callable[..., TReturn]] = None,\n null_result_formatters: Optional[Callable[..., TReturn]] = None,\n method_choice_depends_on_args: Optional[Callable[..., RPCEndpoint]] = None,\n is_property: bool = False,\n ):\n self.json_rpc_method = json_rpc_method\n self.mungers = _set_mungers(mungers, is_property)\n self.request_formatters = request_formatters or get_request_formatters\n self.result_formatters = result_formatters or get_result_formatters\n self.null_result_formatters = (\n null_result_formatters or get_null_result_formatters\n )\n self.method_choice_depends_on_args = method_choice_depends_on_args\n self.is_property = is_property\n\n def __get__(\n...\n# Path: web3/module.py\nfrom typing import (\n TYPE_CHECKING,\n Any,\n Callable,\n Coroutine,\n Dict,\n Optional,\n TypeVar,\n Union,\n cast,\n)\n\nfrom eth_abi.codec import (\n ABICodec,\n)\nfrom eth_utils.toolz import (\n curry,\n pipe,\n)\n\nfrom web3._utils.filters import (\n AsyncLogFilter,\n LogFilter,\n _UseExistingFilter,\n)\nfrom web3.method import (\n Method,\n)\nfrom web3.providers.persistent import (\n PersistentConnectionProvider,\n)\nfrom web3.types import (\n RPCEndpoint,\n RPCResponse,\n)\n\nif TYPE_CHECKING:\n from web3.main import ( # noqa: F401\n AsyncWeb3,\n Web3,\n )\n\n\n@curry\ndef apply_result_formatters(\n result_formatters: Callable[..., Any], result: RPCResponse\n) -> RPCResponse:\n if result_formatters:\n formatted_result = pipe(result, result_formatters)\n return formatted_result\n else:\n return result\n\n\nTReturn = TypeVar(\"TReturn\")\n\n\n@curry\ndef retrieve_blocking_method_call_fn(\n w3: \"Web3\", module: \"Module\", method: Method[Callable[..., TReturn]]\n) -> Callable[..., Union[TReturn, LogFilter]]:\n def caller(*args: Any, **kwargs: Any) -> Union[TReturn, LogFilter]:\n try:\n (method_str, params), response_formatters = method.process_params(\n module, *args, **kwargs\n )\n except _UseExistingFilter as err:\n return LogFilter(eth_module=module, filter_id=err.filter_id)\n\n (\n result_formatters,\n error_formatters,\n null_result_formatters,\n ) = response_formatters\n result = w3.manager.request_blocking(\n method_str, params, error_formatters, null_result_formatters\n )\n return apply_result_formatters(result_formatters, result)\n\n return caller\n\n\n@curry\ndef retrieve_async_method_call_fn(\n async_w3: \"AsyncWeb3\", module: \"Module\", method: Method[Callable[..., Any]]\n) -> Callable[..., Coroutine[Any, Any, Optional[Union[RPCResponse, AsyncLogFilter]]]]:\n async def caller(*args: Any, **kwargs: Any) -> Union[RPCResponse, AsyncLogFilter]:\n try:\n (method_str, params), response_formatters = method.process_params(\n module, *args, **kwargs\n )\n except _UseExistingFilter as err:\n return AsyncLogFilter(eth_module=module, filter_id=err.filter_id)\n\n if isinstance(async_w3.provider, PersistentConnectionProvider):\n # TODO: The typing does not seem to be correct for response_formatters.\n # For now, keep the expected typing but ignore it here.\n provider = async_w3.provider\n cache_key = provider._request_processor.cache_request_information(\n cast(RPCEndpoint, method_str), params, response_formatters # type: ignore # noqa: E501\n )\n try:\n method_str = cast(RPCEndpoint, method_str)\n return await async_w3.manager.send(method_str, params)\n except Exception as e:\n if (\n cache_key is not None\n and cache_key\n in provider._request_processor._request_information_cache\n ):\n provider._request_processor.pop_cached_request_information(\n cache_key\n )\n raise e\n else:\n (\n result_formatters,\n error_formatters,\n null_result_formatters,\n ) = response_formatters\n\n result = await async_w3.manager.coro_request(\n method_str, params, error_formatters, null_result_formatters\n )\n return apply_result_formatters(result_formatters, result)\n\n return caller\n\n\n# Module should no longer have access to the full web3 api.\n# Only the calling functions need access to the request methods.\n# Any \"re-entrant\" shenanigans can go in the middlewares, which do\n# have web3 access.\nclass Module:\n is_async = False\n\n def __init__(self, w3: Union[\"AsyncWeb3\", \"Web3\"]) -> None:\n if self.is_async:\n self.retrieve_caller_fn = retrieve_async_method_call_fn(w3, self)\n else:\n self.retrieve_caller_fn = retrieve_blocking_method_call_fn(w3, self)\n self.w3 = w3\n\n @property\n def codec(self) -> ABICodec:\n # use codec set on the Web3 instance\n return self.w3.codec\n\n def attach_methods(\n self,\n methods: Dict[str, Method[Callable[..., Any]]],\n ) -> None:\n for method_name, method_class in methods.items():\n klass = (\n method_class.__get__(obj=self)()\n if method_class.is_property\n else method_class.__get__(obj=self)\n )\n setattr(self, method_name, klass)\n\n# Path: web3/manager.py\nimport asyncio\nimport logging\nfrom typing import (\n TYPE_CHECKING,\n Any,\n AsyncGenerator,\n Callable,\n List,\n Optional,\n Sequence,\n Tuple,\n Union,\n cast,\n)\n\nfrom eth_utils.toolz import (\n pipe,\n)\nfrom hexbytes import (\n HexBytes,\n)\nfrom websockets.exceptions import (\n ConnectionClosedOK,\n)\n\nfrom web3._utils.caching import (\n generate_cache_key,\n)\nfrom web3._utils.compat import (\n Self,\n)\nfrom web3.datastructures import (\n NamedElementOnion,\n)\nfrom web3.exceptions import (\n BadResponseFormat,\n MethodUnavailable,\n ProviderConnectionError,\n)\nfrom web3.middleware import (\n AttributeDictMiddleware,\n BufferedGasEstimateMiddleware,\n ENSNameToAddressMiddleware,\n GasPriceStrategyMiddleware,\n ValidationMiddleware,\n)\nfrom web3.middleware.base import (\n Middleware,\n MiddlewareOnion,\n)\nfrom web3.module import (\n apply_result_formatters,\n)\nfrom web3.providers import (\n AutoProvider,\n PersistentConnectionProvider,\n)\nfrom web3.types import (\n RPCEndpoint,\n RPCResponse,\n)\n\nif TYPE_CHECKING:\n from web3.main import ( # noqa: F401\n AsyncWeb3,\n Web3,\n )\n from web3.middleware.base import ( # noqa: F401\n Web3Middleware,\n )\n from web3.providers import ( # noqa: F401\n AsyncBaseProvider,\n BaseProvider,\n )\n from web3.providers.persistent.request_processor import ( # noqa: F401\n RequestProcessor,\n )\n\n\nNULL_RESPONSES = [None, HexBytes(\"0x\"), \"0x\"]\nMETHOD_NOT_FOUND = -32601\n\n\ndef _raise_bad_response_format(response: RPCResponse, error: str = \"\") -> None:\n message = \"The response was in an unexpected format and unable to be parsed.\"\n raw_response = f\"The raw response is: {response}\"\n\n if error is not None and error != \"\":\n message = f\"{message} {error}. {raw_response}\"\n else:\n message = f\"{message} {raw_response}\"\n\n raise BadResponseFormat(message)\n\n\ndef apply_error_formatters(\n error_formatters: Callable[..., Any],\n response: RPCResponse,\n) -> RPCResponse:\n if error_formatters:\n formatted_resp = pipe(response, error_formatters)\n return formatted_resp\n else:\n return response\n\n\ndef apply_null_result_formatters(\n null_result_formatters: Callable[..., Any],\n response: RPCResponse,\n params: Optional[Any] = None,\n) -> RPCResponse:\n if null_result_formatters:\n formatted_resp = pipe(params, null_result_formatters)\n return formatted_resp\n else:\n return response\n\n\nclass RequestManager:\n logger = logging.getLogger(\"web3.manager.RequestManager\")\n\n middleware_onion: Union[\"MiddlewareOnion\", NamedElementOnion[None, None]]\n\n def __init__(\n self,\n w3: Union[\"AsyncWeb3\", \"Web3\"],\n provider: Optional[Union[\"BaseProvider\", \"AsyncBaseProvider\"]] = None,\n middlewares: Optional[Sequence[Tuple[Middleware, str]]] = None,\n ) -> None:\n self.w3 = w3\n\n if provider is None:\n self.provider = AutoProvider()\n else:\n self.provider = provider\n\n if middlewares is None:\n middlewares = self.get_default_middlewares()\n\n self.middleware_onion = NamedElementOnion(middlewares)\n\n if isinstance(provider, PersistentConnectionProvider):\n # set up the request processor to be able to properly process ordered\n # responses from the persistent connection as FIFO\n provider = cast(PersistentConnectionProvider, self.provider)\n self._request_processor: RequestProcessor = provider._request_processor\n\n w3: Union[\"AsyncWeb3\", \"Web3\"] = None\n _provider = None\n\n @property\n def provider(self) -> Union[\"BaseProvider\", \"AsyncBaseProvider\"]:\n return self._provider\n\n @provider.setter\n def provider(self, provider: Union[\"BaseProvider\", \"AsyncBaseProvider\"]) -> None:\n self._provider = provider\n\n @staticmethod\n \ndef get_default_middlewares() -> List[Tuple[Middleware, str]]:\n \"\"\"\n List the default middlewares for the request manager.\n Documentation should remain in sync with these defaults.\n \"\"\"\n return [\n (GasPriceStrategyMiddleware, \"gas_price_strategy\"),\n (ENSNameToAddressMiddleware, \"ens_name_to_address\"),\n (AttributeDictMiddleware, \"attrdict\"),\n (ValidationMiddleware, \"validation\"),\n (BufferedGasEstimateMiddleware, \"gas_estimate\"),\n ]\n\n #\n # Provider requests and response\n #\n def _make_request(\n self, method: Union[RPCEndpoint, Callable[..., RPCEndpoint]], params: Any\n ) -> RPCResponse:\n provider = cast(\"BaseProvider\", self.provider)\n request_func = provider.request_func(\n cast(\"Web3\", self.w3), cast(\"MiddlewareOnion\", self.middleware_onion)\n )\n self.logger.debug(f\"Making request. Method: {method}\")\n return request_func(method, params)\n\n async def _coro_make_request(\n self, method: Union[RPCEndpoint, Callable[..., RPCEndpoint]], params: Any\n ) -> RPCResponse:\n provider = cast(\"AsyncBaseProvider\", self.provider)\n request_func = await provider.request_func(\n cast(\"AsyncWeb3\", self.w3), cast(\"MiddlewareOnion\", self.middleware_onion)\n )\n self.logger.debug(f\"Making request. Method: {method}\")\n return await request_func(method, params)\n\n #\n # formatted_response parses and validates JSON-RPC responses for expected\n # properties (result or an error) with the expected types.\n #\n # Required properties are not strictly enforced to further determine which\n # exception to raise for specific cases.\n #\n # See also: https://www.jsonrpc.org/specification\n #\n @staticmethod\n def formatted_response(\n response: RPCResponse,\n params: Any,\n error_formatters: Optional[Callable[..., Any]] = None,\n null_result_formatters: Optional[Callable[..., Any]] = None,\n ) -> Any:\n # jsonrpc is not enforced (as per the spec) but if present, it must be 2.0\n if \"jsonrpc\" in response and response[\"jsonrpc\"] != \"2.0\":\n _raise_bad_response_format(\n response, 'The \"jsonrpc\" field must be present with a value of \"2.0\"'\n )\n\n # id is not enforced (as per the spec) but if present, it must be a\n # string or integer\n # TODO: v7 - enforce id per the spec\n if \"id\" in response:\n response_id = response[\"id\"]\n # id is always None for errors\n if response_id is None and \"error\" not in response:\n _raise_bad_response_format(\n response, '\"id\" must be None when an error is present'\n )\n elif not isinstance(response_id, (str, int, type(None))):\n _raise_bad_response_format(response, '\"id\" must be a string or integer')\n\n # Response may not include both \"error\" and \"result\"\n if \"error\" in response and \"result\" in response:\n _raise_bad_response_format(\n response, 'Response cannot include both \"error\" and \"result\"'\n )\n\n # Format and validate errors\n elif \"error\" in response:\n error = response.get(\"error\")\n # Raise the error when the value is a string\n if error is None or isinstance(error, str):\n raise ValueError(error)\n\n # Errors must include an integer code\n code = error.get(\"code\")\n if not isinstance(code, int):\n _raise_bad_response_format(response, \"error['code'] must be an integer\")\n elif code == METHOD_NOT_FOUND:\n raise MethodUnavailable(error)\n\n # Errors must include a message\n if not isinstance(error.get(\"message\"), str):\n _raise_bad_response_format(\n response, \"error['message'] must be a string\"\n )\n\n apply_error_formatters(error_formatters, response)\n\n raise ValueError(error)\n\n # Format and validate results\n elif \"result\" in response:\n # Null values for result should apply null_result_formatters\n # Skip when result not present in the response (fallback to False)\n if response.get(\"result\", False) in NULL_RESPONSES:\n apply_null_result_formatters(null_result_formatters, response, params)\n return response.get(\"result\")\n\n # Response from eth_subscription includes response[\"params\"][\"result\"]\n elif (\n response.get(\"method\") == \"eth_subscription\"\n and response.get(\"params\") is not None\n and response[\"params\"].get(\"subscription\") is not None\n and response[\"params\"].get(\"result\") is not None\n ):\n return {\n \"subscription\": response[\"params\"][\"subscription\"],\n \"result\": response[\"params\"][\"result\"],\n }\n\n # Any other response type raises BadResponseFormat\n else:\n _raise_bad_response_format(response)\n\n def request_blocking(\n self,\n method: Union[RPCEndpoint, Callable[..., RPCEndpoint]],\n params: Any,\n error_formatters: Optional[Callable[..., Any]] = None,\n null_result_formatters: Optional[Callable[..., Any]] = None,\n ) -> Any:\n \"\"\"\n Make a synchronous request using the provider\n \"\"\"\n response = self._make_request(method, params)\n return self.formatted_response(\n response, params, error_formatters, null_result_formatters\n )\n\n async def coro_request(\n self,\n method: Union[RPCEndpoint, Callable[..., RPCEndpoint]],\n params: Any,\n error_formatters: Optional[Callable[..., Any]] = None,\n null_result_formatters: Optional[Callable[..., Any]] = None,\n ) -> Any:\n \"\"\"\n Coroutine for making a request using the provider\n \"\"\"\n response = await self._coro_make_request(method, params)\n return self.formatted_response(\n response, params, error_formatters, null_result_formatters\n )\n\n # -- persistent connection -- #\n\n async def send(self, method: RPCEndpoint, params: Any) -> RPCResponse:\n provider = cast(PersistentConnectionProvider, self._provider)\n request_func = await provider.request_func(\n cast(\"AsyncWeb3\", self.w3), cast(\"MiddlewareOnion\", self.middleware_onion)\n )\n self.logger.debug(\n \"Making request to open socket connection - \"\n f\"uri: {provider.endpoint_uri}, method: {method}\"\n )\n response = await request_func(method, params)\n return await self._process_response(response)\n\n def _persistent_message_stream(self) -> \"_AsyncPersistentMessageStream\":\n return _AsyncPersistentMessageStream(self)\n\n async def _get_next_message(self) -> Any:\n return await self._message_stream().__anext__()\n\n async def _message_stream(self) -> AsyncGenerator[RPCResponse, None]:\n if not isinstance(self._provider, PersistentConnectionProvider):\n raise TypeError(\n \"Only providers that maintain an open, persistent connection \"\n \"can listen to streams.\"\n )\n\n if self._provider._message_listener_task is None:\n raise ProviderConnectionError(\n \"No listener found for persistent connection.\"\n )\n\n while True:\n # sleep(0) here seems to be the most efficient way to yield control\n # back to the event loop while waiting for the response in the queue.\n await asyncio.sleep(0)\n\n response = self._request_processor.pop_raw_response(subscription=True)\n if (\n response is not None\n and response.get(\"params\", {}).get(\"subscription\")\n in self._request_processor.active_subscriptions\n ):\n # if response is an active subscription response, process it\n yield await self._process_response(response)\n\n async def _process_response(self, response: RPCResponse) -> RPCResponse:\n provider = cast(PersistentConnectionProvider, self._provider)\n request_info = self._request_processor.get_request_information_for_response(\n response\n )\n\n if request_info is None:\n self.logger.debug(\"No cache key found for response, returning raw response\")\n return response\n else:\n if request_info.method == \"eth_subscribe\" and \"result\" in response.keys():\n # if response for the initial eth_subscribe request, which returns the\n # subscription id\n subscription_id = response[\"result\"]\n cache_key = generate_cache_key(subscription_id)\n if cache_key not in self._request_processor._request_information_cache:\n # cache by subscription id in order to process each response for the\n # subscription as it comes in\n request_info.subscription_id = subscription_id\n provider.logger.debug(\n \"Caching eth_subscription info:\\n \"\n f\"cache_key={cache_key},\\n \"\n f\"request_info={request_info.__dict__}\"\n )\n self._request_processor._request_information_cache.cache(\n cache_key, request_info\n )\n\n # pipe response back through middleware response processors\n if len(request_info.middleware_response_processors) > 0:\n response = pipe(response, *request_info.middleware_response_processors)\n\n (\n result_formatters,\n error_formatters,\n null_formatters,\n ) = request_info.response_formatters\n partly_formatted_response = self.formatted_response(\n response,\n request_info.params,\n error_formatters,\n null_formatters,\n )\n return apply_result_formatters(result_formatters, partly_formatted_response)\n\n\nclass _AsyncPersistentMessageStream:\n \"\"\"\n Async generator for pulling subscription responses from the request processor\n subscription queue. This abstraction is necessary to define the `__aiter__()`\n method required for use with \"async for\" loops.\n \"\"\"\n\n def __init__(self, manager: RequestManager, *args: Any, **kwargs: Any) -> None:\n self.manager = manager\n self.provider: PersistentConnectionProvider = cast(\n PersistentConnectionProvider, manager._provider\n )\n super().__init__(*args, **kwargs)\n\n def __aiter__(self) -> Self:\n return self\n\n async def __anext__(self) -> RPCResponse:\n try:\n return await self.manager._get_next_message()\n except ConnectionClosedOK:\n raise StopAsyncIteration\n\n# Path: web3/providers/persistent/persistent_connection.py\nfrom typing import (\n TYPE_CHECKING,\n Any,\n Dict,\n)\n\nfrom web3.types import (\n RPCEndpoint,\n RPCResponse,\n)\n\nif TYPE_CHECKING:\n from web3.main import ( # noqa: F401\n AsyncWeb3,\n )\n from web3.manager import ( # noqa: F401\n _AsyncPersistentMessageStream,\n )\n\n\nclass PersistentConnection:\n \"\"\"\n A class that houses the public API for interacting with the persistent connection\n via a `AsyncWeb3` instance instantiated with a `PersistentConnectionProvider` class.\n \"\"\"\n\n def __init__(self, w3: \"AsyncWeb3\"):\n self._manager = w3.manager\n\n # -- public methods -- #\n @property\n def subscriptions(self) -> Dict[str, Any]:\n return self._manager._request_processor.active_subscriptions\n\n async def send(self, method: RPCEndpoint, params: Any) -> RPCResponse:\n return await self._manager.send(method, params)\n\n async def recv(self) -> Any:\n return await self._manager._get_next_message()\n\n def process_subscriptions(self) -> \"_AsyncPersistentMessageStream\":\n return self._manager._persistent_message_stream()\n\n# Path: web3/providers/persistent/websocket.py\nimport asyncio\nimport json\nimport logging\nimport os\nfrom typing import (\n Any,\n Dict,\n Optional,\n Union,\n)\n\nfrom eth_typing import (\n URI,\n)\nfrom toolz import (\n merge,\n)\nfrom websockets import (\n WebSocketClientProtocol,\n)\nfrom websockets.client import (\n connect,\n)\nfrom websockets.exceptions import (\n WebSocketException,\n)\n\nfrom web3._utils.caching import (\n async_handle_request_caching,\n)\nfrom web3.exceptions import (\n ProviderConnectionError,\n Web3ValidationError,\n)\nfrom web3.providers.persistent import (\n PersistentConnectionProvider,\n)\nfrom web3.types import (\n RPCEndpoint,\n RPCResponse,\n)\n\nDEFAULT_PING_INTERVAL = 30 # 30 seconds\nDEFAULT_PING_TIMEOUT = 300 # 5 minutes\n\nVALID_WEBSOCKET_URI_PREFIXES = {\"ws://\", \"wss://\"}\nRESTRICTED_WEBSOCKET_KWARGS = {\"uri\", \"loop\"}\nDEFAULT_WEBSOCKET_KWARGS = {\n # set how long to wait between pings from the server\n \"ping_interval\": DEFAULT_PING_INTERVAL,\n # set how long to wait without a pong response before closing the connection\n \"ping_timeout\": DEFAULT_PING_TIMEOUT,\n}\n\n\ndef get_default_endpoint() -> URI:\n return URI(os.environ.get(\"WEB3_WS_PROVIDER_URI\", \"ws://127.0.0.1:8546\"))\n\n\nclass WebSocketProvider(PersistentConnectionProvider):\n logger = logging.getLogger(\"web3.providers.WebSocketProvider\")\n is_async: bool = True\n\n _max_connection_retries: int = 5\n _ws: Optional[WebSocketClientProtocol] = None\n\n def __init__(\n self,\n endpoint_uri: Optional[Union[URI, str]] = None,\n websocket_kwargs: Optional[Dict[str, Any]] = None,\n # `PersistentConnectionProvider` kwargs can be passed through\n **kwargs: Any,\n ) -> None:\n self.endpoint_uri = URI(endpoint_uri)\n if self.endpoint_uri is None:\n self.endpoint_uri = get_default_endpoint()\n\n if not any(\n self.endpoint_uri.startswith(prefix)\n for prefix in VALID_WEBSOCKET_URI_PREFIXES\n ):\n raise Web3ValidationError(\n \"WebSocket endpoint uri must begin with 'ws://' or 'wss://': \"\n f\"{self.endpoint_uri}\"\n )\n\n if websocket_kwargs is not None:\n found_restricted_keys = set(websocket_kwargs).intersection(\n RESTRICTED_WEBSOCKET_KWARGS\n )\n if found_restricted_keys:\n raise Web3ValidationError(\n \"Found restricted keys for websocket_kwargs: \"\n f\"{found_restricted_keys}.\"\n )\n\n self.websocket_kwargs = merge(DEFAULT_WEBSOCKET_KWARGS, websocket_kwargs or {})\n\n super().__init__(**kwargs)\n\n def __str__(self) -> str:\n return f\"WebSocket connection: {self.endpoint_uri}\"\n\n async def is_connected(self, show_traceback: bool = False) -> bool:\n if not self._ws:\n return False\n\n try:\n await self._ws.pong()\n return True\n\n except WebSocketException as e:\n if show_traceback:\n raise ProviderConnectionError(\n f\"Error connecting to endpoint: '{self.endpoint_uri}'\"\n ) from e\n return False\n\n async def connect(self) -> None:\n _connection_attempts = 0\n _backoff_rate_change = 1.75\n _backoff_time = 1.75\n\n while _connection_attempts != self._max_connection_retries:\n try:\n _connection_attempts += 1\n self._ws = await connect(self.endpoint_uri, **self.websocket_kwargs)\n self._message_listener_task = asyncio.create_task(\n self._message_listener()\n )\n break\n except WebSocketException as e:\n if _connection_attempts == self._max_connection_retries:\n raise ProviderConnectionError(\n f\"Could not connect to endpoint: {self.endpoint_uri}. \"\n f\"Retries exceeded max of {self._max_connection_retries}.\"\n ) from e\n self.logger.info(\n f\"Could not connect to endpoint: {self.endpoint_uri}. Retrying in \"\n f\"{round(_backoff_time, 1)} seconds.\",\n exc_info=True,\n )\n await asyncio.sleep(_backoff_time)\n _backoff_time *= _backoff_rate_change\n\n async def disconnect(self) -> None:\n if self._ws is not None and not self._ws.closed:\n await self._ws.close()\n self._ws = None\n self.logger.debug(\n f'Successfully disconnected from endpoint: \"{self.endpoint_uri}'\n )\n\n try:\n self._message_listener_task.cancel()\n await self._message_listener_task\n except (asyncio.CancelledError, StopAsyncIteration):\n pass\n self._request_processor.clear_caches()\n\n @async_handle_request_caching\n async def make_request(self, method: RPCEndpoint, params: Any) -> RPCResponse:\n request_data = self.encode_rpc_request(method, params)\n\n if self._ws is None:\n raise ProviderConnectionError(\n \"Connection to websocket has not been initiated for the provider.\"\n )\n\n await asyncio.wait_for(\n self._ws.send(request_data), timeout=self.request_timeout\n )\n\n current_request_id = json.loads(request_data)[\"id\"]\n response = await self._get_response_for_request_id(current_request_id)\n\n return response\n\n async def _message_listener(self) -> None:\n self.logger.info(\n \"WebSocket listener background task started. Storing all messages in \"\n \"appropriate request processor queues / caches to be processed.\"\n )\n while True:\n # the use of sleep(0) seems to be the most efficient way to yield control\n # back to the event loop to share the loop with other tasks.\n await asyncio.sleep(0)\n\n try:\n async for raw_message in self._ws:\n await asyncio.sleep(0)\n\n response = json.loads(raw_message)\n subscription = response.get(\"method\") == \"eth_subscription\"\n await self._request_processor.cache_raw_response(\n response, subscription=subscription\n )\n except Exception as e:\n if not self.silence_listener_task_exceptions:\n loop = asyncio.get_event_loop()\n for task in asyncio.all_tasks(loop=loop):\n task.cancel()\n raise e\n\n self.logger.error(\n \"Exception caught in listener, error logging and keeping \"\n \"listener background task alive.\"\n f\"\\n error={e.__class__.__name__}: {e}\"\n )\n\n# Path: web3/providers/persistent/__init__.py\nfrom .persistent import (\n PersistentConnectionProvider,\n)\nfrom .persistent_connection import (\n PersistentConnection,\n)\nfrom .request_processor import (\n RequestProcessor,\n)\nfrom .async_ipc import (\n AsyncIPCProvider,\n)\nfrom .websocket import (\n WebSocketProvider,\n)\n\n# Path: web3/providers/async_base.py\nimport asyncio\nimport itertools\nfrom typing import (\n TYPE_CHECKING,\n Any,\n Callable,\n Coroutine,\n Optional,\n Set,\n Tuple,\n cast,\n)\n\nfrom eth_utils import (\n is_text,\n to_bytes,\n to_text,\n)\n\nfrom web3._utils.caching import (\n async_handle_request_caching,\n)\nfrom web3._utils.encoding import (\n FriendlyJsonSerde,\n Web3JsonEncoder,\n)\nfrom web3.exceptions import (\n ProviderConnectionError,\n)\nfrom web3.middleware import (\n async_combine_middlewares,\n)\nfrom web3.middleware.base import (\n Middleware,\n MiddlewareOnion,\n)\nfrom web3.types import (\n RPCEndpoint,\n RPCResponse,\n)\nfrom web3.utils import (\n SimpleCache,\n)\n\nif TYPE_CHECKING:\n from websockets import (\n WebSocketClientProtocol,\n )\n\n from web3 import ( # noqa: F401\n AsyncWeb3,\n WebSocketProvider,\n )\n from web3.providers.persistent import ( # noqa: F401\n RequestProcessor,\n )\n\n\nCACHEABLE_REQUESTS = cast(\n Set[RPCEndpoint],\n (\n \"eth_chainId\",\n \"eth_getBlockByHash\",\n \"eth_getBlockTransactionCountByHash\",\n \"eth_getRawTransactionByHash\",\n \"eth_getTransactionByBlockHashAndIndex\",\n \"eth_getTransactionByHash\",\n \"eth_getUncleByBlockHashAndIndex\",\n \"eth_getUncleCountByBlockHash\",\n \"net_version\",\n \"web3_clientVersion\",\n ),\n)\n\n\nclass AsyncBaseProvider:\n _request_func_cache: Tuple[\n Tuple[Middleware, ...], Callable[..., Coroutine[Any, Any, RPCResponse]]\n ] = (None, None)\n\n is_async = True\n has_persistent_connection = False\n global_ccip_read_enabled: bool = True\n ccip_read_max_redirects: int = 4\n\n # request caching\n cache_allowed_requests: bool = False\n cacheable_requests: Set[RPCEndpoint] = CACHEABLE_REQUESTS\n _request_cache: SimpleCache\n _request_cache_lock: asyncio.Lock = asyncio.Lock()\n\n def __init__(self) -> None:\n self._request_cache = SimpleCache(1000)\n\n async def request_func(\n self, async_w3: \"AsyncWeb3\", middleware_onion: MiddlewareOnion\n ) -> Callable[..., Coroutine[Any, Any, RPCResponse]]:\n middlewares: Tuple[Middleware, ...] = middleware_onion.as_tuple_of_middlewares()\n\n cache_key = self._request_func_cache[0]\n if cache_key != middlewares:\n self._request_func_cache = (\n middlewares,\n await async_combine_middlewares(\n middlewares=middlewares,\n async_w3=async_w3,\n provider_request_fn=self.make_request,\n ),\n )\n return self._request_func_cache[-1]\n\n @async_handle_request_caching\n async def make_request(self, method: RPCEndpoint, params: Any) -> RPCResponse:\n raise NotImplementedError(\"Providers must implement this method\")\n\n async def is_connected(self, show_traceback: bool = False) -> bool:\n raise NotImplementedError(\"Providers must implement this method\")\n\n # -- persistent connection providers -- #\n\n _request_processor: \"RequestProcessor\"\n _message_listener_task: \"asyncio.Task[None]\"\n _listen_event: \"asyncio.Event\"\n\n async def connect(self) -> None:\n raise NotImplementedError(\n \"Persistent connection providers must implement this method\"\n )\n\n async def disconnect(self) -> None:\n raise NotImplementedError(\n \"Persistent connection providers must implement this method\"\n )\n\n # WebSocket typing\n _ws: \"WebSocketClientProtocol\"\n\n # IPC typing\n _reader: Optional[asyncio.StreamReader]\n _writer: Optional[asyncio.StreamWriter]\n\n\nclass AsyncJSONBaseProvider(AsyncBaseProvider):\n def __init__(self) -> None:\n super().__init__()\n self.request_counter = itertools.count()\n\n def encode_rpc_request(self, method: RPCEndpoint, params: Any) -> bytes:\n request_id = next(self.request_counter)\n rpc_dict = {\n \"jsonrpc\": \"2.0\",\n \"method\": method,\n \"params\": params or [],\n \"id\": request_id,\n }\n encoded = FriendlyJsonSerde().json_encode(rpc_dict, cls=Web3JsonEncoder)\n return to_bytes(text=encoded)\n\n def decode_rpc_response(self, raw_response: bytes) -> RPCResponse:\n text_response = str(\n to_text(raw_response) if not is_text(raw_response) else raw_response\n )\n return cast(RPCResponse, FriendlyJsonSerde().json_decode(text_response))\n\n async def is_connected(self, show_traceback: bool = False) -> bool:\n try:\n response = await self.make_request(RPCEndpoint(\"web3_clientVersion\"), [])\n except (OSError, ProviderConnectionError) as e:\n if show_traceback:\n raise ProviderConnectionError(\n f\"Problem connecting to provider with error: {type(e)}: {e}\"\n )\n return False\n\n if \"error\" in response:\n if show_traceback:\n raise ProviderConnectionError(\n f\"Error received from provider: {response}\"\n )\n return False\n\n if response.get(\"jsonrpc\") == \"2.0\":\n return True\n else:\n if show_traceback:\n raise ProviderConnectionError(f\"Bad jsonrpc version: {response}\")\n return False\n\n# Path: web3/providers/auto.py\nimport os\nfrom typing import (\n Any,\n Callable,\n Dict,\n Optional,\n Sequence,\n Tuple,\n Type,\n Union,\n)\nfrom urllib.parse import (\n urlparse,\n)\n\nfrom eth_typing import (\n URI,\n)\n\nfrom web3.exceptions import (\n CannotHandleRequest,\n)\nfrom web3.providers import (\n BaseProvider,\n HTTPProvider,\n IPCProvider,\n LegacyWebSocketProvider,\n)\nfrom web3.types import (\n RPCEndpoint,\n RPCResponse,\n)\n\nHTTP_SCHEMES = {\"http\", \"https\"}\nWS_SCHEMES = {\"ws\", \"wss\"}\n\n\ndef load_provider_from_environment() -> BaseProvider:\n uri_string = URI(os.environ.get(\"WEB3_PROVIDER_URI\", \"\"))\n if not uri_string:\n return None\n\n return load_provider_from_uri(uri_string)\n\n\ndef load_provider_from_uri(\n uri_string: URI, headers: Optional[Dict[str, Tuple[str, str]]] = None\n) -> BaseProvider:\n uri = urlparse(uri_string)\n if uri.scheme == \"file\":\n return IPCProvider(uri.path)\n elif uri.scheme in HTTP_SCHEMES:\n return HTTPProvider(uri_string, headers)\n elif uri.scheme in WS_SCHEMES:\n return LegacyWebSocketProvider(uri_string)\n else:\n raise NotImplementedError(\n \"Web3 does not know how to connect to scheme \"\n f\"{uri.scheme!r} in {uri_string!r}\"\n )\n\n\nclass AutoProvider(BaseProvider):\n default_providers = (\n load_provider_from_environment,\n IPCProvider,\n HTTPProvider,\n LegacyWebSocketProvider,\n )\n _active_provider = None\n\n def __init__(\n self,\n potential_providers: Optional[\n Sequence[Union[Callable[..., BaseProvider], Type[BaseProvider]]]\n ] = None,\n ) -> None:\n \"\"\"\n :param iterable potential_providers: ordered series of provider classes\n to attempt with\n\n AutoProvider will initialize each potential provider (without arguments),\n in an attempt to find an active node. The list will default to\n :attribute:`default_providers`.\n \"\"\"\n if potential_providers:\n self._potential_providers = potential_providers\n else:\n self._potential_providers = self.default_providers\n\n def make_request(self, method: RPCEndpoint, params: Any) -> RPCResponse:\n try:\n return self._proxy_request(method, params)\n except OSError:\n return self._proxy_request(method, params, use_cache=False)\n\n def is_connected(self, show_traceback: bool = False) -> bool:\n provider = self._get_active_provider(use_cache=True)\n return provider is not None and provider.is_connected(show_traceback)\n\n def _proxy_request(\n self, method: RPCEndpoint, params: Any, use_cache: bool = True\n ) -> RPCResponse:\n provider = self._get_active_provider(use_cache)\n if provider is None:\n raise CannotHandleRequest(\n \"Could not discover provider while making request: \"\n f\"method:{method}\\nparams:{params}\\n\"\n )\n\n return provider.make_request(method, params)\n\n def _get_active_provider(self, use_cache: bool) -> Optional[BaseProvider]:\n if use_cache and self._active_provider is not None:\n return self._active_provider\n\n for Provider in self._potential_providers:\n provider = Provider()\n if provider is not None and provider.is_connected():\n self._active_provider = provider\n return provider\n\n return None\n\n# Path: web3/providers/legacy_websocket.py\nimport asyncio\nimport json\nimport logging\nimport os\nfrom threading import (\n Thread,\n)\nfrom types import (\n TracebackType,\n)\nfrom typing import (\n Any,\n Optional,\n Type,\n Union,\n)\n\nfrom eth_typing import (\n URI,\n)\nfrom websockets.client import (\n connect,\n)\nfrom websockets.legacy.client import (\n WebSocketClientProtocol,\n)\n\nfrom web3.exceptions import (\n Web3ValidationError,\n)\nfrom web3.providers.base import (\n JSONBaseProvider,\n)\nfrom web3.types import (\n RPCEndpoint,\n RPCResponse,\n)\n\nRESTRICTED_WEBSOCKET_KWARGS = {\"uri\", \"loop\"}\nDEFAULT_WEBSOCKET_TIMEOUT = 30\n\n\ndef _start_event_loop(loop: asyncio.AbstractEventLoop) -> None:\n asyncio.set_event_loop(loop)\n loop.run_forever()\n loop.close()\n\n\ndef _get_threaded_loop() -> asyncio.AbstractEventLoop:\n new_loop = asyncio.new_event_loop()\n thread_loop = Thread(target=_start_event_loop, args=(new_loop,), daemon=True)\n thread_loop.start()\n return new_loop\n\n\ndef get_default_endpoint() -> URI:\n return URI(os.environ.get(\"WEB3_WS_PROVIDER_URI\", \"ws://127.0.0.1:8546\"))\n\n\nclass PersistentWebSocket:\n def __init__(self, endpoint_uri: URI, websocket_kwargs: Any) -> None:\n self.ws: Optional[WebSocketClientProtocol] = None\n self.endpoint_uri = endpoint_uri\n self.websocket_kwargs = websocket_kwargs\n\n async def __aenter__(self) -> WebSocketClientProtocol:\n if self.ws is None:\n self.ws = await connect(uri=self.endpoint_uri, **self.websocket_kwargs)\n return self.ws\n\n async def __aexit__(\n self,\n exc_type: Type[BaseException],\n exc_val: BaseException,\n exc_tb: TracebackType,\n ) -> None:\n if exc_val is not None:\n try:\n await self.ws.close()\n except Exception:\n pass\n self.ws = None\n\n\nclass LegacyWebSocketProvider(JSONBaseProvider):\n logger = logging.getLogger(\"web3.providers.WebSocketProvider\")\n _loop = None\n\n def __init__(\n self,\n endpoint_uri: Optional[Union[URI, str]] = None,\n websocket_kwargs: Optional[Any] = None,\n websocket_timeout: int = DEFAULT_WEBSOCKET_TIMEOUT,\n ) -> None:\n self.endpoint_uri = URI(endpoint_uri)\n self.websocket_timeout = websocket_timeout\n if self.endpoint_uri is None:\n self.endpoint_uri = get_default_endpoint()\n if LegacyWebSocketProvider._loop is None:\n LegacyWebSocketProvider._loop = _get_threaded_loop()\n if websocket_kwargs is None:\n websocket_kwargs = {}\n else:\n found_restricted_keys = set(websocket_kwargs).intersection(\n RESTRICTED_WEBSOCKET_KWARGS\n )\n if found_restricted_keys:\n raise Web3ValidationError(\n f\"{RESTRICTED_WEBSOCKET_KWARGS} are not allowed \"\n f\"in websocket_kwargs, found: {found_restricted_keys}\"\n )\n self.conn = PersistentWebSocket(self.endpoint_uri, websocket_kwargs)\n super().__init__()\n\n def __str__(self) -> str:\n return f\"WS connection {self.endpoint_uri}\"\n\n async def coro_make_request(self, request_data: bytes) -> RPCResponse:\n async with self.conn as conn:\n await asyncio.wait_for(\n conn.send(request_data), timeout=self.websocket_timeout\n )\n return json.loads(\n await asyncio.wait_for(conn.recv(), timeout=self.websocket_timeout)\n )\n\n def make_request(self, method: RPCEndpoint, params: Any) -> RPCResponse:\n self.logger.debug(\n f\"Making request WebSocket. URI: {self.endpoint_uri}, \" f\"Method: {method}\"\n )\n request_data = self.encode_rpc_request(method, params)\n future = asyncio.run_coroutine_threadsafe(\n self.coro_make_request(request_data), LegacyWebSocketProvider._loop\n )\n return future.result()\n\n# Path: web3/_utils/http.py\ndef construct_user_agent(class_type: type) -> str:\n from web3 import (\n __version__ as web3_version,\n )\n\n return f\"web3.py/{web3_version}/{class_type.__module__}.{class_type.__qualname__}\"\n\n# Path: web3/providers/rpc/utils.py\nfrom typing import (\n Sequence,\n Type,\n)\n\nfrom pydantic import (\n BaseModel,\n)\n\nfrom web3.types import (\n RPCEndpoint,\n)\n\nREQUEST_RETRY_ALLOWLIST = [\n \"admin\",\n \"net\",\n \"txpool\",\n \"testing\",\n \"evm\",\n \"eth_protocolVersion\",\n \"eth_syncing\",\n \"eth_coinbase\",\n \"eth_mining\",\n \"eth_hashrate\",\n \"eth_chainId\",\n \"eth_gasPrice\",\n \"eth_accounts\",\n \"eth_blockNumber\",\n \"eth_getBalance\",\n \"eth_getStorageAt\",\n \"eth_getProof\",\n \"eth_getCode\",\n \"eth_getBlockByNumber\",\n \"eth_getBlockByHash\",\n \"eth_getBlockTransactionCountByNumber\",\n \"eth_getBlockTransactionCountByHash\",\n \"eth_getUncleCountByBlockNumber\",\n \"eth_getUncleCountByBlockHash\",\n \"eth_getTransactionByHash\",\n \"eth_getTransactionByBlockHashAndIndex\",\n \"eth_getTransactionByBlockNumberAndIndex\",\n \"eth_getTransactionReceipt\",\n \"eth_getTransactionCount\",\n \"eth_getRawTransactionByHash\",\n \"eth_call\",\n \"eth_estimateGas\",\n \"eth_createAccessList\",\n \"eth_maxPriorityFeePerGas\",\n \"eth_newBlockFilter\",\n \"eth_newPendingTransactionFilter\",\n \"eth_newFilter\",\n \"eth_getFilterChanges\",\n \"eth_getFilterLogs\",\n \"eth_getLogs\",\n \"eth_uninstallFilter\",\n \"eth_getCompilers\",\n \"eth_getWork\",\n \"eth_sign\",\n \"eth_signTypedData\",\n \"eth_sendRawTransaction\",\n \"personal_importRawKey\",\n \"personal_newAccount\",\n \"personal_listAccounts\",\n \"personal_listWallets\",\n \"personal_lockAccount\",\n \"personal_unlockAccount\",\n \"personal_ecRecover\",\n \"personal_sign\",\n \"personal_signTypedData\",\n]\n\n\ndef check_if_retry_on_failure(\n method: RPCEndpoint,\n allowlist: Sequence[str] = None,\n) -> bool:\n if allowlist is None:\n allowlist = REQUEST_RETRY_ALLOWLIST\n\n if method in allowlist or method.split(\"_\")[0] in allowlist:\n return True\n else:\n return False\n\n\nclass ExceptionRetryConfiguration(BaseModel):\n errors: Sequence[Type[BaseException]]\n retries: int\n backoff_factor: float\n method_allowlist: Sequence[str]\n\n def __init__(\n self,\n errors: Sequence[Type[BaseException]] = None,\n retries: int = 5,\n backoff_factor: float = 0.5,\n method_allowlist: Sequence[str] = None,\n ):\n super().__init__(\n errors=errors,\n retries=retries,\n backoff_factor=backoff_factor,\n method_allowlist=method_allowlist or REQUEST_RETRY_ALLOWLIST,\n )\n\n# Path: web3/providers/rpc/async_rpc.py\nimport asyncio\nimport logging\nfrom typing import (\n Any,\n Dict,\n Iterable,\n Optional,\n Tuple,\n Union,\n)\n\nfrom aiohttp import (\n ClientError,\n ClientSession,\n)\nfrom eth_typing import (\n URI,\n)\nfrom eth_utils import (\n to_dict,\n)\n\nfrom web3._utils.http import (\n construct_user_agent,\n)\nfrom web3._utils.request import (\n async_cache_and_return_session as _async_cache_and_return_session,\n async_make_post_request,\n get_default_http_endpoint,\n)\nfrom web3.types import (\n RPCEndpoint,\n RPCResponse,\n)\n\nfrom ..._utils.caching import (\n async_handle_request_caching,\n)\nfrom ..async_base import (\n AsyncJSONBaseProvider,\n)\nfrom .utils import (\n ExceptionRetryConfiguration,\n check_if_retry_on_failure,\n)\n\n\nclass AsyncHTTPProvider(AsyncJSONBaseProvider):\n logger = logging.getLogger(\"web3.providers.AsyncHTTPProvider\")\n endpoint_uri = None\n _request_kwargs = None\n\n def __init__(\n self,\n endpoint_uri: Optional[Union[URI, str]] = None,\n request_kwargs: Optional[Any] = None,\n exception_retry_configuration: Optional[\n ExceptionRetryConfiguration\n ] = ExceptionRetryConfiguration(errors=(ClientError, TimeoutError)),\n ) -> None:\n if endpoint_uri is None:\n self.endpoint_uri = get_default_http_endpoint()\n else:\n self.endpoint_uri = URI(endpoint_uri)\n\n self._request_kwargs = request_kwargs or {}\n self.exception_retry_configuration = exception_retry_configuration\n\n super().__init__()\n\n async def cache_async_session(self, session: ClientSession) -> ClientSession:\n return await _async_cache_and_return_session(self.endpoint_uri, session)\n\n def __str__(self) -> str:\n return f\"RPC connection {self.endpoint_uri}\"\n\n @to_dict\n def get_request_kwargs(self) -> Iterable[Tuple[str, Any]]:\n if \"headers\" not in self._request_kwargs:\n yield \"headers\", self.get_request_headers()\n for key, value in self._request_kwargs.items():\n yield key, value\n\n def get_request_headers(self) -> Dict[str, str]:\n return {\n \"Content-Type\": \"application/json\",\n \"User-Agent\": construct_user_agent(type(self)),\n }\n\n async def _make_request(self, method: RPCEndpoint, request_data: bytes) -> bytes:\n \"\"\"\n If exception_retry_configuration is set, retry on failure; otherwise, make\n the request without retrying.\n \"\"\"\n if (\n self.exception_retry_configuration is not None\n and check_if_retry_on_failure(\n method, self.exception_retry_configuration.method_allowlist\n )\n ):\n for i in range(self.exception_retry_configuration.retries):\n try:\n return await async_make_post_request(\n self.endpoint_uri, request_data, **self.get_request_kwargs()\n )\n except tuple(self.exception_retry_configuration.errors):\n if i < self.exception_retry_configuration.retries - 1:\n await asyncio.sleep(\n self.exception_retry_configuration.backoff_factor\n )\n continue\n else:\n raise\n return None\n else:\n return await async_make_post_request(\n self.endpoint_uri, request_data, **self.get_request_kwargs()\n )\n\n @async_handle_request_caching\n async def make_request(self, method: RPCEndpoint, params: Any) -> RPCResponse:\n self.logger.debug(\n f\"Making request HTTP. URI: {self.endpoint_uri}, Method: {method}\"\n )\n request_data = self.encode_rpc_request(method, params)\n raw_response = await self._make_request(method, request_data)\n response = self.decode_rpc_response(raw_response)\n self.logger.debug(\n f\"Getting response HTTP. URI: {self.endpoint_uri}, \"\n f\"Method: {method}, Response: {response}\"\n )\n return response\n\n# Path: web3/providers/rpc/rpc.py\nimport logging\nimport time\nfrom typing import (\n TYPE_CHECKING,\n Any,\n Dict,\n Iterable,\n Optional,\n Tuple,\n Union,\n)\n\nfrom eth_typing import (\n URI,\n)\nfrom eth_utils import (\n to_dict,\n)\nimport requests\n\nfrom web3._utils.http import (\n construct_user_agent,\n)\nfrom web3._utils.request import (\n cache_and_return_session,\n get_default_http_endpoint,\n make_post_request,\n)\nfrom web3.types import (\n RPCEndpoint,\n RPCResponse,\n)\n\nfrom ..._utils.caching import (\n handle_request_caching,\n)\nfrom ..base import (\n JSONBaseProvider,\n)\nfrom .utils import (\n ExceptionRetryConfiguration,\n check_if_retry_on_failure,\n)\n\nif TYPE_CHECKING:\n from web3.middleware.base import ( # noqa: F401\n Middleware,\n )\n\n\nclass HTTPProvider(JSONBaseProvider):\n logger = logging.getLogger(\"web3.providers.HTTPProvider\")\n endpoint_uri = None\n\n _request_args = None\n _request_kwargs = None\n\n exception_retry_configuration: Optional[ExceptionRetryConfiguration] = None\n\n def __init__(\n self,\n endpoint_uri: Optional[Union[URI, str]] = None,\n request_kwargs: Optional[Any] = None,\n session: Optional[Any] = None,\n exception_retry_configuration: Optional[ExceptionRetryConfiguration] = (\n ExceptionRetryConfiguration(\n errors=(\n ConnectionError,\n requests.HTTPError,\n requests.Timeout,\n )\n )\n ),\n ) -> None:\n if endpoint_uri is None:\n self.endpoint_uri = get_default_http_endpoint()\n else:\n self.endpoint_uri = URI(endpoint_uri)\n\n self._request_kwargs = request_kwargs or {}\n self.exception_retry_configuration = exception_retry_configuration\n\n if session:\n cache_and_return_session(self.endpoint_uri, session)\n\n super().__init__()\n\n def __str__(self) -> str:\n return f\"RPC connection {self.endpoint_uri}\"\n\n @to_dict\n def get_request_kwargs(self) -> Iterable[Tuple[str, Any]]:\n if \"headers\" not in self._request_kwargs:\n yield \"headers\", self.get_request_headers()\n for key, value in self._request_kwargs.items():\n yield key, value\n\n def get_request_headers(self) -> Dict[str, str]:\n return {\n \"Content-Type\": \"application/json\",\n \"User-Agent\": construct_user_agent(type(self)),\n }\n\n def _make_request(self, method: RPCEndpoint, request_data: bytes) -> bytes:\n \"\"\"\n If exception_retry_configuration is set, retry on failure; otherwise, make\n the request without retrying.\n \"\"\"\n if (\n self.exception_retry_configuration is not None\n and check_if_retry_on_failure(\n method, self.exception_retry_configuration.method_allowlist\n )\n ):\n for i in range(self.exception_retry_configuration.retries):\n try:\n return make_post_request(\n...\n", "needle_token_start": 3935, "needle_token_end": 4072, "code_context_ntokens": 16086, "output": ["```python\n@staticmethod\ndef get_default_middlewares() -> List[Tuple[Middleware, str]]:\n \"\"\"\n List the default middlewares for the request manager.\n Documentation should remain in sync with these defaults.\n \"\"\"\n return [\n (GasPriceStrategyMiddleware, \"gas_price_strategy\"),\n (ENSNameToAddressMiddleware, \"ens_name_to_address\"),\n (AttributeDictMiddleware, \"attrdict\"),\n (ValidationMiddleware, \"validation\"),\n (BufferedGasEstimateMiddleware, \"gas_estimate\"),\n ]\n```"]} +{"repo": "ethereum/web3.py", "name": "_get_transaction", "language": "python", "path": "web3/contract/base_contract.py", "position_ratio": 0.35, "description": "\nFunction Description:\n1. **Purpose**: The function is designed to prepare transaction parameters for deploying a smart contract, ensuring all necessary data is included and formatted correctly.\n2. **Input**: An optional dictionary containing transaction parameters.\n3. **Output**: A dictionary of transaction parameters, including default values and encoded contract data necessary for deployment.\n4. **Procedure**: \n - Initializes a transaction dictionary either from the provided input or as an empty dictionary if no input is provided.\n - Validates and removes any forbidden keys from the transaction dictionary.\n - Sets a default sender address if it's not specified and a default account exists.\n - Adds encoded contract data to the transaction dictionary.\n - Returns the prepared transaction dictionary.\n\n", "instruction": "Based on the function description and code context, please retrieve and repeat the exact described function from the code context in a code block wrapped by ```:", "template": "instruction\ncode_context\ndescription\ninstruction", "code_context": " self.w3.eth.default_account, # type: ignore\n )\n\n if \"to\" not in call_transaction:\n if isinstance(self, type):\n raise ValueError(\n \"When using `Contract.[methodtype].[method].call()` from\"\n \" a contract factory you \"\n \"must provide a `to` address with the transaction\"\n )\n else:\n raise ValueError(\n \"Please ensure that this contract instance has an address.\"\n )\n\n return call_transaction\n\n def _transact(self, transaction: Optional[TxParams] = None) -> TxParams:\n if transaction is None:\n transact_transaction: TxParams = {}\n else:\n transact_transaction = cast(TxParams, dict(**transaction))\n\n if \"data\" in transact_transaction:\n raise ValueError(\"Cannot set 'data' field in transact transaction\")\n\n if self.address is not None:\n transact_transaction.setdefault(\"to\", self.address)\n if self.w3.eth.default_account is not empty:\n # type ignored b/c check prevents an empty default_account\n transact_transaction.setdefault(\n \"from\", self.w3.eth.default_account # type: ignore\n )\n\n if \"to\" not in transact_transaction:\n if isinstance(self, type):\n raise ValueError(\n \"When using `Contract.transact` from a contract factory you \"\n \"must provide a `to` address with the transaction\"\n )\n else:\n raise ValueError(\n \"Please ensure that this contract instance has an address.\"\n )\n return transact_transaction\n\n def _estimate_gas(self, transaction: Optional[TxParams] = None) -> TxParams:\n if transaction is None:\n estimate_gas_transaction: TxParams = {}\n else:\n estimate_gas_transaction = cast(TxParams, dict(**transaction))\n\n if \"data\" in estimate_gas_transaction:\n raise ValueError(\"Cannot set 'data' field in estimate_gas transaction\")\n if \"to\" in estimate_gas_transaction:\n raise ValueError(\"Cannot set to in estimate_gas transaction\")\n\n if self.address:\n estimate_gas_transaction.setdefault(\"to\", self.address)\n if self.w3.eth.default_account is not empty:\n # type ignored b/c check prevents an empty default_account\n estimate_gas_transaction.setdefault(\n \"from\", self.w3.eth.default_account # type: ignore\n )\n\n if \"to\" not in estimate_gas_transaction:\n if isinstance(self, type):\n raise ValueError(\n \"When using `Contract.estimate_gas` from a contract factory \"\n \"you must provide a `to` address with the transaction\"\n )\n else:\n raise ValueError(\n \"Please ensure that this contract instance has an address.\"\n )\n return estimate_gas_transaction\n\n def _build_transaction(self, transaction: Optional[TxParams] = None) -> TxParams:\n if transaction is None:\n built_transaction: TxParams = {}\n else:\n built_transaction = cast(TxParams, dict(**transaction))\n\n if \"data\" in built_transaction:\n raise ValueError(\"Cannot set 'data' field in build transaction\")\n\n if not self.address and \"to\" not in built_transaction:\n raise ValueError(\n \"When using `ContractFunction.build_transaction` from a contract \"\n \"factory you must provide a `to` address with the transaction\"\n )\n if self.address and \"to\" in built_transaction:\n raise ValueError(\"Cannot set 'to' field in contract call build transaction\")\n\n if self.address:\n built_transaction.setdefault(\"to\", self.address)\n\n if \"to\" not in built_transaction:\n raise ValueError(\n \"Please ensure that this contract instance has an address.\"\n )\n\n return built_transaction\n\n @combomethod\n def _encode_transaction_data(cls) -> HexStr:\n return add_0x_prefix(encode_abi(cls.w3, cls.abi, cls.arguments, cls.selector))\n\n _return_data_normalizers: Optional[Tuple[Callable[..., Any], ...]] = tuple()\n\n def __repr__(self) -> str:\n if self.abi:\n _repr = f\"\"\n return f\"\"\n\n @classmethod\n def factory(\n cls, class_name: str, **kwargs: Any\n ) -> Union[\"ContractFunction\", \"AsyncContractFunction\"]:\n return PropertyCheckingFactory(class_name, (cls,), kwargs)(kwargs.get(\"abi\"))\n\n\nclass BaseContractFunctions:\n \"\"\"Class containing contract function objects\"\"\"\n\n def __init__(\n self,\n abi: ABI,\n w3: Union[\"Web3\", \"AsyncWeb3\"],\n contract_function_class: Union[\n Type[\"ContractFunction\"], Type[\"AsyncContractFunction\"]\n ],\n address: Optional[ChecksumAddress] = None,\n decode_tuples: Optional[bool] = False,\n ) -> None:\n self.abi = abi\n self.w3 = w3\n self.address = address\n\n if self.abi:\n self._functions = filter_by_type(\"function\", self.abi)\n for func in self._functions:\n setattr(\n self,\n func[\"name\"],\n contract_function_class.factory(\n func[\"name\"],\n w3=self.w3,\n contract_abi=self.abi,\n address=self.address,\n decode_tuples=decode_tuples,\n function_identifier=func[\"name\"],\n ),\n )\n\n def __iter__(self) -> Iterable[\"ABIFunction\"]:\n if not hasattr(self, \"_functions\") or not self._functions:\n return\n\n for func in self._functions:\n yield self[func[\"name\"]]\n\n def __getitem__(self, function_name: str) -> ABIFunction:\n return getattr(self, function_name)\n\n def __hasattr__(self, function_name: str) -> bool:\n try:\n return function_name in self.__dict__[\"_functions\"]\n except ABIFunctionNotFound:\n return False\n\n\nclass BaseContract:\n \"\"\"Base class for Contract proxy classes.\n\n First you need to create your Contract classes using\n :meth:`web3.eth.Eth.contract` that takes compiled Solidity contract\n ABI definitions as input. The created class object will be a subclass of\n this base class.\n\n After you have your Contract proxy class created you can interact with\n smart contracts\n\n * Create a Contract proxy object for an existing deployed smart contract by\n its address using :meth:`__init__`\n\n * Deploy a new smart contract using :py:meth:`Contract.constructor.transact()`\n \"\"\"\n\n # set during class construction\n w3: Union[\"Web3\", \"AsyncWeb3\"] = None\n\n # instance level properties\n address: ChecksumAddress = None\n\n # class properties (overridable at instance level)\n abi: ABI = None\n\n asm = None\n ast = None\n\n bytecode = None\n bytecode_runtime = None\n clone_bin = None\n\n decode_tuples = None\n dev_doc = None\n interface = None\n metadata = None\n opcodes = None\n src_map = None\n src_map_runtime = None\n user_doc = None\n\n # Public API\n #\n @combomethod\n def encodeABI(\n cls,\n fn_name: str,\n args: Optional[Any] = None,\n kwargs: Optional[Any] = None,\n data: Optional[HexStr] = None,\n ) -> HexStr:\n \"\"\"\n Encodes the arguments using the Ethereum ABI for the contract function\n that matches the given name and arguments..\n\n :param data: defaults to function selector\n \"\"\"\n fn_abi, fn_selector, fn_arguments = get_function_info(\n fn_name,\n cls.w3.codec,\n contract_abi=cls.abi,\n args=args,\n kwargs=kwargs,\n )\n\n if data is None:\n data = fn_selector\n\n return encode_abi(cls.w3, fn_abi, fn_arguments, data)\n\n @combomethod\n def all_functions(\n self,\n ) -> \"BaseContractFunction\":\n return self.find_functions_by_identifier(\n self.abi, self.w3, self.address, lambda _: True\n )\n\n @combomethod\n def get_function_by_signature(self, signature: str) -> \"BaseContractFunction\":\n if \" \" in signature:\n raise ValueError(\n \"Function signature should not contain any spaces. \"\n f\"Found spaces in input: {signature}\"\n )\n\n def callable_check(fn_abi: ABIFunction) -> bool:\n return abi_to_signature(fn_abi) == signature\n\n fns = self.find_functions_by_identifier(\n self.abi, self.w3, self.address, callable_check\n )\n return self.get_function_by_identifier(fns, \"signature\")\n\n @combomethod\n def find_functions_by_name(self, fn_name: str) -> \"BaseContractFunction\":\n def callable_check(fn_abi: ABIFunction) -> bool:\n return fn_abi[\"name\"] == fn_name\n\n return self.find_functions_by_identifier(\n self.abi, self.w3, self.address, callable_check\n )\n\n @combomethod\n def get_function_by_name(self, fn_name: str) -> \"BaseContractFunction\":\n fns = self.find_functions_by_name(fn_name)\n return self.get_function_by_identifier(fns, \"name\")\n\n @combomethod\n def get_function_by_selector(\n self, selector: Union[bytes, int, HexStr]\n ) -> \"BaseContractFunction\":\n def callable_check(fn_abi: ABIFunction) -> bool:\n # typed dict cannot be used w/ a normal Dict\n # https://github.com/python/mypy/issues/4976\n return encode_hex(function_abi_to_4byte_selector(fn_abi)) == to_4byte_hex(selector) # type: ignore # noqa: E501\n\n fns = self.find_functions_by_identifier(\n self.abi, self.w3, self.address, callable_check\n )\n return self.get_function_by_identifier(fns, \"selector\")\n\n @combomethod\n def decode_function_input(\n self, data: HexStr\n ) -> Tuple[\"BaseContractFunction\", Dict[str, Any]]:\n # type ignored b/c expects data arg to be HexBytes\n data = HexBytes(data) # type: ignore\n func = self.get_function_by_selector(data[:4])\n arguments = decode_transaction_data(\n func.abi, data, normalizers=BASE_RETURN_NORMALIZERS\n )\n return func, arguments\n\n @combomethod\n def find_functions_by_args(self, *args: Any) -> \"BaseContractFunction\":\n def callable_check(fn_abi: ABIFunction) -> bool:\n return check_if_arguments_can_be_encoded(\n fn_abi, self.w3.codec, args=args, kwargs={}\n )\n\n return self.find_functions_by_identifier(\n self.abi, self.w3, self.address, callable_check\n )\n\n @combomethod\n def get_function_by_args(self, *args: Any) -> \"BaseContractFunction\":\n fns = self.find_functions_by_args(*args)\n return self.get_function_by_identifier(fns, \"args\")\n\n #\n # Private Helpers\n #\n _return_data_normalizers: Tuple[Callable[..., Any], ...] = tuple()\n\n @classmethod\n def _prepare_transaction(\n cls,\n fn_name: str,\n fn_args: Optional[Any] = None,\n fn_kwargs: Optional[Any] = None,\n transaction: Optional[TxParams] = None,\n ) -> TxParams:\n return prepare_transaction(\n cls.address,\n cls.w3,\n fn_identifier=fn_name,\n contract_abi=cls.abi,\n transaction=transaction,\n fn_args=fn_args,\n fn_kwargs=fn_kwargs,\n )\n\n @classmethod\n def _find_matching_fn_abi(\n cls,\n fn_identifier: Optional[str] = None,\n args: Optional[Any] = None,\n kwargs: Optional[Any] = None,\n ) -> ABIFunction:\n return find_matching_fn_abi(\n cls.abi, cls.w3.codec, fn_identifier=fn_identifier, args=args, kwargs=kwargs\n )\n\n @classmethod\n def _find_matching_event_abi(\n cls,\n event_name: Optional[str] = None,\n argument_names: Optional[Sequence[str]] = None,\n ) -> ABIEvent:\n return find_matching_event_abi(\n abi=cls.abi, event_name=event_name, argument_names=argument_names\n )\n\n @combomethod\n def _encode_constructor_data(\n cls, args: Optional[Any] = None, kwargs: Optional[Any] = None\n ) -> HexStr:\n constructor_abi = get_constructor_abi(cls.abi)\n\n if constructor_abi:\n if args is None:\n args = tuple()\n if kwargs is None:\n kwargs = {}\n\n arguments = merge_args_and_kwargs(constructor_abi, args, kwargs)\n\n deploy_data = add_0x_prefix(\n encode_abi(cls.w3, constructor_abi, arguments, data=cls.bytecode)\n )\n else:\n if args is not None or kwargs is not None:\n msg = \"Constructor args were provided, but no constructor function was provided.\" # noqa: E501\n raise TypeError(msg)\n\n deploy_data = to_hex(cls.bytecode)\n\n return deploy_data\n\n @combomethod\n def find_functions_by_identifier(\n cls,\n contract_abi: ABI,\n w3: Union[\"Web3\", \"AsyncWeb3\"],\n address: ChecksumAddress,\n callable_check: Callable[..., Any],\n ) -> List[Any]:\n raise NotImplementedError(\n \"This method should be implemented in the inherited class\"\n )\n\n @combomethod\n def get_function_by_identifier(\n cls, fns: Sequence[\"BaseContractFunction\"], identifier: str\n ) -> \"BaseContractFunction\":\n raise NotImplementedError(\n \"This method should be implemented in the inherited class\"\n )\n\n @staticmethod\n def get_fallback_function(\n abi: ABI,\n w3: Union[\"Web3\", \"AsyncWeb3\"],\n function_type: Type[\"BaseContractFunction\"],\n address: Optional[ChecksumAddress] = None,\n ) -> \"BaseContractFunction\":\n if abi and fallback_func_abi_exists(abi):\n return function_type.factory(\n \"fallback\",\n w3=w3,\n contract_abi=abi,\n address=address,\n function_identifier=FallbackFn,\n )()\n\n return cast(function_type, NonExistentFallbackFunction()) # type: ignore\n\n @staticmethod\n def get_receive_function(\n abi: ABI,\n w3: Union[\"Web3\", \"AsyncWeb3\"],\n function_type: Type[\"BaseContractFunction\"],\n address: Optional[ChecksumAddress] = None,\n ) -> \"BaseContractFunction\":\n if abi and receive_func_abi_exists(abi):\n return function_type.factory(\n \"receive\",\n w3=w3,\n contract_abi=abi,\n address=address,\n function_identifier=ReceiveFn,\n )()\n\n return cast(function_type, NonExistentReceiveFunction()) # type: ignore\n\n\nclass BaseContractCaller:\n \"\"\"\n An alternative Contract API.\n\n This call:\n\n > contract.caller({'from': eth.accounts[1], 'gas': 100000, ...}).add(2, 3)\n is equivalent to this call in the classic contract:\n > contract.functions.add(2, 3).call({'from': eth.accounts[1], 'gas': 100000, ...})\n\n Other options for invoking this class include:\n\n > contract.caller.add(2, 3)\n\n or\n\n > contract.caller().add(2, 3)\n\n or\n\n > contract.caller(transaction={'from': eth.accounts[1], 'gas': 100000, ...}).add(2, 3) # noqa: E501\n \"\"\"\n\n # mypy types\n _functions: List[Union[ABIFunction, ABIEvent]]\n\n def __init__(\n self,\n abi: ABI,\n w3: Union[\"Web3\", \"AsyncWeb3\"],\n address: ChecksumAddress,\n decode_tuples: Optional[bool] = False,\n ) -> None:\n self.w3 = w3\n self.address = address\n self.abi = abi\n self.decode_tuples = decode_tuples\n self._functions = []\n\n def __getattr__(self, function_name: str) -> Any:\n if self.abi is None:\n raise NoABIFound(\n \"There is no ABI found for this contract.\",\n )\n elif not self._functions or len(self._functions) == 0:\n raise NoABIFunctionsFound(\n \"The ABI for this contract contains no function definitions. \",\n \"Are you sure you provided the correct contract ABI?\",\n )\n elif function_name not in {fn[\"name\"] for fn in self._functions}:\n functions_available = \", \".join([fn[\"name\"] for fn in self._functions])\n raise ABIFunctionNotFound(\n f\"The function '{function_name}' was not found in this contract's ABI.\",\n \" Here is a list of all of the function names found: \",\n f\"{functions_available}. \",\n \"Did you mean to call one of those functions?\",\n )\n else:\n return super().__getattribute__(function_name)\n\n def __hasattr__(self, event_name: str) -> bool:\n try:\n return event_name in self.__dict__[\"_events\"]\n except ABIFunctionNotFound:\n return False\n\n @staticmethod\n def call_function(\n fn: TContractFn,\n *args: Any,\n transaction: Optional[TxParams] = None,\n block_identifier: BlockIdentifier = None,\n ccip_read_enabled: Optional[bool] = None,\n **kwargs: Any,\n ) -> Any:\n if transaction is None:\n transaction = {}\n return fn(*args, **kwargs).call(\n transaction=transaction,\n block_identifier=block_identifier,\n ccip_read_enabled=ccip_read_enabled,\n )\n\n\nclass BaseContractConstructor:\n \"\"\"\n Class for contract constructor API.\n \"\"\"\n\n def __init__(\n self,\n w3: Union[\"Web3\", \"AsyncWeb3\"],\n abi: ABI,\n bytecode: HexStr,\n *args: Any,\n **kwargs: Any,\n ) -> None:\n self.w3 = w3\n self.abi = abi\n self.bytecode = bytecode\n self.data_in_transaction = self._encode_data_in_transaction(*args, **kwargs)\n\n @combomethod\n def _encode_data_in_transaction(self, *args: Any, **kwargs: Any) -> HexStr:\n constructor_abi = get_constructor_abi(self.abi)\n\n if constructor_abi:\n if not args:\n args = tuple()\n if not kwargs:\n kwargs = {}\n\n arguments = merge_args_and_kwargs(constructor_abi, args, kwargs)\n data = add_0x_prefix(\n encode_abi(self.w3, constructor_abi, arguments, data=self.bytecode)\n )\n else:\n data = to_hex(self.bytecode)\n\n return data\n\n @combomethod\n def _estimate_gas(self, transaction: Optional[TxParams] = None) -> TxParams:\n if transaction is None:\n estimate_gas_transaction: TxParams = {}\n else:\n estimate_gas_transaction = cast(TxParams, dict(**transaction))\n self.check_forbidden_keys_in_transaction(\n estimate_gas_transaction, [\"data\", \"to\"]\n )\n\n if self.w3.eth.default_account is not empty:\n # type ignored b/c check prevents an empty default_account\n estimate_gas_transaction.setdefault(\n \"from\", self.w3.eth.default_account # type: ignore\n )\n\n estimate_gas_transaction[\"data\"] = self.data_in_transaction\n\n return estimate_gas_transaction\n\n \ndef _get_transaction(self, transaction: Optional[TxParams] = None) -> TxParams:\n if transaction is None:\n transact_transaction: TxParams = {}\n else:\n transact_transaction = cast(TxParams, dict(**transaction))\n self.check_forbidden_keys_in_transaction(\n transact_transaction, [\"data\", \"to\"]\n )\n\n if self.w3.eth.default_account is not empty:\n # type ignored b/c check prevents an empty default_account\n transact_transaction.setdefault(\n \"from\", self.w3.eth.default_account # type: ignore\n )\n\n transact_transaction[\"data\"] = self.data_in_transaction\n\n return transact_transaction\n\n @combomethod\n def _build_transaction(self, transaction: Optional[TxParams] = None) -> TxParams:\n built_transaction = self._get_transaction(transaction)\n built_transaction[\"to\"] = Address(b\"\")\n return built_transaction\n\n @staticmethod\n def check_forbidden_keys_in_transaction(\n transaction: TxParams, forbidden_keys: Optional[Collection[str]] = None\n ) -> None:\n keys_found = transaction.keys() & forbidden_keys\n if keys_found:\n raise ValueError(\n f\"Cannot set '{', '.join(keys_found)}' field(s) in transaction\"\n )\n\n\nclass NonExistentFallbackFunction:\n @staticmethod\n def _raise_exception() -> NoReturn:\n raise FallbackNotFound(\"No fallback function was found in the contract ABI.\")\n\n def __getattr__(self, attr: Any) -> Callable[[], None]:\n return self._raise_exception\n\n\nclass NonExistentReceiveFunction:\n @staticmethod\n def _raise_exception() -> NoReturn:\n raise FallbackNotFound(\"No receive function was found in the contract ABI.\")\n\n def __getattr__(self, attr: Any) -> Callable[[], None]:\n return self._raise_exception\n\n# Path: web3/contract/utils.py\nimport itertools\nfrom typing import (\n TYPE_CHECKING,\n Any,\n Callable,\n List,\n Optional,\n Sequence,\n Tuple,\n Type,\n Union,\n)\n\nfrom eth_abi.exceptions import (\n DecodingError,\n)\nfrom eth_typing import (\n ChecksumAddress,\n)\nfrom hexbytes import (\n HexBytes,\n)\n\nfrom web3._utils.abi import (\n filter_by_type,\n get_abi_output_types,\n map_abi_data,\n named_tree,\n recursive_dict_to_namedtuple,\n)\nfrom web3._utils.async_transactions import (\n async_fill_transaction_defaults,\n)\nfrom web3._utils.contracts import (\n find_matching_fn_abi,\n prepare_transaction,\n)\nfrom web3._utils.normalizers import (\n BASE_RETURN_NORMALIZERS,\n)\nfrom web3._utils.transactions import (\n fill_transaction_defaults,\n)\nfrom web3.exceptions import (\n BadFunctionCallOutput,\n)\nfrom web3.types import (\n ABI,\n ABIFunction,\n BlockIdentifier,\n FunctionIdentifier,\n StateOverride,\n TContractFn,\n TxParams,\n)\n\nif TYPE_CHECKING:\n from web3 import ( # noqa: F401\n AsyncWeb3,\n Web3,\n )\n\nACCEPTABLE_EMPTY_STRINGS = [\"0x\", b\"0x\", \"\", b\"\"]\n\n\ndef call_contract_function(\n w3: \"Web3\",\n address: ChecksumAddress,\n normalizers: Tuple[Callable[..., Any], ...],\n function_identifier: FunctionIdentifier,\n transaction: TxParams,\n block_id: Optional[BlockIdentifier] = None,\n contract_abi: Optional[ABI] = None,\n fn_abi: Optional[ABIFunction] = None,\n state_override: Optional[StateOverride] = None,\n ccip_read_enabled: Optional[bool] = None,\n decode_tuples: Optional[bool] = False,\n *args: Any,\n **kwargs: Any,\n) -> Any:\n \"\"\"\n Helper function for interacting with a contract function using the\n `eth_call` API.\n \"\"\"\n call_transaction = prepare_transaction(\n address,\n w3,\n fn_identifier=function_identifier,\n contract_abi=contract_abi,\n fn_abi=fn_abi,\n transaction=transaction,\n fn_args=args,\n fn_kwargs=kwargs,\n )\n\n return_data = w3.eth.call(\n call_transaction,\n block_identifier=block_id,\n state_override=state_override,\n ccip_read_enabled=ccip_read_enabled,\n )\n\n if fn_abi is None:\n fn_abi = find_matching_fn_abi(\n contract_abi, w3.codec, function_identifier, args, kwargs\n )\n\n output_types = get_abi_output_types(fn_abi)\n\n try:\n output_data = w3.codec.decode(output_types, return_data)\n except DecodingError as e:\n # Provide a more helpful error message than the one provided by\n # eth-abi-utils\n is_missing_code_error = (\n return_data in ACCEPTABLE_EMPTY_STRINGS\n and w3.eth.get_code(address) in ACCEPTABLE_EMPTY_STRINGS\n )\n if is_missing_code_error:\n msg = (\n \"Could not transact with/call contract function, is contract \"\n \"deployed correctly and chain synced?\"\n )\n else:\n msg = (\n f\"Could not decode contract function call to {function_identifier} \"\n f\"with return data: {str(return_data)}, output_types: {output_types}\"\n )\n raise BadFunctionCallOutput(msg) from e\n\n _normalizers = itertools.chain(\n BASE_RETURN_NORMALIZERS,\n normalizers,\n )\n normalized_data = map_abi_data(_normalizers, output_types, output_data)\n\n if decode_tuples:\n decoded = named_tree(fn_abi[\"outputs\"], normalized_data)\n normalized_data = recursive_dict_to_namedtuple(decoded)\n\n if len(normalized_data) == 1:\n return normalized_data[0]\n else:\n return normalized_data\n\n\ndef transact_with_contract_function(\n address: ChecksumAddress,\n w3: \"Web3\",\n function_name: Optional[FunctionIdentifier] = None,\n transaction: Optional[TxParams] = None,\n contract_abi: Optional[ABI] = None,\n fn_abi: Optional[ABIFunction] = None,\n *args: Any,\n **kwargs: Any,\n) -> HexBytes:\n \"\"\"\n Helper function for interacting with a contract function by sending a\n transaction.\n \"\"\"\n transact_transaction = prepare_transaction(\n address,\n w3,\n fn_identifier=function_name,\n contract_abi=contract_abi,\n transaction=transaction,\n fn_abi=fn_abi,\n fn_args=args,\n fn_kwargs=kwargs,\n )\n\n txn_hash = w3.eth.send_transaction(transact_transaction)\n return txn_hash\n\n\ndef estimate_gas_for_function(\n address: ChecksumAddress,\n w3: \"Web3\",\n fn_identifier: Optional[FunctionIdentifier] = None,\n transaction: Optional[TxParams] = None,\n contract_abi: Optional[ABI] = None,\n fn_abi: Optional[ABIFunction] = None,\n block_identifier: Optional[BlockIdentifier] = None,\n state_override: Optional[StateOverride] = None,\n *args: Any,\n **kwargs: Any,\n) -> int:\n \"\"\"Estimates gas cost a function call would take.\n\n Don't call this directly, instead use :meth:`Contract.estimate_gas`\n on your contract instance.\n \"\"\"\n estimate_transaction = prepare_transaction(\n address,\n w3,\n fn_identifier=fn_identifier,\n contract_abi=contract_abi,\n fn_abi=fn_abi,\n transaction=transaction,\n fn_args=args,\n fn_kwargs=kwargs,\n )\n\n return w3.eth.estimate_gas(estimate_transaction, block_identifier, state_override)\n\n\ndef build_transaction_for_function(\n address: ChecksumAddress,\n w3: \"Web3\",\n function_name: Optional[FunctionIdentifier] = None,\n transaction: Optional[TxParams] = None,\n contract_abi: Optional[ABI] = None,\n fn_abi: Optional[ABIFunction] = None,\n *args: Any,\n **kwargs: Any,\n) -> TxParams:\n \"\"\"Builds a dictionary with the fields required to make the given transaction\n\n Don't call this directly, instead use :meth:`Contract.build_transaction`\n on your contract instance.\n \"\"\"\n prepared_transaction = prepare_transaction(\n address,\n w3,\n fn_identifier=function_name,\n contract_abi=contract_abi,\n fn_abi=fn_abi,\n transaction=transaction,\n fn_args=args,\n fn_kwargs=kwargs,\n )\n\n prepared_transaction = fill_transaction_defaults(w3, prepared_transaction)\n\n return prepared_transaction\n\n\ndef find_functions_by_identifier(\n contract_abi: ABI,\n w3: Union[\"Web3\", \"AsyncWeb3\"],\n address: ChecksumAddress,\n callable_check: Callable[..., Any],\n function_type: Type[TContractFn],\n) -> List[TContractFn]:\n fns_abi = filter_by_type(\"function\", contract_abi)\n return [\n function_type.factory(\n fn_abi[\"name\"],\n w3=w3,\n contract_abi=contract_abi,\n address=address,\n function_identifier=fn_abi[\"name\"],\n abi=fn_abi,\n )\n for fn_abi in fns_abi\n if callable_check(fn_abi)\n ]\n\n\ndef get_function_by_identifier(\n fns: Sequence[TContractFn], identifier: str\n) -> TContractFn:\n if len(fns) > 1:\n raise ValueError(\n f\"Found multiple functions with matching {identifier}. \" f\"Found: {fns!r}\"\n )\n elif len(fns) == 0:\n raise ValueError(f\"Could not find any function with matching {identifier}\")\n return fns[0]\n\n\n# --- async --- #\n\n\nasync def async_call_contract_function(\n async_w3: \"AsyncWeb3\",\n address: ChecksumAddress,\n normalizers: Tuple[Callable[..., Any], ...],\n function_identifier: FunctionIdentifier,\n transaction: TxParams,\n block_id: Optional[BlockIdentifier] = None,\n contract_abi: Optional[ABI] = None,\n fn_abi: Optional[ABIFunction] = None,\n state_override: Optional[StateOverride] = None,\n ccip_read_enabled: Optional[bool] = None,\n decode_tuples: Optional[bool] = False,\n *args: Any,\n **kwargs: Any,\n) -> Any:\n \"\"\"\n Helper function for interacting with a contract function using the\n `eth_call` API.\n \"\"\"\n call_transaction = prepare_transaction(\n address,\n async_w3,\n fn_identifier=function_identifier,\n contract_abi=contract_abi,\n fn_abi=fn_abi,\n transaction=transaction,\n fn_args=args,\n fn_kwargs=kwargs,\n )\n\n return_data = await async_w3.eth.call(\n call_transaction,\n block_identifier=block_id,\n state_override=state_override,\n ccip_read_enabled=ccip_read_enabled,\n )\n\n if fn_abi is None:\n fn_abi = find_matching_fn_abi(\n contract_abi, async_w3.codec, function_identifier, args, kwargs\n )\n\n output_types = get_abi_output_types(fn_abi)\n\n try:\n output_data = async_w3.codec.decode(output_types, return_data)\n except DecodingError as e:\n # Provide a more helpful error message than the one provided by\n # eth-abi-utils\n is_missing_code_error = (\n return_data in ACCEPTABLE_EMPTY_STRINGS\n and await async_w3.eth.get_code(address) in ACCEPTABLE_EMPTY_STRINGS\n )\n if is_missing_code_error:\n msg = (\n \"Could not transact with/call contract function, is contract \"\n \"deployed correctly and chain synced?\"\n )\n else:\n msg = (\n f\"Could not decode contract function call to {function_identifier} \"\n f\"with return data: {str(return_data)}, output_types: {output_types}\"\n )\n raise BadFunctionCallOutput(msg) from e\n\n _normalizers = itertools.chain(\n BASE_RETURN_NORMALIZERS,\n normalizers,\n )\n normalized_data = map_abi_data(_normalizers, output_types, output_data)\n\n if decode_tuples:\n decoded = named_tree(fn_abi[\"outputs\"], normalized_data)\n normalized_data = recursive_dict_to_namedtuple(decoded)\n\n if len(normalized_data) == 1:\n return normalized_data[0]\n else:\n return normalized_data\n\n\nasync def async_transact_with_contract_function(\n address: ChecksumAddress,\n async_w3: \"AsyncWeb3\",\n function_name: Optional[FunctionIdentifier] = None,\n transaction: Optional[TxParams] = None,\n contract_abi: Optional[ABI] = None,\n fn_abi: Optional[ABIFunction] = None,\n *args: Any,\n **kwargs: Any,\n) -> HexBytes:\n \"\"\"\n Helper function for interacting with a contract function by sending a\n transaction.\n \"\"\"\n transact_transaction = prepare_transaction(\n address,\n async_w3,\n fn_identifier=function_name,\n contract_abi=contract_abi,\n transaction=transaction,\n fn_abi=fn_abi,\n fn_args=args,\n fn_kwargs=kwargs,\n )\n\n txn_hash = await async_w3.eth.send_transaction(transact_transaction)\n return txn_hash\n\n\nasync def async_estimate_gas_for_function(\n address: ChecksumAddress,\n async_w3: \"AsyncWeb3\",\n fn_identifier: Optional[FunctionIdentifier] = None,\n transaction: Optional[TxParams] = None,\n contract_abi: Optional[ABI] = None,\n fn_abi: Optional[ABIFunction] = None,\n block_identifier: Optional[BlockIdentifier] = None,\n state_override: Optional[StateOverride] = None,\n *args: Any,\n **kwargs: Any,\n) -> int:\n \"\"\"Estimates gas cost a function call would take.\n\n Don't call this directly, instead use :meth:`Contract.estimate_gas`\n on your contract instance.\n \"\"\"\n estimate_transaction = prepare_transaction(\n address,\n async_w3,\n fn_identifier=fn_identifier,\n contract_abi=contract_abi,\n fn_abi=fn_abi,\n transaction=transaction,\n fn_args=args,\n fn_kwargs=kwargs,\n )\n\n return await async_w3.eth.estimate_gas(\n estimate_transaction, block_identifier, state_override\n )\n\n\nasync def async_build_transaction_for_function(\n address: ChecksumAddress,\n async_w3: \"AsyncWeb3\",\n function_name: Optional[FunctionIdentifier] = None,\n transaction: Optional[TxParams] = None,\n contract_abi: Optional[ABI] = None,\n fn_abi: Optional[ABIFunction] = None,\n *args: Any,\n **kwargs: Any,\n) -> TxParams:\n \"\"\"Builds a dictionary with the fields required to make the given transaction\n\n Don't call this directly, instead use :meth:`Contract.build_transaction`\n on your contract instance.\n \"\"\"\n prepared_transaction = prepare_transaction(\n address,\n async_w3,\n fn_identifier=function_name,\n contract_abi=contract_abi,\n fn_abi=fn_abi,\n transaction=transaction,\n fn_args=args,\n fn_kwargs=kwargs,\n )\n\n return await async_fill_transaction_defaults(async_w3, prepared_transaction)\n\n# Path: web3/contract/contract.py\nimport copy\nfrom typing import (\n TYPE_CHECKING,\n Any,\n Callable,\n Dict,\n Iterable,\n List,\n Optional,\n Sequence,\n Type,\n cast,\n)\n\nfrom eth_typing import (\n ChecksumAddress,\n)\nfrom eth_utils import (\n combomethod,\n)\nfrom eth_utils.toolz import (\n partial,\n)\nfrom hexbytes import (\n HexBytes,\n)\n\nfrom web3._utils.abi import (\n fallback_func_abi_exists,\n filter_by_type,\n receive_func_abi_exists,\n)\nfrom web3._utils.compat import (\n Self,\n)\nfrom web3._utils.contracts import (\n parse_block_identifier,\n)\nfrom web3._utils.datatypes import (\n PropertyCheckingFactory,\n)\nfrom web3._utils.events import (\n EventFilterBuilder,\n get_event_data,\n)\nfrom web3._utils.filters import (\n LogFilter,\n)\nfrom web3._utils.function_identifiers import (\n FallbackFn,\n ReceiveFn,\n)\nfrom web3._utils.normalizers import (\n normalize_abi,\n normalize_address,\n normalize_bytecode,\n)\nfrom web3._utils.transactions import (\n fill_transaction_defaults,\n)\nfrom web3.contract.base_contract import (\n BaseContract,\n BaseContractCaller,\n BaseContractConstructor,\n BaseContractEvent,\n BaseContractEvents,\n BaseContractFunction,\n BaseContractFunctions,\n NonExistentFallbackFunction,\n NonExistentReceiveFunction,\n)\nfrom web3.contract.utils import (\n build_transaction_for_function,\n call_contract_function,\n estimate_gas_for_function,\n find_functions_by_identifier,\n get_function_by_identifier,\n transact_with_contract_function,\n)\nfrom web3.exceptions import (\n ABIFunctionNotFound,\n NoABIFound,\n NoABIFunctionsFound,\n Web3ValidationError,\n)\nfrom web3.types import (\n ABI,\n BlockIdentifier,\n EventData,\n StateOverride,\n TxParams,\n)\nfrom web3.utils import (\n get_abi_input_names,\n)\n\nif TYPE_CHECKING:\n from ens import ENS # noqa: F401\n from web3 import Web3 # noqa: F401\n\n\nclass ContractEvent(BaseContractEvent):\n # mypy types\n w3: \"Web3\"\n\n @combomethod\n def get_logs(\n self,\n argument_filters: Optional[Dict[str, Any]] = None,\n fromBlock: Optional[BlockIdentifier] = None,\n toBlock: Optional[BlockIdentifier] = None,\n block_hash: Optional[HexBytes] = None,\n ) -> Iterable[EventData]:\n \"\"\"Get events for this contract instance using eth_getLogs API.\n\n This is a stateless method, as opposed to create_filter.\n It can be safely called against nodes which do not provide\n eth_newFilter API, like Infura nodes.\n\n If there are many events,\n like ``Transfer`` events for a popular token,\n the Ethereum node might be overloaded and timeout\n on the underlying JSON-RPC call.\n\n Example - how to get all ERC-20 token transactions\n for the latest 10 blocks:\n\n .. code-block:: python\n\n from = max(mycontract.web3.eth.block_number - 10, 1)\n to = mycontract.web3.eth.block_number\n\n events = mycontract.events.Transfer.get_logs(fromBlock=from, toBlock=to)\n\n for e in events:\n print(e[\"args\"][\"from\"],\n e[\"args\"][\"to\"],\n e[\"args\"][\"value\"])\n\n The returned processed log values will look like:\n\n .. code-block:: python\n\n (\n AttributeDict({\n 'args': AttributeDict({}),\n 'event': 'LogNoArguments',\n 'logIndex': 0,\n 'transactionIndex': 0,\n 'transactionHash': HexBytes('...'),\n 'address': '0xF2E246BB76DF876Cef8b38ae84130F4F55De395b',\n 'blockHash': HexBytes('...'),\n 'blockNumber': 3\n }),\n AttributeDict(...),\n ...\n )\n\n See also: :func:`web3.middleware.filter.LocalFilterMiddleware`.\n\n :param argument_filters: Filter by argument values. Indexed arguments are\n filtered by the node while non-indexed arguments are filtered by the library.\n :param fromBlock: block number or \"latest\", defaults to \"latest\"\n :param toBlock: block number or \"latest\". Defaults to \"latest\"\n :param block_hash: block hash. block_hash cannot be set at the\n same time as fromBlock or toBlock\n :yield: Tuple of :class:`AttributeDict` instances\n \"\"\"\n event_abi = self._get_event_abi()\n\n # validate ``argument_filters`` if present\n if argument_filters is not None:\n event_arg_names = get_abi_input_names(event_abi)\n if not all(arg in event_arg_names for arg in argument_filters.keys()):\n raise Web3ValidationError(\n \"When filtering by argument names, all argument names must be \"\n \"present in the contract's event ABI.\"\n )\n\n _filter_params = self._get_event_filter_params(\n event_abi, argument_filters, fromBlock, toBlock, block_hash\n )\n # call JSON-RPC API\n logs = self.w3.eth.get_logs(_filter_params)\n\n # convert raw binary data to Python proxy objects as described by ABI:\n all_event_logs = tuple(\n get_event_data(self.w3.codec, event_abi, entry) for entry in logs\n )\n filtered_logs = self._process_get_logs_argument_filters(\n event_abi,\n all_event_logs,\n argument_filters,\n )\n sorted_logs = sorted(filtered_logs, key=lambda e: e[\"logIndex\"])\n sorted_logs = sorted(sorted_logs, key=lambda e: e[\"blockNumber\"])\n return sorted_logs\n\n @combomethod\n def create_filter(\n self,\n *, # PEP 3102\n argument_filters: Optional[Dict[str, Any]] = None,\n fromBlock: Optional[BlockIdentifier] = None,\n toBlock: BlockIdentifier = \"latest\",\n address: Optional[ChecksumAddress] = None,\n topics: Optional[Sequence[Any]] = None,\n ) -> LogFilter:\n \"\"\"\n Create filter object that tracks logs emitted by this contract event.\n \"\"\"\n filter_builder = EventFilterBuilder(self._get_event_abi(), self.w3.codec)\n self._set_up_filter_builder(\n argument_filters,\n fromBlock,\n toBlock,\n address,\n topics,\n filter_builder,\n )\n log_filter = filter_builder.deploy(self.w3)\n log_filter.log_entry_formatter = get_event_data(\n self.w3.codec, self._get_event_abi()\n )\n log_filter.builder = filter_builder\n\n return log_filter\n\n @combomethod\n def build_filter(self) -> EventFilterBuilder:\n builder = EventFilterBuilder(\n self._get_event_abi(),\n self.w3.codec,\n formatter=get_event_data(self.w3.codec, self._get_event_abi()),\n )\n builder.address = self.address\n return builder\n\n\nclass ContractEvents(BaseContractEvents):\n def __init__(\n self, abi: ABI, w3: \"Web3\", address: Optional[ChecksumAddress] = None\n ) -> None:\n super().__init__(abi, w3, ContractEvent, address)\n\n\nclass ContractFunction(BaseContractFunction):\n # mypy types\n w3: \"Web3\"\n\n def __call__(self, *args: Any, **kwargs: Any) -> \"ContractFunction\":\n clone = copy.copy(self)\n if args is None:\n clone.args = tuple()\n else:\n clone.args = args\n\n if kwargs is None:\n clone.kwargs = {}\n else:\n clone.kwargs = kwargs\n clone._set_function_info()\n return clone\n\n @classmethod\n def factory(cls, class_name: str, **kwargs: Any) -> Self:\n return PropertyCheckingFactory(class_name, (cls,), kwargs)(kwargs.get(\"abi\"))\n\n def call(\n self,\n transaction: Optional[TxParams] = None,\n block_identifier: BlockIdentifier = None,\n state_override: Optional[StateOverride] = None,\n ccip_read_enabled: Optional[bool] = None,\n ) -> Any:\n \"\"\"\n Execute a contract function call using the `eth_call` interface.\n\n This method prepares a ``Caller`` object that exposes the contract\n functions and public variables as callable Python functions.\n\n Reading a public ``owner`` address variable example:\n\n .. code-block:: python\n\n ContractFactory = w3.eth.contract(\n abi=wallet_contract_definition[\"abi\"]\n )\n\n # Not a real contract address\n contract = ContractFactory(\"0x2f70d3d26829e412A602E83FE8EeBF80255AEeA5\")\n\n # Read \"owner\" public variable\n addr = contract.functions.owner().call()\n\n :param transaction: Dictionary of transaction info for web3 interface\n :param block_identifier: TODO\n :param state_override TODO\n :param ccip_read_enabled TODO\n :return: ``Caller`` object that has contract public functions\n and variables exposed as Python methods\n \"\"\"\n call_transaction = self._get_call_txparams(transaction)\n\n block_id = parse_block_identifier(self.w3, block_identifier)\n\n return call_contract_function(\n self.w3,\n self.address,\n self._return_data_normalizers,\n self.function_identifier,\n call_transaction,\n block_id,\n self.contract_abi,\n self.abi,\n state_override,\n ccip_read_enabled,\n self.decode_tuples,\n *self.args,\n **self.kwargs,\n )\n\n def transact(self, transaction: Optional[TxParams] = None) -> HexBytes:\n setup_transaction = self._transact(transaction)\n return transact_with_contract_function(\n self.address,\n self.w3,\n self.function_identifier,\n setup_transaction,\n self.contract_abi,\n self.abi,\n *self.args,\n **self.kwargs,\n )\n\n def estimate_gas(\n self,\n transaction: Optional[TxParams] = None,\n block_identifier: Optional[BlockIdentifier] = None,\n state_override: Optional[StateOverride] = None,\n ) -> int:\n setup_transaction = self._estimate_gas(transaction)\n return estimate_gas_for_function(\n self.address,\n self.w3,\n self.function_identifier,\n setup_transaction,\n self.contract_abi,\n self.abi,\n block_identifier,\n state_override,\n *self.args,\n **self.kwargs,\n )\n\n def build_transaction(self, transaction: Optional[TxParams] = None) -> TxParams:\n built_transaction = self._build_transaction(transaction)\n return build_transaction_for_function(\n self.address,\n self.w3,\n self.function_identifier,\n built_transaction,\n self.contract_abi,\n self.abi,\n *self.args,\n **self.kwargs,\n )\n\n @staticmethod\n def get_fallback_function(\n abi: ABI,\n w3: \"Web3\",\n address: Optional[ChecksumAddress] = None,\n ) -> \"ContractFunction\":\n if abi and fallback_func_abi_exists(abi):\n return ContractFunction.factory(\n \"fallback\",\n w3=w3,\n contract_abi=abi,\n address=address,\n function_identifier=FallbackFn,\n )()\n return cast(ContractFunction, NonExistentFallbackFunction())\n\n @staticmethod\n def get_receive_function(\n abi: ABI,\n w3: \"Web3\",\n address: Optional[ChecksumAddress] = None,\n ) -> \"ContractFunction\":\n if abi and receive_func_abi_exists(abi):\n return ContractFunction.factory(\n \"receive\",\n w3=w3,\n contract_abi=abi,\n address=address,\n function_identifier=ReceiveFn,\n )()\n return cast(ContractFunction, NonExistentReceiveFunction())\n\n\nclass ContractFunctions(BaseContractFunctions):\n def __init__(\n self,\n abi: ABI,\n w3: \"Web3\",\n address: Optional[ChecksumAddress] = None,\n decode_tuples: Optional[bool] = False,\n ) -> None:\n super().__init__(abi, w3, ContractFunction, address, decode_tuples)\n\n def __getattr__(self, function_name: str) -> \"ContractFunction\":\n if self.abi is None:\n raise NoABIFound(\n \"There is no ABI found for this contract.\",\n )\n if \"_functions\" not in self.__dict__:\n raise NoABIFunctionsFound(\n \"The abi for this contract contains no function definitions. \",\n \"Are you sure you provided the correct contract abi?\",\n )\n elif function_name not in self.__dict__[\"_functions\"]:\n raise ABIFunctionNotFound(\n f\"The function '{function_name}' was not found in this contract's abi.\",\n \" Are you sure you provided the correct contract abi?\",\n )\n else:\n return super().__getattribute__(function_name)\n\n\nclass Contract(BaseContract):\n # mypy types\n w3: \"Web3\"\n functions: ContractFunctions = None\n caller: \"ContractCaller\" = None\n\n # Instance of :class:`ContractEvents` presenting available Event ABIs\n events: ContractEvents = None\n\n def __init__(self, address: Optional[ChecksumAddress] = None) -> None:\n \"\"\"Create a new smart contract proxy object.\n :param address: Contract address as 0x hex string\"\"\"\n _w3 = self.w3\n if _w3 is None:\n raise AttributeError(\n \"The `Contract` class has not been initialized. Please use the \"\n \"`web3.contract` interface to create your contract class.\"\n )\n\n if address:\n self.address = normalize_address(cast(\"ENS\", _w3.ens), address)\n\n if not self.address:\n raise TypeError(\n \"The address argument is required to instantiate a contract.\"\n )\n\n self.functions = ContractFunctions(\n self.abi, _w3, self.address, decode_tuples=self.decode_tuples\n )\n self.caller = ContractCaller(\n self.abi, _w3, self.address, decode_tuples=self.decode_tuples\n )\n self.events = ContractEvents(self.abi, _w3, self.address)\n self.fallback = Contract.get_fallback_function(\n self.abi,\n _w3,\n ContractFunction,\n self.address,\n )\n self.receive = Contract.get_receive_function(\n self.abi,\n _w3,\n ContractFunction,\n self.address,\n )\n\n @classmethod\n def factory(\n cls, w3: \"Web3\", class_name: Optional[str] = None, **kwargs: Any\n ) -> Type[Self]:\n kwargs[\"w3\"] = w3\n\n normalizers = {\n \"abi\": normalize_abi,\n \"address\": partial(normalize_address, w3.ens),\n \"bytecode\": normalize_bytecode,\n \"bytecode_runtime\": normalize_bytecode,\n }\n\n contract = cast(\n Type[Self],\n PropertyCheckingFactory(\n class_name or cls.__name__,\n (cls,),\n kwargs,\n normalizers=normalizers,\n ),\n )\n contract.functions = ContractFunctions(\n contract.abi, contract.w3, decode_tuples=contract.decode_tuples\n )\n contract.caller = ContractCaller(\n contract.abi,\n contract.w3,\n contract.address,\n decode_tuples=contract.decode_tuples,\n )\n contract.events = ContractEvents(contract.abi, contract.w3)\n contract.fallback = Contract.get_fallback_function(\n contract.abi,\n contract.w3,\n ContractFunction,\n )\n contract.receive = Contract.get_receive_function(\n contract.abi,\n contract.w3,\n ContractFunction,\n )\n\n return contract\n\n @classmethod\n def constructor(cls, *args: Any, **kwargs: Any) -> \"ContractConstructor\":\n \"\"\"\n :param args: The contract constructor arguments as positional arguments\n :param kwargs: The contract constructor arguments as keyword arguments\n :return: a contract constructor object\n \"\"\"\n if cls.bytecode is None:\n raise ValueError(\n \"Cannot call constructor on a contract that does not have \"\n \"'bytecode' associated with it\"\n )\n\n return ContractConstructor(cls.w3, cls.abi, cls.bytecode, *args, **kwargs)\n\n @combomethod\n def find_functions_by_identifier(\n cls,\n contract_abi: ABI,\n w3: \"Web3\",\n address: ChecksumAddress,\n callable_check: Callable[..., Any],\n ) -> List[\"ContractFunction\"]:\n return cast(\n List[\"ContractFunction\"],\n find_functions_by_identifier(\n contract_abi, w3, address, callable_check, ContractFunction\n ),\n )\n\n @combomethod\n def get_function_by_identifier(\n cls, fns: Sequence[\"ContractFunction\"], identifier: str\n ) -> \"ContractFunction\":\n return get_function_by_identifier(fns, identifier)\n\n\nclass ContractCaller(BaseContractCaller):\n # mypy types\n w3: \"Web3\"\n\n def __init__(\n self,\n abi: ABI,\n w3: \"Web3\",\n address: ChecksumAddress,\n transaction: Optional[TxParams] = None,\n block_identifier: BlockIdentifier = None,\n ccip_read_enabled: Optional[bool] = None,\n decode_tuples: Optional[bool] = False,\n ) -> None:\n super().__init__(abi, w3, address, decode_tuples=decode_tuples)\n\n if self.abi:\n if transaction is None:\n transaction = {}\n\n self._functions = filter_by_type(\"function\", self.abi)\n for func in self._functions:\n fn = ContractFunction.factory(\n func[\"name\"],\n w3=w3,\n contract_abi=self.abi,\n address=self.address,\n function_identifier=func[\"name\"],\n decode_tuples=decode_tuples,\n )\n\n caller_method = partial(\n self.call_function,\n fn,\n transaction=transaction,\n block_identifier=block_identifier,\n ccip_read_enabled=ccip_read_enabled,\n )\n\n setattr(self, func[\"name\"], caller_method)\n\n def __call__(\n self,\n transaction: Optional[TxParams] = None,\n block_identifier: BlockIdentifier = None,\n ccip_read_enabled: Optional[bool] = None,\n ) -> \"ContractCaller\":\n if transaction is None:\n transaction = {}\n\n return type(self)(\n self.abi,\n self.w3,\n self.address,\n transaction=transaction,\n block_identifier=block_identifier,\n ccip_read_enabled=ccip_read_enabled,\n decode_tuples=self.decode_tuples,\n )\n\n\nclass ContractConstructor(BaseContractConstructor):\n # mypy types\n w3: \"Web3\"\n\n @combomethod\n def transact(self, transaction: Optional[TxParams] = None) -> HexBytes:\n return self.w3.eth.send_transaction(self._get_transaction(transaction))\n\n @combomethod\n def build_transaction(self, transaction: Optional[TxParams] = None) -> TxParams:\n \"\"\"\n Build the transaction dictionary without sending\n \"\"\"\n built_transaction = self._build_transaction(transaction)\n return fill_transaction_defaults(self.w3, built_transaction)\n\n @combomethod\n def estimate_gas(\n self,\n transaction: Optional[TxParams] = None,\n block_identifier: Optional[BlockIdentifier] = None,\n ) -> int:\n transaction = self._estimate_gas(transaction)\n\n return self.w3.eth.estimate_gas(transaction, block_identifier=block_identifier)\n\n# Path: web3/types.py\nfrom typing import (\n TYPE_CHECKING,\n Any,\n Callable,\n Coroutine,\n Dict,\n List,\n NewType,\n Optional,\n Sequence,\n Type,\n TypeVar,\n Union,\n)\n\nfrom eth_typing import (\n Address,\n BlockNumber,\n ChecksumAddress,\n Hash32,\n HexStr,\n)\nfrom hexbytes import (\n HexBytes,\n)\n\nfrom web3._utils.compat import (\n Literal,\n NotRequired,\n TypedDict,\n)\nfrom web3._utils.function_identifiers import (\n FallbackFn,\n ReceiveFn,\n)\n\nif TYPE_CHECKING:\n from web3.contract.async_contract import AsyncContractFunction # noqa: F401\n from web3.contract.contract import ContractFunction # noqa: F401\n from web3.main import ( # noqa: F401\n AsyncWeb3,\n Web3,\n )\n\n\nTReturn = TypeVar(\"TReturn\")\nTParams = TypeVar(\"TParams\")\nTValue = TypeVar(\"TValue\")\n\nBlockParams = Literal[\"latest\", \"earliest\", \"pending\", \"safe\", \"finalized\"]\nBlockIdentifier = Union[BlockParams, BlockNumber, Hash32, HexStr, HexBytes, int]\nLatestBlockParam = Literal[\"latest\"]\n\nFunctionIdentifier = Union[str, Type[FallbackFn], Type[ReceiveFn]]\n\n# bytes, hexbytes, or hexstr representing a 32 byte hash\n_Hash32 = Union[Hash32, HexBytes, HexStr]\nEnodeURI = NewType(\"EnodeURI\", str)\nENS = NewType(\"ENS\", str)\nNonce = NewType(\"Nonce\", int)\nRPCEndpoint = NewType(\"RPCEndpoint\", str)\nTimestamp = NewType(\"Timestamp\", int)\nWei = NewType(\"Wei\", int)\nGwei = NewType(\"Gwei\", int)\nFormatters = Dict[RPCEndpoint, Callable[..., Any]]\n\n\nclass AccessListEntry(TypedDict):\n address: HexStr\n storageKeys: Sequence[HexStr]\n\n\nAccessList = NewType(\"AccessList\", Sequence[AccessListEntry])\n\n\n# todo: move these to eth_typing once web3 is type hinted\nclass ABIEventParams(TypedDict, total=False):\n indexed: bool\n name: str\n type: str\n\n\n...\n", "needle_token_start": 5536, "needle_token_end": 5725, "code_context_ntokens": 16114, "output": ["```python\ndef _get_transaction(self, transaction: Optional[TxParams] = None) -> TxParams:\n if transaction is None:\n transact_transaction: TxParams = {}\n else:\n transact_transaction = cast(TxParams, dict(**transaction))\n self.check_forbidden_keys_in_transaction(\n transact_transaction, [\"data\", \"to\"]\n )\n\n if self.w3.eth.default_account is not empty:\n # type ignored b/c check prevents an empty default_account\n transact_transaction.setdefault(\n \"from\", self.w3.eth.default_account # type: ignore\n )\n\n transact_transaction[\"data\"] = self.data_in_transaction\n\n return transact_transaction\n```"]} +{"repo": "ethereum/web3.py", "name": "test_async_eth_sign_transaction_legacy", "language": "python", "path": "web3/_utils/module_testing/eth_module.py", "position_ratio": 0.45, "description": "\nFunction Description:\n1. **Purpose**: To verify the correct signing and recovery of a transaction using legacy parameters in an asynchronous environment.\n2. **Input**: Transaction parameters including sender, recipient, value, gas, gas price, and nonce.\n3. **Output**: The signed transaction details and confirmation that the recovered signatory matches the original sender.\n4. **Procedure**: \n - Define the transaction parameters.\n - Sign the transaction asynchronously.\n - Recover the signatory from the signed transaction.\n - Validate that the recovered signatory matches the original sender and that all transaction parameters are correctly reflected in the signed transaction.\n\n", "instruction": "Based on the function description and code context, please retrieve and repeat the exact described function from the code context in a code block wrapped by ```:", "template": "instruction\ncode_context\ndescription\ninstruction", "code_context": "# Path: web3/_utils/module_testing/utils.py\nfrom asyncio import (\n iscoroutinefunction,\n)\nimport copy\nfrom typing import (\n TYPE_CHECKING,\n Any,\n Dict,\n Union,\n cast,\n)\n\nfrom toolz import (\n merge,\n)\n\nif TYPE_CHECKING:\n from web3 import ( # noqa: F401\n AsyncWeb3,\n Web3,\n )\n from web3._utils.compat import ( # noqa: F401\n Self,\n )\n from web3.types import ( # noqa: F401\n AsyncMakeRequestFn,\n MakeRequestFn,\n RPCEndpoint,\n RPCResponse,\n )\n\n\nclass RequestMocker:\n \"\"\"\n Context manager to mock requests made by a web3 instance. This is meant to be used\n via a ``request_mocker`` fixture defined within the appropriate context.\n\n Example:\n\n def test_my_w3(w3, request_mocker):\n assert w3.eth.block_number == 0\n\n with request_mocker(w3, mock_results={\"eth_blockNumber\": \"0x1\"}):\n assert w3.eth.block_number == 1\n\n assert w3.eth.block_number == 0\n\n Example with async and a mocked response object:\n\n async def test_my_w3(async_w3, request_mocker):\n def _iter_responses():\n while True:\n yield {\"error\": {\"message\": \"transaction indexing in progress\"}}\n yield {\"error\": {\"message\": \"transaction indexing in progress\"}}\n yield {\"result\": {\"status\": \"0x1\"}}\n\n iter_responses = _iter_responses()\n\n async with request_mocker(\n async_w3,\n mock_responses={\n \"eth_getTransactionReceipt\": lambda *_: next(iter_responses)\n },\n ):\n # assert that the first two error responses are handled and the result\n # is eventually returned when present\n assert await w3.eth.get_transaction_receipt(\"0x1\") == \"0x1\"\n\n\n - ``mock_results`` is a dict mapping method names to the desired \"result\" object of\n the RPC response.\n - ``mock_errors`` is a dict mapping method names to the desired\n \"error\" object of the RPC response.\n -``mock_responses`` is a dict mapping method names to the entire RPC response\n object. This can be useful if you wish to return an iterator which returns\n different responses on each call to the method.\n\n If a method name is not present in any of the dicts above, the request is made as\n usual.\n \"\"\"\n\n def __init__(\n self,\n w3: Union[\"AsyncWeb3\", \"Web3\"],\n mock_results: Dict[Union[\"RPCEndpoint\", str], Any] = None,\n mock_errors: Dict[Union[\"RPCEndpoint\", str], Any] = None,\n mock_responses: Dict[Union[\"RPCEndpoint\", str], Any] = None,\n ):\n self.w3 = w3\n...\n# Path: web3/_utils/module_testing/eth_module.py\nimport json\nimport math\nimport pytest\nfrom random import (\n randint,\n)\nimport re\nfrom typing import (\n TYPE_CHECKING,\n Any,\n Callable,\n List,\n Type,\n Union,\n cast,\n)\n\nimport eth_abi as abi\nfrom eth_typing import (\n BlockNumber,\n ChecksumAddress,\n HexAddress,\n HexStr,\n)\nfrom eth_utils import (\n is_boolean,\n is_bytes,\n is_checksum_address,\n is_dict,\n is_integer,\n is_list_like,\n is_same_address,\n is_string,\n remove_0x_prefix,\n to_bytes,\n)\nfrom eth_utils.toolz import (\n assoc,\n)\nfrom hexbytes import (\n HexBytes,\n)\n\nfrom web3._utils.empty import (\n empty,\n)\nfrom web3._utils.ens import (\n ens_addresses,\n)\nfrom web3._utils.error_formatters_utils import (\n PANIC_ERROR_CODES,\n)\nfrom web3._utils.fee_utils import (\n PRIORITY_FEE_MIN,\n)\nfrom web3._utils.method_formatters import (\n to_hex_if_integer,\n)\nfrom web3._utils.module_testing.module_testing_utils import (\n assert_contains_log,\n async_mock_offchain_lookup_request_response,\n flaky_geth_dev_mining,\n mock_offchain_lookup_request_response,\n)\nfrom web3._utils.module_testing.utils import (\n RequestMocker,\n)\nfrom web3._utils.type_conversion import (\n to_hex_if_bytes,\n)\nfrom web3.exceptions import (\n BlockNotFound,\n ContractCustomError,\n ContractLogicError,\n ContractPanicError,\n InvalidAddress,\n InvalidTransaction,\n MultipleFailedRequests,\n NameNotFound,\n OffchainLookup,\n TimeExhausted,\n TooManyRequests,\n TransactionNotFound,\n TransactionTypeMismatch,\n Web3ValidationError,\n)\nfrom web3.middleware import (\n ExtraDataToPOAMiddleware,\n)\nfrom web3.types import (\n ENS,\n BlockData,\n FilterParams,\n Nonce,\n RPCEndpoint,\n StateOverrideParams,\n SyncStatus,\n TxData,\n TxParams,\n Wei,\n)\n\nUNKNOWN_ADDRESS = ChecksumAddress(\n HexAddress(HexStr(\"0xdEADBEeF00000000000000000000000000000000\"))\n)\n\nUNKNOWN_HASH = HexStr(\n \"0xdeadbeef00000000000000000000000000000000000000000000000000000000\"\n)\n# \"test offchain lookup\" as an abi-encoded string\nOFFCHAIN_LOOKUP_TEST_DATA = \"0x0000000000000000000000000000000000000000000000000000000000000020000000000000000000000000000000000000000000000000000000000000001474657374206f6666636861696e206c6f6f6b7570000000000000000000000000\" # noqa: E501\nOFFCHAIN_LOOKUP_4BYTE_DATA = \"0x556f1830\"\nOFFCHAIN_LOOKUP_RETURN_DATA = \"00000000000000000000000000000000000000000000000000000000000000a000000000000000000000000000000000000000000000000000000000000001c0da96d05a0000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000002400000000000000000000000000000000000000000000000000000000000000002000000000000000000000000000000000000000000000000000000000000004000000000000000000000000000000000000000000000000000000000000000a0000000000000000000000000000000000000000000000000000000000000002c68747470733a2f2f776562332e70792f676174657761792f7b73656e6465727d2f7b646174617d2e6a736f6e0000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000002568747470733a2f2f776562332e70792f676174657761792f7b73656e6465727d2e6a736f6e00000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000600000000000000000000000000000000000000000000000000000000000000020000000000000000000000000000000000000000000000000000000000000001474657374206f6666636861696e206c6f6f6b757000000000000000000000000000000000000000000000000000000000000000000000000000000000000000600000000000000000000000000000000000000000000000000000000000000020000000000000000000000000000000000000000000000000000000000000001474657374206f6666636861696e206c6f6f6b7570000000000000000000000000\" # noqa: E501\n# \"web3py\" as an abi-encoded string\nWEB3PY_AS_HEXBYTES = \"0x000000000000000000000000000000000000000000000000000000000000002000000000000000000000000000000000000000000000000000000000000000067765623370790000000000000000000000000000000000000000000000000000\" # noqa: E501\n\nRLP_ACCESS_LIST = [\n (\n \"0xde0b295669a9fd93d5f28d9ec85e40f4cb697bae\",\n (\n \"0x0000000000000000000000000000000000000000000000000000000000000003\",\n \"0x0000000000000000000000000000000000000000000000000000000000000007\",\n ),\n ),\n (\"0xbb9bc244d798123fde783fcc1c72d3bb8c189413\", ()),\n]\n\nRPC_ACCESS_LIST = [\n {\n \"address\": \"0xde0b295669a9fd93d5f28d9ec85e40f4cb697bae\",\n \"storageKeys\": (\n \"0x0000000000000000000000000000000000000000000000000000000000000003\",\n \"0x0000000000000000000000000000000000000000000000000000000000000007\",\n ),\n },\n {\"address\": \"0xbb9bc244d798123fde783fcc1c72d3bb8c189413\", \"storageKeys\": ()},\n]\n\nif TYPE_CHECKING:\n from _pytest.monkeypatch import MonkeyPatch # noqa: F401\n\n from web3.contract import Contract # noqa: F401\n from web3.main import ( # noqa: F401\n AsyncWeb3,\n Web3,\n )\n\n\ndef abi_encoded_offchain_lookup_contract_address(\n w3: Union[\"Web3\", \"AsyncWeb3\"], offchain_lookup_contract: \"Contract\"\n) -> HexAddress:\n return HexAddress(\n remove_0x_prefix(\n w3.to_hex(\n abi.encode(\n [\"address\"],\n [to_bytes(hexstr=offchain_lookup_contract.address)],\n )\n )\n )\n )\n\n\nclass AsyncEthModuleTest:\n @pytest.mark.asyncio\n async def test_eth_gas_price(self, async_w3: \"AsyncWeb3\") -> None:\n gas_price = await async_w3.eth.gas_price\n\n assert gas_price > 0\n\n @pytest.mark.asyncio\n async def test_is_connected(self, async_w3: \"AsyncWeb3\") -> None:\n is_connected = await async_w3.is_connected()\n assert is_connected is True\n\n @pytest.mark.asyncio\n async def test_eth_send_transaction_legacy(\n self, async_w3: \"AsyncWeb3\", async_unlocked_account_dual_type: ChecksumAddress\n ) -> None:\n txn_params: TxParams = {\n \"from\": async_unlocked_account_dual_type,\n \"to\": async_unlocked_account_dual_type,\n \"value\": Wei(1),\n \"gas\": 21000,\n \"gasPrice\": await async_w3.eth.gas_price,\n }\n txn_hash = await async_w3.eth.send_transaction(txn_params)\n txn = await async_w3.eth.get_transaction(txn_hash)\n\n assert is_same_address(txn[\"from\"], cast(ChecksumAddress, txn_params[\"from\"]))\n assert is_same_address(txn[\"to\"], cast(ChecksumAddress, txn_params[\"to\"]))\n assert txn[\"value\"] == 1\n assert txn[\"gas\"] == 21000\n assert txn[\"gasPrice\"] == txn_params[\"gasPrice\"]\n\n @pytest.mark.asyncio\n async def test_eth_modify_transaction_legacy(\n self, async_w3: \"AsyncWeb3\", async_unlocked_account: ChecksumAddress\n ) -> None:\n txn_params: TxParams = {\n \"from\": async_unlocked_account,\n \"to\": async_unlocked_account,\n \"value\": Wei(1),\n \"gas\": 21000,\n \"gasPrice\": async_w3.to_wei(\n 1, \"gwei\"\n ), # must be greater than base_fee post London\n }\n txn_hash = await async_w3.eth.send_transaction(txn_params)\n\n modified_txn_hash = await async_w3.eth.modify_transaction(\n txn_hash, gasPrice=(cast(int, txn_params[\"gasPrice\"]) * 2), value=2\n )\n modified_txn = await async_w3.eth.get_transaction(modified_txn_hash)\n\n assert is_same_address(\n modified_txn[\"from\"], cast(ChecksumAddress, txn_params[\"from\"])\n )\n assert is_same_address(\n modified_txn[\"to\"], cast(ChecksumAddress, txn_params[\"to\"])\n )\n assert modified_txn[\"value\"] == 2\n assert modified_txn[\"gas\"] == 21000\n assert modified_txn[\"gasPrice\"] == cast(int, txn_params[\"gasPrice\"]) * 2\n\n @pytest.mark.asyncio\n async def test_eth_modify_transaction(\n self, async_w3: \"AsyncWeb3\", async_unlocked_account: ChecksumAddress\n ) -> None:\n txn_params: TxParams = {\n \"from\": async_unlocked_account,\n \"to\": async_unlocked_account,\n \"value\": Wei(1),\n \"gas\": 21000,\n \"maxPriorityFeePerGas\": async_w3.to_wei(1, \"gwei\"),\n \"maxFeePerGas\": async_w3.to_wei(2, \"gwei\"),\n }\n txn_hash = await async_w3.eth.send_transaction(txn_params)\n\n modified_txn_hash = await async_w3.eth.modify_transaction(\n txn_hash,\n value=2,\n maxPriorityFeePerGas=(cast(Wei, txn_params[\"maxPriorityFeePerGas\"]) * 2),\n maxFeePerGas=(cast(Wei, txn_params[\"maxFeePerGas\"]) * 2),\n )\n modified_txn = await async_w3.eth.get_transaction(modified_txn_hash)\n\n assert is_same_address(\n modified_txn[\"from\"], cast(ChecksumAddress, txn_params[\"from\"])\n )\n assert is_same_address(\n modified_txn[\"to\"], cast(ChecksumAddress, txn_params[\"to\"])\n )\n assert modified_txn[\"value\"] == 2\n assert modified_txn[\"gas\"] == 21000\n assert (\n modified_txn[\"maxPriorityFeePerGas\"]\n == cast(Wei, txn_params[\"maxPriorityFeePerGas\"]) * 2\n )\n assert modified_txn[\"maxFeePerGas\"] == cast(Wei, txn_params[\"maxFeePerGas\"]) * 2\n\n @pytest.mark.asyncio\n async def test_async_eth_sign_transaction(\n self, async_w3: \"AsyncWeb3\", async_unlocked_account: ChecksumAddress\n ) -> None:\n txn_params: TxParams = {\n \"from\": async_unlocked_account,\n \"to\": async_unlocked_account,\n \"value\": Wei(1),\n \"gas\": 21000,\n \"maxFeePerGas\": async_w3.to_wei(2, \"gwei\"),\n \"maxPriorityFeePerGas\": async_w3.to_wei(1, \"gwei\"),\n \"nonce\": Nonce(0),\n }\n result = await async_w3.eth.sign_transaction(txn_params)\n signatory_account = async_w3.eth.account.recover_transaction(result[\"raw\"])\n assert async_unlocked_account == signatory_account\n assert result[\"tx\"][\"to\"] == txn_params[\"to\"]\n assert result[\"tx\"][\"value\"] == txn_params[\"value\"]\n assert result[\"tx\"][\"gas\"] == txn_params[\"gas\"]\n assert result[\"tx\"][\"maxFeePerGas\"] == txn_params[\"maxFeePerGas\"]\n assert (\n result[\"tx\"][\"maxPriorityFeePerGas\"] == txn_params[\"maxPriorityFeePerGas\"]\n )\n assert result[\"tx\"][\"nonce\"] == txn_params[\"nonce\"]\n\n @pytest.mark.asyncio\n async def test_eth_sign_typed_data(\n self,\n async_w3: \"AsyncWeb3\",\n async_unlocked_account_dual_type: ChecksumAddress,\n async_skip_if_testrpc: Callable[[\"AsyncWeb3\"], None],\n ) -> None:\n validJSONMessage = \"\"\"\n {\n \"types\": {\n \"EIP712Domain\": [\n {\"name\": \"name\", \"type\": \"string\"},\n {\"name\": \"version\", \"type\": \"string\"},\n {\"name\": \"chainId\", \"type\": \"uint256\"},\n {\"name\": \"verifyingContract\", \"type\": \"address\"}\n ],\n \"Person\": [\n {\"name\": \"name\", \"type\": \"string\"},\n {\"name\": \"wallet\", \"type\": \"address\"}\n ],\n \"Mail\": [\n {\"name\": \"from\", \"type\": \"Person\"},\n {\"name\": \"to\", \"type\": \"Person\"},\n {\"name\": \"contents\", \"type\": \"string\"}\n ]\n },\n \"primaryType\": \"Mail\",\n \"domain\": {\n \"name\": \"Ether Mail\",\n \"version\": \"1\",\n \"chainId\": \"0x01\",\n \"verifyingContract\": \"0xCcCCccccCCCCcCCCCCCcCcCccCcCCCcCcccccccC\"\n },\n \"message\": {\n \"from\": {\n \"name\": \"Cow\",\n \"wallet\": \"0xCD2a3d9F938E13CD947Ec05AbC7FE734Df8DD826\"\n },\n \"to\": {\n \"name\": \"Bob\",\n \"wallet\": \"0xbBbBBBBbbBBBbbbBbbBbbbbBBbBbbbbBbBbbBBbB\"\n },\n \"contents\": \"Hello, Bob!\"\n }\n }\n \"\"\"\n async_skip_if_testrpc(async_w3)\n signature = HexBytes(\n await async_w3.eth.sign_typed_data(\n async_unlocked_account_dual_type, json.loads(validJSONMessage)\n )\n )\n assert len(signature) == 32 + 32 + 1\n\n @pytest.mark.asyncio\n async def test_invalid_eth_sign_typed_data(\n self,\n async_w3: \"AsyncWeb3\",\n async_unlocked_account_dual_type: ChecksumAddress,\n async_skip_if_testrpc: Callable[[\"AsyncWeb3\"], None],\n ) -> None:\n async_skip_if_testrpc(async_w3)\n invalid_typed_message = \"\"\"\n {\n \"types\": {\n \"EIP712Domain\": [\n {\"name\": \"name\", \"type\": \"string\"},\n {\"name\": \"version\", \"type\": \"string\"},\n {\"name\": \"chainId\", \"type\": \"uint256\"},\n {\"name\": \"verifyingContract\", \"type\": \"address\"}\n ],\n \"Person\": [\n {\"name\": \"name\", \"type\": \"string\"},\n {\"name\": \"wallet\", \"type\": \"address\"}\n ],\n \"Mail\": [\n {\"name\": \"from\", \"type\": \"Person\"},\n {\"name\": \"to\", \"type\": \"Person[2]\"},\n {\"name\": \"contents\", \"type\": \"string\"}\n ]\n },\n \"primaryType\": \"Mail\",\n \"domain\": {\n \"name\": \"Ether Mail\",\n \"version\": \"1\",\n \"chainId\": \"0x01\",\n \"verifyingContract\": \"0xCcCCccccCCCCcCCCCCCcCcCccCcCCCcCcccccccC\"\n },\n \"message\": {\n \"from\": {\n \"name\": \"Cow\",\n \"wallet\": \"0xCD2a3d9F938E13CD947Ec05AbC7FE734Df8DD826\"\n },\n \"to\": [{\n \"name\": \"Bob\",\n \"wallet\": \"0xbBbBBBBbbBBBbbbBbbBbbbbBBbBbbbbBbBbbBBbB\"\n }],\n \"contents\": \"Hello, Bob!\"\n }\n }\n \"\"\"\n with pytest.raises(\n ValueError,\n match=r\".*Expected 2 items for array type Person\\[2\\], got 1 items.*\",\n ):\n await async_w3.eth.sign_typed_data(\n async_unlocked_account_dual_type, json.loads(invalid_typed_message)\n )\n\n @pytest.mark.asyncio\n \nasync def test_async_eth_sign_transaction_legacy(\n self, async_w3: \"AsyncWeb3\", async_unlocked_account: ChecksumAddress\n ) -> None:\n txn_params: TxParams = {\n \"from\": async_unlocked_account,\n \"to\": async_unlocked_account,\n \"value\": Wei(1),\n \"gas\": 21000,\n \"gasPrice\": await async_w3.eth.gas_price,\n \"nonce\": Nonce(0),\n }\n result = await async_w3.eth.sign_transaction(txn_params)\n signatory_account = async_w3.eth.account.recover_transaction(result[\"raw\"])\n assert async_unlocked_account == signatory_account\n assert result[\"tx\"][\"to\"] == txn_params[\"to\"]\n assert result[\"tx\"][\"value\"] == txn_params[\"value\"]\n assert result[\"tx\"][\"gas\"] == txn_params[\"gas\"]\n assert result[\"tx\"][\"gasPrice\"] == txn_params[\"gasPrice\"]\n assert result[\"tx\"][\"nonce\"] == txn_params[\"nonce\"]\n\n @pytest.mark.asyncio\n async def test_async_eth_sign_transaction_hex_fees(\n self, async_w3: \"AsyncWeb3\", async_unlocked_account: ChecksumAddress\n ) -> None:\n txn_params: TxParams = {\n \"from\": async_unlocked_account,\n \"to\": async_unlocked_account,\n \"value\": Wei(1),\n \"gas\": 21000,\n \"maxFeePerGas\": hex(async_w3.to_wei(2, \"gwei\")),\n \"maxPriorityFeePerGas\": hex(async_w3.to_wei(1, \"gwei\")),\n \"nonce\": Nonce(0),\n }\n result = await async_w3.eth.sign_transaction(txn_params)\n signatory_account = async_w3.eth.account.recover_transaction(result[\"raw\"])\n assert async_unlocked_account == signatory_account\n assert result[\"tx\"][\"to\"] == txn_params[\"to\"]\n assert result[\"tx\"][\"value\"] == txn_params[\"value\"]\n assert result[\"tx\"][\"gas\"] == txn_params[\"gas\"]\n assert result[\"tx\"][\"maxFeePerGas\"] == int(str(txn_params[\"maxFeePerGas\"]), 16)\n assert result[\"tx\"][\"maxPriorityFeePerGas\"] == int(\n str(txn_params[\"maxPriorityFeePerGas\"]), 16\n )\n assert result[\"tx\"][\"nonce\"] == txn_params[\"nonce\"]\n\n @pytest.mark.asyncio\n @pytest.mark.xfail(\n reason=\"async name_to_address_middleware has not been implemented yet\"\n )\n async def test_async_eth_sign_transaction_ens_names(\n self, async_w3: \"AsyncWeb3\", async_unlocked_account: ChecksumAddress\n ) -> None:\n with ens_addresses(async_w3, {\"unlocked-account.eth\": async_unlocked_account}):\n txn_params: TxParams = {\n \"from\": \"unlocked-account.eth\",\n \"to\": \"unlocked-account.eth\",\n \"value\": Wei(1),\n \"gas\": 21000,\n \"maxFeePerGas\": async_w3.to_wei(2, \"gwei\"),\n \"maxPriorityFeePerGas\": async_w3.to_wei(1, \"gwei\"),\n \"nonce\": Nonce(0),\n }\n result = await async_w3.eth.sign_transaction(txn_params)\n signatory_account = async_w3.eth.account.recover_transaction(result[\"raw\"])\n assert async_unlocked_account == signatory_account\n assert result[\"tx\"][\"to\"] == async_unlocked_account\n assert result[\"tx\"][\"value\"] == txn_params[\"value\"]\n assert result[\"tx\"][\"gas\"] == txn_params[\"gas\"]\n assert result[\"tx\"][\"maxFeePerGas\"] == txn_params[\"maxFeePerGas\"]\n assert (\n result[\"tx\"][\"maxPriorityFeePerGas\"]\n == txn_params[\"maxPriorityFeePerGas\"]\n )\n assert result[\"tx\"][\"nonce\"] == txn_params[\"nonce\"]\n\n @pytest.mark.asyncio\n async def test_eth_send_transaction(\n self, async_w3: \"AsyncWeb3\", async_unlocked_account_dual_type: ChecksumAddress\n ) -> None:\n txn_params: TxParams = {\n \"from\": async_unlocked_account_dual_type,\n \"to\": async_unlocked_account_dual_type,\n \"value\": Wei(1),\n \"gas\": 21000,\n \"maxFeePerGas\": async_w3.to_wei(3, \"gwei\"),\n \"maxPriorityFeePerGas\": async_w3.to_wei(1, \"gwei\"),\n }\n txn_hash = await async_w3.eth.send_transaction(txn_params)\n txn = await async_w3.eth.get_transaction(txn_hash)\n\n assert is_same_address(txn[\"from\"], cast(ChecksumAddress, txn_params[\"from\"]))\n assert is_same_address(txn[\"to\"], cast(ChecksumAddress, txn_params[\"to\"]))\n assert txn[\"value\"] == 1\n assert txn[\"gas\"] == 21000\n assert txn[\"maxFeePerGas\"] == txn_params[\"maxFeePerGas\"]\n assert txn[\"maxPriorityFeePerGas\"] == txn_params[\"maxPriorityFeePerGas\"]\n assert txn[\"gasPrice\"] == txn_params[\"maxFeePerGas\"]\n\n @pytest.mark.asyncio\n async def test_eth_send_transaction_default_fees(\n self, async_w3: \"AsyncWeb3\", async_unlocked_account_dual_type: ChecksumAddress\n ) -> None:\n txn_params: TxParams = {\n \"from\": async_unlocked_account_dual_type,\n \"to\": async_unlocked_account_dual_type,\n \"value\": Wei(1),\n \"gas\": 21000,\n }\n txn_hash = await async_w3.eth.send_transaction(txn_params)\n txn = await async_w3.eth.get_transaction(txn_hash)\n\n assert is_same_address(txn[\"from\"], cast(ChecksumAddress, txn_params[\"from\"]))\n assert is_same_address(txn[\"to\"], cast(ChecksumAddress, txn_params[\"to\"]))\n assert txn[\"value\"] == 1\n assert txn[\"gas\"] == 21000\n assert is_integer(txn[\"maxPriorityFeePerGas\"])\n assert is_integer(txn[\"maxFeePerGas\"])\n assert txn[\"gasPrice\"] == txn[\"maxFeePerGas\"]\n\n @pytest.mark.asyncio\n async def test_eth_send_transaction_hex_fees(\n self, async_w3: \"AsyncWeb3\", async_unlocked_account_dual_type: ChecksumAddress\n ) -> None:\n txn_params: TxParams = {\n \"from\": async_unlocked_account_dual_type,\n \"to\": async_unlocked_account_dual_type,\n \"value\": Wei(1),\n \"gas\": 21000,\n \"maxFeePerGas\": hex(250 * 10**9),\n \"maxPriorityFeePerGas\": hex(2 * 10**9),\n }\n txn_hash = await async_w3.eth.send_transaction(txn_params)\n txn = await async_w3.eth.get_transaction(txn_hash)\n\n assert is_same_address(txn[\"from\"], cast(ChecksumAddress, txn_params[\"from\"]))\n assert is_same_address(txn[\"to\"], cast(ChecksumAddress, txn_params[\"to\"]))\n assert txn[\"value\"] == 1\n assert txn[\"gas\"] == 21000\n assert txn[\"maxFeePerGas\"] == 250 * 10**9\n assert txn[\"maxPriorityFeePerGas\"] == 2 * 10**9\n\n @pytest.mark.asyncio\n async def test_eth_send_transaction_no_gas(\n self, async_w3: \"AsyncWeb3\", async_unlocked_account_dual_type: ChecksumAddress\n ) -> None:\n txn_params: TxParams = {\n \"from\": async_unlocked_account_dual_type,\n \"to\": async_unlocked_account_dual_type,\n \"value\": Wei(1),\n \"maxFeePerGas\": Wei(250 * 10**9),\n \"maxPriorityFeePerGas\": Wei(2 * 10**9),\n }\n txn_hash = await async_w3.eth.send_transaction(txn_params)\n txn = await async_w3.eth.get_transaction(txn_hash)\n\n assert is_same_address(txn[\"from\"], cast(ChecksumAddress, txn_params[\"from\"]))\n assert is_same_address(txn[\"to\"], cast(ChecksumAddress, txn_params[\"to\"]))\n assert txn[\"value\"] == 1\n assert txn[\"gas\"] == 121000 # 21000 + buffer\n\n @pytest.mark.asyncio\n async def test_eth_send_transaction_with_gas_price(\n self, async_w3: \"AsyncWeb3\", async_unlocked_account_dual_type: ChecksumAddress\n ) -> None:\n txn_params: TxParams = {\n \"from\": async_unlocked_account_dual_type,\n \"to\": async_unlocked_account_dual_type,\n \"value\": Wei(1),\n \"gas\": 21000,\n \"gasPrice\": Wei(1),\n \"maxFeePerGas\": Wei(250 * 10**9),\n \"maxPriorityFeePerGas\": Wei(2 * 10**9),\n }\n with pytest.raises(TransactionTypeMismatch):\n await async_w3.eth.send_transaction(txn_params)\n\n @pytest.mark.asyncio\n async def test_eth_send_transaction_no_priority_fee(\n self, async_w3: \"AsyncWeb3\", async_unlocked_account_dual_type: ChecksumAddress\n ) -> None:\n txn_params: TxParams = {\n \"from\": async_unlocked_account_dual_type,\n \"to\": async_unlocked_account_dual_type,\n \"value\": Wei(1),\n \"gas\": 21000,\n \"maxFeePerGas\": Wei(250 * 10**9),\n }\n with pytest.raises(\n InvalidTransaction, match=\"maxPriorityFeePerGas must be defined\"\n ):\n await async_w3.eth.send_transaction(txn_params)\n\n @pytest.mark.asyncio\n async def test_eth_send_transaction_no_max_fee(\n self, async_w3: \"AsyncWeb3\", async_unlocked_account_dual_type: ChecksumAddress\n ) -> None:\n maxPriorityFeePerGas = async_w3.to_wei(2, \"gwei\")\n txn_params: TxParams = {\n \"from\": async_unlocked_account_dual_type,\n \"to\": async_unlocked_account_dual_type,\n \"value\": Wei(1),\n \"gas\": 21000,\n \"maxPriorityFeePerGas\": maxPriorityFeePerGas,\n }\n txn_hash = await async_w3.eth.send_transaction(txn_params)\n txn = await async_w3.eth.get_transaction(txn_hash)\n\n assert is_same_address(txn[\"from\"], cast(ChecksumAddress, txn_params[\"from\"]))\n assert is_same_address(txn[\"to\"], cast(ChecksumAddress, txn_params[\"to\"]))\n assert txn[\"value\"] == 1\n assert txn[\"gas\"] == 21000\n\n block = await async_w3.eth.get_block(\"latest\")\n assert txn[\"maxFeePerGas\"] == maxPriorityFeePerGas + 2 * block[\"baseFeePerGas\"]\n\n @pytest.mark.asyncio\n async def test_eth_send_transaction_max_fee_less_than_tip(\n self, async_w3: \"AsyncWeb3\", async_unlocked_account_dual_type: ChecksumAddress\n ) -> None:\n txn_params: TxParams = {\n \"from\": async_unlocked_account_dual_type,\n \"to\": async_unlocked_account_dual_type,\n \"value\": Wei(1),\n \"gas\": 21000,\n \"maxFeePerGas\": Wei(1 * 10**9),\n \"maxPriorityFeePerGas\": Wei(2 * 10**9),\n }\n with pytest.raises(\n InvalidTransaction, match=\"maxFeePerGas must be >= maxPriorityFeePerGas\"\n ):\n await async_w3.eth.send_transaction(txn_params)\n\n @pytest.mark.asyncio\n async def test_validation_middleware_chain_id_mismatch(\n self, async_w3: \"AsyncWeb3\", async_unlocked_account_dual_type: ChecksumAddress\n ) -> None:\n wrong_chain_id = 1234567890\n actual_chain_id = await async_w3.eth.chain_id\n\n txn_params: TxParams = {\n \"from\": async_unlocked_account_dual_type,\n \"to\": async_unlocked_account_dual_type,\n \"value\": Wei(1),\n \"gas\": 21000,\n \"maxFeePerGas\": async_w3.to_wei(2, \"gwei\"),\n \"maxPriorityFeePerGas\": async_w3.to_wei(1, \"gwei\"),\n \"chainId\": wrong_chain_id,\n }\n with pytest.raises(\n Web3ValidationError,\n match=f\"The transaction declared chain ID {wrong_chain_id}, \"\n f\"but the connected node is on {actual_chain_id}\",\n ):\n await async_w3.eth.send_transaction(txn_params)\n\n @pytest.mark.asyncio\n async def test_ExtraDataToPOAMiddleware(\n self, async_w3: \"AsyncWeb3\", request_mocker: Type[RequestMocker]\n ) -> None:\n async_w3.middleware_onion.inject(ExtraDataToPOAMiddleware, \"poa\", layer=0)\n extra_data = f\"0x{'ff' * 33}\"\n\n async with request_mocker(\n async_w3,\n mock_results={\"eth_getBlockByNumber\": {\"extraData\": extra_data}},\n ):\n block = await async_w3.eth.get_block(\"latest\")\n\n assert \"extraData\" not in block\n assert block[\"proofOfAuthorityData\"] == to_bytes(hexstr=extra_data)\n\n # clean up\n async_w3.middleware_onion.remove(\"poa\")\n\n @pytest.mark.asyncio\n async def test_eth_send_raw_transaction(self, async_w3: \"AsyncWeb3\") -> None:\n # private key 0x3c2ab4e8f17a7dea191b8c991522660126d681039509dc3bb31af7c9bdb63518\n # This is an unfunded account, but the transaction has a 0 gas price, so is\n # valid. It never needs to be mined, we just want the transaction hash back\n # to confirm.\n # tx = {'to': '0x0000000000000000000000000000000000000000', 'value': 0, 'nonce': 1, 'gas': 21000, 'gasPrice': 0, 'chainId': 131277322940537} # noqa: E501\n # NOTE: nonce=1 to make txn unique from the non-async version of this test\n raw_txn = HexBytes(\n \"0xf8650180825208940000000000000000000000000000000000000000808086eecac466e115a0ffdd42d7dee4ac85427468bc616812e49432e285e4e8f5cd9381163ac3b28108a04ec6b0d89ecbd5e89b0399f336ad50f283fafd70e86593250bf5a2adfb93d17e\" # noqa: E501\n )\n expected_hash = HexStr(\n \"0x52b0ff9cb472f25872fa8ec6a62fa59454fc2ae7901cfcc6cc89d096f49b8fc1\"\n )\n txn_hash = await async_w3.eth.send_raw_transaction(raw_txn)\n assert txn_hash == async_w3.to_bytes(hexstr=expected_hash)\n\n @pytest.mark.asyncio\n async def test_GasPriceStrategyMiddleware(\n self, async_w3: \"AsyncWeb3\", async_unlocked_account_dual_type: ChecksumAddress\n ) -> None:\n txn_params: TxParams = {\n \"from\": async_unlocked_account_dual_type,\n \"to\": async_unlocked_account_dual_type,\n \"value\": Wei(1),\n \"gas\": 21000,\n }\n two_gwei_in_wei = async_w3.to_wei(2, \"gwei\")\n\n def gas_price_strategy(w3: \"Web3\", txn: TxParams) -> Wei:\n return two_gwei_in_wei\n\n async_w3.eth.set_gas_price_strategy(gas_price_strategy)\n\n txn_hash = await async_w3.eth.send_transaction(txn_params)\n txn = await async_w3.eth.get_transaction(txn_hash)\n\n assert txn[\"gasPrice\"] == two_gwei_in_wei\n async_w3.eth.set_gas_price_strategy(None) # reset strategy\n\n @pytest.mark.asyncio\n async def test_gas_price_strategy_middleware_hex_value(\n self, async_w3: \"AsyncWeb3\", async_unlocked_account_dual_type: ChecksumAddress\n ) -> None:\n txn_params: TxParams = {\n \"from\": async_unlocked_account_dual_type,\n \"to\": async_unlocked_account_dual_type,\n \"value\": Wei(1),\n \"gas\": 21000,\n }\n two_gwei_in_wei = async_w3.to_wei(2, \"gwei\")\n\n def gas_price_strategy(_w3: \"Web3\", _txn: TxParams) -> str:\n return hex(two_gwei_in_wei)\n\n async_w3.eth.set_gas_price_strategy(gas_price_strategy) # type: ignore\n\n txn_hash = await async_w3.eth.send_transaction(txn_params)\n txn = await async_w3.eth.get_transaction(txn_hash)\n\n assert txn[\"gasPrice\"] == two_gwei_in_wei\n async_w3.eth.set_gas_price_strategy(None) # reset strategy\n\n @pytest.mark.asyncio\n @pytest.mark.parametrize(\n \"max_fee\", (1000000000, None), ids=[\"with_max_fee\", \"without_max_fee\"]\n )\n async def test_gas_price_from_strategy_bypassed_for_dynamic_fee_txn(\n self,\n async_w3: \"AsyncWeb3\",\n async_unlocked_account_dual_type: ChecksumAddress,\n max_fee: Wei,\n ) -> None:\n max_priority_fee = async_w3.to_wei(1, \"gwei\")\n txn_params: TxParams = {\n \"from\": async_unlocked_account_dual_type,\n \"to\": async_unlocked_account_dual_type,\n \"value\": Wei(1),\n \"gas\": 21000,\n \"maxPriorityFeePerGas\": max_priority_fee,\n }\n if max_fee is not None:\n txn_params = assoc(txn_params, \"maxFeePerGas\", max_fee)\n\n def gas_price_strategy(w3: \"Web3\", txn: TxParams) -> Wei:\n return async_w3.to_wei(2, \"gwei\")\n\n async_w3.eth.set_gas_price_strategy(gas_price_strategy)\n\n txn_hash = await async_w3.eth.send_transaction(txn_params)\n txn = await async_w3.eth.get_transaction(txn_hash)\n\n latest_block = await async_w3.eth.get_block(\"latest\")\n assert (\n txn[\"maxFeePerGas\"] == max_fee\n if max_fee is not None\n else 2 * latest_block[\"baseFeePerGas\"] + max_priority_fee\n )\n assert txn[\"maxPriorityFeePerGas\"] == max_priority_fee\n assert txn[\"gasPrice\"] == txn[\"maxFeePerGas\"]\n\n async_w3.eth.set_gas_price_strategy(None) # reset strategy\n\n @pytest.mark.asyncio\n async def test_gas_price_from_strategy_bypassed_for_dynamic_fee_txn_no_tip(\n self,\n async_w3: \"AsyncWeb3\",\n async_unlocked_account_dual_type: ChecksumAddress,\n ) -> None:\n txn_params: TxParams = {\n \"from\": async_unlocked_account_dual_type,\n \"to\": async_unlocked_account_dual_type,\n \"value\": Wei(1),\n \"gas\": 21000,\n \"maxFeePerGas\": Wei(1000000000),\n }\n\n def gas_price_strategy(_w3: \"Web3\", _txn: TxParams) -> Wei:\n return async_w3.to_wei(2, \"gwei\")\n\n async_w3.eth.set_gas_price_strategy(gas_price_strategy)\n\n with pytest.raises(\n InvalidTransaction, match=\"maxPriorityFeePerGas must be defined\"\n ):\n await async_w3.eth.send_transaction(txn_params)\n\n async_w3.eth.set_gas_price_strategy(None) # reset strategy\n\n @pytest.mark.asyncio\n async def test_eth_estimate_gas(\n self, async_w3: \"AsyncWeb3\", async_unlocked_account_dual_type: ChecksumAddress\n ) -> None:\n gas_estimate = await async_w3.eth.estimate_gas(\n {\n \"from\": async_unlocked_account_dual_type,\n \"to\": async_unlocked_account_dual_type,\n \"value\": Wei(1),\n }\n )\n assert is_integer(gas_estimate)\n assert gas_estimate > 0\n\n @pytest.mark.asyncio\n @pytest.mark.parametrize(\n \"params\",\n (\n {\n \"nonce\": 1, # int\n \"balance\": 1, # int\n \"code\": HexStr(\"0x\"), # HexStr\n # with state\n \"state\": {HexStr(f\"0x{'00' * 32}\"): HexStr(f\"0x{'00' * 32}\")},\n },\n {\n \"nonce\": HexStr(\"0x1\"), # HexStr\n \"balance\": HexStr(\"0x1\"), # HexStr\n \"code\": b\"\\x00\", # bytes\n # with stateDiff\n \"stateDiff\": {HexStr(f\"0x{'00' * 32}\"): HexStr(f\"0x{'00' * 32}\")},\n },\n ),\n )\n async def test_eth_estimate_gas_with_override_param_type_check(\n self,\n async_w3: \"AsyncWeb3\",\n async_math_contract: \"Contract\",\n params: StateOverrideParams,\n ) -> None:\n txn_params: TxParams = {\"from\": await async_w3.eth.coinbase}\n\n # assert does not raise\n await async_w3.eth.estimate_gas(\n txn_params, None, {async_math_contract.address: params}\n )\n\n @pytest.mark.asyncio\n async def test_eth_fee_history(self, async_w3: \"AsyncWeb3\") -> None:\n fee_history = await async_w3.eth.fee_history(1, \"latest\", [50])\n assert is_list_like(fee_history[\"baseFeePerGas\"])\n assert is_list_like(fee_history[\"gasUsedRatio\"])\n assert is_integer(fee_history[\"oldestBlock\"])\n assert fee_history[\"oldestBlock\"] >= 0\n assert is_list_like(fee_history[\"reward\"])\n if len(fee_history[\"reward\"]) > 0:\n assert is_list_like(fee_history[\"reward\"][0])\n\n @pytest.mark.asyncio\n async def test_eth_fee_history_with_integer(\n self, async_w3: \"AsyncWeb3\", async_empty_block: BlockData\n ) -> None:\n fee_history = await async_w3.eth.fee_history(\n 1, async_empty_block[\"number\"], [50]\n )\n assert is_list_like(fee_history[\"baseFeePerGas\"])\n assert is_list_like(fee_history[\"gasUsedRatio\"])\n assert is_integer(fee_history[\"oldestBlock\"])\n assert fee_history[\"oldestBlock\"] >= 0\n assert is_list_like(fee_history[\"reward\"])\n if len(fee_history[\"reward\"]) > 0:\n assert is_list_like(fee_history[\"reward\"][0])\n\n @pytest.mark.asyncio\n async def test_eth_fee_history_no_reward_percentiles(\n self, async_w3: \"AsyncWeb3\"\n ) -> None:\n fee_history = await async_w3.eth.fee_history(1, \"latest\")\n assert is_list_like(fee_history[\"baseFeePerGas\"])\n assert is_list_like(fee_history[\"gasUsedRatio\"])\n assert is_integer(fee_history[\"oldestBlock\"])\n assert fee_history[\"oldestBlock\"] >= 0\n\n @pytest.mark.asyncio\n async def test_eth_max_priority_fee(self, async_w3: \"AsyncWeb3\") -> None:\n max_priority_fee = await async_w3.eth.max_priority_fee\n assert is_integer(max_priority_fee)\n\n @pytest.mark.asyncio\n async def test_eth_max_priority_fee_with_fee_history_calculation(\n self, async_w3: \"AsyncWeb3\", request_mocker: Type[RequestMocker]\n ) -> None:\n async with request_mocker(\n async_w3,\n mock_errors={RPCEndpoint(\"eth_maxPriorityFeePerGas\"): {}},\n mock_results={RPCEndpoint(\"eth_feeHistory\"): {\"reward\": [[0]]}},\n ):\n with pytest.warns(\n UserWarning,\n match=(\n \"There was an issue with the method eth_maxPriorityFeePerGas. \"\n \"Calculating using eth_feeHistory.\"\n ),\n ):\n priority_fee = await async_w3.eth.max_priority_fee\n assert is_integer(priority_fee)\n assert priority_fee == PRIORITY_FEE_MIN\n\n @pytest.mark.asyncio\n async def test_eth_getBlockByHash(\n self, async_w3: \"AsyncWeb3\", async_empty_block: BlockData\n ) -> None:\n block = await async_w3.eth.get_block(async_empty_block[\"hash\"])\n assert block[\"hash\"] == async_empty_block[\"hash\"]\n\n @pytest.mark.asyncio\n async def test_eth_getBlockByHash_not_found(self, async_w3: \"AsyncWeb3\") -> None:\n with pytest.raises(BlockNotFound):\n await async_w3.eth.get_block(UNKNOWN_HASH)\n\n @pytest.mark.asyncio\n async def test_eth_getBlockByHash_pending(self, async_w3: \"AsyncWeb3\") -> None:\n block = await async_w3.eth.get_block(\"pending\")\n assert block[\"hash\"] is None\n\n @pytest.mark.asyncio\n async def test_eth_getBlockByNumber_with_integer(\n self, async_w3: \"AsyncWeb3\", async_empty_block: BlockData\n ) -> None:\n block = await async_w3.eth.get_block(async_empty_block[\"number\"])\n assert block[\"number\"] == async_empty_block[\"number\"]\n\n @pytest.mark.asyncio\n async def test_eth_getBlockByNumber_latest(\n self, async_w3: \"AsyncWeb3\", async_empty_block: BlockData\n ) -> None:\n current_block_number = await async_w3.eth.block_number\n block = await async_w3.eth.get_block(\"latest\")\n assert block[\"number\"] == current_block_number\n\n @pytest.mark.asyncio\n async def test_eth_getBlockByNumber_not_found(\n self, async_w3: \"AsyncWeb3\", async_empty_block: BlockData\n ) -> None:\n with pytest.raises(BlockNotFound):\n await async_w3.eth.get_block(BlockNumber(12345))\n\n @pytest.mark.asyncio\n async def test_eth_getBlockByNumber_pending(\n self, async_w3: \"AsyncWeb3\", async_empty_block: BlockData\n ) -> None:\n current_block_number = await async_w3.eth.block_number\n block = await async_w3.eth.get_block(\"pending\")\n assert block[\"number\"] == current_block_number + 1\n\n @pytest.mark.asyncio\n async def test_eth_getBlockByNumber_earliest(\n self, async_w3: \"AsyncWeb3\", async_empty_block: BlockData\n ) -> None:\n genesis_block = await async_w3.eth.get_block(BlockNumber(0))\n block = await async_w3.eth.get_block(\"earliest\")\n assert block[\"number\"] == 0\n assert block[\"hash\"] == genesis_block[\"hash\"]\n\n @pytest.mark.asyncio\n async def test_eth_getBlockByNumber_safe(\n self, async_w3: \"AsyncWeb3\", async_empty_block: BlockData\n ) -> None:\n block = await async_w3.eth.get_block(\"safe\")\n assert block is not None\n assert isinstance(block[\"number\"], int)\n\n @pytest.mark.asyncio\n async def test_eth_getBlockByNumber_finalized(\n self, async_w3: \"AsyncWeb3\", async_empty_block: BlockData\n ) -> None:\n block = await async_w3.eth.get_block(\"finalized\")\n assert block is not None\n assert isinstance(block[\"number\"], int)\n\n @pytest.mark.asyncio\n async def test_eth_get_block_by_number_full_transactions(\n self, async_w3: \"AsyncWeb3\", async_block_with_txn: BlockData\n ) -> None:\n block = await async_w3.eth.get_block(async_block_with_txn[\"number\"], True)\n transaction = cast(TxData, block[\"transactions\"][0])\n assert transaction[\"hash\"] == async_block_with_txn[\"transactions\"][0]\n\n @pytest.mark.asyncio\n async def test_eth_get_raw_transaction(\n self, async_w3: \"AsyncWeb3\", mined_txn_hash: HexStr\n ) -> None:\n raw_transaction = await async_w3.eth.get_raw_transaction(mined_txn_hash)\n assert is_bytes(raw_transaction)\n\n @pytest.mark.asyncio\n async def test_eth_get_raw_transaction_raises_error(\n self, async_w3: \"AsyncWeb3\"\n ) -> None:\n with pytest.raises(\n TransactionNotFound, match=f\"Transaction with hash: '{UNKNOWN_HASH}'\"\n ):\n await async_w3.eth.get_raw_transaction(UNKNOWN_HASH)\n\n @pytest.mark.asyncio\n async def test_eth_get_raw_transaction_by_block(\n self,\n async_w3: \"AsyncWeb3\",\n async_block_with_txn: BlockData,\n async_unlocked_account_dual_type: ChecksumAddress,\n ) -> None:\n # eth_getRawTransactionByBlockNumberAndIndex: block identifier\n...\n", "needle_token_start": 7175, "needle_token_end": 7463, "code_context_ntokens": 16254, "output": ["```python\n@pytest.mark.asyncio\nasync def test_async_eth_sign_transaction_legacy(\n self, async_w3: \"AsyncWeb3\", async_unlocked_account: ChecksumAddress\n ) -> None:\n txn_params: TxParams = {\n \"from\": async_unlocked_account,\n \"to\": async_unlocked_account,\n \"value\": Wei(1),\n \"gas\": 21000,\n \"gasPrice\": await async_w3.eth.gas_price,\n \"nonce\": Nonce(0),\n }\n result = await async_w3.eth.sign_transaction(txn_params)\n signatory_account = async_w3.eth.account.recover_transaction(result[\"raw\"])\n assert async_unlocked_account == signatory_account\n assert result[\"tx\"][\"to\"] == txn_params[\"to\"]\n assert result[\"tx\"][\"value\"] == txn_params[\"value\"]\n assert result[\"tx\"][\"gas\"] == txn_params[\"gas\"]\n assert result[\"tx\"][\"gasPrice\"] == txn_params[\"gasPrice\"]\n assert result[\"tx\"][\"nonce\"] == txn_params[\"nonce\"]\n```"]} +{"repo": "ethereum/web3.py", "name": "data_tree_map", "language": "python", "path": "web3/_utils/abi.py", "position_ratio": 0.55, "description": "\nFunction Description:\n1. **Purpose**: The function is designed to apply a given transformation function to each element within a structured data tree, where each element is associated with a specific data type.\n2. **Input**: The function takes two parameters: a transformation function that operates on a data type and its associated data, and a data tree where each element is a tuple of a data type and its corresponding data.\n3. **Output**: The output is a modified data tree where each element has been transformed by the provided function, maintaining the structure of the original data tree.\n4. **Procedure**: The function iterates over the data tree, applying the transformation function to each element. If an element is a tuple containing a data type and data (recognized as a specific data type structure), the transformation function is applied to it. The function is applied recursively to ensure all nested elements are processed.\n\n", "instruction": "Based on the function description and code context, please retrieve and repeat the exact described function from the code context in a code block wrapped by ```:", "template": "instruction\ncode_context\ndescription\ninstruction", "code_context": "# Path: web3/contract/async_contract.py\nimport copy\nfrom typing import (\n TYPE_CHECKING,\n Any,\n Awaitable,\n Callable,\n Dict,\n Iterable,\n List,\n Optional,\n Sequence,\n Type,\n cast,\n)\n\nfrom eth_typing import (\n ChecksumAddress,\n)\nfrom eth_utils import (\n combomethod,\n)\nfrom eth_utils.toolz import (\n partial,\n)\nfrom hexbytes import (\n HexBytes,\n)\n\nfrom web3._utils.abi import (\n fallback_func_abi_exists,\n filter_by_type,\n receive_func_abi_exists,\n)\nfrom web3._utils.async_transactions import (\n async_fill_transaction_defaults,\n)\nfrom web3._utils.compat import (\n Self,\n)\nfrom web3._utils.contracts import (\n async_parse_block_identifier,\n)\nfrom web3._utils.datatypes import (\n PropertyCheckingFactory,\n)\nfrom web3._utils.events import (\n AsyncEventFilterBuilder,\n get_event_data,\n)\nfrom web3._utils.filters import (\n AsyncLogFilter,\n)\nfrom web3._utils.function_identifiers import (\n FallbackFn,\n ReceiveFn,\n)\nfrom web3._utils.normalizers import (\n normalize_abi,\n normalize_address_no_ens,\n normalize_bytecode,\n)\nfrom web3.contract.base_contract import (\n BaseContract,\n BaseContractCaller,\n BaseContractConstructor,\n BaseContractEvent,\n BaseContractEvents,\n BaseContractFunction,\n BaseContractFunctions,\n NonExistentFallbackFunction,\n NonExistentReceiveFunction,\n)\nfrom web3.contract.utils import (\n async_build_transaction_for_function,\n async_call_contract_function,\n async_estimate_gas_for_function,\n...\n# Path: web3/contract/__init__.py\nfrom web3.contract.async_contract import (\n AsyncContract,\n AsyncContractCaller,\n)\nfrom web3.contract.contract import (\n Contract,\n ContractCaller,\n ContractConstructor,\n)\n\n# Path: web3/_utils/ens.py\nfrom contextlib import (\n contextmanager,\n)\nfrom typing import (\n TYPE_CHECKING,\n Any,\n Dict,\n Iterator,\n Union,\n cast,\n)\n\nfrom eth_typing import (\n ChecksumAddress,\n)\nfrom eth_utils import (\n is_0x_prefixed,\n is_hex,\n is_hex_address,\n to_checksum_address,\n)\n\nfrom ens import (\n ENS,\n AsyncENS,\n)\nfrom web3.exceptions import (\n NameNotFound,\n)\n\nif TYPE_CHECKING:\n from web3 import ( # noqa: F401\n AsyncWeb3,\n Web3,\n )\n from web3.contract import ( # noqa: F401\n Contract,\n )\n\n\ndef is_ens_name(value: Any) -> bool:\n if not isinstance(value, str):\n return False\n elif is_hex_address(value):\n return False\n elif is_0x_prefixed(value) and is_hex(value):\n return False\n else:\n return ENS.is_valid_name(value)\n\n\ndef validate_name_has_address(ens: ENS, name: str) -> ChecksumAddress:\n addr = ens.address(name)\n if addr:\n return to_checksum_address(addr)\n else:\n raise NameNotFound(f\"Could not find address for name {name!r}\")\n\n\nclass StaticENS:\n def __init__(self, name_addr_pairs: Dict[str, ChecksumAddress]) -> None:\n self.registry = dict(name_addr_pairs)\n\n def address(self, name: str) -> ChecksumAddress:\n return self.registry.get(name, None)\n\n\n@contextmanager\ndef ens_addresses(\n w3: Union[\"Web3\", \"AsyncWeb3\"], name_addr_pairs: Dict[str, ChecksumAddress]\n) -> Iterator[None]:\n original_ens = w3.ens\n if w3.provider.is_async:\n w3.ens = cast(AsyncENS, StaticENS(name_addr_pairs))\n else:\n w3.ens = cast(ENS, StaticENS(name_addr_pairs))\n yield\n w3.ens = original_ens\n\n\n@contextmanager\ndef contract_ens_addresses(\n contract: \"Contract\", name_addr_pairs: Dict[str, ChecksumAddress]\n) -> Iterator[None]:\n \"\"\"\n Use this context manager to temporarily resolve name/address pairs\n supplied as the argument. For example:\n\n with contract_ens_addresses(mycontract, [('resolve-as-1s.eth', '0x111...111')]):\n # any contract call or transaction in here would only resolve the above ENS pair\n \"\"\"\n with ens_addresses(contract.w3, name_addr_pairs):\n yield\n\n\n# --- async --- #\n\n\nasync def async_validate_name_has_address(\n async_ens: AsyncENS, name: str\n) -> ChecksumAddress:\n addr = await async_ens.address(name)\n if not addr:\n raise NameNotFound(f\"Could not find address for name {name!r}\")\n return addr\n\n# Path: web3/_utils/abi.py\nimport binascii\nfrom collections import (\n abc,\n namedtuple,\n)\nimport copy\nimport itertools\nimport re\nfrom typing import (\n TYPE_CHECKING,\n Any,\n Callable,\n Collection,\n Coroutine,\n Dict,\n Iterable,\n List,\n Mapping,\n Optional,\n Sequence,\n Tuple,\n Type,\n Union,\n cast,\n)\n\nfrom eth_abi import (\n codec,\n decoding,\n encoding,\n)\nfrom eth_abi.base import (\n parse_type_str,\n)\nfrom eth_abi.exceptions import (\n ValueOutOfBounds,\n)\nfrom eth_abi.grammar import (\n ABIType,\n BasicType,\n TupleType,\n parse,\n)\nfrom eth_abi.registry import (\n ABIRegistry,\n BaseEquals,\n registry as default_registry,\n)\nfrom eth_typing import (\n HexStr,\n TypeStr,\n)\nfrom eth_utils import (\n decode_hex,\n is_bytes,\n is_list_like,\n is_string,\n is_text,\n to_text,\n to_tuple,\n)\nfrom eth_utils.abi import (\n collapse_if_tuple,\n)\nfrom eth_utils.toolz import (\n curry,\n partial,\n pipe,\n)\n\nfrom web3._utils.decorators import (\n reject_recursive_repeats,\n)\nfrom web3._utils.ens import (\n is_ens_name,\n)\nfrom web3._utils.formatters import (\n recursive_map,\n)\nfrom web3.exceptions import (\n FallbackNotFound,\n MismatchedABI,\n)\nfrom web3.types import (\n ABI,\n ABIEvent,\n ABIEventParams,\n ABIFunction,\n ABIFunctionParams,\n TReturn,\n)\nfrom web3.utils import ( # public utils module\n get_abi_input_names,\n)\n\nif TYPE_CHECKING:\n from web3 import ( # noqa: F401\n AsyncWeb3,\n )\n\n\ndef filter_by_type(_type: str, contract_abi: ABI) -> List[Union[ABIFunction, ABIEvent]]:\n return [abi for abi in contract_abi if abi[\"type\"] == _type]\n\n\ndef filter_by_name(name: str, contract_abi: ABI) -> List[Union[ABIFunction, ABIEvent]]:\n return [\n abi\n for abi in contract_abi\n if (\n abi[\"type\"] not in (\"fallback\", \"constructor\", \"receive\")\n and abi[\"name\"] == name\n )\n ]\n\n\ndef get_abi_input_types(abi: ABIFunction) -> List[str]:\n if \"inputs\" not in abi and (abi[\"type\"] == \"fallback\" or abi[\"type\"] == \"receive\"):\n return []\n else:\n return [collapse_if_tuple(cast(Dict[str, Any], arg)) for arg in abi[\"inputs\"]]\n\n\ndef get_abi_output_types(abi: ABIFunction) -> List[str]:\n if abi[\"type\"] == \"fallback\":\n return []\n else:\n return [collapse_if_tuple(cast(Dict[str, Any], arg)) for arg in abi[\"outputs\"]]\n\n\ndef get_receive_func_abi(contract_abi: ABI) -> ABIFunction:\n receive_abis = filter_by_type(\"receive\", contract_abi)\n if receive_abis:\n return cast(ABIFunction, receive_abis[0])\n else:\n raise FallbackNotFound(\"No receive function was found in the contract ABI.\")\n\n\ndef get_fallback_func_abi(contract_abi: ABI) -> ABIFunction:\n fallback_abis = filter_by_type(\"fallback\", contract_abi)\n if fallback_abis:\n return cast(ABIFunction, fallback_abis[0])\n else:\n raise FallbackNotFound(\"No fallback function was found in the contract ABI.\")\n\n\ndef fallback_func_abi_exists(contract_abi: ABI) -> List[Union[ABIFunction, ABIEvent]]:\n return filter_by_type(\"fallback\", contract_abi)\n\n\ndef receive_func_abi_exists(contract_abi: ABI) -> List[Union[ABIFunction, ABIEvent]]:\n return filter_by_type(\"receive\", contract_abi)\n\n\ndef get_indexed_event_inputs(event_abi: ABIEvent) -> List[ABIEventParams]:\n return [arg for arg in event_abi[\"inputs\"] if arg[\"indexed\"] is True]\n\n\ndef exclude_indexed_event_inputs(event_abi: ABIEvent) -> List[ABIEventParams]:\n return [arg for arg in event_abi[\"inputs\"] if arg[\"indexed\"] is False]\n\n\ndef get_normalized_abi_arg_type(abi_arg: ABIEventParams) -> str:\n \"\"\"\n Return the normalized type for the abi argument provided.\n In order to account for tuple argument types, this abstraction\n makes use of `collapse_if_tuple()` to collapse the appropriate component\n types within a tuple type, if present.\n \"\"\"\n return collapse_if_tuple(dict(abi_arg))\n\n\ndef filter_by_argument_count(\n num_arguments: int, contract_abi: ABI\n) -> List[Union[ABIFunction, ABIEvent]]:\n return [abi for abi in contract_abi if len(abi[\"inputs\"]) == num_arguments]\n\n\ndef filter_by_argument_name(\n argument_names: Collection[str], contract_abi: ABI\n) -> List[Union[ABIFunction, ABIEvent]]:\n return [\n abi\n for abi in contract_abi\n if set(argument_names).intersection(get_abi_input_names(abi))\n == set(argument_names)\n ]\n\n\nclass AddressEncoder(encoding.AddressEncoder):\n @classmethod\n def validate_value(cls, value: Any) -> None:\n if is_ens_name(value):\n return\n\n super().validate_value(value)\n\n\nclass AcceptsHexStrEncoder(encoding.BaseEncoder):\n subencoder_cls: Type[encoding.BaseEncoder] = None\n is_strict: bool = None\n is_big_endian: bool = False\n data_byte_size: int = None\n value_bit_size: int = None\n\n def __init__(\n self,\n subencoder: encoding.BaseEncoder,\n **kwargs: Dict[str, Any],\n ) -> None:\n super().__init__(**kwargs)\n self.subencoder = subencoder\n self.is_dynamic = subencoder.is_dynamic\n\n @classmethod\n def from_type_str(\n cls, abi_type: TypeStr, registry: ABIRegistry\n ) -> \"AcceptsHexStrEncoder\":\n subencoder_cls = cls.get_subencoder_class()\n # cast b/c expects BaseCoder but `from_type_string`\n # restricted to BaseEncoder subclasses\n subencoder = cast(\n encoding.BaseEncoder, subencoder_cls.from_type_str(abi_type, registry)\n )\n return cls(subencoder)\n\n @classmethod\n def get_subencoder_class(cls) -> Type[encoding.BaseEncoder]:\n if cls.subencoder_cls is None:\n raise AttributeError(f\"No subencoder class is set. {cls.__name__}\")\n return cls.subencoder_cls\n\n def validate_value(self, value: Any) -> None:\n normalized_value = self.validate_and_normalize(value)\n self.subencoder.validate_value(normalized_value)\n\n def encode(self, value: Any) -> bytes:\n normalized_value = self.validate_and_normalize(value)\n return self.subencoder.encode(normalized_value)\n\n def validate_and_normalize(self, value: Any) -> HexStr:\n if not is_bytes(value) and not is_text(value):\n self.invalidate_value(value)\n\n raw_value = value\n if is_text(value):\n try:\n value = decode_hex(value)\n except binascii.Error:\n self.invalidate_value(\n value,\n msg=f\"{value} is an invalid hex string\",\n )\n else:\n if raw_value[:2] != \"0x\" and self.is_strict:\n self.invalidate_value(\n raw_value, msg=\"hex string must be prefixed with 0x\"\n )\n\n if self.is_strict and self.data_byte_size is not None:\n if len(value) > self.data_byte_size:\n self.invalidate_value(\n value,\n exc=ValueOutOfBounds,\n msg=f\"exceeds total byte size for bytes{self.data_byte_size} \"\n \"encoding\",\n )\n elif len(value) < self.data_byte_size:\n self.invalidate_value(\n value,\n exc=ValueOutOfBounds,\n msg=f\"less than total byte size for bytes{self.data_byte_size} \"\n \"encoding\",\n )\n\n return value\n\n\nclass BytesEncoder(AcceptsHexStrEncoder):\n subencoder_cls = encoding.BytesEncoder\n is_strict = False\n\n\nclass ExactLengthBytesEncoder(BytesEncoder):\n is_strict = True\n\n def validate(self) -> None:\n super().validate()\n if self.value_bit_size is None:\n raise ValueError(\"`value_bit_size` may not be none\")\n if self.data_byte_size is None:\n raise ValueError(\"`data_byte_size` may not be none\")\n if self.is_big_endian is None:\n raise ValueError(\"`is_big_endian` may not be none\")\n\n if self.value_bit_size % 8 != 0:\n raise ValueError(\n f\"Invalid value bit size: {self.value_bit_size}. \"\n \"Must be a multiple of 8\"\n )\n\n if self.value_bit_size > self.data_byte_size * 8:\n raise ValueError(\"Value byte size exceeds data size\")\n\n @parse_type_str(\"bytes\")\n def from_type_str(\n cls, abi_type: BasicType, registry: ABIRegistry\n ) -> \"ExactLengthBytesEncoder\":\n subencoder_cls = cls.get_subencoder_class()\n subencoder = subencoder_cls.from_type_str(abi_type.to_type_str(), registry)\n # type ignored b/c @parse_type_str decorator turns it into a classmethod,\n # so mypy thinks cls(...) is a call to __call__, but actually calls __init__\n return cls( # type: ignore\n subencoder,\n value_bit_size=abi_type.sub * 8,\n data_byte_size=abi_type.sub,\n )\n\n\nclass ByteStringEncoder(AcceptsHexStrEncoder):\n subencoder_cls = encoding.ByteStringEncoder\n is_strict = False\n\n\nclass StrictByteStringEncoder(AcceptsHexStrEncoder):\n subencoder_cls = encoding.ByteStringEncoder\n is_strict = True\n\n\nclass TextStringEncoder(encoding.TextStringEncoder):\n @classmethod\n def validate_value(cls, value: Any) -> None:\n if is_bytes(value):\n try:\n value = to_text(value)\n except UnicodeDecodeError:\n cls.invalidate_value(\n value,\n msg=\"not decodable as unicode string\",\n )\n\n super().validate_value(value)\n\n\ndef filter_by_encodability(\n abi_codec: codec.ABIEncoder,\n args: Sequence[Any],\n kwargs: Dict[str, Any],\n contract_abi: ABI,\n) -> List[ABIFunction]:\n return [\n cast(ABIFunction, function_abi)\n for function_abi in contract_abi\n if check_if_arguments_can_be_encoded(\n cast(ABIFunction, function_abi), abi_codec, args, kwargs\n )\n ]\n\n\ndef check_if_arguments_can_be_encoded(\n function_abi: ABIFunction,\n abi_codec: codec.ABIEncoder,\n args: Sequence[Any],\n kwargs: Dict[str, Any],\n) -> bool:\n try:\n arguments = merge_args_and_kwargs(function_abi, args, kwargs)\n except TypeError:\n return False\n\n if len(function_abi.get(\"inputs\", [])) != len(arguments):\n return False\n\n try:\n types, aligned_args = get_aligned_abi_inputs(function_abi, arguments)\n except TypeError:\n return False\n\n return all(\n abi_codec.is_encodable(_type, arg) for _type, arg in zip(types, aligned_args)\n )\n\n\ndef merge_args_and_kwargs(\n function_abi: ABIFunction, args: Sequence[Any], kwargs: Dict[str, Any]\n) -> Tuple[Any, ...]:\n \"\"\"\n Takes a list of positional args (``args``) and a dict of keyword args\n (``kwargs``) defining values to be passed to a call to the contract function\n described by ``function_abi``. Checks to ensure that the correct number of\n args were given, no duplicate args were given, and no unknown args were\n given. Returns a list of argument values aligned to the order of inputs\n defined in ``function_abi``.\n \"\"\"\n # Ensure the function is being applied to the correct number of args\n if len(args) + len(kwargs) != len(function_abi.get(\"inputs\", [])):\n raise TypeError(\n f\"Incorrect argument count. Expected '{len(function_abi['inputs'])}'\"\n f\". Got '{len(args) + len(kwargs)}'\"\n )\n\n # If no keyword args were given, we don't need to align them\n if not kwargs:\n return cast(Tuple[Any, ...], args)\n\n kwarg_names = set(kwargs.keys())\n sorted_arg_names = tuple(arg_abi[\"name\"] for arg_abi in function_abi[\"inputs\"])\n args_as_kwargs = dict(zip(sorted_arg_names, args))\n\n # Check for duplicate args\n duplicate_args = kwarg_names.intersection(args_as_kwargs.keys())\n if duplicate_args:\n raise TypeError(\n f\"{function_abi.get('name')}() got multiple values for argument(s) \"\n f\"'{', '.join(duplicate_args)}'\"\n )\n\n # Check for unknown args\n unknown_args = kwarg_names.difference(sorted_arg_names)\n if unknown_args:\n if function_abi.get(\"name\"):\n raise TypeError(\n f\"{function_abi.get('name')}() got unexpected keyword argument(s)\"\n f\" '{', '.join(unknown_args)}'\"\n )\n raise TypeError(\n f\"Type: '{function_abi.get('type')}' got unexpected keyword argument(s)\"\n f\" '{', '.join(unknown_args)}'\"\n )\n\n # Sort args according to their position in the ABI and unzip them from their\n # names\n sorted_args = tuple(\n zip(\n *sorted(\n itertools.chain(kwargs.items(), args_as_kwargs.items()),\n key=lambda kv: sorted_arg_names.index(kv[0]),\n )\n )\n )\n\n if sorted_args:\n return sorted_args[1]\n else:\n return tuple()\n\n\nTUPLE_TYPE_STR_RE = re.compile(r\"^(tuple)((\\[([1-9]\\d*\\b)?])*)??$\")\n\n\ndef get_tuple_type_str_parts(s: str) -> Optional[Tuple[str, Optional[str]]]:\n \"\"\"\n Takes a JSON ABI type string. For tuple type strings, returns the separated\n prefix and array dimension parts. For all other strings, returns ``None``.\n \"\"\"\n match = TUPLE_TYPE_STR_RE.match(s)\n\n if match is not None:\n tuple_prefix = match.group(1)\n tuple_dims = match.group(2)\n\n return tuple_prefix, tuple_dims\n\n return None\n\n\ndef _align_abi_input(arg_abi: ABIFunctionParams, arg: Any) -> Tuple[Any, ...]:\n \"\"\"\n Aligns the values of any mapping at any level of nesting in ``arg``\n according to the layout of the corresponding abi spec.\n \"\"\"\n tuple_parts = get_tuple_type_str_parts(arg_abi[\"type\"])\n\n if tuple_parts is None:\n # Arg is non-tuple. Just return value.\n return arg\n\n tuple_prefix, tuple_dims = tuple_parts\n if tuple_dims is None:\n # Arg is non-list tuple. Each sub arg in `arg` will be aligned\n # according to its corresponding abi.\n sub_abis = arg_abi[\"components\"]\n else:\n num_dims = tuple_dims.count(\"[\")\n\n # Arg is list tuple. A non-list version of its abi will be used to\n # align each element in `arg`.\n new_abi = copy.copy(arg_abi)\n new_abi[\"type\"] = tuple_prefix + \"[]\" * (num_dims - 1)\n\n sub_abis = itertools.repeat(new_abi) # type: ignore\n\n if isinstance(arg, abc.Mapping):\n # Arg is mapping. Align values according to abi order.\n aligned_arg = tuple(arg[abi[\"name\"]] for abi in sub_abis)\n else:\n aligned_arg = arg\n\n if not is_list_like(aligned_arg):\n raise TypeError(\n f'Expected non-string sequence for \"{arg_abi.get(\"type\")}\" '\n f\"component type: got {aligned_arg}\"\n )\n\n # convert NamedTuple to regular tuple\n typing = tuple if isinstance(aligned_arg, tuple) else type(aligned_arg)\n\n return typing(\n _align_abi_input(sub_abi, sub_arg)\n for sub_abi, sub_arg in zip(sub_abis, aligned_arg)\n )\n\n\ndef get_aligned_abi_inputs(\n abi: ABIFunction, args: Union[Tuple[Any, ...], Mapping[Any, Any]]\n) -> Tuple[Tuple[Any, ...], Tuple[Any, ...]]:\n \"\"\"\n Takes a function ABI (``abi``) and a sequence or mapping of args (``args``).\n Returns a list of type strings for the function's inputs and a list of\n arguments which have been aligned to the layout of those types. The args\n contained in ``args`` may contain nested mappings or sequences corresponding\n to tuple-encoded values in ``abi``.\n \"\"\"\n input_abis = abi.get(\"inputs\", [])\n\n if isinstance(args, abc.Mapping):\n # `args` is mapping. Align values according to abi order.\n args = tuple(args[abi[\"name\"]] for abi in input_abis)\n\n return (\n # typed dict cannot be used w/ a normal Dict\n # https://github.com/python/mypy/issues/4976\n tuple(collapse_if_tuple(abi) for abi in input_abis), # type: ignore\n type(args)(_align_abi_input(abi, arg) for abi, arg in zip(input_abis, args)),\n )\n\n\ndef get_constructor_abi(contract_abi: ABI) -> ABIFunction:\n candidates = [abi for abi in contract_abi if abi[\"type\"] == \"constructor\"]\n if len(candidates) == 1:\n return candidates[0]\n elif len(candidates) == 0:\n return None\n elif len(candidates) > 1:\n raise ValueError(\"Found multiple constructors.\")\n return None\n\n\nDYNAMIC_TYPES = [\"bytes\", \"string\"]\n\nINT_SIZES = range(8, 257, 8)\nBYTES_SIZES = range(1, 33)\nUINT_TYPES = [f\"uint{i}\" for i in INT_SIZES]\nINT_TYPES = [f\"int{i}\" for i in INT_SIZES]\nBYTES_TYPES = [f\"bytes{i}\" for i in BYTES_SIZES] + [\"bytes32.byte\"]\n\nSTATIC_TYPES = list(\n itertools.chain(\n [\"address\", \"bool\"],\n UINT_TYPES,\n INT_TYPES,\n BYTES_TYPES,\n )\n)\n\nBASE_TYPE_REGEX = \"|\".join(\n (_type + \"(?![a-z0-9])\" for _type in itertools.chain(STATIC_TYPES, DYNAMIC_TYPES))\n)\n\nSUB_TYPE_REGEX = r\"\\[\" \"[0-9]*\" r\"\\]\"\n\nTYPE_REGEX = (\"^\" \"(?:{base_type})\" \"(?:(?:{sub_type})*)?\" \"$\").format(\n base_type=BASE_TYPE_REGEX,\n sub_type=SUB_TYPE_REGEX,\n)\n\n\ndef is_recognized_type(abi_type: TypeStr) -> bool:\n return bool(re.match(TYPE_REGEX, abi_type))\n\n\ndef is_bool_type(abi_type: TypeStr) -> bool:\n return abi_type == \"bool\"\n\n\ndef is_uint_type(abi_type: TypeStr) -> bool:\n return abi_type in UINT_TYPES\n\n\ndef is_int_type(abi_type: TypeStr) -> bool:\n return abi_type in INT_TYPES\n\n\ndef is_address_type(abi_type: TypeStr) -> bool:\n return abi_type == \"address\"\n\n\ndef is_bytes_type(abi_type: TypeStr) -> bool:\n return abi_type in BYTES_TYPES + [\"bytes\"]\n\n\ndef is_string_type(abi_type: TypeStr) -> bool:\n return abi_type == \"string\"\n\n\n@curry\ndef is_length(target_length: int, value: abc.Sized) -> bool:\n return len(value) == target_length\n\n\ndef size_of_type(abi_type: TypeStr) -> int:\n \"\"\"\n Returns size in bits of abi_type\n \"\"\"\n if \"string\" in abi_type:\n return None\n if \"byte\" in abi_type:\n return None\n if \"[\" in abi_type:\n return None\n if abi_type == \"bool\":\n return 8\n if abi_type == \"address\":\n return 160\n return int(re.sub(r\"\\D\", \"\", abi_type))\n\n\nEND_BRACKETS_OF_ARRAY_TYPE_REGEX = r\"\\[[^]]*\\]$\"\n\n\ndef sub_type_of_array_type(abi_type: TypeStr) -> str:\n if not is_array_type(abi_type):\n raise ValueError(f\"Cannot parse subtype of nonarray abi-type: {abi_type}\")\n\n return re.sub(END_BRACKETS_OF_ARRAY_TYPE_REGEX, \"\", abi_type, 1)\n\n\ndef length_of_array_type(abi_type: TypeStr) -> int:\n if not is_array_type(abi_type):\n raise ValueError(f\"Cannot parse length of nonarray abi-type: {abi_type}\")\n\n inner_brackets = (\n re.search(END_BRACKETS_OF_ARRAY_TYPE_REGEX, abi_type).group(0).strip(\"[]\")\n )\n if not inner_brackets:\n return None\n else:\n return int(inner_brackets)\n\n\nARRAY_REGEX = (\"^\" \"[a-zA-Z0-9_]+\" \"({sub_type})+\" \"$\").format(sub_type=SUB_TYPE_REGEX)\n\n\ndef is_array_type(abi_type: TypeStr) -> bool:\n return bool(re.match(ARRAY_REGEX, abi_type))\n\n\nNAME_REGEX = \"[a-zA-Z_]\" \"[a-zA-Z0-9_]*\"\n\n\nENUM_REGEX = (\"^\" \"{lib_name}\" r\"\\.\" \"{enum_name}\" \"$\").format(\n lib_name=NAME_REGEX, enum_name=NAME_REGEX\n)\n\n\ndef is_probably_enum(abi_type: TypeStr) -> bool:\n return bool(re.match(ENUM_REGEX, abi_type))\n\n\n@to_tuple\ndef normalize_event_input_types(\n abi_args: Collection[Union[ABIFunction, ABIEvent]]\n) -> Iterable[Union[ABIFunction, ABIEvent, Dict[TypeStr, Any]]]:\n for arg in abi_args:\n if is_recognized_type(arg[\"type\"]):\n yield arg\n elif is_probably_enum(arg[\"type\"]):\n yield {k: \"uint8\" if k == \"type\" else v for k, v in arg.items()}\n else:\n yield arg\n\n\ndef abi_to_signature(abi: Union[ABIFunction, ABIEvent]) -> str:\n function_signature = \"{fn_name}({fn_input_types})\".format(\n fn_name=abi[\"name\"],\n fn_input_types=\",\".join(\n collapse_if_tuple(dict(arg))\n for arg in normalize_event_input_types(abi.get(\"inputs\", []))\n ),\n )\n return function_signature\n\n\n########################################################\n#\n# Conditionally modifying data, tagged with ABI Types\n#\n########################################################\n\n\n@curry\ndef map_abi_data(\n normalizers: Sequence[Callable[[TypeStr, Any], Tuple[TypeStr, Any]]],\n types: Sequence[TypeStr],\n data: Sequence[Any],\n) -> Any:\n \"\"\"\n This function will apply normalizers to your data, in the\n context of the relevant types. Each normalizer is in the format:\n\n def normalizer(datatype, data):\n # Conditionally modify data\n return (datatype, data)\n\n Where datatype is a valid ABI type string, like \"uint\".\n\n In case of an array, like \"bool[2]\", normalizer will receive `data`\n as an iterable of typed data, like `[(\"bool\", True), (\"bool\", False)]`.\n\n Internals\n ---\n\n This is accomplished by:\n\n 1. Decorating the data tree with types\n 2. Recursively mapping each of the normalizers to the data\n 3. Stripping the types back out of the tree\n \"\"\"\n pipeline = itertools.chain(\n [abi_data_tree(types)],\n map(data_tree_map, normalizers),\n [partial(recursive_map, strip_abi_type)],\n )\n\n return pipe(data, *pipeline)\n\n\n@curry\ndef abi_data_tree(types: Sequence[TypeStr], data: Sequence[Any]) -> List[Any]:\n \"\"\"\n Decorate the data tree with pairs of (type, data). The pair tuple is actually an\n ABITypedData, but can be accessed as a tuple.\n\n As an example:\n\n >>> abi_data_tree(types=[\"bool[2]\", \"uint\"], data=[[True, False], 0])\n [(\"bool[2]\", [(\"bool\", True), (\"bool\", False)]), (\"uint256\", 0)]\n \"\"\"\n return [\n abi_sub_tree(data_type, data_value)\n for data_type, data_value in zip(types, data)\n ]\n\n\n@curry\n\ndef data_tree_map(\n func: Callable[[TypeStr, Any], Tuple[TypeStr, Any]], data_tree: Any\n) -> \"ABITypedData\":\n \"\"\"\n Map func to every ABITypedData element in the tree. func will\n receive two args: abi_type, and data\n \"\"\"\n\n def map_to_typed_data(elements: Any) -> \"ABITypedData\":\n if isinstance(elements, ABITypedData) and elements.abi_type is not None:\n return ABITypedData(func(*elements))\n else:\n return elements\n\n return recursive_map(map_to_typed_data, data_tree)\n\n\nclass ABITypedData(namedtuple(\"ABITypedData\", \"abi_type, data\")):\n \"\"\"\n This class marks data as having a certain ABI-type.\n\n >>> a1 = ABITypedData(['address', addr1])\n >>> a2 = ABITypedData(['address', addr2])\n >>> addrs = ABITypedData(['address[]', [a1, a2]])\n\n You can access the fields using tuple() interface, or with\n attributes:\n\n >>> assert a1.abi_type == a1[0]\n >>> assert a1.data == a1[1]\n\n Unlike a typical `namedtuple`, you initialize with a single\n positional argument that is iterable, to match the init\n interface of all other relevant collections.\n \"\"\"\n\n def __new__(cls, iterable: Iterable[Any]) -> \"ABITypedData\":\n return super().__new__(cls, *iterable)\n\n\ndef abi_sub_tree(\n type_str_or_abi_type: Optional[Union[TypeStr, ABIType]], data_value: Any\n) -> ABITypedData:\n if type_str_or_abi_type is None:\n return ABITypedData([None, data_value])\n\n if isinstance(type_str_or_abi_type, TypeStr):\n abi_type = parse(type_str_or_abi_type)\n else:\n abi_type = type_str_or_abi_type\n\n # In the two special cases below, we rebuild the given data structures with\n # annotated items\n if abi_type.is_array:\n # If type is array, determine item type and annotate all\n # items in iterable with that type\n item_type_str = abi_type.item_type.to_type_str()\n value_to_annotate = [\n abi_sub_tree(item_type_str, item_value) for item_value in data_value\n ]\n elif isinstance(abi_type, TupleType):\n # Otherwise, if type is tuple, determine component types and annotate\n # tuple components in iterable respectively with those types\n value_to_annotate = type(data_value)(\n abi_sub_tree(comp_type.to_type_str(), comp_value)\n for comp_type, comp_value in zip(abi_type.components, data_value)\n )\n else:\n value_to_annotate = data_value\n\n return ABITypedData(\n [\n abi_type.to_type_str(),\n value_to_annotate,\n ]\n )\n\n\ndef strip_abi_type(elements: Any) -> Any:\n if isinstance(elements, ABITypedData):\n return elements.data\n else:\n return elements\n\n\ndef build_non_strict_registry() -> ABIRegistry:\n # We make a copy here just to make sure that eth-abi's default registry is not\n # affected by our custom encoder subclasses\n registry = default_registry.copy()\n\n registry.unregister(\"address\")\n registry.unregister(\"bytes\")\n registry.unregister(\"bytes\")\n registry.unregister(\"string\")\n\n registry.register(\n BaseEquals(\"address\"),\n AddressEncoder,\n decoding.AddressDecoder,\n label=\"address\",\n )\n registry.register(\n BaseEquals(\"bytes\", with_sub=True),\n BytesEncoder,\n decoding.BytesDecoder,\n label=\"bytes\",\n )\n registry.register(\n BaseEquals(\"bytes\", with_sub=False),\n ByteStringEncoder,\n decoding.ByteStringDecoder,\n label=\"bytes\",\n )\n registry.register(\n BaseEquals(\"string\"),\n TextStringEncoder,\n decoding.StringDecoder,\n label=\"string\",\n )\n return registry\n\n\ndef build_strict_registry() -> ABIRegistry:\n registry = default_registry.copy()\n\n registry.unregister(\"address\")\n registry.unregister(\"bytes\")\n registry.unregister(\"bytes\")\n registry.unregister(\"string\")\n\n registry.register(\n BaseEquals(\"address\"),\n AddressEncoder,\n decoding.AddressDecoder,\n label=\"address\",\n )\n registry.register(\n BaseEquals(\"bytes\", with_sub=True),\n ExactLengthBytesEncoder,\n decoding.BytesDecoder,\n label=\"bytes\",\n )\n registry.register(\n BaseEquals(\"bytes\", with_sub=False),\n StrictByteStringEncoder,\n decoding.ByteStringDecoder,\n label=\"bytes\",\n )\n registry.register(\n BaseEquals(\"string\"),\n encoding.TextStringEncoder,\n decoding.StringDecoder,\n label=\"string\",\n )\n return registry\n\n\ndef named_tree(\n abi: Iterable[Union[ABIFunctionParams, ABIFunction, ABIEvent, Dict[TypeStr, Any]]],\n data: Iterable[Tuple[Any, ...]],\n) -> Dict[str, Any]:\n \"\"\"\n Convert function inputs/outputs or event data tuple to dict with names from ABI.\n \"\"\"\n names = [item[\"name\"] for item in abi]\n items = [_named_subtree(*item) for item in zip(abi, data)]\n\n return dict(zip(names, items))\n\n\ndef _named_subtree(\n abi: Union[ABIFunctionParams, ABIFunction, ABIEvent, Dict[TypeStr, Any]],\n data: Tuple[Any, ...],\n) -> Union[Dict[str, Any], Tuple[Any, ...], List[Any]]:\n abi_type = parse(collapse_if_tuple(dict(abi)))\n\n if abi_type.is_array:\n item_type = abi_type.item_type.to_type_str()\n item_abi = {**abi, \"type\": item_type, \"name\": \"\"}\n items = [_named_subtree(item_abi, item) for item in data]\n return items\n\n elif isinstance(abi_type, TupleType):\n abi = cast(ABIFunctionParams, abi)\n names = [item[\"name\"] for item in abi[\"components\"]]\n items = [_named_subtree(*item) for item in zip(abi[\"components\"], data)]\n\n if len(names) == len(data):\n return dict(zip(names, items))\n else:\n raise MismatchedABI(\n f\"ABI fields {names} has length {len(names)} but received \"\n f\"data {data} with length {len(data)}\"\n )\n\n return data\n\n\ndef recursive_dict_to_namedtuple(data: Dict[str, Any]) -> Tuple[Any, ...]:\n def _dict_to_namedtuple(\n value: Union[Dict[str, Any], List[Any]]\n ) -> Union[Tuple[Any, ...], List[Any]]:\n if not isinstance(value, dict):\n return value\n\n keys, values = zip(*value.items()) if value else ((), ())\n return abi_decoded_namedtuple_factory(keys)(values)\n\n return recursive_map(_dict_to_namedtuple, data)\n\n\ndef abi_decoded_namedtuple_factory(\n fields: Tuple[Any, ...]\n) -> Callable[..., Tuple[Any, ...]]:\n class ABIDecodedNamedTuple(namedtuple(\"ABIDecodedNamedTuple\", fields, rename=True)): # type: ignore # noqa: E501\n def __new__(self, args: Any) -> \"ABIDecodedNamedTuple\":\n return super().__new__(self, *args)\n\n return ABIDecodedNamedTuple\n\n\n# -- async -- #\n\n\nasync def async_data_tree_map(\n async_w3: \"AsyncWeb3\",\n func: Callable[\n [\"AsyncWeb3\", TypeStr, Any], Coroutine[Any, Any, Tuple[TypeStr, Any]]\n ],\n data_tree: Any,\n) -> \"ABITypedData\":\n \"\"\"\n Map an awaitable method to every ABITypedData element in the tree.\n\n The awaitable method should receive three positional args:\n async_w3, abi_type, and data\n \"\"\"\n\n async def async_map_to_typed_data(elements: Any) -> \"ABITypedData\":\n if isinstance(elements, ABITypedData) and elements.abi_type is not None:\n formatted = await func(async_w3, *elements)\n return ABITypedData(formatted)\n else:\n return elements\n\n return await async_recursive_map(async_w3, async_map_to_typed_data, data_tree)\n\n\n@reject_recursive_repeats\nasync def async_recursive_map(\n async_w3: \"AsyncWeb3\",\n func: Callable[[Any], Coroutine[Any, Any, TReturn]],\n data: Any,\n) -> TReturn:\n \"\"\"\n Apply an awaitable method to data and any collection items inside data\n (using async_map_collection).\n\n Define the awaitable method so that it only applies to the type of value that you\n want it to apply to.\n \"\"\"\n\n async def async_recurse(item: Any) -> TReturn:\n return await async_recursive_map(async_w3, func, item)\n\n items_mapped = await async_map_if_collection(async_recurse, data)\n return await func(items_mapped)\n\n\nasync def async_map_if_collection(\n func: Callable[[Any], Coroutine[Any, Any, Any]], value: Any\n) -> Any:\n \"\"\"\n Apply an awaitable method to each element of a collection or value of a dictionary.\n If the value is not a collection, return it unmodified.\n \"\"\"\n\n datatype = type(value)\n if isinstance(value, Mapping):\n return datatype({key: await func(val) for key, val in value.values()})\n if is_string(value):\n return value\n elif isinstance(value, Iterable):\n return datatype([await func(item) for item in value])\n else:\n return value\n\n# Path: web3/_utils/module.py\nimport inspect\nfrom io import (\n UnsupportedOperation,\n)\nfrom typing import (\n TYPE_CHECKING,\n Any,\n Dict,\n List,\n Optional,\n Sequence,\n Union,\n)\n\nfrom web3.exceptions import (\n Web3ValidationError,\n)\nfrom web3.module import (\n Module,\n)\n\nif TYPE_CHECKING:\n from web3.main import BaseWeb3 # noqa: F401\n\n\ndef _validate_init_params_and_return_if_found(module_class: Any) -> List[str]:\n init_params_raw = list(inspect.signature(module_class.__init__).parameters)\n module_init_params = [\n param for param in init_params_raw if param not in [\"self\", \"args\", \"kwargs\"]\n ]\n\n if len(module_init_params) > 1:\n raise UnsupportedOperation(\n \"A module class may accept a single `Web3` instance as the first \"\n \"argument of its __init__() method. More than one argument found for \"\n f\"{module_class.__name__}: {module_init_params}\"\n )\n\n return module_init_params\n\n\ndef attach_modules(\n parent_module: Union[\"BaseWeb3\", \"Module\"],\n module_definitions: Dict[str, Any],\n w3: Optional[Union[\"BaseWeb3\", \"Module\"]] = None,\n) -> None:\n for module_name, module_info in module_definitions.items():\n module_info_is_list_like = isinstance(module_info, Sequence)\n\n module_class = module_info[0] if module_info_is_list_like else module_info\n\n if hasattr(parent_module, module_name):\n raise AttributeError(\n f\"Cannot set {parent_module} module named '{module_name}'. \"\n \" The web3 object already has an attribute with that name\"\n )\n\n # The parent module is the ``Web3`` instance on first run of the loop and w3 is\n # None. Thus, set w3 to the parent_module. The import needs to happen locally\n # due to circular import issues.\n if w3 is None:\n from web3 import (\n AsyncWeb3,\n Web3,\n )\n\n if isinstance(parent_module, Web3) or isinstance(parent_module, AsyncWeb3):\n w3 = parent_module\n\n module_init_params = _validate_init_params_and_return_if_found(module_class)\n if len(module_init_params) == 1:\n # Modules that need access to the ``Web3`` instance may accept the\n # instance as the first arg in their ``__init__()`` method. This is the\n # case for any module that inherits from ``web3.module.Module``.\n # e.g. def __init__(self, w3):\n setattr(parent_module, module_name, module_class(w3))\n else:\n # Modules need not take in a ``Web3`` instance in\n # their ``__init__()`` if not needed\n setattr(parent_module, module_name, module_class())\n\n if module_info_is_list_like:\n if len(module_info) == 2:\n submodule_definitions = module_info[1]\n module = getattr(parent_module, module_name)\n attach_modules(module, submodule_definitions, w3)\n elif len(module_info) != 1:\n raise Web3ValidationError(\n \"Module definitions can only have 1 or 2 elements.\"\n )\n\n# Path: web3/geth.py\nfrom typing import (\n Any,\n Awaitable,\n Callable,\n Dict,\n List,\n Optional,\n Tuple,\n)\n\nfrom eth_typing.encoding import (\n HexStr,\n)\nfrom eth_typing.evm import (\n ChecksumAddress,\n)\nfrom hexbytes.main import (\n HexBytes,\n)\n\nfrom web3._utils.compat import (\n Protocol,\n)\nfrom web3._utils.rpc_abi import (\n RPC,\n)\nfrom web3.method import (\n Method,\n default_root_munger,\n)\nfrom web3.module import (\n Module,\n)\nfrom web3.types import (\n EnodeURI,\n GethWallet,\n NodeInfo,\n Peer,\n TxParams,\n TxPoolContent,\n TxPoolInspect,\n TxPoolStatus,\n)\n\n\nclass UnlockAccountWrapper(Protocol):\n def __call__(\n self,\n account: ChecksumAddress,\n passphrase: str,\n duration: Optional[int] = None,\n ) -> bool:\n pass\n\n\nclass GethPersonal(Module):\n \"\"\"\n https://geth.ethereum.org/docs/interacting-with-geth/rpc/ns-personal\n \"\"\"\n\n is_async = False\n\n ec_recover: Method[Callable[[str, HexStr], ChecksumAddress]] = Method(\n RPC.personal_ecRecover,\n mungers=[default_root_munger],\n )\n\n import_raw_key: Method[Callable[[str, str], ChecksumAddress]] = Method(\n RPC.personal_importRawKey,\n mungers=[default_root_munger],\n )\n\n list_accounts: Method[Callable[[], List[ChecksumAddress]]] = Method(\n RPC.personal_listAccounts,\n is_property=True,\n )\n\n list_wallets: Method[Callable[[], List[GethWallet]]] = Method(\n RPC.personal_listWallets,\n is_property=True,\n )\n\n send_transaction: Method[Callable[[TxParams, str], HexBytes]] = Method(\n RPC.personal_sendTransaction,\n mungers=[default_root_munger],\n )\n\n sign: Method[Callable[[str, ChecksumAddress, Optional[str]], HexStr]] = Method(\n RPC.personal_sign,\n mungers=[default_root_munger],\n )\n\n sign_typed_data: Method[\n Callable[[Dict[str, Any], ChecksumAddress, str], HexStr]\n ] = Method(\n RPC.personal_signTypedData,\n mungers=[default_root_munger],\n )\n\n new_account: Method[Callable[[str], ChecksumAddress]] = Method(\n RPC.personal_newAccount,\n mungers=[default_root_munger],\n )\n\n lock_account: Method[Callable[[ChecksumAddress], bool]] = Method(\n RPC.personal_lockAccount,\n mungers=[default_root_munger],\n )\n\n unlock_account: Method[UnlockAccountWrapper] = Method(\n RPC.personal_unlockAccount,\n mungers=[default_root_munger],\n )\n\n\nclass GethTxPool(Module):\n \"\"\"\n https://geth.ethereum.org/docs/interacting-with-geth/rpc/ns-txpool\n \"\"\"\n\n is_async = False\n\n content: Method[Callable[[], TxPoolContent]] = Method(\n RPC.txpool_content,\n is_property=True,\n )\n\n inspect: Method[Callable[[], TxPoolInspect]] = Method(\n RPC.txpool_inspect,\n is_property=True,\n )\n\n status: Method[Callable[[], TxPoolStatus]] = Method(\n RPC.txpool_status,\n is_property=True,\n )\n\n\nclass ServerConnection(Protocol):\n def __call__(\n self,\n host: str = \"localhost\",\n port: int = 8546,\n cors: str = \"\",\n apis: str = \"eth,net,web3\",\n ) -> bool:\n pass\n\n\ndef admin_start_params_munger(\n _module: Module,\n host: str = \"localhost\",\n port: int = 8546,\n cors: str = \"\",\n apis: str = \"eth,net,web3\",\n) -> Tuple[str, int, str, str]:\n return (host, port, cors, apis)\n\n\nclass GethAdmin(Module):\n \"\"\"\n https://geth.ethereum.org/docs/interacting-with-geth/rpc/ns-admin\n \"\"\"\n\n is_async = False\n\n add_peer: Method[Callable[[EnodeURI], bool]] = Method(\n RPC.admin_addPeer,\n mungers=[default_root_munger],\n )\n\n datadir: Method[Callable[[], str]] = Method(\n RPC.admin_datadir,\n is_property=True,\n )\n\n node_info: Method[Callable[[], NodeInfo]] = Method(\n RPC.admin_nodeInfo,\n is_property=True,\n )\n\n peers: Method[Callable[[], List[Peer]]] = Method(\n RPC.admin_peers,\n is_property=True,\n )\n\n start_http: Method[ServerConnection] = Method(\n RPC.admin_startHTTP,\n mungers=[admin_start_params_munger],\n )\n\n start_ws: Method[ServerConnection] = Method(\n RPC.admin_startWS,\n mungers=[admin_start_params_munger],\n )\n\n stop_http: Method[Callable[[], bool]] = Method(\n RPC.admin_stopHTTP,\n is_property=True,\n )\n\n stop_ws: Method[Callable[[], bool]] = Method(\n RPC.admin_stopWS,\n is_property=True,\n )\n\n\nclass Geth(Module):\n personal: GethPersonal\n admin: GethAdmin\n txpool: GethTxPool\n\n\n# --- async --- #\n\n\nclass AsyncGethTxPool(Module):\n \"\"\"\n https://geth.ethereum.org/docs/interacting-with-geth/rpc/ns-txpool\n \"\"\"\n\n is_async = True\n\n _content: Method[Callable[[], Awaitable[TxPoolContent]]] = Method(\n RPC.txpool_content,\n is_property=True,\n )\n\n async def content(self) -> TxPoolContent:\n return await self._content()\n\n _inspect: Method[Callable[[], Awaitable[TxPoolInspect]]] = Method(\n RPC.txpool_inspect,\n is_property=True,\n )\n\n async def inspect(self) -> TxPoolInspect:\n return await self._inspect()\n\n _status: Method[Callable[[], Awaitable[TxPoolStatus]]] = Method(\n RPC.txpool_status,\n is_property=True,\n )\n\n async def status(self) -> TxPoolStatus:\n return await self._status()\n\n\nclass AsyncGethAdmin(Module):\n \"\"\"\n https://geth.ethereum.org/docs/interacting-with-geth/rpc/ns-admin\n \"\"\"\n\n is_async = True\n\n _add_peer: Method[Callable[[EnodeURI], Awaitable[bool]]] = Method(\n RPC.admin_addPeer,\n mungers=[default_root_munger],\n )\n\n async def add_peer(self, node_url: EnodeURI) -> bool:\n return await self._add_peer(node_url)\n\n _datadir: Method[Callable[[], Awaitable[str]]] = Method(\n RPC.admin_datadir,\n is_property=True,\n )\n\n async def datadir(self) -> str:\n return await self._datadir()\n\n _node_info: Method[Callable[[], Awaitable[NodeInfo]]] = Method(\n RPC.admin_nodeInfo,\n is_property=True,\n )\n\n async def node_info(self) -> NodeInfo:\n return await self._node_info()\n\n _peers: Method[Callable[[], Awaitable[List[Peer]]]] = Method(\n RPC.admin_peers,\n is_property=True,\n )\n\n async def peers(self) -> List[Peer]:\n return await self._peers()\n\n # start_http and stop_http\n\n _start_http: Method[Callable[[str, int, str, str], Awaitable[bool]]] = Method(\n RPC.admin_startHTTP,\n mungers=[admin_start_params_munger],\n )\n\n _stop_http: Method[Callable[[], Awaitable[bool]]] = Method(\n RPC.admin_stopHTTP,\n is_property=True,\n )\n\n async def start_http(\n self,\n host: str = \"localhost\",\n port: int = 8546,\n cors: str = \"\",\n apis: str = \"eth,net,web3\",\n ) -> bool:\n return await self._start_http(host, port, cors, apis)\n\n async def stop_http(self) -> bool:\n return await self._stop_http()\n\n # start_ws and stop_ws\n\n _start_ws: Method[Callable[[str, int, str, str], Awaitable[bool]]] = Method(\n RPC.admin_startWS,\n mungers=[admin_start_params_munger],\n )\n\n _stop_ws: Method[Callable[[], Awaitable[bool]]] = Method(\n RPC.admin_stopWS,\n is_property=True,\n )\n\n async def start_ws(\n self,\n host: str = \"localhost\",\n port: int = 8546,\n cors: str = \"\",\n apis: str = \"eth,net,web3\",\n ) -> bool:\n return await self._start_ws(host, port, cors, apis)\n\n async def stop_ws(self) -> bool:\n return await self._stop_ws()\n\n\nclass AsyncGethPersonal(Module):\n \"\"\"\n https://geth.ethereum.org/docs/interacting-with-geth/rpc/ns-personal\n \"\"\"\n\n is_async = True\n\n # ec_recover\n\n _ec_recover: Method[Callable[[str, HexStr], Awaitable[ChecksumAddress]]] = Method(\n RPC.personal_ecRecover,\n mungers=[default_root_munger],\n )\n\n async def ec_recover(self, message: str, signature: HexStr) -> ChecksumAddress:\n return await self._ec_recover(message, signature)\n\n # import_raw_key\n\n _import_raw_key: Method[Callable[[str, str], Awaitable[ChecksumAddress]]] = Method(\n RPC.personal_importRawKey,\n mungers=[default_root_munger],\n )\n\n async def import_raw_key(\n self, private_key: str, passphrase: str\n ) -> ChecksumAddress:\n return await self._import_raw_key(private_key, passphrase)\n\n # list_accounts and list_wallets\n\n _list_accounts: Method[Callable[[], Awaitable[List[ChecksumAddress]]]] = Method(\n RPC.personal_listAccounts,\n is_property=True,\n )\n\n _list_wallets: Method[Callable[[], Awaitable[List[GethWallet]]]] = Method(\n RPC.personal_listWallets,\n is_property=True,\n )\n\n async def list_accounts(self) -> List[ChecksumAddress]:\n return await self._list_accounts()\n\n async def list_wallets(self) -> List[GethWallet]:\n return await self._list_wallets()\n\n # send_transaction\n\n _send_transaction: Method[Callable[[TxParams, str], Awaitable[HexBytes]]] = Method(\n RPC.personal_sendTransaction,\n mungers=[default_root_munger],\n )\n\n async def send_transaction(\n self, transaction: TxParams, passphrase: str\n ) -> HexBytes:\n return await self._send_transaction(transaction, passphrase)\n\n # sign and sign_typed_data\n\n _sign: Method[\n Callable[[str, ChecksumAddress, Optional[str]], Awaitable[HexStr]]\n ] = Method(\n RPC.personal_sign,\n mungers=[default_root_munger],\n )\n\n _sign_typed_data: Method[\n Callable[[Dict[str, Any], ChecksumAddress, str], Awaitable[HexStr]]\n ] = Method(\n RPC.personal_signTypedData,\n mungers=[default_root_munger],\n )\n\n async def sign(\n self, message: str, account: ChecksumAddress, passphrase: str\n ) -> HexStr:\n return await self._sign(message, account, passphrase)\n\n async def sign_typed_data(\n self, message: Dict[str, Any], account: ChecksumAddress, passphrase: str\n ) -> HexStr:\n return await self._sign_typed_data(message, account, passphrase)\n\n # new_account, lock_account, and unlock_account\n...\n", "needle_token_start": 8725, "needle_token_end": 8895, "code_context_ntokens": 16020, "output": ["```python\n@curry\ndef map_abi_data(\n normalizers: Sequence[Callable[[TypeStr, Any], Tuple[TypeStr, Any]]],\n types: Sequence[TypeStr],\n data: Sequence[Any],\n) -> Any:\n \"\"\"\n This function will apply normalizers to your data, in the\n context of the relevant types. Each normalizer is in the format:\n\n def normalizer(datatype, data):\n # Conditionally modify data\n return (datatype, data)\n\n Where datatype is a valid ABI type string, like \"uint\".\n\n In case of an array, like \"bool[2]\", normalizer will receive `data`\n as an iterable of typed data, like `[(\"bool\", True), (\"bool\", False)]`.\n\n Internals\n ---\n\n This is accomplished by:\n\n 1. Decorating the data tree with types\n 2. Recursively mapping each of the normalizers to the data\n 3. Stripping the types back out of the tree\n \"\"\"\n pipeline = itertools.chain(\n [abi_data_tree(types)],\n map(data_tree_map, normalizers),\n [partial(recursive_map, strip_abi_type)],\n )\n\n return pipe(data, *pipeline)\n```"]} +{"repo": "ethereum/web3.py", "name": "test_gas_price_strategy_hex_value", "language": "python", "path": "web3/_utils/module_testing/eth_module.py", "position_ratio": 0.65, "description": "\nFunction Description:\n1. **Purpose**: To verify that the system correctly handles a gas price strategy function returning a hexadecimal string value for the gas price when sending a transaction.\n2. **Input**: The input includes a transaction parameter dictionary without a specified gas price but with a custom gas price strategy set to return a hexadecimal string.\n3. **Output**: The output is the validation that the transaction uses the gas price as defined by the hexadecimal string returned from the gas price strategy.\n4. **Procedure**: \n - Define a transaction parameter dictionary without a specified gas price.\n - Set a custom gas price strategy that returns a hexadecimal string representing the gas price.\n - Send a transaction using these parameters.\n - Retrieve and verify that the transaction's gas price matches the expected value converted from the hexadecimal string.\n - Reset the gas price strategy to ensure no side effects for other tests.\n\n", "instruction": "Based on the function description and code context, please retrieve and repeat the exact described function from the code context in a code block wrapped by ```:", "template": "instruction\ncode_context\ndescription\ninstruction", "code_context": " assert len(accounts) != 0\n assert all((is_checksum_address(account) for account in accounts))\n assert w3.eth.coinbase in accounts\n\n def test_eth_block_number(self, w3: \"Web3\") -> None:\n block_number = w3.eth.block_number\n assert is_integer(block_number)\n assert block_number >= 0\n\n def test_eth_get_block_number(self, w3: \"Web3\") -> None:\n block_number = w3.eth.get_block_number()\n assert is_integer(block_number)\n assert block_number >= 0\n\n def test_eth_get_balance(self, w3: \"Web3\") -> None:\n coinbase = w3.eth.coinbase\n\n with pytest.raises(InvalidAddress):\n w3.eth.get_balance(ChecksumAddress(HexAddress(HexStr(coinbase.lower()))))\n\n balance = w3.eth.get_balance(coinbase)\n\n assert is_integer(balance)\n assert balance >= 0\n\n def test_eth_get_balance_with_block_identifier(self, w3: \"Web3\") -> None:\n miner_address = w3.eth.get_block(1)[\"miner\"]\n balance_post_genesis = w3.eth.get_balance(miner_address, 1)\n later_balance = w3.eth.get_balance(miner_address, \"latest\")\n\n assert is_integer(balance_post_genesis)\n assert is_integer(later_balance)\n assert later_balance > balance_post_genesis\n\n @pytest.mark.parametrize(\n \"address, expect_success\",\n [(\"test-address.eth\", True), (\"not-an-address.eth\", False)],\n )\n def test_eth_get_balance_with_ens_name(\n self, w3: \"Web3\", address: ChecksumAddress, expect_success: bool\n ) -> None:\n with ens_addresses(w3, {\"test-address.eth\": w3.eth.accounts[0]}):\n if expect_success:\n balance = w3.eth.get_balance(address)\n assert is_integer(balance)\n assert balance >= 0\n else:\n with pytest.raises(NameNotFound):\n w3.eth.get_balance(address)\n\n def test_eth_get_storage_at(self, w3: \"Web3\", storage_contract: \"Contract\") -> None:\n storage_contract_address = storage_contract.address\n\n slot_0 = w3.eth.get_storage_at(storage_contract_address, 0)\n assert slot_0 == HexBytes(f\"0x{'00' * 32}\")\n\n slot_1 = w3.eth.get_storage_at(storage_contract_address, 1)\n assert slot_1 == HexBytes(f\"0x{'00' * 31}01\")\n\n slot_2 = w3.eth.get_storage_at(storage_contract_address, 2)\n assert slot_2 == HexBytes(f\"0x{'00' * 31}02\")\n\n slot_3 = w3.eth.get_storage_at(storage_contract_address, 3)\n assert slot_3 == HexBytes(\n \"0x746872656500000000000000000000000000000000000000000000000000000a\"\n )\n assert bytes(slot_3[:5]) == b\"three\"\n\n slot_4 = w3.eth.get_storage_at(storage_contract_address, 4)\n assert slot_4 == HexBytes(\n \"0x666f757200000000000000000000000000000000000000000000000000000008\"\n )\n assert bytes(slot_4[:4]) == b\"four\"\n\n def test_eth_get_storage_at_ens_name(\n self, w3: \"Web3\", storage_contract: \"Contract\"\n ) -> None:\n with ens_addresses(w3, {\"storage.eth\": storage_contract.address}):\n storage = w3.eth.get_storage_at(ENS(\"storage.eth\"), 1)\n assert storage == HexBytes(f\"0x{'00' * 31}01\")\n\n def test_eth_get_storage_at_invalid_address(self, w3: \"Web3\") -> None:\n coinbase = w3.eth.coinbase\n with pytest.raises(InvalidAddress):\n w3.eth.get_storage_at(\n ChecksumAddress(HexAddress(HexStr(coinbase.lower()))), 0\n )\n\n def test_eth_get_transaction_count(\n self, w3: \"Web3\", unlocked_account_dual_type: ChecksumAddress\n ) -> None:\n transaction_count = w3.eth.get_transaction_count(unlocked_account_dual_type)\n assert is_integer(transaction_count)\n assert transaction_count >= 0\n\n def test_eth_get_transaction_count_ens_name(\n self, w3: \"Web3\", unlocked_account_dual_type: ChecksumAddress\n ) -> None:\n with ens_addresses(\n w3, {\"unlocked-acct-dual-type.eth\": unlocked_account_dual_type}\n ):\n transaction_count = w3.eth.get_transaction_count(\n ENS(\"unlocked-acct-dual-type.eth\")\n )\n assert is_integer(transaction_count)\n assert transaction_count >= 0\n\n def test_eth_get_transaction_count_invalid_address(self, w3: \"Web3\") -> None:\n coinbase = w3.eth.coinbase\n with pytest.raises(InvalidAddress):\n w3.eth.get_transaction_count(\n ChecksumAddress(HexAddress(HexStr(coinbase.lower())))\n )\n\n def test_eth_getBlockTransactionCountByHash_empty_block(\n self, w3: \"Web3\", empty_block: BlockData\n ) -> None:\n transaction_count = w3.eth.get_block_transaction_count(empty_block[\"hash\"])\n\n assert is_integer(transaction_count)\n assert transaction_count == 0\n\n def test_eth_getBlockTransactionCountByNumber_empty_block(\n self, w3: \"Web3\", empty_block: BlockData\n ) -> None:\n transaction_count = w3.eth.get_block_transaction_count(empty_block[\"number\"])\n\n assert is_integer(transaction_count)\n assert transaction_count == 0\n\n def test_eth_getBlockTransactionCountByHash_block_with_txn(\n self, w3: \"Web3\", block_with_txn: BlockData\n ) -> None:\n transaction_count = w3.eth.get_block_transaction_count(block_with_txn[\"hash\"])\n\n assert is_integer(transaction_count)\n assert transaction_count >= 1\n\n def test_eth_getBlockTransactionCountByNumber_block_with_txn(\n self, w3: \"Web3\", block_with_txn: BlockData\n ) -> None:\n transaction_count = w3.eth.get_block_transaction_count(block_with_txn[\"number\"])\n\n assert is_integer(transaction_count)\n assert transaction_count >= 1\n\n def test_eth_getUncleCountByBlockHash(\n self, w3: \"Web3\", empty_block: BlockData\n ) -> None:\n uncle_count = w3.eth.get_uncle_count(empty_block[\"hash\"])\n\n assert is_integer(uncle_count)\n assert uncle_count == 0\n\n def test_eth_getUncleCountByBlockNumber(\n self, w3: \"Web3\", empty_block: BlockData\n ) -> None:\n uncle_count = w3.eth.get_uncle_count(empty_block[\"number\"])\n\n assert is_integer(uncle_count)\n assert uncle_count == 0\n\n def test_eth_get_code(\n self, w3: \"Web3\", math_contract_address: ChecksumAddress\n ) -> None:\n code = w3.eth.get_code(math_contract_address)\n assert isinstance(code, HexBytes)\n assert len(code) > 0\n\n def test_eth_get_code_ens_address(\n self, w3: \"Web3\", math_contract_address: ChecksumAddress\n ) -> None:\n with ens_addresses(w3, {\"mathcontract.eth\": math_contract_address}):\n code = w3.eth.get_code(ENS(\"mathcontract.eth\"))\n assert isinstance(code, HexBytes)\n assert len(code) > 0\n\n def test_eth_get_code_invalid_address(\n self, w3: \"Web3\", math_contract: \"Contract\"\n ) -> None:\n with pytest.raises(InvalidAddress):\n w3.eth.get_code(\n ChecksumAddress(HexAddress(HexStr(math_contract.address.lower())))\n )\n\n def test_eth_get_code_with_block_identifier(\n self, w3: \"Web3\", emitter_contract: \"Contract\"\n ) -> None:\n code = w3.eth.get_code(\n emitter_contract.address, block_identifier=w3.eth.block_number\n )\n assert isinstance(code, HexBytes)\n assert len(code) > 0\n\n def test_eth_create_access_list(\n self,\n w3: \"Web3\",\n unlocked_account_dual_type: ChecksumAddress,\n math_contract: \"Contract\",\n ) -> None:\n # Initialize transaction for gas estimation\n txn_params: TxParams = {\n \"from\": unlocked_account_dual_type,\n \"value\": Wei(1),\n \"gas\": 21000,\n }\n\n txn = math_contract.functions.incrementCounter(1).build_transaction(txn_params)\n\n # create access list using data from transaction\n response = w3.eth.create_access_list(\n {\n \"from\": unlocked_account_dual_type,\n \"to\": math_contract.address,\n \"data\": txn[\"data\"],\n }\n )\n\n assert is_dict(response)\n access_list = response[\"accessList\"]\n assert len(access_list) > 0\n assert access_list[0][\"address\"] is not None\n assert is_checksum_address(access_list[0][\"address\"])\n assert len(access_list[0][\"storageKeys\"][0]) == 32\n assert int(response[\"gasUsed\"]) >= 0\n\n def test_eth_sign(\n self, w3: \"Web3\", unlocked_account_dual_type: ChecksumAddress\n ) -> None:\n signature = w3.eth.sign(\n unlocked_account_dual_type, text=\"Message t\u00f6 sign. Longer than hash!\"\n )\n assert is_bytes(signature)\n assert len(signature) == 32 + 32 + 1\n\n # test other formats\n hexsign = w3.eth.sign(\n unlocked_account_dual_type,\n hexstr=HexStr(\n \"0x4d6573736167652074c3b6207369676e2e204c6f6e676572207468616e206861736821\" # noqa: E501\n ),\n )\n assert hexsign == signature\n\n intsign = w3.eth.sign(\n unlocked_account_dual_type,\n 0x4D6573736167652074C3B6207369676E2E204C6F6E676572207468616E206861736821,\n )\n assert intsign == signature\n\n bytessign = w3.eth.sign(\n unlocked_account_dual_type, b\"Message t\\xc3\\xb6 sign. Longer than hash!\"\n )\n assert bytessign == signature\n\n new_signature = w3.eth.sign(\n unlocked_account_dual_type, text=\"different message is different\"\n )\n assert new_signature != signature\n\n def test_eth_sign_ens_names(\n self, w3: \"Web3\", unlocked_account_dual_type: ChecksumAddress\n ) -> None:\n with ens_addresses(w3, {\"unlocked-acct.eth\": unlocked_account_dual_type}):\n signature = w3.eth.sign(\n \"unlocked-acct.eth\", text=\"Message t\u00f6 sign. Longer than hash!\"\n )\n assert is_bytes(signature)\n assert len(signature) == 32 + 32 + 1\n\n def test_eth_sign_typed_data(\n self,\n w3: \"Web3\",\n unlocked_account_dual_type: ChecksumAddress,\n skip_if_testrpc: Callable[[\"Web3\"], None],\n ) -> None:\n validJSONMessage = \"\"\"\n {\n \"types\": {\n \"EIP712Domain\": [\n {\"name\": \"name\", \"type\": \"string\"},\n {\"name\": \"version\", \"type\": \"string\"},\n {\"name\": \"chainId\", \"type\": \"uint256\"},\n {\"name\": \"verifyingContract\", \"type\": \"address\"}\n ],\n \"Person\": [\n {\"name\": \"name\", \"type\": \"string\"},\n {\"name\": \"wallet\", \"type\": \"address\"}\n ],\n \"Mail\": [\n {\"name\": \"from\", \"type\": \"Person\"},\n {\"name\": \"to\", \"type\": \"Person\"},\n {\"name\": \"contents\", \"type\": \"string\"}\n ]\n },\n \"primaryType\": \"Mail\",\n \"domain\": {\n \"name\": \"Ether Mail\",\n \"version\": \"1\",\n \"chainId\": \"0x01\",\n \"verifyingContract\": \"0xCcCCccccCCCCcCCCCCCcCcCccCcCCCcCcccccccC\"\n },\n \"message\": {\n \"from\": {\n \"name\": \"Cow\",\n \"wallet\": \"0xCD2a3d9F938E13CD947Ec05AbC7FE734Df8DD826\"\n },\n \"to\": {\n \"name\": \"Bob\",\n \"wallet\": \"0xbBbBBBBbbBBBbbbBbbBbbbbBBbBbbbbBbBbbBBbB\"\n },\n \"contents\": \"Hello, Bob!\"\n }\n }\n \"\"\"\n skip_if_testrpc(w3)\n signature = HexBytes(\n w3.eth.sign_typed_data(\n unlocked_account_dual_type, json.loads(validJSONMessage)\n )\n )\n assert len(signature) == 32 + 32 + 1\n\n def test_invalid_eth_sign_typed_data(\n self,\n w3: \"Web3\",\n unlocked_account_dual_type: ChecksumAddress,\n skip_if_testrpc: Callable[[\"Web3\"], None],\n ) -> None:\n skip_if_testrpc(w3)\n invalid_typed_message = \"\"\"\n {\n \"types\": {\n \"EIP712Domain\": [\n {\"name\": \"name\", \"type\": \"string\"},\n {\"name\": \"version\", \"type\": \"string\"},\n {\"name\": \"chainId\", \"type\": \"uint256\"},\n {\"name\": \"verifyingContract\", \"type\": \"address\"}\n ],\n \"Person\": [\n {\"name\": \"name\", \"type\": \"string\"},\n {\"name\": \"wallet\", \"type\": \"address\"}\n ],\n \"Mail\": [\n {\"name\": \"from\", \"type\": \"Person\"},\n {\"name\": \"to\", \"type\": \"Person[2]\"},\n {\"name\": \"contents\", \"type\": \"string\"}\n ]\n },\n \"primaryType\": \"Mail\",\n \"domain\": {\n \"name\": \"Ether Mail\",\n \"version\": \"1\",\n \"chainId\": \"0x01\",\n \"verifyingContract\": \"0xCcCCccccCCCCcCCCCCCcCcCccCcCCCcCcccccccC\"\n },\n \"message\": {\n \"from\": {\n \"name\": \"Cow\",\n \"wallet\": \"0xCD2a3d9F938E13CD947Ec05AbC7FE734Df8DD826\"\n },\n \"to\": [{\n \"name\": \"Bob\",\n \"wallet\": \"0xbBbBBBBbbBBBbbbBbbBbbbbBBbBbbbbBbBbbBBbB\"\n }],\n \"contents\": \"Hello, Bob!\"\n }\n }\n \"\"\"\n with pytest.raises(\n ValueError,\n match=r\".*Expected 2 items for array type Person\\[2\\], got 1 items.*\",\n ):\n w3.eth.sign_typed_data(\n unlocked_account_dual_type, json.loads(invalid_typed_message)\n )\n\n def test_eth_sign_transaction_legacy(\n self, w3: \"Web3\", unlocked_account: ChecksumAddress\n ) -> None:\n txn_params: TxParams = {\n \"from\": unlocked_account,\n \"to\": unlocked_account,\n \"value\": Wei(1),\n \"gas\": 21000,\n \"gasPrice\": w3.eth.gas_price,\n \"nonce\": Nonce(0),\n }\n result = w3.eth.sign_transaction(txn_params)\n signatory_account = w3.eth.account.recover_transaction(result[\"raw\"])\n assert unlocked_account == signatory_account\n assert result[\"tx\"][\"to\"] == txn_params[\"to\"]\n assert result[\"tx\"][\"value\"] == txn_params[\"value\"]\n assert result[\"tx\"][\"gas\"] == txn_params[\"gas\"]\n assert result[\"tx\"][\"gasPrice\"] == txn_params[\"gasPrice\"]\n assert result[\"tx\"][\"nonce\"] == txn_params[\"nonce\"]\n\n def test_eth_sign_transaction(\n self, w3: \"Web3\", unlocked_account: ChecksumAddress\n ) -> None:\n txn_params: TxParams = {\n \"from\": unlocked_account,\n \"to\": unlocked_account,\n \"value\": Wei(1),\n \"gas\": 21000,\n \"maxFeePerGas\": w3.to_wei(2, \"gwei\"),\n \"maxPriorityFeePerGas\": w3.to_wei(1, \"gwei\"),\n \"nonce\": Nonce(0),\n }\n result = w3.eth.sign_transaction(txn_params)\n signatory_account = w3.eth.account.recover_transaction(result[\"raw\"])\n assert unlocked_account == signatory_account\n assert result[\"tx\"][\"to\"] == txn_params[\"to\"]\n assert result[\"tx\"][\"value\"] == txn_params[\"value\"]\n assert result[\"tx\"][\"gas\"] == txn_params[\"gas\"]\n assert result[\"tx\"][\"maxFeePerGas\"] == txn_params[\"maxFeePerGas\"]\n assert (\n result[\"tx\"][\"maxPriorityFeePerGas\"] == txn_params[\"maxPriorityFeePerGas\"]\n )\n assert result[\"tx\"][\"nonce\"] == txn_params[\"nonce\"]\n\n def test_eth_sign_transaction_hex_fees(\n self, w3: \"Web3\", unlocked_account: ChecksumAddress\n ) -> None:\n txn_params: TxParams = {\n \"from\": unlocked_account,\n \"to\": unlocked_account,\n \"value\": Wei(1),\n \"gas\": 21000,\n \"maxFeePerGas\": hex(w3.to_wei(2, \"gwei\")),\n \"maxPriorityFeePerGas\": hex(w3.to_wei(1, \"gwei\")),\n \"nonce\": Nonce(0),\n }\n result = w3.eth.sign_transaction(txn_params)\n signatory_account = w3.eth.account.recover_transaction(result[\"raw\"])\n assert unlocked_account == signatory_account\n assert result[\"tx\"][\"to\"] == txn_params[\"to\"]\n assert result[\"tx\"][\"value\"] == txn_params[\"value\"]\n assert result[\"tx\"][\"gas\"] == txn_params[\"gas\"]\n assert result[\"tx\"][\"maxFeePerGas\"] == int(str(txn_params[\"maxFeePerGas\"]), 16)\n assert result[\"tx\"][\"maxPriorityFeePerGas\"] == int(\n str(txn_params[\"maxPriorityFeePerGas\"]), 16\n )\n assert result[\"tx\"][\"nonce\"] == txn_params[\"nonce\"]\n\n def test_eth_sign_transaction_ens_names(\n self, w3: \"Web3\", unlocked_account: ChecksumAddress\n ) -> None:\n with ens_addresses(w3, {\"unlocked-account.eth\": unlocked_account}):\n txn_params: TxParams = {\n \"from\": \"unlocked-account.eth\",\n \"to\": \"unlocked-account.eth\",\n \"value\": Wei(1),\n \"gas\": 21000,\n \"maxFeePerGas\": w3.to_wei(2, \"gwei\"),\n \"maxPriorityFeePerGas\": w3.to_wei(1, \"gwei\"),\n \"nonce\": Nonce(0),\n }\n result = w3.eth.sign_transaction(txn_params)\n signatory_account = w3.eth.account.recover_transaction(result[\"raw\"])\n assert unlocked_account == signatory_account\n assert result[\"tx\"][\"to\"] == unlocked_account\n assert result[\"tx\"][\"value\"] == txn_params[\"value\"]\n assert result[\"tx\"][\"gas\"] == txn_params[\"gas\"]\n assert result[\"tx\"][\"maxFeePerGas\"] == txn_params[\"maxFeePerGas\"]\n assert (\n result[\"tx\"][\"maxPriorityFeePerGas\"]\n == txn_params[\"maxPriorityFeePerGas\"]\n )\n assert result[\"tx\"][\"nonce\"] == txn_params[\"nonce\"]\n\n def test_eth_send_transaction_addr_checksum_required(\n self, w3: \"Web3\", unlocked_account: ChecksumAddress\n ) -> None:\n non_checksum_addr = unlocked_account.lower()\n txn_params: TxParams = {\n \"from\": unlocked_account,\n \"to\": unlocked_account,\n \"value\": Wei(1),\n \"gas\": 21000,\n \"maxFeePerGas\": w3.to_wei(2, \"gwei\"),\n \"maxPriorityFeePerGas\": w3.to_wei(1, \"gwei\"),\n }\n\n with pytest.raises(InvalidAddress):\n invalid_params = cast(\n TxParams, dict(txn_params, **{\"from\": non_checksum_addr})\n )\n w3.eth.send_transaction(invalid_params)\n\n with pytest.raises(InvalidAddress):\n invalid_params = cast(\n TxParams, dict(txn_params, **{\"to\": non_checksum_addr})\n )\n w3.eth.send_transaction(invalid_params)\n\n def test_eth_send_transaction_legacy(\n self, w3: \"Web3\", unlocked_account_dual_type: ChecksumAddress\n ) -> None:\n txn_params: TxParams = {\n \"from\": unlocked_account_dual_type,\n \"to\": unlocked_account_dual_type,\n \"value\": Wei(1),\n \"gas\": 21000,\n \"gasPrice\": w3.to_wei(\n 1, \"gwei\"\n ), # post-london needs to be more than the base fee\n }\n txn_hash = w3.eth.send_transaction(txn_params)\n txn = w3.eth.get_transaction(txn_hash)\n\n assert is_same_address(txn[\"from\"], cast(ChecksumAddress, txn_params[\"from\"]))\n assert is_same_address(txn[\"to\"], cast(ChecksumAddress, txn_params[\"to\"]))\n assert txn[\"value\"] == 1\n assert txn[\"gas\"] == 21000\n assert txn[\"gasPrice\"] == txn_params[\"gasPrice\"]\n\n def test_eth_send_transaction(\n self, w3: \"Web3\", unlocked_account_dual_type: ChecksumAddress\n ) -> None:\n txn_params: TxParams = {\n \"from\": unlocked_account_dual_type,\n \"to\": unlocked_account_dual_type,\n \"value\": Wei(1),\n \"gas\": 21000,\n \"maxFeePerGas\": w3.to_wei(3, \"gwei\"),\n \"maxPriorityFeePerGas\": w3.to_wei(1, \"gwei\"),\n }\n txn_hash = w3.eth.send_transaction(txn_params)\n txn = w3.eth.get_transaction(txn_hash)\n\n assert is_same_address(txn[\"from\"], cast(ChecksumAddress, txn_params[\"from\"]))\n assert is_same_address(txn[\"to\"], cast(ChecksumAddress, txn_params[\"to\"]))\n assert txn[\"value\"] == 1\n assert txn[\"gas\"] == 21000\n assert txn[\"maxFeePerGas\"] == txn_params[\"maxFeePerGas\"]\n assert txn[\"maxPriorityFeePerGas\"] == txn_params[\"maxPriorityFeePerGas\"]\n assert txn[\"gasPrice\"] == txn_params[\"maxFeePerGas\"]\n\n def test_eth_send_transaction_with_nonce(\n self, w3: \"Web3\", unlocked_account: ChecksumAddress\n ) -> None:\n max_priority_fee_per_gas = w3.to_wei(1.234, \"gwei\")\n max_fee_per_gas = Wei(\n w3.eth.get_block(\"latest\")[\"baseFeePerGas\"] + max_priority_fee_per_gas\n )\n txn_params: TxParams = {\n \"from\": unlocked_account,\n \"to\": unlocked_account,\n \"value\": Wei(1),\n \"gas\": 21000,\n \"maxFeePerGas\": max_fee_per_gas,\n \"maxPriorityFeePerGas\": max_priority_fee_per_gas,\n \"nonce\": Nonce(w3.eth.get_transaction_count(unlocked_account, \"pending\")),\n }\n txn_hash = w3.eth.send_transaction(txn_params)\n txn = w3.eth.get_transaction(txn_hash)\n\n assert is_same_address(txn[\"from\"], cast(ChecksumAddress, txn_params[\"from\"]))\n assert is_same_address(txn[\"to\"], cast(ChecksumAddress, txn_params[\"to\"]))\n assert txn[\"value\"] == 1\n assert txn[\"gas\"] == 21000\n assert txn[\"maxFeePerGas\"] == txn_params[\"maxFeePerGas\"]\n assert txn[\"maxPriorityFeePerGas\"] == txn_params[\"maxPriorityFeePerGas\"]\n assert txn[\"nonce\"] == txn_params[\"nonce\"]\n assert is_integer(txn[\"gasPrice\"])\n assert is_integer(txn_params[\"maxFeePerGas\"])\n\n def test_eth_send_transaction_default_fees(\n self, w3: \"Web3\", unlocked_account_dual_type: ChecksumAddress\n ) -> None:\n txn_params: TxParams = {\n \"from\": unlocked_account_dual_type,\n \"to\": unlocked_account_dual_type,\n \"value\": Wei(1),\n \"gas\": 21000,\n }\n txn_hash = w3.eth.send_transaction(txn_params)\n txn = w3.eth.get_transaction(txn_hash)\n\n assert is_same_address(txn[\"from\"], cast(ChecksumAddress, txn_params[\"from\"]))\n assert is_same_address(txn[\"to\"], cast(ChecksumAddress, txn_params[\"to\"]))\n assert txn[\"value\"] == 1\n assert txn[\"gas\"] == 21000\n assert is_integer(txn[\"maxPriorityFeePerGas\"])\n assert is_integer(txn[\"maxFeePerGas\"])\n assert txn[\"gasPrice\"] == txn[\"maxFeePerGas\"]\n\n def test_eth_send_transaction_hex_fees(\n self, w3: \"Web3\", unlocked_account_dual_type: ChecksumAddress\n ) -> None:\n txn_params: TxParams = {\n \"from\": unlocked_account_dual_type,\n \"to\": unlocked_account_dual_type,\n \"value\": Wei(1),\n \"gas\": 21000,\n \"maxFeePerGas\": hex(250 * 10**9),\n \"maxPriorityFeePerGas\": hex(2 * 10**9),\n }\n txn_hash = w3.eth.send_transaction(txn_params)\n txn = w3.eth.get_transaction(txn_hash)\n\n assert is_same_address(txn[\"from\"], cast(ChecksumAddress, txn_params[\"from\"]))\n assert is_same_address(txn[\"to\"], cast(ChecksumAddress, txn_params[\"to\"]))\n assert txn[\"value\"] == 1\n assert txn[\"gas\"] == 21000\n assert txn[\"maxFeePerGas\"] == 250 * 10**9\n assert txn[\"maxPriorityFeePerGas\"] == 2 * 10**9\n\n def test_eth_send_transaction_no_gas(\n self, w3: \"Web3\", unlocked_account_dual_type: ChecksumAddress\n ) -> None:\n txn_params: TxParams = {\n \"from\": unlocked_account_dual_type,\n \"to\": unlocked_account_dual_type,\n \"value\": Wei(1),\n \"maxFeePerGas\": Wei(250 * 10**9),\n \"maxPriorityFeePerGas\": Wei(2 * 10**9),\n }\n txn_hash = w3.eth.send_transaction(txn_params)\n txn = w3.eth.get_transaction(txn_hash)\n\n assert is_same_address(txn[\"from\"], cast(ChecksumAddress, txn_params[\"from\"]))\n assert is_same_address(txn[\"to\"], cast(ChecksumAddress, txn_params[\"to\"]))\n assert txn[\"value\"] == 1\n assert txn[\"gas\"] == 121000 # 21000 + buffer\n\n def test_eth_send_transaction_with_gas_price(\n self, w3: \"Web3\", unlocked_account_dual_type: ChecksumAddress\n ) -> None:\n txn_params: TxParams = {\n \"from\": unlocked_account_dual_type,\n \"to\": unlocked_account_dual_type,\n \"value\": Wei(1),\n \"gas\": 21000,\n \"gasPrice\": Wei(1),\n \"maxFeePerGas\": Wei(250 * 10**9),\n \"maxPriorityFeePerGas\": Wei(2 * 10**9),\n }\n with pytest.raises(TransactionTypeMismatch):\n w3.eth.send_transaction(txn_params)\n\n def test_eth_send_transaction_no_priority_fee(\n self, w3: \"Web3\", unlocked_account_dual_type: ChecksumAddress\n ) -> None:\n txn_params: TxParams = {\n \"from\": unlocked_account_dual_type,\n \"to\": unlocked_account_dual_type,\n \"value\": Wei(1),\n \"gas\": 21000,\n \"maxFeePerGas\": Wei(250 * 10**9),\n }\n with pytest.raises(\n InvalidTransaction, match=\"maxPriorityFeePerGas must be defined\"\n ):\n w3.eth.send_transaction(txn_params)\n\n def test_eth_send_transaction_no_max_fee(\n self, w3: \"Web3\", unlocked_account_dual_type: ChecksumAddress\n ) -> None:\n max_priority_fee_per_gas = w3.to_wei(2, \"gwei\")\n txn_params: TxParams = {\n \"from\": unlocked_account_dual_type,\n \"to\": unlocked_account_dual_type,\n \"value\": Wei(1),\n \"gas\": 21000,\n \"maxPriorityFeePerGas\": max_priority_fee_per_gas,\n }\n txn_hash = w3.eth.send_transaction(txn_params)\n txn = w3.eth.get_transaction(txn_hash)\n\n assert is_same_address(txn[\"from\"], cast(ChecksumAddress, txn_params[\"from\"]))\n assert is_same_address(txn[\"to\"], cast(ChecksumAddress, txn_params[\"to\"]))\n assert txn[\"value\"] == 1\n assert txn[\"gas\"] == 21000\n assert is_integer(txn[\"maxPriorityFeePerGas\"])\n assert txn[\"maxPriorityFeePerGas\"] == max_priority_fee_per_gas\n assert is_integer(txn[\"maxFeePerGas\"])\n\n def test_eth_send_transaction_max_fee_less_than_tip(\n self, w3: \"Web3\", unlocked_account_dual_type: ChecksumAddress\n ) -> None:\n txn_params: TxParams = {\n \"from\": unlocked_account_dual_type,\n \"to\": unlocked_account_dual_type,\n \"value\": Wei(1),\n \"gas\": 21000,\n \"maxFeePerGas\": Wei(1 * 10**9),\n \"maxPriorityFeePerGas\": Wei(2 * 10**9),\n }\n with pytest.raises(\n InvalidTransaction, match=\"maxFeePerGas must be >= maxPriorityFeePerGas\"\n ):\n w3.eth.send_transaction(txn_params)\n\n def test_validation_middleware_chain_id_mismatch(\n self, w3: \"Web3\", unlocked_account_dual_type: ChecksumAddress\n ) -> None:\n wrong_chain_id = 1234567890\n actual_chain_id = w3.eth.chain_id\n\n txn_params: TxParams = {\n \"from\": unlocked_account_dual_type,\n \"to\": unlocked_account_dual_type,\n \"value\": Wei(1),\n \"gas\": Wei(21000),\n \"maxFeePerGas\": w3.to_wei(2, \"gwei\"),\n \"maxPriorityFeePerGas\": w3.to_wei(1, \"gwei\"),\n \"chainId\": wrong_chain_id,\n }\n with pytest.raises(\n Web3ValidationError,\n match=f\"The transaction declared chain ID {wrong_chain_id}, \"\n f\"but the connected node is on {actual_chain_id}\",\n ):\n w3.eth.send_transaction(txn_params)\n\n @pytest.mark.parametrize(\n \"max_fee\", (1000000000, None), ids=[\"with_max_fee\", \"without_max_fee\"]\n )\n def test_gas_price_from_strategy_bypassed_for_dynamic_fee_txn(\n self, w3: \"Web3\", unlocked_account_dual_type: ChecksumAddress, max_fee: Wei\n ) -> None:\n max_priority_fee = w3.to_wei(1, \"gwei\")\n txn_params: TxParams = {\n \"from\": unlocked_account_dual_type,\n \"to\": unlocked_account_dual_type,\n \"value\": Wei(1),\n \"gas\": 21000,\n \"maxPriorityFeePerGas\": max_priority_fee,\n }\n if max_fee is not None:\n txn_params = assoc(txn_params, \"maxFeePerGas\", max_fee)\n\n def gas_price_strategy(_w3: \"Web3\", _txn: TxParams) -> Wei:\n return w3.to_wei(2, \"gwei\")\n\n w3.eth.set_gas_price_strategy(gas_price_strategy)\n\n txn_hash = w3.eth.send_transaction(txn_params)\n txn = w3.eth.get_transaction(txn_hash)\n\n latest_block = w3.eth.get_block(\"latest\")\n assert (\n txn[\"maxFeePerGas\"] == max_fee\n if max_fee is not None\n else 2 * latest_block[\"baseFeePerGas\"] + max_priority_fee\n )\n assert txn[\"maxPriorityFeePerGas\"] == max_priority_fee\n assert txn[\"gasPrice\"] == txn[\"maxFeePerGas\"]\n\n w3.eth.set_gas_price_strategy(None) # reset strategy\n\n def test_gas_price_from_strategy_bypassed_for_dynamic_fee_txn_no_tip(\n self,\n w3: \"Web3\",\n unlocked_account_dual_type: ChecksumAddress,\n ) -> None:\n txn_params: TxParams = {\n \"from\": unlocked_account_dual_type,\n \"to\": unlocked_account_dual_type,\n \"value\": Wei(1),\n \"gas\": 21000,\n \"maxFeePerGas\": Wei(1000000000),\n }\n\n def gas_price_strategy(_w3: \"Web3\", _txn: TxParams) -> Wei:\n return w3.to_wei(2, \"gwei\")\n\n w3.eth.set_gas_price_strategy(gas_price_strategy)\n\n with pytest.raises(\n InvalidTransaction, match=\"maxPriorityFeePerGas must be defined\"\n ):\n w3.eth.send_transaction(txn_params)\n\n w3.eth.set_gas_price_strategy(None) # reset strategy\n\n def \ntest_gas_price_strategy_hex_value(\n self, w3: \"Web3\", unlocked_account_dual_type: ChecksumAddress\n ) -> None:\n txn_params: TxParams = {\n \"from\": unlocked_account_dual_type,\n \"to\": unlocked_account_dual_type,\n \"value\": Wei(1),\n \"gas\": 21000,\n }\n two_gwei_in_wei = w3.to_wei(2, \"gwei\")\n\n def gas_price_strategy(_w3: \"Web3\", _txn: TxParams) -> str:\n return hex(two_gwei_in_wei)\n\n w3.eth.set_gas_price_strategy(gas_price_strategy) # type: ignore\n\n txn_hash = w3.eth.send_transaction(txn_params)\n txn = w3.eth.get_transaction(txn_hash)\n\n assert txn[\"gasPrice\"] == two_gwei_in_wei\n w3.eth.set_gas_price_strategy(None) # reset strategy\n\n @flaky_geth_dev_mining\n def test_eth_replace_transaction_legacy(\n self, w3: \"Web3\", unlocked_account_dual_type: ChecksumAddress\n ) -> None:\n txn_params: TxParams = {\n \"from\": unlocked_account_dual_type,\n \"to\": unlocked_account_dual_type,\n \"value\": Wei(1),\n \"gas\": 21000,\n \"gasPrice\": w3.to_wei(\n 1, \"gwei\"\n ), # must be greater than base_fee post London\n }\n txn_hash = w3.eth.send_transaction(txn_params)\n\n txn_params[\"gasPrice\"] = w3.to_wei(2, \"gwei\")\n replace_txn_hash = w3.eth.replace_transaction(txn_hash, txn_params)\n replace_txn = w3.eth.get_transaction(replace_txn_hash)\n\n assert is_same_address(\n replace_txn[\"from\"], cast(ChecksumAddress, txn_params[\"from\"])\n )\n assert is_same_address(\n replace_txn[\"to\"], cast(ChecksumAddress, txn_params[\"to\"])\n )\n assert replace_txn[\"value\"] == 1\n assert replace_txn[\"gas\"] == 21000\n assert replace_txn[\"gasPrice\"] == txn_params[\"gasPrice\"]\n\n @flaky_geth_dev_mining\n def test_eth_replace_transaction(\n self, w3: \"Web3\", unlocked_account_dual_type: ChecksumAddress\n ) -> None:\n two_gwei_in_wei = w3.to_wei(2, \"gwei\")\n three_gwei_in_wei = w3.to_wei(3, \"gwei\")\n\n txn_params: TxParams = {\n \"from\": unlocked_account_dual_type,\n \"to\": unlocked_account_dual_type,\n \"value\": Wei(1),\n \"gas\": 21000,\n \"maxFeePerGas\": two_gwei_in_wei,\n \"maxPriorityFeePerGas\": w3.to_wei(1, \"gwei\"),\n }\n txn_hash = w3.eth.send_transaction(txn_params)\n\n txn_params[\"maxFeePerGas\"] = three_gwei_in_wei\n txn_params[\"maxPriorityFeePerGas\"] = two_gwei_in_wei\n\n replace_txn_hash = w3.eth.replace_transaction(txn_hash, txn_params)\n replace_txn = w3.eth.get_transaction(replace_txn_hash)\n\n assert is_same_address(\n replace_txn[\"from\"], cast(ChecksumAddress, txn_params[\"from\"])\n )\n assert is_same_address(\n replace_txn[\"to\"], cast(ChecksumAddress, txn_params[\"to\"])\n )\n assert replace_txn[\"value\"] == 1\n assert replace_txn[\"gas\"] == 21000\n assert replace_txn[\"maxFeePerGas\"] == three_gwei_in_wei\n assert replace_txn[\"maxPriorityFeePerGas\"] == two_gwei_in_wei\n\n @flaky_geth_dev_mining\n def test_eth_replace_transaction_underpriced(\n self, w3: \"Web3\", unlocked_account_dual_type: ChecksumAddress\n ) -> None:\n txn_params: TxParams = {\n \"from\": unlocked_account_dual_type,\n \"to\": unlocked_account_dual_type,\n \"value\": Wei(1),\n \"gas\": 21000,\n \"maxFeePerGas\": w3.to_wei(3, \"gwei\"),\n \"maxPriorityFeePerGas\": w3.to_wei(2, \"gwei\"),\n }\n txn_hash = w3.eth.send_transaction(txn_params)\n\n one_gwei_in_wei = w3.to_wei(1, \"gwei\")\n txn_params[\"maxFeePerGas\"] = one_gwei_in_wei\n txn_params[\"maxPriorityFeePerGas\"] = one_gwei_in_wei\n\n with pytest.raises(ValueError, match=\"replacement transaction underpriced\"):\n w3.eth.replace_transaction(txn_hash, txn_params)\n\n @flaky_geth_dev_mining\n def test_eth_replace_transaction_non_existing_transaction(\n self, w3: \"Web3\", unlocked_account_dual_type: ChecksumAddress\n ) -> None:\n txn_params: TxParams = {\n \"from\": unlocked_account_dual_type,\n \"to\": unlocked_account_dual_type,\n \"value\": Wei(1),\n \"gas\": 21000,\n \"maxFeePerGas\": w3.to_wei(3, \"gwei\"),\n \"maxPriorityFeePerGas\": w3.to_wei(1, \"gwei\"),\n }\n with pytest.raises(TransactionNotFound):\n w3.eth.replace_transaction(\n HexStr(\n \"0x98e8cc09b311583c5079fa600f6c2a3bea8611af168c52e4b60b5b243a441997\"\n ),\n txn_params,\n )\n\n @flaky_geth_dev_mining\n def test_eth_replace_transaction_already_mined(\n self, w3: \"Web3\", unlocked_account_dual_type: ChecksumAddress\n ) -> None:\n txn_params: TxParams = {\n \"from\": unlocked_account_dual_type,\n \"to\": unlocked_account_dual_type,\n \"value\": Wei(1),\n \"gas\": 21000,\n \"maxFeePerGas\": w3.to_wei(2, \"gwei\"),\n \"maxPriorityFeePerGas\": w3.to_wei(1, \"gwei\"),\n }\n txn_hash = w3.eth.send_transaction(txn_params)\n w3.eth.wait_for_transaction_receipt(txn_hash, timeout=10)\n\n txn_params[\"maxFeePerGas\"] = w3.to_wei(3, \"gwei\")\n txn_params[\"maxPriorityFeePerGas\"] = w3.to_wei(2, \"gwei\")\n with pytest.raises(ValueError, match=\"Supplied transaction with hash\"):\n w3.eth.replace_transaction(txn_hash, txn_params)\n\n @flaky_geth_dev_mining\n def test_eth_replace_transaction_incorrect_nonce(\n self, w3: \"Web3\", unlocked_account: ChecksumAddress\n ) -> None:\n txn_params: TxParams = {\n \"from\": unlocked_account,\n \"to\": unlocked_account,\n \"value\": Wei(1),\n \"gas\": 21000,\n \"maxFeePerGas\": w3.to_wei(2, \"gwei\"),\n \"maxPriorityFeePerGas\": w3.to_wei(1, \"gwei\"),\n }\n txn_hash = w3.eth.send_transaction(txn_params)\n txn = w3.eth.get_transaction(txn_hash)\n\n txn_params[\"maxFeePerGas\"] = w3.to_wei(3, \"gwei\")\n txn_params[\"maxPriorityFeePerGas\"] = w3.to_wei(2, \"gwei\")\n txn_params[\"nonce\"] = Nonce(txn[\"nonce\"] + 1)\n with pytest.raises(ValueError):\n w3.eth.replace_transaction(txn_hash, txn_params)\n\n @flaky_geth_dev_mining\n def test_eth_replace_transaction_gas_price_too_low(\n self, w3: \"Web3\", unlocked_account_dual_type: ChecksumAddress\n ) -> None:\n txn_params: TxParams = {\n \"from\": unlocked_account_dual_type,\n \"to\": unlocked_account_dual_type,\n \"value\": Wei(1),\n \"gas\": 21000,\n \"gasPrice\": w3.to_wei(2, \"gwei\"),\n }\n txn_hash = w3.eth.send_transaction(txn_params)\n\n txn_params[\"gasPrice\"] = w3.to_wei(1, \"gwei\")\n with pytest.raises(ValueError):\n w3.eth.replace_transaction(txn_hash, txn_params)\n\n @flaky_geth_dev_mining\n def test_eth_replace_transaction_gas_price_defaulting_minimum(\n self, w3: \"Web3\", unlocked_account: ChecksumAddress\n ) -> None:\n gas_price = w3.to_wei(1, \"gwei\")\n\n txn_params: TxParams = {\n \"from\": unlocked_account,\n \"to\": unlocked_account,\n \"value\": Wei(1),\n \"gas\": 21000,\n \"gasPrice\": gas_price,\n }\n txn_hash = w3.eth.send_transaction(txn_params)\n\n txn_params.pop(\"gasPrice\")\n replace_txn_hash = w3.eth.replace_transaction(txn_hash, txn_params)\n replace_txn = w3.eth.get_transaction(replace_txn_hash)\n\n assert replace_txn[\"gasPrice\"] == math.ceil(\n gas_price * 1.125\n ) # minimum gas price\n\n @flaky_geth_dev_mining\n def test_eth_replace_transaction_gas_price_defaulting_strategy_higher(\n self, w3: \"Web3\", unlocked_account: ChecksumAddress\n ) -> None:\n txn_params: TxParams = {\n \"from\": unlocked_account,\n \"to\": unlocked_account,\n \"value\": Wei(1),\n \"gas\": 21000,\n \"gasPrice\": w3.to_wei(1, \"gwei\"),\n }\n txn_hash = w3.eth.send_transaction(txn_params)\n\n two_gwei_in_wei = w3.to_wei(2, \"gwei\")\n\n def higher_gas_price_strategy(w3: \"Web3\", txn: TxParams) -> Wei:\n return two_gwei_in_wei\n\n w3.eth.set_gas_price_strategy(higher_gas_price_strategy)\n\n txn_params.pop(\"gasPrice\")\n replace_txn_hash = w3.eth.replace_transaction(txn_hash, txn_params)\n replace_txn = w3.eth.get_transaction(replace_txn_hash)\n assert (\n replace_txn[\"gasPrice\"] == two_gwei_in_wei\n ) # Strategy provides higher gas price\n w3.eth.set_gas_price_strategy(None) # reset strategy\n\n @flaky_geth_dev_mining\n def test_eth_replace_transaction_gas_price_defaulting_strategy_lower(\n self, w3: \"Web3\", unlocked_account: ChecksumAddress\n ) -> None:\n gas_price = w3.to_wei(2, \"gwei\")\n\n txn_params: TxParams = {\n \"from\": unlocked_account,\n \"to\": unlocked_account,\n \"value\": Wei(1),\n \"gas\": 21000,\n \"gasPrice\": gas_price,\n }\n txn_hash = w3.eth.send_transaction(txn_params)\n\n def lower_gas_price_strategy(w3: \"Web3\", txn: TxParams) -> Wei:\n return w3.to_wei(1, \"gwei\")\n\n w3.eth.set_gas_price_strategy(lower_gas_price_strategy)\n\n txn_params.pop(\"gasPrice\")\n replace_txn_hash = w3.eth.replace_transaction(txn_hash, txn_params)\n replace_txn = w3.eth.get_transaction(replace_txn_hash)\n # Strategy provides lower gas price - minimum preferred\n assert replace_txn[\"gasPrice\"] == math.ceil(gas_price * 1.125)\n w3.eth.set_gas_price_strategy(None) # reset strategy\n\n def test_eth_modify_transaction_legacy(\n self, w3: \"Web3\", unlocked_account: ChecksumAddress\n ) -> None:\n txn_params: TxParams = {\n \"from\": unlocked_account,\n \"to\": unlocked_account,\n \"value\": Wei(1),\n \"gas\": 21000,\n \"gasPrice\": w3.to_wei(\n 1, \"gwei\"\n ), # must be greater than base_fee post London\n }\n txn_hash = w3.eth.send_transaction(txn_params)\n\n modified_txn_hash = w3.eth.modify_transaction(\n txn_hash, gasPrice=(cast(int, txn_params[\"gasPrice\"]) * 2), value=2\n )\n modified_txn = w3.eth.get_transaction(modified_txn_hash)\n\n assert is_same_address(\n modified_txn[\"from\"], cast(ChecksumAddress, txn_params[\"from\"])\n )\n assert is_same_address(\n modified_txn[\"to\"], cast(ChecksumAddress, txn_params[\"to\"])\n )\n assert modified_txn[\"value\"] == 2\n assert modified_txn[\"gas\"] == 21000\n assert modified_txn[\"gasPrice\"] == cast(int, txn_params[\"gasPrice\"]) * 2\n\n def test_eth_modify_transaction(\n self, w3: \"Web3\", unlocked_account: ChecksumAddress\n ) -> None:\n txn_params: TxParams = {\n \"from\": unlocked_account,\n \"to\": unlocked_account,\n \"value\": Wei(1),\n \"gas\": 21000,\n \"maxPriorityFeePerGas\": w3.to_wei(1, \"gwei\"),\n \"maxFeePerGas\": w3.to_wei(2, \"gwei\"),\n }\n txn_hash = w3.eth.send_transaction(txn_params)\n\n modified_txn_hash = w3.eth.modify_transaction(\n txn_hash,\n value=2,\n maxPriorityFeePerGas=(cast(Wei, txn_params[\"maxPriorityFeePerGas\"]) * 2),\n maxFeePerGas=(cast(Wei, txn_params[\"maxFeePerGas\"]) * 2),\n )\n modified_txn = w3.eth.get_transaction(modified_txn_hash)\n\n assert is_same_address(\n modified_txn[\"from\"], cast(ChecksumAddress, txn_params[\"from\"])\n )\n assert is_same_address(\n modified_txn[\"to\"], cast(ChecksumAddress, txn_params[\"to\"])\n )\n assert modified_txn[\"value\"] == 2\n assert modified_txn[\"gas\"] == 21000\n assert (\n modified_txn[\"maxPriorityFeePerGas\"]\n == cast(Wei, txn_params[\"maxPriorityFeePerGas\"]) * 2\n )\n assert modified_txn[\"maxFeePerGas\"] == cast(Wei, txn_params[\"maxFeePerGas\"]) * 2\n\n def test_eth_send_raw_transaction(\n self, w3: \"Web3\", unlocked_account: ChecksumAddress\n ) -> None:\n signed_tx = w3.eth.account.sign_transaction(\n {\n \"to\": \"0x0000000000000000000000000000000000000000\",\n \"value\": 0,\n \"nonce\": w3.eth.get_transaction_count(unlocked_account),\n \"gas\": 21000,\n \"maxFeePerGas\": 1000000000,\n \"maxPriorityFeePerGas\": 1000000000,\n \"chainId\": 131277322940537,\n },\n # unlocked_account private key:\n \"0x392f63a79b1ff8774845f3fa69de4a13800a59e7083f5187f1558f0797ad0f01\",\n )\n txn_hash = w3.eth.send_raw_transaction(signed_tx.rawTransaction)\n assert txn_hash == signed_tx.hash\n\n def test_eth_call(self, w3: \"Web3\", math_contract: \"Contract\") -> None:\n coinbase = w3.eth.coinbase\n txn_params = math_contract._prepare_transaction(\n fn_name=\"add\",\n fn_args=(7, 11),\n transaction={\"from\": coinbase, \"to\": math_contract.address},\n )\n call_result = w3.eth.call(txn_params)\n assert is_string(call_result)\n (result,) = w3.codec.decode([\"uint256\"], call_result)\n assert result == 18\n\n def test_eth_call_with_override_code(\n self, w3: \"Web3\", revert_contract: \"Contract\"\n ) -> None:\n coinbase = w3.eth.coinbase\n txn_params = revert_contract._prepare_transaction(\n fn_name=\"normalFunction\",\n transaction={\"from\": coinbase, \"to\": revert_contract.address},\n )\n call_result = w3.eth.call(txn_params)\n (result,) = w3.codec.decode([\"bool\"], call_result)\n assert result is True\n\n # override runtime bytecode: `normalFunction` returns `false`\n override_code = HexStr(\n \"0x6080604052348015600f57600080fd5b5060043610603c5760003560e01c8063185c38a4146041578063c06a97cb146049578063d67e4b84146051575b600080fd5b60476071565b005b604f60df565b005b605760e4565b604051808215151515815260200191505060405180910390f35b6040517f08c379a000000000000000000000000000000000000000000000000000000000815260040180806020018281038252601b8152602001807f46756e6374696f6e20686173206265656e2072657665727465642e000000000081525060200191505060405180910390fd5b600080fd5b60008090509056fea2646970667358221220bb71e9e9a2e271cd0fbe833524a3ea67df95f25ea13aef5b0a761fa52b538f1064736f6c63430006010033\" # noqa: E501\n )\n call_result = w3.eth.call(\n txn_params, \"latest\", {revert_contract.address: {\"code\": override_code}}\n )\n (result,) = w3.codec.decode([\"bool\"], call_result)\n assert result is False\n\n # test bytes\n\n bytes_call_result = w3.eth.call(\n txn_params,\n \"latest\",\n...\n", "needle_token_start": 10417, "needle_token_end": 10707, "code_context_ntokens": 16239, "output": ["```python\ndef test_gas_price_strategy_hex_value(\n self, w3: \"Web3\", unlocked_account_dual_type: ChecksumAddress\n ) -> None:\n txn_params: TxParams = {\n \"from\": unlocked_account_dual_type,\n \"to\": unlocked_account_dual_type,\n \"value\": Wei(1),\n \"gas\": 21000,\n }\n two_gwei_in_wei = w3.to_wei(2, \"gwei\")\n\n def gas_price_strategy(_w3: \"Web3\", _txn: TxParams) -> str:\n return hex(two_gwei_in_wei)\n\n w3.eth.set_gas_price_strategy(gas_price_strategy) # type: ignore\n\n txn_hash = w3.eth.send_transaction(txn_params)\n txn = w3.eth.get_transaction(txn_hash)\n\n assert txn[\"gasPrice\"] == two_gwei_in_wei\n w3.eth.set_gas_price_strategy(None) # reset strategy\n```"]} +{"repo": "ethereum/web3.py", "name": "test_eth_estimate_gas_revert_with_msg", "language": "python", "path": "web3/_utils/module_testing/eth_module.py", "position_ratio": 0.75, "description": "\nFunction Description:\n1. **Purpose**: To verify that the system correctly identifies and handles a transaction that is expected to revert due to a specific error message in a smart contract function.\n2. **Input**: A transaction object that includes details such as the sender's address and the target contract address.\n3. **Output**: The test is expected to raise an exception indicating that the transaction has been reverted along with a specific error message.\n4. **Procedure**: The test prepares a transaction intended to call a function known to revert. It then attempts to estimate the gas for this transaction, expecting to catch a specific error that indicates the function's execution has been reverted with a predefined message.\n\n", "instruction": "Based on the function description and code context, please retrieve and repeat the exact described function from the code context in a code block wrapped by ```:", "template": "instruction\ncode_context\ndescription\ninstruction", "code_context": " self, w3: \"Web3\", unlocked_account_dual_type: ChecksumAddress\n ) -> None:\n max_priority_fee_per_gas = w3.to_wei(2, \"gwei\")\n txn_params: TxParams = {\n \"from\": unlocked_account_dual_type,\n \"to\": unlocked_account_dual_type,\n \"value\": Wei(1),\n \"gas\": 21000,\n \"maxPriorityFeePerGas\": max_priority_fee_per_gas,\n }\n txn_hash = w3.eth.send_transaction(txn_params)\n txn = w3.eth.get_transaction(txn_hash)\n\n assert is_same_address(txn[\"from\"], cast(ChecksumAddress, txn_params[\"from\"]))\n assert is_same_address(txn[\"to\"], cast(ChecksumAddress, txn_params[\"to\"]))\n assert txn[\"value\"] == 1\n assert txn[\"gas\"] == 21000\n assert is_integer(txn[\"maxPriorityFeePerGas\"])\n assert txn[\"maxPriorityFeePerGas\"] == max_priority_fee_per_gas\n assert is_integer(txn[\"maxFeePerGas\"])\n\n def test_eth_send_transaction_max_fee_less_than_tip(\n self, w3: \"Web3\", unlocked_account_dual_type: ChecksumAddress\n ) -> None:\n txn_params: TxParams = {\n \"from\": unlocked_account_dual_type,\n \"to\": unlocked_account_dual_type,\n \"value\": Wei(1),\n \"gas\": 21000,\n \"maxFeePerGas\": Wei(1 * 10**9),\n \"maxPriorityFeePerGas\": Wei(2 * 10**9),\n }\n with pytest.raises(\n InvalidTransaction, match=\"maxFeePerGas must be >= maxPriorityFeePerGas\"\n ):\n w3.eth.send_transaction(txn_params)\n\n def test_validation_middleware_chain_id_mismatch(\n self, w3: \"Web3\", unlocked_account_dual_type: ChecksumAddress\n ) -> None:\n wrong_chain_id = 1234567890\n actual_chain_id = w3.eth.chain_id\n\n txn_params: TxParams = {\n \"from\": unlocked_account_dual_type,\n \"to\": unlocked_account_dual_type,\n \"value\": Wei(1),\n \"gas\": Wei(21000),\n \"maxFeePerGas\": w3.to_wei(2, \"gwei\"),\n \"maxPriorityFeePerGas\": w3.to_wei(1, \"gwei\"),\n \"chainId\": wrong_chain_id,\n }\n with pytest.raises(\n Web3ValidationError,\n match=f\"The transaction declared chain ID {wrong_chain_id}, \"\n f\"but the connected node is on {actual_chain_id}\",\n ):\n w3.eth.send_transaction(txn_params)\n\n @pytest.mark.parametrize(\n \"max_fee\", (1000000000, None), ids=[\"with_max_fee\", \"without_max_fee\"]\n )\n def test_gas_price_from_strategy_bypassed_for_dynamic_fee_txn(\n self, w3: \"Web3\", unlocked_account_dual_type: ChecksumAddress, max_fee: Wei\n ) -> None:\n max_priority_fee = w3.to_wei(1, \"gwei\")\n txn_params: TxParams = {\n \"from\": unlocked_account_dual_type,\n \"to\": unlocked_account_dual_type,\n \"value\": Wei(1),\n \"gas\": 21000,\n \"maxPriorityFeePerGas\": max_priority_fee,\n }\n if max_fee is not None:\n txn_params = assoc(txn_params, \"maxFeePerGas\", max_fee)\n\n def gas_price_strategy(_w3: \"Web3\", _txn: TxParams) -> Wei:\n return w3.to_wei(2, \"gwei\")\n\n w3.eth.set_gas_price_strategy(gas_price_strategy)\n\n txn_hash = w3.eth.send_transaction(txn_params)\n txn = w3.eth.get_transaction(txn_hash)\n\n latest_block = w3.eth.get_block(\"latest\")\n assert (\n txn[\"maxFeePerGas\"] == max_fee\n if max_fee is not None\n else 2 * latest_block[\"baseFeePerGas\"] + max_priority_fee\n )\n assert txn[\"maxPriorityFeePerGas\"] == max_priority_fee\n assert txn[\"gasPrice\"] == txn[\"maxFeePerGas\"]\n\n w3.eth.set_gas_price_strategy(None) # reset strategy\n\n def test_gas_price_from_strategy_bypassed_for_dynamic_fee_txn_no_tip(\n self,\n w3: \"Web3\",\n unlocked_account_dual_type: ChecksumAddress,\n ) -> None:\n txn_params: TxParams = {\n \"from\": unlocked_account_dual_type,\n \"to\": unlocked_account_dual_type,\n \"value\": Wei(1),\n \"gas\": 21000,\n \"maxFeePerGas\": Wei(1000000000),\n }\n\n def gas_price_strategy(_w3: \"Web3\", _txn: TxParams) -> Wei:\n return w3.to_wei(2, \"gwei\")\n\n w3.eth.set_gas_price_strategy(gas_price_strategy)\n\n with pytest.raises(\n InvalidTransaction, match=\"maxPriorityFeePerGas must be defined\"\n ):\n w3.eth.send_transaction(txn_params)\n\n w3.eth.set_gas_price_strategy(None) # reset strategy\n\n def test_gas_price_strategy_hex_value(\n self, w3: \"Web3\", unlocked_account_dual_type: ChecksumAddress\n ) -> None:\n txn_params: TxParams = {\n \"from\": unlocked_account_dual_type,\n \"to\": unlocked_account_dual_type,\n \"value\": Wei(1),\n \"gas\": 21000,\n }\n two_gwei_in_wei = w3.to_wei(2, \"gwei\")\n\n def gas_price_strategy(_w3: \"Web3\", _txn: TxParams) -> str:\n return hex(two_gwei_in_wei)\n\n w3.eth.set_gas_price_strategy(gas_price_strategy) # type: ignore\n\n txn_hash = w3.eth.send_transaction(txn_params)\n txn = w3.eth.get_transaction(txn_hash)\n\n assert txn[\"gasPrice\"] == two_gwei_in_wei\n w3.eth.set_gas_price_strategy(None) # reset strategy\n\n @flaky_geth_dev_mining\n def test_eth_replace_transaction_legacy(\n self, w3: \"Web3\", unlocked_account_dual_type: ChecksumAddress\n ) -> None:\n txn_params: TxParams = {\n \"from\": unlocked_account_dual_type,\n \"to\": unlocked_account_dual_type,\n \"value\": Wei(1),\n \"gas\": 21000,\n \"gasPrice\": w3.to_wei(\n 1, \"gwei\"\n ), # must be greater than base_fee post London\n }\n txn_hash = w3.eth.send_transaction(txn_params)\n\n txn_params[\"gasPrice\"] = w3.to_wei(2, \"gwei\")\n replace_txn_hash = w3.eth.replace_transaction(txn_hash, txn_params)\n replace_txn = w3.eth.get_transaction(replace_txn_hash)\n\n assert is_same_address(\n replace_txn[\"from\"], cast(ChecksumAddress, txn_params[\"from\"])\n )\n assert is_same_address(\n replace_txn[\"to\"], cast(ChecksumAddress, txn_params[\"to\"])\n )\n assert replace_txn[\"value\"] == 1\n assert replace_txn[\"gas\"] == 21000\n assert replace_txn[\"gasPrice\"] == txn_params[\"gasPrice\"]\n\n @flaky_geth_dev_mining\n def test_eth_replace_transaction(\n self, w3: \"Web3\", unlocked_account_dual_type: ChecksumAddress\n ) -> None:\n two_gwei_in_wei = w3.to_wei(2, \"gwei\")\n three_gwei_in_wei = w3.to_wei(3, \"gwei\")\n\n txn_params: TxParams = {\n \"from\": unlocked_account_dual_type,\n \"to\": unlocked_account_dual_type,\n \"value\": Wei(1),\n \"gas\": 21000,\n \"maxFeePerGas\": two_gwei_in_wei,\n \"maxPriorityFeePerGas\": w3.to_wei(1, \"gwei\"),\n }\n txn_hash = w3.eth.send_transaction(txn_params)\n\n txn_params[\"maxFeePerGas\"] = three_gwei_in_wei\n txn_params[\"maxPriorityFeePerGas\"] = two_gwei_in_wei\n\n replace_txn_hash = w3.eth.replace_transaction(txn_hash, txn_params)\n replace_txn = w3.eth.get_transaction(replace_txn_hash)\n\n assert is_same_address(\n replace_txn[\"from\"], cast(ChecksumAddress, txn_params[\"from\"])\n )\n assert is_same_address(\n replace_txn[\"to\"], cast(ChecksumAddress, txn_params[\"to\"])\n )\n assert replace_txn[\"value\"] == 1\n assert replace_txn[\"gas\"] == 21000\n assert replace_txn[\"maxFeePerGas\"] == three_gwei_in_wei\n assert replace_txn[\"maxPriorityFeePerGas\"] == two_gwei_in_wei\n\n @flaky_geth_dev_mining\n def test_eth_replace_transaction_underpriced(\n self, w3: \"Web3\", unlocked_account_dual_type: ChecksumAddress\n ) -> None:\n txn_params: TxParams = {\n \"from\": unlocked_account_dual_type,\n \"to\": unlocked_account_dual_type,\n \"value\": Wei(1),\n \"gas\": 21000,\n \"maxFeePerGas\": w3.to_wei(3, \"gwei\"),\n \"maxPriorityFeePerGas\": w3.to_wei(2, \"gwei\"),\n }\n txn_hash = w3.eth.send_transaction(txn_params)\n\n one_gwei_in_wei = w3.to_wei(1, \"gwei\")\n txn_params[\"maxFeePerGas\"] = one_gwei_in_wei\n txn_params[\"maxPriorityFeePerGas\"] = one_gwei_in_wei\n\n with pytest.raises(ValueError, match=\"replacement transaction underpriced\"):\n w3.eth.replace_transaction(txn_hash, txn_params)\n\n @flaky_geth_dev_mining\n def test_eth_replace_transaction_non_existing_transaction(\n self, w3: \"Web3\", unlocked_account_dual_type: ChecksumAddress\n ) -> None:\n txn_params: TxParams = {\n \"from\": unlocked_account_dual_type,\n \"to\": unlocked_account_dual_type,\n \"value\": Wei(1),\n \"gas\": 21000,\n \"maxFeePerGas\": w3.to_wei(3, \"gwei\"),\n \"maxPriorityFeePerGas\": w3.to_wei(1, \"gwei\"),\n }\n with pytest.raises(TransactionNotFound):\n w3.eth.replace_transaction(\n HexStr(\n \"0x98e8cc09b311583c5079fa600f6c2a3bea8611af168c52e4b60b5b243a441997\"\n ),\n txn_params,\n )\n\n @flaky_geth_dev_mining\n def test_eth_replace_transaction_already_mined(\n self, w3: \"Web3\", unlocked_account_dual_type: ChecksumAddress\n ) -> None:\n txn_params: TxParams = {\n \"from\": unlocked_account_dual_type,\n \"to\": unlocked_account_dual_type,\n \"value\": Wei(1),\n \"gas\": 21000,\n \"maxFeePerGas\": w3.to_wei(2, \"gwei\"),\n \"maxPriorityFeePerGas\": w3.to_wei(1, \"gwei\"),\n }\n txn_hash = w3.eth.send_transaction(txn_params)\n w3.eth.wait_for_transaction_receipt(txn_hash, timeout=10)\n\n txn_params[\"maxFeePerGas\"] = w3.to_wei(3, \"gwei\")\n txn_params[\"maxPriorityFeePerGas\"] = w3.to_wei(2, \"gwei\")\n with pytest.raises(ValueError, match=\"Supplied transaction with hash\"):\n w3.eth.replace_transaction(txn_hash, txn_params)\n\n @flaky_geth_dev_mining\n def test_eth_replace_transaction_incorrect_nonce(\n self, w3: \"Web3\", unlocked_account: ChecksumAddress\n ) -> None:\n txn_params: TxParams = {\n \"from\": unlocked_account,\n \"to\": unlocked_account,\n \"value\": Wei(1),\n \"gas\": 21000,\n \"maxFeePerGas\": w3.to_wei(2, \"gwei\"),\n \"maxPriorityFeePerGas\": w3.to_wei(1, \"gwei\"),\n }\n txn_hash = w3.eth.send_transaction(txn_params)\n txn = w3.eth.get_transaction(txn_hash)\n\n txn_params[\"maxFeePerGas\"] = w3.to_wei(3, \"gwei\")\n txn_params[\"maxPriorityFeePerGas\"] = w3.to_wei(2, \"gwei\")\n txn_params[\"nonce\"] = Nonce(txn[\"nonce\"] + 1)\n with pytest.raises(ValueError):\n w3.eth.replace_transaction(txn_hash, txn_params)\n\n @flaky_geth_dev_mining\n def test_eth_replace_transaction_gas_price_too_low(\n self, w3: \"Web3\", unlocked_account_dual_type: ChecksumAddress\n ) -> None:\n txn_params: TxParams = {\n \"from\": unlocked_account_dual_type,\n \"to\": unlocked_account_dual_type,\n \"value\": Wei(1),\n \"gas\": 21000,\n \"gasPrice\": w3.to_wei(2, \"gwei\"),\n }\n txn_hash = w3.eth.send_transaction(txn_params)\n\n txn_params[\"gasPrice\"] = w3.to_wei(1, \"gwei\")\n with pytest.raises(ValueError):\n w3.eth.replace_transaction(txn_hash, txn_params)\n\n @flaky_geth_dev_mining\n def test_eth_replace_transaction_gas_price_defaulting_minimum(\n self, w3: \"Web3\", unlocked_account: ChecksumAddress\n ) -> None:\n gas_price = w3.to_wei(1, \"gwei\")\n\n txn_params: TxParams = {\n \"from\": unlocked_account,\n \"to\": unlocked_account,\n \"value\": Wei(1),\n \"gas\": 21000,\n \"gasPrice\": gas_price,\n }\n txn_hash = w3.eth.send_transaction(txn_params)\n\n txn_params.pop(\"gasPrice\")\n replace_txn_hash = w3.eth.replace_transaction(txn_hash, txn_params)\n replace_txn = w3.eth.get_transaction(replace_txn_hash)\n\n assert replace_txn[\"gasPrice\"] == math.ceil(\n gas_price * 1.125\n ) # minimum gas price\n\n @flaky_geth_dev_mining\n def test_eth_replace_transaction_gas_price_defaulting_strategy_higher(\n self, w3: \"Web3\", unlocked_account: ChecksumAddress\n ) -> None:\n txn_params: TxParams = {\n \"from\": unlocked_account,\n \"to\": unlocked_account,\n \"value\": Wei(1),\n \"gas\": 21000,\n \"gasPrice\": w3.to_wei(1, \"gwei\"),\n }\n txn_hash = w3.eth.send_transaction(txn_params)\n\n two_gwei_in_wei = w3.to_wei(2, \"gwei\")\n\n def higher_gas_price_strategy(w3: \"Web3\", txn: TxParams) -> Wei:\n return two_gwei_in_wei\n\n w3.eth.set_gas_price_strategy(higher_gas_price_strategy)\n\n txn_params.pop(\"gasPrice\")\n replace_txn_hash = w3.eth.replace_transaction(txn_hash, txn_params)\n replace_txn = w3.eth.get_transaction(replace_txn_hash)\n assert (\n replace_txn[\"gasPrice\"] == two_gwei_in_wei\n ) # Strategy provides higher gas price\n w3.eth.set_gas_price_strategy(None) # reset strategy\n\n @flaky_geth_dev_mining\n def test_eth_replace_transaction_gas_price_defaulting_strategy_lower(\n self, w3: \"Web3\", unlocked_account: ChecksumAddress\n ) -> None:\n gas_price = w3.to_wei(2, \"gwei\")\n\n txn_params: TxParams = {\n \"from\": unlocked_account,\n \"to\": unlocked_account,\n \"value\": Wei(1),\n \"gas\": 21000,\n \"gasPrice\": gas_price,\n }\n txn_hash = w3.eth.send_transaction(txn_params)\n\n def lower_gas_price_strategy(w3: \"Web3\", txn: TxParams) -> Wei:\n return w3.to_wei(1, \"gwei\")\n\n w3.eth.set_gas_price_strategy(lower_gas_price_strategy)\n\n txn_params.pop(\"gasPrice\")\n replace_txn_hash = w3.eth.replace_transaction(txn_hash, txn_params)\n replace_txn = w3.eth.get_transaction(replace_txn_hash)\n # Strategy provides lower gas price - minimum preferred\n assert replace_txn[\"gasPrice\"] == math.ceil(gas_price * 1.125)\n w3.eth.set_gas_price_strategy(None) # reset strategy\n\n def test_eth_modify_transaction_legacy(\n self, w3: \"Web3\", unlocked_account: ChecksumAddress\n ) -> None:\n txn_params: TxParams = {\n \"from\": unlocked_account,\n \"to\": unlocked_account,\n \"value\": Wei(1),\n \"gas\": 21000,\n \"gasPrice\": w3.to_wei(\n 1, \"gwei\"\n ), # must be greater than base_fee post London\n }\n txn_hash = w3.eth.send_transaction(txn_params)\n\n modified_txn_hash = w3.eth.modify_transaction(\n txn_hash, gasPrice=(cast(int, txn_params[\"gasPrice\"]) * 2), value=2\n )\n modified_txn = w3.eth.get_transaction(modified_txn_hash)\n\n assert is_same_address(\n modified_txn[\"from\"], cast(ChecksumAddress, txn_params[\"from\"])\n )\n assert is_same_address(\n modified_txn[\"to\"], cast(ChecksumAddress, txn_params[\"to\"])\n )\n assert modified_txn[\"value\"] == 2\n assert modified_txn[\"gas\"] == 21000\n assert modified_txn[\"gasPrice\"] == cast(int, txn_params[\"gasPrice\"]) * 2\n\n def test_eth_modify_transaction(\n self, w3: \"Web3\", unlocked_account: ChecksumAddress\n ) -> None:\n txn_params: TxParams = {\n \"from\": unlocked_account,\n \"to\": unlocked_account,\n \"value\": Wei(1),\n \"gas\": 21000,\n \"maxPriorityFeePerGas\": w3.to_wei(1, \"gwei\"),\n \"maxFeePerGas\": w3.to_wei(2, \"gwei\"),\n }\n txn_hash = w3.eth.send_transaction(txn_params)\n\n modified_txn_hash = w3.eth.modify_transaction(\n txn_hash,\n value=2,\n maxPriorityFeePerGas=(cast(Wei, txn_params[\"maxPriorityFeePerGas\"]) * 2),\n maxFeePerGas=(cast(Wei, txn_params[\"maxFeePerGas\"]) * 2),\n )\n modified_txn = w3.eth.get_transaction(modified_txn_hash)\n\n assert is_same_address(\n modified_txn[\"from\"], cast(ChecksumAddress, txn_params[\"from\"])\n )\n assert is_same_address(\n modified_txn[\"to\"], cast(ChecksumAddress, txn_params[\"to\"])\n )\n assert modified_txn[\"value\"] == 2\n assert modified_txn[\"gas\"] == 21000\n assert (\n modified_txn[\"maxPriorityFeePerGas\"]\n == cast(Wei, txn_params[\"maxPriorityFeePerGas\"]) * 2\n )\n assert modified_txn[\"maxFeePerGas\"] == cast(Wei, txn_params[\"maxFeePerGas\"]) * 2\n\n def test_eth_send_raw_transaction(\n self, w3: \"Web3\", unlocked_account: ChecksumAddress\n ) -> None:\n signed_tx = w3.eth.account.sign_transaction(\n {\n \"to\": \"0x0000000000000000000000000000000000000000\",\n \"value\": 0,\n \"nonce\": w3.eth.get_transaction_count(unlocked_account),\n \"gas\": 21000,\n \"maxFeePerGas\": 1000000000,\n \"maxPriorityFeePerGas\": 1000000000,\n \"chainId\": 131277322940537,\n },\n # unlocked_account private key:\n \"0x392f63a79b1ff8774845f3fa69de4a13800a59e7083f5187f1558f0797ad0f01\",\n )\n txn_hash = w3.eth.send_raw_transaction(signed_tx.rawTransaction)\n assert txn_hash == signed_tx.hash\n\n def test_eth_call(self, w3: \"Web3\", math_contract: \"Contract\") -> None:\n coinbase = w3.eth.coinbase\n txn_params = math_contract._prepare_transaction(\n fn_name=\"add\",\n fn_args=(7, 11),\n transaction={\"from\": coinbase, \"to\": math_contract.address},\n )\n call_result = w3.eth.call(txn_params)\n assert is_string(call_result)\n (result,) = w3.codec.decode([\"uint256\"], call_result)\n assert result == 18\n\n def test_eth_call_with_override_code(\n self, w3: \"Web3\", revert_contract: \"Contract\"\n ) -> None:\n coinbase = w3.eth.coinbase\n txn_params = revert_contract._prepare_transaction(\n fn_name=\"normalFunction\",\n transaction={\"from\": coinbase, \"to\": revert_contract.address},\n )\n call_result = w3.eth.call(txn_params)\n (result,) = w3.codec.decode([\"bool\"], call_result)\n assert result is True\n\n # override runtime bytecode: `normalFunction` returns `false`\n override_code = HexStr(\n \"0x6080604052348015600f57600080fd5b5060043610603c5760003560e01c8063185c38a4146041578063c06a97cb146049578063d67e4b84146051575b600080fd5b60476071565b005b604f60df565b005b605760e4565b604051808215151515815260200191505060405180910390f35b6040517f08c379a000000000000000000000000000000000000000000000000000000000815260040180806020018281038252601b8152602001807f46756e6374696f6e20686173206265656e2072657665727465642e000000000081525060200191505060405180910390fd5b600080fd5b60008090509056fea2646970667358221220bb71e9e9a2e271cd0fbe833524a3ea67df95f25ea13aef5b0a761fa52b538f1064736f6c63430006010033\" # noqa: E501\n )\n call_result = w3.eth.call(\n txn_params, \"latest\", {revert_contract.address: {\"code\": override_code}}\n )\n (result,) = w3.codec.decode([\"bool\"], call_result)\n assert result is False\n\n # test bytes\n\n bytes_call_result = w3.eth.call(\n txn_params,\n \"latest\",\n {revert_contract.address: {\"code\": to_bytes(hexstr=override_code)}},\n )\n (bytes_result,) = w3.codec.decode([\"bool\"], bytes_call_result)\n assert bytes_result is False\n\n @pytest.mark.parametrize(\n \"params\",\n (\n {\n \"nonce\": 1, # int\n \"balance\": 1, # int\n \"code\": HexStr(\"0x\"), # HexStr\n # with state\n \"state\": {HexStr(f\"0x{'00' * 32}\"): HexStr(f\"0x{'00' * 32}\")},\n },\n {\n \"nonce\": HexStr(\"0x1\"), # HexStr\n \"balance\": HexStr(\"0x1\"), # HexStr\n \"code\": b\"\\x00\", # bytes\n # with stateDiff\n \"stateDiff\": {HexStr(f\"0x{'00' * 32}\"): HexStr(f\"0x{'00' * 32}\")},\n },\n ),\n )\n def test_eth_call_with_override_param_type_check(\n self,\n w3: \"Web3\",\n math_contract: \"Contract\",\n params: StateOverrideParams,\n ) -> None:\n txn_params: TxParams = {\"from\": w3.eth.coinbase}\n\n # assert does not raise\n w3.eth.call(txn_params, \"latest\", {math_contract.address: params})\n\n def test_eth_call_with_0_result(\n self, w3: \"Web3\", math_contract: \"Contract\"\n ) -> None:\n coinbase = w3.eth.coinbase\n txn_params = math_contract._prepare_transaction(\n fn_name=\"add\",\n fn_args=(0, 0),\n transaction={\"from\": coinbase, \"to\": math_contract.address},\n )\n call_result = w3.eth.call(txn_params)\n assert is_string(call_result)\n (result,) = w3.codec.decode([\"uint256\"], call_result)\n assert result == 0\n\n def test_eth_call_revert_with_msg(\n self,\n w3: \"Web3\",\n revert_contract: \"Contract\",\n unlocked_account: ChecksumAddress,\n ) -> None:\n txn_params = revert_contract._prepare_transaction(\n fn_name=\"revertWithMessage\",\n transaction={\n \"from\": unlocked_account,\n \"to\": revert_contract.address,\n },\n )\n data = \"0x08c379a00000000000000000000000000000000000000000000000000000000000000020000000000000000000000000000000000000000000000000000000000000001b46756e6374696f6e20686173206265656e2072657665727465642e0000000000\" # noqa: E501\n with pytest.raises(\n ContractLogicError, match=\"execution reverted: Function has been reverted\"\n ) as excinfo:\n w3.eth.call(txn_params)\n assert excinfo.value.data == data\n\n def test_eth_call_revert_without_msg(\n self,\n w3: \"Web3\",\n revert_contract: \"Contract\",\n unlocked_account: ChecksumAddress,\n ) -> None:\n with pytest.raises(ContractLogicError, match=\"execution reverted\"):\n txn_params = revert_contract._prepare_transaction(\n fn_name=\"revertWithoutMessage\",\n transaction={\n \"from\": unlocked_account,\n \"to\": revert_contract.address,\n },\n )\n w3.eth.call(txn_params)\n\n def test_eth_call_custom_error_revert_with_msg(\n self,\n w3: \"Web3\",\n revert_contract: \"Contract\",\n unlocked_account: ChecksumAddress,\n ) -> None:\n data = revert_contract.encodeABI(\n fn_name=\"UnauthorizedWithMessage\", args=[\"You are not authorized\"]\n )\n txn_params = revert_contract._prepare_transaction(\n fn_name=\"customErrorWithMessage\",\n transaction={\n \"from\": unlocked_account,\n \"to\": revert_contract.address,\n },\n )\n with pytest.raises(ContractCustomError, match=data) as excinfo:\n w3.eth.call(txn_params)\n assert excinfo.value.data == data\n\n def test_eth_call_custom_error_revert_without_msg(\n self,\n w3: \"Web3\",\n revert_contract: \"Contract\",\n unlocked_account: ChecksumAddress,\n ) -> None:\n data = revert_contract.encodeABI(fn_name=\"Unauthorized\")\n txn_params = revert_contract._prepare_transaction(\n fn_name=\"customErrorWithoutMessage\",\n transaction={\n \"from\": unlocked_account,\n \"to\": revert_contract.address,\n },\n )\n with pytest.raises(ContractCustomError, match=data) as excinfo:\n w3.eth.call(txn_params)\n assert excinfo.value.data == data\n\n @pytest.mark.parametrize(\n \"panic_error,params\",\n (\n (\"01\", []),\n (\"11\", []),\n (\"12\", [0]),\n (\"21\", [-1]),\n (\"22\", []),\n (\"31\", []),\n (\"32\", []),\n (\"41\", []),\n (\"51\", []),\n ),\n )\n def test_contract_panic_errors(\n self,\n w3: \"Web3\",\n panic_errors_contract: \"Contract\",\n panic_error: str,\n params: List[Any],\n ) -> None:\n method = getattr(\n panic_errors_contract.functions,\n f\"errorCode{panic_error}\",\n )\n error_msg = PANIC_ERROR_CODES[panic_error]\n\n with pytest.raises(ContractPanicError, match=re.escape(error_msg)):\n method(*params).call()\n\n def test_eth_call_offchain_lookup(\n self,\n w3: \"Web3\",\n offchain_lookup_contract: \"Contract\",\n unlocked_account: ChecksumAddress,\n monkeypatch: \"MonkeyPatch\",\n ) -> None:\n normalized_contract_address = to_hex_if_bytes(\n offchain_lookup_contract.address\n ).lower()\n mock_offchain_lookup_request_response(\n monkeypatch,\n mocked_request_url=f\"https://web3.py/gateway/{normalized_contract_address}/{OFFCHAIN_LOOKUP_TEST_DATA}.json\", # noqa: E501\n mocked_json_data=WEB3PY_AS_HEXBYTES,\n )\n response = offchain_lookup_contract.functions.testOffchainLookup(\n OFFCHAIN_LOOKUP_TEST_DATA\n ).call()\n assert w3.codec.decode([\"string\"], response)[0] == \"web3py\"\n\n def test_eth_call_offchain_lookup_raises_when_ccip_read_is_disabled(\n self,\n w3: \"Web3\",\n offchain_lookup_contract: \"Contract\",\n ) -> None:\n return_data = (\n OFFCHAIN_LOOKUP_4BYTE_DATA\n + abi_encoded_offchain_lookup_contract_address(w3, offchain_lookup_contract)\n + OFFCHAIN_LOOKUP_RETURN_DATA\n )\n # test ContractFunction call\n with pytest.raises(OffchainLookup) as e:\n offchain_lookup_contract.functions.testOffchainLookup(\n OFFCHAIN_LOOKUP_TEST_DATA\n ).call(ccip_read_enabled=False)\n assert e.value.data == return_data\n\n # test ContractCaller call\n with pytest.raises(OffchainLookup) as excinfo:\n offchain_lookup_contract.caller(ccip_read_enabled=False).testOffchainLookup(\n OFFCHAIN_LOOKUP_TEST_DATA\n )\n assert excinfo.value.data == return_data\n\n # test global flag on the provider\n w3.provider.global_ccip_read_enabled = False\n\n with pytest.raises(OffchainLookup) as exc_info:\n offchain_lookup_contract.functions.testOffchainLookup(\n OFFCHAIN_LOOKUP_TEST_DATA\n ).call()\n assert exc_info.value.data == return_data\n\n w3.provider.global_ccip_read_enabled = True # cleanup\n\n def test_eth_call_offchain_lookup_call_flag_overrides_provider_flag(\n self,\n w3: \"Web3\",\n offchain_lookup_contract: \"Contract\",\n unlocked_account: ChecksumAddress,\n monkeypatch: \"MonkeyPatch\",\n ) -> None:\n normalized_contract_address = to_hex_if_bytes(\n offchain_lookup_contract.address\n ).lower()\n mock_offchain_lookup_request_response(\n monkeypatch,\n mocked_request_url=f\"https://web3.py/gateway/{normalized_contract_address}/{OFFCHAIN_LOOKUP_TEST_DATA}.json\", # noqa: E501\n mocked_json_data=WEB3PY_AS_HEXBYTES,\n )\n\n w3.provider.global_ccip_read_enabled = False\n\n response = offchain_lookup_contract.functions.testOffchainLookup(\n OFFCHAIN_LOOKUP_TEST_DATA\n ).call(ccip_read_enabled=True)\n assert w3.codec.decode([\"string\"], response)[0] == \"web3py\"\n\n w3.provider.global_ccip_read_enabled = True # cleanup\n\n @pytest.mark.parametrize(\"max_redirects\", range(-1, 4))\n def test_eth_call_offchain_lookup_raises_if_max_redirects_is_less_than_4(\n self,\n w3: \"Web3\",\n offchain_lookup_contract: \"Contract\",\n max_redirects: int,\n ) -> None:\n default_max_redirects = w3.provider.ccip_read_max_redirects\n\n w3.provider.ccip_read_max_redirects = max_redirects\n with pytest.raises(ValueError, match=\"at least 4\"):\n offchain_lookup_contract.functions.testOffchainLookup(\n OFFCHAIN_LOOKUP_TEST_DATA\n ).call()\n\n w3.provider.ccip_read_max_redirects = default_max_redirects # cleanup\n\n def test_eth_call_offchain_lookup_raises_for_improperly_formatted_rest_request_response( # noqa: E501\n self,\n w3: \"Web3\",\n offchain_lookup_contract: \"Contract\",\n unlocked_account: ChecksumAddress,\n monkeypatch: \"MonkeyPatch\",\n ) -> None:\n normalized_contract_address = to_hex_if_bytes(\n offchain_lookup_contract.address\n ).lower()\n mock_offchain_lookup_request_response(\n monkeypatch,\n mocked_request_url=f\"https://web3.py/gateway/{normalized_contract_address}/{OFFCHAIN_LOOKUP_TEST_DATA}.json\", # noqa: E501\n mocked_json_data=WEB3PY_AS_HEXBYTES,\n json_data_field=\"not_data\",\n )\n with pytest.raises(Web3ValidationError, match=\"missing 'data' field\"):\n offchain_lookup_contract.functions.testOffchainLookup(\n OFFCHAIN_LOOKUP_TEST_DATA\n ).call()\n\n @pytest.mark.parametrize(\"status_code_non_4xx_error\", [100, 300, 500, 600])\n def test_eth_call_offchain_lookup_tries_next_url_for_non_4xx_error_status_and_tests_POST( # noqa: E501\n self,\n w3: \"Web3\",\n offchain_lookup_contract: \"Contract\",\n unlocked_account: ChecksumAddress,\n monkeypatch: \"MonkeyPatch\",\n status_code_non_4xx_error: int,\n ) -> None:\n normalized_contract_address = to_hex_if_bytes(\n offchain_lookup_contract.address\n ).lower()\n\n # The next url in our test contract doesn't contain '{data}', triggering\n # the POST request logic. The idea here is to return a bad status for\n # the first url (GET) and a success status from the second call (POST)\n # to test both that we move on to the next url with non 4xx status and\n # that the POST logic is also working as expected.\n mock_offchain_lookup_request_response(\n monkeypatch,\n mocked_request_url=f\"https://web3.py/gateway/{normalized_contract_address}/{OFFCHAIN_LOOKUP_TEST_DATA}.json\", # noqa: E501\n mocked_status_code=status_code_non_4xx_error,\n mocked_json_data=WEB3PY_AS_HEXBYTES,\n )\n mock_offchain_lookup_request_response(\n monkeypatch,\n http_method=\"POST\",\n mocked_request_url=f\"https://web3.py/gateway/{normalized_contract_address}.json\", # noqa: E501\n mocked_status_code=200,\n mocked_json_data=WEB3PY_AS_HEXBYTES,\n sender=normalized_contract_address,\n calldata=OFFCHAIN_LOOKUP_TEST_DATA,\n )\n response = offchain_lookup_contract.functions.testOffchainLookup(\n OFFCHAIN_LOOKUP_TEST_DATA\n ).call()\n assert w3.codec.decode([\"string\"], response)[0] == \"web3py\"\n\n def test_eth_call_offchain_lookup_calls_raise_for_status_for_4xx_status_code(\n self,\n w3: \"Web3\",\n offchain_lookup_contract: \"Contract\",\n unlocked_account: ChecksumAddress,\n monkeypatch: \"MonkeyPatch\",\n ) -> None:\n normalized_contract_address = to_hex_if_bytes(\n offchain_lookup_contract.address\n ).lower()\n mock_offchain_lookup_request_response(\n monkeypatch,\n mocked_request_url=f\"https://web3.py/gateway/{normalized_contract_address}/{OFFCHAIN_LOOKUP_TEST_DATA}.json\", # noqa: E501\n mocked_status_code=randint(400, 499),\n mocked_json_data=WEB3PY_AS_HEXBYTES,\n )\n with pytest.raises(Exception, match=\"called raise_for_status\\\\(\\\\)\"):\n offchain_lookup_contract.functions.testOffchainLookup(\n OFFCHAIN_LOOKUP_TEST_DATA\n ).call()\n\n def test_eth_call_offchain_lookup_raises_when_all_supplied_urls_fail(\n self,\n w3: \"Web3\",\n offchain_lookup_contract: \"Contract\",\n ) -> None:\n # GET and POST requests should fail since responses are not mocked\n with pytest.raises(\n MultipleFailedRequests, match=\"Offchain lookup failed for supplied urls\"\n ):\n offchain_lookup_contract.functions.testOffchainLookup(\n OFFCHAIN_LOOKUP_TEST_DATA\n ).call()\n\n def test_eth_call_continuous_offchain_lookup_raises_with_too_many_requests(\n self,\n w3: \"Web3\",\n offchain_lookup_contract: \"Contract\",\n unlocked_account: ChecksumAddress,\n monkeypatch: \"MonkeyPatch\",\n ) -> None:\n normalized_contract_address = to_hex_if_bytes(\n offchain_lookup_contract.address\n ).lower()\n mock_offchain_lookup_request_response(\n monkeypatch,\n mocked_request_url=f\"https://web3.py/gateway/{normalized_contract_address}/0x.json\", # noqa: E501\n )\n with pytest.raises(TooManyRequests, match=\"Too many CCIP read redirects\"):\n offchain_lookup_contract.caller().continuousOffchainLookup()\n\n def \ntest_eth_estimate_gas_revert_with_msg(\n self,\n w3: \"Web3\",\n revert_contract: \"Contract\",\n unlocked_account: ChecksumAddress,\n ) -> None:\n with pytest.raises(\n ContractLogicError, match=\"execution reverted: Function has been reverted\"\n ):\n txn_params = revert_contract._prepare_transaction(\n fn_name=\"revertWithMessage\",\n transaction={\n \"from\": unlocked_account,\n \"to\": revert_contract.address,\n },\n )\n w3.eth.estimate_gas(txn_params)\n\n def test_eth_estimate_gas_revert_without_msg(\n self,\n w3: \"Web3\",\n revert_contract: \"Contract\",\n unlocked_account: ChecksumAddress,\n ) -> None:\n with pytest.raises(ContractLogicError, match=\"execution reverted\"):\n txn_params = revert_contract._prepare_transaction(\n fn_name=\"revertWithoutMessage\",\n transaction={\n \"from\": unlocked_account,\n \"to\": revert_contract.address,\n },\n )\n w3.eth.estimate_gas(txn_params)\n\n def test_eth_estimate_gas_custom_error_revert_with_msg(\n self,\n w3: \"Web3\",\n revert_contract: \"Contract\",\n unlocked_account: ChecksumAddress,\n ) -> None:\n data = revert_contract.encodeABI(\n fn_name=\"UnauthorizedWithMessage\", args=[\"You are not authorized\"]\n )\n txn_params = revert_contract._prepare_transaction(\n fn_name=\"customErrorWithMessage\",\n transaction={\n \"from\": unlocked_account,\n \"to\": revert_contract.address,\n },\n )\n with pytest.raises(ContractCustomError, match=data) as excinfo:\n w3.eth.estimate_gas(txn_params)\n assert excinfo.value.data == data\n\n def test_eth_estimate_gas_custom_error_revert_without_msg(\n self,\n w3: \"Web3\",\n revert_contract: \"Contract\",\n unlocked_account: ChecksumAddress,\n ) -> None:\n data = revert_contract.encodeABI(fn_name=\"Unauthorized\")\n txn_params = revert_contract._prepare_transaction(\n fn_name=\"customErrorWithoutMessage\",\n transaction={\n \"from\": unlocked_account,\n \"to\": revert_contract.address,\n },\n )\n with pytest.raises(ContractCustomError, match=data) as excinfo:\n w3.eth.estimate_gas(txn_params)\n assert excinfo.value.data == data\n\n def test_eth_estimate_gas(\n self, w3: \"Web3\", unlocked_account_dual_type: ChecksumAddress\n ) -> None:\n gas_estimate = w3.eth.estimate_gas(\n {\n \"from\": unlocked_account_dual_type,\n \"to\": unlocked_account_dual_type,\n \"value\": Wei(1),\n }\n )\n assert is_integer(gas_estimate)\n assert gas_estimate > 0\n\n def test_eth_estimate_gas_with_block(\n self, w3: \"Web3\", unlocked_account_dual_type: ChecksumAddress\n ) -> None:\n gas_estimate = w3.eth.estimate_gas(\n {\n \"from\": unlocked_account_dual_type,\n \"to\": unlocked_account_dual_type,\n \"value\": Wei(1),\n },\n \"latest\",\n )\n assert is_integer(gas_estimate)\n assert gas_estimate > 0\n\n @pytest.mark.parametrize(\n \"params\",\n (\n {\n \"nonce\": 1, # int\n \"balance\": 1, # int\n \"code\": HexStr(\"0x\"), # HexStr\n # with state\n \"state\": {HexStr(f\"0x{'00' * 32}\"): HexStr(f\"0x{'00' * 32}\")},\n },\n {\n \"nonce\": HexStr(\"0x1\"), # HexStr\n \"balance\": HexStr(\"0x1\"), # HexStr\n \"code\": b\"\\x00\", # bytes\n # with stateDiff\n \"stateDiff\": {HexStr(f\"0x{'00' * 32}\"): HexStr(f\"0x{'00' * 32}\")},\n },\n ),\n )\n def test_eth_estimate_gas_with_override_param_type_check(\n self,\n w3: \"Web3\",\n math_contract: \"Contract\",\n params: StateOverrideParams,\n ) -> None:\n txn_params: TxParams = {\"from\": w3.eth.coinbase}\n\n # assert does not raise\n w3.eth.estimate_gas(txn_params, None, {math_contract.address: params})\n\n def test_eth_getBlockByHash(self, w3: \"Web3\", empty_block: BlockData) -> None:\n block = w3.eth.get_block(empty_block[\"hash\"])\n assert block[\"hash\"] == empty_block[\"hash\"]\n assert block[\"receiptsRoot\"] == empty_block[\"receiptsRoot\"]\n assert block[\"logsBloom\"] == empty_block[\"logsBloom\"]\n\n def test_eth_getBlockByHash_not_found(\n self, w3: \"Web3\", empty_block: BlockData\n ) -> None:\n with pytest.raises(BlockNotFound):\n w3.eth.get_block(UNKNOWN_HASH)\n\n def test_eth_getBlockByHash_pending(self, w3: \"Web3\") -> None:\n block = w3.eth.get_block(\"pending\")\n assert block[\"hash\"] is None\n\n def test_eth_getBlockByNumber_with_integer(\n self, w3: \"Web3\", empty_block: BlockData\n ) -> None:\n block = w3.eth.get_block(empty_block[\"number\"])\n assert block[\"number\"] == empty_block[\"number\"]\n\n def test_eth_getBlockByNumber_latest(\n self, w3: \"Web3\", empty_block: BlockData\n ) -> None:\n current_block_number = w3.eth.block_number\n block = w3.eth.get_block(\"latest\")\n assert block[\"number\"] == current_block_number\n\n def test_eth_getBlockByNumber_not_found(\n self, w3: \"Web3\", empty_block: BlockData\n ) -> None:\n with pytest.raises(BlockNotFound):\n w3.eth.get_block(BlockNumber(12345))\n\n def test_eth_getBlockByNumber_pending(\n self, w3: \"Web3\", empty_block: BlockData\n ) -> None:\n current_block_number = w3.eth.block_number\n block = w3.eth.get_block(\"pending\")\n assert block[\"number\"] == current_block_number + 1\n\n def test_eth_getBlockByNumber_earliest(\n self, w3: \"Web3\", empty_block: BlockData\n ) -> None:\n genesis_block = w3.eth.get_block(BlockNumber(0))\n block = w3.eth.get_block(\"earliest\")\n assert block[\"number\"] == 0\n assert block[\"hash\"] == genesis_block[\"hash\"]\n\n def test_eth_getBlockByNumber_safe(\n self, w3: \"Web3\", empty_block: BlockData\n ) -> None:\n block = w3.eth.get_block(\"safe\")\n assert block is not None\n assert isinstance(block[\"number\"], int)\n\n def test_eth_getBlockByNumber_finalized(\n self, w3: \"Web3\", empty_block: BlockData\n ) -> None:\n block = w3.eth.get_block(\"finalized\")\n assert block is not None\n assert isinstance(block[\"number\"], int)\n\n def test_eth_getBlockByNumber_full_transactions(\n self, w3: \"Web3\", block_with_txn: BlockData\n ) -> None:\n block = w3.eth.get_block(block_with_txn[\"number\"], True)\n transaction = block[\"transactions\"][0]\n assert transaction[\"hash\"] == block_with_txn[\"transactions\"][0] # type: ignore\n\n def test_eth_getTransactionByHash(self, w3: \"Web3\", mined_txn_hash: HexStr) -> None:\n transaction = w3.eth.get_transaction(mined_txn_hash)\n assert is_dict(transaction)\n assert transaction[\"hash\"] == HexBytes(mined_txn_hash)\n\n def test_eth_getTransactionByHash_contract_creation(\n self, w3: \"Web3\", math_contract_deploy_txn_hash: HexStr\n ) -> None:\n transaction = w3.eth.get_transaction(math_contract_deploy_txn_hash)\n assert is_dict(transaction)\n assert transaction[\"to\"] is None, f\"to field is {transaction['to']!r}\"\n\n def test_eth_getTransactionByBlockHashAndIndex(\n self, w3: \"Web3\", block_with_txn: BlockData, mined_txn_hash: HexStr\n ) -> None:\n transaction = w3.eth.get_transaction_by_block(block_with_txn[\"hash\"], 0)\n assert is_dict(transaction)\n assert transaction[\"hash\"] == HexBytes(mined_txn_hash)\n\n def test_eth_getTransactionByBlockNumberAndIndex(\n self, w3: \"Web3\", block_with_txn: BlockData, mined_txn_hash: HexStr\n ) -> None:\n transaction = w3.eth.get_transaction_by_block(block_with_txn[\"number\"], 0)\n assert is_dict(transaction)\n assert transaction[\"hash\"] == HexBytes(mined_txn_hash)\n\n def test_eth_get_transaction_receipt_mined(\n self, w3: \"Web3\", block_with_txn: BlockData, mined_txn_hash: HexStr\n ) -> None:\n receipt = w3.eth.get_transaction_receipt(mined_txn_hash)\n assert is_dict(receipt)\n assert receipt[\"blockNumber\"] == block_with_txn[\"number\"]\n assert receipt[\"blockHash\"] == block_with_txn[\"hash\"]\n assert receipt[\"transactionIndex\"] == 0\n assert receipt[\"transactionHash\"] == HexBytes(mined_txn_hash)\n assert is_checksum_address(receipt[\"to\"])\n assert receipt[\"from\"] is not None\n assert is_checksum_address(receipt[\"from\"])\n\n effective_gas_price = receipt[\"effectiveGasPrice\"]\n assert isinstance(effective_gas_price, int)\n assert effective_gas_price > 0\n\n def test_eth_get_transaction_receipt_unmined(\n self, w3: \"Web3\", unlocked_account_dual_type: ChecksumAddress\n ) -> None:\n txn_hash = w3.eth.send_transaction(\n {\n \"from\": unlocked_account_dual_type,\n \"to\": unlocked_account_dual_type,\n \"value\": Wei(1),\n \"gas\": 21000,\n \"maxFeePerGas\": w3.to_wei(3, \"gwei\"),\n \"maxPriorityFeePerGas\": w3.to_wei(1, \"gwei\"),\n }\n )\n with pytest.raises(TransactionNotFound):\n w3.eth.get_transaction_receipt(txn_hash)\n\n def test_eth_get_transaction_receipt_with_log_entry(\n self,\n w3: \"Web3\",\n block_with_txn_with_log: BlockData,\n emitter_contract: \"Contract\",\n txn_hash_with_log: HexStr,\n ) -> None:\n receipt = w3.eth.get_transaction_receipt(txn_hash_with_log)\n assert is_dict(receipt)\n assert receipt[\"blockNumber\"] == block_with_txn_with_log[\"number\"]\n assert receipt[\"blockHash\"] == block_with_txn_with_log[\"hash\"]\n assert receipt[\"transactionIndex\"] == 0\n assert receipt[\"transactionHash\"] == HexBytes(txn_hash_with_log)\n\n assert len(receipt[\"logs\"]) == 1\n log_entry = receipt[\"logs\"][0]\n\n assert log_entry[\"blockNumber\"] == block_with_txn_with_log[\"number\"]\n assert log_entry[\"blockHash\"] == block_with_txn_with_log[\"hash\"]\n assert log_entry[\"logIndex\"] == 0\n assert is_same_address(log_entry[\"address\"], emitter_contract.address)\n assert log_entry[\"transactionIndex\"] == 0\n assert log_entry[\"transactionHash\"] == HexBytes(txn_hash_with_log)\n\n def test_eth_wait_for_transaction_receipt_mined(\n self, w3: \"Web3\", block_with_txn: BlockData, mined_txn_hash: HexStr\n ) -> None:\n receipt = w3.eth.wait_for_transaction_receipt(mined_txn_hash)\n assert is_dict(receipt)\n assert receipt[\"blockNumber\"] == block_with_txn[\"number\"]\n assert receipt[\"blockHash\"] == block_with_txn[\"hash\"]\n assert receipt[\"transactionIndex\"] == 0\n assert receipt[\"transactionHash\"] == HexBytes(mined_txn_hash)\n assert is_checksum_address(receipt[\"to\"])\n assert receipt[\"from\"] is not None\n assert is_checksum_address(receipt[\"from\"])\n\n effective_gas_price = receipt[\"effectiveGasPrice\"]\n assert isinstance(effective_gas_price, int)\n assert effective_gas_price > 0\n\n def test_eth_wait_for_transaction_receipt_unmined(\n self, w3: \"Web3\", unlocked_account_dual_type: ChecksumAddress\n ) -> None:\n txn_hash = w3.eth.send_transaction(\n {\n \"from\": unlocked_account_dual_type,\n \"to\": unlocked_account_dual_type,\n \"value\": Wei(1),\n \"gas\": 21000,\n \"maxFeePerGas\": w3.to_wei(3, \"gwei\"),\n \"maxPriorityFeePerGas\": w3.to_wei(1, \"gwei\"),\n }\n )\n\n timeout = 0.01\n with pytest.raises(TimeExhausted) as exc_info:\n w3.eth.wait_for_transaction_receipt(txn_hash, timeout=timeout)\n\n assert (_ in str(exc_info) for _ in [repr(txn_hash), timeout])\n\n def test_eth_wait_for_transaction_receipt_with_log_entry(\n self,\n w3: \"Web3\",\n block_with_txn_with_log: BlockData,\n emitter_contract: \"Contract\",\n txn_hash_with_log: HexStr,\n ) -> None:\n receipt = w3.eth.wait_for_transaction_receipt(txn_hash_with_log)\n assert is_dict(receipt)\n assert receipt[\"blockNumber\"] == block_with_txn_with_log[\"number\"]\n assert receipt[\"blockHash\"] == block_with_txn_with_log[\"hash\"]\n...\n", "needle_token_start": 12107, "needle_token_end": 12276, "code_context_ntokens": 16253, "output": ["```python\ndef test_eth_estimate_gas_revert_with_msg(\n self, w3: \"Web3\", revert_contract: \"Contract\", unlocked_account: ChecksumAddress\n) -> None:\n with pytest.raises(\n ContractLogicError, match=\"execution reverted: Function has been reverted\"\n ):\n txn_params = revert_contract._prepare_transaction(\n fn_name=\"revertWithMessage\",\n transaction={\n \"from\": unlocked_account,\n \"to\": revert_contract.address,\n },\n )\n w3.eth.estimate_gas(txn_params)\n```"]} +{"repo": "ethereum/web3.py", "name": "_make_get_request", "language": "python", "path": "web3/beacon/beacon.py", "position_ratio": 0.85, "description": "\nFunction Description:\n1. **Purpose**: The function is designed to fetch data from a specified server endpoint and return the response in a structured format.\n2. **Input**: It takes a string representing the specific endpoint of a server from which data needs to be retrieved.\n3. **Output**: It returns a dictionary containing the data fetched from the server endpoint.\n4. **Procedure**: The function constructs the full URL by appending the endpoint to the base server URL. It then sends an HTTP GET request to this URL and waits for the response. The response is processed and returned as a JSON-formatted dictionary.\n\n", "instruction": "Based on the function description and code context, please retrieve and repeat the exact described function from the code context in a code block wrapped by ```:", "template": "instruction\ncode_context\ndescription\ninstruction", "code_context": "# Path: web3/_utils/module_testing/web3_module.py\nimport pytest\nfrom typing import (\n Any,\n NoReturn,\n Sequence,\n Union,\n)\n\nfrom eth_typing import (\n ChecksumAddress,\n HexAddress,\n HexStr,\n TypeStr,\n)\nfrom hexbytes import (\n HexBytes,\n)\n\nfrom web3 import (\n AsyncWeb3,\n Web3,\n)\nfrom web3._utils.ens import (\n ens_addresses,\n)\nfrom web3.exceptions import (\n InvalidAddress,\n)\n\n\nclass Web3ModuleTest:\n def test_web3_client_version(self, w3: Web3) -> None:\n client_version = w3.client_version\n self._check_web3_client_version(client_version)\n\n def _check_web3_client_version(self, client_version: str) -> NoReturn:\n raise NotImplementedError(\"Must be implemented by subclasses\")\n\n # Contract that calculated test values can be found at\n # https://kovan.etherscan.io/address/0xb9be06f5b99372cf9afbccadbbb9954ccaf7f4bb#code\n @pytest.mark.parametrize(\n \"types,values,expected\",\n (\n (\n [\"bool\"],\n [True],\n HexBytes(\n \"0x5fe7f977e71dba2ea1a68e21057beebb9be2ac30c6410aa38d4f3fbe41dcffd2\"\n ),\n ),\n (\n [\"uint8\", \"uint8\", \"uint8\"],\n [97, 98, 99],\n HexBytes(\n \"0x4e03657aea45a94fc7d47ba826c8d667c0d1e6e33a64a036ec44f58fa12d6c45\"\n ),\n ),\n (\n [\"uint248\"],\n [30],\n HexBytes(\n \"0x30f95d210785601eb33ae4d53d405b26f920e765dff87cca8e9a4aec99f82671\"\n ),\n ),\n (\n [\"bool\", \"uint16\"],\n [True, 299],\n HexBytes(\n \"0xed18599ccd80ee9fae9a28b0e34a5573c3233d7468f808fd659bc171cf0b43bd\"\n ),\n ),\n (\n [\"int256\"],\n [-10],\n HexBytes(\n \"0xd6fb717f7e270a360f5093ce6a7a3752183e89c9a9afe5c0cb54b458a304d3d5\"\n ),\n ),\n (\n [\"int256\"],\n [10],\n HexBytes(\n \"0xc65a7bb8d6351c1cf70c95a316cc6a92839c986682d98bc35f958f4883f9d2a8\"\n ),\n ),\n (\n [\"int8\", \"uint8\"],\n [-10, 18],\n HexBytes(\n \"0x5c6ab1e634c08d9c0f4df4d789e8727943ef010dd7ca8e3c89de197a26d148be\"\n ),\n ),\n (\n [\"address\"],\n [\"0x49eddd3769c0712032808d86597b84ac5c2f5614\"],\n InvalidAddress,\n ),\n (\n [\"address\"],\n [\"0x49EdDD3769c0712032808D86597B84ac5c2F5614\"],\n HexBytes(\n \"0x2ff37b5607484cd4eecf6d13292e22bd6e5401eaffcc07e279583bc742c68882\"\n ),\n ),\n (\n [\"bytes2\"],\n [\"0x5402\"],\n HexBytes(\n \"0x4ed9171bda52fca71ab28e7f452bd6eacc3e5a568a47e0fa53b503159a9b8910\"\n ),\n ),\n (\n [\"bytes3\"],\n [\"0x5402\"],\n HexBytes(\n \"0x4ed9171bda52fca71ab28e7f452bd6eacc3e5a568a47e0fa53b503159a9b8910\"\n ),\n ),\n (\n [\"bytes\"],\n [\n \"0x636865636b6c6f6e6762797465737472696e676167\"\n \"61696e7374736f6c6964697479736861336861736866756e6374696f6e\"\n ],\n HexBytes(\n \"0xd78a84d65721b67e4011b10c99dafdedcdcd7cb30153064f773e210b4762e22f\"\n ),\n ),\n (\n [\"string\"],\n [\"testing a string!\"],\n HexBytes(\n \"0xe8c275c0b4070a5ec6cfcb83f0ba394b30ddd283de785d43f2eabfb04bd96747\"\n ),\n ),\n (\n [\"string\", \"bool\", \"uint16\", \"bytes2\", \"address\"],\n [\n \"testing a string!\",\n False,\n 299,\n \"0x5402\",\n \"0x49eddd3769c0712032808d86597b84ac5c2f5614\",\n ],\n InvalidAddress,\n ),\n (\n [\"string\", \"bool\", \"uint16\", \"bytes2\", \"address\"],\n [\n \"testing a string!\",\n False,\n 299,\n \"0x5402\",\n \"0x49EdDD3769c0712032808D86597B84ac5c2F5614\",\n ],\n HexBytes(\n \"0x8cc6eabb25b842715e8ca39e2524ed946759aa37bfb7d4b81829cf5a7e266103\"\n ),\n ),\n (\n [\"bool[2][]\"],\n [[[True, False], [False, True]]],\n HexBytes(\n \"0x1eef261f2eb51a8c736d52be3f91ff79e78a9ec5df2b7f50d0c6f98ed1e2bc06\"\n ),\n ),\n (\n [\"bool[]\"],\n [[True, False, True]],\n HexBytes(\n \"0x5c6090c0461491a2941743bda5c3658bf1ea53bbd3edcde54e16205e18b45792\"\n ),\n ),\n (\n [\"uint24[]\"],\n [[1, 0, 1]],\n HexBytes(\n \"0x5c6090c0461491a2941743bda5c3658bf1ea53bbd3edcde54e16205e18b45792\"\n ),\n ),\n (\n [\"uint8[2]\"],\n [[8, 9]],\n HexBytes(\n \"0xc7694af312c4f286114180fd0ba6a52461fcee8a381636770b19a343af92538a\"\n ),\n ),\n (\n [\"uint256[2]\"],\n [[8, 9]],\n HexBytes(\n \"0xc7694af312c4f286114180fd0ba6a52461fcee8a381636770b19a343af92538a\"\n ),\n ),\n (\n [\"uint8[]\"],\n [[8]],\n HexBytes(\n \"0xf3f7a9fe364faab93b216da50a3214154f22a0a2b415b23a84c8169e8b636ee3\"\n ),\n ),\n (\n [\"address[]\"],\n [\n [\n \"0x49EdDD3769c0712032808D86597B84ac5c2F5614\",\n \"0xA6b759bBbf4B59D24acf7E06e79f3a5D104fdCE5\",\n ]\n ],\n HexBytes(\n...\n# Path: web3/_utils/module_testing/__init__.py\nfrom .eth_module import (\n AsyncEthModuleTest,\n EthModuleTest,\n)\nfrom .go_ethereum_admin_module import (\n GoEthereumAdminModuleTest,\n)\nfrom .go_ethereum_personal_module import (\n GoEthereumPersonalModuleTest,\n)\nfrom .go_ethereum_txpool_module import (\n GoEthereumAsyncTxPoolModuleTest,\n GoEthereumTxPoolModuleTest,\n)\nfrom .net_module import (\n AsyncNetModuleTest,\n NetModuleTest,\n)\nfrom .web3_module import (\n Web3ModuleTest,\n)\n\n# Path: web3/_utils/module_testing/persistent_connection_provider.py\nimport asyncio\nimport pytest\nfrom typing import (\n TYPE_CHECKING,\n Any,\n Dict,\n Tuple,\n cast,\n)\n\nfrom eth_utils import (\n is_hexstr,\n)\nfrom hexbytes import (\n HexBytes,\n)\n\nfrom web3.datastructures import (\n AttributeDict,\n)\nfrom web3.middleware import (\n ExtraDataToPOAMiddleware,\n)\nfrom web3.types import (\n FormattedEthSubscriptionResponse,\n)\n\nif TYPE_CHECKING:\n from web3.main import (\n AsyncWeb3,\n )\n\n\nclass PersistentConnectionProviderTest:\n @pytest.mark.asyncio\n @pytest.mark.parametrize(\n \"subscription_params,ws_subscription_response,expected_formatted_result\",\n (\n (\n (\"newHeads\",),\n {\n \"jsonrpc\": \"2.0\",\n \"method\": \"eth_subscription\",\n \"params\": {\n \"subscription\": \"THIS_WILL_BE_REPLACED_IN_THE_TEST\",\n \"result\": {\n \"number\": \"0x539\",\n \"hash\": \"0xb46b85928f2c2264c2bf7ad5c6d6985664f1527e744193ef990cc0d3da5afc5e\", # noqa: E501\n \"parentHash\": \"0xb46b85928f2c2264c2bf7ad5c6d6985664f1527e744193ef990cc0d3da5afc5e\", # noqa: E501\n \"sha3Uncles\": \"0x1dcc4de8dec75d7aab85b567b6ccd41ad312451b948a7413f0a142fd40d49347\", # noqa: E501\n \"logsBloom\": \"0x00\",\n \"transactionsRoot\": \"0x56260fe8298aff6d360e3a68fa855693f25dcb2708d8a7e509e8519b265d3988\", # noqa: E501\n \"stateRoot\": \"0x56260fe8298aff6d360e3a68fa855693f25dcb2708d8a7e509e8519b265d3988\", # noqa: E501\n \"receiptsRoot\": \"0x56260fe8298aff6d360e3a68fa855693f25dcb2708d8a7e509e8519b265d3988\", # noqa: E501\n \"miner\": \"0x0000000000000000000000000000000000000000\",\n \"difficulty\": \"0x0\",\n \"extraData\": \"0x496c6c756d696e61746520446d6f63726174697a6520447374726962757465\", # noqa: E501\n \"gasLimit\": \"0x1c9c380\",\n \"gasUsed\": \"0xd1ce44\",\n \"timestamp\": \"0x539\",\n \"baseFeePerGas\": \"0x26f93fef9\",\n \"withdrawalsRoot\": \"0x56260fe8298aff6d360e3a68fa855693f25dcb2708d8a7e509e8519b265d3988\", # noqa: E501\n \"nonce\": \"0x0000000000000000\",\n \"mixHash\": \"0x73e9e036ec894047f29954571d4b6d9e8717de7304269c263cbf150caa4e0768\", # noqa: E501\n },\n },\n },\n AttributeDict(\n {\n \"number\": 1337,\n \"hash\": HexBytes(\n \"0xb46b85928f2c2264c2bf7ad5c6d6985664f1527e744193ef990cc0d3da5afc5e\" # noqa: E501\n ),\n \"parentHash\": HexBytes(\n \"0xb46b85928f2c2264c2bf7ad5c6d6985664f1527e744193ef990cc0d3da5afc5e\" # noqa: E501\n ),\n \"sha3Uncles\": HexBytes(\n \"0x1dcc4de8dec75d7aab85b567b6ccd41ad312451b948a7413f0a142fd40d49347\" # noqa: E501\n ),\n \"logsBloom\": HexBytes(\"0x00\"),\n \"transactionsRoot\": HexBytes(\n \"0x56260fe8298aff6d360e3a68fa855693f25dcb2708d8a7e509e8519b265d3988\" # noqa: E501\n ),\n \"stateRoot\": HexBytes(\n \"0x56260fe8298aff6d360e3a68fa855693f25dcb2708d8a7e509e8519b265d3988\" # noqa: E501\n ),\n \"receiptsRoot\": HexBytes(\n \"0x56260fe8298aff6d360e3a68fa855693f25dcb2708d8a7e509e8519b265d3988\" # noqa: E501\n ),\n \"miner\": \"0x0000000000000000000000000000000000000000\",\n \"difficulty\": 0,\n \"extraData\": HexBytes(\n \"0x496c6c756d696e61746520446d6f63726174697a6520447374726962757465\" # noqa: E501\n ),\n \"gasLimit\": 30000000,\n \"gasUsed\": 13749828,\n \"timestamp\": 1337,\n \"baseFeePerGas\": 10461904633,\n \"withdrawalsRoot\": HexBytes(\n \"0x56260fe8298aff6d360e3a68fa855693f25dcb2708d8a7e509e8519b265d3988\" # noqa: E501\n ),\n \"nonce\": HexBytes(\"0x0000000000000000\"),\n \"mixHash\": HexBytes(\n \"0x73e9e036ec894047f29954571d4b6d9e8717de7304269c263cbf150caa4e0768\" # noqa: E501\n ),\n }\n ),\n ),\n (\n (\"newPendingTransactions\", True),\n {\n \"jsonrpc\": \"2.0\",\n \"method\": \"eth_subscription\",\n \"params\": {\n \"subscription\": \"THIS_WILL_BE_REPLACED_IN_THE_TEST\",\n \"result\": {\n \"blockHash\": None,\n \"blockNumber\": None,\n \"from\": \"0x0000000000000000000000000000000000000000\",\n \"gas\": \"0xf2f4\",\n \"gasPrice\": \"0x29035f36f\",\n \"maxFeePerGas\": \"0x29035f36f\",\n \"maxPriorityFeePerGas\": \"0x3b9aca00\",\n \"hash\": \"0xb46b85928f2c2264c2bf7ad5c6d6985664f1527e744193ef990cc0d3da5afc5e\", # noqa: E501\n \"input\": \"0x00\",\n \"nonce\": \"0x2013\",\n \"to\": \"0x0000000000000000000000000000000000000000\",\n \"transactionIndex\": None,\n \"value\": \"0x0\",\n \"type\": \"0x2\",\n \"accessList\": [],\n \"chainId\": \"0x1\",\n \"v\": \"0x1\",\n \"r\": \"0x3c144a7c00ed3118d55445cd5be2ae4620ca377f7c685e9c5f3687671d4dece1\", # noqa: E501\n \"s\": \"0x284de67cbf75fec8a9edb368dee3a37cf6faba87f0af4413b2f869ebfa87d002\", # noqa: E501\n \"yParity\": \"0x1\",\n },\n },\n },\n AttributeDict(\n {\n \"blockHash\": None,\n \"blockNumber\": None,\n \"from\": \"0x0000000000000000000000000000000000000000\",\n \"gas\": 62196,\n \"gasPrice\": 11009389423,\n \"maxFeePerGas\": 11009389423,\n \"maxPriorityFeePerGas\": 1000000000,\n \"hash\": HexBytes(\n \"0xb46b85928f2c2264c2bf7ad5c6d6985664f1527e744193ef990cc0d3da5afc5e\" # noqa: E501\n ),\n \"input\": HexBytes(\"0x00\"),\n \"nonce\": 8211,\n \"to\": \"0x0000000000000000000000000000000000000000\",\n \"transactionIndex\": None,\n \"value\": 0,\n \"type\": 2,\n \"accessList\": [],\n \"chainId\": 1,\n \"v\": 1,\n \"r\": HexBytes(\n \"0x3c144a7c00ed3118d55445cd5be2ae4620ca377f7c685e9c5f3687671d4dece1\" # noqa: E501\n ),\n \"s\": HexBytes(\n \"0x284de67cbf75fec8a9edb368dee3a37cf6faba87f0af4413b2f869ebfa87d002\" # noqa: E501\n ),\n \"yParity\": 1,\n }\n ),\n ),\n (\n (\"newPendingTransactions\", False),\n {\n \"jsonrpc\": \"2.0\",\n \"method\": \"eth_subscription\",\n \"params\": {\n \"subscription\": \"THIS_WILL_BE_REPLACED_IN_THE_TEST\",\n \"result\": \"0xb46b85928f2c2264c2bf7ad5c6d6985664f1527e744193ef990cc0d3da5afc5e\", # noqa: E501\n },\n },\n HexBytes(\n \"0xb46b85928f2c2264c2bf7ad5c6d6985664f1527e744193ef990cc0d3da5afc5e\"\n ),\n ),\n (\n (\"logs\", {\"address\": \"0xC02aaA39b223FE8D0A0e5C4F27eAD9083C756Cc2\"}),\n {\n \"jsonrpc\": \"2.0\",\n \"method\": \"eth_subscription\",\n \"params\": {\n \"subscription\": \"THIS_WILL_BE_REPLACED_IN_THE_TEST\",\n \"result\": {\n \"removed\": False,\n \"logIndex\": \"0x0\",\n \"transactionIndex\": \"0x0\",\n \"transactionHash\": \"0x56260fe8298aff6d360e3a68fa855693f25dcb2708d8a7e509e8519b265d3988\", # noqa: E501\n \"blockHash\": \"0xb46b85928f2c2264c2bf7ad5c6d6985664f1527e744193ef990cc0d3da5afc5e\", # noqa: E501\n \"blockNumber\": \"0x539\",\n \"address\": \"0xC02aaA39b223FE8D0A0e5C4F27eAD9083C756Cc2\",\n \"data\": \"0x00\",\n \"topics\": [\n \"0xe1fffdd4923d04f559f4d29e8bfc6cda04eb5b0d3c460751c2402c5c5cc9105c\", # noqa: E501\n \"0x00000000000000000000000016250d5630b4cf539739df2c5dacb4c659f2482d\", # noqa: E501\n ],\n },\n },\n },\n AttributeDict(\n {\n \"removed\": False,\n \"logIndex\": 0,\n \"transactionIndex\": 0,\n \"transactionHash\": HexBytes(\n \"0x56260fe8298aff6d360e3a68fa855693f25dcb2708d8a7e509e8519b265d3988\" # noqa: E501\n ),\n \"blockHash\": HexBytes(\n \"0xb46b85928f2c2264c2bf7ad5c6d6985664f1527e744193ef990cc0d3da5afc5e\" # noqa: E501\n ),\n \"blockNumber\": 1337,\n \"address\": \"0xC02aaA39b223FE8D0A0e5C4F27eAD9083C756Cc2\",\n \"data\": HexBytes(\"0x00\"),\n \"topics\": [\n HexBytes(\n \"0xe1fffdd4923d04f559f4d29e8bfc6cda04eb5b0d3c460751c2402c5c5cc9105c\" # noqa: E501\n ),\n HexBytes(\n \"0x00000000000000000000000016250d5630b4cf539739df2c5dacb4c659f2482d\" # noqa: E501\n ),\n ],\n }\n ),\n ),\n (\n (\"syncing\",),\n {\n \"jsonrpc\": \"2.0\",\n \"method\": \"eth_subscription\",\n \"params\": {\n \"subscription\": \"THIS_WILL_BE_REPLACED_IN_THE_TEST\",\n \"result\": False,\n },\n },\n False,\n ),\n (\n (\"syncing\",),\n {\n \"jsonrpc\": \"2.0\",\n \"method\": \"eth_subscription\",\n \"params\": {\n \"subscription\": \"THIS_WILL_BE_REPLACED_IN_THE_TEST\",\n \"result\": {\n \"isSyncing\": True,\n \"startingBlock\": \"0x0\",\n \"currentBlock\": \"0x4346fe\",\n \"highestBlock\": \"0x434806\",\n },\n },\n },\n AttributeDict(\n {\n \"isSyncing\": True,\n \"startingBlock\": 0,\n \"currentBlock\": 4409086,\n \"highestBlock\": 4409350,\n }\n ),\n ),\n ),\n ids=[\n \"newHeads\",\n \"newPendingTransactions-FullTxs\",\n \"newPendingTransactions-TxHashes\",\n \"logs\",\n \"syncing-False\",\n \"syncing-True\",\n ],\n )\n async def test_async_eth_subscribe_mocked(\n self,\n async_w3: \"AsyncWeb3\",\n subscription_params: Tuple[Any, ...],\n ws_subscription_response: Dict[str, Any],\n expected_formatted_result: Any,\n ) -> None:\n sub_id = await async_w3.eth.subscribe(*subscription_params)\n assert is_hexstr(sub_id)\n\n # stub out the subscription id so we know how to process the response\n ws_subscription_response[\"params\"][\"subscription\"] = sub_id\n\n # add the response to the subscription response cache as if it came from the\n # websocket connection\n await async_w3.provider._request_processor.cache_raw_response(\n ws_subscription_response, subscription=True\n )\n\n async for msg in async_w3.socket.process_subscriptions():\n response = cast(FormattedEthSubscriptionResponse, msg)\n assert response[\"subscription\"] == sub_id\n assert response[\"result\"] == expected_formatted_result\n\n # only testing one message, so break here\n break\n\n @pytest.mark.asyncio\n async def test_async_extradata_poa_middleware_on_eth_subscription(\n self,\n async_w3: \"AsyncWeb3\",\n ) -> None:\n async_w3.middleware_onion.inject(\n ExtraDataToPOAMiddleware, \"poa_middleware\", layer=0\n )\n\n sub_id = await async_w3.eth.subscribe(\"newHeads\")\n assert is_hexstr(sub_id)\n\n # add the response to the subscription response cache as if it came from the\n # websocket connection\n await async_w3.provider._request_processor.cache_raw_response(\n {\n \"jsonrpc\": \"2.0\",\n \"method\": \"eth_subscription\",\n \"params\": {\n \"subscription\": sub_id,\n \"result\": {\n \"extraData\": f\"0x{'00' * 100}\",\n },\n },\n },\n subscription=True,\n )\n\n async for msg in async_w3.socket.process_subscriptions():\n response = cast(FormattedEthSubscriptionResponse, msg)\n assert response.keys() == {\"subscription\", \"result\"}\n assert response[\"subscription\"] == sub_id\n assert response[\"result\"][\"proofOfAuthorityData\"] == HexBytes( # type: ignore # noqa: E501\n f\"0x{'00' * 100}\"\n )\n\n # only testing one message, so break here\n break\n\n # clean up\n async_w3.middleware_onion.remove(\"poa_middleware\")\n\n @pytest.mark.asyncio\n async def test_asyncio_gather_for_multiple_requests_matches_the_responses(\n self,\n async_w3: \"AsyncWeb3\",\n ) -> None:\n (\n latest,\n chain_id,\n block_num,\n chain_id2,\n pending,\n chain_id3,\n ) = await asyncio.gather(\n async_w3.eth.get_block(\"latest\"),\n async_w3.eth.chain_id,\n async_w3.eth.block_number,\n async_w3.eth.chain_id,\n async_w3.eth.get_block(\"pending\"),\n async_w3.eth.chain_id,\n )\n\n # assert attrdict middleware was applied appropriately\n assert isinstance(latest, AttributeDict)\n assert isinstance(pending, AttributeDict)\n\n # assert block values\n some_block_keys = [\n \"number\",\n \"hash\",\n \"parentHash\",\n \"transactionsRoot\",\n \"stateRoot\",\n \"receiptsRoot\",\n \"size\",\n \"gasLimit\",\n \"gasUsed\",\n \"timestamp\",\n \"transactions\",\n \"baseFeePerGas\",\n ]\n assert all(k in latest.keys() for k in some_block_keys)\n assert all(k in pending.keys() for k in some_block_keys)\n\n assert isinstance(block_num, int)\n assert latest[\"number\"] == block_num\n\n assert isinstance(chain_id, int)\n assert isinstance(chain_id2, int)\n assert isinstance(chain_id3, int)\n\n# Path: web3/auto/__init__.py\nfrom web3 import (\n Web3,\n)\n\nw3 = Web3()\n\n# Path: web3/auto/gethdev.py\nfrom web3 import (\n AsyncIPCProvider,\n AsyncWeb3,\n IPCProvider,\n Web3,\n)\nfrom web3.middleware import (\n ExtraDataToPOAMiddleware,\n)\nfrom web3.providers.ipc import (\n get_dev_ipc_path,\n)\n\nw3 = Web3(IPCProvider(get_dev_ipc_path()))\nw3.middleware_onion.inject(ExtraDataToPOAMiddleware, layer=0)\n\nasync_w3 = AsyncWeb3(AsyncIPCProvider(get_dev_ipc_path()))\nasync_w3.middleware_onion.inject(ExtraDataToPOAMiddleware, layer=0)\n\n# Path: web3/beacon/api_endpoints.py\n# [ BEACON endpoints ]\n\nGET_GENESIS = \"/eth/v1/beacon/genesis\"\n\n# states\nGET_HASH_ROOT = \"/eth/v1/beacon/states/{0}/root\"\nGET_FORK_DATA = \"/eth/v1/beacon/states/{0}/fork\"\nGET_FINALITY_CHECKPOINT = \"/eth/v1/beacon/states/{0}/finality_checkpoints\"\nGET_VALIDATORS = \"/eth/v1/beacon/states/{0}/validators\"\nGET_VALIDATOR = \"/eth/v1/beacon/states/{0}/validators/{1}\"\nGET_VALIDATOR_BALANCES = \"/eth/v1/beacon/states/{0}/validator_balances\"\nGET_EPOCH_COMMITTEES = \"/eth/v1/beacon/states/{0}/committees\"\nGET_EPOCH_SYNC_COMMITTEES = \"/eth/v1/beacon/states/{0}/sync_committees\"\nGET_EPOCH_RANDAO = \"/eth/v1/beacon/states/{0}/randao\"\n\n# headers\nGET_BLOCK_HEADERS = \"/eth/v1/beacon/headers\"\nGET_BLOCK_HEADER = \"/eth/v1/beacon/headers/{0}\"\n\n# blocks\nGET_BLOCK = \"/eth/v2/beacon/blocks/{0}\"\nGET_BLOCK_ROOT = \"/eth/v1/beacon/blocks/{0}/root\"\nGET_BLOCK_ATTESTATIONS = \"/eth/v1/beacon/blocks/{0}/attestations\"\nGET_BLINDED_BLOCKS = \"/eth/v1/beacon/blinded_blocks/{0}\"\n\n# rewards\nGET_REWARDS = \"/eth/v1/beacon/rewards/blocks/{0}\"\n\n# light client\nGET_LIGHT_CLIENT_BOOTSTRAP_STRUCTURE = \"/eth/v1/beacon/light_client/bootstrap/{0}\"\nGET_LIGHT_CLIENT_UPDATES = \"/eth/v1/beacon/light_client/updates\"\nGET_LIGHT_CLIENT_FINALITY_UPDATE = \"/eth/v1/beacon/light_client/finality_update\"\nGET_LIGHT_CLIENT_OPTIMISTIC_UPDATE = \"/eth/v1/beacon/light_client/optimistic_update\"\n\n# pool\nGET_ATTESTATIONS = \"/eth/v1/beacon/pool/attestations\"\nGET_ATTESTER_SLASHINGS = \"/eth/v1/beacon/pool/attester_slashings\"\nGET_PROPOSER_SLASHINGS = \"/eth/v1/beacon/pool/proposer_slashings\"\nGET_VOLUNTARY_EXITS = \"/eth/v1/beacon/pool/voluntary_exits\"\nGET_BLS_TO_EXECUTION_CHANGES = \"/eth/v1/beacon/pool/bls_to_execution_changes\"\n\n\n# [ CONFIG endpoints ]\n\nGET_FORK_SCHEDULE = \"/eth/v1/config/fork_schedule\"\nGET_SPEC = \"/eth/v1/config/spec\"\nGET_DEPOSIT_CONTRACT = \"/eth/v1/config/deposit_contract\"\n\n# [ DEBUG endpoints ]\n\nGET_BEACON_STATE = \"/eth/v1/debug/beacon/states/{0}\"\nGET_BEACON_HEADS = \"/eth/v1/debug/beacon/heads\"\n\n# [ NODE endpoints ]\n\nGET_NODE_IDENTITY = \"/eth/v1/node/identity\"\nGET_PEERS = \"/eth/v1/node/peers\"\nGET_PEER = \"/eth/v1/node/peers/{0}\"\nGET_HEALTH = \"/eth/v1/node/health\"\nGET_VERSION = \"/eth/v1/node/version\"\nGET_SYNCING = \"/eth/v1/node/syncing\"\n\n# Path: web3/beacon/async_beacon.py\nfrom typing import (\n Any,\n Dict,\n)\n\nfrom eth_typing import (\n URI,\n HexStr,\n)\n\nfrom web3._utils.request import (\n async_get_response_from_get_request,\n async_json_make_get_request,\n)\nfrom web3.beacon.api_endpoints import (\n GET_ATTESTATIONS,\n GET_ATTESTER_SLASHINGS,\n GET_BEACON_HEADS,\n GET_BEACON_STATE,\n GET_BLINDED_BLOCKS,\n GET_BLOCK,\n GET_BLOCK_ATTESTATIONS,\n GET_BLOCK_HEADER,\n GET_BLOCK_HEADERS,\n GET_BLOCK_ROOT,\n GET_BLS_TO_EXECUTION_CHANGES,\n GET_DEPOSIT_CONTRACT,\n GET_EPOCH_COMMITTEES,\n GET_EPOCH_RANDAO,\n GET_EPOCH_SYNC_COMMITTEES,\n GET_FINALITY_CHECKPOINT,\n GET_FORK_DATA,\n GET_FORK_SCHEDULE,\n GET_GENESIS,\n GET_HASH_ROOT,\n GET_HEALTH,\n GET_LIGHT_CLIENT_BOOTSTRAP_STRUCTURE,\n GET_LIGHT_CLIENT_FINALITY_UPDATE,\n GET_LIGHT_CLIENT_OPTIMISTIC_UPDATE,\n GET_LIGHT_CLIENT_UPDATES,\n GET_NODE_IDENTITY,\n GET_PEER,\n GET_PEERS,\n GET_PROPOSER_SLASHINGS,\n GET_REWARDS,\n GET_SPEC,\n GET_SYNCING,\n GET_VALIDATOR,\n GET_VALIDATOR_BALANCES,\n GET_VALIDATORS,\n GET_VERSION,\n GET_VOLUNTARY_EXITS,\n)\n\n\nclass AsyncBeacon:\n is_async = True\n\n def __init__(\n self,\n base_url: str,\n request_timeout: float = 10.0,\n ) -> None:\n self.base_url = base_url\n self.request_timeout = request_timeout\n\n async def _async_make_get_request(self, endpoint_uri: str) -> Dict[str, Any]:\n uri = URI(self.base_url + endpoint_uri)\n return await async_json_make_get_request(uri, timeout=self.request_timeout)\n\n # [ BEACON endpoints ]\n\n # states\n\n async def get_genesis(self) -> Dict[str, Any]:\n return await self._async_make_get_request(GET_GENESIS)\n\n async def get_hash_root(self, state_id: str = \"head\") -> Dict[str, Any]:\n return await self._async_make_get_request(GET_HASH_ROOT.format(state_id))\n\n async def get_fork_data(self, state_id: str = \"head\") -> Dict[str, Any]:\n return await self._async_make_get_request(GET_FORK_DATA.format(state_id))\n\n async def get_finality_checkpoint(self, state_id: str = \"head\") -> Dict[str, Any]:\n return await self._async_make_get_request(\n GET_FINALITY_CHECKPOINT.format(state_id)\n )\n\n async def get_validators(self, state_id: str = \"head\") -> Dict[str, Any]:\n return await self._async_make_get_request(GET_VALIDATORS.format(state_id))\n\n async def get_validator(\n self, validator_id: str, state_id: str = \"head\"\n ) -> Dict[str, Any]:\n return await self._async_make_get_request(\n GET_VALIDATOR.format(state_id, validator_id)\n )\n\n async def get_validator_balances(self, state_id: str = \"head\") -> Dict[str, Any]:\n return await self._async_make_get_request(\n GET_VALIDATOR_BALANCES.format(state_id)\n )\n\n async def get_epoch_committees(self, state_id: str = \"head\") -> Dict[str, Any]:\n return await self._async_make_get_request(GET_EPOCH_COMMITTEES.format(state_id))\n\n async def get_epoch_sync_committees(self, state_id: str = \"head\") -> Dict[str, Any]:\n return await self._async_make_get_request(\n GET_EPOCH_SYNC_COMMITTEES.format(state_id)\n )\n\n async def get_epoch_randao(self, state_id: str = \"head\") -> Dict[str, Any]:\n return await self._async_make_get_request(GET_EPOCH_RANDAO.format(state_id))\n\n # headers\n\n async def get_block_headers(self) -> Dict[str, Any]:\n return await self._async_make_get_request(GET_BLOCK_HEADERS)\n\n async def get_block_header(self, block_id: str) -> Dict[str, Any]:\n return await self._async_make_get_request(GET_BLOCK_HEADER.format(block_id))\n\n # block\n\n async def get_block(self, block_id: str) -> Dict[str, Any]:\n return await self._async_make_get_request(GET_BLOCK.format(block_id))\n\n async def get_block_root(self, block_id: str) -> Dict[str, Any]:\n return await self._async_make_get_request(GET_BLOCK_ROOT.format(block_id))\n\n async def get_block_attestations(self, block_id: str) -> Dict[str, Any]:\n return await self._async_make_get_request(\n GET_BLOCK_ATTESTATIONS.format(block_id)\n )\n\n async def get_blinded_blocks(self, block_id: str) -> Dict[str, Any]:\n return await self._async_make_get_request(GET_BLINDED_BLOCKS.format(block_id))\n\n # rewards\n\n async def get_rewards(self, block_id: str) -> Dict[str, Any]:\n return await self._async_make_get_request(GET_REWARDS.format(block_id))\n\n # light client (untested but follows spec)\n\n async def get_light_client_bootstrap_structure(\n self, block_root: HexStr\n ) -> Dict[str, Any]:\n return await self._async_make_get_request(\n GET_LIGHT_CLIENT_BOOTSTRAP_STRUCTURE.format(block_root)\n )\n\n async def get_light_client_updates(self) -> Dict[str, Any]:\n return await self._async_make_get_request(GET_LIGHT_CLIENT_UPDATES)\n\n async def get_light_client_finality_update(self) -> Dict[str, Any]:\n return await self._async_make_get_request(GET_LIGHT_CLIENT_FINALITY_UPDATE)\n\n async def get_light_client_optimistic_update(self) -> Dict[str, Any]:\n return await self._async_make_get_request(GET_LIGHT_CLIENT_OPTIMISTIC_UPDATE)\n\n # pool\n\n async def get_attestations(self) -> Dict[str, Any]:\n return await self._async_make_get_request(GET_ATTESTATIONS)\n\n async def get_attester_slashings(self) -> Dict[str, Any]:\n return await self._async_make_get_request(GET_ATTESTER_SLASHINGS)\n\n async def get_proposer_slashings(self) -> Dict[str, Any]:\n return await self._async_make_get_request(GET_PROPOSER_SLASHINGS)\n\n async def get_voluntary_exits(self) -> Dict[str, Any]:\n return await self._async_make_get_request(GET_VOLUNTARY_EXITS)\n\n async def get_bls_to_execution_changes(self) -> Dict[str, Any]:\n return await self._async_make_get_request(GET_BLS_TO_EXECUTION_CHANGES)\n\n # [ CONFIG endpoints ]\n\n async def get_fork_schedule(self) -> Dict[str, Any]:\n return await self._async_make_get_request(GET_FORK_SCHEDULE)\n\n async def get_spec(self) -> Dict[str, Any]:\n return await self._async_make_get_request(GET_SPEC)\n\n async def get_deposit_contract(self) -> Dict[str, Any]:\n return await self._async_make_get_request(GET_DEPOSIT_CONTRACT)\n\n # [ DEBUG endpoints ]\n\n async def get_beacon_state(self, state_id: str = \"head\") -> Dict[str, Any]:\n return await self._async_make_get_request(GET_BEACON_STATE.format(state_id))\n\n async def get_beacon_heads(self) -> Dict[str, Any]:\n return await self._async_make_get_request(GET_BEACON_HEADS)\n\n # [ NODE endpoints ]\n\n async def get_node_identity(self) -> Dict[str, Any]:\n return await self._async_make_get_request(GET_NODE_IDENTITY)\n\n async def get_peers(self) -> Dict[str, Any]:\n return await self._async_make_get_request(GET_PEERS)\n\n async def get_peer(self, peer_id: str) -> Dict[str, Any]:\n return await self._async_make_get_request(GET_PEER.format(peer_id))\n\n async def get_health(self) -> int:\n url = URI(self.base_url + GET_HEALTH)\n response = await async_get_response_from_get_request(url)\n return response.status\n\n async def get_version(self) -> Dict[str, Any]:\n return await self._async_make_get_request(GET_VERSION)\n\n async def get_syncing(self) -> Dict[str, Any]:\n return await self._async_make_get_request(GET_SYNCING)\n\n# Path: web3/beacon/beacon.py\nfrom typing import (\n Any,\n Dict,\n)\n\nfrom eth_typing import (\n URI,\n HexStr,\n)\n\nfrom web3._utils.request import (\n get_response_from_get_request,\n json_make_get_request,\n)\nfrom web3.beacon.api_endpoints import (\n GET_ATTESTATIONS,\n GET_ATTESTER_SLASHINGS,\n GET_BEACON_HEADS,\n GET_BEACON_STATE,\n GET_BLINDED_BLOCKS,\n GET_BLOCK,\n GET_BLOCK_ATTESTATIONS,\n GET_BLOCK_HEADER,\n GET_BLOCK_HEADERS,\n GET_BLOCK_ROOT,\n GET_BLS_TO_EXECUTION_CHANGES,\n GET_DEPOSIT_CONTRACT,\n GET_EPOCH_COMMITTEES,\n GET_EPOCH_RANDAO,\n GET_EPOCH_SYNC_COMMITTEES,\n GET_FINALITY_CHECKPOINT,\n GET_FORK_DATA,\n GET_FORK_SCHEDULE,\n GET_GENESIS,\n GET_HASH_ROOT,\n GET_HEALTH,\n GET_LIGHT_CLIENT_BOOTSTRAP_STRUCTURE,\n GET_LIGHT_CLIENT_FINALITY_UPDATE,\n GET_LIGHT_CLIENT_OPTIMISTIC_UPDATE,\n GET_LIGHT_CLIENT_UPDATES,\n GET_NODE_IDENTITY,\n GET_PEER,\n GET_PEERS,\n GET_PROPOSER_SLASHINGS,\n GET_REWARDS,\n GET_SPEC,\n GET_SYNCING,\n GET_VALIDATOR,\n GET_VALIDATOR_BALANCES,\n GET_VALIDATORS,\n GET_VERSION,\n GET_VOLUNTARY_EXITS,\n)\n\n\nclass Beacon:\n def __init__(\n self,\n base_url: str,\n request_timeout: float = 10.0,\n ) -> None:\n self.base_url = base_url\n self.request_timeout = request_timeout\n\n \ndef _make_get_request(self, endpoint_url: str) -> Dict[str, Any]:\n uri = URI(self.base_url + endpoint_url)\n return json_make_get_request(uri, timeout=self.request_timeout)\n\n # [ BEACON endpoints ]\n\n # states\n\n def get_genesis(self) -> Dict[str, Any]:\n return self._make_get_request(GET_GENESIS)\n\n def get_hash_root(self, state_id: str = \"head\") -> Dict[str, Any]:\n return self._make_get_request(GET_HASH_ROOT.format(state_id))\n\n def get_fork_data(self, state_id: str = \"head\") -> Dict[str, Any]:\n return self._make_get_request(GET_FORK_DATA.format(state_id))\n\n def get_finality_checkpoint(self, state_id: str = \"head\") -> Dict[str, Any]:\n return self._make_get_request(GET_FINALITY_CHECKPOINT.format(state_id))\n\n def get_validators(self, state_id: str = \"head\") -> Dict[str, Any]:\n return self._make_get_request(GET_VALIDATORS.format(state_id))\n\n def get_validator(\n self, validator_id: str, state_id: str = \"head\"\n ) -> Dict[str, Any]:\n return self._make_get_request(GET_VALIDATOR.format(state_id, validator_id))\n\n def get_validator_balances(self, state_id: str = \"head\") -> Dict[str, Any]:\n return self._make_get_request(GET_VALIDATOR_BALANCES.format(state_id))\n\n def get_epoch_committees(self, state_id: str = \"head\") -> Dict[str, Any]:\n return self._make_get_request(GET_EPOCH_COMMITTEES.format(state_id))\n\n def get_epoch_sync_committees(self, state_id: str = \"head\") -> Dict[str, Any]:\n return self._make_get_request(GET_EPOCH_SYNC_COMMITTEES.format(state_id))\n\n def get_epoch_randao(self, state_id: str = \"head\") -> Dict[str, Any]:\n return self._make_get_request(GET_EPOCH_RANDAO.format(state_id))\n\n # headers\n\n def get_block_headers(self) -> Dict[str, Any]:\n return self._make_get_request(GET_BLOCK_HEADERS)\n\n def get_block_header(self, block_id: str) -> Dict[str, Any]:\n return self._make_get_request(GET_BLOCK_HEADER.format(block_id))\n\n # blocks\n\n def get_block(self, block_id: str) -> Dict[str, Any]:\n return self._make_get_request(GET_BLOCK.format(block_id))\n\n def get_block_root(self, block_id: str) -> Dict[str, Any]:\n return self._make_get_request(GET_BLOCK_ROOT.format(block_id))\n\n def get_block_attestations(self, block_id: str) -> Dict[str, Any]:\n return self._make_get_request(GET_BLOCK_ATTESTATIONS.format(block_id))\n\n def get_blinded_blocks(self, block_id: str) -> Dict[str, Any]:\n return self._make_get_request(GET_BLINDED_BLOCKS.format(block_id))\n\n # rewards\n\n def get_rewards(self, block_id: str) -> Dict[str, Any]:\n return self._make_get_request(GET_REWARDS.format(block_id))\n\n # light client (untested but follows spec)\n\n def get_light_client_bootstrap_structure(\n self, block_root: HexStr\n ) -> Dict[str, Any]:\n return self._make_get_request(\n GET_LIGHT_CLIENT_BOOTSTRAP_STRUCTURE.format(block_root)\n )\n\n def get_light_client_updates(self) -> Dict[str, Any]:\n return self._make_get_request(GET_LIGHT_CLIENT_UPDATES)\n\n def get_light_client_finality_update(self) -> Dict[str, Any]:\n return self._make_get_request(GET_LIGHT_CLIENT_FINALITY_UPDATE)\n\n def get_light_client_optimistic_update(self) -> Dict[str, Any]:\n return self._make_get_request(GET_LIGHT_CLIENT_OPTIMISTIC_UPDATE)\n\n # pool\n\n def get_attestations(self) -> Dict[str, Any]:\n return self._make_get_request(GET_ATTESTATIONS)\n\n def get_attester_slashings(self) -> Dict[str, Any]:\n return self._make_get_request(GET_ATTESTER_SLASHINGS)\n\n def get_proposer_slashings(self) -> Dict[str, Any]:\n return self._make_get_request(GET_PROPOSER_SLASHINGS)\n\n def get_voluntary_exits(self) -> Dict[str, Any]:\n return self._make_get_request(GET_VOLUNTARY_EXITS)\n\n def get_bls_to_execution_changes(self) -> Dict[str, Any]:\n return self._make_get_request(GET_BLS_TO_EXECUTION_CHANGES)\n\n # [ CONFIG endpoints ]\n\n def get_fork_schedule(self) -> Dict[str, Any]:\n return self._make_get_request(GET_FORK_SCHEDULE)\n\n def get_spec(self) -> Dict[str, Any]:\n return self._make_get_request(GET_SPEC)\n\n def get_deposit_contract(self) -> Dict[str, Any]:\n return self._make_get_request(GET_DEPOSIT_CONTRACT)\n\n # [ DEBUG endpoints ]\n\n def get_beacon_state(self, state_id: str = \"head\") -> Dict[str, Any]:\n return self._make_get_request(GET_BEACON_STATE.format(state_id))\n\n def get_beacon_heads(self) -> Dict[str, Any]:\n return self._make_get_request(GET_BEACON_HEADS)\n\n # [ NODE endpoints ]\n\n def get_node_identity(self) -> Dict[str, Any]:\n return self._make_get_request(GET_NODE_IDENTITY)\n\n def get_peers(self) -> Dict[str, Any]:\n return self._make_get_request(GET_PEERS)\n\n def get_peer(self, peer_id: str) -> Dict[str, Any]:\n return self._make_get_request(GET_PEER.format(peer_id))\n\n def get_health(self) -> int:\n url = URI(self.base_url + GET_HEALTH)\n response = get_response_from_get_request(url)\n return response.status_code\n\n def get_version(self) -> Dict[str, Any]:\n return self._make_get_request(GET_VERSION)\n\n def get_syncing(self) -> Dict[str, Any]:\n return self._make_get_request(GET_SYNCING)\n\n# Path: web3/beacon/__init__.py\nfrom .async_beacon import AsyncBeacon\nfrom .beacon import Beacon\n\n# Path: web3/gas_strategies/rpc.py\nfrom typing import (\n Optional,\n)\n\nfrom web3 import (\n Web3,\n)\nfrom web3.types import (\n TxParams,\n Wei,\n)\n\n\ndef rpc_gas_price_strategy(\n w3: Web3, transaction_params: Optional[TxParams] = None\n) -> Wei:\n \"\"\"\n A simple gas price strategy deriving it's value from the eth_gasPrice JSON-RPC call.\n \"\"\"\n return w3.eth.gas_price\n\n# Path: web3/gas_strategies/time_based.py\nimport collections\nimport math\nimport operator\nfrom typing import (\n Iterable,\n Sequence,\n Tuple,\n)\n\nfrom eth_typing import (\n ChecksumAddress,\n)\nfrom eth_utils import (\n to_tuple,\n)\nfrom eth_utils.toolz import (\n curry,\n groupby,\n sliding_window,\n)\nfrom hexbytes import (\n HexBytes,\n)\n\nfrom web3 import (\n Web3,\n)\nfrom web3._utils.math import (\n percentile,\n)\nfrom web3.exceptions import (\n InsufficientData,\n Web3ValidationError,\n)\nfrom web3.types import (\n BlockNumber,\n GasPriceStrategy,\n TxParams,\n Wei,\n)\n\nMinerData = collections.namedtuple(\n \"MinerData\", [\"miner\", \"num_blocks\", \"min_gas_price\", \"low_percentile_gas_price\"]\n)\n...\n", "needle_token_start": 13728, "needle_token_end": 13789, "code_context_ntokens": 16145, "output": ["```python\nfrom typing import (\n Any,\n Dict,\n)\n\nfrom eth_typing import (\n URI,\n HexStr,\n)\n\nfrom web3._utils.request import (\n get_response_from_get_request,\n json_make_get_request,\n)\nfrom web3.beacon.api_endpoints import (\n GET_ATTESTATIONS,\n GET_ATTESTER_SLASHINGS,\n GET_BEACON_HEADS,\n GET_BEACON_STATE,\n GET_BLINDED_BLOCKS,\n GET_BLOCK,\n GET_BLOCK_ATTESTATIONS,\n GET_BLOCK_HEADER,\n GET_BLOCK_HEADERS,\n GET_BLOCK_ROOT,\n GET_BLS_TO_EXECUTION_CHANGES,\n GET_DEPOSIT_CONTRACT,\n GET_EPOCH_COMMITTEES,\n GET_EPOCH_RANDAO,\n GET_EPOCH_SYNC_COMMITTEES,\n GET_FINALITY_CHECKPOINT,\n GET_FORK_DATA,\n GET_FORK_SCHEDULE,\n GET_GENESIS,\n GET_HASH_ROOT,\n GET_HEALTH,\n GET_LIGHT_CLIENT_BOOTSTRAP_STRUCTURE,\n GET_LIGHT_CLIENT_FINALITY_UPDATE,\n GET_LIGHT_CLIENT_OPTIMISTIC_UPDATE,\n GET_LIGHT_CLIENT_UPDATES,\n GET_NODE_IDENTITY,\n GET_PEER,\n GET_PEERS,\n GET_PROPOSER_SLASHINGS,\n GET_REWARDS,\n GET_SPEC,\n GET_SYNCING,\n GET_VALIDATOR,\n GET_VALIDATOR_BALANCES,\n GET_VALIDATORS,\n GET_VERSION,\n GET_VOLUNTARY_EXITS,\n)\n\n\nclass Beacon:\n def __init__(\n self,\n base_url: str,\n request_timeout: float = 10.0,\n ) -> None:\n self.base_url = base_url\n self.request_timeout = request_timeout\n\n def _make_get_request(self, endpoint_url: str) -> Dict[str, Any]:\n uri = URI(self.base_url + endpoint_url)\n return json_make_get_request(uri, timeout=self.request_timeout)\n\n # [ BEACON endpoints ]\n\n # states\n\n def get_genesis(self) -> Dict[str, Any]:\n return self._make_get_request(GET_GENESIS)\n\n def get_hash_root(self, state_id: str = \"head\") -> Dict[str, Any]:\n return self._make_get_request(GET_HASH_ROOT.format(state_id))\n\n def get_fork_data(self, state_id: str = \"head\") -> Dict[str, Any]:\n return self._make_get_request(GET_FORK_DATA.format(state_id))\n\n def get_finality_checkpoint(self, state_id: str = \"head\") -> Dict[str, Any]:\n return self._make_get_request(GET_FINALITY_CHECKPOINT.format(state_id))\n\n def get_validators(self, state_id: str = \"head\") -> Dict[str, Any]:\n return self._make_get_request(GET_VALIDATORS.format(state_id))\n\n def get_validator(\n self, validator_id: str, state_id: str = \"head\"\n ) -> Dict[str, Any]:\n return self._make_get_request(GET_VALIDATOR.format(state_id, validator_id))\n\n def get_validator_balances(self, state_id: str = \"head\") -> Dict[str, Any]:\n return self._make_get_request(GET_VALIDATOR_BALANCES.format(state_id))\n\n def get_epoch_committees(self, state_id: str = \"head\") -> Dict[str, Any]:\n return self._make_get_request(GET_EPOCH_COMMITTEES.format(state_id))\n\n def get_epoch_sync_committees(self, state_id: str = \"head\") -> Dict[str, Any]:\n return self._make_get_request(GET_EPOCH_SYNC_COMMITTEES.format(state_id))\n\n def get_epoch_randao(self, state_id: str = \"head\") -> Dict[str, Any]:\n return self._make_get_request(GET_EPOCH_RANDAO.format(state_id))\n\n # headers\n\n def get_block_headers(self) -> Dict[str, Any]:\n return self._make_get_request(GET_BLOCK_HEADERS)\n\n def get_block_header(self, block_id: str) -> Dict[str, Any]:\n return self._make_get_request(GET_BLOCK_HEADER.format(block_id))\n\n # blocks\n\n def get_block(self, block_id: str) -> Dict[str, Any]:\n return self._make_get_request(GET_BLOCK.format(block_id))\n\n def get_block_root(self, block_id: str) -> Dict[str, Any]:\n return self._make_get_request(GET_BLOCK_ROOT.format(block_id))\n\n def get_block_attestations(self, block_id: str) -> Dict[str, Any]:\n return self._make_get_request(GET_BLOCK_ATTESTATIONS.format(block_id))\n\n def get_blinded_blocks(self, block_id: str) -> Dict[str, Any]:\n return self._make_get_request(GET_BLINDED_BLOCKS.format(block_id))\n\n # rewards\n\n def get_rewards(self, block_id: str) -> Dict[str, Any]:\n return self._make_get_request(GET_REWARDS.format(block_id))\n\n # light client (untested but follows spec)\n\n def get_light_client"]} +{"repo": "ethereum/web3.py", "name": "test_eth_call_revert_custom_error_with_msg", "language": "python", "path": "web3/_utils/module_testing/eth_module.py", "position_ratio": 0.95, "description": "\nFunction Description:\n1. **Purpose**: To verify that a specific Ethereum smart contract function correctly raises a custom error with a predefined message when called under unauthorized conditions.\n2. **Input**: The test function inputs include a transaction parameter set specifying the caller and the contract address, and the encoded ABI of the unauthorized action.\n3. **Output**: The expected output is the raising of a custom error that matches the encoded unauthorized action message.\n4. **Procedure**: The test initiates by encoding the unauthorized action and its message using the contract's ABI. It then prepares a transaction to call the custom error function from an unauthorized account. The test asserts that this transaction call raises the expected custom error with the specific message encoded earlier.\n\n", "instruction": "Based on the function description and code context, please retrieve and repeat the exact described function from the code context in a code block wrapped by ```:", "template": "instruction\ncode_context\ndescription\ninstruction", "code_context": " ), # must be greater than base_fee post London\n }\n txn_hash = await async_w3.eth.send_transaction(txn_params)\n\n modified_txn_hash = await async_w3.eth.modify_transaction(\n txn_hash, gasPrice=(cast(int, txn_params[\"gasPrice\"]) * 2), value=2\n )\n modified_txn = await async_w3.eth.get_transaction(modified_txn_hash)\n\n assert is_same_address(\n modified_txn[\"from\"], cast(ChecksumAddress, txn_params[\"from\"])\n )\n assert is_same_address(\n modified_txn[\"to\"], cast(ChecksumAddress, txn_params[\"to\"])\n )\n assert modified_txn[\"value\"] == 2\n assert modified_txn[\"gas\"] == 21000\n assert modified_txn[\"gasPrice\"] == cast(int, txn_params[\"gasPrice\"]) * 2\n\n @pytest.mark.asyncio\n async def test_eth_modify_transaction(\n self, async_w3: \"AsyncWeb3\", async_unlocked_account: ChecksumAddress\n ) -> None:\n txn_params: TxParams = {\n \"from\": async_unlocked_account,\n \"to\": async_unlocked_account,\n \"value\": Wei(1),\n \"gas\": 21000,\n \"maxPriorityFeePerGas\": async_w3.to_wei(1, \"gwei\"),\n \"maxFeePerGas\": async_w3.to_wei(2, \"gwei\"),\n }\n txn_hash = await async_w3.eth.send_transaction(txn_params)\n\n modified_txn_hash = await async_w3.eth.modify_transaction(\n txn_hash,\n value=2,\n maxPriorityFeePerGas=(cast(Wei, txn_params[\"maxPriorityFeePerGas\"]) * 2),\n maxFeePerGas=(cast(Wei, txn_params[\"maxFeePerGas\"]) * 2),\n )\n modified_txn = await async_w3.eth.get_transaction(modified_txn_hash)\n\n assert is_same_address(\n modified_txn[\"from\"], cast(ChecksumAddress, txn_params[\"from\"])\n )\n assert is_same_address(\n modified_txn[\"to\"], cast(ChecksumAddress, txn_params[\"to\"])\n )\n assert modified_txn[\"value\"] == 2\n assert modified_txn[\"gas\"] == 21000\n assert (\n modified_txn[\"maxPriorityFeePerGas\"]\n == cast(Wei, txn_params[\"maxPriorityFeePerGas\"]) * 2\n )\n assert modified_txn[\"maxFeePerGas\"] == cast(Wei, txn_params[\"maxFeePerGas\"]) * 2\n\n @pytest.mark.asyncio\n async def test_async_eth_sign_transaction(\n self, async_w3: \"AsyncWeb3\", async_unlocked_account: ChecksumAddress\n ) -> None:\n txn_params: TxParams = {\n \"from\": async_unlocked_account,\n \"to\": async_unlocked_account,\n \"value\": Wei(1),\n \"gas\": 21000,\n \"maxFeePerGas\": async_w3.to_wei(2, \"gwei\"),\n \"maxPriorityFeePerGas\": async_w3.to_wei(1, \"gwei\"),\n \"nonce\": Nonce(0),\n }\n result = await async_w3.eth.sign_transaction(txn_params)\n signatory_account = async_w3.eth.account.recover_transaction(result[\"raw\"])\n assert async_unlocked_account == signatory_account\n assert result[\"tx\"][\"to\"] == txn_params[\"to\"]\n assert result[\"tx\"][\"value\"] == txn_params[\"value\"]\n assert result[\"tx\"][\"gas\"] == txn_params[\"gas\"]\n assert result[\"tx\"][\"maxFeePerGas\"] == txn_params[\"maxFeePerGas\"]\n assert (\n result[\"tx\"][\"maxPriorityFeePerGas\"] == txn_params[\"maxPriorityFeePerGas\"]\n )\n assert result[\"tx\"][\"nonce\"] == txn_params[\"nonce\"]\n\n @pytest.mark.asyncio\n async def test_eth_sign_typed_data(\n self,\n async_w3: \"AsyncWeb3\",\n async_unlocked_account_dual_type: ChecksumAddress,\n async_skip_if_testrpc: Callable[[\"AsyncWeb3\"], None],\n ) -> None:\n validJSONMessage = \"\"\"\n {\n \"types\": {\n \"EIP712Domain\": [\n {\"name\": \"name\", \"type\": \"string\"},\n {\"name\": \"version\", \"type\": \"string\"},\n {\"name\": \"chainId\", \"type\": \"uint256\"},\n {\"name\": \"verifyingContract\", \"type\": \"address\"}\n ],\n \"Person\": [\n {\"name\": \"name\", \"type\": \"string\"},\n {\"name\": \"wallet\", \"type\": \"address\"}\n ],\n \"Mail\": [\n {\"name\": \"from\", \"type\": \"Person\"},\n {\"name\": \"to\", \"type\": \"Person\"},\n {\"name\": \"contents\", \"type\": \"string\"}\n ]\n },\n \"primaryType\": \"Mail\",\n \"domain\": {\n \"name\": \"Ether Mail\",\n \"version\": \"1\",\n \"chainId\": \"0x01\",\n \"verifyingContract\": \"0xCcCCccccCCCCcCCCCCCcCcCccCcCCCcCcccccccC\"\n },\n \"message\": {\n \"from\": {\n \"name\": \"Cow\",\n \"wallet\": \"0xCD2a3d9F938E13CD947Ec05AbC7FE734Df8DD826\"\n },\n \"to\": {\n \"name\": \"Bob\",\n \"wallet\": \"0xbBbBBBBbbBBBbbbBbbBbbbbBBbBbbbbBbBbbBBbB\"\n },\n \"contents\": \"Hello, Bob!\"\n }\n }\n \"\"\"\n async_skip_if_testrpc(async_w3)\n signature = HexBytes(\n await async_w3.eth.sign_typed_data(\n async_unlocked_account_dual_type, json.loads(validJSONMessage)\n )\n )\n assert len(signature) == 32 + 32 + 1\n\n @pytest.mark.asyncio\n async def test_invalid_eth_sign_typed_data(\n self,\n async_w3: \"AsyncWeb3\",\n async_unlocked_account_dual_type: ChecksumAddress,\n async_skip_if_testrpc: Callable[[\"AsyncWeb3\"], None],\n ) -> None:\n async_skip_if_testrpc(async_w3)\n invalid_typed_message = \"\"\"\n {\n \"types\": {\n \"EIP712Domain\": [\n {\"name\": \"name\", \"type\": \"string\"},\n {\"name\": \"version\", \"type\": \"string\"},\n {\"name\": \"chainId\", \"type\": \"uint256\"},\n {\"name\": \"verifyingContract\", \"type\": \"address\"}\n ],\n \"Person\": [\n {\"name\": \"name\", \"type\": \"string\"},\n {\"name\": \"wallet\", \"type\": \"address\"}\n ],\n \"Mail\": [\n {\"name\": \"from\", \"type\": \"Person\"},\n {\"name\": \"to\", \"type\": \"Person[2]\"},\n {\"name\": \"contents\", \"type\": \"string\"}\n ]\n },\n \"primaryType\": \"Mail\",\n \"domain\": {\n \"name\": \"Ether Mail\",\n \"version\": \"1\",\n \"chainId\": \"0x01\",\n \"verifyingContract\": \"0xCcCCccccCCCCcCCCCCCcCcCccCcCCCcCcccccccC\"\n },\n \"message\": {\n \"from\": {\n \"name\": \"Cow\",\n \"wallet\": \"0xCD2a3d9F938E13CD947Ec05AbC7FE734Df8DD826\"\n },\n \"to\": [{\n \"name\": \"Bob\",\n \"wallet\": \"0xbBbBBBBbbBBBbbbBbbBbbbbBBbBbbbbBbBbbBBbB\"\n }],\n \"contents\": \"Hello, Bob!\"\n }\n }\n \"\"\"\n with pytest.raises(\n ValueError,\n match=r\".*Expected 2 items for array type Person\\[2\\], got 1 items.*\",\n ):\n await async_w3.eth.sign_typed_data(\n async_unlocked_account_dual_type, json.loads(invalid_typed_message)\n )\n\n @pytest.mark.asyncio\n async def test_async_eth_sign_transaction_legacy(\n self, async_w3: \"AsyncWeb3\", async_unlocked_account: ChecksumAddress\n ) -> None:\n txn_params: TxParams = {\n \"from\": async_unlocked_account,\n \"to\": async_unlocked_account,\n \"value\": Wei(1),\n \"gas\": 21000,\n \"gasPrice\": await async_w3.eth.gas_price,\n \"nonce\": Nonce(0),\n }\n result = await async_w3.eth.sign_transaction(txn_params)\n signatory_account = async_w3.eth.account.recover_transaction(result[\"raw\"])\n assert async_unlocked_account == signatory_account\n assert result[\"tx\"][\"to\"] == txn_params[\"to\"]\n assert result[\"tx\"][\"value\"] == txn_params[\"value\"]\n assert result[\"tx\"][\"gas\"] == txn_params[\"gas\"]\n assert result[\"tx\"][\"gasPrice\"] == txn_params[\"gasPrice\"]\n assert result[\"tx\"][\"nonce\"] == txn_params[\"nonce\"]\n\n @pytest.mark.asyncio\n async def test_async_eth_sign_transaction_hex_fees(\n self, async_w3: \"AsyncWeb3\", async_unlocked_account: ChecksumAddress\n ) -> None:\n txn_params: TxParams = {\n \"from\": async_unlocked_account,\n \"to\": async_unlocked_account,\n \"value\": Wei(1),\n \"gas\": 21000,\n \"maxFeePerGas\": hex(async_w3.to_wei(2, \"gwei\")),\n \"maxPriorityFeePerGas\": hex(async_w3.to_wei(1, \"gwei\")),\n \"nonce\": Nonce(0),\n }\n result = await async_w3.eth.sign_transaction(txn_params)\n signatory_account = async_w3.eth.account.recover_transaction(result[\"raw\"])\n assert async_unlocked_account == signatory_account\n assert result[\"tx\"][\"to\"] == txn_params[\"to\"]\n assert result[\"tx\"][\"value\"] == txn_params[\"value\"]\n assert result[\"tx\"][\"gas\"] == txn_params[\"gas\"]\n assert result[\"tx\"][\"maxFeePerGas\"] == int(str(txn_params[\"maxFeePerGas\"]), 16)\n assert result[\"tx\"][\"maxPriorityFeePerGas\"] == int(\n str(txn_params[\"maxPriorityFeePerGas\"]), 16\n )\n assert result[\"tx\"][\"nonce\"] == txn_params[\"nonce\"]\n\n @pytest.mark.asyncio\n @pytest.mark.xfail(\n reason=\"async name_to_address_middleware has not been implemented yet\"\n )\n async def test_async_eth_sign_transaction_ens_names(\n self, async_w3: \"AsyncWeb3\", async_unlocked_account: ChecksumAddress\n ) -> None:\n with ens_addresses(async_w3, {\"unlocked-account.eth\": async_unlocked_account}):\n txn_params: TxParams = {\n \"from\": \"unlocked-account.eth\",\n \"to\": \"unlocked-account.eth\",\n \"value\": Wei(1),\n \"gas\": 21000,\n \"maxFeePerGas\": async_w3.to_wei(2, \"gwei\"),\n \"maxPriorityFeePerGas\": async_w3.to_wei(1, \"gwei\"),\n \"nonce\": Nonce(0),\n }\n result = await async_w3.eth.sign_transaction(txn_params)\n signatory_account = async_w3.eth.account.recover_transaction(result[\"raw\"])\n assert async_unlocked_account == signatory_account\n assert result[\"tx\"][\"to\"] == async_unlocked_account\n assert result[\"tx\"][\"value\"] == txn_params[\"value\"]\n assert result[\"tx\"][\"gas\"] == txn_params[\"gas\"]\n assert result[\"tx\"][\"maxFeePerGas\"] == txn_params[\"maxFeePerGas\"]\n assert (\n result[\"tx\"][\"maxPriorityFeePerGas\"]\n == txn_params[\"maxPriorityFeePerGas\"]\n )\n assert result[\"tx\"][\"nonce\"] == txn_params[\"nonce\"]\n\n @pytest.mark.asyncio\n async def test_eth_send_transaction(\n self, async_w3: \"AsyncWeb3\", async_unlocked_account_dual_type: ChecksumAddress\n ) -> None:\n txn_params: TxParams = {\n \"from\": async_unlocked_account_dual_type,\n \"to\": async_unlocked_account_dual_type,\n \"value\": Wei(1),\n \"gas\": 21000,\n \"maxFeePerGas\": async_w3.to_wei(3, \"gwei\"),\n \"maxPriorityFeePerGas\": async_w3.to_wei(1, \"gwei\"),\n }\n txn_hash = await async_w3.eth.send_transaction(txn_params)\n txn = await async_w3.eth.get_transaction(txn_hash)\n\n assert is_same_address(txn[\"from\"], cast(ChecksumAddress, txn_params[\"from\"]))\n assert is_same_address(txn[\"to\"], cast(ChecksumAddress, txn_params[\"to\"]))\n assert txn[\"value\"] == 1\n assert txn[\"gas\"] == 21000\n assert txn[\"maxFeePerGas\"] == txn_params[\"maxFeePerGas\"]\n assert txn[\"maxPriorityFeePerGas\"] == txn_params[\"maxPriorityFeePerGas\"]\n assert txn[\"gasPrice\"] == txn_params[\"maxFeePerGas\"]\n\n @pytest.mark.asyncio\n async def test_eth_send_transaction_default_fees(\n self, async_w3: \"AsyncWeb3\", async_unlocked_account_dual_type: ChecksumAddress\n ) -> None:\n txn_params: TxParams = {\n \"from\": async_unlocked_account_dual_type,\n \"to\": async_unlocked_account_dual_type,\n \"value\": Wei(1),\n \"gas\": 21000,\n }\n txn_hash = await async_w3.eth.send_transaction(txn_params)\n txn = await async_w3.eth.get_transaction(txn_hash)\n\n assert is_same_address(txn[\"from\"], cast(ChecksumAddress, txn_params[\"from\"]))\n assert is_same_address(txn[\"to\"], cast(ChecksumAddress, txn_params[\"to\"]))\n assert txn[\"value\"] == 1\n assert txn[\"gas\"] == 21000\n assert is_integer(txn[\"maxPriorityFeePerGas\"])\n assert is_integer(txn[\"maxFeePerGas\"])\n assert txn[\"gasPrice\"] == txn[\"maxFeePerGas\"]\n\n @pytest.mark.asyncio\n async def test_eth_send_transaction_hex_fees(\n self, async_w3: \"AsyncWeb3\", async_unlocked_account_dual_type: ChecksumAddress\n ) -> None:\n txn_params: TxParams = {\n \"from\": async_unlocked_account_dual_type,\n \"to\": async_unlocked_account_dual_type,\n \"value\": Wei(1),\n \"gas\": 21000,\n \"maxFeePerGas\": hex(250 * 10**9),\n \"maxPriorityFeePerGas\": hex(2 * 10**9),\n }\n txn_hash = await async_w3.eth.send_transaction(txn_params)\n txn = await async_w3.eth.get_transaction(txn_hash)\n\n assert is_same_address(txn[\"from\"], cast(ChecksumAddress, txn_params[\"from\"]))\n assert is_same_address(txn[\"to\"], cast(ChecksumAddress, txn_params[\"to\"]))\n assert txn[\"value\"] == 1\n assert txn[\"gas\"] == 21000\n assert txn[\"maxFeePerGas\"] == 250 * 10**9\n assert txn[\"maxPriorityFeePerGas\"] == 2 * 10**9\n\n @pytest.mark.asyncio\n async def test_eth_send_transaction_no_gas(\n self, async_w3: \"AsyncWeb3\", async_unlocked_account_dual_type: ChecksumAddress\n ) -> None:\n txn_params: TxParams = {\n \"from\": async_unlocked_account_dual_type,\n \"to\": async_unlocked_account_dual_type,\n \"value\": Wei(1),\n \"maxFeePerGas\": Wei(250 * 10**9),\n \"maxPriorityFeePerGas\": Wei(2 * 10**9),\n }\n txn_hash = await async_w3.eth.send_transaction(txn_params)\n txn = await async_w3.eth.get_transaction(txn_hash)\n\n assert is_same_address(txn[\"from\"], cast(ChecksumAddress, txn_params[\"from\"]))\n assert is_same_address(txn[\"to\"], cast(ChecksumAddress, txn_params[\"to\"]))\n assert txn[\"value\"] == 1\n assert txn[\"gas\"] == 121000 # 21000 + buffer\n\n @pytest.mark.asyncio\n async def test_eth_send_transaction_with_gas_price(\n self, async_w3: \"AsyncWeb3\", async_unlocked_account_dual_type: ChecksumAddress\n ) -> None:\n txn_params: TxParams = {\n \"from\": async_unlocked_account_dual_type,\n \"to\": async_unlocked_account_dual_type,\n \"value\": Wei(1),\n \"gas\": 21000,\n \"gasPrice\": Wei(1),\n \"maxFeePerGas\": Wei(250 * 10**9),\n \"maxPriorityFeePerGas\": Wei(2 * 10**9),\n }\n with pytest.raises(TransactionTypeMismatch):\n await async_w3.eth.send_transaction(txn_params)\n\n @pytest.mark.asyncio\n async def test_eth_send_transaction_no_priority_fee(\n self, async_w3: \"AsyncWeb3\", async_unlocked_account_dual_type: ChecksumAddress\n ) -> None:\n txn_params: TxParams = {\n \"from\": async_unlocked_account_dual_type,\n \"to\": async_unlocked_account_dual_type,\n \"value\": Wei(1),\n \"gas\": 21000,\n \"maxFeePerGas\": Wei(250 * 10**9),\n }\n with pytest.raises(\n InvalidTransaction, match=\"maxPriorityFeePerGas must be defined\"\n ):\n await async_w3.eth.send_transaction(txn_params)\n\n @pytest.mark.asyncio\n async def test_eth_send_transaction_no_max_fee(\n self, async_w3: \"AsyncWeb3\", async_unlocked_account_dual_type: ChecksumAddress\n ) -> None:\n maxPriorityFeePerGas = async_w3.to_wei(2, \"gwei\")\n txn_params: TxParams = {\n \"from\": async_unlocked_account_dual_type,\n \"to\": async_unlocked_account_dual_type,\n \"value\": Wei(1),\n \"gas\": 21000,\n \"maxPriorityFeePerGas\": maxPriorityFeePerGas,\n }\n txn_hash = await async_w3.eth.send_transaction(txn_params)\n txn = await async_w3.eth.get_transaction(txn_hash)\n\n assert is_same_address(txn[\"from\"], cast(ChecksumAddress, txn_params[\"from\"]))\n assert is_same_address(txn[\"to\"], cast(ChecksumAddress, txn_params[\"to\"]))\n assert txn[\"value\"] == 1\n assert txn[\"gas\"] == 21000\n\n block = await async_w3.eth.get_block(\"latest\")\n assert txn[\"maxFeePerGas\"] == maxPriorityFeePerGas + 2 * block[\"baseFeePerGas\"]\n\n @pytest.mark.asyncio\n async def test_eth_send_transaction_max_fee_less_than_tip(\n self, async_w3: \"AsyncWeb3\", async_unlocked_account_dual_type: ChecksumAddress\n ) -> None:\n txn_params: TxParams = {\n \"from\": async_unlocked_account_dual_type,\n \"to\": async_unlocked_account_dual_type,\n \"value\": Wei(1),\n \"gas\": 21000,\n \"maxFeePerGas\": Wei(1 * 10**9),\n \"maxPriorityFeePerGas\": Wei(2 * 10**9),\n }\n with pytest.raises(\n InvalidTransaction, match=\"maxFeePerGas must be >= maxPriorityFeePerGas\"\n ):\n await async_w3.eth.send_transaction(txn_params)\n\n @pytest.mark.asyncio\n async def test_validation_middleware_chain_id_mismatch(\n self, async_w3: \"AsyncWeb3\", async_unlocked_account_dual_type: ChecksumAddress\n ) -> None:\n wrong_chain_id = 1234567890\n actual_chain_id = await async_w3.eth.chain_id\n\n txn_params: TxParams = {\n \"from\": async_unlocked_account_dual_type,\n \"to\": async_unlocked_account_dual_type,\n \"value\": Wei(1),\n \"gas\": 21000,\n \"maxFeePerGas\": async_w3.to_wei(2, \"gwei\"),\n \"maxPriorityFeePerGas\": async_w3.to_wei(1, \"gwei\"),\n \"chainId\": wrong_chain_id,\n }\n with pytest.raises(\n Web3ValidationError,\n match=f\"The transaction declared chain ID {wrong_chain_id}, \"\n f\"but the connected node is on {actual_chain_id}\",\n ):\n await async_w3.eth.send_transaction(txn_params)\n\n @pytest.mark.asyncio\n async def test_ExtraDataToPOAMiddleware(\n self, async_w3: \"AsyncWeb3\", request_mocker: Type[RequestMocker]\n ) -> None:\n async_w3.middleware_onion.inject(ExtraDataToPOAMiddleware, \"poa\", layer=0)\n extra_data = f\"0x{'ff' * 33}\"\n\n async with request_mocker(\n async_w3,\n mock_results={\"eth_getBlockByNumber\": {\"extraData\": extra_data}},\n ):\n block = await async_w3.eth.get_block(\"latest\")\n\n assert \"extraData\" not in block\n assert block[\"proofOfAuthorityData\"] == to_bytes(hexstr=extra_data)\n\n # clean up\n async_w3.middleware_onion.remove(\"poa\")\n\n @pytest.mark.asyncio\n async def test_eth_send_raw_transaction(self, async_w3: \"AsyncWeb3\") -> None:\n # private key 0x3c2ab4e8f17a7dea191b8c991522660126d681039509dc3bb31af7c9bdb63518\n # This is an unfunded account, but the transaction has a 0 gas price, so is\n # valid. It never needs to be mined, we just want the transaction hash back\n # to confirm.\n # tx = {'to': '0x0000000000000000000000000000000000000000', 'value': 0, 'nonce': 1, 'gas': 21000, 'gasPrice': 0, 'chainId': 131277322940537} # noqa: E501\n # NOTE: nonce=1 to make txn unique from the non-async version of this test\n raw_txn = HexBytes(\n \"0xf8650180825208940000000000000000000000000000000000000000808086eecac466e115a0ffdd42d7dee4ac85427468bc616812e49432e285e4e8f5cd9381163ac3b28108a04ec6b0d89ecbd5e89b0399f336ad50f283fafd70e86593250bf5a2adfb93d17e\" # noqa: E501\n )\n expected_hash = HexStr(\n \"0x52b0ff9cb472f25872fa8ec6a62fa59454fc2ae7901cfcc6cc89d096f49b8fc1\"\n )\n txn_hash = await async_w3.eth.send_raw_transaction(raw_txn)\n assert txn_hash == async_w3.to_bytes(hexstr=expected_hash)\n\n @pytest.mark.asyncio\n async def test_GasPriceStrategyMiddleware(\n self, async_w3: \"AsyncWeb3\", async_unlocked_account_dual_type: ChecksumAddress\n ) -> None:\n txn_params: TxParams = {\n \"from\": async_unlocked_account_dual_type,\n \"to\": async_unlocked_account_dual_type,\n \"value\": Wei(1),\n \"gas\": 21000,\n }\n two_gwei_in_wei = async_w3.to_wei(2, \"gwei\")\n\n def gas_price_strategy(w3: \"Web3\", txn: TxParams) -> Wei:\n return two_gwei_in_wei\n\n async_w3.eth.set_gas_price_strategy(gas_price_strategy)\n\n txn_hash = await async_w3.eth.send_transaction(txn_params)\n txn = await async_w3.eth.get_transaction(txn_hash)\n\n assert txn[\"gasPrice\"] == two_gwei_in_wei\n async_w3.eth.set_gas_price_strategy(None) # reset strategy\n\n @pytest.mark.asyncio\n async def test_gas_price_strategy_middleware_hex_value(\n self, async_w3: \"AsyncWeb3\", async_unlocked_account_dual_type: ChecksumAddress\n ) -> None:\n txn_params: TxParams = {\n \"from\": async_unlocked_account_dual_type,\n \"to\": async_unlocked_account_dual_type,\n \"value\": Wei(1),\n \"gas\": 21000,\n }\n two_gwei_in_wei = async_w3.to_wei(2, \"gwei\")\n\n def gas_price_strategy(_w3: \"Web3\", _txn: TxParams) -> str:\n return hex(two_gwei_in_wei)\n\n async_w3.eth.set_gas_price_strategy(gas_price_strategy) # type: ignore\n\n txn_hash = await async_w3.eth.send_transaction(txn_params)\n txn = await async_w3.eth.get_transaction(txn_hash)\n\n assert txn[\"gasPrice\"] == two_gwei_in_wei\n async_w3.eth.set_gas_price_strategy(None) # reset strategy\n\n @pytest.mark.asyncio\n @pytest.mark.parametrize(\n \"max_fee\", (1000000000, None), ids=[\"with_max_fee\", \"without_max_fee\"]\n )\n async def test_gas_price_from_strategy_bypassed_for_dynamic_fee_txn(\n self,\n async_w3: \"AsyncWeb3\",\n async_unlocked_account_dual_type: ChecksumAddress,\n max_fee: Wei,\n ) -> None:\n max_priority_fee = async_w3.to_wei(1, \"gwei\")\n txn_params: TxParams = {\n \"from\": async_unlocked_account_dual_type,\n \"to\": async_unlocked_account_dual_type,\n \"value\": Wei(1),\n \"gas\": 21000,\n \"maxPriorityFeePerGas\": max_priority_fee,\n }\n if max_fee is not None:\n txn_params = assoc(txn_params, \"maxFeePerGas\", max_fee)\n\n def gas_price_strategy(w3: \"Web3\", txn: TxParams) -> Wei:\n return async_w3.to_wei(2, \"gwei\")\n\n async_w3.eth.set_gas_price_strategy(gas_price_strategy)\n\n txn_hash = await async_w3.eth.send_transaction(txn_params)\n txn = await async_w3.eth.get_transaction(txn_hash)\n\n latest_block = await async_w3.eth.get_block(\"latest\")\n assert (\n txn[\"maxFeePerGas\"] == max_fee\n if max_fee is not None\n else 2 * latest_block[\"baseFeePerGas\"] + max_priority_fee\n )\n assert txn[\"maxPriorityFeePerGas\"] == max_priority_fee\n assert txn[\"gasPrice\"] == txn[\"maxFeePerGas\"]\n\n async_w3.eth.set_gas_price_strategy(None) # reset strategy\n\n @pytest.mark.asyncio\n async def test_gas_price_from_strategy_bypassed_for_dynamic_fee_txn_no_tip(\n self,\n async_w3: \"AsyncWeb3\",\n async_unlocked_account_dual_type: ChecksumAddress,\n ) -> None:\n txn_params: TxParams = {\n \"from\": async_unlocked_account_dual_type,\n \"to\": async_unlocked_account_dual_type,\n \"value\": Wei(1),\n \"gas\": 21000,\n \"maxFeePerGas\": Wei(1000000000),\n }\n\n def gas_price_strategy(_w3: \"Web3\", _txn: TxParams) -> Wei:\n return async_w3.to_wei(2, \"gwei\")\n\n async_w3.eth.set_gas_price_strategy(gas_price_strategy)\n\n with pytest.raises(\n InvalidTransaction, match=\"maxPriorityFeePerGas must be defined\"\n ):\n await async_w3.eth.send_transaction(txn_params)\n\n async_w3.eth.set_gas_price_strategy(None) # reset strategy\n\n @pytest.mark.asyncio\n async def test_eth_estimate_gas(\n self, async_w3: \"AsyncWeb3\", async_unlocked_account_dual_type: ChecksumAddress\n ) -> None:\n gas_estimate = await async_w3.eth.estimate_gas(\n {\n \"from\": async_unlocked_account_dual_type,\n \"to\": async_unlocked_account_dual_type,\n \"value\": Wei(1),\n }\n )\n assert is_integer(gas_estimate)\n assert gas_estimate > 0\n\n @pytest.mark.asyncio\n @pytest.mark.parametrize(\n \"params\",\n (\n {\n \"nonce\": 1, # int\n \"balance\": 1, # int\n \"code\": HexStr(\"0x\"), # HexStr\n # with state\n \"state\": {HexStr(f\"0x{'00' * 32}\"): HexStr(f\"0x{'00' * 32}\")},\n },\n {\n \"nonce\": HexStr(\"0x1\"), # HexStr\n \"balance\": HexStr(\"0x1\"), # HexStr\n \"code\": b\"\\x00\", # bytes\n # with stateDiff\n \"stateDiff\": {HexStr(f\"0x{'00' * 32}\"): HexStr(f\"0x{'00' * 32}\")},\n },\n ),\n )\n async def test_eth_estimate_gas_with_override_param_type_check(\n self,\n async_w3: \"AsyncWeb3\",\n async_math_contract: \"Contract\",\n params: StateOverrideParams,\n ) -> None:\n txn_params: TxParams = {\"from\": await async_w3.eth.coinbase}\n\n # assert does not raise\n await async_w3.eth.estimate_gas(\n txn_params, None, {async_math_contract.address: params}\n )\n\n @pytest.mark.asyncio\n async def test_eth_fee_history(self, async_w3: \"AsyncWeb3\") -> None:\n fee_history = await async_w3.eth.fee_history(1, \"latest\", [50])\n assert is_list_like(fee_history[\"baseFeePerGas\"])\n assert is_list_like(fee_history[\"gasUsedRatio\"])\n assert is_integer(fee_history[\"oldestBlock\"])\n assert fee_history[\"oldestBlock\"] >= 0\n assert is_list_like(fee_history[\"reward\"])\n if len(fee_history[\"reward\"]) > 0:\n assert is_list_like(fee_history[\"reward\"][0])\n\n @pytest.mark.asyncio\n async def test_eth_fee_history_with_integer(\n self, async_w3: \"AsyncWeb3\", async_empty_block: BlockData\n ) -> None:\n fee_history = await async_w3.eth.fee_history(\n 1, async_empty_block[\"number\"], [50]\n )\n assert is_list_like(fee_history[\"baseFeePerGas\"])\n assert is_list_like(fee_history[\"gasUsedRatio\"])\n assert is_integer(fee_history[\"oldestBlock\"])\n assert fee_history[\"oldestBlock\"] >= 0\n assert is_list_like(fee_history[\"reward\"])\n if len(fee_history[\"reward\"]) > 0:\n assert is_list_like(fee_history[\"reward\"][0])\n\n @pytest.mark.asyncio\n async def test_eth_fee_history_no_reward_percentiles(\n self, async_w3: \"AsyncWeb3\"\n ) -> None:\n fee_history = await async_w3.eth.fee_history(1, \"latest\")\n assert is_list_like(fee_history[\"baseFeePerGas\"])\n assert is_list_like(fee_history[\"gasUsedRatio\"])\n assert is_integer(fee_history[\"oldestBlock\"])\n assert fee_history[\"oldestBlock\"] >= 0\n\n @pytest.mark.asyncio\n async def test_eth_max_priority_fee(self, async_w3: \"AsyncWeb3\") -> None:\n max_priority_fee = await async_w3.eth.max_priority_fee\n assert is_integer(max_priority_fee)\n\n @pytest.mark.asyncio\n async def test_eth_max_priority_fee_with_fee_history_calculation(\n self, async_w3: \"AsyncWeb3\", request_mocker: Type[RequestMocker]\n ) -> None:\n async with request_mocker(\n async_w3,\n mock_errors={RPCEndpoint(\"eth_maxPriorityFeePerGas\"): {}},\n mock_results={RPCEndpoint(\"eth_feeHistory\"): {\"reward\": [[0]]}},\n ):\n with pytest.warns(\n UserWarning,\n match=(\n \"There was an issue with the method eth_maxPriorityFeePerGas. \"\n \"Calculating using eth_feeHistory.\"\n ),\n ):\n priority_fee = await async_w3.eth.max_priority_fee\n assert is_integer(priority_fee)\n assert priority_fee == PRIORITY_FEE_MIN\n\n @pytest.mark.asyncio\n async def test_eth_getBlockByHash(\n self, async_w3: \"AsyncWeb3\", async_empty_block: BlockData\n ) -> None:\n block = await async_w3.eth.get_block(async_empty_block[\"hash\"])\n assert block[\"hash\"] == async_empty_block[\"hash\"]\n\n @pytest.mark.asyncio\n async def test_eth_getBlockByHash_not_found(self, async_w3: \"AsyncWeb3\") -> None:\n with pytest.raises(BlockNotFound):\n await async_w3.eth.get_block(UNKNOWN_HASH)\n\n @pytest.mark.asyncio\n async def test_eth_getBlockByHash_pending(self, async_w3: \"AsyncWeb3\") -> None:\n block = await async_w3.eth.get_block(\"pending\")\n assert block[\"hash\"] is None\n\n @pytest.mark.asyncio\n async def test_eth_getBlockByNumber_with_integer(\n self, async_w3: \"AsyncWeb3\", async_empty_block: BlockData\n ) -> None:\n block = await async_w3.eth.get_block(async_empty_block[\"number\"])\n assert block[\"number\"] == async_empty_block[\"number\"]\n\n @pytest.mark.asyncio\n async def test_eth_getBlockByNumber_latest(\n self, async_w3: \"AsyncWeb3\", async_empty_block: BlockData\n ) -> None:\n current_block_number = await async_w3.eth.block_number\n block = await async_w3.eth.get_block(\"latest\")\n assert block[\"number\"] == current_block_number\n\n @pytest.mark.asyncio\n async def test_eth_getBlockByNumber_not_found(\n self, async_w3: \"AsyncWeb3\", async_empty_block: BlockData\n ) -> None:\n with pytest.raises(BlockNotFound):\n await async_w3.eth.get_block(BlockNumber(12345))\n\n @pytest.mark.asyncio\n async def test_eth_getBlockByNumber_pending(\n self, async_w3: \"AsyncWeb3\", async_empty_block: BlockData\n ) -> None:\n current_block_number = await async_w3.eth.block_number\n block = await async_w3.eth.get_block(\"pending\")\n assert block[\"number\"] == current_block_number + 1\n\n @pytest.mark.asyncio\n async def test_eth_getBlockByNumber_earliest(\n self, async_w3: \"AsyncWeb3\", async_empty_block: BlockData\n ) -> None:\n genesis_block = await async_w3.eth.get_block(BlockNumber(0))\n block = await async_w3.eth.get_block(\"earliest\")\n assert block[\"number\"] == 0\n assert block[\"hash\"] == genesis_block[\"hash\"]\n\n @pytest.mark.asyncio\n async def test_eth_getBlockByNumber_safe(\n self, async_w3: \"AsyncWeb3\", async_empty_block: BlockData\n ) -> None:\n block = await async_w3.eth.get_block(\"safe\")\n assert block is not None\n assert isinstance(block[\"number\"], int)\n\n @pytest.mark.asyncio\n async def test_eth_getBlockByNumber_finalized(\n self, async_w3: \"AsyncWeb3\", async_empty_block: BlockData\n ) -> None:\n block = await async_w3.eth.get_block(\"finalized\")\n assert block is not None\n assert isinstance(block[\"number\"], int)\n\n @pytest.mark.asyncio\n async def test_eth_get_block_by_number_full_transactions(\n self, async_w3: \"AsyncWeb3\", async_block_with_txn: BlockData\n ) -> None:\n block = await async_w3.eth.get_block(async_block_with_txn[\"number\"], True)\n transaction = cast(TxData, block[\"transactions\"][0])\n assert transaction[\"hash\"] == async_block_with_txn[\"transactions\"][0]\n\n @pytest.mark.asyncio\n async def test_eth_get_raw_transaction(\n self, async_w3: \"AsyncWeb3\", mined_txn_hash: HexStr\n ) -> None:\n raw_transaction = await async_w3.eth.get_raw_transaction(mined_txn_hash)\n assert is_bytes(raw_transaction)\n\n @pytest.mark.asyncio\n async def test_eth_get_raw_transaction_raises_error(\n self, async_w3: \"AsyncWeb3\"\n ) -> None:\n with pytest.raises(\n TransactionNotFound, match=f\"Transaction with hash: '{UNKNOWN_HASH}'\"\n ):\n await async_w3.eth.get_raw_transaction(UNKNOWN_HASH)\n\n @pytest.mark.asyncio\n async def test_eth_get_raw_transaction_by_block(\n self,\n async_w3: \"AsyncWeb3\",\n async_block_with_txn: BlockData,\n async_unlocked_account_dual_type: ChecksumAddress,\n ) -> None:\n # eth_getRawTransactionByBlockNumberAndIndex: block identifier\n # send a txn to make sure pending block has at least one txn\n await async_w3.eth.send_transaction(\n {\n \"from\": async_unlocked_account_dual_type,\n \"to\": async_unlocked_account_dual_type,\n \"value\": Wei(1),\n }\n )\n pending_block = await async_w3.eth.get_block(\"pending\")\n last_pending_txn_index = len(pending_block[\"transactions\"]) - 1\n raw_txn = await async_w3.eth.get_raw_transaction_by_block(\n \"pending\", last_pending_txn_index\n )\n assert is_bytes(raw_txn)\n\n # eth_getRawTransactionByBlockNumberAndIndex: block number\n async_block_with_txn_number = async_block_with_txn[\"number\"]\n raw_transaction = await async_w3.eth.get_raw_transaction_by_block(\n async_block_with_txn_number, 0\n )\n assert is_bytes(raw_transaction)\n\n # eth_getRawTransactionByBlockHashAndIndex: block hash\n async_block_with_txn_hash = async_block_with_txn[\"hash\"]\n raw_transaction = await async_w3.eth.get_raw_transaction_by_block(\n async_block_with_txn_hash, 0\n )\n assert is_bytes(raw_transaction)\n\n @pytest.mark.asyncio\n @pytest.mark.parametrize(\"unknown_block_num_or_hash\", (1234567899999, UNKNOWN_HASH))\n async def test_eth_get_raw_transaction_by_block_raises_error(\n self, async_w3: \"AsyncWeb3\", unknown_block_num_or_hash: Union[int, HexBytes]\n ) -> None:\n with pytest.raises(\n TransactionNotFound,\n match=(\n f\"Transaction index: 0 on block id: \"\n f\"{to_hex_if_integer(unknown_block_num_or_hash)!r} \"\n f\"not found.\"\n ),\n ):\n await async_w3.eth.get_raw_transaction_by_block(\n unknown_block_num_or_hash, 0\n )\n\n @pytest.mark.asyncio\n async def test_eth_get_raw_transaction_by_block_raises_error_block_identifier(\n self, async_w3: \"AsyncWeb3\"\n ) -> None:\n unknown_identifier = \"unknown\"\n with pytest.raises(\n ValueError,\n match=(\n \"Value did not match any of the recognized block identifiers: \"\n f\"{unknown_identifier}\"\n ),\n ):\n await async_w3.eth.get_raw_transaction_by_block(\n unknown_identifier, 0 # type: ignore\n )\n\n @pytest.mark.asyncio\n async def test_eth_get_balance(self, async_w3: \"AsyncWeb3\") -> None:\n coinbase = await async_w3.eth.coinbase\n\n with pytest.raises(InvalidAddress):\n await async_w3.eth.get_balance(\n ChecksumAddress(HexAddress(HexStr(coinbase.lower())))\n )\n\n balance = await async_w3.eth.get_balance(coinbase)\n\n assert is_integer(balance)\n assert balance >= 0\n\n @pytest.mark.asyncio\n async def test_eth_get_code(\n self, async_w3: \"AsyncWeb3\", async_math_contract_address: ChecksumAddress\n ) -> None:\n code = await async_w3.eth.get_code(async_math_contract_address)\n assert isinstance(code, HexBytes)\n assert len(code) > 0\n\n @pytest.mark.asyncio\n async def test_eth_get_code_invalid_address(\n self,\n async_w3: \"AsyncWeb3\",\n async_math_contract: \"Contract\",\n ) -> None:\n with pytest.raises(InvalidAddress):\n await async_w3.eth.get_code(\n ChecksumAddress(HexAddress(HexStr(async_math_contract.address.lower())))\n )\n\n @pytest.mark.asyncio\n async def test_eth_get_code_with_block_identifier(\n self, async_w3: \"AsyncWeb3\", async_emitter_contract: \"Contract\"\n ) -> None:\n block_id = await async_w3.eth.block_number\n code = await async_w3.eth.get_code(async_emitter_contract.address, block_id)\n assert isinstance(code, HexBytes)\n assert len(code) > 0\n\n @pytest.mark.asyncio\n async def test_eth_create_access_list(\n self,\n async_w3: \"AsyncWeb3\",\n async_unlocked_account_dual_type: ChecksumAddress,\n async_math_contract: \"Contract\",\n ) -> None:\n # Initialize transaction for gas estimation\n txn_params: TxParams = {\n \"from\": async_unlocked_account_dual_type,\n \"value\": Wei(1),\n \"gas\": 21000,\n }\n txn = async_math_contract._prepare_transaction(\n fn_name=\"incrementCounter\",\n fn_args=[1],\n transaction=txn_params,\n )\n\n # create access list using data from transaction\n response = await async_w3.eth.create_access_list(\n {\n \"from\": async_unlocked_account_dual_type,\n \"to\": async_math_contract.address,\n \"data\": txn[\"data\"],\n }\n )\n\n assert is_dict(response)\n access_list = response[\"accessList\"]\n assert len(access_list) > 0\n assert access_list[0][\"address\"] is not None\n assert is_checksum_address(access_list[0][\"address\"])\n assert len(access_list[0][\"storageKeys\"][0]) == 32\n assert int(response[\"gasUsed\"]) >= 0\n\n @pytest.mark.asyncio\n async def test_eth_get_transaction_count(\n self, async_w3: \"AsyncWeb3\", async_unlocked_account_dual_type: ChecksumAddress\n ) -> None:\n transaction_count = await async_w3.eth.get_transaction_count(\n async_unlocked_account_dual_type\n )\n assert is_integer(transaction_count)\n assert transaction_count >= 0\n\n @pytest.mark.asyncio\n async def test_eth_call(\n self, async_w3: \"AsyncWeb3\", async_math_contract: \"Contract\"\n ) -> None:\n coinbase = await async_w3.eth.coinbase\n txn_params = async_math_contract._prepare_transaction(\n fn_name=\"add\",\n fn_args=(7, 11),\n transaction={\"from\": coinbase, \"to\": async_math_contract.address},\n )\n call_result = await async_w3.eth.call(txn_params)\n assert is_string(call_result)\n (result,) = async_w3.codec.decode([\"uint256\"], call_result)\n assert result == 18\n\n @pytest.mark.asyncio\n async def test_eth_call_with_override_code(\n self,\n async_w3: \"AsyncWeb3\",\n async_revert_contract: \"Contract\",\n ) -> None:\n coinbase = await async_w3.eth.coinbase\n txn_params = async_revert_contract._prepare_transaction(\n fn_name=\"normalFunction\",\n transaction={\"from\": coinbase, \"to\": async_revert_contract.address},\n )\n call_result = await async_w3.eth.call(txn_params)\n (result,) = async_w3.codec.decode([\"bool\"], call_result)\n assert result is True\n\n # override runtime bytecode: `normalFunction` returns `false`\n override_code = HexStr(\n \"0x6080604052348015600f57600080fd5b5060043610603c5760003560e01c8063185c38a4146041578063c06a97cb146049578063d67e4b84146051575b600080fd5b60476071565b005b604f60df565b005b605760e4565b604051808215151515815260200191505060405180910390f35b6040517f08c379a000000000000000000000000000000000000000000000000000000000815260040180806020018281038252601b8152602001807f46756e6374696f6e20686173206265656e2072657665727465642e000000000081525060200191505060405180910390fd5b600080fd5b60008090509056fea2646970667358221220bb71e9e9a2e271cd0fbe833524a3ea67df95f25ea13aef5b0a761fa52b538f1064736f6c63430006010033\" # noqa: E501\n )\n call_result = await async_w3.eth.call(\n txn_params,\n \"latest\",\n {async_revert_contract.address: {\"code\": override_code}},\n )\n (result,) = async_w3.codec.decode([\"bool\"], call_result)\n assert result is False\n\n # test bytes\n\n bytes_call_result = await async_w3.eth.call(\n txn_params,\n \"latest\",\n {async_revert_contract.address: {\"code\": to_bytes(hexstr=override_code)}},\n )\n (bytes_result,) = async_w3.codec.decode([\"bool\"], bytes_call_result)\n assert bytes_result is False\n\n @pytest.mark.asyncio\n @pytest.mark.parametrize(\n \"params\",\n (\n {\n \"nonce\": 1, # int\n \"balance\": 1, # int\n \"code\": HexStr(\"0x\"), # HexStr\n # with state\n \"state\": {HexStr(f\"0x{'00' * 32}\"): HexStr(f\"0x{'00' * 32}\")},\n },\n {\n \"nonce\": HexStr(\"0x1\"), # HexStr\n \"balance\": HexStr(\"0x1\"), # HexStr\n \"code\": b\"\\x00\", # bytes\n # with stateDiff\n \"stateDiff\": {HexStr(f\"0x{'00' * 32}\"): HexStr(f\"0x{'00' * 32}\")},\n },\n ),\n )\n async def test_eth_call_with_override_param_type_check(\n self,\n async_w3: \"AsyncWeb3\",\n async_math_contract: \"Contract\",\n params: StateOverrideParams,\n ) -> None:\n coinbase = await async_w3.eth.coinbase\n txn_params: TxParams = {\"from\": coinbase}\n\n # assert does not raise\n await async_w3.eth.call(\n txn_params, \"latest\", {async_math_contract.address: params}\n )\n\n @pytest.mark.asyncio\n async def test_eth_call_with_0_result(\n self, async_w3: \"AsyncWeb3\", async_math_contract: \"Contract\"\n ) -> None:\n coinbase = await async_w3.eth.coinbase\n txn_params = async_math_contract._prepare_transaction(\n fn_name=\"add\",\n fn_args=(0, 0),\n transaction={\"from\": coinbase, \"to\": async_math_contract.address},\n )\n call_result = await async_w3.eth.call(txn_params)\n assert is_string(call_result)\n (result,) = async_w3.codec.decode([\"uint256\"], call_result)\n assert result == 0\n\n @pytest.mark.asyncio\n async def test_eth_call_revert_with_msg(\n self,\n async_w3: \"AsyncWeb3\",\n async_revert_contract: \"Contract\",\n async_unlocked_account: ChecksumAddress,\n ) -> None:\n txn_params = async_revert_contract._prepare_transaction(\n fn_name=\"revertWithMessage\",\n transaction={\n \"from\": async_unlocked_account,\n \"to\": async_revert_contract.address,\n },\n )\n with pytest.raises(\n ContractLogicError, match=\"execution reverted: Function has been reverted\"\n ):\n await async_w3.eth.call(txn_params)\n\n @pytest.mark.asyncio\n async def test_eth_call_revert_without_msg(\n self,\n async_w3: \"AsyncWeb3\",\n async_revert_contract: \"Contract\",\n async_unlocked_account: ChecksumAddress,\n ) -> None:\n with pytest.raises(ContractLogicError, match=\"execution reverted\"):\n txn_params = async_revert_contract._prepare_transaction(\n fn_name=\"revertWithoutMessage\",\n transaction={\n \"from\": async_unlocked_account,\n \"to\": async_revert_contract.address,\n },\n )\n await async_w3.eth.call(txn_params)\n\n @pytest.mark.asyncio\n \nasync def test_eth_call_revert_custom_error_with_msg(\n self,\n async_w3: \"AsyncWeb3\",\n async_revert_contract: \"Contract\",\n async_unlocked_account: ChecksumAddress,\n ) -> None:\n data = async_revert_contract.encodeABI(\n fn_name=\"UnauthorizedWithMessage\", args=[\"You are not authorized\"]\n )\n txn_params = async_revert_contract._prepare_transaction(\n fn_name=\"customErrorWithMessage\",\n transaction={\n \"from\": async_unlocked_account,\n \"to\": async_revert_contract.address,\n },\n )\n with pytest.raises(ContractCustomError, match=data):\n await async_w3.eth.call(txn_params)\n\n @pytest.mark.asyncio\n async def test_eth_call_revert_custom_error_without_msg(\n self,\n async_w3: \"AsyncWeb3\",\n async_revert_contract: \"Contract\",\n async_unlocked_account: ChecksumAddress,\n ) -> None:\n data = async_revert_contract.encodeABI(fn_name=\"Unauthorized\")\n txn_params = async_revert_contract._prepare_transaction(\n fn_name=\"customErrorWithoutMessage\",\n transaction={\n \"from\": async_unlocked_account,\n \"to\": async_revert_contract.address,\n },\n )\n with pytest.raises(ContractCustomError, match=data):\n await async_w3.eth.call(txn_params)\n\n @pytest.mark.parametrize(\n \"panic_error,params\",\n (\n (\"01\", []),\n (\"11\", []),\n (\"12\", [0]),\n (\"21\", [-1]),\n (\"22\", []),\n (\"31\", []),\n (\"32\", []),\n (\"41\", []),\n (\"51\", []),\n ),\n )\n @pytest.mark.asyncio\n async def test_contract_panic_errors(\n self,\n async_w3: \"AsyncWeb3\",\n async_panic_errors_contract: \"Contract\",\n panic_error: str,\n params: List[Any],\n ) -> None:\n method = getattr(\n async_panic_errors_contract.functions,\n f\"errorCode{panic_error}\",\n )\n error_msg = PANIC_ERROR_CODES[panic_error]\n\n with pytest.raises(ContractPanicError, match=re.escape(error_msg)):\n await method(*params).call()\n\n @pytest.mark.asyncio\n async def test_eth_call_offchain_lookup(\n self,\n async_w3: \"AsyncWeb3\",\n async_offchain_lookup_contract: \"Contract\",\n async_unlocked_account: ChecksumAddress,\n monkeypatch: \"MonkeyPatch\",\n ) -> None:\n normalized_contract_address = to_hex_if_bytes(\n async_offchain_lookup_contract.address\n ).lower()\n\n async_mock_offchain_lookup_request_response(\n monkeypatch,\n mocked_request_url=f\"https://web3.py/gateway/{normalized_contract_address}/{OFFCHAIN_LOOKUP_TEST_DATA}.json\", # noqa: E501\n...\n", "needle_token_start": 15358, "needle_token_end": 15568, "code_context_ntokens": 16260, "output": ["```python\n@pytest.mark.asyncio\nasync def test_eth_call_revert_custom_error_with_msg(\n self,\n async_w3: \"AsyncWeb3\",\n async_revert_contract: \"Contract\",\n async_unlocked_account: ChecksumAddress,\n) -> None:\n data = async_revert_contract.encodeABI(\n fn_name=\"UnauthorizedWithMessage\", args=[\"You are not authorized\"]\n )\n txn_params = async_revert_contract._prepare_transaction(\n fn_name=\"customErrorWithMessage\",\n transaction={\n \"from\": async_unlocked_account,\n \"to\": async_revert_contract.address,\n },\n )\n with pytest.raises(ContractCustomError, match=data):\n await async_w3.eth.call(txn_params)\n```"]} +{"repo": "Ciphey/Ciphey", "name": "chars_to_byte", "language": "python", "path": "ciphey/basemods/Decoders/base69.py", "position_ratio": 0.05, "description": "\nFunction Description:\n1. **Purpose**: Converts a pair of characters from a specific encoding into a single byte value.\n2. **Input**: A string of two characters.\n3. **Output**: An integer representing the byte value derived from the input characters.\n4. **Procedure**: The function calculates the byte value by finding the positions of each character in a predefined character set. The position of the second character is multiplied by the length of the character set (69 in this case) and added to the position of the first character to produce the final byte value.\n\n", "instruction": "Based on the function description and code context, please retrieve and repeat the exact described function from the code context in a code block wrapped by ```:", "template": "instruction\ncode_context\ndescription\ninstruction", "code_context": "# Path: ciphey/basemods/Decoders/base62.py\nfrom typing import Dict, Optional\n\nimport base62\n\nfrom ciphey.iface import Config, Decoder, ParamSpec, T, U, registry\n\n\n@registry.register\nclass Base62(Decoder[str]):\n def decode(self, ctext: T) -> Optional[U]:\n \"\"\"\n Performs Base62 decoding\n \"\"\"\n try:\n...\n# Path: ciphey/basemods/Decoders/base69.py\n# Translated to Python and adapted for Ciphey from the JS original at https://github.com/pshihn/base69\n\n\nimport re\nfrom math import ceil\nfrom typing import Dict, Optional\n\nfrom ciphey.iface import Config, Decoder, ParamSpec, T, U, WordList, registry\n\n\n@registry.register\nclass Base69(Decoder[str]):\n def decode(self, ctext: T) -> Optional[U]:\n \"\"\"\n Performs Base69 decoding\n \"\"\"\n # Remove whitespace\n try:\n ctext = re.sub(r\"\\s+\", \"\", ctext, flags=re.UNICODE)\n extra_bytes = 0\n clen = len(ctext)\n\n if ctext[:-1] == \"=\":\n extra_bytes = int(ctext[clen - 2])\n\n CHUNK_COUNT = ceil(clen / 16)\n result = [0 for _ in range(CHUNK_COUNT * 7 - extra_bytes)]\n\n for i in range(CHUNK_COUNT):\n chunk_string = ctext[i * 16 : (i + 1) * 16]\n if extra_bytes and (i == CHUNK_COUNT - 1):\n insert = self.decode_chunk(chunk_string)\n for n, elem in enumerate(insert[0 : 7 - extra_bytes]):\n result[n + i * 7] = elem\n else:\n insert = self.decode_chunk(chunk_string)\n for n, elem in enumerate(insert):\n result[n + i * 7] = elem % 256\n return bytearray(result).decode().strip(\"\\x00\")\n except Exception:\n return None\n\n def decode_chunk(self, s: str):\n padded_bytes = s.endswith(\"=\")\n\n decoded = [0 for _ in range(8)]\n for i in range(8):\n decoded[i] = (\n 0\n if i == 7 and padded_bytes\n else self.chars_to_byte(s[i * 2 : i * 2 + 2])\n )\n\n result = [0 for _ in range(7)]\n for i in range(7):\n t1 = decoded[i] << (i + 1)\n t2 = decoded[i + 1] >> (7 - i - 1)\n result[i] = t1 | t2\n return result\n\n \ndef chars_to_byte(self, s: str):\n return (69 * self.CHARS.index(s[1])) + (self.CHARS.index(s[0]))\n\n @staticmethod\n def priority() -> float:\n # If this becomes lower or equal to the reverse, it breaks.\n # So I'll set it to 0.2 for now since it is very fast anyways.\n return 0.2\n\n def __init__(self, config: Config):\n super().__init__(config)\n self.CHARS = config.get_resource(self._params()[\"dict\"], WordList)\n\n @staticmethod\n def getParams() -> Optional[Dict[str, ParamSpec]]:\n return {\n \"dict\": ParamSpec(\n desc=\"The charset used for the decoder.\",\n req=False,\n default=\"cipheydists::list::base69\",\n )\n }\n\n @staticmethod\n def getTarget() -> str:\n return \"base69\"\n\n# Path: ciphey/basemods/Decoders/base91.py\nfrom typing import Dict, Optional\n\nimport base91\n\nfrom ciphey.iface import Config, Decoder, ParamSpec, T, U, registry\n\n\n@registry.register\nclass Base91(Decoder[str]):\n def decode(self, ctext: T) -> Optional[U]:\n \"\"\"\n Performs Base91 decoding\n \"\"\"\n try:\n return base91.decode(ctext).decode(\"utf-8\")\n except Exception:\n return None\n\n @staticmethod\n def priority() -> float:\n # Not expected to show up often, but also very fast to check.\n return 0.05\n\n def __init__(self, config: Config):\n super().__init__(config)\n\n @staticmethod\n def getParams() -> Optional[Dict[str, ParamSpec]]:\n return None\n\n @staticmethod\n def getTarget() -> str:\n return \"base91\"\n\n# Path: ciphey/basemods/Decoders/bases.py\nimport base64\nimport types\nfrom typing import Any, Callable, Optional\n\nimport logging\nfrom rich.logging import RichHandler\nimport re\n\nfrom ciphey.common import id_lambda\nfrom ciphey.iface import Decoder, registry\n\n\ndef _dispatch(self: Any, ctext: str, func: Callable[[str], bytes]) -> Optional[bytes]:\n logging.debug(f\"Attempting {self.getTarget()}\")\n\n try:\n # remove all whitespace\n ctext = re.sub(r\"\\s+\", \"\", ctext, re.UNICODE)\n result = func(ctext)\n logging.info(f\"{self.getTarget()} successful, returning {result}\")\n return result\n except ValueError:\n logging.debug(f\"Failed to decode {self.getTarget()}\")\n return None\n\n\n_bases = {\n \"base16\": (base64.b16decode, 0.4),\n \"base32\": (base64.b32decode, 0.01),\n \"base64\": (base64.b64decode, 0.4),\n \"base85\": (base64.b85decode, 0.01),\n \"ascii85\": (base64.a85decode, 0.1),\n}\n\n\ndef gen_class(name, decoder, priority, ns):\n ns[\"_get_func\"] = id_lambda(decoder)\n ns[\"decode\"] = lambda self, ctext: _dispatch(self, ctext, self._get_func())\n ns[\"getParams\"] = id_lambda(None)\n ns[\"getTarget\"] = id_lambda(name)\n ns[\"priority\"] = id_lambda(priority)\n ns[\"__init__\"] = lambda self, config: super(type(self), self).__init__(config)\n\n\nfor name, (decoder, priority) in _bases.items():\n t = types.new_class(\n name,\n (Decoder[str],),\n exec_body=lambda x: gen_class(name, decoder, priority, x),\n )\n\n registry.register(t)\n\n# Path: ciphey/basemods/Decoders/baudot.py\nimport re\nfrom typing import Dict, Optional\n\nfrom ciphey.iface import Config, Decoder, ParamSpec, T, Translation, U, registry\n\n\n@registry.register\nclass Baudot(Decoder[str]):\n def decode(self, ctext: T) -> Optional[U]:\n result = \"\"\n switch_to_digit_map = 0\n if re.search(\"^[01]{5}$\", ctext.split()[0]):\n for i in ctext.split():\n if i == \"11011\":\n switch_to_digit_map = 1\n if i == \"11111\":\n switch_to_digit_map = 0\n if switch_to_digit_map == 1:\n result += self.BAUDOT_DICT[\"+\" + i]\n if switch_to_digit_map == 0:\n result += self.BAUDOT_DICT[i]\n return result\n else:\n return None\n\n @staticmethod\n def priority() -> float:\n return 0.05\n\n def __init__(self, config: Config):\n super().__init__(config)\n self.BAUDOT_DICT = config.get_resource(self._params()[\"dict\"], Translation)\n\n @staticmethod\n def getParams() -> Optional[Dict[str, ParamSpec]]:\n return {\n \"dict\": ParamSpec(\n desc=\"The baudot alphabet dictionary to use\",\n req=False,\n default=\"cipheydists::translate::baudot\",\n )\n }\n\n @staticmethod\n def getTarget() -> str:\n return \"baudot\"\n\n# Path: ciphey/basemods/Decoders/binary.py\nimport re\nfrom typing import Dict, List, Optional\n\nimport logging\nfrom rich.logging import RichHandler\n\nfrom ciphey.iface import Config, Decoder, ParamSpec, T, U, registry\n\n\n@registry.register\nclass Binary(Decoder[str]):\n def decode(self, ctext: T) -> Optional[U]:\n try:\n ctext = re.sub(r\"[^\\S \\n]\", \" \", ctext, flags=re.UNICODE)\n ctext = ctext.replace(\"\\n\", \" \")\n\n existing_split = self.try_split(ctext.split(\" \"))\n if existing_split is not None:\n return existing_split\n\n # Now we try our own grouping\n\n # Remove final bit of whitespace\n ctext = ctext.replace(\" \", \"\")\n # Split into bytes, and test\n return self.try_split([ctext[i : i + 8] for i in range(0, len(ctext), 8)])\n # Catch bad octal chars\n except ValueError:\n return None\n\n def try_split(self, split_text: List[str]):\n ret = []\n\n for i in split_text:\n if len(i) == 0:\n continue\n val = int(i, 2)\n if val > 255 or val < 0:\n return None\n ret.append(val)\n\n if len(ret) != 0:\n ret = bytes(ret)\n logging.info(f\"binary successful, returning {ret.__repr__()}\")\n return ret\n\n @staticmethod\n def priority() -> float:\n return 0.3\n\n def __init__(self, config: Config):\n super().__init__(config)\n\n @staticmethod\n def getParams() -> Optional[Dict[str, ParamSpec]]:\n return None\n\n @staticmethod\n def getTarget() -> str:\n return \"binary\"\n\n# Path: ciphey/basemods/Decoders/braille.py\nimport re\nfrom typing import Dict, Optional\n\nfrom ciphey.iface import Config, Decoder, ParamSpec, T, Translation, U, registry\nimport logging\nfrom rich.logging import RichHandler\n\n\n@registry.register\nclass Braille(Decoder[str]):\n def decode(self, ctext: T) -> Optional[U]:\n \"\"\"\n Performs Braille decoding\n \"\"\"\n logging.debug(\"Attempting Braille\")\n ctext_decoded = \"\"\n braille_matches = 0\n for symbol in self.BRAILLE_DICT_INV.values():\n if symbol in ctext:\n braille_matches += 1\n else:\n continue\n if braille_matches == 0:\n logging.debug(\"Failed to decode Braille due to invalid characters\")\n return None\n\n for pattern, value in self.BRAILLE_DICT.items():\n ctext = re.sub(pattern, value, ctext)\n\n wordArr = []\n for word in ctext.split(\" \"):\n # If two commas are in front of a word, uppercase the word and remove the comma\n if word[:2].find(\",,\") != -1:\n wordArr.append(word.replace(\",,\", \"\").upper())\n else:\n wordArr.append(word)\n\n result = []\n for word in wordArr:\n # If one comma is in front of a word, capitalize the word and remove the comma\n if word[0].find(\",\") != -1:\n result.append(word.replace(\",\", \"\").capitalize())\n else:\n result.append(word)\n ctext_decoded = \" \".join(result)\n logging.info(f\"Braille successful, returning '{ctext_decoded}'\")\n return ctext_decoded\n\n @staticmethod\n def priority() -> float:\n return 0.05\n\n def __init__(self, config: Config):\n super().__init__(config)\n self.BRAILLE_DICT = config.get_resource(self._params()[\"dict\"], Translation)\n self.BRAILLE_DICT_INV = {v: k for k, v in self.BRAILLE_DICT.items()}\n\n @staticmethod\n def getParams() -> Optional[Dict[str, ParamSpec]]:\n return {\n \"dict\": ParamSpec(\n desc=\"The Braille dictionary to use\",\n req=False,\n default=\"cipheydists::translate::braille\",\n )\n }\n\n @staticmethod\n def getTarget() -> str:\n return \"braille\"\n\n# Path: ciphey/basemods/Decoders/brainfuck.py\nimport re\nimport time\nfrom typing import Dict, Optional, Tuple\n\nimport logging\nfrom rich.logging import RichHandler\n\nfrom ciphey.iface import Config, Decoder, ParamSpec, T, U, WordList, registry\n\n\n@registry.register\nclass Brainfuck(Decoder[str]):\n def decode(self, ctext: T) -> Optional[U]:\n \"\"\"\n Takes a ciphertext and treats it as a Brainfuck program,\n interpreting it and saving the output as a string to return.\n\n Brainfuck is a very simple, Turing-complete esoteric language.\n Below is a simplified interpreter that attempts to check whether a\n given ciphertext is a brainfuck program that would output a string.\n\n A program that can be \"decoded\" like this is one that:\n * Does not require user input (\",\" instruction)\n * Includes at least one putchar instruction (\".\")\n * Does not contain anything but the main 7 instructions,\n (excluding \",\") and whitespace\n\n Details:\n * This implementation wraps the memory pointer for \">\" and \"<\"\n * It is time-limited to 60 seconds, to prevent hangups\n * The program starts with 100 memory cells, chosen arbitrarily\n \"\"\"\n\n logging.debug(\"Attempting brainfuck\")\n\n result = \"\"\n memory = [0] * 100\n codeptr, memptr = 0, 0 # Instruction pointer and stack pointer\n timelimit = 60 # The timeout in seconds\n\n bracemap, isbf = self.bracemap_and_check(ctext)\n\n # If it doesn't appear to be valid brainfuck code\n if not isbf:\n logging.debug(\"Failed to interpret brainfuck due to invalid characters\")\n return None\n\n # Get start time\n start = time.time()\n\n while codeptr < len(ctext):\n\n current = time.time()\n\n # Return none if we've been running for over a minute\n if current - start > timelimit:\n logging.debug(\"Failed to interpret brainfuck due to timing out\")\n return None\n\n cmd = ctext[codeptr]\n\n if cmd == \"+\":\n if memory[memptr] < 255:\n memory[memptr] = memory[memptr] + 1\n else:\n memory[memptr] = 0\n\n elif cmd == \"-\":\n if memory[memptr] > 0:\n memory[memptr] = memory[memptr] - 1\n else:\n memory[memptr] = 255\n\n elif cmd == \">\":\n if memptr == len(memory) - 1:\n memory.append(0)\n memptr += 1\n\n elif cmd == \"<\":\n if memptr == 0:\n memptr = len(memory) - 1\n else:\n memptr -= 1\n\n # If we're at the beginning of the loop and the memory is 0, exit the loop\n elif cmd == \"[\" and memory[memptr] == 0:\n codeptr = bracemap[codeptr]\n\n # If we're at the end of the loop and the memory is >0, jmp to the beginning of the loop\n elif cmd == \"]\" and memory[memptr]:\n codeptr = bracemap[codeptr]\n\n # Store the output as a string instead of printing it out\n elif cmd == \".\":\n result += chr(memory[memptr])\n\n codeptr += 1\n\n logging.info(f\"Brainfuck successful, returning '{result}'\")\n return result\n\n def bracemap_and_check(self, program: str) -> Tuple[Optional[Dict], bool]:\n \"\"\"\n Create a bracemap of brackets in the program, to compute jmps.\n Maps open -> close brackets as well as close -> open brackets.\n\n Also returns True if the program is valid Brainfuck code. If False, we\n won't even try to run it.\n \"\"\"\n\n open_stack = []\n bracemap = dict()\n legal_instructions = {\"+\", \"-\", \">\", \"<\", \"[\", \"]\", \".\"}\n legal_count = 0\n\n # If the program actually outputs anything (contains \".\")\n prints = False\n\n for idx, instruction in enumerate(program):\n # If instruction is brainfuck (without input) or whitespace, it counts\n if instruction in legal_instructions or re.match(r\"\\s\", instruction):\n legal_count += 1\n\n if not prints and instruction == \".\":\n # If there are no \".\" instructions then this program will not output anything\n prints = True\n\n elif instruction == \"[\":\n open_stack.append(idx)\n\n elif instruction == \"]\":\n try:\n opbracket = open_stack.pop()\n bracemap[opbracket] = idx\n bracemap[idx] = opbracket\n except IndexError:\n # Mismatched braces, not a valid program\n # Closing braces > opening braces\n return (None, False)\n\n # 1. All characters are instructions or whitespace\n # 2. There are no extra open braces\n # 3. There is at least one character to be \"printed\"\n # (result is >=1 in length)\n is_brainfuck = legal_count == len(program) and len(open_stack) == 0 and prints\n\n return bracemap, is_brainfuck\n\n @staticmethod\n def priority() -> float:\n # Not uncommon, but not very common either. It's also slow.\n return 0.08\n\n def __init__(self, config: Config):\n super().__init__(config)\n self.ALPHABET = config.get_resource(self._params()[\"dict\"], WordList)\n\n @staticmethod\n def getParams() -> Optional[Dict[str, ParamSpec]]:\n return {\n \"dict\": ParamSpec(\n desc=\"Brainfuck alphabet (default English)\",\n req=False,\n default=\"cipheydists::list::englishAlphabet\",\n )\n }\n\n @staticmethod\n def getTarget() -> str:\n return \"brainfuck\"\n\n# Path: ciphey/basemods/Decoders/decimal.py\nimport re\nfrom typing import Dict, Optional\n\nimport logging\nfrom rich.logging import RichHandler\n\nfrom ciphey.iface import Config, Decoder, ParamSpec, T, U, registry\n\n\n@registry.register\nclass Decimal(Decoder[str]):\n def decode(self, ctext: T) -> Optional[U]:\n \"\"\"\n Performs Decimal decoding\n \"\"\"\n logging.debug(\"Attempting decimal\")\n ctext_converted = []\n ctext_split = re.split(r\"[ ,;:\\-\\n]\", ctext)\n delimiters = set(sorted(re.sub(r\"[^ ,;:\\-\\n]\", \"\", ctext)))\n ctext_num = re.sub(r\"[,;:\\-\\s]\", \"\", ctext)\n ctext_decoded = \"\"\n if ctext_num.isnumeric() is False:\n logging.debug(\"Failed to decode decimal due to non numeric character(s)\")\n return None\n try:\n for i in ctext_split:\n val = int(i)\n if val > 255 or val < 0:\n logging.debug(\n f\"Failed to decode decimal due to invalid number '{val}'\"\n )\n return None\n ctext_converted.append(chr(val))\n ctext_decoded = \"\".join(ctext_converted)\n logging.info(\n f\"Decimal successful, returning '{ctext_decoded}' with delimiter(s) {delimiters}\"\n )\n return ctext_decoded\n except Exception:\n return None\n\n @staticmethod\n def priority() -> float:\n return 0.05\n\n def __init__(self, config: Config):\n super().__init__(config)\n\n @staticmethod\n def getParams() -> Optional[Dict[str, ParamSpec]]:\n return None\n\n @staticmethod\n def getTarget() -> str:\n return \"decimal\"\n\n# Path: ciphey/basemods/Decoders/dna.py\nimport re\nfrom typing import Dict, Optional\n\nimport logging\nfrom rich.logging import RichHandler\n\nfrom ciphey.iface import Config, Decoder, ParamSpec, T, Translation, U, registry\n\n\n@registry.register\nclass Dna(Decoder[str]):\n def decode(self, ctext: T) -> Optional[U]:\n \"\"\"\n Performs DNA decoding\n \"\"\"\n logging.debug(\"Attempting DNA decoder\")\n ctext_decoded = \"\"\n ctext = re.sub(r\"[,;:\\-\\s]\", \"\", ctext)\n ctext = \" \".join(ctext[i : i + 3] for i in range(0, len(ctext), 3))\n ctext_split = ctext.split(\" \")\n dna_keys = self.DNA_DICT.keys()\n\n for i in ctext_split:\n if i in dna_keys:\n ctext_decoded += self.DNA_DICT[i]\n else:\n return None\n logging.info(f\"DNA successful, returning '{ctext_decoded}'\")\n return ctext_decoded\n\n @staticmethod\n def priority() -> float:\n return 0.2\n\n def __init__(self, config: Config):\n super().__init__(config)\n self.DNA_DICT = config.get_resource(self._params()[\"dict\"], Translation)\n\n @staticmethod\n def getParams() -> Optional[Dict[str, ParamSpec]]:\n return {\n \"dict\": ParamSpec(\n desc=\"The DNA alphabet dictionary to use\",\n req=False,\n default=\"cipheydists::translate::dna\",\n )\n }\n\n @staticmethod\n def getTarget() -> str:\n return \"dna\"\n\n# Path: ciphey/basemods/Decoders/dtmf.py\nimport re\nfrom typing import Dict, Optional\n\nimport logging\nfrom rich.logging import RichHandler\n\nfrom ciphey.iface import Config, Decoder, ParamSpec, T, Translation, U, registry\n\n\n@registry.register\nclass Dtmf(Decoder[str]):\n def decode(self, ctext: T) -> Optional[U]:\n \"\"\"\n Performs DTMF decoding\n \"\"\"\n logging.debug(\"Attempting DTMF decoder\")\n ctext_decoded = \"\"\n ctext = re.sub(r\"[,;:\\-\\/\\s]\", \"\", ctext)\n ctext = \" \".join(ctext[i : i + 7] for i in range(0, len(ctext), 7))\n ctext_split = ctext.split(\" \")\n dtmf_keys = self.DTMF_DICT.keys()\n\n for i in ctext_split:\n if i in dtmf_keys:\n ctext_decoded += self.DTMF_DICT[i]\n else:\n return None\n logging.info(f\"DTMF successful, returning '{ctext_decoded}'\")\n return ctext_decoded\n\n @staticmethod\n def priority() -> float:\n return 0.2\n\n def __init__(self, config: Config):\n super().__init__(config)\n self.DTMF_DICT = config.get_resource(self._params()[\"dict\"], Translation)\n\n @staticmethod\n def getParams() -> Optional[Dict[str, ParamSpec]]:\n return {\n \"dict\": ParamSpec(\n desc=\"The DTMF alphabet dictionary to use\",\n req=False,\n default=\"cipheydists::translate::dtmf\",\n )\n }\n\n @staticmethod\n def getTarget() -> str:\n return \"dtmf\"\n\n# Path: ciphey/basemods/Decoders/galactic.py\nfrom typing import Dict, Optional\n\nimport logging\nfrom rich.logging import RichHandler\n\nfrom ciphey.iface import Config, Decoder, ParamSpec, T, Translation, U, registry\n\n\n@registry.register\nclass Galactic(Decoder[str]):\n def decode(self, ctext: T) -> Optional[U]:\n \"\"\"\n Takes a string written in the 'Standard Galactic Alphabet'\n (aka Minecraft Enchanting Table Symbols) and translates it to ASCII text.\n \"\"\"\n logging.debug(\"Attempting Standard Galactic Alphabet decoder\")\n\n # To avoid complications, only move forward with the decoding if we can\n # reasonably assume that the input string is written in the galactic alphabet\n galactic_matches = 0\n for symbol in self.GALACTIC_DICT.keys():\n # These symbols are assumed to be frequent enough in regular\n # text to be skipped when counting the matches. All others are counted.\n if symbol in ctext and symbol not in [\"!\", \"|\"]:\n galactic_matches += 1\n else:\n continue\n if galactic_matches == 0:\n logging.debug(\n \"No matching galactic alphabet letters found. Skipping galactic decoder\"\n )\n return None\n logging.debug(f\"{galactic_matches} galactic alphabet letters found. \")\n\n result = \"\"\n # Take out the problematic characters consisting of multiple symbols\n ctext = (\n ctext.replace(\"||\", \"|\")\n .replace(\"/\", \"\")\n .replace(\"\u00a1\", \"\")\n .replace(\" \u0323 \", \"\")\n .replace(\"\u0307\", \"x\")\n )\n logging.debug(f\"Modified string is {ctext}\")\n\n for letter in ctext:\n if letter in self.GALACTIC_DICT.keys():\n # Match every letter of the input to its galactic counterpoint\n result += self.GALACTIC_DICT[letter]\n else:\n # If the current character is not in the defined alphabet,\n # just accept it as-is (useful for numbers, punctuation, etc.)\n result += letter\n # Remove the trailing space (appearing as a leading space)\n # from the x that results from the diacritic replacement\n result = result.replace(\"x \", \"x\")\n logging.debug(f\"Decoded string is {result}\")\n return result\n\n @staticmethod\n def priority() -> float:\n # Not expected to show up often, but also very fast to check.\n return 0.01\n\n def __init__(self, config: Config):\n super().__init__(config)\n self.GALACTIC_DICT = config.get_resource(self._params()[\"dict\"], Translation)\n\n @staticmethod\n def getParams() -> Optional[Dict[str, ParamSpec]]:\n return {\n \"dict\": ParamSpec(\n desc=\"The galactic alphabet dictionary to use\",\n req=False,\n default=\"cipheydists::translate::galactic\",\n )\n }\n\n @staticmethod\n def getTarget() -> str:\n return \"galactic\"\n\n# Path: ciphey/basemods/Decoders/gzip.py\nimport zlib\nfrom typing import Dict, Optional\n\nfrom ciphey.iface import Config, Decoder, ParamSpec, T, U, registry\n\n\n@registry.register\nclass Gzip(Decoder[bytes]):\n def decode(self, ctext: T) -> Optional[U]:\n \"\"\"\n Performs Gzip decoding\n \"\"\"\n try:\n return zlib.decompress(ctext, 16 + zlib.MAX_WBITS).decode(\"utf-8\")\n except Exception:\n return None\n\n @staticmethod\n def priority() -> float:\n # Not expected to show up often, but also very fast to check.\n return 0.05\n\n def __init__(self, config: Config):\n super().__init__(config)\n\n @staticmethod\n def getParams() -> Optional[Dict[str, ParamSpec]]:\n return None\n\n @staticmethod\n def getTarget() -> str:\n return \"gzip\"\n\n# Path: ciphey/basemods/Decoders/hexadecimal.py\nfrom typing import Dict, Optional\n\nfrom ciphey.iface import Config, Decoder, ParamSpec, T, U, registry\n\n\n@registry.register\nclass Hexadecimal(Decoder[str]):\n def decode(self, ctext: T) -> Optional[U]:\n \"\"\"\n Performs Hexadecimal decoding\n \"\"\"\n ctext_decoded = \"\"\n try:\n ctext_decoded = bytearray.fromhex(ctext).decode(\"utf-8\")\n return ctext_decoded\n except Exception:\n return None\n\n @staticmethod\n def priority() -> float:\n return 0.015\n\n def __init__(self, config: Config):\n super().__init__(config)\n\n @staticmethod\n def getParams() -> Optional[Dict[str, ParamSpec]]:\n return None\n\n @staticmethod\n def getTarget() -> str:\n return \"hexadecimal\"\n\n# Path: ciphey/basemods/Decoders/leetspeak.py\nfrom typing import Dict, Optional\n\nfrom ciphey.iface import Config, Decoder, ParamSpec, T, Translation, U, registry\n\n\n@registry.register\nclass Leetspeak(Decoder[str]):\n def decode(self, ctext: T) -> Optional[U]:\n for src, dst in self.translate.items():\n ctext = ctext.replace(src, dst)\n return ctext\n\n @staticmethod\n def priority() -> float:\n return 0.05\n\n def __init__(self, config: Config):\n super().__init__(config)\n self.translate = config.get_resource(self._params()[\"dict\"], Translation)\n\n @staticmethod\n def getParams() -> Optional[Dict[str, ParamSpec]]:\n return {\n \"dict\": ParamSpec(\n desc=\"The leetspeak dictionary to use\",\n req=False,\n default=\"cipheydists::translate::leet\",\n )\n }\n\n @staticmethod\n def getTarget() -> str:\n return \"leetspeak\"\n\n# Path: ciphey/basemods/Decoders/morse_code.py\nfrom typing import Dict, Optional\n\nimport logging\nfrom rich.logging import RichHandler\n\nfrom ciphey.iface import Config, Decoder, ParamSpec, T, Translation, U, registry\n\n\n@registry.register\nclass Morse_code(Decoder[str]):\n # A priority list for char/word boundaries\n BOUNDARIES = {\" \": 1, \"/\": 2, \"\\n\": 3}\n PURGE = {ord(c): None for c in BOUNDARIES.keys()}\n MAX_PRIORITY = 3\n ALLOWED = {\".\", \"-\", \" \", \"/\", \"\\n\"}\n MORSE_CODE_DICT: Dict[str, str]\n MORSE_CODE_DICT_INV: Dict[str, str]\n\n def decode(self, ctext: T) -> Optional[U]:\n logging.debug(\"Attempting Morse code decoder\")\n\n char_boundary = word_boundary = None\n\n char_boundary = word_boundary = None\n char_priority = word_priority = 0\n # Custom loop allows early break\n for i in ctext:\n i_priority = self.BOUNDARIES.get(i)\n if i_priority is None:\n if i in self.ALLOWED:\n continue\n logging.debug(f\"Non-morse char '{i}' found\")\n return None\n\n if i_priority <= char_priority or i == char_boundary or i == word_boundary:\n continue\n # Default to having a char boundary over a word boundary\n if (\n i_priority > word_priority\n and word_boundary is None\n and char_boundary is not None\n ):\n word_priority = i_priority\n word_boundary = i\n continue\n char_priority = i_priority\n char_boundary = i\n\n logging.debug(\n f\"Char boundary is unicode {ord(char_boundary)}, and word boundary is unicode {ord(word_boundary) if word_boundary is not None else None}\"\n )\n\n result = \"\"\n\n for word in ctext.split(word_boundary) if word_boundary else [ctext]:\n logging.debug(f\"Attempting to decode word {word}\")\n for char in word.split(char_boundary):\n char = char.translate(self.PURGE)\n if len(char) == 0:\n continue\n try:\n m = self.MORSE_CODE_DICT_INV[char]\n except KeyError:\n logging.debug(f\"Invalid codeword '{char}' found\")\n return None\n result = result + m\n # after every word add a space\n result = result + \" \"\n if len(result) == 0:\n logging.debug(\"Morse code failed to match\")\n return None\n # Remove trailing space\n result = result[:-1]\n logging.info(f\"Morse code successful, returning {result}\")\n return result.strip().upper()\n\n @staticmethod\n def priority() -> float:\n return 0.05\n\n def __init__(self, config: Config):\n super().__init__(config)\n self.MORSE_CODE_DICT = config.get_resource(self._params()[\"dict\"], Translation)\n self.MORSE_CODE_DICT_INV = {v: k for k, v in self.MORSE_CODE_DICT.items()}\n\n @staticmethod\n def getParams() -> Optional[Dict[str, ParamSpec]]:\n return {\n \"dict\": ParamSpec(\n desc=\"The morse code dictionary to use\",\n req=False,\n default=\"cipheydists::translate::morse\",\n )\n }\n\n @staticmethod\n def getTarget() -> str:\n return \"morse_code\"\n\n# Path: ciphey/basemods/Decoders/multi_tap.py\nfrom typing import Dict, Optional\n\nfrom ciphey.iface import Config, Decoder, ParamSpec, T, U, registry\n\n\n@registry.register\nclass Multi_tap(Decoder[str]):\n def decode(self, ctext: T) -> Optional[U]:\n result = \"\"\n for x in ctext.split():\n if x == self.SPACE_DIGIT: # Check if it's a space\n result += \" \"\n elif not Multi_tap.valid_code_part(x):\n return None\n else:\n result += self.decode_num_to_char(x)\n\n return result\n\n @staticmethod\n def valid_code_part(code: str) -> bool:\n if not code.isdigit():\n return False\n\n # if not all the digits are the same\n if not Multi_tap.is_all_dup(code):\n return False\n\n if int(code[0]) not in range(2, 10):\n return False\n\n if len(code) > 4:\n return False\n\n return True\n\n @staticmethod\n def decode_num_to_char(number: str) -> str:\n index = Multi_tap.calculate_index(number)\n return Multi_tap.number_index_to_char(index)\n\n @staticmethod\n def is_all_dup(code):\n return len(set(code)) == 1\n\n @staticmethod\n def calculate_index(number: str) -> int:\n first_number_as_int = int(number[0])\n\n number_index = Multi_tap.get_index_from_first_digit(first_number_as_int)\n\n # Add to index the number of the char : \"22\" -> index += 1\n num_rest_numbers = len(number) - 1\n number_index += num_rest_numbers\n\n return number_index\n\n @staticmethod\n def number_index_to_char(index_number: int) -> str:\n start_ascii_value = ord(\"A\")\n return chr(start_ascii_value + index_number)\n\n @staticmethod\n def get_index_from_first_digit(first_digit: int) -> int:\n number_index = 0\n if first_digit >= 8: # s have 4 chars\n number_index += 1\n\n first_digit -= 2 # start in 200\n\n number_index += first_digit * 3 # jump 3 every time\n\n return number_index\n\n @staticmethod\n def priority() -> float:\n return 0.05\n\n def __init__(self, config: Config):\n super().__init__(config)\n self.SPACE_DIGIT = \"0\"\n\n @staticmethod\n def getParams() -> Optional[Dict[str, ParamSpec]]:\n return None\n\n @staticmethod\n def getTarget() -> str:\n return \"multi_tap\"\n\n# Path: ciphey/basemods/Decoders/octal.py\nfrom typing import Dict, Optional\n\nimport logging\nfrom rich.logging import RichHandler\n\nfrom ciphey.iface import Config, Decoder, ParamSpec, T, U, registry\n\n\n@registry.register\nclass Octal(Decoder[str]):\n def decode(self, ctext: T) -> Optional[U]:\n \"\"\"\n Performs Octal decoding\n \"\"\"\n str_converted = []\n octal_seq = ctext.split(\" \")\n if len(octal_seq) == 1:\n # Concatted octal must be formed of octal triplets\n if len(ctext) % 3 != 0:\n return None\n octal_seq = [ctext[i : i + 3] for i in range(0, len(ctext), 3)]\n logging.debug(f\"Trying chunked octal {octal_seq}\")\n try:\n for octal_char in octal_seq:\n if len(octal_char) > 3:\n logging.debug(\"Octal subseq too long\")\n return None\n n = int(octal_char, 8)\n if (\n n < 0\n ): # n cannot be greater than 255, as we checked that with the earlier length check\n logging.debug(f\"Non octal char {octal_char}\")\n return None\n str_converted.append(n)\n\n return bytes(str_converted)\n # Catch bad octal chars\n except ValueError:\n return None\n\n @staticmethod\n def priority() -> float:\n return 0.025\n\n def __init__(self, config: Config):\n super().__init__(config)\n\n @staticmethod\n def getParams() -> Optional[Dict[str, ParamSpec]]:\n return None\n\n @staticmethod\n def getTarget() -> str:\n return \"octal\"\n\n# Path: ciphey/basemods/Decoders/reverse.py\nfrom typing import Dict, Optional\n\nfrom ciphey.iface import Config, Decoder, ParamSpec, T, U, registry\n\n\n@registry.register\nclass Reverse(Decoder[str]):\n def decode(self, ctext: T) -> Optional[U]:\n return ctext[::-1]\n\n @staticmethod\n def priority() -> float:\n return 0.05\n\n def __init__(self, config: Config):\n super().__init__(config)\n\n @staticmethod\n def getParams() -> Optional[Dict[str, ParamSpec]]:\n return None\n\n @staticmethod\n def getTarget() -> str:\n return \"reverse\"\n\n# Path: ciphey/basemods/Decoders/tap_code.py\n# by https://github.com/RustyDucky and https://github.com/lukasgabriel\n\nfrom typing import Dict, Optional\n\nfrom ciphey.iface import Config, Decoder, ParamSpec, T, Translation, U, registry\n\n\n@registry.register\nclass Tap_code(Decoder[str]):\n def decode(self, ctext: T) -> Optional[U]:\n \"\"\"\n Performs Tap code decoding\n \"\"\"\n try:\n result = \"\"\n combinations = ctext.split(\" \")\n for fragment in combinations:\n result += self.TABLE.get(fragment)\n return result\n except Exception:\n return None\n\n @staticmethod\n def priority() -> float:\n return 0.06\n\n def __init__(self, config: Config):\n super().__init__(config)\n self.TABLE = config.get_resource(self._params()[\"dict\"], Translation)\n\n @staticmethod\n def getParams() -> Optional[Dict[str, ParamSpec]]:\n return {\n \"dict\": ParamSpec(\n desc=\"The table of letters used for the tap code interpretation.\",\n req=False,\n default=\"cipheydists::translate::tap_code\",\n )\n }\n\n @staticmethod\n def getTarget() -> str:\n return \"tap_code\"\n\n# Path: ciphey/basemods/Decoders/unicode.py\nfrom typing import Dict, Optional\n\nimport logging\nfrom rich.logging import RichHandler\n\nfrom ciphey.iface import Config, Decoder, ParamSpec, T, U, registry\n\n\n@registry.register\nclass Utf8(Decoder[bytes]):\n def decode(self, ctext: T) -> Optional[U]:\n \"\"\"\n Performs UTF-8 decoding\n \"\"\"\n logging.debug(\"Attempting UTF-8 decoder\")\n result = \"\"\n try:\n result = ctext.decode(\"utf-8\")\n if result != ctext:\n logging.info(f\"UTF-8 successful, returning '{result}'\")\n return result\n else:\n return None\n except Exception:\n return None\n\n @staticmethod\n def priority() -> float:\n return 0.9\n\n def __init__(self, config: Config):\n super().__init__(config)\n\n @staticmethod\n def getParams() -> Optional[Dict[str, ParamSpec]]:\n return None\n\n @staticmethod\n def getTarget() -> str:\n return \"utf8\"\n\n# Path: ciphey/basemods/Decoders/url.py\nfrom typing import Dict, Optional\nfrom urllib.parse import unquote_plus\n\nimport logging\nfrom rich.logging import RichHandler\n\nfrom ciphey.iface import Config, Decoder, ParamSpec, T, U, registry\n\n\n@registry.register\nclass Url(Decoder[str]):\n def decode(self, ctext: T) -> Optional[U]:\n \"\"\"\n Performs URL decoding\n \"\"\"\n logging.debug(\"Attempting URL\")\n result = \"\"\n try:\n result = unquote_plus(ctext, errors=\"strict\")\n if result != ctext:\n logging.info(f\"URL successful, returning '{result}'\")\n return result\n else:\n return None\n except Exception:\n logging.debug(\"Failed to decode URL\")\n return None\n\n @staticmethod\n def priority() -> float:\n return 0.05\n\n def __init__(self, config: Config):\n super().__init__(config)\n\n @staticmethod\n def getParams() -> Optional[Dict[str, ParamSpec]]:\n return None\n\n @staticmethod\n def getTarget() -> str:\n return \"url\"\n\n# Path: ciphey/basemods/Decoders/uuencode.py\nfrom binascii import a2b_uu\nfrom codecs import decode\nfrom typing import Dict, Optional\n\nimport logging\nfrom rich.logging import RichHandler\n\nfrom ciphey.iface import Config, Decoder, ParamSpec, T, U, registry\n\n\n@registry.register\nclass Uuencode(Decoder[str]):\n def decode(self, ctext: T) -> Optional[U]:\n \"\"\"\n UUEncode (Unix to Unix Encoding) is a symmetric encryption\n based on conversion of binary data (split into 6-bit blocks) into ASCII characters.\n\n This function decodes the input string 'ctext' if it has been encoded using 'uuencoder'\n It will return None otherwise\n \"\"\"\n logging.debug(\"Attempting UUencode\")\n result = \"\"\n try:\n # UUencoded messages may begin with prefix \"begin\" and end with suffix \"end\"\n # In that case, we use the codecs module in Python\n ctext_strip = ctext.strip()\n if ctext_strip.startswith(\"begin\") and ctext_strip.endswith(\"end\"):\n result = decode(bytes(ctext, \"utf-8\"), \"uu\").decode()\n else:\n # If there isn't a \"being\" prefix and \"end\" suffix, we use the binascii module instead\n # It is possible that the ctext has multiple lines, so convert each line and append\n ctext_split = list(filter(None, ctext.splitlines()))\n for _, value in enumerate(ctext_split):\n result += a2b_uu(value).decode(\"utf-8\")\n logging.info(f\"UUencode successful, returning '{result}'\")\n return result\n except Exception:\n logging.debug(\"Failed to decode UUencode\")\n return None\n\n @staticmethod\n def priority() -> float:\n # Not expected to show up often, but also very fast to check.\n return 0.05\n\n def __init__(self, config: Config):\n super().__init__(config)\n\n @staticmethod\n def getParams() -> Optional[Dict[str, ParamSpec]]:\n return None\n\n @staticmethod\n def getTarget() -> str:\n return \"uuencode\"\n\n# Path: ciphey/basemods/Decoders/__init__.py\nfrom . import (\n a1z26,\n atbash,\n base58_bitcoin,\n base58_ripple,\n base62,\n base69,\n base91,\n bases,\n baudot,\n binary,\n braille,\n brainfuck,\n decimal,\n dna,\n dtmf,\n galactic,\n gzip,\n hexadecimal,\n leetspeak,\n morse_code,\n multi_tap,\n octal,\n reverse,\n tap_code,\n unicode,\n url,\n uuencode,\n)\n\n# Path: ciphey/basemods/Resources/cipheydists.py\nfrom functools import lru_cache\nfrom typing import Any, Dict, Optional, Set\n\nimport cipheydists\nimport logging\n\nfrom ciphey.iface import (\n Config,\n Distribution,\n ParamSpec,\n ResourceLoader,\n Translation,\n WordList,\n registry,\n)\n\n\n@registry.register_multi(WordList, Distribution, Translation)\nclass CipheyDists(ResourceLoader):\n # _wordlists: Set[str] = frozenset({\"english\", \"english1000\", \"englishStopWords\"})\n # _brandons: Set[str] = frozenset({\"english\"})\n # _dists: Set[str] = frozenset({\"twist\"})\n # _translates: Set[str] = frozenset({\"morse\"})\n _getters = {\n \"list\": cipheydists.get_list,\n \"dist\": cipheydists.get_dist,\n \"brandon\": cipheydists.get_brandon,\n \"translate\": cipheydists.get_translate,\n }\n\n def whatResources(self) -> Optional[Set[str]]:\n pass\n\n @lru_cache()\n def getResource(self, name: str) -> Any:\n logging.debug(f\"Loading cipheydists resource {name}\")\n prefix, name = name.split(\"::\", 1)\n return self._getters[prefix](name)\n\n def __init__(self, config: Config):\n super().__init__(config)\n\n @staticmethod\n def getParams() -> Optional[Dict[str, ParamSpec]]:\n return None\n\n# Path: ciphey/basemods/Resources/files.py\nimport csv\nimport json\nfrom functools import lru_cache\nfrom typing import Dict, Generic, Optional, Set\n\nfrom ciphey.iface import (\n Config,\n Distribution,\n ParamSpec,\n ResourceLoader,\n T,\n WordList,\n registry,\n)\n\n\n# We can use a generic resource loader here, as we can instantiate it later\n@registry.register_multi(WordList, Distribution)\nclass Json(ResourceLoader):\n def whatResources(self) -> T:\n return self._names\n\n @lru_cache()\n def getResource(self, name: str) -> T:\n prefix, name = name.split(\"::\", 1)\n return {\"wordlist\": (lambda js: {js}), \"dist\": (lambda js: js)}[prefix](\n json.load(open(self._paths[int(name) - 1]))\n )\n\n @staticmethod\n def getName() -> str:\n return \"json\"\n\n @staticmethod\n def getParams() -> Optional[Dict[str, ParamSpec]]:\n return {\"path\": ParamSpec(req=True, desc=\"The path to a JSON file\", list=True)}\n\n def __init__(self, config: Config):\n super().__init__(config)\n self._paths = self._params()[\"path\"]\n self._names = set(range(1, len(self._paths)))\n\n\n# We can use a generic resource loader here, as we can instantiate it later\n@registry.register_multi(WordList, Distribution)\nclass Csv(Generic[T], ResourceLoader[T]):\n def whatResources(self) -> Set[str]:\n return self._names\n\n @lru_cache()\n def getResource(self, name: str) -> T:\n prefix, name = name.split(\"::\", 1)\n return {\n \"wordlist\": (lambda reader: {i[0] for i in reader}),\n \"dist\": (lambda reader: {i[0]: float(i[1]) for i in reader}),\n }[prefix](csv.reader(open(self._paths[int(name) - 1])))\n\n @staticmethod\n def getName() -> str:\n return \"csv\"\n\n @staticmethod\n def getParams() -> Optional[Dict[str, ParamSpec]]:\n return {\"path\": ParamSpec(req=True, desc=\"The path to a CSV file\", list=True)}\n\n def __init__(self, config: Config):\n super().__init__(config)\n self._paths = self._params()[\"path\"]\n self._names = set(range(1, len(self._paths)))\n\n# Path: ciphey/basemods/Resources/__init__.py\nfrom . import cipheydists, files\n\n# Path: ciphey/basemods/Searchers/ausearch.py\nimport bisect\nimport distutils\nimport math\nfrom copy import copy\nfrom dataclasses import dataclass\nfrom functools import lru_cache\nfrom typing import Any, Dict, Generic, List, Optional, TypeVar, Union\n\nimport logging\nfrom rich.logging import RichHandler\n\nfrom ciphey.iface import (\n Checker,\n Config,\n Cracker,\n CrackInfo,\n CrackResult,\n Decoder,\n ParamSpec,\n Searcher,\n SearchLevel,\n SearchResult,\n T,\n registry,\n)\n\n\"\"\"\nWe are using a tree structure here, because that makes searching and tracing back easier\nAs such, when we encounter another possible parent, we remove that edge\n\"\"\"\n\n\nclass DuplicateNode(Exception):\n pass\n\n\n@dataclass\nclass AuSearchSuccessful(Exception):\n target: \"Node\"\n info: str\n\n\n@dataclass\nclass Node:\n # The root has no parent edge\n level: SearchLevel\n parent: Optional[\"Edge\"] = None\n depth: int = 0\n\n @staticmethod\n def decoding(\n config: Config, route: Union[Cracker, Decoder], result: Any, source: \"Node\"\n ) -> \"Node\":\n if not config.cache.mark_ctext(result):\n raise DuplicateNode()\n\n checker: Checker = config.objs[\"checker\"]\n ret = Node(\n parent=None,\n level=SearchLevel(\n name=type(route).__name__.lower(), result=CrackResult(value=result)\n ),\n depth=source.depth + 1,\n )\n edge = Edge(source=source, route=route, dest=ret)\n ret.parent = edge\n check_res = checker(result)\n if check_res is not None:\n raise AuSearchSuccessful(target=ret, info=check_res)\n return ret\n\n @staticmethod\n def cracker(config: Config, edge_template: \"Edge\", result: CrackResult) -> \"Node\":\n if not config.cache.mark_ctext(result.value):\n raise DuplicateNode()\n\n checker: Checker = config.objs[\"checker\"]\n # Edges do not directly contain containers, so this is fine\n edge = copy(edge_template)\n ret = Node(\n parent=edge,\n level=SearchLevel(name=type(edge.route).__name__.lower(), result=result),\n depth=edge.source.depth + 1,\n )\n edge.dest = ret\n check_res = checker(result.value)\n if check_res is not None:\n raise AuSearchSuccessful(target=ret, info=check_res)\n return ret\n\n @staticmethod\n def root(config: Config, ctext: Any):\n if not config.cache.mark_ctext(ctext):\n raise DuplicateNode()\n\n return Node(parent=None, level=SearchLevel.input(ctext))\n\n def get_path(self):\n if self.parent is None:\n return [self.level]\n return self.parent.source.get_path() + [self.level]\n\n\n@dataclass\nclass AusearchEdge:\n # TODO: This is just CrackInfo with failure probability added...\n success_probability: float\n failure_probability: float\n success_time: float\n failure_time: float\n\n def __init__(self, success_probability, success_time, failure_time):\n self.success_probability = success_probability\n self.failure_probability = 1.0 - success_probability\n self.success_time = success_time\n self.failure_time = failure_time\n\n\n@dataclass\nclass AusearchResult:\n weight: float\n index: int\n\n\ndef calculate_score(info: CrackInfo):\n return info.success_likelihood / \\\n (info.success_runtime * info.success_likelihood + info.failure_runtime * (1-info.success_likelihood))\n\n\n@dataclass\nclass Edge:\n source: Node\n route: Union[Cracker, Decoder]\n dest: Optional[Node] = None\n # Info is not filled in for Decoders\n score: Optional[float] = None\n\n\nPriorityType = TypeVar(\"PriorityType\")\n\n\nclass PriorityWorkQueue(Generic[PriorityType, T]):\n _sorted_priorities: List[PriorityType]\n _queues: Dict[Any, List[T]]\n\n def add_work(self, priority: PriorityType, work: List[T]) -> None:\n logging.debug(f\"\"\"Adding work at depth {priority}\"\"\")\n\n idx = bisect.bisect_left(self._sorted_priorities, priority)\n if (\n idx == len(self._sorted_priorities)\n or self._sorted_priorities[idx] != priority\n ):\n self._sorted_priorities.insert(idx, priority)\n self._queues.setdefault(priority, []).extend(work)\n\n def get_work(self) -> T:\n best_priority = self._sorted_priorities[0]\n target = self._queues[best_priority]\n ret = target.pop(0)\n if len(target) == 0:\n self._sorted_priorities.pop()\n return ret\n\n def get_work_chunk(self) -> List[T]:\n \"\"\"Returns the best work for now\"\"\"\n if len(self._sorted_priorities) == 0:\n return []\n best_priority = self._sorted_priorities.pop(0)\n return self._queues.pop(best_priority)\n\n def empty(self):\n return len(self._sorted_priorities) == 0\n\n def __init__(self):\n self._sorted_priorities = []\n self._queues = {}\n\n\n@registry.register\nclass AuSearch(Searcher):\n # Deeper paths get done later\n work: PriorityWorkQueue[int, Edge]\n\n @staticmethod\n def get_crackers_for(t: type):\n return registry[Cracker[t]]\n\n @lru_cache() # To save extra sorting\n def get_decoders_for(self, t: type):\n ret = registry[Decoder[t]]\n ret.sort(key=lambda x: x.priority(), reverse=True)\n return ret\n\n # def expand(self, edge: Edge) -> List[Edge]:\n # \"\"\"Evaluates the destination of the given, and adds its child edges to the pool\"\"\"\n # edge.dest = Node(parent=edge, level=edge.route(edge.source.level.result.value))\n\n def expand_crackers(self, node: Node) -> None:\n if node.depth >= self.max_cipher_depth:\n return\n\n res = node.level.result.value\n additional_work = []\n\n for i in self.get_crackers_for(type(res)):\n inst = self._config()(i)\n info = inst.getInfo(res)\n if info.success_likelihood < self.p_threshold:\n continue\n additional_work.append(\n Edge(source=node, route=inst, score=calculate_score(inst.getInfo(res)))\n )\n\n priority = min(node.depth, self.priority_cap)\n if self.invert_priority:\n priority = -priority\n\n self.work.add_work(priority, additional_work)\n\n def expand_decodings(self, node: Node) -> None:\n val = node.level.result.value\n\n for decoder in self.get_decoders_for(type(val)):\n inst = self._config()(decoder)\n res = inst(val)\n if res is None:\n continue\n try:\n new_node = Node.decoding(\n config=self._config(), route=inst, result=res, source=node\n )\n except DuplicateNode:\n continue\n\n logging.debug(\"Nesting encodings\")\n self.recursive_expand(new_node, False)\n\n def recursive_expand(self, node: Node, nested: bool = True) -> None:\n if node.depth >= self.max_depth:\n return\n\n logging.debug(f\"Expanding depth {node.depth}\")\n\n self.expand_decodings(node)\n\n # Doing this last allows us to catch simple nested encodings faster\n if not nested or self.enable_nested:\n self.expand_crackers(node)\n\n def search(self, ctext: Any) -> Optional[SearchResult]:\n logging.debug(\n f\"\"\"Beginning AuSearch with {\"inverted\" if self.invert_priority else \"normal\"} priority\"\"\"\n )\n\n try:\n root = Node.root(self._config(), ctext)\n except DuplicateNode:\n return None\n\n check_res = self._config().objs[\"checker\"](ctext)\n if check_res is not None:\n return SearchResult(check_res=check_res, path=[root.level])\n\n try:\n self.recursive_expand(root, False)\n\n while True:\n if self.work.empty():\n break\n # Get the highest level result\n chunk = self.work.get_work_chunk()\n chunk.sort(key=lambda i: i.score)\n # Work through all of this level's results\n while len(chunk) != 0:\n logging.debug(f\"{len(chunk)} remaining on this level\")\n # TODO Cyclic uses some tricky C++ here\n # I know because it's sorted the one at the back (the anti-weight)\n # is the most likely\n\n edge: Edge = chunk.pop(-1)\n\n # Expand the node\n res = edge.route(edge.source.level.result.value)\n if res is None:\n continue\n for i in res:\n try:\n node = Node.cracker(\n...\n", "needle_token_start": 770, "needle_token_end": 814, "code_context_ntokens": 15997, "output": ["```python\ndef chars_to_byte(self, s: str):\n return (69 * self.CHARS.index(s[1])) + (self.CHARS.index(s[0]))\n```"]} +{"repo": "Ciphey/Ciphey", "name": "__ge__", "language": "python", "path": "ciphey/iface/_modules.py", "position_ratio": 0.15, "description": "\nFunction Description:\n1. **Purpose**: To determine if the priority of the current instance is greater than or equal to another instance.\n2. **Input**: Another instance of the same class.\n3. **Output**: A boolean value indicating whether the current instance's priority is greater than or equal to the other instance's priority.\n4. **Procedure**: The method retrieves the priority value of the current instance and compares it with the priority value of the other instance using a greater than or equal to comparison.\n\n", "instruction": "Based on the function description and code context, please retrieve and repeat the exact described function from the code context in a code block wrapped by ```:", "template": "instruction\ncode_context\ndescription\ninstruction", "code_context": "# Path: ciphey/iface/_fwd.py\nregistry = None\nconfig = type(None)\n\n# Path: ciphey/iface/_modules.py\nfrom abc import ABC, abstractmethod\nfrom typing import Any, Dict, Generic, List, NamedTuple, Optional, Set, Type, TypeVar\n\nfrom rich import box\nfrom rich.console import Console\nfrom rich.markup import escape\nfrom rich.table import Table\n\nfrom ._fwd import config as Config\n\nT = TypeVar(\"T\")\nU = TypeVar(\"U\")\n\nconsole = Console()\n\n\nclass ParamSpec(NamedTuple):\n \"\"\"\n Attributes:\n req Whether this argument is required\n desc A description of what this argument does\n default The default value for this argument. Ignored if req == True or configPath is not None\n config_ref The path to the config that should be the default value\n list Whether this parameter is in the form of a list, and can therefore be specified more than once\n visible Whether the user can tweak this via the command line\n \"\"\"\n\n req: bool\n desc: str\n default: Optional[Any] = None\n list: bool = False\n config_ref: Optional[List[str]] = None\n visible: bool = True\n\n\nclass ConfigurableModule(ABC):\n @staticmethod\n @abstractmethod\n def getParams() -> Optional[Dict[str, ParamSpec]]:\n \"\"\"\n Returns a dictionary of `argument name: argument specification`\n \"\"\"\n pass\n\n def _checkParams(self):\n \"\"\"\n Fills the given params dict with default values where arguments are not given,\n using None as the default value for default values\n \"\"\"\n\n params = self._params()\n config = self._config()\n\n for key, value in self.getParams().items():\n # If we already have it, then we don't need to do anything\n if key in params:\n continue\n # If we don't have it, but it's required, then fail\n if value.req:\n raise KeyError(\n f\"Missing required param {key} for {type(self).__name__.lower()}\"\n )\n # If it's a reference by default, fill that in\n if value.config_ref is not None:\n tmp = getattr(config, value.config_ref[0])\n params[key] = (\n tmp[value.config_ref[1:]] if len(value.config_ref) > 1 else tmp\n )\n # Otherwise, put in the default value (if it exists)\n elif value.default is not None:\n params[key] = value.default\n\n def _params(self):\n return self._params_obj\n\n def _config(self):\n return self._config_obj\n\n @abstractmethod\n def __init__(self, config: Config):\n self._config_obj = config\n if self.getParams() is not None:\n self._params_obj = config.params.setdefault(type(self).__name__.lower(), {})\n self._checkParams()\n\n\nclass Targeted(ABC):\n @staticmethod\n @abstractmethod\n def getTarget() -> str:\n \"\"\"Should return the target that this object attacks/decodes\"\"\"\n pass\n\n\nclass PolymorphicChecker(ConfigurableModule):\n @abstractmethod\n def check(self, text) -> Optional[str]:\n \"\"\"Should return some description (or an empty string) on success, otherwise return None\"\"\"\n pass\n\n @abstractmethod\n def getExpectedRuntime(self, text) -> float:\n pass\n\n def __call__(self, *args):\n return self.check(*args)\n\n @abstractmethod\n def __init__(self, config: Config):\n super().__init__(config)\n\n\nclass Checker(Generic[T], ConfigurableModule):\n @abstractmethod\n def check(self, text: T) -> Optional[str]:\n \"\"\"Should return some description (or an empty string) on success, otherwise return None\"\"\"\n pass\n\n @abstractmethod\n def getExpectedRuntime(self, text: T) -> float:\n pass\n\n def __call__(self, *args):\n return self.check(*args)\n\n @abstractmethod\n def __init__(self, config: Config):\n super().__init__(config)\n\n @classmethod\n def convert(cls, expected: Set[type]):\n class PolyWrapperClass(PolymorphicChecker):\n @staticmethod\n def getParams() -> Optional[Dict[str, ParamSpec]]:\n return cls.getParams()\n\n def check(self, text) -> Optional[str]:\n \"\"\"Should return some description (or an empty string) on success, otherwise return None\"\"\"\n if type(text) not in expected:\n return None\n else:\n return self._base.check(text)\n\n def getExpectedRuntime(self, text) -> float:\n if type(text) not in expected:\n return 0\n else:\n return self._base.getExpectedRuntime(text)\n\n def __init__(self, config: Config):\n super().__init__(config)\n # This is easier than inheritance\n self._base = cls(config)\n\n PolyWrapperClass.__name__ = cls.__name__\n\n return PolyWrapperClass\n\n\n# class Detector(Generic[T], ConfigurableModule, KnownUtility, Targeted):\n# @abstractmethod\n# def scoreLikelihood(self, ctext: T) -> Dict[str, float]:\n# \"\"\"Should return a dictionary of (cipher_name: score)\"\"\"\n# pass\n#\n# def __call__(self, *args): return self.scoreLikelihood(*args)\n#\n# @abstractmethod\n# def __init__(self, config: Config): super().__init__(config)\n\n\nclass Decoder(Generic[T], ConfigurableModule, Targeted):\n \"\"\"Represents the undoing of some encoding into a different (or the same) type\"\"\"\n\n @abstractmethod\n def decode(self, ctext: T) -> Optional[U]:\n pass\n\n @staticmethod\n @abstractmethod\n def priority() -> float:\n \"\"\"What proportion of decodings are this?\"\"\"\n pass\n\n def __call__(self, *args):\n return self.decode(*args)\n\n @abstractmethod\n def __init__(self, config: Config):\n super().__init__(config)\n\n\nclass DecoderComparer:\n value: Type[Decoder]\n\n def __le__(self, other: \"DecoderComparer\"):\n return self.value.priority() <= other.value.priority()\n\n \ndef __ge__(self, other: \"DecoderComparer\"):\n return self.value.priority() >= other.value.priority()\n\n def __lt__(self, other: \"DecoderComparer\"):\n return self.value.priority() < other.value.priority() and self != other\n\n def __gt__(self, other: \"DecoderComparer\"):\n return self.value.priority() > other.value.priority() and self != other\n\n def __init__(self, value: Type[Decoder]):\n self.value = value\n\n def __repr__(self):\n return f\"\"\n\n\nclass CrackResult(NamedTuple):\n # TODO consider using Generic[T] again for value's type once\n # https://bugs.python.org/issue36517 is resolved\n value: Any\n key_info: Optional[str] = None\n misc_info: Optional[str] = None\n\n\nclass CrackInfo(NamedTuple):\n success_likelihood: float\n success_runtime: float\n failure_runtime: float\n\n\nclass Cracker(Generic[T], ConfigurableModule, Targeted):\n @abstractmethod\n def getInfo(self, ctext: T) -> CrackInfo:\n \"\"\"Should return some informed guesses on resource consumption when run on `ctext`\"\"\"\n pass\n\n @abstractmethod\n def attemptCrack(self, ctext: T) -> List[CrackResult]:\n \"\"\"\n This should attempt to crack the cipher `target`, and return a list of candidate solutions\n \"\"\"\n # FIXME: Actually CrackResult[T], but python complains\n pass\n\n def __call__(self, *args):\n return self.attemptCrack(*args)\n\n @abstractmethod\n def __init__(self, config: Config):\n super().__init__(config)\n\n\nclass ResourceLoader(Generic[T], ConfigurableModule):\n @abstractmethod\n def whatResources(self) -> Optional[Set[str]]:\n \"\"\"\n Return a set of the names of instances T you can provide.\n The names SHOULD be unique amongst ResourceLoaders of the same type\n\n These names will be exposed as f\"{self.__name__}::{name}\", use split_resource_name to recover this\n\n If you cannot reasonably determine what resources you provide, return None instead\n \"\"\"\n pass\n\n @abstractmethod\n def getResource(self, name: str) -> T:\n \"\"\"\n Returns the requested distribution\n\n The behavior is undefined if `name not in self.what_resources()`\n \"\"\"\n pass\n\n def __call__(self, *args):\n return self.getResource(*args)\n\n def __getitem__(self, *args):\n return self.getResource(*args)\n\n @abstractmethod\n def __init__(self, config: Config):\n super().__init__(config)\n\n\nclass SearchLevel(NamedTuple):\n name: str\n result: CrackResult\n\n @staticmethod\n def input(ctext: Any):\n return SearchLevel(name=\"input\", result=CrackResult(ctext))\n\n\nclass SearchResult(NamedTuple):\n path: List[SearchLevel]\n check_res: str\n\n\nclass Searcher(ConfigurableModule):\n \"\"\"A very basic interface for code that plans out how to crack the ciphertext\"\"\"\n\n @abstractmethod\n def search(self, ctext: Any) -> Optional[SearchResult]:\n \"\"\"Returns the path to the correct ciphertext\"\"\"\n pass\n\n @abstractmethod\n def __init__(self, config: Config):\n super().__init__(config)\n\n\ndef pretty_search_results(res: SearchResult, display_intermediate: bool = False) -> str:\n # TODO what is display_intermediate\n ret: str = \"\"\n table = Table(show_header=False, box=box.ROUNDED, safe_box=False)\n # Only print the checker if we need to. Normal people don't know what\n # \"quadgrams\", \"brandon\", \"json checker\" is.\n # We print the checker if its regex or another language, so long as it starts with:\n # \"The\" like \"The plaintext is a Uniform Resource Locator (URL).\"\n if len(res.check_res) != 0 and (\"The\" == res.check_res[0:3] or \"Passed\" == res.check_res[0:6]):\n ret += f\"{res.check_res}\\n\"\n\n def add_one():\n out = \"\"\n if i.name == \"utf8\":\n out += f\" [#808080]{i.name}[/#808080]\\n\"\n else:\n out += f\" {i.name}\"\n already_broken = False\n if i.result.key_info is not None:\n out += f\":\\n Key: {i.result.key_info}\\n\"\n already_broken = True\n if i.result.misc_info is not None:\n if not already_broken:\n out += \":\\n\"\n out += f\" Misc: {i.result.misc_info}\\n\"\n already_broken = True\n if display_intermediate:\n if not already_broken:\n out += \":\\n\"\n out += f' Value: \"{i.result.value}\"\\n'\n already_broken = True\n if not already_broken:\n out += \"\\n\"\n return out, already_broken\n\n # Skip the 'input' and print in order\n already_broken = False\n out = \"\"\n for i in res.path[1:]:\n output, already_broken = add_one()\n out += output\n\n if out:\n if len(out.split(\"\\n\")) > 1:\n ret += \"Formats used:\\n\"\n else:\n ret += \"Format used:\\n\"\n ret += out\n\n # Remove trailing newline\n ret = ret[:-1]\n\n # If we didn't show intermediate steps, then print the final result\n if already_broken:\n ret += f\"\"\"\\nPlaintext: [bold green]\"{escape(res.path[-1].result.value)}\"[bold green]\"\"\"\n else:\n ret += f\"\"\"Plaintext: [bold green]\"{escape(res.path[-1].result.value)}\"[bold green]\"\"\"\n\n table.add_row(ret)\n return table\n\n\n# Some common collection types\nDistribution = Dict[str, float]\nTranslation = Dict[str, str]\nWordList = Set[str]\n\n# Path: ciphey/iface/_config.py\nimport datetime\nimport os\nimport pydoc\nfrom typing import Any, Callable, Dict, List, Optional, Type, Union\n\nimport appdirs\nimport yaml\nimport logging\nfrom rich.logging import RichHandler\n\nfrom . import _fwd\nfrom ._modules import PolymorphicChecker, ResourceLoader, Searcher\n\n\nclass Cache:\n \"\"\"Used to track state between levels of recursion to stop infinite loops, and to optimise repeating actions\"\"\"\n\n def __init__(self):\n self._cache: Dict[Any, Dict[str, Any]] = {}\n\n def mark_ctext(self, ctext: Any) -> bool:\n if (isinstance(ctext, str) or isinstance(ctext, bytes)) and len(ctext) < 4:\n logging.debug(f\"Candidate {ctext.__repr__()} too short!\")\n return False\n\n if ctext in self._cache:\n logging.debug(f\"Deduped {ctext.__repr__()}\")\n return False\n\n logging.debug(f\"New ctext {ctext.__repr__()}\")\n\n self._cache[ctext] = {}\n return True\n\n def get_or_update(self, ctext: Any, keyname: str, get_value: Callable[[], Any]):\n # Should have been marked first\n target = self._cache[ctext]\n res = target.get(keyname)\n if res is not None:\n return res\n\n val = get_value()\n target[keyname] = val\n return val\n\n def try_get(self, ctext: Any, keyname: str):\n return self._cache[ctext].get(keyname)\n\n\ndef split_resource_name(full_name: str) -> (str, str):\n return full_name.split(\"::\", 1)\n\n\nclass Config:\n def __init__(self):\n self.verbosity: int = 0\n self.searcher: str = \"ausearch\"\n self.params: Dict[str, Dict[str, Union[str, List[str]]]] = {}\n self.format: str = \"str\"\n self.modules: List[str] = []\n self.checker: str = \"ezcheck\"\n self.default_dist: str = \"cipheydists::dist::english\"\n self.timeout: Optional[int] = None\n self._inst: Dict[type, Any] = {}\n self.objs: Dict[str, Any] = {}\n self.cache: Cache = Cache()\n\n @staticmethod\n def get_default_dir() -> str:\n return appdirs.user_config_dir(\"ciphey\")\n\n def merge_dict(self, config_file: Optional[Dict[str, Any]]):\n if config_file is None:\n return\n for a, b in config_file.items():\n self.update(a, b)\n\n def load_file(\n self,\n path: str = os.path.join(get_default_dir.__func__(), \"config.yml\"),\n create=False,\n ):\n try:\n with open(path, \"r+\") as file:\n return self.merge_dict(yaml.safe_load(file))\n except FileNotFoundError:\n if create:\n open(path, \"w+\")\n\n def instantiate(self, t: type) -> Any:\n \"\"\"\n Used to enable caching of a instantiated type after the configuration has settled\n \"\"\"\n # We cannot use set default as that would construct it again, and throw away the result\n res = self._inst.get(t)\n if res is not None:\n return res\n ret = t(self)\n self._inst[t] = ret\n return ret\n\n def __call__(self, t: type) -> Any:\n return self.instantiate(t)\n\n def update(self, attrname: str, value: Optional[Any]):\n if value is not None:\n setattr(self, attrname, value)\n\n def update_param(self, owner: str, name: str, value: Optional[Any]):\n if value is None:\n return\n\n target = self.params.setdefault(owner, {})\n\n if _fwd.registry.get_named(owner).getParams()[name].list:\n target.setdefault(name, []).append(value)\n else:\n target[name] = value\n\n def update_format(self, value: Optional[str]):\n if value is not None:\n self.format = value\n\n def load_objs(self):\n # Basic type conversion\n if self.timeout is not None:\n self.objs[\"timeout\"] = datetime.timedelta(seconds=int(self.timeout))\n self.objs[\"format\"] = pydoc.locate(self.format)\n\n # Checkers do not depend on any other config object\n logging.debug(f\"Registry is {_fwd.registry._reg[PolymorphicChecker]}\")\n self.objs[\"checker\"] = self(\n _fwd.registry.get_named(self.checker, PolymorphicChecker)\n )\n # Searchers only depend on checkers\n self.objs[\"searcher\"] = self(_fwd.registry.get_named(self.searcher, Searcher))\n\n def update_log_level(self, verbosity: Optional[int]):\n if verbosity is None:\n return\n self.verbosity = verbosity\n\n if verbosity == 0:\n self.verbosity = logging.WARNING\n elif verbosity == 1:\n self.verbosity = logging.INFO\n elif verbosity >= 2:\n self.verbosity = logging.DEBUG\n else:\n logging.disable(logging.CRITICAL)\n return\n\n # https://rich.readthedocs.io/en/latest/logging.html for more on RichHandler\n logging.basicConfig(\n level=self.verbosity,\n datefmt=\"[%X]\",\n handlers=[RichHandler(markup=True, rich_tracebacks=True)],\n )\n logging.debug(f\"Verbosity set to level {verbosity}\")\n\n def load_modules(self):\n import importlib.util\n\n for i in self.modules:\n spec = importlib.util.spec_from_file_location(\"ciphey.module_load_site\", i)\n mod = importlib.util.module_from_spec(spec)\n spec.loader.exec_module(mod)\n\n logging.info(f\"Loaded modules {_fwd.registry.get_all_names()}\")\n\n def complete_config(self) -> \"Config\":\n \"\"\"This does all the loading for the config, and then returns itself\"\"\"\n self.load_modules()\n self.load_objs()\n self.update_log_level(self.verbosity)\n return self\n\n def get_resource(self, res_name: str, t: Optional[Type] = None) -> Any:\n logging.debug(f\"Loading resource {res_name} of type {t}\")\n\n # FIXME: Actually returns obj of type `t`, but python is bad\n loader, name = split_resource_name(res_name)\n if t is None:\n return self(_fwd.registry.get_named(loader, ResourceLoader))(name)\n else:\n return self(_fwd.registry.get_named(loader, ResourceLoader[t]))(name)\n\n # Setter methods for cleaner library API\n def set_verbosity(self, i):\n self.update_log_level(i)\n return self\n\n def set_spinner(self, spinner):\n self.objs[\"spinner\"] = spinner\n\n def pause_spinner_handle(self):\n spinner = self.objs.get(\"spinner\")\n\n class PausedSpinner:\n def __enter__(self):\n if spinner is not None:\n spinner.stop()\n\n def __exit__(self, exc_type, exc_val, exc_tb):\n if spinner is not None:\n spinner.start()\n\n return PausedSpinner()\n\n @staticmethod\n def library_default():\n \"\"\"The default config for use in a library\"\"\"\n return Config().set_verbosity(-1)\n\n def __str__(self):\n return str(\n {\n \"verbosity\": self.verbosity,\n \"searcher\": self.searcher,\n \"params\": self.params,\n \"format\": self.format,\n \"modules\": self.modules,\n \"checker\": self.checker,\n \"default_dist\": self.default_dist,\n \"timeout\": self.timeout,\n }\n )\n\n\n_fwd.config = Config\n\n# Path: ciphey/iface/_registry.py\nfrom typing import Any, Dict, List, Optional, Set, Tuple, Type, Union\n\ntry:\n from typing import get_args, get_origin\nexcept ImportError:\n from typing_inspect import get_origin, get_args\n\nfrom . import _fwd\nfrom ._modules import *\n\n\nclass Registry:\n # I was planning on using __init_subclass__, but that is incompatible with dynamic type creation when we have\n # generic keys\n\n RegElem = Union[List[Type], Dict[Type, \"RegElem\"]]\n\n _reg: Dict[Type, RegElem] = {}\n _names: Dict[str, Tuple[Type, Set[Type]]] = {}\n _targets: Dict[str, Dict[Type, List[Type]]] = {}\n _modules = {Checker, Cracker, Decoder, ResourceLoader, Searcher, PolymorphicChecker}\n\n def _register_one(self, input_type, module_base, module_args):\n if len(module_args) == 0:\n self._reg.setdefault(module_base, []).append(input_type)\n return\n\n target_reg = self._reg.setdefault(module_base, {})\n # Seek to the given type\n for subtype in module_args[0:-1]:\n target_reg = target_reg.setdefault(subtype, {})\n target_reg.setdefault(module_args[-1], []).append(input_type)\n\n def _real_register(self, input_type: type, *args) -> Type:\n name = input_type.__name__.lower()\n name_target = self._names[name] = (input_type, set())\n\n if issubclass(input_type, Targeted):\n target = input_type.getTarget()\n else:\n target = None\n\n if issubclass(input_type, Searcher):\n module_type = module_base = Searcher\n module_args = ()\n else:\n module_type: Optional[Type] = None\n module_base = None\n\n # Work out what module type this is\n if len(args) == 0 and hasattr(input_type, \"__orig_bases__\"):\n for i in input_type.__orig_bases__:\n if module_type is not None:\n raise TypeError(\n f\"Type derived from multiple registrable base classes {i} and {module_type}\"\n )\n module_base = get_origin(i)\n if module_base not in self._modules:\n continue\n module_type = i\n else:\n for i in self._modules:\n if not issubclass(input_type, i):\n continue\n if module_type is not None:\n raise TypeError(\n f\"Type derived from multiple registrable base classes {i} and {module_type}\"\n )\n module_type = i\n if module_type is None:\n raise TypeError(\"No registrable base class\")\n\n # Replace input type with polymorphic checker if required\n if issubclass(input_type, Checker):\n if len(args) == 0:\n arg = [\n get_args(i)\n for i in input_type.__orig_bases__\n if get_origin(i) == Checker\n ][0]\n if len(arg) != 1:\n raise TypeError(\"No argument for Checker\")\n input_type = input_type.convert({arg[0]})\n else:\n input_type = input_type.convert(set(args))\n self._register_one(input_type, PolymorphicChecker, [])\n # Refresh the names with the new type\n name_target = self._names[name] = (input_type, {PolymorphicChecker})\n\n # Now handle the difference between register and register_multi\n if len(args) == 0:\n if module_type is PolymorphicChecker:\n module_base = PolymorphicChecker\n elif module_base is None:\n raise TypeError(\"No type argument given\")\n self._register_one(input_type, module_base, get_args(module_type))\n name_target[1].add(module_base)\n else:\n if module_base is not None:\n raise TypeError(f\"Redundant type argument for {module_type}\")\n module_base = module_type\n for module_args in args:\n # Correct missing brackets\n if not isinstance(module_args, tuple):\n module_args = (module_args,)\n\n self._register_one(input_type, module_base, module_args)\n name_target[1].add(module_type[module_args])\n\n name_target[1].add(module_type)\n\n if target is not None and issubclass(module_base, Targeted):\n self._targets.setdefault(target, {}).setdefault(module_type, []).append(\n input_type\n )\n\n return input_type\n\n def register(self, input_type):\n return self._real_register(input_type)\n\n def register_multi(self, *x):\n return lambda input_type: self._real_register(input_type, *x)\n\n def __getitem__(self, i: type) -> Optional[Any]:\n target_type = get_origin(i)\n # Check if this is a non-generic type, and return the whole dict if it is\n if target_type is None:\n return self._reg[i]\n\n target_subtypes = get_args(i)\n target_list = self._reg.setdefault(target_type, {})\n for subtype in target_subtypes:\n target_list = target_list.setdefault(subtype, {})\n return target_list\n\n def get_named(self, name: str, type_constraint: Type = None) -> Any:\n ret = self._names[name.lower()]\n if type_constraint and type_constraint not in ret[1]:\n raise TypeError(f\"Type mismatch: wanted {type_constraint}, got {ret[1]}\")\n return ret[0]\n\n def get_targeted(\n self, target: str, type_constraint: Type = None\n ) -> Optional[Union[Dict[Type, Set[Type]], Set[Type]]]:\n x = self._targets.get(target)\n if x is None or type_constraint is None:\n return x\n return x.get(type_constraint)\n\n def get_all_names(self) -> List[str]:\n return list(self._names.keys())\n\n def __str__(self):\n return f\"ciphey.iface.Registry {{_reg: {self._reg}, _names: {self._names}, _targets: {self._targets}}}\"\n\n\n_fwd.registry = Registry()\n\n# Path: ciphey/iface/__init__.py\nfrom ._config import Config\n\nfrom ._modules import (\n Checker,\n Cracker,\n CrackInfo,\n CrackResult,\n Decoder,\n DecoderComparer,\n Distribution,\n ParamSpec,\n PolymorphicChecker,\n ResourceLoader,\n Searcher,\n SearchLevel,\n SearchResult,\n T,\n Translation,\n U,\n WordList,\n pretty_search_results,\n)\nfrom ._registry import get_args, get_origin\n\nfrom ._fwd import registry\n\n# Path: ciphey/basemods/Checkers/any.py\nfrom typing import Dict, Optional\n\nfrom ciphey.iface import Config, ParamSpec, PolymorphicChecker, registry\n\n\n@registry.register\nclass Any(PolymorphicChecker):\n \"\"\"Should only be used for debugging, frankly\"\"\"\n\n def getExpectedRuntime(self, text) -> float:\n return 0 # TODO: actually calculate this\n\n def __init__(self, config: Config):\n super().__init__(config)\n\n def check(self, text: str) -> Optional[str]:\n return \"\"\n\n @staticmethod\n def getParams() -> Optional[Dict[str, ParamSpec]]:\n pass\n\n# Path: ciphey/mathsHelper.py\n\"\"\"\n \u2588\u2588\u2588\u2588\u2588\u2588\u2557\u2588\u2588\u2557\u2588\u2588\u2588\u2588\u2588\u2588\u2557 \u2588\u2588\u2557 \u2588\u2588\u2557\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2557\u2588\u2588\u2557 \u2588\u2588\u2557\n\u2588\u2588\u2554\u2550\u2550\u2550\u2550\u255d\u2588\u2588\u2551\u2588\u2588\u2554\u2550\u2550\u2588\u2588\u2557\u2588\u2588\u2551 \u2588\u2588\u2551\u2588\u2588\u2554\u2550\u2550\u2550\u2550\u255d\u255a\u2588\u2588\u2557 \u2588\u2588\u2554\u255d\n\u2588\u2588\u2551 \u2588\u2588\u2551\u2588\u2588\u2588\u2588\u2588\u2588\u2554\u255d\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2551\u2588\u2588\u2588\u2588\u2588\u2557 \u255a\u2588\u2588\u2588\u2588\u2554\u255d\n\u2588\u2588\u2551 \u2588\u2588\u2551\u2588\u2588\u2554\u2550\u2550\u2550\u255d \u2588\u2588\u2554\u2550\u2550\u2588\u2588\u2551\u2588\u2588\u2554\u2550\u2550\u255d \u255a\u2588\u2588\u2554\u255d\n\u255a\u2588\u2588\u2588\u2588\u2588\u2588\u2557\u2588\u2588\u2551\u2588\u2588\u2551 \u2588\u2588\u2551 \u2588\u2588\u2551\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2557 \u2588\u2588\u2551\n\u00a9 Brandon Skerritt\nGithub: brandonskerritt\n\nClass to provide helper functions for mathematics\n(oh, not entirely mathematics either. Some NLP stuff and sorting dicts. It's just a helper class\n)\n\"\"\"\n\nfrom collections import OrderedDict\nfrom string import punctuation\nfrom typing import Optional\n\nimport logging\nfrom rich.logging import RichHandler\n\n\nclass mathsHelper:\n \"\"\"Class to provide helper functions for mathematics and other small things\"\"\"\n\n def __init__(self):\n # ETAOIN is the most popular letters in order\n self.ETAOIN = \"ETAOINSHRDLCUMWFGYPBVKJXQZ\"\n self.LETTERS = \"ABCDEFGHIJKLMNOPQRSTUVWXYZ\"\n\n @staticmethod\n def gcd(a, b) -> int:\n \"\"\"Greatest common divisor.\n\n The Greatest Common Divisor of a and b using Euclid's Algorithm.\n\n Args:\n a -> num 1\n b -> num 2\n\n Returns:\n Returns GCD(a, b)\n\n \"\"\"\n # Return\n while a != 0:\n a, b = b % a, a\n return b\n\n @staticmethod\n def mod_inv(a: int, m: int) -> Optional[int]:\n \"\"\"\n Returns the modular inverse of a mod m, or None if it does not exist.\n\n The modular inverse of a is the number a_inv that satisfies the equation\n a_inv * a mod m === 1 mod m\n\n Note: This is a naive implementation, and runtime may be improved in several ways.\n For instance by checking if m is prime to perform a different calculation,\n or by using the extended euclidean algorithm.\n \"\"\"\n for i in range(1, m):\n if (m * i + 1) % a == 0:\n return (m * i + 1) // a\n return None\n\n @staticmethod\n def percentage(part: float, whole: float) -> float:\n \"\"\"Returns percentage.\n\n Just a normal algorithm to return the percent.\n\n Args:\n part -> part of the whole number\n whole -> the whole number\n\n Returns:\n Returns the percentage of part to whole.\n\n \"\"\"\n if part <= 0 or whole <= 0:\n return 0\n # works with percentages\n return 100 * float(part) / float(whole)\n\n def sort_prob_table(self, prob_table: dict) -> dict:\n \"\"\"Sorts the probability table.\n\n Sorts a dictionary of dictionaries (and all the sub-dictionaries).\n\n Args:\n prob_table -> The probability table returned by the neural network to sort.\n\n Returns:\n Returns the prob_table, but sorted.\n\n \"\"\"\n # for each object: prob table in dictionary\n max_overall: int = 0\n max_dict_pair: dict = {}\n highest_key = None\n empty_dict: dict = {}\n # sorts the prob table before we find max, and converts it to order dicts\n for key, value in prob_table.items():\n prob_table[key] = self.new_sort(value)\n prob_table[key] = dict(prob_table[key])\n\n # gets maximum key then sets it to the front\n counter_max: int = 0\n counter_prob: int = len(prob_table)\n while counter_max < counter_prob:\n max_overall = 0\n highest_key = None\n logging.debug(\n f\"Running while loop in sort_prob_table, counterMax is {counter_max}\"\n )\n for key, value in prob_table.items():\n logging.debug(f\"Sorting {key}\")\n maxLocal = 0\n # for each item in that table\n for key2, value2 in value.items():\n logging.debug(\n f\"Running key2 {key2}, value2 {value2} for loop for {value.items()}\"\n )\n maxLocal = maxLocal + value2\n logging.debug(\n f\"MaxLocal is {maxLocal} and maxOverall is {max_overall}\"\n )\n if maxLocal > max_overall:\n logging.debug(f\"New max local found {maxLocal}\")\n # because the dict doesn't reset\n max_dict_pair = {}\n max_overall = maxLocal\n # so eventually, we get the maximum dict pairing?\n max_dict_pair[key] = value\n highest_key = key\n logging.debug(f\"Highest key is {highest_key}\")\n # removes the highest key from the prob table\n logging.debug(\n f\"Prob table is {prob_table} and highest key is {highest_key}\"\n )\n logging.debug(f\"Removing {prob_table[highest_key]}\")\n del prob_table[highest_key]\n logging.debug(f\"Prob table after deletion is {prob_table}\")\n counter_max += 1\n empty_dict = {**empty_dict, **max_dict_pair}\n\n # returns the max dict (at the start) with the prob table\n # this way, it should always work on most likely first.\n logging.debug(\n f\"The prob table is {prob_table} and the maxDictPair is {max_dict_pair}\"\n )\n logging.debug(f\"The new sorted prob table is {empty_dict}\")\n return empty_dict\n\n @staticmethod\n def new_sort(new_dict: dict) -> dict:\n \"\"\"Uses OrderedDict to sort a dictionary.\n\n I think it's faster than my implementation.\n\n Args:\n new_dict -> the dictionary to sort\n\n Returns:\n Returns the dict, but sorted.\n\n \"\"\"\n # (f\"d is {d}\")\n logging.debug(f\"The old dictionary before new_sort() is {new_dict}\")\n sorted_i = OrderedDict(\n sorted(new_dict.items(), key=lambda x: x[1], reverse=True)\n )\n logging.debug(f\"The dictionary after new_sort() is {sorted_i}\")\n # sortedI = sort_dictionary(x)\n return sorted_i\n\n @staticmethod\n def is_ascii(s: str) -> bool:\n \"\"\"Returns the boolean value if is_ascii is an ascii char.\n\n Does what it says on the tree. Stolen from\n https://stackoverflow.com/questions/196345/how-to-check-if-a-string-in-python-is-in-ascii\n\n Args:\n s -> the char to check.\n\n Returns:\n Returns the boolean of the char.\n\n \"\"\"\n\n return bool(lambda s: len(s) == len(s.encode()))\n\n @staticmethod\n def strip_punctuation(text: str) -> str:\n \"\"\"Strips punctuation from a given string.\n\n Uses string.punctuation.\n\n Args:\n text -> the text to strip punctuation from.\n\n Returns:\n Returns string without punctuation.\n \"\"\"\n text: str = (str(text).translate(str.maketrans(\"\", \"\", punctuation))).strip(\n \"\\n\"\n )\n return text\n\n# Path: ciphey/basemods/Checkers/brandon.py\n\"\"\"\n \u2588\u2588\u2588\u2588\u2588\u2588\u2557\u2588\u2588\u2557\u2588\u2588\u2588\u2588\u2588\u2588\u2557 \u2588\u2588\u2557 \u2588\u2588\u2557\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2557\u2588\u2588\u2557 \u2588\u2588\u2557\n\u2588\u2588\u2554\u2550\u2550\u2550\u2550\u255d\u2588\u2588\u2551\u2588\u2588\u2554\u2550\u2550\u2588\u2588\u2557\u2588\u2588\u2551 \u2588\u2588\u2551\u2588\u2588\u2554\u2550\u2550\u2550\u2550\u255d\u255a\u2588\u2588\u2557 \u2588\u2588\u2554\u255d\n\u2588\u2588\u2551 \u2588\u2588\u2551\u2588\u2588\u2588\u2588\u2588\u2588\u2554\u255d\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2551\u2588\u2588\u2588\u2588\u2588\u2557 \u255a\u2588\u2588\u2588\u2588\u2554\u255d\n\u2588\u2588\u2551 \u2588\u2588\u2551\u2588\u2588\u2554\u2550\u2550\u2550\u255d \u2588\u2588\u2554\u2550\u2550\u2588\u2588\u2551\u2588\u2588\u2554\u2550\u2550\u255d \u255a\u2588\u2588\u2554\u255d\n\u255a\u2588\u2588\u2588\u2588\u2588\u2588\u2557\u2588\u2588\u2551\u2588\u2588\u2551 \u2588\u2588\u2551 \u2588\u2588\u2551\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2557 \u2588\u2588\u2551\n\u00a9 Brandon Skerritt\nGithub: brandonskerritt\n\nClass to determine whether something is English or not.\n1. Calculate the Chi Squared score of a sentence\n2. If the score is significantly lower than the average score, it _might_ be English\n 2.1. If the score _might_ be English, then take the text and compare it to the sorted dictionary\n in O(n log n) time.\n It creates a percentage of \"How much of this text is in the dictionary?\"\n The dictionary contains:\n * 20,000 most common US words\n * 10,000 most common UK words (there's no repetition between the two)\n * The top 10,000 passwords\n If the word \"Looks like\" English (chi-squared) and if it contains English words, we can conclude it is\n very likely English. The alternative is doing the dictionary thing but with an entire 479k word dictionary (slower)\n 2.2. If the score is not English, but we haven't tested enough to create an average, then test it against\n the dictionary\n\nThings to optimise:\n* We only run the dictionary if it's 20% smaller than the average for chi squared\n* We consider it \"English\" if 45% of the text matches the dictionary\n* We run the dictionary if there is less than 10 total chisquared test\n\nHow to add a language:\n* Download your desired dictionary. Try to make it the most popular words, for example. Place this file into this\n folder with languagename.txt\nAs an example, this comes built in with english.txt\nFind the statistical frequency of each letter in that language.\nFor English, we have:\nself.languages = {\n \"English\":\n [0.0855, 0.0160, 0.0316, 0.0387, 0.1210,0.0218, 0.0209, 0.0496, 0.0733, 0.0022,0.0081, 0.0421, 0.0253, 0.0717,\n 0.0747,0.0207, 0.0010, 0.0633, 0.0673, 0.0894,0.0268, 0.0106, 0.0183, 0.0019, 0.0172,0.0011]\n}\nIn chisquared.py\nTo add your language, do:\nself.languages = {\n \"English\":\n [0.0855, 0.0160, 0.0316, 0.0387, 0.1210,0.0218, 0.0209, 0.0496, 0.0733, 0.0022,0.0081, 0.0421, 0.0253, 0.0717,\n 0.0747,0.0207, 0.0010, 0.0633, 0.0673, 0.0894,0.0268, 0.0106, 0.0183, 0.0019, 0.0172,0.0011]\n \"German\": [0.0973]\n}\nIn alphabetical order\nAnd you're.... Done! Make sure the name of the two match up\n\"\"\"\nimport sys\nfrom math import ceil\nfrom typing import Dict, Optional\n\nimport logging\nfrom rich.logging import RichHandler\n\nfrom ciphey.iface import Checker, Config, ParamSpec, T, registry\n\nsys.path.append(\"..\")\ntry:\n import mathsHelper as mh\nexcept ModuleNotFoundError:\n import ciphey.mathsHelper as mh\n\n\n@registry.register\nclass Brandon(Checker[str]):\n \"\"\"\n Class designed to confirm whether something is **language** based on how many words of **language** appears\n Call confirmLanguage(text, language)\n * text: the text you want to confirm\n * language: the language you want to confirm\n\n Find out what language it is by using chisquared.py, the highest chisquared score is the language\n languageThreshold = 45\n if a string is 45% **language** words, then it's confirmed to be english\n \"\"\"\n\n def getExpectedRuntime(self, text: T) -> float:\n # TODO: actually work this out\n # TODO its 0.2 seconds on average\n return 1e-4 # 100 \u00b5s\n\n wordlist: set\n\n def clean_text(self, text: str) -> set:\n \"\"\"Cleans the text ready to be checked\n\n Strips punctuation, makes it lower case, turns it into a set separated by spaces, removes duplicate words\n\n Args:\n text -> The text we use to perform analysis on\n\n Returns:\n text -> the text as a list, now cleaned\n\n \"\"\"\n # makes the text unique words and readable\n text = text.lower()\n text = self.mh.strip_punctuation(text)\n text = text.split(\" \")\n text = filter(lambda x: len(x) > 2, text)\n text = set(text)\n return text\n\n def checker(self, text: str, threshold: float, text_length: int, var: set) -> bool:\n \"\"\"Given text determine if it passes checker\n\n The checker uses the variable passed to it. I.E. Stopwords list, 1k words, dictionary\n\n Args:\n text -> The text to check\n threshold -> at what point do we return True? The percentage of text that is in var before we return True\n text_length -> the length of the text\n var -> the variable we are checking against. Stopwords list, 1k words list, dictionary list.\n Returns:\n boolean -> True for it passes the test, False for it fails the test.\"\"\"\n if text is None:\n logging.debug(\"Checker's text is None, so returning False\")\n return False\n if var is None:\n logging.debug(\"Checker's input var is None, so returning False\")\n return False\n\n percent = ceil(text_length * threshold)\n logging.debug(f\"Checker's chunks are size {percent}\")\n meet_threshold = 0\n location = 0\n end = percent\n\n if text_length <= 0:\n return False\n\n while location <= text_length:\n # chunks the text, so only gets THRESHOLD chunks of text at a time\n text = list(text)\n to_analyse = text[location:end]\n logging.debug(f\"To analyse is {to_analyse}\")\n for word in to_analyse:\n # if word is a stopword, + 1 to the counter\n if word in var:\n logging.debug(\n f\"{word} is in var, which means I am +=1 to the meet_threshold which is {meet_threshold}\"\n )\n meet_threshold += 1\n meet_threshold_percent = meet_threshold / text_length\n if meet_threshold_percent >= threshold:\n logging.debug(\n f\"Returning true since the percentage is {meet_threshold / text_length} and the threshold is {threshold}\"\n )\n # if we meet the threshold, return True\n # otherwise, go over again until we do\n # We do this in the for loop because if we're at 24% and THRESHOLD is 25\n # we don't want to wait THRESHOLD to return true, we want to return True ASAP\n return True\n location = end\n end = end + percent\n logging.debug(\n f\"The language proportion {meet_threshold_percent} is under the threshold {threshold}\"\n )\n return False\n\n def __init__(self, config: Config):\n # Suppresses warning\n super().__init__(config)\n self.mh = mh.mathsHelper()\n\n phases = config.get_resource(self._params()[\"phases\"])\n\n self.thresholds_phase1 = phases[\"1\"]\n self.thresholds_phase2 = phases[\"2\"]\n self.top1000Words = config.get_resource(self._params().get(\"top1000\"))\n self.wordlist = config.get_resource(self._params()[\"wordlist\"])\n self.stopwords = config.get_resource(self._params().get(\"stopwords\"))\n\n self.len_phase1 = len(self.thresholds_phase1)\n self.len_phase2 = len(self.thresholds_phase2)\n\n def check(self, text: str) -> Optional[str]:\n \"\"\"Checks to see if the text is in English\n\n Performs a decryption, but mainly parses the internal data packet and prints useful information.\n\n Args:\n text -> The text we use to perform analysis on\n\n Returns:\n bool -> True if the text is English, False otherwise.\n\n \"\"\"\n logging.debug(f'In Language Checker with \"{text}\"')\n text = self.clean_text(text)\n logging.debug(f'Text split to \"{text}\"')\n if text == \"\":\n logging.debug(\"Returning None from Brandon as the text cleaned is none.\")\n return None\n\n length_text = len(text)\n\n what_to_use = {}\n\n # this code decides what checker / threshold to use\n # if text is over or equal to maximum size, just use the maximum possible checker\n what_to_use = self.calculateWhatChecker(\n length_text, self.thresholds_phase1.keys()\n )\n logging.debug(self.thresholds_phase1)\n what_to_use = self.thresholds_phase1[str(what_to_use)]\n # def checker(self, text: str, threshold: float, text_length: int, var: set) -> bool:\n if \"check\" in what_to_use:\n # perform check 1k words\n result = self.checker(\n text, what_to_use[\"check\"], length_text, self.top1000Words\n )\n elif \"stop\" in what_to_use:\n # perform stopwords\n result = self.checker(\n text, what_to_use[\"stop\"], length_text, self.stopwords\n )\n elif \"dict\" in what_to_use:\n result = self.checker(text, what_to_use[\"dict\"], length_text, self.wordlist)\n # If result is None, no point doing it again in phase2\n if not result:\n return None\n else:\n logging.info(f\"It is neither stop or check, but instead {what_to_use}\")\n\n # return False if phase 1 fails\n if not result:\n return None\n else:\n what_to_use = self.calculateWhatChecker(\n length_text, self.thresholds_phase2.keys()\n )\n what_to_use = self.thresholds_phase2[str(what_to_use)]\n result = self.checker(text, what_to_use[\"dict\"], length_text, self.wordlist)\n return \"\" if result else None\n\n def calculateWhatChecker(self, length_text, key):\n \"\"\"Calculates what threshold / checker to use\n\n If the length of the text is over the maximum sentence length, use the last checker / threshold\n Otherwise, traverse the keys backwards until we find a key range that does not fit.\n So we traverse backwards and see if the sentence length is between current - 1 and current\n In this way, we find the absolute lowest checker / percentage threshold.\n We traverse backwards because if the text is longer than the max sentence length, we already know.\n In total, the keys are only 5 items long or so. It is not expensive to move backwards, nor is it expensive to move forwards.\n\n Args:\n length_text -> The length of the text\n key -> What key we want to use. I.E. Phase1 keys, Phase2 keys.\n Returns:\n what_to_use -> the key of the lowest checker.\"\"\"\n\n _keys = list(key)\n _keys = list(map(int, _keys))\n if length_text >= int(_keys[-1]):\n what_to_use = list(key)[_keys.index(_keys[-1])]\n else:\n # this algorithm finds the smallest possible fit for the text\n for counter, i in reversed(list(enumerate(_keys))):\n # [0, 110, 150]\n if i <= length_text:\n what_to_use = i\n return what_to_use\n\n @staticmethod\n def getParams() -> Optional[Dict[str, ParamSpec]]:\n return {\n \"top1000\": ParamSpec(\n desc=\"A wordlist of the top 1000 words\",\n req=False,\n default=\"cipheydists::list::english1000\",\n ),\n \"wordlist\": ParamSpec(\n desc=\"A wordlist of all the words\",\n req=False,\n default=\"cipheydists::list::english\",\n ),\n \"stopwords\": ParamSpec(\n desc=\"A wordlist of StopWords\",\n req=False,\n default=\"cipheydists::list::englishStopWords\",\n ),\n \"threshold\": ParamSpec(\n desc=\"The minimum proportion (between 0 and 1) that must be in the dictionary\",\n req=False,\n default=0.45,\n ),\n \"phases\": ParamSpec(\n desc=\"Language-specific phase thresholds\",\n req=False,\n default=\"cipheydists::brandon::english\",\n ),\n }\n\n# Path: ciphey/basemods/Checkers/format.py\nimport json\nfrom typing import Dict, Optional\n\nimport logging\nfrom rich.logging import RichHandler\n\nfrom ciphey.iface import Checker, Config, ParamSpec, T, registry\n\n\n@registry.register\nclass JsonChecker(Checker[str]):\n\n \"\"\"\n This object is effectively a prebuilt quorum (with requirement 1) of common patterns\n \"\"\"\n\n def check(self, text: T) -> Optional[str]:\n logging.debug(\"Trying json checker\")\n\n # https://github.com/Ciphey/Ciphey/issues/389\n if text.isdigit():\n return None\n\n try:\n json.loads(text)\n return \"\"\n except ValueError:\n return None\n\n def getExpectedRuntime(self, text: T) -> float:\n # TODO: actually bench this\n return 1e-7 * len(text) # From benchmarks I found online\n\n def __init__(self, config: Config):\n super().__init__(config)\n\n @staticmethod\n def getParams() -> Optional[Dict[str, ParamSpec]]:\n pass\n\n# Path: ciphey/basemods/Checkers/human.py\nfrom typing import Dict, Optional\n\nfrom ciphey.iface import Checker, Config, ParamSpec, registry\nfrom rich.console import Console\nfrom rich.markup import escape\n\nconsole = Console()\n\n\n@registry.register\nclass HumanChecker(Checker[str]):\n\n \"\"\"\n Uses the person's decision to determine plaintext\n \"\"\"\n\n def check(self, ctext: str) -> Optional[str]:\n with self._config().pause_spinner_handle():\n response = console.input(\n f\"Possible plaintext: [blue bold]{escape(ctext.__repr__())}[/blue bold] ([green]y[/green]/[red]N[/red]): \"\n )\n if response == \"y\":\n return \"\"\n elif response in (\"n\", \"\"):\n return None\n else:\n return self.check(ctext)\n\n def getExpectedRuntime(self, text: str) -> float:\n return 1 # About a second\n\n @staticmethod\n def getParams() -> Optional[Dict[str, ParamSpec]]:\n return None\n\n def __init__(self, config: Config):\n super().__init__(config)\n\n# Path: ciphey/basemods/Checkers/quadgrams.py\nimport logging\nimport re\nfrom math import log10\nfrom typing import Dict, Optional\n\nfrom ciphey.iface import Checker, Config, ParamSpec, T, Translation, registry\nfrom rich.logging import RichHandler\n\n\n@registry.register\nclass Quadgrams(Checker[str]):\n\n \"\"\"\n Uses Quadgrams to determine plaintext\n \"\"\"\n\n def check(self, ctext: T) -> Optional[str]:\n logging.debug(\"Trying Quadgrams checker\")\n # Capitalize and remove everything that's not a letter\n ctext = re.sub(\"[^A-Z]\", \"\", ctext.upper())\n quadgrams = self.QUADGRAMS_DICT\n quadgrams_sum = sum(quadgrams.values())\n score = 0\n for key in quadgrams.keys():\n quadgrams[key] = float(quadgrams[key]) / quadgrams_sum\n floor = log10(0.01 / quadgrams_sum)\n for i in range(len(ctext) - 4 + 1):\n # Get all quadgrams from ctext and check if they're in the dict\n # If yes then add the score of those quadgrams to the total score\n if ctext[i : i + 4] in quadgrams:\n score += quadgrams[ctext[i : i + 4]]\n else:\n score += floor\n if len(ctext) > 0:\n score = score / len(ctext)\n logging.info(f\"Quadgrams is {score}\")\n # The default threshold was found to work the best from lots of testing\n if score > self.threshold:\n return \"\"\n return None\n\n def getExpectedRuntime(self, text: T) -> float:\n # TODO: actually bench this\n return 2e-7 * len(text)\n\n @staticmethod\n def getParams() -> Optional[Dict[str, ParamSpec]]:\n return {\n \"dict\": ParamSpec(\n desc=\"The quadgrams dictionary to use\",\n req=False,\n default=\"cipheydists::dist::quadgrams\",\n ),\n \"score\": ParamSpec(\n desc=\"The score threshold to use\",\n req=False,\n default=0.00011,\n ),\n }\n\n def __init__(self, config: Config):\n super().__init__(config)\n self.QUADGRAMS_DICT = config.get_resource(self._params()[\"dict\"], Translation)\n self.threshold = float(self._params()[\"score\"])\n\n# Path: ciphey/basemods/Checkers/regex.py\nimport re\nfrom typing import Dict, Optional\n\nimport logging\nfrom rich.logging import RichHandler\n\nfrom ciphey.iface import Checker, Config, ParamSpec, T, registry\n\n\n@registry.register\nclass Regex(Checker[str]):\n def getExpectedRuntime(self, text: T) -> float:\n return 1e-5 # TODO: actually calculate this\n\n def __init__(self, config: Config):\n super().__init__(config)\n self.regexes = list(map(re.compile, self._params()[\"regex\"]))\n logging.debug(f\"There are {len(self.regexes)} regexes\")\n\n def check(self, text: str) -> Optional[str]:\n for regex in self.regexes:\n logging.debug(f\"Trying regex {regex} on {text}\")\n res = regex.search(text)\n logging.debug(f\"Results: {res}\")\n if res:\n return f\"Passed with regex {regex}. Want to contribute to Ciphey? Submit your regex here to allow Ciphey to automatically get this next time https://github.com/bee-san/pyWhat/wiki/Adding-your-own-Regex\\n\"\n\n @staticmethod\n def getParams() -> Optional[Dict[str, ParamSpec]]:\n return {\n \"regex\": ParamSpec(\n req=True,\n desc=\"The regex that must be matched (in a substring)\",\n list=True,\n )\n }\n\n\n@registry.register\nclass RegexList(Checker[str]):\n def getExpectedRuntime(self, text: T) -> float:\n return 1e-5 # TODO: actually calculate this\n\n def __init__(self, config: Config):\n super().__init__(config)\n self.regexes = []\n for i in self._params()[\"resource\"]:\n self.regexes += [re.compile(regex) for regex in config.get_resource(i)]\n logging.debug(f\"There are {len(self.regexes)} regexes\")\n\n def check(self, text: str) -> Optional[str]:\n for regex in self.regexes:\n logging.debug(f\"Trying regex {regex} on {text}\")\n res = regex.search(text)\n logging.debug(f\"Results: {res}\")\n if res:\n return f\"passed with regex {regex}\"\n\n...\n", "needle_token_start": 1653, "needle_token_end": 1687, "code_context_ntokens": 15333, "output": ["```python\ndef __ge__(self, other: \"DecoderComparer\"):\n return self.value.priority() >= other.value.priority()\n```"]} +{"repo": "Ciphey/Ciphey", "name": "alpha", "language": "python", "path": "ciphey/basemods/Crackers/hash.py", "position_ratio": 0.25, "description": "\nFunction Description:\n1. **Purpose**: The function is designed to handle a specific type of cryptographic hash cracking, integrating with a broader decryption framework.\n2. **Input**: The function accepts two parameters: a ciphertext (string) and a hash type (string).\n3. **Output**: The function returns `None`, indicating either a placeholder implementation or that the output is handled through side effects not visible in the provided snippet.\n4. **Procedure**: Currently, the function does not implement any operational code, as indicated by the direct return of `None`. This suggests it might be a stub or template for future development or integration.\n\n", "instruction": "Based on the function description and code context, please retrieve and repeat the exact described function from the code context in a code block wrapped by ```:", "template": "instruction\ncode_context\ndescription\ninstruction", "code_context": "# Path: ciphey/basemods/Searchers/ausearch.py\nimport bisect\nimport distutils\nimport math\nfrom copy import copy\nfrom dataclasses import dataclass\nfrom functools import lru_cache\nfrom typing import Any, Dict, Generic, List, Optional, TypeVar, Union\n\nimport logging\nfrom rich.logging import RichHandler\n\nfrom ciphey.iface import (\n Checker,\n Config,\n Cracker,\n CrackInfo,\n CrackResult,\n Decoder,\n...\n# Path: ciphey/basemods/Searchers/__init__.py\nfrom . import ausearch\n\n# Path: ciphey/basemods/__init__.py\nfrom . import Checkers, Crackers, Decoders, Resources, Searchers\n\n# Path: ciphey/ciphey.py\n\"\"\"\n \u2588\u2588\u2588\u2588\u2588\u2588\u2557\u2588\u2588\u2557\u2588\u2588\u2588\u2588\u2588\u2588\u2557 \u2588\u2588\u2557 \u2588\u2588\u2557\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2557\u2588\u2588\u2557 \u2588\u2588\u2557\n\u2588\u2588\u2554\u2550\u2550\u2550\u2550\u255d\u2588\u2588\u2551\u2588\u2588\u2554\u2550\u2550\u2588\u2588\u2557\u2588\u2588\u2551 \u2588\u2588\u2551\u2588\u2588\u2554\u2550\u2550\u2550\u2550\u255d\u255a\u2588\u2588\u2557 \u2588\u2588\u2554\u255d\n\u2588\u2588\u2551 \u2588\u2588\u2551\u2588\u2588\u2588\u2588\u2588\u2588\u2554\u255d\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2551\u2588\u2588\u2588\u2588\u2588\u2557 \u255a\u2588\u2588\u2588\u2588\u2554\u255d\n\u2588\u2588\u2551 \u2588\u2588\u2551\u2588\u2588\u2554\u2550\u2550\u2550\u255d \u2588\u2588\u2554\u2550\u2550\u2588\u2588\u2551\u2588\u2588\u2554\u2550\u2550\u255d \u255a\u2588\u2588\u2554\u255d\n\u255a\u2588\u2588\u2588\u2588\u2588\u2588\u2557\u2588\u2588\u2551\u2588\u2588\u2551 \u2588\u2588\u2551 \u2588\u2588\u2551\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2557 \u2588\u2588\u2551\nhttps://github.com/Ciphey\nhttps://github.com/Ciphey/Ciphey/wiki\n\nThe cycle goes:\nmain -> argparsing (if needed) -> call_encryption -> new Ciphey object -> decrypt() -> produceProbTable ->\none_level_of_decryption -> decrypt_normal\n\"\"\"\nimport os\nimport warnings\nfrom typing import Any, Optional, Union\n\nimport click\nfrom appdirs import AppDirs\nimport logging\nfrom rich.logging import RichHandler\nfrom rich.console import Console\n\nfrom . import iface\n\nwarnings.filterwarnings(\"ignore\")\n\nconsole = Console()\n\n\ndef decrypt(config: iface.Config, ctext: Any) -> Union[str, bytes]:\n \"\"\"A simple alias for searching a ctext and makes the answer pretty\"\"\"\n res: Optional[iface.SearchResult] = config.objs[\"searcher\"].search(ctext)\n if res is None:\n return \"Failed to crack\"\n if config.verbosity < 0:\n return res.path[-1].result.value\n else:\n return iface.pretty_search_results(res)\n\n\ndef get_name(ctx, param, value):\n # reads from stdin if value was not supplied\n if not value and not click.get_text_stream(\"stdin\").isatty():\n click.get_text_stream(\"stdin\").read().strip()\n return click.get_text_stream(\"stdin\").read().strip()\n else:\n return value\n\n\ndef print_help(ctx):\n # prints help menu\n # if no arguments are passed\n click.echo(ctx.get_help())\n ctx.exit()\n\n\n@click.command()\n@click.option(\n \"-t\",\n \"--text\",\n help=\"The ciphertext you want to decrypt.\",\n type=str,\n)\n@click.option(\n \"-q\", \"--quiet\", help=\"Decrease verbosity\", type=int, count=True, default=None\n)\n@click.option(\n \"-g\",\n \"--greppable\",\n help=\"Only print the answer (useful for grep)\",\n type=bool,\n is_flag=True,\n default=None,\n)\n@click.option(\"-v\", \"--verbose\", count=True, type=int)\n@click.option(\"-C\", \"--checker\", help=\"Use the given checker\", default=None)\n@click.option(\n \"-c\",\n \"--config\",\n help=\"Uses the given config file. Defaults to appdirs.user_config_dir('ciphey', 'ciphey')/'config.yml'\",\n)\n@click.option(\"-w\", \"--wordlist\", help=\"Uses the given wordlist\")\n@click.option(\n \"-p\",\n \"--param\",\n help=\"Passes a parameter to the language checker\",\n multiple=True,\n)\n@click.option(\n \"-l\",\n \"--list-params\",\n help=\"List the parameters of the selected module\",\n type=bool,\n)\n@click.option(\n \"--searcher\",\n help=\"Select the searching algorithm to use\",\n)\n# HARLAN TODO XXX\n# I switched this to a boolean flag system\n# https://click.palletsprojects.com/en/7.x/options/#boolean-flags\n# True for bytes input, False for str\n@click.option(\n \"-b\",\n \"--bytes\",\n help=\"Forces ciphey to use binary mode for the input\",\n is_flag=True,\n default=None,\n)\n@click.option(\n \"--default-dist\",\n help=\"Sets the default character/byte distribution\",\n type=str,\n default=None,\n)\n@click.option(\n \"-m\",\n \"--module\",\n help=\"Adds a module from the given path\",\n type=click.Path(),\n multiple=True,\n)\n@click.option(\n \"-A\",\n \"--appdirs\",\n help=\"Print the location of where Ciphey wants the settings file to be\",\n type=bool,\n is_flag=True,\n)\n@click.option(\"-f\", \"--file\", type=click.File(\"rb\"), required=False)\n@click.argument(\"text_stdin\", callback=get_name, required=False)\ndef main(**kwargs):\n \"\"\"Ciphey - Automated Decryption Tool\n\n Documentation:\n https://github.com/Ciphey/Ciphey/wiki\\n\n Discord (support here, we're online most of the day):\n http://discord.skerritt.blog\\n\n GitHub:\n https://github.com/ciphey/ciphey\\n\n\n Ciphey is an automated decryption tool using smart artificial intelligence and natural language processing. Input encrypted text, get the decrypted text back.\n\n Examples:\\n\n Basic Usage: ciphey -t \"aGVsbG8gbXkgbmFtZSBpcyBiZWU=\"\n\n \"\"\"\n\n \"\"\"Function to deal with arguments. Either calls with args or not. Makes Pytest work.\n\n It gets the arguments in the function definition using locals()\n if withArgs is True, that means this is being called with command line args\n so go to arg_parsing() to get those args\n we then update locals() with the new command line args and remove \"withArgs\"\n This function then calls call_encryption(**result) which passes our dict of args\n to the function as its own arguments using dict unpacking.\n Returns:\n The output of the decryption.\n \"\"\"\n\n # if user wants to know where appdirs is\n # print and exit\n if \"appdirs\" in kwargs and kwargs[\"appdirs\"]:\n dirs = AppDirs(\"Ciphey\", \"Ciphey\")\n path_to_config = dirs.user_config_dir\n print(\n f\"The settings.yml file should be at {os.path.join(path_to_config, 'settings.yml')}\"\n )\n return None\n\n # Now we create the config object\n config = iface.Config()\n\n # Load the settings file into the config\n load_msg: str\n cfg_arg = kwargs[\"config\"]\n if cfg_arg is None:\n # Make sure that the config dir actually exists\n os.makedirs(iface.Config.get_default_dir(), exist_ok=True)\n config.load_file(create=True)\n load_msg = f\"Opened config file at {os.path.join(iface.Config.get_default_dir(), 'config.yml')}\"\n else:\n config.load_file(cfg_arg)\n load_msg = f\"Opened config file at {cfg_arg}\"\n\n # Load the verbosity, so that we can start logging\n verbosity = kwargs[\"verbose\"]\n quiet = kwargs[\"quiet\"]\n if verbosity is None:\n if quiet is not None:\n verbosity = -quiet\n elif quiet is not None:\n verbosity -= quiet\n if kwargs[\"greppable\"] is not None:\n verbosity -= 999\n # Use the existing value as a base\n config.verbosity += verbosity\n config.update_log_level(config.verbosity)\n logging.info(load_msg)\n logging.debug(f\"Got cmdline args {kwargs}\")\n\n # Now we load the modules\n module_arg = kwargs[\"module\"]\n if module_arg is not None:\n config.modules += list(module_arg)\n\n # We need to load formats BEFORE we instantiate objects\n if kwargs[\"bytes\"] is not None:\n config.update_format(\"bytes\")\n\n # Next, load the objects\n params = kwargs[\"param\"]\n if params is not None:\n for i in params:\n key, value = i.split(\"=\", 1)\n parent, name = key.split(\".\", 1)\n config.update_param(parent, name, value)\n config.update(\"checker\", kwargs[\"checker\"])\n config.update(\"searcher\", kwargs[\"searcher\"])\n config.update(\"default_dist\", kwargs[\"default_dist\"])\n\n config.complete_config()\n\n logging.debug(f\"Command line opts: {kwargs}\")\n logging.debug(f\"Config finalised: {config}\")\n\n # Finally, we load the plaintext\n if kwargs[\"text\"] is None:\n if kwargs[\"file\"] is not None:\n kwargs[\"text\"] = kwargs[\"file\"].read()\n elif kwargs[\"text_stdin\"] is not None:\n kwargs[\"text\"] = kwargs[\"text_stdin\"]\n else:\n # else print help menu\n print(\"[bold red]Error. No inputs were given to Ciphey. [bold red]\")\n\n @click.pass_context\n def all_procedure(ctx):\n print_help(ctx)\n\n all_procedure()\n\n return None\n\n if issubclass(config.objs[\"format\"], type(kwargs[\"text\"])):\n pass\n elif config.objs[\"format\"] == str and isinstance(kwargs[\"text\"], bytes):\n kwargs[\"text\"] = kwargs[\"text\"].decode(\"utf-8\")\n elif config.objs[\"format\"] == bytes and isinstance(kwargs[\"text\"], str):\n kwargs[\"text\"] = kwargs[\"text\"].encode(\"utf-8\")\n else:\n raise TypeError(f\"Cannot load type {config.format} from {type(kwargs['text'])}\")\n\n result: Optional[str]\n\n # if debug or quiet mode is on, run without spinner\n if config.verbosity != 0:\n result = decrypt(config, kwargs[\"text\"])\n else:\n # else, run with spinner if verbosity is 0\n with console.status(\"[bold green]Thinking...\", spinner=\"moon\") as status:\n config.set_spinner(status)\n result = decrypt(config, kwargs[\"text\"])\n if result is None:\n result = \"Could not find any solutions.\"\n\n console.print(result)\n\n# Path: ciphey/__init__.py\nfrom . import basemods, common, iface\nfrom .ciphey import decrypt\n\n# Path: ciphey/__main__.py\n#! /usr/bin/env python3\n\n\"\"\"\nCiphey: https://github.com/Ciphey/Ciphey\n\"\"\"\n\nimport platform\nimport sys\n\nif __name__ == \"__main__\":\n major = sys.version_info[0]\n minor = sys.version_info[1]\n\n python_version = (\n str(sys.version_info[0])\n + \".\"\n + str(sys.version_info[1])\n + \".\"\n + str(sys.version_info[2])\n )\n\n if major != 3 or minor < 6:\n print(\n f\"Ciphey requires Python 3.6+, you are using {python_version}. Please install a higher Python version. https://www.python.org/downloads/\"\n )\n print(\n \"Alternatively, visit our Discord and use the Ciphey bot in #bots http://discord.skerritt.blog\"\n )\n sys.exit(1)\n if platform.system() == \"Windows\":\n if minor > 8:\n print(\n \"Ciphey does not currently support Python 3.9 on Windows. Please use the Discord bot at http://discord.skerritt.blog\"\n )\n sys.exit(1)\n\n if sys.maxsize > 2 ** 32 is False:\n print(\n \"You are using Python 32 bit and Windows, Ciphey does not support this. Please upgrade to Python 64-bit here https://www.python.org/downloads/\"\n )\n sys.exit(1)\n from .ciphey import main\n\n main()\n\n# Path: ciphey/basemods/Checkers/entropy.py\nfrom typing import Dict, Optional\n\nimport logging\nfrom rich.logging import RichHandler\n\nfrom ciphey.iface import Checker, Config, ParamSpec, T, registry\n\n\n@registry.register\nclass Entropy(Checker[str]):\n\n \"\"\"\n Uses entropy to determine plaintext\n \"\"\"\n\n def check(self, text: T) -> Optional[str]:\n logging.debug(\"Trying entropy checker\")\n pass\n\n def getExpectedRuntime(self, text: T) -> float:\n # TODO: actually bench this\n # Uses benchmark from Discord\n return 2e-7 * len(text)\n\n def __init__(self, config: Config):\n super().__init__(config)\n\n @staticmethod\n def getParams() -> Optional[Dict[str, ParamSpec]]:\n pass\n\n# Path: ciphey/basemods/Checkers/gtest.py\nfrom typing import Dict, Optional\n\nimport logging\nfrom rich.logging import RichHandler\n\nfrom ciphey.iface import Checker, Config, ParamSpec, T, registry\n\n\n@registry.register\nclass GTestChecker(Checker[str]):\n\n \"\"\"\n G-test of fitness, similar to Chi squared.\n \"\"\"\n\n def check(self, text: T) -> Optional[str]:\n logging.debug(\"Trying entropy checker\")\n pass\n\n def getExpectedRuntime(self, text: T) -> float:\n # TODO: actually bench this\n return 4e-7 * len(text)\n\n def __init__(self, config: Config):\n super().__init__(config)\n\n @staticmethod\n def getParams() -> Optional[Dict[str, ParamSpec]]:\n pass\n\n# Path: ciphey/basemods/Crackers/hash.py\n\"\"\"\nThis is Hashbuster but slightly modified to work with Ciphey.\nWhy reinvent the wheel?\nChanges (that I can remember)\n* timeout set, as hashbuster took AGES before timeout was set.\nhttps://github.com/s0md3v/Hash-Buster\n\"\"\"\n\nimport re\nfrom typing import Dict, List, Optional\n\nimport requests\nimport logging\nfrom rich.logging import RichHandler\n\nfrom ciphey.iface import Config, Cracker, CrackInfo, CrackResult, ParamSpec, T, registry\n\nthread_count = 4\n\n\n\ndef alpha(ctext, hashtype):\n return None\n\n\ndef beta(ctext, hashtype):\n try:\n response = requests.get(\n \"https://hashtoolkit.com/reverse-hash/?hash=\", ctext, timeout=5\n ).text\n except requests.exceptions.ReadTimeout as e:\n logging.info(f\"Beta failed timeout {e}\")\n match = re.search(r'/generate-hash/?text=.*?\"', response)\n if match:\n return match.group(1)\n return None\n\n\ndef gamma(ctext, hashtype):\n try:\n response = requests.get(\n \"https://www.nitrxgen.net/md5db/\" + ctext, timeout=5\n ).text\n except requests.exceptions.ReadTimeout as e:\n logging.info(f\"Gamma failed with {e}\")\n if response:\n return response\n else:\n return None\n\n\ndef delta(ctext, hashtype):\n return None\n\n\ndef theta(ctext, hashtype):\n try:\n response = requests.get(\n \"https://md5decrypt.net/Api/api.php?hash=%s&hash_type=%s&email=deanna_abshire@proxymail.eu&code=1152464b80a61728\"\n % (ctext, hashtype),\n timeout=5,\n ).text\n except requests.exceptions.ReadTimeout as e:\n logging.info(f\"Gamma failed with {e}\")\n if len(response) != 0:\n return response\n else:\n return None\n\n\nmd5 = [gamma, alpha, beta, theta, delta]\nsha1 = [alpha, beta, theta, delta]\nsha256 = [alpha, beta, theta]\nsha384 = [alpha, beta, theta]\nsha512 = [alpha, beta, theta]\n\n\nresult = {}\n\n\ndef crack(ctext):\n raise \"Error Crack is called\"\n\n\ndef threaded(ctext):\n resp = crack(ctext)\n if resp:\n print(ctext + \" : \" + resp)\n result[ctext] = resp\n\n\n@registry.register\nclass HashBuster(Cracker[str]):\n @staticmethod\n def getTarget() -> str:\n return \"hash\"\n\n @staticmethod\n def getParams() -> Optional[Dict[str, ParamSpec]]:\n return None\n\n @staticmethod\n def priority() -> float:\n return 0.05\n\n def getInfo(self, ctext: T) -> CrackInfo:\n # TODO calculate these properly\n return CrackInfo(\n success_likelihood=0.5,\n success_runtime=5,\n failure_runtime=5,\n )\n\n def attemptCrack(self, ctext: T) -> List[CrackResult]:\n logging.info(\"Starting to crack hashes\")\n result = False\n\n candidates = []\n if len(ctext) == 32:\n for api in md5:\n r = api(ctext, \"md5\")\n if result is not None or r is not None:\n logging.debug(\"MD5 returns True {r}\")\n candidates.append(result, \"MD5\")\n elif len(ctext) == 40:\n for api in sha1:\n r = api(ctext, \"sha1\")\n if result is not None and r is not None:\n logging.debug(\"sha1 returns true\")\n candidates.append(result, \"SHA1\")\n elif len(ctext) == 64:\n for api in sha256:\n r = api(ctext, \"sha256\")\n if result is not None and r is not None:\n logging.debug(\"sha256 returns true\")\n candidates.append(result, \"SHA256\")\n elif len(ctext) == 96:\n for api in sha384:\n r = api(ctext, \"sha384\")\n if result is not None and r is not None:\n logging.debug(\"sha384 returns true\")\n candidates.append(result, \"SHA384\")\n elif len(ctext) == 128:\n for api in sha512:\n r = api(ctext, \"sha512\")\n if result is not None and r is not None:\n logging.debug(\"sha512 returns true\")\n candidates.append(result, \"SHA512\")\n\n # TODO what the fuck is this code?\n logging.debug(f\"Hash buster returning {result}\")\n # TODO add to 5.1 make this return multiple possible candidates\n return [CrackResult(value=candidates[0][0], misc_info=candidates[1][1])]\n\n def __init__(self, config: Config):\n super().__init__(config)\n\n# Path: ciphey/basemods/Decoders/base58_flickr.py\nfrom typing import Dict, Optional\n\nimport base58\n\nfrom ciphey.iface import Config, Decoder, ParamSpec, T, U, registry\n\n\n@registry.register\nclass Base58_flickr(Decoder[str]):\n def decode(self, ctext: T) -> Optional[U]:\n \"\"\"\n Performs Base58 (Flickr) decoding\n \"\"\"\n FLICKR_ALPHABET = b\"123456789abcdefghijkmnopqrstuvwxyzABCDEFGHJKLMNPQRSTUVWXYZ\"\n try:\n return base58.b58decode(ctext, alphabet=FLICKR_ALPHABET).decode(\"utf-8\")\n except Exception:\n return None\n\n @staticmethod\n def priority() -> float:\n # Not expected to show up often, but also very fast to check.\n return 0.05\n\n def __init__(self, config: Config):\n super().__init__(config)\n\n @staticmethod\n def getParams() -> Optional[Dict[str, ParamSpec]]:\n return None\n\n @staticmethod\n def getTarget() -> str:\n return \"base58_flickr\"\n\n# Path: ciphey/basemods/Decoders/base64_url.py\nimport base64\nfrom typing import Dict, Optional\n\nfrom ciphey.iface import Config, Decoder, ParamSpec, T, U, registry\n\n\n@registry.register\nclass Base64_url(Decoder[str]):\n def decode(self, ctext: T) -> Optional[U]:\n \"\"\"\n Performs Base64 URL decoding\n \"\"\"\n ctext_padding = ctext + \"=\" * (4 - len(ctext) % 4)\n try:\n return base64.urlsafe_b64decode(ctext_padding).decode(\"utf-8\")\n except Exception:\n return None\n\n @staticmethod\n def priority() -> float:\n # Not expected to show up often, but also very fast to check.\n return 0.05\n\n def __init__(self, config: Config):\n super().__init__(config)\n\n @staticmethod\n def getParams() -> Optional[Dict[str, ParamSpec]]:\n return None\n\n @staticmethod\n def getTarget() -> str:\n return \"base64_url\"\n\n# Path: ciphey/basemods/Decoders/base65536.py\nfrom typing import Dict, Optional\n\nimport base65536\n\nfrom ciphey.iface import Config, Decoder, ParamSpec, T, U, registry\n\n\n@registry.register\nclass Base65536(Decoder[str]):\n def decode(self, ctext: T) -> Optional[U]:\n \"\"\"\n Performs Base65536 decoding\n \"\"\"\n try:\n return base65536.decode(ctext).decode(\"utf-8\")\n except Exception:\n return None\n\n @staticmethod\n def priority() -> float:\n # Not expected to show up often, but also very fast to check.\n return 0.05\n\n def __init__(self, config: Config):\n super().__init__(config)\n\n @staticmethod\n def getParams() -> Optional[Dict[str, ParamSpec]]:\n return None\n\n @staticmethod\n def getTarget() -> str:\n return \"base65536\"\n\n# Path: ciphey/basemods/Decoders/z85.py\nfrom typing import Dict, Optional\n\nimport logging\nfrom rich.logging import RichHandler\nfrom zmq.utils import z85\n\nfrom ciphey.iface import Config, Decoder, ParamSpec, T, U, registry\n\n\n@registry.register\nclass Z85(Decoder[str]):\n def decode(self, ctext: T) -> Optional[U]:\n \"\"\"\n Performs Z85 decoding\n \"\"\"\n ctext_len = len(ctext)\n if ctext_len % 5:\n logging.debug(\n f\"Failed to decode Z85 because length must be a multiple of 5, not '{ctext_len}'\"\n )\n return None\n try:\n return z85.decode(ctext).decode(\"utf-8\")\n except Exception:\n return None\n\n @staticmethod\n def priority() -> float:\n # Not expected to show up often, but also very fast to check.\n return 0.05\n\n def __init__(self, config: Config):\n super().__init__(config)\n\n @staticmethod\n def getParams() -> Optional[Dict[str, ParamSpec]]:\n return None\n\n @staticmethod\n def getTarget() -> str:\n return \"z85\"\n\n# Path: ciphey/basemods/Searchers/perfection.py\nfrom typing import Dict, Optional, Set\n\nfrom ciphey.iface import Config, ParamSpec, registry\n\nfrom .ausearch import AuSearch, Node\n\n\n@registry.register\nclass Perfection(AuSearch):\n @staticmethod\n def getParams() -> Optional[Dict[str, ParamSpec]]:\n return None\n\n def findBestNode(self, nodes: Set[Node]) -> Node:\n return next(iter(nodes))\n\n def __init__(self, config: Config):\n super().__init__(config)\n\n# Path: ciphey/basemods/Searchers/astar.py\nimport cipheycore\n\n\nclass Node:\n \"\"\"\n A node has a value associated with it\n Calculated from the heuristic\n \"\"\"\n\n def __init__(\n self,\n config,\n h: float = None,\n edges: (any, float) = None,\n ctext: str = None,\n ):\n self.weight = h\n # Edges is a list of other nodes it can connect to\n self.edges = edges\n self.ctext = ctext\n self.h = h\n self.path = []\n self.information_content = config.cache.get_or_update(\n self.text,\n \"cipheycore::info_content\",\n lambda: cipheycore.info_content(self.ctext),\n )\n\n def __le__(self, node2):\n # if self is less than other\n return self.x <= node2.x\n\n def __lt__(self, node2):\n return self.x < node2.x\n\n def append_edge(self, edge):\n self.edges.append(edge)\n\n def get_edges(self):\n return self.edges\n\n\nclass Graph:\n # example of adjacency list (or rather map)\n # adjacency_list = {\n # 'A': [('B', 1), ('C', 3), ('D', 7)],\n # 'B': [('D', 5)],\n # 'C': [('D', 12)]\n # }\n\n def __init__(self, adjacency_list):\n \"\"\"\n adjacency list: basically the graph\n \"\"\"\n self.adjacency_list = adjacency_list\n self.original_input = cipheycore.info_content(input)\n\n def get_neighbors(self, v):\n try:\n return self.adjacency_list[v]\n except KeyError:\n # If we have exhausted the adjacency list\n return []\n\n # heuristic function with equal values for all nodes\n def heuristic(self, n: Node):\n return n.info_content / self.original_input\n\n def a_star_algorithm(self, start_node: Node, stop_node: Node):\n # TODO store the graph as an attribute\n # open_list is a list of nodes which have been visited, but who's neighbors\n # haven't all been inspected, starts off with the start node\n # closed_list is a list of nodes which have been visited\n # and who's neighbors have been inspected\n open_list = set([start_node])\n closed_list = set([])\n\n # g contains current distances from start_node to all other nodes\n # the default value (if it's not found in the map) is +infinity\n g = {}\n\n g[start_node] = 0\n\n # parents contains an adjacency map of all nodes\n parents = {}\n parents[start_node] = start_node\n\n while len(open_list) > 0:\n print(f\"The open list is {open_list}\")\n n = None\n\n # find a node with the lowest value of f() - evaluation function\n for v in open_list:\n # TODO if v == decoder, run the decoder\n print(f\"The for loop node v is {v}\")\n if n is None or g[v] + self.h(v) < g[n] + self.h(n):\n n = v\n print(f\"The value of n is {n}\")\n\n if n is None:\n print(\"Path does not exist!\")\n return None\n\n # if the current node is the stop_node\n # then we begin reconstructin the path from it to the start_node\n # NOTE Uncomment this for an exit condition\n # TODO Make it exit if decrypter returns True\n # TODO We need to append the decryption methods to each node\n # So when we reconstruct the path we can reconstruct the decryptions\n # used\n if n == stop_node:\n print(\"n is the stop node, we are stopping!\")\n reconst_path = []\n\n while parents[n] != n:\n reconst_path.append(n)\n n = parents[n]\n\n reconst_path.append(start_node)\n\n reconst_path.reverse()\n\n print(\"Path found: {}\".format(reconst_path))\n return reconst_path\n\n print(n)\n for (m, weight) in self.get_neighbors(n):\n print(f\"And the iteration is ({m}, {weight})\")\n # if the current node isn't in both open_list and closed_list\n # add it to open_list and note n as it's parent\n if m not in open_list and m not in closed_list:\n open_list.add(m)\n parents[m] = n\n g[m] = g[n] + weight\n\n # otherwise, check if it's quicker to first visit n, then m\n # and if it is, update parent data and g data\n # and if the node was in the closed_list, move it to open_list\n else:\n if g[m] > g[n] + weight:\n g[m] = g[n] + weight\n parents[m] = n\n\n if m in closed_list:\n closed_list.remove(m)\n open_list.add(m)\n\n # remove n from the open_list, and add it to closed_list\n # because all of his neighbors were inspected\n # open_list.remove(node)\n # closed_list.add(node)\n\n open_list.remove(n)\n closed_list.add(n)\n print(\"\\n\")\n\n print(\"Path does not exist!\")\n return None\n\n\nadjacency_list = {\n \"A\": [(\"B\", 1), (\"C\", 3), (\"D\", 7)],\n \"B\": [(\"D\", 5)],\n \"C\": [(\"D\", 12)],\n}\nA = Node(1)\nB = Node(7)\nC = Node(9)\nD = Node(16)\n\nA.edges = [(B, 1), (C, 3), (D, 7)]\nB.edges = [(D, 5)]\nC.edges = [(D, 12)]\n\n# TODO use a dictionary comprehension to make this\nadjacency_list = {\n A: A.edges,\n B: B.edges,\n C: C.edges,\n}\ngraph1 = Graph(adjacency_list)\ngraph1.a_star_algorithm(A, D)\n\n\"\"\"\nMaybe after it\n\"\"\"\n\n# Path: ciphey/basemods/Searchers/imperfection.py\nimport heapq\n\n\nclass Imperfection:\n \"\"\"The graph is a Node: [List of nodes]\n Where each item in the list of nodes can also have a node with a list of nodes\n\n The result is that we can keep track of edges, while also keeping it small\n\n To calculate current, we push the entire graph to A*\n\n And it calculates the next node to choose, as well as increasing the size\n of the graph with values\n\n We're using a heap, meaning the element at [0] is always the smallest element\n\n So we choose that and return it.\n\n\n The current A* implementation has an end, we simply do not let it end as LC will make it\n end far before it reaches Searcher again.\n\n Current is the start position, so if we say we always start at the start of the graph it'll\n go through the entire graph\n\n graph = {\n Node: [\n {Node :\n {\n node\n }\n }\n ]\n }\n\n For encodings we just do them straight out\n\n The last value of parents from abstract\n \"\"\"\n\n \"\"\"\n\n graph = {'A': ['B', 'C'],\n 'B': ['C', 'D'],\n 'C': ['D'],\n 'D': ['C'],\n 'E': ['F'],\n 'F': ['C']}\"\"\"\n\n def __init__(self):\n None\n\n def findBestNode(self, nodes):\n \"\"\"Finds the best decryption module\"\"\"\n return next(iter(nodes))\n\n # def aStar(self, graph, current, end):\n # \"\"\"The A* search algorithm\n\n # We're using heaps to find the minimum element (the one that will be the next current)\n # Heaps are like sets with O(1) lookup time, but maintain the lowest element as [0]\n # Sets insert in O(1), heaps in O(log N).\n\n # https://stackoverflow.com/questions/4159331/python-speed-up-an-a-star-pathfinding-algorithm\n\n # Current appears to be the list of all new tiles we can reach from current location\n\n # End is the end node, that won't actually run bc LC will make it return before it hits aSTar function\n # so tbh I'll just make it infinite unless something else forces a return\n\n # The graph is the actual data structure used. According to StackOverflow, it looks like this:\n\n # graph = {'A': ['B', 'C'],\n # 'B': ['C', 'D'],\n # 'C': ['D'],\n # 'D': ['C'],\n # 'E': ['F'],\n # 'F': ['C']}\n\n # \"\"\"\n\n # # Runs decodings first\n\n # openSet = set()\n # openHeap = []\n # closedSet = set()\n\n # def retracePath(c):\n # # Retraces a path back to the start\n # path = [c]\n # while c.parent is not None:\n # c = c.parent\n # path.append(c)\n # path.reverse()\n # return path\n\n # # Adds the current location (start) to the heap and set\n # openSet.add(current)\n # openHeap.append((0, current))\n\n # # while openSet contains items\n # while openSet:\n # # TODO change openSet to a heap?\n # # gets the 2nd element from the first element of the heap\n # # so the heap is (0, current)\n # # which means we pop current\n # # this makes me think that current isn't the first?\n # current = heapq.heappop(openHeap)[1]\n # # We don't actually want to end, so I'm commenting this:\n # # XXX\n # if current == end:\n # return retracePath(current)\n # # Removes it from todo and into done i think\n # # closedSet appears to be the set of things we have done\n # openSet.remove(current)\n # closedSet.add(current)\n\n # \"\"\"\n # Okay so our graph looks like this:\n # graph = {\n # Node: [\n # {Node :\n # {\n # node\n # }\n # }\n # ]\n # }\n # graph[current] **SHOULD** be the list of nodes which contains dictionaries of nodes\n\n # \"\"\"\n # for tile in graph[current]:\n # # ClosedSet appears to be the list of visited nodes\n # # TODO place this as a class attribute\n # if tile not in closedSet:\n # # This is the heuristic\n # # TODO expected_time/probability + k * heuristic, for some experimentally determined value of k\n # tile.H = (abs(end.x - tile.x) + abs(end.y - tile.y)) * 10\n\n # # if tile is not in the openSet, add it and then pop it from the heap\n # if tile not in openSet:\n # openSet.add(tile)\n # heapq.heappush(openHeap, (tile.H, tile))\n # # I have no idea where this code is called lol\n # tile.parent = current\n\n # # This returns Nothing\n # # I need to modify it so it finds the best item from Current\n # # So basically, return item 0 of openHeap\n # # return openHeap[0]\n # # Since the [0] item is always minimum\n # return []\n def aStar(self, graph, current, end):\n print(f\"The graph is {graph}\\nCurrent is {current}\\n and End is {end}\")\n openSet = set()\n openHeap = []\n closedSet = set()\n\n def retracePath(c):\n print(\"Calling retrace path\")\n path = [c]\n while c.parent is not None:\n c = c.parent\n path.append(c)\n path.reverse()\n return path\n\n print(\"\\n\")\n\n openSet.add(current)\n openHeap.append((0, current))\n while openSet:\n print(f\"Openset is {openSet}\")\n print(f\"OpenHeap is {openHeap}\")\n print(f\"ClosedSet is {closedSet}\")\n print(f\"Current is {current}\")\n print(f\"I am popping {openHeap} with the first element\")\n current = heapq.heappop(openHeap)[1]\n print(f\"Current is now {current}\")\n print(f\"Graph current is {graph[current]}\")\n if current == end:\n return retracePath(current)\n openSet.remove(current)\n closedSet.add(current)\n for tile in graph[current]:\n if tile not in closedSet:\n tile.H = (abs(end.x - tile.x) + abs(end.y - tile.y)) * 10\n tile.H = 1\n if tile not in openSet:\n openSet.add(tile)\n heapq.heappush(openHeap, (tile.H, tile))\n tile.parent = current\n print(\"\\n\")\n return []\n\n\nclass Node:\n \"\"\"\n A node has a value associated with it\n Calculated from the heuristic\n \"\"\"\n\n def __init__(self, h):\n self.h = h\n self.x = self.h\n self.y = 0.6\n\n def __le__(self, node2):\n # if self is less than other\n return self.x <= node2.x\n\n def __lt__(self, node2):\n return self.x < node2.x\n\n\nif __name__ == \"__main__\":\n obj = Imperfection()\n graph = {\n \"A\": [\"B\", \"C\"],\n \"B\": [\"C\", \"D\"],\n \"C\": [\"D\"],\n \"D\": [\"C\"],\n \"E\": [\"F\"],\n \"F\": [\"C\"],\n }\n # Makes the graph\n y = Node(0.5)\n x = Node(0.3)\n p = Node(0.7)\n q = Node(0.9)\n graph = {y: [x, p], p: q}\n\n print(obj.aStar(graph, y, q))\n\n", "needle_token_start": 4015, "needle_token_end": 4029, "code_context_ntokens": 10573, "output": ["```python\ndef alpha(ctext, hashtype):\n return None\n```"]} +{"repo": "Ciphey/Ciphey", "name": "strip_punctuation", "language": "python", "path": "ciphey/mathsHelper.py", "position_ratio": 0.35, "description": "\nFunction Description:\n1. **Purpose**: To remove all punctuation characters from a provided string, making it cleaner for text processing or analysis.\n2. **Input**: A single string which may contain various characters including punctuation.\n3. **Output**: A modified version of the input string with all punctuation characters removed.\n4. **Procedure**: The function utilizes a translation table to replace each punctuation character with an empty string, effectively removing them from the input string. It also trims newline characters from the ends of the string.\n\n", "instruction": "Based on the function description and code context, please retrieve and repeat the exact described function from the code context in a code block wrapped by ```:", "template": "instruction\ncode_context\ndescription\ninstruction", "code_context": "# Path: ciphey/iface/_config.py\nimport datetime\nimport os\nimport pydoc\nfrom typing import Any, Callable, Dict, List, Optional, Type, Union\n\nimport appdirs\nimport yaml\nimport logging\nfrom rich.logging import RichHandler\n\nfrom . import _fwd\nfrom ._modules import PolymorphicChecker, ResourceLoader, Searcher\n\n\nclass Cache:\n \"\"\"Used to track state between levels of recursion to stop infinite loops, and to optimise repeating actions\"\"\"\n\n def __init__(self):\n self._cache: Dict[Any, Dict[str, Any]] = {}\n\n def mark_ctext(self, ctext: Any) -> bool:\n if (isinstance(ctext, str) or isinstance(ctext, bytes)) and len(ctext) < 4:\n logging.debug(f\"Candidate {ctext.__repr__()} too short!\")\n return False\n\n if ctext in self._cache:\n logging.debug(f\"Deduped {ctext.__repr__()}\")\n return False\n\n logging.debug(f\"New ctext {ctext.__repr__()}\")\n\n self._cache[ctext] = {}\n return True\n\n def get_or_update(self, ctext: Any, keyname: str, get_value: Callable[[], Any]):\n # Should have been marked first\n target = self._cache[ctext]\n res = target.get(keyname)\n if res is not None:\n return res\n\n val = get_value()\n target[keyname] = val\n return val\n\n def try_get(self, ctext: Any, keyname: str):\n return self._cache[ctext].get(keyname)\n\n\ndef split_resource_name(full_name: str) -> (str, str):\n return full_name.split(\"::\", 1)\n\n\nclass Config:\n def __init__(self):\n self.verbosity: int = 0\n self.searcher: str = \"ausearch\"\n self.params: Dict[str, Dict[str, Union[str, List[str]]]] = {}\n self.format: str = \"str\"\n self.modules: List[str] = []\n self.checker: str = \"ezcheck\"\n self.default_dist: str = \"cipheydists::dist::english\"\n self.timeout: Optional[int] = None\n self._inst: Dict[type, Any] = {}\n self.objs: Dict[str, Any] = {}\n self.cache: Cache = Cache()\n\n @staticmethod\n def get_default_dir() -> str:\n return appdirs.user_config_dir(\"ciphey\")\n\n def merge_dict(self, config_file: Optional[Dict[str, Any]]):\n if config_file is None:\n return\n for a, b in config_file.items():\n self.update(a, b)\n\n def load_file(\n self,\n path: str = os.path.join(get_default_dir.__func__(), \"config.yml\"),\n create=False,\n ):\n try:\n with open(path, \"r+\") as file:\n return self.merge_dict(yaml.safe_load(file))\n except FileNotFoundError:\n if create:\n open(path, \"w+\")\n\n def instantiate(self, t: type) -> Any:\n \"\"\"\n Used to enable caching of a instantiated type after the configuration has settled\n \"\"\"\n # We cannot use set default as that would construct it again, and throw away the result\n res = self._inst.get(t)\n if res is not None:\n return res\n ret = t(self)\n self._inst[t] = ret\n return ret\n\n def __call__(self, t: type) -> Any:\n return self.instantiate(t)\n\n def update(self, attrname: str, value: Optional[Any]):\n if value is not None:\n setattr(self, attrname, value)\n\n def update_param(self, owner: str, name: str, value: Optional[Any]):\n if value is None:\n return\n\n target = self.params.setdefault(owner, {})\n\n if _fwd.registry.get_named(owner).getParams()[name].list:\n target.setdefault(name, []).append(value)\n else:\n target[name] = value\n\n def update_format(self, value: Optional[str]):\n if value is not None:\n self.format = value\n\n def load_objs(self):\n # Basic type conversion\n if self.timeout is not None:\n self.objs[\"timeout\"] = datetime.timedelta(seconds=int(self.timeout))\n self.objs[\"format\"] = pydoc.locate(self.format)\n\n # Checkers do not depend on any other config object\n logging.debug(f\"Registry is {_fwd.registry._reg[PolymorphicChecker]}\")\n self.objs[\"checker\"] = self(\n _fwd.registry.get_named(self.checker, PolymorphicChecker)\n )\n # Searchers only depend on checkers\n self.objs[\"searcher\"] = self(_fwd.registry.get_named(self.searcher, Searcher))\n\n def update_log_level(self, verbosity: Optional[int]):\n if verbosity is None:\n return\n self.verbosity = verbosity\n\n if verbosity == 0:\n self.verbosity = logging.WARNING\n elif verbosity == 1:\n self.verbosity = logging.INFO\n elif verbosity >= 2:\n self.verbosity = logging.DEBUG\n else:\n logging.disable(logging.CRITICAL)\n return\n\n...\n# Path: ciphey/iface/_registry.py\nfrom typing import Any, Dict, List, Optional, Set, Tuple, Type, Union\n\ntry:\n from typing import get_args, get_origin\nexcept ImportError:\n from typing_inspect import get_origin, get_args\n\nfrom . import _fwd\nfrom ._modules import *\n\n\nclass Registry:\n # I was planning on using __init_subclass__, but that is incompatible with dynamic type creation when we have\n # generic keys\n\n RegElem = Union[List[Type], Dict[Type, \"RegElem\"]]\n\n _reg: Dict[Type, RegElem] = {}\n _names: Dict[str, Tuple[Type, Set[Type]]] = {}\n _targets: Dict[str, Dict[Type, List[Type]]] = {}\n _modules = {Checker, Cracker, Decoder, ResourceLoader, Searcher, PolymorphicChecker}\n\n def _register_one(self, input_type, module_base, module_args):\n if len(module_args) == 0:\n self._reg.setdefault(module_base, []).append(input_type)\n return\n\n target_reg = self._reg.setdefault(module_base, {})\n # Seek to the given type\n for subtype in module_args[0:-1]:\n target_reg = target_reg.setdefault(subtype, {})\n target_reg.setdefault(module_args[-1], []).append(input_type)\n\n def _real_register(self, input_type: type, *args) -> Type:\n name = input_type.__name__.lower()\n name_target = self._names[name] = (input_type, set())\n\n if issubclass(input_type, Targeted):\n target = input_type.getTarget()\n else:\n target = None\n\n if issubclass(input_type, Searcher):\n module_type = module_base = Searcher\n module_args = ()\n else:\n module_type: Optional[Type] = None\n module_base = None\n\n # Work out what module type this is\n if len(args) == 0 and hasattr(input_type, \"__orig_bases__\"):\n for i in input_type.__orig_bases__:\n if module_type is not None:\n raise TypeError(\n f\"Type derived from multiple registrable base classes {i} and {module_type}\"\n )\n module_base = get_origin(i)\n if module_base not in self._modules:\n continue\n module_type = i\n else:\n for i in self._modules:\n if not issubclass(input_type, i):\n continue\n if module_type is not None:\n raise TypeError(\n f\"Type derived from multiple registrable base classes {i} and {module_type}\"\n )\n module_type = i\n if module_type is None:\n raise TypeError(\"No registrable base class\")\n\n # Replace input type with polymorphic checker if required\n if issubclass(input_type, Checker):\n if len(args) == 0:\n arg = [\n get_args(i)\n for i in input_type.__orig_bases__\n if get_origin(i) == Checker\n ][0]\n if len(arg) != 1:\n raise TypeError(\"No argument for Checker\")\n input_type = input_type.convert({arg[0]})\n else:\n input_type = input_type.convert(set(args))\n self._register_one(input_type, PolymorphicChecker, [])\n # Refresh the names with the new type\n name_target = self._names[name] = (input_type, {PolymorphicChecker})\n\n # Now handle the difference between register and register_multi\n if len(args) == 0:\n if module_type is PolymorphicChecker:\n module_base = PolymorphicChecker\n elif module_base is None:\n raise TypeError(\"No type argument given\")\n self._register_one(input_type, module_base, get_args(module_type))\n name_target[1].add(module_base)\n else:\n if module_base is not None:\n raise TypeError(f\"Redundant type argument for {module_type}\")\n module_base = module_type\n for module_args in args:\n # Correct missing brackets\n if not isinstance(module_args, tuple):\n module_args = (module_args,)\n\n self._register_one(input_type, module_base, module_args)\n name_target[1].add(module_type[module_args])\n\n name_target[1].add(module_type)\n\n if target is not None and issubclass(module_base, Targeted):\n self._targets.setdefault(target, {}).setdefault(module_type, []).append(\n input_type\n )\n\n return input_type\n\n def register(self, input_type):\n return self._real_register(input_type)\n\n def register_multi(self, *x):\n return lambda input_type: self._real_register(input_type, *x)\n\n def __getitem__(self, i: type) -> Optional[Any]:\n target_type = get_origin(i)\n # Check if this is a non-generic type, and return the whole dict if it is\n if target_type is None:\n return self._reg[i]\n\n target_subtypes = get_args(i)\n target_list = self._reg.setdefault(target_type, {})\n for subtype in target_subtypes:\n target_list = target_list.setdefault(subtype, {})\n return target_list\n\n def get_named(self, name: str, type_constraint: Type = None) -> Any:\n ret = self._names[name.lower()]\n if type_constraint and type_constraint not in ret[1]:\n raise TypeError(f\"Type mismatch: wanted {type_constraint}, got {ret[1]}\")\n return ret[0]\n\n def get_targeted(\n self, target: str, type_constraint: Type = None\n ) -> Optional[Union[Dict[Type, Set[Type]], Set[Type]]]:\n x = self._targets.get(target)\n if x is None or type_constraint is None:\n return x\n return x.get(type_constraint)\n\n def get_all_names(self) -> List[str]:\n return list(self._names.keys())\n\n def __str__(self):\n return f\"ciphey.iface.Registry {{_reg: {self._reg}, _names: {self._names}, _targets: {self._targets}}}\"\n\n\n_fwd.registry = Registry()\n\n# Path: ciphey/iface/__init__.py\nfrom ._config import Config\n\nfrom ._modules import (\n Checker,\n Cracker,\n CrackInfo,\n CrackResult,\n Decoder,\n DecoderComparer,\n Distribution,\n ParamSpec,\n PolymorphicChecker,\n ResourceLoader,\n Searcher,\n SearchLevel,\n SearchResult,\n T,\n Translation,\n U,\n WordList,\n pretty_search_results,\n)\nfrom ._registry import get_args, get_origin\n\nfrom ._fwd import registry\n\n# Path: ciphey/basemods/Checkers/any.py\nfrom typing import Dict, Optional\n\nfrom ciphey.iface import Config, ParamSpec, PolymorphicChecker, registry\n\n\n@registry.register\nclass Any(PolymorphicChecker):\n \"\"\"Should only be used for debugging, frankly\"\"\"\n\n def getExpectedRuntime(self, text) -> float:\n return 0 # TODO: actually calculate this\n\n def __init__(self, config: Config):\n super().__init__(config)\n\n def check(self, text: str) -> Optional[str]:\n return \"\"\n\n @staticmethod\n def getParams() -> Optional[Dict[str, ParamSpec]]:\n pass\n\n# Path: ciphey/mathsHelper.py\n\"\"\"\n \u2588\u2588\u2588\u2588\u2588\u2588\u2557\u2588\u2588\u2557\u2588\u2588\u2588\u2588\u2588\u2588\u2557 \u2588\u2588\u2557 \u2588\u2588\u2557\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2557\u2588\u2588\u2557 \u2588\u2588\u2557\n\u2588\u2588\u2554\u2550\u2550\u2550\u2550\u255d\u2588\u2588\u2551\u2588\u2588\u2554\u2550\u2550\u2588\u2588\u2557\u2588\u2588\u2551 \u2588\u2588\u2551\u2588\u2588\u2554\u2550\u2550\u2550\u2550\u255d\u255a\u2588\u2588\u2557 \u2588\u2588\u2554\u255d\n\u2588\u2588\u2551 \u2588\u2588\u2551\u2588\u2588\u2588\u2588\u2588\u2588\u2554\u255d\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2551\u2588\u2588\u2588\u2588\u2588\u2557 \u255a\u2588\u2588\u2588\u2588\u2554\u255d\n\u2588\u2588\u2551 \u2588\u2588\u2551\u2588\u2588\u2554\u2550\u2550\u2550\u255d \u2588\u2588\u2554\u2550\u2550\u2588\u2588\u2551\u2588\u2588\u2554\u2550\u2550\u255d \u255a\u2588\u2588\u2554\u255d\n\u255a\u2588\u2588\u2588\u2588\u2588\u2588\u2557\u2588\u2588\u2551\u2588\u2588\u2551 \u2588\u2588\u2551 \u2588\u2588\u2551\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2557 \u2588\u2588\u2551\n\u00a9 Brandon Skerritt\nGithub: brandonskerritt\n\nClass to provide helper functions for mathematics\n(oh, not entirely mathematics either. Some NLP stuff and sorting dicts. It's just a helper class\n)\n\"\"\"\n\nfrom collections import OrderedDict\nfrom string import punctuation\nfrom typing import Optional\n\nimport logging\nfrom rich.logging import RichHandler\n\n\nclass mathsHelper:\n \"\"\"Class to provide helper functions for mathematics and other small things\"\"\"\n\n def __init__(self):\n # ETAOIN is the most popular letters in order\n self.ETAOIN = \"ETAOINSHRDLCUMWFGYPBVKJXQZ\"\n self.LETTERS = \"ABCDEFGHIJKLMNOPQRSTUVWXYZ\"\n\n @staticmethod\n def gcd(a, b) -> int:\n \"\"\"Greatest common divisor.\n\n The Greatest Common Divisor of a and b using Euclid's Algorithm.\n\n Args:\n a -> num 1\n b -> num 2\n\n Returns:\n Returns GCD(a, b)\n\n \"\"\"\n # Return\n while a != 0:\n a, b = b % a, a\n return b\n\n @staticmethod\n def mod_inv(a: int, m: int) -> Optional[int]:\n \"\"\"\n Returns the modular inverse of a mod m, or None if it does not exist.\n\n The modular inverse of a is the number a_inv that satisfies the equation\n a_inv * a mod m === 1 mod m\n\n Note: This is a naive implementation, and runtime may be improved in several ways.\n For instance by checking if m is prime to perform a different calculation,\n or by using the extended euclidean algorithm.\n \"\"\"\n for i in range(1, m):\n if (m * i + 1) % a == 0:\n return (m * i + 1) // a\n return None\n\n @staticmethod\n def percentage(part: float, whole: float) -> float:\n \"\"\"Returns percentage.\n\n Just a normal algorithm to return the percent.\n\n Args:\n part -> part of the whole number\n whole -> the whole number\n\n Returns:\n Returns the percentage of part to whole.\n\n \"\"\"\n if part <= 0 or whole <= 0:\n return 0\n # works with percentages\n return 100 * float(part) / float(whole)\n\n def sort_prob_table(self, prob_table: dict) -> dict:\n \"\"\"Sorts the probability table.\n\n Sorts a dictionary of dictionaries (and all the sub-dictionaries).\n\n Args:\n prob_table -> The probability table returned by the neural network to sort.\n\n Returns:\n Returns the prob_table, but sorted.\n\n \"\"\"\n # for each object: prob table in dictionary\n max_overall: int = 0\n max_dict_pair: dict = {}\n highest_key = None\n empty_dict: dict = {}\n # sorts the prob table before we find max, and converts it to order dicts\n for key, value in prob_table.items():\n prob_table[key] = self.new_sort(value)\n prob_table[key] = dict(prob_table[key])\n\n # gets maximum key then sets it to the front\n counter_max: int = 0\n counter_prob: int = len(prob_table)\n while counter_max < counter_prob:\n max_overall = 0\n highest_key = None\n logging.debug(\n f\"Running while loop in sort_prob_table, counterMax is {counter_max}\"\n )\n for key, value in prob_table.items():\n logging.debug(f\"Sorting {key}\")\n maxLocal = 0\n # for each item in that table\n for key2, value2 in value.items():\n logging.debug(\n f\"Running key2 {key2}, value2 {value2} for loop for {value.items()}\"\n )\n maxLocal = maxLocal + value2\n logging.debug(\n f\"MaxLocal is {maxLocal} and maxOverall is {max_overall}\"\n )\n if maxLocal > max_overall:\n logging.debug(f\"New max local found {maxLocal}\")\n # because the dict doesn't reset\n max_dict_pair = {}\n max_overall = maxLocal\n # so eventually, we get the maximum dict pairing?\n max_dict_pair[key] = value\n highest_key = key\n logging.debug(f\"Highest key is {highest_key}\")\n # removes the highest key from the prob table\n logging.debug(\n f\"Prob table is {prob_table} and highest key is {highest_key}\"\n )\n logging.debug(f\"Removing {prob_table[highest_key]}\")\n del prob_table[highest_key]\n logging.debug(f\"Prob table after deletion is {prob_table}\")\n counter_max += 1\n empty_dict = {**empty_dict, **max_dict_pair}\n\n # returns the max dict (at the start) with the prob table\n # this way, it should always work on most likely first.\n logging.debug(\n f\"The prob table is {prob_table} and the maxDictPair is {max_dict_pair}\"\n )\n logging.debug(f\"The new sorted prob table is {empty_dict}\")\n return empty_dict\n\n @staticmethod\n def new_sort(new_dict: dict) -> dict:\n \"\"\"Uses OrderedDict to sort a dictionary.\n\n I think it's faster than my implementation.\n\n Args:\n new_dict -> the dictionary to sort\n\n Returns:\n Returns the dict, but sorted.\n\n \"\"\"\n # (f\"d is {d}\")\n logging.debug(f\"The old dictionary before new_sort() is {new_dict}\")\n sorted_i = OrderedDict(\n sorted(new_dict.items(), key=lambda x: x[1], reverse=True)\n )\n logging.debug(f\"The dictionary after new_sort() is {sorted_i}\")\n # sortedI = sort_dictionary(x)\n return sorted_i\n\n @staticmethod\n def is_ascii(s: str) -> bool:\n \"\"\"Returns the boolean value if is_ascii is an ascii char.\n\n Does what it says on the tree. Stolen from\n https://stackoverflow.com/questions/196345/how-to-check-if-a-string-in-python-is-in-ascii\n\n Args:\n s -> the char to check.\n\n Returns:\n Returns the boolean of the char.\n\n \"\"\"\n\n return bool(lambda s: len(s) == len(s.encode()))\n\n @staticmethod\n def strip_punctuation(text: str) -> str:\n \"\"\"Strips punctuation from a given string.\n\n Uses string.punctuation.\n\n Args:\n text -> the text to strip punctuation from.\n\n Returns:\n Returns string without punctuation.\n \"\"\"\n text: str = (str(text).translate(str.maketrans(\"\", \"\", punctuation))).\nstrip(\n \"\\n\"\n )\n return text\n\n# Path: ciphey/basemods/Checkers/brandon.py\n\"\"\"\n \u2588\u2588\u2588\u2588\u2588\u2588\u2557\u2588\u2588\u2557\u2588\u2588\u2588\u2588\u2588\u2588\u2557 \u2588\u2588\u2557 \u2588\u2588\u2557\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2557\u2588\u2588\u2557 \u2588\u2588\u2557\n\u2588\u2588\u2554\u2550\u2550\u2550\u2550\u255d\u2588\u2588\u2551\u2588\u2588\u2554\u2550\u2550\u2588\u2588\u2557\u2588\u2588\u2551 \u2588\u2588\u2551\u2588\u2588\u2554\u2550\u2550\u2550\u2550\u255d\u255a\u2588\u2588\u2557 \u2588\u2588\u2554\u255d\n\u2588\u2588\u2551 \u2588\u2588\u2551\u2588\u2588\u2588\u2588\u2588\u2588\u2554\u255d\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2551\u2588\u2588\u2588\u2588\u2588\u2557 \u255a\u2588\u2588\u2588\u2588\u2554\u255d\n\u2588\u2588\u2551 \u2588\u2588\u2551\u2588\u2588\u2554\u2550\u2550\u2550\u255d \u2588\u2588\u2554\u2550\u2550\u2588\u2588\u2551\u2588\u2588\u2554\u2550\u2550\u255d \u255a\u2588\u2588\u2554\u255d\n\u255a\u2588\u2588\u2588\u2588\u2588\u2588\u2557\u2588\u2588\u2551\u2588\u2588\u2551 \u2588\u2588\u2551 \u2588\u2588\u2551\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2557 \u2588\u2588\u2551\n\u00a9 Brandon Skerritt\nGithub: brandonskerritt\n\nClass to determine whether something is English or not.\n1. Calculate the Chi Squared score of a sentence\n2. If the score is significantly lower than the average score, it _might_ be English\n 2.1. If the score _might_ be English, then take the text and compare it to the sorted dictionary\n in O(n log n) time.\n It creates a percentage of \"How much of this text is in the dictionary?\"\n The dictionary contains:\n * 20,000 most common US words\n * 10,000 most common UK words (there's no repetition between the two)\n * The top 10,000 passwords\n If the word \"Looks like\" English (chi-squared) and if it contains English words, we can conclude it is\n very likely English. The alternative is doing the dictionary thing but with an entire 479k word dictionary (slower)\n 2.2. If the score is not English, but we haven't tested enough to create an average, then test it against\n the dictionary\n\nThings to optimise:\n* We only run the dictionary if it's 20% smaller than the average for chi squared\n* We consider it \"English\" if 45% of the text matches the dictionary\n* We run the dictionary if there is less than 10 total chisquared test\n\nHow to add a language:\n* Download your desired dictionary. Try to make it the most popular words, for example. Place this file into this\n folder with languagename.txt\nAs an example, this comes built in with english.txt\nFind the statistical frequency of each letter in that language.\nFor English, we have:\nself.languages = {\n \"English\":\n [0.0855, 0.0160, 0.0316, 0.0387, 0.1210,0.0218, 0.0209, 0.0496, 0.0733, 0.0022,0.0081, 0.0421, 0.0253, 0.0717,\n 0.0747,0.0207, 0.0010, 0.0633, 0.0673, 0.0894,0.0268, 0.0106, 0.0183, 0.0019, 0.0172,0.0011]\n}\nIn chisquared.py\nTo add your language, do:\nself.languages = {\n \"English\":\n [0.0855, 0.0160, 0.0316, 0.0387, 0.1210,0.0218, 0.0209, 0.0496, 0.0733, 0.0022,0.0081, 0.0421, 0.0253, 0.0717,\n 0.0747,0.0207, 0.0010, 0.0633, 0.0673, 0.0894,0.0268, 0.0106, 0.0183, 0.0019, 0.0172,0.0011]\n \"German\": [0.0973]\n}\nIn alphabetical order\nAnd you're.... Done! Make sure the name of the two match up\n\"\"\"\nimport sys\nfrom math import ceil\nfrom typing import Dict, Optional\n\nimport logging\nfrom rich.logging import RichHandler\n\nfrom ciphey.iface import Checker, Config, ParamSpec, T, registry\n\nsys.path.append(\"..\")\ntry:\n import mathsHelper as mh\nexcept ModuleNotFoundError:\n import ciphey.mathsHelper as mh\n\n\n@registry.register\nclass Brandon(Checker[str]):\n \"\"\"\n Class designed to confirm whether something is **language** based on how many words of **language** appears\n Call confirmLanguage(text, language)\n * text: the text you want to confirm\n * language: the language you want to confirm\n\n Find out what language it is by using chisquared.py, the highest chisquared score is the language\n languageThreshold = 45\n if a string is 45% **language** words, then it's confirmed to be english\n \"\"\"\n\n def getExpectedRuntime(self, text: T) -> float:\n # TODO: actually work this out\n # TODO its 0.2 seconds on average\n return 1e-4 # 100 \u00b5s\n\n wordlist: set\n\n def clean_text(self, text: str) -> set:\n \"\"\"Cleans the text ready to be checked\n\n Strips punctuation, makes it lower case, turns it into a set separated by spaces, removes duplicate words\n\n Args:\n text -> The text we use to perform analysis on\n\n Returns:\n text -> the text as a list, now cleaned\n\n \"\"\"\n # makes the text unique words and readable\n text = text.lower()\n text = self.mh.strip_punctuation(text)\n text = text.split(\" \")\n text = filter(lambda x: len(x) > 2, text)\n text = set(text)\n return text\n\n def checker(self, text: str, threshold: float, text_length: int, var: set) -> bool:\n \"\"\"Given text determine if it passes checker\n\n The checker uses the variable passed to it. I.E. Stopwords list, 1k words, dictionary\n\n Args:\n text -> The text to check\n threshold -> at what point do we return True? The percentage of text that is in var before we return True\n text_length -> the length of the text\n var -> the variable we are checking against. Stopwords list, 1k words list, dictionary list.\n Returns:\n boolean -> True for it passes the test, False for it fails the test.\"\"\"\n if text is None:\n logging.debug(\"Checker's text is None, so returning False\")\n return False\n if var is None:\n logging.debug(\"Checker's input var is None, so returning False\")\n return False\n\n percent = ceil(text_length * threshold)\n logging.debug(f\"Checker's chunks are size {percent}\")\n meet_threshold = 0\n location = 0\n end = percent\n\n if text_length <= 0:\n return False\n\n while location <= text_length:\n # chunks the text, so only gets THRESHOLD chunks of text at a time\n text = list(text)\n to_analyse = text[location:end]\n logging.debug(f\"To analyse is {to_analyse}\")\n for word in to_analyse:\n # if word is a stopword, + 1 to the counter\n if word in var:\n logging.debug(\n f\"{word} is in var, which means I am +=1 to the meet_threshold which is {meet_threshold}\"\n )\n meet_threshold += 1\n meet_threshold_percent = meet_threshold / text_length\n if meet_threshold_percent >= threshold:\n logging.debug(\n f\"Returning true since the percentage is {meet_threshold / text_length} and the threshold is {threshold}\"\n )\n # if we meet the threshold, return True\n # otherwise, go over again until we do\n # We do this in the for loop because if we're at 24% and THRESHOLD is 25\n # we don't want to wait THRESHOLD to return true, we want to return True ASAP\n return True\n location = end\n end = end + percent\n logging.debug(\n f\"The language proportion {meet_threshold_percent} is under the threshold {threshold}\"\n )\n return False\n\n def __init__(self, config: Config):\n # Suppresses warning\n super().__init__(config)\n self.mh = mh.mathsHelper()\n\n phases = config.get_resource(self._params()[\"phases\"])\n\n self.thresholds_phase1 = phases[\"1\"]\n self.thresholds_phase2 = phases[\"2\"]\n self.top1000Words = config.get_resource(self._params().get(\"top1000\"))\n self.wordlist = config.get_resource(self._params()[\"wordlist\"])\n self.stopwords = config.get_resource(self._params().get(\"stopwords\"))\n\n self.len_phase1 = len(self.thresholds_phase1)\n self.len_phase2 = len(self.thresholds_phase2)\n\n def check(self, text: str) -> Optional[str]:\n \"\"\"Checks to see if the text is in English\n\n Performs a decryption, but mainly parses the internal data packet and prints useful information.\n\n Args:\n text -> The text we use to perform analysis on\n\n Returns:\n bool -> True if the text is English, False otherwise.\n\n \"\"\"\n logging.debug(f'In Language Checker with \"{text}\"')\n text = self.clean_text(text)\n logging.debug(f'Text split to \"{text}\"')\n if text == \"\":\n logging.debug(\"Returning None from Brandon as the text cleaned is none.\")\n return None\n\n length_text = len(text)\n\n what_to_use = {}\n\n # this code decides what checker / threshold to use\n # if text is over or equal to maximum size, just use the maximum possible checker\n what_to_use = self.calculateWhatChecker(\n length_text, self.thresholds_phase1.keys()\n )\n logging.debug(self.thresholds_phase1)\n what_to_use = self.thresholds_phase1[str(what_to_use)]\n # def checker(self, text: str, threshold: float, text_length: int, var: set) -> bool:\n if \"check\" in what_to_use:\n # perform check 1k words\n result = self.checker(\n text, what_to_use[\"check\"], length_text, self.top1000Words\n )\n elif \"stop\" in what_to_use:\n # perform stopwords\n result = self.checker(\n text, what_to_use[\"stop\"], length_text, self.stopwords\n )\n elif \"dict\" in what_to_use:\n result = self.checker(text, what_to_use[\"dict\"], length_text, self.wordlist)\n # If result is None, no point doing it again in phase2\n if not result:\n return None\n else:\n logging.info(f\"It is neither stop or check, but instead {what_to_use}\")\n\n # return False if phase 1 fails\n if not result:\n return None\n else:\n what_to_use = self.calculateWhatChecker(\n length_text, self.thresholds_phase2.keys()\n )\n what_to_use = self.thresholds_phase2[str(what_to_use)]\n result = self.checker(text, what_to_use[\"dict\"], length_text, self.wordlist)\n return \"\" if result else None\n\n def calculateWhatChecker(self, length_text, key):\n \"\"\"Calculates what threshold / checker to use\n\n If the length of the text is over the maximum sentence length, use the last checker / threshold\n Otherwise, traverse the keys backwards until we find a key range that does not fit.\n So we traverse backwards and see if the sentence length is between current - 1 and current\n In this way, we find the absolute lowest checker / percentage threshold.\n We traverse backwards because if the text is longer than the max sentence length, we already know.\n In total, the keys are only 5 items long or so. It is not expensive to move backwards, nor is it expensive to move forwards.\n\n Args:\n length_text -> The length of the text\n key -> What key we want to use. I.E. Phase1 keys, Phase2 keys.\n Returns:\n what_to_use -> the key of the lowest checker.\"\"\"\n\n _keys = list(key)\n _keys = list(map(int, _keys))\n if length_text >= int(_keys[-1]):\n what_to_use = list(key)[_keys.index(_keys[-1])]\n else:\n # this algorithm finds the smallest possible fit for the text\n for counter, i in reversed(list(enumerate(_keys))):\n # [0, 110, 150]\n if i <= length_text:\n what_to_use = i\n return what_to_use\n\n @staticmethod\n def getParams() -> Optional[Dict[str, ParamSpec]]:\n return {\n \"top1000\": ParamSpec(\n desc=\"A wordlist of the top 1000 words\",\n req=False,\n default=\"cipheydists::list::english1000\",\n ),\n \"wordlist\": ParamSpec(\n desc=\"A wordlist of all the words\",\n req=False,\n default=\"cipheydists::list::english\",\n ),\n \"stopwords\": ParamSpec(\n desc=\"A wordlist of StopWords\",\n req=False,\n default=\"cipheydists::list::englishStopWords\",\n ),\n \"threshold\": ParamSpec(\n desc=\"The minimum proportion (between 0 and 1) that must be in the dictionary\",\n req=False,\n default=0.45,\n ),\n \"phases\": ParamSpec(\n desc=\"Language-specific phase thresholds\",\n req=False,\n default=\"cipheydists::brandon::english\",\n ),\n }\n\n# Path: ciphey/basemods/Checkers/format.py\nimport json\nfrom typing import Dict, Optional\n\nimport logging\nfrom rich.logging import RichHandler\n\nfrom ciphey.iface import Checker, Config, ParamSpec, T, registry\n\n\n@registry.register\nclass JsonChecker(Checker[str]):\n\n \"\"\"\n This object is effectively a prebuilt quorum (with requirement 1) of common patterns\n \"\"\"\n\n def check(self, text: T) -> Optional[str]:\n logging.debug(\"Trying json checker\")\n\n # https://github.com/Ciphey/Ciphey/issues/389\n if text.isdigit():\n return None\n\n try:\n json.loads(text)\n return \"\"\n except ValueError:\n return None\n\n def getExpectedRuntime(self, text: T) -> float:\n # TODO: actually bench this\n return 1e-7 * len(text) # From benchmarks I found online\n\n def __init__(self, config: Config):\n super().__init__(config)\n\n @staticmethod\n def getParams() -> Optional[Dict[str, ParamSpec]]:\n pass\n\n# Path: ciphey/basemods/Checkers/human.py\nfrom typing import Dict, Optional\n\nfrom ciphey.iface import Checker, Config, ParamSpec, registry\nfrom rich.console import Console\nfrom rich.markup import escape\n\nconsole = Console()\n\n\n@registry.register\nclass HumanChecker(Checker[str]):\n\n \"\"\"\n Uses the person's decision to determine plaintext\n \"\"\"\n\n def check(self, ctext: str) -> Optional[str]:\n with self._config().pause_spinner_handle():\n response = console.input(\n f\"Possible plaintext: [blue bold]{escape(ctext.__repr__())}[/blue bold] ([green]y[/green]/[red]N[/red]): \"\n )\n if response == \"y\":\n return \"\"\n elif response in (\"n\", \"\"):\n return None\n else:\n return self.check(ctext)\n\n def getExpectedRuntime(self, text: str) -> float:\n return 1 # About a second\n\n @staticmethod\n def getParams() -> Optional[Dict[str, ParamSpec]]:\n return None\n\n def __init__(self, config: Config):\n super().__init__(config)\n\n# Path: ciphey/basemods/Checkers/quadgrams.py\nimport logging\nimport re\nfrom math import log10\nfrom typing import Dict, Optional\n\nfrom ciphey.iface import Checker, Config, ParamSpec, T, Translation, registry\nfrom rich.logging import RichHandler\n\n\n@registry.register\nclass Quadgrams(Checker[str]):\n\n \"\"\"\n Uses Quadgrams to determine plaintext\n \"\"\"\n\n def check(self, ctext: T) -> Optional[str]:\n logging.debug(\"Trying Quadgrams checker\")\n # Capitalize and remove everything that's not a letter\n ctext = re.sub(\"[^A-Z]\", \"\", ctext.upper())\n quadgrams = self.QUADGRAMS_DICT\n quadgrams_sum = sum(quadgrams.values())\n score = 0\n for key in quadgrams.keys():\n quadgrams[key] = float(quadgrams[key]) / quadgrams_sum\n floor = log10(0.01 / quadgrams_sum)\n for i in range(len(ctext) - 4 + 1):\n # Get all quadgrams from ctext and check if they're in the dict\n # If yes then add the score of those quadgrams to the total score\n if ctext[i : i + 4] in quadgrams:\n score += quadgrams[ctext[i : i + 4]]\n else:\n score += floor\n if len(ctext) > 0:\n score = score / len(ctext)\n logging.info(f\"Quadgrams is {score}\")\n # The default threshold was found to work the best from lots of testing\n if score > self.threshold:\n return \"\"\n return None\n\n def getExpectedRuntime(self, text: T) -> float:\n # TODO: actually bench this\n return 2e-7 * len(text)\n\n @staticmethod\n def getParams() -> Optional[Dict[str, ParamSpec]]:\n return {\n \"dict\": ParamSpec(\n desc=\"The quadgrams dictionary to use\",\n req=False,\n default=\"cipheydists::dist::quadgrams\",\n ),\n \"score\": ParamSpec(\n desc=\"The score threshold to use\",\n req=False,\n default=0.00011,\n ),\n }\n\n def __init__(self, config: Config):\n super().__init__(config)\n self.QUADGRAMS_DICT = config.get_resource(self._params()[\"dict\"], Translation)\n self.threshold = float(self._params()[\"score\"])\n\n# Path: ciphey/basemods/Checkers/regex.py\nimport re\nfrom typing import Dict, Optional\n\nimport logging\nfrom rich.logging import RichHandler\n\nfrom ciphey.iface import Checker, Config, ParamSpec, T, registry\n\n\n@registry.register\nclass Regex(Checker[str]):\n def getExpectedRuntime(self, text: T) -> float:\n return 1e-5 # TODO: actually calculate this\n\n def __init__(self, config: Config):\n super().__init__(config)\n self.regexes = list(map(re.compile, self._params()[\"regex\"]))\n logging.debug(f\"There are {len(self.regexes)} regexes\")\n\n def check(self, text: str) -> Optional[str]:\n for regex in self.regexes:\n logging.debug(f\"Trying regex {regex} on {text}\")\n res = regex.search(text)\n logging.debug(f\"Results: {res}\")\n if res:\n return f\"Passed with regex {regex}. Want to contribute to Ciphey? Submit your regex here to allow Ciphey to automatically get this next time https://github.com/bee-san/pyWhat/wiki/Adding-your-own-Regex\\n\"\n\n @staticmethod\n def getParams() -> Optional[Dict[str, ParamSpec]]:\n return {\n \"regex\": ParamSpec(\n req=True,\n desc=\"The regex that must be matched (in a substring)\",\n list=True,\n )\n }\n\n\n@registry.register\nclass RegexList(Checker[str]):\n def getExpectedRuntime(self, text: T) -> float:\n return 1e-5 # TODO: actually calculate this\n\n def __init__(self, config: Config):\n super().__init__(config)\n self.regexes = []\n for i in self._params()[\"resource\"]:\n self.regexes += [re.compile(regex) for regex in config.get_resource(i)]\n logging.debug(f\"There are {len(self.regexes)} regexes\")\n\n def check(self, text: str) -> Optional[str]:\n for regex in self.regexes:\n logging.debug(f\"Trying regex {regex} on {text}\")\n res = regex.search(text)\n logging.debug(f\"Results: {res}\")\n if res:\n return f\"passed with regex {regex}\"\n\n @staticmethod\n def getParams() -> Optional[Dict[str, ParamSpec]]:\n return {\n \"resource\": ParamSpec(\n req=True,\n desc=\"A list of regexes that could be matched\",\n list=True,\n )\n }\n\n# Path: ciphey/basemods/Checkers/what.py\nfrom typing import Dict, Optional\n\nfrom ciphey.iface import Checker, Config, ParamSpec, T, registry\nimport logging\nfrom rich.logging import RichHandler\nfrom pywhat import identifier\nfrom rich.console import Console\n\nconsole = Console()\n\n\n@registry.register\nclass What(Checker[str]):\n\n \"\"\"\n Uses PyWhat to determine plaintext with regexes\n https://github.com/bee-san/pyWhat\n \"\"\"\n\n def check(self, ctext: T) -> Optional[str]:\n logging.debug(\"Trying PyWhat checker\")\n returned_regexes = self.id.identify(ctext)\n if returned_regexes[\"Regexes\"]:\n matched_regex = returned_regexes[\"Regexes\"]['text'][0][\"Regex Pattern\"]\n\n ret = f'The plaintext is a [yellow]{matched_regex[\"Name\"]}[/yellow]'\n human = (\n f'\\nI think the plaintext is a [yellow]{matched_regex[\"Name\"]}[/yellow]'\n )\n\n if \"Description\" in matched_regex and matched_regex[\"Description\"]:\n s = matched_regex[\"Description\"]\n # lowercases first letter so it doesn't look weird\n s = f\", which is {s[0].lower() + s[1:]}\\n\"\n ret += s\n human += s\n\n # if URL is attached, include that too.\n if \"URL\" in matched_regex and matched_regex[\"URL\"]:\n link = matched_regex[\"URL\"] + ctext.replace(\" \", \"\")\n ret += f\"\\nClick here to view in browser [#CAE4F1][link={link}]{link}[/link][/#CAE4F1]\\n\"\n\n # If greppable mode is on, don't print this\n if self.config.verbosity >= 0:\n # Print with full stop\n console.print(human)\n return ret\n return None\n\n def getExpectedRuntime(self, text: T) -> float:\n # TODO: actually bench this\n return 2e-7 * len(text)\n\n @staticmethod\n def getParams() -> Optional[Dict[str, ParamSpec]]:\n return None\n\n def __init__(self, config: Config):\n super().__init__(config)\n self.config = config\n self.id = identifier.Identifier()\n\n# Path: ciphey/basemods/Checkers/ezcheck.py\nfrom typing import Dict, List, Optional\n\nfrom ciphey.iface import Checker, Config, ParamSpec, T, registry\n\nfrom .brandon import Brandon\nfrom .format import JsonChecker\nfrom .human import HumanChecker\nfrom .quadgrams import Quadgrams\nfrom .regex import RegexList\nfrom .what import What\n\n\n@registry.register\nclass EzCheck(Checker[str]):\n \"\"\"\n This object is effectively a prebuilt quorum (with requirement 1) of common patterns, followed by a human check\n \"\"\"\n\n def check(self, text: str) -> Optional[str]:\n for checker in self.checkers:\n res = checker.check(text)\n if (\n res is not None\n and (self.decider is None or self.decider.check(text)) is not None\n ):\n return res\n return None\n\n def getExpectedRuntime(self, text: T) -> float:\n return sum(\n i.getExpectedRuntime(text) for i in self.checkers\n ) + self.decider.getExpectedRuntime(text)\n\n def __init__(self, config: Config):\n super().__init__(config)\n\n self.checkers: List[Checker[str]] = []\n # Disable human checker for automated systems\n if config.verbosity >= 0:\n self.decider = config(HumanChecker)\n else:\n self.decider = None\n\n # We need to modify the config for each of the objects\n\n # First PyWhat, as it's the fastest\n self.checkers.append(config(What))\n\n # Next, the json checker\n self.checkers.append(config(JsonChecker))\n\n # Second to last, the quadgrams checker\n self.checkers.append(config(Quadgrams))\n\n # Finally, the Brandon checker, as it is the slowest\n self.checkers.append(config(Brandon))\n\n @staticmethod\n def getParams() -> Optional[Dict[str, ParamSpec]]:\n pass\n\n# Path: ciphey/basemods/Checkers/quorum.py\nfrom typing import Dict, Generic, Optional\n\nfrom ciphey.iface import Checker, Config, ParamSpec, T, _registry\n\n\nclass Quorum(Generic[T], Checker[T]):\n def check(self, text: T) -> Optional[str]:\n left = self._params().k\n results = []\n for checker in self.checkers:\n results.append(checker.check(text))\n if results[-1] is None:\n continue\n left -= 1\n # Early return check\n if left == 0:\n return str(results)\n\n def __init__(self, config: Config):\n super().__init__(config)\n\n if self._params().k is None:\n k = len(self._params()[\"checker\"])\n # These checks need to be separate, to make sure that we do not have zero members\n if self._params().k == 0 or self._params().k > len(self._params()[\"checker\"]):\n raise IndexError(\n \"k must be between 0 and the number of checkers (inclusive)\"\n )\n\n self.checkers = []\n for i in self._params()[\"checker\"]:\n # This enforces type consistency\n self.checkers.append(_registry.get_named(i, Checker[T]))\n\n @staticmethod\n def getParams() -> Optional[Dict[str, ParamSpec]]:\n return {\n \"checker\": ParamSpec(\n req=True, desc=\"The checkers to be used for analysis\", list=True\n ),\n \"k\": ParamSpec(\n req=False,\n desc=\"The minimum quorum size. Defaults to the number of checkers\",\n ),\n }\n\n# Path: ciphey/basemods/Checkers/__init__.py\nfrom . import any, brandon, ezcheck, format, human, quadgrams, quorum, regex, what\n\n# Path: ciphey/common.py\n\"\"\"Some useful adapters\"\"\"\nfrom typing import Any\n\n\ndef id_lambda(value: Any):\n \"\"\"\n A function used in dynamic class generation that abstracts away a constant return value (like in getName)\n \"\"\"\n return lambda *args: value\n\n\ndef fix_case(target: str, base: str) -> str:\n \"\"\"Returns the lower-case string target with the case of base\"\"\"\n ret = \"\".join(\n [\n target[i].upper() if base[i].isupper() else target[i]\n for i in range(len(target))\n ]\n )\n return \"\".join(\n [\n target[i].upper() if base[i].isupper() else target[i]\n for i in range(len(target))\n ]\n )\n\n# Path: ciphey/basemods/Crackers/affine.py\n# Community\n# by https://github.com/Ozzyz\n\nfrom typing import Dict, List, Optional\n\nimport cipheycore\nimport logging\nfrom rich.logging import RichHandler\n\nfrom ciphey.common import fix_case\nfrom ciphey.iface import Config, Cracker, CrackInfo, CrackResult, ParamSpec, registry\nfrom ciphey.mathsHelper import mathsHelper\n\n\n@registry.register\nclass Affine(Cracker[str]):\n \"\"\"\n Each character in the Affine Cipher is encoded with the rule E(x) = (ax + b) mod m\n m is the size of the alphabet, while a and b are the keys in the cipher. a must be coprime to b.\n The Caesar cipher is a specific case of the Affine Cipher, with a=1 and b being the shift of the cipher.\n Decryption is performed by D(x) = a_inv (x - b) mod m where a_inv is the modular multiplicative inverse of a mod m.\n\n In this version of the Affine Cipher, we do not allow alphabets with several instances of the same letter in different cases.\n For instance, the alphabet 'ABCdef123' is allowed, but 'AaBbCc' is not.\n \"\"\"\n\n def getInfo(self, ctext: str) -> CrackInfo:\n return CrackInfo(\n success_likelihood=0.1,\n success_runtime=1e-5,\n failure_runtime=1e-5,\n )\n\n @staticmethod\n def getTarget() -> str:\n return \"affine\"\n\n def attemptCrack(self, ctext: str) -> List[CrackResult]:\n \"\"\"\n Brute forces all the possible combinations of a and b to attempt to crack the cipher.\n \"\"\"\n logging.debug(\"Attempting affine\")\n candidates = []\n\n # a and b are coprime if gcd(a,b) is 1.\n possible_a = [\n a\n for a in range(1, self.alphabet_length)\n if mathsHelper.gcd(a, self.alphabet_length) == 1\n ]\n logging.info(\n f\"Trying Affine Cracker with {len(possible_a)} a-values and {self.alphabet_length} b-values\"\n )\n\n for a in possible_a:\n a_inv = mathsHelper.mod_inv(a, self.alphabet_length)\n # If there is no inverse, we cannot decrypt the text\n if a_inv is None:\n continue\n for b in range(self.alphabet_length):\n # Pass in lowered text. This means that we expect alphabets to not contain both 'a' and 'A'.\n translated = self.decrypt(ctext.lower(), a_inv, b, self.alphabet_length)\n\n candidate_probability = self.plaintext_probability(translated)\n if candidate_probability > self.plaintext_prob_threshold:\n candidates.append(\n CrackResult(\n value=fix_case(translated, ctext), key_info=f\"a={a}, b={b}\"\n )\n )\n logging.info(f\"Affine Cipher returned {len(candidates)} candidates\")\n return candidates\n\n def plaintext_probability(self, translated: str) -> float:\n \"\"\"\n Analyses the translated text and applies the chi squared test to see if it is a probable plaintext candidate\n Returns the probability of the chi-squared test.\n \"\"\"\n analysis = cipheycore.analyse_string(translated)\n return cipheycore.chisq_test(analysis, self.expected)\n\n def decrypt(self, text: str, a_inv: int, b: int, m: int) -> str:\n \"\"\"\n Each letter is decrypted at D(x) = a_inv (x - b) mod m where x is the char\n We treat the char value as its index in the alphabet, so if\n the alphabet is 'abcd....' and the char is 'b', it has the value 1.\n \"\"\"\n return \"\".join([self.decryptChar(char, a_inv, b, m) for char in text])\n\n def decryptChar(self, char: str, a_inv: int, b: int, m: int) -> str:\n\n # We lower the alphabet since both ctext and alphabet need to be in the same case in order\n # to perform the shifts. The translated text will have fixed case after the translation anyways.\n # This is only necessary if the specified alphabet is uppercase.\n alphabet = [x.lower() for x in self.group]\n\n # Preserve characters that are not in alphabet\n if char not in alphabet:\n return char\n char_idx = alphabet.index(char)\n decrypted_char_idx = (a_inv * (char_idx - b)) % m\n return alphabet[decrypted_char_idx]\n\n @staticmethod\n def getParams() -> Optional[Dict[str, ParamSpec]]:\n return {\n \"expected\": ParamSpec(\n desc=\"The expected distribution of the plaintext\",\n req=False,\n config_ref=[\"default_dist\"],\n ),\n \"group\": ParamSpec(\n desc=\"An ordered sequence of chars that make up the alphabet\",\n req=False,\n default=\"abcdefghijklmnopqrstuvwxyz\",\n ),\n }\n\n def __init__(self, config: Config):\n super().__init__(config)\n self.group = list(self._params()[\"group\"])\n self.expected = config.get_resource(self._params()[\"expected\"])\n self.alphabet_length = len(self.group)\n self.cache = config.cache\n self.plaintext_prob_threshold = 0.01\n\n# Path: ciphey/basemods/Crackers/ascii_shift.py\n\"\"\"\n \u2588\u2588\u2588\u2588\u2588\u2588\u2557\u2588\u2588\u2557\u2588\u2588\u2588\u2588\u2588\u2588\u2557 \u2588\u2588\u2557 \u2588\u2588\u2557\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2557\u2588\u2588\u2557 \u2588\u2588\u2557\n\u2588\u2588\u2554\u2550\u2550\u2550\u2550\u255d\u2588\u2588\u2551\u2588\u2588\u2554\u2550\u2550\u2588\u2588\u2557\u2588\u2588\u2551 \u2588\u2588\u2551\u2588\u2588\u2554\u2550\u2550\u2550\u2550\u255d\u255a\u2588\u2588\u2557 \u2588\u2588\u2554\u255d\n\u2588\u2588\u2551 \u2588\u2588\u2551\u2588\u2588\u2588\u2588\u2588\u2588\u2554\u255d\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2551\u2588\u2588\u2588\u2588\u2588\u2557 \u255a\u2588\u2588\u2588\u2588\u2554\u255d\n\u2588\u2588\u2551 \u2588\u2588\u2551\u2588\u2588\u2554\u2550\u2550\u2550\u255d \u2588\u2588\u2554\u2550\u2550\u2588\u2588\u2551\u2588\u2588\u2554\u2550\u2550\u255d \u255a\u2588\u2588\u2554\u255d\n\u255a\u2588\u2588\u2588\u2588\u2588\u2588\u2557\u2588\u2588\u2551\u2588\u2588\u2551 \u2588\u2588\u2551 \u2588\u2588\u2551\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2557 \u2588\u2588\u2551\n\u00a9 Brandon Skerritt\nGithub: brandonskerritt\n\"\"\"\n\nfrom typing import Dict, List, Optional\n\nimport cipheycore\nimport logging\nfrom rich.logging import RichHandler\n\nfrom ciphey.iface import Config, Cracker, CrackInfo, CrackResult, ParamSpec, registry\n\n\n@registry.register\nclass Ascii_shift(Cracker[str]):\n def getInfo(self, ctext: str) -> CrackInfo:\n analysis = self.cache.get_or_update(\n ctext,\n \"cipheycore::simple_analysis\",\n lambda: cipheycore.analyse_string(ctext),\n )\n\n return CrackInfo(\n success_likelihood=cipheycore.caesar_detect(analysis, self.expected),\n # TODO: actually calculate runtimes\n success_runtime=1e-5,\n failure_runtime=1e-5,\n )\n\n @staticmethod\n def getTarget() -> str:\n return \"ascii_shift\"\n\n def attemptCrack(self, ctext: str) -> List[CrackResult]:\n logging.info(f\"Trying ASCII shift cipher on {ctext}\")\n\n logging.debug(\"Beginning cipheycore simple analysis\")\n\n # Hand it off to the core\n analysis = self.cache.get_or_update(\n ctext,\n \"cipheycore::simple_analysis\",\n lambda: cipheycore.analyse_string(ctext),\n )\n logging.debug(\"Beginning cipheycore::caesar\")\n possible_keys = cipheycore.caesar_crack(\n analysis, self.expected, self.group, self.p_value\n )\n\n n_candidates = len(possible_keys)\n logging.info(f\"ASCII shift returned {n_candidates} candidates\")\n\n if n_candidates == 0:\n logging.debug(\"Filtering for better results\")\n analysis = cipheycore.analyse_string(ctext, self.group)\n possible_keys = cipheycore.caesar_crack(\n analysis, self.expected, self.group, self.p_value\n )\n\n candidates = []\n\n for candidate in possible_keys:\n logging.debug(f\"Candidate {candidate.key} has prob {candidate.p_value}\")\n translated = cipheycore.caesar_decrypt(ctext, candidate.key, self.group)\n candidates.append(CrackResult(value=translated, key_info=candidate.key))\n\n return candidates\n\n @staticmethod\n def getParams() -> Optional[Dict[str, ParamSpec]]:\n return {\n \"expected\": ParamSpec(\n desc=\"The expected distribution of the plaintext\",\n req=False,\n config_ref=[\"default_dist\"],\n ),\n \"group\": ParamSpec(\n desc=\"An ordered sequence of chars that make up the ASCII shift cipher alphabet\",\n req=False,\n default=\"\"\"\\x00\\x01\\x02\\x03\\x04\\x05\\x06\\x07\\x08\\t\\n\\x0b\\x0c\\r\\x0e\\x0f\\x10\\x11\\x12\\x13\\x14\\x15\\x16\\x17\\x18\\x19\\x1a\\x1b\\x1c\\x1d\\x1e\\x1f !\"#$%&\\'()*+,-./0123456789:;<=>?@ABCDEFGHIJKLMNOPQRSTUVWXYZ[\\\\]^_`abcdefghijklmnopqrstuvwxyz{|}~\\x7f\"\"\",\n ),\n \"p_value\": ParamSpec(\n desc=\"The p-value to use for standard frequency analysis\",\n req=False,\n default=0.01,\n )\n # TODO: add \"filter\" param\n }\n\n def __init__(self, config: Config):\n super().__init__(config)\n self.group = list(self._params()[\"group\"])\n self.expected = config.get_resource(self._params()[\"expected\"])\n self.cache = config.cache\n self.p_value = float(self._params()[\"p_value\"])\n\n# Path: ciphey/basemods/Crackers/baconian.py\nimport re\nfrom typing import Dict, List, Optional\n\nfrom ciphey.iface import (\n Config,\n Cracker,\n CrackInfo,\n CrackResult,\n ParamSpec,\n Translation,\n registry,\n)\nimport logging\nfrom rich.logging import RichHandler\n\n\n@registry.register\nclass Baconian(Cracker[str]):\n...\n", "needle_token_start": 5601, "needle_token_end": 5616, "code_context_ntokens": 16082, "output": ["```python\n@staticmethod\ndef strip_punctuation(text: str) -> str:\n \"\"\"Strips punctuation from a given string.\n\n Uses string.punctuation.\n\n Args:\n text -> the text to strip punctuation from.\n\n Returns:\n Returns string without punctuation.\n \"\"\"\n text: str = (str(text).translate(str.maketrans(\"\", \"\", punctuation))).\nstrip(\n \"\\n\"\n )\n return text\n```"]} +{"repo": "Ciphey/Ciphey", "name": "crackOne", "language": "python", "path": "ciphey/basemods/Crackers/vigenere.py", "position_ratio": 0.45, "description": "\nFunction Description:\n1. **Purpose**: The function is designed to decrypt a cipher text that has been encrypted using a specific polyalphabetic substitution cipher method. It identifies potential decryption keys based on statistical analysis and returns a list of possible plaintexts.\n2. **Input**: The function takes three parameters: the cipher text as a string, a statistical analysis result of the cipher text, and the original cipher text for case correction purposes.\n3. **Output**: Outputs a list of decryption results, each containing the decrypted text, information about the decryption key used, and additional statistical information related to the decryption process.\n4. **Procedure**: The function performs the following steps:\n - It calls a decryption analysis function to get potential decryption keys based on the provided analysis.\n - Limits the number of potential keys to a predefined maximum.\n - Decrypts the cipher text using each of the potential keys.\n - Corrects the case of the decrypted text to match the original cipher text.\n - Compiles the decryption results, including the decrypted text, key information, and statistical data, into a list.\n\n", "instruction": "Based on the function description and code context, please retrieve and repeat the exact described function from the code context in a code block wrapped by ```:", "template": "instruction\ncode_context\ndescription\ninstruction", "code_context": "# Path: ciphey/basemods/Crackers/ascii_shift.py\n\"\"\"\n \u2588\u2588\u2588\u2588\u2588\u2588\u2557\u2588\u2588\u2557\u2588\u2588\u2588\u2588\u2588\u2588\u2557 \u2588\u2588\u2557 \u2588\u2588\u2557\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2557\u2588\u2588\u2557 \u2588\u2588\u2557\n\u2588\u2588\u2554\u2550\u2550\u2550\u2550\u255d\u2588\u2588\u2551\u2588\u2588\u2554\u2550\u2550\u2588\u2588\u2557\u2588\u2588\u2551 \u2588\u2588\u2551\u2588\u2588\u2554\u2550\u2550\u2550\u2550\u255d\u255a\u2588\u2588\u2557 \u2588\u2588\u2554\u255d\n\u2588\u2588\u2551 \u2588\u2588\u2551\u2588\u2588\u2588\u2588\u2588\u2588\u2554\u255d\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2551\u2588\u2588\u2588\u2588\u2588\u2557 \u255a\u2588\u2588\u2588\u2588\u2554\u255d\n\u2588\u2588\u2551 \u2588\u2588\u2551\u2588\u2588\u2554\u2550\u2550\u2550\u255d \u2588\u2588\u2554\u2550\u2550\u2588\u2588\u2551\u2588\u2588\u2554\u2550\u2550\u255d \u255a\u2588\u2588\u2554\u255d\n\u255a\u2588\u2588\u2588\u2588\u2588\u2588\u2557\u2588\u2588\u2551\u2588\u2588\u2551 \u2588\u2588\u2551 \u2588\u2588\u2551\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2557 \u2588\u2588\u2551\n\u00a9 Brandon Skerritt\nGithub: brandonskerritt\n\"\"\"\n\nfrom typing import Dict, List, Optional\n\nimport cipheycore\nimport logging\nfrom rich.logging import RichHandler\n\nfrom ciphey.iface import Config, Cracker, CrackInfo, CrackResult, ParamSpec, registry\n\n\n@registry.register\nclass Ascii_shift(Cracker[str]):\n def getInfo(self, ctext: str) -> CrackInfo:\n analysis = self.cache.get_or_update(\n ctext,\n \"cipheycore::simple_analysis\",\n lambda: cipheycore.analyse_string(ctext),\n )\n\n return CrackInfo(\n success_likelihood=cipheycore.caesar_detect(analysis, self.expected),\n # TODO: actually calculate runtimes\n success_runtime=1e-5,\n failure_runtime=1e-5,\n )\n\n @staticmethod\n def getTarget() -> str:\n return \"ascii_shift\"\n\n def attemptCrack(self, ctext: str) -> List[CrackResult]:\n logging.info(f\"Trying ASCII shift cipher on {ctext}\")\n\n logging.debug(\"Beginning cipheycore simple analysis\")\n\n # Hand it off to the core\n analysis = self.cache.get_or_update(\n ctext,\n \"cipheycore::simple_analysis\",\n lambda: cipheycore.analyse_string(ctext),\n )\n logging.debug(\"Beginning cipheycore::caesar\")\n possible_keys = cipheycore.caesar_crack(\n analysis, self.expected, self.group, self.p_value\n )\n\n n_candidates = len(possible_keys)\n logging.info(f\"ASCII shift returned {n_candidates} candidates\")\n\n if n_candidates == 0:\n logging.debug(\"Filtering for better results\")\n analysis = cipheycore.analyse_string(ctext, self.group)\n possible_keys = cipheycore.caesar_crack(\n analysis, self.expected, self.group, self.p_value\n )\n\n candidates = []\n\n for candidate in possible_keys:\n logging.debug(f\"Candidate {candidate.key} has prob {candidate.p_value}\")\n translated = cipheycore.caesar_decrypt(ctext, candidate.key, self.group)\n candidates.append(CrackResult(value=translated, key_info=candidate.key))\n\n return candidates\n\n @staticmethod\n def getParams() -> Optional[Dict[str, ParamSpec]]:\n return {\n \"expected\": ParamSpec(\n desc=\"The expected distribution of the plaintext\",\n req=False,\n config_ref=[\"default_dist\"],\n ),\n \"group\": ParamSpec(\n desc=\"An ordered sequence of chars that make up the ASCII shift cipher alphabet\",\n req=False,\n default=\"\"\"\\x00\\x01\\x02\\x03\\x04\\x05\\x06\\x07\\x08\\t\\n\\x0b\\x0c\\r\\x0e\\x0f\\x10\\x11\\x12\\x13\\x14\\x15\\x16\\x17\\x18\\x19\\x1a\\x1b\\x1c\\x1d\\x1e\\x1f !\"#$%&\\'()*+,-./0123456789:;<=>?@ABCDEFGHIJKLMNOPQRSTUVWXYZ[\\\\]^_`abcdefghijklmnopqrstuvwxyz{|}~\\x7f\"\"\",\n ),\n \"p_value\": ParamSpec(\n desc=\"The p-value to use for standard frequency analysis\",\n req=False,\n default=0.01,\n )\n # TODO: add \"filter\" param\n }\n\n def __init__(self, config: Config):\n super().__init__(config)\n self.group = list(self._params()[\"group\"])\n self.expected = config.get_resource(self._params()[\"expected\"])\n self.cache = config.cache\n...\n# Path: ciphey/basemods/Crackers/baconian.py\nimport re\nfrom typing import Dict, List, Optional\n\nfrom ciphey.iface import (\n Config,\n Cracker,\n CrackInfo,\n CrackResult,\n ParamSpec,\n Translation,\n registry,\n)\nimport logging\nfrom rich.logging import RichHandler\n\n\n@registry.register\nclass Baconian(Cracker[str]):\n def getInfo(self, ctext: str) -> CrackInfo:\n return CrackInfo(\n success_likelihood=0.1,\n success_runtime=1e-5,\n failure_runtime=1e-5,\n )\n\n @staticmethod\n def getTarget() -> str:\n return \"baconian\"\n\n def attemptCrack(self, ctext: str) -> List[CrackResult]:\n \"\"\"\n Attempts to decode both variants of the Baconian cipher.\n \"\"\"\n logging.debug(\"Attempting Baconian cracker\")\n candidates = []\n result = []\n ctext_decoded = \"\"\n ctext_decoded2 = \"\"\n\n # Convert to uppercase and replace delimiters and whitespace with nothing\n ctext = re.sub(r\"[,;:\\-\\s]\", \"\", ctext.upper())\n\n # Make sure ctext only contains A and B\n if bool(re.search(r\"[^AB]\", ctext)) is True:\n logging.debug(\"Failed to crack baconian due to non baconian character(s)\")\n return None\n\n # Make sure ctext is divisible by 5\n ctext_len = len(ctext)\n if ctext_len % 5:\n logging.debug(\n f\"Failed to decode Baconian because length must be a multiple of 5, not '{ctext_len}'\"\n )\n return None\n\n # Split ctext into groups of 5\n ctext = \" \".join(ctext[i : i + 5] for i in range(0, len(ctext), 5))\n ctext_split = ctext.split(\" \")\n baconian_keys = self.BACONIAN_DICT.keys()\n\n # Decode I=J and U=V variant\n for i in ctext_split:\n if i in baconian_keys:\n ctext_decoded += self.BACONIAN_DICT[i]\n\n # Decode variant that assigns each letter a unique code\n for i in ctext_split:\n if \"+\" + i in baconian_keys:\n ctext_decoded2 += self.BACONIAN_DICT[\"+\" + i]\n\n candidates.append(ctext_decoded)\n candidates.append(ctext_decoded2)\n for candidate in candidates:\n if candidate != \"\":\n if candidate == candidates[0]:\n result.append(CrackResult(value=candidate, key_info=\"I=J & U=V\"))\n else:\n result.append(CrackResult(value=candidate))\n logging.debug(f\"Baconian cracker - Returning results: {result}\")\n return result\n\n @staticmethod\n def getParams() -> Optional[Dict[str, ParamSpec]]:\n return {\n \"expected\": ParamSpec(\n desc=\"The expected distribution of the plaintext\",\n req=False,\n config_ref=[\"default_dist\"],\n ),\n \"dict\": ParamSpec(\n desc=\"The Baconian alphabet dictionary to use\",\n req=False,\n default=\"cipheydists::translate::baconian\",\n ),\n }\n\n def __init__(self, config: Config):\n super().__init__(config)\n self.BACONIAN_DICT = config.get_resource(self._params()[\"dict\"], Translation)\n self.expected = config.get_resource(self._params()[\"expected\"])\n self.cache = config.cache\n\n# Path: ciphey/basemods/Crackers/caesar.py\n\"\"\"\n \u2588\u2588\u2588\u2588\u2588\u2588\u2557\u2588\u2588\u2557\u2588\u2588\u2588\u2588\u2588\u2588\u2557 \u2588\u2588\u2557 \u2588\u2588\u2557\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2557\u2588\u2588\u2557 \u2588\u2588\u2557\n\u2588\u2588\u2554\u2550\u2550\u2550\u2550\u255d\u2588\u2588\u2551\u2588\u2588\u2554\u2550\u2550\u2588\u2588\u2557\u2588\u2588\u2551 \u2588\u2588\u2551\u2588\u2588\u2554\u2550\u2550\u2550\u2550\u255d\u255a\u2588\u2588\u2557 \u2588\u2588\u2554\u255d\n\u2588\u2588\u2551 \u2588\u2588\u2551\u2588\u2588\u2588\u2588\u2588\u2588\u2554\u255d\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2551\u2588\u2588\u2588\u2588\u2588\u2557 \u255a\u2588\u2588\u2588\u2588\u2554\u255d\n\u2588\u2588\u2551 \u2588\u2588\u2551\u2588\u2588\u2554\u2550\u2550\u2550\u255d \u2588\u2588\u2554\u2550\u2550\u2588\u2588\u2551\u2588\u2588\u2554\u2550\u2550\u255d \u255a\u2588\u2588\u2554\u255d\n\u255a\u2588\u2588\u2588\u2588\u2588\u2588\u2557\u2588\u2588\u2551\u2588\u2588\u2551 \u2588\u2588\u2551 \u2588\u2588\u2551\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2557 \u2588\u2588\u2551\n\u00a9 Brandon Skerritt\nGithub: brandonskerritt\n\"\"\"\nfrom distutils import util\nfrom typing import Dict, List, Optional, Union\n\nimport cipheycore\nimport logging\nfrom rich.logging import RichHandler\n\nfrom ciphey.common import fix_case\nfrom ciphey.iface import Config, Cracker, CrackInfo, CrackResult, ParamSpec, registry\n\n\n@registry.register\nclass Caesar(Cracker[str]):\n def getInfo(self, ctext: str) -> CrackInfo:\n analysis = self.cache.get_or_update(\n ctext,\n \"cipheycore::simple_analysis\",\n lambda: cipheycore.analyse_string(ctext),\n )\n\n return CrackInfo(\n success_likelihood=cipheycore.caesar_detect(analysis, self.expected),\n # TODO: actually calculate runtimes\n success_runtime=1e-5,\n failure_runtime=1e-5,\n )\n\n @staticmethod\n def getTarget() -> str:\n return \"caesar\"\n\n def attemptCrack(self, ctext: str) -> List[CrackResult]:\n logging.info(f\"Trying caesar cipher on {ctext}\")\n # Convert it to lower case\n #\n # TODO: handle different alphabets\n if self.lower:\n message = ctext.lower()\n else:\n message = ctext\n\n logging.debug(\"Beginning cipheycore simple analysis\")\n\n # Hand it off to the core\n analysis = self.cache.get_or_update(\n ctext,\n \"cipheycore::simple_analysis\",\n lambda: cipheycore.analyse_string(ctext),\n )\n logging.debug(\"Beginning cipheycore::caesar\")\n possible_keys = cipheycore.caesar_crack(\n analysis, self.expected, self.group, self.p_value\n )\n\n n_candidates = len(possible_keys)\n logging.info(f\"Caesar returned {n_candidates} candidates\")\n\n if n_candidates == 0:\n logging.debug(\"Filtering for better results\")\n analysis = cipheycore.analyse_string(ctext, self.group)\n possible_keys = cipheycore.caesar_crack(\n analysis, self.expected, self.group, self.p_value\n )\n\n candidates = []\n\n for candidate in possible_keys:\n logging.debug(f\"Candidate {candidate.key} has prob {candidate.p_value}\")\n translated = cipheycore.caesar_decrypt(message, candidate.key, self.group)\n candidates.append(\n CrackResult(value=fix_case(translated, ctext), key_info=candidate.key)\n )\n\n return candidates\n\n @staticmethod\n def getParams() -> Optional[Dict[str, ParamSpec]]:\n return {\n \"expected\": ParamSpec(\n desc=\"The expected distribution of the plaintext\",\n req=False,\n config_ref=[\"default_dist\"],\n ),\n \"group\": ParamSpec(\n desc=\"An ordered sequence of chars that make up the caesar cipher alphabet\",\n req=False,\n default=\"abcdefghijklmnopqrstuvwxyz\",\n ),\n \"lower\": ParamSpec(\n desc=\"Whether or not the ciphertext should be converted to lowercase first\",\n req=False,\n default=True,\n ),\n \"p_value\": ParamSpec(\n desc=\"The p-value to use for standard frequency analysis\",\n req=False,\n default=0.01,\n )\n # TODO: add \"filter\" param\n }\n\n def __init__(self, config: Config):\n super().__init__(config)\n self.lower: Union[str, bool] = self._params()[\"lower\"]\n if not isinstance(self.lower, bool):\n self.lower = util.strtobool(self.lower)\n self.group = list(self._params()[\"group\"])\n self.expected = config.get_resource(self._params()[\"expected\"])\n self.cache = config.cache\n self.p_value = float(self._params()[\"p_value\"])\n\n# Path: ciphey/basemods/Crackers/rot47.py\n\"\"\"\n \u2588\u2588\u2588\u2588\u2588\u2588\u2557\u2588\u2588\u2557\u2588\u2588\u2588\u2588\u2588\u2588\u2557 \u2588\u2588\u2557 \u2588\u2588\u2557\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2557\u2588\u2588\u2557 \u2588\u2588\u2557\n\u2588\u2588\u2554\u2550\u2550\u2550\u2550\u255d\u2588\u2588\u2551\u2588\u2588\u2554\u2550\u2550\u2588\u2588\u2557\u2588\u2588\u2551 \u2588\u2588\u2551\u2588\u2588\u2554\u2550\u2550\u2550\u2550\u255d\u255a\u2588\u2588\u2557 \u2588\u2588\u2554\u255d\n\u2588\u2588\u2551 \u2588\u2588\u2551\u2588\u2588\u2588\u2588\u2588\u2588\u2554\u255d\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2551\u2588\u2588\u2588\u2588\u2588\u2557 \u255a\u2588\u2588\u2588\u2588\u2554\u255d\n\u2588\u2588\u2551 \u2588\u2588\u2551\u2588\u2588\u2554\u2550\u2550\u2550\u255d \u2588\u2588\u2554\u2550\u2550\u2588\u2588\u2551\u2588\u2588\u2554\u2550\u2550\u255d \u255a\u2588\u2588\u2554\u255d\n\u255a\u2588\u2588\u2588\u2588\u2588\u2588\u2557\u2588\u2588\u2551\u2588\u2588\u2551 \u2588\u2588\u2551 \u2588\u2588\u2551\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2557 \u2588\u2588\u2551\n\u00a9 Brandon Skerritt\nGithub: brandonskerritt\n\"\"\"\n\nfrom typing import Dict, List, Optional\n\nimport cipheycore\nimport logging\nfrom rich.logging import RichHandler\n\nfrom ciphey.iface import Config, Cracker, CrackInfo, CrackResult, ParamSpec, registry\n\n\n@registry.register\nclass Rot47(Cracker[str]):\n def getInfo(self, ctext: str) -> CrackInfo:\n analysis = self.cache.get_or_update(\n ctext,\n \"cipheycore::simple_analysis\",\n lambda: cipheycore.analyse_string(ctext),\n )\n\n return CrackInfo(\n success_likelihood=cipheycore.caesar_detect(analysis, self.expected),\n # TODO: actually calculate runtimes\n success_runtime=1e-5,\n failure_runtime=1e-5,\n )\n\n @staticmethod\n def getTarget() -> str:\n return \"rot47\"\n\n def attemptCrack(self, ctext: str) -> List[CrackResult]:\n logging.info(f\"Trying ROT47 cipher on {ctext}\")\n\n logging.debug(\"Beginning cipheycore simple analysis\")\n\n # Hand it off to the core\n analysis = self.cache.get_or_update(\n ctext,\n \"cipheycore::simple_analysis\",\n lambda: cipheycore.analyse_string(ctext),\n )\n logging.debug(\"Beginning cipheycore::caesar\")\n possible_keys = cipheycore.caesar_crack(\n analysis, self.expected, self.group, self.p_value\n )\n\n n_candidates = len(possible_keys)\n logging.info(f\"ROT47 returned {n_candidates} candidates\")\n\n if n_candidates == 0:\n logging.debug(\"Filtering for better results\")\n analysis = cipheycore.analyse_string(ctext, self.group)\n possible_keys = cipheycore.caesar_crack(\n analysis, self.expected, self.group, self.p_value\n )\n\n candidates = []\n\n for candidate in possible_keys:\n logging.debug(f\"Candidate {candidate.key} has prob {candidate.p_value}\")\n translated = cipheycore.caesar_decrypt(ctext, candidate.key, self.group)\n candidates.append(CrackResult(value=translated, key_info=candidate.key))\n\n return candidates\n\n @staticmethod\n def getParams() -> Optional[Dict[str, ParamSpec]]:\n return {\n \"expected\": ParamSpec(\n desc=\"The expected distribution of the plaintext\",\n req=False,\n config_ref=[\"default_dist\"],\n ),\n \"group\": ParamSpec(\n desc=\"An ordered sequence of chars that make up the ROT47 cipher alphabet\",\n req=False,\n default=\"\"\"!\"#$%&\\'()*+,-./0123456789:;<=>?@ABCDEFGHIJKLMNOPQRSTUVWXYZ[\\\\]^_`abcdefghijklmnopqrstuvwxyz{|}~\"\"\",\n ),\n \"p_value\": ParamSpec(\n desc=\"The p-value to use for standard frequency analysis\",\n req=False,\n default=0.01,\n )\n # TODO: add \"filter\" param\n }\n\n def __init__(self, config: Config):\n super().__init__(config)\n self.group = list(self._params()[\"group\"])\n self.expected = config.get_resource(self._params()[\"expected\"])\n self.cache = config.cache\n self.p_value = float(self._params()[\"p_value\"])\n\n# Path: ciphey/basemods/Crackers/soundex.py\nimport re\nfrom typing import Dict, List, Optional\n\nimport logging\nfrom rich.logging import RichHandler\n\nfrom ciphey.iface import (\n Config,\n Cracker,\n CrackInfo,\n CrackResult,\n ParamSpec,\n Translation,\n registry,\n)\n\n\n@registry.register\nclass Soundex(Cracker[str]):\n def getInfo(self, ctext: str) -> CrackInfo:\n return CrackInfo(\n success_likelihood=0.1,\n success_runtime=1e-5,\n failure_runtime=1e-5,\n )\n\n @staticmethod\n def getTarget() -> str:\n return \"soundex\"\n\n def attemptCrack(self, ctext: str) -> List[CrackResult]:\n \"\"\"\n Attempts to crack Soundex by generating all possible combinations.\n \"\"\"\n logging.debug(\"Attempting Soundex cracker\")\n word_list = []\n sentences = []\n result = []\n\n # Convert to uppercase and replace delimiters and whitespace with nothing\n ctext = re.sub(r\"[,;:\\-\\s]\", \"\", ctext.upper())\n\n # Make sure ctext contains only A-Z and 0-9\n if bool(re.search(r\"[^A-Z0-9]\", ctext)) is True:\n logging.debug(\"Failed to crack soundex due to non soundex character(s)\")\n return None\n\n # Make sure ctext is divisible by 4\n ctext_len = len(ctext)\n if ctext_len % 4:\n logging.debug(\n f\"Failed to decode Soundex because length must be a multiple of 4, not '{ctext_len}'\"\n )\n return None\n\n # Split ctext into groups of 4\n ctext = \" \".join(ctext[i : i + 4] for i in range(0, len(ctext), 4))\n ctext_split = ctext.split(\" \")\n soundex_keys = self.SOUNDEX_DICT.keys()\n\n # Find all words that correspond to each given soundex code\n for code in ctext_split:\n if code in soundex_keys:\n word_list.append(self.SOUNDEX_DICT[code])\n\n logging.info(f\"Possible words for given encoded text: {word_list}\")\n\n # Find all possible sentences\n self.getSentenceCombo(\n word_list,\n sentences,\n self.frequency_dict,\n self.sentence_freq,\n self.word_freq,\n )\n\n sorted_sentences = self.sortlistwithdict(sentences, self.frequency_dict)\n\n for sentence in sorted_sentences:\n result.append(CrackResult(value=sentence))\n\n logging.debug(f\"Soundex cracker - Returning results: {result}\")\n return result\n\n def sortlistwithdict(self, listtosort, hashes):\n \"\"\"\n This function uses the sum of ranks (based on frequency) of each word in each\n sentence and sorts them according to it.\n \"\"\"\n return sorted(listtosort, key=lambda x: hashes[x])\n\n def getSentenceCombo(\n self, A, sentences, frequency_dict, sentence_freq, word_freq, result=\"\", n=0\n ):\n \"\"\"\n This function uses recursion to generate a list of sentences from all possible\n words for a given set of soundex codes.\n \"\"\"\n logging.debug(\"Creating all possible sentences from Soundex\")\n if n == len(A):\n sentences.append(result[1:])\n for word in result[1:].split():\n # Adding the rank of each word to find out the sentence's net frequency\n if word in word_freq:\n sentence_freq += word_freq.index(word)\n # If the word isn't in the frequency list then it's a very uncommon word\n # so we add a large number (5000)\n else:\n sentence_freq += 5000\n frequency_dict[result[1:]] = sentence_freq\n sentence_freq = 0\n return\n\n for word in A[n]:\n out = result + \" \" + word\n self.getSentenceCombo(\n A, sentences, frequency_dict, sentence_freq, word_freq, out, n + 1\n )\n\n @staticmethod\n def getParams() -> Optional[Dict[str, ParamSpec]]:\n return {\n \"dict\": ParamSpec(\n desc=\"The Soundex dictionary to use\",\n req=False,\n default=\"cipheydists::translate::soundex\",\n ),\n \"freq\": ParamSpec(\n desc=\"The word frequency dictionary to use\",\n req=False,\n default=\"cipheydists::list::English5000Freq\",\n ),\n }\n\n def __init__(self, config: Config):\n super().__init__(config)\n self.SOUNDEX_DICT = config.get_resource(self._params()[\"dict\"], Translation)\n self.word_freq = config.get_resource(self._params()[\"freq\"], Translation)\n self.frequency_dict = {}\n self.sentence_freq = 0\n\n# Path: ciphey/basemods/Crackers/vigenere.py\n\"\"\"\n \u2588\u2588\u2588\u2588\u2588\u2588\u2557\u2588\u2588\u2557\u2588\u2588\u2588\u2588\u2588\u2588\u2557 \u2588\u2588\u2557 \u2588\u2588\u2557\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2557\u2588\u2588\u2557 \u2588\u2588\u2557\n\u2588\u2588\u2554\u2550\u2550\u2550\u2550\u255d\u2588\u2588\u2551\u2588\u2588\u2554\u2550\u2550\u2588\u2588\u2557\u2588\u2588\u2551 \u2588\u2588\u2551\u2588\u2588\u2554\u2550\u2550\u2550\u2550\u255d\u255a\u2588\u2588\u2557 \u2588\u2588\u2554\u255d\n\u2588\u2588\u2551 \u2588\u2588\u2551\u2588\u2588\u2588\u2588\u2588\u2588\u2554\u255d\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2551\u2588\u2588\u2588\u2588\u2588\u2557 \u255a\u2588\u2588\u2588\u2588\u2554\u255d\n\u2588\u2588\u2551 \u2588\u2588\u2551\u2588\u2588\u2554\u2550\u2550\u2550\u255d \u2588\u2588\u2554\u2550\u2550\u2588\u2588\u2551\u2588\u2588\u2554\u2550\u2550\u255d \u255a\u2588\u2588\u2554\u255d\n\u255a\u2588\u2588\u2588\u2588\u2588\u2588\u2557\u2588\u2588\u2551\u2588\u2588\u2551 \u2588\u2588\u2551 \u2588\u2588\u2551\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2557 \u2588\u2588\u2551\n\u00a9 Brandon Skerritt\nGithub: brandonskerritt\n\"\"\"\nfrom distutils import util\nfrom typing import Dict, List, Optional, Union\n\nimport cipheycore\nimport logging\nfrom rich.logging import RichHandler\n\nfrom ciphey.common import fix_case\nfrom ciphey.iface import Config, Cracker, CrackInfo, CrackResult, ParamSpec, registry\n\n\n@registry.register\nclass Vigenere(Cracker[str]):\n def getInfo(self, ctext: str) -> CrackInfo:\n if self.keysize is not None:\n analysis = self.cache.get_or_update(\n ctext,\n f\"vigenere::{self.keysize}\",\n lambda: cipheycore.analyse_string(\n ctext.lower(), self.keysize, self.group\n ),\n )\n\n val = cipheycore.vigenere_detect(analysis, self.expected)\n\n logging.info(f\"Vigenere has likelihood {val}\")\n\n return CrackInfo(\n success_likelihood=val,\n # TODO: actually calculate runtimes\n success_runtime=1e-3,\n failure_runtime=1e-2,\n )\n\n likely_lens = self.cache.get_or_update(\n ctext,\n \"vigenere::likely_lens\",\n lambda: cipheycore.vigenere_likely_key_lens(\n ctext.lower(), self.expected, self.group, self.detect_p_value\n ),\n )\n\n # Filter out the lens that make no sense\n likely_lens = [i for i in likely_lens if i.len <= self.max_key_length]\n\n for keysize in likely_lens:\n # Store the analysis\n analysis = self.cache.get_or_update(\n ctext, f\"vigenere::{keysize.len}\", lambda: keysize.tab\n )\n if len(likely_lens) == 0:\n return CrackInfo(\n success_likelihood=0,\n # TODO: actually calculate runtimes\n success_runtime=2e-3,\n failure_runtime=2e-2,\n )\n\n logging.info(\n f\"Vigenere has likelihood {likely_lens[0].p_value} with lens {[i.len for i in likely_lens]}\"\n )\n\n return CrackInfo(\n success_likelihood=likely_lens[0].p_value,\n # TODO: actually calculate runtimes\n success_runtime=2e-4,\n failure_runtime=2e-4,\n )\n\n @staticmethod\n def getTarget() -> str:\n return \"vigenere\"\n\n def crackOne(\n self, ctext: str, analysis: cipheycore.windowed_analysis_res, real_ctext: str\n ) -> List[CrackResult]:\n possible_keys = cipheycore.vigenere_crack(\n analysis, self.expected, self.group, self.p_value\n )\n if len(possible_keys) > self.clamp:\n possible_keys = possible_keys[: self.clamp]\n \n logging.debug(\n f\"Vigenere crack got keys: {[[i for i in candidate.key] for candidate in possible_keys]}\"\n )\n return [\n CrackResult(\n value=fix_case(\n cipheycore.vigenere_decrypt(ctext, candidate.key, self.group),\n real_ctext,\n ),\n key_info=\"\".join([self.group[i] for i in candidate.key]),\n misc_info=f\"p-value was {candidate.p_value}\",\n )\n for candidate in possible_keys[: min(len(possible_keys), 10)]\n ]\n\n def attemptCrack(self, ctext: str) -> List[CrackResult]:\n logging.info(\"Trying vigenere cipher\")\n # Convert it to lower case\n if self.lower:\n message = ctext.lower()\n else:\n message = ctext\n\n # Analysis must be done here, where we know the case for the cache\n if self.keysize is not None:\n return self.crackOne(\n message,\n self.cache.get_or_update(\n ctext,\n f\"vigenere::{self.keysize}\",\n lambda: cipheycore.analyse_string(\n message, self.keysize, self.group\n ),\n ),\n ctext,\n )\n\n arrs = []\n likely_lens = self.cache.get_or_update(\n ctext,\n \"vigenere::likely_lens\",\n lambda: cipheycore.vigenere_likely_key_lens(\n message, self.expected, self.group\n ),\n )\n possible_lens = [i for i in likely_lens]\n possible_lens.sort(key=lambda i: i.p_value)\n logging.debug(f\"Got possible lengths {[i.len for i in likely_lens]}\")\n # TODO: work out length\n for i in possible_lens:\n arrs.extend(\n self.crackOne(\n message,\n self.cache.get_or_update(\n ctext,\n f\"vigenere::{i.len}\",\n lambda: cipheycore.analyse_string(message, i.len, self.group),\n ),\n ctext,\n )\n )\n\n logging.info(f\"Vigenere returned {len(arrs)} candidates\")\n return arrs\n\n @staticmethod\n def getParams() -> Optional[Dict[str, ParamSpec]]:\n return {\n \"expected\": ParamSpec(\n desc=\"The expected distribution of the plaintext\",\n req=False,\n config_ref=[\"default_dist\"],\n ),\n \"group\": ParamSpec(\n desc=\"An ordered sequence of chars that make up the caesar cipher alphabet\",\n req=False,\n default=\"abcdefghijklmnopqrstuvwxyz\",\n ),\n \"lower\": ParamSpec(\n desc=\"Whether or not the ciphertext should be converted to lowercase first\",\n req=False,\n default=True,\n ),\n \"keysize\": ParamSpec(\n desc=\"A key size that should be used. If not given, will attempt to work it out\",\n req=False,\n ),\n \"p_value\": ParamSpec(\n desc=\"The p-value to use for windowed frequency analysis\",\n req=False,\n default=0.5,\n ),\n \"detect_p_value\": ParamSpec(\n desc=\"The p-value to use for the detection of Vigenere length\",\n req=False,\n default=0.01,\n ),\n \"clamp\": ParamSpec(\n desc=\"The maximum number of candidates that can be returned per key len\",\n req=False,\n default=10,\n ),\n }\n\n def __init__(self, config: Config):\n super().__init__(config)\n self.lower: Union[str, bool] = self._params()[\"lower\"]\n if not isinstance(self.lower, bool):\n self.lower = util.strtobool(self.lower)\n self.group = list(self._params()[\"group\"])\n self.expected = config.get_resource(self._params()[\"expected\"])\n self.cache = config.cache\n self.keysize = self._params().get(\"keysize\")\n if self.keysize is not None:\n self.keysize = int(self.keysize)\n self.p_value = float(self._params()[\"p_value\"])\n self.detect_p_value = float(self._params()[\"detect_p_value\"])\n self.clamp = int(self._params()[\"clamp\"])\n self.max_key_length = 16\n\n# Path: ciphey/basemods/Crackers/xandy.py\nimport re\nfrom typing import Dict, List, Optional\n\nimport logging\nfrom rich.logging import RichHandler\n\nfrom ciphey.iface import Config, Cracker, CrackInfo, CrackResult, ParamSpec, registry\n\n\n@registry.register\nclass Xandy(Cracker[str]):\n def getInfo(self, ctext: str) -> CrackInfo:\n return CrackInfo(\n success_likelihood=0.1,\n success_runtime=1e-5,\n failure_runtime=1e-5,\n )\n\n @staticmethod\n def binary_to_ascii(variant):\n # Convert the binary string to an integer with base 2\n binary_int = int(variant, 2)\n byte_number = binary_int.bit_length() + 7 // 8\n\n # Convert the resulting int to a bytearray and then decode it to ASCII text\n binary_array = binary_int.to_bytes(byte_number, \"big\")\n try:\n ascii_text = binary_array.decode()\n logging.debug(f\"Found possible solution: {ascii_text[:32]}\")\n return ascii_text\n except UnicodeDecodeError as e:\n logging.debug(f\"Failed to crack X-Y due to a UnicodeDecodeError: {e}\")\n return \"\"\n\n @staticmethod\n def getTarget() -> str:\n return \"xandy\"\n\n def attemptCrack(self, ctext: str) -> List[CrackResult]:\n \"\"\"\n Checks an input if it only consists of two or three different letters.\n If this is the case, it attempts to regard those letters as\n 0 and 1 (with the third characters as an optional delimiter) and then\n converts it to ASCII text.\n \"\"\"\n logging.debug(\"Attempting X-Y replacement\")\n variants = []\n candidates = []\n result = []\n\n # Convert the ctext to all-lowercase and regex-match & replace all whitespace\n ctext = re.sub(r\"\\s+\", \"\", ctext.lower(), flags=re.UNICODE)\n\n # cset contains every unique value in the ctext\n cset = list(set(list(ctext)))\n cset_len = len(cset)\n\n if not 1 < cset_len < 4:\n # We only consider inputs with two or three unique values\n logging.debug(\n \"Failed to crack X-Y due to not containing two or three unique values\"\n )\n return None\n\n logging.debug(f\"String contains {cset_len} unique values: {cset}\")\n\n # In case of three unique values, we regard the least frequent character as the delimiter\n if cset_len == 3:\n # Count each unique character in the set to determine the least frequent one\n counting_list = []\n for char in cset:\n counting_list.append(ctext.count(char))\n val, index = min((val, index) for (index, val) in enumerate(counting_list))\n delimiter = cset[index]\n logging.debug(\n f\"{delimiter} occurs {val} times and is the probable delimiter\"\n )\n # Remove the delimiter from the ctext and compute new cset\n ctext = ctext.replace(delimiter, \"\")\n cset = list(set(list(ctext)))\n\n # Form both variants of the substitution\n for i in range(2):\n if i:\n variants.append(ctext.replace(cset[0], \"1\").replace(cset[1], \"0\"))\n else:\n variants.append(ctext.replace(cset[0], \"0\").replace(cset[1], \"1\"))\n\n # Apply function to both variants and strip stray NULL characters\n for variant in variants:\n candidates.append(self.binary_to_ascii(variant).strip(\"\\x00\"))\n for i, candidate in enumerate(candidates):\n if candidate != \"\":\n keyinfo = f\"{cset[0]} -> {i} & {cset[1]} -> {str(int(not i))}\"\n result.append(CrackResult(value=candidate, key_info=keyinfo))\n logging.debug(f\"X-Y cracker - Returning results: {result}\")\n return result\n\n @staticmethod\n def getParams() -> Optional[Dict[str, ParamSpec]]:\n return {\n \"expected\": ParamSpec(\n desc=\"The expected distribution of the plaintext\",\n req=False,\n config_ref=[\"default_dist\"],\n )\n }\n\n def __init__(self, config: Config):\n super().__init__(config)\n self.expected = config.get_resource(self._params()[\"expected\"])\n self.cache = config.cache\n\n# Path: ciphey/basemods/Crackers/xortool.py\n\"\"\"\n \u2588\u2588\u2588\u2588\u2588\u2588\u2557\u2588\u2588\u2557\u2588\u2588\u2588\u2588\u2588\u2588\u2557 \u2588\u2588\u2557 \u2588\u2588\u2557\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2557\u2588\u2588\u2557 \u2588\u2588\u2557\n\u2588\u2588\u2554\u2550\u2550\u2550\u2550\u255d\u2588\u2588\u2551\u2588\u2588\u2554\u2550\u2550\u2588\u2588\u2557\u2588\u2588\u2551 \u2588\u2588\u2551\u2588\u2588\u2554\u2550\u2550\u2550\u2550\u255d\u255a\u2588\u2588\u2557 \u2588\u2588\u2554\u255d\n\u2588\u2588\u2551 \u2588\u2588\u2551\u2588\u2588\u2588\u2588\u2588\u2588\u2554\u255d\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2551\u2588\u2588\u2588\u2588\u2588\u2557 \u255a\u2588\u2588\u2588\u2588\u2554\u255d\n\u2588\u2588\u2551 \u2588\u2588\u2551\u2588\u2588\u2554\u2550\u2550\u2550\u255d \u2588\u2588\u2554\u2550\u2550\u2588\u2588\u2551\u2588\u2588\u2554\u2550\u2550\u255d \u255a\u2588\u2588\u2554\u255d\n\u255a\u2588\u2588\u2588\u2588\u2588\u2588\u2557\u2588\u2588\u2551\u2588\u2588\u2551 \u2588\u2588\u2551 \u2588\u2588\u2551\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2557 \u2588\u2588\u2551\n\u00a9 Brandon Skerritt\nGithub: bee-san\n\"\"\"\nfrom typing import Dict, List, Optional\n\nimport logging\nfrom rich.logging import RichHandler\n\nfrom xortool_ciphey import tool_main\n\nfrom ciphey.iface import Config, Cracker, CrackInfo, CrackResult, ParamSpec, registry\n\n\n@registry.register\nclass XorTool(Cracker[str]):\n def getInfo(self, ctext: str) -> CrackInfo:\n return CrackInfo(\n success_likelihood=0.1,\n # TODO: actually calculate runtimes\n success_runtime=1e-8,\n failure_runtime=1e-8,\n )\n\n @staticmethod\n def getTarget() -> str:\n return \"xortool\"\n\n def attemptCrack(self, ctext: str) -> List[CrackResult]:\n logging.debug(\"Trying xortool cipher\")\n # TODO handle different charsets\n # TODO allow more config over xortool\n\n logging.debug(f\"{ctext}\")\n\n # https://github.com/Ciphey/xortool/discussions/4\n # for docs on this function\n try:\n result = tool_main.api(str.encode(ctext))\n except:\n logging.debug(\"Xor failed.\")\n return\n\n result = CrackResult(value=result[1][\"Dexored\"], key_info=result[0][\"keys\"])\n\n return [result]\n\n @staticmethod\n def getParams() -> Optional[Dict[str, ParamSpec]]:\n return {\n \"expected\": ParamSpec(\n desc=\"The expected distribution of the plaintext\",\n req=False,\n config_ref=[\"default_dist\"],\n ),\n \"p_value\": ParamSpec(\n desc=\"The p-value to use for standard frequency analysis\",\n req=False,\n default=0.01,\n ),\n }\n\n @staticmethod\n def score_utility() -> float:\n return 1.5\n\n def __init__(self, config: Config):\n super().__init__(config)\n self.expected = config.get_resource(self._params()[\"expected\"])\n self.cache = config.cache\n self.p_value = self._params()[\"p_value\"]\n\n# Path: ciphey/basemods/Crackers/__init__.py\nfrom . import (\n affine,\n ascii_shift,\n baconian,\n caesar,\n rot47,\n soundex,\n vigenere,\n xandy,\n xortool,\n)\n\n# Path: ciphey/basemods/Decoders/a1z26.py\nimport re\nfrom typing import Dict, Optional\n\nimport logging\nfrom rich.logging import RichHandler\n\nfrom ciphey.iface import Config, Decoder, ParamSpec, T, U, registry\n\n\n@registry.register\nclass A1z26(Decoder[str]):\n def decode(self, ctext: T) -> Optional[U]:\n \"\"\"\n Performs A1Z26 decoding\n \"\"\"\n logging.debug(\"Attempting A1Z26\")\n ctext_converted = []\n ctext_split = re.split(r\"[ ,;:\\-\\n]\", ctext)\n delimiters = set(sorted(re.sub(r\"[^ ,;:\\-\\n]\", \"\", ctext)))\n ctext_num = re.sub(r\"[,;:\\-\\s]\", \"\", ctext)\n ctext_decoded = \"\"\n if ctext_num.isnumeric() is False:\n logging.debug(\"Failed to decode A1Z26 due to non numeric character(s)\")\n return None\n try:\n for i in ctext_split:\n val = int(i)\n if val > 26 or val < 1:\n logging.debug(\n f\"Failed to decode A1Z26 due to invalid number '{val}'\"\n )\n return None\n val2 = int(i) + 96\n ctext_converted.append(chr(val2))\n ctext_decoded = \"\".join(ctext_converted)\n logging.info(\n f\"A1Z26 successful, returning '{ctext_decoded}' with delimiter(s) {delimiters}\"\n )\n return ctext_decoded\n except Exception:\n return None\n\n @staticmethod\n def priority() -> float:\n return 0.05\n\n def __init__(self, config: Config):\n super().__init__(config)\n\n @staticmethod\n def getParams() -> Optional[Dict[str, ParamSpec]]:\n return None\n\n @staticmethod\n def getTarget() -> str:\n return \"a1z26\"\n\n# Path: ciphey/basemods/Decoders/atbash.py\nfrom typing import Dict, Optional\n\nfrom ciphey.common import fix_case\nfrom ciphey.iface import Config, Decoder, ParamSpec, T, U, WordList, registry\n\n\n@registry.register\nclass Atbash(Decoder[str]):\n def decode(self, ctext: T) -> Optional[U]:\n \"\"\"\n Takes an encoded string and attempts to decode it according to the Atbash cipher.\n\n The Atbash cipher is a very simple substitution cipher without a key.\n It operates by replacing every letter in the input by its 'counterpoint'\n in the alphabet. Example: A -> Z, B -> Y, ... , M -> N and vice versa.\n \"\"\"\n\n result = \"\"\n atbash_dict = {self.ALPHABET[i]: self.ALPHABET[::-1][i] for i in range(26)}\n\n for letter in ctext.lower():\n if letter in atbash_dict.keys():\n # Match every letter of the input to its atbash counterpoint\n result += atbash_dict[letter]\n else:\n # If the current character is not in the defined alphabet,\n # just accept it as-is (useful for numbers, punctuation, etc.)\n result += letter\n return fix_case(result, ctext)\n\n @staticmethod\n def priority() -> float:\n # Not expected to show up often, but also very fast to check.\n return 0.1\n\n def __init__(self, config: Config):\n super().__init__(config)\n self.ALPHABET = config.get_resource(self._params()[\"dict\"], WordList)\n\n @staticmethod\n def getParams() -> Optional[Dict[str, ParamSpec]]:\n return {\n \"dict\": ParamSpec(\n desc=\"The alphabet used for the atbash operation.\",\n req=False,\n default=\"cipheydists::list::englishAlphabet\",\n )\n }\n\n @staticmethod\n def getTarget() -> str:\n return \"atbash\"\n\n# Path: ciphey/basemods/Decoders/base58_bitcoin.py\nfrom typing import Dict, Optional\n\nimport base58\n\nfrom ciphey.iface import Config, Decoder, ParamSpec, T, U, registry\n\n\n@registry.register\nclass Base58_bitcoin(Decoder[str]):\n def decode(self, ctext: T) -> Optional[U]:\n \"\"\"\n Performs Base58 (Bitcoin) decoding\n \"\"\"\n try:\n return base58.b58decode(ctext).decode(\"utf-8\")\n except Exception:\n return None\n\n @staticmethod\n def priority() -> float:\n # Not expected to show up often, but also very fast to check.\n return 0.05\n\n def __init__(self, config: Config):\n super().__init__(config)\n\n @staticmethod\n def getParams() -> Optional[Dict[str, ParamSpec]]:\n return None\n\n @staticmethod\n def getTarget() -> str:\n return \"base58_bitcoin\"\n\n# Path: ciphey/basemods/Decoders/base58_ripple.py\nfrom typing import Dict, Optional\n\nimport base58\n\nfrom ciphey.iface import Config, Decoder, ParamSpec, T, U, registry\n\n\n@registry.register\nclass Base58_ripple(Decoder[str]):\n def decode(self, ctext: T) -> Optional[U]:\n \"\"\"\n Performs Base58 (Ripple) decoding\n \"\"\"\n try:\n return base58.b58decode(ctext, alphabet=base58.RIPPLE_ALPHABET).decode(\n \"utf-8\"\n )\n except Exception:\n return None\n\n @staticmethod\n def priority() -> float:\n # Not expected to show up often, but also very fast to check.\n return 0.05\n\n def __init__(self, config: Config):\n super().__init__(config)\n\n @staticmethod\n def getParams() -> Optional[Dict[str, ParamSpec]]:\n return None\n\n @staticmethod\n def getTarget() -> str:\n return \"base58_ripple\"\n\n# Path: ciphey/basemods/Decoders/base62.py\nfrom typing import Dict, Optional\n\nimport base62\n\nfrom ciphey.iface import Config, Decoder, ParamSpec, T, U, registry\n\n\n@registry.register\nclass Base62(Decoder[str]):\n def decode(self, ctext: T) -> Optional[U]:\n \"\"\"\n Performs Base62 decoding\n \"\"\"\n try:\n return base62.decodebytes(ctext).decode(\"utf-8\")\n except Exception:\n return None\n\n @staticmethod\n def priority() -> float:\n # Not expected to show up often, but also very fast to check.\n return 0.05\n\n def __init__(self, config: Config):\n super().__init__(config)\n\n @staticmethod\n def getParams() -> Optional[Dict[str, ParamSpec]]:\n return None\n\n @staticmethod\n def getTarget() -> str:\n return \"base62\"\n\n# Path: ciphey/basemods/Decoders/base69.py\n# Translated to Python and adapted for Ciphey from the JS original at https://github.com/pshihn/base69\n\n\nimport re\nfrom math import ceil\nfrom typing import Dict, Optional\n\nfrom ciphey.iface import Config, Decoder, ParamSpec, T, U, WordList, registry\n\n\n@registry.register\nclass Base69(Decoder[str]):\n def decode(self, ctext: T) -> Optional[U]:\n \"\"\"\n Performs Base69 decoding\n \"\"\"\n # Remove whitespace\n try:\n ctext = re.sub(r\"\\s+\", \"\", ctext, flags=re.UNICODE)\n extra_bytes = 0\n clen = len(ctext)\n\n if ctext[:-1] == \"=\":\n extra_bytes = int(ctext[clen - 2])\n\n CHUNK_COUNT = ceil(clen / 16)\n result = [0 for _ in range(CHUNK_COUNT * 7 - extra_bytes)]\n\n for i in range(CHUNK_COUNT):\n chunk_string = ctext[i * 16 : (i + 1) * 16]\n if extra_bytes and (i == CHUNK_COUNT - 1):\n insert = self.decode_chunk(chunk_string)\n for n, elem in enumerate(insert[0 : 7 - extra_bytes]):\n result[n + i * 7] = elem\n else:\n insert = self.decode_chunk(chunk_string)\n for n, elem in enumerate(insert):\n result[n + i * 7] = elem % 256\n return bytearray(result).decode().strip(\"\\x00\")\n except Exception:\n return None\n\n def decode_chunk(self, s: str):\n padded_bytes = s.endswith(\"=\")\n\n decoded = [0 for _ in range(8)]\n for i in range(8):\n decoded[i] = (\n 0\n if i == 7 and padded_bytes\n else self.chars_to_byte(s[i * 2 : i * 2 + 2])\n )\n\n result = [0 for _ in range(7)]\n for i in range(7):\n t1 = decoded[i] << (i + 1)\n t2 = decoded[i + 1] >> (7 - i - 1)\n result[i] = t1 | t2\n return result\n\n def chars_to_byte(self, s: str):\n return (69 * self.CHARS.index(s[1])) + (self.CHARS.index(s[0]))\n\n @staticmethod\n def priority() -> float:\n # If this becomes lower or equal to the reverse, it breaks.\n # So I'll set it to 0.2 for now since it is very fast anyways.\n return 0.2\n\n def __init__(self, config: Config):\n super().__init__(config)\n self.CHARS = config.get_resource(self._params()[\"dict\"], WordList)\n\n @staticmethod\n def getParams() -> Optional[Dict[str, ParamSpec]]:\n return {\n \"dict\": ParamSpec(\n desc=\"The charset used for the decoder.\",\n req=False,\n default=\"cipheydists::list::base69\",\n )\n }\n\n @staticmethod\n def getTarget() -> str:\n return \"base69\"\n\n# Path: ciphey/basemods/Decoders/base91.py\nfrom typing import Dict, Optional\n\nimport base91\n\nfrom ciphey.iface import Config, Decoder, ParamSpec, T, U, registry\n\n\n@registry.register\nclass Base91(Decoder[str]):\n def decode(self, ctext: T) -> Optional[U]:\n \"\"\"\n Performs Base91 decoding\n \"\"\"\n try:\n return base91.decode(ctext).decode(\"utf-8\")\n except Exception:\n return None\n\n @staticmethod\n def priority() -> float:\n # Not expected to show up often, but also very fast to check.\n return 0.05\n\n def __init__(self, config: Config):\n super().__init__(config)\n\n @staticmethod\n def getParams() -> Optional[Dict[str, ParamSpec]]:\n return None\n\n @staticmethod\n def getTarget() -> str:\n return \"base91\"\n\n# Path: ciphey/basemods/Decoders/bases.py\nimport base64\nimport types\nfrom typing import Any, Callable, Optional\n\nimport logging\nfrom rich.logging import RichHandler\nimport re\n\nfrom ciphey.common import id_lambda\nfrom ciphey.iface import Decoder, registry\n\n\ndef _dispatch(self: Any, ctext: str, func: Callable[[str], bytes]) -> Optional[bytes]:\n logging.debug(f\"Attempting {self.getTarget()}\")\n\n try:\n # remove all whitespace\n ctext = re.sub(r\"\\s+\", \"\", ctext, re.UNICODE)\n result = func(ctext)\n logging.info(f\"{self.getTarget()} successful, returning {result}\")\n return result\n except ValueError:\n logging.debug(f\"Failed to decode {self.getTarget()}\")\n return None\n\n\n_bases = {\n \"base16\": (base64.b16decode, 0.4),\n \"base32\": (base64.b32decode, 0.01),\n \"base64\": (base64.b64decode, 0.4),\n \"base85\": (base64.b85decode, 0.01),\n \"ascii85\": (base64.a85decode, 0.1),\n}\n\n\ndef gen_class(name, decoder, priority, ns):\n ns[\"_get_func\"] = id_lambda(decoder)\n ns[\"decode\"] = lambda self, ctext: _dispatch(self, ctext, self._get_func())\n ns[\"getParams\"] = id_lambda(None)\n ns[\"getTarget\"] = id_lambda(name)\n ns[\"priority\"] = id_lambda(priority)\n ns[\"__init__\"] = lambda self, config: super(type(self), self).__init__(config)\n\n\nfor name, (decoder, priority) in _bases.items():\n t = types.new_class(\n name,\n (Decoder[str],),\n exec_body=lambda x: gen_class(name, decoder, priority, x),\n )\n\n registry.register(t)\n\n# Path: ciphey/basemods/Decoders/baudot.py\nimport re\nfrom typing import Dict, Optional\n\nfrom ciphey.iface import Config, Decoder, ParamSpec, T, Translation, U, registry\n\n\n@registry.register\nclass Baudot(Decoder[str]):\n def decode(self, ctext: T) -> Optional[U]:\n result = \"\"\n switch_to_digit_map = 0\n if re.search(\"^[01]{5}$\", ctext.split()[0]):\n for i in ctext.split():\n if i == \"11011\":\n switch_to_digit_map = 1\n if i == \"11111\":\n switch_to_digit_map = 0\n if switch_to_digit_map == 1:\n result += self.BAUDOT_DICT[\"+\" + i]\n if switch_to_digit_map == 0:\n result += self.BAUDOT_DICT[i]\n return result\n else:\n return None\n\n @staticmethod\n def priority() -> float:\n return 0.05\n\n def __init__(self, config: Config):\n super().__init__(config)\n self.BAUDOT_DICT = config.get_resource(self._params()[\"dict\"], Translation)\n\n @staticmethod\n def getParams() -> Optional[Dict[str, ParamSpec]]:\n return {\n \"dict\": ParamSpec(\n desc=\"The baudot alphabet dictionary to use\",\n req=False,\n default=\"cipheydists::translate::baudot\",\n )\n }\n\n @staticmethod\n def getTarget() -> str:\n return \"baudot\"\n\n# Path: ciphey/basemods/Decoders/binary.py\nimport re\nfrom typing import Dict, List, Optional\n\nimport logging\nfrom rich.logging import RichHandler\n\nfrom ciphey.iface import Config, Decoder, ParamSpec, T, U, registry\n\n\n@registry.register\nclass Binary(Decoder[str]):\n def decode(self, ctext: T) -> Optional[U]:\n try:\n ctext = re.sub(r\"[^\\S \\n]\", \" \", ctext, flags=re.UNICODE)\n ctext = ctext.replace(\"\\n\", \" \")\n\n existing_split = self.try_split(ctext.split(\" \"))\n if existing_split is not None:\n return existing_split\n\n # Now we try our own grouping\n\n # Remove final bit of whitespace\n ctext = ctext.replace(\" \", \"\")\n # Split into bytes, and test\n return self.try_split([ctext[i : i + 8] for i in range(0, len(ctext), 8)])\n # Catch bad octal chars\n except ValueError:\n return None\n\n def try_split(self, split_text: List[str]):\n ret = []\n\n for i in split_text:\n if len(i) == 0:\n continue\n val = int(i, 2)\n if val > 255 or val < 0:\n return None\n ret.append(val)\n\n if len(ret) != 0:\n ret = bytes(ret)\n logging.info(f\"binary successful, returning {ret.__repr__()}\")\n return ret\n\n @staticmethod\n def priority() -> float:\n return 0.3\n\n def __init__(self, config: Config):\n super().__init__(config)\n\n @staticmethod\n def getParams() -> Optional[Dict[str, ParamSpec]]:\n return None\n\n @staticmethod\n def getTarget() -> str:\n return \"binary\"\n\n# Path: ciphey/basemods/Decoders/braille.py\nimport re\nfrom typing import Dict, Optional\n\nfrom ciphey.iface import Config, Decoder, ParamSpec, T, Translation, U, registry\nimport logging\nfrom rich.logging import RichHandler\n\n\n@registry.register\nclass Braille(Decoder[str]):\n def decode(self, ctext: T) -> Optional[U]:\n \"\"\"\n Performs Braille decoding\n \"\"\"\n logging.debug(\"Attempting Braille\")\n ctext_decoded = \"\"\n braille_matches = 0\n for symbol in self.BRAILLE_DICT_INV.values():\n if symbol in ctext:\n braille_matches += 1\n else:\n continue\n if braille_matches == 0:\n logging.debug(\"Failed to decode Braille due to invalid characters\")\n return None\n\n for pattern, value in self.BRAILLE_DICT.items():\n ctext = re.sub(pattern, value, ctext)\n\n wordArr = []\n for word in ctext.split(\" \"):\n # If two commas are in front of a word, uppercase the word and remove the comma\n if word[:2].find(\",,\") != -1:\n wordArr.append(word.replace(\",,\", \"\").upper())\n else:\n wordArr.append(word)\n\n result = []\n for word in wordArr:\n # If one comma is in front of a word, capitalize the word and remove the comma\n if word[0].find(\",\") != -1:\n result.append(word.replace(\",\", \"\").capitalize())\n else:\n result.append(word)\n ctext_decoded = \" \".join(result)\n logging.info(f\"Braille successful, returning '{ctext_decoded}'\")\n return ctext_decoded\n\n @staticmethod\n def priority() -> float:\n return 0.05\n\n def __init__(self, config: Config):\n super().__init__(config)\n self.BRAILLE_DICT = config.get_resource(self._params()[\"dict\"], Translation)\n self.BRAILLE_DICT_INV = {v: k for k, v in self.BRAILLE_DICT.items()}\n\n @staticmethod\n def getParams() -> Optional[Dict[str, ParamSpec]]:\n return {\n \"dict\": ParamSpec(\n desc=\"The Braille dictionary to use\",\n req=False,\n default=\"cipheydists::translate::braille\",\n )\n }\n\n @staticmethod\n def getTarget() -> str:\n return \"braille\"\n\n# Path: ciphey/basemods/Decoders/brainfuck.py\nimport re\nimport time\nfrom typing import Dict, Optional, Tuple\n\nimport logging\nfrom rich.logging import RichHandler\n\nfrom ciphey.iface import Config, Decoder, ParamSpec, T, U, WordList, registry\n\n\n@registry.register\nclass Brainfuck(Decoder[str]):\n def decode(self, ctext: T) -> Optional[U]:\n \"\"\"\n Takes a ciphertext and treats it as a Brainfuck program,\n interpreting it and saving the output as a string to return.\n\n Brainfuck is a very simple, Turing-complete esoteric language.\n Below is a simplified interpreter that attempts to check whether a\n given ciphertext is a brainfuck program that would output a string.\n\n A program that can be \"decoded\" like this is one that:\n * Does not require user input (\",\" instruction)\n * Includes at least one putchar instruction (\".\")\n * Does not contain anything but the main 7 instructions,\n (excluding \",\") and whitespace\n\n Details:\n * This implementation wraps the memory pointer for \">\" and \"<\"\n * It is time-limited to 60 seconds, to prevent hangups\n * The program starts with 100 memory cells, chosen arbitrarily\n \"\"\"\n\n logging.debug(\"Attempting brainfuck\")\n\n result = \"\"\n memory = [0] * 100\n codeptr, memptr = 0, 0 # Instruction pointer and stack pointer\n timelimit = 60 # The timeout in seconds\n\n bracemap, isbf = self.bracemap_and_check(ctext)\n\n # If it doesn't appear to be valid brainfuck code\n if not isbf:\n logging.debug(\"Failed to interpret brainfuck due to invalid characters\")\n return None\n...\n", "needle_token_start": 7123, "needle_token_end": 7385, "code_context_ntokens": 16094, "output": ["```python\ndef crackOne(self, ctext: str, analysis: cipheycore.windowed_analysis_res, real_ctext: str) -> List[CrackResult]:\n possible_keys = cipheycore.vigenere_crack(\n analysis, self.expected, self.group, self.p_value\n )\n if len(possible_keys) > self.clamp:\n possible_keys = possible_keys[: self.clamp]\n \n logging.debug(\n f\"Vigenere crack got keys: {[[i for i in candidate.key] for candidate in possible_keys]}\"\n )\n return [\n CrackResult(\n value=fix_case(\n cipheycore.vigenere_decrypt(ctext, candidate.key, self.group),\n real_ctext,\n ),\n key_info=\"\".join([self.group[i] for i in candidate.key]),\n misc_info=f\"p-value was {candidate.p_value}\",\n )\n for candidate in possible_keys[: min(len(possible_keys), 10)]\n ]\n```"]} +{"repo": "Ciphey/Ciphey", "name": "new_sort", "language": "python", "path": "ciphey/mathsHelper.py", "position_ratio": 0.55, "description": "\nFunction Description:\n1. **Purpose**: To organize the elements of a dictionary in descending order based on their values.\n2. **Input**: A dictionary where keys are associated with numerical values.\n3. **Output**: A dictionary sorted in descending order of the values, maintaining the order of keys with equal values as in the original dictionary.\n4. **Procedure**: The function takes a dictionary, sorts it based on the values in descending order using a lambda function as the sorting key, and returns the sorted dictionary as an ordered dictionary to preserve the order of insertion.\n\n", "instruction": "Based on the function description and code context, please retrieve and repeat the exact described function from the code context in a code block wrapped by ```:", "template": "instruction\ncode_context\ndescription\ninstruction", "code_context": "# Path: ciphey/iface/_modules.py\nfrom abc import ABC, abstractmethod\nfrom typing import Any, Dict, Generic, List, NamedTuple, Optional, Set, Type, TypeVar\n\nfrom rich import box\nfrom rich.console import Console\nfrom rich.markup import escape\nfrom rich.table import Table\n\nfrom ._fwd import config as Config\n\nT = TypeVar(\"T\")\nU = TypeVar(\"U\")\n\nconsole = Console()\n\n\nclass ParamSpec(NamedTuple):\n \"\"\"\n Attributes:\n req Whether this argument is required\n desc A description of what this argument does\n default The default value for this argument. Ignored if req == True or configPath is not None\n config_ref The path to the config that should be the default value\n list Whether this parameter is in the form of a list, and can therefore be specified more than once\n visible Whether the user can tweak this via the command line\n \"\"\"\n\n req: bool\n desc: str\n default: Optional[Any] = None\n list: bool = False\n config_ref: Optional[List[str]] = None\n visible: bool = True\n\n\nclass ConfigurableModule(ABC):\n @staticmethod\n @abstractmethod\n def getParams() -> Optional[Dict[str, ParamSpec]]:\n \"\"\"\n Returns a dictionary of `argument name: argument specification`\n \"\"\"\n pass\n\n def _checkParams(self):\n \"\"\"\n Fills the given params dict with default values where arguments are not given,\n using None as the default value for default values\n \"\"\"\n\n params = self._params()\n config = self._config()\n\n for key, value in self.getParams().items():\n # If we already have it, then we don't need to do anything\n if key in params:\n continue\n # If we don't have it, but it's required, then fail\n if value.req:\n raise KeyError(\n f\"Missing required param {key} for {type(self).__name__.lower()}\"\n )\n # If it's a reference by default, fill that in\n if value.config_ref is not None:\n tmp = getattr(config, value.config_ref[0])\n params[key] = (\n tmp[value.config_ref[1:]] if len(value.config_ref) > 1 else tmp\n )\n # Otherwise, put in the default value (if it exists)\n elif value.default is not None:\n params[key] = value.default\n\n def _params(self):\n return self._params_obj\n\n def _config(self):\n return self._config_obj\n\n @abstractmethod\n def __init__(self, config: Config):\n self._config_obj = config\n if self.getParams() is not None:\n self._params_obj = config.params.setdefault(type(self).__name__.lower(), {})\n self._checkParams()\n\n\nclass Targeted(ABC):\n @staticmethod\n @abstractmethod\n def getTarget() -> str:\n \"\"\"Should return the target that this object attacks/decodes\"\"\"\n pass\n\n\nclass PolymorphicChecker(ConfigurableModule):\n @abstractmethod\n def check(self, text) -> Optional[str]:\n \"\"\"Should return some description (or an empty string) on success, otherwise return None\"\"\"\n pass\n\n @abstractmethod\n def getExpectedRuntime(self, text) -> float:\n pass\n\n def __call__(self, *args):\n return self.check(*args)\n\n @abstractmethod\n def __init__(self, config: Config):\n super().__init__(config)\n\n\nclass Checker(Generic[T], ConfigurableModule):\n @abstractmethod\n def check(self, text: T) -> Optional[str]:\n \"\"\"Should return some description (or an empty string) on success, otherwise return None\"\"\"\n pass\n\n @abstractmethod\n def getExpectedRuntime(self, text: T) -> float:\n pass\n\n def __call__(self, *args):\n return self.check(*args)\n\n @abstractmethod\n def __init__(self, config: Config):\n super().__init__(config)\n\n @classmethod\n def convert(cls, expected: Set[type]):\n class PolyWrapperClass(PolymorphicChecker):\n @staticmethod\n def getParams() -> Optional[Dict[str, ParamSpec]]:\n return cls.getParams()\n\n def check(self, text) -> Optional[str]:\n \"\"\"Should return some description (or an empty string) on success, otherwise return None\"\"\"\n if type(text) not in expected:\n return None\n else:\n return self._base.check(text)\n\n def getExpectedRuntime(self, text) -> float:\n if type(text) not in expected:\n return 0\n else:\n return self._base.getExpectedRuntime(text)\n\n def __init__(self, config: Config):\n super().__init__(config)\n # This is easier than inheritance\n self._base = cls(config)\n\n PolyWrapperClass.__name__ = cls.__name__\n\n return PolyWrapperClass\n\n\n# class Detector(Generic[T], ConfigurableModule, KnownUtility, Targeted):\n# @abstractmethod\n# def scoreLikelihood(self, ctext: T) -> Dict[str, float]:\n# \"\"\"Should return a dictionary of (cipher_name: score)\"\"\"\n# pass\n#\n# def __call__(self, *args): return self.scoreLikelihood(*args)\n#\n# @abstractmethod\n# def __init__(self, config: Config): super().__init__(config)\n\n\nclass Decoder(Generic[T], ConfigurableModule, Targeted):\n \"\"\"Represents the undoing of some encoding into a different (or the same) type\"\"\"\n\n @abstractmethod\n def decode(self, ctext: T) -> Optional[U]:\n pass\n\n @staticmethod\n @abstractmethod\n def priority() -> float:\n \"\"\"What proportion of decodings are this?\"\"\"\n pass\n\n def __call__(self, *args):\n return self.decode(*args)\n\n @abstractmethod\n def __init__(self, config: Config):\n super().__init__(config)\n\n\nclass DecoderComparer:\n value: Type[Decoder]\n\n def __le__(self, other: \"DecoderComparer\"):\n return self.value.priority() <= other.value.priority()\n\n def __ge__(self, other: \"DecoderComparer\"):\n return self.value.priority() >= other.value.priority()\n\n def __lt__(self, other: \"DecoderComparer\"):\n return self.value.priority() < other.value.priority() and self != other\n\n def __gt__(self, other: \"DecoderComparer\"):\n return self.value.priority() > other.value.priority() and self != other\n\n def __init__(self, value: Type[Decoder]):\n self.value = value\n\n def __repr__(self):\n return f\"\"\n\n\nclass CrackResult(NamedTuple):\n # TODO consider using Generic[T] again for value's type once\n # https://bugs.python.org/issue36517 is resolved\n value: Any\n key_info: Optional[str] = None\n misc_info: Optional[str] = None\n\n\nclass CrackInfo(NamedTuple):\n success_likelihood: float\n success_runtime: float\n failure_runtime: float\n\n\nclass Cracker(Generic[T], ConfigurableModule, Targeted):\n @abstractmethod\n def getInfo(self, ctext: T) -> CrackInfo:\n \"\"\"Should return some informed guesses on resource consumption when run on `ctext`\"\"\"\n pass\n\n @abstractmethod\n def attemptCrack(self, ctext: T) -> List[CrackResult]:\n \"\"\"\n This should attempt to crack the cipher `target`, and return a list of candidate solutions\n \"\"\"\n # FIXME: Actually CrackResult[T], but python complains\n pass\n\n def __call__(self, *args):\n return self.attemptCrack(*args)\n\n @abstractmethod\n def __init__(self, config: Config):\n super().__init__(config)\n\n\nclass ResourceLoader(Generic[T], ConfigurableModule):\n @abstractmethod\n def whatResources(self) -> Optional[Set[str]]:\n \"\"\"\n Return a set of the names of instances T you can provide.\n The names SHOULD be unique amongst ResourceLoaders of the same type\n\n These names will be exposed as f\"{self.__name__}::{name}\", use split_resource_name to recover this\n\n If you cannot reasonably determine what resources you provide, return None instead\n \"\"\"\n pass\n\n @abstractmethod\n def getResource(self, name: str) -> T:\n \"\"\"\n Returns the requested distribution\n\n The behavior is undefined if `name not in self.what_resources()`\n \"\"\"\n pass\n\n def __call__(self, *args):\n return self.getResource(*args)\n\n def __getitem__(self, *args):\n return self.getResource(*args)\n\n @abstractmethod\n def __init__(self, config: Config):\n super().__init__(config)\n\n\nclass SearchLevel(NamedTuple):\n name: str\n result: CrackResult\n\n @staticmethod\n def input(ctext: Any):\n return SearchLevel(name=\"input\", result=CrackResult(ctext))\n\n\nclass SearchResult(NamedTuple):\n path: List[SearchLevel]\n check_res: str\n\n\nclass Searcher(ConfigurableModule):\n \"\"\"A very basic interface for code that plans out how to crack the ciphertext\"\"\"\n\n @abstractmethod\n def search(self, ctext: Any) -> Optional[SearchResult]:\n \"\"\"Returns the path to the correct ciphertext\"\"\"\n pass\n\n @abstractmethod\n def __init__(self, config: Config):\n super().__init__(config)\n\n\ndef pretty_search_results(res: SearchResult, display_intermediate: bool = False) -> str:\n # TODO what is display_intermediate\n ret: str = \"\"\n table = Table(show_header=False, box=box.ROUNDED, safe_box=False)\n # Only print the checker if we need to. Normal people don't know what\n # \"quadgrams\", \"brandon\", \"json checker\" is.\n # We print the checker if its regex or another language, so long as it starts with:\n # \"The\" like \"The plaintext is a Uniform Resource Locator (URL).\"\n...\n# Path: ciphey/iface/_config.py\nimport datetime\nimport os\nimport pydoc\nfrom typing import Any, Callable, Dict, List, Optional, Type, Union\n\nimport appdirs\nimport yaml\nimport logging\nfrom rich.logging import RichHandler\n\nfrom . import _fwd\nfrom ._modules import PolymorphicChecker, ResourceLoader, Searcher\n\n\nclass Cache:\n \"\"\"Used to track state between levels of recursion to stop infinite loops, and to optimise repeating actions\"\"\"\n\n def __init__(self):\n self._cache: Dict[Any, Dict[str, Any]] = {}\n\n def mark_ctext(self, ctext: Any) -> bool:\n if (isinstance(ctext, str) or isinstance(ctext, bytes)) and len(ctext) < 4:\n logging.debug(f\"Candidate {ctext.__repr__()} too short!\")\n return False\n\n if ctext in self._cache:\n logging.debug(f\"Deduped {ctext.__repr__()}\")\n return False\n\n logging.debug(f\"New ctext {ctext.__repr__()}\")\n\n self._cache[ctext] = {}\n return True\n\n def get_or_update(self, ctext: Any, keyname: str, get_value: Callable[[], Any]):\n # Should have been marked first\n target = self._cache[ctext]\n res = target.get(keyname)\n if res is not None:\n return res\n\n val = get_value()\n target[keyname] = val\n return val\n\n def try_get(self, ctext: Any, keyname: str):\n return self._cache[ctext].get(keyname)\n\n\ndef split_resource_name(full_name: str) -> (str, str):\n return full_name.split(\"::\", 1)\n\n\nclass Config:\n def __init__(self):\n self.verbosity: int = 0\n self.searcher: str = \"ausearch\"\n self.params: Dict[str, Dict[str, Union[str, List[str]]]] = {}\n self.format: str = \"str\"\n self.modules: List[str] = []\n self.checker: str = \"ezcheck\"\n self.default_dist: str = \"cipheydists::dist::english\"\n self.timeout: Optional[int] = None\n self._inst: Dict[type, Any] = {}\n self.objs: Dict[str, Any] = {}\n self.cache: Cache = Cache()\n\n @staticmethod\n def get_default_dir() -> str:\n return appdirs.user_config_dir(\"ciphey\")\n\n def merge_dict(self, config_file: Optional[Dict[str, Any]]):\n if config_file is None:\n return\n for a, b in config_file.items():\n self.update(a, b)\n\n def load_file(\n self,\n path: str = os.path.join(get_default_dir.__func__(), \"config.yml\"),\n create=False,\n ):\n try:\n with open(path, \"r+\") as file:\n return self.merge_dict(yaml.safe_load(file))\n except FileNotFoundError:\n if create:\n open(path, \"w+\")\n\n def instantiate(self, t: type) -> Any:\n \"\"\"\n Used to enable caching of a instantiated type after the configuration has settled\n \"\"\"\n # We cannot use set default as that would construct it again, and throw away the result\n res = self._inst.get(t)\n if res is not None:\n return res\n ret = t(self)\n self._inst[t] = ret\n return ret\n\n def __call__(self, t: type) -> Any:\n return self.instantiate(t)\n\n def update(self, attrname: str, value: Optional[Any]):\n if value is not None:\n setattr(self, attrname, value)\n\n def update_param(self, owner: str, name: str, value: Optional[Any]):\n if value is None:\n return\n\n target = self.params.setdefault(owner, {})\n\n if _fwd.registry.get_named(owner).getParams()[name].list:\n target.setdefault(name, []).append(value)\n else:\n target[name] = value\n\n def update_format(self, value: Optional[str]):\n if value is not None:\n self.format = value\n\n def load_objs(self):\n # Basic type conversion\n if self.timeout is not None:\n self.objs[\"timeout\"] = datetime.timedelta(seconds=int(self.timeout))\n self.objs[\"format\"] = pydoc.locate(self.format)\n\n # Checkers do not depend on any other config object\n logging.debug(f\"Registry is {_fwd.registry._reg[PolymorphicChecker]}\")\n self.objs[\"checker\"] = self(\n _fwd.registry.get_named(self.checker, PolymorphicChecker)\n )\n # Searchers only depend on checkers\n self.objs[\"searcher\"] = self(_fwd.registry.get_named(self.searcher, Searcher))\n\n def update_log_level(self, verbosity: Optional[int]):\n if verbosity is None:\n return\n self.verbosity = verbosity\n\n if verbosity == 0:\n self.verbosity = logging.WARNING\n elif verbosity == 1:\n self.verbosity = logging.INFO\n elif verbosity >= 2:\n self.verbosity = logging.DEBUG\n else:\n logging.disable(logging.CRITICAL)\n return\n\n # https://rich.readthedocs.io/en/latest/logging.html for more on RichHandler\n logging.basicConfig(\n level=self.verbosity,\n datefmt=\"[%X]\",\n handlers=[RichHandler(markup=True, rich_tracebacks=True)],\n )\n logging.debug(f\"Verbosity set to level {verbosity}\")\n\n def load_modules(self):\n import importlib.util\n\n for i in self.modules:\n spec = importlib.util.spec_from_file_location(\"ciphey.module_load_site\", i)\n mod = importlib.util.module_from_spec(spec)\n spec.loader.exec_module(mod)\n\n logging.info(f\"Loaded modules {_fwd.registry.get_all_names()}\")\n\n def complete_config(self) -> \"Config\":\n \"\"\"This does all the loading for the config, and then returns itself\"\"\"\n self.load_modules()\n self.load_objs()\n self.update_log_level(self.verbosity)\n return self\n\n def get_resource(self, res_name: str, t: Optional[Type] = None) -> Any:\n logging.debug(f\"Loading resource {res_name} of type {t}\")\n\n # FIXME: Actually returns obj of type `t`, but python is bad\n loader, name = split_resource_name(res_name)\n if t is None:\n return self(_fwd.registry.get_named(loader, ResourceLoader))(name)\n else:\n return self(_fwd.registry.get_named(loader, ResourceLoader[t]))(name)\n\n # Setter methods for cleaner library API\n def set_verbosity(self, i):\n self.update_log_level(i)\n return self\n\n def set_spinner(self, spinner):\n self.objs[\"spinner\"] = spinner\n\n def pause_spinner_handle(self):\n spinner = self.objs.get(\"spinner\")\n\n class PausedSpinner:\n def __enter__(self):\n if spinner is not None:\n spinner.stop()\n\n def __exit__(self, exc_type, exc_val, exc_tb):\n if spinner is not None:\n spinner.start()\n\n return PausedSpinner()\n\n @staticmethod\n def library_default():\n \"\"\"The default config for use in a library\"\"\"\n return Config().set_verbosity(-1)\n\n def __str__(self):\n return str(\n {\n \"verbosity\": self.verbosity,\n \"searcher\": self.searcher,\n \"params\": self.params,\n \"format\": self.format,\n \"modules\": self.modules,\n \"checker\": self.checker,\n \"default_dist\": self.default_dist,\n \"timeout\": self.timeout,\n }\n )\n\n\n_fwd.config = Config\n\n# Path: ciphey/iface/_registry.py\nfrom typing import Any, Dict, List, Optional, Set, Tuple, Type, Union\n\ntry:\n from typing import get_args, get_origin\nexcept ImportError:\n from typing_inspect import get_origin, get_args\n\nfrom . import _fwd\nfrom ._modules import *\n\n\nclass Registry:\n # I was planning on using __init_subclass__, but that is incompatible with dynamic type creation when we have\n # generic keys\n\n RegElem = Union[List[Type], Dict[Type, \"RegElem\"]]\n\n _reg: Dict[Type, RegElem] = {}\n _names: Dict[str, Tuple[Type, Set[Type]]] = {}\n _targets: Dict[str, Dict[Type, List[Type]]] = {}\n _modules = {Checker, Cracker, Decoder, ResourceLoader, Searcher, PolymorphicChecker}\n\n def _register_one(self, input_type, module_base, module_args):\n if len(module_args) == 0:\n self._reg.setdefault(module_base, []).append(input_type)\n return\n\n target_reg = self._reg.setdefault(module_base, {})\n # Seek to the given type\n for subtype in module_args[0:-1]:\n target_reg = target_reg.setdefault(subtype, {})\n target_reg.setdefault(module_args[-1], []).append(input_type)\n\n def _real_register(self, input_type: type, *args) -> Type:\n name = input_type.__name__.lower()\n name_target = self._names[name] = (input_type, set())\n\n if issubclass(input_type, Targeted):\n target = input_type.getTarget()\n else:\n target = None\n\n if issubclass(input_type, Searcher):\n module_type = module_base = Searcher\n module_args = ()\n else:\n module_type: Optional[Type] = None\n module_base = None\n\n # Work out what module type this is\n if len(args) == 0 and hasattr(input_type, \"__orig_bases__\"):\n for i in input_type.__orig_bases__:\n if module_type is not None:\n raise TypeError(\n f\"Type derived from multiple registrable base classes {i} and {module_type}\"\n )\n module_base = get_origin(i)\n if module_base not in self._modules:\n continue\n module_type = i\n else:\n for i in self._modules:\n if not issubclass(input_type, i):\n continue\n if module_type is not None:\n raise TypeError(\n f\"Type derived from multiple registrable base classes {i} and {module_type}\"\n )\n module_type = i\n if module_type is None:\n raise TypeError(\"No registrable base class\")\n\n # Replace input type with polymorphic checker if required\n if issubclass(input_type, Checker):\n if len(args) == 0:\n arg = [\n get_args(i)\n for i in input_type.__orig_bases__\n if get_origin(i) == Checker\n ][0]\n if len(arg) != 1:\n raise TypeError(\"No argument for Checker\")\n input_type = input_type.convert({arg[0]})\n else:\n input_type = input_type.convert(set(args))\n self._register_one(input_type, PolymorphicChecker, [])\n # Refresh the names with the new type\n name_target = self._names[name] = (input_type, {PolymorphicChecker})\n\n # Now handle the difference between register and register_multi\n if len(args) == 0:\n if module_type is PolymorphicChecker:\n module_base = PolymorphicChecker\n elif module_base is None:\n raise TypeError(\"No type argument given\")\n self._register_one(input_type, module_base, get_args(module_type))\n name_target[1].add(module_base)\n else:\n if module_base is not None:\n raise TypeError(f\"Redundant type argument for {module_type}\")\n module_base = module_type\n for module_args in args:\n # Correct missing brackets\n if not isinstance(module_args, tuple):\n module_args = (module_args,)\n\n self._register_one(input_type, module_base, module_args)\n name_target[1].add(module_type[module_args])\n\n name_target[1].add(module_type)\n\n if target is not None and issubclass(module_base, Targeted):\n self._targets.setdefault(target, {}).setdefault(module_type, []).append(\n input_type\n )\n\n return input_type\n\n def register(self, input_type):\n return self._real_register(input_type)\n\n def register_multi(self, *x):\n return lambda input_type: self._real_register(input_type, *x)\n\n def __getitem__(self, i: type) -> Optional[Any]:\n target_type = get_origin(i)\n # Check if this is a non-generic type, and return the whole dict if it is\n if target_type is None:\n return self._reg[i]\n\n target_subtypes = get_args(i)\n target_list = self._reg.setdefault(target_type, {})\n for subtype in target_subtypes:\n target_list = target_list.setdefault(subtype, {})\n return target_list\n\n def get_named(self, name: str, type_constraint: Type = None) -> Any:\n ret = self._names[name.lower()]\n if type_constraint and type_constraint not in ret[1]:\n raise TypeError(f\"Type mismatch: wanted {type_constraint}, got {ret[1]}\")\n return ret[0]\n\n def get_targeted(\n self, target: str, type_constraint: Type = None\n ) -> Optional[Union[Dict[Type, Set[Type]], Set[Type]]]:\n x = self._targets.get(target)\n if x is None or type_constraint is None:\n return x\n return x.get(type_constraint)\n\n def get_all_names(self) -> List[str]:\n return list(self._names.keys())\n\n def __str__(self):\n return f\"ciphey.iface.Registry {{_reg: {self._reg}, _names: {self._names}, _targets: {self._targets}}}\"\n\n\n_fwd.registry = Registry()\n\n# Path: ciphey/iface/__init__.py\nfrom ._config import Config\n\nfrom ._modules import (\n Checker,\n Cracker,\n CrackInfo,\n CrackResult,\n Decoder,\n DecoderComparer,\n Distribution,\n ParamSpec,\n PolymorphicChecker,\n ResourceLoader,\n Searcher,\n SearchLevel,\n SearchResult,\n T,\n Translation,\n U,\n WordList,\n pretty_search_results,\n)\nfrom ._registry import get_args, get_origin\n\nfrom ._fwd import registry\n\n# Path: ciphey/basemods/Checkers/any.py\nfrom typing import Dict, Optional\n\nfrom ciphey.iface import Config, ParamSpec, PolymorphicChecker, registry\n\n\n@registry.register\nclass Any(PolymorphicChecker):\n \"\"\"Should only be used for debugging, frankly\"\"\"\n\n def getExpectedRuntime(self, text) -> float:\n return 0 # TODO: actually calculate this\n\n def __init__(self, config: Config):\n super().__init__(config)\n\n def check(self, text: str) -> Optional[str]:\n return \"\"\n\n @staticmethod\n def getParams() -> Optional[Dict[str, ParamSpec]]:\n pass\n\n# Path: ciphey/mathsHelper.py\n\"\"\"\n \u2588\u2588\u2588\u2588\u2588\u2588\u2557\u2588\u2588\u2557\u2588\u2588\u2588\u2588\u2588\u2588\u2557 \u2588\u2588\u2557 \u2588\u2588\u2557\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2557\u2588\u2588\u2557 \u2588\u2588\u2557\n\u2588\u2588\u2554\u2550\u2550\u2550\u2550\u255d\u2588\u2588\u2551\u2588\u2588\u2554\u2550\u2550\u2588\u2588\u2557\u2588\u2588\u2551 \u2588\u2588\u2551\u2588\u2588\u2554\u2550\u2550\u2550\u2550\u255d\u255a\u2588\u2588\u2557 \u2588\u2588\u2554\u255d\n\u2588\u2588\u2551 \u2588\u2588\u2551\u2588\u2588\u2588\u2588\u2588\u2588\u2554\u255d\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2551\u2588\u2588\u2588\u2588\u2588\u2557 \u255a\u2588\u2588\u2588\u2588\u2554\u255d\n\u2588\u2588\u2551 \u2588\u2588\u2551\u2588\u2588\u2554\u2550\u2550\u2550\u255d \u2588\u2588\u2554\u2550\u2550\u2588\u2588\u2551\u2588\u2588\u2554\u2550\u2550\u255d \u255a\u2588\u2588\u2554\u255d\n\u255a\u2588\u2588\u2588\u2588\u2588\u2588\u2557\u2588\u2588\u2551\u2588\u2588\u2551 \u2588\u2588\u2551 \u2588\u2588\u2551\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2557 \u2588\u2588\u2551\n\u00a9 Brandon Skerritt\nGithub: brandonskerritt\n\nClass to provide helper functions for mathematics\n(oh, not entirely mathematics either. Some NLP stuff and sorting dicts. It's just a helper class\n)\n\"\"\"\n\nfrom collections import OrderedDict\nfrom string import punctuation\nfrom typing import Optional\n\nimport logging\nfrom rich.logging import RichHandler\n\n\nclass mathsHelper:\n \"\"\"Class to provide helper functions for mathematics and other small things\"\"\"\n\n def __init__(self):\n # ETAOIN is the most popular letters in order\n self.ETAOIN = \"ETAOINSHRDLCUMWFGYPBVKJXQZ\"\n self.LETTERS = \"ABCDEFGHIJKLMNOPQRSTUVWXYZ\"\n\n @staticmethod\n def gcd(a, b) -> int:\n \"\"\"Greatest common divisor.\n\n The Greatest Common Divisor of a and b using Euclid's Algorithm.\n\n Args:\n a -> num 1\n b -> num 2\n\n Returns:\n Returns GCD(a, b)\n\n \"\"\"\n # Return\n while a != 0:\n a, b = b % a, a\n return b\n\n @staticmethod\n def mod_inv(a: int, m: int) -> Optional[int]:\n \"\"\"\n Returns the modular inverse of a mod m, or None if it does not exist.\n\n The modular inverse of a is the number a_inv that satisfies the equation\n a_inv * a mod m === 1 mod m\n\n Note: This is a naive implementation, and runtime may be improved in several ways.\n For instance by checking if m is prime to perform a different calculation,\n or by using the extended euclidean algorithm.\n \"\"\"\n for i in range(1, m):\n if (m * i + 1) % a == 0:\n return (m * i + 1) // a\n return None\n\n @staticmethod\n def percentage(part: float, whole: float) -> float:\n \"\"\"Returns percentage.\n\n Just a normal algorithm to return the percent.\n\n Args:\n part -> part of the whole number\n whole -> the whole number\n\n Returns:\n Returns the percentage of part to whole.\n\n \"\"\"\n if part <= 0 or whole <= 0:\n return 0\n # works with percentages\n return 100 * float(part) / float(whole)\n\n def sort_prob_table(self, prob_table: dict) -> dict:\n \"\"\"Sorts the probability table.\n\n Sorts a dictionary of dictionaries (and all the sub-dictionaries).\n\n Args:\n prob_table -> The probability table returned by the neural network to sort.\n\n Returns:\n Returns the prob_table, but sorted.\n\n \"\"\"\n # for each object: prob table in dictionary\n max_overall: int = 0\n max_dict_pair: dict = {}\n highest_key = None\n empty_dict: dict = {}\n # sorts the prob table before we find max, and converts it to order dicts\n for key, value in prob_table.items():\n prob_table[key] = self.new_sort(value)\n prob_table[key] = dict(prob_table[key])\n\n # gets maximum key then sets it to the front\n counter_max: int = 0\n counter_prob: int = len(prob_table)\n while counter_max < counter_prob:\n max_overall = 0\n highest_key = None\n logging.debug(\n f\"Running while loop in sort_prob_table, counterMax is {counter_max}\"\n )\n for key, value in prob_table.items():\n logging.debug(f\"Sorting {key}\")\n maxLocal = 0\n # for each item in that table\n for key2, value2 in value.items():\n logging.debug(\n f\"Running key2 {key2}, value2 {value2} for loop for {value.items()}\"\n )\n maxLocal = maxLocal + value2\n logging.debug(\n f\"MaxLocal is {maxLocal} and maxOverall is {max_overall}\"\n )\n if maxLocal > max_overall:\n logging.debug(f\"New max local found {maxLocal}\")\n # because the dict doesn't reset\n max_dict_pair = {}\n max_overall = maxLocal\n # so eventually, we get the maximum dict pairing?\n max_dict_pair[key] = value\n highest_key = key\n logging.debug(f\"Highest key is {highest_key}\")\n # removes the highest key from the prob table\n logging.debug(\n f\"Prob table is {prob_table} and highest key is {highest_key}\"\n )\n logging.debug(f\"Removing {prob_table[highest_key]}\")\n del prob_table[highest_key]\n logging.debug(f\"Prob table after deletion is {prob_table}\")\n counter_max += 1\n empty_dict = {**empty_dict, **max_dict_pair}\n\n # returns the max dict (at the start) with the prob table\n # this way, it should always work on most likely first.\n logging.debug(\n f\"The prob table is {prob_table} and the maxDictPair is {max_dict_pair}\"\n )\n logging.debug(f\"The new sorted prob table is {empty_dict}\")\n return empty_dict\n\n @staticmethod\n def new_sort(new_dict: dict) -> dict:\n \"\"\"Uses OrderedDict to sort a dictionary.\n\n I think it's faster than my implementation.\n\n Args:\n new_dict -> the dictionary to sort\n\n Returns:\n Returns the dict, but sorted.\n\n \"\"\"\n # (f\"d is {d}\")\n logging.debug(f\"The old dictionary before new\n_sort() is {new_dict}\")\n sorted_i = OrderedDict(\n sorted(new_dict.items(), key=lambda x: x[1], reverse=True)\n )\n logging.debug(f\"The dictionary after new_sort() is {sorted_i}\")\n # sortedI = sort_dictionary(x)\n return sorted_i\n\n @staticmethod\n def is_ascii(s: str) -> bool:\n \"\"\"Returns the boolean value if is_ascii is an ascii char.\n\n Does what it says on the tree. Stolen from\n https://stackoverflow.com/questions/196345/how-to-check-if-a-string-in-python-is-in-ascii\n\n Args:\n s -> the char to check.\n\n Returns:\n Returns the boolean of the char.\n\n \"\"\"\n\n return bool(lambda s: len(s) == len(s.encode()))\n\n @staticmethod\n def strip_punctuation(text: str) -> str:\n \"\"\"Strips punctuation from a given string.\n\n Uses string.punctuation.\n\n Args:\n text -> the text to strip punctuation from.\n\n Returns:\n Returns string without punctuation.\n \"\"\"\n text: str = (str(text).translate(str.maketrans(\"\", \"\", punctuation))).strip(\n \"\\n\"\n )\n return text\n\n# Path: ciphey/basemods/Checkers/brandon.py\n\"\"\"\n \u2588\u2588\u2588\u2588\u2588\u2588\u2557\u2588\u2588\u2557\u2588\u2588\u2588\u2588\u2588\u2588\u2557 \u2588\u2588\u2557 \u2588\u2588\u2557\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2557\u2588\u2588\u2557 \u2588\u2588\u2557\n\u2588\u2588\u2554\u2550\u2550\u2550\u2550\u255d\u2588\u2588\u2551\u2588\u2588\u2554\u2550\u2550\u2588\u2588\u2557\u2588\u2588\u2551 \u2588\u2588\u2551\u2588\u2588\u2554\u2550\u2550\u2550\u2550\u255d\u255a\u2588\u2588\u2557 \u2588\u2588\u2554\u255d\n\u2588\u2588\u2551 \u2588\u2588\u2551\u2588\u2588\u2588\u2588\u2588\u2588\u2554\u255d\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2551\u2588\u2588\u2588\u2588\u2588\u2557 \u255a\u2588\u2588\u2588\u2588\u2554\u255d\n\u2588\u2588\u2551 \u2588\u2588\u2551\u2588\u2588\u2554\u2550\u2550\u2550\u255d \u2588\u2588\u2554\u2550\u2550\u2588\u2588\u2551\u2588\u2588\u2554\u2550\u2550\u255d \u255a\u2588\u2588\u2554\u255d\n\u255a\u2588\u2588\u2588\u2588\u2588\u2588\u2557\u2588\u2588\u2551\u2588\u2588\u2551 \u2588\u2588\u2551 \u2588\u2588\u2551\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2557 \u2588\u2588\u2551\n\u00a9 Brandon Skerritt\nGithub: brandonskerritt\n\nClass to determine whether something is English or not.\n1. Calculate the Chi Squared score of a sentence\n2. If the score is significantly lower than the average score, it _might_ be English\n 2.1. If the score _might_ be English, then take the text and compare it to the sorted dictionary\n in O(n log n) time.\n It creates a percentage of \"How much of this text is in the dictionary?\"\n The dictionary contains:\n * 20,000 most common US words\n * 10,000 most common UK words (there's no repetition between the two)\n * The top 10,000 passwords\n If the word \"Looks like\" English (chi-squared) and if it contains English words, we can conclude it is\n very likely English. The alternative is doing the dictionary thing but with an entire 479k word dictionary (slower)\n 2.2. If the score is not English, but we haven't tested enough to create an average, then test it against\n the dictionary\n\nThings to optimise:\n* We only run the dictionary if it's 20% smaller than the average for chi squared\n* We consider it \"English\" if 45% of the text matches the dictionary\n* We run the dictionary if there is less than 10 total chisquared test\n\nHow to add a language:\n* Download your desired dictionary. Try to make it the most popular words, for example. Place this file into this\n folder with languagename.txt\nAs an example, this comes built in with english.txt\nFind the statistical frequency of each letter in that language.\nFor English, we have:\nself.languages = {\n \"English\":\n [0.0855, 0.0160, 0.0316, 0.0387, 0.1210,0.0218, 0.0209, 0.0496, 0.0733, 0.0022,0.0081, 0.0421, 0.0253, 0.0717,\n 0.0747,0.0207, 0.0010, 0.0633, 0.0673, 0.0894,0.0268, 0.0106, 0.0183, 0.0019, 0.0172,0.0011]\n}\nIn chisquared.py\nTo add your language, do:\nself.languages = {\n \"English\":\n [0.0855, 0.0160, 0.0316, 0.0387, 0.1210,0.0218, 0.0209, 0.0496, 0.0733, 0.0022,0.0081, 0.0421, 0.0253, 0.0717,\n 0.0747,0.0207, 0.0010, 0.0633, 0.0673, 0.0894,0.0268, 0.0106, 0.0183, 0.0019, 0.0172,0.0011]\n \"German\": [0.0973]\n}\nIn alphabetical order\nAnd you're.... Done! Make sure the name of the two match up\n\"\"\"\nimport sys\nfrom math import ceil\nfrom typing import Dict, Optional\n\nimport logging\nfrom rich.logging import RichHandler\n\nfrom ciphey.iface import Checker, Config, ParamSpec, T, registry\n\nsys.path.append(\"..\")\ntry:\n import mathsHelper as mh\nexcept ModuleNotFoundError:\n import ciphey.mathsHelper as mh\n\n\n@registry.register\nclass Brandon(Checker[str]):\n \"\"\"\n Class designed to confirm whether something is **language** based on how many words of **language** appears\n Call confirmLanguage(text, language)\n * text: the text you want to confirm\n * language: the language you want to confirm\n\n Find out what language it is by using chisquared.py, the highest chisquared score is the language\n languageThreshold = 45\n if a string is 45% **language** words, then it's confirmed to be english\n \"\"\"\n\n def getExpectedRuntime(self, text: T) -> float:\n # TODO: actually work this out\n # TODO its 0.2 seconds on average\n return 1e-4 # 100 \u00b5s\n\n wordlist: set\n\n def clean_text(self, text: str) -> set:\n \"\"\"Cleans the text ready to be checked\n\n Strips punctuation, makes it lower case, turns it into a set separated by spaces, removes duplicate words\n\n Args:\n text -> The text we use to perform analysis on\n\n Returns:\n text -> the text as a list, now cleaned\n\n \"\"\"\n # makes the text unique words and readable\n text = text.lower()\n text = self.mh.strip_punctuation(text)\n text = text.split(\" \")\n text = filter(lambda x: len(x) > 2, text)\n text = set(text)\n return text\n\n def checker(self, text: str, threshold: float, text_length: int, var: set) -> bool:\n \"\"\"Given text determine if it passes checker\n\n The checker uses the variable passed to it. I.E. Stopwords list, 1k words, dictionary\n\n Args:\n text -> The text to check\n threshold -> at what point do we return True? The percentage of text that is in var before we return True\n text_length -> the length of the text\n var -> the variable we are checking against. Stopwords list, 1k words list, dictionary list.\n Returns:\n boolean -> True for it passes the test, False for it fails the test.\"\"\"\n if text is None:\n logging.debug(\"Checker's text is None, so returning False\")\n return False\n if var is None:\n logging.debug(\"Checker's input var is None, so returning False\")\n return False\n\n percent = ceil(text_length * threshold)\n logging.debug(f\"Checker's chunks are size {percent}\")\n meet_threshold = 0\n location = 0\n end = percent\n\n if text_length <= 0:\n return False\n\n while location <= text_length:\n # chunks the text, so only gets THRESHOLD chunks of text at a time\n text = list(text)\n to_analyse = text[location:end]\n logging.debug(f\"To analyse is {to_analyse}\")\n for word in to_analyse:\n # if word is a stopword, + 1 to the counter\n if word in var:\n logging.debug(\n f\"{word} is in var, which means I am +=1 to the meet_threshold which is {meet_threshold}\"\n )\n meet_threshold += 1\n meet_threshold_percent = meet_threshold / text_length\n if meet_threshold_percent >= threshold:\n logging.debug(\n f\"Returning true since the percentage is {meet_threshold / text_length} and the threshold is {threshold}\"\n )\n # if we meet the threshold, return True\n # otherwise, go over again until we do\n # We do this in the for loop because if we're at 24% and THRESHOLD is 25\n # we don't want to wait THRESHOLD to return true, we want to return True ASAP\n return True\n location = end\n end = end + percent\n logging.debug(\n f\"The language proportion {meet_threshold_percent} is under the threshold {threshold}\"\n )\n return False\n\n def __init__(self, config: Config):\n # Suppresses warning\n super().__init__(config)\n self.mh = mh.mathsHelper()\n\n phases = config.get_resource(self._params()[\"phases\"])\n\n self.thresholds_phase1 = phases[\"1\"]\n self.thresholds_phase2 = phases[\"2\"]\n self.top1000Words = config.get_resource(self._params().get(\"top1000\"))\n self.wordlist = config.get_resource(self._params()[\"wordlist\"])\n self.stopwords = config.get_resource(self._params().get(\"stopwords\"))\n\n self.len_phase1 = len(self.thresholds_phase1)\n self.len_phase2 = len(self.thresholds_phase2)\n\n def check(self, text: str) -> Optional[str]:\n \"\"\"Checks to see if the text is in English\n\n Performs a decryption, but mainly parses the internal data packet and prints useful information.\n\n Args:\n text -> The text we use to perform analysis on\n\n Returns:\n bool -> True if the text is English, False otherwise.\n\n \"\"\"\n logging.debug(f'In Language Checker with \"{text}\"')\n text = self.clean_text(text)\n logging.debug(f'Text split to \"{text}\"')\n if text == \"\":\n logging.debug(\"Returning None from Brandon as the text cleaned is none.\")\n return None\n\n length_text = len(text)\n\n what_to_use = {}\n\n # this code decides what checker / threshold to use\n # if text is over or equal to maximum size, just use the maximum possible checker\n what_to_use = self.calculateWhatChecker(\n length_text, self.thresholds_phase1.keys()\n )\n logging.debug(self.thresholds_phase1)\n what_to_use = self.thresholds_phase1[str(what_to_use)]\n # def checker(self, text: str, threshold: float, text_length: int, var: set) -> bool:\n if \"check\" in what_to_use:\n # perform check 1k words\n result = self.checker(\n text, what_to_use[\"check\"], length_text, self.top1000Words\n )\n elif \"stop\" in what_to_use:\n # perform stopwords\n result = self.checker(\n text, what_to_use[\"stop\"], length_text, self.stopwords\n )\n elif \"dict\" in what_to_use:\n result = self.checker(text, what_to_use[\"dict\"], length_text, self.wordlist)\n # If result is None, no point doing it again in phase2\n if not result:\n return None\n else:\n logging.info(f\"It is neither stop or check, but instead {what_to_use}\")\n\n # return False if phase 1 fails\n if not result:\n return None\n else:\n what_to_use = self.calculateWhatChecker(\n length_text, self.thresholds_phase2.keys()\n )\n what_to_use = self.thresholds_phase2[str(what_to_use)]\n result = self.checker(text, what_to_use[\"dict\"], length_text, self.wordlist)\n return \"\" if result else None\n\n def calculateWhatChecker(self, length_text, key):\n \"\"\"Calculates what threshold / checker to use\n\n If the length of the text is over the maximum sentence length, use the last checker / threshold\n Otherwise, traverse the keys backwards until we find a key range that does not fit.\n So we traverse backwards and see if the sentence length is between current - 1 and current\n In this way, we find the absolute lowest checker / percentage threshold.\n We traverse backwards because if the text is longer than the max sentence length, we already know.\n In total, the keys are only 5 items long or so. It is not expensive to move backwards, nor is it expensive to move forwards.\n\n Args:\n length_text -> The length of the text\n key -> What key we want to use. I.E. Phase1 keys, Phase2 keys.\n Returns:\n what_to_use -> the key of the lowest checker.\"\"\"\n\n _keys = list(key)\n _keys = list(map(int, _keys))\n if length_text >= int(_keys[-1]):\n what_to_use = list(key)[_keys.index(_keys[-1])]\n else:\n # this algorithm finds the smallest possible fit for the text\n for counter, i in reversed(list(enumerate(_keys))):\n # [0, 110, 150]\n if i <= length_text:\n what_to_use = i\n return what_to_use\n\n @staticmethod\n def getParams() -> Optional[Dict[str, ParamSpec]]:\n return {\n \"top1000\": ParamSpec(\n desc=\"A wordlist of the top 1000 words\",\n req=False,\n default=\"cipheydists::list::english1000\",\n ),\n \"wordlist\": ParamSpec(\n desc=\"A wordlist of all the words\",\n req=False,\n default=\"cipheydists::list::english\",\n ),\n \"stopwords\": ParamSpec(\n desc=\"A wordlist of StopWords\",\n req=False,\n default=\"cipheydists::list::englishStopWords\",\n ),\n \"threshold\": ParamSpec(\n desc=\"The minimum proportion (between 0 and 1) that must be in the dictionary\",\n req=False,\n default=0.45,\n ),\n \"phases\": ParamSpec(\n desc=\"Language-specific phase thresholds\",\n req=False,\n default=\"cipheydists::brandon::english\",\n ),\n }\n\n# Path: ciphey/basemods/Checkers/format.py\nimport json\nfrom typing import Dict, Optional\n\nimport logging\nfrom rich.logging import RichHandler\n\nfrom ciphey.iface import Checker, Config, ParamSpec, T, registry\n\n\n@registry.register\nclass JsonChecker(Checker[str]):\n\n \"\"\"\n This object is effectively a prebuilt quorum (with requirement 1) of common patterns\n \"\"\"\n\n def check(self, text: T) -> Optional[str]:\n logging.debug(\"Trying json checker\")\n\n # https://github.com/Ciphey/Ciphey/issues/389\n if text.isdigit():\n return None\n\n try:\n json.loads(text)\n return \"\"\n except ValueError:\n return None\n\n def getExpectedRuntime(self, text: T) -> float:\n # TODO: actually bench this\n return 1e-7 * len(text) # From benchmarks I found online\n\n def __init__(self, config: Config):\n super().__init__(config)\n\n @staticmethod\n def getParams() -> Optional[Dict[str, ParamSpec]]:\n pass\n\n# Path: ciphey/basemods/Checkers/human.py\nfrom typing import Dict, Optional\n\nfrom ciphey.iface import Checker, Config, ParamSpec, registry\nfrom rich.console import Console\nfrom rich.markup import escape\n\nconsole = Console()\n\n\n@registry.register\nclass HumanChecker(Checker[str]):\n\n \"\"\"\n Uses the person's decision to determine plaintext\n \"\"\"\n\n def check(self, ctext: str) -> Optional[str]:\n with self._config().pause_spinner_handle():\n response = console.input(\n f\"Possible plaintext: [blue bold]{escape(ctext.__repr__())}[/blue bold] ([green]y[/green]/[red]N[/red]): \"\n )\n if response == \"y\":\n return \"\"\n elif response in (\"n\", \"\"):\n return None\n else:\n return self.check(ctext)\n\n def getExpectedRuntime(self, text: str) -> float:\n return 1 # About a second\n\n @staticmethod\n def getParams() -> Optional[Dict[str, ParamSpec]]:\n return None\n\n def __init__(self, config: Config):\n super().__init__(config)\n\n# Path: ciphey/basemods/Checkers/quadgrams.py\nimport logging\nimport re\nfrom math import log10\nfrom typing import Dict, Optional\n\nfrom ciphey.iface import Checker, Config, ParamSpec, T, Translation, registry\nfrom rich.logging import RichHandler\n\n\n@registry.register\nclass Quadgrams(Checker[str]):\n\n \"\"\"\n Uses Quadgrams to determine plaintext\n \"\"\"\n\n def check(self, ctext: T) -> Optional[str]:\n logging.debug(\"Trying Quadgrams checker\")\n # Capitalize and remove everything that's not a letter\n ctext = re.sub(\"[^A-Z]\", \"\", ctext.upper())\n quadgrams = self.QUADGRAMS_DICT\n quadgrams_sum = sum(quadgrams.values())\n score = 0\n for key in quadgrams.keys():\n quadgrams[key] = float(quadgrams[key]) / quadgrams_sum\n floor = log10(0.01 / quadgrams_sum)\n for i in range(len(ctext) - 4 + 1):\n # Get all quadgrams from ctext and check if they're in the dict\n # If yes then add the score of those quadgrams to the total score\n if ctext[i : i + 4] in quadgrams:\n score += quadgrams[ctext[i : i + 4]]\n else:\n score += floor\n if len(ctext) > 0:\n score = score / len(ctext)\n logging.info(f\"Quadgrams is {score}\")\n # The default threshold was found to work the best from lots of testing\n if score > self.threshold:\n return \"\"\n return None\n\n def getExpectedRuntime(self, text: T) -> float:\n # TODO: actually bench this\n return 2e-7 * len(text)\n\n @staticmethod\n def getParams() -> Optional[Dict[str, ParamSpec]]:\n return {\n \"dict\": ParamSpec(\n desc=\"The quadgrams dictionary to use\",\n req=False,\n default=\"cipheydists::dist::quadgrams\",\n ),\n \"score\": ParamSpec(\n desc=\"The score threshold to use\",\n req=False,\n default=0.00011,\n ),\n }\n\n def __init__(self, config: Config):\n super().__init__(config)\n self.QUADGRAMS_DICT = config.get_resource(self._params()[\"dict\"], Translation)\n self.threshold = float(self._params()[\"score\"])\n\n# Path: ciphey/basemods/Checkers/regex.py\nimport re\nfrom typing import Dict, Optional\n\nimport logging\nfrom rich.logging import RichHandler\n\nfrom ciphey.iface import Checker, Config, ParamSpec, T, registry\n\n\n@registry.register\nclass Regex(Checker[str]):\n def getExpectedRuntime(self, text: T) -> float:\n return 1e-5 # TODO: actually calculate this\n\n def __init__(self, config: Config):\n super().__init__(config)\n self.regexes = list(map(re.compile, self._params()[\"regex\"]))\n logging.debug(f\"There are {len(self.regexes)} regexes\")\n\n def check(self, text: str) -> Optional[str]:\n for regex in self.regexes:\n logging.debug(f\"Trying regex {regex} on {text}\")\n res = regex.search(text)\n logging.debug(f\"Results: {res}\")\n if res:\n return f\"Passed with regex {regex}. Want to contribute to Ciphey? Submit your regex here to allow Ciphey to automatically get this next time https://github.com/bee-san/pyWhat/wiki/Adding-your-own-Regex\\n\"\n\n @staticmethod\n def getParams() -> Optional[Dict[str, ParamSpec]]:\n return {\n \"regex\": ParamSpec(\n req=True,\n desc=\"The regex that must be matched (in a substring)\",\n list=True,\n )\n }\n\n\n@registry.register\nclass RegexList(Checker[str]):\n def getExpectedRuntime(self, text: T) -> float:\n return 1e-5 # TODO: actually calculate this\n\n def __init__(self, config: Config):\n super().__init__(config)\n self.regexes = []\n for i in self._params()[\"resource\"]:\n self.regexes += [re.compile(regex) for regex in config.get_resource(i)]\n logging.debug(f\"There are {len(self.regexes)} regexes\")\n\n def check(self, text: str) -> Optional[str]:\n for regex in self.regexes:\n logging.debug(f\"Trying regex {regex} on {text}\")\n res = regex.search(text)\n logging.debug(f\"Results: {res}\")\n if res:\n return f\"passed with regex {regex}\"\n\n @staticmethod\n def getParams() -> Optional[Dict[str, ParamSpec]]:\n return {\n \"resource\": ParamSpec(\n req=True,\n desc=\"A list of regexes that could be matched\",\n list=True,\n )\n }\n\n# Path: ciphey/basemods/Checkers/what.py\nfrom typing import Dict, Optional\n\nfrom ciphey.iface import Checker, Config, ParamSpec, T, registry\nimport logging\nfrom rich.logging import RichHandler\nfrom pywhat import identifier\nfrom rich.console import Console\n\nconsole = Console()\n\n\n@registry.register\nclass What(Checker[str]):\n\n \"\"\"\n Uses PyWhat to determine plaintext with regexes\n https://github.com/bee-san/pyWhat\n \"\"\"\n\n def check(self, ctext: T) -> Optional[str]:\n logging.debug(\"Trying PyWhat checker\")\n returned_regexes = self.id.identify(ctext)\n if returned_regexes[\"Regexes\"]:\n matched_regex = returned_regexes[\"Regexes\"]['text'][0][\"Regex Pattern\"]\n\n ret = f'The plaintext is a [yellow]{matched_regex[\"Name\"]}[/yellow]'\n human = (\n f'\\nI think the plaintext is a [yellow]{matched_regex[\"Name\"]}[/yellow]'\n )\n\n if \"Description\" in matched_regex and matched_regex[\"Description\"]:\n s = matched_regex[\"Description\"]\n # lowercases first letter so it doesn't look weird\n s = f\", which is {s[0].lower() + s[1:]}\\n\"\n ret += s\n human += s\n\n # if URL is attached, include that too.\n if \"URL\" in matched_regex and matched_regex[\"URL\"]:\n link = matched_regex[\"URL\"] + ctext.replace(\" \", \"\")\n ret += f\"\\nClick here to view in browser [#CAE4F1][link={link}]{link}[/link][/#CAE4F1]\\n\"\n\n # If greppable mode is on, don't print this\n if self.config.verbosity >= 0:\n # Print with full stop\n console.print(human)\n return ret\n return None\n\n def getExpectedRuntime(self, text: T) -> float:\n # TODO: actually bench this\n return 2e-7 * len(text)\n\n @staticmethod\n def getParams() -> Optional[Dict[str, ParamSpec]]:\n return None\n\n def __init__(self, config: Config):\n super().__init__(config)\n self.config = config\n self.id = identifier.Identifier()\n\n# Path: ciphey/basemods/Checkers/ezcheck.py\nfrom typing import Dict, List, Optional\n\nfrom ciphey.iface import Checker, Config, ParamSpec, T, registry\n\nfrom .brandon import Brandon\nfrom .format import JsonChecker\nfrom .human import HumanChecker\nfrom .quadgrams import Quadgrams\nfrom .regex import RegexList\nfrom .what import What\n\n\n@registry.register\nclass EzCheck(Checker[str]):\n \"\"\"\n This object is effectively a prebuilt quorum (with requirement 1) of common patterns, followed by a human check\n \"\"\"\n\n def check(self, text: str) -> Optional[str]:\n for checker in self.checkers:\n res = checker.check(text)\n if (\n res is not None\n and (self.decider is None or self.decider.check(text)) is not None\n ):\n return res\n return None\n\n def getExpectedRuntime(self, text: T) -> float:\n return sum(\n i.getExpectedRuntime(text) for i in self.checkers\n ) + self.decider.getExpectedRuntime(text)\n\n def __init__(self, config: Config):\n super().__init__(config)\n\n self.checkers: List[Checker[str]] = []\n # Disable human checker for automated systems\n if config.verbosity >= 0:\n self.decider = config(HumanChecker)\n else:\n self.decider = None\n\n # We need to modify the config for each of the objects\n\n # First PyWhat, as it's the fastest\n self.checkers.append(config(What))\n\n # Next, the json checker\n self.checkers.append(config(JsonChecker))\n\n # Second to last, the quadgrams checker\n self.checkers.append(config(Quadgrams))\n\n # Finally, the Brandon checker, as it is the slowest\n self.checkers.append(config(Brandon))\n\n @staticmethod\n def getParams() -> Optional[Dict[str, ParamSpec]]:\n pass\n\n# Path: ciphey/basemods/Checkers/quorum.py\nfrom typing import Dict, Generic, Optional\n\nfrom ciphey.iface import Checker, Config, ParamSpec, T, _registry\n\n\nclass Quorum(Generic[T], Checker[T]):\n def check(self, text: T) -> Optional[str]:\n left = self._params().k\n results = []\n for checker in self.checkers:\n results.append(checker.check(text))\n...\n", "needle_token_start": 8705, "needle_token_end": 8905, "code_context_ntokens": 16045, "output": ["```python\nfrom collections import OrderedDict\n\ndef sort_prob_table(prob_table: dict) -> dict:\n \"\"\"Sorts the probability table.\n\n Sorts a dictionary of dictionaries (and all the sub-dictionaries).\n\n Args:\n prob_table -> The probability table returned by the neural network to sort.\n\n Returns:\n Returns the prob_table, but sorted.\n\n \"\"\"\n # for each object: prob table in dictionary\n max_overall: int = 0\n max_dict_pair: dict = {}\n highest_key = None\n empty_dict: dict = {}\n # sorts the prob table before we find max, and converts it to order dicts\n for key, value in prob_table.items():\n prob_table[key] = self.new_sort(value)\n prob_table[key] = dict(prob_table[key])\n\n # gets maximum key then sets it to the front\n counter_max: int = 0\n counter_prob: int = len(prob_table)\n while counter_max < counter_prob:\n max_overall = 0\n highest_key = None\n logging.debug(\n f\"Running while loop in sort_prob_table, counterMax is {counter_max}\"\n )\n for key, value in prob_table.items():\n logging.debug(f\"Sorting {key}\")\n maxLocal = 0\n # for each item in that table\n for key2, value2 in value.items():\n logging.debug(\n f\"Running key2 {key2}, value2 {value2} for loop for {value.items()}\"\n )\n maxLocal = maxLocal + value2\n logging.debug(\n f\"MaxLocal is {maxLocal} and maxOverall is {max_overall}\"\n )\n if maxLocal > max_overall:\n logging.debug(f\"New max local found {maxLocal}\")\n # because the dict doesn't reset\n max_dict_pair = {}\n max_overall = maxLocal\n # so eventually, we get the maximum dict pairing?\n max_dict_pair[key] = value\n highest_key = key\n logging.debug(f\"Highest key is {highest_key}\")\n # removes the highest key from the prob table\n logging.debug(\n f\"Prob table is {prob_table} and highest key is {highest_key}\"\n )\n logging.debug(f\"Removing {prob_table[highest_key]}\")\n del prob_table[highest_key]\n logging.debug(f\"Prob table after deletion is {prob_table}\")\n counter_max += 1\n empty_dict = {**empty_dict, **max_dict_pair}\n\n # returns the max dict (at the start) with the prob table\n # this way, it should always work on most likely first.\n logging.debug(\n f\"The prob table is {prob_table} and the maxDictPair is {max_dict_pair}\"\n )\n logging.debug(f\"The new sorted prob table is {empty_dict}\")\n return empty_dict\n```"]} +{"repo": "Ciphey/Ciphey", "name": "get_name", "language": "python", "path": "ciphey/ciphey.py", "position_ratio": 0.65, "description": "\nFunction Description:\n1. **Purpose**: The function is designed to retrieve a string value, either from a command-line argument or from standard input if the argument is not provided.\n2. **Input**: It takes three parameters: a context object, a parameter object, and a value which is potentially a string.\n3. **Output**: Returns a string that is either the input value provided or read from standard input.\n4. **Procedure**: The function first checks if a value is provided. If not, it checks if the standard input is not a terminal (i.e., there is data being piped or redirected into the program). It then reads this input, strips any surrounding whitespace, and returns it. If a value is provided, or if there is no data on standard input, it simply returns the provided value.\n\n", "instruction": "Based on the function description and code context, please retrieve and repeat the exact described function from the code context in a code block wrapped by ```:", "template": "instruction\ncode_context\ndescription\ninstruction", "code_context": "# Path: ciphey/basemods/Decoders/galactic.py\nfrom typing import Dict, Optional\n\nimport logging\nfrom rich.logging import RichHandler\n\nfrom ciphey.iface import Config, Decoder, ParamSpec, T, Translation, U, registry\n\n\n@registry.register\nclass Galactic(Decoder[str]):\n def decode(self, ctext: T) -> Optional[U]:\n \"\"\"\n Takes a string written in the 'Standard Galactic Alphabet'\n (aka Minecraft Enchanting Table Symbols) and translates it to ASCII text.\n \"\"\"\n logging.debug(\"Attempting Standard Galactic Alphabet decoder\")\n\n # To avoid complications, only move forward with the decoding if we can\n # reasonably assume that the input string is written in the galactic alphabet\n galactic_matches = 0\n for symbol in self.GALACTIC_DICT.keys():\n # These symbols are assumed to be frequent enough in regular\n # text to be skipped when counting the matches. All others are counted.\n if symbol in ctext and symbol not in [\"!\", \"|\"]:\n galactic_matches += 1\n else:\n...\n# Path: ciphey/basemods/Decoders/gzip.py\nimport zlib\nfrom typing import Dict, Optional\n\nfrom ciphey.iface import Config, Decoder, ParamSpec, T, U, registry\n\n\n@registry.register\nclass Gzip(Decoder[bytes]):\n def decode(self, ctext: T) -> Optional[U]:\n \"\"\"\n Performs Gzip decoding\n \"\"\"\n try:\n return zlib.decompress(ctext, 16 + zlib.MAX_WBITS).decode(\"utf-8\")\n except Exception:\n return None\n\n @staticmethod\n def priority() -> float:\n # Not expected to show up often, but also very fast to check.\n return 0.05\n\n def __init__(self, config: Config):\n super().__init__(config)\n\n @staticmethod\n def getParams() -> Optional[Dict[str, ParamSpec]]:\n return None\n\n @staticmethod\n def getTarget() -> str:\n return \"gzip\"\n\n# Path: ciphey/basemods/Decoders/hexadecimal.py\nfrom typing import Dict, Optional\n\nfrom ciphey.iface import Config, Decoder, ParamSpec, T, U, registry\n\n\n@registry.register\nclass Hexadecimal(Decoder[str]):\n def decode(self, ctext: T) -> Optional[U]:\n \"\"\"\n Performs Hexadecimal decoding\n \"\"\"\n ctext_decoded = \"\"\n try:\n ctext_decoded = bytearray.fromhex(ctext).decode(\"utf-8\")\n return ctext_decoded\n except Exception:\n return None\n\n @staticmethod\n def priority() -> float:\n return 0.015\n\n def __init__(self, config: Config):\n super().__init__(config)\n\n @staticmethod\n def getParams() -> Optional[Dict[str, ParamSpec]]:\n return None\n\n @staticmethod\n def getTarget() -> str:\n return \"hexadecimal\"\n\n# Path: ciphey/basemods/Decoders/leetspeak.py\nfrom typing import Dict, Optional\n\nfrom ciphey.iface import Config, Decoder, ParamSpec, T, Translation, U, registry\n\n\n@registry.register\nclass Leetspeak(Decoder[str]):\n def decode(self, ctext: T) -> Optional[U]:\n for src, dst in self.translate.items():\n ctext = ctext.replace(src, dst)\n return ctext\n\n @staticmethod\n def priority() -> float:\n return 0.05\n\n def __init__(self, config: Config):\n super().__init__(config)\n self.translate = config.get_resource(self._params()[\"dict\"], Translation)\n\n @staticmethod\n def getParams() -> Optional[Dict[str, ParamSpec]]:\n return {\n \"dict\": ParamSpec(\n desc=\"The leetspeak dictionary to use\",\n req=False,\n default=\"cipheydists::translate::leet\",\n )\n }\n\n @staticmethod\n def getTarget() -> str:\n return \"leetspeak\"\n\n# Path: ciphey/basemods/Decoders/morse_code.py\nfrom typing import Dict, Optional\n\nimport logging\nfrom rich.logging import RichHandler\n\nfrom ciphey.iface import Config, Decoder, ParamSpec, T, Translation, U, registry\n\n\n@registry.register\nclass Morse_code(Decoder[str]):\n # A priority list for char/word boundaries\n BOUNDARIES = {\" \": 1, \"/\": 2, \"\\n\": 3}\n PURGE = {ord(c): None for c in BOUNDARIES.keys()}\n MAX_PRIORITY = 3\n ALLOWED = {\".\", \"-\", \" \", \"/\", \"\\n\"}\n MORSE_CODE_DICT: Dict[str, str]\n MORSE_CODE_DICT_INV: Dict[str, str]\n\n def decode(self, ctext: T) -> Optional[U]:\n logging.debug(\"Attempting Morse code decoder\")\n\n char_boundary = word_boundary = None\n\n char_boundary = word_boundary = None\n char_priority = word_priority = 0\n # Custom loop allows early break\n for i in ctext:\n i_priority = self.BOUNDARIES.get(i)\n if i_priority is None:\n if i in self.ALLOWED:\n continue\n logging.debug(f\"Non-morse char '{i}' found\")\n return None\n\n if i_priority <= char_priority or i == char_boundary or i == word_boundary:\n continue\n # Default to having a char boundary over a word boundary\n if (\n i_priority > word_priority\n and word_boundary is None\n and char_boundary is not None\n ):\n word_priority = i_priority\n word_boundary = i\n continue\n char_priority = i_priority\n char_boundary = i\n\n logging.debug(\n f\"Char boundary is unicode {ord(char_boundary)}, and word boundary is unicode {ord(word_boundary) if word_boundary is not None else None}\"\n )\n\n result = \"\"\n\n for word in ctext.split(word_boundary) if word_boundary else [ctext]:\n logging.debug(f\"Attempting to decode word {word}\")\n for char in word.split(char_boundary):\n char = char.translate(self.PURGE)\n if len(char) == 0:\n continue\n try:\n m = self.MORSE_CODE_DICT_INV[char]\n except KeyError:\n logging.debug(f\"Invalid codeword '{char}' found\")\n return None\n result = result + m\n # after every word add a space\n result = result + \" \"\n if len(result) == 0:\n logging.debug(\"Morse code failed to match\")\n return None\n # Remove trailing space\n result = result[:-1]\n logging.info(f\"Morse code successful, returning {result}\")\n return result.strip().upper()\n\n @staticmethod\n def priority() -> float:\n return 0.05\n\n def __init__(self, config: Config):\n super().__init__(config)\n self.MORSE_CODE_DICT = config.get_resource(self._params()[\"dict\"], Translation)\n self.MORSE_CODE_DICT_INV = {v: k for k, v in self.MORSE_CODE_DICT.items()}\n\n @staticmethod\n def getParams() -> Optional[Dict[str, ParamSpec]]:\n return {\n \"dict\": ParamSpec(\n desc=\"The morse code dictionary to use\",\n req=False,\n default=\"cipheydists::translate::morse\",\n )\n }\n\n @staticmethod\n def getTarget() -> str:\n return \"morse_code\"\n\n# Path: ciphey/basemods/Decoders/multi_tap.py\nfrom typing import Dict, Optional\n\nfrom ciphey.iface import Config, Decoder, ParamSpec, T, U, registry\n\n\n@registry.register\nclass Multi_tap(Decoder[str]):\n def decode(self, ctext: T) -> Optional[U]:\n result = \"\"\n for x in ctext.split():\n if x == self.SPACE_DIGIT: # Check if it's a space\n result += \" \"\n elif not Multi_tap.valid_code_part(x):\n return None\n else:\n result += self.decode_num_to_char(x)\n\n return result\n\n @staticmethod\n def valid_code_part(code: str) -> bool:\n if not code.isdigit():\n return False\n\n # if not all the digits are the same\n if not Multi_tap.is_all_dup(code):\n return False\n\n if int(code[0]) not in range(2, 10):\n return False\n\n if len(code) > 4:\n return False\n\n return True\n\n @staticmethod\n def decode_num_to_char(number: str) -> str:\n index = Multi_tap.calculate_index(number)\n return Multi_tap.number_index_to_char(index)\n\n @staticmethod\n def is_all_dup(code):\n return len(set(code)) == 1\n\n @staticmethod\n def calculate_index(number: str) -> int:\n first_number_as_int = int(number[0])\n\n number_index = Multi_tap.get_index_from_first_digit(first_number_as_int)\n\n # Add to index the number of the char : \"22\" -> index += 1\n num_rest_numbers = len(number) - 1\n number_index += num_rest_numbers\n\n return number_index\n\n @staticmethod\n def number_index_to_char(index_number: int) -> str:\n start_ascii_value = ord(\"A\")\n return chr(start_ascii_value + index_number)\n\n @staticmethod\n def get_index_from_first_digit(first_digit: int) -> int:\n number_index = 0\n if first_digit >= 8: # s have 4 chars\n number_index += 1\n\n first_digit -= 2 # start in 200\n\n number_index += first_digit * 3 # jump 3 every time\n\n return number_index\n\n @staticmethod\n def priority() -> float:\n return 0.05\n\n def __init__(self, config: Config):\n super().__init__(config)\n self.SPACE_DIGIT = \"0\"\n\n @staticmethod\n def getParams() -> Optional[Dict[str, ParamSpec]]:\n return None\n\n @staticmethod\n def getTarget() -> str:\n return \"multi_tap\"\n\n# Path: ciphey/basemods/Decoders/octal.py\nfrom typing import Dict, Optional\n\nimport logging\nfrom rich.logging import RichHandler\n\nfrom ciphey.iface import Config, Decoder, ParamSpec, T, U, registry\n\n\n@registry.register\nclass Octal(Decoder[str]):\n def decode(self, ctext: T) -> Optional[U]:\n \"\"\"\n Performs Octal decoding\n \"\"\"\n str_converted = []\n octal_seq = ctext.split(\" \")\n if len(octal_seq) == 1:\n # Concatted octal must be formed of octal triplets\n if len(ctext) % 3 != 0:\n return None\n octal_seq = [ctext[i : i + 3] for i in range(0, len(ctext), 3)]\n logging.debug(f\"Trying chunked octal {octal_seq}\")\n try:\n for octal_char in octal_seq:\n if len(octal_char) > 3:\n logging.debug(\"Octal subseq too long\")\n return None\n n = int(octal_char, 8)\n if (\n n < 0\n ): # n cannot be greater than 255, as we checked that with the earlier length check\n logging.debug(f\"Non octal char {octal_char}\")\n return None\n str_converted.append(n)\n\n return bytes(str_converted)\n # Catch bad octal chars\n except ValueError:\n return None\n\n @staticmethod\n def priority() -> float:\n return 0.025\n\n def __init__(self, config: Config):\n super().__init__(config)\n\n @staticmethod\n def getParams() -> Optional[Dict[str, ParamSpec]]:\n return None\n\n @staticmethod\n def getTarget() -> str:\n return \"octal\"\n\n# Path: ciphey/basemods/Decoders/reverse.py\nfrom typing import Dict, Optional\n\nfrom ciphey.iface import Config, Decoder, ParamSpec, T, U, registry\n\n\n@registry.register\nclass Reverse(Decoder[str]):\n def decode(self, ctext: T) -> Optional[U]:\n return ctext[::-1]\n\n @staticmethod\n def priority() -> float:\n return 0.05\n\n def __init__(self, config: Config):\n super().__init__(config)\n\n @staticmethod\n def getParams() -> Optional[Dict[str, ParamSpec]]:\n return None\n\n @staticmethod\n def getTarget() -> str:\n return \"reverse\"\n\n# Path: ciphey/basemods/Decoders/tap_code.py\n# by https://github.com/RustyDucky and https://github.com/lukasgabriel\n\nfrom typing import Dict, Optional\n\nfrom ciphey.iface import Config, Decoder, ParamSpec, T, Translation, U, registry\n\n\n@registry.register\nclass Tap_code(Decoder[str]):\n def decode(self, ctext: T) -> Optional[U]:\n \"\"\"\n Performs Tap code decoding\n \"\"\"\n try:\n result = \"\"\n combinations = ctext.split(\" \")\n for fragment in combinations:\n result += self.TABLE.get(fragment)\n return result\n except Exception:\n return None\n\n @staticmethod\n def priority() -> float:\n return 0.06\n\n def __init__(self, config: Config):\n super().__init__(config)\n self.TABLE = config.get_resource(self._params()[\"dict\"], Translation)\n\n @staticmethod\n def getParams() -> Optional[Dict[str, ParamSpec]]:\n return {\n \"dict\": ParamSpec(\n desc=\"The table of letters used for the tap code interpretation.\",\n req=False,\n default=\"cipheydists::translate::tap_code\",\n )\n }\n\n @staticmethod\n def getTarget() -> str:\n return \"tap_code\"\n\n# Path: ciphey/basemods/Decoders/unicode.py\nfrom typing import Dict, Optional\n\nimport logging\nfrom rich.logging import RichHandler\n\nfrom ciphey.iface import Config, Decoder, ParamSpec, T, U, registry\n\n\n@registry.register\nclass Utf8(Decoder[bytes]):\n def decode(self, ctext: T) -> Optional[U]:\n \"\"\"\n Performs UTF-8 decoding\n \"\"\"\n logging.debug(\"Attempting UTF-8 decoder\")\n result = \"\"\n try:\n result = ctext.decode(\"utf-8\")\n if result != ctext:\n logging.info(f\"UTF-8 successful, returning '{result}'\")\n return result\n else:\n return None\n except Exception:\n return None\n\n @staticmethod\n def priority() -> float:\n return 0.9\n\n def __init__(self, config: Config):\n super().__init__(config)\n\n @staticmethod\n def getParams() -> Optional[Dict[str, ParamSpec]]:\n return None\n\n @staticmethod\n def getTarget() -> str:\n return \"utf8\"\n\n# Path: ciphey/basemods/Decoders/url.py\nfrom typing import Dict, Optional\nfrom urllib.parse import unquote_plus\n\nimport logging\nfrom rich.logging import RichHandler\n\nfrom ciphey.iface import Config, Decoder, ParamSpec, T, U, registry\n\n\n@registry.register\nclass Url(Decoder[str]):\n def decode(self, ctext: T) -> Optional[U]:\n \"\"\"\n Performs URL decoding\n \"\"\"\n logging.debug(\"Attempting URL\")\n result = \"\"\n try:\n result = unquote_plus(ctext, errors=\"strict\")\n if result != ctext:\n logging.info(f\"URL successful, returning '{result}'\")\n return result\n else:\n return None\n except Exception:\n logging.debug(\"Failed to decode URL\")\n return None\n\n @staticmethod\n def priority() -> float:\n return 0.05\n\n def __init__(self, config: Config):\n super().__init__(config)\n\n @staticmethod\n def getParams() -> Optional[Dict[str, ParamSpec]]:\n return None\n\n @staticmethod\n def getTarget() -> str:\n return \"url\"\n\n# Path: ciphey/basemods/Decoders/uuencode.py\nfrom binascii import a2b_uu\nfrom codecs import decode\nfrom typing import Dict, Optional\n\nimport logging\nfrom rich.logging import RichHandler\n\nfrom ciphey.iface import Config, Decoder, ParamSpec, T, U, registry\n\n\n@registry.register\nclass Uuencode(Decoder[str]):\n def decode(self, ctext: T) -> Optional[U]:\n \"\"\"\n UUEncode (Unix to Unix Encoding) is a symmetric encryption\n based on conversion of binary data (split into 6-bit blocks) into ASCII characters.\n\n This function decodes the input string 'ctext' if it has been encoded using 'uuencoder'\n It will return None otherwise\n \"\"\"\n logging.debug(\"Attempting UUencode\")\n result = \"\"\n try:\n # UUencoded messages may begin with prefix \"begin\" and end with suffix \"end\"\n # In that case, we use the codecs module in Python\n ctext_strip = ctext.strip()\n if ctext_strip.startswith(\"begin\") and ctext_strip.endswith(\"end\"):\n result = decode(bytes(ctext, \"utf-8\"), \"uu\").decode()\n else:\n # If there isn't a \"being\" prefix and \"end\" suffix, we use the binascii module instead\n # It is possible that the ctext has multiple lines, so convert each line and append\n ctext_split = list(filter(None, ctext.splitlines()))\n for _, value in enumerate(ctext_split):\n result += a2b_uu(value).decode(\"utf-8\")\n logging.info(f\"UUencode successful, returning '{result}'\")\n return result\n except Exception:\n logging.debug(\"Failed to decode UUencode\")\n return None\n\n @staticmethod\n def priority() -> float:\n # Not expected to show up often, but also very fast to check.\n return 0.05\n\n def __init__(self, config: Config):\n super().__init__(config)\n\n @staticmethod\n def getParams() -> Optional[Dict[str, ParamSpec]]:\n return None\n\n @staticmethod\n def getTarget() -> str:\n return \"uuencode\"\n\n# Path: ciphey/basemods/Decoders/__init__.py\nfrom . import (\n a1z26,\n atbash,\n base58_bitcoin,\n base58_ripple,\n base62,\n base69,\n base91,\n bases,\n baudot,\n binary,\n braille,\n brainfuck,\n decimal,\n dna,\n dtmf,\n galactic,\n gzip,\n hexadecimal,\n leetspeak,\n morse_code,\n multi_tap,\n octal,\n reverse,\n tap_code,\n unicode,\n url,\n uuencode,\n)\n\n# Path: ciphey/basemods/Resources/cipheydists.py\nfrom functools import lru_cache\nfrom typing import Any, Dict, Optional, Set\n\nimport cipheydists\nimport logging\n\nfrom ciphey.iface import (\n Config,\n Distribution,\n ParamSpec,\n ResourceLoader,\n Translation,\n WordList,\n registry,\n)\n\n\n@registry.register_multi(WordList, Distribution, Translation)\nclass CipheyDists(ResourceLoader):\n # _wordlists: Set[str] = frozenset({\"english\", \"english1000\", \"englishStopWords\"})\n # _brandons: Set[str] = frozenset({\"english\"})\n # _dists: Set[str] = frozenset({\"twist\"})\n # _translates: Set[str] = frozenset({\"morse\"})\n _getters = {\n \"list\": cipheydists.get_list,\n \"dist\": cipheydists.get_dist,\n \"brandon\": cipheydists.get_brandon,\n \"translate\": cipheydists.get_translate,\n }\n\n def whatResources(self) -> Optional[Set[str]]:\n pass\n\n @lru_cache()\n def getResource(self, name: str) -> Any:\n logging.debug(f\"Loading cipheydists resource {name}\")\n prefix, name = name.split(\"::\", 1)\n return self._getters[prefix](name)\n\n def __init__(self, config: Config):\n super().__init__(config)\n\n @staticmethod\n def getParams() -> Optional[Dict[str, ParamSpec]]:\n return None\n\n# Path: ciphey/basemods/Resources/files.py\nimport csv\nimport json\nfrom functools import lru_cache\nfrom typing import Dict, Generic, Optional, Set\n\nfrom ciphey.iface import (\n Config,\n Distribution,\n ParamSpec,\n ResourceLoader,\n T,\n WordList,\n registry,\n)\n\n\n# We can use a generic resource loader here, as we can instantiate it later\n@registry.register_multi(WordList, Distribution)\nclass Json(ResourceLoader):\n def whatResources(self) -> T:\n return self._names\n\n @lru_cache()\n def getResource(self, name: str) -> T:\n prefix, name = name.split(\"::\", 1)\n return {\"wordlist\": (lambda js: {js}), \"dist\": (lambda js: js)}[prefix](\n json.load(open(self._paths[int(name) - 1]))\n )\n\n @staticmethod\n def getName() -> str:\n return \"json\"\n\n @staticmethod\n def getParams() -> Optional[Dict[str, ParamSpec]]:\n return {\"path\": ParamSpec(req=True, desc=\"The path to a JSON file\", list=True)}\n\n def __init__(self, config: Config):\n super().__init__(config)\n self._paths = self._params()[\"path\"]\n self._names = set(range(1, len(self._paths)))\n\n\n# We can use a generic resource loader here, as we can instantiate it later\n@registry.register_multi(WordList, Distribution)\nclass Csv(Generic[T], ResourceLoader[T]):\n def whatResources(self) -> Set[str]:\n return self._names\n\n @lru_cache()\n def getResource(self, name: str) -> T:\n prefix, name = name.split(\"::\", 1)\n return {\n \"wordlist\": (lambda reader: {i[0] for i in reader}),\n \"dist\": (lambda reader: {i[0]: float(i[1]) for i in reader}),\n }[prefix](csv.reader(open(self._paths[int(name) - 1])))\n\n @staticmethod\n def getName() -> str:\n return \"csv\"\n\n @staticmethod\n def getParams() -> Optional[Dict[str, ParamSpec]]:\n return {\"path\": ParamSpec(req=True, desc=\"The path to a CSV file\", list=True)}\n\n def __init__(self, config: Config):\n super().__init__(config)\n self._paths = self._params()[\"path\"]\n self._names = set(range(1, len(self._paths)))\n\n# Path: ciphey/basemods/Resources/__init__.py\nfrom . import cipheydists, files\n\n# Path: ciphey/basemods/Searchers/ausearch.py\nimport bisect\nimport distutils\nimport math\nfrom copy import copy\nfrom dataclasses import dataclass\nfrom functools import lru_cache\nfrom typing import Any, Dict, Generic, List, Optional, TypeVar, Union\n\nimport logging\nfrom rich.logging import RichHandler\n\nfrom ciphey.iface import (\n Checker,\n Config,\n Cracker,\n CrackInfo,\n CrackResult,\n Decoder,\n ParamSpec,\n Searcher,\n SearchLevel,\n SearchResult,\n T,\n registry,\n)\n\n\"\"\"\nWe are using a tree structure here, because that makes searching and tracing back easier\nAs such, when we encounter another possible parent, we remove that edge\n\"\"\"\n\n\nclass DuplicateNode(Exception):\n pass\n\n\n@dataclass\nclass AuSearchSuccessful(Exception):\n target: \"Node\"\n info: str\n\n\n@dataclass\nclass Node:\n # The root has no parent edge\n level: SearchLevel\n parent: Optional[\"Edge\"] = None\n depth: int = 0\n\n @staticmethod\n def decoding(\n config: Config, route: Union[Cracker, Decoder], result: Any, source: \"Node\"\n ) -> \"Node\":\n if not config.cache.mark_ctext(result):\n raise DuplicateNode()\n\n checker: Checker = config.objs[\"checker\"]\n ret = Node(\n parent=None,\n level=SearchLevel(\n name=type(route).__name__.lower(), result=CrackResult(value=result)\n ),\n depth=source.depth + 1,\n )\n edge = Edge(source=source, route=route, dest=ret)\n ret.parent = edge\n check_res = checker(result)\n if check_res is not None:\n raise AuSearchSuccessful(target=ret, info=check_res)\n return ret\n\n @staticmethod\n def cracker(config: Config, edge_template: \"Edge\", result: CrackResult) -> \"Node\":\n if not config.cache.mark_ctext(result.value):\n raise DuplicateNode()\n\n checker: Checker = config.objs[\"checker\"]\n # Edges do not directly contain containers, so this is fine\n edge = copy(edge_template)\n ret = Node(\n parent=edge,\n level=SearchLevel(name=type(edge.route).__name__.lower(), result=result),\n depth=edge.source.depth + 1,\n )\n edge.dest = ret\n check_res = checker(result.value)\n if check_res is not None:\n raise AuSearchSuccessful(target=ret, info=check_res)\n return ret\n\n @staticmethod\n def root(config: Config, ctext: Any):\n if not config.cache.mark_ctext(ctext):\n raise DuplicateNode()\n\n return Node(parent=None, level=SearchLevel.input(ctext))\n\n def get_path(self):\n if self.parent is None:\n return [self.level]\n return self.parent.source.get_path() + [self.level]\n\n\n@dataclass\nclass AusearchEdge:\n # TODO: This is just CrackInfo with failure probability added...\n success_probability: float\n failure_probability: float\n success_time: float\n failure_time: float\n\n def __init__(self, success_probability, success_time, failure_time):\n self.success_probability = success_probability\n self.failure_probability = 1.0 - success_probability\n self.success_time = success_time\n self.failure_time = failure_time\n\n\n@dataclass\nclass AusearchResult:\n weight: float\n index: int\n\n\ndef calculate_score(info: CrackInfo):\n return info.success_likelihood / \\\n (info.success_runtime * info.success_likelihood + info.failure_runtime * (1-info.success_likelihood))\n\n\n@dataclass\nclass Edge:\n source: Node\n route: Union[Cracker, Decoder]\n dest: Optional[Node] = None\n # Info is not filled in for Decoders\n score: Optional[float] = None\n\n\nPriorityType = TypeVar(\"PriorityType\")\n\n\nclass PriorityWorkQueue(Generic[PriorityType, T]):\n _sorted_priorities: List[PriorityType]\n _queues: Dict[Any, List[T]]\n\n def add_work(self, priority: PriorityType, work: List[T]) -> None:\n logging.debug(f\"\"\"Adding work at depth {priority}\"\"\")\n\n idx = bisect.bisect_left(self._sorted_priorities, priority)\n if (\n idx == len(self._sorted_priorities)\n or self._sorted_priorities[idx] != priority\n ):\n self._sorted_priorities.insert(idx, priority)\n self._queues.setdefault(priority, []).extend(work)\n\n def get_work(self) -> T:\n best_priority = self._sorted_priorities[0]\n target = self._queues[best_priority]\n ret = target.pop(0)\n if len(target) == 0:\n self._sorted_priorities.pop()\n return ret\n\n def get_work_chunk(self) -> List[T]:\n \"\"\"Returns the best work for now\"\"\"\n if len(self._sorted_priorities) == 0:\n return []\n best_priority = self._sorted_priorities.pop(0)\n return self._queues.pop(best_priority)\n\n def empty(self):\n return len(self._sorted_priorities) == 0\n\n def __init__(self):\n self._sorted_priorities = []\n self._queues = {}\n\n\n@registry.register\nclass AuSearch(Searcher):\n # Deeper paths get done later\n work: PriorityWorkQueue[int, Edge]\n\n @staticmethod\n def get_crackers_for(t: type):\n return registry[Cracker[t]]\n\n @lru_cache() # To save extra sorting\n def get_decoders_for(self, t: type):\n ret = registry[Decoder[t]]\n ret.sort(key=lambda x: x.priority(), reverse=True)\n return ret\n\n # def expand(self, edge: Edge) -> List[Edge]:\n # \"\"\"Evaluates the destination of the given, and adds its child edges to the pool\"\"\"\n # edge.dest = Node(parent=edge, level=edge.route(edge.source.level.result.value))\n\n def expand_crackers(self, node: Node) -> None:\n if node.depth >= self.max_cipher_depth:\n return\n\n res = node.level.result.value\n additional_work = []\n\n for i in self.get_crackers_for(type(res)):\n inst = self._config()(i)\n info = inst.getInfo(res)\n if info.success_likelihood < self.p_threshold:\n continue\n additional_work.append(\n Edge(source=node, route=inst, score=calculate_score(inst.getInfo(res)))\n )\n\n priority = min(node.depth, self.priority_cap)\n if self.invert_priority:\n priority = -priority\n\n self.work.add_work(priority, additional_work)\n\n def expand_decodings(self, node: Node) -> None:\n val = node.level.result.value\n\n for decoder in self.get_decoders_for(type(val)):\n inst = self._config()(decoder)\n res = inst(val)\n if res is None:\n continue\n try:\n new_node = Node.decoding(\n config=self._config(), route=inst, result=res, source=node\n )\n except DuplicateNode:\n continue\n\n logging.debug(\"Nesting encodings\")\n self.recursive_expand(new_node, False)\n\n def recursive_expand(self, node: Node, nested: bool = True) -> None:\n if node.depth >= self.max_depth:\n return\n\n logging.debug(f\"Expanding depth {node.depth}\")\n\n self.expand_decodings(node)\n\n # Doing this last allows us to catch simple nested encodings faster\n if not nested or self.enable_nested:\n self.expand_crackers(node)\n\n def search(self, ctext: Any) -> Optional[SearchResult]:\n logging.debug(\n f\"\"\"Beginning AuSearch with {\"inverted\" if self.invert_priority else \"normal\"} priority\"\"\"\n )\n\n try:\n root = Node.root(self._config(), ctext)\n except DuplicateNode:\n return None\n\n check_res = self._config().objs[\"checker\"](ctext)\n if check_res is not None:\n return SearchResult(check_res=check_res, path=[root.level])\n\n try:\n self.recursive_expand(root, False)\n\n while True:\n if self.work.empty():\n break\n # Get the highest level result\n chunk = self.work.get_work_chunk()\n chunk.sort(key=lambda i: i.score)\n # Work through all of this level's results\n while len(chunk) != 0:\n logging.debug(f\"{len(chunk)} remaining on this level\")\n # TODO Cyclic uses some tricky C++ here\n # I know because it's sorted the one at the back (the anti-weight)\n # is the most likely\n\n edge: Edge = chunk.pop(-1)\n\n # Expand the node\n res = edge.route(edge.source.level.result.value)\n if res is None:\n continue\n for i in res:\n try:\n node = Node.cracker(\n config=self._config(), edge_template=edge, result=i\n )\n self.recursive_expand(node)\n except DuplicateNode:\n continue\n\n except AuSearchSuccessful as e:\n logging.info(\"AuSearch succeeded\")\n return SearchResult(path=e.target.get_path(), check_res=e.info)\n\n logging.info(\"AuSearch failed\")\n\n def __init__(self, config: Config):\n super().__init__(config)\n self._checker: Checker = config.objs[\"checker\"]\n self.work = PriorityWorkQueue() # Has to be defined here because of sharing\n self.invert_priority = bool(\n distutils.util.strtobool(self._params()[\"invert_priority\"])\n )\n self.priority_cap = int(self._params()[\"priority_cap\"])\n self.enable_nested = bool(\n distutils.util.strtobool(self._params()[\"enable_nested\"])\n )\n self.max_cipher_depth = int(self._params()[\"max_cipher_depth\"])\n if self.max_cipher_depth == 0:\n self.max_cipher_depth = math.inf\n self.max_depth = int(self._params()[\"max_depth\"])\n if self.max_depth == 0:\n self.max_depth = math.inf\n self.p_threshold = float(self._params()[\"p_threshold\"])\n\n @staticmethod\n def getParams() -> Optional[Dict[str, ParamSpec]]:\n return {\n \"enable_nested\": ParamSpec(\n req=False,\n desc=\"Enables nested ciphers. \"\n \"Incredibly slow, and not guaranteed to terminate\",\n default=\"False\",\n ),\n \"invert_priority\": ParamSpec(\n req=False,\n desc=\"Causes more complex encodings to be looked at first. \"\n \"Good for deeply buried encodings.\",\n default=\"False\",\n ),\n \"max_cipher_depth\": ParamSpec(\n req=False,\n desc=\"The depth at which we stop trying to crack ciphers. \"\n \"Set to 0 to disable\",\n default=\"0\",\n ),\n \"max_depth\": ParamSpec(\n req=False,\n desc=\"The depth at which we give up. \" \n \"Set to 0 to disable\",\n default=\"0\",\n ),\n \"priority_cap\": ParamSpec(\n req=False,\n desc=\"Sets the maximum depth before we give up ordering items.\",\n default=\"2\",\n ),\n \"p_threshold\": ParamSpec(\n req=False,\n desc=\"Will skip any crackers which have less than this likelihood of succeeding. \"\n \"Set to 0 to disable\",\n default=\"0.01\",\n ),\n }\n\n# Path: ciphey/basemods/Searchers/__init__.py\nfrom . import ausearch\n\n# Path: ciphey/basemods/__init__.py\nfrom . import Checkers, Crackers, Decoders, Resources, Searchers\n\n# Path: ciphey/ciphey.py\n\"\"\"\n \u2588\u2588\u2588\u2588\u2588\u2588\u2557\u2588\u2588\u2557\u2588\u2588\u2588\u2588\u2588\u2588\u2557 \u2588\u2588\u2557 \u2588\u2588\u2557\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2557\u2588\u2588\u2557 \u2588\u2588\u2557\n\u2588\u2588\u2554\u2550\u2550\u2550\u2550\u255d\u2588\u2588\u2551\u2588\u2588\u2554\u2550\u2550\u2588\u2588\u2557\u2588\u2588\u2551 \u2588\u2588\u2551\u2588\u2588\u2554\u2550\u2550\u2550\u2550\u255d\u255a\u2588\u2588\u2557 \u2588\u2588\u2554\u255d\n\u2588\u2588\u2551 \u2588\u2588\u2551\u2588\u2588\u2588\u2588\u2588\u2588\u2554\u255d\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2551\u2588\u2588\u2588\u2588\u2588\u2557 \u255a\u2588\u2588\u2588\u2588\u2554\u255d\n\u2588\u2588\u2551 \u2588\u2588\u2551\u2588\u2588\u2554\u2550\u2550\u2550\u255d \u2588\u2588\u2554\u2550\u2550\u2588\u2588\u2551\u2588\u2588\u2554\u2550\u2550\u255d \u255a\u2588\u2588\u2554\u255d\n\u255a\u2588\u2588\u2588\u2588\u2588\u2588\u2557\u2588\u2588\u2551\u2588\u2588\u2551 \u2588\u2588\u2551 \u2588\u2588\u2551\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2557 \u2588\u2588\u2551\nhttps://github.com/Ciphey\nhttps://github.com/Ciphey/Ciphey/wiki\n\nThe cycle goes:\nmain -> argparsing (if needed) -> call_encryption -> new Ciphey object -> decrypt() -> produceProbTable ->\none_level_of_decryption -> decrypt_normal\n\"\"\"\nimport os\nimport warnings\nfrom typing import Any, Optional, Union\n\nimport click\nfrom appdirs import AppDirs\nimport logging\nfrom rich.logging import RichHandler\nfrom rich.console import Console\n\nfrom . import iface\n\nwarnings.filterwarnings(\"ignore\")\n\nconsole = Console()\n\n\ndef decrypt(config: iface.Config, ctext: Any) -> Union[str, bytes]:\n \"\"\"A simple alias for searching a ctext and makes the answer pretty\"\"\"\n res: Optional[iface.SearchResult] = config.objs[\"searcher\"].search(ctext)\n if res is None:\n return \"Failed to crack\"\n if config.verbosity < 0:\n return res.path[-1].result.value\n else:\n return iface.pretty_search_results(res)\n\n\ndef get_name(ctx, param, value):\n # reads from stdin if value was not supplied\n if not value and not click.get_text_stream(\"stdin\").isatty():\n click.get_text_stream(\"stdin\").read().strip()\n return click.get_text_stream(\"stdin\").read().strip()\n else:\n return value\n\n\ndef print_help(ctx):\n # prints help menu\n # if no \narguments are passed\n click.echo(ctx.get_help())\n ctx.exit()\n\n\n@click.command()\n@click.option(\n \"-t\",\n \"--text\",\n help=\"The ciphertext you want to decrypt.\",\n type=str,\n)\n@click.option(\n \"-q\", \"--quiet\", help=\"Decrease verbosity\", type=int, count=True, default=None\n)\n@click.option(\n \"-g\",\n \"--greppable\",\n help=\"Only print the answer (useful for grep)\",\n type=bool,\n is_flag=True,\n default=None,\n)\n@click.option(\"-v\", \"--verbose\", count=True, type=int)\n@click.option(\"-C\", \"--checker\", help=\"Use the given checker\", default=None)\n@click.option(\n \"-c\",\n \"--config\",\n help=\"Uses the given config file. Defaults to appdirs.user_config_dir('ciphey', 'ciphey')/'config.yml'\",\n)\n@click.option(\"-w\", \"--wordlist\", help=\"Uses the given wordlist\")\n@click.option(\n \"-p\",\n \"--param\",\n help=\"Passes a parameter to the language checker\",\n multiple=True,\n)\n@click.option(\n \"-l\",\n \"--list-params\",\n help=\"List the parameters of the selected module\",\n type=bool,\n)\n@click.option(\n \"--searcher\",\n help=\"Select the searching algorithm to use\",\n)\n# HARLAN TODO XXX\n# I switched this to a boolean flag system\n# https://click.palletsprojects.com/en/7.x/options/#boolean-flags\n# True for bytes input, False for str\n@click.option(\n \"-b\",\n \"--bytes\",\n help=\"Forces ciphey to use binary mode for the input\",\n is_flag=True,\n default=None,\n)\n@click.option(\n \"--default-dist\",\n help=\"Sets the default character/byte distribution\",\n type=str,\n default=None,\n)\n@click.option(\n \"-m\",\n \"--module\",\n help=\"Adds a module from the given path\",\n type=click.Path(),\n multiple=True,\n)\n@click.option(\n \"-A\",\n \"--appdirs\",\n help=\"Print the location of where Ciphey wants the settings file to be\",\n type=bool,\n is_flag=True,\n)\n@click.option(\"-f\", \"--file\", type=click.File(\"rb\"), required=False)\n@click.argument(\"text_stdin\", callback=get_name, required=False)\ndef main(**kwargs):\n \"\"\"Ciphey - Automated Decryption Tool\n\n Documentation:\n https://github.com/Ciphey/Ciphey/wiki\\n\n Discord (support here, we're online most of the day):\n http://discord.skerritt.blog\\n\n GitHub:\n https://github.com/ciphey/ciphey\\n\n\n Ciphey is an automated decryption tool using smart artificial intelligence and natural language processing. Input encrypted text, get the decrypted text back.\n\n Examples:\\n\n Basic Usage: ciphey -t \"aGVsbG8gbXkgbmFtZSBpcyBiZWU=\"\n\n \"\"\"\n\n \"\"\"Function to deal with arguments. Either calls with args or not. Makes Pytest work.\n\n It gets the arguments in the function definition using locals()\n if withArgs is True, that means this is being called with command line args\n so go to arg_parsing() to get those args\n we then update locals() with the new command line args and remove \"withArgs\"\n This function then calls call_encryption(**result) which passes our dict of args\n to the function as its own arguments using dict unpacking.\n Returns:\n The output of the decryption.\n \"\"\"\n\n # if user wants to know where appdirs is\n # print and exit\n if \"appdirs\" in kwargs and kwargs[\"appdirs\"]:\n dirs = AppDirs(\"Ciphey\", \"Ciphey\")\n path_to_config = dirs.user_config_dir\n print(\n f\"The settings.yml file should be at {os.path.join(path_to_config, 'settings.yml')}\"\n )\n return None\n\n # Now we create the config object\n config = iface.Config()\n\n # Load the settings file into the config\n load_msg: str\n cfg_arg = kwargs[\"config\"]\n if cfg_arg is None:\n # Make sure that the config dir actually exists\n os.makedirs(iface.Config.get_default_dir(), exist_ok=True)\n config.load_file(create=True)\n load_msg = f\"Opened config file at {os.path.join(iface.Config.get_default_dir(), 'config.yml')}\"\n else:\n config.load_file(cfg_arg)\n load_msg = f\"Opened config file at {cfg_arg}\"\n\n # Load the verbosity, so that we can start logging\n verbosity = kwargs[\"verbose\"]\n quiet = kwargs[\"quiet\"]\n if verbosity is None:\n if quiet is not None:\n verbosity = -quiet\n elif quiet is not None:\n verbosity -= quiet\n if kwargs[\"greppable\"] is not None:\n verbosity -= 999\n # Use the existing value as a base\n config.verbosity += verbosity\n config.update_log_level(config.verbosity)\n logging.info(load_msg)\n logging.debug(f\"Got cmdline args {kwargs}\")\n\n # Now we load the modules\n module_arg = kwargs[\"module\"]\n if module_arg is not None:\n config.modules += list(module_arg)\n\n # We need to load formats BEFORE we instantiate objects\n if kwargs[\"bytes\"] is not None:\n config.update_format(\"bytes\")\n\n # Next, load the objects\n params = kwargs[\"param\"]\n if params is not None:\n for i in params:\n key, value = i.split(\"=\", 1)\n parent, name = key.split(\".\", 1)\n config.update_param(parent, name, value)\n config.update(\"checker\", kwargs[\"checker\"])\n config.update(\"searcher\", kwargs[\"searcher\"])\n config.update(\"default_dist\", kwargs[\"default_dist\"])\n\n config.complete_config()\n\n logging.debug(f\"Command line opts: {kwargs}\")\n logging.debug(f\"Config finalised: {config}\")\n\n # Finally, we load the plaintext\n if kwargs[\"text\"] is None:\n if kwargs[\"file\"] is not None:\n kwargs[\"text\"] = kwargs[\"file\"].read()\n elif kwargs[\"text_stdin\"] is not None:\n kwargs[\"text\"] = kwargs[\"text_stdin\"]\n else:\n # else print help menu\n print(\"[bold red]Error. No inputs were given to Ciphey. [bold red]\")\n\n @click.pass_context\n def all_procedure(ctx):\n print_help(ctx)\n\n all_procedure()\n\n return None\n\n if issubclass(config.objs[\"format\"], type(kwargs[\"text\"])):\n pass\n elif config.objs[\"format\"] == str and isinstance(kwargs[\"text\"], bytes):\n kwargs[\"text\"] = kwargs[\"text\"].decode(\"utf-8\")\n elif config.objs[\"format\"] == bytes and isinstance(kwargs[\"text\"], str):\n kwargs[\"text\"] = kwargs[\"text\"].encode(\"utf-8\")\n else:\n raise TypeError(f\"Cannot load type {config.format} from {type(kwargs['text'])}\")\n\n result: Optional[str]\n\n # if debug or quiet mode is on, run without spinner\n if config.verbosity != 0:\n result = decrypt(config, kwargs[\"text\"])\n else:\n # else, run with spinner if verbosity is 0\n with console.status(\"[bold green]Thinking...\", spinner=\"moon\") as status:\n config.set_spinner(status)\n result = decrypt(config, kwargs[\"text\"])\n if result is None:\n result = \"Could not find any solutions.\"\n\n console.print(result)\n\n# Path: ciphey/__init__.py\nfrom . import basemods, common, iface\nfrom .ciphey import decrypt\n\n# Path: ciphey/__main__.py\n#! /usr/bin/env python3\n\n\"\"\"\nCiphey: https://github.com/Ciphey/Ciphey\n\"\"\"\n\nimport platform\nimport sys\n\nif __name__ == \"__main__\":\n major = sys.version_info[0]\n minor = sys.version_info[1]\n\n python_version = (\n str(sys.version_info[0])\n + \".\"\n + str(sys.version_info[1])\n + \".\"\n + str(sys.version_info[2])\n )\n\n if major != 3 or minor < 6:\n print(\n f\"Ciphey requires Python 3.6+, you are using {python_version}. Please install a higher Python version. https://www.python.org/downloads/\"\n )\n print(\n \"Alternatively, visit our Discord and use the Ciphey bot in #bots http://discord.skerritt.blog\"\n )\n sys.exit(1)\n if platform.system() == \"Windows\":\n if minor > 8:\n print(\n \"Ciphey does not currently support Python 3.9 on Windows. Please use the Discord bot at http://discord.skerritt.blog\"\n )\n sys.exit(1)\n\n if sys.maxsize > 2 ** 32 is False:\n print(\n \"You are using Python 32 bit and Windows, Ciphey does not support this. Please upgrade to Python 64-bit here https://www.python.org/downloads/\"\n )\n sys.exit(1)\n from .ciphey import main\n\n main()\n\n# Path: ciphey/basemods/Checkers/entropy.py\nfrom typing import Dict, Optional\n\nimport logging\nfrom rich.logging import RichHandler\n\nfrom ciphey.iface import Checker, Config, ParamSpec, T, registry\n\n\n@registry.register\nclass Entropy(Checker[str]):\n\n \"\"\"\n Uses entropy to determine plaintext\n \"\"\"\n\n def check(self, text: T) -> Optional[str]:\n logging.debug(\"Trying entropy checker\")\n pass\n\n def getExpectedRuntime(self, text: T) -> float:\n # TODO: actually bench this\n # Uses benchmark from Discord\n return 2e-7 * len(text)\n\n def __init__(self, config: Config):\n super().__init__(config)\n\n @staticmethod\n def getParams() -> Optional[Dict[str, ParamSpec]]:\n pass\n\n# Path: ciphey/basemods/Checkers/gtest.py\nfrom typing import Dict, Optional\n\nimport logging\nfrom rich.logging import RichHandler\n\nfrom ciphey.iface import Checker, Config, ParamSpec, T, registry\n\n\n@registry.register\nclass GTestChecker(Checker[str]):\n\n \"\"\"\n G-test of fitness, similar to Chi squared.\n \"\"\"\n\n def check(self, text: T) -> Optional[str]:\n logging.debug(\"Trying entropy checker\")\n pass\n\n def getExpectedRuntime(self, text: T) -> float:\n # TODO: actually bench this\n return 4e-7 * len(text)\n\n def __init__(self, config: Config):\n super().__init__(config)\n\n @staticmethod\n def getParams() -> Optional[Dict[str, ParamSpec]]:\n pass\n\n# Path: ciphey/basemods/Crackers/hash.py\n\"\"\"\nThis is Hashbuster but slightly modified to work with Ciphey.\nWhy reinvent the wheel?\nChanges (that I can remember)\n* timeout set, as hashbuster took AGES before timeout was set.\nhttps://github.com/s0md3v/Hash-Buster\n\"\"\"\n\nimport re\nfrom typing import Dict, List, Optional\n\nimport requests\nimport logging\nfrom rich.logging import RichHandler\n\nfrom ciphey.iface import Config, Cracker, CrackInfo, CrackResult, ParamSpec, T, registry\n\nthread_count = 4\n\n\ndef alpha(ctext, hashtype):\n return None\n\n\ndef beta(ctext, hashtype):\n try:\n response = requests.get(\n \"https://hashtoolkit.com/reverse-hash/?hash=\", ctext, timeout=5\n ).text\n except requests.exceptions.ReadTimeout as e:\n logging.info(f\"Beta failed timeout {e}\")\n match = re.search(r'/generate-hash/?text=.*?\"', response)\n if match:\n return match.group(1)\n return None\n\n\ndef gamma(ctext, hashtype):\n try:\n response = requests.get(\n \"https://www.nitrxgen.net/md5db/\" + ctext, timeout=5\n ).text\n except requests.exceptions.ReadTimeout as e:\n logging.info(f\"Gamma failed with {e}\")\n if response:\n return response\n else:\n return None\n\n\ndef delta(ctext, hashtype):\n return None\n\n\ndef theta(ctext, hashtype):\n try:\n response = requests.get(\n \"https://md5decrypt.net/Api/api.php?hash=%s&hash_type=%s&email=deanna_abshire@proxymail.eu&code=1152464b80a61728\"\n % (ctext, hashtype),\n timeout=5,\n ).text\n except requests.exceptions.ReadTimeout as e:\n logging.info(f\"Gamma failed with {e}\")\n if len(response) != 0:\n return response\n else:\n return None\n\n\nmd5 = [gamma, alpha, beta, theta, delta]\nsha1 = [alpha, beta, theta, delta]\nsha256 = [alpha, beta, theta]\nsha384 = [alpha, beta, theta]\nsha512 = [alpha, beta, theta]\n\n\nresult = {}\n\n\ndef crack(ctext):\n raise \"Error Crack is called\"\n\n\ndef threaded(ctext):\n resp = crack(ctext)\n if resp:\n print(ctext + \" : \" + resp)\n result[ctext] = resp\n\n\n@registry.register\nclass HashBuster(Cracker[str]):\n @staticmethod\n def getTarget() -> str:\n return \"hash\"\n\n @staticmethod\n def getParams() -> Optional[Dict[str, ParamSpec]]:\n return None\n\n @staticmethod\n def priority() -> float:\n return 0.05\n\n def getInfo(self, ctext: T) -> CrackInfo:\n # TODO calculate these properly\n return CrackInfo(\n success_likelihood=0.5,\n success_runtime=5,\n failure_runtime=5,\n )\n\n def attemptCrack(self, ctext: T) -> List[CrackResult]:\n logging.info(\"Starting to crack hashes\")\n result = False\n\n candidates = []\n if len(ctext) == 32:\n for api in md5:\n r = api(ctext, \"md5\")\n if result is not None or r is not None:\n logging.debug(\"MD5 returns True {r}\")\n candidates.append(result, \"MD5\")\n elif len(ctext) == 40:\n for api in sha1:\n r = api(ctext, \"sha1\")\n if result is not None and r is not None:\n logging.debug(\"sha1 returns true\")\n candidates.append(result, \"SHA1\")\n elif len(ctext) == 64:\n for api in sha256:\n r = api(ctext, \"sha256\")\n if result is not None and r is not None:\n logging.debug(\"sha256 returns true\")\n candidates.append(result, \"SHA256\")\n elif len(ctext) == 96:\n for api in sha384:\n r = api(ctext, \"sha384\")\n if result is not None and r is not None:\n logging.debug(\"sha384 returns true\")\n candidates.append(result, \"SHA384\")\n elif len(ctext) == 128:\n for api in sha512:\n r = api(ctext, \"sha512\")\n if result is not None and r is not None:\n logging.debug(\"sha512 returns true\")\n candidates.append(result, \"SHA512\")\n\n # TODO what the fuck is this code?\n logging.debug(f\"Hash buster returning {result}\")\n # TODO add to 5.1 make this return multiple possible candidates\n return [CrackResult(value=candidates[0][0], misc_info=candidates[1][1])]\n\n def __init__(self, config: Config):\n super().__init__(config)\n\n# Path: ciphey/basemods/Decoders/base58_flickr.py\nfrom typing import Dict, Optional\n\nimport base58\n\nfrom ciphey.iface import Config, Decoder, ParamSpec, T, U, registry\n\n\n@registry.register\nclass Base58_flickr(Decoder[str]):\n def decode(self, ctext: T) -> Optional[U]:\n \"\"\"\n Performs Base58 (Flickr) decoding\n \"\"\"\n FLICKR_ALPHABET = b\"123456789abcdefghijkmnopqrstuvwxyzABCDEFGHJKLMNPQRSTUVWXYZ\"\n try:\n return base58.b58decode(ctext, alphabet=FLICKR_ALPHABET).decode(\"utf-8\")\n except Exception:\n return None\n\n @staticmethod\n def priority() -> float:\n # Not expected to show up often, but also very fast to check.\n return 0.05\n\n def __init__(self, config: Config):\n super().__init__(config)\n\n @staticmethod\n def getParams() -> Optional[Dict[str, ParamSpec]]:\n return None\n\n @staticmethod\n def getTarget() -> str:\n return \"base58_flickr\"\n\n# Path: ciphey/basemods/Decoders/base64_url.py\nimport base64\nfrom typing import Dict, Optional\n\nfrom ciphey.iface import Config, Decoder, ParamSpec, T, U, registry\n\n\n@registry.register\nclass Base64_url(Decoder[str]):\n def decode(self, ctext: T) -> Optional[U]:\n \"\"\"\n Performs Base64 URL decoding\n \"\"\"\n ctext_padding = ctext + \"=\" * (4 - len(ctext) % 4)\n try:\n return base64.urlsafe_b64decode(ctext_padding).decode(\"utf-8\")\n except Exception:\n return None\n\n @staticmethod\n def priority() -> float:\n # Not expected to show up often, but also very fast to check.\n return 0.05\n\n def __init__(self, config: Config):\n super().__init__(config)\n\n @staticmethod\n def getParams() -> Optional[Dict[str, ParamSpec]]:\n return None\n\n @staticmethod\n def getTarget() -> str:\n return \"base64_url\"\n\n# Path: ciphey/basemods/Decoders/base65536.py\nfrom typing import Dict, Optional\n\nimport base65536\n\nfrom ciphey.iface import Config, Decoder, ParamSpec, T, U, registry\n\n\n@registry.register\nclass Base65536(Decoder[str]):\n def decode(self, ctext: T) -> Optional[U]:\n \"\"\"\n Performs Base65536 decoding\n \"\"\"\n try:\n return base65536.decode(ctext).decode(\"utf-8\")\n except Exception:\n return None\n\n @staticmethod\n def priority() -> float:\n # Not expected to show up often, but also very fast to check.\n return 0.05\n\n def __init__(self, config: Config):\n super().__init__(config)\n\n @staticmethod\n def getParams() -> Optional[Dict[str, ParamSpec]]:\n return None\n\n @staticmethod\n def getTarget() -> str:\n return \"base65536\"\n\n# Path: ciphey/basemods/Decoders/z85.py\nfrom typing import Dict, Optional\n\nimport logging\nfrom rich.logging import RichHandler\nfrom zmq.utils import z85\n\nfrom ciphey.iface import Config, Decoder, ParamSpec, T, U, registry\n\n\n@registry.register\nclass Z85(Decoder[str]):\n def decode(self, ctext: T) -> Optional[U]:\n \"\"\"\n Performs Z85 decoding\n \"\"\"\n ctext_len = len(ctext)\n if ctext_len % 5:\n logging.debug(\n f\"Failed to decode Z85 because length must be a multiple of 5, not '{ctext_len}'\"\n )\n return None\n try:\n return z85.decode(ctext).decode(\"utf-8\")\n except Exception:\n return None\n\n @staticmethod\n def priority() -> float:\n # Not expected to show up often, but also very fast to check.\n return 0.05\n\n def __init__(self, config: Config):\n super().__init__(config)\n\n @staticmethod\n...\n", "needle_token_start": 10368, "needle_token_end": 10477, "code_context_ntokens": 16038, "output": ["```python\ndef get_name(ctx, param, value):\n # reads from stdin if value was not supplied\n if not value and not click.get_text_stream(\"stdin\").isatty():\n click.get_text_stream(\"stdin\").read().strip()\n return click.get_text_stream(\"stdin\").read().strip()\n else:\n return value\n```"]} +{"repo": "Ciphey/Ciphey", "name": "binary_to_ascii", "language": "python", "path": "ciphey/basemods/Crackers/xandy.py", "position_ratio": 0.75, "description": "\nFunction Description:\n1. **Purpose**: The function is designed to convert binary data encoded as a string into ASCII text.\n2. **Input**: A string representing binary data.\n3. **Output**: The ASCII representation of the input binary string if conversion is successful; otherwise, an empty string if there is a decoding error.\n4. **Procedure**: The function first converts the binary string to an integer using base 2. It then calculates the necessary number of bytes to represent this integer and converts the integer to a bytearray. This bytearray is then attempted to be decoded into ASCII text. If successful, the ASCII text is returned; if a Unicode decoding error occurs, an empty string is returned.\n\n", "instruction": "Based on the function description and code context, please retrieve and repeat the exact described function from the code context in a code block wrapped by ```:", "template": "instruction\ncode_context\ndescription\ninstruction", "code_context": "# Path: ciphey/basemods/Checkers/regex.py\nimport re\nfrom typing import Dict, Optional\n\nimport logging\nfrom rich.logging import RichHandler\n\nfrom ciphey.iface import Checker, Config, ParamSpec, T, registry\n\n\n@registry.register\nclass Regex(Checker[str]):\n def getExpectedRuntime(self, text: T) -> float:\n return 1e-5 # TODO: actually calculate this\n\n def __init__(self, config: Config):\n super().__init__(config)\n self.regexes = list(map(re.compile, self._params()[\"regex\"]))\n...\n# Path: ciphey/basemods/Checkers/what.py\nfrom typing import Dict, Optional\n\nfrom ciphey.iface import Checker, Config, ParamSpec, T, registry\nimport logging\nfrom rich.logging import RichHandler\nfrom pywhat import identifier\nfrom rich.console import Console\n\nconsole = Console()\n\n\n@registry.register\nclass What(Checker[str]):\n\n \"\"\"\n Uses PyWhat to determine plaintext with regexes\n https://github.com/bee-san/pyWhat\n \"\"\"\n\n def check(self, ctext: T) -> Optional[str]:\n logging.debug(\"Trying PyWhat checker\")\n returned_regexes = self.id.identify(ctext)\n if returned_regexes[\"Regexes\"]:\n matched_regex = returned_regexes[\"Regexes\"]['text'][0][\"Regex Pattern\"]\n\n ret = f'The plaintext is a [yellow]{matched_regex[\"Name\"]}[/yellow]'\n human = (\n f'\\nI think the plaintext is a [yellow]{matched_regex[\"Name\"]}[/yellow]'\n )\n\n if \"Description\" in matched_regex and matched_regex[\"Description\"]:\n s = matched_regex[\"Description\"]\n # lowercases first letter so it doesn't look weird\n s = f\", which is {s[0].lower() + s[1:]}\\n\"\n ret += s\n human += s\n\n # if URL is attached, include that too.\n if \"URL\" in matched_regex and matched_regex[\"URL\"]:\n link = matched_regex[\"URL\"] + ctext.replace(\" \", \"\")\n ret += f\"\\nClick here to view in browser [#CAE4F1][link={link}]{link}[/link][/#CAE4F1]\\n\"\n\n # If greppable mode is on, don't print this\n if self.config.verbosity >= 0:\n # Print with full stop\n console.print(human)\n return ret\n return None\n\n def getExpectedRuntime(self, text: T) -> float:\n # TODO: actually bench this\n return 2e-7 * len(text)\n\n @staticmethod\n def getParams() -> Optional[Dict[str, ParamSpec]]:\n return None\n\n def __init__(self, config: Config):\n super().__init__(config)\n self.config = config\n self.id = identifier.Identifier()\n\n# Path: ciphey/basemods/Checkers/ezcheck.py\nfrom typing import Dict, List, Optional\n\nfrom ciphey.iface import Checker, Config, ParamSpec, T, registry\n\nfrom .brandon import Brandon\nfrom .format import JsonChecker\nfrom .human import HumanChecker\nfrom .quadgrams import Quadgrams\nfrom .regex import RegexList\nfrom .what import What\n\n\n@registry.register\nclass EzCheck(Checker[str]):\n \"\"\"\n This object is effectively a prebuilt quorum (with requirement 1) of common patterns, followed by a human check\n \"\"\"\n\n def check(self, text: str) -> Optional[str]:\n for checker in self.checkers:\n res = checker.check(text)\n if (\n res is not None\n and (self.decider is None or self.decider.check(text)) is not None\n ):\n return res\n return None\n\n def getExpectedRuntime(self, text: T) -> float:\n return sum(\n i.getExpectedRuntime(text) for i in self.checkers\n ) + self.decider.getExpectedRuntime(text)\n\n def __init__(self, config: Config):\n super().__init__(config)\n\n self.checkers: List[Checker[str]] = []\n # Disable human checker for automated systems\n if config.verbosity >= 0:\n self.decider = config(HumanChecker)\n else:\n self.decider = None\n\n # We need to modify the config for each of the objects\n\n # First PyWhat, as it's the fastest\n self.checkers.append(config(What))\n\n # Next, the json checker\n self.checkers.append(config(JsonChecker))\n\n # Second to last, the quadgrams checker\n self.checkers.append(config(Quadgrams))\n\n # Finally, the Brandon checker, as it is the slowest\n self.checkers.append(config(Brandon))\n\n @staticmethod\n def getParams() -> Optional[Dict[str, ParamSpec]]:\n pass\n\n# Path: ciphey/basemods/Checkers/quorum.py\nfrom typing import Dict, Generic, Optional\n\nfrom ciphey.iface import Checker, Config, ParamSpec, T, _registry\n\n\nclass Quorum(Generic[T], Checker[T]):\n def check(self, text: T) -> Optional[str]:\n left = self._params().k\n results = []\n for checker in self.checkers:\n results.append(checker.check(text))\n if results[-1] is None:\n continue\n left -= 1\n # Early return check\n if left == 0:\n return str(results)\n\n def __init__(self, config: Config):\n super().__init__(config)\n\n if self._params().k is None:\n k = len(self._params()[\"checker\"])\n # These checks need to be separate, to make sure that we do not have zero members\n if self._params().k == 0 or self._params().k > len(self._params()[\"checker\"]):\n raise IndexError(\n \"k must be between 0 and the number of checkers (inclusive)\"\n )\n\n self.checkers = []\n for i in self._params()[\"checker\"]:\n # This enforces type consistency\n self.checkers.append(_registry.get_named(i, Checker[T]))\n\n @staticmethod\n def getParams() -> Optional[Dict[str, ParamSpec]]:\n return {\n \"checker\": ParamSpec(\n req=True, desc=\"The checkers to be used for analysis\", list=True\n ),\n \"k\": ParamSpec(\n req=False,\n desc=\"The minimum quorum size. Defaults to the number of checkers\",\n ),\n }\n\n# Path: ciphey/basemods/Checkers/__init__.py\nfrom . import any, brandon, ezcheck, format, human, quadgrams, quorum, regex, what\n\n# Path: ciphey/common.py\n\"\"\"Some useful adapters\"\"\"\nfrom typing import Any\n\n\ndef id_lambda(value: Any):\n \"\"\"\n A function used in dynamic class generation that abstracts away a constant return value (like in getName)\n \"\"\"\n return lambda *args: value\n\n\ndef fix_case(target: str, base: str) -> str:\n \"\"\"Returns the lower-case string target with the case of base\"\"\"\n ret = \"\".join(\n [\n target[i].upper() if base[i].isupper() else target[i]\n for i in range(len(target))\n ]\n )\n return \"\".join(\n [\n target[i].upper() if base[i].isupper() else target[i]\n for i in range(len(target))\n ]\n )\n\n# Path: ciphey/basemods/Crackers/affine.py\n# Community\n# by https://github.com/Ozzyz\n\nfrom typing import Dict, List, Optional\n\nimport cipheycore\nimport logging\nfrom rich.logging import RichHandler\n\nfrom ciphey.common import fix_case\nfrom ciphey.iface import Config, Cracker, CrackInfo, CrackResult, ParamSpec, registry\nfrom ciphey.mathsHelper import mathsHelper\n\n\n@registry.register\nclass Affine(Cracker[str]):\n \"\"\"\n Each character in the Affine Cipher is encoded with the rule E(x) = (ax + b) mod m\n m is the size of the alphabet, while a and b are the keys in the cipher. a must be coprime to b.\n The Caesar cipher is a specific case of the Affine Cipher, with a=1 and b being the shift of the cipher.\n Decryption is performed by D(x) = a_inv (x - b) mod m where a_inv is the modular multiplicative inverse of a mod m.\n\n In this version of the Affine Cipher, we do not allow alphabets with several instances of the same letter in different cases.\n For instance, the alphabet 'ABCdef123' is allowed, but 'AaBbCc' is not.\n \"\"\"\n\n def getInfo(self, ctext: str) -> CrackInfo:\n return CrackInfo(\n success_likelihood=0.1,\n success_runtime=1e-5,\n failure_runtime=1e-5,\n )\n\n @staticmethod\n def getTarget() -> str:\n return \"affine\"\n\n def attemptCrack(self, ctext: str) -> List[CrackResult]:\n \"\"\"\n Brute forces all the possible combinations of a and b to attempt to crack the cipher.\n \"\"\"\n logging.debug(\"Attempting affine\")\n candidates = []\n\n # a and b are coprime if gcd(a,b) is 1.\n possible_a = [\n a\n for a in range(1, self.alphabet_length)\n if mathsHelper.gcd(a, self.alphabet_length) == 1\n ]\n logging.info(\n f\"Trying Affine Cracker with {len(possible_a)} a-values and {self.alphabet_length} b-values\"\n )\n\n for a in possible_a:\n a_inv = mathsHelper.mod_inv(a, self.alphabet_length)\n # If there is no inverse, we cannot decrypt the text\n if a_inv is None:\n continue\n for b in range(self.alphabet_length):\n # Pass in lowered text. This means that we expect alphabets to not contain both 'a' and 'A'.\n translated = self.decrypt(ctext.lower(), a_inv, b, self.alphabet_length)\n\n candidate_probability = self.plaintext_probability(translated)\n if candidate_probability > self.plaintext_prob_threshold:\n candidates.append(\n CrackResult(\n value=fix_case(translated, ctext), key_info=f\"a={a}, b={b}\"\n )\n )\n logging.info(f\"Affine Cipher returned {len(candidates)} candidates\")\n return candidates\n\n def plaintext_probability(self, translated: str) -> float:\n \"\"\"\n Analyses the translated text and applies the chi squared test to see if it is a probable plaintext candidate\n Returns the probability of the chi-squared test.\n \"\"\"\n analysis = cipheycore.analyse_string(translated)\n return cipheycore.chisq_test(analysis, self.expected)\n\n def decrypt(self, text: str, a_inv: int, b: int, m: int) -> str:\n \"\"\"\n Each letter is decrypted at D(x) = a_inv (x - b) mod m where x is the char\n We treat the char value as its index in the alphabet, so if\n the alphabet is 'abcd....' and the char is 'b', it has the value 1.\n \"\"\"\n return \"\".join([self.decryptChar(char, a_inv, b, m) for char in text])\n\n def decryptChar(self, char: str, a_inv: int, b: int, m: int) -> str:\n\n # We lower the alphabet since both ctext and alphabet need to be in the same case in order\n # to perform the shifts. The translated text will have fixed case after the translation anyways.\n # This is only necessary if the specified alphabet is uppercase.\n alphabet = [x.lower() for x in self.group]\n\n # Preserve characters that are not in alphabet\n if char not in alphabet:\n return char\n char_idx = alphabet.index(char)\n decrypted_char_idx = (a_inv * (char_idx - b)) % m\n return alphabet[decrypted_char_idx]\n\n @staticmethod\n def getParams() -> Optional[Dict[str, ParamSpec]]:\n return {\n \"expected\": ParamSpec(\n desc=\"The expected distribution of the plaintext\",\n req=False,\n config_ref=[\"default_dist\"],\n ),\n \"group\": ParamSpec(\n desc=\"An ordered sequence of chars that make up the alphabet\",\n req=False,\n default=\"abcdefghijklmnopqrstuvwxyz\",\n ),\n }\n\n def __init__(self, config: Config):\n super().__init__(config)\n self.group = list(self._params()[\"group\"])\n self.expected = config.get_resource(self._params()[\"expected\"])\n self.alphabet_length = len(self.group)\n self.cache = config.cache\n self.plaintext_prob_threshold = 0.01\n\n# Path: ciphey/basemods/Crackers/ascii_shift.py\n\"\"\"\n \u2588\u2588\u2588\u2588\u2588\u2588\u2557\u2588\u2588\u2557\u2588\u2588\u2588\u2588\u2588\u2588\u2557 \u2588\u2588\u2557 \u2588\u2588\u2557\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2557\u2588\u2588\u2557 \u2588\u2588\u2557\n\u2588\u2588\u2554\u2550\u2550\u2550\u2550\u255d\u2588\u2588\u2551\u2588\u2588\u2554\u2550\u2550\u2588\u2588\u2557\u2588\u2588\u2551 \u2588\u2588\u2551\u2588\u2588\u2554\u2550\u2550\u2550\u2550\u255d\u255a\u2588\u2588\u2557 \u2588\u2588\u2554\u255d\n\u2588\u2588\u2551 \u2588\u2588\u2551\u2588\u2588\u2588\u2588\u2588\u2588\u2554\u255d\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2551\u2588\u2588\u2588\u2588\u2588\u2557 \u255a\u2588\u2588\u2588\u2588\u2554\u255d\n\u2588\u2588\u2551 \u2588\u2588\u2551\u2588\u2588\u2554\u2550\u2550\u2550\u255d \u2588\u2588\u2554\u2550\u2550\u2588\u2588\u2551\u2588\u2588\u2554\u2550\u2550\u255d \u255a\u2588\u2588\u2554\u255d\n\u255a\u2588\u2588\u2588\u2588\u2588\u2588\u2557\u2588\u2588\u2551\u2588\u2588\u2551 \u2588\u2588\u2551 \u2588\u2588\u2551\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2557 \u2588\u2588\u2551\n\u00a9 Brandon Skerritt\nGithub: brandonskerritt\n\"\"\"\n\nfrom typing import Dict, List, Optional\n\nimport cipheycore\nimport logging\nfrom rich.logging import RichHandler\n\nfrom ciphey.iface import Config, Cracker, CrackInfo, CrackResult, ParamSpec, registry\n\n\n@registry.register\nclass Ascii_shift(Cracker[str]):\n def getInfo(self, ctext: str) -> CrackInfo:\n analysis = self.cache.get_or_update(\n ctext,\n \"cipheycore::simple_analysis\",\n lambda: cipheycore.analyse_string(ctext),\n )\n\n return CrackInfo(\n success_likelihood=cipheycore.caesar_detect(analysis, self.expected),\n # TODO: actually calculate runtimes\n success_runtime=1e-5,\n failure_runtime=1e-5,\n )\n\n @staticmethod\n def getTarget() -> str:\n return \"ascii_shift\"\n\n def attemptCrack(self, ctext: str) -> List[CrackResult]:\n logging.info(f\"Trying ASCII shift cipher on {ctext}\")\n\n logging.debug(\"Beginning cipheycore simple analysis\")\n\n # Hand it off to the core\n analysis = self.cache.get_or_update(\n ctext,\n \"cipheycore::simple_analysis\",\n lambda: cipheycore.analyse_string(ctext),\n )\n logging.debug(\"Beginning cipheycore::caesar\")\n possible_keys = cipheycore.caesar_crack(\n analysis, self.expected, self.group, self.p_value\n )\n\n n_candidates = len(possible_keys)\n logging.info(f\"ASCII shift returned {n_candidates} candidates\")\n\n if n_candidates == 0:\n logging.debug(\"Filtering for better results\")\n analysis = cipheycore.analyse_string(ctext, self.group)\n possible_keys = cipheycore.caesar_crack(\n analysis, self.expected, self.group, self.p_value\n )\n\n candidates = []\n\n for candidate in possible_keys:\n logging.debug(f\"Candidate {candidate.key} has prob {candidate.p_value}\")\n translated = cipheycore.caesar_decrypt(ctext, candidate.key, self.group)\n candidates.append(CrackResult(value=translated, key_info=candidate.key))\n\n return candidates\n\n @staticmethod\n def getParams() -> Optional[Dict[str, ParamSpec]]:\n return {\n \"expected\": ParamSpec(\n desc=\"The expected distribution of the plaintext\",\n req=False,\n config_ref=[\"default_dist\"],\n ),\n \"group\": ParamSpec(\n desc=\"An ordered sequence of chars that make up the ASCII shift cipher alphabet\",\n req=False,\n default=\"\"\"\\x00\\x01\\x02\\x03\\x04\\x05\\x06\\x07\\x08\\t\\n\\x0b\\x0c\\r\\x0e\\x0f\\x10\\x11\\x12\\x13\\x14\\x15\\x16\\x17\\x18\\x19\\x1a\\x1b\\x1c\\x1d\\x1e\\x1f !\"#$%&\\'()*+,-./0123456789:;<=>?@ABCDEFGHIJKLMNOPQRSTUVWXYZ[\\\\]^_`abcdefghijklmnopqrstuvwxyz{|}~\\x7f\"\"\",\n ),\n \"p_value\": ParamSpec(\n desc=\"The p-value to use for standard frequency analysis\",\n req=False,\n default=0.01,\n )\n # TODO: add \"filter\" param\n }\n\n def __init__(self, config: Config):\n super().__init__(config)\n self.group = list(self._params()[\"group\"])\n self.expected = config.get_resource(self._params()[\"expected\"])\n self.cache = config.cache\n self.p_value = float(self._params()[\"p_value\"])\n\n# Path: ciphey/basemods/Crackers/baconian.py\nimport re\nfrom typing import Dict, List, Optional\n\nfrom ciphey.iface import (\n Config,\n Cracker,\n CrackInfo,\n CrackResult,\n ParamSpec,\n Translation,\n registry,\n)\nimport logging\nfrom rich.logging import RichHandler\n\n\n@registry.register\nclass Baconian(Cracker[str]):\n def getInfo(self, ctext: str) -> CrackInfo:\n return CrackInfo(\n success_likelihood=0.1,\n success_runtime=1e-5,\n failure_runtime=1e-5,\n )\n\n @staticmethod\n def getTarget() -> str:\n return \"baconian\"\n\n def attemptCrack(self, ctext: str) -> List[CrackResult]:\n \"\"\"\n Attempts to decode both variants of the Baconian cipher.\n \"\"\"\n logging.debug(\"Attempting Baconian cracker\")\n candidates = []\n result = []\n ctext_decoded = \"\"\n ctext_decoded2 = \"\"\n\n # Convert to uppercase and replace delimiters and whitespace with nothing\n ctext = re.sub(r\"[,;:\\-\\s]\", \"\", ctext.upper())\n\n # Make sure ctext only contains A and B\n if bool(re.search(r\"[^AB]\", ctext)) is True:\n logging.debug(\"Failed to crack baconian due to non baconian character(s)\")\n return None\n\n # Make sure ctext is divisible by 5\n ctext_len = len(ctext)\n if ctext_len % 5:\n logging.debug(\n f\"Failed to decode Baconian because length must be a multiple of 5, not '{ctext_len}'\"\n )\n return None\n\n # Split ctext into groups of 5\n ctext = \" \".join(ctext[i : i + 5] for i in range(0, len(ctext), 5))\n ctext_split = ctext.split(\" \")\n baconian_keys = self.BACONIAN_DICT.keys()\n\n # Decode I=J and U=V variant\n for i in ctext_split:\n if i in baconian_keys:\n ctext_decoded += self.BACONIAN_DICT[i]\n\n # Decode variant that assigns each letter a unique code\n for i in ctext_split:\n if \"+\" + i in baconian_keys:\n ctext_decoded2 += self.BACONIAN_DICT[\"+\" + i]\n\n candidates.append(ctext_decoded)\n candidates.append(ctext_decoded2)\n for candidate in candidates:\n if candidate != \"\":\n if candidate == candidates[0]:\n result.append(CrackResult(value=candidate, key_info=\"I=J & U=V\"))\n else:\n result.append(CrackResult(value=candidate))\n logging.debug(f\"Baconian cracker - Returning results: {result}\")\n return result\n\n @staticmethod\n def getParams() -> Optional[Dict[str, ParamSpec]]:\n return {\n \"expected\": ParamSpec(\n desc=\"The expected distribution of the plaintext\",\n req=False,\n config_ref=[\"default_dist\"],\n ),\n \"dict\": ParamSpec(\n desc=\"The Baconian alphabet dictionary to use\",\n req=False,\n default=\"cipheydists::translate::baconian\",\n ),\n }\n\n def __init__(self, config: Config):\n super().__init__(config)\n self.BACONIAN_DICT = config.get_resource(self._params()[\"dict\"], Translation)\n self.expected = config.get_resource(self._params()[\"expected\"])\n self.cache = config.cache\n\n# Path: ciphey/basemods/Crackers/caesar.py\n\"\"\"\n \u2588\u2588\u2588\u2588\u2588\u2588\u2557\u2588\u2588\u2557\u2588\u2588\u2588\u2588\u2588\u2588\u2557 \u2588\u2588\u2557 \u2588\u2588\u2557\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2557\u2588\u2588\u2557 \u2588\u2588\u2557\n\u2588\u2588\u2554\u2550\u2550\u2550\u2550\u255d\u2588\u2588\u2551\u2588\u2588\u2554\u2550\u2550\u2588\u2588\u2557\u2588\u2588\u2551 \u2588\u2588\u2551\u2588\u2588\u2554\u2550\u2550\u2550\u2550\u255d\u255a\u2588\u2588\u2557 \u2588\u2588\u2554\u255d\n\u2588\u2588\u2551 \u2588\u2588\u2551\u2588\u2588\u2588\u2588\u2588\u2588\u2554\u255d\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2551\u2588\u2588\u2588\u2588\u2588\u2557 \u255a\u2588\u2588\u2588\u2588\u2554\u255d\n\u2588\u2588\u2551 \u2588\u2588\u2551\u2588\u2588\u2554\u2550\u2550\u2550\u255d \u2588\u2588\u2554\u2550\u2550\u2588\u2588\u2551\u2588\u2588\u2554\u2550\u2550\u255d \u255a\u2588\u2588\u2554\u255d\n\u255a\u2588\u2588\u2588\u2588\u2588\u2588\u2557\u2588\u2588\u2551\u2588\u2588\u2551 \u2588\u2588\u2551 \u2588\u2588\u2551\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2557 \u2588\u2588\u2551\n\u00a9 Brandon Skerritt\nGithub: brandonskerritt\n\"\"\"\nfrom distutils import util\nfrom typing import Dict, List, Optional, Union\n\nimport cipheycore\nimport logging\nfrom rich.logging import RichHandler\n\nfrom ciphey.common import fix_case\nfrom ciphey.iface import Config, Cracker, CrackInfo, CrackResult, ParamSpec, registry\n\n\n@registry.register\nclass Caesar(Cracker[str]):\n def getInfo(self, ctext: str) -> CrackInfo:\n analysis = self.cache.get_or_update(\n ctext,\n \"cipheycore::simple_analysis\",\n lambda: cipheycore.analyse_string(ctext),\n )\n\n return CrackInfo(\n success_likelihood=cipheycore.caesar_detect(analysis, self.expected),\n # TODO: actually calculate runtimes\n success_runtime=1e-5,\n failure_runtime=1e-5,\n )\n\n @staticmethod\n def getTarget() -> str:\n return \"caesar\"\n\n def attemptCrack(self, ctext: str) -> List[CrackResult]:\n logging.info(f\"Trying caesar cipher on {ctext}\")\n # Convert it to lower case\n #\n # TODO: handle different alphabets\n if self.lower:\n message = ctext.lower()\n else:\n message = ctext\n\n logging.debug(\"Beginning cipheycore simple analysis\")\n\n # Hand it off to the core\n analysis = self.cache.get_or_update(\n ctext,\n \"cipheycore::simple_analysis\",\n lambda: cipheycore.analyse_string(ctext),\n )\n logging.debug(\"Beginning cipheycore::caesar\")\n possible_keys = cipheycore.caesar_crack(\n analysis, self.expected, self.group, self.p_value\n )\n\n n_candidates = len(possible_keys)\n logging.info(f\"Caesar returned {n_candidates} candidates\")\n\n if n_candidates == 0:\n logging.debug(\"Filtering for better results\")\n analysis = cipheycore.analyse_string(ctext, self.group)\n possible_keys = cipheycore.caesar_crack(\n analysis, self.expected, self.group, self.p_value\n )\n\n candidates = []\n\n for candidate in possible_keys:\n logging.debug(f\"Candidate {candidate.key} has prob {candidate.p_value}\")\n translated = cipheycore.caesar_decrypt(message, candidate.key, self.group)\n candidates.append(\n CrackResult(value=fix_case(translated, ctext), key_info=candidate.key)\n )\n\n return candidates\n\n @staticmethod\n def getParams() -> Optional[Dict[str, ParamSpec]]:\n return {\n \"expected\": ParamSpec(\n desc=\"The expected distribution of the plaintext\",\n req=False,\n config_ref=[\"default_dist\"],\n ),\n \"group\": ParamSpec(\n desc=\"An ordered sequence of chars that make up the caesar cipher alphabet\",\n req=False,\n default=\"abcdefghijklmnopqrstuvwxyz\",\n ),\n \"lower\": ParamSpec(\n desc=\"Whether or not the ciphertext should be converted to lowercase first\",\n req=False,\n default=True,\n ),\n \"p_value\": ParamSpec(\n desc=\"The p-value to use for standard frequency analysis\",\n req=False,\n default=0.01,\n )\n # TODO: add \"filter\" param\n }\n\n def __init__(self, config: Config):\n super().__init__(config)\n self.lower: Union[str, bool] = self._params()[\"lower\"]\n if not isinstance(self.lower, bool):\n self.lower = util.strtobool(self.lower)\n self.group = list(self._params()[\"group\"])\n self.expected = config.get_resource(self._params()[\"expected\"])\n self.cache = config.cache\n self.p_value = float(self._params()[\"p_value\"])\n\n# Path: ciphey/basemods/Crackers/rot47.py\n\"\"\"\n \u2588\u2588\u2588\u2588\u2588\u2588\u2557\u2588\u2588\u2557\u2588\u2588\u2588\u2588\u2588\u2588\u2557 \u2588\u2588\u2557 \u2588\u2588\u2557\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2557\u2588\u2588\u2557 \u2588\u2588\u2557\n\u2588\u2588\u2554\u2550\u2550\u2550\u2550\u255d\u2588\u2588\u2551\u2588\u2588\u2554\u2550\u2550\u2588\u2588\u2557\u2588\u2588\u2551 \u2588\u2588\u2551\u2588\u2588\u2554\u2550\u2550\u2550\u2550\u255d\u255a\u2588\u2588\u2557 \u2588\u2588\u2554\u255d\n\u2588\u2588\u2551 \u2588\u2588\u2551\u2588\u2588\u2588\u2588\u2588\u2588\u2554\u255d\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2551\u2588\u2588\u2588\u2588\u2588\u2557 \u255a\u2588\u2588\u2588\u2588\u2554\u255d\n\u2588\u2588\u2551 \u2588\u2588\u2551\u2588\u2588\u2554\u2550\u2550\u2550\u255d \u2588\u2588\u2554\u2550\u2550\u2588\u2588\u2551\u2588\u2588\u2554\u2550\u2550\u255d \u255a\u2588\u2588\u2554\u255d\n\u255a\u2588\u2588\u2588\u2588\u2588\u2588\u2557\u2588\u2588\u2551\u2588\u2588\u2551 \u2588\u2588\u2551 \u2588\u2588\u2551\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2557 \u2588\u2588\u2551\n\u00a9 Brandon Skerritt\nGithub: brandonskerritt\n\"\"\"\n\nfrom typing import Dict, List, Optional\n\nimport cipheycore\nimport logging\nfrom rich.logging import RichHandler\n\nfrom ciphey.iface import Config, Cracker, CrackInfo, CrackResult, ParamSpec, registry\n\n\n@registry.register\nclass Rot47(Cracker[str]):\n def getInfo(self, ctext: str) -> CrackInfo:\n analysis = self.cache.get_or_update(\n ctext,\n \"cipheycore::simple_analysis\",\n lambda: cipheycore.analyse_string(ctext),\n )\n\n return CrackInfo(\n success_likelihood=cipheycore.caesar_detect(analysis, self.expected),\n # TODO: actually calculate runtimes\n success_runtime=1e-5,\n failure_runtime=1e-5,\n )\n\n @staticmethod\n def getTarget() -> str:\n return \"rot47\"\n\n def attemptCrack(self, ctext: str) -> List[CrackResult]:\n logging.info(f\"Trying ROT47 cipher on {ctext}\")\n\n logging.debug(\"Beginning cipheycore simple analysis\")\n\n # Hand it off to the core\n analysis = self.cache.get_or_update(\n ctext,\n \"cipheycore::simple_analysis\",\n lambda: cipheycore.analyse_string(ctext),\n )\n logging.debug(\"Beginning cipheycore::caesar\")\n possible_keys = cipheycore.caesar_crack(\n analysis, self.expected, self.group, self.p_value\n )\n\n n_candidates = len(possible_keys)\n logging.info(f\"ROT47 returned {n_candidates} candidates\")\n\n if n_candidates == 0:\n logging.debug(\"Filtering for better results\")\n analysis = cipheycore.analyse_string(ctext, self.group)\n possible_keys = cipheycore.caesar_crack(\n analysis, self.expected, self.group, self.p_value\n )\n\n candidates = []\n\n for candidate in possible_keys:\n logging.debug(f\"Candidate {candidate.key} has prob {candidate.p_value}\")\n translated = cipheycore.caesar_decrypt(ctext, candidate.key, self.group)\n candidates.append(CrackResult(value=translated, key_info=candidate.key))\n\n return candidates\n\n @staticmethod\n def getParams() -> Optional[Dict[str, ParamSpec]]:\n return {\n \"expected\": ParamSpec(\n desc=\"The expected distribution of the plaintext\",\n req=False,\n config_ref=[\"default_dist\"],\n ),\n \"group\": ParamSpec(\n desc=\"An ordered sequence of chars that make up the ROT47 cipher alphabet\",\n req=False,\n default=\"\"\"!\"#$%&\\'()*+,-./0123456789:;<=>?@ABCDEFGHIJKLMNOPQRSTUVWXYZ[\\\\]^_`abcdefghijklmnopqrstuvwxyz{|}~\"\"\",\n ),\n \"p_value\": ParamSpec(\n desc=\"The p-value to use for standard frequency analysis\",\n req=False,\n default=0.01,\n )\n # TODO: add \"filter\" param\n }\n\n def __init__(self, config: Config):\n super().__init__(config)\n self.group = list(self._params()[\"group\"])\n self.expected = config.get_resource(self._params()[\"expected\"])\n self.cache = config.cache\n self.p_value = float(self._params()[\"p_value\"])\n\n# Path: ciphey/basemods/Crackers/soundex.py\nimport re\nfrom typing import Dict, List, Optional\n\nimport logging\nfrom rich.logging import RichHandler\n\nfrom ciphey.iface import (\n Config,\n Cracker,\n CrackInfo,\n CrackResult,\n ParamSpec,\n Translation,\n registry,\n)\n\n\n@registry.register\nclass Soundex(Cracker[str]):\n def getInfo(self, ctext: str) -> CrackInfo:\n return CrackInfo(\n success_likelihood=0.1,\n success_runtime=1e-5,\n failure_runtime=1e-5,\n )\n\n @staticmethod\n def getTarget() -> str:\n return \"soundex\"\n\n def attemptCrack(self, ctext: str) -> List[CrackResult]:\n \"\"\"\n Attempts to crack Soundex by generating all possible combinations.\n \"\"\"\n logging.debug(\"Attempting Soundex cracker\")\n word_list = []\n sentences = []\n result = []\n\n # Convert to uppercase and replace delimiters and whitespace with nothing\n ctext = re.sub(r\"[,;:\\-\\s]\", \"\", ctext.upper())\n\n # Make sure ctext contains only A-Z and 0-9\n if bool(re.search(r\"[^A-Z0-9]\", ctext)) is True:\n logging.debug(\"Failed to crack soundex due to non soundex character(s)\")\n return None\n\n # Make sure ctext is divisible by 4\n ctext_len = len(ctext)\n if ctext_len % 4:\n logging.debug(\n f\"Failed to decode Soundex because length must be a multiple of 4, not '{ctext_len}'\"\n )\n return None\n\n # Split ctext into groups of 4\n ctext = \" \".join(ctext[i : i + 4] for i in range(0, len(ctext), 4))\n ctext_split = ctext.split(\" \")\n soundex_keys = self.SOUNDEX_DICT.keys()\n\n # Find all words that correspond to each given soundex code\n for code in ctext_split:\n if code in soundex_keys:\n word_list.append(self.SOUNDEX_DICT[code])\n\n logging.info(f\"Possible words for given encoded text: {word_list}\")\n\n # Find all possible sentences\n self.getSentenceCombo(\n word_list,\n sentences,\n self.frequency_dict,\n self.sentence_freq,\n self.word_freq,\n )\n\n sorted_sentences = self.sortlistwithdict(sentences, self.frequency_dict)\n\n for sentence in sorted_sentences:\n result.append(CrackResult(value=sentence))\n\n logging.debug(f\"Soundex cracker - Returning results: {result}\")\n return result\n\n def sortlistwithdict(self, listtosort, hashes):\n \"\"\"\n This function uses the sum of ranks (based on frequency) of each word in each\n sentence and sorts them according to it.\n \"\"\"\n return sorted(listtosort, key=lambda x: hashes[x])\n\n def getSentenceCombo(\n self, A, sentences, frequency_dict, sentence_freq, word_freq, result=\"\", n=0\n ):\n \"\"\"\n This function uses recursion to generate a list of sentences from all possible\n words for a given set of soundex codes.\n \"\"\"\n logging.debug(\"Creating all possible sentences from Soundex\")\n if n == len(A):\n sentences.append(result[1:])\n for word in result[1:].split():\n # Adding the rank of each word to find out the sentence's net frequency\n if word in word_freq:\n sentence_freq += word_freq.index(word)\n # If the word isn't in the frequency list then it's a very uncommon word\n # so we add a large number (5000)\n else:\n sentence_freq += 5000\n frequency_dict[result[1:]] = sentence_freq\n sentence_freq = 0\n return\n\n for word in A[n]:\n out = result + \" \" + word\n self.getSentenceCombo(\n A, sentences, frequency_dict, sentence_freq, word_freq, out, n + 1\n )\n\n @staticmethod\n def getParams() -> Optional[Dict[str, ParamSpec]]:\n return {\n \"dict\": ParamSpec(\n desc=\"The Soundex dictionary to use\",\n req=False,\n default=\"cipheydists::translate::soundex\",\n ),\n \"freq\": ParamSpec(\n desc=\"The word frequency dictionary to use\",\n req=False,\n default=\"cipheydists::list::English5000Freq\",\n ),\n }\n\n def __init__(self, config: Config):\n super().__init__(config)\n self.SOUNDEX_DICT = config.get_resource(self._params()[\"dict\"], Translation)\n self.word_freq = config.get_resource(self._params()[\"freq\"], Translation)\n self.frequency_dict = {}\n self.sentence_freq = 0\n\n# Path: ciphey/basemods/Crackers/vigenere.py\n\"\"\"\n \u2588\u2588\u2588\u2588\u2588\u2588\u2557\u2588\u2588\u2557\u2588\u2588\u2588\u2588\u2588\u2588\u2557 \u2588\u2588\u2557 \u2588\u2588\u2557\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2557\u2588\u2588\u2557 \u2588\u2588\u2557\n\u2588\u2588\u2554\u2550\u2550\u2550\u2550\u255d\u2588\u2588\u2551\u2588\u2588\u2554\u2550\u2550\u2588\u2588\u2557\u2588\u2588\u2551 \u2588\u2588\u2551\u2588\u2588\u2554\u2550\u2550\u2550\u2550\u255d\u255a\u2588\u2588\u2557 \u2588\u2588\u2554\u255d\n\u2588\u2588\u2551 \u2588\u2588\u2551\u2588\u2588\u2588\u2588\u2588\u2588\u2554\u255d\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2551\u2588\u2588\u2588\u2588\u2588\u2557 \u255a\u2588\u2588\u2588\u2588\u2554\u255d\n\u2588\u2588\u2551 \u2588\u2588\u2551\u2588\u2588\u2554\u2550\u2550\u2550\u255d \u2588\u2588\u2554\u2550\u2550\u2588\u2588\u2551\u2588\u2588\u2554\u2550\u2550\u255d \u255a\u2588\u2588\u2554\u255d\n\u255a\u2588\u2588\u2588\u2588\u2588\u2588\u2557\u2588\u2588\u2551\u2588\u2588\u2551 \u2588\u2588\u2551 \u2588\u2588\u2551\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2557 \u2588\u2588\u2551\n\u00a9 Brandon Skerritt\nGithub: brandonskerritt\n\"\"\"\nfrom distutils import util\nfrom typing import Dict, List, Optional, Union\n\nimport cipheycore\nimport logging\nfrom rich.logging import RichHandler\n\nfrom ciphey.common import fix_case\nfrom ciphey.iface import Config, Cracker, CrackInfo, CrackResult, ParamSpec, registry\n\n\n@registry.register\nclass Vigenere(Cracker[str]):\n def getInfo(self, ctext: str) -> CrackInfo:\n if self.keysize is not None:\n analysis = self.cache.get_or_update(\n ctext,\n f\"vigenere::{self.keysize}\",\n lambda: cipheycore.analyse_string(\n ctext.lower(), self.keysize, self.group\n ),\n )\n\n val = cipheycore.vigenere_detect(analysis, self.expected)\n\n logging.info(f\"Vigenere has likelihood {val}\")\n\n return CrackInfo(\n success_likelihood=val,\n # TODO: actually calculate runtimes\n success_runtime=1e-3,\n failure_runtime=1e-2,\n )\n\n likely_lens = self.cache.get_or_update(\n ctext,\n \"vigenere::likely_lens\",\n lambda: cipheycore.vigenere_likely_key_lens(\n ctext.lower(), self.expected, self.group, self.detect_p_value\n ),\n )\n\n # Filter out the lens that make no sense\n likely_lens = [i for i in likely_lens if i.len <= self.max_key_length]\n\n for keysize in likely_lens:\n # Store the analysis\n analysis = self.cache.get_or_update(\n ctext, f\"vigenere::{keysize.len}\", lambda: keysize.tab\n )\n if len(likely_lens) == 0:\n return CrackInfo(\n success_likelihood=0,\n # TODO: actually calculate runtimes\n success_runtime=2e-3,\n failure_runtime=2e-2,\n )\n\n logging.info(\n f\"Vigenere has likelihood {likely_lens[0].p_value} with lens {[i.len for i in likely_lens]}\"\n )\n\n return CrackInfo(\n success_likelihood=likely_lens[0].p_value,\n # TODO: actually calculate runtimes\n success_runtime=2e-4,\n failure_runtime=2e-4,\n )\n\n @staticmethod\n def getTarget() -> str:\n return \"vigenere\"\n\n def crackOne(\n self, ctext: str, analysis: cipheycore.windowed_analysis_res, real_ctext: str\n ) -> List[CrackResult]:\n possible_keys = cipheycore.vigenere_crack(\n analysis, self.expected, self.group, self.p_value\n )\n if len(possible_keys) > self.clamp:\n possible_keys = possible_keys[: self.clamp]\n logging.debug(\n f\"Vigenere crack got keys: {[[i for i in candidate.key] for candidate in possible_keys]}\"\n )\n return [\n CrackResult(\n value=fix_case(\n cipheycore.vigenere_decrypt(ctext, candidate.key, self.group),\n real_ctext,\n ),\n key_info=\"\".join([self.group[i] for i in candidate.key]),\n misc_info=f\"p-value was {candidate.p_value}\",\n )\n for candidate in possible_keys[: min(len(possible_keys), 10)]\n ]\n\n def attemptCrack(self, ctext: str) -> List[CrackResult]:\n logging.info(\"Trying vigenere cipher\")\n # Convert it to lower case\n if self.lower:\n message = ctext.lower()\n else:\n message = ctext\n\n # Analysis must be done here, where we know the case for the cache\n if self.keysize is not None:\n return self.crackOne(\n message,\n self.cache.get_or_update(\n ctext,\n f\"vigenere::{self.keysize}\",\n lambda: cipheycore.analyse_string(\n message, self.keysize, self.group\n ),\n ),\n ctext,\n )\n\n arrs = []\n likely_lens = self.cache.get_or_update(\n ctext,\n \"vigenere::likely_lens\",\n lambda: cipheycore.vigenere_likely_key_lens(\n message, self.expected, self.group\n ),\n )\n possible_lens = [i for i in likely_lens]\n possible_lens.sort(key=lambda i: i.p_value)\n logging.debug(f\"Got possible lengths {[i.len for i in likely_lens]}\")\n # TODO: work out length\n for i in possible_lens:\n arrs.extend(\n self.crackOne(\n message,\n self.cache.get_or_update(\n ctext,\n f\"vigenere::{i.len}\",\n lambda: cipheycore.analyse_string(message, i.len, self.group),\n ),\n ctext,\n )\n )\n\n logging.info(f\"Vigenere returned {len(arrs)} candidates\")\n return arrs\n\n @staticmethod\n def getParams() -> Optional[Dict[str, ParamSpec]]:\n return {\n \"expected\": ParamSpec(\n desc=\"The expected distribution of the plaintext\",\n req=False,\n config_ref=[\"default_dist\"],\n ),\n \"group\": ParamSpec(\n desc=\"An ordered sequence of chars that make up the caesar cipher alphabet\",\n req=False,\n default=\"abcdefghijklmnopqrstuvwxyz\",\n ),\n \"lower\": ParamSpec(\n desc=\"Whether or not the ciphertext should be converted to lowercase first\",\n req=False,\n default=True,\n ),\n \"keysize\": ParamSpec(\n desc=\"A key size that should be used. If not given, will attempt to work it out\",\n req=False,\n ),\n \"p_value\": ParamSpec(\n desc=\"The p-value to use for windowed frequency analysis\",\n req=False,\n default=0.5,\n ),\n \"detect_p_value\": ParamSpec(\n desc=\"The p-value to use for the detection of Vigenere length\",\n req=False,\n default=0.01,\n ),\n \"clamp\": ParamSpec(\n desc=\"The maximum number of candidates that can be returned per key len\",\n req=False,\n default=10,\n ),\n }\n\n def __init__(self, config: Config):\n super().__init__(config)\n self.lower: Union[str, bool] = self._params()[\"lower\"]\n if not isinstance(self.lower, bool):\n self.lower = util.strtobool(self.lower)\n self.group = list(self._params()[\"group\"])\n self.expected = config.get_resource(self._params()[\"expected\"])\n self.cache = config.cache\n self.keysize = self._params().get(\"keysize\")\n if self.keysize is not None:\n self.keysize = int(self.keysize)\n self.p_value = float(self._params()[\"p_value\"])\n self.detect_p_value = float(self._params()[\"detect_p_value\"])\n self.clamp = int(self._params()[\"clamp\"])\n self.max_key_length = 16\n\n# Path: ciphey/basemods/Crackers/xandy.py\nimport re\nfrom typing import Dict, List, Optional\n\nimport logging\nfrom rich.logging import RichHandler\n\nfrom ciphey.iface import Config, Cracker, CrackInfo, CrackResult, ParamSpec, registry\n\n\n@registry.register\nclass Xandy(Cracker[str]):\n def getInfo(self, ctext: str) -> CrackInfo:\n return CrackInfo(\n success_likelihood=0.1,\n success_runtime=1e-5,\n failure_runtime=1e-5,\n )\n\n @staticmethod\n \ndef binary_to_ascii(variant):\n # Convert the binary string to an integer with base 2\n binary_int = int(variant, 2)\n byte_number = binary_int.bit_length() + 7 // 8\n\n # Convert the resulting int to a bytearray and then decode it to ASCII text\n binary_array = binary_int.to_bytes(byte_number, \"big\")\n try:\n ascii_text = binary_array.decode()\n logging.debug(f\"Found possible solution: {ascii_text[:32]}\")\n return ascii_text\n except UnicodeDecodeError as e:\n logging.debug(f\"Failed to crack X-Y due to a UnicodeDecodeError: {e}\")\n return \"\"\n\n @staticmethod\n def getTarget() -> str:\n return \"xandy\"\n\n def attemptCrack(self, ctext: str) -> List[CrackResult]:\n \"\"\"\n Checks an input if it only consists of two or three different letters.\n If this is the case, it attempts to regard those letters as\n 0 and 1 (with the third characters as an optional delimiter) and then\n converts it to ASCII text.\n \"\"\"\n logging.debug(\"Attempting X-Y replacement\")\n variants = []\n candidates = []\n result = []\n\n # Convert the ctext to all-lowercase and regex-match & replace all whitespace\n ctext = re.sub(r\"\\s+\", \"\", ctext.lower(), flags=re.UNICODE)\n\n # cset contains every unique value in the ctext\n cset = list(set(list(ctext)))\n cset_len = len(cset)\n\n if not 1 < cset_len < 4:\n # We only consider inputs with two or three unique values\n logging.debug(\n \"Failed to crack X-Y due to not containing two or three unique values\"\n )\n return None\n\n logging.debug(f\"String contains {cset_len} unique values: {cset}\")\n\n # In case of three unique values, we regard the least frequent character as the delimiter\n if cset_len == 3:\n # Count each unique character in the set to determine the least frequent one\n counting_list = []\n for char in cset:\n counting_list.append(ctext.count(char))\n val, index = min((val, index) for (index, val) in enumerate(counting_list))\n delimiter = cset[index]\n logging.debug(\n f\"{delimiter} occurs {val} times and is the probable delimiter\"\n )\n # Remove the delimiter from the ctext and compute new cset\n ctext = ctext.replace(delimiter, \"\")\n cset = list(set(list(ctext)))\n\n # Form both variants of the substitution\n for i in range(2):\n if i:\n variants.append(ctext.replace(cset[0], \"1\").replace(cset[1], \"0\"))\n else:\n variants.append(ctext.replace(cset[0], \"0\").replace(cset[1], \"1\"))\n\n # Apply function to both variants and strip stray NULL characters\n for variant in variants:\n candidates.append(self.binary_to_ascii(variant).strip(\"\\x00\"))\n for i, candidate in enumerate(candidates):\n if candidate != \"\":\n keyinfo = f\"{cset[0]} -> {i} & {cset[1]} -> {str(int(not i))}\"\n result.append(CrackResult(value=candidate, key_info=keyinfo))\n logging.debug(f\"X-Y cracker - Returning results: {result}\")\n return result\n\n @staticmethod\n def getParams() -> Optional[Dict[str, ParamSpec]]:\n return {\n \"expected\": ParamSpec(\n desc=\"The expected distribution of the plaintext\",\n req=False,\n config_ref=[\"default_dist\"],\n )\n }\n\n def __init__(self, config: Config):\n super().__init__(config)\n self.expected = config.get_resource(self._params()[\"expected\"])\n self.cache = config.cache\n\n# Path: ciphey/basemods/Crackers/xortool.py\n\"\"\"\n \u2588\u2588\u2588\u2588\u2588\u2588\u2557\u2588\u2588\u2557\u2588\u2588\u2588\u2588\u2588\u2588\u2557 \u2588\u2588\u2557 \u2588\u2588\u2557\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2557\u2588\u2588\u2557 \u2588\u2588\u2557\n\u2588\u2588\u2554\u2550\u2550\u2550\u2550\u255d\u2588\u2588\u2551\u2588\u2588\u2554\u2550\u2550\u2588\u2588\u2557\u2588\u2588\u2551 \u2588\u2588\u2551\u2588\u2588\u2554\u2550\u2550\u2550\u2550\u255d\u255a\u2588\u2588\u2557 \u2588\u2588\u2554\u255d\n\u2588\u2588\u2551 \u2588\u2588\u2551\u2588\u2588\u2588\u2588\u2588\u2588\u2554\u255d\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2551\u2588\u2588\u2588\u2588\u2588\u2557 \u255a\u2588\u2588\u2588\u2588\u2554\u255d\n\u2588\u2588\u2551 \u2588\u2588\u2551\u2588\u2588\u2554\u2550\u2550\u2550\u255d \u2588\u2588\u2554\u2550\u2550\u2588\u2588\u2551\u2588\u2588\u2554\u2550\u2550\u255d \u255a\u2588\u2588\u2554\u255d\n\u255a\u2588\u2588\u2588\u2588\u2588\u2588\u2557\u2588\u2588\u2551\u2588\u2588\u2551 \u2588\u2588\u2551 \u2588\u2588\u2551\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2557 \u2588\u2588\u2551\n\u00a9 Brandon Skerritt\nGithub: bee-san\n\"\"\"\nfrom typing import Dict, List, Optional\n\nimport logging\nfrom rich.logging import RichHandler\n\nfrom xortool_ciphey import tool_main\n\nfrom ciphey.iface import Config, Cracker, CrackInfo, CrackResult, ParamSpec, registry\n\n\n@registry.register\nclass XorTool(Cracker[str]):\n def getInfo(self, ctext: str) -> CrackInfo:\n return CrackInfo(\n success_likelihood=0.1,\n # TODO: actually calculate runtimes\n success_runtime=1e-8,\n failure_runtime=1e-8,\n )\n\n @staticmethod\n def getTarget() -> str:\n return \"xortool\"\n\n def attemptCrack(self, ctext: str) -> List[CrackResult]:\n logging.debug(\"Trying xortool cipher\")\n # TODO handle different charsets\n # TODO allow more config over xortool\n\n logging.debug(f\"{ctext}\")\n\n # https://github.com/Ciphey/xortool/discussions/4\n # for docs on this function\n try:\n result = tool_main.api(str.encode(ctext))\n except:\n logging.debug(\"Xor failed.\")\n return\n\n result = CrackResult(value=result[1][\"Dexored\"], key_info=result[0][\"keys\"])\n\n return [result]\n\n @staticmethod\n def getParams() -> Optional[Dict[str, ParamSpec]]:\n return {\n \"expected\": ParamSpec(\n desc=\"The expected distribution of the plaintext\",\n req=False,\n config_ref=[\"default_dist\"],\n ),\n \"p_value\": ParamSpec(\n desc=\"The p-value to use for standard frequency analysis\",\n req=False,\n default=0.01,\n ),\n }\n\n @staticmethod\n def score_utility() -> float:\n return 1.5\n\n def __init__(self, config: Config):\n super().__init__(config)\n self.expected = config.get_resource(self._params()[\"expected\"])\n self.cache = config.cache\n self.p_value = self._params()[\"p_value\"]\n\n# Path: ciphey/basemods/Crackers/__init__.py\nfrom . import (\n affine,\n ascii_shift,\n baconian,\n caesar,\n rot47,\n soundex,\n vigenere,\n xandy,\n xortool,\n)\n\n# Path: ciphey/basemods/Decoders/a1z26.py\nimport re\nfrom typing import Dict, Optional\n\nimport logging\nfrom rich.logging import RichHandler\n\nfrom ciphey.iface import Config, Decoder, ParamSpec, T, U, registry\n\n\n@registry.register\nclass A1z26(Decoder[str]):\n def decode(self, ctext: T) -> Optional[U]:\n \"\"\"\n Performs A1Z26 decoding\n \"\"\"\n logging.debug(\"Attempting A1Z26\")\n ctext_converted = []\n ctext_split = re.split(r\"[ ,;:\\-\\n]\", ctext)\n delimiters = set(sorted(re.sub(r\"[^ ,;:\\-\\n]\", \"\", ctext)))\n ctext_num = re.sub(r\"[,;:\\-\\s]\", \"\", ctext)\n ctext_decoded = \"\"\n if ctext_num.isnumeric() is False:\n logging.debug(\"Failed to decode A1Z26 due to non numeric character(s)\")\n return None\n try:\n for i in ctext_split:\n val = int(i)\n if val > 26 or val < 1:\n logging.debug(\n f\"Failed to decode A1Z26 due to invalid number '{val}'\"\n )\n return None\n val2 = int(i) + 96\n ctext_converted.append(chr(val2))\n ctext_decoded = \"\".join(ctext_converted)\n logging.info(\n f\"A1Z26 successful, returning '{ctext_decoded}' with delimiter(s) {delimiters}\"\n )\n return ctext_decoded\n except Exception:\n return None\n\n @staticmethod\n def priority() -> float:\n return 0.05\n\n def __init__(self, config: Config):\n super().__init__(config)\n\n @staticmethod\n def getParams() -> Optional[Dict[str, ParamSpec]]:\n return None\n\n @staticmethod\n def getTarget() -> str:\n return \"a1z26\"\n\n# Path: ciphey/basemods/Decoders/atbash.py\nfrom typing import Dict, Optional\n\nfrom ciphey.common import fix_case\nfrom ciphey.iface import Config, Decoder, ParamSpec, T, U, WordList, registry\n\n\n@registry.register\nclass Atbash(Decoder[str]):\n def decode(self, ctext: T) -> Optional[U]:\n \"\"\"\n Takes an encoded string and attempts to decode it according to the Atbash cipher.\n\n The Atbash cipher is a very simple substitution cipher without a key.\n It operates by replacing every letter in the input by its 'counterpoint'\n in the alphabet. Example: A -> Z, B -> Y, ... , M -> N and vice versa.\n \"\"\"\n\n result = \"\"\n atbash_dict = {self.ALPHABET[i]: self.ALPHABET[::-1][i] for i in range(26)}\n\n for letter in ctext.lower():\n if letter in atbash_dict.keys():\n # Match every letter of the input to its atbash counterpoint\n result += atbash_dict[letter]\n else:\n # If the current character is not in the defined alphabet,\n # just accept it as-is (useful for numbers, punctuation, etc.)\n result += letter\n return fix_case(result, ctext)\n\n @staticmethod\n def priority() -> float:\n # Not expected to show up often, but also very fast to check.\n return 0.1\n\n def __init__(self, config: Config):\n super().__init__(config)\n self.ALPHABET = config.get_resource(self._params()[\"dict\"], WordList)\n\n @staticmethod\n def getParams() -> Optional[Dict[str, ParamSpec]]:\n return {\n \"dict\": ParamSpec(\n desc=\"The alphabet used for the atbash operation.\",\n req=False,\n default=\"cipheydists::list::englishAlphabet\",\n )\n }\n\n @staticmethod\n def getTarget() -> str:\n return \"atbash\"\n\n# Path: ciphey/basemods/Decoders/base58_bitcoin.py\nfrom typing import Dict, Optional\n\nimport base58\n\nfrom ciphey.iface import Config, Decoder, ParamSpec, T, U, registry\n\n\n@registry.register\nclass Base58_bitcoin(Decoder[str]):\n def decode(self, ctext: T) -> Optional[U]:\n \"\"\"\n Performs Base58 (Bitcoin) decoding\n \"\"\"\n try:\n return base58.b58decode(ctext).decode(\"utf-8\")\n except Exception:\n return None\n\n @staticmethod\n def priority() -> float:\n # Not expected to show up often, but also very fast to check.\n return 0.05\n\n def __init__(self, config: Config):\n super().__init__(config)\n\n @staticmethod\n def getParams() -> Optional[Dict[str, ParamSpec]]:\n return None\n\n @staticmethod\n def getTarget() -> str:\n return \"base58_bitcoin\"\n\n# Path: ciphey/basemods/Decoders/base58_ripple.py\nfrom typing import Dict, Optional\n\nimport base58\n\nfrom ciphey.iface import Config, Decoder, ParamSpec, T, U, registry\n\n\n@registry.register\nclass Base58_ripple(Decoder[str]):\n def decode(self, ctext: T) -> Optional[U]:\n \"\"\"\n Performs Base58 (Ripple) decoding\n \"\"\"\n try:\n return base58.b58decode(ctext, alphabet=base58.RIPPLE_ALPHABET).decode(\n \"utf-8\"\n )\n except Exception:\n return None\n\n @staticmethod\n def priority() -> float:\n # Not expected to show up often, but also very fast to check.\n return 0.05\n\n def __init__(self, config: Config):\n super().__init__(config)\n\n @staticmethod\n def getParams() -> Optional[Dict[str, ParamSpec]]:\n return None\n\n @staticmethod\n def getTarget() -> str:\n return \"base58_ripple\"\n\n# Path: ciphey/basemods/Decoders/base62.py\nfrom typing import Dict, Optional\n\nimport base62\n\nfrom ciphey.iface import Config, Decoder, ParamSpec, T, U, registry\n\n\n@registry.register\nclass Base62(Decoder[str]):\n def decode(self, ctext: T) -> Optional[U]:\n \"\"\"\n Performs Base62 decoding\n \"\"\"\n try:\n return base62.decodebytes(ctext).decode(\"utf-8\")\n except Exception:\n return None\n\n @staticmethod\n def priority() -> float:\n # Not expected to show up often, but also very fast to check.\n return 0.05\n\n def __init__(self, config: Config):\n super().__init__(config)\n\n @staticmethod\n def getParams() -> Optional[Dict[str, ParamSpec]]:\n return None\n\n @staticmethod\n def getTarget() -> str:\n return \"base62\"\n\n# Path: ciphey/basemods/Decoders/base69.py\n# Translated to Python and adapted for Ciphey from the JS original at https://github.com/pshihn/base69\n\n\nimport re\nfrom math import ceil\nfrom typing import Dict, Optional\n\nfrom ciphey.iface import Config, Decoder, ParamSpec, T, U, WordList, registry\n\n\n@registry.register\nclass Base69(Decoder[str]):\n def decode(self, ctext: T) -> Optional[U]:\n \"\"\"\n Performs Base69 decoding\n \"\"\"\n # Remove whitespace\n try:\n ctext = re.sub(r\"\\s+\", \"\", ctext, flags=re.UNICODE)\n extra_bytes = 0\n clen = len(ctext)\n\n if ctext[:-1] == \"=\":\n extra_bytes = int(ctext[clen - 2])\n\n CHUNK_COUNT = ceil(clen / 16)\n result = [0 for _ in range(CHUNK_COUNT * 7 - extra_bytes)]\n\n for i in range(CHUNK_COUNT):\n...\n", "needle_token_start": 12001, "needle_token_end": 12183, "code_context_ntokens": 16098, "output": ["```python\ndef binary_to_ascii(variant):\n # Convert the binary string to an integer with base 2\n binary_int = int(variant, 2)\n byte_number = binary_int.bit_length() + 7 // 8\n\n # Convert the resulting int to a bytearray and then decode it to ASCII text\n binary_array = binary_int.to_bytes(byte_number, \"big\")\n try:\n ascii_text = binary_array.decode()\n logging.debug(f\"Found possible solution: {ascii_text[:32]}\")\n return ascii_text\n except UnicodeDecodeError as e:\n logging.debug(f\"Failed to crack X-Y due to a UnicodeDecodeError: {e}\")\n return \"\"\n```"]} +{"repo": "Ciphey/Ciphey", "name": "getSentenceCombo", "language": "python", "path": "ciphey/basemods/Crackers/soundex.py", "position_ratio": 0.85, "description": "\nFunction Description:\n1. **Purpose**: The function is designed to recursively generate all possible sentences from a list of lists of words, each corresponding to a set of soundex codes, and calculate their frequency scores based on word usage.\n2. **Input**: The function takes a list of lists where each sublist contains words that match a soundex code, a list to store sentences, a dictionary to map sentences to their frequency scores, and two dictionaries for word and sentence frequencies. It also takes a string to accumulate words into sentences and an index to track recursion depth.\n3. **Output**: It does not return a value directly but populates the provided list with all possible sentence combinations and updates the dictionary with their corresponding frequency scores.\n4. **Procedure**: The function uses recursion to explore all combinations of words from the input list. For each word in the current sublist, it appends the word to the ongoing sentence and calls itself with the next index. Once the end of the list is reached, the complete sentence is stored, and its frequency score is calculated by summing the ranks of its words from the frequency dictionary. If a word is not found in the dictionary, a high penalty score is added to denote rarity.\n\n", "instruction": "Based on the function description and code context, please retrieve and repeat the exact described function from the code context in a code block wrapped by ```:", "template": "instruction\ncode_context\ndescription\ninstruction", "code_context": "# Path: ciphey/basemods/Checkers/brandon.py\n\"\"\"\n \u2588\u2588\u2588\u2588\u2588\u2588\u2557\u2588\u2588\u2557\u2588\u2588\u2588\u2588\u2588\u2588\u2557 \u2588\u2588\u2557 \u2588\u2588\u2557\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2557\u2588\u2588\u2557 \u2588\u2588\u2557\n\u2588\u2588\u2554\u2550\u2550\u2550\u2550\u255d\u2588\u2588\u2551\u2588\u2588\u2554\u2550\u2550\u2588\u2588\u2557\u2588\u2588\u2551 \u2588\u2588\u2551\u2588\u2588\u2554\u2550\u2550\u2550\u2550\u255d\u255a\u2588\u2588\u2557 \u2588\u2588\u2554\u255d\n\u2588\u2588\u2551 \u2588\u2588\u2551\u2588\u2588\u2588\u2588\u2588\u2588\u2554\u255d\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2551\u2588\u2588\u2588\u2588\u2588\u2557 \u255a\u2588\u2588\u2588\u2588\u2554\u255d\n\u2588\u2588\u2551 \u2588\u2588\u2551\u2588\u2588\u2554\u2550\u2550\u2550\u255d \u2588\u2588\u2554\u2550\u2550\u2588\u2588\u2551\u2588\u2588\u2554\u2550\u2550\u255d \u255a\u2588\u2588\u2554\u255d\n\u255a\u2588\u2588\u2588\u2588\u2588\u2588\u2557\u2588\u2588\u2551\u2588\u2588\u2551 \u2588\u2588\u2551 \u2588\u2588\u2551\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2557 \u2588\u2588\u2551\n\u00a9 Brandon Skerritt\nGithub: brandonskerritt\n\nClass to determine whether something is English or not.\n1. Calculate the Chi Squared score of a sentence\n2. If the score is significantly lower than the average score, it _might_ be English\n 2.1. If the score _might_ be English, then take the text and compare it to the sorted dictionary\n in O(n log n) time.\n It creates a percentage of \"How much of this text is in the dictionary?\"\n The dictionary contains:\n * 20,000 most common US words\n * 10,000 most common UK words (there's no repetition between the two)\n * The top 10,000 passwords\n If the word \"Looks like\" English (chi-squared) and if it contains English words, we can conclude it is\n very likely English. The alternative is doing the dictionary thing but with an entire 479k word dictionary (slower)\n 2.2. If the score is not English, but we haven't tested enough to create an average, then test it against\n the dictionary\n\nThings to optimise:\n* We only run the dictionary if it's 20% smaller than the average for chi squared\n* We consider it \"English\" if 45% of the text matches the dictionary\n* We run the dictionary if there is less than 10 total chisquared test\n\nHow to add a language:\n* Download your desired dictionary. Try to make it the most popular words, for example. Place this file into this\n folder with languagename.txt\nAs an example, this comes built in with english.txt\nFind the statistical frequency of each letter in that language.\nFor English, we have:\nself.languages = {\n \"English\":\n [0.0855, 0.0160, 0.0316, 0.0387, 0.1210,0.0218, 0.0209, 0.0496, 0.0733, 0.0022,0.0081, 0.0421, 0.0253, 0.0717,\n 0.0747,0.0207, 0.0010, 0.0633, 0.0673, 0.0894,0.0268, 0.0106, 0.0183, 0.0019, 0.0172,0.0011]\n}\nIn chisquared.py\nTo add your language, do:\nself.languages = {\n \"English\":\n [0.0855, 0.0160, 0.0316, 0.0387, 0.1210,0.0218, 0.0209, 0.0496, 0.0733, 0.0022,0.0081, 0.0421, 0.0253, 0.0717,\n 0.0747,0.0207, 0.0010, 0.0633, 0.0673, 0.0894,0.0268, 0.0106, 0.0183, 0.0019, 0.0172,0.0011]\n \"German\": [0.0973]\n}\nIn alphabetical order\nAnd you're.... Done! Make sure the name of the two match up\n\"\"\"\nimport sys\nfrom math import ceil\nfrom typing import Dict, Optional\n\nimport logging\nfrom rich.logging import RichHandler\n\nfrom ciphey.iface import Checker, Config, ParamSpec, T, registry\n\nsys.path.append(\"..\")\ntry:\n import mathsHelper as mh\nexcept ModuleNotFoundError:\n import ciphey.mathsHelper as mh\n\n\n@registry.register\nclass Brandon(Checker[str]):\n \"\"\"\n Class designed to confirm whether something is **language** based on how many words of **language** appears\n Call confirmLanguage(text, language)\n * text: the text you want to confirm\n * language: the language you want to confirm\n\n Find out what language it is by using chisquared.py, the highest chisquared score is the language\n languageThreshold = 45\n if a string is 45% **language** words, then it's confirmed to be english\n \"\"\"\n\n def getExpectedRuntime(self, text: T) -> float:\n # TODO: actually work this out\n # TODO its 0.2 seconds on average\n return 1e-4 # 100 \u00b5s\n\n wordlist: set\n\n def clean_text(self, text: str) -> set:\n \"\"\"Cleans the text ready to be checked\n\n Strips punctuation, makes it lower case, turns it into a set separated by spaces, removes duplicate words\n\n Args:\n text -> The text we use to perform analysis on\n\n Returns:\n text -> the text as a list, now cleaned\n\n \"\"\"\n # makes the text unique words and readable\n text = text.lower()\n text = self.mh.strip_punctuation(text)\n text = text.split(\" \")\n text = filter(lambda x: len(x) > 2, text)\n text = set(text)\n return text\n\n def checker(self, text: str, threshold: float, text_length: int, var: set) -> bool:\n \"\"\"Given text determine if it passes checker\n\n The checker uses the variable passed to it. I.E. Stopwords list, 1k words, dictionary\n\n Args:\n text -> The text to check\n threshold -> at what point do we return True? The percentage of text that is in var before we return True\n text_length -> the length of the text\n var -> the variable we are checking against. Stopwords list, 1k words list, dictionary list.\n Returns:\n boolean -> True for it passes the test, False for it fails the test.\"\"\"\n if text is None:\n logging.debug(\"Checker's text is None, so returning False\")\n return False\n if var is None:\n logging.debug(\"Checker's input var is None, so returning False\")\n return False\n\n percent = ceil(text_length * threshold)\n logging.debug(f\"Checker's chunks are size {percent}\")\n meet_threshold = 0\n location = 0\n end = percent\n\n if text_length <= 0:\n return False\n\n while location <= text_length:\n # chunks the text, so only gets THRESHOLD chunks of text at a time\n text = list(text)\n to_analyse = text[location:end]\n logging.debug(f\"To analyse is {to_analyse}\")\n for word in to_analyse:\n # if word is a stopword, + 1 to the counter\n if word in var:\n logging.debug(\n f\"{word} is in var, which means I am +=1 to the meet_threshold which is {meet_threshold}\"\n )\n meet_threshold += 1\n meet_threshold_percent = meet_threshold / text_length\n if meet_threshold_percent >= threshold:\n logging.debug(\n f\"Returning true since the percentage is {meet_threshold / text_length} and the threshold is {threshold}\"\n )\n # if we meet the threshold, return True\n # otherwise, go over again until we do\n # We do this in the for loop because if we're at 24% and THRESHOLD is 25\n # we don't want to wait THRESHOLD to return true, we want to return True ASAP\n return True\n location = end\n end = end + percent\n logging.debug(\n f\"The language proportion {meet_threshold_percent} is under the threshold {threshold}\"\n )\n return False\n\n def __init__(self, config: Config):\n # Suppresses warning\n super().__init__(config)\n self.mh = mh.mathsHelper()\n\n phases = config.get_resource(self._params()[\"phases\"])\n\n self.thresholds_phase1 = phases[\"1\"]\n self.thresholds_phase2 = phases[\"2\"]\n self.top1000Words = config.get_resource(self._params().get(\"top1000\"))\n self.wordlist = config.get_resource(self._params()[\"wordlist\"])\n self.stopwords = config.get_resource(self._params().get(\"stopwords\"))\n\n self.len_phase1 = len(self.thresholds_phase1)\n self.len_phase2 = len(self.thresholds_phase2)\n\n def check(self, text: str) -> Optional[str]:\n \"\"\"Checks to see if the text is in English\n\n Performs a decryption, but mainly parses the internal data packet and prints useful information.\n\n Args:\n text -> The text we use to perform analysis on\n\n Returns:\n bool -> True if the text is English, False otherwise.\n\n \"\"\"\n logging.debug(f'In Language Checker with \"{text}\"')\n text = self.clean_text(text)\n logging.debug(f'Text split to \"{text}\"')\n if text == \"\":\n logging.debug(\"Returning None from Brandon as the text cleaned is none.\")\n return None\n\n length_text = len(text)\n\n what_to_use = {}\n...\n# Path: ciphey/basemods/Checkers/format.py\nimport json\nfrom typing import Dict, Optional\n\nimport logging\nfrom rich.logging import RichHandler\n\nfrom ciphey.iface import Checker, Config, ParamSpec, T, registry\n\n\n@registry.register\nclass JsonChecker(Checker[str]):\n\n \"\"\"\n This object is effectively a prebuilt quorum (with requirement 1) of common patterns\n \"\"\"\n\n def check(self, text: T) -> Optional[str]:\n logging.debug(\"Trying json checker\")\n\n # https://github.com/Ciphey/Ciphey/issues/389\n if text.isdigit():\n return None\n\n try:\n json.loads(text)\n return \"\"\n except ValueError:\n return None\n\n def getExpectedRuntime(self, text: T) -> float:\n # TODO: actually bench this\n return 1e-7 * len(text) # From benchmarks I found online\n\n def __init__(self, config: Config):\n super().__init__(config)\n\n @staticmethod\n def getParams() -> Optional[Dict[str, ParamSpec]]:\n pass\n\n# Path: ciphey/basemods/Checkers/human.py\nfrom typing import Dict, Optional\n\nfrom ciphey.iface import Checker, Config, ParamSpec, registry\nfrom rich.console import Console\nfrom rich.markup import escape\n\nconsole = Console()\n\n\n@registry.register\nclass HumanChecker(Checker[str]):\n\n \"\"\"\n Uses the person's decision to determine plaintext\n \"\"\"\n\n def check(self, ctext: str) -> Optional[str]:\n with self._config().pause_spinner_handle():\n response = console.input(\n f\"Possible plaintext: [blue bold]{escape(ctext.__repr__())}[/blue bold] ([green]y[/green]/[red]N[/red]): \"\n )\n if response == \"y\":\n return \"\"\n elif response in (\"n\", \"\"):\n return None\n else:\n return self.check(ctext)\n\n def getExpectedRuntime(self, text: str) -> float:\n return 1 # About a second\n\n @staticmethod\n def getParams() -> Optional[Dict[str, ParamSpec]]:\n return None\n\n def __init__(self, config: Config):\n super().__init__(config)\n\n# Path: ciphey/basemods/Checkers/quadgrams.py\nimport logging\nimport re\nfrom math import log10\nfrom typing import Dict, Optional\n\nfrom ciphey.iface import Checker, Config, ParamSpec, T, Translation, registry\nfrom rich.logging import RichHandler\n\n\n@registry.register\nclass Quadgrams(Checker[str]):\n\n \"\"\"\n Uses Quadgrams to determine plaintext\n \"\"\"\n\n def check(self, ctext: T) -> Optional[str]:\n logging.debug(\"Trying Quadgrams checker\")\n # Capitalize and remove everything that's not a letter\n ctext = re.sub(\"[^A-Z]\", \"\", ctext.upper())\n quadgrams = self.QUADGRAMS_DICT\n quadgrams_sum = sum(quadgrams.values())\n score = 0\n for key in quadgrams.keys():\n quadgrams[key] = float(quadgrams[key]) / quadgrams_sum\n floor = log10(0.01 / quadgrams_sum)\n for i in range(len(ctext) - 4 + 1):\n # Get all quadgrams from ctext and check if they're in the dict\n # If yes then add the score of those quadgrams to the total score\n if ctext[i : i + 4] in quadgrams:\n score += quadgrams[ctext[i : i + 4]]\n else:\n score += floor\n if len(ctext) > 0:\n score = score / len(ctext)\n logging.info(f\"Quadgrams is {score}\")\n # The default threshold was found to work the best from lots of testing\n if score > self.threshold:\n return \"\"\n return None\n\n def getExpectedRuntime(self, text: T) -> float:\n # TODO: actually bench this\n return 2e-7 * len(text)\n\n @staticmethod\n def getParams() -> Optional[Dict[str, ParamSpec]]:\n return {\n \"dict\": ParamSpec(\n desc=\"The quadgrams dictionary to use\",\n req=False,\n default=\"cipheydists::dist::quadgrams\",\n ),\n \"score\": ParamSpec(\n desc=\"The score threshold to use\",\n req=False,\n default=0.00011,\n ),\n }\n\n def __init__(self, config: Config):\n super().__init__(config)\n self.QUADGRAMS_DICT = config.get_resource(self._params()[\"dict\"], Translation)\n self.threshold = float(self._params()[\"score\"])\n\n# Path: ciphey/basemods/Checkers/regex.py\nimport re\nfrom typing import Dict, Optional\n\nimport logging\nfrom rich.logging import RichHandler\n\nfrom ciphey.iface import Checker, Config, ParamSpec, T, registry\n\n\n@registry.register\nclass Regex(Checker[str]):\n def getExpectedRuntime(self, text: T) -> float:\n return 1e-5 # TODO: actually calculate this\n\n def __init__(self, config: Config):\n super().__init__(config)\n self.regexes = list(map(re.compile, self._params()[\"regex\"]))\n logging.debug(f\"There are {len(self.regexes)} regexes\")\n\n def check(self, text: str) -> Optional[str]:\n for regex in self.regexes:\n logging.debug(f\"Trying regex {regex} on {text}\")\n res = regex.search(text)\n logging.debug(f\"Results: {res}\")\n if res:\n return f\"Passed with regex {regex}. Want to contribute to Ciphey? Submit your regex here to allow Ciphey to automatically get this next time https://github.com/bee-san/pyWhat/wiki/Adding-your-own-Regex\\n\"\n\n @staticmethod\n def getParams() -> Optional[Dict[str, ParamSpec]]:\n return {\n \"regex\": ParamSpec(\n req=True,\n desc=\"The regex that must be matched (in a substring)\",\n list=True,\n )\n }\n\n\n@registry.register\nclass RegexList(Checker[str]):\n def getExpectedRuntime(self, text: T) -> float:\n return 1e-5 # TODO: actually calculate this\n\n def __init__(self, config: Config):\n super().__init__(config)\n self.regexes = []\n for i in self._params()[\"resource\"]:\n self.regexes += [re.compile(regex) for regex in config.get_resource(i)]\n logging.debug(f\"There are {len(self.regexes)} regexes\")\n\n def check(self, text: str) -> Optional[str]:\n for regex in self.regexes:\n logging.debug(f\"Trying regex {regex} on {text}\")\n res = regex.search(text)\n logging.debug(f\"Results: {res}\")\n if res:\n return f\"passed with regex {regex}\"\n\n @staticmethod\n def getParams() -> Optional[Dict[str, ParamSpec]]:\n return {\n \"resource\": ParamSpec(\n req=True,\n desc=\"A list of regexes that could be matched\",\n list=True,\n )\n }\n\n# Path: ciphey/basemods/Checkers/what.py\nfrom typing import Dict, Optional\n\nfrom ciphey.iface import Checker, Config, ParamSpec, T, registry\nimport logging\nfrom rich.logging import RichHandler\nfrom pywhat import identifier\nfrom rich.console import Console\n\nconsole = Console()\n\n\n@registry.register\nclass What(Checker[str]):\n\n \"\"\"\n Uses PyWhat to determine plaintext with regexes\n https://github.com/bee-san/pyWhat\n \"\"\"\n\n def check(self, ctext: T) -> Optional[str]:\n logging.debug(\"Trying PyWhat checker\")\n returned_regexes = self.id.identify(ctext)\n if returned_regexes[\"Regexes\"]:\n matched_regex = returned_regexes[\"Regexes\"]['text'][0][\"Regex Pattern\"]\n\n ret = f'The plaintext is a [yellow]{matched_regex[\"Name\"]}[/yellow]'\n human = (\n f'\\nI think the plaintext is a [yellow]{matched_regex[\"Name\"]}[/yellow]'\n )\n\n if \"Description\" in matched_regex and matched_regex[\"Description\"]:\n s = matched_regex[\"Description\"]\n # lowercases first letter so it doesn't look weird\n s = f\", which is {s[0].lower() + s[1:]}\\n\"\n ret += s\n human += s\n\n # if URL is attached, include that too.\n if \"URL\" in matched_regex and matched_regex[\"URL\"]:\n link = matched_regex[\"URL\"] + ctext.replace(\" \", \"\")\n ret += f\"\\nClick here to view in browser [#CAE4F1][link={link}]{link}[/link][/#CAE4F1]\\n\"\n\n # If greppable mode is on, don't print this\n if self.config.verbosity >= 0:\n # Print with full stop\n console.print(human)\n return ret\n return None\n\n def getExpectedRuntime(self, text: T) -> float:\n # TODO: actually bench this\n return 2e-7 * len(text)\n\n @staticmethod\n def getParams() -> Optional[Dict[str, ParamSpec]]:\n return None\n\n def __init__(self, config: Config):\n super().__init__(config)\n self.config = config\n self.id = identifier.Identifier()\n\n# Path: ciphey/basemods/Checkers/ezcheck.py\nfrom typing import Dict, List, Optional\n\nfrom ciphey.iface import Checker, Config, ParamSpec, T, registry\n\nfrom .brandon import Brandon\nfrom .format import JsonChecker\nfrom .human import HumanChecker\nfrom .quadgrams import Quadgrams\nfrom .regex import RegexList\nfrom .what import What\n\n\n@registry.register\nclass EzCheck(Checker[str]):\n \"\"\"\n This object is effectively a prebuilt quorum (with requirement 1) of common patterns, followed by a human check\n \"\"\"\n\n def check(self, text: str) -> Optional[str]:\n for checker in self.checkers:\n res = checker.check(text)\n if (\n res is not None\n and (self.decider is None or self.decider.check(text)) is not None\n ):\n return res\n return None\n\n def getExpectedRuntime(self, text: T) -> float:\n return sum(\n i.getExpectedRuntime(text) for i in self.checkers\n ) + self.decider.getExpectedRuntime(text)\n\n def __init__(self, config: Config):\n super().__init__(config)\n\n self.checkers: List[Checker[str]] = []\n # Disable human checker for automated systems\n if config.verbosity >= 0:\n self.decider = config(HumanChecker)\n else:\n self.decider = None\n\n # We need to modify the config for each of the objects\n\n # First PyWhat, as it's the fastest\n self.checkers.append(config(What))\n\n # Next, the json checker\n self.checkers.append(config(JsonChecker))\n\n # Second to last, the quadgrams checker\n self.checkers.append(config(Quadgrams))\n\n # Finally, the Brandon checker, as it is the slowest\n self.checkers.append(config(Brandon))\n\n @staticmethod\n def getParams() -> Optional[Dict[str, ParamSpec]]:\n pass\n\n# Path: ciphey/basemods/Checkers/quorum.py\nfrom typing import Dict, Generic, Optional\n\nfrom ciphey.iface import Checker, Config, ParamSpec, T, _registry\n\n\nclass Quorum(Generic[T], Checker[T]):\n def check(self, text: T) -> Optional[str]:\n left = self._params().k\n results = []\n for checker in self.checkers:\n results.append(checker.check(text))\n if results[-1] is None:\n continue\n left -= 1\n # Early return check\n if left == 0:\n return str(results)\n\n def __init__(self, config: Config):\n super().__init__(config)\n\n if self._params().k is None:\n k = len(self._params()[\"checker\"])\n # These checks need to be separate, to make sure that we do not have zero members\n if self._params().k == 0 or self._params().k > len(self._params()[\"checker\"]):\n raise IndexError(\n \"k must be between 0 and the number of checkers (inclusive)\"\n )\n\n self.checkers = []\n for i in self._params()[\"checker\"]:\n # This enforces type consistency\n self.checkers.append(_registry.get_named(i, Checker[T]))\n\n @staticmethod\n def getParams() -> Optional[Dict[str, ParamSpec]]:\n return {\n \"checker\": ParamSpec(\n req=True, desc=\"The checkers to be used for analysis\", list=True\n ),\n \"k\": ParamSpec(\n req=False,\n desc=\"The minimum quorum size. Defaults to the number of checkers\",\n ),\n }\n\n# Path: ciphey/basemods/Checkers/__init__.py\nfrom . import any, brandon, ezcheck, format, human, quadgrams, quorum, regex, what\n\n# Path: ciphey/common.py\n\"\"\"Some useful adapters\"\"\"\nfrom typing import Any\n\n\ndef id_lambda(value: Any):\n \"\"\"\n A function used in dynamic class generation that abstracts away a constant return value (like in getName)\n \"\"\"\n return lambda *args: value\n\n\ndef fix_case(target: str, base: str) -> str:\n \"\"\"Returns the lower-case string target with the case of base\"\"\"\n ret = \"\".join(\n [\n target[i].upper() if base[i].isupper() else target[i]\n for i in range(len(target))\n ]\n )\n return \"\".join(\n [\n target[i].upper() if base[i].isupper() else target[i]\n for i in range(len(target))\n ]\n )\n\n# Path: ciphey/basemods/Crackers/affine.py\n# Community\n# by https://github.com/Ozzyz\n\nfrom typing import Dict, List, Optional\n\nimport cipheycore\nimport logging\nfrom rich.logging import RichHandler\n\nfrom ciphey.common import fix_case\nfrom ciphey.iface import Config, Cracker, CrackInfo, CrackResult, ParamSpec, registry\nfrom ciphey.mathsHelper import mathsHelper\n\n\n@registry.register\nclass Affine(Cracker[str]):\n \"\"\"\n Each character in the Affine Cipher is encoded with the rule E(x) = (ax + b) mod m\n m is the size of the alphabet, while a and b are the keys in the cipher. a must be coprime to b.\n The Caesar cipher is a specific case of the Affine Cipher, with a=1 and b being the shift of the cipher.\n Decryption is performed by D(x) = a_inv (x - b) mod m where a_inv is the modular multiplicative inverse of a mod m.\n\n In this version of the Affine Cipher, we do not allow alphabets with several instances of the same letter in different cases.\n For instance, the alphabet 'ABCdef123' is allowed, but 'AaBbCc' is not.\n \"\"\"\n\n def getInfo(self, ctext: str) -> CrackInfo:\n return CrackInfo(\n success_likelihood=0.1,\n success_runtime=1e-5,\n failure_runtime=1e-5,\n )\n\n @staticmethod\n def getTarget() -> str:\n return \"affine\"\n\n def attemptCrack(self, ctext: str) -> List[CrackResult]:\n \"\"\"\n Brute forces all the possible combinations of a and b to attempt to crack the cipher.\n \"\"\"\n logging.debug(\"Attempting affine\")\n candidates = []\n\n # a and b are coprime if gcd(a,b) is 1.\n possible_a = [\n a\n for a in range(1, self.alphabet_length)\n if mathsHelper.gcd(a, self.alphabet_length) == 1\n ]\n logging.info(\n f\"Trying Affine Cracker with {len(possible_a)} a-values and {self.alphabet_length} b-values\"\n )\n\n for a in possible_a:\n a_inv = mathsHelper.mod_inv(a, self.alphabet_length)\n # If there is no inverse, we cannot decrypt the text\n if a_inv is None:\n continue\n for b in range(self.alphabet_length):\n # Pass in lowered text. This means that we expect alphabets to not contain both 'a' and 'A'.\n translated = self.decrypt(ctext.lower(), a_inv, b, self.alphabet_length)\n\n candidate_probability = self.plaintext_probability(translated)\n if candidate_probability > self.plaintext_prob_threshold:\n candidates.append(\n CrackResult(\n value=fix_case(translated, ctext), key_info=f\"a={a}, b={b}\"\n )\n )\n logging.info(f\"Affine Cipher returned {len(candidates)} candidates\")\n return candidates\n\n def plaintext_probability(self, translated: str) -> float:\n \"\"\"\n Analyses the translated text and applies the chi squared test to see if it is a probable plaintext candidate\n Returns the probability of the chi-squared test.\n \"\"\"\n analysis = cipheycore.analyse_string(translated)\n return cipheycore.chisq_test(analysis, self.expected)\n\n def decrypt(self, text: str, a_inv: int, b: int, m: int) -> str:\n \"\"\"\n Each letter is decrypted at D(x) = a_inv (x - b) mod m where x is the char\n We treat the char value as its index in the alphabet, so if\n the alphabet is 'abcd....' and the char is 'b', it has the value 1.\n \"\"\"\n return \"\".join([self.decryptChar(char, a_inv, b, m) for char in text])\n\n def decryptChar(self, char: str, a_inv: int, b: int, m: int) -> str:\n\n # We lower the alphabet since both ctext and alphabet need to be in the same case in order\n # to perform the shifts. The translated text will have fixed case after the translation anyways.\n # This is only necessary if the specified alphabet is uppercase.\n alphabet = [x.lower() for x in self.group]\n\n # Preserve characters that are not in alphabet\n if char not in alphabet:\n return char\n char_idx = alphabet.index(char)\n decrypted_char_idx = (a_inv * (char_idx - b)) % m\n return alphabet[decrypted_char_idx]\n\n @staticmethod\n def getParams() -> Optional[Dict[str, ParamSpec]]:\n return {\n \"expected\": ParamSpec(\n desc=\"The expected distribution of the plaintext\",\n req=False,\n config_ref=[\"default_dist\"],\n ),\n \"group\": ParamSpec(\n desc=\"An ordered sequence of chars that make up the alphabet\",\n req=False,\n default=\"abcdefghijklmnopqrstuvwxyz\",\n ),\n }\n\n def __init__(self, config: Config):\n super().__init__(config)\n self.group = list(self._params()[\"group\"])\n self.expected = config.get_resource(self._params()[\"expected\"])\n self.alphabet_length = len(self.group)\n self.cache = config.cache\n self.plaintext_prob_threshold = 0.01\n\n# Path: ciphey/basemods/Crackers/ascii_shift.py\n\"\"\"\n \u2588\u2588\u2588\u2588\u2588\u2588\u2557\u2588\u2588\u2557\u2588\u2588\u2588\u2588\u2588\u2588\u2557 \u2588\u2588\u2557 \u2588\u2588\u2557\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2557\u2588\u2588\u2557 \u2588\u2588\u2557\n\u2588\u2588\u2554\u2550\u2550\u2550\u2550\u255d\u2588\u2588\u2551\u2588\u2588\u2554\u2550\u2550\u2588\u2588\u2557\u2588\u2588\u2551 \u2588\u2588\u2551\u2588\u2588\u2554\u2550\u2550\u2550\u2550\u255d\u255a\u2588\u2588\u2557 \u2588\u2588\u2554\u255d\n\u2588\u2588\u2551 \u2588\u2588\u2551\u2588\u2588\u2588\u2588\u2588\u2588\u2554\u255d\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2551\u2588\u2588\u2588\u2588\u2588\u2557 \u255a\u2588\u2588\u2588\u2588\u2554\u255d\n\u2588\u2588\u2551 \u2588\u2588\u2551\u2588\u2588\u2554\u2550\u2550\u2550\u255d \u2588\u2588\u2554\u2550\u2550\u2588\u2588\u2551\u2588\u2588\u2554\u2550\u2550\u255d \u255a\u2588\u2588\u2554\u255d\n\u255a\u2588\u2588\u2588\u2588\u2588\u2588\u2557\u2588\u2588\u2551\u2588\u2588\u2551 \u2588\u2588\u2551 \u2588\u2588\u2551\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2557 \u2588\u2588\u2551\n\u00a9 Brandon Skerritt\nGithub: brandonskerritt\n\"\"\"\n\nfrom typing import Dict, List, Optional\n\nimport cipheycore\nimport logging\nfrom rich.logging import RichHandler\n\nfrom ciphey.iface import Config, Cracker, CrackInfo, CrackResult, ParamSpec, registry\n\n\n@registry.register\nclass Ascii_shift(Cracker[str]):\n def getInfo(self, ctext: str) -> CrackInfo:\n analysis = self.cache.get_or_update(\n ctext,\n \"cipheycore::simple_analysis\",\n lambda: cipheycore.analyse_string(ctext),\n )\n\n return CrackInfo(\n success_likelihood=cipheycore.caesar_detect(analysis, self.expected),\n # TODO: actually calculate runtimes\n success_runtime=1e-5,\n failure_runtime=1e-5,\n )\n\n @staticmethod\n def getTarget() -> str:\n return \"ascii_shift\"\n\n def attemptCrack(self, ctext: str) -> List[CrackResult]:\n logging.info(f\"Trying ASCII shift cipher on {ctext}\")\n\n logging.debug(\"Beginning cipheycore simple analysis\")\n\n # Hand it off to the core\n analysis = self.cache.get_or_update(\n ctext,\n \"cipheycore::simple_analysis\",\n lambda: cipheycore.analyse_string(ctext),\n )\n logging.debug(\"Beginning cipheycore::caesar\")\n possible_keys = cipheycore.caesar_crack(\n analysis, self.expected, self.group, self.p_value\n )\n\n n_candidates = len(possible_keys)\n logging.info(f\"ASCII shift returned {n_candidates} candidates\")\n\n if n_candidates == 0:\n logging.debug(\"Filtering for better results\")\n analysis = cipheycore.analyse_string(ctext, self.group)\n possible_keys = cipheycore.caesar_crack(\n analysis, self.expected, self.group, self.p_value\n )\n\n candidates = []\n\n for candidate in possible_keys:\n logging.debug(f\"Candidate {candidate.key} has prob {candidate.p_value}\")\n translated = cipheycore.caesar_decrypt(ctext, candidate.key, self.group)\n candidates.append(CrackResult(value=translated, key_info=candidate.key))\n\n return candidates\n\n @staticmethod\n def getParams() -> Optional[Dict[str, ParamSpec]]:\n return {\n \"expected\": ParamSpec(\n desc=\"The expected distribution of the plaintext\",\n req=False,\n config_ref=[\"default_dist\"],\n ),\n \"group\": ParamSpec(\n desc=\"An ordered sequence of chars that make up the ASCII shift cipher alphabet\",\n req=False,\n default=\"\"\"\\x00\\x01\\x02\\x03\\x04\\x05\\x06\\x07\\x08\\t\\n\\x0b\\x0c\\r\\x0e\\x0f\\x10\\x11\\x12\\x13\\x14\\x15\\x16\\x17\\x18\\x19\\x1a\\x1b\\x1c\\x1d\\x1e\\x1f !\"#$%&\\'()*+,-./0123456789:;<=>?@ABCDEFGHIJKLMNOPQRSTUVWXYZ[\\\\]^_`abcdefghijklmnopqrstuvwxyz{|}~\\x7f\"\"\",\n ),\n \"p_value\": ParamSpec(\n desc=\"The p-value to use for standard frequency analysis\",\n req=False,\n default=0.01,\n )\n # TODO: add \"filter\" param\n }\n\n def __init__(self, config: Config):\n super().__init__(config)\n self.group = list(self._params()[\"group\"])\n self.expected = config.get_resource(self._params()[\"expected\"])\n self.cache = config.cache\n self.p_value = float(self._params()[\"p_value\"])\n\n# Path: ciphey/basemods/Crackers/baconian.py\nimport re\nfrom typing import Dict, List, Optional\n\nfrom ciphey.iface import (\n Config,\n Cracker,\n CrackInfo,\n CrackResult,\n ParamSpec,\n Translation,\n registry,\n)\nimport logging\nfrom rich.logging import RichHandler\n\n\n@registry.register\nclass Baconian(Cracker[str]):\n def getInfo(self, ctext: str) -> CrackInfo:\n return CrackInfo(\n success_likelihood=0.1,\n success_runtime=1e-5,\n failure_runtime=1e-5,\n )\n\n @staticmethod\n def getTarget() -> str:\n return \"baconian\"\n\n def attemptCrack(self, ctext: str) -> List[CrackResult]:\n \"\"\"\n Attempts to decode both variants of the Baconian cipher.\n \"\"\"\n logging.debug(\"Attempting Baconian cracker\")\n candidates = []\n result = []\n ctext_decoded = \"\"\n ctext_decoded2 = \"\"\n\n # Convert to uppercase and replace delimiters and whitespace with nothing\n ctext = re.sub(r\"[,;:\\-\\s]\", \"\", ctext.upper())\n\n # Make sure ctext only contains A and B\n if bool(re.search(r\"[^AB]\", ctext)) is True:\n logging.debug(\"Failed to crack baconian due to non baconian character(s)\")\n return None\n\n # Make sure ctext is divisible by 5\n ctext_len = len(ctext)\n if ctext_len % 5:\n logging.debug(\n f\"Failed to decode Baconian because length must be a multiple of 5, not '{ctext_len}'\"\n )\n return None\n\n # Split ctext into groups of 5\n ctext = \" \".join(ctext[i : i + 5] for i in range(0, len(ctext), 5))\n ctext_split = ctext.split(\" \")\n baconian_keys = self.BACONIAN_DICT.keys()\n\n # Decode I=J and U=V variant\n for i in ctext_split:\n if i in baconian_keys:\n ctext_decoded += self.BACONIAN_DICT[i]\n\n # Decode variant that assigns each letter a unique code\n for i in ctext_split:\n if \"+\" + i in baconian_keys:\n ctext_decoded2 += self.BACONIAN_DICT[\"+\" + i]\n\n candidates.append(ctext_decoded)\n candidates.append(ctext_decoded2)\n for candidate in candidates:\n if candidate != \"\":\n if candidate == candidates[0]:\n result.append(CrackResult(value=candidate, key_info=\"I=J & U=V\"))\n else:\n result.append(CrackResult(value=candidate))\n logging.debug(f\"Baconian cracker - Returning results: {result}\")\n return result\n\n @staticmethod\n def getParams() -> Optional[Dict[str, ParamSpec]]:\n return {\n \"expected\": ParamSpec(\n desc=\"The expected distribution of the plaintext\",\n req=False,\n config_ref=[\"default_dist\"],\n ),\n \"dict\": ParamSpec(\n desc=\"The Baconian alphabet dictionary to use\",\n req=False,\n default=\"cipheydists::translate::baconian\",\n ),\n }\n\n def __init__(self, config: Config):\n super().__init__(config)\n self.BACONIAN_DICT = config.get_resource(self._params()[\"dict\"], Translation)\n self.expected = config.get_resource(self._params()[\"expected\"])\n self.cache = config.cache\n\n# Path: ciphey/basemods/Crackers/caesar.py\n\"\"\"\n \u2588\u2588\u2588\u2588\u2588\u2588\u2557\u2588\u2588\u2557\u2588\u2588\u2588\u2588\u2588\u2588\u2557 \u2588\u2588\u2557 \u2588\u2588\u2557\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2557\u2588\u2588\u2557 \u2588\u2588\u2557\n\u2588\u2588\u2554\u2550\u2550\u2550\u2550\u255d\u2588\u2588\u2551\u2588\u2588\u2554\u2550\u2550\u2588\u2588\u2557\u2588\u2588\u2551 \u2588\u2588\u2551\u2588\u2588\u2554\u2550\u2550\u2550\u2550\u255d\u255a\u2588\u2588\u2557 \u2588\u2588\u2554\u255d\n\u2588\u2588\u2551 \u2588\u2588\u2551\u2588\u2588\u2588\u2588\u2588\u2588\u2554\u255d\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2551\u2588\u2588\u2588\u2588\u2588\u2557 \u255a\u2588\u2588\u2588\u2588\u2554\u255d\n\u2588\u2588\u2551 \u2588\u2588\u2551\u2588\u2588\u2554\u2550\u2550\u2550\u255d \u2588\u2588\u2554\u2550\u2550\u2588\u2588\u2551\u2588\u2588\u2554\u2550\u2550\u255d \u255a\u2588\u2588\u2554\u255d\n\u255a\u2588\u2588\u2588\u2588\u2588\u2588\u2557\u2588\u2588\u2551\u2588\u2588\u2551 \u2588\u2588\u2551 \u2588\u2588\u2551\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2557 \u2588\u2588\u2551\n\u00a9 Brandon Skerritt\nGithub: brandonskerritt\n\"\"\"\nfrom distutils import util\nfrom typing import Dict, List, Optional, Union\n\nimport cipheycore\nimport logging\nfrom rich.logging import RichHandler\n\nfrom ciphey.common import fix_case\nfrom ciphey.iface import Config, Cracker, CrackInfo, CrackResult, ParamSpec, registry\n\n\n@registry.register\nclass Caesar(Cracker[str]):\n def getInfo(self, ctext: str) -> CrackInfo:\n analysis = self.cache.get_or_update(\n ctext,\n \"cipheycore::simple_analysis\",\n lambda: cipheycore.analyse_string(ctext),\n )\n\n return CrackInfo(\n success_likelihood=cipheycore.caesar_detect(analysis, self.expected),\n # TODO: actually calculate runtimes\n success_runtime=1e-5,\n failure_runtime=1e-5,\n )\n\n @staticmethod\n def getTarget() -> str:\n return \"caesar\"\n\n def attemptCrack(self, ctext: str) -> List[CrackResult]:\n logging.info(f\"Trying caesar cipher on {ctext}\")\n # Convert it to lower case\n #\n # TODO: handle different alphabets\n if self.lower:\n message = ctext.lower()\n else:\n message = ctext\n\n logging.debug(\"Beginning cipheycore simple analysis\")\n\n # Hand it off to the core\n analysis = self.cache.get_or_update(\n ctext,\n \"cipheycore::simple_analysis\",\n lambda: cipheycore.analyse_string(ctext),\n )\n logging.debug(\"Beginning cipheycore::caesar\")\n possible_keys = cipheycore.caesar_crack(\n analysis, self.expected, self.group, self.p_value\n )\n\n n_candidates = len(possible_keys)\n logging.info(f\"Caesar returned {n_candidates} candidates\")\n\n if n_candidates == 0:\n logging.debug(\"Filtering for better results\")\n analysis = cipheycore.analyse_string(ctext, self.group)\n possible_keys = cipheycore.caesar_crack(\n analysis, self.expected, self.group, self.p_value\n )\n\n candidates = []\n\n for candidate in possible_keys:\n logging.debug(f\"Candidate {candidate.key} has prob {candidate.p_value}\")\n translated = cipheycore.caesar_decrypt(message, candidate.key, self.group)\n candidates.append(\n CrackResult(value=fix_case(translated, ctext), key_info=candidate.key)\n )\n\n return candidates\n\n @staticmethod\n def getParams() -> Optional[Dict[str, ParamSpec]]:\n return {\n \"expected\": ParamSpec(\n desc=\"The expected distribution of the plaintext\",\n req=False,\n config_ref=[\"default_dist\"],\n ),\n \"group\": ParamSpec(\n desc=\"An ordered sequence of chars that make up the caesar cipher alphabet\",\n req=False,\n default=\"abcdefghijklmnopqrstuvwxyz\",\n ),\n \"lower\": ParamSpec(\n desc=\"Whether or not the ciphertext should be converted to lowercase first\",\n req=False,\n default=True,\n ),\n \"p_value\": ParamSpec(\n desc=\"The p-value to use for standard frequency analysis\",\n req=False,\n default=0.01,\n )\n # TODO: add \"filter\" param\n }\n\n def __init__(self, config: Config):\n super().__init__(config)\n self.lower: Union[str, bool] = self._params()[\"lower\"]\n if not isinstance(self.lower, bool):\n self.lower = util.strtobool(self.lower)\n self.group = list(self._params()[\"group\"])\n self.expected = config.get_resource(self._params()[\"expected\"])\n self.cache = config.cache\n self.p_value = float(self._params()[\"p_value\"])\n\n# Path: ciphey/basemods/Crackers/rot47.py\n\"\"\"\n \u2588\u2588\u2588\u2588\u2588\u2588\u2557\u2588\u2588\u2557\u2588\u2588\u2588\u2588\u2588\u2588\u2557 \u2588\u2588\u2557 \u2588\u2588\u2557\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2557\u2588\u2588\u2557 \u2588\u2588\u2557\n\u2588\u2588\u2554\u2550\u2550\u2550\u2550\u255d\u2588\u2588\u2551\u2588\u2588\u2554\u2550\u2550\u2588\u2588\u2557\u2588\u2588\u2551 \u2588\u2588\u2551\u2588\u2588\u2554\u2550\u2550\u2550\u2550\u255d\u255a\u2588\u2588\u2557 \u2588\u2588\u2554\u255d\n\u2588\u2588\u2551 \u2588\u2588\u2551\u2588\u2588\u2588\u2588\u2588\u2588\u2554\u255d\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2551\u2588\u2588\u2588\u2588\u2588\u2557 \u255a\u2588\u2588\u2588\u2588\u2554\u255d\n\u2588\u2588\u2551 \u2588\u2588\u2551\u2588\u2588\u2554\u2550\u2550\u2550\u255d \u2588\u2588\u2554\u2550\u2550\u2588\u2588\u2551\u2588\u2588\u2554\u2550\u2550\u255d \u255a\u2588\u2588\u2554\u255d\n\u255a\u2588\u2588\u2588\u2588\u2588\u2588\u2557\u2588\u2588\u2551\u2588\u2588\u2551 \u2588\u2588\u2551 \u2588\u2588\u2551\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2557 \u2588\u2588\u2551\n\u00a9 Brandon Skerritt\nGithub: brandonskerritt\n\"\"\"\n\nfrom typing import Dict, List, Optional\n\nimport cipheycore\nimport logging\nfrom rich.logging import RichHandler\n\nfrom ciphey.iface import Config, Cracker, CrackInfo, CrackResult, ParamSpec, registry\n\n\n@registry.register\nclass Rot47(Cracker[str]):\n def getInfo(self, ctext: str) -> CrackInfo:\n analysis = self.cache.get_or_update(\n ctext,\n \"cipheycore::simple_analysis\",\n lambda: cipheycore.analyse_string(ctext),\n )\n\n return CrackInfo(\n success_likelihood=cipheycore.caesar_detect(analysis, self.expected),\n # TODO: actually calculate runtimes\n success_runtime=1e-5,\n failure_runtime=1e-5,\n )\n\n @staticmethod\n def getTarget() -> str:\n return \"rot47\"\n\n def attemptCrack(self, ctext: str) -> List[CrackResult]:\n logging.info(f\"Trying ROT47 cipher on {ctext}\")\n\n logging.debug(\"Beginning cipheycore simple analysis\")\n\n # Hand it off to the core\n analysis = self.cache.get_or_update(\n ctext,\n \"cipheycore::simple_analysis\",\n lambda: cipheycore.analyse_string(ctext),\n )\n logging.debug(\"Beginning cipheycore::caesar\")\n possible_keys = cipheycore.caesar_crack(\n analysis, self.expected, self.group, self.p_value\n )\n\n n_candidates = len(possible_keys)\n logging.info(f\"ROT47 returned {n_candidates} candidates\")\n\n if n_candidates == 0:\n logging.debug(\"Filtering for better results\")\n analysis = cipheycore.analyse_string(ctext, self.group)\n possible_keys = cipheycore.caesar_crack(\n analysis, self.expected, self.group, self.p_value\n )\n\n candidates = []\n\n for candidate in possible_keys:\n logging.debug(f\"Candidate {candidate.key} has prob {candidate.p_value}\")\n translated = cipheycore.caesar_decrypt(ctext, candidate.key, self.group)\n candidates.append(CrackResult(value=translated, key_info=candidate.key))\n\n return candidates\n\n @staticmethod\n def getParams() -> Optional[Dict[str, ParamSpec]]:\n return {\n \"expected\": ParamSpec(\n desc=\"The expected distribution of the plaintext\",\n req=False,\n config_ref=[\"default_dist\"],\n ),\n \"group\": ParamSpec(\n desc=\"An ordered sequence of chars that make up the ROT47 cipher alphabet\",\n req=False,\n default=\"\"\"!\"#$%&\\'()*+,-./0123456789:;<=>?@ABCDEFGHIJKLMNOPQRSTUVWXYZ[\\\\]^_`abcdefghijklmnopqrstuvwxyz{|}~\"\"\",\n ),\n \"p_value\": ParamSpec(\n desc=\"The p-value to use for standard frequency analysis\",\n req=False,\n default=0.01,\n )\n # TODO: add \"filter\" param\n }\n\n def __init__(self, config: Config):\n super().__init__(config)\n self.group = list(self._params()[\"group\"])\n self.expected = config.get_resource(self._params()[\"expected\"])\n self.cache = config.cache\n self.p_value = float(self._params()[\"p_value\"])\n\n# Path: ciphey/basemods/Crackers/soundex.py\nimport re\nfrom typing import Dict, List, Optional\n\nimport logging\nfrom rich.logging import RichHandler\n\nfrom ciphey.iface import (\n Config,\n Cracker,\n CrackInfo,\n CrackResult,\n ParamSpec,\n Translation,\n registry,\n)\n\n\n@registry.register\nclass Soundex(Cracker[str]):\n def getInfo(self, ctext: str) -> CrackInfo:\n return CrackInfo(\n success_likelihood=0.1,\n success_runtime=1e-5,\n failure_runtime=1e-5,\n )\n\n @staticmethod\n def getTarget() -> str:\n return \"soundex\"\n\n def attemptCrack(self, ctext: str) -> List[CrackResult]:\n \"\"\"\n Attempts to crack Soundex by generating all possible combinations.\n \"\"\"\n logging.debug(\"Attempting Soundex cracker\")\n word_list = []\n sentences = []\n result = []\n\n # Convert to uppercase and replace delimiters and whitespace with nothing\n ctext = re.sub(r\"[,;:\\-\\s]\", \"\", ctext.upper())\n\n # Make sure ctext contains only A-Z and 0-9\n if bool(re.search(r\"[^A-Z0-9]\", ctext)) is True:\n logging.debug(\"Failed to crack soundex due to non soundex character(s)\")\n return None\n\n # Make sure ctext is divisible by 4\n ctext_len = len(ctext)\n if ctext_len % 4:\n logging.debug(\n f\"Failed to decode Soundex because length must be a multiple of 4, not '{ctext_len}'\"\n )\n return None\n\n # Split ctext into groups of 4\n ctext = \" \".join(ctext[i : i + 4] for i in range(0, len(ctext), 4))\n ctext_split = ctext.split(\" \")\n soundex_keys = self.SOUNDEX_DICT.keys()\n\n # Find all words that correspond to each given soundex code\n for code in ctext_split:\n if code in soundex_keys:\n word_list.append(self.SOUNDEX_DICT[code])\n\n logging.info(f\"Possible words for given encoded text: {word_list}\")\n\n # Find all possible sentences\n self.getSentenceCombo(\n word_list,\n sentences,\n self.frequency_dict,\n self.sentence_freq,\n self.word_freq,\n )\n\n sorted_sentences = self.sortlistwithdict(sentences, self.frequency_dict)\n\n for sentence in sorted_sentences:\n result.append(CrackResult(value=sentence))\n\n logging.debug(f\"Soundex cracker - Returning results: {result}\")\n return result\n\n def sortlistwithdict(self, listtosort, hashes):\n \"\"\"\n This function uses the sum of ranks (based on frequency) of each word in each\n sentence and sorts them according to it.\n \"\"\"\n return sorted(listtosort, key=lambda x: hashes[x])\n\n \ndef getSentenceCombo(\n self, A, sentences, frequency_dict, sentence_freq, word_freq, result=\"\", n=0\n ):\n \"\"\"\n This function uses recursion to generate a list of sentences from all possible\n words for a given set of soundex codes.\n \"\"\"\n logging.debug(\"Creating all possible sentences from Soundex\")\n if n == len(A):\n sentences.append(result[1:])\n for word in result[1:].split():\n # Adding the rank of each word to find out the sentence's net frequency\n if word in word_freq:\n sentence_freq += word_freq.index(word)\n # If the word isn't in the frequency list then it's a very uncommon word\n # so we add a large number (5000)\n else:\n sentence_freq += 5000\n frequency_dict[result[1:]] = sentence_freq\n sentence_freq = 0\n return\n\n for word in A[n]:\n out = result + \" \" + word\n self.getSentenceCombo(\n A, sentences, frequency_dict, sentence_freq, word_freq, out, n + 1\n )\n\n @staticmethod\n def getParams() -> Optional[Dict[str, ParamSpec]]:\n return {\n \"dict\": ParamSpec(\n desc=\"The Soundex dictionary to use\",\n req=False,\n default=\"cipheydists::translate::soundex\",\n ),\n \"freq\": ParamSpec(\n desc=\"The word frequency dictionary to use\",\n req=False,\n default=\"cipheydists::list::English5000Freq\",\n ),\n }\n\n def __init__(self, config: Config):\n super().__init__(config)\n self.SOUNDEX_DICT = config.get_resource(self._params()[\"dict\"], Translation)\n self.word_freq = config.get_resource(self._params()[\"freq\"], Translation)\n self.frequency_dict = {}\n self.sentence_freq = 0\n\n# Path: ciphey/basemods/Crackers/vigenere.py\n\"\"\"\n \u2588\u2588\u2588\u2588\u2588\u2588\u2557\u2588\u2588\u2557\u2588\u2588\u2588\u2588\u2588\u2588\u2557 \u2588\u2588\u2557 \u2588\u2588\u2557\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2557\u2588\u2588\u2557 \u2588\u2588\u2557\n\u2588\u2588\u2554\u2550\u2550\u2550\u2550\u255d\u2588\u2588\u2551\u2588\u2588\u2554\u2550\u2550\u2588\u2588\u2557\u2588\u2588\u2551 \u2588\u2588\u2551\u2588\u2588\u2554\u2550\u2550\u2550\u2550\u255d\u255a\u2588\u2588\u2557 \u2588\u2588\u2554\u255d\n\u2588\u2588\u2551 \u2588\u2588\u2551\u2588\u2588\u2588\u2588\u2588\u2588\u2554\u255d\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2551\u2588\u2588\u2588\u2588\u2588\u2557 \u255a\u2588\u2588\u2588\u2588\u2554\u255d\n\u2588\u2588\u2551 \u2588\u2588\u2551\u2588\u2588\u2554\u2550\u2550\u2550\u255d \u2588\u2588\u2554\u2550\u2550\u2588\u2588\u2551\u2588\u2588\u2554\u2550\u2550\u255d \u255a\u2588\u2588\u2554\u255d\n\u255a\u2588\u2588\u2588\u2588\u2588\u2588\u2557\u2588\u2588\u2551\u2588\u2588\u2551 \u2588\u2588\u2551 \u2588\u2588\u2551\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2557 \u2588\u2588\u2551\n\u00a9 Brandon Skerritt\nGithub: brandonskerritt\n\"\"\"\nfrom distutils import util\nfrom typing import Dict, List, Optional, Union\n\nimport cipheycore\nimport logging\nfrom rich.logging import RichHandler\n\nfrom ciphey.common import fix_case\nfrom ciphey.iface import Config, Cracker, CrackInfo, CrackResult, ParamSpec, registry\n\n\n@registry.register\nclass Vigenere(Cracker[str]):\n def getInfo(self, ctext: str) -> CrackInfo:\n if self.keysize is not None:\n analysis = self.cache.get_or_update(\n ctext,\n f\"vigenere::{self.keysize}\",\n lambda: cipheycore.analyse_string(\n ctext.lower(), self.keysize, self.group\n ),\n )\n\n val = cipheycore.vigenere_detect(analysis, self.expected)\n\n logging.info(f\"Vigenere has likelihood {val}\")\n\n return CrackInfo(\n success_likelihood=val,\n # TODO: actually calculate runtimes\n success_runtime=1e-3,\n failure_runtime=1e-2,\n )\n\n likely_lens = self.cache.get_or_update(\n ctext,\n \"vigenere::likely_lens\",\n lambda: cipheycore.vigenere_likely_key_lens(\n ctext.lower(), self.expected, self.group, self.detect_p_value\n ),\n )\n\n # Filter out the lens that make no sense\n likely_lens = [i for i in likely_lens if i.len <= self.max_key_length]\n\n for keysize in likely_lens:\n # Store the analysis\n analysis = self.cache.get_or_update(\n ctext, f\"vigenere::{keysize.len}\", lambda: keysize.tab\n )\n if len(likely_lens) == 0:\n return CrackInfo(\n success_likelihood=0,\n # TODO: actually calculate runtimes\n success_runtime=2e-3,\n failure_runtime=2e-2,\n )\n\n logging.info(\n f\"Vigenere has likelihood {likely_lens[0].p_value} with lens {[i.len for i in likely_lens]}\"\n )\n\n return CrackInfo(\n success_likelihood=likely_lens[0].p_value,\n # TODO: actually calculate runtimes\n success_runtime=2e-4,\n failure_runtime=2e-4,\n )\n\n @staticmethod\n def getTarget() -> str:\n return \"vigenere\"\n\n def crackOne(\n self, ctext: str, analysis: cipheycore.windowed_analysis_res, real_ctext: str\n ) -> List[CrackResult]:\n possible_keys = cipheycore.vigenere_crack(\n analysis, self.expected, self.group, self.p_value\n )\n if len(possible_keys) > self.clamp:\n possible_keys = possible_keys[: self.clamp]\n logging.debug(\n f\"Vigenere crack got keys: {[[i for i in candidate.key] for candidate in possible_keys]}\"\n )\n return [\n CrackResult(\n value=fix_case(\n cipheycore.vigenere_decrypt(ctext, candidate.key, self.group),\n real_ctext,\n ),\n key_info=\"\".join([self.group[i] for i in candidate.key]),\n misc_info=f\"p-value was {candidate.p_value}\",\n )\n for candidate in possible_keys[: min(len(possible_keys), 10)]\n ]\n\n def attemptCrack(self, ctext: str) -> List[CrackResult]:\n logging.info(\"Trying vigenere cipher\")\n # Convert it to lower case\n if self.lower:\n message = ctext.lower()\n else:\n message = ctext\n\n # Analysis must be done here, where we know the case for the cache\n if self.keysize is not None:\n return self.crackOne(\n message,\n self.cache.get_or_update(\n ctext,\n f\"vigenere::{self.keysize}\",\n lambda: cipheycore.analyse_string(\n message, self.keysize, self.group\n ),\n ),\n ctext,\n )\n\n arrs = []\n likely_lens = self.cache.get_or_update(\n ctext,\n \"vigenere::likely_lens\",\n lambda: cipheycore.vigenere_likely_key_lens(\n message, self.expected, self.group\n ),\n )\n possible_lens = [i for i in likely_lens]\n possible_lens.sort(key=lambda i: i.p_value)\n logging.debug(f\"Got possible lengths {[i.len for i in likely_lens]}\")\n # TODO: work out length\n for i in possible_lens:\n arrs.extend(\n self.crackOne(\n message,\n self.cache.get_or_update(\n ctext,\n f\"vigenere::{i.len}\",\n lambda: cipheycore.analyse_string(message, i.len, self.group),\n ),\n ctext,\n )\n )\n\n logging.info(f\"Vigenere returned {len(arrs)} candidates\")\n return arrs\n\n @staticmethod\n def getParams() -> Optional[Dict[str, ParamSpec]]:\n return {\n \"expected\": ParamSpec(\n desc=\"The expected distribution of the plaintext\",\n req=False,\n config_ref=[\"default_dist\"],\n ),\n \"group\": ParamSpec(\n desc=\"An ordered sequence of chars that make up the caesar cipher alphabet\",\n req=False,\n default=\"abcdefghijklmnopqrstuvwxyz\",\n ),\n \"lower\": ParamSpec(\n desc=\"Whether or not the ciphertext should be converted to lowercase first\",\n req=False,\n default=True,\n ),\n \"keysize\": ParamSpec(\n desc=\"A key size that should be used. If not given, will attempt to work it out\",\n req=False,\n ),\n \"p_value\": ParamSpec(\n desc=\"The p-value to use for windowed frequency analysis\",\n req=False,\n default=0.5,\n ),\n \"detect_p_value\": ParamSpec(\n desc=\"The p-value to use for the detection of Vigenere length\",\n req=False,\n default=0.01,\n ),\n \"clamp\": ParamSpec(\n desc=\"The maximum number of candidates that can be returned per key len\",\n req=False,\n default=10,\n ),\n }\n\n def __init__(self, config: Config):\n super().__init__(config)\n self.lower: Union[str, bool] = self._params()[\"lower\"]\n if not isinstance(self.lower, bool):\n...\n", "needle_token_start": 13535, "needle_token_end": 13841, "code_context_ntokens": 16111, "output": ["```python\ndef getSentenceCombo(\n self, A, sentences, frequency_dict, sentence_freq, word_freq, result=\"\", n=0\n ):\n \"\"\"\n This function uses recursion to generate a list of sentences from all possible\n words for a given set of soundex codes.\n \"\"\"\n logging.debug(\"Creating all possible sentences from Soundex\")\n if n == len(A):\n sentences.append(result[1:])\n for word in result[1:].split():\n # Adding the rank of each word to find out the sentence's net frequency\n if word in word_freq:\n sentence_freq += word_freq.index(word)\n # If the word isn't in the frequency list then it's a very uncommon word\n # so we add a large number (5000)\n else:\n sentence_freq += 5000\n frequency_dict[result[1:]] = sentence_freq\n sentence_freq = 0\n return\n\n for word in A[n]:\n out = result + \" \" + word\n self.getSentenceCombo(\n A, sentences, frequency_dict, sentence_freq, word_freq, out, n + 1\n )\n```"]} +{"repo": "Ciphey/Ciphey", "name": "sortlistwithdict", "language": "python", "path": "ciphey/basemods/Crackers/soundex.py", "position_ratio": 0.95, "description": "\nFunction Description:\n1. **Purpose**: The function is designed to organize a list of sentences based on the cumulative frequency of the words they contain, with the aim of prioritizing sentences that are more likely to be meaningful or relevant based on their word usage.\n2. **Input**: The function takes two parameters: a list of sentences and a dictionary where keys are words and values are their corresponding frequency ranks.\n3. **Output**: It returns a list of sentences sorted from the most likely to the least likely to be relevant, based on the sum of the frequency ranks of the words in each sentence.\n4. **Procedure**: The function sorts the list of sentences by calculating the sum of the frequency ranks for each word in a sentence (using the provided dictionary) and then ordering the sentences from the lowest sum (highest priority) to the highest sum (lowest priority).\n\n", "instruction": "Based on the function description and code context, please retrieve and repeat the exact described function from the code context in a code block wrapped by ```:", "template": "instruction\ncode_context\ndescription\ninstruction", "code_context": "# Path: ciphey/mathsHelper.py\n\"\"\"\n \u2588\u2588\u2588\u2588\u2588\u2588\u2557\u2588\u2588\u2557\u2588\u2588\u2588\u2588\u2588\u2588\u2557 \u2588\u2588\u2557 \u2588\u2588\u2557\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2557\u2588\u2588\u2557 \u2588\u2588\u2557\n\u2588\u2588\u2554\u2550\u2550\u2550\u2550\u255d\u2588\u2588\u2551\u2588\u2588\u2554\u2550\u2550\u2588\u2588\u2557\u2588\u2588\u2551 \u2588\u2588\u2551\u2588\u2588\u2554\u2550\u2550\u2550\u2550\u255d\u255a\u2588\u2588\u2557 \u2588\u2588\u2554\u255d\n\u2588\u2588\u2551 \u2588\u2588\u2551\u2588\u2588\u2588\u2588\u2588\u2588\u2554\u255d\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2551\u2588\u2588\u2588\u2588\u2588\u2557 \u255a\u2588\u2588\u2588\u2588\u2554\u255d\n\u2588\u2588\u2551 \u2588\u2588\u2551\u2588\u2588\u2554\u2550\u2550\u2550\u255d \u2588\u2588\u2554\u2550\u2550\u2588\u2588\u2551\u2588\u2588\u2554\u2550\u2550\u255d \u255a\u2588\u2588\u2554\u255d\n\u255a\u2588\u2588\u2588\u2588\u2588\u2588\u2557\u2588\u2588\u2551\u2588\u2588\u2551 \u2588\u2588\u2551 \u2588\u2588\u2551\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2557 \u2588\u2588\u2551\n\u00a9 Brandon Skerritt\nGithub: brandonskerritt\n\nClass to provide helper functions for mathematics\n(oh, not entirely mathematics either. Some NLP stuff and sorting dicts. It's just a helper class\n)\n\"\"\"\n\nfrom collections import OrderedDict\nfrom string import punctuation\nfrom typing import Optional\n\nimport logging\nfrom rich.logging import RichHandler\n\n\nclass mathsHelper:\n \"\"\"Class to provide helper functions for mathematics and other small things\"\"\"\n\n def __init__(self):\n # ETAOIN is the most popular letters in order\n self.ETAOIN = \"ETAOINSHRDLCUMWFGYPBVKJXQZ\"\n self.LETTERS = \"ABCDEFGHIJKLMNOPQRSTUVWXYZ\"\n\n @staticmethod\n def gcd(a, b) -> int:\n \"\"\"Greatest common divisor.\n\n The Greatest Common Divisor of a and b using Euclid's Algorithm.\n\n Args:\n a -> num 1\n b -> num 2\n\n Returns:\n Returns GCD(a, b)\n\n \"\"\"\n # Return\n while a != 0:\n a, b = b % a, a\n return b\n\n @staticmethod\n def mod_inv(a: int, m: int) -> Optional[int]:\n \"\"\"\n Returns the modular inverse of a mod m, or None if it does not exist.\n\n The modular inverse of a is the number a_inv that satisfies the equation\n a_inv * a mod m === 1 mod m\n\n Note: This is a naive implementation, and runtime may be improved in several ways.\n For instance by checking if m is prime to perform a different calculation,\n or by using the extended euclidean algorithm.\n \"\"\"\n for i in range(1, m):\n if (m * i + 1) % a == 0:\n return (m * i + 1) // a\n return None\n\n @staticmethod\n...\n# Path: ciphey/basemods/Checkers/brandon.py\n\"\"\"\n \u2588\u2588\u2588\u2588\u2588\u2588\u2557\u2588\u2588\u2557\u2588\u2588\u2588\u2588\u2588\u2588\u2557 \u2588\u2588\u2557 \u2588\u2588\u2557\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2557\u2588\u2588\u2557 \u2588\u2588\u2557\n\u2588\u2588\u2554\u2550\u2550\u2550\u2550\u255d\u2588\u2588\u2551\u2588\u2588\u2554\u2550\u2550\u2588\u2588\u2557\u2588\u2588\u2551 \u2588\u2588\u2551\u2588\u2588\u2554\u2550\u2550\u2550\u2550\u255d\u255a\u2588\u2588\u2557 \u2588\u2588\u2554\u255d\n\u2588\u2588\u2551 \u2588\u2588\u2551\u2588\u2588\u2588\u2588\u2588\u2588\u2554\u255d\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2551\u2588\u2588\u2588\u2588\u2588\u2557 \u255a\u2588\u2588\u2588\u2588\u2554\u255d\n\u2588\u2588\u2551 \u2588\u2588\u2551\u2588\u2588\u2554\u2550\u2550\u2550\u255d \u2588\u2588\u2554\u2550\u2550\u2588\u2588\u2551\u2588\u2588\u2554\u2550\u2550\u255d \u255a\u2588\u2588\u2554\u255d\n\u255a\u2588\u2588\u2588\u2588\u2588\u2588\u2557\u2588\u2588\u2551\u2588\u2588\u2551 \u2588\u2588\u2551 \u2588\u2588\u2551\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2557 \u2588\u2588\u2551\n\u00a9 Brandon Skerritt\nGithub: brandonskerritt\n\nClass to determine whether something is English or not.\n1. Calculate the Chi Squared score of a sentence\n2. If the score is significantly lower than the average score, it _might_ be English\n 2.1. If the score _might_ be English, then take the text and compare it to the sorted dictionary\n in O(n log n) time.\n It creates a percentage of \"How much of this text is in the dictionary?\"\n The dictionary contains:\n * 20,000 most common US words\n * 10,000 most common UK words (there's no repetition between the two)\n * The top 10,000 passwords\n If the word \"Looks like\" English (chi-squared) and if it contains English words, we can conclude it is\n very likely English. The alternative is doing the dictionary thing but with an entire 479k word dictionary (slower)\n 2.2. If the score is not English, but we haven't tested enough to create an average, then test it against\n the dictionary\n\nThings to optimise:\n* We only run the dictionary if it's 20% smaller than the average for chi squared\n* We consider it \"English\" if 45% of the text matches the dictionary\n* We run the dictionary if there is less than 10 total chisquared test\n\nHow to add a language:\n* Download your desired dictionary. Try to make it the most popular words, for example. Place this file into this\n folder with languagename.txt\nAs an example, this comes built in with english.txt\nFind the statistical frequency of each letter in that language.\nFor English, we have:\nself.languages = {\n \"English\":\n [0.0855, 0.0160, 0.0316, 0.0387, 0.1210,0.0218, 0.0209, 0.0496, 0.0733, 0.0022,0.0081, 0.0421, 0.0253, 0.0717,\n 0.0747,0.0207, 0.0010, 0.0633, 0.0673, 0.0894,0.0268, 0.0106, 0.0183, 0.0019, 0.0172,0.0011]\n}\nIn chisquared.py\nTo add your language, do:\nself.languages = {\n \"English\":\n [0.0855, 0.0160, 0.0316, 0.0387, 0.1210,0.0218, 0.0209, 0.0496, 0.0733, 0.0022,0.0081, 0.0421, 0.0253, 0.0717,\n 0.0747,0.0207, 0.0010, 0.0633, 0.0673, 0.0894,0.0268, 0.0106, 0.0183, 0.0019, 0.0172,0.0011]\n \"German\": [0.0973]\n}\nIn alphabetical order\nAnd you're.... Done! Make sure the name of the two match up\n\"\"\"\nimport sys\nfrom math import ceil\nfrom typing import Dict, Optional\n\nimport logging\nfrom rich.logging import RichHandler\n\nfrom ciphey.iface import Checker, Config, ParamSpec, T, registry\n\nsys.path.append(\"..\")\ntry:\n import mathsHelper as mh\nexcept ModuleNotFoundError:\n import ciphey.mathsHelper as mh\n\n\n@registry.register\nclass Brandon(Checker[str]):\n \"\"\"\n Class designed to confirm whether something is **language** based on how many words of **language** appears\n Call confirmLanguage(text, language)\n * text: the text you want to confirm\n * language: the language you want to confirm\n\n Find out what language it is by using chisquared.py, the highest chisquared score is the language\n languageThreshold = 45\n if a string is 45% **language** words, then it's confirmed to be english\n \"\"\"\n\n def getExpectedRuntime(self, text: T) -> float:\n # TODO: actually work this out\n # TODO its 0.2 seconds on average\n return 1e-4 # 100 \u00b5s\n\n wordlist: set\n\n def clean_text(self, text: str) -> set:\n \"\"\"Cleans the text ready to be checked\n\n Strips punctuation, makes it lower case, turns it into a set separated by spaces, removes duplicate words\n\n Args:\n text -> The text we use to perform analysis on\n\n Returns:\n text -> the text as a list, now cleaned\n\n \"\"\"\n # makes the text unique words and readable\n text = text.lower()\n text = self.mh.strip_punctuation(text)\n text = text.split(\" \")\n text = filter(lambda x: len(x) > 2, text)\n text = set(text)\n return text\n\n def checker(self, text: str, threshold: float, text_length: int, var: set) -> bool:\n \"\"\"Given text determine if it passes checker\n\n The checker uses the variable passed to it. I.E. Stopwords list, 1k words, dictionary\n\n Args:\n text -> The text to check\n threshold -> at what point do we return True? The percentage of text that is in var before we return True\n text_length -> the length of the text\n var -> the variable we are checking against. Stopwords list, 1k words list, dictionary list.\n Returns:\n boolean -> True for it passes the test, False for it fails the test.\"\"\"\n if text is None:\n logging.debug(\"Checker's text is None, so returning False\")\n return False\n if var is None:\n logging.debug(\"Checker's input var is None, so returning False\")\n return False\n\n percent = ceil(text_length * threshold)\n logging.debug(f\"Checker's chunks are size {percent}\")\n meet_threshold = 0\n location = 0\n end = percent\n\n if text_length <= 0:\n return False\n\n while location <= text_length:\n # chunks the text, so only gets THRESHOLD chunks of text at a time\n text = list(text)\n to_analyse = text[location:end]\n logging.debug(f\"To analyse is {to_analyse}\")\n for word in to_analyse:\n # if word is a stopword, + 1 to the counter\n if word in var:\n logging.debug(\n f\"{word} is in var, which means I am +=1 to the meet_threshold which is {meet_threshold}\"\n )\n meet_threshold += 1\n meet_threshold_percent = meet_threshold / text_length\n if meet_threshold_percent >= threshold:\n logging.debug(\n f\"Returning true since the percentage is {meet_threshold / text_length} and the threshold is {threshold}\"\n )\n # if we meet the threshold, return True\n # otherwise, go over again until we do\n # We do this in the for loop because if we're at 24% and THRESHOLD is 25\n # we don't want to wait THRESHOLD to return true, we want to return True ASAP\n return True\n location = end\n end = end + percent\n logging.debug(\n f\"The language proportion {meet_threshold_percent} is under the threshold {threshold}\"\n )\n return False\n\n def __init__(self, config: Config):\n # Suppresses warning\n super().__init__(config)\n self.mh = mh.mathsHelper()\n\n phases = config.get_resource(self._params()[\"phases\"])\n\n self.thresholds_phase1 = phases[\"1\"]\n self.thresholds_phase2 = phases[\"2\"]\n self.top1000Words = config.get_resource(self._params().get(\"top1000\"))\n self.wordlist = config.get_resource(self._params()[\"wordlist\"])\n self.stopwords = config.get_resource(self._params().get(\"stopwords\"))\n\n self.len_phase1 = len(self.thresholds_phase1)\n self.len_phase2 = len(self.thresholds_phase2)\n\n def check(self, text: str) -> Optional[str]:\n \"\"\"Checks to see if the text is in English\n\n Performs a decryption, but mainly parses the internal data packet and prints useful information.\n\n Args:\n text -> The text we use to perform analysis on\n\n Returns:\n bool -> True if the text is English, False otherwise.\n\n \"\"\"\n logging.debug(f'In Language Checker with \"{text}\"')\n text = self.clean_text(text)\n logging.debug(f'Text split to \"{text}\"')\n if text == \"\":\n logging.debug(\"Returning None from Brandon as the text cleaned is none.\")\n return None\n\n length_text = len(text)\n\n what_to_use = {}\n\n # this code decides what checker / threshold to use\n # if text is over or equal to maximum size, just use the maximum possible checker\n what_to_use = self.calculateWhatChecker(\n length_text, self.thresholds_phase1.keys()\n )\n logging.debug(self.thresholds_phase1)\n what_to_use = self.thresholds_phase1[str(what_to_use)]\n # def checker(self, text: str, threshold: float, text_length: int, var: set) -> bool:\n if \"check\" in what_to_use:\n # perform check 1k words\n result = self.checker(\n text, what_to_use[\"check\"], length_text, self.top1000Words\n )\n elif \"stop\" in what_to_use:\n # perform stopwords\n result = self.checker(\n text, what_to_use[\"stop\"], length_text, self.stopwords\n )\n elif \"dict\" in what_to_use:\n result = self.checker(text, what_to_use[\"dict\"], length_text, self.wordlist)\n # If result is None, no point doing it again in phase2\n if not result:\n return None\n else:\n logging.info(f\"It is neither stop or check, but instead {what_to_use}\")\n\n # return False if phase 1 fails\n if not result:\n return None\n else:\n what_to_use = self.calculateWhatChecker(\n length_text, self.thresholds_phase2.keys()\n )\n what_to_use = self.thresholds_phase2[str(what_to_use)]\n result = self.checker(text, what_to_use[\"dict\"], length_text, self.wordlist)\n return \"\" if result else None\n\n def calculateWhatChecker(self, length_text, key):\n \"\"\"Calculates what threshold / checker to use\n\n If the length of the text is over the maximum sentence length, use the last checker / threshold\n Otherwise, traverse the keys backwards until we find a key range that does not fit.\n So we traverse backwards and see if the sentence length is between current - 1 and current\n In this way, we find the absolute lowest checker / percentage threshold.\n We traverse backwards because if the text is longer than the max sentence length, we already know.\n In total, the keys are only 5 items long or so. It is not expensive to move backwards, nor is it expensive to move forwards.\n\n Args:\n length_text -> The length of the text\n key -> What key we want to use. I.E. Phase1 keys, Phase2 keys.\n Returns:\n what_to_use -> the key of the lowest checker.\"\"\"\n\n _keys = list(key)\n _keys = list(map(int, _keys))\n if length_text >= int(_keys[-1]):\n what_to_use = list(key)[_keys.index(_keys[-1])]\n else:\n # this algorithm finds the smallest possible fit for the text\n for counter, i in reversed(list(enumerate(_keys))):\n # [0, 110, 150]\n if i <= length_text:\n what_to_use = i\n return what_to_use\n\n @staticmethod\n def getParams() -> Optional[Dict[str, ParamSpec]]:\n return {\n \"top1000\": ParamSpec(\n desc=\"A wordlist of the top 1000 words\",\n req=False,\n default=\"cipheydists::list::english1000\",\n ),\n \"wordlist\": ParamSpec(\n desc=\"A wordlist of all the words\",\n req=False,\n default=\"cipheydists::list::english\",\n ),\n \"stopwords\": ParamSpec(\n desc=\"A wordlist of StopWords\",\n req=False,\n default=\"cipheydists::list::englishStopWords\",\n ),\n \"threshold\": ParamSpec(\n desc=\"The minimum proportion (between 0 and 1) that must be in the dictionary\",\n req=False,\n default=0.45,\n ),\n \"phases\": ParamSpec(\n desc=\"Language-specific phase thresholds\",\n req=False,\n default=\"cipheydists::brandon::english\",\n ),\n }\n\n# Path: ciphey/basemods/Checkers/format.py\nimport json\nfrom typing import Dict, Optional\n\nimport logging\nfrom rich.logging import RichHandler\n\nfrom ciphey.iface import Checker, Config, ParamSpec, T, registry\n\n\n@registry.register\nclass JsonChecker(Checker[str]):\n\n \"\"\"\n This object is effectively a prebuilt quorum (with requirement 1) of common patterns\n \"\"\"\n\n def check(self, text: T) -> Optional[str]:\n logging.debug(\"Trying json checker\")\n\n # https://github.com/Ciphey/Ciphey/issues/389\n if text.isdigit():\n return None\n\n try:\n json.loads(text)\n return \"\"\n except ValueError:\n return None\n\n def getExpectedRuntime(self, text: T) -> float:\n # TODO: actually bench this\n return 1e-7 * len(text) # From benchmarks I found online\n\n def __init__(self, config: Config):\n super().__init__(config)\n\n @staticmethod\n def getParams() -> Optional[Dict[str, ParamSpec]]:\n pass\n\n# Path: ciphey/basemods/Checkers/human.py\nfrom typing import Dict, Optional\n\nfrom ciphey.iface import Checker, Config, ParamSpec, registry\nfrom rich.console import Console\nfrom rich.markup import escape\n\nconsole = Console()\n\n\n@registry.register\nclass HumanChecker(Checker[str]):\n\n \"\"\"\n Uses the person's decision to determine plaintext\n \"\"\"\n\n def check(self, ctext: str) -> Optional[str]:\n with self._config().pause_spinner_handle():\n response = console.input(\n f\"Possible plaintext: [blue bold]{escape(ctext.__repr__())}[/blue bold] ([green]y[/green]/[red]N[/red]): \"\n )\n if response == \"y\":\n return \"\"\n elif response in (\"n\", \"\"):\n return None\n else:\n return self.check(ctext)\n\n def getExpectedRuntime(self, text: str) -> float:\n return 1 # About a second\n\n @staticmethod\n def getParams() -> Optional[Dict[str, ParamSpec]]:\n return None\n\n def __init__(self, config: Config):\n super().__init__(config)\n\n# Path: ciphey/basemods/Checkers/quadgrams.py\nimport logging\nimport re\nfrom math import log10\nfrom typing import Dict, Optional\n\nfrom ciphey.iface import Checker, Config, ParamSpec, T, Translation, registry\nfrom rich.logging import RichHandler\n\n\n@registry.register\nclass Quadgrams(Checker[str]):\n\n \"\"\"\n Uses Quadgrams to determine plaintext\n \"\"\"\n\n def check(self, ctext: T) -> Optional[str]:\n logging.debug(\"Trying Quadgrams checker\")\n # Capitalize and remove everything that's not a letter\n ctext = re.sub(\"[^A-Z]\", \"\", ctext.upper())\n quadgrams = self.QUADGRAMS_DICT\n quadgrams_sum = sum(quadgrams.values())\n score = 0\n for key in quadgrams.keys():\n quadgrams[key] = float(quadgrams[key]) / quadgrams_sum\n floor = log10(0.01 / quadgrams_sum)\n for i in range(len(ctext) - 4 + 1):\n # Get all quadgrams from ctext and check if they're in the dict\n # If yes then add the score of those quadgrams to the total score\n if ctext[i : i + 4] in quadgrams:\n score += quadgrams[ctext[i : i + 4]]\n else:\n score += floor\n if len(ctext) > 0:\n score = score / len(ctext)\n logging.info(f\"Quadgrams is {score}\")\n # The default threshold was found to work the best from lots of testing\n if score > self.threshold:\n return \"\"\n return None\n\n def getExpectedRuntime(self, text: T) -> float:\n # TODO: actually bench this\n return 2e-7 * len(text)\n\n @staticmethod\n def getParams() -> Optional[Dict[str, ParamSpec]]:\n return {\n \"dict\": ParamSpec(\n desc=\"The quadgrams dictionary to use\",\n req=False,\n default=\"cipheydists::dist::quadgrams\",\n ),\n \"score\": ParamSpec(\n desc=\"The score threshold to use\",\n req=False,\n default=0.00011,\n ),\n }\n\n def __init__(self, config: Config):\n super().__init__(config)\n self.QUADGRAMS_DICT = config.get_resource(self._params()[\"dict\"], Translation)\n self.threshold = float(self._params()[\"score\"])\n\n# Path: ciphey/basemods/Checkers/regex.py\nimport re\nfrom typing import Dict, Optional\n\nimport logging\nfrom rich.logging import RichHandler\n\nfrom ciphey.iface import Checker, Config, ParamSpec, T, registry\n\n\n@registry.register\nclass Regex(Checker[str]):\n def getExpectedRuntime(self, text: T) -> float:\n return 1e-5 # TODO: actually calculate this\n\n def __init__(self, config: Config):\n super().__init__(config)\n self.regexes = list(map(re.compile, self._params()[\"regex\"]))\n logging.debug(f\"There are {len(self.regexes)} regexes\")\n\n def check(self, text: str) -> Optional[str]:\n for regex in self.regexes:\n logging.debug(f\"Trying regex {regex} on {text}\")\n res = regex.search(text)\n logging.debug(f\"Results: {res}\")\n if res:\n return f\"Passed with regex {regex}. Want to contribute to Ciphey? Submit your regex here to allow Ciphey to automatically get this next time https://github.com/bee-san/pyWhat/wiki/Adding-your-own-Regex\\n\"\n\n @staticmethod\n def getParams() -> Optional[Dict[str, ParamSpec]]:\n return {\n \"regex\": ParamSpec(\n req=True,\n desc=\"The regex that must be matched (in a substring)\",\n list=True,\n )\n }\n\n\n@registry.register\nclass RegexList(Checker[str]):\n def getExpectedRuntime(self, text: T) -> float:\n return 1e-5 # TODO: actually calculate this\n\n def __init__(self, config: Config):\n super().__init__(config)\n self.regexes = []\n for i in self._params()[\"resource\"]:\n self.regexes += [re.compile(regex) for regex in config.get_resource(i)]\n logging.debug(f\"There are {len(self.regexes)} regexes\")\n\n def check(self, text: str) -> Optional[str]:\n for regex in self.regexes:\n logging.debug(f\"Trying regex {regex} on {text}\")\n res = regex.search(text)\n logging.debug(f\"Results: {res}\")\n if res:\n return f\"passed with regex {regex}\"\n\n @staticmethod\n def getParams() -> Optional[Dict[str, ParamSpec]]:\n return {\n \"resource\": ParamSpec(\n req=True,\n desc=\"A list of regexes that could be matched\",\n list=True,\n )\n }\n\n# Path: ciphey/basemods/Checkers/what.py\nfrom typing import Dict, Optional\n\nfrom ciphey.iface import Checker, Config, ParamSpec, T, registry\nimport logging\nfrom rich.logging import RichHandler\nfrom pywhat import identifier\nfrom rich.console import Console\n\nconsole = Console()\n\n\n@registry.register\nclass What(Checker[str]):\n\n \"\"\"\n Uses PyWhat to determine plaintext with regexes\n https://github.com/bee-san/pyWhat\n \"\"\"\n\n def check(self, ctext: T) -> Optional[str]:\n logging.debug(\"Trying PyWhat checker\")\n returned_regexes = self.id.identify(ctext)\n if returned_regexes[\"Regexes\"]:\n matched_regex = returned_regexes[\"Regexes\"]['text'][0][\"Regex Pattern\"]\n\n ret = f'The plaintext is a [yellow]{matched_regex[\"Name\"]}[/yellow]'\n human = (\n f'\\nI think the plaintext is a [yellow]{matched_regex[\"Name\"]}[/yellow]'\n )\n\n if \"Description\" in matched_regex and matched_regex[\"Description\"]:\n s = matched_regex[\"Description\"]\n # lowercases first letter so it doesn't look weird\n s = f\", which is {s[0].lower() + s[1:]}\\n\"\n ret += s\n human += s\n\n # if URL is attached, include that too.\n if \"URL\" in matched_regex and matched_regex[\"URL\"]:\n link = matched_regex[\"URL\"] + ctext.replace(\" \", \"\")\n ret += f\"\\nClick here to view in browser [#CAE4F1][link={link}]{link}[/link][/#CAE4F1]\\n\"\n\n # If greppable mode is on, don't print this\n if self.config.verbosity >= 0:\n # Print with full stop\n console.print(human)\n return ret\n return None\n\n def getExpectedRuntime(self, text: T) -> float:\n # TODO: actually bench this\n return 2e-7 * len(text)\n\n @staticmethod\n def getParams() -> Optional[Dict[str, ParamSpec]]:\n return None\n\n def __init__(self, config: Config):\n super().__init__(config)\n self.config = config\n self.id = identifier.Identifier()\n\n# Path: ciphey/basemods/Checkers/ezcheck.py\nfrom typing import Dict, List, Optional\n\nfrom ciphey.iface import Checker, Config, ParamSpec, T, registry\n\nfrom .brandon import Brandon\nfrom .format import JsonChecker\nfrom .human import HumanChecker\nfrom .quadgrams import Quadgrams\nfrom .regex import RegexList\nfrom .what import What\n\n\n@registry.register\nclass EzCheck(Checker[str]):\n \"\"\"\n This object is effectively a prebuilt quorum (with requirement 1) of common patterns, followed by a human check\n \"\"\"\n\n def check(self, text: str) -> Optional[str]:\n for checker in self.checkers:\n res = checker.check(text)\n if (\n res is not None\n and (self.decider is None or self.decider.check(text)) is not None\n ):\n return res\n return None\n\n def getExpectedRuntime(self, text: T) -> float:\n return sum(\n i.getExpectedRuntime(text) for i in self.checkers\n ) + self.decider.getExpectedRuntime(text)\n\n def __init__(self, config: Config):\n super().__init__(config)\n\n self.checkers: List[Checker[str]] = []\n # Disable human checker for automated systems\n if config.verbosity >= 0:\n self.decider = config(HumanChecker)\n else:\n self.decider = None\n\n # We need to modify the config for each of the objects\n\n # First PyWhat, as it's the fastest\n self.checkers.append(config(What))\n\n # Next, the json checker\n self.checkers.append(config(JsonChecker))\n\n # Second to last, the quadgrams checker\n self.checkers.append(config(Quadgrams))\n\n # Finally, the Brandon checker, as it is the slowest\n self.checkers.append(config(Brandon))\n\n @staticmethod\n def getParams() -> Optional[Dict[str, ParamSpec]]:\n pass\n\n# Path: ciphey/basemods/Checkers/quorum.py\nfrom typing import Dict, Generic, Optional\n\nfrom ciphey.iface import Checker, Config, ParamSpec, T, _registry\n\n\nclass Quorum(Generic[T], Checker[T]):\n def check(self, text: T) -> Optional[str]:\n left = self._params().k\n results = []\n for checker in self.checkers:\n results.append(checker.check(text))\n if results[-1] is None:\n continue\n left -= 1\n # Early return check\n if left == 0:\n return str(results)\n\n def __init__(self, config: Config):\n super().__init__(config)\n\n if self._params().k is None:\n k = len(self._params()[\"checker\"])\n # These checks need to be separate, to make sure that we do not have zero members\n if self._params().k == 0 or self._params().k > len(self._params()[\"checker\"]):\n raise IndexError(\n \"k must be between 0 and the number of checkers (inclusive)\"\n )\n\n self.checkers = []\n for i in self._params()[\"checker\"]:\n # This enforces type consistency\n self.checkers.append(_registry.get_named(i, Checker[T]))\n\n @staticmethod\n def getParams() -> Optional[Dict[str, ParamSpec]]:\n return {\n \"checker\": ParamSpec(\n req=True, desc=\"The checkers to be used for analysis\", list=True\n ),\n \"k\": ParamSpec(\n req=False,\n desc=\"The minimum quorum size. Defaults to the number of checkers\",\n ),\n }\n\n# Path: ciphey/basemods/Checkers/__init__.py\nfrom . import any, brandon, ezcheck, format, human, quadgrams, quorum, regex, what\n\n# Path: ciphey/common.py\n\"\"\"Some useful adapters\"\"\"\nfrom typing import Any\n\n\ndef id_lambda(value: Any):\n \"\"\"\n A function used in dynamic class generation that abstracts away a constant return value (like in getName)\n \"\"\"\n return lambda *args: value\n\n\ndef fix_case(target: str, base: str) -> str:\n \"\"\"Returns the lower-case string target with the case of base\"\"\"\n ret = \"\".join(\n [\n target[i].upper() if base[i].isupper() else target[i]\n for i in range(len(target))\n ]\n )\n return \"\".join(\n [\n target[i].upper() if base[i].isupper() else target[i]\n for i in range(len(target))\n ]\n )\n\n# Path: ciphey/basemods/Crackers/affine.py\n# Community\n# by https://github.com/Ozzyz\n\nfrom typing import Dict, List, Optional\n\nimport cipheycore\nimport logging\nfrom rich.logging import RichHandler\n\nfrom ciphey.common import fix_case\nfrom ciphey.iface import Config, Cracker, CrackInfo, CrackResult, ParamSpec, registry\nfrom ciphey.mathsHelper import mathsHelper\n\n\n@registry.register\nclass Affine(Cracker[str]):\n \"\"\"\n Each character in the Affine Cipher is encoded with the rule E(x) = (ax + b) mod m\n m is the size of the alphabet, while a and b are the keys in the cipher. a must be coprime to b.\n The Caesar cipher is a specific case of the Affine Cipher, with a=1 and b being the shift of the cipher.\n Decryption is performed by D(x) = a_inv (x - b) mod m where a_inv is the modular multiplicative inverse of a mod m.\n\n In this version of the Affine Cipher, we do not allow alphabets with several instances of the same letter in different cases.\n For instance, the alphabet 'ABCdef123' is allowed, but 'AaBbCc' is not.\n \"\"\"\n\n def getInfo(self, ctext: str) -> CrackInfo:\n return CrackInfo(\n success_likelihood=0.1,\n success_runtime=1e-5,\n failure_runtime=1e-5,\n )\n\n @staticmethod\n def getTarget() -> str:\n return \"affine\"\n\n def attemptCrack(self, ctext: str) -> List[CrackResult]:\n \"\"\"\n Brute forces all the possible combinations of a and b to attempt to crack the cipher.\n \"\"\"\n logging.debug(\"Attempting affine\")\n candidates = []\n\n # a and b are coprime if gcd(a,b) is 1.\n possible_a = [\n a\n for a in range(1, self.alphabet_length)\n if mathsHelper.gcd(a, self.alphabet_length) == 1\n ]\n logging.info(\n f\"Trying Affine Cracker with {len(possible_a)} a-values and {self.alphabet_length} b-values\"\n )\n\n for a in possible_a:\n a_inv = mathsHelper.mod_inv(a, self.alphabet_length)\n # If there is no inverse, we cannot decrypt the text\n if a_inv is None:\n continue\n for b in range(self.alphabet_length):\n # Pass in lowered text. This means that we expect alphabets to not contain both 'a' and 'A'.\n translated = self.decrypt(ctext.lower(), a_inv, b, self.alphabet_length)\n\n candidate_probability = self.plaintext_probability(translated)\n if candidate_probability > self.plaintext_prob_threshold:\n candidates.append(\n CrackResult(\n value=fix_case(translated, ctext), key_info=f\"a={a}, b={b}\"\n )\n )\n logging.info(f\"Affine Cipher returned {len(candidates)} candidates\")\n return candidates\n\n def plaintext_probability(self, translated: str) -> float:\n \"\"\"\n Analyses the translated text and applies the chi squared test to see if it is a probable plaintext candidate\n Returns the probability of the chi-squared test.\n \"\"\"\n analysis = cipheycore.analyse_string(translated)\n return cipheycore.chisq_test(analysis, self.expected)\n\n def decrypt(self, text: str, a_inv: int, b: int, m: int) -> str:\n \"\"\"\n Each letter is decrypted at D(x) = a_inv (x - b) mod m where x is the char\n We treat the char value as its index in the alphabet, so if\n the alphabet is 'abcd....' and the char is 'b', it has the value 1.\n \"\"\"\n return \"\".join([self.decryptChar(char, a_inv, b, m) for char in text])\n\n def decryptChar(self, char: str, a_inv: int, b: int, m: int) -> str:\n\n # We lower the alphabet since both ctext and alphabet need to be in the same case in order\n # to perform the shifts. The translated text will have fixed case after the translation anyways.\n # This is only necessary if the specified alphabet is uppercase.\n alphabet = [x.lower() for x in self.group]\n\n # Preserve characters that are not in alphabet\n if char not in alphabet:\n return char\n char_idx = alphabet.index(char)\n decrypted_char_idx = (a_inv * (char_idx - b)) % m\n return alphabet[decrypted_char_idx]\n\n @staticmethod\n def getParams() -> Optional[Dict[str, ParamSpec]]:\n return {\n \"expected\": ParamSpec(\n desc=\"The expected distribution of the plaintext\",\n req=False,\n config_ref=[\"default_dist\"],\n ),\n \"group\": ParamSpec(\n desc=\"An ordered sequence of chars that make up the alphabet\",\n req=False,\n default=\"abcdefghijklmnopqrstuvwxyz\",\n ),\n }\n\n def __init__(self, config: Config):\n super().__init__(config)\n self.group = list(self._params()[\"group\"])\n self.expected = config.get_resource(self._params()[\"expected\"])\n self.alphabet_length = len(self.group)\n self.cache = config.cache\n self.plaintext_prob_threshold = 0.01\n\n# Path: ciphey/basemods/Crackers/ascii_shift.py\n\"\"\"\n \u2588\u2588\u2588\u2588\u2588\u2588\u2557\u2588\u2588\u2557\u2588\u2588\u2588\u2588\u2588\u2588\u2557 \u2588\u2588\u2557 \u2588\u2588\u2557\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2557\u2588\u2588\u2557 \u2588\u2588\u2557\n\u2588\u2588\u2554\u2550\u2550\u2550\u2550\u255d\u2588\u2588\u2551\u2588\u2588\u2554\u2550\u2550\u2588\u2588\u2557\u2588\u2588\u2551 \u2588\u2588\u2551\u2588\u2588\u2554\u2550\u2550\u2550\u2550\u255d\u255a\u2588\u2588\u2557 \u2588\u2588\u2554\u255d\n\u2588\u2588\u2551 \u2588\u2588\u2551\u2588\u2588\u2588\u2588\u2588\u2588\u2554\u255d\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2551\u2588\u2588\u2588\u2588\u2588\u2557 \u255a\u2588\u2588\u2588\u2588\u2554\u255d\n\u2588\u2588\u2551 \u2588\u2588\u2551\u2588\u2588\u2554\u2550\u2550\u2550\u255d \u2588\u2588\u2554\u2550\u2550\u2588\u2588\u2551\u2588\u2588\u2554\u2550\u2550\u255d \u255a\u2588\u2588\u2554\u255d\n\u255a\u2588\u2588\u2588\u2588\u2588\u2588\u2557\u2588\u2588\u2551\u2588\u2588\u2551 \u2588\u2588\u2551 \u2588\u2588\u2551\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2557 \u2588\u2588\u2551\n\u00a9 Brandon Skerritt\nGithub: brandonskerritt\n\"\"\"\n\nfrom typing import Dict, List, Optional\n\nimport cipheycore\nimport logging\nfrom rich.logging import RichHandler\n\nfrom ciphey.iface import Config, Cracker, CrackInfo, CrackResult, ParamSpec, registry\n\n\n@registry.register\nclass Ascii_shift(Cracker[str]):\n def getInfo(self, ctext: str) -> CrackInfo:\n analysis = self.cache.get_or_update(\n ctext,\n \"cipheycore::simple_analysis\",\n lambda: cipheycore.analyse_string(ctext),\n )\n\n return CrackInfo(\n success_likelihood=cipheycore.caesar_detect(analysis, self.expected),\n # TODO: actually calculate runtimes\n success_runtime=1e-5,\n failure_runtime=1e-5,\n )\n\n @staticmethod\n def getTarget() -> str:\n return \"ascii_shift\"\n\n def attemptCrack(self, ctext: str) -> List[CrackResult]:\n logging.info(f\"Trying ASCII shift cipher on {ctext}\")\n\n logging.debug(\"Beginning cipheycore simple analysis\")\n\n # Hand it off to the core\n analysis = self.cache.get_or_update(\n ctext,\n \"cipheycore::simple_analysis\",\n lambda: cipheycore.analyse_string(ctext),\n )\n logging.debug(\"Beginning cipheycore::caesar\")\n possible_keys = cipheycore.caesar_crack(\n analysis, self.expected, self.group, self.p_value\n )\n\n n_candidates = len(possible_keys)\n logging.info(f\"ASCII shift returned {n_candidates} candidates\")\n\n if n_candidates == 0:\n logging.debug(\"Filtering for better results\")\n analysis = cipheycore.analyse_string(ctext, self.group)\n possible_keys = cipheycore.caesar_crack(\n analysis, self.expected, self.group, self.p_value\n )\n\n candidates = []\n\n for candidate in possible_keys:\n logging.debug(f\"Candidate {candidate.key} has prob {candidate.p_value}\")\n translated = cipheycore.caesar_decrypt(ctext, candidate.key, self.group)\n candidates.append(CrackResult(value=translated, key_info=candidate.key))\n\n return candidates\n\n @staticmethod\n def getParams() -> Optional[Dict[str, ParamSpec]]:\n return {\n \"expected\": ParamSpec(\n desc=\"The expected distribution of the plaintext\",\n req=False,\n config_ref=[\"default_dist\"],\n ),\n \"group\": ParamSpec(\n desc=\"An ordered sequence of chars that make up the ASCII shift cipher alphabet\",\n req=False,\n default=\"\"\"\\x00\\x01\\x02\\x03\\x04\\x05\\x06\\x07\\x08\\t\\n\\x0b\\x0c\\r\\x0e\\x0f\\x10\\x11\\x12\\x13\\x14\\x15\\x16\\x17\\x18\\x19\\x1a\\x1b\\x1c\\x1d\\x1e\\x1f !\"#$%&\\'()*+,-./0123456789:;<=>?@ABCDEFGHIJKLMNOPQRSTUVWXYZ[\\\\]^_`abcdefghijklmnopqrstuvwxyz{|}~\\x7f\"\"\",\n ),\n \"p_value\": ParamSpec(\n desc=\"The p-value to use for standard frequency analysis\",\n req=False,\n default=0.01,\n )\n # TODO: add \"filter\" param\n }\n\n def __init__(self, config: Config):\n super().__init__(config)\n self.group = list(self._params()[\"group\"])\n self.expected = config.get_resource(self._params()[\"expected\"])\n self.cache = config.cache\n self.p_value = float(self._params()[\"p_value\"])\n\n# Path: ciphey/basemods/Crackers/baconian.py\nimport re\nfrom typing import Dict, List, Optional\n\nfrom ciphey.iface import (\n Config,\n Cracker,\n CrackInfo,\n CrackResult,\n ParamSpec,\n Translation,\n registry,\n)\nimport logging\nfrom rich.logging import RichHandler\n\n\n@registry.register\nclass Baconian(Cracker[str]):\n def getInfo(self, ctext: str) -> CrackInfo:\n return CrackInfo(\n success_likelihood=0.1,\n success_runtime=1e-5,\n failure_runtime=1e-5,\n )\n\n @staticmethod\n def getTarget() -> str:\n return \"baconian\"\n\n def attemptCrack(self, ctext: str) -> List[CrackResult]:\n \"\"\"\n Attempts to decode both variants of the Baconian cipher.\n \"\"\"\n logging.debug(\"Attempting Baconian cracker\")\n candidates = []\n result = []\n ctext_decoded = \"\"\n ctext_decoded2 = \"\"\n\n # Convert to uppercase and replace delimiters and whitespace with nothing\n ctext = re.sub(r\"[,;:\\-\\s]\", \"\", ctext.upper())\n\n # Make sure ctext only contains A and B\n if bool(re.search(r\"[^AB]\", ctext)) is True:\n logging.debug(\"Failed to crack baconian due to non baconian character(s)\")\n return None\n\n # Make sure ctext is divisible by 5\n ctext_len = len(ctext)\n if ctext_len % 5:\n logging.debug(\n f\"Failed to decode Baconian because length must be a multiple of 5, not '{ctext_len}'\"\n )\n return None\n\n # Split ctext into groups of 5\n ctext = \" \".join(ctext[i : i + 5] for i in range(0, len(ctext), 5))\n ctext_split = ctext.split(\" \")\n baconian_keys = self.BACONIAN_DICT.keys()\n\n # Decode I=J and U=V variant\n for i in ctext_split:\n if i in baconian_keys:\n ctext_decoded += self.BACONIAN_DICT[i]\n\n # Decode variant that assigns each letter a unique code\n for i in ctext_split:\n if \"+\" + i in baconian_keys:\n ctext_decoded2 += self.BACONIAN_DICT[\"+\" + i]\n\n candidates.append(ctext_decoded)\n candidates.append(ctext_decoded2)\n for candidate in candidates:\n if candidate != \"\":\n if candidate == candidates[0]:\n result.append(CrackResult(value=candidate, key_info=\"I=J & U=V\"))\n else:\n result.append(CrackResult(value=candidate))\n logging.debug(f\"Baconian cracker - Returning results: {result}\")\n return result\n\n @staticmethod\n def getParams() -> Optional[Dict[str, ParamSpec]]:\n return {\n \"expected\": ParamSpec(\n desc=\"The expected distribution of the plaintext\",\n req=False,\n config_ref=[\"default_dist\"],\n ),\n \"dict\": ParamSpec(\n desc=\"The Baconian alphabet dictionary to use\",\n req=False,\n default=\"cipheydists::translate::baconian\",\n ),\n }\n\n def __init__(self, config: Config):\n super().__init__(config)\n self.BACONIAN_DICT = config.get_resource(self._params()[\"dict\"], Translation)\n self.expected = config.get_resource(self._params()[\"expected\"])\n self.cache = config.cache\n\n# Path: ciphey/basemods/Crackers/caesar.py\n\"\"\"\n \u2588\u2588\u2588\u2588\u2588\u2588\u2557\u2588\u2588\u2557\u2588\u2588\u2588\u2588\u2588\u2588\u2557 \u2588\u2588\u2557 \u2588\u2588\u2557\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2557\u2588\u2588\u2557 \u2588\u2588\u2557\n\u2588\u2588\u2554\u2550\u2550\u2550\u2550\u255d\u2588\u2588\u2551\u2588\u2588\u2554\u2550\u2550\u2588\u2588\u2557\u2588\u2588\u2551 \u2588\u2588\u2551\u2588\u2588\u2554\u2550\u2550\u2550\u2550\u255d\u255a\u2588\u2588\u2557 \u2588\u2588\u2554\u255d\n\u2588\u2588\u2551 \u2588\u2588\u2551\u2588\u2588\u2588\u2588\u2588\u2588\u2554\u255d\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2551\u2588\u2588\u2588\u2588\u2588\u2557 \u255a\u2588\u2588\u2588\u2588\u2554\u255d\n\u2588\u2588\u2551 \u2588\u2588\u2551\u2588\u2588\u2554\u2550\u2550\u2550\u255d \u2588\u2588\u2554\u2550\u2550\u2588\u2588\u2551\u2588\u2588\u2554\u2550\u2550\u255d \u255a\u2588\u2588\u2554\u255d\n\u255a\u2588\u2588\u2588\u2588\u2588\u2588\u2557\u2588\u2588\u2551\u2588\u2588\u2551 \u2588\u2588\u2551 \u2588\u2588\u2551\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2557 \u2588\u2588\u2551\n\u00a9 Brandon Skerritt\nGithub: brandonskerritt\n\"\"\"\nfrom distutils import util\nfrom typing import Dict, List, Optional, Union\n\nimport cipheycore\nimport logging\nfrom rich.logging import RichHandler\n\nfrom ciphey.common import fix_case\nfrom ciphey.iface import Config, Cracker, CrackInfo, CrackResult, ParamSpec, registry\n\n\n@registry.register\nclass Caesar(Cracker[str]):\n def getInfo(self, ctext: str) -> CrackInfo:\n analysis = self.cache.get_or_update(\n ctext,\n \"cipheycore::simple_analysis\",\n lambda: cipheycore.analyse_string(ctext),\n )\n\n return CrackInfo(\n success_likelihood=cipheycore.caesar_detect(analysis, self.expected),\n # TODO: actually calculate runtimes\n success_runtime=1e-5,\n failure_runtime=1e-5,\n )\n\n @staticmethod\n def getTarget() -> str:\n return \"caesar\"\n\n def attemptCrack(self, ctext: str) -> List[CrackResult]:\n logging.info(f\"Trying caesar cipher on {ctext}\")\n # Convert it to lower case\n #\n # TODO: handle different alphabets\n if self.lower:\n message = ctext.lower()\n else:\n message = ctext\n\n logging.debug(\"Beginning cipheycore simple analysis\")\n\n # Hand it off to the core\n analysis = self.cache.get_or_update(\n ctext,\n \"cipheycore::simple_analysis\",\n lambda: cipheycore.analyse_string(ctext),\n )\n logging.debug(\"Beginning cipheycore::caesar\")\n possible_keys = cipheycore.caesar_crack(\n analysis, self.expected, self.group, self.p_value\n )\n\n n_candidates = len(possible_keys)\n logging.info(f\"Caesar returned {n_candidates} candidates\")\n\n if n_candidates == 0:\n logging.debug(\"Filtering for better results\")\n analysis = cipheycore.analyse_string(ctext, self.group)\n possible_keys = cipheycore.caesar_crack(\n analysis, self.expected, self.group, self.p_value\n )\n\n candidates = []\n\n for candidate in possible_keys:\n logging.debug(f\"Candidate {candidate.key} has prob {candidate.p_value}\")\n translated = cipheycore.caesar_decrypt(message, candidate.key, self.group)\n candidates.append(\n CrackResult(value=fix_case(translated, ctext), key_info=candidate.key)\n )\n\n return candidates\n\n @staticmethod\n def getParams() -> Optional[Dict[str, ParamSpec]]:\n return {\n \"expected\": ParamSpec(\n desc=\"The expected distribution of the plaintext\",\n req=False,\n config_ref=[\"default_dist\"],\n ),\n \"group\": ParamSpec(\n desc=\"An ordered sequence of chars that make up the caesar cipher alphabet\",\n req=False,\n default=\"abcdefghijklmnopqrstuvwxyz\",\n ),\n \"lower\": ParamSpec(\n desc=\"Whether or not the ciphertext should be converted to lowercase first\",\n req=False,\n default=True,\n ),\n \"p_value\": ParamSpec(\n desc=\"The p-value to use for standard frequency analysis\",\n req=False,\n default=0.01,\n )\n # TODO: add \"filter\" param\n }\n\n def __init__(self, config: Config):\n super().__init__(config)\n self.lower: Union[str, bool] = self._params()[\"lower\"]\n if not isinstance(self.lower, bool):\n self.lower = util.strtobool(self.lower)\n self.group = list(self._params()[\"group\"])\n self.expected = config.get_resource(self._params()[\"expected\"])\n self.cache = config.cache\n self.p_value = float(self._params()[\"p_value\"])\n\n# Path: ciphey/basemods/Crackers/rot47.py\n\"\"\"\n \u2588\u2588\u2588\u2588\u2588\u2588\u2557\u2588\u2588\u2557\u2588\u2588\u2588\u2588\u2588\u2588\u2557 \u2588\u2588\u2557 \u2588\u2588\u2557\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2557\u2588\u2588\u2557 \u2588\u2588\u2557\n\u2588\u2588\u2554\u2550\u2550\u2550\u2550\u255d\u2588\u2588\u2551\u2588\u2588\u2554\u2550\u2550\u2588\u2588\u2557\u2588\u2588\u2551 \u2588\u2588\u2551\u2588\u2588\u2554\u2550\u2550\u2550\u2550\u255d\u255a\u2588\u2588\u2557 \u2588\u2588\u2554\u255d\n\u2588\u2588\u2551 \u2588\u2588\u2551\u2588\u2588\u2588\u2588\u2588\u2588\u2554\u255d\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2551\u2588\u2588\u2588\u2588\u2588\u2557 \u255a\u2588\u2588\u2588\u2588\u2554\u255d\n\u2588\u2588\u2551 \u2588\u2588\u2551\u2588\u2588\u2554\u2550\u2550\u2550\u255d \u2588\u2588\u2554\u2550\u2550\u2588\u2588\u2551\u2588\u2588\u2554\u2550\u2550\u255d \u255a\u2588\u2588\u2554\u255d\n\u255a\u2588\u2588\u2588\u2588\u2588\u2588\u2557\u2588\u2588\u2551\u2588\u2588\u2551 \u2588\u2588\u2551 \u2588\u2588\u2551\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2557 \u2588\u2588\u2551\n\u00a9 Brandon Skerritt\nGithub: brandonskerritt\n\"\"\"\n\nfrom typing import Dict, List, Optional\n\nimport cipheycore\nimport logging\nfrom rich.logging import RichHandler\n\nfrom ciphey.iface import Config, Cracker, CrackInfo, CrackResult, ParamSpec, registry\n\n\n@registry.register\nclass Rot47(Cracker[str]):\n def getInfo(self, ctext: str) -> CrackInfo:\n analysis = self.cache.get_or_update(\n ctext,\n \"cipheycore::simple_analysis\",\n lambda: cipheycore.analyse_string(ctext),\n )\n\n return CrackInfo(\n success_likelihood=cipheycore.caesar_detect(analysis, self.expected),\n # TODO: actually calculate runtimes\n success_runtime=1e-5,\n failure_runtime=1e-5,\n )\n\n @staticmethod\n def getTarget() -> str:\n return \"rot47\"\n\n def attemptCrack(self, ctext: str) -> List[CrackResult]:\n logging.info(f\"Trying ROT47 cipher on {ctext}\")\n\n logging.debug(\"Beginning cipheycore simple analysis\")\n\n # Hand it off to the core\n analysis = self.cache.get_or_update(\n ctext,\n \"cipheycore::simple_analysis\",\n lambda: cipheycore.analyse_string(ctext),\n )\n logging.debug(\"Beginning cipheycore::caesar\")\n possible_keys = cipheycore.caesar_crack(\n analysis, self.expected, self.group, self.p_value\n )\n\n n_candidates = len(possible_keys)\n logging.info(f\"ROT47 returned {n_candidates} candidates\")\n\n if n_candidates == 0:\n logging.debug(\"Filtering for better results\")\n analysis = cipheycore.analyse_string(ctext, self.group)\n possible_keys = cipheycore.caesar_crack(\n analysis, self.expected, self.group, self.p_value\n )\n\n candidates = []\n\n for candidate in possible_keys:\n logging.debug(f\"Candidate {candidate.key} has prob {candidate.p_value}\")\n translated = cipheycore.caesar_decrypt(ctext, candidate.key, self.group)\n candidates.append(CrackResult(value=translated, key_info=candidate.key))\n\n return candidates\n\n @staticmethod\n def getParams() -> Optional[Dict[str, ParamSpec]]:\n return {\n \"expected\": ParamSpec(\n desc=\"The expected distribution of the plaintext\",\n req=False,\n config_ref=[\"default_dist\"],\n ),\n \"group\": ParamSpec(\n desc=\"An ordered sequence of chars that make up the ROT47 cipher alphabet\",\n req=False,\n default=\"\"\"!\"#$%&\\'()*+,-./0123456789:;<=>?@ABCDEFGHIJKLMNOPQRSTUVWXYZ[\\\\]^_`abcdefghijklmnopqrstuvwxyz{|}~\"\"\",\n ),\n \"p_value\": ParamSpec(\n desc=\"The p-value to use for standard frequency analysis\",\n req=False,\n default=0.01,\n )\n # TODO: add \"filter\" param\n }\n\n def __init__(self, config: Config):\n super().__init__(config)\n self.group = list(self._params()[\"group\"])\n self.expected = config.get_resource(self._params()[\"expected\"])\n self.cache = config.cache\n self.p_value = float(self._params()[\"p_value\"])\n\n# Path: ciphey/basemods/Crackers/soundex.py\nimport re\nfrom typing import Dict, List, Optional\n\nimport logging\nfrom rich.logging import RichHandler\n\nfrom ciphey.iface import (\n Config,\n Cracker,\n CrackInfo,\n CrackResult,\n ParamSpec,\n Translation,\n registry,\n)\n\n\n@registry.register\nclass Soundex(Cracker[str]):\n def getInfo(self, ctext: str) -> CrackInfo:\n return CrackInfo(\n success_likelihood=0.1,\n success_runtime=1e-5,\n failure_runtime=1e-5,\n )\n\n @staticmethod\n def getTarget() -> str:\n return \"soundex\"\n\n def attemptCrack(self, ctext: str) -> List[CrackResult]:\n \"\"\"\n Attempts to crack Soundex by generating all possible combinations.\n \"\"\"\n logging.debug(\"Attempting Soundex cracker\")\n word_list = []\n sentences = []\n result = []\n\n # Convert to uppercase and replace delimiters and whitespace with nothing\n ctext = re.sub(r\"[,;:\\-\\s]\", \"\", ctext.upper())\n\n # Make sure ctext contains only A-Z and 0-9\n if bool(re.search(r\"[^A-Z0-9]\", ctext)) is True:\n logging.debug(\"Failed to crack soundex due to non soundex character(s)\")\n return None\n\n # Make sure ctext is divisible by 4\n ctext_len = len(ctext)\n if ctext_len % 4:\n logging.debug(\n f\"Failed to decode Soundex because length must be a multiple of 4, not '{ctext_len}'\"\n )\n return None\n\n # Split ctext into groups of 4\n ctext = \" \".join(ctext[i : i + 4] for i in range(0, len(ctext), 4))\n ctext_split = ctext.split(\" \")\n soundex_keys = self.SOUNDEX_DICT.keys()\n\n # Find all words that correspond to each given soundex code\n for code in ctext_split:\n if code in soundex_keys:\n word_list.append(self.SOUNDEX_DICT[code])\n\n logging.info(f\"Possible words for given encoded text: {word_list}\")\n\n # Find all possible sentences\n self.getSentenceCombo(\n word_list,\n sentences,\n self.frequency_dict,\n self.sentence_freq,\n self.word_freq,\n )\n\n sorted_sentences = self.sortlistwithdict(sentences, self.frequency_dict)\n\n for sentence in sorted_sentences:\n result.append(CrackResult(value=sentence))\n\n logging.debug(f\"Soundex cracker - Returning results: {result}\")\n return result\n\n \ndef sortlistwithdict(self, listtosort, hashes):\n \"\"\"\n This function uses the sum of ranks (based on frequency) of each word in each\n sentence and sorts them according to it.\n \"\"\"\n return sorted(listtosort, key=lambda x: hashes[x])\n\n def getSentenceCombo(\n self, A, sentences, frequency_dict, sentence_freq, word_freq, result=\"\", n=0\n ):\n \"\"\"\n This function uses recursion to generate a list of sentences from all possible\n words for a given set of soundex codes.\n \"\"\"\n logging.debug(\"Creating all possible sentences from Soundex\")\n if n == len(A):\n sentences.append(result[1:])\n for word in result[1:].split():\n # Adding the rank of each word to find out the sentence's net frequency\n if word in word_freq:\n sentence_freq += word_freq.index(word)\n # If the word isn't in the frequency list then it's a very uncommon word\n # so we add a large number (5000)\n else:\n sentence_freq += 5000\n frequency_dict[result[1:]] = sentence_freq\n sentence_freq = 0\n return\n\n for word in A[n]:\n out = result + \" \" + word\n self.getSentenceCombo(\n A, sentences, frequency_dict, sentence_freq, word_freq, out, n + 1\n )\n\n @staticmethod\n def getParams() -> Optional[Dict[str, ParamSpec]]:\n return {\n \"dict\": ParamSpec(\n desc=\"The Soundex dictionary to use\",\n req=False,\n default=\"cipheydists::translate::soundex\",\n ),\n \"freq\": ParamSpec(\n desc=\"The word frequency dictionary to use\",\n req=False,\n default=\"cipheydists::list::English5000Freq\",\n ),\n }\n\n def __init__(self, config: Config):\n super().__init__(config)\n self.SOUNDEX_DICT = config.get_resource(self._params()[\"dict\"], Translation)\n self.word_freq = config.get_resource(self._params()[\"freq\"], Translation)\n self.frequency_dict = {}\n self.sentence_freq = 0\n\n# Path: ciphey/basemods/Crackers/vigenere.py\n\"\"\"\n \u2588\u2588\u2588\u2588\u2588\u2588\u2557\u2588\u2588\u2557\u2588\u2588\u2588\u2588\u2588\u2588\u2557 \u2588\u2588\u2557 \u2588\u2588\u2557\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2557\u2588\u2588\u2557 \u2588\u2588\u2557\n\u2588\u2588\u2554\u2550\u2550\u2550\u2550\u255d\u2588\u2588\u2551\u2588\u2588\u2554\u2550\u2550\u2588\u2588\u2557\u2588\u2588\u2551 \u2588\u2588\u2551\u2588\u2588\u2554\u2550\u2550\u2550\u2550\u255d\u255a\u2588\u2588\u2557 \u2588\u2588\u2554\u255d\n\u2588\u2588\u2551 \u2588\u2588\u2551\u2588\u2588\u2588\u2588\u2588\u2588\u2554\u255d\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2551\u2588\u2588\u2588\u2588\u2588\u2557 \u255a\u2588\u2588\u2588\u2588\u2554\u255d\n\u2588\u2588\u2551 \u2588\u2588\u2551\u2588\u2588\u2554\u2550\u2550\u2550\u255d \u2588\u2588\u2554\u2550\u2550\u2588\u2588\u2551\u2588\u2588\u2554\u2550\u2550\u255d \u255a\u2588\u2588\u2554\u255d\n\u255a\u2588\u2588\u2588\u2588\u2588\u2588\u2557\u2588\u2588\u2551\u2588\u2588\u2551 \u2588\u2588\u2551 \u2588\u2588\u2551\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2557 \u2588\u2588\u2551\n\u00a9 Brandon Skerritt\nGithub: brandonskerritt\n\"\"\"\nfrom distutils import util\nfrom typing import Dict, List, Optional, Union\n...\n", "needle_token_start": 15257, "needle_token_end": 15328, "code_context_ntokens": 16105, "output": ["```python\ndef sortlistwithdict(self, listtosort, hashes):\n \"\"\"\n This function uses the sum of ranks (based on frequency) of each word in each\n sentence and sorts them according to it.\n \"\"\"\n return sorted(listtosort, key=lambda x: hashes[x])\n```"]} +{"repo": "apache/logging-log4cxx", "name": "ConfiguratorWatchdog", "language": "cpp", "path": "src/main/cpp/optionconverter.cpp", "position_ratio": 0.05, "description": "\nFunction Description:\n1. **Purpose**: To monitor changes in a configuration file and reload the configuration if any changes are detected.\n2. **Input**: A configuration handler and the path to a configuration file.\n3. **Output**: There is no direct output; however, the effect is the reconfiguration of the system based on the updated file.\n4. **Procedure**: This component inherits from a file monitoring class, setting up a watch on the specified file. When a change is detected in the file, it uses the provided configuration handler to reapply the settings from the file.\n\n", "instruction": "Based on the function description and code context, please retrieve and repeat the exact described function from the code context in a code block wrapped by ```:", "template": "instruction\ncode_context\ndescription\ninstruction", "code_context": "// Path: src/main/cpp/level.cpp\n/*\n * Licensed to the Apache Software Foundation (ASF) under one or more\n * contributor license agreements. See the NOTICE file distributed with\n * this work for additional information regarding copyright ownership.\n * The ASF licenses this file to You under the Apache License, Version 2.0\n * (the \"License\"); you may not use this file except in compliance with\n...\n// Path: src/main/cpp/optionconverter.cpp\n/*\n * Licensed to the Apache Software Foundation (ASF) under one or more\n * contributor license agreements. See the NOTICE file distributed with\n * this work for additional information regarding copyright ownership.\n * The ASF licenses this file to You under the Apache License, Version 2.0\n * (the \"License\"); you may not use this file except in compliance with\n * the License. You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n */\n\n#include \n#include \n#include \n#include \n#include \n#include \n#include \n#include \n#include \n#include \n#include \n#include \n#include \n#include \n#include \n#include \n#include \n#include \n#include \n#include \n#include \n#include \n#if !defined(LOG4CXX)\n\t#define LOG4CXX 1\n#endif\n#include \n#include \n\nnamespace LOG4CXX_NS\n{\n\nclass ConfiguratorWatchdog : public helpers::FileWatchdog\n{\n\tspi::ConfiguratorPtr m_config;\n\tpublic:\n \nConfiguratorWatchdog(const spi::ConfiguratorPtr& config, const File& filename)\n : helpers::FileWatchdog(filename)\n , m_config(config)\n {\n }\n\n /**\n Call PropertyConfigurator#doConfigure(const String& configFileName,\n const spi::LoggerRepositoryPtr& hierarchy) with the\n filename to reconfigure log4cxx.\n */\n void doOnChange() override\n {\n m_config->doConfigure(file(), LogManager::getLoggerRepository());\n }\n};\n\n}\n\nusing namespace LOG4CXX_NS;\nusing namespace LOG4CXX_NS::helpers;\nusing namespace LOG4CXX_NS::spi;\n\n\nLogString OptionConverter::convertSpecialChars(const LogString& s)\n{\n\tlogchar c;\n\tLogString sbuf;\n\n\tLogString::const_iterator i = s.begin();\n\n\twhile (i != s.end())\n\t{\n\t\tc = *i++;\n\n\t\tif (c == 0x5C /* '\\\\' */)\n\t\t{\n\t\t\tc = *i++;\n\n\t\t\tswitch (c)\n\t\t\t{\n\t\t\t\tcase 0x6E: //'n'\n\t\t\t\t\tc = 0x0A;\n\t\t\t\t\tbreak;\n\n\t\t\t\tcase 0x72: //'r'\n\t\t\t\t\tc = 0x0D;\n\t\t\t\t\tbreak;\n\n\t\t\t\tcase 0x74: //'t'\n\t\t\t\t\tc = 0x09;\n\t\t\t\t\tbreak;\n\n\t\t\t\tcase 0x66: //'f'\n\t\t\t\t\tc = 0x0C;\n\t\t\t\t\tbreak;\n\n\t\t\t\tdefault:\n\t\t\t\t\tbreak;\n\t\t\t}\n\t\t}\n\n\t\tsbuf.append(1, c);\n\t}\n\n\treturn sbuf;\n}\n\n\nbool OptionConverter::toBoolean(const LogString& value, bool dEfault)\n{\n\tif (value.length() >= 4)\n\t{\n\t\tif (StringHelper::equalsIgnoreCase(value.substr(0, 4),\n\t\t\t\tLOG4CXX_STR(\"TRUE\"), LOG4CXX_STR(\"true\")))\n\t\t{\n\t\t\treturn true;\n\t\t}\n\t}\n\n\tif (dEfault && value.length() >= 5)\n\t{\n\t\tif (StringHelper::equalsIgnoreCase(value.substr(0, 5),\n\t\t\t\tLOG4CXX_STR(\"FALSE\"), LOG4CXX_STR(\"false\")))\n\t\t{\n\t\t\treturn false;\n\t\t}\n\t}\n\n\treturn dEfault;\n}\n\nint OptionConverter::toInt(const LogString& value, int dEfault)\n{\n\tLogString trimmed(StringHelper::trim(value));\n\n\tif (trimmed.empty())\n\t{\n\t\treturn dEfault;\n\t}\n\n\tLOG4CXX_ENCODE_CHAR(cvalue, trimmed);\n\n\treturn (int) atol(cvalue.c_str());\n}\n\nlong OptionConverter::toFileSize(const LogString& s, long dEfault)\n{\n\tif (s.empty())\n\t{\n\t\treturn dEfault;\n\t}\n\n\tsize_t index = s.find_first_of(LOG4CXX_STR(\"bB\"));\n\n\tif (index != LogString::npos && index > 0)\n\t{\n\t\tlong multiplier = 1;\n\t\tindex--;\n\n\t\tif (s[index] == 0x6B /* 'k' */ || s[index] == 0x4B /* 'K' */)\n\t\t{\n\t\t\tmultiplier = 1024;\n\t\t}\n\t\telse if (s[index] == 0x6D /* 'm' */ || s[index] == 0x4D /* 'M' */)\n\t\t{\n\t\t\tmultiplier = 1024 * 1024;\n\t\t}\n\t\telse if (s[index] == 0x67 /* 'g'*/ || s[index] == 0x47 /* 'G' */)\n\t\t{\n\t\t\tmultiplier = 1024 * 1024 * 1024;\n\t\t}\n\n\t\treturn toInt(s.substr(0, index), 1) * multiplier;\n\t}\n\n\treturn toInt(s, 1);\n}\n\nLogString OptionConverter::findAndSubst(const LogString& key, Properties& props)\n{\n\tLogString value(props.getProperty(key));\n\n\tif (value.empty())\n\t{\n\t\treturn value;\n\t}\n\n\ttry\n\t{\n\t\treturn substVars(value, props);\n\t}\n\tcatch (IllegalArgumentException& e)\n\t{\n\t\tLogLog::error(((LogString) LOG4CXX_STR(\"Bad option value [\"))\n\t\t\t+ value + LOG4CXX_STR(\"].\"), e);\n\t\treturn value;\n\t}\n}\n\nLogString OptionConverter::substVars(const LogString& val, Properties& props)\n{\n\tLogString sbuf;\n\tconst logchar delimStartArray[] = { 0x24, 0x7B, 0 };\n\tconst LogString delimStart(delimStartArray);\n\tconst logchar delimStop = 0x7D; // '}';\n\tconst size_t DELIM_START_LEN = 2;\n\tconst size_t DELIM_STOP_LEN = 1;\n\n\tsize_t i = 0;\n\n\twhile (true)\n\t{\n\t\tsize_t j = val.find(delimStart, i);\n\n\t\tif (j == val.npos)\n\t\t{\n\t\t\t// no more variables\n\t\t\tif (i == 0)\n\t\t\t{\n\t\t\t\t// this is a simple string\n\t\t\t\treturn val;\n\t\t\t}\n\t\t\telse\n\t\t\t{\n\t\t\t\t// add the tail string which contails no variables and return the result.\n\t\t\t\tsbuf.append(val.substr(i, val.length() - i));\n\t\t\t\treturn sbuf;\n\t\t\t}\n\t\t}\n\t\telse\n\t\t{\n\t\t\tsbuf.append(val.substr(i, j - i));\n\t\t\tsize_t k = val.find(delimStop, j);\n\n\t\t\tif (k == val.npos)\n\t\t\t{\n\t\t\t\tLogString msg(1, (logchar) 0x22 /* '\\\"' */);\n\t\t\t\tmsg.append(val);\n\t\t\t\tmsg.append(LOG4CXX_STR(\"\\\" has no closing brace. Opening brace at position \"));\n\t\t\t\tPool p;\n\t\t\t\tStringHelper::toString(j, p, msg);\n\t\t\t\tmsg.append(1, (logchar) 0x2E /* '.' */);\n\t\t\t\tthrow IllegalArgumentException(msg);\n\t\t\t}\n\t\t\telse\n\t\t\t{\n\t\t\t\tj += DELIM_START_LEN;\n\t\t\t\tLogString key = val.substr(j, k - j);\n\t\t\t\t// first try in System properties\n\t\t\t\tLogString replacement(getSystemProperty(key, LogString()));\n\n\t\t\t\t// then try props parameter\n\t\t\t\tif (replacement.empty())\n\t\t\t\t{\n\t\t\t\t\treplacement = props.getProperty(key);\n\t\t\t\t}\n\n\t\t\t\tif (!replacement.empty())\n\t\t\t\t{\n\t\t\t\t\t// Do variable substitution on the replacement string\n\t\t\t\t\t// such that we can solve \"Hello ${x2}\" as \"Hello p1\"\n\t\t\t\t\t// the where the properties are\n\t\t\t\t\t// x1=p1\n\t\t\t\t\t// x2=${x1}\n\t\t\t\t\tLogString recursiveReplacement = substVars(replacement, props);\n\t\t\t\t\tsbuf.append(recursiveReplacement);\n\t\t\t\t}\n\n\t\t\t\ti = k + DELIM_STOP_LEN;\n\t\t\t}\n\t\t}\n\t}\n}\n\nLogString OptionConverter::getSystemProperty(const LogString& key, const LogString& def)\n{\n\tif (!key.empty())\n\t{\n\t\tLogString value(System::getProperty(key));\n\n\t\tif (!value.empty())\n\t\t{\n\t\t\treturn value;\n\t\t}\n\t}\n\n\treturn def;\n}\n\nLevelPtr OptionConverter::toLevel(const LogString& value,\n\tconst LevelPtr& defaultValue)\n{\n\tsize_t hashIndex = value.find(LOG4CXX_STR(\"#\"));\n\n\tif (hashIndex == LogString::npos)\n\t{\n\t\tif (value.empty())\n\t\t{\n\t\t\treturn defaultValue;\n\t\t}\n\t\telse\n\t\t{\n\t\t\tLogLog::debug(\n\t\t\t\t((LogString) LOG4CXX_STR(\"OptionConverter::toLevel: no class name specified, level=[\"))\n\t\t\t\t+ value\n\t\t\t\t+ LOG4CXX_STR(\"]\"));\n\t\t\t// no class name specified : use standard Level class\n\t\t\treturn Level::toLevelLS(value, defaultValue);\n\t\t}\n\t}\n\n\tLogString clazz = value.substr(hashIndex + 1);\n\tLogString levelName = value.substr(0, hashIndex);\n\tLogLog::debug(((LogString) LOG4CXX_STR(\"OptionConverter::toLevel: class=[\"))\n\t\t+ clazz + LOG4CXX_STR(\"], level=[\") + levelName + LOG4CXX_STR(\"]\"));\n\n\t// This is degenerate case but you never know.\n\tif (levelName.empty())\n\t{\n\t\treturn Level::toLevelLS(value, defaultValue);\n\t}\n\n\ttry\n\t{\n\t\tLevel::LevelClass& levelClass =\n\t\t\t(Level::LevelClass&)Loader::loadClass(clazz);\n\t\treturn levelClass.toLevel(levelName);\n\t}\n\tcatch (ClassNotFoundException&)\n\t{\n\t\tLogLog::warn(((LogString) LOG4CXX_STR(\"custom level class [\"))\n\t\t\t+ clazz + LOG4CXX_STR(\"] not found.\"));\n\t}\n\tcatch (Exception& oops)\n\t{\n\t\tLogLog::warn(\n\t\t\tLOG4CXX_STR(\"class [\") + clazz + LOG4CXX_STR(\"], level [\") + levelName +\n\t\t\tLOG4CXX_STR(\"] conversion) failed.\"), oops);\n\t}\n\tcatch (...)\n\t{\n\t\tLogLog::warn(\n\t\t\tLOG4CXX_STR(\"class [\") + clazz + LOG4CXX_STR(\"], level [\") + levelName +\n\t\t\tLOG4CXX_STR(\"] conversion) failed.\"));\n\t}\n\n\treturn defaultValue;\n}\n\n\nObjectPtr OptionConverter::instantiateByKey(Properties& props, const LogString& key,\n\tconst Class& superClass, const ObjectPtr& defaultValue)\n{\n\t// Get the value of the property in string form\n\tLogString className(findAndSubst(key, props));\n\n\tif (className.empty())\n\t{\n\t\tLogLog::error(\n\t\t\t((LogString) LOG4CXX_STR(\"Could not find value for key \")) + key);\n\t\treturn defaultValue;\n\t}\n\n\t// Trim className to avoid trailing spaces that cause problems.\n\treturn OptionConverter::instantiateByClassName(\n\t\t\tStringHelper::trim(className), superClass, defaultValue);\n}\n\nObjectPtr OptionConverter::instantiateByClassName(const LogString& className,\n\tconst Class& superClass, const ObjectPtr& defaultValue)\n{\n\tif (!className.empty())\n\t{\n\t\ttry\n\t\t{\n\t\t\tconst Class& classObj = Loader::loadClass(className);\n\t\t\tObjectPtr newObject = ObjectPtr(classObj.newInstance());\n\n\t\t\tif (!newObject->instanceof(superClass))\n\t\t\t{\n\t\t\t\treturn defaultValue;\n\t\t\t}\n\n\t\t\treturn newObject;\n\t\t}\n\t\tcatch (Exception& e)\n\t\t{\n\t\t\tLogLog::error(LOG4CXX_STR(\"Could not instantiate class [\") +\n\t\t\t\tclassName + LOG4CXX_STR(\"].\"), e);\n\t\t}\n\t}\n\n\treturn defaultValue;\n}\n\nvoid OptionConverter::selectAndConfigure(const File& configFileName,\n\tconst LogString& _clazz, spi::LoggerRepositoryPtr hierarchy, int delay)\n{\n\tConfiguratorPtr configurator;\n\tLogString clazz = _clazz;\n\n\tLogString filename(configFileName.getPath());\n\n#if LOG4CXX_HAS_DOMCONFIGURATOR\n\tif (clazz.empty()\n\t\t&& filename.length() > 4\n\t\t&& StringHelper::equalsIgnoreCase(\n\t\t\tfilename.substr(filename.length() - 4),\n\t\t\tLOG4CXX_STR(\".XML\"), LOG4CXX_STR(\".xml\")))\n\t{\n\t\tclazz = LOG4CXX_NS::xml::DOMConfigurator::getStaticClass().toString();\n\t}\n#endif\n\n\tif (!clazz.empty())\n\t{\n\t\tLogLog::debug(LOG4CXX_STR(\"Preferred configurator class: \") + clazz);\n\t\tconst Class& clazzObj = Loader::loadClass(clazz);\n\t\tObjectPtr obj = ObjectPtr(clazzObj.newInstance());\n\t\tconfigurator = LOG4CXX_NS::cast(obj);\n\n\t\tif (configurator == 0)\n\t\t{\n\t\t\tLogLog::error(LOG4CXX_STR(\"Could not instantiate configurator [\")\n\t\t\t\t+ clazz + LOG4CXX_STR(\"].\"));\n\t\t\treturn;\n\t\t}\n\t}\n\telse\n\t{\n\t\tconfigurator = std::make_shared();\n\t}\n\n\tif (0 < delay)\n\t{\n\t\tauto dog = new ConfiguratorWatchdog(configurator, configFileName);\n\t\tAPRInitializer::registerCleanup(dog);\n\t\tdog->setDelay(delay);\n\t\tdog->start();\n\t}\n\telse\n\t\tconfigurator->doConfigure(configFileName, hierarchy);\n}\n\n// Path: src/main/cpp/locale.cpp\n/*\n * Licensed to the Apache Software Foundation (ASF) under one or more\n * contributor license agreements. See the NOTICE file distributed with\n * this work for additional information regarding copyright ownership.\n * The ASF licenses this file to You under the Apache License, Version 2.0\n * (the \"License\"); you may not use this file except in compliance with\n * the License. You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n */\n\n#include \n\nusing namespace LOG4CXX_NS;\nusing namespace LOG4CXX_NS::helpers;\n\nstruct Locale::LocalePrivate\n{\n\tLocalePrivate(const LogString& language1)\n\t\t: language(language1)\n\t{\n\t}\n\n\tLocalePrivate(const LogString& language1, const LogString& country1)\n\t\t: language(language1), country(country1)\n\t{\n\t}\n\n\tLocalePrivate(const LogString& language1, const LogString& country1,\n\t\tconst LogString& variant1)\n\t\t: language(language1), country(country1), variant(variant1)\n\t{\n\t}\n\n\tconst LogString language;\n\tconst LogString country;\n\tconst LogString variant;\n};\n\nLocale::Locale(const LogString& language1)\n\t: m_priv(std::make_unique(language1))\n{\n}\n\nLocale::Locale(const LogString& language1, const LogString& country1)\n\t: m_priv(std::make_unique(language1, country1))\n{\n}\n\nLocale::Locale(const LogString& language1, const LogString& country1,\n\tconst LogString& variant1)\n\t: m_priv(std::make_unique(language1, country1, variant1))\n{\n}\n\nLocale::~Locale() {}\n\nconst LogString& Locale::getLanguage() const\n{\n\treturn m_priv->language;\n}\n\nconst LogString& Locale::getCountry() const\n{\n\treturn m_priv->country;\n}\n\nconst LogString& Locale::getVariant() const\n{\n\treturn m_priv->variant;\n}\n\n\n// Path: src/main/cpp/charsetencoder.cpp\n/*\n * Licensed to the Apache Software Foundation (ASF) under one or more\n * contributor license agreements. See the NOTICE file distributed with\n * this work for additional information regarding copyright ownership.\n * The ASF licenses this file to You under the Apache License, Version 2.0\n * (the \"License\"); you may not use this file except in compliance with\n * the License. You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n */\n#include \n#include \n#include \n#include \n#include \n#include \n#include \n#include \n\n#if !defined(LOG4CXX)\n\t#define LOG4CXX 1\n#endif\n\n#include \n#include \n#include \n\n#ifdef LOG4CXX_HAS_WCSTOMBS\n\t#include \n#endif\n\nusing namespace LOG4CXX_NS;\nusing namespace LOG4CXX_NS::helpers;\n\nIMPLEMENT_LOG4CXX_OBJECT(CharsetEncoder)\n\nnamespace LOG4CXX_NS\n{\n\nnamespace helpers\n{\n\n#if APR_HAS_XLATE\n/**\n* A character encoder implemented using apr_xlate.\n*/\nclass APRCharsetEncoder : public CharsetEncoder\n{\n\tpublic:\n\t\tAPRCharsetEncoder(const LogString& topage) : pool()\n\t\t{\n#if LOG4CXX_LOGCHAR_IS_WCHAR\n\t\t\tconst char* frompage = \"WCHAR_T\";\n#endif\n#if LOG4CXX_LOGCHAR_IS_UTF8\n\t\t\tconst char* frompage = \"UTF-8\";\n#endif\n#if LOG4CXX_LOGCHAR_IS_UNICHAR\n\t\t\tconst char* frompage = \"UTF-16\";\n#endif\n\t\t\tstd::string tpage(Transcoder::encodeCharsetName(topage));\n\t\t\tapr_status_t stat = apr_xlate_open(&convset,\n\t\t\t\t\ttpage.c_str(),\n\t\t\t\t\tfrompage,\n\t\t\t\t\tpool.getAPRPool());\n\n\t\t\tif (stat != APR_SUCCESS)\n\t\t\t{\n\t\t\t\tthrow IllegalArgumentException(topage);\n\t\t\t}\n\t\t}\n\n\t\tvirtual ~APRCharsetEncoder()\n\t\t{\n\t\t}\n\n\t\tvirtual log4cxx_status_t encode(const LogString& in,\n\t\t\tLogString::const_iterator& iter,\n\t\t\tByteBuffer& out)\n\t\t{\n\t\t\tapr_status_t stat;\n\t\t\tsize_t outbytes_left = out.remaining();\n\t\t\tsize_t initial_outbytes_left = outbytes_left;\n\t\t\tsize_t position = out.position();\n\n\t\t\tif (iter == in.end())\n\t\t\t{\n\t\t\t\tstd::unique_lock lock(mutex);\n\t\t\t\tstat = apr_xlate_conv_buffer(convset, NULL, NULL,\n\t\t\t\t\t\tout.data() + position, &outbytes_left);\n\t\t\t}\n\t\t\telse\n\t\t\t{\n\t\t\t\tLogString::size_type inOffset = (iter - in.begin());\n\t\t\t\tapr_size_t inbytes_left =\n\t\t\t\t\t(in.size() - inOffset) * sizeof(LogString::value_type);\n\t\t\t\tapr_size_t initial_inbytes_left = inbytes_left;\n\t\t\t\t{\n\t\t\t\t\tstd::unique_lock lock(mutex);\n\t\t\t\t\tstat = apr_xlate_conv_buffer(convset,\n\t\t\t\t\t\t\t(const char*) (in.data() + inOffset),\n\t\t\t\t\t\t\t&inbytes_left,\n\t\t\t\t\t\t\tout.data() + position,\n\t\t\t\t\t\t\t&outbytes_left);\n\t\t\t\t}\n\t\t\t\titer += ((initial_inbytes_left - inbytes_left) / sizeof(LogString::value_type));\n\t\t\t}\n\n\t\t\tout.position(out.position() + (initial_outbytes_left - outbytes_left));\n\t\t\treturn stat;\n\t\t}\n\n\tprivate:\n\t\tAPRCharsetEncoder(const APRCharsetEncoder&);\n\t\tAPRCharsetEncoder& operator=(const APRCharsetEncoder&);\n\t\tPool pool;\n\t\tstd::mutex mutex;\n\t\tapr_xlate_t* convset;\n};\n#endif\n\n#if LOG4CXX_LOGCHAR_IS_WCHAR && LOG4CXX_HAS_WCSTOMBS\n/**\n * A character encoder implemented using wcstombs.\n*/\nclass WcstombsCharsetEncoder : public CharsetEncoder\n{\n\tpublic:\n\t\tWcstombsCharsetEncoder()\n\t\t{\n\t\t}\n\n\t\t/**\n\t\t * Converts a wchar_t to the default external multibyte encoding.\n\t\t */\n\t\tlog4cxx_status_t encode(const LogString& in,\n\t\t\tLogString::const_iterator& iter,\n\t\t\tByteBuffer& out)\n\t\t{\n\t\t\tlog4cxx_status_t stat = APR_SUCCESS;\n\n\t\t\tif (iter != in.end())\n\t\t\t{\n\t\t\t\tsize_t outbytes_left = out.remaining();\n\t\t\t\tsize_t position = out.position();\n\t\t\t\tstd::wstring::size_type inOffset = (iter - in.begin());\n\t\t\t\tenum { BUFSIZE = 256 };\n\t\t\t\twchar_t buf[BUFSIZE];\n\t\t\t\tsize_t chunkSize = BUFSIZE - 1;\n\n\t\t\t\tif (chunkSize * MB_LEN_MAX > outbytes_left)\n\t\t\t\t{\n\t\t\t\t\tchunkSize = outbytes_left / MB_LEN_MAX;\n\t\t\t\t}\n\n\t\t\t\tif (chunkSize > in.length() - inOffset)\n\t\t\t\t{\n\t\t\t\t\tchunkSize = in.length() - inOffset;\n\t\t\t\t}\n\n\t\t\t\tmemset(buf, 0, BUFSIZE * sizeof(wchar_t));\n\t\t\t\tmemcpy(buf,\n\t\t\t\t\tin.data() + inOffset,\n\t\t\t\t\tchunkSize * sizeof(wchar_t));\n\t\t\t\tsize_t converted = wcstombs(out.data() + position, buf, outbytes_left);\n\n\t\t\t\tif (converted == (size_t) -1)\n\t\t\t\t{\n\t\t\t\t\tstat = APR_BADARG;\n\n\t\t\t\t\t//\n\t\t\t\t\t// if unconvertable character was encountered\n\t\t\t\t\t// repeatedly halve source to get fragment that\n\t\t\t\t\t// can be converted\n\t\t\t\t\tfor (chunkSize /= 2;\n\t\t\t\t\t\tchunkSize > 0;\n\t\t\t\t\t\tchunkSize /= 2)\n\t\t\t\t\t{\n\t\t\t\t\t\tbuf[chunkSize] = 0;\n\t\t\t\t\t\tconverted = wcstombs(out.data() + position, buf, outbytes_left);\n\n\t\t\t\t\t\tif (converted != (size_t) -1)\n\t\t\t\t\t\t{\n\t\t\t\t\t\t\titer += chunkSize;\n\t\t\t\t\t\t\tout.position(out.position() + converted);\n\t\t\t\t\t\t\tbreak;\n\t\t\t\t\t\t}\n\t\t\t\t\t}\n\t\t\t\t}\n\t\t\t\telse\n\t\t\t\t{\n\t\t\t\t\titer += chunkSize;\n\t\t\t\t\tout.position(out.position() + converted);\n\t\t\t\t}\n\t\t\t}\n\n\t\t\treturn stat;\n\t\t}\n\n\n\n\tprivate:\n\t\tWcstombsCharsetEncoder(const WcstombsCharsetEncoder&);\n\t\tWcstombsCharsetEncoder& operator=(const WcstombsCharsetEncoder&);\n};\n#endif\n\n\n/**\n* Encodes a LogString to US-ASCII.\n*/\nclass USASCIICharsetEncoder : public CharsetEncoder\n{\n\tpublic:\n\t\tUSASCIICharsetEncoder()\n\t\t{\n\t\t}\n\n\t\tvirtual log4cxx_status_t encode(const LogString& in,\n\t\t\tLogString::const_iterator& iter,\n\t\t\tByteBuffer& out)\n\t\t{\n\t\t\tlog4cxx_status_t stat = APR_SUCCESS;\n\n\t\t\tif (iter != in.end())\n\t\t\t{\n\t\t\t\twhile (out.remaining() > 0 && iter != in.end())\n\t\t\t\t{\n\t\t\t\t\tLogString::const_iterator prev(iter);\n\t\t\t\t\tunsigned int sv = Transcoder::decode(in, iter);\n\n\t\t\t\t\tif (sv <= 0x7F)\n\t\t\t\t\t{\n\t\t\t\t\t\tout.put((char) sv);\n\t\t\t\t\t}\n\t\t\t\t\telse\n\t\t\t\t\t{\n\t\t\t\t\t\titer = prev;\n\t\t\t\t\t\tstat = APR_BADARG;\n\t\t\t\t\t\tbreak;\n\t\t\t\t\t}\n\t\t\t\t}\n\t\t\t}\n\n\t\t\treturn stat;\n\t\t}\n\n\tprivate:\n\t\tUSASCIICharsetEncoder(const USASCIICharsetEncoder&);\n\t\tUSASCIICharsetEncoder& operator=(const USASCIICharsetEncoder&);\n};\n\n/**\n* Converts a LogString to ISO-8859-1.\n*/\nclass ISOLatinCharsetEncoder : public CharsetEncoder\n{\n\tpublic:\n\t\tISOLatinCharsetEncoder()\n\t\t{\n\t\t}\n\n\t\tvirtual log4cxx_status_t encode(const LogString& in,\n\t\t\tLogString::const_iterator& iter,\n\t\t\tByteBuffer& out)\n\t\t{\n\t\t\tlog4cxx_status_t stat = APR_SUCCESS;\n\n\t\t\tif (iter != in.end())\n\t\t\t{\n\t\t\t\twhile (out.remaining() > 0 && iter != in.end())\n\t\t\t\t{\n\t\t\t\t\tLogString::const_iterator prev(iter);\n\t\t\t\t\tunsigned int sv = Transcoder::decode(in, iter);\n\n\t\t\t\t\tif (sv <= 0xFF)\n\t\t\t\t\t{\n\t\t\t\t\t\tout.put((char) sv);\n\t\t\t\t\t}\n\t\t\t\t\telse\n\t\t\t\t\t{\n\t\t\t\t\t\titer = prev;\n\t\t\t\t\t\tstat = APR_BADARG;\n\t\t\t\t\t\tbreak;\n\t\t\t\t\t}\n\t\t\t\t}\n\t\t\t}\n\n\t\t\treturn stat;\n\t\t}\n\n\tprivate:\n\t\tISOLatinCharsetEncoder(const ISOLatinCharsetEncoder&);\n\t\tISOLatinCharsetEncoder& operator=(const ISOLatinCharsetEncoder&);\n};\n\n/**\n* Encodes a LogString to a byte array when the encodings are identical.\n*/\nclass TrivialCharsetEncoder : public CharsetEncoder\n{\n\tpublic:\n\t\tTrivialCharsetEncoder()\n\t\t{\n\t\t}\n\n\n\t\tvirtual log4cxx_status_t encode(const LogString& in,\n\t\t\tLogString::const_iterator& iter,\n\t\t\tByteBuffer& out)\n\t\t{\n\t\t\tif (iter != in.end())\n\t\t\t{\n\t\t\t\tsize_t requested = in.length() - (iter - in.begin());\n\n\t\t\t\tif (requested > out.remaining() / sizeof(logchar))\n\t\t\t\t{\n\t\t\t\t\trequested = out.remaining() / sizeof(logchar);\n\t\t\t\t}\n\n\t\t\t\tmemcpy(out.current(),\n\t\t\t\t\t(const char*) in.data() + (iter - in.begin()),\n\t\t\t\t\trequested * sizeof(logchar));\n\t\t\t\titer += requested;\n\t\t\t\tout.position(out.position() + requested * sizeof(logchar));\n\t\t\t}\n\n\t\t\treturn APR_SUCCESS;\n\t\t}\n\n\tprivate:\n\t\tTrivialCharsetEncoder(const TrivialCharsetEncoder&);\n\t\tTrivialCharsetEncoder& operator=(const TrivialCharsetEncoder&);\n};\n\n#if LOG4CXX_LOGCHAR_IS_UTF8\ntypedef TrivialCharsetEncoder UTF8CharsetEncoder;\n#else\n/**\n * Converts a LogString to UTF-8.\n */\nclass UTF8CharsetEncoder : public CharsetEncoder\n{\n\tpublic:\n\t\tUTF8CharsetEncoder()\n\t\t{\n\t\t}\n\n\t\tvirtual log4cxx_status_t encode(const LogString& in,\n\t\t\tLogString::const_iterator& iter,\n\t\t\tByteBuffer& out)\n\t\t{\n\t\t\twhile (iter != in.end() && out.remaining() >= 8)\n\t\t\t{\n\t\t\t\tunsigned int sv = Transcoder::decode(in, iter);\n\n\t\t\t\tif (sv == 0xFFFF)\n\t\t\t\t{\n\t\t\t\t\treturn APR_BADARG;\n\t\t\t\t}\n\n\t\t\t\tTranscoder::encodeUTF8(sv, out);\n\t\t\t}\n\n\t\t\treturn APR_SUCCESS;\n\t\t}\n\n\tprivate:\n\t\tUTF8CharsetEncoder(const UTF8CharsetEncoder&);\n\t\tUTF8CharsetEncoder& operator=(const UTF8CharsetEncoder&);\n};\n#endif\n\n/**\n * Encodes a LogString to UTF16-BE.\n */\nclass UTF16BECharsetEncoder : public CharsetEncoder\n{\n\tpublic:\n\t\tUTF16BECharsetEncoder()\n\t\t{\n\t\t}\n\n\t\tvirtual log4cxx_status_t encode(const LogString& in,\n\t\t\tLogString::const_iterator& iter,\n\t\t\tByteBuffer& out)\n\t\t{\n\t\t\twhile (iter != in.end() && out.remaining() >= 4)\n\t\t\t{\n\t\t\t\tunsigned int sv = Transcoder::decode(in, iter);\n\n\t\t\t\tif (sv == 0xFFFF)\n\t\t\t\t{\n\t\t\t\t\treturn APR_BADARG;\n\t\t\t\t}\n\n\t\t\t\tTranscoder::encodeUTF16BE(sv, out);\n\t\t\t}\n\n\t\t\treturn APR_SUCCESS;\n\t\t}\n\n\tprivate:\n\t\tUTF16BECharsetEncoder(const UTF16BECharsetEncoder&);\n\t\tUTF16BECharsetEncoder& operator=(const UTF16BECharsetEncoder&);\n};\n\n/**\n * Encodes a LogString to UTF16-LE.\n */\nclass UTF16LECharsetEncoder : public CharsetEncoder\n{\n\tpublic:\n\t\tUTF16LECharsetEncoder()\n\t\t{\n\t\t}\n\n\n\t\tvirtual log4cxx_status_t encode(const LogString& in,\n\t\t\tLogString::const_iterator& iter,\n\t\t\tByteBuffer& out)\n\t\t{\n\t\t\twhile (iter != in.end() && out.remaining() >= 4)\n\t\t\t{\n\t\t\t\tunsigned int sv = Transcoder::decode(in, iter);\n\n\t\t\t\tif (sv == 0xFFFF)\n\t\t\t\t{\n\t\t\t\t\treturn APR_BADARG;\n\t\t\t\t}\n\n\t\t\t\tTranscoder::encodeUTF16LE(sv, out);\n\t\t\t}\n\n\t\t\treturn APR_SUCCESS;\n\t\t}\n\tprivate:\n\t\tUTF16LECharsetEncoder(const UTF16LECharsetEncoder&);\n\t\tUTF16LECharsetEncoder& operator=(const UTF16LECharsetEncoder&);\n};\n\n/**\n * Charset encoder that uses current locale settings.\n */\nclass LocaleCharsetEncoder : public CharsetEncoder\n{\n\tpublic:\n\t\tLocaleCharsetEncoder() : state()\n\t\t{\n\t\t}\n\t\tlog4cxx_status_t encode\n\t\t\t( const LogString& in\n\t\t\t, LogString::const_iterator& iter\n\t\t\t, ByteBuffer& out\n\t\t\t) override\n\t\t{\n\t\t\tlog4cxx_status_t result = APR_SUCCESS;\n#if !LOG4CXX_CHARSET_EBCDIC\n\t\t\tchar* current = out.current();\n\t\t\tsize_t remain = out.remaining();\n\t\t\tif (std::mbsinit(&this->state)) // ByteBuffer not partially encoded?\n\t\t\t{\n\t\t\t\t// Copy single byte characters\n\t\t\t\tfor (;\n\t\t\t\t\titer != in.end() && ((unsigned int) *iter) < 0x80 && 0 < remain;\n\t\t\t\t\titer++, remain--, current++)\n\t\t\t\t{\n\t\t\t\t\t*current = *iter;\n\t\t\t\t}\n\t\t\t}\n#endif\n\t\t\t// Encode characters that may require multiple bytes\n\t\t\twhile (iter != in.end() && MB_CUR_MAX <= remain)\n\t\t\t{\n\t\t\t\tauto ch = Transcoder::decode(in, iter);\n\t\t\t\tauto n = std::wcrtomb(current, ch, &this->state);\n\t\t\t\tif (static_cast(-1) == n) // not a valid wide character?\n\t\t\t\t{\n\t\t\t\t\tresult = APR_BADARG;\n\t\t\t\t\tbreak;\n\t\t\t\t}\n\t\t\t\tremain -= n;\n\t\t\t\tcurrent += n;\n\t\t\t}\n\t\t\tout.position(current - out.data());\n\t\t\treturn result;\n\t\t}\n\n\tprivate:\n\t\tstd::mbstate_t state;\n};\n\n\n} // namespace helpers\n\n} //namespace log4cxx\n\n\n\nCharsetEncoder::CharsetEncoder()\n{\n}\n\nCharsetEncoder::~CharsetEncoder()\n{\n}\n\nCharsetEncoderPtr CharsetEncoder::getDefaultEncoder()\n{\n\tstatic WideLife encoder(createDefaultEncoder());\n\n\t//\n\t// if invoked after static variable destruction\n\t// (if logging is called in the destructor of a static object)\n\t// then create a new decoder.\n\t//\n\tif (encoder.value() == 0)\n\t{\n\t\treturn CharsetEncoderPtr( createDefaultEncoder() );\n\t}\n\n\treturn encoder;\n}\n\nCharsetEncoder* CharsetEncoder::createDefaultEncoder()\n{\n#if LOG4CXX_CHARSET_UTF8\n\treturn new UTF8CharsetEncoder();\n#elif LOG4CXX_CHARSET_ISO88591\n\treturn new ISOLatinCharsetEncoder();\n#elif LOG4CXX_CHARSET_USASCII\n\treturn new USASCIICharsetEncoder();\n#elif LOG4CXX_LOGCHAR_IS_WCHAR && LOG4CXX_HAS_WCSTOMBS\n\treturn new WcstombsCharsetEncoder();\n#else\n\treturn new LocaleCharsetEncoder();\n#endif\n}\n\n\nCharsetEncoderPtr CharsetEncoder::getUTF8Encoder()\n{\n\treturn std::make_shared();\n}\n\n\n\nCharsetEncoderPtr CharsetEncoder::getEncoder(const LogString& charset)\n{\n\tif (StringHelper::equalsIgnoreCase(charset, LOG4CXX_STR(\"UTF-8\"), LOG4CXX_STR(\"utf-8\"))\n\t\t|| StringHelper::equalsIgnoreCase(charset, LOG4CXX_STR(\"CP65001\"), LOG4CXX_STR(\"cp65001\")))\n\t{\n\t\treturn std::make_shared();\n\t}\n\telse if (StringHelper::equalsIgnoreCase(charset, LOG4CXX_STR(\"C\"), LOG4CXX_STR(\"c\")) ||\n\t\tcharset == LOG4CXX_STR(\"646\") ||\n\t\tStringHelper::equalsIgnoreCase(charset, LOG4CXX_STR(\"US-ASCII\"), LOG4CXX_STR(\"us-ascii\")) ||\n\t\tStringHelper::equalsIgnoreCase(charset, LOG4CXX_STR(\"ISO646-US\"), LOG4CXX_STR(\"iso646-US\")) ||\n\t\tStringHelper::equalsIgnoreCase(charset, LOG4CXX_STR(\"ANSI_X3.4-1968\"), LOG4CXX_STR(\"ansi_x3.4-1968\")) ||\n\t\tStringHelper::equalsIgnoreCase(charset, LOG4CXX_STR(\"CP20127\"), LOG4CXX_STR(\"cp20127\")))\n\t{\n\t\treturn std::make_shared();\n\t}\n\telse if (StringHelper::equalsIgnoreCase(charset, LOG4CXX_STR(\"ISO-8859-1\"), LOG4CXX_STR(\"iso-8859-1\")) ||\n\t\tStringHelper::equalsIgnoreCase(charset, LOG4CXX_STR(\"ISO-LATIN-1\"), LOG4CXX_STR(\"iso-latin-1\")) ||\n\t\tStringHelper::equalsIgnoreCase(charset, LOG4CXX_STR(\"CP1252\"), LOG4CXX_STR(\"cp1252\")))\n\t{\n\t\treturn std::make_shared();\n\t}\n\telse if (StringHelper::equalsIgnoreCase(charset, LOG4CXX_STR(\"UTF-16BE\"), LOG4CXX_STR(\"utf-16be\"))\n\t\t|| StringHelper::equalsIgnoreCase(charset, LOG4CXX_STR(\"UTF-16\"), LOG4CXX_STR(\"utf-16\"))\n\t\t|| StringHelper::equalsIgnoreCase(charset, LOG4CXX_STR(\"CP1200\"), LOG4CXX_STR(\"cp1200\")))\n\t{\n\t\treturn std::make_shared();\n\t}\n\telse if (StringHelper::equalsIgnoreCase(charset, LOG4CXX_STR(\"UTF-16LE\"), LOG4CXX_STR(\"utf-16le\")))\n\t{\n\t\treturn std::make_shared();\n\t}\n\telse if (StringHelper::equalsIgnoreCase(charset, LOG4CXX_STR(\"LOCALE\"), LOG4CXX_STR(\"locale\")))\n\t{\n\t\treturn std::make_shared();\n\t}\n\n#if APR_HAS_XLATE\n\treturn std::make_shared(charset);\n#else\n\tthrow IllegalArgumentException(charset);\n#endif\n}\n\n\nvoid CharsetEncoder::reset()\n{\n}\n\nvoid CharsetEncoder::flush(ByteBuffer& /* out */ )\n{\n}\n\n\nvoid CharsetEncoder::encode(CharsetEncoderPtr& enc,\n\tconst LogString& src,\n\tLogString::const_iterator& iter,\n\tByteBuffer& dst)\n{\n\tlog4cxx_status_t stat = enc->encode(src, iter, dst);\n\n\tif (stat != APR_SUCCESS && iter != src.end())\n\t{\n#if LOG4CXX_LOGCHAR_IS_WCHAR || LOG4CXX_LOGCHAR_IS_UNICHAR\n\t\titer++;\n#elif LOG4CXX_LOGCHAR_IS_UTF8\n\n\t\t// advance past this character and all continuation characters\n\t\twhile ((*(++iter) & 0xC0) == 0x80);\n\n#else\n#error logchar is unrecognized\n#endif\n\t\tdst.put(Transcoder::LOSSCHAR);\n\t}\n}\n\nbool CharsetEncoder::isTriviallyCopyable(const LogString& src, const CharsetEncoderPtr& enc)\n{\n\tbool result;\n#if !LOG4CXX_CHARSET_EBCDIC\n\tif (dynamic_cast(enc.get()))\n\t{\n\t\tresult = src.end() == std::find_if(src.begin(), src.end()\n\t\t\t, [](const logchar& ch) -> bool { return 0x80 <= (unsigned int)ch; });\n\t}\n\telse\n#endif\n\t\tresult = !!dynamic_cast(enc.get());\n\treturn result;\n}\n\n// Path: src/main/cpp/colorstartpatternconverter.cpp\n/*\n * Licensed to the Apache Software Foundation (ASF) under one or more\n * contributor license agreements. See the NOTICE file distributed with\n * this work for additional information regarding copyright ownership.\n * The ASF licenses this file to You under the Apache License, Version 2.0\n * (the \"License\"); you may not use this file except in compliance with\n * the License. You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n */\n\n#include \n#include \n#include \n#include \n#include \n#include \n\nusing namespace LOG4CXX_NS;\nusing namespace LOG4CXX_NS::pattern;\nusing namespace LOG4CXX_NS::spi;\nusing namespace LOG4CXX_NS::helpers;\n\nIMPLEMENT_LOG4CXX_OBJECT(ColorStartPatternConverter)\n\n#define priv static_cast(m_priv.get())\n\nstatic LogString colorToANSISequence(const LogString& color, bool isForeground, Pool& pool){\n\tint numberToConvert = 0;\n\n\tif(StringHelper::equalsIgnoreCase(color, LOG4CXX_STR(\"BLACK\"), LOG4CXX_STR(\"black\"))){\n\t\tnumberToConvert = 30;\n\t}else if(StringHelper::equalsIgnoreCase(color, LOG4CXX_STR(\"RED\"), LOG4CXX_STR(\"red\"))){\n\t\tnumberToConvert = 31;\n\t}else if(StringHelper::equalsIgnoreCase(color, LOG4CXX_STR(\"GREEN\"), LOG4CXX_STR(\"green\"))){\n\t\tnumberToConvert = 32;\n\t}else if(StringHelper::equalsIgnoreCase(color, LOG4CXX_STR(\"YELLOW\"), LOG4CXX_STR(\"yellow\"))){\n\t\tnumberToConvert = 33;\n\t}else if(StringHelper::equalsIgnoreCase(color, LOG4CXX_STR(\"BLUE\"), LOG4CXX_STR(\"blue\"))){\n\t\tnumberToConvert = 34;\n\t}else if(StringHelper::equalsIgnoreCase(color, LOG4CXX_STR(\"MAGENTA\"), LOG4CXX_STR(\"magenta\"))){\n\t\tnumberToConvert = 35;\n\t}else if(StringHelper::equalsIgnoreCase(color, LOG4CXX_STR(\"CYAN\"), LOG4CXX_STR(\"cyan\"))){\n\t\tnumberToConvert = 36;\n\t}else if(StringHelper::equalsIgnoreCase(color, LOG4CXX_STR(\"WHITE\"), LOG4CXX_STR(\"white\"))){\n\t\tnumberToConvert = 37;\n\t}\n\n\tif( numberToConvert == 0 ){\n\t\treturn LOG4CXX_STR(\"\");\n\t}\n\tLogString ret;\n\tif( isForeground == false ){\n\t\tnumberToConvert += 10;\n\t}\n\tStringHelper::toString(numberToConvert, pool, ret);\n\treturn ret;\n}\n\nstatic LogString graphicsModeToANSISequence(const LogString& graphicsMode, Pool& pool){\n\tint numberToConvert = 0;\n\n\tif(StringHelper::equalsIgnoreCase(graphicsMode, LOG4CXX_STR(\"BOLD\"), LOG4CXX_STR(\"bold\"))){\n\t\tnumberToConvert = 1;\n\t}else if(StringHelper::equalsIgnoreCase(graphicsMode, LOG4CXX_STR(\"DIM\"), LOG4CXX_STR(\"dim\"))){\n\t\tnumberToConvert = 2;\n\t}else if(StringHelper::equalsIgnoreCase(graphicsMode, LOG4CXX_STR(\"ITALIC\"), LOG4CXX_STR(\"italic\"))){\n\t\tnumberToConvert = 3;\n\t}else if(StringHelper::equalsIgnoreCase(graphicsMode, LOG4CXX_STR(\"UNDERLINE\"), LOG4CXX_STR(\"underline\"))){\n\t\tnumberToConvert = 4;\n\t}else if(StringHelper::equalsIgnoreCase(graphicsMode, LOG4CXX_STR(\"BLINKING\"), LOG4CXX_STR(\"blinking\"))){\n\t\tnumberToConvert = 5;\n\t}else if(StringHelper::equalsIgnoreCase(graphicsMode, LOG4CXX_STR(\"INVERSE\"), LOG4CXX_STR(\"inverse\"))){\n\t\tnumberToConvert = 7;\n\t}else if(StringHelper::equalsIgnoreCase(graphicsMode, LOG4CXX_STR(\"STRIKETHROUGH\"), LOG4CXX_STR(\"strikethrough\"))){\n\t\tnumberToConvert = 9;\n\t}\n\n\tif( numberToConvert == 0 ){\n\t\treturn LOG4CXX_STR(\"\");\n\t}\n\tLogString ret;\n\tStringHelper::toString(numberToConvert, pool, ret);\n\treturn ret;\n}\n\nstatic LogString convertSingleSequence(const LogString& sequence, Pool& pool){\n\tLogString strInParens;\n\tbool inParens = false;\n\tbool hasParens = false;\n\tsize_t x = 0;\n\n\tfor(x = 0; x < sequence.length(); x++){\n\t\tif( sequence[x] == '(' && !inParens ){\n\t\t\tinParens = true;\n\t\t\thasParens = true;\n\t\t\tcontinue;\n\t\t}else if( sequence[x] == '(' && inParens ){\n\t\t\t// Unbalanced parens - parse invalid\n\t\t\treturn LOG4CXX_STR(\"\");\n\t\t}\n\n\t\tif( sequence[x] == ')' && inParens ){\n\t\t\thasParens = true;\n\t\t\tinParens = false;\n\t\t\tbreak;\n\t\t}\n\n\t\tif( inParens ){\n\t\t\tstrInParens.push_back(sequence[x]);\n\t\t}\n\t}\n\n\tif( (x != (sequence.length() - 1) || inParens) && hasParens ){\n\t\t// Unbalanced parens, or more data in the string than we expected - parse invalid\n\t\treturn LOG4CXX_STR(\"\");\n\t}\n\n\tif(StringHelper::startsWith(sequence, LOG4CXX_STR(\"fg(\"))){\n\t\t// Parse foreground\n\t\treturn colorToANSISequence(strInParens, true, pool);\n\t}else if(StringHelper::startsWith(sequence, LOG4CXX_STR(\"bg(\"))){\n\t\treturn colorToANSISequence(strInParens, false, pool);\n\t}else{\n\t\treturn graphicsModeToANSISequence(sequence, pool);\n\t}\n}\n\nstruct ColorStartPatternConverter::ColorPatternConverterPrivate : public PatternConverterPrivate\n{\n\tColorPatternConverterPrivate( const LogString& name, const LogString& style ) :\n\t\tPatternConverterPrivate( name, style ){}\n\n\tLogString m_fatalColor;\n\tLogString m_errorColor;\n\tLogString m_warnColor;\n\tLogString m_infoColor;\n\tLogString m_debugColor;\n\tLogString m_traceColor;\n};\n\nColorStartPatternConverter::ColorStartPatternConverter() :\n\tLoggingEventPatternConverter(std::make_unique(LOG4CXX_STR(\"Color Start\"),\n\t\tLOG4CXX_STR(\"colorStart\")))\n{\n}\n\nPatternConverterPtr ColorStartPatternConverter::newInstance(\n\tconst std::vector& /* options */)\n{\n\tstatic WideLife instance = std::make_shared();\n\treturn instance;\n}\n\nvoid ColorStartPatternConverter::format(\n\tconst LoggingEventPtr& event,\n\tLogString& toAppendTo,\n\tPool& p) const\n{\n\n\tLOG4CXX_NS::LevelPtr lvl = event->getLevel();\n\n\tswitch (lvl->toInt())\n\t{\n\t\tcase LOG4CXX_NS::Level::FATAL_INT:\n\t\t\ttoAppendTo.append(priv->m_fatalColor);\n\t\t\tbreak;\n\n\t\tcase LOG4CXX_NS::Level::ERROR_INT:\n\t\t\ttoAppendTo.append(priv->m_errorColor);\n\t\t\tbreak;\n\n\t\tcase LOG4CXX_NS::Level::WARN_INT:\n\t\t\ttoAppendTo.append(priv->m_warnColor);\n\t\t\tbreak;\n\n\t\tcase LOG4CXX_NS::Level::INFO_INT:\n\t\t\ttoAppendTo.append(priv->m_infoColor);\n\t\t\tbreak;\n\n\t\tcase LOG4CXX_NS::Level::DEBUG_INT:\n\t\t\ttoAppendTo.append(priv->m_debugColor);\n\t\t\tbreak;\n\n\t\tcase LOG4CXX_NS::Level::TRACE_INT:\n\t\t\ttoAppendTo.append(priv->m_traceColor);\n\t\t\tbreak;\n\n\t\tdefault:\n\t\t\tbreak;\n\t}\n}\n\nvoid ColorStartPatternConverter::setFatalColor(const LogString& color){\n\tparseColor(color, &(priv->m_fatalColor));\n}\n\nvoid ColorStartPatternConverter::setErrorColor(const LogString& color){\n\tparseColor(color, &(priv->m_errorColor));\n}\n\nvoid ColorStartPatternConverter::setWarnColor(const LogString& color){\n\tparseColor(color, &(priv->m_warnColor));\n}\n\nvoid ColorStartPatternConverter::setInfoColor(const LogString& color){\n\tparseColor(color, &(priv->m_infoColor));\n}\n\nvoid ColorStartPatternConverter::setDebugColor(const LogString& color){\n\tparseColor(color, &(priv->m_debugColor));\n}\n\nvoid ColorStartPatternConverter::setTraceColor(const LogString& color){\n\tparseColor(color, &(priv->m_traceColor));\n}\n\nvoid ColorStartPatternConverter::parseColor(const LogString& color, LogString* result){\n\tLogString lower = StringHelper::toLowerCase(color);\n\tPool pool;\n\n\t// If the color we are trying to parse is blank, clear our result\n\tif(StringHelper::trim(color).empty() ||\n\t\t\tStringHelper::equalsIgnoreCase(color,\n\t\t\t\t\t\t\t\t\t\t LOG4CXX_STR(\"NONE\"),\n\t\t\t\t\t\t\t\t\t\t LOG4CXX_STR(\"none\"))){\n\t\tresult->clear();\n\t\treturn;\n\t}\n\n\tif( StringHelper::startsWith(lower, LOG4CXX_STR(\"\\\\x1b\")) ){\n\t\tif( color[color.size() - 1] != 'm' ){\n\t\t\t// In order for this to be a valid ANSI escape sequence,\n\t\t\t// it must end with an 'm'. If it does not, reject.\n\t\t\treturn;\n\t\t}\n\t\t// We start with an escape sequence, copy the data over after the escape byte\n\t\tresult->clear();\n\t\tresult->append(LOG4CXX_STR(\"\\x1b\"));\n\t\tfor( size_t x = 4; x < color.size(); x++ ){\n\t\t\tresult->push_back(color[x]);\n\t\t}\n\t}else{\n\t\t// We do not start with an escape sequence: try to parse color\n\t\t// Escape sequence information:\n\t\t// https://gist.github.com/fnky/458719343aabd01cfb17a3a4f7296797\n\t\t// https://en.wikipedia.org/wiki/ANSI_escape_code\n\t\tresult->clear();\n\t\tresult->append(LOG4CXX_STR(\"\\x1b[\"));\n\t\tLogString tmp;\n\t\tfor( size_t x = 0; x < color.size(); x++ ){\n\t\t\tif(color[x] == '|' ){\n\t\t\t\tLogString toAppend = convertSingleSequence(tmp, pool);\n\t\t\t\ttmp.clear();\n\t\t\t\tif(!toAppend.empty()){\n\t\t\t\t\tresult->push_back(';');\n\t\t\t\t\tresult->append(toAppend);\n\t\t\t\t}\n\t\t\t}else{\n\t\t\t\ttmp.push_back(color[x]);\n\t\t\t}\n\t\t}\n\t\tLogString toAppend = convertSingleSequence(tmp, pool);\n\t\ttmp.clear();\n\t\tif(!toAppend.empty()){\n\t\t\tresult->push_back(';');\n\t\t\tresult->append(toAppend);\n\t\t}\n\t\tresult->append(LOG4CXX_STR(\"m\"));\n\t}\n}\n\n// Path: src/main/cpp/literalpatternconverter.cpp\n/*\n * Licensed to the Apache Software Foundation (ASF) under one or more\n * contributor license agreements. See the NOTICE file distributed with\n * this work for additional information regarding copyright ownership.\n * The ASF licenses this file to You under the Apache License, Version 2.0\n * (the \"License\"); you may not use this file except in compliance with\n * the License. You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n */\n\n#include \n#include \n#include \n#include \n#include \n\nusing namespace LOG4CXX_NS;\nusing namespace LOG4CXX_NS::pattern;\nusing namespace LOG4CXX_NS::spi;\nusing namespace LOG4CXX_NS::helpers;\n\n...\n", "needle_token_start": 767, "needle_token_end": 816, "code_context_ntokens": 15188, "output": ["```cpp\nnamespace LOG4CXX_NS\n{\n\nclass ConfiguratorWatchdog : public helpers::FileWatchdog\n{\n\tspi::ConfiguratorPtr m_config;\n\tpublic:\n \nConfiguratorWatchdog(const spi::ConfiguratorPtr& config, const File& filename)\n : helpers::FileWatchdog(filename)\n , m_config(config)\n {\n }\n\n /**\n Call PropertyConfigurator#doConfigure(const String& configFileName,\n const spi::LoggerRepositoryPtr& hierarchy) with the\n filename to reconfigure log4cxx.\n */\n void doOnChange() override\n {\n m_config->doConfigure(file(), LogManager::getLoggerRepository());\n }\n};\n\n}\n```"]} +{"repo": "apache/logging-log4cxx", "name": "TelnetAppenderPriv", "language": "cpp", "path": "src/main/cpp/telnetappender.cpp", "position_ratio": 0.15, "description": "\nFunction Description:\n1. **Purpose**: The purpose of this component is to manage the logging output via a network connection, allowing multiple clients to view log messages in real-time through a telnet session.\n2. **Input**: It takes two primary inputs: a port number for the telnet server to listen on, and a maximum number of client connections that can simultaneously connect to the telnet server.\n3. **Output**: There is no direct output in the traditional sense, as this component primarily handles network connections and streams logging data to connected clients.\n4. **Procedure**: This component initializes a server listening on a specified port with a capacity for a specified maximum number of connections. It manages these connections, ensuring that each connected client receives real-time logging data. The component uses UTF-8 encoding for transmitting data to maintain consistency and support international character sets.\n\n", "instruction": "Based on the function description and code context, please retrieve and repeat the exact described function from the code context in a code block wrapped by ```:", "template": "instruction\ncode_context\ndescription\ninstruction", "code_context": "// Path: src/main/cpp/fileappender.cpp\n/*\n * Licensed to the Apache Software Foundation (ASF) under one or more\n * contributor license agreements. See the NOTICE file distributed with\n * this work for additional information regarding copyright ownership.\n * The ASF licenses this file to You under the Apache License, Version 2.0\n * (the \"License\"); you may not use this file except in compliance with\n * the License. You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n */\n\n#include \n#include \n#include \n#include \n#include \n#include \n#include \n#include \n#include \n#include \n#include \n#include \n#include \n\nusing namespace LOG4CXX_NS;\nusing namespace LOG4CXX_NS::helpers;\nusing namespace LOG4CXX_NS::spi;\n\nIMPLEMENT_LOG4CXX_OBJECT(FileAppender)\n\n#define _priv static_cast(m_priv.get())\n\nFileAppender::FileAppender() :\n\tWriterAppender (std::make_unique())\n{\n}\n\nFileAppender::FileAppender\n\t( const LayoutPtr& layout1\n\t, const LogString& fileName1\n\t, bool append1\n\t, bool bufferedIO1\n\t, int bufferSize1\n\t)\n...\n// Path: src/main/cpp/transform.cpp\n/*\n * Licensed to the Apache Software Foundation (ASF) under one or more\n * contributor license agreements. See the NOTICE file distributed with\n * this work for additional information regarding copyright ownership.\n * The ASF licenses this file to You under the Apache License, Version 2.0\n * (the \"License\"); you may not use this file except in compliance with\n * the License. You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n */\n\n#include \n#include \n#include \n\nusing namespace LOG4CXX_NS;\nusing namespace LOG4CXX_NS::helpers;\n\n\n\nvoid Transform::appendEscapingTags(\n\tLogString& buf, const LogString& input)\n{\n\t//Check if the string is zero length -- if so, return\n\t//what was sent in.\n\n\tif (input.length() == 0 )\n\t{\n\t\treturn;\n\t}\n\n\tlogchar specials[] = { 0x22 /* \" */, 0x26 /* & */, 0x3C /* < */, 0x3E /* > */, 0x00 };\n\tsize_t start = 0;\n\tsize_t special = input.find_first_of(specials, start);\n\n\twhile (special != LogString::npos)\n\t{\n\t\tif (special > start)\n\t\t{\n\t\t\tbuf.append(input, start, special - start);\n\t\t}\n\n\t\tswitch (input[special])\n\t\t{\n\t\t\tcase 0x22:\n\t\t\t\tbuf.append(LOG4CXX_STR(\""\"));\n\t\t\t\tbreak;\n\n\t\t\tcase 0x26:\n\t\t\t\tbuf.append(LOG4CXX_STR(\"&\"));\n\t\t\t\tbreak;\n\n\t\t\tcase 0x3C:\n\t\t\t\tbuf.append(LOG4CXX_STR(\"<\"));\n\t\t\t\tbreak;\n\n\t\t\tcase 0x3E:\n\t\t\t\tbuf.append(LOG4CXX_STR(\">\"));\n\t\t\t\tbreak;\n\n\t\t\tdefault:\n\t\t\t\tbuf.append(1, input[special]);\n\t\t\t\tbreak;\n\t\t}\n\n\t\tstart = special + 1;\n\n\t\tif (special < input.size())\n\t\t{\n\t\t\tspecial = input.find_first_of(specials, start);\n\t\t}\n\t\telse\n\t\t{\n\t\t\tspecial = LogString::npos;\n\t\t}\n\t}\n\n\tif (start < input.size())\n\t{\n\t\tbuf.append(input, start, input.size() - start);\n\t}\n}\n\nvoid Transform::appendEscapingCDATA(\n\tLogString& buf, const LogString& input)\n{\n\tstatic const WideLife CDATA_END(LOG4CXX_STR(\"]]>\"));\n\tstatic const WideLife CDATA_EMBEDED_END(LOG4CXX_STR(\"]]>]]>\n#include \n#include \n#include \n#include \n#include \n#include \n#include \n#include \n\nusing namespace LOG4CXX_NS;\nusing namespace LOG4CXX_NS::helpers;\nusing namespace LOG4CXX_NS::net;\n\nIMPLEMENT_LOG4CXX_OBJECT(TelnetAppender)\n\nstruct TelnetAppender::TelnetAppenderPriv : public AppenderSkeletonPrivate\n{\n\t\nTelnetAppenderPriv( int port, int maxConnections ) : AppenderSkeletonPrivate(),\n\t\tport(port),\n\t\tconnections(maxConnections),\n\t\tencoding(LOG4CXX_STR(\"UTF-8\")),\n\t\tencoder(CharsetEncoder::getUTF8Encoder()),\n\t\tsh(),\n\t\tactiveConnections(0) {}\n\n\tint port;\n\tConnectionList connections;\n\tLogString encoding;\n\tLOG4CXX_NS::helpers::CharsetEncoderPtr encoder;\n\tstd::unique_ptr serverSocket;\n\tstd::thread sh;\n\tsize_t activeConnections;\n};\n\n#define _priv static_cast(m_priv.get())\n\n/** The default telnet server port */\nconst int TelnetAppender::DEFAULT_PORT = 23;\n\n/** The maximum number of concurrent connections */\nconst int TelnetAppender::MAX_CONNECTIONS = 20;\n\nTelnetAppender::TelnetAppender()\n\t: AppenderSkeleton (std::make_unique(DEFAULT_PORT, MAX_CONNECTIONS))\n{\n}\n\nTelnetAppender::~TelnetAppender()\n{\n\tfinalize();\n}\n\nvoid TelnetAppender::activateOptions(Pool& /* p */)\n{\n\tif (_priv->serverSocket == NULL)\n\t{\n\t\t_priv->serverSocket = ServerSocket::create(_priv->port);\n\t\t_priv->serverSocket->setSoTimeout(1000);\n\t}\n\n\t_priv->sh = ThreadUtility::instance()->createThread( LOG4CXX_STR(\"TelnetAppender\"), &TelnetAppender::acceptConnections, this );\n}\n\nvoid TelnetAppender::setOption(const LogString& option,\n\tconst LogString& value)\n{\n\tif (StringHelper::equalsIgnoreCase(option, LOG4CXX_STR(\"PORT\"), LOG4CXX_STR(\"port\")))\n\t{\n\t\tsetPort(OptionConverter::toInt(value, DEFAULT_PORT));\n\t}\n\telse if (StringHelper::equalsIgnoreCase(option, LOG4CXX_STR(\"ENCODING\"), LOG4CXX_STR(\"encoding\")))\n\t{\n\t\tsetEncoding(value);\n\t}\n\telse\n\t{\n\t\tAppenderSkeleton::setOption(option, value);\n\t}\n}\n\nLogString TelnetAppender::getEncoding() const\n{\n\tstd::lock_guard lock(_priv->mutex);\n\treturn _priv->encoding;\n}\n\nvoid TelnetAppender::setEncoding(const LogString& value)\n{\n\tstd::lock_guard lock(_priv->mutex);\n\t_priv->encoder = CharsetEncoder::getEncoder(value);\n\t_priv->encoding = value;\n}\n\n\nvoid TelnetAppender::close()\n{\n\tstd::lock_guard lock(_priv->mutex);\n\n\tif (_priv->closed)\n\t{\n\t\treturn;\n\t}\n\n\t_priv->closed = true;\n\n\tSocketPtr nullSocket;\n\n\tfor (auto& item : _priv->connections)\n\t{\n\t\tif (item)\n\t\t{\n\t\t\titem->close();\n\t\t\titem = nullSocket;\n\t\t}\n\t}\n\n\tif (_priv->serverSocket != NULL)\n\t{\n\t\ttry\n\t\t{\n\t\t\t_priv->serverSocket->close();\n\t\t}\n\t\tcatch (Exception&)\n\t\t{\n\t\t}\n\t}\n\n\tif ( _priv->sh.joinable() )\n\t{\n\t\t_priv->sh.join();\n\t}\n\n\t_priv->activeConnections = 0;\n}\n\n\nvoid TelnetAppender::write(ByteBuffer& buf)\n{\n\tfor (auto& item :_priv->connections)\n\t{\n\t\tif (item)\n\t\t{\n\t\t\ttry\n\t\t\t{\n\t\t\t\tByteBuffer b(buf.current(), buf.remaining());\n\t\t\t\titem->write(b);\n\t\t\t}\n\t\t\tcatch (Exception&)\n\t\t\t{\n\t\t\t\t// The client has closed the connection, remove it from our list:\n\t\t\t\titem.reset();\n\t\t\t\t_priv->activeConnections--;\n\t\t\t}\n\t\t}\n\t}\n}\n\nvoid TelnetAppender::writeStatus(const SocketPtr& socket, const LogString& msg, Pool& p)\n{\n\tsize_t bytesSize = msg.size() * 2;\n\tchar* bytes = p.pstralloc(bytesSize);\n\n\tLogString::const_iterator msgIter(msg.begin());\n\tByteBuffer buf(bytes, bytesSize);\n\n\twhile (msgIter != msg.end())\n\t{\n\t\t_priv->encoder->encode(msg, msgIter, buf);\n\t\tbuf.flip();\n\t\tsocket->write(buf);\n\t\tbuf.clear();\n\t}\n}\n\nvoid TelnetAppender::append(const spi::LoggingEventPtr& event, Pool& p)\n{\n\tsize_t count = _priv->activeConnections;\n\n\tif (count > 0)\n\t{\n\t\tLogString msg;\n\t\t_priv->layout->format(msg, event, _priv->pool);\n\t\tmsg.append(LOG4CXX_STR(\"\\r\\n\"));\n\t\tsize_t bytesSize = msg.size() * 2;\n\t\tchar* bytes = p.pstralloc(bytesSize);\n\n\t\tLogString::const_iterator msgIter(msg.begin());\n\t\tByteBuffer buf(bytes, bytesSize);\n\n\t\tstd::lock_guard lock(_priv->mutex);\n\n\t\twhile (msgIter != msg.end())\n\t\t{\n\t\t\tlog4cxx_status_t stat = _priv->encoder->encode(msg, msgIter, buf);\n\t\t\tbuf.flip();\n\t\t\twrite(buf);\n\t\t\tbuf.clear();\n\n\t\t\tif (CharsetEncoder::isError(stat))\n\t\t\t{\n\t\t\t\tLogString unrepresented(1, 0x3F /* '?' */);\n\t\t\t\tLogString::const_iterator unrepresentedIter(unrepresented.begin());\n\t\t\t\tstat = _priv->encoder->encode(unrepresented, unrepresentedIter, buf);\n\t\t\t\tbuf.flip();\n\t\t\t\twrite(buf);\n\t\t\t\tbuf.clear();\n\t\t\t\tmsgIter++;\n\t\t\t}\n\t\t}\n\t}\n}\n\nvoid TelnetAppender::acceptConnections()\n{\n\n\t// main loop; is left when This->closed is != 0 after an accept()\n\twhile (true)\n\t{\n\t\ttry\n\t\t{\n\t\t\tSocketPtr newClient = _priv->serverSocket->accept();\n\t\t\tbool done = _priv->closed;\n\n\t\t\tif (done)\n\t\t\t{\n\t\t\t\tPool p;\n\t\t\t\twriteStatus(newClient, LOG4CXX_STR(\"Log closed.\\r\\n\"), p);\n\t\t\t\tnewClient->close();\n\n\t\t\t\tbreak;\n\t\t\t}\n\n\t\t\tsize_t count = _priv->activeConnections;\n\n\t\t\tif (count >= _priv->connections.size())\n\t\t\t{\n\t\t\t\tPool p;\n\t\t\t\twriteStatus(newClient, LOG4CXX_STR(\"Too many connections.\\r\\n\"), p);\n\t\t\t\tnewClient->close();\n\t\t\t}\n\t\t\telse\n\t\t\t{\n\t\t\t\t//\n\t\t\t\t// find unoccupied connection\n\t\t\t\t//\n\t\t\t\tstd::lock_guard lock(_priv->mutex);\n\n\t\t\t\tfor (auto& item : _priv->connections)\n\t\t\t\t{\n\t\t\t\t\tif (!item)\n\t\t\t\t\t{\n\t\t\t\t\t\titem = newClient;\n\t\t\t\t\t\t_priv->activeConnections++;\n\n\t\t\t\t\t\tbreak;\n\t\t\t\t\t}\n\t\t\t\t}\n\n\t\t\t\tPool p;\n\t\t\t\tLogString oss(LOG4CXX_STR(\"TelnetAppender v1.0 (\"));\n\t\t\t\tStringHelper::toString((int) count + 1, p, oss);\n\t\t\t\toss += LOG4CXX_STR(\" active connections)\\r\\n\\r\\n\");\n\t\t\t\twriteStatus(newClient, oss, p);\n\t\t\t}\n\t\t}\n\t\tcatch (InterruptedIOException&)\n\t\t{\n\t\t\tif (_priv->closed)\n\t\t\t{\n\t\t\t\tbreak;\n\t\t\t}\n\t\t}\n\t\tcatch (Exception& e)\n\t\t{\n\t\t\tif (!_priv->closed)\n\t\t\t{\n\t\t\t\tLogLog::error(LOG4CXX_STR(\"Encountered error while in SocketHandler loop.\"), e);\n\t\t\t}\n\t\t\telse\n\t\t\t{\n\t\t\t\tbreak;\n\t\t\t}\n\t\t}\n\t}\n\n}\n\nint TelnetAppender::getPort() const\n{\n\treturn _priv->port;\n}\n\nvoid TelnetAppender::setPort(int port1)\n{\n\t_priv->port = port1;\n}\n\n// Path: src/main/cpp/outputdebugstringappender.cpp\n/*\n * Licensed to the Apache Software Foundation (ASF) under one or more\n * contributor license agreements. See the NOTICE file distributed with\n * this work for additional information regarding copyright ownership.\n * The ASF licenses this file to You under the Apache License, Version 2.0\n * (the \"License\"); you may not use this file except in compliance with\n * the License. You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n */\n\n#if defined(_WIN32)\n#include \n#include \n#include \n\n#include \"windows.h\"\n\nusing namespace LOG4CXX_NS;\nusing namespace LOG4CXX_NS::helpers;\nusing namespace LOG4CXX_NS::nt;\n\nIMPLEMENT_LOG4CXX_OBJECT(OutputDebugStringAppender)\n\nOutputDebugStringAppender::OutputDebugStringAppender()\n{\n}\n\nvoid OutputDebugStringAppender::append(const spi::LoggingEventPtr& event, Pool& p)\n{\n\tLogString buf;\n\tthis->m_priv->layout->format(buf, event, p);\n#if LOG4CXX_WCHAR_T_API\n\tLOG4CXX_ENCODE_WCHAR(wstr, buf);\n\t::OutputDebugStringW(wstr.c_str());\n#else\n\tLOG4CXX_ENCODE_CHAR(str, buf);\n\t::OutputDebugStringA(str.c_str());\n#endif\n}\n\n#endif\n\n\n// Path: src/main/cpp/charsetdecoder.cpp\n/*\n * Licensed to the Apache Software Foundation (ASF) under one or more\n * contributor license agreements. See the NOTICE file distributed with\n * this work for additional information regarding copyright ownership.\n * The ASF licenses this file to You under the Apache License, Version 2.0\n * (the \"License\"); you may not use this file except in compliance with\n * the License. You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n */\n#define NOMINMAX /* tell windows not to define min/max macros */\n#include \n#include \n#include \n#include \n#include \n#include \n#include \n#if !defined(LOG4CXX)\n\t#define LOG4CXX 1\n#endif\n#include \n#include \n#include \n#include \n#include \n#include \n\nusing namespace LOG4CXX_NS;\nusing namespace LOG4CXX_NS::helpers;\n\nIMPLEMENT_LOG4CXX_OBJECT(CharsetDecoder)\n\n\nnamespace LOG4CXX_NS\n{\nnamespace helpers\n{\n\n#if APR_HAS_XLATE\n/**\n * Converts from an arbitrary encoding to LogString\n * using apr_xlate. Requires real iconv implementation,\n* apr-iconv will crash in use.\n */\nclass APRCharsetDecoder : public CharsetDecoder\n{\n\tpublic:\n\t\t/**\n\t\t * Creates a new instance.\n\t\t * @param frompage name of source encoding.\n\t\t */\n\t\tAPRCharsetDecoder(const LogString& frompage) : pool()\n\t\t{\n#if LOG4CXX_LOGCHAR_IS_WCHAR\n\t\t\tconst char* topage = \"WCHAR_T\";\n#endif\n#if LOG4CXX_LOGCHAR_IS_UTF8\n\t\t\tconst char* topage = \"UTF-8\";\n#endif\n#if LOG4CXX_LOGCHAR_IS_UNICHAR\n\t\t\tconst char* topage = \"UTF-16\";\n#endif\n\t\t\tstd::string fpage(Transcoder::encodeCharsetName(frompage));\n\t\t\tapr_status_t stat = apr_xlate_open(&convset,\n\t\t\t\t\ttopage,\n\t\t\t\t\tfpage.c_str(),\n\t\t\t\t\tpool.getAPRPool());\n\n\t\t\tif (stat != APR_SUCCESS)\n\t\t\t{\n\t\t\t\tthrow IllegalArgumentException(frompage);\n\t\t\t}\n\t\t}\n\n\t\t/**\n\t\t * Destructor.\n\t\t */\n\t\tvirtual ~APRCharsetDecoder()\n\t\t{\n\t\t}\n\n\t\tvirtual log4cxx_status_t decode(ByteBuffer& in,\n\t\t\tLogString& out)\n\t\t{\n\t\t\tenum { BUFSIZE = 256 };\n\t\t\tlogchar buf[BUFSIZE];\n\t\t\tconst apr_size_t initial_outbytes_left = BUFSIZE * sizeof(logchar);\n\t\t\tapr_status_t stat = APR_SUCCESS;\n\n\t\t\tif (in.remaining() == 0)\n\t\t\t{\n\t\t\t\tsize_t outbytes_left = initial_outbytes_left;\n\t\t\t\t{\n\t\t\t\t\tstd::unique_lock lock(mutex);\n\t\t\t\t\tstat = apr_xlate_conv_buffer((apr_xlate_t*) convset,\n\t\t\t\t\t\t\tNULL, NULL, (char*) buf, &outbytes_left);\n\t\t\t\t}\n\t\t\t\tout.append(buf, (initial_outbytes_left - outbytes_left) / sizeof(logchar));\n\t\t\t}\n\t\t\telse\n\t\t\t{\n\t\t\t\twhile (in.remaining() > 0 && stat == APR_SUCCESS)\n\t\t\t\t{\n\t\t\t\t\tsize_t inbytes_left = in.remaining();\n\t\t\t\t\tsize_t initial_inbytes_left = inbytes_left;\n\t\t\t\t\tsize_t pos = in.position();\n\t\t\t\t\tapr_size_t outbytes_left = initial_outbytes_left;\n\t\t\t\t\t{\n\t\t\t\t\t\tstd::unique_lock lock(mutex);\n\t\t\t\t\t\tstat = apr_xlate_conv_buffer((apr_xlate_t*) convset,\n\t\t\t\t\t\t\t\tin.data() + pos,\n\t\t\t\t\t\t\t\t&inbytes_left,\n\t\t\t\t\t\t\t\t(char*) buf,\n\t\t\t\t\t\t\t\t&outbytes_left);\n\t\t\t\t\t}\n\t\t\t\t\tout.append(buf, (initial_outbytes_left - outbytes_left) / sizeof(logchar));\n\t\t\t\t\tin.position(pos + (initial_inbytes_left - inbytes_left));\n\t\t\t\t}\n\t\t\t}\n\n\t\t\treturn stat;\n\t\t}\n\n\tprivate:\n\t\tAPRCharsetDecoder(const APRCharsetDecoder&);\n\t\tAPRCharsetDecoder& operator=(const APRCharsetDecoder&);\n\t\tLOG4CXX_NS::helpers::Pool pool;\n\t\tstd::mutex mutex;\n\t\tapr_xlate_t* convset;\n};\n\n#endif\n\n#if LOG4CXX_LOGCHAR_IS_WCHAR && LOG4CXX_HAS_MBSRTOWCS\n/**\n* Converts from the default multi-byte string to\n* LogString using mbstowcs.\n*\n*/\nclass MbstowcsCharsetDecoder : public CharsetDecoder\n{\n\tpublic:\n\t\tMbstowcsCharsetDecoder()\n\t\t{\n\t\t}\n\n\t\tvirtual ~MbstowcsCharsetDecoder()\n\t\t{\n\t\t}\n\n\tprivate:\n\t\tinline log4cxx_status_t append(LogString& out, const wchar_t* buf)\n\t\t{\n\t\t\tout.append(buf);\n\t\t\treturn APR_SUCCESS;\n\t\t}\n\n\t\tvirtual log4cxx_status_t decode(ByteBuffer& in,\n\t\t\tLogString& out)\n\t\t{\n\t\t\tlog4cxx_status_t stat = APR_SUCCESS;\n\t\t\tenum { BUFSIZE = 256 };\n\t\t\twchar_t wbuf[BUFSIZE];\n\t\t\tchar cbuf[BUFSIZE*4];\n\n\t\t\tmbstate_t mbstate;\n\t\t\tmemset(&mbstate, 0, sizeof(mbstate));\n\n\t\t\twhile (in.remaining() > 0)\n\t\t\t{\n\t\t\t\tconst char* src = in.current();\n\n\t\t\t\tif (*src == 0)\n\t\t\t\t{\n\t\t\t\t\tout.append(1, (logchar) 0);\n\t\t\t\t\tin.position(in.position() + 1);\n\t\t\t\t}\n\t\t\t\telse\n\t\t\t\t{\n\t\t\t\t\tauto available = std::min(sizeof (cbuf) - 1, in.remaining());\n\t\t\t\t\tstrncpy(cbuf, src, available);\n\t\t\t\t\tcbuf[available] = 0;\n\t\t\t\t\tsrc = cbuf;\n\t\t\t\t\tsize_t wCharCount = mbsrtowcs(wbuf,\n\t\t\t\t\t\t\t&src,\n\t\t\t\t\t\t\tBUFSIZE - 1,\n\t\t\t\t\t\t\t&mbstate);\n\t\t\t\t\tauto converted = src - cbuf;\n\t\t\t\t\tin.position(in.position() + converted);\n\n\t\t\t\t\tif (wCharCount == (size_t) -1) // Illegal byte sequence?\n\t\t\t\t\t{\n\t\t\t\t\t\tLogString msg(LOG4CXX_STR(\"Illegal byte sequence at \"));\n\t\t\t\t\t\tmsg.append(std::to_wstring(in.position()));\n\t\t\t\t\t\tmsg.append(LOG4CXX_STR(\" of \"));\n\t\t\t\t\t\tmsg.append(std::to_wstring(in.limit()));\n\t\t\t\t\t\tLogLog::warn(msg);\n\t\t\t\t\t\tstat = APR_BADCH;\n\t\t\t\t\t\tbreak;\n\t\t\t\t\t}\n\t\t\t\t\telse\n\t\t\t\t\t{\n\t\t\t\t\t\twbuf[wCharCount] = 0;\n\t\t\t\t\t\tstat = append(out, wbuf);\n\t\t\t\t\t}\n\t\t\t\t}\n\t\t\t}\n\n\t\t\treturn stat;\n\t\t}\n\n\n\n\tprivate:\n\t\tMbstowcsCharsetDecoder(const MbstowcsCharsetDecoder&);\n\t\tMbstowcsCharsetDecoder& operator=(const MbstowcsCharsetDecoder&);\n};\n#endif\n\n\n/**\n* Decoder used when the external and internal charsets\n* are the same.\n*\n*/\nclass TrivialCharsetDecoder : public CharsetDecoder\n{\n\tpublic:\n\t\tTrivialCharsetDecoder()\n\t\t{\n\t\t}\n\n\t\tvirtual ~TrivialCharsetDecoder()\n\t\t{\n\t\t}\n\n\t\tvirtual log4cxx_status_t decode(ByteBuffer& in,\n\t\t\tLogString& out)\n\t\t{\n\t\t\tsize_t remaining = in.remaining();\n\n\t\t\tif ( remaining > 0)\n\t\t\t{\n\t\t\t\tconst logchar* src = (const logchar*) (in.data() + in.position());\n\t\t\t\tsize_t count = remaining / sizeof(logchar);\n\t\t\t\tout.append(src, count);\n\t\t\t\tin.position(in.position() + remaining);\n\t\t\t}\n\n\t\t\treturn APR_SUCCESS;\n\t\t}\n\n\n\n\tprivate:\n\t\tTrivialCharsetDecoder(const TrivialCharsetDecoder&);\n\t\tTrivialCharsetDecoder& operator=(const TrivialCharsetDecoder&);\n};\n\n\n#if LOG4CXX_LOGCHAR_IS_UTF8\ntypedef TrivialCharsetDecoder UTF8CharsetDecoder;\n#else\n/**\n* Converts from UTF-8 to std::wstring\n*\n*/\nclass UTF8CharsetDecoder : public CharsetDecoder\n{\n\tpublic:\n\t\tUTF8CharsetDecoder()\n\t\t{\n\t\t}\n\n\t\tvirtual ~UTF8CharsetDecoder()\n\t\t{\n\t\t}\n\n\tprivate:\n\t\tvirtual log4cxx_status_t decode(ByteBuffer& in,\n\t\t\tLogString& out)\n\t\t{\n\t\t\tif (in.remaining() > 0)\n\t\t\t{\n\t\t\t\tstd::string tmp(in.current(), in.remaining());\n\t\t\t\tstd::string::const_iterator iter = tmp.begin();\n\n\t\t\t\twhile (iter != tmp.end())\n\t\t\t\t{\n\t\t\t\t\tunsigned int sv = Transcoder::decode(tmp, iter);\n\n\t\t\t\t\tif (sv == 0xFFFF)\n\t\t\t\t\t{\n\t\t\t\t\t\tsize_t offset = iter - tmp.begin();\n\t\t\t\t\t\tin.position(in.position() + offset);\n\t\t\t\t\t\treturn APR_BADARG;\n\t\t\t\t\t}\n\t\t\t\t\telse\n\t\t\t\t\t{\n\t\t\t\t\t\tTranscoder::encode(sv, out);\n\t\t\t\t\t}\n\t\t\t\t}\n\n\t\t\t\tin.position(in.limit());\n\t\t\t}\n\n\t\t\treturn APR_SUCCESS;\n\t\t}\n\n\tprivate:\n\t\tUTF8CharsetDecoder(const UTF8CharsetDecoder&);\n\t\tUTF8CharsetDecoder& operator=(const UTF8CharsetDecoder&);\n};\n#endif\n\n/**\n* Converts from ISO-8859-1 to LogString.\n*\n*/\nclass ISOLatinCharsetDecoder : public CharsetDecoder\n{\n\tpublic:\n\t\tISOLatinCharsetDecoder()\n\t\t{\n\t\t}\n\n\t\tvirtual ~ISOLatinCharsetDecoder()\n\t\t{\n\t\t}\n\n\tprivate:\n\t\tvirtual log4cxx_status_t decode(ByteBuffer& in,\n\t\t\tLogString& out)\n\t\t{\n\t\t\tif (in.remaining() > 0)\n\t\t\t{\n\n\t\t\t\tconst unsigned char* src = (unsigned char*) in.current();\n\t\t\t\tconst unsigned char* srcEnd = src + in.remaining();\n\n\t\t\t\twhile (src < srcEnd)\n\t\t\t\t{\n\t\t\t\t\tunsigned int sv = *(src++);\n\t\t\t\t\tTranscoder::encode(sv, out);\n\t\t\t\t}\n\n\t\t\t\tin.position(in.limit());\n\t\t\t}\n\n\t\t\treturn APR_SUCCESS;\n\t\t}\n\n\n\n\tprivate:\n\t\tISOLatinCharsetDecoder(const ISOLatinCharsetDecoder&);\n\t\tISOLatinCharsetDecoder& operator=(const ISOLatinCharsetDecoder&);\n};\n\n\n/**\n* Converts from US-ASCII to LogString.\n*\n*/\nclass USASCIICharsetDecoder : public CharsetDecoder\n{\n\tpublic:\n\t\tUSASCIICharsetDecoder()\n\t\t{\n\t\t}\n\n\t\tvirtual ~USASCIICharsetDecoder()\n\t\t{\n\t\t}\n\n\tprivate:\n\n\t\tvirtual log4cxx_status_t decode(ByteBuffer& in,\n\t\t\tLogString& out)\n\t\t{\n\t\t\tlog4cxx_status_t stat = APR_SUCCESS;\n\n\t\t\tif (in.remaining() > 0)\n\t\t\t{\n\n\t\t\t\tconst unsigned char* src = (unsigned char*) in.current();\n\t\t\t\tconst unsigned char* srcEnd = src + in.remaining();\n\n\t\t\t\twhile (src < srcEnd)\n\t\t\t\t{\n\t\t\t\t\tunsigned char sv = *src;\n\n\t\t\t\t\tif (sv < 0x80)\n\t\t\t\t\t{\n\t\t\t\t\t\tsrc++;\n\t\t\t\t\t\tTranscoder::encode(sv, out);\n\t\t\t\t\t}\n\t\t\t\t\telse\n\t\t\t\t\t{\n\t\t\t\t\t\tstat = APR_BADARG;\n\t\t\t\t\t\tbreak;\n\t\t\t\t\t}\n\t\t\t\t}\n\n\t\t\t\tin.position(src - (const unsigned char*) in.data());\n\t\t\t}\n\n\t\t\treturn stat;\n\t\t}\n\n\n\n\tprivate:\n\t\tUSASCIICharsetDecoder(const USASCIICharsetDecoder&);\n\t\tUSASCIICharsetDecoder& operator=(const USASCIICharsetDecoder&);\n};\n\n/**\n * Charset decoder that uses current locale settings.\n */\nclass LocaleCharsetDecoder : public CharsetDecoder\n{\n\tpublic:\n\t\tLocaleCharsetDecoder() : state()\n\t\t{\n\t\t}\n\t\tlog4cxx_status_t decode(ByteBuffer& in, LogString& out) override\n\t\t{\n\t\t\tlog4cxx_status_t result = APR_SUCCESS;\n\t\t\tconst char* p = in.current();\n\t\t\tsize_t i = in.position();\n\t\t\tsize_t remain = in.limit() - i;\n#if !LOG4CXX_CHARSET_EBCDIC\n\t\t\tif (std::mbsinit(&this->state)) // ByteBuffer not partially decoded?\n\t\t\t{\n\t\t\t\t// Copy single byte characters\n\t\t\t\tfor (; 0 < remain && ((unsigned int) *p) < 0x80; --remain, ++i, p++)\n\t\t\t\t{\n\t\t\t\t\tout.append(1, *p);\n\t\t\t\t}\n\t\t\t}\n#endif\n\t\t\t// Decode characters that may be represented by multiple bytes\n\t\t\twhile (0 < remain)\n\t\t\t{\n\t\t\t\twchar_t ch = 0;\n\t\t\t\tsize_t n = std::mbrtowc(&ch, p, remain, &this->state);\n\t\t\t\tif (0 == n) // NULL encountered?\n\t\t\t\t{\n\t\t\t\t\t++i;\n\t\t\t\t\tbreak;\n\t\t\t\t}\n\t\t\t\tif (static_cast(-1) == n) // decoding error?\n\t\t\t\t{\n\t\t\t\t\tresult = APR_BADARG;\n\t\t\t\t\tbreak;\n\t\t\t\t}\n\t\t\t\tif (static_cast(-2) == n) // incomplete sequence?\n\t\t\t\t{\n\t\t\t\t\tbreak;\n\t\t\t\t}\n\t\t\t\tTranscoder::encode(static_cast(ch), out);\n\t\t\t\tremain -= n;\n\t\t\t\ti += n;\n\t\t\t\tp += n;\n\t\t\t}\n\t\t\tin.position(i);\n\t\t\treturn result;\n\t\t}\n\n\tprivate:\n\t\tstd::mbstate_t state;\n};\n\n\n\n} // namespace helpers\n\n} //namespace log4cxx\n\n\nCharsetDecoder::CharsetDecoder()\n{\n}\n\n\nCharsetDecoder::~CharsetDecoder()\n{\n}\n\nCharsetDecoder* CharsetDecoder::createDefaultDecoder()\n{\n#if LOG4CXX_CHARSET_UTF8\n\treturn new UTF8CharsetDecoder();\n#elif LOG4CXX_CHARSET_ISO88591 || defined(_WIN32_WCE)\n\treturn new ISOLatinCharsetDecoder();\n#elif LOG4CXX_CHARSET_USASCII\n\treturn new USASCIICharsetDecoder();\n#elif LOG4CXX_LOGCHAR_IS_WCHAR && LOG4CXX_HAS_MBSRTOWCS\n\treturn new MbstowcsCharsetDecoder();\n#else\n\treturn new LocaleCharsetDecoder();\n#endif\n}\n\nCharsetDecoderPtr CharsetDecoder::getDefaultDecoder()\n{\n\tstatic WideLife decoder(createDefaultDecoder());\n\n\t//\n\t// if invoked after static variable destruction\n\t// (if logging is called in the destructor of a static object)\n\t// then create a new decoder.\n\t//\n\tif (decoder.value() == 0)\n\t{\n\t\treturn CharsetDecoderPtr( createDefaultDecoder() );\n\t}\n\n\treturn decoder;\n}\n\nCharsetDecoderPtr CharsetDecoder::getUTF8Decoder()\n{\n\tstatic WideLife decoder(new UTF8CharsetDecoder());\n\n\t//\n\t// if invoked after static variable destruction\n\t// (if logging is called in the destructor of a static object)\n\t// then create a new decoder.\n\t//\n\tif (decoder.value() == 0)\n\t{\n\t\treturn std::make_shared();\n\t}\n\n\treturn decoder;\n}\n\nCharsetDecoderPtr CharsetDecoder::getISOLatinDecoder()\n{\n\treturn std::make_shared();\n}\n\n\nCharsetDecoderPtr CharsetDecoder::getDecoder(const LogString& charset)\n{\n\tif (StringHelper::equalsIgnoreCase(charset, LOG4CXX_STR(\"UTF-8\"), LOG4CXX_STR(\"utf-8\")) ||\n\t\tStringHelper::equalsIgnoreCase(charset, LOG4CXX_STR(\"UTF8\"), LOG4CXX_STR(\"utf8\")) ||\n\t\tStringHelper::equalsIgnoreCase(charset, LOG4CXX_STR(\"CP65001\"), LOG4CXX_STR(\"cp65001\")))\n\t{\n\t\treturn std::make_shared();\n\t}\n\telse if (StringHelper::equalsIgnoreCase(charset, LOG4CXX_STR(\"C\"), LOG4CXX_STR(\"c\")) ||\n\t\tcharset == LOG4CXX_STR(\"646\") ||\n\t\tStringHelper::equalsIgnoreCase(charset, LOG4CXX_STR(\"US-ASCII\"), LOG4CXX_STR(\"us-ascii\")) ||\n\t\tStringHelper::equalsIgnoreCase(charset, LOG4CXX_STR(\"ISO646-US\"), LOG4CXX_STR(\"iso646-US\")) ||\n\t\tStringHelper::equalsIgnoreCase(charset, LOG4CXX_STR(\"ANSI_X3.4-1968\"), LOG4CXX_STR(\"ansi_x3.4-1968\")) ||\n\t\tStringHelper::equalsIgnoreCase(charset, LOG4CXX_STR(\"CP20127\"), LOG4CXX_STR(\"cp20127\")))\n\t{\n\t\treturn std::make_shared();\n\t}\n\telse if (StringHelper::equalsIgnoreCase(charset, LOG4CXX_STR(\"ISO-8859-1\"), LOG4CXX_STR(\"iso-8859-1\")) ||\n\t\tStringHelper::equalsIgnoreCase(charset, LOG4CXX_STR(\"ISO-LATIN-1\"), LOG4CXX_STR(\"iso-latin-1\")) ||\n\t\tStringHelper::equalsIgnoreCase(charset, LOG4CXX_STR(\"CP1252\"), LOG4CXX_STR(\"cp1252\")))\n\t{\n\t\treturn std::make_shared();\n\t}\n\telse if (StringHelper::equalsIgnoreCase(charset, LOG4CXX_STR(\"LOCALE\"), LOG4CXX_STR(\"locale\")))\n\t{\n\t\treturn std::make_shared();\n\t}\n\n#if APR_HAS_XLATE\n\treturn std::make_shared(charset);\n#else\n\tthrow IllegalArgumentException(charset);\n#endif\n}\n\n\n\n\n\n\n\n// Path: src/main/cpp/throwableinformationpatternconverter.cpp\n/*\n * Licensed to the Apache Software Foundation (ASF) under one or more\n * contributor license agreements. See the NOTICE file distributed with\n * this work for additional information regarding copyright ownership.\n * The ASF licenses this file to You under the Apache License, Version 2.0\n * (the \"License\"); you may not use this file except in compliance with\n * the License. You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n */\n\n#include \n#include \n#include \n#include \n#include \n#include \n\nusing namespace LOG4CXX_NS;\nusing namespace LOG4CXX_NS::pattern;\nusing namespace LOG4CXX_NS::spi;\nusing namespace LOG4CXX_NS::helpers;\n\nstruct ThrowableInformationPatternConverter::ThrowableInformationPatternConverterPrivate :\n\tpublic PatternConverterPrivate\n{\n\tThrowableInformationPatternConverterPrivate( const LogString& name, const LogString& style, bool shortReport ) :\n\t\tPatternConverterPrivate( name, style ),\n\t\tshortReport(shortReport) {}\n\n\t/**\n\t * If \"short\", only first line of throwable report will be formatted.\n\t */\n\tconst bool shortReport;\n};\n\nIMPLEMENT_LOG4CXX_OBJECT(ThrowableInformationPatternConverter)\n\nThrowableInformationPatternConverter::ThrowableInformationPatternConverter(bool shortReport1) :\n\tLoggingEventPatternConverter(\n\t\tstd::make_unique(\n\t\t\tLOG4CXX_STR(\"Throwable\"),\n\t\t\tLOG4CXX_STR(\"throwable\"),\n\t\t\tshortReport1))\n{\n}\n\nPatternConverterPtr ThrowableInformationPatternConverter::newInstance(\n\tconst std::vector& options)\n{\n\tif (options.size() > 0 && options[0].compare(LOG4CXX_STR(\"short\")) == 0)\n\t{\n\t\tstatic WideLife shortConverter = std::make_shared(true);\n\t\treturn shortConverter;\n\t}\n\n\tstatic WideLife converter = std::make_shared(false);\n\treturn converter;\n}\n\nvoid ThrowableInformationPatternConverter::format(\n\tconst LoggingEventPtr& /* event */,\n\tLogString& /* toAppendTo */,\n\tPool& /* p */) const\n{\n}\n\n/**\n * This converter obviously handles throwables.\n * @return true.\n */\nbool ThrowableInformationPatternConverter::handlesThrowable() const\n{\n\treturn true;\n}\n\n// Path: src/main/cpp/htmllayout.cpp\n/*\n * Licensed to the Apache Software Foundation (ASF) under one or more\n * contributor license agreements. See the NOTICE file distributed with\n * this work for additional information regarding copyright ownership.\n * The ASF licenses this file to You under the Apache License, Version 2.0\n * (the \"License\"); you may not use this file except in compliance with\n * the License. You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n */\n\n#include \n#include \n#include \n#include \n#include \n#include \n#include \n#include \n#include \n#include \n\nusing namespace LOG4CXX_NS;\nusing namespace LOG4CXX_NS::helpers;\nusing namespace LOG4CXX_NS::spi;\n\nstruct HTMLLayout::HTMLLayoutPrivate\n{\n\tHTMLLayoutPrivate()\n\t\t: locationInfo(false)\n\t\t, title(LOG4CXX_STR(\"Log4cxx Log Messages\"))\n\t\t, dateFormat()\n\t\t, expectedPatternLength(100)\n\t\t{}\n\n\t// Print no location info by default\n\tbool locationInfo; //= false\n\n\tLogString title;\n\n\thelpers::ISO8601DateFormat dateFormat;\n\n\t// Expected length of a formatted event excluding the message text\n\tsize_t expectedPatternLength;\n};\n\nIMPLEMENT_LOG4CXX_OBJECT(HTMLLayout)\n\n\nHTMLLayout::HTMLLayout()\n\t: m_priv(std::make_unique())\n{\n\tm_priv->dateFormat.setTimeZone(TimeZone::getGMT());\n\tm_priv->expectedPatternLength = getFormattedEventCharacterCount() * 2;\n}\n\nHTMLLayout::~HTMLLayout() {}\n\n\nvoid HTMLLayout::setOption(const LogString& option,\n\tconst LogString& value)\n{\n\n\tif (StringHelper::equalsIgnoreCase(option,\n\t\t\tLOG4CXX_STR(\"TITLE\"), LOG4CXX_STR(\"title\")))\n\t{\n\t\tsetTitle(value);\n\t}\n\telse if (StringHelper::equalsIgnoreCase(option,\n\t\t\tLOG4CXX_STR(\"LOCATIONINFO\"), LOG4CXX_STR(\"locationinfo\")))\n\t{\n\t\tsetLocationInfo(OptionConverter::toBoolean(value, false));\n\t\tm_priv->expectedPatternLength = getFormattedEventCharacterCount() * 2;\n\t}\n}\n\nvoid HTMLLayout::format(LogString& output,\n\tconst spi::LoggingEventPtr& event,\n\tPool& p) const\n{\n\toutput.reserve(m_priv->expectedPatternLength + event->getMessage().size());\n\toutput.append(LOG4CXX_EOL);\n\toutput.append(LOG4CXX_STR(\"\"));\n\toutput.append(LOG4CXX_EOL);\n\toutput.append(LOG4CXX_STR(\"\"));\n\n\tm_priv->dateFormat.format(output, event->getTimeStamp(), p);\n\n\n\toutput.append(LOG4CXX_STR(\"\"));\n\toutput.append(LOG4CXX_EOL);\n\n\toutput.append(LOG4CXX_STR(\"getThreadName());\n\toutput.append(threadName);\n\toutput.append(LOG4CXX_STR(\" thread\\\">\"));\n\toutput.append(threadName);\n\toutput.append(LOG4CXX_STR(\"\"));\n\toutput.append(LOG4CXX_EOL);\n\n\toutput.append(LOG4CXX_STR(\"\"));\n\n\tif (event->getLevel()->equals(Level::getDebug()))\n\t{\n\t\toutput.append(LOG4CXX_STR(\"\"));\n\t\toutput.append(event->getLevel()->toString());\n\t\toutput.append(LOG4CXX_STR(\"\"));\n\t}\n\telse if (event->getLevel()->isGreaterOrEqual(Level::getWarn()))\n\t{\n\t\toutput.append(LOG4CXX_STR(\"\"));\n\t\toutput.append(event->getLevel()->toString());\n\t\toutput.append(LOG4CXX_STR(\"\"));\n\t}\n\telse\n\t{\n\t\toutput.append(event->getLevel()->toString());\n\t}\n\n\toutput.append(LOG4CXX_STR(\"\"));\n\toutput.append(LOG4CXX_EOL);\n\n\toutput.append(LOG4CXX_STR(\"getLoggerName());\n\toutput.append(LOG4CXX_STR(\" logger\\\">\"));\n\tTransform::appendEscapingTags(output, event->getLoggerName());\n\toutput.append(LOG4CXX_STR(\"\"));\n\toutput.append(LOG4CXX_EOL);\n\n\tif (m_priv->locationInfo)\n\t{\n\t\toutput.append(LOG4CXX_STR(\"\"));\n\t\tconst LocationInfo& locInfo = event->getLocationInformation();\n\t\tLOG4CXX_DECODE_CHAR(fileName, locInfo.getFileName());\n\t\tTransform::appendEscapingTags(output, fileName);\n\t\toutput.append(1, (logchar) 0x3A /* ':' */);\n\t\tint line = event->getLocationInformation().getLineNumber();\n\n\t\tif (line != 0)\n\t\t{\n\t\t\tStringHelper::toString(line, p, output);\n\t\t}\n\n\t\toutput.append(LOG4CXX_STR(\"\"));\n\t\toutput.append(LOG4CXX_EOL);\n\t}\n\n\toutput.append(LOG4CXX_STR(\"\"));\n\tTransform::appendEscapingTags(output, event->getRenderedMessage());\n\toutput.append(LOG4CXX_STR(\"\"));\n\toutput.append(LOG4CXX_EOL);\n\toutput.append(LOG4CXX_STR(\"\"));\n\toutput.append(LOG4CXX_EOL);\n\n\tLogString ndcVal;\n\n\tif (event->getNDC(ndcVal))\n\t{\n\t\toutput.append(LOG4CXX_STR(\"\"));\n\t\toutput.append(LOG4CXX_STR(\"NDC: \"));\n\t\tTransform::appendEscapingTags(output, ndcVal);\n\t\toutput.append(LOG4CXX_STR(\"\"));\n\t\toutput.append(LOG4CXX_EOL);\n\t}\n}\n\nvoid HTMLLayout::appendHeader(LogString& output, Pool& p)\n{\n\toutput.append(LOG4CXX_STR(\"\"));\n\toutput.append(LOG4CXX_EOL);\n\toutput.append(LOG4CXX_STR(\"\"));\n\toutput.append(LOG4CXX_EOL);\n\toutput.append(LOG4CXX_STR(\"\"));\n\toutput.append(LOG4CXX_EOL);\n\toutput.append(LOG4CXX_STR(\"\"));\n\toutput.append(m_priv->title);\n\toutput.append(LOG4CXX_STR(\"\"));\n\toutput.append(LOG4CXX_EOL);\n\toutput.append(LOG4CXX_STR(\"\"));\n\toutput.append(LOG4CXX_EOL);\n\toutput.append(LOG4CXX_STR(\"\"));\n\toutput.append(LOG4CXX_EOL);\n\toutput.append(LOG4CXX_STR(\"\"));\n\toutput.append(LOG4CXX_EOL);\n\toutput.append(LOG4CXX_STR(\"
\"));\n\toutput.append(LOG4CXX_EOL);\n\toutput.append(LOG4CXX_STR(\"Log session start time \"));\n\n\tm_priv->dateFormat.format(output, Date::currentTime(), p);\n\n\toutput.append(LOG4CXX_STR(\"
\"));\n\toutput.append(LOG4CXX_EOL);\n\toutput.append(LOG4CXX_STR(\"
\"));\n\toutput.append(LOG4CXX_EOL);\n\toutput.append(LOG4CXX_STR(\"\"));\n\toutput.append(LOG4CXX_EOL);\n\toutput.append(LOG4CXX_STR(\"\"));\n\toutput.append(LOG4CXX_EOL);\n\toutput.append(LOG4CXX_STR(\"\"));\n\toutput.append(LOG4CXX_EOL);\n\toutput.append(LOG4CXX_STR(\"\"));\n\toutput.append(LOG4CXX_EOL);\n\toutput.append(LOG4CXX_STR(\"\"));\n\toutput.append(LOG4CXX_EOL);\n\toutput.append(LOG4CXX_STR(\"\"));\n\toutput.append(LOG4CXX_EOL);\n\n\tif (m_priv->locationInfo)\n\t{\n\t\toutput.append(LOG4CXX_STR(\"\"));\n\t\toutput.append(LOG4CXX_EOL);\n\t}\n\n\toutput.append(LOG4CXX_STR(\"\"));\n\toutput.append(LOG4CXX_EOL);\n\toutput.append(LOG4CXX_STR(\"\"));\n\toutput.append(LOG4CXX_EOL);\n}\n\nvoid HTMLLayout::appendFooter(LogString& output, Pool& /* pool */ )\n{\n\toutput.append(LOG4CXX_STR(\"
TimeThreadLevelLoggerFile:LineMessage
\"));\n\toutput.append(LOG4CXX_EOL);\n\toutput.append(LOG4CXX_STR(\"
\"));\n\toutput.append(LOG4CXX_EOL);\n\toutput.append(LOG4CXX_STR(\"\"));\n}\n\nvoid HTMLLayout::setLocationInfo(bool locationInfoFlag)\n{\n\tm_priv->locationInfo = locationInfoFlag;\n}\n\nbool HTMLLayout::getLocationInfo() const\n{\n\treturn m_priv->locationInfo;\n}\n\nvoid HTMLLayout::setTitle(const LogString& title1)\n{\n\tm_priv->title.assign(title1);\n}\n\nconst LogString& HTMLLayout::getTitle() const\n{\n\treturn m_priv->title;\n}\n\nLogString HTMLLayout::getContentType() const\n{\n\treturn LOG4CXX_STR(\"text/html\");\n}\n\nbool HTMLLayout::ignoresThrowable() const\n{\n\treturn false;\n}\n\n// Path: src/main/cpp/logger.cpp\n/*\n * Licensed to the Apache Software Foundation (ASF) under one or more\n * contributor license agreements. See the NOTICE file distributed with\n * this work for additional information regarding copyright ownership.\n * The ASF licenses this file to You under the Apache License, Version 2.0\n * (the \"License\"); you may not use this file except in compliance with\n * the License. You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n...\n", "needle_token_start": 2241, "needle_token_end": 2330, "code_context_ntokens": 15179, "output": ["```cpp\n// Path: src/main/cpp/telnetappender.cpp\n/*\n * Licensed to the Apache Software Foundation (ASF) under one or more\n * contributor license agreements. See the NOTICE file distributed with\n * this work for additional information regarding copyright ownership.\n * The ASF licenses this file to You under the Apache License, Version 2.0\n * (the \"License\"); you may not use this file except in compliance with\n * the License. You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n */\n\n#include \n#include \n#include \n#include \n#include \n#include \n#include \n#include \n#include \n\nusing namespace LOG4CXX_NS;\nusing namespace LOG4CXX_NS::helpers;\nusing namespace LOG4CXX_NS::net;\n\nIMPLEMENT_LOG4CXX_OBJECT(TelnetAppender)\n\nstruct TelnetAppender::TelnetAppenderPriv : public AppenderSkeletonPrivate\n{\n\t\nTelnetAppenderPriv( int port, int maxConnections ) : AppenderSkeletonPrivate(),\n\t\tport(port),\n\t\tconnections(maxConnections),\n\t\tencoding(LOG4CXX_STR(\"UTF-8\")),\n\t\tencoder(CharsetEncoder::getUTF8Encoder()),\n\t\tsh(),\n\t\tactiveConnections(0) {}\n\n\tint port;\n\tConnectionList connections;\n\tLogString encoding;\n\tLOG4CXX_NS::helpers::CharsetEncoderPtr encoder;\n\tstd::unique_ptr serverSocket;\n\tstd::thread sh;\n\tsize_t activeConnections;\n};\n\n#define _priv static_cast(m_priv.get())\n\n/** The default telnet server port */\nconst int TelnetAppender::DEFAULT_PORT = 23;\n\n/** The maximum number of concurrent connections */\nconst int TelnetAppender::MAX_CONNECTIONS = 20;\n\nTelnetAppender::TelnetAppender()\n\t: AppenderSkeleton (std::make_unique(DEFAULT_PORT, MAX_CONNECTIONS))\n{\n}\n\nTelnetAppender::~TelnetAppender()\n{\n\tfinalize();\n}\n\nvoid TelnetAppender::activateOptions(Pool& /* p */)\n{\n\tif (_priv->serverSocket == NULL)\n\t{\n\t\t_priv->serverSocket = ServerSocket::create(_priv->port);\n\t\t_priv->serverSocket->setSoTimeout(1000);\n\t}\n\n\t_priv->sh = ThreadUtility::instance()->createThread( LOG4CXX_STR(\"TelnetAppender\"), &TelnetAppender::acceptConnections, this );\n}\n\nvoid TelnetAppender::setOption(const LogString& option,\n\tconst LogString& value)\n{\n\tif (StringHelper::equalsIgnoreCase(option, LOG4CXX_STR(\"PORT\"), LOG4CXX_STR(\"port\")))\n\t{\n\t\tsetPort(OptionConverter::toInt(value, DEFAULT_PORT));\n\t}\n\telse if (StringHelper::equalsIgnoreCase(option, LOG4CXX_STR(\"ENCODING\"), LOG4CXX_STR(\"encoding\")))\n\t{\n\t\tsetEncoding(value);\n\t}\n\telse\n\t{\n\t\tAppenderSkeleton::setOption(option, value);\n\t}\n}\n\nLogString TelnetAppender::getEncoding() const\n{\n\tstd::lock_guard lock(_priv->mutex);\n\treturn _priv->encoding;\n}\n\nvoid TelnetAppender::setEncoding(const LogString& value)\n{\n\tstd::lock_guard lock(_priv->mutex);\n\t_priv->encoder = CharsetEncoder::getEncoder(value);\n\t_priv->encoding = value;\n}\n\n\nvoid TelnetAppender::close()\n{\n\tstd::lock_guard lock(_priv->mutex);\n\n\tif (_priv->closed)\n\t{\n\t\treturn;\n\t}\n\n\t_priv->closed = true;\n\n\tSocketPtr nullSocket;\n\n\tfor (auto& item : _priv->connections)\n\t{\n\t\tif (item)\n\t\t{\n\t\t\titem->close();\n\t\t\titem = nullSocket;\n\t\t}\n\t}\n\n\tif (_priv->serverSocket != NULL)\n\t{\n\t\ttry\n\t\t{\n\t\t\t_priv->serverSocket->close();\n\t\t}\n\t\tcatch (Exception&)\n\t\t{\n\t\t}\n\t}\n\n\tif ( _priv->sh.joinable() )\n\t{\n\t\t_priv->sh.join();\n\t}\n\n\t_priv->activeConnections = 0;\n}\n\n\nvoid TelnetAppender"]} +{"repo": "apache/logging-log4cxx", "name": "HierarchyPrivate", "language": "cpp", "path": "src/main/cpp/hierarchy.cpp", "position_ratio": 0.25, "description": "\nFunction Description:\n1. **Purpose**: To manage and maintain the state and configuration of a logging hierarchy, ensuring proper initialization and tracking of various status flags and settings related to logging.\n2. **Input**: No direct inputs as it initializes internal state when an instance of the containing class is created.\n3. **Output**: Does not produce a direct output but initializes and maintains internal state variables that affect the behavior of the logging system.\n4. **Procedure**: Initializes internal state with default values, including flags for configuration status and warnings, and sets a default logging level threshold. This setup is crucial for the correct functioning and configuration management of the logging system.\n\n", "instruction": "Based on the function description and code context, please retrieve and repeat the exact described function from the code context in a code block wrapped by ```:", "template": "instruction\ncode_context\ndescription\ninstruction", "code_context": "// Path: src/main/cpp/rollingfileappender.cpp\n/*\n * Licensed to the Apache Software Foundation (ASF) under one or more\n * contributor license agreements. See the NOTICE file distributed with\n * this work for additional information regarding copyright ownership.\n * The ASF licenses this file to You under the Apache License, Version 2.0\n * (the \"License\"); you may not use this file except in compliance with\n * the License. You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n */\n\n#include \n#include \n#include \n#include \n#include \n#include \n#include \n#include \n#include \n#include \n#include \n#include \n#include \n\nusing namespace LOG4CXX_NS;\nusing namespace LOG4CXX_NS::rolling;\nusing namespace LOG4CXX_NS::helpers;\nusing namespace LOG4CXX_NS::spi;\n\nstruct RollingFileAppender::RollingFileAppenderPriv : public FileAppenderPriv\n{\n\tRollingFileAppenderPriv() :\n\t\tFileAppenderPriv(),\n\t\tfileLength(0) {}\n\n\t/**\n\t * Triggering policy.\n\t */\n\tTriggeringPolicyPtr triggeringPolicy;\n\n\t/**\n\t * Rolling policy.\n\t */\n\tRollingPolicyPtr rollingPolicy;\n\n\t/**\n\t * Length of current active log file.\n\t */\n\tsize_t fileLength;\n\n\t/**\n\t * save the loggingevent\n\t */\n\tspi::LoggingEventPtr _event;\n};\n\n#define _priv static_cast(m_priv.get())\n\nIMPLEMENT_LOG4CXX_OBJECT(RollingFileAppender)\n\n\n/**\n * Construct a new instance.\n */\nRollingFileAppender::RollingFileAppender() :\n\tFileAppender (std::make_unique())\n{\n}\n\nvoid RollingFileAppender::setOption(const LogString& option, const LogString& value)\n{\n\tif (StringHelper::equalsIgnoreCase(option,\n\t\t\tLOG4CXX_STR(\"MAXFILESIZE\"), LOG4CXX_STR(\"maxfilesize\"))\n\t\t|| StringHelper::equalsIgnoreCase(option,\n\t\t\tLOG4CXX_STR(\"MAXIMUMFILESIZE\"), LOG4CXX_STR(\"maximumfilesize\")))\n\t{\n\t\tsetMaxFileSize(value);\n\t}\n\telse if (StringHelper::equalsIgnoreCase(option,\n\t\t\tLOG4CXX_STR(\"MAXBACKUPINDEX\"), LOG4CXX_STR(\"maxbackupindex\"))\n\t\t|| StringHelper::equalsIgnoreCase(option,\n\t\t\tLOG4CXX_STR(\"MAXIMUMBACKUPINDEX\"), LOG4CXX_STR(\"maximumbackupindex\")))\n\t{\n\t\tsetMaxBackupIndex(StringHelper::toInt(value));\n\t}\n\telse if (StringHelper::equalsIgnoreCase(option,\n\t\t\tLOG4CXX_STR(\"FILEDATEPATTERN\"), LOG4CXX_STR(\"filedatepattern\")))\n\t{\n\t\tsetDatePattern(value);\n\t}\n\telse\n\t{\n\t\tFileAppender::setOption(option, value);\n\t}\n}\n\nint RollingFileAppender::getMaxBackupIndex() const\n{\n\tint result = 1;\n\tif (auto fwrp = LOG4CXX_NS::cast(_priv->rollingPolicy))\n\t\tresult = fwrp->getMaxIndex();\n\treturn result;\n}\n\nvoid RollingFileAppender::setMaxBackupIndex(int maxBackups)\n{\n\tauto fwrp = LOG4CXX_NS::cast(_priv->rollingPolicy);\n\tif (!fwrp)\n\t{\n\t\tfwrp = std::make_shared();\n\t\tfwrp->setFileNamePattern(getFile() + LOG4CXX_STR(\".%i\"));\n\t\t_priv->rollingPolicy = fwrp;\n\t}\n\tfwrp->setMaxIndex(maxBackups);\n}\n\nsize_t RollingFileAppender::getMaximumFileSize() const\n{\n\tsize_t result = 10 * 1024 * 1024;\n\tif (auto sbtp = LOG4CXX_NS::cast(_priv->triggeringPolicy))\n\t\tresult = sbtp->getMaxFileSize();\n\treturn result;\n}\n\nvoid RollingFileAppender::setMaximumFileSize(size_t maxFileSize)\n{\n\tauto sbtp = LOG4CXX_NS::cast(_priv->triggeringPolicy);\n\tif (!sbtp)\n\t{\n\t\tsbtp = std::make_shared();\n\t\t_priv->triggeringPolicy = sbtp;\n\t}\n\tsbtp->setMaxFileSize(maxFileSize);\n}\n\nvoid RollingFileAppender::setMaxFileSize(const LogString& value)\n{\n\tsetMaximumFileSize(OptionConverter::toFileSize(value, long(getMaximumFileSize() + 1)));\n}\n\nLogString RollingFileAppender::makeFileNamePattern(const LogString& datePattern)\n{\n\tLogString result(getFile());\n\tbool inLiteral = false;\n\tbool inPattern = false;\n\n\tfor (size_t i = 0; i < datePattern.length(); i++)\n\t{\n\t\tif (datePattern[i] == 0x27 /* '\\'' */)\n\t\t{\n\t\t\tinLiteral = !inLiteral;\n\n\t\t\tif (inLiteral && inPattern)\n\t\t\t{\n\t\t\t\tresult.append(1, (logchar) 0x7D /* '}' */);\n\t\t\t\tinPattern = false;\n\t\t\t}\n\t\t}\n\t\telse\n\t\t{\n\t\t\tif (!inLiteral && !inPattern)\n\t\t\t{\n\t\t\t\tconst logchar dbrace[] = { 0x25, 0x64, 0x7B, 0 }; // \"%d{\"\n\t\t\t\tresult.append(dbrace);\n\t\t\t\tinPattern = true;\n\t\t\t}\n\n\t\t\tresult.append(1, datePattern[i]);\n\t\t}\n\t}\n\n\tif (inPattern)\n\t{\n\t\tresult.append(1, (logchar) 0x7D /* '}' */);\n\t}\n\treturn result;\n}\n\nvoid RollingFileAppender::setDatePattern(const LogString& newPattern)\n{\n\tauto tbrp = LOG4CXX_NS::cast(_priv->rollingPolicy);\n\tif (!tbrp)\n\t{\n\t\ttbrp = std::make_shared();\n\t\t_priv->rollingPolicy = tbrp;\n\t}\n\ttbrp->setFileNamePattern(makeFileNamePattern(newPattern));\n}\n\n/**\n * Prepare instance of use.\n */\nvoid RollingFileAppender::activateOptions(Pool& p)\n{\n\tif (!_priv->rollingPolicy)\n\t{\n\t\tLogLog::warn(LOG4CXX_STR(\"No rolling policy configured for the appender named [\")\n\t\t\t+ _priv->name + LOG4CXX_STR(\"].\"));\n\t\tauto fwrp = std::make_shared();\n\t\tfwrp->setFileNamePattern(getFile() + LOG4CXX_STR(\".%i\"));\n\t\t_priv->rollingPolicy = fwrp;\n\t}\n\n\t//\n\t// if no explicit triggering policy and rolling policy is both.\n\t//\n\tif (!_priv->triggeringPolicy)\n\t{\n\t\tTriggeringPolicyPtr trig = LOG4CXX_NS::cast(_priv->rollingPolicy);\n\n\t\tif (trig != NULL)\n\t\t{\n\t\t\t_priv->triggeringPolicy = trig;\n\t\t}\n\t}\n\n\tif (!_priv->triggeringPolicy)\n\t{\n\t\tLogLog::warn(LOG4CXX_STR(\"No triggering policy configured for the appender named [\")\n\t\t\t+ _priv->name + LOG4CXX_STR(\"].\"));\n\t\t_priv->triggeringPolicy = std::make_shared();\n\t}\n\n\t{\n\t\tstd::lock_guard lock(_priv->mutex);\n\t\t_priv->triggeringPolicy->activateOptions(p);\n\t\t_priv->rollingPolicy->activateOptions(p);\n\n\t\ttry\n\t\t{\n\t\t\tRolloverDescriptionPtr rollover1 =\n\t\t\t\t_priv->rollingPolicy->initialize(getFile(), getAppend(), p);\n\n\t\t\tif (rollover1 != NULL)\n\t\t\t{\n\t\t\t\tActionPtr syncAction(rollover1->getSynchronous());\n\n\t\t\t\tif (syncAction != NULL)\n\t\t\t\t{\n\t\t\t\t\tsyncAction->execute(p);\n\t\t\t\t}\n\n\t\t\t\t_priv->fileName = rollover1->getActiveFileName();\n\t\t\t\t_priv->fileAppend = rollover1->getAppend();\n\n\t\t\t\t//\n\t\t\t\t// async action not yet implemented\n\t\t\t\t//\n\t\t\t\tActionPtr asyncAction(rollover1->getAsynchronous());\n\n\t\t\t\tif (asyncAction != NULL)\n\t\t\t\t{\n\t\t\t\t\tasyncAction->execute(p);\n\t\t\t\t}\n\t\t\t}\n\n\t\t\tFile activeFile;\n\t\t\tactiveFile.setPath(getFile());\n\n\t\t\tif (getAppend())\n\t\t\t{\n\t\t\t\t_priv->fileLength = activeFile.length(p);\n\t\t\t}\n\t\t\telse\n\t\t\t{\n\t\t\t\t_priv->fileLength = 0;\n\t\t\t}\n\n\t\t\tFileAppender::activateOptionsInternal(p);\n\t\t}\n\t\tcatch (std::exception&)\n\t\t{\n\t\t\tLogLog::warn(\n\t\t\t\tLogString(LOG4CXX_STR(\"Exception will initializing RollingFileAppender named \"))\n\t\t\t\t+ getName());\n\t\t}\n\t}\n}\n\n/**\n Implements the usual roll over behaviour.\n\n

If MaxBackupIndex is positive, then files\n {File.1, ..., File.MaxBackupIndex -1}\n are renamed to {File.2, ...,\n File.MaxBackupIndex}. Moreover, File is\n renamed File.1 and closed. A new File is\n created to receive further log output.\n\n

If MaxBackupIndex is equal to zero, then the\n File is truncated with no backup files created.\n\n * @return true if rollover performed.\n */\nbool RollingFileAppender::rollover(Pool& p)\n{\n\tstd::lock_guard lock(_priv->mutex);\n\treturn rolloverInternal(p);\n}\n\nbool RollingFileAppender::rolloverInternal(Pool& p)\n{\n\t//\n\t// can't roll without a policy\n\t//\n\tif (_priv->rollingPolicy != NULL)\n\t{\n\n\t\t{\n\t\t\t\ttry\n\t\t\t\t{\n\t\t\t\t\tRolloverDescriptionPtr rollover1(_priv->rollingPolicy->rollover(this->getFile(), this->getAppend(), p));\n\n\t\t\t\t\tif (rollover1 != NULL)\n\t\t\t\t\t{\n\t\t\t\t\t\tif (rollover1->getActiveFileName() == getFile())\n\t\t\t\t\t\t{\n...\n// Path: src/main/cpp/hierarchy.cpp\n/*\n * Licensed to the Apache Software Foundation (ASF) under one or more\n * contributor license agreements. See the NOTICE file distributed with\n * this work for additional information regarding copyright ownership.\n * The ASF licenses this file to You under the Apache License, Version 2.0\n * (the \"License\"); you may not use this file except in compliance with\n * the License. You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n */\n\n#include \n#include \n#include \n#include \n#include \n#include \n#include \n#include \n#include \n#include \n#include \n#include \n#if !defined(LOG4CXX)\n\t#define LOG4CXX 1\n#endif\n#include \n#include \n#include \"assert.h\"\n\n\nusing namespace LOG4CXX_NS;\nusing namespace LOG4CXX_NS::spi;\nusing namespace LOG4CXX_NS::helpers;\n\n\ntypedef std::map LoggerMap;\ntypedef std::map ProvisionNodeMap;\n\nstruct Hierarchy::HierarchyPrivate\n{\n\t\nHierarchyPrivate()\n\t\t: configured(false)\n\t\t, emittedNoAppenderWarning(false)\n\t\t, emittedNoResourceBundleWarning(false)\n\t\t, thresholdInt(Level::ALL_INT)\n\t{\n\t}\n\n\thelpers::Pool pool;\n\tmutable std::recursive_mutex mutex;\n\tmutable std::mutex configuredMutex;\n\tbool configured;\n\tbool emittedNoAppenderWarning;\n\tbool emittedNoResourceBundleWarning;\n\tint thresholdInt;\n\n\tspi::HierarchyEventListenerList listeners;\n\tLoggerPtr root;\n\tLevelPtr threshold;\n\tLoggerMap loggers;\n\tProvisionNodeMap provisionNodes;\n\n\tstd::vector allAppenders;\n};\n\nIMPLEMENT_LOG4CXX_OBJECT(Hierarchy)\n\nHierarchy::Hierarchy() :\n\tm_priv(std::make_unique())\n{\n}\n\nHierarchy::~Hierarchy()\n{\n\tstd::lock_guard lock(m_priv->mutex);\n\tfor (auto& item : m_priv->loggers)\n\t{\n\t\tif (auto& pLogger = item.second)\n\t\t{\n\t\t\tpLogger->removeHierarchy();\n\t\t\tpLogger->removeAllAppenders();\n\t\t}\n\t}\n\tif (m_priv->root)\n\t{\n\t\tm_priv->root->removeHierarchy();\n\t\tm_priv->root->removeAllAppenders();\n\t}\n}\n\nvoid Hierarchy::addHierarchyEventListener(const spi::HierarchyEventListenerPtr& listener)\n{\n\tstd::lock_guard lock(m_priv->mutex);\n\n\tif (std::find(m_priv->listeners.begin(), m_priv->listeners.end(), listener) != m_priv->listeners.end())\n\t{\n\t\tLogLog::warn(LOG4CXX_STR(\"Ignoring attempt to add an existent listener.\"));\n\t}\n\telse\n\t{\n\t\tm_priv->listeners.push_back(listener);\n\t}\n}\n\nvoid Hierarchy::removeHierarchyEventListener(const spi::HierarchyEventListenerPtr& listener)\n{\n\tstd::lock_guard lock(m_priv->mutex);\n\n auto found = std::find(m_priv->listeners.begin(), m_priv->listeners.end(), listener);\n if(found != m_priv->listeners.end()){\n m_priv->listeners.erase(found);\n }\n}\n\nvoid Hierarchy::clear()\n{\n\tstd::lock_guard lock(m_priv->mutex);\n\tm_priv->loggers.clear();\n}\n\nvoid Hierarchy::emitNoAppenderWarning(const Logger* logger)\n{\n\tbool emitWarning = false;\n\t{\n\t\tstd::lock_guard lock(m_priv->mutex);\n\t\temitWarning = !m_priv->emittedNoAppenderWarning;\n\t\tm_priv->emittedNoAppenderWarning = true;\n\t}\n\n\t// No appender in hierarchy, warn user only once.\n\tif (emitWarning)\n\t{\n\t\tLogLog::warn(((LogString) LOG4CXX_STR(\"No appender could be found for logger (\"))\n\t\t\t+ logger->getName() + LOG4CXX_STR(\").\"));\n\t\tLogLog::warn(LOG4CXX_STR(\"Please initialize the log4cxx system properly.\"));\n\t}\n}\n\n\nLoggerPtr Hierarchy::exists(const LogString& name)\n{\n\tstd::lock_guard lock(m_priv->mutex);\n\n\tLoggerPtr logger;\n\tLoggerMap::iterator it = m_priv->loggers.find(name);\n\n\tif (it != m_priv->loggers.end())\n\t{\n\t\tlogger = it->second;\n\t}\n\n\n\treturn logger;\n}\n\nvoid Hierarchy::setThreshold(const LevelPtr& l)\n{\n\tif (l != 0)\n\t{\n\t\tstd::lock_guard lock(m_priv->mutex);\n\t\tsetThresholdInternal(l);\n\t}\n}\n\nvoid Hierarchy::setThreshold(const LogString& levelStr)\n{\n\tLevelPtr l(Level::toLevelLS(levelStr, 0));\n\n\tif (l != 0)\n\t{\n\t\tsetThreshold(l);\n\t}\n\telse\n\t{\n\t\tLogLog::warn(((LogString) LOG4CXX_STR(\"No level could be found named \\\"\"))\n\t\t\t+ levelStr + LOG4CXX_STR(\"\\\".\"));\n\t}\n}\n\nvoid Hierarchy::setThresholdInternal(const LevelPtr& l)\n{\n\tm_priv->thresholdInt = l->toInt();\n\tm_priv->threshold = l;\n\n\tif (m_priv->thresholdInt != Level::ALL_INT)\n\t{\n\t\tm_priv->configured = true;\n\t}\n}\n\nvoid Hierarchy::fireAddAppenderEvent(const Logger* logger, const Appender* appender)\n{\n\tsetConfigured(true);\n\tHierarchyEventListenerList clonedList;\n\t{\n\t\tstd::lock_guard lock(m_priv->mutex);\n\t\tclonedList = m_priv->listeners;\n\t}\n\n\tfor (auto& item : clonedList)\n\t\titem->addAppenderEvent(logger, appender);\n}\n\nvoid Hierarchy::fireRemoveAppenderEvent(const Logger* logger, const Appender* appender)\n\n{\n\tHierarchyEventListenerList clonedList;\n\t{\n\t\tstd::lock_guard lock(m_priv->mutex);\n\t\tclonedList = m_priv->listeners;\n\t}\n\tfor (auto& item : clonedList)\n\t\titem->removeAppenderEvent(logger, appender);\n}\n\nLevelPtr Hierarchy::getThreshold() const\n{\n\treturn m_priv->threshold ? m_priv->threshold : Level::getAll();\n}\n\nLoggerPtr Hierarchy::getLogger(const LogString& name)\n{\n\tstatic WideLife defaultFactory = std::make_shared();\n\treturn getLogger(name, defaultFactory);\n}\n\nLoggerPtr Hierarchy::getLogger(const LogString& name,\n\tconst spi::LoggerFactoryPtr& factory)\n{\n\tauto root = getRootLogger();\n\tstd::lock_guard lock(m_priv->mutex);\n\n\tLoggerMap::iterator it = m_priv->loggers.find(name);\n\tLoggerPtr result;\n\n\tif (it != m_priv->loggers.end())\n\t{\n\t\tresult = it->second;\n\t}\n\tif (!result && factory)\n\t{\n\t\tLoggerPtr logger(factory->makeNewLoggerInstance(m_priv->pool, name));\n\t\tlogger->setHierarchy(this);\n\t\tm_priv->loggers.insert(LoggerMap::value_type(name, logger));\n\n\t\tProvisionNodeMap::iterator it2 = m_priv->provisionNodes.find(name);\n\n\t\tif (it2 != m_priv->provisionNodes.end())\n\t\t{\n\t\t\tupdateChildren(it2->second, logger);\n\t\t\tm_priv->provisionNodes.erase(it2);\n\t\t}\n\n\t\tupdateParents(logger, root);\n\t\tresult = logger;\n\t}\n\treturn result;\n\n}\n\nLoggerList Hierarchy::getCurrentLoggers() const\n{\n\tstd::lock_guard lock(m_priv->mutex);\n\n\tLoggerList v;\n\tfor (auto& item : m_priv->loggers)\n\t{\n\t\tif (auto pLogger = item.second)\n\t\t\tv.push_back(pLogger);\n\t}\n\treturn v;\n}\n\nLoggerPtr Hierarchy::getRootLogger() const\n{\n\tstd::lock_guard lock(m_priv->mutex);\n\tif (!m_priv->root)\n\t{\n\t\tm_priv->root = std::make_shared(m_priv->pool, Level::getDebug());\n\t\tm_priv->root->setHierarchy(const_cast(this));\n\t}\n\n\treturn m_priv->root;\n}\n\nbool Hierarchy::isDisabled(int level) const\n{\n\treturn m_priv->thresholdInt > level;\n}\n\nvoid Hierarchy::ensureIsConfigured(std::function configurator)\n{\n\tstd::unique_lock lock(m_priv->configuredMutex);\n\tif (!m_priv->configured)\n\t{\n\t\tconfigurator();\n\t\tm_priv->configured = true;\n\t}\n}\n\nvoid Hierarchy::resetConfiguration()\n{\n\tstd::lock_guard lock(m_priv->mutex);\n\n\tif (m_priv->root)\n\t{\n\t\tm_priv->root->setLevel(Level::getDebug());\n\t\tm_priv->root->setResourceBundle(0);\n\t}\n\tsetThresholdInternal(Level::getAll());\n\n\tshutdownInternal();\n\n\tfor (auto& item : m_priv->loggers)\n\t{\n\t\tif (auto pLogger = item.second)\n\t\t{\n\t\t\tpLogger->setLevel(0);\n\t\t\tpLogger->setAdditivity(true);\n\t\t\tpLogger->setResourceBundle(0);\n\t\t}\n\t}\n}\n\nvoid Hierarchy::shutdown()\n{\n\tstd::lock_guard lock(m_priv->mutex);\n\n\tshutdownInternal();\n}\n\nvoid Hierarchy::shutdownInternal()\n{\n\tm_priv->configured = false;\n\n\t// begin by closing nested appenders\n\tif (m_priv->root)\n\t\tm_priv->root->closeNestedAppenders();\n\n\tfor (auto& item : m_priv->loggers)\n\t{\n\t\tif (auto pLogger = item.second)\n\t\t\tpLogger->closeNestedAppenders();\n\t}\n\n\t// then, remove all appenders\n\tif (m_priv->root)\n\t\tm_priv->root->removeAllAppenders();\n\n\tfor (auto& item : m_priv->loggers)\n\t{\n\t\tif (auto pLogger = item.second)\n\t\t\tpLogger->removeAllAppenders();\n\t}\n}\n\nvoid Hierarchy::updateParents(const LoggerPtr& logger, const LoggerPtr& root)\n{\n\tconst LogString name(logger->getName());\n\tsize_t length = name.size();\n\tbool parentFound = false;\n\n\n\t// if name = \"w.x.y.z\", loop through \"w.x.y\", \"w.x\" and \"w\", but not \"w.x.y.z\"\n\tfor (size_t i = name.find_last_of(0x2E /* '.' */, length - 1);\n\t\t(i != LogString::npos) && (i != 0);\n\t\ti = name.find_last_of(0x2E /* '.' */, i - 1))\n\t{\n\t\tLogString substr = name.substr(0, i);\n\n\t\tLoggerMap::iterator it = m_priv->loggers.find(substr);\n\n\t\tif (it != m_priv->loggers.end())\n\t\t{\n\t\t\tparentFound = true;\n\t\t\tlogger->setParent( it->second );\n\t\t\tbreak; // no need to update the ancestors of the closest ancestor\n\t\t}\n\t\telse\n\t\t{\n\t\t\tProvisionNodeMap::iterator it2 = m_priv->provisionNodes.find(substr);\n\n\t\t\tif (it2 != m_priv->provisionNodes.end())\n\t\t\t{\n\t\t\t\tit2->second.push_back(logger);\n\t\t\t}\n\t\t\telse\n\t\t\t{\n\t\t\t\tProvisionNode node(1, logger);\n\t\t\t\tm_priv->provisionNodes.insert(\n\t\t\t\t\tProvisionNodeMap::value_type(substr, node));\n\t\t\t}\n\t\t}\n\t}\n\n\t// If we could not find any existing parents, then link with root.\n\tif (!parentFound)\n\t{\n\t\tlogger->setParent( root );\n\t}\n}\n\nvoid Hierarchy::updateChildren(ProvisionNode& pn, const LoggerPtr& logger)\n{\n\tfor (auto& l : pn)\n\t{\n\t\t// Unless this child already points to a correct (lower) parent,\n\t\t// make logger.parent point to l.parent and l.parent to logger.\n\t\tif (!StringHelper::startsWith(l->getParent()->getName(), logger->getName()))\n\t\t{\n\t\t\tlogger->setParent( l->getParent() );\n\t\t\tl->setParent( logger );\n\t\t}\n\t}\n \n}\n\nvoid Hierarchy::updateChildren(const Logger* parent)\n{\n\tfor (auto& item : m_priv->loggers)\n\t{\n\t\tfor (auto l = item.second; l; l = l->getParent())\n\t\t{\n\t\t\tif (l->getParent().get() == parent)\n\t\t\t{\n\t\t\t\titem.second->updateThreshold();\n\t\t\t\tbreak;\n\t\t\t}\n\t\t}\n\t}\n}\n\nvoid Hierarchy::setConfigured(bool newValue)\n{\n\tstd::unique_lock lock(m_priv->configuredMutex, std::try_to_lock);\n\tif (lock.owns_lock()) // Not being auto-configured?\n\t\tm_priv->configured = newValue;\n}\n\nbool Hierarchy::isConfigured()\n{\n\tstd::unique_lock lock(m_priv->configuredMutex); // Blocks while auto-configuration is active\n\treturn m_priv->configured;\n}\n\nHierarchyPtr Hierarchy::create()\n{\n\tHierarchyPtr ret(new Hierarchy);\n\treturn ret;\n}\n\nvoid Hierarchy::clearAppenders()\n{\n\tm_priv->allAppenders.clear();\n}\n\nvoid Hierarchy::addAppender(AppenderPtr appender)\n{\n\tm_priv->allAppenders.push_back(appender);\n}\n\nbool Hierarchy::removeLogger(const LogString& name, bool ifNotUsed)\n{\n\tauto parentRefCount = [this](const LoggerPtr& child) -> int\n\t{\n\t\tint result = 0;\n\t\tfor (auto& node : m_priv->provisionNodes)\n\t\t{\n\t\t\tif (node.second.end() != std::find(node.second.begin(), node.second.end(), child))\n\t\t\t\t++result;\n\t\t}\n\t\treturn result;\n\t};\n\tbool result = false;\n\tstd::lock_guard lock(m_priv->mutex);\n\tauto it = m_priv->loggers.find(name);\n\tif (it == m_priv->loggers.end())\n\t\t;\n\telse if (ifNotUsed && 1 + parentRefCount(it->second) < it->second.use_count())\n\t\t;\n\telse\n\t{\n\t\tfor (auto& node : m_priv->provisionNodes)\n\t\t{\n\t\t\tfor (size_t i = node.second.size(); 0 < i; )\n\t\t\t{\n\t\t\t\tif (node.second[--i] == it->second)\n\t\t\t\t\tnode.second.erase(node.second.begin() + i);\n\t\t\t}\n\t\t}\n\t\tm_priv->loggers.erase(it);\n\t\tresult = true;\n\t}\n\treturn result;\n}\n\n// Path: src/main/cpp/fileoutputstream.cpp\n/*\n * Licensed to the Apache Software Foundation (ASF) under one or more\n * contributor license agreements. See the NOTICE file distributed with\n * this work for additional information regarding copyright ownership.\n * The ASF licenses this file to You under the Apache License, Version 2.0\n * (the \"License\"); you may not use this file except in compliance with\n * the License. You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n */\n\n#include \n#include \n#include \n#include \n#include \n#include \n#if !defined(LOG4CXX)\n\t#define LOG4CXX 1\n#endif\n#include \n\nusing namespace LOG4CXX_NS;\nusing namespace LOG4CXX_NS::helpers;\n\nstruct FileOutputStream::FileOutputStreamPrivate\n{\n\tFileOutputStreamPrivate() : fileptr(nullptr) {}\n\n\tPool pool;\n\tapr_file_t* fileptr;\n};\n\nIMPLEMENT_LOG4CXX_OBJECT(FileOutputStream)\n\nFileOutputStream::FileOutputStream(const LogString& filename,\n\tbool append) : m_priv(std::make_unique())\n{\n\tm_priv->fileptr = open(filename, append, m_priv->pool);\n}\n\nFileOutputStream::FileOutputStream(const logchar* filename,\n\tbool append) : m_priv(std::make_unique())\n{\n\tm_priv->fileptr = open(filename, append, m_priv->pool);\n}\n\napr_file_t* FileOutputStream::open(const LogString& filename,\n\tbool append, Pool& pool)\n{\n\tapr_fileperms_t perm = APR_OS_DEFAULT;\n\tapr_int32_t flags = APR_WRITE | APR_CREATE;\n\n\tif (append)\n\t{\n\t\tflags |= APR_APPEND;\n\t}\n\telse\n\t{\n\t\tflags |= APR_TRUNCATE;\n\t}\n\n\tFile fn;\n\tfn.setPath(filename);\n\tapr_file_t* fileptr = 0;\n\tapr_status_t stat = fn.open(&fileptr, flags, perm, pool);\n\n\tif (stat != APR_SUCCESS)\n\t{\n\t\tthrow IOException(stat);\n\t}\n\n\treturn fileptr;\n}\n\nFileOutputStream::~FileOutputStream()\n{\n\tif (m_priv->fileptr != NULL && !APRInitializer::isDestructed)\n\t{\n\t\tapr_file_close(m_priv->fileptr);\n\t}\n}\n\nvoid FileOutputStream::close(Pool& /* p */)\n{\n\tif (m_priv->fileptr != NULL)\n\t{\n\t\tapr_status_t stat = apr_file_close(m_priv->fileptr);\n\n\t\tif (stat != APR_SUCCESS)\n\t\t{\n\t\t\tthrow IOException(stat);\n\t\t}\n\n\t\tm_priv->fileptr = NULL;\n\t}\n}\n\nvoid FileOutputStream::flush(Pool& /* p */)\n{\n}\n\nvoid FileOutputStream::write(ByteBuffer& buf, Pool& /* p */ )\n{\n\tif (m_priv->fileptr == NULL)\n\t{\n\t\tthrow IOException(-1);\n\t}\n\n\tsize_t nbytes = buf.remaining();\n\tsize_t pos = buf.position();\n\tconst char* data = buf.data();\n\n\twhile (nbytes > 0)\n\t{\n\t\tapr_status_t stat = apr_file_write(\n\t\t\t\tm_priv->fileptr, data + pos, &nbytes);\n\n\t\tif (stat != APR_SUCCESS)\n\t\t{\n\t\t\tthrow IOException(stat);\n\t\t}\n\n\t\tpos += nbytes;\n\t\tbuf.position(pos);\n\t\tnbytes = buf.remaining();\n\t}\n}\n\napr_file_t* FileOutputStream::getFilePtr() const{\n\treturn m_priv->fileptr;\n}\n\n\n// Path: src/main/cpp/atexitregistry.cpp\n/*\n * Licensed to the Apache Software Foundation (ASF) under one or more\n * contributor license agreements. See the NOTICE file distributed with\n * this work for additional information regarding copyright ownership.\n * The ASF licenses this file to You under the Apache License, Version 2.0\n * (the \"License\"); you may not use this file except in compliance with\n * the License. You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n */\n#include \n#include \n#include \n\nusing namespace LOG4CXX_NS;\nusing namespace LOG4CXX_NS::helpers;\n\nnamespace\n{\n\tstruct AtExitRegistryImpl : public AtExitRegistry\n\t{\n\t\t~AtExitRegistryImpl()\n\t\t{\n\t\t\tstd::lock_guard lock(mutex);\n\t\t\twhile(!actions.empty())\n\t\t\t{\n\t\t\t\tstd::function action = std::move(actions.begin()->second);\n\t\t\t\tactions.erase(actions.begin());\n\t\t\t\taction();\n\t\t\t}\n\t\t}\n\n\t\tvoid add(void* key, std::function action)\n\t\t{\n\t\t\tstd::lock_guard lock(mutex);\n\t\t\tactions.emplace(key, std::move(action));\n\t\t}\n\n\t\tvoid del(void* key)\n\t\t{\n\t\t\tstd::lock_guard lock(mutex);\n\t\t\tactions.erase(key);\n\t\t}\n\n\tprivate:\n\t\tstd::recursive_mutex mutex;\n\t\tstd::map> actions;\n\t} s_instance;\n}\n\nAtExitRegistry& AtExitRegistry::instance()\n{\n\treturn s_instance;\n}\n\nvoid AtExitRegistry::add(void* key, std::function action)\n{\n\treturn s_instance.add(key, std::move(action));\n}\n\nvoid AtExitRegistry::del(void* key)\n{\n\treturn s_instance.del(key);\n}\n\n\n// Path: src/main/cpp/cacheddateformat.cpp\n/*\n * Licensed to the Apache Software Foundation (ASF) under one or more\n * contributor license agreements. See the NOTICE file distributed with\n * this work for additional information regarding copyright ownership.\n * The ASF licenses this file to You under the Apache License, Version 2.0\n * (the \"License\"); you may not use this file except in compliance with\n * the License. You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n */\n#define __STDC_CONSTANT_MACROS\n#define NOMINMAX /* tell wnidows not to define min/max macros */\n#include \n#include \n#include \n#include \n#include \n\nusing namespace LOG4CXX_NS;\nusing namespace LOG4CXX_NS::helpers;\nusing namespace LOG4CXX_NS::pattern;\n\nstruct CachedDateFormat::CachedDateFormatPriv\n{\n\tCachedDateFormatPriv(DateFormatPtr dateFormat, int expiration1) :\n\t\tformatter(dateFormat),\n\t\tmillisecondStart(0),\n\t\tslotBegin(std::numeric_limits::min()),\n\t\tcache(50, 0x20),\n\t\texpiration(expiration1),\n\t\tpreviousTime(std::numeric_limits::min())\n\t{}\n\n\t/**\n\t * Wrapped formatter.\n\t */\n\tLOG4CXX_NS::helpers::DateFormatPtr formatter;\n\n\t/**\n\t * Index of initial digit of millisecond pattern or\n\t * UNRECOGNIZED_MILLISECONDS or NO_MILLISECONDS.\n\t */\n\tmutable int millisecondStart;\n\n\t/**\n\t * Integral second preceding the previous convered Date.\n\t */\n\tmutable log4cxx_time_t slotBegin;\n\n\n\t/**\n\t * Cache of previous conversion.\n\t */\n\tmutable LogString cache;\n\n\n\t/**\n\t * Maximum validity period for the cache.\n\t * Typically 1, use cache for duplicate requests only, or\n\t * 1000000, use cache for requests within the same integral second.\n\t */\n\tconst int expiration;\n\n\t/**\n\t * Date requested in previous conversion.\n\t */\n\tmutable log4cxx_time_t previousTime;\n};\n\n\n/**\n* Supported digit set. If the wrapped DateFormat uses\n* a different unit set, the millisecond pattern\n* will not be recognized and duplicate requests\n* will use the cache.\n*/\nconst logchar CachedDateFormat::digits[] = { 0x30, 0x31, 0x32, 0x33, 0x34, 0x35, 0x36, 0x37, 0x38, 0x39, 0 };\n\n\n/**\n * First magic number (in microseconds) used to detect\n * the millisecond position.\n */\nconst int CachedDateFormat::magic1 = 654000;\n\n\n/**\n * Expected representation of first magic number in milliseconds.\n */\nconst logchar CachedDateFormat::magicString1[] = { 0x36, 0x35, 0x34, 0 };\n\n\n/**\n * Second magic number (in microseconds) used to detect\n * the millisecond position.\n */\nconst int CachedDateFormat::magic2 = 987000;\n\n\n/**\n * Expected representation of second magic number in milliseconds.\n */\nconst logchar CachedDateFormat::magicString2[] = { 0x39, 0x38, 0x37, 0};\n\n\n/**\n * Expected representation of 0 milliseconds.\n */\nconst logchar CachedDateFormat::zeroString[] = { 0x30, 0x30, 0x30, 0 };\n\n/**\n * Creates a new CachedDateFormat object.\n * @param dateFormat Date format, may not be null.\n * @param expiration maximum cached range in milliseconds.\n * If the dateFormat is known to be incompatible with the\n * caching algorithm, use a value of 0 to totally disable\n * caching or 1 to only use cache for duplicate requests.\n */\nCachedDateFormat::CachedDateFormat(const DateFormatPtr& dateFormat,\n\tint expiration1) :\n\tm_priv(std::make_unique(dateFormat, expiration1))\n{\n\tif (dateFormat == NULL)\n\t{\n\t\tthrow IllegalArgumentException(LOG4CXX_STR(\"dateFormat cannot be null\"));\n\t}\n\n\tif (expiration1 < 0)\n\t{\n\t\tthrow IllegalArgumentException(LOG4CXX_STR(\"expiration must be non-negative\"));\n\t}\n}\n\nCachedDateFormat::~CachedDateFormat() {}\n\n\n/**\n * Finds start of millisecond field in formatted time.\n * @param time long time, must be integral number of seconds\n * @param formatted String corresponding formatted string\n * @param formatter DateFormat date format\n * @return int position in string of first digit of milliseconds,\n * -1 indicates no millisecond field, -2 indicates unrecognized\n * field (likely RelativeTimeDateFormat)\n */\nint CachedDateFormat::findMillisecondStart(\n\tlog4cxx_time_t time, const LogString& formatted,\n\tconst DateFormatPtr& formatter,\n\tPool& pool)\n{\n\n\tlog4cxx_time_t slotBegin = (time / 1000000) * 1000000;\n\n\tif (slotBegin > time)\n\t{\n\t\tslotBegin -= 1000000;\n\t}\n\n\tint millis = (int) (time - slotBegin) / 1000;\n\n\t// the magic numbers are in microseconds\n\tint magic = magic1;\n\tLogString magicString(magicString1);\n\n\tif (millis == magic1 / 1000)\n\t{\n\t\tmagic = magic2;\n\t\tmagicString = magicString2;\n\t}\n\n\tLogString plusMagic;\n\tformatter->format(plusMagic, slotBegin + magic, pool);\n\n\t/**\n\t * If the string lengths differ then\n\t * we can't use the cache except for duplicate requests.\n\t */\n\tif (plusMagic.length() != formatted.length())\n\t{\n\t\treturn UNRECOGNIZED_MILLISECONDS;\n\t}\n\telse\n\t{\n\t\t// find first difference between values\n\t\tfor (LogString::size_type i = 0; i < formatted.length(); i++)\n\t\t{\n\t\t\tif (formatted[i] != plusMagic[i])\n\t\t\t{\n\t\t\t\t//\n\t\t\t\t// determine the expected digits for the base time\n\t\t\t\tconst logchar abc[] = { 0x41, 0x42, 0x43, 0 };\n\t\t\t\tLogString formattedMillis(abc);\n\t\t\t\tmillisecondFormat(millis, formattedMillis, 0);\n\n\t\t\t\tLogString plusZero;\n\t\t\t\tformatter->format(plusZero, slotBegin, pool);\n\n\t\t\t\t// Test if the next 1..3 characters match the magic string, main problem is that magic\n\t\t\t\t// available millis in formatted can overlap. Therefore the current i is not always the\n\t\t\t\t// index of the first millis char, but may be already within the millis. Besides that\n\t\t\t\t// the millis can occur everywhere in formatted. See LOGCXX-420 and following.\n\t\t\t\tsize_t magicLength = magicString.length();\n\t\t\t\tsize_t overlapping = magicString.find(plusMagic[i]);\n\t\t\t\tint possibleRetVal = int(i - overlapping);\n\n\t\t\t\tif (plusZero.length() == formatted.length()\n\t\t\t\t\t&& regionMatches(magicString, 0, plusMagic, possibleRetVal, magicLength)\n\t\t\t\t\t&& regionMatches(formattedMillis, 0, formatted, possibleRetVal, magicLength)\n\t\t\t\t\t&& regionMatches(zeroString, 0, plusZero, possibleRetVal, magicLength)\n\t\t\t\t\t// The following will and should fail for patterns with more than one SSS because\n\t\t\t\t\t// we only seem to be able to change one SSS in e.g. format and need to reformat the\n\t\t\t\t\t// whole string in other cases.\n\t\t\t\t\t&& (formatted.length() == possibleRetVal + magicLength\n\t\t\t\t\t\t|| plusZero.compare(possibleRetVal + magicLength,\n\t\t\t\t\t\t\tLogString::npos, plusMagic, possibleRetVal + magicLength, LogString::npos) == 0))\n\t\t\t\t{\n\t\t\t\t\treturn possibleRetVal;\n\t\t\t\t}\n\t\t\t\telse\n\t\t\t\t{\n\t\t\t\t\treturn UNRECOGNIZED_MILLISECONDS;\n\t\t\t\t}\n\t\t\t}\n\t\t}\n\t}\n\n\treturn NO_MILLISECONDS;\n}\n\n\n/**\n * Formats a millisecond count into a date/time string.\n *\n * @param now Number of milliseconds after midnight 1 Jan 1970 GMT.\n * @param sbuf the string buffer to write to\n */\nvoid CachedDateFormat::format(LogString& buf, log4cxx_time_t now, Pool& p) const\n{\n\n\t//\n\t// If the current requested time is identical to the previously\n\t// requested time, then append the cache contents.\n\t//\n\tif (now == m_priv->previousTime)\n\t{\n\t\tbuf.append(m_priv->cache);\n\t\treturn;\n\t}\n\n\t//\n\t// If millisecond pattern was not unrecognized\n\t// (that is if it was found or milliseconds did not appear)\n\t//\n\tif (m_priv->millisecondStart != UNRECOGNIZED_MILLISECONDS)\n\t{\n\t\t// Check if the cache is still valid.\n\t\t// If the requested time is within the same integral second\n\t\t// as the last request and a shorter expiration was not requested.\n\t\tif (now < m_priv->slotBegin + m_priv->expiration\n\t\t\t&& now >= m_priv->slotBegin\n\t\t\t&& now < m_priv->slotBegin + 1000000L)\n\t\t{\n\t\t\t//\n\t\t\t// if there was a millisecond field then update it\n\t\t\t//\n\t\t\tif (m_priv->millisecondStart >= 0)\n\t\t\t{\n\t\t\t\tmillisecondFormat((int) ((now - m_priv->slotBegin) / 1000), m_priv->cache, m_priv->millisecondStart);\n\t\t\t}\n\n\t\t\t//\n\t\t\t// update the previously requested time\n\t\t\t// (the slot begin should be unchanged)\n\t\t\tm_priv->previousTime = now;\n\t\t\tbuf.append(m_priv->cache);\n\n\t\t\treturn;\n\t\t}\n\t}\n\n\t//\n\t// could not use previous value.\n\t// Call underlying formatter to format date.\n\tm_priv->cache.erase(m_priv->cache.begin(), m_priv->cache.end());\n\tm_priv->formatter->format(m_priv->cache, now, p);\n\tbuf.append(m_priv->cache);\n\tm_priv->previousTime = now;\n\tm_priv->slotBegin = (m_priv->previousTime / 1000000) * 1000000;\n\n\tif (m_priv->slotBegin > m_priv->previousTime)\n\t{\n\t\tm_priv->slotBegin -= 1000000;\n\t}\n\n\t//\n\t// if the milliseconds field was previous found\n\t// then reevaluate in case it moved.\n\t//\n\tif (m_priv->millisecondStart >= 0)\n\t{\n\t\tm_priv->millisecondStart = findMillisecondStart(now, m_priv->cache, m_priv->formatter, p);\n\t}\n}\n\n/**\n * Formats a count of milliseconds (0-999) into a numeric representation.\n * @param millis Millisecond count between 0 and 999.\n * @buf String buffer, may not be null.\n * @offset Starting position in buffer, the length of the\n * buffer must be at least offset + 3.\n */\nvoid CachedDateFormat::millisecondFormat(int millis,\n\tLogString& buf,\n\tint offset)\n{\n\tbuf[offset] = digits[millis / 100];\n\tbuf[offset + 1] = digits[(millis / 10) % 10];\n\tbuf[offset + 2] = digits[millis % 10];\n}\n\n/**\n * Set timezone.\n *\n * @remarks Setting the timezone using getCalendar().setTimeZone()\n * will likely cause caching to misbehave.\n * @param timeZone TimeZone new timezone\n */\nvoid CachedDateFormat::setTimeZone(const TimeZonePtr& timeZone)\n{\n\tm_priv->formatter->setTimeZone(timeZone);\n\tm_priv->previousTime = std::numeric_limits::min();\n\tm_priv->slotBegin = std::numeric_limits::min();\n}\n\n\n\nvoid CachedDateFormat::numberFormat(LogString& s, int n, Pool& p) const\n{\n\tm_priv->formatter->numberFormat(s, n, p);\n}\n\n\n/**\n * Gets maximum cache validity for the specified SimpleDateTime\n * conversion pattern.\n * @param pattern conversion pattern, may not be null.\n * @returns Duration in microseconds from an integral second\n * that the cache will return consistent results.\n */\nint CachedDateFormat::getMaximumCacheValidity(const LogString& pattern)\n{\n\t//\n\t// If there are more \"S\" in the pattern than just one \"SSS\" then\n\t// (for example, \"HH:mm:ss,SSS SSS\"), then set the expiration to\n\t// one millisecond which should only perform duplicate request caching.\n\t//\n\tconst logchar S = 0x53;\n\tconst logchar SSS[] = { 0x53, 0x53, 0x53, 0 };\n\tsize_t firstS = pattern.find(S);\n\tsize_t len = pattern.length();\n\n\t//\n\t// if there are no S's or\n\t// three that start with the first S and no fourth S in the string\n\t//\n\tif (firstS == LogString::npos ||\n\t\t(len >= firstS + 3 && pattern.compare(firstS, 3, SSS) == 0\n\t\t\t&& (len == firstS + 3 ||\n\t\t\t\tpattern.find(S, firstS + 3) == LogString::npos)))\n\t{\n\t\treturn 1000000;\n\t}\n\n\treturn 1000;\n}\n\n\n/**\n* Tests if two string regions are equal.\n* @param target target string.\n* @param toffset character position in target to start comparison.\n* @param other other string.\n* @param ooffset character position in other to start comparison.\n* @param len length of region.\n* @return true if regions are equal.\n*/\nbool CachedDateFormat::regionMatches(\n\tconst LogString& target,\n\tsize_t toffset,\n\tconst LogString& other,\n\tsize_t ooffset,\n\tsize_t len)\n{\n\treturn target.compare(toffset, len, other, ooffset, len) == 0;\n}\n\n\n// Path: src/main/cpp/loader.cpp\n/*\n * Licensed to the Apache Software Foundation (ASF) under one or more\n * contributor license agreements. See the NOTICE file distributed with\n * this work for additional information regarding copyright ownership.\n * The ASF licenses this file to You under the Apache License, Version 2.0\n * (the \"License\"); you may not use this file except in compliance with\n * the License. You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n */\n\n#include \n#include \n#include \n#include \n#include \n#include \n#include \n#include \n#include \n#include \n#include \n#include \n#include \n#include \n#include \n#include \n#include \n\nusing namespace LOG4CXX_NS;\nusing namespace LOG4CXX_NS::helpers;\nusing namespace LOG4CXX_NS::spi;\nusing namespace LOG4CXX_NS::filter;\n\nIMPLEMENT_LOG4CXX_OBJECT(Object)\nIMPLEMENT_LOG4CXX_OBJECT(OptionHandler)\nIMPLEMENT_LOG4CXX_OBJECT(ErrorHandler)\nIMPLEMENT_LOG4CXX_OBJECT(Appender)\nIMPLEMENT_LOG4CXX_OBJECT(Filter)\nIMPLEMENT_LOG4CXX_OBJECT(AppenderAttachable)\nIMPLEMENT_LOG4CXX_OBJECT(LoggerFactory)\nIMPLEMENT_LOG4CXX_OBJECT(LoggerRepository)\nIMPLEMENT_LOG4CXX_OBJECT(DenyAllFilter)\nIMPLEMENT_LOG4CXX_OBJECT(RepositorySelector)\nIMPLEMENT_LOG4CXX_OBJECT(XMLDOMNode)\nIMPLEMENT_LOG4CXX_OBJECT(XMLDOMDocument)\nIMPLEMENT_LOG4CXX_OBJECT(XMLDOMElement)\nIMPLEMENT_LOG4CXX_OBJECT(XMLDOMNodeList)\nIMPLEMENT_LOG4CXX_OBJECT(TriggeringEventEvaluator)\n\nconst Class& Loader::loadClass(const LogString& clazz)\n{\n\treturn Class::forName(clazz);\n}\n\n\nInputStreamPtr Loader::getResourceAsStream(const LogString& name)\n{\n\n\ttry\n\t{\n\t\treturn std::make_shared(name);\n\t}\n\tcatch (const IOException&)\n\t{\n\t}\n\n\treturn 0;\n}\n\n// Path: src/main/cpp/logmanager.cpp\n/*\n * Licensed to the Apache Software Foundation (ASF) under one or more\n * contributor license agreements. See the NOTICE file distributed with\n * this work for additional information regarding copyright ownership.\n * The ASF licenses this file to You under the Apache License, Version 2.0\n * (the \"License\"); you may not use this file except in compliance with\n * the License. You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n */\n\n#include \n#include \n#include \n#include \n#include \n#include \n#include \n#include \n#include \n#include \n#include \n#include \n\n#include \n#include \n#include \n#if !defined(LOG4CXX)\n\t#define LOG4CXX 1\n#endif\n#include \n\nusing namespace LOG4CXX_NS;\n...\n", "needle_token_start": 3790, "needle_token_end": 3847, "code_context_ntokens": 15213, "output": ["```cpp\n// Path: src/main/cpp/hierarchy.cpp\n/*\n * Licensed to the Apache Software Foundation (ASF) under one or more\n * contributor license agreements. See the NOTICE file distributed with\n * this work for additional information regarding copyright ownership.\n * The ASF licenses this file to You under the Apache License, Version 2.0\n * (the \"License\"); you may not use this file except in compliance with\n * the License. You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n */\n\n#include \n#include \n#include \n#include \n#include \n#include \n#include \n#include \n#include \n#include \n#include \n#include \n#if !defined(LOG4CXX)\n\t#define LOG4CXX 1\n#endif\n#include \n#include \n#include \"assert.h\"\n\n\nusing namespace LOG4CXX_NS;\nusing namespace LOG4CXX_NS::spi;\nusing namespace LOG4CXX_NS::helpers;\n\n\ntypedef std::map LoggerMap;\ntypedef std::map ProvisionNodeMap;\n\nstruct Hierarchy::HierarchyPrivate\n{\n\t\nHierarchyPrivate()\n\t\t: configured(false)\n\t\t, emittedNoAppenderWarning(false)\n\t\t, emittedNoResourceBundleWarning(false)\n\t\t, thresholdInt(Level::ALL_INT)\n\t{\n\t}\n\n\thelpers::Pool pool;\n\tmutable std::recursive_mutex mutex;\n\tmutable std::mutex configuredMutex;\n\tbool configured;\n\tbool emittedNoAppenderWarning;\n\tbool emittedNoResourceBundleWarning;\n\tint thresholdInt;\n\n\tspi::HierarchyEventListenerList listeners;\n\tLoggerPtr root;\n\tLevelPtr threshold;\n\tLoggerMap loggers;\n\tProvisionNodeMap provisionNodes;\n\n\tstd::vector allAppenders;\n};\n\nIMPLEMENT_LOG4CXX_OBJECT(Hierarchy)\n\nHierarchy::Hierarchy() :\n\tm_priv(std::make_unique())\n{\n}\n\nHierarchy::~Hierarchy()\n{\n\tstd::lock_guard lock(m_priv->mutex);\n\tfor (auto& item : m_priv->loggers)\n\t{\n\t\tif (auto& pLogger = item.second)\n\t\t{\n\t\t\tpLogger->removeHierarchy();\n\t\t\tpLogger->removeAllAppenders();\n\t\t}\n\t}\n\tif (m_priv->root)\n\t{\n\t\tm_priv->root->removeHierarchy();\n\t\tm_priv->root->removeAllAppenders();\n\t}\n}\n\nvoid Hierarchy::addHierarchyEventListener(const spi::HierarchyEventListenerPtr& listener)\n{\n\tstd::lock_guard lock(m_priv->mutex);\n\n\tif (std::find(m_priv->listeners.begin(), m_priv->listeners.end(), listener) != m_priv->listeners.end())\n\t{\n\t\tLogLog::warn(LOG4CXX_STR(\"Ignoring attempt to add an existent listener.\"));\n\t}\n\telse\n\t{\n\t\tm_priv->listeners.push_back(listener);\n\t}\n}\n\nvoid Hierarchy::removeHierarchyEventListener(const spi::HierarchyEventListenerPtr& listener)\n{\n\tstd::lock_guard lock(m_priv->mutex);\n\n auto found = std::find(m_priv->listeners.begin(), m_priv->listeners.end(), listener);\n if(found != m_priv->listeners.end()){\n m_priv->listeners.erase(found);\n }\n}\n\nvoid Hierarchy::clear()\n{\n\tstd::lock_guard lock(m_priv->mutex);\n\tm_priv->loggers.clear();\n}\n\nvoid Hierarchy::emitNoAppenderWarning(const Logger* logger)\n{\n\tbool emitWarning = false;\n\t{\n\t\tstd::lock_guard lock(m_priv->mutex);\n\t\temitWarning = !m_priv->emittedNoAppenderWarning;\n\t\tm_priv->emittedNoAppenderWarning = true;\n\t}\n\n\t// No appender in hierarchy, warn user only once.\n\tif (emitWarning)\n\t{\n\t\tLogLog::warn(((LogString) LOG4CXX_STR(\"No appender could be found for logger (\"))\n\t\t\t+ logger->getName() + LOG4CXX_STR(\").\"));\n\t\tLogLog::warn(LOG4CXX_STR(\"Please initialize the log4cxx system properly.\"));\n\t}\n}\n\n\nLoggerPtr Hierarchy::exists(const LogString& name)\n{\n\tstd::lock_guard"]} +{"repo": "apache/logging-log4cxx", "name": "decodeUTF16", "language": "cpp", "path": "src/main/cpp/transcoder.cpp", "position_ratio": 0.35, "description": "\nFunction Description:\n1. **Purpose**: The function decodes a sequence of UTF-16 encoded characters into a Unicode scalar value, handling surrogate pairs correctly.\n2. **Input**: The function takes a string and an iterator pointing to the current character in the string.\n3. **Output**: It returns an unsigned integer representing the Unicode scalar value of the character or characters pointed to by the iterator. If the character is part of a valid surrogate pair, it returns the combined scalar value; otherwise, it returns the value of the single character or an error code for unrecognized sequences.\n4. **Procedure**: \n - First, the function checks if the current character is a high-surrogate (in the range 0xD800 to 0xDBFF). If not, it simply returns the character's value and advances the iterator unless the character is 0xFFFF.\n - If the character is a high-surrogate, the function checks the next character to see if it is a low-surrogate (in the range 0xDC00 to 0xDFFF). If so, it calculates the scalar value from the surrogate pair, advances the iterator past both characters, and returns the scalar value.\n - If the sequence does not form a valid surrogate pair, the function returns an error code (0xFFFF) without advancing the iterator.\n\n", "instruction": "Based on the function description and code context, please retrieve and repeat the exact described function from the code context in a code block wrapped by ```:", "template": "instruction\ncode_context\ndescription\ninstruction", "code_context": "// Path: src/main/cpp/resourcebundle.cpp\n/*\n...\n// Path: src/main/cpp/outputstreamwriter.cpp\n/*\n * Licensed to the Apache Software Foundation (ASF) under one or more\n * contributor license agreements. See the NOTICE file distributed with\n * this work for additional information regarding copyright ownership.\n * The ASF licenses this file to You under the Apache License, Version 2.0\n * (the \"License\"); you may not use this file except in compliance with\n * the License. You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n */\n\n#include \n#include \n#include \n#include \n#include \n#include \n\nusing namespace LOG4CXX_NS;\nusing namespace LOG4CXX_NS::helpers;\n\nIMPLEMENT_LOG4CXX_OBJECT(OutputStreamWriter)\n\nstruct OutputStreamWriter::OutputStreamWriterPrivate{\n\tOutputStreamWriterPrivate(OutputStreamPtr& out1) : out(out1), enc(CharsetEncoder::getDefaultEncoder()){}\n\n\tOutputStreamWriterPrivate(OutputStreamPtr& out1,\n\t\t\t\t\t\t\t CharsetEncoderPtr& enc1)\n\t\t: out(out1), enc(enc1){}\n\n\tOutputStreamPtr out;\n\tCharsetEncoderPtr enc;\n};\n\nOutputStreamWriter::OutputStreamWriter(OutputStreamPtr& out1)\n\t: m_priv(std::make_unique(out1))\n{\n\tif (out1 == 0)\n\t{\n\t\tthrow NullPointerException(LOG4CXX_STR(\"out parameter may not be null.\"));\n\t}\n}\n\nOutputStreamWriter::OutputStreamWriter(OutputStreamPtr& out1,\n\tCharsetEncoderPtr& enc1)\n\t: m_priv(std::make_unique(out1, enc1))\n{\n\tif (out1 == 0)\n\t{\n\t\tthrow NullPointerException(LOG4CXX_STR(\"out parameter may not be null.\"));\n\t}\n\n\tif (enc1 == 0)\n\t{\n\t\tthrow NullPointerException(LOG4CXX_STR(\"enc parameter may not be null.\"));\n\t}\n}\n\nOutputStreamWriter::~OutputStreamWriter()\n{\n}\n\nvoid OutputStreamWriter::close(Pool& p)\n{\n\tm_priv->out->close(p);\n}\n\nvoid OutputStreamWriter::flush(Pool& p)\n{\n\tm_priv->out->flush(p);\n}\n\nvoid OutputStreamWriter::write(const LogString& str, Pool& p)\n{\n\tif (str.empty())\n\t\treturn;\n\tif (CharsetEncoder::isTriviallyCopyable(str, m_priv->enc))\n\t{\n\t\tByteBuffer buf((char*)str.data(), str.size() * sizeof (logchar));\n\t\tm_priv->out->write(buf, p);\n\t}\n\telse\n\t{\n\t\tenum { BUFSIZE = 1024 };\n\t\tchar stackData[BUFSIZE];\n\t\tchar* rawbuf = stackData;\n\t\tsize_t bufSize = BUFSIZE;\n#ifdef LOG4CXX_MULTI_PROCESS\n\t\tstd::vector heapData;\n\t\t// Ensure the logging event is a single write system call to keep events from each process separate\n\t\tif (bufSize < str.length() * 2)\n\t\t{\n\t\t\theapData.resize(bufSize = str.length() * 2);\n\t\t\trawbuf = heapData.data();\n\t\t}\n#endif\n\t\tByteBuffer buf(rawbuf, bufSize);\n\t\tm_priv->enc->reset();\n\t\tLogString::const_iterator iter = str.begin();\n\n\t\twhile (iter != str.end())\n\t\t{\n\t\t\tCharsetEncoder::encode(m_priv->enc, str, iter, buf);\n\t\t\tbuf.flip();\n\t\t\tm_priv->out->write(buf, p);\n\t\t\tbuf.clear();\n\t\t}\n\n\t\tCharsetEncoder::encode(m_priv->enc, str, iter, buf);\n\t\tm_priv->enc->flush(buf);\n\t\tbuf.flip();\n\t\tm_priv->out->write(buf, p);\n\t}\n}\n\nOutputStreamPtr OutputStreamWriter::getOutputStreamPtr() const\n{\n\treturn m_priv->out;\n}\n\n\n// Path: src/main/cpp/transcoder.cpp\n/*\n * Licensed to the Apache Software Foundation (ASF) under one or more\n * contributor license agreements. See the NOTICE file distributed with\n * this work for additional information regarding copyright ownership.\n * The ASF licenses this file to You under the Apache License, Version 2.0\n * (the \"License\"); you may not use this file except in compliance with\n * the License. You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n */\n\n#include \n#include \n#include \n#include \n#include \n#include \n#include \n#include \n#include \n#include \n#include \n#include \n#if !defined(LOG4CXX)\n\t#define LOG4CXX 1\n#endif\n#include \n\n#if LOG4CXX_CFSTRING_API\n\t#include \n#endif\n\nusing namespace LOG4CXX_NS;\nusing namespace LOG4CXX_NS::helpers;\n\n\nvoid Transcoder::decodeUTF8(const std::string& src, LogString& dst)\n{\n\tstd::string::const_iterator iter = src.begin();\n\n\twhile (iter != src.end())\n\t{\n\t\tunsigned int sv = decode(src, iter);\n\n\t\tif (sv != 0xFFFF)\n\t\t{\n\t\t\tencode(sv, dst);\n\t\t}\n\t\telse\n\t\t{\n\t\t\tdst.append(1, LOSSCHAR);\n\t\t\titer++;\n\t\t}\n\t}\n}\n\nvoid Transcoder::encodeUTF8(const LogString& src, std::string& dst)\n{\n#if LOG4CXX_LOGCHAR_IS_UTF8\n\tdst.append(src);\n#else\n\tLogString::const_iterator iter = src.begin();\n\n\twhile (iter != src.end())\n\t{\n\t\tunsigned int sv = decode(src, iter);\n\n\t\tif (sv != 0xFFFF)\n\t\t{\n\t\t\tencode(sv, dst);\n\t\t}\n\t\telse\n\t\t{\n\t\t\tdst.append(1, LOSSCHAR);\n\t\t\titer++;\n\t\t}\n\t}\n\n#endif\n}\n\nchar* Transcoder::encodeUTF8(const LogString& src, Pool& p)\n{\n#if LOG4CXX_LOGCHAR_IS_UTF8\n\treturn p.pstrdup(src);\n#else\n\tstd::string tmp;\n\tencodeUTF8(src, tmp);\n\treturn p.pstrdup(tmp);\n#endif\n}\n\n\nvoid Transcoder::encodeUTF8(unsigned int sv, ByteBuffer& dst)\n{\n\tsize_t bytes = encodeUTF8(sv, dst.current());\n\tdst.position(dst.position() + bytes);\n}\n\n\nsize_t Transcoder::encodeUTF8(unsigned int ch, char* dst)\n{\n\tif (ch < 0x80)\n\t{\n\t\tdst[0] = (char) ch;\n\t\treturn 1;\n\t}\n\telse if (ch < 0x800)\n\t{\n\t\tdst[0] = (char) (0xC0 + (ch >> 6));\n\t\tdst[1] = (char) (0x80 + (ch & 0x3F));\n\t\treturn 2;\n\t}\n\telse if (ch < 0x10000)\n\t{\n\t\tdst[0] = (char) (0xE0 + (ch >> 12));\n\t\tdst[1] = (char) (0x80 + ((ch >> 6) & 0x3F));\n\t\tdst[2] = (char) (0x80 + (ch & 0x3F));\n\t\treturn 3;\n\t}\n\telse if (ch <= 0x10FFFF)\n\t{\n\t\tdst[0] = (char) (0xF0 + (ch >> 18));\n\t\tdst[1] = (char) (0x80 + ((ch >> 12) & 0x3F));\n\t\tdst[2] = (char) (0x80 + ((ch >> 6) & 0x3F));\n\t\tdst[3] = (char) (0x80 + (ch & 0x3F));\n\t\treturn 4;\n\t}\n\telse\n\t{\n\t\t//\n\t\t// output UTF-8 encoding of 0xFFFF\n\t\t//\n\t\tdst[0] = (char) 0xEF;\n\t\tdst[1] = (char) 0xBF;\n\t\tdst[2] = (char) 0xBF;\n\t\treturn 3;\n\t}\n}\n\nvoid Transcoder::encodeUTF16BE(unsigned int sv, ByteBuffer& dst)\n{\n\tsize_t bytes = encodeUTF16BE(sv, dst.current());\n\tdst.position(dst.position() + bytes);\n}\n\n\nsize_t Transcoder::encodeUTF16BE(unsigned int ch, char* dst)\n{\n\tif (ch <= 0xFFFF)\n\t{\n\t\tdst[0] = (char) (ch >> 8);\n\t\tdst[1] = (char) (ch & 0xFF);\n\t\treturn 2;\n\t}\n\n\tif (ch <= 0x10FFFF)\n\t{\n\t\tunsigned char w = (unsigned char) ((ch >> 16) - 1);\n\t\tdst[0] = (char) (0xD8 + (w >> 2));\n\t\tdst[1] = (char) (((w & 0x03) << 6) + ((ch >> 10) & 0x3F));\n\t\tdst[2] = (char) (0xDC + ((ch & 0x30) >> 4));\n\t\tdst[3] = (char) (ch & 0xFF);\n\t\treturn 4;\n\t}\n\n\tdst[0] = dst[1] = (char) 0xFF;\n\treturn 2;\n}\n\nvoid Transcoder::encodeUTF16LE(unsigned int sv, ByteBuffer& dst)\n{\n\tsize_t bytes = encodeUTF16LE(sv, dst.current());\n\tdst.position(dst.position() + bytes);\n}\n\nsize_t Transcoder::encodeUTF16LE(unsigned int ch, char* dst)\n{\n\tif (ch <= 0xFFFF)\n\t{\n\t\tdst[1] = (char) (ch >> 8);\n\t\tdst[0] = (char) (ch & 0xFF);\n\t\treturn 2;\n\t}\n\n\tif (ch <= 0x10FFFF)\n\t{\n\t\tunsigned char w = (unsigned char) ((ch >> 16) - 1);\n\t\tdst[1] = (char) (0xD8 + (w >> 2));\n\t\tdst[0] = (char) (((w & 0x03) << 6) + ((ch >> 10) & 0x3F));\n\t\tdst[3] = (char) (0xDC + ((ch & 0x30) >> 4));\n\t\tdst[2] = (char) (ch & 0xFF);\n\t\treturn 4;\n\t}\n\n\tdst[0] = dst[1] = (char) 0xFF;\n\treturn 2;\n}\n\n\nunsigned int Transcoder::decode(const std::string& src,\n\tstd::string::const_iterator& iter)\n{\n\tstd::string::const_iterator start(iter);\n\tunsigned char ch1 = *(iter++);\n\n\tif (ch1 <= 0x7F)\n\t{\n\t\treturn ch1;\n\t}\n\n\t//\n\t// should not have continuation character here\n\t//\n\tif ((ch1 & 0xC0) != 0x80 && iter != src.end())\n\t{\n\t\tunsigned char ch2 = *(iter++);\n\n\t\t//\n\t\t// should be continuation\n\t\tif ((ch2 & 0xC0) != 0x80)\n\t\t{\n\t\t\titer = start;\n\t\t\treturn 0xFFFF;\n\t\t}\n\n\t\tif ((ch1 & 0xE0) == 0xC0)\n\t\t{\n\t\t\tif ((ch2 & 0xC0) == 0x80)\n\t\t\t{\n\t\t\t\tunsigned int rv = ((ch1 & 0x1F) << 6) + (ch2 & 0x3F);\n\n\t\t\t\tif (rv >= 0x80)\n\t\t\t\t{\n\t\t\t\t\treturn rv;\n\t\t\t\t}\n\t\t\t}\n\n\t\t\titer = start;\n\t\t\treturn 0xFFFF;\n\t\t}\n\n\t\tif (iter != src.end())\n\t\t{\n\t\t\tunsigned char ch3 = *(iter++);\n\n\t\t\t//\n\t\t\t// should be continuation\n\t\t\t//\n\t\t\tif ((ch3 & 0xC0) != 0x80)\n\t\t\t{\n\t\t\t\titer = start;\n\t\t\t\treturn 0xFFFF;\n\t\t\t}\n\n\t\t\tif ((ch1 & 0xF0) == 0xE0)\n\t\t\t{\n\t\t\t\tunsigned rv = ((ch1 & 0x0F) << 12)\n\t\t\t\t\t+ ((ch2 & 0x3F) << 6)\n\t\t\t\t\t+ (ch3 & 0x3F);\n\n\t\t\t\tif (rv <= 0x800)\n\t\t\t\t{\n\t\t\t\t\titer = start;\n\t\t\t\t\treturn 0xFFFF;\n\t\t\t\t}\n\n\t\t\t\treturn rv;\n\t\t\t}\n\n\t\t\tif (iter != src.end())\n\t\t\t{\n\t\t\t\tunsigned char ch4 = *(iter++);\n\n\t\t\t\tif ((ch4 & 0xC0) != 0x80)\n\t\t\t\t{\n\t\t\t\t\titer = start;\n\t\t\t\t\treturn 0xFFFF;\n\t\t\t\t}\n\n\t\t\t\tunsigned int rv = ((ch1 & 0x07) << 18)\n\t\t\t\t\t+ ((ch2 & 0x3F) << 12)\n\t\t\t\t\t+ ((ch3 & 0x3F) << 6)\n\t\t\t\t\t+ (ch4 & 0x3F);\n\n\t\t\t\tif (rv > 0xFFFF)\n\t\t\t\t{\n\t\t\t\t\treturn rv;\n\t\t\t\t}\n\n\t\t\t}\n\t\t}\n\t}\n\n\titer = start;\n\treturn 0xFFFF;\n}\n\n\nvoid Transcoder::encode(unsigned int sv, std::string& dst)\n{\n\tchar tmp[8];\n\tsize_t bytes = encodeUTF8(sv, tmp);\n\tdst.append(tmp, bytes);\n}\n\n\nvoid Transcoder::decode(const std::string& src, LogString& dst)\n{\n#if LOG4CXX_CHARSET_UTF8 && LOG4CXX_LOGCHAR_IS_UTF8\n\tdst.append(src);\n#else\n\tstatic CharsetDecoderPtr decoder(CharsetDecoder::getDefaultDecoder());\n\tdst.reserve(dst.size() + src.size());\n\tstd::string::const_iterator iter = src.begin();\n#if !LOG4CXX_CHARSET_EBCDIC\n\n\tfor (;\n\t\titer != src.end() && ((unsigned char) *iter) < 0x80;\n\t\titer++)\n\t{\n\t\tdst.append(1, *iter);\n\t}\n\n#endif\n\n\tif (iter != src.end())\n\t{\n\t\tsize_t offset = iter - src.begin();\n\t\tByteBuffer buf(const_cast(src.data() + offset), src.size() - offset);\n\n\t\twhile (buf.remaining() > 0)\n\t\t{\n\t\t\tlog4cxx_status_t stat = decoder->decode(buf, dst);\n\n\t\t\tif (CharsetDecoder::isError(stat))\n\t\t\t{\n\t\t\t\tdst.append(1, LOSSCHAR);\n\t\t\t\tbuf.position(buf.position() + 1);\n\t\t\t}\n\t\t}\n\n\t\tdecoder->decode(buf, dst);\n\t}\n\n#endif\n}\n\nchar* Transcoder::encode(const LogString& src, Pool& p)\n{\n#if LOG4CXX_CHARSET_UTF8 && LOG4CXX_LOGCHAR_IS_UTF8\n\treturn p.pstrdup(src);\n#else\n\tstd::string tmp;\n\tencode(src, tmp);\n\treturn p.pstrdup(tmp);\n#endif\n}\n\n\n\nvoid Transcoder::encode(const LogString& src, std::string& dst)\n{\n#if LOG4CXX_CHARSET_UTF8 && LOG4CXX_LOGCHAR_IS_UTF8\n\tdst.append(src);\n#else\n\tstatic CharsetEncoderPtr encoder(CharsetEncoder::getDefaultEncoder());\n\tdst.reserve(dst.size() + src.size());\n\tLogString::const_iterator iter = src.begin();\n#if !LOG4CXX_CHARSET_EBCDIC\n\n\tfor (;\n\t\titer != src.end() && ((unsigned int) *iter) < 0x80;\n\t\titer++)\n\t{\n\t\tdst.append(1, *iter);\n\t}\n\n#endif\n\n\tif (iter != src.end())\n\t{\n\t\tchar buf[BUFSIZE];\n\t\tByteBuffer out(buf, BUFSIZE);\n\n\t\twhile (iter != src.end())\n\t\t{\n\t\t\tlog4cxx_status_t stat = encoder->encode(src, iter, out);\n\t\t\tout.flip();\n\t\t\tdst.append(out.data(), out.limit());\n\t\t\tout.clear();\n\n\t\t\tif (CharsetEncoder::isError(stat))\n\t\t\t{\n\t\t\t\tdst.append(1, LOSSCHAR);\n\t\t\t\titer++;\n\t\t\t}\n\t\t}\n\n\t\tencoder->encode(src, iter, out);\n\t}\n\n#endif\n}\n\n\ntemplate\n\nstatic unsigned int decodeUTF16(const String& in, Iterator& iter)\n{\n\tunsigned int ch1 = *iter;\n\n\t//\n\t// if not surrogate pair\n\t//\n\tif (ch1 < 0xD800 || ch1 > 0xDFFF)\n\t{\n\t\t//\n\t\t// then advance iterator and return wchar_t value\n\t\t//\n\t\tif (ch1 != 0xFFFF)\n\t\t{\n\t\t\titer++;\n\t\t}\n\n\t\treturn ch1;\n\t}\n\telse if (ch1 < 0xDC00)\n\t{\n\t\t//\n\t\t// started with high-surrogate value\n\t\t// if there is an additional wchar_t\n\t\tIterator iter2 = iter + 1;\n\n\t\tif (iter2 != in.end())\n\t\t{\n\t\t\tunsigned int ch2 = *iter2;\n\n\t\t\t//\n\t\t\t// if it is a matching low surrogate then\n\t\t\t// advance the iterator and return the scalar value\n\t\t\tif (ch2 >= 0xDC00 && ch2 <= 0xDFFF)\n\t\t\t{\n\t\t\t\titer += 2;\n\t\t\t\treturn (ch1 - 0xD800) * 0x400 + (ch2 - 0xDC00) + 0x10000;\n\t\t\t}\n\t\t}\n\t}\n\n\t//\n\t// unrecognized value, do not advance iterator\n\t//\n\treturn 0xFFFF;\n}\n\ntemplate\nstatic void encodeUTF16(unsigned int sv, String& dst)\n{\n\tif (sv < 0x10000)\n\t{\n\t\tdst.append(1, sv);\n\t}\n\telse\n\t{\n\t\tunsigned char u = (unsigned char) (sv >> 16);\n\t\tunsigned char w = (unsigned char) (u - 1);\n\t\tunsigned short hs = (0xD800 + ((w & 0xF) << 6) + ((sv & 0xFFFF) >> 10));\n\t\tunsigned short ls = (0xDC00 + (sv & 0x3FF));\n\t\tdst.append(1, hs);\n\t\tdst.append(1, ls);\n\t}\n}\n\n\n\n#if LOG4CXX_WCHAR_T_API || LOG4CXX_LOGCHAR_IS_WCHAR_T || defined(WIN32) || defined(_WIN32)\nvoid Transcoder::decode(const std::wstring& src, LogString& dst)\n{\n#if LOG4CXX_LOGCHAR_IS_WCHAR_T\n\tdst.append(src, len);\n#else\n\tstd::wstring::const_iterator i = src.begin();\n\n\twhile (i != src.end())\n\t{\n\t\tunsigned int cp = decode(src, i);\n\n\t\tif (cp != 0xFFFF)\n\t\t{\n\t\t\tencode(cp, dst);\n\t\t}\n\t\telse\n\t\t{\n\t\t\tdst.append(1, LOSSCHAR);\n\t\t\ti++;\n\t\t}\n\t}\n\n#endif\n}\n\nvoid Transcoder::encode(const LogString& src, std::wstring& dst)\n{\n#if LOG4CXX_LOGCHAR_IS_WCHAR_T\n\tdst.append(src);\n#else\n\n\tfor (LogString::const_iterator i = src.begin(); i != src.end();)\n\t{\n\t\tunsigned int cp = Transcoder::decode(src, i);\n\n\t\tif (cp != 0xFFFF)\n\t\t{\n\t\t\tencode(cp, dst);\n\t\t}\n\t\telse\n\t\t{\n\t\t\tdst.append(1, LOSSCHAR);\n\t\t\ti++;\n\t\t}\n\t}\n\n#endif\n}\n\nwchar_t* Transcoder::wencode(const LogString& src, Pool& p)\n{\n#if LOG4CXX_LOGCHAR_IS_WCHAR_T\n\tstd::wstring& tmp = src;\n#else\n\tstd::wstring tmp;\n\tencode(src, tmp);\n#endif\n\twchar_t* dst = (wchar_t*) p.palloc((tmp.length() + 1) * sizeof(wchar_t));\n\tdst[tmp.length()] = 0;\n\tstd::memcpy(dst, tmp.data(), tmp.length() * sizeof(wchar_t));\n\treturn dst;\n}\n\n\nunsigned int Transcoder::decode(const std::wstring& in,\n\tstd::wstring::const_iterator& iter)\n{\n#if defined(__STDC_ISO_10646__)\n\treturn *(iter++);\n#else\n\treturn decodeUTF16(in, iter);\n#endif\n}\n\n\nvoid Transcoder::encode(unsigned int sv, std::wstring& dst)\n{\n#if defined(__STDC_ISO_10646__)\n\tdst.append(1, sv);\n#else\n\n\tif (sizeof(wchar_t) == 4)\n\t{\n\t\tdst.append(1, sv);\n\t}\n\telse\n\t{\n\t\tencodeUTF16(sv, dst);\n\t}\n\n#endif\n}\n\n#endif\n\n\n\n#if LOG4CXX_UNICHAR_API || LOG4CXX_LOGCHAR_IS_UNICHAR\nvoid Transcoder::decode(const std::basic_string& src, LogString& dst)\n{\n#if LOG4CXX_LOGCHAR_IS_UNICHAR\n\tdst.append(src);\n#else\n\n\tfor (std::basic_string::const_iterator i = src.begin();\n\t\ti != src.end();)\n\t{\n\t\tunsigned int cp = decode(src, i);\n\t\tencode(cp, dst);\n\t}\n\n#endif\n}\n\nvoid Transcoder::encode(const LogString& src, std::basic_string& dst)\n{\n#if LOG4CXX_LOGCHAR_IS_UNICHAR\n\tdst.append(src);\n#else\n\n\tfor (LogString::const_iterator i = src.begin();\n\t\ti != src.end();)\n\t{\n\t\tunsigned int cp = decode(src, i);\n\t\tencode(cp, dst);\n\t}\n\n#endif\n}\n\nunsigned int Transcoder::decode(const std::basic_string& in,\n\tstd::basic_string::const_iterator& iter)\n{\n\treturn decodeUTF16(in, iter);\n}\n\nvoid Transcoder::encode(unsigned int sv, std::basic_string& dst)\n{\n\tencodeUTF16(sv, dst);\n}\n\n#endif\n\n#if LOG4CXX_CFSTRING_API\nvoid Transcoder::decode(const CFStringRef& src, LogString& dst)\n{\n\tauto length = CFStringGetLength(src);\n#if defined(_DEBUG)\n\tPool pool;\n\tLogString msg(LOG4CXX_STR(\"Transcoder::decodeCFString\"));\n\tmsg += LOG4CXX_STR(\" length \");\n\tStringHelper::toString((size_t)length, pool, msg);\n\tLogLog::debug(msg);\n#endif\n\n\tif (length > 0)\n\t{\n\t\tstd::vector tmp(length);\n\t\tCFStringGetCharacters(src, CFRangeMake(0, length), &tmp[0]);\n\t\tfor (auto i = tmp.begin(); i != tmp.end(); )\n\t\t{\n\t\t\tunsigned int cp = decodeUTF16(tmp, i);\n\t\t\tencode(cp, dst);\n\t\t}\n\t}\n}\n\nCFStringRef Transcoder::encode(const LogString& src)\n{\n\tstd::basic_string tmp;\n\tfor (auto ch : src)\n\t\tencodeUTF16(ch, tmp);\n\treturn CFStringCreateWithCharacters(kCFAllocatorDefault, tmp.data(), tmp.size());\n}\n#endif // #if LOG4CXX_CFSTRING_API\n\n\nlogchar Transcoder::decode(char val)\n{\n#if LOG4CXX_CHARSET_EBCDIC\n\tLogString dst;\n\tTranscoder::decode(std::string(1, val), dst);\n\treturn dst[0];\n#else\n\treturn val;\n#endif\n}\n\nLogString Transcoder::decode(const char* val)\n{\n#if LOG4CXX_LOGCHAR_IS_UTF8 && !LOG4CXX_CHARSET_EBCDIC\n\treturn val;\n#else\n\tLogString dst;\n\tTranscoder::decode(val, dst);\n\treturn dst;\n#endif\n}\n\n\nstd::string Transcoder::encodeCharsetName(const LogString& val)\n{\n\tchar asciiTable[] = { ' ', '!', '\"', '#', '$', '%', '&', '\\'', '(', ')', '*', '+', ',', '-', '.', '/',\n\t\t\t'0', '1', '2', '3', '4', '5', '6', '7', '8', '9', ':', ';', '<', '=', '>', '?',\n\t\t\t'@', 'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O',\n\t\t\t'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z', '[', '\\\\', ']', '^', '_',\n\t\t\t'`', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o',\n\t\t\t'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', '{', '|', '}', '~'\n\t\t};\n\tstd::string out;\n\n\tfor (auto& item : val)\n\t{\n\t\tif (item >= 0x20 && item < 0x7F)\n\t\t{\n\t\t\tout.append(1, asciiTable[item - 0x20]);\n\t\t}\n\t\telse\n\t\t{\n\t\t\tout.append(1, LOSSCHAR);\n\t\t}\n\t}\n\n\treturn out;\n}\n\n// Path: src/main/cpp/linelocationpatternconverter.cpp\n/*\n * Licensed to the Apache Software Foundation (ASF) under one or more\n * contributor license agreements. See the NOTICE file distributed with\n * this work for additional information regarding copyright ownership.\n * The ASF licenses this file to You under the Apache License, Version 2.0\n * (the \"License\"); you may not use this file except in compliance with\n * the License. You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n */\n\n#include \n#include \n#include \n#include \n#include \n\nusing namespace LOG4CXX_NS;\nusing namespace LOG4CXX_NS::pattern;\nusing namespace LOG4CXX_NS::spi;\nusing namespace LOG4CXX_NS::helpers;\n\nIMPLEMENT_LOG4CXX_OBJECT(LineLocationPatternConverter)\n\nLineLocationPatternConverter::LineLocationPatternConverter() :\n\tLoggingEventPatternConverter(LOG4CXX_STR(\"Line\"),\n\t\tLOG4CXX_STR(\"line\"))\n{\n}\n\nPatternConverterPtr LineLocationPatternConverter::newInstance(\n\tconst std::vector& /* options */)\n{\n\tstatic WideLife instance = std::make_shared();\n\treturn instance;\n}\n\nvoid LineLocationPatternConverter::format(\n\tconst LoggingEventPtr& event,\n\tLogString& toAppendTo,\n\tPool& p) const\n{\n\tStringHelper::toString(\n\t\tevent->getLocationInformation().getLineNumber(),\n\t\tp, toAppendTo);\n}\n\n// Path: src/main/cpp/integerpatternconverter.cpp\n/*\n * Licensed to the Apache Software Foundation (ASF) under one or more\n * contributor license agreements. See the NOTICE file distributed with\n * this work for additional information regarding copyright ownership.\n * The ASF licenses this file to You under the Apache License, Version 2.0\n * (the \"License\"); you may not use this file except in compliance with\n * the License. You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n */\n\n#include \n#include \n#include \n#include \n\nusing namespace LOG4CXX_NS;\nusing namespace LOG4CXX_NS::pattern;\nusing namespace LOG4CXX_NS::helpers;\n\nIMPLEMENT_LOG4CXX_OBJECT(IntegerPatternConverter)\n\nIntegerPatternConverter::IntegerPatternConverter() :\n\tPatternConverter(LOG4CXX_STR(\"Integer\"),\n\t\tLOG4CXX_STR(\"integer\"))\n{\n}\n\nPatternConverterPtr IntegerPatternConverter::newInstance(\n\tconst std::vector& /* options */)\n{\n\tstatic WideLife instance = std::make_shared();\n\treturn instance;\n}\n\nvoid IntegerPatternConverter::format(\n\tconst ObjectPtr& obj,\n\tLogString& toAppendTo,\n\tPool& p) const\n{\n\tIntegerPtr i = LOG4CXX_NS::cast(obj);\n\n\tif (i != NULL)\n\t{\n\t\tStringHelper::toString(i->intValue(), p, toAppendTo);\n\t}\n}\n\n// Path: src/main/cpp/triggeringpolicy.cpp\n/*\n * Licensed to the Apache Software Foundation (ASF) under one or more\n * contributor license agreements. See the NOTICE file distributed with\n * this work for additional information regarding copyright ownership.\n * The ASF licenses this file to You under the Apache License, Version 2.0\n * (the \"License\"); you may not use this file except in compliance with\n * the License. You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n */\n\n#include \n\nusing namespace LOG4CXX_NS;\nusing namespace LOG4CXX_NS::rolling;\nusing namespace LOG4CXX_NS::helpers;\n\nIMPLEMENT_LOG4CXX_OBJECT(TriggeringPolicy)\n\nTriggeringPolicy::~TriggeringPolicy()\n{\n}\n\n// Path: src/main/cpp/fallbackerrorhandler.cpp\n/*\n * Licensed to the Apache Software Foundation (ASF) under one or more\n * contributor license agreements. See the NOTICE file distributed with\n * this work for additional information regarding copyright ownership.\n * The ASF licenses this file to You under the Apache License, Version 2.0\n * (the \"License\"); you may not use this file except in compliance with\n * the License. You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n */\n\n#include \n#include \n#include \n#include \n#include \n#include \n#include \n#include \n#include \n\nusing namespace LOG4CXX_NS;\nusing namespace LOG4CXX_NS::helpers;\nusing namespace LOG4CXX_NS::spi;\nusing namespace LOG4CXX_NS::varia;\n\nIMPLEMENT_LOG4CXX_OBJECT(FallbackErrorHandler)\n\nstruct FallbackErrorHandler::FallbackErrorHandlerPrivate\n{\n\tAppenderWeakPtr backup;\n\tAppenderWeakPtr primary;\n\tstd::vector loggers;\n\tbool errorReported = false;\n};\n\nFallbackErrorHandler::FallbackErrorHandler()\n\t: m_priv(std::make_unique())\n{\n}\n\nFallbackErrorHandler::~FallbackErrorHandler() {}\n\nvoid FallbackErrorHandler::setLogger(const LoggerPtr& logger)\n{\n\tLogLog::debug(((LogString) LOG4CXX_STR(\"FB: Adding logger [\"))\n\t\t+ logger->getName() + LOG4CXX_STR(\"].\"));\n\tm_priv->loggers.push_back(logger);\n}\n\nvoid FallbackErrorHandler::error(const LogString& message,\n\tconst std::exception& e,\n\tint errorCode) const\n{\n\terror(message, e, errorCode, 0);\n}\n\nvoid FallbackErrorHandler::error(const LogString& message,\n\tconst std::exception& e,\n\tint, const spi::LoggingEventPtr&) const\n{\n\tLogLog::debug(((LogString) LOG4CXX_STR(\"FB: The following error reported: \"))\n\t\t+ message, e);\n\tLogLog::debug(LOG4CXX_STR(\"FB: INITIATING FALLBACK PROCEDURE.\"));\n\n\tAppenderPtr primaryLocked = m_priv->primary.lock();\n\tAppenderPtr backupLocked = m_priv->backup.lock();\n\n\tif ( !primaryLocked || !backupLocked )\n\t{\n\t\treturn;\n\t}\n\n\tfor (LoggerPtr l : m_priv->loggers)\n\t{\n\t\tLogLog::debug(((LogString) LOG4CXX_STR(\"FB: Searching for [\"))\n\t\t\t+ primaryLocked->getName() + LOG4CXX_STR(\"] in logger [\")\n\t\t\t+ l->getName() + LOG4CXX_STR(\"].\"));\n\t\tLogLog::debug(((LogString) LOG4CXX_STR(\"FB: Replacing [\"))\n\t\t\t+ primaryLocked->getName() + LOG4CXX_STR(\"] by [\")\n\t\t\t+ backupLocked->getName() + LOG4CXX_STR(\"] in logger [\")\n\t\t\t+ l->getName() + LOG4CXX_STR(\"].\"));\n\t\tl->removeAppender(primaryLocked);\n\t\tLogLog::debug(((LogString) LOG4CXX_STR(\"FB: Adding appender [\"))\n\t\t\t+ backupLocked->getName() + LOG4CXX_STR(\"] to logger \")\n\t\t\t+ l->getName());\n\t\tl->addAppender(backupLocked);\n\t}\n\tm_priv->errorReported = true;\n}\n\nvoid FallbackErrorHandler::setAppender(const AppenderPtr& primary1)\n{\n\tLogLog::debug(((LogString) LOG4CXX_STR(\"FB: Setting primary appender to [\"))\n\t\t+ primary1->getName() + LOG4CXX_STR(\"].\"));\n\tm_priv->primary = primary1;\n}\n\nvoid FallbackErrorHandler::setBackupAppender(const AppenderPtr& backup1)\n{\n\tLogLog::debug(((LogString) LOG4CXX_STR(\"FB: Setting backup appender to [\"))\n\t\t+ backup1->getName() + LOG4CXX_STR(\"].\"));\n\tm_priv->backup = backup1;\n\n\t// Make sure that we keep a reference to the appender around, since otherwise\n\t// the appender would be lost if it has no loggers that use it.\n\tLoggerRepository* repository = LogManager::getRootLogger()->getLoggerRepository();\n\tHierarchy* hierarchy = dynamic_cast(repository);\n\tif(hierarchy){\n\t\thierarchy->addAppender(backup1);\n\t}\n\n}\n\nvoid FallbackErrorHandler::activateOptions(Pool&)\n{\n}\n\nvoid FallbackErrorHandler::setOption(const LogString&, const LogString&)\n{\n}\n\nbool FallbackErrorHandler::errorReported() const\n{\n\treturn m_priv->errorReported;\n}\n\n// Path: src/main/cpp/ndc.cpp\n/*\n * Licensed to the Apache Software Foundation (ASF) under one or more\n * contributor license agreements. See the NOTICE file distributed with\n * this work for additional information regarding copyright ownership.\n * The ASF licenses this file to You under the Apache License, Version 2.0\n * (the \"License\"); you may not use this file except in compliance with\n * the License. You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n */\n\n#include \n#include \n#include \n\nusing namespace LOG4CXX_NS;\nusing namespace LOG4CXX_NS::helpers;\n\nNDC::NDC(const std::string& message)\n{\n\tpush(message);\n}\n\nNDC::~NDC()\n{\n\tpop();\n}\n\n\nLogString& NDC::getMessage(NDC::DiagnosticContext& ctx)\n{\n\treturn ctx.first;\n}\n\nLogString& NDC::getFullMessage(NDC::DiagnosticContext& ctx)\n{\n\treturn ctx.second;\n}\n\nvoid NDC::clear()\n{\n\tThreadSpecificData* data = ThreadSpecificData::getCurrentData();\n\n\tif (data != 0)\n\t{\n\t\tStack& stack = data->getStack();\n\n\t\twhile (!stack.empty())\n\t\t{\n\t\t\tstack.pop();\n\t\t}\n\n\t\tdata->recycle();\n\t}\n}\n\nNDC::Stack* NDC::cloneStack()\n{\n\tThreadSpecificData* data = ThreadSpecificData::getCurrentData();\n\n\tif (data != 0)\n\t{\n\t\tStack& stack = data->getStack();\n\n\t\tif (!stack.empty())\n\t\t{\n\t\t\treturn new Stack(stack);\n\t\t}\n\t}\n\n\treturn new Stack();\n}\n\nvoid NDC::inherit(NDC::Stack* stack)\n{\n\tif (stack != NULL)\n\t{\n\t\tThreadSpecificData::inherit(*stack);\n\t\tdelete stack;\n\t}\n}\n\n\nbool NDC::get(LogString& dest)\n{\n\tThreadSpecificData* data = ThreadSpecificData::getCurrentData();\n\n\tif (data != 0)\n\t{\n\t\tStack& stack = data->getStack();\n\n\t\tif (!stack.empty())\n\t\t{\n\t\t\tdest.append(getFullMessage(stack.top()));\n\t\t\treturn true;\n\t\t}\n\n\t\tdata->recycle();\n\t}\n\n\treturn false;\n}\n\nint NDC::getDepth()\n{\n\tint size = 0;\n\tThreadSpecificData* data = ThreadSpecificData::getCurrentData();\n\n\tif (data != 0)\n\t{\n\t\tsize = (int)data->getStack().size();\n\n\t\tif (size == 0)\n\t\t{\n\t\t\tdata->recycle();\n\t\t}\n\t}\n\n\treturn size;\n}\n\nLogString NDC::pop()\n{\n\tThreadSpecificData* data = ThreadSpecificData::getCurrentData();\n\n\tif (data != 0)\n\t{\n\t\tStack& stack = data->getStack();\n\n\t\tif (!stack.empty())\n\t\t{\n\t\t\tLogString value(getMessage(stack.top()));\n\t\t\tstack.pop();\n\t\t\tdata->recycle();\n\t\t\treturn value;\n\t\t}\n\n\t\tdata->recycle();\n\t}\n\n\treturn LogString();\n}\n\nbool NDC::pop(std::string& dst)\n{\n\tbool retval = false;\n\tThreadSpecificData* data = ThreadSpecificData::getCurrentData();\n\n\tif (data != 0)\n\t{\n\t\tStack& stack = data->getStack();\n\n\t\tif (!stack.empty())\n\t\t{\n\t\t\tTranscoder::encode(getMessage(stack.top()), dst);\n\t\t\tstack.pop();\n\t\t\tretval = true;\n\t\t}\n\n\t\tdata->recycle();\n\t}\n\n\treturn retval;\n}\n\nLogString NDC::peek()\n{\n\tThreadSpecificData* data = ThreadSpecificData::getCurrentData();\n\n\tif (data != 0)\n\t{\n\t\tStack& stack = data->getStack();\n\n\t\tif (!stack.empty())\n\t\t{\n\t\t\treturn getMessage(stack.top());\n\t\t}\n\n\t\tdata->recycle();\n\t}\n\n\treturn LogString();\n}\n\nbool NDC::peek(std::string& dst)\n{\n\tThreadSpecificData* data = ThreadSpecificData::getCurrentData();\n\n\tif (data != 0)\n\t{\n\t\tStack& stack = data->getStack();\n\n\t\tif (!stack.empty())\n\t\t{\n\t\t\tTranscoder::encode(getMessage(stack.top()), dst);\n\t\t\treturn true;\n\t\t}\n\n\t\tdata->recycle();\n\t}\n\n\treturn false;\n}\n\nvoid NDC::pushLS(const LogString& message)\n{\n\tThreadSpecificData::push(message);\n}\n\nvoid NDC::push(const std::string& message)\n{\n\tLOG4CXX_DECODE_CHAR(msg, message);\n\tpushLS(msg);\n}\n\nvoid NDC::remove()\n{\n\tclear();\n}\n\nbool NDC::empty()\n{\n\tbool empty = true;\n\tThreadSpecificData* data = ThreadSpecificData::getCurrentData();\n\n\tif (data != 0)\n\t{\n\t\tStack& stack = data->getStack();\n\t\tempty = stack.empty();\n\n\t\tif (empty)\n\t\t{\n\t\t\tdata->recycle();\n\t\t}\n\t}\n\n\treturn empty;\n}\n\n#if LOG4CXX_WCHAR_T_API\nNDC::NDC(const std::wstring& message)\n{\n\tpush(message);\n}\n\nvoid NDC::push(const std::wstring& message)\n{\n\tLOG4CXX_DECODE_WCHAR(msg, message);\n\tpushLS(msg);\n}\n\nbool NDC::pop(std::wstring& dst)\n{\n\tThreadSpecificData* data = ThreadSpecificData::getCurrentData();\n\n\tif (data != 0)\n\t{\n\t\tStack& stack = data->getStack();\n\n\t\tif (!stack.empty())\n\t\t{\n\t\t\tTranscoder::encode(getMessage(stack.top()), dst);\n\t\t\tstack.pop();\n\t\t\tdata->recycle();\n\t\t\treturn true;\n\t\t}\n\n\t\tdata->recycle();\n\t}\n\n\treturn false;\n}\n\nbool NDC::peek(std::wstring& dst)\n{\n\tThreadSpecificData* data = ThreadSpecificData::getCurrentData();\n\n\tif (data != 0)\n\t{\n\t\tStack& stack = data->getStack();\n\n\t\tif (!stack.empty())\n\t\t{\n\t\t\tTranscoder::encode(getMessage(stack.top()), dst);\n\t\t\treturn true;\n\t\t}\n\n\t\tdata->recycle();\n\t}\n\n\treturn false;\n}\n\n#endif\n\n\n#if LOG4CXX_UNICHAR_API\nNDC::NDC(const std::basic_string& message)\n{\n\tpush(message);\n}\n\nvoid NDC::push(const std::basic_string& message)\n{\n\tLOG4CXX_DECODE_UNICHAR(msg, message);\n\tpushLS(msg);\n}\n\nbool NDC::pop(std::basic_string& dst)\n{\n\tThreadSpecificData* data = ThreadSpecificData::getCurrentData();\n\n\tif (data != 0)\n\t{\n\t\tStack& stack = data->getStack();\n\n\t\tif (!stack.empty())\n\t\t{\n\t\t\tTranscoder::encode(getMessage(stack.top()), dst);\n\t\t\tstack.pop();\n\t\t\tdata->recycle();\n\t\t\treturn true;\n\t\t}\n\n\t\tdata->recycle();\n\t}\n\n\treturn false;\n}\n\nbool NDC::peek(std::basic_string& dst)\n{\n\tThreadSpecificData* data = ThreadSpecificData::getCurrentData();\n\n\tif (data != 0)\n\t{\n\t\tStack& stack = data->getStack();\n\n\t\tif (!stack.empty())\n\t\t{\n\t\t\tTranscoder::encode(getMessage(stack.top()), dst);\n\t\t\treturn true;\n\t\t}\n\n\t\tdata->recycle();\n\t}\n\n\treturn false;\n}\n\n#endif\n\n\n#if LOG4CXX_CFSTRING_API\nNDC::NDC(const CFStringRef& message)\n{\n\tpush(message);\n}\n\nvoid NDC::push(const CFStringRef& message)\n{\n\tLOG4CXX_DECODE_CFSTRING(msg, message);\n\tpushLS(msg);\n}\n\nbool NDC::pop(CFStringRef& dst)\n{\n\tThreadSpecificData* data = ThreadSpecificData::getCurrentData();\n\n\tif (data != 0)\n\t{\n\t\tStack& stack = data->getStack();\n\n\t\tif (!stack.empty())\n\t\t{\n\t\t\tdst = Transcoder::encode(getMessage(stack.top()));\n\t\t\tstack.pop();\n\t\t\tdata->recycle();\n\t\t\treturn true;\n\t\t}\n\n\t\tdata->recycle();\n\t}\n\n\treturn false;\n}\n\nbool NDC::peek(CFStringRef& dst)\n{\n\tThreadSpecificData* data = ThreadSpecificData::getCurrentData();\n\n\tif (data != 0)\n\t{\n\t\tStack& stack = data->getStack();\n\n\t\tif (!stack.empty())\n\t\t{\n\t\t\tdst = Transcoder::encode(getMessage(stack.top()));\n\t\t\treturn true;\n\t\t}\n\n\t\tdata->recycle();\n\t}\n\n\treturn false;\n}\n\n#endif\n\n\n// Path: src/main/cpp/fileinputstream.cpp\n/*\n * Licensed to the Apache Software Foundation (ASF) under one or more\n * contributor license agreements. See the NOTICE file distributed with\n * this work for additional information regarding copyright ownership.\n * The ASF licenses this file to You under the Apache License, Version 2.0\n * (the \"License\"); you may not use this file except in compliance with\n * the License. You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n */\n\n#include \n#include \n#include \n#include \n#include \n#include \n#if !defined(LOG4CXX)\n\t#define LOG4CXX 1\n#endif\n#include \n\nusing namespace LOG4CXX_NS;\nusing namespace LOG4CXX_NS::helpers;\n\nstruct FileInputStream::FileInputStreamPrivate\n{\n\tFileInputStreamPrivate() : fileptr(nullptr) {}\n\n\tPool pool;\n\tapr_file_t* fileptr;\n};\n\nIMPLEMENT_LOG4CXX_OBJECT(FileInputStream)\n\nFileInputStream::FileInputStream(const LogString& filename) :\n\tm_priv(std::make_unique())\n{\n\topen(filename);\n}\n\nFileInputStream::FileInputStream(const logchar* filename) :\n\tm_priv(std::make_unique())\n{\n\tLogString fn(filename);\n\topen(fn);\n}\n\n\nvoid FileInputStream::open(const LogString& filename)\n{\n\tapr_fileperms_t perm = APR_OS_DEFAULT;\n\tapr_int32_t flags = APR_READ;\n\tapr_status_t stat = File().setPath(filename).open(&m_priv->fileptr, flags, perm, m_priv->pool);\n\n\tif (stat != APR_SUCCESS)\n\t{\n\t\tthrow IOException(stat);\n\t}\n}\n\n\nFileInputStream::FileInputStream(const File& aFile) :\n\tm_priv(std::make_unique())\n{\n\tapr_fileperms_t perm = APR_OS_DEFAULT;\n\tapr_int32_t flags = APR_READ;\n\tapr_status_t stat = aFile.open(&m_priv->fileptr, flags, perm, m_priv->pool);\n\n\tif (stat != APR_SUCCESS)\n\t{\n\t\tthrow IOException(stat);\n\t}\n}\n\n\nFileInputStream::~FileInputStream()\n{\n\tif (m_priv->fileptr != NULL && !APRInitializer::isDestructed)\n\t{\n\t\tapr_file_close(m_priv->fileptr);\n\t}\n}\n\n\nvoid FileInputStream::close()\n{\n\tapr_status_t stat = apr_file_close(m_priv->fileptr);\n\n\tif (stat == APR_SUCCESS)\n\t{\n\t\tm_priv->fileptr = NULL;\n\t}\n\telse\n\t{\n\t\tthrow IOException(stat);\n\t}\n}\n\n\nint FileInputStream::read(ByteBuffer& buf)\n{\n\tapr_size_t bytesRead = buf.remaining();\n\tapr_status_t stat = apr_file_read(m_priv->fileptr, buf.current(), &bytesRead);\n\tint retval = -1;\n\n\tif (!APR_STATUS_IS_EOF(stat))\n\t{\n\t\tif (stat != APR_SUCCESS)\n\t\t{\n\t\t\tthrow IOException(stat);\n\t\t}\n\n\t\tbuf.position(buf.position() + bytesRead);\n\t\tretval = (int)bytesRead;\n\t}\n\n\treturn retval;\n}\n\n// Path: src/main/cpp/classnamepatternconverter.cpp\n/*\n * Licensed to the Apache Software Foundation (ASF) under one or more\n * contributor license agreements. See the NOTICE file distributed with\n * this work for additional information regarding copyright ownership.\n * The ASF licenses this file to You under the Apache License, Version 2.0\n * (the \"License\"); you may not use this file except in compliance with\n * the License. You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n */\n\n#include \n#include \n#include \n#include \n\nusing namespace LOG4CXX_NS;\nusing namespace LOG4CXX_NS::pattern;\nusing namespace LOG4CXX_NS::spi;\nusing namespace LOG4CXX_NS::helpers;\n\nIMPLEMENT_LOG4CXX_OBJECT(ClassNamePatternConverter)\n\nClassNamePatternConverter::ClassNamePatternConverter(\n\tconst std::vector& options) :\n\tNamePatternConverter(LOG4CXX_STR(\"Class Name\"),\n\t\tLOG4CXX_STR(\"class name\"), options)\n{\n}\n\nPatternConverterPtr ClassNamePatternConverter::newInstance(\n\tconst std::vector& options)\n{\n\tif (options.size() == 0)\n\t{\n\t\tstatic WideLife def = std::make_shared(options);\n\t\treturn def;\n\t}\n\n\treturn std::make_shared(options);\n}\n\nvoid ClassNamePatternConverter::format(\n\tconst LoggingEventPtr& event,\n\tLogString& toAppendTo,\n\tPool& /* p */) const\n{\n\tint initialLength = (int)toAppendTo.length();\n\tappend(toAppendTo, event->getLocationInformation().getClassName());\n\tabbreviate(initialLength, toAppendTo);\n}\n\n// Path: src/main/cpp/pool.cpp\n/*\n * Licensed to the Apache Software Foundation (ASF) under one or more\n * contributor license agreements. See the NOTICE file distributed with\n * this work for additional information regarding copyright ownership.\n * The ASF licenses this file to You under the Apache License, Version 2.0\n * (the \"License\"); you may not use this file except in compliance with\n * the License. You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n...\n", "needle_token_start": 5095, "needle_token_end": 5486, "code_context_ntokens": 15141, "output": ["```cpp\ntemplate\nstatic unsigned int decodeUTF16(const String& in, Iterator& iter)\n{\n\tunsigned int ch1 = *iter;\n\n\t//\n\t// if not surrogate pair\n\t//\n\tif (ch1 < 0xD800 || ch1 > 0xDFFF)\n\t{\n\t\t//\n\t\t// then advance iterator and return wchar_t value\n\t\t//\n\t\tif (ch1 != 0xFFFF)\n\t\t{\n\t\t\titer++;\n\t\t}\n\n\t\treturn ch1;\n\t}\n\telse if (ch1 < 0xDC00)\n\t{\n\t\t//\n\t\t// started with high-surrogate value\n\t\t// if there is an additional wchar_t\n\t\tIterator iter2 = iter + 1;\n\n\t\tif (iter2 != in.end())\n\t\t{\n\t\t\tunsigned int ch2 = *iter2;\n\n\t\t\t//\n\t\t\t// if it is a matching low surrogate then\n\t\t\t// advance the iterator and return the scalar value\n\t\t\tif (ch2 >= 0xDC00 && ch2 <= 0xDFFF)\n\t\t\t{\n\t\t\t\titer += 2;\n\t\t\t\treturn (ch1 - 0xD800) * 0x400 + (ch2 - 0xDC00) + 0x10000;\n\t\t\t}\n\t\t}\n\t}\n\n\t//\n\t// unrecognized value, do not advance iterator\n\t//\n\treturn 0xFFFF;\n}\n```"]} +{"repo": "apache/logging-log4cxx", "name": "catch", "language": "cpp", "path": "src/main/cpp/odbcappender.cpp", "position_ratio": 0.45, "description": "\nFunction Description:\n1. **Purpose**: The purpose of this function is to handle exceptions that occur during the execution of a prepared SQL statement in a database logging context. It ensures that any SQL-related errors are caught and reported appropriately.\n2. **Input**: The input includes a prepared SQL statement and a logging event that contains the data to be logged.\n3. **Output**: There is no direct output from this function; however, it affects the program by logging error messages to an error handler if an exception occurs.\n4. **Procedure**: The procedure involves attempting to execute a prepared SQL statement with the provided logging event data. If the execution fails (indicated by a negative return value), an SQL exception is thrown. This exception is then caught, and an error message is logged through an error handler, specifying the nature of the failure and the error code associated with the SQL execution failure.\n\n", "instruction": "Based on the function description and code context, please retrieve and repeat the exact described function from the code context in a code block wrapped by ```:", "template": "instruction\ncode_context\ndescription\ninstruction", "code_context": " * limitations under the License.\n */\n#include \n#include \n#include \n#include \n#include \n#include \n#include \n#include \n#include \n#include // std::pow\n\n#include \n#include \n#include \n#include \n#include \n#include \n#include \n#include \n#include \n#include \n#include \n#include \n#include \n\n#if !defined(LOG4CXX)\n\t#define LOG4CXX 1\n#endif\n#include \n#if LOG4CXX_HAVE_ODBC\n\t#if defined(WIN32) || defined(_WIN32)\n\t\t#include \n\t#endif\n\t#include \n#else\n\ttypedef void* SQLHSTMT;\n#endif\n#include \n#if defined(min)\n\t#undef min\n#endif\n#include \n#include \n\n\nusing namespace LOG4CXX_NS;\nusing namespace LOG4CXX_NS::helpers;\nusing namespace LOG4CXX_NS::db;\nusing namespace LOG4CXX_NS::spi;\nusing namespace LOG4CXX_NS::pattern;\n\nSQLException::SQLException(short fHandleType,\n\tvoid* hInput, const char* prolog,\n\tLOG4CXX_NS::helpers::Pool& p)\n\t: Exception(formatMessage(fHandleType, hInput, prolog, p))\n{\n}\n\n\nSQLException::SQLException(const char* msg)\n\t: Exception(msg)\n{\n}\n\nSQLException::SQLException(const SQLException& src)\n\t: Exception(src)\n{\n}\n\nconst char* SQLException::formatMessage(short fHandleType,\n\tvoid* hInput, const char* prolog, LOG4CXX_NS::helpers::Pool& p)\n{\n\tstd::string strReturn(prolog);\n\tstrReturn.append(\" - \");\n#if LOG4CXX_HAVE_ODBC\n\tSQLCHAR SqlState[6];\n\tSQLCHAR Msg[SQL_MAX_MESSAGE_LENGTH];\n\tSQLINTEGER NativeError;\n\tSQLSMALLINT i;\n\tSQLSMALLINT MsgLen;\n\tSQLRETURN rc2;\n\n\t// Get the status records.\n\ti = 1;\n\n\twhile ((rc2 = SQLGetDiagRecA(fHandleType, hInput, i, SqlState, &NativeError,\n\t\t\t\t\tMsg, sizeof(Msg), &MsgLen)) != SQL_NO_DATA)\n\t{\n\t\tstrReturn.append((char*) Msg);\n\t\ti++;\n\t}\n\n#else\n\tstrReturn.append(\"log4cxx built without ODBC support\");\n#endif\n\n\treturn apr_pstrdup((apr_pool_t*) p.getAPRPool(), strReturn.c_str());\n}\n\n\nIMPLEMENT_LOG4CXX_OBJECT(ODBCAppender)\n\n#define _priv static_cast(m_priv.get())\n\nODBCAppender::ODBCAppender()\n\t: AppenderSkeleton (std::make_unique(\n#if LOG4CXX_EVENTS_AT_EXIT\n\t\t[this] {\n\t\t\tstd::lock_guard lock(_priv->mutex);\n\t\t\tif(_priv->closed)\n\t\t\t\treturn;\n\t\t\ttry\n\t\t\t{\n\t\t\t\tflushBuffer(_priv->pool);\n\t\t\t}\n\t\t\tcatch (SQLException& e)\n\t\t\t{\n\t\t\t\t_priv->errorHandler->error(LOG4CXX_STR(\"Error flushing connection\"),\n\t\t\t\t\te, ErrorCode::GENERIC_FAILURE);\n\t\t\t}\n\t\t}\n#endif\n\t\t\t\t\t\t\t\t))\n{\n}\n\nODBCAppender::~ODBCAppender()\n{\n\tfinalize();\n}\n\n#define RULES_PUT(spec, cls) \\\n\tspecs.insert(PatternMap::value_type(LogString(LOG4CXX_STR(spec)), cls ::newInstance))\n\nstatic PatternMap getFormatSpecifiers()\n{\n\tPatternMap specs;\n\tif (specs.empty())\n\t{\n\t\tRULES_PUT(\"logger\", LoggerPatternConverter);\n\t\tRULES_PUT(\"class\", ClassNamePatternConverter);\n\t\tRULES_PUT(\"time\", DatePatternConverter);\n\t\tRULES_PUT(\"shortfilename\", ShortFileLocationPatternConverter);\n\t\tRULES_PUT(\"fullfilename\", FileLocationPatternConverter);\n\t\tRULES_PUT(\"location\", FullLocationPatternConverter);\n\t\tRULES_PUT(\"line\", LineLocationPatternConverter);\n\t\tRULES_PUT(\"message\", MessagePatternConverter);\n\t\tRULES_PUT(\"method\", MethodLocationPatternConverter);\n\t\tRULES_PUT(\"level\", LevelPatternConverter);\n\t\tRULES_PUT(\"thread\", ThreadPatternConverter);\n\t\tRULES_PUT(\"threadname\", ThreadUsernamePatternConverter);\n\t\tRULES_PUT(\"mdc\", MDCPatternConverter);\n\t\tRULES_PUT(\"ndc\", NDCPatternConverter);\n\t}\n\treturn specs;\n}\n\nvoid ODBCAppender::setOption(const LogString& option, const LogString& value)\n{\n\tif (StringHelper::equalsIgnoreCase(option, LOG4CXX_STR(\"BUFFERSIZE\"), LOG4CXX_STR(\"buffersize\")))\n\t{\n\t\tsetBufferSize((size_t)OptionConverter::toInt(value, 1));\n\t}\n\telse if (StringHelper::equalsIgnoreCase(option, LOG4CXX_STR(\"PASSWORD\"), LOG4CXX_STR(\"password\")))\n\t{\n\t\tsetPassword(value);\n\t}\n\telse if (StringHelper::equalsIgnoreCase(option, LOG4CXX_STR(\"SQL\"), LOG4CXX_STR(\"sql\")))\n\t{\n\t\tsetSql(value);\n\t}\n\telse if (StringHelper::equalsIgnoreCase(option, LOG4CXX_STR(\"URL\"), LOG4CXX_STR(\"url\"))\n\t\t|| StringHelper::equalsIgnoreCase(option, LOG4CXX_STR(\"DSN\"), LOG4CXX_STR(\"dsn\"))\n\t\t|| StringHelper::equalsIgnoreCase(option, LOG4CXX_STR(\"CONNECTIONSTRING\"), LOG4CXX_STR(\"connectionstring\")) )\n\t{\n\t\tsetURL(value);\n\t}\n\telse if (StringHelper::equalsIgnoreCase(option, LOG4CXX_STR(\"USER\"), LOG4CXX_STR(\"user\")))\n\t{\n\t\tsetUser(value);\n\t}\n\telse if (StringHelper::equalsIgnoreCase(option, LOG4CXX_STR(\"COLUMNMAPPING\"), LOG4CXX_STR(\"columnmapping\")))\n\t{\n\t\t_priv->mappedName.push_back(value);\n\t}\n\telse\n\t{\n\t\tAppenderSkeleton::setOption(option, value);\n\t}\n}\n\n// Does ODBCAppender require a layout?\n\nbool ODBCAppender::requiresLayout() const\n{\n\treturn false;\n}\n\nvoid ODBCAppender::activateOptions(LOG4CXX_NS::helpers::Pool&)\n{\n#if !LOG4CXX_HAVE_ODBC\n\tLogLog::error(LOG4CXX_STR(\"Can not activate ODBCAppender unless compiled with ODBC support.\"));\n#else\n\tif (_priv->mappedName.empty())\n\t{\n\t\tLogLog::error(LOG4CXX_STR(\"ODBCAppender column mappings not defined, logging events will not be inserted\"));\n\t}\n\tauto specs = getFormatSpecifiers();\n\tfor (auto& name : _priv->mappedName)\n\t{\n\t\tauto lowerName = StringHelper::toLowerCase(name);\n\t\tauto pItem = specs.find(lowerName);\n\t\tif (specs.end() == pItem)\n\t\t{\n\t\t\tif (lowerName.size() < 5\n\t\t\t || lowerName.substr(0, 4) != LOG4CXX_STR(\"mdc{\"))\n\t\t\t\tLogLog::error(name + LOG4CXX_STR(\" is not a supported ColumnMapping value\"));\n\t\t\telse // A single MDC entry\n\t\t\t{\n\t\t\t\tauto index = lowerName.find(0x7D /* '}' */, 4);\n\t\t\t\tauto len = (lowerName.npos == index ? lowerName.size() : index) - 4;\n\t\t\t\tODBCAppenderPriv::DataBinding paramData{ 0, 0, 0, 0, 0 };\n\t\t\t\tparamData.converter = std::make_shared(lowerName.substr(4, len));\n\t\t\t\t_priv->parameterValue.push_back(paramData);\n\t\t\t}\n\t\t}\n\t\telse\n\t\t{\n\t\t\tODBCAppenderPriv::DataBinding paramData{ 0, 0, 0, 0, 0 };\n\t\t\tstd::vector options;\n\t\t\tif (LOG4CXX_STR(\"time\") == pItem->first)\n\t\t\t\toptions.push_back(LOG4CXX_STR(\"yyyy-MM-dd HH:mm:ss.SSSSSS\"));\n\t\t\tparamData.converter = LOG4CXX_NS::cast((pItem->second)(options));\n\t\t\t_priv->parameterValue.push_back(paramData);\n\t\t}\n\t}\n#endif\n}\n\n\nvoid ODBCAppender::append(const spi::LoggingEventPtr& event, LOG4CXX_NS::helpers::Pool& p)\n{\n#if LOG4CXX_HAVE_ODBC\n\t_priv->buffer.push_back(event);\n\n\tif (_priv->buffer.size() >= _priv->bufferSize)\n\t{\n\t\tflushBuffer(p);\n\t}\n\n#endif\n}\n\nLogString ODBCAppender::getLogStatement(const spi::LoggingEventPtr& event, LOG4CXX_NS::helpers::Pool& p) const\n{\n return LogString();\n}\n\nvoid ODBCAppender::execute(const LogString& sql, LOG4CXX_NS::helpers::Pool& p)\n{\n}\n\n/* The default behavior holds a single connection open until the appender\nis closed (typically when garbage collected).*/\nvoid ODBCAppender::closeConnection(ODBCAppender::SQLHDBC /* con */)\n{\n}\n\nODBCAppender::SQLHDBC ODBCAppender::getConnection(LOG4CXX_NS::helpers::Pool& p)\n{\n#if LOG4CXX_HAVE_ODBC\n\tSQLRETURN ret;\n\n\tif (_priv->env == SQL_NULL_HENV)\n\t{\n\t\tret = SQLAllocHandle(SQL_HANDLE_ENV, SQL_NULL_HANDLE, &_priv->env);\n\n\t\tif (ret < 0)\n\t\t{\n\t\t\tSQLException ex(SQL_HANDLE_ENV, _priv->env, \"Failed to allocate SQL handle\", p);\n\t\t\t_priv->env = SQL_NULL_HENV;\n\t\t\tthrow ex;\n\t\t}\n\n\t\tret = SQLSetEnvAttr(_priv->env, SQL_ATTR_ODBC_VERSION, (SQLPOINTER) SQL_OV_ODBC3, SQL_IS_INTEGER);\n\n\t\tif (ret < 0)\n\t\t{\n\t\t\tSQLException ex(SQL_HANDLE_ENV, _priv->env, \"Failed to set odbc version\", p);\n\t\t\tSQLFreeHandle(SQL_HANDLE_ENV, _priv->env);\n\t\t\t_priv->env = SQL_NULL_HENV;\n\t\t\tthrow ex;\n\t\t}\n\t}\n\n\tif (_priv->connection == SQL_NULL_HDBC)\n\t{\n\t\tret = SQLAllocHandle(SQL_HANDLE_DBC, _priv->env, &_priv->connection);\n\n\t\tif (ret < 0)\n\t\t{\n\t\t\tSQLException ex(SQL_HANDLE_DBC, _priv->connection, \"Failed to allocate sql handle\", p);\n\t\t\t_priv->connection = SQL_NULL_HDBC;\n\t\t\tthrow ex;\n\t\t}\n\n#if LOG4CXX_LOGCHAR_IS_WCHAR\n\t\tSQLWCHAR *wUser = nullptr, *wPwd = nullptr;\n\t\tif (!_priv->databaseUser.empty())\n\t\t\twUser = (SQLWCHAR*)_priv->databaseUser.c_str();\n\t\tif (!_priv->databasePassword.empty())\n\t\t\twPwd = (SQLWCHAR*)_priv->databasePassword.c_str();\n\t\tret = SQLConnectW(_priv->connection\n\t\t\t, (SQLWCHAR*)_priv->databaseURL.c_str(), SQL_NTS\n\t\t\t, wUser, SQL_NTS\n\t\t\t, wPwd, SQL_NTS\n\t\t\t);\n#elif LOG4CXX_LOGCHAR_IS_UTF8\n\t\tSQLCHAR *wUser = nullptr, *wPwd = nullptr;\n\t\tif (!_priv->databaseUser.empty())\n\t\t\twUser = (SQLCHAR*)_priv->databaseUser.c_str();\n\t\tif (!_priv->databasePassword.empty())\n\t\t\twPwd = (SQLCHAR*)_priv->databasePassword.c_str();\n\t\tret = SQLConnectA(_priv->connection\n\t\t\t, (SQLCHAR*)_priv->databaseURL.c_str(), SQL_NTS\n\t\t\t, wUser, SQL_NTS\n\t\t\t, wPwd, SQL_NTS\n\t\t\t);\n#else\n\t\tSQLWCHAR* wURL, *wUser = nullptr, *wPwd = nullptr;\n\t\tencode(&wURL, _priv->databaseURL, p);\n\t\tif (!_priv->databaseUser.empty())\n\t\t\tencode(&wUser, _priv->databaseUser, p);\n\t\tif (!_priv->databasePassword.empty())\n\t\t\tencode(&wPwd, _priv->databasePassword, p);\n\n\t\tret = SQLConnectW( _priv->connection\n\t\t\t, wURL, SQL_NTS\n\t\t\t, wUser, SQL_NTS\n\t\t\t, wPwd, SQL_NTS\n\t\t\t);\n#endif\n\n\t\tif (ret < 0)\n\t\t{\n\t\t\tSQLException ex(SQL_HANDLE_DBC, _priv->connection, \"Failed to connect to database\", p);\n\t\t\tSQLFreeHandle(SQL_HANDLE_DBC, _priv->connection);\n\t\t\t_priv->connection = SQL_NULL_HDBC;\n\t\t\tthrow ex;\n\t\t}\n\t}\n\n\treturn _priv->connection;\n#else\n\treturn 0;\n#endif\n}\n\nvoid ODBCAppender::close()\n{\n\tif (_priv->closed)\n\t{\n\t\treturn;\n\t}\n\n\tPool p;\n\n\ttry\n\t{\n\t\tflushBuffer(p);\n\t}\n\tcatch (SQLException& e)\n\t{\n\t\t_priv->errorHandler->error(LOG4CXX_STR(\"Error closing connection\"),\n\t\t\te, ErrorCode::GENERIC_FAILURE);\n\t}\n\n#if LOG4CXX_HAVE_ODBC\n\n\tif (_priv->connection != SQL_NULL_HDBC)\n\t{\n\t\tSQLDisconnect(_priv->connection);\n\t\tSQLFreeHandle(SQL_HANDLE_DBC, _priv->connection);\n\t}\n\n\tif (_priv->env != SQL_NULL_HENV)\n\t{\n\t\tSQLFreeHandle(SQL_HANDLE_ENV, _priv->env);\n\t}\n\n#endif\n\t_priv->closed = true;\n}\n\n#if LOG4CXX_HAVE_ODBC\nvoid ODBCAppender::ODBCAppenderPriv::setPreparedStatement(SQLHDBC con, Pool& p)\n{\n\tauto ret = SQLAllocHandle( SQL_HANDLE_STMT, con, &this->preparedStatement);\n\tif (ret < 0)\n\t{\n\t\tthrow SQLException( SQL_HANDLE_DBC, con, \"Failed to allocate statement handle.\", p);\n\t}\n\n#if LOG4CXX_LOGCHAR_IS_WCHAR\n\tret = SQLPrepareW(this->preparedStatement, (SQLWCHAR*)this->sqlStatement.c_str(), SQL_NTS);\n#elif LOG4CXX_LOGCHAR_IS_UTF8\n\tret = SQLPrepareA(this->preparedStatement, (SQLCHAR*)this->sqlStatement.c_str(), SQL_NTS);\n#else\n\tSQLWCHAR* wsql;\n\tencode(&wsql, this->sqlStatement, p);\n\tret = SQLPrepareW(this->preparedStatement, wsql, SQL_NTS);\n#endif\n\tif (ret < 0)\n\t{\n\t\tthrow SQLException(SQL_HANDLE_STMT, this->preparedStatement, \"Failed to prepare sql statement.\", p);\n\t}\n\n\tint parameterNumber = 0;\n\tfor (auto& item : this->parameterValue)\n\t{\n\t\t++parameterNumber;\n\t\tSQLSMALLINT targetType;\n\t\tSQLULEN targetMaxCharCount;\n\t\tSQLSMALLINT decimalDigits;\n\t\tSQLSMALLINT nullable;\n\t\tauto ret = SQLDescribeParam\n\t\t\t( this->preparedStatement\n\t\t\t, parameterNumber\n\t\t\t, &targetType\n\t\t\t, &targetMaxCharCount\n\t\t\t, &decimalDigits\n\t\t\t, &nullable\n\t\t\t);\n\t\tif (ret < 0)\n\t\t{\n\t\t\tthrow SQLException(SQL_HANDLE_STMT, this->preparedStatement, \"Failed to describe parameter\", p);\n\t\t}\n\t\tif (SQL_CHAR == targetType || SQL_VARCHAR == targetType || SQL_LONGVARCHAR == targetType)\n\t\t{\n\t\t\titem.paramType = SQL_C_CHAR;\n\t\t\titem.paramMaxCharCount = targetMaxCharCount;\n\t\t\titem.paramValueSize = (SQLINTEGER)(item.paramMaxCharCount) * sizeof(char) + sizeof(char);\n\t\t\titem.paramValue = (SQLPOINTER)p.palloc(item.paramValueSize + sizeof(char));\n\t\t}\n\t\telse if (SQL_WCHAR == targetType || SQL_WVARCHAR == targetType || SQL_WLONGVARCHAR == targetType)\n\t\t{\n\t\t\titem.paramType = SQL_C_WCHAR;\n\t\t\titem.paramMaxCharCount = targetMaxCharCount;\n\t\t\titem.paramValueSize = (SQLINTEGER)(targetMaxCharCount) * sizeof(wchar_t) + sizeof(wchar_t);\n\t\t\titem.paramValue = (SQLPOINTER)p.palloc(item.paramValueSize + sizeof(wchar_t));\n\t\t}\n\t\telse if (SQL_TYPE_TIMESTAMP == targetType || SQL_TYPE_DATE == targetType || SQL_TYPE_TIME == targetType\n\t\t\t|| SQL_DATETIME == targetType)\n\t\t{\n\t\t\titem.paramType = SQL_C_TYPE_TIMESTAMP;\n\t\t\titem.paramMaxCharCount = (0 <= decimalDigits) ? decimalDigits : 6;\n\t\t\titem.paramValueSize = sizeof(SQL_TIMESTAMP_STRUCT);\n\t\t\titem.paramValue = (SQLPOINTER)p.palloc(item.paramValueSize);\n\t\t}\n\t\telse\n\t\t{\n\t\t\tif (SQL_INTEGER != targetType)\n\t\t\t{\n\t\t\t\tLogString msg(LOG4CXX_STR(\"Unexpected targetType (\"));\n\t\t\t\thelpers::StringHelper::toString(targetType, p, msg);\n\t\t\t\tmsg += LOG4CXX_STR(\") at parameter \");\n\t\t\t\thelpers::StringHelper::toString(parameterNumber, p, msg);\n\t\t\t\tmsg += LOG4CXX_STR(\" while preparing SQL\");\n\t\t\t\tLogLog::warn(msg);\n\t\t\t}\n\t\t\titem.paramMaxCharCount = 30;\n#if LOG4CXX_LOGCHAR_IS_UTF8\n\t\t\titem.paramType = SQL_C_CHAR;\n\t\t\titem.paramValueSize = (SQLINTEGER)(item.paramMaxCharCount) * sizeof(char);\n\t\t\titem.paramValue = (SQLPOINTER)p.palloc(item.paramValueSize + sizeof(char));\n#else\n\t\t\titem.paramType = SQL_C_WCHAR;\n\t\t\titem.paramValueSize = (SQLINTEGER)(item.paramMaxCharCount) * sizeof(wchar_t);\n\t\t\titem.paramValue = (SQLPOINTER)p.palloc(item.paramValueSize + sizeof(wchar_t));\n#endif\n\t\t}\n\t\titem.strLen_or_Ind = SQL_NTS;\n\t\tret = SQLBindParameter\n\t\t\t( this->preparedStatement\n\t\t\t, parameterNumber\n\t\t\t, SQL_PARAM_INPUT\n\t\t\t, item.paramType // ValueType\n\t\t\t, targetType\n\t\t\t, targetMaxCharCount\n\t\t\t, decimalDigits\n\t\t\t, item.paramValue\n\t\t\t, item.paramValueSize\n\t\t\t, &item.strLen_or_Ind\n\t\t\t);\n\t\tif (ret < 0)\n\t\t{\n\t\t\tthrow SQLException(SQL_HANDLE_STMT, this->preparedStatement, \"Failed to bind parameter\", p);\n\t\t}\n\t}\n}\n\nvoid ODBCAppender::ODBCAppenderPriv::setParameterValues(const spi::LoggingEventPtr& event, Pool& p)\n{\n\tfor (auto& item : this->parameterValue)\n\t{\n\t\tif (!item.paramValue || item.paramValueSize <= 0)\n\t\t\t;\n\t\telse if (SQL_C_WCHAR == item.paramType)\n\t\t{\n\t\t\tLogString sbuf;\n\t\t\titem.converter->format(event, sbuf, p);\n#if LOG4CXX_LOGCHAR_IS_WCHAR_T\n\t\t\tstd::wstring& tmp = sbuf;\n#else\n\t\t\tstd::wstring tmp;\n\t\t\tTranscoder::encode(sbuf, tmp);\n#endif\n\t\t\tauto dst = (wchar_t*)item.paramValue;\n\t\t\tauto charCount = std::min(size_t(item.paramMaxCharCount), tmp.size());\n\t\t\tauto copySize = std::min(size_t(item.paramValueSize - 1), charCount * sizeof(wchar_t));\n\t\t\tstd::memcpy(dst, tmp.data(), copySize);\n\t\t\tdst[copySize / sizeof(wchar_t)] = 0;\n\t\t}\n\t\telse if (SQL_C_CHAR == item.paramType)\n\t\t{\n\t\t\tLogString sbuf;\n\t\t\titem.converter->format(event, sbuf, p);\n#if LOG4CXX_LOGCHAR_IS_UTF8\n\t\t\tstd::string& tmp = sbuf;\n#else\n\t\t\tstd::string tmp;\n\t\t\tTranscoder::encode(sbuf, tmp);\n#endif\n\t\t\tauto dst = (char*)item.paramValue;\n\t\t\tauto sz = std::min(size_t(item.paramMaxCharCount), tmp.size());\n\t\t\tauto copySize = std::min(size_t(item.paramValueSize - 1), sz * sizeof(char));\n\t\t\tstd::memcpy(dst, tmp.data(), copySize);\n\t\t\tdst[copySize] = 0;\n\t\t}\n\t\telse if (SQL_C_TYPE_TIMESTAMP == item.paramType)\n\t\t{\n\t\t\tapr_time_exp_t exploded;\n\t\t\tapr_status_t stat = this->timeZone->explode(&exploded, event->getTimeStamp());\n\t\t\tif (stat == APR_SUCCESS)\n\t\t\t{\n\t\t\t\tauto dst = (SQL_TIMESTAMP_STRUCT*)item.paramValue;\n\t\t\t\tdst->year = 1900 + exploded.tm_year;\n\t\t\t\tdst->month = 1 + exploded.tm_mon;\n\t\t\t\tdst->day = exploded.tm_mday;\n\t\t\t\tdst->hour = exploded.tm_hour;\n\t\t\t\tdst->minute = exploded.tm_min;\n\t\t\t\tdst->second = exploded.tm_sec;\n\t\t\t\t// Prevent '[ODBC SQL Server Driver]Datetime field overflow' by rounding to the target field precision\n\t\t\t\tint roundingExponent = 6 - (int)item.paramMaxCharCount;\n\t\t\t\tif (0 < roundingExponent)\n\t\t\t\t{\n\t\t\t\t\tint roundingDivisor = (int)std::pow(10, roundingExponent);\n\t\t\t\t\tdst->fraction = 1000 * roundingDivisor * ((exploded.tm_usec + roundingDivisor / 2) / roundingDivisor);\n\t\t\t\t}\n\t\t\t\telse\n\t\t\t\t\tdst->fraction = 1000 * exploded.tm_usec;\n\t\t\t}\n\t\t}\n\t}\n}\n#endif\n\nvoid ODBCAppender::flushBuffer(Pool& p)\n{\n\tfor (auto& logEvent : _priv->buffer)\n\t{\n\t\tif (_priv->parameterValue.empty())\n\t\t\t_priv->errorHandler->error(LOG4CXX_STR(\"ODBCAppender column mappings not defined\"));\n#if LOG4CXX_HAVE_ODBC\n\t\telse try\n\t\t{\n\t\t\tif (0 == _priv->preparedStatement)\n\t\t\t\t_priv->setPreparedStatement(getConnection(p), p);\n\t\t\t_priv->setParameterValues(logEvent, p);\n\t\t\tauto ret = SQLExecute(_priv->preparedStatement);\n\t\t\tif (ret < 0)\n\t\t\t{\n\t\t\t\tthrow SQLException(SQL_HANDLE_STMT, _priv->preparedStatement, \"Failed to execute prepared statement\", p);\n\t\t\t}\n\t\t}\n\t\t\ncatch (SQLException& e)\n\t\t{\n\t\t\t_priv->errorHandler->error(LOG4CXX_STR(\"Failed to execute sql\"), e,\n\t\t\t\tErrorCode::FLUSH_FAILURE);\n\t\t}\n#endif\n\t}\n\n\t// clear the buffer of reported events\n\t_priv->buffer.clear();\n}\n\nvoid ODBCAppender::setSql(const LogString& s)\n{\n _priv->sqlStatement = s;\n}\n\n#if LOG4CXX_WCHAR_T_API || LOG4CXX_LOGCHAR_IS_WCHAR_T || defined(WIN32) || defined(_WIN32)\nvoid ODBCAppender::encode(wchar_t** dest, const LogString& src, Pool& p)\n{\n\t*dest = Transcoder::wencode(src, p);\n}\n#endif\n\nvoid ODBCAppender::encode(unsigned short** dest,\n\tconst LogString& src, Pool& p)\n{\n\t// worst case double number of characters from UTF-8 or wchar_t\n\t*dest = (unsigned short*)\n\t\tp.palloc((src.size() + 1) * 2 * sizeof(unsigned short));\n\tunsigned short* current = *dest;\n\n\tfor (LogString::const_iterator i = src.begin();\n\t\ti != src.end();)\n\t{\n\t\tunsigned int sv = Transcoder::decode(src, i);\n\n\t\tif (sv < 0x10000)\n\t\t{\n\t\t\t*current++ = (unsigned short) sv;\n\t\t}\n\t\telse\n\t\t{\n\t\t\tunsigned char u = (unsigned char) (sv >> 16);\n\t\t\tunsigned char w = (unsigned char) (u - 1);\n\t\t\tunsigned short hs = (0xD800 + ((w & 0xF) << 6) + ((sv & 0xFFFF) >> 10));\n\t\t\tunsigned short ls = (0xDC00 + (sv & 0x3FF));\n\t\t\t*current++ = (unsigned short) hs;\n\t\t\t*current++ = (unsigned short) ls;\n\t\t}\n\t}\n\n\t*current = 0;\n}\n\nconst LogString& ODBCAppender::getSql() const\n{\n\treturn _priv->sqlStatement;\n}\n\nvoid ODBCAppender::setUser(const LogString& user)\n{\n\t_priv->databaseUser = user;\n}\n\nvoid ODBCAppender::setURL(const LogString& url)\n{\n\t_priv->databaseURL = url;\n}\n\nvoid ODBCAppender::setPassword(const LogString& password)\n{\n\t_priv->databasePassword = password;\n}\n\nvoid ODBCAppender::setBufferSize(size_t newBufferSize)\n{\n\t_priv->bufferSize = newBufferSize;\n}\n\nconst LogString& ODBCAppender::getUser() const\n{\n\treturn _priv->databaseUser;\n}\n\nconst LogString& ODBCAppender::getURL() const\n{\n\treturn _priv->databaseURL;\n}\n\nconst LogString& ODBCAppender::getPassword() const\n{\n\treturn _priv->databasePassword;\n}\n\nsize_t ODBCAppender::getBufferSize() const\n{\n\treturn _priv->bufferSize;\n}\n\n\n// Path: src/main/cpp/stringtokenizer.cpp\n/*\n * Licensed to the Apache Software Foundation (ASF) under one or more\n * contributor license agreements. See the NOTICE file distributed with\n * this work for additional information regarding copyright ownership.\n * The ASF licenses this file to You under the Apache License, Version 2.0\n * (the \"License\"); you may not use this file except in compliance with\n * the License. You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n */\n\n#include \n#include \n#include \n#if !defined(LOG4CXX)\n\t#define LOG4CXX 1\n#endif\n#include \n\nusing namespace LOG4CXX_NS;\nusing namespace LOG4CXX_NS::helpers;\n\nstruct StringTokenizer::StringTokenizerPrivate{\n\tStringTokenizerPrivate(const LogString& str, const LogString& delim1) : src(str), delim(delim1), pos(0){}\n\tLogString src;\n\tLogString delim;\n\tsize_t pos;\n};\n\n\nStringTokenizer::StringTokenizer(const LogString& str, const LogString& delim1)\n\t: m_priv(std::make_unique(str, delim1))\n{\n}\n\nStringTokenizer::~StringTokenizer()\n{\n}\n\nbool StringTokenizer::hasMoreTokens() const\n{\n\treturn (m_priv->pos != LogString::npos\n\t\t\t&& m_priv->src.find_first_not_of(m_priv->delim, m_priv->pos) != LogString::npos);\n}\n\nLogString StringTokenizer::nextToken()\n{\n\tif (m_priv->pos != LogString::npos)\n\t{\n\t\tsize_t nextPos = m_priv->src.find_first_not_of(m_priv->delim, m_priv->pos);\n\n\t\tif (nextPos != LogString::npos)\n\t\t{\n\t\t\tm_priv->pos = m_priv->src.find_first_of(m_priv->delim, nextPos);\n\n\t\t\tif (m_priv->pos == LogString::npos)\n\t\t\t{\n\t\t\t\treturn m_priv->src.substr(nextPos);\n\t\t\t}\n\n\t\t\treturn m_priv->src.substr(nextPos, m_priv->pos - nextPos);\n\t\t}\n\t}\n\n\tthrow NoSuchElementException();\n#if LOG4CXX_RETURN_AFTER_THROW\n\treturn LogString();\n#endif\n}\n\n// Path: src/main/cpp/resourcebundle.cpp\n/*\n * Licensed to the Apache Software Foundation (ASF) under one or more\n * contributor license agreements. See the NOTICE file distributed with\n * this work for additional information regarding copyright ownership.\n * The ASF licenses this file to You under the Apache License, Version 2.0\n * (the \"License\"); you may not use this file except in compliance with\n * the License. You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n */\n#include \n#include \n#include \n#include \n#include \n#include \n#include \n\nusing namespace LOG4CXX_NS;\nusing namespace LOG4CXX_NS::helpers;\n\nIMPLEMENT_LOG4CXX_OBJECT(ResourceBundle)\n\nResourceBundlePtr ResourceBundle::getBundle(const LogString& baseName,\n\tconst Locale& locale)\n{\n\tstd::vector bundlesNames;\n\n\tif (!locale.getVariant().empty())\n\t{\n\t\tbundlesNames.push_back(baseName + LOG4CXX_STR(\"_\") +\n\t\t\tlocale.getLanguage() + LOG4CXX_STR(\"_\") +\n\t\t\tlocale.getCountry() + LOG4CXX_STR(\"_\") +\n\t\t\tlocale.getVariant());\n\t}\n\n\tif (!locale.getCountry().empty())\n\t{\n\t\tbundlesNames.push_back(baseName + LOG4CXX_STR(\"_\") +\n\t\t\tlocale.getLanguage() + LOG4CXX_STR(\"_\") +\n\t\t\tlocale.getCountry());\n\t}\n\n\tif (!locale.getLanguage().empty())\n\t{\n\t\tbundlesNames.push_back(baseName + LOG4CXX_STR(\"_\") +\n\t\t\tlocale.getLanguage());\n\t}\n\n\tbundlesNames.push_back(baseName);\n\n\tPropertyResourceBundlePtr resourceBundle, previous;\n\tfor (auto bundleName : bundlesNames)\n\t{\n\t\tPropertyResourceBundlePtr current;\n\n\t\t// Try loading a class which implements ResourceBundle\n\t\ttry\n\t\t{\n\t\t\tconst Class& classObj = Loader::loadClass(bundleName);\n\t\t\tObjectPtr obj = ObjectPtr(classObj.newInstance());\n\t\t\tcurrent = LOG4CXX_NS::cast(obj);\n\t\t}\n\t\tcatch (ClassNotFoundException&)\n\t\t{\n\t\t\tcurrent.reset();\n\t\t}\n\n\t\t// No class found, then try to create a PropertyResourceBundle from a file\n\t\tif (!current)\n\t\t{\n\t\t\tInputStreamPtr bundleStream =\n\t\t\t\tLoader::getResourceAsStream(\n\t\t\t\t\tbundleName + LOG4CXX_STR(\".properties\"));\n\n\t\t\tif (!bundleStream)\n\t\t\t{\n\t\t\t\tcontinue;\n\t\t\t}\n\n\t\t\ttry\n\t\t\t{\n\t\t\t\tcurrent = std::make_shared(bundleStream);\n\t\t\t}\n\t\t\tcatch (Exception&)\n\t\t\t{\n\t\t\t\tthrow;\n\t\t\t}\n\t\t}\n\n\t\t// Add the new resource bundle to the hierarchy\n\t\tif (!resourceBundle)\n\t\t{\n\t\t\tresourceBundle = current;\n\t\t\tprevious = current;\n\t\t}\n\t\telse\n\t\t{\n\t\t\tprevious->setParent(current);\n\t\t\tprevious = current;\n\t\t}\n\t}\n\n\t// no resource bundle found at all, then throw exception\n\tif (!resourceBundle)\n\t{\n\t\tthrow MissingResourceException(\n\t\t\t((LogString) LOG4CXX_STR(\"Missing resource bundle \")) + baseName);\n\t}\n\n\treturn resourceBundle;\n}\n\n\n// Path: src/main/cpp/outputstreamwriter.cpp\n/*\n * Licensed to the Apache Software Foundation (ASF) under one or more\n * contributor license agreements. See the NOTICE file distributed with\n * this work for additional information regarding copyright ownership.\n * The ASF licenses this file to You under the Apache License, Version 2.0\n * (the \"License\"); you may not use this file except in compliance with\n * the License. You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n */\n\n#include \n#include \n#include \n#include \n#include \n#include \n\nusing namespace LOG4CXX_NS;\nusing namespace LOG4CXX_NS::helpers;\n\nIMPLEMENT_LOG4CXX_OBJECT(OutputStreamWriter)\n\nstruct OutputStreamWriter::OutputStreamWriterPrivate{\n\tOutputStreamWriterPrivate(OutputStreamPtr& out1) : out(out1), enc(CharsetEncoder::getDefaultEncoder()){}\n\n\tOutputStreamWriterPrivate(OutputStreamPtr& out1,\n\t\t\t\t\t\t\t CharsetEncoderPtr& enc1)\n\t\t: out(out1), enc(enc1){}\n\n\tOutputStreamPtr out;\n\tCharsetEncoderPtr enc;\n};\n\nOutputStreamWriter::OutputStreamWriter(OutputStreamPtr& out1)\n\t: m_priv(std::make_unique(out1))\n{\n\tif (out1 == 0)\n\t{\n\t\tthrow NullPointerException(LOG4CXX_STR(\"out parameter may not be null.\"));\n\t}\n}\n\nOutputStreamWriter::OutputStreamWriter(OutputStreamPtr& out1,\n\tCharsetEncoderPtr& enc1)\n\t: m_priv(std::make_unique(out1, enc1))\n{\n\tif (out1 == 0)\n\t{\n\t\tthrow NullPointerException(LOG4CXX_STR(\"out parameter may not be null.\"));\n\t}\n\n\tif (enc1 == 0)\n\t{\n\t\tthrow NullPointerException(LOG4CXX_STR(\"enc parameter may not be null.\"));\n\t}\n}\n\nOutputStreamWriter::~OutputStreamWriter()\n{\n}\n\nvoid OutputStreamWriter::close(Pool& p)\n{\n\tm_priv->out->close(p);\n}\n\nvoid OutputStreamWriter::flush(Pool& p)\n{\n\tm_priv->out->flush(p);\n}\n\nvoid OutputStreamWriter::write(const LogString& str, Pool& p)\n{\n\tif (str.empty())\n\t\treturn;\n\tif (CharsetEncoder::isTriviallyCopyable(str, m_priv->enc))\n\t{\n\t\tByteBuffer buf((char*)str.data(), str.size() * sizeof (logchar));\n\t\tm_priv->out->write(buf, p);\n\t}\n\telse\n\t{\n\t\tenum { BUFSIZE = 1024 };\n\t\tchar stackData[BUFSIZE];\n\t\tchar* rawbuf = stackData;\n\t\tsize_t bufSize = BUFSIZE;\n#ifdef LOG4CXX_MULTI_PROCESS\n\t\tstd::vector heapData;\n\t\t// Ensure the logging event is a single write system call to keep events from each process separate\n\t\tif (bufSize < str.length() * 2)\n\t\t{\n\t\t\theapData.resize(bufSize = str.length() * 2);\n\t\t\trawbuf = heapData.data();\n\t\t}\n#endif\n\t\tByteBuffer buf(rawbuf, bufSize);\n\t\tm_priv->enc->reset();\n\t\tLogString::const_iterator iter = str.begin();\n\n\t\twhile (iter != str.end())\n\t\t{\n\t\t\tCharsetEncoder::encode(m_priv->enc, str, iter, buf);\n\t\t\tbuf.flip();\n\t\t\tm_priv->out->write(buf, p);\n\t\t\tbuf.clear();\n\t\t}\n\n\t\tCharsetEncoder::encode(m_priv->enc, str, iter, buf);\n\t\tm_priv->enc->flush(buf);\n\t\tbuf.flip();\n\t\tm_priv->out->write(buf, p);\n\t}\n}\n\nOutputStreamPtr OutputStreamWriter::getOutputStreamPtr() const\n{\n\treturn m_priv->out;\n}\n\n\n// Path: src/main/cpp/transcoder.cpp\n/*\n * Licensed to the Apache Software Foundation (ASF) under one or more\n * contributor license agreements. See the NOTICE file distributed with\n * this work for additional information regarding copyright ownership.\n * The ASF licenses this file to You under the Apache License, Version 2.0\n * (the \"License\"); you may not use this file except in compliance with\n * the License. You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n */\n\n#include \n#include \n#include \n#include \n#include \n#include \n#include \n#include \n#include \n#include \n#include \n#include \n#if !defined(LOG4CXX)\n\t#define LOG4CXX 1\n#endif\n#include \n\n#if LOG4CXX_CFSTRING_API\n\t#include \n#endif\n\nusing namespace LOG4CXX_NS;\nusing namespace LOG4CXX_NS::helpers;\n\n\nvoid Transcoder::decodeUTF8(const std::string& src, LogString& dst)\n{\n\tstd::string::const_iterator iter = src.begin();\n\n\twhile (iter != src.end())\n\t{\n\t\tunsigned int sv = decode(src, iter);\n\n\t\tif (sv != 0xFFFF)\n\t\t{\n\t\t\tencode(sv, dst);\n\t\t}\n\t\telse\n\t\t{\n\t\t\tdst.append(1, LOSSCHAR);\n\t\t\titer++;\n\t\t}\n\t}\n}\n\nvoid Transcoder::encodeUTF8(const LogString& src, std::string& dst)\n{\n#if LOG4CXX_LOGCHAR_IS_UTF8\n\tdst.append(src);\n#else\n\tLogString::const_iterator iter = src.begin();\n\n\twhile (iter != src.end())\n\t{\n\t\tunsigned int sv = decode(src, iter);\n\n\t\tif (sv != 0xFFFF)\n\t\t{\n\t\t\tencode(sv, dst);\n\t\t}\n\t\telse\n\t\t{\n\t\t\tdst.append(1, LOSSCHAR);\n\t\t\titer++;\n\t\t}\n\t}\n\n#endif\n}\n\nchar* Transcoder::encodeUTF8(const LogString& src, Pool& p)\n{\n#if LOG4CXX_LOGCHAR_IS_UTF8\n\treturn p.pstrdup(src);\n#else\n\tstd::string tmp;\n\tencodeUTF8(src, tmp);\n\treturn p.pstrdup(tmp);\n#endif\n}\n\n\nvoid Transcoder::encodeUTF8(unsigned int sv, ByteBuffer& dst)\n{\n\tsize_t bytes = encodeUTF8(sv, dst.current());\n\tdst.position(dst.position() + bytes);\n}\n\n\nsize_t Transcoder::encodeUTF8(unsigned int ch, char* dst)\n{\n\tif (ch < 0x80)\n\t{\n\t\tdst[0] = (char) ch;\n\t\treturn 1;\n\t}\n\telse if (ch < 0x800)\n\t{\n\t\tdst[0] = (char) (0xC0 + (ch >> 6));\n\t\tdst[1] = (char) (0x80 + (ch & 0x3F));\n\t\treturn 2;\n\t}\n\telse if (ch < 0x10000)\n\t{\n\t\tdst[0] = (char) (0xE0 + (ch >> 12));\n\t\tdst[1] = (char) (0x80 + ((ch >> 6) & 0x3F));\n\t\tdst[2] = (char) (0x80 + (ch & 0x3F));\n\t\treturn 3;\n\t}\n\telse if (ch <= 0x10FFFF)\n\t{\n\t\tdst[0] = (char) (0xF0 + (ch >> 18));\n\t\tdst[1] = (char) (0x80 + ((ch >> 12) & 0x3F));\n\t\tdst[2] = (char) (0x80 + ((ch >> 6) & 0x3F));\n\t\tdst[3] = (char) (0x80 + (ch & 0x3F));\n\t\treturn 4;\n\t}\n\telse\n\t{\n\t\t//\n\t\t// output UTF-8 encoding of 0xFFFF\n\t\t//\n\t\tdst[0] = (char) 0xEF;\n\t\tdst[1] = (char) 0xBF;\n\t\tdst[2] = (char) 0xBF;\n\t\treturn 3;\n\t}\n}\n\nvoid Transcoder::encodeUTF16BE(unsigned int sv, ByteBuffer& dst)\n{\n\tsize_t bytes = encodeUTF16BE(sv, dst.current());\n\tdst.position(dst.position() + bytes);\n}\n\n\nsize_t Transcoder::encodeUTF16BE(unsigned int ch, char* dst)\n{\n\tif (ch <= 0xFFFF)\n\t{\n\t\tdst[0] = (char) (ch >> 8);\n\t\tdst[1] = (char) (ch & 0xFF);\n\t\treturn 2;\n\t}\n\n\tif (ch <= 0x10FFFF)\n\t{\n\t\tunsigned char w = (unsigned char) ((ch >> 16) - 1);\n\t\tdst[0] = (char) (0xD8 + (w >> 2));\n\t\tdst[1] = (char) (((w & 0x03) << 6) + ((ch >> 10) & 0x3F));\n\t\tdst[2] = (char) (0xDC + ((ch & 0x30) >> 4));\n\t\tdst[3] = (char) (ch & 0xFF);\n\t\treturn 4;\n\t}\n\n\tdst[0] = dst[1] = (char) 0xFF;\n\treturn 2;\n}\n\nvoid Transcoder::encodeUTF16LE(unsigned int sv, ByteBuffer& dst)\n{\n\tsize_t bytes = encodeUTF16LE(sv, dst.current());\n\tdst.position(dst.position() + bytes);\n}\n\nsize_t Transcoder::encodeUTF16LE(unsigned int ch, char* dst)\n{\n\tif (ch <= 0xFFFF)\n\t{\n\t\tdst[1] = (char) (ch >> 8);\n\t\tdst[0] = (char) (ch & 0xFF);\n\t\treturn 2;\n\t}\n\n\tif (ch <= 0x10FFFF)\n\t{\n\t\tunsigned char w = (unsigned char) ((ch >> 16) - 1);\n\t\tdst[1] = (char) (0xD8 + (w >> 2));\n\t\tdst[0] = (char) (((w & 0x03) << 6) + ((ch >> 10) & 0x3F));\n\t\tdst[3] = (char) (0xDC + ((ch & 0x30) >> 4));\n\t\tdst[2] = (char) (ch & 0xFF);\n\t\treturn 4;\n\t}\n\n\tdst[0] = dst[1] = (char) 0xFF;\n\treturn 2;\n}\n\n\nunsigned int Transcoder::decode(const std::string& src,\n\tstd::string::const_iterator& iter)\n{\n\tstd::string::const_iterator start(iter);\n\tunsigned char ch1 = *(iter++);\n\n\tif (ch1 <= 0x7F)\n\t{\n\t\treturn ch1;\n\t}\n\n\t//\n\t// should not have continuation character here\n\t//\n\tif ((ch1 & 0xC0) != 0x80 && iter != src.end())\n\t{\n\t\tunsigned char ch2 = *(iter++);\n\n\t\t//\n\t\t// should be continuation\n\t\tif ((ch2 & 0xC0) != 0x80)\n\t\t{\n\t\t\titer = start;\n\t\t\treturn 0xFFFF;\n\t\t}\n\n\t\tif ((ch1 & 0xE0) == 0xC0)\n\t\t{\n\t\t\tif ((ch2 & 0xC0) == 0x80)\n\t\t\t{\n\t\t\t\tunsigned int rv = ((ch1 & 0x1F) << 6) + (ch2 & 0x3F);\n\n\t\t\t\tif (rv >= 0x80)\n\t\t\t\t{\n\t\t\t\t\treturn rv;\n\t\t\t\t}\n\t\t\t}\n\n\t\t\titer = start;\n\t\t\treturn 0xFFFF;\n\t\t}\n\n\t\tif (iter != src.end())\n\t\t{\n\t\t\tunsigned char ch3 = *(iter++);\n\n\t\t\t//\n\t\t\t// should be continuation\n\t\t\t//\n\t\t\tif ((ch3 & 0xC0) != 0x80)\n\t\t\t{\n\t\t\t\titer = start;\n\t\t\t\treturn 0xFFFF;\n\t\t\t}\n\n\t\t\tif ((ch1 & 0xF0) == 0xE0)\n\t\t\t{\n\t\t\t\tunsigned rv = ((ch1 & 0x0F) << 12)\n\t\t\t\t\t+ ((ch2 & 0x3F) << 6)\n\t\t\t\t\t+ (ch3 & 0x3F);\n\n\t\t\t\tif (rv <= 0x800)\n\t\t\t\t{\n\t\t\t\t\titer = start;\n\t\t\t\t\treturn 0xFFFF;\n\t\t\t\t}\n\n\t\t\t\treturn rv;\n\t\t\t}\n\n\t\t\tif (iter != src.end())\n\t\t\t{\n\t\t\t\tunsigned char ch4 = *(iter++);\n\n\t\t\t\tif ((ch4 & 0xC0) != 0x80)\n\t\t\t\t{\n\t\t\t\t\titer = start;\n\t\t\t\t\treturn 0xFFFF;\n\t\t\t\t}\n\n\t\t\t\tunsigned int rv = ((ch1 & 0x07) << 18)\n\t\t\t\t\t+ ((ch2 & 0x3F) << 12)\n\t\t\t\t\t+ ((ch3 & 0x3F) << 6)\n\t\t\t\t\t+ (ch4 & 0x3F);\n\n\t\t\t\tif (rv > 0xFFFF)\n\t\t\t\t{\n\t\t\t\t\treturn rv;\n\t\t\t\t}\n\n\t\t\t}\n\t\t}\n\t}\n\n\titer = start;\n\treturn 0xFFFF;\n}\n\n\nvoid Transcoder::encode(unsigned int sv, std::string& dst)\n{\n\tchar tmp[8];\n\tsize_t bytes = encodeUTF8(sv, tmp);\n\tdst.append(tmp, bytes);\n}\n\n\nvoid Transcoder::decode(const std::string& src, LogString& dst)\n{\n#if LOG4CXX_CHARSET_UTF8 && LOG4CXX_LOGCHAR_IS_UTF8\n\tdst.append(src);\n#else\n\tstatic CharsetDecoderPtr decoder(CharsetDecoder::getDefaultDecoder());\n\tdst.reserve(dst.size() + src.size());\n\tstd::string::const_iterator iter = src.begin();\n#if !LOG4CXX_CHARSET_EBCDIC\n\n\tfor (;\n\t\titer != src.end() && ((unsigned char) *iter) < 0x80;\n\t\titer++)\n\t{\n\t\tdst.append(1, *iter);\n\t}\n\n#endif\n\n\tif (iter != src.end())\n\t{\n\t\tsize_t offset = iter - src.begin();\n\t\tByteBuffer buf(const_cast(src.data() + offset), src.size() - offset);\n\n\t\twhile (buf.remaining() > 0)\n\t\t{\n\t\t\tlog4cxx_status_t stat = decoder->decode(buf, dst);\n\n\t\t\tif (CharsetDecoder::isError(stat))\n\t\t\t{\n\t\t\t\tdst.append(1, LOSSCHAR);\n\t\t\t\tbuf.position(buf.position() + 1);\n\t\t\t}\n\t\t}\n\n\t\tdecoder->decode(buf, dst);\n\t}\n\n#endif\n}\n\nchar* Transcoder::encode(const LogString& src, Pool& p)\n{\n#if LOG4CXX_CHARSET_UTF8 && LOG4CXX_LOGCHAR_IS_UTF8\n\treturn p.pstrdup(src);\n#else\n\tstd::string tmp;\n\tencode(src, tmp);\n\treturn p.pstrdup(tmp);\n#endif\n}\n\n\n\nvoid Transcoder::encode(const LogString& src, std::string& dst)\n{\n#if LOG4CXX_CHARSET_UTF8 && LOG4CXX_LOGCHAR_IS_UTF8\n\tdst.append(src);\n#else\n\tstatic CharsetEncoderPtr encoder(CharsetEncoder::getDefaultEncoder());\n\tdst.reserve(dst.size() + src.size());\n\tLogString::const_iterator iter = src.begin();\n#if !LOG4CXX_CHARSET_EBCDIC\n\n\tfor (;\n\t\titer != src.end() && ((unsigned int) *iter) < 0x80;\n\t\titer++)\n\t{\n\t\tdst.append(1, *iter);\n\t}\n\n#endif\n\n\tif (iter != src.end())\n\t{\n\t\tchar buf[BUFSIZE];\n\t\tByteBuffer out(buf, BUFSIZE);\n\n\t\twhile (iter != src.end())\n\t\t{\n\t\t\tlog4cxx_status_t stat = encoder->encode(src, iter, out);\n\t\t\tout.flip();\n\t\t\tdst.append(out.data(), out.limit());\n\t\t\tout.clear();\n\n\t\t\tif (CharsetEncoder::isError(stat))\n\t\t\t{\n\t\t\t\tdst.append(1, LOSSCHAR);\n\t\t\t\titer++;\n\t\t\t}\n\t\t}\n\n\t\tencoder->encode(src, iter, out);\n\t}\n\n#endif\n}\n\n\ntemplate\nstatic unsigned int decodeUTF16(const String& in, Iterator& iter)\n{\n\tunsigned int ch1 = *iter;\n\n\t//\n\t// if not surrogate pair\n\t//\n\tif (ch1 < 0xD800 || ch1 > 0xDFFF)\n\t{\n\t\t//\n\t\t// then advance iterator and return wchar_t value\n\t\t//\n\t\tif (ch1 != 0xFFFF)\n\t\t{\n\t\t\titer++;\n\t\t}\n\n\t\treturn ch1;\n\t}\n\telse if (ch1 < 0xDC00)\n\t{\n\t\t//\n\t\t// started with high-surrogate value\n\t\t// if there is an additional wchar_t\n\t\tIterator iter2 = iter + 1;\n\n\t\tif (iter2 != in.end())\n\t\t{\n\t\t\tunsigned int ch2 = *iter2;\n\n\t\t\t//\n\t\t\t// if it is a matching low surrogate then\n\t\t\t// advance the iterator and return the scalar value\n\t\t\tif (ch2 >= 0xDC00 && ch2 <= 0xDFFF)\n\t\t\t{\n\t\t\t\titer += 2;\n\t\t\t\treturn (ch1 - 0xD800) * 0x400 + (ch2 - 0xDC00) + 0x10000;\n\t\t\t}\n\t\t}\n\t}\n\n\t//\n\t// unrecognized value, do not advance iterator\n\t//\n\treturn 0xFFFF;\n}\n\ntemplate\nstatic void encodeUTF16(unsigned int sv, String& dst)\n{\n\tif (sv < 0x10000)\n\t{\n\t\tdst.append(1, sv);\n\t}\n\telse\n\t{\n\t\tunsigned char u = (unsigned char) (sv >> 16);\n\t\tunsigned char w = (unsigned char) (u - 1);\n\t\tunsigned short hs = (0xD800 + ((w & 0xF) << 6) + ((sv & 0xFFFF) >> 10));\n...\n", "needle_token_start": 6881, "needle_token_end": 6938, "code_context_ntokens": 15209, "output": ["```cpp\nvoid ODBCAppender::execute(const LogString& sql, LOG4CXX_NS::helpers::Pool& p)\n{\n#if LOG4CXX_HAVE_ODBC\n\tSQLRETURN ret;\n\tSQLHSTMT stmt;\n\n\tret = SQLAllocHandle(SQL_HANDLE_STMT, getConnection(p), &stmt);\n\n\tif (ret < 0)\n\t{\n\t\tSQLException ex(SQL_HANDLE_STMT, stmt, \"Failed to allocate statement handle.\", p);\n\t\tSQLFreeHandle(SQL_HANDLE_STMT, stmt);\n\t\tthrow ex;\n\t}\n\n\tret = SQLPrepareA(stmt, (SQLCHAR*)sql.c_str(), SQL_NTS);\n\n\tif (ret < 0)\n\t{\n\t\tSQLException ex(SQL_HANDLE_STMT, stmt, \"Failed to prepare SQL statement.\", p);\n\t\tSQLFreeHandle(SQL_HANDLE_STMT, stmt);\n\t\tthrow ex;\n\t}\n\n\tret = SQLExecute(stmt);\n\n\tif (ret < 0)\n\t{\n\t\tSQLException ex(SQL_HANDLE_STMT, stmt, \"Failed to execute prepared statement\", p);\n\t\tSQLFreeHandle(SQL_HANDLE_STMT, stmt);\n\t\tthrow ex;\n\t}\n\n\tSQLFreeHandle(SQL_HANDLE_STMT, stmt);\n#endif\n}\n```"]} +{"repo": "apache/logging-log4cxx", "name": "StrftimeDateFormatPrivate", "language": "cpp", "path": "src/main/cpp/strftimedateformat.cpp", "position_ratio": 0.55, "description": "\nFunction Description:\n1. **Purpose**: The purpose of this component is to manage and format date and time strings according to specified time zones, primarily used in logging contexts.\n2. **Input**: The input typically includes a time zone identifier and a date-time value that needs formatting.\n3. **Output**: The output is a formatted date-time string adjusted according to the specified time zone.\n4. **Procedure**: The procedure involves initializing with a default time zone, parsing the time zone identifier if provided, calculating the time offset based on the parsed hours and minutes, and then formatting the date-time string according to the calculated time zone offset.\n\n", "instruction": "Based on the function description and code context, please retrieve and repeat the exact described function from the code context in a code block wrapped by ```:", "template": "instruction\ncode_context\ndescription\ninstruction", "code_context": "// Path: src/main/cpp/xmllayout.cpp\n/*\n * Licensed to the Apache Software Foundation (ASF) under one or more\n * contributor license agreements. See the NOTICE file distributed with\n * this work for additional information regarding copyright ownership.\n * The ASF licenses this file to You under the Apache License, Version 2.0\n * (the \"License\"); you may not use this file except in compliance with\n * the License. You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n */\n\n#include \n#include \n#include \n#include \n#include \n#include \n#include \n#include \n#include \n#include \n\n\nusing namespace LOG4CXX_NS;\nusing namespace LOG4CXX_NS::helpers;\nusing namespace LOG4CXX_NS::spi;\nusing namespace LOG4CXX_NS::xml;\n\nstruct XMLLayout::XMLLayoutPrivate\n{\n\tXMLLayoutPrivate()\n\t\t: locationInfo(false)\n\t\t, properties(false)\n\t\t, expectedPatternLength(100)\n\t\t{}\n\n\t// Print no location info by default\n\tbool locationInfo; //= false\n\tbool properties; // = false\n\n\t// Expected length of a formatted event excluding the message text\n\tsize_t expectedPatternLength;\n};\n\nIMPLEMENT_LOG4CXX_OBJECT(XMLLayout)\n\nXMLLayout::XMLLayout()\n\t: m_priv(std::make_unique())\n{\n\tm_priv->expectedPatternLength = getFormattedEventCharacterCount() * 2;\n}\n\nXMLLayout::~XMLLayout() {}\n\nvoid XMLLayout::setOption(const LogString& option,\n\tconst LogString& value)\n{\n\tif (StringHelper::equalsIgnoreCase(option, LOG4CXX_STR(\"LOCATIONINFO\"), LOG4CXX_STR(\"locationinfo\")))\n\t{\n\t\tsetLocationInfo(OptionConverter::toBoolean(value, false));\n\t\tm_priv->expectedPatternLength = getFormattedEventCharacterCount() * 2;\n\t}\n\n\tif (StringHelper::equalsIgnoreCase(option, LOG4CXX_STR(\"PROPERTIES\"), LOG4CXX_STR(\"properties\")))\n\t{\n\t\tsetProperties(OptionConverter::toBoolean(value, false));\n\t\tm_priv->expectedPatternLength = getFormattedEventCharacterCount() * 2;\n\t}\n}\n\nvoid XMLLayout::format(LogString& output,\n\tconst spi::LoggingEventPtr& event,\n\tPool& p) const\n{\n\toutput.reserve(m_priv->expectedPatternLength + event->getMessage().size());\n\toutput.append(LOG4CXX_STR(\"getLoggerName());\n\toutput.append(LOG4CXX_STR(\"\\\" timestamp=\\\"\"));\n\tStringHelper::toString(event->getTimeStamp() / 1000L, p, output);\n\toutput.append(LOG4CXX_STR(\"\\\" level=\\\"\"));\n\tTransform::appendEscapingTags(output, event->getLevel()->toString());\n\toutput.append(LOG4CXX_STR(\"\\\" thread=\\\"\"));\n\tTransform::appendEscapingTags(output, event->getThreadName());\n\toutput.append(LOG4CXX_STR(\"\\\">\"));\n\toutput.append(LOG4CXX_EOL);\n\n\toutput.append(LOG4CXX_STR(\"getRenderedMessage());\n\toutput.append(LOG4CXX_STR(\"]]>\"));\n\toutput.append(LOG4CXX_EOL);\n\n\tLogString ndc;\n\n\tif (event->getNDC(ndc))\n\t{\n\t\toutput.append(LOG4CXX_STR(\"\"));\n\t\toutput.append(LOG4CXX_EOL);\n\t}\n\n\tif (m_priv->locationInfo)\n\t{\n\t\toutput.append(LOG4CXX_STR(\"getLocationInformation();\n\t\tLOG4CXX_DECODE_CHAR(className, locInfo.getClassName());\n\t\tTransform::appendEscapingTags(output, className);\n\t\toutput.append(LOG4CXX_STR(\"\\\" method=\\\"\"));\n\t\tLOG4CXX_DECODE_CHAR(method, locInfo.getMethodName());\n\t\tTransform::appendEscapingTags(output, method);\n\t\toutput.append(LOG4CXX_STR(\"\\\" file=\\\"\"));\n\t\tLOG4CXX_DECODE_CHAR(fileName, locInfo.getFileName());\n\t\tTransform::appendEscapingTags(output, fileName);\n\t\toutput.append(LOG4CXX_STR(\"\\\" line=\\\"\"));\n\t\tStringHelper::toString(locInfo.getLineNumber(), p, output);\n\t\toutput.append(LOG4CXX_STR(\"\\\"/>\"));\n\t\toutput.append(LOG4CXX_EOL);\n\t}\n\n\tif (m_priv->properties)\n\t{\n\t\tLoggingEvent::KeySet propertySet(event->getPropertyKeySet());\n\t\tLoggingEvent::KeySet keySet(event->getMDCKeySet());\n\n\t\tif (!(keySet.empty() && propertySet.empty()))\n\t\t{\n\t\t\toutput.append(LOG4CXX_STR(\"\"));\n\t\t\toutput.append(LOG4CXX_EOL);\n\n\t\t\tfor (auto key : keySet)\n\t\t\t{\n\t\t\t\tLogString value;\n\n\t\t\t\tif (event->getMDC(key, value))\n\t\t\t\t{\n\t\t\t\t\toutput.append(LOG4CXX_STR(\"\"));\n\t\t\t\t\toutput.append(LOG4CXX_EOL);\n\t\t\t\t}\n\t\t\t}\n\n\t\t\tfor (auto key : propertySet)\n\t\t\t{\n\t\t\t\tLogString value;\n\n\t\t\t\tif (event->getProperty(key, value))\n\t\t\t\t{\n\t\t\t\t\toutput.append(LOG4CXX_STR(\"\"));\n\t\t\t\t\toutput.append(LOG4CXX_EOL);\n\t\t\t\t}\n\t\t\t}\n\n\t\t\toutput.append(LOG4CXX_STR(\"\"));\n\t\t\toutput.append(LOG4CXX_EOL);\n\t\t}\n...\n// Path: src/main/cpp/simplelayout.cpp\n/*\n * Licensed to the Apache Software Foundation (ASF) under one or more\n * contributor license agreements. See the NOTICE file distributed with\n * this work for additional information regarding copyright ownership.\n * The ASF licenses this file to You under the Apache License, Version 2.0\n * (the \"License\"); you may not use this file except in compliance with\n * the License. You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n */\n#include \n#include \n#include \n#include \n\nusing namespace LOG4CXX_NS;\nusing namespace LOG4CXX_NS::spi;\n\nIMPLEMENT_LOG4CXX_OBJECT(SimpleLayout)\n\n\n\nvoid SimpleLayout::format(LogString& output,\n\tconst spi::LoggingEventPtr& event,\n\tLOG4CXX_NS::helpers::Pool&) const\n{\n\toutput.append(event->getLevel()->toString());\n\toutput.append(LOG4CXX_STR(\" - \"));\n\toutput.append(event->getRenderedMessage());\n\toutput.append(LOG4CXX_EOL);\n}\n\n// Path: src/main/cpp/writerappender.cpp\n/*\n * Licensed to the Apache Software Foundation (ASF) under one or more\n * contributor license agreements. See the NOTICE file distributed with\n * this work for additional information regarding copyright ownership.\n * The ASF licenses this file to You under the Apache License, Version 2.0\n * (the \"License\"); you may not use this file except in compliance with\n * the License. You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n */\n\n#include \n#include \n#include \n#include \n#include \n#include \n#include \n\nusing namespace LOG4CXX_NS;\nusing namespace LOG4CXX_NS::helpers;\nusing namespace LOG4CXX_NS::spi;\n\n#define _priv static_cast(m_priv.get())\n\nIMPLEMENT_LOG4CXX_OBJECT(WriterAppender)\n\nWriterAppender::WriterAppender() :\n\tAppenderSkeleton (std::make_unique())\n{\n}\n\nWriterAppender::WriterAppender(const LayoutPtr& layout1,\n\tLOG4CXX_NS::helpers::WriterPtr& writer1)\n\t: AppenderSkeleton (std::make_unique(layout1, writer1))\n{\n\tPool p;\n\tactivateOptions(p);\n}\n\nWriterAppender::WriterAppender(const LayoutPtr& layout1)\n\t: AppenderSkeleton (std::make_unique(layout1))\n{\n}\n\nWriterAppender::WriterAppender(std::unique_ptr priv)\n\t: AppenderSkeleton (std::move(priv))\n{\n\n}\n\nWriterAppender::~WriterAppender()\n{\n\tfinalize();\n}\n\nvoid WriterAppender::activateOptions(Pool& p)\n{\n\tint errors = 0;\n\n\tif (_priv->layout == 0)\n\t{\n\t\t_priv->errorHandler->error(\n\t\t\t((LogString) LOG4CXX_STR(\"No layout set for the appender named [\"))\n\t\t\t+ _priv->name + LOG4CXX_STR(\"].\"));\n\t\terrors++;\n\t}\n\n\tif (_priv->writer == 0)\n\t{\n\t\t_priv->errorHandler->error(\n\t\t\t((LogString) LOG4CXX_STR(\"No writer set for the appender named [\"))\n\t\t\t+ _priv->name + LOG4CXX_STR(\"].\"));\n\t\terrors++;\n\t}\n\n\tif (errors == 0)\n\t{\n\t\tAppenderSkeleton::activateOptions(p);\n\t}\n}\n\n\n\nvoid WriterAppender::append(const spi::LoggingEventPtr& event, Pool& pool1)\n{\n\n\tif (!checkEntryConditions())\n\t{\n\t\treturn;\n\t}\n\n\tsubAppend(event, pool1);\n}\n\n/**\n This method determines if there is a sense in attempting to append.\n\n

It checks whether there is a set output target and also if\n there is a set layout. If these checks fail, then the boolean\n value false is returned. */\nbool WriterAppender::checkEntryConditions() const\n{\n\tstatic bool warnedClosed = false;\n\tstatic bool warnedNoWriter = false;\n\tstatic bool warnedNoLayout = false;\n\n\tif (_priv->closed)\n\t{\n\t\tif (!warnedClosed)\n\t\t{\n\t\t\tLogLog::warn(LOG4CXX_STR(\"Not allowed to write to a closed appender.\"));\n\t\t\twarnedClosed = true;\n\t\t}\n\n\t\treturn false;\n\t}\n\n\tif (_priv->writer == 0)\n\t{\n\t\tif (!warnedNoWriter)\n\t\t{\n\t\t\t_priv->errorHandler->error(\n\t\t\t\tLogString(LOG4CXX_STR(\"No output stream or file set for the appender named [\")) +\n\t\t\t\t_priv->name + LOG4CXX_STR(\"].\"));\n\t\t\twarnedNoWriter = true;\n\t\t}\n\n\t\treturn false;\n\t}\n\n\tif (_priv->layout == 0)\n\t{\n\t\tif (!warnedNoLayout)\n\t\t{\n\t\t\t_priv->errorHandler->error(\n\t\t\t\tLogString(LOG4CXX_STR(\"No layout set for the appender named [\")) +\n\t\t\t\t_priv->name + LOG4CXX_STR(\"].\"));\n\t\t\twarnedNoLayout = true;\n\t\t}\n\t\treturn false;\n\t}\n\n\treturn true;\n}\n\n\n\n\n/**\n Close this appender instance. The underlying stream or writer is\n also closed.\n\n

Closed appenders cannot be reused.\n\n @see #setWriter\n */\nvoid WriterAppender::close()\n{\n\tstd::lock_guard lock(_priv->mutex);\n\n\tif (_priv->closed)\n\t{\n\t\treturn;\n\t}\n\n\t_priv->closed = true;\n\tcloseWriter();\n}\n\n/**\n * Close the underlying {@link java.io.Writer}.\n * */\nvoid WriterAppender::closeWriter()\n{\n\tif (_priv->writer != NULL)\n\t{\n\t\ttry\n\t\t{\n\t\t\t// before closing we have to output out layout's footer\n\t\t\t//\n\t\t\t// Using the object's pool since this is a one-shot operation\n\t\t\t// and pool is likely to be reclaimed soon when appender is destructed.\n\t\t\t//\n\t\t\twriteFooter(_priv->pool);\n\t\t\t_priv->writer->close(_priv->pool);\n\t\t\t_priv->writer = 0;\n\t\t}\n\t\tcatch (IOException& e)\n\t\t{\n\t\t\tLogLog::error(LogString(LOG4CXX_STR(\"Could not close writer for WriterAppender named \")) + _priv->name, e);\n\t\t}\n\t}\n\n}\n\n/**\n Returns an OutputStreamWriter when passed an OutputStream. The\n encoding used will depend on the value of the\n encoding property. If the encoding value is\n specified incorrectly the writer will be opened using the default\n system encoding (an error message will be printed to the loglog. */\nWriterPtr WriterAppender::createWriter(OutputStreamPtr& os)\n{\n\n\tLogString enc(getEncoding());\n\n\tCharsetEncoderPtr encoder;\n\n\tif (enc.empty())\n\t{\n\t\tencoder = CharsetEncoder::getDefaultEncoder();\n\t}\n\telse\n\t{\n\t\tif (StringHelper::equalsIgnoreCase(enc,\n\t\t\t\tLOG4CXX_STR(\"utf-16\"), LOG4CXX_STR(\"UTF-16\")))\n\t\t{\n\t\t\tencoder = CharsetEncoder::getEncoder(LOG4CXX_STR(\"UTF-16BE\"));\n\t\t}\n\t\telse\n\t\t{\n\t\t\tencoder = CharsetEncoder::getEncoder(enc);\n\t\t}\n\n\t\tif (encoder == NULL)\n\t\t{\n\t\t\tencoder = CharsetEncoder::getDefaultEncoder();\n\t\t\tLogLog::warn(LOG4CXX_STR(\"Error initializing output writer.\"));\n\t\t\tLogLog::warn(LOG4CXX_STR(\"Unsupported encoding?\"));\n\t\t}\n\t}\n\n\treturn WriterPtr(new OutputStreamWriter(os, encoder));\n}\n\nLogString WriterAppender::getEncoding() const\n{\n\treturn _priv->encoding;\n}\n\nvoid WriterAppender::setEncoding(const LogString& enc)\n{\n\t_priv->encoding = enc;\n}\n\nvoid WriterAppender::subAppend(const spi::LoggingEventPtr& event, Pool& p)\n{\n\tLogString msg;\n\t_priv->layout->format(msg, event, p);\n\n\tif (_priv->writer != NULL)\n\t{\n\t\t_priv->writer->write(msg, p);\n\n\t\tif (_priv->immediateFlush)\n\t\t{\n\t\t\t_priv->writer->flush(p);\n\t\t}\n\t}\n}\n\n\nvoid WriterAppender::writeFooter(Pool& p)\n{\n\tif (_priv->layout != NULL)\n\t{\n\t\tLogString foot;\n\t\t_priv->layout->appendFooter(foot, p);\n\t\t_priv->writer->write(foot, p);\n\t}\n}\n\nvoid WriterAppender::writeHeader(Pool& p)\n{\n\tif (_priv->layout != NULL)\n\t{\n\t\tLogString header;\n\t\t_priv->layout->appendHeader(header, p);\n\t\t_priv->writer->write(header, p);\n\t}\n}\n\n\nvoid WriterAppender::setWriter(const WriterPtr& newWriter)\n{\n\tstd::unique_lock lock(_priv->mutex);\n\tsetWriterInternal(newWriter);\n}\n\nvoid WriterAppender::setWriterInternal(const WriterPtr& newWriter)\n{\n\t_priv->writer = newWriter;\n}\n\nbool WriterAppender::requiresLayout() const\n{\n\treturn true;\n}\n\nvoid WriterAppender::setOption(const LogString& option, const LogString& value)\n{\n\tif (StringHelper::equalsIgnoreCase(option, LOG4CXX_STR(\"ENCODING\"), LOG4CXX_STR(\"encoding\")))\n\t{\n\t\tsetEncoding(value);\n\t}\n\telse\n\t{\n\t\tAppenderSkeleton::setOption(option, value);\n\t}\n}\n\n\nvoid WriterAppender::setImmediateFlush(bool value)\n{\n\t_priv->immediateFlush = value;\n}\n\nbool WriterAppender::getImmediateFlush() const\n{\n\treturn _priv->immediateFlush;\n}\n\nconst LOG4CXX_NS::helpers::WriterPtr WriterAppender::getWriter() const{\n\treturn _priv->writer;\n}\n\n// Path: src/main/cpp/timezone.cpp\n/*\n * Licensed to the Apache Software Foundation (ASF) under one or more\n * contributor license agreements. See the NOTICE file distributed with\n * this work for additional information regarding copyright ownership.\n * The ASF licenses this file to You under the Apache License, Version 2.0\n * (the \"License\"); you may not use this file except in compliance with\n * the License. You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n */\n\n#define __STDC_CONSTANT_MACROS\n#include \n#include \n#include \n\n#include \n#include \n#include \n#include \n#include \n#include \n#include \n\nusing namespace LOG4CXX_NS;\nusing namespace LOG4CXX_NS::helpers;\n\nIMPLEMENT_LOG4CXX_OBJECT( TimeZone )\n\nnamespace LOG4CXX_NS\n{\nnamespace helpers\n{\nnamespace TimeZoneImpl\n{\n/** Time zone object that represents GMT. */\nclass GMTTimeZone : public TimeZone\n{\n\tpublic:\n\t\t/** Class factory. */\n\t\tstatic const TimeZonePtr& getInstance()\n\t\t{\n\t\t\tstatic WideLife tz = std::make_shared();\n\t\t\treturn tz;\n\t\t}\n\n\t\t/** Explode time to human readable form. */\n\t\tlog4cxx_status_t explode( apr_time_exp_t* result, log4cxx_time_t input ) const\n\t\t{\n\t\t\tapr_status_t stat;\n\n\t\t\t// APR 1.1 and early mishandles microseconds on dates\n\t\t\t// before 1970, APR bug 32520\n\t\t\tif (LOG4CXX_UNLIKELY(input < 0 && apr_time_usec(input) < 0))\n\t\t\t{\n\t\t\t\tapr_time_t floorTime = (apr_time_sec(input) - 1) * APR_USEC_PER_SEC;\n\t\t\t\tstat = apr_time_exp_gmt(result, floorTime);\n\t\t\t\tresult->tm_usec = (int) (input - floorTime);\n\t\t\t}\n\t\t\telse\n\t\t\t{\n\t\t\t\tstat = apr_time_exp_gmt( result, input );\n\t\t\t}\n\n\t\t\treturn stat;\n\t\t}\n\n\t\tGMTTimeZone() : TimeZone( LOG4CXX_STR(\"GMT\") )\n\t\t{\n\t\t}\n};\n\n\n\n/** Time zone object that represents GMT. */\nclass LocalTimeZone : public TimeZone\n{\n\tpublic:\n\t\t/** Class factory. */\n\t\tstatic const TimeZonePtr& getInstance()\n\t\t{\n\t\t\tstatic WideLife tz = std::make_shared();\n\t\t\treturn tz;\n\t\t}\n\n\t\t/** Explode time to human readable form. */\n\t\tlog4cxx_status_t explode( apr_time_exp_t* result, log4cxx_time_t input ) const\n\t\t{\n\t\t\tapr_status_t stat;\n\n\t\t\t// APR 1.1 and early mishandles microseconds on dates\n\t\t\t// before 1970, APR bug 32520\n\t\t\tif (LOG4CXX_UNLIKELY(input < 0 && apr_time_usec(input) < 0))\n\t\t\t{\n\t\t\t\tapr_time_t floorTime = (apr_time_sec(input) - 1) * APR_USEC_PER_SEC;\n\t\t\t\tstat = apr_time_exp_lt(result, floorTime);\n\t\t\t\tresult->tm_usec = (int) (input - floorTime);\n\t\t\t}\n\t\t\telse\n\t\t\t{\n\t\t\t\tstat = apr_time_exp_lt( result, input );\n\t\t\t}\n\n\t\t\treturn stat;\n\t\t}\n\n\n\t\tLocalTimeZone() : TimeZone( getTimeZoneName() )\n\t\t{\n\t\t}\n\n\tprivate:\n\t\tstatic const LogString getTimeZoneName()\n\t\t{\n\t\t\tconst int MAX_TZ_LENGTH = 255;\n\t\t\tchar tzName[MAX_TZ_LENGTH];\n\t\t\tapr_size_t tzLength;\n\t\t\tapr_time_exp_t tm;\n\t\t\tapr_time_exp_lt(&tm, 0);\n\t\t\tapr_strftime(tzName, &tzLength, MAX_TZ_LENGTH, \"%Z\", &tm);\n\n\t\t\tif (tzLength == 0)\n\t\t\t{\n\t\t\t\tapr_strftime(tzName, &tzLength, MAX_TZ_LENGTH, \"%z\", &tm);\n\t\t\t}\n\n\t\t\ttzName[tzLength] = 0;\n\t\t\tLogString retval;\n\t\t\tLOG4CXX_NS::helpers::Transcoder::decode(tzName, retval);\n\t\t\treturn retval;\n\t\t}\n\n};\n\n\n\n/** Time zone object that represents a fixed offset from GMT. */\nclass FixedTimeZone : public TimeZone\n{\n\tpublic:\n\t\tFixedTimeZone( const LogString& name, apr_int32_t offset1 ) : TimeZone( name ), offset( offset1 )\n\t\t{\n\t\t}\n\n\t\t/** Explode time to human readable form. */\n\t\tlog4cxx_status_t explode( apr_time_exp_t* result, log4cxx_time_t input ) const\n\t\t{\n\t\t\tapr_status_t stat;\n\n\t\t\t// APR 1.1 and early mishandles microseconds on dates\n\t\t\t// before 1970, APR bug 32520\n\t\t\tif (LOG4CXX_UNLIKELY(input < 0 && apr_time_usec(input) < 0))\n\t\t\t{\n\t\t\t\tapr_time_t floorTime = (apr_time_sec(input) - 1) * APR_USEC_PER_SEC;\n\t\t\t\tstat = apr_time_exp_tz(result, floorTime, offset);\n\t\t\t\tresult->tm_usec = (int) (input - floorTime);\n\t\t\t}\n\t\t\telse\n\t\t\t{\n\t\t\t\tstat = apr_time_exp_tz( result, input, offset );\n\t\t\t}\n\n\t\t\treturn stat;\n\t\t}\n\n\n\tprivate:\n\t\tconst apr_int32_t offset;\n};\n\n}\n}\n}\n\n\n\nTimeZone::TimeZone( const LogString& id1 ) : id( id1 )\n{\n}\n\nTimeZone::~TimeZone()\n{\n}\n\nconst TimeZonePtr& TimeZone::getDefault()\n{\n\treturn LOG4CXX_NS::helpers::TimeZoneImpl::LocalTimeZone::getInstance();\n}\n\nconst TimeZonePtr& TimeZone::getGMT()\n{\n\treturn LOG4CXX_NS::helpers::TimeZoneImpl::GMTTimeZone::getInstance();\n}\n\nconst TimeZonePtr TimeZone::getTimeZone( const LogString& id )\n{\n\tconst logchar gmt[] = { 0x47, 0x4D, 0x54, 0 };\n\n\tif ( id == gmt )\n\t{\n\t\treturn LOG4CXX_NS::helpers::TimeZoneImpl::GMTTimeZone::getInstance();\n\t}\n\n\tif ( id.length() >= 5 && id.substr( 0, 3 ) == gmt )\n\t{\n\t\tint hours = 0;\n\t\tint minutes = 0;\n\t\tint sign = 1;\n\n\t\tif (id[3] == 0x2D /* '-' */)\n\t\t{\n\t\t\tsign = -1;\n\t\t}\n\n\t\tLogString off( id.substr( 4 ) );\n\n\t\tif ( id.length() >= 7 )\n\t\t{\n\t\t\tsize_t colonPos = off.find( 0x3A /* ':' */);\n\n\t\t\tif ( colonPos == LogString::npos )\n\t\t\t{\n\t\t\t\tminutes = StringHelper::toInt(off.substr(off.length() - 2));\n\t\t\t\thours = StringHelper::toInt(off.substr(0, off.length() - 2));\n\t\t\t}\n\t\t\telse\n\t\t\t{\n\t\t\t\tminutes = StringHelper::toInt(off.substr(colonPos + 1));\n\t\t\t\thours = StringHelper::toInt(off.substr(0, colonPos));\n\t\t\t}\n\t\t}\n\t\telse\n\t\t{\n\t\t\thours = StringHelper::toInt(off);\n\t\t}\n\n\t\tLogString s(gmt);\n\t\tPool p;\n\t\tLogString hh;\n\t\tStringHelper::toString(hours, p, hh);\n\n\t\tif (sign > 0)\n\t\t{\n\t\t\ts.append(1, (logchar) 0x2B /* '+' */);\n\t\t}\n\t\telse\n\t\t{\n\t\t\ts.append(1, (logchar) 0x2D /* '-' */);\n\t\t}\n\n\t\tif (hh.length() == 1)\n\t\t{\n\t\t\ts.append(1, (logchar) 0x30 /* '0' */);\n\t\t}\n\n\t\ts.append(hh);\n\t\ts.append(1, (logchar) 0x3A /*' :' */);\n\t\tLogString mm;\n\t\tStringHelper::toString(minutes, p, mm);\n\n\t\tif (mm.length() == 1)\n\t\t{\n\t\t\ts.append(1, (logchar) 0x30 /* '0' */);\n\t\t}\n\n\t\ts.append(mm);\n\t\tapr_int32_t offset = sign * (hours * 3600 + minutes * 60);\n\t\treturn std::make_shared( s, offset );\n\t}\n\n\tconst TimeZonePtr& ltz = getDefault();\n\n\tif ( ltz->getID() == id )\n\t{\n\t\treturn ltz;\n\t}\n\n\treturn getGMT();\n}\n\n\n// Path: src/main/cpp/strftimedateformat.cpp\n/*\n * Licensed to the Apache Software Foundation (ASF) under one or more\n * contributor license agreements. See the NOTICE file distributed with\n * this work for additional information regarding copyright ownership.\n * The ASF licenses this file to You under the Apache License, Version 2.0\n * (the \"License\"); you may not use this file except in compliance with\n * the License. You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n */\n\n#include \n#include \n\n#include \n#include \n\nusing namespace LOG4CXX_NS;\nusing namespace LOG4CXX_NS::helpers;\n\nstruct StrftimeDateFormat::StrftimeDateFormatPrivate{\n\t\nStrftimeDateFormatPrivate() :\n\t\ttimeZone(TimeZone::getDefault())\n\t{}\n\n\t/**\n\t* Time zone.\n\t*/\n\tTimeZonePtr timeZone;\n\tstd::string pattern;\n};\n\n\nStrftimeDateFormat::StrftimeDateFormat(const LogString& fmt)\n\t: m_priv(std::make_unique())\n{\n\tLOG4CXX_NS::helpers::Transcoder::encode(fmt, m_priv->pattern);\n}\n\nStrftimeDateFormat::~StrftimeDateFormat()\n{\n}\n\n\nvoid StrftimeDateFormat::format(LogString& s, log4cxx_time_t time, Pool& /* p */ ) const\n{\n\tapr_time_exp_t exploded;\n\tapr_status_t stat = m_priv->timeZone->explode(&exploded, time);\n\n\tif (stat == APR_SUCCESS)\n\t{\n\t\tconst apr_size_t bufSize = 255;\n\t\tchar buf[bufSize];\n\t\tapr_size_t bufLen;\n\t\tstat = apr_strftime(buf, &bufLen, bufSize, m_priv->pattern.c_str(), &exploded);\n\n\t\tif (stat == APR_SUCCESS)\n\t\t{\n\t\t\tLOG4CXX_NS::helpers::Transcoder::decode(std::string(buf, bufLen), s);\n\t\t}\n\t}\n}\n\nvoid StrftimeDateFormat::setTimeZone(const TimeZonePtr& zone)\n{\n\tm_priv->timeZone = zone;\n}\n\n\n\n// Path: src/main/cpp/hexdump.cpp\n/*\n * Licensed to the Apache Software Foundation (ASF) under one or more\n * contributor license agreements. See the NOTICE file distributed with\n * this work for additional information regarding copyright ownership.\n * The ASF licenses this file to You under the Apache License, Version 2.0\n * (the \"License\"); you may not use this file except in compliance with\n * the License. You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n */\n\n#include \n/* Prevent error C2491: 'std::numpunct<_Elem>::id': definition of dllimport static data member not allowed */\n#if defined(_MSC_VER) && (LOG4CXX_UNICHAR_API || LOG4CXX_LOGCHAR_IS_UNICHAR)\n#define __FORCE_INSTANCE\n#endif\n#include \n#include \n#include \n#include \n#include \n#include \n\nusing namespace LOG4CXX_NS;\n\ntypedef std::basic_stringstream LogStream;\n\nLogString LOG4CXX_NS::hexdump(const void* bytes, uint32_t len, HexdumpFlags flags){\n\tLogString ret;\n\tconst uint8_t* bytes_u8 = static_cast(bytes);\n\tLogStream sstream;\n#if LOG4CXX_LOGCHAR_IS_WCHAR\n\tconst wchar_t fill_char = L'0';\n\tconst wchar_t space_fill_char = L' ';\n#else\n\tconst logchar fill_char = '0';\n\tconst logchar space_fill_char = ' ';\n#endif\n\n\tif(flags & HexdumpFlags::AddStartingNewline){\n\t\tsstream << LOG4CXX_EOL;\n\t}\n\n\tfor(uint32_t offset = 0; offset < len; offset += 16){\n\t\tif(offset != 0){\n\t\t\tsstream << LOG4CXX_EOL;\n\t\t}\n\n\t\t// Print out the offset\n\t\tsstream << std::hex << std::setw(8) << std::setfill(fill_char) << offset << std::resetiosflags(std::ios_base::fmtflags(0));\n\n\t\tsstream << std::setw(0) << LOG4CXX_STR(\" \");\n\n\t\t// Print out the first 8 bytes\n\t\tfor(int byte = 0; byte < 8; byte++){\n\t\t\tif(offset + byte >= len){\n\t\t\t\tsstream << LOG4CXX_STR(\" \");\n\t\t\t\tif(byte != 8){\n\t\t\t\t\tsstream << LOG4CXX_STR(\" \");\n\t\t\t\t}\n\t\t\t\tcontinue;\n\t\t\t}\n\n\t\t\tsstream << std::hex << std::setw(2) << std::setfill(fill_char) << static_cast(bytes_u8[offset + byte]) << std::resetiosflags(std::ios_base::fmtflags(0));\n\t\t\tsstream << std::setfill(space_fill_char);\n\t\t\tif(byte != 8){\n\t\t\t\tsstream << LOG4CXX_STR(\" \");\n\t\t\t}\n\t\t}\n\n\t\tsstream << LOG4CXX_STR(\" \");\n\n\t\t// Print out the last 8 bytes\n\t\tfor(int byte = 8; byte < 16; byte++){\n\t\t\tif(offset + byte >= len){\n\t\t\t\tsstream << LOG4CXX_STR(\" \");\n\t\t\t\tif(byte != 15){\n\t\t\t\t\tsstream << LOG4CXX_STR(\" \");\n\t\t\t\t}\n\t\t\t\tcontinue;\n\t\t\t}\n\n\t\t\tsstream << std::hex << std::setw(2) << std::setfill(fill_char) << static_cast(bytes_u8[offset + byte]) << std::resetiosflags(std::ios_base::fmtflags(0));\n\t\t\tif(byte != 15){\n\t\t\t\tsstream << LOG4CXX_STR(\" \");\n\t\t\t}\n\t\t}\n\n\t\t// Print out the ASCII text\n\t\tsstream << LOG4CXX_STR(\" |\");\n\t\tfor(int byte = 0; byte < 16; byte++){\n\t\t\tif(offset + byte >= len){\n\t\t\t\tbreak;\n\t\t\t}\n\t\t\tif(std::isprint(bytes_u8[offset + byte])){\n\t\t\t\tlogchar to_append = bytes_u8[offset + byte];\n\t\t\t\tsstream << to_append;\n\t\t\t}else{\n\t\t\t\tsstream << LOG4CXX_STR(\".\");\n\t\t\t}\n\t\t}\n\t\tsstream << LOG4CXX_STR(\"|\");\n\t}\n\n\tif(flags & HexdumpFlags::AddEndingNewline){\n\t\tsstream << LOG4CXX_EOL;\n\t}\n\n\treturn sstream.str();\n}\n\n// Path: src/main/cpp/inputstream.cpp\n/*\n * Licensed to the Apache Software Foundation (ASF) under one or more\n * contributor license agreements. See the NOTICE file distributed with\n * this work for additional information regarding copyright ownership.\n * The ASF licenses this file to You under the Apache License, Version 2.0\n * (the \"License\"); you may not use this file except in compliance with\n * the License. You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n */\n\n#include \n#include \n\nusing namespace LOG4CXX_NS;\nusing namespace LOG4CXX_NS::helpers;\n\nIMPLEMENT_LOG4CXX_OBJECT(InputStream)\n\nInputStream::InputStream()\n{\n}\n\nInputStream::~InputStream()\n{\n}\n\n// Path: src/main/cpp/writer.cpp\n/*\n * Licensed to the Apache Software Foundation (ASF) under one or more\n * contributor license agreements. See the NOTICE file distributed with\n * this work for additional information regarding copyright ownership.\n * The ASF licenses this file to You under the Apache License, Version 2.0\n * (the \"License\"); you may not use this file except in compliance with\n * the License. You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n */\n\n#include \n#include \n#include \n#include \n\nusing namespace LOG4CXX_NS::helpers;\n\nIMPLEMENT_LOG4CXX_OBJECT(Writer)\n\nWriter::Writer()\n{\n}\n\nWriter::~Writer()\n{\n}\n\n// Path: src/main/cpp/multiprocessrollingfileappender.cpp\n/*\n * Licensed to the Apache Software Foundation (ASF) under one or more\n * contributor license agreements. See the NOTICE file distributed with\n * this work for additional information regarding copyright ownership.\n * The ASF licenses this file to You under the Apache License, Version 2.0\n * (the \"License\"); you may not use this file except in compliance with\n * the License. You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n */\n\n#include \n#include \n#include \n#ifndef MAX_FILE_LEN\n\t#define MAX_FILE_LEN 2048\n#endif\n#include \n#include \n\n#include \n#include \n#include \n#include \n#include \n#include \n#include \n#include \n#include \n#include \n#include \n#include \n\nusing namespace LOG4CXX_NS;\nusing namespace LOG4CXX_NS::rolling;\nusing namespace LOG4CXX_NS::helpers;\nusing namespace LOG4CXX_NS::spi;\n\nstruct MultiprocessRollingFileAppender::MultiprocessRollingFileAppenderPriv : public FileAppenderPriv\n{\n\tMultiprocessRollingFileAppenderPriv() :\n\t\tFileAppenderPriv(),\n\t\tfileLength(0) {}\n\n\t/**\n\t * Triggering policy.\n\t */\n\tTriggeringPolicyPtr triggeringPolicy;\n\n\t/**\n\t * Rolling policy.\n\t */\n\tRollingPolicyPtr rollingPolicy;\n\n\t/**\n\t * Length of current active log file.\n\t */\n\tsize_t fileLength;\n\n\t/**\n\t * save the loggingevent\n\t */\n\tspi::LoggingEventPtr _event;\n};\n\n#define _priv static_cast(m_priv.get())\n\nIMPLEMENT_LOG4CXX_OBJECT(MultiprocessRollingFileAppender)\n\n\n/**\n * Construct a new instance.\n */\nMultiprocessRollingFileAppender::MultiprocessRollingFileAppender() :\n\tFileAppender (std::make_unique())\n{\n}\n\n/**\n * Prepare instance of use.\n */\nvoid MultiprocessRollingFileAppender::activateOptions(Pool& p)\n{\n\tif (_priv->rollingPolicy == NULL)\n\t{\n\t\tauto fwrp = std::make_shared();\n\t\tfwrp->setFileNamePattern(getFile() + LOG4CXX_STR(\".%i\"));\n\t\t_priv->rollingPolicy = fwrp;\n\t}\n\n\t//\n\t// if no explicit triggering policy and rolling policy is both.\n\t//\n\tif (_priv->triggeringPolicy == NULL)\n\t{\n\t\tTriggeringPolicyPtr trig = LOG4CXX_NS::cast(_priv->rollingPolicy);\n\n\t\tif (trig != NULL)\n\t\t{\n\t\t\t_priv->triggeringPolicy = trig;\n\t\t}\n\t}\n\n\tif (_priv->triggeringPolicy == NULL)\n\t{\n\t\t_priv->triggeringPolicy = std::make_shared();\n\t}\n\n\t{\n\t\tstd::lock_guard lock(_priv->mutex);\n\t\t_priv->triggeringPolicy->activateOptions(p);\n\t\t_priv->rollingPolicy->activateOptions(p);\n\n\t\ttry\n\t\t{\n\t\t\tRolloverDescriptionPtr rollover1 =\n\t\t\t\t_priv->rollingPolicy->initialize(getFile(), getAppend(), p);\n\n\t\t\tif (rollover1 != NULL)\n\t\t\t{\n\t\t\t\tActionPtr syncAction(rollover1->getSynchronous());\n\n\t\t\t\tif (syncAction != NULL)\n\t\t\t\t{\n\t\t\t\t\tsyncAction->execute(p);\n\t\t\t\t}\n\n\t\t\t\t_priv->fileName = rollover1->getActiveFileName();\n\t\t\t\t_priv->fileAppend = rollover1->getAppend();\n\n\t\t\t\t//\n\t\t\t\t// async action not yet implemented\n\t\t\t\t//\n\t\t\t\tActionPtr asyncAction(rollover1->getAsynchronous());\n\n\t\t\t\tif (asyncAction != NULL)\n\t\t\t\t{\n\t\t\t\t\tasyncAction->execute(p);\n\t\t\t\t}\n\t\t\t}\n\n\t\t\tFile activeFile;\n\t\t\tactiveFile.setPath(getFile());\n\n\t\t\tif (getAppend())\n\t\t\t{\n\t\t\t\t_priv->fileLength = activeFile.length(p);\n\t\t\t}\n\t\t\telse\n\t\t\t{\n\t\t\t\t_priv->fileLength = 0;\n\t\t\t}\n\n\t\t\tFileAppender::activateOptionsInternal(p);\n\t\t}\n\t\tcatch (std::exception&)\n\t\t{\n\t\t\tLogLog::warn(\n\t\t\t\tLogString(LOG4CXX_STR(\"Exception will initializing RollingFileAppender named \"))\n\t\t\t\t+ getName());\n\t\t}\n\t}\n}\n\nvoid MultiprocessRollingFileAppender::releaseFileLock(apr_file_t* lock_file)\n{\n\tif (lock_file)\n\t{\n\t\tapr_status_t stat = apr_file_unlock(lock_file);\n\n\t\tif (stat != APR_SUCCESS)\n\t\t{\n\t\t\tLogLog::warn(LOG4CXX_STR(\"flock: unlock failed\"));\n\t\t}\n\n\t\tapr_file_close(lock_file);\n\t\tlock_file = NULL;\n\t}\n}\n\n/**\n Implements the usual roll over behaviour.\n\n

If MaxBackupIndex is positive, then files\n {File.1, ..., File.MaxBackupIndex -1}\n are renamed to {File.2, ...,\n File.MaxBackupIndex}. Moreover, File is\n renamed File.1 and closed. A new File is\n created to receive further log output.\n\n

If MaxBackupIndex is equal to zero, then the\n File is truncated with no backup files created.\n\n * @return true if rollover performed.\n */\nbool MultiprocessRollingFileAppender::rollover(Pool& p)\n{\n\tstd::lock_guard lock(_priv->mutex);\n\treturn rolloverInternal(p);\n}\n\nbool MultiprocessRollingFileAppender::rolloverInternal(Pool& p)\n{\n\t//\n\t// can't roll without a policy\n\t//\n\tif (_priv->rollingPolicy != NULL)\n\t{\n\n\t\t{\n\t\t\tLogString fileName(getFile());\n\t\t\tRollingPolicyBasePtr basePolicy = LOG4CXX_NS::cast(_priv->rollingPolicy);\n\t\t\tapr_time_t n = apr_time_now();\n\t\t\tObjectPtr obj = std::make_shared(n);\n\t\t\tLogString fileNamePattern;\n\n\t\t\tif (basePolicy)\n\t\t\t{\n\t\t\t\tif (basePolicy->getPatternConverterList().size())\n\t\t\t\t{\n\t\t\t\t\t(*(basePolicy->getPatternConverterList().begin()))->format(obj, fileNamePattern, p);\n\t\t\t\t\tfileName = std::string(fileNamePattern);\n\t\t\t\t}\n\t\t\t}\n\n\t\t\tbool bAlreadyRolled = true;\n\t\t\tchar szDirName[MAX_FILE_LEN] = {'\\0'};\n\t\t\tchar szBaseName[MAX_FILE_LEN] = {'\\0'};\n\t\t\tchar szUid[MAX_FILE_LEN] = {'\\0'};\n\t\t\tmemcpy(szDirName, fileName.c_str(), fileName.size() > MAX_FILE_LEN ? MAX_FILE_LEN : fileName.size());\n\t\t\tmemcpy(szBaseName, fileName.c_str(), fileName.size() > MAX_FILE_LEN ? MAX_FILE_LEN : fileName.size());\n\t\t\tapr_uid_t uid;\n\t\t\tapr_gid_t groupid;\n\t\t\tapr_status_t stat = apr_uid_current(&uid, &groupid, p.getAPRPool());\n\n\t\t\tif (stat == APR_SUCCESS)\n\t\t\t{\n#ifdef WIN32\n\t\t\t\tsnprintf(szUid, MAX_FILE_LEN, \"%p\", uid);\n#else\n\t\t\t\tsnprintf(szUid, MAX_FILE_LEN, \"%u\", (unsigned int)uid);\n#endif\n\t\t\t}\n\n\t\t\tLOG4CXX_NS::filesystem::path path = szDirName;\n\t\t\tconst auto lockname = path.parent_path() / (path.filename().string() + szUid + \".lock\");\n\t\t\tapr_file_t* lock_file;\n\t\t\tstat = apr_file_open(&lock_file, lockname.string().c_str(), APR_CREATE | APR_READ | APR_WRITE, APR_OS_DEFAULT, p.getAPRPool());\n\n\t\t\tif (stat != APR_SUCCESS)\n\t\t\t{\n\t\t\t\tLogString err = LOG4CXX_STR(\"lockfile return error: open lockfile failed. \");\n\t\t\t\terr += (strerror(errno));\n\t\t\t\tLogLog::warn(err);\n\t\t\t\tbAlreadyRolled = false;\n\t\t\t\tlock_file = NULL;\n\t\t\t}\n\t\t\telse\n\t\t\t{\n\t\t\t\tstat = apr_file_lock(lock_file, APR_FLOCK_EXCLUSIVE);\n\n\t\t\t\tif (stat != APR_SUCCESS)\n\t\t\t\t{\n\t\t\t\t\tLogString err = LOG4CXX_STR(\"apr_file_lock: lock failed. \");\n\t\t\t\t\terr += (strerror(errno));\n\t\t\t\t\tLogLog::warn(err);\n\t\t\t\t\tbAlreadyRolled = false;\n\t\t\t\t}\n\t\t\t\telse\n\t\t\t\t{\n\t\t\t\t\tif (_priv->_event)\n\t\t\t\t\t{\n\t\t\t\t\t\t_priv->triggeringPolicy->isTriggeringEvent(this, _priv->_event, getFile(), getFileLength());\n\t\t\t\t\t}\n\t\t\t\t}\n\t\t\t}\n\n\t\t\tif (bAlreadyRolled)\n\t\t\t{\n\t\t\t\tapr_finfo_t finfo1, finfo2;\n\t\t\t\tapr_status_t st1, st2;\n\t\t\t\tconst WriterPtr writer = getWriter();\n\t\t\t\tconst FileOutputStreamPtr fos = LOG4CXX_NS::cast( writer );\n\t\t\t\tif( !fos ){\n\t\t\t\t\tLogLog::error( LOG4CXX_STR(\"Can't cast writer to FileOutputStream\") );\n\t\t\t\t\treturn false;\n\t\t\t\t}\n\t\t\t\tapr_file_t* _fd = fos->getFilePtr();\n\t\t\t\tst1 = apr_file_info_get(&finfo1, APR_FINFO_IDENT, _fd);\n\n\t\t\t\tif (st1 != APR_SUCCESS)\n\t\t\t\t{\n\t\t\t\t\tLogLog::warn(LOG4CXX_STR(\"apr_file_info_get failed\"));\n\t\t\t\t}\n\n\t\t\t\tLogString fname = getFile();\n\t\t\t\tst2 = apr_stat(&finfo2, fname.c_str(), APR_FINFO_IDENT, p.getAPRPool());\n\n\t\t\t\tif (st2 != APR_SUCCESS)\n\t\t\t\t{\n\t\t\t\t\tLogLog::warn(LOG4CXX_STR(\"apr_stat failed.\"));\n\t\t\t\t}\n\n\t\t\t\tbAlreadyRolled = ((st1 == APR_SUCCESS) && (st2 == APR_SUCCESS)\n\t\t\t\t\t\t&& ((finfo1.device != finfo2.device) || (finfo1.inode != finfo2.inode)));\n\t\t\t}\n\n\t\t\tif (!bAlreadyRolled)\n\t\t\t{\n\n\t\t\t\ttry\n\t\t\t\t{\n\t\t\t\t\tRolloverDescriptionPtr rollover1(_priv->rollingPolicy->rollover(this->getFile(), this->getAppend(), p));\n\n\t\t\t\t\tif (rollover1 != NULL)\n\t\t\t\t\t{\n\t\t\t\t\t\tif (rollover1->getActiveFileName() == getFile())\n\t\t\t\t\t\t{\n\t\t\t\t\t\t\tcloseWriter();\n\n\t\t\t\t\t\t\tbool success = true;\n\n\t\t\t\t\t\t\tif (rollover1->getSynchronous() != NULL)\n\t\t\t\t\t\t\t{\n\t\t\t\t\t\t\t\tsuccess = false;\n\n\t\t\t\t\t\t\t\ttry\n\t\t\t\t\t\t\t\t{\n\t\t\t\t\t\t\t\t\tsuccess = rollover1->getSynchronous()->execute(p);\n\t\t\t\t\t\t\t\t}\n\t\t\t\t\t\t\t\tcatch (std::exception& ex)\n\t\t\t\t\t\t\t\t{\n\t\t\t\t\t\t\t\t\tLogLog::warn(LOG4CXX_STR(\"Exception on rollover\"));\n\t\t\t\t\t\t\t\t\tLogString exmsg;\n\t\t\t\t\t\t\t\t\tLOG4CXX_NS::helpers::Transcoder::decode(ex.what(), exmsg);\n\t\t\t\t\t\t\t\t\t_priv->errorHandler->error(exmsg, ex, 0);\n\t\t\t\t\t\t\t\t}\n\t\t\t\t\t\t\t}\n\n\t\t\t\t\t\t\tif (success)\n\t\t\t\t\t\t\t{\n\t\t\t\t\t\t\t\tif (rollover1->getAppend())\n\t\t\t\t\t\t\t\t{\n\t\t\t\t\t\t\t\t\t_priv->fileLength = File().setPath(rollover1->getActiveFileName()).length(p);\n\t\t\t\t\t\t\t\t}\n\t\t\t\t\t\t\t\telse\n\t\t\t\t\t\t\t\t{\n\t\t\t\t\t\t\t\t\t_priv->fileLength = 0;\n\t\t\t\t\t\t\t\t}\n\n\t\t\t\t\t\t\t\t//\n\t\t\t\t\t\t\t\t// async action not yet implemented\n\t\t\t\t\t\t\t\t//\n\t\t\t\t\t\t\t\tActionPtr asyncAction(rollover1->getAsynchronous());\n\n\t\t\t\t\t\t\t\tif (asyncAction != NULL)\n\t\t\t\t\t\t\t\t{\n\t\t\t\t\t\t\t\t\tasyncAction->execute(p);\n\t\t\t\t\t\t\t\t}\n\n\t\t\t\t\t\t\t\tsetFileInternal(\n\t\t\t\t\t\t\t\t\trollover1->getActiveFileName(), rollover1->getAppend(),\n\t\t\t\t\t\t\t\t\t_priv->bufferedIO, _priv->bufferSize, p);\n\t\t\t\t\t\t\t}\n\t\t\t\t\t\t\telse\n\t\t\t\t\t\t\t{\n\t\t\t\t\t\t\t\tsetFileInternal(\n\t\t\t\t\t\t\t\t\trollover1->getActiveFileName(), true, _priv->bufferedIO, _priv->bufferSize, p);\n\t\t\t\t\t\t\t}\n\t\t\t\t\t\t}\n\t\t\t\t\t\telse\n\t\t\t\t\t\t{\n\t\t\t\t\t\t\tcloseWriter();\n\t\t\t\t\t\t\tsetFileInternal(rollover1->getActiveFileName());\n\t\t\t\t\t\t\t// Call activateOptions to create any intermediate directories(if required)\n\t\t\t\t\t\t\tFileAppender::activateOptionsInternal(p);\n\t\t\t\t\t\t\tOutputStreamPtr os(new FileOutputStream(\n\t\t\t\t\t\t\t\t\trollover1->getActiveFileName(), rollover1->getAppend()));\n\t\t\t\t\t\t\tWriterPtr newWriter(createWriter(os));\n\t\t\t\t\t\t\tsetWriterInternal(newWriter);\n\n\t\t\t\t\t\t\tbool success = true;\n\n...\n", "needle_token_start": 8334, "needle_token_end": 8356, "code_context_ntokens": 15213, "output": ["```cpp\n// Path: src/main/cpp/timezone.cpp\n/*\n * Licensed to the Apache Software Foundation (ASF) under one or more\n * contributor license agreements. See the NOTICE file distributed with\n * this work for additional information regarding copyright ownership.\n * The ASF licenses this file to You under the Apache License, Version 2.0\n * (the \"License\"); you may not use this file except in compliance with\n * the License. You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n */\n\n#define __STDC_CONSTANT_MACROS\n#include \n#include \n#include \n\n#include \n#include \n#include \n#include \n#include \n#include \n#include \n\nusing namespace LOG4CXX_NS;\nusing namespace LOG4CXX_NS::helpers;\n\nIMPLEMENT_LOG4CXX_OBJECT( TimeZone )\n\nnamespace LOG4CXX_NS\n{\nnamespace helpers\n{\nnamespace TimeZoneImpl\n{\n/** Time zone object that represents GMT. */\nclass GMTTimeZone : public TimeZone\n{\n\tpublic:\n\t\t/** Class factory. */\n\t\tstatic const TimeZonePtr& getInstance()\n\t\t{\n\t\t\tstatic WideLife tz = std::make_shared();\n\t\t\treturn tz;\n\t\t}\n\n\t\t/** Explode time to human readable form. */\n\t\tlog4cxx_status_t explode( apr_time_exp_t* result, log4cxx_time_t input ) const\n\t\t{\n\t\t\tapr_status_t stat;\n\n\t\t\t// APR 1.1 and early mishandles microseconds on dates\n\t\t\t// before 1970, APR bug 32520\n\t\t\tif (LOG4CXX_UNLIKELY(input < 0 && apr_time_usec(input) < 0))\n\t\t\t{\n\t\t\t\tapr_time_t floorTime = (apr_time_sec(input) - 1) * APR_USEC_PER_SEC;\n\t\t\t\tstat = apr_time_exp_gmt(result, floorTime);\n\t\t\t\tresult->tm_usec = (int) (input - floorTime);\n\t\t\t}\n\t\t\telse\n\t\t\t{\n\t\t\t\tstat = apr_time_exp_gmt( result, input );\n\t\t\t}\n\n\t\t\treturn stat;\n\t\t}\n\n\t\tGMTTimeZone() : TimeZone( LOG4CXX_STR(\"GMT\") )\n\t\t{\n\t\t}\n};\n\n\n\n/** Time zone object that represents GMT. */\nclass LocalTimeZone : public TimeZone\n{\n\tpublic:\n\t\t/** Class factory. */\n\t\tstatic const TimeZonePtr& getInstance()\n\t\t{\n\t\t\tstatic WideLife tz = std::make_shared();\n\t\t\treturn tz;\n\t\t}\n\n\t\t/** Explode time to human readable form. */\n\t\tlog4cxx_status_t explode( apr_time_exp_t* result, log4cxx_time_t input ) const\n\t\t{\n\t\t\tapr_status_t stat;\n\n\t\t\t// APR 1.1 and early mishandles microseconds on dates\n\t\t\t// before 1970, APR bug 32520\n\t\t\tif (LOG4CXX_UNLIKELY(input < 0 && apr_time_usec(input) < 0))\n\t\t\t{\n\t\t\t\tapr_time_t floorTime = (apr_time_sec(input) - 1) * APR_USEC_PER_SEC;\n\t\t\t\tstat = apr_time_exp_lt(result, floorTime);\n\t\t\t\tresult->tm_usec = (int) (input - floorTime);\n\t\t\t}\n\t\t\telse\n\t\t\t{\n\t\t\t\tstat = apr_time_exp_lt( result, input );\n\t\t\t}\n\n\t\t\treturn stat;\n\t\t}\n\n\n\t\tLocalTimeZone() : TimeZone( getTimeZoneName() )\n\t\t{\n\t\t}\n\n\tprivate:\n\t\tstatic const LogString getTimeZoneName()\n\t\t{\n\t\t\tconst int MAX_TZ_LENGTH = 255;\n\t\t\tchar tzName[MAX_TZ_LENGTH];\n\t\t\tapr_size_t tzLength;\n\t\t\tapr_time_exp_t tm;\n\t\t\tapr_time_exp_lt(&tm, 0);\n\t\t\tapr_strftime(tzName, &tzLength, MAX_TZ_LENGTH, \"%Z\", &tm);\n\n\t\t\tif (tzLength == 0)\n\t\t\t{\n\t\t\t\tapr_strftime(tzName, &tzLength, MAX_TZ_LENGTH, \"%z\", &tm);\n\t\t\t}\n\n\t\t\ttzName[tzLength] = 0;\n\t\t\tLogString retval;\n\t\t\tLOG4CXX_NS::helpers::Transcoder::decode(t"]} +{"repo": "apache/logging-log4cxx", "name": "UTF8CharsetEncoder", "language": "cpp", "path": "src/main/cpp/charsetencoder.cpp", "position_ratio": 0.65, "description": "\nFunction Description:\n1. **Purpose**: The function is designed to convert a sequence of characters into UTF-8 encoded bytes, which is a widely used character encoding system that supports a large range of characters from various scripts.\n2. **Input**: It accepts a sequence of characters along with a position indicator to start the conversion, and a buffer to store the encoded bytes.\n3. **Output**: The result is a status indicating the success or failure of the encoding process, and the buffer is filled with the UTF-8 encoded bytes.\n4. **Procedure**: The function iterates over the input characters, converting each character into its corresponding UTF-8 encoded form. It checks if the character can be encoded within the remaining space of the buffer. If a character cannot be encoded (due to space constraints or encoding issues), the process stops, and an appropriate status is returned.\n\n", "instruction": "Based on the function description and code context, please retrieve and repeat the exact described function from the code context in a code block wrapped by ```:", "template": "instruction\ncode_context\ndescription\ninstruction", "code_context": "// Path: src/main/cpp/level.cpp\n/*\n * Licensed to the Apache Software Foundation (ASF) under one or more\n * contributor license agreements. See the NOTICE file distributed with\n * this work for additional information regarding copyright ownership.\n * The ASF licenses this file to You under the Apache License, Version 2.0\n * (the \"License\"); you may not use this file except in compliance with\n * the License. You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n */\n\n#include \n#include \n#include \n#include \n#if !defined(LOG4CXX)\n\t#define LOG4CXX 1\n#endif\n#include \n#include \n\nusing namespace LOG4CXX_NS;\nusing namespace LOG4CXX_NS::helpers;\n\nIMPLEMENT_LOG4CXX_OBJECT_WITH_CUSTOM_CLASS(Level, LevelClass)\n\nLevelPtr Level::getOff()\n{\n\tstatic WideLife offLevel = std::make_shared(Level::OFF_INT, LOG4CXX_STR(\"OFF\"), 0);\n\treturn offLevel;\n}\n\nLevelPtr Level::getFatal()\n{\n\tstatic WideLife fatalLevel = std::make_shared(Level::FATAL_INT, LOG4CXX_STR(\"FATAL\"), 0);\n\treturn fatalLevel;\n}\n\nLevelPtr Level::getError()\n{\n\tstatic WideLife errorLevel = std::make_shared(Level::ERROR_INT, LOG4CXX_STR(\"ERROR\"), 3);\n\treturn errorLevel;\n}\n\nLevelPtr Level::getWarn()\n{\n\tstatic WideLife warnLevel = std::make_shared(Level::WARN_INT, LOG4CXX_STR(\"WARN\"), 4);\n\treturn warnLevel;\n}\n\nLevelPtr Level::getInfo()\n{\n\tstatic WideLife infoLevel = std::make_shared(Level::INFO_INT, LOG4CXX_STR(\"INFO\"), 6);\n\treturn infoLevel;\n}\n\nLevelPtr Level::getDebug()\n{\n\tstatic WideLife debugLevel = std::make_shared(Level::DEBUG_INT, LOG4CXX_STR(\"DEBUG\"), 7);\n\treturn debugLevel;\n}\n\nLevelPtr Level::getTrace()\n{\n\tstatic WideLife traceLevel = std::make_shared(Level::TRACE_INT, LOG4CXX_STR(\"TRACE\"), 7);\n\treturn traceLevel;\n}\n\n\nLevelPtr Level::getAll()\n{\n\tstatic WideLife allLevel = std::make_shared(Level::ALL_INT, LOG4CXX_STR(\"ALL\"), 7);\n\treturn allLevel;\n}\n\n\n\nLevel::Level(int level1,\n\tconst LogString& name1, int syslogEquivalent1)\n\t: level(level1), name(name1), syslogEquivalent(syslogEquivalent1)\n{\n\tAPRInitializer::initialize();\n}\n\n\nLevelPtr Level::toLevelLS(const LogString& sArg)\n{\n\treturn toLevelLS(sArg, Level::getDebug());\n}\n\nLogString Level::toString() const\n{\n\treturn name;\n}\n\n\nLevelPtr Level::toLevel(int val)\n{\n\treturn toLevel(val, Level::getDebug());\n}\n\nconst Level::Data& Level::getData()\n{\n\tstatic Data data =\n\t\t{ getOff()\n\t\t, getFatal()\n\t\t, getError()\n\t\t, getWarn()\n\t\t, getInfo()\n\t\t, getDebug()\n\t\t, getTrace()\n\t\t, getAll()\n\t\t};\n\treturn data;\n}\n\n\nLevelPtr Level::toLevel(int val, const LevelPtr& defaultLevel)\n{\n\tswitch (val)\n\t{\n\t\tcase ALL_INT:\n\t\t\treturn getAll();\n\n\t\tcase DEBUG_INT:\n\t\t\treturn getDebug();\n\n\t\tcase TRACE_INT:\n\t\t\treturn getTrace();\n\n\t\tcase INFO_INT:\n\t\t\treturn getInfo();\n\n\t\tcase WARN_INT:\n\t\t\treturn getWarn();\n\n\t\tcase ERROR_INT:\n\t\t\treturn getError();\n\n\t\tcase FATAL_INT:\n\t\t\treturn getFatal();\n\n\t\tcase OFF_INT:\n\t\t\treturn getOff();\n\n\t\tdefault:\n\t\t\treturn defaultLevel;\n\t}\n}\n\nLevelPtr Level::toLevel(const std::string& sArg)\n{\n\treturn toLevel(sArg, Level::getDebug());\n}\n\nLevelPtr Level::toLevel(const std::string& sArg, const LevelPtr& defaultLevel)\n{\n...\n// Path: src/main/cpp/optionconverter.cpp\n/*\n * Licensed to the Apache Software Foundation (ASF) under one or more\n * contributor license agreements. See the NOTICE file distributed with\n * this work for additional information regarding copyright ownership.\n * The ASF licenses this file to You under the Apache License, Version 2.0\n * (the \"License\"); you may not use this file except in compliance with\n * the License. You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n */\n\n#include \n#include \n#include \n#include \n#include \n#include \n#include \n#include \n#include \n#include \n#include \n#include \n#include \n#include \n#include \n#include \n#include \n#include \n#include \n#include \n#include \n#include \n#if !defined(LOG4CXX)\n\t#define LOG4CXX 1\n#endif\n#include \n#include \n\nnamespace LOG4CXX_NS\n{\n\nclass ConfiguratorWatchdog : public helpers::FileWatchdog\n{\n\tspi::ConfiguratorPtr m_config;\n\tpublic:\n ConfiguratorWatchdog(const spi::ConfiguratorPtr& config, const File& filename)\n : helpers::FileWatchdog(filename)\n , m_config(config)\n {\n }\n\n /**\n Call PropertyConfigurator#doConfigure(const String& configFileName,\n const spi::LoggerRepositoryPtr& hierarchy) with the\n filename to reconfigure log4cxx.\n */\n void doOnChange() override\n {\n m_config->doConfigure(file(), LogManager::getLoggerRepository());\n }\n};\n\n}\n\nusing namespace LOG4CXX_NS;\nusing namespace LOG4CXX_NS::helpers;\nusing namespace LOG4CXX_NS::spi;\n\n\nLogString OptionConverter::convertSpecialChars(const LogString& s)\n{\n\tlogchar c;\n\tLogString sbuf;\n\n\tLogString::const_iterator i = s.begin();\n\n\twhile (i != s.end())\n\t{\n\t\tc = *i++;\n\n\t\tif (c == 0x5C /* '\\\\' */)\n\t\t{\n\t\t\tc = *i++;\n\n\t\t\tswitch (c)\n\t\t\t{\n\t\t\t\tcase 0x6E: //'n'\n\t\t\t\t\tc = 0x0A;\n\t\t\t\t\tbreak;\n\n\t\t\t\tcase 0x72: //'r'\n\t\t\t\t\tc = 0x0D;\n\t\t\t\t\tbreak;\n\n\t\t\t\tcase 0x74: //'t'\n\t\t\t\t\tc = 0x09;\n\t\t\t\t\tbreak;\n\n\t\t\t\tcase 0x66: //'f'\n\t\t\t\t\tc = 0x0C;\n\t\t\t\t\tbreak;\n\n\t\t\t\tdefault:\n\t\t\t\t\tbreak;\n\t\t\t}\n\t\t}\n\n\t\tsbuf.append(1, c);\n\t}\n\n\treturn sbuf;\n}\n\n\nbool OptionConverter::toBoolean(const LogString& value, bool dEfault)\n{\n\tif (value.length() >= 4)\n\t{\n\t\tif (StringHelper::equalsIgnoreCase(value.substr(0, 4),\n\t\t\t\tLOG4CXX_STR(\"TRUE\"), LOG4CXX_STR(\"true\")))\n\t\t{\n\t\t\treturn true;\n\t\t}\n\t}\n\n\tif (dEfault && value.length() >= 5)\n\t{\n\t\tif (StringHelper::equalsIgnoreCase(value.substr(0, 5),\n\t\t\t\tLOG4CXX_STR(\"FALSE\"), LOG4CXX_STR(\"false\")))\n\t\t{\n\t\t\treturn false;\n\t\t}\n\t}\n\n\treturn dEfault;\n}\n\nint OptionConverter::toInt(const LogString& value, int dEfault)\n{\n\tLogString trimmed(StringHelper::trim(value));\n\n\tif (trimmed.empty())\n\t{\n\t\treturn dEfault;\n\t}\n\n\tLOG4CXX_ENCODE_CHAR(cvalue, trimmed);\n\n\treturn (int) atol(cvalue.c_str());\n}\n\nlong OptionConverter::toFileSize(const LogString& s, long dEfault)\n{\n\tif (s.empty())\n\t{\n\t\treturn dEfault;\n\t}\n\n\tsize_t index = s.find_first_of(LOG4CXX_STR(\"bB\"));\n\n\tif (index != LogString::npos && index > 0)\n\t{\n\t\tlong multiplier = 1;\n\t\tindex--;\n\n\t\tif (s[index] == 0x6B /* 'k' */ || s[index] == 0x4B /* 'K' */)\n\t\t{\n\t\t\tmultiplier = 1024;\n\t\t}\n\t\telse if (s[index] == 0x6D /* 'm' */ || s[index] == 0x4D /* 'M' */)\n\t\t{\n\t\t\tmultiplier = 1024 * 1024;\n\t\t}\n\t\telse if (s[index] == 0x67 /* 'g'*/ || s[index] == 0x47 /* 'G' */)\n\t\t{\n\t\t\tmultiplier = 1024 * 1024 * 1024;\n\t\t}\n\n\t\treturn toInt(s.substr(0, index), 1) * multiplier;\n\t}\n\n\treturn toInt(s, 1);\n}\n\nLogString OptionConverter::findAndSubst(const LogString& key, Properties& props)\n{\n\tLogString value(props.getProperty(key));\n\n\tif (value.empty())\n\t{\n\t\treturn value;\n\t}\n\n\ttry\n\t{\n\t\treturn substVars(value, props);\n\t}\n\tcatch (IllegalArgumentException& e)\n\t{\n\t\tLogLog::error(((LogString) LOG4CXX_STR(\"Bad option value [\"))\n\t\t\t+ value + LOG4CXX_STR(\"].\"), e);\n\t\treturn value;\n\t}\n}\n\nLogString OptionConverter::substVars(const LogString& val, Properties& props)\n{\n\tLogString sbuf;\n\tconst logchar delimStartArray[] = { 0x24, 0x7B, 0 };\n\tconst LogString delimStart(delimStartArray);\n\tconst logchar delimStop = 0x7D; // '}';\n\tconst size_t DELIM_START_LEN = 2;\n\tconst size_t DELIM_STOP_LEN = 1;\n\n\tsize_t i = 0;\n\n\twhile (true)\n\t{\n\t\tsize_t j = val.find(delimStart, i);\n\n\t\tif (j == val.npos)\n\t\t{\n\t\t\t// no more variables\n\t\t\tif (i == 0)\n\t\t\t{\n\t\t\t\t// this is a simple string\n\t\t\t\treturn val;\n\t\t\t}\n\t\t\telse\n\t\t\t{\n\t\t\t\t// add the tail string which contails no variables and return the result.\n\t\t\t\tsbuf.append(val.substr(i, val.length() - i));\n\t\t\t\treturn sbuf;\n\t\t\t}\n\t\t}\n\t\telse\n\t\t{\n\t\t\tsbuf.append(val.substr(i, j - i));\n\t\t\tsize_t k = val.find(delimStop, j);\n\n\t\t\tif (k == val.npos)\n\t\t\t{\n\t\t\t\tLogString msg(1, (logchar) 0x22 /* '\\\"' */);\n\t\t\t\tmsg.append(val);\n\t\t\t\tmsg.append(LOG4CXX_STR(\"\\\" has no closing brace. Opening brace at position \"));\n\t\t\t\tPool p;\n\t\t\t\tStringHelper::toString(j, p, msg);\n\t\t\t\tmsg.append(1, (logchar) 0x2E /* '.' */);\n\t\t\t\tthrow IllegalArgumentException(msg);\n\t\t\t}\n\t\t\telse\n\t\t\t{\n\t\t\t\tj += DELIM_START_LEN;\n\t\t\t\tLogString key = val.substr(j, k - j);\n\t\t\t\t// first try in System properties\n\t\t\t\tLogString replacement(getSystemProperty(key, LogString()));\n\n\t\t\t\t// then try props parameter\n\t\t\t\tif (replacement.empty())\n\t\t\t\t{\n\t\t\t\t\treplacement = props.getProperty(key);\n\t\t\t\t}\n\n\t\t\t\tif (!replacement.empty())\n\t\t\t\t{\n\t\t\t\t\t// Do variable substitution on the replacement string\n\t\t\t\t\t// such that we can solve \"Hello ${x2}\" as \"Hello p1\"\n\t\t\t\t\t// the where the properties are\n\t\t\t\t\t// x1=p1\n\t\t\t\t\t// x2=${x1}\n\t\t\t\t\tLogString recursiveReplacement = substVars(replacement, props);\n\t\t\t\t\tsbuf.append(recursiveReplacement);\n\t\t\t\t}\n\n\t\t\t\ti = k + DELIM_STOP_LEN;\n\t\t\t}\n\t\t}\n\t}\n}\n\nLogString OptionConverter::getSystemProperty(const LogString& key, const LogString& def)\n{\n\tif (!key.empty())\n\t{\n\t\tLogString value(System::getProperty(key));\n\n\t\tif (!value.empty())\n\t\t{\n\t\t\treturn value;\n\t\t}\n\t}\n\n\treturn def;\n}\n\nLevelPtr OptionConverter::toLevel(const LogString& value,\n\tconst LevelPtr& defaultValue)\n{\n\tsize_t hashIndex = value.find(LOG4CXX_STR(\"#\"));\n\n\tif (hashIndex == LogString::npos)\n\t{\n\t\tif (value.empty())\n\t\t{\n\t\t\treturn defaultValue;\n\t\t}\n\t\telse\n\t\t{\n\t\t\tLogLog::debug(\n\t\t\t\t((LogString) LOG4CXX_STR(\"OptionConverter::toLevel: no class name specified, level=[\"))\n\t\t\t\t+ value\n\t\t\t\t+ LOG4CXX_STR(\"]\"));\n\t\t\t// no class name specified : use standard Level class\n\t\t\treturn Level::toLevelLS(value, defaultValue);\n\t\t}\n\t}\n\n\tLogString clazz = value.substr(hashIndex + 1);\n\tLogString levelName = value.substr(0, hashIndex);\n\tLogLog::debug(((LogString) LOG4CXX_STR(\"OptionConverter::toLevel: class=[\"))\n\t\t+ clazz + LOG4CXX_STR(\"], level=[\") + levelName + LOG4CXX_STR(\"]\"));\n\n\t// This is degenerate case but you never know.\n\tif (levelName.empty())\n\t{\n\t\treturn Level::toLevelLS(value, defaultValue);\n\t}\n\n\ttry\n\t{\n\t\tLevel::LevelClass& levelClass =\n\t\t\t(Level::LevelClass&)Loader::loadClass(clazz);\n\t\treturn levelClass.toLevel(levelName);\n\t}\n\tcatch (ClassNotFoundException&)\n\t{\n\t\tLogLog::warn(((LogString) LOG4CXX_STR(\"custom level class [\"))\n\t\t\t+ clazz + LOG4CXX_STR(\"] not found.\"));\n\t}\n\tcatch (Exception& oops)\n\t{\n\t\tLogLog::warn(\n\t\t\tLOG4CXX_STR(\"class [\") + clazz + LOG4CXX_STR(\"], level [\") + levelName +\n\t\t\tLOG4CXX_STR(\"] conversion) failed.\"), oops);\n\t}\n\tcatch (...)\n\t{\n\t\tLogLog::warn(\n\t\t\tLOG4CXX_STR(\"class [\") + clazz + LOG4CXX_STR(\"], level [\") + levelName +\n\t\t\tLOG4CXX_STR(\"] conversion) failed.\"));\n\t}\n\n\treturn defaultValue;\n}\n\n\nObjectPtr OptionConverter::instantiateByKey(Properties& props, const LogString& key,\n\tconst Class& superClass, const ObjectPtr& defaultValue)\n{\n\t// Get the value of the property in string form\n\tLogString className(findAndSubst(key, props));\n\n\tif (className.empty())\n\t{\n\t\tLogLog::error(\n\t\t\t((LogString) LOG4CXX_STR(\"Could not find value for key \")) + key);\n\t\treturn defaultValue;\n\t}\n\n\t// Trim className to avoid trailing spaces that cause problems.\n\treturn OptionConverter::instantiateByClassName(\n\t\t\tStringHelper::trim(className), superClass, defaultValue);\n}\n\nObjectPtr OptionConverter::instantiateByClassName(const LogString& className,\n\tconst Class& superClass, const ObjectPtr& defaultValue)\n{\n\tif (!className.empty())\n\t{\n\t\ttry\n\t\t{\n\t\t\tconst Class& classObj = Loader::loadClass(className);\n\t\t\tObjectPtr newObject = ObjectPtr(classObj.newInstance());\n\n\t\t\tif (!newObject->instanceof(superClass))\n\t\t\t{\n\t\t\t\treturn defaultValue;\n\t\t\t}\n\n\t\t\treturn newObject;\n\t\t}\n\t\tcatch (Exception& e)\n\t\t{\n\t\t\tLogLog::error(LOG4CXX_STR(\"Could not instantiate class [\") +\n\t\t\t\tclassName + LOG4CXX_STR(\"].\"), e);\n\t\t}\n\t}\n\n\treturn defaultValue;\n}\n\nvoid OptionConverter::selectAndConfigure(const File& configFileName,\n\tconst LogString& _clazz, spi::LoggerRepositoryPtr hierarchy, int delay)\n{\n\tConfiguratorPtr configurator;\n\tLogString clazz = _clazz;\n\n\tLogString filename(configFileName.getPath());\n\n#if LOG4CXX_HAS_DOMCONFIGURATOR\n\tif (clazz.empty()\n\t\t&& filename.length() > 4\n\t\t&& StringHelper::equalsIgnoreCase(\n\t\t\tfilename.substr(filename.length() - 4),\n\t\t\tLOG4CXX_STR(\".XML\"), LOG4CXX_STR(\".xml\")))\n\t{\n\t\tclazz = LOG4CXX_NS::xml::DOMConfigurator::getStaticClass().toString();\n\t}\n#endif\n\n\tif (!clazz.empty())\n\t{\n\t\tLogLog::debug(LOG4CXX_STR(\"Preferred configurator class: \") + clazz);\n\t\tconst Class& clazzObj = Loader::loadClass(clazz);\n\t\tObjectPtr obj = ObjectPtr(clazzObj.newInstance());\n\t\tconfigurator = LOG4CXX_NS::cast(obj);\n\n\t\tif (configurator == 0)\n\t\t{\n\t\t\tLogLog::error(LOG4CXX_STR(\"Could not instantiate configurator [\")\n\t\t\t\t+ clazz + LOG4CXX_STR(\"].\"));\n\t\t\treturn;\n\t\t}\n\t}\n\telse\n\t{\n\t\tconfigurator = std::make_shared();\n\t}\n\n\tif (0 < delay)\n\t{\n\t\tauto dog = new ConfiguratorWatchdog(configurator, configFileName);\n\t\tAPRInitializer::registerCleanup(dog);\n\t\tdog->setDelay(delay);\n\t\tdog->start();\n\t}\n\telse\n\t\tconfigurator->doConfigure(configFileName, hierarchy);\n}\n\n// Path: src/main/cpp/locale.cpp\n/*\n * Licensed to the Apache Software Foundation (ASF) under one or more\n * contributor license agreements. See the NOTICE file distributed with\n * this work for additional information regarding copyright ownership.\n * The ASF licenses this file to You under the Apache License, Version 2.0\n * (the \"License\"); you may not use this file except in compliance with\n * the License. You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n */\n\n#include \n\nusing namespace LOG4CXX_NS;\nusing namespace LOG4CXX_NS::helpers;\n\nstruct Locale::LocalePrivate\n{\n\tLocalePrivate(const LogString& language1)\n\t\t: language(language1)\n\t{\n\t}\n\n\tLocalePrivate(const LogString& language1, const LogString& country1)\n\t\t: language(language1), country(country1)\n\t{\n\t}\n\n\tLocalePrivate(const LogString& language1, const LogString& country1,\n\t\tconst LogString& variant1)\n\t\t: language(language1), country(country1), variant(variant1)\n\t{\n\t}\n\n\tconst LogString language;\n\tconst LogString country;\n\tconst LogString variant;\n};\n\nLocale::Locale(const LogString& language1)\n\t: m_priv(std::make_unique(language1))\n{\n}\n\nLocale::Locale(const LogString& language1, const LogString& country1)\n\t: m_priv(std::make_unique(language1, country1))\n{\n}\n\nLocale::Locale(const LogString& language1, const LogString& country1,\n\tconst LogString& variant1)\n\t: m_priv(std::make_unique(language1, country1, variant1))\n{\n}\n\nLocale::~Locale() {}\n\nconst LogString& Locale::getLanguage() const\n{\n\treturn m_priv->language;\n}\n\nconst LogString& Locale::getCountry() const\n{\n\treturn m_priv->country;\n}\n\nconst LogString& Locale::getVariant() const\n{\n\treturn m_priv->variant;\n}\n\n\n// Path: src/main/cpp/charsetencoder.cpp\n/*\n * Licensed to the Apache Software Foundation (ASF) under one or more\n * contributor license agreements. See the NOTICE file distributed with\n * this work for additional information regarding copyright ownership.\n * The ASF licenses this file to You under the Apache License, Version 2.0\n * (the \"License\"); you may not use this file except in compliance with\n * the License. You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n */\n#include \n#include \n#include \n#include \n#include \n#include \n#include \n#include \n\n#if !defined(LOG4CXX)\n\t#define LOG4CXX 1\n#endif\n\n#include \n#include \n#include \n\n#ifdef LOG4CXX_HAS_WCSTOMBS\n\t#include \n#endif\n\nusing namespace LOG4CXX_NS;\nusing namespace LOG4CXX_NS::helpers;\n\nIMPLEMENT_LOG4CXX_OBJECT(CharsetEncoder)\n\nnamespace LOG4CXX_NS\n{\n\nnamespace helpers\n{\n\n#if APR_HAS_XLATE\n/**\n* A character encoder implemented using apr_xlate.\n*/\nclass APRCharsetEncoder : public CharsetEncoder\n{\n\tpublic:\n\t\tAPRCharsetEncoder(const LogString& topage) : pool()\n\t\t{\n#if LOG4CXX_LOGCHAR_IS_WCHAR\n\t\t\tconst char* frompage = \"WCHAR_T\";\n#endif\n#if LOG4CXX_LOGCHAR_IS_UTF8\n\t\t\tconst char* frompage = \"UTF-8\";\n#endif\n#if LOG4CXX_LOGCHAR_IS_UNICHAR\n\t\t\tconst char* frompage = \"UTF-16\";\n#endif\n\t\t\tstd::string tpage(Transcoder::encodeCharsetName(topage));\n\t\t\tapr_status_t stat = apr_xlate_open(&convset,\n\t\t\t\t\ttpage.c_str(),\n\t\t\t\t\tfrompage,\n\t\t\t\t\tpool.getAPRPool());\n\n\t\t\tif (stat != APR_SUCCESS)\n\t\t\t{\n\t\t\t\tthrow IllegalArgumentException(topage);\n\t\t\t}\n\t\t}\n\n\t\tvirtual ~APRCharsetEncoder()\n\t\t{\n\t\t}\n\n\t\tvirtual log4cxx_status_t encode(const LogString& in,\n\t\t\tLogString::const_iterator& iter,\n\t\t\tByteBuffer& out)\n\t\t{\n\t\t\tapr_status_t stat;\n\t\t\tsize_t outbytes_left = out.remaining();\n\t\t\tsize_t initial_outbytes_left = outbytes_left;\n\t\t\tsize_t position = out.position();\n\n\t\t\tif (iter == in.end())\n\t\t\t{\n\t\t\t\tstd::unique_lock lock(mutex);\n\t\t\t\tstat = apr_xlate_conv_buffer(convset, NULL, NULL,\n\t\t\t\t\t\tout.data() + position, &outbytes_left);\n\t\t\t}\n\t\t\telse\n\t\t\t{\n\t\t\t\tLogString::size_type inOffset = (iter - in.begin());\n\t\t\t\tapr_size_t inbytes_left =\n\t\t\t\t\t(in.size() - inOffset) * sizeof(LogString::value_type);\n\t\t\t\tapr_size_t initial_inbytes_left = inbytes_left;\n\t\t\t\t{\n\t\t\t\t\tstd::unique_lock lock(mutex);\n\t\t\t\t\tstat = apr_xlate_conv_buffer(convset,\n\t\t\t\t\t\t\t(const char*) (in.data() + inOffset),\n\t\t\t\t\t\t\t&inbytes_left,\n\t\t\t\t\t\t\tout.data() + position,\n\t\t\t\t\t\t\t&outbytes_left);\n\t\t\t\t}\n\t\t\t\titer += ((initial_inbytes_left - inbytes_left) / sizeof(LogString::value_type));\n\t\t\t}\n\n\t\t\tout.position(out.position() + (initial_outbytes_left - outbytes_left));\n\t\t\treturn stat;\n\t\t}\n\n\tprivate:\n\t\tAPRCharsetEncoder(const APRCharsetEncoder&);\n\t\tAPRCharsetEncoder& operator=(const APRCharsetEncoder&);\n\t\tPool pool;\n\t\tstd::mutex mutex;\n\t\tapr_xlate_t* convset;\n};\n#endif\n\n#if LOG4CXX_LOGCHAR_IS_WCHAR && LOG4CXX_HAS_WCSTOMBS\n/**\n * A character encoder implemented using wcstombs.\n*/\nclass WcstombsCharsetEncoder : public CharsetEncoder\n{\n\tpublic:\n\t\tWcstombsCharsetEncoder()\n\t\t{\n\t\t}\n\n\t\t/**\n\t\t * Converts a wchar_t to the default external multibyte encoding.\n\t\t */\n\t\tlog4cxx_status_t encode(const LogString& in,\n\t\t\tLogString::const_iterator& iter,\n\t\t\tByteBuffer& out)\n\t\t{\n\t\t\tlog4cxx_status_t stat = APR_SUCCESS;\n\n\t\t\tif (iter != in.end())\n\t\t\t{\n\t\t\t\tsize_t outbytes_left = out.remaining();\n\t\t\t\tsize_t position = out.position();\n\t\t\t\tstd::wstring::size_type inOffset = (iter - in.begin());\n\t\t\t\tenum { BUFSIZE = 256 };\n\t\t\t\twchar_t buf[BUFSIZE];\n\t\t\t\tsize_t chunkSize = BUFSIZE - 1;\n\n\t\t\t\tif (chunkSize * MB_LEN_MAX > outbytes_left)\n\t\t\t\t{\n\t\t\t\t\tchunkSize = outbytes_left / MB_LEN_MAX;\n\t\t\t\t}\n\n\t\t\t\tif (chunkSize > in.length() - inOffset)\n\t\t\t\t{\n\t\t\t\t\tchunkSize = in.length() - inOffset;\n\t\t\t\t}\n\n\t\t\t\tmemset(buf, 0, BUFSIZE * sizeof(wchar_t));\n\t\t\t\tmemcpy(buf,\n\t\t\t\t\tin.data() + inOffset,\n\t\t\t\t\tchunkSize * sizeof(wchar_t));\n\t\t\t\tsize_t converted = wcstombs(out.data() + position, buf, outbytes_left);\n\n\t\t\t\tif (converted == (size_t) -1)\n\t\t\t\t{\n\t\t\t\t\tstat = APR_BADARG;\n\n\t\t\t\t\t//\n\t\t\t\t\t// if unconvertable character was encountered\n\t\t\t\t\t// repeatedly halve source to get fragment that\n\t\t\t\t\t// can be converted\n\t\t\t\t\tfor (chunkSize /= 2;\n\t\t\t\t\t\tchunkSize > 0;\n\t\t\t\t\t\tchunkSize /= 2)\n\t\t\t\t\t{\n\t\t\t\t\t\tbuf[chunkSize] = 0;\n\t\t\t\t\t\tconverted = wcstombs(out.data() + position, buf, outbytes_left);\n\n\t\t\t\t\t\tif (converted != (size_t) -1)\n\t\t\t\t\t\t{\n\t\t\t\t\t\t\titer += chunkSize;\n\t\t\t\t\t\t\tout.position(out.position() + converted);\n\t\t\t\t\t\t\tbreak;\n\t\t\t\t\t\t}\n\t\t\t\t\t}\n\t\t\t\t}\n\t\t\t\telse\n\t\t\t\t{\n\t\t\t\t\titer += chunkSize;\n\t\t\t\t\tout.position(out.position() + converted);\n\t\t\t\t}\n\t\t\t}\n\n\t\t\treturn stat;\n\t\t}\n\n\n\n\tprivate:\n\t\tWcstombsCharsetEncoder(const WcstombsCharsetEncoder&);\n\t\tWcstombsCharsetEncoder& operator=(const WcstombsCharsetEncoder&);\n};\n#endif\n\n\n/**\n* Encodes a LogString to US-ASCII.\n*/\nclass USASCIICharsetEncoder : public CharsetEncoder\n{\n\tpublic:\n\t\tUSASCIICharsetEncoder()\n\t\t{\n\t\t}\n\n\t\tvirtual log4cxx_status_t encode(const LogString& in,\n\t\t\tLogString::const_iterator& iter,\n\t\t\tByteBuffer& out)\n\t\t{\n\t\t\tlog4cxx_status_t stat = APR_SUCCESS;\n\n\t\t\tif (iter != in.end())\n\t\t\t{\n\t\t\t\twhile (out.remaining() > 0 && iter != in.end())\n\t\t\t\t{\n\t\t\t\t\tLogString::const_iterator prev(iter);\n\t\t\t\t\tunsigned int sv = Transcoder::decode(in, iter);\n\n\t\t\t\t\tif (sv <= 0x7F)\n\t\t\t\t\t{\n\t\t\t\t\t\tout.put((char) sv);\n\t\t\t\t\t}\n\t\t\t\t\telse\n\t\t\t\t\t{\n\t\t\t\t\t\titer = prev;\n\t\t\t\t\t\tstat = APR_BADARG;\n\t\t\t\t\t\tbreak;\n\t\t\t\t\t}\n\t\t\t\t}\n\t\t\t}\n\n\t\t\treturn stat;\n\t\t}\n\n\tprivate:\n\t\tUSASCIICharsetEncoder(const USASCIICharsetEncoder&);\n\t\tUSASCIICharsetEncoder& operator=(const USASCIICharsetEncoder&);\n};\n\n/**\n* Converts a LogString to ISO-8859-1.\n*/\nclass ISOLatinCharsetEncoder : public CharsetEncoder\n{\n\tpublic:\n\t\tISOLatinCharsetEncoder()\n\t\t{\n\t\t}\n\n\t\tvirtual log4cxx_status_t encode(const LogString& in,\n\t\t\tLogString::const_iterator& iter,\n\t\t\tByteBuffer& out)\n\t\t{\n\t\t\tlog4cxx_status_t stat = APR_SUCCESS;\n\n\t\t\tif (iter != in.end())\n\t\t\t{\n\t\t\t\twhile (out.remaining() > 0 && iter != in.end())\n\t\t\t\t{\n\t\t\t\t\tLogString::const_iterator prev(iter);\n\t\t\t\t\tunsigned int sv = Transcoder::decode(in, iter);\n\n\t\t\t\t\tif (sv <= 0xFF)\n\t\t\t\t\t{\n\t\t\t\t\t\tout.put((char) sv);\n\t\t\t\t\t}\n\t\t\t\t\telse\n\t\t\t\t\t{\n\t\t\t\t\t\titer = prev;\n\t\t\t\t\t\tstat = APR_BADARG;\n\t\t\t\t\t\tbreak;\n\t\t\t\t\t}\n\t\t\t\t}\n\t\t\t}\n\n\t\t\treturn stat;\n\t\t}\n\n\tprivate:\n\t\tISOLatinCharsetEncoder(const ISOLatinCharsetEncoder&);\n\t\tISOLatinCharsetEncoder& operator=(const ISOLatinCharsetEncoder&);\n};\n\n/**\n* Encodes a LogString to a byte array when the encodings are identical.\n*/\nclass TrivialCharsetEncoder : public CharsetEncoder\n{\n\tpublic:\n\t\tTrivialCharsetEncoder()\n\t\t{\n\t\t}\n\n\n\t\tvirtual log4cxx_status_t encode(const LogString& in,\n\t\t\tLogString::const_iterator& iter,\n\t\t\tByteBuffer& out)\n\t\t{\n\t\t\tif (iter != in.end())\n\t\t\t{\n\t\t\t\tsize_t requested = in.length() - (iter - in.begin());\n\n\t\t\t\tif (requested > out.remaining() / sizeof(logchar))\n\t\t\t\t{\n\t\t\t\t\trequested = out.remaining() / sizeof(logchar);\n\t\t\t\t}\n\n\t\t\t\tmemcpy(out.current(),\n\t\t\t\t\t(const char*) in.data() + (iter - in.begin()),\n\t\t\t\t\trequested * sizeof(logchar));\n\t\t\t\titer += requested;\n\t\t\t\tout.position(out.position() + requested * sizeof(logchar));\n\t\t\t}\n\n\t\t\treturn APR_SUCCESS;\n\t\t}\n\n\tprivate:\n\t\tTrivialCharsetEncoder(const TrivialCharsetEncoder&);\n\t\tTrivialCharsetEncoder& operator=(const TrivialCharsetEncoder&);\n};\n\n#if LOG4CXX_LOGCHAR_IS_UTF8\ntypedef TrivialCharsetEncoder UTF8CharsetEncoder;\n#else\n/**\n * Converts a LogString to UTF-8.\n */\nclass UTF8CharsetEncoder : public CharsetEncoder\n{\n\tpublic:\n\t\t\nUTF8CharsetEncoder()\n\t\t{\n\t\t}\n\n\t\tvirtual log4cxx_status_t encode(const LogString& in,\n\t\t\tLogString::const_iterator& iter,\n\t\t\tByteBuffer& out)\n\t\t{\n\t\t\twhile (iter != in.end() && out.remaining() >= 8)\n\t\t\t{\n\t\t\t\tunsigned int sv = Transcoder::decode(in, iter);\n\n\t\t\t\tif (sv == 0xFFFF)\n\t\t\t\t{\n\t\t\t\t\treturn APR_BADARG;\n\t\t\t\t}\n\n\t\t\t\tTranscoder::encodeUTF8(sv, out);\n\t\t\t}\n\n\t\t\treturn APR_SUCCESS;\n\t\t}\n\n\tprivate:\n\t\tUTF8CharsetEncoder(const UTF8CharsetEncoder&);\n\t\tUTF8CharsetEncoder& operator=(const UTF8CharsetEncoder&);\n};\n#endif\n\n/**\n * Encodes a LogString to UTF16-BE.\n */\nclass UTF16BECharsetEncoder : public CharsetEncoder\n{\n\tpublic:\n\t\tUTF16BECharsetEncoder()\n\t\t{\n\t\t}\n\n\t\tvirtual log4cxx_status_t encode(const LogString& in,\n\t\t\tLogString::const_iterator& iter,\n\t\t\tByteBuffer& out)\n\t\t{\n\t\t\twhile (iter != in.end() && out.remaining() >= 4)\n\t\t\t{\n\t\t\t\tunsigned int sv = Transcoder::decode(in, iter);\n\n\t\t\t\tif (sv == 0xFFFF)\n\t\t\t\t{\n\t\t\t\t\treturn APR_BADARG;\n\t\t\t\t}\n\n\t\t\t\tTranscoder::encodeUTF16BE(sv, out);\n\t\t\t}\n\n\t\t\treturn APR_SUCCESS;\n\t\t}\n\n\tprivate:\n\t\tUTF16BECharsetEncoder(const UTF16BECharsetEncoder&);\n\t\tUTF16BECharsetEncoder& operator=(const UTF16BECharsetEncoder&);\n};\n\n/**\n * Encodes a LogString to UTF16-LE.\n */\nclass UTF16LECharsetEncoder : public CharsetEncoder\n{\n\tpublic:\n\t\tUTF16LECharsetEncoder()\n\t\t{\n\t\t}\n\n\n\t\tvirtual log4cxx_status_t encode(const LogString& in,\n\t\t\tLogString::const_iterator& iter,\n\t\t\tByteBuffer& out)\n\t\t{\n\t\t\twhile (iter != in.end() && out.remaining() >= 4)\n\t\t\t{\n\t\t\t\tunsigned int sv = Transcoder::decode(in, iter);\n\n\t\t\t\tif (sv == 0xFFFF)\n\t\t\t\t{\n\t\t\t\t\treturn APR_BADARG;\n\t\t\t\t}\n\n\t\t\t\tTranscoder::encodeUTF16LE(sv, out);\n\t\t\t}\n\n\t\t\treturn APR_SUCCESS;\n\t\t}\n\tprivate:\n\t\tUTF16LECharsetEncoder(const UTF16LECharsetEncoder&);\n\t\tUTF16LECharsetEncoder& operator=(const UTF16LECharsetEncoder&);\n};\n\n/**\n * Charset encoder that uses current locale settings.\n */\nclass LocaleCharsetEncoder : public CharsetEncoder\n{\n\tpublic:\n\t\tLocaleCharsetEncoder() : state()\n\t\t{\n\t\t}\n\t\tlog4cxx_status_t encode\n\t\t\t( const LogString& in\n\t\t\t, LogString::const_iterator& iter\n\t\t\t, ByteBuffer& out\n\t\t\t) override\n\t\t{\n\t\t\tlog4cxx_status_t result = APR_SUCCESS;\n#if !LOG4CXX_CHARSET_EBCDIC\n\t\t\tchar* current = out.current();\n\t\t\tsize_t remain = out.remaining();\n\t\t\tif (std::mbsinit(&this->state)) // ByteBuffer not partially encoded?\n\t\t\t{\n\t\t\t\t// Copy single byte characters\n\t\t\t\tfor (;\n\t\t\t\t\titer != in.end() && ((unsigned int) *iter) < 0x80 && 0 < remain;\n\t\t\t\t\titer++, remain--, current++)\n\t\t\t\t{\n\t\t\t\t\t*current = *iter;\n\t\t\t\t}\n\t\t\t}\n#endif\n\t\t\t// Encode characters that may require multiple bytes\n\t\t\twhile (iter != in.end() && MB_CUR_MAX <= remain)\n\t\t\t{\n\t\t\t\tauto ch = Transcoder::decode(in, iter);\n\t\t\t\tauto n = std::wcrtomb(current, ch, &this->state);\n\t\t\t\tif (static_cast(-1) == n) // not a valid wide character?\n\t\t\t\t{\n\t\t\t\t\tresult = APR_BADARG;\n\t\t\t\t\tbreak;\n\t\t\t\t}\n\t\t\t\tremain -= n;\n\t\t\t\tcurrent += n;\n\t\t\t}\n\t\t\tout.position(current - out.data());\n\t\t\treturn result;\n\t\t}\n\n\tprivate:\n\t\tstd::mbstate_t state;\n};\n\n\n} // namespace helpers\n\n} //namespace log4cxx\n\n\n\nCharsetEncoder::CharsetEncoder()\n{\n}\n\nCharsetEncoder::~CharsetEncoder()\n{\n}\n\nCharsetEncoderPtr CharsetEncoder::getDefaultEncoder()\n{\n\tstatic WideLife encoder(createDefaultEncoder());\n\n\t//\n\t// if invoked after static variable destruction\n\t// (if logging is called in the destructor of a static object)\n\t// then create a new decoder.\n\t//\n\tif (encoder.value() == 0)\n\t{\n\t\treturn CharsetEncoderPtr( createDefaultEncoder() );\n\t}\n\n\treturn encoder;\n}\n\nCharsetEncoder* CharsetEncoder::createDefaultEncoder()\n{\n#if LOG4CXX_CHARSET_UTF8\n\treturn new UTF8CharsetEncoder();\n#elif LOG4CXX_CHARSET_ISO88591\n\treturn new ISOLatinCharsetEncoder();\n#elif LOG4CXX_CHARSET_USASCII\n\treturn new USASCIICharsetEncoder();\n#elif LOG4CXX_LOGCHAR_IS_WCHAR && LOG4CXX_HAS_WCSTOMBS\n\treturn new WcstombsCharsetEncoder();\n#else\n\treturn new LocaleCharsetEncoder();\n#endif\n}\n\n\nCharsetEncoderPtr CharsetEncoder::getUTF8Encoder()\n{\n\treturn std::make_shared();\n}\n\n\n\nCharsetEncoderPtr CharsetEncoder::getEncoder(const LogString& charset)\n{\n\tif (StringHelper::equalsIgnoreCase(charset, LOG4CXX_STR(\"UTF-8\"), LOG4CXX_STR(\"utf-8\"))\n\t\t|| StringHelper::equalsIgnoreCase(charset, LOG4CXX_STR(\"CP65001\"), LOG4CXX_STR(\"cp65001\")))\n\t{\n\t\treturn std::make_shared();\n\t}\n\telse if (StringHelper::equalsIgnoreCase(charset, LOG4CXX_STR(\"C\"), LOG4CXX_STR(\"c\")) ||\n\t\tcharset == LOG4CXX_STR(\"646\") ||\n\t\tStringHelper::equalsIgnoreCase(charset, LOG4CXX_STR(\"US-ASCII\"), LOG4CXX_STR(\"us-ascii\")) ||\n\t\tStringHelper::equalsIgnoreCase(charset, LOG4CXX_STR(\"ISO646-US\"), LOG4CXX_STR(\"iso646-US\")) ||\n\t\tStringHelper::equalsIgnoreCase(charset, LOG4CXX_STR(\"ANSI_X3.4-1968\"), LOG4CXX_STR(\"ansi_x3.4-1968\")) ||\n\t\tStringHelper::equalsIgnoreCase(charset, LOG4CXX_STR(\"CP20127\"), LOG4CXX_STR(\"cp20127\")))\n\t{\n\t\treturn std::make_shared();\n\t}\n\telse if (StringHelper::equalsIgnoreCase(charset, LOG4CXX_STR(\"ISO-8859-1\"), LOG4CXX_STR(\"iso-8859-1\")) ||\n\t\tStringHelper::equalsIgnoreCase(charset, LOG4CXX_STR(\"ISO-LATIN-1\"), LOG4CXX_STR(\"iso-latin-1\")) ||\n\t\tStringHelper::equalsIgnoreCase(charset, LOG4CXX_STR(\"CP1252\"), LOG4CXX_STR(\"cp1252\")))\n\t{\n\t\treturn std::make_shared();\n\t}\n\telse if (StringHelper::equalsIgnoreCase(charset, LOG4CXX_STR(\"UTF-16BE\"), LOG4CXX_STR(\"utf-16be\"))\n\t\t|| StringHelper::equalsIgnoreCase(charset, LOG4CXX_STR(\"UTF-16\"), LOG4CXX_STR(\"utf-16\"))\n\t\t|| StringHelper::equalsIgnoreCase(charset, LOG4CXX_STR(\"CP1200\"), LOG4CXX_STR(\"cp1200\")))\n\t{\n\t\treturn std::make_shared();\n\t}\n\telse if (StringHelper::equalsIgnoreCase(charset, LOG4CXX_STR(\"UTF-16LE\"), LOG4CXX_STR(\"utf-16le\")))\n\t{\n\t\treturn std::make_shared();\n\t}\n\telse if (StringHelper::equalsIgnoreCase(charset, LOG4CXX_STR(\"LOCALE\"), LOG4CXX_STR(\"locale\")))\n\t{\n\t\treturn std::make_shared();\n\t}\n\n#if APR_HAS_XLATE\n\treturn std::make_shared(charset);\n#else\n\tthrow IllegalArgumentException(charset);\n#endif\n}\n\n\nvoid CharsetEncoder::reset()\n{\n}\n\nvoid CharsetEncoder::flush(ByteBuffer& /* out */ )\n{\n}\n\n\nvoid CharsetEncoder::encode(CharsetEncoderPtr& enc,\n\tconst LogString& src,\n\tLogString::const_iterator& iter,\n\tByteBuffer& dst)\n{\n\tlog4cxx_status_t stat = enc->encode(src, iter, dst);\n\n\tif (stat != APR_SUCCESS && iter != src.end())\n\t{\n#if LOG4CXX_LOGCHAR_IS_WCHAR || LOG4CXX_LOGCHAR_IS_UNICHAR\n\t\titer++;\n#elif LOG4CXX_LOGCHAR_IS_UTF8\n\n\t\t// advance past this character and all continuation characters\n\t\twhile ((*(++iter) & 0xC0) == 0x80);\n\n#else\n#error logchar is unrecognized\n#endif\n\t\tdst.put(Transcoder::LOSSCHAR);\n\t}\n}\n\nbool CharsetEncoder::isTriviallyCopyable(const LogString& src, const CharsetEncoderPtr& enc)\n{\n\tbool result;\n#if !LOG4CXX_CHARSET_EBCDIC\n\tif (dynamic_cast(enc.get()))\n\t{\n\t\tresult = src.end() == std::find_if(src.begin(), src.end()\n\t\t\t, [](const logchar& ch) -> bool { return 0x80 <= (unsigned int)ch; });\n\t}\n\telse\n#endif\n\t\tresult = !!dynamic_cast(enc.get());\n\treturn result;\n}\n\n// Path: src/main/cpp/colorstartpatternconverter.cpp\n/*\n * Licensed to the Apache Software Foundation (ASF) under one or more\n * contributor license agreements. See the NOTICE file distributed with\n * this work for additional information regarding copyright ownership.\n * The ASF licenses this file to You under the Apache License, Version 2.0\n * (the \"License\"); you may not use this file except in compliance with\n * the License. You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n */\n\n#include \n#include \n#include \n#include \n#include \n#include \n\nusing namespace LOG4CXX_NS;\nusing namespace LOG4CXX_NS::pattern;\nusing namespace LOG4CXX_NS::spi;\nusing namespace LOG4CXX_NS::helpers;\n\nIMPLEMENT_LOG4CXX_OBJECT(ColorStartPatternConverter)\n\n#define priv static_cast(m_priv.get())\n\nstatic LogString colorToANSISequence(const LogString& color, bool isForeground, Pool& pool){\n\tint numberToConvert = 0;\n\n\tif(StringHelper::equalsIgnoreCase(color, LOG4CXX_STR(\"BLACK\"), LOG4CXX_STR(\"black\"))){\n\t\tnumberToConvert = 30;\n\t}else if(StringHelper::equalsIgnoreCase(color, LOG4CXX_STR(\"RED\"), LOG4CXX_STR(\"red\"))){\n\t\tnumberToConvert = 31;\n\t}else if(StringHelper::equalsIgnoreCase(color, LOG4CXX_STR(\"GREEN\"), LOG4CXX_STR(\"green\"))){\n\t\tnumberToConvert = 32;\n\t}else if(StringHelper::equalsIgnoreCase(color, LOG4CXX_STR(\"YELLOW\"), LOG4CXX_STR(\"yellow\"))){\n\t\tnumberToConvert = 33;\n\t}else if(StringHelper::equalsIgnoreCase(color, LOG4CXX_STR(\"BLUE\"), LOG4CXX_STR(\"blue\"))){\n\t\tnumberToConvert = 34;\n\t}else if(StringHelper::equalsIgnoreCase(color, LOG4CXX_STR(\"MAGENTA\"), LOG4CXX_STR(\"magenta\"))){\n\t\tnumberToConvert = 35;\n\t}else if(StringHelper::equalsIgnoreCase(color, LOG4CXX_STR(\"CYAN\"), LOG4CXX_STR(\"cyan\"))){\n\t\tnumberToConvert = 36;\n\t}else if(StringHelper::equalsIgnoreCase(color, LOG4CXX_STR(\"WHITE\"), LOG4CXX_STR(\"white\"))){\n\t\tnumberToConvert = 37;\n\t}\n\n\tif( numberToConvert == 0 ){\n\t\treturn LOG4CXX_STR(\"\");\n\t}\n\tLogString ret;\n\tif( isForeground == false ){\n\t\tnumberToConvert += 10;\n\t}\n\tStringHelper::toString(numberToConvert, pool, ret);\n\treturn ret;\n}\n\nstatic LogString graphicsModeToANSISequence(const LogString& graphicsMode, Pool& pool){\n\tint numberToConvert = 0;\n\n\tif(StringHelper::equalsIgnoreCase(graphicsMode, LOG4CXX_STR(\"BOLD\"), LOG4CXX_STR(\"bold\"))){\n\t\tnumberToConvert = 1;\n\t}else if(StringHelper::equalsIgnoreCase(graphicsMode, LOG4CXX_STR(\"DIM\"), LOG4CXX_STR(\"dim\"))){\n\t\tnumberToConvert = 2;\n\t}else if(StringHelper::equalsIgnoreCase(graphicsMode, LOG4CXX_STR(\"ITALIC\"), LOG4CXX_STR(\"italic\"))){\n\t\tnumberToConvert = 3;\n\t}else if(StringHelper::equalsIgnoreCase(graphicsMode, LOG4CXX_STR(\"UNDERLINE\"), LOG4CXX_STR(\"underline\"))){\n\t\tnumberToConvert = 4;\n\t}else if(StringHelper::equalsIgnoreCase(graphicsMode, LOG4CXX_STR(\"BLINKING\"), LOG4CXX_STR(\"blinking\"))){\n\t\tnumberToConvert = 5;\n\t}else if(StringHelper::equalsIgnoreCase(graphicsMode, LOG4CXX_STR(\"INVERSE\"), LOG4CXX_STR(\"inverse\"))){\n\t\tnumberToConvert = 7;\n\t}else if(StringHelper::equalsIgnoreCase(graphicsMode, LOG4CXX_STR(\"STRIKETHROUGH\"), LOG4CXX_STR(\"strikethrough\"))){\n\t\tnumberToConvert = 9;\n\t}\n\n\tif( numberToConvert == 0 ){\n\t\treturn LOG4CXX_STR(\"\");\n\t}\n\tLogString ret;\n\tStringHelper::toString(numberToConvert, pool, ret);\n\treturn ret;\n}\n\nstatic LogString convertSingleSequence(const LogString& sequence, Pool& pool){\n\tLogString strInParens;\n\tbool inParens = false;\n\tbool hasParens = false;\n\tsize_t x = 0;\n\n\tfor(x = 0; x < sequence.length(); x++){\n\t\tif( sequence[x] == '(' && !inParens ){\n\t\t\tinParens = true;\n\t\t\thasParens = true;\n\t\t\tcontinue;\n\t\t}else if( sequence[x] == '(' && inParens ){\n\t\t\t// Unbalanced parens - parse invalid\n\t\t\treturn LOG4CXX_STR(\"\");\n\t\t}\n\n\t\tif( sequence[x] == ')' && inParens ){\n\t\t\thasParens = true;\n\t\t\tinParens = false;\n\t\t\tbreak;\n\t\t}\n\n\t\tif( inParens ){\n\t\t\tstrInParens.push_back(sequence[x]);\n\t\t}\n\t}\n\n\tif( (x != (sequence.length() - 1) || inParens) && hasParens ){\n\t\t// Unbalanced parens, or more data in the string than we expected - parse invalid\n\t\treturn LOG4CXX_STR(\"\");\n\t}\n\n\tif(StringHelper::startsWith(sequence, LOG4CXX_STR(\"fg(\"))){\n\t\t// Parse foreground\n\t\treturn colorToANSISequence(strInParens, true, pool);\n\t}else if(StringHelper::startsWith(sequence, LOG4CXX_STR(\"bg(\"))){\n\t\treturn colorToANSISequence(strInParens, false, pool);\n\t}else{\n\t\treturn graphicsModeToANSISequence(sequence, pool);\n\t}\n}\n\nstruct ColorStartPatternConverter::ColorPatternConverterPrivate : public PatternConverterPrivate\n{\n\tColorPatternConverterPrivate( const LogString& name, const LogString& style ) :\n\t\tPatternConverterPrivate( name, style ){}\n\n\tLogString m_fatalColor;\n\tLogString m_errorColor;\n\tLogString m_warnColor;\n\tLogString m_infoColor;\n\tLogString m_debugColor;\n\tLogString m_traceColor;\n};\n\nColorStartPatternConverter::ColorStartPatternConverter() :\n\tLoggingEventPatternConverter(std::make_unique(LOG4CXX_STR(\"Color Start\"),\n\t\tLOG4CXX_STR(\"colorStart\")))\n{\n}\n\nPatternConverterPtr ColorStartPatternConverter::newInstance(\n\tconst std::vector& /* options */)\n{\n\tstatic WideLife instance = std::make_shared();\n\treturn instance;\n}\n\nvoid ColorStartPatternConverter::format(\n\tconst LoggingEventPtr& event,\n\tLogString& toAppendTo,\n\tPool& p) const\n{\n\n\tLOG4CXX_NS::LevelPtr lvl = event->getLevel();\n\n\tswitch (lvl->toInt())\n\t{\n\t\tcase LOG4CXX_NS::Level::FATAL_INT:\n\t\t\ttoAppendTo.append(priv->m_fatalColor);\n\t\t\tbreak;\n\n\t\tcase LOG4CXX_NS::Level::ERROR_INT:\n\t\t\ttoAppendTo.append(priv->m_errorColor);\n\t\t\tbreak;\n\n\t\tcase LOG4CXX_NS::Level::WARN_INT:\n\t\t\ttoAppendTo.append(priv->m_warnColor);\n\t\t\tbreak;\n\n\t\tcase LOG4CXX_NS::Level::INFO_INT:\n\t\t\ttoAppendTo.append(priv->m_infoColor);\n\t\t\tbreak;\n\n\t\tcase LOG4CXX_NS::Level::DEBUG_INT:\n\t\t\ttoAppendTo.append(priv->m_debugColor);\n\t\t\tbreak;\n\n\t\tcase LOG4CXX_NS::Level::TRACE_INT:\n\t\t\ttoAppendTo.append(priv->m_traceColor);\n\t\t\tbreak;\n\n\t\tdefault:\n...\n", "needle_token_start": 9822, "needle_token_end": 9837, "code_context_ntokens": 15168, "output": ["```cpp\n#if LOG4CXX_LOGCHAR_IS_UTF8\ntypedef TrivialCharsetEncoder UTF8CharsetEncoder;\n#else\n/**\n * Converts a LogString to UTF-8.\n */\nclass UTF8CharsetEncoder : public CharsetEncoder\n{\n\tpublic:\n\t\t\nUTF8CharsetEncoder()\n\t\t{\n\t\t}\n\n\t\tvirtual log4cxx_status_t encode(const LogString& in,\n\t\t\tLogString::const_iterator& iter,\n\t\t\tByteBuffer& out)\n\t\t{\n\t\t\twhile (iter != in.end() && out.remaining() >= 8)\n\t\t\t{\n\t\t\t\tunsigned int sv = Transcoder::decode(in, iter);\n\n\t\t\t\tif (sv == 0xFFFF)\n\t\t\t\t{\n\t\t\t\t\treturn APR_BADARG;\n\t\t\t\t}\n\n\t\t\t\tTranscoder::encodeUTF8(sv, out);\n\t\t\t}\n\n\t\t\treturn APR_SUCCESS;\n\t\t}\n\n\tprivate:\n\t\tUTF8CharsetEncoder(const UTF8CharsetEncoder&);\n\t\tUTF8CharsetEncoder& operator=(const UTF8CharsetEncoder&);\n};\n#endif\n```"]} +{"repo": "apache/logging-log4cxx", "name": "PropertiesPatternConverterPrivate", "language": "cpp", "path": "src/main/cpp/propertiespatternconverter.cpp", "position_ratio": 0.75, "description": "\nFunction Description:\n1. **Purpose**: The purpose is to handle the customization of logging output based on specific properties. It allows the configuration of log output to include specific property values dynamically.\n2. **Input**: The input includes a property name which specifies which property of the logging event should be included in the log output.\n3. **Output**: There is no direct output; however, it modifies the logging behavior by incorporating the specified property's value into the log output.\n4. **Procedure**: The procedure involves initializing with the name of the property to be used. When a logging event occurs, the specified property's value is retrieved and formatted according to the defined style, and then included in the log output.\n\n", "instruction": "Based on the function description and code context, please retrieve and repeat the exact described function from the code context in a code block wrapped by ```:", "template": "instruction\ncode_context\ndescription\ninstruction", "code_context": "// Path: src/main/cpp/systemerrwriter.cpp\n/*\n * Licensed to the Apache Software Foundation (ASF) under one or more\n * contributor license agreements. See the NOTICE file distributed with\n * this work for additional information regarding copyright ownership.\n * The ASF licenses this file to You under the Apache License, Version 2.0\n * (the \"License\"); you may not use this file except in compliance with\n * the License. You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n...\n// Path: src/main/cpp/threadusernamepatternconverter.cpp\n/*\n * Licensed to the Apache Software Foundation (ASF) under one or more\n * contributor license agreements. See the NOTICE file distributed with\n * this work for additional information regarding copyright ownership.\n * The ASF licenses this file to You under the Apache License, Version 2.0\n * (the \"License\"); you may not use this file except in compliance with\n * the License. You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n */\n\n#include \n#include \n#include \n#include \n\nusing namespace LOG4CXX_NS;\nusing namespace LOG4CXX_NS::pattern;\nusing namespace LOG4CXX_NS::spi;\nusing namespace LOG4CXX_NS::helpers;\n\nIMPLEMENT_LOG4CXX_OBJECT(ThreadUsernamePatternConverter)\n\nThreadUsernamePatternConverter::ThreadUsernamePatternConverter() :\n\tLoggingEventPatternConverter(LOG4CXX_STR(\"Thread Name\"),\n\t\tLOG4CXX_STR(\"Thread Name\"))\n{\n}\n\nPatternConverterPtr ThreadUsernamePatternConverter::newInstance(\n\tconst std::vector& /* options */)\n{\n\tstatic WideLife def = std::make_shared();\n\treturn def;\n}\n\nvoid ThreadUsernamePatternConverter::format(\n\tconst LoggingEventPtr& event,\n\tLogString& toAppendTo,\n\tPool& /* p */) const\n{\n\ttoAppendTo.append(event->getThreadUserName());\n}\n\n\n// Path: src/main/cpp/file.cpp\n/*\n * Licensed to the Apache Software Foundation (ASF) under one or more\n * contributor license agreements. See the NOTICE file distributed with\n * this work for additional information regarding copyright ownership.\n * The ASF licenses this file to You under the Apache License, Version 2.0\n * (the \"License\"); you may not use this file except in compliance with\n * the License. You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n */\n\n#include \n#include \n#include \n#include \n#include \n#include \n#include \n#include \n\nusing namespace LOG4CXX_NS;\nusing namespace LOG4CXX_NS::helpers;\n\nstruct File::FilePrivate{\n\tFilePrivate() :\n\t\tautoDelete(false)\n\t{}\n\n\tFilePrivate(LogString path) :\n\t\tpath(path),\n\t\tautoDelete(false)\n\t{}\n\n\tFilePrivate(LogString path, bool autoDelete) :\n\t\tpath(path),\n\t\tautoDelete(autoDelete)\n\t{}\n\n\tLogString path;\n\tbool autoDelete;\n};\n\nFile::File() :\n\tm_priv(std::make_unique())\n{\n}\n\ntemplate\nstatic LogString decodeLS(const S* src)\n{\n\tLogString dst;\n\n\tif (src != 0)\n\t{\n\t\tTranscoder::decode(src, dst);\n\t}\n\n\treturn dst;\n}\n\ntemplate\nstatic LogString decodeLS(const std::basic_string& src)\n{\n\tLogString dst;\n\tTranscoder::decode(src, dst);\n\treturn dst;\n}\n\n\nFile::File(const std::string& name1)\n\t: m_priv(std::make_unique(decodeLS(name1)))\n{\n}\n\nFile::File(const char* name1)\n\t: m_priv(std::make_unique(decodeLS(name1)))\n{\n}\n\n#if LOG4CXX_WCHAR_T_API\nFile::File(const std::wstring& name1)\n\t: m_priv(std::make_unique(decodeLS(name1)))\n{\n}\n\nFile::File(const wchar_t* name1)\n\t: m_priv(std::make_unique(decodeLS(name1)))\n{\n}\n#endif\n\n#if LOG4CXX_UNICHAR_API || LOG4CXX_LOGCHAR_IS_UNICHAR\nFile::File(const std::basic_string& name1)\n\t: m_priv(std::make_unique(decodeLS(name1)))\n{\n}\n\nFile::File(const UniChar* name1)\n\t: m_priv(std::make_unique(decodeLS(name1)))\n{\n}\n#endif\n\n#if LOG4CXX_CFSTRING_API\nFile::File(const CFStringRef& name1)\n\t: m_priv(std::make_unique(decodeLS(name1)))\n{\n}\n#endif\n\nFile::File(const File& src)\n\t: m_priv(std::make_unique(src.m_priv->path, src.m_priv->autoDelete))\n{\n}\n\nFile& File::operator=(const File& src)\n{\n\tif (this == &src)\n\t{\n\t\treturn *this;\n\t}\n\n\tm_priv->path.assign(src.m_priv->path);\n\tm_priv->autoDelete = src.m_priv->autoDelete;\n\n\treturn *this;\n}\n\n\nFile::~File()\n{\n\tif(m_priv->autoDelete){\n\t\tPool p;\n\t\tdeleteFile(p);\n\t}\n}\n\n\nLogString File::getPath() const\n{\n\treturn m_priv->path;\n}\n\nFile& File::setPath(const LogString& newName)\n{\n\tm_priv->path.assign(newName);\n\treturn *this;\n}\n\nLogString File::getName() const\n{\n\tconst logchar slashes[] = { 0x2F, 0x5C, 0 };\n\tsize_t lastSlash = m_priv->path.find_last_of(slashes);\n\n\tif (lastSlash != LogString::npos)\n\t{\n\t\treturn m_priv->path.substr(lastSlash + 1);\n\t}\n\n\treturn m_priv->path;\n}\n\nchar* File::getPath(Pool& p) const\n{\n\tint style = APR_FILEPATH_ENCODING_UNKNOWN;\n\tapr_filepath_encoding(&style, p.getAPRPool());\n\tchar* retval = NULL;\n\n\tif (style == APR_FILEPATH_ENCODING_UTF8)\n\t{\n\t\tretval = Transcoder::encodeUTF8(m_priv->path, p);\n\t}\n\telse\n\t{\n\t\tretval = Transcoder::encode(m_priv->path, p);\n\t}\n\n\treturn retval;\n}\n\nlog4cxx_status_t File::open(apr_file_t** file, int flags,\n\tint perm, Pool& p) const\n{\n\treturn apr_file_open(file, getPath(p), flags, perm, p.getAPRPool());\n}\n\n\n\nbool File::exists(Pool& p) const\n{\n\tapr_finfo_t finfo;\n\tapr_status_t rv = apr_stat(&finfo, getPath(p),\n\t\t\t0, p.getAPRPool());\n\treturn rv == APR_SUCCESS;\n}\n\nchar* File::convertBackSlashes(char* src)\n{\n\tfor (char* c = src; *c != 0; c++)\n\t{\n\t\tif (*c == '\\\\')\n\t\t{\n\t\t\t*c = '/';\n\t\t}\n\t}\n\n\treturn src;\n}\n\nbool File::deleteFile(Pool& p) const\n{\n\tapr_status_t rv = apr_file_remove(convertBackSlashes(getPath(p)),\n\t\t\tp.getAPRPool());\n\treturn rv == APR_SUCCESS;\n}\n\nbool File::renameTo(const File& dest, Pool& p) const\n{\n\tapr_status_t rv = apr_file_rename(convertBackSlashes(getPath(p)),\n\t\t\tconvertBackSlashes(dest.getPath(p)),\n\t\t\tp.getAPRPool());\n\treturn rv == APR_SUCCESS;\n}\n\n\nsize_t File::length(Pool& pool) const\n{\n\tapr_finfo_t finfo;\n\tapr_status_t rv = apr_stat(&finfo, getPath(pool),\n\t\t\tAPR_FINFO_SIZE, pool.getAPRPool());\n\n\tif (rv == APR_SUCCESS)\n\t{\n\t\treturn (size_t) finfo.size;\n\t}\n\n\treturn 0;\n}\n\n\nlog4cxx_time_t File::lastModified(Pool& pool) const\n{\n\tapr_finfo_t finfo;\n\tapr_status_t rv = apr_stat(&finfo, getPath(pool),\n\t\t\tAPR_FINFO_MTIME, pool.getAPRPool());\n\n\tif (rv == APR_SUCCESS)\n\t{\n\t\treturn finfo.mtime;\n\t}\n\n\treturn 0;\n}\n\n\nstd::vector File::list(Pool& p) const\n{\n\tapr_dir_t* dir;\n\tapr_finfo_t entry;\n\tstd::vector filenames;\n\n\tapr_status_t stat = apr_dir_open(&dir,\n\t\t\tconvertBackSlashes(getPath(p)),\n\t\t\tp.getAPRPool());\n\n\tif (stat == APR_SUCCESS)\n\t{\n\t\tint style = APR_FILEPATH_ENCODING_UNKNOWN;\n\t\tapr_filepath_encoding(&style, p.getAPRPool());\n\t\tstat = apr_dir_read(&entry, APR_FINFO_DIRENT, dir);\n\n\t\twhile (stat == APR_SUCCESS)\n\t\t{\n\t\t\tif (entry.name != NULL)\n\t\t\t{\n\t\t\t\tLogString filename;\n\n\t\t\t\tif (style == APR_FILEPATH_ENCODING_UTF8)\n\t\t\t\t{\n\t\t\t\t\tTranscoder::decodeUTF8(entry.name, filename);\n\t\t\t\t}\n\t\t\t\telse\n\t\t\t\t{\n\t\t\t\t\tTranscoder::decode(entry.name, filename);\n\t\t\t\t}\n\n\t\t\t\tfilenames.push_back(filename);\n\t\t\t}\n\n\t\t\tstat = apr_dir_read(&entry, APR_FINFO_DIRENT, dir);\n\t\t}\n\n\t\tstat = apr_dir_close(dir);\n\t}\n\n\treturn filenames;\n}\n\nLogString File::getParent(Pool&) const\n{\n\tLogString::size_type slashPos = m_priv->path.rfind(LOG4CXX_STR('/'));\n\tLogString::size_type backPos = m_priv->path.rfind(LOG4CXX_STR('\\\\'));\n\n\tif (slashPos == LogString::npos)\n\t{\n\t\tslashPos = backPos;\n\t}\n\telse\n\t{\n\t\tif (backPos != LogString::npos && backPos > slashPos)\n\t\t{\n\t\t\tslashPos = backPos;\n\t\t}\n\t}\n\n\tLogString parent;\n\n\tif (slashPos != LogString::npos && slashPos > 0)\n\t{\n\t\tparent.assign(m_priv->path, 0, slashPos);\n\t}\n\n\treturn parent;\n}\n\nbool File::mkdirs(Pool& p) const\n{\n\tapr_status_t stat = apr_dir_make_recursive(convertBackSlashes(getPath(p)),\n\t\t\tAPR_OS_DEFAULT, p.getAPRPool());\n\treturn stat == APR_SUCCESS;\n}\n\nvoid File::setAutoDelete(bool autoDelete){\n\tm_priv->autoDelete = autoDelete;\n}\n\nbool File::getAutoDelete() const{\n\treturn m_priv->autoDelete;\n}\n\n// Path: src/main/cpp/systemoutwriter.cpp\n/*\n * Licensed to the Apache Software Foundation (ASF) under one or more\n * contributor license agreements. See the NOTICE file distributed with\n * this work for additional information regarding copyright ownership.\n * The ASF licenses this file to You under the Apache License, Version 2.0\n * (the \"License\"); you may not use this file except in compliance with\n * the License. You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n */\n\n#include \n#include \n#include \n#include \n#if !defined(LOG4CXX)\n\t#define LOG4CXX 1\n#endif\n#include \n\nusing namespace LOG4CXX_NS;\nusing namespace LOG4CXX_NS::helpers;\n\nIMPLEMENT_LOG4CXX_OBJECT(SystemOutWriter)\n\nSystemOutWriter::SystemOutWriter()\n{\n}\n\nSystemOutWriter::~SystemOutWriter()\n{\n}\n\nvoid SystemOutWriter::close(Pool& /* p */ )\n{\n}\n\nvoid SystemOutWriter::flush(Pool& /* p */ )\n{\n\tflush();\n}\n\nvoid SystemOutWriter::write(const LogString& str, Pool& /* p */ )\n{\n\twrite(str);\n}\n\nbool SystemOutWriter::isWide()\n{\n#if LOG4CXX_FORCE_WIDE_CONSOLE\n\treturn true;\n#elif LOG4CXX_FORCE_BYTE_CONSOLE || !LOG4CXX_HAS_FWIDE\n\treturn false;\n#else\n\treturn fwide(stdout, 0) > 0;\n#endif\n}\n\nvoid SystemOutWriter::write(const LogString& str)\n{\n#if LOG4CXX_WCHAR_T_API\n\n\tif (isWide())\n\t{\n\t\tLOG4CXX_ENCODE_WCHAR(msg, str);\n\t\tfputws(msg.c_str(), stdout);\n\t\treturn;\n\t}\n\n#endif\n\tLOG4CXX_ENCODE_CHAR(msg, str);\n\tfputs(msg.c_str(), stdout);\n}\n\nvoid SystemOutWriter::flush()\n{\n\tfflush(stdout);\n}\n\n// Path: src/main/cpp/gzcompressaction.cpp\n/*\n * Licensed to the Apache Software Foundation (ASF) under one or more\n * contributor license agreements. See the NOTICE file distributed with\n * this work for additional information regarding copyright ownership.\n * The ASF licenses this file to You under the Apache License, Version 2.0\n * (the \"License\"); you may not use this file except in compliance with\n * the License. You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n */\n\n#include \n#include \n#include \n#include \n#include \n#include \n#include \n\nusing namespace LOG4CXX_NS;\nusing namespace LOG4CXX_NS::rolling;\nusing namespace LOG4CXX_NS::helpers;\n\n#define priv static_cast(m_priv.get())\n\nstruct GZCompressAction::GZCompressActionPrivate : public ActionPrivate\n{\n\tGZCompressActionPrivate( const File& toRename,\n\t\tconst File& renameTo,\n\t\tbool deleteSource):\n\t\tsource(toRename), destination(renameTo), deleteSource(deleteSource) {}\n\n\tconst File source;\n\tFile destination;\n\tbool deleteSource;\n\tbool throwIOExceptionOnForkFailure = true;\n};\n\nIMPLEMENT_LOG4CXX_OBJECT(GZCompressAction)\n\nGZCompressAction::GZCompressAction(const File& src,\n\tconst File& dest,\n\tbool del)\n\t: Action(std::make_unique(\n\t\t\t src, dest, del))\n{\n}\n\nGZCompressAction::~GZCompressAction() {}\n\nbool GZCompressAction::execute(LOG4CXX_NS::helpers::Pool& p) const\n{\n\tif (priv->source.exists(p))\n\t{\n\t\tapr_pool_t* aprpool = p.getAPRPool();\n\t\tapr_procattr_t* attr;\n\t\tapr_status_t stat = apr_procattr_create(&attr, aprpool);\n\n\t\tif (stat != APR_SUCCESS)\n\t\t{\n\t\t\tthrow IOException(stat);\n\t\t}\n\n\t\tstat = apr_procattr_io_set(attr, APR_NO_PIPE, APR_FULL_BLOCK, APR_FULL_BLOCK);\n\n\t\tif (stat != APR_SUCCESS)\n\t\t{\n\t\t\tthrow IOException(stat);\n\t\t}\n\n\t\tstat = apr_procattr_cmdtype_set(attr, APR_PROGRAM_PATH);\n\n\t\tif (stat != APR_SUCCESS)\n\t\t{\n\t\t\tthrow IOException(stat);\n\t\t}\n\n\t\t//\n\t\t// set child process output to destination file\n\t\t//\n\t\tapr_file_t* child_out;\n\t\tapr_int32_t flags = APR_FOPEN_READ | APR_FOPEN_WRITE |\n\t\t\tAPR_FOPEN_CREATE | APR_FOPEN_TRUNCATE;\n\t\tstat = priv->destination.open(&child_out, flags, APR_OS_DEFAULT, p);\n\n\t\tif (stat != APR_SUCCESS)\n\t\t{\n\t\t\tthrow IOException(stat);\n\t\t}\n\n\t\tstat = apr_procattr_child_out_set(attr, child_out, NULL);\n\n\t\tif (stat != APR_SUCCESS)\n\t\t{\n\t\t\tthrow IOException(stat);\n\t\t}\n\n\t\t//\n\t\t// redirect the child's error stream to this processes' error stream\n\t\t//\n\t\tapr_file_t* child_err;\n\t\tstat = apr_file_open_stderr(&child_err, aprpool);\n\n\t\tif (stat == APR_SUCCESS)\n\t\t{\n\t\t\tstat = apr_procattr_child_err_set(attr, child_err, NULL);\n\n\t\t\tif (stat != APR_SUCCESS)\n\t\t\t{\n\t\t\t\tthrow IOException(stat);\n\t\t\t}\n\t\t}\n\n\t\tpriv->destination.setAutoDelete(true);\n\n\t\tconst char** args = (const char**)\n\t\t\tapr_palloc(aprpool, 4 * sizeof(*args));\n\t\tint i = 0;\n\t\targs[i++] = \"gzip\";\n\t\targs[i++] = \"-c\";\n\t\targs[i++] = Transcoder::encode(priv->source.getPath(), p);\n\t\targs[i++] = NULL;\n\n\t\tapr_proc_t pid;\n\t\tstat = apr_proc_create(&pid, \"gzip\", args, NULL, attr, aprpool);\n\n\t\tif (stat != APR_SUCCESS && priv->throwIOExceptionOnForkFailure)\n\t\t{\n\t\t\tthrow IOException(stat);\n\t\t}else if(stat != APR_SUCCESS && !priv->throwIOExceptionOnForkFailure)\n\t\t{\n\t\t\t/* If we fail here (to create the gzip child process),\n\t\t\t * skip the compression and consider the rotation to be\n\t\t\t * otherwise successful. The caller has already rotated\n\t\t\t * the log file (`source` here refers to the\n\t\t\t * uncompressed, rotated path, and `destination` the\n\t\t\t * same path with `.gz` appended). Remove the empty\n\t\t\t * destination file and leave source as-is.\n\t\t\t */\n\t\t\tLogLog::warn(LOG4CXX_STR(\"Failed to fork gzip during log rotation; leaving log file uncompressed\"));\n\t\t\tstat = apr_file_close(child_out);\n\t\t\tif (stat != APR_SUCCESS)\n\t\t\t{\n\t\t\t\tLogLog::warn(LOG4CXX_STR(\"Failed to close abandoned .gz file; ignoring\"));\n\t\t\t}\n\t\t\treturn true;\n\t\t}\n\n\t\tapr_proc_wait(&pid, NULL, NULL, APR_WAIT);\n\t\tstat = apr_file_close(child_out);\n\n\t\tif (stat != APR_SUCCESS)\n\t\t{\n\t\t\tthrow IOException(stat);\n\t\t}\n\n\t\tpriv->destination.setAutoDelete(false);\n\n\t\tif (priv->deleteSource)\n\t\t{\n\t\t\tpriv->source.deleteFile(p);\n\t\t}\n\n\t\treturn true;\n\t}\n\n\treturn false;\n}\n\nvoid GZCompressAction::setThrowIOExceptionOnForkFailure(bool throwIO){\n\tpriv->throwIOExceptionOnForkFailure = throwIO;\n}\n\n\n// Path: src/main/cpp/relativetimepatternconverter.cpp\n/*\n * Licensed to the Apache Software Foundation (ASF) under one or more\n * contributor license agreements. See the NOTICE file distributed with\n * this work for additional information regarding copyright ownership.\n * The ASF licenses this file to You under the Apache License, Version 2.0\n * (the \"License\"); you may not use this file except in compliance with\n * the License. You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n */\n\n#include \n#include \n#include \n#include \n#include \n\nusing namespace LOG4CXX_NS;\nusing namespace LOG4CXX_NS::pattern;\nusing namespace LOG4CXX_NS::spi;\nusing namespace LOG4CXX_NS::helpers;\n\nIMPLEMENT_LOG4CXX_OBJECT(RelativeTimePatternConverter)\n\nRelativeTimePatternConverter::RelativeTimePatternConverter() :\n\tLoggingEventPatternConverter(LOG4CXX_STR(\"Time\"),\n\t\tLOG4CXX_STR(\"time\"))\n{\n}\n\nPatternConverterPtr RelativeTimePatternConverter::newInstance(\n\tconst std::vector& /* options */)\n{\n\tstatic WideLife def = std::make_shared();\n\treturn def;\n}\n\nvoid RelativeTimePatternConverter::format(\n\tconst LoggingEventPtr& event,\n\tLogString& toAppendTo,\n\tPool& p) const\n{\n\tlog4cxx_time_t delta = (event->getTimeStamp() - LoggingEvent::getStartTime()) / 1000;\n\tStringHelper::toString(delta, p, toAppendTo);\n}\n\n\n// Path: src/main/cpp/levelmatchfilter.cpp\n/*\n * Licensed to the Apache Software Foundation (ASF) under one or more\n * contributor license agreements. See the NOTICE file distributed with\n * this work for additional information regarding copyright ownership.\n * The ASF licenses this file to You under the Apache License, Version 2.0\n * (the \"License\"); you may not use this file except in compliance with\n * the License. You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n */\n\n#include \n#include \n#include \n#include \n#include \n#include \n#include \n\nusing namespace LOG4CXX_NS;\nusing namespace LOG4CXX_NS::filter;\nusing namespace LOG4CXX_NS::spi;\nusing namespace LOG4CXX_NS::helpers;\n\n#define priv static_cast(m_priv.get())\n\nstruct LevelMatchFilter::LevelMatchFilterPrivate : public FilterPrivate\n{\n\tbool acceptOnMatch;\n\tLevelPtr levelToMatch;\n};\n\nIMPLEMENT_LOG4CXX_OBJECT(LevelMatchFilter)\n\nLevelMatchFilter::LevelMatchFilter()\n\t: Filter(std::make_unique())\n{\n\tpriv->acceptOnMatch = true;\n}\n\nLevelMatchFilter::~LevelMatchFilter() {}\n\nvoid LevelMatchFilter::setOption(const LogString& option,\n\tconst LogString& value)\n{\n\n\n\tif (StringHelper::equalsIgnoreCase(option,\n\t\t\tLOG4CXX_STR(\"LEVELTOMATCH\"), LOG4CXX_STR(\"leveltomatch\")))\n\t{\n\t\tsetLevelToMatch(value);\n\t}\n\telse if (StringHelper::equalsIgnoreCase(option,\n\t\t\tLOG4CXX_STR(\"ACCEPTONMATCH\"), LOG4CXX_STR(\"acceptonmatch\")))\n\t{\n\t\tpriv->acceptOnMatch = OptionConverter::toBoolean(value, priv->acceptOnMatch);\n\t}\n}\n\nvoid LevelMatchFilter::setLevelToMatch(const LogString& levelToMatch1)\n{\n\tpriv->levelToMatch = OptionConverter::toLevel(levelToMatch1, priv->levelToMatch);\n}\n\nLogString LevelMatchFilter::getLevelToMatch() const\n{\n\treturn priv->levelToMatch->toString();\n}\n\nFilter::FilterDecision LevelMatchFilter::decide(\n\tconst LOG4CXX_NS::spi::LoggingEventPtr& event) const\n{\n\tif (priv->levelToMatch != 0 && priv->levelToMatch->equals(event->getLevel()))\n\t{\n\t\tif (priv->acceptOnMatch)\n\t\t{\n\t\t\treturn Filter::ACCEPT;\n\t\t}\n\t\telse\n\t\t{\n\t\t\treturn Filter::DENY;\n\t\t}\n\t}\n\telse\n\t{\n\t\treturn Filter::NEUTRAL;\n\t}\n}\n\nvoid LevelMatchFilter::setAcceptOnMatch(bool acceptOnMatch1)\n{\n\tpriv->acceptOnMatch = acceptOnMatch1;\n}\n\nbool LevelMatchFilter::getAcceptOnMatch() const\n{\n\treturn priv->acceptOnMatch;\n}\n\n// Path: src/main/cpp/fulllocationpatternconverter.cpp\n/*\n * Licensed to the Apache Software Foundation (ASF) under one or more\n * contributor license agreements. See the NOTICE file distributed with\n * this work for additional information regarding copyright ownership.\n * The ASF licenses this file to You under the Apache License, Version 2.0\n * (the \"License\"); you may not use this file except in compliance with\n * the License. You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n */\n\n#include \n#include \n#include \n#include \n#include \n\nusing namespace LOG4CXX_NS;\nusing namespace LOG4CXX_NS::pattern;\nusing namespace LOG4CXX_NS::spi;\nusing namespace LOG4CXX_NS::helpers;\n\nIMPLEMENT_LOG4CXX_OBJECT(FullLocationPatternConverter)\n\nFullLocationPatternConverter::FullLocationPatternConverter() :\n\tLoggingEventPatternConverter(LOG4CXX_STR(\"Full Location\"),\n\t\tLOG4CXX_STR(\"fullLocation\"))\n{\n}\n\nPatternConverterPtr FullLocationPatternConverter::newInstance(\n\tconst std::vector& /* options */)\n{\n\tstatic WideLife instance = std::make_shared();\n\treturn instance;\n}\n\nvoid FullLocationPatternConverter::format(\n\tconst LoggingEventPtr& event,\n\tLogString& toAppendTo,\n\tPool& p) const\n{\n\tappend(toAppendTo, event->getLocationInformation().getFileName());\n\ttoAppendTo.append(1, (logchar) 0x28 /* '(' */);\n\tStringHelper::toString(\n\t\tevent->getLocationInformation().getLineNumber(),\n\t\tp, toAppendTo);\n\ttoAppendTo.append(1, (logchar) 0x29 /* ')' */);\n}\n\n// Path: src/main/cpp/layout.cpp\n/*\n * Licensed to the Apache Software Foundation (ASF) under one or more\n * contributor license agreements. See the NOTICE file distributed with\n * this work for additional information regarding copyright ownership.\n * The ASF licenses this file to You under the Apache License, Version 2.0\n * (the \"License\"); you may not use this file except in compliance with\n * the License. You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n */\n#include \n#include \n\nusing namespace LOG4CXX_NS;\nusing namespace LOG4CXX_NS::helpers;\n\nIMPLEMENT_LOG4CXX_OBJECT(Layout)\n\n\nLayout::~Layout() {}\n\nLogString Layout::getContentType() const\n{\n\treturn LOG4CXX_STR(\"text/plain\");\n}\n\nvoid Layout::appendHeader(LogString&, LOG4CXX_NS::helpers::Pool&) {}\n\nvoid Layout::appendFooter(LogString&, LOG4CXX_NS::helpers::Pool&) {}\n\n/**\n * The expected length of a formatted event excluding the message text\n */\nsize_t Layout::getFormattedEventCharacterCount() const\n{\n\tauto exampleEvent = std::make_shared\n\t\t( LOG4CXX_STR(\"example.logger\")\n\t\t, Level::getDebug()\n\t\t, LOG4CXX_LOCATION\n\t\t, LogString()\n\t\t);\n\tLogString text;\n\tPool pool;\n\tformat(text, exampleEvent, pool);\n\treturn text.size();\n}\n\n// Path: src/main/cpp/levelpatternconverter.cpp\n/*\n * Licensed to the Apache Software Foundation (ASF) under one or more\n * contributor license agreements. See the NOTICE file distributed with\n * this work for additional information regarding copyright ownership.\n * The ASF licenses this file to You under the Apache License, Version 2.0\n * (the \"License\"); you may not use this file except in compliance with\n * the License. You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n */\n\n#include \n#include \n#include \n#include \n#include \n\n\nusing namespace LOG4CXX_NS;\nusing namespace LOG4CXX_NS::pattern;\nusing namespace LOG4CXX_NS::spi;\nusing namespace LOG4CXX_NS::helpers;\n\nIMPLEMENT_LOG4CXX_OBJECT(LevelPatternConverter)\n\nLevelPatternConverter::LevelPatternConverter() :\n\tLoggingEventPatternConverter(LOG4CXX_STR(\"Level\"),\n\t\tLOG4CXX_STR(\"level\"))\n{\n}\n\nPatternConverterPtr LevelPatternConverter::newInstance(\n\tconst std::vector& /* options */)\n{\n\tstatic WideLife def = std::make_shared();\n\treturn def;\n}\n\nvoid LevelPatternConverter::format(\n\tconst LoggingEventPtr& event,\n\tLogString& toAppendTo,\n\tLOG4CXX_NS::helpers::Pool& /* p */) const\n{\n\ttoAppendTo.append(event->getLevel()->toString());\n}\n\n\n/**\n * {@inheritDoc}\n */\nLogString LevelPatternConverter::getStyleClass(const ObjectPtr& obj) const\n{\n\tLoggingEventPtr e = LOG4CXX_NS::cast(obj);\n\n\tif (e != NULL)\n\t{\n\t\tint lint = e->getLevel()->toInt();\n\n\t\tswitch (lint)\n\t\t{\n\t\t\tcase Level::TRACE_INT:\n\t\t\t\treturn LOG4CXX_STR(\"level trace\");\n\n\t\t\tcase Level::DEBUG_INT:\n\t\t\t\treturn LOG4CXX_STR(\"level debug\");\n\n\t\t\tcase Level::INFO_INT:\n\t\t\t\treturn LOG4CXX_STR(\"level info\");\n\n\t\t\tcase Level::WARN_INT:\n\t\t\t\treturn LOG4CXX_STR(\"level warn\");\n\n\t\t\tcase Level::ERROR_INT:\n\t\t\t\treturn LOG4CXX_STR(\"level error\");\n\n\t\t\tcase Level::FATAL_INT:\n\t\t\t\treturn LOG4CXX_STR(\"level fatal\");\n\n\t\t\tdefault:\n\t\t\t\treturn LogString(LOG4CXX_STR(\"level \")) + e->getLevel()->toString();\n\t\t}\n\t}\n\n\treturn LOG4CXX_STR(\"level\");\n}\n\n// Path: src/main/cpp/consoleappender.cpp\n/*\n * Licensed to the Apache Software Foundation (ASF) under one or more\n * contributor license agreements. See the NOTICE file distributed with\n * this work for additional information regarding copyright ownership.\n * The ASF licenses this file to You under the Apache License, Version 2.0\n * (the \"License\"); you may not use this file except in compliance with\n * the License. You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n */\n#include \n#include \n#include \n#include \n#include \n#include \n#include \n#include \n#include \n\nusing namespace LOG4CXX_NS;\nusing namespace LOG4CXX_NS::helpers;\n\nstruct ConsoleAppender::ConsoleAppenderPriv : public WriterAppender::WriterAppenderPriv\n{\n\tConsoleAppenderPriv(LogString target) :\n\t\tWriterAppenderPriv(),\n\t\ttarget(target) {}\n\n\tLogString target;\n};\n\n#define _priv static_cast(m_priv.get())\n\nIMPLEMENT_LOG4CXX_OBJECT(ConsoleAppender)\n\nConsoleAppender::ConsoleAppender()\n\t: WriterAppender (std::make_unique(getSystemOut()))\n{\n}\n\nConsoleAppender::ConsoleAppender(const LayoutPtr& layout)\n\t: WriterAppender (std::make_unique(getSystemOut()))\n{\n\tsetLayout(layout);\n\tPool p;\n\tsetWriter(std::make_shared());\n\tWriterAppender::activateOptions(p);\n}\n\nConsoleAppender::ConsoleAppender(const LayoutPtr& layout, const LogString& target)\n\t: WriterAppender (std::make_unique(target))\n{\n\tsetLayout(layout);\n\tsetTarget(target);\n\tPool p;\n\tConsoleAppender::activateOptions(p);\n}\n\nConsoleAppender::~ConsoleAppender()\n{\n\tfinalize();\n}\n\nconst LogString& ConsoleAppender::getSystemOut()\n{\n\tstatic const WideLife name(LOG4CXX_STR(\"System.out\"));\n\treturn name;\n}\n\nconst LogString& ConsoleAppender::getSystemErr()\n{\n\tstatic const WideLife name(LOG4CXX_STR(\"System.err\"));\n\treturn name;\n}\n\nvoid ConsoleAppender::setTarget(const LogString& value)\n{\n\tLogString v = StringHelper::trim(value);\n\n\tif (StringHelper::equalsIgnoreCase(v,\n\t\t\tLOG4CXX_STR(\"SYSTEM.OUT\"), LOG4CXX_STR(\"system.out\")))\n\t{\n\t\t_priv->target = getSystemOut();\n\t}\n\telse if (StringHelper::equalsIgnoreCase(v,\n\t\t\tLOG4CXX_STR(\"SYSTEM.ERR\"), LOG4CXX_STR(\"system.err\")))\n\t{\n\t\t_priv->target = getSystemErr();\n\t}\n\telse\n\t{\n\t\ttargetWarn(value);\n\t}\n}\n\nLogString ConsoleAppender::getTarget() const\n{\n\treturn _priv->target;\n}\n\nvoid ConsoleAppender::targetWarn(const LogString& val)\n{\n\tLogLog::warn(((LogString) LOG4CXX_STR(\"[\"))\n\t\t+ val + LOG4CXX_STR(\"] should be system.out or system.err.\"));\n\tLogLog::warn(LOG4CXX_STR(\"Using previously set target, System.out by default.\"));\n}\n\nvoid ConsoleAppender::activateOptions(Pool& p)\n{\n\tif (StringHelper::equalsIgnoreCase(_priv->target,\n\t\t\tLOG4CXX_STR(\"SYSTEM.OUT\"), LOG4CXX_STR(\"system.out\")))\n\t{\n\t\tWriterPtr writer1 = std::make_shared();\n\t\tsetWriter(writer1);\n\t}\n\telse if (StringHelper::equalsIgnoreCase(_priv->target,\n\t\t\tLOG4CXX_STR(\"SYSTEM.ERR\"), LOG4CXX_STR(\"system.err\")))\n\t{\n\t\tWriterPtr writer1 = std::make_shared();\n\t\tsetWriter(writer1);\n\t}\n\n\tWriterAppender::activateOptions(p);\n}\n\nvoid ConsoleAppender::setOption(const LogString& option, const LogString& value)\n{\n\tif (StringHelper::equalsIgnoreCase(option,\n\t\t\tLOG4CXX_STR(\"TARGET\"), LOG4CXX_STR(\"target\")))\n\t{\n\t\tsetTarget(value);\n\t}\n\telse\n\t{\n\t\tWriterAppender::setOption(option, value);\n\t}\n}\n\n\n\n\n\n\n\n// Path: src/main/cpp/propertiespatternconverter.cpp\n/*\n * Licensed to the Apache Software Foundation (ASF) under one or more\n * contributor license agreements. See the NOTICE file distributed with\n * this work for additional information regarding copyright ownership.\n * The ASF licenses this file to You under the Apache License, Version 2.0\n * (the \"License\"); you may not use this file except in compliance with\n * the License. You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n */\n\n#include \n#include \n#include \n#include \n#include \n\n#include \n\nusing namespace LOG4CXX_NS;\nusing namespace LOG4CXX_NS::pattern;\nusing namespace LOG4CXX_NS::spi;\nusing namespace LOG4CXX_NS::helpers;\n\n#define priv static_cast(m_priv.get())\n\nstruct PropertiesPatternConverter::PropertiesPatternConverterPrivate : public PatternConverterPrivate\n{\n\t\nPropertiesPatternConverterPrivate( const LogString& name, const LogString& style, const LogString& propertyName ) :\n\t\tPatternConverterPrivate( name, style ),\n\t\toption(propertyName) {}\n\n\t/**\n\t * Name of property to output.\n\t */\n\tconst LogString option;\n};\n\nIMPLEMENT_LOG4CXX_OBJECT(PropertiesPatternConverter)\n\nPropertiesPatternConverter::PropertiesPatternConverter(const LogString& name1,\n\tconst LogString& propertyName) :\n\tLoggingEventPatternConverter(\n\t\tstd::make_unique(name1, LOG4CXX_STR(\"property\"), propertyName))\n{\n}\n\nPatternConverterPtr PropertiesPatternConverter::newInstance(\n\tconst std::vector& options)\n{\n\tif (options.size() == 0)\n\t{\n\t\tstatic WideLife def = std::make_shared(\n\t\t\t\tLOG4CXX_STR(\"Properties\"), LOG4CXX_STR(\"\"));\n\t\treturn def;\n\t}\n\n\tLogString converterName(LOG4CXX_STR(\"Property{\"));\n\tconverterName.append(options[0]);\n\tconverterName.append(LOG4CXX_STR(\"}\"));\n\treturn std::make_shared(converterName, options[0]);\n}\n\nvoid PropertiesPatternConverter::format(\n\tconst LoggingEventPtr& event,\n\tLogString& toAppendTo,\n\tPool& /* p */) const\n{\n\tif (priv->option.length() == 0)\n\t{\n\t\ttoAppendTo.append(1, (logchar) 0x7B /* '{' */);\n\n\t\tfor (auto const& item : event->getMDCKeySet())\n\t\t{\n\t\t\ttoAppendTo.append(1, (logchar) 0x7B /* '{' */);\n\t\t\ttoAppendTo.append(item);\n\t\t\ttoAppendTo.append(1, (logchar) 0x2C /* ',' */);\n\t\t\tevent->getMDC(item, toAppendTo);\n\t\t\ttoAppendTo.append(1, (logchar) 0x7D /* '}' */);\n\t\t}\n\n\t\ttoAppendTo.append(1, (logchar) 0x7D /* '}' */);\n\n\t}\n\telse\n\t{\n\t\tevent->getMDC(priv->option, toAppendTo);\n\t}\n}\n\n\n// Path: src/main/cpp/basicconfigurator.cpp\n/*\n * Licensed to the Apache Software Foundation (ASF) under one or more\n * contributor license agreements. See the NOTICE file distributed with\n * this work for additional information regarding copyright ownership.\n * The ASF licenses this file to You under the Apache License, Version 2.0\n * (the \"License\"); you may not use this file except in compliance with\n * the License. You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n */\n#include \n#include \n#include \n#include \n#include \n#include \n#include \n\nusing namespace LOG4CXX_NS;\n\nvoid BasicConfigurator::configure(const LayoutPtr& layoutArg)\n{\n\tLogManager::getLoggerRepository()->setConfigured(true);\n\tauto layout = layoutArg;\n\tif (!layout)\n\t{\n\t\tstatic const helpers::WideLife TTCC_CONVERSION_PATTERN(LOG4CXX_STR(\"%r [%t] %p %c %x - %m%n\"));\n\t\tlayout = std::make_shared(TTCC_CONVERSION_PATTERN);\n\t}\n\tauto appender = std::make_shared(layout);\n\tLogger::getRootLogger()->addAppender(appender);\n}\n\nvoid BasicConfigurator::configure(const AppenderPtr& appender)\n{\n\tLoggerPtr root = Logger::getRootLogger();\n\troot->addAppender(appender);\n}\n\nvoid BasicConfigurator::resetConfiguration()\n{\n\tLogManager::resetConfiguration();\n}\n\n// Path: src/main/cpp/threadutility.cpp\n/*\n * Licensed to the Apache Software Foundation (ASF) under one or more\n * contributor license agreements. See the NOTICE file distributed with\n * this work for additional information regarding copyright ownership.\n * The ASF licenses this file to You under the Apache License, Version 2.0\n * (the \"License\"); you may not use this file except in compliance with\n * the License. You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n */\n\n#include \"log4cxx/helpers/threadutility.h\"\n#if !defined(LOG4CXX)\n\t#define LOG4CXX 1\n#endif\n#include \"log4cxx/private/log4cxx_private.h\"\n#include \"log4cxx/helpers/loglog.h\"\n#include \"log4cxx/helpers/transcoder.h\"\n\n#include \n#include \n\n#if WIN32\n\t#include \n\t#include \n#endif\n\nnamespace LOG4CXX_NS\n{\nnamespace helpers\n{\n\nstruct ThreadUtility::priv_data\n{\n\tpriv_data()\n\t{\n\t\tstart_pre = nullptr;\n\t\tstarted = nullptr;\n\t\tstart_post = nullptr;\n\t}\n\n\tThreadStartPre start_pre;\n\tThreadStarted started;\n\tThreadStartPost start_post;\n};\n\n#if LOG4CXX_HAS_PTHREAD_SIGMASK\n\tstatic thread_local sigset_t old_mask;\n\tstatic thread_local bool sigmask_valid;\n#endif\n\nThreadUtility::ThreadUtility() :\n\tm_priv( new priv_data() )\n{\n\t// Block signals by default.\n\tconfigureFuncs( std::bind( &ThreadUtility::preThreadBlockSignals, this ),\n\t\tnullptr,\n\t\tstd::bind( &ThreadUtility::postThreadUnblockSignals, this ) );\n}\n\nThreadUtility::~ThreadUtility() {}\n\nThreadUtility* ThreadUtility::instance()\n{\n\tstatic WideLife instance;\n\treturn &instance.value();\n}\n\nvoid ThreadUtility::configure( ThreadConfigurationType type )\n{\n\tauto utility = instance();\n\n\tif ( type == ThreadConfigurationType::NoConfiguration )\n\t{\n\t\tutility->configureFuncs( nullptr, nullptr, nullptr );\n\t}\n\telse if ( type == ThreadConfigurationType::NameThreadOnly )\n\t{\n\t\tutility->configureFuncs( nullptr,\n\t\t\tstd::bind( &ThreadUtility::threadStartedNameThread, utility,\n\t\t\t\tstd::placeholders::_1,\n\t\t\t\tstd::placeholders::_2,\n\t\t\t\tstd::placeholders::_3 ),\n\t\t\tnullptr );\n\t}\n\telse if ( type == ThreadConfigurationType::BlockSignalsOnly )\n\t{\n\t\tutility->configureFuncs( std::bind( &ThreadUtility::preThreadBlockSignals, utility ),\n\t\t\tnullptr,\n\t\t\tstd::bind( &ThreadUtility::postThreadUnblockSignals, utility ) );\n\t}\n\telse if ( type == ThreadConfigurationType::BlockSignalsAndNameThread )\n\t{\n\t\tutility->configureFuncs( std::bind( &ThreadUtility::preThreadBlockSignals, utility ),\n\t\t\tstd::bind( &ThreadUtility::threadStartedNameThread, utility,\n\t\t\t\tstd::placeholders::_1,\n\t\t\t\tstd::placeholders::_2,\n\t\t\t\tstd::placeholders::_3 ),\n\t\t\tstd::bind( &ThreadUtility::postThreadUnblockSignals, utility ) );\n\t}\n}\n\nvoid ThreadUtility::configureFuncs( ThreadStartPre pre_start,\n\tThreadStarted started,\n\tThreadStartPost post_start )\n{\n\tm_priv->start_pre = pre_start;\n\tm_priv->started = started;\n\tm_priv->start_post = post_start;\n}\n\nvoid ThreadUtility::preThreadBlockSignals()\n{\n#if LOG4CXX_HAS_PTHREAD_SIGMASK\n\tsigset_t set;\n\tsigfillset(&set);\n\n\tif ( pthread_sigmask(SIG_SETMASK, &set, &old_mask) < 0 )\n\t{\n\t\tLOGLOG_ERROR( LOG4CXX_STR(\"Unable to set thread sigmask\") );\n\t\tsigmask_valid = false;\n\t}\n\telse\n\t{\n\t\tsigmask_valid = true;\n\t}\n\n#endif /* LOG4CXX_HAS_PTHREAD_SIGMASK */\n}\n\nvoid ThreadUtility::threadStartedNameThread(LogString threadName,\n\tstd::thread::id /*threadId*/,\n\tstd::thread::native_handle_type nativeHandle)\n{\n#if LOG4CXX_HAS_PTHREAD_SETNAME\n\tLOG4CXX_ENCODE_CHAR(sthreadName, threadName);\n\tif (pthread_setname_np(static_cast(nativeHandle), sthreadName.c_str()) < 0) {\n\t\tLOGLOG_ERROR(LOG4CXX_STR(\"unable to set thread name\"));\n\t}\n#elif WIN32\n\ttypedef HRESULT (WINAPI *TSetThreadDescription)(HANDLE, PCWSTR);\n\tstatic struct initialiser\n\t{\n\t\tHMODULE hKernelBase;\n\t\tTSetThreadDescription SetThreadDescription;\n\t\tinitialiser()\n\t\t\t: hKernelBase(GetModuleHandleA(\"KernelBase.dll\"))\n\t\t\t, SetThreadDescription(nullptr)\n\t\t{\n\t\t\tif (hKernelBase)\n\t\t\t\tSetThreadDescription = reinterpret_cast(GetProcAddress(hKernelBase, \"SetThreadDescription\"));\n\t\t}\n\t} win32Func;\n\tif (win32Func.SetThreadDescription)\n\t{\n\t\tLOG4CXX_ENCODE_WCHAR(wthreadName, threadName);\n\t\tif(FAILED(win32Func.SetThreadDescription(static_cast(nativeHandle), wthreadName.c_str())))\n\t\t\tLOGLOG_ERROR( LOG4CXX_STR(\"unable to set thread name\") );\n\t}\n#endif\n}\n\nvoid ThreadUtility::postThreadUnblockSignals()\n{\n#if LOG4CXX_HAS_PTHREAD_SIGMASK\n\n\t// Only restore the signal mask if we were able to set it in the first place.\n\tif ( sigmask_valid )\n\t{\n\t\tif ( pthread_sigmask(SIG_SETMASK, &old_mask, nullptr) < 0 )\n\t\t{\n\t\t\tLOGLOG_ERROR( LOG4CXX_STR(\"Unable to set thread sigmask\") );\n\t\t}\n\t}\n\n#endif /* LOG4CXX_HAS_PTHREAD_SIGMASK */\n}\n\n\nThreadStartPre ThreadUtility::preStartFunction()\n{\n\treturn m_priv->start_pre;\n}\n\nThreadStarted ThreadUtility::threadStartedFunction()\n{\n\treturn m_priv->started;\n}\n\nThreadStartPost ThreadUtility::postStartFunction()\n{\n\treturn m_priv->start_post;\n}\n\n} //namespace helpers\n} //namespace log4cxx\n\n// Path: src/main/cpp/inetaddress.cpp\n/*\n * Licensed to the Apache Software Foundation (ASF) under one or more\n * contributor license agreements. See the NOTICE file distributed with\n * this work for additional information regarding copyright ownership.\n * The ASF licenses this file to You under the Apache License, Version 2.0\n * (the \"License\"); you may not use this file except in compliance with\n * the License. You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n */\n\n#include \n#include \n#include \n#include \n#include \n\n#include \"apr_network_io.h\"\n\nusing namespace LOG4CXX_NS;\nusing namespace LOG4CXX_NS::helpers;\n\nIMPLEMENT_LOG4CXX_OBJECT(InetAddress)\n\nstruct InetAddress::InetAddressPrivate{\n\n\tLogString ipAddrString;\n\tLogString hostNameString;\n};\n\nUnknownHostException::UnknownHostException(const LogString& msg1)\n\t: Exception(msg1)\n{\n}\n\nUnknownHostException::UnknownHostException(const UnknownHostException& src)\n\t: Exception(src)\n{\n}\n\nUnknownHostException& UnknownHostException::operator=(const UnknownHostException& src)\n{\n\tException::operator=(src);\n\treturn *this;\n}\n\n\nInetAddress::InetAddress(const LogString& hostName, const LogString& hostAddr)\n\t: m_priv(std::make_unique())\n{\n\tm_priv->ipAddrString = hostAddr;\n\tm_priv->hostNameString = hostName;\n}\n\nInetAddress::~InetAddress(){}\n\n/** Determines all the IP addresses of a host, given the host's name.\n*/\nstd::vector InetAddress::getAllByName(const LogString& host)\n{\n\tLOG4CXX_ENCODE_CHAR(encodedHost, host);\n\n\t// retrieve information about the given host\n\tPool addrPool;\n\n\tapr_sockaddr_t* address = 0;\n\tapr_status_t status =\n\t\tapr_sockaddr_info_get(&address, encodedHost.c_str(),\n...\n", "needle_token_start": 11412, "needle_token_end": 11457, "code_context_ntokens": 15257, "output": ["```cpp\n#include \n#include \n#include \n#include \n#include \n\n#include \n\nusing namespace LOG4CXX_NS;\nusing namespace LOG4CXX_NS::pattern;\nusing namespace LOG4CXX_NS::spi;\nusing namespace LOG4CXX_NS::helpers;\n\n#define priv static_cast(m_priv.get())\n\nstruct PropertiesPatternConverter::PropertiesPatternConverterPrivate : public PatternConverterPrivate\n{\n\t\nPropertiesPatternConverterPrivate( const LogString& name, const LogString& style, const LogString& propertyName ) :\n\t\tPatternConverterPrivate( name, style ),\n\t\toption(propertyName) {}\n\n\t/**\n\t * Name of property to output.\n\t */\n\tconst LogString option;\n};\n\nIMPLEMENT_LOG4CXX_OBJECT(PropertiesPatternConverter)\n\nPropertiesPatternConverter::PropertiesPatternConverter(const LogString& name1,\n\tconst LogString& propertyName) :\n\tLoggingEventPatternConverter(\n\t\tstd::make_unique(name1, LOG4CXX_STR(\"property\"), propertyName))\n{\n}\n\nPatternConverterPtr PropertiesPatternConverter::newInstance(\n\tconst std::vector& options)\n{\n\tif (options.size() == 0)\n\t{\n\t\tstatic WideLife def = std::make_shared(\n\t\t\t\tLOG4CXX_STR(\"Properties\"), LOG4CXX_STR(\"\"));\n\t\treturn def;\n\t}\n\n\tLogString converterName(LOG4CXX_STR(\"Property{\"));\n\tconverterName.append(options[0]);\n\tconverterName.append(LOG4CXX_STR(\"}\"));\n\treturn std::make_shared(converterName, options[0]);\n}\n\nvoid PropertiesPatternConverter::format(\n\tconst LoggingEventPtr& event,\n\tLogString& toAppendTo,\n\tPool& /* p */) const\n{\n\tif (priv->option.length() == 0)\n\t{\n\t\ttoAppendTo.append(1, (logchar) 0x7B /* '{' */);\n\n\t\tfor (auto const& item : event->getMDCKeySet())\n\t\t{\n\t\t\ttoAppendTo.append(1, (logchar) 0x7B /* '{' */);\n\t\t\ttoAppendTo.append(item);\n\t\t\ttoAppendTo.append(1, (logchar) 0x2C /* ',' */);\n\t\t\tevent->getMDC(item, toAppendTo);\n\t\t\ttoAppendTo.append(1, (logchar) 0x7D /* '}' */);\n\t\t}\n\n\t\ttoAppendTo.append(1, (logchar) 0x7D /* '}' */);\n\n\t}\n\telse\n\t{\n\t\tevent->getMDC(priv->option, toAppendTo);\n\t}\n}\n```"]} +{"repo": "apache/logging-log4cxx", "name": "ThrowableInformationPatternConverterPrivate", "language": "cpp", "path": "src/main/cpp/throwableinformationpatternconverter.cpp", "position_ratio": 0.85, "description": "\nFunction Description:\n1. **Purpose**: To configure and handle the conversion of throwable information from logging events into a formatted string, with an option for abbreviated content.\n2. **Input**: Takes a name and style for formatting, along with a boolean to determine if the throwable information should be reported in a shortened form.\n3. **Output**: Does not directly output data; instead, it sets up the internal configuration for how throwable information should be formatted when converted.\n4. **Procedure**: Initializes the internal settings based on the provided parameters, setting up how throwable information will be processed and formatted during logging operations.\n\n", "instruction": "Based on the function description and code context, please retrieve and repeat the exact described function from the code context in a code block wrapped by ```:", "template": "instruction\ncode_context\ndescription\ninstruction", "code_context": "// Path: src/main/cpp/fileappender.cpp\n/*\n * Licensed to the Apache Software Foundation (ASF) under one or more\n * contributor license agreements. See the NOTICE file distributed with\n * this work for additional information regarding copyright ownership.\n * The ASF licenses this file to You under the Apache License, Version 2.0\n * (the \"License\"); you may not use this file except in compliance with\n * the License. You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n */\n\n#include \n#include \n#include \n#include \n#include \n#include \n#include \n#include \n#include \n#include \n#include \n#include \n#include \n\nusing namespace LOG4CXX_NS;\nusing namespace LOG4CXX_NS::helpers;\nusing namespace LOG4CXX_NS::spi;\n\nIMPLEMENT_LOG4CXX_OBJECT(FileAppender)\n\n#define _priv static_cast(m_priv.get())\n\nFileAppender::FileAppender() :\n\tWriterAppender (std::make_unique())\n{\n}\n\nFileAppender::FileAppender\n\t( const LayoutPtr& layout1\n\t, const LogString& fileName1\n\t, bool append1\n\t, bool bufferedIO1\n\t, int bufferSize1\n\t)\n\t: WriterAppender(std::make_unique(layout1, fileName1, append1, bufferedIO1, bufferSize1))\n{\n\tPool p;\n\tactivateOptions(p);\n}\n\nFileAppender::FileAppender\n\t( const LayoutPtr& layout1\n\t, const LogString& fileName1\n\t, bool append1\n\t)\n\t: WriterAppender(std::make_unique(layout1, fileName1, append1, false))\n{\n\tPool p;\n\tactivateOptions(p);\n}\n\nFileAppender::FileAppender(const LayoutPtr& layout1, const LogString& fileName1)\n\t: WriterAppender(std::make_unique(layout1, fileName1))\n{\n\tPool p;\n\tactivateOptions(p);\n}\n\nFileAppender::FileAppender(std::unique_ptr priv)\n\t: WriterAppender (std::move(priv))\n{\n}\n\nFileAppender::~FileAppender()\n{\n\tfinalize();\n}\n\nvoid FileAppender::setAppend(bool fileAppend1)\n{\n\tstd::lock_guard lock(_priv->mutex);\n\t_priv->fileAppend = fileAppend1;\n}\n\nvoid FileAppender::setFile(const LogString& file)\n{\n\tstd::lock_guard lock(_priv->mutex);\n\tsetFileInternal(file);\n}\n\nvoid FileAppender::setFileInternal(const LogString& file)\n{\n\t_priv->fileName = file;\n}\n\nvoid FileAppender::setBufferedIO(bool bufferedIO1)\n{\n\tstd::lock_guard lock(_priv->mutex);\n\t_priv->bufferedIO = bufferedIO1;\n\n\tif (bufferedIO1)\n\t{\n\t\tsetImmediateFlush(false);\n\t}\n}\n\nvoid FileAppender::setOption(const LogString& option,\n\tconst LogString& value)\n{\n\tif (StringHelper::equalsIgnoreCase(option, LOG4CXX_STR(\"FILE\"), LOG4CXX_STR(\"file\"))\n\t\t|| StringHelper::equalsIgnoreCase(option, LOG4CXX_STR(\"FILENAME\"), LOG4CXX_STR(\"filename\")))\n\t{\n\t\tstd::lock_guard lock(_priv->mutex);\n\t\t_priv->fileName = stripDuplicateBackslashes(value);\n\t}\n\telse if (StringHelper::equalsIgnoreCase(option, LOG4CXX_STR(\"APPEND\"), LOG4CXX_STR(\"append\")))\n\t{\n\t\tstd::lock_guard lock(_priv->mutex);\n\t\t_priv->fileAppend = OptionConverter::toBoolean(value, true);\n\t}\n\telse if (StringHelper::equalsIgnoreCase(option, LOG4CXX_STR(\"BUFFEREDIO\"), LOG4CXX_STR(\"bufferedio\")))\n\t{\n\t\tstd::lock_guard lock(_priv->mutex);\n\t\t_priv->bufferedIO = OptionConverter::toBoolean(value, false);\n\t}\n\telse if (StringHelper::equalsIgnoreCase(option, LOG4CXX_STR(\"IMMEDIATEFLUSH\"), LOG4CXX_STR(\"immediateflush\")))\n\t{\n\t\tstd::lock_guard lock(_priv->mutex);\n\t\t_priv->bufferedIO = !OptionConverter::toBoolean(value, false);\n\t}\n\telse if (StringHelper::equalsIgnoreCase(option, LOG4CXX_STR(\"BUFFERSIZE\"), LOG4CXX_STR(\"buffersize\")))\n\t{\n\t\tstd::lock_guard lock(_priv->mutex);\n\t\t_priv->bufferSize = OptionConverter::toFileSize(value, 8 * 1024);\n\t}\n\telse\n\t{\n\t\tWriterAppender::setOption(option, value);\n\t}\n}\n\nvoid FileAppender::activateOptions(Pool& p)\n{\n\tstd::lock_guard lock(_priv->mutex);\n\tactivateOptionsInternal(p);\n}\n\nvoid FileAppender::activateOptionsInternal(Pool& p)\n{\n\tint errors = 0;\n\n\tif (!_priv->fileName.empty())\n\t{\n\t\ttry\n\t\t{\n\t\t\tsetFileInternal(_priv->fileName, _priv->fileAppend, _priv->bufferedIO, _priv->bufferSize, p);\n\t\t}\n\t\tcatch (IOException& e)\n\t\t{\n\t\t\terrors++;\n\t\t\tLogString msg(LOG4CXX_STR(\"setFile(\"));\n\t\t\tmsg.append(_priv->fileName);\n\t\t\tmsg.append(1, (logchar) 0x2C /* ',' */);\n\t\t\tStringHelper::toString(_priv->fileAppend, msg);\n\t\t\tmsg.append(LOG4CXX_STR(\") call failed.\"));\n\t\t\t_priv->errorHandler->error(msg, e, ErrorCode::FILE_OPEN_FAILURE);\n\t\t}\n\t}\n\telse\n\t{\n\t\terrors++;\n\t\tLogLog::error(LogString(LOG4CXX_STR(\"File option not set for appender [\"))\n\t\t\t+ _priv->name + LOG4CXX_STR(\"].\"));\n\t\tLogLog::warn(LOG4CXX_STR(\"Are you using FileAppender instead of ConsoleAppender?\"));\n\t}\n\n\tif (errors == 0)\n\t{\n\t\tWriterAppender::activateOptions(p);\n\t}\n}\n\n\n/**\n * Replaces double backslashes (except the leading doubles of UNC's)\n * with single backslashes for compatibility with existing path\n * specifications that were working around use of\n * OptionConverter::convertSpecialChars in XML configuration files.\n *\n * @param src source string\n * @return modified string\n *\n *\n */\n...\n// Path: src/main/cpp/transform.cpp\n/*\n * Licensed to the Apache Software Foundation (ASF) under one or more\n * contributor license agreements. See the NOTICE file distributed with\n * this work for additional information regarding copyright ownership.\n * The ASF licenses this file to You under the Apache License, Version 2.0\n * (the \"License\"); you may not use this file except in compliance with\n * the License. You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n */\n\n#include \n#include \n#include \n\nusing namespace LOG4CXX_NS;\nusing namespace LOG4CXX_NS::helpers;\n\n\n\nvoid Transform::appendEscapingTags(\n\tLogString& buf, const LogString& input)\n{\n\t//Check if the string is zero length -- if so, return\n\t//what was sent in.\n\n\tif (input.length() == 0 )\n\t{\n\t\treturn;\n\t}\n\n\tlogchar specials[] = { 0x22 /* \" */, 0x26 /* & */, 0x3C /* < */, 0x3E /* > */, 0x00 };\n\tsize_t start = 0;\n\tsize_t special = input.find_first_of(specials, start);\n\n\twhile (special != LogString::npos)\n\t{\n\t\tif (special > start)\n\t\t{\n\t\t\tbuf.append(input, start, special - start);\n\t\t}\n\n\t\tswitch (input[special])\n\t\t{\n\t\t\tcase 0x22:\n\t\t\t\tbuf.append(LOG4CXX_STR(\""\"));\n\t\t\t\tbreak;\n\n\t\t\tcase 0x26:\n\t\t\t\tbuf.append(LOG4CXX_STR(\"&\"));\n\t\t\t\tbreak;\n\n\t\t\tcase 0x3C:\n\t\t\t\tbuf.append(LOG4CXX_STR(\"<\"));\n\t\t\t\tbreak;\n\n\t\t\tcase 0x3E:\n\t\t\t\tbuf.append(LOG4CXX_STR(\">\"));\n\t\t\t\tbreak;\n\n\t\t\tdefault:\n\t\t\t\tbuf.append(1, input[special]);\n\t\t\t\tbreak;\n\t\t}\n\n\t\tstart = special + 1;\n\n\t\tif (special < input.size())\n\t\t{\n\t\t\tspecial = input.find_first_of(specials, start);\n\t\t}\n\t\telse\n\t\t{\n\t\t\tspecial = LogString::npos;\n\t\t}\n\t}\n\n\tif (start < input.size())\n\t{\n\t\tbuf.append(input, start, input.size() - start);\n\t}\n}\n\nvoid Transform::appendEscapingCDATA(\n\tLogString& buf, const LogString& input)\n{\n\tstatic const WideLife CDATA_END(LOG4CXX_STR(\"]]>\"));\n\tstatic const WideLife CDATA_EMBEDED_END(LOG4CXX_STR(\"]]>]]>\n#include \n#include \n#include \n#include \n#include \n#include \n#include \n#include \n\nusing namespace LOG4CXX_NS;\nusing namespace LOG4CXX_NS::helpers;\nusing namespace LOG4CXX_NS::net;\n\nIMPLEMENT_LOG4CXX_OBJECT(TelnetAppender)\n\nstruct TelnetAppender::TelnetAppenderPriv : public AppenderSkeletonPrivate\n{\n\tTelnetAppenderPriv( int port, int maxConnections ) : AppenderSkeletonPrivate(),\n\t\tport(port),\n\t\tconnections(maxConnections),\n\t\tencoding(LOG4CXX_STR(\"UTF-8\")),\n\t\tencoder(CharsetEncoder::getUTF8Encoder()),\n\t\tsh(),\n\t\tactiveConnections(0) {}\n\n\tint port;\n\tConnectionList connections;\n\tLogString encoding;\n\tLOG4CXX_NS::helpers::CharsetEncoderPtr encoder;\n\tstd::unique_ptr serverSocket;\n\tstd::thread sh;\n\tsize_t activeConnections;\n};\n\n#define _priv static_cast(m_priv.get())\n\n/** The default telnet server port */\nconst int TelnetAppender::DEFAULT_PORT = 23;\n\n/** The maximum number of concurrent connections */\nconst int TelnetAppender::MAX_CONNECTIONS = 20;\n\nTelnetAppender::TelnetAppender()\n\t: AppenderSkeleton (std::make_unique(DEFAULT_PORT, MAX_CONNECTIONS))\n{\n}\n\nTelnetAppender::~TelnetAppender()\n{\n\tfinalize();\n}\n\nvoid TelnetAppender::activateOptions(Pool& /* p */)\n{\n\tif (_priv->serverSocket == NULL)\n\t{\n\t\t_priv->serverSocket = ServerSocket::create(_priv->port);\n\t\t_priv->serverSocket->setSoTimeout(1000);\n\t}\n\n\t_priv->sh = ThreadUtility::instance()->createThread( LOG4CXX_STR(\"TelnetAppender\"), &TelnetAppender::acceptConnections, this );\n}\n\nvoid TelnetAppender::setOption(const LogString& option,\n\tconst LogString& value)\n{\n\tif (StringHelper::equalsIgnoreCase(option, LOG4CXX_STR(\"PORT\"), LOG4CXX_STR(\"port\")))\n\t{\n\t\tsetPort(OptionConverter::toInt(value, DEFAULT_PORT));\n\t}\n\telse if (StringHelper::equalsIgnoreCase(option, LOG4CXX_STR(\"ENCODING\"), LOG4CXX_STR(\"encoding\")))\n\t{\n\t\tsetEncoding(value);\n\t}\n\telse\n\t{\n\t\tAppenderSkeleton::setOption(option, value);\n\t}\n}\n\nLogString TelnetAppender::getEncoding() const\n{\n\tstd::lock_guard lock(_priv->mutex);\n\treturn _priv->encoding;\n}\n\nvoid TelnetAppender::setEncoding(const LogString& value)\n{\n\tstd::lock_guard lock(_priv->mutex);\n\t_priv->encoder = CharsetEncoder::getEncoder(value);\n\t_priv->encoding = value;\n}\n\n\nvoid TelnetAppender::close()\n{\n\tstd::lock_guard lock(_priv->mutex);\n\n\tif (_priv->closed)\n\t{\n\t\treturn;\n\t}\n\n\t_priv->closed = true;\n\n\tSocketPtr nullSocket;\n\n\tfor (auto& item : _priv->connections)\n\t{\n\t\tif (item)\n\t\t{\n\t\t\titem->close();\n\t\t\titem = nullSocket;\n\t\t}\n\t}\n\n\tif (_priv->serverSocket != NULL)\n\t{\n\t\ttry\n\t\t{\n\t\t\t_priv->serverSocket->close();\n\t\t}\n\t\tcatch (Exception&)\n\t\t{\n\t\t}\n\t}\n\n\tif ( _priv->sh.joinable() )\n\t{\n\t\t_priv->sh.join();\n\t}\n\n\t_priv->activeConnections = 0;\n}\n\n\nvoid TelnetAppender::write(ByteBuffer& buf)\n{\n\tfor (auto& item :_priv->connections)\n\t{\n\t\tif (item)\n\t\t{\n\t\t\ttry\n\t\t\t{\n\t\t\t\tByteBuffer b(buf.current(), buf.remaining());\n\t\t\t\titem->write(b);\n\t\t\t}\n\t\t\tcatch (Exception&)\n\t\t\t{\n\t\t\t\t// The client has closed the connection, remove it from our list:\n\t\t\t\titem.reset();\n\t\t\t\t_priv->activeConnections--;\n\t\t\t}\n\t\t}\n\t}\n}\n\nvoid TelnetAppender::writeStatus(const SocketPtr& socket, const LogString& msg, Pool& p)\n{\n\tsize_t bytesSize = msg.size() * 2;\n\tchar* bytes = p.pstralloc(bytesSize);\n\n\tLogString::const_iterator msgIter(msg.begin());\n\tByteBuffer buf(bytes, bytesSize);\n\n\twhile (msgIter != msg.end())\n\t{\n\t\t_priv->encoder->encode(msg, msgIter, buf);\n\t\tbuf.flip();\n\t\tsocket->write(buf);\n\t\tbuf.clear();\n\t}\n}\n\nvoid TelnetAppender::append(const spi::LoggingEventPtr& event, Pool& p)\n{\n\tsize_t count = _priv->activeConnections;\n\n\tif (count > 0)\n\t{\n\t\tLogString msg;\n\t\t_priv->layout->format(msg, event, _priv->pool);\n\t\tmsg.append(LOG4CXX_STR(\"\\r\\n\"));\n\t\tsize_t bytesSize = msg.size() * 2;\n\t\tchar* bytes = p.pstralloc(bytesSize);\n\n\t\tLogString::const_iterator msgIter(msg.begin());\n\t\tByteBuffer buf(bytes, bytesSize);\n\n\t\tstd::lock_guard lock(_priv->mutex);\n\n\t\twhile (msgIter != msg.end())\n\t\t{\n\t\t\tlog4cxx_status_t stat = _priv->encoder->encode(msg, msgIter, buf);\n\t\t\tbuf.flip();\n\t\t\twrite(buf);\n\t\t\tbuf.clear();\n\n\t\t\tif (CharsetEncoder::isError(stat))\n\t\t\t{\n\t\t\t\tLogString unrepresented(1, 0x3F /* '?' */);\n\t\t\t\tLogString::const_iterator unrepresentedIter(unrepresented.begin());\n\t\t\t\tstat = _priv->encoder->encode(unrepresented, unrepresentedIter, buf);\n\t\t\t\tbuf.flip();\n\t\t\t\twrite(buf);\n\t\t\t\tbuf.clear();\n\t\t\t\tmsgIter++;\n\t\t\t}\n\t\t}\n\t}\n}\n\nvoid TelnetAppender::acceptConnections()\n{\n\n\t// main loop; is left when This->closed is != 0 after an accept()\n\twhile (true)\n\t{\n\t\ttry\n\t\t{\n\t\t\tSocketPtr newClient = _priv->serverSocket->accept();\n\t\t\tbool done = _priv->closed;\n\n\t\t\tif (done)\n\t\t\t{\n\t\t\t\tPool p;\n\t\t\t\twriteStatus(newClient, LOG4CXX_STR(\"Log closed.\\r\\n\"), p);\n\t\t\t\tnewClient->close();\n\n\t\t\t\tbreak;\n\t\t\t}\n\n\t\t\tsize_t count = _priv->activeConnections;\n\n\t\t\tif (count >= _priv->connections.size())\n\t\t\t{\n\t\t\t\tPool p;\n\t\t\t\twriteStatus(newClient, LOG4CXX_STR(\"Too many connections.\\r\\n\"), p);\n\t\t\t\tnewClient->close();\n\t\t\t}\n\t\t\telse\n\t\t\t{\n\t\t\t\t//\n\t\t\t\t// find unoccupied connection\n\t\t\t\t//\n\t\t\t\tstd::lock_guard lock(_priv->mutex);\n\n\t\t\t\tfor (auto& item : _priv->connections)\n\t\t\t\t{\n\t\t\t\t\tif (!item)\n\t\t\t\t\t{\n\t\t\t\t\t\titem = newClient;\n\t\t\t\t\t\t_priv->activeConnections++;\n\n\t\t\t\t\t\tbreak;\n\t\t\t\t\t}\n\t\t\t\t}\n\n\t\t\t\tPool p;\n\t\t\t\tLogString oss(LOG4CXX_STR(\"TelnetAppender v1.0 (\"));\n\t\t\t\tStringHelper::toString((int) count + 1, p, oss);\n\t\t\t\toss += LOG4CXX_STR(\" active connections)\\r\\n\\r\\n\");\n\t\t\t\twriteStatus(newClient, oss, p);\n\t\t\t}\n\t\t}\n\t\tcatch (InterruptedIOException&)\n\t\t{\n\t\t\tif (_priv->closed)\n\t\t\t{\n\t\t\t\tbreak;\n\t\t\t}\n\t\t}\n\t\tcatch (Exception& e)\n\t\t{\n\t\t\tif (!_priv->closed)\n\t\t\t{\n\t\t\t\tLogLog::error(LOG4CXX_STR(\"Encountered error while in SocketHandler loop.\"), e);\n\t\t\t}\n\t\t\telse\n\t\t\t{\n\t\t\t\tbreak;\n\t\t\t}\n\t\t}\n\t}\n\n}\n\nint TelnetAppender::getPort() const\n{\n\treturn _priv->port;\n}\n\nvoid TelnetAppender::setPort(int port1)\n{\n\t_priv->port = port1;\n}\n\n// Path: src/main/cpp/outputdebugstringappender.cpp\n/*\n * Licensed to the Apache Software Foundation (ASF) under one or more\n * contributor license agreements. See the NOTICE file distributed with\n * this work for additional information regarding copyright ownership.\n * The ASF licenses this file to You under the Apache License, Version 2.0\n * (the \"License\"); you may not use this file except in compliance with\n * the License. You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n */\n\n#if defined(_WIN32)\n#include \n#include \n#include \n\n#include \"windows.h\"\n\nusing namespace LOG4CXX_NS;\nusing namespace LOG4CXX_NS::helpers;\nusing namespace LOG4CXX_NS::nt;\n\nIMPLEMENT_LOG4CXX_OBJECT(OutputDebugStringAppender)\n\nOutputDebugStringAppender::OutputDebugStringAppender()\n{\n}\n\nvoid OutputDebugStringAppender::append(const spi::LoggingEventPtr& event, Pool& p)\n{\n\tLogString buf;\n\tthis->m_priv->layout->format(buf, event, p);\n#if LOG4CXX_WCHAR_T_API\n\tLOG4CXX_ENCODE_WCHAR(wstr, buf);\n\t::OutputDebugStringW(wstr.c_str());\n#else\n\tLOG4CXX_ENCODE_CHAR(str, buf);\n\t::OutputDebugStringA(str.c_str());\n#endif\n}\n\n#endif\n\n\n// Path: src/main/cpp/charsetdecoder.cpp\n/*\n * Licensed to the Apache Software Foundation (ASF) under one or more\n * contributor license agreements. See the NOTICE file distributed with\n * this work for additional information regarding copyright ownership.\n * The ASF licenses this file to You under the Apache License, Version 2.0\n * (the \"License\"); you may not use this file except in compliance with\n * the License. You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n */\n#define NOMINMAX /* tell windows not to define min/max macros */\n#include \n#include \n#include \n#include \n#include \n#include \n#include \n#if !defined(LOG4CXX)\n\t#define LOG4CXX 1\n#endif\n#include \n#include \n#include \n#include \n#include \n#include \n\nusing namespace LOG4CXX_NS;\nusing namespace LOG4CXX_NS::helpers;\n\nIMPLEMENT_LOG4CXX_OBJECT(CharsetDecoder)\n\n\nnamespace LOG4CXX_NS\n{\nnamespace helpers\n{\n\n#if APR_HAS_XLATE\n/**\n * Converts from an arbitrary encoding to LogString\n * using apr_xlate. Requires real iconv implementation,\n* apr-iconv will crash in use.\n */\nclass APRCharsetDecoder : public CharsetDecoder\n{\n\tpublic:\n\t\t/**\n\t\t * Creates a new instance.\n\t\t * @param frompage name of source encoding.\n\t\t */\n\t\tAPRCharsetDecoder(const LogString& frompage) : pool()\n\t\t{\n#if LOG4CXX_LOGCHAR_IS_WCHAR\n\t\t\tconst char* topage = \"WCHAR_T\";\n#endif\n#if LOG4CXX_LOGCHAR_IS_UTF8\n\t\t\tconst char* topage = \"UTF-8\";\n#endif\n#if LOG4CXX_LOGCHAR_IS_UNICHAR\n\t\t\tconst char* topage = \"UTF-16\";\n#endif\n\t\t\tstd::string fpage(Transcoder::encodeCharsetName(frompage));\n\t\t\tapr_status_t stat = apr_xlate_open(&convset,\n\t\t\t\t\ttopage,\n\t\t\t\t\tfpage.c_str(),\n\t\t\t\t\tpool.getAPRPool());\n\n\t\t\tif (stat != APR_SUCCESS)\n\t\t\t{\n\t\t\t\tthrow IllegalArgumentException(frompage);\n\t\t\t}\n\t\t}\n\n\t\t/**\n\t\t * Destructor.\n\t\t */\n\t\tvirtual ~APRCharsetDecoder()\n\t\t{\n\t\t}\n\n\t\tvirtual log4cxx_status_t decode(ByteBuffer& in,\n\t\t\tLogString& out)\n\t\t{\n\t\t\tenum { BUFSIZE = 256 };\n\t\t\tlogchar buf[BUFSIZE];\n\t\t\tconst apr_size_t initial_outbytes_left = BUFSIZE * sizeof(logchar);\n\t\t\tapr_status_t stat = APR_SUCCESS;\n\n\t\t\tif (in.remaining() == 0)\n\t\t\t{\n\t\t\t\tsize_t outbytes_left = initial_outbytes_left;\n\t\t\t\t{\n\t\t\t\t\tstd::unique_lock lock(mutex);\n\t\t\t\t\tstat = apr_xlate_conv_buffer((apr_xlate_t*) convset,\n\t\t\t\t\t\t\tNULL, NULL, (char*) buf, &outbytes_left);\n\t\t\t\t}\n\t\t\t\tout.append(buf, (initial_outbytes_left - outbytes_left) / sizeof(logchar));\n\t\t\t}\n\t\t\telse\n\t\t\t{\n\t\t\t\twhile (in.remaining() > 0 && stat == APR_SUCCESS)\n\t\t\t\t{\n\t\t\t\t\tsize_t inbytes_left = in.remaining();\n\t\t\t\t\tsize_t initial_inbytes_left = inbytes_left;\n\t\t\t\t\tsize_t pos = in.position();\n\t\t\t\t\tapr_size_t outbytes_left = initial_outbytes_left;\n\t\t\t\t\t{\n\t\t\t\t\t\tstd::unique_lock lock(mutex);\n\t\t\t\t\t\tstat = apr_xlate_conv_buffer((apr_xlate_t*) convset,\n\t\t\t\t\t\t\t\tin.data() + pos,\n\t\t\t\t\t\t\t\t&inbytes_left,\n\t\t\t\t\t\t\t\t(char*) buf,\n\t\t\t\t\t\t\t\t&outbytes_left);\n\t\t\t\t\t}\n\t\t\t\t\tout.append(buf, (initial_outbytes_left - outbytes_left) / sizeof(logchar));\n\t\t\t\t\tin.position(pos + (initial_inbytes_left - inbytes_left));\n\t\t\t\t}\n\t\t\t}\n\n\t\t\treturn stat;\n\t\t}\n\n\tprivate:\n\t\tAPRCharsetDecoder(const APRCharsetDecoder&);\n\t\tAPRCharsetDecoder& operator=(const APRCharsetDecoder&);\n\t\tLOG4CXX_NS::helpers::Pool pool;\n\t\tstd::mutex mutex;\n\t\tapr_xlate_t* convset;\n};\n\n#endif\n\n#if LOG4CXX_LOGCHAR_IS_WCHAR && LOG4CXX_HAS_MBSRTOWCS\n/**\n* Converts from the default multi-byte string to\n* LogString using mbstowcs.\n*\n*/\nclass MbstowcsCharsetDecoder : public CharsetDecoder\n{\n\tpublic:\n\t\tMbstowcsCharsetDecoder()\n\t\t{\n\t\t}\n\n\t\tvirtual ~MbstowcsCharsetDecoder()\n\t\t{\n\t\t}\n\n\tprivate:\n\t\tinline log4cxx_status_t append(LogString& out, const wchar_t* buf)\n\t\t{\n\t\t\tout.append(buf);\n\t\t\treturn APR_SUCCESS;\n\t\t}\n\n\t\tvirtual log4cxx_status_t decode(ByteBuffer& in,\n\t\t\tLogString& out)\n\t\t{\n\t\t\tlog4cxx_status_t stat = APR_SUCCESS;\n\t\t\tenum { BUFSIZE = 256 };\n\t\t\twchar_t wbuf[BUFSIZE];\n\t\t\tchar cbuf[BUFSIZE*4];\n\n\t\t\tmbstate_t mbstate;\n\t\t\tmemset(&mbstate, 0, sizeof(mbstate));\n\n\t\t\twhile (in.remaining() > 0)\n\t\t\t{\n\t\t\t\tconst char* src = in.current();\n\n\t\t\t\tif (*src == 0)\n\t\t\t\t{\n\t\t\t\t\tout.append(1, (logchar) 0);\n\t\t\t\t\tin.position(in.position() + 1);\n\t\t\t\t}\n\t\t\t\telse\n\t\t\t\t{\n\t\t\t\t\tauto available = std::min(sizeof (cbuf) - 1, in.remaining());\n\t\t\t\t\tstrncpy(cbuf, src, available);\n\t\t\t\t\tcbuf[available] = 0;\n\t\t\t\t\tsrc = cbuf;\n\t\t\t\t\tsize_t wCharCount = mbsrtowcs(wbuf,\n\t\t\t\t\t\t\t&src,\n\t\t\t\t\t\t\tBUFSIZE - 1,\n\t\t\t\t\t\t\t&mbstate);\n\t\t\t\t\tauto converted = src - cbuf;\n\t\t\t\t\tin.position(in.position() + converted);\n\n\t\t\t\t\tif (wCharCount == (size_t) -1) // Illegal byte sequence?\n\t\t\t\t\t{\n\t\t\t\t\t\tLogString msg(LOG4CXX_STR(\"Illegal byte sequence at \"));\n\t\t\t\t\t\tmsg.append(std::to_wstring(in.position()));\n\t\t\t\t\t\tmsg.append(LOG4CXX_STR(\" of \"));\n\t\t\t\t\t\tmsg.append(std::to_wstring(in.limit()));\n\t\t\t\t\t\tLogLog::warn(msg);\n\t\t\t\t\t\tstat = APR_BADCH;\n\t\t\t\t\t\tbreak;\n\t\t\t\t\t}\n\t\t\t\t\telse\n\t\t\t\t\t{\n\t\t\t\t\t\twbuf[wCharCount] = 0;\n\t\t\t\t\t\tstat = append(out, wbuf);\n\t\t\t\t\t}\n\t\t\t\t}\n\t\t\t}\n\n\t\t\treturn stat;\n\t\t}\n\n\n\n\tprivate:\n\t\tMbstowcsCharsetDecoder(const MbstowcsCharsetDecoder&);\n\t\tMbstowcsCharsetDecoder& operator=(const MbstowcsCharsetDecoder&);\n};\n#endif\n\n\n/**\n* Decoder used when the external and internal charsets\n* are the same.\n*\n*/\nclass TrivialCharsetDecoder : public CharsetDecoder\n{\n\tpublic:\n\t\tTrivialCharsetDecoder()\n\t\t{\n\t\t}\n\n\t\tvirtual ~TrivialCharsetDecoder()\n\t\t{\n\t\t}\n\n\t\tvirtual log4cxx_status_t decode(ByteBuffer& in,\n\t\t\tLogString& out)\n\t\t{\n\t\t\tsize_t remaining = in.remaining();\n\n\t\t\tif ( remaining > 0)\n\t\t\t{\n\t\t\t\tconst logchar* src = (const logchar*) (in.data() + in.position());\n\t\t\t\tsize_t count = remaining / sizeof(logchar);\n\t\t\t\tout.append(src, count);\n\t\t\t\tin.position(in.position() + remaining);\n\t\t\t}\n\n\t\t\treturn APR_SUCCESS;\n\t\t}\n\n\n\n\tprivate:\n\t\tTrivialCharsetDecoder(const TrivialCharsetDecoder&);\n\t\tTrivialCharsetDecoder& operator=(const TrivialCharsetDecoder&);\n};\n\n\n#if LOG4CXX_LOGCHAR_IS_UTF8\ntypedef TrivialCharsetDecoder UTF8CharsetDecoder;\n#else\n/**\n* Converts from UTF-8 to std::wstring\n*\n*/\nclass UTF8CharsetDecoder : public CharsetDecoder\n{\n\tpublic:\n\t\tUTF8CharsetDecoder()\n\t\t{\n\t\t}\n\n\t\tvirtual ~UTF8CharsetDecoder()\n\t\t{\n\t\t}\n\n\tprivate:\n\t\tvirtual log4cxx_status_t decode(ByteBuffer& in,\n\t\t\tLogString& out)\n\t\t{\n\t\t\tif (in.remaining() > 0)\n\t\t\t{\n\t\t\t\tstd::string tmp(in.current(), in.remaining());\n\t\t\t\tstd::string::const_iterator iter = tmp.begin();\n\n\t\t\t\twhile (iter != tmp.end())\n\t\t\t\t{\n\t\t\t\t\tunsigned int sv = Transcoder::decode(tmp, iter);\n\n\t\t\t\t\tif (sv == 0xFFFF)\n\t\t\t\t\t{\n\t\t\t\t\t\tsize_t offset = iter - tmp.begin();\n\t\t\t\t\t\tin.position(in.position() + offset);\n\t\t\t\t\t\treturn APR_BADARG;\n\t\t\t\t\t}\n\t\t\t\t\telse\n\t\t\t\t\t{\n\t\t\t\t\t\tTranscoder::encode(sv, out);\n\t\t\t\t\t}\n\t\t\t\t}\n\n\t\t\t\tin.position(in.limit());\n\t\t\t}\n\n\t\t\treturn APR_SUCCESS;\n\t\t}\n\n\tprivate:\n\t\tUTF8CharsetDecoder(const UTF8CharsetDecoder&);\n\t\tUTF8CharsetDecoder& operator=(const UTF8CharsetDecoder&);\n};\n#endif\n\n/**\n* Converts from ISO-8859-1 to LogString.\n*\n*/\nclass ISOLatinCharsetDecoder : public CharsetDecoder\n{\n\tpublic:\n\t\tISOLatinCharsetDecoder()\n\t\t{\n\t\t}\n\n\t\tvirtual ~ISOLatinCharsetDecoder()\n\t\t{\n\t\t}\n\n\tprivate:\n\t\tvirtual log4cxx_status_t decode(ByteBuffer& in,\n\t\t\tLogString& out)\n\t\t{\n\t\t\tif (in.remaining() > 0)\n\t\t\t{\n\n\t\t\t\tconst unsigned char* src = (unsigned char*) in.current();\n\t\t\t\tconst unsigned char* srcEnd = src + in.remaining();\n\n\t\t\t\twhile (src < srcEnd)\n\t\t\t\t{\n\t\t\t\t\tunsigned int sv = *(src++);\n\t\t\t\t\tTranscoder::encode(sv, out);\n\t\t\t\t}\n\n\t\t\t\tin.position(in.limit());\n\t\t\t}\n\n\t\t\treturn APR_SUCCESS;\n\t\t}\n\n\n\n\tprivate:\n\t\tISOLatinCharsetDecoder(const ISOLatinCharsetDecoder&);\n\t\tISOLatinCharsetDecoder& operator=(const ISOLatinCharsetDecoder&);\n};\n\n\n/**\n* Converts from US-ASCII to LogString.\n*\n*/\nclass USASCIICharsetDecoder : public CharsetDecoder\n{\n\tpublic:\n\t\tUSASCIICharsetDecoder()\n\t\t{\n\t\t}\n\n\t\tvirtual ~USASCIICharsetDecoder()\n\t\t{\n\t\t}\n\n\tprivate:\n\n\t\tvirtual log4cxx_status_t decode(ByteBuffer& in,\n\t\t\tLogString& out)\n\t\t{\n\t\t\tlog4cxx_status_t stat = APR_SUCCESS;\n\n\t\t\tif (in.remaining() > 0)\n\t\t\t{\n\n\t\t\t\tconst unsigned char* src = (unsigned char*) in.current();\n\t\t\t\tconst unsigned char* srcEnd = src + in.remaining();\n\n\t\t\t\twhile (src < srcEnd)\n\t\t\t\t{\n\t\t\t\t\tunsigned char sv = *src;\n\n\t\t\t\t\tif (sv < 0x80)\n\t\t\t\t\t{\n\t\t\t\t\t\tsrc++;\n\t\t\t\t\t\tTranscoder::encode(sv, out);\n\t\t\t\t\t}\n\t\t\t\t\telse\n\t\t\t\t\t{\n\t\t\t\t\t\tstat = APR_BADARG;\n\t\t\t\t\t\tbreak;\n\t\t\t\t\t}\n\t\t\t\t}\n\n\t\t\t\tin.position(src - (const unsigned char*) in.data());\n\t\t\t}\n\n\t\t\treturn stat;\n\t\t}\n\n\n\n\tprivate:\n\t\tUSASCIICharsetDecoder(const USASCIICharsetDecoder&);\n\t\tUSASCIICharsetDecoder& operator=(const USASCIICharsetDecoder&);\n};\n\n/**\n * Charset decoder that uses current locale settings.\n */\nclass LocaleCharsetDecoder : public CharsetDecoder\n{\n\tpublic:\n\t\tLocaleCharsetDecoder() : state()\n\t\t{\n\t\t}\n\t\tlog4cxx_status_t decode(ByteBuffer& in, LogString& out) override\n\t\t{\n\t\t\tlog4cxx_status_t result = APR_SUCCESS;\n\t\t\tconst char* p = in.current();\n\t\t\tsize_t i = in.position();\n\t\t\tsize_t remain = in.limit() - i;\n#if !LOG4CXX_CHARSET_EBCDIC\n\t\t\tif (std::mbsinit(&this->state)) // ByteBuffer not partially decoded?\n\t\t\t{\n\t\t\t\t// Copy single byte characters\n\t\t\t\tfor (; 0 < remain && ((unsigned int) *p) < 0x80; --remain, ++i, p++)\n\t\t\t\t{\n\t\t\t\t\tout.append(1, *p);\n\t\t\t\t}\n\t\t\t}\n#endif\n\t\t\t// Decode characters that may be represented by multiple bytes\n\t\t\twhile (0 < remain)\n\t\t\t{\n\t\t\t\twchar_t ch = 0;\n\t\t\t\tsize_t n = std::mbrtowc(&ch, p, remain, &this->state);\n\t\t\t\tif (0 == n) // NULL encountered?\n\t\t\t\t{\n\t\t\t\t\t++i;\n\t\t\t\t\tbreak;\n\t\t\t\t}\n\t\t\t\tif (static_cast(-1) == n) // decoding error?\n\t\t\t\t{\n\t\t\t\t\tresult = APR_BADARG;\n\t\t\t\t\tbreak;\n\t\t\t\t}\n\t\t\t\tif (static_cast(-2) == n) // incomplete sequence?\n\t\t\t\t{\n\t\t\t\t\tbreak;\n\t\t\t\t}\n\t\t\t\tTranscoder::encode(static_cast(ch), out);\n\t\t\t\tremain -= n;\n\t\t\t\ti += n;\n\t\t\t\tp += n;\n\t\t\t}\n\t\t\tin.position(i);\n\t\t\treturn result;\n\t\t}\n\n\tprivate:\n\t\tstd::mbstate_t state;\n};\n\n\n\n} // namespace helpers\n\n} //namespace log4cxx\n\n\nCharsetDecoder::CharsetDecoder()\n{\n}\n\n\nCharsetDecoder::~CharsetDecoder()\n{\n}\n\nCharsetDecoder* CharsetDecoder::createDefaultDecoder()\n{\n#if LOG4CXX_CHARSET_UTF8\n\treturn new UTF8CharsetDecoder();\n#elif LOG4CXX_CHARSET_ISO88591 || defined(_WIN32_WCE)\n\treturn new ISOLatinCharsetDecoder();\n#elif LOG4CXX_CHARSET_USASCII\n\treturn new USASCIICharsetDecoder();\n#elif LOG4CXX_LOGCHAR_IS_WCHAR && LOG4CXX_HAS_MBSRTOWCS\n\treturn new MbstowcsCharsetDecoder();\n#else\n\treturn new LocaleCharsetDecoder();\n#endif\n}\n\nCharsetDecoderPtr CharsetDecoder::getDefaultDecoder()\n{\n\tstatic WideLife decoder(createDefaultDecoder());\n\n\t//\n\t// if invoked after static variable destruction\n\t// (if logging is called in the destructor of a static object)\n\t// then create a new decoder.\n\t//\n\tif (decoder.value() == 0)\n\t{\n\t\treturn CharsetDecoderPtr( createDefaultDecoder() );\n\t}\n\n\treturn decoder;\n}\n\nCharsetDecoderPtr CharsetDecoder::getUTF8Decoder()\n{\n\tstatic WideLife decoder(new UTF8CharsetDecoder());\n\n\t//\n\t// if invoked after static variable destruction\n\t// (if logging is called in the destructor of a static object)\n\t// then create a new decoder.\n\t//\n\tif (decoder.value() == 0)\n\t{\n\t\treturn std::make_shared();\n\t}\n\n\treturn decoder;\n}\n\nCharsetDecoderPtr CharsetDecoder::getISOLatinDecoder()\n{\n\treturn std::make_shared();\n}\n\n\nCharsetDecoderPtr CharsetDecoder::getDecoder(const LogString& charset)\n{\n\tif (StringHelper::equalsIgnoreCase(charset, LOG4CXX_STR(\"UTF-8\"), LOG4CXX_STR(\"utf-8\")) ||\n\t\tStringHelper::equalsIgnoreCase(charset, LOG4CXX_STR(\"UTF8\"), LOG4CXX_STR(\"utf8\")) ||\n\t\tStringHelper::equalsIgnoreCase(charset, LOG4CXX_STR(\"CP65001\"), LOG4CXX_STR(\"cp65001\")))\n\t{\n\t\treturn std::make_shared();\n\t}\n\telse if (StringHelper::equalsIgnoreCase(charset, LOG4CXX_STR(\"C\"), LOG4CXX_STR(\"c\")) ||\n\t\tcharset == LOG4CXX_STR(\"646\") ||\n\t\tStringHelper::equalsIgnoreCase(charset, LOG4CXX_STR(\"US-ASCII\"), LOG4CXX_STR(\"us-ascii\")) ||\n\t\tStringHelper::equalsIgnoreCase(charset, LOG4CXX_STR(\"ISO646-US\"), LOG4CXX_STR(\"iso646-US\")) ||\n\t\tStringHelper::equalsIgnoreCase(charset, LOG4CXX_STR(\"ANSI_X3.4-1968\"), LOG4CXX_STR(\"ansi_x3.4-1968\")) ||\n\t\tStringHelper::equalsIgnoreCase(charset, LOG4CXX_STR(\"CP20127\"), LOG4CXX_STR(\"cp20127\")))\n\t{\n\t\treturn std::make_shared();\n\t}\n\telse if (StringHelper::equalsIgnoreCase(charset, LOG4CXX_STR(\"ISO-8859-1\"), LOG4CXX_STR(\"iso-8859-1\")) ||\n\t\tStringHelper::equalsIgnoreCase(charset, LOG4CXX_STR(\"ISO-LATIN-1\"), LOG4CXX_STR(\"iso-latin-1\")) ||\n\t\tStringHelper::equalsIgnoreCase(charset, LOG4CXX_STR(\"CP1252\"), LOG4CXX_STR(\"cp1252\")))\n\t{\n\t\treturn std::make_shared();\n\t}\n\telse if (StringHelper::equalsIgnoreCase(charset, LOG4CXX_STR(\"LOCALE\"), LOG4CXX_STR(\"locale\")))\n\t{\n\t\treturn std::make_shared();\n\t}\n\n#if APR_HAS_XLATE\n\treturn std::make_shared(charset);\n#else\n\tthrow IllegalArgumentException(charset);\n#endif\n}\n\n\n\n\n\n\n\n// Path: src/main/cpp/throwableinformationpatternconverter.cpp\n/*\n * Licensed to the Apache Software Foundation (ASF) under one or more\n * contributor license agreements. See the NOTICE file distributed with\n * this work for additional information regarding copyright ownership.\n * The ASF licenses this file to You under the Apache License, Version 2.0\n * (the \"License\"); you may not use this file except in compliance with\n * the License. You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n */\n\n#include \n#include \n#include \n#include \n#include \n#include \n\nusing namespace LOG4CXX_NS;\nusing namespace LOG4CXX_NS::pattern;\nusing namespace LOG4CXX_NS::spi;\nusing namespace LOG4CXX_NS::helpers;\n\nstruct ThrowableInformationPatternConverter::ThrowableInformationPatternConverterPrivate :\n\tpublic PatternConverterPrivate\n{\n\t\nThrowableInformationPatternConverterPrivate( const LogString& name, const LogString& style, bool shortReport ) :\n\t\tPatternConverterPrivate( name, style ),\n\t\tshortReport(shortReport) {}\n\n\t/**\n\t * If \"short\", only first line of throwable report will be formatted.\n\t */\n\tconst bool shortReport;\n};\n\nIMPLEMENT_LOG4CXX_OBJECT(ThrowableInformationPatternConverter)\n\nThrowableInformationPatternConverter::ThrowableInformationPatternConverter(bool shortReport1) :\n\tLoggingEventPatternConverter(\n\t\tstd::make_unique(\n\t\t\tLOG4CXX_STR(\"Throwable\"),\n\t\t\tLOG4CXX_STR(\"throwable\"),\n\t\t\tshortReport1))\n{\n}\n\nPatternConverterPtr ThrowableInformationPatternConverter::newInstance(\n\tconst std::vector& options)\n{\n\tif (options.size() > 0 && options[0].compare(LOG4CXX_STR(\"short\")) == 0)\n\t{\n\t\tstatic WideLife shortConverter = std::make_shared(true);\n\t\treturn shortConverter;\n\t}\n\n\tstatic WideLife converter = std::make_shared(false);\n\treturn converter;\n}\n\nvoid ThrowableInformationPatternConverter::format(\n\tconst LoggingEventPtr& /* event */,\n\tLogString& /* toAppendTo */,\n\tPool& /* p */) const\n{\n}\n\n/**\n * This converter obviously handles throwables.\n * @return true.\n */\nbool ThrowableInformationPatternConverter::handlesThrowable() const\n{\n\treturn true;\n}\n\n// Path: src/main/cpp/htmllayout.cpp\n/*\n * Licensed to the Apache Software Foundation (ASF) under one or more\n * contributor license agreements. See the NOTICE file distributed with\n * this work for additional information regarding copyright ownership.\n * The ASF licenses this file to You under the Apache License, Version 2.0\n * (the \"License\"); you may not use this file except in compliance with\n * the License. You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n */\n\n#include \n#include \n#include \n#include \n#include \n#include \n#include \n#include \n#include \n#include \n\nusing namespace LOG4CXX_NS;\nusing namespace LOG4CXX_NS::helpers;\nusing namespace LOG4CXX_NS::spi;\n\nstruct HTMLLayout::HTMLLayoutPrivate\n{\n\tHTMLLayoutPrivate()\n\t\t: locationInfo(false)\n\t\t, title(LOG4CXX_STR(\"Log4cxx Log Messages\"))\n\t\t, dateFormat()\n\t\t, expectedPatternLength(100)\n\t\t{}\n\n\t// Print no location info by default\n\tbool locationInfo; //= false\n\n\tLogString title;\n\n\thelpers::ISO8601DateFormat dateFormat;\n\n\t// Expected length of a formatted event excluding the message text\n\tsize_t expectedPatternLength;\n};\n\nIMPLEMENT_LOG4CXX_OBJECT(HTMLLayout)\n\n\nHTMLLayout::HTMLLayout()\n\t: m_priv(std::make_unique())\n{\n\tm_priv->dateFormat.setTimeZone(TimeZone::getGMT());\n\tm_priv->expectedPatternLength = getFormattedEventCharacterCount() * 2;\n}\n\nHTMLLayout::~HTMLLayout() {}\n\n\nvoid HTMLLayout::setOption(const LogString& option,\n\tconst LogString& value)\n{\n\n\tif (StringHelper::equalsIgnoreCase(option,\n\t\t\tLOG4CXX_STR(\"TITLE\"), LOG4CXX_STR(\"title\")))\n\t{\n\t\tsetTitle(value);\n\t}\n\telse if (StringHelper::equalsIgnoreCase(option,\n\t\t\tLOG4CXX_STR(\"LOCATIONINFO\"), LOG4CXX_STR(\"locationinfo\")))\n\t{\n\t\tsetLocationInfo(OptionConverter::toBoolean(value, false));\n\t\tm_priv->expectedPatternLength = getFormattedEventCharacterCount() * 2;\n\t}\n}\n\nvoid HTMLLayout::format(LogString& output,\n\tconst spi::LoggingEventPtr& event,\n\tPool& p) const\n{\n\toutput.reserve(m_priv->expectedPatternLength + event->getMessage().size());\n\toutput.append(LOG4CXX_EOL);\n\toutput.append(LOG4CXX_STR(\"\"));\n\toutput.append(LOG4CXX_EOL);\n\toutput.append(LOG4CXX_STR(\"\"));\n\n\tm_priv->dateFormat.format(output, event->getTimeStamp(), p);\n\n\n\toutput.append(LOG4CXX_STR(\"\"));\n\toutput.append(LOG4CXX_EOL);\n\n\toutput.append(LOG4CXX_STR(\"getThreadName());\n\toutput.append(threadName);\n\toutput.append(LOG4CXX_STR(\" thread\\\">\"));\n\toutput.append(threadName);\n\toutput.append(LOG4CXX_STR(\"\"));\n\toutput.append(LOG4CXX_EOL);\n\n\toutput.append(LOG4CXX_STR(\"\"));\n\n\tif (event->getLevel()->equals(Level::getDebug()))\n\t{\n\t\toutput.append(LOG4CXX_STR(\"\"));\n\t\toutput.append(event->getLevel()->toString());\n\t\toutput.append(LOG4CXX_STR(\"\"));\n\t}\n\telse if (event->getLevel()->isGreaterOrEqual(Level::getWarn()))\n\t{\n\t\toutput.append(LOG4CXX_STR(\"\"));\n\t\toutput.append(event->getLevel()->toString());\n\t\toutput.append(LOG4CXX_STR(\"\"));\n\t}\n\telse\n\t{\n\t\toutput.append(event->getLevel()->toString());\n\t}\n\n\toutput.append(LOG4CXX_STR(\"\"));\n\toutput.append(LOG4CXX_EOL);\n\n\toutput.append(LOG4CXX_STR(\"getLoggerName());\n\toutput.append(LOG4CXX_STR(\" logger\\\">\"));\n\tTransform::appendEscapingTags(output, event->getLoggerName());\n\toutput.append(LOG4CXX_STR(\"\"));\n\toutput.append(LOG4CXX_EOL);\n\n\tif (m_priv->locationInfo)\n\t{\n\t\toutput.append(LOG4CXX_STR(\"\"));\n\t\tconst LocationInfo& locInfo = event->getLocationInformation();\n\t\tLOG4CXX_DECODE_CHAR(fileName, locInfo.getFileName());\n\t\tTransform::appendEscapingTags(output, fileName);\n\t\toutput.append(1, (logchar) 0x3A /* ':' */);\n\t\tint line = event->getLocationInformation().getLineNumber();\n\n\t\tif (line != 0)\n\t\t{\n\t\t\tStringHelper::toString(line, p, output);\n\t\t}\n\n\t\toutput.append(LOG4CXX_STR(\"\"));\n\t\toutput.append(LOG4CXX_EOL);\n\t}\n\n\toutput.append(LOG4CXX_STR(\"\"));\n\tTransform::appendEscapingTags(output, event->getRenderedMessage());\n\toutput.append(LOG4CXX_STR(\"\"));\n\toutput.append(LOG4CXX_EOL);\n\toutput.append(LOG4CXX_STR(\"\"));\n\toutput.append(LOG4CXX_EOL);\n\n\tLogString ndcVal;\n\n\tif (event->getNDC(ndcVal))\n\t{\n\t\toutput.append(LOG4CXX_STR(\"\"));\n\t\toutput.append(LOG4CXX_STR(\"NDC: \"));\n...\n", "needle_token_start": 12850, "needle_token_end": 12896, "code_context_ntokens": 15142, "output": ["```cpp\n#include \n#include \n#include \n#include \n#include \n#include \n\nusing namespace LOG4CXX_NS;\nusing namespace LOG4CXX_NS::pattern;\nusing namespace LOG4CXX_NS::spi;\nusing namespace LOG4CXX_NS::helpers;\n\nstruct ThrowableInformationPatternConverter::ThrowableInformationPatternConverterPrivate :\n\tpublic PatternConverterPrivate\n{\n\t\nThrowableInformationPatternConverterPrivate( const LogString& name, const LogString& style, bool shortReport ) :\n\t\tPatternConverterPrivate( name, style ),\n\t\tshortReport(shortReport) {}\n\n\t/**\n\t * If \"short\", only first line of throwable report will be formatted.\n\t */\n\tconst bool shortReport;\n};\n\nIMPLEMENT_LOG4CXX_OBJECT(ThrowableInformationPatternConverter)\n\nThrowableInformationPatternConverter::ThrowableInformationPatternConverter(bool shortReport1) :\n\tLoggingEventPatternConverter(\n\t\tstd::make_unique(\n\t\t\tLOG4CXX_STR(\"Throwable\"),\n\t\t\tLOG4CXX_STR(\"throwable\"),\n\t\t\tshortReport1))\n{\n}\n\nPatternConverterPtr ThrowableInformationPatternConverter::newInstance(\n\tconst std::vector& options)\n{\n\tif (options.size() > 0 && options[0].compare(LOG4CXX_STR(\"short\")) == 0)\n\t{\n\t\tstatic WideLife shortConverter = std::make_shared(true);\n\t\treturn shortConverter;\n\t}\n\n\tstatic WideLife converter = std::make_shared(false);\n\treturn converter;\n}\n\nvoid ThrowableInformationPatternConverter::format(\n\tconst LoggingEventPtr& /* event */,\n\tLogString& /* toAppendTo */,\n\tPool& /* p */) const\n{\n}\n\n/**\n * This converter obviously handles throwables.\n * @return true.\n */\nbool ThrowableInformationPatternConverter::handlesThrowable() const\n{\n\treturn true;\n}\n```"]} +{"repo": "apache/logging-log4cxx", "name": "APRCharsetEncoder", "language": "cpp", "path": "src/main/cpp/charsetencoder.cpp", "position_ratio": 0.95, "description": "\nFunction Description:\n1. **Purpose**: To convert characters from a native encoding format to a specified target encoding format using the Apache Portable Runtime (APR) library's translation facilities.\n2. **Input**: A string representing the target encoding format (e.g., \"UTF-8\", \"ISO-8859-1\").\n3. **Output**: There is no direct output from the constructor itself; however, it sets up an environment for subsequent character encoding operations, throwing an exception if the setup fails.\n4. **Procedure**: The procedure involves initializing a memory pool, determining the native character encoding format based on compile-time macros, and then attempting to create a translation handle that can convert from the native format to the specified target format. If the translation setup is successful, the environment is ready for encoding operations; otherwise, an exception is thrown indicating a problem with the specified target encoding.\n\n", "instruction": "Based on the function description and code context, please retrieve and repeat the exact described function from the code context in a code block wrapped by ```:", "template": "instruction\ncode_context\ndescription\ninstruction", "code_context": "// Path: src/main/cpp/logger.cpp\n/*\n * Licensed to the Apache Software Foundation (ASF) under one or more\n * contributor license agreements. See the NOTICE file distributed with\n * this work for additional information regarding copyright ownership.\n * The ASF licenses this file to You under the Apache License, Version 2.0\n * (the \"License\"); you may not use this file except in compliance with\n * the License. You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n */\n\n#include \n#include \n#include \n#include \n#include \n#include \n#include \n#include \n#include \n#include \n#include \n#include \n#include \n#if !defined(LOG4CXX)\n\t#define LOG4CXX 1\n#endif\n#include \n#include \n\nusing namespace LOG4CXX_NS;\nusing namespace LOG4CXX_NS::helpers;\nusing namespace LOG4CXX_NS::spi;\n\nstruct Logger::LoggerPrivate\n{\n\tLoggerPrivate(Pool& p, const LogString& name1):\n\t\tname(name1),\n\t\trepositoryRaw(0),\n\t\taai(p),\n\t\tadditive(true),\n\t\tlevelData(Level::getData()) {}\n\n\t/**\n\tThe name of this logger.\n\t*/\n\tLogString name;\n\n\t/**\n\tThe assigned level of this logger. The\n\tlevel variable need not be assigned a value in\n\twhich case it is inherited form the hierarchy. */\n\tLevelPtr level;\n\n\t/**\n\tThe parent of this logger. All loggers have at least one\n\tancestor which is the root logger. */\n\tLoggerPtr parent;\n\n\t/** The resourceBundle for localized messages.\n\n\t@see setResourceBundle, getResourceBundle\n\t*/\n\thelpers::ResourceBundlePtr resourceBundle;\n\n\n\t// Loggers need to know what Hierarchy they are in\n\tLOG4CXX_NS::spi::LoggerRepository* repositoryRaw;\n\n\thelpers::AppenderAttachableImpl aai;\n\n\t/** Additivity is set to true by default, that is children inherit\n\t the appenders of their ancestors by default. If this variable is\n\t set to false then the appenders found in the\n\t ancestors of this logger are not used. However, the children\n\t of this logger will inherit its appenders, unless the children\n\t have their additivity flag set to false too. See\n\t the user manual for more details. */\n\tbool additive;\n\n\tconst Level::Data& levelData;\n};\n\nIMPLEMENT_LOG4CXX_OBJECT(Logger)\n\nLogger::Logger(Pool& p, const LogString& name1)\n\t: m_priv(std::make_unique(p, name1))\n\t, m_threshold(0)\n{\n}\n\nLogger::~Logger()\n{\n}\n\nvoid Logger::addAppender(const AppenderPtr newAppender)\n{\n\tm_priv->aai.addAppender(newAppender);\n\tif (auto rep = getHierarchy())\n\t{\n\t\trep->fireAddAppenderEvent(this, newAppender.get());\n\t}\n}\n\nvoid Logger::reconfigure( const std::vector& appenders, bool additive1 )\n{\n\tm_priv->additive = additive1;\n\n\tm_priv->aai.removeAllAppenders();\n\n\tfor (auto const& item : appenders)\n\t{\n\t\tm_priv->aai.addAppender(item);\n\n\t\tif (auto rep = getHierarchy())\n\t\t{\n\t\t\trep->fireAddAppenderEvent(this, item.get());\n\t\t}\n\t}\n}\n\nvoid Logger::callAppenders(const spi::LoggingEventPtr& event, Pool& p) const\n{\n\tint writes = 0;\n\n\tfor (const Logger* logger = this;\n\t\tlogger != 0;\n\t\tlogger = logger->m_priv->parent.get())\n\t{\n\t\twrites += logger->m_priv->aai.appendLoopOnAppenders(event, p);\n\n\t\tif (!logger->m_priv->additive)\n\t\t{\n\t\t\tbreak;\n\t\t}\n\t}\n\n\tauto rep = getHierarchy();\n\n\tif (writes == 0 && rep)\n\t{\n\t\trep->emitNoAppenderWarning(const_cast(this));\n\t}\n}\n\nvoid Logger::closeNestedAppenders()\n{\n\tfor (auto& item : getAllAppenders())\n\t{\n\t\titem->close();\n\t}\n}\n\nvoid Logger::addEvent(const LevelPtr& level, std::string&& message, const LocationInfo& location) const\n{\n\tif (!getHierarchy()) // Has removeHierarchy() been called?\n\t\treturn;\n#if LOG4CXX_LOGCHAR_IS_UTF8\n\tauto event = std::make_shared(m_priv->name, level, location, std::move(message));\n#else\n\tLOG4CXX_DECODE_CHAR(msg, message);\n\tauto event = std::make_shared(m_priv->name, level, location, std::move(msg));\n#endif\n\tPool p;\n\tcallAppenders(event, p);\n}\n\nvoid Logger::addInfoEvent(std::string&& message, const LocationInfo& location) const\n{\n\taddEvent(m_priv->levelData.Info, std::move(message), location);\n}\n\nvoid Logger::addDebugEvent(std::string&& message, const LocationInfo& location) const\n{\n\taddEvent(m_priv->levelData.Debug, std::move(message), location);\n}\n\nvoid Logger::addTraceEvent(std::string&& message, const LocationInfo& location) const\n{\n\taddEvent(m_priv->levelData.Trace, std::move(message), location);\n}\n\nvoid Logger::forcedLog(const LevelPtr& level, const std::string& message,\n\tconst LocationInfo& location) const\n{\n\tif (!getHierarchy()) // Has removeHierarchy() been called?\n\t\treturn;\n#if LOG4CXX_LOGCHAR_IS_UTF8\n\tauto event = std::make_shared(m_priv->name, level, message, location);\n#else\n\tLOG4CXX_DECODE_CHAR(msg, message);\n\tauto event = std::make_shared(m_priv->name, level, location, std::move(msg));\n#endif\n\tPool p;\n\tcallAppenders(event, p);\n}\n\nvoid Logger::forcedLog(const LevelPtr& level1, const std::string& message) const\n{\n\tforcedLog(level1, message, LocationInfo::getLocationUnavailable());\n}\n\nvoid Logger::addEventLS(const LevelPtr& level, LogString&& message, const LocationInfo& location) const\n{\n\tif (!getHierarchy()) // Has removeHierarchy() been called?\n\t\treturn;\n\tauto event = std::make_shared(m_priv->name, level, location, std::move(message));\n\tPool p;\n\tcallAppenders(event, p);\n}\n\nvoid Logger::forcedLogLS(const LevelPtr& level1, const LogString& message,\n\tconst LocationInfo& location) const\n{\n\tif (!getHierarchy()) // Has removeHierarchy() been called?\n\t\treturn;\n\tauto event = std::make_shared(m_priv->name, level1, message, location);\n\tPool p;\n\tcallAppenders(event, p);\n}\n\n\nbool Logger::getAdditivity() const\n{\n\treturn m_priv->additive;\n}\n\nAppenderList Logger::getAllAppenders() const\n{\n\treturn m_priv->aai.getAllAppenders();\n}\n\nAppenderPtr Logger::getAppender(const LogString& name1) const\n{\n\treturn m_priv->aai.getAppender(name1);\n}\n\nconst LevelPtr& Logger::getEffectiveLevel() const\n{\n\tfor (const Logger* l = this; l != 0; l = l->m_priv->parent.get())\n\t{\n\t\tif (l->m_priv->level != 0)\n\t\t{\n\t\t\treturn l->m_priv->level;\n\t\t}\n\t}\n\n\tthrow NullPointerException(LOG4CXX_STR(\"No level specified for logger or ancestors.\"));\n#if LOG4CXX_RETURN_AFTER_THROW\n\treturn m_priv->level;\n#endif\n}\n\nLoggerRepository* Logger::getLoggerRepository() const\n{\n\treturn m_priv->repositoryRaw;\n}\n\nLoggerRepository* Logger::getHierarchy() const\n{\n\treturn m_priv->repositoryRaw;\n}\n\nResourceBundlePtr Logger::getResourceBundle() const\n{\n\tfor (const Logger* l = this; l != 0; l = l->m_priv->parent.get())\n\t{\n\t\tif (l->m_priv->resourceBundle != 0)\n\t\t{\n\t\t\treturn l->m_priv->resourceBundle;\n\t\t}\n\t}\n\n\t// It might be the case that there is no resource bundle\n\treturn 0;\n}\n\n\nLogString Logger::getResourceBundleString(const LogString& key) const\n{\n\tResourceBundlePtr rb = getResourceBundle();\n\n\t// This is one of the rare cases where we can use logging in order\n\t// to report errors from within log4j.\n\tif (rb == 0)\n\t{\n\t\treturn LogString();\n\t}\n\telse\n\t{\n\t\ttry\n\t\t{\n\t\t\treturn rb->getString(key);\n\t\t}\n\t\tcatch (MissingResourceException&)\n\t\t{\n\t\t\tlogLS(Level::getError(), LOG4CXX_STR(\"No resource is associated with key \\\"\") +\n\t\t\t\tkey + LOG4CXX_STR(\"\\\".\"), LocationInfo::getLocationUnavailable());\n\n\t\t\treturn LogString();\n\t\t}\n\t}\n}\n\n\nLoggerPtr Logger::getParent() const\n{\n\treturn m_priv->parent;\n}\n\nconst LevelPtr& Logger::getLevel() const\n{\n\treturn m_priv->level;\n}\n\nbool Logger::isAttached(const AppenderPtr appender) const\n{\n\treturn m_priv->aai.isAttached(appender);\n}\n\nbool Logger::isTraceEnabled() const\n{\n\tauto rep = getHierarchy();\n\n\tif (!rep || rep->isDisabled(Level::TRACE_INT))\n\t{\n\t\treturn false;\n\t}\n\n\treturn getEffectiveLevel()->toInt() <= Level::TRACE_INT;\n}\n\nbool Logger::isDebugEnabled() const\n{\n\tauto rep = getHierarchy();\n\n\tif (!rep || rep->isDisabled(Level::DEBUG_INT))\n\t{\n\t\treturn false;\n\t}\n\n\treturn getEffectiveLevel()->toInt() <= Level::DEBUG_INT;\n}\n\nbool Logger::isEnabledFor(const LevelPtr& level1) const\n{\n\tauto rep = getHierarchy();\n\n\tif (!rep || rep->isDisabled(level1->toInt()))\n\t{\n\t\treturn false;\n\t}\n\n\treturn level1->isGreaterOrEqual(getEffectiveLevel());\n}\n\n\nbool Logger::isInfoEnabled() const\n{\n\tauto rep = getHierarchy();\n\n\tif (!rep || rep->isDisabled(Level::INFO_INT))\n\t{\n\t\treturn false;\n\t}\n\n\treturn getEffectiveLevel()->toInt() <= Level::INFO_INT;\n}\n\nbool Logger::isErrorEnabled() const\n{\n\tauto rep = getHierarchy();\n\n\tif (!rep || rep->isDisabled(Level::ERROR_INT))\n\t{\n\t\treturn false;\n\t}\n\n\treturn getEffectiveLevel()->toInt() <= Level::ERROR_INT;\n}\n\nbool Logger::isWarnEnabled() const\n{\n\tauto rep = getHierarchy();\n\n\tif (!rep || rep->isDisabled(Level::WARN_INT))\n\t{\n\t\treturn false;\n\t}\n\n\treturn getEffectiveLevel()->toInt() <= Level::WARN_INT;\n}\n\nbool Logger::isFatalEnabled() const\n{\n\tauto rep = getHierarchy();\n\n\tif (!rep || rep->isDisabled(Level::FATAL_INT))\n\t{\n\t\treturn false;\n\t}\n\n\treturn getEffectiveLevel()->toInt() <= Level::FATAL_INT;\n}\n\n/*void Logger::l7dlog(const LevelPtr& level, const String& key,\n const char* file, int line)\n{\n\tauto rep = getHierarchy();\n\n if (!rep || rep->isDisabled(level->level))\n {\n return;\n }\n\n if (level->isGreaterOrEqual(getEffectiveLevel()))\n {\n String msg = getResourceBundleString(key);\n\n // if message corresponding to 'key' could not be found in the\n // resource bundle, then default to 'key'.\n if (msg.empty())\n {\n msg = key;\n }\n\n forcedLog(FQCN, level, msg, file, line);\n }\n}*/\n\n\n\nvoid Logger::l7dlog(const LevelPtr& level1, const LogString& key,\n\tconst LocationInfo& location, const std::vector& params) const\n{\n\tauto rep = getHierarchy();\n\n\tif (!rep || rep->isDisabled(level1->toInt()))\n\t{\n\t\treturn;\n\t}\n\n\tif (level1->isGreaterOrEqual(getEffectiveLevel()))\n\t{\n\t\tLogString pattern = getResourceBundleString(key);\n\t\tLogString msg;\n\n\t\tif (pattern.empty())\n\t\t{\n\t\t\tmsg = key;\n\t\t}\n\t\telse\n\t\t{\n\t\t\tmsg = StringHelper::format(pattern, params);\n\t\t}\n\n\t\taddEventLS(level1, std::move(msg), location);\n\t}\n}\n\nvoid Logger::l7dlog(const LevelPtr& level1, const std::string& key,\n\tconst LocationInfo& location) const\n{\n\tLOG4CXX_DECODE_CHAR(lkey, key);\n\n\tstd::vector values(0);\n\tl7dlog(level1, lkey, location, values);\n}\n\nvoid Logger::l7dlog(const LevelPtr& level1, const std::string& key,\n\tconst LocationInfo& location, const std::string& val1) const\n{\n\tLOG4CXX_DECODE_CHAR(lkey, key);\n...\n// Path: src/main/cpp/aprdatagramsocket.cpp\n/*\n * Licensed to the Apache Software Foundation (ASF) under one or more\n * contributor license agreements. See the NOTICE file distributed with\n * this work for additional information regarding copyright ownership.\n * The ASF licenses this file to You under the Apache License, Version 2.0\n * (the \"License\"); you may not use this file except in compliance with\n * the License. You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n */\n\n#include \n#include \n#include \n#include \n\nnamespace LOG4CXX_NS\n{\nnamespace helpers\n{\n\n#define _priv static_cast(m_priv.get())\n\nstruct APRDatagramSocket::APRDatagramSocketPriv : public DatagramSocketPriv {\n\tAPRDatagramSocketPriv() :\n\t\tDatagramSocketPriv(),\n\t\tsocket(nullptr)\n\t{}\n\n\tAPRDatagramSocketPriv(int port) :\n\t\tDatagramSocketPriv(port),\n\t\tsocket(nullptr)\n\t{}\n\n\tAPRDatagramSocketPriv(int port, InetAddressPtr localAddress) :\n\t\tDatagramSocketPriv(port, localAddress),\n\t\tsocket(nullptr)\n\t{}\n\n\t/** The APR socket */\n\tapr_socket_t* socket;\n\n\t/** The memory pool for the socket */\n\tPool socketPool;\n};\n\nAPRDatagramSocket::APRDatagramSocket() :\n\tDatagramSocket(std::make_unique()){\n\tinit();\n}\n\nAPRDatagramSocket::APRDatagramSocket(int port) :\n\tDatagramSocket(std::make_unique(port)){\n\tinit();\n}\n\nAPRDatagramSocket::APRDatagramSocket(int port, InetAddressPtr laddr) :\n\tDatagramSocket(std::make_unique(port, laddr)){\n\tinit();\n}\n\nvoid APRDatagramSocket::init()\n{\n\tapr_socket_t* newSocket;\n\tapr_status_t status =\n\t\tapr_socket_create(&newSocket, APR_INET, SOCK_DGRAM,\n\t\t\tAPR_PROTO_UDP, _priv->socketPool.getAPRPool());\n\t_priv->socket = newSocket;\n\n\tif (status != APR_SUCCESS)\n\t{\n\t\tthrow SocketException(status);\n\t}\n}\n\nvoid APRDatagramSocket::receive(DatagramPacketPtr& p)\n{\n\tPool addrPool;\n\n\t// Create the address from which to receive the datagram packet\n\tLOG4CXX_ENCODE_CHAR(hostAddr, p->getAddress()->getHostAddress());\n\tapr_sockaddr_t* addr;\n\tapr_status_t status =\n\t\tapr_sockaddr_info_get(&addr, hostAddr.c_str(), APR_INET,\n\t\t\tp->getPort(), 0, addrPool.getAPRPool());\n\n\tif (status != APR_SUCCESS)\n\t{\n\t\tthrow SocketException(status);\n\t}\n\n\t// receive the datagram packet\n\tapr_size_t len = p->getLength();\n\tstatus = apr_socket_recvfrom(addr, _priv->socket, 0,\n\t\t\t(char*)p->getData(), &len);\n\n\tif (status != APR_SUCCESS)\n\t{\n\t\tthrow IOException(status);\n\t}\n}\n\nvoid APRDatagramSocket::send(DatagramPacketPtr& p)\n{\n\tPool addrPool;\n\n\t// create the adress to which to send the datagram packet\n\tLOG4CXX_ENCODE_CHAR(hostAddr, p->getAddress()->getHostAddress());\n\tapr_sockaddr_t* addr;\n\tapr_status_t status =\n\t\tapr_sockaddr_info_get(&addr, hostAddr.c_str(), APR_INET, p->getPort(),\n\t\t\t0, addrPool.getAPRPool());\n\n\tif (status != APR_SUCCESS)\n\t{\n\t\tthrow SocketException(status);\n\t}\n\n\t// send the datagram packet\n\tapr_size_t len = p->getLength();\n\tstatus = apr_socket_sendto(_priv->socket, addr, 0,\n\t\t\t(char*)p->getData(), &len);\n\n\tif (status != APR_SUCCESS)\n\t{\n\t\tthrow IOException(status);\n\t}\n}\n\nvoid APRDatagramSocket::close(){\n\tif (_priv->socket != 0)\n\t{\n\t\tapr_status_t status = apr_socket_close(_priv->socket);\n\n\t\tif (status != APR_SUCCESS)\n\t\t{\n\t\t\tthrow SocketException(status);\n\t\t}\n\n\t\t_priv->socket = 0;\n\t\t_priv->localPort = 0;\n\t}\n}\n\nvoid APRDatagramSocket::bind(int localPort1, InetAddressPtr localAddress1)\n{\n\tPool addrPool;\n\n\t// Create server socket address (including port number)\n\tLOG4CXX_ENCODE_CHAR(hostAddr, localAddress1->getHostAddress());\n\tapr_sockaddr_t* server_addr;\n\tapr_status_t status =\n\t\tapr_sockaddr_info_get(&server_addr, hostAddr.c_str(), APR_INET,\n\t\t\tlocalPort1, 0, addrPool.getAPRPool());\n\n\tif (status != APR_SUCCESS)\n\t{\n\t\tthrow BindException(status);\n\t}\n\n\t// bind the socket to the address\n\tstatus = apr_socket_bind(_priv->socket, server_addr);\n\n\tif (status != APR_SUCCESS)\n\t{\n\t\tthrow BindException(status);\n\t}\n\n\tm_priv->localPort = localPort1;\n\tm_priv->localAddress = localAddress1;\n}\n\n\nvoid APRDatagramSocket::connect(InetAddressPtr address1, int port1)\n{\n\tm_priv->address = address1;\n\tm_priv->port = port1;\n\n\tPool addrPool;\n\n\t// create socket address\n\tLOG4CXX_ENCODE_CHAR(hostAddr, address1->getHostAddress());\n\tapr_sockaddr_t* client_addr;\n\tapr_status_t status =\n\t\tapr_sockaddr_info_get(&client_addr, hostAddr.c_str(), APR_INET,\n\t\t\tm_priv->port, 0, addrPool.getAPRPool());\n\n\tif (status != APR_SUCCESS)\n\t{\n\t\tthrow ConnectException(status);\n\t}\n\n\t// connect the socket\n\tstatus = apr_socket_connect(_priv->socket, client_addr);\n\n\tif (status != APR_SUCCESS)\n\t{\n\t\tthrow ConnectException(status);\n\t}\n}\n\n\nbool APRDatagramSocket::isClosed() const\n{\n\treturn _priv->socket != nullptr;\n}\n\n} //namespace helpers\n} //namespace log4cxx\n\n// Path: src/main/cpp/classregistration.cpp\n/*\n * Licensed to the Apache Software Foundation (ASF) under one or more\n * contributor license agreements. See the NOTICE file distributed with\n * this work for additional information regarding copyright ownership.\n * The ASF licenses this file to You under the Apache License, Version 2.0\n * (the \"License\"); you may not use this file except in compliance with\n * the License. You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n */\n#include \n#include \n#include \n\nusing namespace LOG4CXX_NS;\nusing namespace LOG4CXX_NS::helpers;\n\nClassRegistration::ClassRegistration(ClassAccessor accessor)\n{\n\tClass::registerClass((*accessor)());\n}\n\n\n\n// Path: src/main/cpp/level.cpp\n/*\n * Licensed to the Apache Software Foundation (ASF) under one or more\n * contributor license agreements. See the NOTICE file distributed with\n * this work for additional information regarding copyright ownership.\n * The ASF licenses this file to You under the Apache License, Version 2.0\n * (the \"License\"); you may not use this file except in compliance with\n * the License. You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n */\n\n#include \n#include \n#include \n#include \n#if !defined(LOG4CXX)\n\t#define LOG4CXX 1\n#endif\n#include \n#include \n\nusing namespace LOG4CXX_NS;\nusing namespace LOG4CXX_NS::helpers;\n\nIMPLEMENT_LOG4CXX_OBJECT_WITH_CUSTOM_CLASS(Level, LevelClass)\n\nLevelPtr Level::getOff()\n{\n\tstatic WideLife offLevel = std::make_shared(Level::OFF_INT, LOG4CXX_STR(\"OFF\"), 0);\n\treturn offLevel;\n}\n\nLevelPtr Level::getFatal()\n{\n\tstatic WideLife fatalLevel = std::make_shared(Level::FATAL_INT, LOG4CXX_STR(\"FATAL\"), 0);\n\treturn fatalLevel;\n}\n\nLevelPtr Level::getError()\n{\n\tstatic WideLife errorLevel = std::make_shared(Level::ERROR_INT, LOG4CXX_STR(\"ERROR\"), 3);\n\treturn errorLevel;\n}\n\nLevelPtr Level::getWarn()\n{\n\tstatic WideLife warnLevel = std::make_shared(Level::WARN_INT, LOG4CXX_STR(\"WARN\"), 4);\n\treturn warnLevel;\n}\n\nLevelPtr Level::getInfo()\n{\n\tstatic WideLife infoLevel = std::make_shared(Level::INFO_INT, LOG4CXX_STR(\"INFO\"), 6);\n\treturn infoLevel;\n}\n\nLevelPtr Level::getDebug()\n{\n\tstatic WideLife debugLevel = std::make_shared(Level::DEBUG_INT, LOG4CXX_STR(\"DEBUG\"), 7);\n\treturn debugLevel;\n}\n\nLevelPtr Level::getTrace()\n{\n\tstatic WideLife traceLevel = std::make_shared(Level::TRACE_INT, LOG4CXX_STR(\"TRACE\"), 7);\n\treturn traceLevel;\n}\n\n\nLevelPtr Level::getAll()\n{\n\tstatic WideLife allLevel = std::make_shared(Level::ALL_INT, LOG4CXX_STR(\"ALL\"), 7);\n\treturn allLevel;\n}\n\n\n\nLevel::Level(int level1,\n\tconst LogString& name1, int syslogEquivalent1)\n\t: level(level1), name(name1), syslogEquivalent(syslogEquivalent1)\n{\n\tAPRInitializer::initialize();\n}\n\n\nLevelPtr Level::toLevelLS(const LogString& sArg)\n{\n\treturn toLevelLS(sArg, Level::getDebug());\n}\n\nLogString Level::toString() const\n{\n\treturn name;\n}\n\n\nLevelPtr Level::toLevel(int val)\n{\n\treturn toLevel(val, Level::getDebug());\n}\n\nconst Level::Data& Level::getData()\n{\n\tstatic Data data =\n\t\t{ getOff()\n\t\t, getFatal()\n\t\t, getError()\n\t\t, getWarn()\n\t\t, getInfo()\n\t\t, getDebug()\n\t\t, getTrace()\n\t\t, getAll()\n\t\t};\n\treturn data;\n}\n\n\nLevelPtr Level::toLevel(int val, const LevelPtr& defaultLevel)\n{\n\tswitch (val)\n\t{\n\t\tcase ALL_INT:\n\t\t\treturn getAll();\n\n\t\tcase DEBUG_INT:\n\t\t\treturn getDebug();\n\n\t\tcase TRACE_INT:\n\t\t\treturn getTrace();\n\n\t\tcase INFO_INT:\n\t\t\treturn getInfo();\n\n\t\tcase WARN_INT:\n\t\t\treturn getWarn();\n\n\t\tcase ERROR_INT:\n\t\t\treturn getError();\n\n\t\tcase FATAL_INT:\n\t\t\treturn getFatal();\n\n\t\tcase OFF_INT:\n\t\t\treturn getOff();\n\n\t\tdefault:\n\t\t\treturn defaultLevel;\n\t}\n}\n\nLevelPtr Level::toLevel(const std::string& sArg)\n{\n\treturn toLevel(sArg, Level::getDebug());\n}\n\nLevelPtr Level::toLevel(const std::string& sArg, const LevelPtr& defaultLevel)\n{\n\tLOG4CXX_DECODE_CHAR(s, sArg);\n\treturn toLevelLS(s, defaultLevel);\n}\n\nvoid Level::toString(std::string& dst) const\n{\n\tTranscoder::encode(name, dst);\n}\n\n#if LOG4CXX_WCHAR_T_API\nLevelPtr Level::toLevel(const std::wstring& sArg)\n{\n\treturn toLevel(sArg, Level::getDebug());\n}\n\nLevelPtr Level::toLevel(const std::wstring& sArg, const LevelPtr& defaultLevel)\n{\n\tLOG4CXX_DECODE_WCHAR(s, sArg);\n\treturn toLevelLS(s, defaultLevel);\n}\n\nvoid Level::toString(std::wstring& dst) const\n{\n\tTranscoder::encode(name, dst);\n}\n\n#endif\n\n#if LOG4CXX_UNICHAR_API || LOG4CXX_LOGCHAR_IS_UNICHAR\nLevelPtr Level::toLevel(const std::basic_string& sArg)\n{\n\treturn toLevel(sArg, Level::getDebug());\n}\n\nLevelPtr Level::toLevel(const std::basic_string& sArg, const LevelPtr& defaultLevel)\n{\n\tLOG4CXX_DECODE_UNICHAR(s, sArg);\n\treturn toLevelLS(s, defaultLevel);\n}\n\nvoid Level::toString(std::basic_string& dst) const\n{\n\tTranscoder::encode(name, dst);\n}\n\n#endif\n\n#if LOG4CXX_CFSTRING_API\nLevelPtr Level::toLevel(const CFStringRef& sArg)\n{\n\treturn toLevel(sArg, Level::getDebug());\n}\n\nLevelPtr Level::toLevel(const CFStringRef& sArg, const LevelPtr& defaultLevel)\n{\n\tLogString s;\n\tTranscoder::decode(sArg, s);\n\treturn toLevelLS(s, defaultLevel);\n}\n\nvoid Level::toString(CFStringRef& dst) const\n{\n\tdst = Transcoder::encode(name);\n}\n#endif\n\n\nLevelPtr Level::toLevelLS(const LogString& sArg, const LevelPtr& defaultLevel)\n{\n\tconst LogString trimmed(StringHelper::trim(sArg));\n\tconst size_t len = trimmed.length();\n\n\tif (len == 4)\n\t{\n\t\tif (StringHelper::equalsIgnoreCase(trimmed, LOG4CXX_STR(\"INFO\"), LOG4CXX_STR(\"info\")))\n\t\t{\n\t\t\treturn getInfo();\n\t\t}\n\n\t\tif (StringHelper::equalsIgnoreCase(trimmed, LOG4CXX_STR(\"WARN\"), LOG4CXX_STR(\"warn\")))\n\t\t{\n\t\t\treturn getWarn();\n\t\t}\n\t}\n\telse\n\t{\n\t\tif (len == 5)\n\t\t{\n\t\t\tif (StringHelper::equalsIgnoreCase(trimmed, LOG4CXX_STR(\"DEBUG\"), LOG4CXX_STR(\"debug\")))\n\t\t\t{\n\t\t\t\treturn getDebug();\n\t\t\t}\n\n\t\t\tif (StringHelper::equalsIgnoreCase(trimmed, LOG4CXX_STR(\"TRACE\"), LOG4CXX_STR(\"trace\")))\n\t\t\t{\n\t\t\t\treturn getTrace();\n\t\t\t}\n\n\t\t\tif (StringHelper::equalsIgnoreCase(trimmed, LOG4CXX_STR(\"ERROR\"), LOG4CXX_STR(\"error\")))\n\t\t\t{\n\t\t\t\treturn getError();\n\t\t\t}\n\n\t\t\tif (StringHelper::equalsIgnoreCase(trimmed, LOG4CXX_STR(\"FATAL\"), LOG4CXX_STR(\"fatal\")))\n\t\t\t{\n\t\t\t\treturn getFatal();\n\t\t\t}\n\t\t}\n\t\telse\n\t\t{\n\t\t\tif (len == 3)\n\t\t\t{\n\t\t\t\tif (StringHelper::equalsIgnoreCase(trimmed, LOG4CXX_STR(\"OFF\"), LOG4CXX_STR(\"off\")))\n\t\t\t\t{\n\t\t\t\t\treturn getOff();\n\t\t\t\t}\n\n\t\t\t\tif (StringHelper::equalsIgnoreCase(trimmed, LOG4CXX_STR(\"ALL\"), LOG4CXX_STR(\"all\")))\n\t\t\t\t{\n\t\t\t\t\treturn getAll();\n\t\t\t\t}\n\t\t\t}\n\t\t}\n\t}\n\n\treturn defaultLevel;\n}\n\n\nbool Level::equals(const LevelPtr& level1) const\n{\n\treturn level1 && this->level == level1->level;\n}\n\nbool Level::isGreaterOrEqual(const LevelPtr& level1) const\n{\n\treturn level1 && this->level >= level1->level;\n}\n\n\n// Path: src/main/cpp/optionconverter.cpp\n/*\n * Licensed to the Apache Software Foundation (ASF) under one or more\n * contributor license agreements. See the NOTICE file distributed with\n * this work for additional information regarding copyright ownership.\n * The ASF licenses this file to You under the Apache License, Version 2.0\n * (the \"License\"); you may not use this file except in compliance with\n * the License. You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n */\n\n#include \n#include \n#include \n#include \n#include \n#include \n#include \n#include \n#include \n#include \n#include \n#include \n#include \n#include \n#include \n#include \n#include \n#include \n#include \n#include \n#include \n#include \n#if !defined(LOG4CXX)\n\t#define LOG4CXX 1\n#endif\n#include \n#include \n\nnamespace LOG4CXX_NS\n{\n\nclass ConfiguratorWatchdog : public helpers::FileWatchdog\n{\n\tspi::ConfiguratorPtr m_config;\n\tpublic:\n ConfiguratorWatchdog(const spi::ConfiguratorPtr& config, const File& filename)\n : helpers::FileWatchdog(filename)\n , m_config(config)\n {\n }\n\n /**\n Call PropertyConfigurator#doConfigure(const String& configFileName,\n const spi::LoggerRepositoryPtr& hierarchy) with the\n filename to reconfigure log4cxx.\n */\n void doOnChange() override\n {\n m_config->doConfigure(file(), LogManager::getLoggerRepository());\n }\n};\n\n}\n\nusing namespace LOG4CXX_NS;\nusing namespace LOG4CXX_NS::helpers;\nusing namespace LOG4CXX_NS::spi;\n\n\nLogString OptionConverter::convertSpecialChars(const LogString& s)\n{\n\tlogchar c;\n\tLogString sbuf;\n\n\tLogString::const_iterator i = s.begin();\n\n\twhile (i != s.end())\n\t{\n\t\tc = *i++;\n\n\t\tif (c == 0x5C /* '\\\\' */)\n\t\t{\n\t\t\tc = *i++;\n\n\t\t\tswitch (c)\n\t\t\t{\n\t\t\t\tcase 0x6E: //'n'\n\t\t\t\t\tc = 0x0A;\n\t\t\t\t\tbreak;\n\n\t\t\t\tcase 0x72: //'r'\n\t\t\t\t\tc = 0x0D;\n\t\t\t\t\tbreak;\n\n\t\t\t\tcase 0x74: //'t'\n\t\t\t\t\tc = 0x09;\n\t\t\t\t\tbreak;\n\n\t\t\t\tcase 0x66: //'f'\n\t\t\t\t\tc = 0x0C;\n\t\t\t\t\tbreak;\n\n\t\t\t\tdefault:\n\t\t\t\t\tbreak;\n\t\t\t}\n\t\t}\n\n\t\tsbuf.append(1, c);\n\t}\n\n\treturn sbuf;\n}\n\n\nbool OptionConverter::toBoolean(const LogString& value, bool dEfault)\n{\n\tif (value.length() >= 4)\n\t{\n\t\tif (StringHelper::equalsIgnoreCase(value.substr(0, 4),\n\t\t\t\tLOG4CXX_STR(\"TRUE\"), LOG4CXX_STR(\"true\")))\n\t\t{\n\t\t\treturn true;\n\t\t}\n\t}\n\n\tif (dEfault && value.length() >= 5)\n\t{\n\t\tif (StringHelper::equalsIgnoreCase(value.substr(0, 5),\n\t\t\t\tLOG4CXX_STR(\"FALSE\"), LOG4CXX_STR(\"false\")))\n\t\t{\n\t\t\treturn false;\n\t\t}\n\t}\n\n\treturn dEfault;\n}\n\nint OptionConverter::toInt(const LogString& value, int dEfault)\n{\n\tLogString trimmed(StringHelper::trim(value));\n\n\tif (trimmed.empty())\n\t{\n\t\treturn dEfault;\n\t}\n\n\tLOG4CXX_ENCODE_CHAR(cvalue, trimmed);\n\n\treturn (int) atol(cvalue.c_str());\n}\n\nlong OptionConverter::toFileSize(const LogString& s, long dEfault)\n{\n\tif (s.empty())\n\t{\n\t\treturn dEfault;\n\t}\n\n\tsize_t index = s.find_first_of(LOG4CXX_STR(\"bB\"));\n\n\tif (index != LogString::npos && index > 0)\n\t{\n\t\tlong multiplier = 1;\n\t\tindex--;\n\n\t\tif (s[index] == 0x6B /* 'k' */ || s[index] == 0x4B /* 'K' */)\n\t\t{\n\t\t\tmultiplier = 1024;\n\t\t}\n\t\telse if (s[index] == 0x6D /* 'm' */ || s[index] == 0x4D /* 'M' */)\n\t\t{\n\t\t\tmultiplier = 1024 * 1024;\n\t\t}\n\t\telse if (s[index] == 0x67 /* 'g'*/ || s[index] == 0x47 /* 'G' */)\n\t\t{\n\t\t\tmultiplier = 1024 * 1024 * 1024;\n\t\t}\n\n\t\treturn toInt(s.substr(0, index), 1) * multiplier;\n\t}\n\n\treturn toInt(s, 1);\n}\n\nLogString OptionConverter::findAndSubst(const LogString& key, Properties& props)\n{\n\tLogString value(props.getProperty(key));\n\n\tif (value.empty())\n\t{\n\t\treturn value;\n\t}\n\n\ttry\n\t{\n\t\treturn substVars(value, props);\n\t}\n\tcatch (IllegalArgumentException& e)\n\t{\n\t\tLogLog::error(((LogString) LOG4CXX_STR(\"Bad option value [\"))\n\t\t\t+ value + LOG4CXX_STR(\"].\"), e);\n\t\treturn value;\n\t}\n}\n\nLogString OptionConverter::substVars(const LogString& val, Properties& props)\n{\n\tLogString sbuf;\n\tconst logchar delimStartArray[] = { 0x24, 0x7B, 0 };\n\tconst LogString delimStart(delimStartArray);\n\tconst logchar delimStop = 0x7D; // '}';\n\tconst size_t DELIM_START_LEN = 2;\n\tconst size_t DELIM_STOP_LEN = 1;\n\n\tsize_t i = 0;\n\n\twhile (true)\n\t{\n\t\tsize_t j = val.find(delimStart, i);\n\n\t\tif (j == val.npos)\n\t\t{\n\t\t\t// no more variables\n\t\t\tif (i == 0)\n\t\t\t{\n\t\t\t\t// this is a simple string\n\t\t\t\treturn val;\n\t\t\t}\n\t\t\telse\n\t\t\t{\n\t\t\t\t// add the tail string which contails no variables and return the result.\n\t\t\t\tsbuf.append(val.substr(i, val.length() - i));\n\t\t\t\treturn sbuf;\n\t\t\t}\n\t\t}\n\t\telse\n\t\t{\n\t\t\tsbuf.append(val.substr(i, j - i));\n\t\t\tsize_t k = val.find(delimStop, j);\n\n\t\t\tif (k == val.npos)\n\t\t\t{\n\t\t\t\tLogString msg(1, (logchar) 0x22 /* '\\\"' */);\n\t\t\t\tmsg.append(val);\n\t\t\t\tmsg.append(LOG4CXX_STR(\"\\\" has no closing brace. Opening brace at position \"));\n\t\t\t\tPool p;\n\t\t\t\tStringHelper::toString(j, p, msg);\n\t\t\t\tmsg.append(1, (logchar) 0x2E /* '.' */);\n\t\t\t\tthrow IllegalArgumentException(msg);\n\t\t\t}\n\t\t\telse\n\t\t\t{\n\t\t\t\tj += DELIM_START_LEN;\n\t\t\t\tLogString key = val.substr(j, k - j);\n\t\t\t\t// first try in System properties\n\t\t\t\tLogString replacement(getSystemProperty(key, LogString()));\n\n\t\t\t\t// then try props parameter\n\t\t\t\tif (replacement.empty())\n\t\t\t\t{\n\t\t\t\t\treplacement = props.getProperty(key);\n\t\t\t\t}\n\n\t\t\t\tif (!replacement.empty())\n\t\t\t\t{\n\t\t\t\t\t// Do variable substitution on the replacement string\n\t\t\t\t\t// such that we can solve \"Hello ${x2}\" as \"Hello p1\"\n\t\t\t\t\t// the where the properties are\n\t\t\t\t\t// x1=p1\n\t\t\t\t\t// x2=${x1}\n\t\t\t\t\tLogString recursiveReplacement = substVars(replacement, props);\n\t\t\t\t\tsbuf.append(recursiveReplacement);\n\t\t\t\t}\n\n\t\t\t\ti = k + DELIM_STOP_LEN;\n\t\t\t}\n\t\t}\n\t}\n}\n\nLogString OptionConverter::getSystemProperty(const LogString& key, const LogString& def)\n{\n\tif (!key.empty())\n\t{\n\t\tLogString value(System::getProperty(key));\n\n\t\tif (!value.empty())\n\t\t{\n\t\t\treturn value;\n\t\t}\n\t}\n\n\treturn def;\n}\n\nLevelPtr OptionConverter::toLevel(const LogString& value,\n\tconst LevelPtr& defaultValue)\n{\n\tsize_t hashIndex = value.find(LOG4CXX_STR(\"#\"));\n\n\tif (hashIndex == LogString::npos)\n\t{\n\t\tif (value.empty())\n\t\t{\n\t\t\treturn defaultValue;\n\t\t}\n\t\telse\n\t\t{\n\t\t\tLogLog::debug(\n\t\t\t\t((LogString) LOG4CXX_STR(\"OptionConverter::toLevel: no class name specified, level=[\"))\n\t\t\t\t+ value\n\t\t\t\t+ LOG4CXX_STR(\"]\"));\n\t\t\t// no class name specified : use standard Level class\n\t\t\treturn Level::toLevelLS(value, defaultValue);\n\t\t}\n\t}\n\n\tLogString clazz = value.substr(hashIndex + 1);\n\tLogString levelName = value.substr(0, hashIndex);\n\tLogLog::debug(((LogString) LOG4CXX_STR(\"OptionConverter::toLevel: class=[\"))\n\t\t+ clazz + LOG4CXX_STR(\"], level=[\") + levelName + LOG4CXX_STR(\"]\"));\n\n\t// This is degenerate case but you never know.\n\tif (levelName.empty())\n\t{\n\t\treturn Level::toLevelLS(value, defaultValue);\n\t}\n\n\ttry\n\t{\n\t\tLevel::LevelClass& levelClass =\n\t\t\t(Level::LevelClass&)Loader::loadClass(clazz);\n\t\treturn levelClass.toLevel(levelName);\n\t}\n\tcatch (ClassNotFoundException&)\n\t{\n\t\tLogLog::warn(((LogString) LOG4CXX_STR(\"custom level class [\"))\n\t\t\t+ clazz + LOG4CXX_STR(\"] not found.\"));\n\t}\n\tcatch (Exception& oops)\n\t{\n\t\tLogLog::warn(\n\t\t\tLOG4CXX_STR(\"class [\") + clazz + LOG4CXX_STR(\"], level [\") + levelName +\n\t\t\tLOG4CXX_STR(\"] conversion) failed.\"), oops);\n\t}\n\tcatch (...)\n\t{\n\t\tLogLog::warn(\n\t\t\tLOG4CXX_STR(\"class [\") + clazz + LOG4CXX_STR(\"], level [\") + levelName +\n\t\t\tLOG4CXX_STR(\"] conversion) failed.\"));\n\t}\n\n\treturn defaultValue;\n}\n\n\nObjectPtr OptionConverter::instantiateByKey(Properties& props, const LogString& key,\n\tconst Class& superClass, const ObjectPtr& defaultValue)\n{\n\t// Get the value of the property in string form\n\tLogString className(findAndSubst(key, props));\n\n\tif (className.empty())\n\t{\n\t\tLogLog::error(\n\t\t\t((LogString) LOG4CXX_STR(\"Could not find value for key \")) + key);\n\t\treturn defaultValue;\n\t}\n\n\t// Trim className to avoid trailing spaces that cause problems.\n\treturn OptionConverter::instantiateByClassName(\n\t\t\tStringHelper::trim(className), superClass, defaultValue);\n}\n\nObjectPtr OptionConverter::instantiateByClassName(const LogString& className,\n\tconst Class& superClass, const ObjectPtr& defaultValue)\n{\n\tif (!className.empty())\n\t{\n\t\ttry\n\t\t{\n\t\t\tconst Class& classObj = Loader::loadClass(className);\n\t\t\tObjectPtr newObject = ObjectPtr(classObj.newInstance());\n\n\t\t\tif (!newObject->instanceof(superClass))\n\t\t\t{\n\t\t\t\treturn defaultValue;\n\t\t\t}\n\n\t\t\treturn newObject;\n\t\t}\n\t\tcatch (Exception& e)\n\t\t{\n\t\t\tLogLog::error(LOG4CXX_STR(\"Could not instantiate class [\") +\n\t\t\t\tclassName + LOG4CXX_STR(\"].\"), e);\n\t\t}\n\t}\n\n\treturn defaultValue;\n}\n\nvoid OptionConverter::selectAndConfigure(const File& configFileName,\n\tconst LogString& _clazz, spi::LoggerRepositoryPtr hierarchy, int delay)\n{\n\tConfiguratorPtr configurator;\n\tLogString clazz = _clazz;\n\n\tLogString filename(configFileName.getPath());\n\n#if LOG4CXX_HAS_DOMCONFIGURATOR\n\tif (clazz.empty()\n\t\t&& filename.length() > 4\n\t\t&& StringHelper::equalsIgnoreCase(\n\t\t\tfilename.substr(filename.length() - 4),\n\t\t\tLOG4CXX_STR(\".XML\"), LOG4CXX_STR(\".xml\")))\n\t{\n\t\tclazz = LOG4CXX_NS::xml::DOMConfigurator::getStaticClass().toString();\n\t}\n#endif\n\n\tif (!clazz.empty())\n\t{\n\t\tLogLog::debug(LOG4CXX_STR(\"Preferred configurator class: \") + clazz);\n\t\tconst Class& clazzObj = Loader::loadClass(clazz);\n\t\tObjectPtr obj = ObjectPtr(clazzObj.newInstance());\n\t\tconfigurator = LOG4CXX_NS::cast(obj);\n\n\t\tif (configurator == 0)\n\t\t{\n\t\t\tLogLog::error(LOG4CXX_STR(\"Could not instantiate configurator [\")\n\t\t\t\t+ clazz + LOG4CXX_STR(\"].\"));\n\t\t\treturn;\n\t\t}\n\t}\n\telse\n\t{\n\t\tconfigurator = std::make_shared();\n\t}\n\n\tif (0 < delay)\n\t{\n\t\tauto dog = new ConfiguratorWatchdog(configurator, configFileName);\n\t\tAPRInitializer::registerCleanup(dog);\n\t\tdog->setDelay(delay);\n\t\tdog->start();\n\t}\n\telse\n\t\tconfigurator->doConfigure(configFileName, hierarchy);\n}\n\n// Path: src/main/cpp/locale.cpp\n/*\n * Licensed to the Apache Software Foundation (ASF) under one or more\n * contributor license agreements. See the NOTICE file distributed with\n * this work for additional information regarding copyright ownership.\n * The ASF licenses this file to You under the Apache License, Version 2.0\n * (the \"License\"); you may not use this file except in compliance with\n * the License. You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n */\n\n#include \n\nusing namespace LOG4CXX_NS;\nusing namespace LOG4CXX_NS::helpers;\n\nstruct Locale::LocalePrivate\n{\n\tLocalePrivate(const LogString& language1)\n\t\t: language(language1)\n\t{\n\t}\n\n\tLocalePrivate(const LogString& language1, const LogString& country1)\n\t\t: language(language1), country(country1)\n\t{\n\t}\n\n\tLocalePrivate(const LogString& language1, const LogString& country1,\n\t\tconst LogString& variant1)\n\t\t: language(language1), country(country1), variant(variant1)\n\t{\n\t}\n\n\tconst LogString language;\n\tconst LogString country;\n\tconst LogString variant;\n};\n\nLocale::Locale(const LogString& language1)\n\t: m_priv(std::make_unique(language1))\n{\n}\n\nLocale::Locale(const LogString& language1, const LogString& country1)\n\t: m_priv(std::make_unique(language1, country1))\n{\n}\n\nLocale::Locale(const LogString& language1, const LogString& country1,\n\tconst LogString& variant1)\n\t: m_priv(std::make_unique(language1, country1, variant1))\n{\n}\n\nLocale::~Locale() {}\n\nconst LogString& Locale::getLanguage() const\n{\n\treturn m_priv->language;\n}\n\nconst LogString& Locale::getCountry() const\n{\n\treturn m_priv->country;\n}\n\nconst LogString& Locale::getVariant() const\n{\n\treturn m_priv->variant;\n}\n\n\n// Path: src/main/cpp/charsetencoder.cpp\n/*\n * Licensed to the Apache Software Foundation (ASF) under one or more\n * contributor license agreements. See the NOTICE file distributed with\n * this work for additional information regarding copyright ownership.\n * The ASF licenses this file to You under the Apache License, Version 2.0\n * (the \"License\"); you may not use this file except in compliance with\n * the License. You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n */\n#include \n#include \n#include \n#include \n#include \n#include \n#include \n#include \n\n#if !defined(LOG4CXX)\n\t#define LOG4CXX 1\n#endif\n\n#include \n#include \n#include \n\n#ifdef LOG4CXX_HAS_WCSTOMBS\n\t#include \n#endif\n\nusing namespace LOG4CXX_NS;\nusing namespace LOG4CXX_NS::helpers;\n\nIMPLEMENT_LOG4CXX_OBJECT(CharsetEncoder)\n\nnamespace LOG4CXX_NS\n{\n\nnamespace helpers\n{\n\n#if APR_HAS_XLATE\n/**\n* A character encoder implemented using apr_xlate.\n*/\nclass APRCharsetEncoder : public CharsetEncoder\n{\n\tpublic:\n\t\t\nAPRCharsetEncoder(const LogString& topage) : pool()\n\t\t{\n#if LOG4CXX_LOGCHAR_IS_WCHAR\n\t\t\tconst char* frompage = \"WCHAR_T\";\n#endif\n#if LOG4CXX_LOGCHAR_IS_UTF8\n\t\t\tconst char* frompage = \"UTF-8\";\n#endif\n#if LOG4CXX_LOGCHAR_IS_UNICHAR\n\t\t\tconst char* frompage = \"UTF-16\";\n#endif\n\t\t\tstd::string tpage(Transcoder::encodeCharsetName(topage));\n\t\t\tapr_status_t stat = apr_xlate_open(&convset,\n\t\t\t\t\ttpage.c_str(),\n\t\t\t\t\tfrompage,\n\t\t\t\t\tpool.getAPRPool());\n\n\t\t\tif (stat != APR_SUCCESS)\n\t\t\t{\n\t\t\t\tthrow IllegalArgumentException(topage);\n\t\t\t}\n\t\t}\n\n\t\tvirtual ~APRCharsetEncoder()\n\t\t{\n\t\t}\n\n\t\tvirtual log4cxx_status_t encode(const LogString& in,\n\t\t\tLogString::const_iterator& iter,\n\t\t\tByteBuffer& out)\n\t\t{\n\t\t\tapr_status_t stat;\n\t\t\tsize_t outbytes_left = out.remaining();\n\t\t\tsize_t initial_outbytes_left = outbytes_left;\n\t\t\tsize_t position = out.position();\n\n\t\t\tif (iter == in.end())\n\t\t\t{\n\t\t\t\tstd::unique_lock lock(mutex);\n\t\t\t\tstat = apr_xlate_conv_buffer(convset, NULL, NULL,\n\t\t\t\t\t\tout.data() + position, &outbytes_left);\n\t\t\t}\n\t\t\telse\n\t\t\t{\n\t\t\t\tLogString::size_type inOffset = (iter - in.begin());\n\t\t\t\tapr_size_t inbytes_left =\n\t\t\t\t\t(in.size() - inOffset) * sizeof(LogString::value_type);\n\t\t\t\tapr_size_t initial_inbytes_left = inbytes_left;\n\t\t\t\t{\n\t\t\t\t\tstd::unique_lock lock(mutex);\n\t\t\t\t\tstat = apr_xlate_conv_buffer(convset,\n\t\t\t\t\t\t\t(const char*) (in.data() + inOffset),\n\t\t\t\t\t\t\t&inbytes_left,\n\t\t\t\t\t\t\tout.data() + position,\n\t\t\t\t\t\t\t&outbytes_left);\n\t\t\t\t}\n\t\t\t\titer += ((initial_inbytes_left - inbytes_left) / sizeof(LogString::value_type));\n\t\t\t}\n\n\t\t\tout.position(out.position() + (initial_outbytes_left - outbytes_left));\n\t\t\treturn stat;\n\t\t}\n\n\tprivate:\n\t\tAPRCharsetEncoder(const APRCharsetEncoder&);\n\t\tAPRCharsetEncoder& operator=(const APRCharsetEncoder&);\n\t\tPool pool;\n\t\tstd::mutex mutex;\n\t\tapr_xlate_t* convset;\n};\n#endif\n\n#if LOG4CXX_LOGCHAR_IS_WCHAR && LOG4CXX_HAS_WCSTOMBS\n/**\n * A character encoder implemented using wcstombs.\n*/\n...\n", "needle_token_start": 14257, "needle_token_end": 14503, "code_context_ntokens": 15137, "output": ["```cpp\n#include \n#include \n#include \n#include \n#include \n#include \n#include \n#include \n\n#if !defined(LOG4CXX)\n\t#define LOG4CXX 1\n#endif\n\n#include \n#include \n#include \n\n#ifdef LOG4CXX_HAS_WCSTOMBS\n\t#include \n#endif\n\nusing namespace LOG4CXX_NS;\nusing namespace LOG4CXX_NS::helpers;\n\nIMPLEMENT_LOG4CXX_OBJECT(CharsetEncoder)\n\nnamespace LOG4CXX_NS\n{\n\nnamespace helpers\n{\n\n#if APR_HAS_XLATE\n/**\n* A character encoder implemented using apr_xlate.\n*/\nclass APRCharsetEncoder : public CharsetEncoder\n{\n\tpublic:\n\t\t\nAPRCharsetEncoder(const LogString& topage) : pool()\n\t\t{\n#if LOG4CXX_LOGCHAR_IS_WCHAR\n\t\t\tconst char* frompage = \"WCHAR_T\";\n#endif\n#if LOG4CXX_LOGCHAR_IS_UTF8\n\t\t\tconst char* frompage = \"UTF-8\";\n#endif\n#if LOG4CXX_LOGCHAR_IS_UNICHAR\n\t\t\tconst char* frompage = \"UTF-16\";\n#endif\n\t\t\tstd::string tpage(Transcoder::encodeCharsetName(topage));\n\t\t\tapr_status_t stat = apr_xlate_open(&convset,\n\t\t\t\t\ttpage.c_str(),\n\t\t\t\t\tfrompage,\n\t\t\t\t\tpool.getAPRPool());\n\n\t\t\tif (stat != APR_SUCCESS)\n\t\t\t{\n\t\t\t\tthrow IllegalArgumentException(topage);\n\t\t\t}\n\t\t}\n\n\t\tvirtual ~APRCharsetEncoder()\n\t\t{\n\t\t}\n\n\t\tvirtual log4cxx_status_t encode(const LogString& in,\n\t\t\tLogString::const_iterator& iter,\n\t\t\tByteBuffer& out)\n\t\t{\n\t\t\tapr_status_t stat;\n\t\t\tsize_t outbytes_left = out.remaining();\n\t\t\tsize_t initial_outbytes_left = outbytes_left;\n\t\t\tsize_t position = out.position();\n\n\t\t\tif (iter == in.end())\n\t\t\t{\n\t\t\t\tstd::unique_lock lock(mutex);\n\t\t\t\tstat = apr_xlate_conv_buffer(convset, NULL, NULL,\n\t\t\t\t\t\tout.data() + position, &outbytes_left);\n\t\t\t}\n\t\t\telse\n\t\t\t{\n\t\t\t\tLogString::size_type inOffset = (iter - in.begin());\n\t\t\t\tapr_size_t inbytes_left =\n\t\t\t\t\t(in.size() - inOffset) * sizeof(LogString::value_type);\n\t\t\t\tapr_size_t initial_inbytes_left = inbytes_left;\n\t\t\t\t{\n\t\t\t\t\tstd::unique_lock lock(mutex);\n\t\t\t\t\tstat = apr_xlate_conv_buffer(convset,\n\t\t\t\t\t\t\t(const char*) (in.data() + inOffset),\n\t\t\t\t\t\t\t&inbytes_left,\n\t\t\t\t\t\t\tout.data() + position,\n\t\t\t\t\t\t\t&outbytes_left);\n\t\t\t\t}\n\t\t\t\titer += ((initial_inbytes_left - inbytes_left) / sizeof(LogString::value_type));\n\t\t\t}\n\n\t\t\tout.position(out.position() + (initial_outbytes_left - outbytes_left));\n\t\t\treturn stat;\n\t\t}\n\n\tprivate:\n\t\tAPRCharsetEncoder(const APRCharsetEncoder&);\n\t\tAPRCharsetEncoder& operator=(const APRCharsetEncoder&);\n\t\tPool pool;\n\t\tstd::mutex mutex;\n\t\tapr_xlate_t* convset;\n};\n#endif\n\n#if LOG4CXX_LOGCHAR_IS_WCHAR && LOG4CXX_HAS_WCSTOMBS\n/**\n * A character encoder implemented using wcstombs.\n*/\n...\n```"]} +{"repo": "skypjack/uvw", "name": "try_read", "language": "cpp", "path": "src/uvw/util.h", "position_ratio": 0.05, "description": "\nFunction Description:\n1. **Purpose**: To attempt reading data using a provided function, handling potential buffer size issues dynamically.\n2. **Input**: A function that reads data into a buffer, additional arguments for the function, and a buffer size.\n3. **Output**: A string containing the data read, or an empty string if the read fails.\n4. **Procedure**: \n - Initially, a fixed-size buffer is used to attempt reading data.\n - If the buffer is too small (indicated by a specific error code), a larger buffer is dynamically allocated based on the required size.\n - The function is called again with the new buffer. If successful, the data is converted to a string; otherwise, an empty string is returned.\n\n", "instruction": "Based on the function description and code context, please retrieve and repeat the exact described function from the code context in a code block wrapped by ```:", "template": "instruction\ncode_context\ndescription\ninstruction", "code_context": " */\nstruct uts_name {\n uts_name(std::shared_ptr init);\n\n /**\n * @brief Gets the operating system name (like \"Linux\").\n * @return The operating system name.\n */\n std::string sysname() const noexcept;\n\n /**\n * @brief Gets the operating system release (like \"2.6.28\").\n * @return The operating system release.\n */\n std::string release() const noexcept;\n\n /**\n * @brief Gets the operating system version.\n * @return The operating system version\n */\n std::string version() const noexcept;\n\n /**\n * @brief Gets the hardware identifier.\n * @return The hardware identifier.\n */\n std::string machine() const noexcept;\n\nprivate:\n std::shared_ptr uname;\n};\n\n/**\n * @brief The IPv4 tag.\n *\n * To be used as template parameter to switch between IPv4 and IPv6.\n */\nstruct ipv4 {};\n\n/**\n * @brief The IPv6 tag.\n *\n * To be used as template parameter to switch between IPv4 and IPv6.\n */\nstruct ipv6 {};\n\n/**\n * @brief Address representation.\n */\nstruct socket_address {\n std::string ip; /*!< Either an IPv4 or an IPv6. */\n unsigned int port; /*!< A valid service identifier. */\n};\n\n/**\n * \\brief CPU information.\n */\nstruct cpu_info {\n using cpu_time = decltype(uv_cpu_info_t::cpu_times);\n\n std::string model; /*!< The model of the CPU. */\n int speed; /*!< The frequency of the CPU. */\n\n /**\n * @brief CPU times.\n *\n * It is built up of the following data members: `user`, `nice`, `sys`,\n * `idle`, `irq`, all of them having type `uint64_t`.\n */\n cpu_time times;\n};\n\n/**\n * \\brief Interface address.\n */\nstruct interface_address {\n std::string name; /*!< The name of the interface (as an example _eth0_). */\n char physical[6]; /*!< The physical address. */\n bool internal; /*!< True if it is an internal interface (as an example _loopback_), false otherwise. */\n socket_address address; /*!< The address of the given interface. */\n socket_address netmask; /*!< The netmask of the given interface. */\n};\n\nnamespace details {\n\nstatic constexpr std::size_t DEFAULT_SIZE = 128;\n\ntemplate\n\nstd::string try_read(F &&f, Args &&...args) noexcept {\n std::size_t size = DEFAULT_SIZE;\n char buf[DEFAULT_SIZE];\n std::string str{};\n auto err = std::forward(f)(args..., buf, &size);\n\n if(UV_ENOBUFS == err) {\n std::unique_ptr data{new char[size]};\n err = std::forward(f)(args..., data.get(), &size);\n\n if(0 == err) {\n str = data.get();\n }\n } else if(0 == err) {\n str.assign(buf, size);\n }\n\n return str;\n}\n\nvoid common_alloc_callback(uv_handle_t *, std::size_t suggested, uv_buf_t *buf);\n\ntemplate\nvoid common_alloc_callback(uv_handle_t *handle, std::size_t suggested, uv_buf_t *buf) {\n auto [alloc, size] = Alloc(*static_cast(handle->data), suggested);\n *buf = uv_buf_init(alloc, static_cast(size));\n}\n\nsockaddr ip_addr(const char *addr, unsigned int port);\nsocket_address sock_addr(const sockaddr_in &addr);\nsocket_address sock_addr(const sockaddr_in6 &addr);\nsocket_address sock_addr(const sockaddr &addr);\nsocket_address sock_addr(const sockaddr_storage &storage);\n\n} // namespace details\n\n/**\n * @brief Miscellaneous utilities.\n *\n * Miscellaneous functions that don\u2019t really belong to any other class.\n */\nstruct utilities {\n using malloc_func_type = void *(*)(size_t);\n using realloc_func_type = void *(*)(void *, size_t);\n using calloc_func_type = void *(*)(size_t, size_t);\n using free_func_type = void (*)(void *);\n\n /**\n * @brief OS dedicated utilities.\n */\n struct os {\n /**\n * @brief Returns the current process id.\n *\n * See the official\n * [documentation](http://docs.libuv.org/en/v1.x/misc.html#c.uv_os_getpid)\n * for further details.\n *\n * @return The current process id.\n */\n static pid_type pid() noexcept;\n\n /**\n * @brief Returns the parent process id.\n *\n * See the official\n * [documentation](http://docs.libuv.org/en/v1.x/misc.html#c.uv_os_getppid)\n * for further details.\n *\n * @return The parent process id.\n */\n static pid_type ppid() noexcept;\n\n /**\n * @brief Gets the current user's home directory.\n *\n * See the official\n * [documentation](http://docs.libuv.org/en/v1.x/misc.html#c.uv_os_homedir)\n * for further details.\n *\n * @return The current user's home directory, an empty string in case of\n * errors.\n */\n static std::string homedir() noexcept;\n\n /**\n * @brief Gets the temp directory.\n *\n * See the official\n * [documentation](http://docs.libuv.org/en/v1.x/misc.html#c.uv_os_tmpdir)\n * for further details.\n *\n * @return The temp directory, an empty string in case of errors.\n */\n static std::string tmpdir() noexcept;\n\n /**\n * @brief Retrieves an environment variable.\n * @param name The name of the variable to be retrieved.\n * @return The value of the environment variable, an empty string in\n * case of errors.\n */\n static std::string env(const std::string &name) noexcept;\n\n /**\n * @brief Creates, updates or deletes an environment variable.\n * @param name The name of the variable to be updated.\n * @param value The value to be used for the variable (an empty string\n * to unset it).\n * @return True in case of success, false otherwise.\n */\n static bool env(const std::string &name, const std::string &value) noexcept;\n\n /**\n * @brief Retrieves all environment variables and iterates them.\n *\n * Environment variables are passed one at a time to the callback in the\n * form of `std::string_view`s.
\n * The signature of the function call operator must be such that it\n * accepts two parameters, the name and the value of the i-th variable.\n *\n * @tparam Func Type of a function object to which to pass environment\n * variables.\n * @param func A function object to which to pass environment variables.\n * @return True in case of success, false otherwise.\n */\n template\n static std::enable_if_t, bool>\n env(Func func) noexcept {\n uv_env_item_t *items = nullptr;\n int count{};\n\n const bool ret = (uv_os_environ(&items, &count) == 0);\n\n if(ret) {\n for(int pos = 0; pos < count; ++pos) {\n func(std::string_view{items[pos].name}, std::string_view{items[pos].value});\n }\n\n uv_os_free_environ(items, count);\n }\n\n return ret;\n }\n\n /**\n * @brief Returns the hostname.\n * @return The hostname, an empty string in case of errors.\n */\n static std::string hostname() noexcept;\n\n /**\n * @brief Gets name and information about the current kernel.\n *\n * This function can be used to get name and information about the\n * current kernel. The populated data includes the operating system\n * name, release, version, and machine.\n *\n * @return Name and information about the current kernel.\n */\n static uts_name uname() noexcept;\n\n /**\n * @brief Gets a subset of the password file entry.\n *\n * This function can be used to get the subset of the password file\n * entry for the current effective uid (not the real uid).\n *\n * See the official\n * [documentation](http://docs.libuv.org/en/v1.x/misc.html#c.uv_os_get_passwd)\n * for further details.\n *\n * @return The accessible subset of the password file entry.\n */\n static passwd_info passwd() noexcept;\n\n /**\n * @brief Retrieves the scheduling priority of a process.\n *\n * The returned value is between -20 (high priority) and 19 (low priority).\n * A value that is out of range is returned in case of errors.\n *\n * @note\n * On Windows, the result won't equal necessarily the exact value of the\n * priority because of a mapping to a Windows priority class.\n *\n * @param pid A valid process id.\n * @return The scheduling priority of the process.\n */\n static int priority(pid_type pid);\n\n /**\n * @brief Sets the scheduling priority of a process.\n *\n * The returned value range is between -20 (high priority) and 19 (low\n * priority).\n *\n * @note\n * On Windows, the priority is mapped to a Windows priority class. When\n * retrieving the process priority, the result won't equal necessarily the\n * exact value of the priority.\n *\n * @param pid A valid process id.\n * @param prio The scheduling priority to set to the process.\n * @return True in case of success, false otherwise.\n */\n static bool priority(pid_type pid, int prio);\n };\n\n /**\n * @brief Gets the type of the handle given a category.\n * @param category A properly initialized handle category.\n * @return The actual type of the handle as defined by handle_type\n */\n static handle_type guess_handle(handle_category category) noexcept;\n\n /**\n * @brief Gets the type of the stream to be used with the given descriptor.\n *\n * Returns the type of stream that should be used with a given file\n * descriptor.
\n * Usually this will be used during initialization to guess the type of the\n * stdio streams.\n *\n * @param file A valid descriptor.\n * @return One of the following types:\n *\n * * `handle_type::UNKNOWN`\n * * `handle_type::PIPE`\n * * `handle_type::TCP`\n * * `handle_type::TTY`\n * * `handle_type::UDP`\n * * `handle_type::FILE`\n */\n static handle_type guess_handle(file_handle file) noexcept;\n\n /** @brief Gets information about the CPUs on the system.\n *\n * This function can be used to query the underlying system and get a set of\n * descriptors of all the available CPUs.\n *\n * @return A set of descriptors of all the available CPUs.\n */\n static std::vector cpu() noexcept;\n\n /**\n * @brief Gets a set of descriptors of all the available interfaces.\n *\n * This function can be used to query the underlying system and get a set of\n * descriptors of all the available interfaces, either internal or not.\n *\n * @return A set of descriptors of all the available interfaces.\n */\n static std::vector interface_addresses() noexcept;\n\n /**\n * @brief IPv6-capable implementation of\n * [if_indextoname](https://linux.die.net/man/3/if_indextoname).\n *\n * Mapping between network interface names and indexes.\n *\n * See the official\n * [documentation](http://docs.libuv.org/en/v1.x/misc.html#c.uv_if_indextoname)\n * for further details.\n *\n * @param index Network interface index.\n * @return Network interface name.\n */\n static std::string index_to_name(unsigned int index) noexcept;\n\n /**\n * @brief Retrieves a network interface identifier.\n *\n * See the official\n * [documentation](http://docs.libuv.org/en/v1.x/misc.html#c.uv_if_indextoiid)\n * for further details.\n *\n * @param index Network interface index.\n * @return Network interface identifier.\n */\n static std::string index_to_iid(unsigned int index) noexcept;\n\n /**\n * @brief Override the use of some standard library\u2019s functions.\n *\n * Override the use of the standard library\u2019s memory allocation\n * functions.
\n * This method must be invoked before any other `uvw` function is called or\n * after all resources have been freed and thus the underlying library\n * doesn\u2019t reference any allocated memory chunk.\n *\n * If any of the function pointers is _null_, the invokation will fail.\n *\n * @note\n * There is no protection against changing the allocator multiple times. If\n * the user changes it they are responsible for making sure the allocator is\n * changed while no memory was allocated with the previous allocator, or\n * that they are compatible.\n *\n * @param malloc_func Replacement function for _malloc_.\n * @param realloc_func Replacement function for _realloc_.\n * @param calloc_func Replacement function for _calloc_.\n * @param free_func Replacement function for _free_.\n * @return True in case of success, false otherwise.\n */\n static bool replace_allocator(malloc_func_type malloc_func, realloc_func_type realloc_func, calloc_func_type calloc_func, free_func_type free_func) noexcept;\n\n /**\n * @brief Gets the load average.\n * @return `[0,0,0]` on Windows (not available), the load average otherwise.\n */\n static std::array load_average() noexcept;\n\n /**\n * @brief Store the program arguments.\n *\n * Required for getting / setting the process title.\n *\n * @return Arguments that haven't been consumed internally.\n */\n static char **setup_args(int argc, char **argv);\n\n /**\n * @brief Gets the title of the current process.\n * @return The process title.\n */\n static std::string process_title();\n\n /**\n * @brief Sets the current process title.\n * @param title The process title to be set.\n * @return True in case of success, false otherwise.\n */\n static bool process_title(const std::string &title);\n\n /**\n * @brief Gets memory information (in bytes).\n * @return Memory information.\n */\n static uint64_t total_memory() noexcept;\n\n /**\n * @brief Gets the amount of memory available to the process (in bytes).\n *\n * Gets the amount of memory available to the process based on limits\n * imposed by the OS. If there is no such constraint, or the constraint is\n * unknown, `0` is returned.
\n * Note that it is not unusual for this value to be less than or greater\n * than `totalMemory`.\n *\n * @return Amount of memory available to the process.\n */\n static uint64_t constrained_memory() noexcept;\n\n /**\n * @brief Gets the amount of free memory still available to the process.\n * @return Amount of free memory still available to the process (in bytes).\n */\n static uint64_t available_memory() noexcept;\n\n /**\n * @brief Gets the current system uptime.\n * @return The current system uptime or 0 in case of errors.\n */\n static double uptime() noexcept;\n\n /**\n * @brief Gets the resource usage measures for the current process.\n * @return Resource usage measures, zeroes-filled object in case of errors.\n */\n static resource_usage rusage() noexcept;\n\n /**\n * @brief Gets the current system time from a high-resolution clock source.\n * @param source Clock source, either real-time or monotonic.\n * @return Current system time from the given high-resolution clock source.\n */\n static timespec64 gettime(clock_id source) noexcept;\n\n /**\n * @brief Gets the current high-resolution real time.\n *\n * The time is expressed in nanoseconds. It is relative to an arbitrary time\n * in the past. It is not related to the time of the day and therefore not\n * subject to clock drift. The primary use is for measuring performance\n * between interval.\n *\n * @return The current high-resolution real time.\n */\n static uint64_t hrtime() noexcept;\n\n /**\n * @brief Gets the executable path.\n * @return The executable path, an empty string in case of errors.\n */\n static std::string path() noexcept;\n\n /**\n * @brief Gets the current working directory.\n * @return The current working directory, an empty string in case of errors.\n */\n static std::string cwd() noexcept;\n\n /**\n * @brief Changes the current working directory.\n * @param dir The working directory to be set.\n * @return True in case of success, false otherwise.\n */\n static bool chdir(const std::string &dir) noexcept;\n\n /**\n * @brief Cross-platform implementation of\n * [`gettimeofday`](https://linux.die.net/man/2/gettimeofday)\n * @return The current time.\n */\n static timeval64 time_of_day() noexcept;\n\n /**\n * @brief Causes the calling thread to sleep for a while.\n * @param msec Number of milliseconds to sleep.\n */\n static void sleep(unsigned int msec) noexcept;\n\n /**\n * @brief Returns an estimate of the amount of parallelism a program should\n * use (always a non-zero value).\n * @return Estimate of the amount of parallelism a program should use.\n */\n static unsigned int available_parallelism() noexcept;\n};\n\n/**\n * @brief Helper type for visitors.\n * @tparam Func Types of function objects.\n */\ntemplate\nstruct overloaded: Func... {\n using Func::operator()...;\n};\n\n/**\n * @brief Deduction guide.\n * @tparam Func Types of function objects.\n */\ntemplate\noverloaded(Func...) -> overloaded;\n\n} // namespace uvw\n\n#ifndef UVW_AS_LIB\n# include \"util.cpp\"\n#endif\n\n#endif // UVW_UTIL_INCLUDE_H\n\n// Path: src/uvw/loop.cpp\n#ifdef UVW_AS_LIB\n# include \"loop.h\"\n#endif\n\n#include \"config.h\"\n\nnamespace uvw {\n\nUVW_INLINE loop::loop(std::unique_ptr ptr) noexcept\n : uv_loop{std::move(ptr)} {}\n\nUVW_INLINE std::shared_ptr loop::create() {\n auto ptr = std::unique_ptr{new uv_loop_t, [](uv_loop_t *l) { delete l; }};\n auto curr = std::shared_ptr{new loop{std::move(ptr)}};\n\n if(uv_loop_init(curr->uv_loop.get())) {\n curr = nullptr;\n }\n\n return curr;\n}\n\nUVW_INLINE std::shared_ptr loop::create(uv_loop_t *res) {\n auto ptr = std::unique_ptr{res, [](uv_loop_t *) {}};\n return std::shared_ptr{new loop{std::move(ptr)}};\n}\n\nUVW_INLINE std::shared_ptr loop::get_default() {\n static std::weak_ptr ref;\n std::shared_ptr curr;\n\n if(ref.expired()) {\n auto def = uv_default_loop();\n\n if(def) {\n auto ptr = std::unique_ptr(def, [](uv_loop_t *) {});\n curr = std::shared_ptr{new loop{std::move(ptr)}};\n }\n\n ref = curr;\n } else {\n curr = ref.lock();\n }\n\n return curr;\n}\n\nUVW_INLINE loop::~loop() noexcept {\n if(uv_loop) {\n close();\n }\n}\n\nUVW_INLINE int loop::close() {\n int ret = 0;\n\n if(uv_loop) {\n ret = uv_loop_close(uv_loop.get());\n uv_loop.reset();\n }\n\n return ret;\n}\n\nUVW_INLINE int loop::run(run_mode mode) noexcept {\n return uv_run(uv_loop.get(), static_cast(mode));\n}\n\nUVW_INLINE bool loop::alive() const noexcept {\n return !!uv_loop_alive(uv_loop.get());\n}\n\nUVW_INLINE void loop::stop() noexcept {\n uv_stop(uv_loop.get());\n}\n\nUVW_INLINE int loop::descriptor() const noexcept {\n return uv_backend_fd(uv_loop.get());\n}\n\nUVW_INLINE std::pair loop::timeout() const noexcept {\n auto to = uv_backend_timeout(uv_loop.get());\n return std::make_pair(to == -1, time{to});\n}\n\nUVW_INLINE loop::time loop::idle_time() const noexcept {\n return time{uv_metrics_idle_time(uv_loop.get())};\n}\n\nUVW_INLINE metrics_type loop::metrics() const noexcept {\n metrics_type res{};\n uv_metrics_info(uv_loop.get(), &res);\n return res;\n}\n\nUVW_INLINE loop::time loop::now() const noexcept {\n return time{uv_now(uv_loop.get())};\n}\n\nUVW_INLINE void loop::update() const noexcept {\n return uv_update_time(uv_loop.get());\n}\n\nUVW_INLINE int loop::fork() noexcept {\n return uv_loop_fork(uv_loop.get());\n}\n\nUVW_INLINE void loop::data(std::shared_ptr ud) {\n user_data = std::move(ud);\n}\n\nUVW_INLINE const uv_loop_t *loop::raw() const noexcept {\n return uv_loop.get();\n}\n\nUVW_INLINE uv_loop_t *loop::raw() noexcept {\n return const_cast(const_cast(this)->raw());\n}\n\n} // namespace uvw\n\n// Path: src/uvw/loop.h\n#ifndef UVW_LOOP_INCLUDE_H\n#define UVW_LOOP_INCLUDE_H\n\n#ifdef _WIN32\n# include \n#endif\n\n#include \n#include \n#include \n#include \n#include \n#include \n#include \"config.h\"\n#include \"emitter.h\"\n#include \"util.h\"\n\nnamespace uvw {\n\nclass async_handle;\nclass check_handle;\nclass fs_event_handle;\nclass fs_poll_handle;\nclass idle_handle;\nclass pipe_handle;\nclass poll_handle;\nclass prepare_handle;\nclass process_handle;\nclass signal_handle;\nclass tcp_handle;\nclass timer_handle;\nclass tty_handle;\nclass udp_handle;\n\nnamespace details {\n\nenum class uvw_loop_option : std::underlying_type_t {\n BLOCK_SIGNAL = UV_LOOP_BLOCK_SIGNAL,\n IDLE_TIME = UV_METRICS_IDLE_TIME\n};\n\nenum class uvw_run_mode : std::underlying_type_t {\n DEFAULT = UV_RUN_DEFAULT,\n ONCE = UV_RUN_ONCE,\n NOWAIT = UV_RUN_NOWAIT\n};\n\n} // namespace details\n\nusing metrics_type = uv_metrics_t; /*!< Library equivalent for uv_metrics_t. */\n\n/**\n * @brief The loop class.\n *\n * The event loop is the central part of `uvw`'s functionalities, as well as\n * `libuv`'s ones.
\n * It takes care of polling for I/O and scheduling callbacks to be run based on\n * different sources of events.\n */\nclass loop final: public emitter, public std::enable_shared_from_this {\n using deleter = void (*)(uv_loop_t *);\n\n template\n friend class resource;\n\n class uv_token {\n friend class loop;\n explicit uv_token(int) {}\n };\n\n template\n auto init(int, Type &value) -> decltype(value.init()) {\n return value.init();\n }\n\n template\n int init(char, Type &) {\n return 0;\n }\n\n loop(std::unique_ptr ptr) noexcept;\n\npublic:\n using token = uv_token;\n using time = std::chrono::duration;\n using option = details::uvw_loop_option;\n using run_mode = details::uvw_run_mode;\n\n /**\n * @brief Initializes a new loop instance.\n * @return A pointer to the newly created loop.\n */\n static std::shared_ptr create();\n\n /**\n * @brief Initializes a new loop instance from an existing resource.\n *\n * The lifetime of the resource must exceed that of the instance to which\n * it's associated. Management of the memory associated with the resource is\n * in charge of the user.\n *\n * @param res A valid pointer to a correctly initialized resource.\n * @return A pointer to the newly created loop.\n */\n static std::shared_ptr create(uv_loop_t *res);\n\n /**\n * @brief Gets the initialized default loop.\n *\n * It may return an empty pointer in case of failure.
\n * This function is just a convenient way for having a global loop\n * throughout an application, the default loop is in no way different than\n * the ones initialized with `create()`.
\n * As such, the default loop can be closed with `close()` so the resources\n * associated with it are freed (even if it is not strictly necessary).\n *\n * @return The initialized default loop.\n */\n static std::shared_ptr get_default();\n\n loop(const loop &) = delete;\n loop(loop &&other) = delete;\n\n loop &operator=(const loop &) = delete;\n loop &operator=(loop &&other) = delete;\n\n ~loop() noexcept;\n\n /**\n * @brief Sets additional loop options.\n *\n * You should normally call this before the first call to uv_run() unless\n * mentioned otherwise.
\n * Supported options:\n *\n * * `loop::option::BLOCK_SIGNAL`: Block a signal when polling for new\n * events. A second argument is required and it is the signal number.\n * * `loop::option::IDLE_TIME`: Accumulate the amount of idle time the event\n * loop spends in the event provider. This option is necessary to use\n * `idle_time()`.\n *\n * See the official\n * [documentation](http://docs.libuv.org/en/v1.x/loop.html#c.uv_loop_configure)\n * for further details.\n *\n * @return Underlying return value.\n */\n template\n int configure(option flag, Args &&...args) {\n return uv_loop_configure(uv_loop.get(), static_cast(flag), std::forward(args)...);\n }\n\n /**\n * @brief Creates resources of any type.\n *\n * This should be used as a default method to create resources.
\n * The arguments are the ones required for the specific resource.\n *\n * Use it as `loop->resource()`.\n *\n * @return A pointer to the newly created resource.\n */\n template\n std::shared_ptr resource(Args &&...args) {\n auto ptr = uninitialized_resource(std::forward(args)...);\n return (init(0, *ptr) == 0) ? ptr : nullptr;\n }\n\n /**\n * @brief Creates uninitialized resources of any type.\n * @return A pointer to the newly created resource.\n */\n template\n std::shared_ptr uninitialized_resource(Args &&...args) {\n return std::make_shared(token{0}, shared_from_this(), std::forward(args)...);\n }\n\n /**\n * @brief Releases all internal loop resources.\n *\n * Call this function only when the loop has finished executing and all open\n * handles and requests have been closed, or the loop will error.\n *\n * @return Underlying return value.\n */\n int close();\n\n /**\n * @brief Runs the event loop.\n *\n * Available modes are:\n *\n * * `loop::run_mode::DEFAULT`: Runs the event loop until there are no more\n * active and referenced handles or requests.\n * * `loop::run_mode::ONCE`: Poll for i/o once. Note that this function\n * blocks if there are no pending callbacks.\n * * `loop::run_mode::NOWAIT`: Poll for i/o once but don\u2019t block if there\n * are no pending callbacks.\n *\n * See the official\n * [documentation](http://docs.libuv.org/en/v1.x/loop.html#c.uv_run)\n * for further details.\n *\n * @return Underlying return value.\n */\n int run(run_mode mode = run_mode::DEFAULT) noexcept;\n\n /**\n * @brief Checks if there are active resources.\n * @return True if there are active resources in the loop.\n */\n bool alive() const noexcept;\n\n /**\n * @brief Stops the event loop.\n *\n * It causes `run()` to end as soon as possible.
\n * This will happen not sooner than the next loop iteration.
\n * If this function was called before blocking for I/O, the loop won\u2019t block\n * for I/O on this iteration.\n */\n void stop() noexcept;\n\n /**\n * @brief Get backend file descriptor.\n *\n * Only kqueue, epoll and event ports are supported.
\n * This can be used in conjunction with `run(loop::run_mode::NOWAIT)` to\n * poll in one thread and run the event loop\u2019s callbacks in another.\n *\n * @return The backend file descriptor.\n */\n int descriptor() const noexcept;\n\n /**\n * @brief Gets the poll timeout.\n * @return A `std::pair` composed as it follows:\n * * A boolean value that is true in case of valid timeout, false otherwise.\n * * Milliseconds (`std::chrono::duration`).\n */\n std::pair timeout() const noexcept;\n\n /**\n * @brief Returns the amount of time the event loop has been idle. The call\n * is thread safe.\n * @return The accumulated time spent idle.\n */\n time idle_time() const noexcept;\n\n /**\n * @brief Tracks various internal operations of the event loop.\n * @return Event loop metrics.\n */\n metrics_type metrics() const noexcept;\n\n /**\n * @brief Returns the current timestamp in milliseconds.\n *\n * The timestamp is cached at the start of the event loop tick.
\n * The timestamp increases monotonically from some arbitrary point in\n * time.
\n * Don\u2019t make assumptions about the starting point, you will only get\n * disappointed.\n *\n * @return The current timestamp in milliseconds (actual type is\n * `std::chrono::duration`).\n */\n time now() const noexcept;\n\n /**\n * @brief Updates the event loop\u2019s concept of _now_.\n *\n * The current time is cached at the start of the event loop tick in order\n * to reduce the number of time-related system calls.
\n * You won\u2019t normally need to call this function unless you have callbacks\n * that block the event loop for longer periods of time, where _longer_ is\n * somewhat subjective but probably on the order of a millisecond or more.\n */\n void update() const noexcept;\n\n /**\n * @brief Walks the list of handles.\n *\n * The callback is invoked once for each handle that is still active.\n *\n * @param callback A function to invoke once for each active handle.\n */\n template\n void walk(Func callback) {\n auto func = [](uv_handle_t *hndl, void *callback_func) {\n if(hndl->data) {\n auto &cb = *static_cast(callback_func);\n\n switch(utilities::guess_handle(handle_category{hndl->type})) {\n case handle_type::ASYNC:\n cb(*static_cast(hndl->data));\n break;\n case handle_type::CHECK:\n cb(*static_cast(hndl->data));\n break;\n case handle_type::FS_EVENT:\n cb(*static_cast(hndl->data));\n break;\n case handle_type::FS_POLL:\n cb(*static_cast(hndl->data));\n break;\n case handle_type::IDLE:\n cb(*static_cast(hndl->data));\n break;\n case handle_type::PIPE:\n cb(*static_cast(hndl->data));\n break;\n case handle_type::POLL:\n cb(*static_cast(hndl->data));\n break;\n case handle_type::PREPARE:\n cb(*static_cast(hndl->data));\n break;\n case handle_type::PROCESS:\n cb(*static_cast(hndl->data));\n break;\n case handle_type::SIGNAL:\n cb(*static_cast(hndl->data));\n break;\n case handle_type::TCP:\n cb(*static_cast(hndl->data));\n break;\n case handle_type::TIMER:\n cb(*static_cast(hndl->data));\n break;\n case handle_type::TTY:\n cb(*static_cast(hndl->data));\n break;\n case handle_type::UDP:\n cb(*static_cast(hndl->data));\n break;\n default:\n // this handle isn't managed by uvw, let it be...\n break;\n }\n }\n };\n\n uv_walk(uv_loop.get(), func, &callback);\n }\n\n /**\n * @brief Reinitialize any kernel state necessary in the child process after\n * a fork(2) system call.\n *\n * Previously started watchers will continue to be started in the child\n * process.\n *\n * It is necessary to explicitly call this function on every event loop\n * created in the parent process that you plan to continue to use in the\n * child, including the default loop (even if you don\u2019t continue to use it\n * in the parent). This function must be called before calling any API\n * function using the loop in the child. Failure to do so will result in\n * undefined behaviour, possibly including duplicate events delivered to\n * both parent and child or aborting the child process.\n *\n * When possible, it is preferred to create a new loop in the child process\n * instead of reusing a loop created in the parent. New loops created in the\n * child process after the fork should not use this function.\n *\n * Note that this function is not implemented on Windows.
\n * Note also that this function is experimental in `libuv`. It may contain\n * bugs, and is subject to change or removal. API and ABI stability is not\n * guaranteed.\n *\n * See the official\n * [documentation](http://docs.libuv.org/en/v1.x/loop.html#c.uv_loop_fork)\n * for further details.\n *\n * @return Underlying return value.\n */\n int fork() noexcept;\n\n /**\n * @brief Gets user-defined data. `uvw` won't use this field in any case.\n * @return User-defined data if any, an invalid pointer otherwise.\n */\n template\n std::shared_ptr data() const {\n return std::static_pointer_cast(user_data);\n }\n\n /**\n * @brief Sets arbitrary data. `uvw` won't use this field in any case.\n * @param ud User-defined arbitrary data.\n */\n void data(std::shared_ptr ud);\n\n /**\n * @brief Gets the underlying raw data structure.\n *\n * This function should not be used, unless you know exactly what you are\n * doing and what are the risks.
\n * Going raw is dangerous, mainly because the lifetime management of a loop,\n * a handle or a request is in charge to the library itself and users should\n * not work around it.\n *\n * @warning\n * Use this function at your own risk, but do not expect any support in case\n * of bugs.\n *\n * @return The underlying raw data structure.\n */\n const uv_loop_t *raw() const noexcept;\n\n /**\n * @brief Gets the underlying raw data structure.\n *\n * This function should not be used, unless you know exactly what you are\n * doing and what are the risks.
\n * Going raw is dangerous, mainly because the lifetime management of a loop,\n * a handle or a request is in charge to the library itself and users should\n * not work around it.\n *\n * @warning\n * Use this function at your own risk, but do not expect any support in case\n * of bugs.\n *\n * @return The underlying raw data structure.\n */\n uv_loop_t *raw() noexcept;\n\nprivate:\n std::unique_ptr uv_loop;\n std::shared_ptr user_data{nullptr};\n};\n\n} // namespace uvw\n\n#ifndef UVW_AS_LIB\n# include \"loop.cpp\"\n#endif\n\n#endif // UVW_LOOP_INCLUDE_H\n\n// Path: src/uvw/uv_type.hpp\n#ifndef UVW_UV_TYPE_INCLUDE_H\n#define UVW_UV_TYPE_INCLUDE_H\n\n#include \n#include \n#include \n#include \"config.h\"\n#include \"loop.h\"\n\nnamespace uvw {\n\n/**\n * @brief Wrapper class for underlying types.\n *\n * It acts mainly as a wrapper around data structures of the underlying library.\n */\ntemplate\nstruct uv_type {\n explicit uv_type(loop::token, std::shared_ptr ref) noexcept\n : owner{std::move(ref)}, resource{} {}\n\n uv_type(const uv_type &) = delete;\n uv_type(uv_type &&) = delete;\n\n uv_type &operator=(const uv_type &) = delete;\n uv_type &operator=(uv_type &&) = delete;\n\n /**\n * @brief Gets the loop from which the resource was originated.\n * @return A reference to a loop instance.\n */\n loop &parent() const noexcept {\n return *owner;\n }\n\n /**\n * @brief Gets the underlying raw data structure.\n *\n * This function should not be used, unless you know exactly what you are\n * doing and what are the risks.
\n * Going raw is dangerous, mainly because the lifetime management of a loop,\n * a handle or a request is in charge to the library itself and users should\n * not work around it.\n *\n * @warning\n * Use this function at your own risk, but do not expect any support in case\n * of bugs.\n *\n * @return The underlying raw data structure.\n */\n const U *raw() const noexcept {\n return &resource;\n }\n\n /**\n * @brief Gets the underlying raw data structure.\n *\n * This function should not be used, unless you know exactly what you are\n * doing and what are the risks.
\n * Going raw is dangerous, mainly because the lifetime management of a loop,\n * a handle or a request is in charge to the library itself and users should\n * not work around it.\n *\n * @warning\n * Use this function at your own risk, but do not expect any support in case\n * of bugs.\n *\n * @return The underlying raw data structure.\n */\n U *raw() noexcept {\n return &resource;\n }\n\nprotected:\n ~uv_type() = default;\n\nprivate:\n std::shared_ptr owner;\n U resource;\n};\n\n} // namespace uvw\n\n#endif // UVW_UV_TYPE_INCLUDE_H\n\n// Path: src/uvw/resource.hpp\n#ifndef UVW_RESOURCE_INCLUDE_H\n#define UVW_RESOURCE_INCLUDE_H\n\n#include \n#include \n#include \"config.h\"\n#include \"emitter.h\"\n#include \"uv_type.hpp\"\n\nnamespace uvw {\n\n/**\n * @brief Common class for almost all the resources available in `uvw`.\n *\n * This is the base class for handles and requests.\n */\ntemplate\nclass resource: public uv_type, public emitter, public std::enable_shared_from_this {\nprotected:\n int leak_if(int err) noexcept {\n if(err == 0) {\n self_ptr = this->shared_from_this();\n }\n\n return err;\n }\n\n void self_reset() noexcept {\n self_ptr.reset();\n }\n\n bool has_self() const noexcept {\n return static_cast(self_ptr);\n }\n\npublic:\n explicit resource(loop::token token, std::shared_ptr ref)\n : uv_type{token, std::move(ref)} {\n this->raw()->data = this;\n }\n\n /**\n * @brief Gets user-defined data. `uvw` won't use this field in any case.\n * @return User-defined data if any, an invalid pointer otherwise.\n */\n template\n std::shared_ptr data() const {\n return std::static_pointer_cast(user_data);\n }\n\n /**\n * @brief Sets arbitrary data. `uvw` won't use this field in any case.\n * @param udata User-defined arbitrary data.\n */\n void data(std::shared_ptr udata) {\n user_data = std::move(udata);\n }\n\nprivate:\n std::shared_ptr user_data{nullptr};\n std::shared_ptr self_ptr{nullptr};\n};\n\n} // namespace uvw\n\n#endif // UVW_RESOURCE_INCLUDE_H\n\n// Path: src/uvw/handle.hpp\n#ifndef UVW_HANDLE_INCLUDE_H\n#define UVW_HANDLE_INCLUDE_H\n\n#include \n#include \n#include \n#include \n#include \"config.h\"\n#include \"resource.hpp\"\n#include \"util.h\"\n\nnamespace uvw {\n\n/*! @brief Close event. */\nstruct close_event {};\n\n/**\n * @brief Handle base class.\n *\n * Base type for all `uvw` handle types.\n */\ntemplate\nclass handle: public resource {\nprotected:\n static void close_callback(uv_handle_t *hndl) {\n handle &ref = *(static_cast(hndl->data));\n [[maybe_unused]] auto ptr = ref.shared_from_this();\n ref.self_reset();\n ref.publish(close_event{});\n }\n\n uv_handle_t *as_uv_handle() {\n return reinterpret_cast(this->raw());\n }\n\n const uv_handle_t *as_uv_handle() const {\n return reinterpret_cast(this->raw());\n }\n\npublic:\n using resource::resource;\n\n /**\n * @brief Gets the category of the handle.\n *\n * A base handle offers no functionality to promote it to the actual handle\n * type. By means of this function, an opaque value that identifies the\n * category of the handle is made available to the users.\n *\n * @return The actual category of the handle.\n */\n handle_category category() const noexcept {\n return handle_category{as_uv_handle()->type};\n }\n\n /**\n * @brief Gets the type of the handle.\n *\n * A base handle offers no functionality to promote it to the actual handle\n * type. By means of this function, the type of the underlying handle as\n * specified by handle_type is made available to the users.\n *\n * @return The actual type of the handle.\n */\n handle_type type() const noexcept {\n return utilities::guess_handle(category());\n }\n\n /**\n * @brief Checks if the handle is active.\n *\n * What _active_ means depends on the type of handle:\n *\n * * An async_handle handle is always active and cannot be deactivated,\n * except by closing it with uv_close().\n * * A pipe, tcp, udp, etc. handle - basically any handle that deals with\n * I/O - is active when it is doing something that involves I/O, like\n * reading, writing, connecting, accepting new connections, etc.\n * * A check, idle, timer, etc. handle is active when it has been started\n * with a call to `start()`.\n *\n * Rule of thumb: if a handle of type `foo_handle` has a `start()` member\n * method, then it\u2019s active from the moment that method is called. Likewise,\n * `stop()` deactivates the handle again.\n *\n * @return True if the handle is active, false otherwise.\n */\n bool active() const noexcept {\n return !!uv_is_active(as_uv_handle());\n }\n\n /**\n * @brief Checks if a handle is closing or closed.\n *\n * This function should only be used between the initialization of the\n * handle and the arrival of the close callback.\n *\n * @return True if the handle is closing or closed, false otherwise.\n */\n bool closing() const noexcept {\n return !!uv_is_closing(as_uv_handle());\n }\n\n /**\n * @brief Request handle to be closed.\n *\n * This **must** be called on each handle before memory is released.
\n * In-progress requests are cancelled and this can result in errors.\n *\n * The handle will emit a close event when finished.\n */\n void close() noexcept {\n if(!closing()) {\n uv_close(as_uv_handle(), &handle::close_callback);\n }\n }\n\n /**\n * @brief Reference the given handle.\n *\n * References are idempotent, that is, if a handle is already referenced\n * calling this function again will have no effect.\n */\n void reference() noexcept {\n uv_ref(as_uv_handle());\n }\n\n /**\n * @brief Unreference the given handle.\n *\n * References are idempotent, that is, if a handle is not referenced calling\n * this function again will have no effect.\n */\n void unreference() noexcept {\n uv_unref(as_uv_handle());\n }\n\n /**\n * @brief Checks if the given handle referenced.\n * @return True if the handle referenced, false otherwise.\n */\n bool referenced() const noexcept {\n return !!uv_has_ref(as_uv_handle());\n }\n\n /**\n * @brief Returns the size of the underlying handle type.\n * @return The size of the underlying handle type.\n */\n std::size_t size() const noexcept {\n return uv_handle_size(as_uv_handle()->type);\n }\n\n /**\n * @brief Gets the size of the send buffer used for the socket.\n *\n * Gets the size of the send buffer that the operating system uses for the\n * socket.
\n * This function works for tcp, pipeand udp handles on Unix and for tcp and\n * udp handles on Windows.
\n * Note that Linux will return double the size of the original set value.\n *\n * @return The size of the send buffer, the underlying return value in case\n * of errors.\n */\n int send_buffer_size() {\n int value = 0;\n auto err = uv_send_buffer_size(as_uv_handle(), &value);\n return err ? err : value;\n }\n\n /**\n * @brief Sets the size of the send buffer used for the socket.\n *\n * Sets the size of the send buffer that the operating system uses for the\n * socket.
\n * This function works for tcp, pipe and udp handles on Unix and for tcp and\n * udp handles on Windows.
\n * Note that Linux will set double the size.\n *\n * @return Underlying return value.\n */\n int send_buffer_size(int value) {\n return uv_send_buffer_size(as_uv_handle(), &value);\n }\n\n /**\n * @brief Gets the size of the receive buffer used for the socket.\n *\n * Gets the size of the receive buffer that the operating system uses for\n * the socket.
\n * This function works for tcp, pipe and udp handles on Unix and for tcp and\n * udp handles on Windows.
\n * Note that Linux will return double the size of the original set value.\n *\n * @return The size of the receive buffer, the underlying return value in\n * case of errors.\n */\n int recv_buffer_size() {\n int value = 0;\n auto err = uv_recv_buffer_size(as_uv_handle(), &value);\n return err ? err : value;\n }\n\n /**\n * @brief Sets the size of the receive buffer used for the socket.\n *\n * Sets the size of the receive buffer that the operating system uses for\n * the socket.
\n * This function works for tcp, pipe and udp handles on Unix and for tcp and\n * udp handles on Windows.
\n * Note that Linux will set double the size.\n *\n * @return Underlying return value.\n */\n int recv_buffer_size(int value) {\n return uv_recv_buffer_size(as_uv_handle(), &value);\n }\n\n /**\n * @brief Gets the platform dependent file descriptor equivalent.\n *\n * Supported handles:\n *\n * * tcp_handle\n * * pipe_handle\n * * tty_handle\n * * udp_handle\n * * poll_handle\n *\n * If invoked on a different handle, one that doesn\u2019t have an attached file\n * descriptor yet or one which was closed, an invalid value is returned.\n *\n * See the official\n * [documentation](http://docs.libuv.org/en/v1.x/handle.html#c.uv_fileno)\n * for further details.\n *\n * @return The file descriptor attached to the hande or a negative value in\n * case of errors.\n */\n os_file_descriptor fd() const {\n uv_os_fd_t fd;\n uv_fileno(as_uv_handle(), &fd);\n return fd;\n }\n};\n\n} // namespace uvw\n\n#endif // UVW_HANDLE_INCLUDE_H\n\n// Path: src/uvw/prepare.h\n#ifndef UVW_PREPARE_INCLUDE_H\n#define UVW_PREPARE_INCLUDE_H\n\n#include \n#include \"handle.hpp\"\n#include \"loop.h\"\n\nnamespace uvw {\n\n/*! @brief Prepare event. */\nstruct prepare_event {};\n\n/**\n * @brief The prepare handle.\n *\n * Prepare handles will emit a prepare event once per loop iteration, right\n * before polling for I/O.\n *\n * To create a `prepare_handle` through a `loop`, no arguments are required.\n */\nclass prepare_handle final: public handle {\n static void start_callback(uv_prepare_t *hndl);\n\npublic:\n using handle::handle;\n\n /**\n * @brief Initializes the handle.\n * @return Underlying return value.\n */\n int init();\n\n /**\n * @brief Starts the handle.\n *\n * A prepare event will be emitted once per loop iteration, right before\n * polling for I/O.\n *\n * The handle will start emitting prepare events when needed.\n *\n * @return Underlying return value.\n */\n int start();\n\n /**\n * @brief Stops the handle.\n * @return Underlying return value.\n */\n int stop();\n};\n\n} // namespace uvw\n\n#ifndef UVW_AS_LIB\n# include \"prepare.cpp\"\n#endif\n\n#endif // UVW_PREPARE_INCLUDE_H\n\n// Path: src/uvw/prepare.cpp\n#ifdef UVW_AS_LIB\n# include \"prepare.h\"\n#endif\n\n#include \"config.h\"\n\nnamespace uvw {\n\nUVW_INLINE void prepare_handle::start_callback(uv_prepare_t *hndl) {\n prepare_handle &prepare = *(static_cast(hndl->data));\n prepare.publish(prepare_event{});\n}\n\nUVW_INLINE int prepare_handle::init() {\n return leak_if(uv_prepare_init(parent().raw(), raw()));\n}\n\nUVW_INLINE int prepare_handle::start() {\n return uv_prepare_start(raw(), &start_callback);\n}\n\nUVW_INLINE int prepare_handle::stop() {\n return uv_prepare_stop(raw());\n}\n\n} // namespace uvw\n\n// Path: src/uvw/request.hpp\n#ifndef UVW_REQUEST_INCLUDE_H\n#define UVW_REQUEST_INCLUDE_H\n\n#include \n#include \n#include \n#include \n#include \"config.h\"\n#include \"resource.hpp\"\n\nnamespace uvw {\n\n/**\n * @brief Request base class.\n *\n * Base type for all `uvw` request types.\n */\ntemplate\nclass request: public resource {\nprotected:\n static auto reserve(U *req) {\n auto ptr = static_cast(req->data)->shared_from_this();\n ptr->self_reset();\n return ptr;\n }\n\npublic:\n using resource::resource;\n\n /**\n * @brief Cancels a pending request.\n *\n * This method fails if the request is executing or has finished\n * executing.\n *\n * See the official\n * [documentation](http://docs.libuv.org/en/v1.x/request.html#c.uv_cancel)\n * for further details.\n *\n * @return Underlying return value.\n */\n int cancel() {\n return uv_cancel(reinterpret_cast(this->raw()));\n }\n\n /**\n * @brief Returns the size of the underlying request type.\n * @return The size of the underlying request type.\n */\n std::size_t size() const noexcept {\n return uv_req_size(reinterpret_cast(this->raw())->type);\n }\n};\n\n} // namespace uvw\n\n#endif // UVW_REQUEST_INCLUDE_H\n\n// Path: src/uvw/stream.h\n#ifndef UVW_STREAM_INCLUDE_H\n#define UVW_STREAM_INCLUDE_H\n\n#include \n#include \n#include \n#include \n#include \n#include \n#include \"config.h\"\n#include \"handle.hpp\"\n#include \"loop.h\"\n#include \"request.hpp\"\n\nnamespace uvw {\n\n/*! @brief Connect event. */\nstruct connect_event {};\n\n/*! @brief End event. */\nstruct end_event {};\n\n/*! @brief Listen event. */\nstruct listen_event {};\n\n/*! @brief Shutdown event. */\nstruct shutdown_event {};\n\n/*! @brief Write event. */\nstruct write_event {};\n\n/*! @brief Data event. */\nstruct data_event {\n explicit data_event(std::unique_ptr buf, std::size_t len) noexcept;\n\n std::unique_ptr data; /*!< A bunch of data read on the stream. */\n std::size_t length; /*!< The amount of data read on the stream. */\n};\n\nnamespace details {\n\nclass connect_req final: public request {\n static void connect_callback(uv_connect_t *req, int status);\n\npublic:\n using request::request;\n\n template\n auto connect(F &&f, Args &&...args) -> std::enable_if_t(f)(raw(), std::forward(args)..., &connect_callback)), void>, int> {\n std::forward(f)(raw(), std::forward(args)..., &connect_callback);\n return this->leak_if(0);\n }\n\n template\n auto connect(F &&f, Args &&...args) -> std::enable_if_t(f)(raw(), std::forward(args)..., &connect_callback)), void>, int> {\n return this->leak_if(std::forward(f)(raw(), std::forward(args)..., &connect_callback));\n }\n};\n\nclass shutdown_req final: public request {\n static void shoutdown_callback(uv_shutdown_t *req, int status);\n\npublic:\n using request::request;\n\n int shutdown(uv_stream_t *hndl);\n};\n\ntemplate\nclass write_req final: public request, uv_write_t, write_event> {\n...\n", "needle_token_start": 692, "needle_token_end": 871, "code_context_ntokens": 16064, "output": ["```cpp\ntemplate\nstd::string try_read(F &&f, Args &&...args) noexcept {\n std::size_t size = DEFAULT_SIZE;\n char buf[DEFAULT_SIZE];\n std::string str{};\n auto err = std::forward(f)(args..., buf, &size);\n\n if(UV_ENOBUFS == err) {\n std::unique_ptr data{new char[size]};\n err = std::forward(f)(args..., data.get(), &size);\n\n if(0 == err) {\n str = data.get();\n }\n } else if(0 == err) {\n str.assign(buf, size);\n }\n\n return str;\n}\n```"]} +{"repo": "skypjack/uvw", "name": "set", "language": "cpp", "path": "src/uvw/thread.h", "position_ratio": 0.15, "description": "\nFunction Description:\n1. **Purpose**: The function is designed to store a variable in thread-local storage, ensuring that the variable is accessible only to the thread that set it.\n2. **Input**: The function takes a pointer to the variable that needs to be stored in the thread-local storage.\n3. **Output**: There is no return value; the function operates by side effect, storing the provided variable in a thread-specific location.\n4. **Procedure**: The function casts the provided variable pointer to a generic type and uses an underlying library function to associate this pointer with a thread-specific key, effectively storing the variable in the thread's local storage.\n\n", "instruction": "Based on the function description and code context, please retrieve and repeat the exact described function from the code context in a code block wrapped by ```:", "template": "instruction\ncode_context\ndescription\ninstruction", "code_context": "// Path: src/uvw/check.cpp\n#ifdef UVW_AS_LIB\n# include \"check.h\"\n#endif\n\n#include \"config.h\"\n\nnamespace uvw {\n\nUVW_INLINE void check_handle::start_callback(uv_check_t *hndl) {\n check_handle &check = *(static_cast(hndl->data));\n check.publish(check_event{});\n}\n...\n// Path: src/uvw/enum.hpp\n#ifndef UVW_ENUM_INCLUDE_HPP\n#define UVW_ENUM_INCLUDE_HPP\n\n#include \n#include \"config.h\"\n\n/**\n * @brief Operator available for enums for which bitmask support is enabled.\n * @tparam Type Enum class type.\n * @param lhs The first value to use.\n * @param rhs The second value to use.\n * @return The result of invoking the operator on the underlying types of the\n * two values provided.\n */\ntemplate\n[[nodiscard]] constexpr std::enable_if_t, decltype(Type::_UVW_ENUM)>\noperator|(const Type lhs, const Type rhs) noexcept {\n return static_cast(static_cast>(lhs) | static_cast>(rhs));\n}\n\n/*! @copydoc operator| */\ntemplate\n[[nodiscard]] constexpr std::enable_if_t, decltype(Type::_UVW_ENUM)>\noperator&(const Type lhs, const Type rhs) noexcept {\n return static_cast(static_cast>(lhs) & static_cast>(rhs));\n}\n\n/*! @copydoc operator| */\ntemplate\n[[nodiscard]] constexpr std::enable_if_t, decltype(Type::_UVW_ENUM)>\noperator^(const Type lhs, const Type rhs) noexcept {\n return static_cast(static_cast>(lhs) ^ static_cast>(rhs));\n}\n\n/**\n * @brief Operator available for enums for which bitmask support is enabled.\n * @tparam Type Enum class type.\n * @param value The value to use.\n * @return The result of invoking the operator on the underlying types of the\n * value provided.\n */\ntemplate\n[[nodiscard]] constexpr std::enable_if_t, decltype(Type::_UVW_ENUM)>\noperator~(const Type value) noexcept {\n return static_cast(~static_cast>(value));\n}\n\n/*! @copydoc operator~ */\ntemplate\n[[nodiscard]] constexpr std::enable_if_t, decltype(Type::_UVW_ENUM, bool{})>\noperator!(const Type value) noexcept {\n return !static_cast>(value);\n}\n\n/*! @copydoc operator| */\ntemplate\nconstexpr std::enable_if_t, decltype(Type::_UVW_ENUM) &>\noperator|=(Type &lhs, const Type rhs) noexcept {\n return (lhs = (lhs | rhs));\n}\n\n/*! @copydoc operator| */\ntemplate\nconstexpr std::enable_if_t, decltype(Type::_UVW_ENUM) &>\noperator&=(Type &lhs, const Type rhs) noexcept {\n return (lhs = (lhs & rhs));\n}\n\n/*! @copydoc operator| */\ntemplate\nconstexpr std::enable_if_t, decltype(Type::_UVW_ENUM) &>\noperator^=(Type &lhs, const Type rhs) noexcept {\n return (lhs = (lhs ^ rhs));\n}\n\n#endif\n\n// Path: src/uvw/thread.h\n#ifndef UVW_THREAD_INCLUDE_H\n#define UVW_THREAD_INCLUDE_H\n\n#include \n#include \n#include \n#include \n#include \n#include \n#include \n#include \"config.h\"\n#include \"enum.hpp\"\n#include \"loop.h\"\n#include \"uv_type.hpp\"\n\nnamespace uvw {\n\nnamespace details {\n\nenum class uvw_thread_create_flags : std::underlying_type_t {\n THREAD_NO_FLAGS = UV_THREAD_NO_FLAGS,\n THREAD_HAS_STACK_SIZE = UV_THREAD_HAS_STACK_SIZE\n};\n\n}\n\nclass thread;\nclass thread_local_storage;\nclass once;\nclass mutex;\nclass rwlock;\nclass semaphore;\nclass condition;\nclass barrier;\n\n/**\n * @brief The thread wrapper.\n *\n * To create a `thread` through a `loop`, arguments follow:\n *\n * * A callback invoked to initialize thread execution. The type must be such\n * that it can be assigned to an `std::function)>`.\n * * An optional payload the type of which is `std::shared_ptr`.\n */\nclass thread final: public uv_type {\n using internal_task = std::function)>;\n\n static void create_callback(void *arg);\n\npublic:\n using create_flags = details::uvw_thread_create_flags;\n using task = internal_task;\n using type = uv_thread_t;\n\n explicit thread(loop::token token, std::shared_ptr ref, task t, std::shared_ptr d = nullptr) noexcept;\n\n /**\n * @brief Obtains the identifier of the calling thread.\n * @return The identifier of the calling thread.\n */\n static type self() noexcept;\n\n /**\n * @brief Gets the CPU number on which the calling thread is running.\n * @return The CPU number on which the calling thread is running.\n */\n static int getcpu() noexcept;\n\n /**\n * @brief Compares thread by means of their identifiers.\n * @param tl A valid instance of a thread.\n * @param tr A valid instance of a thread.\n * @return True if the two threads are the same thread, false otherwise.\n */\n static bool equal(const thread &tl, const thread &tr) noexcept;\n\n ~thread() noexcept;\n\n /**\n * @brief Creates a new thread.\n * @return True in case of success, false otherwise.\n */\n bool run() noexcept;\n\n /**\n * @brief Creates a new thread.\n *\n * Available flags are:\n *\n * * `thread::create_flags::THREAD_NO_FLAGS`: no flags set.\n * * `thread::create_flags::THREAD_HAS_STACK_SIZE`: if set, `stack` specifies a\n * stack size for the new thread. 0 indicates that the default value should\n * be used (it behaves as if the flag was not set). Other values will be\n * rounded up to the nearest page boundary.\n *\n * @return True in case of success, false otherwise.\n */\n bool run(create_flags opts, std::size_t stack = {}) noexcept;\n\n /**\n * @brief Joins with a terminated thread.\n * @return True in case of success, false otherwise.\n */\n bool join() noexcept;\n\nprivate:\n std::shared_ptr data;\n task func;\n};\n\n/**\n * @brief The thread local storage wrapper.\n *\n * A storage area that can only be accessed by one thread. The variable can be\n * seen as a global variable that is only visible to a particular thread and not\n * the whole program.\n */\nclass thread_local_storage final: public uv_type {\npublic:\n explicit thread_local_storage(loop::token token, std::shared_ptr ref) noexcept;\n\n ~thread_local_storage() noexcept;\n\n /**\n * @brief Gets the value of a given variable.\n * @tparam T Type to which to cast the opaque storage area.\n * @return A pointer to the given variable.\n */\n template\n T *get() noexcept {\n return static_cast(uv_key_get(uv_type::raw()));\n }\n\n /**\n * @brief Sets the value of a given variable.\n * @tparam T Type of the variable to store aside.\n * @param value A valid pointer to the variable to store\n */\n template\n \nvoid set(T *value) noexcept {\n return uv_key_set(uv_type::raw(), value);\n }\n};\n\n/**\n * @brief The once wrapper.\n *\n * Runs a function once and only once. Concurrent calls to `once` will block all\n * callers except one (it\u2019s unspecified which one).\n */\nclass once final: public uv_type {\n static uv_once_t *guard() noexcept;\n\npublic:\n using uv_type::uv_type;\n\n /**\n * @brief Runs a function once and only once.\n *\n * The callback must be such that it's convertible to `void(*)(void)`. Free\n * functions and non-capturing lambdas are both viable solutions.\n *\n * @tparam F Type of the callback.\n * @param f A valid callback function.\n */\n template\n static void run(F &&f) noexcept {\n using callback_type = void (*)(void);\n static_assert(std::is_convertible_v);\n callback_type cb = f;\n uv_once(guard(), cb);\n }\n};\n\n/**\n * @brief The mutex wrapper.\n *\n * To create a `mutex` through a `loop`, arguments follow:\n *\n * * An option boolean that specifies if the mutex is a recursive one. The\n * default value is false, the mutex isn't recursive.\n */\nclass mutex final: public uv_type {\n friend class condition;\n\npublic:\n explicit mutex(loop::token token, std::shared_ptr ref, bool recursive = false) noexcept;\n\n ~mutex() noexcept;\n\n /**\n * @brief Locks the mutex.\n */\n void lock() noexcept;\n\n /**\n * @brief Tries to lock the mutex.\n * @return True in case of success, false otherwise.\n */\n bool try_lock() noexcept;\n\n /**\n * @brief Unlocks the mutex.\n */\n void unlock() noexcept;\n};\n\n/**\n * @brief The rwlock wrapper.\n */\nclass rwlock final: public uv_type {\npublic:\n explicit rwlock(loop::token token, std::shared_ptr ref) noexcept;\n\n ~rwlock() noexcept;\n\n /**\n * @brief Locks a read-write lock object for reading.\n */\n void rdlock() noexcept;\n\n /**\n * @brief Tries to lock a read-write lock object for reading.\n * @return True in case of success, false otherwise.\n */\n bool try_rdlock() noexcept;\n\n /**\n * @brief Unlocks a read-write lock object previously locked for reading.\n */\n void rdunlock() noexcept;\n\n /**\n * @brief Locks a read-write lock object for writing.\n */\n void wrlock() noexcept;\n\n /**\n * @brief Tries to lock a read-write lock object for writing.\n * @return True in case of success, false otherwise.\n */\n bool try_wrlock() noexcept;\n\n /**\n * @brief Unlocks a read-write lock object previously locked for writing.\n */\n void wrunlock() noexcept;\n};\n\n/**\n * @brief The semaphore wrapper.\n *\n * To create a `semaphore` through a `loop`, arguments follow:\n *\n * * An unsigned integer that specifies the initial value for the semaphore.\n */\nclass semaphore final: public uv_type {\npublic:\n explicit semaphore(loop::token token, std::shared_ptr ref, unsigned int value) noexcept;\n\n ~semaphore() noexcept;\n\n /**\n * @brief Unlocks a semaphore.\n */\n void post() noexcept;\n\n /**\n * @brief Locks a semaphore.\n */\n void wait() noexcept;\n\n /**\n * @brief Tries to lock a semaphore.\n * @return True in case of success, false otherwise.\n */\n bool try_wait() noexcept;\n};\n\n/**\n * @brief The condition wrapper.\n */\nclass condition final: public uv_type {\npublic:\n explicit condition(loop::token token, std::shared_ptr ref) noexcept;\n\n ~condition() noexcept;\n\n /**\n * @brief Signals a condition.\n *\n * This function shall unblock at least one of the threads that are blocked\n * on the specified condition variable (if any threads are blocked on it).\n */\n void signal() noexcept;\n\n /**\n * @brief Broadcasts a condition.\n *\n * This function shall unblock threads blocked on a condition variable.\n */\n void broadcast() noexcept;\n\n /**\n * @brief Waits on a condition.\n *\n * These function atomically releases the mutex and causes the calling\n * thread to block on the condition variable.\n *\n * @param mtx A mutex locked by the calling thread, otherwise expect\n * undefined behavior.\n */\n void wait(mutex &mtx) noexcept;\n\n /**\n * @brief Waits on a condition.\n *\n * These function atomically releases the mutex and causes the calling\n * thread to block on the condition variable.
\n * The functions returns with an error if the absolute time specified passes\n * (that is, system time equals or exceeds it) before the condition is\n * signaled or broadcasted, or if the absolute time specified has already\n * been passed at the time of the call.\n *\n * @param mtx A mutex locked by the calling thread, otherwise expect\n * undefined behavior.\n * @param timeout The maximum time to wait before to return.\n * @return True in case of success, false otherwise.\n */\n bool timed_wait(mutex &mtx, uint64_t timeout) noexcept;\n};\n\n/**\n * @brief The barrier wrapper.\n *\n * To create a `barrier` through a `loop`, arguments follow:\n *\n * * An unsigned integer that specifies the number of threads that must call\n * `wait` before any of them successfully return from the call. The value\n * specified must be greater than zero.\n */\nclass barrier final: public uv_type {\npublic:\n explicit barrier(loop::token token, std::shared_ptr ref, unsigned int count) noexcept;\n\n ~barrier() noexcept;\n\n /**\n * @brief Synchronizes at a barrier.\n * @return True in case of success, false otherwise.\n */\n bool wait() noexcept;\n};\n\n} // namespace uvw\n\n#ifndef UVW_AS_LIB\n# include \"thread.cpp\"\n#endif\n\n#endif // UVW_THREAD_INCLUDE_H\n\n// Path: src/uvw/thread.cpp\n#ifdef UVW_AS_LIB\n# include \"thread.h\"\n#endif\n\n#include \"config.h\"\n\nnamespace uvw {\n\nUVW_INLINE thread::thread(loop::token token, std::shared_ptr ref, task t, std::shared_ptr d) noexcept\n : uv_type{token, std::move(ref)},\n data{std::move(d)},\n func{std::move(t)} {}\n\nUVW_INLINE void thread::create_callback(void *arg) {\n thread &curr = *(static_cast(arg));\n curr.func(curr.data);\n}\n\nUVW_INLINE thread::type thread::self() noexcept {\n return uv_thread_self();\n}\n\nUVW_INLINE int thread::getcpu() noexcept {\n return uv_thread_getcpu();\n}\n\nUVW_INLINE bool thread::equal(const thread &tl, const thread &tr) noexcept {\n return !(0 == uv_thread_equal(tl.raw(), tr.raw()));\n}\n\nUVW_INLINE thread::~thread() noexcept {\n join();\n}\n\nUVW_INLINE bool thread::run() noexcept {\n return (0 == uv_thread_create(raw(), &create_callback, this));\n}\n\nUVW_INLINE bool thread::run(create_flags opts, std::size_t stack) noexcept {\n uv_thread_options_t params{static_cast(opts), stack};\n return (0 == uv_thread_create_ex(raw(), ¶ms, &create_callback, this));\n}\n\nUVW_INLINE bool thread::join() noexcept {\n return (0 == uv_thread_join(raw()));\n}\n\nUVW_INLINE thread_local_storage::thread_local_storage(loop::token token, std::shared_ptr ref) noexcept\n : uv_type{token, std::move(ref)} {\n uv_key_create(uv_type::raw());\n}\n\nUVW_INLINE thread_local_storage::~thread_local_storage() noexcept {\n uv_key_delete(uv_type::raw());\n}\n\nUVW_INLINE uv_once_t *once::guard() noexcept {\n static uv_once_t once = UV_ONCE_INIT;\n return &once;\n}\n\nUVW_INLINE mutex::mutex(loop::token token, std::shared_ptr ref, bool recursive) noexcept\n : uv_type{token, std::move(ref)} {\n if(recursive) {\n uv_mutex_init_recursive(raw());\n } else {\n uv_mutex_init(raw());\n }\n}\n\nUVW_INLINE mutex::~mutex() noexcept {\n uv_mutex_destroy(raw());\n}\n\nUVW_INLINE void mutex::lock() noexcept {\n uv_mutex_lock(raw());\n}\n\nUVW_INLINE bool mutex::try_lock() noexcept {\n return (0 == uv_mutex_trylock(raw()));\n}\n\nUVW_INLINE void mutex::unlock() noexcept {\n uv_mutex_unlock(raw());\n}\n\nUVW_INLINE rwlock::rwlock(loop::token token, std::shared_ptr ref) noexcept\n : uv_type{token, std::move(ref)} {\n uv_rwlock_init(raw());\n}\n\nUVW_INLINE rwlock::~rwlock() noexcept {\n uv_rwlock_destroy(raw());\n}\n\nUVW_INLINE void rwlock::rdlock() noexcept {\n uv_rwlock_rdlock(raw());\n}\n\nUVW_INLINE bool rwlock::try_rdlock() noexcept {\n return (0 == uv_rwlock_tryrdlock(raw()));\n}\n\nUVW_INLINE void rwlock::rdunlock() noexcept {\n uv_rwlock_rdunlock(raw());\n}\n\nUVW_INLINE void rwlock::wrlock() noexcept {\n uv_rwlock_wrlock(raw());\n}\n\nUVW_INLINE bool rwlock::try_wrlock() noexcept {\n return (0 == uv_rwlock_trywrlock(raw()));\n}\n\nUVW_INLINE void rwlock::wrunlock() noexcept {\n uv_rwlock_wrunlock(raw());\n}\n\nUVW_INLINE semaphore::semaphore(loop::token token, std::shared_ptr ref, unsigned int value) noexcept\n : uv_type{token, std::move(ref)} {\n uv_sem_init(raw(), value);\n}\n\nUVW_INLINE semaphore::~semaphore() noexcept {\n uv_sem_destroy(raw());\n}\n\nUVW_INLINE void semaphore::post() noexcept {\n uv_sem_post(raw());\n}\n\nUVW_INLINE void semaphore::wait() noexcept {\n uv_sem_wait(raw());\n}\n\nUVW_INLINE bool semaphore::try_wait() noexcept {\n return (0 == uv_sem_trywait(raw()));\n}\n\nUVW_INLINE condition::condition(loop::token token, std::shared_ptr ref) noexcept\n : uv_type{token, std::move(ref)} {\n uv_cond_init(raw());\n}\n\nUVW_INLINE condition::~condition() noexcept {\n uv_cond_destroy(raw());\n}\n\nUVW_INLINE void condition::signal() noexcept {\n uv_cond_signal(raw());\n}\n\nUVW_INLINE void condition::broadcast() noexcept {\n uv_cond_broadcast(raw());\n}\n\nUVW_INLINE void condition::wait(mutex &mtx) noexcept {\n uv_cond_wait(raw(), mtx.raw());\n}\n\nUVW_INLINE bool condition::timed_wait(mutex &mtx, uint64_t timeout) noexcept {\n return (0 == uv_cond_timedwait(raw(), mtx.raw(), timeout));\n}\n\nUVW_INLINE barrier::barrier(loop::token token, std::shared_ptr ref, unsigned int count) noexcept\n : uv_type{token, std::move(ref)} {\n uv_barrier_init(raw(), count);\n}\n\nUVW_INLINE barrier::~barrier() noexcept {\n uv_barrier_destroy(raw());\n}\n\nUVW_INLINE bool barrier::wait() noexcept {\n return (0 == uv_barrier_wait(raw()));\n}\n\n} // namespace uvw\n\n// Path: src/uvw/async.h\n#ifndef UVW_ASYNC_INCLUDE_H\n#define UVW_ASYNC_INCLUDE_H\n\n#include \n#include \"handle.hpp\"\n#include \"loop.h\"\n\nnamespace uvw {\n\n/*! @brief Async event. */\nstruct async_event {};\n\n/**\n * @brief The async handle.\n *\n * Async handles allow the user to _wakeup_ the event loop and get an event\n * emitted from another thread.\n *\n * To create an `async_handle` through a `loop`, no arguments are required.\n */\nclass async_handle final: public handle {\n static void send_callback(uv_async_t *hndl);\n\npublic:\n using handle::handle;\n\n /**\n * @brief Initializes the handle.\n *\n * Unlike other handle initialization functions, it immediately starts the\n * handle.\n *\n * @return Underlying return value.\n */\n int init();\n\n /**\n * @brief Wakeups the event loop and emits the async event.\n *\n * It\u2019s safe to call this function from any thread.
\n * An async event is emitted on the loop thread.\n *\n * See the official\n * [documentation](http://docs.libuv.org/en/v1.x/async.html#c.uv_async_send)\n * for further details.\n *\n * @return Underlying return value.\n */\n int send();\n};\n\n} // namespace uvw\n\n#ifndef UVW_AS_LIB\n# include \"async.cpp\"\n#endif\n\n#endif // UVW_ASYNC_INCLUDE_H\n\n// Path: src/uvw/async.cpp\n#ifdef UVW_AS_LIB\n# include \"async.h\"\n#endif\n\n#include \"config.h\"\n\nnamespace uvw {\n\nUVW_INLINE void async_handle::send_callback(uv_async_t *hndl) {\n async_handle &async = *(static_cast(hndl->data));\n async.publish(async_event{});\n}\n\nUVW_INLINE int async_handle::init() {\n return leak_if(uv_async_init(parent().raw(), raw(), &send_callback));\n}\n\nUVW_INLINE int async_handle::send() {\n return uv_async_send(raw());\n}\n\n} // namespace uvw\n\n// Path: src/uvw/fs.h\n#ifndef UVW_FS_INCLUDE_H\n#define UVW_FS_INCLUDE_H\n\n#include \n#include \n#include \n#include \n#include \n#include \"config.h\"\n#include \"enum.hpp\"\n#include \"loop.h\"\n#include \"request.hpp\"\n#include \"util.h\"\n\nnamespace uvw {\n\nnamespace details {\n\nenum class uvw_fs_type : std::underlying_type_t {\n UNKNOWN = UV_FS_UNKNOWN,\n CUSTOM = UV_FS_CUSTOM,\n OPEN = UV_FS_OPEN,\n CLOSE = UV_FS_CLOSE,\n READ = UV_FS_READ,\n WRITE = UV_FS_WRITE,\n SENDFILE = UV_FS_SENDFILE,\n STAT = UV_FS_STAT,\n LSTAT = UV_FS_LSTAT,\n FSTAT = UV_FS_FSTAT,\n FTRUNCATE = UV_FS_FTRUNCATE,\n UTIME = UV_FS_UTIME,\n FUTIME = UV_FS_FUTIME,\n ACCESS = UV_FS_ACCESS,\n CHMOD = UV_FS_CHMOD,\n FCHMOD = UV_FS_FCHMOD,\n FSYNC = UV_FS_FSYNC,\n FDATASYNC = UV_FS_FDATASYNC,\n UNLINK = UV_FS_UNLINK,\n RMDIR = UV_FS_RMDIR,\n MKDIR = UV_FS_MKDIR,\n MKDTEMP = UV_FS_MKDTEMP,\n RENAME = UV_FS_RENAME,\n SCANDIR = UV_FS_SCANDIR,\n LINK = UV_FS_LINK,\n SYMLINK = UV_FS_SYMLINK,\n READLINK = UV_FS_READLINK,\n CHOWN = UV_FS_CHOWN,\n FCHOWN = UV_FS_FCHOWN,\n REALPATH = UV_FS_REALPATH,\n COPYFILE = UV_FS_COPYFILE,\n LCHOWN = UV_FS_LCHOWN,\n OPENDIR = UV_FS_OPENDIR,\n READDIR = UV_FS_READDIR,\n CLOSEDIR = UV_FS_CLOSEDIR,\n STATFS = UV_FS_STATFS,\n MKSTEMP = UV_FS_MKSTEMP,\n LUTIME = UV_FS_LUTIME\n};\n\nenum class uvw_dirent_type_t : std::underlying_type_t {\n UNKNOWN = UV_DIRENT_UNKNOWN,\n FILE = UV_DIRENT_FILE,\n DIR = UV_DIRENT_DIR,\n LINK = UV_DIRENT_LINK,\n FIFO = UV_DIRENT_FIFO,\n SOCKET = UV_DIRENT_SOCKET,\n CHAR = UV_DIRENT_CHAR,\n BLOCK = UV_DIRENT_BLOCK\n};\n\nenum class uvw_file_open_flags : int {\n APPEND = UV_FS_O_APPEND,\n CREAT = UV_FS_O_CREAT,\n DIRECT = UV_FS_O_DIRECT,\n DIRECTORY = UV_FS_O_DIRECTORY,\n DSYNC = UV_FS_O_DSYNC,\n EXCL = UV_FS_O_EXCL,\n EXLOCK = UV_FS_O_EXLOCK,\n FILEMAP = UV_FS_O_FILEMAP,\n NOATIME = UV_FS_O_NOATIME,\n NOCTTY = UV_FS_O_NOCTTY,\n NOFOLLOW = UV_FS_O_NOFOLLOW,\n NONBLOCK = UV_FS_O_NONBLOCK,\n RANDOM = UV_FS_O_RANDOM,\n RDONLY = UV_FS_O_RDONLY,\n RDWR = UV_FS_O_RDWR,\n SEQUENTIAL = UV_FS_O_SEQUENTIAL,\n SHORT_LIVED = UV_FS_O_SHORT_LIVED,\n SYMLINK = UV_FS_O_SYMLINK,\n SYNC = UV_FS_O_SYNC,\n TEMPORARY = UV_FS_O_TEMPORARY,\n TRUNC = UV_FS_O_TRUNC,\n WRONLY = UV_FS_O_WRONLY,\n _UVW_ENUM = 0\n};\n\nenum class uvw_copy_file_flags : int {\n EXCL = UV_FS_COPYFILE_EXCL,\n FICLONE = UV_FS_COPYFILE_FICLONE,\n FICLONE_FORCE = UV_FS_COPYFILE_FICLONE_FORCE,\n _UVW_ENUM = 0\n};\n\nenum class uvw_symlink_flags : int {\n DIR = UV_FS_SYMLINK_DIR,\n JUNCTION = UV_FS_SYMLINK_JUNCTION,\n _UVW_ENUM = 0\n};\n\n} // namespace details\n\n/**\n * @brief Common fs event.\n *\n * Available types are:\n *\n * * `fs_request::fs_type::UNKNOWN`\n * * `fs_request::fs_type::CUSTOM`\n * * `fs_request::fs_type::OPEN`\n * * `fs_request::fs_type::CLOSE`\n * * `fs_request::fs_type::READ`\n * * `fs_request::fs_type::WRITE`\n * * `fs_request::fs_type::SENDFILE`\n * * `fs_request::fs_type::STAT`\n * * `fs_request::fs_type::LSTAT`\n * * `fs_request::fs_type::FSTAT`\n * * `fs_request::fs_type::FTRUNCATE`\n * * `fs_request::fs_type::UTIME`\n * * `fs_request::fs_type::FUTIME`\n * * `fs_request::fs_type::ACCESS`\n * * `fs_request::fs_type::CHMOD`\n * * `fs_request::fs_type::FCHMOD`\n * * `fs_request::fs_type::FSYNC`\n * * `fs_request::fs_type::FDATASYNC`\n * * `fs_request::fs_type::UNLINK`\n * * `fs_request::fs_type::RMDIR`\n * * `fs_request::fs_type::MKDIR`\n * * `fs_request::fs_type::MKDTEMP`\n * * `fs_request::fs_type::RENAME`\n * * `fs_request::fs_type::SCANDIR`\n * * `fs_request::fs_type::LINK`\n * * `fs_request::fs_type::SYMLINK`\n * * `fs_request::fs_type::READLINK`\n * * `fs_request::fs_type::CHOWN`\n * * `fs_request::fs_type::FCHOWN`\n * * `fs_request::fs_type::REALPATH`\n * * `fs_request::fs_type::COPYFILE`\n * * `fs_request::fs_type::LCHOWN`\n * * `fs_request::fs_type::OPENDIR`\n * * `fs_request::fs_type::READDIR`\n * * `fs_request::fs_type::CLOSEDIR`\n * * `fs_request::fs_type::STATFS`\n * * `fs_request::fs_type::MKSTEMP`\n * * `fs_request::fs_type::LUTIME`\n *\n * See the official\n * [documentation](http://docs.libuv.org/en/v1.x/fs.html#c.uv_fs_type)\n * for further details.\n */\nstruct fs_event {\n using fs_type = details::uvw_fs_type;\n using entry_type = details::uvw_dirent_type_t;\n\n fs_event(const uv_fs_t &req, std::unique_ptr data)\n : fs_event{req} {\n read.data = std::move(data);\n }\n\n fs_event(const uv_fs_t &req)\n : type{req.fs_type},\n path{req.path},\n result{static_cast(req.result)} {\n switch(type) {\n case fs_type::STAT:\n case fs_type::LSTAT:\n case fs_type::FSTAT:\n stat = *static_cast(req.ptr);\n break;\n case fs_type::READLINK:\n readlink.data = static_cast(req.ptr);\n break;\n case fs_type::READDIR:\n dirent.name = static_cast(req.ptr)->dirents[0].name;\n dirent.type = static_cast(static_cast(req.ptr)->dirents[0].type);\n dirent.eos = !req.result;\n break;\n case fs_type::STATFS:\n statfs = *static_cast(req.ptr);\n break;\n default:\n // nothing to do here\n break;\n }\n }\n\n fs_type type; /*!< Actual event type. */\n const char *path; /*!< The path affecting the request. */\n std::size_t result; /*!< Result value for the specific type. */\n\n struct {\n std::unique_ptr data; /*!< A bunch of data read from the given path. */\n } read;\n\n struct {\n const char *data; /*!< The content of a symbolic link. */\n } readlink;\n\n file_info stat; /*!< An initialized instance of file_info. */\n fs_info statfs; /*!< An initialized instance of fs_info. */\n\n struct {\n const char *name; /*!< The name of the last entry. */\n entry_type type; /*!< The entry type. */\n bool eos; /*!< True if there a no more entries to read. */\n } dirent;\n};\n\n/**\n * @brief Base class for fs/file request.\n *\n * Not directly instantiable, should not be used by the users of the library.\n */\ntemplate\nclass fs_request: public request {\nprotected:\n static void fs_request_callback(uv_fs_t *req) {\n if(auto ptr = request::reserve(req); req->result < 0) {\n ptr->publish(error_event{req->result});\n } else {\n ptr->publish(fs_event{*req});\n }\n }\n\npublic:\n using time = std::chrono::duration;\n using fs_type = details::uvw_fs_type;\n using entry_type = details::uvw_dirent_type_t;\n\n using request::request;\n};\n\n/**\n * @brief The file request.\n *\n * Cross-platform sync and async filesystem operations.
\n * All file operations are run on the threadpool.\n *\n * To create a `file_req` through a `loop`, no arguments are required.\n *\n * See the official\n * [documentation](http://docs.libuv.org/en/v1.x/fs.html)\n * for further details.\n */\nclass file_req final: public fs_request {\n static constexpr uv_file BAD_FD = -1;\n\n static void fs_open_callback(uv_fs_t *req);\n static void fs_close_callback(uv_fs_t *req);\n static void fs_read_callback(uv_fs_t *req);\n\npublic:\n using file_open_flags = details::uvw_file_open_flags;\n\n using fs_request::fs_request;\n\n ~file_req() noexcept;\n\n /**\n * @brief Async [close](http://linux.die.net/man/2/close).\n *\n * Emit a `fs_event` event when completed.\n */\n void close();\n\n /**\n * @brief Sync [close](http://linux.die.net/man/2/close).\n * @return True in case of success, false otherwise.\n */\n bool close_sync();\n\n /**\n * @brief Async [open](http://linux.die.net/man/2/open).\n *\n * Emit a `fs_event` event when completed.\n *\n * Available flags are:\n *\n * * `file_req::file_open_flags::APPEND`\n * * `file_req::file_open_flags::CREAT`\n * * `file_req::file_open_flags::DIRECT`\n * * `file_req::file_open_flags::DIRECTORY`\n * * `file_req::file_open_flags::DSYNC`\n * * `file_req::file_open_flags::EXCL`\n * * `file_req::file_open_flags::EXLOCK`\n * * `file_req::file_open_flags::FILEMAP`\n * * `file_req::file_open_flags::NOATIME`\n * * `file_req::file_open_flags::NOCTTY`\n * * `file_req::file_open_flags::NOFOLLOW`\n * * `file_req::file_open_flags::NONBLOCK`\n * * `file_req::file_open_flags::RANDOM`\n * * `file_req::file_open_flags::RDONLY`\n * * `file_req::file_open_flags::RDWR`\n * * `file_req::file_open_flags::SEQUENTIAL`\n * * `file_req::file_open_flags::SHORT_LIVED`\n * * `file_req::file_open_flags::SYMLINK`\n * * `file_req::file_open_flags::SYNC`\n * * `file_req::file_open_flags::TEMPORARY`\n * * `file_req::file_open_flags::TRUNC`\n * * `file_req::file_open_flags::WRONLY`\n *\n * See the official\n * [documentation](http://docs.libuv.org/en/v1.x/fs.html#file-open-constants)\n * for further details.\n *\n * @param path A valid path name for a file.\n * @param flags Flags made out of underlying constants.\n * @param mode Mode, as described in the official documentation.\n */\n void open(const std::string &path, file_open_flags flags, int mode);\n\n /**\n * @brief Sync [open](http://linux.die.net/man/2/open).\n *\n * Available flags are:\n *\n * * `file_req::file_open_flags::APPEND`\n * * `file_req::file_open_flags::CREAT`\n * * `file_req::file_open_flags::DIRECT`\n * * `file_req::file_open_flags::DIRECTORY`\n * * `file_req::file_open_flags::DSYNC`\n * * `file_req::file_open_flags::EXCL`\n * * `file_req::file_open_flags::EXLOCK`\n * * `file_req::file_open_flags::FILEMAP`\n * * `file_req::file_open_flags::NOATIME`\n * * `file_req::file_open_flags::NOCTTY`\n * * `file_req::file_open_flags::NOFOLLOW`\n * * `file_req::file_open_flags::NONBLOCK`\n * * `file_req::file_open_flags::RANDOM`\n * * `file_req::file_open_flags::RDONLY`\n * * `file_req::file_open_flags::RDWR`\n * * `file_req::file_open_flags::SEQUENTIAL`\n * * `file_req::file_open_flags::SHORT_LIVED`\n * * `file_req::file_open_flags::SYMLINK`\n * * `file_req::file_open_flags::SYNC`\n * * `file_req::file_open_flags::TEMPORARY`\n * * `file_req::file_open_flags::TRUNC`\n * * `file_req::file_open_flags::WRONLY`\n *\n * See the official\n * [documentation](http://docs.libuv.org/en/v1.x/fs.html#file-open-constants)\n * for further details.\n *\n * @param path A valid path name for a file.\n * @param flags Flags made out of underlying constants.\n * @param mode Mode, as described in the official documentation.\n * @return True in case of success, false otherwise.\n */\n bool open_sync(const std::string &path, file_open_flags flags, int mode);\n\n /**\n * @brief Async [read](http://linux.die.net/man/2/preadv).\n *\n * Emit a `fs_event` event when completed.\n *\n * @param offset Offset, as described in the official documentation.\n * @param len Length, as described in the official documentation.\n */\n void read(int64_t offset, unsigned int len);\n\n /**\n * @brief Sync [read](http://linux.die.net/man/2/preadv).\n *\n * @param offset Offset, as described in the official documentation.\n * @param len Length, as described in the official documentation.\n *\n * @return A `std::pair` composed as it follows:\n * * A boolean value that is true in case of success, false otherwise.\n * * A `std::pair` composed as it follows:\n * * A bunch of data read from the given path.\n * * The amount of data read from the given path.\n */\n std::pair, std::size_t>> read_sync(int64_t offset, unsigned int len);\n\n /**\n * @brief Async [write](http://linux.die.net/man/2/pwritev).\n *\n * The request takes the ownership of the data and it is in charge of delete\n * them.\n *\n * Emit a `fs_event` event when completed.\n *\n * @param buf The data to be written.\n * @param len The lenght of the submitted data.\n * @param offset Offset, as described in the official documentation.\n */\n void write(std::unique_ptr buf, unsigned int len, int64_t offset);\n\n /**\n * @brief Async [write](http://linux.die.net/man/2/pwritev).\n *\n * The request doesn't take the ownership of the data. Be sure that their\n * lifetime overcome the one of the request.\n *\n * Emit a `fs_event` event when completed.\n *\n * @param buf The data to be written.\n * @param len The lenght of the submitted data.\n * @param offset Offset, as described in the official documentation.\n */\n void write(char *buf, unsigned int len, int64_t offset);\n\n /**\n * @brief Sync [write](http://linux.die.net/man/2/pwritev).\n *\n * @param buf The data to be written.\n * @param len The lenght of the submitted data.\n * @param offset Offset, as described in the official documentation.\n *\n * @return A `std::pair` composed as it follows:\n * * A boolean value that is true in case of success, false otherwise.\n * * The amount of data written to the given path.\n */\n std::pair write_sync(std::unique_ptr buf, unsigned int len, int64_t offset);\n\n /**\n * @brief Async [fstat](http://linux.die.net/man/2/fstat).\n *\n * Emit a `fs_event` event when completed.\n */\n void stat();\n\n /**\n * @brief Sync [fstat](http://linux.die.net/man/2/fstat).\n *\n * @return A `std::pair` composed as it follows:\n * * A boolean value that is true in case of success, false otherwise.\n * * An initialized instance of file_info.\n */\n std::pair stat_sync();\n\n /**\n * @brief Async [fsync](http://linux.die.net/man/2/fsync).\n *\n * Emit a `fs_event` event when completed.\n */\n void sync();\n\n /**\n * @brief Sync [fsync](http://linux.die.net/man/2/fsync).\n * @return True in case of success, false otherwise.\n */\n bool sync_sync();\n\n /**\n * @brief Async [fdatasync](http://linux.die.net/man/2/fdatasync).\n *\n * Emit a `fs_event` event when completed.\n */\n void datasync();\n\n /**\n * @brief Sync [fdatasync](http://linux.die.net/man/2/fdatasync).\n * @return True in case of success, false otherwise.\n */\n bool datasync_sync();\n\n /**\n * @brief Async [ftruncate](http://linux.die.net/man/2/ftruncate).\n *\n * Emit a `fs_event` event when completed.\n *\n * @param offset Offset, as described in the official documentation.\n */\n void truncate(int64_t offset);\n\n /**\n * @brief Sync [ftruncate](http://linux.die.net/man/2/ftruncate).\n * @param offset Offset, as described in the official documentation.\n * @return True in case of success, false otherwise.\n */\n bool truncate_sync(int64_t offset);\n\n /**\n * @brief Async [sendfile](http://linux.die.net/man/2/sendfile).\n *\n * Emit a `fs_event` event when completed.\n *\n * @param out A valid instance of file_handle.\n * @param offset Offset, as described in the official documentation.\n * @param length Length, as described in the official documentation.\n */\n void sendfile(file_handle out, int64_t offset, std::size_t length);\n\n /**\n * @brief Sync [sendfile](http://linux.die.net/man/2/sendfile).\n *\n * @param out A valid instance of file_handle.\n * @param offset Offset, as described in the official documentation.\n * @param length Length, as described in the official documentation.\n *\n * @return A `std::pair` composed as it follows:\n * * A boolean value that is true in case of success, false otherwise.\n * * The amount of data transferred.\n */\n std::pair sendfile_sync(file_handle out, int64_t offset, std::size_t length);\n\n /**\n * @brief Async [fchmod](http://linux.die.net/man/2/fchmod).\n *\n * Emit a `fs_event` event when completed.\n *\n * @param mode Mode, as described in the official documentation.\n */\n void chmod(int mode);\n\n /**\n * @brief Sync [fchmod](http://linux.die.net/man/2/fchmod).\n * @param mode Mode, as described in the official documentation.\n * @return True in case of success, false otherwise.\n */\n bool chmod_sync(int mode);\n\n /**\n * @brief Async [futime](http://linux.die.net/man/3/futimes).\n *\n * Emit a `fs_event` event when completed.\n *\n * @param atime `std::chrono::duration`, having the same meaning as\n * described in the official documentation.\n * @param mtime `std::chrono::duration`, having the same meaning as\n * described in the official documentation.\n */\n void futime(time atime, time mtime);\n\n /**\n * @brief Sync [futime](http://linux.die.net/man/3/futimes).\n * @param atime `std::chrono::duration`, having the same meaning as\n * described in the official documentation.\n * @param mtime `std::chrono::duration`, having the same meaning as\n * described in the official documentation.\n * @return True in case of success, false otherwise.\n */\n bool futime_sync(time atime, time mtime);\n\n /**\n * @brief Async [fchown](http://linux.die.net/man/2/fchown).\n *\n * Emit a `fs_event` event when completed.\n *\n * @param uid UID, as described in the official documentation.\n * @param gid GID, as described in the official documentation.\n */\n void chown(uid_type uid, gid_type gid);\n\n /**\n * @brief Sync [fchown](http://linux.die.net/man/2/fchown).\n * @param uid UID, as described in the official documentation.\n * @param gid GID, as described in the official documentation.\n * @return True in case of success, false otherwise.\n */\n bool chown_sync(uid_type uid, gid_type gid);\n\n /**\n * @brief Cast operator to file_handle.\n *\n * Cast operator to an internal representation of the underlying file\n * handle.\n *\n * @return A valid instance of file_handle (the descriptor can be invalid).\n */\n operator file_handle() const noexcept;\n\nprivate:\n std::unique_ptr current{nullptr};\n uv_buf_t buffer{};\n uv_file file{BAD_FD};\n};\n\n/**\n * @brief The fs request.\n *\n * Cross-platform sync and async filesystem operations.
\n * All file operations are run on the threadpool.\n *\n * To create a `fs_req` through a `loop`, no arguments are required.\n *\n * See the official\n * [documentation](http://docs.libuv.org/en/v1.x/fs.html)\n * for further details.\n */\nclass fs_req final: public fs_request {\npublic:\n using copy_file_flags = details::uvw_copy_file_flags;\n using symlink_flags = details::uvw_symlink_flags;\n\n using fs_request::fs_request;\n\n ~fs_req() noexcept;\n\n /**\n * @brief Async [unlink](http://linux.die.net/man/2/unlink).\n *\n * Emit a `fs_event` event when completed.\n *\n * @param path Path, as described in the official documentation.\n */\n void unlink(const std::string &path);\n\n /**\n * @brief Sync [unlink](http://linux.die.net/man/2/unlink).\n * @param path Path, as described in the official documentation.\n * @return True in case of success, false otherwise.\n */\n bool unlink_sync(const std::string &path);\n\n /**\n * @brief Async [mkdir](http://linux.die.net/man/2/mkdir).\n *\n * Emit a `fs_event` event when completed.\n *\n * @param path Path, as described in the official documentation.\n * @param mode Mode, as described in the official documentation.\n */\n void mkdir(const std::string &path, int mode);\n\n /**\n * @brief Sync [mkdir](http://linux.die.net/man/2/mkdir).\n * @param path Path, as described in the official documentation.\n * @param mode Mode, as described in the official documentation.\n * @return True in case of success, false otherwise.\n */\n bool mkdir_sync(const std::string &path, int mode);\n\n /**\n * @brief Async [mktemp](http://linux.die.net/man/3/mkdtemp).\n *\n * Emit a `fs_event` event when completed.\n *\n * @param tpl Template, as described in the official documentation.\n */\n void mkdtemp(const std::string &tpl);\n\n /**\n * @brief Sync [mktemp](http://linux.die.net/man/3/mkdtemp).\n *\n * @param tpl Template, as described in the official documentation.\n *\n * @return A `std::pair` composed as it follows:\n * * A boolean value that is true in case of success, false otherwise.\n * * The actual path of the newly created directory.\n */\n std::pair mkdtemp_sync(const std::string &tpl);\n\n /**\n * @brief Async [mkstemp](https://linux.die.net/man/3/mkstemp).\n *\n * Emit a `fs_event` event when completed.\n *\n * @param tpl Template, as described in the official documentation.\n */\n void mkstemp(const std::string &tpl);\n\n /**\n * @brief Sync [mkstemp](https://linux.die.net/man/3/mkstemp).\n *\n * Returns a composed value where:\n *\n * * The first parameter indicates the created file path.\n * * The second parameter is the file descriptor as an integer.\n *\n * See the official\n * [documentation](http://docs.libuv.org/en/v1.x/fs.html#c.uv_fs_mkstemp)\n * for further details.\n *\n * @param tpl Template, as described in the official documentation.\n *\n * @return A pair where:\n\n * * The first parameter is a boolean value that is true in case of success,\n * false otherwise.\n * * The second parameter is a composed value (see above).\n */\n std::pair> mkstemp_sync(const std::string &tpl);\n\n /**\n * @brief Async [lutime](http://linux.die.net/man/3/lutimes).\n *\n * Emit a `fs_event` event when completed.\n *\n * @param path Path, as described in the official documentation.\n * @param atime `std::chrono::duration`, having the same meaning as\n * described in the official documentation.\n * @param mtime `std::chrono::duration`, having the same meaning as\n * described in the official documentation.\n */\n void lutime(const std::string &path, time atime, time mtime);\n\n /**\n * @brief Sync [lutime](http://linux.die.net/man/3/lutimes).\n * @param path Path, as described in the official documentation.\n * @param atime `std::chrono::duration`, having the same meaning as\n * described in the official documentation.\n * @param mtime `std::chrono::duration`, having the same meaning as\n * described in the official documentation.\n * @return True in case of success, false otherwise.\n */\n bool lutime_sync(const std::string &path, time atime, time mtime);\n\n /**\n * @brief Async [rmdir](http://linux.die.net/man/2/rmdir).\n *\n * Emit a `fs_event` event when completed.\n *\n * @param path Path, as described in the official documentation.\n */\n void rmdir(const std::string &path);\n\n /**\n * @brief Sync [rmdir](http://linux.die.net/man/2/rmdir).\n * @param path Path, as described in the official documentation.\n * @return True in case of success, false otherwise.\n */\n bool rmdir_sync(const std::string &path);\n\n /**\n * @brief Async [scandir](http://linux.die.net/man/3/scandir).\n *\n * Emit a `fs_event` event when completed.\n *\n * @param path Path, as described in the official documentation.\n * @param flags Flags, as described in the official documentation.\n */\n void scandir(const std::string &path, int flags);\n\n /**\n * @brief Sync [scandir](http://linux.die.net/man/3/scandir).\n *\n * @param path Path, as described in the official documentation.\n * @param flags Flags, as described in the official documentation.\n *\n * @return A `std::pair` composed as it follows:\n * * A boolean value that is true in case of success, false otherwise.\n * * The number of directory entries selected.\n */\n std::pair scandir_sync(const std::string &path, int flags);\n\n /**\n * @brief Gets entries populated with the next directory entry data.\n *\n * Returns a composed value where:\n *\n * * The first parameter indicates the entry type (see below).\n * * The second parameter is a string that contains the actual value.\n *\n * Available entry types are:\n *\n * * `fs_req::entry_type::UNKNOWN`\n * * `fs_req::entry_type::FILE`\n * * `fs_req::entry_type::DIR`\n * * `fs_req::entry_type::LINK`\n * * `fs_req::entry_type::FIFO`\n * * `fs_req::entry_type::SOCKET`\n * * `fs_req::entry_type::CHAR`\n * * `fs_req::entry_type::BLOCK`\n *\n * See the official\n * [documentation](http://docs.libuv.org/en/v1.x/fs.html#c.uv_dirent_t)\n * for further details.\n *\n * @return A pair where:\n *\n * * The first parameter is a boolean value that indicates if the current\n * entry is still valid.\n * * The second parameter is a composed value (see above).\n */\n std::pair> scandir_next();\n\n /**\n * @brief Async [stat](http://linux.die.net/man/2/stat).\n *\n * Emit a `fs_event` event when completed.\n *\n * @param path Path, as described in the official documentation.\n */\n void stat(const std::string &path);\n\n /**\n * @brief Sync [stat](http://linux.die.net/man/2/stat).\n *\n * @param path Path, as described in the official documentation.\n *\n * @return A `std::pair` composed as it follows:\n * * A boolean value that is true in case of success, false otherwise.\n * * An initialized instance of file_info.\n */\n std::pair stat_sync(const std::string &path);\n\n /**\n * @brief Async [lstat](http://linux.die.net/man/2/lstat).\n *\n * Emit a `fs_event` event when completed.\n *\n * @param path Path, as described in the official documentation.\n */\n void lstat(const std::string &path);\n\n /**\n * @brief Sync [lstat](http://linux.die.net/man/2/lstat).\n *\n * @param path Path, as described in the official documentation.\n *\n * @return A `std::pair` composed as it follows:\n * * A boolean value that is true in case of success, false otherwise.\n * * An initialized instance of file_info.\n */\n...\n", "needle_token_start": 2380, "needle_token_end": 2411, "code_context_ntokens": 16040, "output": ["```cpp\ntemplate\nvoid set(T *value) noexcept {\n return uv_key_set(uv_type::raw(), value);\n}\n```"]} +{"repo": "skypjack/uvw", "name": "uv_token", "language": "cpp", "path": "src/uvw/loop.h", "position_ratio": 0.25, "description": "\nFunction Description:\n1. **Purpose**: The purpose of this entity is to act as a control mechanism within a class, ensuring that certain operations are restricted to be performed only through the class itself or its friends, thereby encapsulating specific functionalities.\n2. **Input**: It takes an integer as an input, which is used solely to differentiate its constructor and does not carry semantic meaning beyond controlling access.\n3. **Output**: This entity does not produce an output as it is not a function but a constructor for creating instances with restricted access.\n4. **Procedure**: The procedure involves instantiating this entity by passing an integer, which effectively serves as a key or token, allowing the creation of an instance. This instance is then used internally within the class to manage or guard certain functionalities that are sensitive or should be restricted in scope.\n\n", "instruction": "Based on the function description and code context, please retrieve and repeat the exact described function from the code context in a code block wrapped by ```:", "template": "instruction\ncode_context\ndescription\ninstruction", "code_context": "// Path: src/uvw/util.h\n#ifndef UVW_UTIL_INCLUDE_H\n#define UVW_UTIL_INCLUDE_H\n\n#include \n#include \n#include \n#include \n#include \n#include \n#include \n#include \n#include \n#include \n#include \"config.h\"\n\nnamespace uvw {\n\nnamespace details {\n\nenum class uvw_handle_type : std::underlying_type_t {\n UNKNOWN = UV_UNKNOWN_HANDLE,\n ASYNC = UV_ASYNC,\n CHECK = UV_CHECK,\n FS_EVENT = UV_FS_EVENT,\n FS_POLL = UV_FS_POLL,\n HANDLE = UV_HANDLE,\n IDLE = UV_IDLE,\n PIPE = UV_NAMED_PIPE,\n POLL = UV_POLL,\n PREPARE = UV_PREPARE,\n PROCESS = UV_PROCESS,\n STREAM = UV_STREAM,\n TCP = UV_TCP,\n TIMER = UV_TIMER,\n TTY = UV_TTY,\n UDP = UV_UDP,\n SIGNAL = UV_SIGNAL,\n FILE = UV_FILE\n};\n\nenum class uvw_clock_id : std::underlying_type_t {\n MONOTONIC = UV_CLOCK_MONOTONIC,\n REALTIME = UV_CLOCK_REALTIME\n};\n\ntemplate\nstruct uv_type_wrapper {\n using Type = T;\n\n constexpr uv_type_wrapper()\n : value{} {}\n\n constexpr uv_type_wrapper(Type val)\n : value{val} {}\n\n constexpr operator Type() const noexcept {\n return value;\n }\n\n bool operator==(uv_type_wrapper other) const noexcept {\n return value == other.value;\n }\n\nprivate:\n const Type value;\n};\n\ntemplate\nbool operator==(uv_type_wrapper lhs, uv_type_wrapper rhs) {\n return !(lhs == rhs);\n}\n\n} // namespace details\n\n/**\n * @brief Windows size representation.\n */\nstruct win_size {\n int width; /*!< The _width_ of the given window. */\n int height; /*!< The _height_ of the given window. */\n};\n\nusing handle_type = details::uvw_handle_type; /*!< The type of a handle. */\nusing handle_category = details::uv_type_wrapper; /*!< Utility class that wraps an internal handle type. */\nusing file_handle = details::uv_type_wrapper; /*!< Utility class that wraps an internal file handle. */\nusing os_socket_handle = details::uv_type_wrapper; /*!< Utility class that wraps an os socket handle. */\nusing os_file_descriptor = details::uv_type_wrapper; /*!< Utility class that wraps an os file descriptor. */\nusing pid_type = details::uv_type_wrapper; /*!< Utility class that wraps a cross platform representation of a pid. */\nusing clock_id = details::uvw_clock_id; /*!< Utility class that wraps a clock source. */\n\nconstexpr file_handle std_in{0}; /*!< Placeholder for stdin descriptor. */\nconstexpr file_handle std_out{1}; /*!< Placeholder for stdout descriptor. */\nconstexpr file_handle std_err{2}; /*!< Placeholder for stderr descriptor. */\n\nusing time_spec = uv_timespec_t; /*!< Library equivalent for uv_timespec_t. */\nusing file_info = uv_stat_t; /*!< Library equivalent for uv_stat_t. */\nusing fs_info = uv_statfs_t; /*!< Library equivalent for uv_statfs_t. */\nusing uid_type = uv_uid_t; /*!< Library equivalent for uv_uid_t. */\nusing gid_type = uv_gid_t; /*!< Library equivalent for uv_gid_t. */\n\nusing timeval = uv_timeval_t; /*!< Library equivalent for uv_timeval_t. */\nusing timeval64 = uv_timeval64_t; /*!< Library equivalent for uv_timeval64_t. */\nusing timespec64 = uv_timespec64_t; /*!< Library equivalent for uv_timespec64_t. */\nusing resource_usage = uv_rusage_t; /*!< Library equivalent for uv_rusage_t. */\n\n/**\n * @brief Utility class.\n *\n * This class can be used to query the subset of the password file entry for the\n * current effective uid (not the real uid).\n *\n * \\sa utilities::passwd\n */\nstruct passwd_info {\n passwd_info(std::shared_ptr pwd);\n\n /**\n * @brief Gets the username.\n * @return The username of the current effective uid (not the real uid).\n */\n std::string username() const noexcept;\n\n /**\n * @brief Gets the uid.\n * @return The current effective uid (not the real uid).\n */\n decltype(uv_passwd_t::uid) uid() const noexcept;\n\n /**\n * @brief Gets the gid.\n * @return The gid of the current effective uid (not the real uid).\n */\n decltype(uv_passwd_t::gid) gid() const noexcept;\n\n /**\n * @brief Gets the shell.\n * @return The shell of the current effective uid (not the real uid).\n */\n std::string shell() const noexcept;\n\n /**\n * @brief Gets the homedir.\n * @return The homedir of the current effective uid (not the real uid).\n */\n std::string homedir() const noexcept;\n\n /**\n * @brief Checks if the instance contains valid data.\n * @return True if data are all valid, false otherwise.\n */\n operator bool() const noexcept;\n\nprivate:\n std::shared_ptr value;\n};\n\n/**\n * @brief Utility class.\n *\n * This class can be used to get name and information about the current kernel.\n * The populated data includes the operating system name, release, version, and\n * machine.\n *\n * \\sa utilities::uname\n */\nstruct uts_name {\n uts_name(std::shared_ptr init);\n\n /**\n * @brief Gets the operating system name (like \"Linux\").\n * @return The operating system name.\n */\n std::string sysname() const noexcept;\n\n /**\n * @brief Gets the operating system release (like \"2.6.28\").\n * @return The operating system release.\n */\n std::string release() const noexcept;\n\n /**\n * @brief Gets the operating system version.\n * @return The operating system version\n */\n std::string version() const noexcept;\n\n /**\n * @brief Gets the hardware identifier.\n * @return The hardware identifier.\n */\n std::string machine() const noexcept;\n\nprivate:\n std::shared_ptr uname;\n};\n\n/**\n * @brief The IPv4 tag.\n *\n * To be used as template parameter to switch between IPv4 and IPv6.\n */\nstruct ipv4 {};\n\n/**\n * @brief The IPv6 tag.\n *\n * To be used as template parameter to switch between IPv4 and IPv6.\n */\nstruct ipv6 {};\n\n/**\n * @brief Address representation.\n */\nstruct socket_address {\n std::string ip; /*!< Either an IPv4 or an IPv6. */\n unsigned int port; /*!< A valid service identifier. */\n};\n\n/**\n * \\brief CPU information.\n */\nstruct cpu_info {\n using cpu_time = decltype(uv_cpu_info_t::cpu_times);\n\n std::string model; /*!< The model of the CPU. */\n int speed; /*!< The frequency of the CPU. */\n\n /**\n * @brief CPU times.\n...\n// Path: src/uvw/loop.cpp\n#ifdef UVW_AS_LIB\n# include \"loop.h\"\n#endif\n\n#include \"config.h\"\n\nnamespace uvw {\n\nUVW_INLINE loop::loop(std::unique_ptr ptr) noexcept\n : uv_loop{std::move(ptr)} {}\n\nUVW_INLINE std::shared_ptr loop::create() {\n auto ptr = std::unique_ptr{new uv_loop_t, [](uv_loop_t *l) { delete l; }};\n auto curr = std::shared_ptr{new loop{std::move(ptr)}};\n\n if(uv_loop_init(curr->uv_loop.get())) {\n curr = nullptr;\n }\n\n return curr;\n}\n\nUVW_INLINE std::shared_ptr loop::create(uv_loop_t *res) {\n auto ptr = std::unique_ptr{res, [](uv_loop_t *) {}};\n return std::shared_ptr{new loop{std::move(ptr)}};\n}\n\nUVW_INLINE std::shared_ptr loop::get_default() {\n static std::weak_ptr ref;\n std::shared_ptr curr;\n\n if(ref.expired()) {\n auto def = uv_default_loop();\n\n if(def) {\n auto ptr = std::unique_ptr(def, [](uv_loop_t *) {});\n curr = std::shared_ptr{new loop{std::move(ptr)}};\n }\n\n ref = curr;\n } else {\n curr = ref.lock();\n }\n\n return curr;\n}\n\nUVW_INLINE loop::~loop() noexcept {\n if(uv_loop) {\n close();\n }\n}\n\nUVW_INLINE int loop::close() {\n int ret = 0;\n\n if(uv_loop) {\n ret = uv_loop_close(uv_loop.get());\n uv_loop.reset();\n }\n\n return ret;\n}\n\nUVW_INLINE int loop::run(run_mode mode) noexcept {\n return uv_run(uv_loop.get(), static_cast(mode));\n}\n\nUVW_INLINE bool loop::alive() const noexcept {\n return !!uv_loop_alive(uv_loop.get());\n}\n\nUVW_INLINE void loop::stop() noexcept {\n uv_stop(uv_loop.get());\n}\n\nUVW_INLINE int loop::descriptor() const noexcept {\n return uv_backend_fd(uv_loop.get());\n}\n\nUVW_INLINE std::pair loop::timeout() const noexcept {\n auto to = uv_backend_timeout(uv_loop.get());\n return std::make_pair(to == -1, time{to});\n}\n\nUVW_INLINE loop::time loop::idle_time() const noexcept {\n return time{uv_metrics_idle_time(uv_loop.get())};\n}\n\nUVW_INLINE metrics_type loop::metrics() const noexcept {\n metrics_type res{};\n uv_metrics_info(uv_loop.get(), &res);\n return res;\n}\n\nUVW_INLINE loop::time loop::now() const noexcept {\n return time{uv_now(uv_loop.get())};\n}\n\nUVW_INLINE void loop::update() const noexcept {\n return uv_update_time(uv_loop.get());\n}\n\nUVW_INLINE int loop::fork() noexcept {\n return uv_loop_fork(uv_loop.get());\n}\n\nUVW_INLINE void loop::data(std::shared_ptr ud) {\n user_data = std::move(ud);\n}\n\nUVW_INLINE const uv_loop_t *loop::raw() const noexcept {\n return uv_loop.get();\n}\n\nUVW_INLINE uv_loop_t *loop::raw() noexcept {\n return const_cast(const_cast(this)->raw());\n}\n\n} // namespace uvw\n\n// Path: src/uvw/loop.h\n#ifndef UVW_LOOP_INCLUDE_H\n#define UVW_LOOP_INCLUDE_H\n\n#ifdef _WIN32\n# include \n#endif\n\n#include \n#include \n#include \n#include \n#include \n#include \n#include \"config.h\"\n#include \"emitter.h\"\n#include \"util.h\"\n\nnamespace uvw {\n\nclass async_handle;\nclass check_handle;\nclass fs_event_handle;\nclass fs_poll_handle;\nclass idle_handle;\nclass pipe_handle;\nclass poll_handle;\nclass prepare_handle;\nclass process_handle;\nclass signal_handle;\nclass tcp_handle;\nclass timer_handle;\nclass tty_handle;\nclass udp_handle;\n\nnamespace details {\n\nenum class uvw_loop_option : std::underlying_type_t {\n BLOCK_SIGNAL = UV_LOOP_BLOCK_SIGNAL,\n IDLE_TIME = UV_METRICS_IDLE_TIME\n};\n\nenum class uvw_run_mode : std::underlying_type_t {\n DEFAULT = UV_RUN_DEFAULT,\n ONCE = UV_RUN_ONCE,\n NOWAIT = UV_RUN_NOWAIT\n};\n\n} // namespace details\n\nusing metrics_type = uv_metrics_t; /*!< Library equivalent for uv_metrics_t. */\n\n/**\n * @brief The loop class.\n *\n * The event loop is the central part of `uvw`'s functionalities, as well as\n * `libuv`'s ones.
\n * It takes care of polling for I/O and scheduling callbacks to be run based on\n * different sources of events.\n */\nclass loop final: public emitter, public std::enable_shared_from_this {\n using deleter = void (*)(uv_loop_t *);\n\n template\n friend class resource;\n\n class uv_token {\n friend class loop;\n \nexplicit uv_token(int) {}\n };\n\n template\n auto init(int, Type &value) -> decltype(value.init()) {\n return value.init();\n }\n\n template\n int init(char, Type &) {\n return 0;\n }\n\n loop(std::unique_ptr ptr) noexcept;\n\npublic:\n using token = uv_token;\n using time = std::chrono::duration;\n using option = details::uvw_loop_option;\n using run_mode = details::uvw_run_mode;\n\n /**\n * @brief Initializes a new loop instance.\n * @return A pointer to the newly created loop.\n */\n static std::shared_ptr create();\n\n /**\n * @brief Initializes a new loop instance from an existing resource.\n *\n * The lifetime of the resource must exceed that of the instance to which\n * it's associated. Management of the memory associated with the resource is\n * in charge of the user.\n *\n * @param res A valid pointer to a correctly initialized resource.\n * @return A pointer to the newly created loop.\n */\n static std::shared_ptr create(uv_loop_t *res);\n\n /**\n * @brief Gets the initialized default loop.\n *\n * It may return an empty pointer in case of failure.
\n * This function is just a convenient way for having a global loop\n * throughout an application, the default loop is in no way different than\n * the ones initialized with `create()`.
\n * As such, the default loop can be closed with `close()` so the resources\n * associated with it are freed (even if it is not strictly necessary).\n *\n * @return The initialized default loop.\n */\n static std::shared_ptr get_default();\n\n loop(const loop &) = delete;\n loop(loop &&other) = delete;\n\n loop &operator=(const loop &) = delete;\n loop &operator=(loop &&other) = delete;\n\n ~loop() noexcept;\n\n /**\n * @brief Sets additional loop options.\n *\n * You should normally call this before the first call to uv_run() unless\n * mentioned otherwise.
\n * Supported options:\n *\n * * `loop::option::BLOCK_SIGNAL`: Block a signal when polling for new\n * events. A second argument is required and it is the signal number.\n * * `loop::option::IDLE_TIME`: Accumulate the amount of idle time the event\n * loop spends in the event provider. This option is necessary to use\n * `idle_time()`.\n *\n * See the official\n * [documentation](http://docs.libuv.org/en/v1.x/loop.html#c.uv_loop_configure)\n * for further details.\n *\n * @return Underlying return value.\n */\n template\n int configure(option flag, Args &&...args) {\n return uv_loop_configure(uv_loop.get(), static_cast(flag), std::forward(args)...);\n }\n\n /**\n * @brief Creates resources of any type.\n *\n * This should be used as a default method to create resources.
\n * The arguments are the ones required for the specific resource.\n *\n * Use it as `loop->resource()`.\n *\n * @return A pointer to the newly created resource.\n */\n template\n std::shared_ptr resource(Args &&...args) {\n auto ptr = uninitialized_resource(std::forward(args)...);\n return (init(0, *ptr) == 0) ? ptr : nullptr;\n }\n\n /**\n * @brief Creates uninitialized resources of any type.\n * @return A pointer to the newly created resource.\n */\n template\n std::shared_ptr uninitialized_resource(Args &&...args) {\n return std::make_shared(token{0}, shared_from_this(), std::forward(args)...);\n }\n\n /**\n * @brief Releases all internal loop resources.\n *\n * Call this function only when the loop has finished executing and all open\n * handles and requests have been closed, or the loop will error.\n *\n * @return Underlying return value.\n */\n int close();\n\n /**\n * @brief Runs the event loop.\n *\n * Available modes are:\n *\n * * `loop::run_mode::DEFAULT`: Runs the event loop until there are no more\n * active and referenced handles or requests.\n * * `loop::run_mode::ONCE`: Poll for i/o once. Note that this function\n * blocks if there are no pending callbacks.\n * * `loop::run_mode::NOWAIT`: Poll for i/o once but don\u2019t block if there\n * are no pending callbacks.\n *\n * See the official\n * [documentation](http://docs.libuv.org/en/v1.x/loop.html#c.uv_run)\n * for further details.\n *\n * @return Underlying return value.\n */\n int run(run_mode mode = run_mode::DEFAULT) noexcept;\n\n /**\n * @brief Checks if there are active resources.\n * @return True if there are active resources in the loop.\n */\n bool alive() const noexcept;\n\n /**\n * @brief Stops the event loop.\n *\n * It causes `run()` to end as soon as possible.
\n * This will happen not sooner than the next loop iteration.
\n * If this function was called before blocking for I/O, the loop won\u2019t block\n * for I/O on this iteration.\n */\n void stop() noexcept;\n\n /**\n * @brief Get backend file descriptor.\n *\n * Only kqueue, epoll and event ports are supported.
\n * This can be used in conjunction with `run(loop::run_mode::NOWAIT)` to\n * poll in one thread and run the event loop\u2019s callbacks in another.\n *\n * @return The backend file descriptor.\n */\n int descriptor() const noexcept;\n\n /**\n * @brief Gets the poll timeout.\n * @return A `std::pair` composed as it follows:\n * * A boolean value that is true in case of valid timeout, false otherwise.\n * * Milliseconds (`std::chrono::duration`).\n */\n std::pair timeout() const noexcept;\n\n /**\n * @brief Returns the amount of time the event loop has been idle. The call\n * is thread safe.\n * @return The accumulated time spent idle.\n */\n time idle_time() const noexcept;\n\n /**\n * @brief Tracks various internal operations of the event loop.\n * @return Event loop metrics.\n */\n metrics_type metrics() const noexcept;\n\n /**\n * @brief Returns the current timestamp in milliseconds.\n *\n * The timestamp is cached at the start of the event loop tick.
\n * The timestamp increases monotonically from some arbitrary point in\n * time.
\n * Don\u2019t make assumptions about the starting point, you will only get\n * disappointed.\n *\n * @return The current timestamp in milliseconds (actual type is\n * `std::chrono::duration`).\n */\n time now() const noexcept;\n\n /**\n * @brief Updates the event loop\u2019s concept of _now_.\n *\n * The current time is cached at the start of the event loop tick in order\n * to reduce the number of time-related system calls.
\n * You won\u2019t normally need to call this function unless you have callbacks\n * that block the event loop for longer periods of time, where _longer_ is\n * somewhat subjective but probably on the order of a millisecond or more.\n */\n void update() const noexcept;\n\n /**\n * @brief Walks the list of handles.\n *\n * The callback is invoked once for each handle that is still active.\n *\n * @param callback A function to invoke once for each active handle.\n */\n template\n void walk(Func callback) {\n auto func = [](uv_handle_t *hndl, void *callback_func) {\n if(hndl->data) {\n auto &cb = *static_cast(callback_func);\n\n switch(utilities::guess_handle(handle_category{hndl->type})) {\n case handle_type::ASYNC:\n cb(*static_cast(hndl->data));\n break;\n case handle_type::CHECK:\n cb(*static_cast(hndl->data));\n break;\n case handle_type::FS_EVENT:\n cb(*static_cast(hndl->data));\n break;\n case handle_type::FS_POLL:\n cb(*static_cast(hndl->data));\n break;\n case handle_type::IDLE:\n cb(*static_cast(hndl->data));\n break;\n case handle_type::PIPE:\n cb(*static_cast(hndl->data));\n break;\n case handle_type::POLL:\n cb(*static_cast(hndl->data));\n break;\n case handle_type::PREPARE:\n cb(*static_cast(hndl->data));\n break;\n case handle_type::PROCESS:\n cb(*static_cast(hndl->data));\n break;\n case handle_type::SIGNAL:\n cb(*static_cast(hndl->data));\n break;\n case handle_type::TCP:\n cb(*static_cast(hndl->data));\n break;\n case handle_type::TIMER:\n cb(*static_cast(hndl->data));\n break;\n case handle_type::TTY:\n cb(*static_cast(hndl->data));\n break;\n case handle_type::UDP:\n cb(*static_cast(hndl->data));\n break;\n default:\n // this handle isn't managed by uvw, let it be...\n break;\n }\n }\n };\n\n uv_walk(uv_loop.get(), func, &callback);\n }\n\n /**\n * @brief Reinitialize any kernel state necessary in the child process after\n * a fork(2) system call.\n *\n * Previously started watchers will continue to be started in the child\n * process.\n *\n * It is necessary to explicitly call this function on every event loop\n * created in the parent process that you plan to continue to use in the\n * child, including the default loop (even if you don\u2019t continue to use it\n * in the parent). This function must be called before calling any API\n * function using the loop in the child. Failure to do so will result in\n * undefined behaviour, possibly including duplicate events delivered to\n * both parent and child or aborting the child process.\n *\n * When possible, it is preferred to create a new loop in the child process\n * instead of reusing a loop created in the parent. New loops created in the\n * child process after the fork should not use this function.\n *\n * Note that this function is not implemented on Windows.
\n * Note also that this function is experimental in `libuv`. It may contain\n * bugs, and is subject to change or removal. API and ABI stability is not\n * guaranteed.\n *\n * See the official\n * [documentation](http://docs.libuv.org/en/v1.x/loop.html#c.uv_loop_fork)\n * for further details.\n *\n * @return Underlying return value.\n */\n int fork() noexcept;\n\n /**\n * @brief Gets user-defined data. `uvw` won't use this field in any case.\n * @return User-defined data if any, an invalid pointer otherwise.\n */\n template\n std::shared_ptr data() const {\n return std::static_pointer_cast(user_data);\n }\n\n /**\n * @brief Sets arbitrary data. `uvw` won't use this field in any case.\n * @param ud User-defined arbitrary data.\n */\n void data(std::shared_ptr ud);\n\n /**\n * @brief Gets the underlying raw data structure.\n *\n * This function should not be used, unless you know exactly what you are\n * doing and what are the risks.
\n * Going raw is dangerous, mainly because the lifetime management of a loop,\n * a handle or a request is in charge to the library itself and users should\n * not work around it.\n *\n * @warning\n * Use this function at your own risk, but do not expect any support in case\n * of bugs.\n *\n * @return The underlying raw data structure.\n */\n const uv_loop_t *raw() const noexcept;\n\n /**\n * @brief Gets the underlying raw data structure.\n *\n * This function should not be used, unless you know exactly what you are\n * doing and what are the risks.
\n * Going raw is dangerous, mainly because the lifetime management of a loop,\n * a handle or a request is in charge to the library itself and users should\n * not work around it.\n *\n * @warning\n * Use this function at your own risk, but do not expect any support in case\n * of bugs.\n *\n * @return The underlying raw data structure.\n */\n uv_loop_t *raw() noexcept;\n\nprivate:\n std::unique_ptr uv_loop;\n std::shared_ptr user_data{nullptr};\n};\n\n} // namespace uvw\n\n#ifndef UVW_AS_LIB\n# include \"loop.cpp\"\n#endif\n\n#endif // UVW_LOOP_INCLUDE_H\n\n// Path: src/uvw/uv_type.hpp\n#ifndef UVW_UV_TYPE_INCLUDE_H\n#define UVW_UV_TYPE_INCLUDE_H\n\n#include \n#include \n#include \n#include \"config.h\"\n#include \"loop.h\"\n\nnamespace uvw {\n\n/**\n * @brief Wrapper class for underlying types.\n *\n * It acts mainly as a wrapper around data structures of the underlying library.\n */\ntemplate\nstruct uv_type {\n explicit uv_type(loop::token, std::shared_ptr ref) noexcept\n : owner{std::move(ref)}, resource{} {}\n\n uv_type(const uv_type &) = delete;\n uv_type(uv_type &&) = delete;\n\n uv_type &operator=(const uv_type &) = delete;\n uv_type &operator=(uv_type &&) = delete;\n\n /**\n * @brief Gets the loop from which the resource was originated.\n * @return A reference to a loop instance.\n */\n loop &parent() const noexcept {\n return *owner;\n }\n\n /**\n * @brief Gets the underlying raw data structure.\n *\n * This function should not be used, unless you know exactly what you are\n * doing and what are the risks.
\n * Going raw is dangerous, mainly because the lifetime management of a loop,\n * a handle or a request is in charge to the library itself and users should\n * not work around it.\n *\n * @warning\n * Use this function at your own risk, but do not expect any support in case\n * of bugs.\n *\n * @return The underlying raw data structure.\n */\n const U *raw() const noexcept {\n return &resource;\n }\n\n /**\n * @brief Gets the underlying raw data structure.\n *\n * This function should not be used, unless you know exactly what you are\n * doing and what are the risks.
\n * Going raw is dangerous, mainly because the lifetime management of a loop,\n * a handle or a request is in charge to the library itself and users should\n * not work around it.\n *\n * @warning\n * Use this function at your own risk, but do not expect any support in case\n * of bugs.\n *\n * @return The underlying raw data structure.\n */\n U *raw() noexcept {\n return &resource;\n }\n\nprotected:\n ~uv_type() = default;\n\nprivate:\n std::shared_ptr owner;\n U resource;\n};\n\n} // namespace uvw\n\n#endif // UVW_UV_TYPE_INCLUDE_H\n\n// Path: src/uvw/resource.hpp\n#ifndef UVW_RESOURCE_INCLUDE_H\n#define UVW_RESOURCE_INCLUDE_H\n\n#include \n#include \n#include \"config.h\"\n#include \"emitter.h\"\n#include \"uv_type.hpp\"\n\nnamespace uvw {\n\n/**\n * @brief Common class for almost all the resources available in `uvw`.\n *\n * This is the base class for handles and requests.\n */\ntemplate\nclass resource: public uv_type, public emitter, public std::enable_shared_from_this {\nprotected:\n int leak_if(int err) noexcept {\n if(err == 0) {\n self_ptr = this->shared_from_this();\n }\n\n return err;\n }\n\n void self_reset() noexcept {\n self_ptr.reset();\n }\n\n bool has_self() const noexcept {\n return static_cast(self_ptr);\n }\n\npublic:\n explicit resource(loop::token token, std::shared_ptr ref)\n : uv_type{token, std::move(ref)} {\n this->raw()->data = this;\n }\n\n /**\n * @brief Gets user-defined data. `uvw` won't use this field in any case.\n * @return User-defined data if any, an invalid pointer otherwise.\n */\n template\n std::shared_ptr data() const {\n return std::static_pointer_cast(user_data);\n }\n\n /**\n * @brief Sets arbitrary data. `uvw` won't use this field in any case.\n * @param udata User-defined arbitrary data.\n */\n void data(std::shared_ptr udata) {\n user_data = std::move(udata);\n }\n\nprivate:\n std::shared_ptr user_data{nullptr};\n std::shared_ptr self_ptr{nullptr};\n};\n\n} // namespace uvw\n\n#endif // UVW_RESOURCE_INCLUDE_H\n\n// Path: src/uvw/handle.hpp\n#ifndef UVW_HANDLE_INCLUDE_H\n#define UVW_HANDLE_INCLUDE_H\n\n#include \n#include \n#include \n#include \n#include \"config.h\"\n#include \"resource.hpp\"\n#include \"util.h\"\n\nnamespace uvw {\n\n/*! @brief Close event. */\nstruct close_event {};\n\n/**\n * @brief Handle base class.\n *\n * Base type for all `uvw` handle types.\n */\ntemplate\nclass handle: public resource {\nprotected:\n static void close_callback(uv_handle_t *hndl) {\n handle &ref = *(static_cast(hndl->data));\n [[maybe_unused]] auto ptr = ref.shared_from_this();\n ref.self_reset();\n ref.publish(close_event{});\n }\n\n uv_handle_t *as_uv_handle() {\n return reinterpret_cast(this->raw());\n }\n\n const uv_handle_t *as_uv_handle() const {\n return reinterpret_cast(this->raw());\n }\n\npublic:\n using resource::resource;\n\n /**\n * @brief Gets the category of the handle.\n *\n * A base handle offers no functionality to promote it to the actual handle\n * type. By means of this function, an opaque value that identifies the\n * category of the handle is made available to the users.\n *\n * @return The actual category of the handle.\n */\n handle_category category() const noexcept {\n return handle_category{as_uv_handle()->type};\n }\n\n /**\n * @brief Gets the type of the handle.\n *\n * A base handle offers no functionality to promote it to the actual handle\n * type. By means of this function, the type of the underlying handle as\n * specified by handle_type is made available to the users.\n *\n * @return The actual type of the handle.\n */\n handle_type type() const noexcept {\n return utilities::guess_handle(category());\n }\n\n /**\n * @brief Checks if the handle is active.\n *\n * What _active_ means depends on the type of handle:\n *\n * * An async_handle handle is always active and cannot be deactivated,\n * except by closing it with uv_close().\n * * A pipe, tcp, udp, etc. handle - basically any handle that deals with\n * I/O - is active when it is doing something that involves I/O, like\n * reading, writing, connecting, accepting new connections, etc.\n * * A check, idle, timer, etc. handle is active when it has been started\n * with a call to `start()`.\n *\n * Rule of thumb: if a handle of type `foo_handle` has a `start()` member\n * method, then it\u2019s active from the moment that method is called. Likewise,\n * `stop()` deactivates the handle again.\n *\n * @return True if the handle is active, false otherwise.\n */\n bool active() const noexcept {\n return !!uv_is_active(as_uv_handle());\n }\n\n /**\n * @brief Checks if a handle is closing or closed.\n *\n * This function should only be used between the initialization of the\n * handle and the arrival of the close callback.\n *\n * @return True if the handle is closing or closed, false otherwise.\n */\n bool closing() const noexcept {\n return !!uv_is_closing(as_uv_handle());\n }\n\n /**\n * @brief Request handle to be closed.\n *\n * This **must** be called on each handle before memory is released.
\n * In-progress requests are cancelled and this can result in errors.\n *\n * The handle will emit a close event when finished.\n */\n void close() noexcept {\n if(!closing()) {\n uv_close(as_uv_handle(), &handle::close_callback);\n }\n }\n\n /**\n * @brief Reference the given handle.\n *\n * References are idempotent, that is, if a handle is already referenced\n * calling this function again will have no effect.\n */\n void reference() noexcept {\n uv_ref(as_uv_handle());\n }\n\n /**\n * @brief Unreference the given handle.\n *\n * References are idempotent, that is, if a handle is not referenced calling\n * this function again will have no effect.\n */\n void unreference() noexcept {\n uv_unref(as_uv_handle());\n }\n\n /**\n * @brief Checks if the given handle referenced.\n * @return True if the handle referenced, false otherwise.\n */\n bool referenced() const noexcept {\n return !!uv_has_ref(as_uv_handle());\n }\n\n /**\n * @brief Returns the size of the underlying handle type.\n * @return The size of the underlying handle type.\n */\n std::size_t size() const noexcept {\n return uv_handle_size(as_uv_handle()->type);\n }\n\n /**\n * @brief Gets the size of the send buffer used for the socket.\n *\n * Gets the size of the send buffer that the operating system uses for the\n * socket.
\n * This function works for tcp, pipeand udp handles on Unix and for tcp and\n * udp handles on Windows.
\n * Note that Linux will return double the size of the original set value.\n *\n * @return The size of the send buffer, the underlying return value in case\n * of errors.\n */\n int send_buffer_size() {\n int value = 0;\n auto err = uv_send_buffer_size(as_uv_handle(), &value);\n return err ? err : value;\n }\n\n /**\n * @brief Sets the size of the send buffer used for the socket.\n *\n * Sets the size of the send buffer that the operating system uses for the\n * socket.
\n * This function works for tcp, pipe and udp handles on Unix and for tcp and\n * udp handles on Windows.
\n * Note that Linux will set double the size.\n *\n * @return Underlying return value.\n */\n int send_buffer_size(int value) {\n return uv_send_buffer_size(as_uv_handle(), &value);\n }\n\n /**\n * @brief Gets the size of the receive buffer used for the socket.\n *\n * Gets the size of the receive buffer that the operating system uses for\n * the socket.
\n * This function works for tcp, pipe and udp handles on Unix and for tcp and\n * udp handles on Windows.
\n * Note that Linux will return double the size of the original set value.\n *\n * @return The size of the receive buffer, the underlying return value in\n * case of errors.\n */\n int recv_buffer_size() {\n int value = 0;\n auto err = uv_recv_buffer_size(as_uv_handle(), &value);\n return err ? err : value;\n }\n\n /**\n * @brief Sets the size of the receive buffer used for the socket.\n *\n * Sets the size of the receive buffer that the operating system uses for\n * the socket.
\n * This function works for tcp, pipe and udp handles on Unix and for tcp and\n * udp handles on Windows.
\n * Note that Linux will set double the size.\n *\n * @return Underlying return value.\n */\n int recv_buffer_size(int value) {\n return uv_recv_buffer_size(as_uv_handle(), &value);\n }\n\n /**\n * @brief Gets the platform dependent file descriptor equivalent.\n *\n * Supported handles:\n *\n * * tcp_handle\n * * pipe_handle\n * * tty_handle\n * * udp_handle\n * * poll_handle\n *\n * If invoked on a different handle, one that doesn\u2019t have an attached file\n * descriptor yet or one which was closed, an invalid value is returned.\n *\n * See the official\n * [documentation](http://docs.libuv.org/en/v1.x/handle.html#c.uv_fileno)\n * for further details.\n *\n * @return The file descriptor attached to the hande or a negative value in\n * case of errors.\n */\n os_file_descriptor fd() const {\n uv_os_fd_t fd;\n uv_fileno(as_uv_handle(), &fd);\n return fd;\n }\n};\n\n} // namespace uvw\n\n#endif // UVW_HANDLE_INCLUDE_H\n\n// Path: src/uvw/prepare.h\n#ifndef UVW_PREPARE_INCLUDE_H\n#define UVW_PREPARE_INCLUDE_H\n\n#include \n#include \"handle.hpp\"\n#include \"loop.h\"\n\nnamespace uvw {\n\n/*! @brief Prepare event. */\nstruct prepare_event {};\n\n/**\n * @brief The prepare handle.\n *\n * Prepare handles will emit a prepare event once per loop iteration, right\n * before polling for I/O.\n *\n * To create a `prepare_handle` through a `loop`, no arguments are required.\n */\nclass prepare_handle final: public handle {\n static void start_callback(uv_prepare_t *hndl);\n\npublic:\n using handle::handle;\n\n /**\n * @brief Initializes the handle.\n * @return Underlying return value.\n */\n int init();\n\n /**\n * @brief Starts the handle.\n *\n * A prepare event will be emitted once per loop iteration, right before\n * polling for I/O.\n *\n * The handle will start emitting prepare events when needed.\n *\n * @return Underlying return value.\n */\n int start();\n\n /**\n * @brief Stops the handle.\n * @return Underlying return value.\n */\n int stop();\n};\n\n} // namespace uvw\n\n#ifndef UVW_AS_LIB\n# include \"prepare.cpp\"\n#endif\n\n#endif // UVW_PREPARE_INCLUDE_H\n\n// Path: src/uvw/prepare.cpp\n#ifdef UVW_AS_LIB\n# include \"prepare.h\"\n#endif\n\n#include \"config.h\"\n\nnamespace uvw {\n\nUVW_INLINE void prepare_handle::start_callback(uv_prepare_t *hndl) {\n prepare_handle &prepare = *(static_cast(hndl->data));\n prepare.publish(prepare_event{});\n}\n\nUVW_INLINE int prepare_handle::init() {\n return leak_if(uv_prepare_init(parent().raw(), raw()));\n}\n\nUVW_INLINE int prepare_handle::start() {\n return uv_prepare_start(raw(), &start_callback);\n}\n\nUVW_INLINE int prepare_handle::stop() {\n return uv_prepare_stop(raw());\n}\n\n} // namespace uvw\n\n// Path: src/uvw/request.hpp\n#ifndef UVW_REQUEST_INCLUDE_H\n#define UVW_REQUEST_INCLUDE_H\n\n#include \n#include \n#include \n#include \n#include \"config.h\"\n#include \"resource.hpp\"\n\nnamespace uvw {\n\n/**\n * @brief Request base class.\n *\n * Base type for all `uvw` request types.\n */\ntemplate\nclass request: public resource {\nprotected:\n static auto reserve(U *req) {\n auto ptr = static_cast(req->data)->shared_from_this();\n ptr->self_reset();\n return ptr;\n }\n\npublic:\n using resource::resource;\n\n /**\n * @brief Cancels a pending request.\n *\n * This method fails if the request is executing or has finished\n * executing.\n *\n * See the official\n * [documentation](http://docs.libuv.org/en/v1.x/request.html#c.uv_cancel)\n * for further details.\n *\n * @return Underlying return value.\n */\n int cancel() {\n return uv_cancel(reinterpret_cast(this->raw()));\n }\n\n /**\n * @brief Returns the size of the underlying request type.\n * @return The size of the underlying request type.\n */\n std::size_t size() const noexcept {\n return uv_req_size(reinterpret_cast(this->raw())->type);\n }\n};\n\n} // namespace uvw\n\n#endif // UVW_REQUEST_INCLUDE_H\n\n// Path: src/uvw/stream.h\n#ifndef UVW_STREAM_INCLUDE_H\n#define UVW_STREAM_INCLUDE_H\n\n#include \n#include \n#include \n#include \n#include \n#include \n#include \"config.h\"\n#include \"handle.hpp\"\n#include \"loop.h\"\n#include \"request.hpp\"\n\nnamespace uvw {\n\n/*! @brief Connect event. */\nstruct connect_event {};\n\n/*! @brief End event. */\nstruct end_event {};\n\n/*! @brief Listen event. */\nstruct listen_event {};\n\n/*! @brief Shutdown event. */\nstruct shutdown_event {};\n\n/*! @brief Write event. */\nstruct write_event {};\n\n/*! @brief Data event. */\nstruct data_event {\n explicit data_event(std::unique_ptr buf, std::size_t len) noexcept;\n\n std::unique_ptr data; /*!< A bunch of data read on the stream. */\n std::size_t length; /*!< The amount of data read on the stream. */\n};\n\nnamespace details {\n\nclass connect_req final: public request {\n static void connect_callback(uv_connect_t *req, int status);\n\npublic:\n using request::request;\n\n template\n auto connect(F &&f, Args &&...args) -> std::enable_if_t(f)(raw(), std::forward(args)..., &connect_callback)), void>, int> {\n std::forward(f)(raw(), std::forward(args)..., &connect_callback);\n return this->leak_if(0);\n }\n\n template\n auto connect(F &&f, Args &&...args) -> std::enable_if_t(f)(raw(), std::forward(args)..., &connect_callback)), void>, int> {\n return this->leak_if(std::forward(f)(raw(), std::forward(args)..., &connect_callback));\n }\n};\n\nclass shutdown_req final: public request {\n static void shoutdown_callback(uv_shutdown_t *req, int status);\n\npublic:\n using request::request;\n\n int shutdown(uv_stream_t *hndl);\n};\n\ntemplate\nclass write_req final: public request, uv_write_t, write_event> {\n static void write_callback(uv_write_t *req, int status) {\n if(auto ptr = request, uv_write_t, write_event>::reserve(req); status) {\n ptr->publish(error_event{status});\n } else {\n ptr->publish(write_event{});\n }\n }\n\npublic:\n write_req(loop::token token, std::shared_ptr parent, std::unique_ptr dt, unsigned int len)\n : request, uv_write_t, write_event>{token, std::move(parent)},\n data{std::move(dt)},\n buf{uv_buf_init(data.get(), len)} {}\n\n int write(uv_stream_t *hndl) {\n return this->leak_if(uv_write(this->raw(), hndl, &buf, 1, &write_callback));\n }\n\n int write(uv_stream_t *hndl, uv_stream_t *send) {\n return this->leak_if(uv_write2(this->raw(), hndl, &buf, 1, send, &write_callback));\n }\n\nprivate:\n std::unique_ptr data;\n uv_buf_t buf;\n};\n\n} // namespace details\n\n/**\n * @brief The stream handle.\n *\n * Stream handles provide an abstraction of a duplex communication channel.\n * The stream handle is an intermediate type, `uvw` provides three stream\n * implementations: tcp, pipe and tty handles.\n */\ntemplate\nclass stream_handle: public handle {\n using base = handle;\n\n template\n friend class stream_handle;\n\n static constexpr unsigned int DEFAULT_BACKLOG = 128;\n\n static void read_callback(uv_stream_t *hndl, ssize_t nread, const uv_buf_t *buf) {\n T &ref = *(static_cast(hndl->data));\n // data will be destroyed no matter of what the value of nread is\n std::unique_ptr data{buf->base};\n\n // nread == 0 is ignored (see http://docs.libuv.org/en/v1.x/stream.html)\n // equivalent to EAGAIN/EWOULDBLOCK, it shouldn't be treated as an error\n // for we don't have data to emit though, it's fine to suppress it\n\n if(nread == UV_EOF) {\n // end of stream\n ref.publish(end_event{});\n } else if(nread > 0) {\n // data available\n ref.publish(data_event{std::move(data), static_cast(nread)});\n } else if(nread < 0) {\n // transmission error\n ref.publish(error_event(nread));\n }\n }\n\n static void listen_callback(uv_stream_t *hndl, int status) {\n if(T &ref = *(static_cast(hndl->data)); status) {\n ref.publish(error_event{status});\n } else {\n ref.publish(listen_event{});\n }\n }\n\n uv_stream_t *as_uv_stream() {\n return reinterpret_cast(this->raw());\n }\n\n const uv_stream_t *as_uv_stream() const {\n return reinterpret_cast(this->raw());\n }\n\npublic:\n#ifdef _MSC_VER\n stream_handle(loop::token token, std::shared_ptr ref)\n : base{token, std::move(ref)} {}\n#else\n using base::base;\n#endif\n\n /**\n * @brief Shutdowns the outgoing (write) side of a duplex stream.\n *\n * It waits for pending write requests to complete. The handle should refer\n * to a initialized stream.
\n * A shutdown event will be emitted after shutdown is complete.\n *\n * @return Underlying return value.\n */\n int shutdown() {\n auto listener = [ptr = this->shared_from_this()](const auto &event, const auto &) {\n ptr->publish(event);\n };\n\n auto shutdown = this->parent().template resource();\n shutdown->template on(listener);\n shutdown->template on(listener);\n\n return shutdown->shutdown(as_uv_stream());\n }\n\n /**\n * @brief Starts listening for incoming connections.\n *\n * When a new incoming connection is received, a listen event is\n * emitted.\n *\n * @param backlog Indicates the number of connections the kernel might\n * queue, same as listen(2).\n *\n * @return Underlying return value.\n */\n int listen(int backlog = DEFAULT_BACKLOG) {\n return uv_listen(as_uv_stream(), backlog, &listen_callback);\n }\n\n /**\n * @brief Accepts incoming connections.\n *\n * This call is used in conjunction with `listen()` to accept incoming\n * connections. Call this function after receiving a listen event to accept\n * the connection. Before calling this function, the submitted handle must\n * be initialized.\n *\n * When the listen event is emitted it is guaranteed that this function will\n * complete successfully the first time. If you attempt to use it more than\n * once, it may fail.
\n * It is suggested to only call this function once per listen event.\n *\n * @note\n * Both the handles must be running on the same loop.\n *\n * @param ref An initialized handle to be used to accept the connection.\n * @return Underlying return value.\n */\n template\n int accept(S &ref) {\n return uv_accept(as_uv_stream(), ref.as_uv_stream());\n }\n\n /**\n * @brief Starts reading data from an incoming stream.\n *\n * A data event will be emitted several times until there is no more data to\n * read or `stop()` is called.
\n * An end event will be emitted when there is no more data to read.\n *\n * @return Underlying return value.\n */\n int read() {\n return uv_read_start(as_uv_stream(), &details::common_alloc_callback, &read_callback);\n }\n\n /**\n * @brief Starts reading data from an incoming stream.\n * @sa read\n * @tparam Alloc Custom allocation function.\n * @return Underlying return value.\n */\n template\n int read() {\n return uv_read_start(as_uv_stream(), &details::common_alloc_callback, &read_callback);\n }\n\n /**\n * @brief Stops reading data from the stream.\n *\n * This function is idempotent and may be safely called on a stopped stream.\n *\n * @return Underlying return value.\n */\n int stop() {\n return uv_read_stop(as_uv_stream());\n }\n\n /**\n * @brief Writes data to the stream.\n *\n * Data are written in order. The handle takes the ownership of the data and\n * it is in charge of delete them.\n *\n * A write event will be emitted when the data have been written.\n *\n * @param data The data to be written to the stream.\n * @param len The lenght of the submitted data.\n * @return Underlying return value.\n */\n template\n int write(std::unique_ptr data, unsigned int len) {\n auto req = this->parent().template resource>(std::move(data), len);\n auto listener = [ptr = this->shared_from_this()](const auto &event, const auto &) {\n ptr->publish(event);\n };\n\n req->template on(listener);\n req->template on(listener);\n\n return req->write(as_uv_stream());\n }\n\n /**\n * @brief Writes data to the stream.\n *\n * Data are written in order. The handle doesn't take the ownership of the\n * data. Be sure that their lifetime overcome the one of the request.\n *\n * A write event will be emitted when the data have been written.\n *\n * @param data The data to be written to the stream.\n * @param len The lenght of the submitted data.\n * @return Underlying return value.\n */\n int write(char *data, unsigned int len) {\n auto req = this->parent().template resource>(std::unique_ptr{data, [](char *) {}}, len);\n auto listener = [ptr = this->shared_from_this()](const auto &event, const auto &) {\n ptr->publish(event);\n };\n\n req->template on(listener);\n req->template on(listener);\n\n return req->write(as_uv_stream());\n }\n\n /**\n * @brief Extended write function for sending handles over a pipe handle.\n *\n * The pipe must be initialized with `ipc == true`.\n *\n * `send` must be a tcp or pipe handle, which is a server or a connection\n * (listening or connected state). Bound sockets or pipes will be assumed to\n * be servers.\n *\n * The handle takes the ownership of the data and it is in charge of delete\n * them.\n *\n * A write event will be emitted when the data have been written.\n *\n * @param send The handle over which to write data.\n * @param data The data to be written to the stream.\n * @param len The lenght of the submitted data.\n * @return Underlying return value.\n */\n template\n int write(S &send, std::unique_ptr data, unsigned int len) {\n auto req = this->parent().template resource>(std::move(data), len);\n auto listener = [ptr = this->shared_from_this()](const auto &event, const auto &) {\n ptr->publish(event);\n };\n\n req->template on(listener);\n req->template on(listener);\n\n return req->write(as_uv_stream(), send.as_uv_stream());\n }\n\n /**\n * @brief Extended write function for sending handles over a pipe handle.\n *\n...\n", "needle_token_start": 3974, "needle_token_end": 3983, "code_context_ntokens": 16074, "output": ["```cpp\nclass uv_token {\n friend class loop;\n \nexplicit uv_token(int) {}\n};\n```"]} +{"repo": "skypjack/uvw", "name": "write_req", "language": "cpp", "path": "src/uvw/stream.h", "position_ratio": 0.35, "description": "\nFunction Description:\n1. **Purpose**: The purpose of this function is to handle the asynchronous writing of data to a stream, managing memory and notifying upon completion or error.\n2. **Input**: It takes a token for event loop association, a shared pointer to the event loop, a unique pointer to the data buffer managed by a custom deleter, and the length of the data to be written.\n3. **Output**: It does not return a value directly but triggers events upon completion. If the write operation is successful, a success event is published; if there is an error, an error event with the status code is published.\n4. **Procedure**: The function initializes a write request with the provided data and length. It then sets up a callback to handle the result of the write operation. When the write operation completes, the callback checks the status. If there is an error, it publishes an error event; otherwise, it publishes a success event indicating the write operation was successful.\n\n", "instruction": "Based on the function description and code context, please retrieve and repeat the exact described function from the code context in a code block wrapped by ```:", "template": "instruction\ncode_context\ndescription\ninstruction", "code_context": "// Path: src/uvw/uv_type.hpp\n#ifndef UVW_UV_TYPE_INCLUDE_H\n#define UVW_UV_TYPE_INCLUDE_H\n\n#include \n#include \n#include \n#include \"config.h\"\n#include \"loop.h\"\n\nnamespace uvw {\n\n/**\n * @brief Wrapper class for underlying types.\n *\n * It acts mainly as a wrapper around data structures of the underlying library.\n */\ntemplate\nstruct uv_type {\n explicit uv_type(loop::token, std::shared_ptr ref) noexcept\n : owner{std::move(ref)}, resource{} {}\n\n uv_type(const uv_type &) = delete;\n uv_type(uv_type &&) = delete;\n\n uv_type &operator=(const uv_type &) = delete;\n uv_type &operator=(uv_type &&) = delete;\n\n /**\n * @brief Gets the loop from which the resource was originated.\n * @return A reference to a loop instance.\n */\n loop &parent() const noexcept {\n return *owner;\n }\n\n /**\n * @brief Gets the underlying raw data structure.\n *\n * This function should not be used, unless you know exactly what you are\n * doing and what are the risks.
\n * Going raw is dangerous, mainly because the lifetime management of a loop,\n * a handle or a request is in charge to the library itself and users should\n * not work around it.\n *\n * @warning\n * Use this function at your own risk, but do not expect any support in case\n * of bugs.\n *\n * @return The underlying raw data structure.\n */\n const U *raw() const noexcept {\n return &resource;\n }\n\n /**\n * @brief Gets the underlying raw data structure.\n *\n * This function should not be used, unless you know exactly what you are\n * doing and what are the risks.
\n * Going raw is dangerous, mainly because the lifetime management of a loop,\n * a handle or a request is in charge to the library itself and users should\n * not work around it.\n *\n * @warning\n * Use this function at your own risk, but do not expect any support in case\n * of bugs.\n *\n * @return The underlying raw data structure.\n */\n U *raw() noexcept {\n return &resource;\n }\n\nprotected:\n ~uv_type() = default;\n\nprivate:\n std::shared_ptr owner;\n U resource;\n};\n\n} // namespace uvw\n...\n// Path: src/uvw/resource.hpp\n#ifndef UVW_RESOURCE_INCLUDE_H\n#define UVW_RESOURCE_INCLUDE_H\n\n#include \n#include \n#include \"config.h\"\n#include \"emitter.h\"\n#include \"uv_type.hpp\"\n\nnamespace uvw {\n\n/**\n * @brief Common class for almost all the resources available in `uvw`.\n *\n * This is the base class for handles and requests.\n */\ntemplate\nclass resource: public uv_type, public emitter, public std::enable_shared_from_this {\nprotected:\n int leak_if(int err) noexcept {\n if(err == 0) {\n self_ptr = this->shared_from_this();\n }\n\n return err;\n }\n\n void self_reset() noexcept {\n self_ptr.reset();\n }\n\n bool has_self() const noexcept {\n return static_cast(self_ptr);\n }\n\npublic:\n explicit resource(loop::token token, std::shared_ptr ref)\n : uv_type{token, std::move(ref)} {\n this->raw()->data = this;\n }\n\n /**\n * @brief Gets user-defined data. `uvw` won't use this field in any case.\n * @return User-defined data if any, an invalid pointer otherwise.\n */\n template\n std::shared_ptr data() const {\n return std::static_pointer_cast(user_data);\n }\n\n /**\n * @brief Sets arbitrary data. `uvw` won't use this field in any case.\n * @param udata User-defined arbitrary data.\n */\n void data(std::shared_ptr udata) {\n user_data = std::move(udata);\n }\n\nprivate:\n std::shared_ptr user_data{nullptr};\n std::shared_ptr self_ptr{nullptr};\n};\n\n} // namespace uvw\n\n#endif // UVW_RESOURCE_INCLUDE_H\n\n// Path: src/uvw/handle.hpp\n#ifndef UVW_HANDLE_INCLUDE_H\n#define UVW_HANDLE_INCLUDE_H\n\n#include \n#include \n#include \n#include \n#include \"config.h\"\n#include \"resource.hpp\"\n#include \"util.h\"\n\nnamespace uvw {\n\n/*! @brief Close event. */\nstruct close_event {};\n\n/**\n * @brief Handle base class.\n *\n * Base type for all `uvw` handle types.\n */\ntemplate\nclass handle: public resource {\nprotected:\n static void close_callback(uv_handle_t *hndl) {\n handle &ref = *(static_cast(hndl->data));\n [[maybe_unused]] auto ptr = ref.shared_from_this();\n ref.self_reset();\n ref.publish(close_event{});\n }\n\n uv_handle_t *as_uv_handle() {\n return reinterpret_cast(this->raw());\n }\n\n const uv_handle_t *as_uv_handle() const {\n return reinterpret_cast(this->raw());\n }\n\npublic:\n using resource::resource;\n\n /**\n * @brief Gets the category of the handle.\n *\n * A base handle offers no functionality to promote it to the actual handle\n * type. By means of this function, an opaque value that identifies the\n * category of the handle is made available to the users.\n *\n * @return The actual category of the handle.\n */\n handle_category category() const noexcept {\n return handle_category{as_uv_handle()->type};\n }\n\n /**\n * @brief Gets the type of the handle.\n *\n * A base handle offers no functionality to promote it to the actual handle\n * type. By means of this function, the type of the underlying handle as\n * specified by handle_type is made available to the users.\n *\n * @return The actual type of the handle.\n */\n handle_type type() const noexcept {\n return utilities::guess_handle(category());\n }\n\n /**\n * @brief Checks if the handle is active.\n *\n * What _active_ means depends on the type of handle:\n *\n * * An async_handle handle is always active and cannot be deactivated,\n * except by closing it with uv_close().\n * * A pipe, tcp, udp, etc. handle - basically any handle that deals with\n * I/O - is active when it is doing something that involves I/O, like\n * reading, writing, connecting, accepting new connections, etc.\n * * A check, idle, timer, etc. handle is active when it has been started\n * with a call to `start()`.\n *\n * Rule of thumb: if a handle of type `foo_handle` has a `start()` member\n * method, then it\u2019s active from the moment that method is called. Likewise,\n * `stop()` deactivates the handle again.\n *\n * @return True if the handle is active, false otherwise.\n */\n bool active() const noexcept {\n return !!uv_is_active(as_uv_handle());\n }\n\n /**\n * @brief Checks if a handle is closing or closed.\n *\n * This function should only be used between the initialization of the\n * handle and the arrival of the close callback.\n *\n * @return True if the handle is closing or closed, false otherwise.\n */\n bool closing() const noexcept {\n return !!uv_is_closing(as_uv_handle());\n }\n\n /**\n * @brief Request handle to be closed.\n *\n * This **must** be called on each handle before memory is released.
\n * In-progress requests are cancelled and this can result in errors.\n *\n * The handle will emit a close event when finished.\n */\n void close() noexcept {\n if(!closing()) {\n uv_close(as_uv_handle(), &handle::close_callback);\n }\n }\n\n /**\n * @brief Reference the given handle.\n *\n * References are idempotent, that is, if a handle is already referenced\n * calling this function again will have no effect.\n */\n void reference() noexcept {\n uv_ref(as_uv_handle());\n }\n\n /**\n * @brief Unreference the given handle.\n *\n * References are idempotent, that is, if a handle is not referenced calling\n * this function again will have no effect.\n */\n void unreference() noexcept {\n uv_unref(as_uv_handle());\n }\n\n /**\n * @brief Checks if the given handle referenced.\n * @return True if the handle referenced, false otherwise.\n */\n bool referenced() const noexcept {\n return !!uv_has_ref(as_uv_handle());\n }\n\n /**\n * @brief Returns the size of the underlying handle type.\n * @return The size of the underlying handle type.\n */\n std::size_t size() const noexcept {\n return uv_handle_size(as_uv_handle()->type);\n }\n\n /**\n * @brief Gets the size of the send buffer used for the socket.\n *\n * Gets the size of the send buffer that the operating system uses for the\n * socket.
\n * This function works for tcp, pipeand udp handles on Unix and for tcp and\n * udp handles on Windows.
\n * Note that Linux will return double the size of the original set value.\n *\n * @return The size of the send buffer, the underlying return value in case\n * of errors.\n */\n int send_buffer_size() {\n int value = 0;\n auto err = uv_send_buffer_size(as_uv_handle(), &value);\n return err ? err : value;\n }\n\n /**\n * @brief Sets the size of the send buffer used for the socket.\n *\n * Sets the size of the send buffer that the operating system uses for the\n * socket.
\n * This function works for tcp, pipe and udp handles on Unix and for tcp and\n * udp handles on Windows.
\n * Note that Linux will set double the size.\n *\n * @return Underlying return value.\n */\n int send_buffer_size(int value) {\n return uv_send_buffer_size(as_uv_handle(), &value);\n }\n\n /**\n * @brief Gets the size of the receive buffer used for the socket.\n *\n * Gets the size of the receive buffer that the operating system uses for\n * the socket.
\n * This function works for tcp, pipe and udp handles on Unix and for tcp and\n * udp handles on Windows.
\n * Note that Linux will return double the size of the original set value.\n *\n * @return The size of the receive buffer, the underlying return value in\n * case of errors.\n */\n int recv_buffer_size() {\n int value = 0;\n auto err = uv_recv_buffer_size(as_uv_handle(), &value);\n return err ? err : value;\n }\n\n /**\n * @brief Sets the size of the receive buffer used for the socket.\n *\n * Sets the size of the receive buffer that the operating system uses for\n * the socket.
\n * This function works for tcp, pipe and udp handles on Unix and for tcp and\n * udp handles on Windows.
\n * Note that Linux will set double the size.\n *\n * @return Underlying return value.\n */\n int recv_buffer_size(int value) {\n return uv_recv_buffer_size(as_uv_handle(), &value);\n }\n\n /**\n * @brief Gets the platform dependent file descriptor equivalent.\n *\n * Supported handles:\n *\n * * tcp_handle\n * * pipe_handle\n * * tty_handle\n * * udp_handle\n * * poll_handle\n *\n * If invoked on a different handle, one that doesn\u2019t have an attached file\n * descriptor yet or one which was closed, an invalid value is returned.\n *\n * See the official\n * [documentation](http://docs.libuv.org/en/v1.x/handle.html#c.uv_fileno)\n * for further details.\n *\n * @return The file descriptor attached to the hande or a negative value in\n * case of errors.\n */\n os_file_descriptor fd() const {\n uv_os_fd_t fd;\n uv_fileno(as_uv_handle(), &fd);\n return fd;\n }\n};\n\n} // namespace uvw\n\n#endif // UVW_HANDLE_INCLUDE_H\n\n// Path: src/uvw/prepare.h\n#ifndef UVW_PREPARE_INCLUDE_H\n#define UVW_PREPARE_INCLUDE_H\n\n#include \n#include \"handle.hpp\"\n#include \"loop.h\"\n\nnamespace uvw {\n\n/*! @brief Prepare event. */\nstruct prepare_event {};\n\n/**\n * @brief The prepare handle.\n *\n * Prepare handles will emit a prepare event once per loop iteration, right\n * before polling for I/O.\n *\n * To create a `prepare_handle` through a `loop`, no arguments are required.\n */\nclass prepare_handle final: public handle {\n static void start_callback(uv_prepare_t *hndl);\n\npublic:\n using handle::handle;\n\n /**\n * @brief Initializes the handle.\n * @return Underlying return value.\n */\n int init();\n\n /**\n * @brief Starts the handle.\n *\n * A prepare event will be emitted once per loop iteration, right before\n * polling for I/O.\n *\n * The handle will start emitting prepare events when needed.\n *\n * @return Underlying return value.\n */\n int start();\n\n /**\n * @brief Stops the handle.\n * @return Underlying return value.\n */\n int stop();\n};\n\n} // namespace uvw\n\n#ifndef UVW_AS_LIB\n# include \"prepare.cpp\"\n#endif\n\n#endif // UVW_PREPARE_INCLUDE_H\n\n// Path: src/uvw/prepare.cpp\n#ifdef UVW_AS_LIB\n# include \"prepare.h\"\n#endif\n\n#include \"config.h\"\n\nnamespace uvw {\n\nUVW_INLINE void prepare_handle::start_callback(uv_prepare_t *hndl) {\n prepare_handle &prepare = *(static_cast(hndl->data));\n prepare.publish(prepare_event{});\n}\n\nUVW_INLINE int prepare_handle::init() {\n return leak_if(uv_prepare_init(parent().raw(), raw()));\n}\n\nUVW_INLINE int prepare_handle::start() {\n return uv_prepare_start(raw(), &start_callback);\n}\n\nUVW_INLINE int prepare_handle::stop() {\n return uv_prepare_stop(raw());\n}\n\n} // namespace uvw\n\n// Path: src/uvw/request.hpp\n#ifndef UVW_REQUEST_INCLUDE_H\n#define UVW_REQUEST_INCLUDE_H\n\n#include \n#include \n#include \n#include \n#include \"config.h\"\n#include \"resource.hpp\"\n\nnamespace uvw {\n\n/**\n * @brief Request base class.\n *\n * Base type for all `uvw` request types.\n */\ntemplate\nclass request: public resource {\nprotected:\n static auto reserve(U *req) {\n auto ptr = static_cast(req->data)->shared_from_this();\n ptr->self_reset();\n return ptr;\n }\n\npublic:\n using resource::resource;\n\n /**\n * @brief Cancels a pending request.\n *\n * This method fails if the request is executing or has finished\n * executing.\n *\n * See the official\n * [documentation](http://docs.libuv.org/en/v1.x/request.html#c.uv_cancel)\n * for further details.\n *\n * @return Underlying return value.\n */\n int cancel() {\n return uv_cancel(reinterpret_cast(this->raw()));\n }\n\n /**\n * @brief Returns the size of the underlying request type.\n * @return The size of the underlying request type.\n */\n std::size_t size() const noexcept {\n return uv_req_size(reinterpret_cast(this->raw())->type);\n }\n};\n\n} // namespace uvw\n\n#endif // UVW_REQUEST_INCLUDE_H\n\n// Path: src/uvw/stream.h\n#ifndef UVW_STREAM_INCLUDE_H\n#define UVW_STREAM_INCLUDE_H\n\n#include \n#include \n#include \n#include \n#include \n#include \n#include \"config.h\"\n#include \"handle.hpp\"\n#include \"loop.h\"\n#include \"request.hpp\"\n\nnamespace uvw {\n\n/*! @brief Connect event. */\nstruct connect_event {};\n\n/*! @brief End event. */\nstruct end_event {};\n\n/*! @brief Listen event. */\nstruct listen_event {};\n\n/*! @brief Shutdown event. */\nstruct shutdown_event {};\n\n/*! @brief Write event. */\nstruct write_event {};\n\n/*! @brief Data event. */\nstruct data_event {\n explicit data_event(std::unique_ptr buf, std::size_t len) noexcept;\n\n std::unique_ptr data; /*!< A bunch of data read on the stream. */\n std::size_t length; /*!< The amount of data read on the stream. */\n};\n\nnamespace details {\n\nclass connect_req final: public request {\n static void connect_callback(uv_connect_t *req, int status);\n\npublic:\n using request::request;\n\n template\n auto connect(F &&f, Args &&...args) -> std::enable_if_t(f)(raw(), std::forward(args)..., &connect_callback)), void>, int> {\n std::forward(f)(raw(), std::forward(args)..., &connect_callback);\n return this->leak_if(0);\n }\n\n template\n auto connect(F &&f, Args &&...args) -> std::enable_if_t(f)(raw(), std::forward(args)..., &connect_callback)), void>, int> {\n return this->leak_if(std::forward(f)(raw(), std::forward(args)..., &connect_callback));\n }\n};\n\nclass shutdown_req final: public request {\n static void shoutdown_callback(uv_shutdown_t *req, int status);\n\npublic:\n using request::request;\n\n int shutdown(uv_stream_t *hndl);\n};\n\ntemplate\nclass write_req final: public request, uv_write_t, write_event> {\n static void write_callback(uv_write_t *req, int status) {\n if(auto ptr = request, uv_write_t, write_event>::reserve(req); status) {\n ptr->publish(error_event{status});\n } else {\n ptr->publish(write_event{});\n }\n }\n\npublic:\n \nwrite_req(loop::token token, std::shared_ptr parent, std::unique_ptr dt, unsigned int len)\n : request, uv_write_t, write_event>{token, std::move(parent)},\n data{std::move(dt)},\n buf{uv_buf_init(data.get(), len)} {}\n\n int write(uv_stream_t *hndl) {\n return this->leak_if(uv_write(this->raw(), hndl, &buf, 1, &write_callback));\n }\n\n int write(uv_stream_t *hndl, uv_stream_t *send) {\n return this->leak_if(uv_write2(this->raw(), hndl, &buf, 1, send, &write_callback));\n }\n\nprivate:\n std::unique_ptr data;\n uv_buf_t buf;\n};\n\n} // namespace details\n\n/**\n * @brief The stream handle.\n *\n * Stream handles provide an abstraction of a duplex communication channel.\n * The stream handle is an intermediate type, `uvw` provides three stream\n * implementations: tcp, pipe and tty handles.\n */\ntemplate\nclass stream_handle: public handle {\n using base = handle;\n\n template\n friend class stream_handle;\n\n static constexpr unsigned int DEFAULT_BACKLOG = 128;\n\n static void read_callback(uv_stream_t *hndl, ssize_t nread, const uv_buf_t *buf) {\n T &ref = *(static_cast(hndl->data));\n // data will be destroyed no matter of what the value of nread is\n std::unique_ptr data{buf->base};\n\n // nread == 0 is ignored (see http://docs.libuv.org/en/v1.x/stream.html)\n // equivalent to EAGAIN/EWOULDBLOCK, it shouldn't be treated as an error\n // for we don't have data to emit though, it's fine to suppress it\n\n if(nread == UV_EOF) {\n // end of stream\n ref.publish(end_event{});\n } else if(nread > 0) {\n // data available\n ref.publish(data_event{std::move(data), static_cast(nread)});\n } else if(nread < 0) {\n // transmission error\n ref.publish(error_event(nread));\n }\n }\n\n static void listen_callback(uv_stream_t *hndl, int status) {\n if(T &ref = *(static_cast(hndl->data)); status) {\n ref.publish(error_event{status});\n } else {\n ref.publish(listen_event{});\n }\n }\n\n uv_stream_t *as_uv_stream() {\n return reinterpret_cast(this->raw());\n }\n\n const uv_stream_t *as_uv_stream() const {\n return reinterpret_cast(this->raw());\n }\n\npublic:\n#ifdef _MSC_VER\n stream_handle(loop::token token, std::shared_ptr ref)\n : base{token, std::move(ref)} {}\n#else\n using base::base;\n#endif\n\n /**\n * @brief Shutdowns the outgoing (write) side of a duplex stream.\n *\n * It waits for pending write requests to complete. The handle should refer\n * to a initialized stream.
\n * A shutdown event will be emitted after shutdown is complete.\n *\n * @return Underlying return value.\n */\n int shutdown() {\n auto listener = [ptr = this->shared_from_this()](const auto &event, const auto &) {\n ptr->publish(event);\n };\n\n auto shutdown = this->parent().template resource();\n shutdown->template on(listener);\n shutdown->template on(listener);\n\n return shutdown->shutdown(as_uv_stream());\n }\n\n /**\n * @brief Starts listening for incoming connections.\n *\n * When a new incoming connection is received, a listen event is\n * emitted.\n *\n * @param backlog Indicates the number of connections the kernel might\n * queue, same as listen(2).\n *\n * @return Underlying return value.\n */\n int listen(int backlog = DEFAULT_BACKLOG) {\n return uv_listen(as_uv_stream(), backlog, &listen_callback);\n }\n\n /**\n * @brief Accepts incoming connections.\n *\n * This call is used in conjunction with `listen()` to accept incoming\n * connections. Call this function after receiving a listen event to accept\n * the connection. Before calling this function, the submitted handle must\n * be initialized.\n *\n * When the listen event is emitted it is guaranteed that this function will\n * complete successfully the first time. If you attempt to use it more than\n * once, it may fail.
\n * It is suggested to only call this function once per listen event.\n *\n * @note\n * Both the handles must be running on the same loop.\n *\n * @param ref An initialized handle to be used to accept the connection.\n * @return Underlying return value.\n */\n template\n int accept(S &ref) {\n return uv_accept(as_uv_stream(), ref.as_uv_stream());\n }\n\n /**\n * @brief Starts reading data from an incoming stream.\n *\n * A data event will be emitted several times until there is no more data to\n * read or `stop()` is called.
\n * An end event will be emitted when there is no more data to read.\n *\n * @return Underlying return value.\n */\n int read() {\n return uv_read_start(as_uv_stream(), &details::common_alloc_callback, &read_callback);\n }\n\n /**\n * @brief Starts reading data from an incoming stream.\n * @sa read\n * @tparam Alloc Custom allocation function.\n * @return Underlying return value.\n */\n template\n int read() {\n return uv_read_start(as_uv_stream(), &details::common_alloc_callback, &read_callback);\n }\n\n /**\n * @brief Stops reading data from the stream.\n *\n * This function is idempotent and may be safely called on a stopped stream.\n *\n * @return Underlying return value.\n */\n int stop() {\n return uv_read_stop(as_uv_stream());\n }\n\n /**\n * @brief Writes data to the stream.\n *\n * Data are written in order. The handle takes the ownership of the data and\n * it is in charge of delete them.\n *\n * A write event will be emitted when the data have been written.\n *\n * @param data The data to be written to the stream.\n * @param len The lenght of the submitted data.\n * @return Underlying return value.\n */\n template\n int write(std::unique_ptr data, unsigned int len) {\n auto req = this->parent().template resource>(std::move(data), len);\n auto listener = [ptr = this->shared_from_this()](const auto &event, const auto &) {\n ptr->publish(event);\n };\n\n req->template on(listener);\n req->template on(listener);\n\n return req->write(as_uv_stream());\n }\n\n /**\n * @brief Writes data to the stream.\n *\n * Data are written in order. The handle doesn't take the ownership of the\n * data. Be sure that their lifetime overcome the one of the request.\n *\n * A write event will be emitted when the data have been written.\n *\n * @param data The data to be written to the stream.\n * @param len The lenght of the submitted data.\n * @return Underlying return value.\n */\n int write(char *data, unsigned int len) {\n auto req = this->parent().template resource>(std::unique_ptr{data, [](char *) {}}, len);\n auto listener = [ptr = this->shared_from_this()](const auto &event, const auto &) {\n ptr->publish(event);\n };\n\n req->template on(listener);\n req->template on(listener);\n\n return req->write(as_uv_stream());\n }\n\n /**\n * @brief Extended write function for sending handles over a pipe handle.\n *\n * The pipe must be initialized with `ipc == true`.\n *\n * `send` must be a tcp or pipe handle, which is a server or a connection\n * (listening or connected state). Bound sockets or pipes will be assumed to\n * be servers.\n *\n * The handle takes the ownership of the data and it is in charge of delete\n * them.\n *\n * A write event will be emitted when the data have been written.\n *\n * @param send The handle over which to write data.\n * @param data The data to be written to the stream.\n * @param len The lenght of the submitted data.\n * @return Underlying return value.\n */\n template\n int write(S &send, std::unique_ptr data, unsigned int len) {\n auto req = this->parent().template resource>(std::move(data), len);\n auto listener = [ptr = this->shared_from_this()](const auto &event, const auto &) {\n ptr->publish(event);\n };\n\n req->template on(listener);\n req->template on(listener);\n\n return req->write(as_uv_stream(), send.as_uv_stream());\n }\n\n /**\n * @brief Extended write function for sending handles over a pipe handle.\n *\n * The pipe must be initialized with `ipc == true`.\n *\n * `send` must be a tcp or pipe handle, which is a server or a connection\n * (listening or connected state). Bound sockets or pipes will be assumed to\n * be servers.\n *\n * The handle doesn't take the ownership of the data. Be sure that their\n * lifetime overcome the one of the request.\n *\n * A write event will be emitted when the data have been written.\n *\n * @param send The handle over which to write data.\n * @param data The data to be written to the stream.\n * @param len The lenght of the submitted data.\n * @return Underlying return value.\n */\n template\n int write(S &send, char *data, unsigned int len) {\n auto req = this->parent().template resource>(std::unique_ptr{data, [](char *) {}}, len);\n auto listener = [ptr = this->shared_from_this()](const auto &event, const auto &) {\n ptr->publish(event);\n };\n\n req->template on(listener);\n req->template on(listener);\n\n return req->write(as_uv_stream(), send.as_uv_stream());\n }\n\n /**\n * @brief Queues a write request if it can be completed immediately.\n *\n * Same as `write()`, but won\u2019t queue a write request if it can\u2019t be\n * completed immediately.\n *\n * @param data The data to be written to the stream.\n * @param len The lenght of the submitted data.\n * @return Underlying return value.\n */\n int try_write(std::unique_ptr data, unsigned int len) {\n uv_buf_t bufs[] = {uv_buf_init(data.get(), len)};\n return uv_try_write(as_uv_stream(), bufs, 1);\n }\n\n /**\n * @brief Queues a write request if it can be completed immediately.\n *\n * Same as `try_write` for sending handles over a pipe.\n *\n * @param data The data to be written to the stream.\n * @param len The lenght of the submitted data.\n * @param send A valid handle suitable for the purpose.\n * @return Underlying return value.\n */\n template\n int try_write(std::unique_ptr data, unsigned int len, stream_handle &send) {\n uv_buf_t bufs[] = {uv_buf_init(data.get(), len)};\n return uv_try_write2(as_uv_stream(), bufs, 1, send.raw());\n }\n\n /**\n * @brief Queues a write request if it can be completed immediately.\n *\n * Same as `write()`, but won\u2019t queue a write request if it can\u2019t be\n * completed immediately.\n *\n * @param data The data to be written to the stream.\n * @param len The lenght of the submitted data.\n * @return Underlying return value.\n */\n int try_write(char *data, unsigned int len) {\n uv_buf_t bufs[] = {uv_buf_init(data, len)};\n return uv_try_write(as_uv_stream(), bufs, 1);\n }\n\n /**\n * @brief Queues a write request if it can be completed immediately.\n *\n * Same as `try_write` for sending handles over a pipe.\n *\n * @param data The data to be written to the stream.\n * @param len The lenght of the submitted data.\n * @param send A valid handle suitable for the purpose.\n * @return Underlying return value.\n */\n template\n int try_write(char *data, unsigned int len, stream_handle &send) {\n uv_buf_t bufs[] = {uv_buf_init(data, len)};\n return uv_try_write2(as_uv_stream(), bufs, 1, send.raw());\n }\n\n /**\n * @brief Checks if the stream is readable.\n * @return True if the stream is readable, false otherwise.\n */\n bool readable() const noexcept {\n return (uv_is_readable(as_uv_stream()) == 1);\n }\n\n /**\n * @brief Checks if the stream is writable.\n * @return True if the stream is writable, false otherwise.\n */\n bool writable() const noexcept {\n return (uv_is_writable(as_uv_stream()) == 1);\n }\n\n /**\n * @brief Enables or disables blocking mode for a stream.\n *\n * When blocking mode is enabled all writes complete synchronously. The\n * interface remains unchanged otherwise, e.g. completion or failure of the\n * operation will still be reported through events which are emitted\n * asynchronously.\n *\n * See the official\n * [documentation](http://docs.libuv.org/en/v1.x/stream.html#c.uv_stream_set_blocking)\n * for further details.\n *\n * @param enable True to enable blocking mode, false otherwise.\n * @return True in case of success, false otherwise.\n */\n bool blocking(bool enable = false) {\n return (0 == uv_stream_set_blocking(as_uv_stream(), enable));\n }\n\n /**\n * @brief Gets the amount of queued bytes waiting to be sent.\n * @return Amount of queued bytes waiting to be sent.\n */\n size_t write_queue_size() const noexcept {\n return uv_stream_get_write_queue_size(as_uv_stream());\n }\n};\n\n} // namespace uvw\n\n#ifndef UVW_AS_LIB\n# include \"stream.cpp\"\n#endif\n\n#endif // UVW_STREAM_INCLUDE_H\n\n// Path: src/uvw/stream.cpp\n#ifdef UVW_AS_LIB\n# include \"stream.h\"\n#endif\n\n#include \"config.h\"\n\nnamespace uvw {\n\nUVW_INLINE data_event::data_event(std::unique_ptr buf, std::size_t len) noexcept\n : data{std::move(buf)},\n length{len} {}\n\nUVW_INLINE void details::connect_req::connect_callback(uv_connect_t *req, int status) {\n if(auto ptr = reserve(req); status) {\n ptr->publish(error_event{status});\n } else {\n ptr->publish(connect_event{});\n }\n}\n\nUVW_INLINE void details::shutdown_req::shoutdown_callback(uv_shutdown_t *req, int status) {\n if(auto ptr = reserve(req); status) {\n ptr->publish(error_event{status});\n } else {\n ptr->publish(shutdown_event{});\n }\n}\n\nUVW_INLINE int details::shutdown_req::shutdown(uv_stream_t *hndl) {\n return this->leak_if(uv_shutdown(raw(), hndl, &shoutdown_callback));\n}\n\n} // namespace uvw\n\n// Path: src/uvw/check.h\n#ifndef UVW_CHECK_INCLUDE_H\n#define UVW_CHECK_INCLUDE_H\n\n#include \n#include \"handle.hpp\"\n#include \"loop.h\"\n\nnamespace uvw {\n\n/*! @brief Check event. */\nstruct check_event {};\n\n/**\n * @brief The check handle.\n *\n * Check handles will emit a check event once per loop iteration, right after\n * polling for I/O.\n *\n * To create a `check_handle` through a `loop`, no arguments are required.\n */\nclass check_handle final: public handle {\n static void start_callback(uv_check_t *hndl);\n\npublic:\n using handle::handle;\n\n /**\n * @brief Initializes the handle.\n * @return Underlying return value.\n */\n int init();\n\n /**\n * @brief Starts the handle.\n *\n * A check event will be emitted once per loop iteration, right after\n * polling for I/O.\n *\n * @return Underlying return value.\n */\n int start();\n\n /**\n * @brief Stops the handle.\n * @return Underlying return value.\n */\n int stop();\n};\n\n} // namespace uvw\n\n#ifndef UVW_AS_LIB\n# include \"check.cpp\"\n#endif\n\n#endif // UVW_CHECK_INCLUDE_H\n\n// Path: src/uvw/check.cpp\n#ifdef UVW_AS_LIB\n# include \"check.h\"\n#endif\n\n#include \"config.h\"\n\nnamespace uvw {\n\nUVW_INLINE void check_handle::start_callback(uv_check_t *hndl) {\n check_handle &check = *(static_cast(hndl->data));\n check.publish(check_event{});\n}\n\nUVW_INLINE int check_handle::init() {\n return leak_if(uv_check_init(parent().raw(), raw()));\n}\n\nUVW_INLINE int check_handle::start() {\n return uv_check_start(raw(), &start_callback);\n}\n\nUVW_INLINE int check_handle::stop() {\n return uv_check_stop(raw());\n}\n\n} // namespace uvw\n\n// Path: src/uvw/enum.hpp\n#ifndef UVW_ENUM_INCLUDE_HPP\n#define UVW_ENUM_INCLUDE_HPP\n\n#include \n#include \"config.h\"\n\n/**\n * @brief Operator available for enums for which bitmask support is enabled.\n * @tparam Type Enum class type.\n * @param lhs The first value to use.\n * @param rhs The second value to use.\n * @return The result of invoking the operator on the underlying types of the\n * two values provided.\n */\ntemplate\n[[nodiscard]] constexpr std::enable_if_t, decltype(Type::_UVW_ENUM)>\noperator|(const Type lhs, const Type rhs) noexcept {\n return static_cast(static_cast>(lhs) | static_cast>(rhs));\n}\n\n/*! @copydoc operator| */\ntemplate\n[[nodiscard]] constexpr std::enable_if_t, decltype(Type::_UVW_ENUM)>\noperator&(const Type lhs, const Type rhs) noexcept {\n return static_cast(static_cast>(lhs) & static_cast>(rhs));\n}\n\n/*! @copydoc operator| */\ntemplate\n[[nodiscard]] constexpr std::enable_if_t, decltype(Type::_UVW_ENUM)>\noperator^(const Type lhs, const Type rhs) noexcept {\n return static_cast(static_cast>(lhs) ^ static_cast>(rhs));\n}\n\n/**\n * @brief Operator available for enums for which bitmask support is enabled.\n * @tparam Type Enum class type.\n * @param value The value to use.\n * @return The result of invoking the operator on the underlying types of the\n * value provided.\n */\ntemplate\n[[nodiscard]] constexpr std::enable_if_t, decltype(Type::_UVW_ENUM)>\noperator~(const Type value) noexcept {\n return static_cast(~static_cast>(value));\n}\n\n/*! @copydoc operator~ */\ntemplate\n[[nodiscard]] constexpr std::enable_if_t, decltype(Type::_UVW_ENUM, bool{})>\noperator!(const Type value) noexcept {\n return !static_cast>(value);\n}\n\n/*! @copydoc operator| */\ntemplate\nconstexpr std::enable_if_t, decltype(Type::_UVW_ENUM) &>\noperator|=(Type &lhs, const Type rhs) noexcept {\n return (lhs = (lhs | rhs));\n}\n\n/*! @copydoc operator| */\ntemplate\nconstexpr std::enable_if_t, decltype(Type::_UVW_ENUM) &>\noperator&=(Type &lhs, const Type rhs) noexcept {\n return (lhs = (lhs & rhs));\n}\n\n/*! @copydoc operator| */\ntemplate\nconstexpr std::enable_if_t, decltype(Type::_UVW_ENUM) &>\noperator^=(Type &lhs, const Type rhs) noexcept {\n return (lhs = (lhs ^ rhs));\n}\n\n#endif\n\n// Path: src/uvw/thread.h\n#ifndef UVW_THREAD_INCLUDE_H\n#define UVW_THREAD_INCLUDE_H\n\n#include \n#include \n#include \n#include \n#include \n#include \n#include \n#include \"config.h\"\n#include \"enum.hpp\"\n#include \"loop.h\"\n#include \"uv_type.hpp\"\n\nnamespace uvw {\n\nnamespace details {\n\nenum class uvw_thread_create_flags : std::underlying_type_t {\n THREAD_NO_FLAGS = UV_THREAD_NO_FLAGS,\n THREAD_HAS_STACK_SIZE = UV_THREAD_HAS_STACK_SIZE\n};\n\n}\n\nclass thread;\nclass thread_local_storage;\nclass once;\nclass mutex;\nclass rwlock;\nclass semaphore;\nclass condition;\nclass barrier;\n\n/**\n * @brief The thread wrapper.\n *\n * To create a `thread` through a `loop`, arguments follow:\n *\n * * A callback invoked to initialize thread execution. The type must be such\n * that it can be assigned to an `std::function)>`.\n * * An optional payload the type of which is `std::shared_ptr`.\n */\nclass thread final: public uv_type {\n using internal_task = std::function)>;\n\n static void create_callback(void *arg);\n\npublic:\n using create_flags = details::uvw_thread_create_flags;\n using task = internal_task;\n using type = uv_thread_t;\n\n explicit thread(loop::token token, std::shared_ptr ref, task t, std::shared_ptr d = nullptr) noexcept;\n\n /**\n * @brief Obtains the identifier of the calling thread.\n * @return The identifier of the calling thread.\n */\n static type self() noexcept;\n\n /**\n * @brief Gets the CPU number on which the calling thread is running.\n * @return The CPU number on which the calling thread is running.\n */\n static int getcpu() noexcept;\n\n /**\n * @brief Compares thread by means of their identifiers.\n * @param tl A valid instance of a thread.\n * @param tr A valid instance of a thread.\n * @return True if the two threads are the same thread, false otherwise.\n */\n static bool equal(const thread &tl, const thread &tr) noexcept;\n\n ~thread() noexcept;\n\n /**\n * @brief Creates a new thread.\n * @return True in case of success, false otherwise.\n */\n bool run() noexcept;\n\n /**\n * @brief Creates a new thread.\n *\n * Available flags are:\n *\n * * `thread::create_flags::THREAD_NO_FLAGS`: no flags set.\n * * `thread::create_flags::THREAD_HAS_STACK_SIZE`: if set, `stack` specifies a\n * stack size for the new thread. 0 indicates that the default value should\n * be used (it behaves as if the flag was not set). Other values will be\n * rounded up to the nearest page boundary.\n *\n * @return True in case of success, false otherwise.\n */\n bool run(create_flags opts, std::size_t stack = {}) noexcept;\n\n /**\n * @brief Joins with a terminated thread.\n * @return True in case of success, false otherwise.\n */\n bool join() noexcept;\n\nprivate:\n std::shared_ptr data;\n task func;\n};\n\n/**\n * @brief The thread local storage wrapper.\n *\n * A storage area that can only be accessed by one thread. The variable can be\n * seen as a global variable that is only visible to a particular thread and not\n * the whole program.\n */\nclass thread_local_storage final: public uv_type {\npublic:\n explicit thread_local_storage(loop::token token, std::shared_ptr ref) noexcept;\n\n ~thread_local_storage() noexcept;\n\n /**\n * @brief Gets the value of a given variable.\n * @tparam T Type to which to cast the opaque storage area.\n * @return A pointer to the given variable.\n */\n template\n T *get() noexcept {\n return static_cast(uv_key_get(uv_type::raw()));\n }\n\n /**\n * @brief Sets the value of a given variable.\n * @tparam T Type of the variable to store aside.\n * @param value A valid pointer to the variable to store\n */\n template\n void set(T *value) noexcept {\n return uv_key_set(uv_type::raw(), value);\n }\n};\n\n/**\n * @brief The once wrapper.\n *\n * Runs a function once and only once. Concurrent calls to `once` will block all\n * callers except one (it\u2019s unspecified which one).\n */\nclass once final: public uv_type {\n static uv_once_t *guard() noexcept;\n\npublic:\n using uv_type::uv_type;\n\n /**\n * @brief Runs a function once and only once.\n *\n * The callback must be such that it's convertible to `void(*)(void)`. Free\n * functions and non-capturing lambdas are both viable solutions.\n *\n * @tparam F Type of the callback.\n * @param f A valid callback function.\n */\n template\n static void run(F &&f) noexcept {\n using callback_type = void (*)(void);\n static_assert(std::is_convertible_v);\n callback_type cb = f;\n uv_once(guard(), cb);\n }\n};\n\n/**\n * @brief The mutex wrapper.\n *\n * To create a `mutex` through a `loop`, arguments follow:\n *\n * * An option boolean that specifies if the mutex is a recursive one. The\n * default value is false, the mutex isn't recursive.\n */\nclass mutex final: public uv_type {\n friend class condition;\n\npublic:\n explicit mutex(loop::token token, std::shared_ptr ref, bool recursive = false) noexcept;\n\n ~mutex() noexcept;\n\n /**\n * @brief Locks the mutex.\n */\n void lock() noexcept;\n\n /**\n * @brief Tries to lock the mutex.\n * @return True in case of success, false otherwise.\n */\n bool try_lock() noexcept;\n\n /**\n * @brief Unlocks the mutex.\n */\n void unlock() noexcept;\n};\n\n/**\n * @brief The rwlock wrapper.\n */\nclass rwlock final: public uv_type {\npublic:\n explicit rwlock(loop::token token, std::shared_ptr ref) noexcept;\n\n ~rwlock() noexcept;\n\n /**\n * @brief Locks a read-write lock object for reading.\n */\n void rdlock() noexcept;\n\n /**\n * @brief Tries to lock a read-write lock object for reading.\n * @return True in case of success, false otherwise.\n */\n bool try_rdlock() noexcept;\n\n /**\n * @brief Unlocks a read-write lock object previously locked for reading.\n */\n void rdunlock() noexcept;\n\n /**\n * @brief Locks a read-write lock object for writing.\n */\n void wrlock() noexcept;\n\n /**\n * @brief Tries to lock a read-write lock object for writing.\n * @return True in case of success, false otherwise.\n */\n bool try_wrlock() noexcept;\n\n /**\n * @brief Unlocks a read-write lock object previously locked for writing.\n */\n void wrunlock() noexcept;\n};\n\n/**\n * @brief The semaphore wrapper.\n *\n * To create a `semaphore` through a `loop`, arguments follow:\n *\n * * An unsigned integer that specifies the initial value for the semaphore.\n */\nclass semaphore final: public uv_type {\npublic:\n explicit semaphore(loop::token token, std::shared_ptr ref, unsigned int value) noexcept;\n\n ~semaphore() noexcept;\n\n /**\n * @brief Unlocks a semaphore.\n */\n void post() noexcept;\n\n /**\n * @brief Locks a semaphore.\n */\n void wait() noexcept;\n\n /**\n * @brief Tries to lock a semaphore.\n * @return True in case of success, false otherwise.\n */\n bool try_wait() noexcept;\n};\n\n/**\n * @brief The condition wrapper.\n */\nclass condition final: public uv_type {\npublic:\n explicit condition(loop::token token, std::shared_ptr ref) noexcept;\n\n ~condition() noexcept;\n\n /**\n * @brief Signals a condition.\n *\n * This function shall unblock at least one of the threads that are blocked\n * on the specified condition variable (if any threads are blocked on it).\n */\n void signal() noexcept;\n\n /**\n * @brief Broadcasts a condition.\n *\n * This function shall unblock threads blocked on a condition variable.\n */\n void broadcast() noexcept;\n\n /**\n * @brief Waits on a condition.\n *\n * These function atomically releases the mutex and causes the calling\n * thread to block on the condition variable.\n *\n * @param mtx A mutex locked by the calling thread, otherwise expect\n * undefined behavior.\n */\n void wait(mutex &mtx) noexcept;\n\n /**\n * @brief Waits on a condition.\n *\n * These function atomically releases the mutex and causes the calling\n * thread to block on the condition variable.
\n * The functions returns with an error if the absolute time specified passes\n * (that is, system time equals or exceeds it) before the condition is\n * signaled or broadcasted, or if the absolute time specified has already\n * been passed at the time of the call.\n *\n * @param mtx A mutex locked by the calling thread, otherwise expect\n * undefined behavior.\n * @param timeout The maximum time to wait before to return.\n * @return True in case of success, false otherwise.\n */\n bool timed_wait(mutex &mtx, uint64_t timeout) noexcept;\n};\n\n/**\n * @brief The barrier wrapper.\n *\n * To create a `barrier` through a `loop`, arguments follow:\n *\n * * An unsigned integer that specifies the number of threads that must call\n * `wait` before any of them successfully return from the call. The value\n * specified must be greater than zero.\n */\nclass barrier final: public uv_type {\npublic:\n explicit barrier(loop::token token, std::shared_ptr ref, unsigned int count) noexcept;\n\n ~barrier() noexcept;\n\n /**\n * @brief Synchronizes at a barrier.\n * @return True in case of success, false otherwise.\n */\n bool wait() noexcept;\n};\n\n} // namespace uvw\n\n#ifndef UVW_AS_LIB\n# include \"thread.cpp\"\n#endif\n\n#endif // UVW_THREAD_INCLUDE_H\n\n// Path: src/uvw/thread.cpp\n#ifdef UVW_AS_LIB\n# include \"thread.h\"\n#endif\n\n#include \"config.h\"\n\nnamespace uvw {\n\nUVW_INLINE thread::thread(loop::token token, std::shared_ptr ref, task t, std::shared_ptr d) noexcept\n : uv_type{token, std::move(ref)},\n data{std::move(d)},\n func{std::move(t)} {}\n\nUVW_INLINE void thread::create_callback(void *arg) {\n thread &curr = *(static_cast(arg));\n curr.func(curr.data);\n}\n\nUVW_INLINE thread::type thread::self() noexcept {\n return uv_thread_self();\n}\n\nUVW_INLINE int thread::getcpu() noexcept {\n return uv_thread_getcpu();\n}\n\nUVW_INLINE bool thread::equal(const thread &tl, const thread &tr) noexcept {\n return !(0 == uv_thread_equal(tl.raw(), tr.raw()));\n}\n\nUVW_INLINE thread::~thread() noexcept {\n join();\n}\n\nUVW_INLINE bool thread::run() noexcept {\n return (0 == uv_thread_create(raw(), &create_callback, this));\n}\n\nUVW_INLINE bool thread::run(create_flags opts, std::size_t stack) noexcept {\n uv_thread_options_t params{static_cast(opts), stack};\n return (0 == uv_thread_create_ex(raw(), ¶ms, &create_callback, this));\n}\n\nUVW_INLINE bool thread::join() noexcept {\n return (0 == uv_thread_join(raw()));\n}\n\nUVW_INLINE thread_local_storage::thread_local_storage(loop::token token, std::shared_ptr ref) noexcept\n : uv_type{token, std::move(ref)} {\n uv_key_create(uv_type::raw());\n}\n\nUVW_INLINE thread_local_storage::~thread_local_storage() noexcept {\n uv_key_delete(uv_type::raw());\n}\n\nUVW_INLINE uv_once_t *once::guard() noexcept {\n static uv_once_t once = UV_ONCE_INIT;\n return &once;\n}\n\nUVW_INLINE mutex::mutex(loop::token token, std::shared_ptr ref, bool recursive) noexcept\n : uv_type{token, std::move(ref)} {\n if(recursive) {\n uv_mutex_init_recursive(raw());\n } else {\n uv_mutex_init(raw());\n }\n}\n\nUVW_INLINE mutex::~mutex() noexcept {\n uv_mutex_destroy(raw());\n}\n\nUVW_INLINE void mutex::lock() noexcept {\n uv_mutex_lock(raw());\n}\n\nUVW_INLINE bool mutex::try_lock() noexcept {\n return (0 == uv_mutex_trylock(raw()));\n}\n\nUVW_INLINE void mutex::unlock() noexcept {\n uv_mutex_unlock(raw());\n}\n\nUVW_INLINE rwlock::rwlock(loop::token token, std::shared_ptr ref) noexcept\n : uv_type{token, std::move(ref)} {\n uv_rwlock_init(raw());\n}\n\nUVW_INLINE rwlock::~rwlock() noexcept {\n uv_rwlock_destroy(raw());\n}\n\nUVW_INLINE void rwlock::rdlock() noexcept {\n uv_rwlock_rdlock(raw());\n}\n\nUVW_INLINE bool rwlock::try_rdlock() noexcept {\n return (0 == uv_rwlock_tryrdlock(raw()));\n}\n\nUVW_INLINE void rwlock::rdunlock() noexcept {\n uv_rwlock_rdunlock(raw());\n}\n\nUVW_INLINE void rwlock::wrlock() noexcept {\n uv_rwlock_wrlock(raw());\n}\n\nUVW_INLINE bool rwlock::try_wrlock() noexcept {\n...\n", "needle_token_start": 5567, "needle_token_end": 5664, "code_context_ntokens": 16023, "output": ["```cpp\ntemplate\nclass write_req final: public request, uv_write_t, write_event> {\n static void write_callback(uv_write_t *req, int status) {\n if(auto ptr = request, uv_write_t, write_event>::reserve(req); status) {\n ptr->publish(error_event{status});\n } else {\n ptr->publish(write_event{});\n }\n }\n\npublic:\n write_req(loop::token token, std::shared_ptr parent, std::unique_ptr dt, unsigned int len)\n : request, uv_write_t, write_event>{token, std::move(parent)},\n data{std::move(dt)},\n buf{uv_buf_init(data.get(), len)} {}\n\n int write(uv_stream_t *hndl) {\n return this->leak_if(uv_write(this->raw(), hndl, &buf, 1, &write_callback));\n }\n\n int write(uv_stream_t *hndl, uv_stream_t *send) {\n return this->leak_if(uv_write2(this->raw(), hndl, &buf, 1, send, &write_callback));\n }\n\nprivate:\n std::unique_ptr data;\n uv_buf_t buf;\n};\n```"]} +{"repo": "skypjack/uvw", "name": "recv", "language": "cpp", "path": "src/uvw/udp.h", "position_ratio": 0.45, "description": "\nFunction Description:\n1. **Purpose**: The function is designed to prepare the UDP socket for receiving data. If the socket is not already bound, it automatically binds to an all-interfaces IPv4 address with a random port.\n2. **Input**: There are no parameters required for this function.\n3. **Output**: Returns an integer value that represents the success or failure of the operation based on the underlying system call.\n4. **Procedure**: The function checks if the socket is bound; if not, it binds the socket to a default address and port. It then sets up the socket to start receiving data, which will trigger a data event upon data arrival.\n\n", "instruction": "Based on the function description and code context, please retrieve and repeat the exact described function from the code context in a code block wrapped by ```:", "template": "instruction\ncode_context\ndescription\ninstruction", "code_context": "// Path: src/uvw/tty.h\n#ifndef UVW_TTY_INCLUDE_H\n#define UVW_TTY_INCLUDE_H\n\n#include \n#include \n#include \n#include \"config.h\"\n#include \"stream.h\"\n#include \"util.h\"\n\nnamespace uvw {\n\nnamespace details {\n\nstruct reset_mode_memo {\n ~reset_mode_memo();\n};\n\nenum class uvw_tty_mode_t : std::underlying_type_t {\n NORMAL = UV_TTY_MODE_NORMAL,\n RAW = UV_TTY_MODE_RAW,\n IO = UV_TTY_MODE_IO\n};\n\nenum class uvw_tty_vtermstate_t : std::underlying_type_t {\n SUPPORTED = UV_TTY_SUPPORTED,\n UNSUPPORTED = UV_TTY_UNSUPPORTED\n};\n\n} // namespace details\n\n/**\n * @brief The tty handle.\n *\n * TTY handles represent a stream for the console.\n *\n * To create a `tty_handle` through a `loop`, arguments follow:\n *\n * * A valid file_handle. Usually the file descriptor will be:\n * * `uvw::std_in` or `0` for `stdin`\n * * `uvw::std_out` or `1` for `stdout`\n * * `uvw::std_err` or `2` for `stderr`\n * * A boolean value that specifies the plan on calling `read()` with this\n * stream. Remember that `stdin` is readable, `stdout` is not.\n *\n * See the official\n * [documentation](http://docs.libuv.org/en/v1.x/tty.html#c.uv_tty_init)\n * for further details.\n */\nclass tty_handle final: public stream_handle {\n static std::shared_ptr mode_memo_handler();\n\npublic:\n using tty_mode = details::uvw_tty_mode_t;\n using tty_vtermstate = details::uvw_tty_vtermstate_t;\n\n explicit tty_handle(loop::token token, std::shared_ptr ref, file_handle desc, bool readable);\n\n /**\n * @brief Initializes the handle.\n * @return Underlying return value.\n */\n int init();\n\n /**\n * @brief Sets the TTY using the specified terminal mode.\n *\n * Available modes are:\n *\n * * `TTY::tty_mode::NORMAL`\n * * `TTY::tty_mode::RAW`\n * * `TTY::tty_mode::IO`\n *\n * See the official\n...\n// Path: src/uvw/tty.cpp\n#ifdef UVW_AS_LIB\n# include \"tty.h\"\n#endif\n\n#include \n#include \"config.h\"\n\nnamespace uvw {\n\nUVW_INLINE details::reset_mode_memo::~reset_mode_memo() {\n uv_tty_reset_mode();\n}\n\nUVW_INLINE tty_handle::tty_handle(loop::token token, std::shared_ptr ref, file_handle desc, bool readable)\n : stream_handle{token, std::move(ref)},\n memo{mode_memo_handler()},\n fd{desc},\n rw{readable} {}\n\nUVW_INLINE std::shared_ptr tty_handle::mode_memo_handler() {\n static std::weak_ptr weak;\n auto shared = weak.lock();\n if(!shared) { weak = shared = std::make_shared(); }\n return shared;\n}\n\nUVW_INLINE int tty_handle::init() {\n return leak_if(uv_tty_init(parent().raw(), raw(), fd, rw));\n}\n\nUVW_INLINE bool tty_handle::mode(tty_handle::tty_mode m) {\n return (0 == uv_tty_set_mode(raw(), static_cast(m)));\n}\n\nUVW_INLINE bool tty_handle::reset_mode() noexcept {\n return (0 == uv_tty_reset_mode());\n}\n\nUVW_INLINE win_size tty_handle::get_win_size() {\n win_size size;\n\n if(0 != uv_tty_get_winsize(raw(), &size.width, &size.height)) {\n size.width = -1;\n size.height = -1;\n }\n\n return size;\n}\n\nUVW_INLINE void tty_handle::vterm_state(tty_handle::tty_vtermstate s) const noexcept {\n switch(s) {\n case tty_vtermstate::SUPPORTED:\n uv_tty_set_vterm_state(uv_tty_vtermstate_t::UV_TTY_SUPPORTED);\n break;\n case tty_vtermstate::UNSUPPORTED:\n uv_tty_set_vterm_state(uv_tty_vtermstate_t::UV_TTY_UNSUPPORTED);\n break;\n }\n}\n\nUVW_INLINE tty_handle::tty_vtermstate tty_handle::vterm_state() const noexcept {\n uv_tty_vtermstate_t state;\n uv_tty_get_vterm_state(&state);\n return tty_vtermstate{state};\n}\n\n} // namespace uvw\n\n// Path: src/uvw/udp.h\n#ifndef UVW_UDP_INCLUDE_H\n#define UVW_UDP_INCLUDE_H\n\n#include \n#include \n#include \n#include \n#include \n#include \n#include \"config.h\"\n#include \"enum.hpp\"\n#include \"handle.hpp\"\n#include \"request.hpp\"\n#include \"util.h\"\n\nnamespace uvw {\n\n/*! @brief Send event. */\nstruct send_event {};\n\n/*! @brief UDP data event. */\nstruct udp_data_event {\n explicit udp_data_event(socket_address sndr, std::unique_ptr buf, std::size_t len, bool part) noexcept;\n\n std::unique_ptr data; /*!< A bunch of data read on the stream. */\n std::size_t length; /*!< The amount of data read on the stream. */\n socket_address sender; /*!< A valid instance of socket_address. */\n bool partial; /*!< True if the message was truncated, false otherwise. */\n};\n\nnamespace details {\n\nenum class uvw_udp_flags : std::underlying_type_t {\n IPV6ONLY = UV_UDP_IPV6ONLY,\n UDP_PARTIAL = UV_UDP_PARTIAL,\n REUSEADDR = UV_UDP_REUSEADDR,\n UDP_MMSG_CHUNK = UV_UDP_MMSG_CHUNK,\n UDP_MMSG_FREE = UV_UDP_MMSG_FREE,\n UDP_LINUX_RECVERR = UV_UDP_LINUX_RECVERR,\n UDP_RECVMMSG = UV_UDP_RECVMMSG,\n _UVW_ENUM = 0\n};\n\nenum class uvw_membership : std::underlying_type_t {\n LEAVE_GROUP = UV_LEAVE_GROUP,\n JOIN_GROUP = UV_JOIN_GROUP\n};\n\nclass send_req final: public request {\n static void udp_send_callback(uv_udp_send_t *req, int status);\n\npublic:\n using deleter = void (*)(char *);\n\n send_req(loop::token token, std::shared_ptr parent, std::unique_ptr dt, unsigned int len);\n\n int send(uv_udp_t *hndl, const struct sockaddr *addr);\n\nprivate:\n std::unique_ptr data;\n uv_buf_t buf;\n};\n\n} // namespace details\n\n/**\n * @brief The UDP handle.\n *\n * UDP handles encapsulate UDP communication for both clients and servers.
\n * By default, _ipv4_ is used as a template parameter. The handle already\n * supports _IPv6_ out-of-the-box by using `uvw::ipv6`.\n *\n * To create an `udp_handle` through a `loop`, arguments follow:\n *\n * * An optional integer value that indicates optional flags used to initialize\n * the socket.\n *\n * See the official\n * [documentation](http://docs.libuv.org/en/v1.x/udp.html#c.uv_udp_init_ex)\n * for further details.\n */\nclass udp_handle final: public handle {\n static void recv_callback(uv_udp_t *hndl, ssize_t nread, const uv_buf_t *buf, const sockaddr *addr, unsigned flags);\n\npublic:\n using membership = details::uvw_membership;\n using udp_flags = details::uvw_udp_flags;\n using ipv4 = uvw::ipv4;\n using ipv6 = uvw::ipv6;\n\n explicit udp_handle(loop::token token, std::shared_ptr ref, unsigned int f = {});\n\n /**\n * @brief Initializes the handle. The actual socket is created lazily.\n * @return Underlying return value.\n */\n int init();\n\n /**\n * @brief Opens an existing file descriptor or SOCKET as a UDP handle.\n *\n * The passed file descriptor or SOCKET is not checked for its type, but\n * it\u2019s required that it represents a valid datagram socket.\n *\n * See the official\n * [documentation](http://docs.libuv.org/en/v1.x/udp.html#c.uv_udp_open)\n * for further details.\n *\n * @param socket A valid socket handle (either a file descriptor or a\n * SOCKET).\n *\n * @return Underlying return value.\n */\n int open(os_socket_handle socket);\n\n /**\n * @brief Associates the handle to a remote address and port (either IPv4 or\n * IPv6).\n *\n * Every message sent by this handle is automatically sent to the given\n * destination.
\n * Trying to call this function on an already connected handle isn't\n * allowed.\n *\n * @param addr Initialized `sockaddr_in` or `sockaddr_in6` data structure.\n * @return Underlying return value.\n */\n int connect(const sockaddr &addr);\n\n /**\n * @brief Associates the handle to a remote address and port (either IPv4 or\n * IPv6).\n *\n * Every message sent by this handle is automatically sent to the given\n * destination.
\n * Trying to call this function on an already connected handle isn't\n * allowed.\n *\n * @param ip The address to which to bind.\n * @param port The port to which to bind.\n * @return Underlying return value.\n */\n int connect(const std::string &ip, unsigned int port);\n\n /**\n * @brief Associates the handle to a remote address and port (either IPv4 or\n * IPv6).\n *\n * Every message sent by this handle is automatically sent to the given\n * destination.
\n * Trying to call this function on an already connected handle isn't\n * allowed.\n *\n * @param addr A valid instance of socket_address.\n * @return Underlying return value.\n */\n int connect(socket_address addr);\n\n /**\n * @brief Disconnects the handle.\n *\n * Trying to disconnect a handle that is not connected isn't allowed.\n *\n * @return Underlying return value.\n */\n int disconnect();\n\n /**\n * @brief Gets the remote address to which the handle is connected, if any.\n * @return A valid instance of socket_address, an empty one in case of\n * errors.\n */\n socket_address peer() const noexcept;\n\n /**\n * @brief Binds the UDP handle to an IP address and port.\n *\n * Available flags are:\n *\n * * `udp_handle::udp_flags::IPV6ONLY`\n * * `udp_handle::udp_flags::UDP_PARTIAL`\n * * `udp_handle::udp_flags::REUSEADDR`\n * * `udp_handle::udp_flags::UDP_MMSG_CHUNK`\n * * `udp_handle::udp_flags::UDP_MMSG_FREE`\n * * `udp_handle::udp_flags::UDP_LINUX_RECVERR`\n * * `udp_handle::udp_flags::UDP_RECVMMSG`\n *\n * See the official\n * [documentation](http://docs.libuv.org/en/v1.x/udp.html#c.uv_udp_flags)\n * for further details.\n *\n * @param addr Initialized `sockaddr_in` or `sockaddr_in6` data structure.\n * @param opts Optional additional flags.\n * @return Underlying return value.\n */\n int bind(const sockaddr &addr, udp_flags opts = udp_flags::_UVW_ENUM);\n\n /**\n * @brief Binds the UDP handle to an IP address and port.\n *\n * Available flags are:\n *\n * * `udp_handle::udp_flags::IPV6ONLY`\n * * `udp_handle::udp_flags::UDP_PARTIAL`\n * * `udp_handle::udp_flags::REUSEADDR`\n * * `udp_handle::udp_flags::UDP_MMSG_CHUNK`\n * * `udp_handle::udp_flags::UDP_MMSG_FREE`\n * * `udp_handle::udp_flags::UDP_LINUX_RECVERR`\n * * `udp_handle::udp_flags::UDP_RECVMMSG`\n *\n * See the official\n * [documentation](http://docs.libuv.org/en/v1.x/udp.html#c.uv_udp_flags)\n * for further details.\n *\n * @param ip The IP address to which to bind.\n * @param port The port to which to bind.\n * @param opts Optional additional flags.\n * @return Underlying return value.\n */\n int bind(const std::string &ip, unsigned int port, udp_flags opts = udp_flags::_UVW_ENUM);\n\n /**\n * @brief Binds the UDP handle to an IP address and port.\n *\n * Available flags are:\n *\n * * `udp_handle::udp_flags::IPV6ONLY`\n * * `udp_handle::udp_flags::UDP_PARTIAL`\n * * `udp_handle::udp_flags::REUSEADDR`\n * * `udp_handle::udp_flags::UDP_MMSG_CHUNK`\n * * `udp_handle::udp_flags::UDP_MMSG_FREE`\n * * `udp_handle::udp_flags::UDP_LINUX_RECVERR`\n * * `udp_handle::udp_flags::UDP_RECVMMSG`\n *\n * See the official\n * [documentation](http://docs.libuv.org/en/v1.x/udp.html#c.uv_udp_flags)\n * for further details.\n *\n * @param addr A valid instance of socket_address.\n * @param opts Optional additional flags.\n * @return Underlying return value.\n */\n int bind(socket_address addr, udp_flags opts = udp_flags::_UVW_ENUM);\n\n /**\n * @brief Get the local IP and port of the UDP handle.\n * @return A valid instance of socket_address, an empty one in case of\n * errors.\n */\n socket_address sock() const noexcept;\n\n /**\n * @brief Sets membership for a multicast address.\n *\n * Available values for `ms` are:\n *\n * * `udp_handle::membership::LEAVE_GROUP`\n * * `udp_handle::membership::JOIN_GROUP`\n *\n * @param multicast Multicast address to set membership for.\n * @param iface Interface address.\n * @param ms Action to be performed.\n * @return True in case of success, false otherwise.\n */\n bool multicast_membership(const std::string &multicast, const std::string &iface, membership ms);\n\n /**\n * @brief Sets IP multicast loop flag.\n *\n * This makes multicast packets loop back to local sockets.\n *\n * @param enable True to enable multicast loop, false otherwise.\n * @return True in case of success, false otherwise.\n */\n bool multicast_loop(bool enable = true);\n\n /**\n * @brief Sets the multicast ttl.\n * @param val A value in the range `[1, 255]`.\n * @return True in case of success, false otherwise.\n */\n bool multicast_ttl(int val);\n\n /**\n * @brief Sets the multicast interface to send or receive data on.\n * @param iface Interface address.\n * @return True in case of success, false otherwise.\n */\n bool multicast_interface(const std::string &iface);\n\n /**\n * @brief Sets broadcast on or off.\n * @param enable True to set broadcast on, false otherwise.\n * @return True in case of success, false otherwise.\n */\n bool broadcast(bool enable = false);\n\n /**\n * @brief Sets the time to live.\n * @param val A value in the range `[1, 255]`.\n * @return True in case of success, false otherwise.\n */\n bool ttl(int val);\n\n /**\n * @brief Sends data over the UDP socket.\n *\n * Note that if the socket has not previously been bound with `bind()`, it\n * will be bound to `0.0.0.0` (the _all interfaces_ IPv4 address) and a\n * random port number.\n *\n * The handle takes the ownership of the data and it is in charge of delete\n * them.\n *\n * A send event will be emitted when the data have been sent.\n *\n * @param addr Initialized `sockaddr_in` or `sockaddr_in6` data structure.\n * @param data The data to be sent.\n * @param len The lenght of the submitted data.\n * @return Underlying return value.\n */\n int send(const sockaddr &addr, std::unique_ptr data, unsigned int len);\n\n /**\n * @brief Sends data over the UDP socket.\n *\n * Note that if the socket has not previously been bound with `bind()`, it\n * will be bound to `0.0.0.0` (the _all interfaces_ IPv4 address) and a\n * random port number.\n *\n * The handle takes the ownership of the data and it is in charge of delete\n * them.\n *\n * A send event will be emitted when the data have been sent.\n *\n * @param ip The address to which to send data.\n * @param port The port to which to send data.\n * @param data The data to be sent.\n * @param len The lenght of the submitted data.\n * @return Underlying return value.\n */\n int send(const std::string &ip, unsigned int port, std::unique_ptr data, unsigned int len);\n\n /**\n * @brief Sends data over the UDP socket.\n *\n * Note that if the socket has not previously been bound with `bind()`, it\n * will be bound to `0.0.0.0` (the _all interfaces_ IPv4 address) and a\n * random port number.\n *\n * The handle takes the ownership of the data and it is in charge of delete\n * them.\n *\n * A send event will be emitted when the data have been sent.\n *\n * @param addr A valid instance of socket_address.\n * @param data The data to be sent.\n * @param len The lenght of the submitted data.\n * @return Underlying return value.\n */\n int send(socket_address addr, std::unique_ptr data, unsigned int len);\n\n /**\n * @brief Sends data over the UDP socket.\n *\n * Note that if the socket has not previously been bound with `bind()`, it\n * will be bound to `0.0.0.0` (the _all interfaces_ IPv4 address) and a\n * random port number.\n *\n * The handle doesn't take the ownership of the data. Be sure that their\n * lifetime overcome the one of the request.\n *\n * A send event will be emitted when the data have been sent.\n *\n * @param addr Initialized `sockaddr_in` or `sockaddr_in6` data structure.\n * @param data The data to be sent.\n * @param len The lenght of the submitted data.\n * @return Underlying return value.\n */\n int send(const sockaddr &addr, char *data, unsigned int len);\n\n /**\n * @brief Sends data over the UDP socket.\n *\n * Note that if the socket has not previously been bound with `bind()`, it\n * will be bound to `0.0.0.0` (the _all interfaces_ IPv4 address) and a\n * random port number.\n *\n * The handle doesn't take the ownership of the data. Be sure that their\n * lifetime overcome the one of the request.\n *\n * A send event will be emitted when the data have been sent.\n *\n * @param ip The address to which to send data.\n * @param port The port to which to send data.\n * @param data The data to be sent.\n * @param len The lenght of the submitted data.\n * @return Underlying return value.\n */\n int send(const std::string &ip, unsigned int port, char *data, unsigned int len);\n\n /**\n * @brief Sends data over the UDP socket.\n *\n * Note that if the socket has not previously been bound with `bind()`, it\n * will be bound to `0.0.0.0` (the _all interfaces_ IPv4 address) and a\n * random port number.\n *\n * The handle doesn't take the ownership of the data. Be sure that their\n * lifetime overcome the one of the request.\n *\n * A send event will be emitted when the data have been sent.\n *\n * @param addr A valid instance of socket_address.\n * @param data The data to be sent.\n * @param len The lenght of the submitted data.\n * @return Underlying return value.\n */\n int send(socket_address addr, char *data, unsigned int len);\n\n /**\n * @brief Sends data over the UDP socket.\n *\n * Same as `send()`, but it won\u2019t queue a send request if it can\u2019t be\n * completed immediately.\n *\n * @param addr Initialized `sockaddr_in` or `sockaddr_in6` data structure.\n * @param data The data to be sent.\n * @param len The lenght of the submitted data.\n * @return Underlying return value.\n */\n int try_send(const sockaddr &addr, std::unique_ptr data, unsigned int len);\n\n /**\n * @brief Sends data over the UDP socket.\n *\n * Same as `send()`, but it won\u2019t queue a send request if it can\u2019t be\n * completed immediately.\n *\n * @param ip The address to which to send data.\n * @param port The port to which to send data.\n * @param data The data to be sent.\n * @param len The lenght of the submitted data.\n * @return Underlying return value.\n */\n int try_send(const std::string &ip, unsigned int port, std::unique_ptr data, unsigned int len);\n\n /**\n * @brief Sends data over the UDP socket.\n *\n * Same as `send()`, but it won\u2019t queue a send request if it can\u2019t be\n * completed immediately.\n *\n * @param addr A valid instance of socket_address.\n * @param data The data to be sent.\n * @param len The lenght of the submitted data.\n * @return Underlying return value.\n */\n int try_send(socket_address addr, std::unique_ptr data, unsigned int len);\n\n /**\n * @brief Sends data over the UDP socket.\n *\n * Same as `send()`, but it won\u2019t queue a send request if it can\u2019t be\n * completed immediately.\n *\n * @param addr Initialized `sockaddr_in` or `sockaddr_in6` data structure.\n * @param data The data to be sent.\n * @param len The lenght of the submitted data.\n * @return Underlying return value.\n */\n int try_send(const sockaddr &addr, char *data, unsigned int len);\n\n /**\n * @brief Sends data over the UDP socket.\n *\n * Same as `send()`, but it won\u2019t queue a send request if it can\u2019t be\n * completed immediately.\n *\n * @param ip The address to which to send data.\n * @param port The port to which to send data.\n * @param data The data to be sent.\n * @param len The lenght of the submitted data.\n * @return Underlying return value.\n */\n int try_send(const std::string &ip, unsigned int port, char *data, unsigned int len);\n\n /**\n * @brief Sends data over the UDP socket.\n *\n * Same as `send()`, but it won\u2019t queue a send request if it can\u2019t be\n * completed immediately.\n *\n * @param addr A valid instance of socket_address.\n * @param data The data to be sent.\n * @param len The lenght of the submitted data.\n * @return Underlying return value.\n */\n int try_send(socket_address addr, char *data, unsigned int len);\n\n /**\n * @brief Prepares for receiving data.\n *\n * Note that if the socket has not previously been bound with `bind()`, it\n * is bound to `0.0.0.0` (the _all interfaces_ IPv4 address) and a random\n * port number.\n *\n * An UDP data event will be emitted when the handle receives data.\n *\n * @return Underlying return value.\n */\n int recv();\n\n /**\n * @brief Prepares for receiving data.\n * @sa recv\n * @tparam Alloc Custom allocation function.\n * @return Underlying return value.\n */\n template\n int recv() {\n retur\nn uv_udp_recv_start(raw(), &details::common_alloc_callback, &recv_callback);\n }\n\n /**\n * @brief Stops listening for incoming datagrams.\n * @return Underlying return value.\n */\n int stop();\n\n /**\n * @brief Gets the number of bytes queued for sending.\n *\n * It strictly shows how much information is currently queued.\n *\n * @return Number of bytes queued for sending.\n */\n size_t send_queue_size() const noexcept;\n\n /**\n * @brief Number of send requests currently in the queue awaiting to be\n * processed.\n * @return Number of send requests currently in the queue.\n */\n size_t send_queue_count() const noexcept;\n\nprivate:\n enum {\n DEFAULT,\n FLAGS\n } tag{DEFAULT};\n\n unsigned int flags{};\n};\n\n} // namespace uvw\n\n#ifndef UVW_AS_LIB\n# include \"udp.cpp\"\n#endif\n\n#endif // UVW_UDP_INCLUDE_H\n\n// Path: src/uvw/udp.cpp\n#ifdef UVW_AS_LIB\n# include \"udp.h\"\n#endif\n\n#include \"config.h\"\n\nnamespace uvw {\n\nUVW_INLINE udp_data_event::udp_data_event(socket_address sndr, std::unique_ptr buf, std::size_t len, bool part) noexcept\n : data{std::move(buf)},\n length{len},\n sender{std::move(sndr)},\n partial{part} {}\n\nUVW_INLINE void details::send_req::udp_send_callback(uv_udp_send_t *req, int status) {\n if(auto ptr = reserve(req); status) {\n ptr->publish(error_event{status});\n } else {\n ptr->publish(send_event{});\n }\n}\n\nUVW_INLINE details::send_req::send_req(loop::token token, std::shared_ptr parent, std::unique_ptr dt, unsigned int len)\n : request{token, std::move(parent)},\n data{std::move(dt)},\n buf{uv_buf_init(data.get(), len)} {}\n\nUVW_INLINE int details::send_req::send(uv_udp_t *hndl, const struct sockaddr *addr) {\n return this->leak_if(uv_udp_send(raw(), hndl, &buf, 1, addr, &udp_send_callback));\n}\n\nUVW_INLINE void udp_handle::recv_callback(uv_udp_t *hndl, ssize_t nread, const uv_buf_t *buf, const sockaddr *addr, unsigned flags) {\n udp_handle &udp = *(static_cast(hndl->data));\n // data will be destroyed no matter of what the value of nread is\n std::unique_ptr data{buf->base};\n\n if(nread > 0) {\n // data available (can be truncated)\n udp.publish(udp_data_event{details::sock_addr(*addr), std::move(data), static_cast(nread), !(0 == (flags & UV_UDP_PARTIAL))});\n } else if(nread == 0 && addr == nullptr) {\n // no more data to be read, doing nothing is fine\n } else if(nread == 0 && addr != nullptr) {\n // empty udp packet\n udp.publish(udp_data_event{details::sock_addr(*addr), std::move(data), static_cast(nread), false});\n } else {\n // transmission error\n udp.publish(error_event(nread));\n }\n}\n\nUVW_INLINE udp_handle::udp_handle(loop::token token, std::shared_ptr ref, unsigned int f)\n : handle{token, std::move(ref)}, tag{FLAGS}, flags{f} {}\n\nUVW_INLINE int udp_handle::init() {\n if(tag == FLAGS) {\n return leak_if(uv_udp_init_ex(parent().raw(), raw(), flags));\n } else {\n return leak_if(uv_udp_init(parent().raw(), raw()));\n }\n}\n\nUVW_INLINE int udp_handle::open(os_socket_handle socket) {\n return uv_udp_open(raw(), socket);\n}\n\nUVW_INLINE int udp_handle::connect(const sockaddr &addr) {\n return uv_udp_connect(raw(), &addr);\n}\n\nUVW_INLINE int udp_handle::connect(const std::string &ip, unsigned int port) {\n return connect(details::ip_addr(ip.data(), port));\n}\n\nUVW_INLINE int udp_handle::connect(socket_address addr) {\n return connect(addr.ip, addr.port);\n}\n\nUVW_INLINE int udp_handle::disconnect() {\n return uv_udp_connect(raw(), nullptr);\n}\n\nUVW_INLINE socket_address udp_handle::peer() const noexcept {\n sockaddr_storage storage;\n int len = sizeof(sockaddr_storage);\n uv_udp_getpeername(raw(), reinterpret_cast(&storage), &len);\n return details::sock_addr(storage);\n}\n\nUVW_INLINE int udp_handle::bind(const sockaddr &addr, udp_handle::udp_flags opts) {\n return uv_udp_bind(raw(), &addr, static_cast(opts));\n}\n\nUVW_INLINE int udp_handle::bind(const std::string &ip, unsigned int port, udp_flags opts) {\n return bind(details::ip_addr(ip.data(), port), opts);\n}\n\nUVW_INLINE int udp_handle::bind(socket_address addr, udp_flags opts) {\n return bind(addr.ip, addr.port, opts);\n}\n\nUVW_INLINE socket_address udp_handle::sock() const noexcept {\n sockaddr_storage storage;\n int len = sizeof(sockaddr_storage);\n uv_udp_getsockname(raw(), reinterpret_cast(&storage), &len);\n return details::sock_addr(storage);\n}\n\nUVW_INLINE bool udp_handle::multicast_membership(const std::string &multicast, const std::string &iface, membership ms) {\n return (0 == uv_udp_set_membership(raw(), multicast.data(), iface.data(), static_cast(ms)));\n}\n\nUVW_INLINE bool udp_handle::multicast_loop(bool enable) {\n return (0 == uv_udp_set_multicast_loop(raw(), enable));\n}\n\nUVW_INLINE bool udp_handle::multicast_ttl(int val) {\n return (0 == uv_udp_set_multicast_ttl(raw(), val > 255 ? 255 : val));\n}\n\nUVW_INLINE bool udp_handle::multicast_interface(const std::string &iface) {\n return (0 == uv_udp_set_multicast_interface(raw(), iface.data()));\n}\n\nUVW_INLINE bool udp_handle::broadcast(bool enable) {\n return (0 == uv_udp_set_broadcast(raw(), enable));\n}\n\nUVW_INLINE bool udp_handle::ttl(int val) {\n return (0 == uv_udp_set_ttl(raw(), val > 255 ? 255 : val));\n}\n\nUVW_INLINE int udp_handle::send(const sockaddr &addr, std::unique_ptr data, unsigned int len) {\n auto req = parent().resource(std::unique_ptr{data.release(), [](char *ptr) { delete[] ptr; }}, len);\n\n auto listener = [ptr = shared_from_this()](const auto &event, const auto &) {\n ptr->publish(event);\n };\n\n req->on(listener);\n req->on(listener);\n\n return req->send(raw(), &addr);\n}\n\nUVW_INLINE int udp_handle::send(const std::string &ip, unsigned int port, std::unique_ptr data, unsigned int len) {\n return send(details::ip_addr(ip.data(), port), std::move(data), len);\n}\n\nUVW_INLINE int udp_handle::send(socket_address addr, std::unique_ptr data, unsigned int len) {\n return send(addr.ip, addr.port, std::move(data), len);\n}\n\nUVW_INLINE int udp_handle::send(const sockaddr &addr, char *data, unsigned int len) {\n auto req = parent().resource(std::unique_ptr{data, [](char *) {}}, len);\n\n auto listener = [ptr = shared_from_this()](const auto &event, const auto &) {\n ptr->publish(event);\n };\n\n req->on(listener);\n req->on(listener);\n\n return req->send(raw(), &addr);\n}\n\nUVW_INLINE int udp_handle::send(const std::string &ip, unsigned int port, char *data, unsigned int len) {\n return send(details::ip_addr(ip.data(), port), data, len);\n}\n\nUVW_INLINE int udp_handle::send(socket_address addr, char *data, unsigned int len) {\n return send(addr.ip, addr.port, data, len);\n}\n\nUVW_INLINE int udp_handle::try_send(const sockaddr &addr, std::unique_ptr data, unsigned int len) {\n uv_buf_t bufs[] = {uv_buf_init(data.get(), len)};\n return uv_udp_try_send(raw(), bufs, 1, &addr);\n}\n\nUVW_INLINE int udp_handle::try_send(const std::string &ip, unsigned int port, std::unique_ptr data, unsigned int len) {\n return try_send(details::ip_addr(ip.data(), port), std::move(data), len);\n}\n\nUVW_INLINE int udp_handle::try_send(socket_address addr, std::unique_ptr data, unsigned int len) {\n return try_send(addr.ip, addr.port, std::move(data), len);\n}\n\nUVW_INLINE int udp_handle::try_send(const sockaddr &addr, char *data, unsigned int len) {\n uv_buf_t bufs[] = {uv_buf_init(data, len)};\n return uv_udp_try_send(raw(), bufs, 1, &addr);\n}\n\nUVW_INLINE int udp_handle::try_send(const std::string &ip, unsigned int port, char *data, unsigned int len) {\n return try_send(details::ip_addr(ip.data(), port), data, len);\n}\n\nUVW_INLINE int udp_handle::try_send(socket_address addr, char *data, unsigned int len) {\n return try_send(addr.ip, addr.port, data, len);\n}\n\nUVW_INLINE int udp_handle::recv() {\n return uv_udp_recv_start(raw(), &details::common_alloc_callback, &recv_callback);\n}\n\nUVW_INLINE int udp_handle::stop() {\n return uv_udp_recv_stop(raw());\n}\n\nUVW_INLINE size_t udp_handle::send_queue_size() const noexcept {\n return uv_udp_get_send_queue_size(raw());\n}\n\nUVW_INLINE size_t udp_handle::send_queue_count() const noexcept {\n return uv_udp_get_send_queue_count(raw());\n}\n\n} // namespace uvw\n\n// Path: src/uvw/signal.h\n#ifndef UVW_SIGNAL_INCLUDE_H\n#define UVW_SIGNAL_INCLUDE_H\n\n#include \n#include \"config.h\"\n#include \"handle.hpp\"\n#include \"loop.h\"\n\nnamespace uvw {\n\n/*! @brief Signal event. */\nstruct signal_event {\n explicit signal_event(int sig) noexcept;\n\n int signum; /*!< The signal being monitored by this handle. */\n};\n\n/**\n * @brief The signal handle.\n *\n * Signal handles implement Unix style signal handling on a per-event loop\n * bases.
\n * Reception of some signals is emulated on Windows.\n *\n * To create a `signal_handle` through a `loop`, no arguments are required.\n *\n * See the official\n * [documentation](http://docs.libuv.org/en/v1.x/signal.html)\n * for further details.\n */\nclass signal_handle final: public handle {\n static void start_callback(uv_signal_t *hndl, int signum);\n\npublic:\n using handle::handle;\n\n /**\n * @brief Initializes the handle.\n * @return Underlying return value.\n */\n int init();\n\n /**\n * @brief Starts the handle.\n *\n * The handle will start emitting signal events when needed.\n *\n * @param signum The signal to be monitored.\n * @return Underlying return value.\n */\n int start(int signum);\n\n /**\n * @brief Starts the handle.\n *\n * Same functionality as signal_handle::start but the signal handler is\n * reset the moment the signal is received.\n *\n * @param signum The signal to be monitored.\n * @return Underlying return value.\n */\n int one_shot(int signum);\n\n /**\n * @brief Stops the handle.\n * @return Underlying return value.\n */\n int stop();\n\n /**\n * @brief Gets the signal being monitored.\n * @return The signal being monitored.\n */\n int signal() const noexcept;\n};\n\n} // namespace uvw\n\n#ifndef UVW_AS_LIB\n# include \"signal.cpp\"\n#endif\n\n#endif // UVW_SIGNAL_INCLUDE_H\n\n// Path: src/uvw/signal.cpp\n#ifdef UVW_AS_LIB\n# include \"signal.h\"\n#endif\n\n#include \"config.h\"\n\nnamespace uvw {\n\nUVW_INLINE signal_event::signal_event(int sig) noexcept\n : signum{sig} {}\n\nUVW_INLINE void signal_handle::start_callback(uv_signal_t *hndl, int signum) {\n signal_handle &signal = *(static_cast(hndl->data));\n signal.publish(signal_event{signum});\n}\n\nUVW_INLINE int signal_handle::init() {\n return leak_if(uv_signal_init(parent().raw(), raw()));\n}\n\nUVW_INLINE int signal_handle::start(int signum) {\n return uv_signal_start(raw(), &start_callback, signum);\n}\n\nUVW_INLINE int signal_handle::one_shot(int signum) {\n return uv_signal_start_oneshot(raw(), &start_callback, signum);\n}\n\nUVW_INLINE int signal_handle::stop() {\n return uv_signal_stop(raw());\n}\n\nUVW_INLINE int signal_handle::signal() const noexcept {\n return raw()->signum;\n}\n\n} // namespace uvw\n\n// Path: src/uvw/tcp.h\n#ifndef UVW_TCP_INCLUDE_H\n#define UVW_TCP_INCLUDE_H\n\n#include \n#include \n#include \n#include \n#include \n#include \n#include \"config.h\"\n#include \"enum.hpp\"\n#include \"request.hpp\"\n#include \"stream.h\"\n#include \"util.h\"\n\nnamespace uvw {\n\nnamespace details {\n\nenum class uvw_tcp_flags : std::underlying_type_t {\n IPV6ONLY = UV_TCP_IPV6ONLY,\n _UVW_ENUM = 0\n};\n\n}\n\n/**\n * @brief The TCP handle.\n *\n * TCP handles are used to represent both TCP streams and servers.
\n * By default, _ipv4_ is used as a template parameter. The handle already\n * supports _IPv6_ out-of-the-box by using `uvw::ipv6`.\n *\n * To create a `tcp_handle` through a `loop`, arguments follow:\n *\n * * An optional integer value that indicates the flags used to initialize\n * the socket.\n *\n * See the official\n * [documentation](http://docs.libuv.org/en/v1.x/tcp.html#c.uv_tcp_init_ex)\n * for further details.\n */\nclass tcp_handle final: public stream_handle {\npublic:\n using time = std::chrono::duration;\n using tcp_flags = details::uvw_tcp_flags;\n using ipv4 = uvw::ipv4;\n using ipv6 = uvw::ipv6;\n\n explicit tcp_handle(loop::token token, std::shared_ptr ref, unsigned int f = {});\n\n /**\n * @brief Initializes the handle. No socket is created as of yet.\n * @return Underlying return value.\n */\n int init();\n\n /**\n * @brief Opens an existing file descriptor or SOCKET as a TCP handle.\n *\n * The passed file descriptor or SOCKET is not checked for its type, but\n * it\u2019s required that it represents a valid stream socket.\n *\n * @param socket A valid socket handle (either a file descriptor or a\n * SOCKET).\n *\n * @return Underlying return value.\n */\n int open(os_socket_handle socket);\n\n /**\n * @brief Enables/Disables Nagle\u2019s algorithm.\n * @param value True to enable it, false otherwise.\n * @return True in case of success, false otherwise.\n */\n bool no_delay(bool value = false);\n\n /**\n * @brief Enables/Disables TCP keep-alive.\n * @param enable True to enable it, false otherwise.\n * @param val Initial delay in seconds (use\n * `std::chrono::duration`).\n * @return True in case of success, false otherwise.\n */\n bool keep_alive(bool enable = false, time val = time{0});\n\n /**\n * @brief Enables/Disables simultaneous asynchronous accept requests.\n *\n * Enables/Disables simultaneous asynchronous accept requests that are\n * queued by the operating system when listening for new TCP\n * connections.
\n * This setting is used to tune a TCP server for the desired performance.\n * Having simultaneous accepts can significantly improve the rate of\n * accepting connections (which is why it is enabled by default) but may\n * lead to uneven load distribution in multi-process setups.\n *\n * @param enable True to enable it, false otherwise.\n * @return True in case of success, false otherwise.\n */\n bool simultaneous_accepts(bool enable = true);\n\n /**\n * @brief Binds the handle to an address and port.\n *\n * A successful call to this function does not guarantee that the call to\n * `listen()` or `connect()` will work properly.\n *\n * Available flags are:\n *\n * * `tcp_handle::tcp_flags::IPV6ONLY`: it disables dual-stack support and\n * only IPv6 is used.\n *\n * @param addr Initialized `sockaddr_in` or `sockaddr_in6` data structure.\n * @param opts Optional additional flags.\n * @return Underlying return value.\n */\n int bind(const sockaddr &addr, tcp_flags opts = tcp_flags::_UVW_ENUM);\n\n /**\n * @brief Binds the handle to an address and port.\n *\n * A successful call to this function does not guarantee that the call to\n * `listen()` or `connect()` will work properly.\n *\n * Available flags are:\n *\n * * `tcp_handle::tcp_flags::IPV6ONLY`: it disables dual-stack support and\n * only IPv6 is used.\n *\n * @param ip The address to which to bind.\n * @param port The port to which to bind.\n * @param opts Optional additional flags.\n * @return Underlying return value.\n */\n int bind(const std::string &ip, unsigned int port, tcp_flags opts = tcp_flags::_UVW_ENUM);\n\n /**\n * @brief Binds the handle to an address and port.\n *\n * A successful call to this function does not guarantee that the call to\n * `listen()` or `connect()` will work properly.\n *\n * Available flags are:\n *\n * * `tcp_handle::tcp_flags::IPV6ONLY`: it disables dual-stack support and\n * only IPv6 is used.\n *\n * @param addr A valid instance of socket_address.\n * @param opts Optional additional flags.\n * @return Underlying return value.\n */\n int bind(socket_address addr, tcp_flags opts = tcp_flags::_UVW_ENUM);\n\n /**\n * @brief Gets the current address to which the handle is bound.\n * @return A valid instance of socket_address, an empty one in case of\n * errors.\n */\n socket_address sock() const noexcept;\n\n /**\n * @brief Gets the address of the peer connected to the handle.\n * @return A valid instance of socket_address, an empty one in case of\n * errors.\n */\n socket_address peer() const noexcept;\n\n /**\n * @brief Establishes an IPv4 or IPv6 TCP connection.\n *\n * On Windows if the addr is initialized to point to an unspecified address\n * (`0.0.0.0` or `::`) it will be changed to point to localhost. This is\n * done to match the behavior of Linux systems.\n *\n * A connect event is emitted when the connection has been established.\n *\n * @param addr Initialized `sockaddr_in` or `sockaddr_in6` data structure.\n * @return Underlying return value.\n */\n int connect(const sockaddr &addr);\n\n /**\n * @brief Establishes an IPv4 or IPv6 TCP connection.\n *\n * A connect event is emitted when the connection has been established.\n *\n * @param ip The address to which to bind.\n * @param port The port to which to bind.\n * @return Underlying return value.\n */\n int connect(const std::string &ip, unsigned int port);\n\n /**\n * @brief Establishes an IPv4 or IPv6 TCP connection.\n *\n * A connect event is emitted when the connection has been established.\n *\n * @param addr A valid instance of socket_address.\n * @return Underlying return value.\n */\n int connect(socket_address addr);\n\n /**\n * @brief Resets a TCP connection by sending a RST packet.\n *\n * This is accomplished by setting the `SO_LINGER` socket option with a\n * linger interval of zero and then calling `close`.
\n * Due to some platform inconsistencies, mixing of `shutdown` and\n * `close_reset` calls is not allowed.\n *\n * A close event is emitted when the connection has been reset.\n *\n * @return Underlying return value.\n */\n int close_reset();\n\nprivate:\n enum {\n DEFAULT,\n FLAGS\n } tag;\n\n unsigned int flags;\n};\n\n} // namespace uvw\n\n#ifndef UVW_AS_LIB\n# include \"tcp.cpp\"\n#endif\n\n#endif // UVW_TCP_INCLUDE_H\n\n// Path: src/uvw/tcp.cpp\n#ifdef UVW_AS_LIB\n# include \"tcp.h\"\n#endif\n\n#include \"config.h\"\n\nnamespace uvw {\n\nUVW_INLINE tcp_handle::tcp_handle(loop::token token, std::shared_ptr ref, unsigned int f)\n : stream_handle{token, std::move(ref)}, tag{f ? FLAGS : DEFAULT}, flags{f} {}\n\nUVW_INLINE int tcp_handle::init() {\n if(tag == FLAGS) {\n return leak_if(uv_tcp_init_ex(parent().raw(), raw(), flags));\n } else {\n return leak_if(uv_tcp_init(parent().raw(), raw()));\n }\n}\n\nUVW_INLINE int tcp_handle::open(os_socket_handle socket) {\n return uv_tcp_open(raw(), socket);\n}\n\nUVW_INLINE bool tcp_handle::no_delay(bool value) {\n return (0 == uv_tcp_nodelay(raw(), value));\n}\n\nUVW_INLINE bool tcp_handle::keep_alive(bool enable, tcp_handle::time val) {\n return (0 == uv_tcp_keepalive(raw(), enable, val.count()));\n}\n\nUVW_INLINE bool tcp_handle::simultaneous_accepts(bool enable) {\n return (0 == uv_tcp_simultaneous_accepts(raw(), enable));\n}\n\nUVW_INLINE int tcp_handle::bind(const sockaddr &addr, tcp_flags opts) {\n return uv_tcp_bind(raw(), &addr, static_cast(opts));\n}\n\nUVW_INLINE int tcp_handle::bind(const std::string &ip, unsigned int port, tcp_flags opts) {\n return bind(details::ip_addr(ip.data(), port), opts);\n}\n\nUVW_INLINE int tcp_handle::bind(socket_address addr, tcp_flags opts) {\n return bind(addr.ip, addr.port, opts);\n}\n\nUVW_INLINE socket_address tcp_handle::sock() const noexcept {\n sockaddr_storage storage;\n int len = sizeof(sockaddr_storage);\n uv_tcp_getsockname(raw(), reinterpret_cast(&storage), &len);\n return details::sock_addr(storage);\n}\n\nUVW_INLINE socket_address tcp_handle::peer() const noexcept {\n sockaddr_storage storage;\n int len = sizeof(sockaddr_storage);\n uv_tcp_getpeername(raw(), reinterpret_cast(&storage), &len);\n return details::sock_addr(storage);\n}\n\nUVW_INLINE int tcp_handle::connect(const std::string &ip, unsigned int port) {\n return connect(details::ip_addr(ip.data(), port));\n}\n\nUVW_INLINE int tcp_handle::connect(socket_address addr) {\n return connect(addr.ip, addr.port);\n}\n\nUVW_INLINE int tcp_handle::connect(const sockaddr &addr) {\n auto listener = [ptr = shared_from_this()](const auto &event, const auto &) {\n ptr->publish(event);\n };\n\n auto req = parent().resource();\n req->on(listener);\n req->on(listener);\n\n return req->connect(&uv_tcp_connect, raw(), &addr);\n}\n\nUVW_INLINE int tcp_handle::close_reset() {\n return uv_tcp_close_reset(raw(), &this->close_callback);\n}\n\n} // namespace uvw\n\n// Path: src/uvw/process.h\n#ifndef UVW_PROCESS_INCLUDE_H\n#define UVW_PROCESS_INCLUDE_H\n\n#include \n#include \n#include \n#include \n#include \n#include \"config.h\"\n#include \"enum.hpp\"\n#include \"handle.hpp\"\n#include \"loop.h\"\n#include \"stream.h\"\n#include \"util.h\"\n\nnamespace uvw {\n\nnamespace details {\n\nenum class uvw_process_flags : std::underlying_type_t {\n SETUID = UV_PROCESS_SETUID,\n SETGID = UV_PROCESS_SETGID,\n WINDOWS_VERBATIM_ARGUMENTS = UV_PROCESS_WINDOWS_VERBATIM_ARGUMENTS,\n DETACHED = UV_PROCESS_DETACHED,\n WINDOWS_HIDE = UV_PROCESS_WINDOWS_HIDE,\n WINDOWS_HIDE_CONSOLE = UV_PROCESS_WINDOWS_HIDE_CONSOLE,\n WINDOWS_HIDE_GUI = UV_PROCESS_WINDOWS_HIDE_GUI,\n _UVW_ENUM = 0\n};\n\nenum class uvw_stdio_flags : std::underlying_type_t {\n IGNORE_STREAM = UV_IGNORE,\n CREATE_PIPE = UV_CREATE_PIPE,\n INHERIT_FD = UV_INHERIT_FD,\n INHERIT_STREAM = UV_INHERIT_STREAM,\n READABLE_PIPE = UV_READABLE_PIPE,\n WRITABLE_PIPE = UV_WRITABLE_PIPE,\n OVERLAPPED_PIPE = UV_OVERLAPPED_PIPE,\n _UVW_ENUM = 0\n};\n\n} // namespace details\n\n/*! @brief Exit event. */\nstruct exit_event {\n explicit exit_event(int64_t code, int sig) noexcept;\n\n int64_t status; /*!< The exit status. */\n int signal; /*!< The signal that caused the process to terminate, if any. */\n};\n\n/**\n * @brief The process handle.\n *\n * Process handles will spawn a new process and allow the user to control it and\n * establish communication channels with it using streams.\n */\nclass process_handle final: public handle {\n static void exit_callback(uv_process_t *hndl, int64_t exit_status, int term_signal);\n\npublic:\n using process_flags = details::uvw_process_flags;\n using stdio_flags = details::uvw_stdio_flags;\n\n process_handle(loop::token token, std::shared_ptr ref);\n\n /**\n * @brief Disables inheritance for file descriptors/handles.\n *\n * Disables inheritance for file descriptors/handles that this process\n * inherited from its parent. The effect is that child processes spawned by\n * this process don\u2019t accidentally inherit these handles.
\n * It is recommended to call this function as early in your program as\n * possible, before the inherited file descriptors can be closed or\n * duplicated.\n *\n * See the official\n * [documentation](http://docs.libuv.org/en/v1.x/process.html#c.uv_disable_stdio_inheritance)\n * for further details.\n */\n static void disable_stdio_inheritance() noexcept;\n\n /**\n * @brief kill Sends the specified signal to the given PID.\n * @param pid A valid process id.\n * @param signum A valid signal identifier.\n * @return True in case of success, false otherwise.\n */\n static bool kill(int pid, int signum) noexcept;\n\n /**\n * @brief Initializes the handle.\n * @return Underlying return value.\n */\n int init();\n\n /**\n * @brief spawn Starts the process.\n *\n * See the official\n * [documentation](http://docs.libuv.org/en/v1.x/process.html)\n * for further details.\n *\n * @param file Path pointing to the program to be executed.\n * @param args Command line arguments.\n * @param env Optional environment for the new process.\n * @return Underlying return value.\n */\n int spawn(const char *file, char **args, char **env = nullptr);\n\n /**\n * @brief Sends the specified signal to the internal process handle.\n * @param signum A valid signal identifier.\n * @return Underlying return value.\n */\n int kill(int signum);\n\n /**\n * @brief Gets the PID of the spawned process.\n *\n * It\u2019s set after calling `spawn()`.\n *\n * @return The PID of the spawned process.\n */\n int pid() noexcept;\n\n /**\n * @brief Sets the current working directory for the subprocess.\n * @param path The working directory to be used when `spawn()` is invoked.\n * @return A reference to this process handle.\n */\n process_handle &cwd(const std::string &path) noexcept;\n\n /**\n * @brief Sets flags that control how `spawn()` behaves.\n *\n * Available flags are:\n *\n * * `process_handle::process_flags::SETUID`\n...\n", "needle_token_start": 7238, "needle_token_end": 7289, "code_context_ntokens": 16083, "output": ["```cpp\nUVW_INLINE int udp_handle::recv() {\n return uv_udp_recv_start(raw(), &details::common_alloc_callback, &recv_callback);\n}\n```"]} +{"repo": "skypjack/uvw", "name": "configure", "language": "cpp", "path": "src/uvw/loop.h", "position_ratio": 0.55, "description": "\nFunction Description:\n1. **Purpose**: To set additional options for an event loop, typically before it starts running, to modify its behavior based on specific needs such as blocking signals or tracking idle time.\n2. **Input**: A flag representing the option to be set and any necessary arguments that the specific option might require.\n3. **Output**: An integer indicating the success or failure of setting the option.\n4. **Procedure**: The method takes a flag and arguments, then applies these settings to the event loop using an underlying library call. The settings must be applied before the event loop starts processing events unless otherwise specified.\n\n", "instruction": "Based on the function description and code context, please retrieve and repeat the exact described function from the code context in a code block wrapped by ```:", "template": "instruction\ncode_context\ndescription\ninstruction", "code_context": "// Path: src/uvw/util.h\n#ifndef UVW_UTIL_INCLUDE_H\n#define UVW_UTIL_INCLUDE_H\n\n#include \n#include \n#include \n#include \n#include \n#include \n#include \n#include \n#include \n#include \n#include \"config.h\"\n\nnamespace uvw {\n\nnamespace details {\n\nenum class uvw_handle_type : std::underlying_type_t {\n UNKNOWN = UV_UNKNOWN_HANDLE,\n ASYNC = UV_ASYNC,\n CHECK = UV_CHECK,\n FS_EVENT = UV_FS_EVENT,\n FS_POLL = UV_FS_POLL,\n HANDLE = UV_HANDLE,\n IDLE = UV_IDLE,\n PIPE = UV_NAMED_PIPE,\n POLL = UV_POLL,\n PREPARE = UV_PREPARE,\n PROCESS = UV_PROCESS,\n STREAM = UV_STREAM,\n TCP = UV_TCP,\n TIMER = UV_TIMER,\n TTY = UV_TTY,\n UDP = UV_UDP,\n SIGNAL = UV_SIGNAL,\n FILE = UV_FILE\n};\n\nenum class uvw_clock_id : std::underlying_type_t {\n MONOTONIC = UV_CLOCK_MONOTONIC,\n REALTIME = UV_CLOCK_REALTIME\n};\n\ntemplate\nstruct uv_type_wrapper {\n using Type = T;\n\n constexpr uv_type_wrapper()\n : value{} {}\n\n constexpr uv_type_wrapper(Type val)\n : value{val} {}\n\n constexpr operator Type() const noexcept {\n return value;\n }\n\n bool operator==(uv_type_wrapper other) const noexcept {\n return value == other.value;\n }\n\nprivate:\n const Type value;\n};\n\ntemplate\nbool operator==(uv_type_wrapper lhs, uv_type_wrapper rhs) {\n return !(lhs == rhs);\n}\n\n} // namespace details\n\n/**\n * @brief Windows size representation.\n */\nstruct win_size {\n int width; /*!< The _width_ of the given window. */\n int height; /*!< The _height_ of the given window. */\n};\n\nusing handle_type = details::uvw_handle_type; /*!< The type of a handle. */\nusing handle_category = details::uv_type_wrapper; /*!< Utility class that wraps an internal handle type. */\nusing file_handle = details::uv_type_wrapper; /*!< Utility class that wraps an internal file handle. */\nusing os_socket_handle = details::uv_type_wrapper; /*!< Utility class that wraps an os socket handle. */\nusing os_file_descriptor = details::uv_type_wrapper; /*!< Utility class that wraps an os file descriptor. */\nusing pid_type = details::uv_type_wrapper; /*!< Utility class that wraps a cross platform representation of a pid. */\nusing clock_id = details::uvw_clock_id; /*!< Utility class that wraps a clock source. */\n\nconstexpr file_handle std_in{0}; /*!< Placeholder for stdin descriptor. */\nconstexpr file_handle std_out{1}; /*!< Placeholder for stdout descriptor. */\nconstexpr file_handle std_err{2}; /*!< Placeholder for stderr descriptor. */\n\nusing time_spec = uv_timespec_t; /*!< Library equivalent for uv_timespec_t. */\nusing file_info = uv_stat_t; /*!< Library equivalent for uv_stat_t. */\nusing fs_info = uv_statfs_t; /*!< Library equivalent for uv_statfs_t. */\nusing uid_type = uv_uid_t; /*!< Library equivalent for uv_uid_t. */\nusing gid_type = uv_gid_t; /*!< Library equivalent for uv_gid_t. */\n\nusing timeval = uv_timeval_t; /*!< Library equivalent for uv_timeval_t. */\nusing timeval64 = uv_timeval64_t; /*!< Library equivalent for uv_timeval64_t. */\nusing timespec64 = uv_timespec64_t; /*!< Library equivalent for uv_timespec64_t. */\nusing resource_usage = uv_rusage_t; /*!< Library equivalent for uv_rusage_t. */\n\n/**\n * @brief Utility class.\n *\n * This class can be used to query the subset of the password file entry for the\n * current effective uid (not the real uid).\n *\n * \\sa utilities::passwd\n */\nstruct passwd_info {\n passwd_info(std::shared_ptr pwd);\n\n /**\n * @brief Gets the username.\n * @return The username of the current effective uid (not the real uid).\n */\n std::string username() const noexcept;\n\n /**\n * @brief Gets the uid.\n * @return The current effective uid (not the real uid).\n */\n decltype(uv_passwd_t::uid) uid() const noexcept;\n\n /**\n * @brief Gets the gid.\n * @return The gid of the current effective uid (not the real uid).\n */\n decltype(uv_passwd_t::gid) gid() const noexcept;\n\n /**\n * @brief Gets the shell.\n * @return The shell of the current effective uid (not the real uid).\n */\n std::string shell() const noexcept;\n\n /**\n * @brief Gets the homedir.\n * @return The homedir of the current effective uid (not the real uid).\n */\n std::string homedir() const noexcept;\n\n /**\n * @brief Checks if the instance contains valid data.\n * @return True if data are all valid, false otherwise.\n */\n operator bool() const noexcept;\n\nprivate:\n std::shared_ptr value;\n};\n\n/**\n * @brief Utility class.\n *\n * This class can be used to get name and information about the current kernel.\n * The populated data includes the operating system name, release, version, and\n * machine.\n *\n * \\sa utilities::uname\n */\nstruct uts_name {\n uts_name(std::shared_ptr init);\n\n /**\n * @brief Gets the operating system name (like \"Linux\").\n * @return The operating system name.\n */\n std::string sysname() const noexcept;\n\n /**\n * @brief Gets the operating system release (like \"2.6.28\").\n * @return The operating system release.\n */\n std::string release() const noexcept;\n\n /**\n * @brief Gets the operating system version.\n * @return The operating system version\n */\n std::string version() const noexcept;\n\n /**\n * @brief Gets the hardware identifier.\n * @return The hardware identifier.\n */\n std::string machine() const noexcept;\n\nprivate:\n std::shared_ptr uname;\n};\n\n/**\n * @brief The IPv4 tag.\n *\n * To be used as template parameter to switch between IPv4 and IPv6.\n */\nstruct ipv4 {};\n\n/**\n * @brief The IPv6 tag.\n *\n * To be used as template parameter to switch between IPv4 and IPv6.\n */\nstruct ipv6 {};\n\n/**\n * @brief Address representation.\n */\nstruct socket_address {\n std::string ip; /*!< Either an IPv4 or an IPv6. */\n unsigned int port; /*!< A valid service identifier. */\n};\n\n/**\n * \\brief CPU information.\n */\nstruct cpu_info {\n using cpu_time = decltype(uv_cpu_info_t::cpu_times);\n\n std::string model; /*!< The model of the CPU. */\n int speed; /*!< The frequency of the CPU. */\n\n /**\n * @brief CPU times.\n *\n * It is built up of the following data members: `user`, `nice`, `sys`,\n * `idle`, `irq`, all of them having type `uint64_t`.\n */\n cpu_time times;\n};\n\n/**\n * \\brief Interface address.\n */\nstruct interface_address {\n std::string name; /*!< The name of the interface (as an example _eth0_). */\n char physical[6]; /*!< The physical address. */\n bool internal; /*!< True if it is an internal interface (as an example _loopback_), false otherwise. */\n socket_address address; /*!< The address of the given interface. */\n socket_address netmask; /*!< The netmask of the given interface. */\n};\n\nnamespace details {\n\nstatic constexpr std::size_t DEFAULT_SIZE = 128;\n\ntemplate\nstd::string try_read(F &&f, Args &&...args) noexcept {\n std::size_t size = DEFAULT_SIZE;\n char buf[DEFAULT_SIZE];\n std::string str{};\n auto err = std::forward(f)(args..., buf, &size);\n\n if(UV_ENOBUFS == err) {\n std::unique_ptr data{new char[size]};\n err = std::forward(f)(args..., data.get(), &size);\n\n if(0 == err) {\n str = data.get();\n }\n } else if(0 == err) {\n str.assign(buf, size);\n }\n\n return str;\n}\n\nvoid common_alloc_callback(uv_handle_t *, std::size_t suggested, uv_buf_t *buf);\n\ntemplate\nvoid common_alloc_callback(uv_handle_t *handle, std::size_t suggested, uv_buf_t *buf) {\n auto [alloc, size] = Alloc(*static_cast(handle->data), suggested);\n *buf = uv_buf_init(alloc, static_cast(size));\n}\n\nsockaddr ip_addr(const char *addr, unsigned int port);\nsocket_address sock_addr(const sockaddr_in &addr);\nsocket_address sock_addr(const sockaddr_in6 &addr);\nsocket_address sock_addr(const sockaddr &addr);\nsocket_address sock_addr(const sockaddr_storage &storage);\n\n} // namespace details\n\n/**\n * @brief Miscellaneous utilities.\n *\n * Miscellaneous functions that don\u2019t really belong to any other class.\n */\nstruct utilities {\n using malloc_func_type = void *(*)(size_t);\n using realloc_func_type = void *(*)(void *, size_t);\n using calloc_func_type = void *(*)(size_t, size_t);\n using free_func_type = void (*)(void *);\n\n /**\n * @brief OS dedicated utilities.\n */\n struct os {\n /**\n * @brief Returns the current process id.\n *\n * See the official\n * [documentation](http://docs.libuv.org/en/v1.x/misc.html#c.uv_os_getpid)\n * for further details.\n *\n * @return The current process id.\n */\n static pid_type pid() noexcept;\n\n /**\n * @brief Returns the parent process id.\n *\n * See the official\n * [documentation](http://docs.libuv.org/en/v1.x/misc.html#c.uv_os_getppid)\n * for further details.\n *\n * @return The parent process id.\n */\n static pid_type ppid() noexcept;\n\n /**\n * @brief Gets the current user's home directory.\n *\n * See the official\n * [documentation](http://docs.libuv.org/en/v1.x/misc.html#c.uv_os_homedir)\n * for further details.\n *\n * @return The current user's home directory, an empty string in case of\n * errors.\n */\n static std::string homedir() noexcept;\n\n /**\n * @brief Gets the temp directory.\n *\n * See the official\n * [documentation](http://docs.libuv.org/en/v1.x/misc.html#c.uv_os_tmpdir)\n * for further details.\n *\n * @return The temp directory, an empty string in case of errors.\n */\n static std::string tmpdir() noexcept;\n\n /**\n * @brief Retrieves an environment variable.\n * @param name The name of the variable to be retrieved.\n * @return The value of the environment variable, an empty string in\n * case of errors.\n */\n static std::string env(const std::string &name) noexcept;\n\n /**\n * @brief Creates, updates or deletes an environment variable.\n * @param name The name of the variable to be updated.\n * @param value The value to be used for the variable (an empty string\n * to unset it).\n * @return True in case of success, false otherwise.\n */\n static bool env(const std::string &name, const std::string &value) noexcept;\n\n /**\n * @brief Retrieves all environment variables and iterates them.\n *\n * Environment variables are passed one at a time to the callback in the\n * form of `std::string_view`s.
\n * The signature of the function call operator must be such that it\n * accepts two parameters, the name and the value of the i-th variable.\n *\n * @tparam Func Type of a function object to which to pass environment\n * variables.\n * @param func A function object to which to pass environment variables.\n * @return True in case of success, false otherwise.\n */\n template\n static std::enable_if_t, bool>\n env(Func func) noexcept {\n uv_env_item_t *items = nullptr;\n int count{};\n\n const bool ret = (uv_os_environ(&items, &count) == 0);\n\n if(ret) {\n for(int pos = 0; pos < count; ++pos) {\n func(std::string_view{items[pos].name}, std::string_view{items[pos].value});\n }\n\n uv_os_free_environ(items, count);\n }\n\n return ret;\n }\n\n /**\n * @brief Returns the hostname.\n * @return The hostname, an empty string in case of errors.\n */\n static std::string hostname() noexcept;\n\n /**\n * @brief Gets name and information about the current kernel.\n *\n * This function can be used to get name and information about the\n * current kernel. The populated data includes the operating system\n * name, release, version, and machine.\n *\n * @return Name and information about the current kernel.\n */\n static uts_name uname() noexcept;\n\n /**\n * @brief Gets a subset of the password file entry.\n *\n * This function can be used to get the subset of the password file\n * entry for the current effective uid (not the real uid).\n *\n * See the official\n * [documentation](http://docs.libuv.org/en/v1.x/misc.html#c.uv_os_get_passwd)\n * for further details.\n *\n * @return The accessible subset of the password file entry.\n */\n static passwd_info passwd() noexcept;\n\n /**\n * @brief Retrieves the scheduling priority of a process.\n *\n * The returned value is between -20 (high priority) and 19 (low priority).\n * A value that is out of range is returned in case of errors.\n *\n * @note\n * On Windows, the result won't equal necessarily the exact value of the\n * priority because of a mapping to a Windows priority class.\n *\n * @param pid A valid process id.\n * @return The scheduling priority of the process.\n */\n static int priority(pid_type pid);\n\n /**\n * @brief Sets the scheduling priority of a process.\n *\n * The returned value range is between -20 (high priority) and 19 (low\n * priority).\n *\n * @note\n * On Windows, the priority is mapped to a Windows priority class. When\n * retrieving the process priority, the result won't equal necessarily the\n * exact value of the priority.\n *\n * @param pid A valid process id.\n * @param prio The scheduling priority to set to the process.\n * @return True in case of success, false otherwise.\n */\n static bool priority(pid_type pid, int prio);\n };\n\n /**\n * @brief Gets the type of the handle given a category.\n * @param category A properly initialized handle category.\n * @return The actual type of the handle as defined by handle_type\n */\n static handle_type guess_handle(handle_category category) noexcept;\n\n /**\n * @brief Gets the type of the stream to be used with the given descriptor.\n *\n * Returns the type of stream that should be used with a given file\n * descriptor.
\n * Usually this will be used during initialization to guess the type of the\n * stdio streams.\n *\n * @param file A valid descriptor.\n * @return One of the following types:\n *\n * * `handle_type::UNKNOWN`\n * * `handle_type::PIPE`\n * * `handle_type::TCP`\n * * `handle_type::TTY`\n * * `handle_type::UDP`\n * * `handle_type::FILE`\n */\n static handle_type guess_handle(file_handle file) noexcept;\n\n /** @brief Gets information about the CPUs on the system.\n *\n * This function can be used to query the underlying system and get a set of\n * descriptors of all the available CPUs.\n *\n * @return A set of descriptors of all the available CPUs.\n */\n static std::vector cpu() noexcept;\n\n /**\n * @brief Gets a set of descriptors of all the available interfaces.\n *\n * This function can be used to query the underlying system and get a set of\n * descriptors of all the available interfaces, either internal or not.\n *\n * @return A set of descriptors of all the available interfaces.\n */\n static std::vector interface_addresses() noexcept;\n\n /**\n * @brief IPv6-capable implementation of\n * [if_indextoname](https://linux.die.net/man/3/if_indextoname).\n *\n * Mapping between network interface names and indexes.\n *\n * See the official\n * [documentation](http://docs.libuv.org/en/v1.x/misc.html#c.uv_if_indextoname)\n * for further details.\n *\n * @param index Network interface index.\n * @return Network interface name.\n */\n static std::string index_to_name(unsigned int index) noexcept;\n\n /**\n * @brief Retrieves a network interface identifier.\n *\n * See the official\n * [documentation](http://docs.libuv.org/en/v1.x/misc.html#c.uv_if_indextoiid)\n * for further details.\n *\n * @param index Network interface index.\n * @return Network interface identifier.\n */\n static std::string index_to_iid(unsigned int index) noexcept;\n\n /**\n * @brief Override the use of some standard library\u2019s functions.\n *\n * Override the use of the standard library\u2019s memory allocation\n * functions.
\n * This method must be invoked before any other `uvw` function is called or\n * after all resources have been freed and thus the underlying library\n * doesn\u2019t reference any allocated memory chunk.\n *\n * If any of the function pointers is _null_, the invokation will fail.\n *\n * @note\n * There is no protection against changing the allocator multiple times. If\n * the user changes it they are responsible for making sure the allocator is\n * changed while no memory was allocated with the previous allocator, or\n * that they are compatible.\n *\n * @param malloc_func Replacement function for _malloc_.\n * @param realloc_func Replacement function for _realloc_.\n * @param calloc_func Replacement function for _calloc_.\n * @param free_func Replacement function for _free_.\n * @return True in case of success, false otherwise.\n */\n static bool replace_allocator(malloc_func_type malloc_func, realloc_func_type realloc_func, calloc_func_type calloc_func, free_func_type free_func) noexcept;\n\n /**\n * @brief Gets the load average.\n * @return `[0,0,0]` on Windows (not available), the load average otherwise.\n */\n static std::array load_average() noexcept;\n\n /**\n * @brief Store the program arguments.\n *\n * Required for getting / setting the process title.\n *\n * @return Arguments that haven't been consumed internally.\n */\n static char **setup_args(int argc, char **argv);\n\n /**\n * @brief Gets the title of the current process.\n * @return The process title.\n */\n static std::string process_title();\n\n /**\n * @brief Sets the current process title.\n * @param title The process title to be set.\n * @return True in case of success, false otherwise.\n */\n static bool process_title(const std::string &title);\n\n /**\n * @brief Gets memory information (in bytes).\n * @return Memory information.\n */\n static uint64_t total_memory() noexcept;\n\n /**\n * @brief Gets the amount of memory available to the process (in bytes).\n *\n * Gets the amount of memory available to the process based on limits\n * imposed by the OS. If there is no such constraint, or the constraint is\n * unknown, `0` is returned.
\n * Note that it is not unusual for this value to be less than or greater\n * than `totalMemory`.\n *\n * @return Amount of memory available to the process.\n */\n static uint64_t constrained_memory() noexcept;\n\n /**\n * @brief Gets the amount of free memory still available to the process.\n * @return Amount of free memory still available to the process (in bytes).\n */\n static uint64_t available_memory() noexcept;\n\n /**\n * @brief Gets the current system uptime.\n * @return The current system uptime or 0 in case of errors.\n */\n static double uptime() noexcept;\n\n /**\n * @brief Gets the resource usage measures for the current process.\n * @return Resource usage measures, zeroes-filled object in case of errors.\n */\n static resource_usage rusage() noexcept;\n\n /**\n * @brief Gets the current system time from a high-resolution clock source.\n * @param source Clock source, either real-time or monotonic.\n * @return Current system time from the given high-resolution clock source.\n */\n static timespec64 gettime(clock_id source) noexcept;\n\n /**\n * @brief Gets the current high-resolution real time.\n *\n * The time is expressed in nanoseconds. It is relative to an arbitrary time\n...\n// Path: src/uvw/loop.cpp\n#ifdef UVW_AS_LIB\n# include \"loop.h\"\n#endif\n\n#include \"config.h\"\n\nnamespace uvw {\n\nUVW_INLINE loop::loop(std::unique_ptr ptr) noexcept\n : uv_loop{std::move(ptr)} {}\n\nUVW_INLINE std::shared_ptr loop::create() {\n auto ptr = std::unique_ptr{new uv_loop_t, [](uv_loop_t *l) { delete l; }};\n auto curr = std::shared_ptr{new loop{std::move(ptr)}};\n\n if(uv_loop_init(curr->uv_loop.get())) {\n curr = nullptr;\n }\n\n return curr;\n}\n\nUVW_INLINE std::shared_ptr loop::create(uv_loop_t *res) {\n auto ptr = std::unique_ptr{res, [](uv_loop_t *) {}};\n return std::shared_ptr{new loop{std::move(ptr)}};\n}\n\nUVW_INLINE std::shared_ptr loop::get_default() {\n static std::weak_ptr ref;\n std::shared_ptr curr;\n\n if(ref.expired()) {\n auto def = uv_default_loop();\n\n if(def) {\n auto ptr = std::unique_ptr(def, [](uv_loop_t *) {});\n curr = std::shared_ptr{new loop{std::move(ptr)}};\n }\n\n ref = curr;\n } else {\n curr = ref.lock();\n }\n\n return curr;\n}\n\nUVW_INLINE loop::~loop() noexcept {\n if(uv_loop) {\n close();\n }\n}\n\nUVW_INLINE int loop::close() {\n int ret = 0;\n\n if(uv_loop) {\n ret = uv_loop_close(uv_loop.get());\n uv_loop.reset();\n }\n\n return ret;\n}\n\nUVW_INLINE int loop::run(run_mode mode) noexcept {\n return uv_run(uv_loop.get(), static_cast(mode));\n}\n\nUVW_INLINE bool loop::alive() const noexcept {\n return !!uv_loop_alive(uv_loop.get());\n}\n\nUVW_INLINE void loop::stop() noexcept {\n uv_stop(uv_loop.get());\n}\n\nUVW_INLINE int loop::descriptor() const noexcept {\n return uv_backend_fd(uv_loop.get());\n}\n\nUVW_INLINE std::pair loop::timeout() const noexcept {\n auto to = uv_backend_timeout(uv_loop.get());\n return std::make_pair(to == -1, time{to});\n}\n\nUVW_INLINE loop::time loop::idle_time() const noexcept {\n return time{uv_metrics_idle_time(uv_loop.get())};\n}\n\nUVW_INLINE metrics_type loop::metrics() const noexcept {\n metrics_type res{};\n uv_metrics_info(uv_loop.get(), &res);\n return res;\n}\n\nUVW_INLINE loop::time loop::now() const noexcept {\n return time{uv_now(uv_loop.get())};\n}\n\nUVW_INLINE void loop::update() const noexcept {\n return uv_update_time(uv_loop.get());\n}\n\nUVW_INLINE int loop::fork() noexcept {\n return uv_loop_fork(uv_loop.get());\n}\n\nUVW_INLINE void loop::data(std::shared_ptr ud) {\n user_data = std::move(ud);\n}\n\nUVW_INLINE const uv_loop_t *loop::raw() const noexcept {\n return uv_loop.get();\n}\n\nUVW_INLINE uv_loop_t *loop::raw() noexcept {\n return const_cast(const_cast(this)->raw());\n}\n\n} // namespace uvw\n\n// Path: src/uvw/loop.h\n#ifndef UVW_LOOP_INCLUDE_H\n#define UVW_LOOP_INCLUDE_H\n\n#ifdef _WIN32\n# include \n#endif\n\n#include \n#include \n#include \n#include \n#include \n#include \n#include \"config.h\"\n#include \"emitter.h\"\n#include \"util.h\"\n\nnamespace uvw {\n\nclass async_handle;\nclass check_handle;\nclass fs_event_handle;\nclass fs_poll_handle;\nclass idle_handle;\nclass pipe_handle;\nclass poll_handle;\nclass prepare_handle;\nclass process_handle;\nclass signal_handle;\nclass tcp_handle;\nclass timer_handle;\nclass tty_handle;\nclass udp_handle;\n\nnamespace details {\n\nenum class uvw_loop_option : std::underlying_type_t {\n BLOCK_SIGNAL = UV_LOOP_BLOCK_SIGNAL,\n IDLE_TIME = UV_METRICS_IDLE_TIME\n};\n\nenum class uvw_run_mode : std::underlying_type_t {\n DEFAULT = UV_RUN_DEFAULT,\n ONCE = UV_RUN_ONCE,\n NOWAIT = UV_RUN_NOWAIT\n};\n\n} // namespace details\n\nusing metrics_type = uv_metrics_t; /*!< Library equivalent for uv_metrics_t. */\n\n/**\n * @brief The loop class.\n *\n * The event loop is the central part of `uvw`'s functionalities, as well as\n * `libuv`'s ones.
\n * It takes care of polling for I/O and scheduling callbacks to be run based on\n * different sources of events.\n */\nclass loop final: public emitter, public std::enable_shared_from_this {\n using deleter = void (*)(uv_loop_t *);\n\n template\n friend class resource;\n\n class uv_token {\n friend class loop;\n explicit uv_token(int) {}\n };\n\n template\n auto init(int, Type &value) -> decltype(value.init()) {\n return value.init();\n }\n\n template\n int init(char, Type &) {\n return 0;\n }\n\n loop(std::unique_ptr ptr) noexcept;\n\npublic:\n using token = uv_token;\n using time = std::chrono::duration;\n using option = details::uvw_loop_option;\n using run_mode = details::uvw_run_mode;\n\n /**\n * @brief Initializes a new loop instance.\n * @return A pointer to the newly created loop.\n */\n static std::shared_ptr create();\n\n /**\n * @brief Initializes a new loop instance from an existing resource.\n *\n * The lifetime of the resource must exceed that of the instance to which\n * it's associated. Management of the memory associated with the resource is\n * in charge of the user.\n *\n * @param res A valid pointer to a correctly initialized resource.\n * @return A pointer to the newly created loop.\n */\n static std::shared_ptr create(uv_loop_t *res);\n\n /**\n * @brief Gets the initialized default loop.\n *\n * It may return an empty pointer in case of failure.
\n * This function is just a convenient way for having a global loop\n * throughout an application, the default loop is in no way different than\n * the ones initialized with `create()`.
\n * As such, the default loop can be closed with `close()` so the resources\n * associated with it are freed (even if it is not strictly necessary).\n *\n * @return The initialized default loop.\n */\n static std::shared_ptr get_default();\n\n loop(const loop &) = delete;\n loop(loop &&other) = delete;\n\n loop &operator=(const loop &) = delete;\n loop &operator=(loop &&other) = delete;\n\n ~loop() noexcept;\n\n /**\n * @brief Sets additional loop options.\n *\n * You should normally call this before the first call to uv_run() unless\n * mentioned otherwise.
\n * Supported options:\n *\n * * `loop::option::BLOCK_SIGNAL`: Block a signal when polling for new\n * events. A second argument is required and it is the signal number.\n * * `loop::option::IDLE_TIME`: Accumulate the amount of idle time the event\n * loop spends in the event provider. This option is necessary to use\n * `idle_time()`.\n *\n * See the official\n * [documentation](http://docs.libuv.org/en/v1.x/loop.html#c.uv_loop_configure)\n * for further details.\n *\n * @return Underlying return value.\n */\n template\n \nint configure(option flag, Args &&...args) {\n return uv_loop_configure(uv_loop.get(), static_cast(flag), std::forward(args)...);\n }\n\n /**\n * @brief Creates resources of any type.\n *\n * This should be used as a default method to create resources.
\n * The arguments are the ones required for the specific resource.\n *\n * Use it as `loop->resource()`.\n *\n * @return A pointer to the newly created resource.\n */\n template\n std::shared_ptr resource(Args &&...args) {\n auto ptr = uninitialized_resource(std::forward(args)...);\n return (init(0, *ptr) == 0) ? ptr : nullptr;\n }\n\n /**\n * @brief Creates uninitialized resources of any type.\n * @return A pointer to the newly created resource.\n */\n template\n std::shared_ptr uninitialized_resource(Args &&...args) {\n return std::make_shared(token{0}, shared_from_this(), std::forward(args)...);\n }\n\n /**\n * @brief Releases all internal loop resources.\n *\n * Call this function only when the loop has finished executing and all open\n * handles and requests have been closed, or the loop will error.\n *\n * @return Underlying return value.\n */\n int close();\n\n /**\n * @brief Runs the event loop.\n *\n * Available modes are:\n *\n * * `loop::run_mode::DEFAULT`: Runs the event loop until there are no more\n * active and referenced handles or requests.\n * * `loop::run_mode::ONCE`: Poll for i/o once. Note that this function\n * blocks if there are no pending callbacks.\n * * `loop::run_mode::NOWAIT`: Poll for i/o once but don\u2019t block if there\n * are no pending callbacks.\n *\n * See the official\n * [documentation](http://docs.libuv.org/en/v1.x/loop.html#c.uv_run)\n * for further details.\n *\n * @return Underlying return value.\n */\n int run(run_mode mode = run_mode::DEFAULT) noexcept;\n\n /**\n * @brief Checks if there are active resources.\n * @return True if there are active resources in the loop.\n */\n bool alive() const noexcept;\n\n /**\n * @brief Stops the event loop.\n *\n * It causes `run()` to end as soon as possible.
\n * This will happen not sooner than the next loop iteration.
\n * If this function was called before blocking for I/O, the loop won\u2019t block\n * for I/O on this iteration.\n */\n void stop() noexcept;\n\n /**\n * @brief Get backend file descriptor.\n *\n * Only kqueue, epoll and event ports are supported.
\n * This can be used in conjunction with `run(loop::run_mode::NOWAIT)` to\n * poll in one thread and run the event loop\u2019s callbacks in another.\n *\n * @return The backend file descriptor.\n */\n int descriptor() const noexcept;\n\n /**\n * @brief Gets the poll timeout.\n * @return A `std::pair` composed as it follows:\n * * A boolean value that is true in case of valid timeout, false otherwise.\n * * Milliseconds (`std::chrono::duration`).\n */\n std::pair timeout() const noexcept;\n\n /**\n * @brief Returns the amount of time the event loop has been idle. The call\n * is thread safe.\n * @return The accumulated time spent idle.\n */\n time idle_time() const noexcept;\n\n /**\n * @brief Tracks various internal operations of the event loop.\n * @return Event loop metrics.\n */\n metrics_type metrics() const noexcept;\n\n /**\n * @brief Returns the current timestamp in milliseconds.\n *\n * The timestamp is cached at the start of the event loop tick.
\n * The timestamp increases monotonically from some arbitrary point in\n * time.
\n * Don\u2019t make assumptions about the starting point, you will only get\n * disappointed.\n *\n * @return The current timestamp in milliseconds (actual type is\n * `std::chrono::duration`).\n */\n time now() const noexcept;\n\n /**\n * @brief Updates the event loop\u2019s concept of _now_.\n *\n * The current time is cached at the start of the event loop tick in order\n * to reduce the number of time-related system calls.
\n * You won\u2019t normally need to call this function unless you have callbacks\n * that block the event loop for longer periods of time, where _longer_ is\n * somewhat subjective but probably on the order of a millisecond or more.\n */\n void update() const noexcept;\n\n /**\n * @brief Walks the list of handles.\n *\n * The callback is invoked once for each handle that is still active.\n *\n * @param callback A function to invoke once for each active handle.\n */\n template\n void walk(Func callback) {\n auto func = [](uv_handle_t *hndl, void *callback_func) {\n if(hndl->data) {\n auto &cb = *static_cast(callback_func);\n\n switch(utilities::guess_handle(handle_category{hndl->type})) {\n case handle_type::ASYNC:\n cb(*static_cast(hndl->data));\n break;\n case handle_type::CHECK:\n cb(*static_cast(hndl->data));\n break;\n case handle_type::FS_EVENT:\n cb(*static_cast(hndl->data));\n break;\n case handle_type::FS_POLL:\n cb(*static_cast(hndl->data));\n break;\n case handle_type::IDLE:\n cb(*static_cast(hndl->data));\n break;\n case handle_type::PIPE:\n cb(*static_cast(hndl->data));\n break;\n case handle_type::POLL:\n cb(*static_cast(hndl->data));\n break;\n case handle_type::PREPARE:\n cb(*static_cast(hndl->data));\n break;\n case handle_type::PROCESS:\n cb(*static_cast(hndl->data));\n break;\n case handle_type::SIGNAL:\n cb(*static_cast(hndl->data));\n break;\n case handle_type::TCP:\n cb(*static_cast(hndl->data));\n break;\n case handle_type::TIMER:\n cb(*static_cast(hndl->data));\n break;\n case handle_type::TTY:\n cb(*static_cast(hndl->data));\n break;\n case handle_type::UDP:\n cb(*static_cast(hndl->data));\n break;\n default:\n // this handle isn't managed by uvw, let it be...\n break;\n }\n }\n };\n\n uv_walk(uv_loop.get(), func, &callback);\n }\n\n /**\n * @brief Reinitialize any kernel state necessary in the child process after\n * a fork(2) system call.\n *\n * Previously started watchers will continue to be started in the child\n * process.\n *\n * It is necessary to explicitly call this function on every event loop\n * created in the parent process that you plan to continue to use in the\n * child, including the default loop (even if you don\u2019t continue to use it\n * in the parent). This function must be called before calling any API\n * function using the loop in the child. Failure to do so will result in\n * undefined behaviour, possibly including duplicate events delivered to\n * both parent and child or aborting the child process.\n *\n * When possible, it is preferred to create a new loop in the child process\n * instead of reusing a loop created in the parent. New loops created in the\n * child process after the fork should not use this function.\n *\n * Note that this function is not implemented on Windows.
\n * Note also that this function is experimental in `libuv`. It may contain\n * bugs, and is subject to change or removal. API and ABI stability is not\n * guaranteed.\n *\n * See the official\n * [documentation](http://docs.libuv.org/en/v1.x/loop.html#c.uv_loop_fork)\n * for further details.\n *\n * @return Underlying return value.\n */\n int fork() noexcept;\n\n /**\n * @brief Gets user-defined data. `uvw` won't use this field in any case.\n * @return User-defined data if any, an invalid pointer otherwise.\n */\n template\n std::shared_ptr data() const {\n return std::static_pointer_cast(user_data);\n }\n\n /**\n * @brief Sets arbitrary data. `uvw` won't use this field in any case.\n * @param ud User-defined arbitrary data.\n */\n void data(std::shared_ptr ud);\n\n /**\n * @brief Gets the underlying raw data structure.\n *\n * This function should not be used, unless you know exactly what you are\n * doing and what are the risks.
\n * Going raw is dangerous, mainly because the lifetime management of a loop,\n * a handle or a request is in charge to the library itself and users should\n * not work around it.\n *\n * @warning\n * Use this function at your own risk, but do not expect any support in case\n * of bugs.\n *\n * @return The underlying raw data structure.\n */\n const uv_loop_t *raw() const noexcept;\n\n /**\n * @brief Gets the underlying raw data structure.\n *\n * This function should not be used, unless you know exactly what you are\n * doing and what are the risks.
\n * Going raw is dangerous, mainly because the lifetime management of a loop,\n * a handle or a request is in charge to the library itself and users should\n * not work around it.\n *\n * @warning\n * Use this function at your own risk, but do not expect any support in case\n * of bugs.\n *\n * @return The underlying raw data structure.\n */\n uv_loop_t *raw() noexcept;\n\nprivate:\n std::unique_ptr uv_loop;\n std::shared_ptr user_data{nullptr};\n};\n\n} // namespace uvw\n\n#ifndef UVW_AS_LIB\n# include \"loop.cpp\"\n#endif\n\n#endif // UVW_LOOP_INCLUDE_H\n\n// Path: src/uvw/uv_type.hpp\n#ifndef UVW_UV_TYPE_INCLUDE_H\n#define UVW_UV_TYPE_INCLUDE_H\n\n#include \n#include \n#include \n#include \"config.h\"\n#include \"loop.h\"\n\nnamespace uvw {\n\n/**\n * @brief Wrapper class for underlying types.\n *\n * It acts mainly as a wrapper around data structures of the underlying library.\n */\ntemplate\nstruct uv_type {\n explicit uv_type(loop::token, std::shared_ptr ref) noexcept\n : owner{std::move(ref)}, resource{} {}\n\n uv_type(const uv_type &) = delete;\n uv_type(uv_type &&) = delete;\n\n uv_type &operator=(const uv_type &) = delete;\n uv_type &operator=(uv_type &&) = delete;\n\n /**\n * @brief Gets the loop from which the resource was originated.\n * @return A reference to a loop instance.\n */\n loop &parent() const noexcept {\n return *owner;\n }\n\n /**\n * @brief Gets the underlying raw data structure.\n *\n * This function should not be used, unless you know exactly what you are\n * doing and what are the risks.
\n * Going raw is dangerous, mainly because the lifetime management of a loop,\n * a handle or a request is in charge to the library itself and users should\n * not work around it.\n *\n * @warning\n * Use this function at your own risk, but do not expect any support in case\n * of bugs.\n *\n * @return The underlying raw data structure.\n */\n const U *raw() const noexcept {\n return &resource;\n }\n\n /**\n * @brief Gets the underlying raw data structure.\n *\n * This function should not be used, unless you know exactly what you are\n * doing and what are the risks.
\n * Going raw is dangerous, mainly because the lifetime management of a loop,\n * a handle or a request is in charge to the library itself and users should\n * not work around it.\n *\n * @warning\n * Use this function at your own risk, but do not expect any support in case\n * of bugs.\n *\n * @return The underlying raw data structure.\n */\n U *raw() noexcept {\n return &resource;\n }\n\nprotected:\n ~uv_type() = default;\n\nprivate:\n std::shared_ptr owner;\n U resource;\n};\n\n} // namespace uvw\n\n#endif // UVW_UV_TYPE_INCLUDE_H\n\n// Path: src/uvw/resource.hpp\n#ifndef UVW_RESOURCE_INCLUDE_H\n#define UVW_RESOURCE_INCLUDE_H\n\n#include \n#include \n#include \"config.h\"\n#include \"emitter.h\"\n#include \"uv_type.hpp\"\n\nnamespace uvw {\n\n/**\n * @brief Common class for almost all the resources available in `uvw`.\n *\n * This is the base class for handles and requests.\n */\ntemplate\nclass resource: public uv_type, public emitter, public std::enable_shared_from_this {\nprotected:\n int leak_if(int err) noexcept {\n if(err == 0) {\n self_ptr = this->shared_from_this();\n }\n\n return err;\n }\n\n void self_reset() noexcept {\n self_ptr.reset();\n }\n\n bool has_self() const noexcept {\n return static_cast(self_ptr);\n }\n\npublic:\n explicit resource(loop::token token, std::shared_ptr ref)\n : uv_type{token, std::move(ref)} {\n this->raw()->data = this;\n }\n\n /**\n * @brief Gets user-defined data. `uvw` won't use this field in any case.\n * @return User-defined data if any, an invalid pointer otherwise.\n */\n template\n std::shared_ptr data() const {\n return std::static_pointer_cast(user_data);\n }\n\n /**\n * @brief Sets arbitrary data. `uvw` won't use this field in any case.\n * @param udata User-defined arbitrary data.\n */\n void data(std::shared_ptr udata) {\n user_data = std::move(udata);\n }\n\nprivate:\n std::shared_ptr user_data{nullptr};\n std::shared_ptr self_ptr{nullptr};\n};\n\n} // namespace uvw\n\n#endif // UVW_RESOURCE_INCLUDE_H\n\n// Path: src/uvw/handle.hpp\n#ifndef UVW_HANDLE_INCLUDE_H\n#define UVW_HANDLE_INCLUDE_H\n\n#include \n#include \n#include \n#include \n#include \"config.h\"\n#include \"resource.hpp\"\n#include \"util.h\"\n\nnamespace uvw {\n\n/*! @brief Close event. */\nstruct close_event {};\n\n/**\n * @brief Handle base class.\n *\n * Base type for all `uvw` handle types.\n */\ntemplate\nclass handle: public resource {\nprotected:\n static void close_callback(uv_handle_t *hndl) {\n handle &ref = *(static_cast(hndl->data));\n [[maybe_unused]] auto ptr = ref.shared_from_this();\n ref.self_reset();\n ref.publish(close_event{});\n }\n\n uv_handle_t *as_uv_handle() {\n return reinterpret_cast(this->raw());\n }\n\n const uv_handle_t *as_uv_handle() const {\n return reinterpret_cast(this->raw());\n }\n\npublic:\n using resource::resource;\n\n /**\n * @brief Gets the category of the handle.\n *\n * A base handle offers no functionality to promote it to the actual handle\n * type. By means of this function, an opaque value that identifies the\n * category of the handle is made available to the users.\n *\n * @return The actual category of the handle.\n */\n handle_category category() const noexcept {\n return handle_category{as_uv_handle()->type};\n }\n\n /**\n * @brief Gets the type of the handle.\n *\n * A base handle offers no functionality to promote it to the actual handle\n * type. By means of this function, the type of the underlying handle as\n * specified by handle_type is made available to the users.\n *\n * @return The actual type of the handle.\n */\n handle_type type() const noexcept {\n return utilities::guess_handle(category());\n }\n\n /**\n * @brief Checks if the handle is active.\n *\n * What _active_ means depends on the type of handle:\n *\n * * An async_handle handle is always active and cannot be deactivated,\n * except by closing it with uv_close().\n * * A pipe, tcp, udp, etc. handle - basically any handle that deals with\n * I/O - is active when it is doing something that involves I/O, like\n * reading, writing, connecting, accepting new connections, etc.\n * * A check, idle, timer, etc. handle is active when it has been started\n * with a call to `start()`.\n *\n * Rule of thumb: if a handle of type `foo_handle` has a `start()` member\n * method, then it\u2019s active from the moment that method is called. Likewise,\n * `stop()` deactivates the handle again.\n *\n * @return True if the handle is active, false otherwise.\n */\n bool active() const noexcept {\n return !!uv_is_active(as_uv_handle());\n }\n\n /**\n * @brief Checks if a handle is closing or closed.\n *\n * This function should only be used between the initialization of the\n * handle and the arrival of the close callback.\n *\n * @return True if the handle is closing or closed, false otherwise.\n */\n bool closing() const noexcept {\n return !!uv_is_closing(as_uv_handle());\n }\n\n /**\n * @brief Request handle to be closed.\n *\n * This **must** be called on each handle before memory is released.
\n * In-progress requests are cancelled and this can result in errors.\n *\n * The handle will emit a close event when finished.\n */\n void close() noexcept {\n if(!closing()) {\n uv_close(as_uv_handle(), &handle::close_callback);\n }\n }\n\n /**\n * @brief Reference the given handle.\n *\n * References are idempotent, that is, if a handle is already referenced\n * calling this function again will have no effect.\n */\n void reference() noexcept {\n uv_ref(as_uv_handle());\n }\n\n /**\n * @brief Unreference the given handle.\n *\n * References are idempotent, that is, if a handle is not referenced calling\n * this function again will have no effect.\n */\n void unreference() noexcept {\n uv_unref(as_uv_handle());\n }\n\n /**\n * @brief Checks if the given handle referenced.\n * @return True if the handle referenced, false otherwise.\n */\n bool referenced() const noexcept {\n return !!uv_has_ref(as_uv_handle());\n }\n\n /**\n * @brief Returns the size of the underlying handle type.\n * @return The size of the underlying handle type.\n */\n std::size_t size() const noexcept {\n return uv_handle_size(as_uv_handle()->type);\n }\n\n /**\n * @brief Gets the size of the send buffer used for the socket.\n *\n * Gets the size of the send buffer that the operating system uses for the\n * socket.
\n * This function works for tcp, pipeand udp handles on Unix and for tcp and\n * udp handles on Windows.
\n * Note that Linux will return double the size of the original set value.\n *\n * @return The size of the send buffer, the underlying return value in case\n * of errors.\n */\n int send_buffer_size() {\n int value = 0;\n auto err = uv_send_buffer_size(as_uv_handle(), &value);\n return err ? err : value;\n }\n\n /**\n * @brief Sets the size of the send buffer used for the socket.\n *\n * Sets the size of the send buffer that the operating system uses for the\n * socket.
\n * This function works for tcp, pipe and udp handles on Unix and for tcp and\n * udp handles on Windows.
\n * Note that Linux will set double the size.\n *\n * @return Underlying return value.\n */\n int send_buffer_size(int value) {\n return uv_send_buffer_size(as_uv_handle(), &value);\n }\n\n /**\n * @brief Gets the size of the receive buffer used for the socket.\n *\n * Gets the size of the receive buffer that the operating system uses for\n * the socket.
\n * This function works for tcp, pipe and udp handles on Unix and for tcp and\n * udp handles on Windows.
\n * Note that Linux will return double the size of the original set value.\n *\n * @return The size of the receive buffer, the underlying return value in\n * case of errors.\n */\n int recv_buffer_size() {\n int value = 0;\n auto err = uv_recv_buffer_size(as_uv_handle(), &value);\n return err ? err : value;\n }\n\n /**\n * @brief Sets the size of the receive buffer used for the socket.\n *\n * Sets the size of the receive buffer that the operating system uses for\n * the socket.
\n * This function works for tcp, pipe and udp handles on Unix and for tcp and\n * udp handles on Windows.
\n * Note that Linux will set double the size.\n *\n * @return Underlying return value.\n */\n int recv_buffer_size(int value) {\n return uv_recv_buffer_size(as_uv_handle(), &value);\n }\n\n /**\n * @brief Gets the platform dependent file descriptor equivalent.\n *\n * Supported handles:\n *\n * * tcp_handle\n * * pipe_handle\n * * tty_handle\n * * udp_handle\n * * poll_handle\n *\n * If invoked on a different handle, one that doesn\u2019t have an attached file\n * descriptor yet or one which was closed, an invalid value is returned.\n *\n * See the official\n * [documentation](http://docs.libuv.org/en/v1.x/handle.html#c.uv_fileno)\n * for further details.\n *\n * @return The file descriptor attached to the hande or a negative value in\n * case of errors.\n */\n os_file_descriptor fd() const {\n uv_os_fd_t fd;\n uv_fileno(as_uv_handle(), &fd);\n return fd;\n }\n};\n\n} // namespace uvw\n\n#endif // UVW_HANDLE_INCLUDE_H\n\n// Path: src/uvw/prepare.h\n#ifndef UVW_PREPARE_INCLUDE_H\n#define UVW_PREPARE_INCLUDE_H\n\n#include \n#include \"handle.hpp\"\n#include \"loop.h\"\n\nnamespace uvw {\n\n/*! @brief Prepare event. */\nstruct prepare_event {};\n\n/**\n * @brief The prepare handle.\n *\n * Prepare handles will emit a prepare event once per loop iteration, right\n * before polling for I/O.\n *\n * To create a `prepare_handle` through a `loop`, no arguments are required.\n */\nclass prepare_handle final: public handle {\n static void start_callback(uv_prepare_t *hndl);\n\npublic:\n using handle::handle;\n\n /**\n * @brief Initializes the handle.\n * @return Underlying return value.\n */\n int init();\n\n /**\n * @brief Starts the handle.\n *\n * A prepare event will be emitted once per loop iteration, right before\n * polling for I/O.\n *\n * The handle will start emitting prepare events when needed.\n *\n * @return Underlying return value.\n */\n int start();\n\n /**\n * @brief Stops the handle.\n * @return Underlying return value.\n */\n int stop();\n};\n\n} // namespace uvw\n\n#ifndef UVW_AS_LIB\n# include \"prepare.cpp\"\n#endif\n\n#endif // UVW_PREPARE_INCLUDE_H\n\n// Path: src/uvw/prepare.cpp\n#ifdef UVW_AS_LIB\n# include \"prepare.h\"\n#endif\n\n#include \"config.h\"\n\nnamespace uvw {\n\nUVW_INLINE void prepare_handle::start_callback(uv_prepare_t *hndl) {\n prepare_handle &prepare = *(static_cast(hndl->data));\n prepare.publish(prepare_event{});\n}\n\nUVW_INLINE int prepare_handle::init() {\n return leak_if(uv_prepare_init(parent().raw(), raw()));\n}\n\nUVW_INLINE int prepare_handle::start() {\n return uv_prepare_start(raw(), &start_callback);\n}\n\nUVW_INLINE int prepare_handle::stop() {\n return uv_prepare_stop(raw());\n}\n\n} // namespace uvw\n\n// Path: src/uvw/request.hpp\n#ifndef UVW_REQUEST_INCLUDE_H\n...\n", "needle_token_start": 8790, "needle_token_end": 8844, "code_context_ntokens": 16077, "output": ["```cpp\ntemplate\nint configure(option flag, Args &&...args) {\n return uv_loop_configure(uv_loop.get(), static_cast(flag), std::forward(args)...);\n}\n```"]} +{"repo": "skypjack/uvw", "name": "stream_handle", "language": "cpp", "path": "src/uvw/stream.h", "position_ratio": 0.65, "description": "\nFunction Description:\n1. **Purpose**: Provides an abstraction layer for duplex communication channels, facilitating data transmission and reception across different types of stream-based connections such as TCP, pipes, and TTY.\n2. **Input**: Receives data through callbacks that handle various stream events like connection establishment, data reception, and stream closure.\n3. **Output**: Emits events corresponding to the stream's state changes and data flow, such as connection success, incoming data, or errors.\n4. **Procedure**: Initializes with a loop reference and sets up callbacks to handle stream events. It processes incoming data, handles connection lifecycle events, and manages errors and stream termination through a series of predefined callbacks.\n\n", "instruction": "Based on the function description and code context, please retrieve and repeat the exact described function from the code context in a code block wrapped by ```:", "template": "instruction\ncode_context\ndescription\ninstruction", "code_context": "// Path: src/uvw/loop.h\n#ifndef UVW_LOOP_INCLUDE_H\n#define UVW_LOOP_INCLUDE_H\n\n#ifdef _WIN32\n# include \n#endif\n\n#include \n#include \n#include \n#include \n#include \n#include \n#include \"config.h\"\n#include \"emitter.h\"\n#include \"util.h\"\n\nnamespace uvw {\n\nclass async_handle;\nclass check_handle;\nclass fs_event_handle;\nclass fs_poll_handle;\nclass idle_handle;\nclass pipe_handle;\nclass poll_handle;\nclass prepare_handle;\nclass process_handle;\nclass signal_handle;\nclass tcp_handle;\nclass timer_handle;\nclass tty_handle;\nclass udp_handle;\n\nnamespace details {\n\nenum class uvw_loop_option : std::underlying_type_t {\n BLOCK_SIGNAL = UV_LOOP_BLOCK_SIGNAL,\n IDLE_TIME = UV_METRICS_IDLE_TIME\n};\n\nenum class uvw_run_mode : std::underlying_type_t {\n DEFAULT = UV_RUN_DEFAULT,\n ONCE = UV_RUN_ONCE,\n NOWAIT = UV_RUN_NOWAIT\n};\n\n} // namespace details\n\nusing metrics_type = uv_metrics_t; /*!< Library equivalent for uv_metrics_t. */\n\n/**\n * @brief The loop class.\n *\n * The event loop is the central part of `uvw`'s functionalities, as well as\n * `libuv`'s ones.
\n * It takes care of polling for I/O and scheduling callbacks to be run based on\n * different sources of events.\n */\nclass loop final: public emitter, public std::enable_shared_from_this {\n using deleter = void (*)(uv_loop_t *);\n\n template\n friend class resource;\n\n class uv_token {\n friend class loop;\n explicit uv_token(int) {}\n };\n\n template\n auto init(int, Type &value) -> decltype(value.init()) {\n return value.init();\n }\n\n template\n int init(char, Type &) {\n return 0;\n }\n\n loop(std::unique_ptr ptr) noexcept;\n\npublic:\n using token = uv_token;\n using time = std::chrono::duration;\n using option = details::uvw_loop_option;\n using run_mode = details::uvw_run_mode;\n\n /**\n * @brief Initializes a new loop instance.\n * @return A pointer to the newly created loop.\n */\n static std::shared_ptr create();\n\n /**\n * @brief Initializes a new loop instance from an existing resource.\n *\n * The lifetime of the resource must exceed that of the instance to which\n * it's associated. Management of the memory associated with the resource is\n * in charge of the user.\n *\n * @param res A valid pointer to a correctly initialized resource.\n * @return A pointer to the newly created loop.\n */\n static std::shared_ptr create(uv_loop_t *res);\n\n /**\n * @brief Gets the initialized default loop.\n *\n * It may return an empty pointer in case of failure.
\n * This function is just a convenient way for having a global loop\n * throughout an application, the default loop is in no way different than\n * the ones initialized with `create()`.
\n * As such, the default loop can be closed with `close()` so the resources\n * associated with it are freed (even if it is not strictly necessary).\n *\n * @return The initialized default loop.\n */\n static std::shared_ptr get_default();\n\n loop(const loop &) = delete;\n loop(loop &&other) = delete;\n\n loop &operator=(const loop &) = delete;\n loop &operator=(loop &&other) = delete;\n\n ~loop() noexcept;\n\n /**\n * @brief Sets additional loop options.\n *\n * You should normally call this before the first call to uv_run() unless\n * mentioned otherwise.
\n * Supported options:\n *\n * * `loop::option::BLOCK_SIGNAL`: Block a signal when polling for new\n * events. A second argument is required and it is the signal number.\n * * `loop::option::IDLE_TIME`: Accumulate the amount of idle time the event\n * loop spends in the event provider. This option is necessary to use\n * `idle_time()`.\n *\n * See the official\n * [documentation](http://docs.libuv.org/en/v1.x/loop.html#c.uv_loop_configure)\n * for further details.\n *\n * @return Underlying return value.\n */\n template\n int configure(option flag, Args &&...args) {\n return uv_loop_configure(uv_loop.get(), static_cast(flag), std::forward(args)...);\n }\n\n /**\n * @brief Creates resources of any type.\n *\n * This should be used as a default method to create resources.
\n * The arguments are the ones required for the specific resource.\n *\n * Use it as `loop->resource()`.\n *\n * @return A pointer to the newly created resource.\n */\n template\n std::shared_ptr resource(Args &&...args) {\n auto ptr = uninitialized_resource(std::forward(args)...);\n return (init(0, *ptr) == 0) ? ptr : nullptr;\n }\n\n /**\n * @brief Creates uninitialized resources of any type.\n * @return A pointer to the newly created resource.\n */\n template\n std::shared_ptr uninitialized_resource(Args &&...args) {\n return std::make_shared(token{0}, shared_from_this(), std::forward(args)...);\n }\n\n /**\n * @brief Releases all internal loop resources.\n *\n * Call this function only when the loop has finished executing and all open\n * handles and requests have been closed, or the loop will error.\n *\n * @return Underlying return value.\n */\n int close();\n\n /**\n * @brief Runs the event loop.\n *\n * Available modes are:\n *\n * * `loop::run_mode::DEFAULT`: Runs the event loop until there are no more\n * active and referenced handles or requests.\n * * `loop::run_mode::ONCE`: Poll for i/o once. Note that this function\n * blocks if there are no pending callbacks.\n * * `loop::run_mode::NOWAIT`: Poll for i/o once but don\u2019t block if there\n * are no pending callbacks.\n *\n * See the official\n * [documentation](http://docs.libuv.org/en/v1.x/loop.html#c.uv_run)\n * for further details.\n *\n * @return Underlying return value.\n */\n int run(run_mode mode = run_mode::DEFAULT) noexcept;\n\n /**\n * @brief Checks if there are active resources.\n * @return True if there are active resources in the loop.\n */\n bool alive() const noexcept;\n\n /**\n * @brief Stops the event loop.\n *\n * It causes `run()` to end as soon as possible.
\n * This will happen not sooner than the next loop iteration.
\n * If this function was called before blocking for I/O, the loop won\u2019t block\n * for I/O on this iteration.\n */\n void stop() noexcept;\n\n /**\n * @brief Get backend file descriptor.\n *\n * Only kqueue, epoll and event ports are supported.
\n * This can be used in conjunction with `run(loop::run_mode::NOWAIT)` to\n * poll in one thread and run the event loop\u2019s callbacks in another.\n *\n * @return The backend file descriptor.\n */\n int descriptor() const noexcept;\n\n /**\n * @brief Gets the poll timeout.\n * @return A `std::pair` composed as it follows:\n * * A boolean value that is true in case of valid timeout, false otherwise.\n * * Milliseconds (`std::chrono::duration`).\n */\n std::pair timeout() const noexcept;\n\n /**\n * @brief Returns the amount of time the event loop has been idle. The call\n * is thread safe.\n * @return The accumulated time spent idle.\n */\n time idle_time() const noexcept;\n\n /**\n * @brief Tracks various internal operations of the event loop.\n * @return Event loop metrics.\n */\n metrics_type metrics() const noexcept;\n\n /**\n * @brief Returns the current timestamp in milliseconds.\n *\n * The timestamp is cached at the start of the event loop tick.
\n * The timestamp increases monotonically from some arbitrary point in\n * time.
\n * Don\u2019t make assumptions about the starting point, you will only get\n * disappointed.\n *\n * @return The current timestamp in milliseconds (actual type is\n * `std::chrono::duration`).\n */\n time now() const noexcept;\n\n /**\n * @brief Updates the event loop\u2019s concept of _now_.\n *\n * The current time is cached at the start of the event loop tick in order\n * to reduce the number of time-related system calls.
\n * You won\u2019t normally need to call this function unless you have callbacks\n * that block the event loop for longer periods of time, where _longer_ is\n * somewhat subjective but probably on the order of a millisecond or more.\n */\n void update() const noexcept;\n\n /**\n * @brief Walks the list of handles.\n *\n * The callback is invoked once for each handle that is still active.\n *\n * @param callback A function to invoke once for each active handle.\n */\n template\n void walk(Func callback) {\n auto func = [](uv_handle_t *hndl, void *callback_func) {\n if(hndl->data) {\n auto &cb = *static_cast(callback_func);\n\n switch(utilities::guess_handle(handle_category{hndl->type})) {\n case handle_type::ASYNC:\n cb(*static_cast(hndl->data));\n break;\n case handle_type::CHECK:\n cb(*static_cast(hndl->data));\n break;\n case handle_type::FS_EVENT:\n cb(*static_cast(hndl->data));\n break;\n case handle_type::FS_POLL:\n cb(*static_cast(hndl->data));\n break;\n case handle_type::IDLE:\n cb(*static_cast(hndl->data));\n break;\n case handle_type::PIPE:\n cb(*static_cast(hndl->data));\n break;\n case handle_type::POLL:\n cb(*static_cast(hndl->data));\n break;\n case handle_type::PREPARE:\n cb(*static_cast(hndl->data));\n break;\n case handle_type::PROCESS:\n cb(*static_cast(hndl->data));\n break;\n case handle_type::SIGNAL:\n cb(*static_cast(hndl->data));\n break;\n case handle_type::TCP:\n cb(*static_cast(hndl->data));\n break;\n case handle_type::TIMER:\n cb(*static_cast(hndl->data));\n break;\n case handle_type::TTY:\n cb(*static_cast(hndl->data));\n break;\n case handle_type::UDP:\n cb(*static_cast(hndl->data));\n break;\n default:\n // this handle isn't managed by uvw, let it be...\n break;\n }\n }\n };\n\n uv_walk(uv_loop.get(), func, &callback);\n }\n\n /**\n * @brief Reinitialize any kernel state necessary in the child process after\n * a fork(2) system call.\n *\n * Previously started watchers will continue to be started in the child\n * process.\n *\n * It is necessary to explicitly call this function on every event loop\n * created in the parent process that you plan to continue to use in the\n * child, including the default loop (even if you don\u2019t continue to use it\n * in the parent). This function must be called before calling any API\n * function using the loop in the child. Failure to do so will result in\n * undefined behaviour, possibly including duplicate events delivered to\n * both parent and child or aborting the child process.\n *\n * When possible, it is preferred to create a new loop in the child process\n * instead of reusing a loop created in the parent. New loops created in the\n * child process after the fork should not use this function.\n *\n * Note that this function is not implemented on Windows.
\n * Note also that this function is experimental in `libuv`. It may contain\n * bugs, and is subject to change or removal. API and ABI stability is not\n * guaranteed.\n *\n * See the official\n * [documentation](http://docs.libuv.org/en/v1.x/loop.html#c.uv_loop_fork)\n * for further details.\n *\n * @return Underlying return value.\n */\n int fork() noexcept;\n\n /**\n * @brief Gets user-defined data. `uvw` won't use this field in any case.\n * @return User-defined data if any, an invalid pointer otherwise.\n */\n template\n std::shared_ptr data() const {\n return std::static_pointer_cast(user_data);\n }\n\n /**\n * @brief Sets arbitrary data. `uvw` won't use this field in any case.\n * @param ud User-defined arbitrary data.\n */\n void data(std::shared_ptr ud);\n\n /**\n * @brief Gets the underlying raw data structure.\n *\n...\n// Path: src/uvw/uv_type.hpp\n#ifndef UVW_UV_TYPE_INCLUDE_H\n#define UVW_UV_TYPE_INCLUDE_H\n\n#include \n#include \n#include \n#include \"config.h\"\n#include \"loop.h\"\n\nnamespace uvw {\n\n/**\n * @brief Wrapper class for underlying types.\n *\n * It acts mainly as a wrapper around data structures of the underlying library.\n */\ntemplate\nstruct uv_type {\n explicit uv_type(loop::token, std::shared_ptr ref) noexcept\n : owner{std::move(ref)}, resource{} {}\n\n uv_type(const uv_type &) = delete;\n uv_type(uv_type &&) = delete;\n\n uv_type &operator=(const uv_type &) = delete;\n uv_type &operator=(uv_type &&) = delete;\n\n /**\n * @brief Gets the loop from which the resource was originated.\n * @return A reference to a loop instance.\n */\n loop &parent() const noexcept {\n return *owner;\n }\n\n /**\n * @brief Gets the underlying raw data structure.\n *\n * This function should not be used, unless you know exactly what you are\n * doing and what are the risks.
\n * Going raw is dangerous, mainly because the lifetime management of a loop,\n * a handle or a request is in charge to the library itself and users should\n * not work around it.\n *\n * @warning\n * Use this function at your own risk, but do not expect any support in case\n * of bugs.\n *\n * @return The underlying raw data structure.\n */\n const U *raw() const noexcept {\n return &resource;\n }\n\n /**\n * @brief Gets the underlying raw data structure.\n *\n * This function should not be used, unless you know exactly what you are\n * doing and what are the risks.
\n * Going raw is dangerous, mainly because the lifetime management of a loop,\n * a handle or a request is in charge to the library itself and users should\n * not work around it.\n *\n * @warning\n * Use this function at your own risk, but do not expect any support in case\n * of bugs.\n *\n * @return The underlying raw data structure.\n */\n U *raw() noexcept {\n return &resource;\n }\n\nprotected:\n ~uv_type() = default;\n\nprivate:\n std::shared_ptr owner;\n U resource;\n};\n\n} // namespace uvw\n\n#endif // UVW_UV_TYPE_INCLUDE_H\n\n// Path: src/uvw/resource.hpp\n#ifndef UVW_RESOURCE_INCLUDE_H\n#define UVW_RESOURCE_INCLUDE_H\n\n#include \n#include \n#include \"config.h\"\n#include \"emitter.h\"\n#include \"uv_type.hpp\"\n\nnamespace uvw {\n\n/**\n * @brief Common class for almost all the resources available in `uvw`.\n *\n * This is the base class for handles and requests.\n */\ntemplate\nclass resource: public uv_type, public emitter, public std::enable_shared_from_this {\nprotected:\n int leak_if(int err) noexcept {\n if(err == 0) {\n self_ptr = this->shared_from_this();\n }\n\n return err;\n }\n\n void self_reset() noexcept {\n self_ptr.reset();\n }\n\n bool has_self() const noexcept {\n return static_cast(self_ptr);\n }\n\npublic:\n explicit resource(loop::token token, std::shared_ptr ref)\n : uv_type{token, std::move(ref)} {\n this->raw()->data = this;\n }\n\n /**\n * @brief Gets user-defined data. `uvw` won't use this field in any case.\n * @return User-defined data if any, an invalid pointer otherwise.\n */\n template\n std::shared_ptr data() const {\n return std::static_pointer_cast(user_data);\n }\n\n /**\n * @brief Sets arbitrary data. `uvw` won't use this field in any case.\n * @param udata User-defined arbitrary data.\n */\n void data(std::shared_ptr udata) {\n user_data = std::move(udata);\n }\n\nprivate:\n std::shared_ptr user_data{nullptr};\n std::shared_ptr self_ptr{nullptr};\n};\n\n} // namespace uvw\n\n#endif // UVW_RESOURCE_INCLUDE_H\n\n// Path: src/uvw/handle.hpp\n#ifndef UVW_HANDLE_INCLUDE_H\n#define UVW_HANDLE_INCLUDE_H\n\n#include \n#include \n#include \n#include \n#include \"config.h\"\n#include \"resource.hpp\"\n#include \"util.h\"\n\nnamespace uvw {\n\n/*! @brief Close event. */\nstruct close_event {};\n\n/**\n * @brief Handle base class.\n *\n * Base type for all `uvw` handle types.\n */\ntemplate\nclass handle: public resource {\nprotected:\n static void close_callback(uv_handle_t *hndl) {\n handle &ref = *(static_cast(hndl->data));\n [[maybe_unused]] auto ptr = ref.shared_from_this();\n ref.self_reset();\n ref.publish(close_event{});\n }\n\n uv_handle_t *as_uv_handle() {\n return reinterpret_cast(this->raw());\n }\n\n const uv_handle_t *as_uv_handle() const {\n return reinterpret_cast(this->raw());\n }\n\npublic:\n using resource::resource;\n\n /**\n * @brief Gets the category of the handle.\n *\n * A base handle offers no functionality to promote it to the actual handle\n * type. By means of this function, an opaque value that identifies the\n * category of the handle is made available to the users.\n *\n * @return The actual category of the handle.\n */\n handle_category category() const noexcept {\n return handle_category{as_uv_handle()->type};\n }\n\n /**\n * @brief Gets the type of the handle.\n *\n * A base handle offers no functionality to promote it to the actual handle\n * type. By means of this function, the type of the underlying handle as\n * specified by handle_type is made available to the users.\n *\n * @return The actual type of the handle.\n */\n handle_type type() const noexcept {\n return utilities::guess_handle(category());\n }\n\n /**\n * @brief Checks if the handle is active.\n *\n * What _active_ means depends on the type of handle:\n *\n * * An async_handle handle is always active and cannot be deactivated,\n * except by closing it with uv_close().\n * * A pipe, tcp, udp, etc. handle - basically any handle that deals with\n * I/O - is active when it is doing something that involves I/O, like\n * reading, writing, connecting, accepting new connections, etc.\n * * A check, idle, timer, etc. handle is active when it has been started\n * with a call to `start()`.\n *\n * Rule of thumb: if a handle of type `foo_handle` has a `start()` member\n * method, then it\u2019s active from the moment that method is called. Likewise,\n * `stop()` deactivates the handle again.\n *\n * @return True if the handle is active, false otherwise.\n */\n bool active() const noexcept {\n return !!uv_is_active(as_uv_handle());\n }\n\n /**\n * @brief Checks if a handle is closing or closed.\n *\n * This function should only be used between the initialization of the\n * handle and the arrival of the close callback.\n *\n * @return True if the handle is closing or closed, false otherwise.\n */\n bool closing() const noexcept {\n return !!uv_is_closing(as_uv_handle());\n }\n\n /**\n * @brief Request handle to be closed.\n *\n * This **must** be called on each handle before memory is released.
\n * In-progress requests are cancelled and this can result in errors.\n *\n * The handle will emit a close event when finished.\n */\n void close() noexcept {\n if(!closing()) {\n uv_close(as_uv_handle(), &handle::close_callback);\n }\n }\n\n /**\n * @brief Reference the given handle.\n *\n * References are idempotent, that is, if a handle is already referenced\n * calling this function again will have no effect.\n */\n void reference() noexcept {\n uv_ref(as_uv_handle());\n }\n\n /**\n * @brief Unreference the given handle.\n *\n * References are idempotent, that is, if a handle is not referenced calling\n * this function again will have no effect.\n */\n void unreference() noexcept {\n uv_unref(as_uv_handle());\n }\n\n /**\n * @brief Checks if the given handle referenced.\n * @return True if the handle referenced, false otherwise.\n */\n bool referenced() const noexcept {\n return !!uv_has_ref(as_uv_handle());\n }\n\n /**\n * @brief Returns the size of the underlying handle type.\n * @return The size of the underlying handle type.\n */\n std::size_t size() const noexcept {\n return uv_handle_size(as_uv_handle()->type);\n }\n\n /**\n * @brief Gets the size of the send buffer used for the socket.\n *\n * Gets the size of the send buffer that the operating system uses for the\n * socket.
\n * This function works for tcp, pipeand udp handles on Unix and for tcp and\n * udp handles on Windows.
\n * Note that Linux will return double the size of the original set value.\n *\n * @return The size of the send buffer, the underlying return value in case\n * of errors.\n */\n int send_buffer_size() {\n int value = 0;\n auto err = uv_send_buffer_size(as_uv_handle(), &value);\n return err ? err : value;\n }\n\n /**\n * @brief Sets the size of the send buffer used for the socket.\n *\n * Sets the size of the send buffer that the operating system uses for the\n * socket.
\n * This function works for tcp, pipe and udp handles on Unix and for tcp and\n * udp handles on Windows.
\n * Note that Linux will set double the size.\n *\n * @return Underlying return value.\n */\n int send_buffer_size(int value) {\n return uv_send_buffer_size(as_uv_handle(), &value);\n }\n\n /**\n * @brief Gets the size of the receive buffer used for the socket.\n *\n * Gets the size of the receive buffer that the operating system uses for\n * the socket.
\n * This function works for tcp, pipe and udp handles on Unix and for tcp and\n * udp handles on Windows.
\n * Note that Linux will return double the size of the original set value.\n *\n * @return The size of the receive buffer, the underlying return value in\n * case of errors.\n */\n int recv_buffer_size() {\n int value = 0;\n auto err = uv_recv_buffer_size(as_uv_handle(), &value);\n return err ? err : value;\n }\n\n /**\n * @brief Sets the size of the receive buffer used for the socket.\n *\n * Sets the size of the receive buffer that the operating system uses for\n * the socket.
\n * This function works for tcp, pipe and udp handles on Unix and for tcp and\n * udp handles on Windows.
\n * Note that Linux will set double the size.\n *\n * @return Underlying return value.\n */\n int recv_buffer_size(int value) {\n return uv_recv_buffer_size(as_uv_handle(), &value);\n }\n\n /**\n * @brief Gets the platform dependent file descriptor equivalent.\n *\n * Supported handles:\n *\n * * tcp_handle\n * * pipe_handle\n * * tty_handle\n * * udp_handle\n * * poll_handle\n *\n * If invoked on a different handle, one that doesn\u2019t have an attached file\n * descriptor yet or one which was closed, an invalid value is returned.\n *\n * See the official\n * [documentation](http://docs.libuv.org/en/v1.x/handle.html#c.uv_fileno)\n * for further details.\n *\n * @return The file descriptor attached to the hande or a negative value in\n * case of errors.\n */\n os_file_descriptor fd() const {\n uv_os_fd_t fd;\n uv_fileno(as_uv_handle(), &fd);\n return fd;\n }\n};\n\n} // namespace uvw\n\n#endif // UVW_HANDLE_INCLUDE_H\n\n// Path: src/uvw/prepare.h\n#ifndef UVW_PREPARE_INCLUDE_H\n#define UVW_PREPARE_INCLUDE_H\n\n#include \n#include \"handle.hpp\"\n#include \"loop.h\"\n\nnamespace uvw {\n\n/*! @brief Prepare event. */\nstruct prepare_event {};\n\n/**\n * @brief The prepare handle.\n *\n * Prepare handles will emit a prepare event once per loop iteration, right\n * before polling for I/O.\n *\n * To create a `prepare_handle` through a `loop`, no arguments are required.\n */\nclass prepare_handle final: public handle {\n static void start_callback(uv_prepare_t *hndl);\n\npublic:\n using handle::handle;\n\n /**\n * @brief Initializes the handle.\n * @return Underlying return value.\n */\n int init();\n\n /**\n * @brief Starts the handle.\n *\n * A prepare event will be emitted once per loop iteration, right before\n * polling for I/O.\n *\n * The handle will start emitting prepare events when needed.\n *\n * @return Underlying return value.\n */\n int start();\n\n /**\n * @brief Stops the handle.\n * @return Underlying return value.\n */\n int stop();\n};\n\n} // namespace uvw\n\n#ifndef UVW_AS_LIB\n# include \"prepare.cpp\"\n#endif\n\n#endif // UVW_PREPARE_INCLUDE_H\n\n// Path: src/uvw/prepare.cpp\n#ifdef UVW_AS_LIB\n# include \"prepare.h\"\n#endif\n\n#include \"config.h\"\n\nnamespace uvw {\n\nUVW_INLINE void prepare_handle::start_callback(uv_prepare_t *hndl) {\n prepare_handle &prepare = *(static_cast(hndl->data));\n prepare.publish(prepare_event{});\n}\n\nUVW_INLINE int prepare_handle::init() {\n return leak_if(uv_prepare_init(parent().raw(), raw()));\n}\n\nUVW_INLINE int prepare_handle::start() {\n return uv_prepare_start(raw(), &start_callback);\n}\n\nUVW_INLINE int prepare_handle::stop() {\n return uv_prepare_stop(raw());\n}\n\n} // namespace uvw\n\n// Path: src/uvw/request.hpp\n#ifndef UVW_REQUEST_INCLUDE_H\n#define UVW_REQUEST_INCLUDE_H\n\n#include \n#include \n#include \n#include \n#include \"config.h\"\n#include \"resource.hpp\"\n\nnamespace uvw {\n\n/**\n * @brief Request base class.\n *\n * Base type for all `uvw` request types.\n */\ntemplate\nclass request: public resource {\nprotected:\n static auto reserve(U *req) {\n auto ptr = static_cast(req->data)->shared_from_this();\n ptr->self_reset();\n return ptr;\n }\n\npublic:\n using resource::resource;\n\n /**\n * @brief Cancels a pending request.\n *\n * This method fails if the request is executing or has finished\n * executing.\n *\n * See the official\n * [documentation](http://docs.libuv.org/en/v1.x/request.html#c.uv_cancel)\n * for further details.\n *\n * @return Underlying return value.\n */\n int cancel() {\n return uv_cancel(reinterpret_cast(this->raw()));\n }\n\n /**\n * @brief Returns the size of the underlying request type.\n * @return The size of the underlying request type.\n */\n std::size_t size() const noexcept {\n return uv_req_size(reinterpret_cast(this->raw())->type);\n }\n};\n\n} // namespace uvw\n\n#endif // UVW_REQUEST_INCLUDE_H\n\n// Path: src/uvw/stream.h\n#ifndef UVW_STREAM_INCLUDE_H\n#define UVW_STREAM_INCLUDE_H\n\n#include \n#include \n#include \n#include \n#include \n#include \n#include \"config.h\"\n#include \"handle.hpp\"\n#include \"loop.h\"\n#include \"request.hpp\"\n\nnamespace uvw {\n\n/*! @brief Connect event. */\nstruct connect_event {};\n\n/*! @brief End event. */\nstruct end_event {};\n\n/*! @brief Listen event. */\nstruct listen_event {};\n\n/*! @brief Shutdown event. */\nstruct shutdown_event {};\n\n/*! @brief Write event. */\nstruct write_event {};\n\n/*! @brief Data event. */\nstruct data_event {\n explicit data_event(std::unique_ptr buf, std::size_t len) noexcept;\n\n std::unique_ptr data; /*!< A bunch of data read on the stream. */\n std::size_t length; /*!< The amount of data read on the stream. */\n};\n\nnamespace details {\n\nclass connect_req final: public request {\n static void connect_callback(uv_connect_t *req, int status);\n\npublic:\n using request::request;\n\n template\n auto connect(F &&f, Args &&...args) -> std::enable_if_t(f)(raw(), std::forward(args)..., &connect_callback)), void>, int> {\n std::forward(f)(raw(), std::forward(args)..., &connect_callback);\n return this->leak_if(0);\n }\n\n template\n auto connect(F &&f, Args &&...args) -> std::enable_if_t(f)(raw(), std::forward(args)..., &connect_callback)), void>, int> {\n return this->leak_if(std::forward(f)(raw(), std::forward(args)..., &connect_callback));\n }\n};\n\nclass shutdown_req final: public request {\n static void shoutdown_callback(uv_shutdown_t *req, int status);\n\npublic:\n using request::request;\n\n int shutdown(uv_stream_t *hndl);\n};\n\ntemplate\nclass write_req final: public request, uv_write_t, write_event> {\n static void write_callback(uv_write_t *req, int status) {\n if(auto ptr = request, uv_write_t, write_event>::reserve(req); status) {\n ptr->publish(error_event{status});\n } else {\n ptr->publish(write_event{});\n }\n }\n\npublic:\n write_req(loop::token token, std::shared_ptr parent, std::unique_ptr dt, unsigned int len)\n : request, uv_write_t, write_event>{token, std::move(parent)},\n data{std::move(dt)},\n buf{uv_buf_init(data.get(), len)} {}\n\n int write(uv_stream_t *hndl) {\n return this->leak_if(uv_write(this->raw(), hndl, &buf, 1, &write_callback));\n }\n\n int write(uv_stream_t *hndl, uv_stream_t *send) {\n return this->leak_if(uv_write2(this->raw(), hndl, &buf, 1, send, &write_callback));\n }\n\nprivate:\n std::unique_ptr data;\n uv_buf_t buf;\n};\n\n} // namespace details\n\n/**\n * @brief The stream handle.\n *\n * Stream handles provide an abstraction of a duplex communication channel.\n * The stream handle is an intermediate type, `uvw` provides three stream\n * implementations: tcp, pipe and tty handles.\n */\ntemplate\nclass stream_handle: public handle {\n using base = handle;\n\n template\n friend class stream_handle;\n\n static constexpr unsigned int DEFAULT_BACKLOG = 128;\n\n static void read_callback(uv_stream_t *hndl, ssize_t nread, const uv_buf_t *buf) {\n T &ref = *(static_cast(hndl->data));\n // data will be destroyed no matter of what the value of nread is\n std::unique_ptr data{buf->base};\n\n // nread == 0 is ignored (see http://docs.libuv.org/en/v1.x/stream.html)\n // equivalent to EAGAIN/EWOULDBLOCK, it shouldn't be treated as an error\n // for we don't have data to emit though, it's fine to suppress it\n\n if(nread == UV_EOF) {\n // end of stream\n ref.publish(end_event{});\n } else if(nread > 0) {\n // data available\n ref.publish(data_event{std::move(data), static_cast(nread)});\n } else if(nread < 0) {\n // transmission error\n ref.publish(error_event(nread));\n }\n }\n\n static void listen_callback(uv_stream_t *hndl, int status) {\n if(T &ref = *(static_cast(hndl->data)); status) {\n ref.publish(error_event{status});\n } else {\n ref.publish(listen_event{});\n }\n }\n\n uv_stream_t *as_uv_stream() {\n return reinterpret_cast(this->raw());\n }\n\n const uv_stream_t *as_uv_stream() const {\n return reinterpret_cast(this->raw());\n }\n\npublic:\n#ifdef _MSC_VER\n \nstream_handle(loop::token token, std::shared_ptr ref)\n : base{token, std::move(ref)} {}\n#else\n using base::base;\n#endif\n\n /**\n * @brief Shutdowns the outgoing (write) side of a duplex stream.\n *\n * It waits for pending write requests to complete. The handle should refer\n * to a initialized stream.
\n * A shutdown event will be emitted after shutdown is complete.\n *\n * @return Underlying return value.\n */\n int shutdown() {\n auto listener = [ptr = this->shared_from_this()](const auto &event, const auto &) {\n ptr->publish(event);\n };\n\n auto shutdown = this->parent().template resource();\n shutdown->template on(listener);\n shutdown->template on(listener);\n\n return shutdown->shutdown(as_uv_stream());\n }\n\n /**\n * @brief Starts listening for incoming connections.\n *\n * When a new incoming connection is received, a listen event is\n * emitted.\n *\n * @param backlog Indicates the number of connections the kernel might\n * queue, same as listen(2).\n *\n * @return Underlying return value.\n */\n int listen(int backlog = DEFAULT_BACKLOG) {\n return uv_listen(as_uv_stream(), backlog, &listen_callback);\n }\n\n /**\n * @brief Accepts incoming connections.\n *\n * This call is used in conjunction with `listen()` to accept incoming\n * connections. Call this function after receiving a listen event to accept\n * the connection. Before calling this function, the submitted handle must\n * be initialized.\n *\n * When the listen event is emitted it is guaranteed that this function will\n * complete successfully the first time. If you attempt to use it more than\n * once, it may fail.
\n * It is suggested to only call this function once per listen event.\n *\n * @note\n * Both the handles must be running on the same loop.\n *\n * @param ref An initialized handle to be used to accept the connection.\n * @return Underlying return value.\n */\n template\n int accept(S &ref) {\n return uv_accept(as_uv_stream(), ref.as_uv_stream());\n }\n\n /**\n * @brief Starts reading data from an incoming stream.\n *\n * A data event will be emitted several times until there is no more data to\n * read or `stop()` is called.
\n * An end event will be emitted when there is no more data to read.\n *\n * @return Underlying return value.\n */\n int read() {\n return uv_read_start(as_uv_stream(), &details::common_alloc_callback, &read_callback);\n }\n\n /**\n * @brief Starts reading data from an incoming stream.\n * @sa read\n * @tparam Alloc Custom allocation function.\n * @return Underlying return value.\n */\n template\n int read() {\n return uv_read_start(as_uv_stream(), &details::common_alloc_callback, &read_callback);\n }\n\n /**\n * @brief Stops reading data from the stream.\n *\n * This function is idempotent and may be safely called on a stopped stream.\n *\n * @return Underlying return value.\n */\n int stop() {\n return uv_read_stop(as_uv_stream());\n }\n\n /**\n * @brief Writes data to the stream.\n *\n * Data are written in order. The handle takes the ownership of the data and\n * it is in charge of delete them.\n *\n * A write event will be emitted when the data have been written.\n *\n * @param data The data to be written to the stream.\n * @param len The lenght of the submitted data.\n * @return Underlying return value.\n */\n template\n int write(std::unique_ptr data, unsigned int len) {\n auto req = this->parent().template resource>(std::move(data), len);\n auto listener = [ptr = this->shared_from_this()](const auto &event, const auto &) {\n ptr->publish(event);\n };\n\n req->template on(listener);\n req->template on(listener);\n\n return req->write(as_uv_stream());\n }\n\n /**\n * @brief Writes data to the stream.\n *\n * Data are written in order. The handle doesn't take the ownership of the\n * data. Be sure that their lifetime overcome the one of the request.\n *\n * A write event will be emitted when the data have been written.\n *\n * @param data The data to be written to the stream.\n * @param len The lenght of the submitted data.\n * @return Underlying return value.\n */\n int write(char *data, unsigned int len) {\n auto req = this->parent().template resource>(std::unique_ptr{data, [](char *) {}}, len);\n auto listener = [ptr = this->shared_from_this()](const auto &event, const auto &) {\n ptr->publish(event);\n };\n\n req->template on(listener);\n req->template on(listener);\n\n return req->write(as_uv_stream());\n }\n\n /**\n * @brief Extended write function for sending handles over a pipe handle.\n *\n * The pipe must be initialized with `ipc == true`.\n *\n * `send` must be a tcp or pipe handle, which is a server or a connection\n * (listening or connected state). Bound sockets or pipes will be assumed to\n * be servers.\n *\n * The handle takes the ownership of the data and it is in charge of delete\n * them.\n *\n * A write event will be emitted when the data have been written.\n *\n * @param send The handle over which to write data.\n * @param data The data to be written to the stream.\n * @param len The lenght of the submitted data.\n * @return Underlying return value.\n */\n template\n int write(S &send, std::unique_ptr data, unsigned int len) {\n auto req = this->parent().template resource>(std::move(data), len);\n auto listener = [ptr = this->shared_from_this()](const auto &event, const auto &) {\n ptr->publish(event);\n };\n\n req->template on(listener);\n req->template on(listener);\n\n return req->write(as_uv_stream(), send.as_uv_stream());\n }\n\n /**\n * @brief Extended write function for sending handles over a pipe handle.\n *\n * The pipe must be initialized with `ipc == true`.\n *\n * `send` must be a tcp or pipe handle, which is a server or a connection\n * (listening or connected state). Bound sockets or pipes will be assumed to\n * be servers.\n *\n * The handle doesn't take the ownership of the data. Be sure that their\n * lifetime overcome the one of the request.\n *\n * A write event will be emitted when the data have been written.\n *\n * @param send The handle over which to write data.\n * @param data The data to be written to the stream.\n * @param len The lenght of the submitted data.\n * @return Underlying return value.\n */\n template\n int write(S &send, char *data, unsigned int len) {\n auto req = this->parent().template resource>(std::unique_ptr{data, [](char *) {}}, len);\n auto listener = [ptr = this->shared_from_this()](const auto &event, const auto &) {\n ptr->publish(event);\n };\n\n req->template on(listener);\n req->template on(listener);\n\n return req->write(as_uv_stream(), send.as_uv_stream());\n }\n\n /**\n * @brief Queues a write request if it can be completed immediately.\n *\n * Same as `write()`, but won\u2019t queue a write request if it can\u2019t be\n * completed immediately.\n *\n * @param data The data to be written to the stream.\n * @param len The lenght of the submitted data.\n * @return Underlying return value.\n */\n int try_write(std::unique_ptr data, unsigned int len) {\n uv_buf_t bufs[] = {uv_buf_init(data.get(), len)};\n return uv_try_write(as_uv_stream(), bufs, 1);\n }\n\n /**\n * @brief Queues a write request if it can be completed immediately.\n *\n * Same as `try_write` for sending handles over a pipe.\n *\n * @param data The data to be written to the stream.\n * @param len The lenght of the submitted data.\n * @param send A valid handle suitable for the purpose.\n * @return Underlying return value.\n */\n template\n int try_write(std::unique_ptr data, unsigned int len, stream_handle &send) {\n uv_buf_t bufs[] = {uv_buf_init(data.get(), len)};\n return uv_try_write2(as_uv_stream(), bufs, 1, send.raw());\n }\n\n /**\n * @brief Queues a write request if it can be completed immediately.\n *\n * Same as `write()`, but won\u2019t queue a write request if it can\u2019t be\n * completed immediately.\n *\n * @param data The data to be written to the stream.\n * @param len The lenght of the submitted data.\n * @return Underlying return value.\n */\n int try_write(char *data, unsigned int len) {\n uv_buf_t bufs[] = {uv_buf_init(data, len)};\n return uv_try_write(as_uv_stream(), bufs, 1);\n }\n\n /**\n * @brief Queues a write request if it can be completed immediately.\n *\n * Same as `try_write` for sending handles over a pipe.\n *\n * @param data The data to be written to the stream.\n * @param len The lenght of the submitted data.\n * @param send A valid handle suitable for the purpose.\n * @return Underlying return value.\n */\n template\n int try_write(char *data, unsigned int len, stream_handle &send) {\n uv_buf_t bufs[] = {uv_buf_init(data, len)};\n return uv_try_write2(as_uv_stream(), bufs, 1, send.raw());\n }\n\n /**\n * @brief Checks if the stream is readable.\n * @return True if the stream is readable, false otherwise.\n */\n bool readable() const noexcept {\n return (uv_is_readable(as_uv_stream()) == 1);\n }\n\n /**\n * @brief Checks if the stream is writable.\n * @return True if the stream is writable, false otherwise.\n */\n bool writable() const noexcept {\n return (uv_is_writable(as_uv_stream()) == 1);\n }\n\n /**\n * @brief Enables or disables blocking mode for a stream.\n *\n * When blocking mode is enabled all writes complete synchronously. The\n * interface remains unchanged otherwise, e.g. completion or failure of the\n * operation will still be reported through events which are emitted\n * asynchronously.\n *\n * See the official\n * [documentation](http://docs.libuv.org/en/v1.x/stream.html#c.uv_stream_set_blocking)\n * for further details.\n *\n * @param enable True to enable blocking mode, false otherwise.\n * @return True in case of success, false otherwise.\n */\n bool blocking(bool enable = false) {\n return (0 == uv_stream_set_blocking(as_uv_stream(), enable));\n }\n\n /**\n * @brief Gets the amount of queued bytes waiting to be sent.\n * @return Amount of queued bytes waiting to be sent.\n */\n size_t write_queue_size() const noexcept {\n return uv_stream_get_write_queue_size(as_uv_stream());\n }\n};\n\n} // namespace uvw\n\n#ifndef UVW_AS_LIB\n# include \"stream.cpp\"\n#endif\n\n#endif // UVW_STREAM_INCLUDE_H\n\n// Path: src/uvw/stream.cpp\n#ifdef UVW_AS_LIB\n# include \"stream.h\"\n#endif\n\n#include \"config.h\"\n\nnamespace uvw {\n\nUVW_INLINE data_event::data_event(std::unique_ptr buf, std::size_t len) noexcept\n : data{std::move(buf)},\n length{len} {}\n\nUVW_INLINE void details::connect_req::connect_callback(uv_connect_t *req, int status) {\n if(auto ptr = reserve(req); status) {\n ptr->publish(error_event{status});\n } else {\n ptr->publish(connect_event{});\n }\n}\n\nUVW_INLINE void details::shutdown_req::shoutdown_callback(uv_shutdown_t *req, int status) {\n if(auto ptr = reserve(req); status) {\n ptr->publish(error_event{status});\n } else {\n ptr->publish(shutdown_event{});\n }\n}\n\nUVW_INLINE int details::shutdown_req::shutdown(uv_stream_t *hndl) {\n return this->leak_if(uv_shutdown(raw(), hndl, &shoutdown_callback));\n}\n\n} // namespace uvw\n\n// Path: src/uvw/check.h\n#ifndef UVW_CHECK_INCLUDE_H\n#define UVW_CHECK_INCLUDE_H\n\n#include \n#include \"handle.hpp\"\n#include \"loop.h\"\n\nnamespace uvw {\n\n/*! @brief Check event. */\nstruct check_event {};\n\n/**\n * @brief The check handle.\n *\n * Check handles will emit a check event once per loop iteration, right after\n * polling for I/O.\n *\n * To create a `check_handle` through a `loop`, no arguments are required.\n */\nclass check_handle final: public handle {\n static void start_callback(uv_check_t *hndl);\n\npublic:\n using handle::handle;\n\n /**\n * @brief Initializes the handle.\n * @return Underlying return value.\n */\n int init();\n\n /**\n * @brief Starts the handle.\n *\n * A check event will be emitted once per loop iteration, right after\n * polling for I/O.\n *\n * @return Underlying return value.\n */\n int start();\n\n /**\n * @brief Stops the handle.\n * @return Underlying return value.\n */\n int stop();\n};\n\n} // namespace uvw\n\n#ifndef UVW_AS_LIB\n# include \"check.cpp\"\n#endif\n\n#endif // UVW_CHECK_INCLUDE_H\n\n// Path: src/uvw/check.cpp\n#ifdef UVW_AS_LIB\n# include \"check.h\"\n#endif\n\n#include \"config.h\"\n\nnamespace uvw {\n\nUVW_INLINE void check_handle::start_callback(uv_check_t *hndl) {\n check_handle &check = *(static_cast(hndl->data));\n check.publish(check_event{});\n}\n\nUVW_INLINE int check_handle::init() {\n return leak_if(uv_check_init(parent().raw(), raw()));\n}\n\nUVW_INLINE int check_handle::start() {\n return uv_check_start(raw(), &start_callback);\n}\n\nUVW_INLINE int check_handle::stop() {\n return uv_check_stop(raw());\n}\n\n} // namespace uvw\n\n// Path: src/uvw/enum.hpp\n#ifndef UVW_ENUM_INCLUDE_HPP\n#define UVW_ENUM_INCLUDE_HPP\n\n#include \n#include \"config.h\"\n\n/**\n * @brief Operator available for enums for which bitmask support is enabled.\n * @tparam Type Enum class type.\n * @param lhs The first value to use.\n * @param rhs The second value to use.\n * @return The result of invoking the operator on the underlying types of the\n * two values provided.\n */\ntemplate\n[[nodiscard]] constexpr std::enable_if_t, decltype(Type::_UVW_ENUM)>\noperator|(const Type lhs, const Type rhs) noexcept {\n return static_cast(static_cast>(lhs) | static_cast>(rhs));\n}\n\n/*! @copydoc operator| */\ntemplate\n[[nodiscard]] constexpr std::enable_if_t, decltype(Type::_UVW_ENUM)>\noperator&(const Type lhs, const Type rhs) noexcept {\n return static_cast(static_cast>(lhs) & static_cast>(rhs));\n}\n\n/*! @copydoc operator| */\ntemplate\n[[nodiscard]] constexpr std::enable_if_t, decltype(Type::_UVW_ENUM)>\noperator^(const Type lhs, const Type rhs) noexcept {\n return static_cast(static_cast>(lhs) ^ static_cast>(rhs));\n}\n\n/**\n * @brief Operator available for enums for which bitmask support is enabled.\n * @tparam Type Enum class type.\n * @param value The value to use.\n * @return The result of invoking the operator on the underlying types of the\n * value provided.\n */\ntemplate\n[[nodiscard]] constexpr std::enable_if_t, decltype(Type::_UVW_ENUM)>\noperator~(const Type value) noexcept {\n return static_cast(~static_cast>(value));\n}\n\n/*! @copydoc operator~ */\ntemplate\n[[nodiscard]] constexpr std::enable_if_t, decltype(Type::_UVW_ENUM, bool{})>\noperator!(const Type value) noexcept {\n return !static_cast>(value);\n}\n\n/*! @copydoc operator| */\ntemplate\nconstexpr std::enable_if_t, decltype(Type::_UVW_ENUM) &>\noperator|=(Type &lhs, const Type rhs) noexcept {\n return (lhs = (lhs | rhs));\n}\n\n/*! @copydoc operator| */\ntemplate\nconstexpr std::enable_if_t, decltype(Type::_UVW_ENUM) &>\noperator&=(Type &lhs, const Type rhs) noexcept {\n return (lhs = (lhs & rhs));\n}\n\n/*! @copydoc operator| */\ntemplate\nconstexpr std::enable_if_t, decltype(Type::_UVW_ENUM) &>\noperator^=(Type &lhs, const Type rhs) noexcept {\n return (lhs = (lhs ^ rhs));\n}\n\n#endif\n\n// Path: src/uvw/thread.h\n#ifndef UVW_THREAD_INCLUDE_H\n#define UVW_THREAD_INCLUDE_H\n\n#include \n#include \n#include \n#include \n#include \n#include \n#include \n#include \"config.h\"\n#include \"enum.hpp\"\n#include \"loop.h\"\n#include \"uv_type.hpp\"\n\nnamespace uvw {\n\nnamespace details {\n\nenum class uvw_thread_create_flags : std::underlying_type_t {\n THREAD_NO_FLAGS = UV_THREAD_NO_FLAGS,\n THREAD_HAS_STACK_SIZE = UV_THREAD_HAS_STACK_SIZE\n};\n\n}\n\nclass thread;\nclass thread_local_storage;\nclass once;\nclass mutex;\nclass rwlock;\nclass semaphore;\nclass condition;\nclass barrier;\n\n/**\n...\n", "needle_token_start": 10440, "needle_token_end": 10473, "code_context_ntokens": 16088, "output": ["```cpp\ntemplate\nclass stream_handle: public handle {\n using base = handle;\n\n template\n friend class stream_handle;\n\n static constexpr unsigned int DEFAULT_BACKLOG = 128;\n\n static void read_callback(uv_stream_t *hndl, ssize_t nread, const uv_buf_t *buf) {\n T &ref = *(static_cast(hndl->data));\n // data will be destroyed no matter of what the value of nread is\n std::unique_ptr data{buf->base};\n\n // nread == 0 is ignored (see http://docs.libuv.org/en/v1.x/stream.html)\n // equivalent to EAGAIN/EWOULDBLOCK, it shouldn't be treated as an error\n // for we don't have data to emit though, it's fine to suppress it\n\n if(nread == UV_EOF) {\n // end of stream\n ref.publish(end_event{});\n } else if(nread > 0) {\n // data available\n ref.publish(data_event{std::move(data), static_cast(nread)});\n } else if(nread < 0) {\n // transmission error\n ref.publish(error_event(nread));\n }\n }\n\n static void listen_callback(uv_stream_t *hndl, int status) {\n if(T &ref = *(static_cast(hndl->data)); status) {\n ref.publish(error_event{status});\n } else {\n ref.publish(listen_event{});\n }\n }\n\n uv_stream_t *as_uv_stream() {\n return reinterpret_cast(this->raw());\n }\n\n const uv_stream_t *as_uv_stream() const {\n return reinterpret_cast(this->raw());\n }\n\npublic:\n#ifdef _MSC_VER\n stream_handle(loop::token token, std::shared_ptr ref)\n : base{token, std::move(ref)} {}\n#else\n using base::base;\n#endif\n\n /**\n * @brief Shutdowns the outgoing (write) side of a duplex stream.\n *\n * It waits for pending write requests to complete. The handle should refer\n * to a initialized stream.
\n * A shutdown event will be emitted after shutdown is complete.\n *\n * @return Underlying return value.\n */\n int shutdown() {\n auto listener = [ptr = this->shared_from_this()](const auto &event, const auto &) {\n ptr->publish(event);\n };\n\n auto shutdown = this->parent().template resource();\n shutdown->template on(listener);\n shutdown->template on(listener);\n\n return shutdown->shutdown(as_uv_stream());\n }\n\n /**\n * @brief Starts listening for incoming connections.\n *\n * When a new incoming connection is received, a listen event is\n * emitted.\n *\n * @param backlog Indicates the number of connections the kernel might\n * queue, same as listen(2).\n *\n * @return Underlying return value.\n */\n int listen(int backlog = DEFAULT_BACKLOG) {\n return uv_listen(as_uv_stream(), backlog, &listen_callback);\n }\n\n /**\n * @brief Accepts incoming connections.\n *\n * This call is used in conjunction with `listen()` to accept incoming\n * connections. Call this function after receiving a listen event to accept\n * the connection. Before calling this function, the submitted handle must\n * be initialized.\n *\n * When the listen event is emitted it is guaranteed that this function will\n * complete successfully the first time. If you attempt to use it more than\n * once, it may fail.
\n * It is suggested to only call this function once per listen event.\n *\n * @note\n * Both the handles must be running on the same loop.\n *\n * @param ref An initialized handle to be used to accept the connection.\n * @return Underlying return value.\n */\n template\n int accept(S &ref) {\n return uv_accept(as_uv_stream(), ref.as_uv_stream());\n }\n\n /**\n * @brief Starts reading data from an incoming stream.\n *\n * A data event will be emitted several times until there is no more data to\n * read or `stop()` is called.
\n * An end event will be emitted when there is no more data to read.\n *\n * @return Underlying return value.\n */\n int read() {\n "]} +{"repo": "skypjack/uvw", "name": "env", "language": "cpp", "path": "src/uvw/util.h", "position_ratio": 0.75, "description": "\nFunction Description:\n1. **Purpose**: To manage environment variables by either retrieving, setting, or iterating through them.\n2. **Input**: For retrieval and setting, the name of the environment variable is required. For setting, an additional value is needed which can be an empty string to unset the variable. For iteration, a function object is required that can handle the name and value of each environment variable.\n3. **Output**: For retrieval, it returns the value of the environment variable or an empty string if not found or on error. For setting, it returns a boolean indicating success or failure. For iteration, it returns a boolean indicating success or failure of the iteration process.\n4. **Procedure**: Depending on the operation:\n - For retrieval, the function checks the system for the specified environment variable and returns its value.\n - For setting, the function updates or deletes the specified environment variable based on the value provided.\n - For iteration, the function retrieves all environment variables and passes each one to the provided function object.\n\n", "instruction": "Based on the function description and code context, please retrieve and repeat the exact described function from the code context in a code block wrapped by ```:", "template": "instruction\ncode_context\ndescription\ninstruction", "code_context": "// Path: src/uvw/config.h\n#ifndef UVW_CONFIG_H\n#define UVW_CONFIG_H\n\n#ifndef UVW_AS_LIB\n# define UVW_INLINE inline\n#else\n# define UVW_INLINE\n#endif\n\n#endif\n\n// Path: src/uvw/type_info.hpp\n#ifndef UVW_TYPE_INFO_INCLUDE_HPP\n#define UVW_TYPE_INFO_INCLUDE_HPP\n\n#include \n#include \"config.h\"\n\nnamespace uvw {\n\n/**\n * @cond TURN_OFF_DOXYGEN\n * Internal details not to be documented.\n */\n\nnamespace internal {\n\n// Fowler-Noll-Vo hash function v. 1a - the good\n[[nodiscard]] static constexpr std::uint32_t fnv1a(const char *curr) noexcept {\n constexpr std::uint32_t offset = 2166136261;\n constexpr std::uint32_t prime = 16777619;\n auto value = offset;\n\n while(*curr != 0) {\n auto curr_val_int = static_cast(*(curr++));\n value = (value ^ curr_val_int) * prime;\n }\n\n return value;\n}\n\n[[nodiscard]] static inline std::uint32_t counter() noexcept {\n static std::uint32_t cnt{};\n return cnt++;\n}\n\ntemplate\n[[nodiscard]] static std::uint32_t fake() noexcept {\n static std::uint32_t local = counter();\n return local;\n}\n\n} // namespace internal\n\n/**\n * Internal details not to be documented.\n * @endcond\n */\n\n/**\n * @brief Returns a numerical identifier for a given type.\n * @tparam Type The type for which to return the numerical identifier.\n * @return The numerical identifier of the give type.\n */\ntemplate\n[[nodiscard]] static constexpr std::uint32_t type() noexcept {\n#if defined __clang__ || defined __GNUC__\n return internal::fnv1a(__PRETTY_FUNCTION__);\n#elif defined _MSC_VER\n return internal::fnv1a(__FUNCSIG__);\n#else\n return internal::fake();\n#endif\n}\n\n} // namespace uvw\n\n#endif // UVW_TYPE_INFO_INCLUDE_HPP\n\n// Path: src/uvw/emitter.cpp\n#ifdef UVW_AS_LIB\n# include \"emitter.h\"\n#endif\n\n#include \"config.h\"\n\nnamespace uvw {\n\nUVW_INLINE int error_event::translate(int sys) noexcept {\n return uv_translate_sys_error(sys);\n}\n\nUVW_INLINE const char *error_event::what() const noexcept {\n return uv_strerror(ec);\n}\n\nUVW_INLINE const char *error_event::name() const noexcept {\n return uv_err_name(ec);\n}\n\nUVW_INLINE int error_event::code() const noexcept {\n return ec;\n}\n\nUVW_INLINE error_event::operator bool() const noexcept {\n return ec < 0;\n}\n\n} // namespace uvw\n\n// Path: src/uvw/emitter.h\n#ifndef UVW_EMITTER_INCLUDE_H\n#define UVW_EMITTER_INCLUDE_H\n\n#include \n#include \n#include \n#include \n#include \n#include \n#include \n#include \n#include \n#include \"config.h\"\n#include \"type_info.hpp\"\n\nnamespace uvw {\n\n/**\n * @brief Error event.\n *\n * Custom wrapper around error constants of `libuv`.\n */\nstruct error_event {\n template>>\n explicit error_event(Type val) noexcept\n : ec{static_cast(val)} {}\n\n /**\n * @brief Returns the `libuv` error code equivalent to the given platform dependent error code.\n *\n * It returns:\n * * POSIX error codes on Unix (the ones stored in errno).\n * * Win32 error codes on Windows (those returned by GetLastError() or WSAGetLastError()).\n *\n * If `sys` is already a `libuv` error code, it is simply returned.\n *\n * @param sys A platform dependent error code.\n * @return The `libuv` error code equivalent to the given platform dependent error code.\n */\n static int translate(int sys) noexcept;\n\n /**\n * @brief Returns the error message for the given error code.\n *\n * Leaks a few bytes of memory when you call it with an unknown error code.\n *\n * @return The error message for the given error code.\n */\n const char *what() const noexcept;\n\n /**\n * @brief Returns the error name for the given error code.\n *\n * Leaks a few bytes of memory when you call it with an unknown error code.\n *\n * @return The error name for the given error code.\n */\n const char *name() const noexcept;\n\n /**\n * @brief Gets the underlying error code, that is an error constant of `libuv`.\n * @return The underlying error code.\n */\n int code() const noexcept;\n\n /**\n * @brief Checks if the event contains a valid error code.\n * @return True in case of success, false otherwise.\n */\n explicit operator bool() const noexcept;\n\nprivate:\n const int ec;\n};\n\n/**\n * @brief Event emitter base class.\n *\n * Almost everything in `uvw` is an event emitter.
\n * This is the base class from which resources and loops inherit.\n */\ntemplate\nclass emitter {\npublic:\n template\n using listener_t = std::function;\n\nprivate:\n template\n const auto &handler() const noexcept {\n return std::get>(handlers);\n }\n\n template\n auto &handler() noexcept {\n return std::get>(handlers);\n }\n\nprotected:\n template\n void publish(Type event) {\n if(auto &listener = handler(); listener) {\n listener(event, *static_cast(this));\n }\n }\n\npublic:\n virtual ~emitter() noexcept {\n static_assert(std::is_base_of_v, Elem>);\n }\n\n /**\n * @brief Registers a long-lived listener with the event emitter.\n *\n * This method is used to register a listener with the emitter.
\n * A listener is usually defined as a callable object assignable to a\n * `std::function\n void on(listener_t f) {\n handler() = std::move(f);\n }\n\n /*! @brief Disconnects the listener for the given event type. */\n template\n void reset() noexcept {\n handler() = nullptr;\n }\n\n /*! @brief Disconnects all listeners. */\n void reset() noexcept {\n reset();\n (reset(), ...);\n }\n\n /**\n * @brief Checks if there is a listener registered for the specific event.\n * @return True if there is a listener registered for the specific event,\n * false otherwise.\n */\n template\n bool has() const noexcept {\n return static_cast(handler());\n }\n\nprivate:\n std::tuple, listener_t...> handlers{};\n};\n\n} // namespace uvw\n\n#ifndef UVW_AS_LIB\n# include \"emitter.cpp\"\n#endif\n\n#endif // UVW_EMITTER_INCLUDE_H\n\n// Path: src/uvw/util.cpp\n#ifdef UVW_AS_LIB\n# include \"util.h\"\n#endif\n\n#include \"config.h\"\n\nnamespace uvw {\n\nUVW_INLINE passwd_info::passwd_info(std::shared_ptr pwd)\n : value{pwd} {}\n\nUVW_INLINE std::string passwd_info::username() const noexcept {\n return ((value && value->username) ? value->username : \"\");\n}\n\nUVW_INLINE decltype(uv_passwd_t::uid) passwd_info::uid() const noexcept {\n return (value ? value->uid : decltype(uv_passwd_t::uid){});\n}\n\nUVW_INLINE decltype(uv_passwd_t::gid) passwd_info::gid() const noexcept {\n return (value ? value->gid : decltype(uv_passwd_t::gid){});\n}\n\nUVW_INLINE std::string passwd_info::shell() const noexcept {\n return ((value && value->shell) ? value->shell : \"\");\n}\n\nUVW_INLINE std::string passwd_info::homedir() const noexcept {\n return ((value && value->homedir) ? value->homedir : \"\");\n}\n\nUVW_INLINE passwd_info::operator bool() const noexcept {\n return static_cast(value);\n}\n\nUVW_INLINE uts_name::uts_name(std::shared_ptr init)\n : uname{init} {}\n\nUVW_INLINE std::string uts_name::sysname() const noexcept {\n return uname ? uname->sysname : \"\";\n}\n\nUVW_INLINE std::string uts_name::release() const noexcept {\n return uname ? uname->release : \"\";\n}\n\nUVW_INLINE std::string uts_name::version() const noexcept {\n return uname ? uname->version : \"\";\n}\n\nUVW_INLINE std::string uts_name::machine() const noexcept {\n return uname ? uname->machine : \"\";\n}\n\nnamespace details {\n\nUVW_INLINE void common_alloc_callback(uv_handle_t *, std::size_t suggested, uv_buf_t *buf) {\n auto size = static_cast(suggested);\n *buf = uv_buf_init(new char[size], size);\n}\n\nUVW_INLINE sockaddr ip_addr(const char *addr, unsigned int port) {\n // explicitly cast to avoid `-Wsign-conversion` warnings\n // libuv internally just casts to an `unsigned short` anyway\n auto signed_port = static_cast(port);\n if(sockaddr_in addr_in; uv_ip4_addr(addr, signed_port, &addr_in) == 0) {\n return reinterpret_cast(addr_in);\n } else if(sockaddr_in6 addr_in6; uv_ip6_addr(addr, signed_port, &addr_in6) == 0) {\n return reinterpret_cast(addr_in6);\n }\n\n return {};\n}\n\nUVW_INLINE socket_address sock_addr(const sockaddr_in &addr) {\n if(char name[details::DEFAULT_SIZE]; uv_ip4_name(&addr, name, details::DEFAULT_SIZE) == 0) {\n return socket_address{std::string{name}, ntohs(addr.sin_port)};\n }\n\n return socket_address{};\n}\n\nUVW_INLINE socket_address sock_addr(const sockaddr_in6 &addr) {\n if(char name[details::DEFAULT_SIZE]; uv_ip6_name(&addr, name, details::DEFAULT_SIZE) == 0) {\n return socket_address{std::string{name}, ntohs(addr.sin6_port)};\n }\n\n return socket_address{};\n}\n\nUVW_INLINE socket_address sock_addr(const sockaddr &addr) {\n if(addr.sa_family == AF_INET) {\n return sock_addr(reinterpret_cast(addr));\n } else if(addr.sa_family == AF_INET6) {\n return sock_addr(reinterpret_cast(addr));\n }\n\n return socket_address{};\n}\n\nUVW_INLINE socket_address sock_addr(const sockaddr_storage &storage) {\n if(storage.ss_family == AF_INET) {\n return sock_addr(reinterpret_cast(storage));\n } else if(storage.ss_family == AF_INET6) {\n return sock_addr(reinterpret_cast(storage));\n }\n\n return socket_address{};\n}\n\n} // namespace details\n\nUVW_INLINE pid_type utilities::os::pid() noexcept {\n return uv_os_getpid();\n}\n\nUVW_INLINE pid_type utilities::os::ppid() noexcept {\n return uv_os_getppid();\n}\n\nUVW_INLINE std::string utilities::os::homedir() noexcept {\n return details::try_read(&uv_os_homedir);\n}\n\nUVW_INLINE std::string utilities::os::tmpdir() noexcept {\n return details::try_read(&uv_os_tmpdir);\n}\n\nUVW_INLINE std::string utilities::os::env(const std::string &name) noexcept {\n return details::try_read(&uv_os_getenv, name.c_str());\n}\n\nUVW_INLINE bool utilities::os::env(const std::string &name, const std::string &value) noexcept {\n return (0 == (value.empty() ? uv_os_unsetenv(name.c_str()) : uv_os_setenv(name.c_str(), value.c_str())));\n}\n\nUVW_INLINE std::string utilities::os::hostname() noexcept {\n return details::try_read(&uv_os_gethostname);\n}\n\nUVW_INLINE uts_name utilities::os::uname() noexcept {\n auto ptr = std::make_shared();\n uv_os_uname(ptr.get());\n return ptr;\n}\n\nUVW_INLINE passwd_info utilities::os::passwd() noexcept {\n auto deleter = [](uv_passwd_t *passwd) {\n uv_os_free_passwd(passwd);\n delete passwd;\n };\n\n std::shared_ptr ptr{new uv_passwd_t, std::move(deleter)};\n uv_os_get_passwd(ptr.get());\n return ptr;\n}\n\nUVW_INLINE int utilities::os::priority(pid_type pid) {\n int prio = 0;\n\n if(uv_os_getpriority(pid, &prio)) {\n prio = UV_PRIORITY_LOW + 1;\n }\n\n return prio;\n}\n\nUVW_INLINE bool utilities::os::priority(pid_type pid, int prio) {\n return 0 == uv_os_setpriority(pid, prio);\n}\n\nUVW_INLINE handle_type utilities::guess_handle(handle_category category) noexcept {\n switch(category) {\n case UV_ASYNC:\n return handle_type::ASYNC;\n case UV_CHECK:\n return handle_type::CHECK;\n case UV_FS_EVENT:\n return handle_type::FS_EVENT;\n case UV_FS_POLL:\n return handle_type::FS_POLL;\n case UV_HANDLE:\n return handle_type::HANDLE;\n case UV_IDLE:\n return handle_type::IDLE;\n case UV_NAMED_PIPE:\n return handle_type::PIPE;\n case UV_POLL:\n return handle_type::POLL;\n case UV_PREPARE:\n return handle_type::PREPARE;\n case UV_PROCESS:\n return handle_type::PROCESS;\n case UV_STREAM:\n return handle_type::STREAM;\n case UV_TCP:\n return handle_type::TCP;\n case UV_TIMER:\n return handle_type::TIMER;\n case UV_TTY:\n return handle_type::TTY;\n case UV_UDP:\n return handle_type::UDP;\n case UV_SIGNAL:\n return handle_type::SIGNAL;\n case UV_FILE:\n return handle_type::FILE;\n default:\n return handle_type::UNKNOWN;\n }\n}\n\nUVW_INLINE handle_type utilities::guess_handle(file_handle file) noexcept {\n handle_category category = uv_guess_handle(file);\n return guess_handle(category);\n}\n\nUVW_INLINE std::vector utilities::cpu() noexcept {\n std::vector cpuinfos;\n\n uv_cpu_info_t *infos;\n int count;\n\n if(0 == uv_cpu_info(&infos, &count)) {\n for(int next = 0; next < count; ++next) {\n cpuinfos.push_back({infos[next].model, infos[next].speed, infos[next].cpu_times});\n }\n\n uv_free_cpu_info(infos, count);\n }\n\n return cpuinfos;\n}\n\nUVW_INLINE std::vector utilities::interface_addresses() noexcept {\n std::vector interfaces;\n\n uv_interface_address_t *ifaces{nullptr};\n int count{0};\n\n if(0 == uv_interface_addresses(&ifaces, &count)) {\n for(int next = 0; next < count; ++next) {\n interface_address iface_addr;\n\n iface_addr.name = ifaces[next].name;\n std::copy(ifaces[next].phys_addr, (ifaces[next].phys_addr + 6), iface_addr.physical);\n iface_addr.internal = ifaces[next].is_internal == 0 ? false : true;\n\n if(ifaces[next].address.address4.sin_family == AF_INET) {\n iface_addr.address = details::sock_addr(ifaces[next].address.address4);\n iface_addr.netmask = details::sock_addr(ifaces[next].netmask.netmask4);\n } else if(ifaces[next].address.address4.sin_family == AF_INET6) {\n iface_addr.address = details::sock_addr(ifaces[next].address.address6);\n iface_addr.netmask = details::sock_addr(ifaces[next].netmask.netmask6);\n }\n\n interfaces.push_back(std::move(iface_addr));\n }\n\n uv_free_interface_addresses(ifaces, count);\n }\n\n return interfaces;\n}\n\nUVW_INLINE std::string utilities::index_to_name(unsigned int index) noexcept {\n return details::try_read(&uv_if_indextoname, index);\n}\n\nUVW_INLINE std::string utilities::index_to_iid(unsigned int index) noexcept {\n return details::try_read(&uv_if_indextoiid, index);\n}\n\nUVW_INLINE bool utilities::replace_allocator(malloc_func_type malloc_func, realloc_func_type realloc_func, calloc_func_type calloc_func, free_func_type free_func) noexcept {\n return (0 == uv_replace_allocator(malloc_func, realloc_func, calloc_func, free_func));\n}\n\nUVW_INLINE std::array utilities::load_average() noexcept {\n std::array avg;\n uv_loadavg(avg.data());\n return avg;\n}\n\nUVW_INLINE char **utilities::setup_args(int argc, char **argv) {\n return uv_setup_args(argc, argv);\n}\n\nUVW_INLINE std::string utilities::process_title() {\n std::size_t size = details::DEFAULT_SIZE;\n char buf[details::DEFAULT_SIZE];\n std::string str{};\n\n if(0 == uv_get_process_title(buf, size)) {\n str.assign(buf, size);\n }\n\n return str;\n}\n\nUVW_INLINE bool utilities::process_title(const std::string &title) {\n return (0 == uv_set_process_title(title.c_str()));\n}\n\nUVW_INLINE uint64_t utilities::total_memory() noexcept {\n return uv_get_total_memory();\n}\n\nUVW_INLINE uint64_t utilities::constrained_memory() noexcept {\n return uv_get_constrained_memory();\n}\n\nUVW_INLINE uint64_t utilities::available_memory() noexcept {\n return uv_get_available_memory();\n}\n\nUVW_INLINE double utilities::uptime() noexcept {\n double ret;\n\n if(0 != uv_uptime(&ret)) {\n ret = 0;\n }\n\n return ret;\n}\n\nUVW_INLINE resource_usage utilities::rusage() noexcept {\n resource_usage ru;\n auto err = uv_getrusage(&ru);\n return err ? resource_usage{} : ru;\n}\n\nUVW_INLINE timespec64 utilities::gettime(clock_id source) noexcept {\n timespec64 ts;\n auto err = uv_clock_gettime(static_cast(source), &ts);\n return err ? timespec64{} : ts;\n}\n\nUVW_INLINE uint64_t utilities::hrtime() noexcept {\n return uv_hrtime();\n}\n\nUVW_INLINE std::string utilities::path() noexcept {\n return details::try_read(&uv_exepath);\n}\n\nUVW_INLINE std::string utilities::cwd() noexcept {\n return details::try_read(&uv_cwd);\n}\n\nUVW_INLINE bool utilities::chdir(const std::string &dir) noexcept {\n return (0 == uv_chdir(dir.data()));\n}\n\nUVW_INLINE timeval64 utilities::time_of_day() noexcept {\n uv_timeval64_t ret;\n uv_gettimeofday(&ret);\n return ret;\n}\n\nUVW_INLINE void utilities::sleep(unsigned int msec) noexcept {\n uv_sleep(msec);\n}\n\nUVW_INLINE unsigned int utilities::available_parallelism() noexcept {\n return uv_available_parallelism();\n}\n\n} // namespace uvw\n\n// Path: src/uvw/util.h\n#ifndef UVW_UTIL_INCLUDE_H\n#define UVW_UTIL_INCLUDE_H\n\n#include \n#include \n#include \n#include \n#include \n#include \n#include \n#include \n#include \n#include \n#include \"config.h\"\n\nnamespace uvw {\n\nnamespace details {\n\nenum class uvw_handle_type : std::underlying_type_t {\n UNKNOWN = UV_UNKNOWN_HANDLE,\n ASYNC = UV_ASYNC,\n CHECK = UV_CHECK,\n FS_EVENT = UV_FS_EVENT,\n FS_POLL = UV_FS_POLL,\n HANDLE = UV_HANDLE,\n IDLE = UV_IDLE,\n PIPE = UV_NAMED_PIPE,\n POLL = UV_POLL,\n PREPARE = UV_PREPARE,\n PROCESS = UV_PROCESS,\n STREAM = UV_STREAM,\n TCP = UV_TCP,\n TIMER = UV_TIMER,\n TTY = UV_TTY,\n UDP = UV_UDP,\n SIGNAL = UV_SIGNAL,\n FILE = UV_FILE\n};\n\nenum class uvw_clock_id : std::underlying_type_t {\n MONOTONIC = UV_CLOCK_MONOTONIC,\n REALTIME = UV_CLOCK_REALTIME\n};\n\ntemplate\nstruct uv_type_wrapper {\n using Type = T;\n\n constexpr uv_type_wrapper()\n : value{} {}\n\n constexpr uv_type_wrapper(Type val)\n : value{val} {}\n\n constexpr operator Type() const noexcept {\n return value;\n }\n\n bool operator==(uv_type_wrapper other) const noexcept {\n return value == other.value;\n }\n\nprivate:\n const Type value;\n};\n\ntemplate\nbool operator==(uv_type_wrapper lhs, uv_type_wrapper rhs) {\n return !(lhs == rhs);\n}\n\n} // namespace details\n\n/**\n * @brief Windows size representation.\n */\nstruct win_size {\n int width; /*!< The _width_ of the given window. */\n int height; /*!< The _height_ of the given window. */\n};\n\nusing handle_type = details::uvw_handle_type; /*!< The type of a handle. */\nusing handle_category = details::uv_type_wrapper; /*!< Utility class that wraps an internal handle type. */\nusing file_handle = details::uv_type_wrapper; /*!< Utility class that wraps an internal file handle. */\nusing os_socket_handle = details::uv_type_wrapper; /*!< Utility class that wraps an os socket handle. */\nusing os_file_descriptor = details::uv_type_wrapper; /*!< Utility class that wraps an os file descriptor. */\nusing pid_type = details::uv_type_wrapper; /*!< Utility class that wraps a cross platform representation of a pid. */\nusing clock_id = details::uvw_clock_id; /*!< Utility class that wraps a clock source. */\n\nconstexpr file_handle std_in{0}; /*!< Placeholder for stdin descriptor. */\nconstexpr file_handle std_out{1}; /*!< Placeholder for stdout descriptor. */\nconstexpr file_handle std_err{2}; /*!< Placeholder for stderr descriptor. */\n\nusing time_spec = uv_timespec_t; /*!< Library equivalent for uv_timespec_t. */\nusing file_info = uv_stat_t; /*!< Library equivalent for uv_stat_t. */\nusing fs_info = uv_statfs_t; /*!< Library equivalent for uv_statfs_t. */\nusing uid_type = uv_uid_t; /*!< Library equivalent for uv_uid_t. */\nusing gid_type = uv_gid_t; /*!< Library equivalent for uv_gid_t. */\n\nusing timeval = uv_timeval_t; /*!< Library equivalent for uv_timeval_t. */\nusing timeval64 = uv_timeval64_t; /*!< Library equivalent for uv_timeval64_t. */\nusing timespec64 = uv_timespec64_t; /*!< Library equivalent for uv_timespec64_t. */\nusing resource_usage = uv_rusage_t; /*!< Library equivalent for uv_rusage_t. */\n\n/**\n * @brief Utility class.\n *\n * This class can be used to query the subset of the password file entry for the\n * current effective uid (not the real uid).\n *\n * \\sa utilities::passwd\n */\nstruct passwd_info {\n passwd_info(std::shared_ptr pwd);\n\n /**\n * @brief Gets the username.\n * @return The username of the current effective uid (not the real uid).\n */\n std::string username() const noexcept;\n\n /**\n * @brief Gets the uid.\n * @return The current effective uid (not the real uid).\n */\n decltype(uv_passwd_t::uid) uid() const noexcept;\n\n /**\n * @brief Gets the gid.\n * @return The gid of the current effective uid (not the real uid).\n */\n decltype(uv_passwd_t::gid) gid() const noexcept;\n\n /**\n * @brief Gets the shell.\n * @return The shell of the current effective uid (not the real uid).\n */\n std::string shell() const noexcept;\n\n /**\n * @brief Gets the homedir.\n * @return The homedir of the current effective uid (not the real uid).\n */\n std::string homedir() const noexcept;\n\n /**\n * @brief Checks if the instance contains valid data.\n * @return True if data are all valid, false otherwise.\n */\n operator bool() const noexcept;\n\nprivate:\n std::shared_ptr value;\n};\n\n/**\n * @brief Utility class.\n *\n * This class can be used to get name and information about the current kernel.\n * The populated data includes the operating system name, release, version, and\n * machine.\n *\n * \\sa utilities::uname\n */\nstruct uts_name {\n uts_name(std::shared_ptr init);\n\n /**\n * @brief Gets the operating system name (like \"Linux\").\n * @return The operating system name.\n */\n std::string sysname() const noexcept;\n\n /**\n * @brief Gets the operating system release (like \"2.6.28\").\n * @return The operating system release.\n */\n std::string release() const noexcept;\n\n /**\n * @brief Gets the operating system version.\n * @return The operating system version\n */\n std::string version() const noexcept;\n\n /**\n * @brief Gets the hardware identifier.\n * @return The hardware identifier.\n */\n std::string machine() const noexcept;\n\nprivate:\n std::shared_ptr uname;\n};\n\n/**\n * @brief The IPv4 tag.\n *\n * To be used as template parameter to switch between IPv4 and IPv6.\n */\nstruct ipv4 {};\n\n/**\n * @brief The IPv6 tag.\n *\n * To be used as template parameter to switch between IPv4 and IPv6.\n */\nstruct ipv6 {};\n\n/**\n * @brief Address representation.\n */\nstruct socket_address {\n std::string ip; /*!< Either an IPv4 or an IPv6. */\n unsigned int port; /*!< A valid service identifier. */\n};\n\n/**\n * \\brief CPU information.\n */\nstruct cpu_info {\n using cpu_time = decltype(uv_cpu_info_t::cpu_times);\n\n std::string model; /*!< The model of the CPU. */\n int speed; /*!< The frequency of the CPU. */\n\n /**\n * @brief CPU times.\n *\n * It is built up of the following data members: `user`, `nice`, `sys`,\n * `idle`, `irq`, all of them having type `uint64_t`.\n */\n cpu_time times;\n};\n\n/**\n * \\brief Interface address.\n */\nstruct interface_address {\n std::string name; /*!< The name of the interface (as an example _eth0_). */\n char physical[6]; /*!< The physical address. */\n bool internal; /*!< True if it is an internal interface (as an example _loopback_), false otherwise. */\n socket_address address; /*!< The address of the given interface. */\n socket_address netmask; /*!< The netmask of the given interface. */\n};\n\nnamespace details {\n\nstatic constexpr std::size_t DEFAULT_SIZE = 128;\n\ntemplate\nstd::string try_read(F &&f, Args &&...args) noexcept {\n std::size_t size = DEFAULT_SIZE;\n char buf[DEFAULT_SIZE];\n std::string str{};\n auto err = std::forward(f)(args..., buf, &size);\n\n if(UV_ENOBUFS == err) {\n std::unique_ptr data{new char[size]};\n err = std::forward(f)(args..., data.get(), &size);\n\n if(0 == err) {\n str = data.get();\n }\n } else if(0 == err) {\n str.assign(buf, size);\n }\n\n return str;\n}\n\nvoid common_alloc_callback(uv_handle_t *, std::size_t suggested, uv_buf_t *buf);\n\ntemplate\nvoid common_alloc_callback(uv_handle_t *handle, std::size_t suggested, uv_buf_t *buf) {\n auto [alloc, size] = Alloc(*static_cast(handle->data), suggested);\n *buf = uv_buf_init(alloc, static_cast(size));\n}\n\nsockaddr ip_addr(const char *addr, unsigned int port);\nsocket_address sock_addr(const sockaddr_in &addr);\nsocket_address sock_addr(const sockaddr_in6 &addr);\nsocket_address sock_addr(const sockaddr &addr);\nsocket_address sock_addr(const sockaddr_storage &storage);\n\n} // namespace details\n\n/**\n * @brief Miscellaneous utilities.\n *\n * Miscellaneous functions that don\u2019t really belong to any other class.\n */\nstruct utilities {\n using malloc_func_type = void *(*)(size_t);\n using realloc_func_type = void *(*)(void *, size_t);\n using calloc_func_type = void *(*)(size_t, size_t);\n using free_func_type = void (*)(void *);\n\n /**\n * @brief OS dedicated utilities.\n */\n struct os {\n /**\n * @brief Returns the current process id.\n *\n * See the official\n * [documentation](http://docs.libuv.org/en/v1.x/misc.html#c.uv_os_getpid)\n * for further details.\n *\n * @return The current process id.\n */\n static pid_type pid() noexcept;\n\n /**\n * @brief Returns the parent process id.\n *\n * See the official\n * [documentation](http://docs.libuv.org/en/v1.x/misc.html#c.uv_os_getppid)\n * for further details.\n *\n * @return The parent process id.\n */\n static pid_type ppid() noexcept;\n\n /**\n * @brief Gets the current user's home directory.\n *\n * See the official\n * [documentation](http://docs.libuv.org/en/v1.x/misc.html#c.uv_os_homedir)\n * for further details.\n *\n * @return The current user's home directory, an empty string in case of\n * errors.\n */\n static std::string homedir() noexcept;\n\n /**\n * @brief Gets the temp directory.\n *\n * See the official\n * [documentation](http://docs.libuv.org/en/v1.x/misc.html#c.uv_os_tmpdir)\n * for further details.\n *\n * @return The temp directory, an empty string in case of errors.\n */\n static std::string tmpdir() noexcept;\n\n /**\n * @brief Retrieves an environment variable.\n * @param name The name of the variable to be retrieved.\n * @return The value of the environment variable, an empty string in\n * case of errors.\n */\n static std::string env(const std::string &name) noexcept;\n\n /**\n * @brief Creates, updates or deletes an environment variable.\n * @param name The name of the variable to be updated.\n * @param value The value to be used for the variable (an empty string\n * to unset it).\n * @return True in case of success, false otherwise.\n */\n static bool env(const std::string &name, const std::string &value) noexcept;\n\n /**\n * @brief Retrieves all environment variables and iterates them.\n *\n * Environment variables are passed one at a time to the callback in the\n * form of `std::string_view`s.
\n * The signature of the function call operator must be such that it\n * accepts two parameters, the name and the value of the i-th variable.\n *\n * @tparam Func Type of a function object to which to pass environment\n * variables.\n * @param func A function object to which to pass environment variables.\n * @return True in case of success, false otherwise.\n */\n template\n st\natic std::enable_if_t, bool>\n env(Func func) noexcept {\n uv_env_item_t *items = nullptr;\n int count{};\n\n const bool ret = (uv_os_environ(&items, &count) == 0);\n\n if(ret) {\n for(int pos = 0; pos < count; ++pos) {\n func(std::string_view{items[pos].name}, std::string_view{items[pos].value});\n }\n\n uv_os_free_environ(items, count);\n }\n\n return ret;\n }\n\n /**\n * @brief Returns the hostname.\n * @return The hostname, an empty string in case of errors.\n */\n static std::string hostname() noexcept;\n\n /**\n * @brief Gets name and information about the current kernel.\n *\n * This function can be used to get name and information about the\n * current kernel. The populated data includes the operating system\n * name, release, version, and machine.\n *\n * @return Name and information about the current kernel.\n */\n static uts_name uname() noexcept;\n\n /**\n * @brief Gets a subset of the password file entry.\n *\n * This function can be used to get the subset of the password file\n * entry for the current effective uid (not the real uid).\n *\n * See the official\n * [documentation](http://docs.libuv.org/en/v1.x/misc.html#c.uv_os_get_passwd)\n * for further details.\n *\n * @return The accessible subset of the password file entry.\n */\n static passwd_info passwd() noexcept;\n\n /**\n * @brief Retrieves the scheduling priority of a process.\n *\n * The returned value is between -20 (high priority) and 19 (low priority).\n * A value that is out of range is returned in case of errors.\n *\n * @note\n * On Windows, the result won't equal necessarily the exact value of the\n * priority because of a mapping to a Windows priority class.\n *\n * @param pid A valid process id.\n * @return The scheduling priority of the process.\n */\n static int priority(pid_type pid);\n\n /**\n * @brief Sets the scheduling priority of a process.\n *\n * The returned value range is between -20 (high priority) and 19 (low\n * priority).\n *\n * @note\n * On Windows, the priority is mapped to a Windows priority class. When\n * retrieving the process priority, the result won't equal necessarily the\n * exact value of the priority.\n *\n * @param pid A valid process id.\n * @param prio The scheduling priority to set to the process.\n * @return True in case of success, false otherwise.\n */\n static bool priority(pid_type pid, int prio);\n };\n\n /**\n * @brief Gets the type of the handle given a category.\n * @param category A properly initialized handle category.\n * @return The actual type of the handle as defined by handle_type\n */\n static handle_type guess_handle(handle_category category) noexcept;\n\n /**\n * @brief Gets the type of the stream to be used with the given descriptor.\n *\n * Returns the type of stream that should be used with a given file\n * descriptor.
\n * Usually this will be used during initialization to guess the type of the\n * stdio streams.\n *\n * @param file A valid descriptor.\n * @return One of the following types:\n *\n * * `handle_type::UNKNOWN`\n * * `handle_type::PIPE`\n * * `handle_type::TCP`\n * * `handle_type::TTY`\n * * `handle_type::UDP`\n * * `handle_type::FILE`\n */\n static handle_type guess_handle(file_handle file) noexcept;\n\n /** @brief Gets information about the CPUs on the system.\n *\n * This function can be used to query the underlying system and get a set of\n * descriptors of all the available CPUs.\n *\n * @return A set of descriptors of all the available CPUs.\n */\n static std::vector cpu() noexcept;\n\n /**\n * @brief Gets a set of descriptors of all the available interfaces.\n *\n * This function can be used to query the underlying system and get a set of\n * descriptors of all the available interfaces, either internal or not.\n *\n * @return A set of descriptors of all the available interfaces.\n */\n static std::vector interface_addresses() noexcept;\n\n /**\n * @brief IPv6-capable implementation of\n * [if_indextoname](https://linux.die.net/man/3/if_indextoname).\n *\n * Mapping between network interface names and indexes.\n *\n * See the official\n * [documentation](http://docs.libuv.org/en/v1.x/misc.html#c.uv_if_indextoname)\n * for further details.\n *\n * @param index Network interface index.\n * @return Network interface name.\n */\n static std::string index_to_name(unsigned int index) noexcept;\n\n /**\n * @brief Retrieves a network interface identifier.\n *\n * See the official\n * [documentation](http://docs.libuv.org/en/v1.x/misc.html#c.uv_if_indextoiid)\n * for further details.\n *\n * @param index Network interface index.\n * @return Network interface identifier.\n */\n static std::string index_to_iid(unsigned int index) noexcept;\n\n /**\n * @brief Override the use of some standard library\u2019s functions.\n *\n * Override the use of the standard library\u2019s memory allocation\n * functions.
\n * This method must be invoked before any other `uvw` function is called or\n * after all resources have been freed and thus the underlying library\n * doesn\u2019t reference any allocated memory chunk.\n *\n * If any of the function pointers is _null_, the invokation will fail.\n *\n * @note\n * There is no protection against changing the allocator multiple times. If\n * the user changes it they are responsible for making sure the allocator is\n * changed while no memory was allocated with the previous allocator, or\n * that they are compatible.\n *\n * @param malloc_func Replacement function for _malloc_.\n * @param realloc_func Replacement function for _realloc_.\n * @param calloc_func Replacement function for _calloc_.\n * @param free_func Replacement function for _free_.\n * @return True in case of success, false otherwise.\n */\n static bool replace_allocator(malloc_func_type malloc_func, realloc_func_type realloc_func, calloc_func_type calloc_func, free_func_type free_func) noexcept;\n\n /**\n * @brief Gets the load average.\n * @return `[0,0,0]` on Windows (not available), the load average otherwise.\n */\n static std::array load_average() noexcept;\n\n /**\n * @brief Store the program arguments.\n *\n * Required for getting / setting the process title.\n *\n * @return Arguments that haven't been consumed internally.\n */\n static char **setup_args(int argc, char **argv);\n\n /**\n * @brief Gets the title of the current process.\n * @return The process title.\n */\n static std::string process_title();\n\n /**\n * @brief Sets the current process title.\n * @param title The process title to be set.\n * @return True in case of success, false otherwise.\n */\n static bool process_title(const std::string &title);\n\n /**\n * @brief Gets memory information (in bytes).\n * @return Memory information.\n */\n static uint64_t total_memory() noexcept;\n\n /**\n * @brief Gets the amount of memory available to the process (in bytes).\n *\n * Gets the amount of memory available to the process based on limits\n * imposed by the OS. If there is no such constraint, or the constraint is\n * unknown, `0` is returned.
\n * Note that it is not unusual for this value to be less than or greater\n * than `totalMemory`.\n *\n * @return Amount of memory available to the process.\n */\n static uint64_t constrained_memory() noexcept;\n\n /**\n * @brief Gets the amount of free memory still available to the process.\n * @return Amount of free memory still available to the process (in bytes).\n */\n static uint64_t available_memory() noexcept;\n\n /**\n * @brief Gets the current system uptime.\n * @return The current system uptime or 0 in case of errors.\n */\n static double uptime() noexcept;\n\n /**\n * @brief Gets the resource usage measures for the current process.\n * @return Resource usage measures, zeroes-filled object in case of errors.\n */\n static resource_usage rusage() noexcept;\n\n /**\n * @brief Gets the current system time from a high-resolution clock source.\n * @param source Clock source, either real-time or monotonic.\n * @return Current system time from the given high-resolution clock source.\n */\n static timespec64 gettime(clock_id source) noexcept;\n\n /**\n * @brief Gets the current high-resolution real time.\n *\n * The time is expressed in nanoseconds. It is relative to an arbitrary time\n * in the past. It is not related to the time of the day and therefore not\n * subject to clock drift. The primary use is for measuring performance\n * between interval.\n *\n * @return The current high-resolution real time.\n */\n static uint64_t hrtime() noexcept;\n\n /**\n * @brief Gets the executable path.\n * @return The executable path, an empty string in case of errors.\n */\n static std::string path() noexcept;\n\n /**\n * @brief Gets the current working directory.\n * @return The current working directory, an empty string in case of errors.\n */\n static std::string cwd() noexcept;\n\n /**\n * @brief Changes the current working directory.\n * @param dir The working directory to be set.\n * @return True in case of success, false otherwise.\n */\n static bool chdir(const std::string &dir) noexcept;\n\n /**\n * @brief Cross-platform implementation of\n * [`gettimeofday`](https://linux.die.net/man/2/gettimeofday)\n * @return The current time.\n */\n static timeval64 time_of_day() noexcept;\n\n /**\n * @brief Causes the calling thread to sleep for a while.\n * @param msec Number of milliseconds to sleep.\n */\n static void sleep(unsigned int msec) noexcept;\n\n /**\n * @brief Returns an estimate of the amount of parallelism a program should\n * use (always a non-zero value).\n * @return Estimate of the amount of parallelism a program should use.\n */\n static unsigned int available_parallelism() noexcept;\n};\n\n/**\n * @brief Helper type for visitors.\n * @tparam Func Types of function objects.\n */\ntemplate\nstruct overloaded: Func... {\n using Func::operator()...;\n};\n\n/**\n * @brief Deduction guide.\n * @tparam Func Types of function objects.\n */\ntemplate\noverloaded(Func...) -> overloaded;\n\n} // namespace uvw\n\n#ifndef UVW_AS_LIB\n# include \"util.cpp\"\n#endif\n\n#endif // UVW_UTIL_INCLUDE_H\n\n// Path: src/uvw/loop.cpp\n#ifdef UVW_AS_LIB\n# include \"loop.h\"\n#endif\n\n#include \"config.h\"\n\nnamespace uvw {\n\nUVW_INLINE loop::loop(std::unique_ptr ptr) noexcept\n : uv_loop{std::move(ptr)} {}\n\nUVW_INLINE std::shared_ptr loop::create() {\n auto ptr = std::unique_ptr{new uv_loop_t, [](uv_loop_t *l) { delete l; }};\n auto curr = std::shared_ptr{new loop{std::move(ptr)}};\n\n if(uv_loop_init(curr->uv_loop.get())) {\n curr = nullptr;\n }\n\n return curr;\n}\n\nUVW_INLINE std::shared_ptr loop::create(uv_loop_t *res) {\n auto ptr = std::unique_ptr{res, [](uv_loop_t *) {}};\n return std::shared_ptr{new loop{std::move(ptr)}};\n}\n\nUVW_INLINE std::shared_ptr loop::get_default() {\n static std::weak_ptr ref;\n std::shared_ptr curr;\n\n if(ref.expired()) {\n auto def = uv_default_loop();\n\n if(def) {\n auto ptr = std::unique_ptr(def, [](uv_loop_t *) {});\n curr = std::shared_ptr{new loop{std::move(ptr)}};\n }\n\n ref = curr;\n } else {\n curr = ref.lock();\n }\n\n return curr;\n}\n\nUVW_INLINE loop::~loop() noexcept {\n if(uv_loop) {\n close();\n }\n}\n\nUVW_INLINE int loop::close() {\n int ret = 0;\n\n if(uv_loop) {\n ret = uv_loop_close(uv_loop.get());\n uv_loop.reset();\n }\n\n return ret;\n}\n\nUVW_INLINE int loop::run(run_mode mode) noexcept {\n return uv_run(uv_loop.get(), static_cast(mode));\n}\n\nUVW_INLINE bool loop::alive() const noexcept {\n return !!uv_loop_alive(uv_loop.get());\n}\n\nUVW_INLINE void loop::stop() noexcept {\n uv_stop(uv_loop.get());\n}\n\nUVW_INLINE int loop::descriptor() const noexcept {\n return uv_backend_fd(uv_loop.get());\n}\n\nUVW_INLINE std::pair loop::timeout() const noexcept {\n auto to = uv_backend_timeout(uv_loop.get());\n return std::make_pair(to == -1, time{to});\n}\n\nUVW_INLINE loop::time loop::idle_time() const noexcept {\n return time{uv_metrics_idle_time(uv_loop.get())};\n}\n\nUVW_INLINE metrics_type loop::metrics() const noexcept {\n metrics_type res{};\n uv_metrics_info(uv_loop.get(), &res);\n return res;\n}\n\nUVW_INLINE loop::time loop::now() const noexcept {\n return time{uv_now(uv_loop.get())};\n}\n\nUVW_INLINE void loop::update() const noexcept {\n return uv_update_time(uv_loop.get());\n}\n\nUVW_INLINE int loop::fork() noexcept {\n...\n", "needle_token_start": 9978, "needle_token_end": 10157, "code_context_ntokens": 14097, "output": ["```cpp\n/**\n * @brief Retrieves an environment variable.\n * @param name The name of the variable to be retrieved.\n * @return The value of the environment variable, an empty string in case of errors.\n */\nstatic std::string env(const std::string &name) noexcept {\n return details::try_read(&uv_os_getenv, name.c_str());\n}\n\n/**\n * @brief Creates, updates or deletes an environment variable.\n * @param name The name of the variable to be updated.\n * @param value The value to be used for the variable (an empty string to unset it).\n * @return True in case of success, false otherwise.\n */\nstatic bool env(const std::string &name, const std::string &value) noexcept {\n return (0 == (value.empty() ? uv_os_unsetenv(name.c_str()) : uv_os_setenv(name.c_str(), value.c_str())));\n}\n\n/**\n * @brief Retrieves all environment variables and iterates them.\n *\n * Environment variables are passed one at a time to the callback in the form of `std::string_view`s.
\n * The signature of the function call operator must be such that it\n * accepts two parameters, the name and the value of the i-th variable.\n *\n * @tparam Func Type of a function object to which to pass environment variables.\n * @param func A function object to which to pass environment variables.\n * @return True in case of success, false otherwise.\n */\ntemplate\nstatic std::enable_if_t, bool>\nenv(Func func) noexcept {\n uv_env_item_t *items = nullptr;\n int count{};\n\n const bool ret = (uv_os_environ(&items, &count) == 0);\n\n if(ret) {\n for(int pos = 0; pos < count; ++pos) {\n func(std::string_view{items[pos].name}, std::string_view{items[pos].value});\n }\n\n uv_os_free_environ(items, count);\n }\n\n return ret;\n}\n```"]} +{"repo": "skypjack/uvw", "name": "fnv1a", "language": "cpp", "path": "src/uvw/type_info.hpp", "position_ratio": 0.85, "description": "\nFunction Description:\n1. **Purpose**: The function is designed to compute a hash value for a string using a specific algorithm that is efficient and has a low collision rate, making it suitable for hash table implementations and other applications where a unique identifier for strings is needed.\n2. **Input**: It accepts a pointer to a constant character array (C-style string).\n3. **Output**: It returns a 32-bit unsigned integer representing the computed hash value of the input string.\n4. **Procedure**: The function initializes a hash value with a predefined constant. It then iterates through each character of the input string, combining the current hash value and the character's integer value using an exclusive OR operation, followed by multiplication with a prime number. This process repeats until the end of the string is reached, and the final hash value is returned.\n\n", "instruction": "Based on the function description and code context, please retrieve and repeat the exact described function from the code context in a code block wrapped by ```:", "template": "instruction\ncode_context\ndescription\ninstruction", "code_context": "// Path: src/uvw/config.h\n#ifndef UVW_CONFIG_H\n#define UVW_CONFIG_H\n\n#ifndef UVW_AS_LIB\n# define UVW_INLINE inline\n#else\n# define UVW_INLINE\n#endif\n\n#endif\n\n// Path: src/uvw/type_info.hpp\n#ifndef UVW_TYPE_INFO_INCLUDE_HPP\n#define UVW_TYPE_INFO_INCLUDE_HPP\n\n#include \n#include \"config.h\"\n\nnamespace uvw {\n\n/**\n * @cond TURN_OFF_DOXYGEN\n * Internal details not to be documented.\n */\n\nnamespace internal {\n\n// Fowler-Noll-Vo hash function v. 1a - the good\n\n[[nodiscard]] static constexpr std::uint32_t fnv1a(const char *curr) noexcept {\n constexpr std::uint32_t offset = 2166136261;\n constexpr std::uint32_t prime = 16777619;\n auto value = offset;\n\n while(*curr != 0) {\n auto curr_val_int = static_cast(*(curr++));\n value = (value ^ curr_val_int) * prime;\n }\n\n return value;\n}\n\n[[nodiscard]] static inline std::uint32_t counter() noexcept {\n static std::uint32_t cnt{};\n return cnt++;\n}\n\ntemplate\n[[nodiscard]] static std::uint32_t fake() noexcept {\n static std::uint32_t local = counter();\n return local;\n}\n\n} // namespace internal\n\n/**\n * Internal details not to be documented.\n * @endcond\n */\n\n/**\n * @brief Returns a numerical identifier for a given type.\n * @tparam Type The type for which to return the numerical identifier.\n * @return The numerical identifier of the give type.\n */\ntemplate\n[[nodiscard]] static constexpr std::uint32_t type() noexcept {\n#if defined __clang__ || defined __GNUC__\n return internal::fnv1a(__PRETTY_FUNCTION__);\n#elif defined _MSC_VER\n return internal::fnv1a(__FUNCSIG__);\n#else\n return internal::fake();\n#endif\n}\n\n} // namespace uvw\n\n#endif // UVW_TYPE_INFO_INCLUDE_HPP\n\n// Path: src/uvw/emitter.cpp\n#ifdef UVW_AS_LIB\n# include \"emitter.h\"\n#endif\n\n#include \"config.h\"\n\nnamespace uvw {\n\nUVW_INLINE int error_event::translate(int sys) noexcept {\n return uv_translate_sys_error(sys);\n}\n\nUVW_INLINE const char *error_event::what() const noexcept {\n return uv_strerror(ec);\n}\n\nUVW_INLINE const char *error_event::name() const noexcept {\n return uv_err_name(ec);\n}\n\nUVW_INLINE int error_event::code() const noexcept {\n return ec;\n}\n\nUVW_INLINE error_event::operator bool() const noexcept {\n return ec < 0;\n}\n\n} // namespace uvw\n\n// Path: src/uvw/emitter.h\n#ifndef UVW_EMITTER_INCLUDE_H\n#define UVW_EMITTER_INCLUDE_H\n\n#include \n#include \n#include \n#include \n#include \n#include \n#include \n#include \n#include \n#include \"config.h\"\n#include \"type_info.hpp\"\n\nnamespace uvw {\n\n/**\n * @brief Error event.\n *\n * Custom wrapper around error constants of `libuv`.\n */\nstruct error_event {\n template>>\n explicit error_event(Type val) noexcept\n : ec{static_cast(val)} {}\n\n /**\n * @brief Returns the `libuv` error code equivalent to the given platform dependent error code.\n *\n * It returns:\n * * POSIX error codes on Unix (the ones stored in errno).\n * * Win32 error codes on Windows (those returned by GetLastError() or WSAGetLastError()).\n *\n * If `sys` is already a `libuv` error code, it is simply returned.\n *\n * @param sys A platform dependent error code.\n * @return The `libuv` error code equivalent to the given platform dependent error code.\n */\n static int translate(int sys) noexcept;\n\n /**\n * @brief Returns the error message for the given error code.\n *\n * Leaks a few bytes of memory when you call it with an unknown error code.\n *\n * @return The error message for the given error code.\n */\n const char *what() const noexcept;\n\n /**\n * @brief Returns the error name for the given error code.\n *\n * Leaks a few bytes of memory when you call it with an unknown error code.\n *\n * @return The error name for the given error code.\n */\n const char *name() const noexcept;\n\n /**\n * @brief Gets the underlying error code, that is an error constant of `libuv`.\n * @return The underlying error code.\n */\n int code() const noexcept;\n\n /**\n * @brief Checks if the event contains a valid error code.\n * @return True in case of success, false otherwise.\n */\n explicit operator bool() const noexcept;\n\nprivate:\n const int ec;\n};\n\n/**\n * @brief Event emitter base class.\n *\n * Almost everything in `uvw` is an event emitter.
\n * This is the base class from which resources and loops inherit.\n */\ntemplate\nclass emitter {\npublic:\n template\n using listener_t = std::function;\n\nprivate:\n template\n const auto &handler() const noexcept {\n return std::get>(handlers);\n }\n\n template\n auto &handler() noexcept {\n return std::get>(handlers);\n }\n\nprotected:\n template\n void publish(Type event) {\n if(auto &listener = handler(); listener) {\n listener(event, *static_cast(this));\n }\n }\n\npublic:\n virtual ~emitter() noexcept {\n static_assert(std::is_base_of_v, Elem>);\n }\n\n /**\n * @brief Registers a long-lived listener with the event emitter.\n *\n * This method is used to register a listener with the emitter.
\n * A listener is usually defined as a callable object assignable to a\n * `std::function\n void on(listener_t f) {\n handler() = std::move(f);\n }\n\n /*! @brief Disconnects the listener for the given event type. */\n template\n void reset() noexcept {\n handler() = nullptr;\n }\n\n /*! @brief Disconnects all listeners. */\n void reset() noexcept {\n reset();\n (reset(), ...);\n }\n\n /**\n * @brief Checks if there is a listener registered for the specific event.\n * @return True if there is a listener registered for the specific event,\n * false otherwise.\n */\n template\n bool has() const noexcept {\n return static_cast(handler());\n }\n\nprivate:\n std::tuple, listener_t...> handlers{};\n};\n\n} // namespace uvw\n\n#ifndef UVW_AS_LIB\n# include \"emitter.cpp\"\n#endif\n\n#endif // UVW_EMITTER_INCLUDE_H\n\n// Path: src/uvw/util.cpp\n#ifdef UVW_AS_LIB\n# include \"util.h\"\n#endif\n\n#include \"config.h\"\n\nnamespace uvw {\n\nUVW_INLINE passwd_info::passwd_info(std::shared_ptr pwd)\n : value{pwd} {}\n\nUVW_INLINE std::string passwd_info::username() const noexcept {\n return ((value && value->username) ? value->username : \"\");\n}\n\nUVW_INLINE decltype(uv_passwd_t::uid) passwd_info::uid() const noexcept {\n return (value ? value->uid : decltype(uv_passwd_t::uid){});\n}\n\nUVW_INLINE decltype(uv_passwd_t::gid) passwd_info::gid() const noexcept {\n return (value ? value->gid : decltype(uv_passwd_t::gid){});\n}\n\nUVW_INLINE std::string passwd_info::shell() const noexcept {\n return ((value && value->shell) ? value->shell : \"\");\n}\n\nUVW_INLINE std::string passwd_info::homedir() const noexcept {\n return ((value && value->homedir) ? value->homedir : \"\");\n}\n\nUVW_INLINE passwd_info::operator bool() const noexcept {\n return static_cast(value);\n}\n\nUVW_INLINE uts_name::uts_name(std::shared_ptr init)\n : uname{init} {}\n\nUVW_INLINE std::string uts_name::sysname() const noexcept {\n...\n", "needle_token_start": 211, "needle_token_end": 358, "code_context_ntokens": 2670, "output": ["```cpp\n[[nodiscard]] static constexpr std::uint32_t fnv1a(const char *curr) noexcept {\n constexpr std::uint32_t offset = 2166136261;\n constexpr std::uint32_t prime = 16777619;\n auto value = offset;\n\n while(*curr != 0) {\n auto curr_val_int = static_cast(*(curr++));\n value = (value ^ curr_val_int) * prime;\n }\n\n return value;\n}\n```"]} +{"repo": "skypjack/uvw", "name": "publish", "language": "cpp", "path": "src/uvw/emitter.h", "position_ratio": 0.95, "description": "\nFunction Description:\n1. **Purpose**: The function is designed to notify registered listeners about specific events occurring within an object, allowing for event-driven responses.\n2. **Input**: The function takes a single parameter representing an event of a generic type.\n3. **Output**: There is no direct output from the function; however, it triggers side effects by invoking event listeners that may perform actions based on the event.\n4. **Procedure**: The function first retrieves the event listener associated with the event's type. If a listener is registered, it invokes this listener, passing the event and a reference to the object from which the event originated.\n\n", "instruction": "Based on the function description and code context, please retrieve and repeat the exact described function from the code context in a code block wrapped by ```:", "template": "instruction\ncode_context\ndescription\ninstruction", "code_context": "// Path: src/uvw/config.h\n#ifndef UVW_CONFIG_H\n#define UVW_CONFIG_H\n\n#ifndef UVW_AS_LIB\n# define UVW_INLINE inline\n#else\n# define UVW_INLINE\n#endif\n\n#endif\n\n// Path: src/uvw/type_info.hpp\n#ifndef UVW_TYPE_INFO_INCLUDE_HPP\n#define UVW_TYPE_INFO_INCLUDE_HPP\n\n#include \n#include \"config.h\"\n\nnamespace uvw {\n\n/**\n * @cond TURN_OFF_DOXYGEN\n * Internal details not to be documented.\n */\n\nnamespace internal {\n\n// Fowler-Noll-Vo hash function v. 1a - the good\n[[nodiscard]] static constexpr std::uint32_t fnv1a(const char *curr) noexcept {\n constexpr std::uint32_t offset = 2166136261;\n constexpr std::uint32_t prime = 16777619;\n auto value = offset;\n\n while(*curr != 0) {\n auto curr_val_int = static_cast(*(curr++));\n value = (value ^ curr_val_int) * prime;\n }\n\n return value;\n}\n\n[[nodiscard]] static inline std::uint32_t counter() noexcept {\n static std::uint32_t cnt{};\n return cnt++;\n}\n\ntemplate\n[[nodiscard]] static std::uint32_t fake() noexcept {\n static std::uint32_t local = counter();\n return local;\n}\n\n} // namespace internal\n\n/**\n * Internal details not to be documented.\n * @endcond\n */\n\n/**\n * @brief Returns a numerical identifier for a given type.\n * @tparam Type The type for which to return the numerical identifier.\n * @return The numerical identifier of the give type.\n */\ntemplate\n[[nodiscard]] static constexpr std::uint32_t type() noexcept {\n#if defined __clang__ || defined __GNUC__\n return internal::fnv1a(__PRETTY_FUNCTION__);\n#elif defined _MSC_VER\n return internal::fnv1a(__FUNCSIG__);\n#else\n return internal::fake();\n#endif\n}\n\n} // namespace uvw\n\n#endif // UVW_TYPE_INFO_INCLUDE_HPP\n\n// Path: src/uvw/emitter.cpp\n#ifdef UVW_AS_LIB\n# include \"emitter.h\"\n#endif\n\n#include \"config.h\"\n\nnamespace uvw {\n\nUVW_INLINE int error_event::translate(int sys) noexcept {\n return uv_translate_sys_error(sys);\n}\n\nUVW_INLINE const char *error_event::what() const noexcept {\n return uv_strerror(ec);\n}\n\nUVW_INLINE const char *error_event::name() const noexcept {\n return uv_err_name(ec);\n}\n\nUVW_INLINE int error_event::code() const noexcept {\n return ec;\n}\n\nUVW_INLINE error_event::operator bool() const noexcept {\n return ec < 0;\n}\n\n} // namespace uvw\n\n// Path: src/uvw/emitter.h\n#ifndef UVW_EMITTER_INCLUDE_H\n#define UVW_EMITTER_INCLUDE_H\n\n#include \n#include \n#include \n#include \n#include \n#include \n#include \n#include \n#include \n#include \"config.h\"\n#include \"type_info.hpp\"\n\nnamespace uvw {\n\n/**\n * @brief Error event.\n *\n * Custom wrapper around error constants of `libuv`.\n */\nstruct error_event {\n template>>\n explicit error_event(Type val) noexcept\n : ec{static_cast(val)} {}\n\n /**\n * @brief Returns the `libuv` error code equivalent to the given platform dependent error code.\n *\n * It returns:\n * * POSIX error codes on Unix (the ones stored in errno).\n * * Win32 error codes on Windows (those returned by GetLastError() or WSAGetLastError()).\n *\n * If `sys` is already a `libuv` error code, it is simply returned.\n *\n * @param sys A platform dependent error code.\n * @return The `libuv` error code equivalent to the given platform dependent error code.\n */\n static int translate(int sys) noexcept;\n\n /**\n * @brief Returns the error message for the given error code.\n *\n * Leaks a few bytes of memory when you call it with an unknown error code.\n *\n * @return The error message for the given error code.\n */\n const char *what() const noexcept;\n\n /**\n * @brief Returns the error name for the given error code.\n *\n * Leaks a few bytes of memory when you call it with an unknown error code.\n *\n * @return The error name for the given error code.\n */\n const char *name() const noexcept;\n\n /**\n * @brief Gets the underlying error code, that is an error constant of `libuv`.\n * @return The underlying error code.\n */\n int code() const noexcept;\n\n /**\n * @brief Checks if the event contains a valid error code.\n * @return True in case of success, false otherwise.\n */\n explicit operator bool() const noexcept;\n\nprivate:\n const int ec;\n};\n\n/**\n * @brief Event emitter base class.\n *\n * Almost everything in `uvw` is an event emitter.
\n * This is the base class from which resources and loops inherit.\n */\ntemplate\nclass emitter {\npublic:\n template\n using listener_t = std::function;\n\nprivate:\n template\n const auto &handler() const noexcept {\n return std::get>(handlers);\n }\n\n template\n auto &handler() noexcept {\n return std::get>(handlers);\n }\n\nprotected:\n template\n \nvoid publish(Type event) {\n if(auto &listener = handler(); listener) {\n listener(event, *static_cast(this));\n }\n }\n\npublic:\n virtual ~emitter() noexcept {\n static_assert(std::is_base_of_v, Elem>);\n }\n\n /**\n * @brief Registers a long-lived listener with the event emitter.\n *\n * This method is used to register a listener with the emitter.
\n * A listener is usually defined as a callable object assignable to a\n * `std::function\n void on(listener_t f) {\n handler() = std::move(f);\n }\n\n /*! @brief Disconnects the listener for the given event type. */\n template\n void reset() noexcept {\n handler() = nullptr;\n }\n\n /*! @brief Disconnects all listeners. */\n void reset() noexcept {\n reset();\n (reset(), ...);\n }\n\n /**\n * @brief Checks if there is a listener registered for the specific event.\n * @return True if there is a listener registered for the specific event,\n * false otherwise.\n */\n template\n bool has() const noexcept {\n return static_cast(handler());\n }\n\nprivate:\n std::tuple, listener_t...> handlers{};\n};\n\n} // namespace uvw\n\n#ifndef UVW_AS_LIB\n# include \"emitter.cpp\"\n#endif\n\n#endif // UVW_EMITTER_INCLUDE_H\n\n// Path: src/uvw/util.cpp\n#ifdef UVW_AS_LIB\n# include \"util.h\"\n#endif\n\n#include \"config.h\"\n\nnamespace uvw {\n\nUVW_INLINE passwd_info::passwd_info(std::shared_ptr pwd)\n : value{pwd} {}\n\nUVW_INLINE std::string passwd_info::username() const noexcept {\n return ((value && value->username) ? value->username : \"\");\n}\n\nUVW_INLINE decltype(uv_passwd_t::uid) passwd_info::uid() const noexcept {\n return (value ? value->uid : decltype(uv_passwd_t::uid){});\n}\n\nUVW_INLINE decltype(uv_passwd_t::gid) passwd_info::gid() const noexcept {\n return (value ? value->gid : decltype(uv_passwd_t::gid){});\n}\n\nUVW_INLINE std::string passwd_info::shell() const noexcept {\n return ((value && value->shell) ? value->shell : \"\");\n}\n\n...\n", "needle_token_start": 1716, "needle_token_end": 1761, "code_context_ntokens": 2527, "output": ["```cpp\ntemplate\nvoid publish(Type event) {\n if(auto &listener = handler(); listener) {\n listener(event, *static_cast(this));\n }\n}\n```"]} +{"repo": "ClickHouse/clickhouse-cpp", "name": "CompateStringsCaseInsensitive", "language": "cpp", "path": "clickhouse/types/type_parser.cpp", "position_ratio": 0.05, "description": "\nFunction Description:\n1. **Purpose**: The function is designed to compare two strings in a case-insensitive manner, primarily to determine their equality or order without considering the case of the characters.\n2. **Input**: It takes two string-like objects as parameters.\n3. **Output**: Returns an integer: 0 if both strings are equal (ignoring case), a positive value if the first string is lexicographically greater, and a negative value if the second string is greater.\n4. **Procedure**: The function first checks if the sizes of the two strings differ and returns accordingly. If the sizes are the same, it performs a case-insensitive comparison using platform-specific functions to compare the strings for a specified length equal to their common size.\n\n", "instruction": "Based on the function description and code context, please retrieve and repeat the exact described function from the code context in a code block wrapped by ```:", "template": "instruction\ncode_context\ndescription\ninstruction", "code_context": "#include \"clickhouse/base/platform.h\" // for _win_\n\n#include \n#include \n#include \n#include \n#include \n\n#if defined _win_\n#include \n#else\n#include \n#endif\n\n\nnamespace clickhouse {\n\nbool TypeAst::operator==(const TypeAst & other) const {\n return meta == other.meta\n && code == other.code\n && name == other.name\n && value == other.value\n && std::equal(elements.begin(), elements.end(), other.elements.begin(), other.elements.end());\n}\n\nstatic const std::unordered_map kTypeCode = {\n { \"Void\", Type::Void },\n { \"Int8\", Type::Int8 },\n { \"Int16\", Type::Int16 },\n { \"Int32\", Type::Int32 },\n { \"Int64\", Type::Int64 },\n { \"Bool\", Type::UInt8 },\n { \"UInt8\", Type::UInt8 },\n { \"UInt16\", Type::UInt16 },\n { \"UInt32\", Type::UInt32 },\n { \"UInt64\", Type::UInt64 },\n { \"Float32\", Type::Float32 },\n { \"Float64\", Type::Float64 },\n { \"String\", Type::String },\n { \"FixedString\", Type::FixedString },\n { \"DateTime\", Type::DateTime },\n { \"DateTime64\", Type::DateTime64 },\n { \"Date\", Type::Date },\n { \"Date32\", Type::Date32 },\n { \"Array\", Type::Array },\n { \"Nullable\", Type::Nullable },\n { \"Tuple\", Type::Tuple },\n { \"Enum8\", Type::Enum8 },\n { \"Enum16\", Type::Enum16 },\n { \"UUID\", Type::UUID },\n { \"IPv4\", Type::IPv4 },\n { \"IPv6\", Type::IPv6 },\n { \"Int128\", Type::Int128 },\n// { \"UInt128\", Type::UInt128 },\n { \"Decimal\", Type::Decimal },\n { \"Decimal32\", Type::Decimal32 },\n { \"Decimal64\", Type::Decimal64 },\n { \"Decimal128\", Type::Decimal128 },\n { \"LowCardinality\", Type::LowCardinality },\n { \"Map\", Type::Map },\n { \"Point\", Type::Point },\n { \"Ring\", Type::Ring },\n { \"Polygon\", Type::Polygon },\n { \"MultiPolygon\", Type::MultiPolygon },\n};\n\ntemplate \n\ninline int CompateStringsCaseInsensitive(const L& left, const R& right) {\n int64_t size_diff = left.size() - right.size();\n if (size_diff != 0)\n return size_diff > 0 ? 1 : -1;\n\n#if defined _win_\n return _strnicmp(left.data(), right.data(), left.size());\n#else\n return strncasecmp(left.data(), right.data(), left.size());\n#endif\n}\n\nstatic Type::Code GetTypeCode(const std::string& name) {\n auto it = kTypeCode.find(name);\n if (it != kTypeCode.end()) {\n return it->second;\n }\n\n return Type::Void;\n}\n\nstatic TypeAst::Meta GetTypeMeta(const StringView& name) {\n if (name == \"Array\") {\n return TypeAst::Array;\n }\n\n if (name == \"Null\") {\n return TypeAst::Null;\n }\n\n if (name == \"Nullable\") {\n return TypeAst::Nullable;\n }\n\n if (name == \"Tuple\") {\n return TypeAst::Tuple;\n }\n\n if (name == \"Enum8\" || name == \"Enum16\") {\n return TypeAst::Enum;\n }\n\n if (name == \"LowCardinality\") {\n return TypeAst::LowCardinality;\n }\n\n if (name == \"SimpleAggregateFunction\") {\n return TypeAst::SimpleAggregateFunction;\n }\n\n if (name == \"Map\") {\n return TypeAst::Map;\n }\n\n return TypeAst::Terminal;\n}\n\nbool ValidateAST(const TypeAst& ast) {\n // Void terminal that is not actually \"void\" produced when unknown type is encountered.\n if (ast.meta == TypeAst::Terminal\n && ast.code == Type::Void\n && CompateStringsCaseInsensitive(ast.name, std::string_view(\"void\")) != 0)\n //throw UnimplementedError(\"Unsupported type: \" + ast.name);\n return false;\n\n return true;\n}\n\n\nTypeParser::TypeParser(const StringView& name)\n : cur_(name.data())\n , end_(name.data() + name.size())\n , type_(nullptr)\n{\n}\n\nTypeParser::~TypeParser() = default;\n\nbool TypeParser::Parse(TypeAst* type) {\n type_ = type;\n open_elements_.push(type_);\n\n size_t processed_tokens = 0;\n do {\n const Token & token = NextToken();\n switch (token.type) {\n case Token::QuotedString:\n {\n type_->meta = TypeAst::Terminal;\n if (token.value.length() < 1)\n type_->value_string = {};\n else\n type_->value_string = token.value.substr(1, token.value.length() - 2).to_string();\n type_->code = Type::String;\n break;\n }\n case Token::Name:\n type_->meta = GetTypeMeta(token.value);\n type_->name = token.value.to_string();\n type_->code = GetTypeCode(type_->name);\n break;\n case Token::Number:\n type_->meta = TypeAst::Number;\n type_->value = std::stol(token.value.to_string());\n break;\n case Token::String:\n type_->meta = TypeAst::String;\n type_->value_string = std::string(token.value);\n break;\n case Token::LPar:\n type_->elements.emplace_back(TypeAst());\n open_elements_.push(type_);\n type_ = &type_->elements.back();\n break;\n case Token::RPar:\n type_ = open_elements_.top();\n open_elements_.pop();\n break;\n case Token::Assign:\n case Token::Comma:\n type_ = open_elements_.top();\n open_elements_.pop();\n type_->elements.emplace_back(TypeAst());\n open_elements_.push(type_);\n type_ = &type_->elements.back();\n break;\n case Token::EOS:\n {\n // Ubalanced braces, brackets, etc is an error.\n if (open_elements_.size() != 1)\n return false;\n\n // Empty input string, no tokens produced\n if (processed_tokens == 0)\n return false;\n\n return ValidateAST(*type);\n }\n case Token::Invalid:\n return false;\n }\n ++processed_tokens;\n } while (true);\n}\n\nTypeParser::Token TypeParser::NextToken() {\n for (; cur_ < end_; ++cur_) {\n switch (*cur_) {\n case ' ':\n case '\\n':\n case '\\t':\n case '\\0':\n continue;\n case '=':\n return Token{Token::Assign, StringView(cur_++, 1)};\n case '(':\n return Token{Token::LPar, StringView(cur_++, 1)};\n case ')':\n return Token{Token::RPar, StringView(cur_++, 1)};\n case ',':\n return Token{Token::Comma, StringView(cur_++, 1)};\n case '\\'':\n {\n const auto end_quote_length = 1;\n const StringView end_quote{cur_, end_quote_length};\n // Fast forward to the closing quote.\n const auto start = cur_++;\n for (; cur_ < end_ - end_quote_length; ++cur_) {\n // TODO (nemkov): handle escaping ?\n if (end_quote == StringView{cur_, end_quote_length}) {\n cur_ += end_quote_length;\n\n return Token{Token::QuotedString, StringView{start, cur_}};\n }\n }\n return Token{Token::QuotedString, StringView(cur_++, 1)};\n }\n\n default: {\n const char* st = cur_;\n\n if (*cur_ == '\\'') {\n for (st = ++cur_; cur_ < end_; ++cur_) {\n if (*cur_ == '\\'') {\n return Token{Token::String, StringView(st, cur_++ - st)};\n }\n }\n\n return Token{Token::Invalid, StringView()};\n }\n\n if (isalpha(*cur_) || *cur_ == '_') {\n for (; cur_ < end_; ++cur_) {\n if (!isalpha(*cur_) && !isdigit(*cur_) && *cur_ != '_') {\n break;\n }\n }\n\n return Token{Token::Name, StringView(st, cur_)};\n }\n\n if (isdigit(*cur_) || *cur_ == '-') {\n for (++cur_; cur_ < end_; ++cur_) {\n if (!isdigit(*cur_)) {\n break;\n }\n }\n\n return Token{Token::Number, StringView(st, cur_)};\n }\n\n return Token{Token::Invalid, StringView()};\n }\n }\n }\n\n return Token{Token::EOS, StringView()};\n}\n\n\nconst TypeAst* ParseTypeName(const std::string& type_name) {\n // Cache for type_name.\n // Usually we won't have too many type names in the cache, so do not try to\n // limit cache size.\n static std::map ast_cache;\n static std::mutex lock;\n\n std::lock_guard guard(lock);\n auto it = ast_cache.find(type_name);\n if (it != ast_cache.end()) {\n return &it->second;\n }\n\n auto& ast = ast_cache[type_name];\n if (TypeParser(type_name).Parse(&ast)) {\n return *\n }\n ast_cache.erase(type_name);\n return nullptr;\n}\n\n}\n\n// Path: clickhouse/version.h\n#pragma once\n\n#define CLICKHOUSE_CPP_VERSION_MAJOR 2\n#define CLICKHOUSE_CPP_VERSION_MINOR 5\n#define CLICKHOUSE_CPP_VERSION_PATCH 1\n\n#define CLICKHOUSE_CPP_VERSION_BUILD 0\n\n// Expecting each version component to be less than 99\n#define CLICKHOUSE_CPP_VERSION \\\n CLICKHOUSE_CPP_VERSION_MAJOR * 100 * 100 * 100 \\\n + CLICKHOUSE_CPP_VERSION_MINOR * 100 * 100 \\\n + CLICKHOUSE_CPP_VERSION_PATCH * 100 \\\n + CLICKHOUSE_CPP_VERSION_BUILD\n\n// Path: clickhouse/error_codes.h\n#pragma once\n\nnamespace clickhouse {\n\n// based on https://github.com/ClickHouse/ClickHouse/blob/54ae88a859507722821624476c3818152f944055/src/Common/ErrorCodes.cpp\n// (master on 28 Feb 2024)\nenum ErrorCodes {\n OK = 0,\n UNSUPPORTED_METHOD = 1,\n UNSUPPORTED_PARAMETER = 2,\n UNEXPECTED_END_OF_FILE = 3,\n EXPECTED_END_OF_FILE = 4,\n CANNOT_PARSE_TEXT = 6,\n INCORRECT_NUMBER_OF_COLUMNS = 7,\n THERE_IS_NO_COLUMN = 8,\n SIZES_OF_COLUMNS_DOESNT_MATCH = 9,\n NOT_FOUND_COLUMN_IN_BLOCK = 10,\n POSITION_OUT_OF_BOUND = 11,\n PARAMETER_OUT_OF_BOUND = 12,\n SIZES_OF_COLUMNS_IN_TUPLE_DOESNT_MATCH = 13,\n DUPLICATE_COLUMN = 15,\n NO_SUCH_COLUMN_IN_TABLE = 16,\n SIZE_OF_FIXED_STRING_DOESNT_MATCH = 19,\n NUMBER_OF_COLUMNS_DOESNT_MATCH = 20,\n CANNOT_READ_FROM_ISTREAM = 23,\n CANNOT_WRITE_TO_OSTREAM = 24,\n CANNOT_PARSE_ESCAPE_SEQUENCE = 25,\n CANNOT_PARSE_QUOTED_STRING = 26,\n CANNOT_PARSE_INPUT_ASSERTION_FAILED = 27,\n CANNOT_PRINT_FLOAT_OR_DOUBLE_NUMBER = 28,\n ATTEMPT_TO_READ_AFTER_EOF = 32,\n CANNOT_READ_ALL_DATA = 33,\n TOO_MANY_ARGUMENTS_FOR_FUNCTION = 34,\n TOO_FEW_ARGUMENTS_FOR_FUNCTION = 35,\n BAD_ARGUMENTS = 36,\n UNKNOWN_ELEMENT_IN_AST = 37,\n CANNOT_PARSE_DATE = 38,\n TOO_LARGE_SIZE_COMPRESSED = 39,\n CHECKSUM_DOESNT_MATCH = 40,\n CANNOT_PARSE_DATETIME = 41,\n NUMBER_OF_ARGUMENTS_DOESNT_MATCH = 42,\n ILLEGAL_TYPE_OF_ARGUMENT = 43,\n ILLEGAL_COLUMN = 44,\n UNKNOWN_FUNCTION = 46,\n UNKNOWN_IDENTIFIER = 47,\n NOT_IMPLEMENTED = 48,\n LOGICAL_ERROR = 49,\n UNKNOWN_TYPE = 50,\n EMPTY_LIST_OF_COLUMNS_QUERIED = 51,\n COLUMN_QUERIED_MORE_THAN_ONCE = 52,\n TYPE_MISMATCH = 53,\n STORAGE_REQUIRES_PARAMETER = 55,\n UNKNOWN_STORAGE = 56,\n TABLE_ALREADY_EXISTS = 57,\n TABLE_METADATA_ALREADY_EXISTS = 58,\n ILLEGAL_TYPE_OF_COLUMN_FOR_FILTER = 59,\n UNKNOWN_TABLE = 60,\n SYNTAX_ERROR = 62,\n UNKNOWN_AGGREGATE_FUNCTION = 63,\n CANNOT_GET_SIZE_OF_FIELD = 68,\n ARGUMENT_OUT_OF_BOUND = 69,\n CANNOT_CONVERT_TYPE = 70,\n CANNOT_WRITE_AFTER_END_OF_BUFFER = 71,\n CANNOT_PARSE_NUMBER = 72,\n UNKNOWN_FORMAT = 73,\n CANNOT_READ_FROM_FILE_DESCRIPTOR = 74,\n CANNOT_WRITE_TO_FILE_DESCRIPTOR = 75,\n CANNOT_OPEN_FILE = 76,\n CANNOT_CLOSE_FILE = 77,\n UNKNOWN_TYPE_OF_QUERY = 78,\n INCORRECT_FILE_NAME = 79,\n INCORRECT_QUERY = 80,\n UNKNOWN_DATABASE = 81,\n DATABASE_ALREADY_EXISTS = 82,\n DIRECTORY_DOESNT_EXIST = 83,\n DIRECTORY_ALREADY_EXISTS = 84,\n FORMAT_IS_NOT_SUITABLE_FOR_INPUT = 85,\n RECEIVED_ERROR_FROM_REMOTE_IO_SERVER = 86,\n CANNOT_SEEK_THROUGH_FILE = 87,\n CANNOT_TRUNCATE_FILE = 88,\n UNKNOWN_COMPRESSION_METHOD = 89,\n EMPTY_LIST_OF_COLUMNS_PASSED = 90,\n SIZES_OF_MARKS_FILES_ARE_INCONSISTENT = 91,\n EMPTY_DATA_PASSED = 92,\n UNKNOWN_AGGREGATED_DATA_VARIANT = 93,\n CANNOT_MERGE_DIFFERENT_AGGREGATED_DATA_VARIANTS = 94,\n CANNOT_READ_FROM_SOCKET = 95,\n CANNOT_WRITE_TO_SOCKET = 96,\n UNKNOWN_PACKET_FROM_CLIENT = 99,\n UNKNOWN_PACKET_FROM_SERVER = 100,\n UNEXPECTED_PACKET_FROM_CLIENT = 101,\n UNEXPECTED_PACKET_FROM_SERVER = 102,\n TOO_SMALL_BUFFER_SIZE = 104,\n FILE_DOESNT_EXIST = 107,\n NO_DATA_TO_INSERT = 108,\n CANNOT_BLOCK_SIGNAL = 109,\n CANNOT_UNBLOCK_SIGNAL = 110,\n CANNOT_MANIPULATE_SIGSET = 111,\n CANNOT_WAIT_FOR_SIGNAL = 112,\n THERE_IS_NO_SESSION = 113,\n CANNOT_CLOCK_GETTIME = 114,\n UNKNOWN_SETTING = 115,\n THERE_IS_NO_DEFAULT_VALUE = 116,\n INCORRECT_DATA = 117,\n ENGINE_REQUIRED = 119,\n CANNOT_INSERT_VALUE_OF_DIFFERENT_SIZE_INTO_TUPLE = 120,\n UNSUPPORTED_JOIN_KEYS = 121,\n INCOMPATIBLE_COLUMNS = 122,\n UNKNOWN_TYPE_OF_AST_NODE = 123,\n INCORRECT_ELEMENT_OF_SET = 124,\n INCORRECT_RESULT_OF_SCALAR_SUBQUERY = 125,\n ILLEGAL_INDEX = 127,\n TOO_LARGE_ARRAY_SIZE = 128,\n FUNCTION_IS_SPECIAL = 129,\n CANNOT_READ_ARRAY_FROM_TEXT = 130,\n TOO_LARGE_STRING_SIZE = 131,\n AGGREGATE_FUNCTION_DOESNT_ALLOW_PARAMETERS = 133,\n PARAMETERS_TO_AGGREGATE_FUNCTIONS_MUST_BE_LITERALS = 134,\n ZERO_ARRAY_OR_TUPLE_INDEX = 135,\n UNKNOWN_ELEMENT_IN_CONFIG = 137,\n EXCESSIVE_ELEMENT_IN_CONFIG = 138,\n NO_ELEMENTS_IN_CONFIG = 139,\n SAMPLING_NOT_SUPPORTED = 141,\n NOT_FOUND_NODE = 142,\n UNKNOWN_OVERFLOW_MODE = 145,\n UNKNOWN_DIRECTION_OF_SORTING = 152,\n ILLEGAL_DIVISION = 153,\n DICTIONARIES_WAS_NOT_LOADED = 156,\n TOO_MANY_ROWS = 158,\n TIMEOUT_EXCEEDED = 159,\n TOO_SLOW = 160,\n TOO_MANY_COLUMNS = 161,\n TOO_DEEP_SUBQUERIES = 162,\n READONLY = 164,\n TOO_MANY_TEMPORARY_COLUMNS = 165,\n TOO_MANY_TEMPORARY_NON_CONST_COLUMNS = 166,\n TOO_DEEP_AST = 167,\n TOO_BIG_AST = 168,\n BAD_TYPE_OF_FIELD = 169,\n BAD_GET = 170,\n CANNOT_CREATE_DIRECTORY = 172,\n CANNOT_ALLOCATE_MEMORY = 173,\n CYCLIC_ALIASES = 174,\n MULTIPLE_EXPRESSIONS_FOR_ALIAS = 179,\n THERE_IS_NO_PROFILE = 180,\n ILLEGAL_FINAL = 181,\n ILLEGAL_PREWHERE = 182,\n UNEXPECTED_EXPRESSION = 183,\n ILLEGAL_AGGREGATION = 184,\n UNSUPPORTED_COLLATION_LOCALE = 186,\n COLLATION_COMPARISON_FAILED = 187,\n SIZES_OF_ARRAYS_DONT_MATCH = 190,\n SET_SIZE_LIMIT_EXCEEDED = 191,\n UNKNOWN_USER = 192,\n WRONG_PASSWORD = 193,\n REQUIRED_PASSWORD = 194,\n IP_ADDRESS_NOT_ALLOWED = 195,\n UNKNOWN_ADDRESS_PATTERN_TYPE = 196,\n DNS_ERROR = 198,\n UNKNOWN_QUOTA = 199,\n QUOTA_EXCEEDED = 201,\n TOO_MANY_SIMULTANEOUS_QUERIES = 202,\n NO_FREE_CONNECTION = 203,\n CANNOT_FSYNC = 204,\n ALIAS_REQUIRED = 206,\n AMBIGUOUS_IDENTIFIER = 207,\n EMPTY_NESTED_TABLE = 208,\n SOCKET_TIMEOUT = 209,\n NETWORK_ERROR = 210,\n EMPTY_QUERY = 211,\n UNKNOWN_LOAD_BALANCING = 212,\n UNKNOWN_TOTALS_MODE = 213,\n CANNOT_STATVFS = 214,\n NOT_AN_AGGREGATE = 215,\n QUERY_WITH_SAME_ID_IS_ALREADY_RUNNING = 216,\n CLIENT_HAS_CONNECTED_TO_WRONG_PORT = 217,\n TABLE_IS_DROPPED = 218,\n DATABASE_NOT_EMPTY = 219,\n DUPLICATE_INTERSERVER_IO_ENDPOINT = 220,\n NO_SUCH_INTERSERVER_IO_ENDPOINT = 221,\n UNEXPECTED_AST_STRUCTURE = 223,\n REPLICA_IS_ALREADY_ACTIVE = 224,\n NO_ZOOKEEPER = 225,\n NO_FILE_IN_DATA_PART = 226,\n UNEXPECTED_FILE_IN_DATA_PART = 227,\n BAD_SIZE_OF_FILE_IN_DATA_PART = 228,\n QUERY_IS_TOO_LARGE = 229,\n NOT_FOUND_EXPECTED_DATA_PART = 230,\n TOO_MANY_UNEXPECTED_DATA_PARTS = 231,\n NO_SUCH_DATA_PART = 232,\n BAD_DATA_PART_NAME = 233,\n NO_REPLICA_HAS_PART = 234,\n DUPLICATE_DATA_PART = 235,\n ABORTED = 236,\n NO_REPLICA_NAME_GIVEN = 237,\n FORMAT_VERSION_TOO_OLD = 238,\n CANNOT_MUNMAP = 239,\n CANNOT_MREMAP = 240,\n MEMORY_LIMIT_EXCEEDED = 241,\n TABLE_IS_READ_ONLY = 242,\n NOT_ENOUGH_SPACE = 243,\n UNEXPECTED_ZOOKEEPER_ERROR = 244,\n CORRUPTED_DATA = 246,\n INVALID_PARTITION_VALUE = 248,\n NO_SUCH_REPLICA = 251,\n TOO_MANY_PARTS = 252,\n REPLICA_ALREADY_EXISTS = 253,\n NO_ACTIVE_REPLICAS = 254,\n TOO_MANY_RETRIES_TO_FETCH_PARTS = 255,\n PARTITION_ALREADY_EXISTS = 256,\n PARTITION_DOESNT_EXIST = 257,\n UNION_ALL_RESULT_STRUCTURES_MISMATCH = 258,\n CLIENT_OUTPUT_FORMAT_SPECIFIED = 260,\n UNKNOWN_BLOCK_INFO_FIELD = 261,\n BAD_COLLATION = 262,\n CANNOT_COMPILE_CODE = 263,\n INCOMPATIBLE_TYPE_OF_JOIN = 264,\n NO_AVAILABLE_REPLICA = 265,\n MISMATCH_REPLICAS_DATA_SOURCES = 266,\n INFINITE_LOOP = 269,\n CANNOT_COMPRESS = 270,\n CANNOT_DECOMPRESS = 271,\n CANNOT_IO_SUBMIT = 272,\n CANNOT_IO_GETEVENTS = 273,\n AIO_READ_ERROR = 274,\n AIO_WRITE_ERROR = 275,\n INDEX_NOT_USED = 277,\n ALL_CONNECTION_TRIES_FAILED = 279,\n NO_AVAILABLE_DATA = 280,\n DICTIONARY_IS_EMPTY = 281,\n INCORRECT_INDEX = 282,\n UNKNOWN_DISTRIBUTED_PRODUCT_MODE = 283,\n WRONG_GLOBAL_SUBQUERY = 284,\n TOO_FEW_LIVE_REPLICAS = 285,\n UNSATISFIED_QUORUM_FOR_PREVIOUS_WRITE = 286,\n UNKNOWN_FORMAT_VERSION = 287,\n DISTRIBUTED_IN_JOIN_SUBQUERY_DENIED = 288,\n REPLICA_IS_NOT_IN_QUORUM = 289,\n LIMIT_EXCEEDED = 290,\n DATABASE_ACCESS_DENIED = 291,\n MONGODB_CANNOT_AUTHENTICATE = 293,\n RECEIVED_EMPTY_DATA = 295,\n SHARD_HAS_NO_CONNECTIONS = 297,\n CANNOT_PIPE = 298,\n CANNOT_FORK = 299,\n CANNOT_DLSYM = 300,\n CANNOT_CREATE_CHILD_PROCESS = 301,\n CHILD_WAS_NOT_EXITED_NORMALLY = 302,\n CANNOT_SELECT = 303,\n CANNOT_WAITPID = 304,\n TABLE_WAS_NOT_DROPPED = 305,\n TOO_DEEP_RECURSION = 306,\n TOO_MANY_BYTES = 307,\n UNEXPECTED_NODE_IN_ZOOKEEPER = 308,\n FUNCTION_CANNOT_HAVE_PARAMETERS = 309,\n INVALID_CONFIG_PARAMETER = 318,\n UNKNOWN_STATUS_OF_INSERT = 319,\n VALUE_IS_OUT_OF_RANGE_OF_DATA_TYPE = 321,\n UNKNOWN_DATABASE_ENGINE = 336,\n UNFINISHED = 341,\n METADATA_MISMATCH = 342,\n SUPPORT_IS_DISABLED = 344,\n TABLE_DIFFERS_TOO_MUCH = 345,\n CANNOT_CONVERT_CHARSET = 346,\n CANNOT_LOAD_CONFIG = 347,\n CANNOT_INSERT_NULL_IN_ORDINARY_COLUMN = 349,\n AMBIGUOUS_COLUMN_NAME = 352,\n INDEX_OF_POSITIONAL_ARGUMENT_IS_OUT_OF_RANGE = 353,\n ZLIB_INFLATE_FAILED = 354,\n ZLIB_DEFLATE_FAILED = 355,\n INTO_OUTFILE_NOT_ALLOWED = 358,\n TABLE_SIZE_EXCEEDS_MAX_DROP_SIZE_LIMIT = 359,\n CANNOT_CREATE_CHARSET_CONVERTER = 360,\n SEEK_POSITION_OUT_OF_BOUND = 361,\n CURRENT_WRITE_BUFFER_IS_EXHAUSTED = 362,\n CANNOT_CREATE_IO_BUFFER = 363,\n RECEIVED_ERROR_TOO_MANY_REQUESTS = 364,\n SIZES_OF_NESTED_COLUMNS_ARE_INCONSISTENT = 366,\n ALL_REPLICAS_ARE_STALE = 369,\n DATA_TYPE_CANNOT_BE_USED_IN_TABLES = 370,\n INCONSISTENT_CLUSTER_DEFINITION = 371,\n SESSION_NOT_FOUND = 372,\n SESSION_IS_LOCKED = 373,\n INVALID_SESSION_TIMEOUT = 374,\n CANNOT_DLOPEN = 375,\n CANNOT_PARSE_UUID = 376,\n ILLEGAL_SYNTAX_FOR_DATA_TYPE = 377,\n DATA_TYPE_CANNOT_HAVE_ARGUMENTS = 378,\n CANNOT_KILL = 380,\n HTTP_LENGTH_REQUIRED = 381,\n CANNOT_LOAD_CATBOOST_MODEL = 382,\n CANNOT_APPLY_CATBOOST_MODEL = 383,\n PART_IS_TEMPORARILY_LOCKED = 384,\n MULTIPLE_STREAMS_REQUIRED = 385,\n NO_COMMON_TYPE = 386,\n DICTIONARY_ALREADY_EXISTS = 387,\n CANNOT_ASSIGN_OPTIMIZE = 388,\n INSERT_WAS_DEDUPLICATED = 389,\n CANNOT_GET_CREATE_TABLE_QUERY = 390,\n EXTERNAL_LIBRARY_ERROR = 391,\n QUERY_IS_PROHIBITED = 392,\n THERE_IS_NO_QUERY = 393,\n QUERY_WAS_CANCELLED = 394,\n FUNCTION_THROW_IF_VALUE_IS_NON_ZERO = 395,\n TOO_MANY_ROWS_OR_BYTES = 396,\n QUERY_IS_NOT_SUPPORTED_IN_MATERIALIZED_VIEW = 397,\n UNKNOWN_MUTATION_COMMAND = 398,\n FORMAT_IS_NOT_SUITABLE_FOR_OUTPUT = 399,\n CANNOT_STAT = 400,\n FEATURE_IS_NOT_ENABLED_AT_BUILD_TIME = 401,\n CANNOT_IOSETUP = 402,\n INVALID_JOIN_ON_EXPRESSION = 403,\n BAD_ODBC_CONNECTION_STRING = 404,\n TOP_AND_LIMIT_TOGETHER = 406,\n DECIMAL_OVERFLOW = 407,\n BAD_REQUEST_PARAMETER = 408,\n EXTERNAL_SERVER_IS_NOT_RESPONDING = 410,\n PTHREAD_ERROR = 411,\n NETLINK_ERROR = 412,\n CANNOT_SET_SIGNAL_HANDLER = 413,\n ALL_REPLICAS_LOST = 415,\n REPLICA_STATUS_CHANGED = 416,\n EXPECTED_ALL_OR_ANY = 417,\n UNKNOWN_JOIN = 418,\n MULTIPLE_ASSIGNMENTS_TO_COLUMN = 419,\n CANNOT_UPDATE_COLUMN = 420,\n CANNOT_ADD_DIFFERENT_AGGREGATE_STATES = 421,\n UNSUPPORTED_URI_SCHEME = 422,\n CANNOT_GETTIMEOFDAY = 423,\n CANNOT_LINK = 424,\n SYSTEM_ERROR = 425,\n CANNOT_COMPILE_REGEXP = 427,\n FAILED_TO_GETPWUID = 429,\n MISMATCHING_USERS_FOR_PROCESS_AND_DATA = 430,\n ILLEGAL_SYNTAX_FOR_CODEC_TYPE = 431,\n UNKNOWN_CODEC = 432,\n ILLEGAL_CODEC_PARAMETER = 433,\n CANNOT_PARSE_PROTOBUF_SCHEMA = 434,\n NO_COLUMN_SERIALIZED_TO_REQUIRED_PROTOBUF_FIELD = 435,\n PROTOBUF_BAD_CAST = 436,\n PROTOBUF_FIELD_NOT_REPEATED = 437,\n DATA_TYPE_CANNOT_BE_PROMOTED = 438,\n CANNOT_SCHEDULE_TASK = 439,\n INVALID_LIMIT_EXPRESSION = 440,\n CANNOT_PARSE_DOMAIN_VALUE_FROM_STRING = 441,\n BAD_DATABASE_FOR_TEMPORARY_TABLE = 442,\n NO_COLUMNS_SERIALIZED_TO_PROTOBUF_FIELDS = 443,\n UNKNOWN_PROTOBUF_FORMAT = 444,\n CANNOT_MPROTECT = 445,\n FUNCTION_NOT_ALLOWED = 446,\n HYPERSCAN_CANNOT_SCAN_TEXT = 447,\n BROTLI_READ_FAILED = 448,\n BROTLI_WRITE_FAILED = 449,\n BAD_TTL_EXPRESSION = 450,\n BAD_TTL_FILE = 451,\n SETTING_CONSTRAINT_VIOLATION = 452,\n MYSQL_CLIENT_INSUFFICIENT_CAPABILITIES = 453,\n OPENSSL_ERROR = 454,\n SUSPICIOUS_TYPE_FOR_LOW_CARDINALITY = 455,\n UNKNOWN_QUERY_PARAMETER = 456,\n BAD_QUERY_PARAMETER = 457,\n CANNOT_UNLINK = 458,\n CANNOT_SET_THREAD_PRIORITY = 459,\n CANNOT_CREATE_TIMER = 460,\n CANNOT_SET_TIMER_PERIOD = 461,\n CANNOT_FCNTL = 463,\n CANNOT_PARSE_ELF = 464,\n CANNOT_PARSE_DWARF = 465,\n INSECURE_PATH = 466,\n CANNOT_PARSE_BOOL = 467,\n CANNOT_PTHREAD_ATTR = 468,\n VIOLATED_CONSTRAINT = 469,\n QUERY_IS_NOT_SUPPORTED_IN_LIVE_VIEW = 470,\n INVALID_SETTING_VALUE = 471,\n READONLY_SETTING = 472,\n DEADLOCK_AVOIDED = 473,\n INVALID_TEMPLATE_FORMAT = 474,\n INVALID_WITH_FILL_EXPRESSION = 475,\n WITH_TIES_WITHOUT_ORDER_BY = 476,\n INVALID_USAGE_OF_INPUT = 477,\n UNKNOWN_POLICY = 478,\n UNKNOWN_DISK = 479,\n UNKNOWN_PROTOCOL = 480,\n PATH_ACCESS_DENIED = 481,\n DICTIONARY_ACCESS_DENIED = 482,\n TOO_MANY_REDIRECTS = 483,\n INTERNAL_REDIS_ERROR = 484,\n CANNOT_GET_CREATE_DICTIONARY_QUERY = 487,\n INCORRECT_DICTIONARY_DEFINITION = 489,\n CANNOT_FORMAT_DATETIME = 490,\n UNACCEPTABLE_URL = 491,\n ACCESS_ENTITY_NOT_FOUND = 492,\n ACCESS_ENTITY_ALREADY_EXISTS = 493,\n ACCESS_STORAGE_READONLY = 495,\n QUOTA_REQUIRES_CLIENT_KEY = 496,\n ACCESS_DENIED = 497,\n LIMIT_BY_WITH_TIES_IS_NOT_SUPPORTED = 498,\n S3_ERROR = 499,\n AZURE_BLOB_STORAGE_ERROR = 500,\n CANNOT_CREATE_DATABASE = 501,\n CANNOT_SIGQUEUE = 502,\n AGGREGATE_FUNCTION_THROW = 503,\n FILE_ALREADY_EXISTS = 504,\n UNABLE_TO_SKIP_UNUSED_SHARDS = 507,\n UNKNOWN_ACCESS_TYPE = 508,\n INVALID_GRANT = 509,\n CACHE_DICTIONARY_UPDATE_FAIL = 510,\n UNKNOWN_ROLE = 511,\n SET_NON_GRANTED_ROLE = 512,\n UNKNOWN_PART_TYPE = 513,\n ACCESS_STORAGE_FOR_INSERTION_NOT_FOUND = 514,\n INCORRECT_ACCESS_ENTITY_DEFINITION = 515,\n AUTHENTICATION_FAILED = 516,\n CANNOT_ASSIGN_ALTER = 517,\n CANNOT_COMMIT_OFFSET = 518,\n NO_REMOTE_SHARD_AVAILABLE = 519,\n CANNOT_DETACH_DICTIONARY_AS_TABLE = 520,\n ATOMIC_RENAME_FAIL = 521,\n UNKNOWN_ROW_POLICY = 523,\n ALTER_OF_COLUMN_IS_FORBIDDEN = 524,\n INCORRECT_DISK_INDEX = 525,\n NO_SUITABLE_FUNCTION_IMPLEMENTATION = 527,\n CASSANDRA_INTERNAL_ERROR = 528,\n NOT_A_LEADER = 529,\n CANNOT_CONNECT_RABBITMQ = 530,\n CANNOT_FSTAT = 531,\n LDAP_ERROR = 532,\n UNKNOWN_RAID_TYPE = 535,\n CANNOT_RESTORE_FROM_FIELD_DUMP = 536,\n ILLEGAL_MYSQL_VARIABLE = 537,\n MYSQL_SYNTAX_ERROR = 538,\n CANNOT_BIND_RABBITMQ_EXCHANGE = 539,\n CANNOT_DECLARE_RABBITMQ_EXCHANGE = 540,\n CANNOT_CREATE_RABBITMQ_QUEUE_BINDING = 541,\n CANNOT_REMOVE_RABBITMQ_EXCHANGE = 542,\n UNKNOWN_MYSQL_DATATYPES_SUPPORT_LEVEL = 543,\n ROW_AND_ROWS_TOGETHER = 544,\n FIRST_AND_NEXT_TOGETHER = 545,\n NO_ROW_DELIMITER = 546,\n INVALID_RAID_TYPE = 547,\n UNKNOWN_VOLUME = 548,\n DATA_TYPE_CANNOT_BE_USED_IN_KEY = 549,\n UNRECOGNIZED_ARGUMENTS = 552,\n LZMA_STREAM_ENCODER_FAILED = 553,\n LZMA_STREAM_DECODER_FAILED = 554,\n ROCKSDB_ERROR = 555,\n SYNC_MYSQL_USER_ACCESS_ERROR = 556,\n UNKNOWN_UNION = 557,\n EXPECTED_ALL_OR_DISTINCT = 558,\n INVALID_GRPC_QUERY_INFO = 559,\n ZSTD_ENCODER_FAILED = 560,\n ZSTD_DECODER_FAILED = 561,\n TLD_LIST_NOT_FOUND = 562,\n CANNOT_READ_MAP_FROM_TEXT = 563,\n INTERSERVER_SCHEME_DOESNT_MATCH = 564,\n TOO_MANY_PARTITIONS = 565,\n CANNOT_RMDIR = 566,\n DUPLICATED_PART_UUIDS = 567,\n RAFT_ERROR = 568,\n MULTIPLE_COLUMNS_SERIALIZED_TO_SAME_PROTOBUF_FIELD = 569,\n DATA_TYPE_INCOMPATIBLE_WITH_PROTOBUF_FIELD = 570,\n DATABASE_REPLICATION_FAILED = 571,\n TOO_MANY_QUERY_PLAN_OPTIMIZATIONS = 572,\n EPOLL_ERROR = 573,\n DISTRIBUTED_TOO_MANY_PENDING_BYTES = 574,\n UNKNOWN_SNAPSHOT = 575,\n KERBEROS_ERROR = 576,\n INVALID_SHARD_ID = 577,\n INVALID_FORMAT_INSERT_QUERY_WITH_DATA = 578,\n INCORRECT_PART_TYPE = 579,\n CANNOT_SET_ROUNDING_MODE = 580,\n TOO_LARGE_DISTRIBUTED_DEPTH = 581,\n NO_SUCH_PROJECTION_IN_TABLE = 582,\n ILLEGAL_PROJECTION = 583,\n PROJECTION_NOT_USED = 584,\n CANNOT_PARSE_YAML = 585,\n CANNOT_CREATE_FILE = 586,\n CONCURRENT_ACCESS_NOT_SUPPORTED = 587,\n DISTRIBUTED_BROKEN_BATCH_INFO = 588,\n DISTRIBUTED_BROKEN_BATCH_FILES = 589,\n CANNOT_SYSCONF = 590,\n SQLITE_ENGINE_ERROR = 591,\n DATA_ENCRYPTION_ERROR = 592,\n ZERO_COPY_REPLICATION_ERROR = 593,\n BZIP2_STREAM_DECODER_FAILED = 594,\n BZIP2_STREAM_ENCODER_FAILED = 595,\n INTERSECT_OR_EXCEPT_RESULT_STRUCTURES_MISMATCH = 596,\n NO_SUCH_ERROR_CODE = 597,\n BACKUP_ALREADY_EXISTS = 598,\n BACKUP_NOT_FOUND = 599,\n BACKUP_VERSION_NOT_SUPPORTED = 600,\n BACKUP_DAMAGED = 601,\n NO_BASE_BACKUP = 602,\n WRONG_BASE_BACKUP = 603,\n BACKUP_ENTRY_ALREADY_EXISTS = 604,\n BACKUP_ENTRY_NOT_FOUND = 605,\n BACKUP_IS_EMPTY = 606,\n CANNOT_RESTORE_DATABASE = 607,\n CANNOT_RESTORE_TABLE = 608,\n FUNCTION_ALREADY_EXISTS = 609,\n CANNOT_DROP_FUNCTION = 610,\n CANNOT_CREATE_RECURSIVE_FUNCTION = 611,\n POSTGRESQL_CONNECTION_FAILURE = 614,\n CANNOT_ADVISE = 615,\n UNKNOWN_READ_METHOD = 616,\n LZ4_ENCODER_FAILED = 617,\n LZ4_DECODER_FAILED = 618,\n POSTGRESQL_REPLICATION_INTERNAL_ERROR = 619,\n QUERY_NOT_ALLOWED = 620,\n CANNOT_NORMALIZE_STRING = 621,\n CANNOT_PARSE_CAPN_PROTO_SCHEMA = 622,\n CAPN_PROTO_BAD_CAST = 623,\n BAD_FILE_TYPE = 624,\n IO_SETUP_ERROR = 625,\n CANNOT_SKIP_UNKNOWN_FIELD = 626,\n BACKUP_ENGINE_NOT_FOUND = 627,\n OFFSET_FETCH_WITHOUT_ORDER_BY = 628,\n HTTP_RANGE_NOT_SATISFIABLE = 629,\n HAVE_DEPENDENT_OBJECTS = 630,\n UNKNOWN_FILE_SIZE = 631,\n UNEXPECTED_DATA_AFTER_PARSED_VALUE = 632,\n QUERY_IS_NOT_SUPPORTED_IN_WINDOW_VIEW = 633,\n MONGODB_ERROR = 634,\n CANNOT_POLL = 635,\n CANNOT_EXTRACT_TABLE_STRUCTURE = 636,\n INVALID_TABLE_OVERRIDE = 637,\n SNAPPY_UNCOMPRESS_FAILED = 638,\n SNAPPY_COMPRESS_FAILED = 639,\n NO_HIVEMETASTORE = 640,\n CANNOT_APPEND_TO_FILE = 641,\n CANNOT_PACK_ARCHIVE = 642,\n CANNOT_UNPACK_ARCHIVE = 643,\n NUMBER_OF_DIMENSIONS_MISMATCHED = 645,\n CANNOT_BACKUP_TABLE = 647,\n WRONG_DDL_RENAMING_SETTINGS = 648,\n INVALID_TRANSACTION = 649,\n SERIALIZATION_ERROR = 650,\n CAPN_PROTO_BAD_TYPE = 651,\n ONLY_NULLS_WHILE_READING_SCHEMA = 652,\n CANNOT_PARSE_BACKUP_SETTINGS = 653,\n WRONG_BACKUP_SETTINGS = 654,\n FAILED_TO_SYNC_BACKUP_OR_RESTORE = 655,\n UNKNOWN_STATUS_OF_TRANSACTION = 659,\n HDFS_ERROR = 660,\n CANNOT_SEND_SIGNAL = 661,\n FS_METADATA_ERROR = 662,\n INCONSISTENT_METADATA_FOR_BACKUP = 663,\n ACCESS_STORAGE_DOESNT_ALLOW_BACKUP = 664,\n CANNOT_CONNECT_NATS = 665,\n NOT_INITIALIZED = 667,\n INVALID_STATE = 668,\n NAMED_COLLECTION_DOESNT_EXIST = 669,\n NAMED_COLLECTION_ALREADY_EXISTS = 670,\n NAMED_COLLECTION_IS_IMMUTABLE = 671,\n INVALID_SCHEDULER_NODE = 672,\n RESOURCE_ACCESS_DENIED = 673,\n RESOURCE_NOT_FOUND = 674,\n CANNOT_PARSE_IPV4 = 675,\n CANNOT_PARSE_IPV6 = 676,\n THREAD_WAS_CANCELED = 677,\n IO_URING_INIT_FAILED = 678,\n IO_URING_SUBMIT_ERROR = 679,\n MIXED_ACCESS_PARAMETER_TYPES = 690,\n UNKNOWN_ELEMENT_OF_ENUM = 691,\n TOO_MANY_MUTATIONS = 692,\n AWS_ERROR = 693,\n ASYNC_LOAD_CYCLE = 694,\n ASYNC_LOAD_FAILED = 695,\n ASYNC_LOAD_CANCELED = 696,\n CANNOT_RESTORE_TO_NONENCRYPTED_DISK = 697,\n INVALID_REDIS_STORAGE_TYPE = 698,\n INVALID_REDIS_TABLE_STRUCTURE = 699,\n USER_SESSION_LIMIT_EXCEEDED = 700,\n CLUSTER_DOESNT_EXIST = 701,\n CLIENT_INFO_DOES_NOT_MATCH = 702,\n INVALID_IDENTIFIER = 703,\n QUERY_CACHE_USED_WITH_NONDETERMINISTIC_FUNCTIONS = 704,\n // TABLE_NOT_EMPTY = 705,\n // LIBSSH_ERROR = 706,\n // GCP_ERROR = 707,\n // ILLEGAL_STATISTIC = 708,\n // CANNOT_GET_REPLICATED_DATABASE_SNAPSHOT = 709,\n // FAULT_INJECTED = 710,\n // FILECACHE_ACCESS_DENIED = 711,\n // TOO_MANY_MATERIALIZED_VIEWS = 712,\n // UNEXPECTED_CLUSTER = 714,\n // CANNOT_DETECT_FORMAT = 715,\n // CANNOT_FORGET_PARTITION = 716,\n KEEPER_EXCEPTION = 999,\n POCO_EXCEPTION = 1000,\n STD_EXCEPTION = 1001,\n UNKNOWN_EXCEPTION = 1002,\n};\n\n}\n\n// Path: clickhouse/base/string_utils.h\n#pragma once\n\n#include \"string_view.h\"\n\n#include \n#include \n\nnamespace clickhouse {\n\ntemplate \n[[deprecated(\"Not used by clickhosue-cpp itself, and will be removed in next major release (3.0) \")]]\ninline T FromString(const std::string& s) {\n std::istringstream iss(s);\n T result;\n iss >> result;\n return result;\n}\n\ntemplate \n[[deprecated(\"Not used by clickhosue-cpp itself, and will be removed in next major release (3.0) \")]]\ninline T FromString(const StringView& s) {\n std::istringstream iss((std::string(s)));\n T result;\n iss >> result;\n return result;\n}\n\n}\n\n", "needle_token_start": 747, "needle_token_end": 871, "code_context_ntokens": 14467, "output": ["```cpp\ntemplate \ninline int CompateStringsCaseInsensitive(const L& left, const R& right) {\n int64_t size_diff = left.size() - right.size();\n if (size_diff != 0)\n return size_diff > 0 ? 1 : -1;\n\n#if defined _win_\n return _strnicmp(left.data(), right.data(), left.size());\n#else\n return strncasecmp(left.data(), right.data(), left.size());\n#endif\n}\n```"]} +{"repo": "ClickHouse/clickhouse-cpp", "name": "ContainerToString", "language": "cpp", "path": "clickhouse/columns/itemview.cpp", "position_ratio": 0.15, "description": "\nFunction Description:\n1. **Purpose**: The function is designed to convert a container of elements into a single string representation, with elements separated by a specified delimiter.\n2. **Input**: It takes a container (like a vector, list, etc.) and an optional string delimiter (default is \", \").\n3. **Output**: It returns a string that concatenates all the elements of the container, separated by the specified delimiter.\n4. **Procedure**: The function iterates through each element of the container, appending each element to a string stream. If the current element is not the last in the container, the delimiter is appended after the element. The loop avoids incrementing the iterator inside the loop body to handle the delimiter logic correctly.\n\n", "instruction": "Based on the function description and code context, please retrieve and repeat the exact described function from the code context in a code block wrapped by ```:", "template": "instruction\ncode_context\ndescription\ninstruction", "code_context": "// Path: clickhouse/columns/lowcardinality.cpp\n#include \"lowcardinality.h\"\n\n#include \"string.h\"\n#include \"nullable.h\"\n#include \"../base/wire_format.h\"\n\n#include \n\n#include \n#include \n#include \n\n#include \n\nnamespace {\nusing namespace clickhouse;\n\nenum KeySerializationVersion {\n SharedDictionariesWithAdditionalKeys = 1,\n};\n\nenum IndexType {\n UInt8 = 0,\n UInt16,\n UInt32,\n UInt64,\n};\n\nconstexpr uint64_t IndexTypeMask = 0b11111111;\n\nenum IndexFlag {\n /// Need to read dictionary if it wasn't.\n NeedGlobalDictionaryBit = 1u << 8u,\n /// Need to read additional keys. Additional keys are stored before indexes as value N and N keys after them.\n HasAdditionalKeysBit = 1u << 9u,\n /// Need to update dictionary. It means that previous granule has different dictionary.\n NeedUpdateDictionary = 1u << 10u\n};\n\nColumnRef createIndexColumn(IndexType type) {\n switch (type) {\n case IndexType::UInt8:\n return std::make_shared();\n case IndexType::UInt16:\n return std::make_shared();\n case IndexType::UInt32:\n return std::make_shared();\n case IndexType::UInt64:\n return std::make_shared();\n }\n\n throw ValidationError(\"Invalid LowCardinality index type value: \" + std::to_string(static_cast(type)));\n}\n\nIndexType indexTypeFromIndexColumn(const Column & index_column) {\n switch (index_column.Type()->GetCode()) {\n case Type::UInt8:\n return IndexType::UInt8;\n case Type::UInt16:\n return IndexType::UInt16;\n case Type::UInt32:\n return IndexType::UInt32;\n case Type::UInt64:\n return IndexType::UInt64;\n default:\n throw ValidationError(\"Invalid index column type for LowCardinality column:\" + index_column.Type()->GetName());\n }\n}\n\ntemplate \ninline const ResultColumnType & column_down_cast(const ColumnType & c) {\n return dynamic_cast(c);\n}\n\ntemplate \ninline ResultColumnType & column_down_cast(ColumnType & c) {\n return dynamic_cast(c);\n}\n\n// std::visit-ish function to avoid including header, which is not present in older version of XCode.\ntemplate \ninline auto VisitIndexColumn(Vizitor && vizitor, ColumnType && col) {\n switch (col.Type()->GetCode()) {\n case Type::UInt8:\n return vizitor(column_down_cast(col));\n case Type::UInt16:\n return vizitor(column_down_cast(col));\n case Type::UInt32:\n return vizitor(column_down_cast(col));\n case Type::UInt64:\n return vizitor(column_down_cast(col));\n default:\n throw ValidationError(\"Invalid index column type \" + col.GetType().GetName());\n }\n}\n\n// A special NULL-item, which is expected at pos(0) in dictionary,\n// note that we distinguish empty string from NULL-value.\ninline auto GetNullItemForDictionary(const ColumnRef dictionary) {\n if (auto n = dictionary->As()) {\n return ItemView {};\n } else {\n return ItemView{dictionary->Type()->GetCode(), std::string_view{}};\n }\n}\n\n// A special default item, which is expected at pos(0) in dictionary,\n// note that we distinguish empty string from NULL-value.\ninline ItemView GetDefaultItemForDictionary(const ColumnRef dictionary) {\n if (auto n = dictionary->As()) {\n return GetDefaultItemForDictionary(n->Nested());\n } else {\n return ItemView{dictionary->Type()->GetCode(), std::string_view{}};\n }\n}\n\nvoid AppendToDictionary(Column& dictionary, const ItemView & item);\n\ninline void AppendNullableToDictionary(ColumnNullable& nullable, const ItemView & item) {\n auto nested = nullable.Nested();\n\n const bool isNullValue = item.type == Type::Void;\n\n if (isNullValue) {\n AppendToDictionary(*nested, GetNullItemForDictionary(nested));\n } else {\n const auto nestedType = nested->GetType().GetCode();\n if (nestedType != item.type) {\n throw ValidationError(\"Invalid value. Type expected: \" + nested->GetType().GetName());\n }\n\n AppendToDictionary(*nested, item);\n }\n\n nullable.Append(isNullValue);\n}\n\ninline void AppendToDictionary(Column& dictionary, const ItemView & item) {\n switch (dictionary.GetType().GetCode()) {\n case Type::FixedString:\n column_down_cast(dictionary).Append(item.get());\n return;\n case Type::String:\n column_down_cast(dictionary).Append(item.get());\n return;\n case Type::Nullable:\n AppendNullableToDictionary(column_down_cast(dictionary), item);\n return;\n default:\n throw ValidationError(\"Unexpected dictionary column type: \" + dictionary.GetType().GetName());\n }\n}\n\n}\n\nnamespace clickhouse {\nColumnLowCardinality::ColumnLowCardinality(ColumnRef dictionary_column)\n : Column(Type::CreateLowCardinality(dictionary_column->Type())),\n dictionary_column_(dictionary_column->CloneEmpty()), // safe way to get an column of the same type.\n index_column_(std::make_shared())\n{\n Setup(dictionary_column);\n}\n\nColumnLowCardinality::ColumnLowCardinality(std::shared_ptr dictionary_column)\n : Column(Type::CreateLowCardinality(dictionary_column->Type())),\n dictionary_column_(dictionary_column->CloneEmpty()), // safe way to get an column of the same type.\n index_column_(std::make_shared())\n{\n AppendNullItem();\n Setup(dictionary_column);\n}\n\nColumnLowCardinality::~ColumnLowCardinality()\n{}\n\nvoid ColumnLowCardinality::Reserve(size_t new_cap) {\n dictionary_column_->Reserve(new_cap);\n index_column_->Reserve(new_cap);\n}\n\nvoid ColumnLowCardinality::Setup(ColumnRef dictionary_column) {\n AppendDefaultItem();\n\n if (dictionary_column->Size() != 0) {\n // Add values, updating index_column_ and unique_items_map_.\n\n // TODO: it would be possible to eliminate copying\n // by adding InsertUnsafe(pos, ItemView) method to a Column\n // (to insert null-item at pos 0),\n // but that is too much work for now.\n for (size_t i = 0; i < dictionary_column->Size(); ++i) {\n AppendUnsafe(dictionary_column->GetItem(i));\n }\n }\n}\n\nstd::uint64_t ColumnLowCardinality::getDictionaryIndex(std::uint64_t item_index) const {\n return VisitIndexColumn([item_index](const auto & arg) -> std::uint64_t {\n return arg[item_index];\n }, *index_column_);\n}\n\nvoid ColumnLowCardinality::appendIndex(std::uint64_t item_index) {\n // TODO (nemkov): handle case when index should go from UInt8 to UInt16, etc.\n VisitIndexColumn([item_index](auto & arg) {\n arg.Append(static_cast::DataType>(item_index));\n }, *index_column_);\n}\n\nvoid ColumnLowCardinality::removeLastIndex() {\n VisitIndexColumn([](auto & arg) {\n arg.Erase(arg.Size() - 1);\n }, *index_column_);\n}\n\ndetails::LowCardinalityHashKey ColumnLowCardinality::computeHashKey(const ItemView & item) {\n static const auto hasher = std::hash{};\n if (item.type == Type::Void) {\n // to distinguish NULL of ColumnNullable and empty string.\n return {0u, 0u};\n }\n\n...\n// Path: clickhouse/columns/itemview.cpp\n#include \"../columns/itemview.h\"\n\n#include \n#include \n\nnamespace {\n\ntemplate \n\nstd::string ContainerToString(Container container, const char * separator = \", \") {\n std::stringstream sstr;\n const auto end = std::end(container);\n for (auto i = std::begin(container); i != end; /*intentionally no ++i*/) {\n const auto & elem = *i;\n sstr << elem;\n\n if (++i != end) {\n sstr << separator;\n }\n }\n\n return sstr.str();\n}\n\n}\n\nnamespace clickhouse {\n\nvoid ItemView::ValidateData(Type::Code type, DataType data) {\n\n auto AssertSize = [type, &data](std::initializer_list allowed_sizes) -> void {\n const auto end = std::end(allowed_sizes);\n if (std::find(std::begin(allowed_sizes), end, static_cast(data.size())) == end) {\n throw AssertionError(std::string(\"ItemView value size mismatch for \")\n + Type::TypeName(type)\n + \" expected: \" + ContainerToString(allowed_sizes, \" or \")\n + \", got: \" + std::to_string(data.size()));\n }\n };\n\n switch (type) {\n case Type::Code::Void:\n return AssertSize({0});\n\n case Type::Code::Int8:\n case Type::Code::UInt8:\n case Type::Code::Enum8:\n return AssertSize({1});\n\n case Type::Code::Int16:\n case Type::Code::UInt16:\n case Type::Code::Date:\n case Type::Code::Enum16:\n return AssertSize({2});\n\n case Type::Code::Int32:\n case Type::Code::UInt32:\n case Type::Code::Float32:\n case Type::Code::DateTime:\n case Type::Code::Date32:\n case Type::Code::IPv4:\n case Type::Code::Decimal32:\n return AssertSize({4});\n\n case Type::Code::Int64:\n case Type::Code::UInt64:\n case Type::Code::Float64:\n case Type::Code::DateTime64:\n case Type::Code::Decimal64:\n return AssertSize({8});\n\n case Type::Code::String:\n case Type::Code::FixedString:\n // value can be of any size\n return;\n\n case Type::Code::Array:\n case Type::Code::Nullable:\n case Type::Code::Tuple:\n case Type::Code::LowCardinality:\n case Type::Code::Map:\n throw AssertionError(\"Unsupported type in ItemView: \" + std::string(Type::TypeName(type)));\n\n case Type::Code::IPv6:\n case Type::Code::UUID:\n case Type::Code::Int128:\n case Type::Code::Decimal128:\n return AssertSize({16});\n\n case Type::Code::Decimal:\n // Could be either Decimal32, Decimal64 or Decimal128\n return AssertSize({4, 8, 16});\n\n default:\n throw UnimplementedError(\"Unknown type code:\" + std::to_string(static_cast(type)));\n }\n}\n\n}\n\n// Path: clickhouse/columns/numeric.cpp\n#include \"numeric.h\"\n#include \"utils.h\"\n\n#include \"../base/wire_format.h\"\n\nnamespace clickhouse {\n\ntemplate \nColumnVector::ColumnVector()\n : Column(Type::CreateSimple())\n{\n}\n\ntemplate \nColumnVector::ColumnVector(const std::vector & data)\n : Column(Type::CreateSimple())\n , data_(data)\n{\n}\n\ntemplate \nColumnVector::ColumnVector(std::vector && data)\n : Column(Type::CreateSimple())\n , data_(std::move(data))\n{\n}\n\ntemplate \nvoid ColumnVector::Append(const T& value) {\n data_.push_back(value);\n}\n\ntemplate \nvoid ColumnVector::Erase(size_t pos, size_t count) {\n const auto begin = std::min(pos, data_.size());\n const auto last = begin + std::min(data_.size() - begin, count);\n\n data_.erase(data_.begin() + begin, data_.begin() + last);\n}\n\ntemplate \nstd::vector& ColumnVector::GetWritableData() {\n return data_;\n}\n\ntemplate \nvoid ColumnVector::Reserve(size_t new_cap) {\n data_.reserve(new_cap);\n}\n\ntemplate \nsize_t ColumnVector::Capacity() const {\n return data_.capacity();\n}\n\ntemplate \nvoid ColumnVector::Clear() {\n data_.clear();\n}\n\ntemplate \nconst T& ColumnVector::At(size_t n) const {\n return data_.at(n);\n}\n\ntemplate \nvoid ColumnVector::Append(ColumnRef column) {\n if (auto col = column->As>()) {\n data_.insert(data_.end(), col->data_.begin(), col->data_.end());\n }\n}\n\ntemplate \nbool ColumnVector::LoadBody(InputStream* input, size_t rows) {\n data_.resize(rows);\n\n return WireFormat::ReadBytes(*input, data_.data(), data_.size() * sizeof(T));\n}\n\ntemplate \nvoid ColumnVector::SaveBody(OutputStream* output) {\n WireFormat::WriteBytes(*output, data_.data(), data_.size() * sizeof(T));\n}\n\ntemplate \nsize_t ColumnVector::Size() const {\n return data_.size();\n}\n\ntemplate \nColumnRef ColumnVector::Slice(size_t begin, size_t len) const {\n return std::make_shared>(SliceVector(data_, begin, len));\n}\n\ntemplate \nColumnRef ColumnVector::CloneEmpty() const {\n return std::make_shared>();\n}\n\ntemplate \nvoid ColumnVector::Swap(Column& other) {\n auto & col = dynamic_cast &>(other);\n data_.swap(col.data_);\n}\n\ntemplate \nItemView ColumnVector::GetItem(size_t index) const {\n return ItemView{type_->GetCode(), data_[index]};\n}\n\ntemplate class ColumnVector;\ntemplate class ColumnVector;\ntemplate class ColumnVector;\ntemplate class ColumnVector;\n\ntemplate class ColumnVector;\ntemplate class ColumnVector;\ntemplate class ColumnVector;\ntemplate class ColumnVector;\ntemplate class ColumnVector;\n\ntemplate class ColumnVector;\ntemplate class ColumnVector;\n\n}\n\n// Path: clickhouse/columns/decimal.cpp\n#include \"decimal.h\"\n\nnamespace\n{\nusing namespace clickhouse;\n\n#ifdef ABSL_HAVE_INTRINSIC_INT128\ntemplate \ninline bool addOverflow(const Int128 & l, const T & r, Int128 * result)\n{\n __int128 res;\n const auto ret_value = __builtin_add_overflow(static_cast<__int128>(l), static_cast<__int128>(r), &res);\n\n *result = res;\n return ret_value;\n}\n\ntemplate \ninline bool mulOverflow(const Int128 & l, const T & r, Int128 * result)\n{\n __int128 res;\n const auto ret_value = __builtin_mul_overflow(static_cast<__int128>(l), static_cast<__int128>(r), &res);\n\n *result = res;\n return ret_value;\n}\n\n#else\ntemplate \ninline bool getSignBit(const T & v)\n{\n return v < static_cast(0);\n}\n\ninline bool getSignBit(const Int128 & v)\n{\n// static constexpr Int128 zero {};\n// return v < zero;\n\n // Sign of the whole absl::int128 value is determined by sign of higher 64 bits.\n return absl::Int128High64(v) < 0;\n}\n\ninline bool addOverflow(const Int128 & l, const Int128 & r, Int128 * result)\n{\n // *result = l + r;\n // const auto result_sign = getSignBit(*result);\n // return l_sign == r_sign && l_sign != result_sign;\n\n // Based on code from:\n // https://wiki.sei.cmu.edu/confluence/display/c/INT32-C.+Ensure+that+operations+on+signed+integers+do+not+result+in+overflow#INT32C.Ensurethatoperationsonsignedintegersdonotresultinoverflow-CompliantSolution\n const auto r_positive = !getSignBit(r);\n\n if ((r_positive && (l > (std::numeric_limits::max() - r))) ||\n (!r_positive && (l < (std::numeric_limits::min() - r)))) {\n return true;\n }\n *result = l + r;\n\n return false;\n}\n\ntemplate \ninline bool mulOverflow(const Int128 & l, const T & r, Int128 * result)\n{\n // Based on code from:\n // https://wiki.sei.cmu.edu/confluence/display/c/INT32-C.+Ensure+that+operations+on+signed+integers+do+not+result+in+overflow#INT32C.Ensurethatoperationsonsignedintegersdonotresultinoverflow-CompliantSolution.3\n const auto l_positive = !getSignBit(l);\n const auto r_positive = !getSignBit(r);\n\n if (l_positive) {\n if (r_positive) {\n if (r != 0 && l > (std::numeric_limits::max() / r)) {\n return true;\n }\n } else {\n if (l != 0 && r < (std::numeric_limits::min() / l)) {\n return true;\n }\n }\n } else {\n if (r_positive) {\n if (r != 0 && l < (std::numeric_limits::min() / r)) {\n return true;\n }\n } else {\n if (l != 0 && (r < (std::numeric_limits::max() / l))) {\n return true;\n }\n }\n }\n\n *result = l * r;\n return false;\n}\n#endif\n\n}\n\nnamespace clickhouse {\n\nColumnDecimal::ColumnDecimal(size_t precision, size_t scale)\n : Column(Type::CreateDecimal(precision, scale))\n{\n if (precision <= 9) {\n data_ = std::make_shared();\n } else if (precision <= 18) {\n data_ = std::make_shared();\n } else {\n data_ = std::make_shared();\n }\n}\n\nColumnDecimal::ColumnDecimal(TypeRef type, ColumnRef data)\n : Column(type),\n data_(data)\n{\n}\n\nvoid ColumnDecimal::Append(const Int128& value) {\n if (data_->Type()->GetCode() == Type::Int32) {\n data_->As()->Append(static_cast(value));\n } else if (data_->Type()->GetCode() == Type::Int64) {\n data_->As()->Append(static_cast(value));\n } else {\n data_->As()->Append(static_cast(value));\n }\n}\n\nvoid ColumnDecimal::Append(const std::string& value) {\n Int128 int_value = 0;\n auto c = value.begin();\n auto end = value.end();\n bool sign = true;\n bool has_dot = false;\n\n size_t zeros = 0;\n\n while (c != end) {\n if (*c == '-') {\n sign = false;\n if (c != value.begin()) {\n break;\n }\n } else if (*c == '.' && !has_dot) {\n size_t distance = std::distance(c, end) - 1;\n auto scale = type_->As()->GetScale();\n\n if (distance <= scale) {\n zeros = scale - distance;\n } else {\n std::advance(end, scale - distance);\n }\n\n has_dot = true;\n } else if (*c >= '0' && *c <= '9') {\n if (mulOverflow(int_value, 10, &int_value) ||\n addOverflow(int_value, *c - '0', &int_value)) {\n throw AssertionError(\"value is too big for 128-bit integer\");\n }\n } else {\n throw ValidationError(std::string(\"unexpected symbol '\") + (*c) + \"' in decimal value\");\n }\n ++c;\n }\n\n if (c != end) {\n throw ValidationError(\"unexpected symbol '-' in decimal value\");\n }\n\n while (zeros) {\n if (mulOverflow(int_value, 10, &int_value)) {\n throw AssertionError(\"value is too big for 128-bit integer\");\n }\n --zeros;\n }\n\n Append(sign ? int_value : -int_value);\n}\n\nInt128 ColumnDecimal::At(size_t i) const {\n switch (data_->Type()->GetCode()) {\n case Type::Int32:\n return static_cast(data_->As()->At(i));\n case Type::Int64:\n return static_cast(data_->As()->At(i));\n case Type::Int128:\n return data_->As()->At(i);\n default:\n throw ValidationError(\"Invalid data_ column type in ColumnDecimal\");\n }\n}\n\nvoid ColumnDecimal::Reserve(size_t new_cap) {\n data_->Reserve(new_cap);\n}\n\nvoid ColumnDecimal::Append(ColumnRef column) {\n if (auto col = column->As()) {\n data_->Append(col->data_);\n }\n}\n\nbool ColumnDecimal::LoadBody(InputStream * input, size_t rows) {\n return data_->LoadBody(input, rows);\n}\n\nvoid ColumnDecimal::SaveBody(OutputStream* output) {\n data_->SaveBody(output);\n}\n\nvoid ColumnDecimal::Clear() {\n data_->Clear();\n}\n\nsize_t ColumnDecimal::Size() const {\n return data_->Size();\n}\n\nColumnRef ColumnDecimal::Slice(size_t begin, size_t len) const {\n // coundn't use std::make_shared since this c-tor is private\n return ColumnRef{new ColumnDecimal(type_, data_->Slice(begin, len))};\n}\n\nColumnRef ColumnDecimal::CloneEmpty() const {\n // coundn't use std::make_shared since this c-tor is private\n return ColumnRef{new ColumnDecimal(type_, data_->CloneEmpty())};\n}\n\nvoid ColumnDecimal::Swap(Column& other) {\n auto & col = dynamic_cast(other);\n data_.swap(col.data_);\n}\n\nItemView ColumnDecimal::GetItem(size_t index) const {\n return ItemView{GetType().GetCode(), data_->GetItem(index)};\n}\n\nsize_t ColumnDecimal::GetScale() const\n{\n return type_->As()->GetScale();\n}\n\nsize_t ColumnDecimal::GetPrecision() const\n{\n return type_->As()->GetPrecision();\n}\n\n}\n\n// Path: clickhouse/columns/string.cpp\n#include \"string.h\"\n#include \"utils.h\"\n\n#include \"../base/wire_format.h\"\n\nnamespace {\n\nconstexpr size_t DEFAULT_BLOCK_SIZE = 4096;\n\ntemplate \nsize_t ComputeTotalSize(const Container & strings, size_t begin = 0, size_t len = -1) {\n size_t result = 0;\n if (begin < strings.size()) {\n len = std::min(len, strings.size() - begin);\n\n for (size_t i = begin; i < begin + len; ++i)\n result += strings[i].size();\n }\n\n return result;\n}\n\n}\n\nnamespace clickhouse {\n\nColumnFixedString::ColumnFixedString(size_t n)\n : Column(Type::CreateString(n))\n , string_size_(n)\n{\n}\n\nvoid ColumnFixedString::Reserve(size_t new_cap) {\n data_.reserve(string_size_ * new_cap);\n}\n\nvoid ColumnFixedString::Append(std::string_view str) {\n if (str.size() > string_size_) {\n throw ValidationError(\"Expected string of length not greater than \"\n + std::to_string(string_size_) + \" bytes, received \"\n + std::to_string(str.size()) + \" bytes.\");\n }\n\n if (data_.capacity() - data_.size() < str.size()) {\n // round up to the next block size\n const auto new_size = (((data_.size() + string_size_) / DEFAULT_BLOCK_SIZE) + 1) * DEFAULT_BLOCK_SIZE;\n data_.reserve(new_size);\n }\n\n data_.insert(data_.size(), str);\n // Pad up to string_size_ with zeroes.\n if (str.size() < string_size_) {\n const auto padding_size = string_size_ - str.size();\n data_.resize(data_.size() + padding_size, char(0));\n }\n}\n\nvoid ColumnFixedString::Clear() {\n data_.clear();\n}\n\nstd::string_view ColumnFixedString::At(size_t n) const {\n const auto pos = n * string_size_;\n return std::string_view(&data_.at(pos), string_size_);\n}\n\nsize_t ColumnFixedString::FixedSize() const {\n return string_size_;\n}\n\nvoid ColumnFixedString::Append(ColumnRef column) {\n if (auto col = column->As()) {\n if (string_size_ == col->string_size_) {\n data_.insert(data_.end(), col->data_.begin(), col->data_.end());\n }\n }\n}\n\nbool ColumnFixedString::LoadBody(InputStream * input, size_t rows) {\n data_.resize(string_size_ * rows);\n if (!WireFormat::ReadBytes(*input, &data_[0], data_.size())) {\n return false;\n }\n\n return true;\n}\n\nvoid ColumnFixedString::SaveBody(OutputStream* output) {\n WireFormat::WriteBytes(*output, data_.data(), data_.size());\n}\n\nsize_t ColumnFixedString::Size() const {\n return data_.size() / string_size_;\n}\n\nColumnRef ColumnFixedString::Slice(size_t begin, size_t len) const {\n auto result = std::make_shared(string_size_);\n\n if (begin < Size()) {\n const auto b = begin * string_size_;\n const auto l = len * string_size_;\n result->data_ = data_.substr(b, std::min(data_.size() - b, l));\n }\n\n return result;\n}\n\nColumnRef ColumnFixedString::CloneEmpty() const {\n return std::make_shared(string_size_);\n}\n\nvoid ColumnFixedString::Swap(Column& other) {\n auto & col = dynamic_cast(other);\n std::swap(string_size_, col.string_size_);\n data_.swap(col.data_);\n}\n\nItemView ColumnFixedString::GetItem(size_t index) const {\n return ItemView{Type::FixedString, this->At(index)};\n}\n\nstruct ColumnString::Block\n{\n using CharT = typename std::string::value_type;\n\n explicit Block(size_t starting_capacity)\n : size(0),\n capacity(starting_capacity),\n data_(new CharT[capacity])\n {}\n\n inline auto GetAvailable() const {\n return capacity - size;\n }\n\n std::string_view AppendUnsafe(std::string_view str) {\n const auto pos = &data_[size];\n\n memcpy(pos, str.data(), str.size());\n size += str.size();\n\n return std::string_view(pos, str.size());\n }\n\n auto GetCurrentWritePos() {\n return &data_[size];\n }\n\n std::string_view ConsumeTailAsStringViewUnsafe(size_t len) {\n const auto start = &data_[size];\n size += len;\n return std::string_view(start, len);\n }\n\n size_t size;\n const size_t capacity;\n std::unique_ptr data_;\n};\n\nColumnString::ColumnString()\n : Column(Type::CreateString())\n{\n}\n\nColumnString::ColumnString(size_t element_count)\n : Column(Type::CreateString())\n{\n items_.reserve(element_count);\n // 16 is arbitrary number, assumption that string values are about ~256 bytes long.\n blocks_.reserve(std::max(1, element_count / 16));\n}\n\nColumnString::ColumnString(const std::vector& data)\n : ColumnString()\n{\n items_.reserve(data.size());\n blocks_.emplace_back(ComputeTotalSize(data));\n\n for (const auto & s : data) {\n AppendUnsafe(s);\n }\n}\n\nColumnString::ColumnString(std::vector&& data)\n : ColumnString()\n{\n items_.reserve(data.size());\n\n for (auto&& d : data) {\n append_data_.emplace_back(std::move(d));\n auto& last_data = append_data_.back();\n items_.emplace_back(std::string_view{ last_data.data(),last_data.length() });\n }\n}\n\nColumnString::~ColumnString()\n{}\n\nvoid ColumnString::Reserve(size_t new_cap) {\n items_.reserve(new_cap);\n // 16 is arbitrary number, assumption that string values are about ~256 bytes long.\n blocks_.reserve(std::max(1, new_cap / 16));\n}\n\nvoid ColumnString::Append(std::string_view str) {\n if (blocks_.size() == 0 || blocks_.back().GetAvailable() < str.length()) {\n blocks_.emplace_back(std::max(DEFAULT_BLOCK_SIZE, str.size()));\n }\n\n items_.emplace_back(blocks_.back().AppendUnsafe(str));\n}\n\nvoid ColumnString::Append(const char* str) {\n Append(std::string_view(str, strlen(str)));\n}\n\nvoid ColumnString::Append(std::string&& steal_value) {\n append_data_.emplace_back(std::move(steal_value));\n auto& last_data = append_data_.back();\n items_.emplace_back(std::string_view{ last_data.data(),last_data.length() });\n}\n\nvoid ColumnString::AppendNoManagedLifetime(std::string_view str) {\n items_.emplace_back(str);\n}\n\nvoid ColumnString::AppendUnsafe(std::string_view str) {\n items_.emplace_back(blocks_.back().AppendUnsafe(str));\n}\n\nvoid ColumnString::Clear() {\n items_.clear();\n blocks_.clear();\n append_data_.clear();\n}\n\nstd::string_view ColumnString::At(size_t n) const {\n return items_.at(n);\n}\n\nvoid ColumnString::Append(ColumnRef column) {\n if (auto col = column->As()) {\n const auto total_size = ComputeTotalSize(col->items_);\n\n // TODO: fill up existing block with some items and then add a new one for the rest of items\n if (blocks_.size() == 0 || blocks_.back().GetAvailable() < total_size)\n blocks_.emplace_back(std::max(DEFAULT_BLOCK_SIZE, total_size));\n\n // Intentionally not doing items_.reserve() since that cripples performance.\n for (size_t i = 0; i < column->Size(); ++i) {\n this->AppendUnsafe((*col)[i]);\n }\n }\n}\n\nbool ColumnString::LoadBody(InputStream* input, size_t rows) {\n if (rows == 0) {\n items_.clear();\n blocks_.clear();\n\n return true;\n }\n\n decltype(items_) new_items;\n decltype(blocks_) new_blocks;\n\n new_items.reserve(rows);\n\n // Suboptimzal if the first row string is >DEFAULT_BLOCK_SIZE, but that must be a very rare case.\n Block * block = &new_blocks.emplace_back(DEFAULT_BLOCK_SIZE);\n\n for (size_t i = 0; i < rows; ++i) {\n uint64_t len;\n if (!WireFormat::ReadUInt64(*input, &len))\n return false;\n\n if (len > block->GetAvailable())\n block = &new_blocks.emplace_back(std::max(DEFAULT_BLOCK_SIZE, len));\n\n if (!WireFormat::ReadBytes(*input, block->GetCurrentWritePos(), len))\n return false;\n\n new_items.emplace_back(block->ConsumeTailAsStringViewUnsafe(len));\n }\n\n items_.swap(new_items);\n blocks_.swap(new_blocks);\n\n return true;\n}\n\nvoid ColumnString::SaveBody(OutputStream* output) {\n for (const auto & item : items_) {\n WireFormat::WriteString(*output, item);\n }\n}\n\nsize_t ColumnString::Size() const {\n return items_.size();\n}\n\nColumnRef ColumnString::Slice(size_t begin, size_t len) const {\n auto result = std::make_shared();\n\n if (begin < items_.size()) {\n len = std::min(len, items_.size() - begin);\n result->items_.reserve(len);\n\n result->blocks_.emplace_back(ComputeTotalSize(items_, begin, len));\n for (size_t i = begin; i < begin + len; ++i) {\n result->Append(items_[i]);\n }\n }\n\n return result;\n}\n\nColumnRef ColumnString::CloneEmpty() const {\n return std::make_shared();\n}\n\nvoid ColumnString::Swap(Column& other) {\n auto & col = dynamic_cast(other);\n items_.swap(col.items_);\n blocks_.swap(col.blocks_);\n append_data_.swap(col.append_data_);\n}\n\nItemView ColumnString::GetItem(size_t index) const {\n return ItemView{Type::String, this->At(index)};\n}\n\n}\n\n// Path: clickhouse/columns/ip4.cpp\n#include \"ip4.h\"\n\n#include \"../base/socket.h\" // for platform-specific IPv4-related functions\n#include \n\nnamespace clickhouse {\n\nColumnIPv4::ColumnIPv4()\n : Column(Type::CreateIPv4())\n , data_(std::make_shared())\n{\n}\n\nColumnIPv4::ColumnIPv4(ColumnRef data)\n : Column(Type::CreateIPv4())\n , data_(data ? data->As() : nullptr)\n{\n if (!data_)\n throw ValidationError(\"Expecting ColumnUInt32, got \" + (data ? data->GetType().GetName() : \"null\"));\n}\n\nColumnIPv4::ColumnIPv4(std::vector&& data)\n : Column(Type::CreateIPv4())\n{\n for (auto& addr : data) {\n addr = htonl(addr);\n }\n data_ = std::make_shared(std::move(data));\n}\n\nvoid ColumnIPv4::Append(const std::string& str) {\n uint32_t address;\n if (inet_pton(AF_INET, str.c_str(), &address) != 1)\n throw ValidationError(\"invalid IPv4 format, ip: \" + str);\n data_->Append(htonl(address));\n}\n\nvoid ColumnIPv4::Append(uint32_t ip) {\n data_->Append(htonl(ip));\n}\n\nvoid ColumnIPv4::Append(in_addr ip) {\n data_->Append(htonl(ip.s_addr));\n}\n\nvoid ColumnIPv4::Clear() {\n data_->Clear();\n}\n\nin_addr ColumnIPv4::At(size_t n) const {\n in_addr addr;\n addr.s_addr = ntohl(data_->At(n));\n return addr;\n}\n\nin_addr ColumnIPv4::operator [] (size_t n) const {\n in_addr addr;\n addr.s_addr = ntohl(data_->operator[](n));\n return addr;\n}\n\nstd::string ColumnIPv4::AsString(size_t n) const {\n const auto& addr = this->At(n);\n\n char buf[INET_ADDRSTRLEN];\n const char* ip_str = inet_ntop(AF_INET, &addr, buf, INET_ADDRSTRLEN);\n\n if (ip_str == nullptr) {\n throw std::system_error(\n std::error_code(errno, std::generic_category()),\n \"Invalid IPv4 data\");\n }\n\n return ip_str;\n}\n\nvoid ColumnIPv4::Reserve(size_t new_cap) {\n data_->Reserve(new_cap);\n}\n\nvoid ColumnIPv4::Append(ColumnRef column) {\n if (auto col = column->As()) {\n data_->Append(col->data_);\n }\n}\n\nbool ColumnIPv4::LoadBody(InputStream * input, size_t rows) {\n return data_->LoadBody(input, rows);\n}\n\nvoid ColumnIPv4::SaveBody(OutputStream* output) {\n data_->SaveBody(output);\n}\n\nsize_t ColumnIPv4::Size() const {\n return data_->Size();\n}\n\nColumnRef ColumnIPv4::Slice(size_t begin, size_t len) const {\n return std::make_shared(data_->Slice(begin, len));\n}\n\nColumnRef ColumnIPv4::CloneEmpty() const {\n return std::make_shared(data_->CloneEmpty());\n}\n\nvoid ColumnIPv4::Swap(Column& other) {\n auto & col = dynamic_cast(other);\n data_.swap(col.data_);\n}\n\nItemView ColumnIPv4::GetItem(size_t index) const {\n return ItemView(Type::IPv4, data_->GetItem(index));\n}\n\n}\n\n// Path: clickhouse/columns/tuple.cpp\n#include \"tuple.h\"\n\nnamespace clickhouse {\n\nstatic std::vector CollectTypes(const std::vector& columns) {\n std::vector types;\n for (const auto& col : columns) {\n types.push_back(col->Type());\n }\n return types;\n}\n\nColumnTuple::ColumnTuple(const std::vector& columns)\n : Column(Type::CreateTuple(CollectTypes(columns)))\n , columns_(columns)\n{\n}\n\nsize_t ColumnTuple::TupleSize() const {\n return columns_.size();\n}\n\nvoid ColumnTuple::Reserve(size_t new_cap) {\n for (auto& column : columns_) {\n column->Reserve(new_cap);\n } \n}\n\nvoid ColumnTuple::Append(ColumnRef column) {\n if (!this->Type()->IsEqual(column->Type())) {\n throw ValidationError(\n \"can't append column of type \" + column->Type()->GetName() + \" \"\n \"to column type \" + this->Type()->GetName());\n }\n const auto & source_tuple_column = column->As();\n for (size_t ci = 0; ci < columns_.size(); ++ci) {\n columns_[ci]->Append((*source_tuple_column)[ci]);\n }\n}\nsize_t ColumnTuple::Size() const {\n return columns_.empty() ? 0 : columns_[0]->Size();\n}\n\nColumnRef ColumnTuple::Slice(size_t begin, size_t len) const {\n std::vector sliced_columns;\n sliced_columns.reserve(columns_.size());\n for(const auto &column : columns_) {\n sliced_columns.push_back(column->Slice(begin, len));\n }\n\n return std::make_shared(sliced_columns);\n}\n\nColumnRef ColumnTuple::CloneEmpty() const {\n std::vector result_columns;\n result_columns.reserve(columns_.size());\n\n for(const auto &column : columns_) {\n result_columns.push_back(column->CloneEmpty());\n }\n\n return std::make_shared(result_columns);\n}\n\nbool ColumnTuple::LoadPrefix(InputStream* input, size_t rows) {\n for (auto ci = columns_.begin(); ci != columns_.end(); ++ci) {\n if (!(*ci)->LoadPrefix(input, rows)) {\n return false;\n }\n }\n\n return true;\n}\n\nbool ColumnTuple::LoadBody(InputStream* input, size_t rows) {\n for (auto ci = columns_.begin(); ci != columns_.end(); ++ci) {\n if (!(*ci)->LoadBody(input, rows)) {\n return false;\n }\n }\n\n return true;\n}\n\nvoid ColumnTuple::SavePrefix(OutputStream* output) {\n for (auto & column : columns_) {\n column->SavePrefix(output);\n }\n}\n\nvoid ColumnTuple::SaveBody(OutputStream* output) {\n for (auto & column : columns_) {\n column->SaveBody(output);\n }\n}\n\nvoid ColumnTuple::Clear() {\n columns_.clear();\n}\n\nvoid ColumnTuple::Swap(Column& other) {\n auto & col = dynamic_cast(other);\n columns_.swap(col.columns_);\n}\n\n}\n\n// Path: clickhouse/columns/column.cpp\n#include \"column.h\"\n\nnamespace clickhouse {\n\nbool Column::LoadPrefix(InputStream*, size_t) {\n /// does nothing by default\n return true;\n}\n\nbool Column::Load(InputStream* input, size_t rows) {\n return LoadPrefix(input, rows) && LoadBody(input, rows);\n}\n\nvoid Column::SavePrefix(OutputStream*) {\n /// does nothing by default\n}\n\n/// Saves column data to output stream.\nvoid Column::Save(OutputStream* output) {\n SavePrefix(output);\n SaveBody(output);\n}\n\n}\n\n// Path: clickhouse/base/sslsocket.cpp\n#include \"sslsocket.h\"\n#include \"../client.h\"\n#include \"../exceptions.h\"\n\n#include \n\n#include \n#include \n#include \n#include \n\n\nnamespace {\n\nstd::string getCertificateInfo(X509* cert)\n{\n if (!cert)\n return \"No certificate\";\n\n std::unique_ptr mem_bio(BIO_new(BIO_s_mem()), &BIO_free);\n X509_print(mem_bio.get(), cert);\n\n char * data = nullptr;\n auto len = BIO_get_mem_data(mem_bio.get(), &data);\n if (len < 0)\n return \"Can't get certificate info due to BIO error \" + std::to_string(len);\n\n return std::string(data, len);\n}\n\nvoid throwSSLError(SSL * ssl, int error, const char * /*location*/, const char * /*statement*/, const std::string prefix = \"OpenSSL error: \") {\n const auto detail_error = ERR_get_error();\n auto reason = ERR_reason_error_string(detail_error);\n reason = reason ? reason : \"Unknown SSL error\";\n\n std::string reason_str = reason;\n if (ssl) {\n // Print certificate only if handshake isn't completed \n if (auto ssl_session = SSL_get_session(ssl); ssl_session && SSL_get_state(ssl) != TLS_ST_OK)\n reason_str += \"\\nServer certificate: \" + getCertificateInfo(SSL_SESSION_get0_peer(ssl_session));\n }\n\n// std::cerr << \"!!! SSL error at \" << location\n// << \"\\n\\tcaused by \" << statement\n// << \"\\n\\t: \"<< reason_str << \"(\" << error << \")\"\n// << \"\\n\\t last err: \" << ERR_peek_last_error()\n// << std::endl;\n\n throw clickhouse::OpenSSLError(prefix + std::to_string(error) + \" : \" + reason_str);\n}\n\nvoid configureSSL(const clickhouse::SSLParams::ConfigurationType & configuration, SSL * ssl, SSL_CTX * context = nullptr) {\n std::unique_ptr conf_ctx_holder(SSL_CONF_CTX_new(), SSL_CONF_CTX_free);\n auto conf_ctx = conf_ctx_holder.get();\n\n // To make both cmdline and flag file commands start with no prefix.\n SSL_CONF_CTX_set1_prefix(conf_ctx, \"\");\n // Allow all set of client commands, also turn on proper error reporting to reuse throwSSLError().\n SSL_CONF_CTX_set_flags(conf_ctx, SSL_CONF_FLAG_CMDLINE | SSL_CONF_FLAG_FILE | SSL_CONF_FLAG_CLIENT | SSL_CONF_FLAG_SHOW_ERRORS | SSL_CONF_FLAG_CERTIFICATE );\n if (ssl)\n SSL_CONF_CTX_set_ssl(conf_ctx, ssl);\n else if (context)\n SSL_CONF_CTX_set_ssl_ctx(conf_ctx, context);\n\n for (const auto & kv : configuration) {\n const int err = SSL_CONF_cmd(conf_ctx, kv.first.c_str(), (kv.second ? kv.second->c_str() : nullptr));\n // From the documentation:\n // 2 - both key and value used\n // 1 - only key used\n // 0 - error during processing\n // -2 - key not recodnized\n // -3 - missing value\n const bool value_present = !!kv.second;\n if (err == 2 || (err == 1 && !value_present))\n continue;\n else if (err == 0)\n throwSSLError(ssl, SSL_ERROR_NONE, nullptr, nullptr, \"Failed to configure OpenSSL with command '\" + kv.first + \"' \");\n else if (err == 1 && value_present)\n throw clickhouse::OpenSSLError(\"Failed to configure OpenSSL: command '\" + kv.first + \"' needs no value\");\n else if (err == -2)\n throw clickhouse::OpenSSLError(\"Failed to configure OpenSSL: unknown command '\" + kv.first + \"'\");\n else if (err == -3)\n throw clickhouse::OpenSSLError(\"Failed to configure OpenSSL: command '\" + kv.first + \"' requires a value\");\n else\n throw clickhouse::OpenSSLError(\"Failed to configure OpenSSL: command '\" + kv.first + \"' unknown error: \" + std::to_string(err));\n }\n}\n\n#define STRINGIFY_HELPER(x) #x\n#define STRINGIFY(x) STRINGIFY_HELPER(x)\n#define LOCATION __FILE__ \":\" STRINGIFY(__LINE__)\n\nstruct SSLInitializer {\n SSLInitializer() {\n SSL_library_init();\n SSLeay_add_ssl_algorithms();\n SSL_load_error_strings();\n }\n};\n\nSSL_CTX * prepareSSLContext(const clickhouse::SSLParams & context_params) {\n static const SSLInitializer ssl_initializer;\n\n const SSL_METHOD *method = TLS_client_method();\n std::unique_ptr ctx(SSL_CTX_new(method), &SSL_CTX_free);\n\n if (!ctx)\n throw clickhouse::OpenSSLError(\"Failed to initialize SSL context\");\n\n#define HANDLE_SSL_CTX_ERROR(statement) do { \\\n if (const auto ret_code = (statement); !ret_code) \\\n throwSSLError(nullptr, static_cast(ERR_peek_error()), LOCATION, #statement); \\\n} while(false);\n\n if (context_params.use_default_ca_locations)\n HANDLE_SSL_CTX_ERROR(SSL_CTX_set_default_verify_paths(ctx.get()));\n if (!context_params.path_to_ca_directory.empty())\n HANDLE_SSL_CTX_ERROR(\n SSL_CTX_load_verify_locations(\n ctx.get(),\n nullptr,\n context_params.path_to_ca_directory.c_str())\n );\n\n for (const auto & f : context_params.path_to_ca_files)\n HANDLE_SSL_CTX_ERROR(SSL_CTX_load_verify_locations(ctx.get(), f.c_str(), nullptr));\n\n if (context_params.context_options != -1)\n SSL_CTX_set_options(ctx.get(), context_params.context_options);\n if (context_params.min_protocol_version != -1)\n HANDLE_SSL_CTX_ERROR(\n SSL_CTX_set_min_proto_version(ctx.get(), context_params.min_protocol_version));\n if (context_params.max_protocol_version != -1)\n HANDLE_SSL_CTX_ERROR(\n SSL_CTX_set_max_proto_version(ctx.get(), context_params.max_protocol_version));\n\n return ctx.release();\n#undef HANDLE_SSL_CTX_ERROR\n}\n\nauto convertConfiguration(const decltype(clickhouse::ClientOptions::SSLOptions::configuration) & configuration)\n{\n auto result = decltype(clickhouse::SSLParams::configuration){};\n for (const auto & cv : configuration)\n result.push_back({cv.command, cv.value});\n\n return result;\n}\n\nclickhouse::SSLParams GetSSLParams(const clickhouse::ClientOptions& opts) {\n const auto& ssl_options = *opts.ssl_options;\n return clickhouse::SSLParams{\n ssl_options.path_to_ca_files,\n ssl_options.path_to_ca_directory,\n ssl_options.use_default_ca_locations,\n ssl_options.context_options,\n ssl_options.min_protocol_version,\n ssl_options.max_protocol_version,\n ssl_options.use_sni,\n ssl_options.skip_verification,\n ssl_options.host_flags,\n convertConfiguration(ssl_options.configuration)\n };\n}\n\n}\n\nnamespace clickhouse {\n\nSSLContext::SSLContext(SSL_CTX & context)\n : context_(&context, &SSL_CTX_free)\n{\n SSL_CTX_up_ref(context_.get());\n}\n\nSSLContext::SSLContext(const SSLParams & context_params)\n : context_(prepareSSLContext(context_params), &SSL_CTX_free)\n{\n}\n\nSSL_CTX * SSLContext::getContext() {\n return context_.get();\n}\n\n// Allows caller to use returned value of `statement` if there was no error, throws exception otherwise.\n#define HANDLE_SSL_ERROR(SSL_PTR, statement) [&] { \\\n if (const auto ret_code = (statement); ret_code <= 0) { \\\n throwSSLError(SSL_PTR, SSL_get_error(SSL_PTR, static_cast(ret_code)), LOCATION, #statement); \\\n return static_cast>(0); \\\n } \\\n else \\\n return ret_code; \\\n} ()\n\n/* // debug macro for tracing SSL state\n#define LOG_SSL_STATE() std::cerr << \"!!!!\" << LOCATION << \" @\" << __FUNCTION__ \\\n << \"\\t\" << SSL_get_version(ssl_) << \" state: \" << SSL_state_string_long(ssl_) \\\n << \"\\n\\t handshake state: \" << SSL_get_state(ssl_) \\\n << std::endl\n*/\nSSLSocket::SSLSocket(const NetworkAddress& addr, const SocketTimeoutParams& timeout_params,\n const SSLParams & ssl_params, SSLContext& context)\n : Socket(addr, timeout_params)\n , ssl_(SSL_new(context.getContext()), &SSL_free)\n{\n auto ssl = ssl_.get();\n if (!ssl)\n throw clickhouse::OpenSSLError(\"Failed to create SSL instance\");\n\n std::unique_ptr ip_addr(a2i_IPADDRESS(addr.Host().c_str()), &ASN1_OCTET_STRING_free);\n\n HANDLE_SSL_ERROR(ssl, SSL_set_fd(ssl, static_cast(handle_)));\n if (ssl_params.use_SNI)\n HANDLE_SSL_ERROR(ssl, SSL_set_tlsext_host_name(ssl, addr.Host().c_str()));\n\n if (ssl_params.host_flags != -1)\n SSL_set_hostflags(ssl, ssl_params.host_flags);\n HANDLE_SSL_ERROR(ssl, SSL_set1_host(ssl, addr.Host().c_str()));\n\n // DO NOT use SSL_set_verify(ssl, SSL_VERIFY_PEER, nullptr), since\n // we check verification result later, and that provides better error message.\n\n if (ssl_params.configuration.size() > 0)\n configureSSL(ssl_params.configuration, ssl);\n\n SSL_set_connect_state(ssl);\n HANDLE_SSL_ERROR(ssl, SSL_connect(ssl));\n HANDLE_SSL_ERROR(ssl, SSL_set_mode(ssl, SSL_MODE_AUTO_RETRY));\n\n if (const auto verify_result = SSL_get_verify_result(ssl); !ssl_params.skip_verification && verify_result != X509_V_OK) {\n auto error_message = X509_verify_cert_error_string(verify_result);\n throw clickhouse::OpenSSLError(\"Failed to verify SSL connection, X509_v error: \"\n + std::to_string(verify_result)\n + \" \" + error_message\n + \"\\nServer certificate: \" + getCertificateInfo(SSL_get_peer_certificate(ssl)));\n }\n\n // Host name verification is done by OpenSSL itself, however if we are connecting to an ip-address,\n // no verification is made, so we have to do it manually.\n // Just in case if this is ever required, leave it here commented out.\n// if (ip_addr) {\n// // if hostname is actually an IP address\n// HANDLE_SSL_ERROR(ssl, X509_check_ip(\n// SSL_get_peer_certificate(ssl),\n// ASN1_STRING_get0_data(ip_addr.get()),\n// ASN1_STRING_length(ip_addr.get()),\n// 0));\n// }\n}\n\nvoid SSLSocket::validateParams(const SSLParams & ssl_params) {\n // We need either SSL or SSL_CTX to properly validate configuration, so create a temporary one.\n std::unique_ptr ctx(SSL_CTX_new(TLS_client_method()), &SSL_CTX_free);\n configureSSL(ssl_params.configuration, nullptr, ctx.get());\n}\n\n\nSSLSocketFactory::SSLSocketFactory(const ClientOptions& opts)\n : NonSecureSocketFactory()\n , ssl_params_(GetSSLParams(opts)) {\n if (opts.ssl_options->ssl_context) {\n ssl_context_ = std::make_unique(*opts.ssl_options->ssl_context);\n } else {\n ssl_context_ = std::make_unique(ssl_params_);\n }\n}\n\nSSLSocketFactory::~SSLSocketFactory() = default;\n\nstd::unique_ptr SSLSocketFactory::doConnect(const NetworkAddress& address, const ClientOptions& opts) {\n SocketTimeoutParams timeout_params { opts.connection_connect_timeout, opts.connection_recv_timeout, opts.connection_send_timeout };\n return std::make_unique(address, timeout_params, ssl_params_, *ssl_context_);\n}\n\nstd::unique_ptr SSLSocket::makeInputStream() const {\n return std::make_unique(ssl_.get());\n}\n\nstd::unique_ptr SSLSocket::makeOutputStream() const {\n return std::make_unique(ssl_.get());\n}\n\nSSLSocketInput::SSLSocketInput(SSL *ssl)\n : ssl_(ssl)\n{}\n\nsize_t SSLSocketInput::DoRead(void* buf, size_t len) {\n size_t actually_read;\n HANDLE_SSL_ERROR(ssl_, SSL_read_ex(ssl_, buf, len, &actually_read));\n return actually_read;\n}\n\nSSLSocketOutput::SSLSocketOutput(SSL *ssl)\n : ssl_(ssl)\n{}\n\nsize_t SSLSocketOutput::DoWrite(const void* data, size_t len) {\n if (len > std::numeric_limits::max())\n // FIXME(vnemkov): We should do multiple `SSL_write`s in this case.\n throw AssertionError(\"Failed to write too big chunk at once \"\n + std::to_string(len) + \" > \" + std::to_string(std::numeric_limits::max()));\n return static_cast(HANDLE_SSL_ERROR(ssl_, SSL_write(ssl_, data, static_cast(len))));\n}\n\n#undef HANDLE_SSL_ERROR\n\n}\n\n// Path: clickhouse/base/singleton.h\n#pragma once\n\nnamespace clickhouse {\n\ntemplate \nT* Singleton() {\n static T instance;\n return &instance;\n}\n\n}\n\n// Path: clickhouse/base/socket.cpp\n#include \"socket.h\"\n#include \"singleton.h\"\n#include \"../client.h\"\n\n#include \n#include \n#include \n#include \n#include \n#include \n\n#if !defined(_win_)\n# include \n# include \n# include \n# include \n# include \n# include \n#endif\n\nnamespace clickhouse {\n\n#if defined(_win_)\nchar const* windowsErrorCategory::name() const noexcept {\n return \"WindowsSocketError\";\n}\n\nstd::string windowsErrorCategory::message(int c) const {\n char error[UINT8_MAX];\n auto len = FormatMessageA(FORMAT_MESSAGE_FROM_SYSTEM, nullptr, static_cast(c), 0, error, sizeof(error), nullptr);\n if (len == 0) {\n return \"unknown\";\n }\n while (len && (error[len - 1] == '\\r' || error[len - 1] == '\\n')) {\n...\n", "needle_token_start": 2356, "needle_token_end": 2473, "code_context_ntokens": 16054, "output": ["```cpp\ntemplate \nstd::string ContainerToString(Container container, const char * separator = \", \") {\n std::stringstream sstr;\n const auto end = std::end(container);\n for (auto i = std::begin(container); i != end; /*intentionally no ++i*/) {\n const auto & elem = *i;\n sstr << elem;\n\n if (++i != end) {\n sstr << separator;\n }\n }\n\n return sstr.str();\n}\n```"]} +{"repo": "ClickHouse/clickhouse-cpp", "name": "MapValueView", "language": "cpp", "path": "clickhouse/columns/map.h", "position_ratio": 0.25, "description": "\nFunction Description:\n1. **Purpose**: Provides a read-only view of a single row in a map-type column, allowing access to elements in a map structure with linear complexity.\n2. **Input**: Receives a view of an array representing a single map entry, which includes both keys and values.\n3. **Output**: Enables access to the elements of the map as pairs of keys and values.\n4. **Procedure**: Initializes with an array view of a map entry and facilitates the retrieval of elements as key-value pairs, though each access operation has a linear time complexity due to the underlying data structure.\n\n", "instruction": "Based on the function description and code context, please retrieve and repeat the exact described function from the code context in a code block wrapped by ```:", "template": "instruction\ncode_context\ndescription\ninstruction", "code_context": "// Path: clickhouse/columns/nullable.h\n#pragma once\n\n#include \"column.h\"\n#include \"numeric.h\"\n\n#include \n\nnamespace clickhouse {\n\n/**\n * Represents column of Nullable(T).\n */\nclass ColumnNullable : public Column {\npublic:\n ColumnNullable(ColumnRef nested, ColumnRef nulls);\n\n /// Appends one null flag to the end of the column\n void Append(bool isnull);\n\n /// Returns null flag at given row number.\n bool IsNull(size_t n) const;\n\n /// Returns nested column.\n ColumnRef Nested() const;\n\n /// Returns nulls column.\n ColumnRef Nulls() const;\n\n...\n// Path: clickhouse/columns/lowcardinality.h\n#pragma once\n\n#include \"column.h\"\n#include \"numeric.h\"\n#include \"nullable.h\"\n\n#include \n#include \n#include \n#include \n\nnamespace clickhouse {\n\ntemplate \nclass ColumnLowCardinalityT;\n\nnamespace details {\n\n/** LowCardinalityHashKey used as key in unique items hashmap to abstract away key value\n * (type of which depends on dictionary column) and to reduce likelehood of collisions.\n *\n * In order to dramatically reduce collision rate, we use 2 different hashes from 2 different hash functions.\n * First hash is used in hashtable (to calculate item position).\n * Second one is used as part of key value and accessed via `operator==()` upon collision resolution/detection.\n */\nusing LowCardinalityHashKey = std::pair;\n\nstruct LowCardinalityHashKeyHash {\n inline std::size_t operator()(const LowCardinalityHashKey &hash_key) const noexcept {\n return hash_key.first;\n }\n};\n\n}\n\n/*\n * LC column contains an \"invisible\" default item at the beginning of the collection. [default, ...]\n * If the nested type is Nullable, it contains a null-item at the beginning and a default item at the second position. [null, default, ...]\n * Null map is not serialized in LC columns. Instead, nulls are tracked by having an index of 0.\n * */\nclass ColumnLowCardinality : public Column {\npublic:\n using UniqueItems = std::unordered_map;\n\n template \n friend class ColumnLowCardinalityT;\n\nprivate:\n // IMPLEMENTATION NOTE: ColumnLowCardinalityT takes reference to underlying dictionary column object,\n // so make sure to NOT change address of the dictionary object (with reset(), swap()) or with anything else.\n ColumnRef dictionary_column_;\n ColumnRef index_column_;\n UniqueItems unique_items_map_;\n\npublic:\n ColumnLowCardinality(ColumnLowCardinality&& col) = default;\n // c-tor makes a deep copy of the dictionary_column.\n explicit ColumnLowCardinality(ColumnRef dictionary_column);\n explicit ColumnLowCardinality(std::shared_ptr dictionary_column);\n\n template \n explicit ColumnLowCardinality(std::shared_ptr> dictionary_column)\n : ColumnLowCardinality(dictionary_column->template As())\n {}\n\n ~ColumnLowCardinality();\n\n /// Increase the capacity of the column for large block insertion.\n void Reserve(size_t new_cap) override;\n\n /// Appends another LowCardinality column to the end of this one, updating dictionary.\n void Append(ColumnRef /*column*/) override;\n\n bool LoadPrefix(InputStream* input, size_t rows) override;\n\n /// Loads column data from input stream.\n bool LoadBody(InputStream* input, size_t rows) override;\n\n /// Saves column prefix to output stream.\n void SavePrefix(OutputStream* output) override;\n\n /// Saves column data to output stream.\n void SaveBody(OutputStream* output) override;\n\n /// Clear column data.\n void Clear() override;\n\n /// Returns count of rows in the column.\n size_t Size() const override;\n\n /// Makes slice of current column, with compacted dictionary\n ColumnRef Slice(size_t begin, size_t len) const override;\n ColumnRef CloneEmpty() const override;\n void Swap(Column& other) override;\n ItemView GetItem(size_t index) const override;\n\n size_t GetDictionarySize() const;\n TypeRef GetNestedType() const;\n\nprotected:\n std::uint64_t getDictionaryIndex(std::uint64_t item_index) const;\n void appendIndex(std::uint64_t item_index);\n void removeLastIndex();\n ColumnRef GetDictionary();\n\n void AppendUnsafe(const ItemView &);\n\nprivate:\n void Setup(ColumnRef dictionary_column);\n void AppendNullItem();\n void AppendDefaultItem();\n\npublic:\n static details::LowCardinalityHashKey computeHashKey(const ItemView &);\n};\n\n/** Type-aware wrapper that provides simple convenience interface for accessing/appending individual items.\n */\ntemplate \nclass ColumnLowCardinalityT : public ColumnLowCardinality {\n\n DictionaryColumnType& typed_dictionary_;\n const Type::Code type_;\n\npublic:\n using WrappedColumnType = DictionaryColumnType;\n // Type this column takes as argument of Append and returns with At() and operator[]\n using ValueType = typename DictionaryColumnType::ValueType;\n\n explicit ColumnLowCardinalityT(ColumnLowCardinality&& col)\n : ColumnLowCardinality(std::move(col))\n , typed_dictionary_(dynamic_cast(*GetDictionary()))\n , type_(GetTypeCode(typed_dictionary_))\n {\n }\n\n template \n explicit ColumnLowCardinalityT(Args &&... args)\n : ColumnLowCardinalityT(std::make_shared(std::forward(args)...))\n {}\n\n // Create LC column from existing T-column, making a deep copy of all contents.\n explicit ColumnLowCardinalityT(std::shared_ptr dictionary_col)\n : ColumnLowCardinality(dictionary_col)\n , typed_dictionary_(dynamic_cast(*GetDictionary()))\n , type_(GetTypeCode(typed_dictionary_))\n {}\n\n /// Extended interface to simplify reading/adding individual items.\n\n /// Returns element at given row number.\n inline ValueType At(size_t n) const {\n return typed_dictionary_.At(getDictionaryIndex(n));\n }\n\n /// Returns element at given row number.\n inline ValueType operator [] (size_t n) const {\n return typed_dictionary_[getDictionaryIndex(n)];\n }\n\n // so the non-virtual Append below doesn't shadow Append() from base class when compiled with older compilers.\n using ColumnLowCardinality::Append;\n\n inline void Append(const ValueType & value) {\n if constexpr (IsNullable) {\n if (value.has_value()) {\n AppendUnsafe(ItemView{type_, *value});\n } else {\n AppendUnsafe(ItemView{});\n }\n } else {\n AppendUnsafe(ItemView{type_, value});\n }\n }\n\n template \n inline void AppendMany(const T& container) {\n for (const auto & item : container) {\n Append(item);\n }\n }\n\n /** Create a ColumnLowCardinalityT from a ColumnLowCardinality, without copying data and offsets, but by\n * 'stealing' those from `col`.\n *\n * Ownership of column internals is transferred to returned object, original (argument) object\n * MUST NOT BE USED IN ANY WAY, it is only safe to dispose it.\n *\n * Throws an exception if `col` is of wrong type, it is safe to use original col in this case.\n * This is a static method to make such conversion verbose.\n */\n static auto Wrap(ColumnLowCardinality&& col) {\n return std::make_shared>(std::move(col));\n }\n\n static auto Wrap(Column&& col) { return Wrap(std::move(dynamic_cast(col))); }\n\n // Helper to simplify integration with other APIs\n static auto Wrap(ColumnRef&& col) { return Wrap(std::move(*col->AsStrict())); }\n\n ColumnRef Slice(size_t begin, size_t size) const override {\n return Wrap(ColumnLowCardinality::Slice(begin, size));\n }\n\n ColumnRef CloneEmpty() const override { return Wrap(ColumnLowCardinality::CloneEmpty()); }\n\nprivate:\n\n template \n static auto GetTypeCode(T& column) {\n if constexpr (IsNullable) {\n return GetTypeCode(*column.Nested()->template AsStrict());\n } else {\n return column.Type()->GetCode();\n }\n }\n};\n\n}\n\n// Path: clickhouse/base/projected_iterator.h\n#pragma once\n\n#include \n#include \n#include \n\nnamespace clickhouse {\n\ntemplate ()(std::declval())),\n typename Value = std::decay_t>\nclass ProjectedIterator {\npublic:\n using value_type = Value;\n using reference = Reference;\n using pointer = Reference;\n using difference_type = typename std::iterator_traits::difference_type;\n using iterator_category = typename std::iterator_traits::iterator_category;\n\n ProjectedIterator() = default;\n\n inline ProjectedIterator(Iterator const& iterator, UnaryFunction functor)\n : iterator_(iterator)\n , functor_(std::move(functor)) {\n }\n\n inline UnaryFunction functor() const { return functor; }\n\n inline Iterator const& base() const { return iterator_; }\n\n inline reference operator*() const { return functor_(iterator_); }\n\n inline ProjectedIterator& operator++() {\n ++iterator_;\n return *this;\n }\n\n inline ProjectedIterator& operator--() {\n --iterator_;\n return *this;\n }\n\n inline bool operator==(const ProjectedIterator& other) const {\n return this->iterator_ == other.iterator_;\n }\n\n inline bool operator!=(const ProjectedIterator& other) const {\n return !(*this == other);\n }\n\nprivate:\n Iterator iterator_;\n UnaryFunction functor_;\n};\n\n} // namespace clickhouse\n\n// Path: clickhouse/columns/map.h\n#pragma once\n\n#include \"../base/projected_iterator.h\"\n#include \"array.h\"\n#include \"column.h\"\n#include \"tuple.h\"\n\n#include \n#include \n\nnamespace clickhouse {\n\ntemplate \nclass ColumnMapT;\n\n/**\n * Represents column of Map(K, V).\n */\nclass ColumnMap : public Column {\npublic:\n /** Create a map of given type, with actual values and offsets.\n *\n * Both `data` and `offsets` are used (and modified) internally bye ColumnArray.\n * Users are strongly advised against modifying contents of `data` or `offsets` afterwards.\n */\n explicit ColumnMap(ColumnRef data);\n\n /// Increase the capacity of the column for large block insertion.\n void Reserve(size_t new_cap) override;\n\n /// Appends content of given column to the end of current one.\n void Append(ColumnRef column) override;\n\n /// Loads column prefix from input stream.\n bool LoadPrefix(InputStream* input, size_t rows) override;\n\n /// Loads column data from input stream.\n bool LoadBody(InputStream* input, size_t rows) override;\n\n /// Saves column prefix to output stream.\n void SavePrefix(OutputStream* output) override;\n\n /// Saves column data to output stream.\n void SaveBody(OutputStream* output) override;\n\n /// Clear column data .\n void Clear() override;\n\n /// Returns count of rows in the column.\n size_t Size() const override;\n\n /// Makes slice of the current column.\n ColumnRef Slice(size_t, size_t) const override;\n ColumnRef CloneEmpty() const override;\n void Swap(Column&) override;\n\n /// Converts map at pos n to column.\n /// Type of row is tuple {key, value}.\n ColumnRef GetAsColumn(size_t n) const;\n\nprotected:\n template \n friend class ColumnMapT;\n\n ColumnMap(ColumnMap&& map);\n\nprivate:\n std::shared_ptr data_;\n};\n\ntemplate \nclass ColumnMapT : public ColumnMap {\npublic:\n using KeyColumnType = K;\n using ValueColumnType = V;\n using Key = std::decay_t().At(0))>;\n using Value = std::decay_t().At(0))>;\n using TupleColumnType = ColumnTupleT;\n using ArrayColumnType = ColumnArrayT;\n\n ColumnMapT(ColumnRef data)\n : ColumnMap(data), typed_data_(data->AsStrict>()) {}\n\n ColumnMapT(std::shared_ptr keys, std::shared_ptr values)\n : ColumnMap(std::make_shared(std::make_shared(\n std::make_tuple(std::move(keys), std::move(values))))),\n typed_data_(data_->template As()) {}\n\n ColumnRef Slice(size_t begin, size_t len) const override {\n return std::make_shared>(typed_data_->Slice(begin, len));\n }\n\n ColumnRef CloneEmpty() const override {\n return std::make_shared>(typed_data_->CloneEmpty());\n }\n\n void Swap(Column& other) override {\n auto& col = dynamic_cast&>(other);\n col.typed_data_.swap(typed_data_);\n ColumnMap::Swap(other);\n }\n\n /// A single (row) value of the Map-column i.e. read-only map.\n /// It has a linear time complexity to access items\n /// Because data base type has same structure\n /// \"This lookup works now with a linear complexity.\"\n /// https://clickhouse.com/docs/en/sql-reference/data-types/map\n /// Convert it to a suitable container required to access more than one element\n\n class MapValueView {\n const typename ArrayColumnType::ArrayValueView data_;\n\n public:\n using ValueType = std::pair;\n\n \nMapValueView(typename ArrayColumnType::ArrayValueView data) : data_(std::move(data)) {}\n\n inline auto operator[](const Key& key) const { return (*Find(key)).second; }\n\n inline auto At(const Key& key) const {\n auto it = Find(key);\n if (it == end()) throw ValidationError(\"ColumnMap value key not found\");\n return (*it).second;\n }\n\n class Iterator {\n typename ArrayColumnType::ArrayValueView::Iterator data_iterator_;\n\n public:\n Iterator() = default;\n\n Iterator(typename ArrayColumnType::ArrayValueView::Iterator data_iterator)\n : data_iterator_(data_iterator) {}\n\n using ValueType = std::pair;\n using difference_type = size_t;\n using value_type = ValueType;\n using pointer = void;\n using reference = ValueType&;\n using iterator_category = std::forward_iterator_tag;\n\n inline auto operator*() const {\n auto tuple = *data_iterator_;\n return ValueType{std::get<0>(tuple), std::get<1>(tuple)};\n }\n\n inline Iterator& operator++() {\n ++data_iterator_;\n return *this;\n }\n\n inline bool operator==(const Iterator& other) const {\n return this->data_iterator_ == other.data_iterator_;\n }\n\n inline bool operator!=(const Iterator& other) const { return !(*this == other); }\n };\n\n // minimalistic stl-like container interface, hence the lowercase\n inline Iterator begin() const { return Iterator{data_.begin()}; }\n\n inline Iterator cbegin() const { return Iterator{data_.cbegin()}; }\n\n inline Iterator end() const { return Iterator{data_.end()}; }\n\n inline Iterator cend() const { return Iterator{data_.cend()}; }\n\n inline size_t size() const { return data_.size(); }\n\n // It is ugly to have both size() and Size(), but it is for compatitability with both STL\n // and rest of the clickhouse-cpp.\n inline size_t Size() const { return data_.Size(); }\n\n inline size_t Count(const Key& key) const {\n size_t result = 0;\n for (auto item : data_) {\n if (std::get<0>(item) == key) {\n ++result;\n }\n }\n return result;\n }\n\n inline Iterator Find(const Key& key) const {\n for (auto it = data_.begin(); it != data_.end(); ++it) {\n if (std::get<0>(*it) == key) {\n return Iterator{it};\n }\n }\n return end();\n }\n\n inline bool operator==(const MapValueView& other) const {\n if (size() != other.size()) {\n return false;\n }\n const auto make_index = [](const auto& data) {\n std::vector result{data.Size()};\n std::generate(result.begin(), result.end(), [i = 0] () mutable {return i++;});\n std::sort(result.begin(), result.end(), [&data](size_t l, size_t r) {return data[l] < data[r];});\n return result;\n };\n const auto index = make_index(data_);\n for (const auto& val : other.data_) {\n if (!std::binary_search(index.begin(), index.end(), val,\n [&data = data_](const auto& l, size_t r) {return l < data[r];})) {\n return false;\n }\n }\n return true;\n }\n\n inline bool operator!=(const MapValueView& other) const { return !(*this == other); }\n };\n\n inline auto At(size_t index) const { return MapValueView{typed_data_->At(index)}; }\n\n inline auto operator[](size_t index) const { return At(index); }\n\n using ColumnMap::Append;\n\n inline void Append(const MapValueView& value) { typed_data_->Append(value.data_); }\n\n inline void Append(const std::vector>& tuples) {\n typed_data_->Append(tuples.begin(), tuples.end());\n }\n\n template \n inline void Append(const T& value) {\n using BaseIter = decltype(value.begin());\n using KeyOfT = decltype(std::declval()->first);\n using ValOfT = decltype(std::declval()->second);\n using Functor = std::function(const BaseIter&)>;\n using Iterator = ProjectedIterator;\n\n Functor functor = [](const BaseIter& i) {\n return std::make_tuple(std::cref(i->first), std::cref(i->second));\n };\n\n typed_data_->Append(Iterator{value.begin(), functor}, Iterator{value.end(), functor});\n }\n\n static auto Wrap(ColumnMap&& col) {\n auto data = ArrayColumnType::Wrap(std::move(col.data_));\n return std::make_shared>(std::move(data));\n }\n\n static auto Wrap(Column&& col) { return Wrap(std::move(dynamic_cast(col))); }\n\n // Helper to simplify integration with other APIs\n static auto Wrap(ColumnRef&& col) { return Wrap(std::move(*col->AsStrict())); }\n\nprivate:\n std::shared_ptr typed_data_;\n};\n\n} // namespace clickhouse\n\n// Path: clickhouse/columns/uuid.h\n#pragma once\n\n#include \"../base/uuid.h\"\n#include \"column.h\"\n#include \"numeric.h\"\n\nnamespace clickhouse {\n\n\n/**\n * Represents a UUID column.\n */\nclass ColumnUUID : public Column {\npublic:\n ColumnUUID();\n\n explicit ColumnUUID(ColumnRef data);\n\n /// Appends one element to the end of column.\n void Append(const UUID& value);\n\n /// Returns element at given row number.\n const UUID At(size_t n) const;\n\n /// Returns element at given row number.\n inline const UUID operator [] (size_t n) const { return At(n); }\n\npublic:\n /// Increase the capacity of the column for large block insertion.\n void Reserve(size_t new_cap) override;\n\n /// Appends content of given column to the end of current one.\n void Append(ColumnRef column) override;\n\n /// Loads column data from input stream.\n bool LoadBody(InputStream* input, size_t rows) override;\n\n /// Saves column data to output stream.\n void SaveBody(OutputStream* output) override;\n\n /// Clear column data .\n void Clear() override;\n\n /// Returns count of rows in the column.\n size_t Size() const override;\n\n /// Makes slice of the current column.\n ColumnRef Slice(size_t begin, size_t len) const override;\n ColumnRef CloneEmpty() const override;\n void Swap(Column& other) override;\n\n ItemView GetItem(size_t) const override;\n\nprivate:\n std::shared_ptr data_;\n};\n\n}\n\n// Path: clickhouse/client.h\n#pragma once\n\n#include \"query.h\"\n#include \"exceptions.h\"\n\n#include \"columns/array.h\"\n#include \"columns/date.h\"\n#include \"columns/decimal.h\"\n#include \"columns/enum.h\"\n#include \"columns/geo.h\"\n#include \"columns/ip4.h\"\n#include \"columns/ip6.h\"\n#include \"columns/lowcardinality.h\"\n#include \"columns/nullable.h\"\n#include \"columns/numeric.h\"\n#include \"columns/map.h\"\n#include \"columns/string.h\"\n#include \"columns/tuple.h\"\n#include \"columns/uuid.h\"\n\n#include \n#include \n#include \n#include \n#include \n#include \n\ntypedef struct ssl_ctx_st SSL_CTX;\n\nnamespace clickhouse {\n\nstruct ServerInfo {\n std::string name;\n std::string timezone;\n std::string display_name;\n uint64_t version_major;\n uint64_t version_minor;\n uint64_t version_patch;\n uint64_t revision;\n};\n\n/// Methods of block compression.\nenum class CompressionMethod {\n None = -1,\n LZ4 = 1,\n};\n\nstruct Endpoint {\n std::string host;\n uint16_t port = 9000;\n inline bool operator==(const Endpoint& right) const {\n return host == right.host && port == right.port;\n }\n};\n\nenum class EndpointsIterationAlgorithm {\n RoundRobin = 0,\n};\n\nstruct ClientOptions {\n // Setter goes first, so it is possible to apply 'deprecated' annotation safely.\n#define DECLARE_FIELD(name, type, setter, default_value) \\\n inline auto & setter(const type& value) { \\\n name = value; \\\n return *this; \\\n } \\\n type name = default_value\n\n /// Hostname of the server.\n DECLARE_FIELD(host, std::string, SetHost, std::string());\n /// Service port.\n DECLARE_FIELD(port, uint16_t, SetPort, 9000);\n\n /** Set endpoints (host+port), only one is used.\n * Client tries to connect to those endpoints one by one, on the round-robin basis:\n * first default enpoint (set via SetHost() + SetPort()), then each of endpoints, from begin() to end(),\n * the first one to establish connection is used for the rest of the session.\n * If port isn't specified, default(9000) value will be used.\n */\n DECLARE_FIELD(endpoints, std::vector, SetEndpoints, {});\n\n /// Default database.\n DECLARE_FIELD(default_database, std::string, SetDefaultDatabase, \"default\");\n /// User name.\n DECLARE_FIELD(user, std::string, SetUser, \"default\");\n /// Access password.\n DECLARE_FIELD(password, std::string, SetPassword, std::string());\n\n /// By default all exceptions received during query execution will be\n /// passed to OnException handler. Set rethrow_exceptions to true to\n /// enable throwing exceptions with standard c++ exception mechanism.\n DECLARE_FIELD(rethrow_exceptions, bool, SetRethrowException, true);\n\n /// Ping server every time before execute any query.\n DECLARE_FIELD(ping_before_query, bool, SetPingBeforeQuery, false);\n /// Count of retry to send request to server.\n DECLARE_FIELD(send_retries, unsigned int, SetSendRetries, 1);\n /// Amount of time to wait before next retry.\n DECLARE_FIELD(retry_timeout, std::chrono::seconds, SetRetryTimeout, std::chrono::seconds(5));\n\n /// Compression method.\n DECLARE_FIELD(compression_method, CompressionMethod, SetCompressionMethod, CompressionMethod::None);\n\n /// TCP Keep alive options\n DECLARE_FIELD(tcp_keepalive, bool, TcpKeepAlive, false);\n DECLARE_FIELD(tcp_keepalive_idle, std::chrono::seconds, SetTcpKeepAliveIdle, std::chrono::seconds(60));\n DECLARE_FIELD(tcp_keepalive_intvl, std::chrono::seconds, SetTcpKeepAliveInterval, std::chrono::seconds(5));\n DECLARE_FIELD(tcp_keepalive_cnt, unsigned int, SetTcpKeepAliveCount, 3);\n\n // TCP options\n DECLARE_FIELD(tcp_nodelay, bool, TcpNoDelay, true);\n\n /// Connection socket connect timeout. If the timeout is negative then the connect operation will never timeout.\n DECLARE_FIELD(connection_connect_timeout, std::chrono::milliseconds, SetConnectionConnectTimeout, std::chrono::seconds(5));\n\n /// Connection socket timeout. If the timeout is set to zero then the operation will never timeout.\n DECLARE_FIELD(connection_recv_timeout, std::chrono::milliseconds, SetConnectionRecvTimeout, std::chrono::milliseconds(0));\n DECLARE_FIELD(connection_send_timeout, std::chrono::milliseconds, SetConnectionSendTimeout, std::chrono::milliseconds(0));\n\n /** It helps to ease migration of the old codebases, which can't afford to switch\n * to using ColumnLowCardinalityT or ColumnLowCardinality directly,\n * but still want to benefit from smaller on-wire LowCardinality bandwidth footprint.\n *\n * @see LowCardinalitySerializationAdaptor, CreateColumnByType\n */\n [[deprecated(\"Makes implementation of LC(X) harder and code uglier. Will be removed in next major release (3.0) \")]]\n DECLARE_FIELD(backward_compatibility_lowcardinality_as_wrapped_column, bool, SetBakcwardCompatibilityFeatureLowCardinalityAsWrappedColumn, false);\n\n /** Set max size data to compress if compression enabled.\n *\n * Allows choosing tradeoff between RAM\\CPU:\n * - Lower value reduces RAM usage, but slightly increases CPU usage.\n * - Higher value increases RAM usage but slightly decreases CPU usage.\n */\n DECLARE_FIELD(max_compression_chunk_size, unsigned int, SetMaxCompressionChunkSize, 65535);\n\n struct SSLOptions {\n /** There are two ways to configure an SSL connection:\n * - provide a pre-configured SSL_CTX, which is not modified and not owned by the Client.\n * - provide a set of options and allow the Client to create and configure SSL_CTX by itself.\n */\n\n /** Pre-configured SSL-context for SSL-connection.\n * If NOT null client DONES NOT take ownership of context and it must be valid for client lifetime.\n * If null client initlaizes OpenSSL and creates his own context, initializes it using\n * other options, like path_to_ca_files, path_to_ca_directory, use_default_ca_locations, etc.\n *\n * Either way context is used to create an SSL-connection, which is then configured with\n * whatever was provided as `configuration`, `host_flags`, `skip_verification` and `use_sni`.\n */\n SSL_CTX * ssl_context = nullptr;\n auto & SetExternalSSLContext(SSL_CTX * new_ssl_context) {\n ssl_context = new_ssl_context;\n return *this;\n }\n\n /** Means to validate the server-supplied certificate against trusted Certificate Authority (CA).\n * If no CAs are configured, the server's identity can't be validated, and the Client would err.\n * See https://www.openssl.org/docs/man1.1.1/man3/SSL_CTX_set_default_verify_paths.html\n */\n /// Load default CA certificates from default locations.\n DECLARE_FIELD(use_default_ca_locations, bool, SetUseDefaultCALocations, true);\n /// Path to the CA files to verify server certificate, may be empty.\n DECLARE_FIELD(path_to_ca_files, std::vector, SetPathToCAFiles, {});\n /// Path to the directory with CA files used to validate server certificate, may be empty.\n DECLARE_FIELD(path_to_ca_directory, std::string, SetPathToCADirectory, \"\");\n\n /** Min and max protocol versions to use, set with SSL_CTX_set_min_proto_version and SSL_CTX_set_max_proto_version\n * for details see https://www.openssl.org/docs/man1.1.1/man3/SSL_CTX_set_min_proto_version.html\n */\n DECLARE_FIELD(min_protocol_version, int, SetMinProtocolVersion, DEFAULT_VALUE);\n DECLARE_FIELD(max_protocol_version, int, SetMaxProtocolVersion, DEFAULT_VALUE);\n\n /** Options to be set with SSL_CTX_set_options,\n * for details see https://www.openssl.org/docs/man1.1.1/man3/SSL_CTX_set_options.html\n */\n DECLARE_FIELD(context_options, int, SetContextOptions, DEFAULT_VALUE);\n\n /** Use SNI at ClientHello\n */\n DECLARE_FIELD(use_sni, bool, SetUseSNI, true);\n\n /** Skip SSL session verification (server's certificate, etc).\n *\n * WARNING: settig to true will bypass all SSL session checks, which\n * is dangerous, but can be used against self-signed certificates, e.g. for testing purposes.\n */\n DECLARE_FIELD(skip_verification, bool, SetSkipVerification, false);\n\n /** Mode of verifying host ssl certificate against name of the host, set with SSL_set_hostflags.\n * For details see https://www.openssl.org/docs/man1.1.1/man3/SSL_set_hostflags.html\n */\n DECLARE_FIELD(host_flags, int, SetHostVerifyFlags, DEFAULT_VALUE);\n\n struct CommandAndValue {\n std::string command;\n std::optional value = std::nullopt;\n };\n /** Extra configuration options, set with SSL_CONF_cmd.\n * For deatils see https://www.openssl.org/docs/man1.1.1/man3/SSL_CONF_cmd.html\n *\n * Takes multiple pairs of command-value strings, all commands are supported,\n * and prefix is empty.\n * i.e. pass `sigalgs` or `SignatureAlgorithms` instead of `-sigalgs`.\n *\n * Rewrites any other options/flags if set in other ways.\n */\n DECLARE_FIELD(configuration, std::vector, SetConfiguration, {});\n\n static const int DEFAULT_VALUE = -1;\n };\n\n // By default SSL is turned off.\n std::optional ssl_options = std::nullopt;\n\n // Will throw an exception if client was built without SSL support.\n ClientOptions& SetSSLOptions(SSLOptions options);\n\n#undef DECLARE_FIELD\n};\n\nstd::ostream& operator<<(std::ostream& os, const ClientOptions& options);\nstd::ostream& operator<<(std::ostream& os, const Endpoint& options);\n\nclass SocketFactory;\n\n/**\n *\n */\nclass Client {\npublic:\n Client(const ClientOptions& opts);\n Client(const ClientOptions& opts,\n std::unique_ptr socket_factory);\n ~Client();\n\n /// Intends for execute arbitrary queries.\n void Execute(const Query& query);\n\n /// Intends for execute select queries. Data will be returned with\n /// one or more call of \\p cb.\n void Select(const std::string& query, SelectCallback cb);\n void Select(const std::string& query, const std::string& query_id, SelectCallback cb);\n\n /// Executes a select query which can be canceled by returning false from\n /// the data handler function \\p cb.\n void SelectCancelable(const std::string& query, SelectCancelableCallback cb);\n void SelectCancelable(const std::string& query, const std::string& query_id, SelectCancelableCallback cb);\n\n /// Alias for Execute.\n void Select(const Query& query);\n\n /// Intends for insert block of data into a table \\p table_name.\n void Insert(const std::string& table_name, const Block& block);\n void Insert(const std::string& table_name, const std::string& query_id, const Block& block);\n\n /// Ping server for aliveness.\n void Ping();\n\n /// Reset connection with initial params.\n void ResetConnection();\n\n const ServerInfo& GetServerInfo() const;\n\n /// Get current connected endpoint.\n /// In case when client is not connected to any endpoint, nullopt will returned.\n const std::optional& GetCurrentEndpoint() const;\n\n // Try to connect to different endpoints one by one only one time. If it doesn't work, throw an exception.\n void ResetConnectionEndpoint();\n\n struct Version\n {\n uint16_t major;\n uint16_t minor;\n uint16_t patch;\n uint16_t build;\n const char * extra;\n };\n\n static Version GetVersion();\n\nprivate:\n const ClientOptions options_;\n\n class Impl;\n std::unique_ptr impl_;\n};\n\n}\n\n// Path: clickhouse/protocol.h\n#pragma once\n\nnamespace clickhouse {\n\n /// Types of packets received from server\n namespace ServerCodes {\n enum {\n Hello = 0, /// Name, version, revision.\n Data = 1, /// `Block` of data, may be compressed.\n Exception = 2, /// Exception that occurred on server side during query execution.\n Progress = 3, /// Query execcution progress: rows and bytes read.\n Pong = 4, /// response to Ping sent by client.\n EndOfStream = 5, /// All packets were sent.\n ProfileInfo = 6, /// Profiling data\n Totals = 7, /// Block of totals, may be compressed.\n Extremes = 8, /// Block of mins and maxs, may be compressed.\n TablesStatusResponse = 9, /// Response to TableStatus.\n Log = 10, /// Query execution log.\n TableColumns = 11, /// Columns' description for default values calculation\n PartUUIDs = 12, /// List of unique parts ids.\n ReadTaskRequest = 13, /// String (UUID) describes a request for which next task is needed\n /// This is such an inverted logic, where server sends requests\n /// And client returns back response\n ProfileEvents = 14, /// Packet with profile events from server.\n };\n }\n\n /// Types of packets sent by client.\n namespace ClientCodes {\n enum {\n Hello = 0, /// Name, version, default database name.\n Query = 1, /** Query id, query settings, query processing stage,\n * compression status, and query text (no INSERT data).\n */\n Data = 2, /// Data `Block` (e.g. INSERT data), may be compressed.\n Cancel = 3, /// Cancel query.\n Ping = 4, /// Check server connection.\n };\n }\n\n /// Should we compress `Block`s of data\n namespace CompressionState {\n enum {\n Disable = 0,\n Enable = 1,\n };\n }\n\n namespace Stages {\n enum {\n Complete = 2,\n };\n }\n}\n\n// Path: clickhouse/base/input.h\n#pragma once\n\n#include \n#include \n#include \n#include \n\nnamespace clickhouse {\n\nclass InputStream {\npublic:\n virtual ~InputStream() noexcept (false)\n { }\n\n /// Reads one byte from the stream.\n inline bool ReadByte(uint8_t* byte) {\n return DoRead(byte, sizeof(uint8_t)) == sizeof(uint8_t);\n }\n\n /// Reads some data from the stream.\n inline size_t Read(void* buf, size_t len) {\n return DoRead(buf, len);\n }\n\n // Skips a number of bytes. Returns false if an underlying read error occurs.\n virtual bool Skip(size_t bytes) = 0;\n\nprotected:\n virtual size_t DoRead(void* buf, size_t len) = 0;\n};\n\n\nclass ZeroCopyInput : public InputStream {\npublic:\n inline size_t Next(const void** buf, size_t len) {\n return DoNext(buf, len);\n }\n\n bool Skip(size_t bytes) override;\n\nprotected:\n virtual size_t DoNext(const void** ptr, size_t len) = 0;\n\n size_t DoRead(void* buf, size_t len) override;\n};\n\n\n/**\n * A ZeroCopyInput stream backed by an in-memory array of bytes.\n */\nclass ArrayInput : public ZeroCopyInput {\npublic:\n ArrayInput() noexcept;\n ArrayInput(const void* buf, size_t len) noexcept;\n ~ArrayInput() override;\n\n /// Number of bytes available in the stream.\n inline size_t Avail() const noexcept {\n return len_;\n }\n\n /// Current read position in the memory block used by this stream.\n inline const uint8_t* Data() const noexcept {\n return data_;\n }\n\n /// Whether there is more data in the stream.\n inline bool Exhausted() const noexcept {\n return !Avail();\n }\n\n inline void Reset(const void* buf, size_t len) noexcept {\n data_ = static_cast(buf);\n len_ = len;\n }\n\nprivate:\n size_t DoNext(const void** ptr, size_t len) override;\n\nprivate:\n const uint8_t* data_;\n size_t len_;\n};\n\n\nclass BufferedInput : public ZeroCopyInput {\npublic:\n BufferedInput(std::unique_ptr source, size_t buflen = 8192);\n ~BufferedInput() override;\n\n void Reset();\n\nprotected:\n size_t DoRead(void* buf, size_t len) override;\n size_t DoNext(const void** ptr, size_t len) override;\n\nprivate:\n std::unique_ptr const source_;\n ArrayInput array_input_;\n std::vector buffer_;\n};\n\n}\n\n// Path: clickhouse/base/buffer.h\n#pragma once\n\n#include \n#include \n\nnamespace clickhouse {\n\nusing Buffer = std::vector;\n\n}\n\n// Path: clickhouse/base/output.h\n#pragma once\n\n#include \"buffer.h\"\n\n#include \n#include \n#include \n#include \n#include \n\nnamespace clickhouse {\n\nclass OutputStream {\npublic:\n virtual ~OutputStream()\n { }\n\n inline void Flush() {\n DoFlush();\n }\n\n inline size_t Write(const void* data, size_t len) {\n return DoWrite(data, len);\n }\n\nprotected:\n virtual void DoFlush() { }\n\n virtual size_t DoWrite(const void* data, size_t len) = 0;\n};\n\n\nclass ZeroCopyOutput : public OutputStream {\npublic:\n inline size_t Next(void** data, size_t size) {\n return DoNext(data, size);\n }\n\nprotected:\n // Obtains a buffer into which data can be written. Any data written\n // into this buffer will eventually (maybe instantly, maybe later on)\n // be written to the output.\n virtual size_t DoNext(void** data, size_t len) = 0;\n\n size_t DoWrite(const void* data, size_t len) override;\n};\n\n\n/**\n * A ZeroCopyOutput stream backed by an in-memory array of bytes.\n */\nclass ArrayOutput : public ZeroCopyOutput {\npublic:\n ArrayOutput(void* buf, size_t len);\n ~ArrayOutput() override;\n\n /// Number of bytes available in the stream.\n inline size_t Avail() const noexcept {\n return end_ - buf_;\n }\n\n /// Current write position in the memory block used by this stream.\n inline const uint8_t* Data() const noexcept {\n return buf_;\n }\n\n /// Whether there is more space in the stream.\n inline bool Exhausted() const noexcept {\n return !Avail();\n }\n\n /// Initializes this stream with a new memory block.\n inline void Reset(void* buf, size_t len) noexcept {\n buf_ = static_cast(buf);\n end_ = buf_ + len;\n buffer_size_ = len;\n }\n\n /// Number of bytes written to the buffer.\n inline size_t Size() const noexcept {\n return buffer_size_ - Avail();\n }\n\nprotected:\n size_t DoNext(void** data, size_t len) override;\n\nprivate:\n uint8_t* buf_;\n uint8_t* end_;\n size_t buffer_size_;\n};\n\n\n/**\n * A ZeroCopyOutput stream backed by a vector.\n *\n * Doesn't Flush() in destructor, client must ensure to do it manually at some point.\n */\nclass BufferOutput : public ZeroCopyOutput {\npublic:\n BufferOutput(Buffer* buf);\n ~BufferOutput() override;\n\nprotected:\n size_t DoNext(void** data, size_t len) override;\n\nprivate:\n Buffer* buf_;\n size_t pos_;\n};\n\n/** BufferedOutput writes data to internal buffer first.\n *\n * Any data goes to underlying stream only if internal buffer is full\n * or when client invokes Flush() on this.\n *\n * Doesn't Flush() in destructor, client must ensure to do it manually at some point.\n */\nclass BufferedOutput : public ZeroCopyOutput {\npublic:\n explicit BufferedOutput(std::unique_ptr destination, size_t buflen = 8192);\n ~BufferedOutput() override;\n\n void Reset();\n\nprotected:\n void DoFlush() override;\n size_t DoNext(void** data, size_t len) override;\n size_t DoWrite(const void* data, size_t len) override;\n\nprivate:\n std::unique_ptr const destination_;\n Buffer buffer_;\n ArrayOutput array_output_;\n};\n\ntemplate \nvoid WriteUnaligned(void* buf, const T& value) {\n memcpy(buf, &value, sizeof(value));\n}\n\n}\n\n// Path: clickhouse/base/compressed.h\n#pragma once\n\n#include \"input.h\"\n#include \"output.h\"\n#include \"buffer.h\"\n\nnamespace clickhouse {\n\nclass CompressedInput : public ZeroCopyInput {\npublic:\n explicit CompressedInput(InputStream* input);\n ~CompressedInput() override;\n\nprotected:\n size_t DoNext(const void** ptr, size_t len) override;\n\n bool Decompress();\n\nprivate:\n InputStream* const input_;\n\n Buffer data_;\n ArrayInput mem_;\n};\n\nclass CompressedOutput : public OutputStream {\npublic:\n explicit CompressedOutput(OutputStream * destination, size_t max_compressed_chunk_size = 0);\n ~CompressedOutput() override;\n\nprotected:\n size_t DoWrite(const void* data, size_t len) override;\n void DoFlush() override;\n\nprivate:\n void Compress(const void * data, size_t len);\n void PreallocateCompressBuffer(size_t input_size);\n\nprivate:\n OutputStream * destination_;\n const size_t max_compressed_chunk_size_;\n Buffer compressed_buffer_;\n};\n\n}\n\n// Path: clickhouse/base/platform.h\n#pragma once\n\n#if defined(__linux__)\n# define _linux_\n#elif defined(_WIN64)\n# define _win64_\n# define _win32_\n#elif defined(__WIN32__) || defined(_WIN32)\n# define _win32_\n#elif defined(__APPLE__)\n# define _darwin_\n#endif\n\n#if defined(_win32_) || defined(_win64_)\n# define _win_\n# if !defined(_WIN32_WINNT) || (_WIN32_WINNT < 0x0600)\n# undef _WIN32_WINNT\n# define _WIN32_WINNT 0x0600 // The WSAPoll function is defined on Windows Vista and later.\n# endif\n# define WIN32_LEAN_AND_MEAN 1 // don't include too much header automatically\n#endif\n\n#if defined(_linux_) || defined (_darwin_)\n# define _unix_\n#endif\n\n#if defined(_MSC_VER)\n# undef NOMINMAX\n# define NOMINMAX\n# include \n# define ssize_t SSIZE_T\n# define HAVE_SSIZE_T 1\n#endif\n\n// Path: clickhouse/base/endpoints_iterator.h\n#pragma once\n\n#include \"clickhouse/client.h\"\n#include \n\nnamespace clickhouse {\n\nstruct ClientOptions;\n\n/**\n * Base class for iterating through endpoints.\n*/\nclass EndpointsIteratorBase\n{\n public:\n virtual ~EndpointsIteratorBase() = default;\n\n virtual Endpoint Next() = 0;\n};\n\nclass RoundRobinEndpointsIterator : public EndpointsIteratorBase\n{\n public:\n explicit RoundRobinEndpointsIterator(const std::vector& opts);\n Endpoint Next() override;\n\n ~RoundRobinEndpointsIterator() override;\n\n private:\n const std::vector& endpoints;\n size_t current_index;\n};\n\n}\n\n// Path: clickhouse/base/socket.h\n#pragma once\n\n#include \"platform.h\"\n#include \"input.h\"\n#include \"output.h\"\n#include \"endpoints_iterator.h\"\n\n#include \n#include \n#include \n\n#if defined(_win_)\n# include \n# include \n#else\n# include \n# include \n# include \n# include \n\n# if !defined(SOCKET)\n# define SOCKET int\n# endif\n#endif\n\n#include \n#include \n\nstruct addrinfo;\n\nnamespace clickhouse {\n\nstruct ClientOptions;\n\n/** Address of a host to establish connection to.\n *\n */\nclass NetworkAddress {\npublic:\n explicit NetworkAddress(const std::string& host,\n const std::string& port = \"0\");\n ~NetworkAddress();\n\n const struct addrinfo* Info() const;\n const std::string & Host() const;\n\nprivate:\n const std::string host_;\n struct addrinfo* info_;\n};\n\n#if defined(_win_)\n\nclass windowsErrorCategory : public std::error_category {\npublic:\n char const* name() const noexcept override final;\n std::string message(int c) const override final;\n\n static windowsErrorCategory const& category();\n};\n\n#endif\n\n#if defined(_unix_)\n\nclass getaddrinfoErrorCategory : public std::error_category {\npublic:\n char const* name() const noexcept override final;\n std::string message(int c) const override final;\n\n static getaddrinfoErrorCategory const& category();\n};\n\n#endif\n\n\nclass SocketBase {\npublic:\n virtual ~SocketBase();\n\n virtual std::unique_ptr makeInputStream() const = 0;\n virtual std::unique_ptr makeOutputStream() const = 0;\n};\n\n\nclass SocketFactory {\npublic:\n virtual ~SocketFactory();\n\n // TODO: move connection-related options to ConnectionOptions structure.\n\n virtual std::unique_ptr connect(const ClientOptions& opts, const Endpoint& endpoint) = 0;\n\n virtual void sleepFor(const std::chrono::milliseconds& duration);\n};\n\n\nstruct SocketTimeoutParams {\n std::chrono::milliseconds connect_timeout{ 5000 };\n std::chrono::milliseconds recv_timeout{ 0 };\n std::chrono::milliseconds send_timeout{ 0 };\n};\n\nclass Socket : public SocketBase {\npublic:\n Socket(const NetworkAddress& addr, const SocketTimeoutParams& timeout_params);\n Socket(const NetworkAddress& addr);\n Socket(Socket&& other) noexcept;\n Socket& operator=(Socket&& other) noexcept;\n\n ~Socket() override;\n\n /// @params idle the time (in seconds) the connection needs to remain\n /// idle before TCP starts sending keepalive probes.\n /// @params intvl the time (in seconds) between individual keepalive probes.\n /// @params cnt the maximum number of keepalive probes TCP should send\n /// before dropping the connection.\n void SetTcpKeepAlive(int idle, int intvl, int cnt) noexcept;\n\n /// @params nodelay whether to enable TCP_NODELAY\n void SetTcpNoDelay(bool nodelay) noexcept;\n\n std::unique_ptr makeInputStream() const override;\n std::unique_ptr makeOutputStream() const override;\n\nprotected:\n Socket(const Socket&) = delete;\n Socket& operator = (const Socket&) = delete;\n void Close();\n\n SOCKET handle_;\n};\n\n\nclass NonSecureSocketFactory : public SocketFactory {\npublic:\n ~NonSecureSocketFactory() override;\n\n std::unique_ptr connect(const ClientOptions& opts, const Endpoint& endpoint) override;\n\nprotected:\n virtual std::unique_ptr doConnect(const NetworkAddress& address, const ClientOptions& opts);\n\n void setSocketOptions(Socket& socket, const ClientOptions& opts);\n};\n\n\nclass SocketInput : public InputStream {\npublic:\n explicit SocketInput(SOCKET s);\n ~SocketInput();\n\nprotected:\n bool Skip(size_t bytes) override;\n size_t DoRead(void* buf, size_t len) override;\n\nprivate:\n SOCKET s_;\n};\n\nclass SocketOutput : public OutputStream {\npublic:\n explicit SocketOutput(SOCKET s);\n ~SocketOutput();\n\nprotected:\n size_t DoWrite(const void* data, size_t len) override;\n\nprivate:\n SOCKET s_;\n};\n\nstatic struct NetrworkInitializer {\n NetrworkInitializer();\n} gNetrworkInitializer;\n\n}\n\n// Path: clickhouse/base/wire_format.h\n#pragma once\n\n#include \n#include \n\nnamespace clickhouse {\n\nclass InputStream;\nclass OutputStream;\n\nclass WireFormat {\npublic:\n template \n static bool ReadFixed(InputStream& input, T* value);\n static bool ReadString(InputStream& input, std::string* value);\n static bool SkipString(InputStream& input);\n static bool ReadBytes(InputStream& input, void* buf, size_t len);\n static bool ReadUInt64(InputStream& input, uint64_t* value);\n static bool ReadVarint64(InputStream& output, uint64_t* value);\n\n template \n static void WriteFixed(OutputStream& output, const T& value);\n static void WriteBytes(OutputStream& output, const void* buf, size_t len);\n static void WriteString(OutputStream& output, std::string_view value);\n static void WriteUInt64(OutputStream& output, const uint64_t value);\n static void WriteVarint64(OutputStream& output, uint64_t value);\n\nprivate:\n static bool ReadAll(InputStream& input, void* buf, size_t len);\n static void WriteAll(OutputStream& output, const void* buf, size_t len);\n};\n\ntemplate \ninline bool WireFormat::ReadFixed(InputStream& input, T* value) {\n return ReadAll(input, value, sizeof(T));\n}\n\ninline bool WireFormat::ReadString(InputStream& input, std::string* value) {\n uint64_t len = 0;\n if (ReadVarint64(input, &len)) {\n if (len > 0x00FFFFFFULL) {\n return false;\n }\n value->resize((size_t)len);\n return ReadAll(input, value->data(), (size_t)len);\n }\n\n return false;\n}\n\ninline bool WireFormat::ReadBytes(InputStream& input, void* buf, size_t len) {\n return ReadAll(input, buf, len);\n}\n\ninline bool WireFormat::ReadUInt64(InputStream& input, uint64_t* value) {\n return ReadVarint64(input, value);\n}\n\ntemplate \ninline void WireFormat::WriteFixed(OutputStream& output, const T& value) {\n WriteAll(output, &value, sizeof(T));\n}\n\ninline void WireFormat::WriteBytes(OutputStream& output, const void* buf, size_t len) {\n WriteAll(output, buf, len);\n}\n\ninline void WireFormat::WriteString(OutputStream& output, std::string_view value) {\n WriteVarint64(output, value.size());\n WriteAll(output, value.data(), value.size());\n}\n\ninline void WireFormat::WriteUInt64(OutputStream& output, const uint64_t value) {\n WriteVarint64(output, value);\n}\n\n}\n\n// Path: clickhouse/columns/factory.h\n#pragma once\n\n#include \"column.h\"\n\nnamespace clickhouse {\n\nstruct CreateColumnByTypeSettings\n{\n bool low_cardinality_as_wrapped_column = false;\n};\n\nColumnRef CreateColumnByType(const std::string& type_name, CreateColumnByTypeSettings settings = {});\n\n}\n\n// Path: clickhouse/base/sslsocket.h\n#pragma once\n\n#include \"socket.h\"\n\n#include \n#include \n#include \n\ntypedef struct ssl_ctx_st SSL_CTX;\ntypedef struct ssl_st SSL;\n\nnamespace clickhouse {\n\nstruct SSLParams\n{\n std::vector path_to_ca_files;\n std::string path_to_ca_directory;\n bool use_default_ca_locations;\n int context_options;\n int min_protocol_version;\n int max_protocol_version;\n bool use_SNI;\n bool skip_verification;\n int host_flags;\n using ConfigurationType = std::vector>>;\n ConfigurationType configuration;\n};\n\nclass SSLContext\n{\npublic:\n explicit SSLContext(SSL_CTX & context);\n explicit SSLContext(const SSLParams & context_params);\n ~SSLContext() = default;\n\n SSLContext(const SSLContext &) = delete;\n SSLContext& operator=(const SSLContext &) = delete;\n SSLContext(SSLContext &&) = delete;\n SSLContext& operator=(SSLContext &) = delete;\n\nprivate:\n friend class SSLSocket;\n SSL_CTX * getContext();\n\nprivate:\n std::unique_ptr context_;\n};\n\nclass SSLSocket : public Socket {\npublic:\n explicit SSLSocket(const NetworkAddress& addr, const SocketTimeoutParams& timeout_params,\n const SSLParams& ssl_params, SSLContext& context);\n\n SSLSocket(SSLSocket &&) = default;\n ~SSLSocket() override = default;\n\n SSLSocket(const SSLSocket & ) = delete;\n SSLSocket& operator=(const SSLSocket & ) = delete;\n\n std::unique_ptr makeInputStream() const override;\n std::unique_ptr makeOutputStream() const override;\n\n static void validateParams(const SSLParams & ssl_params);\nprivate:\n std::unique_ptr ssl_;\n};\n\nclass SSLSocketFactory : public NonSecureSocketFactory {\npublic:\n explicit SSLSocketFactory(const ClientOptions& opts);\n ~SSLSocketFactory() override;\n\nprotected:\n std::unique_ptr doConnect(const NetworkAddress& address, const ClientOptions& opts) override;\n\nprivate:\n const SSLParams ssl_params_;\n std::unique_ptr ssl_context_;\n};\n\nclass SSLSocketInput : public InputStream {\npublic:\n explicit SSLSocketInput(SSL *ssl);\n ~SSLSocketInput() = default;\n\n bool Skip(size_t /*bytes*/) override {\n return false;\n }\n\nprotected:\n size_t DoRead(void* buf, size_t len) override;\n\nprivate:\n // Not owning\n SSL *ssl_;\n};\n\nclass SSLSocketOutput : public OutputStream {\npublic:\n explicit SSLSocketOutput(SSL *ssl);\n ~SSLSocketOutput() = default;\n\nprotected:\n size_t DoWrite(const void* data, size_t len) override;\n\nprivate:\n // Not owning\n SSL *ssl_;\n};\n\n}\n\n// Path: clickhouse/client.cpp\n#include \"client.h\"\n#include \"clickhouse/version.h\"\n#include \"protocol.h\"\n\n#include \"base/compressed.h\"\n#include \"base/socket.h\"\n#include \"base/wire_format.h\"\n\n#include \"columns/factory.h\"\n\n#include \n#include \n#include \n#include \n\n#if defined(WITH_OPENSSL)\n#include \"base/sslsocket.h\"\n#endif\n\n#define DBMS_NAME \"ClickHouse\"\n\n#define DBMS_MIN_REVISION_WITH_TEMPORARY_TABLES 50264\n#define DBMS_MIN_REVISION_WITH_TOTAL_ROWS_IN_PROGRESS 51554\n#define DBMS_MIN_REVISION_WITH_BLOCK_INFO 51903\n#define DBMS_MIN_REVISION_WITH_CLIENT_INFO 54032\n#define DBMS_MIN_REVISION_WITH_SERVER_TIMEZONE 54058\n#define DBMS_MIN_REVISION_WITH_QUOTA_KEY_IN_CLIENT_INFO 54060\n//#define DBMS_MIN_REVISION_WITH_TABLES_STATUS 54226\n#define DBMS_MIN_REVISION_WITH_TIME_ZONE_PARAMETER_IN_DATETIME_DATA_TYPE 54337\n#define DBMS_MIN_REVISION_WITH_SERVER_DISPLAY_NAME 54372\n#define DBMS_MIN_REVISION_WITH_VERSION_PATCH 54401\n...\n", "needle_token_start": 3969, "needle_token_end": 3993, "code_context_ntokens": 15941, "output": ["```cpp\nclass MapValueView {\n const typename ArrayColumnType::ArrayValueView data_;\n\npublic:\n using ValueType = std::pair;\n\n MapValueView(typename ArrayColumnType::ArrayValueView data) : data_(std::move(data)) {}\n\n inline auto operator[](const Key& key) const { return (*Find(key)).second; }\n\n inline auto At(const Key& key) const {\n auto it = Find(key);\n if (it == end()) throw ValidationError(\"ColumnMap value key not found\");\n return (*it).second;\n }\n\n class Iterator {\n typename ArrayColumnType::ArrayValueView::Iterator data_iterator_;\n\n public:\n Iterator() = default;\n\n Iterator(typename ArrayColumnType::ArrayValueView::Iterator data_iterator)\n : data_iterator_(data_iterator) {}\n\n using ValueType = std::pair;\n using difference_type = size_t;\n using value_type = ValueType;\n using pointer = void;\n using reference = ValueType&;\n using iterator_category = std::forward_iterator_tag;\n\n inline auto operator*() const {\n auto tuple = *data_iterator_;\n return ValueType{std::get<0>(tuple), std::get<1>(tuple)};\n }\n\n inline Iterator& operator++() {\n ++data_iterator_;\n return *this;\n }\n\n inline bool operator==(const Iterator& other) const {\n return this->data_iterator_ == other.data_iterator_;\n }\n\n inline bool operator!=(const Iterator& other) const { return !(*this == other); }\n };\n\n // minimalistic stl-like container interface, hence the lowercase\n inline Iterator begin() const { return Iterator{data_.begin()}; }\n\n inline Iterator cbegin() const { return Iterator{data_.cbegin()}; }\n\n inline Iterator end() const { return Iterator{data_.end()}; }\n\n inline Iterator cend() const { return Iterator{data_.cend()}; }\n\n inline size_t size() const { return data_.size(); }\n\n // It is ugly to have both size() and Size(), but it is for compatitability with both STL\n // and rest of the clickhouse-cpp.\n inline size_t Size() const { return data_.Size(); }\n\n inline size_t Count(const Key& key) const {\n size_t result = 0;\n for (auto item : data_) {\n if (std::get<0>(item) == key) {\n ++result;\n }\n }\n return result;\n }\n\n inline Iterator Find(const Key& key) const {\n for (auto it = data_.begin(); it != data_.end(); ++it) {\n if (std::get<0>(*it) == key) {\n return Iterator{it};\n }\n }\n return end();\n }\n\n inline bool operator==(const MapValueView& other) const {\n if (size() != other.size()) {\n return false;\n }\n const auto make_index = [](const auto& data) {\n std::vector result{data.Size()};\n std::generate(result.begin(), result.end(), [i = 0] () mutable {return i++;});\n std::sort(result.begin(), result.end(), [&data = data_](size_t l, size_t r) {return data[l] < data[r];});\n return result;\n };\n const auto index = make_index(data_);\n for (const auto& val : other.data_) {\n if (!std::binary_search(index.begin(), index.end(), val,\n [&data = data_](const auto& l, size_t r) {return l < data[r];})) {\n return false;\n }\n }\n return true;\n }\n\n inline bool operator!=(const MapValueView& other) const { return !(*this == other); }\n};\n```"]} +{"repo": "ClickHouse/clickhouse-cpp", "name": "ServerException", "language": "cpp", "path": "clickhouse/exceptions.h", "position_ratio": 0.35, "description": "\nFunction Description:\n1. **Purpose**: To encapsulate and handle exceptions originating from the server, providing a structured way to manage and report errors that occur during server interactions.\n2. **Input**: Receives a unique pointer to an exception structure containing details such as error code, name, display text, and stack trace.\n3. **Output**: Does not return a value but initializes an object that inherits from a standard runtime error class, enriched with server-specific error information.\n4. **Procedure**: Constructs an object by taking a unique pointer to an exception structure, transferring ownership of this structure to the newly created object to maintain error details.\n\n", "instruction": "Based on the function description and code context, please retrieve and repeat the exact described function from the code context in a code block wrapped by ```:", "template": "instruction\ncode_context\ndescription\ninstruction", "code_context": "// Path: clickhouse/types/types.h\n#pragma once\n\n#include \"absl/numeric/int128.h\"\n\n#include \n#include \n#include \n#include \n#include \n#include \n\nnamespace clickhouse {\n\nusing Int128 = absl::int128;\nusing Int64 = int64_t;\n\nusing TypeRef = std::shared_ptr;\n\nclass Type {\npublic:\n enum Code {\n Void = 0,\n Int8,\n Int16,\n Int32,\n Int64,\n UInt8,\n UInt16,\n UInt32,\n UInt64,\n Float32,\n Float64,\n String,\n FixedString,\n DateTime,\n Date,\n Array,\n Nullable,\n Tuple,\n Enum8,\n Enum16,\n UUID,\n IPv4,\n IPv6,\n Int128,\n Decimal,\n Decimal32,\n Decimal64,\n Decimal128,\n LowCardinality,\n DateTime64,\n Date32,\n Map,\n Point,\n Ring,\n Polygon,\n MultiPolygon\n };\n\n using EnumItem = std::pair;\n\nprotected:\n Type(const Code code);\n\npublic:\n template \n auto* As() {\n return static_cast(this);\n }\n\n template \n const auto* As() const {\n return static_cast(this);\n }\n\n /// Type's code.\n Code GetCode() const { return code_; }\n\n /// String representation of the type.\n std::string GetName() const;\n\n /// Is given type same as current one.\n bool IsEqual(const Type& other) const {\n // Types are equal only if both code_ and type_unique_id_ are equal.\n return this == &other\n // GetTypeUniqueId() is relatively heavy, so avoid calling it when comparing obviously different types.\n || (this->GetCode() == other.GetCode() && this->GetTypeUniqueId() == other.GetTypeUniqueId());\n }\n\n bool IsEqual(const TypeRef& other) const { return IsEqual(*other); }\n\n /// Simple name, doesn't depend on parameters and\\or nested types, caller MUST NOT free returned value.\n static const char* TypeName(Code);\n\npublic:\n static TypeRef CreateArray(TypeRef item_type);\n\n static TypeRef CreateDate();\n\n static TypeRef CreateDate32();\n\n static TypeRef CreateDateTime(std::string timezone = std::string());\n\n static TypeRef CreateDateTime64(size_t precision, std::string timezone = std::string());\n\n static TypeRef CreateDecimal(size_t precision, size_t scale);\n\n static TypeRef CreateIPv4();\n\n static TypeRef CreateIPv6();\n\n static TypeRef CreateNothing();\n\n static TypeRef CreateNullable(TypeRef nested_type);\n\n template \n static TypeRef CreateSimple();\n\n static TypeRef CreateString();\n\n static TypeRef CreateString(size_t n);\n\n static TypeRef CreateTuple(const std::vector& item_types);\n\n static TypeRef CreateEnum8(const std::vector& enum_items);\n\n static TypeRef CreateEnum16(const std::vector& enum_items);\n\n static TypeRef CreateUUID();\n\n static TypeRef CreateLowCardinality(TypeRef item_type);\n\n static TypeRef CreateMap(TypeRef key_type, TypeRef value_type);\n\n static TypeRef CreatePoint();\n\n static TypeRef CreateRing();\n\n static TypeRef CreatePolygon();\n\n static TypeRef CreateMultiPolygon();\n\nprivate:\n uint64_t GetTypeUniqueId() const;\n\n const Code code_;\n mutable std::atomic type_unique_id_;\n};\n\ninline bool operator==(const Type & left, const Type & right) {\n if (&left == &right)\n return true;\n if (typeid(left) == typeid(right))\n return left.IsEqual(right);\n return false;\n}\n\ninline bool operator==(const TypeRef & left, const TypeRef & right) {\n return *left == *right;\n}\n\nclass ArrayType : public Type {\npublic:\n explicit ArrayType(TypeRef item_type);\n\n std::string GetName() const { return std::string(\"Array(\") + item_type_->GetName() + \")\"; }\n\n /// Type of array's elements.\n inline TypeRef GetItemType() const { return item_type_; }\n\nprivate:\n TypeRef item_type_;\n};\n\nclass DecimalType : public Type {\npublic:\n DecimalType(size_t precision, size_t scale);\n\n std::string GetName() const;\n friend class EnumType;\n friend class DateTimeType;\n\n inline size_t GetScale() const { return scale_; }\n inline size_t GetPrecision() const { return precision_; }\n\nprivate:\n const size_t precision_, scale_;\n};\n\nnamespace details\n{\nclass TypeWithTimeZoneMixin\n{\npublic:\n TypeWithTimeZoneMixin(std::string timezone);\n\n /// Timezone associated with a data column.\n const std::string & Timezone() const;\n\nprivate:\n std::string timezone_;\n};\n}\n\nclass DateTimeType : public Type, public details::TypeWithTimeZoneMixin {\npublic:\n explicit DateTimeType(std::string timezone);\n\n std::string GetName() const;\n};\n\nclass DateTime64Type: public Type, public details::TypeWithTimeZoneMixin {\npublic:\n explicit DateTime64Type(size_t precision, std::string timezone_);\n\n std::string GetName() const;\n\n inline size_t GetPrecision() const { return precision_; }\nprivate:\n size_t precision_;\n};\n\nclass EnumType : public Type {\npublic:\n EnumType(Type::Code type, const std::vector& items);\n\n std::string GetName() const;\n\n /// Methods to work with enum types.\n std::string_view GetEnumName(int16_t value) const;\n int16_t GetEnumValue(const std::string& name) const;\n bool HasEnumName(const std::string& name) const;\n bool HasEnumValue(int16_t value) const;\n\nprivate:\n using ValueToNameType = std::map;\n using NameToValueType = std::map;\n using ValueToNameIterator = ValueToNameType::const_iterator;\n\n ValueToNameType value_to_name_;\n NameToValueType name_to_value_;\n\npublic:\n ValueToNameIterator BeginValueToName() const;\n ValueToNameIterator EndValueToName() const;\n};\n\nclass FixedStringType : public Type {\npublic:\n explicit FixedStringType(size_t n);\n\n std::string GetName() const { return std::string(\"FixedString(\") + std::to_string(size_) + \")\"; }\n\n inline size_t GetSize() const { return size_; }\n\nprivate:\n size_t size_;\n};\n\nclass NullableType : public Type {\npublic:\n explicit NullableType(TypeRef nested_type);\n\n std::string GetName() const { return std::string(\"Nullable(\") + nested_type_->GetName() + \")\"; }\n\n /// Type of nested nullable element.\n TypeRef GetNestedType() const { return nested_type_; }\n\nprivate:\n TypeRef nested_type_;\n};\n\nclass TupleType : public Type {\npublic:\n explicit TupleType(const std::vector& item_types);\n\n std::string GetName() const;\n\n /// Type of nested Tuple element type.\n std::vector GetTupleType() const { return item_types_; }\n\nprivate:\n std::vector item_types_;\n};\n\nclass LowCardinalityType : public Type {\npublic:\n explicit LowCardinalityType(TypeRef nested_type);\n ~LowCardinalityType();\n\n std::string GetName() const { return std::string(\"LowCardinality(\") + nested_type_->GetName() + \")\"; }\n\n /// Type of nested nullable element.\n TypeRef GetNestedType() const { return nested_type_; }\n\nprivate:\n TypeRef nested_type_;\n};\n\nclass MapType : public Type {\npublic:\n explicit MapType(TypeRef key_type, TypeRef value_type);\n\n std::string GetName() const;\n\n /// Type of keys.\n TypeRef GetKeyType() const { return key_type_; }\n\n /// Type of values.\n TypeRef GetValueType() const { return value_type_; }\n\nprivate:\n TypeRef key_type_;\n TypeRef value_type_;\n};\n\ntemplate <>\ninline TypeRef Type::CreateSimple() {\n return TypeRef(new Type(Int8));\n}\n\ntemplate <>\ninline TypeRef Type::CreateSimple() {\n return TypeRef(new Type(Int16));\n}\n\ntemplate <>\ninline TypeRef Type::CreateSimple() {\n return TypeRef(new Type(Int32));\n}\n\ntemplate <>\ninline TypeRef Type::CreateSimple() {\n return TypeRef(new Type(Int64));\n}\n\ntemplate <>\ninline TypeRef Type::CreateSimple() {\n return TypeRef(new Type(Int128));\n}\n\ntemplate <>\ninline TypeRef Type::CreateSimple() {\n return TypeRef(new Type(UInt8));\n}\n\ntemplate <>\ninline TypeRef Type::CreateSimple() {\n return TypeRef(new Type(UInt16));\n}\n\ntemplate <>\ninline TypeRef Type::CreateSimple() {\n return TypeRef(new Type(UInt32));\n}\n\ntemplate <>\ninline TypeRef Type::CreateSimple() {\n return TypeRef(new Type(UInt64));\n}\n\ntemplate <>\ninline TypeRef Type::CreateSimple() {\n return TypeRef(new Type(Float32));\n}\n\ntemplate <>\ninline TypeRef Type::CreateSimple() {\n return TypeRef(new Type(Float64));\n}\n\n} // namespace clickhouse\n\n// Path: clickhouse/server_exception.h\n#pragma once\n\n#include \n#include \n\nnamespace clickhouse {\nstruct Exception {\n int code = 0;\n std::string name;\n std::string display_text;\n std::string stack_trace;\n /// Pointer to nested exception.\n std::unique_ptr nested;\n};\n\n}\n\n// Path: clickhouse/exceptions.h\n#pragma once\n\n#include \"server_exception.h\"\n\n#include \n\nnamespace clickhouse {\n\nclass Error : public std::runtime_error {\n using std::runtime_error::runtime_error;\n};\n\n// Caused by any user-related code, like invalid column types or arguments passed to any method.\nclass ValidationError : public Error {\n using Error::Error;\n};\n\n// Buffers+IO errors, failure to serialize/deserialize, checksum mismatches, etc.\nclass ProtocolError : public Error {\n using Error::Error;\n};\n\nclass UnimplementedError : public Error {\n using Error::Error;\n};\n\n// Internal validation error.\nclass AssertionError : public Error {\n using Error::Error;\n};\n\nclass OpenSSLError : public Error {\n using Error::Error;\n};\n\nclass LZ4Error : public Error {\n using Error::Error;\n};\n\n// Exception received from server.\nclass ServerException : public Error {\npublic:\n \nServerException(std::unique_ptr e)\n : Error(std::string())\n , exception_(std::move(e))\n {\n }\n\n int GetCode() const {\n return exception_->code;\n }\n\n const Exception& GetException() const {\n return *exception_;\n }\n\n const char* what() const noexcept override {\n return exception_->display_text.c_str();\n }\n\nprivate:\n std::unique_ptr exception_;\n};\nusing ServerError = ServerException;\n\n}\n\n// Path: clickhouse/columns/itemview.h\n#pragma once\n\n#include \"../types/types.h\"\n#include \"../exceptions.h\"\n\n#include \n#include \n#include \n\nnamespace clickhouse {\n\n/** ItemView is a view on a data stored in Column, safe-ish interface for reading values from Column.\n *\n * Data is not owned (hence the name View) and will be invalidated on column update, load\n * or destruction (basically on calling any non-const method of Column).\n * `type` reflects what is stored in `data` and can be almost any value-type\n * (except Nullable, Array, Tuple, LowCardinality).\n *\n */\nstruct ItemView {\n using DataType = std::string_view;\n\n const Type::Code type;\n const DataType data;\n\nprivate:\n template \n inline auto ConvertToStorageValue(const T& t) {\n if constexpr (std::is_same_v || std::is_same_v) {\n return std::string_view{t};\n } else if constexpr (std::is_fundamental_v || std::is_same_v>) {\n return std::string_view{reinterpret_cast(&t), sizeof(T)};\n } else {\n static_assert(!std::is_same_v, \"Unknown type, which can't be stored in ItemView\");\n return;\n }\n }\n\npublic:\n ItemView(Type::Code type, DataType data)\n : type(type),\n data(data)\n {\n ValidateData(type, data);\n }\n\n ItemView(Type::Code type, ItemView other)\n : type(type),\n data(other.data)\n {\n ValidateData(type, data);\n }\n\n explicit ItemView()\n : ItemView(Type::Void, std::string_view{})\n {}\n\n template \n explicit ItemView(Type::Code type, const T & value)\n : ItemView(type, ConvertToStorageValue(value))\n {}\n\n template \n auto get() const {\n using ValueType = std::remove_cv_t>;\n if constexpr (std::is_same_v || std::is_same_v) {\n return data;\n } else if constexpr (std::is_fundamental_v || std::is_same_v) {\n if (sizeof(ValueType) == data.size()) {\n return *reinterpret_cast(data.data());\n } else {\n throw AssertionError(\"Incompatitable value type and size. Requested size: \"\n + std::to_string(sizeof(ValueType)) + \" stored size: \" + std::to_string(data.size()));\n }\n }\n }\n\n inline std::string_view AsBinaryData() const {\n return data;\n }\n\n // Validate that value matches type, will throw an exception if validation fails.\n static void ValidateData(Type::Code type, DataType data);\n};\n\n}\n\n// Path: clickhouse/columns/column.h\n#pragma once\n\n#include \"../types/types.h\"\n#include \"../columns/itemview.h\"\n#include \"../exceptions.h\"\n\n#include \n#include \n\nnamespace clickhouse {\n\nclass InputStream;\nclass OutputStream;\n\nusing ColumnRef = std::shared_ptr;\n\n/**\n * An abstract base of all columns classes.\n */\nclass Column : public std::enable_shared_from_this {\npublic:\n explicit inline Column(TypeRef type) : type_(type) {}\n\n virtual ~Column() {}\n\n /// Downcast pointer to the specific column's subtype.\n template \n inline std::shared_ptr As() {\n return std::dynamic_pointer_cast(shared_from_this());\n }\n\n /// Downcast pointer to the specific column's subtype.\n template \n inline std::shared_ptr As() const {\n return std::dynamic_pointer_cast(shared_from_this());\n }\n\n /// Downcast pointer to the specific column's subtype.\n template \n inline std::shared_ptr AsStrict() {\n auto result = std::dynamic_pointer_cast(shared_from_this());\n if (!result) {\n throw ValidationError(\"Can't cast from \" + type_->GetName());\n }\n return result;\n }\n\n /// Get type object of the column.\n inline TypeRef Type() const { return type_; }\n inline const class Type& GetType() const { return *type_; }\n\n /// Appends content of given column to the end of current one.\n virtual void Append(ColumnRef column) = 0;\n\n /// Increase the capacity of the column for large block insertion.\n virtual void Reserve(size_t new_cap) = 0;\n\n /// Template method to load column data from input stream. It'll call LoadPrefix and LoadBody.\n /// Should be called only once from the client. Derived classes should not call it.\n bool Load(InputStream* input, size_t rows);\n\n /// Loads column prefix from input stream.\n virtual bool LoadPrefix(InputStream* input, size_t rows);\n\n /// Loads column data from input stream.\n virtual bool LoadBody(InputStream* input, size_t rows) = 0;\n\n /// Saves column prefix to output stream. Column types with prefixes must implement it.\n virtual void SavePrefix(OutputStream* output);\n\n /// Saves column body to output stream.\n virtual void SaveBody(OutputStream* output) = 0;\n\n /// Template method to save to output stream. It'll call SavePrefix and SaveBody respectively\n /// Should be called only once from the client. Derived classes should not call it.\n /// Save is split in Prefix and Body because some data types require prefixes and specific serialization order.\n /// For instance, Array(LowCardinality(X)) requires LowCardinality.key_version bytes to come before Array.offsets\n void Save(OutputStream* output);\n\n /// Clear column data .\n virtual void Clear() = 0;\n\n /// Returns count of rows in the column.\n virtual size_t Size() const = 0;\n\n /// Makes slice of the current column.\n virtual ColumnRef Slice(size_t begin, size_t len) const = 0;\n\n virtual ColumnRef CloneEmpty() const = 0;\n\n virtual void Swap(Column&) = 0;\n\n /// Get a view on raw item data if it is supported by column, will throw an exception if index is out of range.\n /// Please note that view is invalidated once column items are added or deleted, column is loaded from strean or destroyed.\n virtual ItemView GetItem(size_t) const {\n throw UnimplementedError(\"GetItem() is not supported for column of \" + type_->GetName());\n }\n\n friend void swap(Column& left, Column& right) {\n left.Swap(right);\n }\n\nprotected:\n TypeRef type_;\n};\n\n} // namespace clickhouse\n\n// Path: clickhouse/block.h\n#pragma once\n\n#include \"columns/column.h\"\n\nnamespace clickhouse {\n\nstruct BlockInfo {\n uint8_t is_overflows = 0;\n int32_t bucket_num = -1;\n};\n\nclass Block {\npublic:\n /// Allow to iterate over block's columns.\n class Iterator {\n public:\n Iterator(const Block& block);\n\n /// Name of column.\n const std::string& Name() const;\n\n /// Type of column.\n TypeRef Type() const;\n\n /// Reference to column object.\n ColumnRef Column() const;\n\n /// Move to next column, returns false if next call to IsValid() would return false;\n bool Next();\n\n /// Is the iterator still valid.\n bool IsValid() const;\n\n size_t ColumnIndex() const {\n return idx_;\n }\n\n Iterator& operator*() { return *this; }\n const Iterator& operator*() const { return *this; }\n\n bool operator==(const Iterator & other) const {\n return &block_ == &other.block_ && idx_ == other.idx_;\n }\n bool operator!=(const Iterator & other) const {\n return !(*this == other);\n }\n\n Iterator& operator++() {\n this->Next();\n return *this;\n }\n\n private:\n friend class Block;\n struct ConstructAtEndTag {};\n Iterator(const Block& block, ConstructAtEndTag at_end);\n Iterator() = delete;\n\n const Block& block_;\n size_t idx_;\n };\n\npublic:\n Block();\n Block(size_t cols, size_t rows);\n ~Block();\n\n /// Append named column to the block.\n void AppendColumn(const std::string& name, const ColumnRef& col);\n\n /// Count of columns in the block.\n size_t GetColumnCount() const;\n\n const BlockInfo& Info() const;\n\n /// Set block info\n void SetInfo(BlockInfo info);\n\n /// Count of rows in the block.\n size_t GetRowCount() const;\n\n size_t RefreshRowCount();\n\n const std::string& GetColumnName(size_t idx) const {\n return columns_.at(idx).name;\n }\n\n /// Reference to column by index in the block.\n ColumnRef operator [] (size_t idx) const;\n\n Iterator begin() const;\n Iterator end() const;\n Iterator cbegin() const { return begin(); }\n Iterator cend() const { return end(); }\n\nprivate:\n struct ColumnItem {\n std::string name;\n ColumnRef column;\n };\n\n BlockInfo info_;\n std::vector columns_;\n /// Count of rows in the block.\n size_t rows_;\n};\n\n}\n\n// Path: clickhouse/block.cpp\n#include \"block.h\"\n\n#include \"exceptions.h\"\n\n#include \n\nnamespace clickhouse {\n\nBlock::Iterator::Iterator(const Block& block)\n : block_(block)\n , idx_(0)\n{\n}\n\nBlock::Iterator::Iterator(const Block& block, Block::Iterator::ConstructAtEndTag /*at_end*/)\n : block_(block)\n , idx_(block.GetColumnCount())\n{}\n\nconst std::string& Block::Iterator::Name() const {\n return block_.columns_[idx_].name;\n}\n\nTypeRef Block::Iterator::Type() const {\n return block_.columns_[idx_].column->Type();\n}\n\nColumnRef Block::Iterator::Column() const {\n return block_.columns_[idx_].column;\n}\n\nbool Block::Iterator::Next() {\n ++idx_;\n return IsValid();\n}\n\nbool Block::Iterator::IsValid() const {\n return idx_ < block_.columns_.size();\n}\n\n\nBlock::Block()\n : rows_(0)\n{\n}\n\nBlock::Block(size_t cols, size_t rows)\n : rows_(rows)\n{\n columns_.reserve(cols);\n}\n\nBlock::~Block() = default;\n\nvoid Block::AppendColumn(const std::string& name, const ColumnRef& col) {\n if (columns_.empty()) {\n rows_ = col->Size();\n } else if (col->Size() != rows_) {\n throw ValidationError(\"all columns in block must have same count of rows. Name: [\"+name+\"], rows: [\"+std::to_string(rows_)+\"], columns: [\" + std::to_string(col->Size())+\"]\");\n }\n\n columns_.push_back(ColumnItem{name, col});\n}\n\n/// Count of columns in the block.\nsize_t Block::GetColumnCount() const {\n return columns_.size();\n}\n\nconst BlockInfo& Block::Info() const {\n return info_;\n}\n\n/// Set block info\nvoid Block::SetInfo(BlockInfo info) {\n info_ = std::move(info);\n}\n\n/// Count of rows in the block.\nsize_t Block::GetRowCount() const {\n return rows_;\n}\n\nsize_t Block::RefreshRowCount()\n{\n size_t rows = 0UL;\n\n for (size_t idx = 0UL; idx < columns_.size(); ++idx)\n {\n const std::string& name = columns_[idx].name;\n const ColumnRef& col = columns_[idx].column;\n\n if (idx == 0UL)\n rows = col->Size();\n else if (rows != col->Size())\n throw ValidationError(\"all columns in block must have same count of rows. Name: [\"+name+\"], rows: [\"+std::to_string(rows)+\"], columns: [\" + std::to_string(col->Size())+\"]\");\n }\n\n rows_ = rows;\n return rows_;\n}\n\nColumnRef Block::operator [] (size_t idx) const {\n if (idx < columns_.size()) {\n return columns_[idx].column;\n }\n\n throw std::out_of_range(\"column index is out of range. Index: [\"+std::to_string(idx)+\"], columns: [\" + std::to_string(columns_.size())+\"]\");\n}\n\nBlock::Iterator Block::begin() const {\n return Iterator(*this);\n}\n\nBlock::Iterator Block::end() const {\n return Iterator(*this, Iterator::ConstructAtEndTag{});\n}\n\n}\n\n// Path: clickhouse/base/uuid.h\n#pragma once\n\n#include \n#include \n\nnamespace clickhouse {\n\nusing UInt128 = std::pair;\n\nusing UUID = UInt128;\n\n}\n\n// Path: clickhouse/base/open_telemetry.h\n#pragma once\n\n#include \"uuid.h\"\n\n#include \n\nnamespace clickhouse::open_telemetry {\n\n/// See https://www.w3.org/TR/trace-context/ for trace_flags definition\nenum TraceFlags : uint8_t {\n TRACE_FLAG_NONE = 0,\n TRACE_FLAG_SAMPLED = 1,\n};\n\n/// The runtime info we need to create new OpenTelemetry spans.\nstruct TracingContext {\n UUID trace_id{};\n uint64_t span_id = 0;\n std::string tracestate;\n uint8_t trace_flags = TRACE_FLAG_NONE;\n};\n\n} // namespace clickhouse::open_telemetry\n\n// Path: clickhouse/query.h\n#pragma once\n\n#include \"block.h\"\n#include \"server_exception.h\"\n\n#include \"base/open_telemetry.h\"\n\n#include \n#include \n#include \n#include \n#include \n#include \n\nnamespace clickhouse {\n\nstruct QuerySettingsField {\n enum Flags : uint64_t\n {\n IMPORTANT = 0x01,\n CUSTOM = 0x02,\n OBSOLETE = 0x04,\n };\n std::string value;\n uint64_t flags{0};\n};\n\nusing QuerySettings = std::unordered_map;\n\nstruct Profile {\n uint64_t rows = 0;\n uint64_t blocks = 0;\n uint64_t bytes = 0;\n uint64_t rows_before_limit = 0;\n bool applied_limit = false;\n bool calculated_rows_before_limit = false;\n};\n\n\nstruct Progress {\n uint64_t rows = 0;\n uint64_t bytes = 0;\n uint64_t total_rows = 0;\n uint64_t written_rows = 0;\n uint64_t written_bytes = 0;\n};\n\n\nclass QueryEvents {\npublic:\n virtual ~QueryEvents()\n { }\n\n /// Some data was received.\n virtual void OnData(const Block& block) = 0;\n virtual bool OnDataCancelable(const Block& block) = 0;\n\n virtual void OnServerException(const Exception& e) = 0;\n\n virtual void OnProfile(const Profile& profile) = 0;\n\n virtual void OnProgress(const Progress& progress) = 0;\n\n /** Handle query execution logs provided by server.\n * Amount of logs regulated by `send_logs_level` setting.\n * By-default only `fatal` log events are sent to the client side.\n */\n virtual void OnServerLog(const Block& block) = 0;\n\n /// Handle query execution profile events.\n virtual void OnProfileEvents(const Block& block) = 0;\n\n virtual void OnFinish() = 0;\n};\n\n\nusing ExceptionCallback = std::function;\nusing ProgressCallback = std::function;\nusing SelectCallback = std::function;\nusing SelectCancelableCallback = std::function;\nusing SelectServerLogCallback = std::function;\nusing ProfileEventsCallback = std::function;\nusing ProfileCallbak = std::function;\n\n\nclass Query : public QueryEvents {\npublic:\n Query();\n Query(const char* query, const char* query_id = nullptr);\n Query(const std::string& query, const std::string& query_id = default_query_id);\n ~Query() override;\n\n ///\n inline const std::string& GetText() const {\n return query_;\n }\n\n inline const std::string& GetQueryID() const {\n return query_id_;\n }\n\n inline const QuerySettings& GetQuerySettings() const {\n return query_settings_;\n }\n\n /// Set per query settings\n inline Query& SetQuerySettings(QuerySettings query_settings) {\n query_settings_ = std::move(query_settings);\n return *this;\n }\n\n /// Set per query setting\n inline Query& SetSetting(const std::string& key, const QuerySettingsField& value) {\n query_settings_[key] = value;\n return *this;\n }\n\n inline const std::optional& GetTracingContext() const {\n return tracing_context_;\n }\n\n /// Set tracing context for open telemetry signals\n inline Query& SetTracingContext(open_telemetry::TracingContext tracing_context) {\n tracing_context_ = std::move(tracing_context);\n return *this;\n }\n\n /// Set handler for receiving result data.\n inline Query& OnData(SelectCallback cb) {\n select_cb_ = std::move(cb);\n return *this;\n }\n\n inline Query& OnDataCancelable(SelectCancelableCallback cb) {\n select_cancelable_cb_ = std::move(cb);\n return *this;\n }\n\n /// Set handler for receiving server's exception.\n inline Query& OnException(ExceptionCallback cb) {\n exception_cb_ = std::move(cb);\n return *this;\n }\n\n /// Set handler for receiving a progress of query execution.\n inline Query& OnProgress(ProgressCallback cb) {\n progress_cb_ = std::move(cb);\n return *this;\n }\n\n /// Set handler for receiving a server log of query exceution.\n inline Query& OnServerLog(SelectServerLogCallback cb) {\n select_server_log_cb_ = std::move(cb);\n return *this;\n }\n\n /// Set handler for receiving profile events.\n inline Query& OnProfileEvents(ProfileEventsCallback cb) {\n profile_events_callback_cb_ = std::move(cb);\n return *this;\n }\n\n inline Query& OnProfile(ProfileCallbak cb) {\n profile_callback_cb_ = std::move(cb);\n return *this;\n }\n\n static const std::string default_query_id;\n\nprivate:\n void OnData(const Block& block) override {\n if (select_cb_) {\n select_cb_(block);\n }\n }\n\n bool OnDataCancelable(const Block& block) override {\n if (select_cancelable_cb_) {\n return select_cancelable_cb_(block);\n } else {\n return true;\n }\n }\n\n void OnServerException(const Exception& e) override {\n if (exception_cb_) {\n exception_cb_(e);\n }\n }\n\n void OnProfile(const Profile& profile) override {\n if (profile_callback_cb_)\n profile_callback_cb_(profile);\n }\n\n void OnProgress(const Progress& progress) override {\n if (progress_cb_) {\n progress_cb_(progress);\n }\n }\n\n void OnServerLog(const Block& block) override {\n if (select_server_log_cb_) {\n select_server_log_cb_(block);\n }\n }\n\n void OnProfileEvents(const Block& block) override {\n if (profile_events_callback_cb_) {\n profile_events_callback_cb_(block);\n }\n }\n\n void OnFinish() override {\n }\n\nprivate:\n const std::string query_;\n const std::string query_id_;\n std::optional tracing_context_;\n QuerySettings query_settings_;\n ExceptionCallback exception_cb_;\n ProgressCallback progress_cb_;\n SelectCallback select_cb_;\n SelectCancelableCallback select_cancelable_cb_;\n SelectServerLogCallback select_server_log_cb_;\n ProfileEventsCallback profile_events_callback_cb_;\n ProfileCallbak profile_callback_cb_;\n};\n\n}\n\n// Path: clickhouse/columns/numeric.h\n#pragma once\n\n#include \"column.h\"\n#include \"absl/numeric/int128.h\"\n\nnamespace clickhouse {\n\n/**\n * Represents various numeric columns.\n */\ntemplate \nclass ColumnVector : public Column {\npublic:\n using DataType = T;\n using ValueType = T;\n\n ColumnVector();\n\n explicit ColumnVector(const std::vector& data);\n explicit ColumnVector(std::vector && data);\n\n /// Increase the capacity of the column for large block insertion.\n void Reserve(size_t new_cap) override;\n\n /// Appends one element to the end of column.\n void Append(const T& value);\n\n /// Returns element at given row number.\n const T& At(size_t n) const;\n\n /// Returns element at given row number.\n inline const T& operator [] (size_t n) const { return At(n); }\n\n void Erase(size_t pos, size_t count = 1);\n\n /// Get Raw Vector Contents\n std::vector& GetWritableData();\n\n /// Returns the capacity of the column\n size_t Capacity() const;\n\npublic:\n /// Appends content of given column to the end of current one.\n void Append(ColumnRef column) override;\n\n /// Loads column data from input stream.\n bool LoadBody(InputStream* input, size_t rows) override;\n\n /// Saves column data to output stream.\n void SaveBody(OutputStream* output) override;\n\n /// Clear column data .\n void Clear() override;\n\n /// Returns count of rows in the column.\n size_t Size() const override;\n\n /// Makes slice of the current column.\n ColumnRef Slice(size_t begin, size_t len) const override;\n ColumnRef CloneEmpty() const override;\n void Swap(Column& other) override;\n\n ItemView GetItem(size_t index) const override;\n\nprivate:\n std::vector data_;\n};\n\nusing Int128 = absl::int128;\nusing Int64 = int64_t;\n\nusing ColumnUInt8 = ColumnVector;\nusing ColumnUInt16 = ColumnVector;\nusing ColumnUInt32 = ColumnVector;\nusing ColumnUInt64 = ColumnVector;\n\nusing ColumnInt8 = ColumnVector;\nusing ColumnInt16 = ColumnVector;\nusing ColumnInt32 = ColumnVector;\nusing ColumnInt64 = ColumnVector;\nusing ColumnInt128 = ColumnVector;\n\nusing ColumnFloat32 = ColumnVector;\nusing ColumnFloat64 = ColumnVector;\n\n}\n\n// Path: clickhouse/columns/utils.h\n#pragma once\n\n#include \n#include \n#include \n\nnamespace clickhouse {\n\ntemplate \nstd::vector SliceVector(const std::vector& vec, size_t begin, size_t len) {\n std::vector result;\n\n if (begin < vec.size()) {\n len = std::min(len, vec.size() - begin);\n result.assign(vec.begin() + begin, vec.begin() + (begin + len));\n }\n\n return result;\n}\n\ntemplate \nstruct HasWrapMethod {\nprivate:\n static int detect(...);\n template \n static decltype(U::Wrap(std::move(std::declval()))) detect(const U&);\n\npublic:\n static constexpr bool value = !std::is_same()))>::value;\n};\n\ntemplate \ninline std::shared_ptr WrapColumn(ColumnRef&& column) {\n if constexpr (HasWrapMethod::value) {\n return T::Wrap(std::move(column));\n } else {\n return column->template AsStrict();\n }\n}\n\n}\n\n// Path: clickhouse/columns/array.h\n#pragma once\n\n#include \"column.h\"\n#include \"numeric.h\"\n#include \"utils.h\"\n\n#include \n\nnamespace clickhouse {\n\ntemplate \nclass ColumnArrayT;\n\n/**\n * Represents column of Array(T).\n */\nclass ColumnArray : public Column {\npublic:\n using ValueType = ColumnRef;\n\n /** Create an array of given type.\n *\n * `data` is used internally (and modified) by ColumnArray.\n * Users are strongly advised against supplying non-empty columns and/or modifying\n * contents of `data` afterwards.\n */\n explicit ColumnArray(ColumnRef data);\n\n /** Create an array of given type, with actual values and offsets.\n *\n * Both `data` and `offsets` are used (and modified) internally bye ColumnArray.\n * Users are strongly advised against modifying contents of `data` or `offsets` afterwards.\n */\n ColumnArray(ColumnRef data, std::shared_ptr offsets);\n\n /// Converts input column to array and appends as one row to the current column.\n void AppendAsColumn(ColumnRef array);\n\n /// Converts array at pos n to column.\n /// Type of element of result column same as type of array element.\n ColumnRef GetAsColumn(size_t n) const;\n\n /// Shorthand to get a column casted to a proper type.\n template \n auto GetAsColumnTyped(size_t n) const {\n return GetAsColumn(n)->AsStrict();\n }\n\npublic:\n /// Increase the capacity of the column for large block insertion.\n void Reserve(size_t new_cap) override;\n\n /// Appends content of given column to the end of current one.\n void Append(ColumnRef column) override;\n\n /// Loads column prefix from input stream.\n bool LoadPrefix(InputStream* input, size_t rows) override;\n\n /// Loads column data from input stream.\n bool LoadBody(InputStream* input, size_t rows) override;\n\n /// Saves column prefix to output stream.\n void SavePrefix(OutputStream* output) override;\n\n /// Saves column data to output stream.\n void SaveBody(OutputStream* output) override;\n\n /// Clear column data .\n void Clear() override;\n\n /// Returns count of rows in the column.\n size_t Size() const override;\n\n /// Makes slice of the current column.\n ColumnRef Slice(size_t, size_t) const override;\n ColumnRef CloneEmpty() const override;\n void Swap(Column&) override;\n\n void OffsetsIncrease(size_t);\n\nprotected:\n template friend class ColumnArrayT;\n\n ColumnArray(ColumnArray&& array);\n\n size_t GetOffset(size_t n) const;\n size_t GetSize(size_t n) const;\n ColumnRef GetData();\n void AddOffset(size_t n);\n void Reset();\n\nprivate:\n ColumnRef data_;\n std::shared_ptr offsets_;\n};\n\ntemplate \nclass ColumnArrayT : public ColumnArray {\npublic:\n class ArrayValueView;\n using ValueType = ArrayValueView;\n using NestedColumnType = ColumnType;\n\n explicit ColumnArrayT(std::shared_ptr data)\n : ColumnArray(data)\n , typed_nested_data_(data)\n {}\n\n ColumnArrayT(std::shared_ptr data, std::shared_ptr offsets)\n : ColumnArray(data, offsets)\n , typed_nested_data_(data)\n {}\n\n template \n explicit ColumnArrayT(Args &&... args)\n : ColumnArrayT(std::make_shared(std::forward(args)...))\n {}\n\n /** Create a ColumnArrayT from a ColumnArray, without copying data and offsets, but by 'stealing' those from `col`.\n *\n * Ownership of column internals is transferred to returned object, original (argument) object\n * MUST NOT BE USED IN ANY WAY, it is only safe to dispose it.\n *\n * Throws an exception if `col` is of wrong type, it is safe to use original col in this case.\n * This is a static method to make such conversion verbose.\n */\n static auto Wrap(ColumnArray&& col) {\n auto nested_data = WrapColumn(col.GetData());\n return std::make_shared>(nested_data, col.offsets_);\n }\n\n static auto Wrap(Column&& col) {\n return Wrap(std::move(dynamic_cast(col)));\n }\n\n // Helper to simplify integration with other APIs\n static auto Wrap(ColumnRef&& col) {\n return Wrap(std::move(*col->AsStrict()));\n }\n\n /// A single (row) value of the Array-column, i.e. readonly array of items.\n class ArrayValueView {\n const std::shared_ptr typed_nested_data_;\n const size_t offset_;\n const size_t size_;\n\n public:\n using ValueType = std::decay_t().At(0))>;\n\n ArrayValueView(std::shared_ptr data, size_t offset = 0, size_t size = std::numeric_limits::max())\n : typed_nested_data_(data)\n , offset_(offset)\n , size_(std::min(typed_nested_data_->Size() - offset, size))\n {}\n\n inline auto operator[](size_t index) const {\n return (*typed_nested_data_)[offset_ + index];\n }\n\n inline auto At(size_t index) const {\n if (index >= size_)\n throw ValidationError(\"ColumnArray value index out of bounds: \"\n + std::to_string(index) + \", max is \" + std::to_string(size_));\n return typed_nested_data_->At(offset_ + index);\n }\n\n class Iterator {\n const std::shared_ptr typed_nested_data_;\n const size_t offset_;\n const size_t size_;\n size_t index_;\n public:\n Iterator() = default;\n\n Iterator(std::shared_ptr typed_nested_data, size_t offset, size_t size, size_t index)\n : typed_nested_data_(typed_nested_data)\n , offset_(offset)\n , size_(size)\n , index_(index)\n {}\n\n using ValueType = typename ArrayValueView::ValueType;\n\n inline auto operator*() const {\n return typed_nested_data_->At(offset_ + index_);\n }\n\n inline Iterator& operator++() {\n ++index_;\n return *this;\n }\n\n inline bool operator==(const Iterator& other) const {\n return this->typed_nested_data_ == other.typed_nested_data_\n && this->offset_ == other.offset_\n && this->size_ == other.size_\n && this->index_ == other.index_;\n }\n\n inline bool operator!=(const Iterator& other) const {\n return !(*this == other);\n }\n };\n\n // minimalistic stl-like container interface, hence the lowercase\n inline Iterator begin() const {\n return Iterator{typed_nested_data_, offset_, size_, 0};\n }\n\n inline Iterator cbegin() const {\n return Iterator{typed_nested_data_, offset_, size_, 0};\n }\n\n inline Iterator end() const {\n return Iterator{typed_nested_data_, offset_, size_, size_};\n }\n\n inline Iterator cend() const {\n return Iterator{typed_nested_data_, offset_, size_, size_};\n }\n\n inline size_t size() const {\n return size_;\n }\n\n // It is ugly to have both size() and Size(), but it is for compatitability with both STL and rest of the clickhouse-cpp.\n inline size_t Size() const {\n return size_;\n }\n\n inline bool operator==(const ArrayValueView& other) const {\n if (size() != other.size()) {\n return false;\n }\n for (size_t i = 0; i < size_; ++i) {\n if ((*this)[i] != other[i]) {\n return false;\n }\n }\n return true;\n }\n\n inline bool operator!=(const ArrayValueView& other) const {\n return !(*this == other);\n }\n };\n\n inline auto At(size_t index) const {\n if (index >= Size())\n throw ValidationError(\"ColumnArray row index out of bounds: \"\n + std::to_string(index) + \", max is \" + std::to_string(Size()));\n\n return ArrayValueView{typed_nested_data_, GetOffset(index), GetSize(index)};\n }\n\n inline auto operator[](size_t index) const {\n return ArrayValueView{typed_nested_data_, GetOffset(index), GetSize(index)};\n }\n\n using ColumnArray::Append;\n\n template \n inline void Append(Container&& container) {\n using container_type = decltype(container);\n if constexpr (std::is_lvalue_reference_v || \n std::is_const_v>) {\n Append(std::begin(container), std::end(container));\n }\n else {\n Append(std::make_move_iterator(std::begin(container)),\n std::make_move_iterator(std::end(container)));\n }\n }\n\n template \n inline void Append(const std::initializer_list& container) {\n Append(std::begin(container), std::end(container));\n }\n\n template \n inline void Append(Begin begin, End end) {\n auto & nested_data = *typed_nested_data_;\n size_t counter = 0;\n\n while (begin != end) {\n nested_data.Append(*begin);\n ++begin;\n ++counter;\n }\n\n // Even if there are 0 items, increase counter, creating empty array item.\n AddOffset(counter);\n }\n\n ColumnRef Slice(size_t begin, size_t size) const override {\n return Wrap(ColumnArray::Slice(begin, size));\n }\n\n ColumnRef CloneEmpty() const override {\n return Wrap(ColumnArray::CloneEmpty());\n }\n\n void Swap(Column& other) override {\n auto & col = dynamic_cast &>(other);\n typed_nested_data_.swap(col.typed_nested_data_);\n ColumnArray::Swap(other);\n }\n\nprivate:\n /// Helper to allow wrapping a \"typeless\" ColumnArray\n ColumnArrayT(ColumnArray&& array, std::shared_ptr nested_data)\n : ColumnArray(std::move(array))\n , typed_nested_data_(std::move(nested_data))\n {}\n\n\nprivate:\n std::shared_ptr typed_nested_data_;\n};\n\n}\n\n// Path: clickhouse/columns/decimal.h\n#pragma once\n\n#include \"column.h\"\n#include \"numeric.h\"\n\nnamespace clickhouse {\n\n/**\n * Represents a column of decimal type.\n */\nclass ColumnDecimal : public Column {\npublic:\n using ValueType = Int128;\n\n ColumnDecimal(size_t precision, size_t scale);\n\n void Append(const Int128& value);\n void Append(const std::string& value);\n\n Int128 At(size_t i) const;\n inline auto operator[](size_t i) const { return At(i); }\n\npublic:\n /// Increase the capacity of the column for large block insertion.\n void Reserve(size_t new_cap) override;\n void Append(ColumnRef column) override;\n bool LoadBody(InputStream* input, size_t rows) override;\n void SaveBody(OutputStream* output) override;\n void Clear() override;\n size_t Size() const override;\n ColumnRef Slice(size_t begin, size_t len) const override;\n ColumnRef CloneEmpty() const override;\n void Swap(Column& other) override;\n ItemView GetItem(size_t index) const override;\n\n size_t GetScale() const;\n size_t GetPrecision() const;\n\nprivate:\n /// Depending on a precision it can be one of:\n /// - ColumnInt32\n /// - ColumnInt64\n /// - ColumnInt128\n ColumnRef data_;\n\n explicit ColumnDecimal(TypeRef type, ColumnRef data);\n};\n\n}\n\n// Path: clickhouse/columns/date.h\n#pragma once\n\n#include \"decimal.h\"\n#include \"numeric.h\"\n\n#include \n\nnamespace clickhouse {\n\n/** */\nclass ColumnDate : public Column {\npublic:\n using ValueType = std::time_t;\n\n ColumnDate();\n explicit ColumnDate(std::vector&& data);\n\n /// Appends one element to the end of column.\n /// The implementation is fundamentally wrong, ignores timezones, leap years and daylight saving.\n void Append(const std::time_t& value);\n...\n", "needle_token_start": 3074, "needle_token_end": 3113, "code_context_ntokens": 13439, "output": ["```cpp\nclass ServerException : public Error {\npublic:\n ServerException(std::unique_ptr e)\n : Error(std::string())\n , exception_(std::move(e))\n {\n }\n\n int GetCode() const {\n return exception_->code;\n }\n\n const Exception& GetException() const {\n return *exception_;\n }\n\n const char* what() const noexcept override {\n return exception_->display_text.c_str();\n }\n\nprivate:\n std::unique_ptr exception_;\n};\nusing ServerError = ServerException;\n```"]} +{"repo": "ClickHouse/clickhouse-cpp", "name": "AppendMany", "language": "cpp", "path": "clickhouse/columns/lowcardinality.h", "position_ratio": 0.45, "description": "\nFunction Description:\n1. **Purpose**: The function is designed to add multiple elements to a low cardinality column in a batch, enhancing efficiency when processing collections of data.\n2. **Input**: Accepts a container (such as a list or vector) containing elements of a specific type that matches the column's data type.\n3. **Output**: There is no direct output returned by the function; however, it modifies the state of the column by appending each element from the input container.\n4. **Procedure**: Iterates over each element in the provided container, and for each element, it utilizes another function to safely append the element to the column, handling type-specific considerations such as nullability.\n\n", "instruction": "Based on the function description and code context, please retrieve and repeat the exact described function from the code context in a code block wrapped by ```:", "template": "instruction\ncode_context\ndescription\ninstruction", "code_context": "// Path: clickhouse/columns/tuple.h\n#pragma once\n\n#include \"column.h\"\n#include \"utils.h\"\n\n#include \n\nnamespace clickhouse {\n\n/**\n * Represents column of Tuple([T]).\n */\nclass ColumnTuple : public Column {\npublic:\n ColumnTuple(const std::vector& columns);\n\n /// Returns count of columns in the tuple.\n size_t TupleSize() const;\n\n inline ColumnRef operator [] (size_t n) const {\n return columns_[n];\n }\n\n inline ColumnRef At(size_t n) const {\n return columns_[n];\n }\n\npublic:\n /// Increase the capacity of the column for large block insertion.\n void Reserve(size_t new_cap) override;\n\n /// Appends content of given column to the end of current one.\n void Append(ColumnRef column) override;\n\n /// Loads column prefix from input stream.\n bool LoadPrefix(InputStream* input, size_t rows) override;\n\n /// Loads column data from input stream.\n bool LoadBody(InputStream* input, size_t rows) override;\n\n /// Saves column prefix to output stream.\n void SavePrefix(OutputStream* output) override;\n\n /// Saves column data to output stream.\n void SaveBody(OutputStream* output) override;\n\n /// Clear column data .\n void Clear() override;\n\n /// Returns count of rows in the column.\n size_t Size() const override;\n\n /// Makes slice of the current column.\n ColumnRef Slice(size_t, size_t) const override;\n ColumnRef CloneEmpty() const override;\n void Swap(Column& other) override;\n\nprivate:\n std::vector columns_;\n};\n\ntemplate \nclass ColumnTupleT : public ColumnTuple {\npublic:\n using TupleOfColumns = std::tuple...>;\n\n using ValueType = std::tuple().At(0))>...>;\n\n ColumnTupleT(std::tuple...> columns)\n : ColumnTuple(TupleToVector(columns)), typed_columns_(std::move(columns)) {}\n\n ColumnTupleT(std::vector columns)\n : ColumnTuple(columns), typed_columns_(VectorToTuple(std::move(columns))) {}\n\n ColumnTupleT(const std::initializer_list columns)\n : ColumnTuple(columns), typed_columns_(VectorToTuple(std::move(columns))) {}\n\n inline ValueType At(size_t index) const { return GetTupleOfValues(index); }\n\n inline ValueType operator[](size_t index) const { return GetTupleOfValues(index); }\n\n using ColumnTuple::Append;\n\n template \n inline void Append(std::tuple value) {\n AppendTuple(std::move(value));\n }\n\n /** Create a ColumnTupleT from a ColumnTuple, without copying data and offsets, but by\n * 'stealing' those from `col`.\n *\n * Ownership of column internals is transferred to returned object, original (argument) object\n * MUST NOT BE USED IN ANY WAY, it is only safe to dispose it.\n *\n * Throws an exception if `col` is of wrong type, it is safe to use original col in this case.\n * This is a static method to make such conversion verbose.\n */\n static auto Wrap(ColumnTuple&& col) {\n if (col.TupleSize() != std::tuple_size_v) {\n throw ValidationError(\"Can't wrap from \" + col.GetType().GetName());\n }\n return std::make_shared>(VectorToTuple(std::move(col)));\n }\n\n static auto Wrap(Column&& col) { return Wrap(std::move(dynamic_cast(col))); }\n\n...\n// Path: clickhouse/columns/geo.h\n#pragma once\n\n#include \"array.h\"\n#include \"column.h\"\n#include \"numeric.h\"\n#include \"tuple.h\"\n\nnamespace clickhouse {\n\ntemplate \nclass ColumnGeo : public Column {\npublic:\n using ValueType = typename NestedColumnType::ValueType;\n\n ColumnGeo();\n\n explicit ColumnGeo(ColumnRef data);\n\n /// Appends one element to the end of column.\n template \n void Append(const T& value) {\n data_->Append(value);\n }\n\n /// Returns element at given row number.\n const ValueType At(size_t n) const;\n\n /// Returns element at given row number.\n inline const ValueType operator[](size_t n) const { return At(n); }\n\npublic:\n /// Increase the capacity of the column for large block insertion.\n void Reserve(size_t new_cap) override;\n\n /// Appends content of given column to the end of current one.\n void Append(ColumnRef column) override;\n\n /// Loads column data from input stream.\n bool LoadBody(InputStream* input, size_t rows) override;\n\n /// Saves column data to output stream.\n void SaveBody(OutputStream* output) override;\n\n /// Clear column data .\n void Clear() override;\n\n /// Returns count of rows in the column.\n size_t Size() const override;\n\n /// Makes slice of the current column.\n ColumnRef Slice(size_t begin, size_t len) const override;\n ColumnRef CloneEmpty() const override;\n void Swap(Column& other) override;\n\nprivate:\n std::shared_ptr data_;\n};\n\n// /**\n// * Represents a Point column.\n// */\nusing ColumnPoint = ColumnGeo, Type::Code::Point>;\n\n/**\n * Represents a Ring column.\n */\nusing ColumnRing = ColumnGeo, Type::Code::Ring>;\n\n/**\n * Represents a Polygon column.\n */\nusing ColumnPolygon = ColumnGeo, Type::Code::Polygon>;\n\n/**\n * Represents a MultiPolygon column.\n */\nusing ColumnMultiPolygon = ColumnGeo, Type::Code::MultiPolygon>;\n\n} // namespace clickhouse\n\n// Path: clickhouse/columns/ip4.h\n#pragma once\n\n#include \"numeric.h\"\n\nstruct in_addr;\n\nnamespace clickhouse {\n\nclass ColumnIPv4 : public Column {\npublic:\n using DataType = in_addr;\n using ValueType = in_addr;\n\n ColumnIPv4();\n /** Takes ownership of the data, expects ColumnUInt32.\n * Modifying memory pointed by `data` from outside is UB.\n *\n * TODO: deprecate and remove as it is too dangerous and error-prone.\n */\n explicit ColumnIPv4(ColumnRef data);\n\n explicit ColumnIPv4(std::vector&& data);\n\n /// Appends one element to the column.\n void Append(const std::string& ip);\n\n /// @params ip numeric value with host byte order.\n void Append(uint32_t ip);\n\n ///\n void Append(in_addr ip);\n\n /// Returns element at given row number.\n in_addr At(size_t n) const;\n\n /// Returns element at given row number.\n in_addr operator [] (size_t n) const;\n\n std::string AsString(size_t n) const;\n\npublic:\n /// Increase the capacity of the column for large block insertion.\n void Reserve(size_t new_cap) override;\n\n /// Appends content of given column to the end of current one.\n void Append(ColumnRef column) override;\n\n /// Loads column data from input stream.\n bool LoadBody(InputStream* input, size_t rows) override;\n\n /// Saves column data to output stream.\n void SaveBody(OutputStream* output) override;\n\n /// Clear column data .\n void Clear() override;\n\n /// Returns count of rows in the column.\n size_t Size() const override;\n\n /// Makes slice of the current column.\n ColumnRef Slice(size_t begin, size_t len) const override;\n ColumnRef CloneEmpty() const override;\n void Swap(Column& other) override;\n\n ItemView GetItem(size_t index) const override;\n\nprivate:\n std::shared_ptr data_;\n};\n\n}\n\n// Path: clickhouse/columns/string.h\n#pragma once\n\n#include \"column.h\"\n\n#include \n#include \n#include \n#include \n#include \n\nnamespace clickhouse {\n\n/**\n * Represents column of fixed-length strings.\n */\nclass ColumnFixedString : public Column {\npublic:\n using ValueType = std::string_view;\n\n explicit ColumnFixedString(size_t n);\n\n template \n ColumnFixedString(size_t n, const Values & values)\n : ColumnFixedString(n)\n {\n for (const auto & v : values)\n Append(v);\n }\n\n /// Increase the capacity of the column for large block insertion.\n void Reserve(size_t) override;\n\n /// Appends one element to the column.\n void Append(std::string_view str);\n\n /// Returns element at given row number.\n std::string_view At(size_t n) const;\n\n /// Returns element at given row number.\n inline std::string_view operator [] (size_t n) const { return At(n); }\n\n /// Returns the max size of the fixed string\n size_t FixedSize() const;\n\npublic:\n /// Appends content of given column to the end of current one.\n void Append(ColumnRef column) override;\n\n /// Loads column data from input stream.\n bool LoadBody(InputStream* input, size_t rows) override;\n\n /// Saves column data to output stream.\n void SaveBody(OutputStream* output) override;\n\n /// Clear column data .\n void Clear() override;\n\n /// Returns count of rows in the column.\n size_t Size() const override;\n\n /// Makes slice of the current column.\n ColumnRef Slice(size_t begin, size_t len) const override;\n ColumnRef CloneEmpty() const override;\n void Swap(Column& other) override;\n\n ItemView GetItem(size_t) const override;\n\nprivate:\n size_t string_size_;\n std::string data_;\n};\n\n/**\n * Represents column of variable-length strings.\n */\nclass ColumnString : public Column {\npublic:\n // Type this column takes as argument of Append and returns with At() and operator[]\n using ValueType = std::string_view;\n\n ColumnString();\n ~ColumnString();\n\n explicit ColumnString(size_t element_count);\n explicit ColumnString(const std::vector & data);\n explicit ColumnString(std::vector&& data);\n ColumnString& operator=(const ColumnString&) = delete;\n ColumnString(const ColumnString&) = delete;\n\n /// Increase the capacity of the column for large block insertion.\n void Reserve(size_t new_cap) override;\n\n /// Appends one element to the column.\n void Append(std::string_view str);\n\n /// Appends one element to the column.\n void Append(const char* str);\n\n /// Appends one element to the column.\n void Append(std::string&& steal_value);\n\n /// Appends one element to the column.\n /// If str lifetime is managed elsewhere and guaranteed to outlive the Block sent to the server\n void AppendNoManagedLifetime(std::string_view str);\n\n /// Returns element at given row number.\n std::string_view At(size_t n) const;\n\n /// Returns element at given row number.\n inline std::string_view operator [] (size_t n) const { return At(n); }\n\npublic:\n /// Appends content of given column to the end of current one.\n void Append(ColumnRef column) override;\n\n /// Loads column data from input stream.\n bool LoadBody(InputStream* input, size_t rows) override;\n\n /// Saves column data to output stream.\n void SaveBody(OutputStream* output) override;\n\n /// Clear column data .\n void Clear() override;\n\n /// Returns count of rows in the column.\n size_t Size() const override;\n\n /// Makes slice of the current column.\n ColumnRef Slice(size_t begin, size_t len) const override;\n ColumnRef CloneEmpty() const override;\n void Swap(Column& other) override;\n ItemView GetItem(size_t) const override;\n\nprivate:\n void AppendUnsafe(std::string_view);\n\nprivate:\n struct Block;\n\n std::vector items_;\n std::vector blocks_;\n std::deque append_data_;\n};\n\n}\n\n// Path: clickhouse/columns/ip6.h\n#pragma once\n\n#include \"string.h\"\n#include \n\nstruct in6_addr;\n\nnamespace clickhouse {\n\nclass ColumnIPv6 : public Column {\npublic:\n using DataType = in6_addr;\n using ValueType = in6_addr;\n\n ColumnIPv6();\n /** Takes ownership of the data, expects ColumnFixedString.\n * Modifying memory pointed by `data` from outside is UB.\n *\n * TODO: deprecate and remove as it is too dangerous and error-prone.\n */\n explicit ColumnIPv6(ColumnRef data);\n\n /// Appends one element to the column.\n void Append(const std::string_view& str);\n\n void Append(const in6_addr* addr);\n void Append(const in6_addr& addr);\n\n /// Returns element at given row number.\n in6_addr At(size_t n) const;\n\n /// Returns element at given row number.\n in6_addr operator [] (size_t n) const;\n\n std::string AsString(size_t n) const;\n\npublic:\n /// Increase the capacity of the column for large block insertion.\n void Reserve(size_t new_cap) override;\n\n /// Appends content of given column to the end of current one.\n void Append(ColumnRef column) override;\n\n /// Loads column data from input stream.\n bool LoadBody(InputStream* input, size_t rows) override;\n\n /// Saves column data to output stream.\n void SaveBody(OutputStream* output) override;\n\n /// Clear column data .\n void Clear() override;\n\n /// Returns count of rows in the column.\n size_t Size() const override;\n\n /// Makes slice of the current column.\n ColumnRef Slice(size_t begin, size_t len) const override;\n ColumnRef CloneEmpty() const override;\n void Swap(Column& other) override;\n ItemView GetItem(size_t index) const override;\n\nprivate:\n std::shared_ptr data_;\n};\n\n}\n\n// Path: clickhouse/columns/nullable.h\n#pragma once\n\n#include \"column.h\"\n#include \"numeric.h\"\n\n#include \n\nnamespace clickhouse {\n\n/**\n * Represents column of Nullable(T).\n */\nclass ColumnNullable : public Column {\npublic:\n ColumnNullable(ColumnRef nested, ColumnRef nulls);\n\n /// Appends one null flag to the end of the column\n void Append(bool isnull);\n\n /// Returns null flag at given row number.\n bool IsNull(size_t n) const;\n\n /// Returns nested column.\n ColumnRef Nested() const;\n\n /// Returns nulls column.\n ColumnRef Nulls() const;\n\npublic:\n /// Increase the capacity of the column for large block insertion.\n void Reserve(size_t new_cap) override;\n\n /// Appends content of given column to the end of current one.\n void Append(ColumnRef column) override;\n\n /// Loads column prefix from input stream.\n bool LoadPrefix(InputStream* input, size_t rows) override;\n\n /// Loads column data from input stream.\n bool LoadBody(InputStream* input, size_t rows) override;\n\n /// Saves column prefix to output stream.\n void SavePrefix(OutputStream* output) override;\n\n /// Saves column data to output stream.\n void SaveBody(OutputStream* output) override;\n\n /// Clear column data .\n void Clear() override;\n\n /// Returns count of rows in the column.\n size_t Size() const override;\n\n /// Makes slice of the current column.\n ColumnRef Slice(size_t begin, size_t len) const override;\n ColumnRef CloneEmpty() const override;\n void Swap(Column&) override;\n\n ItemView GetItem(size_t) const override;\n\nprivate:\n ColumnRef nested_;\n std::shared_ptr nulls_;\n};\n\ntemplate \nclass ColumnNullableT : public ColumnNullable {\npublic:\n using NestedColumnType = ColumnType;\n using ValueType = std::optional().At(0))>>;\n\n ColumnNullableT(std::shared_ptr data, std::shared_ptr nulls)\n : ColumnNullable(data, nulls)\n , typed_nested_data_(data)\n {}\n\n explicit ColumnNullableT(std::shared_ptr data)\n : ColumnNullableT(data, FillNulls(data->Size()))\n {}\n\n template \n explicit ColumnNullableT(Args &&... args)\n : ColumnNullableT(std::make_shared(std::forward(args)...))\n {}\n\n inline ValueType At(size_t index) const {\n return IsNull(index) ? ValueType{} : ValueType{typed_nested_data_->At(index)};\n }\n\n inline ValueType operator[](size_t index) const { return At(index); }\n\n /// Appends content of given column to the end of current one.\n void Append(ColumnRef column) override {\n ColumnNullable::Append(std::move(column));\n }\n\n inline void Append(ValueType value) {\n ColumnNullable::Append(!value.has_value());\n if (value.has_value()) {\n typed_nested_data_->Append(std::move(*value));\n } else {\n typed_nested_data_->Append(typename ValueType::value_type{});\n }\n }\n\n /** Create a ColumnNullableT from a ColumnNullable, without copying data and offsets, but by\n * 'stealing' those from `col`.\n *\n * Ownership of column internals is transferred to returned object, original (argument) object\n * MUST NOT BE USED IN ANY WAY, it is only safe to dispose it.\n *\n * Throws an exception if `col` is of wrong type, it is safe to use original col in this case.\n * This is a static method to make such conversion verbose.\n */\n static auto Wrap(ColumnNullable&& col) {\n return std::make_shared>(\n col.Nested()->AsStrict(),\n col.Nulls()->AsStrict()) ;\n }\n\n static auto Wrap(Column&& col) { return Wrap(std::move(dynamic_cast(col))); }\n\n // Helper to simplify integration with other APIs\n static auto Wrap(ColumnRef&& col) { return Wrap(std::move(*col->AsStrict())); }\n\n ColumnRef Slice(size_t begin, size_t size) const override {\n return Wrap(ColumnNullable::Slice(begin, size));\n }\n\n ColumnRef CloneEmpty() const override { return Wrap(ColumnNullable::CloneEmpty()); }\n\n void Swap(Column& other) override {\n auto& col = dynamic_cast&>(other);\n typed_nested_data_.swap(col.typed_nested_data_);\n ColumnNullable::Swap(other);\n }\n\nprivate:\n static inline auto FillNulls(size_t n){\n auto result = std::make_shared();\n for (size_t i = 0; i < n; ++i) {\n result->Append(0);\n }\n return result;\n }\n\n std::shared_ptr typed_nested_data_;\n};\n\ntemplate \nconstexpr bool IsNullable = std::is_base_of_v;\n\n}\n\n// Path: clickhouse/columns/lowcardinality.h\n#pragma once\n\n#include \"column.h\"\n#include \"numeric.h\"\n#include \"nullable.h\"\n\n#include \n#include \n#include \n#include \n\nnamespace clickhouse {\n\ntemplate \nclass ColumnLowCardinalityT;\n\nnamespace details {\n\n/** LowCardinalityHashKey used as key in unique items hashmap to abstract away key value\n * (type of which depends on dictionary column) and to reduce likelehood of collisions.\n *\n * In order to dramatically reduce collision rate, we use 2 different hashes from 2 different hash functions.\n * First hash is used in hashtable (to calculate item position).\n * Second one is used as part of key value and accessed via `operator==()` upon collision resolution/detection.\n */\nusing LowCardinalityHashKey = std::pair;\n\nstruct LowCardinalityHashKeyHash {\n inline std::size_t operator()(const LowCardinalityHashKey &hash_key) const noexcept {\n return hash_key.first;\n }\n};\n\n}\n\n/*\n * LC column contains an \"invisible\" default item at the beginning of the collection. [default, ...]\n * If the nested type is Nullable, it contains a null-item at the beginning and a default item at the second position. [null, default, ...]\n * Null map is not serialized in LC columns. Instead, nulls are tracked by having an index of 0.\n * */\nclass ColumnLowCardinality : public Column {\npublic:\n using UniqueItems = std::unordered_map;\n\n template \n friend class ColumnLowCardinalityT;\n\nprivate:\n // IMPLEMENTATION NOTE: ColumnLowCardinalityT takes reference to underlying dictionary column object,\n // so make sure to NOT change address of the dictionary object (with reset(), swap()) or with anything else.\n ColumnRef dictionary_column_;\n ColumnRef index_column_;\n UniqueItems unique_items_map_;\n\npublic:\n ColumnLowCardinality(ColumnLowCardinality&& col) = default;\n // c-tor makes a deep copy of the dictionary_column.\n explicit ColumnLowCardinality(ColumnRef dictionary_column);\n explicit ColumnLowCardinality(std::shared_ptr dictionary_column);\n\n template \n explicit ColumnLowCardinality(std::shared_ptr> dictionary_column)\n : ColumnLowCardinality(dictionary_column->template As())\n {}\n\n ~ColumnLowCardinality();\n\n /// Increase the capacity of the column for large block insertion.\n void Reserve(size_t new_cap) override;\n\n /// Appends another LowCardinality column to the end of this one, updating dictionary.\n void Append(ColumnRef /*column*/) override;\n\n bool LoadPrefix(InputStream* input, size_t rows) override;\n\n /// Loads column data from input stream.\n bool LoadBody(InputStream* input, size_t rows) override;\n\n /// Saves column prefix to output stream.\n void SavePrefix(OutputStream* output) override;\n\n /// Saves column data to output stream.\n void SaveBody(OutputStream* output) override;\n\n /// Clear column data.\n void Clear() override;\n\n /// Returns count of rows in the column.\n size_t Size() const override;\n\n /// Makes slice of current column, with compacted dictionary\n ColumnRef Slice(size_t begin, size_t len) const override;\n ColumnRef CloneEmpty() const override;\n void Swap(Column& other) override;\n ItemView GetItem(size_t index) const override;\n\n size_t GetDictionarySize() const;\n TypeRef GetNestedType() const;\n\nprotected:\n std::uint64_t getDictionaryIndex(std::uint64_t item_index) const;\n void appendIndex(std::uint64_t item_index);\n void removeLastIndex();\n ColumnRef GetDictionary();\n\n void AppendUnsafe(const ItemView &);\n\nprivate:\n void Setup(ColumnRef dictionary_column);\n void AppendNullItem();\n void AppendDefaultItem();\n\npublic:\n static details::LowCardinalityHashKey computeHashKey(const ItemView &);\n};\n\n/** Type-aware wrapper that provides simple convenience interface for accessing/appending individual items.\n */\ntemplate \nclass ColumnLowCardinalityT : public ColumnLowCardinality {\n\n DictionaryColumnType& typed_dictionary_;\n const Type::Code type_;\n\npublic:\n using WrappedColumnType = DictionaryColumnType;\n // Type this column takes as argument of Append and returns with At() and operator[]\n using ValueType = typename DictionaryColumnType::ValueType;\n\n explicit ColumnLowCardinalityT(ColumnLowCardinality&& col)\n : ColumnLowCardinality(std::move(col))\n , typed_dictionary_(dynamic_cast(*GetDictionary()))\n , type_(GetTypeCode(typed_dictionary_))\n {\n }\n\n template \n explicit ColumnLowCardinalityT(Args &&... args)\n : ColumnLowCardinalityT(std::make_shared(std::forward(args)...))\n {}\n\n // Create LC column from existing T-column, making a deep copy of all contents.\n explicit ColumnLowCardinalityT(std::shared_ptr dictionary_col)\n : ColumnLowCardinality(dictionary_col)\n , typed_dictionary_(dynamic_cast(*GetDictionary()))\n , type_(GetTypeCode(typed_dictionary_))\n {}\n\n /// Extended interface to simplify reading/adding individual items.\n\n /// Returns element at given row number.\n inline ValueType At(size_t n) const {\n return typed_dictionary_.At(getDictionaryIndex(n));\n }\n\n /// Returns element at given row number.\n inline ValueType operator [] (size_t n) const {\n return typed_dictionary_[getDictionaryIndex(n)];\n }\n\n // so the non-virtual Append below doesn't shadow Append() from base class when compiled with older compilers.\n using ColumnLowCardinality::Append;\n\n inline void Append(const ValueType & value) {\n if constexpr (IsNullable) {\n if (value.has_value()) {\n AppendUnsafe(ItemView{type_, *value});\n } else {\n AppendUnsafe(ItemView{});\n }\n } else {\n AppendUnsafe(ItemView{type_, value});\n }\n }\n\n template \n \ninline void AppendMany(const T& container) {\n for (const auto & item : container) {\n Append(item);\n }\n }\n\n /** Create a ColumnLowCardinalityT from a ColumnLowCardinality, without copying data and offsets, but by\n * 'stealing' those from `col`.\n *\n * Ownership of column internals is transferred to returned object, original (argument) object\n * MUST NOT BE USED IN ANY WAY, it is only safe to dispose it.\n *\n * Throws an exception if `col` is of wrong type, it is safe to use original col in this case.\n * This is a static method to make such conversion verbose.\n */\n static auto Wrap(ColumnLowCardinality&& col) {\n return std::make_shared>(std::move(col));\n }\n\n static auto Wrap(Column&& col) { return Wrap(std::move(dynamic_cast(col))); }\n\n // Helper to simplify integration with other APIs\n static auto Wrap(ColumnRef&& col) { return Wrap(std::move(*col->AsStrict())); }\n\n ColumnRef Slice(size_t begin, size_t size) const override {\n return Wrap(ColumnLowCardinality::Slice(begin, size));\n }\n\n ColumnRef CloneEmpty() const override { return Wrap(ColumnLowCardinality::CloneEmpty()); }\n\nprivate:\n\n template \n static auto GetTypeCode(T& column) {\n if constexpr (IsNullable) {\n return GetTypeCode(*column.Nested()->template AsStrict());\n } else {\n return column.Type()->GetCode();\n }\n }\n};\n\n}\n\n// Path: clickhouse/base/projected_iterator.h\n#pragma once\n\n#include \n#include \n#include \n\nnamespace clickhouse {\n\ntemplate ()(std::declval())),\n typename Value = std::decay_t>\nclass ProjectedIterator {\npublic:\n using value_type = Value;\n using reference = Reference;\n using pointer = Reference;\n using difference_type = typename std::iterator_traits::difference_type;\n using iterator_category = typename std::iterator_traits::iterator_category;\n\n ProjectedIterator() = default;\n\n inline ProjectedIterator(Iterator const& iterator, UnaryFunction functor)\n : iterator_(iterator)\n , functor_(std::move(functor)) {\n }\n\n inline UnaryFunction functor() const { return functor; }\n\n inline Iterator const& base() const { return iterator_; }\n\n inline reference operator*() const { return functor_(iterator_); }\n\n inline ProjectedIterator& operator++() {\n ++iterator_;\n return *this;\n }\n\n inline ProjectedIterator& operator--() {\n --iterator_;\n return *this;\n }\n\n inline bool operator==(const ProjectedIterator& other) const {\n return this->iterator_ == other.iterator_;\n }\n\n inline bool operator!=(const ProjectedIterator& other) const {\n return !(*this == other);\n }\n\nprivate:\n Iterator iterator_;\n UnaryFunction functor_;\n};\n\n} // namespace clickhouse\n\n// Path: clickhouse/columns/map.h\n#pragma once\n\n#include \"../base/projected_iterator.h\"\n#include \"array.h\"\n#include \"column.h\"\n#include \"tuple.h\"\n\n#include \n#include \n\nnamespace clickhouse {\n\ntemplate \nclass ColumnMapT;\n\n/**\n * Represents column of Map(K, V).\n */\nclass ColumnMap : public Column {\npublic:\n /** Create a map of given type, with actual values and offsets.\n *\n * Both `data` and `offsets` are used (and modified) internally bye ColumnArray.\n * Users are strongly advised against modifying contents of `data` or `offsets` afterwards.\n */\n explicit ColumnMap(ColumnRef data);\n\n /// Increase the capacity of the column for large block insertion.\n void Reserve(size_t new_cap) override;\n\n /// Appends content of given column to the end of current one.\n void Append(ColumnRef column) override;\n\n /// Loads column prefix from input stream.\n bool LoadPrefix(InputStream* input, size_t rows) override;\n\n /// Loads column data from input stream.\n bool LoadBody(InputStream* input, size_t rows) override;\n\n /// Saves column prefix to output stream.\n void SavePrefix(OutputStream* output) override;\n\n /// Saves column data to output stream.\n void SaveBody(OutputStream* output) override;\n\n /// Clear column data .\n void Clear() override;\n\n /// Returns count of rows in the column.\n size_t Size() const override;\n\n /// Makes slice of the current column.\n ColumnRef Slice(size_t, size_t) const override;\n ColumnRef CloneEmpty() const override;\n void Swap(Column&) override;\n\n /// Converts map at pos n to column.\n /// Type of row is tuple {key, value}.\n ColumnRef GetAsColumn(size_t n) const;\n\nprotected:\n template \n friend class ColumnMapT;\n\n ColumnMap(ColumnMap&& map);\n\nprivate:\n std::shared_ptr data_;\n};\n\ntemplate \nclass ColumnMapT : public ColumnMap {\npublic:\n using KeyColumnType = K;\n using ValueColumnType = V;\n using Key = std::decay_t().At(0))>;\n using Value = std::decay_t().At(0))>;\n using TupleColumnType = ColumnTupleT;\n using ArrayColumnType = ColumnArrayT;\n\n ColumnMapT(ColumnRef data)\n : ColumnMap(data), typed_data_(data->AsStrict>()) {}\n\n ColumnMapT(std::shared_ptr keys, std::shared_ptr values)\n : ColumnMap(std::make_shared(std::make_shared(\n std::make_tuple(std::move(keys), std::move(values))))),\n typed_data_(data_->template As()) {}\n\n ColumnRef Slice(size_t begin, size_t len) const override {\n return std::make_shared>(typed_data_->Slice(begin, len));\n }\n\n ColumnRef CloneEmpty() const override {\n return std::make_shared>(typed_data_->CloneEmpty());\n }\n\n void Swap(Column& other) override {\n auto& col = dynamic_cast&>(other);\n col.typed_data_.swap(typed_data_);\n ColumnMap::Swap(other);\n }\n\n /// A single (row) value of the Map-column i.e. read-only map.\n /// It has a linear time complexity to access items\n /// Because data base type has same structure\n /// \"This lookup works now with a linear complexity.\"\n /// https://clickhouse.com/docs/en/sql-reference/data-types/map\n /// Convert it to a suitable container required to access more than one element\n\n class MapValueView {\n const typename ArrayColumnType::ArrayValueView data_;\n\n public:\n using ValueType = std::pair;\n\n MapValueView(typename ArrayColumnType::ArrayValueView data) : data_(std::move(data)) {}\n\n inline auto operator[](const Key& key) const { return (*Find(key)).second; }\n\n inline auto At(const Key& key) const {\n auto it = Find(key);\n if (it == end()) throw ValidationError(\"ColumnMap value key not found\");\n return (*it).second;\n }\n\n class Iterator {\n typename ArrayColumnType::ArrayValueView::Iterator data_iterator_;\n\n public:\n Iterator() = default;\n\n Iterator(typename ArrayColumnType::ArrayValueView::Iterator data_iterator)\n : data_iterator_(data_iterator) {}\n\n using ValueType = std::pair;\n using difference_type = size_t;\n using value_type = ValueType;\n using pointer = void;\n using reference = ValueType&;\n using iterator_category = std::forward_iterator_tag;\n\n inline auto operator*() const {\n auto tuple = *data_iterator_;\n return ValueType{std::get<0>(tuple), std::get<1>(tuple)};\n }\n\n inline Iterator& operator++() {\n ++data_iterator_;\n return *this;\n }\n\n inline bool operator==(const Iterator& other) const {\n return this->data_iterator_ == other.data_iterator_;\n }\n\n inline bool operator!=(const Iterator& other) const { return !(*this == other); }\n };\n\n // minimalistic stl-like container interface, hence the lowercase\n inline Iterator begin() const { return Iterator{data_.begin()}; }\n\n inline Iterator cbegin() const { return Iterator{data_.cbegin()}; }\n\n inline Iterator end() const { return Iterator{data_.end()}; }\n\n inline Iterator cend() const { return Iterator{data_.cend()}; }\n\n inline size_t size() const { return data_.size(); }\n\n // It is ugly to have both size() and Size(), but it is for compatitability with both STL\n // and rest of the clickhouse-cpp.\n inline size_t Size() const { return data_.Size(); }\n\n inline size_t Count(const Key& key) const {\n size_t result = 0;\n for (auto item : data_) {\n if (std::get<0>(item) == key) {\n ++result;\n }\n }\n return result;\n }\n\n inline Iterator Find(const Key& key) const {\n for (auto it = data_.begin(); it != data_.end(); ++it) {\n if (std::get<0>(*it) == key) {\n return Iterator{it};\n }\n }\n return end();\n }\n\n inline bool operator==(const MapValueView& other) const {\n if (size() != other.size()) {\n return false;\n }\n const auto make_index = [](const auto& data) {\n std::vector result{data.Size()};\n std::generate(result.begin(), result.end(), [i = 0] () mutable {return i++;});\n std::sort(result.begin(), result.end(), [&data](size_t l, size_t r) {return data[l] < data[r];});\n return result;\n };\n const auto index = make_index(data_);\n for (const auto& val : other.data_) {\n if (!std::binary_search(index.begin(), index.end(), val,\n [&data = data_](const auto& l, size_t r) {return l < data[r];})) {\n return false;\n }\n }\n return true;\n }\n\n inline bool operator!=(const MapValueView& other) const { return !(*this == other); }\n };\n\n inline auto At(size_t index) const { return MapValueView{typed_data_->At(index)}; }\n\n inline auto operator[](size_t index) const { return At(index); }\n\n using ColumnMap::Append;\n\n inline void Append(const MapValueView& value) { typed_data_->Append(value.data_); }\n\n inline void Append(const std::vector>& tuples) {\n typed_data_->Append(tuples.begin(), tuples.end());\n }\n\n template \n inline void Append(const T& value) {\n using BaseIter = decltype(value.begin());\n using KeyOfT = decltype(std::declval()->first);\n using ValOfT = decltype(std::declval()->second);\n using Functor = std::function(const BaseIter&)>;\n using Iterator = ProjectedIterator;\n\n Functor functor = [](const BaseIter& i) {\n return std::make_tuple(std::cref(i->first), std::cref(i->second));\n };\n\n typed_data_->Append(Iterator{value.begin(), functor}, Iterator{value.end(), functor});\n }\n\n static auto Wrap(ColumnMap&& col) {\n auto data = ArrayColumnType::Wrap(std::move(col.data_));\n return std::make_shared>(std::move(data));\n }\n\n static auto Wrap(Column&& col) { return Wrap(std::move(dynamic_cast(col))); }\n\n // Helper to simplify integration with other APIs\n static auto Wrap(ColumnRef&& col) { return Wrap(std::move(*col->AsStrict())); }\n\nprivate:\n std::shared_ptr typed_data_;\n};\n\n} // namespace clickhouse\n\n// Path: clickhouse/columns/uuid.h\n#pragma once\n\n#include \"../base/uuid.h\"\n#include \"column.h\"\n#include \"numeric.h\"\n\nnamespace clickhouse {\n\n\n/**\n * Represents a UUID column.\n */\nclass ColumnUUID : public Column {\npublic:\n ColumnUUID();\n\n explicit ColumnUUID(ColumnRef data);\n\n /// Appends one element to the end of column.\n void Append(const UUID& value);\n\n /// Returns element at given row number.\n const UUID At(size_t n) const;\n\n /// Returns element at given row number.\n inline const UUID operator [] (size_t n) const { return At(n); }\n\npublic:\n /// Increase the capacity of the column for large block insertion.\n void Reserve(size_t new_cap) override;\n\n /// Appends content of given column to the end of current one.\n void Append(ColumnRef column) override;\n\n /// Loads column data from input stream.\n bool LoadBody(InputStream* input, size_t rows) override;\n\n /// Saves column data to output stream.\n void SaveBody(OutputStream* output) override;\n\n /// Clear column data .\n void Clear() override;\n\n /// Returns count of rows in the column.\n size_t Size() const override;\n\n /// Makes slice of the current column.\n ColumnRef Slice(size_t begin, size_t len) const override;\n ColumnRef CloneEmpty() const override;\n void Swap(Column& other) override;\n\n ItemView GetItem(size_t) const override;\n\nprivate:\n std::shared_ptr data_;\n};\n\n}\n\n// Path: clickhouse/client.h\n#pragma once\n\n#include \"query.h\"\n#include \"exceptions.h\"\n\n#include \"columns/array.h\"\n#include \"columns/date.h\"\n#include \"columns/decimal.h\"\n#include \"columns/enum.h\"\n#include \"columns/geo.h\"\n#include \"columns/ip4.h\"\n#include \"columns/ip6.h\"\n#include \"columns/lowcardinality.h\"\n#include \"columns/nullable.h\"\n#include \"columns/numeric.h\"\n#include \"columns/map.h\"\n#include \"columns/string.h\"\n#include \"columns/tuple.h\"\n#include \"columns/uuid.h\"\n\n#include \n#include \n#include \n#include \n#include \n#include \n\ntypedef struct ssl_ctx_st SSL_CTX;\n\nnamespace clickhouse {\n\nstruct ServerInfo {\n std::string name;\n std::string timezone;\n std::string display_name;\n uint64_t version_major;\n uint64_t version_minor;\n uint64_t version_patch;\n uint64_t revision;\n};\n\n/// Methods of block compression.\nenum class CompressionMethod {\n None = -1,\n LZ4 = 1,\n};\n\nstruct Endpoint {\n std::string host;\n uint16_t port = 9000;\n inline bool operator==(const Endpoint& right) const {\n return host == right.host && port == right.port;\n }\n};\n\nenum class EndpointsIterationAlgorithm {\n RoundRobin = 0,\n};\n\nstruct ClientOptions {\n // Setter goes first, so it is possible to apply 'deprecated' annotation safely.\n#define DECLARE_FIELD(name, type, setter, default_value) \\\n inline auto & setter(const type& value) { \\\n name = value; \\\n return *this; \\\n } \\\n type name = default_value\n\n /// Hostname of the server.\n DECLARE_FIELD(host, std::string, SetHost, std::string());\n /// Service port.\n DECLARE_FIELD(port, uint16_t, SetPort, 9000);\n\n /** Set endpoints (host+port), only one is used.\n * Client tries to connect to those endpoints one by one, on the round-robin basis:\n * first default enpoint (set via SetHost() + SetPort()), then each of endpoints, from begin() to end(),\n * the first one to establish connection is used for the rest of the session.\n * If port isn't specified, default(9000) value will be used.\n */\n DECLARE_FIELD(endpoints, std::vector, SetEndpoints, {});\n\n /// Default database.\n DECLARE_FIELD(default_database, std::string, SetDefaultDatabase, \"default\");\n /// User name.\n DECLARE_FIELD(user, std::string, SetUser, \"default\");\n /// Access password.\n DECLARE_FIELD(password, std::string, SetPassword, std::string());\n\n /// By default all exceptions received during query execution will be\n /// passed to OnException handler. Set rethrow_exceptions to true to\n /// enable throwing exceptions with standard c++ exception mechanism.\n DECLARE_FIELD(rethrow_exceptions, bool, SetRethrowException, true);\n\n /// Ping server every time before execute any query.\n DECLARE_FIELD(ping_before_query, bool, SetPingBeforeQuery, false);\n /// Count of retry to send request to server.\n DECLARE_FIELD(send_retries, unsigned int, SetSendRetries, 1);\n /// Amount of time to wait before next retry.\n DECLARE_FIELD(retry_timeout, std::chrono::seconds, SetRetryTimeout, std::chrono::seconds(5));\n\n /// Compression method.\n DECLARE_FIELD(compression_method, CompressionMethod, SetCompressionMethod, CompressionMethod::None);\n\n /// TCP Keep alive options\n DECLARE_FIELD(tcp_keepalive, bool, TcpKeepAlive, false);\n DECLARE_FIELD(tcp_keepalive_idle, std::chrono::seconds, SetTcpKeepAliveIdle, std::chrono::seconds(60));\n DECLARE_FIELD(tcp_keepalive_intvl, std::chrono::seconds, SetTcpKeepAliveInterval, std::chrono::seconds(5));\n DECLARE_FIELD(tcp_keepalive_cnt, unsigned int, SetTcpKeepAliveCount, 3);\n\n // TCP options\n DECLARE_FIELD(tcp_nodelay, bool, TcpNoDelay, true);\n\n /// Connection socket connect timeout. If the timeout is negative then the connect operation will never timeout.\n DECLARE_FIELD(connection_connect_timeout, std::chrono::milliseconds, SetConnectionConnectTimeout, std::chrono::seconds(5));\n\n /// Connection socket timeout. If the timeout is set to zero then the operation will never timeout.\n DECLARE_FIELD(connection_recv_timeout, std::chrono::milliseconds, SetConnectionRecvTimeout, std::chrono::milliseconds(0));\n DECLARE_FIELD(connection_send_timeout, std::chrono::milliseconds, SetConnectionSendTimeout, std::chrono::milliseconds(0));\n\n /** It helps to ease migration of the old codebases, which can't afford to switch\n * to using ColumnLowCardinalityT or ColumnLowCardinality directly,\n * but still want to benefit from smaller on-wire LowCardinality bandwidth footprint.\n *\n * @see LowCardinalitySerializationAdaptor, CreateColumnByType\n */\n [[deprecated(\"Makes implementation of LC(X) harder and code uglier. Will be removed in next major release (3.0) \")]]\n DECLARE_FIELD(backward_compatibility_lowcardinality_as_wrapped_column, bool, SetBakcwardCompatibilityFeatureLowCardinalityAsWrappedColumn, false);\n\n /** Set max size data to compress if compression enabled.\n *\n * Allows choosing tradeoff between RAM\\CPU:\n * - Lower value reduces RAM usage, but slightly increases CPU usage.\n * - Higher value increases RAM usage but slightly decreases CPU usage.\n */\n DECLARE_FIELD(max_compression_chunk_size, unsigned int, SetMaxCompressionChunkSize, 65535);\n\n struct SSLOptions {\n /** There are two ways to configure an SSL connection:\n * - provide a pre-configured SSL_CTX, which is not modified and not owned by the Client.\n * - provide a set of options and allow the Client to create and configure SSL_CTX by itself.\n */\n\n /** Pre-configured SSL-context for SSL-connection.\n * If NOT null client DONES NOT take ownership of context and it must be valid for client lifetime.\n * If null client initlaizes OpenSSL and creates his own context, initializes it using\n * other options, like path_to_ca_files, path_to_ca_directory, use_default_ca_locations, etc.\n *\n * Either way context is used to create an SSL-connection, which is then configured with\n * whatever was provided as `configuration`, `host_flags`, `skip_verification` and `use_sni`.\n */\n SSL_CTX * ssl_context = nullptr;\n auto & SetExternalSSLContext(SSL_CTX * new_ssl_context) {\n ssl_context = new_ssl_context;\n return *this;\n }\n\n /** Means to validate the server-supplied certificate against trusted Certificate Authority (CA).\n * If no CAs are configured, the server's identity can't be validated, and the Client would err.\n * See https://www.openssl.org/docs/man1.1.1/man3/SSL_CTX_set_default_verify_paths.html\n */\n /// Load default CA certificates from default locations.\n DECLARE_FIELD(use_default_ca_locations, bool, SetUseDefaultCALocations, true);\n /// Path to the CA files to verify server certificate, may be empty.\n DECLARE_FIELD(path_to_ca_files, std::vector, SetPathToCAFiles, {});\n /// Path to the directory with CA files used to validate server certificate, may be empty.\n DECLARE_FIELD(path_to_ca_directory, std::string, SetPathToCADirectory, \"\");\n\n /** Min and max protocol versions to use, set with SSL_CTX_set_min_proto_version and SSL_CTX_set_max_proto_version\n * for details see https://www.openssl.org/docs/man1.1.1/man3/SSL_CTX_set_min_proto_version.html\n */\n DECLARE_FIELD(min_protocol_version, int, SetMinProtocolVersion, DEFAULT_VALUE);\n DECLARE_FIELD(max_protocol_version, int, SetMaxProtocolVersion, DEFAULT_VALUE);\n\n /** Options to be set with SSL_CTX_set_options,\n * for details see https://www.openssl.org/docs/man1.1.1/man3/SSL_CTX_set_options.html\n */\n DECLARE_FIELD(context_options, int, SetContextOptions, DEFAULT_VALUE);\n\n /** Use SNI at ClientHello\n */\n DECLARE_FIELD(use_sni, bool, SetUseSNI, true);\n\n /** Skip SSL session verification (server's certificate, etc).\n *\n * WARNING: settig to true will bypass all SSL session checks, which\n * is dangerous, but can be used against self-signed certificates, e.g. for testing purposes.\n */\n DECLARE_FIELD(skip_verification, bool, SetSkipVerification, false);\n\n /** Mode of verifying host ssl certificate against name of the host, set with SSL_set_hostflags.\n * For details see https://www.openssl.org/docs/man1.1.1/man3/SSL_set_hostflags.html\n */\n DECLARE_FIELD(host_flags, int, SetHostVerifyFlags, DEFAULT_VALUE);\n\n struct CommandAndValue {\n std::string command;\n std::optional value = std::nullopt;\n };\n /** Extra configuration options, set with SSL_CONF_cmd.\n * For deatils see https://www.openssl.org/docs/man1.1.1/man3/SSL_CONF_cmd.html\n *\n * Takes multiple pairs of command-value strings, all commands are supported,\n * and prefix is empty.\n * i.e. pass `sigalgs` or `SignatureAlgorithms` instead of `-sigalgs`.\n *\n * Rewrites any other options/flags if set in other ways.\n */\n DECLARE_FIELD(configuration, std::vector, SetConfiguration, {});\n\n static const int DEFAULT_VALUE = -1;\n };\n\n // By default SSL is turned off.\n std::optional ssl_options = std::nullopt;\n\n // Will throw an exception if client was built without SSL support.\n ClientOptions& SetSSLOptions(SSLOptions options);\n\n#undef DECLARE_FIELD\n};\n\nstd::ostream& operator<<(std::ostream& os, const ClientOptions& options);\nstd::ostream& operator<<(std::ostream& os, const Endpoint& options);\n\nclass SocketFactory;\n\n/**\n *\n */\nclass Client {\npublic:\n Client(const ClientOptions& opts);\n Client(const ClientOptions& opts,\n std::unique_ptr socket_factory);\n ~Client();\n\n /// Intends for execute arbitrary queries.\n void Execute(const Query& query);\n\n /// Intends for execute select queries. Data will be returned with\n /// one or more call of \\p cb.\n void Select(const std::string& query, SelectCallback cb);\n void Select(const std::string& query, const std::string& query_id, SelectCallback cb);\n\n /// Executes a select query which can be canceled by returning false from\n /// the data handler function \\p cb.\n void SelectCancelable(const std::string& query, SelectCancelableCallback cb);\n void SelectCancelable(const std::string& query, const std::string& query_id, SelectCancelableCallback cb);\n\n /// Alias for Execute.\n void Select(const Query& query);\n\n /// Intends for insert block of data into a table \\p table_name.\n void Insert(const std::string& table_name, const Block& block);\n void Insert(const std::string& table_name, const std::string& query_id, const Block& block);\n\n /// Ping server for aliveness.\n void Ping();\n\n /// Reset connection with initial params.\n void ResetConnection();\n\n const ServerInfo& GetServerInfo() const;\n\n /// Get current connected endpoint.\n /// In case when client is not connected to any endpoint, nullopt will returned.\n const std::optional& GetCurrentEndpoint() const;\n\n // Try to connect to different endpoints one by one only one time. If it doesn't work, throw an exception.\n void ResetConnectionEndpoint();\n\n struct Version\n {\n uint16_t major;\n uint16_t minor;\n uint16_t patch;\n uint16_t build;\n const char * extra;\n };\n\n static Version GetVersion();\n\nprivate:\n const ClientOptions options_;\n\n class Impl;\n std::unique_ptr impl_;\n};\n\n}\n\n// Path: clickhouse/protocol.h\n#pragma once\n\nnamespace clickhouse {\n\n /// Types of packets received from server\n namespace ServerCodes {\n enum {\n Hello = 0, /// Name, version, revision.\n Data = 1, /// `Block` of data, may be compressed.\n Exception = 2, /// Exception that occurred on server side during query execution.\n Progress = 3, /// Query execcution progress: rows and bytes read.\n Pong = 4, /// response to Ping sent by client.\n EndOfStream = 5, /// All packets were sent.\n ProfileInfo = 6, /// Profiling data\n Totals = 7, /// Block of totals, may be compressed.\n Extremes = 8, /// Block of mins and maxs, may be compressed.\n TablesStatusResponse = 9, /// Response to TableStatus.\n Log = 10, /// Query execution log.\n TableColumns = 11, /// Columns' description for default values calculation\n PartUUIDs = 12, /// List of unique parts ids.\n ReadTaskRequest = 13, /// String (UUID) describes a request for which next task is needed\n /// This is such an inverted logic, where server sends requests\n /// And client returns back response\n ProfileEvents = 14, /// Packet with profile events from server.\n };\n }\n\n /// Types of packets sent by client.\n namespace ClientCodes {\n enum {\n Hello = 0, /// Name, version, default database name.\n Query = 1, /** Query id, query settings, query processing stage,\n * compression status, and query text (no INSERT data).\n */\n Data = 2, /// Data `Block` (e.g. INSERT data), may be compressed.\n Cancel = 3, /// Cancel query.\n Ping = 4, /// Check server connection.\n };\n }\n\n /// Should we compress `Block`s of data\n namespace CompressionState {\n enum {\n Disable = 0,\n Enable = 1,\n };\n }\n\n namespace Stages {\n enum {\n Complete = 2,\n };\n }\n}\n\n// Path: clickhouse/base/input.h\n#pragma once\n\n#include \n#include \n#include \n#include \n\nnamespace clickhouse {\n\nclass InputStream {\npublic:\n virtual ~InputStream() noexcept (false)\n { }\n\n /// Reads one byte from the stream.\n inline bool ReadByte(uint8_t* byte) {\n return DoRead(byte, sizeof(uint8_t)) == sizeof(uint8_t);\n }\n\n /// Reads some data from the stream.\n inline size_t Read(void* buf, size_t len) {\n return DoRead(buf, len);\n }\n\n // Skips a number of bytes. Returns false if an underlying read error occurs.\n virtual bool Skip(size_t bytes) = 0;\n\nprotected:\n virtual size_t DoRead(void* buf, size_t len) = 0;\n};\n\n\nclass ZeroCopyInput : public InputStream {\npublic:\n inline size_t Next(const void** buf, size_t len) {\n return DoNext(buf, len);\n }\n\n bool Skip(size_t bytes) override;\n\nprotected:\n virtual size_t DoNext(const void** ptr, size_t len) = 0;\n\n size_t DoRead(void* buf, size_t len) override;\n};\n\n\n/**\n * A ZeroCopyInput stream backed by an in-memory array of bytes.\n */\nclass ArrayInput : public ZeroCopyInput {\npublic:\n ArrayInput() noexcept;\n ArrayInput(const void* buf, size_t len) noexcept;\n ~ArrayInput() override;\n\n /// Number of bytes available in the stream.\n inline size_t Avail() const noexcept {\n return len_;\n }\n\n /// Current read position in the memory block used by this stream.\n inline const uint8_t* Data() const noexcept {\n return data_;\n }\n\n /// Whether there is more data in the stream.\n inline bool Exhausted() const noexcept {\n return !Avail();\n }\n\n inline void Reset(const void* buf, size_t len) noexcept {\n data_ = static_cast(buf);\n len_ = len;\n }\n\nprivate:\n size_t DoNext(const void** ptr, size_t len) override;\n\nprivate:\n const uint8_t* data_;\n size_t len_;\n};\n\n\nclass BufferedInput : public ZeroCopyInput {\npublic:\n BufferedInput(std::unique_ptr source, size_t buflen = 8192);\n ~BufferedInput() override;\n\n void Reset();\n\nprotected:\n size_t DoRead(void* buf, size_t len) override;\n size_t DoNext(const void** ptr, size_t len) override;\n\nprivate:\n std::unique_ptr const source_;\n ArrayInput array_input_;\n std::vector buffer_;\n};\n\n}\n\n// Path: clickhouse/base/buffer.h\n#pragma once\n\n#include \n#include \n\nnamespace clickhouse {\n\nusing Buffer = std::vector;\n\n}\n\n// Path: clickhouse/base/output.h\n#pragma once\n\n...\n", "needle_token_start": 7112, "needle_token_end": 7147, "code_context_ntokens": 15933, "output": ["```cpp\ntemplate \ninline void AppendMany(const T& container) {\n for (const auto & item : container) {\n Append(item);\n }\n}\n```"]} +{"repo": "ClickHouse/clickhouse-cpp", "name": "GetEndpointsIterator", "language": "cpp", "path": "clickhouse/client.cpp", "position_ratio": 0.55, "description": "\nFunction Description:\n1. **Purpose**: To generate an iterator that cycles through a list of network connection points in a round-robin fashion.\n2. **Input**: A configuration object containing a list of network endpoints.\n3. **Output**: A unique pointer to an iterator base class, specifically initialized to iterate over the provided endpoints in a round-robin manner.\n4. **Procedure**: The function first checks if the list of endpoints in the input configuration is empty. If it is, an exception is thrown indicating the absence of endpoints. If not empty, the function creates and returns a unique pointer to a round-robin iterator initialized with the list of endpoints.\n\n", "instruction": "Based on the function description and code context, please retrieve and repeat the exact described function from the code context in a code block wrapped by ```:", "template": "instruction\ncode_context\ndescription\ninstruction", "code_context": "// Path: clickhouse/client.h\n#pragma once\n\n#include \"query.h\"\n#include \"exceptions.h\"\n\n#include \"columns/array.h\"\n#include \"columns/date.h\"\n#include \"columns/decimal.h\"\n#include \"columns/enum.h\"\n#include \"columns/geo.h\"\n#include \"columns/ip4.h\"\n#include \"columns/ip6.h\"\n#include \"columns/lowcardinality.h\"\n#include \"columns/nullable.h\"\n#include \"columns/numeric.h\"\n#include \"columns/map.h\"\n#include \"columns/string.h\"\n#include \"columns/tuple.h\"\n#include \"columns/uuid.h\"\n\n#include \n#include \n#include \n#include \n#include \n#include \n\ntypedef struct ssl_ctx_st SSL_CTX;\n\nnamespace clickhouse {\n\nstruct ServerInfo {\n std::string name;\n std::string timezone;\n std::string display_name;\n uint64_t version_major;\n uint64_t version_minor;\n uint64_t version_patch;\n uint64_t revision;\n};\n\n/// Methods of block compression.\nenum class CompressionMethod {\n None = -1,\n LZ4 = 1,\n};\n\nstruct Endpoint {\n std::string host;\n uint16_t port = 9000;\n inline bool operator==(const Endpoint& right) const {\n return host == right.host && port == right.port;\n }\n};\n\nenum class EndpointsIterationAlgorithm {\n RoundRobin = 0,\n};\n\nstruct ClientOptions {\n // Setter goes first, so it is possible to apply 'deprecated' annotation safely.\n#define DECLARE_FIELD(name, type, setter, default_value) \\\n inline auto & setter(const type& value) { \\\n name = value; \\\n return *this; \\\n } \\\n type name = default_value\n\n /// Hostname of the server.\n DECLARE_FIELD(host, std::string, SetHost, std::string());\n /// Service port.\n DECLARE_FIELD(port, uint16_t, SetPort, 9000);\n\n /** Set endpoints (host+port), only one is used.\n * Client tries to connect to those endpoints one by one, on the round-robin basis:\n * first default enpoint (set via SetHost() + SetPort()), then each of endpoints, from begin() to end(),\n * the first one to establish connection is used for the rest of the session.\n * If port isn't specified, default(9000) value will be used.\n */\n DECLARE_FIELD(endpoints, std::vector, SetEndpoints, {});\n\n /// Default database.\n DECLARE_FIELD(default_database, std::string, SetDefaultDatabase, \"default\");\n /// User name.\n DECLARE_FIELD(user, std::string, SetUser, \"default\");\n /// Access password.\n DECLARE_FIELD(password, std::string, SetPassword, std::string());\n\n /// By default all exceptions received during query execution will be\n /// passed to OnException handler. Set rethrow_exceptions to true to\n /// enable throwing exceptions with standard c++ exception mechanism.\n...\n// Path: clickhouse/protocol.h\n#pragma once\n\nnamespace clickhouse {\n\n /// Types of packets received from server\n namespace ServerCodes {\n enum {\n Hello = 0, /// Name, version, revision.\n Data = 1, /// `Block` of data, may be compressed.\n Exception = 2, /// Exception that occurred on server side during query execution.\n Progress = 3, /// Query execcution progress: rows and bytes read.\n Pong = 4, /// response to Ping sent by client.\n EndOfStream = 5, /// All packets were sent.\n ProfileInfo = 6, /// Profiling data\n Totals = 7, /// Block of totals, may be compressed.\n Extremes = 8, /// Block of mins and maxs, may be compressed.\n TablesStatusResponse = 9, /// Response to TableStatus.\n Log = 10, /// Query execution log.\n TableColumns = 11, /// Columns' description for default values calculation\n PartUUIDs = 12, /// List of unique parts ids.\n ReadTaskRequest = 13, /// String (UUID) describes a request for which next task is needed\n /// This is such an inverted logic, where server sends requests\n /// And client returns back response\n ProfileEvents = 14, /// Packet with profile events from server.\n };\n }\n\n /// Types of packets sent by client.\n namespace ClientCodes {\n enum {\n Hello = 0, /// Name, version, default database name.\n Query = 1, /** Query id, query settings, query processing stage,\n * compression status, and query text (no INSERT data).\n */\n Data = 2, /// Data `Block` (e.g. INSERT data), may be compressed.\n Cancel = 3, /// Cancel query.\n Ping = 4, /// Check server connection.\n };\n }\n\n /// Should we compress `Block`s of data\n namespace CompressionState {\n enum {\n Disable = 0,\n Enable = 1,\n };\n }\n\n namespace Stages {\n enum {\n Complete = 2,\n };\n }\n}\n\n// Path: clickhouse/base/input.h\n#pragma once\n\n#include \n#include \n#include \n#include \n\nnamespace clickhouse {\n\nclass InputStream {\npublic:\n virtual ~InputStream() noexcept (false)\n { }\n\n /// Reads one byte from the stream.\n inline bool ReadByte(uint8_t* byte) {\n return DoRead(byte, sizeof(uint8_t)) == sizeof(uint8_t);\n }\n\n /// Reads some data from the stream.\n inline size_t Read(void* buf, size_t len) {\n return DoRead(buf, len);\n }\n\n // Skips a number of bytes. Returns false if an underlying read error occurs.\n virtual bool Skip(size_t bytes) = 0;\n\nprotected:\n virtual size_t DoRead(void* buf, size_t len) = 0;\n};\n\n\nclass ZeroCopyInput : public InputStream {\npublic:\n inline size_t Next(const void** buf, size_t len) {\n return DoNext(buf, len);\n }\n\n bool Skip(size_t bytes) override;\n\nprotected:\n virtual size_t DoNext(const void** ptr, size_t len) = 0;\n\n size_t DoRead(void* buf, size_t len) override;\n};\n\n\n/**\n * A ZeroCopyInput stream backed by an in-memory array of bytes.\n */\nclass ArrayInput : public ZeroCopyInput {\npublic:\n ArrayInput() noexcept;\n ArrayInput(const void* buf, size_t len) noexcept;\n ~ArrayInput() override;\n\n /// Number of bytes available in the stream.\n inline size_t Avail() const noexcept {\n return len_;\n }\n\n /// Current read position in the memory block used by this stream.\n inline const uint8_t* Data() const noexcept {\n return data_;\n }\n\n /// Whether there is more data in the stream.\n inline bool Exhausted() const noexcept {\n return !Avail();\n }\n\n inline void Reset(const void* buf, size_t len) noexcept {\n data_ = static_cast(buf);\n len_ = len;\n }\n\nprivate:\n size_t DoNext(const void** ptr, size_t len) override;\n\nprivate:\n const uint8_t* data_;\n size_t len_;\n};\n\n\nclass BufferedInput : public ZeroCopyInput {\npublic:\n BufferedInput(std::unique_ptr source, size_t buflen = 8192);\n ~BufferedInput() override;\n\n void Reset();\n\nprotected:\n size_t DoRead(void* buf, size_t len) override;\n size_t DoNext(const void** ptr, size_t len) override;\n\nprivate:\n std::unique_ptr const source_;\n ArrayInput array_input_;\n std::vector buffer_;\n};\n\n}\n\n// Path: clickhouse/base/buffer.h\n#pragma once\n\n#include \n#include \n\nnamespace clickhouse {\n\nusing Buffer = std::vector;\n\n}\n\n// Path: clickhouse/base/output.h\n#pragma once\n\n#include \"buffer.h\"\n\n#include \n#include \n#include \n#include \n#include \n\nnamespace clickhouse {\n\nclass OutputStream {\npublic:\n virtual ~OutputStream()\n { }\n\n inline void Flush() {\n DoFlush();\n }\n\n inline size_t Write(const void* data, size_t len) {\n return DoWrite(data, len);\n }\n\nprotected:\n virtual void DoFlush() { }\n\n virtual size_t DoWrite(const void* data, size_t len) = 0;\n};\n\n\nclass ZeroCopyOutput : public OutputStream {\npublic:\n inline size_t Next(void** data, size_t size) {\n return DoNext(data, size);\n }\n\nprotected:\n // Obtains a buffer into which data can be written. Any data written\n // into this buffer will eventually (maybe instantly, maybe later on)\n // be written to the output.\n virtual size_t DoNext(void** data, size_t len) = 0;\n\n size_t DoWrite(const void* data, size_t len) override;\n};\n\n\n/**\n * A ZeroCopyOutput stream backed by an in-memory array of bytes.\n */\nclass ArrayOutput : public ZeroCopyOutput {\npublic:\n ArrayOutput(void* buf, size_t len);\n ~ArrayOutput() override;\n\n /// Number of bytes available in the stream.\n inline size_t Avail() const noexcept {\n return end_ - buf_;\n }\n\n /// Current write position in the memory block used by this stream.\n inline const uint8_t* Data() const noexcept {\n return buf_;\n }\n\n /// Whether there is more space in the stream.\n inline bool Exhausted() const noexcept {\n return !Avail();\n }\n\n /// Initializes this stream with a new memory block.\n inline void Reset(void* buf, size_t len) noexcept {\n buf_ = static_cast(buf);\n end_ = buf_ + len;\n buffer_size_ = len;\n }\n\n /// Number of bytes written to the buffer.\n inline size_t Size() const noexcept {\n return buffer_size_ - Avail();\n }\n\nprotected:\n size_t DoNext(void** data, size_t len) override;\n\nprivate:\n uint8_t* buf_;\n uint8_t* end_;\n size_t buffer_size_;\n};\n\n\n/**\n * A ZeroCopyOutput stream backed by a vector.\n *\n * Doesn't Flush() in destructor, client must ensure to do it manually at some point.\n */\nclass BufferOutput : public ZeroCopyOutput {\npublic:\n BufferOutput(Buffer* buf);\n ~BufferOutput() override;\n\nprotected:\n size_t DoNext(void** data, size_t len) override;\n\nprivate:\n Buffer* buf_;\n size_t pos_;\n};\n\n/** BufferedOutput writes data to internal buffer first.\n *\n * Any data goes to underlying stream only if internal buffer is full\n * or when client invokes Flush() on this.\n *\n * Doesn't Flush() in destructor, client must ensure to do it manually at some point.\n */\nclass BufferedOutput : public ZeroCopyOutput {\npublic:\n explicit BufferedOutput(std::unique_ptr destination, size_t buflen = 8192);\n ~BufferedOutput() override;\n\n void Reset();\n\nprotected:\n void DoFlush() override;\n size_t DoNext(void** data, size_t len) override;\n size_t DoWrite(const void* data, size_t len) override;\n\nprivate:\n std::unique_ptr const destination_;\n Buffer buffer_;\n ArrayOutput array_output_;\n};\n\ntemplate \nvoid WriteUnaligned(void* buf, const T& value) {\n memcpy(buf, &value, sizeof(value));\n}\n\n}\n\n// Path: clickhouse/base/compressed.h\n#pragma once\n\n#include \"input.h\"\n#include \"output.h\"\n#include \"buffer.h\"\n\nnamespace clickhouse {\n\nclass CompressedInput : public ZeroCopyInput {\npublic:\n explicit CompressedInput(InputStream* input);\n ~CompressedInput() override;\n\nprotected:\n size_t DoNext(const void** ptr, size_t len) override;\n\n bool Decompress();\n\nprivate:\n InputStream* const input_;\n\n Buffer data_;\n ArrayInput mem_;\n};\n\nclass CompressedOutput : public OutputStream {\npublic:\n explicit CompressedOutput(OutputStream * destination, size_t max_compressed_chunk_size = 0);\n ~CompressedOutput() override;\n\nprotected:\n size_t DoWrite(const void* data, size_t len) override;\n void DoFlush() override;\n\nprivate:\n void Compress(const void * data, size_t len);\n void PreallocateCompressBuffer(size_t input_size);\n\nprivate:\n OutputStream * destination_;\n const size_t max_compressed_chunk_size_;\n Buffer compressed_buffer_;\n};\n\n}\n\n// Path: clickhouse/base/platform.h\n#pragma once\n\n#if defined(__linux__)\n# define _linux_\n#elif defined(_WIN64)\n# define _win64_\n# define _win32_\n#elif defined(__WIN32__) || defined(_WIN32)\n# define _win32_\n#elif defined(__APPLE__)\n# define _darwin_\n#endif\n\n#if defined(_win32_) || defined(_win64_)\n# define _win_\n# if !defined(_WIN32_WINNT) || (_WIN32_WINNT < 0x0600)\n# undef _WIN32_WINNT\n# define _WIN32_WINNT 0x0600 // The WSAPoll function is defined on Windows Vista and later.\n# endif\n# define WIN32_LEAN_AND_MEAN 1 // don't include too much header automatically\n#endif\n\n#if defined(_linux_) || defined (_darwin_)\n# define _unix_\n#endif\n\n#if defined(_MSC_VER)\n# undef NOMINMAX\n# define NOMINMAX\n# include \n# define ssize_t SSIZE_T\n# define HAVE_SSIZE_T 1\n#endif\n\n// Path: clickhouse/base/endpoints_iterator.h\n#pragma once\n\n#include \"clickhouse/client.h\"\n#include \n\nnamespace clickhouse {\n\nstruct ClientOptions;\n\n/**\n * Base class for iterating through endpoints.\n*/\nclass EndpointsIteratorBase\n{\n public:\n virtual ~EndpointsIteratorBase() = default;\n\n virtual Endpoint Next() = 0;\n};\n\nclass RoundRobinEndpointsIterator : public EndpointsIteratorBase\n{\n public:\n explicit RoundRobinEndpointsIterator(const std::vector& opts);\n Endpoint Next() override;\n\n ~RoundRobinEndpointsIterator() override;\n\n private:\n const std::vector& endpoints;\n size_t current_index;\n};\n\n}\n\n// Path: clickhouse/base/socket.h\n#pragma once\n\n#include \"platform.h\"\n#include \"input.h\"\n#include \"output.h\"\n#include \"endpoints_iterator.h\"\n\n#include \n#include \n#include \n\n#if defined(_win_)\n# include \n# include \n#else\n# include \n# include \n# include \n# include \n\n# if !defined(SOCKET)\n# define SOCKET int\n# endif\n#endif\n\n#include \n#include \n\nstruct addrinfo;\n\nnamespace clickhouse {\n\nstruct ClientOptions;\n\n/** Address of a host to establish connection to.\n *\n */\nclass NetworkAddress {\npublic:\n explicit NetworkAddress(const std::string& host,\n const std::string& port = \"0\");\n ~NetworkAddress();\n\n const struct addrinfo* Info() const;\n const std::string & Host() const;\n\nprivate:\n const std::string host_;\n struct addrinfo* info_;\n};\n\n#if defined(_win_)\n\nclass windowsErrorCategory : public std::error_category {\npublic:\n char const* name() const noexcept override final;\n std::string message(int c) const override final;\n\n static windowsErrorCategory const& category();\n};\n\n#endif\n\n#if defined(_unix_)\n\nclass getaddrinfoErrorCategory : public std::error_category {\npublic:\n char const* name() const noexcept override final;\n std::string message(int c) const override final;\n\n static getaddrinfoErrorCategory const& category();\n};\n\n#endif\n\n\nclass SocketBase {\npublic:\n virtual ~SocketBase();\n\n virtual std::unique_ptr makeInputStream() const = 0;\n virtual std::unique_ptr makeOutputStream() const = 0;\n};\n\n\nclass SocketFactory {\npublic:\n virtual ~SocketFactory();\n\n // TODO: move connection-related options to ConnectionOptions structure.\n\n virtual std::unique_ptr connect(const ClientOptions& opts, const Endpoint& endpoint) = 0;\n\n virtual void sleepFor(const std::chrono::milliseconds& duration);\n};\n\n\nstruct SocketTimeoutParams {\n std::chrono::milliseconds connect_timeout{ 5000 };\n std::chrono::milliseconds recv_timeout{ 0 };\n std::chrono::milliseconds send_timeout{ 0 };\n};\n\nclass Socket : public SocketBase {\npublic:\n Socket(const NetworkAddress& addr, const SocketTimeoutParams& timeout_params);\n Socket(const NetworkAddress& addr);\n Socket(Socket&& other) noexcept;\n Socket& operator=(Socket&& other) noexcept;\n\n ~Socket() override;\n\n /// @params idle the time (in seconds) the connection needs to remain\n /// idle before TCP starts sending keepalive probes.\n /// @params intvl the time (in seconds) between individual keepalive probes.\n /// @params cnt the maximum number of keepalive probes TCP should send\n /// before dropping the connection.\n void SetTcpKeepAlive(int idle, int intvl, int cnt) noexcept;\n\n /// @params nodelay whether to enable TCP_NODELAY\n void SetTcpNoDelay(bool nodelay) noexcept;\n\n std::unique_ptr makeInputStream() const override;\n std::unique_ptr makeOutputStream() const override;\n\nprotected:\n Socket(const Socket&) = delete;\n Socket& operator = (const Socket&) = delete;\n void Close();\n\n SOCKET handle_;\n};\n\n\nclass NonSecureSocketFactory : public SocketFactory {\npublic:\n ~NonSecureSocketFactory() override;\n\n std::unique_ptr connect(const ClientOptions& opts, const Endpoint& endpoint) override;\n\nprotected:\n virtual std::unique_ptr doConnect(const NetworkAddress& address, const ClientOptions& opts);\n\n void setSocketOptions(Socket& socket, const ClientOptions& opts);\n};\n\n\nclass SocketInput : public InputStream {\npublic:\n explicit SocketInput(SOCKET s);\n ~SocketInput();\n\nprotected:\n bool Skip(size_t bytes) override;\n size_t DoRead(void* buf, size_t len) override;\n\nprivate:\n SOCKET s_;\n};\n\nclass SocketOutput : public OutputStream {\npublic:\n explicit SocketOutput(SOCKET s);\n ~SocketOutput();\n\nprotected:\n size_t DoWrite(const void* data, size_t len) override;\n\nprivate:\n SOCKET s_;\n};\n\nstatic struct NetrworkInitializer {\n NetrworkInitializer();\n} gNetrworkInitializer;\n\n}\n\n// Path: clickhouse/base/wire_format.h\n#pragma once\n\n#include \n#include \n\nnamespace clickhouse {\n\nclass InputStream;\nclass OutputStream;\n\nclass WireFormat {\npublic:\n template \n static bool ReadFixed(InputStream& input, T* value);\n static bool ReadString(InputStream& input, std::string* value);\n static bool SkipString(InputStream& input);\n static bool ReadBytes(InputStream& input, void* buf, size_t len);\n static bool ReadUInt64(InputStream& input, uint64_t* value);\n static bool ReadVarint64(InputStream& output, uint64_t* value);\n\n template \n static void WriteFixed(OutputStream& output, const T& value);\n static void WriteBytes(OutputStream& output, const void* buf, size_t len);\n static void WriteString(OutputStream& output, std::string_view value);\n static void WriteUInt64(OutputStream& output, const uint64_t value);\n static void WriteVarint64(OutputStream& output, uint64_t value);\n\nprivate:\n static bool ReadAll(InputStream& input, void* buf, size_t len);\n static void WriteAll(OutputStream& output, const void* buf, size_t len);\n};\n\ntemplate \ninline bool WireFormat::ReadFixed(InputStream& input, T* value) {\n return ReadAll(input, value, sizeof(T));\n}\n\ninline bool WireFormat::ReadString(InputStream& input, std::string* value) {\n uint64_t len = 0;\n if (ReadVarint64(input, &len)) {\n if (len > 0x00FFFFFFULL) {\n return false;\n }\n value->resize((size_t)len);\n return ReadAll(input, value->data(), (size_t)len);\n }\n\n return false;\n}\n\ninline bool WireFormat::ReadBytes(InputStream& input, void* buf, size_t len) {\n return ReadAll(input, buf, len);\n}\n\ninline bool WireFormat::ReadUInt64(InputStream& input, uint64_t* value) {\n return ReadVarint64(input, value);\n}\n\ntemplate \ninline void WireFormat::WriteFixed(OutputStream& output, const T& value) {\n WriteAll(output, &value, sizeof(T));\n}\n\ninline void WireFormat::WriteBytes(OutputStream& output, const void* buf, size_t len) {\n WriteAll(output, buf, len);\n}\n\ninline void WireFormat::WriteString(OutputStream& output, std::string_view value) {\n WriteVarint64(output, value.size());\n WriteAll(output, value.data(), value.size());\n}\n\ninline void WireFormat::WriteUInt64(OutputStream& output, const uint64_t value) {\n WriteVarint64(output, value);\n}\n\n}\n\n// Path: clickhouse/columns/factory.h\n#pragma once\n\n#include \"column.h\"\n\nnamespace clickhouse {\n\nstruct CreateColumnByTypeSettings\n{\n bool low_cardinality_as_wrapped_column = false;\n};\n\nColumnRef CreateColumnByType(const std::string& type_name, CreateColumnByTypeSettings settings = {});\n\n}\n\n// Path: clickhouse/base/sslsocket.h\n#pragma once\n\n#include \"socket.h\"\n\n#include \n#include \n#include \n\ntypedef struct ssl_ctx_st SSL_CTX;\ntypedef struct ssl_st SSL;\n\nnamespace clickhouse {\n\nstruct SSLParams\n{\n std::vector path_to_ca_files;\n std::string path_to_ca_directory;\n bool use_default_ca_locations;\n int context_options;\n int min_protocol_version;\n int max_protocol_version;\n bool use_SNI;\n bool skip_verification;\n int host_flags;\n using ConfigurationType = std::vector>>;\n ConfigurationType configuration;\n};\n\nclass SSLContext\n{\npublic:\n explicit SSLContext(SSL_CTX & context);\n explicit SSLContext(const SSLParams & context_params);\n ~SSLContext() = default;\n\n SSLContext(const SSLContext &) = delete;\n SSLContext& operator=(const SSLContext &) = delete;\n SSLContext(SSLContext &&) = delete;\n SSLContext& operator=(SSLContext &) = delete;\n\nprivate:\n friend class SSLSocket;\n SSL_CTX * getContext();\n\nprivate:\n std::unique_ptr context_;\n};\n\nclass SSLSocket : public Socket {\npublic:\n explicit SSLSocket(const NetworkAddress& addr, const SocketTimeoutParams& timeout_params,\n const SSLParams& ssl_params, SSLContext& context);\n\n SSLSocket(SSLSocket &&) = default;\n ~SSLSocket() override = default;\n\n SSLSocket(const SSLSocket & ) = delete;\n SSLSocket& operator=(const SSLSocket & ) = delete;\n\n std::unique_ptr makeInputStream() const override;\n std::unique_ptr makeOutputStream() const override;\n\n static void validateParams(const SSLParams & ssl_params);\nprivate:\n std::unique_ptr ssl_;\n};\n\nclass SSLSocketFactory : public NonSecureSocketFactory {\npublic:\n explicit SSLSocketFactory(const ClientOptions& opts);\n ~SSLSocketFactory() override;\n\nprotected:\n std::unique_ptr doConnect(const NetworkAddress& address, const ClientOptions& opts) override;\n\nprivate:\n const SSLParams ssl_params_;\n std::unique_ptr ssl_context_;\n};\n\nclass SSLSocketInput : public InputStream {\npublic:\n explicit SSLSocketInput(SSL *ssl);\n ~SSLSocketInput() = default;\n\n bool Skip(size_t /*bytes*/) override {\n return false;\n }\n\nprotected:\n size_t DoRead(void* buf, size_t len) override;\n\nprivate:\n // Not owning\n SSL *ssl_;\n};\n\nclass SSLSocketOutput : public OutputStream {\npublic:\n explicit SSLSocketOutput(SSL *ssl);\n ~SSLSocketOutput() = default;\n\nprotected:\n size_t DoWrite(const void* data, size_t len) override;\n\nprivate:\n // Not owning\n SSL *ssl_;\n};\n\n}\n\n// Path: clickhouse/client.cpp\n#include \"client.h\"\n#include \"clickhouse/version.h\"\n#include \"protocol.h\"\n\n#include \"base/compressed.h\"\n#include \"base/socket.h\"\n#include \"base/wire_format.h\"\n\n#include \"columns/factory.h\"\n\n#include \n#include \n#include \n#include \n\n#if defined(WITH_OPENSSL)\n#include \"base/sslsocket.h\"\n#endif\n\n#define DBMS_NAME \"ClickHouse\"\n\n#define DBMS_MIN_REVISION_WITH_TEMPORARY_TABLES 50264\n#define DBMS_MIN_REVISION_WITH_TOTAL_ROWS_IN_PROGRESS 51554\n#define DBMS_MIN_REVISION_WITH_BLOCK_INFO 51903\n#define DBMS_MIN_REVISION_WITH_CLIENT_INFO 54032\n#define DBMS_MIN_REVISION_WITH_SERVER_TIMEZONE 54058\n#define DBMS_MIN_REVISION_WITH_QUOTA_KEY_IN_CLIENT_INFO 54060\n//#define DBMS_MIN_REVISION_WITH_TABLES_STATUS 54226\n#define DBMS_MIN_REVISION_WITH_TIME_ZONE_PARAMETER_IN_DATETIME_DATA_TYPE 54337\n#define DBMS_MIN_REVISION_WITH_SERVER_DISPLAY_NAME 54372\n#define DBMS_MIN_REVISION_WITH_VERSION_PATCH 54401\n#define DBMS_MIN_REVISION_WITH_LOW_CARDINALITY_TYPE 54405\n#define DBMS_MIN_REVISION_WITH_COLUMN_DEFAULTS_METADATA 54410\n#define DBMS_MIN_REVISION_WITH_CLIENT_WRITE_INFO 54420\n#define DBMS_MIN_REVISION_WITH_SETTINGS_SERIALIZED_AS_STRINGS 54429\n#define DBMS_MIN_REVISION_WITH_INTERSERVER_SECRET 54441\n#define DBMS_MIN_REVISION_WITH_OPENTELEMETRY 54442\n#define DBMS_MIN_REVISION_WITH_DISTRIBUTED_DEPTH 54448\n#define DBMS_MIN_REVISION_WITH_INITIAL_QUERY_START_TIME 54449\n#define DBMS_MIN_REVISION_WITH_INCREMENTAL_PROFILE_EVENTS 54451\n\n#define DMBS_PROTOCOL_REVISION DBMS_MIN_REVISION_WITH_INCREMENTAL_PROFILE_EVENTS\n\nnamespace clickhouse {\n\nstruct ClientInfo {\n uint8_t iface_type = 1; // TCP\n uint8_t query_kind;\n std::string initial_user;\n std::string initial_query_id;\n std::string quota_key;\n std::string os_user;\n std::string client_hostname;\n std::string client_name;\n std::string initial_address = \"[::ffff:127.0.0.1]:0\";\n uint64_t client_version_major = 0;\n uint64_t client_version_minor = 0;\n uint64_t client_version_patch = 0;\n uint32_t client_revision = 0;\n};\n\nstd::ostream& operator<<(std::ostream& os, const Endpoint& endpoint) {\n return os << endpoint.host << \":\" << endpoint.port;\n}\n\nstd::ostream& operator<<(std::ostream& os, const ClientOptions& opt) {\n os << \"Client(\"\n << \" Endpoints : [\";\n size_t extra_endpoints = 0;\n\n if (!opt.host.empty()) {\n extra_endpoints = 1;\n os << opt.user << '@' << Endpoint{opt.host, opt.port};\n\n if (opt.endpoints.size())\n os << \", \";\n }\n\n for (size_t i = 0; i < opt.endpoints.size(); i++) {\n os << opt.user << '@' << opt.endpoints[i]\n << ((i == opt.endpoints.size() - 1) ? \"\" : \", \");\n }\n\n os << \"] (\" << opt.endpoints.size() + extra_endpoints << \" items )\"\n << \" ping_before_query:\" << opt.ping_before_query\n << \" send_retries:\" << opt.send_retries\n << \" retry_timeout:\" << opt.retry_timeout.count()\n << \" compression_method:\"\n << (opt.compression_method == CompressionMethod::LZ4 ? \"LZ4\" : \"None\");\n#if defined(WITH_OPENSSL)\n if (opt.ssl_options) {\n const auto & ssl_options = *opt.ssl_options;\n os << \" SSL (\"\n << \" ssl_context: \" << (ssl_options.ssl_context ? \"provided by user\" : \"created internally\")\n << \" use_default_ca_locations: \" << ssl_options.use_default_ca_locations\n << \" path_to_ca_files: \" << ssl_options.path_to_ca_files.size() << \" items\"\n << \" path_to_ca_directory: \" << ssl_options.path_to_ca_directory\n << \" min_protocol_version: \" << ssl_options.min_protocol_version\n << \" max_protocol_version: \" << ssl_options.max_protocol_version\n << \" context_options: \" << ssl_options.context_options\n << \")\";\n }\n#endif\n os << \")\";\n return os;\n}\n\nClientOptions& ClientOptions::SetSSLOptions(ClientOptions::SSLOptions options)\n{\n#ifdef WITH_OPENSSL\n ssl_options = options;\n return *this;\n#else\n (void)options;\n throw OpenSSLError(\"Library was built with no SSL support\");\n#endif\n}\n\nnamespace {\n\nstd::unique_ptr GetSocketFactory(const ClientOptions& opts) {\n (void)opts;\n#if defined(WITH_OPENSSL)\n if (opts.ssl_options)\n return std::make_unique(opts);\n else\n#endif\n return std::make_unique();\n}\n\n\nstd::unique_ptr GetEndpointsIterator(const ClientOptions& opts) {\n if (opts.endpoints.empty())\n {\n throw ValidationError(\"The list of endpoints is empty\");\n }\n\n return std::make_unique(opts.endpoints);\n}\n\n}\n\nclass Client::Impl {\npublic:\n Impl(const ClientOptions& opts);\n Impl(const ClientOptions& opts,\n std::unique_ptr socket_factory);\n ~Impl();\n\n void ExecuteQuery(Query query);\n\n void SendCancel();\n\n void Insert(const std::string& table_name, const std::string& query_id, const Block& block);\n\n void Ping();\n\n void ResetConnection();\n\n void ResetConnectionEndpoint();\n\n const ServerInfo& GetServerInfo() const;\n\n const std::optional& GetCurrentEndpoint() const;\n\nprivate:\n bool Handshake();\n\n bool ReceivePacket(uint64_t* server_packet = nullptr);\n\n void SendQuery(const Query& query);\n\n void SendData(const Block& block);\n\n bool SendHello();\n\n bool ReadBlock(InputStream& input, Block* block);\n\n bool ReceiveHello();\n\n /// Reads data packet form input stream.\n bool ReceiveData();\n\n /// Reads exception packet form input stream.\n bool ReceiveException(bool rethrow = false);\n\n void WriteBlock(const Block& block, OutputStream& output);\n\n void CreateConnection();\n\n void InitializeStreams(std::unique_ptr&& socket);\n\n inline size_t GetConnectionAttempts() const\n {\n return options_.endpoints.size() * options_.send_retries;\n }\n\nprivate:\n /// In case of network errors tries to reconnect to server and\n /// call fuc several times.\n void RetryGuard(std::function func);\n\n void RetryConnectToTheEndpoint(std::function& func);\n\nprivate:\n class EnsureNull {\n public:\n inline EnsureNull(QueryEvents* ev, QueryEvents** ptr)\n : ptr_(ptr)\n {\n if (ptr_) {\n *ptr_ = ev;\n }\n }\n\n inline ~EnsureNull() {\n if (ptr_) {\n *ptr_ = nullptr;\n }\n }\n\n private:\n QueryEvents** ptr_;\n\n };\n\n\n const ClientOptions options_;\n QueryEvents* events_;\n int compression_ = CompressionState::Disable;\n\n std::unique_ptr socket_factory_;\n\n std::unique_ptr input_;\n std::unique_ptr output_;\n std::unique_ptr socket_;\n std::unique_ptr endpoints_iterator;\n\n std::optional current_endpoint_;\n\n ServerInfo server_info_;\n};\n\nClientOptions modifyClientOptions(ClientOptions opts)\n{\n if (opts.host.empty())\n return opts;\n\n Endpoint default_endpoint({opts.host, opts.port});\n opts.endpoints.emplace(opts.endpoints.begin(), default_endpoint);\n return opts;\n}\n\nClient::Impl::Impl(const ClientOptions& opts)\n : Impl(opts, GetSocketFactory(opts)) {}\n\nClient::Impl::Impl(const ClientOptions& opts,\n std::unique_ptr socket_factory)\n : options_(modifyClientOptions(opts))\n , events_(nullptr)\n , socket_factory_(std::move(socket_factory))\n , endpoints_iterator(GetEndpointsIterator(options_))\n{\n CreateConnection();\n\n if (options_.compression_method != CompressionMethod::None) {\n compression_ = CompressionState::Enable;\n }\n}\n\nClient::Impl::~Impl()\n{ }\n\nvoid Client::Impl::ExecuteQuery(Query query) {\n EnsureNull en(static_cast(&query), &events_);\n\n if (options_.ping_before_query) {\n RetryGuard([this]() { Ping(); });\n }\n\n SendQuery(query);\n\n while (ReceivePacket()) {\n ;\n }\n}\n\nstd::string NameToQueryString(const std::string &input)\n{\n std::string output;\n output.reserve(input.size() + 2);\n output += '`';\n\n for (const auto & c : input) {\n if (c == '`') {\n //escape ` with ``\n output.append(\"``\");\n } else {\n output.push_back(c);\n }\n }\n\n output += '`';\n return output;\n}\n\nvoid Client::Impl::Insert(const std::string& table_name, const std::string& query_id, const Block& block) {\n if (options_.ping_before_query) {\n RetryGuard([this]() { Ping(); });\n }\n\n std::stringstream fields_section;\n const auto num_columns = block.GetColumnCount();\n\n for (unsigned int i = 0; i < num_columns; ++i) {\n if (i == num_columns - 1) {\n fields_section << NameToQueryString(block.GetColumnName(i));\n } else {\n fields_section << NameToQueryString(block.GetColumnName(i)) << \",\";\n }\n }\n\n Query query(\"INSERT INTO \" + table_name + \" ( \" + fields_section.str() + \" ) VALUES\", query_id);\n SendQuery(query);\n\n uint64_t server_packet;\n // Receive data packet.\n while (true) {\n bool ret = ReceivePacket(&server_packet);\n\n if (!ret) {\n throw ProtocolError(\"fail to receive data packet\");\n }\n if (server_packet == ServerCodes::Data) {\n break;\n }\n if (server_packet == ServerCodes::Progress) {\n continue;\n }\n }\n\n // Send data.\n SendData(block);\n // Send empty block as marker of\n // end of data.\n SendData(Block());\n\n // Wait for EOS.\n uint64_t eos_packet{0};\n while (ReceivePacket(&eos_packet)) {\n ;\n }\n\n if (eos_packet != ServerCodes::EndOfStream && eos_packet != ServerCodes::Exception\n && eos_packet != ServerCodes::Log && options_.rethrow_exceptions) {\n throw ProtocolError(std::string{\"unexpected packet from server while receiving end of query, expected (expected Exception, EndOfStream or Log, got: \"}\n + (eos_packet ? std::to_string(eos_packet) : \"nothing\") + \")\");\n }\n}\n\nvoid Client::Impl::Ping() {\n WireFormat::WriteUInt64(*output_, ClientCodes::Ping);\n output_->Flush();\n\n uint64_t server_packet;\n const bool ret = ReceivePacket(&server_packet);\n\n if (!ret || server_packet != ServerCodes::Pong) {\n throw ProtocolError(\"fail to ping server\");\n }\n}\n\nvoid Client::Impl::ResetConnection() {\n InitializeStreams(socket_factory_->connect(options_, current_endpoint_.value()));\n\n if (!Handshake()) {\n throw ProtocolError(\"fail to connect to \" + options_.host);\n }\n}\n\nvoid Client::Impl::ResetConnectionEndpoint() {\n current_endpoint_.reset();\n for (size_t i = 0; i < options_.endpoints.size();)\n {\n try\n {\n current_endpoint_ = endpoints_iterator->Next();\n ResetConnection();\n return;\n } catch (const std::system_error&) {\n if (++i == options_.endpoints.size())\n {\n current_endpoint_.reset();\n throw;\n }\n }\n }\n}\n\nvoid Client::Impl::CreateConnection() {\n // make sure to try to connect to each endpoint at least once even if `options_.send_retries` is 0\n const size_t max_attempts = (options_.send_retries ? options_.send_retries : 1);\n for (size_t i = 0; i < max_attempts;)\n {\n try\n {\n // Try to connect to each endpoint before throwing exception.\n ResetConnectionEndpoint();\n return;\n } catch (const std::system_error&) {\n if (++i >= max_attempts)\n {\n throw;\n }\n }\n }\n}\n\nconst ServerInfo& Client::Impl::GetServerInfo() const {\n return server_info_;\n}\n\n\nconst std::optional& Client::Impl::GetCurrentEndpoint() const {\n return current_endpoint_;\n}\n\nbool Client::Impl::Handshake() {\n if (!SendHello()) {\n return false;\n }\n if (!ReceiveHello()) {\n return false;\n }\n return true;\n}\n\nbool Client::Impl::ReceivePacket(uint64_t* server_packet) {\n uint64_t packet_type = 0;\n\n if (!WireFormat::ReadVarint64(*input_, &packet_type)) {\n return false;\n }\n if (server_packet) {\n *server_packet = packet_type;\n }\n\n switch (packet_type) {\n case ServerCodes::Data: {\n if (!ReceiveData()) {\n throw ProtocolError(\"can't read data packet from input stream\");\n }\n return true;\n }\n\n case ServerCodes::Exception: {\n ReceiveException();\n return false;\n }\n\n case ServerCodes::ProfileInfo: {\n Profile profile;\n\n if (!WireFormat::ReadUInt64(*input_, &profile.rows)) {\n return false;\n }\n if (!WireFormat::ReadUInt64(*input_, &profile.blocks)) {\n return false;\n }\n if (!WireFormat::ReadUInt64(*input_, &profile.bytes)) {\n return false;\n }\n if (!WireFormat::ReadFixed(*input_, &profile.applied_limit)) {\n return false;\n }\n if (!WireFormat::ReadUInt64(*input_, &profile.rows_before_limit)) {\n return false;\n }\n if (!WireFormat::ReadFixed(*input_, &profile.calculated_rows_before_limit)) {\n return false;\n }\n\n if (events_) {\n events_->OnProfile(profile);\n }\n\n return true;\n }\n\n case ServerCodes::Progress: {\n Progress info;\n\n if (!WireFormat::ReadUInt64(*input_, &info.rows)) {\n return false;\n }\n if (!WireFormat::ReadUInt64(*input_, &info.bytes)) {\n return false;\n }\n if constexpr(DMBS_PROTOCOL_REVISION >= DBMS_MIN_REVISION_WITH_TOTAL_ROWS_IN_PROGRESS) {\n if (!WireFormat::ReadUInt64(*input_, &info.total_rows)) {\n return false;\n }\n }\n if constexpr (DMBS_PROTOCOL_REVISION >= DBMS_MIN_REVISION_WITH_CLIENT_WRITE_INFO)\n {\n if (!WireFormat::ReadUInt64(*input_, &info.written_rows)) {\n return false;\n }\n if (!WireFormat::ReadUInt64(*input_, &info.written_bytes)) {\n return false;\n }\n }\n\n if (events_) {\n events_->OnProgress(info);\n }\n\n return true;\n }\n\n case ServerCodes::Pong: {\n return true;\n }\n\n case ServerCodes::Hello: {\n return true;\n }\n\n case ServerCodes::EndOfStream: {\n if (events_) {\n events_->OnFinish();\n }\n return false;\n }\n\n case ServerCodes::Log: {\n // log tag\n if (!WireFormat::SkipString(*input_)) {\n return false;\n }\n Block block;\n\n // Use uncompressed stream since log blocks usually contain only one row\n if (!ReadBlock(*input_, &block)) {\n return false;\n }\n\n if (events_) {\n events_->OnServerLog(block);\n }\n return true;\n }\n\n case ServerCodes::TableColumns: {\n // external table name\n if (!WireFormat::SkipString(*input_)) {\n return false;\n }\n\n // columns metadata\n if (!WireFormat::SkipString(*input_)) {\n return false;\n }\n return true;\n }\n\n case ServerCodes::ProfileEvents: {\n if (!WireFormat::SkipString(*input_)) {\n return false;\n }\n\n Block block;\n if (!ReadBlock(*input_, &block)) {\n return false;\n }\n\n if (events_) {\n events_->OnProfileEvents(block);\n }\n return true;\n }\n\n default:\n throw UnimplementedError(\"unimplemented \" + std::to_string((int)packet_type));\n break;\n }\n}\n\nbool Client::Impl::ReadBlock(InputStream& input, Block* block) {\n // Additional information about block.\n if constexpr (DMBS_PROTOCOL_REVISION >= DBMS_MIN_REVISION_WITH_BLOCK_INFO) {\n uint64_t num;\n BlockInfo info;\n\n // BlockInfo\n if (!WireFormat::ReadUInt64(input, &num)) {\n return false;\n }\n if (!WireFormat::ReadFixed(input, &info.is_overflows)) {\n return false;\n }\n if (!WireFormat::ReadUInt64(input, &num)) {\n return false;\n }\n if (!WireFormat::ReadFixed(input, &info.bucket_num)) {\n return false;\n }\n if (!WireFormat::ReadUInt64(input, &num)) {\n return false;\n }\n\n block->SetInfo(std::move(info));\n }\n\n uint64_t num_columns = 0;\n uint64_t num_rows = 0;\n\n if (!WireFormat::ReadUInt64(input, &num_columns)) {\n return false;\n }\n if (!WireFormat::ReadUInt64(input, &num_rows)) {\n return false;\n }\n\n CreateColumnByTypeSettings create_column_settings;\n create_column_settings.low_cardinality_as_wrapped_column = options_.backward_compatibility_lowcardinality_as_wrapped_column;\n\n for (size_t i = 0; i < num_columns; ++i) {\n std::string name;\n std::string type;\n if (!WireFormat::ReadString(input, &name)) {\n return false;\n }\n if (!WireFormat::ReadString(input, &type)) {\n return false;\n }\n\n if (ColumnRef col = CreateColumnByType(type, create_column_settings)) {\n if (num_rows && !col->Load(&input, num_rows)) {\n throw ProtocolError(\"can't load column '\" + name + \"' of type \" + type);\n }\n\n block->AppendColumn(name, col);\n } else {\n throw UnimplementedError(std::string(\"unsupported column type: \") + type);\n }\n }\n\n return true;\n}\n\nbool Client::Impl::ReceiveData() {\n Block block;\n\n if constexpr (DMBS_PROTOCOL_REVISION >= DBMS_MIN_REVISION_WITH_TEMPORARY_TABLES) {\n if (!WireFormat::SkipString(*input_)) {\n return false;\n }\n }\n\n if (compression_ == CompressionState::Enable) {\n CompressedInput compressed(input_.get());\n if (!ReadBlock(compressed, &block)) {\n return false;\n }\n } else {\n if (!ReadBlock(*input_, &block)) {\n return false;\n }\n }\n\n if (events_) {\n events_->OnData(block);\n if (!events_->OnDataCancelable(block)) {\n SendCancel();\n }\n }\n\n return true;\n}\n\nbool Client::Impl::ReceiveException(bool rethrow) {\n std::unique_ptr e(new Exception);\n Exception* current = e.get();\n\n bool exception_received = true;\n do {\n bool has_nested = false;\n\n if (!WireFormat::ReadFixed(*input_, ¤t->code)) {\n exception_received = false;\n break;\n }\n if (!WireFormat::ReadString(*input_, ¤t->name)) {\n exception_received = false;\n break;\n }\n if (!WireFormat::ReadString(*input_, ¤t->display_text)) {\n exception_received = false;\n break;\n }\n if (!WireFormat::ReadString(*input_, ¤t->stack_trace)) {\n exception_received = false;\n break;\n }\n if (!WireFormat::ReadFixed(*input_, &has_nested)) {\n exception_received = false;\n break;\n }\n\n if (has_nested) {\n current->nested.reset(new Exception);\n current = current->nested.get();\n } else {\n break;\n }\n } while (true);\n\n if (events_) {\n events_->OnServerException(*e);\n }\n\n if (rethrow || options_.rethrow_exceptions) {\n throw ServerError(std::move(e));\n }\n\n return exception_received;\n}\n\nvoid Client::Impl::SendCancel() {\n WireFormat::WriteUInt64(*output_, ClientCodes::Cancel);\n output_->Flush();\n}\n\nvoid Client::Impl::SendQuery(const Query& query) {\n WireFormat::WriteUInt64(*output_, ClientCodes::Query);\n WireFormat::WriteString(*output_, query.GetQueryID());\n\n /// Client info.\n if (server_info_.revision >= DBMS_MIN_REVISION_WITH_CLIENT_INFO) {\n ClientInfo info;\n\n info.query_kind = 1;\n info.client_name = \"ClickHouse client\";\n info.client_version_major = CLICKHOUSE_CPP_VERSION_MAJOR;\n info.client_version_minor = CLICKHOUSE_CPP_VERSION_MINOR;\n info.client_version_patch = CLICKHOUSE_CPP_VERSION_PATCH;\n info.client_revision = DMBS_PROTOCOL_REVISION;\n\n\n WireFormat::WriteFixed(*output_, info.query_kind);\n WireFormat::WriteString(*output_, info.initial_user);\n WireFormat::WriteString(*output_, info.initial_query_id);\n WireFormat::WriteString(*output_, info.initial_address);\n if (server_info_.revision >= DBMS_MIN_REVISION_WITH_INITIAL_QUERY_START_TIME) {\n WireFormat::WriteFixed(*output_, 0);\n }\n WireFormat::WriteFixed(*output_, info.iface_type);\n\n WireFormat::WriteString(*output_, info.os_user);\n WireFormat::WriteString(*output_, info.client_hostname);\n WireFormat::WriteString(*output_, info.client_name);\n WireFormat::WriteUInt64(*output_, info.client_version_major);\n WireFormat::WriteUInt64(*output_, info.client_version_minor);\n WireFormat::WriteUInt64(*output_, info.client_revision);\n\n if (server_info_.revision >= DBMS_MIN_REVISION_WITH_QUOTA_KEY_IN_CLIENT_INFO)\n WireFormat::WriteString(*output_, info.quota_key);\n if (server_info_.revision >= DBMS_MIN_REVISION_WITH_DISTRIBUTED_DEPTH)\n WireFormat::WriteUInt64(*output_, 0u);\n if (server_info_.revision >= DBMS_MIN_REVISION_WITH_VERSION_PATCH) {\n WireFormat::WriteUInt64(*output_, info.client_version_patch);\n }\n\n if (server_info_.revision >= DBMS_MIN_REVISION_WITH_OPENTELEMETRY) {\n if (const auto& tracing_context = query.GetTracingContext()) {\n // Have OpenTelemetry header.\n WireFormat::WriteFixed(*output_, uint8_t(1));\n // No point writing these numbers with variable length, because they\n // are random and will probably require the full length anyway.\n WireFormat::WriteFixed(*output_, tracing_context->trace_id);\n WireFormat::WriteFixed(*output_, tracing_context->span_id);\n WireFormat::WriteString(*output_, tracing_context->tracestate);\n WireFormat::WriteFixed(*output_, tracing_context->trace_flags);\n } else {\n // Don't have OpenTelemetry header.\n WireFormat::WriteFixed(*output_, uint8_t(0));\n }\n } else {\n if (query.GetTracingContext()) {\n // Current implementation works only for server version >= v20.11.2.1-stable\n throw UnimplementedError(std::string(\"Can't send open telemetry tracing context to a server, server version is too old\"));\n }\n }\n }\n\n /// Per query settings\n if (server_info_.revision >= DBMS_MIN_REVISION_WITH_SETTINGS_SERIALIZED_AS_STRINGS) {\n for(const auto& [name, field] : query.GetQuerySettings()) {\n WireFormat::WriteString(*output_, name);\n WireFormat::WriteVarint64(*output_, field.flags);\n WireFormat::WriteString(*output_, field.value);\n }\n }\n else if (query.GetQuerySettings().size() > 0) {\n // Current implementation works only for server version >= v20.1.2.4-stable, since we do not implement binary settings serialization.\n throw UnimplementedError(std::string(\"Can't send query settings to a server, server version is too old\"));\n }\n // Empty string signals end of serialized settings\n WireFormat::WriteString(*output_, std::string());\n\n if (server_info_.revision >= DBMS_MIN_REVISION_WITH_INTERSERVER_SECRET) {\n WireFormat::WriteString(*output_, \"\");\n }\n\n WireFormat::WriteUInt64(*output_, Stages::Complete);\n WireFormat::WriteUInt64(*output_, compression_);\n WireFormat::WriteString(*output_, query.GetText());\n // Send empty block as marker of\n // end of data\n SendData(Block());\n\n output_->Flush();\n}\n\n\nvoid Client::Impl::WriteBlock(const Block& block, OutputStream& output) {\n // Additional information about block.\n if (server_info_.revision >= DBMS_MIN_REVISION_WITH_BLOCK_INFO) {\n WireFormat::WriteUInt64(output, 1);\n WireFormat::WriteFixed(output, block.Info().is_overflows);\n WireFormat::WriteUInt64(output, 2);\n WireFormat::WriteFixed(output, block.Info().bucket_num);\n WireFormat::WriteUInt64(output, 0);\n }\n\n WireFormat::WriteUInt64(output, block.GetColumnCount());\n WireFormat::WriteUInt64(output, block.GetRowCount());\n\n for (Block::Iterator bi(block); bi.IsValid(); bi.Next()) {\n WireFormat::WriteString(output, bi.Name());\n WireFormat::WriteString(output, bi.Type()->GetName());\n\n // Empty columns are not serialized and occupy exactly 0 bytes.\n // ref https://github.com/ClickHouse/ClickHouse/blob/39b37a3240f74f4871c8c1679910e065af6bea19/src/Formats/NativeWriter.cpp#L163\n const bool containsData = block.GetRowCount() > 0;\n if (containsData) {\n bi.Column()->Save(&output);\n }\n }\n output.Flush();\n}\n\nvoid Client::Impl::SendData(const Block& block) {\n WireFormat::WriteUInt64(*output_, ClientCodes::Data);\n\n if (server_info_.revision >= DBMS_MIN_REVISION_WITH_TEMPORARY_TABLES) {\n WireFormat::WriteString(*output_, std::string());\n }\n\n if (compression_ == CompressionState::Enable) {\n assert(options_.compression_method == CompressionMethod::LZ4);\n\n std::unique_ptr compressed_output = std::make_unique(output_.get(), options_.max_compression_chunk_size);\n BufferedOutput buffered(std::move(compressed_output), options_.max_compression_chunk_size);\n\n WriteBlock(block, buffered);\n } else {\n WriteBlock(block, *output_);\n }\n\n output_->Flush();\n}\n\nvoid Client::Impl::InitializeStreams(std::unique_ptr&& socket) {\n std::unique_ptr output = std::make_unique(socket->makeOutputStream());\n std::unique_ptr input = std::make_unique(socket->makeInputStream());\n\n std::swap(input, input_);\n std::swap(output, output_);\n std::swap(socket, socket_);\n}\n\nbool Client::Impl::SendHello() {\n WireFormat::WriteUInt64(*output_, ClientCodes::Hello);\n WireFormat::WriteString(*output_, std::string(DBMS_NAME) + \" client\");\n WireFormat::WriteUInt64(*output_, CLICKHOUSE_CPP_VERSION_MAJOR);\n WireFormat::WriteUInt64(*output_, CLICKHOUSE_CPP_VERSION_MINOR);\n WireFormat::WriteUInt64(*output_, DMBS_PROTOCOL_REVISION);\n WireFormat::WriteString(*output_, options_.default_database);\n WireFormat::WriteString(*output_, options_.user);\n WireFormat::WriteString(*output_, options_.password);\n\n output_->Flush();\n\n return true;\n}\n\nbool Client::Impl::ReceiveHello() {\n uint64_t packet_type = 0;\n\n if (!WireFormat::ReadVarint64(*input_, &packet_type)) {\n return false;\n }\n\n if (packet_type == ServerCodes::Hello) {\n if (!WireFormat::ReadString(*input_, &server_info_.name)) {\n return false;\n }\n if (!WireFormat::ReadUInt64(*input_, &server_info_.version_major)) {\n return false;\n }\n if (!WireFormat::ReadUInt64(*input_, &server_info_.version_minor)) {\n return false;\n }\n if (!WireFormat::ReadUInt64(*input_, &server_info_.revision)) {\n return false;\n }\n\n if (server_info_.revision >= DBMS_MIN_REVISION_WITH_SERVER_TIMEZONE) {\n if (!WireFormat::ReadString(*input_, &server_info_.timezone)) {\n return false;\n }\n }\n\n if (server_info_.revision >= DBMS_MIN_REVISION_WITH_SERVER_DISPLAY_NAME) {\n if (!WireFormat::ReadString(*input_, &server_info_.display_name)) {\n return false;\n }\n }\n\n...\n", "needle_token_start": 8724, "needle_token_end": 8804, "code_context_ntokens": 15981, "output": ["```cpp\nstd::unique_ptr GetEndpointsIterator(const ClientOptions& opts) {\n if (opts.endpoints.empty())\n {\n throw ValidationError(\"The list of endpoints is empty\");\n }\n\n return std::make_unique(opts.endpoints);\n}\n```"]} +{"repo": "ClickHouse/clickhouse-cpp", "name": "swap", "language": "cpp", "path": "clickhouse/columns/column.h", "position_ratio": 0.65, "description": "\nFunction Description:\n1. **Purpose**: The function is designed to exchange the contents of two instances of a class that represents a column in a database or data structure.\n2. **Input**: Two instances of the column class.\n3. **Output**: There is no return value; however, the internal state of the two column instances are exchanged.\n4. **Procedure**: This function directly swaps the internal data of the two column instances, effectively interchanging their states without the need for a temporary variable to hold the data of one during the swap. This is typically implemented to facilitate operations like reordering or algorithmic shuffling where elements need to be swapped efficiently.\n\n", "instruction": "Based on the function description and code context, please retrieve and repeat the exact described function from the code context in a code block wrapped by ```:", "template": "instruction\ncode_context\ndescription\ninstruction", "code_context": "// Path: clickhouse/types/types.h\n#pragma once\n\n#include \"absl/numeric/int128.h\"\n\n#include \n#include \n#include \n#include \n#include \n#include \n\nnamespace clickhouse {\n\nusing Int128 = absl::int128;\nusing Int64 = int64_t;\n\nusing TypeRef = std::shared_ptr;\n\nclass Type {\npublic:\n enum Code {\n Void = 0,\n Int8,\n Int16,\n Int32,\n Int64,\n UInt8,\n UInt16,\n UInt32,\n UInt64,\n Float32,\n Float64,\n String,\n FixedString,\n DateTime,\n Date,\n Array,\n Nullable,\n Tuple,\n Enum8,\n Enum16,\n UUID,\n IPv4,\n IPv6,\n Int128,\n Decimal,\n Decimal32,\n Decimal64,\n Decimal128,\n LowCardinality,\n DateTime64,\n Date32,\n Map,\n Point,\n Ring,\n Polygon,\n MultiPolygon\n };\n\n using EnumItem = std::pair;\n\nprotected:\n Type(const Code code);\n\npublic:\n template \n auto* As() {\n return static_cast(this);\n }\n\n template \n const auto* As() const {\n return static_cast(this);\n }\n\n /// Type's code.\n Code GetCode() const { return code_; }\n\n /// String representation of the type.\n std::string GetName() const;\n\n /// Is given type same as current one.\n bool IsEqual(const Type& other) const {\n // Types are equal only if both code_ and type_unique_id_ are equal.\n return this == &other\n // GetTypeUniqueId() is relatively heavy, so avoid calling it when comparing obviously different types.\n || (this->GetCode() == other.GetCode() && this->GetTypeUniqueId() == other.GetTypeUniqueId());\n }\n\n bool IsEqual(const TypeRef& other) const { return IsEqual(*other); }\n\n /// Simple name, doesn't depend on parameters and\\or nested types, caller MUST NOT free returned value.\n static const char* TypeName(Code);\n\npublic:\n static TypeRef CreateArray(TypeRef item_type);\n\n static TypeRef CreateDate();\n\n static TypeRef CreateDate32();\n\n static TypeRef CreateDateTime(std::string timezone = std::string());\n\n static TypeRef CreateDateTime64(size_t precision, std::string timezone = std::string());\n\n static TypeRef CreateDecimal(size_t precision, size_t scale);\n\n static TypeRef CreateIPv4();\n\n static TypeRef CreateIPv6();\n\n static TypeRef CreateNothing();\n\n static TypeRef CreateNullable(TypeRef nested_type);\n\n template \n static TypeRef CreateSimple();\n\n static TypeRef CreateString();\n\n static TypeRef CreateString(size_t n);\n\n static TypeRef CreateTuple(const std::vector& item_types);\n\n static TypeRef CreateEnum8(const std::vector& enum_items);\n\n static TypeRef CreateEnum16(const std::vector& enum_items);\n\n static TypeRef CreateUUID();\n\n static TypeRef CreateLowCardinality(TypeRef item_type);\n\n static TypeRef CreateMap(TypeRef key_type, TypeRef value_type);\n\n static TypeRef CreatePoint();\n\n static TypeRef CreateRing();\n\n static TypeRef CreatePolygon();\n\n static TypeRef CreateMultiPolygon();\n\nprivate:\n uint64_t GetTypeUniqueId() const;\n\n const Code code_;\n mutable std::atomic type_unique_id_;\n};\n\ninline bool operator==(const Type & left, const Type & right) {\n if (&left == &right)\n return true;\n if (typeid(left) == typeid(right))\n return left.IsEqual(right);\n return false;\n}\n\ninline bool operator==(const TypeRef & left, const TypeRef & right) {\n return *left == *right;\n}\n\nclass ArrayType : public Type {\npublic:\n explicit ArrayType(TypeRef item_type);\n\n std::string GetName() const { return std::string(\"Array(\") + item_type_->GetName() + \")\"; }\n\n /// Type of array's elements.\n inline TypeRef GetItemType() const { return item_type_; }\n\nprivate:\n TypeRef item_type_;\n};\n\nclass DecimalType : public Type {\npublic:\n DecimalType(size_t precision, size_t scale);\n\n std::string GetName() const;\n friend class EnumType;\n friend class DateTimeType;\n\n inline size_t GetScale() const { return scale_; }\n inline size_t GetPrecision() const { return precision_; }\n\nprivate:\n const size_t precision_, scale_;\n};\n\nnamespace details\n{\nclass TypeWithTimeZoneMixin\n{\npublic:\n TypeWithTimeZoneMixin(std::string timezone);\n\n /// Timezone associated with a data column.\n const std::string & Timezone() const;\n\nprivate:\n std::string timezone_;\n};\n}\n\nclass DateTimeType : public Type, public details::TypeWithTimeZoneMixin {\npublic:\n explicit DateTimeType(std::string timezone);\n\n std::string GetName() const;\n};\n\nclass DateTime64Type: public Type, public details::TypeWithTimeZoneMixin {\npublic:\n explicit DateTime64Type(size_t precision, std::string timezone_);\n\n std::string GetName() const;\n\n inline size_t GetPrecision() const { return precision_; }\nprivate:\n size_t precision_;\n};\n\nclass EnumType : public Type {\npublic:\n EnumType(Type::Code type, const std::vector& items);\n\n std::string GetName() const;\n\n /// Methods to work with enum types.\n std::string_view GetEnumName(int16_t value) const;\n int16_t GetEnumValue(const std::string& name) const;\n bool HasEnumName(const std::string& name) const;\n bool HasEnumValue(int16_t value) const;\n\nprivate:\n using ValueToNameType = std::map;\n using NameToValueType = std::map;\n using ValueToNameIterator = ValueToNameType::const_iterator;\n\n ValueToNameType value_to_name_;\n NameToValueType name_to_value_;\n\npublic:\n ValueToNameIterator BeginValueToName() const;\n ValueToNameIterator EndValueToName() const;\n};\n\nclass FixedStringType : public Type {\npublic:\n explicit FixedStringType(size_t n);\n\n std::string GetName() const { return std::string(\"FixedString(\") + std::to_string(size_) + \")\"; }\n\n inline size_t GetSize() const { return size_; }\n\nprivate:\n size_t size_;\n};\n\nclass NullableType : public Type {\npublic:\n explicit NullableType(TypeRef nested_type);\n\n std::string GetName() const { return std::string(\"Nullable(\") + nested_type_->GetName() + \")\"; }\n\n /// Type of nested nullable element.\n TypeRef GetNestedType() const { return nested_type_; }\n\nprivate:\n TypeRef nested_type_;\n};\n\nclass TupleType : public Type {\npublic:\n explicit TupleType(const std::vector& item_types);\n\n std::string GetName() const;\n\n /// Type of nested Tuple element type.\n std::vector GetTupleType() const { return item_types_; }\n\nprivate:\n std::vector item_types_;\n};\n\nclass LowCardinalityType : public Type {\npublic:\n explicit LowCardinalityType(TypeRef nested_type);\n ~LowCardinalityType();\n\n std::string GetName() const { return std::string(\"LowCardinality(\") + nested_type_->GetName() + \")\"; }\n\n /// Type of nested nullable element.\n TypeRef GetNestedType() const { return nested_type_; }\n\nprivate:\n TypeRef nested_type_;\n};\n\nclass MapType : public Type {\npublic:\n explicit MapType(TypeRef key_type, TypeRef value_type);\n\n std::string GetName() const;\n\n /// Type of keys.\n TypeRef GetKeyType() const { return key_type_; }\n\n /// Type of values.\n TypeRef GetValueType() const { return value_type_; }\n\nprivate:\n TypeRef key_type_;\n TypeRef value_type_;\n};\n\ntemplate <>\ninline TypeRef Type::CreateSimple() {\n return TypeRef(new Type(Int8));\n}\n\ntemplate <>\ninline TypeRef Type::CreateSimple() {\n return TypeRef(new Type(Int16));\n}\n\ntemplate <>\ninline TypeRef Type::CreateSimple() {\n return TypeRef(new Type(Int32));\n}\n\ntemplate <>\ninline TypeRef Type::CreateSimple() {\n return TypeRef(new Type(Int64));\n}\n\ntemplate <>\ninline TypeRef Type::CreateSimple() {\n return TypeRef(new Type(Int128));\n}\n\ntemplate <>\ninline TypeRef Type::CreateSimple() {\n return TypeRef(new Type(UInt8));\n}\n\ntemplate <>\ninline TypeRef Type::CreateSimple() {\n return TypeRef(new Type(UInt16));\n}\n\ntemplate <>\ninline TypeRef Type::CreateSimple() {\n return TypeRef(new Type(UInt32));\n}\n\ntemplate <>\ninline TypeRef Type::CreateSimple() {\n return TypeRef(new Type(UInt64));\n}\n\ntemplate <>\ninline TypeRef Type::CreateSimple() {\n return TypeRef(new Type(Float32));\n}\n\ntemplate <>\ninline TypeRef Type::CreateSimple() {\n return TypeRef(new Type(Float64));\n}\n\n} // namespace clickhouse\n\n// Path: clickhouse/server_exception.h\n#pragma once\n\n#include \n#include \n\nnamespace clickhouse {\nstruct Exception {\n int code = 0;\n std::string name;\n std::string display_text;\n std::string stack_trace;\n /// Pointer to nested exception.\n std::unique_ptr nested;\n};\n\n}\n\n// Path: clickhouse/exceptions.h\n#pragma once\n\n#include \"server_exception.h\"\n\n#include \n\nnamespace clickhouse {\n\nclass Error : public std::runtime_error {\n using std::runtime_error::runtime_error;\n};\n\n// Caused by any user-related code, like invalid column types or arguments passed to any method.\nclass ValidationError : public Error {\n using Error::Error;\n};\n\n// Buffers+IO errors, failure to serialize/deserialize, checksum mismatches, etc.\nclass ProtocolError : public Error {\n using Error::Error;\n};\n\nclass UnimplementedError : public Error {\n using Error::Error;\n};\n\n// Internal validation error.\nclass AssertionError : public Error {\n using Error::Error;\n};\n\nclass OpenSSLError : public Error {\n using Error::Error;\n};\n\nclass LZ4Error : public Error {\n using Error::Error;\n};\n\n// Exception received from server.\nclass ServerException : public Error {\npublic:\n ServerException(std::unique_ptr e)\n : Error(std::string())\n , exception_(std::move(e))\n {\n }\n\n int GetCode() const {\n return exception_->code;\n }\n\n const Exception& GetException() const {\n return *exception_;\n }\n\n const char* what() const noexcept override {\n return exception_->display_text.c_str();\n }\n\nprivate:\n std::unique_ptr exception_;\n};\nusing ServerError = ServerException;\n\n}\n\n// Path: clickhouse/columns/itemview.h\n#pragma once\n\n#include \"../types/types.h\"\n#include \"../exceptions.h\"\n\n#include \n#include \n#include \n\nnamespace clickhouse {\n\n/** ItemView is a view on a data stored in Column, safe-ish interface for reading values from Column.\n *\n * Data is not owned (hence the name View) and will be invalidated on column update, load\n * or destruction (basically on calling any non-const method of Column).\n * `type` reflects what is stored in `data` and can be almost any value-type\n * (except Nullable, Array, Tuple, LowCardinality).\n *\n */\nstruct ItemView {\n using DataType = std::string_view;\n\n const Type::Code type;\n const DataType data;\n\nprivate:\n template \n inline auto ConvertToStorageValue(const T& t) {\n if constexpr (std::is_same_v || std::is_same_v) {\n return std::string_view{t};\n } else if constexpr (std::is_fundamental_v || std::is_same_v>) {\n return std::string_view{reinterpret_cast(&t), sizeof(T)};\n } else {\n static_assert(!std::is_same_v, \"Unknown type, which can't be stored in ItemView\");\n return;\n }\n }\n\npublic:\n ItemView(Type::Code type, DataType data)\n : type(type),\n data(data)\n {\n ValidateData(type, data);\n }\n\n ItemView(Type::Code type, ItemView other)\n : type(type),\n data(other.data)\n {\n ValidateData(type, data);\n }\n\n explicit ItemView()\n : ItemView(Type::Void, std::string_view{})\n {}\n\n template \n explicit ItemView(Type::Code type, const T & value)\n : ItemView(type, ConvertToStorageValue(value))\n {}\n\n template \n auto get() const {\n using ValueType = std::remove_cv_t>;\n if constexpr (std::is_same_v || std::is_same_v) {\n return data;\n } else if constexpr (std::is_fundamental_v || std::is_same_v) {\n if (sizeof(ValueType) == data.size()) {\n return *reinterpret_cast(data.data());\n } else {\n throw AssertionError(\"Incompatitable value type and size. Requested size: \"\n + std::to_string(sizeof(ValueType)) + \" stored size: \" + std::to_string(data.size()));\n }\n }\n }\n\n inline std::string_view AsBinaryData() const {\n return data;\n }\n\n // Validate that value matches type, will throw an exception if validation fails.\n static void ValidateData(Type::Code type, DataType data);\n};\n\n}\n\n// Path: clickhouse/columns/column.h\n#pragma once\n\n#include \"../types/types.h\"\n#include \"../columns/itemview.h\"\n#include \"../exceptions.h\"\n\n#include \n#include \n\nnamespace clickhouse {\n\nclass InputStream;\nclass OutputStream;\n\nusing ColumnRef = std::shared_ptr;\n\n/**\n * An abstract base of all columns classes.\n */\nclass Column : public std::enable_shared_from_this {\npublic:\n explicit inline Column(TypeRef type) : type_(type) {}\n\n virtual ~Column() {}\n\n /// Downcast pointer to the specific column's subtype.\n template \n inline std::shared_ptr As() {\n return std::dynamic_pointer_cast(shared_from_this());\n }\n\n /// Downcast pointer to the specific column's subtype.\n template \n inline std::shared_ptr As() const {\n return std::dynamic_pointer_cast(shared_from_this());\n }\n\n /// Downcast pointer to the specific column's subtype.\n template \n inline std::shared_ptr AsStrict() {\n auto result = std::dynamic_pointer_cast(shared_from_this());\n if (!result) {\n throw ValidationError(\"Can't cast from \" + type_->GetName());\n }\n return result;\n }\n\n /// Get type object of the column.\n inline TypeRef Type() const { return type_; }\n inline const class Type& GetType() const { return *type_; }\n\n /// Appends content of given column to the end of current one.\n virtual void Append(ColumnRef column) = 0;\n\n /// Increase the capacity of the column for large block insertion.\n virtual void Reserve(size_t new_cap) = 0;\n\n /// Template method to load column data from input stream. It'll call LoadPrefix and LoadBody.\n /// Should be called only once from the client. Derived classes should not call it.\n bool Load(InputStream* input, size_t rows);\n\n /// Loads column prefix from input stream.\n virtual bool LoadPrefix(InputStream* input, size_t rows);\n\n /// Loads column data from input stream.\n virtual bool LoadBody(InputStream* input, size_t rows) = 0;\n\n /// Saves column prefix to output stream. Column types with prefixes must implement it.\n virtual void SavePrefix(OutputStream* output);\n\n /// Saves column body to output stream.\n virtual void SaveBody(OutputStream* output) = 0;\n\n /// Template method to save to output stream. It'll call SavePrefix and SaveBody respectively\n /// Should be called only once from the client. Derived classes should not call it.\n /// Save is split in Prefix and Body because some data types require prefixes and specific serialization order.\n /// For instance, Array(LowCardinality(X)) requires LowCardinality.key_version bytes to come before Array.offsets\n void Save(OutputStream* output);\n\n /// Clear column data .\n virtual void Clear() = 0;\n\n /// Returns count of rows in the column.\n virtual size_t Size() const = 0;\n\n /// Makes slice of the current column.\n virtual ColumnRef Slice(size_t begin, size_t len) const = 0;\n\n virtual ColumnRef CloneEmpty() const = 0;\n\n virtual void Swap(Column&) = 0;\n\n /// Get a view on raw item data if it is supported by column, will throw an exception if index is out of range.\n /// Please note that view is invalidated once column items are added or deleted, column is loaded from strean or destroyed.\n virtual ItemView GetItem(size_t) const {\n throw UnimplementedError(\"GetItem() is not supported for column of \" + type_->GetName());\n }\n\n friend \nvoid swap(Column& left, Column& right) {\n left.Swap(right);\n }\n\nprotected:\n TypeRef type_;\n};\n\n} // namespace clickhouse\n\n// Path: clickhouse/block.h\n#pragma once\n\n#include \"columns/column.h\"\n\nnamespace clickhouse {\n\nstruct BlockInfo {\n uint8_t is_overflows = 0;\n int32_t bucket_num = -1;\n};\n\nclass Block {\npublic:\n /// Allow to iterate over block's columns.\n class Iterator {\n public:\n Iterator(const Block& block);\n\n /// Name of column.\n const std::string& Name() const;\n\n /// Type of column.\n TypeRef Type() const;\n\n /// Reference to column object.\n ColumnRef Column() const;\n\n /// Move to next column, returns false if next call to IsValid() would return false;\n bool Next();\n\n /// Is the iterator still valid.\n bool IsValid() const;\n\n size_t ColumnIndex() const {\n return idx_;\n }\n\n Iterator& operator*() { return *this; }\n const Iterator& operator*() const { return *this; }\n\n bool operator==(const Iterator & other) const {\n return &block_ == &other.block_ && idx_ == other.idx_;\n }\n bool operator!=(const Iterator & other) const {\n return !(*this == other);\n }\n\n Iterator& operator++() {\n this->Next();\n return *this;\n }\n\n private:\n friend class Block;\n struct ConstructAtEndTag {};\n Iterator(const Block& block, ConstructAtEndTag at_end);\n Iterator() = delete;\n\n const Block& block_;\n size_t idx_;\n };\n\npublic:\n Block();\n Block(size_t cols, size_t rows);\n ~Block();\n\n /// Append named column to the block.\n void AppendColumn(const std::string& name, const ColumnRef& col);\n\n /// Count of columns in the block.\n size_t GetColumnCount() const;\n\n const BlockInfo& Info() const;\n\n /// Set block info\n void SetInfo(BlockInfo info);\n\n /// Count of rows in the block.\n size_t GetRowCount() const;\n\n size_t RefreshRowCount();\n\n const std::string& GetColumnName(size_t idx) const {\n return columns_.at(idx).name;\n }\n\n /// Reference to column by index in the block.\n ColumnRef operator [] (size_t idx) const;\n\n Iterator begin() const;\n Iterator end() const;\n Iterator cbegin() const { return begin(); }\n Iterator cend() const { return end(); }\n\nprivate:\n struct ColumnItem {\n std::string name;\n ColumnRef column;\n };\n\n BlockInfo info_;\n std::vector columns_;\n /// Count of rows in the block.\n size_t rows_;\n};\n\n}\n\n// Path: clickhouse/block.cpp\n#include \"block.h\"\n\n#include \"exceptions.h\"\n\n#include \n\nnamespace clickhouse {\n\nBlock::Iterator::Iterator(const Block& block)\n : block_(block)\n , idx_(0)\n{\n}\n\nBlock::Iterator::Iterator(const Block& block, Block::Iterator::ConstructAtEndTag /*at_end*/)\n : block_(block)\n , idx_(block.GetColumnCount())\n{}\n\nconst std::string& Block::Iterator::Name() const {\n return block_.columns_[idx_].name;\n}\n\nTypeRef Block::Iterator::Type() const {\n return block_.columns_[idx_].column->Type();\n}\n\nColumnRef Block::Iterator::Column() const {\n return block_.columns_[idx_].column;\n}\n\nbool Block::Iterator::Next() {\n ++idx_;\n return IsValid();\n}\n\nbool Block::Iterator::IsValid() const {\n return idx_ < block_.columns_.size();\n}\n\n\nBlock::Block()\n : rows_(0)\n{\n}\n\nBlock::Block(size_t cols, size_t rows)\n : rows_(rows)\n{\n columns_.reserve(cols);\n}\n\nBlock::~Block() = default;\n\nvoid Block::AppendColumn(const std::string& name, const ColumnRef& col) {\n if (columns_.empty()) {\n rows_ = col->Size();\n } else if (col->Size() != rows_) {\n throw ValidationError(\"all columns in block must have same count of rows. Name: [\"+name+\"], rows: [\"+std::to_string(rows_)+\"], columns: [\" + std::to_string(col->Size())+\"]\");\n }\n\n columns_.push_back(ColumnItem{name, col});\n}\n\n/// Count of columns in the block.\nsize_t Block::GetColumnCount() const {\n return columns_.size();\n}\n\nconst BlockInfo& Block::Info() const {\n return info_;\n}\n\n/// Set block info\nvoid Block::SetInfo(BlockInfo info) {\n info_ = std::move(info);\n}\n\n/// Count of rows in the block.\nsize_t Block::GetRowCount() const {\n return rows_;\n}\n\nsize_t Block::RefreshRowCount()\n{\n size_t rows = 0UL;\n\n for (size_t idx = 0UL; idx < columns_.size(); ++idx)\n {\n const std::string& name = columns_[idx].name;\n const ColumnRef& col = columns_[idx].column;\n\n if (idx == 0UL)\n rows = col->Size();\n else if (rows != col->Size())\n throw ValidationError(\"all columns in block must have same count of rows. Name: [\"+name+\"], rows: [\"+std::to_string(rows)+\"], columns: [\" + std::to_string(col->Size())+\"]\");\n }\n\n rows_ = rows;\n return rows_;\n}\n\nColumnRef Block::operator [] (size_t idx) const {\n if (idx < columns_.size()) {\n return columns_[idx].column;\n }\n\n throw std::out_of_range(\"column index is out of range. Index: [\"+std::to_string(idx)+\"], columns: [\" + std::to_string(columns_.size())+\"]\");\n}\n\nBlock::Iterator Block::begin() const {\n return Iterator(*this);\n}\n\nBlock::Iterator Block::end() const {\n return Iterator(*this, Iterator::ConstructAtEndTag{});\n}\n\n}\n\n// Path: clickhouse/base/uuid.h\n#pragma once\n\n#include \n#include \n\nnamespace clickhouse {\n\nusing UInt128 = std::pair;\n\nusing UUID = UInt128;\n\n}\n\n// Path: clickhouse/base/open_telemetry.h\n#pragma once\n\n#include \"uuid.h\"\n\n#include \n\nnamespace clickhouse::open_telemetry {\n\n/// See https://www.w3.org/TR/trace-context/ for trace_flags definition\nenum TraceFlags : uint8_t {\n TRACE_FLAG_NONE = 0,\n TRACE_FLAG_SAMPLED = 1,\n};\n\n/// The runtime info we need to create new OpenTelemetry spans.\nstruct TracingContext {\n UUID trace_id{};\n uint64_t span_id = 0;\n std::string tracestate;\n uint8_t trace_flags = TRACE_FLAG_NONE;\n};\n\n} // namespace clickhouse::open_telemetry\n\n// Path: clickhouse/query.h\n#pragma once\n\n#include \"block.h\"\n#include \"server_exception.h\"\n\n#include \"base/open_telemetry.h\"\n\n#include \n#include \n#include \n#include \n#include \n#include \n\nnamespace clickhouse {\n\nstruct QuerySettingsField {\n enum Flags : uint64_t\n {\n IMPORTANT = 0x01,\n CUSTOM = 0x02,\n OBSOLETE = 0x04,\n };\n std::string value;\n uint64_t flags{0};\n};\n\nusing QuerySettings = std::unordered_map;\n\nstruct Profile {\n uint64_t rows = 0;\n uint64_t blocks = 0;\n uint64_t bytes = 0;\n uint64_t rows_before_limit = 0;\n bool applied_limit = false;\n bool calculated_rows_before_limit = false;\n};\n\n\nstruct Progress {\n uint64_t rows = 0;\n uint64_t bytes = 0;\n uint64_t total_rows = 0;\n uint64_t written_rows = 0;\n uint64_t written_bytes = 0;\n};\n\n\nclass QueryEvents {\npublic:\n virtual ~QueryEvents()\n { }\n\n /// Some data was received.\n virtual void OnData(const Block& block) = 0;\n virtual bool OnDataCancelable(const Block& block) = 0;\n\n virtual void OnServerException(const Exception& e) = 0;\n\n virtual void OnProfile(const Profile& profile) = 0;\n\n virtual void OnProgress(const Progress& progress) = 0;\n\n /** Handle query execution logs provided by server.\n * Amount of logs regulated by `send_logs_level` setting.\n * By-default only `fatal` log events are sent to the client side.\n */\n virtual void OnServerLog(const Block& block) = 0;\n\n /// Handle query execution profile events.\n virtual void OnProfileEvents(const Block& block) = 0;\n\n virtual void OnFinish() = 0;\n};\n\n\nusing ExceptionCallback = std::function;\nusing ProgressCallback = std::function;\nusing SelectCallback = std::function;\nusing SelectCancelableCallback = std::function;\nusing SelectServerLogCallback = std::function;\nusing ProfileEventsCallback = std::function;\nusing ProfileCallbak = std::function;\n\n\nclass Query : public QueryEvents {\npublic:\n Query();\n Query(const char* query, const char* query_id = nullptr);\n Query(const std::string& query, const std::string& query_id = default_query_id);\n ~Query() override;\n\n ///\n inline const std::string& GetText() const {\n return query_;\n }\n\n inline const std::string& GetQueryID() const {\n return query_id_;\n }\n\n inline const QuerySettings& GetQuerySettings() const {\n return query_settings_;\n }\n\n /// Set per query settings\n inline Query& SetQuerySettings(QuerySettings query_settings) {\n query_settings_ = std::move(query_settings);\n return *this;\n }\n\n /// Set per query setting\n inline Query& SetSetting(const std::string& key, const QuerySettingsField& value) {\n query_settings_[key] = value;\n return *this;\n }\n\n inline const std::optional& GetTracingContext() const {\n return tracing_context_;\n }\n\n /// Set tracing context for open telemetry signals\n inline Query& SetTracingContext(open_telemetry::TracingContext tracing_context) {\n tracing_context_ = std::move(tracing_context);\n return *this;\n }\n\n /// Set handler for receiving result data.\n inline Query& OnData(SelectCallback cb) {\n select_cb_ = std::move(cb);\n return *this;\n }\n\n inline Query& OnDataCancelable(SelectCancelableCallback cb) {\n select_cancelable_cb_ = std::move(cb);\n return *this;\n }\n\n /// Set handler for receiving server's exception.\n inline Query& OnException(ExceptionCallback cb) {\n exception_cb_ = std::move(cb);\n return *this;\n }\n\n /// Set handler for receiving a progress of query execution.\n inline Query& OnProgress(ProgressCallback cb) {\n progress_cb_ = std::move(cb);\n return *this;\n }\n\n /// Set handler for receiving a server log of query exceution.\n inline Query& OnServerLog(SelectServerLogCallback cb) {\n select_server_log_cb_ = std::move(cb);\n return *this;\n }\n\n /// Set handler for receiving profile events.\n inline Query& OnProfileEvents(ProfileEventsCallback cb) {\n profile_events_callback_cb_ = std::move(cb);\n return *this;\n }\n\n inline Query& OnProfile(ProfileCallbak cb) {\n profile_callback_cb_ = std::move(cb);\n return *this;\n }\n\n static const std::string default_query_id;\n\nprivate:\n void OnData(const Block& block) override {\n if (select_cb_) {\n select_cb_(block);\n }\n }\n\n bool OnDataCancelable(const Block& block) override {\n if (select_cancelable_cb_) {\n return select_cancelable_cb_(block);\n } else {\n return true;\n }\n }\n\n void OnServerException(const Exception& e) override {\n if (exception_cb_) {\n exception_cb_(e);\n }\n }\n\n void OnProfile(const Profile& profile) override {\n if (profile_callback_cb_)\n profile_callback_cb_(profile);\n }\n\n void OnProgress(const Progress& progress) override {\n if (progress_cb_) {\n progress_cb_(progress);\n }\n }\n\n void OnServerLog(const Block& block) override {\n if (select_server_log_cb_) {\n select_server_log_cb_(block);\n }\n }\n\n void OnProfileEvents(const Block& block) override {\n if (profile_events_callback_cb_) {\n profile_events_callback_cb_(block);\n }\n }\n\n void OnFinish() override {\n }\n\nprivate:\n const std::string query_;\n const std::string query_id_;\n std::optional tracing_context_;\n QuerySettings query_settings_;\n ExceptionCallback exception_cb_;\n ProgressCallback progress_cb_;\n SelectCallback select_cb_;\n SelectCancelableCallback select_cancelable_cb_;\n SelectServerLogCallback select_server_log_cb_;\n ProfileEventsCallback profile_events_callback_cb_;\n ProfileCallbak profile_callback_cb_;\n};\n\n}\n\n// Path: clickhouse/columns/numeric.h\n#pragma once\n\n#include \"column.h\"\n#include \"absl/numeric/int128.h\"\n\nnamespace clickhouse {\n\n/**\n * Represents various numeric columns.\n */\ntemplate \nclass ColumnVector : public Column {\npublic:\n using DataType = T;\n using ValueType = T;\n\n ColumnVector();\n\n explicit ColumnVector(const std::vector& data);\n explicit ColumnVector(std::vector && data);\n\n /// Increase the capacity of the column for large block insertion.\n void Reserve(size_t new_cap) override;\n\n /// Appends one element to the end of column.\n void Append(const T& value);\n\n /// Returns element at given row number.\n const T& At(size_t n) const;\n\n /// Returns element at given row number.\n inline const T& operator [] (size_t n) const { return At(n); }\n\n void Erase(size_t pos, size_t count = 1);\n\n /// Get Raw Vector Contents\n std::vector& GetWritableData();\n\n /// Returns the capacity of the column\n size_t Capacity() const;\n\npublic:\n /// Appends content of given column to the end of current one.\n void Append(ColumnRef column) override;\n\n /// Loads column data from input stream.\n bool LoadBody(InputStream* input, size_t rows) override;\n\n /// Saves column data to output stream.\n void SaveBody(OutputStream* output) override;\n\n /// Clear column data .\n void Clear() override;\n\n /// Returns count of rows in the column.\n size_t Size() const override;\n\n /// Makes slice of the current column.\n ColumnRef Slice(size_t begin, size_t len) const override;\n ColumnRef CloneEmpty() const override;\n void Swap(Column& other) override;\n\n ItemView GetItem(size_t index) const override;\n\nprivate:\n std::vector data_;\n};\n\nusing Int128 = absl::int128;\nusing Int64 = int64_t;\n\nusing ColumnUInt8 = ColumnVector;\nusing ColumnUInt16 = ColumnVector;\nusing ColumnUInt32 = ColumnVector;\nusing ColumnUInt64 = ColumnVector;\n\nusing ColumnInt8 = ColumnVector;\nusing ColumnInt16 = ColumnVector;\nusing ColumnInt32 = ColumnVector;\nusing ColumnInt64 = ColumnVector;\nusing ColumnInt128 = ColumnVector;\n\nusing ColumnFloat32 = ColumnVector;\nusing ColumnFloat64 = ColumnVector;\n\n}\n\n// Path: clickhouse/columns/utils.h\n#pragma once\n\n#include \n#include \n#include \n\nnamespace clickhouse {\n\ntemplate \nstd::vector SliceVector(const std::vector& vec, size_t begin, size_t len) {\n std::vector result;\n\n if (begin < vec.size()) {\n len = std::min(len, vec.size() - begin);\n result.assign(vec.begin() + begin, vec.begin() + (begin + len));\n }\n\n return result;\n}\n\ntemplate \nstruct HasWrapMethod {\nprivate:\n static int detect(...);\n template \n static decltype(U::Wrap(std::move(std::declval()))) detect(const U&);\n\npublic:\n static constexpr bool value = !std::is_same()))>::value;\n};\n\ntemplate \ninline std::shared_ptr WrapColumn(ColumnRef&& column) {\n if constexpr (HasWrapMethod::value) {\n return T::Wrap(std::move(column));\n } else {\n return column->template AsStrict();\n }\n}\n\n}\n\n// Path: clickhouse/columns/array.h\n#pragma once\n\n#include \"column.h\"\n#include \"numeric.h\"\n#include \"utils.h\"\n\n#include \n\nnamespace clickhouse {\n\ntemplate \nclass ColumnArrayT;\n\n/**\n * Represents column of Array(T).\n */\nclass ColumnArray : public Column {\npublic:\n using ValueType = ColumnRef;\n\n /** Create an array of given type.\n *\n * `data` is used internally (and modified) by ColumnArray.\n * Users are strongly advised against supplying non-empty columns and/or modifying\n * contents of `data` afterwards.\n */\n explicit ColumnArray(ColumnRef data);\n\n /** Create an array of given type, with actual values and offsets.\n *\n * Both `data` and `offsets` are used (and modified) internally bye ColumnArray.\n * Users are strongly advised against modifying contents of `data` or `offsets` afterwards.\n */\n ColumnArray(ColumnRef data, std::shared_ptr offsets);\n\n /// Converts input column to array and appends as one row to the current column.\n void AppendAsColumn(ColumnRef array);\n\n /// Converts array at pos n to column.\n /// Type of element of result column same as type of array element.\n ColumnRef GetAsColumn(size_t n) const;\n\n /// Shorthand to get a column casted to a proper type.\n template \n auto GetAsColumnTyped(size_t n) const {\n return GetAsColumn(n)->AsStrict();\n }\n\npublic:\n /// Increase the capacity of the column for large block insertion.\n void Reserve(size_t new_cap) override;\n\n /// Appends content of given column to the end of current one.\n void Append(ColumnRef column) override;\n\n /// Loads column prefix from input stream.\n bool LoadPrefix(InputStream* input, size_t rows) override;\n\n /// Loads column data from input stream.\n bool LoadBody(InputStream* input, size_t rows) override;\n\n /// Saves column prefix to output stream.\n void SavePrefix(OutputStream* output) override;\n\n /// Saves column data to output stream.\n void SaveBody(OutputStream* output) override;\n\n /// Clear column data .\n void Clear() override;\n\n /// Returns count of rows in the column.\n size_t Size() const override;\n\n /// Makes slice of the current column.\n ColumnRef Slice(size_t, size_t) const override;\n ColumnRef CloneEmpty() const override;\n void Swap(Column&) override;\n\n void OffsetsIncrease(size_t);\n\n...\n", "needle_token_start": 5049, "needle_token_end": 5073, "code_context_ntokens": 10613, "output": ["```cpp\nvoid Column::Swap(Column& other) override {\n std::swap(data_, other.data_);\n}\n```"]} +{"repo": "ClickHouse/clickhouse-cpp", "name": "CreateColumnByType", "language": "cpp", "path": "clickhouse/columns/factory.cpp", "position_ratio": 0.75, "description": "\nFunction Description:\n1. **Purpose**: The function is designed to generate a column object based on a specified data type description, handling various data structures and configurations.\n2. **Input**: It takes a string representing the type of the column and a settings object that configures specific behaviors during the column creation.\n3. **Output**: Returns a reference to the newly created column object, or a null reference if the creation fails or the type is unsupported.\n4. **Procedure**: \n - The function begins by parsing the input string to determine the structure and type of the column needed.\n - Depending on the parsed type, it constructs the appropriate column object. This includes handling complex types like enumerations, maps, and wrapped columns with specific settings.\n - For each type, it may recursively create nested columns or apply specific settings that affect how the column behaves or is represented.\n - If the type is unsupported or an error occurs during parsing or creation, it returns a null reference.\n\n", "instruction": "Based on the function description and code context, please retrieve and repeat the exact described function from the code context in a code block wrapped by ```:", "template": "instruction\ncode_context\ndescription\ninstruction", "code_context": "// Path: clickhouse/client.cpp\n#include \"client.h\"\n#include \"clickhouse/version.h\"\n#include \"protocol.h\"\n\n#include \"base/compressed.h\"\n#include \"base/socket.h\"\n#include \"base/wire_format.h\"\n\n#include \"columns/factory.h\"\n\n#include \n#include \n#include \n#include \n\n#if defined(WITH_OPENSSL)\n#include \"base/sslsocket.h\"\n#endif\n\n#define DBMS_NAME \"ClickHouse\"\n\n#define DBMS_MIN_REVISION_WITH_TEMPORARY_TABLES 50264\n#define DBMS_MIN_REVISION_WITH_TOTAL_ROWS_IN_PROGRESS 51554\n#define DBMS_MIN_REVISION_WITH_BLOCK_INFO 51903\n#define DBMS_MIN_REVISION_WITH_CLIENT_INFO 54032\n#define DBMS_MIN_REVISION_WITH_SERVER_TIMEZONE 54058\n#define DBMS_MIN_REVISION_WITH_QUOTA_KEY_IN_CLIENT_INFO 54060\n//#define DBMS_MIN_REVISION_WITH_TABLES_STATUS 54226\n#define DBMS_MIN_REVISION_WITH_TIME_ZONE_PARAMETER_IN_DATETIME_DATA_TYPE 54337\n#define DBMS_MIN_REVISION_WITH_SERVER_DISPLAY_NAME 54372\n#define DBMS_MIN_REVISION_WITH_VERSION_PATCH 54401\n#define DBMS_MIN_REVISION_WITH_LOW_CARDINALITY_TYPE 54405\n#define DBMS_MIN_REVISION_WITH_COLUMN_DEFAULTS_METADATA 54410\n#define DBMS_MIN_REVISION_WITH_CLIENT_WRITE_INFO 54420\n#define DBMS_MIN_REVISION_WITH_SETTINGS_SERIALIZED_AS_STRINGS 54429\n#define DBMS_MIN_REVISION_WITH_INTERSERVER_SECRET 54441\n#define DBMS_MIN_REVISION_WITH_OPENTELEMETRY 54442\n#define DBMS_MIN_REVISION_WITH_DISTRIBUTED_DEPTH 54448\n#define DBMS_MIN_REVISION_WITH_INITIAL_QUERY_START_TIME 54449\n#define DBMS_MIN_REVISION_WITH_INCREMENTAL_PROFILE_EVENTS 54451\n\n#define DMBS_PROTOCOL_REVISION DBMS_MIN_REVISION_WITH_INCREMENTAL_PROFILE_EVENTS\n\nnamespace clickhouse {\n\nstruct ClientInfo {\n uint8_t iface_type = 1; // TCP\n uint8_t query_kind;\n std::string initial_user;\n std::string initial_query_id;\n std::string quota_key;\n std::string os_user;\n std::string client_hostname;\n std::string client_name;\n std::string initial_address = \"[::ffff:127.0.0.1]:0\";\n uint64_t client_version_major = 0;\n uint64_t client_version_minor = 0;\n uint64_t client_version_patch = 0;\n uint32_t client_revision = 0;\n};\n\nstd::ostream& operator<<(std::ostream& os, const Endpoint& endpoint) {\n return os << endpoint.host << \":\" << endpoint.port;\n}\n\nstd::ostream& operator<<(std::ostream& os, const ClientOptions& opt) {\n os << \"Client(\"\n << \" Endpoints : [\";\n size_t extra_endpoints = 0;\n\n if (!opt.host.empty()) {\n extra_endpoints = 1;\n os << opt.user << '@' << Endpoint{opt.host, opt.port};\n\n if (opt.endpoints.size())\n os << \", \";\n }\n\n for (size_t i = 0; i < opt.endpoints.size(); i++) {\n os << opt.user << '@' << opt.endpoints[i]\n << ((i == opt.endpoints.size() - 1) ? \"\" : \", \");\n }\n\n os << \"] (\" << opt.endpoints.size() + extra_endpoints << \" items )\"\n << \" ping_before_query:\" << opt.ping_before_query\n << \" send_retries:\" << opt.send_retries\n << \" retry_timeout:\" << opt.retry_timeout.count()\n << \" compression_method:\"\n << (opt.compression_method == CompressionMethod::LZ4 ? \"LZ4\" : \"None\");\n#if defined(WITH_OPENSSL)\n if (opt.ssl_options) {\n const auto & ssl_options = *opt.ssl_options;\n os << \" SSL (\"\n << \" ssl_context: \" << (ssl_options.ssl_context ? \"provided by user\" : \"created internally\")\n << \" use_default_ca_locations: \" << ssl_options.use_default_ca_locations\n << \" path_to_ca_files: \" << ssl_options.path_to_ca_files.size() << \" items\"\n << \" path_to_ca_directory: \" << ssl_options.path_to_ca_directory\n << \" min_protocol_version: \" << ssl_options.min_protocol_version\n << \" max_protocol_version: \" << ssl_options.max_protocol_version\n << \" context_options: \" << ssl_options.context_options\n << \")\";\n }\n#endif\n os << \")\";\n return os;\n}\n\nClientOptions& ClientOptions::SetSSLOptions(ClientOptions::SSLOptions options)\n{\n#ifdef WITH_OPENSSL\n ssl_options = options;\n return *this;\n#else\n (void)options;\n throw OpenSSLError(\"Library was built with no SSL support\");\n#endif\n}\n\nnamespace {\n\nstd::unique_ptr GetSocketFactory(const ClientOptions& opts) {\n (void)opts;\n#if defined(WITH_OPENSSL)\n if (opts.ssl_options)\n return std::make_unique(opts);\n else\n#endif\n return std::make_unique();\n}\n\nstd::unique_ptr GetEndpointsIterator(const ClientOptions& opts) {\n if (opts.endpoints.empty())\n {\n throw ValidationError(\"The list of endpoints is empty\");\n }\n\n return std::make_unique(opts.endpoints);\n}\n\n}\n\nclass Client::Impl {\npublic:\n Impl(const ClientOptions& opts);\n Impl(const ClientOptions& opts,\n std::unique_ptr socket_factory);\n ~Impl();\n\n void ExecuteQuery(Query query);\n\n void SendCancel();\n\n void Insert(const std::string& table_name, const std::string& query_id, const Block& block);\n\n void Ping();\n\n void ResetConnection();\n\n void ResetConnectionEndpoint();\n\n const ServerInfo& GetServerInfo() const;\n\n const std::optional& GetCurrentEndpoint() const;\n\nprivate:\n bool Handshake();\n\n bool ReceivePacket(uint64_t* server_packet = nullptr);\n\n void SendQuery(const Query& query);\n\n void SendData(const Block& block);\n\n bool SendHello();\n\n bool ReadBlock(InputStream& input, Block* block);\n\n bool ReceiveHello();\n\n /// Reads data packet form input stream.\n bool ReceiveData();\n\n /// Reads exception packet form input stream.\n...\n// Path: clickhouse/query.cpp\n#include \"query.h\"\n\nnamespace clickhouse {\n\nconst std::string Query::default_query_id = {};\n\nQuery::Query()\n{ }\n\nQuery::Query(const char* query, const char* query_id)\n : query_(query)\n , query_id_(query_id ? std::string(query_id): default_query_id)\n{\n}\n\nQuery::Query(const std::string& query, const std::string& query_id)\n : query_(query)\n , query_id_(query_id)\n{\n}\n\nQuery::~Query()\n{ }\n\n}\n\n// Path: clickhouse/columns/map.cpp\n#include \"map.h\"\n\n#include \n\n#include \"../exceptions.h\"\n#include \"utils.h\"\n\nnamespace {\n\nusing namespace clickhouse;\n\nTypeRef GetMapType(const Type& data_type) {\n auto array = data_type.As();\n if (!array) {\n throw ValidationError(\"Wrong type \" + data_type.GetName() + \" of data for map\");\n }\n auto tuple = array->GetItemType()->As();\n if (!tuple) {\n throw ValidationError(\"Wrong type \" + data_type.GetName() + \" of data for map\");\n }\n auto types = tuple->GetTupleType();\n if (types.size() != 2) {\n throw ValidationError(\"Wrong type \" + data_type.GetName() + \" of data for map\");\n }\n return Type::CreateMap(types[0], types[1]);\n}\n\n} // namespace\n\nnamespace clickhouse {\n\nColumnMap::ColumnMap(ColumnRef data)\n : Column(GetMapType(data->GetType())), data_(data->As()) {\n}\n\nvoid ColumnMap::Reserve(size_t new_cap) {\n data_->Reserve(new_cap);\n}\n\nvoid ColumnMap::Clear() {\n data_->Clear();\n}\n\nvoid ColumnMap::Append(ColumnRef column) {\n if (auto col = column->As()) {\n data_->Append(col->data_);\n }\n}\n\nbool ColumnMap::LoadPrefix(InputStream* input, size_t rows) {\n return data_->LoadPrefix(input, rows);\n}\n\nbool ColumnMap::LoadBody(InputStream* input, size_t rows) {\n return data_->LoadBody(input, rows);\n}\n\nvoid ColumnMap::SavePrefix(OutputStream* output) {\n data_->SavePrefix(output);\n}\n\nvoid ColumnMap::SaveBody(OutputStream* output) {\n data_->SaveBody(output);\n}\n\nsize_t ColumnMap::Size() const {\n return data_->Size();\n}\n\nColumnRef ColumnMap::Slice(size_t begin, size_t len) const {\n return std::make_shared(data_->Slice(begin, len));\n}\n\nColumnRef ColumnMap::CloneEmpty() const {\n return std::make_shared(data_->CloneEmpty());\n}\n\nvoid ColumnMap::Swap(Column& other) {\n auto& col = dynamic_cast(other);\n data_.swap(col.data_);\n}\n\nColumnRef ColumnMap::GetAsColumn(size_t n) const {\n return data_->GetAsColumn(n);\n}\n\n} // namespace clickhouse\n\n// Path: clickhouse/columns/array.cpp\n#include \"array.h\"\n#include \"numeric.h\"\n\n#include \n\nnamespace clickhouse {\n\nColumnArray::ColumnArray(ColumnRef data)\n : ColumnArray(data, std::make_shared())\n{\n}\n\nColumnArray::ColumnArray(ColumnRef data, std::shared_ptr offsets)\n : Column(Type::CreateArray(data->Type()))\n , data_(data)\n , offsets_(offsets)\n{\n}\n\nColumnArray::ColumnArray(ColumnArray&& other)\n : Column(other.Type())\n , data_(std::move(other.data_))\n , offsets_(std::move(other.offsets_))\n{\n}\n\nvoid ColumnArray::AppendAsColumn(ColumnRef array) {\n // appending data may throw (i.e. due to ype check failure), so do it first to avoid partly modified state.\n data_->Append(array);\n AddOffset(array->Size());\n}\n\nColumnRef ColumnArray::GetAsColumn(size_t n) const {\n if (n >= Size())\n throw ValidationError(\"Index is out ouf bounds: \" + std::to_string(n));\n\n return data_->Slice(GetOffset(n), GetSize(n));\n}\n\nColumnRef ColumnArray::Slice(size_t begin, size_t size) const {\n if (size && begin + size > Size())\n throw ValidationError(\"Slice indexes are out of bounds\");\n\n auto result = std::make_shared(data_->Slice(GetOffset(begin), GetOffset(begin + size) - GetOffset(begin)));\n for (size_t i = 0; i < size; i++)\n result->AddOffset(GetSize(begin + i));\n\n return result;\n}\n\nColumnRef ColumnArray::CloneEmpty() const {\n return std::make_shared(data_->CloneEmpty());\n}\n\nvoid ColumnArray::Reserve(size_t new_cap) {\n data_->Reserve(new_cap);\n offsets_->Reserve(new_cap);\n}\n\nvoid ColumnArray::Append(ColumnRef column) {\n if (auto col = column->As()) {\n for (size_t i = 0; i < col->Size(); ++i) {\n AppendAsColumn(col->GetAsColumn(i));\n }\n }\n}\n\nbool ColumnArray::LoadPrefix(InputStream* input, size_t rows) {\n if (!rows) {\n return true;\n }\n\n return data_->LoadPrefix(input, rows);\n}\n\nbool ColumnArray::LoadBody(InputStream* input, size_t rows) {\n if (!rows) {\n return true;\n }\n if (!offsets_->LoadBody(input, rows)) {\n return false;\n }\n const auto nested_rows = (*offsets_)[rows - 1];\n if (nested_rows == 0) {\n return true;\n }\n if (!data_->LoadBody(input, nested_rows)) {\n return false;\n }\n return true;\n}\n\nvoid ColumnArray::SavePrefix(OutputStream* output) {\n data_->SavePrefix(output);\n}\n\nvoid ColumnArray::SaveBody(OutputStream* output) {\n offsets_->SaveBody(output);\n if (data_->Size() > 0) {\n data_->SaveBody(output);\n }\n}\n\nvoid ColumnArray::Clear() {\n offsets_->Clear();\n data_->Clear();\n}\n\nsize_t ColumnArray::Size() const {\n return offsets_->Size();\n}\n\nvoid ColumnArray::Swap(Column& other) {\n auto & col = dynamic_cast(other);\n data_.swap(col.data_);\n offsets_.swap(col.offsets_);\n}\n\nvoid ColumnArray::OffsetsIncrease(size_t n) {\n offsets_->Append(n);\n}\n\nsize_t ColumnArray::GetOffset(size_t n) const {\n\n return (n == 0) ? 0 : (*offsets_)[n - 1];\n}\n\nvoid ColumnArray::AddOffset(size_t n) {\n if (offsets_->Size() == 0) {\n offsets_->Append(n);\n } else {\n offsets_->Append((*offsets_)[offsets_->Size() - 1] + n);\n }\n}\n\nsize_t ColumnArray::GetSize(size_t n) const {\n return (n == 0) ? (*offsets_)[n] : ((*offsets_)[n] - (*offsets_)[n - 1]);\n}\n\nColumnRef ColumnArray::GetData() {\n return data_;\n}\n\nvoid ColumnArray::Reset() {\n data_.reset();\n offsets_.reset();\n}\n\n}\n\n// Path: clickhouse/columns/geo.cpp\n#include \"geo.h\"\n\n#include \"utils.h\"\n\nnamespace {\nusing namespace ::clickhouse;\n\ntemplate \nTypeRef CreateGeoType() {\n if constexpr (type_code == Type::Code::Point) {\n return Type::CreatePoint();\n } else if constexpr (type_code == Type::Code::Ring) {\n return Type::CreateRing();\n } else if constexpr (type_code == Type::Code::Polygon) {\n return Type::CreatePolygon();\n } else if constexpr (type_code == Type::Code::MultiPolygon) {\n return Type::CreateMultiPolygon();\n }\n}\n\ntemplate \nstd::shared_ptr CreateColumn() {\n if constexpr (std::is_same_v>) {\n return std::make_shared>(\n std::make_tuple(std::make_shared(), std::make_shared()));\n } else {\n return std::make_shared();\n }\n}\n\n} // namespace\n\nnamespace clickhouse {\n\ntemplate \nColumnGeo::ColumnGeo()\n : Column(CreateGeoType()),\n data_(CreateColumn()) {\n}\n\ntemplate \nColumnGeo::ColumnGeo(ColumnRef data)\n : Column(CreateGeoType())\n , data_(WrapColumn(std::move(data))) {\n}\n\ntemplate \nvoid ColumnGeo::Clear() {\n data_->Clear();\n}\n\ntemplate \nconst typename ColumnGeo::ValueType ColumnGeo::At(size_t n) const {\n return data_->At(n);\n}\n\ntemplate\nvoid ColumnGeo::Reserve(size_t new_cap) {\n data_->Reserve(new_cap);\n}\n\ntemplate \nvoid ColumnGeo::Append(ColumnRef column) {\n if (auto col = column->template As()) {\n data_->Append(col->data_->template As());\n }\n}\n\ntemplate \nbool ColumnGeo::LoadBody(InputStream* input, size_t rows) {\n return data_->LoadBody(input, rows);\n}\n\ntemplate \nvoid ColumnGeo::SaveBody(OutputStream* output) {\n data_->SaveBody(output);\n}\n\ntemplate \nsize_t ColumnGeo::Size() const {\n return data_->Size();\n}\n\ntemplate \nColumnRef ColumnGeo::Slice(size_t begin, size_t len) const {\n return std::make_shared(data_->Slice(begin, len));\n}\n\ntemplate \nColumnRef ColumnGeo::CloneEmpty() const {\n return std::make_shared();\n}\n\ntemplate \nvoid ColumnGeo::Swap(Column& other) {\n auto& col = dynamic_cast(other);\n data_.swap(col.data_);\n}\n\ntemplate class ColumnGeo, Type::Code::Point>;\n\ntemplate class ColumnGeo, Type::Code::Ring>;\n\ntemplate class ColumnGeo, Type::Code::Polygon>;\n\ntemplate class ColumnGeo, Type::Code::MultiPolygon>;\n\n} // namespace clickhouse\n\n// Path: clickhouse/columns/ip6.cpp\n#include \"ip6.h\"\n#include \"../base/socket.h\" // for IPv6 platform-specific stuff\n#include \"../exceptions.h\"\n\n#include \n\nnamespace clickhouse {\n\nstatic_assert(sizeof(struct in6_addr) == 16, \"sizeof in6_addr should be 16 bytes\");\n\nColumnIPv6::ColumnIPv6()\n : Column(Type::CreateIPv6())\n , data_(std::make_shared(16))\n{\n}\n\nColumnIPv6::ColumnIPv6(ColumnRef data)\n : Column(Type::CreateIPv6())\n , data_(data ? data->As() : nullptr)\n{\n if (!data_ || data_->FixedSize() != sizeof(in6_addr))\n throw ValidationError(\"Expecting ColumnFixedString(16), got \" + (data ? data->GetType().GetName() : \"null\"));\n}\n\nvoid ColumnIPv6::Append(const std::string_view& str) {\n unsigned char buf[16];\n if (inet_pton(AF_INET6, str.data(), buf) != 1) {\n throw ValidationError(\"invalid IPv6 format, ip: \" + std::string(str));\n }\n data_->Append(std::string_view((const char*)buf, 16));\n}\n\nvoid ColumnIPv6::Append(const in6_addr* addr) {\n data_->Append(std::string_view((const char*)addr->s6_addr, 16));\n}\n\nvoid ColumnIPv6::Append(const in6_addr& addr) {\n Append(&addr);\n}\n\nvoid ColumnIPv6::Clear() {\n data_->Clear();\n}\n\nstd::string ColumnIPv6::AsString (size_t n) const {\n const auto& addr = this->At(n);\n\n char buf[INET6_ADDRSTRLEN];\n const char* ip_str = inet_ntop(AF_INET6, &addr, buf, INET6_ADDRSTRLEN);\n\n if (ip_str == nullptr) {\n throw std::system_error(\n std::error_code(errno, std::generic_category()),\n \"Invalid IPv6 data\");\n }\n\n return ip_str;\n}\n\nin6_addr ColumnIPv6::At(size_t n) const {\n return *reinterpret_cast(data_->At(n).data());\n}\n\nin6_addr ColumnIPv6::operator [] (size_t n) const {\n return *reinterpret_cast(data_->At(n).data());\n}\n\nvoid ColumnIPv6::Reserve(size_t new_cap) {\n data_->Reserve(new_cap);\n}\n\nvoid ColumnIPv6::Append(ColumnRef column) {\n if (auto col = column->As()) {\n data_->Append(col->data_);\n }\n}\n\nbool ColumnIPv6::LoadBody(InputStream* input, size_t rows) {\n return data_->LoadBody(input, rows);\n}\n\nvoid ColumnIPv6::SaveBody(OutputStream* output) {\n data_->SaveBody(output);\n}\n\nsize_t ColumnIPv6::Size() const {\n return data_->Size();\n}\n\nColumnRef ColumnIPv6::Slice(size_t begin, size_t len) const {\n return std::make_shared(data_->Slice(begin, len));\n}\n\nColumnRef ColumnIPv6::CloneEmpty() const {\n return std::make_shared(data_->CloneEmpty());\n}\n\nvoid ColumnIPv6::Swap(Column& other) {\n auto & col = dynamic_cast(other);\n data_.swap(col.data_);\n}\n\nItemView ColumnIPv6::GetItem(size_t index) const {\n return ItemView{Type::IPv6, data_->GetItem(index)};\n}\n\n}\n\n// Path: clickhouse/columns/lowcardinalityadaptor.h\n#pragma once\n\n#include \"column.h\"\n#include \"lowcardinality.h\"\n\n#include \n\nnamespace clickhouse {\n\nclass OutputStream;\nclass CodedInputStream;\n\n/** Adapts any ColumnType to be serialized\\deserialized as LowCardinality,\n * and to be castable to ColumnType via ColumnPtr->As().\n *\n * It helps to ease migration of the old codebases, which can't afford to switch\n * to using ColumnLowCardinalityT or ColumnLowCardinality directly,\n * but still want to benefit from smaller on-wire LowCardinality bandwidth footprint.\n *\n * Not intended to be used by users directly.\n *\n * @see ClientOptions, CreateColumnByType\n */\ntemplate \nclass\n[[deprecated(\"Makes implementation of LC(X) harder and code uglier. Will be removed in next major release (3.0) \")]]\nLowCardinalitySerializationAdaptor : public AdaptedColumnType\n{\npublic:\n using AdaptedColumnType::AdaptedColumnType;\n\n bool LoadPrefix(InputStream* input, size_t rows) override {\n auto new_data_column = this->Slice(0, 0)->template As();\n ColumnLowCardinalityT low_cardinality_col(new_data_column);\n\n return low_cardinality_col.LoadPrefix(input, rows);\n }\n\n /// Loads column data from input stream.\n bool LoadBody(InputStream* input, size_t rows) override {\n auto new_data_column = this->CloneEmpty()->template As();\n\n ColumnLowCardinalityT low_cardinality_col(new_data_column);\n if (!low_cardinality_col.LoadBody(input, rows))\n return false;\n\n // It safe to reuse `flat_data_column` later since ColumnLowCardinalityT makes a deep copy, but still check just in case.\n assert(new_data_column->Size() == 0);\n\n for (size_t i = 0; i < low_cardinality_col.Size(); ++i)\n new_data_column->Append(low_cardinality_col[i]);\n\n this->Swap(*new_data_column);\n return true;\n }\n\n /// Saves column data to output stream.\n void SaveBody(OutputStream* output) override {\n ColumnLowCardinalityT(this->template As()).SaveBody(output);\n }\n};\n\n}\n\n// Path: clickhouse/columns/nothing.h\n\n#pragma once\n\n#include \"column.h\"\n#include \"../base/input.h\"\n\n#include \n#include \n\nnamespace clickhouse {\n\n/**\n * Represents dummy column of NULLs.\n */\nclass ColumnNothing : public Column {\npublic:\n ColumnNothing()\n : Column(Type::CreateNothing())\n , size_(0)\n {\n }\n\n explicit ColumnNothing(size_t n)\n : Column(Type::CreateNothing())\n , size_(n)\n {\n }\n\n /// Increase the capacity of the column for large block insertion.\n void Reserve(size_t) override {};\n\n /// Appends one element to the column.\n void Append(std::unique_ptr) { ++size_; }\n\n /// Returns element at given row number.\n std::nullptr_t At(size_t) const { return nullptr; };\n\n /// Returns element at given row number.\n inline std::nullptr_t operator [] (size_t) const { return nullptr; };\n\n /// Makes slice of the current column.\n ColumnRef Slice(size_t, size_t len) const override {\n return std::make_shared(len);\n }\n\n ColumnRef CloneEmpty() const override {\n return std::make_shared();\n }\n\n ItemView GetItem(size_t /*index*/) const override { return ItemView{}; }\n\npublic:\n /// Appends content of given column to the end of current one.\n void Append(ColumnRef column) override {\n if (auto col = column->As()) {\n size_ += col->Size();\n }\n }\n\n /// Loads column data from input stream.\n bool LoadBody(InputStream* input, size_t rows) override {\n input->Skip(rows);\n size_ += rows;\n return true;\n }\n\n /// Saves column data to output stream.\n void SaveBody(OutputStream*) override {\n throw UnimplementedError(\"method SaveBody is not supported for Nothing column\");\n }\n\n /// Clear column data .\n void Clear() override { size_ = 0; }\n\n /// Returns count of rows in the column.\n size_t Size() const override { return size_; }\n\n void Swap(Column& other) override {\n auto & col = dynamic_cast(other);\n std::swap(size_, col.size_);\n }\n\nprivate:\n size_t size_;\n};\n\n}\n\n// Path: clickhouse/base/string_view.h\n#pragma once\n\n#include \n#include \n#include \n\n/**\n * A lightweight non-owning read-only view into a subsequence of a string.\n */\ntemplate <\n typename TChar,\n typename TTraits = std::char_traits\n>\nclass\n[[deprecated(\"Obsolete due to C++17's std::string_view. Will be removed in next major release (3.0) \")]]\nStringViewImpl {\npublic:\n using size_type = size_t;\n using traits_type = TTraits;\n using value_type = typename TTraits::char_type;\n\n static constexpr size_type npos = size_type(-1);\n\npublic:\n inline StringViewImpl() noexcept\n : data_(nullptr)\n , size_(0)\n {\n }\n\n constexpr inline StringViewImpl(const TChar* data, size_t len) noexcept\n : data_(data)\n , size_(len)\n {\n }\n\n template \n constexpr inline StringViewImpl(const TChar (&str)[len]) noexcept\n : data_(str)\n , size_(len - 1)\n {\n }\n\n inline StringViewImpl(const TChar* begin, const TChar* end) noexcept\n : data_(begin)\n , size_(end - begin)\n {\n assert(begin <= end);\n }\n\n inline StringViewImpl(const std::basic_string& str) noexcept\n : data_(str.data())\n , size_(str.size())\n {\n }\n\n inline TChar at(size_type pos) const {\n if (pos >= size_)\n throw std::out_of_range(\"pos must be less than len\");\n return data_[pos];\n }\n\n inline const TChar* data() const noexcept {\n return data_;\n }\n\n inline bool empty() const noexcept {\n return size_ == 0;\n }\n\n inline bool null() const noexcept {\n assert(size_ == 0);\n return data_ == nullptr;\n }\n\n inline size_type size() const noexcept {\n return size_;\n }\n\n // to mimic std::string and std::string_view\n inline size_type length() const noexcept {\n return size();\n }\n\npublic:\n // Returns a substring [pos, pos + count).\n // If the requested substring extends past the end of the string,\n // or if count == npos, the returned substring is [pos, size()).\n StringViewImpl substr(size_type pos, size_type count = npos) const {\n if (pos >= size_)\n throw std::out_of_range(\"pos must be less than len\");\n if (pos + count >= size_ || count == npos)\n return StringViewImpl(data_ + pos, size_ - pos);\n else\n return StringViewImpl(data_ + pos, count);\n }\n\n inline const std::basic_string to_string() const {\n return std::basic_string(data_, size_);\n }\n\npublic:\n inline operator bool () const noexcept {\n return !empty();\n }\n\n inline explicit operator const std::basic_string () const {\n return to_string();\n }\n\n inline TChar operator [] (size_type pos) const noexcept {\n return data_[pos];\n }\n\n inline bool operator < (const StringViewImpl& other) const noexcept {\n if (size_ < other.size_)\n return true;\n if (size_ > other.size_)\n return false;\n return TTraits::compare(data_, other.data_, size_) < 0;\n }\n\n inline bool operator == (const StringViewImpl& other) const noexcept {\n if (size_ == other.size_)\n return TTraits::compare(data_, other.data_, size_) == 0;\n return false;\n }\n\nprivate:\n const TChar* data_;\n size_t size_;\n};\n\n\n// It creates StringView from literal constant at compile time.\ntemplate \nconstexpr inline StringViewImpl MakeStringView(const TChar (&str)[size]) {\n return StringViewImpl(str, size - 1);\n}\n\n\nusing StringView = StringViewImpl;\n\n// Path: clickhouse/types/type_parser.h\n#pragma once\n\n#include \"../base/string_view.h\"\n#include \"types.h\"\n\n#include \n#include \n#include \n\nnamespace clickhouse {\n\nstruct TypeAst {\n enum Meta {\n Array,\n Assign,\n Null,\n Nullable,\n Number,\n String,\n Terminal,\n Tuple,\n Enum,\n LowCardinality,\n SimpleAggregateFunction,\n Map\n };\n\n /// Type's category.\n Meta meta;\n Type::Code code;\n /// Type's name.\n /// Need to cache TypeAst, so can't use StringView for name.\n std::string name;\n /// Value associated with the node,\n /// used for fixed-width types and enum values.\n int64_t value = 0;\n std::string value_string;\n /// Subelements of the type.\n /// Used to store enum's names and values as well.\n std::vector elements;\n\n bool operator==(const TypeAst & other) const;\n inline bool operator!=(const TypeAst & other) const {\n return !(*this == other);\n }\n};\n\n\nclass TypeParser {\n\n struct Token {\n enum Type {\n Invalid = 0,\n Assign,\n Name,\n Number,\n String,\n LPar,\n RPar,\n Comma,\n QuotedString, // string with quotation marks included\n EOS,\n };\n\n Type type;\n StringView value;\n };\n\npublic:\n explicit TypeParser(const StringView& name);\n ~TypeParser();\n\n bool Parse(TypeAst* type);\n\nprivate:\n Token NextToken();\n\nprivate:\n const char* cur_;\n const char* end_;\n\n TypeAst* type_;\n std::stack open_elements_;\n};\n\n\nconst TypeAst* ParseTypeName(const std::string& type_name);\n\n}\n\n// Path: clickhouse/columns/factory.cpp\n#include \"factory.h\"\n\n#include \"array.h\"\n#include \"date.h\"\n#include \"decimal.h\"\n#include \"enum.h\"\n#include \"geo.h\"\n#include \"ip4.h\"\n#include \"ip6.h\"\n#include \"lowcardinality.h\"\n#include \"lowcardinalityadaptor.h\"\n#include \"map.h\"\n#include \"nothing.h\"\n#include \"nullable.h\"\n#include \"numeric.h\"\n#include \"string.h\"\n#include \"tuple.h\"\n#include \"uuid.h\"\n\n\n#include \"../types/type_parser.h\"\n\n#include \"../exceptions.h\"\n\n#include \n#include \n\nnamespace clickhouse {\nnamespace {\n\n// Like Python's list's []:\n// * 0 - first element\n// * 1 - second element\n// * -1 - last element\n// * -2 - one before last, etc.\nconst auto& GetASTChildElement(const TypeAst & ast, int position) {\n if (static_cast(abs(position)) >= ast.elements.size())\n throw ValidationError(\"AST child element index out of bounds: \" + std::to_string(position));\n\n if (position < 0)\n position = static_cast(ast.elements.size() + position);\n\n return ast.elements[static_cast(position)];\n}\n\nstatic ColumnRef CreateTerminalColumn(const TypeAst& ast) {\n switch (ast.code) {\n case Type::Void:\n return std::make_shared();\n\n case Type::UInt8:\n return std::make_shared();\n case Type::UInt16:\n return std::make_shared();\n case Type::UInt32:\n return std::make_shared();\n case Type::UInt64:\n return std::make_shared();\n\n case Type::Int8:\n return std::make_shared();\n case Type::Int16:\n return std::make_shared();\n case Type::Int32:\n return std::make_shared();\n case Type::Int64:\n return std::make_shared();\n case Type::Int128:\n return std::make_shared();\n\n case Type::Float32:\n return std::make_shared();\n case Type::Float64:\n return std::make_shared();\n\n case Type::Decimal:\n return std::make_shared(GetASTChildElement(ast, 0).value, GetASTChildElement(ast, -1).value);\n case Type::Decimal32:\n return std::make_shared(9, GetASTChildElement(ast, 0).value);\n case Type::Decimal64:\n return std::make_shared(18, GetASTChildElement(ast, 0).value);\n case Type::Decimal128:\n return std::make_shared(38, GetASTChildElement(ast, 0).value);\n\n case Type::String:\n return std::make_shared();\n case Type::FixedString:\n return std::make_shared(GetASTChildElement(ast, 0).value);\n\n case Type::DateTime:\n if (ast.elements.empty()) {\n return std::make_shared();\n } else {\n return std::make_shared(GetASTChildElement(ast, 0).value_string);\n }\n case Type::DateTime64:\n if (ast.elements.empty()) {\n return nullptr;\n }\n if (ast.elements.size() == 1) {\n return std::make_shared(ast.elements[0].value);\n } else {\n return std::make_shared(ast.elements[0].value, ast.elements[1].value_string);\n }\n case Type::Date:\n return std::make_shared();\n case Type::Date32:\n return std::make_shared();\n\n case Type::IPv4:\n return std::make_shared();\n case Type::IPv6:\n return std::make_shared();\n\n case Type::UUID:\n return std::make_shared();\n\n case Type::Point:\n return std::make_shared();\n\n case Type::Ring:\n return std::make_shared();\n\n case Type::Polygon:\n return std::make_shared();\n\n case Type::MultiPolygon:\n return std::make_shared();\n\n default:\n return nullptr;\n }\n}\n\nstatic ColumnRef CreateColumnFromAst(const TypeAst& ast, CreateColumnByTypeSettings settings) {\n switch (ast.meta) {\n case TypeAst::Array: {\n return std::make_shared(\n CreateColumnFromAst(GetASTChildElement(ast, 0), settings)\n );\n }\n\n case TypeAst::Nullable: {\n return std::make_shared(\n CreateColumnFromAst(GetASTChildElement(ast, 0), settings),\n std::make_shared()\n );\n }\n\n case TypeAst::Terminal: {\n return CreateTerminalColumn(ast);\n }\n\n case TypeAst::Tuple: {\n std::vector columns;\n\n columns.reserve(ast.elements.size());\n for (const auto& elem : ast.elements) {\n if (auto col = CreateColumnFromAst(elem, settings)) {\n columns.push_back(col);\n } else {\n return nullptr;\n }\n }\n\n return std::make_shared(columns);\n }\n\n case TypeAst::Enum: {\n std::vector enum_items;\n //ast.elements.size() minimum is 1.\n if ((ast.elements.size() % 2) != 0) {\n throw ValidationError(ast.name + \" content is not correct\");\n }\n\n enum_items.reserve(ast.elements.size() / 2);\n for (size_t i = 0; i < ast.elements.size(); i += 2) {\n enum_items.push_back(Type::EnumItem{\n ast.elements[i].value_string,\n static_cast(ast.elements[i + 1].value)\n });\n }\n\n if (ast.code == Type::Enum8) {\n return std::make_shared(\n Type::CreateEnum8(enum_items)\n );\n } else if (ast.code == Type::Enum16) {\n return std::make_shared(\n Type::CreateEnum16(enum_items)\n );\n }\n break;\n }\n case TypeAst::LowCardinality: {\n const auto nested = GetASTChildElement(ast, 0);\n if (settings.low_cardinality_as_wrapped_column) {\n switch (nested.code) {\n // TODO (nemkov): update this to maximize code reuse.\n case Type::String:\n return std::make_shared>();\n case Type::FixedString:\n return std::make_shared>(GetASTChildElement(nested, 0).value);\n case Type::Nullable:\n throw UnimplementedError(\"LowCardinality(\" + nested.name + \") is not supported with LowCardinalityAsWrappedColumn on\");\n default:\n throw UnimplementedError(\"LowCardinality(\" + nested.name + \") is not supported\");\n }\n }\n else {\n switch (nested.code) {\n // TODO (nemkov): update this to maximize code reuse.\n case Type::String:\n return std::make_shared>();\n case Type::FixedString:\n return std::make_shared>(GetASTChildElement(nested, 0).value);\n case Type::Nullable:\n return std::make_shared(\n std::make_shared(\n CreateColumnFromAst(GetASTChildElement(nested, 0), settings),\n std::make_shared()\n )\n );\n default:\n throw UnimplementedError(\"LowCardinality(\" + nested.name + \") is not supported\");\n }\n }\n }\n case TypeAst::SimpleAggregateFunction: {\n return CreateTerminalColumn(GetASTChildElement(ast, -1));\n }\n\n case TypeAst::Map: {\n if (ast.elements.size() != 2) {\n throw ValidationError(ast.name + \" content is not correct\");\n }\n\n std::vector columns;\n\n columns.reserve(ast.elements.size());\n for (const auto& elem : ast.elements) {\n if (auto col = CreateColumnFromAst(elem, settings)) {\n columns.push_back(col);\n } else {\n return nullptr;\n }\n }\n\n return std::make_shared(\n std::make_shared(\n std::make_shared(columns)));\n }\n\n case TypeAst::Assign:\n case TypeAst::Null:\n case TypeAst::Number:\n case TypeAst::String:\n break;\n }\n\n return nullptr;\n}\n\n} // namespace\n\n\n\nColumnRef CreateColumnByType(const std::string& type_name, CreateColumnByTypeSettings settings) {\n auto ast = ParseTypeName(type_name);\n if (ast != nullptr) {\n return CreateColumnFromAst(*ast, settings);\n }\n\n return nullptr;\n}\n\n}\n\n// Path: clickhouse/columns/nullable.cpp\n#include \"nullable.h\"\n\n#include \n#include \n\nnamespace clickhouse {\n\nColumnNullable::ColumnNullable(ColumnRef nested, ColumnRef nulls)\n : Column(Type::CreateNullable(nested->Type()))\n , nested_(nested)\n , nulls_(nulls->As())\n{\n if (nested_->Size() != nulls->Size()) {\n throw ValidationError(\"count of elements in nested and nulls should be the same\");\n }\n}\n\nvoid ColumnNullable::Append(bool isnull)\n{\n nulls_->Append(isnull ? 1 : 0);\n}\n\n\nbool ColumnNullable::IsNull(size_t n) const {\n return nulls_->At(n) != 0;\n}\n\nColumnRef ColumnNullable::Nested() const {\n return nested_;\n}\n\nColumnRef ColumnNullable::Nulls() const\n{\n return nulls_;\n}\n\nvoid ColumnNullable::Reserve(size_t new_cap) {\n nested_->Reserve(new_cap);\n nulls_->Reserve(new_cap);\n}\n\nvoid ColumnNullable::Append(ColumnRef column) {\n if (auto col = column->As()) {\n if (!col->nested_->Type()->IsEqual(nested_->Type())) {\n return;\n }\n\n nested_->Append(col->nested_);\n nulls_->Append(col->nulls_);\n }\n}\n\nvoid ColumnNullable::Clear() {\n nested_->Clear();\n nulls_->Clear();\n}\n\nbool ColumnNullable::LoadPrefix(InputStream* input, size_t rows) {\n return nested_->LoadPrefix(input, rows);\n}\n\nbool ColumnNullable::LoadBody(InputStream* input, size_t rows) {\n if (!nulls_->LoadBody(input, rows)) {\n return false;\n }\n if (!nested_->LoadBody(input, rows)) {\n return false;\n }\n return true;\n}\n\nvoid ColumnNullable::SavePrefix(OutputStream* output) {\n nested_->SavePrefix(output);\n}\n\nvoid ColumnNullable::SaveBody(OutputStream* output) {\n nulls_->SaveBody(output);\n nested_->SaveBody(output);\n}\n\nsize_t ColumnNullable::Size() const {\n return nulls_->Size();\n}\n\nColumnRef ColumnNullable::Slice(size_t begin, size_t len) const {\n return std::make_shared(nested_->Slice(begin, len), nulls_->Slice(begin, len));\n}\n\nColumnRef ColumnNullable::CloneEmpty() const {\n return std::make_shared(nested_->CloneEmpty(), nulls_->CloneEmpty());\n}\n\nvoid ColumnNullable::Swap(Column& other) {\n auto & col = dynamic_cast(other);\n if (!nested_->Type()->IsEqual(col.nested_->Type()))\n throw ValidationError(\"Can't swap() Nullable columns of different types.\");\n\n nested_.swap(col.nested_);\n nulls_.swap(col.nulls_);\n}\n\nItemView ColumnNullable::GetItem(size_t index) const {\n if (IsNull(index))\n return ItemView();\n\n return nested_->GetItem(index);\n}\n\n}\n\n// Path: clickhouse/columns/enum.cpp\n#include \"enum.h\"\n#include \"utils.h\"\n\n#include \"../base/input.h\"\n#include \"../base/output.h\"\n#include \"../base/wire_format.h\"\n\nnamespace clickhouse {\n\ntemplate \nColumnEnum::ColumnEnum(TypeRef type)\n : Column(type)\n{\n}\n\ntemplate \nColumnEnum::ColumnEnum(TypeRef type, const std::vector& data)\n : Column(type)\n , data_(data)\n{\n}\n\ntemplate \nColumnEnum::ColumnEnum(TypeRef type, std::vector&& data)\n : Column(type)\n , data_(std::move(data))\n{\n}\n\ntemplate \nvoid ColumnEnum::Append(const T& value, bool checkValue) {\n if (checkValue) {\n // TODO: type_->HasEnumValue(value), \"Enum type doesn't have value \" + std::to_string(value);\n }\n data_.push_back(value);\n}\n\ntemplate \nvoid ColumnEnum::Append(const std::string& name) {\n data_.push_back(static_cast(type_->As()->GetEnumValue(name)));\n}\n\ntemplate \nvoid ColumnEnum::Clear() {\n data_.clear();\n}\n\ntemplate \nconst T& ColumnEnum::At(size_t n) const {\n return data_.at(n);\n}\n\ntemplate \nstd::string_view ColumnEnum::NameAt(size_t n) const {\n return type_->As()->GetEnumName(data_.at(n));\n}\n\ntemplate \nvoid ColumnEnum::SetAt(size_t n, const T& value, bool checkValue) {\n if (checkValue) {\n // TODO: type_->HasEnumValue(value), \"Enum type doesn't have value \" + std::to_string(value);\n }\n data_.at(n) = value;\n}\n\ntemplate \nvoid ColumnEnum::SetNameAt(size_t n, const std::string& name) {\n data_.at(n) = static_cast(type_->As()->GetEnumValue(name));\n}\n\ntemplate\nvoid ColumnEnum::Reserve(size_t new_cap) {\n data_.reserve(new_cap);\n}\n\ntemplate \nvoid ColumnEnum::Append(ColumnRef column) {\n if (auto col = column->As>()) {\n data_.insert(data_.end(), col->data_.begin(), col->data_.end());\n }\n}\n\ntemplate \nbool ColumnEnum::LoadBody(InputStream* input, size_t rows) {\n data_.resize(rows);\n return WireFormat::ReadBytes(*input, data_.data(), data_.size() * sizeof(T));\n}\n\ntemplate \nvoid ColumnEnum::SaveBody(OutputStream* output) {\n WireFormat::WriteBytes(*output, data_.data(), data_.size() * sizeof(T));\n}\n\ntemplate \nsize_t ColumnEnum::Size() const {\n return data_.size();\n}\n\ntemplate \nColumnRef ColumnEnum::Slice(size_t begin, size_t len) const {\n return std::make_shared>(type_, SliceVector(data_, begin, len));\n}\n\ntemplate \nColumnRef ColumnEnum::CloneEmpty() const {\n return std::make_shared>(type_);\n}\n\ntemplate \nvoid ColumnEnum::Swap(Column& other) {\n auto & col = dynamic_cast &>(other);\n data_.swap(col.data_);\n}\n\ntemplate \nItemView ColumnEnum::GetItem(size_t index) const {\n return ItemView{type_->GetCode(), data_[index]};\n}\n\ntemplate class ColumnEnum;\ntemplate class ColumnEnum;\n\n}\n\n// Path: clickhouse/columns/date.cpp\n#include \"date.h\"\n#include \n\nnamespace clickhouse {\n\nColumnDate::ColumnDate()\n : Column(Type::CreateDate())\n , data_(std::make_shared())\n{\n}\n\nColumnDate::ColumnDate(std::vector&& data)\n : Column(Type::CreateDate())\n , data_(std::make_shared(std::move(data)))\n{\n}\n\nvoid ColumnDate::Append(const std::time_t& value) {\n /// The implementation is fundamentally wrong, ignores timezones, leap years and daylight saving.\n data_->Append(static_cast(value / std::time_t(86400)));\n}\n\nvoid ColumnDate::Clear() {\n data_->Clear();\n}\n\nstd::time_t ColumnDate::At(size_t n) const {\n /// The implementation is fundamentally wrong, ignores timezones, leap years and daylight saving.\n return static_cast(data_->At(n)) * 86400;\n}\n\nvoid ColumnDate::AppendRaw(uint16_t value) {\n data_->Append(value);\n}\n\nuint16_t ColumnDate::RawAt(size_t n) const {\n return data_->At(n);\n}\n\nvoid ColumnDate::Append(ColumnRef column) {\n if (auto col = column->As()) {\n data_->Append(col->data_);\n }\n}\n\nstd::vector& ColumnDate::GetWritableData() {\n return data_->GetWritableData();\n}\n\nvoid ColumnDate::Reserve(size_t new_cap) {\n data_->Reserve(new_cap);\n}\n\nsize_t ColumnDate::Capacity() const {\n return data_->Capacity();\n}\n\nbool ColumnDate::LoadBody(InputStream* input, size_t rows) {\n return data_->LoadBody(input, rows);\n}\n\nvoid ColumnDate::SaveBody(OutputStream* output) {\n data_->SaveBody(output);\n}\n\nsize_t ColumnDate::Size() const {\n return data_->Size();\n}\n\nColumnRef ColumnDate::Slice(size_t begin, size_t len) const {\n auto col = data_->Slice(begin, len)->As();\n auto result = std::make_shared();\n\n result->data_->Append(col);\n\n return result;\n}\n\nColumnRef ColumnDate::CloneEmpty() const {\n return std::make_shared();\n}\n\nvoid ColumnDate::Swap(Column& other) {\n auto & col = dynamic_cast(other);\n data_.swap(col.data_);\n}\n\nItemView ColumnDate::GetItem(size_t index) const {\n return ItemView(Type::Date, data_->GetItem(index));\n}\n\n\nColumnDate32::ColumnDate32()\n : Column(Type::CreateDate32())\n , data_(std::make_shared())\n{\n}\n\nColumnDate32::ColumnDate32(std::vector&& data)\n : Column(Type::CreateDate32())\n , data_(std::make_shared(std::move(data)))\n{\n}\n\nvoid ColumnDate32::Append(const std::time_t& value) {\n /// The implementation is fundamentally wrong, ignores timezones, leap years and daylight saving.\n data_->Append(static_cast(value / std::time_t(86400)));\n}\n\nvoid ColumnDate32::Clear() {\n data_->Clear();\n}\n\nstd::time_t ColumnDate32::At(size_t n) const {\n /// The implementation is fundamentally wrong, ignores timezones, leap years and daylight saving.\n return static_cast(data_->At(n)) * 86400;\n}\n\nvoid ColumnDate32::Append(ColumnRef column) {\n if (auto col = column->As()) {\n data_->Append(col->data_);\n }\n}\n\nstd::vector& ColumnDate32::GetWritableData() {\n return data_->GetWritableData();\n}\n\nvoid ColumnDate32::Reserve(size_t new_cap) {\n data_->Reserve(new_cap);\n}\n\nsize_t ColumnDate32::Capacity() const {\n return data_->Capacity();\n}\n\nvoid ColumnDate32::AppendRaw(int32_t value) {\n data_->Append(value);\n}\n\nint32_t ColumnDate32::RawAt(size_t n) const {\n return data_->At(n);\n}\n\nbool ColumnDate32::LoadBody(InputStream* input, size_t rows) {\n return data_->LoadBody(input, rows);\n}\n\nvoid ColumnDate32::SaveBody(OutputStream* output) {\n data_->SaveBody(output);\n}\n\nsize_t ColumnDate32::Size() const {\n return data_->Size();\n}\n\nColumnRef ColumnDate32::Slice(size_t begin, size_t len) const {\n auto col = data_->Slice(begin, len)->As();\n auto result = std::make_shared();\n\n result->data_->Append(col);\n\n return result;\n}\n\nColumnRef ColumnDate32::CloneEmpty() const {\n return std::make_shared();\n}\n\nvoid ColumnDate32::Swap(Column& other) {\n auto & col = dynamic_cast(other);\n data_.swap(col.data_);\n}\n\nItemView ColumnDate32::GetItem(size_t index) const {\n return ItemView{Type()->GetCode(), data_->GetItem(index)};\n}\n\nColumnDateTime::ColumnDateTime()\n : Column(Type::CreateDateTime())\n , data_(std::make_shared())\n{\n}\n\nColumnDateTime::ColumnDateTime(std::string timezone)\n : Column(Type::CreateDateTime(std::move(timezone)))\n , data_(std::make_shared())\n{\n}\n\nColumnDateTime::ColumnDateTime(std::vector&& data)\n : Column(Type::CreateDateTime())\n , data_(std::make_shared(std::move(data))) {\n}\n\nColumnDateTime::ColumnDateTime(std::string timezone, std::vector&& data)\n : Column(Type::CreateDateTime(std::move(timezone)))\n , data_(std::make_shared(std::move(data))) {\n}\n\nvoid ColumnDateTime::Append(const std::time_t& value) {\n data_->Append(static_cast(value));\n}\n\nstd::time_t ColumnDateTime::At(size_t n) const {\n return data_->At(n);\n}\n\nvoid ColumnDateTime::AppendRaw(uint32_t value) {\n data_->Append(value);\n}\n\nstd::string ColumnDateTime::Timezone() const {\n return type_->As()->Timezone();\n}\n\nvoid ColumnDateTime::Append(ColumnRef column) {\n if (auto col = column->As()) {\n data_->Append(col->data_);\n }\n}\n\nstd::vector& ColumnDateTime::GetWritableData() {\n return data_->GetWritableData();\n}\n\nvoid ColumnDateTime::Reserve(size_t new_cap) {\n data_->Reserve(new_cap);\n}\n\n...\n", "needle_token_start": 11960, "needle_token_end": 12032, "code_context_ntokens": 15987, "output": ["```cpp\n#include \"factory.h\"\n\n#include \"array.h\"\n#include \"date.h\"\n#include \"decimal.h\"\n#include \"enum.h\"\n#include \"geo.h\"\n#include \"ip4.h\"\n#include \"ip6.h\"\n#include \"lowcardinality.h\"\n#include \"lowcardinalityadaptor.h\"\n#include \"map.h\"\n#include \"nothing.h\"\n#include \"nullable.h\"\n#include \"numeric.h\"\n#include \"string.h\"\n#include \"tuple.h\"\n#include \"uuid.h\"\n\n#include \"../types/type_parser.h\"\n\n#include \"../exceptions.h\"\n\n#include \n#include \n\nnamespace clickhouse {\nnamespace {\n\n// Like Python's list's []:\n// * 0 - first element\n// * 1 - second element\n// * -1 - last element\n// * -2 - one before last, etc.\nconst auto& GetASTChildElement(const TypeAst & ast, int position) {\n if (static_cast(abs(position)) >= ast.elements.size())\n throw ValidationError(\"AST child element index out of bounds: \" + std::to_string(position));\n\n if (position < 0)\n position = static_cast(ast.elements.size() + position);\n\n return ast.elements[static_cast(position)];\n}\n\nstatic ColumnRef CreateTerminalColumn(const TypeAst& ast) {\n switch (ast.code) {\n case Type::Void:\n return std::make_shared();\n\n case Type::UInt8:\n return std::make_shared();\n case Type::UInt16:\n return std::make_shared();\n case Type::UInt32:\n return std::make_shared();\n case Type::UInt64:\n return std::make_shared();\n\n case Type::Int8:\n return std::make_shared();\n case Type::Int16:\n return std::make_shared();\n case Type::Int32:\n return std::make_shared();\n case Type::Int64:\n return std::make_shared();\n case Type::Int128:\n return std::make_shared();\n\n case Type::Float32:\n return std::make_shared();\n case Type::Float64:\n return std::make_shared();\n\n case Type::Decimal:\n return std::make_shared(GetASTChildElement(ast, 0).value, GetASTChildElement(ast, -1).value);\n case Type::Decimal32:\n return std::make_shared(9, GetASTChildElement(ast, 0).value);\n case Type::Decimal64:\n return std::make_shared(18, GetASTChildElement(ast, 0).value);\n case Type::Decimal128:\n return std::make_shared(38, GetASTChildElement(ast, 0).value);\n\n case Type::String:\n return std::make_shared();\n case Type::FixedString:\n return std::make_shared(GetASTChildElement(ast, 0).value);\n\n case Type::DateTime:\n if (ast.elements.empty()) {\n return std::make_shared();\n } else {\n return std::make_shared(GetASTChildElement(ast, 0).value_string);\n }\n case Type::DateTime64:\n if (ast.elements.empty()) {\n return nullptr;\n }\n if (ast.elements.size() == 1) {\n return std::make_shared(ast.elements[0].value);\n } else {\n return std::make_shared(ast.elements[0].value, ast.elements[1].value_string);\n }\n case Type::Date:\n return std::make_shared();\n case Type::Date32:\n return std::make_shared();\n\n case Type::IPv4:\n return std::make_shared();\n case Type::IPv6:\n return std::make_shared();\n\n case Type::UUID:\n return std::make_shared();\n\n case Type::Point:\n return std::make_shared();\n\n case Type::Ring:\n return std::make_shared();\n\n case Type::Polygon:\n return std::make_shared();\n\n case Type::MultiPolygon:\n return std::make_shared();\n\n default:\n return nullptr;\n }\n}\n\nstatic ColumnRef CreateColumnFromAst(const TypeAst& ast, CreateColumnByTypeSettings settings) {\n switch (ast.meta) {\n case TypeAst::Array: {\n "]} +{"repo": "ClickHouse/clickhouse-cpp", "name": "ConvertToStorageValue", "language": "cpp", "path": "clickhouse/columns/itemview.h", "position_ratio": 0.85, "description": "\nFunction Description:\n1. **Purpose**: The function is designed to convert various data types into a uniform format suitable for storage in a specific data structure, ensuring compatibility and consistency in how data is represented.\n2. **Input**: Accepts a generic input of various data types, including strings, fundamental data types, and user-defined types.\n3. **Output**: Returns a uniform string view representation of the input data, facilitating standardized storage.\n4. **Procedure**: The function checks the type of the input data:\n - If the input is a string or string view, it directly returns a string view of the input.\n - If the input is a fundamental data type or a specific user-defined type, it converts the data to a string view by interpreting the data's memory representation.\n - If the input type is not supported, it triggers a compile-time assertion to indicate an error due to an unknown type.\n\n", "instruction": "Based on the function description and code context, please retrieve and repeat the exact described function from the code context in a code block wrapped by ```:", "template": "instruction\ncode_context\ndescription\ninstruction", "code_context": "// Path: clickhouse/types/types.h\n#pragma once\n\n#include \"absl/numeric/int128.h\"\n\n#include \n#include \n#include \n#include \n#include \n#include \n\nnamespace clickhouse {\n\nusing Int128 = absl::int128;\nusing Int64 = int64_t;\n\nusing TypeRef = std::shared_ptr;\n\nclass Type {\npublic:\n enum Code {\n Void = 0,\n Int8,\n Int16,\n Int32,\n Int64,\n UInt8,\n UInt16,\n UInt32,\n UInt64,\n Float32,\n Float64,\n String,\n FixedString,\n DateTime,\n Date,\n Array,\n Nullable,\n Tuple,\n Enum8,\n Enum16,\n UUID,\n IPv4,\n IPv6,\n Int128,\n Decimal,\n Decimal32,\n Decimal64,\n Decimal128,\n LowCardinality,\n DateTime64,\n Date32,\n Map,\n Point,\n Ring,\n Polygon,\n MultiPolygon\n };\n\n using EnumItem = std::pair;\n\nprotected:\n Type(const Code code);\n\npublic:\n template \n auto* As() {\n return static_cast(this);\n }\n\n template \n const auto* As() const {\n return static_cast(this);\n }\n\n /// Type's code.\n Code GetCode() const { return code_; }\n\n /// String representation of the type.\n std::string GetName() const;\n\n /// Is given type same as current one.\n bool IsEqual(const Type& other) const {\n // Types are equal only if both code_ and type_unique_id_ are equal.\n return this == &other\n // GetTypeUniqueId() is relatively heavy, so avoid calling it when comparing obviously different types.\n || (this->GetCode() == other.GetCode() && this->GetTypeUniqueId() == other.GetTypeUniqueId());\n }\n\n bool IsEqual(const TypeRef& other) const { return IsEqual(*other); }\n\n /// Simple name, doesn't depend on parameters and\\or nested types, caller MUST NOT free returned value.\n static const char* TypeName(Code);\n\npublic:\n static TypeRef CreateArray(TypeRef item_type);\n\n static TypeRef CreateDate();\n\n static TypeRef CreateDate32();\n\n static TypeRef CreateDateTime(std::string timezone = std::string());\n\n static TypeRef CreateDateTime64(size_t precision, std::string timezone = std::string());\n\n static TypeRef CreateDecimal(size_t precision, size_t scale);\n\n static TypeRef CreateIPv4();\n\n static TypeRef CreateIPv6();\n\n static TypeRef CreateNothing();\n\n static TypeRef CreateNullable(TypeRef nested_type);\n\n template \n static TypeRef CreateSimple();\n\n static TypeRef CreateString();\n\n static TypeRef CreateString(size_t n);\n\n static TypeRef CreateTuple(const std::vector& item_types);\n\n static TypeRef CreateEnum8(const std::vector& enum_items);\n\n static TypeRef CreateEnum16(const std::vector& enum_items);\n\n static TypeRef CreateUUID();\n\n static TypeRef CreateLowCardinality(TypeRef item_type);\n\n static TypeRef CreateMap(TypeRef key_type, TypeRef value_type);\n\n static TypeRef CreatePoint();\n\n static TypeRef CreateRing();\n\n static TypeRef CreatePolygon();\n\n static TypeRef CreateMultiPolygon();\n\nprivate:\n uint64_t GetTypeUniqueId() const;\n\n const Code code_;\n mutable std::atomic type_unique_id_;\n};\n\ninline bool operator==(const Type & left, const Type & right) {\n if (&left == &right)\n return true;\n if (typeid(left) == typeid(right))\n return left.IsEqual(right);\n return false;\n}\n\ninline bool operator==(const TypeRef & left, const TypeRef & right) {\n return *left == *right;\n}\n\nclass ArrayType : public Type {\npublic:\n explicit ArrayType(TypeRef item_type);\n\n std::string GetName() const { return std::string(\"Array(\") + item_type_->GetName() + \")\"; }\n\n /// Type of array's elements.\n inline TypeRef GetItemType() const { return item_type_; }\n\nprivate:\n TypeRef item_type_;\n};\n\nclass DecimalType : public Type {\npublic:\n DecimalType(size_t precision, size_t scale);\n\n std::string GetName() const;\n friend class EnumType;\n friend class DateTimeType;\n\n inline size_t GetScale() const { return scale_; }\n inline size_t GetPrecision() const { return precision_; }\n\nprivate:\n const size_t precision_, scale_;\n};\n\nnamespace details\n{\nclass TypeWithTimeZoneMixin\n{\npublic:\n TypeWithTimeZoneMixin(std::string timezone);\n\n /// Timezone associated with a data column.\n const std::string & Timezone() const;\n\nprivate:\n std::string timezone_;\n};\n}\n\nclass DateTimeType : public Type, public details::TypeWithTimeZoneMixin {\npublic:\n explicit DateTimeType(std::string timezone);\n\n std::string GetName() const;\n};\n\nclass DateTime64Type: public Type, public details::TypeWithTimeZoneMixin {\npublic:\n explicit DateTime64Type(size_t precision, std::string timezone_);\n\n std::string GetName() const;\n\n inline size_t GetPrecision() const { return precision_; }\nprivate:\n size_t precision_;\n};\n\nclass EnumType : public Type {\npublic:\n EnumType(Type::Code type, const std::vector& items);\n\n std::string GetName() const;\n\n /// Methods to work with enum types.\n std::string_view GetEnumName(int16_t value) const;\n int16_t GetEnumValue(const std::string& name) const;\n bool HasEnumName(const std::string& name) const;\n bool HasEnumValue(int16_t value) const;\n\nprivate:\n using ValueToNameType = std::map;\n using NameToValueType = std::map;\n using ValueToNameIterator = ValueToNameType::const_iterator;\n\n ValueToNameType value_to_name_;\n NameToValueType name_to_value_;\n\npublic:\n ValueToNameIterator BeginValueToName() const;\n ValueToNameIterator EndValueToName() const;\n};\n\nclass FixedStringType : public Type {\npublic:\n explicit FixedStringType(size_t n);\n\n std::string GetName() const { return std::string(\"FixedString(\") + std::to_string(size_) + \")\"; }\n\n inline size_t GetSize() const { return size_; }\n\nprivate:\n size_t size_;\n};\n\nclass NullableType : public Type {\npublic:\n explicit NullableType(TypeRef nested_type);\n\n std::string GetName() const { return std::string(\"Nullable(\") + nested_type_->GetName() + \")\"; }\n\n /// Type of nested nullable element.\n TypeRef GetNestedType() const { return nested_type_; }\n\nprivate:\n TypeRef nested_type_;\n};\n\nclass TupleType : public Type {\npublic:\n explicit TupleType(const std::vector& item_types);\n\n std::string GetName() const;\n\n /// Type of nested Tuple element type.\n std::vector GetTupleType() const { return item_types_; }\n\nprivate:\n std::vector item_types_;\n};\n\nclass LowCardinalityType : public Type {\npublic:\n explicit LowCardinalityType(TypeRef nested_type);\n ~LowCardinalityType();\n\n std::string GetName() const { return std::string(\"LowCardinality(\") + nested_type_->GetName() + \")\"; }\n\n /// Type of nested nullable element.\n TypeRef GetNestedType() const { return nested_type_; }\n\nprivate:\n TypeRef nested_type_;\n};\n\nclass MapType : public Type {\npublic:\n explicit MapType(TypeRef key_type, TypeRef value_type);\n\n std::string GetName() const;\n\n /// Type of keys.\n TypeRef GetKeyType() const { return key_type_; }\n\n /// Type of values.\n TypeRef GetValueType() const { return value_type_; }\n\nprivate:\n TypeRef key_type_;\n TypeRef value_type_;\n};\n\ntemplate <>\ninline TypeRef Type::CreateSimple() {\n return TypeRef(new Type(Int8));\n}\n\ntemplate <>\ninline TypeRef Type::CreateSimple() {\n return TypeRef(new Type(Int16));\n}\n\ntemplate <>\ninline TypeRef Type::CreateSimple() {\n return TypeRef(new Type(Int32));\n}\n\ntemplate <>\ninline TypeRef Type::CreateSimple() {\n return TypeRef(new Type(Int64));\n}\n\ntemplate <>\ninline TypeRef Type::CreateSimple() {\n return TypeRef(new Type(Int128));\n}\n\ntemplate <>\ninline TypeRef Type::CreateSimple() {\n return TypeRef(new Type(UInt8));\n}\n\ntemplate <>\ninline TypeRef Type::CreateSimple() {\n return TypeRef(new Type(UInt16));\n}\n\ntemplate <>\ninline TypeRef Type::CreateSimple() {\n return TypeRef(new Type(UInt32));\n}\n\ntemplate <>\ninline TypeRef Type::CreateSimple() {\n return TypeRef(new Type(UInt64));\n}\n\ntemplate <>\ninline TypeRef Type::CreateSimple() {\n return TypeRef(new Type(Float32));\n}\n\ntemplate <>\ninline TypeRef Type::CreateSimple() {\n return TypeRef(new Type(Float64));\n}\n\n} // namespace clickhouse\n\n// Path: clickhouse/server_exception.h\n#pragma once\n\n#include \n#include \n\nnamespace clickhouse {\nstruct Exception {\n int code = 0;\n std::string name;\n std::string display_text;\n std::string stack_trace;\n /// Pointer to nested exception.\n std::unique_ptr nested;\n};\n\n}\n\n// Path: clickhouse/exceptions.h\n#pragma once\n\n#include \"server_exception.h\"\n\n#include \n\nnamespace clickhouse {\n\nclass Error : public std::runtime_error {\n using std::runtime_error::runtime_error;\n};\n\n// Caused by any user-related code, like invalid column types or arguments passed to any method.\nclass ValidationError : public Error {\n using Error::Error;\n};\n\n// Buffers+IO errors, failure to serialize/deserialize, checksum mismatches, etc.\nclass ProtocolError : public Error {\n using Error::Error;\n};\n\nclass UnimplementedError : public Error {\n using Error::Error;\n};\n\n// Internal validation error.\nclass AssertionError : public Error {\n using Error::Error;\n};\n\nclass OpenSSLError : public Error {\n using Error::Error;\n};\n\nclass LZ4Error : public Error {\n using Error::Error;\n};\n\n// Exception received from server.\nclass ServerException : public Error {\npublic:\n ServerException(std::unique_ptr e)\n : Error(std::string())\n , exception_(std::move(e))\n {\n }\n\n int GetCode() const {\n return exception_->code;\n }\n\n const Exception& GetException() const {\n return *exception_;\n }\n\n const char* what() const noexcept override {\n return exception_->display_text.c_str();\n }\n\nprivate:\n std::unique_ptr exception_;\n};\nusing ServerError = ServerException;\n\n}\n\n// Path: clickhouse/columns/itemview.h\n#pragma once\n\n#include \"../types/types.h\"\n#include \"../exceptions.h\"\n\n#include \n#include \n#include \n\nnamespace clickhouse {\n\n/** ItemView is a view on a data stored in Column, safe-ish interface for reading values from Column.\n *\n * Data is not owned (hence the name View) and will be invalidated on column update, load\n * or destruction (basically on calling any non-const method of Column).\n * `type` reflects what is stored in `data` and can be almost any value-type\n * (except Nullable, Array, Tuple, LowCardinality).\n *\n */\nstruct ItemView {\n using DataType = std::string_view;\n\n const Type::Code type;\n const DataType data;\n\nprivate:\n template \n \ninline auto ConvertToStorageValue(const T& t) {\n if constexpr (std::is_same_v || std::is_same_v) {\n return std::string_view{t};\n } else if constexpr (std::is_fundamental_v || std::is_same_v>) {\n return std::string_view{reinterpret_cast(&t), sizeof(T)};\n } else {\n static_assert(!std::is_same_v, \"Unknown type, which can't be stored in ItemView\");\n return;\n }\n }\n\npublic:\n ItemView(Type::Code type, DataType data)\n : type(type),\n data(data)\n {\n ValidateData(type, data);\n }\n\n ItemView(Type::Code type, ItemView other)\n : type(type),\n data(other.data)\n {\n ValidateData(type, data);\n }\n\n explicit ItemView()\n : ItemView(Type::Void, std::string_view{})\n {}\n\n template \n explicit ItemView(Type::Code type, const T & value)\n : ItemView(type, ConvertToStorageValue(value))\n {}\n\n template \n auto get() const {\n using ValueType = std::remove_cv_t>;\n if constexpr (std::is_same_v || std::is_same_v) {\n return data;\n } else if constexpr (std::is_fundamental_v || std::is_same_v) {\n if (sizeof(ValueType) == data.size()) {\n return *reinterpret_cast(data.data());\n } else {\n throw AssertionError(\"Incompatitable value type and size. Requested size: \"\n + std::to_string(sizeof(ValueType)) + \" stored size: \" + std::to_string(data.size()));\n }\n }\n }\n\n inline std::string_view AsBinaryData() const {\n return data;\n }\n\n // Validate that value matches type, will throw an exception if validation fails.\n static void ValidateData(Type::Code type, DataType data);\n};\n\n}\n\n// Path: clickhouse/columns/column.h\n#pragma once\n\n#include \"../types/types.h\"\n#include \"../columns/itemview.h\"\n#include \"../exceptions.h\"\n\n#include \n#include \n\nnamespace clickhouse {\n\nclass InputStream;\nclass OutputStream;\n\nusing ColumnRef = std::shared_ptr;\n\n/**\n * An abstract base of all columns classes.\n */\nclass Column : public std::enable_shared_from_this {\npublic:\n explicit inline Column(TypeRef type) : type_(type) {}\n\n virtual ~Column() {}\n\n /// Downcast pointer to the specific column's subtype.\n template \n inline std::shared_ptr As() {\n return std::dynamic_pointer_cast(shared_from_this());\n }\n\n /// Downcast pointer to the specific column's subtype.\n template \n inline std::shared_ptr As() const {\n return std::dynamic_pointer_cast(shared_from_this());\n }\n\n /// Downcast pointer to the specific column's subtype.\n template \n inline std::shared_ptr AsStrict() {\n auto result = std::dynamic_pointer_cast(shared_from_this());\n if (!result) {\n throw ValidationError(\"Can't cast from \" + type_->GetName());\n }\n return result;\n }\n\n /// Get type object of the column.\n inline TypeRef Type() const { return type_; }\n inline const class Type& GetType() const { return *type_; }\n\n /// Appends content of given column to the end of current one.\n virtual void Append(ColumnRef column) = 0;\n\n /// Increase the capacity of the column for large block insertion.\n virtual void Reserve(size_t new_cap) = 0;\n\n /// Template method to load column data from input stream. It'll call LoadPrefix and LoadBody.\n /// Should be called only once from the client. Derived classes should not call it.\n bool Load(InputStream* input, size_t rows);\n\n /// Loads column prefix from input stream.\n virtual bool LoadPrefix(InputStream* input, size_t rows);\n\n /// Loads column data from input stream.\n virtual bool LoadBody(InputStream* input, size_t rows) = 0;\n\n /// Saves column prefix to output stream. Column types with prefixes must implement it.\n virtual void SavePrefix(OutputStream* output);\n\n /// Saves column body to output stream.\n virtual void SaveBody(OutputStream* output) = 0;\n\n /// Template method to save to output stream. It'll call SavePrefix and SaveBody respectively\n /// Should be called only once from the client. Derived classes should not call it.\n /// Save is split in Prefix and Body because some data types require prefixes and specific serialization order.\n /// For instance, Array(LowCardinality(X)) requires LowCardinality.key_version bytes to come before Array.offsets\n void Save(OutputStream* output);\n\n /// Clear column data .\n virtual void Clear() = 0;\n\n /// Returns count of rows in the column.\n virtual size_t Size() const = 0;\n\n /// Makes slice of the current column.\n virtual ColumnRef Slice(size_t begin, size_t len) const = 0;\n\n virtual ColumnRef CloneEmpty() const = 0;\n\n virtual void Swap(Column&) = 0;\n\n /// Get a view on raw item data if it is supported by column, will throw an exception if index is out of range.\n /// Please note that view is invalidated once column items are added or deleted, column is loaded from strean or destroyed.\n virtual ItemView GetItem(size_t) const {\n throw UnimplementedError(\"GetItem() is not supported for column of \" + type_->GetName());\n }\n\n friend void swap(Column& left, Column& right) {\n left.Swap(right);\n }\n\nprotected:\n TypeRef type_;\n};\n\n} // namespace clickhouse\n\n// Path: clickhouse/block.h\n#pragma once\n\n#include \"columns/column.h\"\n\nnamespace clickhouse {\n\nstruct BlockInfo {\n uint8_t is_overflows = 0;\n int32_t bucket_num = -1;\n};\n\nclass Block {\npublic:\n /// Allow to iterate over block's columns.\n class Iterator {\n public:\n Iterator(const Block& block);\n\n /// Name of column.\n const std::string& Name() const;\n\n /// Type of column.\n TypeRef Type() const;\n\n /// Reference to column object.\n ColumnRef Column() const;\n\n /// Move to next column, returns false if next call to IsValid() would return false;\n bool Next();\n\n /// Is the iterator still valid.\n bool IsValid() const;\n\n size_t ColumnIndex() const {\n return idx_;\n }\n\n Iterator& operator*() { return *this; }\n const Iterator& operator*() const { return *this; }\n\n bool operator==(const Iterator & other) const {\n return &block_ == &other.block_ && idx_ == other.idx_;\n }\n bool operator!=(const Iterator & other) const {\n return !(*this == other);\n }\n\n Iterator& operator++() {\n this->Next();\n return *this;\n }\n\n private:\n friend class Block;\n struct ConstructAtEndTag {};\n Iterator(const Block& block, ConstructAtEndTag at_end);\n Iterator() = delete;\n\n const Block& block_;\n size_t idx_;\n };\n\npublic:\n Block();\n Block(size_t cols, size_t rows);\n ~Block();\n\n /// Append named column to the block.\n void AppendColumn(const std::string& name, const ColumnRef& col);\n\n /// Count of columns in the block.\n size_t GetColumnCount() const;\n\n const BlockInfo& Info() const;\n\n /// Set block info\n void SetInfo(BlockInfo info);\n\n /// Count of rows in the block.\n size_t GetRowCount() const;\n\n size_t RefreshRowCount();\n\n const std::string& GetColumnName(size_t idx) const {\n return columns_.at(idx).name;\n }\n\n /// Reference to column by index in the block.\n ColumnRef operator [] (size_t idx) const;\n\n Iterator begin() const;\n Iterator end() const;\n Iterator cbegin() const { return begin(); }\n Iterator cend() const { return end(); }\n\nprivate:\n struct ColumnItem {\n std::string name;\n ColumnRef column;\n };\n\n BlockInfo info_;\n std::vector columns_;\n /// Count of rows in the block.\n size_t rows_;\n};\n\n}\n\n// Path: clickhouse/block.cpp\n#include \"block.h\"\n\n#include \"exceptions.h\"\n\n#include \n\nnamespace clickhouse {\n\nBlock::Iterator::Iterator(const Block& block)\n : block_(block)\n , idx_(0)\n{\n}\n\n...\n", "needle_token_start": 3440, "needle_token_end": 3620, "code_context_ntokens": 5891, "output": ["```cpp\ntemplate <>\ninline auto ConvertToStorageValue(const T& t) {\n if constexpr (std::is_same_v || std::is_same_v) {\n return std::string_view{t};\n } else if constexpr (std::is_fundamental_v || std::is_same_v>) {\n return std::string_view{reinterpret_cast(&t), sizeof(T)};\n } else {\n static_assert(!std::is_same_v, \"Unknown type, which can't be stored in ItemView\");\n return;\n }\n}\n```"]} +{"repo": "ClickHouse/clickhouse-cpp", "name": "GetSocketFactory", "language": "cpp", "path": "clickhouse/client.cpp", "position_ratio": 0.95, "description": "\nFunction Description:\n1. **Purpose**: The function is designed to create and return an appropriate socket factory based on the security settings specified in the client configuration. It decides whether to create a secure or non-secure socket factory.\n2. **Input**: The function takes a configuration object that includes client options, which may or may not include SSL settings.\n3. **Output**: Returns a unique pointer to a socket factory, which can either be secure or non-secure depending on the input configuration.\n4. **Procedure**: The function checks if SSL options are provided in the input configuration. If SSL options are present and the build supports SSL, it creates and returns a secure socket factory. If no SSL options are provided or the build does not support SSL, it returns a non-secure socket factory.\n\n", "instruction": "Based on the function description and code context, please retrieve and repeat the exact described function from the code context in a code block wrapped by ```:", "template": "instruction\ncode_context\ndescription\ninstruction", "code_context": "// Path: clickhouse/columns/lowcardinality.h\n#pragma once\n\n#include \"column.h\"\n#include \"numeric.h\"\n#include \"nullable.h\"\n\n#include \n#include \n#include \n#include \n\nnamespace clickhouse {\n\ntemplate \nclass ColumnLowCardinalityT;\n\nnamespace details {\n\n/** LowCardinalityHashKey used as key in unique items hashmap to abstract away key value\n * (type of which depends on dictionary column) and to reduce likelehood of collisions.\n *\n * In order to dramatically reduce collision rate, we use 2 different hashes from 2 different hash functions.\n * First hash is used in hashtable (to calculate item position).\n * Second one is used as part of key value and accessed via `operator==()` upon collision resolution/detection.\n */\nusing LowCardinalityHashKey = std::pair;\n\nstruct LowCardinalityHashKeyHash {\n inline std::size_t operator()(const LowCardinalityHashKey &hash_key) const noexcept {\n return hash_key.first;\n }\n};\n\n}\n\n/*\n * LC column contains an \"invisible\" default item at the beginning of the collection. [default, ...]\n * If the nested type is Nullable, it contains a null-item at the beginning and a default item at the second position. [null, default, ...]\n * Null map is not serialized in LC columns. Instead, nulls are tracked by having an index of 0.\n * */\nclass ColumnLowCardinality : public Column {\npublic:\n...\n// Path: clickhouse/base/projected_iterator.h\n#pragma once\n\n#include \n#include \n#include \n\nnamespace clickhouse {\n\ntemplate ()(std::declval())),\n typename Value = std::decay_t>\nclass ProjectedIterator {\npublic:\n using value_type = Value;\n using reference = Reference;\n using pointer = Reference;\n using difference_type = typename std::iterator_traits::difference_type;\n using iterator_category = typename std::iterator_traits::iterator_category;\n\n ProjectedIterator() = default;\n\n inline ProjectedIterator(Iterator const& iterator, UnaryFunction functor)\n : iterator_(iterator)\n , functor_(std::move(functor)) {\n }\n\n inline UnaryFunction functor() const { return functor; }\n\n inline Iterator const& base() const { return iterator_; }\n\n inline reference operator*() const { return functor_(iterator_); }\n\n inline ProjectedIterator& operator++() {\n ++iterator_;\n return *this;\n }\n\n inline ProjectedIterator& operator--() {\n --iterator_;\n return *this;\n }\n\n inline bool operator==(const ProjectedIterator& other) const {\n return this->iterator_ == other.iterator_;\n }\n\n inline bool operator!=(const ProjectedIterator& other) const {\n return !(*this == other);\n }\n\nprivate:\n Iterator iterator_;\n UnaryFunction functor_;\n};\n\n} // namespace clickhouse\n\n// Path: clickhouse/columns/map.h\n#pragma once\n\n#include \"../base/projected_iterator.h\"\n#include \"array.h\"\n#include \"column.h\"\n#include \"tuple.h\"\n\n#include \n#include \n\nnamespace clickhouse {\n\ntemplate \nclass ColumnMapT;\n\n/**\n * Represents column of Map(K, V).\n */\nclass ColumnMap : public Column {\npublic:\n /** Create a map of given type, with actual values and offsets.\n *\n * Both `data` and `offsets` are used (and modified) internally bye ColumnArray.\n * Users are strongly advised against modifying contents of `data` or `offsets` afterwards.\n */\n explicit ColumnMap(ColumnRef data);\n\n /// Increase the capacity of the column for large block insertion.\n void Reserve(size_t new_cap) override;\n\n /// Appends content of given column to the end of current one.\n void Append(ColumnRef column) override;\n\n /// Loads column prefix from input stream.\n bool LoadPrefix(InputStream* input, size_t rows) override;\n\n /// Loads column data from input stream.\n bool LoadBody(InputStream* input, size_t rows) override;\n\n /// Saves column prefix to output stream.\n void SavePrefix(OutputStream* output) override;\n\n /// Saves column data to output stream.\n void SaveBody(OutputStream* output) override;\n\n /// Clear column data .\n void Clear() override;\n\n /// Returns count of rows in the column.\n size_t Size() const override;\n\n /// Makes slice of the current column.\n ColumnRef Slice(size_t, size_t) const override;\n ColumnRef CloneEmpty() const override;\n void Swap(Column&) override;\n\n /// Converts map at pos n to column.\n /// Type of row is tuple {key, value}.\n ColumnRef GetAsColumn(size_t n) const;\n\nprotected:\n template \n friend class ColumnMapT;\n\n ColumnMap(ColumnMap&& map);\n\nprivate:\n std::shared_ptr data_;\n};\n\ntemplate \nclass ColumnMapT : public ColumnMap {\npublic:\n using KeyColumnType = K;\n using ValueColumnType = V;\n using Key = std::decay_t().At(0))>;\n using Value = std::decay_t().At(0))>;\n using TupleColumnType = ColumnTupleT;\n using ArrayColumnType = ColumnArrayT;\n\n ColumnMapT(ColumnRef data)\n : ColumnMap(data), typed_data_(data->AsStrict>()) {}\n\n ColumnMapT(std::shared_ptr keys, std::shared_ptr values)\n : ColumnMap(std::make_shared(std::make_shared(\n std::make_tuple(std::move(keys), std::move(values))))),\n typed_data_(data_->template As()) {}\n\n ColumnRef Slice(size_t begin, size_t len) const override {\n return std::make_shared>(typed_data_->Slice(begin, len));\n }\n\n ColumnRef CloneEmpty() const override {\n return std::make_shared>(typed_data_->CloneEmpty());\n }\n\n void Swap(Column& other) override {\n auto& col = dynamic_cast&>(other);\n col.typed_data_.swap(typed_data_);\n ColumnMap::Swap(other);\n }\n\n /// A single (row) value of the Map-column i.e. read-only map.\n /// It has a linear time complexity to access items\n /// Because data base type has same structure\n /// \"This lookup works now with a linear complexity.\"\n /// https://clickhouse.com/docs/en/sql-reference/data-types/map\n /// Convert it to a suitable container required to access more than one element\n\n class MapValueView {\n const typename ArrayColumnType::ArrayValueView data_;\n\n public:\n using ValueType = std::pair;\n\n MapValueView(typename ArrayColumnType::ArrayValueView data) : data_(std::move(data)) {}\n\n inline auto operator[](const Key& key) const { return (*Find(key)).second; }\n\n inline auto At(const Key& key) const {\n auto it = Find(key);\n if (it == end()) throw ValidationError(\"ColumnMap value key not found\");\n return (*it).second;\n }\n\n class Iterator {\n typename ArrayColumnType::ArrayValueView::Iterator data_iterator_;\n\n public:\n Iterator() = default;\n\n Iterator(typename ArrayColumnType::ArrayValueView::Iterator data_iterator)\n : data_iterator_(data_iterator) {}\n\n using ValueType = std::pair;\n using difference_type = size_t;\n using value_type = ValueType;\n using pointer = void;\n using reference = ValueType&;\n using iterator_category = std::forward_iterator_tag;\n\n inline auto operator*() const {\n auto tuple = *data_iterator_;\n return ValueType{std::get<0>(tuple), std::get<1>(tuple)};\n }\n\n inline Iterator& operator++() {\n ++data_iterator_;\n return *this;\n }\n\n inline bool operator==(const Iterator& other) const {\n return this->data_iterator_ == other.data_iterator_;\n }\n\n inline bool operator!=(const Iterator& other) const { return !(*this == other); }\n };\n\n // minimalistic stl-like container interface, hence the lowercase\n inline Iterator begin() const { return Iterator{data_.begin()}; }\n\n inline Iterator cbegin() const { return Iterator{data_.cbegin()}; }\n\n inline Iterator end() const { return Iterator{data_.end()}; }\n\n inline Iterator cend() const { return Iterator{data_.cend()}; }\n\n inline size_t size() const { return data_.size(); }\n\n // It is ugly to have both size() and Size(), but it is for compatitability with both STL\n // and rest of the clickhouse-cpp.\n inline size_t Size() const { return data_.Size(); }\n\n inline size_t Count(const Key& key) const {\n size_t result = 0;\n for (auto item : data_) {\n if (std::get<0>(item) == key) {\n ++result;\n }\n }\n return result;\n }\n\n inline Iterator Find(const Key& key) const {\n for (auto it = data_.begin(); it != data_.end(); ++it) {\n if (std::get<0>(*it) == key) {\n return Iterator{it};\n }\n }\n return end();\n }\n\n inline bool operator==(const MapValueView& other) const {\n if (size() != other.size()) {\n return false;\n }\n const auto make_index = [](const auto& data) {\n std::vector result{data.Size()};\n std::generate(result.begin(), result.end(), [i = 0] () mutable {return i++;});\n std::sort(result.begin(), result.end(), [&data](size_t l, size_t r) {return data[l] < data[r];});\n return result;\n };\n const auto index = make_index(data_);\n for (const auto& val : other.data_) {\n if (!std::binary_search(index.begin(), index.end(), val,\n [&data = data_](const auto& l, size_t r) {return l < data[r];})) {\n return false;\n }\n }\n return true;\n }\n\n inline bool operator!=(const MapValueView& other) const { return !(*this == other); }\n };\n\n inline auto At(size_t index) const { return MapValueView{typed_data_->At(index)}; }\n\n inline auto operator[](size_t index) const { return At(index); }\n\n using ColumnMap::Append;\n\n inline void Append(const MapValueView& value) { typed_data_->Append(value.data_); }\n\n inline void Append(const std::vector>& tuples) {\n typed_data_->Append(tuples.begin(), tuples.end());\n }\n\n template \n inline void Append(const T& value) {\n using BaseIter = decltype(value.begin());\n using KeyOfT = decltype(std::declval()->first);\n using ValOfT = decltype(std::declval()->second);\n using Functor = std::function(const BaseIter&)>;\n using Iterator = ProjectedIterator;\n\n Functor functor = [](const BaseIter& i) {\n return std::make_tuple(std::cref(i->first), std::cref(i->second));\n };\n\n typed_data_->Append(Iterator{value.begin(), functor}, Iterator{value.end(), functor});\n }\n\n static auto Wrap(ColumnMap&& col) {\n auto data = ArrayColumnType::Wrap(std::move(col.data_));\n return std::make_shared>(std::move(data));\n }\n\n static auto Wrap(Column&& col) { return Wrap(std::move(dynamic_cast(col))); }\n\n // Helper to simplify integration with other APIs\n static auto Wrap(ColumnRef&& col) { return Wrap(std::move(*col->AsStrict())); }\n\nprivate:\n std::shared_ptr typed_data_;\n};\n\n} // namespace clickhouse\n\n// Path: clickhouse/columns/uuid.h\n#pragma once\n\n#include \"../base/uuid.h\"\n#include \"column.h\"\n#include \"numeric.h\"\n\nnamespace clickhouse {\n\n\n/**\n * Represents a UUID column.\n */\nclass ColumnUUID : public Column {\npublic:\n ColumnUUID();\n\n explicit ColumnUUID(ColumnRef data);\n\n /// Appends one element to the end of column.\n void Append(const UUID& value);\n\n /// Returns element at given row number.\n const UUID At(size_t n) const;\n\n /// Returns element at given row number.\n inline const UUID operator [] (size_t n) const { return At(n); }\n\npublic:\n /// Increase the capacity of the column for large block insertion.\n void Reserve(size_t new_cap) override;\n\n /// Appends content of given column to the end of current one.\n void Append(ColumnRef column) override;\n\n /// Loads column data from input stream.\n bool LoadBody(InputStream* input, size_t rows) override;\n\n /// Saves column data to output stream.\n void SaveBody(OutputStream* output) override;\n\n /// Clear column data .\n void Clear() override;\n\n /// Returns count of rows in the column.\n size_t Size() const override;\n\n /// Makes slice of the current column.\n ColumnRef Slice(size_t begin, size_t len) const override;\n ColumnRef CloneEmpty() const override;\n void Swap(Column& other) override;\n\n ItemView GetItem(size_t) const override;\n\nprivate:\n std::shared_ptr data_;\n};\n\n}\n\n// Path: clickhouse/client.h\n#pragma once\n\n#include \"query.h\"\n#include \"exceptions.h\"\n\n#include \"columns/array.h\"\n#include \"columns/date.h\"\n#include \"columns/decimal.h\"\n#include \"columns/enum.h\"\n#include \"columns/geo.h\"\n#include \"columns/ip4.h\"\n#include \"columns/ip6.h\"\n#include \"columns/lowcardinality.h\"\n#include \"columns/nullable.h\"\n#include \"columns/numeric.h\"\n#include \"columns/map.h\"\n#include \"columns/string.h\"\n#include \"columns/tuple.h\"\n#include \"columns/uuid.h\"\n\n#include \n#include \n#include \n#include \n#include \n#include \n\ntypedef struct ssl_ctx_st SSL_CTX;\n\nnamespace clickhouse {\n\nstruct ServerInfo {\n std::string name;\n std::string timezone;\n std::string display_name;\n uint64_t version_major;\n uint64_t version_minor;\n uint64_t version_patch;\n uint64_t revision;\n};\n\n/// Methods of block compression.\nenum class CompressionMethod {\n None = -1,\n LZ4 = 1,\n};\n\nstruct Endpoint {\n std::string host;\n uint16_t port = 9000;\n inline bool operator==(const Endpoint& right) const {\n return host == right.host && port == right.port;\n }\n};\n\nenum class EndpointsIterationAlgorithm {\n RoundRobin = 0,\n};\n\nstruct ClientOptions {\n // Setter goes first, so it is possible to apply 'deprecated' annotation safely.\n#define DECLARE_FIELD(name, type, setter, default_value) \\\n inline auto & setter(const type& value) { \\\n name = value; \\\n return *this; \\\n } \\\n type name = default_value\n\n /// Hostname of the server.\n DECLARE_FIELD(host, std::string, SetHost, std::string());\n /// Service port.\n DECLARE_FIELD(port, uint16_t, SetPort, 9000);\n\n /** Set endpoints (host+port), only one is used.\n * Client tries to connect to those endpoints one by one, on the round-robin basis:\n * first default enpoint (set via SetHost() + SetPort()), then each of endpoints, from begin() to end(),\n * the first one to establish connection is used for the rest of the session.\n * If port isn't specified, default(9000) value will be used.\n */\n DECLARE_FIELD(endpoints, std::vector, SetEndpoints, {});\n\n /// Default database.\n DECLARE_FIELD(default_database, std::string, SetDefaultDatabase, \"default\");\n /// User name.\n DECLARE_FIELD(user, std::string, SetUser, \"default\");\n /// Access password.\n DECLARE_FIELD(password, std::string, SetPassword, std::string());\n\n /// By default all exceptions received during query execution will be\n /// passed to OnException handler. Set rethrow_exceptions to true to\n /// enable throwing exceptions with standard c++ exception mechanism.\n DECLARE_FIELD(rethrow_exceptions, bool, SetRethrowException, true);\n\n /// Ping server every time before execute any query.\n DECLARE_FIELD(ping_before_query, bool, SetPingBeforeQuery, false);\n /// Count of retry to send request to server.\n DECLARE_FIELD(send_retries, unsigned int, SetSendRetries, 1);\n /// Amount of time to wait before next retry.\n DECLARE_FIELD(retry_timeout, std::chrono::seconds, SetRetryTimeout, std::chrono::seconds(5));\n\n /// Compression method.\n DECLARE_FIELD(compression_method, CompressionMethod, SetCompressionMethod, CompressionMethod::None);\n\n /// TCP Keep alive options\n DECLARE_FIELD(tcp_keepalive, bool, TcpKeepAlive, false);\n DECLARE_FIELD(tcp_keepalive_idle, std::chrono::seconds, SetTcpKeepAliveIdle, std::chrono::seconds(60));\n DECLARE_FIELD(tcp_keepalive_intvl, std::chrono::seconds, SetTcpKeepAliveInterval, std::chrono::seconds(5));\n DECLARE_FIELD(tcp_keepalive_cnt, unsigned int, SetTcpKeepAliveCount, 3);\n\n // TCP options\n DECLARE_FIELD(tcp_nodelay, bool, TcpNoDelay, true);\n\n /// Connection socket connect timeout. If the timeout is negative then the connect operation will never timeout.\n DECLARE_FIELD(connection_connect_timeout, std::chrono::milliseconds, SetConnectionConnectTimeout, std::chrono::seconds(5));\n\n /// Connection socket timeout. If the timeout is set to zero then the operation will never timeout.\n DECLARE_FIELD(connection_recv_timeout, std::chrono::milliseconds, SetConnectionRecvTimeout, std::chrono::milliseconds(0));\n DECLARE_FIELD(connection_send_timeout, std::chrono::milliseconds, SetConnectionSendTimeout, std::chrono::milliseconds(0));\n\n /** It helps to ease migration of the old codebases, which can't afford to switch\n * to using ColumnLowCardinalityT or ColumnLowCardinality directly,\n * but still want to benefit from smaller on-wire LowCardinality bandwidth footprint.\n *\n * @see LowCardinalitySerializationAdaptor, CreateColumnByType\n */\n [[deprecated(\"Makes implementation of LC(X) harder and code uglier. Will be removed in next major release (3.0) \")]]\n DECLARE_FIELD(backward_compatibility_lowcardinality_as_wrapped_column, bool, SetBakcwardCompatibilityFeatureLowCardinalityAsWrappedColumn, false);\n\n /** Set max size data to compress if compression enabled.\n *\n * Allows choosing tradeoff between RAM\\CPU:\n * - Lower value reduces RAM usage, but slightly increases CPU usage.\n * - Higher value increases RAM usage but slightly decreases CPU usage.\n */\n DECLARE_FIELD(max_compression_chunk_size, unsigned int, SetMaxCompressionChunkSize, 65535);\n\n struct SSLOptions {\n /** There are two ways to configure an SSL connection:\n * - provide a pre-configured SSL_CTX, which is not modified and not owned by the Client.\n * - provide a set of options and allow the Client to create and configure SSL_CTX by itself.\n */\n\n /** Pre-configured SSL-context for SSL-connection.\n * If NOT null client DONES NOT take ownership of context and it must be valid for client lifetime.\n * If null client initlaizes OpenSSL and creates his own context, initializes it using\n * other options, like path_to_ca_files, path_to_ca_directory, use_default_ca_locations, etc.\n *\n * Either way context is used to create an SSL-connection, which is then configured with\n * whatever was provided as `configuration`, `host_flags`, `skip_verification` and `use_sni`.\n */\n SSL_CTX * ssl_context = nullptr;\n auto & SetExternalSSLContext(SSL_CTX * new_ssl_context) {\n ssl_context = new_ssl_context;\n return *this;\n }\n\n /** Means to validate the server-supplied certificate against trusted Certificate Authority (CA).\n * If no CAs are configured, the server's identity can't be validated, and the Client would err.\n * See https://www.openssl.org/docs/man1.1.1/man3/SSL_CTX_set_default_verify_paths.html\n */\n /// Load default CA certificates from default locations.\n DECLARE_FIELD(use_default_ca_locations, bool, SetUseDefaultCALocations, true);\n /// Path to the CA files to verify server certificate, may be empty.\n DECLARE_FIELD(path_to_ca_files, std::vector, SetPathToCAFiles, {});\n /// Path to the directory with CA files used to validate server certificate, may be empty.\n DECLARE_FIELD(path_to_ca_directory, std::string, SetPathToCADirectory, \"\");\n\n /** Min and max protocol versions to use, set with SSL_CTX_set_min_proto_version and SSL_CTX_set_max_proto_version\n * for details see https://www.openssl.org/docs/man1.1.1/man3/SSL_CTX_set_min_proto_version.html\n */\n DECLARE_FIELD(min_protocol_version, int, SetMinProtocolVersion, DEFAULT_VALUE);\n DECLARE_FIELD(max_protocol_version, int, SetMaxProtocolVersion, DEFAULT_VALUE);\n\n /** Options to be set with SSL_CTX_set_options,\n * for details see https://www.openssl.org/docs/man1.1.1/man3/SSL_CTX_set_options.html\n */\n DECLARE_FIELD(context_options, int, SetContextOptions, DEFAULT_VALUE);\n\n /** Use SNI at ClientHello\n */\n DECLARE_FIELD(use_sni, bool, SetUseSNI, true);\n\n /** Skip SSL session verification (server's certificate, etc).\n *\n * WARNING: settig to true will bypass all SSL session checks, which\n * is dangerous, but can be used against self-signed certificates, e.g. for testing purposes.\n */\n DECLARE_FIELD(skip_verification, bool, SetSkipVerification, false);\n\n /** Mode of verifying host ssl certificate against name of the host, set with SSL_set_hostflags.\n * For details see https://www.openssl.org/docs/man1.1.1/man3/SSL_set_hostflags.html\n */\n DECLARE_FIELD(host_flags, int, SetHostVerifyFlags, DEFAULT_VALUE);\n\n struct CommandAndValue {\n std::string command;\n std::optional value = std::nullopt;\n };\n /** Extra configuration options, set with SSL_CONF_cmd.\n * For deatils see https://www.openssl.org/docs/man1.1.1/man3/SSL_CONF_cmd.html\n *\n * Takes multiple pairs of command-value strings, all commands are supported,\n * and prefix is empty.\n * i.e. pass `sigalgs` or `SignatureAlgorithms` instead of `-sigalgs`.\n *\n * Rewrites any other options/flags if set in other ways.\n */\n DECLARE_FIELD(configuration, std::vector, SetConfiguration, {});\n\n static const int DEFAULT_VALUE = -1;\n };\n\n // By default SSL is turned off.\n std::optional ssl_options = std::nullopt;\n\n // Will throw an exception if client was built without SSL support.\n ClientOptions& SetSSLOptions(SSLOptions options);\n\n#undef DECLARE_FIELD\n};\n\nstd::ostream& operator<<(std::ostream& os, const ClientOptions& options);\nstd::ostream& operator<<(std::ostream& os, const Endpoint& options);\n\nclass SocketFactory;\n\n/**\n *\n */\nclass Client {\npublic:\n Client(const ClientOptions& opts);\n Client(const ClientOptions& opts,\n std::unique_ptr socket_factory);\n ~Client();\n\n /// Intends for execute arbitrary queries.\n void Execute(const Query& query);\n\n /// Intends for execute select queries. Data will be returned with\n /// one or more call of \\p cb.\n void Select(const std::string& query, SelectCallback cb);\n void Select(const std::string& query, const std::string& query_id, SelectCallback cb);\n\n /// Executes a select query which can be canceled by returning false from\n /// the data handler function \\p cb.\n void SelectCancelable(const std::string& query, SelectCancelableCallback cb);\n void SelectCancelable(const std::string& query, const std::string& query_id, SelectCancelableCallback cb);\n\n /// Alias for Execute.\n void Select(const Query& query);\n\n /// Intends for insert block of data into a table \\p table_name.\n void Insert(const std::string& table_name, const Block& block);\n void Insert(const std::string& table_name, const std::string& query_id, const Block& block);\n\n /// Ping server for aliveness.\n void Ping();\n\n /// Reset connection with initial params.\n void ResetConnection();\n\n const ServerInfo& GetServerInfo() const;\n\n /// Get current connected endpoint.\n /// In case when client is not connected to any endpoint, nullopt will returned.\n const std::optional& GetCurrentEndpoint() const;\n\n // Try to connect to different endpoints one by one only one time. If it doesn't work, throw an exception.\n void ResetConnectionEndpoint();\n\n struct Version\n {\n uint16_t major;\n uint16_t minor;\n uint16_t patch;\n uint16_t build;\n const char * extra;\n };\n\n static Version GetVersion();\n\nprivate:\n const ClientOptions options_;\n\n class Impl;\n std::unique_ptr impl_;\n};\n\n}\n\n// Path: clickhouse/protocol.h\n#pragma once\n\nnamespace clickhouse {\n\n /// Types of packets received from server\n namespace ServerCodes {\n enum {\n Hello = 0, /// Name, version, revision.\n Data = 1, /// `Block` of data, may be compressed.\n Exception = 2, /// Exception that occurred on server side during query execution.\n Progress = 3, /// Query execcution progress: rows and bytes read.\n Pong = 4, /// response to Ping sent by client.\n EndOfStream = 5, /// All packets were sent.\n ProfileInfo = 6, /// Profiling data\n Totals = 7, /// Block of totals, may be compressed.\n Extremes = 8, /// Block of mins and maxs, may be compressed.\n TablesStatusResponse = 9, /// Response to TableStatus.\n Log = 10, /// Query execution log.\n TableColumns = 11, /// Columns' description for default values calculation\n PartUUIDs = 12, /// List of unique parts ids.\n ReadTaskRequest = 13, /// String (UUID) describes a request for which next task is needed\n /// This is such an inverted logic, where server sends requests\n /// And client returns back response\n ProfileEvents = 14, /// Packet with profile events from server.\n };\n }\n\n /// Types of packets sent by client.\n namespace ClientCodes {\n enum {\n Hello = 0, /// Name, version, default database name.\n Query = 1, /** Query id, query settings, query processing stage,\n * compression status, and query text (no INSERT data).\n */\n Data = 2, /// Data `Block` (e.g. INSERT data), may be compressed.\n Cancel = 3, /// Cancel query.\n Ping = 4, /// Check server connection.\n };\n }\n\n /// Should we compress `Block`s of data\n namespace CompressionState {\n enum {\n Disable = 0,\n Enable = 1,\n };\n }\n\n namespace Stages {\n enum {\n Complete = 2,\n };\n }\n}\n\n// Path: clickhouse/base/input.h\n#pragma once\n\n#include \n#include \n#include \n#include \n\nnamespace clickhouse {\n\nclass InputStream {\npublic:\n virtual ~InputStream() noexcept (false)\n { }\n\n /// Reads one byte from the stream.\n inline bool ReadByte(uint8_t* byte) {\n return DoRead(byte, sizeof(uint8_t)) == sizeof(uint8_t);\n }\n\n /// Reads some data from the stream.\n inline size_t Read(void* buf, size_t len) {\n return DoRead(buf, len);\n }\n\n // Skips a number of bytes. Returns false if an underlying read error occurs.\n virtual bool Skip(size_t bytes) = 0;\n\nprotected:\n virtual size_t DoRead(void* buf, size_t len) = 0;\n};\n\n\nclass ZeroCopyInput : public InputStream {\npublic:\n inline size_t Next(const void** buf, size_t len) {\n return DoNext(buf, len);\n }\n\n bool Skip(size_t bytes) override;\n\nprotected:\n virtual size_t DoNext(const void** ptr, size_t len) = 0;\n\n size_t DoRead(void* buf, size_t len) override;\n};\n\n\n/**\n * A ZeroCopyInput stream backed by an in-memory array of bytes.\n */\nclass ArrayInput : public ZeroCopyInput {\npublic:\n ArrayInput() noexcept;\n ArrayInput(const void* buf, size_t len) noexcept;\n ~ArrayInput() override;\n\n /// Number of bytes available in the stream.\n inline size_t Avail() const noexcept {\n return len_;\n }\n\n /// Current read position in the memory block used by this stream.\n inline const uint8_t* Data() const noexcept {\n return data_;\n }\n\n /// Whether there is more data in the stream.\n inline bool Exhausted() const noexcept {\n return !Avail();\n }\n\n inline void Reset(const void* buf, size_t len) noexcept {\n data_ = static_cast(buf);\n len_ = len;\n }\n\nprivate:\n size_t DoNext(const void** ptr, size_t len) override;\n\nprivate:\n const uint8_t* data_;\n size_t len_;\n};\n\n\nclass BufferedInput : public ZeroCopyInput {\npublic:\n BufferedInput(std::unique_ptr source, size_t buflen = 8192);\n ~BufferedInput() override;\n\n void Reset();\n\nprotected:\n size_t DoRead(void* buf, size_t len) override;\n size_t DoNext(const void** ptr, size_t len) override;\n\nprivate:\n std::unique_ptr const source_;\n ArrayInput array_input_;\n std::vector buffer_;\n};\n\n}\n\n// Path: clickhouse/base/buffer.h\n#pragma once\n\n#include \n#include \n\nnamespace clickhouse {\n\nusing Buffer = std::vector;\n\n}\n\n// Path: clickhouse/base/output.h\n#pragma once\n\n#include \"buffer.h\"\n\n#include \n#include \n#include \n#include \n#include \n\nnamespace clickhouse {\n\nclass OutputStream {\npublic:\n virtual ~OutputStream()\n { }\n\n inline void Flush() {\n DoFlush();\n }\n\n inline size_t Write(const void* data, size_t len) {\n return DoWrite(data, len);\n }\n\nprotected:\n virtual void DoFlush() { }\n\n virtual size_t DoWrite(const void* data, size_t len) = 0;\n};\n\n\nclass ZeroCopyOutput : public OutputStream {\npublic:\n inline size_t Next(void** data, size_t size) {\n return DoNext(data, size);\n }\n\nprotected:\n // Obtains a buffer into which data can be written. Any data written\n // into this buffer will eventually (maybe instantly, maybe later on)\n // be written to the output.\n virtual size_t DoNext(void** data, size_t len) = 0;\n\n size_t DoWrite(const void* data, size_t len) override;\n};\n\n\n/**\n * A ZeroCopyOutput stream backed by an in-memory array of bytes.\n */\nclass ArrayOutput : public ZeroCopyOutput {\npublic:\n ArrayOutput(void* buf, size_t len);\n ~ArrayOutput() override;\n\n /// Number of bytes available in the stream.\n inline size_t Avail() const noexcept {\n return end_ - buf_;\n }\n\n /// Current write position in the memory block used by this stream.\n inline const uint8_t* Data() const noexcept {\n return buf_;\n }\n\n /// Whether there is more space in the stream.\n inline bool Exhausted() const noexcept {\n return !Avail();\n }\n\n /// Initializes this stream with a new memory block.\n inline void Reset(void* buf, size_t len) noexcept {\n buf_ = static_cast(buf);\n end_ = buf_ + len;\n buffer_size_ = len;\n }\n\n /// Number of bytes written to the buffer.\n inline size_t Size() const noexcept {\n return buffer_size_ - Avail();\n }\n\nprotected:\n size_t DoNext(void** data, size_t len) override;\n\nprivate:\n uint8_t* buf_;\n uint8_t* end_;\n size_t buffer_size_;\n};\n\n\n/**\n * A ZeroCopyOutput stream backed by a vector.\n *\n * Doesn't Flush() in destructor, client must ensure to do it manually at some point.\n */\nclass BufferOutput : public ZeroCopyOutput {\npublic:\n BufferOutput(Buffer* buf);\n ~BufferOutput() override;\n\nprotected:\n size_t DoNext(void** data, size_t len) override;\n\nprivate:\n Buffer* buf_;\n size_t pos_;\n};\n\n/** BufferedOutput writes data to internal buffer first.\n *\n * Any data goes to underlying stream only if internal buffer is full\n * or when client invokes Flush() on this.\n *\n * Doesn't Flush() in destructor, client must ensure to do it manually at some point.\n */\nclass BufferedOutput : public ZeroCopyOutput {\npublic:\n explicit BufferedOutput(std::unique_ptr destination, size_t buflen = 8192);\n ~BufferedOutput() override;\n\n void Reset();\n\nprotected:\n void DoFlush() override;\n size_t DoNext(void** data, size_t len) override;\n size_t DoWrite(const void* data, size_t len) override;\n\nprivate:\n std::unique_ptr const destination_;\n Buffer buffer_;\n ArrayOutput array_output_;\n};\n\ntemplate \nvoid WriteUnaligned(void* buf, const T& value) {\n memcpy(buf, &value, sizeof(value));\n}\n\n}\n\n// Path: clickhouse/base/compressed.h\n#pragma once\n\n#include \"input.h\"\n#include \"output.h\"\n#include \"buffer.h\"\n\nnamespace clickhouse {\n\nclass CompressedInput : public ZeroCopyInput {\npublic:\n explicit CompressedInput(InputStream* input);\n ~CompressedInput() override;\n\nprotected:\n size_t DoNext(const void** ptr, size_t len) override;\n\n bool Decompress();\n\nprivate:\n InputStream* const input_;\n\n Buffer data_;\n ArrayInput mem_;\n};\n\nclass CompressedOutput : public OutputStream {\npublic:\n explicit CompressedOutput(OutputStream * destination, size_t max_compressed_chunk_size = 0);\n ~CompressedOutput() override;\n\nprotected:\n size_t DoWrite(const void* data, size_t len) override;\n void DoFlush() override;\n\nprivate:\n void Compress(const void * data, size_t len);\n void PreallocateCompressBuffer(size_t input_size);\n\nprivate:\n OutputStream * destination_;\n const size_t max_compressed_chunk_size_;\n Buffer compressed_buffer_;\n};\n\n}\n\n// Path: clickhouse/base/platform.h\n#pragma once\n\n#if defined(__linux__)\n# define _linux_\n#elif defined(_WIN64)\n# define _win64_\n# define _win32_\n#elif defined(__WIN32__) || defined(_WIN32)\n# define _win32_\n#elif defined(__APPLE__)\n# define _darwin_\n#endif\n\n#if defined(_win32_) || defined(_win64_)\n# define _win_\n# if !defined(_WIN32_WINNT) || (_WIN32_WINNT < 0x0600)\n# undef _WIN32_WINNT\n# define _WIN32_WINNT 0x0600 // The WSAPoll function is defined on Windows Vista and later.\n# endif\n# define WIN32_LEAN_AND_MEAN 1 // don't include too much header automatically\n#endif\n\n#if defined(_linux_) || defined (_darwin_)\n# define _unix_\n#endif\n\n#if defined(_MSC_VER)\n# undef NOMINMAX\n# define NOMINMAX\n# include \n# define ssize_t SSIZE_T\n# define HAVE_SSIZE_T 1\n#endif\n\n// Path: clickhouse/base/endpoints_iterator.h\n#pragma once\n\n#include \"clickhouse/client.h\"\n#include \n\nnamespace clickhouse {\n\nstruct ClientOptions;\n\n/**\n * Base class for iterating through endpoints.\n*/\nclass EndpointsIteratorBase\n{\n public:\n virtual ~EndpointsIteratorBase() = default;\n\n virtual Endpoint Next() = 0;\n};\n\nclass RoundRobinEndpointsIterator : public EndpointsIteratorBase\n{\n public:\n explicit RoundRobinEndpointsIterator(const std::vector& opts);\n Endpoint Next() override;\n\n ~RoundRobinEndpointsIterator() override;\n\n private:\n const std::vector& endpoints;\n size_t current_index;\n};\n\n}\n\n// Path: clickhouse/base/socket.h\n#pragma once\n\n#include \"platform.h\"\n#include \"input.h\"\n#include \"output.h\"\n#include \"endpoints_iterator.h\"\n\n#include \n#include \n#include \n\n#if defined(_win_)\n# include \n# include \n#else\n# include \n# include \n# include \n# include \n\n# if !defined(SOCKET)\n# define SOCKET int\n# endif\n#endif\n\n#include \n#include \n\nstruct addrinfo;\n\nnamespace clickhouse {\n\nstruct ClientOptions;\n\n/** Address of a host to establish connection to.\n *\n */\nclass NetworkAddress {\npublic:\n explicit NetworkAddress(const std::string& host,\n const std::string& port = \"0\");\n ~NetworkAddress();\n\n const struct addrinfo* Info() const;\n const std::string & Host() const;\n\nprivate:\n const std::string host_;\n struct addrinfo* info_;\n};\n\n#if defined(_win_)\n\nclass windowsErrorCategory : public std::error_category {\npublic:\n char const* name() const noexcept override final;\n std::string message(int c) const override final;\n\n static windowsErrorCategory const& category();\n};\n\n#endif\n\n#if defined(_unix_)\n\nclass getaddrinfoErrorCategory : public std::error_category {\npublic:\n char const* name() const noexcept override final;\n std::string message(int c) const override final;\n\n static getaddrinfoErrorCategory const& category();\n};\n\n#endif\n\n\nclass SocketBase {\npublic:\n virtual ~SocketBase();\n\n virtual std::unique_ptr makeInputStream() const = 0;\n virtual std::unique_ptr makeOutputStream() const = 0;\n};\n\n\nclass SocketFactory {\npublic:\n virtual ~SocketFactory();\n\n // TODO: move connection-related options to ConnectionOptions structure.\n\n virtual std::unique_ptr connect(const ClientOptions& opts, const Endpoint& endpoint) = 0;\n\n virtual void sleepFor(const std::chrono::milliseconds& duration);\n};\n\n\nstruct SocketTimeoutParams {\n std::chrono::milliseconds connect_timeout{ 5000 };\n std::chrono::milliseconds recv_timeout{ 0 };\n std::chrono::milliseconds send_timeout{ 0 };\n};\n\nclass Socket : public SocketBase {\npublic:\n Socket(const NetworkAddress& addr, const SocketTimeoutParams& timeout_params);\n Socket(const NetworkAddress& addr);\n Socket(Socket&& other) noexcept;\n Socket& operator=(Socket&& other) noexcept;\n\n ~Socket() override;\n\n /// @params idle the time (in seconds) the connection needs to remain\n /// idle before TCP starts sending keepalive probes.\n /// @params intvl the time (in seconds) between individual keepalive probes.\n /// @params cnt the maximum number of keepalive probes TCP should send\n /// before dropping the connection.\n void SetTcpKeepAlive(int idle, int intvl, int cnt) noexcept;\n\n /// @params nodelay whether to enable TCP_NODELAY\n void SetTcpNoDelay(bool nodelay) noexcept;\n\n std::unique_ptr makeInputStream() const override;\n std::unique_ptr makeOutputStream() const override;\n\nprotected:\n Socket(const Socket&) = delete;\n Socket& operator = (const Socket&) = delete;\n void Close();\n\n SOCKET handle_;\n};\n\n\nclass NonSecureSocketFactory : public SocketFactory {\npublic:\n ~NonSecureSocketFactory() override;\n\n std::unique_ptr connect(const ClientOptions& opts, const Endpoint& endpoint) override;\n\nprotected:\n virtual std::unique_ptr doConnect(const NetworkAddress& address, const ClientOptions& opts);\n\n void setSocketOptions(Socket& socket, const ClientOptions& opts);\n};\n\n\nclass SocketInput : public InputStream {\npublic:\n explicit SocketInput(SOCKET s);\n ~SocketInput();\n\nprotected:\n bool Skip(size_t bytes) override;\n size_t DoRead(void* buf, size_t len) override;\n\nprivate:\n SOCKET s_;\n};\n\nclass SocketOutput : public OutputStream {\npublic:\n explicit SocketOutput(SOCKET s);\n ~SocketOutput();\n\nprotected:\n size_t DoWrite(const void* data, size_t len) override;\n\nprivate:\n SOCKET s_;\n};\n\nstatic struct NetrworkInitializer {\n NetrworkInitializer();\n} gNetrworkInitializer;\n\n}\n\n// Path: clickhouse/base/wire_format.h\n#pragma once\n\n#include \n#include \n\nnamespace clickhouse {\n\nclass InputStream;\nclass OutputStream;\n\nclass WireFormat {\npublic:\n template \n static bool ReadFixed(InputStream& input, T* value);\n static bool ReadString(InputStream& input, std::string* value);\n static bool SkipString(InputStream& input);\n static bool ReadBytes(InputStream& input, void* buf, size_t len);\n static bool ReadUInt64(InputStream& input, uint64_t* value);\n static bool ReadVarint64(InputStream& output, uint64_t* value);\n\n template \n static void WriteFixed(OutputStream& output, const T& value);\n static void WriteBytes(OutputStream& output, const void* buf, size_t len);\n static void WriteString(OutputStream& output, std::string_view value);\n static void WriteUInt64(OutputStream& output, const uint64_t value);\n static void WriteVarint64(OutputStream& output, uint64_t value);\n\nprivate:\n static bool ReadAll(InputStream& input, void* buf, size_t len);\n static void WriteAll(OutputStream& output, const void* buf, size_t len);\n};\n\ntemplate \ninline bool WireFormat::ReadFixed(InputStream& input, T* value) {\n return ReadAll(input, value, sizeof(T));\n}\n\ninline bool WireFormat::ReadString(InputStream& input, std::string* value) {\n uint64_t len = 0;\n if (ReadVarint64(input, &len)) {\n if (len > 0x00FFFFFFULL) {\n return false;\n }\n value->resize((size_t)len);\n return ReadAll(input, value->data(), (size_t)len);\n }\n\n return false;\n}\n\ninline bool WireFormat::ReadBytes(InputStream& input, void* buf, size_t len) {\n return ReadAll(input, buf, len);\n}\n\ninline bool WireFormat::ReadUInt64(InputStream& input, uint64_t* value) {\n return ReadVarint64(input, value);\n}\n\ntemplate \ninline void WireFormat::WriteFixed(OutputStream& output, const T& value) {\n WriteAll(output, &value, sizeof(T));\n}\n\ninline void WireFormat::WriteBytes(OutputStream& output, const void* buf, size_t len) {\n WriteAll(output, buf, len);\n}\n\ninline void WireFormat::WriteString(OutputStream& output, std::string_view value) {\n WriteVarint64(output, value.size());\n WriteAll(output, value.data(), value.size());\n}\n\ninline void WireFormat::WriteUInt64(OutputStream& output, const uint64_t value) {\n WriteVarint64(output, value);\n}\n\n}\n\n// Path: clickhouse/columns/factory.h\n#pragma once\n\n#include \"column.h\"\n\nnamespace clickhouse {\n\nstruct CreateColumnByTypeSettings\n{\n bool low_cardinality_as_wrapped_column = false;\n};\n\nColumnRef CreateColumnByType(const std::string& type_name, CreateColumnByTypeSettings settings = {});\n\n}\n\n// Path: clickhouse/base/sslsocket.h\n#pragma once\n\n#include \"socket.h\"\n\n#include \n#include \n#include \n\ntypedef struct ssl_ctx_st SSL_CTX;\ntypedef struct ssl_st SSL;\n\nnamespace clickhouse {\n\nstruct SSLParams\n{\n std::vector path_to_ca_files;\n std::string path_to_ca_directory;\n bool use_default_ca_locations;\n int context_options;\n int min_protocol_version;\n int max_protocol_version;\n bool use_SNI;\n bool skip_verification;\n int host_flags;\n using ConfigurationType = std::vector>>;\n ConfigurationType configuration;\n};\n\nclass SSLContext\n{\npublic:\n explicit SSLContext(SSL_CTX & context);\n explicit SSLContext(const SSLParams & context_params);\n ~SSLContext() = default;\n\n SSLContext(const SSLContext &) = delete;\n SSLContext& operator=(const SSLContext &) = delete;\n SSLContext(SSLContext &&) = delete;\n SSLContext& operator=(SSLContext &) = delete;\n\nprivate:\n friend class SSLSocket;\n SSL_CTX * getContext();\n\nprivate:\n std::unique_ptr context_;\n};\n\nclass SSLSocket : public Socket {\npublic:\n explicit SSLSocket(const NetworkAddress& addr, const SocketTimeoutParams& timeout_params,\n const SSLParams& ssl_params, SSLContext& context);\n\n SSLSocket(SSLSocket &&) = default;\n ~SSLSocket() override = default;\n\n SSLSocket(const SSLSocket & ) = delete;\n SSLSocket& operator=(const SSLSocket & ) = delete;\n\n std::unique_ptr makeInputStream() const override;\n std::unique_ptr makeOutputStream() const override;\n\n static void validateParams(const SSLParams & ssl_params);\nprivate:\n std::unique_ptr ssl_;\n};\n\nclass SSLSocketFactory : public NonSecureSocketFactory {\npublic:\n explicit SSLSocketFactory(const ClientOptions& opts);\n ~SSLSocketFactory() override;\n\nprotected:\n std::unique_ptr doConnect(const NetworkAddress& address, const ClientOptions& opts) override;\n\nprivate:\n const SSLParams ssl_params_;\n std::unique_ptr ssl_context_;\n};\n\nclass SSLSocketInput : public InputStream {\npublic:\n explicit SSLSocketInput(SSL *ssl);\n ~SSLSocketInput() = default;\n\n bool Skip(size_t /*bytes*/) override {\n return false;\n }\n\nprotected:\n size_t DoRead(void* buf, size_t len) override;\n\nprivate:\n // Not owning\n SSL *ssl_;\n};\n\nclass SSLSocketOutput : public OutputStream {\npublic:\n explicit SSLSocketOutput(SSL *ssl);\n ~SSLSocketOutput() = default;\n\nprotected:\n size_t DoWrite(const void* data, size_t len) override;\n\nprivate:\n // Not owning\n SSL *ssl_;\n};\n\n}\n\n// Path: clickhouse/client.cpp\n#include \"client.h\"\n#include \"clickhouse/version.h\"\n#include \"protocol.h\"\n\n#include \"base/compressed.h\"\n#include \"base/socket.h\"\n#include \"base/wire_format.h\"\n\n#include \"columns/factory.h\"\n\n#include \n#include \n#include \n#include \n\n#if defined(WITH_OPENSSL)\n#include \"base/sslsocket.h\"\n#endif\n\n#define DBMS_NAME \"ClickHouse\"\n\n#define DBMS_MIN_REVISION_WITH_TEMPORARY_TABLES 50264\n#define DBMS_MIN_REVISION_WITH_TOTAL_ROWS_IN_PROGRESS 51554\n#define DBMS_MIN_REVISION_WITH_BLOCK_INFO 51903\n#define DBMS_MIN_REVISION_WITH_CLIENT_INFO 54032\n#define DBMS_MIN_REVISION_WITH_SERVER_TIMEZONE 54058\n#define DBMS_MIN_REVISION_WITH_QUOTA_KEY_IN_CLIENT_INFO 54060\n//#define DBMS_MIN_REVISION_WITH_TABLES_STATUS 54226\n#define DBMS_MIN_REVISION_WITH_TIME_ZONE_PARAMETER_IN_DATETIME_DATA_TYPE 54337\n#define DBMS_MIN_REVISION_WITH_SERVER_DISPLAY_NAME 54372\n#define DBMS_MIN_REVISION_WITH_VERSION_PATCH 54401\n#define DBMS_MIN_REVISION_WITH_LOW_CARDINALITY_TYPE 54405\n#define DBMS_MIN_REVISION_WITH_COLUMN_DEFAULTS_METADATA 54410\n#define DBMS_MIN_REVISION_WITH_CLIENT_WRITE_INFO 54420\n#define DBMS_MIN_REVISION_WITH_SETTINGS_SERIALIZED_AS_STRINGS 54429\n#define DBMS_MIN_REVISION_WITH_INTERSERVER_SECRET 54441\n#define DBMS_MIN_REVISION_WITH_OPENTELEMETRY 54442\n#define DBMS_MIN_REVISION_WITH_DISTRIBUTED_DEPTH 54448\n#define DBMS_MIN_REVISION_WITH_INITIAL_QUERY_START_TIME 54449\n#define DBMS_MIN_REVISION_WITH_INCREMENTAL_PROFILE_EVENTS 54451\n\n#define DMBS_PROTOCOL_REVISION DBMS_MIN_REVISION_WITH_INCREMENTAL_PROFILE_EVENTS\n\nnamespace clickhouse {\n\nstruct ClientInfo {\n uint8_t iface_type = 1; // TCP\n uint8_t query_kind;\n std::string initial_user;\n std::string initial_query_id;\n std::string quota_key;\n std::string os_user;\n std::string client_hostname;\n std::string client_name;\n std::string initial_address = \"[::ffff:127.0.0.1]:0\";\n uint64_t client_version_major = 0;\n uint64_t client_version_minor = 0;\n uint64_t client_version_patch = 0;\n uint32_t client_revision = 0;\n};\n\nstd::ostream& operator<<(std::ostream& os, const Endpoint& endpoint) {\n return os << endpoint.host << \":\" << endpoint.port;\n}\n\nstd::ostream& operator<<(std::ostream& os, const ClientOptions& opt) {\n os << \"Client(\"\n << \" Endpoints : [\";\n size_t extra_endpoints = 0;\n\n if (!opt.host.empty()) {\n extra_endpoints = 1;\n os << opt.user << '@' << Endpoint{opt.host, opt.port};\n\n if (opt.endpoints.size())\n os << \", \";\n }\n\n for (size_t i = 0; i < opt.endpoints.size(); i++) {\n os << opt.user << '@' << opt.endpoints[i]\n << ((i == opt.endpoints.size() - 1) ? \"\" : \", \");\n }\n\n os << \"] (\" << opt.endpoints.size() + extra_endpoints << \" items )\"\n << \" ping_before_query:\" << opt.ping_before_query\n << \" send_retries:\" << opt.send_retries\n << \" retry_timeout:\" << opt.retry_timeout.count()\n << \" compression_method:\"\n << (opt.compression_method == CompressionMethod::LZ4 ? \"LZ4\" : \"None\");\n#if defined(WITH_OPENSSL)\n if (opt.ssl_options) {\n const auto & ssl_options = *opt.ssl_options;\n os << \" SSL (\"\n << \" ssl_context: \" << (ssl_options.ssl_context ? \"provided by user\" : \"created internally\")\n << \" use_default_ca_locations: \" << ssl_options.use_default_ca_locations\n << \" path_to_ca_files: \" << ssl_options.path_to_ca_files.size() << \" items\"\n << \" path_to_ca_directory: \" << ssl_options.path_to_ca_directory\n << \" min_protocol_version: \" << ssl_options.min_protocol_version\n << \" max_protocol_version: \" << ssl_options.max_protocol_version\n << \" context_options: \" << ssl_options.context_options\n << \")\";\n }\n#endif\n os << \")\";\n return os;\n}\n\nClientOptions& ClientOptions::SetSSLOptions(ClientOptions::SSLOptions options)\n{\n#ifdef WITH_OPENSSL\n ssl_options = options;\n return *this;\n#else\n (void)options;\n throw OpenSSLError(\"Library was built with no SSL support\");\n#endif\n}\n\nnamespace {\n\n\nstd::unique_ptr GetSocketFactory(const ClientOptions& opts) {\n (void)opts;\n#if defined(WITH_OPENSSL)\n if (opts.ssl_options)\n return std::make_unique(opts);\n else\n#endif\n return std::make_unique();\n}\n\nstd::unique_ptr GetEndpointsIterator(const ClientOptions& opts) {\n if (opts.endpoints.empty())\n {\n throw ValidationError(\"The list of endpoints is empty\");\n }\n\n return std::make_unique(opts.endpoints);\n}\n\n}\n\nclass Client::Impl {\npublic:\n Impl(const ClientOptions& opts);\n Impl(const ClientOptions& opts,\n std::unique_ptr socket_factory);\n ~Impl();\n\n void ExecuteQuery(Query query);\n\n void SendCancel();\n\n void Insert(const std::string& table_name, const std::string& query_id, const Block& block);\n\n void Ping();\n\n void ResetConnection();\n\n void ResetConnectionEndpoint();\n\n const ServerInfo& GetServerInfo() const;\n\n const std::optional& GetCurrentEndpoint() const;\n\nprivate:\n bool Handshake();\n\n bool ReceivePacket(uint64_t* server_packet = nullptr);\n\n void SendQuery(const Query& query);\n\n void SendData(const Block& block);\n\n bool SendHello();\n\n bool ReadBlock(InputStream& input, Block* block);\n\n bool ReceiveHello();\n\n /// Reads data packet form input stream.\n bool ReceiveData();\n\n /// Reads exception packet form input stream.\n bool ReceiveException(bool rethrow = false);\n\n void WriteBlock(const Block& block, OutputStream& output);\n\n void CreateConnection();\n\n void InitializeStreams(std::unique_ptr&& socket);\n\n inline size_t GetConnectionAttempts() const\n {\n return options_.endpoints.size() * options_.send_retries;\n }\n\nprivate:\n /// In case of network errors tries to reconnect to server and\n /// call fuc several times.\n void RetryGuard(std::function func);\n\n void RetryConnectToTheEndpoint(std::function& func);\n\nprivate:\n class EnsureNull {\n public:\n inline EnsureNull(QueryEvents* ev, QueryEvents** ptr)\n : ptr_(ptr)\n {\n if (ptr_) {\n *ptr_ = ev;\n }\n }\n\n inline ~EnsureNull() {\n if (ptr_) {\n *ptr_ = nullptr;\n }\n }\n\n private:\n QueryEvents** ptr_;\n\n };\n\n\n const ClientOptions options_;\n QueryEvents* events_;\n int compression_ = CompressionState::Disable;\n\n std::unique_ptr socket_factory_;\n\n std::unique_ptr input_;\n std::unique_ptr output_;\n std::unique_ptr socket_;\n std::unique_ptr endpoints_iterator;\n\n...\n", "needle_token_start": 15113, "needle_token_end": 15201, "code_context_ntokens": 15935, "output": ["```cpp\nstd::unique_ptr GetSocketFactory(const ClientOptions& opts) {\n (void)opts;\n#if defined(WITH_OPENSSL)\n if (opts.ssl_options)\n return std::make_unique(opts);\n else\n#endif\n return std::make_unique();\n}\n```"]} +{"repo": "polybar/polybar", "name": "make_mask", "language": "cpp", "path": "src/modules/bspwm.cpp", "position_ratio": 0.05, "description": "\nFunction Description:\n1. **Purpose**: The function is designed to create a bitmask based on the states of a window manager module, allowing for the combination of two different states into a single bitmask.\n2. **Input**: Two enumeration values representing the states of a window manager module, with the second state having a default value indicating no state.\n3. **Output**: An unsigned integer representing the bitmask created from the input states.\n4. **Procedure**: The function checks if each input state is non-zero. For each non-zero state, it shifts a bit left by the state value minus one and combines these bits using a bitwise OR operation to form the final bitmask.\n\n", "instruction": "Based on the function description and code context, please retrieve and repeat the exact described function from the code context in a code block wrapped by ```:", "template": "instruction\ncode_context\ndescription\ninstruction", "code_context": "// Path: src/modules/date.cpp\n#include \"modules/date.hpp\"\n\n#include \"drawtypes/label.hpp\"\n#include \"modules/meta/base.inl\"\n\nPOLYBAR_NS\n\nnamespace modules {\n template class module;\n\n date_module::date_module(const bar_settings& bar, string name_, const config& config)\n : timer_module(bar, move(name_), config) {\n if (!m_bar.locale.empty()) {\n datetime_stream.imbue(std::locale(m_bar.locale.c_str()));\n }\n\n m_router->register_action(EVENT_TOGGLE, [this]() { action_toggle(); });\n\n m_dateformat = m_conf.get(name(), \"date\", \"\"s);\n m_dateformat_alt = m_conf.get(name(), \"date-alt\", \"\"s);\n m_timeformat = m_conf.get(name(), \"time\", \"\"s);\n m_timeformat_alt = m_conf.get(name(), \"time-alt\", \"\"s);\n\n if (m_dateformat.empty() && m_timeformat.empty()) {\n throw module_error(\"No date or time format specified\");\n }\n\n set_interval(1s);\n\n m_formatter->add(DEFAULT_FORMAT, TAG_LABEL, {TAG_LABEL, TAG_DATE});\n\n if (m_formatter->has(TAG_DATE)) {\n m_log.warn(\"%s: The format tag `` is deprecated, use `

404 Not Found

\\n\";\n NotFound_tmp_stream << \"
drogon/\";\n NotFound_tmp_stream << NotFound_view_data.get(\"version\");\n NotFound_tmp_stream << \"
\\n\";\n NotFound_tmp_stream << \"\\n\";\n NotFound_tmp_stream << \"\\n\";\n NotFound_tmp_stream << \"\\n\";\n NotFound_tmp_stream << \"\\n\";\n NotFound_tmp_stream << \"\\n\";\n NotFound_tmp_stream << \"\\n\";\n NotFound_tmp_stream << \"\\n\";\n NotFound_tmp_stream << \"\\n\";\n return NotFound_tmp_stream.str();\n}\n\n// Path: lib/src/LocalHostFilter.cc\n/**\n *\n * LocalHostFilter.cc\n * An Tao\n *\n * Copyright 2018, An Tao. All rights reserved.\n * https://github.com/an-tao/drogon\n * Use of this source code is governed by a MIT license\n * that can be found in the License file.\n *\n * Drogon\n *\n */\n\n#include \"HttpResponseImpl.h\"\n#include \nusing namespace drogon;\n\nvoid LocalHostFilter::doFilter(const HttpRequestPtr &req,\n FilterCallback &&fcb,\n FilterChainCallback &&fccb)\n{\n if (req->peerAddr().isLoopbackIp())\n {\n fccb();\n return;\n }\n auto res = drogon::HttpResponse::newNotFoundResponse(req);\n fcb(res);\n}\n\n// Path: lib/src/RedisResultSkipped.cc\n/**\n *\n * RedisClientSkipped.cc\n * An Tao\n *\n * Copyright 2018, An Tao. All rights reserved.\n * https://github.com/an-tao/drogon\n * Use of this source code is governed by a MIT license\n * that can be found in the License file.\n *\n * Drogon\n *\n */\n\n#include \"drogon/nosql/RedisResult.h\"\n#include \"trantor/utils/Logger.h\"\n\nnamespace drogon\n{\nnamespace nosql\n{\nstd::string RedisResult::getStringForDisplaying() const noexcept\n{\n LOG_FATAL << \"Redis is not supported by drogon, please install the \"\n \"hiredis library first.\";\n abort();\n}\n\nstd::string RedisResult::getStringForDisplayingWithIndent(\n size_t /*indent*/) const noexcept\n{\n LOG_FATAL << \"Redis is not supported by drogon, please install the \"\n \"hiredis library first.\";\n abort();\n}\n\nstd::string RedisResult::asString() const noexcept(false)\n{\n LOG_FATAL << \"Redis is not supported by drogon, please install the \"\n \"hiredis library first.\";\n abort();\n}\n\nRedisResultType RedisResult::type() const noexcept\n{\n LOG_FATAL << \"Redis is not supported by drogon, please install the \"\n \"hiredis library first.\";\n abort();\n}\n\nstd::vector RedisResult::asArray() const noexcept(false)\n{\n LOG_FATAL << \"Redis is not supported by drogon, please install the \"\n \"hiredis library first.\";\n abort();\n}\n\nlong long RedisResult::asInteger() const noexcept(false)\n{\n LOG_FATAL << \"Redis is not supported by drogon, please install the \"\n \"hiredis library first.\";\n abort();\n}\n\nbool RedisResult::isNil() const noexcept\n{\n LOG_FATAL << \"Redis is not supported by drogon, please install the \"\n \"hiredis library first.\";\n abort();\n}\n} // namespace nosql\n} // namespace drogon\n\n// Path: lib/src/TokenBucketRateLimiter.cc\n#include \"TokenBucketRateLimiter.h\"\n\nusing namespace drogon;\n\nTokenBucketRateLimiter::TokenBucketRateLimiter(\n size_t capacity,\n std::chrono::duration timeUnit)\n : capacity_(capacity),\n lastTime_(std::chrono::steady_clock::now()),\n timeUnit_(timeUnit),\n tokens_((double)capacity_)\n{\n}\n\n// implementation of the token bucket algorithm\nbool TokenBucketRateLimiter::isAllowed()\n{\n auto now = std::chrono::steady_clock::now();\n auto duration = std::chrono::duration_cast>(\n now - lastTime_);\n tokens_ += capacity_ * (duration / timeUnit_);\n if (tokens_ > capacity_)\n tokens_ = (double)capacity_;\n lastTime_ = now;\n if (tokens_ > 1.0)\n {\n tokens_ -= 1.0;\n return true;\n }\n return false;\n}\n\n// Path: lib/src/AccessLogger.cc\n/**\n *\n * @file AccessLogger.cc\n * @author An Tao\n *\n * Copyright 2018, An Tao. All rights reserved.\n * https://github.com/an-tao/drogon\n * Use of this source code is governed by a MIT license\n * that can be found in the License file.\n *\n * Drogon\n *\n */\n\n#include \"HttpUtils.h\"\n#include \n#include \n#include \n#include \n#include \n#if !defined _WIN32 && !defined __HAIKU__\n#include \n#include \n#elif defined __HAIKU__\n#include \n#else\n#include \n#endif\n#ifdef __FreeBSD__\n#include \n#endif\n\n#ifdef DROGON_SPDLOG_SUPPORT\n#include \n#include \n#include \n#include \n#ifdef _WIN32\n#include \n// Damn antedeluvian M$ macros\n#undef min\n#undef max\n#endif\n#ifndef _WIN32\n#include \n#include \n#define os_access access\n#elif !defined(_WIN32) || defined(__MINGW32__)\n#include \n#include \n#define os_access access\n#else\n#include \n#define os_access _waccess\n#define R_OK 04\n#define W_OK 02\n#endif\n#endif\n\nusing namespace drogon;\nusing namespace drogon::plugin;\n\nbool AccessLogger::useRealIp_ = false;\n\nvoid AccessLogger::initAndStart(const Json::Value &config)\n{\n useLocalTime_ = config.get(\"use_local_time\", true).asBool();\n showMicroseconds_ = config.get(\"show_microseconds\", true).asBool();\n timeFormat_ = config.get(\"custom_time_format\", \"\").asString();\n useCustomTimeFormat_ = !timeFormat_.empty();\n useRealIp_ = config.get(\"use_real_ip\", false).asBool();\n\n logFunctionMap_ = {{\"$request_path\", outputReqPath},\n {\"$path\", outputReqPath},\n {\"$date\",\n [this](trantor::LogStream &stream,\n const drogon::HttpRequestPtr &req,\n const drogon::HttpResponsePtr &resp) {\n outputDate(stream, req, resp);\n }},\n {\"$request_date\",\n [this](trantor::LogStream &stream,\n const drogon::HttpRequestPtr &req,\n const drogon::HttpResponsePtr &resp) {\n outputReqDate(stream, req, resp);\n }},\n {\"$request_query\", outputReqQuery},\n {\"$request_url\", outputReqURL},\n {\"$query\", outputReqQuery},\n {\"$url\", outputReqURL},\n {\"$request_version\", outputVersion},\n {\"$version\", outputVersion},\n {\"$request\", outputReqLine},\n {\"$remote_addr\", outputRemoteAddr},\n {\"$local_addr\", outputLocalAddr},\n {\"$request_len\", outputReqLength},\n {\"$body_bytes_received\", outputReqLength},\n {\"$method\", outputMethod},\n {\"$thread\", outputThreadNumber},\n {\"$response_len\", outputRespLength},\n {\"$body_bytes_sent\", outputRespLength},\n {\"$status\", outputStatusString},\n {\"$status_code\", outputStatusCode},\n {\"$processing_time\", outputProcessingTime},\n {\"$upstream_http_content-type\", outputRespContentType},\n {\"$upstream_http_content_type\", outputRespContentType}};\n auto format = config.get(\"log_format\", \"\").asString();\n if (format.empty())\n {\n format =\n \"$request_date $method $url [$body_bytes_received] ($remote_addr - \"\n \"$local_addr) $status $body_bytes_sent $processing_time\";\n }\n createLogFunctions(format);\n auto logPath = config.get(\"log_path\", \"\").asString();\n#ifdef DROGON_SPDLOG_SUPPORT\n auto logWithSpdlog = trantor::Logger::hasSpdLogSupport() &&\n config.get(\"use_spdlog\", false).asBool();\n if (logWithSpdlog)\n {\n logIndex_ = config.get(\"log_index\", 0).asInt();\n // Do nothing if already initialized...\n if (!trantor::Logger::getSpdLogger(logIndex_))\n {\n trantor::Logger::enableSpdLog(logIndex_);\n // Get the new logger & replace its sinks with the ones of the\n // config\n auto logger = trantor::Logger::getSpdLogger(logIndex_);\n std::vector sinks;\n while (!logPath.empty())\n {\n // 1. check existence of folder or try to create it\n auto fsLogPath =\n std::filesystem::path(utils::toNativePath(logPath));\n std::error_code fsErr;\n if (!std::filesystem::create_directories(fsLogPath, fsErr) &&\n fsErr)\n {\n LOG_ERROR << \"could not create log file path\";\n break;\n }\n // 2. check if we have rights to create files in the folder\n if (os_access(fsLogPath.native().c_str(), W_OK) != 0)\n {\n LOG_ERROR << \"cannot create files in log folder\";\n break;\n }\n std::filesystem::path fileName(\n config.get(\"log_file\", \"access.log\").asString());\n if (fileName.empty())\n fileName = \"access.log\";\n else\n fileName.replace_extension(\".log\");\n auto sizeLimit = config.get(\"log_size_limit\", 0).asUInt64();\n if (sizeLimit == 0)\n sizeLimit = config.get(\"size_limit\", 0).asUInt64();\n if (sizeLimit == 0) // 0 is not allowed by this sink\n sizeLimit = std::numeric_limits::max();\n std::size_t maxFiles = config.get(\"max_files\", 0).asUInt();\n sinks.push_back(\n std::make_shared(\n (fsLogPath / fileName).string(),\n sizeLimit,\n // spdlog limitation\n std::min(maxFiles, std::size_t(20000)),\n false));\n break;\n }\n if (sinks.empty())\n sinks.push_back(\n std::make_shared());\n#if defined(_WIN32) && defined(_DEBUG)\n // On Windows with debug, it may be interesting to have the logs\n // directly in the Visual Studio / WinDbg console\n sinks.push_back(std::make_shared());\n#endif\n logger->sinks() = sinks;\n // Override the pattern set in\n // trantor::Logger::getDefaultSpdLogger() and let AccessLogger\n // format the output\n logger->set_pattern(\"%v\");\n }\n }\n else\n#endif\n if (!logPath.empty())\n {\n auto fileName = config.get(\"log_file\", \"access.log\").asString();\n auto extension = std::string(\".log\");\n auto pos = fileName.rfind('.');\n if (pos != std::string::npos)\n {\n extension = fileName.substr(pos);\n fileName = fileName.substr(0, pos);\n }\n if (fileName.empty())\n {\n fileName = \"access\";\n }\n asyncFileLogger_.setFileName(fileName, extension, logPath);\n asyncFileLogger_.startLogging();\n logIndex_ = config.get(\"log_index\", 0).asInt();\n trantor::Logger::setOutputFunction(\n [&](const char *msg, const uint64_t len) {\n asyncFileLogger_.output(msg, len);\n },\n [&]() { asyncFileLogger_.flush(); },\n logIndex_);\n auto sizeLimit = config.get(\"log_size_limit\", 0).asUInt64();\n if (sizeLimit == 0)\n {\n // In earlier code, \"size_limit\" is taken instead of\n // \"log_size_limit\" as it said in the comment in AccessLogger.h.\n // In order to ensure backward compatibility we still take this\n // field as a fallback.\n sizeLimit = config.get(\"size_limit\", 0).asUInt64();\n }\n if (sizeLimit > 0)\n {\n asyncFileLogger_.setFileSizeLimit(sizeLimit);\n }\n auto maxFiles = config.get(\"max_files\", 0).asUInt();\n if (maxFiles >= 0)\n {\n asyncFileLogger_.setMaxFiles(maxFiles);\n }\n }\n drogon::app().registerPreSendingAdvice(\n [this](const drogon::HttpRequestPtr &req,\n const drogon::HttpResponsePtr &resp) {\n logging(LOG_RAW_TO(logIndex_), req, resp);\n });\n}\n\nvoid AccessLogger::shutdown()\n{\n}\n\nvoid AccessLogger::logging(trantor::LogStream &stream,\n const drogon::HttpRequestPtr &req,\n const drogon::HttpResponsePtr &resp)\n{\n for (auto &func : logFunctions_)\n {\n func(stream, req, resp);\n }\n}\n\nvoid AccessLogger::createLogFunctions(std::string format)\n{\n std::string rawString;\n while (!format.empty())\n {\n LOG_TRACE << format;\n auto pos = format.find('$');\n if (pos != std::string::npos)\n {\n rawString += format.substr(0, pos);\n\n format = format.substr(pos);\n std::regex e{\"^\\\\$[a-zA-Z0-9\\\\-_]+\"};\n std::smatch m;\n if (std::regex_search(format, m, e))\n {\n if (!rawString.empty())\n {\n logFunctions_.emplace_back(\n [rawString](trantor::LogStream &stream,\n const drogon::HttpRequestPtr &,\n const drogon::HttpResponsePtr &) {\n stream << rawString;\n });\n rawString.clear();\n }\n auto placeholder = m[0];\n logFunctions_.emplace_back(newLogFunction(placeholder));\n format = m.suffix().str();\n }\n else\n {\n rawString += '$';\n format = format.substr(1);\n }\n }\n else\n {\n rawString += format;\n break;\n }\n }\n if (!rawString.empty())\n {\n logFunctions_.emplace_back(\n [rawString =\n std::move(rawString)](trantor::LogStream &stream,\n const drogon::HttpRequestPtr &,\n const drogon::HttpResponsePtr &) {\n stream << rawString << \"\\n\";\n });\n }\n else\n {\n logFunctions_.emplace_back(\n [](trantor::LogStream &stream,\n const drogon::HttpRequestPtr &,\n const drogon::HttpResponsePtr &) { stream << \"\\n\"; });\n }\n}\n\nAccessLogger::LogFunction AccessLogger::newLogFunction(\n const std::string &placeholder)\n{\n auto iter = logFunctionMap_.find(placeholder);\n if (iter != logFunctionMap_.end())\n {\n return iter->second;\n }\n if (placeholder.find(\"$http_\") == 0 && placeholder.size() > 6)\n {\n auto headerName = placeholder.substr(6);\n return [headerName =\n std::move(headerName)](trantor::LogStream &stream,\n const drogon::HttpRequestPtr &req,\n const drogon::HttpResponsePtr &) {\n outputReqHeader(stream, req, headerName);\n };\n }\n if (placeholder.find(\"$cookie_\") == 0 && placeholder.size() > 8)\n {\n auto cookieName = placeholder.substr(8);\n return [cookieName =\n std::move(cookieName)](trantor::LogStream &stream,\n const drogon::HttpRequestPtr &req,\n const drogon::HttpResponsePtr &) {\n outputReqCookie(stream, req, cookieName);\n };\n }\n if (placeholder.find(\"$upstream_http_\") == 0 && placeholder.size() > 15)\n {\n auto headerName = placeholder.substr(15);\n return [headerName = std::move(\n headerName)](trantor::LogStream &stream,\n const drogon::HttpRequestPtr &,\n const drogon::HttpResponsePtr &resp) {\n outputRespHeader(stream, resp, headerName);\n };\n }\n return [placeholder](trantor::LogStream &stream,\n const drogon::HttpRequestPtr &,\n const drogon::HttpResponsePtr &) {\n stream << placeholder;\n };\n}\n\nvoid AccessLogger::outputReqPath(trantor::LogStream &stream,\n const drogon::HttpRequestPtr &req,\n const drogon::HttpResponsePtr &)\n{\n stream << req->path();\n}\n\nvoid AccessLogger::outputDate(trantor::LogStream &stream,\n const drogon::HttpRequestPtr &,\n const drogon::HttpResponsePtr &) const\n{\n if (useCustomTimeFormat_)\n {\n if (useLocalTime_)\n {\n stream << trantor::Date::now().toCustomedFormattedStringLocal(\n timeFormat_, showMicroseconds_);\n }\n else\n {\n stream << trantor::Date::now().toCustomedFormattedString(\n timeFormat_, showMicroseconds_);\n }\n }\n else\n {\n if (useLocalTime_)\n {\n stream << trantor::Date::now().toFormattedStringLocal(\n showMicroseconds_);\n }\n else\n {\n stream << trantor::Date::now().toFormattedString(showMicroseconds_);\n }\n }\n}\n\nvoid AccessLogger::outputReqDate(trantor::LogStream &stream,\n const drogon::HttpRequestPtr &req,\n const drogon::HttpResponsePtr &) const\n{\n if (useCustomTimeFormat_)\n {\n if (useLocalTime_)\n {\n stream << req->creationDate().toCustomedFormattedStringLocal(\n timeFormat_, showMicroseconds_);\n }\n else\n {\n stream << req->creationDate().toCustomedFormattedString(\n timeFormat_, showMicroseconds_);\n }\n }\n else\n {\n if (useLocalTime_)\n {\n stream << req->creationDate().toFormattedStringLocal(\n showMicroseconds_);\n }\n else\n {\n stream << req->creationDate().toFormattedString(showMicroseconds_);\n }\n }\n}\n\n//$request_query\nvoid AccessLogger::outputReqQuery(trantor::LogStream &stream,\n const drogon::HttpRequestPtr &req,\n const drogon::HttpResponsePtr &)\n{\n stream << req->query();\n}\n\n//$request_url\nvoid AccessLogger::outputReqURL(trantor::LogStream &stream,\n const drogon::HttpRequestPtr &req,\n const drogon::HttpResponsePtr &)\n{\n auto &query = req->query();\n if (query.empty())\n {\n stream << req->path();\n }\n else\n {\n stream << req->path() << '?' << query;\n }\n}\n\n//$request_version\nvoid AccessLogger::outputVersion(trantor::LogStream &stream,\n const drogon::HttpRequestPtr &req,\n const drogon::HttpResponsePtr &)\n{\n stream << req->versionString();\n}\n\n//$request\nvoid AccessLogger::outputReqLine(trantor::LogStream &stream,\n const drogon::HttpRequestPtr &req,\n const drogon::HttpResponsePtr &)\n{\n auto &query = req->query();\n if (query.empty())\n {\n stream << req->methodString() << \" \" << req->path() << \" \"\n << req->versionString();\n }\n else\n {\n stream << req->methodString() << \" \" << req->path() << '?' << query\n << \" \" << req->versionString();\n }\n}\n\nvoid AccessLogger::outputRemoteAddr(trantor::LogStream &stream,\n const drogon::HttpRequestPtr &req,\n const drogon::HttpResponsePtr &)\n{\n if (useRealIp_)\n {\n stream << RealIpResolver::GetRealAddr(req).toIpPort();\n }\n else\n {\n stream << req->peerAddr().toIpPort();\n }\n}\n\nvoid AccessLogger::outputLocalAddr(trantor::LogStream &stream,\n const drogon::HttpRequestPtr &req,\n const drogon::HttpResponsePtr &)\n{\n stream << req->localAddr().toIpPort();\n}\n\nvoid AccessLogger::outputReqLength(trantor::LogStream &stream,\n const drogon::HttpRequestPtr &req,\n const drogon::HttpResponsePtr &)\n{\n stream << req->body().length();\n}\n\nvoid AccessLogger::outputRespLength(trantor::LogStream &stream,\n const drogon::HttpRequestPtr &,\n const drogon::HttpResponsePtr &resp)\n{\n stream << resp->body().length();\n}\n\nvoid AccessLogger::outputMethod(trantor::LogStream &stream,\n const drogon::HttpRequestPtr &req,\n const drogon::HttpResponsePtr &)\n{\n stream << req->methodString();\n}\n\nvoid AccessLogger::outputThreadNumber(trantor::LogStream &stream,\n const drogon::HttpRequestPtr &,\n const drogon::HttpResponsePtr &)\n{\n#ifdef __linux__\n static thread_local pid_t threadId_{0};\n#else\n static thread_local uint64_t threadId_{0};\n#endif\n#ifdef __linux__\n if (threadId_ == 0)\n threadId_ = static_cast(::syscall(SYS_gettid));\n#elif defined __FreeBSD__\n if (threadId_ == 0)\n {\n threadId_ = pthread_getthreadid_np();\n }\n#elif defined __OpenBSD__\n if (threadId_ == 0)\n {\n threadId_ = getthrid();\n }\n#elif defined _WIN32 || defined __HAIKU__\n if (threadId_ == 0)\n {\n std::stringstream ss;\n ss << std::this_thread::get_id();\n threadId_ = std::stoull(ss.str());\n }\n#else\n if (threadId_ == 0)\n {\n pthread_threadid_np(NULL, &threadId_);\n }\n#endif\n stream << threadId_;\n}\n\n//$http_[header_name]\nvoid AccessLogger::outputReqHeader(trantor::LogStream &stream,\n...\n", "needle_token_start": 2400, "needle_token_end": 2407, "code_context_ntokens": 16109, "output": ["```cpp\nJson::Value JsonConfigAdapter::getJson(const std::string &content) const\n noexcept(false)\n{\n static std::once_flag once;\n static Json::CharReaderBuilder builder;\n std::call_once(once, []() { builder[\"collectComments\"] = false; });\n JSONCPP_STRING errs;\n std::unique_ptr reader(builder.newCharReader());\n Json::Value root;\n if (!reader->parse(\n content.c_str(), content.c_str() + content.size(), &root, &errs))\n {\n throw std::runtime_error(errs);\n }\n return root;\n}\n```"]} +{"repo": "drogonframework/drogon", "name": "registerCustomExtensionMime", "language": "cpp", "path": "lib/src/HttpUtils.cc", "position_ratio": 0.25, "description": "\nFunction Description:\n1. **Purpose**: The function is designed to add or update a custom file extension to MIME type mapping in a global map. This allows the system to recognize and handle files with custom extensions appropriately.\n2. **Input**: It takes two strings as parameters: the first represents the file extension (without the dot), and the second is the corresponding MIME type.\n3. **Output**: There is no return value; the function operates by side effects, modifying a global map.\n4. **Procedure**: The function first checks if the input extension is non-empty. If it is, it checks if this extension already has a MIME type associated with it in the map. If an association exists, a warning is logged indicating that the existing MIME type is being overwritten. Finally, the new MIME type is set for the given extension in the map.\n\n", "instruction": "Based on the function description and code context, please retrieve and repeat the exact described function from the code context in a code block wrapped by ```:", "template": "instruction\ncode_context\ndescription\ninstruction", "code_context": " {\"rar\", {FT_ARCHIVE, CT_APPLICATION_VND_RAR}},\n {\"svg\", {FT_IMAGE, CT_IMAGE_SVG_XML}},\n {\"tar\", {FT_ARCHIVE, CT_APPLICATION_X_TAR}},\n {\"targa\", {FT_IMAGE, CT_IMAGE_X_TGA}},\n {\"tif\", {FT_IMAGE, CT_IMAGE_TIFF}},\n {\"tiff\", {FT_IMAGE, CT_IMAGE_TIFF}},\n {\"tga\", {FT_IMAGE, CT_IMAGE_X_TGA}},\n {\"tgz\", {FT_ARCHIVE, CT_APPLICATION_X_TGZ}},\n {\"ts\", {FT_MEDIA, CT_VIDEO_MPEG2TS}},\n {\"tta\", {FT_AUDIO, CT_AUDIO_X_TTA}},\n {\"ttf\", {FT_DOCUMENT, CT_APPLICATION_X_FONT_TRUETYPE}},\n {\"txt\", {FT_DOCUMENT, CT_TEXT_PLAIN}},\n {\"w64\", {FT_AUDIO, CT_AUDIO_WAVE}},\n {\"wav\", {FT_AUDIO, CT_AUDIO_WAVE}},\n {\"wave\", {FT_AUDIO, CT_AUDIO_WAVE}},\n {\"wasm\", {FT_DOCUMENT, CT_APPLICATION_WASM}},\n {\"weba\", {FT_AUDIO, CT_AUDIO_WEBM}},\n {\"webm\", {FT_MEDIA, CT_VIDEO_WEBM}},\n {\"webp\", {FT_IMAGE, CT_IMAGE_WEBP}},\n {\"wma\", {FT_AUDIO, CT_AUDIO_X_MS_WMA}},\n {\"woff\", {FT_DOCUMENT, CT_APPLICATION_FONT_WOFF}},\n {\"woff2\", {FT_DOCUMENT, CT_APPLICATION_FONT_WOFF2}},\n {\"wv\", {FT_AUDIO, CT_AUDIO_X_WAVPACK}},\n {\"xht\", {FT_DOCUMENT, CT_APPLICATION_XHTML}},\n {\"xhtml\", {FT_DOCUMENT, CT_APPLICATION_XHTML}},\n {\"xml\", {FT_DOCUMENT, CT_APPLICATION_XML}},\n {\"xsl\", {FT_DOCUMENT, CT_TEXT_XSL}},\n {\"xz\", {FT_ARCHIVE, CT_APPLICATION_X_XZ}},\n {\"zip\", {FT_ARCHIVE, CT_APPLICATION_ZIP}},\n {\"7z\", {FT_ARCHIVE, CT_APPLICATION_X_7Z}},\n };\n\nconst std::string_view &statusCodeToString(int code)\n{\n switch (code)\n {\n case 100:\n {\n static std::string_view sv = \"Continue\";\n return sv;\n }\n case 101:\n {\n static std::string_view sv = \"Switching Protocols\";\n return sv;\n }\n case 102:\n {\n static std::string_view sv = \"Processing\";\n return sv;\n }\n case 103:\n {\n static std::string_view sv = \"Early Hints\";\n return sv;\n }\n case 200:\n {\n static std::string_view sv = \"OK\";\n return sv;\n }\n case 201:\n {\n static std::string_view sv = \"Created\";\n return sv;\n }\n case 202:\n {\n static std::string_view sv = \"Accepted\";\n return sv;\n }\n case 203:\n {\n static std::string_view sv = \"Non-Authoritative Information\";\n return sv;\n }\n case 204:\n {\n static std::string_view sv = \"No Content\";\n return sv;\n }\n case 205:\n {\n static std::string_view sv = \"Reset Content\";\n return sv;\n }\n case 206:\n {\n static std::string_view sv = \"Partial Content\";\n return sv;\n }\n case 207:\n {\n static std::string_view sv = \"Multi-Status\";\n return sv;\n }\n case 208:\n {\n static std::string_view sv = \"Already Reported\";\n return sv;\n }\n case 226:\n {\n static std::string_view sv = \"IM Used\";\n return sv;\n }\n case 300:\n {\n static std::string_view sv = \"Multiple Choices\";\n return sv;\n }\n case 301:\n {\n static std::string_view sv = \"Moved Permanently\";\n return sv;\n }\n case 302:\n {\n static std::string_view sv = \"Found\";\n return sv;\n }\n case 303:\n {\n static std::string_view sv = \"See Other\";\n return sv;\n }\n case 304:\n {\n static std::string_view sv = \"Not Modified\";\n return sv;\n }\n case 305:\n {\n static std::string_view sv = \"Use Proxy\";\n return sv;\n }\n case 306:\n {\n static std::string_view sv = \"(Unused)\";\n return sv;\n }\n case 307:\n {\n static std::string_view sv = \"Temporary Redirect\";\n return sv;\n }\n case 308:\n {\n static std::string_view sv = \"Permanent Redirect\";\n return sv;\n }\n case 400:\n {\n static std::string_view sv = \"Bad Request\";\n return sv;\n }\n case 401:\n {\n static std::string_view sv = \"Unauthorized\";\n return sv;\n }\n case 402:\n {\n static std::string_view sv = \"Payment Required\";\n return sv;\n }\n case 403:\n {\n static std::string_view sv = \"Forbidden\";\n return sv;\n }\n case 404:\n {\n static std::string_view sv = \"Not Found\";\n return sv;\n }\n case 405:\n {\n static std::string_view sv = \"Method Not Allowed\";\n return sv;\n }\n case 406:\n {\n static std::string_view sv = \"Not Acceptable\";\n return sv;\n }\n case 407:\n {\n static std::string_view sv = \"Proxy Authentication Required\";\n return sv;\n }\n case 408:\n {\n static std::string_view sv = \"Request Time-out\";\n return sv;\n }\n case 409:\n {\n static std::string_view sv = \"Conflict\";\n return sv;\n }\n case 410:\n {\n static std::string_view sv = \"Gone\";\n return sv;\n }\n case 411:\n {\n static std::string_view sv = \"Length Required\";\n return sv;\n }\n case 412:\n {\n static std::string_view sv = \"Precondition Failed\";\n return sv;\n }\n case 413:\n {\n static std::string_view sv = \"Request Entity Too Large\";\n return sv;\n }\n case 414:\n {\n static std::string_view sv = \"Request-URI Too Large\";\n return sv;\n }\n case 415:\n {\n static std::string_view sv = \"Unsupported Media Type\";\n return sv;\n }\n case 416:\n {\n static std::string_view sv = \"Requested Range Not Satisfiable\";\n return sv;\n }\n case 417:\n {\n static std::string_view sv = \"Expectation Failed\";\n return sv;\n }\n case 418:\n {\n static std::string_view sv = \"I'm a Teapot\";\n return sv;\n }\n case 421:\n {\n static std::string_view sv = \"Misdirected Request\";\n return sv;\n }\n case 422:\n {\n static std::string_view sv = \"Unprocessable Entity\";\n return sv;\n }\n case 423:\n {\n static std::string_view sv = \"Locked\";\n return sv;\n }\n case 424:\n {\n static std::string_view sv = \"Failed Dependency\";\n return sv;\n }\n case 425:\n {\n static std::string_view sv = \"Too Early\";\n return sv;\n }\n case 426:\n {\n static std::string_view sv = \"Upgrade Required\";\n return sv;\n }\n case 428:\n {\n static std::string_view sv = \"Precondition Required\";\n return sv;\n }\n case 429:\n {\n static std::string_view sv = \"Too Many Requests\";\n return sv;\n }\n case 431:\n {\n static std::string_view sv = \"Request Header Fields Too Large\";\n return sv;\n }\n case 451:\n {\n static std::string_view sv = \"Unavailable For Legal Reasons\";\n return sv;\n }\n case 500:\n {\n static std::string_view sv = \"Internal Server Error\";\n return sv;\n }\n case 501:\n {\n static std::string_view sv = \"Not Implemented\";\n return sv;\n }\n case 502:\n {\n static std::string_view sv = \"Bad Gateway\";\n return sv;\n }\n case 503:\n {\n static std::string_view sv = \"Service Unavailable\";\n return sv;\n }\n case 504:\n {\n static std::string_view sv = \"Gateway Time-out\";\n return sv;\n }\n case 505:\n {\n static std::string_view sv = \"HTTP Version Not Supported\";\n return sv;\n }\n case 506:\n {\n static std::string_view sv = \"Variant Also Negotiates\";\n return sv;\n }\n case 507:\n {\n static std::string_view sv = \"Insufficient Storage\";\n return sv;\n }\n case 508:\n {\n static std::string_view sv = \"Loop Detected\";\n return sv;\n }\n case 510:\n {\n static std::string_view sv = \"Not Extended\";\n return sv;\n }\n case 511:\n {\n static std::string_view sv = \"Network Authentication Required\";\n return sv;\n }\n default:\n if (code >= 100 && code < 200)\n {\n static std::string_view sv = \"Informational\";\n return sv;\n }\n else if (code >= 200 && code < 300)\n {\n static std::string_view sv = \"Successful\";\n return sv;\n }\n else if (code >= 300 && code < 400)\n {\n static std::string_view sv = \"Redirection\";\n return sv;\n }\n else if (code >= 400 && code < 500)\n {\n static std::string_view sv = \"Bad Request\";\n return sv;\n }\n else if (code >= 500 && code < 600)\n {\n static std::string_view sv = \"Server Error\";\n return sv;\n }\n else\n {\n static std::string_view sv = \"Undefined Error\";\n return sv;\n }\n }\n}\n\nContentType getContentType(const std::string &fileName)\n{\n std::string extName;\n auto pos = fileName.rfind('.');\n if (pos != std::string::npos)\n {\n extName = fileName.substr(pos + 1);\n transform(extName.begin(),\n extName.end(),\n extName.begin(),\n [](unsigned char c) { return tolower(c); });\n }\n auto it = fileTypeDatabase_.find(extName);\n return (it == fileTypeDatabase_.end()) ? CT_APPLICATION_OCTET_STREAM\n : it->second.second;\n}\n\nContentType parseContentType(const std::string_view &contentType)\n{\n // Generate map from database for faster query\n static std::unordered_map contentTypeMap_;\n // Thread safe initialization\n static std::once_flag flag;\n std::call_once(flag, []() {\n for (const auto &e : mimeTypeDatabase_)\n {\n for (const auto &type : e.second.first)\n contentTypeMap_[type] = e.first;\n }\n });\n auto ext = contentType.find(';');\n if (ext != std::string_view::npos)\n return parseContentType(contentType.substr(0, ext));\n if (contentType == \"application/x-www-form-urlencoded\")\n return CT_APPLICATION_X_FORM;\n if (contentType == \"multipart/form-data\")\n return CT_MULTIPART_FORM_DATA;\n auto it = contentTypeMap_.find(contentType);\n return (it == contentTypeMap_.end()) ? CT_CUSTOM : it->second;\n}\n\nFileType parseFileType(const std::string_view &fileExtension)\n{\n std::string extName(fileExtension);\n transform(extName.begin(),\n extName.end(),\n extName.begin(),\n [](unsigned char c) { return tolower(c); });\n auto it = fileTypeDatabase_.find(extName);\n return (it == fileTypeDatabase_.end()) ? FT_CUSTOM : it->second.first;\n}\n\nFileType getFileType(ContentType contentType)\n{\n // Generate map from database for faster query\n static std::unordered_map fileTypeMap_;\n // Thread safe initialization\n static std::once_flag flag;\n std::call_once(flag, []() {\n for (const auto &e : fileTypeDatabase_)\n fileTypeMap_[e.second.second] = e.second.first;\n fileTypeMap_[CT_NONE] = FT_UNKNOWN;\n fileTypeMap_[CT_CUSTOM] = FT_CUSTOM;\n });\n auto it = fileTypeMap_.find(contentType);\n return (it == fileTypeMap_.end()) ? FT_UNKNOWN : it->second;\n}\n\nconst std::string_view &contentTypeToMime(ContentType contentType)\n{\n auto it = mimeTypeDatabase_.find(contentType);\n return (it == mimeTypeDatabase_.end())\n ? mimeTypeDatabase_.at(CT_APPLICATION_OCTET_STREAM).first.front()\n : (it->second.second.empty() ? it->second.first.front()\n : it->second.second);\n}\n\n\nvoid registerCustomExtensionMime(const std::string &ext,\n const std::string &mime)\n{\n if (ext.empty())\n return;\n auto &mimeStr = customMime[ext];\n if (!mimeStr.empty())\n {\n LOG_WARN << ext << \" has already been registered as type \" << mime\n << \". Overwriting.\";\n }\n mimeStr = mime;\n}\n\nconst std::string_view fileNameToMime(const std::string &fileName)\n{\n ContentType internalContentType = getContentType(fileName);\n if (internalContentType != CT_APPLICATION_OCTET_STREAM)\n return contentTypeToMime(internalContentType);\n\n std::string extName;\n auto pos = fileName.rfind('.');\n if (pos != std::string::npos)\n {\n extName = fileName.substr(pos + 1);\n transform(extName.begin(),\n extName.end(),\n extName.begin(),\n [](unsigned char c) { return tolower(c); });\n }\n auto it = customMime.find(extName);\n if (it == customMime.end())\n return \"\";\n return it->second;\n}\n\nstd::pair fileNameToContentTypeAndMime(\n const std::string &fileName)\n{\n ContentType internalContentType = getContentType(fileName);\n if (internalContentType != CT_APPLICATION_OCTET_STREAM)\n return {internalContentType, contentTypeToMime(internalContentType)};\n\n std::string extName;\n auto pos = fileName.rfind('.');\n if (pos != std::string::npos)\n {\n extName = fileName.substr(pos + 1);\n transform(extName.begin(),\n extName.end(),\n extName.begin(),\n [](unsigned char c) { return tolower(c); });\n }\n auto it = customMime.find(extName);\n if (it == customMime.end())\n return {CT_NONE, \"\"};\n return {CT_CUSTOM, it->second};\n}\n\nconst std::vector &getFileExtensions(ContentType contentType)\n{\n // Generate map from database for faster query\n static std::unordered_map>\n extensionMap_;\n static std::vector notFound_;\n // Thread safe initialization\n static std::once_flag flag;\n std::call_once(flag, []() {\n for (const auto &e : fileTypeDatabase_)\n if (!e.first.empty())\n extensionMap_[e.second.second].push_back(e.first);\n // Add deprecated\n extensionMap_[CT_APPLICATION_X_JAVASCRIPT] =\n extensionMap_[CT_TEXT_JAVASCRIPT];\n extensionMap_[CT_TEXT_XML] = extensionMap_[CT_APPLICATION_XML];\n });\n auto it = extensionMap_.find(contentType);\n if (it == extensionMap_.end())\n return notFound_;\n return it->second;\n}\n\n} // namespace drogon\n\n// Path: lib/src/ConfigAdapterManager.h\n#pragma once\n#include \"ConfigAdapterManager.h\"\n#include \"ConfigAdapter.h\"\n#include \n\nnamespace drogon\n{\nclass ConfigAdapterManager\n{\n public:\n static ConfigAdapterManager &instance();\n Json::Value getJson(const std::string &content, std::string ext) const\n noexcept(false);\n\n private:\n ConfigAdapterManager();\n std::map adapters_;\n};\n} // namespace drogon\n\n// Path: lib/src/ConfigAdapterManager.cc\n#include \"ConfigAdapterManager.h\"\n#include \"JsonConfigAdapter.h\"\n#include \"YamlConfigAdapter.h\"\n#include \n\nusing namespace drogon;\n#define REGISTER_CONFIG_ADAPTER(adapter) \\\n { \\\n auto adapterPtr = std::make_shared(); \\\n auto exts = adapterPtr->getExtensions(); \\\n for (auto ext : exts) \\\n { \\\n std::transform(ext.begin(), ext.end(), ext.begin(), ::tolower); \\\n adapters_[ext] = adapterPtr; \\\n } \\\n }\n\nConfigAdapterManager &ConfigAdapterManager::instance()\n{\n static ConfigAdapterManager instance;\n return instance;\n}\n\nJson::Value ConfigAdapterManager::getJson(const std::string &content,\n std::string ext) const noexcept(false)\n{\n std::transform(ext.begin(), ext.end(), ext.begin(), ::tolower);\n auto it = adapters_.find(ext);\n if (it == adapters_.end())\n {\n throw std::runtime_error(\"No valid parser for this config file!\");\n }\n return it->second->getJson(content);\n}\n\nConfigAdapterManager::ConfigAdapterManager()\n{\n REGISTER_CONFIG_ADAPTER(JsonConfigAdapter);\n REGISTER_CONFIG_ADAPTER(YamlConfigAdapter);\n}\n\n// Path: lib/src/DrTemplateBase.cc\n/**\n *\n * @file DrTemplateBase.cc\n * @author An Tao\n *\n * Copyright 2018, An Tao. All rights reserved.\n * https://github.com/an-tao/drogon\n * Use of this source code is governed by a MIT license\n * that can be found in the License file.\n *\n * Drogon\n *\n */\n\n#include \n#include \n#include \n#include \n#include \n\nusing namespace drogon;\n\nstd::shared_ptr DrTemplateBase::newTemplate(\n const std::string &templateName)\n{\n LOG_TRACE << \"http view name=\" << templateName;\n auto l = templateName.length();\n if (l >= 4 && templateName[l - 4] == '.' && templateName[l - 3] == 'c' &&\n templateName[l - 2] == 's' && templateName[l - 1] == 'p')\n {\n std::string::size_type pos = 0;\n std::string newName;\n newName.reserve(templateName.size());\n if (templateName[0] == '/' || templateName[0] == '\\\\')\n {\n pos = 1;\n }\n else if (templateName[0] == '.' &&\n (templateName[1] == '/' || templateName[1] == '\\\\'))\n {\n pos = 2;\n }\n while (pos < l - 4)\n {\n if (templateName[pos] == '/' || templateName[pos] == '\\\\')\n {\n newName.append(\"::\");\n }\n else\n {\n newName.append(1, templateName[pos]);\n }\n ++pos;\n }\n return std::shared_ptr(dynamic_cast(\n drogon::DrClassMap::newObject(newName)));\n }\n else\n {\n return std::shared_ptr(dynamic_cast(\n drogon::DrClassMap::newObject(templateName)));\n }\n}\n\n// Path: lib/src/DrClassMap.cc\n/**\n *\n * DrClassMap.cc\n * An Tao\n *\n * Copyright 2018, An Tao. All rights reserved.\n * https://github.com/an-tao/drogon\n * Use of this source code is governed by a MIT license\n * that can be found in the License file.\n *\n * Drogon\n *\n */\n\n#include \n#include \n#include \n\nusing namespace drogon;\n\nnamespace drogon\n{\nnamespace internal\n{\nstatic std::unordered_map> &\ngetObjsMap()\n{\n static std::unordered_map>\n singleInstanceMap;\n return singleInstanceMap;\n}\n\nstatic std::mutex &getMapMutex()\n{\n static std::mutex mtx;\n return mtx;\n}\n\n} // namespace internal\n} // namespace drogon\n\nvoid DrClassMap::registerClass(const std::string &className,\n const DrAllocFunc &func,\n const DrSharedAllocFunc &sharedFunc)\n{\n LOG_TRACE << \"Register class:\" << className;\n getMap().insert(\n std::make_pair(className, std::make_pair(func, sharedFunc)));\n}\n\nDrObjectBase *DrClassMap::newObject(const std::string &className)\n{\n auto iter = getMap().find(className);\n if (iter != getMap().end())\n {\n return iter->second.first();\n }\n else\n return nullptr;\n}\n\nstd::shared_ptr DrClassMap::newSharedObject(\n const std::string &className)\n{\n auto iter = getMap().find(className);\n if (iter != getMap().end())\n {\n if (iter->second.second)\n return iter->second.second();\n else\n return std::shared_ptr(iter->second.first());\n }\n else\n return nullptr;\n}\n\nconst std::shared_ptr &DrClassMap::getSingleInstance(\n const std::string &className)\n{\n auto &mtx = internal::getMapMutex();\n auto &singleInstanceMap = internal::getObjsMap();\n {\n std::lock_guard lock(mtx);\n auto iter = singleInstanceMap.find(className);\n if (iter != singleInstanceMap.end())\n return iter->second;\n }\n auto newObj = newSharedObject(className);\n {\n std::lock_guard lock(mtx);\n auto ret = singleInstanceMap.insert(\n std::make_pair(className, std::move(newObj)));\n return ret.first->second;\n }\n}\n\nvoid DrClassMap::setSingleInstance(const std::shared_ptr &ins)\n{\n auto &mtx = internal::getMapMutex();\n auto &singleInstanceMap = internal::getObjsMap();\n std::lock_guard lock(mtx);\n singleInstanceMap[ins->className()] = ins;\n}\n\nstd::vector DrClassMap::getAllClassName()\n{\n std::vector ret;\n for (auto const &iter : getMap())\n {\n ret.push_back(iter.first);\n }\n return ret;\n}\n\nstd::unordered_map> &\nDrClassMap::getMap()\n{\n static std::unordered_map>\n map;\n return map;\n}\n\n// Path: lib/src/HttpFileImpl.h\n/**\n *\n * @file HttpFileImpl.h\n * @author An Tao\n *\n * Copyright 2018, An Tao. All rights reserved.\n * https://github.com/an-tao/drogon\n * Use of this source code is governed by a MIT license\n * that can be found in the License file.\n *\n * Drogon\n *\n */\n\n#pragma once\n#include \"HttpUtils.h\"\n#include \n\n#include \n#include \n#include \n#include \n#include \n#include \n\nnamespace drogon\n{\nclass HttpFileImpl\n{\n public:\n /// Return the file name;\n const std::string &getFileName() const noexcept\n {\n return fileName_;\n }\n\n /// Set the file name, usually called by the MultiPartParser parser.\n void setFileName(const std::string &fileName) noexcept\n {\n fileName_ = fileName;\n }\n\n void setFileName(std::string &&fileName) noexcept\n {\n fileName_ = std::move(fileName);\n }\n\n /// Return the file extension;\n std::string_view getFileExtension() const noexcept\n {\n return drogon::getFileExtension(fileName_);\n }\n\n /// Set the contents of the file, usually called by the MultiPartParser\n /// parser.\n void setFile(const char *data, size_t length) noexcept\n {\n fileContent_ = std::string_view{data, length};\n }\n\n /// Save the file to the file system.\n /**\n * The folder saving the file is app().getUploadPath().\n * The full path is app().getUploadPath()+\"/\"+this->getFileName()\n */\n int save() const noexcept;\n\n /// Save the file to @param path\n /**\n * @param path if the parameter is prefixed with \"/\", \"./\" or \"../\", or is\n * \".\" or \"..\", the full path is path+\"/\"+this->getFileName(),\n * otherwise the file is saved as\n * app().getUploadPath()+\"/\"+path+\"/\"+this->getFileName()\n */\n int save(const std::string &path) const noexcept;\n\n /// Save the file to file system with a new name\n /**\n * @param fileName if the parameter isn't prefixed with \"/\", \"./\" or \"../\",\n * the full path is app().getUploadPath()+\"/\"+filename, otherwise the file\n * is saved as the filename\n */\n int saveAs(const std::string &fileName) const noexcept;\n\n /// Return the file length.\n size_t fileLength() const noexcept\n {\n return fileContent_.length();\n }\n\n const char *fileData() const noexcept\n {\n return fileContent_.data();\n }\n\n const std::string_view &fileContent() const noexcept\n {\n return fileContent_;\n }\n\n /// Return the name of the item in multiple parts.\n const std::string &getItemName() const noexcept\n {\n return itemName_;\n }\n\n void setItemName(const std::string &itemName) noexcept\n {\n itemName_ = itemName;\n }\n\n void setItemName(std::string &&itemName) noexcept\n {\n itemName_ = std::move(itemName);\n }\n\n /// Return the type of file.\n FileType getFileType() const noexcept\n {\n auto ft = drogon::getFileType(contentType_);\n if ((ft != FT_UNKNOWN) && (ft != FT_CUSTOM))\n return ft;\n return parseFileType(getFileExtension());\n }\n\n /// Return md5 hash of the file\n std::string getMd5() const noexcept;\n // Return sha1 hash of the file\n std::string getSha256() const noexcept;\n // Return sha512 hash of the file\n std::string getSha3() const noexcept;\n // int saveTo(const std::string &pathAndFileName) const;\n int saveTo(const std::filesystem::path &pathAndFileName) const noexcept;\n\n void setRequest(const HttpRequestPtr &req) noexcept\n {\n requestPtr_ = req;\n }\n\n drogon::ContentType getContentType() const noexcept\n {\n return contentType_;\n }\n\n void setContentType(drogon::ContentType contentType) noexcept\n {\n contentType_ = contentType;\n }\n\n void setContentTransferEncoding(\n const std::string &contentTransferEncoding) noexcept\n {\n transferEncoding_ = contentTransferEncoding;\n }\n\n void setContentTransferEncoding(\n std::string &&contentTransferEncoding) noexcept\n {\n transferEncoding_ = std::move(contentTransferEncoding);\n }\n\n const std::string &getContentTransferEncoding() const noexcept\n {\n return transferEncoding_;\n }\n\n private:\n std::string fileName_;\n std::string itemName_;\n std::string transferEncoding_;\n std::string_view fileContent_;\n HttpRequestPtr requestPtr_;\n drogon::ContentType contentType_{drogon::CT_NONE};\n};\n} // namespace drogon\n\n// Path: lib/src/HttpFileImpl.cc\n/**\n *\n * @file HttpFileImpl.cc\n * @author An Tao\n *\n * Copyright 2018, An Tao. All rights reserved.\n * https://github.com/an-tao/drogon\n * Use of this source code is governed by a MIT license\n * that can be found in the License file.\n *\n * Drogon\n *\n */\n\n#include \"HttpFileImpl.h\"\n#include \"HttpAppFrameworkImpl.h\"\n#include \n#include \n#include \n#include \n#include \n#include \n\nusing namespace drogon;\n\nint HttpFileImpl::save() const noexcept\n{\n return save(HttpAppFrameworkImpl::instance().getUploadPath());\n}\n\nint HttpFileImpl::save(const std::string &path) const noexcept\n{\n assert(!path.empty());\n if (fileName_.empty())\n return -1;\n std::filesystem::path fsUploadDir(utils::toNativePath(path));\n\n if (fsUploadDir.is_absolute())\n { // do nothing\n }\n else if ((!fsUploadDir.has_parent_path() ||\n (fsUploadDir.begin()->string() != \".\" &&\n fsUploadDir.begin()->string() != \"..\")))\n {\n fsUploadDir = utils::toNativePath(\n HttpAppFrameworkImpl::instance().getUploadPath()) /\n fsUploadDir;\n }\n else\n {\n fsUploadDir = std::filesystem::current_path() / fsUploadDir;\n }\n\n fsUploadDir = std::filesystem::weakly_canonical(fsUploadDir);\n\n if (!std::filesystem::exists(fsUploadDir))\n {\n LOG_TRACE << \"create path:\" << fsUploadDir;\n std::error_code err;\n std::filesystem::create_directories(fsUploadDir, err);\n if (err)\n {\n LOG_SYSERR;\n return -1;\n }\n }\n\n std::filesystem::path fsSaveToPath(std::filesystem::weakly_canonical(\n fsUploadDir / utils::toNativePath(fileName_)));\n LOG_TRACE << \"save to path:\" << fsSaveToPath;\n if (!std::equal(fsUploadDir.begin(),\n fsUploadDir.end(),\n fsSaveToPath.begin()))\n {\n LOG_ERROR\n << \"Attempt writing outside of upload directory detected. Path: \"\n << fileName_;\n return -1;\n }\n\n return saveTo(fsSaveToPath);\n}\n\nint HttpFileImpl::saveAs(const std::string &fileName) const noexcept\n{\n assert(!fileName.empty());\n std::filesystem::path fsFileName(utils::toNativePath(fileName));\n if (!fsFileName.is_absolute() && (!fsFileName.has_parent_path() ||\n (fsFileName.begin()->string() != \".\" &&\n fsFileName.begin()->string() != \"..\")))\n {\n std::filesystem::path fsUploadPath(utils::toNativePath(\n HttpAppFrameworkImpl::instance().getUploadPath()));\n fsFileName = fsUploadPath / fsFileName;\n }\n if (fsFileName.has_parent_path() &&\n !std::filesystem::exists(fsFileName.parent_path()))\n {\n LOG_TRACE << \"create path:\" << fsFileName.parent_path();\n std::error_code err;\n std::filesystem::create_directories(fsFileName.parent_path(), err);\n if (err)\n {\n LOG_SYSERR;\n return -1;\n }\n }\n return saveTo(fsFileName);\n}\n\nint HttpFileImpl::saveTo(\n const std::filesystem::path &pathAndFileName) const noexcept\n{\n LOG_TRACE << \"save uploaded file:\" << pathAndFileName;\n auto wPath = utils::toNativePath(pathAndFileName.native());\n std::ofstream file(wPath, std::ios::binary);\n if (file.is_open())\n {\n file.write(fileContent_.data(), fileContent_.size());\n file.close();\n return 0;\n }\n else\n {\n LOG_ERROR << \"save failed!\";\n return -1;\n }\n}\n\nstd::string HttpFileImpl::getMd5() const noexcept\n{\n return utils::getMd5(fileContent_.data(), fileContent_.size());\n}\n\nstd::string HttpFileImpl::getSha256() const noexcept\n{\n return utils::getSha256(fileContent_.data(), fileContent_.size());\n}\n\nstd::string HttpFileImpl::getSha3() const noexcept\n{\n return utils::getSha3(fileContent_.data(), fileContent_.size());\n}\n\nconst std::string &HttpFile::getFileName() const noexcept\n{\n return implPtr_->getFileName();\n}\n\nvoid HttpFile::setFileName(const std::string &fileName) noexcept\n{\n implPtr_->setFileName(fileName);\n}\n\nstd::string_view HttpFile::getFileExtension() const noexcept\n{\n return implPtr_->getFileExtension();\n}\n\nFileType HttpFile::getFileType() const noexcept\n{\n return implPtr_->getFileType();\n}\n\nvoid HttpFile::setFile(const char *data, size_t length) noexcept\n{\n implPtr_->setFile(data, length);\n}\n\nint HttpFile::save() const noexcept\n{\n return implPtr_->save();\n}\n\nint HttpFile::save(const std::string &path) const noexcept\n{\n return implPtr_->save(path);\n}\n\nint HttpFile::saveAs(const std::string &fileName) const noexcept\n{\n return implPtr_->saveAs(fileName);\n}\n\nsize_t HttpFile::fileLength() const noexcept\n{\n return implPtr_->fileLength();\n}\n\ndrogon::ContentType HttpFile::getContentType() const noexcept\n{\n return implPtr_->getContentType();\n}\n\nconst char *HttpFile::fileData() const noexcept\n{\n return implPtr_->fileData();\n}\n\nstd::string HttpFile::getMd5() const noexcept\n{\n return implPtr_->getMd5();\n}\n\nconst std::string &HttpFile::getContentTransferEncoding() const noexcept\n{\n return implPtr_->getContentTransferEncoding();\n}\n\nHttpFile::HttpFile(std::shared_ptr &&implPtr) noexcept\n : implPtr_(std::move(implPtr))\n{\n}\n\nconst std::string &HttpFile::getItemName() const noexcept\n{\n return implPtr_->getItemName();\n}\n\n// Path: lib/src/ListenerManager.cc\n/**\n *\n * @file ListenerManager.cc\n * @author An Tao\n *\n * Copyright 2018, An Tao. All rights reserved.\n * https://github.com/an-tao/drogon\n * Use of this source code is governed by a MIT license\n * that can be found in the License file.\n *\n * Drogon\n *\n */\n\n#include \"ListenerManager.h\"\n#include \n#include \n#include \n#include \"HttpAppFrameworkImpl.h\"\n#include \"HttpServer.h\"\n#ifndef _WIN32\n#include \n#include \n#endif\n\nnamespace drogon\n{\n#ifndef _WIN32\nclass DrogonFileLocker : public trantor::NonCopyable\n{\n public:\n DrogonFileLocker()\n {\n fd_ = open(\"/tmp/drogon.lock\", O_TRUNC | O_CREAT, 0666);\n flock(fd_, LOCK_EX);\n }\n\n ~DrogonFileLocker()\n {\n close(fd_);\n }\n\n private:\n int fd_{0};\n};\n\n#endif\n} // namespace drogon\n\nusing namespace trantor;\nusing namespace drogon;\n\nvoid ListenerManager::addListener(\n const std::string &ip,\n uint16_t port,\n bool useSSL,\n const std::string &certFile,\n const std::string &keyFile,\n bool useOldTLS,\n const std::vector> &sslConfCmds)\n{\n if (useSSL && !utils::supportsTls())\n LOG_ERROR << \"Can't use SSL without OpenSSL found in your system\";\n listeners_.emplace_back(\n ip, port, useSSL, certFile, keyFile, useOldTLS, sslConfCmds);\n}\n\nstd::vector ListenerManager::getListeners() const\n{\n std::vector listeners;\n for (auto &server : servers_)\n {\n listeners.emplace_back(server->address());\n }\n return listeners;\n}\n\nvoid ListenerManager::createListeners(\n const std::string &globalCertFile,\n const std::string &globalKeyFile,\n const std::vector> &sslConfCmds,\n const std::vector &ioLoops)\n{\n LOG_TRACE << \"thread num=\" << ioLoops.size();\n#ifdef __linux__\n for (size_t i = 0; i < ioLoops.size(); ++i)\n {\n for (auto const &listener : listeners_)\n {\n auto const &ip = listener.ip_;\n bool isIpv6 = (ip.find(':') != std::string::npos);\n InetAddress listenAddress(ip, listener.port_, isIpv6);\n if (listenAddress.isUnspecified())\n {\n LOG_FATAL << \"Failed to parse IP address '\" << ip\n << \"'. (Note: FQDN/domain names/hostnames are not \"\n \"supported. Including 'localhost')\";\n abort();\n }\n if (i == 0 && !app().reusePort())\n {\n DrogonFileLocker lock;\n // Check whether the port is in use.\n TcpServer server(HttpAppFrameworkImpl::instance().getLoop(),\n listenAddress,\n \"drogonPortTest\",\n true,\n false);\n }\n std::shared_ptr serverPtr =\n std::make_shared(ioLoops[i],\n listenAddress,\n \"drogon\");\n\n if (listener.useSSL_ && utils::supportsTls())\n {\n auto cert = listener.certFile_;\n auto key = listener.keyFile_;\n if (cert.empty())\n cert = globalCertFile;\n if (key.empty())\n key = globalKeyFile;\n if (cert.empty() || key.empty())\n {\n std::cerr\n << \"You can't use https without cert file or key file\"\n << std::endl;\n exit(1);\n }\n auto cmds = sslConfCmds;\n std::copy(listener.sslConfCmds_.begin(),\n listener.sslConfCmds_.end(),\n std::back_inserter(cmds));\n auto policy =\n trantor::TLSPolicy::defaultServerPolicy(cert, key);\n policy->setConfCmds(cmds).setUseOldTLS(listener.useOldTLS_);\n serverPtr->enableSSL(std::move(policy));\n }\n servers_.push_back(serverPtr);\n }\n }\n#else\n\n if (!listeners_.empty())\n {\n listeningThread_ =\n std::make_unique(\"DrogonListeningLoop\");\n listeningThread_->run();\n for (auto const &listener : listeners_)\n {\n auto ip = listener.ip_;\n bool isIpv6 = (ip.find(':') != std::string::npos);\n auto serverPtr = std::make_shared(\n listeningThread_->getLoop(),\n InetAddress(ip, listener.port_, isIpv6),\n \"drogon\");\n if (listener.useSSL_ && utils::supportsTls())\n {\n auto cert = listener.certFile_;\n auto key = listener.keyFile_;\n if (cert.empty())\n cert = globalCertFile;\n if (key.empty())\n key = globalKeyFile;\n if (cert.empty() || key.empty())\n {\n std::cerr\n << \"You can't use https without cert file or key file\"\n << std::endl;\n exit(1);\n }\n auto cmds = sslConfCmds;\n auto policy =\n trantor::TLSPolicy::defaultServerPolicy(cert, key);\n policy->setConfCmds(cmds).setUseOldTLS(listener.useOldTLS_);\n serverPtr->enableSSL(std::move(policy));\n }\n serverPtr->setIoLoops(ioLoops);\n servers_.push_back(serverPtr);\n }\n }\n#endif\n}\n\nvoid ListenerManager::startListening()\n{\n for (auto &server : servers_)\n {\n server->start();\n }\n}\n\nvoid ListenerManager::stopListening()\n{\n for (auto &serverPtr : servers_)\n {\n serverPtr->stop();\n }\n if (listeningThread_)\n {\n auto loop = listeningThread_->getLoop();\n assert(!loop->isInLoopThread());\n loop->quit();\n listeningThread_->wait();\n }\n}\n\n// Path: lib/src/RealIpResolver.cc\n/**\n *\n * @file RealIpResolver.cc\n * @author Nitromelon\n *\n * Copyright 2022, Nitromelon. All rights reserved.\n * https://github.com/drogonframework/drogon\n * Use of this source code is governed by a MIT license\n * that can be found in the License file.\n *\n * Drogon\n *\n */\n\n#include \n#include \n#include \n\nusing namespace drogon;\nusing namespace drogon::plugin;\n\nstruct XForwardedForParser : public trantor::NonCopyable\n{\n explicit XForwardedForParser(std::string value)\n : value_(std::move(value)), start_(value_.c_str()), len_(value_.size())\n {\n }\n\n std::string getNext()\n {\n if (len_ == 0)\n {\n return {};\n }\n // Skip trailing separators\n const char *cur;\n for (cur = start_ + len_ - 1; cur > start_; --cur, --len_)\n {\n if (*cur != ' ' && *cur != ',')\n {\n break;\n }\n }\n for (; cur > start_; --cur)\n {\n if (*cur == ' ' || *cur == ',')\n {\n ++cur;\n break;\n }\n }\n std::string ip{cur, len_ - (cur - start_)};\n len_ = cur == start_ ? 0 : cur - start_ - 1;\n return ip;\n }\n\n private:\n std::string value_;\n const char *start_;\n size_t len_;\n};\n\nstatic trantor::InetAddress parseAddress(const std::string &addr)\n{\n auto pos = addr.find(':');\n uint16_t port = 0;\n if (pos == std::string::npos)\n {\n return trantor::InetAddress(addr, 0);\n }\n try\n {\n port = std::stoi(addr.substr(pos + 1));\n }\n catch (const std::exception &ex)\n {\n (void)ex;\n LOG_ERROR << \"Error in ipv4 address: \" + addr;\n port = 0;\n }\n return trantor::InetAddress(addr.substr(0, pos), port);\n}\n\nvoid RealIpResolver::initAndStart(const Json::Value &config)\n{\n fromHeader_ = config.get(\"from_header\", \"x-forwarded-for\").asString();\n attributeKey_ = config.get(\"attribute_key\", \"real-ip\").asString();\n\n std::transform(fromHeader_.begin(),\n fromHeader_.end(),\n fromHeader_.begin(),\n [](unsigned char c) { return tolower(c); });\n if (fromHeader_ == \"x-forwarded-for\")\n {\n useXForwardedFor_ = true;\n }\n\n const Json::Value &trustIps = config[\"trust_ips\"];\n if (!trustIps.isArray())\n {\n throw std::runtime_error(\"Invalid trusted_ips. Should be array.\");\n }\n for (const auto &elem : trustIps)\n {\n std::string ipOrCidr = elem.asString();\n trustCIDRs_.emplace_back(ipOrCidr);\n }\n\n drogon::app().registerPreRoutingAdvice([this](const HttpRequestPtr &req) {\n const std::string &ipHeader = req->getHeader(fromHeader_);\n const trantor::InetAddress &peerAddr = req->getPeerAddr();\n if (ipHeader.empty() || !matchCidr(peerAddr))\n {\n // Target header is empty, or\n // direct peer is already a non-proxy\n req->attributes()->insert(attributeKey_, peerAddr);\n return;\n }\n // Use a header field which contains a single ip\n if (!useXForwardedFor_)\n {\n trantor::InetAddress addr = parseAddress(ipHeader);\n if (addr.isUnspecified())\n {\n req->attributes()->insert(attributeKey_, peerAddr);\n }\n else\n {\n req->attributes()->insert(attributeKey_, addr);\n }\n return;\n }\n // Use x-forwarded-for header, which may contains multiple ip address,\n // separated by comma\n XForwardedForParser parser(ipHeader);\n std::string ip;\n while (!(ip = parser.getNext()).empty())\n {\n trantor::InetAddress addr = parseAddress(ip);\n if (addr.isUnspecified() || matchCidr(addr))\n {\n continue;\n }\n req->attributes()->insert(attributeKey_, addr);\n return;\n }\n // No match, use peerAddr\n req->attributes()->insert(attributeKey_, peerAddr);\n });\n}\n\nvoid RealIpResolver::shutdown()\n{\n}\n\nconst trantor::InetAddress &RealIpResolver::GetRealAddr(\n const HttpRequestPtr &req)\n{\n auto *plugin = app().getPlugin();\n if (!plugin)\n {\n return req->getPeerAddr();\n }\n return plugin->getRealAddr(req);\n}\n\nconst trantor::InetAddress &RealIpResolver::getRealAddr(\n const HttpRequestPtr &req) const\n{\n const std::shared_ptr &attributesPtr = req->getAttributes();\n if (!attributesPtr->find(attributeKey_))\n {\n return req->getPeerAddr();\n }\n return attributesPtr->get(attributeKey_);\n}\n\nbool RealIpResolver::matchCidr(const trantor::InetAddress &addr) const\n{\n for (auto &cidr : trustCIDRs_)\n {\n if ((addr.ipNetEndian() & cidr.mask_) == cidr.addr_)\n {\n return true;\n }\n }\n return false;\n}\n\nRealIpResolver::CIDR::CIDR(const std::string &ipOrCidr)\n{\n // Find CIDR slash\n auto pos = ipOrCidr.find('/');\n std::string ipv4;\n if (pos != std::string::npos)\n {\n // parameter is a CIDR block\n std::string prefixLen = ipOrCidr.substr(pos + 1);\n ipv4 = ipOrCidr.substr(0, pos);\n uint16_t prefix = std::stoi(prefixLen);\n if (prefix > 32)\n {\n throw std::runtime_error(\"Bad CIDR block: \" + ipOrCidr);\n }\n mask_ = htonl(0xffffffffu << (32 - prefix));\n }\n else\n {\n // parameter is an IP\n ipv4 = ipOrCidr;\n mask_ = 0xffffffffu;\n }\n\n trantor::InetAddress addr(ipv4, 0);\n if (addr.isIpV6())\n {\n throw std::runtime_error(\"Ipv6 is not supported by RealIpResolver.\");\n }\n if (addr.isUnspecified())\n {\n throw std::runtime_error(\"Bad ipv4 address: \" + ipv4);\n }\n addr_ = addr.ipNetEndian() & mask_;\n}\n\n// Path: lib/src/FiltersFunction.cc\n/**\n *\n * @file FiltersFunction.cc\n * @author An Tao\n *\n * Copyright 2018, An Tao. All rights reserved.\n * https://github.com/an-tao/drogon\n * Use of this source code is governed by a MIT license\n * that can be found in the License file.\n *\n * Drogon\n *\n */\n\n#include \"FiltersFunction.h\"\n#include \"HttpRequestImpl.h\"\n#include \"HttpResponseImpl.h\"\n#include \"HttpAppFrameworkImpl.h\"\n#include \n\n#include \n\nnamespace drogon\n{\nnamespace filters_function\n{\nstatic void doFilterChains(\n const std::vector> &filters,\n size_t index,\n const HttpRequestImplPtr &req,\n std::shared_ptr>\n &&callbackPtr)\n{\n if (index < filters.size())\n {\n auto &filter = filters[index];\n filter->doFilter(\n req,\n [/*copy*/ callbackPtr](const HttpResponsePtr &resp) {\n (*callbackPtr)(resp);\n },\n [index, req, callbackPtr, &filters]() mutable {\n auto ioLoop = req->getLoop();\n if (ioLoop && !ioLoop->isInLoopThread())\n {\n ioLoop->queueInLoop(\n [&filters,\n index,\n req,\n callbackPtr = std::move(callbackPtr)]() mutable {\n doFilterChains(filters,\n index + 1,\n req,\n std::move(callbackPtr));\n });\n }\n else\n {\n doFilterChains(filters,\n index + 1,\n req,\n std::move(callbackPtr));\n }\n });\n }\n else\n {\n (*callbackPtr)(nullptr);\n }\n}\n\nstd::vector> createFilters(\n const std::vector &filterNames)\n{\n std::vector> filters;\n for (auto const &filter : filterNames)\n {\n auto object_ = DrClassMap::getSingleInstance(filter);\n auto filter_ = std::dynamic_pointer_cast(object_);\n if (filter_)\n filters.push_back(filter_);\n else\n {\n LOG_ERROR << \"filter \" << filter << \" not found\";\n }\n }\n return filters;\n}\n\nvoid doFilters(const std::vector> &filters,\n const HttpRequestImplPtr &req,\n std::function &&callback)\n{\n auto callbackPtr =\n std::make_shared>(std::move(callback));\n doFilterChains(filters, 0, req, std::move(callbackPtr));\n}\n\n} // namespace filters_function\n} // namespace drogon\n\n// Path: lib/src/RangeParser.cc\n/**\n *\n * RangeParser.h\n * He, Wanchen\n *\n * Copyright 2021, He,Wanchen. All rights reserved.\n * https://github.com/drogonframework/drogon\n * Use of this source code is governed by a MIT license\n * that can be found in the License file.\n *\n * Drogon\n *\n */\n\n#include \"RangeParser.h\"\n\n#include \n\nusing namespace drogon;\n\nstatic constexpr size_t MAX_SIZE = std::numeric_limits::max();\nstatic constexpr size_t MAX_TEN = MAX_SIZE / 10;\nstatic constexpr size_t MAX_DIGIT = MAX_SIZE % 10;\n\n// clang-format off\n#define DR_SKIP_WHITESPACE(p) while (*p == ' ') { ++(p); }\n#define DR_ISDIGIT(p) ('0' <= *(p) && *(p) <= '9')\n#define DR_WOULD_OVERFLOW(base, digit) \\\n (static_cast(base) > MAX_TEN || \\\n (static_cast(base) >= MAX_TEN && \\\n static_cast(digit) - '0' > MAX_DIGIT))\n\n// clang-format on\n\n/** Following formats are valid range header according to rfc7233`\n * Range: =-\n * Range: =-\n * Range: =-, -\n * Range: =-, -, -\n * Range: =-\n */\n\nFileRangeParseResult drogon::parseRangeHeader(const std::string &rangeStr,\n size_t contentLength,\n std::vector &ranges)\n{\n if (rangeStr.size() < 7 || rangeStr.compare(0, 6, \"bytes=\") != 0)\n {\n return InvalidRange;\n }\n const char *iter = rangeStr.c_str() + 6;\n\n size_t totalSize = 0;\n while (true)\n {\n size_t start = 0;\n size_t end = 0;\n // If this is a suffix range: =-\n bool isSuffix = false;\n\n DR_SKIP_WHITESPACE(iter);\n\n if (*iter == '-')\n {\n isSuffix = true;\n ++iter;\n }\n // Parse start\n else\n {\n if (!DR_ISDIGIT(iter))\n {\n return InvalidRange;\n }\n while (DR_ISDIGIT(iter))\n {\n // integer out of range\n if (DR_WOULD_OVERFLOW(start, *iter))\n {\n return NotSatisfiable;\n }\n start = start * 10 + (*iter++ - '0');\n }\n DR_SKIP_WHITESPACE(iter);\n // should be separator now\n if (*iter++ != '-')\n {\n return InvalidRange;\n }\n DR_SKIP_WHITESPACE(iter);\n // If this is a prefix range =-\n if (*iter == ',' || *iter == '\\0')\n {\n end = contentLength;\n // Handle found\n if (start < end)\n {\n if (totalSize > MAX_SIZE - (end - start))\n {\n return NotSatisfiable;\n }\n totalSize += end - start;\n ranges.push_back({start, end});\n }\n if (*iter++ != ',')\n {\n break;\n }\n continue;\n }\n }\n\n // Parse end\n if (!DR_ISDIGIT(iter))\n {\n return InvalidRange;\n }\n while (DR_ISDIGIT(iter))\n {\n if (DR_WOULD_OVERFLOW(end, *iter))\n {\n return NotSatisfiable;\n }\n end = end * 10 + (*iter++ - '0');\n }\n DR_SKIP_WHITESPACE(iter);\n\n if (*iter != ',' && *iter != '\\0')\n {\n return InvalidRange;\n }\n if (isSuffix)\n {\n start = (end < contentLength) ? contentLength - end : 0;\n end = contentLength - 1;\n }\n // [start, end)\n if (end >= contentLength)\n {\n end = contentLength;\n }\n else\n {\n ++end;\n }\n\n // handle found\n if (start < end)\n {\n ranges.push_back({start, end});\n if (totalSize > MAX_SIZE - (end - start))\n {\n return NotSatisfiable;\n }\n totalSize += end - start;\n // We restrict the number to be under 100, to avoid malicious\n // requests.\n // Though rfc does not say anything about max number of ranges,\n // it does mention that server can ignore range header freely.\n if (ranges.size() > 100)\n {\n return InvalidRange;\n }\n }\n if (*iter++ != ',')\n {\n break;\n }\n }\n\n if (ranges.size() == 0 || totalSize > contentLength)\n {\n return NotSatisfiable;\n }\n...\n", "needle_token_start": 4019, "needle_token_end": 4127, "code_context_ntokens": 16101, "output": ["```cpp\nvoid registerCustomExtensionMime(const std::string &ext,\n const std::string &mime)\n{\n if (ext.empty())\n return;\n auto &mimeStr = customMime[ext];\n if (!mimeStr.empty())\n {\n LOG_WARN << ext << \" has already been registered as type \" << mime\n << \". Overwriting.\";\n }\n mimeStr = mime;\n}\n```"]} +{"repo": "drogonframework/drogon", "name": "genHttpResponse", "language": "cpp", "path": "lib/src/HttpResponseImpl.cc", "position_ratio": 0.35, "description": "\nFunction Description:\n1. **Purpose**: To generate an HTTP response based on a specified template and data, or return a not-found response if the template does not exist.\n2. **Input**: A string representing the name of the template, a data structure containing the data to be populated in the template, and a pointer to the HTTP request.\n3. **Output**: A pointer to an HTTP response object, either containing the rendered template or a not-found status.\n4. **Procedure**: \n - The function attempts to create a new template object using the provided template name.\n - If the template object is successfully created, it generates the text for the HTTP response body using the provided data.\n - A new HTTP response object is then created, and the generated text is set as the body of this response.\n - If the template object cannot be created (i.e., the template does not exist), the function returns a not-found HTTP response, indicating that the requested template could not be found.\n\n", "instruction": "Based on the function description and code context, please retrieve and repeat the exact described function from the code context in a code block wrapped by ```:", "template": "instruction\ncode_context\ndescription\ninstruction", "code_context": "// Path: lib/src/HttpClientImpl.cc\n/**\n *\n * @file HttpClientImpl.cc\n * @author An Tao\n *\n * Copyright 2018, An Tao. All rights reserved.\n * https://github.com/an-tao/drogon\n * Use of this source code is governed by a MIT license\n * that can be found in the License file.\n *\n * Drogon\n *\n */\n\n#include \"HttpClientImpl.h\"\n#include \"HttpAppFrameworkImpl.h\"\n#include \"HttpRequestImpl.h\"\n#include \"HttpResponseImpl.h\"\n#include \"HttpResponseParser.h\"\n\n#include \n#include \n#include \n\nusing namespace trantor;\nusing namespace drogon;\nusing namespace std::placeholders;\n\nnamespace trantor\n{\nstatic const size_t kDefaultDNSTimeout{600};\n}\n\nvoid HttpClientImpl::createTcpClient()\n{\n LOG_TRACE << \"New TcpClient,\" << serverAddr_.toIpPort();\n tcpClientPtr_ =\n std::make_shared(loop_, serverAddr_, \"httpClient\");\n\n if (useSSL_ && utils::supportsTls())\n {\n LOG_TRACE << \"useOldTLS=\" << useOldTLS_;\n LOG_TRACE << \"domain=\" << domain_;\n auto policy = trantor::TLSPolicy::defaultClientPolicy();\n policy->setUseOldTLS(useOldTLS_)\n .setValidate(validateCert_)\n .setHostname(domain_)\n .setConfCmds(sslConfCmds_)\n .setCertPath(clientCertPath_)\n .setKeyPath(clientKeyPath_);\n tcpClientPtr_->enableSSL(std::move(policy));\n }\n\n auto thisPtr = shared_from_this();\n std::weak_ptr weakPtr = thisPtr;\n tcpClientPtr_->setSockOptCallback([weakPtr](int fd) {\n auto thisPtr = weakPtr.lock();\n if (!thisPtr)\n return;\n if (thisPtr->sockOptCallback_)\n thisPtr->sockOptCallback_(fd);\n });\n tcpClientPtr_->setConnectionCallback(\n [weakPtr](const trantor::TcpConnectionPtr &connPtr) {\n auto thisPtr = weakPtr.lock();\n if (!thisPtr)\n return;\n if (connPtr->connected())\n {\n connPtr->setContext(\n std::make_shared(connPtr));\n // send request;\n LOG_TRACE << \"Connection established!\";\n while (thisPtr->pipeliningCallbacks_.size() <=\n thisPtr->pipeliningDepth_ &&\n !thisPtr->requestsBuffer_.empty())\n {\n thisPtr->sendReq(connPtr,\n thisPtr->requestsBuffer_.front().first);\n thisPtr->pipeliningCallbacks_.push(\n std::move(thisPtr->requestsBuffer_.front()));\n thisPtr->requestsBuffer_.pop_front();\n }\n }\n else\n {\n LOG_TRACE << \"connection disconnect\";\n auto responseParser = connPtr->getContext();\n if (responseParser && responseParser->parseResponseOnClose() &&\n responseParser->gotAll())\n {\n auto &firstReq = thisPtr->pipeliningCallbacks_.front();\n if (firstReq.first->method() == Head)\n {\n responseParser->setForHeadMethod();\n }\n auto resp = responseParser->responseImpl();\n responseParser->reset();\n // temporary fix of dead tcpClientPtr_\n // TODO: fix HttpResponseParser when content-length absence\n thisPtr->tcpClientPtr_.reset();\n thisPtr->handleResponse(resp, std::move(firstReq), connPtr);\n if (!thisPtr->requestsBuffer_.empty())\n {\n thisPtr->createTcpClient();\n }\n return;\n }\n thisPtr->onError(ReqResult::NetworkFailure);\n }\n });\n tcpClientPtr_->setConnectionErrorCallback([weakPtr]() {\n auto thisPtr = weakPtr.lock();\n if (!thisPtr)\n return;\n // can't connect to server\n thisPtr->onError(ReqResult::BadServerAddress);\n });\n tcpClientPtr_->setMessageCallback(\n [weakPtr](const trantor::TcpConnectionPtr &connPtr,\n trantor::MsgBuffer *msg) {\n auto thisPtr = weakPtr.lock();\n if (thisPtr)\n {\n thisPtr->onRecvMessage(connPtr, msg);\n }\n });\n tcpClientPtr_->setSSLErrorCallback([weakPtr](SSLError err) {\n auto thisPtr = weakPtr.lock();\n if (!thisPtr)\n return;\n if (err == trantor::SSLError::kSSLHandshakeError)\n thisPtr->onError(ReqResult::HandshakeError);\n else if (err == trantor::SSLError::kSSLInvalidCertificate)\n thisPtr->onError(ReqResult::InvalidCertificate);\n else if (err == trantor::SSLError::kSSLProtocolError)\n thisPtr->onError(ReqResult::EncryptionFailure);\n else\n {\n LOG_FATAL << \"Invalid value for SSLError\";\n abort();\n }\n });\n tcpClientPtr_->connect();\n}\n\nHttpClientImpl::HttpClientImpl(trantor::EventLoop *loop,\n const trantor::InetAddress &addr,\n bool useSSL,\n bool useOldTLS,\n bool validateCert)\n : loop_(loop),\n serverAddr_(addr),\n useSSL_(useSSL),\n validateCert_(validateCert),\n useOldTLS_(useOldTLS)\n{\n}\n\nHttpClientImpl::HttpClientImpl(trantor::EventLoop *loop,\n const std::string &hostString,\n bool useOldTLS,\n bool validateCert)\n : loop_(loop), validateCert_(validateCert), useOldTLS_(useOldTLS)\n{\n auto lowerHost = hostString;\n std::transform(lowerHost.begin(),\n lowerHost.end(),\n lowerHost.begin(),\n [](unsigned char c) { return tolower(c); });\n if (lowerHost.find(\"https://\") == 0)\n {\n useSSL_ = true;\n lowerHost = lowerHost.substr(8);\n }\n else if (lowerHost.find(\"http://\") == 0)\n {\n useSSL_ = false;\n lowerHost = lowerHost.substr(7);\n }\n else\n {\n return;\n }\n auto pos = lowerHost.find(']');\n if (lowerHost[0] == '[' && pos != std::string::npos)\n {\n // ipv6\n domain_ = lowerHost.substr(1, pos - 1);\n if (lowerHost[pos + 1] == ':')\n {\n auto portStr = lowerHost.substr(pos + 2);\n pos = portStr.find('/');\n if (pos != std::string::npos)\n {\n portStr = portStr.substr(0, pos);\n }\n auto port = atoi(portStr.c_str());\n if (port > 0 && port < 65536)\n {\n serverAddr_ = InetAddress(domain_, port, true);\n }\n }\n else\n {\n if (useSSL_)\n {\n serverAddr_ = InetAddress(domain_, 443, true);\n }\n else\n {\n serverAddr_ = InetAddress(domain_, 80, true);\n }\n }\n }\n else\n {\n auto pos = lowerHost.find(':');\n if (pos != std::string::npos)\n {\n domain_ = lowerHost.substr(0, pos);\n auto portStr = lowerHost.substr(pos + 1);\n pos = portStr.find('/');\n if (pos != std::string::npos)\n {\n portStr = portStr.substr(0, pos);\n }\n auto port = atoi(portStr.c_str());\n if (port > 0 && port < 65536)\n {\n serverAddr_ = InetAddress(domain_, port);\n }\n }\n else\n {\n domain_ = lowerHost;\n pos = domain_.find('/');\n if (pos != std::string::npos)\n {\n domain_ = domain_.substr(0, pos);\n }\n if (useSSL_)\n {\n serverAddr_ = InetAddress(domain_, 443);\n }\n else\n {\n serverAddr_ = InetAddress(domain_, 80);\n }\n }\n }\n if (serverAddr_.isUnspecified())\n {\n isDomainName_ = true;\n }\n LOG_TRACE << \"userSSL=\" << useSSL_ << \" domain=\" << domain_;\n}\n\nHttpClientImpl::~HttpClientImpl()\n{\n LOG_TRACE << \"Deconstruction HttpClient\";\n if (resolverPtr_ && !(loop_->isInLoopThread()))\n {\n // Make sure the resolverPtr_ is destroyed in the correct thread.\n loop_->queueInLoop([resolverPtr = std::move(resolverPtr_)]() {});\n }\n}\n\nvoid HttpClientImpl::sendRequest(const drogon::HttpRequestPtr &req,\n const drogon::HttpReqCallback &callback,\n double timeout)\n{\n auto thisPtr = shared_from_this();\n loop_->runInLoop([thisPtr, req, callback = callback, timeout]() mutable {\n thisPtr->sendRequestInLoop(req, std::move(callback), timeout);\n });\n}\n\nvoid HttpClientImpl::sendRequest(const drogon::HttpRequestPtr &req,\n drogon::HttpReqCallback &&callback,\n double timeout)\n{\n auto thisPtr = shared_from_this();\n loop_->runInLoop(\n [thisPtr, req, callback = std::move(callback), timeout]() mutable {\n thisPtr->sendRequestInLoop(req, std::move(callback), timeout);\n });\n}\n\nstruct RequestCallbackParams\n{\n RequestCallbackParams(HttpReqCallback &&cb,\n HttpClientImplPtr client,\n HttpRequestPtr req)\n : callback(std::move(cb)),\n clientPtr(std::move(client)),\n requestPtr(std::move(req))\n {\n }\n\n const drogon::HttpReqCallback callback;\n const HttpClientImplPtr clientPtr;\n const HttpRequestPtr requestPtr;\n bool timeoutFlag{false};\n};\n\nvoid HttpClientImpl::sendRequestInLoop(const HttpRequestPtr &req,\n HttpReqCallback &&callback,\n double timeout)\n{\n if (timeout <= 0)\n {\n sendRequestInLoop(req, std::move(callback));\n return;\n }\n\n auto callbackParamsPtr =\n std::make_shared(std::move(callback),\n shared_from_this(),\n req);\n\n loop_->runAfter(\n timeout,\n [weakCallbackBackPtr =\n std::weak_ptr(callbackParamsPtr)] {\n auto callbackParamsPtr = weakCallbackBackPtr.lock();\n if (callbackParamsPtr != nullptr)\n {\n auto &thisPtr = callbackParamsPtr->clientPtr;\n if (callbackParamsPtr->timeoutFlag)\n {\n return;\n }\n\n callbackParamsPtr->timeoutFlag = true;\n\n for (auto iter = thisPtr->requestsBuffer_.begin();\n iter != thisPtr->requestsBuffer_.end();\n ++iter)\n {\n if (iter->first == callbackParamsPtr->requestPtr)\n {\n thisPtr->requestsBuffer_.erase(iter);\n break;\n }\n }\n\n (callbackParamsPtr->callback)(ReqResult::Timeout, nullptr);\n }\n });\n sendRequestInLoop(req,\n [callbackParamsPtr](ReqResult r,\n const HttpResponsePtr &resp) {\n if (callbackParamsPtr->timeoutFlag)\n {\n return;\n }\n callbackParamsPtr->timeoutFlag = true;\n (callbackParamsPtr->callback)(r, resp);\n });\n}\n\nstatic bool isValidIpAddr(const trantor::InetAddress &addr)\n{\n if (addr.portNetEndian() == 0)\n {\n return false;\n }\n if (!addr.isIpV6())\n {\n return addr.ipNetEndian() != 0;\n }\n // Is ipv6\n auto ipaddr = addr.ip6NetEndian();\n for (int i = 0; i < 4; ++i)\n {\n if (ipaddr[i] != 0)\n {\n return true;\n }\n }\n return false;\n}\n\nvoid HttpClientImpl::sendRequestInLoop(const drogon::HttpRequestPtr &req,\n drogon::HttpReqCallback &&callback)\n{\n loop_->assertInLoopThread();\n if (!static_cast(req.get())->passThrough())\n {\n req->addHeader(\"connection\", \"Keep-Alive\");\n if (!userAgent_.empty())\n req->addHeader(\"user-agent\", userAgent_);\n }\n // Set the host header if not already set\n if (req->getHeader(\"host\").empty())\n {\n if (onDefaultPort())\n {\n req->addHeader(\"host\", host());\n }\n else\n {\n...\n// Path: lib/src/AOPAdvice.h\n/**\n *\n * AOPAdvice.h\n * An Tao\n *\n * Copyright 2018, An Tao. All rights reserved.\n * https://github.com/an-tao/drogon\n * Use of this source code is governed by a MIT license\n * that can be found in the License file.\n *\n * Drogon\n *\n */\n\n#pragma once\n#include \"impl_forwards.h\"\n#include \n#include \n#include \n#include \n\nnamespace drogon\n{\nclass AopAdvice\n{\n public:\n static AopAdvice &instance()\n {\n static AopAdvice inst;\n return inst;\n }\n\n // Getters?\n bool hasPreRoutingAdvices() const\n {\n return !preRoutingAdvices_.empty();\n }\n\n bool hasPostRoutingAdvices() const\n {\n return !postRoutingAdvices_.empty();\n }\n\n bool hasPreHandlingAdvices() const\n {\n return !preHandlingAdvices_.empty();\n }\n\n // Setters?\n void registerNewConnectionAdvice(\n std::function advice)\n {\n newConnectionAdvices_.emplace_back(std::move(advice));\n }\n\n void registerHttpResponseCreationAdvice(\n std::function advice)\n {\n responseCreationAdvices_.emplace_back(std::move(advice));\n }\n\n void registerSyncAdvice(\n std::function advice)\n\n {\n syncAdvices_.emplace_back(std::move(advice));\n }\n\n void registerPreRoutingObserver(\n std::function advice)\n {\n preRoutingObservers_.emplace_back(std::move(advice));\n }\n\n void registerPreRoutingAdvice(\n std::function advice)\n {\n preRoutingAdvices_.emplace_back(std::move(advice));\n }\n\n void registerPostRoutingObserver(\n std::function advice)\n {\n postRoutingObservers_.emplace_back(std::move(advice));\n }\n\n void registerPostRoutingAdvice(\n std::function advice)\n {\n postRoutingAdvices_.emplace_back(std::move(advice));\n }\n\n void registerPreHandlingObserver(\n std::function advice)\n {\n preHandlingObservers_.emplace_back(std::move(advice));\n }\n\n void registerPreHandlingAdvice(\n std::function advice)\n {\n preHandlingAdvices_.emplace_back(std::move(advice));\n }\n\n void registerPostHandlingAdvice(\n std::function\n advice)\n {\n postHandlingAdvices_.emplace_back(std::move(advice));\n }\n\n void registerPreSendingAdvice(\n std::function\n advice)\n {\n preSendingAdvices_.emplace_back(std::move(advice));\n }\n\n // Executors\n bool passNewConnectionAdvices(const trantor::TcpConnectionPtr &conn) const;\n void passResponseCreationAdvices(const HttpResponsePtr &resp) const;\n\n HttpResponsePtr passSyncAdvices(const HttpRequestPtr &req) const;\n void passPreRoutingObservers(const HttpRequestImplPtr &req) const;\n void passPreRoutingAdvices(\n const HttpRequestImplPtr &req,\n std::function &&callback) const;\n void passPostRoutingObservers(const HttpRequestImplPtr &req) const;\n void passPostRoutingAdvices(\n const HttpRequestImplPtr &req,\n std::function &&callback) const;\n void passPreHandlingObservers(const HttpRequestImplPtr &req) const;\n void passPreHandlingAdvices(\n const HttpRequestImplPtr &req,\n std::function &&callback) const;\n void passPostHandlingAdvices(const HttpRequestImplPtr &req,\n const HttpResponsePtr &resp) const;\n void passPreSendingAdvices(const HttpRequestImplPtr &req,\n const HttpResponsePtr &resp) const;\n\n private:\n using SyncAdvice = std::function;\n using SyncReqObserver = std::function;\n using SyncObserver =\n std::function;\n using AsyncAdvice = std::function;\n\n // If we want to add aop functions anytime, we can add a mutex here\n\n std::vector>\n newConnectionAdvices_;\n std::vector>\n responseCreationAdvices_;\n\n std::vector syncAdvices_;\n std::vector preRoutingObservers_;\n std::vector preRoutingAdvices_;\n std::vector postRoutingObservers_;\n std::vector postRoutingAdvices_;\n std::vector preHandlingObservers_;\n std::vector preHandlingAdvices_;\n std::vector postHandlingAdvices_;\n std::vector preSendingAdvices_;\n};\n\n} // namespace drogon\n\n// Path: lib/src/HttpResponseImpl.cc\n/**\n *\n * @file HttpResponseImpl.cc\n * @author An Tao\n *\n * Copyright 2018, An Tao. All rights reserved.\n * https://github.com/an-tao/drogon\n * Use of this source code is governed by a MIT license\n * that can be found in the License file.\n *\n * Drogon\n *\n */\n\n#include \"HttpResponseImpl.h\"\n#include \"AOPAdvice.h\"\n#include \"HttpAppFrameworkImpl.h\"\n#include \"HttpUtils.h\"\n#include \n#include \n#include \n#include \n#include \n#include \n#include \n#include \n#include \n\nusing namespace trantor;\nusing namespace drogon;\nusing namespace std::literals::string_literals;\nusing namespace std::placeholders;\n#ifdef _WIN32\n#undef max\n#endif\n\nnamespace drogon\n{\n// \"Fri, 23 Aug 2019 12:58:03 GMT\" length = 29\nstatic const size_t httpFullDateStringLength = 29;\n\n\nstatic inline HttpResponsePtr genHttpResponse(const std::string &viewName,\n const HttpViewData &data,\n const HttpRequestPtr &req)\n{\n auto templ = DrTemplateBase::newTemplate(viewName);\n if (templ)\n {\n auto res = HttpResponse::newHttpResponse();\n res->setBody(templ->genText(data));\n return res;\n }\n return drogon::HttpResponse::newNotFoundResponse(req);\n}\n} // namespace drogon\n\nHttpResponsePtr HttpResponse::newHttpResponse()\n{\n auto res = std::make_shared(k200OK, CT_TEXT_HTML);\n AopAdvice::instance().passResponseCreationAdvices(res);\n return res;\n}\n\nHttpResponsePtr HttpResponse::newHttpResponse(HttpStatusCode code,\n ContentType type)\n{\n auto res = std::make_shared(code, type);\n AopAdvice::instance().passResponseCreationAdvices(res);\n return res;\n}\n\nHttpResponsePtr HttpResponse::newHttpJsonResponse(const Json::Value &data)\n{\n auto res = std::make_shared(k200OK, CT_APPLICATION_JSON);\n res->setJsonObject(data);\n AopAdvice::instance().passResponseCreationAdvices(res);\n return res;\n}\n\nHttpResponsePtr HttpResponse::newHttpJsonResponse(Json::Value &&data)\n{\n auto res = std::make_shared(k200OK, CT_APPLICATION_JSON);\n res->setJsonObject(std::move(data));\n AopAdvice::instance().passResponseCreationAdvices(res);\n return res;\n}\n\nconst char *HttpResponseImpl::versionString() const\n{\n const char *result = \"UNKNOWN\";\n switch (version_)\n {\n case Version::kHttp10:\n result = \"HTTP/1.0\";\n break;\n\n case Version::kHttp11:\n result = \"HTTP/1.1\";\n break;\n\n default:\n break;\n }\n return result;\n}\n\nvoid HttpResponseImpl::generateBodyFromJson() const\n{\n if (!jsonPtr_ || flagForSerializingJson_)\n {\n return;\n }\n flagForSerializingJson_ = true;\n static std::once_flag once;\n static Json::StreamWriterBuilder builder;\n std::call_once(once, []() {\n builder[\"commentStyle\"] = \"None\";\n builder[\"indentation\"] = \"\";\n if (!app().isUnicodeEscapingUsedInJson())\n {\n builder[\"emitUTF8\"] = true;\n }\n auto &precision = app().getFloatPrecisionInJson();\n if (precision.first != 0)\n {\n builder[\"precision\"] = precision.first;\n builder[\"precisionType\"] = precision.second;\n }\n });\n bodyPtr_ = std::make_shared(\n writeString(builder, *jsonPtr_));\n}\n\nHttpResponsePtr HttpResponse::newNotFoundResponse(const HttpRequestPtr &req)\n{\n auto loop = trantor::EventLoop::getEventLoopOfCurrentThread();\n auto &resp = HttpAppFrameworkImpl::instance().getCustom404Page();\n if (resp)\n {\n if (loop && loop->index() < app().getThreadNum())\n {\n return resp;\n }\n else\n {\n return HttpResponsePtr{new HttpResponseImpl(\n *static_cast(resp.get()))};\n }\n }\n else\n {\n if (HttpAppFrameworkImpl::instance().isUsingCustomErrorHandler())\n {\n return app().getCustomErrorHandler()(k404NotFound, req);\n }\n else if (loop && loop->index() < app().getThreadNum())\n {\n // If the current thread is an IO thread\n static std::once_flag threadOnce;\n static IOThreadStorage thread404Pages;\n std::call_once(threadOnce, [req = req] {\n thread404Pages.init([req = req](drogon::HttpResponsePtr &resp,\n size_t /*index*/) {\n HttpViewData data;\n data.insert(\"version\", drogon::getVersion());\n resp = HttpResponse::newHttpViewResponse(\"drogon::NotFound\",\n data);\n resp->setStatusCode(k404NotFound);\n resp->setExpiredTime(0);\n });\n });\n LOG_TRACE << \"Use cached 404 response\";\n return thread404Pages.getThreadData();\n }\n else\n {\n HttpViewData data;\n data.insert(\"version\", drogon::getVersion());\n auto notFoundResp =\n HttpResponse::newHttpViewResponse(\"drogon::NotFound\", data);\n notFoundResp->setStatusCode(k404NotFound);\n return notFoundResp;\n }\n }\n}\n\nHttpResponsePtr HttpResponse::newRedirectionResponse(\n const std::string &location,\n HttpStatusCode status)\n{\n auto res = std::make_shared();\n res->setStatusCode(status);\n res->redirect(location);\n AopAdvice::instance().passResponseCreationAdvices(res);\n return res;\n}\n\nHttpResponsePtr HttpResponse::newHttpViewResponse(const std::string &viewName,\n const HttpViewData &data,\n const HttpRequestPtr &req)\n{\n return genHttpResponse(viewName, data, req);\n}\n\nHttpResponsePtr HttpResponse::newFileResponse(\n const unsigned char *pBuffer,\n size_t bufferLength,\n const std::string &attachmentFileName,\n ContentType type,\n const std::string &typeString)\n{\n // Make Raw HttpResponse\n auto resp = std::make_shared();\n\n // Set response body and length\n resp->setBody(\n std::string(reinterpret_cast(pBuffer), bufferLength));\n\n // Set status of message\n resp->setStatusCode(k200OK);\n\n // Check for type and assign proper content type in header\n if (!typeString.empty())\n {\n // auto contentType = type;\n if (type == CT_NONE)\n type = parseContentType(typeString);\n if (type == CT_NONE)\n type = CT_APPLICATION_OCTET_STREAM; // XXX: Is this Ok?\n static_cast(resp.get())\n ->setContentTypeCodeAndCustomString(type,\n typeString.c_str(),\n typeString.size());\n }\n else if (type != CT_NONE)\n {\n resp->setContentTypeCode(type);\n }\n else if (!attachmentFileName.empty())\n {\n resp->setContentTypeCode(drogon::getContentType(attachmentFileName));\n }\n else\n {\n resp->setContentTypeCode(\n CT_APPLICATION_OCTET_STREAM); // default content-type for file;\n }\n\n // Add additional header values\n if (!attachmentFileName.empty())\n {\n resp->addHeader(\"Content-Disposition\",\n \"attachment; filename=\" + attachmentFileName);\n }\n\n // Finalize and return response\n AopAdvice::instance().passResponseCreationAdvices(resp);\n return resp;\n}\n\nHttpResponsePtr HttpResponse::newFileResponse(\n const std::string &fullPath,\n const std::string &attachmentFileName,\n ContentType type,\n const std::string &typeString,\n const HttpRequestPtr &req)\n{\n return newFileResponse(\n fullPath, 0, 0, false, attachmentFileName, type, typeString, req);\n}\n\nHttpResponsePtr HttpResponse::newFileResponse(\n const std::string &fullPath,\n size_t offset,\n size_t length,\n bool setContentRange,\n const std::string &attachmentFileName,\n ContentType type,\n const std::string &typeString,\n const HttpRequestPtr &req)\n{\n std::ifstream infile(utils::toNativePath(fullPath), std::ifstream::binary);\n LOG_TRACE << \"send http file:\" << fullPath << \" offset \" << offset\n << \" length \" << length;\n if (!infile)\n {\n auto resp = HttpResponse::newNotFoundResponse(req);\n return resp;\n }\n auto resp = std::make_shared();\n std::streambuf *pbuf = infile.rdbuf();\n size_t filesize =\n static_cast(pbuf->pubseekoff(0, std::ifstream::end));\n if (offset > filesize || length > filesize || // in case of overflow\n offset + length > filesize)\n {\n resp->setStatusCode(k416RequestedRangeNotSatisfiable);\n if (setContentRange)\n {\n char buf[64];\n snprintf(buf, sizeof(buf), \"bytes */%zu\", filesize);\n resp->addHeader(\"Content-Range\", std::string(buf));\n }\n return resp;\n }\n if (length == 0)\n {\n length = filesize - offset;\n }\n pbuf->pubseekoff(offset, std::ifstream::beg); // rewind\n\n if (HttpAppFrameworkImpl::instance().useSendfile() && length > 1024 * 200)\n // TODO : Is 200k an appropriate value? Or set it to be configurable\n {\n // The advantages of sendfile() can only be reflected in sending large\n // files.\n resp->setSendfile(fullPath);\n // Must set length with the right value! Content-Length header relies on\n // this value.\n resp->setSendfileRange(offset, length);\n }\n else\n {\n std::string str;\n str.resize(length);\n pbuf->sgetn(&str[0], length);\n resp->setBody(std::move(str));\n resp->setSendfileRange(offset, length);\n }\n\n // Set correct status code\n if (length < filesize)\n {\n resp->setStatusCode(k206PartialContent);\n }\n else\n {\n resp->setStatusCode(k200OK);\n }\n\n // Infer content type\n if (type == CT_NONE)\n {\n if (!typeString.empty())\n {\n auto r = static_cast(resp.get());\n // auto contentType = type;\n if (type == CT_NONE)\n type = parseContentType(typeString);\n if (type == CT_NONE)\n type = CT_CUSTOM; // XXX: Is this Ok?\n r->setContentTypeCodeAndCustomString(type, typeString);\n }\n else if (!attachmentFileName.empty())\n {\n resp->setContentTypeCode(\n drogon::getContentType(attachmentFileName));\n }\n else\n {\n resp->setContentTypeCode(drogon::getContentType(fullPath));\n }\n }\n else\n {\n if (typeString.empty())\n resp->setContentTypeCode(type);\n else\n {\n auto r = static_cast(resp.get());\n // auto contentType = type;\n if (type == CT_NONE)\n type = parseContentType(typeString);\n if (type == CT_NONE)\n type = CT_CUSTOM; // XXX: Is this Ok?\n r->setContentTypeCodeAndCustomString(type, typeString);\n }\n }\n\n // Set headers\n if (!attachmentFileName.empty())\n {\n resp->addHeader(\"Content-Disposition\",\n \"attachment; filename=\" + attachmentFileName);\n }\n if (setContentRange && length > 0)\n {\n char buf[128];\n snprintf(buf,\n sizeof(buf),\n \"bytes %zu-%zu/%zu\",\n offset,\n offset + length - 1,\n filesize);\n resp->addHeader(\"Content-Range\", std::string(buf));\n }\n AopAdvice::instance().passResponseCreationAdvices(resp);\n return resp;\n}\n\nHttpResponsePtr HttpResponse::newStreamResponse(\n const std::function &callback,\n const std::string &attachmentFileName,\n ContentType type,\n const std::string &typeString,\n const HttpRequestPtr &req)\n{\n LOG_TRACE << \"send stream as \"s\n << (attachmentFileName.empty() ? \"raw data\"s\n : \"file: \"s + attachmentFileName);\n if (!callback)\n {\n auto resp = HttpResponse::newNotFoundResponse();\n return resp;\n }\n auto resp = std::make_shared();\n resp->setStreamCallback(callback);\n resp->setStatusCode(k200OK);\n\n // Infer content type\n if (type == CT_NONE)\n {\n if (!typeString.empty())\n {\n auto r = static_cast(resp.get());\n auto contentType = type;\n if (type == CT_NONE)\n type = parseContentType(typeString);\n if (type == CT_NONE)\n type = CT_CUSTOM; // XXX: Is this Ok?\n r->setContentTypeCodeAndCustomString(type, typeString);\n }\n else if (!attachmentFileName.empty())\n {\n resp->setContentTypeCode(\n drogon::getContentType(attachmentFileName));\n }\n }\n else\n {\n if (typeString.empty())\n resp->setContentTypeCode(type);\n else\n {\n auto r = static_cast(resp.get());\n auto contentType = type;\n if (type == CT_NONE)\n type = parseContentType(typeString);\n if (type == CT_NONE)\n type = CT_CUSTOM; // XXX: Is this Ok?\n r->setContentTypeCodeAndCustomString(type, typeString);\n }\n }\n\n // Set headers\n if (!attachmentFileName.empty())\n {\n resp->addHeader(\"Content-Disposition\",\n \"attachment; filename=\" + attachmentFileName);\n }\n AopAdvice::instance().passResponseCreationAdvices(resp);\n return resp;\n}\n\nHttpResponsePtr HttpResponse::newAsyncStreamResponse(\n const std::function &callback,\n bool disableKickoffTimeout)\n{\n if (!callback)\n {\n auto resp = HttpResponse::newNotFoundResponse();\n return resp;\n }\n auto resp = std::make_shared();\n resp->setAsyncStreamCallback(callback, disableKickoffTimeout);\n resp->setStatusCode(k200OK);\n AopAdvice::instance().passResponseCreationAdvices(resp);\n return resp;\n}\n\nvoid HttpResponseImpl::makeHeaderString(trantor::MsgBuffer &buffer)\n{\n buffer.ensureWritableBytes(128);\n int len{0};\n if (version_ == Version::kHttp11)\n {\n if (customStatusCode_ >= 0)\n {\n len = snprintf(buffer.beginWrite(),\n buffer.writableBytes(),\n \"HTTP/1.1 %d \",\n customStatusCode_);\n }\n else\n {\n len = snprintf(buffer.beginWrite(),\n buffer.writableBytes(),\n \"HTTP/1.1 %d \",\n statusCode_);\n }\n }\n else\n {\n if (customStatusCode_ >= 0)\n {\n len = snprintf(buffer.beginWrite(),\n buffer.writableBytes(),\n \"HTTP/1.0 %d \",\n customStatusCode_);\n }\n else\n {\n len = snprintf(buffer.beginWrite(),\n buffer.writableBytes(),\n \"HTTP/1.0 %d \",\n statusCode_);\n }\n }\n buffer.hasWritten(len);\n\n if (!statusMessage_.empty())\n buffer.append(statusMessage_.data(), statusMessage_.length());\n buffer.append(\"\\r\\n\");\n generateBodyFromJson();\n if (!passThrough_)\n {\n buffer.ensureWritableBytes(64);\n if (streamCallback_ || asyncStreamCallback_)\n {\n // When the headers are created, it is time to set the transfer\n // encoding to chunked if the contents size is not specified\n if (!ifCloseConnection() &&\n headers_.find(\"content-length\") == headers_.end())\n {\n LOG_DEBUG << \"send stream with transfer-encoding chunked\";\n headers_[\"transfer-encoding\"] = \"chunked\";\n }\n len = 0;\n }\n else if (sendfileName_.empty())\n {\n auto bodyLength = bodyPtr_ ? bodyPtr_->length() : 0;\n len = snprintf(buffer.beginWrite(),\n buffer.writableBytes(),\n contentLengthFormatString(),\n bodyLength);\n }\n else\n {\n auto bodyLength = sendfileRange_.second;\n len = snprintf(buffer.beginWrite(),\n buffer.writableBytes(),\n contentLengthFormatString(),\n bodyLength);\n }\n buffer.hasWritten(len);\n if (headers_.find(\"connection\") == headers_.end())\n {\n if (closeConnection_)\n {\n buffer.append(\"connection: close\\r\\n\");\n }\n else if (version_ == Version::kHttp10)\n {\n buffer.append(\"connection: Keep-Alive\\r\\n\");\n }\n }\n\n if (!contentTypeString_.empty())\n {\n buffer.append(\"content-type: \");\n buffer.append(contentTypeString_);\n buffer.append(\"\\r\\n\");\n }\n if (HttpAppFrameworkImpl::instance().sendServerHeader())\n {\n buffer.append(\n HttpAppFrameworkImpl::instance().getServerHeaderString());\n }\n }\n\n for (auto it = headers_.begin(); it != headers_.end(); ++it)\n {\n buffer.append(it->first);\n buffer.append(\": \");\n buffer.append(it->second);\n buffer.append(\"\\r\\n\");\n }\n}\n\nvoid HttpResponseImpl::renderToBuffer(trantor::MsgBuffer &buffer)\n{\n if (expriedTime_ >= 0)\n {\n auto strPtr = renderToBuffer();\n buffer.append(strPtr->peek(), strPtr->readableBytes());\n return;\n }\n\n if (!fullHeaderString_)\n {\n makeHeaderString(buffer);\n }\n else\n {\n buffer.append(*fullHeaderString_);\n }\n\n // output cookies\n if (!cookies_.empty())\n {\n for (auto it = cookies_.begin(); it != cookies_.end(); ++it)\n {\n buffer.append(it->second.cookieString());\n }\n }\n\n // output Date header\n if (!passThrough_ &&\n drogon::HttpAppFrameworkImpl::instance().sendDateHeader())\n {\n buffer.append(\"date: \");\n buffer.append(utils::getHttpFullDate(trantor::Date::date()),\n httpFullDateStringLength);\n buffer.append(\"\\r\\n\\r\\n\");\n }\n else\n {\n buffer.append(\"\\r\\n\");\n }\n if (bodyPtr_)\n buffer.append(bodyPtr_->data(), bodyPtr_->length());\n}\n\nstd::shared_ptr HttpResponseImpl::renderToBuffer()\n{\n if (expriedTime_ >= 0)\n {\n if (!passThrough_ &&\n drogon::HttpAppFrameworkImpl::instance().sendDateHeader())\n {\n if (datePos_ != static_cast(-1))\n {\n auto now = trantor::Date::now();\n bool isDateChanged =\n ((now.microSecondsSinceEpoch() / MICRO_SECONDS_PRE_SEC) !=\n httpStringDate_);\n assert(httpString_);\n if (isDateChanged)\n {\n httpStringDate_ =\n now.microSecondsSinceEpoch() / MICRO_SECONDS_PRE_SEC;\n auto newDate = utils::getHttpFullDate(now);\n\n httpString_ =\n std::make_shared(*httpString_);\n memcpy((void *)&(*httpString_)[datePos_],\n newDate,\n httpFullDateStringLength);\n return httpString_;\n }\n\n return httpString_;\n }\n }\n else\n {\n if (httpString_)\n return httpString_;\n }\n }\n auto httpString = std::make_shared(256);\n if (!fullHeaderString_)\n {\n makeHeaderString(*httpString);\n }\n else\n {\n httpString->append(*fullHeaderString_);\n }\n\n // output cookies\n if (!cookies_.empty())\n {\n for (auto it = cookies_.begin(); it != cookies_.end(); ++it)\n {\n httpString->append(it->second.cookieString());\n }\n }\n\n // output Date header\n if (!passThrough_ &&\n drogon::HttpAppFrameworkImpl::instance().sendDateHeader())\n {\n httpString->append(\"date: \");\n auto datePos = httpString->readableBytes();\n httpString->append(utils::getHttpFullDate(trantor::Date::date()),\n httpFullDateStringLength);\n httpString->append(\"\\r\\n\\r\\n\");\n datePos_ = datePos;\n }\n else\n {\n httpString->append(\"\\r\\n\");\n }\n\n LOG_TRACE << \"response(no body):\"\n << std::string_view{httpString->peek(),\n httpString->readableBytes()};\n if (bodyPtr_)\n httpString->append(bodyPtr_->data(), bodyPtr_->length());\n if (expriedTime_ >= 0)\n {\n httpString_ = httpString;\n }\n return httpString;\n}\n\nstd::shared_ptr HttpResponseImpl::\n renderHeaderForHeadMethod()\n{\n auto httpString = std::make_shared(256);\n if (!fullHeaderString_)\n {\n makeHeaderString(*httpString);\n }\n else\n {\n httpString->append(*fullHeaderString_);\n }\n\n // output cookies\n if (!cookies_.empty())\n {\n for (auto it = cookies_.begin(); it != cookies_.end(); ++it)\n {\n httpString->append(it->second.cookieString());\n }\n }\n\n // output Date header\n if (!passThrough_ &&\n drogon::HttpAppFrameworkImpl::instance().sendDateHeader())\n {\n httpString->append(\"date: \");\n httpString->append(utils::getHttpFullDate(trantor::Date::date()),\n httpFullDateStringLength);\n httpString->append(\"\\r\\n\\r\\n\");\n }\n else\n {\n httpString->append(\"\\r\\n\");\n }\n\n return httpString;\n}\n\nvoid HttpResponseImpl::addHeader(const char *start,\n const char *colon,\n const char *end)\n{\n fullHeaderString_.reset();\n std::string field(start, colon);\n transform(field.begin(), field.end(), field.begin(), [](unsigned char c) {\n return tolower(c);\n });\n ++colon;\n while (colon < end && isspace(static_cast(*colon)))\n {\n ++colon;\n }\n std::string value(colon, end);\n while (!value.empty() &&\n isspace(static_cast(value[value.size() - 1])))\n {\n value.resize(value.size() - 1);\n }\n\n if (field == \"set-cookie\")\n {\n // LOG_INFO<<\"cookies!!!:\"<(cookie_name[cpos])))\n ++cpos;\n cookie_name = cookie_name.substr(cpos);\n ++epos;\n while (epos < coo.length() &&\n isspace(static_cast(coo[epos])))\n ++epos;\n cookie_value = coo.substr(epos);\n }\n else\n {\n std::string::size_type cpos = 0;\n while (cpos < coo.length() &&\n isspace(static_cast(coo[cpos])))\n ++cpos;\n cookie_name = coo.substr(cpos);\n }\n if (i == 0)\n {\n cookie.setKey(cookie_name);\n cookie.setValue(cookie_value);\n }\n else\n {\n std::transform(cookie_name.begin(),\n cookie_name.end(),\n cookie_name.begin(),\n [](unsigned char c) { return tolower(c); });\n if (cookie_name == \"path\")\n {\n cookie.setPath(cookie_value);\n }\n else if (cookie_name == \"domain\")\n {\n cookie.setDomain(cookie_value);\n }\n else if (cookie_name == \"expires\")\n {\n cookie.setExpiresDate(utils::getHttpDate(cookie_value));\n }\n else if (cookie_name == \"secure\")\n {\n cookie.setSecure(true);\n }\n else if (cookie_name == \"httponly\")\n {\n cookie.setHttpOnly(true);\n }\n else if (cookie_name == \"samesite\")\n {\n cookie.setSameSite(\n cookie.convertString2SameSite(cookie_value));\n }\n else if (cookie_name == \"max-age\")\n {\n cookie.setMaxAge(std::stoi(cookie_value));\n }\n }\n }\n if (!cookie.key().empty())\n {\n cookies_[cookie.key()] = cookie;\n }\n }\n else\n {\n headers_[std::move(field)] = std::move(value);\n }\n}\n\nvoid HttpResponseImpl::swap(HttpResponseImpl &that) noexcept\n{\n using std::swap;\n headers_.swap(that.headers_);\n cookies_.swap(that.cookies_);\n swap(statusCode_, that.statusCode_);\n swap(version_, that.version_);\n swap(statusMessage_, that.statusMessage_);\n swap(closeConnection_, that.closeConnection_);\n bodyPtr_.swap(that.bodyPtr_);\n swap(contentType_, that.contentType_);\n swap(flagForParsingContentType_, that.flagForParsingContentType_);\n swap(flagForParsingJson_, that.flagForParsingJson_);\n swap(sendfileName_, that.sendfileName_);\n swap(streamCallback_, that.streamCallback_);\n swap(asyncStreamCallback_, that.asyncStreamCallback_);\n jsonPtr_.swap(that.jsonPtr_);\n fullHeaderString_.swap(that.fullHeaderString_);\n httpString_.swap(that.httpString_);\n swap(datePos_, that.datePos_);\n swap(jsonParsingErrorPtr_, that.jsonParsingErrorPtr_);\n}\n\nvoid HttpResponseImpl::clear()\n{\n statusCode_ = kUnknown;\n version_ = Version::kHttp11;\n statusMessage_ = std::string_view{};\n fullHeaderString_.reset();\n jsonParsingErrorPtr_.reset();\n sendfileName_.clear();\n if (streamCallback_)\n {\n LOG_TRACE << \"Cleanup HttpResponse stream callback\";\n streamCallback_(nullptr, 0); // callback internal cleanup\n streamCallback_ = {};\n }\n if (asyncStreamCallback_)\n {\n // asyncStreamCallback_(nullptr);\n asyncStreamCallback_ = {};\n }\n headers_.clear();\n cookies_.clear();\n bodyPtr_.reset();\n jsonPtr_.reset();\n expriedTime_ = -1;\n datePos_ = std::string::npos;\n flagForParsingContentType_ = false;\n flagForParsingJson_ = false;\n}\n\nvoid HttpResponseImpl::parseJson() const\n{\n static std::once_flag once;\n static Json::CharReaderBuilder builder;\n std::call_once(once, []() {\n builder[\"collectComments\"] = false;\n builder[\"stackLimit\"] =\n static_cast(drogon::app().getJsonParserStackLimit());\n });\n JSONCPP_STRING errs;\n std::unique_ptr reader(builder.newCharReader());\n if (bodyPtr_)\n {\n jsonPtr_ = std::make_shared();\n if (!reader->parse(bodyPtr_->data(),\n bodyPtr_->data() + bodyPtr_->length(),\n jsonPtr_.get(),\n &errs))\n {\n LOG_ERROR << errs;\n LOG_ERROR << \"body: \" << bodyPtr_->getString();\n jsonPtr_.reset();\n jsonParsingErrorPtr_ =\n std::make_shared(std::move(errs));\n }\n else\n {\n jsonParsingErrorPtr_.reset();\n }\n }\n else\n {\n jsonPtr_.reset();\n jsonParsingErrorPtr_ =\n std::make_shared(\"empty response body\");\n }\n}\n\nbool HttpResponseImpl::shouldBeCompressed() const\n{\n if (streamCallback_ || asyncStreamCallback_ || !sendfileName_.empty() ||\n contentType() >= CT_APPLICATION_OCTET_STREAM ||\n getBody().length() < 1024 || !(getHeaderBy(\"content-encoding\").empty()))\n {\n return false;\n }\n return true;\n}\n\nvoid HttpResponseImpl::setContentTypeString(const char *typeString,\n size_t typeStringLength)\n{\n std::string sv(typeString, typeStringLength);\n auto contentType = parseContentType(sv);\n if (contentType == CT_NONE)\n contentType = CT_CUSTOM;\n contentType_ = contentType;\n contentTypeString_ = std::string(sv);\n flagForParsingContentType_ = true;\n}\n\n// Path: lib/src/ConfigLoader.h\n/**\n *\n * ConfigLoader.h\n * An Tao\n *\n * Copyright 2018, An Tao. All rights reserved.\n * https://github.com/an-tao/drogon\n * Use of this source code is governed by a MIT license\n * that can be found in the License file.\n *\n * Drogon\n *\n */\n\n#pragma once\n\n#include \n#include \n#include \n\nnamespace drogon\n{\nclass ConfigLoader : public trantor::NonCopyable\n{\n public:\n explicit ConfigLoader(const std::string &configFile) noexcept(false);\n explicit ConfigLoader(const Json::Value &data);\n explicit ConfigLoader(Json::Value &&data);\n ~ConfigLoader();\n\n const Json::Value &jsonValue() const\n {\n return configJsonRoot_;\n }\n\n void load() noexcept(false);\n\n private:\n std::string configFile_;\n Json::Value configJsonRoot_;\n};\n} // namespace drogon\n\n// Path: lib/src/HttpServer.h\n/**\n *\n * @file HttpServer.h\n * @author An Tao\n *\n * Copyright 2018, An Tao. All rights reserved.\n * https://github.com/an-tao/drogon\n * Use of this source code is governed by a MIT license\n * that can be found in the License file.\n *\n * Drogon\n *\n */\n\n#pragma once\n\n#include \n#include \n#include \n#include \n#include \n#include \"impl_forwards.h\"\n\nstruct CallbackParamPack;\n\nnamespace drogon\n{\nclass ControllerBinderBase;\n\nclass HttpServer : trantor::NonCopyable\n{\n public:\n HttpServer(trantor::EventLoop *loop,\n const trantor::InetAddress &listenAddr,\n std::string name);\n\n ~HttpServer();\n\n void setIoLoops(const std::vector &ioLoops)\n {\n server_.setIoLoops(ioLoops);\n }\n\n void start();\n void stop();\n\n void enableSSL(trantor::TLSPolicyPtr policy)\n {\n server_.enableSSL(std::move(policy));\n }\n\n const trantor::InetAddress &address() const\n {\n return server_.address();\n }\n\n private:\n friend class HttpInternalForwardHelper;\n\n static void onConnection(const trantor::TcpConnectionPtr &conn);\n static void onMessage(const trantor::TcpConnectionPtr &,\n trantor::MsgBuffer *);\n static void onRequests(const trantor::TcpConnectionPtr &,\n const std::vector &,\n const std::shared_ptr &);\n\n struct HttpRequestParamPack\n {\n std::shared_ptr binderPtr;\n std::function callback;\n };\n\n struct WsRequestParamPack\n {\n std::shared_ptr binderPtr;\n std::function callback;\n WebSocketConnectionImplPtr wsConnPtr;\n };\n\n // Http request handling steps\n static void onHttpRequest(const HttpRequestImplPtr &,\n std::function &&);\n static void httpRequestRouting(\n const HttpRequestImplPtr &req,\n std::function &&callback);\n static void httpRequestHandling(\n const HttpRequestImplPtr &req,\n std::shared_ptr &&binderPtr,\n std::function &&callback);\n\n // Websocket request handling steps\n static void onWebsocketRequest(\n const HttpRequestImplPtr &,\n std::function &&,\n WebSocketConnectionImplPtr &&);\n static void websocketRequestRouting(\n const HttpRequestImplPtr &req,\n std::function &&callback,\n WebSocketConnectionImplPtr &&wsConnPtr);\n static void websocketRequestHandling(\n const HttpRequestImplPtr &req,\n std::shared_ptr &&binderPtr,\n std::function &&callback,\n WebSocketConnectionImplPtr &&wsConnPtr);\n\n // Http/Websocket shared handling steps\n template \n static void requestPostRouting(const HttpRequestImplPtr &req, Pack &&pack);\n template \n static void requestPassFilters(const HttpRequestImplPtr &req, Pack &&pack);\n template \n static void requestPreHandling(const HttpRequestImplPtr &req, Pack &&pack);\n\n // Response buffering and sending\n static void handleResponse(\n const HttpResponsePtr &response,\n const std::shared_ptr ¶mPack,\n bool *respReadyPtr);\n static void sendResponse(const trantor::TcpConnectionPtr &,\n const HttpResponsePtr &,\n bool isHeadMethod);\n static void sendResponses(\n const trantor::TcpConnectionPtr &conn,\n const std::vector> &responses,\n trantor::MsgBuffer &buffer);\n\n trantor::TcpServer server_;\n};\n\nclass HttpInternalForwardHelper\n{\n public:\n static void forward(const HttpRequestImplPtr &req,\n std::function &&callback)\n {\n return HttpServer::onHttpRequest(req, std::move(callback));\n }\n};\n\n} // namespace drogon\n\n// Path: lib/src/ListenerManager.h\n/**\n *\n * @file ListenerManager.h\n * @author An Tao\n *\n * Copyright 2018, An Tao. All rights reserved.\n * https://github.com/an-tao/drogon\n * Use of this source code is governed by a MIT license\n * that can be found in the License file.\n *\n * Drogon\n *\n */\n\n#pragma once\n\n#include \n#include \n#include \n#include \n#include \n#include \n#include \n#include \"impl_forwards.h\"\n\nnamespace trantor\n{\nclass InetAddress;\n}\n\nnamespace drogon\n{\nclass ListenerManager : public trantor::NonCopyable\n{\n public:\n ~ListenerManager() = default;\n void addListener(const std::string &ip,\n uint16_t port,\n bool useSSL = false,\n const std::string &certFile = \"\",\n const std::string &keyFile = \"\",\n bool useOldTLS = false,\n const std::vector>\n &sslConfCmds = {});\n std::vector getListeners() const;\n void createListeners(\n const std::string &globalCertFile,\n const std::string &globalKeyFile,\n const std::vector> &sslConfCmds,\n const std::vector &ioLoops);\n void startListening();\n void stopListening();\n\n private:\n struct ListenerInfo\n {\n ListenerInfo(\n std::string ip,\n uint16_t port,\n bool useSSL,\n std::string certFile,\n std::string keyFile,\n bool useOldTLS,\n std::vector> sslConfCmds)\n : ip_(std::move(ip)),\n port_(port),\n useSSL_(useSSL),\n certFile_(std::move(certFile)),\n keyFile_(std::move(keyFile)),\n useOldTLS_(useOldTLS),\n sslConfCmds_(std::move(sslConfCmds))\n {\n }\n\n std::string ip_;\n uint16_t port_;\n bool useSSL_;\n std::string certFile_;\n std::string keyFile_;\n bool useOldTLS_;\n std::vector> sslConfCmds_;\n };\n\n std::vector listeners_;\n std::vector> servers_;\n\n // should have value when and only when on OS that one port can only be\n // listened by one thread\n std::unique_ptr listeningThread_;\n};\n\n} // namespace drogon\n\n// Path: lib/src/RedisClientManager.h\n/**\n *\n * RedisClientManager.h\n * An Tao\n *\n * Copyright 2018, An Tao. All rights reserved.\n * https://github.com/an-tao/drogon\n * Use of this source code is governed by a MIT license\n * that can be found in the License file.\n *\n * Drogon\n *\n */\n\n#pragma once\n\n#include \n#include \n#include \n#include \n...\n", "needle_token_start": 5570, "needle_token_end": 5689, "code_context_ntokens": 16124, "output": ["```cpp\nstatic inline HttpResponsePtr genHttpResponse(const std::string &viewName,\n const HttpViewData &data,\n const HttpRequestPtr &req)\n{\n auto templ = DrTemplateBase::newTemplate(viewName);\n if (templ)\n {\n auto res = HttpResponse::newHttpResponse();\n res->setBody(templ->genText(data));\n return res;\n }\n return drogon::HttpResponse::newNotFoundResponse(req);\n}\n```"]} +{"repo": "drogonframework/drogon", "name": "splitStringToSet", "language": "cpp", "path": "lib/src/Utilities.cc", "position_ratio": 0.45, "description": "\nFunction Description:\n1. **Purpose**: The function is designed to divide a given string into distinct substrings based on a specified delimiter, and then store these substrings in a set to ensure all entries are unique and automatically ordered.\n2. **Input**: The function takes two strings: one is the main string to be split, and the other is the delimiter used to split the string.\n3. **Output**: It returns a set of strings, where each element is a unique substring from the original string, split according to the specified delimiter.\n4. **Procedure**: The function starts by initializing two indices to manage the current position within the string and the position of the next delimiter. It iterates through the main string, using the delimiter to find the end of each substring. Each found substring is then added to the set (if it's not an empty string), ensuring no duplicates. This process repeats until the entire string has been processed. If any part of the string remains after the last delimiter, it is also added to the set.\n\n", "instruction": "Based on the function description and code context, please retrieve and repeat the exact described function from the code context in a code block wrapped by ```:", "template": "instruction\ncode_context\ndescription\ninstruction", "code_context": "// Path: lib/src/MultiPart.cc\n/**\n *\n * @file MultiPart.cc\n * @author An Tao\n *\n * Copyright 2018, An Tao. All rights reserved.\n * https://github.com/an-tao/drogon\n * Use of this source code is governed by a MIT license\n * that can be found in the License file.\n *\n * Drogon\n *\n */\n\n#include \"HttpRequestImpl.h\"\n#include \"HttpUtils.h\"\n#include \"HttpAppFrameworkImpl.h\"\n#include \"HttpFileImpl.h\"\n#include \n#include \n#include \n#include \n#include \n#include \n#include \n#include \n#ifndef _WIN32\n#include \n#endif\n\nusing namespace drogon;\n\nconst std::vector &MultiPartParser::getFiles() const\n{\n return files_;\n}\n\nstd::unordered_map MultiPartParser::getFilesMap() const\n{\n std::unordered_map result;\n for (auto &file : files_)\n {\n result.emplace(file.getItemName(), file);\n }\n return result;\n}\n\nconst SafeStringMap &MultiPartParser::getParameters() const\n{\n return parameters_;\n}\n\nint MultiPartParser::parse(const HttpRequestPtr &req)\n{\n if (req->method() != Post && req->method() != Put)\n return -1;\n const std::string &contentType =\n static_cast(req.get())->getHeaderBy(\"content-type\");\n if (contentType.empty())\n {\n return -1;\n }\n std::string::size_type pos = contentType.find(';');\n if (pos == std::string::npos)\n return -1;\n\n std::string type = contentType.substr(0, pos);\n std::transform(type.begin(), type.end(), type.begin(), [](unsigned char c) {\n return tolower(c);\n });\n if (type != \"multipart/form-data\")\n return -1;\n pos = contentType.find(\"boundary=\");\n if (pos == std::string::npos)\n return -1;\n auto pos2 = contentType.find(';', pos);\n if (pos2 == std::string::npos)\n pos2 = contentType.size();\n return parse(req, contentType.data() + (pos + 9), pos2 - (pos + 9));\n}\n\nstatic std::pair parseLine(\n const char *begin,\n const char *end)\n{\n auto p = begin;\n while (p != end)\n {\n if (*p == ':')\n {\n if (p + 1 != end && *(p + 1) == ' ')\n {\n return std::make_pair(std::string_view(begin, p - begin),\n...\n// Path: lib/src/SharedLibManager.cc\n/**\n *\n * @file SharedLibManager.cc\n * @author An Tao\n *\n * Copyright 2018, An Tao. All rights reserved.\n * https://github.com/an-tao/drogon\n * Use of this source code is governed by a MIT license\n * that can be found in the License file.\n *\n * Drogon\n *\n */\n\n#include \"SharedLibManager.h\"\n#include \n#include \n#include \n#include \n#include \n#include \n#include \n\nstatic void forEachFileIn(\n const std::string &path,\n const std::function &cb)\n{\n DIR *dp;\n struct dirent *dirp;\n struct stat st;\n\n /* open dirent directory */\n if ((dp = opendir(path.c_str())) == NULL)\n {\n // perror(\"opendir:\");\n LOG_ERROR << \"can't open dir,path:\" << path;\n return;\n }\n\n /**\n * read all files in this dir\n **/\n while ((dirp = readdir(dp)) != NULL)\n {\n /* ignore hidden files */\n if (dirp->d_name[0] == '.')\n continue;\n /* get dirent status */\n std::string filename = dirp->d_name;\n std::string fullname = path;\n fullname.append(\"/\").append(filename);\n if (stat(fullname.c_str(), &st) == -1)\n {\n perror(\"stat\");\n closedir(dp);\n return;\n }\n\n /* if dirent is a directory, find files recursively */\n if (S_ISDIR(st.st_mode))\n {\n forEachFileIn(fullname, cb);\n }\n else\n {\n cb(fullname, st);\n }\n }\n closedir(dp);\n return;\n}\n\nusing namespace drogon;\n\nSharedLibManager::SharedLibManager(const std::vector &libPaths,\n const std::string &outputPath)\n : libPaths_(libPaths), outputPath_(outputPath)\n{\n workingThread_.run();\n timeId_ =\n workingThread_.getLoop()->runEvery(5.0, [this]() { managerLibs(); });\n}\n\nSharedLibManager::~SharedLibManager()\n{\n workingThread_.getLoop()->invalidateTimer(timeId_);\n}\n\nvoid SharedLibManager::managerLibs()\n{\n for (auto const &libPath : libPaths_)\n {\n forEachFileIn(\n libPath,\n [this, libPath](const std::string &filename,\n const struct stat &st) {\n auto pos = filename.rfind('.');\n if (pos != std::string::npos)\n {\n auto exName = filename.substr(pos + 1);\n if (exName == \"csp\")\n {\n // compile\n auto lockFile = filename + \".lock\";\n std::ifstream fin(lockFile);\n if (fin)\n {\n return;\n }\n\n void *oldHandle = nullptr;\n if (dlMap_.find(filename) != dlMap_.end())\n {\n#if defined __linux__ || defined __HAIKU__\n if (st.st_mtim.tv_sec >\n dlMap_[filename].mTime.tv_sec)\n#elif defined _WIN32\n if (st.st_mtime > dlMap_[filename].mTime.tv_sec)\n#else\n if (st.st_mtimespec.tv_sec >\n dlMap_[filename].mTime.tv_sec)\n#endif\n {\n LOG_TRACE << \"new csp file:\" << filename;\n oldHandle = dlMap_[filename].handle;\n }\n else\n return;\n }\n\n {\n std::ofstream fout(lockFile);\n }\n\n auto srcFile = filename.substr(0, pos);\n if (!outputPath_.empty())\n {\n pos = srcFile.rfind(\"/\");\n if (pos != std::string::npos)\n {\n srcFile = srcFile.substr(pos + 1);\n }\n srcFile = outputPath_ + \"/\" + srcFile;\n }\n auto soFile = srcFile + \".so\";\n DLStat dlStat;\n if (!shouldCompileLib(soFile, st))\n {\n LOG_TRACE << \"Using already compiled library:\"\n << soFile;\n dlStat.handle = loadLib(soFile, oldHandle);\n }\n else\n {\n // generate source code and compile it.\n std::string cmd = \"drogon_ctl create view \";\n if (!outputPath_.empty())\n {\n cmd.append(filename).append(\" -o \").append(\n outputPath_);\n }\n else\n {\n cmd.append(filename).append(\" -o \").append(\n libPath);\n }\n srcFile.append(\".cc\");\n LOG_TRACE << cmd;\n auto r = system(cmd.c_str());\n // TODO: handle r\n (void)(r);\n dlStat.handle =\n compileAndLoadLib(srcFile, oldHandle);\n }\n#if defined __linux__ || defined __HAIKU__\n dlStat.mTime = st.st_mtim;\n#elif defined _WIN32\n dlStat.mTime.tv_sec = st.st_mtime;\n#else\n dlStat.mTime = st.st_mtimespec;\n#endif\n if (dlStat.handle)\n {\n dlMap_[filename] = dlStat;\n }\n else\n {\n dlStat.handle = dlMap_[filename].handle;\n dlMap_[filename] = dlStat;\n }\n workingThread_.getLoop()->runAfter(3.5, [lockFile]() {\n LOG_TRACE << \"remove file \" << lockFile;\n if (unlink(lockFile.c_str()) == -1)\n perror(\"\");\n });\n }\n }\n });\n }\n}\n\nvoid *SharedLibManager::compileAndLoadLib(const std::string &sourceFile,\n void *oldHld)\n{\n LOG_TRACE << \"src:\" << sourceFile;\n std::string cmd = COMPILER_COMMAND;\n cmd.append(\" \")\n .append(sourceFile)\n .append(\" \")\n .append(COMPILATION_FLAGS)\n .append(\" \")\n .append(INCLUDING_DIRS);\n if (std::string(COMPILER_ID).find(\"Clang\") != std::string::npos)\n cmd.append(\" -shared -fPIC -undefined dynamic_lookup -o \");\n else\n cmd.append(\" -shared -fPIC --no-gnu-unique -o \");\n auto pos = sourceFile.rfind('.');\n auto soFile = sourceFile.substr(0, pos);\n soFile.append(\".so\");\n cmd.append(soFile);\n LOG_TRACE << cmd;\n\n if (system(cmd.c_str()) == 0)\n {\n LOG_TRACE << \"Compiled successfully:\" << soFile;\n return loadLib(soFile, oldHld);\n }\n else\n {\n LOG_DEBUG << \"Could not compile library.\";\n return nullptr;\n }\n}\n\nbool SharedLibManager::shouldCompileLib(const std::string &soFile,\n const struct stat &sourceStat)\n{\n#if defined __linux__ || defined __HAIKU__\n auto sourceModifiedTime = sourceStat.st_mtim.tv_sec;\n#elif defined _WIN32\n auto sourceModifiedTime = sourceStat.st_mtime;\n#else\n auto sourceModifiedTime = sourceStat.st_mtimespec.tv_sec;\n#endif\n\n struct stat soStat;\n if (stat(soFile.c_str(), &soStat) == -1)\n {\n LOG_TRACE << \"Cannot determine modification time for:\" << soFile;\n return true;\n }\n\n#if defined __linux__ || defined __HAIKU__\n auto soModifiedTime = soStat.st_mtim.tv_sec;\n#elif defined _WIN32\n auto soModifiedTime = soStat.st_mtime;\n#else\n auto soModifiedTime = soStat.st_mtimespec.tv_sec;\n#endif\n\n return (sourceModifiedTime > soModifiedTime);\n}\n\nvoid *SharedLibManager::loadLib(const std::string &soFile, void *oldHld)\n{\n if (oldHld)\n {\n if (dlclose(oldHld) == 0)\n {\n LOG_TRACE << \"Successfully closed dynamic library:\" << oldHld;\n }\n else\n {\n LOG_TRACE << dlerror();\n }\n }\n auto Handle = dlopen(soFile.c_str(), RTLD_LAZY);\n if (!Handle)\n {\n LOG_ERROR << \"load \" << soFile << \" error!\";\n LOG_ERROR << dlerror();\n }\n else\n {\n LOG_TRACE << \"Successfully loaded library file \" << soFile;\n }\n\n return Handle;\n}\n\n// Path: lib/src/RedisClientSkipped.cc\n/**\n *\n * RedisClientSkipped.cc\n * An Tao\n *\n * Copyright 2018, An Tao. All rights reserved.\n * https://github.com/an-tao/drogon\n * Use of this source code is governed by a MIT license\n * that can be found in the License file.\n *\n * Drogon\n *\n */\n\n#include \"drogon/nosql/RedisClient.h\"\n\nnamespace drogon\n{\nnamespace nosql\n{\nstd::shared_ptr RedisClient::newRedisClient(\n const trantor::InetAddress & /*serverAddress*/,\n size_t /*numberOfConnections*/,\n const std::string & /*password*/,\n const unsigned int /*db*/,\n const std::string & /*username*/)\n{\n LOG_FATAL << \"Redis is not supported by drogon, please install the \"\n \"hiredis library first.\";\n abort();\n}\n} // namespace nosql\n} // namespace drogon\n\n// Path: lib/src/Cookie.cc\n/**\n *\n * Cookie.cc\n * An Tao\n *\n * Copyright 2018, An Tao. All rights reserved.\n * https://github.com/an-tao/drogon\n * Use of this source code is governed by a MIT license\n * that can be found in the License file.\n *\n * Drogon\n *\n */\n\n#include \n#include \n#include \nusing namespace drogon;\n\nstd::string Cookie::cookieString() const\n{\n std::string ret = \"Set-Cookie: \";\n // reserve space to reduce frequency allocation\n ret.reserve(ret.size() + key_.size() + value_.size() + 30);\n ret.append(key_).append(\"=\").append(value_).append(\"; \");\n if (expiresDate_.microSecondsSinceEpoch() !=\n (std::numeric_limits::max)() &&\n expiresDate_.microSecondsSinceEpoch() >= 0)\n {\n ret.append(\"Expires=\")\n .append(utils::getHttpFullDate(expiresDate_))\n .append(\"; \");\n }\n if (maxAge_.has_value())\n {\n ret.append(\"Max-Age=\")\n .append(std::to_string(maxAge_.value()))\n .append(\"; \");\n }\n if (!domain_.empty())\n {\n ret.append(\"Domain=\").append(domain_).append(\"; \");\n }\n if (!path_.empty())\n {\n ret.append(\"Path=\").append(path_).append(\"; \");\n }\n if (sameSite_ != SameSite::kNull)\n {\n switch (sameSite_)\n {\n case SameSite::kLax:\n ret.append(\"SameSite=Lax; \");\n break;\n case SameSite::kStrict:\n ret.append(\"SameSite=Strict; \");\n break;\n case SameSite::kNone:\n ret.append(\"SameSite=None; \");\n // Cookies with SameSite=None must now also specify the Secure\n // attribute (they require a secure context/HTTPS).\n ret.append(\"Secure; \");\n break;\n default:\n // Lax replaced None as the default value to ensure that users\n // have reasonably robust defense against some CSRF attacks\n ret.append(\"SameSite=Lax; \");\n }\n }\n if (secure_ && sameSite_ != SameSite::kNone)\n {\n ret.append(\"Secure; \");\n }\n if (httpOnly_)\n {\n ret.append(\"HttpOnly; \");\n }\n ret.resize(ret.length() - 2); // delete last semicolon\n ret.append(\"\\r\\n\");\n return ret;\n}\n\n// Path: lib/src/Utilities.cc\n/**\n *\n * @file Utilities.cc\n * @author An Tao\n *\n * Copyright 2018, An Tao. All rights reserved.\n * https://github.com/an-tao/drogon\n * Use of this source code is governed by a MIT license\n * that can be found in the License file.\n *\n * Drogon\n *\n */\n\n#include \n#include \n#include \n#include \n#ifdef USE_BROTLI\n#include \n#include \n#endif\n#ifdef _WIN32\n#include \n#include \n#include \n#include \n#else\n#include \n#include \n#endif\n#include \n#include \n#include \n#include \n#include \n#include \n#include \n#include \n#include \n#include \n#include \n#include \n#include \n#include \n#include \n\n#ifdef _WIN32\n\nchar *strptime(const char *s, const char *f, struct tm *tm)\n{\n // std::get_time is defined such that its\n // format parameters are the exact same as strptime.\n std::istringstream input(s);\n input.imbue(std::locale(setlocale(LC_ALL, nullptr)));\n input >> std::get_time(tm, f);\n if (input.fail())\n {\n return nullptr;\n }\n return (char *)(s + input.tellg());\n}\n\ntime_t timegm(struct tm *tm)\n{\n struct tm my_tm;\n\n memcpy(&my_tm, tm, sizeof(struct tm));\n\n /* _mkgmtime() changes the value of the struct tm* you pass in, so\n * use a copy\n */\n return _mkgmtime(&my_tm);\n}\n#endif\n\n#ifdef __HAIKU__\n// HACK: Haiku has a timegm implementation. But it is not exposed\nextern \"C\" time_t timegm(struct tm *tm);\n#endif\n\nnamespace drogon\n{\nnamespace utils\n{\nstatic const std::string base64Chars =\n \"ABCDEFGHIJKLMNOPQRSTUVWXYZ\"\n \"abcdefghijklmnopqrstuvwxyz\"\n \"0123456789+/\";\n\nstatic const std::string urlBase64Chars =\n \"ABCDEFGHIJKLMNOPQRSTUVWXYZ\"\n \"abcdefghijklmnopqrstuvwxyz\"\n \"0123456789-_\";\n\nclass Base64CharMap\n{\n public:\n Base64CharMap()\n {\n char index = 0;\n for (int c = 'A'; c <= 'Z'; ++c)\n {\n charMap_[c] = index++;\n }\n for (int c = 'a'; c <= 'z'; ++c)\n {\n charMap_[c] = index++;\n }\n for (int c = '0'; c <= '9'; ++c)\n {\n charMap_[c] = index++;\n }\n charMap_[static_cast('+')] = charMap_[static_cast('-')] =\n index++;\n charMap_[static_cast('/')] = charMap_[static_cast('_')] =\n index;\n charMap_[0] = char(0xff);\n }\n\n char getIndex(const char c) const noexcept\n {\n return charMap_[static_cast(c)];\n }\n\n private:\n char charMap_[256]{0};\n};\n\nstatic const Base64CharMap base64CharMap;\n\nstatic inline bool isBase64(unsigned char c)\n{\n if (isalnum(c))\n return true;\n switch (c)\n {\n case '+':\n case '/':\n case '-':\n case '_':\n return true;\n }\n return false;\n}\n\nbool isInteger(std::string_view str)\n{\n for (auto c : str)\n if (c < '0' || c > '9')\n return false;\n return true;\n}\n\nbool isBase64(std::string_view str)\n{\n for (auto c : str)\n if (!isBase64(c))\n return false;\n return true;\n}\n\nstd::string genRandomString(int length)\n{\n static const char char_space[] =\n \"0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ\";\n static std::once_flag once;\n static const size_t len = strlen(char_space);\n static const int randMax = RAND_MAX - (RAND_MAX % len);\n std::call_once(once, []() {\n std::srand(static_cast(time(nullptr)));\n });\n\n int i;\n std::string str;\n str.resize(length);\n\n for (i = 0; i < length; ++i)\n {\n int x = std::rand();\n while (x >= randMax)\n {\n x = std::rand();\n }\n x = (x % len);\n str[i] = char_space[x];\n }\n\n return str;\n}\n\nstd::vector hexToBinaryVector(const char *ptr, size_t length)\n{\n assert(length % 2 == 0);\n std::vector ret(length / 2, '\\0');\n for (size_t i = 0; i < ret.size(); ++i)\n {\n auto p = i * 2;\n char c1 = ptr[p];\n if (c1 >= '0' && c1 <= '9')\n {\n c1 -= '0';\n }\n else if (c1 >= 'a' && c1 <= 'f')\n {\n c1 -= 'a';\n c1 += 10;\n }\n else if (c1 >= 'A' && c1 <= 'F')\n {\n c1 -= 'A';\n c1 += 10;\n }\n else\n {\n return std::vector();\n }\n char c2 = ptr[p + 1];\n if (c2 >= '0' && c2 <= '9')\n {\n c2 -= '0';\n }\n else if (c2 >= 'a' && c2 <= 'f')\n {\n c2 -= 'a';\n c2 += 10;\n }\n else if (c2 >= 'A' && c2 <= 'F')\n {\n c2 -= 'A';\n c2 += 10;\n }\n else\n {\n return std::vector();\n }\n ret[i] = c1 * 16 + c2;\n }\n return ret;\n}\n\nstd::string hexToBinaryString(const char *ptr, size_t length)\n{\n assert(length % 2 == 0);\n std::string ret(length / 2, '\\0');\n for (size_t i = 0; i < ret.length(); ++i)\n {\n auto p = i * 2;\n char c1 = ptr[p];\n if (c1 >= '0' && c1 <= '9')\n {\n c1 -= '0';\n }\n else if (c1 >= 'a' && c1 <= 'f')\n {\n c1 -= 'a';\n c1 += 10;\n }\n else if (c1 >= 'A' && c1 <= 'F')\n {\n c1 -= 'A';\n c1 += 10;\n }\n else\n {\n return \"\";\n }\n char c2 = ptr[p + 1];\n if (c2 >= '0' && c2 <= '9')\n {\n c2 -= '0';\n }\n else if (c2 >= 'a' && c2 <= 'f')\n {\n c2 -= 'a';\n c2 += 10;\n }\n else if (c2 >= 'A' && c2 <= 'F')\n {\n c2 -= 'A';\n c2 += 10;\n }\n else\n {\n return \"\";\n }\n ret[i] = c1 * 16 + c2;\n }\n return ret;\n}\n\nDROGON_EXPORT void binaryStringToHex(const char *ptr,\n size_t length,\n char *out,\n bool lowerCase)\n{\n for (size_t i = 0; i < length; ++i)\n {\n int value = (ptr[i] & 0xf0) >> 4;\n if (value < 10)\n {\n out[i * 2] = char(value + 48);\n }\n else\n {\n if (!lowerCase)\n {\n out[i * 2] = char(value + 55);\n }\n else\n {\n out[i * 2] = char(value + 87);\n }\n }\n\n value = (ptr[i] & 0x0f);\n if (value < 10)\n {\n out[i * 2 + 1] = char(value + 48);\n }\n else\n {\n if (!lowerCase)\n {\n out[i * 2 + 1] = char(value + 55);\n }\n else\n {\n out[i * 2 + 1] = char(value + 87);\n }\n }\n }\n}\n\nstd::string binaryStringToHex(const unsigned char *ptr,\n size_t length,\n bool lowercase)\n{\n std::string idString(length * 2, '\\0');\n binaryStringToHex((const char *)ptr, length, &idString[0], lowercase);\n return idString;\n}\n\n\nstd::set splitStringToSet(const std::string &str,\n const std::string &separator)\n{\n std::set ret;\n std::string::size_type pos1, pos2;\n pos2 = 0;\n pos1 = str.find(separator);\n while (pos1 != std::string::npos)\n {\n if (pos1 != 0)\n {\n std::string item = str.substr(pos2, pos1 - pos2);\n ret.insert(item);\n }\n pos2 = pos1 + separator.length();\n while (pos2 < str.length() &&\n str.substr(pos2, separator.length()) == separator)\n pos2 += separator.length();\n pos1 = str.find(separator, pos2);\n }\n if (pos2 < str.length())\n ret.insert(str.substr(pos2));\n return ret;\n}\n\ninline std::string createUuidString(const char *str, size_t len, bool lowercase)\n{\n assert(len == 16);\n std::string uuid(36, '\\0');\n binaryStringToHex(str, 4, &uuid[0], lowercase);\n uuid[8] = '-';\n binaryStringToHex(str + 4, 2, &uuid[9], lowercase);\n uuid[13] = '-';\n binaryStringToHex(str + 6, 2, &uuid[14], lowercase);\n uuid[18] = '-';\n binaryStringToHex(str + 8, 2, &uuid[19], lowercase);\n uuid[23] = '-';\n binaryStringToHex(str + 10, 6, &uuid[24], lowercase);\n return uuid;\n}\n\nstd::string getUuid(bool lowercase)\n{\n#if USE_OSSP_UUID\n uuid_t *uuid;\n uuid_create(&uuid);\n uuid_make(uuid, UUID_MAKE_V4);\n char *str{nullptr};\n size_t len{0};\n uuid_export(uuid, UUID_FMT_BIN, &str, &len);\n uuid_destroy(uuid);\n auto ret = createUuidString(str, len, lowercase);\n free(str);\n return ret;\n#elif defined __FreeBSD__ || defined __OpenBSD__\n uuid_t *uuid = new uuid_t;\n char *binstr = (char *)malloc(16);\n#if defined __FreeBSD__\n uuidgen(uuid, 1);\n#else\n uint32_t status;\n uuid_create(uuid, &status);\n#endif\n#if _BYTE_ORDER == _LITTLE_ENDIAN\n uuid_enc_le(binstr, uuid);\n#else /* _BYTE_ORDER != _LITTLE_ENDIAN */\n uuid_enc_be(binstr, uuid);\n#endif /* _BYTE_ORDER == _LITTLE_ENDIAN */\n delete uuid;\n auto ret = createUuidString(binstr, 16, lowercase);\n free(binstr);\n return ret;\n#elif defined _WIN32\n uuid_t uu;\n UuidCreate(&uu);\n char tempStr[100];\n auto len = snprintf(tempStr,\n sizeof(tempStr),\n \"%08x-%04x-%04x-%02x%02x-%02x%02x%02x%02x%02x%02x\",\n uu.Data1,\n uu.Data2,\n uu.Data3,\n uu.Data4[0],\n uu.Data4[1],\n uu.Data4[2],\n uu.Data4[3],\n uu.Data4[4],\n uu.Data4[5],\n uu.Data4[6],\n uu.Data4[7]);\n return std::string{tempStr, static_cast(len)};\n#else\n uuid_t uu;\n uuid_generate(uu);\n auto uuid = createUuidString((const char *)uu, 16, lowercase);\n return uuid;\n#endif\n}\n\nstd::string base64Encode(const unsigned char *bytes_to_encode,\n size_t in_len,\n bool url_safe,\n bool padded)\n{\n std::string ret;\n ret.reserve(base64EncodedLength(in_len, padded));\n int i = 0;\n unsigned char char_array_3[3];\n unsigned char char_array_4[4];\n\n const std::string &charSet = url_safe ? urlBase64Chars : base64Chars;\n\n while (in_len--)\n {\n char_array_3[i++] = *(bytes_to_encode++);\n if (i == 3)\n {\n char_array_4[0] = (char_array_3[0] & 0xfc) >> 2;\n char_array_4[1] = ((char_array_3[0] & 0x03) << 4) +\n ((char_array_3[1] & 0xf0) >> 4);\n char_array_4[2] = ((char_array_3[1] & 0x0f) << 2) +\n ((char_array_3[2] & 0xc0) >> 6);\n char_array_4[3] = char_array_3[2] & 0x3f;\n\n for (i = 0; (i < 4); ++i)\n ret += charSet[char_array_4[i]];\n i = 0;\n }\n }\n\n if (i)\n {\n for (int j = i; j < 3; ++j)\n char_array_3[j] = '\\0';\n\n char_array_4[0] = (char_array_3[0] & 0xfc) >> 2;\n char_array_4[1] =\n ((char_array_3[0] & 0x03) << 4) + ((char_array_3[1] & 0xf0) >> 4);\n char_array_4[2] =\n ((char_array_3[1] & 0x0f) << 2) + ((char_array_3[2] & 0xc0) >> 6);\n char_array_4[3] = char_array_3[2] & 0x3f;\n\n for (int j = 0; (j <= i); ++j)\n ret += charSet[char_array_4[j]];\n\n if (padded)\n while ((++i < 4))\n ret += '=';\n }\n return ret;\n}\n\nstd::vector base64DecodeToVector(std::string_view encoded_string)\n{\n auto in_len = encoded_string.size();\n int i = 0;\n int in_{0};\n char char_array_4[4], char_array_3[3];\n std::vector ret;\n ret.reserve(base64DecodedLength(in_len));\n\n while (in_len-- && (encoded_string[in_] != '='))\n {\n if (!isBase64(encoded_string[in_]))\n {\n ++in_;\n continue;\n }\n\n char_array_4[i++] = encoded_string[in_];\n ++in_;\n if (i == 4)\n {\n for (i = 0; i < 4; ++i)\n {\n char_array_4[i] = base64CharMap.getIndex(char_array_4[i]);\n }\n\n char_array_3[0] =\n (char_array_4[0] << 2) + ((char_array_4[1] & 0x30) >> 4);\n char_array_3[1] = ((char_array_4[1] & 0xf) << 4) +\n ((char_array_4[2] & 0x3c) >> 2);\n char_array_3[2] = ((char_array_4[2] & 0x3) << 6) + char_array_4[3];\n\n for (i = 0; (i < 3); ++i)\n ret.push_back(char_array_3[i]);\n i = 0;\n }\n }\n\n if (i)\n {\n for (int j = i; j < 4; ++j)\n char_array_4[j] = 0;\n\n for (int j = 0; j < 4; ++j)\n {\n char_array_4[j] = base64CharMap.getIndex(char_array_4[j]);\n }\n\n char_array_3[0] =\n (char_array_4[0] << 2) + ((char_array_4[1] & 0x30) >> 4);\n char_array_3[1] =\n ((char_array_4[1] & 0xf) << 4) + ((char_array_4[2] & 0x3c) >> 2);\n char_array_3[2] = ((char_array_4[2] & 0x3) << 6) + char_array_4[3];\n\n --i;\n for (int j = 0; (j < i); ++j)\n ret.push_back(char_array_3[j]);\n }\n\n return ret;\n}\n\nstd::string base64Decode(std::string_view encoded_string)\n{\n auto in_len = encoded_string.size();\n int i = 0;\n int in_{0};\n unsigned char char_array_4[4], char_array_3[3];\n std::string ret;\n ret.reserve(base64DecodedLength(in_len));\n\n while (in_len-- && (encoded_string[in_] != '='))\n {\n if (!isBase64(encoded_string[in_]))\n {\n ++in_;\n continue;\n }\n\n char_array_4[i++] = encoded_string[in_];\n ++in_;\n if (i == 4)\n {\n for (i = 0; i < 4; ++i)\n {\n char_array_4[i] = base64CharMap.getIndex(char_array_4[i]);\n }\n char_array_3[0] =\n (char_array_4[0] << 2) + ((char_array_4[1] & 0x30) >> 4);\n char_array_3[1] = ((char_array_4[1] & 0xf) << 4) +\n ((char_array_4[2] & 0x3c) >> 2);\n char_array_3[2] = ((char_array_4[2] & 0x3) << 6) + char_array_4[3];\n\n for (i = 0; (i < 3); ++i)\n ret += char_array_3[i];\n i = 0;\n }\n }\n\n if (i)\n {\n for (int j = i; j < 4; ++j)\n char_array_4[j] = 0;\n\n for (int j = 0; j < 4; ++j)\n {\n char_array_4[j] = base64CharMap.getIndex(char_array_4[j]);\n }\n\n char_array_3[0] =\n (char_array_4[0] << 2) + ((char_array_4[1] & 0x30) >> 4);\n char_array_3[1] =\n ((char_array_4[1] & 0xf) << 4) + ((char_array_4[2] & 0x3c) >> 2);\n char_array_3[2] = ((char_array_4[2] & 0x3) << 6) + char_array_4[3];\n\n --i;\n for (int j = 0; (j < i); ++j)\n ret += char_array_3[j];\n }\n\n return ret;\n}\n\nstatic std::string charToHex(char c)\n{\n std::string result;\n char first, second;\n\n first = (c & 0xF0) / 16;\n first += first > 9 ? 'A' - 10 : '0';\n second = c & 0x0F;\n second += second > 9 ? 'A' - 10 : '0';\n\n result.append(1, first);\n result.append(1, second);\n\n return result;\n}\n\nstd::string urlEncodeComponent(const std::string &src)\n{\n std::string result;\n std::string::const_iterator iter;\n\n for (iter = src.begin(); iter != src.end(); ++iter)\n {\n switch (*iter)\n {\n case ' ':\n result.append(1, '+');\n break;\n // alnum\n case 'A':\n case 'B':\n case 'C':\n case 'D':\n case 'E':\n case 'F':\n case 'G':\n case 'H':\n case 'I':\n case 'J':\n case 'K':\n case 'L':\n case 'M':\n case 'N':\n case 'O':\n case 'P':\n case 'Q':\n case 'R':\n case 'S':\n case 'T':\n case 'U':\n case 'V':\n case 'W':\n case 'X':\n case 'Y':\n case 'Z':\n case 'a':\n case 'b':\n case 'c':\n case 'd':\n case 'e':\n case 'f':\n case 'g':\n case 'h':\n case 'i':\n case 'j':\n case 'k':\n case 'l':\n case 'm':\n case 'n':\n case 'o':\n case 'p':\n case 'q':\n case 'r':\n case 's':\n case 't':\n case 'u':\n case 'v':\n case 'w':\n case 'x':\n case 'y':\n case 'z':\n case '0':\n case '1':\n case '2':\n case '3':\n case '4':\n case '5':\n case '6':\n case '7':\n case '8':\n case '9':\n // mark\n case '-':\n case '_':\n case '.':\n case '!':\n case '~':\n case '*':\n case '(':\n case ')':\n result.append(1, *iter);\n break;\n // escape\n default:\n result.append(1, '%');\n result.append(charToHex(*iter));\n break;\n }\n }\n\n return result;\n}\n\nstd::string urlEncode(const std::string &src)\n{\n std::string result;\n std::string::const_iterator iter;\n\n for (iter = src.begin(); iter != src.end(); ++iter)\n {\n switch (*iter)\n {\n case ' ':\n result.append(1, '+');\n break;\n // alnum\n case 'A':\n case 'B':\n case 'C':\n case 'D':\n case 'E':\n case 'F':\n case 'G':\n case 'H':\n case 'I':\n case 'J':\n case 'K':\n case 'L':\n case 'M':\n case 'N':\n case 'O':\n case 'P':\n case 'Q':\n case 'R':\n case 'S':\n case 'T':\n case 'U':\n case 'V':\n case 'W':\n case 'X':\n case 'Y':\n case 'Z':\n case 'a':\n case 'b':\n case 'c':\n case 'd':\n case 'e':\n case 'f':\n case 'g':\n case 'h':\n case 'i':\n case 'j':\n case 'k':\n case 'l':\n case 'm':\n case 'n':\n case 'o':\n case 'p':\n case 'q':\n case 'r':\n case 's':\n case 't':\n case 'u':\n case 'v':\n case 'w':\n case 'x':\n case 'y':\n case 'z':\n case '0':\n case '1':\n case '2':\n case '3':\n case '4':\n case '5':\n case '6':\n case '7':\n case '8':\n case '9':\n // mark\n case '-':\n case '_':\n case '.':\n case '!':\n case '~':\n case '*':\n case '\\'':\n case '(':\n case ')':\n case '&':\n case '=':\n case '/':\n case '\\\\':\n case '?':\n result.append(1, *iter);\n break;\n // escape\n default:\n result.append(1, '%');\n result.append(charToHex(*iter));\n break;\n }\n }\n\n return result;\n}\n\nbool needUrlDecoding(const char *begin, const char *end)\n{\n return std::find_if(begin, end, [](const char c) {\n return c == '+' || c == '%';\n }) != end;\n}\n\nstd::string urlDecode(const char *begin, const char *end)\n{\n std::string result;\n size_t len = end - begin;\n result.reserve(len * 2);\n int hex = 0;\n for (size_t i = 0; i < len; ++i)\n {\n switch (begin[i])\n {\n case '+':\n result += ' ';\n break;\n case '%':\n if ((i + 2) < len && isxdigit(begin[i + 1]) &&\n isxdigit(begin[i + 2]))\n {\n unsigned int x1 = begin[i + 1];\n if (x1 >= '0' && x1 <= '9')\n {\n x1 -= '0';\n }\n else if (x1 >= 'a' && x1 <= 'f')\n {\n x1 = x1 - 'a' + 10;\n }\n else if (x1 >= 'A' && x1 <= 'F')\n {\n x1 = x1 - 'A' + 10;\n }\n unsigned int x2 = begin[i + 2];\n if (x2 >= '0' && x2 <= '9')\n {\n x2 -= '0';\n }\n else if (x2 >= 'a' && x2 <= 'f')\n {\n x2 = x2 - 'a' + 10;\n }\n else if (x2 >= 'A' && x2 <= 'F')\n {\n x2 = x2 - 'A' + 10;\n }\n hex = x1 * 16 + x2;\n\n result += char(hex);\n i += 2;\n }\n else\n {\n result += '%';\n }\n break;\n default:\n result += begin[i];\n break;\n }\n }\n return result;\n}\n\n/* Compress gzip data */\nstd::string gzipCompress(const char *data, const size_t ndata)\n{\n z_stream strm = {nullptr,\n 0,\n 0,\n nullptr,\n 0,\n 0,\n nullptr,\n nullptr,\n nullptr,\n nullptr,\n nullptr,\n 0,\n 0,\n 0};\n if (data && ndata > 0)\n {\n if (deflateInit2(&strm,\n Z_DEFAULT_COMPRESSION,\n Z_DEFLATED,\n MAX_WBITS + 16,\n 8,\n Z_DEFAULT_STRATEGY) != Z_OK)\n {\n LOG_ERROR << \"deflateInit2 error!\";\n return std::string{};\n }\n std::string outstr;\n outstr.resize(compressBound(static_cast(ndata)));\n strm.next_in = (Bytef *)data;\n strm.avail_in = static_cast(ndata);\n int ret;\n do\n {\n if (strm.total_out >= outstr.size())\n {\n outstr.resize(strm.total_out * 2);\n }\n assert(outstr.size() >= strm.total_out);\n strm.avail_out = static_cast(outstr.size() - strm.total_out);\n strm.next_out = (Bytef *)outstr.data() + strm.total_out;\n ret = deflate(&strm, Z_FINISH); /* no bad return value */\n if (ret == Z_STREAM_ERROR)\n {\n (void)deflateEnd(&strm);\n return std::string{};\n }\n } while (strm.avail_out == 0);\n assert(strm.avail_in == 0);\n assert(ret == Z_STREAM_END); /* stream will be complete */\n outstr.resize(strm.total_out);\n /* clean up and return */\n (void)deflateEnd(&strm);\n return outstr;\n }\n return std::string{};\n}\n\nstd::string gzipDecompress(const char *data, const size_t ndata)\n{\n if (ndata == 0)\n return std::string(data, ndata);\n\n auto full_length = ndata;\n\n auto decompressed = std::string(full_length * 2, 0);\n bool done = false;\n\n z_stream strm = {nullptr,\n 0,\n 0,\n nullptr,\n 0,\n 0,\n nullptr,\n nullptr,\n nullptr,\n nullptr,\n nullptr,\n 0,\n 0,\n 0};\n strm.next_in = (Bytef *)data;\n strm.avail_in = static_cast(ndata);\n strm.total_out = 0;\n strm.zalloc = Z_NULL;\n strm.zfree = Z_NULL;\n if (inflateInit2(&strm, (15 + 32)) != Z_OK)\n {\n LOG_ERROR << \"inflateInit2 error!\";\n return std::string{};\n }\n while (!done)\n {\n // Make sure we have enough room and reset the lengths.\n if (strm.total_out >= decompressed.length())\n {\n decompressed.resize(decompressed.length() * 2);\n }\n strm.next_out = (Bytef *)decompressed.data() + strm.total_out;\n strm.avail_out =\n static_cast(decompressed.length() - strm.total_out);\n // Inflate another chunk.\n int status = inflate(&strm, Z_SYNC_FLUSH);\n if (status == Z_STREAM_END)\n {\n done = true;\n }\n else if (status != Z_OK)\n {\n break;\n }\n }\n if (inflateEnd(&strm) != Z_OK)\n return std::string{};\n // Set real length.\n if (done)\n {\n decompressed.resize(strm.total_out);\n return decompressed;\n }\n else\n {\n return std::string{};\n }\n}\n\nchar *getHttpFullDate(const trantor::Date &date)\n{\n static thread_local int64_t lastSecond = 0;\n static thread_local char lastTimeString[128] = {0};\n auto nowSecond = date.microSecondsSinceEpoch() / MICRO_SECONDS_PRE_SEC;\n if (nowSecond == lastSecond)\n {\n return lastTimeString;\n }\n lastSecond = nowSecond;\n date.toCustomedFormattedString(\"%a, %d %b %Y %H:%M:%S GMT\",\n lastTimeString,\n sizeof(lastTimeString));\n return lastTimeString;\n}\n\ntrantor::Date getHttpDate(const std::string &httpFullDateString)\n{\n static const std::array formats = {\n // RFC822 (default)\n \"%a, %d %b %Y %H:%M:%S\",\n // RFC 850 (deprecated)\n \"%a, %d-%b-%y %H:%M:%S\",\n // ansi asctime format\n \"%a %b %d %H:%M:%S %Y\",\n // weird RFC 850-hybrid thing that reddit uses\n \"%a, %d-%b-%Y %H:%M:%S\",\n };\n struct tm tmptm;\n for (const char *format : formats)\n {\n if (strptime(httpFullDateString.c_str(), format, &tmptm) != NULL)\n {\n auto epoch = timegm(&tmptm);\n return trantor::Date(epoch * MICRO_SECONDS_PRE_SEC);\n }\n }\n LOG_WARN << \"invalid datetime format: '\" << httpFullDateString << \"'\";\n return trantor::Date((std::numeric_limits::max)());\n}\n\nstd::string formattedString(const char *format, ...)\n{\n std::string strBuffer(128, 0);\n va_list ap, backup_ap;\n va_start(ap, format);\n va_copy(backup_ap, ap);\n auto result = vsnprintf((char *)strBuffer.data(),\n strBuffer.size(),\n format,\n backup_ap);\n va_end(backup_ap);\n if ((result >= 0) && ((std::string::size_type)result < strBuffer.size()))\n {\n strBuffer.resize(result);\n }\n else\n {\n while (true)\n {\n if (result < 0)\n {\n // Older snprintf() behavior. Just try doubling the buffer size\n strBuffer.resize(strBuffer.size() * 2);\n }\n else\n {\n strBuffer.resize(result + 1);\n }\n\n va_copy(backup_ap, ap);\n auto result = vsnprintf((char *)strBuffer.data(),\n strBuffer.size(),\n format,\n backup_ap);\n va_end(backup_ap);\n\n if ((result >= 0) &&\n ((std::string::size_type)result < strBuffer.size()))\n {\n strBuffer.resize(result);\n break;\n }\n }\n }\n va_end(ap);\n return strBuffer;\n}\n\nint createPath(const std::string &path)\n{\n if (path.empty())\n return 0;\n auto osPath{toNativePath(path)};\n if (osPath.back() != std::filesystem::path::preferred_separator)\n osPath.push_back(std::filesystem::path::preferred_separator);\n std::filesystem::path fsPath(osPath);\n std::error_code err;\n std::filesystem::create_directories(fsPath, err);\n if (err)\n {\n LOG_ERROR << \"Error \" << err.value() << \" creating path \" << osPath\n << \": \" << err.message();\n return -1;\n }\n return 0;\n}\n#ifdef USE_BROTLI\nstd::string brotliCompress(const char *data, const size_t ndata)\n{\n std::string ret;\n if (ndata == 0)\n return ret;\n ret.resize(BrotliEncoderMaxCompressedSize(ndata));\n size_t encodedSize{ret.size()};\n auto r = BrotliEncoderCompress(5,\n BROTLI_DEFAULT_WINDOW,\n BROTLI_DEFAULT_MODE,\n ndata,\n (const uint8_t *)(data),\n &encodedSize,\n (uint8_t *)(ret.data()));\n if (r == BROTLI_FALSE)\n ret.resize(0);\n else\n ret.resize(encodedSize);\n return ret;\n}\n\nstd::string brotliDecompress(const char *data, const size_t ndata)\n{\n if (ndata == 0)\n return std::string(data, ndata);\n\n size_t availableIn = ndata;\n auto nextIn = (const uint8_t *)(data);\n auto decompressed = std::string(availableIn * 3, 0);\n size_t availableOut = decompressed.size();\n auto nextOut = (uint8_t *)(decompressed.data());\n size_t totalOut{0};\n bool done = false;\n auto s = BrotliDecoderCreateInstance(nullptr, nullptr, nullptr);\n while (!done)\n {\n auto result = BrotliDecoderDecompressStream(\n s, &availableIn, &nextIn, &availableOut, &nextOut, &totalOut);\n if (result == BROTLI_DECODER_RESULT_SUCCESS)\n {\n decompressed.resize(totalOut);\n done = true;\n }\n else if (result == BROTLI_DECODER_RESULT_NEEDS_MORE_OUTPUT)\n {\n assert(totalOut == decompressed.size());\n decompressed.resize(totalOut * 2);\n nextOut = (uint8_t *)(decompressed.data() + totalOut);\n availableOut = totalOut;\n }\n else\n {\n decompressed.resize(0);\n done = true;\n }\n }\n BrotliDecoderDestroyInstance(s);\n return decompressed;\n}\n#else\nstd::string brotliCompress(const char * /*data*/, const size_t /*ndata*/)\n{\n LOG_ERROR << \"If you do not have the brotli package installed, you cannot \"\n \"use brotliCompress()\";\n abort();\n}\n\nstd::string brotliDecompress(const char * /*data*/, const size_t /*ndata*/)\n{\n LOG_ERROR << \"If you do not have the brotli package installed, you cannot \"\n \"use brotliDecompress()\";\n abort();\n}\n#endif\n\nstd::string getMd5(const char *data, const size_t dataLen)\n{\n return trantor::utils::toHexString(trantor::utils::md5(data, dataLen));\n}\n\nstd::string getSha1(const char *data, const size_t dataLen)\n{\n return trantor::utils::toHexString(trantor::utils::sha1(data, dataLen));\n}\n\nstd::string getSha256(const char *data, const size_t dataLen)\n{\n return trantor::utils::toHexString(trantor::utils::sha256(data, dataLen));\n}\n\nstd::string getSha3(const char *data, const size_t dataLen)\n{\n return trantor::utils::toHexString(trantor::utils::sha3(data, dataLen));\n}\n\nstd::string getBlake2b(const char *data, const size_t dataLen)\n{\n return trantor::utils::toHexString(trantor::utils::blake2b(data, dataLen));\n}\n\nvoid replaceAll(std::string &s, const std::string &from, const std::string &to)\n{\n size_t pos = 0;\n while ((pos = s.find(from, pos)) != std::string::npos)\n {\n s.replace(pos, from.size(), to);\n pos += to.size();\n }\n}\n\nbool supportsTls() noexcept\n{\n return trantor::utils::tlsBackend() != \"None\";\n}\n\nbool secureRandomBytes(void *ptr, size_t size)\n{\n return trantor::utils::secureRandomBytes(ptr, size);\n}\n\nstd::string secureRandomString(size_t size)\n{\n if (size == 0)\n return std::string();\n\n std::string ret(size, 0);\n const std::string_view chars =\n \"0123456789\"\n \"ABCDEFGHIJKLMNOPQRSTUVWXYZ\"\n \"abcdefghijklmnopqrstuvwxyz\"\n \"+-\";\n assert(chars.size() == 64);\n\n // batch up to 32 bytes of random data for efficiency. Calling\n // secureRandomBytes can be expensive.\n auto randByte = []() {\n thread_local trantor::utils::Hash256 hash;\n thread_local size_t i = 0;\n if (i == 0)\n {\n bool ok = trantor::utils::secureRandomBytes(&hash, sizeof(hash));\n if (!ok)\n throw std::runtime_error(\n \"Failed to generate random bytes for secureRandomString\");\n }\n unsigned char *hashBytes = reinterpret_cast(&hash);\n auto ret = hashBytes[i];\n i = (i + 1) % sizeof(hash);\n return ret;\n };\n\n for (size_t i = 0; i < size; ++i)\n ret[i] = chars[randByte() % 64];\n return ret;\n}\n\nnamespace internal\n{\nconst size_t fixedRandomNumber = []() {\n size_t res;\n utils::secureRandomBytes(&res, sizeof(res));\n return res;\n}();\n}\n\n} // namespace utils\n} // namespace drogon\n\n// Path: lib/src/SpinLock.h\n/**\n * SpinLock.h\n * An Tao\n *\n * Copyright 2018, An Tao. All rights reserved.\n * https://github.com/an-tao/drogon\n * Use of this source code is governed by a MIT license\n * that can be found in the License file.\n *\n * Drogon\n *\n */\n\n#pragma once\n#include \n...\n", "needle_token_start": 7115, "needle_token_end": 7355, "code_context_ntokens": 16163, "output": ["```cpp\nstd::set splitStringToSet(const std::string &str, const std::string &separator)\n{\n std::set ret;\n std::string::size_type pos1, pos2;\n pos2 = 0;\n pos1 = str.find(separator);\n while (pos1 != std::string::npos)\n {\n if (pos1 != 0)\n {\n std::string item = str.substr(pos2, pos1 - pos2);\n ret.insert(item);\n }\n pos2 = pos1 + separator.length();\n while (pos2 < str.length() &&\n str.substr(pos2, separator.length()) == separator)\n pos2 += separator.length();\n pos1 = str.find(separator, pos2);\n }\n if (pos2 < str.length())\n ret.insert(str.substr(pos2));\n return ret;\n}\n```"]} +{"repo": "drogonframework/drogon", "name": "HttpMessageStringViewBody", "language": "cpp", "path": "lib/src/HttpMessageBody.h", "position_ratio": 0.55, "description": "\nFunction Description:\n1. **Purpose**: To manage a view of a constant segment of data, representing the body of an HTTP message without owning the data itself.\n2. **Input**: A pointer to the beginning of the data segment and the length of the segment.\n3. **Output**: Provides access to the data segment and its properties such as length, but does not modify the original data.\n4. **Procedure**: Initializes with a data pointer and length, setting the internal type to indicate it handles a view of data. Provides methods to access the data and its length, but does not support modifying the data.\n\n", "instruction": "Based on the function description and code context, please retrieve and repeat the exact described function from the code context in a code block wrapped by ```:", "template": "instruction\ncode_context\ndescription\ninstruction", "code_context": "// Path: lib/src/HttpAppFrameworkImpl.h\n/**\n *\n * @file HttpAppFrameworkImpl.h\n * @author An Tao\n *\n * Copyright 2018, An Tao. All rights reserved.\n * https://github.com/an-tao/drogon\n * Use of this source code is governed by a MIT license\n * that can be found in the License file.\n *\n * Drogon\n *\n */\n\n#pragma once\n\n#include \n#include \n#include \n#include \n#include \n#include \n#include \n#include \"SessionManager.h\"\n#include \"drogon/utils/Utilities.h\"\n#include \"impl_forwards.h\"\n\nnamespace trantor\n{\nclass EventLoopThreadPool;\n}\n\nnamespace drogon\n{\nHttpResponsePtr defaultErrorHandler(HttpStatusCode code,\n const HttpRequestPtr &req);\nvoid defaultExceptionHandler(const std::exception &,\n const HttpRequestPtr &,\n std::function &&);\n\nstruct InitBeforeMainFunction\n{\n explicit InitBeforeMainFunction(const std::function &func)\n {\n func();\n }\n};\n\nclass HttpAppFrameworkImpl final : public HttpAppFramework\n{\n public:\n HttpAppFrameworkImpl();\n\n const Json::Value &getCustomConfig() const override\n {\n return jsonConfig_[\"custom_config\"];\n }\n\n PluginBase *getPlugin(const std::string &name) override;\n std::shared_ptr getSharedPlugin(\n const std::string &name) override;\n void addPlugins(const Json::Value &configs) override;\n void addPlugin(const std::string &name,\n const std::vector &dependencies,\n const Json::Value &config) override;\n HttpAppFramework &addListener(\n const std::string &ip,\n uint16_t port,\n bool useSSL,\n const std::string &certFile,\n const std::string &keyFile,\n bool useOldTLS,\n const std::vector> &sslConfCmds)\n override;\n HttpAppFramework &setThreadNum(size_t threadNum) override;\n\n size_t getThreadNum() const override\n {\n return threadNum_;\n }\n\n HttpAppFramework &setSSLConfigCommands(\n const std::vector> &sslConfCmds)\n override;\n HttpAppFramework &setSSLFiles(const std::string &certPath,\n const std::string &keyPath) override;\n void run() override;\n HttpAppFramework ®isterWebSocketController(\n const std::string &pathName,\n const std::string &ctrlName,\n const std::vector &filtersAndMethods)\n override;\n HttpAppFramework ®isterHttpSimpleController(\n const std::string &pathName,\n const std::string &ctrlName,\n const std::vector &filtersAndMethods)\n override;\n\n HttpAppFramework &setCustom404Page(const HttpResponsePtr &resp,\n bool set404) override\n {\n if (set404)\n {\n resp->setStatusCode(k404NotFound);\n }\n custom404_ = resp;\n return *this;\n }\n\n HttpAppFramework &setCustomErrorHandler(\n std::function\n &&resp_generator) override;\n\n const HttpResponsePtr &getCustom404Page();\n\n void forward(const HttpRequestPtr &req,\n std::function &&callback,\n const std::string &hostString,\n double timeout) override;\n\n void forward(const HttpRequestImplPtr &req,\n std::function &&callback,\n const std::string &hostString,\n double timeout = 0);\n\n HttpAppFramework ®isterBeginningAdvice(\n const std::function &advice) override\n {\n beginningAdvices_.emplace_back(advice);\n return *this;\n }\n\n HttpAppFramework ®isterNewConnectionAdvice(\n const std::function &advice)\n override;\n\n HttpAppFramework ®isterHttpResponseCreationAdvice(\n const std::function &advice) override;\n\n HttpAppFramework ®isterSyncAdvice(\n const std::function &advice)\n override;\n\n HttpAppFramework ®isterPreRoutingAdvice(\n const std::function &advice) override;\n\n HttpAppFramework ®isterPostRoutingAdvice(\n const std::function &advice) override;\n\n HttpAppFramework ®isterPreHandlingAdvice(\n const std::function &advice) override;\n\n HttpAppFramework ®isterPreRoutingAdvice(\n const std::function &advice) override;\n\n HttpAppFramework ®isterPostRoutingAdvice(\n const std::function &advice) override;\n\n HttpAppFramework ®isterPreHandlingAdvice(\n const std::function &advice) override;\n\n HttpAppFramework ®isterPostHandlingAdvice(\n const std::function &advice) override;\n\n HttpAppFramework ®isterPreSendingAdvice(\n const std::function &advice) override;\n\n HttpAppFramework &setDefaultHandler(DefaultHandler handler) override;\n\n HttpAppFramework &setupFileLogger() override;\n\n HttpAppFramework &enableSession(\n const size_t timeout,\n Cookie::SameSite sameSite = Cookie::SameSite::kNull,\n const std::string &cookieKey = \"JSESSIONID\",\n int maxAge = -1,\n SessionManager::IdGeneratorCallback idGeneratorCallback =\n nullptr) override\n {\n useSession_ = true;\n sessionTimeout_ = timeout;\n sessionSameSite_ = sameSite;\n sessionCookieKey_ = cookieKey;\n sessionMaxAge_ = maxAge;\n return setSessionIdGenerator(idGeneratorCallback);\n }\n\n HttpAppFramework &setSessionIdGenerator(\n SessionManager::IdGeneratorCallback idGeneratorCallback = nullptr)\n {\n sessionIdGeneratorCallback_ =\n idGeneratorCallback ? idGeneratorCallback\n : []() { return utils::getUuid(true); };\n return *this;\n }\n\n HttpAppFramework &disableSession() override\n {\n useSession_ = false;\n return *this;\n }\n\n HttpAppFramework ®isterSessionStartAdvice(\n const AdviceStartSessionCallback &advice) override\n {\n sessionStartAdvices_.emplace_back(advice);\n return *this;\n }\n\n HttpAppFramework ®isterSessionDestroyAdvice(\n const AdviceDestroySessionCallback &advice) override\n {\n sessionDestroyAdvices_.emplace_back(advice);\n return *this;\n }\n\n const std::string &getDocumentRoot() const override\n {\n return rootPath_;\n }\n\n HttpAppFramework &setDocumentRoot(const std::string &rootPath) override\n {\n rootPath_ = rootPath;\n return *this;\n }\n\n HttpAppFramework &setStaticFileHeaders(\n const std::vector> &headers)\n override;\n\n HttpAppFramework &addALocation(\n const std::string &uriPrefix,\n const std::string &defaultContentType,\n const std::string &alias,\n bool isCaseSensitive,\n bool allowAll,\n bool isRecursive,\n const std::vector &filters) override;\n\n const std::string &getUploadPath() const override\n {\n return uploadPath_;\n }\n\n const std::shared_ptr &getResolver() const override\n {\n static auto resolver = trantor::Resolver::newResolver(getLoop());\n return resolver;\n }\n\n HttpAppFramework &setUploadPath(const std::string &uploadPath) override;\n HttpAppFramework &setFileTypes(\n const std::vector &types) override;\n#ifndef _WIN32\n HttpAppFramework &enableDynamicViewsLoading(\n const std::vector &libPaths,\n const std::string &outputPath) override;\n#endif\n HttpAppFramework &setMaxConnectionNum(size_t maxConnections) override;\n HttpAppFramework &setMaxConnectionNumPerIP(\n size_t maxConnectionsPerIP) override;\n HttpAppFramework &loadConfigFile(const std::string &fileName) noexcept(\n false) override;\n HttpAppFramework &loadConfigJson(const Json::Value &data) noexcept(\n false) override;\n HttpAppFramework &loadConfigJson(Json::Value &&data) noexcept(\n false) override;\n\n HttpAppFramework &enableRunAsDaemon() override\n {\n runAsDaemon_ = true;\n return *this;\n...\n// Path: lib/src/HttpUtils.h\n/**\n *\n * @file HttpUtils.h\n * @author An Tao\n *\n * Copyright 2018, An Tao. All rights reserved.\n * https://github.com/an-tao/drogon\n * Use of this source code is governed by a MIT license\n * that can be found in the License file.\n *\n * Drogon\n *\n */\n\n#pragma once\n\n#include \n#include \n#include \n#include \n\nnamespace drogon\n{\nconst std::string_view &contentTypeToMime(ContentType contentType);\nconst std::string_view &statusCodeToString(int code);\nContentType getContentType(const std::string &fileName);\nContentType parseContentType(const std::string_view &contentType);\nFileType parseFileType(const std::string_view &fileExtension);\nFileType getFileType(ContentType contentType);\nvoid registerCustomExtensionMime(const std::string &ext,\n const std::string &mime);\nconst std::string_view fileNameToMime(const std::string &fileName);\nstd::pair fileNameToContentTypeAndMime(\n const std::string &filename);\n\ninline std::string_view getFileExtension(const std::string &fileName)\n{\n auto pos = fileName.rfind('.');\n if (pos == std::string::npos)\n return \"\";\n return std::string_view(&fileName[pos + 1], fileName.length() - pos - 1);\n}\n\nconst std::vector &getFileExtensions(ContentType contentType);\n\ninline const std::vector &getFileExtensions(\n const std::string_view &contentType)\n{\n return getFileExtensions(parseContentType(contentType));\n}\n\ntemplate \ninline constexpr const char *contentLengthFormatString()\n{\n return \"content-length: %d\\r\\n\";\n}\n\ntemplate <>\ninline constexpr const char *contentLengthFormatString()\n{\n return \"content-length: %u\\r\\n\";\n}\n\ntemplate <>\ninline constexpr const char *contentLengthFormatString()\n{\n return \"content-length: %ld\\r\\n\";\n}\n\ntemplate <>\ninline constexpr const char *contentLengthFormatString()\n{\n return \"content-length: %lu\\r\\n\";\n}\n\ntemplate <>\ninline constexpr const char *contentLengthFormatString()\n{\n return \"content-length: %lld\\r\\n\";\n}\n\ntemplate <>\ninline constexpr const char *contentLengthFormatString()\n{\n return \"content-length: %llu\\r\\n\";\n}\n} // namespace drogon\n\n// Path: lib/src/CacheFile.h\n/**\n *\n * CacheFile.h\n * An Tao\n *\n * Copyright 2018, An Tao. All rights reserved.\n * https://github.com/an-tao/drogon\n * Use of this source code is governed by a MIT license\n * that can be found in the License file.\n *\n * Drogon\n *\n */\n\n#pragma once\n\n#include \n#include \n#include \n#include \n\nnamespace drogon\n{\nclass CacheFile : public trantor::NonCopyable\n{\n public:\n explicit CacheFile(const std::string &path, bool autoDelete = true);\n ~CacheFile();\n\n void append(const std::string &data)\n {\n append(data.data(), data.length());\n }\n\n void append(const char *data, size_t length);\n\n std::string_view getStringView()\n {\n if (data())\n return std::string_view(data_, dataLength_);\n return std::string_view();\n }\n\n private:\n char *data();\n size_t length();\n FILE *file_{nullptr};\n bool autoDelete_{true};\n const std::string path_;\n char *data_{nullptr};\n size_t dataLength_{0};\n};\n} // namespace drogon\n\n// Path: lib/src/HttpRequestImpl.h\n/**\n *\n * @file HttpRequestImpl.h\n * An Tao\n *\n * Copyright 2018, An Tao. All rights reserved.\n * https://github.com/an-tao/drogon\n * Use of this source code is governed by a MIT license\n * that can be found in the License file.\n *\n * Drogon\n *\n */\n\n#pragma once\n\n#include \"HttpUtils.h\"\n#include \"CacheFile.h\"\n#include \n#include \n#include \n#include \n#include \n#include \n#include \n#include \n#include \n#include \n#include \n#include \n#include \n#include \n#include \n\nnamespace drogon\n{\nenum class StreamDecompressStatus\n{\n TooLarge,\n DecompressError,\n NotSupported,\n Ok\n};\n\nclass HttpRequestImpl : public HttpRequest\n{\n public:\n friend class HttpRequestParser;\n\n explicit HttpRequestImpl(trantor::EventLoop *loop)\n : creationDate_(trantor::Date::now()), loop_(loop)\n {\n }\n\n void reset()\n {\n method_ = Invalid;\n previousMethod_ = Invalid;\n version_ = Version::kUnknown;\n flagForParsingJson_ = false;\n headers_.clear();\n cookies_.clear();\n flagForParsingParameters_ = false;\n path_.clear();\n originalPath_.clear();\n pathEncode_ = true;\n matchedPathPattern_ = \"\";\n query_.clear();\n parameters_.clear();\n jsonPtr_.reset();\n sessionPtr_.reset();\n attributesPtr_.reset();\n cacheFilePtr_.reset();\n expectPtr_.reset();\n content_.clear();\n contentType_ = CT_TEXT_PLAIN;\n flagForParsingContentType_ = false;\n contentTypeString_.clear();\n keepAlive_ = true;\n jsonParsingErrorPtr_.reset();\n peerCertificate_.reset();\n routingParams_.clear();\n }\n\n trantor::EventLoop *getLoop()\n {\n return loop_;\n }\n\n void setVersion(Version v)\n {\n version_ = v;\n if (version_ == Version::kHttp10)\n {\n keepAlive_ = false;\n }\n }\n\n Version version() const override\n {\n return version_;\n }\n\n const char *versionString() const override;\n\n bool setMethod(const char *start, const char *end);\n\n void setSecure(bool secure)\n {\n isOnSecureConnection_ = secure;\n }\n\n void setMethod(const HttpMethod method) override\n {\n previousMethod_ = method_;\n method_ = method;\n return;\n }\n\n HttpMethod method() const override\n {\n return method_;\n }\n\n bool isHead() const override\n {\n return (method_ == HttpMethod::Head) ||\n ((method_ == HttpMethod::Get) &&\n (previousMethod_ == HttpMethod::Head));\n }\n\n const char *methodString() const override;\n\n void setPath(const char *start, const char *end)\n {\n if (utils::needUrlDecoding(start, end))\n {\n originalPath_.append(start, end);\n path_ = utils::urlDecode(start, end);\n }\n else\n {\n path_.append(start, end);\n }\n }\n\n const std::vector &getRoutingParameters() const override\n {\n return routingParams_;\n }\n\n void setRoutingParameters(std::vector &¶ms) override\n {\n routingParams_ = std::move(params);\n }\n\n void setPath(const std::string &path) override\n {\n path_ = path;\n }\n\n void setPath(std::string &&path) override\n {\n path_ = std::move(path);\n }\n\n void setPathEncode(bool pathEncode) override\n {\n pathEncode_ = pathEncode;\n }\n\n const SafeStringMap ¶meters() const override\n {\n parseParametersOnce();\n return parameters_;\n }\n\n const std::string &getParameter(const std::string &key) const override\n {\n static const std::string defaultVal;\n parseParametersOnce();\n auto iter = parameters_.find(key);\n if (iter != parameters_.end())\n return iter->second;\n return defaultVal;\n }\n\n const std::string &path() const override\n {\n return path_;\n }\n\n const std::string &getOriginalPath() const override\n {\n return originalPath_.empty() ? path_ : originalPath_;\n }\n\n void setQuery(const char *start, const char *end)\n {\n query_.assign(start, end);\n }\n\n void setQuery(const std::string &query)\n {\n query_ = query;\n }\n\n std::string_view bodyView() const\n {\n if (cacheFilePtr_)\n {\n return cacheFilePtr_->getStringView();\n }\n return content_;\n }\n\n const char *bodyData() const override\n {\n if (cacheFilePtr_)\n {\n return cacheFilePtr_->getStringView().data();\n }\n return content_.data();\n }\n\n size_t bodyLength() const override\n {\n if (cacheFilePtr_)\n {\n return cacheFilePtr_->getStringView().length();\n }\n return content_.length();\n }\n\n void appendToBody(const char *data, size_t length);\n\n void reserveBodySize(size_t length);\n\n std::string_view queryView() const\n {\n return query_;\n }\n\n std::string_view contentView() const\n {\n if (cacheFilePtr_)\n return cacheFilePtr_->getStringView();\n return content_;\n }\n\n const std::string &query() const override\n {\n return query_;\n }\n\n const trantor::InetAddress &peerAddr() const override\n {\n return peer_;\n }\n\n const trantor::InetAddress &localAddr() const override\n {\n return local_;\n }\n\n const trantor::Date &creationDate() const override\n {\n return creationDate_;\n }\n\n const trantor::CertificatePtr &peerCertificate() const override\n {\n return peerCertificate_;\n }\n\n void setCreationDate(const trantor::Date &date)\n {\n creationDate_ = date;\n }\n\n void setPeerAddr(const trantor::InetAddress &peer)\n {\n peer_ = peer;\n }\n\n void setLocalAddr(const trantor::InetAddress &local)\n {\n local_ = local;\n }\n\n void setPeerCertificate(const trantor::CertificatePtr &cert)\n {\n peerCertificate_ = cert;\n }\n\n void addHeader(const char *start, const char *colon, const char *end);\n\n void removeHeader(std::string key) override\n {\n transform(key.begin(), key.end(), key.begin(), [](unsigned char c) {\n return tolower(c);\n });\n removeHeaderBy(key);\n }\n\n void removeHeaderBy(const std::string &lowerKey)\n {\n headers_.erase(lowerKey);\n }\n\n const std::string &getHeader(std::string field) const override\n {\n std::transform(field.begin(),\n field.end(),\n field.begin(),\n [](unsigned char c) { return tolower(c); });\n return getHeaderBy(field);\n }\n\n const std::string &getHeaderBy(const std::string &lowerField) const\n {\n static const std::string defaultVal;\n auto it = headers_.find(lowerField);\n if (it != headers_.end())\n {\n return it->second;\n }\n return defaultVal;\n }\n\n const std::string &getCookie(const std::string &field) const override\n {\n static const std::string defaultVal;\n auto it = cookies_.find(field);\n if (it != cookies_.end())\n {\n return it->second;\n }\n return defaultVal;\n }\n\n const SafeStringMap &headers() const override\n {\n return headers_;\n }\n\n const SafeStringMap &cookies() const override\n {\n return cookies_;\n }\n\n void setParameter(const std::string &key, const std::string &value) override\n {\n flagForParsingParameters_ = true;\n parameters_[key] = value;\n }\n\n const std::string &getContent() const\n {\n return content_;\n }\n\n void swap(HttpRequestImpl &that) noexcept;\n\n void setContent(const std::string &content)\n {\n content_ = content;\n }\n\n void setBody(const std::string &body) override\n {\n content_ = body;\n }\n\n void setBody(std::string &&body) override\n {\n content_ = std::move(body);\n }\n\n void addHeader(std::string field, const std::string &value) override\n {\n transform(field.begin(),\n field.end(),\n field.begin(),\n [](unsigned char c) { return tolower(c); });\n headers_[std::move(field)] = value;\n }\n\n void addHeader(std::string field, std::string &&value) override\n {\n transform(field.begin(),\n field.end(),\n field.begin(),\n [](unsigned char c) { return tolower(c); });\n headers_[std::move(field)] = std::move(value);\n }\n\n void addCookie(const std::string &key, const std::string &value) override\n {\n cookies_[key] = value;\n }\n\n void setPassThrough(bool flag) override\n {\n passThrough_ = flag;\n }\n\n bool passThrough() const\n {\n return passThrough_;\n }\n\n void appendToBuffer(trantor::MsgBuffer *output) const;\n\n const SessionPtr &session() const override\n {\n return sessionPtr_;\n }\n\n void setSession(const SessionPtr &session)\n {\n sessionPtr_ = session;\n }\n\n const AttributesPtr &attributes() const override\n {\n if (!attributesPtr_)\n {\n attributesPtr_ = std::make_shared();\n }\n return attributesPtr_;\n }\n\n const std::shared_ptr &jsonObject() const override\n {\n // Not multi-thread safe but good, because we basically call this\n // function in a single thread\n if (!flagForParsingJson_)\n {\n flagForParsingJson_ = true;\n parseJson();\n }\n return jsonPtr_;\n }\n\n void setCustomContentTypeString(const std::string &type) override\n {\n contentType_ = CT_NONE;\n flagForParsingContentType_ = true;\n bool haveHeader = type.find(\"content-type: \") == 0;\n bool haveCRLF = type.rfind(\"\\r\\n\") == type.size() - 2;\n\n size_t endOffset = 0;\n if (haveHeader)\n endOffset += 14;\n if (haveCRLF)\n endOffset += 2;\n contentTypeString_ = std::string(type.begin() + (haveHeader ? 14 : 0),\n type.end() - endOffset);\n }\n\n void setContentTypeCode(const ContentType type) override\n {\n contentType_ = type;\n flagForParsingContentType_ = true;\n auto &typeStr = contentTypeToMime(type);\n setContentType(std::string(typeStr.data(), typeStr.length()));\n }\n\n void setContentTypeString(const char *typeString,\n size_t typeStringLength) override;\n\n // void setContentTypeCodeAndCharacterSet(ContentType type, const\n // std::string &charSet = \"utf-8\") override\n // {\n // contentType_ = type;\n // setContentType(webContentTypeAndCharsetToString(type, charSet));\n // }\n\n ContentType contentType() const override\n {\n parseContentTypeAndString();\n return contentType_;\n }\n\n const char *matchedPathPatternData() const override\n {\n return matchedPathPattern_.data();\n }\n\n size_t matchedPathPatternLength() const override\n {\n return matchedPathPattern_.length();\n }\n\n void setMatchedPathPattern(const std::string &pathPattern)\n {\n matchedPathPattern_ = pathPattern;\n }\n\n const std::string &expect() const\n {\n static const std::string none{\"\"};\n if (expectPtr_)\n return *expectPtr_;\n return none;\n }\n\n bool keepAlive() const\n {\n return keepAlive_;\n }\n\n bool isOnSecureConnection() const noexcept override\n {\n return isOnSecureConnection_;\n }\n\n const std::string &getJsonError() const override\n {\n static const std::string none{\"\"};\n if (jsonParsingErrorPtr_)\n return *jsonParsingErrorPtr_;\n return none;\n }\n\n StreamDecompressStatus decompressBody();\n\n ~HttpRequestImpl();\n\n protected:\n friend class HttpRequest;\n\n void setContentType(const std::string &contentType)\n {\n contentTypeString_ = contentType;\n }\n\n void setContentType(std::string &&contentType)\n {\n contentTypeString_ = std::move(contentType);\n }\n\n void parseContentTypeAndString() const\n {\n if (!flagForParsingContentType_)\n {\n flagForParsingContentType_ = true;\n auto &contentTypeString = getHeaderBy(\"content-type\");\n if (contentTypeString == \"\")\n {\n contentType_ = CT_NONE;\n }\n else\n {\n auto pos = contentTypeString.find(';');\n if (pos != std::string::npos)\n {\n contentType_ = parseContentType(\n std::string_view(contentTypeString.data(), pos));\n }\n else\n {\n contentType_ =\n parseContentType(std::string_view(contentTypeString));\n }\n\n if (contentType_ == CT_NONE)\n contentType_ = CT_CUSTOM;\n contentTypeString_ = contentTypeString;\n }\n }\n }\n\n private:\n void parseParameters() const;\n\n void parseParametersOnce() const\n {\n // Not multi-thread safe but good, because we basically call this\n // function in a single thread\n if (!flagForParsingParameters_)\n {\n flagForParsingParameters_ = true;\n parseParameters();\n }\n }\n\n void createTmpFile();\n void parseJson() const;\n#ifdef USE_BROTLI\n StreamDecompressStatus decompressBodyBrotli() noexcept;\n#endif\n StreamDecompressStatus decompressBodyGzip() noexcept;\n mutable bool flagForParsingParameters_{false};\n mutable bool flagForParsingJson_{false};\n HttpMethod method_{Invalid};\n HttpMethod previousMethod_{Invalid};\n Version version_{Version::kUnknown};\n std::string path_;\n std::string originalPath_;\n bool pathEncode_{true};\n std::string_view matchedPathPattern_{\"\"};\n std::string query_;\n SafeStringMap headers_;\n SafeStringMap cookies_;\n mutable SafeStringMap parameters_;\n mutable std::shared_ptr jsonPtr_;\n SessionPtr sessionPtr_;\n mutable AttributesPtr attributesPtr_;\n trantor::InetAddress peer_;\n trantor::InetAddress local_;\n trantor::Date creationDate_;\n trantor::CertificatePtr peerCertificate_;\n std::unique_ptr cacheFilePtr_;\n mutable std::unique_ptr jsonParsingErrorPtr_;\n std::unique_ptr expectPtr_;\n bool keepAlive_{true};\n bool isOnSecureConnection_{false};\n bool passThrough_{false};\n std::vector routingParams_;\n\n protected:\n std::string content_;\n trantor::EventLoop *loop_;\n mutable ContentType contentType_{CT_TEXT_PLAIN};\n mutable bool flagForParsingContentType_{false};\n mutable std::string contentTypeString_;\n};\n\nusing HttpRequestImplPtr = std::shared_ptr;\n\ninline void swap(HttpRequestImpl &one, HttpRequestImpl &two) noexcept\n{\n one.swap(two);\n}\n\n} // namespace drogon\n\n// Path: lib/src/HttpMessageBody.h\n/**\n *\n * HttpMessageBody.h\n * An Tao\n *\n * Copyright 2018, An Tao. All rights reserved.\n * https://github.com/an-tao/drogon\n * Use of this source code is governed by a MIT license\n * that can be found in the License file.\n *\n * Drogon\n *\n */\n\n#pragma once\n#include \n#include \n#include \n\nnamespace drogon\n{\nclass HttpMessageBody\n{\n public:\n enum class BodyType\n {\n kNone = 0,\n kString,\n kStringView\n };\n\n BodyType bodyType()\n {\n return type_;\n }\n\n virtual const char *data() const\n {\n return nullptr;\n }\n\n virtual char *data()\n {\n return nullptr;\n }\n\n virtual size_t length() const\n {\n return 0;\n }\n\n virtual std::string_view getString() const = 0;\n\n virtual void append(const char * /*buf*/, size_t /*len*/)\n {\n }\n\n virtual ~HttpMessageBody()\n {\n }\n\n protected:\n BodyType type_{BodyType::kNone};\n};\n\nclass HttpMessageStringBody : public HttpMessageBody\n{\n public:\n HttpMessageStringBody()\n {\n type_ = BodyType::kString;\n }\n\n HttpMessageStringBody(const std::string &body) : body_(body)\n {\n type_ = BodyType::kString;\n }\n\n HttpMessageStringBody(std::string &&body) : body_(std::move(body))\n {\n type_ = BodyType::kString;\n }\n\n const char *data() const override\n {\n return body_.data();\n }\n\n char *data() override\n {\n return const_cast(body_.data());\n }\n\n size_t length() const override\n {\n return body_.length();\n }\n\n std::string_view getString() const override\n {\n return std::string_view{body_.data(), body_.length()};\n }\n\n void append(const char *buf, size_t len) override\n {\n body_.append(buf, len);\n }\n\n private:\n std::string body_;\n};\n\nclass HttpMessageStringViewBody : public HttpMessageBody\n{\n public:\n \nHttpMessageStringViewBody(const char *buf, size_t len) : body_(buf, len)\n {\n type_ = BodyType::kStringView;\n }\n\n const char *data() const override\n {\n return body_.data();\n }\n\n char *data() override\n {\n return const_cast(body_.data());\n }\n\n size_t length() const override\n {\n return body_.length();\n }\n\n std::string_view getString() const override\n {\n return body_;\n }\n\n private:\n std::string_view body_;\n};\n\n} // namespace drogon\n\n// Path: lib/src/HttpResponseImpl.h\n/**\n *\n * @file HttpResponseImpl.h\n * @author An Tao\n *\n * Copyright 2018, An Tao. All rights reserved.\n * https://github.com/an-tao/drogon\n * Use of this source code is governed by a MIT license\n * that can be found in the License file.\n *\n * Drogon\n *\n */\n\n#pragma once\n\n#include \"HttpUtils.h\"\n#include \"HttpMessageBody.h\"\n#include \n#include \n#include \n#include \n#include \n#include \n#include \n#include \n#include \n#include \n#include \n\nnamespace drogon\n{\nclass DROGON_EXPORT HttpResponseImpl : public HttpResponse\n{\n friend class HttpResponseParser;\n\n public:\n HttpResponseImpl() : creationDate_(trantor::Date::now())\n {\n }\n\n HttpResponseImpl(HttpStatusCode code, ContentType type)\n : statusCode_(code),\n statusMessage_(statusCodeToString(code)),\n creationDate_(trantor::Date::now()),\n contentType_(type),\n flagForParsingContentType_(true),\n contentTypeString_(contentTypeToMime(type))\n {\n }\n\n void setPassThrough(bool flag) override\n {\n passThrough_ = flag;\n }\n\n HttpStatusCode statusCode() const override\n {\n return statusCode_;\n }\n\n const trantor::Date &creationDate() const override\n {\n return creationDate_;\n }\n\n void setStatusCode(HttpStatusCode code) override\n {\n statusCode_ = code;\n setStatusMessage(statusCodeToString(code));\n }\n\n void setVersion(const Version v) override\n {\n version_ = v;\n if (version_ == Version::kHttp10)\n {\n closeConnection_ = true;\n }\n }\n\n Version version() const override\n {\n return version_;\n }\n\n const char *versionString() const override;\n\n void setCloseConnection(bool on) override\n {\n closeConnection_ = on;\n }\n\n bool ifCloseConnection() const override\n {\n return closeConnection_;\n }\n\n void setContentTypeCode(ContentType type) override\n {\n contentType_ = type;\n auto ct = contentTypeToMime(type);\n contentTypeString_ = std::string(ct.data(), ct.size());\n flagForParsingContentType_ = true;\n }\n\n // void setContentTypeCodeAndCharacterSet(ContentType type, const\n // std::string &charSet = \"utf-8\") override\n // {\n // contentType_ = type;\n // setContentType(webContentTypeAndCharsetToString(type, charSet));\n // }\n\n ContentType contentType() const override\n {\n parseContentTypeAndString();\n return contentType_;\n }\n\n const std::string &getHeader(std::string key) const override\n {\n transform(key.begin(), key.end(), key.begin(), [](unsigned char c) {\n return tolower(c);\n });\n return getHeaderBy(key);\n }\n\n void removeHeader(std::string key) override\n {\n transform(key.begin(), key.end(), key.begin(), [](unsigned char c) {\n return tolower(c);\n });\n removeHeaderBy(key);\n }\n\n const SafeStringMap &headers() const override\n {\n return headers_;\n }\n\n const std::string &getHeaderBy(const std::string &lowerKey) const\n {\n static const std::string defaultVal;\n auto iter = headers_.find(lowerKey);\n if (iter == headers_.end())\n {\n return defaultVal;\n }\n return iter->second;\n }\n\n void removeHeaderBy(const std::string &lowerKey)\n {\n fullHeaderString_.reset();\n headers_.erase(lowerKey);\n }\n\n void addHeader(std::string field, const std::string &value) override\n {\n fullHeaderString_.reset();\n transform(field.begin(),\n field.end(),\n field.begin(),\n [](unsigned char c) { return tolower(c); });\n headers_[std::move(field)] = value;\n }\n\n void addHeader(std::string field, std::string &&value) override\n {\n fullHeaderString_.reset();\n transform(field.begin(),\n field.end(),\n field.begin(),\n [](unsigned char c) { return tolower(c); });\n headers_[std::move(field)] = std::move(value);\n }\n\n void addHeader(const char *start, const char *colon, const char *end);\n\n void addCookie(const std::string &key, const std::string &value) override\n {\n cookies_[key] = Cookie(key, value);\n }\n\n void addCookie(const Cookie &cookie) override\n {\n cookies_[cookie.key()] = cookie;\n }\n\n void addCookie(Cookie &&cookie) override\n {\n cookies_[cookie.key()] = std::move(cookie);\n }\n\n const Cookie &getCookie(const std::string &key) const override\n {\n static const Cookie defaultCookie;\n auto it = cookies_.find(key);\n if (it != cookies_.end())\n {\n return it->second;\n }\n return defaultCookie;\n }\n\n const SafeStringMap &cookies() const override\n {\n return cookies_;\n }\n\n void removeCookie(const std::string &key) override\n {\n cookies_.erase(key);\n }\n\n void setBody(const std::string &body) override\n {\n bodyPtr_ = std::make_shared(body);\n if (passThrough_)\n {\n addHeader(\"content-length\", std::to_string(bodyPtr_->length()));\n }\n }\n\n void setBody(std::string &&body) override\n {\n bodyPtr_ = std::make_shared(std::move(body));\n if (passThrough_)\n {\n addHeader(\"content-length\", std::to_string(bodyPtr_->length()));\n }\n }\n\n void redirect(const std::string &url)\n {\n headers_[\"location\"] = url;\n }\n\n std::shared_ptr renderToBuffer();\n void renderToBuffer(trantor::MsgBuffer &buffer);\n std::shared_ptr renderHeaderForHeadMethod();\n void clear() override;\n\n void setExpiredTime(ssize_t expiredTime) override\n {\n expriedTime_ = expiredTime;\n datePos_ = std::string::npos;\n if (expriedTime_ < 0 && version_ == Version::kHttp10)\n {\n fullHeaderString_.reset();\n }\n }\n\n ssize_t expiredTime() const override\n {\n return expriedTime_;\n }\n\n const char *getBodyData() const override\n {\n if (!flagForSerializingJson_ && jsonPtr_)\n {\n generateBodyFromJson();\n }\n else if (!bodyPtr_)\n {\n return nullptr;\n }\n return bodyPtr_->data();\n }\n\n size_t getBodyLength() const override\n {\n if (bodyPtr_)\n return bodyPtr_->length();\n return 0;\n }\n\n void swap(HttpResponseImpl &that) noexcept;\n void parseJson() const;\n\n const std::shared_ptr &jsonObject() const override\n {\n // Not multi-thread safe but good, because we basically call this\n // function in a single thread\n if (!flagForParsingJson_)\n {\n flagForParsingJson_ = true;\n parseJson();\n }\n return jsonPtr_;\n }\n\n const std::string &getJsonError() const override\n {\n static const std::string none;\n if (jsonParsingErrorPtr_)\n return *jsonParsingErrorPtr_;\n return none;\n }\n\n void setJsonObject(const Json::Value &pJson)\n {\n flagForParsingJson_ = true;\n flagForSerializingJson_ = false;\n jsonPtr_ = std::make_shared(pJson);\n }\n\n void setJsonObject(Json::Value &&pJson)\n {\n flagForParsingJson_ = true;\n flagForSerializingJson_ = false;\n jsonPtr_ = std::make_shared(std::move(pJson));\n }\n\n bool shouldBeCompressed() const;\n void generateBodyFromJson() const;\n\n const std::string &sendfileName() const override\n {\n return sendfileName_;\n }\n\n const SendfileRange &sendfileRange() const override\n {\n return sendfileRange_;\n }\n\n const trantor::CertificatePtr &peerCertificate() const override\n {\n return peerCertificate_;\n }\n\n void setPeerCertificate(const trantor::CertificatePtr &cert)\n {\n peerCertificate_ = cert;\n }\n\n void setSendfile(const std::string &filename)\n {\n sendfileName_ = filename;\n }\n\n void setSendfileRange(size_t offset, size_t len)\n {\n sendfileRange_.first = offset;\n sendfileRange_.second = len;\n }\n\n const std::function &streamCallback()\n const override\n {\n return streamCallback_;\n }\n\n void setStreamCallback(\n const std::function &callback)\n {\n streamCallback_ = callback;\n }\n\n const std::function &asyncStreamCallback()\n const override\n {\n return asyncStreamCallback_;\n }\n\n void setAsyncStreamCallback(\n const std::function &callback,\n bool disableKickoffTimeout)\n {\n asyncStreamCallback_ = callback;\n asyncStreamDisableKickoff_ = disableKickoffTimeout;\n }\n\n bool asyncStreamKickoffDisabled() const\n {\n return asyncStreamDisableKickoff_;\n }\n\n void makeHeaderString()\n {\n fullHeaderString_ = std::make_shared(128);\n makeHeaderString(*fullHeaderString_);\n }\n\n std::string contentTypeString() const override\n {\n parseContentTypeAndString();\n return contentTypeString_;\n }\n\n void gunzip()\n {\n if (bodyPtr_)\n {\n auto gunzipBody =\n utils::gzipDecompress(bodyPtr_->data(), bodyPtr_->length());\n removeHeaderBy(\"content-encoding\");\n bodyPtr_ =\n std::make_shared(std::move(gunzipBody));\n addHeader(\"content-length\", std::to_string(bodyPtr_->length()));\n }\n }\n#ifdef USE_BROTLI\n void brDecompress()\n {\n if (bodyPtr_)\n {\n auto gunzipBody =\n utils::brotliDecompress(bodyPtr_->data(), bodyPtr_->length());\n removeHeaderBy(\"content-encoding\");\n bodyPtr_ =\n std::make_shared(std::move(gunzipBody));\n addHeader(\"content-length\", std::to_string(bodyPtr_->length()));\n }\n }\n#endif\n ~HttpResponseImpl() override = default;\n\n protected:\n void makeHeaderString(trantor::MsgBuffer &headerString);\n\n void parseContentTypeAndString() const\n {\n if (!flagForParsingContentType_)\n {\n flagForParsingContentType_ = true;\n auto &contentTypeString = getHeaderBy(\"content-type\");\n if (contentTypeString == \"\")\n {\n contentType_ = CT_NONE;\n }\n else\n {\n auto pos = contentTypeString.find(';');\n if (pos != std::string::npos)\n {\n contentType_ = parseContentType(\n std::string_view(contentTypeString.data(), pos));\n }\n else\n {\n contentType_ =\n parseContentType(std::string_view(contentTypeString));\n }\n\n if (contentType_ == CT_NONE)\n contentType_ = CT_CUSTOM;\n contentTypeString_ = contentTypeString;\n }\n }\n }\n\n private:\n void setBody(const char *body, size_t len) override\n {\n bodyPtr_ = std::make_shared(body, len);\n if (passThrough_)\n {\n addHeader(\"content-length\", std::to_string(bodyPtr_->length()));\n }\n }\n\n void setContentTypeCodeAndCustomString(ContentType type,\n const char *typeString,\n size_t typeStringLength) override\n {\n contentType_ = type;\n flagForParsingContentType_ = true;\n\n std::string_view sv(typeString, typeStringLength);\n bool haveHeader = sv.find(\"content-type: \") == 0;\n bool haveCRLF = sv.rfind(\"\\r\\n\") == sv.size() - 2;\n\n size_t endOffset = 0;\n if (haveHeader)\n endOffset += 14;\n if (haveCRLF)\n endOffset += 2;\n setContentType(std::string_view{typeString + (haveHeader ? 14 : 0),\n typeStringLength - endOffset});\n }\n\n void setContentTypeString(const char *typeString,\n size_t typeStringLength) override;\n\n void setCustomStatusCode(int code,\n const char *message,\n size_t messageLength) override\n {\n assert(code >= 0);\n customStatusCode_ = code;\n statusMessage_ = std::string_view{message, messageLength};\n }\n\n SafeStringMap headers_;\n SafeStringMap cookies_;\n\n int customStatusCode_{-1};\n HttpStatusCode statusCode_{kUnknown};\n std::string_view statusMessage_;\n\n trantor::Date creationDate_;\n Version version_{Version::kHttp11};\n bool closeConnection_{false};\n mutable std::shared_ptr bodyPtr_;\n ssize_t expriedTime_{-1};\n std::string sendfileName_;\n SendfileRange sendfileRange_{0, 0};\n std::function streamCallback_;\n std::function asyncStreamCallback_;\n bool asyncStreamDisableKickoff_{false};\n\n mutable std::shared_ptr jsonPtr_;\n\n std::shared_ptr fullHeaderString_;\n trantor::CertificatePtr peerCertificate_;\n mutable std::shared_ptr httpString_;\n mutable size_t datePos_{static_cast(-1)};\n mutable int64_t httpStringDate_{-1};\n mutable bool flagForParsingJson_{false};\n mutable bool flagForSerializingJson_{true};\n mutable ContentType contentType_{CT_TEXT_PLAIN};\n mutable bool flagForParsingContentType_{false};\n mutable std::shared_ptr jsonParsingErrorPtr_;\n mutable std::string contentTypeString_{\"text/html; charset=utf-8\"};\n bool passThrough_{false};\n\n void setContentType(const std::string_view &contentType)\n {\n contentTypeString_ =\n std::string(contentType.data(), contentType.size());\n }\n\n void setStatusMessage(const std::string_view &message)\n {\n statusMessage_ = message;\n }\n};\n\nusing HttpResponseImplPtr = std::shared_ptr;\n\ninline void swap(HttpResponseImpl &one, HttpResponseImpl &two) noexcept\n{\n one.swap(two);\n}\n\n} // namespace drogon\n\n// Path: lib/src/RangeParser.h\n/**\n *\n * RangeParser.h\n * He, Wanchen\n *\n * Copyright 2021, He,Wanchen. All rights reserved.\n * https://github.com/drogonframework/drogon\n * Use of this source code is governed by a MIT license\n * that can be found in the License file.\n *\n * Drogon\n *\n */\n#pragma once\n\n#include \n#include \n#include \n\nnamespace drogon\n{\n// [start, end)\nstruct FileRange\n{\n size_t start;\n size_t end;\n};\n\nenum FileRangeParseResult\n{\n InvalidRange = -1,\n NotSatisfiable = 0,\n SinglePart = 1,\n MultiPart = 2\n};\n\nFileRangeParseResult parseRangeHeader(const std::string &rangeStr,\n size_t contentLength,\n std::vector &ranges);\n\n} // namespace drogon\n\n// Path: lib/src/StaticFileRouter.cc\n/**\n *\n * StaticFileRouter.cc\n * An Tao\n *\n * Copyright 2018, An Tao. All rights reserved.\n * https://github.com/an-tao/drogon\n * Use of this source code is governed by a MIT license\n * that can be found in the License file.\n *\n * Drogon\n *\n */\n\n#include \"StaticFileRouter.h\"\n#include \"HttpAppFrameworkImpl.h\"\n#include \"HttpRequestImpl.h\"\n#include \"HttpResponseImpl.h\"\n#include \"RangeParser.h\"\n#include \n#include \n#include \n#include \n#include \n#ifndef _WIN32\n#include \n#elif !defined(__MINGW32__)\n#define stat _wstati64\n#define S_ISREG(m) (((m) & 0170000) == (0100000))\n#define S_ISDIR(m) (((m) & 0170000) == (0040000))\n#endif\n#include \n#include \n\nusing namespace drogon;\n\nvoid StaticFileRouter::init(const std::vector &ioLoops)\n{\n // Max timeout up to about 70 days;\n staticFilesCacheMap_ = std::make_unique<\n IOThreadStorage>>>();\n staticFilesCacheMap_->init(\n [&ioLoops](std::unique_ptr> &mapPtr,\n size_t i) {\n assert(i == ioLoops[i]->index());\n mapPtr = std::make_unique>(ioLoops[i],\n 1.0,\n 4,\n 50);\n });\n staticFilesCache_ = std::make_unique<\n IOThreadStorage>>();\n ioLocationsPtr_ =\n std::make_shared>>();\n for (auto *loop : ioLoops)\n {\n loop->queueInLoop(\n [ioLocationsPtr = ioLocationsPtr_, locations = locations_] {\n **ioLocationsPtr = locations;\n });\n }\n}\n\nvoid StaticFileRouter::reset()\n{\n staticFilesCacheMap_.reset();\n staticFilesCache_.reset();\n ioLocationsPtr_.reset();\n locations_.clear();\n}\n\nvoid StaticFileRouter::route(\n const HttpRequestImplPtr &req,\n std::function &&callback)\n{\n const std::string &path = req->path();\n if (path.find(\"..\") != std::string::npos)\n {\n auto directories = utils::splitString(path, \"/\");\n int traversalDepth = 0;\n for (const auto &dir : directories)\n {\n if (dir == \"..\")\n {\n traversalDepth--;\n }\n else if (dir != \".\")\n {\n traversalDepth++;\n }\n\n if (traversalDepth < 0)\n {\n // Downloading files from the parent folder is forbidden.\n callback(app().getCustomErrorHandler()(k403Forbidden, req));\n return;\n }\n }\n }\n\n auto lPath = path;\n std::transform(lPath.begin(),\n lPath.end(),\n lPath.begin(),\n [](unsigned char c) { return tolower(c); });\n\n for (auto &location : **ioLocationsPtr_)\n {\n auto &URI = location.uriPrefix_;\n if (location.realLocation_.empty())\n {\n if (!location.alias_.empty())\n {\n if (location.alias_[0] == '/')\n {\n location.realLocation_ = location.alias_;\n }\n else\n {\n location.realLocation_ =\n HttpAppFrameworkImpl::instance().getDocumentRoot() +\n location.alias_;\n }\n }\n else\n {\n location.realLocation_ =\n HttpAppFrameworkImpl::instance().getDocumentRoot() +\n location.uriPrefix_;\n }\n if (location.realLocation_[location.realLocation_.length() - 1] !=\n '/')\n {\n location.realLocation_.append(1, '/');\n }\n if (!location.isCaseSensitive_)\n {\n std::transform(URI.begin(),\n URI.end(),\n URI.begin(),\n [](unsigned char c) { return tolower(c); });\n }\n }\n auto &tmpPath = location.isCaseSensitive_ ? path : lPath;\n if (tmpPath.length() >= URI.length() &&\n std::equal(tmpPath.begin(),\n tmpPath.begin() + URI.length(),\n URI.begin()))\n {\n std::string_view restOfThePath{path.data() + URI.length(),\n path.length() - URI.length()};\n auto pos = restOfThePath.rfind('/');\n if (pos != 0 && pos != std::string_view::npos &&\n !location.isRecursive_)\n {\n callback(app().getCustomErrorHandler()(k403Forbidden, req));\n return;\n }\n std::string filePath =\n location.realLocation_ +\n std::string{restOfThePath.data(), restOfThePath.length()};\n std::filesystem::path fsFilePath(utils::toNativePath(filePath));\n std::error_code err;\n if (!std::filesystem::exists(fsFilePath, err))\n {\n defaultHandler_(req, std::move(callback));\n return;\n }\n if (std::filesystem::is_directory(fsFilePath, err))\n {\n // Check if path is eligible for an implicit index.html\n if (implicitPageEnable_)\n {\n filePath = filePath + \"/\" + implicitPage_;\n }\n else\n {\n callback(app().getCustomErrorHandler()(k403Forbidden, req));\n return;\n }\n }\n else\n {\n if (!location.allowAll_)\n {\n pos = restOfThePath.rfind('.');\n if (pos == std::string_view::npos)\n {\n callback(\n app().getCustomErrorHandler()(k403Forbidden, req));\n return;\n }\n std::string extension{restOfThePath.data() + pos + 1,\n restOfThePath.length() - pos - 1};\n std::transform(extension.begin(),\n extension.end(),\n extension.begin(),\n [](unsigned char c) { return tolower(c); });\n if (fileTypeSet_.find(extension) == fileTypeSet_.end())\n {\n callback(\n app().getCustomErrorHandler()(k403Forbidden, req));\n return;\n }\n }\n }\n\n if (location.filters_.empty())\n {\n sendStaticFileResponse(filePath,\n req,\n std::move(callback),\n std::string_view{\n location.defaultContentType_});\n }\n else\n {\n filters_function::doFilters(\n location.filters_,\n req,\n [this,\n req,\n filePath = std::move(filePath),\n contentType =\n std::string_view{location.defaultContentType_},\n callback = std::move(callback)](\n const HttpResponsePtr &resp) mutable {\n if (resp)\n {\n callback(resp);\n }\n else\n {\n sendStaticFileResponse(filePath,\n req,\n std::move(callback),\n contentType);\n }\n });\n }\n return;\n }\n }\n std::string directoryPath =\n HttpAppFrameworkImpl::instance().getDocumentRoot() + path;\n std::filesystem::path fsDirectoryPath(utils::toNativePath(directoryPath));\n std::error_code err;\n if (std::filesystem::exists(fsDirectoryPath, err))\n {\n if (std::filesystem::is_directory(fsDirectoryPath, err))\n {\n // Check if path is eligible for an implicit index.html\n if (implicitPageEnable_)\n {\n std::string filePath = directoryPath + \"/\" + implicitPage_;\n sendStaticFileResponse(filePath, req, std::move(callback), \"\");\n return;\n }\n else\n {\n callback(app().getCustomErrorHandler()(k403Forbidden, req));\n return;\n }\n }\n else\n {\n // This is a normal page\n auto pos = path.rfind('.');\n if (pos == std::string::npos)\n {\n callback(app().getCustomErrorHandler()(k403Forbidden, req));\n return;\n }\n std::string filetype = lPath.substr(pos + 1);\n if (fileTypeSet_.find(filetype) != fileTypeSet_.end())\n {\n // LOG_INFO << \"file query!\" << path;\n std::string filePath = directoryPath;\n sendStaticFileResponse(filePath, req, std::move(callback), \"\");\n return;\n }\n }\n }\n defaultHandler_(req, std::move(callback));\n}\n\n...\n", "needle_token_start": 8737, "needle_token_end": 8778, "code_context_ntokens": 15988, "output": ["```cpp\nclass HttpMessageStringViewBody : public HttpMessageBody\n{\n public:\n HttpMessageStringViewBody(const char *buf, size_t len) : body_(buf, len)\n {\n type_ = BodyType::kStringView;\n }\n\n const char *data() const override\n {\n return body_.data();\n }\n\n char *data() override\n {\n return const_cast(body_.data());\n }\n\n size_t length() const override\n {\n return body_.length();\n }\n\n std::string_view getString() const override\n {\n return body_;\n }\n\n private:\n std::string_view body_;\n};\n```"]} +{"repo": "drogonframework/drogon", "name": "ListenerInfo", "language": "cpp", "path": "lib/src/ListenerManager.h", "position_ratio": 0.65, "description": "\nFunction Description:\n1. **Purpose**: To encapsulate the configuration details required for setting up a network listener, including SSL/TLS specifics.\n2. **Input**: Receives multiple parameters such as IP address, port number, SSL usage flag, certificate file path, key file path, TLS version flag, and SSL configuration commands.\n3. **Output**: Does not directly produce an output but initializes an object with the provided configuration for use in network operations.\n4. **Procedure**: Takes the provided inputs, moves some of them to optimize memory usage, and stores them within the object for later access when setting up network listeners.\n\n", "instruction": "Based on the function description and code context, please retrieve and repeat the exact described function from the code context in a code block wrapped by ```:", "template": "instruction\ncode_context\ndescription\ninstruction", "code_context": "// Path: lib/src/AOPAdvice.h\n/**\n *\n...\n// Path: lib/src/HttpResponseImpl.cc\n/**\n *\n * @file HttpResponseImpl.cc\n * @author An Tao\n *\n * Copyright 2018, An Tao. All rights reserved.\n * https://github.com/an-tao/drogon\n * Use of this source code is governed by a MIT license\n * that can be found in the License file.\n *\n * Drogon\n *\n */\n\n#include \"HttpResponseImpl.h\"\n#include \"AOPAdvice.h\"\n#include \"HttpAppFrameworkImpl.h\"\n#include \"HttpUtils.h\"\n#include \n#include \n#include \n#include \n#include \n#include \n#include \n#include \n#include \n\nusing namespace trantor;\nusing namespace drogon;\nusing namespace std::literals::string_literals;\nusing namespace std::placeholders;\n#ifdef _WIN32\n#undef max\n#endif\n\nnamespace drogon\n{\n// \"Fri, 23 Aug 2019 12:58:03 GMT\" length = 29\nstatic const size_t httpFullDateStringLength = 29;\n\nstatic inline HttpResponsePtr genHttpResponse(const std::string &viewName,\n const HttpViewData &data,\n const HttpRequestPtr &req)\n{\n auto templ = DrTemplateBase::newTemplate(viewName);\n if (templ)\n {\n auto res = HttpResponse::newHttpResponse();\n res->setBody(templ->genText(data));\n return res;\n }\n return drogon::HttpResponse::newNotFoundResponse(req);\n}\n} // namespace drogon\n\nHttpResponsePtr HttpResponse::newHttpResponse()\n{\n auto res = std::make_shared(k200OK, CT_TEXT_HTML);\n AopAdvice::instance().passResponseCreationAdvices(res);\n return res;\n}\n\nHttpResponsePtr HttpResponse::newHttpResponse(HttpStatusCode code,\n ContentType type)\n{\n auto res = std::make_shared(code, type);\n AopAdvice::instance().passResponseCreationAdvices(res);\n return res;\n}\n\nHttpResponsePtr HttpResponse::newHttpJsonResponse(const Json::Value &data)\n{\n auto res = std::make_shared(k200OK, CT_APPLICATION_JSON);\n res->setJsonObject(data);\n AopAdvice::instance().passResponseCreationAdvices(res);\n return res;\n}\n\nHttpResponsePtr HttpResponse::newHttpJsonResponse(Json::Value &&data)\n{\n auto res = std::make_shared(k200OK, CT_APPLICATION_JSON);\n res->setJsonObject(std::move(data));\n AopAdvice::instance().passResponseCreationAdvices(res);\n return res;\n}\n\nconst char *HttpResponseImpl::versionString() const\n{\n const char *result = \"UNKNOWN\";\n switch (version_)\n {\n case Version::kHttp10:\n result = \"HTTP/1.0\";\n break;\n\n case Version::kHttp11:\n result = \"HTTP/1.1\";\n break;\n\n default:\n break;\n }\n return result;\n}\n\nvoid HttpResponseImpl::generateBodyFromJson() const\n{\n if (!jsonPtr_ || flagForSerializingJson_)\n {\n return;\n }\n flagForSerializingJson_ = true;\n static std::once_flag once;\n static Json::StreamWriterBuilder builder;\n std::call_once(once, []() {\n builder[\"commentStyle\"] = \"None\";\n builder[\"indentation\"] = \"\";\n if (!app().isUnicodeEscapingUsedInJson())\n {\n builder[\"emitUTF8\"] = true;\n }\n auto &precision = app().getFloatPrecisionInJson();\n if (precision.first != 0)\n {\n builder[\"precision\"] = precision.first;\n builder[\"precisionType\"] = precision.second;\n }\n });\n bodyPtr_ = std::make_shared(\n writeString(builder, *jsonPtr_));\n}\n\nHttpResponsePtr HttpResponse::newNotFoundResponse(const HttpRequestPtr &req)\n{\n auto loop = trantor::EventLoop::getEventLoopOfCurrentThread();\n auto &resp = HttpAppFrameworkImpl::instance().getCustom404Page();\n if (resp)\n {\n if (loop && loop->index() < app().getThreadNum())\n {\n return resp;\n }\n else\n {\n return HttpResponsePtr{new HttpResponseImpl(\n *static_cast(resp.get()))};\n }\n }\n else\n {\n if (HttpAppFrameworkImpl::instance().isUsingCustomErrorHandler())\n {\n return app().getCustomErrorHandler()(k404NotFound, req);\n }\n else if (loop && loop->index() < app().getThreadNum())\n {\n // If the current thread is an IO thread\n static std::once_flag threadOnce;\n static IOThreadStorage thread404Pages;\n std::call_once(threadOnce, [req = req] {\n thread404Pages.init([req = req](drogon::HttpResponsePtr &resp,\n size_t /*index*/) {\n HttpViewData data;\n data.insert(\"version\", drogon::getVersion());\n resp = HttpResponse::newHttpViewResponse(\"drogon::NotFound\",\n data);\n resp->setStatusCode(k404NotFound);\n resp->setExpiredTime(0);\n });\n });\n LOG_TRACE << \"Use cached 404 response\";\n return thread404Pages.getThreadData();\n }\n else\n {\n HttpViewData data;\n data.insert(\"version\", drogon::getVersion());\n auto notFoundResp =\n HttpResponse::newHttpViewResponse(\"drogon::NotFound\", data);\n notFoundResp->setStatusCode(k404NotFound);\n return notFoundResp;\n }\n }\n}\n\nHttpResponsePtr HttpResponse::newRedirectionResponse(\n const std::string &location,\n HttpStatusCode status)\n{\n auto res = std::make_shared();\n res->setStatusCode(status);\n res->redirect(location);\n AopAdvice::instance().passResponseCreationAdvices(res);\n return res;\n}\n\nHttpResponsePtr HttpResponse::newHttpViewResponse(const std::string &viewName,\n const HttpViewData &data,\n const HttpRequestPtr &req)\n{\n return genHttpResponse(viewName, data, req);\n}\n\nHttpResponsePtr HttpResponse::newFileResponse(\n const unsigned char *pBuffer,\n size_t bufferLength,\n const std::string &attachmentFileName,\n ContentType type,\n const std::string &typeString)\n{\n // Make Raw HttpResponse\n auto resp = std::make_shared();\n\n // Set response body and length\n resp->setBody(\n std::string(reinterpret_cast(pBuffer), bufferLength));\n\n // Set status of message\n resp->setStatusCode(k200OK);\n\n // Check for type and assign proper content type in header\n if (!typeString.empty())\n {\n // auto contentType = type;\n if (type == CT_NONE)\n type = parseContentType(typeString);\n if (type == CT_NONE)\n type = CT_APPLICATION_OCTET_STREAM; // XXX: Is this Ok?\n static_cast(resp.get())\n ->setContentTypeCodeAndCustomString(type,\n typeString.c_str(),\n typeString.size());\n }\n else if (type != CT_NONE)\n {\n resp->setContentTypeCode(type);\n }\n else if (!attachmentFileName.empty())\n {\n resp->setContentTypeCode(drogon::getContentType(attachmentFileName));\n }\n else\n {\n resp->setContentTypeCode(\n CT_APPLICATION_OCTET_STREAM); // default content-type for file;\n }\n\n // Add additional header values\n if (!attachmentFileName.empty())\n {\n resp->addHeader(\"Content-Disposition\",\n \"attachment; filename=\" + attachmentFileName);\n }\n\n // Finalize and return response\n AopAdvice::instance().passResponseCreationAdvices(resp);\n return resp;\n}\n\nHttpResponsePtr HttpResponse::newFileResponse(\n const std::string &fullPath,\n const std::string &attachmentFileName,\n ContentType type,\n const std::string &typeString,\n const HttpRequestPtr &req)\n{\n return newFileResponse(\n fullPath, 0, 0, false, attachmentFileName, type, typeString, req);\n}\n\nHttpResponsePtr HttpResponse::newFileResponse(\n const std::string &fullPath,\n size_t offset,\n size_t length,\n bool setContentRange,\n const std::string &attachmentFileName,\n ContentType type,\n const std::string &typeString,\n const HttpRequestPtr &req)\n{\n std::ifstream infile(utils::toNativePath(fullPath), std::ifstream::binary);\n LOG_TRACE << \"send http file:\" << fullPath << \" offset \" << offset\n << \" length \" << length;\n if (!infile)\n {\n auto resp = HttpResponse::newNotFoundResponse(req);\n return resp;\n }\n auto resp = std::make_shared();\n std::streambuf *pbuf = infile.rdbuf();\n size_t filesize =\n static_cast(pbuf->pubseekoff(0, std::ifstream::end));\n if (offset > filesize || length > filesize || // in case of overflow\n offset + length > filesize)\n {\n resp->setStatusCode(k416RequestedRangeNotSatisfiable);\n if (setContentRange)\n {\n char buf[64];\n snprintf(buf, sizeof(buf), \"bytes */%zu\", filesize);\n resp->addHeader(\"Content-Range\", std::string(buf));\n }\n return resp;\n }\n if (length == 0)\n {\n length = filesize - offset;\n }\n pbuf->pubseekoff(offset, std::ifstream::beg); // rewind\n\n if (HttpAppFrameworkImpl::instance().useSendfile() && length > 1024 * 200)\n // TODO : Is 200k an appropriate value? Or set it to be configurable\n {\n // The advantages of sendfile() can only be reflected in sending large\n // files.\n resp->setSendfile(fullPath);\n // Must set length with the right value! Content-Length header relies on\n // this value.\n resp->setSendfileRange(offset, length);\n }\n else\n {\n std::string str;\n str.resize(length);\n pbuf->sgetn(&str[0], length);\n resp->setBody(std::move(str));\n resp->setSendfileRange(offset, length);\n }\n\n // Set correct status code\n if (length < filesize)\n {\n resp->setStatusCode(k206PartialContent);\n }\n else\n {\n resp->setStatusCode(k200OK);\n }\n\n // Infer content type\n if (type == CT_NONE)\n {\n if (!typeString.empty())\n {\n auto r = static_cast(resp.get());\n // auto contentType = type;\n if (type == CT_NONE)\n type = parseContentType(typeString);\n if (type == CT_NONE)\n type = CT_CUSTOM; // XXX: Is this Ok?\n r->setContentTypeCodeAndCustomString(type, typeString);\n }\n else if (!attachmentFileName.empty())\n {\n resp->setContentTypeCode(\n drogon::getContentType(attachmentFileName));\n }\n else\n {\n resp->setContentTypeCode(drogon::getContentType(fullPath));\n }\n }\n else\n {\n if (typeString.empty())\n resp->setContentTypeCode(type);\n else\n {\n auto r = static_cast(resp.get());\n // auto contentType = type;\n if (type == CT_NONE)\n type = parseContentType(typeString);\n if (type == CT_NONE)\n type = CT_CUSTOM; // XXX: Is this Ok?\n r->setContentTypeCodeAndCustomString(type, typeString);\n }\n }\n\n // Set headers\n if (!attachmentFileName.empty())\n {\n resp->addHeader(\"Content-Disposition\",\n \"attachment; filename=\" + attachmentFileName);\n }\n if (setContentRange && length > 0)\n {\n char buf[128];\n snprintf(buf,\n sizeof(buf),\n \"bytes %zu-%zu/%zu\",\n offset,\n offset + length - 1,\n filesize);\n resp->addHeader(\"Content-Range\", std::string(buf));\n }\n AopAdvice::instance().passResponseCreationAdvices(resp);\n return resp;\n}\n\nHttpResponsePtr HttpResponse::newStreamResponse(\n const std::function &callback,\n const std::string &attachmentFileName,\n ContentType type,\n const std::string &typeString,\n const HttpRequestPtr &req)\n{\n LOG_TRACE << \"send stream as \"s\n << (attachmentFileName.empty() ? \"raw data\"s\n : \"file: \"s + attachmentFileName);\n if (!callback)\n {\n auto resp = HttpResponse::newNotFoundResponse();\n return resp;\n }\n auto resp = std::make_shared();\n resp->setStreamCallback(callback);\n resp->setStatusCode(k200OK);\n\n // Infer content type\n if (type == CT_NONE)\n {\n if (!typeString.empty())\n {\n auto r = static_cast(resp.get());\n auto contentType = type;\n if (type == CT_NONE)\n type = parseContentType(typeString);\n if (type == CT_NONE)\n type = CT_CUSTOM; // XXX: Is this Ok?\n r->setContentTypeCodeAndCustomString(type, typeString);\n }\n else if (!attachmentFileName.empty())\n {\n resp->setContentTypeCode(\n drogon::getContentType(attachmentFileName));\n }\n }\n else\n {\n if (typeString.empty())\n resp->setContentTypeCode(type);\n else\n {\n auto r = static_cast(resp.get());\n auto contentType = type;\n if (type == CT_NONE)\n type = parseContentType(typeString);\n if (type == CT_NONE)\n type = CT_CUSTOM; // XXX: Is this Ok?\n r->setContentTypeCodeAndCustomString(type, typeString);\n }\n }\n\n // Set headers\n if (!attachmentFileName.empty())\n {\n resp->addHeader(\"Content-Disposition\",\n \"attachment; filename=\" + attachmentFileName);\n }\n AopAdvice::instance().passResponseCreationAdvices(resp);\n return resp;\n}\n\nHttpResponsePtr HttpResponse::newAsyncStreamResponse(\n const std::function &callback,\n bool disableKickoffTimeout)\n{\n if (!callback)\n {\n auto resp = HttpResponse::newNotFoundResponse();\n return resp;\n }\n auto resp = std::make_shared();\n resp->setAsyncStreamCallback(callback, disableKickoffTimeout);\n resp->setStatusCode(k200OK);\n AopAdvice::instance().passResponseCreationAdvices(resp);\n return resp;\n}\n\nvoid HttpResponseImpl::makeHeaderString(trantor::MsgBuffer &buffer)\n{\n buffer.ensureWritableBytes(128);\n int len{0};\n if (version_ == Version::kHttp11)\n {\n if (customStatusCode_ >= 0)\n {\n len = snprintf(buffer.beginWrite(),\n buffer.writableBytes(),\n \"HTTP/1.1 %d \",\n customStatusCode_);\n }\n else\n {\n len = snprintf(buffer.beginWrite(),\n buffer.writableBytes(),\n \"HTTP/1.1 %d \",\n statusCode_);\n }\n }\n else\n {\n if (customStatusCode_ >= 0)\n {\n len = snprintf(buffer.beginWrite(),\n buffer.writableBytes(),\n \"HTTP/1.0 %d \",\n customStatusCode_);\n }\n else\n {\n len = snprintf(buffer.beginWrite(),\n buffer.writableBytes(),\n \"HTTP/1.0 %d \",\n statusCode_);\n }\n }\n buffer.hasWritten(len);\n\n if (!statusMessage_.empty())\n buffer.append(statusMessage_.data(), statusMessage_.length());\n buffer.append(\"\\r\\n\");\n generateBodyFromJson();\n if (!passThrough_)\n {\n buffer.ensureWritableBytes(64);\n if (streamCallback_ || asyncStreamCallback_)\n {\n // When the headers are created, it is time to set the transfer\n // encoding to chunked if the contents size is not specified\n if (!ifCloseConnection() &&\n headers_.find(\"content-length\") == headers_.end())\n {\n LOG_DEBUG << \"send stream with transfer-encoding chunked\";\n headers_[\"transfer-encoding\"] = \"chunked\";\n }\n len = 0;\n }\n else if (sendfileName_.empty())\n {\n auto bodyLength = bodyPtr_ ? bodyPtr_->length() : 0;\n len = snprintf(buffer.beginWrite(),\n buffer.writableBytes(),\n contentLengthFormatString(),\n bodyLength);\n }\n else\n {\n auto bodyLength = sendfileRange_.second;\n len = snprintf(buffer.beginWrite(),\n buffer.writableBytes(),\n contentLengthFormatString(),\n bodyLength);\n }\n buffer.hasWritten(len);\n if (headers_.find(\"connection\") == headers_.end())\n {\n if (closeConnection_)\n {\n buffer.append(\"connection: close\\r\\n\");\n }\n else if (version_ == Version::kHttp10)\n {\n buffer.append(\"connection: Keep-Alive\\r\\n\");\n }\n }\n\n if (!contentTypeString_.empty())\n {\n buffer.append(\"content-type: \");\n buffer.append(contentTypeString_);\n buffer.append(\"\\r\\n\");\n }\n if (HttpAppFrameworkImpl::instance().sendServerHeader())\n {\n buffer.append(\n HttpAppFrameworkImpl::instance().getServerHeaderString());\n }\n }\n\n for (auto it = headers_.begin(); it != headers_.end(); ++it)\n {\n buffer.append(it->first);\n buffer.append(\": \");\n buffer.append(it->second);\n buffer.append(\"\\r\\n\");\n }\n}\n\nvoid HttpResponseImpl::renderToBuffer(trantor::MsgBuffer &buffer)\n{\n if (expriedTime_ >= 0)\n {\n auto strPtr = renderToBuffer();\n buffer.append(strPtr->peek(), strPtr->readableBytes());\n return;\n }\n\n if (!fullHeaderString_)\n {\n makeHeaderString(buffer);\n }\n else\n {\n buffer.append(*fullHeaderString_);\n }\n\n // output cookies\n if (!cookies_.empty())\n {\n for (auto it = cookies_.begin(); it != cookies_.end(); ++it)\n {\n buffer.append(it->second.cookieString());\n }\n }\n\n // output Date header\n if (!passThrough_ &&\n drogon::HttpAppFrameworkImpl::instance().sendDateHeader())\n {\n buffer.append(\"date: \");\n buffer.append(utils::getHttpFullDate(trantor::Date::date()),\n httpFullDateStringLength);\n buffer.append(\"\\r\\n\\r\\n\");\n }\n else\n {\n buffer.append(\"\\r\\n\");\n }\n if (bodyPtr_)\n buffer.append(bodyPtr_->data(), bodyPtr_->length());\n}\n\nstd::shared_ptr HttpResponseImpl::renderToBuffer()\n{\n if (expriedTime_ >= 0)\n {\n if (!passThrough_ &&\n drogon::HttpAppFrameworkImpl::instance().sendDateHeader())\n {\n if (datePos_ != static_cast(-1))\n {\n auto now = trantor::Date::now();\n bool isDateChanged =\n ((now.microSecondsSinceEpoch() / MICRO_SECONDS_PRE_SEC) !=\n httpStringDate_);\n assert(httpString_);\n if (isDateChanged)\n {\n httpStringDate_ =\n now.microSecondsSinceEpoch() / MICRO_SECONDS_PRE_SEC;\n auto newDate = utils::getHttpFullDate(now);\n\n httpString_ =\n std::make_shared(*httpString_);\n memcpy((void *)&(*httpString_)[datePos_],\n newDate,\n httpFullDateStringLength);\n return httpString_;\n }\n\n return httpString_;\n }\n }\n else\n {\n if (httpString_)\n return httpString_;\n }\n }\n auto httpString = std::make_shared(256);\n if (!fullHeaderString_)\n {\n makeHeaderString(*httpString);\n }\n else\n {\n httpString->append(*fullHeaderString_);\n }\n\n // output cookies\n if (!cookies_.empty())\n {\n for (auto it = cookies_.begin(); it != cookies_.end(); ++it)\n {\n httpString->append(it->second.cookieString());\n }\n }\n\n // output Date header\n if (!passThrough_ &&\n drogon::HttpAppFrameworkImpl::instance().sendDateHeader())\n {\n httpString->append(\"date: \");\n auto datePos = httpString->readableBytes();\n httpString->append(utils::getHttpFullDate(trantor::Date::date()),\n httpFullDateStringLength);\n httpString->append(\"\\r\\n\\r\\n\");\n datePos_ = datePos;\n }\n else\n {\n httpString->append(\"\\r\\n\");\n }\n\n LOG_TRACE << \"response(no body):\"\n << std::string_view{httpString->peek(),\n httpString->readableBytes()};\n if (bodyPtr_)\n httpString->append(bodyPtr_->data(), bodyPtr_->length());\n if (expriedTime_ >= 0)\n {\n httpString_ = httpString;\n }\n return httpString;\n}\n\nstd::shared_ptr HttpResponseImpl::\n renderHeaderForHeadMethod()\n{\n auto httpString = std::make_shared(256);\n if (!fullHeaderString_)\n {\n makeHeaderString(*httpString);\n }\n else\n {\n httpString->append(*fullHeaderString_);\n }\n\n // output cookies\n if (!cookies_.empty())\n {\n for (auto it = cookies_.begin(); it != cookies_.end(); ++it)\n {\n httpString->append(it->second.cookieString());\n }\n }\n\n // output Date header\n if (!passThrough_ &&\n drogon::HttpAppFrameworkImpl::instance().sendDateHeader())\n {\n httpString->append(\"date: \");\n httpString->append(utils::getHttpFullDate(trantor::Date::date()),\n httpFullDateStringLength);\n httpString->append(\"\\r\\n\\r\\n\");\n }\n else\n {\n httpString->append(\"\\r\\n\");\n }\n\n return httpString;\n}\n\nvoid HttpResponseImpl::addHeader(const char *start,\n const char *colon,\n const char *end)\n{\n fullHeaderString_.reset();\n std::string field(start, colon);\n transform(field.begin(), field.end(), field.begin(), [](unsigned char c) {\n return tolower(c);\n });\n ++colon;\n while (colon < end && isspace(static_cast(*colon)))\n {\n ++colon;\n }\n std::string value(colon, end);\n while (!value.empty() &&\n isspace(static_cast(value[value.size() - 1])))\n {\n value.resize(value.size() - 1);\n }\n\n if (field == \"set-cookie\")\n {\n // LOG_INFO<<\"cookies!!!:\"<(cookie_name[cpos])))\n ++cpos;\n cookie_name = cookie_name.substr(cpos);\n ++epos;\n while (epos < coo.length() &&\n isspace(static_cast(coo[epos])))\n ++epos;\n cookie_value = coo.substr(epos);\n }\n else\n {\n std::string::size_type cpos = 0;\n while (cpos < coo.length() &&\n isspace(static_cast(coo[cpos])))\n ++cpos;\n cookie_name = coo.substr(cpos);\n }\n if (i == 0)\n {\n cookie.setKey(cookie_name);\n cookie.setValue(cookie_value);\n }\n else\n {\n std::transform(cookie_name.begin(),\n cookie_name.end(),\n cookie_name.begin(),\n [](unsigned char c) { return tolower(c); });\n if (cookie_name == \"path\")\n {\n cookie.setPath(cookie_value);\n }\n else if (cookie_name == \"domain\")\n {\n cookie.setDomain(cookie_value);\n }\n else if (cookie_name == \"expires\")\n {\n cookie.setExpiresDate(utils::getHttpDate(cookie_value));\n }\n else if (cookie_name == \"secure\")\n {\n cookie.setSecure(true);\n }\n else if (cookie_name == \"httponly\")\n {\n cookie.setHttpOnly(true);\n }\n else if (cookie_name == \"samesite\")\n {\n cookie.setSameSite(\n cookie.convertString2SameSite(cookie_value));\n }\n else if (cookie_name == \"max-age\")\n {\n cookie.setMaxAge(std::stoi(cookie_value));\n }\n }\n }\n if (!cookie.key().empty())\n {\n cookies_[cookie.key()] = cookie;\n }\n }\n else\n {\n headers_[std::move(field)] = std::move(value);\n }\n}\n\nvoid HttpResponseImpl::swap(HttpResponseImpl &that) noexcept\n{\n using std::swap;\n headers_.swap(that.headers_);\n cookies_.swap(that.cookies_);\n swap(statusCode_, that.statusCode_);\n swap(version_, that.version_);\n swap(statusMessage_, that.statusMessage_);\n swap(closeConnection_, that.closeConnection_);\n bodyPtr_.swap(that.bodyPtr_);\n swap(contentType_, that.contentType_);\n swap(flagForParsingContentType_, that.flagForParsingContentType_);\n swap(flagForParsingJson_, that.flagForParsingJson_);\n swap(sendfileName_, that.sendfileName_);\n swap(streamCallback_, that.streamCallback_);\n swap(asyncStreamCallback_, that.asyncStreamCallback_);\n jsonPtr_.swap(that.jsonPtr_);\n fullHeaderString_.swap(that.fullHeaderString_);\n httpString_.swap(that.httpString_);\n swap(datePos_, that.datePos_);\n swap(jsonParsingErrorPtr_, that.jsonParsingErrorPtr_);\n}\n\nvoid HttpResponseImpl::clear()\n{\n statusCode_ = kUnknown;\n version_ = Version::kHttp11;\n statusMessage_ = std::string_view{};\n fullHeaderString_.reset();\n jsonParsingErrorPtr_.reset();\n sendfileName_.clear();\n if (streamCallback_)\n {\n LOG_TRACE << \"Cleanup HttpResponse stream callback\";\n streamCallback_(nullptr, 0); // callback internal cleanup\n streamCallback_ = {};\n }\n if (asyncStreamCallback_)\n {\n // asyncStreamCallback_(nullptr);\n asyncStreamCallback_ = {};\n }\n headers_.clear();\n cookies_.clear();\n bodyPtr_.reset();\n jsonPtr_.reset();\n expriedTime_ = -1;\n datePos_ = std::string::npos;\n flagForParsingContentType_ = false;\n flagForParsingJson_ = false;\n}\n\nvoid HttpResponseImpl::parseJson() const\n{\n static std::once_flag once;\n static Json::CharReaderBuilder builder;\n std::call_once(once, []() {\n builder[\"collectComments\"] = false;\n builder[\"stackLimit\"] =\n static_cast(drogon::app().getJsonParserStackLimit());\n });\n JSONCPP_STRING errs;\n std::unique_ptr reader(builder.newCharReader());\n if (bodyPtr_)\n {\n jsonPtr_ = std::make_shared();\n if (!reader->parse(bodyPtr_->data(),\n bodyPtr_->data() + bodyPtr_->length(),\n jsonPtr_.get(),\n &errs))\n {\n LOG_ERROR << errs;\n LOG_ERROR << \"body: \" << bodyPtr_->getString();\n jsonPtr_.reset();\n jsonParsingErrorPtr_ =\n std::make_shared(std::move(errs));\n }\n else\n {\n jsonParsingErrorPtr_.reset();\n }\n }\n else\n {\n jsonPtr_.reset();\n jsonParsingErrorPtr_ =\n std::make_shared(\"empty response body\");\n }\n}\n\nbool HttpResponseImpl::shouldBeCompressed() const\n{\n if (streamCallback_ || asyncStreamCallback_ || !sendfileName_.empty() ||\n contentType() >= CT_APPLICATION_OCTET_STREAM ||\n getBody().length() < 1024 || !(getHeaderBy(\"content-encoding\").empty()))\n {\n return false;\n }\n return true;\n}\n\nvoid HttpResponseImpl::setContentTypeString(const char *typeString,\n size_t typeStringLength)\n{\n std::string sv(typeString, typeStringLength);\n auto contentType = parseContentType(sv);\n if (contentType == CT_NONE)\n contentType = CT_CUSTOM;\n contentType_ = contentType;\n contentTypeString_ = std::string(sv);\n flagForParsingContentType_ = true;\n}\n\n// Path: lib/src/ConfigLoader.h\n/**\n *\n * ConfigLoader.h\n * An Tao\n *\n * Copyright 2018, An Tao. All rights reserved.\n * https://github.com/an-tao/drogon\n * Use of this source code is governed by a MIT license\n * that can be found in the License file.\n *\n * Drogon\n *\n */\n\n#pragma once\n\n#include \n#include \n#include \n\nnamespace drogon\n{\nclass ConfigLoader : public trantor::NonCopyable\n{\n public:\n explicit ConfigLoader(const std::string &configFile) noexcept(false);\n explicit ConfigLoader(const Json::Value &data);\n explicit ConfigLoader(Json::Value &&data);\n ~ConfigLoader();\n\n const Json::Value &jsonValue() const\n {\n return configJsonRoot_;\n }\n\n void load() noexcept(false);\n\n private:\n std::string configFile_;\n Json::Value configJsonRoot_;\n};\n} // namespace drogon\n\n// Path: lib/src/HttpServer.h\n/**\n *\n * @file HttpServer.h\n * @author An Tao\n *\n * Copyright 2018, An Tao. All rights reserved.\n * https://github.com/an-tao/drogon\n * Use of this source code is governed by a MIT license\n * that can be found in the License file.\n *\n * Drogon\n *\n */\n\n#pragma once\n\n#include \n#include \n#include \n#include \n#include \n#include \"impl_forwards.h\"\n\nstruct CallbackParamPack;\n\nnamespace drogon\n{\nclass ControllerBinderBase;\n\nclass HttpServer : trantor::NonCopyable\n{\n public:\n HttpServer(trantor::EventLoop *loop,\n const trantor::InetAddress &listenAddr,\n std::string name);\n\n ~HttpServer();\n\n void setIoLoops(const std::vector &ioLoops)\n {\n server_.setIoLoops(ioLoops);\n }\n\n void start();\n void stop();\n\n void enableSSL(trantor::TLSPolicyPtr policy)\n {\n server_.enableSSL(std::move(policy));\n }\n\n const trantor::InetAddress &address() const\n {\n return server_.address();\n }\n\n private:\n friend class HttpInternalForwardHelper;\n\n static void onConnection(const trantor::TcpConnectionPtr &conn);\n static void onMessage(const trantor::TcpConnectionPtr &,\n trantor::MsgBuffer *);\n static void onRequests(const trantor::TcpConnectionPtr &,\n const std::vector &,\n const std::shared_ptr &);\n\n struct HttpRequestParamPack\n {\n std::shared_ptr binderPtr;\n std::function callback;\n };\n\n struct WsRequestParamPack\n {\n std::shared_ptr binderPtr;\n std::function callback;\n WebSocketConnectionImplPtr wsConnPtr;\n };\n\n // Http request handling steps\n static void onHttpRequest(const HttpRequestImplPtr &,\n std::function &&);\n static void httpRequestRouting(\n const HttpRequestImplPtr &req,\n std::function &&callback);\n static void httpRequestHandling(\n const HttpRequestImplPtr &req,\n std::shared_ptr &&binderPtr,\n std::function &&callback);\n\n // Websocket request handling steps\n static void onWebsocketRequest(\n const HttpRequestImplPtr &,\n std::function &&,\n WebSocketConnectionImplPtr &&);\n static void websocketRequestRouting(\n const HttpRequestImplPtr &req,\n std::function &&callback,\n WebSocketConnectionImplPtr &&wsConnPtr);\n static void websocketRequestHandling(\n const HttpRequestImplPtr &req,\n std::shared_ptr &&binderPtr,\n std::function &&callback,\n WebSocketConnectionImplPtr &&wsConnPtr);\n\n // Http/Websocket shared handling steps\n template \n static void requestPostRouting(const HttpRequestImplPtr &req, Pack &&pack);\n template \n static void requestPassFilters(const HttpRequestImplPtr &req, Pack &&pack);\n template \n static void requestPreHandling(const HttpRequestImplPtr &req, Pack &&pack);\n\n // Response buffering and sending\n static void handleResponse(\n const HttpResponsePtr &response,\n const std::shared_ptr ¶mPack,\n bool *respReadyPtr);\n static void sendResponse(const trantor::TcpConnectionPtr &,\n const HttpResponsePtr &,\n bool isHeadMethod);\n static void sendResponses(\n const trantor::TcpConnectionPtr &conn,\n const std::vector> &responses,\n trantor::MsgBuffer &buffer);\n\n trantor::TcpServer server_;\n};\n\nclass HttpInternalForwardHelper\n{\n public:\n static void forward(const HttpRequestImplPtr &req,\n std::function &&callback)\n {\n return HttpServer::onHttpRequest(req, std::move(callback));\n }\n};\n\n} // namespace drogon\n\n// Path: lib/src/ListenerManager.h\n/**\n *\n * @file ListenerManager.h\n * @author An Tao\n *\n * Copyright 2018, An Tao. All rights reserved.\n * https://github.com/an-tao/drogon\n * Use of this source code is governed by a MIT license\n * that can be found in the License file.\n *\n * Drogon\n *\n */\n\n#pragma once\n\n#include \n#include \n#include \n#include \n#include \n#include \n#include \n#include \"impl_forwards.h\"\n\nnamespace trantor\n{\nclass InetAddress;\n}\n\nnamespace drogon\n{\nclass ListenerManager : public trantor::NonCopyable\n{\n public:\n ~ListenerManager() = default;\n void addListener(const std::string &ip,\n uint16_t port,\n bool useSSL = false,\n const std::string &certFile = \"\",\n const std::string &keyFile = \"\",\n bool useOldTLS = false,\n const std::vector>\n &sslConfCmds = {});\n std::vector getListeners() const;\n void createListeners(\n const std::string &globalCertFile,\n const std::string &globalKeyFile,\n const std::vector> &sslConfCmds,\n const std::vector &ioLoops);\n void startListening();\n void stopListening();\n\n private:\n struct ListenerInfo\n {\n \nListenerInfo(\n std::string ip,\n uint16_t port,\n bool useSSL,\n std::string certFile,\n std::string keyFile,\n bool useOldTLS,\n std::vector> sslConfCmds)\n : ip_(std::move(ip)),\n port_(port),\n useSSL_(useSSL),\n certFile_(std::move(certFile)),\n keyFile_(std::move(keyFile)),\n useOldTLS_(useOldTLS),\n sslConfCmds_(std::move(sslConfCmds))\n {\n }\n\n std::string ip_;\n uint16_t port_;\n bool useSSL_;\n std::string certFile_;\n std::string keyFile_;\n bool useOldTLS_;\n std::vector> sslConfCmds_;\n };\n\n std::vector listeners_;\n std::vector> servers_;\n\n // should have value when and only when on OS that one port can only be\n // listened by one thread\n std::unique_ptr listeningThread_;\n};\n\n} // namespace drogon\n\n// Path: lib/src/RedisClientManager.h\n/**\n *\n * RedisClientManager.h\n * An Tao\n *\n * Copyright 2018, An Tao. All rights reserved.\n * https://github.com/an-tao/drogon\n * Use of this source code is governed by a MIT license\n * that can be found in the License file.\n *\n * Drogon\n *\n */\n\n#pragma once\n\n#include \n#include \n#include \n#include \n#include \n#include \n#include \n\nnamespace drogon\n{\nnamespace nosql\n{\nclass RedisClientManager : public trantor::NonCopyable\n{\n public:\n void createRedisClients(const std::vector &ioLoops);\n\n RedisClientPtr getRedisClient(const std::string &name)\n {\n assert(redisClientsMap_.find(name) != redisClientsMap_.end());\n return redisClientsMap_[name];\n }\n\n RedisClientPtr getFastRedisClient(const std::string &name)\n {\n auto iter = redisFastClientsMap_.find(name);\n assert(iter != redisFastClientsMap_.end());\n return iter->second.getThreadData();\n }\n\n void createRedisClient(const std::string &name,\n const std::string &host,\n unsigned short port,\n const std::string &username,\n const std::string &password,\n size_t connectionNum,\n bool isFast,\n double timeout,\n unsigned int db);\n // bool areAllRedisClientsAvailable() const noexcept;\n\n ~RedisClientManager();\n\n private:\n std::map redisClientsMap_;\n std::map> redisFastClientsMap_;\n\n struct RedisInfo\n {\n std::string name_;\n std::string addr_;\n std::string username_;\n std::string password_;\n unsigned short port_;\n bool isFast_;\n size_t connectionNumber_;\n double timeout_;\n unsigned int db_;\n };\n\n std::vector redisInfos_;\n};\n} // namespace nosql\n} // namespace drogon\n\n// Path: lib/src/SharedLibManager.h\n/**\n *\n * SharedLibManager.h\n * An Tao\n *\n * Copyright 2018, An Tao. All rights reserved.\n * https://github.com/an-tao/drogon\n * Use of this source code is governed by a MIT license\n * that can be found in the License file.\n *\n * Drogon\n *\n */\n\n#pragma once\n\n#include \n#include \n#include \n#include \n#include \n\nnamespace drogon\n{\nclass SharedLibManager : public trantor::NonCopyable\n{\n public:\n SharedLibManager(const std::vector &libPaths,\n const std::string &outputPath);\n ~SharedLibManager();\n\n private:\n void managerLibs();\n std::vector libPaths_;\n std::string outputPath_;\n\n struct DLStat\n {\n void *handle{nullptr};\n struct timespec mTime = {0, 0};\n };\n\n std::unordered_map dlMap_;\n void *compileAndLoadLib(const std::string &sourceFile, void *oldHld);\n void *loadLib(const std::string &soFile, void *oldHld);\n bool shouldCompileLib(const std::string &soFile,\n const struct stat &sourceStat);\n trantor::TimerId timeId_;\n trantor::EventLoopThread workingThread_;\n};\n} // namespace drogon\n\n// Path: lib/src/HttpAppFrameworkImpl.cc\n/**\n *\n * @file HttpAppFrameworkImpl.cc\n * @author An Tao\n *\n * Copyright 2018, An Tao. All rights reserved.\n * https://github.com/an-tao/drogon\n * Use of this source code is governed by a MIT license\n * that can be found in the License file.\n *\n * Drogon\n *\n */\n\n#include \"HttpAppFrameworkImpl.h\"\n#include \n#include \n#include \n#include \n#include \n#include \n#include \n#include \n#include \"AOPAdvice.h\"\n#include \"ConfigLoader.h\"\n#include \"DbClientManager.h\"\n#include \"HttpClientImpl.h\"\n#include \"HttpConnectionLimit.h\"\n#include \"HttpControllersRouter.h\"\n#include \"HttpRequestImpl.h\"\n#include \"HttpResponseImpl.h\"\n#include \"HttpServer.h\"\n#include \"HttpUtils.h\"\n#include \"ListenerManager.h\"\n#include \"PluginsManager.h\"\n#include \"RedisClientManager.h\"\n#include \"SessionManager.h\"\n#include \"SharedLibManager.h\"\n#include \"StaticFileRouter.h\"\n\n#include \n#include \n#include \n#include \n#include \n#include \n\n#include \n#include \n#ifndef _WIN32\n#include \n#include \n#define os_access access\n#elif !defined(_WIN32) || defined(__MINGW32__)\n#include \n#include \n#define os_access access\n#else\n#include \n#define os_access _waccess\n#define R_OK 04\n#define W_OK 02\n#endif\n\n#ifdef DROGON_SPDLOG_SUPPORT\n#include \n#include \n#include \n#include \n#ifdef _WIN32\n#include \n// Damn antedeluvian M$ macros\n#undef min\n#undef max\n#endif\n#endif // DROGON_SPDLOG_SUPPORT\n\nusing namespace drogon;\nusing namespace std::placeholders;\n\nHttpAppFrameworkImpl::HttpAppFrameworkImpl()\n : listenerManagerPtr_(new ListenerManager),\n pluginsManagerPtr_(new PluginsManager),\n dbClientManagerPtr_(new orm::DbClientManager),\n redisClientManagerPtr_(new nosql::RedisClientManager),\n uploadPath_(rootPath_ + \"uploads\")\n{\n}\n\nstatic std::function f = [] {\n LOG_TRACE << \"Initialize the main event loop in the main thread\";\n};\n\n/// Make sure that the main event loop is initialized in the main thread.\ndrogon::InitBeforeMainFunction drogon::HttpAppFrameworkImpl::initFirst_([]() {\n HttpAppFrameworkImpl::instance().getLoop()->runInLoop(f);\n});\n\nnamespace drogon\n{\nstd::string getVersion()\n{\n return DROGON_VERSION;\n}\n\nstd::string getGitCommit()\n{\n return DROGON_VERSION_SHA1;\n}\n\nHttpResponsePtr defaultErrorHandler(HttpStatusCode code, const HttpRequestPtr &)\n{\n return std::make_shared(code, CT_TEXT_HTML);\n}\n\nvoid defaultExceptionHandler(\n const std::exception &e,\n const HttpRequestPtr &req,\n std::function &&callback)\n{\n std::string pathWithQuery = req->path();\n if (req->query().empty() == false)\n pathWithQuery += \"?\" + req->query();\n LOG_ERROR << \"Unhandled exception in \" << pathWithQuery\n << \", what(): \" << e.what();\n const auto &handler = app().getCustomErrorHandler();\n callback(handler(k500InternalServerError, req));\n}\n\nstatic void godaemon()\n{\n printf(\"Initializing daemon mode\\n\");\n#ifndef _WIN32\n if (getppid() != 1)\n {\n pid_t pid;\n pid = fork();\n if (pid > 0)\n exit(0); // parent\n if (pid < 0)\n {\n perror(\"fork\");\n exit(1);\n }\n setsid();\n }\n\n // redirect stdin/stdout/stderr to /dev/null\n close(0);\n close(1);\n close(2);\n\n int ret = open(\"/dev/null\", O_RDWR);\n (void)ret;\n ret = dup(0);\n (void)ret;\n ret = dup(0);\n (void)ret;\n umask(0);\n#else\n LOG_ERROR << \"Cannot run as daemon in Windows\";\n exit(1);\n#endif\n\n return;\n}\n\nstatic void TERMFunction(int sig)\n{\n if (sig == SIGTERM)\n {\n LOG_WARN << \"SIGTERM signal received.\";\n HttpAppFrameworkImpl::instance().getTermSignalHandler()();\n }\n else if (sig == SIGINT)\n {\n LOG_WARN << \"SIGINT signal received.\";\n HttpAppFrameworkImpl::instance().getIntSignalHandler()();\n }\n}\n\n} // namespace drogon\n\nHttpAppFrameworkImpl::~HttpAppFrameworkImpl() noexcept\n{\n// Destroy the following objects before the loop destruction\n#ifndef _WIN32\n sharedLibManagerPtr_.reset();\n#endif\n sessionManagerPtr_.reset();\n}\n\nHttpAppFramework &HttpAppFrameworkImpl::setStaticFilesCacheTime(int cacheTime)\n{\n StaticFileRouter::instance().setStaticFilesCacheTime(cacheTime);\n return *this;\n}\n\nint HttpAppFrameworkImpl::staticFilesCacheTime() const\n{\n return StaticFileRouter::instance().staticFilesCacheTime();\n}\n\nHttpAppFramework &HttpAppFrameworkImpl::setGzipStatic(bool useGzipStatic)\n{\n StaticFileRouter::instance().setGzipStatic(useGzipStatic);\n return *this;\n}\n\nHttpAppFramework &HttpAppFrameworkImpl::setBrStatic(bool useGzipStatic)\n{\n StaticFileRouter::instance().setBrStatic(useGzipStatic);\n return *this;\n}\n\nHttpAppFramework &HttpAppFrameworkImpl::setImplicitPageEnable(\n bool useImplicitPage)\n{\n StaticFileRouter::instance().setImplicitPageEnable(useImplicitPage);\n return *this;\n}\n\nbool HttpAppFrameworkImpl::isImplicitPageEnabled() const\n{\n return StaticFileRouter::instance().isImplicitPageEnabled();\n}\n\nHttpAppFramework &HttpAppFrameworkImpl::setImplicitPage(\n const std::string &implicitPageFile)\n{\n StaticFileRouter::instance().setImplicitPage(implicitPageFile);\n return *this;\n}\n\nconst std::string &HttpAppFrameworkImpl::getImplicitPage() const\n{\n return StaticFileRouter::instance().getImplicitPage();\n}\n#ifndef _WIN32\nHttpAppFramework &HttpAppFrameworkImpl::enableDynamicViewsLoading(\n const std::vector &libPaths,\n const std::string &outputPath)\n{\n assert(!running_);\n\n for (auto const &libpath : libPaths)\n {\n if (libpath[0] == '/' ||\n (libpath.length() >= 2 && libpath[0] == '.' && libpath[1] == '/') ||\n (libpath.length() >= 3 && libpath[0] == '.' && libpath[1] == '.' &&\n libpath[2] == '/') ||\n libpath == \".\" || libpath == \"..\")\n {\n libFilePaths_.push_back(libpath);\n }\n else\n {\n if (rootPath_[rootPath_.length() - 1] == '/')\n libFilePaths_.push_back(rootPath_ + libpath);\n else\n libFilePaths_.push_back(rootPath_ + \"/\" + libpath);\n }\n }\n libFileOutputPath_ = outputPath;\n if (!libFileOutputPath_.empty())\n {\n if (drogon::utils::createPath(libFileOutputPath_) == -1)\n {\n LOG_FATAL << \"Can't create \" << libFileOutputPath_\n << \" path for dynamic views\";\n exit(-1);\n }\n }\n\n return *this;\n}\n#endif\nHttpAppFramework &HttpAppFrameworkImpl::setFileTypes(\n const std::vector &types)\n{\n StaticFileRouter::instance().setFileTypes(types);\n return *this;\n}\n\nHttpAppFramework &HttpAppFrameworkImpl::registerWebSocketController(\n const std::string &pathName,\n const std::string &ctrlName,\n const std::vector &filtersAndMethods)\n{\n assert(!routersInit_);\n HttpControllersRouter::instance().registerWebSocketController(\n pathName, ctrlName, filtersAndMethods);\n return *this;\n}\n\nHttpAppFramework &HttpAppFrameworkImpl::registerHttpSimpleController(\n const std::string &pathName,\n const std::string &ctrlName,\n const std::vector &filtersAndMethods)\n{\n assert(!routersInit_);\n HttpControllersRouter::instance().registerHttpSimpleController(\n pathName, ctrlName, filtersAndMethods);\n return *this;\n}\n\nvoid HttpAppFrameworkImpl::registerHttpController(\n const std::string &pathPattern,\n const internal::HttpBinderBasePtr &binder,\n const std::vector &validMethods,\n const std::vector &filters,\n const std::string &handlerName)\n{\n assert(!pathPattern.empty());\n assert(binder);\n assert(!routersInit_);\n HttpControllersRouter::instance().addHttpPath(\n pathPattern, binder, validMethods, filters, handlerName);\n}\n\nvoid HttpAppFrameworkImpl::registerHttpControllerViaRegex(\n const std::string ®Exp,\n const internal::HttpBinderBasePtr &binder,\n const std::vector &validMethods,\n const std::vector &filters,\n const std::string &handlerName)\n{\n assert(!regExp.empty());\n assert(binder);\n assert(!routersInit_);\n HttpControllersRouter::instance().addHttpRegex(\n regExp, binder, validMethods, filters, handlerName);\n}\n\nHttpAppFramework &HttpAppFrameworkImpl::setThreadNum(size_t threadNum)\n{\n if (threadNum == 0)\n {\n threadNum_ = std::thread::hardware_concurrency();\n return *this;\n }\n threadNum_ = threadNum;\n return *this;\n}\n\nPluginBase *HttpAppFrameworkImpl::getPlugin(const std::string &name)\n{\n return pluginsManagerPtr_->getPlugin(name);\n}\n\nstd::shared_ptr HttpAppFrameworkImpl::getSharedPlugin(\n const std::string &name)\n{\n return pluginsManagerPtr_->getSharedPlugin(name);\n}\n\nvoid HttpAppFrameworkImpl::addPlugin(\n const std::string &name,\n const std::vector &dependencies,\n const Json::Value &config)\n{\n assert(!isRunning());\n Json::Value pluginConfig;\n pluginConfig[\"name\"] = name;\n Json::Value deps(Json::arrayValue);\n for (const auto dep : dependencies)\n {\n deps.append(dep);\n }\n pluginConfig[\"dependencies\"] = deps;\n pluginConfig[\"config\"] = config;\n auto &plugins = jsonRuntimeConfig_[\"plugins\"];\n plugins.append(pluginConfig);\n}\n\nvoid HttpAppFrameworkImpl::addPlugins(const Json::Value &configs)\n{\n assert(!isRunning());\n assert(configs.isArray());\n auto &plugins = jsonRuntimeConfig_[\"plugins\"];\n for (const auto config : configs)\n {\n plugins.append(config);\n }\n}\n\nHttpAppFramework &HttpAppFrameworkImpl::addListener(\n const std::string &ip,\n uint16_t port,\n bool useSSL,\n const std::string &certFile,\n const std::string &keyFile,\n bool useOldTLS,\n const std::vector> &sslConfCmds)\n{\n assert(!running_);\n listenerManagerPtr_->addListener(\n ip, port, useSSL, certFile, keyFile, useOldTLS, sslConfCmds);\n return *this;\n}\n\nHttpAppFramework &HttpAppFrameworkImpl::setMaxConnectionNum(\n size_t maxConnections)\n{\n HttpConnectionLimit::instance().setMaxConnectionNum(maxConnections);\n return *this;\n}\n\nHttpAppFramework &HttpAppFrameworkImpl::setMaxConnectionNumPerIP(\n size_t maxConnectionsPerIP)\n{\n HttpConnectionLimit::instance().setMaxConnectionNumPerIP(\n maxConnectionsPerIP);\n return *this;\n}\n\nHttpAppFramework &HttpAppFrameworkImpl::loadConfigFile(\n const std::string &fileName)\n{\n ConfigLoader loader(fileName);\n loader.load();\n jsonConfig_ = loader.jsonValue();\n return *this;\n}\n\nHttpAppFramework &HttpAppFrameworkImpl::loadConfigJson(const Json::Value &data)\n{\n ConfigLoader loader(data);\n loader.load();\n jsonConfig_ = loader.jsonValue();\n return *this;\n}\n\nHttpAppFramework &HttpAppFrameworkImpl::loadConfigJson(Json::Value &&data)\n{\n ConfigLoader loader(std::move(data));\n loader.load();\n jsonConfig_ = loader.jsonValue();\n return *this;\n}\n\nHttpAppFramework &HttpAppFrameworkImpl::setLogPath(\n const std::string &logPath,\n const std::string &logfileBaseName,\n size_t logfileSize,\n size_t maxFiles,\n bool useSpdlog)\n{\n#ifdef DROGON_SPDLOG_SUPPORT\n logWithSpdlog_ = trantor::Logger::hasSpdLogSupport() && useSpdlog;\n#endif\n if (logPath.empty())\n return *this;\n // std::filesystem does not provide a method to check access permissions, so\n // keep existing code\n if (os_access(utils::toNativePath(logPath).c_str(), 0) != 0)\n {\n std::cerr << \"Log path does not exist!\\n\";\n exit(1);\n }\n if (os_access(utils::toNativePath(logPath).c_str(), R_OK | W_OK) != 0)\n {\n std::cerr << \"Unable to access log path!\\n\";\n exit(1);\n }\n logPath_ = logPath;\n logfileBaseName_ = logfileBaseName;\n logfileSize_ = logfileSize;\n logfileMaxNum_ = maxFiles;\n return *this;\n}\n\nHttpAppFramework &HttpAppFrameworkImpl::setLogLevel(\n trantor::Logger::LogLevel level)\n{\n trantor::Logger::setLogLevel(level);\n return *this;\n}\n\nHttpAppFramework &HttpAppFrameworkImpl::setLogLocalTime(bool on)\n{\n trantor::Logger::setDisplayLocalTime(on);\n return *this;\n}\n\nHttpAppFramework &HttpAppFrameworkImpl::setSSLConfigCommands(\n const std::vector> &sslConfCmds)\n{\n sslConfCmds_ = sslConfCmds;\n return *this;\n}\n\nHttpAppFramework &HttpAppFrameworkImpl::setSSLFiles(const std::string &certPath,\n const std::string &keyPath)\n{\n sslCertPath_ = certPath;\n sslKeyPath_ = keyPath;\n return *this;\n}\n\nvoid HttpAppFrameworkImpl::run()\n{\n if (!getLoop()->isInLoopThread())\n {\n getLoop()->moveToCurrentThread();\n }\n LOG_TRACE << \"Start to run...\";\n // Create dirs for cache files\n...\n", "needle_token_start": 10425, "needle_token_end": 10581, "code_context_ntokens": 16125, "output": ["```cpp\nvoid ListenerManager::addListener(const std::string &ip,\n uint16_t port,\n bool useSSL,\n const std::string &certFile,\n const std::string &keyFile,\n bool useOldTLS,\n const std::vector> &sslConfCmds)\n{\n listeners_.emplace_back(ip, port, useSSL, certFile, keyFile, useOldTLS, sslConfCmds);\n}\n```"]} +{"repo": "drogonframework/drogon", "name": "handleException", "language": "cpp", "path": "lib/src/HttpBinder.cc", "position_ratio": 0.75, "description": "\nFunction Description:\n1. **Purpose**: To manage exceptions that occur during HTTP request handling by delegating the exception handling to a framework-specific handler.\n2. **Input**: The function takes three parameters: an exception reference, an HTTP request pointer, and a callback function.\n3. **Output**: There is no direct output; however, the exception handling response is managed through the callback function.\n4. **Procedure**: The function invokes a registered exception handler from the application framework, passing it the exception, the HTTP request, and the callback function to handle the response appropriately.\n\n", "instruction": "Based on the function description and code context, please retrieve and repeat the exact described function from the code context in a code block wrapped by ```:", "template": "instruction\ncode_context\ndescription\ninstruction", "code_context": "// Path: lib/src/HttpAppFrameworkImpl.cc\n/**\n *\n * @file HttpAppFrameworkImpl.cc\n * @author An Tao\n *\n * Copyright 2018, An Tao. All rights reserved.\n * https://github.com/an-tao/drogon\n * Use of this source code is governed by a MIT license\n * that can be found in the License file.\n *\n * Drogon\n *\n */\n\n#include \"HttpAppFrameworkImpl.h\"\n#include \n#include \n#include \n#include \n#include \n#include \n#include \n#include \n#include \"AOPAdvice.h\"\n#include \"ConfigLoader.h\"\n#include \"DbClientManager.h\"\n#include \"HttpClientImpl.h\"\n#include \"HttpConnectionLimit.h\"\n#include \"HttpControllersRouter.h\"\n#include \"HttpRequestImpl.h\"\n#include \"HttpResponseImpl.h\"\n#include \"HttpServer.h\"\n#include \"HttpUtils.h\"\n#include \"ListenerManager.h\"\n#include \"PluginsManager.h\"\n#include \"RedisClientManager.h\"\n#include \"SessionManager.h\"\n#include \"SharedLibManager.h\"\n#include \"StaticFileRouter.h\"\n\n#include \n#include \n#include \n#include \n#include \n#include \n\n#include \n#include \n#ifndef _WIN32\n#include \n#include \n#define os_access access\n#elif !defined(_WIN32) || defined(__MINGW32__)\n#include \n#include \n#define os_access access\n#else\n#include \n#define os_access _waccess\n#define R_OK 04\n#define W_OK 02\n#endif\n\n#ifdef DROGON_SPDLOG_SUPPORT\n#include \n#include \n#include \n#include \n#ifdef _WIN32\n#include \n// Damn antedeluvian M$ macros\n#undef min\n#undef max\n#endif\n#endif // DROGON_SPDLOG_SUPPORT\n\nusing namespace drogon;\nusing namespace std::placeholders;\n\nHttpAppFrameworkImpl::HttpAppFrameworkImpl()\n : listenerManagerPtr_(new ListenerManager),\n pluginsManagerPtr_(new PluginsManager),\n dbClientManagerPtr_(new orm::DbClientManager),\n redisClientManagerPtr_(new nosql::RedisClientManager),\n uploadPath_(rootPath_ + \"uploads\")\n{\n}\n\nstatic std::function f = [] {\n LOG_TRACE << \"Initialize the main event loop in the main thread\";\n};\n\n/// Make sure that the main event loop is initialized in the main thread.\ndrogon::InitBeforeMainFunction drogon::HttpAppFrameworkImpl::initFirst_([]() {\n HttpAppFrameworkImpl::instance().getLoop()->runInLoop(f);\n});\n\nnamespace drogon\n{\nstd::string getVersion()\n{\n return DROGON_VERSION;\n}\n\nstd::string getGitCommit()\n{\n return DROGON_VERSION_SHA1;\n}\n\nHttpResponsePtr defaultErrorHandler(HttpStatusCode code, const HttpRequestPtr &)\n{\n return std::make_shared(code, CT_TEXT_HTML);\n}\n\nvoid defaultExceptionHandler(\n const std::exception &e,\n const HttpRequestPtr &req,\n std::function &&callback)\n{\n std::string pathWithQuery = req->path();\n if (req->query().empty() == false)\n pathWithQuery += \"?\" + req->query();\n LOG_ERROR << \"Unhandled exception in \" << pathWithQuery\n << \", what(): \" << e.what();\n const auto &handler = app().getCustomErrorHandler();\n callback(handler(k500InternalServerError, req));\n}\n\nstatic void godaemon()\n{\n printf(\"Initializing daemon mode\\n\");\n#ifndef _WIN32\n if (getppid() != 1)\n {\n pid_t pid;\n pid = fork();\n if (pid > 0)\n exit(0); // parent\n if (pid < 0)\n {\n perror(\"fork\");\n exit(1);\n }\n setsid();\n }\n\n // redirect stdin/stdout/stderr to /dev/null\n close(0);\n close(1);\n close(2);\n\n int ret = open(\"/dev/null\", O_RDWR);\n (void)ret;\n ret = dup(0);\n (void)ret;\n ret = dup(0);\n (void)ret;\n umask(0);\n#else\n LOG_ERROR << \"Cannot run as daemon in Windows\";\n exit(1);\n#endif\n\n return;\n}\n\nstatic void TERMFunction(int sig)\n{\n if (sig == SIGTERM)\n {\n LOG_WARN << \"SIGTERM signal received.\";\n HttpAppFrameworkImpl::instance().getTermSignalHandler()();\n }\n else if (sig == SIGINT)\n {\n LOG_WARN << \"SIGINT signal received.\";\n HttpAppFrameworkImpl::instance().getIntSignalHandler()();\n }\n}\n\n} // namespace drogon\n\nHttpAppFrameworkImpl::~HttpAppFrameworkImpl() noexcept\n{\n// Destroy the following objects before the loop destruction\n#ifndef _WIN32\n sharedLibManagerPtr_.reset();\n#endif\n sessionManagerPtr_.reset();\n}\n\nHttpAppFramework &HttpAppFrameworkImpl::setStaticFilesCacheTime(int cacheTime)\n{\n StaticFileRouter::instance().setStaticFilesCacheTime(cacheTime);\n return *this;\n}\n\nint HttpAppFrameworkImpl::staticFilesCacheTime() const\n{\n return StaticFileRouter::instance().staticFilesCacheTime();\n}\n\nHttpAppFramework &HttpAppFrameworkImpl::setGzipStatic(bool useGzipStatic)\n{\n StaticFileRouter::instance().setGzipStatic(useGzipStatic);\n return *this;\n}\n\nHttpAppFramework &HttpAppFrameworkImpl::setBrStatic(bool useGzipStatic)\n{\n StaticFileRouter::instance().setBrStatic(useGzipStatic);\n return *this;\n}\n\nHttpAppFramework &HttpAppFrameworkImpl::setImplicitPageEnable(\n bool useImplicitPage)\n{\n StaticFileRouter::instance().setImplicitPageEnable(useImplicitPage);\n return *this;\n}\n\nbool HttpAppFrameworkImpl::isImplicitPageEnabled() const\n{\n return StaticFileRouter::instance().isImplicitPageEnabled();\n}\n\nHttpAppFramework &HttpAppFrameworkImpl::setImplicitPage(\n const std::string &implicitPageFile)\n{\n StaticFileRouter::instance().setImplicitPage(implicitPageFile);\n return *this;\n}\n\nconst std::string &HttpAppFrameworkImpl::getImplicitPage() const\n{\n return StaticFileRouter::instance().getImplicitPage();\n}\n#ifndef _WIN32\nHttpAppFramework &HttpAppFrameworkImpl::enableDynamicViewsLoading(\n const std::vector &libPaths,\n const std::string &outputPath)\n{\n assert(!running_);\n\n for (auto const &libpath : libPaths)\n {\n if (libpath[0] == '/' ||\n (libpath.length() >= 2 && libpath[0] == '.' && libpath[1] == '/') ||\n (libpath.length() >= 3 && libpath[0] == '.' && libpath[1] == '.' &&\n libpath[2] == '/') ||\n libpath == \".\" || libpath == \"..\")\n {\n libFilePaths_.push_back(libpath);\n }\n else\n {\n if (rootPath_[rootPath_.length() - 1] == '/')\n libFilePaths_.push_back(rootPath_ + libpath);\n else\n libFilePaths_.push_back(rootPath_ + \"/\" + libpath);\n }\n }\n libFileOutputPath_ = outputPath;\n if (!libFileOutputPath_.empty())\n {\n if (drogon::utils::createPath(libFileOutputPath_) == -1)\n {\n LOG_FATAL << \"Can't create \" << libFileOutputPath_\n << \" path for dynamic views\";\n exit(-1);\n }\n }\n\n return *this;\n}\n#endif\nHttpAppFramework &HttpAppFrameworkImpl::setFileTypes(\n const std::vector &types)\n{\n StaticFileRouter::instance().setFileTypes(types);\n return *this;\n}\n\nHttpAppFramework &HttpAppFrameworkImpl::registerWebSocketController(\n const std::string &pathName,\n const std::string &ctrlName,\n const std::vector &filtersAndMethods)\n{\n assert(!routersInit_);\n HttpControllersRouter::instance().registerWebSocketController(\n pathName, ctrlName, filtersAndMethods);\n return *this;\n}\n\nHttpAppFramework &HttpAppFrameworkImpl::registerHttpSimpleController(\n const std::string &pathName,\n const std::string &ctrlName,\n const std::vector &filtersAndMethods)\n{\n assert(!routersInit_);\n HttpControllersRouter::instance().registerHttpSimpleController(\n pathName, ctrlName, filtersAndMethods);\n return *this;\n}\n\nvoid HttpAppFrameworkImpl::registerHttpController(\n const std::string &pathPattern,\n const internal::HttpBinderBasePtr &binder,\n const std::vector &validMethods,\n const std::vector &filters,\n const std::string &handlerName)\n{\n assert(!pathPattern.empty());\n assert(binder);\n assert(!routersInit_);\n HttpControllersRouter::instance().addHttpPath(\n pathPattern, binder, validMethods, filters, handlerName);\n}\n\nvoid HttpAppFrameworkImpl::registerHttpControllerViaRegex(\n const std::string ®Exp,\n const internal::HttpBinderBasePtr &binder,\n const std::vector &validMethods,\n const std::vector &filters,\n const std::string &handlerName)\n{\n assert(!regExp.empty());\n assert(binder);\n assert(!routersInit_);\n HttpControllersRouter::instance().addHttpRegex(\n regExp, binder, validMethods, filters, handlerName);\n}\n\nHttpAppFramework &HttpAppFrameworkImpl::setThreadNum(size_t threadNum)\n{\n if (threadNum == 0)\n {\n threadNum_ = std::thread::hardware_concurrency();\n return *this;\n }\n threadNum_ = threadNum;\n return *this;\n}\n\nPluginBase *HttpAppFrameworkImpl::getPlugin(const std::string &name)\n{\n return pluginsManagerPtr_->getPlugin(name);\n}\n\nstd::shared_ptr HttpAppFrameworkImpl::getSharedPlugin(\n const std::string &name)\n{\n return pluginsManagerPtr_->getSharedPlugin(name);\n}\n\nvoid HttpAppFrameworkImpl::addPlugin(\n const std::string &name,\n const std::vector &dependencies,\n const Json::Value &config)\n{\n assert(!isRunning());\n Json::Value pluginConfig;\n pluginConfig[\"name\"] = name;\n Json::Value deps(Json::arrayValue);\n for (const auto dep : dependencies)\n {\n deps.append(dep);\n }\n pluginConfig[\"dependencies\"] = deps;\n pluginConfig[\"config\"] = config;\n auto &plugins = jsonRuntimeConfig_[\"plugins\"];\n plugins.append(pluginConfig);\n}\n\nvoid HttpAppFrameworkImpl::addPlugins(const Json::Value &configs)\n{\n assert(!isRunning());\n assert(configs.isArray());\n auto &plugins = jsonRuntimeConfig_[\"plugins\"];\n for (const auto config : configs)\n {\n plugins.append(config);\n }\n}\n\nHttpAppFramework &HttpAppFrameworkImpl::addListener(\n const std::string &ip,\n uint16_t port,\n bool useSSL,\n const std::string &certFile,\n const std::string &keyFile,\n bool useOldTLS,\n const std::vector> &sslConfCmds)\n{\n assert(!running_);\n listenerManagerPtr_->addListener(\n ip, port, useSSL, certFile, keyFile, useOldTLS, sslConfCmds);\n return *this;\n}\n\nHttpAppFramework &HttpAppFrameworkImpl::setMaxConnectionNum(\n size_t maxConnections)\n{\n HttpConnectionLimit::instance().setMaxConnectionNum(maxConnections);\n return *this;\n}\n\nHttpAppFramework &HttpAppFrameworkImpl::setMaxConnectionNumPerIP(\n size_t maxConnectionsPerIP)\n{\n HttpConnectionLimit::instance().setMaxConnectionNumPerIP(\n maxConnectionsPerIP);\n return *this;\n}\n\nHttpAppFramework &HttpAppFrameworkImpl::loadConfigFile(\n const std::string &fileName)\n{\n ConfigLoader loader(fileName);\n loader.load();\n jsonConfig_ = loader.jsonValue();\n return *this;\n}\n\nHttpAppFramework &HttpAppFrameworkImpl::loadConfigJson(const Json::Value &data)\n{\n ConfigLoader loader(data);\n loader.load();\n jsonConfig_ = loader.jsonValue();\n return *this;\n}\n\nHttpAppFramework &HttpAppFrameworkImpl::loadConfigJson(Json::Value &&data)\n{\n ConfigLoader loader(std::move(data));\n loader.load();\n jsonConfig_ = loader.jsonValue();\n return *this;\n}\n\nHttpAppFramework &HttpAppFrameworkImpl::setLogPath(\n const std::string &logPath,\n const std::string &logfileBaseName,\n size_t logfileSize,\n size_t maxFiles,\n bool useSpdlog)\n{\n#ifdef DROGON_SPDLOG_SUPPORT\n logWithSpdlog_ = trantor::Logger::hasSpdLogSupport() && useSpdlog;\n#endif\n if (logPath.empty())\n return *this;\n // std::filesystem does not provide a method to check access permissions, so\n // keep existing code\n if (os_access(utils::toNativePath(logPath).c_str(), 0) != 0)\n {\n std::cerr << \"Log path does not exist!\\n\";\n exit(1);\n }\n if (os_access(utils::toNativePath(logPath).c_str(), R_OK | W_OK) != 0)\n {\n std::cerr << \"Unable to access log path!\\n\";\n exit(1);\n }\n logPath_ = logPath;\n logfileBaseName_ = logfileBaseName;\n logfileSize_ = logfileSize;\n logfileMaxNum_ = maxFiles;\n return *this;\n}\n\nHttpAppFramework &HttpAppFrameworkImpl::setLogLevel(\n trantor::Logger::LogLevel level)\n{\n trantor::Logger::setLogLevel(level);\n return *this;\n}\n\nHttpAppFramework &HttpAppFrameworkImpl::setLogLocalTime(bool on)\n{\n trantor::Logger::setDisplayLocalTime(on);\n return *this;\n}\n\nHttpAppFramework &HttpAppFrameworkImpl::setSSLConfigCommands(\n const std::vector> &sslConfCmds)\n{\n sslConfCmds_ = sslConfCmds;\n return *this;\n}\n\nHttpAppFramework &HttpAppFrameworkImpl::setSSLFiles(const std::string &certPath,\n const std::string &keyPath)\n{\n sslCertPath_ = certPath;\n sslKeyPath_ = keyPath;\n return *this;\n}\n\nvoid HttpAppFrameworkImpl::run()\n{\n if (!getLoop()->isInLoopThread())\n {\n getLoop()->moveToCurrentThread();\n }\n LOG_TRACE << \"Start to run...\";\n // Create dirs for cache files\n for (int i = 0; i < 256; ++i)\n {\n char dirName[4];\n snprintf(dirName, sizeof(dirName), \"%02x\", i);\n std::transform(dirName, dirName + 2, dirName, [](unsigned char c) {\n return toupper(c);\n });\n utils::createPath(getUploadPath() + \"/tmp/\" + dirName);\n }\n if (runAsDaemon_)\n {\n // go daemon!\n godaemon();\n#ifdef __linux__\n getLoop()->resetTimerQueue();\n#endif\n getLoop()->resetAfterFork();\n }\n // set relaunching\n if (relaunchOnError_)\n {\n#ifndef _WIN32\n while (true)\n {\n int child_status = 0;\n auto child_pid = fork();\n if (child_pid < 0)\n {\n LOG_ERROR << \"fork error\";\n abort();\n }\n else if (child_pid == 0)\n {\n // child\n break;\n }\n waitpid(child_pid, &child_status, 0);\n sleep(1);\n LOG_INFO << \"start new process\";\n }\n#ifdef __linux__\n getLoop()->resetTimerQueue();\n#endif\n getLoop()->resetAfterFork();\n#endif\n }\n if (handleSigterm_)\n {\n#ifdef WIN32\n signal(SIGTERM, TERMFunction);\n signal(SIGINT, TERMFunction);\n#else\n struct sigaction sa;\n sa.sa_handler = TERMFunction;\n sigemptyset(&sa.sa_mask);\n sa.sa_flags = 0;\n if (sigaction(SIGINT, &sa, NULL) == -1)\n {\n LOG_ERROR << \"sigaction() failed, can't set SIGINT handler\";\n abort();\n }\n if (sigaction(SIGTERM, &sa, NULL) == -1)\n {\n LOG_ERROR << \"sigaction() failed, can't set SIGTERM handler\";\n abort();\n }\n#endif\n }\n setupFileLogger();\n if (relaunchOnError_)\n {\n LOG_INFO << \"Start child process\";\n }\n\n#ifndef _WIN32\n if (!libFilePaths_.empty())\n {\n sharedLibManagerPtr_ =\n std::make_unique(libFilePaths_,\n libFileOutputPath_);\n }\n#endif\n\n // Create IO threads\n ioLoopThreadPool_ =\n std::make_unique(threadNum_,\n \"DrogonIoLoop\");\n std::vector ioLoops = ioLoopThreadPool_->getLoops();\n for (size_t i = 0; i < threadNum_; ++i)\n {\n ioLoops[i]->setIndex(i);\n }\n getLoop()->setIndex(threadNum_);\n\n // Create all listeners.\n listenerManagerPtr_->createListeners(sslCertPath_,\n sslKeyPath_,\n sslConfCmds_,\n ioLoops);\n\n // A fast database client instance should be created in the main event\n // loop, so put the main loop into ioLoops.\n ioLoops.push_back(getLoop());\n dbClientManagerPtr_->createDbClients(ioLoops);\n redisClientManagerPtr_->createRedisClients(ioLoops);\n if (useSession_)\n {\n sessionManagerPtr_ =\n std::make_unique(getLoop(),\n sessionTimeout_,\n sessionStartAdvices_,\n sessionDestroyAdvices_,\n sessionIdGeneratorCallback_);\n }\n // now start running!!\n running_ = true;\n // Initialize plugins\n auto &pluginConfig = jsonConfig_[\"plugins\"];\n const auto &runtumePluginConfig = jsonRuntimeConfig_[\"plugins\"];\n if (!pluginConfig.isNull())\n {\n if (!runtumePluginConfig.isNull() && runtumePluginConfig.isArray())\n {\n for (const auto &plugin : runtumePluginConfig)\n {\n pluginConfig.append(plugin);\n }\n }\n }\n else\n {\n jsonConfig_[\"plugins\"] = runtumePluginConfig;\n }\n if (!pluginConfig.isNull())\n {\n pluginsManagerPtr_->initializeAllPlugins(pluginConfig,\n [](PluginBase *plugin) {\n LOG_TRACE\n << \"new plugin:\"\n << plugin->className();\n // TODO: new plugin\n });\n }\n routersInit_ = true;\n HttpControllersRouter::instance().init(ioLoops);\n StaticFileRouter::instance().init(ioLoops);\n getLoop()->queueInLoop([this]() {\n for (auto &adv : beginningAdvices_)\n {\n adv();\n }\n beginningAdvices_.clear();\n // Let listener event loops run when everything is ready.\n listenerManagerPtr_->startListening();\n });\n // start all loops\n // TODO: when should IOLoops start?\n // In before, IOLoops are started in `listenerManagerPtr_->startListening()`\n // It should be fine for them to start anywhere before `startListening()`.\n // However, we should consider other components.\n ioLoopThreadPool_->start();\n getLoop()->loop();\n}\n\nHttpAppFramework &HttpAppFrameworkImpl::setUploadPath(\n const std::string &uploadPath)\n{\n assert(!uploadPath.empty());\n\n std::filesystem::path fsUploadPath(utils::toNativePath(uploadPath));\n if (!fsUploadPath.is_absolute())\n {\n std::filesystem::path fsRoot(utils::toNativePath(rootPath_));\n fsUploadPath = fsRoot / fsUploadPath;\n }\n uploadPath_ = utils::fromNativePath(fsUploadPath.native());\n return *this;\n}\n\nvoid HttpAppFrameworkImpl::findSessionForRequest(const HttpRequestImplPtr &req)\n{\n if (useSession_)\n {\n std::string sessionId = req->getCookie(sessionCookieKey_);\n bool needSetSessionid = false;\n if (sessionId.empty())\n {\n sessionId = sessionIdGeneratorCallback_();\n needSetSessionid = true;\n }\n req->setSession(\n sessionManagerPtr_->getSession(sessionId, needSetSessionid));\n }\n}\n\nstd::vector HttpAppFrameworkImpl::getHandlersInfo() const\n{\n return HttpControllersRouter::instance().getHandlersInfo();\n}\n\nHttpResponsePtr HttpAppFrameworkImpl::handleSessionForResponse(\n const HttpRequestImplPtr &req,\n const HttpResponsePtr &resp)\n{\n if (useSession_)\n {\n auto &sessionPtr = req->getSession();\n if (!sessionPtr)\n {\n return resp;\n }\n if (sessionPtr->needToChangeSessionId())\n {\n sessionManagerPtr_->changeSessionId(sessionPtr);\n }\n if (sessionPtr->needSetToClient())\n {\n if (resp->expiredTime() >= 0)\n {\n auto newResp = std::make_shared(\n *static_cast(resp.get()));\n newResp->setExpiredTime(-1); // make it temporary\n auto sessionid =\n Cookie(sessionCookieKey_, sessionPtr->sessionId());\n sessionid.setPath(\"/\");\n sessionid.setSameSite(sessionSameSite_);\n if (sessionMaxAge_ >= 0)\n sessionid.setMaxAge(sessionMaxAge_);\n newResp->addCookie(std::move(sessionid));\n sessionPtr->hasSet();\n\n return newResp;\n }\n else\n {\n auto sessionid =\n Cookie(sessionCookieKey_, sessionPtr->sessionId());\n sessionid.setPath(\"/\");\n sessionid.setSameSite(sessionSameSite_);\n if (sessionMaxAge_ >= 0)\n sessionid.setMaxAge(sessionMaxAge_);\n resp->addCookie(std::move(sessionid));\n sessionPtr->hasSet();\n\n return resp;\n }\n }\n else if (resp->version() != req->version())\n {\n auto newResp = std::make_shared(\n *static_cast(resp.get()));\n newResp->setVersion(req->version());\n newResp->setExpiredTime(-1); // make it temporary\n\n return newResp;\n }\n else\n {\n return resp;\n }\n }\n else\n {\n if (resp->expiredTime() >= 0 && resp->version() != req->version())\n {\n auto newResp = std::make_shared(\n *static_cast(resp.get()));\n newResp->setVersion(req->version());\n newResp->setExpiredTime(-1); // make it temporary\n\n return newResp;\n }\n else\n {\n return resp;\n }\n }\n}\n\ntrantor::EventLoop *HttpAppFrameworkImpl::getLoop() const\n{\n static trantor::EventLoop loop;\n return &loop;\n}\n\ntrantor::EventLoop *HttpAppFrameworkImpl::getIOLoop(size_t id) const\n{\n if (!ioLoopThreadPool_)\n {\n LOG_WARN << \"Please call getIOLoop() after drogon::app().run()\";\n return nullptr;\n }\n auto n = ioLoopThreadPool_->size();\n if (id >= n)\n {\n LOG_TRACE << \"Loop id (\" << id << \") out of range [0-\" << n << \").\";\n id %= n;\n LOG_TRACE << \"Rounded to : \" << id;\n }\n return ioLoopThreadPool_->getLoop(id);\n}\n\nHttpAppFramework &HttpAppFramework::instance()\n{\n return HttpAppFrameworkImpl::instance();\n}\n\nvoid HttpAppFrameworkImpl::forward(\n const HttpRequestPtr &req,\n std::function &&callback,\n const std::string &hostString,\n double timeout)\n{\n forward(std::dynamic_pointer_cast(req),\n std::move(callback),\n hostString,\n timeout);\n}\n\nvoid HttpAppFrameworkImpl::forward(\n const HttpRequestImplPtr &req,\n std::function &&callback,\n const std::string &hostString,\n double timeout)\n{\n if (hostString.empty())\n {\n HttpInternalForwardHelper::forward(req, std::move(callback));\n }\n else\n {\n /// A tiny implementation of a reverse proxy.\n static std::unordered_map clientsMap;\n HttpClientImplPtr clientPtr;\n static std::mutex mtx;\n {\n std::lock_guard lock(mtx);\n auto iter = clientsMap.find(hostString);\n if (iter != clientsMap.end())\n {\n clientPtr = iter->second;\n }\n else\n {\n clientPtr = std::make_shared(\n trantor::EventLoop::getEventLoopOfCurrentThread()\n ? trantor::EventLoop::getEventLoopOfCurrentThread()\n : getLoop(),\n hostString);\n clientsMap[hostString] = clientPtr;\n }\n }\n req->setPassThrough(true);\n clientPtr->sendRequest(\n req,\n [callback = std::move(callback), req](ReqResult result,\n const HttpResponsePtr &resp) {\n if (result == ReqResult::Ok)\n {\n resp->setPassThrough(true);\n callback(resp);\n }\n else\n {\n callback(HttpResponse::newNotFoundResponse(req));\n }\n },\n timeout);\n }\n}\n\norm::DbClientPtr HttpAppFrameworkImpl::getDbClient(const std::string &name)\n{\n return dbClientManagerPtr_->getDbClient(name);\n}\n\norm::DbClientPtr HttpAppFrameworkImpl::getFastDbClient(const std::string &name)\n{\n return dbClientManagerPtr_->getFastDbClient(name);\n}\n\nnosql::RedisClientPtr HttpAppFrameworkImpl::getRedisClient(\n const std::string &name)\n{\n return redisClientManagerPtr_->getRedisClient(name);\n}\n\nnosql::RedisClientPtr HttpAppFrameworkImpl::getFastRedisClient(\n const std::string &name)\n{\n return redisClientManagerPtr_->getFastRedisClient(name);\n}\n\nHttpAppFramework &HttpAppFrameworkImpl::createDbClient(\n const std::string &dbType,\n const std::string &host,\n const unsigned short port,\n const std::string &databaseName,\n const std::string &userName,\n const std::string &password,\n const size_t connectionNum,\n const std::string &filename,\n const std::string &name,\n const bool isFast,\n const std::string &characterSet,\n double timeout,\n const bool autoBatch)\n{\n assert(!running_);\n dbClientManagerPtr_->createDbClient(dbType,\n host,\n port,\n databaseName,\n userName,\n password,\n connectionNum,\n filename,\n name,\n isFast,\n characterSet,\n timeout,\n autoBatch);\n return *this;\n}\n\nHttpAppFramework &HttpAppFrameworkImpl::createRedisClient(\n const std::string &ip,\n unsigned short port,\n const std::string &name,\n const std::string &password,\n size_t connectionNum,\n bool isFast,\n double timeout,\n unsigned int db,\n const std::string &username)\n{\n assert(!running_);\n redisClientManagerPtr_->createRedisClient(\n name, ip, port, username, password, connectionNum, isFast, timeout, db);\n return *this;\n}\n\nvoid HttpAppFrameworkImpl::quit()\n{\n if (getLoop()->isRunning())\n {\n getLoop()->queueInLoop([this]() {\n // Release members in the reverse order of initialization\n listenerManagerPtr_->stopListening();\n listenerManagerPtr_.reset();\n StaticFileRouter::instance().reset();\n HttpControllersRouter::instance().reset();\n pluginsManagerPtr_.reset();\n redisClientManagerPtr_.reset();\n dbClientManagerPtr_.reset();\n running_ = false;\n getLoop()->quit();\n for (trantor::EventLoop *loop : ioLoopThreadPool_->getLoops())\n {\n loop->quit();\n }\n ioLoopThreadPool_->wait();\n });\n }\n}\n\nconst HttpResponsePtr &HttpAppFrameworkImpl::getCustom404Page()\n{\n if (!custom404_)\n {\n return custom404_;\n }\n auto loop = trantor::EventLoop::getEventLoopOfCurrentThread();\n if (loop && loop->index() < app().getThreadNum())\n {\n // If the current thread is an IO thread\n static IOThreadStorage thread404Pages;\n static std::once_flag once;\n std::call_once(once, [this] {\n thread404Pages.init(\n [this](HttpResponsePtr &resp, size_t /*index*/) {\n resp = std::make_shared(\n *static_cast(custom404_.get()));\n });\n });\n return thread404Pages.getThreadData();\n }\n else\n {\n return custom404_;\n }\n}\n\nHttpAppFramework &HttpAppFrameworkImpl::setStaticFileHeaders(\n const std::vector> &headers)\n{\n StaticFileRouter::instance().setStaticFileHeaders(headers);\n return *this;\n}\n\nHttpAppFramework &HttpAppFrameworkImpl::addALocation(\n const std::string &uriPrefix,\n const std::string &defaultContentType,\n const std::string &alias,\n bool isCaseSensitive,\n bool allowAll,\n bool isRecursive,\n const std::vector &filters)\n{\n StaticFileRouter::instance().addALocation(uriPrefix,\n defaultContentType,\n alias,\n isCaseSensitive,\n allowAll,\n isRecursive,\n filters);\n return *this;\n}\n\nbool HttpAppFrameworkImpl::areAllDbClientsAvailable() const noexcept\n{\n return dbClientManagerPtr_->areAllDbClientsAvailable();\n}\n\nHttpAppFramework &HttpAppFrameworkImpl::setCustomErrorHandler(\n std::function\n &&resp_generator)\n{\n customErrorHandler_ = std::move(resp_generator);\n usingCustomErrorHandler_ = true;\n return *this;\n}\n\nconst std::function &\nHttpAppFrameworkImpl::getCustomErrorHandler() const\n{\n return customErrorHandler_;\n}\n\nstd::vector HttpAppFrameworkImpl::getListeners() const\n{\n return listenerManagerPtr_->getListeners();\n}\n\nHttpAppFramework &HttpAppFrameworkImpl::setDefaultHandler(\n DefaultHandler handler)\n{\n StaticFileRouter::instance().setDefaultHandler(std::move(handler));\n return *this;\n}\n\nHttpAppFramework &HttpAppFrameworkImpl::setupFileLogger()\n{\n#ifdef DROGON_SPDLOG_SUPPORT\n if (logWithSpdlog_)\n {\n...\n// Path: lib/src/ConfigAdapter.h\n#pragma once\n#include \n#include \n#include \n#include \n#include \n\nnamespace drogon\n{\nclass ConfigAdapter\n{\n public:\n virtual ~ConfigAdapter() = default;\n virtual Json::Value getJson(const std::string &content) const\n noexcept(false) = 0;\n virtual std::vector getExtensions() const = 0;\n};\n\nusing ConfigAdapterPtr = std::shared_ptr;\n\n} // namespace drogon\n\n// Path: lib/src/JsonConfigAdapter.h\n#pragma once\n#include \"ConfigAdapter.h\"\n\nnamespace drogon\n{\nclass JsonConfigAdapter : public ConfigAdapter\n{\n public:\n JsonConfigAdapter() = default;\n ~JsonConfigAdapter() override = default;\n Json::Value getJson(const std::string &content) const\n noexcept(false) override;\n std::vector getExtensions() const override;\n};\n} // namespace drogon\n\n// Path: lib/src/JsonConfigAdapter.cc\n#include \"JsonConfigAdapter.h\"\n#include \n#include \n\nusing namespace drogon;\n\nJson::Value JsonConfigAdapter::getJson(const std::string &content) const\n noexcept(false)\n{\n static std::once_flag once;\n static Json::CharReaderBuilder builder;\n std::call_once(once, []() { builder[\"collectComments\"] = false; });\n JSONCPP_STRING errs;\n std::unique_ptr reader(builder.newCharReader());\n Json::Value root;\n if (!reader->parse(\n content.c_str(), content.c_str() + content.size(), &root, &errs))\n {\n throw std::runtime_error(errs);\n }\n return root;\n}\n\nstd::vector JsonConfigAdapter::getExtensions() const\n{\n return {\"json\"};\n}\n\n// Path: lib/src/Histogram.cc\n#include \nusing namespace drogon;\nusing namespace drogon::monitoring;\n\nvoid Histogram::observe(double value)\n{\n std::lock_guard lock(mutex_);\n if (maxAge_ > std::chrono::seconds(0) &&\n timerId_ == trantor::InvalidTimerId)\n {\n std::weak_ptr weakPtr =\n std::dynamic_pointer_cast(shared_from_this());\n timerId_ = loopPtr_->runEvery(maxAge_ / timeBucketCount_, [weakPtr]() {\n auto thisPtr = weakPtr.lock();\n if (!thisPtr)\n return;\n thisPtr->rotateTimeBuckets();\n });\n }\n auto ¤tBucket = timeBuckets_.back();\n currentBucket.sum += value;\n currentBucket.count += 1;\n for (size_t i = 0; i < bucketBoundaries_.size(); i++)\n {\n if (value <= bucketBoundaries_[i])\n {\n currentBucket.buckets[i] += 1;\n break;\n }\n }\n if (value > bucketBoundaries_.back())\n {\n currentBucket.buckets.back() += 1;\n }\n}\n\nstd::vector Histogram::collect() const\n{\n std::vector samples;\n std::lock_guard guard(mutex_);\n size_t count{0};\n for (size_t i = 0; i < bucketBoundaries_.size(); i++)\n {\n Sample sample;\n for (auto &bucket : timeBuckets_)\n {\n count += bucket.buckets[i];\n }\n sample.name = name_ + \"_bucket\";\n sample.exLabels.emplace_back(\"le\",\n std::to_string(bucketBoundaries_[i]));\n sample.value = count;\n samples.emplace_back(std::move(sample));\n }\n Sample sample;\n for (auto &bucket : timeBuckets_)\n {\n count += bucket.buckets.back();\n }\n sample.name = name_ + \"_bucket\";\n sample.exLabels.emplace_back(\"le\", \"+Inf\");\n sample.value = count;\n samples.emplace_back(std::move(sample));\n double sum{0};\n uint64_t totalCount{0};\n for (auto &bucket : timeBuckets_)\n {\n sum += bucket.sum;\n totalCount += bucket.count;\n }\n Sample sumSample;\n sumSample.name = name_ + \"_sum\";\n sumSample.value = sum;\n samples.emplace_back(std::move(sumSample));\n Sample countSample;\n countSample.name = name_ + \"_count\";\n countSample.value = totalCount;\n samples.emplace_back(std::move(countSample));\n return samples;\n}\n\n// Path: lib/src/GlobalFilters.cc\n#include \n#include \n#include \n#include \"FiltersFunction.h\"\n#include \"HttpRequestImpl.h\"\n#include \"HttpAppFrameworkImpl.h\"\n\nusing namespace drogon::plugin;\n\nvoid GlobalFilters::initAndStart(const Json::Value &config)\n{\n if (config.isMember(\"filters\") && config[\"filters\"].isArray())\n {\n auto &filters = config[\"filters\"];\n\n for (auto const &filter : filters)\n {\n if (filter.isString())\n {\n auto filterPtr = std::dynamic_pointer_cast(\n drogon::DrClassMap::getSingleInstance(filter.asString()));\n if (filterPtr)\n {\n filters_.push_back(filterPtr);\n }\n else\n {\n LOG_ERROR << \"Filter \" << filter.asString()\n << \" not found!\";\n }\n }\n }\n }\n if (config.isMember(\"exempt\"))\n {\n auto exempt = config[\"exempt\"];\n if (exempt.isArray())\n {\n std::string regexStr;\n for (auto const &ex : exempt)\n {\n if (ex.isString())\n {\n regexStr.append(\"(\").append(ex.asString()).append(\")|\");\n }\n else\n {\n LOG_ERROR << \"exempt must be a string array!\";\n }\n }\n if (!regexStr.empty())\n {\n regexStr.pop_back();\n exemptPegex_ = std::regex(regexStr);\n regexFlag_ = true;\n }\n }\n else if (exempt.isString())\n {\n exemptPegex_ = std::regex(exempt.asString());\n regexFlag_ = true;\n }\n else\n {\n LOG_ERROR << \"exempt must be a string or string array!\";\n }\n }\n std::weak_ptr weakPtr = shared_from_this();\n drogon::app().registerPreRoutingAdvice(\n [weakPtr](const drogon::HttpRequestPtr &req,\n drogon::AdviceCallback &&acb,\n drogon::AdviceChainCallback &&accb) {\n auto thisPtr = weakPtr.lock();\n if (!thisPtr)\n {\n accb();\n return;\n }\n if (thisPtr->regexFlag_)\n {\n if (std::regex_match(req->path(), thisPtr->exemptPegex_))\n {\n accb();\n return;\n }\n }\n\n drogon::filters_function::doFilters(\n thisPtr->filters_,\n std::static_pointer_cast(req),\n [acb = std::move(acb),\n accb = std::move(accb)](const HttpResponsePtr &resp) {\n if (resp)\n {\n acb(resp);\n }\n else\n {\n accb();\n }\n });\n });\n}\n\nvoid GlobalFilters::shutdown()\n{\n filters_.clear();\n}\n\n// Path: lib/src/HttpControllerBinder.cc\n/**\n *\n * @file HttpControllerBinder.cc\n * @author Nitromelon\n *\n * Copyright 2023, Nitromelon. All rights reserved.\n * https://github.com/drogonframework/drogon\n * Use of this source code is governed by a MIT license\n * that can be found in the License file.\n *\n * Drogon\n *\n */\n\n#include \"HttpControllerBinder.h\"\n#include \"HttpResponseImpl.h\"\n#include \n#include \n\nnamespace drogon\n{\n\nvoid HttpSimpleControllerBinder::handleRequest(\n const HttpRequestImplPtr &req,\n std::function &&callback) const\n{\n // Binders without controller should be removed after run()\n assert(controller_);\n\n try\n {\n auto cb = callback; // copy\n controller_->asyncHandleHttpRequest(req, std::move(cb));\n }\n catch (const std::exception &e)\n {\n app().getExceptionHandler()(e, req, std::move(callback));\n return;\n }\n catch (...)\n {\n LOG_ERROR << \"Exception not derived from std::exception\";\n return;\n }\n}\n\nvoid HttpControllerBinder::handleRequest(\n const HttpRequestImplPtr &req,\n std::function &&callback) const\n{\n auto ¶msVector = req->getRoutingParameters();\n std::deque params(paramsVector.begin(), paramsVector.end());\n binderPtr_->handleHttpRequest(params, req, std::move(callback));\n}\n\nvoid WebsocketControllerBinder::handleRequest(\n const HttpRequestImplPtr &req,\n std::function &&callback) const\n{\n std::string wsKey = req->getHeaderBy(\"sec-websocket-key\");\n wsKey.append(\"258EAFA5-E914-47DA-95CA-C5AB0DC85B11\");\n unsigned char accKey[20];\n auto sha1 = trantor::utils::sha1(wsKey.c_str(), wsKey.length());\n memcpy(accKey, &sha1, sizeof(sha1));\n auto base64Key = utils::base64Encode(accKey, sizeof(accKey));\n auto resp = HttpResponse::newHttpResponse();\n resp->setStatusCode(k101SwitchingProtocols);\n resp->addHeader(\"Upgrade\", \"websocket\");\n resp->addHeader(\"Connection\", \"Upgrade\");\n resp->addHeader(\"Sec-WebSocket-Accept\", base64Key);\n callback(resp);\n}\n\nvoid WebsocketControllerBinder::handleNewConnection(\n const HttpRequestImplPtr &req,\n const WebSocketConnectionImplPtr &wsConnPtr) const\n{\n auto ctrlPtr = controller_;\n assert(ctrlPtr);\n wsConnPtr->setMessageCallback(\n [ctrlPtr](std::string &&message,\n const WebSocketConnectionImplPtr &connPtr,\n const WebSocketMessageType &type) {\n ctrlPtr->handleNewMessage(connPtr, std::move(message), type);\n });\n wsConnPtr->setCloseCallback(\n [ctrlPtr](const WebSocketConnectionImplPtr &connPtr) {\n ctrlPtr->handleConnectionClosed(connPtr);\n });\n ctrlPtr->handleNewConnection(req, wsConnPtr);\n}\n\n} // namespace drogon\n\n// Path: lib/src/HttpBinder.cc\n/**\n *\n * HttpBinder.h\n * Martin Chang\n *\n * Copyright 2021, Martin Chang. All rights reserved.\n * https://github.com/an-tao/drogon\n * Use of this source code is governed by a MIT license\n * that can be found in the License file.\n *\n * Drogon\n *\n */\n\n#include \n#include \n\nnamespace drogon\n{\nnamespace internal\n{\n\nvoid handleException(const std::exception &e,\n const HttpRequestPtr &req,\n std::function &&callback)\n{\n app().getExceptionHandler()(e, req, std::move(callback));\n}\n} // namespace internal\n} // namespace drogon\n\n// Path: lib/src/YamlConfigAdapter.h\n#pragma once\n#include \"ConfigAdapter.h\"\n\nnamespace drogon\n{\nclass YamlConfigAdapter : public ConfigAdapter\n{\n public:\n YamlConfigAdapter() = default;\n ~YamlConfigAdapter() override = default;\n Json::Value getJson(const std::string &content) const\n noexcept(false) override;\n std::vector getExtensions() const override;\n};\n} // namespace drogon\n\n// Path: lib/src/YamlConfigAdapter.cc\n#include \"YamlConfigAdapter.h\"\n#ifdef HAS_YAML_CPP\n#include \n#endif\nusing namespace drogon;\n#ifdef HAS_YAML_CPP\nnamespace YAML\n{\nstatic bool yaml2json(const Node &node, Json::Value &jsonValue)\n{\n if (node.IsNull())\n {\n return false;\n }\n else if (node.IsScalar())\n {\n if (node.Tag() != \"!\")\n {\n try\n {\n jsonValue = node.as();\n return true;\n }\n catch (const YAML::BadConversion &e)\n {\n }\n try\n {\n jsonValue = node.as();\n return true;\n }\n catch (const YAML::BadConversion &e)\n {\n }\n try\n {\n jsonValue = node.as();\n return true;\n }\n catch (const YAML::BadConversion &e)\n {\n }\n }\n\n Json::Value v(node.Scalar());\n jsonValue.swapPayload(v);\n return true;\n }\n else if (node.IsSequence())\n {\n for (std::size_t i = 0; i < node.size(); i++)\n {\n Json::Value v;\n if (yaml2json(node[i], v))\n {\n jsonValue.append(v);\n }\n else\n {\n return false;\n }\n }\n\n return true;\n }\n else if (node.IsMap())\n {\n for (YAML::const_iterator it = node.begin(); it != node.end(); ++it)\n {\n Json::Value v;\n if (yaml2json(it->second, v))\n {\n jsonValue[it->first.Scalar()] = v;\n }\n else\n {\n return false;\n }\n }\n\n return true;\n }\n\n return false;\n}\n\ntemplate <>\nstruct convert\n{\n static bool decode(const Node &node, Json::Value &rhs)\n {\n return yaml2json(node, rhs);\n };\n};\n} // namespace YAML\n\n#endif\nJson::Value YamlConfigAdapter::getJson(const std::string &content) const\n noexcept(false)\n{\n#if HAS_YAML_CPP\n // parse yaml file\n YAML::Node config = YAML::Load(content);\n if (!config.IsNull())\n {\n return config.as();\n }\n else\n return Json::Value();\n#else\n throw std::runtime_error(\"please install yaml-cpp library\");\n#endif\n}\n\nstd::vector YamlConfigAdapter::getExtensions() const\n{\n return {\"yaml\", \"yml\"};\n}\n\n// Path: lib/src/FixedWindowRateLimiter.h\n#pragma once\n\n#include \n#include \n\nnamespace drogon\n{\nclass FixedWindowRateLimiter : public RateLimiter\n{\n public:\n FixedWindowRateLimiter(size_t capacity,\n std::chrono::duration timeUnit);\n bool isAllowed() override;\n ~FixedWindowRateLimiter() noexcept override = default;\n\n private:\n size_t capacity_;\n size_t currentRequests_{0};\n std::chrono::steady_clock::time_point lastTime_;\n std::chrono::duration timeUnit_;\n};\n} // namespace drogon\n\n// Path: lib/src/SlidingWindowRateLimiter.h\n#pragma once\n#include \n#include \n\nnamespace drogon\n{\nclass SlidingWindowRateLimiter : public RateLimiter\n{\n public:\n SlidingWindowRateLimiter(size_t capacity,\n std::chrono::duration timeUnit);\n bool isAllowed() override;\n ~SlidingWindowRateLimiter() noexcept override = default;\n\n private:\n size_t capacity_;\n size_t currentRequests_{0};\n size_t previousRequests_{0};\n std::chrono::steady_clock::time_point unitStartTime_;\n std::chrono::steady_clock::time_point lastTime_;\n std::chrono::duration timeUnit_;\n};\n} // namespace drogon\n\n// Path: lib/src/TokenBucketRateLimiter.h\n#pragma once\n\n#include \n\nnamespace drogon\n{\nclass TokenBucketRateLimiter : public RateLimiter\n{\n public:\n TokenBucketRateLimiter(size_t capacity,\n std::chrono::duration timeUnit);\n bool isAllowed() override;\n ~TokenBucketRateLimiter() noexcept override = default;\n\n private:\n size_t capacity_;\n std::chrono::steady_clock::time_point lastTime_;\n std::chrono::duration timeUnit_;\n double tokens_;\n};\n} // namespace drogon\n\n// Path: lib/src/RateLimiter.cc\n#include \n#include \"FixedWindowRateLimiter.h\"\n#include \"SlidingWindowRateLimiter.h\"\n#include \"TokenBucketRateLimiter.h\"\n\nusing namespace drogon;\n\nRateLimiterPtr RateLimiter::newRateLimiter(\n RateLimiterType type,\n size_t capacity,\n std::chrono::duration timeUnit)\n{\n switch (type)\n {\n case RateLimiterType::kFixedWindow:\n return std::make_shared(capacity, timeUnit);\n case RateLimiterType::kSlidingWindow:\n return std::make_shared(capacity,\n timeUnit);\n case RateLimiterType::kTokenBucket:\n return std::make_shared(capacity, timeUnit);\n }\n return std::make_shared(capacity, timeUnit);\n}\n\n// Path: lib/src/SlashRemover.cc\n#include \n#include \n#include \n#include \"drogon/utils/FunctionTraits.h\"\n#include \n#include \n#include \n#include \n#include \n#include \n\nusing namespace drogon;\nusing namespace drogon::plugin;\nusing std::string;\nusing std::string_view;\n\nenum removeSlashMode : uint8_t\n{\n trailing = 1 << 0,\n duplicate = 1 << 1,\n both = trailing | duplicate,\n};\n\n/// Returns the index before the trailing slashes,\n/// or 0 if only contains slashes\nstatic inline size_t findTrailingSlashes(string_view url)\n{\n auto len = url.size();\n // Must be at least 2 chars and end with a slash\n if (len < 2 || url.back() != '/')\n return string::npos;\n\n size_t a = len - 1; // We already know the last char is '/',\n // we will use pre-decrement to account for this\n while (--a > 0 && url[a] == '/')\n ; // We know the first char is '/', so don't check for 0\n return a;\n}\n\nstatic inline void removeTrailingSlashes(string &url,\n size_t start,\n string_view originalUrl)\n{\n url = originalUrl.substr(0, start + 1);\n}\n\n/// Returns the index of the 2nd duplicate slash\nstatic inline size_t findDuplicateSlashes(string_view url)\n{\n size_t len = url.size();\n if (len < 2)\n return string::npos;\n\n bool startedPair = true; // Always starts with a slash\n for (size_t a = 1; a < len; ++a)\n {\n if (url[a] != '/') // Broken pair\n {\n startedPair = false;\n continue;\n }\n if (startedPair) // Matching pair\n return a;\n startedPair = true;\n }\n\n return string::npos;\n}\n\nstatic inline void removeDuplicateSlashes(string &url, size_t start)\n{\n // +1 because we don't need to look at the same character again,\n // which was found by `findDuplicateSlashes`, it saves one iteration\n for (size_t b = (start--) + 1, len = url.size(); b < len; ++b)\n {\n const char c = url[b];\n if (c != '/' || url[start] != '/')\n {\n ++start;\n url[start] = c;\n }\n }\n url.resize(start + 1);\n}\n\nstatic inline std::pair findExcessiveSlashes(string_view url)\n{\n size_t len = url.size();\n if (len < 2) // Must have at least 2 characters to count as either trailing\n // or duplicate slash\n return {string::npos, string::npos};\n\n // Trail finder\n size_t trailIdx = len; // The pre-decrement will put it on last char\n while (--trailIdx > 0 && url[trailIdx] == '/')\n ; // We know first char is '/', no need to check it\n\n // Filled with '/'\n if (trailIdx == 0)\n return {\n 0, // Only keep first slash\n string::npos, // No duplicate\n };\n\n // Look for a duplicate pair\n size_t dupIdx = 1;\n for (bool startedPair = true; dupIdx < trailIdx;\n ++dupIdx) // Always starts with a slash\n {\n if (url[dupIdx] != '/') // Broken pair\n {\n startedPair = false;\n continue;\n }\n if (startedPair) // Matching pair\n break;\n startedPair = true;\n }\n\n // Found no duplicate\n if (dupIdx == trailIdx)\n return {\n trailIdx != len - 1\n ? // If has gone past last char, then there is a trailing slash\n trailIdx\n : string::npos, // No trail\n string::npos, // No duplicate\n };\n\n // Duplicate found\n return {\n trailIdx != len - 1\n ? // If has gone past last char, then there is a trailing slash\n trailIdx\n : string::npos, // No trail\n dupIdx,\n };\n}\n\nstatic inline void removeExcessiveSlashes(string &url,\n std::pair start,\n string_view originalUrl)\n{\n if (start.first != string::npos)\n removeTrailingSlashes(url, start.first, originalUrl);\n else\n url = originalUrl;\n\n if (start.second != string::npos)\n removeDuplicateSlashes(url, start.second);\n}\n\nstatic inline bool handleReq(const drogon::HttpRequestPtr &req,\n uint8_t removeMode)\n{\n switch (removeMode)\n {\n case trailing:\n {\n auto find = findTrailingSlashes(req->path());\n if (find == string::npos)\n return false;\n\n string newPath;\n removeTrailingSlashes(newPath, find, req->path());\n req->setPath(std::move(newPath));\n break;\n }\n case duplicate:\n {\n auto find = findDuplicateSlashes(req->path());\n if (find == string::npos)\n return false;\n\n string newPath = req->path();\n removeDuplicateSlashes(newPath, find);\n req->setPath(std::move(newPath));\n break;\n }\n case both:\n default:\n {\n auto find = findExcessiveSlashes(req->path());\n if (find.first == string::npos && find.second == string::npos)\n return false;\n\n string newPath;\n removeExcessiveSlashes(newPath, find, req->path());\n req->setPath(std::move(newPath));\n break;\n }\n }\n return true;\n}\n\nvoid SlashRemover::initAndStart(const Json::Value &config)\n{\n trailingSlashes_ = config.get(\"remove_trailing_slashes\", true).asBool();\n duplicateSlashes_ = config.get(\"remove_duplicate_slashes\", true).asBool();\n redirect_ = config.get(\"redirect\", true).asBool();\n const uint8_t removeMode =\n (trailingSlashes_ * trailing) | (duplicateSlashes_ * duplicate);\n if (!removeMode)\n return;\n auto redirector = app().getPlugin();\n if (!redirector)\n {\n LOG_ERROR << \"Redirector plugin is not found!\";\n return;\n }\n auto func = [removeMode](const HttpRequestPtr &req) -> bool {\n return handleReq(req, removeMode);\n };\n if (redirect_)\n {\n redirector->registerPathRewriteHandler(std::move(func));\n }\n else\n {\n redirector->registerForwardHandler(std::move(func));\n }\n}\n\nvoid SlashRemover::shutdown()\n{\n LOG_TRACE << \"SlashRemover plugin is shutdown!\";\n}\n\n// Path: lib/src/NotFound.cc\n// this file is generated by program automatically,don't modify it!\n\n/**\n *\n * NotFound.cc\n * An Tao\n *\n * Copyright 2018, An Tao. All rights reserved.\n * https://github.com/an-tao/drogon\n * Use of this source code is governed by a MIT license\n * that can be found in the License file.\n *\n * Drogon\n *\n */\n\n#include \n#include \n#include \n#include \n#include \n#include \n\nusing namespace drogon;\n\nstd::string NotFound::genText(const HttpViewData &NotFound_view_data)\n{\n std::stringstream NotFound_tmp_stream;\n NotFound_tmp_stream << \"\\n\";\n NotFound_tmp_stream << \"404 Not Found\\n\";\n...\n", "needle_token_start": 12096, "needle_token_end": 12158, "code_context_ntokens": 16146, "output": ["```cpp\nvoid defaultExceptionHandler(\n const std::exception &e,\n const HttpRequestPtr &req,\n std::function &&callback)\n{\n std::string pathWithQuery = req->path();\n if (req->query().empty() == false)\n pathWithQuery += \"?\" + req->query();\n LOG_ERROR << \"Unhandled exception in \" << pathWithQuery\n << \", what(): \" << e.what();\n const auto &handler = app().getCustomErrorHandler();\n callback(handler(k500InternalServerError, req));\n}\n```"]} +{"repo": "drogonframework/drogon", "name": "bytesSize", "language": "cpp", "path": "lib/src/ConfigLoader.cc", "position_ratio": 0.85, "description": "\nFunction Description:\n1. **Purpose**: The function is designed to convert a string representation of a size, potentially suffixed with a unit (e.g., K, M, G, T), into its numerical value in bytes.\n2. **Input**: A string representing the size, which may include a unit suffix to indicate multiplication factors (e.g., \"K\" for kilobytes, \"M\" for megabytes).\n3. **Output**: A boolean indicating success or failure of the conversion, and a size_t variable that holds the computed size in bytes.\n4. **Procedure**: \n - The function first checks if the input string is empty, setting the size to an undefined large value and returning true.\n - If not empty, it initializes the size to 1 and checks the last character of the string for a unit suffix to determine the multiplication factor.\n - The suffix is removed, and the remaining part of the string is converted to a numerical value.\n - The function then multiplies this value by the determined factor to get the size in bytes.\n - It checks for overflow during multiplication and returns false if the conversion fails or if the final computed size exceeds the maximum allowable size.\n\n", "instruction": "Based on the function description and code context, please retrieve and repeat the exact described function from the code context in a code block wrapped by ```:", "template": "instruction\ncode_context\ndescription\ninstruction", "code_context": "// Path: lib/src/HttpUtils.cc\n/**\n *\n * @file HttpUtils.cc\n * @author An Tao\n *\n * Copyright 2018, An Tao. All rights reserved.\n * https://github.com/an-tao/drogon\n * Use of this source code is governed by a MIT license\n * that can be found in the License file.\n *\n * Drogon\n *\n */\n\n#include \"HttpUtils.h\"\n#include \n#include \n#include \n#include \n#include \n\nnamespace drogon\n{\nstatic std::unordered_map customMime;\n\n// https://en.wikipedia.org/wiki/List_of_file_formats\n// https://developer.mozilla.org/en-US/docs/Web/HTTP/Basics_of_HTTP/MIME_types/Common_types\n// https://www.digipres.org/formats/mime-types/\n// https://www.iana.org/assignments/media-types/media-types.xhtml\n// content type -> list of corresponding mime types, the first being the default\n// (more standard) one + the mime type to return in contentTypeToMime() when not\n// empty (mainly to return the charset for text types)\nstatic const std::unordered_map<\n ContentType,\n std::pair, std::string_view>>\n mimeTypeDatabase_{\n {CT_NONE, {{\"\"}, \"\"}},\n {CT_APPLICATION_OCTET_STREAM, {{\"application/octet-stream\"}, \"\"}},\n {CT_APPLICATION_X_FORM, {{\"application/x-www-form-urlencoded\"}, \"\"}},\n {CT_MULTIPART_FORM_DATA, {{\"multipart/form-data\"}, \"\"}},\n {CT_APPLICATION_GZIP, {{\"application/gzip\"}, \"\"}},\n {CT_APPLICATION_JSON,\n {{\"application/json\"}, \"application/json; charset=utf-8\"}},\n {CT_APPLICATION_FONT_WOFF, {{\"application/font-woff\"}, \"\"}},\n {CT_APPLICATION_FONT_WOFF2, {{\"application/font-woff2\"}, \"\"}},\n {CT_APPLICATION_JAVA_ARCHIVE,\n {{\"application/java-archive\", \"application/x-java-archive\"}, \"\"}},\n {CT_APPLICATION_MSWORD, {{\"application/msword\"}, \"\"}},\n {CT_APPLICATION_MSWORDX,\n {{\"application/\"\n \"vnd.openxmlformats-officedocument.wordprocessingml.document\"},\n \"\"}},\n {CT_APPLICATION_PDF, {{\"application/pdf\"}, \"\"}},\n {CT_APPLICATION_VND_MS_FONTOBJ,\n {{\"application/vnd.ms-fontobject\"}, \"\"}},\n {CT_APPLICATION_VND_RAR, {{\"application/vnd.rar\"}, \"\"}},\n {CT_APPLICATION_WASM, {{\"application/wasm\"}, \"\"}},\n {CT_APPLICATION_X_BZIP, {{\"application/x-bzip\"}, \"\"}},\n {CT_APPLICATION_X_BZIP2, {{\"application/x-bzip2\"}, \"\"}},\n {CT_APPLICATION_X_7Z, {{\"application/x-7z-compressed\"}, \"\"}},\n {CT_APPLICATION_X_HTTPD_PHP, {{\"application/x-httpd-php\"}, \"\"}},\n {CT_APPLICATION_X_JAVASCRIPT,\n {{\"application/x-javascript\"},\n \"application/x-javascript; charset=utf-8\"}},\n {CT_APPLICATION_X_FONT_OPENTYPE,\n {{\"application/x-font-opentype\", \"font/otf\"}, \"\"}},\n {CT_APPLICATION_X_FONT_TRUETYPE,\n {{\"application/x-font-truetype\", \"font/ttf\"}, \"\"}},\n {CT_APPLICATION_X_TAR, {{\"application/x-tar\"}, \"\"}},\n {CT_APPLICATION_X_TGZ, {{\"application/x-tgz\"}, \"\"}},\n {CT_APPLICATION_X_XZ, {{\"application/x-xz\", \"application/x-lzma\"}, \"\"}},\n {CT_APPLICATION_XHTML,\n {{\"application/xhtml+xml\", \"application/xhtml\"},\n \"application/xhtml+xml; charset=utf-8\"}},\n {CT_APPLICATION_XML,\n {{\"application/xml\"}, \"application/xml; charset=utf-8\"}},\n {CT_APPLICATION_ZIP, {{\"application/zip\"}, \"\"}},\n {CT_AUDIO_AAC, {{\"audio/aac\", \"audio/aacp\"}, \"\"}},\n {CT_AUDIO_AC3, {{\"audio/ac3\"}, \"\"}},\n {CT_AUDIO_AIFF, {{\"audio/aiff\", \"audio/x-aiff\"}, \"\"}},\n {CT_AUDIO_FLAC, {{\"audio/flac\"}, \"\"}},\n {CT_AUDIO_MATROSKA, {{\"audio/matroska\", \"audio/x-matroska\"}, \"\"}},\n {CT_AUDIO_MPEG, {{\"audio/mpeg\"}, \"\"}},\n {CT_AUDIO_MPEG4, {{\"audio/mp4\", \"audio/x-m4a\"}, \"\"}},\n {CT_AUDIO_OGG, {{\"audio/ogg\"}, \"\"}},\n {CT_AUDIO_WAVE, {{\"audio/wav\", \"audio/x-wav\"}, \"\"}},\n {CT_AUDIO_X_APE, {{\"audio/x-ape\"}, \"\"}},\n {CT_AUDIO_X_MS_WMA, {{\"audio/x-ms-wma\"}, \"\"}},\n {CT_AUDIO_X_TTA, {{\"audio/x-tta\"}, \"\"}},\n {CT_AUDIO_X_WAVPACK, {{\"audio/x-wavpack\"}, \"\"}},\n {CT_AUDIO_WEBM, {{\"audio/webm\"}, \"\"}},\n {CT_IMAGE_APNG, {{\"image/apng\"}, \"\"}},\n {CT_IMAGE_AVIF, {{\"image/avif\"}, \"\"}},\n {CT_IMAGE_BMP, {{\"image/bmp\"}, \"\"}},\n {CT_IMAGE_GIF, {{\"image/gif\"}, \"\"}},\n {CT_IMAGE_ICNS, {{\"image/icns\"}, \"\"}},\n {CT_IMAGE_JP2, {{\"image/jp2\", \"image/jpx\", \"image/jpm\"}, \"\"}},\n {CT_IMAGE_JPG, {{\"image/jpeg\"}, \"\"}},\n {CT_IMAGE_PNG, {{\"image/png\"}, \"\"}},\n {CT_IMAGE_SVG_XML, {{\"image/svg+xml\"}, \"\"}},\n...\n// Path: lib/src/ConfigAdapterManager.h\n#pragma once\n#include \"ConfigAdapterManager.h\"\n#include \"ConfigAdapter.h\"\n#include \n\nnamespace drogon\n{\nclass ConfigAdapterManager\n{\n public:\n static ConfigAdapterManager &instance();\n Json::Value getJson(const std::string &content, std::string ext) const\n noexcept(false);\n\n private:\n ConfigAdapterManager();\n std::map adapters_;\n};\n} // namespace drogon\n\n// Path: lib/src/ConfigAdapterManager.cc\n#include \"ConfigAdapterManager.h\"\n#include \"JsonConfigAdapter.h\"\n#include \"YamlConfigAdapter.h\"\n#include \n\nusing namespace drogon;\n#define REGISTER_CONFIG_ADAPTER(adapter) \\\n { \\\n auto adapterPtr = std::make_shared(); \\\n auto exts = adapterPtr->getExtensions(); \\\n for (auto ext : exts) \\\n { \\\n std::transform(ext.begin(), ext.end(), ext.begin(), ::tolower); \\\n adapters_[ext] = adapterPtr; \\\n } \\\n }\n\nConfigAdapterManager &ConfigAdapterManager::instance()\n{\n static ConfigAdapterManager instance;\n return instance;\n}\n\nJson::Value ConfigAdapterManager::getJson(const std::string &content,\n std::string ext) const noexcept(false)\n{\n std::transform(ext.begin(), ext.end(), ext.begin(), ::tolower);\n auto it = adapters_.find(ext);\n if (it == adapters_.end())\n {\n throw std::runtime_error(\"No valid parser for this config file!\");\n }\n return it->second->getJson(content);\n}\n\nConfigAdapterManager::ConfigAdapterManager()\n{\n REGISTER_CONFIG_ADAPTER(JsonConfigAdapter);\n REGISTER_CONFIG_ADAPTER(YamlConfigAdapter);\n}\n\n// Path: lib/src/DrTemplateBase.cc\n/**\n *\n * @file DrTemplateBase.cc\n * @author An Tao\n *\n * Copyright 2018, An Tao. All rights reserved.\n * https://github.com/an-tao/drogon\n * Use of this source code is governed by a MIT license\n * that can be found in the License file.\n *\n * Drogon\n *\n */\n\n#include \n#include \n#include \n#include \n#include \n\nusing namespace drogon;\n\nstd::shared_ptr DrTemplateBase::newTemplate(\n const std::string &templateName)\n{\n LOG_TRACE << \"http view name=\" << templateName;\n auto l = templateName.length();\n if (l >= 4 && templateName[l - 4] == '.' && templateName[l - 3] == 'c' &&\n templateName[l - 2] == 's' && templateName[l - 1] == 'p')\n {\n std::string::size_type pos = 0;\n std::string newName;\n newName.reserve(templateName.size());\n if (templateName[0] == '/' || templateName[0] == '\\\\')\n {\n pos = 1;\n }\n else if (templateName[0] == '.' &&\n (templateName[1] == '/' || templateName[1] == '\\\\'))\n {\n pos = 2;\n }\n while (pos < l - 4)\n {\n if (templateName[pos] == '/' || templateName[pos] == '\\\\')\n {\n newName.append(\"::\");\n }\n else\n {\n newName.append(1, templateName[pos]);\n }\n ++pos;\n }\n return std::shared_ptr(dynamic_cast(\n drogon::DrClassMap::newObject(newName)));\n }\n else\n {\n return std::shared_ptr(dynamic_cast(\n drogon::DrClassMap::newObject(templateName)));\n }\n}\n\n// Path: lib/src/DrClassMap.cc\n/**\n *\n * DrClassMap.cc\n * An Tao\n *\n * Copyright 2018, An Tao. All rights reserved.\n * https://github.com/an-tao/drogon\n * Use of this source code is governed by a MIT license\n * that can be found in the License file.\n *\n * Drogon\n *\n */\n\n#include \n#include \n#include \n\nusing namespace drogon;\n\nnamespace drogon\n{\nnamespace internal\n{\nstatic std::unordered_map> &\ngetObjsMap()\n{\n static std::unordered_map>\n singleInstanceMap;\n return singleInstanceMap;\n}\n\nstatic std::mutex &getMapMutex()\n{\n static std::mutex mtx;\n return mtx;\n}\n\n} // namespace internal\n} // namespace drogon\n\nvoid DrClassMap::registerClass(const std::string &className,\n const DrAllocFunc &func,\n const DrSharedAllocFunc &sharedFunc)\n{\n LOG_TRACE << \"Register class:\" << className;\n getMap().insert(\n std::make_pair(className, std::make_pair(func, sharedFunc)));\n}\n\nDrObjectBase *DrClassMap::newObject(const std::string &className)\n{\n auto iter = getMap().find(className);\n if (iter != getMap().end())\n {\n return iter->second.first();\n }\n else\n return nullptr;\n}\n\nstd::shared_ptr DrClassMap::newSharedObject(\n const std::string &className)\n{\n auto iter = getMap().find(className);\n if (iter != getMap().end())\n {\n if (iter->second.second)\n return iter->second.second();\n else\n return std::shared_ptr(iter->second.first());\n }\n else\n return nullptr;\n}\n\nconst std::shared_ptr &DrClassMap::getSingleInstance(\n const std::string &className)\n{\n auto &mtx = internal::getMapMutex();\n auto &singleInstanceMap = internal::getObjsMap();\n {\n std::lock_guard lock(mtx);\n auto iter = singleInstanceMap.find(className);\n if (iter != singleInstanceMap.end())\n return iter->second;\n }\n auto newObj = newSharedObject(className);\n {\n std::lock_guard lock(mtx);\n auto ret = singleInstanceMap.insert(\n std::make_pair(className, std::move(newObj)));\n return ret.first->second;\n }\n}\n\nvoid DrClassMap::setSingleInstance(const std::shared_ptr &ins)\n{\n auto &mtx = internal::getMapMutex();\n auto &singleInstanceMap = internal::getObjsMap();\n std::lock_guard lock(mtx);\n singleInstanceMap[ins->className()] = ins;\n}\n\nstd::vector DrClassMap::getAllClassName()\n{\n std::vector ret;\n for (auto const &iter : getMap())\n {\n ret.push_back(iter.first);\n }\n return ret;\n}\n\nstd::unordered_map> &\nDrClassMap::getMap()\n{\n static std::unordered_map>\n map;\n return map;\n}\n\n// Path: lib/src/HttpFileImpl.h\n/**\n *\n * @file HttpFileImpl.h\n * @author An Tao\n *\n * Copyright 2018, An Tao. All rights reserved.\n * https://github.com/an-tao/drogon\n * Use of this source code is governed by a MIT license\n * that can be found in the License file.\n *\n * Drogon\n *\n */\n\n#pragma once\n#include \"HttpUtils.h\"\n#include \n\n#include \n#include \n#include \n#include \n#include \n#include \n\nnamespace drogon\n{\nclass HttpFileImpl\n{\n public:\n /// Return the file name;\n const std::string &getFileName() const noexcept\n {\n return fileName_;\n }\n\n /// Set the file name, usually called by the MultiPartParser parser.\n void setFileName(const std::string &fileName) noexcept\n {\n fileName_ = fileName;\n }\n\n void setFileName(std::string &&fileName) noexcept\n {\n fileName_ = std::move(fileName);\n }\n\n /// Return the file extension;\n std::string_view getFileExtension() const noexcept\n {\n return drogon::getFileExtension(fileName_);\n }\n\n /// Set the contents of the file, usually called by the MultiPartParser\n /// parser.\n void setFile(const char *data, size_t length) noexcept\n {\n fileContent_ = std::string_view{data, length};\n }\n\n /// Save the file to the file system.\n /**\n * The folder saving the file is app().getUploadPath().\n * The full path is app().getUploadPath()+\"/\"+this->getFileName()\n */\n int save() const noexcept;\n\n /// Save the file to @param path\n /**\n * @param path if the parameter is prefixed with \"/\", \"./\" or \"../\", or is\n * \".\" or \"..\", the full path is path+\"/\"+this->getFileName(),\n * otherwise the file is saved as\n * app().getUploadPath()+\"/\"+path+\"/\"+this->getFileName()\n */\n int save(const std::string &path) const noexcept;\n\n /// Save the file to file system with a new name\n /**\n * @param fileName if the parameter isn't prefixed with \"/\", \"./\" or \"../\",\n * the full path is app().getUploadPath()+\"/\"+filename, otherwise the file\n * is saved as the filename\n */\n int saveAs(const std::string &fileName) const noexcept;\n\n /// Return the file length.\n size_t fileLength() const noexcept\n {\n return fileContent_.length();\n }\n\n const char *fileData() const noexcept\n {\n return fileContent_.data();\n }\n\n const std::string_view &fileContent() const noexcept\n {\n return fileContent_;\n }\n\n /// Return the name of the item in multiple parts.\n const std::string &getItemName() const noexcept\n {\n return itemName_;\n }\n\n void setItemName(const std::string &itemName) noexcept\n {\n itemName_ = itemName;\n }\n\n void setItemName(std::string &&itemName) noexcept\n {\n itemName_ = std::move(itemName);\n }\n\n /// Return the type of file.\n FileType getFileType() const noexcept\n {\n auto ft = drogon::getFileType(contentType_);\n if ((ft != FT_UNKNOWN) && (ft != FT_CUSTOM))\n return ft;\n return parseFileType(getFileExtension());\n }\n\n /// Return md5 hash of the file\n std::string getMd5() const noexcept;\n // Return sha1 hash of the file\n std::string getSha256() const noexcept;\n // Return sha512 hash of the file\n std::string getSha3() const noexcept;\n // int saveTo(const std::string &pathAndFileName) const;\n int saveTo(const std::filesystem::path &pathAndFileName) const noexcept;\n\n void setRequest(const HttpRequestPtr &req) noexcept\n {\n requestPtr_ = req;\n }\n\n drogon::ContentType getContentType() const noexcept\n {\n return contentType_;\n }\n\n void setContentType(drogon::ContentType contentType) noexcept\n {\n contentType_ = contentType;\n }\n\n void setContentTransferEncoding(\n const std::string &contentTransferEncoding) noexcept\n {\n transferEncoding_ = contentTransferEncoding;\n }\n\n void setContentTransferEncoding(\n std::string &&contentTransferEncoding) noexcept\n {\n transferEncoding_ = std::move(contentTransferEncoding);\n }\n\n const std::string &getContentTransferEncoding() const noexcept\n {\n return transferEncoding_;\n }\n\n private:\n std::string fileName_;\n std::string itemName_;\n std::string transferEncoding_;\n std::string_view fileContent_;\n HttpRequestPtr requestPtr_;\n drogon::ContentType contentType_{drogon::CT_NONE};\n};\n} // namespace drogon\n\n// Path: lib/src/HttpFileImpl.cc\n/**\n *\n * @file HttpFileImpl.cc\n * @author An Tao\n *\n * Copyright 2018, An Tao. All rights reserved.\n * https://github.com/an-tao/drogon\n * Use of this source code is governed by a MIT license\n * that can be found in the License file.\n *\n * Drogon\n *\n */\n\n#include \"HttpFileImpl.h\"\n#include \"HttpAppFrameworkImpl.h\"\n#include \n#include \n#include \n#include \n#include \n#include \n\nusing namespace drogon;\n\nint HttpFileImpl::save() const noexcept\n{\n return save(HttpAppFrameworkImpl::instance().getUploadPath());\n}\n\nint HttpFileImpl::save(const std::string &path) const noexcept\n{\n assert(!path.empty());\n if (fileName_.empty())\n return -1;\n std::filesystem::path fsUploadDir(utils::toNativePath(path));\n\n if (fsUploadDir.is_absolute())\n { // do nothing\n }\n else if ((!fsUploadDir.has_parent_path() ||\n (fsUploadDir.begin()->string() != \".\" &&\n fsUploadDir.begin()->string() != \"..\")))\n {\n fsUploadDir = utils::toNativePath(\n HttpAppFrameworkImpl::instance().getUploadPath()) /\n fsUploadDir;\n }\n else\n {\n fsUploadDir = std::filesystem::current_path() / fsUploadDir;\n }\n\n fsUploadDir = std::filesystem::weakly_canonical(fsUploadDir);\n\n if (!std::filesystem::exists(fsUploadDir))\n {\n LOG_TRACE << \"create path:\" << fsUploadDir;\n std::error_code err;\n std::filesystem::create_directories(fsUploadDir, err);\n if (err)\n {\n LOG_SYSERR;\n return -1;\n }\n }\n\n std::filesystem::path fsSaveToPath(std::filesystem::weakly_canonical(\n fsUploadDir / utils::toNativePath(fileName_)));\n LOG_TRACE << \"save to path:\" << fsSaveToPath;\n if (!std::equal(fsUploadDir.begin(),\n fsUploadDir.end(),\n fsSaveToPath.begin()))\n {\n LOG_ERROR\n << \"Attempt writing outside of upload directory detected. Path: \"\n << fileName_;\n return -1;\n }\n\n return saveTo(fsSaveToPath);\n}\n\nint HttpFileImpl::saveAs(const std::string &fileName) const noexcept\n{\n assert(!fileName.empty());\n std::filesystem::path fsFileName(utils::toNativePath(fileName));\n if (!fsFileName.is_absolute() && (!fsFileName.has_parent_path() ||\n (fsFileName.begin()->string() != \".\" &&\n fsFileName.begin()->string() != \"..\")))\n {\n std::filesystem::path fsUploadPath(utils::toNativePath(\n HttpAppFrameworkImpl::instance().getUploadPath()));\n fsFileName = fsUploadPath / fsFileName;\n }\n if (fsFileName.has_parent_path() &&\n !std::filesystem::exists(fsFileName.parent_path()))\n {\n LOG_TRACE << \"create path:\" << fsFileName.parent_path();\n std::error_code err;\n std::filesystem::create_directories(fsFileName.parent_path(), err);\n if (err)\n {\n LOG_SYSERR;\n return -1;\n }\n }\n return saveTo(fsFileName);\n}\n\nint HttpFileImpl::saveTo(\n const std::filesystem::path &pathAndFileName) const noexcept\n{\n LOG_TRACE << \"save uploaded file:\" << pathAndFileName;\n auto wPath = utils::toNativePath(pathAndFileName.native());\n std::ofstream file(wPath, std::ios::binary);\n if (file.is_open())\n {\n file.write(fileContent_.data(), fileContent_.size());\n file.close();\n return 0;\n }\n else\n {\n LOG_ERROR << \"save failed!\";\n return -1;\n }\n}\n\nstd::string HttpFileImpl::getMd5() const noexcept\n{\n return utils::getMd5(fileContent_.data(), fileContent_.size());\n}\n\nstd::string HttpFileImpl::getSha256() const noexcept\n{\n return utils::getSha256(fileContent_.data(), fileContent_.size());\n}\n\nstd::string HttpFileImpl::getSha3() const noexcept\n{\n return utils::getSha3(fileContent_.data(), fileContent_.size());\n}\n\nconst std::string &HttpFile::getFileName() const noexcept\n{\n return implPtr_->getFileName();\n}\n\nvoid HttpFile::setFileName(const std::string &fileName) noexcept\n{\n implPtr_->setFileName(fileName);\n}\n\nstd::string_view HttpFile::getFileExtension() const noexcept\n{\n return implPtr_->getFileExtension();\n}\n\nFileType HttpFile::getFileType() const noexcept\n{\n return implPtr_->getFileType();\n}\n\nvoid HttpFile::setFile(const char *data, size_t length) noexcept\n{\n implPtr_->setFile(data, length);\n}\n\nint HttpFile::save() const noexcept\n{\n return implPtr_->save();\n}\n\nint HttpFile::save(const std::string &path) const noexcept\n{\n return implPtr_->save(path);\n}\n\nint HttpFile::saveAs(const std::string &fileName) const noexcept\n{\n return implPtr_->saveAs(fileName);\n}\n\nsize_t HttpFile::fileLength() const noexcept\n{\n return implPtr_->fileLength();\n}\n\ndrogon::ContentType HttpFile::getContentType() const noexcept\n{\n return implPtr_->getContentType();\n}\n\nconst char *HttpFile::fileData() const noexcept\n{\n return implPtr_->fileData();\n}\n\nstd::string HttpFile::getMd5() const noexcept\n{\n return implPtr_->getMd5();\n}\n\nconst std::string &HttpFile::getContentTransferEncoding() const noexcept\n{\n return implPtr_->getContentTransferEncoding();\n}\n\nHttpFile::HttpFile(std::shared_ptr &&implPtr) noexcept\n : implPtr_(std::move(implPtr))\n{\n}\n\nconst std::string &HttpFile::getItemName() const noexcept\n{\n return implPtr_->getItemName();\n}\n\n// Path: lib/src/ListenerManager.cc\n/**\n *\n * @file ListenerManager.cc\n * @author An Tao\n *\n * Copyright 2018, An Tao. All rights reserved.\n * https://github.com/an-tao/drogon\n * Use of this source code is governed by a MIT license\n * that can be found in the License file.\n *\n * Drogon\n *\n */\n\n#include \"ListenerManager.h\"\n#include \n#include \n#include \n#include \"HttpAppFrameworkImpl.h\"\n#include \"HttpServer.h\"\n#ifndef _WIN32\n#include \n#include \n#endif\n\nnamespace drogon\n{\n#ifndef _WIN32\nclass DrogonFileLocker : public trantor::NonCopyable\n{\n public:\n DrogonFileLocker()\n {\n fd_ = open(\"/tmp/drogon.lock\", O_TRUNC | O_CREAT, 0666);\n flock(fd_, LOCK_EX);\n }\n\n ~DrogonFileLocker()\n {\n close(fd_);\n }\n\n private:\n int fd_{0};\n};\n\n#endif\n} // namespace drogon\n\nusing namespace trantor;\nusing namespace drogon;\n\nvoid ListenerManager::addListener(\n const std::string &ip,\n uint16_t port,\n bool useSSL,\n const std::string &certFile,\n const std::string &keyFile,\n bool useOldTLS,\n const std::vector> &sslConfCmds)\n{\n if (useSSL && !utils::supportsTls())\n LOG_ERROR << \"Can't use SSL without OpenSSL found in your system\";\n listeners_.emplace_back(\n ip, port, useSSL, certFile, keyFile, useOldTLS, sslConfCmds);\n}\n\nstd::vector ListenerManager::getListeners() const\n{\n std::vector listeners;\n for (auto &server : servers_)\n {\n listeners.emplace_back(server->address());\n }\n return listeners;\n}\n\nvoid ListenerManager::createListeners(\n const std::string &globalCertFile,\n const std::string &globalKeyFile,\n const std::vector> &sslConfCmds,\n const std::vector &ioLoops)\n{\n LOG_TRACE << \"thread num=\" << ioLoops.size();\n#ifdef __linux__\n for (size_t i = 0; i < ioLoops.size(); ++i)\n {\n for (auto const &listener : listeners_)\n {\n auto const &ip = listener.ip_;\n bool isIpv6 = (ip.find(':') != std::string::npos);\n InetAddress listenAddress(ip, listener.port_, isIpv6);\n if (listenAddress.isUnspecified())\n {\n LOG_FATAL << \"Failed to parse IP address '\" << ip\n << \"'. (Note: FQDN/domain names/hostnames are not \"\n \"supported. Including 'localhost')\";\n abort();\n }\n if (i == 0 && !app().reusePort())\n {\n DrogonFileLocker lock;\n // Check whether the port is in use.\n TcpServer server(HttpAppFrameworkImpl::instance().getLoop(),\n listenAddress,\n \"drogonPortTest\",\n true,\n false);\n }\n std::shared_ptr serverPtr =\n std::make_shared(ioLoops[i],\n listenAddress,\n \"drogon\");\n\n if (listener.useSSL_ && utils::supportsTls())\n {\n auto cert = listener.certFile_;\n auto key = listener.keyFile_;\n if (cert.empty())\n cert = globalCertFile;\n if (key.empty())\n key = globalKeyFile;\n if (cert.empty() || key.empty())\n {\n std::cerr\n << \"You can't use https without cert file or key file\"\n << std::endl;\n exit(1);\n }\n auto cmds = sslConfCmds;\n std::copy(listener.sslConfCmds_.begin(),\n listener.sslConfCmds_.end(),\n std::back_inserter(cmds));\n auto policy =\n trantor::TLSPolicy::defaultServerPolicy(cert, key);\n policy->setConfCmds(cmds).setUseOldTLS(listener.useOldTLS_);\n serverPtr->enableSSL(std::move(policy));\n }\n servers_.push_back(serverPtr);\n }\n }\n#else\n\n if (!listeners_.empty())\n {\n listeningThread_ =\n std::make_unique(\"DrogonListeningLoop\");\n listeningThread_->run();\n for (auto const &listener : listeners_)\n {\n auto ip = listener.ip_;\n bool isIpv6 = (ip.find(':') != std::string::npos);\n auto serverPtr = std::make_shared(\n listeningThread_->getLoop(),\n InetAddress(ip, listener.port_, isIpv6),\n \"drogon\");\n if (listener.useSSL_ && utils::supportsTls())\n {\n auto cert = listener.certFile_;\n auto key = listener.keyFile_;\n if (cert.empty())\n cert = globalCertFile;\n if (key.empty())\n key = globalKeyFile;\n if (cert.empty() || key.empty())\n {\n std::cerr\n << \"You can't use https without cert file or key file\"\n << std::endl;\n exit(1);\n }\n auto cmds = sslConfCmds;\n auto policy =\n trantor::TLSPolicy::defaultServerPolicy(cert, key);\n policy->setConfCmds(cmds).setUseOldTLS(listener.useOldTLS_);\n serverPtr->enableSSL(std::move(policy));\n }\n serverPtr->setIoLoops(ioLoops);\n servers_.push_back(serverPtr);\n }\n }\n#endif\n}\n\nvoid ListenerManager::startListening()\n{\n for (auto &server : servers_)\n {\n server->start();\n }\n}\n\nvoid ListenerManager::stopListening()\n{\n for (auto &serverPtr : servers_)\n {\n serverPtr->stop();\n }\n if (listeningThread_)\n {\n auto loop = listeningThread_->getLoop();\n assert(!loop->isInLoopThread());\n loop->quit();\n listeningThread_->wait();\n }\n}\n\n// Path: lib/src/RealIpResolver.cc\n/**\n *\n * @file RealIpResolver.cc\n * @author Nitromelon\n *\n * Copyright 2022, Nitromelon. All rights reserved.\n * https://github.com/drogonframework/drogon\n * Use of this source code is governed by a MIT license\n * that can be found in the License file.\n *\n * Drogon\n *\n */\n\n#include \n#include \n#include \n\nusing namespace drogon;\nusing namespace drogon::plugin;\n\nstruct XForwardedForParser : public trantor::NonCopyable\n{\n explicit XForwardedForParser(std::string value)\n : value_(std::move(value)), start_(value_.c_str()), len_(value_.size())\n {\n }\n\n std::string getNext()\n {\n if (len_ == 0)\n {\n return {};\n }\n // Skip trailing separators\n const char *cur;\n for (cur = start_ + len_ - 1; cur > start_; --cur, --len_)\n {\n if (*cur != ' ' && *cur != ',')\n {\n break;\n }\n }\n for (; cur > start_; --cur)\n {\n if (*cur == ' ' || *cur == ',')\n {\n ++cur;\n break;\n }\n }\n std::string ip{cur, len_ - (cur - start_)};\n len_ = cur == start_ ? 0 : cur - start_ - 1;\n return ip;\n }\n\n private:\n std::string value_;\n const char *start_;\n size_t len_;\n};\n\nstatic trantor::InetAddress parseAddress(const std::string &addr)\n{\n auto pos = addr.find(':');\n uint16_t port = 0;\n if (pos == std::string::npos)\n {\n return trantor::InetAddress(addr, 0);\n }\n try\n {\n port = std::stoi(addr.substr(pos + 1));\n }\n catch (const std::exception &ex)\n {\n (void)ex;\n LOG_ERROR << \"Error in ipv4 address: \" + addr;\n port = 0;\n }\n return trantor::InetAddress(addr.substr(0, pos), port);\n}\n\nvoid RealIpResolver::initAndStart(const Json::Value &config)\n{\n fromHeader_ = config.get(\"from_header\", \"x-forwarded-for\").asString();\n attributeKey_ = config.get(\"attribute_key\", \"real-ip\").asString();\n\n std::transform(fromHeader_.begin(),\n fromHeader_.end(),\n fromHeader_.begin(),\n [](unsigned char c) { return tolower(c); });\n if (fromHeader_ == \"x-forwarded-for\")\n {\n useXForwardedFor_ = true;\n }\n\n const Json::Value &trustIps = config[\"trust_ips\"];\n if (!trustIps.isArray())\n {\n throw std::runtime_error(\"Invalid trusted_ips. Should be array.\");\n }\n for (const auto &elem : trustIps)\n {\n std::string ipOrCidr = elem.asString();\n trustCIDRs_.emplace_back(ipOrCidr);\n }\n\n drogon::app().registerPreRoutingAdvice([this](const HttpRequestPtr &req) {\n const std::string &ipHeader = req->getHeader(fromHeader_);\n const trantor::InetAddress &peerAddr = req->getPeerAddr();\n if (ipHeader.empty() || !matchCidr(peerAddr))\n {\n // Target header is empty, or\n // direct peer is already a non-proxy\n req->attributes()->insert(attributeKey_, peerAddr);\n return;\n }\n // Use a header field which contains a single ip\n if (!useXForwardedFor_)\n {\n trantor::InetAddress addr = parseAddress(ipHeader);\n if (addr.isUnspecified())\n {\n req->attributes()->insert(attributeKey_, peerAddr);\n }\n else\n {\n req->attributes()->insert(attributeKey_, addr);\n }\n return;\n }\n // Use x-forwarded-for header, which may contains multiple ip address,\n // separated by comma\n XForwardedForParser parser(ipHeader);\n std::string ip;\n while (!(ip = parser.getNext()).empty())\n {\n trantor::InetAddress addr = parseAddress(ip);\n if (addr.isUnspecified() || matchCidr(addr))\n {\n continue;\n }\n req->attributes()->insert(attributeKey_, addr);\n return;\n }\n // No match, use peerAddr\n req->attributes()->insert(attributeKey_, peerAddr);\n });\n}\n\nvoid RealIpResolver::shutdown()\n{\n}\n\nconst trantor::InetAddress &RealIpResolver::GetRealAddr(\n const HttpRequestPtr &req)\n{\n auto *plugin = app().getPlugin();\n if (!plugin)\n {\n return req->getPeerAddr();\n }\n return plugin->getRealAddr(req);\n}\n\nconst trantor::InetAddress &RealIpResolver::getRealAddr(\n const HttpRequestPtr &req) const\n{\n const std::shared_ptr &attributesPtr = req->getAttributes();\n if (!attributesPtr->find(attributeKey_))\n {\n return req->getPeerAddr();\n }\n return attributesPtr->get(attributeKey_);\n}\n\nbool RealIpResolver::matchCidr(const trantor::InetAddress &addr) const\n{\n for (auto &cidr : trustCIDRs_)\n {\n if ((addr.ipNetEndian() & cidr.mask_) == cidr.addr_)\n {\n return true;\n }\n }\n return false;\n}\n\nRealIpResolver::CIDR::CIDR(const std::string &ipOrCidr)\n{\n // Find CIDR slash\n auto pos = ipOrCidr.find('/');\n std::string ipv4;\n if (pos != std::string::npos)\n {\n // parameter is a CIDR block\n std::string prefixLen = ipOrCidr.substr(pos + 1);\n ipv4 = ipOrCidr.substr(0, pos);\n uint16_t prefix = std::stoi(prefixLen);\n if (prefix > 32)\n {\n throw std::runtime_error(\"Bad CIDR block: \" + ipOrCidr);\n }\n mask_ = htonl(0xffffffffu << (32 - prefix));\n }\n else\n {\n // parameter is an IP\n ipv4 = ipOrCidr;\n mask_ = 0xffffffffu;\n }\n\n trantor::InetAddress addr(ipv4, 0);\n if (addr.isIpV6())\n {\n throw std::runtime_error(\"Ipv6 is not supported by RealIpResolver.\");\n }\n if (addr.isUnspecified())\n {\n throw std::runtime_error(\"Bad ipv4 address: \" + ipv4);\n }\n addr_ = addr.ipNetEndian() & mask_;\n}\n\n// Path: lib/src/FiltersFunction.cc\n/**\n *\n * @file FiltersFunction.cc\n * @author An Tao\n *\n * Copyright 2018, An Tao. All rights reserved.\n * https://github.com/an-tao/drogon\n * Use of this source code is governed by a MIT license\n * that can be found in the License file.\n *\n * Drogon\n *\n */\n\n#include \"FiltersFunction.h\"\n#include \"HttpRequestImpl.h\"\n#include \"HttpResponseImpl.h\"\n#include \"HttpAppFrameworkImpl.h\"\n#include \n\n#include \n\nnamespace drogon\n{\nnamespace filters_function\n{\nstatic void doFilterChains(\n const std::vector> &filters,\n size_t index,\n const HttpRequestImplPtr &req,\n std::shared_ptr>\n &&callbackPtr)\n{\n if (index < filters.size())\n {\n auto &filter = filters[index];\n filter->doFilter(\n req,\n [/*copy*/ callbackPtr](const HttpResponsePtr &resp) {\n (*callbackPtr)(resp);\n },\n [index, req, callbackPtr, &filters]() mutable {\n auto ioLoop = req->getLoop();\n if (ioLoop && !ioLoop->isInLoopThread())\n {\n ioLoop->queueInLoop(\n [&filters,\n index,\n req,\n callbackPtr = std::move(callbackPtr)]() mutable {\n doFilterChains(filters,\n index + 1,\n req,\n std::move(callbackPtr));\n });\n }\n else\n {\n doFilterChains(filters,\n index + 1,\n req,\n std::move(callbackPtr));\n }\n });\n }\n else\n {\n (*callbackPtr)(nullptr);\n }\n}\n\nstd::vector> createFilters(\n const std::vector &filterNames)\n{\n std::vector> filters;\n for (auto const &filter : filterNames)\n {\n auto object_ = DrClassMap::getSingleInstance(filter);\n auto filter_ = std::dynamic_pointer_cast(object_);\n if (filter_)\n filters.push_back(filter_);\n else\n {\n LOG_ERROR << \"filter \" << filter << \" not found\";\n }\n }\n return filters;\n}\n\nvoid doFilters(const std::vector> &filters,\n const HttpRequestImplPtr &req,\n std::function &&callback)\n{\n auto callbackPtr =\n std::make_shared>(std::move(callback));\n doFilterChains(filters, 0, req, std::move(callbackPtr));\n}\n\n} // namespace filters_function\n} // namespace drogon\n\n// Path: lib/src/RangeParser.cc\n/**\n *\n * RangeParser.h\n * He, Wanchen\n *\n * Copyright 2021, He,Wanchen. All rights reserved.\n * https://github.com/drogonframework/drogon\n * Use of this source code is governed by a MIT license\n * that can be found in the License file.\n *\n * Drogon\n *\n */\n\n#include \"RangeParser.h\"\n\n#include \n\nusing namespace drogon;\n\nstatic constexpr size_t MAX_SIZE = std::numeric_limits::max();\nstatic constexpr size_t MAX_TEN = MAX_SIZE / 10;\nstatic constexpr size_t MAX_DIGIT = MAX_SIZE % 10;\n\n// clang-format off\n#define DR_SKIP_WHITESPACE(p) while (*p == ' ') { ++(p); }\n#define DR_ISDIGIT(p) ('0' <= *(p) && *(p) <= '9')\n#define DR_WOULD_OVERFLOW(base, digit) \\\n (static_cast(base) > MAX_TEN || \\\n (static_cast(base) >= MAX_TEN && \\\n static_cast(digit) - '0' > MAX_DIGIT))\n\n// clang-format on\n\n/** Following formats are valid range header according to rfc7233`\n * Range: =-\n * Range: =-\n * Range: =-, -\n * Range: =-, -, -\n * Range: =-\n */\n\nFileRangeParseResult drogon::parseRangeHeader(const std::string &rangeStr,\n size_t contentLength,\n std::vector &ranges)\n{\n if (rangeStr.size() < 7 || rangeStr.compare(0, 6, \"bytes=\") != 0)\n {\n return InvalidRange;\n }\n const char *iter = rangeStr.c_str() + 6;\n\n size_t totalSize = 0;\n while (true)\n {\n size_t start = 0;\n size_t end = 0;\n // If this is a suffix range: =-\n bool isSuffix = false;\n\n DR_SKIP_WHITESPACE(iter);\n\n if (*iter == '-')\n {\n isSuffix = true;\n ++iter;\n }\n // Parse start\n else\n {\n if (!DR_ISDIGIT(iter))\n {\n return InvalidRange;\n }\n while (DR_ISDIGIT(iter))\n {\n // integer out of range\n if (DR_WOULD_OVERFLOW(start, *iter))\n {\n return NotSatisfiable;\n }\n start = start * 10 + (*iter++ - '0');\n }\n DR_SKIP_WHITESPACE(iter);\n // should be separator now\n if (*iter++ != '-')\n {\n return InvalidRange;\n }\n DR_SKIP_WHITESPACE(iter);\n // If this is a prefix range =-\n if (*iter == ',' || *iter == '\\0')\n {\n end = contentLength;\n // Handle found\n if (start < end)\n {\n if (totalSize > MAX_SIZE - (end - start))\n {\n return NotSatisfiable;\n }\n totalSize += end - start;\n ranges.push_back({start, end});\n }\n if (*iter++ != ',')\n {\n break;\n }\n continue;\n }\n }\n\n // Parse end\n if (!DR_ISDIGIT(iter))\n {\n return InvalidRange;\n }\n while (DR_ISDIGIT(iter))\n {\n if (DR_WOULD_OVERFLOW(end, *iter))\n {\n return NotSatisfiable;\n }\n end = end * 10 + (*iter++ - '0');\n }\n DR_SKIP_WHITESPACE(iter);\n\n if (*iter != ',' && *iter != '\\0')\n {\n return InvalidRange;\n }\n if (isSuffix)\n {\n start = (end < contentLength) ? contentLength - end : 0;\n end = contentLength - 1;\n }\n // [start, end)\n if (end >= contentLength)\n {\n end = contentLength;\n }\n else\n {\n ++end;\n }\n\n // handle found\n if (start < end)\n {\n ranges.push_back({start, end});\n if (totalSize > MAX_SIZE - (end - start))\n {\n return NotSatisfiable;\n }\n totalSize += end - start;\n // We restrict the number to be under 100, to avoid malicious\n // requests.\n // Though rfc does not say anything about max number of ranges,\n // it does mention that server can ignore range header freely.\n if (ranges.size() > 100)\n {\n return InvalidRange;\n }\n }\n if (*iter++ != ',')\n {\n break;\n }\n }\n\n if (ranges.size() == 0 || totalSize > contentLength)\n {\n return NotSatisfiable;\n }\n return ranges.size() == 1 ? SinglePart : MultiPart;\n}\n\n#undef DR_SKIP_WHITESPACE\n#undef DR_ISDIGIT\n#undef DR_WOULD_OVERFLOW\n\n// Path: lib/src/ConfigLoader.cc\n/**\n *\n * @file ConfigLoader.cc\n * @author An Tao\n *\n * Copyright 2018, An Tao. All rights reserved.\n * https://github.com/an-tao/drogon\n * Use of this source code is governed by a MIT license\n * that can be found in the License file.\n *\n * Drogon\n *\n */\n\n#include \"ConfigLoader.h\"\n#include \"HttpAppFrameworkImpl.h\"\n#include \n#include \n#include \n#include \n#include \n#include \n#if !defined(_WIN32)\n#include \n#define os_access access\n#else\n#include \n#ifndef __MINGW32__\n#define os_access _waccess\n#define R_OK 04\n#define W_OK 02\n#else\n#define os_access access\n#endif\n#endif\n\n#include \n#include \"ConfigAdapterManager.h\"\n#include \n\nusing namespace drogon;\n\n\nstatic bool bytesSize(std::string &sizeStr, size_t &size)\n{\n if (sizeStr.empty())\n {\n size = -1;\n return true;\n }\n else\n {\n size = 1;\n switch (sizeStr[sizeStr.length() - 1])\n {\n case 'k':\n case 'K':\n size = 1024;\n sizeStr.resize(sizeStr.length() - 1);\n break;\n case 'M':\n case 'm':\n size = (1024 * 1024);\n sizeStr.resize(sizeStr.length() - 1);\n break;\n case 'g':\n case 'G':\n size = (1024 * 1024 * 1024);\n sizeStr.resize(sizeStr.length() - 1);\n break;\n#if ((ULONG_MAX) != (UINT_MAX))\n // 64bit system\n case 't':\n case 'T':\n size = (1024L * 1024L * 1024L * 1024L);\n sizeStr.resize(sizeStr.length() - 1);\n break;\n#endif\n case '0':\n case '1':\n case '2':\n case '3':\n case '4':\n case '5':\n case '7':\n case '8':\n case '9':\n break;\n default:\n return false;\n break;\n }\n std::istringstream iss(sizeStr);\n size_t tmpSize;\n iss >> tmpSize;\n if (iss.fail())\n {\n return false;\n }\n if ((size_t(-1) / tmpSize) >= size)\n size *= tmpSize;\n else\n {\n size = -1;\n }\n return true;\n }\n}\n\nConfigLoader::ConfigLoader(const std::string &configFile)\n{\n if (os_access(drogon::utils::toNativePath(configFile).c_str(), 0) != 0)\n {\n throw std::runtime_error(\"Config file \" + configFile + \" not found!\");\n }\n if (os_access(drogon::utils::toNativePath(configFile).c_str(), R_OK) != 0)\n {\n throw std::runtime_error(\"No permission to read config file \" +\n configFile);\n }\n configFile_ = configFile;\n auto pos = configFile.find_last_of('.');\n if (pos == std::string::npos)\n {\n throw std::runtime_error(\"Invalid config file name!\");\n }\n auto ext = configFile.substr(pos + 1);\n std::ifstream infile(drogon::utils::toNativePath(configFile).c_str(),\n std::ifstream::in);\n // get the content of the infile\n std::string content((std::istreambuf_iterator(infile)),\n std::istreambuf_iterator());\n try\n {\n configJsonRoot_ =\n ConfigAdapterManager::instance().getJson(content, std::move(ext));\n }\n catch (std::exception &e)\n {\n throw std::runtime_error(\"Error reading config file \" + configFile +\n \": \" + e.what());\n }\n}\n\nConfigLoader::ConfigLoader(const Json::Value &data) : configJsonRoot_(data)\n{\n}\n\nConfigLoader::ConfigLoader(Json::Value &&data)\n : configJsonRoot_(std::move(data))\n{\n}\n\nConfigLoader::~ConfigLoader()\n{\n}\n\nstatic void loadLogSetting(const Json::Value &log)\n{\n if (!log)\n return;\n auto useSpdlog = log.get(\"use_spdlog\", false).asBool();\n auto logPath = log.get(\"log_path\", \"\").asString();\n auto baseName = log.get(\"logfile_base_name\", \"\").asString();\n auto logSize = log.get(\"log_size_limit\", 100000000).asUInt64();\n auto maxFiles = log.get(\"max_files\", 0).asUInt();\n HttpAppFrameworkImpl::instance().setLogPath(\n logPath, baseName, logSize, maxFiles, useSpdlog);\n auto logLevel = log.get(\"log_level\", \"DEBUG\").asString();\n if (logLevel == \"TRACE\")\n {\n trantor::Logger::setLogLevel(trantor::Logger::kTrace);\n }\n else if (logLevel == \"DEBUG\")\n {\n trantor::Logger::setLogLevel(trantor::Logger::kDebug);\n }\n else if (logLevel == \"INFO\")\n {\n trantor::Logger::setLogLevel(trantor::Logger::kInfo);\n }\n else if (logLevel == \"WARN\")\n {\n trantor::Logger::setLogLevel(trantor::Logger::kWarn);\n }\n auto localTime = log.get(\"display_local_time\", false).asBool();\n trantor::Logger::setDisplayLocalTime(localTime);\n}\n\nstatic void loadControllers(const Json::Value &controllers)\n{\n if (!controllers)\n return;\n for (auto const &controller : controllers)\n {\n auto path = controller.get(\"path\", \"\").asString();\n auto ctrlName = controller.get(\"controller\", \"\").asString();\n if (path == \"\" || ctrlName == \"\")\n continue;\n std::vector constraints;\n if (!controller[\"http_methods\"].isNull())\n {\n for (auto const &method : controller[\"http_methods\"])\n {\n auto strMethod = method.asString();\n std::transform(strMethod.begin(),\n strMethod.end(),\n strMethod.begin(),\n [](unsigned char c) { return tolower(c); });\n if (strMethod == \"get\")\n {\n constraints.push_back(Get);\n }\n else if (strMethod == \"post\")\n {\n constraints.push_back(Post);\n }\n else if (strMethod == \"head\") // The branch never work\n {\n constraints.push_back(Head);\n }\n else if (strMethod == \"put\")\n {\n constraints.push_back(Put);\n }\n else if (strMethod == \"delete\")\n {\n constraints.push_back(Delete);\n }\n else if (strMethod == \"patch\")\n {\n constraints.push_back(Patch);\n }\n }\n }\n if (!controller[\"filters\"].isNull())\n {\n for (auto const &filter : controller[\"filters\"])\n {\n constraints.push_back(filter.asString());\n }\n }\n drogon::app().registerHttpSimpleController(path, ctrlName, constraints);\n }\n}\n\nstatic void loadApp(const Json::Value &app)\n{\n if (!app)\n return;\n // threads number\n auto threadsNum = app.get(\"threads_num\", 1).asUInt64();\n if (threadsNum == 1)\n {\n threadsNum = app.get(\"number_of_threads\", 1).asUInt64();\n }\n if (threadsNum == 0)\n {\n // set the number to the number of processors.\n threadsNum = std::thread::hardware_concurrency();\n LOG_TRACE << \"The number of processors is \" << threadsNum;\n }\n if (threadsNum < 1)\n threadsNum = 1;\n drogon::app().setThreadNum(threadsNum);\n // session\n auto enableSession = app.get(\"enable_session\", false).asBool();\n if (enableSession)\n {\n auto timeout = app.get(\"session_timeout\", 0).asUInt64();\n auto sameSite = app.get(\"session_same_site\", \"Null\").asString();\n auto cookieKey = app.get(\"session_cookie_key\", \"JSESSIONID\").asString();\n auto maxAge = app.get(\"session_max_age\", -1).asInt();\n drogon::app().enableSession(timeout,\n Cookie::convertString2SameSite(sameSite),\n cookieKey,\n maxAge);\n }\n else\n drogon::app().disableSession();\n // document root\n auto documentRoot = app.get(\"document_root\", \"\").asString();\n if (documentRoot != \"\")\n {\n drogon::app().setDocumentRoot(documentRoot);\n }\n if (!app[\"static_file_headers\"].empty())\n {\n if (app[\"static_file_headers\"].isArray())\n {\n std::vector> headers;\n for (auto &header : app[\"static_file_headers\"])\n {\n headers.emplace_back(\n std::make_pair(header[\"name\"].asString(),\n header[\"value\"].asString()));\n }\n drogon::app().setStaticFileHeaders(headers);\n }\n else\n {\n throw std::runtime_error(\n \"The static_file_headers option must be an array\");\n }\n }\n // upload path\n auto uploadPath = app.get(\"upload_path\", \"uploads\").asString();\n drogon::app().setUploadPath(uploadPath);\n // file types\n auto fileTypes = app[\"file_types\"];\n if (fileTypes.isArray() && !fileTypes.empty())\n {\n std::vector types;\n for (auto const &fileType : fileTypes)\n {\n types.push_back(fileType.asString());\n LOG_TRACE << \"file type:\" << types.back();\n }\n drogon::app().setFileTypes(types);\n }\n // locations\n if (app.isMember(\"locations\"))\n {\n auto &locations = app[\"locations\"];\n if (!locations.isArray())\n {\n throw std::runtime_error(\"The locations option must be an array\");\n }\n for (auto &location : locations)\n {\n auto uri = location.get(\"uri_prefix\", \"\").asString();\n if (uri.empty())\n continue;\n auto defaultContentType =\n location.get(\"default_content_type\", \"\").asString();\n auto alias = location.get(\"alias\", \"\").asString();\n auto isCaseSensitive =\n location.get(\"is_case_sensitive\", false).asBool();\n...\n", "needle_token_start": 13407, "needle_token_end": 13892, "code_context_ntokens": 16094, "output": ["```cpp\nstatic bool bytesSize(std::string &sizeStr, size_t &size)\n{\n if (sizeStr.empty())\n {\n size = -1;\n return true;\n }\n else\n {\n size = 1;\n switch (sizeStr[sizeStr.length() - 1])\n {\n case 'k':\n case 'K':\n size = 1024;\n sizeStr.resize(sizeStr.length() - 1);\n break;\n case 'M':\n case 'm':\n size = (1024 * 1024);\n sizeStr.resize(sizeStr.length() - 1);\n break;\n case 'g':\n case 'G':\n size = (1024 * 1024 * 1024);\n sizeStr.resize(sizeStr.length() - 1);\n break;\n#if ((ULONG_MAX) != (UINT_MAX))\n // 64bit system\n case 't':\n case 'T':\n size = (1024L * 1024L * 1024L * 1024L);\n sizeStr.resize(sizeStr.length() - 1);\n break;\n#endif\n case '0':\n case '1':\n case '2':\n case '3':\n case '4':\n case '5':\n case '7':\n case '8':\n case '9':\n break;\n default:\n return false;\n break;\n }\n std::istringstream iss(sizeStr);\n size_t tmpSize;\n iss >> tmpSize;\n if (iss.fail())\n {\n return false;\n }\n if ((size_t(-1) / tmpSize) >= size)\n size *= tmpSize;\n else\n {\n size = -1;\n }\n return true;\n }\n}\n```"]} +{"repo": "drogonframework/drogon", "name": "ChunkingParams", "language": "cpp", "path": "lib/src/HttpServer.cc", "position_ratio": 0.95, "description": "\nFunction Description:\n1. **Purpose**: The purpose is to manage the process of dynamically fetching data in chunks for transmission over a network, particularly useful in scenarios where data needs to be sent in segments.\n2. **Input**: The input is a callback function that specifies how to retrieve the data chunks. This function takes a buffer and a size, and returns the amount of data actually filled into the buffer.\n3. **Output**: There is no direct output; however, the callback function indirectly determines the data chunks that are output through its return value, which indicates the size of the data filled.\n4. **Procedure**: The procedure involves initializing with a callback function that is responsible for providing data chunks. When data needs to be sent, this callback is invoked with a buffer and size, and it should fill the buffer with up to 'size' bytes of data, returning the number of bytes actually written.\n\n", "instruction": "Based on the function description and code context, please retrieve and repeat the exact described function from the code context in a code block wrapped by ```:", "template": "instruction\ncode_context\ndescription\ninstruction", "code_context": "// Path: lib/src/ConfigLoader.cc\n/**\n *\n * @file ConfigLoader.cc\n * @author An Tao\n *\n * Copyright 2018, An Tao. All rights reserved.\n * https://github.com/an-tao/drogon\n * Use of this source code is governed by a MIT license\n * that can be found in the License file.\n *\n * Drogon\n *\n */\n\n#include \"ConfigLoader.h\"\n#include \"HttpAppFrameworkImpl.h\"\n#include \n#include \n#include \n#include \n#include \n#include \n#if !defined(_WIN32)\n#include \n#define os_access access\n#else\n#include \n#ifndef __MINGW32__\n#define os_access _waccess\n#define R_OK 04\n#define W_OK 02\n#else\n#define os_access access\n#endif\n#endif\n\n#include \n#include \"ConfigAdapterManager.h\"\n#include \n\nusing namespace drogon;\n\nstatic bool bytesSize(std::string &sizeStr, size_t &size)\n{\n if (sizeStr.empty())\n {\n size = -1;\n return true;\n }\n else\n {\n size = 1;\n switch (sizeStr[sizeStr.length() - 1])\n {\n case 'k':\n case 'K':\n size = 1024;\n sizeStr.resize(sizeStr.length() - 1);\n break;\n case 'M':\n case 'm':\n size = (1024 * 1024);\n sizeStr.resize(sizeStr.length() - 1);\n break;\n case 'g':\n case 'G':\n size = (1024 * 1024 * 1024);\n sizeStr.resize(sizeStr.length() - 1);\n break;\n#if ((ULONG_MAX) != (UINT_MAX))\n // 64bit system\n case 't':\n case 'T':\n size = (1024L * 1024L * 1024L * 1024L);\n sizeStr.resize(sizeStr.length() - 1);\n break;\n#endif\n case '0':\n case '1':\n case '2':\n case '3':\n case '4':\n case '5':\n case '7':\n case '8':\n case '9':\n break;\n default:\n return false;\n break;\n }\n std::istringstream iss(sizeStr);\n size_t tmpSize;\n iss >> tmpSize;\n if (iss.fail())\n {\n return false;\n }\n if ((size_t(-1) / tmpSize) >= size)\n size *= tmpSize;\n else\n {\n size = -1;\n }\n return true;\n }\n}\n\nConfigLoader::ConfigLoader(const std::string &configFile)\n{\n if (os_access(drogon::utils::toNativePath(configFile).c_str(), 0) != 0)\n {\n throw std::runtime_error(\"Config file \" + configFile + \" not found!\");\n }\n if (os_access(drogon::utils::toNativePath(configFile).c_str(), R_OK) != 0)\n {\n throw std::runtime_error(\"No permission to read config file \" +\n configFile);\n }\n configFile_ = configFile;\n auto pos = configFile.find_last_of('.');\n if (pos == std::string::npos)\n {\n throw std::runtime_error(\"Invalid config file name!\");\n }\n auto ext = configFile.substr(pos + 1);\n std::ifstream infile(drogon::utils::toNativePath(configFile).c_str(),\n std::ifstream::in);\n // get the content of the infile\n std::string content((std::istreambuf_iterator(infile)),\n std::istreambuf_iterator());\n try\n {\n configJsonRoot_ =\n ConfigAdapterManager::instance().getJson(content, std::move(ext));\n }\n catch (std::exception &e)\n {\n throw std::runtime_error(\"Error reading config file \" + configFile +\n \": \" + e.what());\n }\n}\n\nConfigLoader::ConfigLoader(const Json::Value &data) : configJsonRoot_(data)\n{\n}\n\nConfigLoader::ConfigLoader(Json::Value &&data)\n : configJsonRoot_(std::move(data))\n{\n}\n\nConfigLoader::~ConfigLoader()\n{\n}\n\nstatic void loadLogSetting(const Json::Value &log)\n{\n if (!log)\n return;\n auto useSpdlog = log.get(\"use_spdlog\", false).asBool();\n auto logPath = log.get(\"log_path\", \"\").asString();\n auto baseName = log.get(\"logfile_base_name\", \"\").asString();\n auto logSize = log.get(\"log_size_limit\", 100000000).asUInt64();\n auto maxFiles = log.get(\"max_files\", 0).asUInt();\n HttpAppFrameworkImpl::instance().setLogPath(\n logPath, baseName, logSize, maxFiles, useSpdlog);\n auto logLevel = log.get(\"log_level\", \"DEBUG\").asString();\n if (logLevel == \"TRACE\")\n {\n trantor::Logger::setLogLevel(trantor::Logger::kTrace);\n }\n else if (logLevel == \"DEBUG\")\n {\n trantor::Logger::setLogLevel(trantor::Logger::kDebug);\n }\n else if (logLevel == \"INFO\")\n {\n trantor::Logger::setLogLevel(trantor::Logger::kInfo);\n }\n else if (logLevel == \"WARN\")\n {\n trantor::Logger::setLogLevel(trantor::Logger::kWarn);\n }\n auto localTime = log.get(\"display_local_time\", false).asBool();\n trantor::Logger::setDisplayLocalTime(localTime);\n}\n\nstatic void loadControllers(const Json::Value &controllers)\n{\n if (!controllers)\n return;\n for (auto const &controller : controllers)\n {\n auto path = controller.get(\"path\", \"\").asString();\n auto ctrlName = controller.get(\"controller\", \"\").asString();\n if (path == \"\" || ctrlName == \"\")\n continue;\n std::vector constraints;\n if (!controller[\"http_methods\"].isNull())\n {\n for (auto const &method : controller[\"http_methods\"])\n {\n auto strMethod = method.asString();\n std::transform(strMethod.begin(),\n strMethod.end(),\n strMethod.begin(),\n [](unsigned char c) { return tolower(c); });\n if (strMethod == \"get\")\n {\n constraints.push_back(Get);\n }\n else if (strMethod == \"post\")\n {\n constraints.push_back(Post);\n }\n else if (strMethod == \"head\") // The branch never work\n {\n constraints.push_back(Head);\n }\n else if (strMethod == \"put\")\n {\n constraints.push_back(Put);\n }\n else if (strMethod == \"delete\")\n {\n constraints.push_back(Delete);\n }\n else if (strMethod == \"patch\")\n {\n constraints.push_back(Patch);\n }\n }\n }\n if (!controller[\"filters\"].isNull())\n {\n for (auto const &filter : controller[\"filters\"])\n {\n constraints.push_back(filter.asString());\n }\n }\n drogon::app().registerHttpSimpleController(path, ctrlName, constraints);\n }\n}\n\nstatic void loadApp(const Json::Value &app)\n{\n if (!app)\n return;\n // threads number\n auto threadsNum = app.get(\"threads_num\", 1).asUInt64();\n if (threadsNum == 1)\n {\n threadsNum = app.get(\"number_of_threads\", 1).asUInt64();\n }\n if (threadsNum == 0)\n {\n // set the number to the number of processors.\n threadsNum = std::thread::hardware_concurrency();\n LOG_TRACE << \"The number of processors is \" << threadsNum;\n }\n if (threadsNum < 1)\n threadsNum = 1;\n drogon::app().setThreadNum(threadsNum);\n // session\n auto enableSession = app.get(\"enable_session\", false).asBool();\n if (enableSession)\n {\n auto timeout = app.get(\"session_timeout\", 0).asUInt64();\n auto sameSite = app.get(\"session_same_site\", \"Null\").asString();\n auto cookieKey = app.get(\"session_cookie_key\", \"JSESSIONID\").asString();\n auto maxAge = app.get(\"session_max_age\", -1).asInt();\n drogon::app().enableSession(timeout,\n Cookie::convertString2SameSite(sameSite),\n cookieKey,\n maxAge);\n }\n else\n drogon::app().disableSession();\n // document root\n auto documentRoot = app.get(\"document_root\", \"\").asString();\n if (documentRoot != \"\")\n {\n drogon::app().setDocumentRoot(documentRoot);\n }\n if (!app[\"static_file_headers\"].empty())\n {\n if (app[\"static_file_headers\"].isArray())\n {\n std::vector> headers;\n for (auto &header : app[\"static_file_headers\"])\n {\n headers.emplace_back(\n std::make_pair(header[\"name\"].asString(),\n header[\"value\"].asString()));\n }\n drogon::app().setStaticFileHeaders(headers);\n }\n else\n {\n throw std::runtime_error(\n \"The static_file_headers option must be an array\");\n }\n }\n // upload path\n auto uploadPath = app.get(\"upload_path\", \"uploads\").asString();\n drogon::app().setUploadPath(uploadPath);\n // file types\n auto fileTypes = app[\"file_types\"];\n if (fileTypes.isArray() && !fileTypes.empty())\n {\n std::vector types;\n for (auto const &fileType : fileTypes)\n {\n types.push_back(fileType.asString());\n LOG_TRACE << \"file type:\" << types.back();\n }\n drogon::app().setFileTypes(types);\n }\n // locations\n if (app.isMember(\"locations\"))\n {\n auto &locations = app[\"locations\"];\n if (!locations.isArray())\n {\n throw std::runtime_error(\"The locations option must be an array\");\n }\n for (auto &location : locations)\n {\n auto uri = location.get(\"uri_prefix\", \"\").asString();\n if (uri.empty())\n continue;\n auto defaultContentType =\n location.get(\"default_content_type\", \"\").asString();\n auto alias = location.get(\"alias\", \"\").asString();\n auto isCaseSensitive =\n location.get(\"is_case_sensitive\", false).asBool();\n auto allAll = location.get(\"allow_all\", true).asBool();\n auto isRecursive = location.get(\"is_recursive\", true).asBool();\n if (!location[\"filters\"].isNull())\n {\n if (location[\"filters\"].isArray())\n {\n std::vector filters;\n for (auto const &filter : location[\"filters\"])\n {\n filters.push_back(filter.asString());\n }\n drogon::app().addALocation(uri,\n defaultContentType,\n alias,\n isCaseSensitive,\n allAll,\n isRecursive,\n filters);\n }\n else\n {\n throw std::runtime_error(\"the filters of location '\" + uri +\n \"' should be an array\");\n }\n }\n else\n {\n drogon::app().addALocation(uri,\n defaultContentType,\n alias,\n isCaseSensitive,\n allAll,\n isRecursive);\n }\n }\n }\n // max connections\n auto maxConns = app.get(\"max_connections\", 0).asUInt64();\n if (maxConns > 0)\n {\n drogon::app().setMaxConnectionNum(maxConns);\n }\n // max connections per IP\n auto maxConnsPerIP = app.get(\"max_connections_per_ip\", 0).asUInt64();\n if (maxConnsPerIP > 0)\n {\n drogon::app().setMaxConnectionNumPerIP(maxConnsPerIP);\n }\n#ifndef _WIN32\n // dynamic views\n auto enableDynamicViews = app.get(\"load_dynamic_views\", false).asBool();\n if (enableDynamicViews)\n {\n auto viewsPaths = app[\"dynamic_views_path\"];\n if (viewsPaths.isArray() && viewsPaths.size() > 0)\n {\n std::vector paths;\n for (auto const &viewsPath : viewsPaths)\n {\n paths.push_back(viewsPath.asString());\n LOG_TRACE << \"views path:\" << paths.back();\n }\n auto outputPath =\n app.get(\"dynamic_views_output_path\", \"\").asString();\n drogon::app().enableDynamicViewsLoading(paths, outputPath);\n }\n }\n#endif\n auto stackLimit = app.get(\"json_parser_stack_limit\", 1000).asUInt64();\n drogon::app().setJsonParserStackLimit(stackLimit);\n auto unicodeEscaping =\n app.get(\"enable_unicode_escaping_in_json\", true).asBool();\n drogon::app().setUnicodeEscapingInJson(unicodeEscaping);\n auto &precision = app[\"float_precision_in_json\"];\n if (!precision.isNull())\n {\n auto precisionLength = precision.get(\"precision\", 0).asUInt64();\n auto precisionType =\n precision.get(\"precision_type\", \"significant\").asString();\n drogon::app().setFloatPrecisionInJson((unsigned int)precisionLength,\n precisionType);\n }\n // log\n loadLogSetting(app[\"log\"]);\n // run as daemon\n auto runAsDaemon = app.get(\"run_as_daemon\", false).asBool();\n if (runAsDaemon)\n {\n drogon::app().enableRunAsDaemon();\n }\n // handle SIGTERM\n auto handleSigterm = app.get(\"handle_sig_term\", true).asBool();\n if (!handleSigterm)\n {\n drogon::app().disableSigtermHandling();\n }\n // relaunch\n auto relaunch = app.get(\"relaunch_on_error\", false).asBool();\n if (relaunch)\n {\n drogon::app().enableRelaunchOnError();\n }\n auto useSendfile = app.get(\"use_sendfile\", true).asBool();\n drogon::app().enableSendfile(useSendfile);\n auto useGzip = app.get(\"use_gzip\", true).asBool();\n drogon::app().enableGzip(useGzip);\n auto useBr = app.get(\"use_brotli\", false).asBool();\n drogon::app().enableBrotli(useBr);\n auto staticFilesCacheTime = app.get(\"static_files_cache_time\", 5).asInt();\n drogon::app().setStaticFilesCacheTime(staticFilesCacheTime);\n loadControllers(app[\"simple_controllers_map\"]);\n // Kick off idle connections\n auto kickOffTimeout = app.get(\"idle_connection_timeout\", 60).asUInt64();\n drogon::app().setIdleConnectionTimeout(kickOffTimeout);\n auto server = app.get(\"server_header_field\", \"\").asString();\n if (!server.empty())\n drogon::app().setServerHeaderField(server);\n auto sendServerHeader = app.get(\"enable_server_header\", true).asBool();\n drogon::app().enableServerHeader(sendServerHeader);\n auto sendDateHeader = app.get(\"enable_date_header\", true).asBool();\n drogon::app().enableDateHeader(sendDateHeader);\n auto keepaliveReqs = app.get(\"keepalive_requests\", 0).asUInt64();\n drogon::app().setKeepaliveRequestsNumber(keepaliveReqs);\n auto pipeliningReqs = app.get(\"pipelining_requests\", 0).asUInt64();\n drogon::app().setPipeliningRequestsNumber(pipeliningReqs);\n auto useGzipStatic = app.get(\"gzip_static\", true).asBool();\n drogon::app().setGzipStatic(useGzipStatic);\n auto useBrStatic = app.get(\"br_static\", true).asBool();\n drogon::app().setBrStatic(useBrStatic);\n auto maxBodySize = app.get(\"client_max_body_size\", \"1M\").asString();\n size_t size;\n if (bytesSize(maxBodySize, size))\n {\n drogon::app().setClientMaxBodySize(size);\n }\n else\n {\n throw std::runtime_error(\"Error format of client_max_body_size\");\n }\n auto maxMemoryBodySize =\n app.get(\"client_max_memory_body_size\", \"64K\").asString();\n if (bytesSize(maxMemoryBodySize, size))\n {\n drogon::app().setClientMaxMemoryBodySize(size);\n }\n else\n {\n throw std::runtime_error(\"Error format of client_max_memory_body_size\");\n }\n auto maxWsMsgSize =\n app.get(\"client_max_websocket_message_size\", \"128K\").asString();\n if (bytesSize(maxWsMsgSize, size))\n {\n drogon::app().setClientMaxWebSocketMessageSize(size);\n }\n else\n {\n throw std::runtime_error(\n \"Error format of client_max_websocket_message_size\");\n }\n drogon::app().enableReusePort(app.get(\"reuse_port\", false).asBool());\n drogon::app().setHomePage(app.get(\"home_page\", \"index.html\").asString());\n drogon::app().setImplicitPageEnable(\n app.get(\"use_implicit_page\", true).asBool());\n drogon::app().setImplicitPage(\n app.get(\"implicit_page\", \"index.html\").asString());\n auto mimes = app[\"mime\"];\n if (!mimes.isNull())\n {\n auto names = mimes.getMemberNames();\n for (const auto &mime : names)\n {\n auto ext = mimes[mime];\n std::vector exts;\n if (ext.isString())\n exts.push_back(ext.asString());\n else if (ext.isArray())\n {\n for (const auto &extension : ext)\n exts.push_back(extension.asString());\n }\n\n for (const auto &extension : exts)\n drogon::app().registerCustomExtensionMime(extension, mime);\n }\n }\n bool enableCompressedRequests =\n app.get(\"enabled_compressed_request\", false).asBool();\n drogon::app().enableCompressedRequest(enableCompressedRequests);\n}\n\nstatic void loadDbClients(const Json::Value &dbClients)\n{\n if (!dbClients)\n return;\n for (auto const &client : dbClients)\n {\n auto type = client.get(\"rdbms\", \"postgresql\").asString();\n std::transform(type.begin(),\n type.end(),\n type.begin(),\n [](unsigned char c) { return tolower(c); });\n auto host = client.get(\"host\", \"127.0.0.1\").asString();\n auto port = client.get(\"port\", 5432).asUInt();\n auto dbname = client.get(\"dbname\", \"\").asString();\n if (dbname == \"\" && type != \"sqlite3\")\n {\n throw std::runtime_error(\n \"Please configure dbname in the configuration file\");\n }\n auto user = client.get(\"user\", \"postgres\").asString();\n auto password = client.get(\"passwd\", \"\").asString();\n if (password.empty())\n {\n password = client.get(\"password\", \"\").asString();\n }\n auto connNum = client.get(\"connection_number\", 1).asUInt();\n if (connNum == 1)\n {\n connNum = client.get(\"number_of_connections\", 1).asUInt();\n }\n auto name = client.get(\"name\", \"default\").asString();\n auto filename = client.get(\"filename\", \"\").asString();\n auto isFast = client.get(\"is_fast\", false).asBool();\n auto characterSet = client.get(\"characterSet\", \"\").asString();\n if (characterSet.empty())\n {\n characterSet = client.get(\"client_encoding\", \"\").asString();\n }\n auto timeout = client.get(\"timeout\", -1.0).asDouble();\n auto autoBatch = client.get(\"auto_batch\", false).asBool();\n drogon::app().createDbClient(type,\n host,\n (unsigned short)port,\n dbname,\n user,\n password,\n connNum,\n filename,\n name,\n isFast,\n characterSet,\n timeout,\n autoBatch);\n }\n}\n\nstatic void loadRedisClients(const Json::Value &redisClients)\n{\n if (!redisClients)\n return;\n for (auto const &client : redisClients)\n {\n std::promise promise;\n auto future = promise.get_future();\n auto host = client.get(\"host\", \"127.0.0.1\").asString();\n trantor::Resolver::newResolver()->resolve(\n host, [&promise](const trantor::InetAddress &address) {\n promise.set_value(address.toIp());\n });\n auto port = client.get(\"port\", 6379).asUInt();\n auto username = client.get(\"username\", \"\").asString();\n auto password = client.get(\"passwd\", \"\").asString();\n if (password.empty())\n {\n password = client.get(\"password\", \"\").asString();\n }\n auto connNum = client.get(\"connection_number\", 1).asUInt();\n if (connNum == 1)\n {\n connNum = client.get(\"number_of_connections\", 1).asUInt();\n }\n auto name = client.get(\"name\", \"default\").asString();\n auto isFast = client.get(\"is_fast\", false).asBool();\n auto timeout = client.get(\"timeout\", -1.0).asDouble();\n auto db = client.get(\"db\", 0).asUInt();\n auto hostIp = future.get();\n drogon::app().createRedisClient(hostIp,\n port,\n name,\n password,\n connNum,\n isFast,\n timeout,\n db,\n username);\n }\n}\n\nstatic void loadListeners(const Json::Value &listeners)\n{\n if (!listeners)\n return;\n LOG_TRACE << \"Has \" << listeners.size() << \" listeners\";\n for (auto const &listener : listeners)\n {\n auto addr = listener.get(\"address\", \"0.0.0.0\").asString();\n auto port = (uint16_t)listener.get(\"port\", 0).asUInt();\n auto useSSL = listener.get(\"https\", false).asBool();\n auto cert = listener.get(\"cert\", \"\").asString();\n auto key = listener.get(\"key\", \"\").asString();\n auto useOldTLS = listener.get(\"use_old_tls\", false).asBool();\n std::vector> sslConfCmds;\n if (listener.isMember(\"ssl_conf\"))\n {\n for (const auto &opt : listener[\"ssl_conf\"])\n {\n...\n// Path: lib/src/IntranetIpFilter.cc\n/**\n *\n * IntranetIpFilter.cc\n * An Tao\n *\n * Copyright 2018, An Tao. All rights reserved.\n * https://github.com/an-tao/drogon\n * Use of this source code is governed by a MIT license\n * that can be found in the License file.\n *\n * Drogon\n *\n */\n\n#include \"HttpResponseImpl.h\"\n#include \nusing namespace drogon;\n\nvoid IntranetIpFilter::doFilter(const HttpRequestPtr &req,\n FilterCallback &&fcb,\n FilterChainCallback &&fccb)\n{\n if (req->peerAddr().isIntranetIp())\n {\n fccb();\n return;\n }\n auto res = drogon::HttpResponse::newNotFoundResponse(req);\n fcb(res);\n}\n\n// Path: lib/src/FixedWindowRateLimiter.cc\n#include \"FixedWindowRateLimiter.h\"\n\nusing namespace drogon;\n\nFixedWindowRateLimiter::FixedWindowRateLimiter(\n size_t capacity,\n std::chrono::duration timeUnit)\n : capacity_(capacity),\n lastTime_(std::chrono::steady_clock::now()),\n timeUnit_(timeUnit)\n{\n}\n\n// implementation of the fixed window algorithm\n\nbool FixedWindowRateLimiter::isAllowed()\n{\n auto now = std::chrono::steady_clock::now();\n auto duration = std::chrono::duration_cast>(\n now - lastTime_);\n if (duration >= timeUnit_)\n {\n currentRequests_ = 0;\n lastTime_ = now;\n }\n if (currentRequests_ < capacity_)\n {\n currentRequests_++;\n return true;\n }\n return false;\n}\n\n// Path: lib/src/SlidingWindowRateLimiter.cc\n#include \"SlidingWindowRateLimiter.h\"\n#include \n\nusing namespace drogon;\n\nSlidingWindowRateLimiter::SlidingWindowRateLimiter(\n size_t capacity,\n std::chrono::duration timeUnit)\n : capacity_(capacity),\n unitStartTime_(std::chrono::steady_clock::now()),\n lastTime_(unitStartTime_),\n timeUnit_(timeUnit)\n{\n}\n\n// implementation of the sliding window algorithm\nbool SlidingWindowRateLimiter::isAllowed()\n{\n auto now = std::chrono::steady_clock::now();\n unitStartTime_ =\n unitStartTime_ +\n std::chrono::duration_cast(\n std::chrono::duration(\n static_cast(\n (uint64_t)(std::chrono::duration_cast<\n std::chrono::duration>(\n now - unitStartTime_)\n .count() /\n timeUnit_.count())) *\n timeUnit_.count()));\n\n if (unitStartTime_ > lastTime_)\n {\n auto duration =\n std::chrono::duration_cast>(\n unitStartTime_ - lastTime_);\n auto startTime = lastTime_;\n if (duration >= timeUnit_)\n {\n previousRequests_ = 0;\n }\n else\n {\n previousRequests_ = currentRequests_;\n }\n currentRequests_ = 0;\n }\n auto coef = std::chrono::duration_cast>(\n now - unitStartTime_) /\n timeUnit_;\n assert(coef <= 1.0);\n auto count = previousRequests_ * (1.0 - coef) + currentRequests_;\n if (count < capacity_)\n {\n currentRequests_++;\n lastTime_ = now;\n return true;\n }\n return false;\n}\n\n// Path: lib/src/HttpRequestParser.h\n/**\n *\n * HttpRequestParser.h\n * An Tao\n *\n * Copyright 2018, An Tao. All rights reserved.\n * https://github.com/an-tao/drogon\n * Use of this source code is governed by a MIT license\n * that can be found in the License file.\n *\n * Drogon\n *\n */\n\n#pragma once\n\n#include \n#include \n#include \n#include \n#include \n#include \n#include \n#include \"impl_forwards.h\"\n\nnamespace drogon\n{\nclass HttpRequestParser : public trantor::NonCopyable,\n public std::enable_shared_from_this\n{\n public:\n enum class HttpRequestParseStatus\n {\n kExpectMethod,\n kExpectRequestLine,\n kExpectHeaders,\n kExpectBody,\n kExpectChunkLen,\n kExpectChunkBody,\n kExpectLastEmptyChunk,\n kGotAll,\n };\n\n explicit HttpRequestParser(const trantor::TcpConnectionPtr &connPtr);\n\n // return false if any error\n int parseRequest(trantor::MsgBuffer *buf);\n\n bool gotAll() const\n {\n return status_ == HttpRequestParseStatus::kGotAll;\n }\n\n void reset();\n\n const HttpRequestImplPtr &requestImpl() const\n {\n return request_;\n }\n\n bool firstReq()\n {\n if (firstRequest_)\n {\n firstRequest_ = false;\n return true;\n }\n return false;\n }\n\n const WebSocketConnectionImplPtr &webSocketConn() const\n {\n return websockConnPtr_;\n }\n\n void setWebsockConnection(const WebSocketConnectionImplPtr &conn)\n {\n websockConnPtr_ = conn;\n }\n\n // to support request pipelining(rfc2616-8.1.2.2)\n void pushRequestToPipelining(const HttpRequestPtr &, bool isHeadMethod);\n bool pushResponseToPipelining(const HttpRequestPtr &, HttpResponsePtr);\n void popReadyResponses(std::vector> &);\n\n size_t numberOfRequestsInPipelining() const\n {\n return requestPipelining_.size();\n }\n\n bool emptyPipelining()\n {\n return requestPipelining_.empty();\n }\n\n bool isStop() const\n {\n return stopWorking_;\n }\n\n void stop()\n {\n stopWorking_ = true;\n }\n\n size_t numberOfRequestsParsed() const\n {\n return requestsCounter_;\n }\n\n trantor::MsgBuffer &getBuffer()\n {\n return sendBuffer_;\n }\n\n std::vector> &getResponseBuffer()\n {\n assert(loop_->isInLoopThread());\n if (!responseBuffer_)\n {\n responseBuffer_ =\n std::unique_ptr>>(\n new std::vector>);\n }\n return *responseBuffer_;\n }\n\n std::vector &getRequestBuffer()\n {\n assert(loop_->isInLoopThread());\n if (!requestBuffer_)\n {\n requestBuffer_ = std::unique_ptr>(\n new std::vector);\n }\n return *requestBuffer_;\n }\n\n private:\n HttpRequestImplPtr makeRequestForPool(HttpRequestImpl *p);\n void shutdownConnection(HttpStatusCode code);\n bool processRequestLine(const char *begin, const char *end);\n HttpRequestParseStatus status_;\n trantor::EventLoop *loop_;\n HttpRequestImplPtr request_;\n bool firstRequest_{true};\n WebSocketConnectionImplPtr websockConnPtr_;\n std::deque>>\n requestPipelining_;\n size_t requestsCounter_{0};\n std::weak_ptr conn_;\n bool stopWorking_{false};\n trantor::MsgBuffer sendBuffer_;\n std::unique_ptr>>\n responseBuffer_;\n std::unique_ptr> requestBuffer_;\n std::vector requestsPool_;\n size_t currentChunkLength_{0};\n size_t currentContentLength_{0};\n};\n\n} // namespace drogon\n\n// Path: lib/src/HttpServer.cc\n/**\n *\n * @file HttpServer.cc\n * @author An Tao\n *\n * Copyright 2018, An Tao. All rights reserved.\n * https://github.com/an-tao/drogon\n * Use of this source code is governed by a MIT license\n * that can be found in the License file.\n *\n * Drogon\n *\n */\n\n#include \"HttpServer.h\"\n#include \n#include \n#include \n#include \n#include \n#include \n#include \"AOPAdvice.h\"\n#include \"FiltersFunction.h\"\n#include \"HttpAppFrameworkImpl.h\"\n#include \"HttpConnectionLimit.h\"\n#include \"HttpControllerBinder.h\"\n#include \"HttpRequestImpl.h\"\n#include \"HttpRequestParser.h\"\n#include \"HttpResponseImpl.h\"\n#include \"HttpControllersRouter.h\"\n#include \"StaticFileRouter.h\"\n#include \"WebSocketConnectionImpl.h\"\n\n#if COZ_PROFILING\n#include \n#else\n#define COZ_PROGRESS\n#define COZ_PROGRESS_NAMED(name)\n#define COZ_BEGIN(name)\n#define COZ_END(name)\n#endif\n\nusing namespace std::placeholders;\nusing namespace drogon;\nusing namespace trantor;\n\nstatic inline bool isWebSocket(const HttpRequestImplPtr &req);\nstatic inline HttpResponsePtr tryDecompressRequest(\n const HttpRequestImplPtr &req);\nstatic inline bool passSyncAdvices(\n const HttpRequestImplPtr &req,\n const std::shared_ptr &requestParser,\n bool shouldBePipelined,\n bool isHeadMethod);\nstatic inline HttpResponsePtr getCompressedResponse(\n const HttpRequestImplPtr &req,\n const HttpResponsePtr &response,\n bool isHeadMethod);\n\nstatic void handleInvalidHttpMethod(\n const HttpRequestImplPtr &req,\n std::function &&callback);\n\nstatic void handleHttpOptions(\n const HttpRequestImplPtr &req,\n const std::string &allowMethods,\n std::function &&callback);\n\nHttpServer::HttpServer(EventLoop *loop,\n const InetAddress &listenAddr,\n std::string name)\n#ifdef __linux__\n : server_(loop, listenAddr, std::move(name))\n#else\n : server_(loop, listenAddr, std::move(name), true, app().reusePort())\n#endif\n{\n server_.setConnectionCallback(onConnection);\n server_.setRecvMessageCallback(onMessage);\n server_.kickoffIdleConnections(\n HttpAppFrameworkImpl::instance().getIdleConnectionTimeout());\n}\n\nHttpServer::~HttpServer() = default;\n\nvoid HttpServer::start()\n{\n LOG_TRACE << \"HttpServer[\" << server_.name() << \"] starts listening on \"\n << server_.ipPort();\n server_.start();\n}\n\nvoid HttpServer::stop()\n{\n server_.stop();\n}\n\nvoid HttpServer::onConnection(const TcpConnectionPtr &conn)\n{\n if (conn->connected())\n {\n auto parser = std::make_shared(conn);\n parser->reset();\n conn->setContext(parser);\n if (!HttpConnectionLimit::instance().tryAddConnection(conn))\n {\n LOG_ERROR << \"too much connections!force close!\";\n conn->forceClose();\n return;\n }\n if (!AopAdvice::instance().passNewConnectionAdvices(conn))\n {\n conn->forceClose();\n }\n }\n else if (conn->disconnected())\n {\n LOG_TRACE << \"conn disconnected!\";\n HttpConnectionLimit::instance().releaseConnection(conn);\n auto requestParser = conn->getContext();\n if (requestParser)\n {\n if (requestParser->webSocketConn())\n {\n requestParser->webSocketConn()->onClose();\n }\n conn->clearContext();\n }\n }\n}\n\nvoid HttpServer::onMessage(const TcpConnectionPtr &conn, MsgBuffer *buf)\n{\n if (!conn->hasContext())\n return;\n auto requestParser = conn->getContext();\n if (!requestParser)\n return;\n if (requestParser->webSocketConn())\n {\n // Websocket payload\n requestParser->webSocketConn()->onNewMessage(conn, buf);\n return;\n }\n\n auto &requests = requestParser->getRequestBuffer();\n // With the pipelining feature or web socket, it is possible to receive\n // multiple messages at once, so the while loop is necessary\n while (buf->readableBytes() > 0)\n {\n if (requestParser->isStop())\n {\n // The number of requests has reached the limit.\n buf->retrieveAll();\n return;\n }\n int parseRes = requestParser->parseRequest(buf);\n if (parseRes < 0)\n {\n requestParser->reset();\n conn->forceClose();\n return;\n }\n if (parseRes == 0)\n {\n break;\n }\n auto &req = requestParser->requestImpl();\n req->setPeerAddr(conn->peerAddr());\n req->setLocalAddr(conn->localAddr());\n req->setCreationDate(trantor::Date::date());\n req->setSecure(conn->isSSLConnection());\n req->setPeerCertificate(conn->peerCertificate());\n requests.push_back(req);\n requestParser->reset();\n }\n onRequests(conn, requests, requestParser);\n requests.clear();\n}\n\nstruct CallbackParamPack\n{\n CallbackParamPack(trantor::TcpConnectionPtr conn,\n HttpRequestImplPtr req,\n std::shared_ptr loopFlag,\n std::shared_ptr requestParser,\n bool isHeadMethod)\n : conn_(std::move(conn)),\n req_(std::move(req)),\n loopFlag_(std::move(loopFlag)),\n requestParser_(std::move(requestParser)),\n isHeadMethod_(isHeadMethod)\n {\n }\n\n trantor::TcpConnectionPtr conn_;\n HttpRequestImplPtr req_;\n std::shared_ptr loopFlag_;\n std::shared_ptr requestParser_;\n bool isHeadMethod_;\n std::atomic responseSent_{false};\n};\n\nvoid HttpServer::onRequests(\n const TcpConnectionPtr &conn,\n const std::vector &requests,\n const std::shared_ptr &requestParser)\n{\n if (requests.empty())\n return;\n\n // will only be checked for the first request\n if (requestParser->firstReq() && requests.size() == 1 &&\n isWebSocket(requests[0]))\n {\n auto &req = requests[0];\n if (passSyncAdvices(req,\n requestParser,\n false /* Not pipelined */,\n false /* Not HEAD */))\n {\n auto wsConn = std::make_shared(conn);\n wsConn->setPingMessage(\"\", std::chrono::seconds{30});\n onWebsocketRequest(\n req,\n [conn, wsConn, requestParser, req](\n const HttpResponsePtr &resp0) mutable {\n if (conn->connected())\n {\n auto resp = HttpAppFrameworkImpl::instance()\n .handleSessionForResponse(req, resp0);\n AopAdvice::instance().passPreSendingAdvices(req, resp);\n if (resp->statusCode() == k101SwitchingProtocols)\n {\n requestParser->setWebsockConnection(wsConn);\n }\n auto httpString =\n ((HttpResponseImpl *)resp.get())->renderToBuffer();\n conn->send(httpString);\n COZ_PROGRESS\n }\n },\n std::move(wsConn));\n return;\n }\n\n // flush response for not passing sync advices\n if (conn->connected() && !requestParser->getResponseBuffer().empty())\n {\n sendResponses(conn,\n requestParser->getResponseBuffer(),\n requestParser->getBuffer());\n requestParser->getResponseBuffer().clear();\n }\n return;\n }\n\n if (HttpAppFrameworkImpl::instance().keepaliveRequestsNumber() > 0 &&\n requestParser->numberOfRequestsParsed() >=\n HttpAppFrameworkImpl::instance().keepaliveRequestsNumber())\n {\n requestParser->stop();\n conn->shutdown();\n return;\n }\n if (HttpAppFrameworkImpl::instance().pipeliningRequestsNumber() > 0 &&\n requestParser->numberOfRequestsInPipelining() + requests.size() >=\n HttpAppFrameworkImpl::instance().pipeliningRequestsNumber())\n {\n requestParser->stop();\n conn->shutdown();\n return;\n }\n if (!conn->connected())\n {\n return;\n }\n auto loopFlagPtr = std::make_shared(true);\n\n for (auto &req : requests)\n {\n bool isHeadMethod = (req->method() == Head);\n if (isHeadMethod)\n {\n req->setMethod(Get);\n }\n bool reqPipelined = false;\n if (!requestParser->emptyPipelining())\n {\n requestParser->pushRequestToPipelining(req, isHeadMethod);\n reqPipelined = true;\n }\n if (!passSyncAdvices(req, requestParser, reqPipelined, isHeadMethod))\n {\n continue;\n }\n\n // Optimization: Avoids dynamic allocation when copying the callback in\n // handlers (ex: copying callback into lambda captures in DB calls)\n bool respReady{false};\n auto paramPack = std::make_shared(\n conn, req, loopFlagPtr, requestParser, isHeadMethod);\n\n auto errResp = tryDecompressRequest(req);\n if (errResp)\n {\n handleResponse(errResp, paramPack, &respReady);\n }\n else\n {\n // `handleResponse()` callback may be called synchronously. In this\n // case, the generated response should not be sent right away, but\n // be queued in buffer instead. Those ready responses will be sent\n // together after the end of the for loop.\n //\n // By doing this, we could reduce some system calls when sending\n // through socket. In order to achieve this, we create a\n // `respReady` variable.\n onHttpRequest(req,\n [respReadyPtr = &respReady,\n paramPack = std::move(paramPack)](\n const HttpResponsePtr &response) {\n handleResponse(response, paramPack, respReadyPtr);\n });\n }\n if (!reqPipelined && !respReady)\n {\n requestParser->pushRequestToPipelining(req, isHeadMethod);\n }\n }\n *loopFlagPtr = false;\n if (conn->connected() && !requestParser->getResponseBuffer().empty())\n {\n sendResponses(conn,\n requestParser->getResponseBuffer(),\n requestParser->getBuffer());\n requestParser->getResponseBuffer().clear();\n }\n}\n\nvoid HttpServer::onHttpRequest(\n const HttpRequestImplPtr &req,\n std::function &&callback)\n{\n LOG_TRACE << \"new request:\" << req->peerAddr().toIpPort() << \"->\"\n << req->localAddr().toIpPort();\n LOG_TRACE << \"Headers \" << req->methodString() << \" \" << req->path();\n LOG_TRACE << \"http path=\" << req->path();\n if (req->method() == Options && (req->path() == \"*\" || req->path() == \"/*\"))\n {\n auto resp = HttpResponse::newHttpResponse();\n resp->setContentTypeCode(ContentType::CT_TEXT_PLAIN);\n resp->addHeader(\"Allow\", \"GET,HEAD,POST,PUT,DELETE,OPTIONS,PATCH\");\n resp->setExpiredTime(0);\n callback(resp);\n return;\n }\n\n // TODO: move session related codes to its own singleton class\n HttpAppFrameworkImpl::instance().findSessionForRequest(req);\n // pre-routing aop\n auto &aop = AopAdvice::instance();\n aop.passPreRoutingObservers(req);\n if (!aop.hasPreRoutingAdvices())\n {\n httpRequestRouting(req, std::move(callback));\n return;\n }\n aop.passPreRoutingAdvices(req,\n [req, callback = std::move(callback)](\n const HttpResponsePtr &resp) mutable {\n if (resp)\n {\n callback(resp);\n }\n else\n {\n httpRequestRouting(req,\n std::move(callback));\n }\n });\n}\n\nvoid HttpServer::httpRequestRouting(\n const HttpRequestImplPtr &req,\n std::function &&callback)\n{\n // How to access router here?? Make router class singleton?\n RouteResult result = HttpControllersRouter::instance().route(req);\n if (result.result == RouteResult::Success)\n {\n HttpRequestParamPack pack{std::move(result.binderPtr),\n std::move(callback)};\n requestPostRouting(req, std::move(pack));\n return;\n }\n if (result.result == RouteResult::MethodNotAllowed)\n {\n handleInvalidHttpMethod(req, std::move(callback));\n return;\n }\n\n // Fallback to static file router\n // TODO: make this router a plugin\n if (req->path() == \"/\" &&\n !HttpAppFrameworkImpl::instance().getHomePage().empty())\n {\n req->setPath(\"/\" + HttpAppFrameworkImpl::instance().getHomePage());\n }\n StaticFileRouter::instance().route(req, std::move(callback));\n}\n\ntemplate \nvoid HttpServer::requestPostRouting(const HttpRequestImplPtr &req, Pack &&pack)\n{\n // post-routing aop\n auto &aop = AopAdvice::instance();\n aop.passPostRoutingObservers(req);\n if (!aop.hasPostRoutingAdvices())\n {\n requestPassFilters(req, std::forward(pack));\n return;\n }\n aop.passPostRoutingAdvices(req,\n [req, pack = std::forward(pack)](\n const HttpResponsePtr &resp) mutable {\n if (resp)\n {\n pack.callback(resp);\n }\n else\n {\n requestPassFilters(req, std::move(pack));\n }\n });\n}\n\ntemplate \nvoid HttpServer::requestPassFilters(const HttpRequestImplPtr &req, Pack &&pack)\n{\n // pass filters\n auto &filters = pack.binderPtr->filters_;\n if (filters.empty())\n {\n requestPreHandling(req, std::forward(pack));\n return;\n }\n filters_function::doFilters(filters,\n req,\n [req, pack = std::forward(pack)](\n const HttpResponsePtr &resp) mutable {\n if (resp)\n {\n pack.callback(resp);\n }\n else\n {\n requestPreHandling(req,\n std::move(pack));\n }\n });\n}\n\ntemplate \nvoid HttpServer::requestPreHandling(const HttpRequestImplPtr &req, Pack &&pack)\n{\n if (req->method() == Options)\n {\n handleHttpOptions(req,\n *pack.binderPtr->corsMethods_,\n std::move(pack.callback));\n return;\n }\n\n // pre-handling aop\n auto &aop = AopAdvice::instance();\n aop.passPreHandlingObservers(req);\n if (!aop.hasPreHandlingAdvices())\n {\n if constexpr (std::is_same_v, HttpRequestParamPack>)\n {\n httpRequestHandling(req,\n std::move(pack.binderPtr),\n std::move(pack.callback));\n }\n else\n {\n websocketRequestHandling(req,\n std::move(pack.binderPtr),\n std::move(pack.callback),\n std::move(pack.wsConnPtr));\n }\n return;\n }\n aop.passPreHandlingAdvices(\n req,\n [req,\n pack = std::forward(pack)](const HttpResponsePtr &resp) mutable {\n if (resp)\n {\n pack.callback(resp);\n return;\n }\n if constexpr (std::is_same_v,\n HttpRequestParamPack>)\n {\n httpRequestHandling(req,\n std::move(pack.binderPtr),\n std::move(pack.callback));\n }\n else\n {\n websocketRequestHandling(req,\n std::move(pack.binderPtr),\n std::move(pack.callback),\n std::move(pack.wsConnPtr));\n }\n });\n}\n\nvoid HttpServer::httpRequestHandling(\n const HttpRequestImplPtr &req,\n std::shared_ptr &&binderPtr,\n std::function &&callback)\n{\n // Check cached response\n auto &cachedResp = *(binderPtr->responseCache_);\n if (cachedResp)\n {\n if (cachedResp->expiredTime() == 0 ||\n (trantor::Date::now() <\n cachedResp->creationDate().after(\n static_cast(cachedResp->expiredTime()))))\n {\n // use cached response!\n LOG_TRACE << \"Use cached response\";\n\n // post-handling aop\n AopAdvice::instance().passPostHandlingAdvices(req, cachedResp);\n callback(cachedResp);\n return;\n }\n else\n {\n cachedResp.reset();\n }\n }\n\n auto &binderRef = *binderPtr;\n binderRef.handleRequest(\n req,\n // This is the actual callback being passed to controller\n [req, binderPtr = std::move(binderPtr), callback = std::move(callback)](\n const HttpResponsePtr &resp) mutable {\n // Check if we need to cache the response\n if (resp->expiredTime() >= 0 && resp->statusCode() != k404NotFound)\n {\n static_cast(resp.get())->makeHeaderString();\n auto loop = req->getLoop();\n if (loop->isInLoopThread())\n {\n binderPtr->responseCache_.setThreadData(resp);\n }\n else\n {\n loop->queueInLoop(\n [binderPtr = std::move(binderPtr), resp]() {\n binderPtr->responseCache_.setThreadData(resp);\n });\n }\n }\n // post-handling aop\n AopAdvice::instance().passPostHandlingAdvices(req, resp);\n callback(resp);\n });\n}\n\nvoid HttpServer::onWebsocketRequest(\n const HttpRequestImplPtr &req,\n std::function &&callback,\n WebSocketConnectionImplPtr &&wsConnPtr)\n{\n HttpAppFrameworkImpl::instance().findSessionForRequest(req);\n // pre-routing aop\n auto &aop = AopAdvice::instance();\n aop.passPreRoutingObservers(req);\n if (!aop.hasPreRoutingAdvices())\n {\n websocketRequestRouting(req, std::move(callback), std::move(wsConnPtr));\n return;\n }\n aop.passPreRoutingAdvices(\n req,\n [req, wsConnPtr = std::move(wsConnPtr), callback = std::move(callback)](\n const HttpResponsePtr &resp) mutable {\n if (resp)\n {\n callback(resp);\n }\n else\n {\n websocketRequestRouting(req,\n std::move(callback),\n std::move(wsConnPtr));\n }\n });\n}\n\nvoid HttpServer::websocketRequestRouting(\n const HttpRequestImplPtr &req,\n std::function &&callback,\n WebSocketConnectionImplPtr &&wsConnPtr)\n{\n RouteResult result = HttpControllersRouter::instance().routeWs(req);\n\n if (result.result == RouteResult::Success)\n {\n WsRequestParamPack pack{std::move(result.binderPtr),\n std::move(callback),\n std::move(wsConnPtr)};\n requestPostRouting(req, std::move(pack));\n return;\n }\n if (result.result == RouteResult::MethodNotAllowed)\n {\n handleInvalidHttpMethod(req, std::move(callback));\n return;\n }\n\n // Not found\n auto resp = drogon::HttpResponse::newNotFoundResponse(req);\n resp->setCloseConnection(true);\n callback(resp);\n}\n\nvoid HttpServer::websocketRequestHandling(\n const HttpRequestImplPtr &req,\n std::shared_ptr &&binderPtr,\n std::function &&callback,\n WebSocketConnectionImplPtr &&wsConnPtr)\n{\n binderPtr->handleRequest(\n req,\n [req, callback = std::move(callback)](const HttpResponsePtr &resp) {\n AopAdvice::instance().passPostHandlingAdvices(req, resp);\n callback(resp);\n });\n\n // TODO: more elegant?\n static_cast(binderPtr.get())\n ->handleNewConnection(req, wsConnPtr);\n}\n\nvoid HttpServer::handleResponse(\n const HttpResponsePtr &response,\n const std::shared_ptr ¶mPack,\n bool *respReadyPtr)\n{\n auto &conn = paramPack->conn_;\n auto &req = paramPack->req_;\n auto &requestParser = paramPack->requestParser_;\n auto &loopFlagPtr = paramPack->loopFlag_;\n const bool isHeadMethod = paramPack->isHeadMethod_;\n\n if (!response)\n return;\n if (!conn->connected())\n return;\n\n if (paramPack->responseSent_.exchange(true, std::memory_order_acq_rel))\n {\n LOG_ERROR << \"Sending more than 1 response for request. \"\n \"Ignoring later response\";\n return;\n }\n\n auto resp =\n HttpAppFrameworkImpl::instance().handleSessionForResponse(req,\n response);\n resp->setVersion(req->getVersion());\n resp->setCloseConnection(!req->keepAlive());\n AopAdvice::instance().passPreSendingAdvices(req, resp);\n\n auto newResp = getCompressedResponse(req, resp, isHeadMethod);\n if (conn->getLoop()->isInLoopThread())\n {\n /*\n * A client that supports persistent connections MAY\n * \"pipeline\" its requests (i.e., send multiple requests\n * without waiting for each response). A server MUST send\n * its responses to those requests in the same order that\n * the requests were received. rfc2616-8.1.1.2\n */\n if (requestParser->emptyPipelining())\n {\n // response must have arrived synchronously\n assert(*loopFlagPtr);\n // TODO: change to weakPtr to be sure. But may drop performance.\n *respReadyPtr = true;\n requestParser->getResponseBuffer().emplace_back(std::move(newResp),\n isHeadMethod);\n return;\n }\n if (requestParser->pushResponseToPipelining(req, std::move(newResp)))\n {\n auto &responseBuffer = requestParser->getResponseBuffer();\n requestParser->popReadyResponses(responseBuffer);\n if (!*loopFlagPtr)\n {\n // We have passed the point where `onRequests()` sends\n // responses. So, at here we should send ready responses from\n // the beginning of pipeline queue.\n sendResponses(conn, responseBuffer, requestParser->getBuffer());\n responseBuffer.clear();\n }\n }\n }\n else\n {\n conn->getLoop()->queueInLoop(\n [conn, req, requestParser, newResp = std::move(newResp)]() mutable {\n if (!conn->connected())\n {\n return;\n }\n if (requestParser->pushResponseToPipelining(req,\n std::move(newResp)))\n {\n std::vector> responses;\n requestParser->popReadyResponses(responses);\n sendResponses(conn, responses, requestParser->getBuffer());\n }\n });\n }\n}\n\nstruct ChunkingParams\n{\n using DataCallback = std::function;\n\n \nexplicit ChunkingParams(DataCallback cb) : dataCallback(std::move(cb))\n {\n }\n\n DataCallback dataCallback;\n bool bFinished{false};\n#ifndef NDEBUG // defined by CMake for release build\n std::size_t nDataReturned{0};\n#endif\n};\n\nstatic std::size_t chunkingCallback(\n const std::shared_ptr &cbParams,\n char *pBuffer,\n std::size_t nSize)\n{\n if (!cbParams)\n return 0;\n // Cleanup\n if (pBuffer == nullptr)\n {\n LOG_TRACE << \"Chunking callback cleanup\";\n if (cbParams && cbParams->dataCallback)\n {\n cbParams->dataCallback(pBuffer, nSize);\n cbParams->dataCallback = {};\n }\n return 0;\n }\n // Terminal chunk already returned\n if (cbParams->bFinished)\n {\n LOG_TRACE << \"Chunking callback has no more data\";\n#ifndef NDEBUG // defined by CMake for release build\n LOG_TRACE << \"Chunking callback: total data returned: \"\n << cbParams->nDataReturned << \" bytes\";\n#endif\n return 0;\n }\n\n // Reserve size to prepend the chunk size & append cr/lf, and get data\n struct\n {\n std::size_t operator()(std::size_t n)\n {\n return n == 0 ? 0 : 1 + (*this)(n >> 4);\n }\n } neededDigits;\n\n auto nHeaderSize = neededDigits(nSize) + 2;\n auto nDataSize =\n cbParams->dataCallback(pBuffer + nHeaderSize, nSize - nHeaderSize - 2);\n if (nDataSize == 0)\n {\n // Terminal chunk + cr/lf\n cbParams->bFinished = true;\n#ifdef _WIN32\n memcpy_s(pBuffer, nSize, \"0\\r\\n\\r\\n\", 5);\n#else\n memcpy(pBuffer, \"0\\r\\n\\r\\n\", 5);\n#endif\n LOG_TRACE << \"Chunking callback: no more data, return last chunk of \"\n \"size 0 & end of message\";\n return 5;\n }\n // Non-terminal chunks\n pBuffer[nHeaderSize + nDataSize] = '\\r';\n pBuffer[nHeaderSize + nDataSize + 1] = '\\n';\n // The spec does not say if the chunk size is allowed tohave leading zeroes\n // Use a fixed size header with leading zeroes\n // (tested to work with Chrome, Firefox, Safari, Edge, wget, curl and VLC)\n#ifdef _WIN32\n char pszFormat[]{\"%04llx\\r\"};\n#else\n char pszFormat[]{\"%04lx\\r\"};\n#endif\n pszFormat[2] = '0' + char(nHeaderSize - 2);\n snprintf(pBuffer, nHeaderSize, pszFormat, nDataSize);\n pBuffer[nHeaderSize - 1] = '\\n';\n LOG_TRACE << \"Chunking callback: return chunk of size \" << nDataSize;\n...\n", "needle_token_start": 15404, "needle_token_end": 15430, "code_context_ntokens": 16235, "output": ["```cpp\nstatic std::size_t chunkingCallback(\n const std::shared_ptr &cbParams,\n char *pBuffer,\n std::size_t nSize)\n{\n if (!cbParams)\n return 0;\n // Cleanup\n if (pBuffer == nullptr)\n {\n LOG_TRACE << \"Chunking callback cleanup\";\n if (cbParams && cbParams->dataCallback)\n {\n cbParams->dataCallback(pBuffer, nSize);\n cbParams->dataCallback = {};\n }\n return 0;\n }\n // Terminal chunk already returned\n if (cbParams->bFinished)\n {\n LOG_TRACE << \"Chunking callback has no more data\";\n#ifndef NDEBUG // defined by CMake for release build\n LOG_TRACE << \"Chunking callback: total data returned: \"\n << cbParams->nDataReturned << \" bytes\";\n#endif\n return 0;\n }\n\n // Reserve size to prepend the chunk size & append cr/lf, and get data\n struct\n {\n std::size_t operator()(std::size_t n)\n {\n return n == 0 ? 0 : 1 + (*this)(n >> 4);\n }\n } neededDigits;\n\n auto nHeaderSize = neededDigits(nSize) + 2;\n auto nDataSize =\n cbParams->dataCallback(pBuffer + nHeaderSize, nSize - nHeaderSize - 2);\n if (nDataSize == 0)\n {\n // Terminal chunk + cr/lf\n cbParams->bFinished = true;\n#ifdef _WIN32\n memcpy_s(pBuffer, nSize, \"0\\r\\n\\r\\n\", 5);\n#else\n memcpy(pBuffer, \"0\\r\\n\\r\\n\", 5);\n#endif\n LOG_TRACE << \"Chunking callback: no more data, return last chunk of \"\n \"size 0 & end of message\";\n return 5;\n }\n // Non-terminal chunks\n pBuffer[nHeaderSize + nDataSize] = '\\r';\n pBuffer[nHeaderSize + nDataSize + 1] = '\\n';\n // The spec does not say if the chunk size is allowed tohave leading zeroes\n // Use a fixed size header with leading zeroes\n // (tested to work with Chrome, Firefox, Safari, Edge, wget, curl and VLC)\n#ifdef _WIN32\n char pszFormat[]{\"%04llx\\r\"};\n#else\n char pszFormat[]{\"%04lx\\r\"};\n#endif\n pszFormat[2] = '0' + char(nHeaderSize - 2);\n snprintf(pBuffer, nHeaderSize, pszFormat, nDataSize);\n pBuffer[nHeaderSize - 1] = '\\n';\n LOG_TRACE << \"Chunking callback: return chunk of size \" << nDataSize;\n return nHeaderSize + nDataSize + 2;\n}\n```"]} +{"repo": "sass/node-sass", "name": "compatibility", "language": "cpp", "path": "src/libsass/src/plugins.cpp", "position_ratio": 0.05, "description": "\nFunction Description:\n1. **Purpose**: To determine if a plugin's version is compatible with the current library version, focusing on major version compatibility.\n2. **Input**: A string representing the version of the plugin.\n3. **Output**: A boolean value; `true` if the versions are compatible up to the major version, `false` otherwise.\n4. **Procedure**: The function first checks if either the input version or the library's version is marked as \"not available\" (`\"[na]\"`). If either is, it returns `false`. If both versions are available, it locates the position of the second period in the library's version to isolate the major version. If the second period is not found, it directly compares the full version strings. If found, it compares the substrings up to the second period to check major version compatibility.\n\n", "instruction": "Based on the function description and code context, please retrieve and repeat the exact described function from the code context in a code block wrapped by ```:", "template": "instruction\ncode_context\ndescription\ninstruction", "code_context": "// Path: src/libsass/src/environment.cpp\n#include \"sass.hpp\"\n#include \"ast.hpp\"\n#include \"environment.hpp\"\n\nnamespace Sass {\n\n template \n Environment::Environment(bool is_shadow)\n : local_frame_(environment_map()),\n parent_(0), is_shadow_(false)\n { }\n template \n Environment::Environment(Environment* env, bool is_shadow)\n : local_frame_(environment_map()),\n parent_(env), is_shadow_(is_shadow)\n { }\n template \n Environment::Environment(Environment& env, bool is_shadow)\n : local_frame_(environment_map()),\n parent_(&env), is_shadow_(is_shadow)\n { }\n\n // link parent to create a stack\n template \n void Environment::link(Environment& env) { parent_ = &env; }\n template \n void Environment::link(Environment* env) { parent_ = env; }\n\n // this is used to find the global frame\n // which is the second last on the stack\n template \n bool Environment::is_lexical() const\n {\n return !! parent_ && parent_->parent_;\n }\n\n // only match the real root scope\n // there is still a parent around\n // not sure what it is actually use for\n // I guess we store functions etc. there\n template \n bool Environment::is_global() const\n {\n...\n// Path: src/libsass/src/plugins.cpp\n#include \"sass.hpp\"\n#include \n#include \"output.hpp\"\n#include \"plugins.hpp\"\n\n#ifdef _WIN32\n#include \n#else\n#include \n#include \n#include \n#include \n#endif\n\nnamespace Sass {\n\n Plugins::Plugins(void) { }\n Plugins::~Plugins(void)\n {\n for (auto function : functions) {\n sass_delete_function(function);\n }\n for (auto importer : importers) {\n sass_delete_importer(importer);\n }\n for (auto header : headers) {\n sass_delete_importer(header);\n }\n }\n\n // check if plugin is compatible with this version\n // plugins may be linked static against libsass\n // we try to be compatible between major versions\n \ninline bool compatibility(const char* their_version)\n {\n// const char* their_version = \"3.1.2\";\n // first check if anyone has an unknown version\n const char* our_version = libsass_version();\n if (!strcmp(their_version, \"[na]\")) return false;\n if (!strcmp(our_version, \"[na]\")) return false;\n\n // find the position of the second dot\n size_t pos = std::string(our_version).find('.', 0);\n if (pos != std::string::npos) pos = std::string(our_version).find('.', pos + 1);\n\n // if we do not have two dots we fallback to compare complete string\n if (pos == std::string::npos) { return strcmp(their_version, our_version) ? 0 : 1; }\n // otherwise only compare up to the second dot (major versions)\n else { return strncmp(their_version, our_version, pos) ? 0 : 1; }\n\n }\n\n // load one specific plugin\n bool Plugins::load_plugin (const std::string& path)\n {\n\n typedef const char* (*__plugin_version__)(void);\n typedef Sass_Function_List (*__plugin_load_fns__)(void);\n typedef Sass_Importer_List (*__plugin_load_imps__)(void);\n\n if (LOAD_LIB(plugin, path))\n {\n // try to load initial function to query libsass version suppor\n if (LOAD_LIB_FN(__plugin_version__, plugin_version, \"libsass_get_version\"))\n {\n // get the libsass version of the plugin\n if (!compatibility(plugin_version())) return false;\n // try to get import address for \"libsass_load_functions\"\n if (LOAD_LIB_FN(__plugin_load_fns__, plugin_load_functions, \"libsass_load_functions\"))\n {\n Sass_Function_List fns = plugin_load_functions(), _p = fns;\n while (fns && *fns) { functions.push_back(*fns); ++ fns; }\n sass_free_memory(_p); // only delete the container, items not yet\n }\n // try to get import address for \"libsass_load_importers\"\n if (LOAD_LIB_FN(__plugin_load_imps__, plugin_load_importers, \"libsass_load_importers\"))\n {\n Sass_Importer_List imps = plugin_load_importers(), _p = imps;\n while (imps && *imps) { importers.push_back(*imps); ++ imps; }\n sass_free_memory(_p); // only delete the container, items not yet\n }\n // try to get import address for \"libsass_load_headers\"\n if (LOAD_LIB_FN(__plugin_load_imps__, plugin_load_headers, \"libsass_load_headers\"))\n {\n Sass_Importer_List imps = plugin_load_headers(), _p = imps;\n while (imps && *imps) { headers.push_back(*imps); ++ imps; }\n sass_free_memory(_p); // only delete the container, items not yet\n }\n // success\n return true;\n }\n else\n {\n // print debug message to stderr (should not happen)\n std::cerr << \"failed loading 'libsass_support' in <\" << path << \">\" << std::endl;\n if (const char* dlsym_error = dlerror()) std::cerr << dlsym_error << std::endl;\n CLOSE_LIB(plugin);\n }\n }\n else\n {\n // print debug message to stderr (should not happen)\n std::cerr << \"failed loading plugin <\" << path << \">\" << std::endl;\n if (const char* dlopen_error = dlerror()) std::cerr << dlopen_error << std::endl;\n }\n\n return false;\n\n }\n\n size_t Plugins::load_plugins(const std::string& path)\n {\n\n // count plugins\n size_t loaded = 0;\n\n #ifdef _WIN32\n\n try\n {\n\n // use wchar (utf16)\n WIN32_FIND_DATAW data;\n // trailing slash is guaranteed\n std::string globsrch(path + \"*.dll\");\n // convert to wide chars (utf16) for system call\n std::wstring wglobsrch(UTF_8::convert_to_utf16(globsrch));\n HANDLE hFile = FindFirstFileW(wglobsrch.c_str(), &data);\n // check if system called returned a result\n // ToDo: maybe we should print a debug message\n if (hFile == INVALID_HANDLE_VALUE) return -1;\n\n // read directory\n while (true)\n {\n try\n {\n // the system will report the filenames with wide chars (utf16)\n std::string entry = UTF_8::convert_from_utf16(data.cFileName);\n // check if file ending matches exactly\n if (!ends_with(entry, \".dll\")) continue;\n // load the plugin and increase counter\n if (load_plugin(path + entry)) ++ loaded;\n // check if there should be more entries\n if (GetLastError() == ERROR_NO_MORE_FILES) break;\n // load next entry (check for return type)\n if (!FindNextFileW(hFile, &data)) break;\n }\n catch (...)\n {\n // report the error to the console (should not happen)\n // seems like we got strange data from the system call?\n std::cerr << \"filename in plugin path has invalid utf8?\" << std::endl;\n }\n }\n }\n catch (utf8::invalid_utf8)\n {\n // report the error to the console (should not happen)\n // implementors should make sure to provide valid utf8\n std::cerr << \"plugin path contains invalid utf8\" << std::endl;\n }\n\n #else\n\n DIR *dp;\n struct dirent *dirp;\n if((dp = opendir(path.c_str())) == NULL) return -1;\n while ((dirp = readdir(dp)) != NULL) {\n #if __APPLE__\n if (!ends_with(dirp->d_name, \".dylib\")) continue;\n #else\n if (!ends_with(dirp->d_name, \".so\")) continue;\n #endif\n if (load_plugin(path + dirp->d_name)) ++ loaded;\n }\n closedir(dp);\n\n #endif\n return loaded;\n\n }\n\n}\n\n// Path: src/libsass/src/values.hpp\n#ifndef SASS_VALUES_H\n#define SASS_VALUES_H\n\n#include \"ast.hpp\"\n\nnamespace Sass {\n\n union Sass_Value* ast_node_to_sass_value (const Expression_Ptr val);\n Value_Ptr sass_value_to_ast_node (const union Sass_Value* val);\n\n}\n#endif\n\n// Path: src/libsass/src/listize.hpp\n#ifndef SASS_LISTIZE_H\n#define SASS_LISTIZE_H\n\n#include \n#include \n\n#include \"ast.hpp\"\n#include \"context.hpp\"\n#include \"operation.hpp\"\n#include \"environment.hpp\"\n\nnamespace Sass {\n\n struct Backtrace;\n\n class Listize : public Operation_CRTP {\n\n Expression_Ptr fallback_impl(AST_Node_Ptr n);\n\n public:\n Listize();\n ~Listize() { }\n\n Expression_Ptr operator()(Selector_List_Ptr);\n Expression_Ptr operator()(Complex_Selector_Ptr);\n Expression_Ptr operator()(Compound_Selector_Ptr);\n\n template \n Expression_Ptr fallback(U x) { return fallback_impl(x); }\n };\n\n}\n\n#endif\n\n// Path: src/libsass/src/functions.hpp\n#ifndef SASS_FUNCTIONS_H\n#define SASS_FUNCTIONS_H\n\n#include \"listize.hpp\"\n#include \"position.hpp\"\n#include \"environment.hpp\"\n#include \"ast_fwd_decl.hpp\"\n#include \"sass/functions.h\"\n\n#define BUILT_IN(name) Expression_Ptr \\\nname(Env& env, Env& d_env, Context& ctx, Signature sig, ParserState pstate, Backtraces traces, std::vector selector_stack)\n\nnamespace Sass {\n struct Backtrace;\n typedef const char* Signature;\n typedef Expression_Ptr (*Native_Function)(Env&, Env&, Context&, Signature, ParserState, Backtraces, std::vector);\n\n Definition_Ptr make_native_function(Signature, Native_Function, Context& ctx);\n Definition_Ptr make_c_function(Sass_Function_Entry c_func, Context& ctx);\n\n std::string function_name(Signature);\n\n namespace Functions {\n\n extern Signature rgb_sig;\n extern Signature rgba_4_sig;\n extern Signature rgba_2_sig;\n extern Signature red_sig;\n extern Signature green_sig;\n extern Signature blue_sig;\n extern Signature mix_sig;\n extern Signature hsl_sig;\n extern Signature hsla_sig;\n extern Signature hue_sig;\n extern Signature saturation_sig;\n extern Signature lightness_sig;\n extern Signature adjust_hue_sig;\n extern Signature lighten_sig;\n extern Signature darken_sig;\n extern Signature saturate_sig;\n extern Signature desaturate_sig;\n extern Signature grayscale_sig;\n extern Signature complement_sig;\n extern Signature invert_sig;\n extern Signature alpha_sig;\n extern Signature opacity_sig;\n extern Signature opacify_sig;\n extern Signature fade_in_sig;\n extern Signature transparentize_sig;\n extern Signature fade_out_sig;\n extern Signature adjust_color_sig;\n extern Signature scale_color_sig;\n extern Signature change_color_sig;\n extern Signature ie_hex_str_sig;\n extern Signature unquote_sig;\n extern Signature quote_sig;\n extern Signature str_length_sig;\n extern Signature str_insert_sig;\n extern Signature str_index_sig;\n extern Signature str_slice_sig;\n extern Signature to_upper_case_sig;\n extern Signature to_lower_case_sig;\n extern Signature percentage_sig;\n extern Signature round_sig;\n extern Signature ceil_sig;\n extern Signature floor_sig;\n extern Signature abs_sig;\n extern Signature min_sig;\n extern Signature max_sig;\n extern Signature inspect_sig;\n extern Signature random_sig;\n extern Signature length_sig;\n extern Signature nth_sig;\n extern Signature index_sig;\n extern Signature join_sig;\n extern Signature append_sig;\n extern Signature zip_sig;\n extern Signature list_separator_sig;\n extern Signature type_of_sig;\n extern Signature unit_sig;\n extern Signature unitless_sig;\n extern Signature comparable_sig;\n extern Signature variable_exists_sig;\n extern Signature global_variable_exists_sig;\n extern Signature function_exists_sig;\n extern Signature mixin_exists_sig;\n extern Signature feature_exists_sig;\n extern Signature call_sig;\n extern Signature not_sig;\n extern Signature if_sig;\n extern Signature map_get_sig;\n extern Signature map_merge_sig;\n extern Signature map_remove_sig;\n extern Signature map_keys_sig;\n extern Signature map_values_sig;\n extern Signature map_has_key_sig;\n extern Signature keywords_sig;\n extern Signature set_nth_sig;\n extern Signature unique_id_sig;\n extern Signature selector_nest_sig;\n extern Signature selector_append_sig;\n extern Signature selector_extend_sig;\n extern Signature selector_replace_sig;\n extern Signature selector_unify_sig;\n extern Signature is_superselector_sig;\n extern Signature simple_selectors_sig;\n extern Signature selector_parse_sig;\n extern Signature is_bracketed_sig;\n extern Signature content_exists_sig;\n extern Signature get_function_sig;\n\n BUILT_IN(rgb);\n BUILT_IN(rgba_4);\n BUILT_IN(rgba_2);\n BUILT_IN(red);\n BUILT_IN(green);\n BUILT_IN(blue);\n BUILT_IN(mix);\n BUILT_IN(hsl);\n BUILT_IN(hsla);\n BUILT_IN(hue);\n BUILT_IN(saturation);\n BUILT_IN(lightness);\n BUILT_IN(adjust_hue);\n BUILT_IN(lighten);\n BUILT_IN(darken);\n BUILT_IN(saturate);\n BUILT_IN(desaturate);\n BUILT_IN(grayscale);\n BUILT_IN(complement);\n BUILT_IN(invert);\n BUILT_IN(alpha);\n BUILT_IN(opacify);\n BUILT_IN(transparentize);\n BUILT_IN(adjust_color);\n BUILT_IN(scale_color);\n BUILT_IN(change_color);\n BUILT_IN(ie_hex_str);\n BUILT_IN(sass_unquote);\n BUILT_IN(sass_quote);\n BUILT_IN(str_length);\n BUILT_IN(str_insert);\n BUILT_IN(str_index);\n BUILT_IN(str_slice);\n BUILT_IN(to_upper_case);\n BUILT_IN(to_lower_case);\n BUILT_IN(percentage);\n BUILT_IN(round);\n BUILT_IN(ceil);\n BUILT_IN(floor);\n BUILT_IN(abs);\n BUILT_IN(min);\n BUILT_IN(max);\n BUILT_IN(inspect);\n BUILT_IN(random);\n BUILT_IN(length);\n BUILT_IN(nth);\n BUILT_IN(index);\n BUILT_IN(join);\n BUILT_IN(append);\n BUILT_IN(zip);\n BUILT_IN(list_separator);\n BUILT_IN(type_of);\n BUILT_IN(unit);\n BUILT_IN(unitless);\n BUILT_IN(comparable);\n BUILT_IN(variable_exists);\n BUILT_IN(global_variable_exists);\n BUILT_IN(function_exists);\n BUILT_IN(mixin_exists);\n BUILT_IN(feature_exists);\n BUILT_IN(call);\n BUILT_IN(sass_not);\n BUILT_IN(sass_if);\n BUILT_IN(map_get);\n BUILT_IN(map_merge);\n BUILT_IN(map_remove);\n BUILT_IN(map_keys);\n BUILT_IN(map_values);\n BUILT_IN(map_has_key);\n BUILT_IN(keywords);\n BUILT_IN(set_nth);\n BUILT_IN(unique_id);\n BUILT_IN(selector_nest);\n BUILT_IN(selector_append);\n BUILT_IN(selector_extend);\n BUILT_IN(selector_replace);\n BUILT_IN(selector_unify);\n BUILT_IN(is_superselector);\n BUILT_IN(simple_selectors);\n BUILT_IN(selector_parse);\n BUILT_IN(is_bracketed);\n BUILT_IN(content_exists);\n BUILT_IN(get_function);\n }\n}\n\n#endif\n\n// Path: src/libsass/src/sass_functions.hpp\n#ifndef SASS_SASS_FUNCTIONS_H\n#define SASS_SASS_FUNCTIONS_H\n\n#include \"sass.h\"\n#include \"environment.hpp\"\n#include \"functions.hpp\"\n\n// Struct to hold custom function callback\nstruct Sass_Function {\n char* signature;\n Sass_Function_Fn function;\n void* cookie;\n};\n\n// External import entry\nstruct Sass_Import {\n char* imp_path; // path as found in the import statement\n char *abs_path; // path after importer has resolved it\n char* source;\n char* srcmap;\n // error handling\n char* error;\n size_t line;\n size_t column;\n};\n\n// External environments\nstruct Sass_Env {\n // links to parent frames\n Sass::Env* frame;\n};\n\n// External call entry\nstruct Sass_Callee {\n const char* name;\n const char* path;\n size_t line;\n size_t column;\n enum Sass_Callee_Type type;\n struct Sass_Env env;\n};\n\n// Struct to hold importer callback\nstruct Sass_Importer {\n Sass_Importer_Fn importer;\n double priority;\n void* cookie;\n};\n\n#endif\n// Path: src/libsass/src/sass_functions.cpp\n#include \"sass.hpp\"\n#include \n#include \"util.hpp\"\n#include \"context.hpp\"\n#include \"values.hpp\"\n#include \"sass/functions.h\"\n#include \"sass_functions.hpp\"\n\nextern \"C\" {\n using namespace Sass;\n\n Sass_Function_List ADDCALL sass_make_function_list(size_t length)\n {\n return (Sass_Function_List) calloc(length + 1, sizeof(Sass_Function_Entry));\n }\n\n Sass_Function_Entry ADDCALL sass_make_function(const char* signature, Sass_Function_Fn function, void* cookie)\n {\n Sass_Function_Entry cb = (Sass_Function_Entry) calloc(1, sizeof(Sass_Function));\n if (cb == 0) return 0;\n cb->signature = sass_copy_c_string(signature);\n cb->function = function;\n cb->cookie = cookie;\n return cb;\n }\n\n void ADDCALL sass_delete_function(Sass_Function_Entry entry)\n {\n free(entry->signature);\n free(entry);\n }\n\n // Deallocator for the allocated memory\n void ADDCALL sass_delete_function_list(Sass_Function_List list)\n {\n Sass_Function_List it = list;\n if (list == 0) return;\n while(*list) {\n sass_delete_function(*list);\n ++list;\n }\n free(it);\n }\n\n // Setters and getters for callbacks on function lists\n Sass_Function_Entry ADDCALL sass_function_get_list_entry(Sass_Function_List list, size_t pos) { return list[pos]; }\n void sass_function_set_list_entry(Sass_Function_List list, size_t pos, Sass_Function_Entry cb) { list[pos] = cb; }\n\n const char* ADDCALL sass_function_get_signature(Sass_Function_Entry cb) { return cb->signature; }\n Sass_Function_Fn ADDCALL sass_function_get_function(Sass_Function_Entry cb) { return cb->function; }\n void* ADDCALL sass_function_get_cookie(Sass_Function_Entry cb) { return cb->cookie; }\n\n Sass_Importer_Entry ADDCALL sass_make_importer(Sass_Importer_Fn importer, double priority, void* cookie)\n {\n Sass_Importer_Entry cb = (Sass_Importer_Entry) calloc(1, sizeof(Sass_Importer));\n if (cb == 0) return 0;\n cb->importer = importer;\n cb->priority = priority;\n cb->cookie = cookie;\n return cb;\n }\n\n Sass_Importer_Fn ADDCALL sass_importer_get_function(Sass_Importer_Entry cb) { return cb->importer; }\n double ADDCALL sass_importer_get_priority (Sass_Importer_Entry cb) { return cb->priority; }\n void* ADDCALL sass_importer_get_cookie(Sass_Importer_Entry cb) { return cb->cookie; }\n\n // Just in case we have some stray import structs\n void ADDCALL sass_delete_importer (Sass_Importer_Entry cb)\n {\n free(cb);\n }\n\n // Creator for sass custom importer function list\n Sass_Importer_List ADDCALL sass_make_importer_list(size_t length)\n {\n return (Sass_Importer_List) calloc(length + 1, sizeof(Sass_Importer_Entry));\n }\n\n // Deallocator for the allocated memory\n void ADDCALL sass_delete_importer_list(Sass_Importer_List list)\n {\n Sass_Importer_List it = list;\n if (list == 0) return;\n while(*list) {\n sass_delete_importer(*list);\n ++list;\n }\n free(it);\n }\n\n Sass_Importer_Entry ADDCALL sass_importer_get_list_entry(Sass_Importer_List list, size_t idx) { return list[idx]; }\n void ADDCALL sass_importer_set_list_entry(Sass_Importer_List list, size_t idx, Sass_Importer_Entry cb) { list[idx] = cb; }\n\n // Creator for sass custom importer return argument list\n Sass_Import_List ADDCALL sass_make_import_list(size_t length)\n {\n return (Sass_Import**) calloc(length + 1, sizeof(Sass_Import*));\n }\n\n // Creator for a single import entry returned by the custom importer inside the list\n // We take ownership of the memory for source and srcmap (freed when context is destroyd)\n Sass_Import_Entry ADDCALL sass_make_import(const char* imp_path, const char* abs_path, char* source, char* srcmap)\n {\n Sass_Import* v = (Sass_Import*) calloc(1, sizeof(Sass_Import));\n if (v == 0) return 0;\n v->imp_path = imp_path ? sass_copy_c_string(imp_path) : 0;\n v->abs_path = abs_path ? sass_copy_c_string(abs_path) : 0;\n v->source = source;\n v->srcmap = srcmap;\n v->error = 0;\n v->line = -1;\n v->column = -1;\n return v;\n }\n\n // Older style, but somehow still valid - keep around or deprecate?\n Sass_Import_Entry ADDCALL sass_make_import_entry(const char* path, char* source, char* srcmap)\n {\n return sass_make_import(path, path, source, srcmap);\n }\n\n // Upgrade a normal import entry to throw an error (original path can be re-used by error reporting)\n Sass_Import_Entry ADDCALL sass_import_set_error(Sass_Import_Entry import, const char* error, size_t line, size_t col)\n {\n if (import == 0) return 0;\n if (import->error) free(import->error);\n import->error = error ? sass_copy_c_string(error) : 0;\n import->line = line ? line : -1;\n import->column = col ? col : -1;\n return import;\n }\n\n // Setters and getters for entries on the import list\n void ADDCALL sass_import_set_list_entry(Sass_Import_List list, size_t idx, Sass_Import_Entry entry) { list[idx] = entry; }\n Sass_Import_Entry ADDCALL sass_import_get_list_entry(Sass_Import_List list, size_t idx) { return list[idx]; }\n\n // Deallocator for the allocated memory\n void ADDCALL sass_delete_import_list(Sass_Import_List list)\n {\n Sass_Import_List it = list;\n if (list == 0) return;\n while(*list) {\n sass_delete_import(*list);\n ++list;\n }\n free(it);\n }\n\n // Just in case we have some stray import structs\n void ADDCALL sass_delete_import(Sass_Import_Entry import)\n {\n free(import->imp_path);\n free(import->abs_path);\n free(import->source);\n free(import->srcmap);\n free(import->error);\n free(import);\n }\n\n // Getter for callee entry\n const char* ADDCALL sass_callee_get_name(Sass_Callee_Entry entry) { return entry->name; }\n const char* ADDCALL sass_callee_get_path(Sass_Callee_Entry entry) { return entry->path; }\n size_t ADDCALL sass_callee_get_line(Sass_Callee_Entry entry) { return entry->line; }\n size_t ADDCALL sass_callee_get_column(Sass_Callee_Entry entry) { return entry->column; }\n enum Sass_Callee_Type ADDCALL sass_callee_get_type(Sass_Callee_Entry entry) { return entry->type; }\n Sass_Env_Frame ADDCALL sass_callee_get_env (Sass_Callee_Entry entry) { return &entry->env; }\n\n // Getters and Setters for environments (lexical, local and global)\n union Sass_Value* ADDCALL sass_env_get_lexical (Sass_Env_Frame env, const char* name) {\n Expression_Ptr ex = Cast((*env->frame)[name]);\n return ex != NULL ? ast_node_to_sass_value(ex) : NULL;\n }\n void ADDCALL sass_env_set_lexical (Sass_Env_Frame env, const char* name, union Sass_Value* val) {\n (*env->frame)[name] = sass_value_to_ast_node(val);\n }\n union Sass_Value* ADDCALL sass_env_get_local (Sass_Env_Frame env, const char* name) {\n Expression_Ptr ex = Cast(env->frame->get_local(name));\n return ex != NULL ? ast_node_to_sass_value(ex) : NULL;\n }\n void ADDCALL sass_env_set_local (Sass_Env_Frame env, const char* name, union Sass_Value* val) {\n env->frame->set_local(name, sass_value_to_ast_node(val));\n }\n union Sass_Value* ADDCALL sass_env_get_global (Sass_Env_Frame env, const char* name) {\n Expression_Ptr ex = Cast(env->frame->get_global(name));\n return ex != NULL ? ast_node_to_sass_value(ex) : NULL;\n }\n void ADDCALL sass_env_set_global (Sass_Env_Frame env, const char* name, union Sass_Value* val) {\n env->frame->set_global(name, sass_value_to_ast_node(val));\n }\n\n // Getter for import entry\n const char* ADDCALL sass_import_get_imp_path(Sass_Import_Entry entry) { return entry->imp_path; }\n const char* ADDCALL sass_import_get_abs_path(Sass_Import_Entry entry) { return entry->abs_path; }\n const char* ADDCALL sass_import_get_source(Sass_Import_Entry entry) { return entry->source; }\n const char* ADDCALL sass_import_get_srcmap(Sass_Import_Entry entry) { return entry->srcmap; }\n\n // Getter for import error entry\n size_t ADDCALL sass_import_get_error_line(Sass_Import_Entry entry) { return entry->line; }\n size_t ADDCALL sass_import_get_error_column(Sass_Import_Entry entry) { return entry->column; }\n const char* ADDCALL sass_import_get_error_message(Sass_Import_Entry entry) { return entry->error; }\n\n // Explicit functions to take ownership of the memory\n // Resets our own property since we do not know if it is still alive\n char* ADDCALL sass_import_take_source(Sass_Import_Entry entry) { char* ptr = entry->source; entry->source = 0; return ptr; }\n char* ADDCALL sass_import_take_srcmap(Sass_Import_Entry entry) { char* ptr = entry->srcmap; entry->srcmap = 0; return ptr; }\n\n}\n\n// Path: src/libsass/src/subset_map.cpp\n#include \"sass.hpp\"\n#include \"ast.hpp\"\n#include \"subset_map.hpp\"\n\nnamespace Sass {\n\n void Subset_Map::put(const Compound_Selector_Obj& sel, const SubSetMapPair& value)\n {\n if (sel->empty()) throw std::runtime_error(\"internal error: subset map keys may not be empty\");\n size_t index = values_.size();\n values_.push_back(value);\n for (size_t i = 0, S = sel->length(); i < S; ++i)\n {\n hash_[(*sel)[i]].push_back(std::make_pair(sel, index));\n }\n }\n\n std::vector Subset_Map::get_kv(const Compound_Selector_Obj& sel)\n {\n SimpleSelectorDict dict(sel->begin(), sel->end()); // XXX Set\n std::vector indices;\n for (size_t i = 0, S = sel->length(); i < S; ++i) {\n if (!hash_.count((*sel)[i])) {\n continue;\n }\n const std::vector >& subsets = hash_[(*sel)[i]];\n for (const std::pair& item : subsets) {\n bool include = true;\n for (const Simple_Selector_Obj& it : item.first->elements()) {\n auto found = dict.find(it);\n if (found == dict.end()) {\n include = false;\n break;\n }\n }\n if (include) indices.push_back(item.second);\n }\n }\n sort(indices.begin(), indices.end());\n std::vector::iterator indices_end = unique(indices.begin(), indices.end());\n indices.resize(distance(indices.begin(), indices_end));\n\n std::vector results;\n for (size_t i = 0, S = indices.size(); i < S; ++i) {\n results.push_back(values_[indices[i]]);\n }\n return results;\n }\n\n std::vector Subset_Map::get_v(const Compound_Selector_Obj& sel)\n {\n return get_kv(sel);\n }\n\n}\n// Path: src/libsass/src/to_value.hpp\n#ifndef SASS_TO_VALUE_H\n#define SASS_TO_VALUE_H\n\n#include \"operation.hpp\"\n#include \"sass/values.h\"\n#include \"ast_fwd_decl.hpp\"\n\nnamespace Sass {\n\n class To_Value : public Operation_CRTP {\n\n Value_Ptr fallback_impl(AST_Node_Ptr n);\n\n private:\n\n Context& ctx;\n\n public:\n\n To_Value(Context& ctx)\n : ctx(ctx)\n { }\n ~To_Value() { }\n using Operation::operator();\n\n Value_Ptr operator()(Argument_Ptr);\n Value_Ptr operator()(Boolean_Ptr);\n Value_Ptr operator()(Number_Ptr);\n Value_Ptr operator()(Color_Ptr);\n Value_Ptr operator()(String_Constant_Ptr);\n Value_Ptr operator()(String_Quoted_Ptr);\n Value_Ptr operator()(Custom_Warning_Ptr);\n Value_Ptr operator()(Custom_Error_Ptr);\n Value_Ptr operator()(List_Ptr);\n Value_Ptr operator()(Map_Ptr);\n Value_Ptr operator()(Null_Ptr);\n Value_Ptr operator()(Function_Ptr);\n\n // convert to string via `To_String`\n Value_Ptr operator()(Selector_List_Ptr);\n Value_Ptr operator()(Binary_Expression_Ptr);\n\n // fallback throws error\n template \n Value_Ptr fallback(U x) { return fallback_impl(x); }\n };\n\n}\n\n#endif\n\n// Path: src/libsass/src/to_value.cpp\n#include \"sass.hpp\"\n#include \"ast.hpp\"\n#include \"to_value.hpp\"\n\nnamespace Sass {\n\n Value_Ptr To_Value::fallback_impl(AST_Node_Ptr n)\n {\n // throw a runtime error if this happens\n // we want a well defined set of possible nodes\n throw std::runtime_error(\"invalid node for to_value\");\n }\n\n // Custom_Error is a valid value\n Value_Ptr To_Value::operator()(Custom_Error_Ptr e)\n {\n return e;\n }\n\n // Custom_Warning is a valid value\n Value_Ptr To_Value::operator()(Custom_Warning_Ptr w)\n {\n return w;\n }\n\n // Boolean is a valid value\n Value_Ptr To_Value::operator()(Boolean_Ptr b)\n {\n return b;\n }\n\n // Number is a valid value\n Value_Ptr To_Value::operator()(Number_Ptr n)\n {\n return n;\n }\n\n // Color is a valid value\n Value_Ptr To_Value::operator()(Color_Ptr c)\n {\n return c;\n }\n\n // String_Constant is a valid value\n Value_Ptr To_Value::operator()(String_Constant_Ptr s)\n {\n return s;\n }\n\n // String_Quoted is a valid value\n Value_Ptr To_Value::operator()(String_Quoted_Ptr s)\n {\n return s;\n }\n\n // List is a valid value\n Value_Ptr To_Value::operator()(List_Ptr l)\n {\n List_Obj ll = SASS_MEMORY_NEW(List,\n l->pstate(),\n l->length(),\n l->separator(),\n l->is_arglist(),\n l->is_bracketed());\n for (size_t i = 0, L = l->length(); i < L; ++i) {\n ll->append((*l)[i]->perform(this));\n }\n return ll.detach();\n }\n\n // Map is a valid value\n Value_Ptr To_Value::operator()(Map_Ptr m)\n {\n return m;\n }\n\n // Null is a valid value\n Value_Ptr To_Value::operator()(Null_Ptr n)\n {\n return n;\n }\n\n // Function is a valid value\n Value_Ptr To_Value::operator()(Function_Ptr n)\n {\n return n;\n }\n\n // Argument returns its value\n Value_Ptr To_Value::operator()(Argument_Ptr arg)\n {\n if (!arg->name().empty()) return 0;\n return arg->value()->perform(this);\n }\n\n // Selector_List is converted to a string\n Value_Ptr To_Value::operator()(Selector_List_Ptr s)\n {\n return SASS_MEMORY_NEW(String_Quoted,\n s->pstate(),\n s->to_string(ctx.c_options));\n }\n\n // Binary_Expression is converted to a string\n Value_Ptr To_Value::operator()(Binary_Expression_Ptr s)\n {\n return SASS_MEMORY_NEW(String_Quoted,\n s->pstate(),\n s->to_string(ctx.c_options));\n }\n\n};\n\n// Path: src/libsass/src/json.hpp\n/*\n Copyright (C) 2011 Joseph A. Adams (joeyadams3.14159@gmail.com)\n All rights reserved.\n\n Permission is hereby granted, free of charge, to any person obtaining a copy\n of this software and associated documentation files (the \"Software\"), to deal\n in the Software without restriction, including without limitation the rights\n to use, copy, modify, merge, publish, distribute, sublicense, and/or sell\n copies of the Software, and to permit persons to whom the Software is\n furnished to do so, subject to the following conditions:\n\n The above copyright notice and this permission notice shall be included in\n all copies or substantial portions of the Software.\n\n THE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR\n IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,\n FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE\n AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER\n LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,\n OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN\n THE SOFTWARE.\n*/\n\n#ifndef CCAN_JSON_H\n#define CCAN_JSON_H\n\n#include \n#include \n\ntypedef enum {\n JSON_NULL,\n JSON_BOOL,\n JSON_STRING,\n JSON_NUMBER,\n JSON_ARRAY,\n JSON_OBJECT,\n} JsonTag;\n\ntypedef struct JsonNode JsonNode;\n\nstruct JsonNode\n{\n /* only if parent is an object or array (NULL otherwise) */\n JsonNode *parent;\n JsonNode *prev, *next;\n\n /* only if parent is an object (NULL otherwise) */\n char *key; /* Must be valid UTF-8. */\n\n JsonTag tag;\n union {\n /* JSON_BOOL */\n bool bool_;\n\n /* JSON_STRING */\n char *string_; /* Must be valid UTF-8. */\n\n /* JSON_NUMBER */\n double number_;\n\n /* JSON_ARRAY */\n /* JSON_OBJECT */\n struct {\n JsonNode *head, *tail;\n } children;\n };\n};\n\n/*** Encoding, decoding, and validation ***/\n\nJsonNode *json_decode (const char *json);\nchar *json_encode (const JsonNode *node);\nchar *json_encode_string (const char *str);\nchar *json_stringify (const JsonNode *node, const char *space);\nvoid json_delete (JsonNode *node);\n\nbool json_validate (const char *json);\n\n/*** Lookup and traversal ***/\n\nJsonNode *json_find_element (JsonNode *array, int index);\nJsonNode *json_find_member (JsonNode *object, const char *key);\n\nJsonNode *json_first_child (const JsonNode *node);\n\n#define json_foreach(i, object_or_array) \\\n for ((i) = json_first_child(object_or_array); \\\n (i) != NULL; \\\n (i) = (i)->next)\n\n/*** Construction and manipulation ***/\n\nJsonNode *json_mknull(void);\nJsonNode *json_mkbool(bool b);\nJsonNode *json_mkstring(const char *s);\nJsonNode *json_mknumber(double n);\nJsonNode *json_mkarray(void);\nJsonNode *json_mkobject(void);\n\nvoid json_append_element(JsonNode *array, JsonNode *element);\nvoid json_prepend_element(JsonNode *array, JsonNode *element);\nvoid json_append_member(JsonNode *object, const char *key, JsonNode *value);\nvoid json_prepend_member(JsonNode *object, const char *key, JsonNode *value);\n\nvoid json_remove_from_parent(JsonNode *node);\n\n/*** Debugging ***/\n\n/*\n * Look for structure and encoding problems in a JsonNode or its descendents.\n *\n * If a problem is detected, return false, writing a description of the problem\n * to errmsg (unless errmsg is NULL).\n */\nbool json_check(const JsonNode *node, char errmsg[256]);\n\n#endif\n\n// Path: src/libsass/src/sass_context.cpp\n#include \"sass.hpp\"\n#include \n#include \n#include \n#include \n#include \n\n#include \"sass.h\"\n#include \"ast.hpp\"\n#include \"file.hpp\"\n#include \"json.hpp\"\n#include \"util.hpp\"\n#include \"context.hpp\"\n#include \"sass_context.hpp\"\n#include \"sass_functions.hpp\"\n#include \"ast_fwd_decl.hpp\"\n#include \"error_handling.hpp\"\n\n#define LFEED \"\\n\"\n\n// C++ helper\nnamespace Sass {\n // see sass_copy_c_string(std::string str)\n static inline JsonNode* json_mkstream(const std::stringstream& stream)\n {\n // hold on to string on stack!\n std::string str(stream.str());\n return json_mkstring(str.c_str());\n }\n\n static int handle_error(Sass_Context* c_ctx) {\n try {\n throw;\n }\n catch (Exception::Base& e) {\n std::stringstream msg_stream;\n std::string cwd(Sass::File::get_cwd());\n std::string msg_prefix(e.errtype());\n bool got_newline = false;\n msg_stream << msg_prefix << \": \";\n const char* msg = e.what();\n while (msg && *msg) {\n if (*msg == '\\r') {\n got_newline = true;\n }\n else if (*msg == '\\n') {\n got_newline = true;\n }\n else if (got_newline) {\n msg_stream << std::string(msg_prefix.size() + 2, ' ');\n got_newline = false;\n }\n msg_stream << *msg;\n ++msg;\n }\n if (!got_newline) msg_stream << \"\\n\";\n\n if (e.traces.empty()) {\n // we normally should have some traces, still here as a fallback\n std::string rel_path(Sass::File::abs2rel(e.pstate.path, cwd, cwd));\n msg_stream << std::string(msg_prefix.size() + 2, ' ');\n msg_stream << \" on line \" << e.pstate.line + 1 << \" of \" << rel_path << \"\\n\";\n }\n else {\n std::string rel_path(Sass::File::abs2rel(e.pstate.path, cwd, cwd));\n msg_stream << traces_to_string(e.traces, \" \");\n }\n\n // now create the code trace (ToDo: maybe have util functions?)\n if (e.pstate.line != std::string::npos && e.pstate.column != std::string::npos) {\n size_t lines = e.pstate.line;\n const char* line_beg = e.pstate.src;\n // scan through src until target line\n // move line_beg pointer to line start\n while (line_beg && *line_beg && lines != 0) {\n if (*line_beg == '\\n') --lines;\n utf8::unchecked::next(line_beg); \n }\n const char* line_end = line_beg;\n // move line_end before next newline character\n while (line_end && *line_end && *line_end != '\\n') {\n if (*line_end == '\\n') break;\n if (*line_end == '\\r') break;\n utf8::unchecked::next(line_end); \n }\n if (line_end && *line_end != 0) ++ line_end;\n size_t line_len = line_end - line_beg;\n size_t move_in = 0; size_t shorten = 0;\n size_t left_chars = 42; size_t max_chars = 76;\n // reported excerpt should not exceed `max_chars` chars\n if (e.pstate.column > line_len) left_chars = e.pstate.column;\n if (e.pstate.column > left_chars) move_in = e.pstate.column - left_chars;\n if (line_len > max_chars + move_in) shorten = line_len - move_in - max_chars;\n utf8::advance(line_beg, move_in, line_end);\n utf8::retreat(line_end, shorten, line_beg);\n std::string sanitized; std::string marker(e.pstate.column - move_in, '-');\n utf8::replace_invalid(line_beg, line_end, std::back_inserter(sanitized));\n msg_stream << \">> \" << sanitized << \"\\n\";\n msg_stream << \" \" << marker << \"^\\n\";\n }\n\n JsonNode* json_err = json_mkobject();\n json_append_member(json_err, \"status\", json_mknumber(1));\n json_append_member(json_err, \"file\", json_mkstring(e.pstate.path));\n json_append_member(json_err, \"line\", json_mknumber((double)(e.pstate.line + 1)));\n json_append_member(json_err, \"column\", json_mknumber((double)(e.pstate.column + 1)));\n json_append_member(json_err, \"message\", json_mkstring(e.what()));\n json_append_member(json_err, \"formatted\", json_mkstream(msg_stream));\n try { c_ctx->error_json = json_stringify(json_err, \" \"); }\n catch (...) {}\n c_ctx->error_message = sass_copy_string(msg_stream.str());\n c_ctx->error_text = sass_copy_c_string(e.what());\n c_ctx->error_status = 1;\n c_ctx->error_file = sass_copy_c_string(e.pstate.path);\n c_ctx->error_line = e.pstate.line + 1;\n c_ctx->error_column = e.pstate.column + 1;\n c_ctx->error_src = e.pstate.src;\n c_ctx->output_string = 0;\n c_ctx->source_map_string = 0;\n json_delete(json_err);\n }\n catch (std::bad_alloc& ba) {\n std::stringstream msg_stream;\n JsonNode* json_err = json_mkobject();\n msg_stream << \"Unable to allocate memory: \" << ba.what() << std::endl;\n json_append_member(json_err, \"status\", json_mknumber(2));\n json_append_member(json_err, \"message\", json_mkstring(ba.what()));\n json_append_member(json_err, \"formatted\", json_mkstream(msg_stream));\n try { c_ctx->error_json = json_stringify(json_err, \" \"); }\n catch (...) {}\n c_ctx->error_message = sass_copy_string(msg_stream.str());\n c_ctx->error_text = sass_copy_c_string(ba.what());\n c_ctx->error_status = 2;\n c_ctx->output_string = 0;\n c_ctx->source_map_string = 0;\n json_delete(json_err);\n }\n catch (std::exception& e) {\n std::stringstream msg_stream;\n JsonNode* json_err = json_mkobject();\n msg_stream << \"Internal Error: \" << e.what() << std::endl;\n json_append_member(json_err, \"status\", json_mknumber(3));\n json_append_member(json_err, \"message\", json_mkstring(e.what()));\n json_append_member(json_err, \"formatted\", json_mkstream(msg_stream));\n try { c_ctx->error_json = json_stringify(json_err, \" \"); }\n catch (...) {}\n c_ctx->error_message = sass_copy_string(msg_stream.str());\n c_ctx->error_text = sass_copy_c_string(e.what());\n c_ctx->error_status = 3;\n c_ctx->output_string = 0;\n c_ctx->source_map_string = 0;\n json_delete(json_err);\n }\n catch (std::string& e) {\n std::stringstream msg_stream;\n JsonNode* json_err = json_mkobject();\n msg_stream << \"Internal Error: \" << e << std::endl;\n json_append_member(json_err, \"status\", json_mknumber(4));\n json_append_member(json_err, \"message\", json_mkstring(e.c_str()));\n json_append_member(json_err, \"formatted\", json_mkstream(msg_stream));\n try { c_ctx->error_json = json_stringify(json_err, \" \"); }\n catch (...) {}\n c_ctx->error_message = sass_copy_string(msg_stream.str());\n c_ctx->error_text = sass_copy_c_string(e.c_str());\n c_ctx->error_status = 4;\n c_ctx->output_string = 0;\n c_ctx->source_map_string = 0;\n json_delete(json_err);\n }\n catch (const char* e) {\n std::stringstream msg_stream;\n JsonNode* json_err = json_mkobject();\n msg_stream << \"Internal Error: \" << e << std::endl;\n json_append_member(json_err, \"status\", json_mknumber(4));\n json_append_member(json_err, \"message\", json_mkstring(e));\n json_append_member(json_err, \"formatted\", json_mkstream(msg_stream));\n try { c_ctx->error_json = json_stringify(json_err, \" \"); }\n catch (...) {}\n c_ctx->error_message = sass_copy_string(msg_stream.str());\n c_ctx->error_text = sass_copy_c_string(e);\n c_ctx->error_status = 4;\n c_ctx->output_string = 0;\n c_ctx->source_map_string = 0;\n json_delete(json_err);\n }\n catch (...) {\n std::stringstream msg_stream;\n JsonNode* json_err = json_mkobject();\n msg_stream << \"Unknown error occurred\" << std::endl;\n json_append_member(json_err, \"status\", json_mknumber(5));\n json_append_member(json_err, \"message\", json_mkstring(\"unknown\"));\n try { c_ctx->error_json = json_stringify(json_err, \" \"); }\n catch (...) {}\n c_ctx->error_message = sass_copy_string(msg_stream.str());\n c_ctx->error_text = sass_copy_c_string(\"unknown\");\n c_ctx->error_status = 5;\n c_ctx->output_string = 0;\n c_ctx->source_map_string = 0;\n json_delete(json_err);\n }\n return c_ctx->error_status;\n }\n\n // allow one error handler to throw another error\n // this can happen with invalid utf8 and json lib\n static int handle_errors(Sass_Context* c_ctx) {\n try { return handle_error(c_ctx); }\n catch (...) { return handle_error(c_ctx); }\n }\n\n static Block_Obj sass_parse_block(Sass_Compiler* compiler) throw()\n {\n\n // assert valid pointer\n if (compiler == 0) return 0;\n // The cpp context must be set by now\n Context* cpp_ctx = compiler->cpp_ctx;\n Sass_Context* c_ctx = compiler->c_ctx;\n // We will take care to wire up the rest\n compiler->cpp_ctx->c_compiler = compiler;\n compiler->state = SASS_COMPILER_PARSED;\n\n try {\n\n // get input/output path from options\n std::string input_path = safe_str(c_ctx->input_path);\n std::string output_path = safe_str(c_ctx->output_path);\n\n // maybe skip some entries of included files\n // we do not include stdin for data contexts\n bool skip = c_ctx->type == SASS_CONTEXT_DATA;\n\n // dispatch parse call\n Block_Obj root(cpp_ctx->parse());\n // abort on errors\n if (!root) return 0;\n\n // skip all prefixed files? (ToDo: check srcmap)\n // IMO source-maps should point to headers already\n // therefore don't skip it for now. re-enable or\n // remove completely once this is tested\n size_t headers = cpp_ctx->head_imports;\n\n // copy the included files on to the context (dont forget to free later)\n if (copy_strings(cpp_ctx->get_included_files(skip, headers), &c_ctx->included_files) == NULL)\n throw(std::bad_alloc());\n\n // return parsed block\n return root;\n\n }\n // pass errors to generic error handler\n catch (...) { handle_errors(c_ctx); }\n\n // error\n return 0;\n\n }\n\n}\n\nextern \"C\" {\n using namespace Sass;\n\n static void sass_clear_options (struct Sass_Options* options);\n static void sass_reset_options (struct Sass_Options* options);\n static void copy_options(struct Sass_Options* to, struct Sass_Options* from) {\n // do not overwrite ourself\n if (to == from) return;\n // free assigned memory\n sass_clear_options(to);\n // move memory\n *to = *from;\n // Reset pointers on source\n sass_reset_options(from);\n }\n\n #define IMPLEMENT_SASS_OPTION_ACCESSOR(type, option) \\\n type ADDCALL sass_option_get_##option (struct Sass_Options* options) { return options->option; } \\\n void ADDCALL sass_option_set_##option (struct Sass_Options* options, type option) { options->option = option; }\n #define IMPLEMENT_SASS_OPTION_STRING_GETTER(type, option, def) \\\n type ADDCALL sass_option_get_##option (struct Sass_Options* options) { return safe_str(options->option, def); }\n #define IMPLEMENT_SASS_OPTION_STRING_SETTER(type, option, def) \\\n void ADDCALL sass_option_set_##option (struct Sass_Options* options, type option) \\\n { free(options->option); options->option = option || def ? sass_copy_c_string(option ? option : def) : 0; }\n #define IMPLEMENT_SASS_OPTION_STRING_ACCESSOR(type, option, def) \\\n IMPLEMENT_SASS_OPTION_STRING_GETTER(type, option, def) \\\n IMPLEMENT_SASS_OPTION_STRING_SETTER(type, option, def)\n\n #define IMPLEMENT_SASS_CONTEXT_GETTER(type, option) \\\n type ADDCALL sass_context_get_##option (struct Sass_Context* ctx) { return ctx->option; }\n #define IMPLEMENT_SASS_CONTEXT_TAKER(type, option) \\\n type sass_context_take_##option (struct Sass_Context* ctx) \\\n { type foo = ctx->option; ctx->option = 0; return foo; }\n\n\n // generic compilation function (not exported, use file/data compile instead)\n static Sass_Compiler* sass_prepare_context (Sass_Context* c_ctx, Context* cpp_ctx) throw()\n {\n try {\n // register our custom functions\n if (c_ctx->c_functions) {\n auto this_func_data = c_ctx->c_functions;\n while (this_func_data && *this_func_data) {\n cpp_ctx->add_c_function(*this_func_data);\n ++this_func_data;\n }\n }\n\n // register our custom headers\n if (c_ctx->c_headers) {\n auto this_head_data = c_ctx->c_headers;\n while (this_head_data && *this_head_data) {\n cpp_ctx->add_c_header(*this_head_data);\n ++this_head_data;\n }\n }\n\n // register our custom importers\n if (c_ctx->c_importers) {\n auto this_imp_data = c_ctx->c_importers;\n while (this_imp_data && *this_imp_data) {\n cpp_ctx->add_c_importer(*this_imp_data);\n ++this_imp_data;\n }\n }\n\n // reset error status\n c_ctx->error_json = 0;\n c_ctx->error_text = 0;\n c_ctx->error_message = 0;\n c_ctx->error_status = 0;\n // reset error position\n c_ctx->error_src = 0;\n c_ctx->error_file = 0;\n c_ctx->error_line = std::string::npos;\n c_ctx->error_column = std::string::npos;\n\n // allocate a new compiler instance\n void* ctxmem = calloc(1, sizeof(struct Sass_Compiler));\n if (ctxmem == 0) { std::cerr << \"Error allocating memory for context\" << std::endl; return 0; }\n Sass_Compiler* compiler = (struct Sass_Compiler*) ctxmem;\n compiler->state = SASS_COMPILER_CREATED;\n\n // store in sass compiler\n compiler->c_ctx = c_ctx;\n compiler->cpp_ctx = cpp_ctx;\n cpp_ctx->c_compiler = compiler;\n\n // use to parse block\n return compiler;\n\n }\n // pass errors to generic error handler\n catch (...) { handle_errors(c_ctx); }\n\n // error\n return 0;\n\n }\n\n // generic compilation function (not exported, use file/data compile instead)\n static int sass_compile_context (Sass_Context* c_ctx, Context* cpp_ctx)\n {\n\n // prepare sass compiler with context and options\n...\n", "needle_token_start": 674, "needle_token_end": 931, "code_context_ntokens": 16165, "output": ["```cpp\ninline bool compatibility(const char* their_version)\n{\n // const char* their_version = \"3.1.2\";\n // first check if anyone has an unknown version\n const char* our_version = libsass_version();\n if (!strcmp(their_version, \"[na]\")) return false;\n if (!strcmp(our_version, \"[na]\")) return false;\n\n // find the position of the second dot\n size_t pos = std::string(our_version).find('.', 0);\n if (pos != std::string::npos) pos = std::string(our_version).find('.', pos + 1);\n\n // if we do not have two dots we fallback to compare complete string\n if (pos == std::string::npos) { return strcmp(their_version, our_version) ? 0 : 1; }\n // otherwise only compare up to the second dot (major versions)\n else { return strncmp(their_version, our_version, pos) ? 0 : 1; }\n\n}\n```"]} +{"repo": "sass/node-sass", "name": "Sass_Inspect_Options", "language": "cpp", "path": "src/libsass/src/sass.hpp", "position_ratio": 0.15, "description": "\nFunction Description:\n1. **Purpose**: The purpose of this structure is to configure the output characteristics of CSS code generated from SASS/SCSS files, including the style format, numerical precision, and color compression settings.\n2. **Input**: This structure takes three parameters: a style format for the output CSS (from a predefined set of styles), a numerical precision for fractional values, and a boolean indicating whether to compress color values in selectors.\n3. **Output**: There is no direct output from this structure itself; rather, it influences the output of the CSS generation process by specifying how the CSS should be formatted and how precise numerical values should be.\n4. **Procedure**: Upon initialization, the structure is configured with default values or with values provided by the user. These settings are then used during the CSS generation process to determine the output style, precision of numerical values, and whether colors in selectors should be compressed.\n\n", "instruction": "Based on the function description and code context, please retrieve and repeat the exact described function from the code context in a code block wrapped by ```:", "template": "instruction\ncode_context\ndescription\ninstruction", "code_context": "// Path: src/sass_types/string.cpp\n#include \n#include \"string.h\"\n#include \"../create_string.h\"\n\nnamespace SassTypes\n{\n String::String(Sass_Value* v) : SassValueWrapper(v) {}\n\n Sass_Value* String::construct(const std::vector> raw_val, Sass_Value **out) {\n char const* value = \"\";\n\n if (raw_val.size() >= 1) {\n if (!raw_val[0]->IsString()) {\n return fail(\"Argument should be a string.\", out);\n }\n\n value = create_string(raw_val[0]);\n...\n// Path: src/sass_types/null.cpp\n#include \n#include \"null.h\"\n\nnamespace SassTypes\n{\n Nan::Persistent Null::constructor;\n bool Null::constructor_locked = false;\n\n Null::Null() {\n value = sass_make_null();\n }\n\n Null& Null::get_singleton() {\n static Null singleton_instance;\n return singleton_instance;\n }\n\n v8::Local Null::get_constructor() {\n Nan::EscapableHandleScope scope;\n v8::Local conslocal;\n if (constructor.IsEmpty()) {\n v8::Local tpl = Nan::New(New);\n\n tpl->SetClassName(Nan::New(\"SassNull\").ToLocalChecked());\n tpl->InstanceTemplate()->SetInternalFieldCount(1);\n\n conslocal = Nan::GetFunction(tpl).ToLocalChecked();\n constructor.Reset(conslocal);\n\n get_singleton().js_object.Reset(Nan::NewInstance(conslocal).ToLocalChecked());\n Nan::SetInternalFieldPointer(Nan::New(get_singleton().js_object), 0, &get_singleton());\n Nan::Set(conslocal, Nan::New(\"NULL\").ToLocalChecked(), Nan::New(get_singleton().js_object));\n\n constructor_locked = true;\n } else {\n conslocal = Nan::New(constructor);\n }\n\n return scope.Escape(conslocal);\n }\n\n v8::Local Null::get_js_object() {\n return Nan::New(this->js_object);\n }\n\n NAN_METHOD(Null::New) {\n\n if (info.IsConstructCall()) {\n if (constructor_locked) {\n return Nan::ThrowTypeError(\"Cannot instantiate SassNull\");\n }\n }\n else {\n info.GetReturnValue().Set(get_singleton().get_js_object());\n }\n }\n}\n\n// Path: src/sass_types/boolean.cpp\n#include \n#include \"boolean.h\"\n\nnamespace SassTypes\n{\n Nan::Persistent Boolean::constructor;\n bool Boolean::constructor_locked = false;\n\n Boolean::Boolean(bool _value) {\n value = sass_make_boolean(_value);\n }\n\n Boolean& Boolean::get_singleton(bool v) {\n static Boolean instance_false(false), instance_true(true);\n return v ? instance_true : instance_false;\n }\n\n v8::Local Boolean::get_constructor() {\n Nan::EscapableHandleScope scope;\n v8::Local conslocal;\n if (constructor.IsEmpty()) {\n v8::Local tpl = Nan::New(New);\n\n tpl->SetClassName(Nan::New(\"SassBoolean\").ToLocalChecked());\n tpl->InstanceTemplate()->SetInternalFieldCount(1);\n Nan::SetPrototypeTemplate(tpl, \"getValue\", Nan::New(GetValue));\n\n conslocal = Nan::GetFunction(tpl).ToLocalChecked();\n constructor.Reset(conslocal);\n\n get_singleton(false).js_object.Reset(Nan::NewInstance(conslocal).ToLocalChecked());\n Nan::SetInternalFieldPointer(Nan::New(get_singleton(false).js_object), 0, &get_singleton(false));\n Nan::Set(conslocal, Nan::New(\"FALSE\").ToLocalChecked(), Nan::New(get_singleton(false).js_object));\n\n get_singleton(true).js_object.Reset(Nan::NewInstance(conslocal).ToLocalChecked());\n Nan::SetInternalFieldPointer(Nan::New(get_singleton(true).js_object), 0, &get_singleton(true));\n Nan::Set(conslocal, Nan::New(\"TRUE\").ToLocalChecked(), Nan::New(get_singleton(true).js_object));\n\n constructor_locked = true;\n } else {\n conslocal = Nan::New(constructor);\n }\n\n return scope.Escape(conslocal);\n }\n\n v8::Local Boolean::get_js_object() {\n return Nan::New(this->js_object);\n }\n\n v8::Local Boolean::get_js_boolean() {\n return sass_boolean_get_value(this->value) ? Nan::True() : Nan::False();\n }\n\n NAN_METHOD(Boolean::New) {\n if (info.IsConstructCall()) {\n if (constructor_locked) {\n return Nan::ThrowTypeError(\"Cannot instantiate SassBoolean\");\n }\n }\n else {\n if (info.Length() != 1 || !info[0]->IsBoolean()) {\n return Nan::ThrowTypeError(\"Expected one boolean argument\");\n }\n\n info.GetReturnValue().Set(get_singleton(Nan::To(info[0]).FromJust()).get_js_object());\n }\n }\n\n NAN_METHOD(Boolean::GetValue) {\n info.GetReturnValue().Set(Boolean::Unwrap(info.This())->get_js_boolean());\n }\n}\n\n// Path: src/libsass/src/sass.hpp\n// must be the first include in all compile units\n#ifndef SASS_SASS_H\n#define SASS_SASS_H\n\n// undefine extensions macro to tell sys includes\n// that we do not want any macros to be exported\n// mainly fixes an issue on SmartOS (SEC macro)\n#undef __EXTENSIONS__\n\n#ifdef _MSC_VER\n#pragma warning(disable : 4005)\n#endif\n\n// aplies to MSVC and MinGW\n#ifdef _WIN32\n// we do not want the ERROR macro\n# define NOGDI\n// we do not want the min/max macro\n# define NOMINMAX\n// we do not want the IN/OUT macro\n# define _NO_W32_PSEUDO_MODIFIERS\n#endif\n\n\n// should we be case insensitive\n// when dealing with files or paths\n#ifndef FS_CASE_SENSITIVE\n# ifdef _WIN32\n# define FS_CASE_SENSITIVE 0\n# else\n# define FS_CASE_SENSITIVE 1\n# endif\n#endif\n\n// path separation char\n#ifndef PATH_SEP\n# ifdef _WIN32\n# define PATH_SEP ';'\n# else\n# define PATH_SEP ':'\n# endif\n#endif\n\n\n// include C-API header\n#include \"sass/base.h\"\n\n// For C++ helper\n#include \n\n// output behaviours\nnamespace Sass {\n\n // create some C++ aliases for the most used options\n const static Sass_Output_Style NESTED = SASS_STYLE_NESTED;\n const static Sass_Output_Style COMPACT = SASS_STYLE_COMPACT;\n const static Sass_Output_Style EXPANDED = SASS_STYLE_EXPANDED;\n const static Sass_Output_Style COMPRESSED = SASS_STYLE_COMPRESSED;\n // only used internal to trigger ruby inspect behavior\n const static Sass_Output_Style INSPECT = SASS_STYLE_INSPECT;\n const static Sass_Output_Style TO_SASS = SASS_STYLE_TO_SASS;\n\n // helper to aid dreaded MSVC debug mode\n // see implementation for more details\n char* sass_copy_string(std::string str);\n\n}\n\n// input behaviours\nenum Sass_Input_Style {\n SASS_CONTEXT_NULL,\n SASS_CONTEXT_FILE,\n SASS_CONTEXT_DATA,\n SASS_CONTEXT_FOLDER\n};\n\n// simple linked list\nstruct string_list {\n string_list* next;\n char* string;\n};\n\n// sass config options structure\nstruct Sass_Inspect_Options {\n\n // Output style for the generated css code\n // A value from above SASS_STYLE_* constants\n enum Sass_Output_Style output_style;\n\n // Precision for fractional numbers\n int precision;\n\n // Do not compress colors in selectors\n bool in_selector;\n\n // initialization list (constructor with defaults)\n \nSass_Inspect_Options(Sass_Output_Style style = Sass::NESTED,\n int precision = 5, bool in_selector = false)\n : output_style(style), precision(precision), in_selector(in_selector)\n { }\n\n};\n\n// sass config options structure\nstruct Sass_Output_Options : Sass_Inspect_Options {\n\n // String to be used for indentation\n const char* indent;\n // String to be used to for line feeds\n const char* linefeed;\n\n // Emit comments in the generated CSS indicating\n // the corresponding source line.\n bool source_comments;\n\n // initialization list (constructor with defaults)\n Sass_Output_Options(struct Sass_Inspect_Options opt,\n const char* indent = \" \",\n const char* linefeed = \"\\n\",\n bool source_comments = false)\n : Sass_Inspect_Options(opt),\n indent(indent), linefeed(linefeed),\n source_comments(source_comments)\n { }\n\n // initialization list (constructor with defaults)\n Sass_Output_Options(Sass_Output_Style style = Sass::NESTED,\n int precision = 5,\n const char* indent = \" \",\n const char* linefeed = \"\\n\",\n bool source_comments = false)\n : Sass_Inspect_Options(style, precision),\n indent(indent), linefeed(linefeed),\n source_comments(source_comments)\n { }\n\n};\n\n#endif\n\n// Path: src/libsass/src/memory/SharedPtr.hpp\n#ifndef SASS_MEMORY_SHARED_PTR_H\n#define SASS_MEMORY_SHARED_PTR_H\n\n#include \"sass/base.h\"\n\n#include \n\nnamespace Sass {\n\n class SharedPtr;\n\n ///////////////////////////////////////////////////////////////////////////////\n // Use macros for the allocation task, since overloading operator `new`\n // has been proven to be flaky under certain compilers (see comment below).\n ///////////////////////////////////////////////////////////////////////////////\n\n #ifdef DEBUG_SHARED_PTR\n\n #define SASS_MEMORY_NEW(Class, ...) \\\n ((Class*)(new Class(__VA_ARGS__))->trace(__FILE__, __LINE__)) \\\n\n #define SASS_MEMORY_COPY(obj) \\\n ((obj)->copy(__FILE__, __LINE__)) \\\n\n #define SASS_MEMORY_CLONE(obj) \\\n ((obj)->clone(__FILE__, __LINE__)) \\\n\n #else\n\n #define SASS_MEMORY_NEW(Class, ...) \\\n new Class(__VA_ARGS__) \\\n\n #define SASS_MEMORY_COPY(obj) \\\n ((obj)->copy()) \\\n\n #define SASS_MEMORY_CLONE(obj) \\\n ((obj)->clone()) \\\n\n #endif\n\n class SharedObj {\n protected:\n friend class SharedPtr;\n friend class Memory_Manager;\n #ifdef DEBUG_SHARED_PTR\n static std::vector all;\n std::string file;\n size_t line;\n #endif\n static bool taint;\n long refcounter;\n // long refcount;\n bool detached;\n #ifdef DEBUG_SHARED_PTR\n bool dbg;\n #endif\n public:\n #ifdef DEBUG_SHARED_PTR\n static void dumpMemLeaks();\n SharedObj* trace(std::string file, size_t line) {\n this->file = file;\n this->line = line;\n return this;\n }\n #endif\n SharedObj();\n #ifdef DEBUG_SHARED_PTR\n std::string getDbgFile() {\n return file;\n }\n size_t getDbgLine() {\n return line;\n }\n void setDbg(bool dbg) {\n this->dbg = dbg;\n }\n #endif\n static void setTaint(bool val) {\n taint = val;\n }\n virtual ~SharedObj();\n long getRefCount() {\n return refcounter;\n }\n };\n\n\n class SharedPtr {\n protected:\n SharedObj* node;\n protected:\n void decRefCount();\n void incRefCount();\n public:\n // the empty constructor\n SharedPtr()\n : node(NULL) {};\n // the create constructor\n SharedPtr(SharedObj* ptr);\n // the copy constructor\n SharedPtr(const SharedPtr& obj);\n // the move constructor\n SharedPtr(SharedPtr&& obj);\n // copy assignment operator\n SharedPtr& operator=(const SharedPtr& obj);\n // move assignment operator\n SharedPtr& operator=(SharedPtr&& obj);\n // pure virtual destructor\n virtual ~SharedPtr() = 0;\n public:\n SharedObj* obj () const {\n return node;\n };\n SharedObj* operator-> () const {\n return node;\n };\n bool isNull () {\n return node == NULL;\n };\n bool isNull () const {\n return node == NULL;\n };\n SharedObj* detach() const {\n if (node) {\n node->detached = true;\n }\n return node;\n };\n operator bool() const {\n return node != NULL;\n };\n\n };\n\n template < class T >\n class SharedImpl : private SharedPtr {\n public:\n SharedImpl()\n : SharedPtr(NULL) {};\n SharedImpl(T* node)\n : SharedPtr(node) {};\n template < class U >\n SharedImpl(SharedImpl obj)\n : SharedPtr(static_cast(obj.ptr())) {}\n SharedImpl(T&& node)\n : SharedPtr(node) {};\n SharedImpl(const T& node)\n : SharedPtr(node) {};\n // the copy constructor\n SharedImpl(const SharedImpl& impl)\n : SharedPtr(impl.node) {};\n // the move constructor\n SharedImpl(SharedImpl&& impl)\n : SharedPtr(impl.node) {};\n // copy assignment operator\n SharedImpl& operator=(const SharedImpl& rhs) {\n if (node) decRefCount();\n node = rhs.node;\n incRefCount();\n return *this;\n }\n // move assignment operator\n SharedImpl& operator=(SharedImpl&& rhs) {\n // don't move our self\n if (this != &rhs) {\n if (node) decRefCount();\n node = std::move(rhs.node);\n rhs.node = NULL;\n }\n return *this;\n }\n ~SharedImpl() {};\n public:\n operator T*() const {\n return static_cast(this->obj());\n }\n operator T&() const {\n return *static_cast(this->obj());\n }\n T& operator* () const {\n return *static_cast(this->obj());\n };\n T* operator-> () const {\n return static_cast(this->obj());\n };\n T* ptr () const {\n return static_cast(this->obj());\n };\n T* detach() const {\n if (this->obj() == NULL) return NULL;\n return static_cast(SharedPtr::detach());\n }\n bool isNull() const {\n return this->obj() == NULL;\n }\n bool operator<(const T& rhs) const {\n return *this->ptr() < rhs;\n };\n operator bool() const {\n return this->obj() != NULL;\n };\n };\n\n}\n\n#endif\n// Path: src/libsass/src/ast_fwd_decl.hpp\n#ifndef SASS_AST_FWD_DECL_H\n#define SASS_AST_FWD_DECL_H\n\n#include \n#include \n#include \n#include \n#include \n#include \n#include \n#include \n#include \n#include \"memory/SharedPtr.hpp\"\n#include \"sass/functions.h\"\n\n/////////////////////////////////////////////\n// Forward declarations for the AST visitors.\n/////////////////////////////////////////////\nnamespace Sass {\n\n class AST_Node;\n typedef AST_Node* AST_Node_Ptr;\n typedef AST_Node const* AST_Node_Ptr_Const;\n\n class Has_Block;\n typedef Has_Block* Has_Block_Ptr;\n typedef Has_Block const* Has_Block_Ptr_Const;\n\n class Simple_Selector;\n typedef Simple_Selector* Simple_Selector_Ptr;\n typedef Simple_Selector const* Simple_Selector_Ptr_Const;\n\n class PreValue;\n typedef PreValue* PreValue_Ptr;\n typedef PreValue const* PreValue_Ptr_Const;\n class Thunk;\n typedef Thunk* Thunk_Ptr;\n typedef Thunk const* Thunk_Ptr_Const;\n class Block;\n typedef Block* Block_Ptr;\n typedef Block const* Block_Ptr_Const;\n class Expression;\n typedef Expression* Expression_Ptr;\n typedef Expression const* Expression_Ptr_Const;\n class Statement;\n typedef Statement* Statement_Ptr;\n typedef Statement const* Statement_Ptr_Const;\n class Value;\n typedef Value* Value_Ptr;\n typedef Value const* Value_Ptr_Const;\n class Declaration;\n typedef Declaration* Declaration_Ptr;\n typedef Declaration const* Declaration_Ptr_Const;\n class Ruleset;\n typedef Ruleset* Ruleset_Ptr;\n typedef Ruleset const* Ruleset_Ptr_Const;\n class Bubble;\n typedef Bubble* Bubble_Ptr;\n typedef Bubble const* Bubble_Ptr_Const;\n class Trace;\n typedef Trace* Trace_Ptr;\n typedef Trace const* Trace_Ptr_Const;\n\n class Media_Block;\n typedef Media_Block* Media_Block_Ptr;\n typedef Media_Block const* Media_Block_Ptr_Const;\n class Supports_Block;\n typedef Supports_Block* Supports_Block_Ptr;\n typedef Supports_Block const* Supports_Block_Ptr_Const;\n class Directive;\n typedef Directive* Directive_Ptr;\n typedef Directive const* Directive_Ptr_Const;\n\n\n class Keyframe_Rule;\n typedef Keyframe_Rule* Keyframe_Rule_Ptr;\n typedef Keyframe_Rule const* Keyframe_Rule_Ptr_Const;\n class At_Root_Block;\n typedef At_Root_Block* At_Root_Block_Ptr;\n typedef At_Root_Block const* At_Root_Block_Ptr_Const;\n class Assignment;\n typedef Assignment* Assignment_Ptr;\n typedef Assignment const* Assignment_Ptr_Const;\n\n class Import;\n typedef Import* Import_Ptr;\n typedef Import const* Import_Ptr_Const;\n class Import_Stub;\n typedef Import_Stub* Import_Stub_Ptr;\n typedef Import_Stub const* Import_Stub_Ptr_Const;\n class Warning;\n typedef Warning* Warning_Ptr;\n typedef Warning const* Warning_Ptr_Const;\n\n class Error;\n typedef Error* Error_Ptr;\n typedef Error const* Error_Ptr_Const;\n class Debug;\n typedef Debug* Debug_Ptr;\n typedef Debug const* Debug_Ptr_Const;\n class Comment;\n typedef Comment* Comment_Ptr;\n typedef Comment const* Comment_Ptr_Const;\n\n class If;\n typedef If* If_Ptr;\n typedef If const* If_Ptr_Const;\n class For;\n typedef For* For_Ptr;\n typedef For const* For_Ptr_Const;\n class Each;\n typedef Each* Each_Ptr;\n typedef Each const* Each_Ptr_Const;\n class While;\n typedef While* While_Ptr;\n typedef While const* While_Ptr_Const;\n class Return;\n typedef Return* Return_Ptr;\n typedef Return const* Return_Ptr_Const;\n class Content;\n typedef Content* Content_Ptr;\n typedef Content const* Content_Ptr_Const;\n class Extension;\n typedef Extension* Extension_Ptr;\n typedef Extension const* Extension_Ptr_Const;\n class Definition;\n typedef Definition* Definition_Ptr;\n typedef Definition const* Definition_Ptr_Const;\n\n class List;\n typedef List* List_Ptr;\n typedef List const* List_Ptr_Const;\n class Map;\n typedef Map* Map_Ptr;\n typedef Map const* Map_Ptr_Const;\n class Function;\n typedef Function* Function_Ptr;\n typedef Function const* Function_Ptr_Const;\n\n class Mixin_Call;\n typedef Mixin_Call* Mixin_Call_Ptr;\n typedef Mixin_Call const* Mixin_Call_Ptr_Const;\n class Binary_Expression;\n typedef Binary_Expression* Binary_Expression_Ptr;\n typedef Binary_Expression const* Binary_Expression_Ptr_Const;\n class Unary_Expression;\n typedef Unary_Expression* Unary_Expression_Ptr;\n typedef Unary_Expression const* Unary_Expression_Ptr_Const;\n class Function_Call;\n typedef Function_Call* Function_Call_Ptr;\n typedef Function_Call const* Function_Call_Ptr_Const;\n class Function_Call_Schema;\n typedef Function_Call_Schema* Function_Call_Schema_Ptr;\n typedef Function_Call_Schema const* Function_Call_Schema_Ptr_Const;\n class Custom_Warning;\n typedef Custom_Warning* Custom_Warning_Ptr;\n typedef Custom_Warning const* Custom_Warning_Ptr_Const;\n class Custom_Error;\n typedef Custom_Error* Custom_Error_Ptr;\n typedef Custom_Error const* Custom_Error_Ptr_Const;\n\n class Variable;\n typedef Variable* Variable_Ptr;\n typedef Variable const* Variable_Ptr_Const;\n class Number;\n typedef Number* Number_Ptr;\n typedef Number const* Number_Ptr_Const;\n class Color;\n typedef Color* Color_Ptr;\n typedef Color const* Color_Ptr_Const;\n class Boolean;\n typedef Boolean* Boolean_Ptr;\n typedef Boolean const* Boolean_Ptr_Const;\n class String;\n typedef String* String_Ptr;\n typedef String const* String_Ptr_Const;\n\n class String_Schema;\n typedef String_Schema* String_Schema_Ptr;\n typedef String_Schema const* String_Schema_Ptr_Const;\n class String_Constant;\n typedef String_Constant* String_Constant_Ptr;\n typedef String_Constant const* String_Constant_Ptr_Const;\n class String_Quoted;\n typedef String_Quoted* String_Quoted_Ptr;\n typedef String_Quoted const* String_Quoted_Ptr_Const;\n\n class Media_Query;\n typedef Media_Query* Media_Query_Ptr;\n typedef Media_Query const* Media_Query_Ptr_Const;\n class Media_Query_Expression;\n typedef Media_Query_Expression* Media_Query_Expression_Ptr;\n typedef Media_Query_Expression const* Media_Query_Expression_Ptr_Const;\n class Supports_Condition;\n typedef Supports_Condition* Supports_Condition_Ptr;\n typedef Supports_Condition const* Supports_Condition_Ptr_Const;\n class Supports_Operator;\n typedef Supports_Operator* Supports_Operator_Ptr;\n typedef Supports_Operator const* Supports_Operator_Ptr_Const;\n class Supports_Negation;\n typedef Supports_Negation* Supports_Negation_Ptr;\n typedef Supports_Negation const* Supports_Negation_Ptr_Const;\n class Supports_Declaration;\n typedef Supports_Declaration* Supports_Declaration_Ptr;\n typedef Supports_Declaration const* Supports_Declaration_Ptr_Const;\n class Supports_Interpolation;\n typedef Supports_Interpolation* Supports_Interpolation_Ptr;\n typedef Supports_Interpolation const* Supports_Interpolation_Ptr_Const;\n\n\n class Null;\n typedef Null* Null_Ptr;\n typedef Null const* Null_Ptr_Const;\n\n class At_Root_Query;\n typedef At_Root_Query* At_Root_Query_Ptr;\n typedef At_Root_Query const* At_Root_Query_Ptr_Const;\n class Parent_Selector;\n typedef Parent_Selector* Parent_Selector_Ptr;\n typedef Parent_Selector const* Parent_Selector_Ptr_Const;\n class Parameter;\n typedef Parameter* Parameter_Ptr;\n typedef Parameter const* Parameter_Ptr_Const;\n class Parameters;\n typedef Parameters* Parameters_Ptr;\n typedef Parameters const* Parameters_Ptr_Const;\n class Argument;\n typedef Argument* Argument_Ptr;\n typedef Argument const* Argument_Ptr_Const;\n class Arguments;\n typedef Arguments* Arguments_Ptr;\n typedef Arguments const* Arguments_Ptr_Const;\n class Selector;\n typedef Selector* Selector_Ptr;\n typedef Selector const* Selector_Ptr_Const;\n\n\n class Selector_Schema;\n typedef Selector_Schema* Selector_Schema_Ptr;\n typedef Selector_Schema const* Selector_Schema_Ptr_Const;\n class Placeholder_Selector;\n typedef Placeholder_Selector* Placeholder_Selector_Ptr;\n typedef Placeholder_Selector const* Placeholder_Selector_Ptr_Const;\n class Element_Selector;\n typedef Element_Selector* Element_Selector_Ptr;\n typedef Element_Selector const* Element_Selector_Ptr_Const;\n class Class_Selector;\n typedef Class_Selector* Class_Selector_Ptr;\n typedef Class_Selector const* Class_Selector_Ptr_Const;\n class Id_Selector;\n typedef Id_Selector* Id_Selector_Ptr;\n typedef Id_Selector const* Id_Selector_Ptr_Const;\n class Attribute_Selector;\n typedef Attribute_Selector* Attribute_Selector_Ptr;\n typedef Attribute_Selector const* Attribute_Selector_Ptr_Const;\n\n class Pseudo_Selector;\n typedef Pseudo_Selector* Pseudo_Selector_Ptr;\n typedef Pseudo_Selector const * Pseudo_Selector_Ptr_Const;\n class Wrapped_Selector;\n typedef Wrapped_Selector* Wrapped_Selector_Ptr;\n typedef Wrapped_Selector const * Wrapped_Selector_Ptr_Const;\n class Compound_Selector;\n typedef Compound_Selector* Compound_Selector_Ptr;\n typedef Compound_Selector const * Compound_Selector_Ptr_Const;\n class Complex_Selector;\n typedef Complex_Selector* Complex_Selector_Ptr;\n typedef Complex_Selector const * Complex_Selector_Ptr_Const;\n class Selector_List;\n typedef Selector_List* Selector_List_Ptr;\n typedef Selector_List const * Selector_List_Ptr_Const;\n\n\n // common classes\n class Context;\n class Expand;\n class Eval;\n\n // declare classes that are instances of memory nodes\n // #define IMPL_MEM_OBJ(type) using type##_Obj = SharedImpl\n #define IMPL_MEM_OBJ(type) typedef SharedImpl type##_Obj\n\n IMPL_MEM_OBJ(AST_Node);\n IMPL_MEM_OBJ(Statement);\n IMPL_MEM_OBJ(Block);\n IMPL_MEM_OBJ(Ruleset);\n IMPL_MEM_OBJ(Bubble);\n IMPL_MEM_OBJ(Trace);\n IMPL_MEM_OBJ(Media_Block);\n IMPL_MEM_OBJ(Supports_Block);\n IMPL_MEM_OBJ(Directive);\n IMPL_MEM_OBJ(Keyframe_Rule);\n IMPL_MEM_OBJ(At_Root_Block);\n IMPL_MEM_OBJ(Declaration);\n IMPL_MEM_OBJ(Assignment);\n IMPL_MEM_OBJ(Import);\n IMPL_MEM_OBJ(Import_Stub);\n IMPL_MEM_OBJ(Warning);\n IMPL_MEM_OBJ(Error);\n IMPL_MEM_OBJ(Debug);\n IMPL_MEM_OBJ(Comment);\n IMPL_MEM_OBJ(PreValue);\n IMPL_MEM_OBJ(Has_Block);\n IMPL_MEM_OBJ(Thunk);\n IMPL_MEM_OBJ(If);\n IMPL_MEM_OBJ(For);\n IMPL_MEM_OBJ(Each);\n IMPL_MEM_OBJ(While);\n IMPL_MEM_OBJ(Return);\n IMPL_MEM_OBJ(Content);\n IMPL_MEM_OBJ(Extension);\n IMPL_MEM_OBJ(Definition);\n IMPL_MEM_OBJ(Mixin_Call);\n IMPL_MEM_OBJ(Value);\n IMPL_MEM_OBJ(Expression);\n IMPL_MEM_OBJ(List);\n IMPL_MEM_OBJ(Map);\n IMPL_MEM_OBJ(Function);\n IMPL_MEM_OBJ(Binary_Expression);\n IMPL_MEM_OBJ(Unary_Expression);\n IMPL_MEM_OBJ(Function_Call);\n IMPL_MEM_OBJ(Function_Call_Schema);\n IMPL_MEM_OBJ(Custom_Warning);\n IMPL_MEM_OBJ(Custom_Error);\n IMPL_MEM_OBJ(Variable);\n IMPL_MEM_OBJ(Number);\n IMPL_MEM_OBJ(Color);\n IMPL_MEM_OBJ(Boolean);\n IMPL_MEM_OBJ(String_Schema);\n IMPL_MEM_OBJ(String);\n IMPL_MEM_OBJ(String_Constant);\n IMPL_MEM_OBJ(String_Quoted);\n IMPL_MEM_OBJ(Media_Query);\n IMPL_MEM_OBJ(Media_Query_Expression);\n IMPL_MEM_OBJ(Supports_Condition);\n IMPL_MEM_OBJ(Supports_Operator);\n IMPL_MEM_OBJ(Supports_Negation);\n IMPL_MEM_OBJ(Supports_Declaration);\n IMPL_MEM_OBJ(Supports_Interpolation);\n IMPL_MEM_OBJ(At_Root_Query);\n IMPL_MEM_OBJ(Null);\n IMPL_MEM_OBJ(Parent_Selector);\n IMPL_MEM_OBJ(Parameter);\n IMPL_MEM_OBJ(Parameters);\n IMPL_MEM_OBJ(Argument);\n IMPL_MEM_OBJ(Arguments);\n IMPL_MEM_OBJ(Selector);\n IMPL_MEM_OBJ(Selector_Schema);\n IMPL_MEM_OBJ(Simple_Selector);\n IMPL_MEM_OBJ(Placeholder_Selector);\n IMPL_MEM_OBJ(Element_Selector);\n IMPL_MEM_OBJ(Class_Selector);\n IMPL_MEM_OBJ(Id_Selector);\n IMPL_MEM_OBJ(Attribute_Selector);\n IMPL_MEM_OBJ(Pseudo_Selector);\n IMPL_MEM_OBJ(Wrapped_Selector);\n IMPL_MEM_OBJ(Compound_Selector);\n IMPL_MEM_OBJ(Complex_Selector);\n IMPL_MEM_OBJ(Selector_List);\n\n // ###########################################################################\n // Implement compare, order and hashing operations for AST Nodes\n // ###########################################################################\n\n struct HashNodes {\n template \n size_t operator() (const T& ex) const {\n return ex.isNull() ? 0 : ex->hash();\n }\n };\n struct OrderNodes {\n template \n bool operator() (const T& lhs, const T& rhs) const {\n return !lhs.isNull() && !rhs.isNull() && *lhs < *rhs;\n }\n };\n struct CompareNodes {\n template \n bool operator() (const T& lhs, const T& rhs) const {\n // code around sass logic issue. 1px == 1 is true\n // but both items are still different keys in maps\n if (dynamic_cast(lhs.ptr()))\n if (dynamic_cast(rhs.ptr()))\n return lhs->hash() == rhs->hash();\n return !lhs.isNull() && !rhs.isNull() && *lhs == *rhs;\n }\n };\n\n // ###########################################################################\n // some often used typedefs\n // ###########################################################################\n\n typedef std::unordered_map<\n Expression_Obj, // key\n Expression_Obj, // value\n HashNodes, // hasher\n CompareNodes // compare\n > ExpressionMap;\n typedef std::unordered_set<\n Expression_Obj, // value\n HashNodes, // hasher\n CompareNodes // compare\n > ExpressionSet;\n\n typedef std::string SubSetMapKey;\n typedef std::vector SubSetMapKeys;\n\n typedef std::pair SubSetMapPair;\n typedef std::pair SubSetMapLookup;\n typedef std::vector SubSetMapPairs;\n typedef std::vector SubSetMapLookups;\n\n typedef std::pair SubSetMapResult;\n typedef std::vector SubSetMapResults;\n\n typedef std::deque ComplexSelectorDeque;\n typedef std::set SimpleSelectorSet;\n typedef std::set ComplexSelectorSet;\n typedef std::set CompoundSelectorSet;\n typedef std::unordered_set SimpleSelectorDict;\n\n typedef std::vector* ImporterStack;\n\n // only to switch implementations for testing\n #define environment_map std::map\n\n // ###########################################################################\n // explicit type conversion functions\n // ###########################################################################\n\n template\n T* Cast(AST_Node* ptr);\n\n template\n const T* Cast(const AST_Node* ptr);\n\n // sometimes you know the class you want to cast to is final\n // in this case a simple typeid check is faster and safe to use\n\n #define DECLARE_BASE_CAST(T) \\\n template<> T* Cast(AST_Node* ptr); \\\n template<> const T* Cast(const AST_Node* ptr); \\\n\n // ###########################################################################\n // implement specialization for final classes\n // ###########################################################################\n\n DECLARE_BASE_CAST(AST_Node)\n DECLARE_BASE_CAST(Expression)\n DECLARE_BASE_CAST(Statement)\n DECLARE_BASE_CAST(Has_Block)\n DECLARE_BASE_CAST(PreValue)\n DECLARE_BASE_CAST(Value)\n DECLARE_BASE_CAST(List)\n DECLARE_BASE_CAST(String)\n DECLARE_BASE_CAST(String_Constant)\n DECLARE_BASE_CAST(Supports_Condition)\n DECLARE_BASE_CAST(Selector)\n DECLARE_BASE_CAST(Simple_Selector)\n\n}\n\n#endif\n\n// Path: src/libsass/src/util.hpp\n#ifndef SASS_UTIL_H\n#define SASS_UTIL_H\n\n#include \n#include \n#include \n#include \"sass.hpp\"\n#include \"sass/base.h\"\n#include \"ast_fwd_decl.hpp\"\n\n#define SASS_ASSERT(cond, msg) assert(cond && msg)\n\nnamespace Sass {\n\n double round(double val, size_t precision = 0);\n double sass_strtod(const char* str);\n const char* safe_str(const char *, const char* = \"\");\n void free_string_array(char **);\n char **copy_strings(const std::vector&, char ***, int = 0);\n std::string read_css_string(const std::string& str, bool css = true);\n std::string evacuate_escapes(const std::string& str);\n std::string string_to_output(const std::string& str);\n std::string comment_to_string(const std::string& text);\n std::string read_hex_escapes(const std::string& str);\n std::string escape_string(const std::string& str);\n void newline_to_space(std::string& str);\n\n std::string quote(const std::string&, char q = 0);\n std::string unquote(const std::string&, char* q = 0, bool keep_utf8_sequences = false, bool strict = true);\n char detect_best_quotemark(const char* s, char qm = '\"');\n\n bool is_hex_doublet(double n);\n bool is_color_doublet(double r, double g, double b);\n\n bool peek_linefeed(const char* start);\n\n namespace Util {\n\n std::string rtrim(const std::string& str);\n\n std::string normalize_underscores(const std::string& str);\n std::string normalize_decimals(const std::string& str);\n\n bool isPrintable(Ruleset_Ptr r, Sass_Output_Style style = NESTED);\n bool isPrintable(Supports_Block_Ptr r, Sass_Output_Style style = NESTED);\n bool isPrintable(Media_Block_Ptr r, Sass_Output_Style style = NESTED);\n bool isPrintable(Comment_Ptr b, Sass_Output_Style style = NESTED);\n bool isPrintable(Block_Obj b, Sass_Output_Style style = NESTED);\n bool isPrintable(String_Constant_Ptr s, Sass_Output_Style style = NESTED);\n bool isPrintable(String_Quoted_Ptr s, Sass_Output_Style style = NESTED);\n bool isPrintable(Declaration_Ptr d, Sass_Output_Style style = NESTED);\n bool isAscii(const char chr);\n\n }\n}\n#endif\n\n// Path: src/libsass/src/units.hpp\n#ifndef SASS_UNITS_H\n#define SASS_UNITS_H\n\n#include \n#include \n#include \n#include \n\nnamespace Sass {\n\n const double PI = std::acos(-1);\n\n enum UnitClass {\n LENGTH = 0x000,\n ANGLE = 0x100,\n TIME = 0x200,\n FREQUENCY = 0x300,\n RESOLUTION = 0x400,\n INCOMMENSURABLE = 0x500\n };\n\n enum UnitType {\n\n // size units\n IN = UnitClass::LENGTH,\n CM,\n PC,\n MM,\n PT,\n PX,\n\n // angle units\n DEG = ANGLE,\n GRAD,\n RAD,\n TURN,\n\n // time units\n SEC = TIME,\n MSEC,\n\n // frequency units\n HERTZ = FREQUENCY,\n KHERTZ,\n\n // resolutions units\n DPI = RESOLUTION,\n DPCM,\n DPPX,\n\n // for unknown units\n UNKNOWN = INCOMMENSURABLE\n\n };\n\n class Units {\n public:\n std::vector numerators;\n std::vector denominators;\n public:\n // default constructor\n Units() :\n numerators(),\n denominators()\n { }\n // copy constructor\n Units(const Units* ptr) :\n numerators(ptr->numerators),\n denominators(ptr->denominators)\n { }\n // convert to string\n std::string unit() const;\n // get if units are empty\n bool is_unitless() const;\n // return if valid for css\n bool is_valid_css_unit() const;\n // reduce units for output\n // returns conversion factor\n double reduce();\n // normalize units for compare\n // returns conversion factor\n double normalize();\n // compare operations\n bool operator< (const Units& rhs) const;\n bool operator== (const Units& rhs) const;\n // factor to convert into given units\n double convert_factor(const Units&) const;\n };\n\n extern const double size_conversion_factors[6][6];\n extern const double angle_conversion_factors[4][4];\n extern const double time_conversion_factors[2][2];\n extern const double frequency_conversion_factors[2][2];\n extern const double resolution_conversion_factors[3][3];\n\n UnitType get_main_unit(const UnitClass unit);\n enum Sass::UnitType string_to_unit(const std::string&);\n const char* unit_to_string(Sass::UnitType unit);\n enum Sass::UnitClass get_unit_type(Sass::UnitType unit);\n std::string get_unit_class(Sass::UnitType unit);\n std::string unit_to_class(const std::string&);\n // throws incompatibleUnits exceptions\n double conversion_factor(const std::string&, const std::string&);\n double conversion_factor(UnitType, UnitType, UnitClass, UnitClass);\n double convert_units(const std::string&, const std::string&, int&, int&);\n\n}\n\n#endif\n\n// Path: src/libsass/src/b64/cencode.h\n/*\ncencode.h - c header for a base64 encoding algorithm\n\nThis is part of the libb64 project, and has been placed in the public domain.\nFor details, see http://sourceforge.net/projects/libb64\n*/\n\n#ifndef BASE64_CENCODE_H\n#define BASE64_CENCODE_H\n\ntypedef enum\n{\n step_A, step_B, step_C\n} base64_encodestep;\n\ntypedef struct\n{\n\tbase64_encodestep step;\n\tchar result;\n\tint stepcount;\n} base64_encodestate;\n\nvoid base64_init_encodestate(base64_encodestate* state_in);\n\nchar base64_encode_value(char value_in);\n\nint base64_encode_block(const char* plaintext_in, int length_in, char* code_out, base64_encodestate* state_in);\n\nint base64_encode_blockend(char* code_out, base64_encodestate* state_in);\n\n#endif /* BASE64_CENCODE_H */\n\n\n// Path: src/libsass/src/b64/encode.h\n// :mode=c++:\n/*\nencode.h - c++ wrapper for a base64 encoding algorithm\n\nThis is part of the libb64 project, and has been placed in the public domain.\nFor details, see http://sourceforge.net/projects/libb64\n*/\n#ifndef BASE64_ENCODE_H\n#define BASE64_ENCODE_H\n\n#include \n\nnamespace base64\n{\n\textern \"C\"\n\t{\n\t\t#include \"cencode.h\"\n\t}\n\n\tstruct encoder\n\t{\n\t\tbase64_encodestate _state;\n\t\tint _buffersize;\n\n\t\tencoder(int buffersize_in = BUFFERSIZE)\n\t\t: _buffersize(buffersize_in)\n\t\t{\n\t\t\tbase64_init_encodestate(&_state);\n\t\t}\n\n\t\tint encode(char value_in)\n\t\t{\n\t\t\treturn base64_encode_value(value_in);\n\t\t}\n\n\t\tint encode(const char* code_in, const int length_in, char* plaintext_out)\n\t\t{\n\t\t\treturn base64_encode_block(code_in, length_in, plaintext_out, &_state);\n\t\t}\n\n\t\tint encode_end(char* plaintext_out)\n\t\t{\n\t\t\treturn base64_encode_blockend(plaintext_out, &_state);\n\t\t}\n\n\t\tvoid encode(std::istream& istream_in, std::ostream& ostream_in)\n\t\t{\n\t\t\tbase64_init_encodestate(&_state);\n\t\t\t//\n\t\t\tconst int N = _buffersize;\n\t\t\tchar* plaintext = new char[N];\n\t\t\tchar* code = new char[2*N];\n\t\t\tint plainlength;\n\t\t\tint codelength;\n\n\t\t\tdo\n\t\t\t{\n\t\t\t\tistream_in.read(plaintext, N);\n\t\t\t\tplainlength = static_cast(istream_in.gcount());\n\t\t\t\t//\n\t\t\t\tcodelength = encode(plaintext, plainlength, code);\n\t\t\t\tostream_in.write(code, codelength);\n\t\t\t}\n\t\t\twhile (istream_in.good() && plainlength > 0);\n\n\t\t\tcodelength = encode_end(code);\n\t\t\tostream_in.write(code, codelength);\n\t\t\t//\n\t\t\tbase64_init_encodestate(&_state);\n\n\t\t\tdelete [] code;\n\t\t\tdelete [] plaintext;\n\t\t}\n\t};\n\n} // namespace base64\n\n#endif // BASE64_ENCODE_H\n\n\n// Path: src/libsass/src/kwd_arg_macros.hpp\n#ifndef SASS_KWD_ARG_MACROS_H\n#define SASS_KWD_ARG_MACROS_H\n\n// Example usage:\n// KWD_ARG_SET(Args) {\n// KWD_ARG(Args, string, foo);\n// KWD_ARG(Args, int, bar);\n// ...\n// };\n//\n// ... and later ...\n//\n// something(Args().foo(\"hey\").bar(3));\n\n#define KWD_ARG_SET(set_name) class set_name\n\n#define KWD_ARG(set_name, type, name) \\\nprivate: \\\n type name##_; \\\npublic: \\\n set_name& name(type name##__) { \\\n name##_ = name##__; \\\n return *this; \\\n } \\\n type name() { return name##_; } \\\nprivate:\n\n#endif\n\n// Path: src/libsass/src/sass_context.hpp\n#ifndef SASS_SASS_CONTEXT_H\n#define SASS_SASS_CONTEXT_H\n\n#include \"sass/base.h\"\n#include \"sass/context.h\"\n#include \"ast_fwd_decl.hpp\"\n\n// sass config options structure\nstruct Sass_Options : Sass_Output_Options {\n\n // embed sourceMappingUrl as data uri\n bool source_map_embed;\n\n // embed include contents in maps\n bool source_map_contents;\n\n // create file urls for sources\n bool source_map_file_urls;\n\n // Disable sourceMappingUrl in css output\n bool omit_source_map_url;\n\n // Treat source_string as sass (as opposed to scss)\n bool is_indented_syntax_src;\n\n // The input path is used for source map\n // generation. It can be used to define\n // something with string compilation or to\n // overload the input file path. It is\n // set to \"stdin\" for data contexts and\n // to the input file on file contexts.\n char* input_path;\n\n // The output path is used for source map\n // generation. LibSass will not write to\n // this file, it is just used to create\n // information in source-maps etc.\n char* output_path;\n\n // Colon-separated list of paths\n // Semicolon-separated on Windows\n // Maybe use array interface instead?\n char* include_path;\n char* plugin_path;\n\n // Include paths (linked string list)\n struct string_list* include_paths;\n // Plugin paths (linked string list)\n struct string_list* plugin_paths;\n\n // Path to source map file\n // Enables source map generation\n // Used to create sourceMappingUrl\n char* source_map_file;\n\n // Directly inserted in source maps\n char* source_map_root;\n\n // Custom functions that can be called from sccs code\n Sass_Function_List c_functions;\n\n // List of custom importers\n Sass_Importer_List c_importers;\n\n // List of custom headers\n Sass_Importer_List c_headers;\n\n};\n\n\n// base for all contexts\nstruct Sass_Context : Sass_Options\n{\n\n // store context type info\n enum Sass_Input_Style type;\n\n // generated output data\n char* output_string;\n\n // generated source map json\n char* source_map_string;\n\n // error status\n int error_status;\n char* error_json;\n char* error_text;\n char* error_message;\n // error position\n char* error_file;\n size_t error_line;\n size_t error_column;\n const char* error_src;\n\n // report imported files\n char** included_files;\n\n};\n\n// struct for file compilation\nstruct Sass_File_Context : Sass_Context {\n\n // no additional fields required\n // input_path is already on options\n\n};\n\n// struct for data compilation\nstruct Sass_Data_Context : Sass_Context {\n\n // provided source string\n char* source_string;\n char* srcmap_string;\n\n};\n\n// link c and cpp context\nstruct Sass_Compiler {\n // progress status\n Sass_Compiler_State state;\n // original c context\n Sass_Context* c_ctx;\n // Sass::Context\n Sass::Context* cpp_ctx;\n // Sass::Block\n Sass::Block_Obj root;\n};\n\n#endif\n// Path: src/libsass/src/ast_def_macros.hpp\n#ifndef SASS_AST_DEF_MACROS_H\n#define SASS_AST_DEF_MACROS_H\n\n// Helper class to switch a flag and revert once we go out of scope\ntemplate \nclass LocalOption {\n private:\n T* var; // pointer to original variable\n T orig; // copy of the original option\n public:\n LocalOption(T& var)\n {\n this->var = &var;\n this->orig = var;\n }\n LocalOption(T& var, T orig)\n {\n this->var = &var;\n this->orig = var;\n *(this->var) = orig;\n }\n void reset()\n {\n *(this->var) = this->orig;\n }\n ~LocalOption() {\n *(this->var) = this->orig;\n }\n};\n\n#define LOCAL_FLAG(name,opt) LocalOption flag_##name(name, opt)\n#define LOCAL_COUNT(name,opt) LocalOption cnt_##name(name, opt)\n\n#define NESTING_GUARD(name) \\\n LocalOption cnt_##name(name, name + 1); \\\n if (name > MAX_NESTING) throw Exception::NestingLimitError(pstate, traces); \\\n\n#define ATTACH_OPERATIONS()\\\nvirtual void perform(Operation* op) { (*op)(this); }\\\nvirtual AST_Node_Ptr perform(Operation* op) { return (*op)(this); }\\\nvirtual Statement_Ptr perform(Operation* op) { return (*op)(this); }\\\nvirtual Expression_Ptr perform(Operation* op) { return (*op)(this); }\\\nvirtual Selector_Ptr perform(Operation* op) { return (*op)(this); }\\\nvirtual std::string perform(Operation* op) { return (*op)(this); }\\\nvirtual union Sass_Value* perform(Operation* op) { return (*op)(this); }\\\nvirtual Value_Ptr perform(Operation* op) { return (*op)(this); }\n\n#define ADD_PROPERTY(type, name)\\\nprotected:\\\n type name##_;\\\npublic:\\\n type name() const { return name##_; }\\\n type name(type name##__) { return name##_ = name##__; }\\\nprivate:\n\n#define HASH_PROPERTY(type, name)\\\nprotected:\\\n type name##_;\\\npublic:\\\n type name() const { return name##_; }\\\n type name(type name##__) { hash_ = 0; return name##_ = name##__; }\\\nprivate:\n\n#define ADD_CONSTREF(type, name) \\\nprotected: \\\n type name##_; \\\npublic: \\\n const type& name() const { return name##_; } \\\n void name(type name##__) { name##_ = name##__; } \\\nprivate:\n\n#define HASH_CONSTREF(type, name) \\\nprotected: \\\n type name##_; \\\npublic: \\\n const type& name() const { return name##_; } \\\n void name(type name##__) { hash_ = 0; name##_ = name##__; } \\\nprivate:\n\n#endif\n\n// Path: src/libsass/src/environment.hpp\n#ifndef SASS_ENVIRONMENT_H\n#define SASS_ENVIRONMENT_H\n\n#include \n#include \"ast_fwd_decl.hpp\"\n#include \"ast_def_macros.hpp\"\n\nnamespace Sass {\n\n typedef environment_map::iterator EnvIter;\n\n class EnvResult {\n public:\n EnvIter it;\n bool found;\n public:\n EnvResult(EnvIter it, bool found)\n : it(it), found(found) {}\n };\n\n template \n class Environment {\n // TODO: test with map\n environment_map local_frame_;\n ADD_PROPERTY(Environment*, parent)\n ADD_PROPERTY(bool, is_shadow)\n\n public:\n Environment(bool is_shadow = false);\n Environment(Environment* env, bool is_shadow = false);\n Environment(Environment& env, bool is_shadow = false);\n\n // link parent to create a stack\n void link(Environment& env);\n void link(Environment* env);\n\n // this is used to find the global frame\n // which is the second last on the stack\n bool is_lexical() const;\n\n // only match the real root scope\n // there is still a parent around\n // not sure what it is actually use for\n // I guess we store functions etc. there\n bool is_global() const;\n\n // scope operates on the current frame\n\n environment_map& local_frame();\n\n bool has_local(const std::string& key) const;\n\n EnvResult find_local(const std::string& key);\n\n T& get_local(const std::string& key);\n\n // set variable on the current frame\n void set_local(const std::string& key, const T& val);\n void set_local(const std::string& key, T&& val);\n\n void del_local(const std::string& key);\n\n // global operates on the global frame\n // which is the second last on the stack\n Environment* global_env();\n // get the env where the variable already exists\n // if it does not yet exist, we return current env\n Environment* lexical_env(const std::string& key);\n\n bool has_global(const std::string& key);\n\n T& get_global(const std::string& key);\n\n // set a variable on the global frame\n void set_global(const std::string& key, const T& val);\n void set_global(const std::string& key, T&& val);\n\n void del_global(const std::string& key);\n\n // see if we have a lexical variable\n // move down the stack but stop before we\n // reach the global frame (is not included)\n bool has_lexical(const std::string& key) const;\n\n // see if we have a lexical we could update\n // either update already existing lexical value\n // or we create a new one on the current frame\n void set_lexical(const std::string& key, T&& val);\n void set_lexical(const std::string& key, const T& val);\n\n // look on the full stack for key\n // include all scopes available\n bool has(const std::string& key) const;\n\n // look on the full stack for key\n // include all scopes available\n EnvResult find(const std::string& key);\n\n // use array access for getter and setter functions\n T& operator[](const std::string& key);\n\n #ifdef DEBUG\n size_t print(std::string prefix = \"\");\n #endif\n\n };\n\n // define typedef for our use case\n typedef Environment Env;\n\n}\n\n#endif\n\n// Path: src/libsass/src/base64vlq.hpp\n#ifndef SASS_BASE64VLQ_H\n#define SASS_BASE64VLQ_H\n\n#include \n\nnamespace Sass {\n\n class Base64VLQ {\n\n public:\n\n std::string encode(const int number) const;\n\n private:\n\n char base64_encode(const int number) const;\n\n int to_vlq_signed(const int number) const;\n\n static const char* CHARACTERS;\n\n static const int VLQ_BASE_SHIFT;\n static const int VLQ_BASE;\n static const int VLQ_BASE_MASK;\n static const int VLQ_CONTINUATION_BIT;\n };\n\n}\n\n#endif\n\n// Path: src/libsass/src/position.hpp\n#ifndef SASS_POSITION_H\n#define SASS_POSITION_H\n\n#include \n#include \n// #include \n\nnamespace Sass {\n\n\n class Offset {\n\n public: // c-tor\n Offset(const char chr);\n Offset(const char* string);\n Offset(const std::string& text);\n Offset(const size_t line, const size_t column);\n\n // return new position, incremented by the given string\n Offset add(const char* begin, const char* end);\n Offset inc(const char* begin, const char* end) const;\n\n // init/create instance from const char substring\n static Offset init(const char* beg, const char* end);\n\n public: // overload operators for position\n void operator+= (const Offset &pos);\n bool operator== (const Offset &pos) const;\n bool operator!= (const Offset &pos) const;\n Offset operator+ (const Offset &off) const;\n Offset operator- (const Offset &off) const;\n\n public: // overload output stream operator\n // friend std::ostream& operator<<(std::ostream& strm, const Offset& off);\n\n public:\n Offset off() { return *this; }\n\n public:\n size_t line;\n size_t column;\n\n };\n\n class Position : public Offset {\n\n public: // c-tor\n Position(const size_t file); // line(0), column(0)\n Position(const size_t file, const Offset& offset);\n Position(const size_t line, const size_t column); // file(-1)\n Position(const size_t file, const size_t line, const size_t column);\n\n public: // overload operators for position\n void operator+= (const Offset &off);\n bool operator== (const Position &pos) const;\n bool operator!= (const Position &pos) const;\n const Position operator+ (const Offset &off) const;\n const Offset operator- (const Offset &off) const;\n // return new position, incremented by the given string\n Position add(const char* begin, const char* end);\n Position inc(const char* begin, const char* end) const;\n\n public: // overload output stream operator\n // friend std::ostream& operator<<(std::ostream& strm, const Position& pos);\n\n public:\n size_t file;\n\n };\n\n // Token type for representing lexed chunks of text\n class Token {\n public:\n const char* prefix;\n const char* begin;\n const char* end;\n\n Token()\n...\n", "needle_token_start": 2364, "needle_token_end": 2429, "code_context_ntokens": 16025, "output": ["```cpp\n// Path: src/libsass/src/sass.hpp\n// must be the first include in all compile units\n#ifndef SASS_SASS_H\n#define SASS_SASS_H\n\n// undefine extensions macro to tell sys includes\n// that we do not want any macros to be exported\n// mainly fixes an issue on SmartOS (SEC macro)\n#undef __EXTENSIONS__\n\n#ifdef _MSC_VER\n#pragma warning(disable : 4005)\n#endif\n\n// aplies to MSVC and MinGW\n#ifdef _WIN32\n// we do not want the ERROR macro\n# define NOGDI\n// we do not want the min/max macro\n# define NOMINMAX\n// we do not want the IN/OUT macro\n# define _NO_W32_PSEUDO_MODIFIERS\n#endif\n\n\n// should we be case insensitive\n// when dealing with files or paths\n#ifndef FS_CASE_SENSITIVE\n# ifdef _WIN32\n# define FS_CASE_SENSITIVE 0\n# else\n# define FS_CASE_SENSITIVE 1\n# endif\n#endif\n\n// path separation char\n#ifndef PATH_SEP\n# ifdef _WIN32\n# define PATH_SEP ';'\n# else\n# define PATH_SEP ':'\n# endif\n#endif\n\n\n// include C-API header\n#include \"sass/base.h\"\n\n// For C++ helper\n#include \n\n// output behaviours\nnamespace Sass {\n\n // create some C++ aliases for the most used options\n const static Sass_Output_Style NESTED = SASS_STYLE_NESTED;\n const static Sass_Output_Style COMPACT = SASS_STYLE_COMPACT;\n const static Sass_Output_Style EXPANDED = SASS_STYLE_EXPANDED;\n const static Sass_Output_Style COMPRESSED = SASS_STYLE_COMPRESSED;\n // only used internal to trigger ruby inspect behavior\n const static Sass_Output_Style INSPECT = SASS_STYLE_INSPECT;\n const static Sass_Output_Style TO_SASS = SASS_STYLE_TO_SASS;\n\n // helper to aid dreaded MSVC debug mode\n // see implementation for more details\n char* sass_copy_string(std::string str);\n\n}\n\n// input behaviours\nenum Sass_Input_Style {\n SASS_CONTEXT_NULL,\n SASS_CONTEXT_FILE,\n SASS_CONTEXT_DATA,\n SASS_CONTEXT_FOLDER\n};\n\n// simple linked list\nstruct string_list {\n string_list* next;\n char* string;\n};\n\n// sass config options structure\nstruct Sass_Inspect_Options {\n\n // Output style for the generated css code\n // A value from above SASS_STYLE_* constants\n enum Sass_Output_Style output_style;\n\n // Precision for fractional numbers\n int precision;\n\n // Do not compress colors in selectors\n bool in_selector;\n\n // initialization list (constructor with defaults)\n \nSass_Inspect_Options(Sass_Output_Style style = Sass::NESTED,\n int precision = 5, bool in_selector = false)\n : output_style(style), precision(precision), in_selector(in_selector)\n { }\n\n};\n\n// sass config options structure\nstruct Sass_Output_Options : Sass_Inspect_Options {\n\n // String to be used for indentation\n const char* indent;\n // String to be used to for line feeds\n const char* linefeed;\n\n // Emit comments in the generated CSS indicating\n // the corresponding source line.\n bool source_comments;\n\n // initialization list (constructor with defaults)\n Sass_Output_Options(struct Sass_Inspect_Options opt,\n const char* indent = \" \",\n const char* linefeed = \"\\n\",\n bool source_comments = false)\n : Sass_Inspect_Options(opt),\n indent(indent), linefeed(linefeed),\n source_comments(source_comments)\n { }\n\n // initialization list (constructor with defaults)\n Sass_Output_Options(Sass_Output_Style style = Sass::NESTED,\n int precision = 5,\n const char* indent = \" \",\n const char* linefeed = \"\\n\",\n bool source_comments = false)\n : Sass_Inspect_Options(style, precision),\n indent(indent), linefeed(linefeed),\n source_comments(source_comments)\n { }\n\n};\n\n#endif\n```"]} +{"repo": "sass/node-sass", "name": "Thunk", "language": "cpp", "path": "src/libsass/src/ast.hpp", "position_ratio": 0.25, "description": "\nFunction Description:\n1. **Purpose**: The purpose of this class is to facilitate delayed evaluation of expressions within a specific environment. This is useful in scenarios where the evaluation needs to be deferred until a certain condition or context is met.\n2. **Input**: It takes an expression object and an optional environment pointer. The expression object represents the computation or value to be evaluated later, and the environment pointer refers to the scope or context in which this expression should be evaluated.\n3. **Output**: There is no direct output since this class primarily handles the storage of an expression and its evaluation context. The actual output would be generated when the stored expression is eventually evaluated within its stored environment.\n4. **Procedure**: Upon instantiation, the class stores the provided expression and environment. The evaluation of the expression is deferred until explicitly invoked, at which point the expression is evaluated in the context of the stored environment, ensuring that the correct scope and conditions are applied.\n\n", "instruction": "Based on the function description and code context, please retrieve and repeat the exact described function from the code context in a code block wrapped by ```:", "template": "instruction\ncode_context\ndescription\ninstruction", "code_context": "\n std::string type() const { return \"string\"; }\n static std::string type_name() { return \"string\"; }\n\n bool is_left_interpolant(void) const;\n bool is_right_interpolant(void) const;\n // void has_interpolants(bool tc) { }\n bool has_interpolants() {\n for (auto el : elements()) {\n if (el->is_interpolant()) return true;\n }\n return false;\n }\n virtual void rtrim();\n\n virtual size_t hash()\n {\n if (hash_ == 0) {\n for (auto string : elements())\n hash_combine(hash_, string->hash());\n }\n return hash_;\n }\n\n virtual void set_delayed(bool delayed) {\n is_delayed(delayed);\n }\n\n virtual bool operator==(const Expression& rhs) const;\n ATTACH_AST_OPERATIONS(String_Schema)\n ATTACH_OPERATIONS()\n };\n\n ////////////////////////////////////////////////////////\n // Flat strings -- the lowest level of raw textual data.\n ////////////////////////////////////////////////////////\n class String_Constant : public String {\n ADD_PROPERTY(char, quote_mark)\n ADD_PROPERTY(bool, can_compress_whitespace)\n HASH_CONSTREF(std::string, value)\n protected:\n size_t hash_;\n public:\n String_Constant(const String_Constant* ptr)\n : String(ptr),\n quote_mark_(ptr->quote_mark_),\n can_compress_whitespace_(ptr->can_compress_whitespace_),\n value_(ptr->value_),\n hash_(ptr->hash_)\n { }\n String_Constant(ParserState pstate, std::string val, bool css = true)\n : String(pstate), quote_mark_(0), can_compress_whitespace_(false), value_(read_css_string(val, css)), hash_(0)\n { }\n String_Constant(ParserState pstate, const char* beg, bool css = true)\n : String(pstate), quote_mark_(0), can_compress_whitespace_(false), value_(read_css_string(std::string(beg), css)), hash_(0)\n { }\n String_Constant(ParserState pstate, const char* beg, const char* end, bool css = true)\n : String(pstate), quote_mark_(0), can_compress_whitespace_(false), value_(read_css_string(std::string(beg, end-beg), css)), hash_(0)\n { }\n String_Constant(ParserState pstate, const Token& tok, bool css = true)\n : String(pstate), quote_mark_(0), can_compress_whitespace_(false), value_(read_css_string(std::string(tok.begin, tok.end), css)), hash_(0)\n { }\n std::string type() const { return \"string\"; }\n static std::string type_name() { return \"string\"; }\n virtual bool is_invisible() const;\n virtual void rtrim();\n\n virtual size_t hash()\n {\n if (hash_ == 0) {\n hash_ = std::hash()(value_);\n }\n return hash_;\n }\n\n virtual bool operator==(const Expression& rhs) const;\n virtual std::string inspect() const; // quotes are forced on inspection\n\n // static char auto_quote() { return '*'; }\n static char double_quote() { return '\"'; }\n static char single_quote() { return '\\''; }\n\n ATTACH_AST_OPERATIONS(String_Constant)\n ATTACH_OPERATIONS()\n };\n\n ////////////////////////////////////////////////////////\n // Possibly quoted string (unquote on instantiation)\n ////////////////////////////////////////////////////////\n class String_Quoted : public String_Constant {\n public:\n String_Quoted(ParserState pstate, std::string val, char q = 0,\n bool keep_utf8_escapes = false, bool skip_unquoting = false,\n bool strict_unquoting = true, bool css = true)\n : String_Constant(pstate, val, css)\n {\n if (skip_unquoting == false) {\n value_ = unquote(value_, "e_mark_, keep_utf8_escapes, strict_unquoting);\n }\n if (q && quote_mark_) quote_mark_ = q;\n }\n String_Quoted(const String_Quoted* ptr)\n : String_Constant(ptr)\n { }\n virtual bool operator==(const Expression& rhs) const;\n virtual std::string inspect() const; // quotes are forced on inspection\n ATTACH_AST_OPERATIONS(String_Quoted)\n ATTACH_OPERATIONS()\n };\n\n /////////////////\n // Media queries.\n /////////////////\n class Media_Query : public Expression,\n public Vectorized {\n ADD_PROPERTY(String_Obj, media_type)\n ADD_PROPERTY(bool, is_negated)\n ADD_PROPERTY(bool, is_restricted)\n public:\n Media_Query(ParserState pstate,\n String_Obj t = 0, size_t s = 0, bool n = false, bool r = false)\n : Expression(pstate), Vectorized(s),\n media_type_(t), is_negated_(n), is_restricted_(r)\n { }\n Media_Query(const Media_Query* ptr)\n : Expression(ptr),\n Vectorized(*ptr),\n media_type_(ptr->media_type_),\n is_negated_(ptr->is_negated_),\n is_restricted_(ptr->is_restricted_)\n { }\n ATTACH_AST_OPERATIONS(Media_Query)\n ATTACH_OPERATIONS()\n };\n\n ////////////////////////////////////////////////////\n // Media expressions (for use inside media queries).\n ////////////////////////////////////////////////////\n class Media_Query_Expression : public Expression {\n ADD_PROPERTY(Expression_Obj, feature)\n ADD_PROPERTY(Expression_Obj, value)\n ADD_PROPERTY(bool, is_interpolated)\n public:\n Media_Query_Expression(ParserState pstate,\n Expression_Obj f, Expression_Obj v, bool i = false)\n : Expression(pstate), feature_(f), value_(v), is_interpolated_(i)\n { }\n Media_Query_Expression(const Media_Query_Expression* ptr)\n : Expression(ptr),\n feature_(ptr->feature_),\n value_(ptr->value_),\n is_interpolated_(ptr->is_interpolated_)\n { }\n ATTACH_AST_OPERATIONS(Media_Query_Expression)\n ATTACH_OPERATIONS()\n };\n\n ////////////////////\n // `@supports` rule.\n ////////////////////\n class Supports_Block : public Has_Block {\n ADD_PROPERTY(Supports_Condition_Obj, condition)\n public:\n Supports_Block(ParserState pstate, Supports_Condition_Obj condition, Block_Obj block = 0)\n : Has_Block(pstate, block), condition_(condition)\n { statement_type(SUPPORTS); }\n Supports_Block(const Supports_Block* ptr)\n : Has_Block(ptr), condition_(ptr->condition_)\n { statement_type(SUPPORTS); }\n bool bubbles() { return true; }\n ATTACH_AST_OPERATIONS(Supports_Block)\n ATTACH_OPERATIONS()\n };\n\n //////////////////////////////////////////////////////\n // The abstract superclass of all Supports conditions.\n //////////////////////////////////////////////////////\n class Supports_Condition : public Expression {\n public:\n Supports_Condition(ParserState pstate)\n : Expression(pstate)\n { }\n Supports_Condition(const Supports_Condition* ptr)\n : Expression(ptr)\n { }\n virtual bool needs_parens(Supports_Condition_Obj cond) const { return false; }\n ATTACH_AST_OPERATIONS(Supports_Condition)\n ATTACH_OPERATIONS()\n };\n\n ////////////////////////////////////////////////////////////\n // An operator condition (e.g. `CONDITION1 and CONDITION2`).\n ////////////////////////////////////////////////////////////\n class Supports_Operator : public Supports_Condition {\n public:\n enum Operand { AND, OR };\n private:\n ADD_PROPERTY(Supports_Condition_Obj, left);\n ADD_PROPERTY(Supports_Condition_Obj, right);\n ADD_PROPERTY(Operand, operand);\n public:\n Supports_Operator(ParserState pstate, Supports_Condition_Obj l, Supports_Condition_Obj r, Operand o)\n : Supports_Condition(pstate), left_(l), right_(r), operand_(o)\n { }\n Supports_Operator(const Supports_Operator* ptr)\n : Supports_Condition(ptr),\n left_(ptr->left_),\n right_(ptr->right_),\n operand_(ptr->operand_)\n { }\n virtual bool needs_parens(Supports_Condition_Obj cond) const;\n ATTACH_AST_OPERATIONS(Supports_Operator)\n ATTACH_OPERATIONS()\n };\n\n //////////////////////////////////////////\n // A negation condition (`not CONDITION`).\n //////////////////////////////////////////\n class Supports_Negation : public Supports_Condition {\n private:\n ADD_PROPERTY(Supports_Condition_Obj, condition);\n public:\n Supports_Negation(ParserState pstate, Supports_Condition_Obj c)\n : Supports_Condition(pstate), condition_(c)\n { }\n Supports_Negation(const Supports_Negation* ptr)\n : Supports_Condition(ptr), condition_(ptr->condition_)\n { }\n virtual bool needs_parens(Supports_Condition_Obj cond) const;\n ATTACH_AST_OPERATIONS(Supports_Negation)\n ATTACH_OPERATIONS()\n };\n\n /////////////////////////////////////////////////////\n // A declaration condition (e.g. `(feature: value)`).\n /////////////////////////////////////////////////////\n class Supports_Declaration : public Supports_Condition {\n private:\n ADD_PROPERTY(Expression_Obj, feature);\n ADD_PROPERTY(Expression_Obj, value);\n public:\n Supports_Declaration(ParserState pstate, Expression_Obj f, Expression_Obj v)\n : Supports_Condition(pstate), feature_(f), value_(v)\n { }\n Supports_Declaration(const Supports_Declaration* ptr)\n : Supports_Condition(ptr),\n feature_(ptr->feature_),\n value_(ptr->value_)\n { }\n virtual bool needs_parens(Supports_Condition_Obj cond) const { return false; }\n ATTACH_AST_OPERATIONS(Supports_Declaration)\n ATTACH_OPERATIONS()\n };\n\n ///////////////////////////////////////////////\n // An interpolation condition (e.g. `#{$var}`).\n ///////////////////////////////////////////////\n class Supports_Interpolation : public Supports_Condition {\n private:\n ADD_PROPERTY(Expression_Obj, value);\n public:\n Supports_Interpolation(ParserState pstate, Expression_Obj v)\n : Supports_Condition(pstate), value_(v)\n { }\n Supports_Interpolation(const Supports_Interpolation* ptr)\n : Supports_Condition(ptr),\n value_(ptr->value_)\n { }\n virtual bool needs_parens(Supports_Condition_Obj cond) const { return false; }\n ATTACH_AST_OPERATIONS(Supports_Interpolation)\n ATTACH_OPERATIONS()\n };\n\n /////////////////////////////////////////////////\n // At root expressions (for use inside @at-root).\n /////////////////////////////////////////////////\n class At_Root_Query : public Expression {\n private:\n ADD_PROPERTY(Expression_Obj, feature)\n ADD_PROPERTY(Expression_Obj, value)\n public:\n At_Root_Query(ParserState pstate, Expression_Obj f = 0, Expression_Obj v = 0, bool i = false)\n : Expression(pstate), feature_(f), value_(v)\n { }\n At_Root_Query(const At_Root_Query* ptr)\n : Expression(ptr),\n feature_(ptr->feature_),\n value_(ptr->value_)\n { }\n bool exclude(std::string str);\n ATTACH_AST_OPERATIONS(At_Root_Query)\n ATTACH_OPERATIONS()\n };\n\n ///////////\n // At-root.\n ///////////\n class At_Root_Block : public Has_Block {\n ADD_PROPERTY(At_Root_Query_Obj, expression)\n public:\n At_Root_Block(ParserState pstate, Block_Obj b = 0, At_Root_Query_Obj e = 0)\n : Has_Block(pstate, b), expression_(e)\n { statement_type(ATROOT); }\n At_Root_Block(const At_Root_Block* ptr)\n : Has_Block(ptr), expression_(ptr->expression_)\n { statement_type(ATROOT); }\n bool bubbles() { return true; }\n bool exclude_node(Statement_Obj s) {\n if (expression() == 0)\n {\n return s->statement_type() == Statement::RULESET;\n }\n\n if (s->statement_type() == Statement::DIRECTIVE)\n {\n if (Directive_Obj dir = Cast(s))\n {\n std::string keyword(dir->keyword());\n if (keyword.length() > 0) keyword.erase(0, 1);\n return expression()->exclude(keyword);\n }\n }\n if (s->statement_type() == Statement::MEDIA)\n {\n return expression()->exclude(\"media\");\n }\n if (s->statement_type() == Statement::RULESET)\n {\n return expression()->exclude(\"rule\");\n }\n if (s->statement_type() == Statement::SUPPORTS)\n {\n return expression()->exclude(\"supports\");\n }\n if (Directive_Obj dir = Cast(s))\n {\n if (dir->is_keyframes()) return expression()->exclude(\"keyframes\");\n }\n return false;\n }\n ATTACH_AST_OPERATIONS(At_Root_Block)\n ATTACH_OPERATIONS()\n };\n\n //////////////////\n // The null value.\n //////////////////\n class Null : public Value {\n public:\n Null(ParserState pstate) : Value(pstate) { concrete_type(NULL_VAL); }\n Null(const Null* ptr) : Value(ptr) { concrete_type(NULL_VAL); }\n std::string type() const { return \"null\"; }\n static std::string type_name() { return \"null\"; }\n bool is_invisible() const { return true; }\n operator bool() { return false; }\n bool is_false() { return true; }\n\n virtual size_t hash()\n {\n return -1;\n }\n\n virtual bool operator== (const Expression& rhs) const;\n\n ATTACH_AST_OPERATIONS(Null)\n ATTACH_OPERATIONS()\n };\n\n /////////////////////////////////\n // Thunks for delayed evaluation.\n /////////////////////////////////\n class Thunk : public Expression {\n ADD_PROPERTY(Expression_Obj, expression)\n ADD_PROPERTY(Env*, environment)\n public:\n \nThunk(ParserState pstate, Expression_Obj exp, Env* env = 0)\n : Expression(pstate), expression_(exp), environment_(env)\n { }\n };\n\n /////////////////////////////////////////////////////////\n // Individual parameter objects for mixins and functions.\n /////////////////////////////////////////////////////////\n class Parameter : public AST_Node {\n ADD_CONSTREF(std::string, name)\n ADD_PROPERTY(Expression_Obj, default_value)\n ADD_PROPERTY(bool, is_rest_parameter)\n public:\n Parameter(ParserState pstate,\n std::string n, Expression_Obj def = 0, bool rest = false)\n : AST_Node(pstate), name_(n), default_value_(def), is_rest_parameter_(rest)\n {\n // tried to come up with a spec test for this, but it does no longer\n // get past the parser (it error out earlier). A spec test was added!\n // if (default_value_ && is_rest_parameter_) {\n // error(\"variable-length parameter may not have a default value\", pstate_);\n // }\n }\n Parameter(const Parameter* ptr)\n : AST_Node(ptr),\n name_(ptr->name_),\n default_value_(ptr->default_value_),\n is_rest_parameter_(ptr->is_rest_parameter_)\n {\n // tried to come up with a spec test for this, but it does no longer\n // get past the parser (it error out earlier). A spec test was added!\n // if (default_value_ && is_rest_parameter_) {\n // error(\"variable-length parameter may not have a default value\", pstate_);\n // }\n }\n ATTACH_AST_OPERATIONS(Parameter)\n ATTACH_OPERATIONS()\n };\n\n /////////////////////////////////////////////////////////////////////////\n // Parameter lists -- in their own class to facilitate context-sensitive\n // error checking (e.g., ensuring that all optional parameters follow all\n // required parameters).\n /////////////////////////////////////////////////////////////////////////\n class Parameters : public AST_Node, public Vectorized {\n ADD_PROPERTY(bool, has_optional_parameters)\n ADD_PROPERTY(bool, has_rest_parameter)\n protected:\n void adjust_after_pushing(Parameter_Obj p)\n {\n if (p->default_value()) {\n if (has_rest_parameter()) {\n coreError(\"optional parameters may not be combined with variable-length parameters\", p->pstate());\n }\n has_optional_parameters(true);\n }\n else if (p->is_rest_parameter()) {\n if (has_rest_parameter()) {\n coreError(\"functions and mixins cannot have more than one variable-length parameter\", p->pstate());\n }\n has_rest_parameter(true);\n }\n else {\n if (has_rest_parameter()) {\n coreError(\"required parameters must precede variable-length parameters\", p->pstate());\n }\n if (has_optional_parameters()) {\n coreError(\"required parameters must precede optional parameters\", p->pstate());\n }\n }\n }\n public:\n Parameters(ParserState pstate)\n : AST_Node(pstate),\n Vectorized(),\n has_optional_parameters_(false),\n has_rest_parameter_(false)\n { }\n Parameters(const Parameters* ptr)\n : AST_Node(ptr),\n Vectorized(*ptr),\n has_optional_parameters_(ptr->has_optional_parameters_),\n has_rest_parameter_(ptr->has_rest_parameter_)\n { }\n ATTACH_AST_OPERATIONS(Parameters)\n ATTACH_OPERATIONS()\n };\n\n /////////////////////////////////////////\n // Abstract base class for CSS selectors.\n /////////////////////////////////////////\n class Selector : public Expression {\n // ADD_PROPERTY(bool, has_reference)\n // line break before list separator\n ADD_PROPERTY(bool, has_line_feed)\n // line break after list separator\n ADD_PROPERTY(bool, has_line_break)\n // maybe we have optional flag\n ADD_PROPERTY(bool, is_optional)\n // parent block pointers\n\n // must not be a reference counted object\n // otherwise we create circular references\n ADD_PROPERTY(Media_Block_Ptr, media_block)\n protected:\n size_t hash_;\n public:\n Selector(ParserState pstate)\n : Expression(pstate),\n has_line_feed_(false),\n has_line_break_(false),\n is_optional_(false),\n media_block_(0),\n hash_(0)\n { concrete_type(SELECTOR); }\n Selector(const Selector* ptr)\n : Expression(ptr),\n // has_reference_(ptr->has_reference_),\n has_line_feed_(ptr->has_line_feed_),\n has_line_break_(ptr->has_line_break_),\n is_optional_(ptr->is_optional_),\n media_block_(ptr->media_block_),\n hash_(ptr->hash_)\n { concrete_type(SELECTOR); }\n virtual ~Selector() = 0;\n virtual size_t hash() = 0;\n virtual unsigned long specificity() const = 0;\n virtual void set_media_block(Media_Block_Ptr mb) {\n media_block(mb);\n }\n virtual bool has_parent_ref() const {\n return false;\n }\n virtual bool has_real_parent_ref() const {\n return false;\n }\n // dispatch to correct handlers\n virtual bool operator<(const Selector& rhs) const = 0;\n virtual bool operator==(const Selector& rhs) const = 0;\n ATTACH_VIRTUAL_AST_OPERATIONS(Selector);\n };\n inline Selector::~Selector() { }\n\n /////////////////////////////////////////////////////////////////////////\n // Interpolated selectors -- the interpolated String will be expanded and\n // re-parsed into a normal selector class.\n /////////////////////////////////////////////////////////////////////////\n class Selector_Schema : public AST_Node {\n ADD_PROPERTY(String_Obj, contents)\n ADD_PROPERTY(bool, connect_parent);\n // must not be a reference counted object\n // otherwise we create circular references\n ADD_PROPERTY(Media_Block_Ptr, media_block)\n // store computed hash\n size_t hash_;\n public:\n Selector_Schema(ParserState pstate, String_Obj c)\n : AST_Node(pstate),\n contents_(c),\n connect_parent_(true),\n media_block_(NULL),\n hash_(0)\n { }\n Selector_Schema(const Selector_Schema* ptr)\n : AST_Node(ptr),\n contents_(ptr->contents_),\n connect_parent_(ptr->connect_parent_),\n media_block_(ptr->media_block_),\n hash_(ptr->hash_)\n { }\n virtual bool has_parent_ref() const;\n virtual bool has_real_parent_ref() const;\n virtual bool operator<(const Selector& rhs) const;\n virtual bool operator==(const Selector& rhs) const;\n // selector schema is not yet a final selector, so we do not\n // have a specificity for it yet. We need to\n virtual unsigned long specificity() const { return 0; }\n virtual size_t hash() {\n if (hash_ == 0) {\n hash_combine(hash_, contents_->hash());\n }\n return hash_;\n }\n ATTACH_AST_OPERATIONS(Selector_Schema)\n ATTACH_OPERATIONS()\n };\n\n ////////////////////////////////////////////\n // Abstract base class for simple selectors.\n ////////////////////////////////////////////\n class Simple_Selector : public Selector {\n ADD_CONSTREF(std::string, ns)\n ADD_CONSTREF(std::string, name)\n ADD_PROPERTY(Simple_Type, simple_type)\n ADD_PROPERTY(bool, has_ns)\n public:\n Simple_Selector(ParserState pstate, std::string n = \"\")\n : Selector(pstate), ns_(\"\"), name_(n), has_ns_(false)\n {\n simple_type(SIMPLE);\n size_t pos = n.find('|');\n // found some namespace\n if (pos != std::string::npos) {\n has_ns_ = true;\n ns_ = n.substr(0, pos);\n name_ = n.substr(pos + 1);\n }\n }\n Simple_Selector(const Simple_Selector* ptr)\n : Selector(ptr),\n ns_(ptr->ns_),\n name_(ptr->name_),\n has_ns_(ptr->has_ns_)\n { simple_type(SIMPLE); }\n virtual std::string ns_name() const\n {\n std::string name(\"\");\n if (has_ns_)\n name += ns_ + \"|\";\n return name + name_;\n }\n virtual size_t hash()\n {\n if (hash_ == 0) {\n hash_combine(hash_, std::hash()(SELECTOR));\n hash_combine(hash_, std::hash()(ns()));\n hash_combine(hash_, std::hash()(name()));\n }\n return hash_;\n }\n // namespace compare functions\n bool is_ns_eq(const Simple_Selector& r) const;\n // namespace query functions\n bool is_universal_ns() const\n {\n return has_ns_ && ns_ == \"*\";\n }\n bool has_universal_ns() const\n {\n return !has_ns_ || ns_ == \"*\";\n }\n bool is_empty_ns() const\n {\n return !has_ns_ || ns_ == \"\";\n }\n bool has_empty_ns() const\n {\n return has_ns_ && ns_ == \"\";\n }\n bool has_qualified_ns() const\n {\n return has_ns_ && ns_ != \"\" && ns_ != \"*\";\n }\n // name query functions\n bool is_universal() const\n {\n return name_ == \"*\";\n }\n\n virtual bool has_placeholder() {\n return false;\n }\n\n virtual ~Simple_Selector() = 0;\n virtual Compound_Selector_Ptr unify_with(Compound_Selector_Ptr);\n virtual bool has_parent_ref() const { return false; };\n virtual bool has_real_parent_ref() const { return false; };\n virtual bool is_pseudo_element() const { return false; }\n\n virtual bool is_superselector_of(Compound_Selector_Obj sub) { return false; }\n\n virtual bool operator==(const Selector& rhs) const;\n virtual bool operator==(const Simple_Selector& rhs) const;\n inline bool operator!=(const Simple_Selector& rhs) const { return !(*this == rhs); }\n\n bool operator<(const Selector& rhs) const;\n bool operator<(const Simple_Selector& rhs) const;\n // default implementation should work for most of the simple selectors (otherwise overload)\n ATTACH_VIRTUAL_AST_OPERATIONS(Simple_Selector);\n ATTACH_OPERATIONS();\n };\n inline Simple_Selector::~Simple_Selector() { }\n\n\n //////////////////////////////////\n // The Parent Selector Expression.\n //////////////////////////////////\n // parent selectors can occur in selectors but also\n // inside strings in declarations (Compound_Selector).\n // only one simple parent selector means the first case.\n class Parent_Selector : public Simple_Selector {\n ADD_PROPERTY(bool, real)\n public:\n Parent_Selector(ParserState pstate, bool r = true)\n : Simple_Selector(pstate, \"&\"), real_(r)\n { /* has_reference(true); */ }\n Parent_Selector(const Parent_Selector* ptr)\n : Simple_Selector(ptr), real_(ptr->real_)\n { /* has_reference(true); */ }\n bool is_real_parent_ref() const { return real(); };\n virtual bool has_parent_ref() const { return true; };\n virtual bool has_real_parent_ref() const { return is_real_parent_ref(); };\n virtual unsigned long specificity() const\n {\n return 0;\n }\n std::string type() const { return \"selector\"; }\n static std::string type_name() { return \"selector\"; }\n ATTACH_AST_OPERATIONS(Parent_Selector)\n ATTACH_OPERATIONS()\n };\n\n /////////////////////////////////////////////////////////////////////////\n // Placeholder selectors (e.g., \"%foo\") for use in extend-only selectors.\n /////////////////////////////////////////////////////////////////////////\n class Placeholder_Selector : public Simple_Selector {\n public:\n Placeholder_Selector(ParserState pstate, std::string n)\n : Simple_Selector(pstate, n)\n { }\n Placeholder_Selector(const Placeholder_Selector* ptr)\n : Simple_Selector(ptr)\n { }\n virtual unsigned long specificity() const\n {\n return Constants::Specificity_Base;\n }\n virtual bool has_placeholder() {\n return true;\n }\n virtual ~Placeholder_Selector() {};\n ATTACH_AST_OPERATIONS(Placeholder_Selector)\n ATTACH_OPERATIONS()\n };\n\n /////////////////////////////////////////////////////////////////////\n // Element selectors (and the universal selector) -- e.g., div, span, *.\n /////////////////////////////////////////////////////////////////////\n class Element_Selector : public Simple_Selector {\n public:\n Element_Selector(ParserState pstate, std::string n)\n : Simple_Selector(pstate, n)\n { }\n Element_Selector(const Element_Selector* ptr)\n : Simple_Selector(ptr)\n { }\n virtual unsigned long specificity() const\n {\n if (name() == \"*\") return 0;\n else return Constants::Specificity_Element;\n }\n virtual Simple_Selector_Ptr unify_with(Simple_Selector_Ptr);\n virtual Compound_Selector_Ptr unify_with(Compound_Selector_Ptr);\n virtual bool operator==(const Simple_Selector& rhs) const;\n virtual bool operator==(const Element_Selector& rhs) const;\n virtual bool operator<(const Simple_Selector& rhs) const;\n virtual bool operator<(const Element_Selector& rhs) const;\n ATTACH_AST_OPERATIONS(Element_Selector)\n ATTACH_OPERATIONS()\n };\n\n ////////////////////////////////////////////////\n // Class selectors -- i.e., .foo.\n ////////////////////////////////////////////////\n class Class_Selector : public Simple_Selector {\n public:\n Class_Selector(ParserState pstate, std::string n)\n : Simple_Selector(pstate, n)\n { }\n Class_Selector(const Class_Selector* ptr)\n : Simple_Selector(ptr)\n { }\n virtual unsigned long specificity() const\n {\n return Constants::Specificity_Class;\n }\n virtual Compound_Selector_Ptr unify_with(Compound_Selector_Ptr);\n ATTACH_AST_OPERATIONS(Class_Selector)\n ATTACH_OPERATIONS()\n };\n\n ////////////////////////////////////////////////\n // ID selectors -- i.e., #foo.\n ////////////////////////////////////////////////\n class Id_Selector : public Simple_Selector {\n public:\n Id_Selector(ParserState pstate, std::string n)\n : Simple_Selector(pstate, n)\n { }\n Id_Selector(const Id_Selector* ptr)\n : Simple_Selector(ptr)\n { }\n virtual unsigned long specificity() const\n {\n return Constants::Specificity_ID;\n }\n virtual Compound_Selector_Ptr unify_with(Compound_Selector_Ptr);\n ATTACH_AST_OPERATIONS(Id_Selector)\n ATTACH_OPERATIONS()\n };\n\n ///////////////////////////////////////////////////\n // Attribute selectors -- e.g., [src*=\".jpg\"], etc.\n ///////////////////////////////////////////////////\n class Attribute_Selector : public Simple_Selector {\n ADD_CONSTREF(std::string, matcher)\n // this cannot be changed to obj atm!!!!!!????!!!!!!!\n ADD_PROPERTY(String_Obj, value) // might be interpolated\n ADD_PROPERTY(char, modifier);\n public:\n Attribute_Selector(ParserState pstate, std::string n, std::string m, String_Obj v, char o = 0)\n : Simple_Selector(pstate, n), matcher_(m), value_(v), modifier_(o)\n { simple_type(ATTR_SEL); }\n Attribute_Selector(const Attribute_Selector* ptr)\n : Simple_Selector(ptr),\n matcher_(ptr->matcher_),\n value_(ptr->value_),\n modifier_(ptr->modifier_)\n { simple_type(ATTR_SEL); }\n virtual size_t hash()\n {\n if (hash_ == 0) {\n hash_combine(hash_, Simple_Selector::hash());\n hash_combine(hash_, std::hash()(matcher()));\n if (value_) hash_combine(hash_, value_->hash());\n }\n return hash_;\n }\n virtual unsigned long specificity() const\n {\n return Constants::Specificity_Attr;\n }\n virtual bool operator==(const Simple_Selector& rhs) const;\n virtual bool operator==(const Attribute_Selector& rhs) const;\n virtual bool operator<(const Simple_Selector& rhs) const;\n virtual bool operator<(const Attribute_Selector& rhs) const;\n ATTACH_AST_OPERATIONS(Attribute_Selector)\n ATTACH_OPERATIONS()\n };\n\n //////////////////////////////////////////////////////////////////\n // Pseudo selectors -- e.g., :first-child, :nth-of-type(...), etc.\n //////////////////////////////////////////////////////////////////\n /* '::' starts a pseudo-element, ':' a pseudo-class */\n /* Except :first-line, :first-letter, :before and :after */\n /* Note that pseudo-elements are restricted to one per selector */\n /* and occur only in the last simple_selector_sequence. */\n inline bool is_pseudo_class_element(const std::string& name)\n {\n return name == \":before\" ||\n name == \":after\" ||\n name == \":first-line\" ||\n name == \":first-letter\";\n }\n\n // Pseudo Selector cannot have any namespace?\n class Pseudo_Selector : public Simple_Selector {\n ADD_PROPERTY(String_Obj, expression)\n public:\n Pseudo_Selector(ParserState pstate, std::string n, String_Obj expr = 0)\n : Simple_Selector(pstate, n), expression_(expr)\n { simple_type(PSEUDO_SEL); }\n Pseudo_Selector(const Pseudo_Selector* ptr)\n : Simple_Selector(ptr), expression_(ptr->expression_)\n { simple_type(PSEUDO_SEL); }\n\n // A pseudo-element is made of two colons (::) followed by the name.\n // The `::` notation is introduced by the current document in order to\n // establish a discrimination between pseudo-classes and pseudo-elements.\n // For compatibility with existing style sheets, user agents must also\n // accept the previous one-colon notation for pseudo-elements introduced\n // in CSS levels 1 and 2 (namely, :first-line, :first-letter, :before and\n // :after). This compatibility is not allowed for the new pseudo-elements\n // introduced in this specification.\n virtual bool is_pseudo_element() const\n {\n return (name_[0] == ':' && name_[1] == ':')\n || is_pseudo_class_element(name_);\n }\n virtual size_t hash()\n {\n if (hash_ == 0) {\n hash_combine(hash_, Simple_Selector::hash());\n if (expression_) hash_combine(hash_, expression_->hash());\n }\n return hash_;\n }\n virtual unsigned long specificity() const\n {\n if (is_pseudo_element())\n return Constants::Specificity_Element;\n return Constants::Specificity_Pseudo;\n }\n virtual bool operator==(const Simple_Selector& rhs) const;\n virtual bool operator==(const Pseudo_Selector& rhs) const;\n virtual bool operator<(const Simple_Selector& rhs) const;\n virtual bool operator<(const Pseudo_Selector& rhs) const;\n virtual Compound_Selector_Ptr unify_with(Compound_Selector_Ptr);\n ATTACH_AST_OPERATIONS(Pseudo_Selector)\n ATTACH_OPERATIONS()\n };\n\n /////////////////////////////////////////////////\n // Wrapped selector -- pseudo selector that takes a list of selectors as argument(s) e.g., :not(:first-of-type), :-moz-any(ol p.blah, ul, menu, dir)\n /////////////////////////////////////////////////\n class Wrapped_Selector : public Simple_Selector {\n ADD_PROPERTY(Selector_List_Obj, selector)\n public:\n Wrapped_Selector(ParserState pstate, std::string n, Selector_List_Obj sel)\n : Simple_Selector(pstate, n), selector_(sel)\n { simple_type(WRAPPED_SEL); }\n Wrapped_Selector(const Wrapped_Selector* ptr)\n : Simple_Selector(ptr), selector_(ptr->selector_)\n { simple_type(WRAPPED_SEL); }\n virtual bool is_superselector_of(Wrapped_Selector_Obj sub);\n // Selectors inside the negation pseudo-class are counted like any\n // other, but the negation itself does not count as a pseudo-class.\n virtual size_t hash();\n virtual bool has_parent_ref() const;\n virtual bool has_real_parent_ref() const;\n virtual unsigned long specificity() const;\n virtual bool find ( bool (*f)(AST_Node_Obj) );\n virtual bool operator==(const Simple_Selector& rhs) const;\n virtual bool operator==(const Wrapped_Selector& rhs) const;\n virtual bool operator<(const Simple_Selector& rhs) const;\n virtual bool operator<(const Wrapped_Selector& rhs) const;\n virtual void cloneChildren();\n ATTACH_AST_OPERATIONS(Wrapped_Selector)\n ATTACH_OPERATIONS()\n };\n\n ////////////////////////////////////////////////////////////////////////////\n // Simple selector sequences. Maintains flags indicating whether it contains\n // any parent references or placeholders, to simplify expansion.\n ////////////////////////////////////////////////////////////////////////////\n class Compound_Selector : public Selector, public Vectorized {\n private:\n ComplexSelectorSet sources_;\n ADD_PROPERTY(bool, extended);\n ADD_PROPERTY(bool, has_parent_reference);\n protected:\n void adjust_after_pushing(Simple_Selector_Obj s)\n {\n // if (s->has_reference()) has_reference(true);\n // if (s->has_placeholder()) has_placeholder(true);\n }\n public:\n Compound_Selector(ParserState pstate, size_t s = 0)\n : Selector(pstate),\n Vectorized(s),\n extended_(false),\n has_parent_reference_(false)\n { }\n Compound_Selector(const Compound_Selector* ptr)\n : Selector(ptr),\n Vectorized(*ptr),\n extended_(ptr->extended_),\n has_parent_reference_(ptr->has_parent_reference_)\n { }\n bool contains_placeholder() {\n for (size_t i = 0, L = length(); i < L; ++i) {\n if ((*this)[i]->has_placeholder()) return true;\n }\n return false;\n };\n\n void append(Simple_Selector_Ptr element);\n\n bool is_universal() const\n {\n return length() == 1 && (*this)[0]->is_universal();\n }\n\n Complex_Selector_Obj to_complex();\n Compound_Selector_Ptr unify_with(Compound_Selector_Ptr rhs);\n // virtual Placeholder_Selector_Ptr find_placeholder();\n virtual bool has_parent_ref() const;\n virtual bool has_real_parent_ref() const;\n Simple_Selector_Ptr base() const {\n if (length() == 0) return 0;\n // ToDo: why is this needed?\n if (Cast((*this)[0]))\n return (*this)[0];\n return 0;\n }\n virtual bool is_superselector_of(Compound_Selector_Obj sub, std::string wrapped = \"\");\n virtual bool is_superselector_of(Complex_Selector_Obj sub, std::string wrapped = \"\");\n virtual bool is_superselector_of(Selector_List_Obj sub, std::string wrapped = \"\");\n virtual size_t hash()\n {\n if (Selector::hash_ == 0) {\n hash_combine(Selector::hash_, std::hash()(SELECTOR));\n if (length()) hash_combine(Selector::hash_, Vectorized::hash());\n }\n return Selector::hash_;\n }\n virtual unsigned long specificity() const\n {\n int sum = 0;\n for (size_t i = 0, L = length(); i < L; ++i)\n { sum += (*this)[i]->specificity(); }\n return sum;\n }\n\n virtual bool has_placeholder()\n {\n if (length() == 0) return false;\n if (Simple_Selector_Obj ss = elements().front()) {\n if (ss->has_placeholder()) return true;\n }\n return false;\n }\n\n bool is_empty_reference()\n {\n return length() == 1 &&\n Cast((*this)[0]);\n }\n\n virtual bool find ( bool (*f)(AST_Node_Obj) );\n virtual bool operator<(const Selector& rhs) const;\n virtual bool operator==(const Selector& rhs) const;\n virtual bool operator<(const Compound_Selector& rhs) const;\n virtual bool operator==(const Compound_Selector& rhs) const;\n inline bool operator!=(const Compound_Selector& rhs) const { return !(*this == rhs); }\n\n ComplexSelectorSet& sources() { return sources_; }\n void clearSources() { sources_.clear(); }\n void mergeSources(ComplexSelectorSet& sources);\n\n Compound_Selector_Ptr minus(Compound_Selector_Ptr rhs);\n virtual void cloneChildren();\n ATTACH_AST_OPERATIONS(Compound_Selector)\n ATTACH_OPERATIONS()\n };\n\n ////////////////////////////////////////////////////////////////////////////\n // General selectors -- i.e., simple sequences combined with one of the four\n // CSS selector combinators (\">\", \"+\", \"~\", and whitespace). Essentially a\n // linked list.\n ////////////////////////////////////////////////////////////////////////////\n class Complex_Selector : public Selector {\n public:\n enum Combinator { ANCESTOR_OF, PARENT_OF, PRECEDES, ADJACENT_TO, REFERENCE };\n private:\n HASH_CONSTREF(Combinator, combinator)\n HASH_PROPERTY(Compound_Selector_Obj, head)\n HASH_PROPERTY(Complex_Selector_Obj, tail)\n HASH_PROPERTY(String_Obj, reference);\n public:\n bool contains_placeholder() {\n if (head() && head()->contains_placeholder()) return true;\n if (tail() && tail()->contains_placeholder()) return true;\n return false;\n };\n Complex_Selector(ParserState pstate,\n Combinator c = ANCESTOR_OF,\n Compound_Selector_Obj h = 0,\n Complex_Selector_Obj t = 0,\n String_Obj r = 0)\n : Selector(pstate),\n combinator_(c),\n head_(h), tail_(t),\n reference_(r)\n {}\n Complex_Selector(const Complex_Selector* ptr)\n : Selector(ptr),\n combinator_(ptr->combinator_),\n head_(ptr->head_), tail_(ptr->tail_),\n reference_(ptr->reference_)\n {};\n virtual bool has_parent_ref() const;\n virtual bool has_real_parent_ref() const;\n\n Complex_Selector_Obj skip_empty_reference()\n {\n if ((!head_ || !head_->length() || head_->is_empty_reference()) &&\n combinator() == Combinator::ANCESTOR_OF)\n {\n if (!tail_) return 0;\n tail_->has_line_feed_ = this->has_line_feed_;\n // tail_->has_line_break_ = this->has_line_break_;\n return tail_->skip_empty_reference();\n }\n return this;\n }\n\n // can still have a tail\n bool is_empty_ancestor() const\n {\n return (!head() || head()->length() == 0) &&\n combinator() == Combinator::ANCESTOR_OF;\n }\n\n Selector_List_Ptr tails(Selector_List_Ptr tails);\n\n // front returns the first real tail\n // skips over parent and empty ones\n Complex_Selector_Obj first();\n // last returns the last real tail\n Complex_Selector_Obj last();\n\n // some shortcuts that should be removed\n Complex_Selector_Obj innermost() { return last(); };\n\n size_t length() const;\n Selector_List_Ptr resolve_parent_refs(std::vector& pstack, Backtraces& traces, bool implicit_parent = true);\n virtual bool is_superselector_of(Compound_Selector_Obj sub, std::string wrapping = \"\");\n virtual bool is_superselector_of(Complex_Selector_Obj sub, std::string wrapping = \"\");\n virtual bool is_superselector_of(Selector_List_Obj sub, std::string wrapping = \"\");\n Selector_List_Ptr unify_with(Complex_Selector_Ptr rhs);\n Combinator clear_innermost();\n void append(Complex_Selector_Obj, Backtraces& traces);\n void set_innermost(Complex_Selector_Obj, Combinator);\n virtual size_t hash()\n {\n if (hash_ == 0) {\n hash_combine(hash_, std::hash()(SELECTOR));\n hash_combine(hash_, std::hash()(combinator_));\n if (head_) hash_combine(hash_, head_->hash());\n if (tail_) hash_combine(hash_, tail_->hash());\n }\n return hash_;\n }\n virtual unsigned long specificity() const\n {\n int sum = 0;\n if (head()) sum += head()->specificity();\n if (tail()) sum += tail()->specificity();\n return sum;\n }\n virtual void set_media_block(Media_Block_Ptr mb) {\n media_block(mb);\n if (tail_) tail_->set_media_block(mb);\n if (head_) head_->set_media_block(mb);\n }\n virtual bool has_placeholder() {\n if (head_ && head_->has_placeholder()) return true;\n if (tail_ && tail_->has_placeholder()) return true;\n return false;\n }\n virtual bool find ( bool (*f)(AST_Node_Obj) );\n virtual bool operator<(const Selector& rhs) const;\n virtual bool operator==(const Selector& rhs) const;\n virtual bool operator<(const Complex_Selector& rhs) const;\n virtual bool operator==(const Complex_Selector& rhs) const;\n inline bool operator!=(const Complex_Selector& rhs) const { return !(*this == rhs); }\n const ComplexSelectorSet sources()\n {\n //s = Set.new\n //seq.map {|sseq_or_op| s.merge sseq_or_op.sources if sseq_or_op.is_a?(SimpleSequence)}\n //s\n\n ComplexSelectorSet srcs;\n\n Compound_Selector_Obj pHead = head();\n Complex_Selector_Obj pTail = tail();\n\n if (pHead) {\n const ComplexSelectorSet& headSources = pHead->sources();\n srcs.insert(headSources.begin(), headSources.end());\n }\n\n if (pTail) {\n const ComplexSelectorSet& tailSources = pTail->sources();\n srcs.insert(tailSources.begin(), tailSources.end());\n }\n\n return srcs;\n }\n void addSources(ComplexSelectorSet& sources) {\n // members.map! {|m| m.is_a?(SimpleSequence) ? m.with_more_sources(sources) : m}\n Complex_Selector_Ptr pIter = this;\n while (pIter) {\n Compound_Selector_Ptr pHead = pIter->head();\n\n if (pHead) {\n pHead->mergeSources(sources);\n }\n\n pIter = pIter->tail();\n }\n }\n void clearSources() {\n Complex_Selector_Ptr pIter = this;\n while (pIter) {\n Compound_Selector_Ptr pHead = pIter->head();\n\n if (pHead) {\n pHead->clearSources();\n }\n\n pIter = pIter->tail();\n }\n }\n\n virtual void cloneChildren();\n ATTACH_AST_OPERATIONS(Complex_Selector)\n ATTACH_OPERATIONS()\n };\n\n ///////////////////////////////////\n // Comma-separated selector groups.\n ///////////////////////////////////\n class Selector_List : public Selector, public Vectorized {\n ADD_PROPERTY(Selector_Schema_Obj, schema)\n ADD_CONSTREF(std::vector, wspace)\n protected:\n void adjust_after_pushing(Complex_Selector_Obj c);\n public:\n Selector_List(ParserState pstate, size_t s = 0)\n : Selector(pstate),\n Vectorized(s),\n schema_(NULL),\n wspace_(0)\n { }\n Selector_List(const Selector_List* ptr)\n : Selector(ptr),\n Vectorized(*ptr),\n schema_(ptr->schema_),\n wspace_(ptr->wspace_)\n { }\n std::string type() const { return \"list\"; }\n // remove parent selector references\n // basically unwraps parsed selectors\n virtual bool has_parent_ref() const;\n virtual bool has_real_parent_ref() const;\n void remove_parent_selectors();\n Selector_List_Ptr resolve_parent_refs(std::vector& pstack, Backtraces& traces, bool implicit_parent = true);\n virtual bool is_superselector_of(Compound_Selector_Obj sub, std::string wrapping = \"\");\n virtual bool is_superselector_of(Complex_Selector_Obj sub, std::string wrapping = \"\");\n virtual bool is_superselector_of(Selector_List_Obj sub, std::string wrapping = \"\");\n Selector_List_Ptr unify_with(Selector_List_Ptr);\n void populate_extends(Selector_List_Obj, Subset_Map&);\n Selector_List_Obj eval(Eval& eval);\n virtual size_t hash()\n {\n if (Selector::hash_ == 0) {\n hash_combine(Selector::hash_, std::hash()(SELECTOR));\n hash_combine(Selector::hash_, Vectorized::hash());\n }\n return Selector::hash_;\n }\n virtual unsigned long specificity() const\n {\n unsigned long sum = 0;\n unsigned long specificity;\n for (size_t i = 0, L = length(); i < L; ++i)\n {\n specificity = (*this)[i]->specificity();\n if (sum < specificity) sum = specificity;\n }\n return sum;\n }\n virtual void set_media_block(Media_Block_Ptr mb) {\n media_block(mb);\n for (Complex_Selector_Obj cs : elements()) {\n cs->set_media_block(mb);\n }\n }\n virtual bool has_placeholder() {\n for (Complex_Selector_Obj cs : elements()) {\n if (cs->has_placeholder()) return true;\n }\n return false;\n }\n virtual bool find ( bool (*f)(AST_Node_Obj) );\n virtual bool operator<(const Selector& rhs) const;\n virtual bool operator==(const Selector& rhs) const;\n virtual bool operator<(const Selector_List& rhs) const;\n virtual bool operator==(const Selector_List& rhs) const;\n // Selector Lists can be compared to comma lists\n virtual bool operator==(const Expression& rhs) const;\n virtual void cloneChildren();\n ATTACH_AST_OPERATIONS(Selector_List)\n ATTACH_OPERATIONS()\n };\n\n // compare function for sorting and probably other other uses\n struct cmp_complex_selector { inline bool operator() (const Complex_Selector_Obj l, const Complex_Selector_Obj r) { return (*l < *r); } };\n struct cmp_compound_selector { inline bool operator() (const Compound_Selector_Obj l, const Compound_Selector_Obj r) { return (*l < *r); } };\n struct cmp_simple_selector { inline bool operator() (const Simple_Selector_Obj l, const Simple_Selector_Obj r) { return (*l < *r); } };\n\n}\n\n#ifdef __clang__\n\n#pragma clang diagnostic pop\n\n#endif\n\n#endif\n\n// Path: src/libsass/src/ast_fwd_decl.cpp\n#include \"ast.hpp\"\n\nnamespace Sass {\n\n #define IMPLEMENT_BASE_CAST(T) \\\n template<> \\\n T* Cast(AST_Node* ptr) { \\\n return dynamic_cast(ptr); \\\n }; \\\n \\\n template<> \\\n const T* Cast(const AST_Node* ptr) { \\\n return dynamic_cast(ptr); \\\n }; \\\n\n IMPLEMENT_BASE_CAST(AST_Node)\n IMPLEMENT_BASE_CAST(Expression)\n IMPLEMENT_BASE_CAST(Statement)\n IMPLEMENT_BASE_CAST(Has_Block)\n IMPLEMENT_BASE_CAST(PreValue)\n IMPLEMENT_BASE_CAST(Value)\n IMPLEMENT_BASE_CAST(List)\n IMPLEMENT_BASE_CAST(String)\n IMPLEMENT_BASE_CAST(String_Constant)\n IMPLEMENT_BASE_CAST(Supports_Condition)\n IMPLEMENT_BASE_CAST(Selector)\n IMPLEMENT_BASE_CAST(Simple_Selector)\n\n}\n\n// Path: src/libsass/src/backtrace.cpp\n#include \"backtrace.hpp\"\n\nnamespace Sass {\n\n const std::string traces_to_string(Backtraces traces, std::string indent) {\n\n std::stringstream ss;\n std::string cwd(File::get_cwd());\n\n bool first = true;\n size_t i_beg = traces.size() - 1;\n size_t i_end = std::string::npos;\n for (size_t i = i_beg; i != i_end; i --) {\n\n const Backtrace& trace = traces[i];\n\n // make path relative to the current directory\n std::string rel_path(File::abs2rel(trace.pstate.path, cwd, cwd));\n\n // skip functions on error cases (unsure why ruby sass does this)\n // if (trace.caller.substr(0, 6) == \", in f\") continue;\n\n if (first) {\n ss << indent;\n ss << \"on line \";\n ss << trace.pstate.line + 1;\n ss << \" of \" << rel_path;\n // ss << trace.caller;\n first = false;\n } else {\n ss << trace.caller;\n ss << std::endl;\n ss << indent;\n ss << \"from line \";\n ss << trace.pstate.line + 1;\n ss << \" of \" << rel_path;\n }\n\n }\n\n ss << std::endl;\n return ss.str();\n\n }\n\n};\n\n// Path: src/libsass/src/environment.cpp\n#include \"sass.hpp\"\n#include \"ast.hpp\"\n#include \"environment.hpp\"\n\nnamespace Sass {\n\n template \n Environment::Environment(bool is_shadow)\n : local_frame_(environment_map()),\n parent_(0), is_shadow_(false)\n { }\n template \n Environment::Environment(Environment* env, bool is_shadow)\n : local_frame_(environment_map()),\n parent_(env), is_shadow_(is_shadow)\n { }\n template \n Environment::Environment(Environment& env, bool is_shadow)\n : local_frame_(environment_map()),\n parent_(&env), is_shadow_(is_shadow)\n { }\n\n // link parent to create a stack\n template \n void Environment::link(Environment& env) { parent_ = &env; }\n template \n void Environment::link(Environment* env) { parent_ = env; }\n\n // this is used to find the global frame\n // which is the second last on the stack\n template \n bool Environment::is_lexical() const\n {\n return !! parent_ && parent_->parent_;\n }\n\n // only match the real root scope\n // there is still a parent around\n // not sure what it is actually use for\n // I guess we store functions etc. there\n template \n bool Environment::is_global() const\n {\n return parent_ && ! parent_->parent_;\n }\n\n template \n environment_map& Environment::local_frame() {\n return local_frame_;\n }\n\n template \n bool Environment::has_local(const std::string& key) const\n { return local_frame_.find(key) != local_frame_.end(); }\n\n template EnvResult\n Environment::find_local(const std::string& key)\n {\n auto end = local_frame_.end();\n auto it = local_frame_.find(key);\n return EnvResult(it, it != end);\n }\n\n template \n T& Environment::get_local(const std::string& key)\n { return local_frame_[key]; }\n\n template \n void Environment::set_local(const std::string& key, const T& val)\n {\n local_frame_[key] = val;\n }\n template \n void Environment::set_local(const std::string& key, T&& val)\n {\n local_frame_[key] = val;\n }\n\n template \n void Environment::del_local(const std::string& key)\n { local_frame_.erase(key); }\n\n template \n Environment* Environment::global_env()\n {\n Environment* cur = this;\n while (cur->is_lexical()) {\n cur = cur->parent_;\n }\n return cur;\n }\n\n template \n bool Environment::has_global(const std::string& key)\n { return global_env()->has(key); }\n\n template \n T& Environment::get_global(const std::string& key)\n { return (*global_env())[key]; }\n\n template \n void Environment::set_global(const std::string& key, const T& val)\n {\n global_env()->local_frame_[key] = val;\n }\n template \n void Environment::set_global(const std::string& key, T&& val)\n {\n global_env()->local_frame_[key] = val;\n }\n\n template \n void Environment::del_global(const std::string& key)\n { global_env()->local_frame_.erase(key); }\n\n template \n Environment* Environment::lexical_env(const std::string& key)\n {\n Environment* cur = this;\n while (cur) {\n if (cur->has_local(key)) {\n return cur;\n }\n cur = cur->parent_;\n }\n return this;\n }\n\n // see if we have a lexical variable\n // move down the stack but stop before we\n // reach the global frame (is not included)\n template \n bool Environment::has_lexical(const std::string& key) const\n {\n auto cur = this;\n while (cur->is_lexical()) {\n if (cur->has_local(key)) return true;\n cur = cur->parent_;\n }\n return false;\n }\n\n // see if we have a lexical we could update\n // either update already existing lexical value\n // or if flag is set, we create one if no lexical found\n template \n void Environment::set_lexical(const std::string& key, const T& val)\n {\n Environment* cur = this;\n bool shadow = false;\n while ((cur && cur->is_lexical()) || shadow) {\n EnvResult rv(cur->find_local(key));\n if (rv.found) {\n rv.it->second = val;\n return;\n }\n shadow = cur->is_shadow();\n cur = cur->parent_;\n }\n set_local(key, val);\n }\n // this one moves the value\n template \n void Environment::set_lexical(const std::string& key, T&& val)\n {\n Environment* cur = this;\n bool shadow = false;\n while ((cur && cur->is_lexical()) || shadow) {\n EnvResult rv(cur->find_local(key));\n if (rv.found) {\n rv.it->second = val;\n return;\n }\n shadow = cur->is_shadow();\n cur = cur->parent_;\n }\n set_local(key, val);\n }\n\n // look on the full stack for key\n // include all scopes available\n template \n bool Environment::has(const std::string& key) const\n {\n auto cur = this;\n while (cur) {\n if (cur->has_local(key)) {\n return true;\n }\n cur = cur->parent_;\n }\n return false;\n }\n\n // look on the full stack for key\n // include all scopes available\n template EnvResult\n Environment::find(const std::string& key)\n {\n auto cur = this;\n while (true) {\n EnvResult rv(cur->find_local(key));\n if (rv.found) return rv;\n cur = cur->parent_;\n if (!cur) return rv;\n }\n };\n\n // use array access for getter and setter functions\n template \n T& Environment::operator[](const std::string& key)\n {\n auto cur = this;\n while (cur) {\n if (cur->has_local(key)) {\n return cur->get_local(key);\n }\n cur = cur->parent_;\n...\n", "needle_token_start": 4044, "needle_token_end": 4085, "code_context_ntokens": 16259, "output": ["```cpp\nclass Thunk : public Expression {\n ADD_PROPERTY(Expression_Obj, expression)\n ADD_PROPERTY(Env*, environment)\npublic:\n Thunk(ParserState pstate, Expression_Obj exp, Env* env = 0)\n : Expression(pstate), expression_(exp), environment_(env)\n { }\n};\n```"]} +{"repo": "sass/node-sass", "name": "vector_to_string", "language": "cpp", "path": "src/libsass/src/paths.hpp", "position_ratio": 0.35, "description": "\nFunction Description:\n1. **Purpose**: The function is designed to convert a list of elements into a human-readable string format, encapsulating the elements within square brackets and separating them with commas.\n2. **Input**: A vector containing elements of a generic type.\n3. **Output**: A string representation of the vector, where elements are enclosed in square brackets and separated by commas. If the vector is empty, it returns \"[]\".\n4. **Procedure**: The function initializes a string stream and begins by adding an opening square bracket. If the vector is not empty, it inserts the first element directly. For vectors with more than one element, it iterates through the remaining elements, appending each with a preceding comma and a space for formatting. Finally, it closes with a square bracket and converts the stream contents into a string for output.\n\n", "instruction": "Based on the function description and code context, please retrieve and repeat the exact described function from the code context in a code block wrapped by ```:", "template": "instruction\ncode_context\ndescription\ninstruction", "code_context": "// Path: src/libsass/src/ast.cpp\n#include \"sass.hpp\"\n#include \"ast.hpp\"\n#include \"context.hpp\"\n#include \"node.hpp\"\n#include \"eval.hpp\"\n#include \"extend.hpp\"\n#include \"emitter.hpp\"\n#include \"color_maps.hpp\"\n#include \"ast_fwd_decl.hpp\"\n#include \n#include \n#include \n#include \n#include \n#include \n#include \n\nnamespace Sass {\n\n static Null sass_null(ParserState(\"null\"));\n\n bool Wrapped_Selector::find ( bool (*f)(AST_Node_Obj) )\n {\n // check children first\n if (selector_) {\n if (selector_->find(f)) return true;\n }\n // execute last\n return f(this);\n }\n\n bool Selector_List::find ( bool (*f)(AST_Node_Obj) )\n {\n // check children first\n for (Complex_Selector_Obj sel : elements()) {\n if (sel->find(f)) return true;\n }\n // execute last\n return f(this);\n }\n\n bool Compound_Selector::find ( bool (*f)(AST_Node_Obj) )\n {\n // check children first\n for (Simple_Selector_Obj sel : elements()) {\n if (sel->find(f)) return true;\n }\n // execute last\n return f(this);\n }\n\n bool Complex_Selector::find ( bool (*f)(AST_Node_Obj) )\n {\n // check children first\n if (head_ && head_->find(f)) return true;\n if (tail_ && tail_->find(f)) return true;\n // execute last\n return f(this);\n }\n\n bool Supports_Operator::needs_parens(Supports_Condition_Obj cond) const {\n if (Supports_Operator_Obj op = Cast(cond)) {\n return op->operand() != operand();\n }\n return Cast(cond) != NULL;\n }\n\n bool Supports_Negation::needs_parens(Supports_Condition_Obj cond) const {\n return Cast(cond) ||\n Cast(cond);\n }\n\n void str_rtrim(std::string& str, const std::string& delimiters = \" \\f\\n\\r\\t\\v\")\n {\n str.erase( str.find_last_not_of( delimiters ) + 1 );\n }\n\n void String_Constant::rtrim()\n {\n str_rtrim(value_);\n }\n\n void String_Schema::rtrim()\n {\n if (!empty()) {\n if (String_Ptr str = Cast(last())) str->rtrim();\n }\n }\n\n void Argument::set_delayed(bool delayed)\n {\n if (value_) value_->set_delayed(delayed);\n is_delayed(delayed);\n }\n\n void Arguments::set_delayed(bool delayed)\n {\n for (Argument_Obj arg : elements()) {\n if (arg) arg->set_delayed(delayed);\n }\n is_delayed(delayed);\n }\n\n\n bool At_Root_Query::exclude(std::string str)\n {\n bool with = feature() && unquote(feature()->to_string()).compare(\"with\") == 0;\n List_Ptr l = static_cast(value().ptr());\n std::string v;\n\n if (with)\n {\n if (!l || l->length() == 0) return str.compare(\"rule\") != 0;\n for (size_t i = 0, L = l->length(); i < L; ++i)\n {\n v = unquote((*l)[i]->to_string());\n if (v.compare(\"all\") == 0 || v == str) return false;\n }\n return true;\n }\n else\n {\n if (!l || !l->length()) return str.compare(\"rule\") == 0;\n for (size_t i = 0, L = l->length(); i < L; ++i)\n {\n v = unquote((*l)[i]->to_string());\n if (v.compare(\"all\") == 0 || v == str) return true;\n }\n return false;\n }\n }\n\n void AST_Node::update_pstate(const ParserState& pstate)\n {\n pstate_.offset += pstate - pstate_ + pstate.offset;\n }\n\n bool Simple_Selector::is_ns_eq(const Simple_Selector& r) const\n {\n // https://github.com/sass/sass/issues/2229\n if ((has_ns_ == r.has_ns_) ||\n (has_ns_ && ns_.empty()) ||\n (r.has_ns_ && r.ns_.empty())\n ) {\n if (ns_.empty() && r.ns() == \"*\") return false;\n else if (r.ns().empty() && ns() == \"*\") return false;\n else return ns() == r.ns();\n }\n return false;\n }\n\n bool Compound_Selector::operator< (const Compound_Selector& rhs) const\n {\n size_t L = std::min(length(), rhs.length());\n for (size_t i = 0; i < L; ++i)\n {\n Simple_Selector_Obj l = (*this)[i];\n Simple_Selector_Obj r = rhs[i];\n if (!l && !r) return false;\n else if (!r) return false;\n else if (!l) return true;\n else if (*l != *r)\n { return *l < *r; }\n }\n // just compare the length now\n return length() < rhs.length();\n }\n\n bool Compound_Selector::has_parent_ref() const\n {\n for (Simple_Selector_Obj s : *this) {\n if (s && s->has_parent_ref()) return true;\n }\n return false;\n }\n\n bool Compound_Selector::has_real_parent_ref() const\n {\n for (Simple_Selector_Obj s : *this) {\n if (s && s->has_real_parent_ref()) return true;\n }\n return false;\n }\n\n bool Complex_Selector::has_parent_ref() const\n {\n return (head() && head()->has_parent_ref()) ||\n (tail() && tail()->has_parent_ref());\n }\n\n bool Complex_Selector::has_real_parent_ref() const\n {\n return (head() && head()->has_real_parent_ref()) ||\n (tail() && tail()->has_real_parent_ref());\n }\n\n bool Complex_Selector::operator< (const Complex_Selector& rhs) const\n {\n // const iterators for tails\n Complex_Selector_Ptr_Const l = this;\n Complex_Selector_Ptr_Const r = &rhs;\n Compound_Selector_Ptr l_h = NULL;\n Compound_Selector_Ptr r_h = NULL;\n if (l) l_h = l->head();\n if (r) r_h = r->head();\n // process all tails\n while (true)\n {\n #ifdef DEBUG\n // skip empty ancestor first\n if (l && l->is_empty_ancestor())\n {\n l_h = NULL;\n l = l->tail();\n if(l) l_h = l->head();\n continue;\n }\n // skip empty ancestor first\n if (r && r->is_empty_ancestor())\n {\n r_h = NULL;\n r = r->tail();\n if (r) r_h = r->head();\n continue;\n }\n #endif\n // check for valid selectors\n if (!l) return !!r;\n if (!r) return false;\n // both are null\n else if (!l_h && !r_h)\n {\n // check combinator after heads\n if (l->combinator() != r->combinator())\n { return l->combinator() < r->combinator(); }\n // advance to next tails\n l = l->tail();\n r = r->tail();\n // fetch the next headers\n l_h = NULL; r_h = NULL;\n if (l) l_h = l->head();\n if (r) r_h = r->head();\n }\n // one side is null\n else if (!r_h) return true;\n else if (!l_h) return false;\n // heads ok and equal\n else if (*l_h == *r_h)\n {\n // check combinator after heads\n if (l->combinator() != r->combinator())\n { return l->combinator() < r->combinator(); }\n // advance to next tails\n l = l->tail();\n r = r->tail();\n // fetch the next headers\n l_h = NULL; r_h = NULL;\n if (l) l_h = l->head();\n if (r) r_h = r->head();\n }\n // heads are not equal\n else return *l_h < *r_h;\n }\n }\n\n bool Complex_Selector::operator== (const Complex_Selector& rhs) const\n {\n // const iterators for tails\n Complex_Selector_Ptr_Const l = this;\n Complex_Selector_Ptr_Const r = &rhs;\n Compound_Selector_Ptr l_h = NULL;\n Compound_Selector_Ptr r_h = NULL;\n if (l) l_h = l->head();\n if (r) r_h = r->head();\n // process all tails\n while (true)\n {\n #ifdef DEBUG\n // skip empty ancestor first\n if (l && l->is_empty_ancestor())\n {\n l_h = NULL;\n l = l->tail();\n if (l) l_h = l->head();\n continue;\n }\n // skip empty ancestor first\n if (r && r->is_empty_ancestor())\n {\n r_h = NULL;\n r = r->tail();\n if (r) r_h = r->head();\n continue;\n }\n #endif\n // check the pointers\n if (!r) return !l;\n if (!l) return !r;\n // both are null\n if (!l_h && !r_h)\n {\n // check combinator after heads\n if (l->combinator() != r->combinator())\n { return l->combinator() < r->combinator(); }\n // advance to next tails\n l = l->tail();\n r = r->tail();\n // fetch the next heads\n l_h = NULL; r_h = NULL;\n if (l) l_h = l->head();\n if (r) r_h = r->head();\n }\n // equals if other head is empty\n else if ((!l_h && !r_h) ||\n (!l_h && r_h->empty()) ||\n (!r_h && l_h->empty()) ||\n (l_h && r_h && *l_h == *r_h))\n {\n // check combinator after heads\n if (l->combinator() != r->combinator())\n { return l->combinator() == r->combinator(); }\n // advance to next tails\n l = l->tail();\n r = r->tail();\n // fetch the next heads\n l_h = NULL; r_h = NULL;\n if (l) l_h = l->head();\n if (r) r_h = r->head();\n }\n // abort\n else break;\n }\n // unreachable\n return false;\n }\n\n Compound_Selector_Ptr Compound_Selector::unify_with(Compound_Selector_Ptr rhs)\n {\n if (empty()) return rhs;\n Compound_Selector_Obj unified = SASS_MEMORY_COPY(rhs);\n for (size_t i = 0, L = length(); i < L; ++i)\n {\n if (unified.isNull()) break;\n unified = at(i)->unify_with(unified);\n }\n return unified.detach();\n }\n\n bool Complex_Selector::operator== (const Selector& rhs) const\n {\n if (const Selector_List* sl = Cast(&rhs)) return *this == *sl;\n if (const Simple_Selector* sp = Cast(&rhs)) return *this == *sp;\n if (const Complex_Selector* cs = Cast(&rhs)) return *this == *cs;\n if (const Compound_Selector* ch = Cast(&rhs)) return *this == *ch;\n throw std::runtime_error(\"invalid selector base classes to compare\");\n }\n\n\n bool Complex_Selector::operator< (const Selector& rhs) const\n {\n if (const Selector_List* sl = Cast(&rhs)) return *this < *sl;\n if (const Simple_Selector* sp = Cast(&rhs)) return *this < *sp;\n if (const Complex_Selector* cs = Cast(&rhs)) return *this < *cs;\n if (const Compound_Selector* ch = Cast(&rhs)) return *this < *ch;\n throw std::runtime_error(\"invalid selector base classes to compare\");\n }\n\n bool Compound_Selector::operator== (const Selector& rhs) const\n {\n if (const Selector_List* sl = Cast(&rhs)) return *this == *sl;\n if (const Simple_Selector* sp = Cast(&rhs)) return *this == *sp;\n if (const Complex_Selector* cs = Cast(&rhs)) return *this == *cs;\n if (const Compound_Selector* ch = Cast(&rhs)) return *this == *ch;\n throw std::runtime_error(\"invalid selector base classes to compare\");\n }\n\n bool Compound_Selector::operator< (const Selector& rhs) const\n {\n if (const Selector_List* sl = Cast(&rhs)) return *this < *sl;\n if (const Simple_Selector* sp = Cast(&rhs)) return *this < *sp;\n if (const Complex_Selector* cs = Cast(&rhs)) return *this < *cs;\n if (const Compound_Selector* ch = Cast(&rhs)) return *this < *ch;\n throw std::runtime_error(\"invalid selector base classes to compare\");\n }\n\n bool Selector_Schema::operator== (const Selector& rhs) const\n {\n if (const Selector_List* sl = Cast(&rhs)) return *this == *sl;\n if (const Simple_Selector* sp = Cast(&rhs)) return *this == *sp;\n if (const Complex_Selector* cs = Cast(&rhs)) return *this == *cs;\n if (const Compound_Selector* ch = Cast(&rhs)) return *this == *ch;\n throw std::runtime_error(\"invalid selector base classes to compare\");\n }\n\n bool Selector_Schema::operator< (const Selector& rhs) const\n {\n if (const Selector_List* sl = Cast(&rhs)) return *this < *sl;\n if (const Simple_Selector* sp = Cast(&rhs)) return *this < *sp;\n if (const Complex_Selector* cs = Cast(&rhs)) return *this < *cs;\n if (const Compound_Selector* ch = Cast(&rhs)) return *this < *ch;\n throw std::runtime_error(\"invalid selector base classes to compare\");\n }\n\n bool Simple_Selector::operator== (const Selector& rhs) const\n {\n if (Simple_Selector_Ptr_Const sp = Cast(&rhs)) return *this == *sp;\n return false;\n }\n\n bool Simple_Selector::operator< (const Selector& rhs) const\n {\n if (Simple_Selector_Ptr_Const sp = Cast(&rhs)) return *this < *sp;\n return false;\n }\n\n bool Simple_Selector::operator== (const Simple_Selector& rhs) const\n {\n // solve the double dispatch problem by using RTTI information via dynamic cast\n if (const Pseudo_Selector* lhs = Cast(this)) {return *lhs == rhs; }\n else if (const Wrapped_Selector* lhs = Cast(this)) {return *lhs == rhs; }\n else if (const Element_Selector* lhs = Cast(this)) {return *lhs == rhs; }\n else if (const Attribute_Selector* lhs = Cast(this)) {return *lhs == rhs; }\n else if (name_ == rhs.name_)\n { return is_ns_eq(rhs); }\n else return false;\n }\n\n bool Simple_Selector::operator< (const Simple_Selector& rhs) const\n {\n // solve the double dispatch problem by using RTTI information via dynamic cast\n if (const Pseudo_Selector* lhs = Cast(this)) {return *lhs < rhs; }\n else if (const Wrapped_Selector* lhs = Cast(this)) {return *lhs < rhs; }\n else if (const Element_Selector* lhs = Cast(this)) {return *lhs < rhs; }\n else if (const Attribute_Selector* lhs = Cast(this)) {return *lhs < rhs; }\n if (is_ns_eq(rhs))\n { return name_ < rhs.name_; }\n return ns_ < rhs.ns_;\n }\n\n bool Selector_List::operator== (const Selector& rhs) const\n {\n // solve the double dispatch problem by using RTTI information via dynamic cast\n if (Selector_List_Ptr_Const sl = Cast(&rhs)) { return *this == *sl; }\n else if (Complex_Selector_Ptr_Const cpx = Cast(&rhs)) { return *this == *cpx; }\n else if (Compound_Selector_Ptr_Const cpd = Cast(&rhs)) { return *this == *cpd; }\n // no compare method\n return this == &rhs;\n }\n\n // Selector lists can be compared to comma lists\n bool Selector_List::operator== (const Expression& rhs) const\n {\n // solve the double dispatch problem by using RTTI information via dynamic cast\n if (List_Ptr_Const ls = Cast(&rhs)) { return *ls == *this; }\n if (Selector_Ptr_Const ls = Cast(&rhs)) { return *this == *ls; }\n // compare invalid (maybe we should error?)\n return false;\n }\n\n bool Selector_List::operator== (const Selector_List& rhs) const\n {\n // for array access\n size_t i = 0, n = 0;\n size_t iL = length();\n size_t nL = rhs.length();\n // create temporary vectors and sort them\n std::vector l_lst = this->elements();\n...\n// Path: src/libsass/src/base64vlq.cpp\n#include \"sass.hpp\"\n#include \"base64vlq.hpp\"\n\nnamespace Sass {\n\n std::string Base64VLQ::encode(const int number) const\n {\n std::string encoded = \"\";\n\n int vlq = to_vlq_signed(number);\n\n do {\n int digit = vlq & VLQ_BASE_MASK;\n vlq >>= VLQ_BASE_SHIFT;\n if (vlq > 0) {\n digit |= VLQ_CONTINUATION_BIT;\n }\n encoded += base64_encode(digit);\n } while (vlq > 0);\n\n return encoded;\n }\n\n char Base64VLQ::base64_encode(const int number) const\n {\n int index = number;\n if (index < 0) index = 0;\n if (index > 63) index = 63;\n return CHARACTERS[index];\n }\n\n int Base64VLQ::to_vlq_signed(const int number) const\n {\n return (number < 0) ? ((-number) << 1) + 1 : (number << 1) + 0;\n }\n\n const char* Base64VLQ::CHARACTERS = \"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/\";\n\n const int Base64VLQ::VLQ_BASE_SHIFT = 5;\n const int Base64VLQ::VLQ_BASE = 1 << VLQ_BASE_SHIFT;\n const int Base64VLQ::VLQ_BASE_MASK = VLQ_BASE - 1;\n const int Base64VLQ::VLQ_CONTINUATION_BIT = VLQ_BASE;\n\n}\n\n// Path: src/libsass/src/paths.hpp\n#ifndef SASS_PATHS_H\n#define SASS_PATHS_H\n\n#include \n#include \n#include \n\n\ntemplate\n\nstd::string vector_to_string(std::vector v)\n{\n std::stringstream buffer;\n buffer << \"[\";\n\n if (!v.empty())\n { buffer << v[0]; }\n else\n { buffer << \"]\"; }\n\n if (v.size() == 1)\n { buffer << \"]\"; }\n else\n {\n for (size_t i = 1, S = v.size(); i < S; ++i) buffer << \", \" << v[i];\n buffer << \"]\";\n }\n\n return buffer.str();\n}\n\nnamespace Sass {\n\n\n template\n std::vector > paths(std::vector > strata, size_t from_end = 0)\n {\n if (strata.empty()) {\n return std::vector >();\n }\n\n size_t end = strata.size() - from_end;\n if (end <= 1) {\n std::vector > starting_points;\n starting_points.reserve(strata[0].size());\n for (size_t i = 0, S = strata[0].size(); i < S; ++i) {\n std::vector starting_point;\n starting_point.push_back(strata[0][i]);\n starting_points.push_back(starting_point);\n }\n return starting_points;\n }\n\n std::vector > up_to_here = paths(strata, from_end + 1);\n std::vector here = strata[end-1];\n\n std::vector > branches;\n branches.reserve(up_to_here.size() * here.size());\n for (size_t i = 0, S1 = up_to_here.size(); i < S1; ++i) {\n for (size_t j = 0, S2 = here.size(); j < S2; ++j) {\n std::vector branch = up_to_here[i];\n branch.push_back(here[j]);\n branches.push_back(branch);\n }\n }\n\n return branches;\n }\n\n}\n\n#endif\n\n// Path: src/libsass/src/debug.hpp\n#ifndef SASS_DEBUG_H\n#define SASS_DEBUG_H\n\n#include \n\n#ifndef UINT32_MAX\n #define UINT32_MAX 0xffffffffU\n#endif\n\nenum dbg_lvl_t : uint32_t {\n NONE = 0,\n TRIM = 1,\n CHUNKS = 2,\n SUBWEAVE = 4,\n WEAVE = 8,\n EXTEND_COMPOUND = 16,\n EXTEND_COMPLEX = 32,\n LCS = 64,\n EXTEND_OBJECT = 128,\n ALL = UINT32_MAX\n};\n\n#ifdef DEBUG\n\n#ifndef DEBUG_LVL\nconst uint32_t debug_lvl = UINT32_MAX;\n#else\nconst uint32_t debug_lvl = (DEBUG_LVL);\n#endif // DEBUG_LVL\n\n#define DEBUG_PRINT(lvl, x) if((lvl) & debug_lvl) { std::cerr << x; }\n#define DEBUG_PRINTLN(lvl, x) if((lvl) & debug_lvl) { std::cerr << x << std::endl; }\n#define DEBUG_EXEC(lvl, x) if((lvl) & debug_lvl) { x; }\n\n#else // DEBUG\n\n#define DEBUG_PRINT(lvl, x)\n#define DEBUG_PRINTLN(lvl, x)\n#define DEBUG_EXEC(lvl, x)\n\n#endif // DEBUG\n\n#endif // SASS_DEBUG\n\n// Path: src/libsass/src/sass_util.hpp\n#ifndef SASS_SASS_UTIL_H\n#define SASS_SASS_UTIL_H\n\n#include \"ast.hpp\"\n#include \"node.hpp\"\n#include \"debug.hpp\"\n\nnamespace Sass {\n\n\n\n\n /*\n This is for ports of functions in the Sass:Util module.\n */\n\n\n /*\n # Return a Node collection of all possible paths through the given Node collection of Node collections.\n #\n # @param arrs [NodeCollection>]\n # @return [NodeCollection>]\n #\n # @example\n # paths([[1, 2], [3, 4], [5]]) #=>\n # # [[1, 3, 5],\n # # [2, 3, 5],\n # # [1, 4, 5],\n # # [2, 4, 5]]\n */\n Node paths(const Node& arrs);\n\n\n /*\n This class is a default implementation of a Node comparator that can be passed to the lcs function below.\n It uses operator== for equality comparision. It then returns one if the Nodes are equal.\n */\n class DefaultLcsComparator {\n public:\n bool operator()(const Node& one, const Node& two, Node& out) const {\n // TODO: Is this the correct C++ interpretation?\n // block ||= proc {|a, b| a == b && a}\n if (one == two) {\n out = one;\n return true;\n }\n\n return false;\n }\n };\n\n\n typedef std::vector > LCSTable;\n\n\n /*\n This is the equivalent of ruby's Sass::Util.lcs_backtrace.\n\n # Computes a single longest common subsequence for arrays x and y.\n # Algorithm from http://en.wikipedia.org/wiki/Longest_common_subsequence_problem#Reading_out_an_LCS\n */\n template\n Node lcs_backtrace(const LCSTable& c, const Node& x, const Node& y, int i, int j, const ComparatorType& comparator) {\n DEBUG_PRINTLN(LCS, \"LCSBACK: X=\" << x << \" Y=\" << y << \" I=\" << i << \" J=\" << j)\n\n if (i == 0 || j == 0) {\n DEBUG_PRINTLN(LCS, \"RETURNING EMPTY\")\n return Node::createCollection();\n }\n\n NodeDeque& xChildren = *(x.collection());\n NodeDeque& yChildren = *(y.collection());\n\n Node compareOut = Node::createNil();\n if (comparator(xChildren[i], yChildren[j], compareOut)) {\n DEBUG_PRINTLN(LCS, \"RETURNING AFTER ELEM COMPARE\")\n Node result = lcs_backtrace(c, x, y, i - 1, j - 1, comparator);\n result.collection()->push_back(compareOut);\n return result;\n }\n\n if (c[i][j - 1] > c[i - 1][j]) {\n DEBUG_PRINTLN(LCS, \"RETURNING AFTER TABLE COMPARE\")\n return lcs_backtrace(c, x, y, i, j - 1, comparator);\n }\n\n DEBUG_PRINTLN(LCS, \"FINAL RETURN\")\n return lcs_backtrace(c, x, y, i - 1, j, comparator);\n }\n\n\n /*\n This is the equivalent of ruby's Sass::Util.lcs_table.\n\n # Calculates the memoization table for the Least Common Subsequence algorithm.\n # Algorithm from http://en.wikipedia.org/wiki/Longest_common_subsequence_problem#Computing_the_length_of_the_LCS\n */\n template\n void lcs_table(const Node& x, const Node& y, const ComparatorType& comparator, LCSTable& out) {\n DEBUG_PRINTLN(LCS, \"LCSTABLE: X=\" << x << \" Y=\" << y)\n\n NodeDeque& xChildren = *(x.collection());\n NodeDeque& yChildren = *(y.collection());\n\n LCSTable c(xChildren.size(), std::vector(yChildren.size()));\n\n // These shouldn't be necessary since the vector will be initialized to 0 already.\n // x.size.times {|i| c[i][0] = 0}\n // y.size.times {|j| c[0][j] = 0}\n\n for (size_t i = 1; i < xChildren.size(); i++) {\n for (size_t j = 1; j < yChildren.size(); j++) {\n Node compareOut = Node::createNil();\n\n if (comparator(xChildren[i], yChildren[j], compareOut)) {\n c[i][j] = c[i - 1][j - 1] + 1;\n } else {\n c[i][j] = std::max(c[i][j - 1], c[i - 1][j]);\n }\n }\n }\n\n out = c;\n }\n\n\n /*\n This is the equivalent of ruby's Sass::Util.lcs.\n\n # Computes a single longest common subsequence for `x` and `y`.\n # If there are more than one longest common subsequences,\n # the one returned is that which starts first in `x`.\n\n # @param x [NodeCollection]\n # @param y [NodeCollection]\n # @comparator An equality check between elements of `x` and `y`.\n # @return [NodeCollection] The LCS\n\n http://en.wikipedia.org/wiki/Longest_common_subsequence_problem\n */\n template\n Node lcs(Node& x, Node& y, const ComparatorType& comparator) {\n DEBUG_PRINTLN(LCS, \"LCS: X=\" << x << \" Y=\" << y)\n\n Node newX = Node::createCollection();\n newX.collection()->push_back(Node::createNil());\n newX.plus(x);\n\n Node newY = Node::createCollection();\n newY.collection()->push_back(Node::createNil());\n newY.plus(y);\n\n LCSTable table;\n lcs_table(newX, newY, comparator, table);\n\n return lcs_backtrace(table, newX, newY, static_cast(newX.collection()->size()) - 1, static_cast(newY.collection()->size()) - 1, comparator);\n }\n\n\n /*\n This is the equivalent of ruby sass' Sass::Util.flatten and [].flatten.\n Sass::Util.flatten requires the number of levels to flatten, while\n [].flatten doesn't and will flatten the entire array. This function\n supports both.\n\n # Flattens the first `n` nested arrays. If n == -1, all arrays will be flattened\n #\n # @param arr [NodeCollection] The array to flatten\n # @param n [int] The number of levels to flatten\n # @return [NodeCollection] The flattened array\n */\n Node flatten(Node& arr, int n = -1);\n\n\n /*\n This is the equivalent of ruby's Sass::Util.group_by_to_a.\n\n # Performs the equivalent of `enum.group_by.to_a`, but with a guaranteed\n # order. Unlike [#hash_to_a], the resulting order isn't sorted key order;\n # instead, it's the same order as `#group_by` has under Ruby 1.9 (key\n # appearance order).\n #\n # @param enum [Enumerable]\n # @return [Array<[Object, Array]>] An array of pairs.\n\n TODO: update @param and @return once I know what those are.\n\n The following is the modified version of the ruby code that was more portable to C++. You\n should be able to drop it into ruby 3.2.19 and get the same results from ruby sass.\n\n def group_by_to_a(enum, &block)\n order = {}\n\n arr = []\n\n grouped = {}\n\n for e in enum do\n key = block[e]\n unless order.include?(key)\n order[key] = order.size\n end\n\n if not grouped.has_key?(key) then\n grouped[key] = [e]\n else\n grouped[key].push(e)\n end\n end\n\n grouped.each do |key, vals|\n arr[order[key]] = [key, vals]\n end\n\n arr\n end\n\n */\n template\n void group_by_to_a(std::vector& enumeration, KeyFunctorType& keyFunc, std::vector > >& arr /*out*/) {\n\n std::map order;\n\n std::map > grouped;\n\n for (typename std::vector::iterator enumIter = enumeration.begin(), enumIterEnd = enumeration.end(); enumIter != enumIterEnd; enumIter++) {\n EnumType& e = *enumIter;\n\n KeyType key = keyFunc(e);\n\n if (grouped.find(key->hash()) == grouped.end()) {\n order.insert(std::make_pair((unsigned int)order.size(), key));\n\n std::vector newCollection;\n newCollection.push_back(e);\n grouped.insert(std::make_pair(key->hash(), newCollection));\n } else {\n std::vector& collection = grouped.at(key->hash());\n collection.push_back(e);\n }\n }\n\n for (unsigned int index = 0; index < order.size(); index++) {\n KeyType& key = order.at(index);\n std::vector& values = grouped.at(key->hash());\n\n std::pair > grouping = std::make_pair(key, values);\n\n arr.push_back(grouping);\n }\n }\n\n\n}\n\n#endif\n\n// Path: src/libsass/src/extend.cpp\n#include \"sass.hpp\"\n#include \"extend.hpp\"\n#include \"context.hpp\"\n#include \"backtrace.hpp\"\n#include \"paths.hpp\"\n#include \"parser.hpp\"\n#include \"expand.hpp\"\n#include \"node.hpp\"\n#include \"sass_util.hpp\"\n#include \"remove_placeholders.hpp\"\n#include \"debug.hpp\"\n#include \n#include \n#include \n\n/*\n NOTES:\n\n - The print* functions print to cerr. This allows our testing frameworks (like sass-spec) to ignore the output, which\n is very helpful when debugging. The format of the output is mainly to wrap things in square brackets to match what\n ruby already outputs (to make comparisons easier).\n\n - For the direct porting effort, we're trying to port method-for-method until we get all the tests passing.\n Where applicable, I've tried to include the ruby code above the function for reference until all our tests pass.\n The ruby code isn't always directly portable, so I've tried to include any modified ruby code that was actually\n used for the porting.\n\n - DO NOT try to optimize yet. We get a tremendous benefit out of comparing the output of each stage of the extend to the ruby\n output at the same stage. This makes it much easier to determine where problems are. Try to keep as close to\n the ruby code as you can until we have all the sass-spec tests passing. Then, we should optimize. However, if you see\n something that could probably be optimized, let's not forget it. Add a // TODO: or // IMPROVEMENT: comment.\n\n - Coding conventions in this file (these may need to be changed before merging back into master)\n - Very basic hungarian notation:\n p prefix for pointers (pSelector)\n no prefix for value types and references (selector)\n - Use STL iterators where possible\n - prefer verbose naming over terse naming\n - use typedefs for STL container types for make maintenance easier\n\n - You may see a lot of comments that say \"// TODO: is this the correct combinator?\". See the comment referring to combinators\n in extendCompoundSelector for a more extensive explanation of my confusion. I think our divergence in data model from ruby\n sass causes this to be necessary.\n\n\n GLOBAL TODOS:\n\n - wrap the contents of the print functions in DEBUG preprocesser conditionals so they will be optimized away in non-debug mode.\n\n - consider making the extend* functions member functions to avoid passing around ctx and subset_map map around. This has the\n drawback that the implementation details of the operator are then exposed to the outside world, which is not ideal and\n can cause additional compile time dependencies.\n\n - mark the helper methods in this file static to given them compilation unit linkage.\n\n - implement parent directive matching\n\n - fix compilation warnings for unused Extend members if we really don't need those references anymore.\n */\n\n\nnamespace Sass {\n\n\n\n#ifdef DEBUG\n\n // TODO: move the ast specific ostream operators into ast.hpp/ast.cpp\n std::ostream& operator<<(std::ostream& os, const Complex_Selector::Combinator combinator) {\n switch (combinator) {\n case Complex_Selector::ANCESTOR_OF: os << \"\\\" \\\"\"; break;\n case Complex_Selector::PARENT_OF: os << \"\\\">\\\"\"; break;\n case Complex_Selector::PRECEDES: os << \"\\\"~\\\"\"; break;\n case Complex_Selector::ADJACENT_TO: os << \"\\\"+\\\"\"; break;\n case Complex_Selector::REFERENCE: os << \"\\\"/\\\"\"; break;\n }\n\n return os;\n }\n\n\n std::ostream& operator<<(std::ostream& os, Compound_Selector& compoundSelector) {\n for (size_t i = 0, L = compoundSelector.length(); i < L; ++i) {\n if (i > 0) os << \", \";\n os << compoundSelector[i]->to_string();\n }\n return os;\n }\n\n std::ostream& operator<<(std::ostream& os, Simple_Selector& simpleSelector) {\n os << simpleSelector.to_string();\n return os;\n }\n\n // Print a string representation of a Compound_Selector\n static void printSimpleSelector(Simple_Selector* pSimpleSelector, const char* message=NULL, bool newline=true) {\n\n if (message) {\n std::cerr << message;\n }\n\n if (pSimpleSelector) {\n std::cerr << \"[\" << *pSimpleSelector << \"]\";\n } else {\n std::cerr << \"NULL\";\n }\n\n if (newline) {\n std::cerr << std::endl;\n }\n }\n\n // Print a string representation of a Compound_Selector\n static void printCompoundSelector(Compound_Selector_Ptr pCompoundSelector, const char* message=NULL, bool newline=true) {\n\n if (message) {\n std::cerr << message;\n }\n\n if (pCompoundSelector) {\n std::cerr << \"[\" << *pCompoundSelector << \"]\";\n } else {\n std::cerr << \"NULL\";\n }\n\n if (newline) {\n std::cerr << std::endl;\n }\n }\n\n\n std::ostream& operator<<(std::ostream& os, Complex_Selector& complexSelector) {\n\n os << \"[\";\n Complex_Selector_Ptr pIter = &complexSelector;\n bool first = true;\n while (pIter) {\n if (pIter->combinator() != Complex_Selector::ANCESTOR_OF) {\n if (!first) {\n os << \", \";\n }\n first = false;\n os << pIter->combinator();\n }\n\n if (!first) {\n os << \", \";\n }\n first = false;\n\n if (pIter->head()) {\n os << pIter->head()->to_string();\n } else {\n os << \"NULL_HEAD\";\n }\n\n pIter = pIter->tail();\n }\n os << \"]\";\n\n return os;\n }\n\n\n // Print a string representation of a Complex_Selector\n static void printComplexSelector(Complex_Selector_Ptr pComplexSelector, const char* message=NULL, bool newline=true) {\n\n if (message) {\n std::cerr << message;\n }\n\n if (pComplexSelector) {\n std::cerr << *pComplexSelector;\n } else {\n std::cerr << \"NULL\";\n }\n\n if (newline) {\n std::cerr << std::endl;\n }\n }\n\n static void printSelsNewSeqPairCollection(SubSetMapLookups& collection, const char* message=NULL, bool newline=true) {\n\n if (message) {\n std::cerr << message;\n }\n bool first = true;\n std::cerr << \"[\";\n for(SubSetMapLookup& pair : collection) {\n if (first) {\n first = false;\n } else {\n std::cerr << \", \";\n }\n std::cerr << \"[\";\n Compound_Selector_Ptr pSels = pair.first;\n Complex_Selector_Ptr pNewSelector = pair.second;\n std::cerr << \"[\" << *pSels << \"], \";\n printComplexSelector(pNewSelector, NULL, false);\n }\n std::cerr << \"]\";\n\n if (newline) {\n std::cerr << std::endl;\n }\n }\n\n // Print a string representation of a ComplexSelectorSet\n static void printSourcesSet(ComplexSelectorSet& sources, const char* message=NULL, bool newline=true) {\n\n if (message) {\n std::cerr << message;\n }\n\n // Convert to a deque of strings so we can sort since order doesn't matter in a set. This should cut down on\n // the differences we see when debug printing.\n typedef std::deque SourceStrings;\n SourceStrings sourceStrings;\n for (ComplexSelectorSet::iterator iterator = sources.begin(), iteratorEnd = sources.end(); iterator != iteratorEnd; ++iterator) {\n Complex_Selector_Ptr pSource = *iterator;\n std::stringstream sstream;\n sstream << complexSelectorToNode(pSource);\n sourceStrings.push_back(sstream.str());\n }\n\n // Sort to get consistent output\n std::sort(sourceStrings.begin(), sourceStrings.end());\n\n std::cerr << \"ComplexSelectorSet[\";\n for (SourceStrings::iterator iterator = sourceStrings.begin(), iteratorEnd = sourceStrings.end(); iterator != iteratorEnd; ++iterator) {\n std::string source = *iterator;\n if (iterator != sourceStrings.begin()) {\n std::cerr << \", \";\n }\n std::cerr << source;\n }\n std::cerr << \"]\";\n\n if (newline) {\n std::cerr << std::endl;\n }\n }\n\n\n std::ostream& operator<<(std::ostream& os, SubSetMapPairs& entries) {\n os << \"SUBSET_MAP_ENTRIES[\";\n\n for (SubSetMapPairs::iterator iterator = entries.begin(), endIterator = entries.end(); iterator != endIterator; ++iterator) {\n Complex_Selector_Obj pExtComplexSelector = iterator->first; // The selector up to where the @extend is (ie, the thing to merge)\n Compound_Selector_Obj pExtCompoundSelector = iterator->second; // The stuff after the @extend\n\n if (iterator != entries.begin()) {\n os << \", \";\n }\n\n os << \"(\";\n\n if (pExtComplexSelector) {\n std::cerr << *pExtComplexSelector;\n } else {\n std::cerr << \"NULL\";\n }\n\n os << \" -> \";\n\n if (pExtCompoundSelector) {\n std::cerr << *pExtCompoundSelector;\n } else {\n std::cerr << \"NULL\";\n }\n\n os << \")\";\n\n }\n\n os << \"]\";\n\n return os;\n }\n#endif\n\n static bool parentSuperselector(Complex_Selector_Ptr pOne, Complex_Selector_Ptr pTwo) {\n // TODO: figure out a better way to create a Complex_Selector from scratch\n // TODO: There's got to be a better way. This got ugly quick...\n Element_Selector_Obj fakeParent = SASS_MEMORY_NEW(Element_Selector, ParserState(\"[FAKE]\"), \"temp\");\n Compound_Selector_Obj fakeHead = SASS_MEMORY_NEW(Compound_Selector, ParserState(\"[FAKE]\"), 1 /*size*/);\n fakeHead->elements().push_back(fakeParent);\n Complex_Selector_Obj fakeParentContainer = SASS_MEMORY_NEW(Complex_Selector, ParserState(\"[FAKE]\"), Complex_Selector::ANCESTOR_OF, fakeHead /*head*/, NULL /*tail*/);\n\n pOne->set_innermost(fakeParentContainer, Complex_Selector::ANCESTOR_OF);\n pTwo->set_innermost(fakeParentContainer, Complex_Selector::ANCESTOR_OF);\n\n bool isSuperselector = pOne->is_superselector_of(pTwo);\n\n pOne->clear_innermost();\n pTwo->clear_innermost();\n\n return isSuperselector;\n }\n\n void nodeToComplexSelectorDeque(const Node& node, ComplexSelectorDeque& out) {\n for (NodeDeque::iterator iter = node.collection()->begin(), iterEnd = node.collection()->end(); iter != iterEnd; iter++) {\n Node& child = *iter;\n out.push_back(nodeToComplexSelector(child));\n }\n }\n\n Node complexSelectorDequeToNode(const ComplexSelectorDeque& deque) {\n Node result = Node::createCollection();\n\n for (ComplexSelectorDeque::const_iterator iter = deque.begin(), iterEnd = deque.end(); iter != iterEnd; iter++) {\n Complex_Selector_Obj pChild = *iter;\n result.collection()->push_back(complexSelectorToNode(pChild));\n }\n\n return result;\n }\n\n class LcsCollectionComparator {\n public:\n LcsCollectionComparator() {}\n\n bool operator()(Complex_Selector_Obj pOne, Complex_Selector_Obj pTwo, Complex_Selector_Obj& pOut) const {\n /*\n This code is based on the following block from ruby sass' subweave\n do |s1, s2|\n next s1 if s1 == s2\n next unless s1.first.is_a?(SimpleSequence) && s2.first.is_a?(SimpleSequence)\n next s2 if parent_superselector?(s1, s2)\n next s1 if parent_superselector?(s2, s1)\n end\n */\n\n if (*pOne == *pTwo) {\n pOut = pOne;\n return true;\n }\n\n if (pOne->combinator() != Complex_Selector::ANCESTOR_OF || pTwo->combinator() != Complex_Selector::ANCESTOR_OF) {\n return false;\n }\n\n if (parentSuperselector(pOne, pTwo)) {\n pOut = pTwo;\n return true;\n }\n\n if (parentSuperselector(pTwo, pOne)) {\n pOut = pOne;\n return true;\n }\n\n return false;\n }\n };\n\n\n /*\n This is the equivalent of ruby's Sass::Util.lcs_backtrace.\n\n # Computes a single longest common subsequence for arrays x and y.\n # Algorithm from http://en.wikipedia.org/wiki/Longest_common_subsequence_problem#Reading_out_an_LCS\n */\n void lcs_backtrace(const LCSTable& c, ComplexSelectorDeque& x, ComplexSelectorDeque& y, int i, int j, const LcsCollectionComparator& comparator, ComplexSelectorDeque& out) {\n //DEBUG_PRINTLN(LCS, \"LCSBACK: X=\" << x << \" Y=\" << y << \" I=\" << i << \" J=\" << j)\n // TODO: make printComplexSelectorDeque and use DEBUG_EXEC AND DEBUG_PRINTLN HERE to get equivalent output\n\n if (i == 0 || j == 0) {\n DEBUG_PRINTLN(LCS, \"RETURNING EMPTY\")\n return;\n }\n\n\n Complex_Selector_Obj pCompareOut;\n if (comparator(x[i], y[j], pCompareOut)) {\n DEBUG_PRINTLN(LCS, \"RETURNING AFTER ELEM COMPARE\")\n lcs_backtrace(c, x, y, i - 1, j - 1, comparator, out);\n out.push_back(pCompareOut);\n return;\n }\n\n if (c[i][j - 1] > c[i - 1][j]) {\n DEBUG_PRINTLN(LCS, \"RETURNING AFTER TABLE COMPARE\")\n lcs_backtrace(c, x, y, i, j - 1, comparator, out);\n return;\n }\n\n DEBUG_PRINTLN(LCS, \"FINAL RETURN\")\n lcs_backtrace(c, x, y, i - 1, j, comparator, out);\n return;\n }\n\n /*\n This is the equivalent of ruby's Sass::Util.lcs_table.\n\n # Calculates the memoization table for the Least Common Subsequence algorithm.\n # Algorithm from http://en.wikipedia.org/wiki/Longest_common_subsequence_problem#Computing_the_length_of_the_LCS\n */\n void lcs_table(const ComplexSelectorDeque& x, const ComplexSelectorDeque& y, const LcsCollectionComparator& comparator, LCSTable& out) {\n //DEBUG_PRINTLN(LCS, \"LCSTABLE: X=\" << x << \" Y=\" << y)\n // TODO: make printComplexSelectorDeque and use DEBUG_EXEC AND DEBUG_PRINTLN HERE to get equivalent output\n\n LCSTable c(x.size(), std::vector(y.size()));\n\n // These shouldn't be necessary since the vector will be initialized to 0 already.\n // x.size.times {|i| c[i][0] = 0}\n // y.size.times {|j| c[0][j] = 0}\n\n for (size_t i = 1; i < x.size(); i++) {\n for (size_t j = 1; j < y.size(); j++) {\n Complex_Selector_Obj pCompareOut;\n\n if (comparator(x[i], y[j], pCompareOut)) {\n c[i][j] = c[i - 1][j - 1] + 1;\n } else {\n c[i][j] = std::max(c[i][j - 1], c[i - 1][j]);\n }\n }\n }\n\n out = c;\n }\n\n /*\n This is the equivalent of ruby's Sass::Util.lcs.\n\n # Computes a single longest common subsequence for `x` and `y`.\n # If there are more than one longest common subsequences,\n # the one returned is that which starts first in `x`.\n\n # @param x [NodeCollection]\n # @param y [NodeCollection]\n # @comparator An equality check between elements of `x` and `y`.\n # @return [NodeCollection] The LCS\n\n http://en.wikipedia.org/wiki/Longest_common_subsequence_problem\n */\n void lcs(ComplexSelectorDeque& x, ComplexSelectorDeque& y, const LcsCollectionComparator& comparator, ComplexSelectorDeque& out) {\n //DEBUG_PRINTLN(LCS, \"LCS: X=\" << x << \" Y=\" << y)\n // TODO: make printComplexSelectorDeque and use DEBUG_EXEC AND DEBUG_PRINTLN HERE to get equivalent output\n\n x.push_front(NULL);\n y.push_front(NULL);\n\n LCSTable table;\n lcs_table(x, y, comparator, table);\n\n return lcs_backtrace(table, x, y, static_cast(x.size()) - 1, static_cast(y.size()) - 1, comparator, out);\n }\n\n\n /*\n This is the equivalent of ruby's Sequence.trim.\n\n The following is the modified version of the ruby code that was more portable to C++. You\n should be able to drop it into ruby 3.2.19 and get the same results from ruby sass.\n\n # Avoid truly horrific quadratic behavior. TODO: I think there\n # may be a way to get perfect trimming without going quadratic.\n return seqses if seqses.size > 100\n\n # Keep the results in a separate array so we can be sure we aren't\n # comparing against an already-trimmed selector. This ensures that two\n # identical selectors don't mutually trim one another.\n result = seqses.dup\n\n # This is n^2 on the sequences, but only comparing between\n # separate sequences should limit the quadratic behavior.\n seqses.each_with_index do |seqs1, i|\n tempResult = []\n\n for seq1 in seqs1 do\n max_spec = 0\n for seq in _sources(seq1) do\n max_spec = [max_spec, seq.specificity].max\n end\n\n\n isMoreSpecificOuter = false\n for seqs2 in result do\n if seqs1.equal?(seqs2) then\n next\n end\n\n # Second Law of Extend: the specificity of a generated selector\n # should never be less than the specificity of the extending\n # selector.\n #\n # See https://github.com/nex3/sass/issues/324.\n isMoreSpecificInner = false\n for seq2 in seqs2 do\n isMoreSpecificInner = _specificity(seq2) >= max_spec && _superselector?(seq2, seq1)\n if isMoreSpecificInner then\n break\n end\n end\n\n if isMoreSpecificInner then\n isMoreSpecificOuter = true\n break\n end\n end\n\n if !isMoreSpecificOuter then\n tempResult.push(seq1)\n end\n end\n\n result[i] = tempResult\n\n end\n\n result\n */\n /*\n - IMPROVEMENT: We could probably work directly in the output trimmed deque.\n */\n Node Extend::trim(Node& seqses, bool isReplace) {\n // See the comments in the above ruby code before embarking on understanding this function.\n\n // Avoid poor performance in extreme cases.\n if (seqses.collection()->size() > 100) {\n return seqses;\n }\n\n\n DEBUG_PRINTLN(TRIM, \"TRIM: \" << seqses)\n\n\n Node result = Node::createCollection();\n result.plus(seqses);\n\n DEBUG_PRINTLN(TRIM, \"RESULT INITIAL: \" << result)\n\n // Normally we use the standard STL iterators, but in this case, we need to access the result collection by index since we're\n // iterating the input collection, computing a value, and then setting the result in the output collection. We have to keep track\n // of the index manually.\n int toTrimIndex = 0;\n\n for (NodeDeque::iterator seqsesIter = seqses.collection()->begin(), seqsesIterEnd = seqses.collection()->end(); seqsesIter != seqsesIterEnd; ++seqsesIter) {\n Node& seqs1 = *seqsesIter;\n\n DEBUG_PRINTLN(TRIM, \"SEQS1: \" << seqs1 << \" \" << toTrimIndex)\n\n Node tempResult = Node::createCollection();\n tempResult.got_line_feed = seqs1.got_line_feed;\n\n for (NodeDeque::iterator seqs1Iter = seqs1.collection()->begin(), seqs1EndIter = seqs1.collection()->end(); seqs1Iter != seqs1EndIter; ++seqs1Iter) {\n Node& seq1 = *seqs1Iter;\n\n Complex_Selector_Obj pSeq1 = nodeToComplexSelector(seq1);\n\n // Compute the maximum specificity. This requires looking at the \"sources\" of the sequence. See SimpleSequence.sources in the ruby code\n // for a good description of sources.\n //\n // TODO: I'm pretty sure there's a bug in the sources code. It was implemented for sass-spec's 182_test_nested_extend_loop test.\n // While the test passes, I compared the state of each trim call to verify correctness. The last trim call had incorrect sources. We\n // had an extra source that the ruby version did not have. Without a failing test case, this is going to be extra hard to find. My\n // best guess at this point is that we're cloning an object somewhere and maintaining the sources when we shouldn't be. This is purely\n // a guess though.\n unsigned long maxSpecificity = isReplace ? pSeq1->specificity() : 0;\n ComplexSelectorSet sources = pSeq1->sources();\n\n DEBUG_PRINTLN(TRIM, \"TRIM SEQ1: \" << seq1)\n DEBUG_EXEC(TRIM, printSourcesSet(sources, \"TRIM SOURCES: \"))\n\n for (ComplexSelectorSet::iterator sourcesSetIterator = sources.begin(), sourcesSetIteratorEnd = sources.end(); sourcesSetIterator != sourcesSetIteratorEnd; ++sourcesSetIterator) {\n const Complex_Selector_Obj& pCurrentSelector = *sourcesSetIterator;\n maxSpecificity = std::max(maxSpecificity, pCurrentSelector->specificity());\n }\n\n DEBUG_PRINTLN(TRIM, \"MAX SPECIFICITY: \" << maxSpecificity)\n\n bool isMoreSpecificOuter = false;\n\n int resultIndex = 0;\n\n for (NodeDeque::iterator resultIter = result.collection()->begin(), resultIterEnd = result.collection()->end(); resultIter != resultIterEnd; ++resultIter) {\n Node& seqs2 = *resultIter;\n\n DEBUG_PRINTLN(TRIM, \"SEQS1: \" << seqs1)\n DEBUG_PRINTLN(TRIM, \"SEQS2: \" << seqs2)\n\n // Do not compare the same sequence to itself. The ruby call we're trying to\n // emulate is: seqs1.equal?(seqs2). equal? is an object comparison, not an equivalency comparision.\n // Since we have the same pointers in seqes and results, we can do a pointer comparision. seqs1 is\n // derived from seqses and seqs2 is derived from result.\n if (seqs1.collection() == seqs2.collection()) {\n DEBUG_PRINTLN(TRIM, \"CONTINUE\")\n continue;\n }\n\n bool isMoreSpecificInner = false;\n\n for (NodeDeque::iterator seqs2Iter = seqs2.collection()->begin(), seqs2IterEnd = seqs2.collection()->end(); seqs2Iter != seqs2IterEnd; ++seqs2Iter) {\n Node& seq2 = *seqs2Iter;\n\n Complex_Selector_Obj pSeq2 = nodeToComplexSelector(seq2);\n\n DEBUG_PRINTLN(TRIM, \"SEQ2 SPEC: \" << pSeq2->specificity())\n DEBUG_PRINTLN(TRIM, \"IS SPEC: \" << pSeq2->specificity() << \" >= \" << maxSpecificity << \" \" << (pSeq2->specificity() >= maxSpecificity ? \"true\" : \"false\"))\n DEBUG_PRINTLN(TRIM, \"IS SUPER: \" << (pSeq2->is_superselector_of(pSeq1) ? \"true\" : \"false\"))\n\n isMoreSpecificInner = pSeq2->specificity() >= maxSpecificity && pSeq2->is_superselector_of(pSeq1);\n\n if (isMoreSpecificInner) {\n DEBUG_PRINTLN(TRIM, \"FOUND MORE SPECIFIC\")\n break;\n }\n }\n\n // If we found something more specific, we're done. Let the outer loop know and stop iterating.\n if (isMoreSpecificInner) {\n isMoreSpecificOuter = true;\n break;\n }\n\n resultIndex++;\n }\n\n if (!isMoreSpecificOuter) {\n...\n", "needle_token_start": 5601, "needle_token_end": 5748, "code_context_ntokens": 16070, "output": ["```cpp\n#include \n#include \n#include \n\ntemplate\nstd::string vector_to_string(std::vector v)\n{\n std::stringstream buffer;\n buffer << \"[\";\n\n if (!v.empty())\n { buffer << v[0]; }\n else\n { buffer << \"]\"; }\n\n if (v.size() == 1)\n { buffer << \"]\"; }\n else\n {\n for (size_t i = 1, S = v.size(); i < S; ++i) buffer << \", \" << v[i];\n buffer << \"]\";\n }\n\n return buffer.str();\n}\n```"]} +{"repo": "sass/node-sass", "name": "debug_subset_map", "language": "cpp", "path": "src/libsass/src/debugger.hpp", "position_ratio": 0.45, "description": "\nFunction Description:\n1. **Purpose**: The function is designed to output a detailed textual representation of a mapping structure used in a style sheet language processor. This helps in debugging by showing how different abstract syntax tree (AST) nodes are related within the map.\n2. **Input**: The function accepts a mapping structure that associates AST nodes with each other, typically representing relationships or transformations in a stylesheet.\n3. **Output**: There is no return value; instead, the function outputs to the standard error stream. It prints a detailed breakdown of each key-value pair in the mapping, showing the structure and content of the associated AST nodes.\n4. **Procedure**: The function begins by checking if the indentation level is at the base level and, if so, prints a start marker. It then iterates over each entry in the map, calling another function to print the AST nodes that form the key and the value of each entry with appropriate indentation. After iterating through all entries, if the indentation level is still at the base, it prints an end marker.\n\n", "instruction": "Based on the function description and code context, please retrieve and repeat the exact described function from the code context in a code block wrapped by ```:", "template": "instruction\ncode_context\ndescription\ninstruction", "code_context": " std::cerr << \" [@media:\" << block->media_block() << \"]\";\n std::cerr << \" \" << block->tabs() << std::endl;\n } else if (Cast(node)) {\n Import_Stub_Ptr block = Cast(node);\n std::cerr << ind << \"Import_Stub \" << block;\n std::cerr << \" (\" << pstate_source_position(node) << \")\";\n std::cerr << \" [\" << block->imp_path() << \"] \";\n std::cerr << \" \" << block->tabs() << std::endl;\n } else if (Cast(node)) {\n Import_Ptr block = Cast(node);\n std::cerr << ind << \"Import \" << block;\n std::cerr << \" (\" << pstate_source_position(node) << \")\";\n std::cerr << \" \" << block->tabs() << std::endl;\n // std::vector files_;\n for (auto imp : block->urls()) debug_ast(imp, ind + \"@: \", env);\n debug_ast(block->import_queries(), ind + \"@@ \");\n } else if (Cast(node)) {\n Assignment_Ptr block = Cast(node);\n std::cerr << ind << \"Assignment \" << block;\n std::cerr << \" (\" << pstate_source_position(node) << \")\";\n std::cerr << \" <<\" << block->variable() << \">> \" << block->tabs() << std::endl;\n debug_ast(block->value(), ind + \"=\", env);\n } else if (Cast(node)) {\n Declaration_Ptr block = Cast(node);\n std::cerr << ind << \"Declaration \" << block;\n std::cerr << \" (\" << pstate_source_position(node) << \")\";\n std::cerr << \" [is_custom_property: \" << block->is_custom_property() << \"] \";\n std::cerr << \" \" << block->tabs() << std::endl;\n debug_ast(block->property(), ind + \" prop: \", env);\n debug_ast(block->value(), ind + \" value: \", env);\n debug_ast(block->block(), ind + \" \", env);\n } else if (Cast(node)) {\n Keyframe_Rule_Ptr has_block = Cast(node);\n std::cerr << ind << \"Keyframe_Rule \" << has_block;\n std::cerr << \" (\" << pstate_source_position(node) << \")\";\n std::cerr << \" \" << has_block->tabs() << std::endl;\n if (has_block->name()) debug_ast(has_block->name(), ind + \"@\");\n if (has_block->block()) for(const Statement_Obj& i : has_block->block()->elements()) { debug_ast(i, ind + \" \", env); }\n } else if (Cast(node)) {\n Directive_Ptr block = Cast(node);\n std::cerr << ind << \"Directive \" << block;\n std::cerr << \" (\" << pstate_source_position(node) << \")\";\n std::cerr << \" [\" << block->keyword() << \"] \" << block->tabs() << std::endl;\n debug_ast(block->selector(), ind + \"~\", env);\n debug_ast(block->value(), ind + \"+\", env);\n if (block->block()) for(const Statement_Obj& i : block->block()->elements()) { debug_ast(i, ind + \" \", env); }\n } else if (Cast(node)) {\n Each_Ptr block = Cast(node);\n std::cerr << ind << \"Each \" << block;\n std::cerr << \" (\" << pstate_source_position(node) << \")\";\n std::cerr << \" \" << block->tabs() << std::endl;\n if (block->block()) for(const Statement_Obj& i : block->block()->elements()) { debug_ast(i, ind + \" \", env); }\n } else if (Cast(node)) {\n For_Ptr block = Cast(node);\n std::cerr << ind << \"For \" << block;\n std::cerr << \" (\" << pstate_source_position(node) << \")\";\n std::cerr << \" \" << block->tabs() << std::endl;\n if (block->block()) for(const Statement_Obj& i : block->block()->elements()) { debug_ast(i, ind + \" \", env); }\n } else if (Cast(node)) {\n While_Ptr block = Cast(node);\n std::cerr << ind << \"While \" << block;\n std::cerr << \" (\" << pstate_source_position(node) << \")\";\n std::cerr << \" \" << block->tabs() << std::endl;\n if (block->block()) for(const Statement_Obj& i : block->block()->elements()) { debug_ast(i, ind + \" \", env); }\n } else if (Cast(node)) {\n Definition_Ptr block = Cast(node);\n std::cerr << ind << \"Definition \" << block;\n std::cerr << \" (\" << pstate_source_position(node) << \")\";\n std::cerr << \" [name: \" << block->name() << \"] \";\n std::cerr << \" [type: \" << (block->type() == Sass::Definition::Type::MIXIN ? \"Mixin \" : \"Function \") << \"] \";\n // this seems to lead to segfaults some times?\n // std::cerr << \" [signature: \" << block->signature() << \"] \";\n std::cerr << \" [native: \" << block->native_function() << \"] \";\n std::cerr << \" \" << block->tabs() << std::endl;\n debug_ast(block->parameters(), ind + \" params: \", env);\n if (block->block()) debug_ast(block->block(), ind + \" \", env);\n } else if (Cast(node)) {\n Mixin_Call_Ptr block = Cast(node);\n std::cerr << ind << \"Mixin_Call \" << block << \" \" << block->tabs();\n std::cerr << \" (\" << pstate_source_position(block) << \")\";\n std::cerr << \" [\" << block->name() << \"]\";\n std::cerr << \" [has_content: \" << block->has_content() << \"] \" << std::endl;\n debug_ast(block->arguments(), ind + \" args: \");\n if (block->block()) debug_ast(block->block(), ind + \" \", env);\n } else if (Ruleset_Ptr ruleset = Cast(node)) {\n std::cerr << ind << \"Ruleset \" << ruleset;\n std::cerr << \" (\" << pstate_source_position(node) << \")\";\n std::cerr << \" [indent: \" << ruleset->tabs() << \"]\";\n std::cerr << (ruleset->is_invisible() ? \" [INVISIBLE]\" : \"\");\n std::cerr << (ruleset->is_root() ? \" [root]\" : \"\");\n std::cerr << std::endl;\n debug_ast(ruleset->selector(), ind + \">\");\n debug_ast(ruleset->block(), ind + \" \");\n } else if (Cast(node)) {\n Block_Ptr block = Cast(node);\n std::cerr << ind << \"Block \" << block;\n std::cerr << \" (\" << pstate_source_position(node) << \")\";\n std::cerr << (block->is_invisible() ? \" [INVISIBLE]\" : \"\");\n std::cerr << \" [indent: \" << block->tabs() << \"]\" << std::endl;\n for(const Statement_Obj& i : block->elements()) { debug_ast(i, ind + \" \", env); }\n } else if (Cast(node)) {\n Variable_Ptr expression = Cast(node);\n std::cerr << ind << \"Variable \" << expression;\n std::cerr << \" [interpolant: \" << expression->is_interpolant() << \"] \";\n std::cerr << \" (\" << pstate_source_position(node) << \")\";\n std::cerr << \" [\" << expression->name() << \"]\" << std::endl;\n std::string name(expression->name());\n if (env && env->has(name)) debug_ast(Cast((*env)[name]), ind + \" -> \", env);\n } else if (Cast(node)) {\n Function_Call_Schema_Ptr expression = Cast(node);\n std::cerr << ind << \"Function_Call_Schema \" << expression;\n std::cerr << \" [interpolant: \" << expression->is_interpolant() << \"] \";\n std::cerr << \" (\" << pstate_source_position(node) << \")\";\n std::cerr << \"\" << std::endl;\n debug_ast(expression->name(), ind + \"name: \", env);\n debug_ast(expression->arguments(), ind + \" args: \", env);\n } else if (Cast(node)) {\n Function_Call_Ptr expression = Cast(node);\n std::cerr << ind << \"Function_Call \" << expression;\n std::cerr << \" [interpolant: \" << expression->is_interpolant() << \"] \";\n std::cerr << \" (\" << pstate_source_position(node) << \")\";\n std::cerr << \" [\" << expression->name() << \"]\";\n if (expression->is_delayed()) std::cerr << \" [delayed]\";\n if (expression->is_interpolant()) std::cerr << \" [interpolant]\";\n if (expression->is_css()) std::cerr << \" [css]\";\n std::cerr << std::endl;\n debug_ast(expression->arguments(), ind + \" args: \", env);\n debug_ast(expression->func(), ind + \" func: \", env);\n } else if (Cast(node)) {\n Function_Ptr expression = Cast(node);\n std::cerr << ind << \"Function \" << expression;\n std::cerr << \" (\" << pstate_source_position(node) << \")\";\n if (expression->is_css()) std::cerr << \" [css]\";\n std::cerr << std::endl;\n debug_ast(expression->definition(), ind + \" definition: \", env);\n } else if (Cast(node)) {\n Arguments_Ptr expression = Cast(node);\n std::cerr << ind << \"Arguments \" << expression;\n if (expression->is_delayed()) std::cerr << \" [delayed]\";\n std::cerr << \" (\" << pstate_source_position(node) << \")\";\n if (expression->has_named_arguments()) std::cerr << \" [has_named_arguments]\";\n if (expression->has_rest_argument()) std::cerr << \" [has_rest_argument]\";\n if (expression->has_keyword_argument()) std::cerr << \" [has_keyword_argument]\";\n std::cerr << std::endl;\n for(const Argument_Obj& i : expression->elements()) { debug_ast(i, ind + \" \", env); }\n } else if (Cast(node)) {\n Argument_Ptr expression = Cast(node);\n std::cerr << ind << \"Argument \" << expression;\n std::cerr << \" (\" << pstate_source_position(node) << \")\";\n std::cerr << \" [\" << expression->value().ptr() << \"]\";\n std::cerr << \" [name: \" << expression->name() << \"] \";\n std::cerr << \" [rest: \" << expression->is_rest_argument() << \"] \";\n std::cerr << \" [keyword: \" << expression->is_keyword_argument() << \"] \" << std::endl;\n debug_ast(expression->value(), ind + \" value: \", env);\n } else if (Cast(node)) {\n Parameters_Ptr expression = Cast(node);\n std::cerr << ind << \"Parameters \" << expression;\n std::cerr << \" (\" << pstate_source_position(node) << \")\";\n std::cerr << \" [has_optional: \" << expression->has_optional_parameters() << \"] \";\n std::cerr << \" [has_rest: \" << expression->has_rest_parameter() << \"] \";\n std::cerr << std::endl;\n for(const Parameter_Obj& i : expression->elements()) { debug_ast(i, ind + \" \", env); }\n } else if (Cast(node)) {\n Parameter_Ptr expression = Cast(node);\n std::cerr << ind << \"Parameter \" << expression;\n std::cerr << \" (\" << pstate_source_position(node) << \")\";\n std::cerr << \" [name: \" << expression->name() << \"] \";\n std::cerr << \" [default: \" << expression->default_value().ptr() << \"] \";\n std::cerr << \" [rest: \" << expression->is_rest_parameter() << \"] \" << std::endl;\n } else if (Cast(node)) {\n Unary_Expression_Ptr expression = Cast(node);\n std::cerr << ind << \"Unary_Expression \" << expression;\n std::cerr << \" [interpolant: \" << expression->is_interpolant() << \"] \";\n std::cerr << \" [delayed: \" << expression->is_delayed() << \"] \";\n std::cerr << \" (\" << pstate_source_position(node) << \")\";\n std::cerr << \" [\" << expression->type() << \"]\" << std::endl;\n debug_ast(expression->operand(), ind + \" operand: \", env);\n } else if (Cast(node)) {\n Binary_Expression_Ptr expression = Cast(node);\n std::cerr << ind << \"Binary_Expression \" << expression;\n if (expression->is_interpolant()) std::cerr << \" [is interpolant] \";\n if (expression->is_left_interpolant()) std::cerr << \" [left interpolant] \";\n if (expression->is_right_interpolant()) std::cerr << \" [right interpolant] \";\n std::cerr << \" [delayed: \" << expression->is_delayed() << \"] \";\n std::cerr << \" [ws_before: \" << expression->op().ws_before << \"] \";\n std::cerr << \" [ws_after: \" << expression->op().ws_after << \"] \";\n std::cerr << \" (\" << pstate_source_position(node) << \")\";\n std::cerr << \" [\" << expression->type_name() << \"]\" << std::endl;\n debug_ast(expression->left(), ind + \" left: \", env);\n debug_ast(expression->right(), ind + \" right: \", env);\n } else if (Cast(node)) {\n Map_Ptr expression = Cast(node);\n std::cerr << ind << \"Map \" << expression;\n std::cerr << \" [interpolant: \" << expression->is_interpolant() << \"] \";\n std::cerr << \" (\" << pstate_source_position(node) << \")\";\n std::cerr << \" [Hashed]\" << std::endl;\n for (const auto& i : expression->elements()) {\n debug_ast(i.first, ind + \" key: \");\n debug_ast(i.second, ind + \" val: \");\n }\n } else if (Cast(node)) {\n List_Ptr expression = Cast(node);\n std::cerr << ind << \"List \" << expression;\n std::cerr << \" (\" << pstate_source_position(node) << \")\";\n std::cerr << \" (\" << expression->length() << \") \" <<\n (expression->separator() == SASS_COMMA ? \"Comma \" : expression->separator() == SASS_HASH ? \"Map \" : \"Space \") <<\n \" [delayed: \" << expression->is_delayed() << \"] \" <<\n \" [interpolant: \" << expression->is_interpolant() << \"] \" <<\n \" [listized: \" << expression->from_selector() << \"] \" <<\n \" [arglist: \" << expression->is_arglist() << \"] \" <<\n \" [bracketed: \" << expression->is_bracketed() << \"] \" <<\n \" [expanded: \" << expression->is_expanded() << \"] \" <<\n \" [hash: \" << expression->hash() << \"] \" <<\n std::endl;\n for(const auto& i : expression->elements()) { debug_ast(i, ind + \" \", env); }\n } else if (Cast(node)) {\n Content_Ptr expression = Cast(node);\n std::cerr << ind << \"Content \" << expression;\n std::cerr << \" (\" << pstate_source_position(node) << \")\";\n std::cerr << \" [@media:\" << expression->media_block() << \"]\";\n std::cerr << \" [Statement]\" << std::endl;\n } else if (Cast(node)) {\n Boolean_Ptr expression = Cast(node);\n std::cerr << ind << \"Boolean \" << expression;\n std::cerr << \" (\" << pstate_source_position(node) << \")\";\n std::cerr << \" [interpolant: \" << expression->is_interpolant() << \"] \";\n std::cerr << \" [\" << expression->value() << \"]\" << std::endl;\n } else if (Cast(node)) {\n Color_Ptr expression = Cast(node);\n std::cerr << ind << \"Color \" << expression;\n std::cerr << \" (\" << pstate_source_position(node) << \")\";\n std::cerr << \" [delayed: \" << expression->is_delayed() << \"] \";\n std::cerr << \" [interpolant: \" << expression->is_interpolant() << \"] \";\n std::cerr << \" [\" << expression->r() << \":\" << expression->g() << \":\" << expression->b() << \"@\" << expression->a() << \"]\" << std::endl;\n } else if (Cast(node)) {\n Number_Ptr expression = Cast(node);\n std::cerr << ind << \"Number \" << expression;\n std::cerr << \" (\" << pstate_source_position(node) << \")\";\n std::cerr << \" [delayed: \" << expression->is_delayed() << \"] \";\n std::cerr << \" [interpolant: \" << expression->is_interpolant() << \"] \";\n std::cerr << \" [\" << expression->value() << expression->unit() << \"]\" <<\n \" [hash: \" << expression->hash() << \"] \" <<\n std::endl;\n } else if (Cast(node)) {\n Null_Ptr expression = Cast(node);\n std::cerr << ind << \"Null \" << expression;\n std::cerr << \" (\" << pstate_source_position(node) << \")\";\n std::cerr << \" [interpolant: \" << expression->is_interpolant() << \"] \"\n // \" [hash: \" << expression->hash() << \"] \"\n << std::endl;\n } else if (Cast(node)) {\n String_Quoted_Ptr expression = Cast(node);\n std::cerr << ind << \"String_Quoted \" << expression;\n std::cerr << \" (\" << pstate_source_position(node) << \")\";\n std::cerr << \" [\" << prettyprint(expression->value()) << \"]\";\n if (expression->is_delayed()) std::cerr << \" [delayed]\";\n if (expression->is_interpolant()) std::cerr << \" [interpolant]\";\n if (expression->quote_mark()) std::cerr << \" [quote_mark: \" << expression->quote_mark() << \"]\";\n std::cerr << \" <\" << prettyprint(expression->pstate().token.ws_before()) << \">\" << std::endl;\n } else if (Cast(node)) {\n String_Constant_Ptr expression = Cast(node);\n std::cerr << ind << \"String_Constant \" << expression;\n if (expression->concrete_type()) {\n std::cerr << \" \" << expression->concrete_type();\n }\n std::cerr << \" (\" << pstate_source_position(node) << \")\";\n std::cerr << \" [\" << prettyprint(expression->value()) << \"]\";\n if (expression->is_delayed()) std::cerr << \" [delayed]\";\n if (expression->is_interpolant()) std::cerr << \" [interpolant]\";\n std::cerr << \" <\" << prettyprint(expression->pstate().token.ws_before()) << \">\" << std::endl;\n } else if (Cast(node)) {\n String_Schema_Ptr expression = Cast(node);\n std::cerr << ind << \"String_Schema \" << expression;\n std::cerr << \" (\" << pstate_source_position(expression) << \")\";\n std::cerr << \" \" << expression->concrete_type();\n std::cerr << \" (\" << pstate_source_position(node) << \")\";\n if (expression->css()) std::cerr << \" [css]\";\n if (expression->is_delayed()) std::cerr << \" [delayed]\";\n if (expression->is_interpolant()) std::cerr << \" [is interpolant]\";\n if (expression->has_interpolant()) std::cerr << \" [has interpolant]\";\n if (expression->is_left_interpolant()) std::cerr << \" [left interpolant] \";\n if (expression->is_right_interpolant()) std::cerr << \" [right interpolant] \";\n std::cerr << \" <\" << prettyprint(expression->pstate().token.ws_before()) << \">\" << std::endl;\n for(const auto& i : expression->elements()) { debug_ast(i, ind + \" \", env); }\n } else if (Cast(node)) {\n String_Ptr expression = Cast(node);\n std::cerr << ind << \"String \" << expression;\n std::cerr << \" \" << expression->concrete_type();\n std::cerr << \" (\" << pstate_source_position(node) << \")\";\n if (expression->is_interpolant()) std::cerr << \" [interpolant]\";\n std::cerr << \" <\" << prettyprint(expression->pstate().token.ws_before()) << \">\" << std::endl;\n } else if (Cast(node)) {\n Expression_Ptr expression = Cast(node);\n std::cerr << ind << \"Expression \" << expression;\n std::cerr << \" (\" << pstate_source_position(node) << \")\";\n switch (expression->concrete_type()) {\n case Expression::Concrete_Type::NONE: std::cerr << \" [NONE]\"; break;\n case Expression::Concrete_Type::BOOLEAN: std::cerr << \" [BOOLEAN]\"; break;\n case Expression::Concrete_Type::NUMBER: std::cerr << \" [NUMBER]\"; break;\n case Expression::Concrete_Type::COLOR: std::cerr << \" [COLOR]\"; break;\n case Expression::Concrete_Type::STRING: std::cerr << \" [STRING]\"; break;\n case Expression::Concrete_Type::LIST: std::cerr << \" [LIST]\"; break;\n case Expression::Concrete_Type::MAP: std::cerr << \" [MAP]\"; break;\n case Expression::Concrete_Type::SELECTOR: std::cerr << \" [SELECTOR]\"; break;\n case Expression::Concrete_Type::NULL_VAL: std::cerr << \" [NULL_VAL]\"; break;\n case Expression::Concrete_Type::C_WARNING: std::cerr << \" [C_WARNING]\"; break;\n case Expression::Concrete_Type::C_ERROR: std::cerr << \" [C_ERROR]\"; break;\n case Expression::Concrete_Type::FUNCTION: std::cerr << \" [FUNCTION]\"; break;\n case Expression::Concrete_Type::NUM_TYPES: std::cerr << \" [NUM_TYPES]\"; break;\n case Expression::Concrete_Type::VARIABLE: std::cerr << \" [VARIABLE]\"; break;\n case Expression::Concrete_Type::FUNCTION_VAL: std::cerr << \" [FUNCTION_VAL]\"; break;\n }\n std::cerr << std::endl;\n } else if (Cast(node)) {\n Has_Block_Ptr has_block = Cast(node);\n std::cerr << ind << \"Has_Block \" << has_block;\n std::cerr << \" (\" << pstate_source_position(node) << \")\";\n std::cerr << \" \" << has_block->tabs() << std::endl;\n if (has_block->block()) for(const Statement_Obj& i : has_block->block()->elements()) { debug_ast(i, ind + \" \", env); }\n } else if (Cast(node)) {\n Statement_Ptr statement = Cast(node);\n std::cerr << ind << \"Statement \" << statement;\n std::cerr << \" (\" << pstate_source_position(node) << \")\";\n std::cerr << \" \" << statement->tabs() << std::endl;\n }\n\n if (ind == \"\") std::cerr << \"####################################################################\\n\";\n}\n\ninline void debug_node(Node* node, std::string ind = \"\")\n{\n if (ind == \"\") std::cerr << \"#@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@\\n\";\n if (node->isCombinator()) {\n std::cerr << ind;\n std::cerr << \"Combinator \";\n std::cerr << node << \" \";\n if (node->got_line_feed) std::cerr << \"[LF] \";\n switch (node->combinator()) {\n case Complex_Selector::ADJACENT_TO: std::cerr << \"{+} \"; break;\n case Complex_Selector::PARENT_OF: std::cerr << \"{>} \"; break;\n case Complex_Selector::PRECEDES: std::cerr << \"{~} \"; break;\n case Complex_Selector::REFERENCE: std::cerr << \"{@} \"; break;\n case Complex_Selector::ANCESTOR_OF: std::cerr << \"{ } \"; break;\n }\n std::cerr << std::endl;\n // debug_ast(node->combinator(), ind + \" \");\n } else if (node->isSelector()) {\n std::cerr << ind;\n std::cerr << \"Selector \";\n std::cerr << node << \" \";\n if (node->got_line_feed) std::cerr << \"[LF] \";\n std::cerr << std::endl;\n debug_ast(node->selector(), ind + \" \");\n } else if (node->isCollection()) {\n std::cerr << ind;\n std::cerr << \"Collection \";\n std::cerr << node << \" \";\n if (node->got_line_feed) std::cerr << \"[LF] \";\n std::cerr << std::endl;\n for(auto n : (*node->collection())) {\n debug_node(&n, ind + \" \");\n }\n } else if (node->isNil()) {\n std::cerr << ind;\n std::cerr << \"Nil \";\n std::cerr << node << \" \";\n if (node->got_line_feed) std::cerr << \"[LF] \";\n std::cerr << std::endl;\n } else {\n std::cerr << ind;\n std::cerr << \"OTHER \";\n std::cerr << node << \" \";\n if (node->got_line_feed) std::cerr << \"[LF] \";\n std::cerr << std::endl;\n }\n if (ind == \"\") std::cerr << \"#@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@\\n\";\n}\n\n/*\ninline void debug_ast(const AST_Node_Ptr node, std::string ind = \"\", Env* env = 0)\n{\n debug_ast(const_cast(node), ind, env);\n}\n*/\ninline void debug_node(const Node* node, std::string ind = \"\")\n{\n debug_node(const_cast(node), ind);\n}\n\n\ninline void debug_subset_map(Sass::Subset_Map& map, std::string ind = \"\")\n{\n if (ind == \"\") std::cerr << \"#%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%\\n\";\n for(auto const &it : map.values()) {\n debug_ast(it.first, ind + \"first: \");\n debug_ast(it.second, ind + \"second: \");\n }\n if (ind == \"\") std::cerr << \"#%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%\\n\";\n}\n\ninline void debug_subset_entries(SubSetMapPairs* entries, std::string ind = \"\")\n{\n if (ind == \"\") std::cerr << \"#%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%\\n\";\n for(auto const &pair : *entries) {\n debug_ast(pair.first, ind + \"first: \");\n debug_ast(pair.second, ind + \"second: \");\n }\n if (ind == \"\") std::cerr << \"#%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%\\n\";\n}\n\n#endif // SASS_DEBUGGER\n\n// Path: src/libsass/src/memory/SharedPtr.cpp\n#include \"../sass.hpp\"\n#include \n#include \n\n#include \"SharedPtr.hpp\"\n#include \"../ast_fwd_decl.hpp\"\n\n#ifdef DEBUG_SHARED_PTR\n#include \"../debugger.hpp\"\n#endif\n\nnamespace Sass {\n\n #ifdef DEBUG_SHARED_PTR\n void SharedObj::dumpMemLeaks() {\n if (!all.empty()) {\n std::cerr << \"###################################\\n\";\n std::cerr << \"# REPORTING MISSING DEALLOCATIONS #\\n\";\n std::cerr << \"###################################\\n\";\n for (SharedObj* var : all) {\n if (AST_Node_Ptr ast = dynamic_cast(var)) {\n debug_ast(ast);\n } else {\n std::cerr << \"LEAKED \" << var << \"\\n\";\n }\n }\n }\n }\n std::vector SharedObj::all;\n #endif\n\n bool SharedObj::taint = false;\n\n SharedObj::SharedObj()\n : detached(false)\n #ifdef DEBUG_SHARED_PTR\n , dbg(false)\n #endif\n {\n refcounter = 0;\n #ifdef DEBUG_SHARED_PTR\n if (taint) all.push_back(this);\n #endif\n };\n\n SharedObj::~SharedObj() {\n #ifdef DEBUG_SHARED_PTR\n if (dbg) std::cerr << \"Destruct \" << this << \"\\n\";\n if(!all.empty()) { // check needed for MSVC (no clue why?)\n all.erase(std::remove(all.begin(), all.end(), this), all.end());\n }\n #endif\n };\n\n void SharedPtr::decRefCount() {\n if (node) {\n -- node->refcounter;\n #ifdef DEBUG_SHARED_PTR\n if (node->dbg) std::cerr << \"- \" << node << \" X \" << node->refcounter << \" (\" << this << \") \" << \"\\n\";\n #endif\n if (node->refcounter == 0) {\n #ifdef DEBUG_SHARED_PTR\n // AST_Node_Ptr ast = dynamic_cast(node);\n if (node->dbg) std::cerr << \"DELETE NODE \" << node << \"\\n\";\n #endif\n if (!node->detached) {\n delete(node);\n }\n }\n }\n }\n\n void SharedPtr::incRefCount() {\n if (node) {\n ++ node->refcounter;\n node->detached = false;\n #ifdef DEBUG_SHARED_PTR\n if (node->dbg) {\n std::cerr << \"+ \" << node << \" X \" << node->refcounter << \" (\" << this << \") \" << \"\\n\";\n }\n #endif\n }\n }\n\n SharedPtr::~SharedPtr() {\n decRefCount();\n }\n\n\n // the create constructor\n SharedPtr::SharedPtr(SharedObj* ptr)\n : node(ptr) {\n incRefCount();\n }\n // copy assignment operator\n SharedPtr& SharedPtr::operator=(const SharedPtr& rhs) {\n void* cur_ptr = (void*) node;\n void* rhs_ptr = (void*) rhs.node;\n if (cur_ptr == rhs_ptr) {\n return *this;\n }\n decRefCount();\n node = rhs.node;\n incRefCount();\n return *this;\n }\n\n // the copy constructor\n SharedPtr::SharedPtr(const SharedPtr& obj)\n : node(obj.node) {\n incRefCount();\n }\n\n}\n// Path: src/libsass/contrib/plugin.cpp\n#include \n#include \n#include \n#include \n\n// gcc: g++ -shared plugin.cpp -o plugin.so -fPIC -Llib -lsass\n// mingw: g++ -shared plugin.cpp -o plugin.dll -Llib -lsass\n\nextern \"C\" const char* ADDCALL libsass_get_version() {\n return libsass_version();\n}\n\nunion Sass_Value* custom_function(const union Sass_Value* s_args, Sass_Function_Entry cb, struct Sass_Compiler* comp)\n{\n // get context/option struct associated with this compiler\n struct Sass_Context* ctx = sass_compiler_get_context(comp);\n struct Sass_Options* opts = sass_compiler_get_options(comp);\n // get the cookie from function descriptor\n void* cookie = sass_function_get_cookie(cb);\n // we actually abuse the void* to store an \"int\"\n return sass_make_number((intptr_t)cookie, \"px\");\n}\n\nextern \"C\" Sass_Function_List ADDCALL libsass_load_functions()\n{\n // allocate a custom function caller\n Sass_Function_Entry c_func =\n sass_make_function(\"foo()\", custom_function, (void*)42);\n // create list of all custom functions\n Sass_Function_List fn_list = sass_make_function_list(1);\n // put the only function in this plugin to the list\n sass_function_set_list_entry(fn_list, 0, c_func);\n // return the list\n return fn_list;\n}\n\nSass_Import_List custom_importer(const char* cur_path, Sass_Importer_Entry cb, struct Sass_Compiler* comp)\n{\n // get the cookie from importer descriptor\n void* cookie = sass_importer_get_cookie(cb);\n // create a list to hold our import entries\n Sass_Import_List incs = sass_make_import_list(1);\n // create our only import entry (route path back)\n incs[0] = sass_make_import_entry(cur_path, 0, 0);\n // return imports\n return incs;\n}\n\nextern \"C\" Sass_Importer_List ADDCALL libsass_load_importers()\n{\n // allocate a custom function caller\n Sass_Importer_Entry c_imp =\n sass_make_importer(custom_importer, - 99, (void*)42);\n // create list of all custom functions\n Sass_Importer_List imp_list = sass_make_importer_list(1);\n // put the only function in this plugin to the list\n sass_importer_set_list_entry(imp_list, 0, c_imp);\n // return the list\n return imp_list;\n}\n\n// Path: src/libsass/test/test_node.cpp\n#include \n#include \n\n#include \"node.hpp\"\n#include \"parser.hpp\"\n\n\n#define STATIC_ARRAY_SIZE(array) (sizeof((array))/sizeof((array[0])))\n\n\nnamespace Sass {\n\n Context ctx = Context::Data();\n\n const char* const ROUNDTRIP_TESTS[] = {\n NULL,\n \"~\",\n \"CMPD\",\n \"~ CMPD\",\n \"CMPD >\",\n \"> > CMPD\",\n \"CMPD ~ ~\",\n \"> + CMPD1.CMPD2 > ~\",\n \"> + CMPD1.CMPD2 CMPD3.CMPD4 > ~\",\n \"+ CMPD1 CMPD2 ~ CMPD3 + CMPD4 > CMPD5 > ~\"\n };\n\n\n\n static Complex_Selector* createComplexSelector(std::string src) {\n std::string temp(src);\n temp += \";\";\n return (*Parser::from_c_str(temp.c_str(), ctx, \"\", Position()).parse_selector_list())[0];\n }\n\n\n void roundtripTest(const char* toTest) {\n\n // Create the initial selector\n\n Complex_Selector* pOrigSelector = NULL;\n if (toTest) {\n pOrigSelector = createComplexSelector(toTest);\n }\n\n std::string expected(pOrigSelector ? pOrigSelector->to_string() : \"NULL\");\n\n\n // Roundtrip the selector into a node and back\n\n Node node = complexSelectorToNode(pOrigSelector, ctx);\n\n std::stringstream nodeStringStream;\n nodeStringStream << node;\n std::string nodeString = nodeStringStream.str();\n cout << \"ASNODE: \" << node << endl;\n\n Complex_Selector* pNewSelector = nodeToComplexSelector(node, ctx);\n\n // Show the result\n\n std::string result(pNewSelector ? pNewSelector->to_string() : \"NULL\");\n\n cout << \"SELECTOR: \" << expected << endl;\n cout << \"NEW SELECTOR: \" << result << endl;\n\n\n // Test that they are equal using the equality operator\n\n assert( (!pOrigSelector && !pNewSelector ) || (pOrigSelector && pNewSelector) );\n if (pOrigSelector) {\n assert( *pOrigSelector == *pNewSelector );\n }\n\n\n // Test that they are equal by comparing the string versions of the selectors\n\n assert(expected == result);\n\n }\n\n\n int main() {\n for (int index = 0; index < STATIC_ARRAY_SIZE(ROUNDTRIP_TESTS); index++) {\n const char* const toTest = ROUNDTRIP_TESTS[index];\n cout << \"\\nINPUT STRING: \" << (toTest ? toTest : \"NULL\") << endl;\n roundtripTest(toTest);\n }\n\n cout << \"\\nTesting Done.\\n\";\n }\n\n\n}\n\n// Path: src/libsass/test/test_selector_difference.cpp\n#include \"../ast.hpp\"\n#include \"../context.hpp\"\n#include \"../parser.hpp\"\n#include \n#include \n\nusing namespace Sass;\n\nContext ctx = Context::Data();\n\nCompound_Selector* selector(std::string src)\n{ return Parser::from_c_str(src.c_str(), ctx, \"\", Position()).parse_compound_selector(); }\n\nvoid diff(std::string s, std::string t)\n{\n std::cout << s << \" - \" << t << \" = \" << selector(s + \";\")->minus(selector(t + \";\"), ctx)->to_string() << std::endl;\n}\n\nint main()\n{\n diff(\".a.b.c\", \".c.b\");\n diff(\".a.b.c\", \".fludge.b\");\n\n return 0;\n}\n\n// Path: src/libsass/test/test_superselector.cpp\n#include \"../ast.hpp\"\n#include \"../context.hpp\"\n#include \"../parser.hpp\"\n#include \n\nusing namespace Sass;\n\nContext ctx = Context(Context::Data());\n\nCompound_Selector* compound_selector(std::string src)\n{ return Parser::from_c_str(src.c_str(), ctx, \"\", Position()).parse_compound_selector(); }\n\nComplex_Selector* complex_selector(std::string src)\n{ return Parser::from_c_str(src.c_str(), ctx, \"\", Position()).parse_complex_selector(false); }\n\nvoid check_compound(std::string s1, std::string s2)\n{\n std::cout << \"Is \"\n << s1\n << \" a superselector of \"\n << s2\n << \"?\\t\"\n << compound_selector(s1 + \";\")->is_superselector_of(compound_selector(s2 + \";\"))\n << std::endl;\n}\n\nvoid check_complex(std::string s1, std::string s2)\n{\n std::cout << \"Is \"\n << s1\n << \" a superselector of \"\n << s2\n << \"?\\t\"\n << complex_selector(s1 + \";\")->is_superselector_of(complex_selector(s2 + \";\"))\n << std::endl;\n}\n\nint main()\n{\n check_compound(\".foo\", \".foo.bar\");\n check_compound(\".foo.bar\", \".foo\");\n check_compound(\".foo.bar\", \"div.foo\");\n check_compound(\".foo\", \"div.foo\");\n check_compound(\"div.foo\", \".foo\");\n check_compound(\"div.foo\", \"div.bar.foo\");\n check_compound(\"p.foo\", \"div.bar.foo\");\n check_compound(\".hux\", \".mumble\");\n\n std::cout << std::endl;\n\n check_complex(\".foo ~ .bar\", \".foo + .bar\");\n check_complex(\".foo .bar\", \".foo + .bar\");\n check_complex(\".foo .bar\", \".foo > .bar\");\n check_complex(\".foo .bar > .hux\", \".foo.a .bar.b > .hux\");\n check_complex(\".foo ~ .bar .hux\", \".foo.a + .bar.b > .hux\");\n check_complex(\".foo\", \".bar .foo\");\n check_complex(\".foo\", \".foo.a\");\n check_complex(\".foo.bar\", \".foo\");\n check_complex(\".foo .bar .hux\", \".bar .hux\");\n check_complex(\".foo ~ .bar .hux.x\", \".foo.a + .bar.b > .hux.y\");\n check_complex(\".foo ~ .bar .hux\", \".foo.a + .bar.b > .mumble\");\n check_complex(\".foo + .bar\", \".foo ~ .bar\");\n check_complex(\"a c e\", \"a b c d e\");\n check_complex(\"c a e\", \"a b c d e\");\n\n return 0;\n}\n\n\n\n// Path: src/libsass/test/test_paths.cpp\n#include \n#include \"../paths.hpp\"\n\nusing namespace Sass;\n\ntemplate\nstd::vector& operator<<(std::vector& v, const T& e)\n{\n v.push_back(e);\n return v;\n}\n\nint main()\n{\n std::vector v1, v2, v3;\n v1 << 1 << 2;\n v2 << 3;\n v3 << 4 << 5 << 6;\n\n std::vector > ss;\n ss << v1 << v2 << v3;\n\n std::vector > ps = paths(ss);\n for (size_t i = 0, S = ps.size(); i < S; ++i) {\n std::cout << vector_to_string(ps[i]) << std::endl;\n }\n return 0;\n}\n\n// Path: src/libsass/test/test_unification.cpp\n#include \"../ast.hpp\"\n#include \"../context.hpp\"\n#include \"../parser.hpp\"\n#include \n\nusing namespace Sass;\n\nContext ctx = Context(Context::Data());\n\nCompound_Selector* selector(std::string src)\n{ return Parser::from_c_str(src.c_str(), ctx, \"\", Position()).parse_compound_selector(); }\n\nvoid unify(std::string lhs, std::string rhs)\n{\n Compound_Selector* unified = selector(lhs + \";\")->unify_with(selector(rhs + \";\"), ctx);\n std::cout << lhs << \" UNIFIED WITH \" << rhs << \" =\\t\" << (unified ? unified->to_string() : \"NOTHING\") << std::endl;\n}\n\nint main()\n{\n unify(\".foo\", \".foo.bar\");\n unify(\"div:nth-of-type(odd)\", \"div:first-child\");\n unify(\"div\", \"span:whatever\");\n unify(\"div\", \"span\");\n unify(\"foo:bar::after\", \"foo:bar::first-letter\");\n unify(\".foo#bar.hux\", \".hux.foo#bar\");\n unify(\".foo#bar.hux\", \".hux.foo#baz\");\n unify(\"*:blah:fudge\", \"p:fudge:blah\");\n\n return 0;\n}\n\n// Path: src/libsass/test/test_subset_map.cpp\n#include \n#include \n#include \n#include \"../subset_map.hpp\"\n\nSubset_Map ssm;\n\nstring toString(std::vector v);\nstring toString(std::vector>> v);\nvoid assertEqual(string std::sExpected, std::string sResult);\n\nvoid setup() {\n ssm.clear();\n\n //@ssm[Set[1, 2]] = \"Foo\"\n std::vector s1;\n s1.push_back(\"1\");\n s1.push_back(\"2\");\n ssm.put(s1, \"Foo\");\n\n //@ssm[Set[\"fizz\", \"fazz\"]] = \"Bar\"\n std::vector s2;\n s2.push_back(\"fizz\");\n s2.push_back(\"fazz\");\n ssm.put(s2, \"Bar\");\n\n //@ssm[Set[:foo, :bar]] = \"Baz\"\n std::vector s3;\n s3.push_back(\":foo\");\n s3.push_back(\":bar\");\n ssm.put(s3, \"Baz\");\n\n //@ssm[Set[:foo, :bar, :baz]] = \"Bang\"\n std::vector s4;\n s4.push_back(\":foo\");\n s4.push_back(\":bar\");\n s4.push_back(\":baz\");\n ssm.put(s4, \"Bang\");\n\n //@ssm[Set[:bip, :bop, :blip]] = \"Qux\"\n std::vector s5;\n s5.push_back(\":bip\");\n s5.push_back(\":bop\");\n s5.push_back(\":blip\");\n ssm.put(s5, \"Qux\");\n\n //@ssm[Set[:bip, :bop]] = \"Thram\"\n std::vector s6;\n s6.push_back(\":bip\");\n s6.push_back(\":bop\");\n ssm.put(s6, \"Thram\");\n}\n\nvoid testEqualKeys() {\n std::cout << \"testEqualKeys\" << std::endl;\n\n //assert_equal [[\"Foo\", Set[1, 2]]], @ssm.get(Set[1, 2])\n std::vector k1;\n k1.push_back(\"1\");\n k1.push_back(\"2\");\n assertEqual(\"[[Foo, Set[1, 2]]]\", toString(ssm.get_kv(k1)));\n\n //assert_equal [[\"Bar\", Set[\"fizz\", \"fazz\"]]], @ssm.get(Set[\"fizz\", \"fazz\"])\n std::vector k2;\n k2.push_back(\"fizz\");\n k2.push_back(\"fazz\");\n assertEqual(\"[[Bar, Set[fizz, fazz]]]\", toString(ssm.get_kv(k2)));\n\n std::cout << std::endl;\n}\n\nvoid testSubsetKeys() {\n std::cout << \"testSubsetKeys\" << std::endl;\n\n //assert_equal [[\"Foo\", Set[1, 2]]], @ssm.get(Set[1, 2, \"fuzz\"])\n std::vector k1;\n k1.push_back(\"1\");\n k1.push_back(\"2\");\n k1.push_back(\"fuzz\");\n assertEqual(\"[[Foo, Set[1, 2]]]\", toString(ssm.get_kv(k1)));\n\n //assert_equal [[\"Bar\", Set[\"fizz\", \"fazz\"]]], @ssm.get(Set[\"fizz\", \"fazz\", 3])\n std::vector k2;\n k2.push_back(\"fizz\");\n k2.push_back(\"fazz\");\n k2.push_back(\"3\");\n assertEqual(\"[[Bar, Set[fizz, fazz]]]\", toString(ssm.get_kv(k2)));\n\n std::cout << std::endl;\n}\n\nvoid testSupersetKeys() {\n std::cout << \"testSupersetKeys\" << std::endl;\n\n //assert_equal [], @ssm.get(Set[1])\n std::vector k1;\n k1.push_back(\"1\");\n assertEqual(\"[]\", toString(ssm.get_kv(k1)));\n\n //assert_equal [], @ssm.get(Set[2])\n std::vector k2;\n k2.push_back(\"2\");\n assertEqual(\"[]\", toString(ssm.get_kv(k2)));\n\n //assert_equal [], @ssm.get(Set[\"fizz\"])\n std::vector k3;\n k3.push_back(\"fizz\");\n assertEqual(\"[]\", toString(ssm.get_kv(k3)));\n\n //assert_equal [], @ssm.get(Set[\"fazz\"])\n std::vector k4;\n k4.push_back(\"fazz\");\n assertEqual(\"[]\", toString(ssm.get_kv(k4)));\n\n std::cout << std::endl;\n}\n\nvoid testDisjointKeys() {\n std::cout << \"testDisjointKeys\" << std::endl;\n\n //assert_equal [], @ssm.get(Set[3, 4])\n std::vector k1;\n k1.push_back(\"3\");\n k1.push_back(\"4\");\n assertEqual(\"[]\", toString(ssm.get_kv(k1)));\n\n //assert_equal [], @ssm.get(Set[\"fuzz\", \"frizz\"])\n std::vector k2;\n k2.push_back(\"fuzz\");\n k2.push_back(\"frizz\");\n assertEqual(\"[]\", toString(ssm.get_kv(k2)));\n\n //assert_equal [], @ssm.get(Set[\"gran\", 15])\n std::vector k3;\n k3.push_back(\"gran\");\n k3.push_back(\"15\");\n assertEqual(\"[]\", toString(ssm.get_kv(k3)));\n\n std::cout << std::endl;\n}\n\nvoid testSemiDisjointKeys() {\n std::cout << \"testSemiDisjointKeys\" << std::endl;\n\n //assert_equal [], @ssm.get(Set[2, 3])\n std::vector k1;\n k1.push_back(\"2\");\n k1.push_back(\"3\");\n assertEqual(\"[]\", toString(ssm.get_kv(k1)));\n\n //assert_equal [], @ssm.get(Set[\"fizz\", \"fuzz\"])\n std::vector k2;\n k2.push_back(\"fizz\");\n k2.push_back(\"fuzz\");\n assertEqual(\"[]\", toString(ssm.get_kv(k2)));\n\n //assert_equal [], @ssm.get(Set[1, \"fazz\"])\n std::vector k3;\n k3.push_back(\"1\");\n k3.push_back(\"fazz\");\n assertEqual(\"[]\", toString(ssm.get_kv(k3)));\n\n std::cout << std::endl;\n}\n\nvoid testEmptyKeySet() {\n std::cout << \"testEmptyKeySet\" << std::endl;\n\n //assert_raises(ArgumentError) {@ssm[Set[]] = \"Fail\"}\n std::vector s1;\n try {\n ssm.put(s1, \"Fail\");\n }\n catch (const char* &e) {\n assertEqual(\"internal error: subset map keys may not be empty\", e);\n }\n}\n\nvoid testEmptyKeyGet() {\n std::cout << \"testEmptyKeyGet\" << std::endl;\n\n //assert_equal [], @ssm.get(Set[])\n std::vector k1;\n assertEqual(\"[]\", toString(ssm.get_kv(k1)));\n\n std::cout << std::endl;\n}\nvoid testMultipleSubsets() {\n std::cout << \"testMultipleSubsets\" << std::endl;\n\n //assert_equal [[\"Foo\", Set[1, 2]], [\"Bar\", Set[\"fizz\", \"fazz\"]]], @ssm.get(Set[1, 2, \"fizz\", \"fazz\"])\n std::vector k1;\n k1.push_back(\"1\");\n k1.push_back(\"2\");\n k1.push_back(\"fizz\");\n k1.push_back(\"fazz\");\n assertEqual(\"[[Foo, Set[1, 2]], [Bar, Set[fizz, fazz]]]\", toString(ssm.get_kv(k1)));\n\n //assert_equal [[\"Foo\", Set[1, 2]], [\"Bar\", Set[\"fizz\", \"fazz\"]]], @ssm.get(Set[1, 2, 3, \"fizz\", \"fazz\", \"fuzz\"])\n std::vector k2;\n k2.push_back(\"1\");\n k2.push_back(\"2\");\n k2.push_back(\"3\");\n k2.push_back(\"fizz\");\n k2.push_back(\"fazz\");\n k2.push_back(\"fuzz\");\n assertEqual(\"[[Foo, Set[1, 2]], [Bar, Set[fizz, fazz]]]\", toString(ssm.get_kv(k2)));\n\n //assert_equal [[\"Baz\", Set[:foo, :bar]]], @ssm.get(Set[:foo, :bar])\n std::vector k3;\n k3.push_back(\":foo\");\n k3.push_back(\":bar\");\n assertEqual(\"[[Baz, Set[:foo, :bar]]]\", toString(ssm.get_kv(k3)));\n\n //assert_equal [[\"Baz\", Set[:foo, :bar]], [\"Bang\", Set[:foo, :bar, :baz]]], @ssm.get(Set[:foo, :bar, :baz])\n std::vector k4;\n k4.push_back(\":foo\");\n k4.push_back(\":bar\");\n k4.push_back(\":baz\");\n assertEqual(\"[[Baz, Set[:foo, :bar]], [Bang, Set[:foo, :bar, :baz]]]\", toString(ssm.get_kv(k4)));\n\n std::cout << std::endl;\n}\nvoid testBracketBracket() {\n std::cout << \"testBracketBracket\" << std::endl;\n\n //assert_equal [\"Foo\"], @ssm[Set[1, 2, \"fuzz\"]]\n std::vector k1;\n k1.push_back(\"1\");\n k1.push_back(\"2\");\n k1.push_back(\"fuzz\");\n assertEqual(\"[Foo]\", toString(ssm.get_v(k1)));\n\n //assert_equal [\"Baz\", \"Bang\"], @ssm[Set[:foo, :bar, :baz]]\n std::vector k2;\n k2.push_back(\":foo\");\n k2.push_back(\":bar\");\n k2.push_back(\":baz\");\n assertEqual(\"[Baz, Bang]\", toString(ssm.get_v(k2)));\n\n std::cout << std::endl;\n}\n\nvoid testKeyOrder() {\n std::cout << \"testEqualKeys\" << std::endl;\n\n //assert_equal [[\"Foo\", Set[1, 2]]], @ssm.get(Set[2, 1])\n std::vector k1;\n k1.push_back(\"2\");\n k1.push_back(\"1\");\n assertEqual(\"[[Foo, Set[1, 2]]]\", toString(ssm.get_kv(k1)));\n\n std::cout << std::endl;\n}\n\nvoid testOrderPreserved() {\n std::cout << \"testOrderPreserved\" << std::endl;\n //@ssm[Set[10, 11, 12]] = 1\n std::vector s1;\n s1.push_back(\"10\");\n s1.push_back(\"11\");\n s1.push_back(\"12\");\n ssm.put(s1, \"1\");\n\n //@ssm[Set[10, 11]] = 2\n std::vector s2;\n s2.push_back(\"10\");\n s2.push_back(\"11\");\n ssm.put(s2, \"2\");\n\n //@ssm[Set[11]] = 3\n std::vector s3;\n s3.push_back(\"11\");\n ssm.put(s3, \"3\");\n\n //@ssm[Set[11, 12]] = 4\n std::vector s4;\n s4.push_back(\"11\");\n s4.push_back(\"12\");\n ssm.put(s4, \"4\");\n\n //@ssm[Set[9, 10, 11, 12, 13]] = 5\n std::vector s5;\n s5.push_back(\"9\");\n s5.push_back(\"10\");\n s5.push_back(\"11\");\n s5.push_back(\"12\");\n s5.push_back(\"13\");\n ssm.put(s5, \"5\");\n\n //@ssm[Set[10, 13]] = 6\n std::vector s6;\n s6.push_back(\"10\");\n s6.push_back(\"13\");\n ssm.put(s6, \"6\");\n\n //assert_equal([[1, Set[10, 11, 12]], [2, Set[10, 11]], [3, Set[11]], [4, Set[11, 12]], [5, Set[9, 10, 11, 12, 13]], [6, Set[10, 13]]], @ssm.get(Set[9, 10, 11, 12, 13]))\n std::vector k1;\n k1.push_back(\"9\");\n k1.push_back(\"10\");\n k1.push_back(\"11\");\n k1.push_back(\"12\");\n k1.push_back(\"13\");\n assertEqual(\"[[1, Set[10, 11, 12]], [2, Set[10, 11]], [3, Set[11]], [4, Set[11, 12]], [5, Set[9, 10, 11, 12, 13]], [6, Set[10, 13]]]\", toString(ssm.get_kv(k1)));\n\n std::cout << std::endl;\n}\nvoid testMultipleEqualValues() {\n std::cout << \"testMultipleEqualValues\" << std::endl;\n //@ssm[Set[11, 12]] = 1\n std::vector s1;\n s1.push_back(\"11\");\n s1.push_back(\"12\");\n ssm.put(s1, \"1\");\n\n //@ssm[Set[12, 13]] = 2\n std::vector s2;\n s2.push_back(\"12\");\n s2.push_back(\"13\");\n ssm.put(s2, \"2\");\n\n //@ssm[Set[13, 14]] = 1\n std::vector s3;\n s3.push_back(\"13\");\n s3.push_back(\"14\");\n ssm.put(s3, \"1\");\n\n //@ssm[Set[14, 15]] = 1\n std::vector s4;\n s4.push_back(\"14\");\n s4.push_back(\"15\");\n ssm.put(s4, \"1\");\n\n //assert_equal([[1, Set[11, 12]], [2, Set[12, 13]], [1, Set[13, 14]], [1, Set[14, 15]]], @ssm.get(Set[11, 12, 13, 14, 15]))\n std::vector k1;\n k1.push_back(\"11\");\n k1.push_back(\"12\");\n k1.push_back(\"13\");\n k1.push_back(\"14\");\n k1.push_back(\"15\");\n assertEqual(\"[[1, Set[11, 12]], [2, Set[12, 13]], [1, Set[13, 14]], [1, Set[14, 15]]]\", toString(ssm.get_kv(k1)));\n\n std::cout << std::endl;\n}\n\n...\n", "needle_token_start": 7299, "needle_token_end": 7428, "code_context_ntokens": 16222, "output": ["```cpp\ninline void debug_subset_map(Sass::Subset_Map& map, std::string ind = \"\")\n{\n if (ind == \"\") std::cerr << \"#%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%\\n\";\n for(auto const &it : map.values()) {\n debug_ast(it.first, ind + \"first: \");\n debug_ast(it.second, ind + \"second: \");\n }\n if (ind == \"\") std::cerr << \"#%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%\\n\";\n}\n```"]} +{"repo": "sass/node-sass", "name": "get_unit_type", "language": "cpp", "path": "src/libsass/src/units.cpp", "position_ratio": 0.55, "description": "\nFunction Description:\n1. **Purpose**: To determine the category of a given unit based on its type.\n2. **Input**: An integer representing a specific unit type.\n3. **Output**: An enumeration value representing the category of the unit (e.g., length, angle, time).\n4. **Procedure**: The function checks the upper byte of the input integer to match it against predefined categories. It returns the corresponding category if a match is found, otherwise, it returns a default category indicating the unit is not commensurable with known types.\n\n", "instruction": "Based on the function description and code context, please retrieve and repeat the exact described function from the code context in a code block wrapped by ```:", "template": "instruction\ncode_context\ndescription\ninstruction", "code_context": "// Path: src/libsass/src/color_maps.cpp\n#include \"sass.hpp\"\n#include \"ast.hpp\"\n#include \"color_maps.hpp\"\n\nnamespace Sass {\n\n namespace ColorNames\n {\n const char aliceblue [] = \"aliceblue\";\n const char antiquewhite [] = \"antiquewhite\";\n const char cyan [] = \"cyan\";\n const char aqua [] = \"aqua\";\n const char aquamarine [] = \"aquamarine\";\n const char azure [] = \"azure\";\n const char beige [] = \"beige\";\n const char bisque [] = \"bisque\";\n const char black [] = \"black\";\n const char blanchedalmond [] = \"blanchedalmond\";\n const char blue [] = \"blue\";\n const char blueviolet [] = \"blueviolet\";\n const char brown [] = \"brown\";\n const char burlywood [] = \"burlywood\";\n const char cadetblue [] = \"cadetblue\";\n const char chartreuse [] = \"chartreuse\";\n const char chocolate [] = \"chocolate\";\n const char coral [] = \"coral\";\n const char cornflowerblue [] = \"cornflowerblue\";\n const char cornsilk [] = \"cornsilk\";\n const char crimson [] = \"crimson\";\n const char darkblue [] = \"darkblue\";\n const char darkcyan [] = \"darkcyan\";\n const char darkgoldenrod [] = \"darkgoldenrod\";\n const char darkgray [] = \"darkgray\";\n const char darkgrey [] = \"darkgrey\";\n const char darkgreen [] = \"darkgreen\";\n const char darkkhaki [] = \"darkkhaki\";\n const char darkmagenta [] = \"darkmagenta\";\n const char darkolivegreen [] = \"darkolivegreen\";\n const char darkorange [] = \"darkorange\";\n const char darkorchid [] = \"darkorchid\";\n const char darkred [] = \"darkred\";\n const char darksalmon [] = \"darksalmon\";\n const char darkseagreen [] = \"darkseagreen\";\n const char darkslateblue [] = \"darkslateblue\";\n const char darkslategray [] = \"darkslategray\";\n const char darkslategrey [] = \"darkslategrey\";\n const char darkturquoise [] = \"darkturquoise\";\n const char darkviolet [] = \"darkviolet\";\n const char deeppink [] = \"deeppink\";\n const char deepskyblue [] = \"deepskyblue\";\n const char dimgray [] = \"dimgray\";\n const char dimgrey [] = \"dimgrey\";\n const char dodgerblue [] = \"dodgerblue\";\n const char firebrick [] = \"firebrick\";\n const char floralwhite [] = \"floralwhite\";\n const char forestgreen [] = \"forestgreen\";\n const char magenta [] = \"magenta\";\n const char fuchsia [] = \"fuchsia\";\n const char gainsboro [] = \"gainsboro\";\n const char ghostwhite [] = \"ghostwhite\";\n const char gold [] = \"gold\";\n const char goldenrod [] = \"goldenrod\";\n const char gray [] = \"gray\";\n const char grey [] = \"grey\";\n const char green [] = \"green\";\n const char greenyellow [] = \"greenyellow\";\n const char honeydew [] = \"honeydew\";\n const char hotpink [] = \"hotpink\";\n const char indianred [] = \"indianred\";\n const char indigo [] = \"indigo\";\n const char ivory [] = \"ivory\";\n const char khaki [] = \"khaki\";\n const char lavender [] = \"lavender\";\n const char lavenderblush [] = \"lavenderblush\";\n const char lawngreen [] = \"lawngreen\";\n const char lemonchiffon [] = \"lemonchiffon\";\n const char lightblue [] = \"lightblue\";\n const char lightcoral [] = \"lightcoral\";\n const char lightcyan [] = \"lightcyan\";\n const char lightgoldenrodyellow [] = \"lightgoldenrodyellow\";\n const char lightgray [] = \"lightgray\";\n const char lightgrey [] = \"lightgrey\";\n const char lightgreen [] = \"lightgreen\";\n const char lightpink [] = \"lightpink\";\n const char lightsalmon [] = \"lightsalmon\";\n const char lightseagreen [] = \"lightseagreen\";\n const char lightskyblue [] = \"lightskyblue\";\n const char lightslategray [] = \"lightslategray\";\n const char lightslategrey [] = \"lightslategrey\";\n const char lightsteelblue [] = \"lightsteelblue\";\n const char lightyellow [] = \"lightyellow\";\n const char lime [] = \"lime\";\n const char limegreen [] = \"limegreen\";\n const char linen [] = \"linen\";\n const char maroon [] = \"maroon\";\n const char mediumaquamarine [] = \"mediumaquamarine\";\n const char mediumblue [] = \"mediumblue\";\n const char mediumorchid [] = \"mediumorchid\";\n const char mediumpurple [] = \"mediumpurple\";\n const char mediumseagreen [] = \"mediumseagreen\";\n const char mediumslateblue [] = \"mediumslateblue\";\n const char mediumspringgreen [] = \"mediumspringgreen\";\n const char mediumturquoise [] = \"mediumturquoise\";\n const char mediumvioletred [] = \"mediumvioletred\";\n const char midnightblue [] = \"midnightblue\";\n const char mintcream [] = \"mintcream\";\n const char mistyrose [] = \"mistyrose\";\n const char moccasin [] = \"moccasin\";\n const char navajowhite [] = \"navajowhite\";\n const char navy [] = \"navy\";\n const char oldlace [] = \"oldlace\";\n const char olive [] = \"olive\";\n const char olivedrab [] = \"olivedrab\";\n const char orange [] = \"orange\";\n const char orangered [] = \"orangered\";\n const char orchid [] = \"orchid\";\n const char palegoldenrod [] = \"palegoldenrod\";\n const char palegreen [] = \"palegreen\";\n const char paleturquoise [] = \"paleturquoise\";\n const char palevioletred [] = \"palevioletred\";\n const char papayawhip [] = \"papayawhip\";\n const char peachpuff [] = \"peachpuff\";\n const char peru [] = \"peru\";\n const char pink [] = \"pink\";\n const char plum [] = \"plum\";\n const char powderblue [] = \"powderblue\";\n const char purple [] = \"purple\";\n const char red [] = \"red\";\n const char rosybrown [] = \"rosybrown\";\n const char royalblue [] = \"royalblue\";\n const char saddlebrown [] = \"saddlebrown\";\n const char salmon [] = \"salmon\";\n const char sandybrown [] = \"sandybrown\";\n const char seagreen [] = \"seagreen\";\n const char seashell [] = \"seashell\";\n const char sienna [] = \"sienna\";\n const char silver [] = \"silver\";\n const char skyblue [] = \"skyblue\";\n const char slateblue [] = \"slateblue\";\n const char slategray [] = \"slategray\";\n const char slategrey [] = \"slategrey\";\n const char snow [] = \"snow\";\n const char springgreen [] = \"springgreen\";\n const char steelblue [] = \"steelblue\";\n const char tan [] = \"tan\";\n const char teal [] = \"teal\";\n const char thistle [] = \"thistle\";\n const char tomato [] = \"tomato\";\n const char turquoise [] = \"turquoise\";\n const char violet [] = \"violet\";\n const char wheat [] = \"wheat\";\n const char white [] = \"white\";\n const char whitesmoke [] = \"whitesmoke\";\n const char yellow [] = \"yellow\";\n const char yellowgreen [] = \"yellowgreen\";\n const char rebeccapurple [] = \"rebeccapurple\";\n const char transparent [] = \"transparent\";\n }\n\n namespace Colors {\n const ParserState color_table(\"[COLOR TABLE]\");\n const Color aliceblue(color_table, 240, 248, 255, 1);\n const Color antiquewhite(color_table, 250, 235, 215, 1);\n const Color cyan(color_table, 0, 255, 255, 1);\n const Color aqua(color_table, 0, 255, 255, 1);\n const Color aquamarine(color_table, 127, 255, 212, 1);\n const Color azure(color_table, 240, 255, 255, 1);\n const Color beige(color_table, 245, 245, 220, 1);\n const Color bisque(color_table, 255, 228, 196, 1);\n const Color black(color_table, 0, 0, 0, 1);\n const Color blanchedalmond(color_table, 255, 235, 205, 1);\n const Color blue(color_table, 0, 0, 255, 1);\n const Color blueviolet(color_table, 138, 43, 226, 1);\n const Color brown(color_table, 165, 42, 42, 1);\n const Color burlywood(color_table, 222, 184, 135, 1);\n const Color cadetblue(color_table, 95, 158, 160, 1);\n const Color chartreuse(color_table, 127, 255, 0, 1);\n const Color chocolate(color_table, 210, 105, 30, 1);\n const Color coral(color_table, 255, 127, 80, 1);\n const Color cornflowerblue(color_table, 100, 149, 237, 1);\n const Color cornsilk(color_table, 255, 248, 220, 1);\n const Color crimson(color_table, 220, 20, 60, 1);\n const Color darkblue(color_table, 0, 0, 139, 1);\n const Color darkcyan(color_table, 0, 139, 139, 1);\n const Color darkgoldenrod(color_table, 184, 134, 11, 1);\n const Color darkgray(color_table, 169, 169, 169, 1);\n const Color darkgrey(color_table, 169, 169, 169, 1);\n const Color darkgreen(color_table, 0, 100, 0, 1);\n const Color darkkhaki(color_table, 189, 183, 107, 1);\n const Color darkmagenta(color_table, 139, 0, 139, 1);\n const Color darkolivegreen(color_table, 85, 107, 47, 1);\n const Color darkorange(color_table, 255, 140, 0, 1);\n const Color darkorchid(color_table, 153, 50, 204, 1);\n const Color darkred(color_table, 139, 0, 0, 1);\n const Color darksalmon(color_table, 233, 150, 122, 1);\n const Color darkseagreen(color_table, 143, 188, 143, 1);\n const Color darkslateblue(color_table, 72, 61, 139, 1);\n const Color darkslategray(color_table, 47, 79, 79, 1);\n const Color darkslategrey(color_table, 47, 79, 79, 1);\n const Color darkturquoise(color_table, 0, 206, 209, 1);\n const Color darkviolet(color_table, 148, 0, 211, 1);\n const Color deeppink(color_table, 255, 20, 147, 1);\n const Color deepskyblue(color_table, 0, 191, 255, 1);\n const Color dimgray(color_table, 105, 105, 105, 1);\n const Color dimgrey(color_table, 105, 105, 105, 1);\n...\n// Path: src/libsass/src/values.cpp\n#include \"sass.hpp\"\n#include \"sass.h\"\n#include \"values.hpp\"\n\n#include \n\nnamespace Sass {\n\n // convert value from C++ side to C-API\n union Sass_Value* ast_node_to_sass_value (const Expression_Ptr val)\n {\n if (val->concrete_type() == Expression::NUMBER)\n {\n Number_Ptr_Const res = Cast(val);\n return sass_make_number(res->value(), res->unit().c_str());\n }\n else if (val->concrete_type() == Expression::COLOR)\n {\n Color_Ptr_Const col = Cast(val);\n return sass_make_color(col->r(), col->g(), col->b(), col->a());\n }\n else if (val->concrete_type() == Expression::LIST)\n {\n List_Ptr_Const l = Cast(val);\n union Sass_Value* list = sass_make_list(l->size(), l->separator(), l->is_bracketed());\n for (size_t i = 0, L = l->length(); i < L; ++i) {\n Expression_Obj obj = l->at(i);\n auto val = ast_node_to_sass_value(obj);\n sass_list_set_value(list, i, val);\n }\n return list;\n }\n else if (val->concrete_type() == Expression::MAP)\n {\n Map_Ptr_Const m = Cast(val);\n union Sass_Value* map = sass_make_map(m->length());\n size_t i = 0; for (Expression_Obj key : m->keys()) {\n sass_map_set_key(map, i, ast_node_to_sass_value(key));\n sass_map_set_value(map, i, ast_node_to_sass_value(m->at(key)));\n ++ i;\n }\n return map;\n }\n else if (val->concrete_type() == Expression::NULL_VAL)\n {\n return sass_make_null();\n }\n else if (val->concrete_type() == Expression::BOOLEAN)\n {\n Boolean_Ptr_Const res = Cast(val);\n return sass_make_boolean(res->value());\n }\n else if (val->concrete_type() == Expression::STRING)\n {\n if (String_Quoted_Ptr_Const qstr = Cast(val))\n {\n return sass_make_qstring(qstr->value().c_str());\n }\n else if (String_Constant_Ptr_Const cstr = Cast(val))\n {\n return sass_make_string(cstr->value().c_str());\n }\n }\n return sass_make_error(\"unknown sass value type\");\n }\n\n // convert value from C-API to C++ side\n Value_Ptr sass_value_to_ast_node (const union Sass_Value* val)\n {\n switch (sass_value_get_tag(val)) {\n case SASS_NUMBER:\n return SASS_MEMORY_NEW(Number,\n ParserState(\"[C-VALUE]\"),\n sass_number_get_value(val),\n sass_number_get_unit(val));\n case SASS_BOOLEAN:\n return SASS_MEMORY_NEW(Boolean,\n ParserState(\"[C-VALUE]\"),\n sass_boolean_get_value(val));\n case SASS_COLOR:\n return SASS_MEMORY_NEW(Color,\n ParserState(\"[C-VALUE]\"),\n sass_color_get_r(val),\n sass_color_get_g(val),\n sass_color_get_b(val),\n sass_color_get_a(val));\n case SASS_STRING:\n if (sass_string_is_quoted(val)) {\n return SASS_MEMORY_NEW(String_Quoted,\n ParserState(\"[C-VALUE]\"),\n sass_string_get_value(val));\n }\n return SASS_MEMORY_NEW(String_Constant,\n ParserState(\"[C-VALUE]\"),\n sass_string_get_value(val));\n case SASS_LIST: {\n List_Ptr l = SASS_MEMORY_NEW(List,\n ParserState(\"[C-VALUE]\"),\n sass_list_get_length(val),\n sass_list_get_separator(val));\n for (size_t i = 0, L = sass_list_get_length(val); i < L; ++i) {\n l->append(sass_value_to_ast_node(sass_list_get_value(val, i)));\n }\n l->is_bracketed(sass_list_get_is_bracketed(val));\n return l;\n }\n case SASS_MAP: {\n Map_Ptr m = SASS_MEMORY_NEW(Map, ParserState(\"[C-VALUE]\"));\n for (size_t i = 0, L = sass_map_get_length(val); i < L; ++i) {\n *m << std::make_pair(\n sass_value_to_ast_node(sass_map_get_key(val, i)),\n sass_value_to_ast_node(sass_map_get_value(val, i)));\n }\n return m;\n }\n case SASS_NULL:\n return SASS_MEMORY_NEW(Null, ParserState(\"[C-VALUE]\"));\n case SASS_ERROR:\n return SASS_MEMORY_NEW(Custom_Error,\n ParserState(\"[C-VALUE]\"),\n sass_error_get_message(val));\n case SASS_WARNING:\n return SASS_MEMORY_NEW(Custom_Warning,\n ParserState(\"[C-VALUE]\"),\n sass_warning_get_message(val));\n default: break;\n }\n return 0;\n }\n\n}\n\n// Path: src/libsass/src/error_handling.cpp\n#include \"sass.hpp\"\n#include \"ast.hpp\"\n#include \"prelexer.hpp\"\n#include \"backtrace.hpp\"\n#include \"error_handling.hpp\"\n\n#include \n\nnamespace Sass {\n\n namespace Exception {\n\n Base::Base(ParserState pstate, std::string msg, Backtraces traces)\n : std::runtime_error(msg), msg(msg),\n prefix(\"Error\"), pstate(pstate), traces(traces)\n { }\n\n InvalidSass::InvalidSass(ParserState pstate, Backtraces traces, std::string msg)\n : Base(pstate, msg, traces)\n { }\n\n\n InvalidParent::InvalidParent(Selector_Ptr parent, Backtraces traces, Selector_Ptr selector)\n : Base(selector->pstate(), def_msg, traces), parent(parent), selector(selector)\n {\n msg = \"Invalid parent selector for \\\"\";\n msg += selector->to_string(Sass_Inspect_Options());\n msg += \"\\\": \\\"\";\n msg += parent->to_string(Sass_Inspect_Options());\n msg += \"\\\"\";\n }\n\n InvalidVarKwdType::InvalidVarKwdType(ParserState pstate, Backtraces traces, std::string name, const Argument_Ptr arg)\n : Base(pstate, def_msg, traces), name(name), arg(arg)\n {\n msg = \"Variable keyword argument map must have string keys.\\n\";\n msg += name + \" is not a string in \" + arg->to_string() + \".\";\n }\n\n InvalidArgumentType::InvalidArgumentType(ParserState pstate, Backtraces traces, std::string fn, std::string arg, std::string type, const Value_Ptr value)\n : Base(pstate, def_msg, traces), fn(fn), arg(arg), type(type), value(value)\n {\n msg = arg + \": \\\"\";\n if (value) msg += value->to_string(Sass_Inspect_Options());\n msg += \"\\\" is not a \" + type;\n msg += \" for `\" + fn + \"'\";\n }\n\n MissingArgument::MissingArgument(ParserState pstate, Backtraces traces, std::string fn, std::string arg, std::string fntype)\n : Base(pstate, def_msg, traces), fn(fn), arg(arg), fntype(fntype)\n {\n msg = fntype + \" \" + fn;\n msg += \" is missing argument \";\n msg += arg + \".\";\n }\n\n InvalidSyntax::InvalidSyntax(ParserState pstate, Backtraces traces, std::string msg)\n : Base(pstate, msg, traces)\n { }\n\n NestingLimitError::NestingLimitError(ParserState pstate, Backtraces traces, std::string msg)\n : Base(pstate, msg, traces)\n { }\n\n DuplicateKeyError::DuplicateKeyError(Backtraces traces, const Map& dup, const Expression& org)\n : Base(org.pstate(), def_msg, traces), dup(dup), org(org)\n {\n msg = \"Duplicate key \";\n msg += dup.get_duplicate_key()->inspect();\n msg += \" in map (\";\n msg += org.inspect();\n msg += \").\";\n }\n\n TypeMismatch::TypeMismatch(Backtraces traces, const Expression& var, const std::string type)\n : Base(var.pstate(), def_msg, traces), var(var), type(type)\n {\n msg = var.to_string();\n msg += \" is not an \";\n msg += type;\n msg += \".\";\n }\n\n InvalidValue::InvalidValue(Backtraces traces, const Expression& val)\n : Base(val.pstate(), def_msg, traces), val(val)\n {\n msg = val.to_string();\n msg += \" isn't a valid CSS value.\";\n }\n\n StackError::StackError(Backtraces traces, const AST_Node& node)\n : Base(node.pstate(), def_msg, traces), node(node)\n {\n msg = \"stack level too deep\";\n }\n\n IncompatibleUnits::IncompatibleUnits(const Units& lhs, const Units& rhs)\n {\n msg = \"Incompatible units: '\";\n msg += rhs.unit();\n msg += \"' and '\";\n msg += lhs.unit();\n msg += \"'.\";\n }\n\n IncompatibleUnits::IncompatibleUnits(const UnitType lhs, const UnitType rhs)\n {\n msg = \"Incompatible units: '\";\n msg += unit_to_string(rhs);\n msg += \"' and '\";\n msg += unit_to_string(lhs);\n msg += \"'.\";\n }\n\n AlphaChannelsNotEqual::AlphaChannelsNotEqual(Expression_Ptr_Const lhs, Expression_Ptr_Const rhs, enum Sass_OP op)\n : OperationError(), lhs(lhs), rhs(rhs), op(op)\n {\n msg = \"Alpha channels must be equal: \";\n msg += lhs->to_string({ NESTED, 5 });\n msg += \" \" + sass_op_to_name(op) + \" \";\n msg += rhs->to_string({ NESTED, 5 });\n msg += \".\";\n }\n\n ZeroDivisionError::ZeroDivisionError(const Expression& lhs, const Expression& rhs)\n : OperationError(), lhs(lhs), rhs(rhs)\n {\n msg = \"divided by 0\";\n }\n\n UndefinedOperation::UndefinedOperation(Expression_Ptr_Const lhs, Expression_Ptr_Const rhs, enum Sass_OP op)\n : OperationError(), lhs(lhs), rhs(rhs), op(op)\n {\n msg = def_op_msg + \": \\\"\";\n msg += lhs->to_string({ NESTED, 5 });\n msg += \" \" + sass_op_to_name(op) + \" \";\n msg += rhs->to_string({ TO_SASS, 5 });\n msg += \"\\\".\";\n }\n\n InvalidNullOperation::InvalidNullOperation(Expression_Ptr_Const lhs, Expression_Ptr_Const rhs, enum Sass_OP op)\n : UndefinedOperation(lhs, rhs, op)\n {\n msg = def_op_null_msg + \": \\\"\";\n msg += lhs->inspect();\n msg += \" \" + sass_op_to_name(op) + \" \";\n msg += rhs->inspect();\n msg += \"\\\".\";\n }\n\n SassValueError::SassValueError(Backtraces traces, ParserState pstate, OperationError& err)\n : Base(pstate, err.what(), traces)\n {\n msg = err.what();\n prefix = err.errtype();\n }\n\n }\n\n\n void warn(std::string msg, ParserState pstate)\n {\n std::cerr << \"Warning: \" << msg << std::endl;\n }\n\n void warning(std::string msg, ParserState pstate)\n {\n std::string cwd(Sass::File::get_cwd());\n std::string abs_path(Sass::File::rel2abs(pstate.path, cwd, cwd));\n std::string rel_path(Sass::File::abs2rel(pstate.path, cwd, cwd));\n std::string output_path(Sass::File::path_for_console(rel_path, abs_path, pstate.path));\n\n std::cerr << \"WARNING on line \" << pstate.line+1 << \", column \" << pstate.column+1 << \" of \" << output_path << \":\" << std::endl;\n std::cerr << msg << std::endl << std::endl;\n }\n\n void warn(std::string msg, ParserState pstate, Backtrace* bt)\n {\n warn(msg, pstate);\n }\n\n void deprecated_function(std::string msg, ParserState pstate)\n {\n std::string cwd(Sass::File::get_cwd());\n std::string abs_path(Sass::File::rel2abs(pstate.path, cwd, cwd));\n std::string rel_path(Sass::File::abs2rel(pstate.path, cwd, cwd));\n std::string output_path(Sass::File::path_for_console(rel_path, abs_path, pstate.path));\n\n std::cerr << \"DEPRECATION WARNING: \" << msg << std::endl;\n std::cerr << \"will be an error in future versions of Sass.\" << std::endl;\n std::cerr << \" on line \" << pstate.line+1 << \" of \" << output_path << std::endl;\n }\n\n void deprecated(std::string msg, std::string msg2, bool with_column, ParserState pstate)\n {\n std::string cwd(Sass::File::get_cwd());\n std::string abs_path(Sass::File::rel2abs(pstate.path, cwd, cwd));\n std::string rel_path(Sass::File::abs2rel(pstate.path, cwd, cwd));\n std::string output_path(Sass::File::path_for_console(rel_path, pstate.path, pstate.path));\n\n std::cerr << \"DEPRECATION WARNING on line \" << pstate.line + 1;\n if (with_column) std::cerr << \", column \" << pstate.column + pstate.offset.column + 1;\n if (output_path.length()) std::cerr << \" of \" << output_path;\n std::cerr << \":\" << std::endl;\n std::cerr << msg << std::endl;\n if (msg2.length()) std::cerr << msg2 << std::endl;\n std::cerr << std::endl;\n }\n\n void deprecated_bind(std::string msg, ParserState pstate)\n {\n std::string cwd(Sass::File::get_cwd());\n std::string abs_path(Sass::File::rel2abs(pstate.path, cwd, cwd));\n std::string rel_path(Sass::File::abs2rel(pstate.path, cwd, cwd));\n std::string output_path(Sass::File::path_for_console(rel_path, abs_path, pstate.path));\n\n std::cerr << \"WARNING: \" << msg << std::endl;\n std::cerr << \" on line \" << pstate.line+1 << \" of \" << output_path << std::endl;\n std::cerr << \"This will be an error in future versions of Sass.\" << std::endl;\n }\n\n // should be replaced with error with backtraces\n void coreError(std::string msg, ParserState pstate)\n {\n Backtraces traces;\n throw Exception::InvalidSyntax(pstate, traces, msg);\n }\n\n void error(std::string msg, ParserState pstate, Backtraces& traces)\n {\n traces.push_back(Backtrace(pstate));\n throw Exception::InvalidSyntax(pstate, traces, msg);\n }\n\n}\n\n// Path: src/libsass/src/units.cpp\n#include \"sass.hpp\"\n#include \n#include \"units.hpp\"\n#include \"error_handling.hpp\"\n\nnamespace Sass {\n\n /* the conversion matrix can be readed the following way */\n /* if you go down, the factor is for the numerator (multiply) */\n /* if you go right, the factor is for the denominator (divide) */\n /* and yes, we actually use both, not sure why, but why not!? */\n\n const double size_conversion_factors[6][6] =\n {\n /* in cm pc mm pt px */\n /* in */ { 1, 2.54, 6, 25.4, 72, 96, },\n /* cm */ { 1.0/2.54, 1, 6.0/2.54, 10, 72.0/2.54, 96.0/2.54 },\n /* pc */ { 1.0/6.0, 2.54/6.0, 1, 25.4/6.0, 72.0/6.0, 96.0/6.0 },\n /* mm */ { 1.0/25.4, 1.0/10.0, 6.0/25.4, 1, 72.0/25.4, 96.0/25.4 },\n /* pt */ { 1.0/72.0, 2.54/72.0, 6.0/72.0, 25.4/72.0, 1, 96.0/72.0 },\n /* px */ { 1.0/96.0, 2.54/96.0, 6.0/96.0, 25.4/96.0, 72.0/96.0, 1, }\n };\n\n const double angle_conversion_factors[4][4] =\n {\n /* deg grad rad turn */\n /* deg */ { 1, 40.0/36.0, PI/180.0, 1.0/360.0 },\n /* grad */ { 36.0/40.0, 1, PI/200.0, 1.0/400.0 },\n /* rad */ { 180.0/PI, 200.0/PI, 1, 0.5/PI },\n /* turn */ { 360.0, 400.0, 2.0*PI, 1 }\n };\n\n const double time_conversion_factors[2][2] =\n {\n /* s ms */\n /* s */ { 1, 1000.0 },\n /* ms */ { 1/1000.0, 1 }\n };\n const double frequency_conversion_factors[2][2] =\n {\n /* Hz kHz */\n /* Hz */ { 1, 1/1000.0 },\n /* kHz */ { 1000.0, 1 }\n };\n const double resolution_conversion_factors[3][3] =\n {\n /* dpi dpcm dppx */\n /* dpi */ { 1, 1/2.54, 1/96.0 },\n /* dpcm */ { 2.54, 1, 2.54/96 },\n /* dppx */ { 96, 96/2.54, 1 }\n };\n\n \nUnitClass get_unit_type(UnitType unit)\n {\n switch (unit & 0xFF00)\n {\n case UnitClass::LENGTH: return UnitClass::LENGTH;\n case UnitClass::ANGLE: return UnitClass::ANGLE;\n case UnitClass::TIME: return UnitClass::TIME;\n case UnitClass::FREQUENCY: return UnitClass::FREQUENCY;\n case UnitClass::RESOLUTION: return UnitClass::RESOLUTION;\n default: return UnitClass::INCOMMENSURABLE;\n }\n };\n\n std::string get_unit_class(UnitType unit)\n {\n switch (unit & 0xFF00)\n {\n case UnitClass::LENGTH: return \"LENGTH\";\n case UnitClass::ANGLE: return \"ANGLE\";\n case UnitClass::TIME: return \"TIME\";\n case UnitClass::FREQUENCY: return \"FREQUENCY\";\n case UnitClass::RESOLUTION: return \"RESOLUTION\";\n default: return \"INCOMMENSURABLE\";\n }\n };\n\n UnitType get_main_unit(const UnitClass unit)\n {\n switch (unit)\n {\n case UnitClass::LENGTH: return UnitType::PX;\n case UnitClass::ANGLE: return UnitType::DEG;\n case UnitClass::TIME: return UnitType::SEC;\n case UnitClass::FREQUENCY: return UnitType::HERTZ;\n case UnitClass::RESOLUTION: return UnitType::DPI;\n default: return UnitType::UNKNOWN;\n }\n };\n\n UnitType string_to_unit(const std::string& s)\n {\n // size units\n if (s == \"px\") return UnitType::PX;\n else if (s == \"pt\") return UnitType::PT;\n else if (s == \"pc\") return UnitType::PC;\n else if (s == \"mm\") return UnitType::MM;\n else if (s == \"cm\") return UnitType::CM;\n else if (s == \"in\") return UnitType::IN;\n // angle units\n else if (s == \"deg\") return UnitType::DEG;\n else if (s == \"grad\") return UnitType::GRAD;\n else if (s == \"rad\") return UnitType::RAD;\n else if (s == \"turn\") return UnitType::TURN;\n // time units\n else if (s == \"s\") return UnitType::SEC;\n else if (s == \"ms\") return UnitType::MSEC;\n // frequency units\n else if (s == \"Hz\") return UnitType::HERTZ;\n else if (s == \"kHz\") return UnitType::KHERTZ;\n // resolutions units\n else if (s == \"dpi\") return UnitType::DPI;\n else if (s == \"dpcm\") return UnitType::DPCM;\n else if (s == \"dppx\") return UnitType::DPPX;\n // for unknown units\n else return UnitType::UNKNOWN;\n }\n\n const char* unit_to_string(UnitType unit)\n {\n switch (unit) {\n // size units\n case UnitType::PX: return \"px\";\n case UnitType::PT: return \"pt\";\n case UnitType::PC: return \"pc\";\n case UnitType::MM: return \"mm\";\n case UnitType::CM: return \"cm\";\n case UnitType::IN: return \"in\";\n // angle units\n case UnitType::DEG: return \"deg\";\n case UnitType::GRAD: return \"grad\";\n case UnitType::RAD: return \"rad\";\n case UnitType::TURN: return \"turn\";\n // time units\n case UnitType::SEC: return \"s\";\n case UnitType::MSEC: return \"ms\";\n // frequency units\n case UnitType::HERTZ: return \"Hz\";\n case UnitType::KHERTZ: return \"kHz\";\n // resolutions units\n case UnitType::DPI: return \"dpi\";\n case UnitType::DPCM: return \"dpcm\";\n case UnitType::DPPX: return \"dppx\";\n // for unknown units\n default: return \"\";\n }\n }\n\n std::string unit_to_class(const std::string& s)\n {\n if (s == \"px\") return \"LENGTH\";\n else if (s == \"pt\") return \"LENGTH\";\n else if (s == \"pc\") return \"LENGTH\";\n else if (s == \"mm\") return \"LENGTH\";\n else if (s == \"cm\") return \"LENGTH\";\n else if (s == \"in\") return \"LENGTH\";\n // angle units\n else if (s == \"deg\") return \"ANGLE\";\n else if (s == \"grad\") return \"ANGLE\";\n else if (s == \"rad\") return \"ANGLE\";\n else if (s == \"turn\") return \"ANGLE\";\n // time units\n else if (s == \"s\") return \"TIME\";\n else if (s == \"ms\") return \"TIME\";\n // frequency units\n else if (s == \"Hz\") return \"FREQUENCY\";\n else if (s == \"kHz\") return \"FREQUENCY\";\n // resolutions units\n else if (s == \"dpi\") return \"RESOLUTION\";\n else if (s == \"dpcm\") return \"RESOLUTION\";\n else if (s == \"dppx\") return \"RESOLUTION\";\n // for unknown units\n return \"CUSTOM:\" + s;\n }\n\n // throws incompatibleUnits exceptions\n double conversion_factor(const std::string& s1, const std::string& s2)\n {\n // assert for same units\n if (s1 == s2) return 1;\n // get unit enum from string\n UnitType u1 = string_to_unit(s1);\n UnitType u2 = string_to_unit(s2);\n // query unit group types\n UnitClass t1 = get_unit_type(u1);\n UnitClass t2 = get_unit_type(u2);\n // return the conversion factor\n return conversion_factor(u1, u2, t1, t2);\n }\n\n // throws incompatibleUnits exceptions\n double conversion_factor(UnitType u1, UnitType u2, UnitClass t1, UnitClass t2)\n {\n // can't convert between groups\n if (t1 != t2) return 0;\n // get absolute offset\n // used for array acces\n size_t i1 = u1 - t1;\n size_t i2 = u2 - t2;\n // process known units\n switch (t1) {\n case LENGTH:\n return size_conversion_factors[i1][i2];\n case ANGLE:\n return angle_conversion_factors[i1][i2];\n case TIME:\n return time_conversion_factors[i1][i2];\n case FREQUENCY:\n return frequency_conversion_factors[i1][i2];\n case RESOLUTION:\n return resolution_conversion_factors[i1][i2];\n case INCOMMENSURABLE:\n return 0;\n }\n // fallback\n return 0;\n }\n\n double convert_units(const std::string& lhs, const std::string& rhs, int& lhsexp, int& rhsexp)\n {\n double f = 0;\n // do not convert same ones\n if (lhs == rhs) return 0;\n // skip already canceled out unit\n if (lhsexp == 0) return 0;\n if (rhsexp == 0) return 0;\n // check if it can be converted\n UnitType ulhs = string_to_unit(lhs);\n UnitType urhs = string_to_unit(rhs);\n // skip units we cannot convert\n if (ulhs == UNKNOWN) return 0;\n if (urhs == UNKNOWN) return 0;\n // query unit group types\n UnitClass clhs = get_unit_type(ulhs);\n UnitClass crhs = get_unit_type(urhs);\n // skip units we cannot convert\n if (clhs != crhs) return 0;\n // if right denominator is bigger than lhs, we want to keep it in rhs unit\n if (rhsexp < 0 && lhsexp > 0 && - rhsexp > lhsexp) {\n // get the conversion factor for units\n f = conversion_factor(urhs, ulhs, clhs, crhs);\n // left hand side has been consumned\n f = std::pow(f, lhsexp);\n rhsexp += lhsexp;\n lhsexp = 0;\n }\n else {\n // get the conversion factor for units\n f = conversion_factor(ulhs, urhs, clhs, crhs);\n // right hand side has been consumned\n f = std::pow(f, rhsexp);\n lhsexp += rhsexp;\n rhsexp = 0;\n }\n return f;\n }\n\n bool Units::operator< (const Units& rhs) const\n {\n return (numerators < rhs.numerators) &&\n (denominators < rhs.denominators);\n }\n bool Units::operator== (const Units& rhs) const\n {\n return (numerators == rhs.numerators) &&\n (denominators == rhs.denominators);\n }\n\n double Units::normalize()\n {\n\n size_t iL = numerators.size();\n size_t nL = denominators.size();\n\n // the final conversion factor\n double factor = 1;\n\n for (size_t i = 0; i < iL; i++) {\n std::string &lhs = numerators[i];\n UnitType ulhs = string_to_unit(lhs);\n if (ulhs == UNKNOWN) continue;\n UnitClass clhs = get_unit_type(ulhs);\n UnitType umain = get_main_unit(clhs);\n if (ulhs == umain) continue;\n double f(conversion_factor(umain, ulhs, clhs, clhs));\n if (f == 0) throw std::runtime_error(\"INVALID\");\n numerators[i] = unit_to_string(umain);\n factor /= f;\n }\n\n for (size_t n = 0; n < nL; n++) {\n std::string &rhs = denominators[n];\n UnitType urhs = string_to_unit(rhs);\n if (urhs == UNKNOWN) continue;\n UnitClass crhs = get_unit_type(urhs);\n UnitType umain = get_main_unit(crhs);\n if (urhs == umain) continue;\n double f(conversion_factor(umain, urhs, crhs, crhs));\n if (f == 0) throw std::runtime_error(\"INVALID\");\n denominators[n] = unit_to_string(umain);\n factor /= f;\n }\n\n std::sort (numerators.begin(), numerators.end());\n std::sort (denominators.begin(), denominators.end());\n\n // return for conversion\n return factor;\n }\n\n double Units::reduce()\n {\n\n size_t iL = numerators.size();\n size_t nL = denominators.size();\n\n // have less than two units?\n if (iL + nL < 2) return 1;\n\n // first make sure same units cancel each other out\n // it seems that a map table will fit nicely to do this\n // we basically construct exponents for each unit\n // has the advantage that they will be pre-sorted\n std::map exponents;\n\n // initialize by summing up occurences in unit vectors\n // this will already cancel out equivalent units (e.q. px/px)\n for (size_t i = 0; i < iL; i ++) exponents[numerators[i]] += 1;\n for (size_t n = 0; n < nL; n ++) exponents[denominators[n]] -= 1;\n\n // the final conversion factor\n double factor = 1;\n\n // convert between compatible units\n for (size_t i = 0; i < iL; i++) {\n for (size_t n = 0; n < nL; n++) {\n std::string &lhs = numerators[i], &rhs = denominators[n];\n int &lhsexp = exponents[lhs], &rhsexp = exponents[rhs];\n double f(convert_units(lhs, rhs, lhsexp, rhsexp));\n if (f == 0) continue;\n factor /= f;\n }\n }\n\n // now we can build up the new unit arrays\n numerators.clear();\n denominators.clear();\n\n // recreate sorted units vectors\n for (auto exp : exponents) {\n int &exponent = exp.second;\n while (exponent > 0 && exponent --)\n numerators.push_back(exp.first);\n while (exponent < 0 && exponent ++)\n denominators.push_back(exp.first);\n }\n\n // return for conversion\n return factor;\n\n }\n\n std::string Units::unit() const\n {\n std::string u;\n size_t iL = numerators.size();\n size_t nL = denominators.size();\n for (size_t i = 0; i < iL; i += 1) {\n if (i) u += '*';\n u += numerators[i];\n }\n if (nL != 0) u += '/';\n for (size_t n = 0; n < nL; n += 1) {\n if (n) u += '*';\n u += denominators[n];\n }\n return u;\n }\n\n bool Units::is_unitless() const\n {\n return numerators.empty() &&\n denominators.empty();\n }\n\n bool Units::is_valid_css_unit() const\n {\n return numerators.size() <= 1 &&\n denominators.size() == 0;\n }\n\n // this does not cover all cases (multiple prefered units)\n double Units::convert_factor(const Units& r) const\n {\n\n std::vector miss_nums(0);\n std::vector miss_dens(0);\n // create copy since we need these for state keeping\n std::vector r_nums(r.numerators);\n std::vector r_dens(r.denominators);\n\n auto l_num_it = numerators.begin();\n auto l_num_end = numerators.end();\n\n bool l_unitless = is_unitless();\n auto r_unitless = r.is_unitless();\n\n // overall conversion\n double factor = 1;\n\n // process all left numerators\n while (l_num_it != l_num_end)\n {\n // get and increment afterwards\n const std::string l_num = *(l_num_it ++);\n\n auto r_num_it = r_nums.begin(), r_num_end = r_nums.end();\n\n bool found = false;\n // search for compatible numerator\n while (r_num_it != r_num_end)\n {\n // get and increment afterwards\n const std::string r_num = *(r_num_it);\n // get possible conversion factor for units\n double conversion = conversion_factor(l_num, r_num);\n // skip incompatible numerator\n if (conversion == 0) {\n ++ r_num_it;\n continue;\n }\n // apply to global factor\n factor *= conversion;\n // remove item from vector\n r_nums.erase(r_num_it);\n // found numerator\n found = true;\n break;\n }\n // maybe we did not find any\n // left numerator is leftover\n if (!found) miss_nums.push_back(l_num);\n }\n\n auto l_den_it = denominators.begin();\n auto l_den_end = denominators.end();\n\n // process all left denominators\n while (l_den_it != l_den_end)\n {\n // get and increment afterwards\n const std::string l_den = *(l_den_it ++);\n\n auto r_den_it = r_dens.begin();\n auto r_den_end = r_dens.end();\n\n bool found = false;\n // search for compatible denominator\n while (r_den_it != r_den_end)\n {\n // get and increment afterwards\n const std::string r_den = *(r_den_it);\n // get possible converstion factor for units\n double conversion = conversion_factor(l_den, r_den);\n // skip incompatible denominator\n if (conversion == 0) {\n ++ r_den_it;\n continue;\n }\n // apply to global factor\n factor /= conversion;\n // remove item from vector\n r_dens.erase(r_den_it);\n // found denominator\n found = true;\n break;\n }\n // maybe we did not find any\n // left denominator is leftover\n if (!found) miss_dens.push_back(l_den);\n }\n\n // check left-overs (ToDo: might cancel out?)\n if (miss_nums.size() > 0 && !r_unitless) {\n throw Exception::IncompatibleUnits(r, *this);\n }\n else if (miss_dens.size() > 0 && !r_unitless) {\n throw Exception::IncompatibleUnits(r, *this);\n }\n else if (r_nums.size() > 0 && !l_unitless) {\n throw Exception::IncompatibleUnits(r, *this);\n }\n else if (r_dens.size() > 0 && !l_unitless) {\n throw Exception::IncompatibleUnits(r, *this);\n }\n\n return factor;\n }\n\n}\n\n// Path: src/libsass/src/remove_placeholders.cpp\n#include \"sass.hpp\"\n#include \"remove_placeholders.hpp\"\n#include \"context.hpp\"\n#include \"inspect.hpp\"\n#include \n\nnamespace Sass {\n\n Remove_Placeholders::Remove_Placeholders()\n { }\n\n void Remove_Placeholders::operator()(Block_Ptr b) {\n for (size_t i = 0, L = b->length(); i < L; ++i) {\n Statement_Ptr st = b->at(i);\n st->perform(this);\n }\n }\n\n Selector_List_Ptr Remove_Placeholders::remove_placeholders(Selector_List_Ptr sl)\n {\n Selector_List_Ptr new_sl = SASS_MEMORY_NEW(Selector_List, sl->pstate());\n\n for (size_t i = 0, L = sl->length(); i < L; ++i) {\n if (!sl->at(i)->contains_placeholder()) {\n new_sl->append(sl->at(i));\n }\n }\n\n return new_sl;\n\n }\n\n\n void Remove_Placeholders::operator()(Ruleset_Ptr r) {\n // Create a new selector group without placeholders\n Selector_List_Obj sl = Cast(r->selector());\n\n if (sl) {\n // Set the new placeholder selector list\n r->selector(remove_placeholders(sl));\n // Remove placeholders in wrapped selectors\n for (Complex_Selector_Obj cs : sl->elements()) {\n while (cs) {\n if (cs->head()) {\n for (Simple_Selector_Obj& ss : cs->head()->elements()) {\n if (Wrapped_Selector_Ptr ws = Cast(ss)) {\n if (Selector_List_Ptr wsl = Cast(ws->selector())) {\n Selector_List_Ptr clean = remove_placeholders(wsl);\n // also clean superflous parent selectors\n // probably not really the correct place\n clean->remove_parent_selectors();\n ws->selector(clean);\n }\n }\n }\n }\n cs = cs->tail();\n }\n }\n }\n\n // Iterate into child blocks\n Block_Obj b = r->block();\n\n for (size_t i = 0, L = b->length(); i < L; ++i) {\n if (b->at(i)) {\n Statement_Obj st = b->at(i);\n st->perform(this);\n }\n }\n }\n\n void Remove_Placeholders::operator()(Media_Block_Ptr m) {\n operator()(m->block());\n }\n void Remove_Placeholders::operator()(Supports_Block_Ptr m) {\n operator()(m->block());\n }\n\n void Remove_Placeholders::operator()(Directive_Ptr a) {\n if (a->block()) a->block()->perform(this);\n }\n\n}\n\n// Path: src/libsass/src/operators.cpp\n#include \"sass.hpp\"\n#include \"operators.hpp\"\n\nnamespace Sass {\n\n namespace Operators {\n\n inline double add(double x, double y) { return x + y; }\n inline double sub(double x, double y) { return x - y; }\n inline double mul(double x, double y) { return x * y; }\n inline double div(double x, double y) { return x / y; } // x/0 checked by caller\n\n inline double mod(double x, double y) { // x/0 checked by caller\n if ((x > 0 && y < 0) || (x < 0 && y > 0)) {\n double ret = std::fmod(x, y);\n return ret ? ret + y : ret;\n } else {\n return std::fmod(x, y);\n }\n }\n\n typedef double (*bop)(double, double);\n bop ops[Sass_OP::NUM_OPS] = {\n 0, 0, // and, or\n 0, 0, 0, 0, 0, 0, // eq, neq, gt, gte, lt, lte\n add, sub, mul, div, mod\n };\n\n /* static function, has no pstate or traces */\n bool eq(Expression_Obj lhs, Expression_Obj rhs)\n {\n // operation is undefined if one is not a number\n if (!lhs || !rhs) throw Exception::UndefinedOperation(lhs, rhs, Sass_OP::EQ);\n // use compare operator from ast node\n return *lhs == *rhs;\n }\n\n /* static function, throws OperationError, has no pstate or traces */\n bool cmp(Expression_Obj lhs, Expression_Obj rhs, const Sass_OP op)\n {\n // can only compare numbers!?\n Number_Obj l = Cast(lhs);\n Number_Obj r = Cast(rhs);\n // operation is undefined if one is not a number\n if (!l || !r) throw Exception::UndefinedOperation(lhs, rhs, op);\n // use compare operator from ast node\n return *l < *r;\n }\n\n /* static functions, throws OperationError, has no pstate or traces */\n bool lt(Expression_Obj lhs, Expression_Obj rhs) { return cmp(lhs, rhs, Sass_OP::LT); }\n bool neq(Expression_Obj lhs, Expression_Obj rhs) { return eq(lhs, rhs) == false; }\n bool gt(Expression_Obj lhs, Expression_Obj rhs) { return !cmp(lhs, rhs, Sass_OP::GT) && neq(lhs, rhs); }\n bool lte(Expression_Obj lhs, Expression_Obj rhs) { return cmp(lhs, rhs, Sass_OP::LTE) || eq(lhs, rhs); }\n bool gte(Expression_Obj lhs, Expression_Obj rhs) { return !cmp(lhs, rhs, Sass_OP::GTE) || eq(lhs, rhs); }\n\n /* colour math deprecation warning */\n void op_color_deprecation(enum Sass_OP op, std::string lsh, std::string rhs, const ParserState& pstate)\n {\n std::string op_str(\n op == Sass_OP::ADD ? \"plus\" :\n op == Sass_OP::DIV ? \"div\" :\n op == Sass_OP::SUB ? \"minus\" :\n op == Sass_OP::MUL ? \"times\" : \"\"\n );\n\n std::string msg(\"The operation `\" + lsh + \" \" + op_str + \" \" + rhs + \"` is deprecated and will be an error in future versions.\");\n std::string tail(\"Consider using Sass's color functions instead.\\nhttp://sass-lang.com/documentation/Sass/Script/Functions.html#other_color_functions\");\n\n deprecated(msg, tail, false, pstate);\n }\n\n /* static function, throws OperationError, has no traces but optional pstate for returned value */\n Value_Ptr op_strings(Sass::Operand operand, Value& lhs, Value& rhs, struct Sass_Inspect_Options opt, const ParserState& pstate, bool delayed)\n {\n enum Sass_OP op = operand.operand;\n\n String_Quoted_Ptr lqstr = Cast(&lhs);\n String_Quoted_Ptr rqstr = Cast(&rhs);\n\n std::string lstr(lqstr ? lqstr->value() : lhs.to_string(opt));\n std::string rstr(rqstr ? rqstr->value() : rhs.to_string(opt));\n\n if (Cast(&lhs)) throw Exception::InvalidNullOperation(&lhs, &rhs, op);\n if (Cast(&rhs)) throw Exception::InvalidNullOperation(&lhs, &rhs, op);\n\n std::string sep;\n switch (op) {\n case Sass_OP::ADD: sep = \"\"; break;\n case Sass_OP::SUB: sep = \"-\"; break;\n case Sass_OP::DIV: sep = \"/\"; break;\n case Sass_OP::EQ: sep = \"==\"; break;\n case Sass_OP::NEQ: sep = \"!=\"; break;\n case Sass_OP::LT: sep = \"<\"; break;\n case Sass_OP::GT: sep = \">\"; break;\n case Sass_OP::LTE: sep = \"<=\"; break;\n case Sass_OP::GTE: sep = \">=\"; break;\n default:\n throw Exception::UndefinedOperation(&lhs, &rhs, op);\n break;\n }\n\n if (op == Sass_OP::ADD) {\n // create string that might be quoted on output (but do not unquote what we pass)\n return SASS_MEMORY_NEW(String_Quoted, pstate, lstr + rstr, 0, false, true);\n }\n\n // add whitespace around operator\n // but only if result is not delayed\n if (sep != \"\" && delayed == false) {\n if (operand.ws_before) sep = \" \" + sep;\n if (operand.ws_after) sep = sep + \" \";\n }\n\n if (op == Sass_OP::SUB || op == Sass_OP::DIV) {\n if (lqstr && lqstr->quote_mark()) lstr = quote(lstr);\n if (rqstr && rqstr->quote_mark()) rstr = quote(rstr);\n }\n\n return SASS_MEMORY_NEW(String_Constant, pstate, lstr + sep + rstr);\n }\n\n /* static function, throws OperationError, has no traces but optional pstate for returned value */\n Value_Ptr op_colors(enum Sass_OP op, const Color& lhs, const Color& rhs, struct Sass_Inspect_Options opt, const ParserState& pstate, bool delayed)\n {\n\n if (lhs.a() != rhs.a()) {\n throw Exception::AlphaChannelsNotEqual(&lhs, &rhs, op);\n }\n if ((op == Sass_OP::DIV || op == Sass_OP::MOD) && (!rhs.r() || !rhs.g() || !rhs.b())) {\n throw Exception::ZeroDivisionError(lhs, rhs);\n...\n", "needle_token_start": 8862, "needle_token_end": 9010, "code_context_ntokens": 16225, "output": ["```cpp\nUnitClass get_unit_type(UnitType unit)\n{\n switch (unit & 0xFF00)\n {\n case UnitClass::LENGTH: return UnitClass::LENGTH;\n case UnitClass::ANGLE: return UnitClass::ANGLE;\n case UnitClass::TIME: return UnitClass::TIME;\n case UnitClass::FREQUENCY: return UnitClass::FREQUENCY;\n case UnitClass::RESOLUTION: return UnitClass::RESOLUTION;\n default: return UnitClass::INCOMMENSURABLE;\n }\n};\n```"]} +{"repo": "sass/node-sass", "name": "ParentSuperselectorChunker", "language": "cpp", "path": "src/libsass/src/extend.cpp", "position_ratio": 0.65, "description": "\nFunction Description:\n1. **Purpose**: The purpose of this class is to manage and process a collection of CSS selectors to determine if one selector is a superselector of another, which means it encompasses all the styles of the other selector but may include additional constraints or styles.\n2. **Input**: The input is a single node representing a collection of CSS selectors.\n3. **Output**: The output is a processed collection where selectors that are less specific or redundant due to more specific superselectors are identified or filtered.\n4. **Procedure**: The procedure involves iterating through a collection of CSS selectors, comparing each selector's specificity and checking if one selector is a superselector of another. This involves creating temporary selectors to facilitate the comparison and updating the collection based on these comparisons to ensure it only contains the most specific and relevant selectors.\n\n", "instruction": "Based on the function description and code context, please retrieve and repeat the exact described function from the code context in a code block wrapped by ```:", "template": "instruction\ncode_context\ndescription\ninstruction", "code_context": "// Path: src/libsass/src/debug.hpp\n#ifndef SASS_DEBUG_H\n#define SASS_DEBUG_H\n\n#include \n\n#ifndef UINT32_MAX\n #define UINT32_MAX 0xffffffffU\n#endif\n\nenum dbg_lvl_t : uint32_t {\n NONE = 0,\n TRIM = 1,\n CHUNKS = 2,\n SUBWEAVE = 4,\n WEAVE = 8,\n EXTEND_COMPOUND = 16,\n EXTEND_COMPLEX = 32,\n LCS = 64,\n EXTEND_OBJECT = 128,\n ALL = UINT32_MAX\n};\n\n#ifdef DEBUG\n\n#ifndef DEBUG_LVL\nconst uint32_t debug_lvl = UINT32_MAX;\n#else\nconst uint32_t debug_lvl = (DEBUG_LVL);\n#endif // DEBUG_LVL\n\n#define DEBUG_PRINT(lvl, x) if((lvl) & debug_lvl) { std::cerr << x; }\n#define DEBUG_PRINTLN(lvl, x) if((lvl) & debug_lvl) { std::cerr << x << std::endl; }\n#define DEBUG_EXEC(lvl, x) if((lvl) & debug_lvl) { x; }\n\n#else // DEBUG\n\n#define DEBUG_PRINT(lvl, x)\n#define DEBUG_PRINTLN(lvl, x)\n#define DEBUG_EXEC(lvl, x)\n\n#endif // DEBUG\n\n...\n// Path: src/libsass/src/sass_util.hpp\n#ifndef SASS_SASS_UTIL_H\n#define SASS_SASS_UTIL_H\n\n#include \"ast.hpp\"\n#include \"node.hpp\"\n#include \"debug.hpp\"\n\nnamespace Sass {\n\n\n\n\n /*\n This is for ports of functions in the Sass:Util module.\n */\n\n\n /*\n # Return a Node collection of all possible paths through the given Node collection of Node collections.\n #\n # @param arrs [NodeCollection>]\n # @return [NodeCollection>]\n #\n # @example\n # paths([[1, 2], [3, 4], [5]]) #=>\n # # [[1, 3, 5],\n # # [2, 3, 5],\n # # [1, 4, 5],\n # # [2, 4, 5]]\n */\n Node paths(const Node& arrs);\n\n\n /*\n This class is a default implementation of a Node comparator that can be passed to the lcs function below.\n It uses operator== for equality comparision. It then returns one if the Nodes are equal.\n */\n class DefaultLcsComparator {\n public:\n bool operator()(const Node& one, const Node& two, Node& out) const {\n // TODO: Is this the correct C++ interpretation?\n // block ||= proc {|a, b| a == b && a}\n if (one == two) {\n out = one;\n return true;\n }\n\n return false;\n }\n };\n\n\n typedef std::vector > LCSTable;\n\n\n /*\n This is the equivalent of ruby's Sass::Util.lcs_backtrace.\n\n # Computes a single longest common subsequence for arrays x and y.\n # Algorithm from http://en.wikipedia.org/wiki/Longest_common_subsequence_problem#Reading_out_an_LCS\n */\n template\n Node lcs_backtrace(const LCSTable& c, const Node& x, const Node& y, int i, int j, const ComparatorType& comparator) {\n DEBUG_PRINTLN(LCS, \"LCSBACK: X=\" << x << \" Y=\" << y << \" I=\" << i << \" J=\" << j)\n\n if (i == 0 || j == 0) {\n DEBUG_PRINTLN(LCS, \"RETURNING EMPTY\")\n return Node::createCollection();\n }\n\n NodeDeque& xChildren = *(x.collection());\n NodeDeque& yChildren = *(y.collection());\n\n Node compareOut = Node::createNil();\n if (comparator(xChildren[i], yChildren[j], compareOut)) {\n DEBUG_PRINTLN(LCS, \"RETURNING AFTER ELEM COMPARE\")\n Node result = lcs_backtrace(c, x, y, i - 1, j - 1, comparator);\n result.collection()->push_back(compareOut);\n return result;\n }\n\n if (c[i][j - 1] > c[i - 1][j]) {\n DEBUG_PRINTLN(LCS, \"RETURNING AFTER TABLE COMPARE\")\n return lcs_backtrace(c, x, y, i, j - 1, comparator);\n }\n\n DEBUG_PRINTLN(LCS, \"FINAL RETURN\")\n return lcs_backtrace(c, x, y, i - 1, j, comparator);\n }\n\n\n /*\n This is the equivalent of ruby's Sass::Util.lcs_table.\n\n # Calculates the memoization table for the Least Common Subsequence algorithm.\n # Algorithm from http://en.wikipedia.org/wiki/Longest_common_subsequence_problem#Computing_the_length_of_the_LCS\n */\n template\n void lcs_table(const Node& x, const Node& y, const ComparatorType& comparator, LCSTable& out) {\n DEBUG_PRINTLN(LCS, \"LCSTABLE: X=\" << x << \" Y=\" << y)\n\n NodeDeque& xChildren = *(x.collection());\n NodeDeque& yChildren = *(y.collection());\n\n LCSTable c(xChildren.size(), std::vector(yChildren.size()));\n\n // These shouldn't be necessary since the vector will be initialized to 0 already.\n // x.size.times {|i| c[i][0] = 0}\n // y.size.times {|j| c[0][j] = 0}\n\n for (size_t i = 1; i < xChildren.size(); i++) {\n for (size_t j = 1; j < yChildren.size(); j++) {\n Node compareOut = Node::createNil();\n\n if (comparator(xChildren[i], yChildren[j], compareOut)) {\n c[i][j] = c[i - 1][j - 1] + 1;\n } else {\n c[i][j] = std::max(c[i][j - 1], c[i - 1][j]);\n }\n }\n }\n\n out = c;\n }\n\n\n /*\n This is the equivalent of ruby's Sass::Util.lcs.\n\n # Computes a single longest common subsequence for `x` and `y`.\n # If there are more than one longest common subsequences,\n # the one returned is that which starts first in `x`.\n\n # @param x [NodeCollection]\n # @param y [NodeCollection]\n # @comparator An equality check between elements of `x` and `y`.\n # @return [NodeCollection] The LCS\n\n http://en.wikipedia.org/wiki/Longest_common_subsequence_problem\n */\n template\n Node lcs(Node& x, Node& y, const ComparatorType& comparator) {\n DEBUG_PRINTLN(LCS, \"LCS: X=\" << x << \" Y=\" << y)\n\n Node newX = Node::createCollection();\n newX.collection()->push_back(Node::createNil());\n newX.plus(x);\n\n Node newY = Node::createCollection();\n newY.collection()->push_back(Node::createNil());\n newY.plus(y);\n\n LCSTable table;\n lcs_table(newX, newY, comparator, table);\n\n return lcs_backtrace(table, newX, newY, static_cast(newX.collection()->size()) - 1, static_cast(newY.collection()->size()) - 1, comparator);\n }\n\n\n /*\n This is the equivalent of ruby sass' Sass::Util.flatten and [].flatten.\n Sass::Util.flatten requires the number of levels to flatten, while\n [].flatten doesn't and will flatten the entire array. This function\n supports both.\n\n # Flattens the first `n` nested arrays. If n == -1, all arrays will be flattened\n #\n # @param arr [NodeCollection] The array to flatten\n # @param n [int] The number of levels to flatten\n # @return [NodeCollection] The flattened array\n */\n Node flatten(Node& arr, int n = -1);\n\n\n /*\n This is the equivalent of ruby's Sass::Util.group_by_to_a.\n\n # Performs the equivalent of `enum.group_by.to_a`, but with a guaranteed\n # order. Unlike [#hash_to_a], the resulting order isn't sorted key order;\n # instead, it's the same order as `#group_by` has under Ruby 1.9 (key\n # appearance order).\n #\n # @param enum [Enumerable]\n # @return [Array<[Object, Array]>] An array of pairs.\n\n TODO: update @param and @return once I know what those are.\n\n The following is the modified version of the ruby code that was more portable to C++. You\n should be able to drop it into ruby 3.2.19 and get the same results from ruby sass.\n\n def group_by_to_a(enum, &block)\n order = {}\n\n arr = []\n\n grouped = {}\n\n for e in enum do\n key = block[e]\n unless order.include?(key)\n order[key] = order.size\n end\n\n if not grouped.has_key?(key) then\n grouped[key] = [e]\n else\n grouped[key].push(e)\n end\n end\n\n grouped.each do |key, vals|\n arr[order[key]] = [key, vals]\n end\n\n arr\n end\n\n */\n template\n void group_by_to_a(std::vector& enumeration, KeyFunctorType& keyFunc, std::vector > >& arr /*out*/) {\n\n std::map order;\n\n std::map > grouped;\n\n for (typename std::vector::iterator enumIter = enumeration.begin(), enumIterEnd = enumeration.end(); enumIter != enumIterEnd; enumIter++) {\n EnumType& e = *enumIter;\n\n KeyType key = keyFunc(e);\n\n if (grouped.find(key->hash()) == grouped.end()) {\n order.insert(std::make_pair((unsigned int)order.size(), key));\n\n std::vector newCollection;\n newCollection.push_back(e);\n grouped.insert(std::make_pair(key->hash(), newCollection));\n } else {\n std::vector& collection = grouped.at(key->hash());\n collection.push_back(e);\n }\n }\n\n for (unsigned int index = 0; index < order.size(); index++) {\n KeyType& key = order.at(index);\n std::vector& values = grouped.at(key->hash());\n\n std::pair > grouping = std::make_pair(key, values);\n\n arr.push_back(grouping);\n }\n }\n\n\n}\n\n#endif\n\n// Path: src/libsass/src/extend.cpp\n#include \"sass.hpp\"\n#include \"extend.hpp\"\n#include \"context.hpp\"\n#include \"backtrace.hpp\"\n#include \"paths.hpp\"\n#include \"parser.hpp\"\n#include \"expand.hpp\"\n#include \"node.hpp\"\n#include \"sass_util.hpp\"\n#include \"remove_placeholders.hpp\"\n#include \"debug.hpp\"\n#include \n#include \n#include \n\n/*\n NOTES:\n\n - The print* functions print to cerr. This allows our testing frameworks (like sass-spec) to ignore the output, which\n is very helpful when debugging. The format of the output is mainly to wrap things in square brackets to match what\n ruby already outputs (to make comparisons easier).\n\n - For the direct porting effort, we're trying to port method-for-method until we get all the tests passing.\n Where applicable, I've tried to include the ruby code above the function for reference until all our tests pass.\n The ruby code isn't always directly portable, so I've tried to include any modified ruby code that was actually\n used for the porting.\n\n - DO NOT try to optimize yet. We get a tremendous benefit out of comparing the output of each stage of the extend to the ruby\n output at the same stage. This makes it much easier to determine where problems are. Try to keep as close to\n the ruby code as you can until we have all the sass-spec tests passing. Then, we should optimize. However, if you see\n something that could probably be optimized, let's not forget it. Add a // TODO: or // IMPROVEMENT: comment.\n\n - Coding conventions in this file (these may need to be changed before merging back into master)\n - Very basic hungarian notation:\n p prefix for pointers (pSelector)\n no prefix for value types and references (selector)\n - Use STL iterators where possible\n - prefer verbose naming over terse naming\n - use typedefs for STL container types for make maintenance easier\n\n - You may see a lot of comments that say \"// TODO: is this the correct combinator?\". See the comment referring to combinators\n in extendCompoundSelector for a more extensive explanation of my confusion. I think our divergence in data model from ruby\n sass causes this to be necessary.\n\n\n GLOBAL TODOS:\n\n - wrap the contents of the print functions in DEBUG preprocesser conditionals so they will be optimized away in non-debug mode.\n\n - consider making the extend* functions member functions to avoid passing around ctx and subset_map map around. This has the\n drawback that the implementation details of the operator are then exposed to the outside world, which is not ideal and\n can cause additional compile time dependencies.\n\n - mark the helper methods in this file static to given them compilation unit linkage.\n\n - implement parent directive matching\n\n - fix compilation warnings for unused Extend members if we really don't need those references anymore.\n */\n\n\nnamespace Sass {\n\n\n\n#ifdef DEBUG\n\n // TODO: move the ast specific ostream operators into ast.hpp/ast.cpp\n std::ostream& operator<<(std::ostream& os, const Complex_Selector::Combinator combinator) {\n switch (combinator) {\n case Complex_Selector::ANCESTOR_OF: os << \"\\\" \\\"\"; break;\n case Complex_Selector::PARENT_OF: os << \"\\\">\\\"\"; break;\n case Complex_Selector::PRECEDES: os << \"\\\"~\\\"\"; break;\n case Complex_Selector::ADJACENT_TO: os << \"\\\"+\\\"\"; break;\n case Complex_Selector::REFERENCE: os << \"\\\"/\\\"\"; break;\n }\n\n return os;\n }\n\n\n std::ostream& operator<<(std::ostream& os, Compound_Selector& compoundSelector) {\n for (size_t i = 0, L = compoundSelector.length(); i < L; ++i) {\n if (i > 0) os << \", \";\n os << compoundSelector[i]->to_string();\n }\n return os;\n }\n\n std::ostream& operator<<(std::ostream& os, Simple_Selector& simpleSelector) {\n os << simpleSelector.to_string();\n return os;\n }\n\n // Print a string representation of a Compound_Selector\n static void printSimpleSelector(Simple_Selector* pSimpleSelector, const char* message=NULL, bool newline=true) {\n\n if (message) {\n std::cerr << message;\n }\n\n if (pSimpleSelector) {\n std::cerr << \"[\" << *pSimpleSelector << \"]\";\n } else {\n std::cerr << \"NULL\";\n }\n\n if (newline) {\n std::cerr << std::endl;\n }\n }\n\n // Print a string representation of a Compound_Selector\n static void printCompoundSelector(Compound_Selector_Ptr pCompoundSelector, const char* message=NULL, bool newline=true) {\n\n if (message) {\n std::cerr << message;\n }\n\n if (pCompoundSelector) {\n std::cerr << \"[\" << *pCompoundSelector << \"]\";\n } else {\n std::cerr << \"NULL\";\n }\n\n if (newline) {\n std::cerr << std::endl;\n }\n }\n\n\n std::ostream& operator<<(std::ostream& os, Complex_Selector& complexSelector) {\n\n os << \"[\";\n Complex_Selector_Ptr pIter = &complexSelector;\n bool first = true;\n while (pIter) {\n if (pIter->combinator() != Complex_Selector::ANCESTOR_OF) {\n if (!first) {\n os << \", \";\n }\n first = false;\n os << pIter->combinator();\n }\n\n if (!first) {\n os << \", \";\n }\n first = false;\n\n if (pIter->head()) {\n os << pIter->head()->to_string();\n } else {\n os << \"NULL_HEAD\";\n }\n\n pIter = pIter->tail();\n }\n os << \"]\";\n\n return os;\n }\n\n\n // Print a string representation of a Complex_Selector\n static void printComplexSelector(Complex_Selector_Ptr pComplexSelector, const char* message=NULL, bool newline=true) {\n\n if (message) {\n std::cerr << message;\n }\n\n if (pComplexSelector) {\n std::cerr << *pComplexSelector;\n } else {\n std::cerr << \"NULL\";\n }\n\n if (newline) {\n std::cerr << std::endl;\n }\n }\n\n static void printSelsNewSeqPairCollection(SubSetMapLookups& collection, const char* message=NULL, bool newline=true) {\n\n if (message) {\n std::cerr << message;\n }\n bool first = true;\n std::cerr << \"[\";\n for(SubSetMapLookup& pair : collection) {\n if (first) {\n first = false;\n } else {\n std::cerr << \", \";\n }\n std::cerr << \"[\";\n Compound_Selector_Ptr pSels = pair.first;\n Complex_Selector_Ptr pNewSelector = pair.second;\n std::cerr << \"[\" << *pSels << \"], \";\n printComplexSelector(pNewSelector, NULL, false);\n }\n std::cerr << \"]\";\n\n if (newline) {\n std::cerr << std::endl;\n }\n }\n\n // Print a string representation of a ComplexSelectorSet\n static void printSourcesSet(ComplexSelectorSet& sources, const char* message=NULL, bool newline=true) {\n\n if (message) {\n std::cerr << message;\n }\n\n // Convert to a deque of strings so we can sort since order doesn't matter in a set. This should cut down on\n // the differences we see when debug printing.\n typedef std::deque SourceStrings;\n SourceStrings sourceStrings;\n for (ComplexSelectorSet::iterator iterator = sources.begin(), iteratorEnd = sources.end(); iterator != iteratorEnd; ++iterator) {\n Complex_Selector_Ptr pSource = *iterator;\n std::stringstream sstream;\n sstream << complexSelectorToNode(pSource);\n sourceStrings.push_back(sstream.str());\n }\n\n // Sort to get consistent output\n std::sort(sourceStrings.begin(), sourceStrings.end());\n\n std::cerr << \"ComplexSelectorSet[\";\n for (SourceStrings::iterator iterator = sourceStrings.begin(), iteratorEnd = sourceStrings.end(); iterator != iteratorEnd; ++iterator) {\n std::string source = *iterator;\n if (iterator != sourceStrings.begin()) {\n std::cerr << \", \";\n }\n std::cerr << source;\n }\n std::cerr << \"]\";\n\n if (newline) {\n std::cerr << std::endl;\n }\n }\n\n\n std::ostream& operator<<(std::ostream& os, SubSetMapPairs& entries) {\n os << \"SUBSET_MAP_ENTRIES[\";\n\n for (SubSetMapPairs::iterator iterator = entries.begin(), endIterator = entries.end(); iterator != endIterator; ++iterator) {\n Complex_Selector_Obj pExtComplexSelector = iterator->first; // The selector up to where the @extend is (ie, the thing to merge)\n Compound_Selector_Obj pExtCompoundSelector = iterator->second; // The stuff after the @extend\n\n if (iterator != entries.begin()) {\n os << \", \";\n }\n\n os << \"(\";\n\n if (pExtComplexSelector) {\n std::cerr << *pExtComplexSelector;\n } else {\n std::cerr << \"NULL\";\n }\n\n os << \" -> \";\n\n if (pExtCompoundSelector) {\n std::cerr << *pExtCompoundSelector;\n } else {\n std::cerr << \"NULL\";\n }\n\n os << \")\";\n\n }\n\n os << \"]\";\n\n return os;\n }\n#endif\n\n static bool parentSuperselector(Complex_Selector_Ptr pOne, Complex_Selector_Ptr pTwo) {\n // TODO: figure out a better way to create a Complex_Selector from scratch\n // TODO: There's got to be a better way. This got ugly quick...\n Element_Selector_Obj fakeParent = SASS_MEMORY_NEW(Element_Selector, ParserState(\"[FAKE]\"), \"temp\");\n Compound_Selector_Obj fakeHead = SASS_MEMORY_NEW(Compound_Selector, ParserState(\"[FAKE]\"), 1 /*size*/);\n fakeHead->elements().push_back(fakeParent);\n Complex_Selector_Obj fakeParentContainer = SASS_MEMORY_NEW(Complex_Selector, ParserState(\"[FAKE]\"), Complex_Selector::ANCESTOR_OF, fakeHead /*head*/, NULL /*tail*/);\n\n pOne->set_innermost(fakeParentContainer, Complex_Selector::ANCESTOR_OF);\n pTwo->set_innermost(fakeParentContainer, Complex_Selector::ANCESTOR_OF);\n\n bool isSuperselector = pOne->is_superselector_of(pTwo);\n\n pOne->clear_innermost();\n pTwo->clear_innermost();\n\n return isSuperselector;\n }\n\n void nodeToComplexSelectorDeque(const Node& node, ComplexSelectorDeque& out) {\n for (NodeDeque::iterator iter = node.collection()->begin(), iterEnd = node.collection()->end(); iter != iterEnd; iter++) {\n Node& child = *iter;\n out.push_back(nodeToComplexSelector(child));\n }\n }\n\n Node complexSelectorDequeToNode(const ComplexSelectorDeque& deque) {\n Node result = Node::createCollection();\n\n for (ComplexSelectorDeque::const_iterator iter = deque.begin(), iterEnd = deque.end(); iter != iterEnd; iter++) {\n Complex_Selector_Obj pChild = *iter;\n result.collection()->push_back(complexSelectorToNode(pChild));\n }\n\n return result;\n }\n\n class LcsCollectionComparator {\n public:\n LcsCollectionComparator() {}\n\n bool operator()(Complex_Selector_Obj pOne, Complex_Selector_Obj pTwo, Complex_Selector_Obj& pOut) const {\n /*\n This code is based on the following block from ruby sass' subweave\n do |s1, s2|\n next s1 if s1 == s2\n next unless s1.first.is_a?(SimpleSequence) && s2.first.is_a?(SimpleSequence)\n next s2 if parent_superselector?(s1, s2)\n next s1 if parent_superselector?(s2, s1)\n end\n */\n\n if (*pOne == *pTwo) {\n pOut = pOne;\n return true;\n }\n\n if (pOne->combinator() != Complex_Selector::ANCESTOR_OF || pTwo->combinator() != Complex_Selector::ANCESTOR_OF) {\n return false;\n }\n\n if (parentSuperselector(pOne, pTwo)) {\n pOut = pTwo;\n return true;\n }\n\n if (parentSuperselector(pTwo, pOne)) {\n pOut = pOne;\n return true;\n }\n\n return false;\n }\n };\n\n\n /*\n This is the equivalent of ruby's Sass::Util.lcs_backtrace.\n\n # Computes a single longest common subsequence for arrays x and y.\n # Algorithm from http://en.wikipedia.org/wiki/Longest_common_subsequence_problem#Reading_out_an_LCS\n */\n void lcs_backtrace(const LCSTable& c, ComplexSelectorDeque& x, ComplexSelectorDeque& y, int i, int j, const LcsCollectionComparator& comparator, ComplexSelectorDeque& out) {\n //DEBUG_PRINTLN(LCS, \"LCSBACK: X=\" << x << \" Y=\" << y << \" I=\" << i << \" J=\" << j)\n // TODO: make printComplexSelectorDeque and use DEBUG_EXEC AND DEBUG_PRINTLN HERE to get equivalent output\n\n if (i == 0 || j == 0) {\n DEBUG_PRINTLN(LCS, \"RETURNING EMPTY\")\n return;\n }\n\n\n Complex_Selector_Obj pCompareOut;\n if (comparator(x[i], y[j], pCompareOut)) {\n DEBUG_PRINTLN(LCS, \"RETURNING AFTER ELEM COMPARE\")\n lcs_backtrace(c, x, y, i - 1, j - 1, comparator, out);\n out.push_back(pCompareOut);\n return;\n }\n\n if (c[i][j - 1] > c[i - 1][j]) {\n DEBUG_PRINTLN(LCS, \"RETURNING AFTER TABLE COMPARE\")\n lcs_backtrace(c, x, y, i, j - 1, comparator, out);\n return;\n }\n\n DEBUG_PRINTLN(LCS, \"FINAL RETURN\")\n lcs_backtrace(c, x, y, i - 1, j, comparator, out);\n return;\n }\n\n /*\n This is the equivalent of ruby's Sass::Util.lcs_table.\n\n # Calculates the memoization table for the Least Common Subsequence algorithm.\n # Algorithm from http://en.wikipedia.org/wiki/Longest_common_subsequence_problem#Computing_the_length_of_the_LCS\n */\n void lcs_table(const ComplexSelectorDeque& x, const ComplexSelectorDeque& y, const LcsCollectionComparator& comparator, LCSTable& out) {\n //DEBUG_PRINTLN(LCS, \"LCSTABLE: X=\" << x << \" Y=\" << y)\n // TODO: make printComplexSelectorDeque and use DEBUG_EXEC AND DEBUG_PRINTLN HERE to get equivalent output\n\n LCSTable c(x.size(), std::vector(y.size()));\n\n // These shouldn't be necessary since the vector will be initialized to 0 already.\n // x.size.times {|i| c[i][0] = 0}\n // y.size.times {|j| c[0][j] = 0}\n\n for (size_t i = 1; i < x.size(); i++) {\n for (size_t j = 1; j < y.size(); j++) {\n Complex_Selector_Obj pCompareOut;\n\n if (comparator(x[i], y[j], pCompareOut)) {\n c[i][j] = c[i - 1][j - 1] + 1;\n } else {\n c[i][j] = std::max(c[i][j - 1], c[i - 1][j]);\n }\n }\n }\n\n out = c;\n }\n\n /*\n This is the equivalent of ruby's Sass::Util.lcs.\n\n # Computes a single longest common subsequence for `x` and `y`.\n # If there are more than one longest common subsequences,\n # the one returned is that which starts first in `x`.\n\n # @param x [NodeCollection]\n # @param y [NodeCollection]\n # @comparator An equality check between elements of `x` and `y`.\n # @return [NodeCollection] The LCS\n\n http://en.wikipedia.org/wiki/Longest_common_subsequence_problem\n */\n void lcs(ComplexSelectorDeque& x, ComplexSelectorDeque& y, const LcsCollectionComparator& comparator, ComplexSelectorDeque& out) {\n //DEBUG_PRINTLN(LCS, \"LCS: X=\" << x << \" Y=\" << y)\n // TODO: make printComplexSelectorDeque and use DEBUG_EXEC AND DEBUG_PRINTLN HERE to get equivalent output\n\n x.push_front(NULL);\n y.push_front(NULL);\n\n LCSTable table;\n lcs_table(x, y, comparator, table);\n\n return lcs_backtrace(table, x, y, static_cast(x.size()) - 1, static_cast(y.size()) - 1, comparator, out);\n }\n\n\n /*\n This is the equivalent of ruby's Sequence.trim.\n\n The following is the modified version of the ruby code that was more portable to C++. You\n should be able to drop it into ruby 3.2.19 and get the same results from ruby sass.\n\n # Avoid truly horrific quadratic behavior. TODO: I think there\n # may be a way to get perfect trimming without going quadratic.\n return seqses if seqses.size > 100\n\n # Keep the results in a separate array so we can be sure we aren't\n # comparing against an already-trimmed selector. This ensures that two\n # identical selectors don't mutually trim one another.\n result = seqses.dup\n\n # This is n^2 on the sequences, but only comparing between\n # separate sequences should limit the quadratic behavior.\n seqses.each_with_index do |seqs1, i|\n tempResult = []\n\n for seq1 in seqs1 do\n max_spec = 0\n for seq in _sources(seq1) do\n max_spec = [max_spec, seq.specificity].max\n end\n\n\n isMoreSpecificOuter = false\n for seqs2 in result do\n if seqs1.equal?(seqs2) then\n next\n end\n\n # Second Law of Extend: the specificity of a generated selector\n # should never be less than the specificity of the extending\n # selector.\n #\n # See https://github.com/nex3/sass/issues/324.\n isMoreSpecificInner = false\n for seq2 in seqs2 do\n isMoreSpecificInner = _specificity(seq2) >= max_spec && _superselector?(seq2, seq1)\n if isMoreSpecificInner then\n break\n end\n end\n\n if isMoreSpecificInner then\n isMoreSpecificOuter = true\n break\n end\n end\n\n if !isMoreSpecificOuter then\n tempResult.push(seq1)\n end\n end\n\n result[i] = tempResult\n\n end\n\n result\n */\n /*\n - IMPROVEMENT: We could probably work directly in the output trimmed deque.\n */\n Node Extend::trim(Node& seqses, bool isReplace) {\n // See the comments in the above ruby code before embarking on understanding this function.\n\n // Avoid poor performance in extreme cases.\n if (seqses.collection()->size() > 100) {\n return seqses;\n }\n\n\n DEBUG_PRINTLN(TRIM, \"TRIM: \" << seqses)\n\n\n Node result = Node::createCollection();\n result.plus(seqses);\n\n DEBUG_PRINTLN(TRIM, \"RESULT INITIAL: \" << result)\n\n // Normally we use the standard STL iterators, but in this case, we need to access the result collection by index since we're\n // iterating the input collection, computing a value, and then setting the result in the output collection. We have to keep track\n // of the index manually.\n int toTrimIndex = 0;\n\n for (NodeDeque::iterator seqsesIter = seqses.collection()->begin(), seqsesIterEnd = seqses.collection()->end(); seqsesIter != seqsesIterEnd; ++seqsesIter) {\n Node& seqs1 = *seqsesIter;\n\n DEBUG_PRINTLN(TRIM, \"SEQS1: \" << seqs1 << \" \" << toTrimIndex)\n\n Node tempResult = Node::createCollection();\n tempResult.got_line_feed = seqs1.got_line_feed;\n\n for (NodeDeque::iterator seqs1Iter = seqs1.collection()->begin(), seqs1EndIter = seqs1.collection()->end(); seqs1Iter != seqs1EndIter; ++seqs1Iter) {\n Node& seq1 = *seqs1Iter;\n\n Complex_Selector_Obj pSeq1 = nodeToComplexSelector(seq1);\n\n // Compute the maximum specificity. This requires looking at the \"sources\" of the sequence. See SimpleSequence.sources in the ruby code\n // for a good description of sources.\n //\n // TODO: I'm pretty sure there's a bug in the sources code. It was implemented for sass-spec's 182_test_nested_extend_loop test.\n // While the test passes, I compared the state of each trim call to verify correctness. The last trim call had incorrect sources. We\n // had an extra source that the ruby version did not have. Without a failing test case, this is going to be extra hard to find. My\n // best guess at this point is that we're cloning an object somewhere and maintaining the sources when we shouldn't be. This is purely\n // a guess though.\n unsigned long maxSpecificity = isReplace ? pSeq1->specificity() : 0;\n ComplexSelectorSet sources = pSeq1->sources();\n\n DEBUG_PRINTLN(TRIM, \"TRIM SEQ1: \" << seq1)\n DEBUG_EXEC(TRIM, printSourcesSet(sources, \"TRIM SOURCES: \"))\n\n for (ComplexSelectorSet::iterator sourcesSetIterator = sources.begin(), sourcesSetIteratorEnd = sources.end(); sourcesSetIterator != sourcesSetIteratorEnd; ++sourcesSetIterator) {\n const Complex_Selector_Obj& pCurrentSelector = *sourcesSetIterator;\n maxSpecificity = std::max(maxSpecificity, pCurrentSelector->specificity());\n }\n\n DEBUG_PRINTLN(TRIM, \"MAX SPECIFICITY: \" << maxSpecificity)\n\n bool isMoreSpecificOuter = false;\n\n int resultIndex = 0;\n\n for (NodeDeque::iterator resultIter = result.collection()->begin(), resultIterEnd = result.collection()->end(); resultIter != resultIterEnd; ++resultIter) {\n Node& seqs2 = *resultIter;\n\n DEBUG_PRINTLN(TRIM, \"SEQS1: \" << seqs1)\n DEBUG_PRINTLN(TRIM, \"SEQS2: \" << seqs2)\n\n // Do not compare the same sequence to itself. The ruby call we're trying to\n // emulate is: seqs1.equal?(seqs2). equal? is an object comparison, not an equivalency comparision.\n // Since we have the same pointers in seqes and results, we can do a pointer comparision. seqs1 is\n // derived from seqses and seqs2 is derived from result.\n if (seqs1.collection() == seqs2.collection()) {\n DEBUG_PRINTLN(TRIM, \"CONTINUE\")\n continue;\n }\n\n bool isMoreSpecificInner = false;\n\n for (NodeDeque::iterator seqs2Iter = seqs2.collection()->begin(), seqs2IterEnd = seqs2.collection()->end(); seqs2Iter != seqs2IterEnd; ++seqs2Iter) {\n Node& seq2 = *seqs2Iter;\n\n Complex_Selector_Obj pSeq2 = nodeToComplexSelector(seq2);\n\n DEBUG_PRINTLN(TRIM, \"SEQ2 SPEC: \" << pSeq2->specificity())\n DEBUG_PRINTLN(TRIM, \"IS SPEC: \" << pSeq2->specificity() << \" >= \" << maxSpecificity << \" \" << (pSeq2->specificity() >= maxSpecificity ? \"true\" : \"false\"))\n DEBUG_PRINTLN(TRIM, \"IS SUPER: \" << (pSeq2->is_superselector_of(pSeq1) ? \"true\" : \"false\"))\n\n isMoreSpecificInner = pSeq2->specificity() >= maxSpecificity && pSeq2->is_superselector_of(pSeq1);\n\n if (isMoreSpecificInner) {\n DEBUG_PRINTLN(TRIM, \"FOUND MORE SPECIFIC\")\n break;\n }\n }\n\n // If we found something more specific, we're done. Let the outer loop know and stop iterating.\n if (isMoreSpecificInner) {\n isMoreSpecificOuter = true;\n break;\n }\n\n resultIndex++;\n }\n\n if (!isMoreSpecificOuter) {\n DEBUG_PRINTLN(TRIM, \"PUSHING: \" << seq1)\n tempResult.collection()->push_back(seq1);\n }\n\n }\n\n DEBUG_PRINTLN(TRIM, \"RESULT BEFORE ASSIGN: \" << result)\n DEBUG_PRINTLN(TRIM, \"TEMP RESULT: \" << toTrimIndex << \" \" << tempResult)\n (*result.collection())[toTrimIndex] = tempResult;\n\n toTrimIndex++;\n\n DEBUG_PRINTLN(TRIM, \"RESULT: \" << result)\n }\n\n return result;\n }\n\n\n\n static bool parentSuperselector(const Node& one, const Node& two) {\n // TODO: figure out a better way to create a Complex_Selector from scratch\n // TODO: There's got to be a better way. This got ugly quick...\n Element_Selector_Obj fakeParent = SASS_MEMORY_NEW(Element_Selector, ParserState(\"[FAKE]\"), \"temp\");\n Compound_Selector_Obj fakeHead = SASS_MEMORY_NEW(Compound_Selector, ParserState(\"[FAKE]\"), 1 /*size*/);\n fakeHead->elements().push_back(fakeParent);\n Complex_Selector_Obj fakeParentContainer = SASS_MEMORY_NEW(Complex_Selector, ParserState(\"[FAKE]\"), Complex_Selector::ANCESTOR_OF, fakeHead /*head*/, NULL /*tail*/);\n\n Complex_Selector_Obj pOneWithFakeParent = nodeToComplexSelector(one);\n pOneWithFakeParent->set_innermost(fakeParentContainer, Complex_Selector::ANCESTOR_OF);\n Complex_Selector_Obj pTwoWithFakeParent = nodeToComplexSelector(two);\n pTwoWithFakeParent->set_innermost(fakeParentContainer, Complex_Selector::ANCESTOR_OF);\n\n return pOneWithFakeParent->is_superselector_of(pTwoWithFakeParent);\n }\n\n\n class ParentSuperselectorChunker {\n public:\n \nParentSuperselectorChunker(Node& lcs) : mLcs(lcs) {}\n Node& mLcs;\n\n bool operator()(const Node& seq) const {\n // {|s| parent_superselector?(s.first, lcs.first)}\n if (seq.collection()->size() == 0) return false;\n return parentSuperselector(seq.collection()->front(), mLcs.collection()->front());\n }\n };\n\n class SubweaveEmptyChunker {\n public:\n bool operator()(const Node& seq) const {\n // {|s| s.empty?}\n\n return seq.collection()->empty();\n }\n };\n\n /*\n # Takes initial subsequences of `seq1` and `seq2` and returns all\n # orderings of those subsequences. The initial subsequences are determined\n # by a block.\n #\n # Destructively removes the initial subsequences of `seq1` and `seq2`.\n #\n # For example, given `(A B C | D E)` and `(1 2 | 3 4 5)` (with `|`\n # denoting the boundary of the initial subsequence), this would return\n # `[(A B C 1 2), (1 2 A B C)]`. The sequences would then be `(D E)` and\n # `(3 4 5)`.\n #\n # @param seq1 [Array]\n # @param seq2 [Array]\n # @yield [a] Used to determine when to cut off the initial subsequences.\n # Called repeatedly for each sequence until it returns true.\n # @yieldparam a [Array] A final subsequence of one input sequence after\n # cutting off some initial subsequence.\n # @yieldreturn [Boolean] Whether or not to cut off the initial subsequence\n # here.\n # @return [Array] All possible orderings of the initial subsequences.\n def chunks(seq1, seq2)\n chunk1 = []\n chunk1 << seq1.shift until yield seq1\n chunk2 = []\n chunk2 << seq2.shift until yield seq2\n return [] if chunk1.empty? && chunk2.empty?\n return [chunk2] if chunk1.empty?\n return [chunk1] if chunk2.empty?\n [chunk1 + chunk2, chunk2 + chunk1]\n end\n */\n template\n static Node chunks(Node& seq1, Node& seq2, const ChunkerType& chunker) {\n Node chunk1 = Node::createCollection();\n while (seq1.collection()->size() && !chunker(seq1)) {\n chunk1.collection()->push_back(seq1.collection()->front());\n seq1.collection()->pop_front();\n }\n\n Node chunk2 = Node::createCollection();\n while (!seq2.collection()->empty() && !chunker(seq2)) {\n chunk2.collection()->push_back(seq2.collection()->front());\n seq2.collection()->pop_front();\n }\n\n if (chunk1.collection()->empty() && chunk2.collection()->empty()) {\n DEBUG_PRINTLN(CHUNKS, \"RETURNING BOTH EMPTY\")\n return Node::createCollection();\n }\n\n if (chunk1.collection()->empty()) {\n Node chunk2Wrapper = Node::createCollection();\n chunk2Wrapper.collection()->push_back(chunk2);\n DEBUG_PRINTLN(CHUNKS, \"RETURNING ONE EMPTY\")\n return chunk2Wrapper;\n }\n\n if (chunk2.collection()->empty()) {\n Node chunk1Wrapper = Node::createCollection();\n chunk1Wrapper.collection()->push_back(chunk1);\n DEBUG_PRINTLN(CHUNKS, \"RETURNING TWO EMPTY\")\n return chunk1Wrapper;\n }\n\n Node perms = Node::createCollection();\n\n Node firstPermutation = Node::createCollection();\n firstPermutation.collection()->insert(firstPermutation.collection()->end(), chunk1.collection()->begin(), chunk1.collection()->end());\n firstPermutation.collection()->insert(firstPermutation.collection()->end(), chunk2.collection()->begin(), chunk2.collection()->end());\n perms.collection()->push_back(firstPermutation);\n\n Node secondPermutation = Node::createCollection();\n secondPermutation.collection()->insert(secondPermutation.collection()->end(), chunk2.collection()->begin(), chunk2.collection()->end());\n secondPermutation.collection()->insert(secondPermutation.collection()->end(), chunk1.collection()->begin(), chunk1.collection()->end());\n perms.collection()->push_back(secondPermutation);\n\n DEBUG_PRINTLN(CHUNKS, \"RETURNING PERM\")\n\n return perms;\n }\n\n\n static Node groupSelectors(Node& seq) {\n Node newSeq = Node::createCollection();\n\n Node tail = Node::createCollection();\n tail.plus(seq);\n\n while (!tail.collection()->empty()) {\n Node head = Node::createCollection();\n\n do {\n head.collection()->push_back(tail.collection()->front());\n tail.collection()->pop_front();\n } while (!tail.collection()->empty() && (head.collection()->back().isCombinator() || tail.collection()->front().isCombinator()));\n\n newSeq.collection()->push_back(head);\n }\n\n return newSeq;\n }\n\n\n static void getAndRemoveInitialOps(Node& seq, Node& ops) {\n NodeDeque& seqCollection = *(seq.collection());\n NodeDeque& opsCollection = *(ops.collection());\n\n while (seqCollection.size() > 0 && seqCollection.front().isCombinator()) {\n opsCollection.push_back(seqCollection.front());\n seqCollection.pop_front();\n }\n }\n\n\n static void getAndRemoveFinalOps(Node& seq, Node& ops) {\n NodeDeque& seqCollection = *(seq.collection());\n NodeDeque& opsCollection = *(ops.collection());\n\n while (seqCollection.size() > 0 && seqCollection.back().isCombinator()) {\n opsCollection.push_back(seqCollection.back()); // Purposefully reversed to match ruby code\n seqCollection.pop_back();\n }\n }\n\n\n /*\n def merge_initial_ops(seq1, seq2)\n ops1, ops2 = [], []\n ops1 << seq1.shift while seq1.first.is_a?(String)\n ops2 << seq2.shift while seq2.first.is_a?(String)\n\n newline = false\n newline ||= !!ops1.shift if ops1.first == \"\\n\"\n newline ||= !!ops2.shift if ops2.first == \"\\n\"\n\n # If neither sequence is a subsequence of the other, they cannot be\n # merged successfully\n lcs = Sass::Util.lcs(ops1, ops2)\n return unless lcs == ops1 || lcs == ops2\n return (newline ? [\"\\n\"] : []) + (ops1.size > ops2.size ? ops1 : ops2)\n end\n */\n static Node mergeInitialOps(Node& seq1, Node& seq2) {\n Node ops1 = Node::createCollection();\n Node ops2 = Node::createCollection();\n\n getAndRemoveInitialOps(seq1, ops1);\n getAndRemoveInitialOps(seq2, ops2);\n\n // TODO: Do we have this information available to us?\n // newline = false\n // newline ||= !!ops1.shift if ops1.first == \"\\n\"\n // newline ||= !!ops2.shift if ops2.first == \"\\n\"\n\n // If neither sequence is a subsequence of the other, they cannot be merged successfully\n DefaultLcsComparator lcsDefaultComparator;\n Node opsLcs = lcs(ops1, ops2, lcsDefaultComparator);\n\n if (!(opsLcs == ops1 || opsLcs == ops2)) {\n return Node::createNil();\n }\n\n // TODO: more newline logic\n // return (newline ? [\"\\n\"] : []) + (ops1.size > ops2.size ? ops1 : ops2)\n\n return (ops1.collection()->size() > ops2.collection()->size() ? ops1 : ops2);\n }\n\n\n /*\n def merge_final_ops(seq1, seq2, res = [])\n\n\n # This code looks complicated, but it's actually just a bunch of special\n # cases for interactions between different combinators.\n op1, op2 = ops1.first, ops2.first\n if op1 && op2\n sel1 = seq1.pop\n sel2 = seq2.pop\n if op1 == '~' && op2 == '~'\n if sel1.superselector?(sel2)\n res.unshift sel2, '~'\n elsif sel2.superselector?(sel1)\n res.unshift sel1, '~'\n else\n merged = sel1.unify(sel2.members, sel2.subject?)\n res.unshift [\n [sel1, '~', sel2, '~'],\n [sel2, '~', sel1, '~'],\n ([merged, '~'] if merged)\n ].compact\n end\n elsif (op1 == '~' && op2 == '+') || (op1 == '+' && op2 == '~')\n if op1 == '~'\n tilde_sel, plus_sel = sel1, sel2\n else\n tilde_sel, plus_sel = sel2, sel1\n end\n\n if tilde_sel.superselector?(plus_sel)\n res.unshift plus_sel, '+'\n else\n merged = plus_sel.unify(tilde_sel.members, tilde_sel.subject?)\n res.unshift [\n [tilde_sel, '~', plus_sel, '+'],\n ([merged, '+'] if merged)\n ].compact\n end\n elsif op1 == '>' && %w[~ +].include?(op2)\n res.unshift sel2, op2\n seq1.push sel1, op1\n elsif op2 == '>' && %w[~ +].include?(op1)\n res.unshift sel1, op1\n seq2.push sel2, op2\n elsif op1 == op2\n return unless merged = sel1.unify(sel2.members, sel2.subject?)\n res.unshift merged, op1\n else\n # Unknown selector combinators can't be unified\n return\n end\n return merge_final_ops(seq1, seq2, res)\n elsif op1\n seq2.pop if op1 == '>' && seq2.last && seq2.last.superselector?(seq1.last)\n res.unshift seq1.pop, op1\n return merge_final_ops(seq1, seq2, res)\n else # op2\n seq1.pop if op2 == '>' && seq1.last && seq1.last.superselector?(seq2.last)\n res.unshift seq2.pop, op2\n return merge_final_ops(seq1, seq2, res)\n end\n end\n */\n static Node mergeFinalOps(Node& seq1, Node& seq2, Node& res) {\n\n Node ops1 = Node::createCollection();\n Node ops2 = Node::createCollection();\n\n getAndRemoveFinalOps(seq1, ops1);\n getAndRemoveFinalOps(seq2, ops2);\n\n // TODO: do we have newlines to remove?\n // ops1.reject! {|o| o == \"\\n\"}\n // ops2.reject! {|o| o == \"\\n\"}\n\n if (ops1.collection()->empty() && ops2.collection()->empty()) {\n return res;\n }\n\n if (ops1.collection()->size() > 1 || ops2.collection()->size() > 1) {\n DefaultLcsComparator lcsDefaultComparator;\n Node opsLcs = lcs(ops1, ops2, lcsDefaultComparator);\n\n // If there are multiple operators, something hacky's going on. If one is a supersequence of the other, use that, otherwise give up.\n\n if (!(opsLcs == ops1 || opsLcs == ops2)) {\n return Node::createNil();\n }\n\n if (ops1.collection()->size() > ops2.collection()->size()) {\n res.collection()->insert(res.collection()->begin(), ops1.collection()->rbegin(), ops1.collection()->rend());\n } else {\n res.collection()->insert(res.collection()->begin(), ops2.collection()->rbegin(), ops2.collection()->rend());\n }\n\n return res;\n }\n\n if (!ops1.collection()->empty() && !ops2.collection()->empty()) {\n\n Node op1 = ops1.collection()->front();\n Node op2 = ops2.collection()->front();\n\n Node sel1 = seq1.collection()->back();\n seq1.collection()->pop_back();\n\n Node sel2 = seq2.collection()->back();\n seq2.collection()->pop_back();\n\n if (op1.combinator() == Complex_Selector::PRECEDES && op2.combinator() == Complex_Selector::PRECEDES) {\n\n if (sel1.selector()->is_superselector_of(sel2.selector())) {\n\n res.collection()->push_front(op1 /*PRECEDES - could have been op2 as well*/);\n res.collection()->push_front(sel2);\n\n } else if (sel2.selector()->is_superselector_of(sel1.selector())) {\n\n res.collection()->push_front(op1 /*PRECEDES - could have been op2 as well*/);\n res.collection()->push_front(sel1);\n\n } else {\n\n DEBUG_PRINTLN(ALL, \"sel1: \" << sel1)\n DEBUG_PRINTLN(ALL, \"sel2: \" << sel2)\n\n Complex_Selector_Obj pMergedWrapper = SASS_MEMORY_CLONE(sel1.selector()); // Clone the Complex_Selector to get back to something we can transform to a node once we replace the head with the unification result\n // TODO: does subject matter? Ruby: return unless merged = sel1.unify(sel2.members, sel2.subject?)\n Compound_Selector_Ptr pMerged = sel1.selector()->head()->unify_with(sel2.selector()->head());\n pMergedWrapper->head(pMerged);\n\n DEBUG_EXEC(ALL, printCompoundSelector(pMerged, \"MERGED: \"))\n\n Node newRes = Node::createCollection();\n\n Node firstPerm = Node::createCollection();\n firstPerm.collection()->push_back(sel1);\n firstPerm.collection()->push_back(Node::createCombinator(Complex_Selector::PRECEDES));\n firstPerm.collection()->push_back(sel2);\n firstPerm.collection()->push_back(Node::createCombinator(Complex_Selector::PRECEDES));\n newRes.collection()->push_back(firstPerm);\n\n Node secondPerm = Node::createCollection();\n secondPerm.collection()->push_back(sel2);\n secondPerm.collection()->push_back(Node::createCombinator(Complex_Selector::PRECEDES));\n secondPerm.collection()->push_back(sel1);\n secondPerm.collection()->push_back(Node::createCombinator(Complex_Selector::PRECEDES));\n newRes.collection()->push_back(secondPerm);\n\n if (pMerged) {\n Node mergedPerm = Node::createCollection();\n mergedPerm.collection()->push_back(Node::createSelector(pMergedWrapper));\n mergedPerm.collection()->push_back(Node::createCombinator(Complex_Selector::PRECEDES));\n newRes.collection()->push_back(mergedPerm);\n }\n\n res.collection()->push_front(newRes);\n\n DEBUG_PRINTLN(ALL, \"RESULT: \" << res)\n\n }\n\n } else if (((op1.combinator() == Complex_Selector::PRECEDES && op2.combinator() == Complex_Selector::ADJACENT_TO)) || ((op1.combinator() == Complex_Selector::ADJACENT_TO && op2.combinator() == Complex_Selector::PRECEDES))) {\n\n Node tildeSel = sel1;\n Node plusSel = sel2;\n Node plusOp = op2;\n if (op1.combinator() != Complex_Selector::PRECEDES) {\n tildeSel = sel2;\n plusSel = sel1;\n plusOp = op1;\n }\n\n if (tildeSel.selector()->is_superselector_of(plusSel.selector())) {\n\n res.collection()->push_front(plusOp);\n res.collection()->push_front(plusSel);\n\n } else {\n\n DEBUG_PRINTLN(ALL, \"PLUS SEL: \" << plusSel)\n DEBUG_PRINTLN(ALL, \"TILDE SEL: \" << tildeSel)\n\n Complex_Selector_Obj pMergedWrapper = SASS_MEMORY_CLONE(plusSel.selector()); // Clone the Complex_Selector to get back to something we can transform to a node once we replace the head with the unification result\n // TODO: does subject matter? Ruby: merged = plus_sel.unify(tilde_sel.members, tilde_sel.subject?)\n Compound_Selector_Ptr pMerged = plusSel.selector()->head()->unify_with(tildeSel.selector()->head());\n pMergedWrapper->head(pMerged);\n\n DEBUG_EXEC(ALL, printCompoundSelector(pMerged, \"MERGED: \"))\n\n Node newRes = Node::createCollection();\n\n Node firstPerm = Node::createCollection();\n firstPerm.collection()->push_back(tildeSel);\n firstPerm.collection()->push_back(Node::createCombinator(Complex_Selector::PRECEDES));\n firstPerm.collection()->push_back(plusSel);\n firstPerm.collection()->push_back(Node::createCombinator(Complex_Selector::ADJACENT_TO));\n newRes.collection()->push_back(firstPerm);\n\n if (pMerged) {\n Node mergedPerm = Node::createCollection();\n mergedPerm.collection()->push_back(Node::createSelector(pMergedWrapper));\n mergedPerm.collection()->push_back(Node::createCombinator(Complex_Selector::ADJACENT_TO));\n newRes.collection()->push_back(mergedPerm);\n }\n\n res.collection()->push_front(newRes);\n\n DEBUG_PRINTLN(ALL, \"RESULT: \" << res)\n\n }\n } else if (op1.combinator() == Complex_Selector::PARENT_OF && (op2.combinator() == Complex_Selector::PRECEDES || op2.combinator() == Complex_Selector::ADJACENT_TO)) {\n\n res.collection()->push_front(op2);\n res.collection()->push_front(sel2);\n\n seq1.collection()->push_back(sel1);\n seq1.collection()->push_back(op1);\n\n } else if (op2.combinator() == Complex_Selector::PARENT_OF && (op1.combinator() == Complex_Selector::PRECEDES || op1.combinator() == Complex_Selector::ADJACENT_TO)) {\n\n res.collection()->push_front(op1);\n res.collection()->push_front(sel1);\n\n seq2.collection()->push_back(sel2);\n seq2.collection()->push_back(op2);\n\n } else if (op1.combinator() == op2.combinator()) {\n\n DEBUG_PRINTLN(ALL, \"sel1: \" << sel1)\n DEBUG_PRINTLN(ALL, \"sel2: \" << sel2)\n\n Complex_Selector_Obj pMergedWrapper = SASS_MEMORY_CLONE(sel1.selector()); // Clone the Complex_Selector to get back to something we can transform to a node once we replace the head with the unification result\n // TODO: does subject matter? Ruby: return unless merged = sel1.unify(sel2.members, sel2.subject?)\n Compound_Selector_Ptr pMerged = sel1.selector()->head()->unify_with(sel2.selector()->head());\n pMergedWrapper->head(pMerged);\n\n DEBUG_EXEC(ALL, printCompoundSelector(pMerged, \"MERGED: \"))\n\n if (!pMerged) {\n return Node::createNil();\n }\n\n res.collection()->push_front(op1);\n res.collection()->push_front(Node::createSelector(pMergedWrapper));\n\n DEBUG_PRINTLN(ALL, \"RESULT: \" << res)\n\n } else {\n return Node::createNil();\n }\n\n return mergeFinalOps(seq1, seq2, res);\n\n } else if (!ops1.collection()->empty()) {\n\n Node op1 = ops1.collection()->front();\n\n if (op1.combinator() == Complex_Selector::PARENT_OF && !seq2.collection()->empty() && seq2.collection()->back().selector()->is_superselector_of(seq1.collection()->back().selector())) {\n seq2.collection()->pop_back();\n }\n\n // TODO: consider unshift(NodeCollection, Node)\n res.collection()->push_front(op1);\n res.collection()->push_front(seq1.collection()->back());\n seq1.collection()->pop_back();\n\n return mergeFinalOps(seq1, seq2, res);\n\n } else { // !ops2.collection()->empty()\n\n Node op2 = ops2.collection()->front();\n\n if (op2.combinator() == Complex_Selector::PARENT_OF && !seq1.collection()->empty() && seq1.collection()->back().selector()->is_superselector_of(seq2.collection()->back().selector())) {\n seq1.collection()->pop_back();\n }\n\n res.collection()->push_front(op2);\n...\n", "needle_token_start": 10389, "needle_token_end": 10409, "code_context_ntokens": 16016, "output": ["```cpp\nstatic bool parentSuperselector(const Node& one, const Node& two) {\n // TODO: figure out a better way to create a Complex_Selector from scratch\n // TODO: There's got to be a better way. This got ugly quick...\n Element_Selector_Obj fakeParent = SASS_MEMORY_NEW(Element_Selector, ParserState(\"[FAKE]\"), \"temp\");\n Compound_Selector_Obj fakeHead = SASS_MEMORY_NEW(Compound_Selector, ParserState(\"[FAKE]\"), 1 /*size*/);\n fakeHead->elements().push_back(fakeParent);\n Complex_Selector_Obj fakeParentContainer = SASS_MEMORY_NEW(Complex_Selector, ParserState(\"[FAKE]\"), Complex_Selector::ANCESTOR_OF, fakeHead /*head*/, NULL /*tail*/);\n\n Complex_Selector_Obj pOneWithFakeParent = nodeToComplexSelector(one);\n pOneWithFakeParent->set_innermost(fakeParentContainer, Complex_Selector::ANCESTOR_OF);\n Complex_Selector_Obj pTwoWithFakeParent = nodeToComplexSelector(two);\n pTwoWithFakeParent->set_innermost(fakeParentContainer, Complex_Selector::ANCESTOR_OF);\n\n return pOneWithFakeParent->is_superselector_of(pTwoWithFakeParent);\n}\n```"]} +{"repo": "sass/node-sass", "name": "Operand", "language": "cpp", "path": "src/libsass/src/ast.hpp", "position_ratio": 0.75, "description": "\nFunction Description:\n1. **Purpose**: The purpose of this class is to represent an operation in a stylesheet processing context, encapsulating details about the operation and its formatting preferences.\n2. **Input**: It takes an operation type, and two boolean flags indicating whether there should be whitespace before and after the operation.\n3. **Output**: It does not produce a direct output but initializes an object that holds the operation type and whitespace preferences.\n4. **Procedure**: The procedure involves initializing the class with the specified operation type and whitespace flags, storing these values in the object's properties for later use in processing or formatting operations in stylesheets.\n\n", "instruction": "Based on the function description and code context, please retrieve and repeat the exact described function from the code context in a code block wrapped by ```:", "template": "instruction\ncode_context\ndescription\ninstruction", "code_context": "// Path: src/libsass/src/utf8/checked.h\n// Copyright 2006 Nemanja Trifunovic\n\n/*\nPermission is hereby granted, free of charge, to any person or organization\nobtaining a copy of the software and accompanying documentation covered by\nthis license (the \"Software\") to use, reproduce, display, distribute,\nexecute, and transmit the Software, and to prepare derivative works of the\nSoftware, and to permit third-parties to whom the Software is furnished to\ndo so, all subject to the following:\n\n...\n// Path: src/libsass/src/utf8/unchecked.h\n// Copyright 2006 Nemanja Trifunovic\n\n/*\nPermission is hereby granted, free of charge, to any person or organization\nobtaining a copy of the software and accompanying documentation covered by\nthis license (the \"Software\") to use, reproduce, display, distribute,\nexecute, and transmit the Software, and to prepare derivative works of the\nSoftware, and to permit third-parties to whom the Software is furnished to\ndo so, all subject to the following:\n\nThe copyright notices in the Software and this entire statement, including\nthe above license grant, this restriction and the following disclaimer,\nmust be included in all copies of the Software, in whole or in part, and\nall derivative works of the Software, unless such copies or derivative\nworks are solely in the form of machine-executable object code generated by\na source language processor.\n\nTHE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR\nIMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,\nFITNESS FOR A PARTICULAR PURPOSE, TITLE AND NON-INFRINGEMENT. IN NO EVENT\nSHALL THE COPYRIGHT HOLDERS OR ANYONE DISTRIBUTING THE SOFTWARE BE LIABLE\nFOR ANY DAMAGES OR OTHER LIABILITY, WHETHER IN CONTRACT, TORT OR OTHERWISE,\nARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER\nDEALINGS IN THE SOFTWARE.\n*/\n\n\n#ifndef UTF8_FOR_CPP_UNCHECKED_H_2675DCD0_9480_4c0c_B92A_CC14C027B731\n#define UTF8_FOR_CPP_UNCHECKED_H_2675DCD0_9480_4c0c_B92A_CC14C027B731\n\n#include \"core.h\"\n\nnamespace utf8\n{\n namespace unchecked\n {\n template \n octet_iterator append(uint32_t cp, octet_iterator result)\n {\n if (cp < 0x80) // one octet\n *(result++) = static_cast(cp);\n else if (cp < 0x800) { // two octets\n *(result++) = static_cast((cp >> 6) | 0xc0);\n *(result++) = static_cast((cp & 0x3f) | 0x80);\n }\n else if (cp < 0x10000) { // three octets\n *(result++) = static_cast((cp >> 12) | 0xe0);\n *(result++) = static_cast(((cp >> 6) & 0x3f) | 0x80);\n *(result++) = static_cast((cp & 0x3f) | 0x80);\n }\n else { // four octets\n *(result++) = static_cast((cp >> 18) | 0xf0);\n *(result++) = static_cast(((cp >> 12) & 0x3f)| 0x80);\n *(result++) = static_cast(((cp >> 6) & 0x3f) | 0x80);\n *(result++) = static_cast((cp & 0x3f) | 0x80);\n }\n return result;\n }\n\n template \n uint32_t next(octet_iterator& it)\n {\n uint32_t cp = utf8::internal::mask8(*it);\n typename std::iterator_traits::difference_type length = utf8::internal::sequence_length(it);\n switch (length) {\n case 1:\n break;\n case 2:\n it++;\n cp = ((cp << 6) & 0x7ff) + ((*it) & 0x3f);\n break;\n case 3:\n ++it;\n cp = ((cp << 12) & 0xffff) + ((utf8::internal::mask8(*it) << 6) & 0xfff);\n ++it;\n cp += (*it) & 0x3f;\n break;\n case 4:\n ++it;\n cp = ((cp << 18) & 0x1fffff) + ((utf8::internal::mask8(*it) << 12) & 0x3ffff);\n ++it;\n cp += (utf8::internal::mask8(*it) << 6) & 0xfff;\n ++it;\n cp += (*it) & 0x3f;\n break;\n }\n ++it;\n return cp;\n }\n\n template \n uint32_t peek_next(octet_iterator it)\n {\n return utf8::unchecked::next(it);\n }\n\n template \n uint32_t prior(octet_iterator& it)\n {\n while (utf8::internal::is_trail(*(--it))) ;\n octet_iterator temp = it;\n return utf8::unchecked::next(temp);\n }\n\n // Deprecated in versions that include prior, but only for the sake of consistency (see utf8::previous)\n template \n inline uint32_t previous(octet_iterator& it)\n {\n return utf8::unchecked::prior(it);\n }\n\n template \n void advance (octet_iterator& it, distance_type n)\n {\n for (distance_type i = 0; i < n; ++i)\n utf8::unchecked::next(it);\n }\n\n template \n void retreat (octet_iterator& it, distance_type n)\n {\n for (distance_type i = 0; i < n; ++i)\n utf8::unchecked::prior(it);\n }\n\n template \n typename std::iterator_traits::difference_type\n distance (octet_iterator first, octet_iterator last)\n {\n typename std::iterator_traits::difference_type dist;\n for (dist = 0; first < last; ++dist)\n utf8::unchecked::next(first);\n return dist;\n }\n\n template \n octet_iterator utf16to8 (u16bit_iterator start, u16bit_iterator end, octet_iterator result)\n {\n while (start != end) {\n uint32_t cp = utf8::internal::mask16(*start++);\n // Take care of surrogate pairs first\n if (utf8::internal::is_lead_surrogate(cp)) {\n uint32_t trail_surrogate = utf8::internal::mask16(*start++);\n cp = (cp << 10) + trail_surrogate + internal::SURROGATE_OFFSET;\n }\n result = utf8::unchecked::append(cp, result);\n }\n return result;\n }\n\n template \n u16bit_iterator utf8to16 (octet_iterator start, octet_iterator end, u16bit_iterator result)\n {\n while (start < end) {\n uint32_t cp = utf8::unchecked::next(start);\n if (cp > 0xffff) { //make a surrogate pair\n *result++ = static_cast((cp >> 10) + internal::LEAD_OFFSET);\n *result++ = static_cast((cp & 0x3ff) + internal::TRAIL_SURROGATE_MIN);\n }\n else\n *result++ = static_cast(cp);\n }\n return result;\n }\n\n template \n octet_iterator utf32to8 (u32bit_iterator start, u32bit_iterator end, octet_iterator result)\n {\n while (start != end)\n result = utf8::unchecked::append(*(start++), result);\n\n return result;\n }\n\n template \n u32bit_iterator utf8to32 (octet_iterator start, octet_iterator end, u32bit_iterator result)\n {\n while (start < end)\n (*result++) = utf8::unchecked::next(start);\n\n return result;\n }\n\n // The iterator class\n template \n class iterator : public std::iterator {\n octet_iterator it;\n public:\n iterator () {}\n explicit iterator (const octet_iterator& octet_it): it(octet_it) {}\n // the default \"big three\" are OK\n octet_iterator base () const { return it; }\n uint32_t operator * () const\n {\n octet_iterator temp = it;\n return utf8::unchecked::next(temp);\n }\n bool operator == (const iterator& rhs) const\n {\n return (it == rhs.it);\n }\n bool operator != (const iterator& rhs) const\n {\n return !(operator == (rhs));\n }\n iterator& operator ++ ()\n {\n ::std::advance(it, utf8::internal::sequence_length(it));\n return *this;\n }\n iterator operator ++ (int)\n {\n iterator temp = *this;\n ::std::advance(it, utf8::internal::sequence_length(it));\n return temp;\n }\n iterator& operator -- ()\n {\n utf8::unchecked::prior(it);\n return *this;\n }\n iterator operator -- (int)\n {\n iterator temp = *this;\n utf8::unchecked::prior(it);\n return temp;\n }\n }; // class iterator\n\n } // namespace utf8::unchecked\n} // namespace utf8\n\n\n#endif // header guard\n\n\n// Path: src/libsass/src/utf8.h\n// Copyright 2006 Nemanja Trifunovic\n\n/*\nPermission is hereby granted, free of charge, to any person or organization\nobtaining a copy of the software and accompanying documentation covered by\nthis license (the \"Software\") to use, reproduce, display, distribute,\nexecute, and transmit the Software, and to prepare derivative works of the\nSoftware, and to permit third-parties to whom the Software is furnished to\ndo so, all subject to the following:\n\nThe copyright notices in the Software and this entire statement, including\nthe above license grant, this restriction and the following disclaimer,\nmust be included in all copies of the Software, in whole or in part, and\nall derivative works of the Software, unless such copies or derivative\nworks are solely in the form of machine-executable object code generated by\na source language processor.\n\nTHE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR\nIMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,\nFITNESS FOR A PARTICULAR PURPOSE, TITLE AND NON-INFRINGEMENT. IN NO EVENT\nSHALL THE COPYRIGHT HOLDERS OR ANYONE DISTRIBUTING THE SOFTWARE BE LIABLE\nFOR ANY DAMAGES OR OTHER LIABILITY, WHETHER IN CONTRACT, TORT OR OTHERWISE,\nARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER\nDEALINGS IN THE SOFTWARE.\n*/\n\n\n#ifndef UTF8_FOR_CPP_2675DCD0_9480_4c0c_B92A_CC14C027B731\n#define UTF8_FOR_CPP_2675DCD0_9480_4c0c_B92A_CC14C027B731\n\n#include \"utf8/checked.h\"\n#include \"utf8/unchecked.h\"\n\n#endif // header guard\n\n// Path: src/libsass/src/utf8_string.hpp\n#ifndef SASS_UTF8_STRING_H\n#define SASS_UTF8_STRING_H\n\n#include \n#include \"utf8.h\"\n\nnamespace Sass {\n namespace UTF_8 {\n\n // naming conventions:\n // offset: raw byte offset (0 based)\n // position: code point offset (0 based)\n // index: code point offset (1 based or negative)\n\n // function that will count the number of code points (utf-8 characters) from the beginning to the given end\n size_t code_point_count(const std::string& str, size_t start, size_t end);\n size_t code_point_count(const std::string& str);\n\n // function that will return the byte offset of a code point in a\n size_t offset_at_position(const std::string& str, size_t position);\n\n // function that returns number of bytes in a character in a string\n size_t code_point_size_at_offset(const std::string& str, size_t offset);\n\n // function that will return a normalized index, given a crazy one\n size_t normalize_index(int index, size_t len);\n\n #ifdef _WIN32\n // functions to handle unicode paths on windows\n std::string convert_from_utf16(const std::wstring& wstr);\n std::wstring convert_to_utf16(const std::string& str);\n #endif\n\n }\n}\n\n#endif\n\n// Path: src/libsass/src/plugins.hpp\n#ifndef SASS_PLUGINS_H\n#define SASS_PLUGINS_H\n\n#include \n#include \n#include \"utf8_string.hpp\"\n#include \"sass/functions.h\"\n\n#ifdef _WIN32\n\n #define LOAD_LIB(var, path) HMODULE var = LoadLibraryW(UTF_8::convert_to_utf16(path).c_str())\n #define LOAD_LIB_WCHR(var, path_wide_str) HMODULE var = LoadLibraryW(path_wide_str.c_str())\n #define LOAD_LIB_FN(type, var, name) type var = (type) GetProcAddress(plugin, name)\n #define CLOSE_LIB(var) FreeLibrary(var)\n\n #ifndef dlerror\n #define dlerror() 0\n #endif\n\n#else\n\n #define LOAD_LIB(var, path) void* var = dlopen(path.c_str(), RTLD_LAZY)\n #define LOAD_LIB_FN(type, var, name) type var = (type) dlsym(plugin, name)\n #define CLOSE_LIB(var) dlclose(var)\n\n#endif\n\nnamespace Sass {\n\n\n class Plugins {\n\n public: // c-tor\n Plugins(void);\n ~Plugins(void);\n\n public: // methods\n // load one specific plugin\n bool load_plugin(const std::string& path);\n // load all plugins from a directory\n size_t load_plugins(const std::string& path);\n\n public: // public accessors\n const std::vector get_headers(void) { return headers; }\n const std::vector get_importers(void) { return importers; }\n const std::vector get_functions(void) { return functions; }\n\n private: // private vars\n std::vector headers;\n std::vector importers;\n std::vector functions;\n\n };\n\n}\n\n#endif\n\n// Path: src/libsass/src/context.hpp\n#ifndef SASS_CONTEXT_H\n#define SASS_CONTEXT_H\n\n#include \n#include \n#include \n\n#define BUFFERSIZE 255\n#include \"b64/encode.h\"\n\n#include \"ast_fwd_decl.hpp\"\n#include \"kwd_arg_macros.hpp\"\n#include \"ast_fwd_decl.hpp\"\n#include \"sass_context.hpp\"\n#include \"environment.hpp\"\n#include \"source_map.hpp\"\n#include \"subset_map.hpp\"\n#include \"backtrace.hpp\"\n#include \"output.hpp\"\n#include \"plugins.hpp\"\n#include \"file.hpp\"\n\n\nstruct Sass_Function;\n\nnamespace Sass {\n\n class Context {\n public:\n void import_url (Import_Ptr imp, std::string load_path, const std::string& ctx_path);\n bool call_headers(const std::string& load_path, const char* ctx_path, ParserState& pstate, Import_Ptr imp)\n { return call_loader(load_path, ctx_path, pstate, imp, c_headers, false); };\n bool call_importers(const std::string& load_path, const char* ctx_path, ParserState& pstate, Import_Ptr imp)\n { return call_loader(load_path, ctx_path, pstate, imp, c_importers, true); };\n\n private:\n bool call_loader(const std::string& load_path, const char* ctx_path, ParserState& pstate, Import_Ptr imp, std::vector importers, bool only_one = true);\n\n public:\n const std::string CWD;\n struct Sass_Options& c_options;\n std::string entry_path;\n size_t head_imports;\n Plugins plugins;\n Output emitter;\n\n // generic ast node garbage container\n // used to avoid possible circular refs\n std::vector ast_gc;\n // resources add under our control\n // these are guaranteed to be freed\n std::vector strings;\n std::vector resources;\n std::map sheets;\n Subset_Map subset_map;\n std::vector import_stack;\n std::vector callee_stack;\n std::vector traces;\n\n struct Sass_Compiler* c_compiler;\n\n // absolute paths to includes\n std::vector included_files;\n // relative includes for sourcemap\n std::vector srcmap_links;\n // vectors above have same size\n\n std::vector plugin_paths; // relative paths to load plugins\n std::vector include_paths; // lookup paths for includes\n\n\n\n\n\n void apply_custom_headers(Block_Obj root, const char* path, ParserState pstate);\n\n std::vector c_headers;\n std::vector c_importers;\n std::vector c_functions;\n\n void add_c_header(Sass_Importer_Entry header);\n void add_c_importer(Sass_Importer_Entry importer);\n void add_c_function(Sass_Function_Entry function);\n\n const std::string indent; // String to be used for indentation\n const std::string linefeed; // String to be used for line feeds\n const std::string input_path; // for relative paths in src-map\n const std::string output_path; // for relative paths to the output\n const std::string source_map_file; // path to source map file (enables feature)\n const std::string source_map_root; // path for sourceRoot property (pass-through)\n\n virtual ~Context();\n Context(struct Sass_Context&);\n virtual Block_Obj parse() = 0;\n virtual Block_Obj compile();\n virtual char* render(Block_Obj root);\n virtual char* render_srcmap();\n\n void register_resource(const Include&, const Resource&);\n void register_resource(const Include&, const Resource&, ParserState&);\n std::vector find_includes(const Importer& import);\n Include load_import(const Importer&, ParserState pstate);\n\n Sass_Output_Style output_style() { return c_options.output_style; };\n std::vector get_included_files(bool skip = false, size_t headers = 0);\n\n private:\n void collect_plugin_paths(const char* paths_str);\n void collect_plugin_paths(string_list* paths_array);\n void collect_include_paths(const char* paths_str);\n void collect_include_paths(string_list* paths_array);\n std::string format_embedded_source_map();\n std::string format_source_mapping_url(const std::string& out_path);\n\n\n // void register_built_in_functions(Env* env);\n // void register_function(Signature sig, Native_Function f, Env* env);\n // void register_function(Signature sig, Native_Function f, size_t arity, Env* env);\n // void register_overload_stub(std::string name, Env* env);\n\n public:\n const std::string& cwd() { return CWD; };\n };\n\n class File_Context : public Context {\n public:\n File_Context(struct Sass_File_Context& ctx)\n : Context(ctx)\n { }\n virtual ~File_Context();\n virtual Block_Obj parse();\n };\n\n class Data_Context : public Context {\n public:\n char* source_c_str;\n char* srcmap_c_str;\n Data_Context(struct Sass_Data_Context& ctx)\n : Context(ctx)\n {\n source_c_str = ctx.source_string;\n srcmap_c_str = ctx.srcmap_string;\n ctx.source_string = 0; // passed away\n ctx.srcmap_string = 0; // passed away\n }\n virtual ~Data_Context();\n virtual Block_Obj parse();\n };\n\n}\n\n#endif\n\n// Path: src/libsass/src/constants.hpp\n#ifndef SASS_CONSTANTS_H\n#define SASS_CONSTANTS_H\n\nnamespace Sass {\n namespace Constants {\n\n // The maximum call stack that can be created\n extern const unsigned long MaxCallStack;\n\n // https://developer.mozilla.org/en-US/docs/Web/CSS/Specificity\n // The following list of selectors is by increasing specificity:\n extern const unsigned long Specificity_Star;\n extern const unsigned long Specificity_Universal;\n extern const unsigned long Specificity_Element;\n extern const unsigned long Specificity_Base;\n extern const unsigned long Specificity_Class;\n extern const unsigned long Specificity_Attr;\n extern const unsigned long Specificity_Pseudo;\n extern const unsigned long Specificity_ID;\n\n // sass keywords\n extern const char at_root_kwd[];\n extern const char import_kwd[];\n extern const char mixin_kwd[];\n extern const char function_kwd[];\n extern const char return_kwd[];\n extern const char include_kwd[];\n extern const char content_kwd[];\n extern const char extend_kwd[];\n extern const char if_kwd[];\n extern const char else_kwd[];\n extern const char if_after_else_kwd[];\n extern const char for_kwd[];\n extern const char from_kwd[];\n extern const char to_kwd[];\n extern const char through_kwd[];\n extern const char each_kwd[];\n extern const char in_kwd[];\n extern const char while_kwd[];\n extern const char warn_kwd[];\n extern const char error_kwd[];\n extern const char debug_kwd[];\n extern const char default_kwd[];\n extern const char global_kwd[];\n extern const char null_kwd[];\n extern const char optional_kwd[];\n extern const char with_kwd[];\n extern const char without_kwd[];\n extern const char all_kwd[];\n extern const char rule_kwd[];\n\n // css standard units\n extern const char em_kwd[];\n extern const char ex_kwd[];\n extern const char px_kwd[];\n extern const char cm_kwd[];\n extern const char mm_kwd[];\n extern const char pt_kwd[];\n extern const char pc_kwd[];\n extern const char deg_kwd[];\n extern const char rad_kwd[];\n extern const char grad_kwd[];\n extern const char turn_kwd[];\n extern const char ms_kwd[];\n extern const char s_kwd[];\n extern const char Hz_kwd[];\n extern const char kHz_kwd[];\n\n // vendor prefixes\n extern const char vendor_opera_kwd[];\n extern const char vendor_webkit_kwd[];\n extern const char vendor_mozilla_kwd[];\n extern const char vendor_ms_kwd[];\n extern const char vendor_khtml_kwd[];\n\n // css functions and keywords\n extern const char charset_kwd[];\n extern const char media_kwd[];\n extern const char supports_kwd[];\n extern const char keyframes_kwd[];\n extern const char only_kwd[];\n extern const char rgb_fn_kwd[];\n extern const char url_fn_kwd[];\n extern const char url_kwd[];\n // extern const char url_prefix_fn_kwd[];\n extern const char important_kwd[];\n extern const char pseudo_not_fn_kwd[];\n extern const char even_kwd[];\n extern const char odd_kwd[];\n extern const char progid_kwd[];\n extern const char expression_kwd[];\n extern const char calc_fn_kwd[];\n\n // char classes for \"regular expressions\"\n extern const char almost_any_value_class[];\n\n // css selector keywords\n extern const char sel_deep_kwd[];\n\n // css attribute-matching operators\n extern const char tilde_equal[];\n extern const char pipe_equal[];\n extern const char caret_equal[];\n extern const char dollar_equal[];\n extern const char star_equal[];\n\n // relational & logical operators and constants\n extern const char and_kwd[];\n extern const char or_kwd[];\n extern const char not_kwd[];\n extern const char gt[];\n extern const char gte[];\n extern const char lt[];\n extern const char lte[];\n extern const char eq[];\n extern const char neq[];\n extern const char true_kwd[];\n extern const char false_kwd[];\n\n // miscellaneous punctuation and delimiters\n extern const char percent_str[];\n extern const char empty_str[];\n extern const char slash_slash[];\n extern const char slash_star[];\n extern const char star_slash[];\n extern const char hash_lbrace[];\n extern const char rbrace[];\n extern const char rparen[];\n extern const char sign_chars[];\n extern const char op_chars[];\n extern const char hyphen[];\n extern const char ellipsis[];\n // extern const char url_space_chars[];\n\n // type names\n extern const char numeric_name[];\n extern const char number_name[];\n extern const char percentage_name[];\n extern const char dimension_name[];\n extern const char string_name[];\n extern const char bool_name[];\n extern const char color_name[];\n extern const char list_name[];\n extern const char map_name[];\n extern const char arglist_name[];\n\n // constants for uri parsing (RFC 3986 Appendix A.)\n extern const char uri_chars[];\n extern const char real_uri_chars[];\n\n // some specific constant character classes\n // they must be static to be useable by lexer\n extern const char static_ops[];\n extern const char selector_list_delims[];\n extern const char complex_selector_delims[];\n extern const char selector_combinator_ops[];\n extern const char attribute_compare_modifiers[];\n extern const char selector_lookahead_ops[];\n\n // byte order marks\n // (taken from http://en.wikipedia.org/wiki/Byte_order_mark)\n extern const unsigned char utf_8_bom[];\n extern const unsigned char utf_16_bom_be[];\n extern const unsigned char utf_16_bom_le[];\n extern const unsigned char utf_32_bom_be[];\n extern const unsigned char utf_32_bom_le[];\n extern const unsigned char utf_7_bom_1[];\n extern const unsigned char utf_7_bom_2[];\n extern const unsigned char utf_7_bom_3[];\n extern const unsigned char utf_7_bom_4[];\n extern const unsigned char utf_7_bom_5[];\n extern const unsigned char utf_1_bom[];\n extern const unsigned char utf_ebcdic_bom[];\n extern const unsigned char scsu_bom[];\n extern const unsigned char bocu_1_bom[];\n extern const unsigned char gb_18030_bom[];\n\n }\n}\n\n#endif\n\n// Path: src/libsass/src/error_handling.hpp\n#ifndef SASS_ERROR_HANDLING_H\n#define SASS_ERROR_HANDLING_H\n\n#include \n#include \n#include \n#include \"position.hpp\"\n#include \"backtrace.hpp\"\n#include \"ast_fwd_decl.hpp\"\n#include \"sass/functions.h\"\n\nnamespace Sass {\n\n struct Backtrace;\n\n namespace Exception {\n\n const std::string def_msg = \"Invalid sass detected\";\n const std::string def_op_msg = \"Undefined operation\";\n const std::string def_op_null_msg = \"Invalid null operation\";\n const std::string def_nesting_limit = \"Code too deeply neested\";\n\n class Base : public std::runtime_error {\n protected:\n std::string msg;\n std::string prefix;\n public:\n ParserState pstate;\n Backtraces traces;\n public:\n Base(ParserState pstate, std::string msg, Backtraces traces);\n virtual const char* errtype() const { return prefix.c_str(); }\n virtual const char* what() const throw() { return msg.c_str(); }\n virtual ~Base() throw() {};\n };\n\n class InvalidSass : public Base {\n public:\n InvalidSass(ParserState pstate, Backtraces traces, std::string msg);\n virtual ~InvalidSass() throw() {};\n };\n\n class InvalidParent : public Base {\n protected:\n Selector_Ptr parent;\n Selector_Ptr selector;\n public:\n InvalidParent(Selector_Ptr parent, Backtraces traces, Selector_Ptr selector);\n virtual ~InvalidParent() throw() {};\n };\n\n class MissingArgument : public Base {\n protected:\n std::string fn;\n std::string arg;\n std::string fntype;\n public:\n MissingArgument(ParserState pstate, Backtraces traces, std::string fn, std::string arg, std::string fntype);\n virtual ~MissingArgument() throw() {};\n };\n\n class InvalidArgumentType : public Base {\n protected:\n std::string fn;\n std::string arg;\n std::string type;\n const Value_Ptr value;\n public:\n InvalidArgumentType(ParserState pstate, Backtraces traces, std::string fn, std::string arg, std::string type, const Value_Ptr value = 0);\n virtual ~InvalidArgumentType() throw() {};\n };\n\n class InvalidVarKwdType : public Base {\n protected:\n std::string name;\n const Argument_Ptr arg;\n public:\n InvalidVarKwdType(ParserState pstate, Backtraces traces, std::string name, const Argument_Ptr arg = 0);\n virtual ~InvalidVarKwdType() throw() {};\n };\n\n class InvalidSyntax : public Base {\n public:\n InvalidSyntax(ParserState pstate, Backtraces traces, std::string msg);\n virtual ~InvalidSyntax() throw() {};\n };\n\n class NestingLimitError : public Base {\n public:\n NestingLimitError(ParserState pstate, Backtraces traces, std::string msg = def_nesting_limit);\n virtual ~NestingLimitError() throw() {};\n };\n\n class DuplicateKeyError : public Base {\n protected:\n const Map& dup;\n const Expression& org;\n public:\n DuplicateKeyError(Backtraces traces, const Map& dup, const Expression& org);\n virtual const char* errtype() const { return \"Error\"; }\n virtual ~DuplicateKeyError() throw() {};\n };\n\n class TypeMismatch : public Base {\n protected:\n const Expression& var;\n const std::string type;\n public:\n TypeMismatch(Backtraces traces, const Expression& var, const std::string type);\n virtual const char* errtype() const { return \"Error\"; }\n virtual ~TypeMismatch() throw() {};\n };\n\n class InvalidValue : public Base {\n protected:\n const Expression& val;\n public:\n InvalidValue(Backtraces traces, const Expression& val);\n virtual const char* errtype() const { return \"Error\"; }\n virtual ~InvalidValue() throw() {};\n };\n\n class StackError : public Base {\n protected:\n const AST_Node& node;\n public:\n StackError(Backtraces traces, const AST_Node& node);\n virtual const char* errtype() const { return \"SystemStackError\"; }\n virtual ~StackError() throw() {};\n };\n\n /* common virtual base class (has no pstate or trace) */\n class OperationError : public std::runtime_error {\n protected:\n std::string msg;\n public:\n OperationError(std::string msg = def_op_msg)\n : std::runtime_error(msg), msg(msg)\n {};\n public:\n virtual const char* errtype() const { return \"Error\"; }\n virtual const char* what() const throw() { return msg.c_str(); }\n virtual ~OperationError() throw() {};\n };\n\n class ZeroDivisionError : public OperationError {\n protected:\n const Expression& lhs;\n const Expression& rhs;\n public:\n ZeroDivisionError(const Expression& lhs, const Expression& rhs);\n virtual const char* errtype() const { return \"ZeroDivisionError\"; }\n virtual ~ZeroDivisionError() throw() {};\n };\n\n class IncompatibleUnits : public OperationError {\n protected:\n // const Sass::UnitType lhs;\n // const Sass::UnitType rhs;\n public:\n IncompatibleUnits(const Units& lhs, const Units& rhs);\n IncompatibleUnits(const UnitType lhs, const UnitType rhs);\n virtual ~IncompatibleUnits() throw() {};\n };\n\n class UndefinedOperation : public OperationError {\n protected:\n Expression_Ptr_Const lhs;\n Expression_Ptr_Const rhs;\n const Sass_OP op;\n public:\n UndefinedOperation(Expression_Ptr_Const lhs, Expression_Ptr_Const rhs, enum Sass_OP op);\n // virtual const char* errtype() const { return \"Error\"; }\n virtual ~UndefinedOperation() throw() {};\n };\n\n class InvalidNullOperation : public UndefinedOperation {\n public:\n InvalidNullOperation(Expression_Ptr_Const lhs, Expression_Ptr_Const rhs, enum Sass_OP op);\n virtual ~InvalidNullOperation() throw() {};\n };\n\n class AlphaChannelsNotEqual : public OperationError {\n protected:\n Expression_Ptr_Const lhs;\n Expression_Ptr_Const rhs;\n const Sass_OP op;\n public:\n AlphaChannelsNotEqual(Expression_Ptr_Const lhs, Expression_Ptr_Const rhs, enum Sass_OP op);\n // virtual const char* errtype() const { return \"Error\"; }\n virtual ~AlphaChannelsNotEqual() throw() {};\n };\n\n class SassValueError : public Base {\n public:\n SassValueError(Backtraces traces, ParserState pstate, OperationError& err);\n virtual ~SassValueError() throw() {};\n };\n\n }\n\n void warn(std::string msg, ParserState pstate);\n void warn(std::string msg, ParserState pstate, Backtrace* bt);\n void warning(std::string msg, ParserState pstate);\n\n void deprecated_function(std::string msg, ParserState pstate);\n void deprecated(std::string msg, std::string msg2, bool with_column, ParserState pstate);\n void deprecated_bind(std::string msg, ParserState pstate);\n // void deprecated(std::string msg, ParserState pstate, Backtrace* bt);\n\n void coreError(std::string msg, ParserState pstate);\n void error(std::string msg, ParserState pstate, Backtraces& traces);\n\n}\n\n#endif\n\n// Path: src/libsass/include/sass/version.h\n#ifndef SASS_VERSION_H\n#define SASS_VERSION_H\n\n#ifndef LIBSASS_VERSION\n#define LIBSASS_VERSION \"[NA]\"\n#endif\n\n#ifndef LIBSASS_LANGUAGE_VERSION\n#define LIBSASS_LANGUAGE_VERSION \"3.5\"\n#endif\n\n#endif\n\n// Path: src/libsass/include/sass2scss.h\n/**\n * sass2scss\n * Licensed under the MIT License\n * Copyright (c) Marcel Greter\n */\n\n#ifndef SASS2SCSS_H\n#define SASS2SCSS_H\n\n#ifdef _WIN32\n\n /* You should define ADD_EXPORTS *only* when building the DLL. */\n #ifdef ADD_EXPORTS\n #define ADDAPI __declspec(dllexport)\n\t#define ADDCALL __cdecl\n #else\n #define ADDAPI\n\t#define ADDCALL\n #endif\n\n#else /* _WIN32 not defined. */\n\n /* Define with no value on non-Windows OSes. */\n #define ADDAPI\n #define ADDCALL\n\n#endif\n\n#ifdef __cplusplus\n\n#include \n#include \n#include \n#include \n#include \n\n#ifndef SASS2SCSS_VERSION\n// Hardcode once the file is copied from\n// https://github.com/mgreter/sass2scss\n#define SASS2SCSS_VERSION \"1.1.1\"\n#endif\n\n// add namespace for c++\nnamespace Sass\n{\n\n\t// pretty print options\n\tconst int SASS2SCSS_PRETTIFY_0 = 0;\n\tconst int SASS2SCSS_PRETTIFY_1 = 1;\n\tconst int SASS2SCSS_PRETTIFY_2 = 2;\n\tconst int SASS2SCSS_PRETTIFY_3 = 3;\n\n\t// remove one-line comment\n\tconst int SASS2SCSS_KEEP_COMMENT = 32;\n\t// remove multi-line comments\n\tconst int SASS2SCSS_STRIP_COMMENT = 64;\n\t// convert one-line to multi-line\n\tconst int SASS2SCSS_CONVERT_COMMENT = 128;\n\n\t// String for finding something interesting\n\tconst std::string SASS2SCSS_FIND_WHITESPACE = \" \\t\\n\\v\\f\\r\";\n\n\t// converter struct\n\t// holding all states\n\tstruct converter\n\t{\n\t\t// bit options\n\t\tint options;\n\t\t// is selector\n\t\tbool selector;\n\t\t// concat lists\n\t\tbool comma;\n\t\t// has property\n\t\tbool property;\n\t\t// has semicolon\n\t\tbool semicolon;\n\t\t// comment context\n\t\tstd::string comment;\n\t\t// flag end of file\n\t\tbool end_of_file;\n\t\t// whitespace buffer\n\t\tstd::string whitespace;\n\t\t// context/block stack\n\t\tstd::stack indents;\n\t};\n\n\t// function only available in c++ code\n\tchar* sass2scss (const std::string& sass, const int options);\n\n}\n// EO namespace\n\n// declare for c\nextern \"C\" {\n#endif\n\n\t// prettyfy print options\n\t#define SASS2SCSS_PRETTIFY_0 0\n\t#define SASS2SCSS_PRETTIFY_1 1\n\t#define SASS2SCSS_PRETTIFY_2 2\n\t#define SASS2SCSS_PRETTIFY_3 3\n\n\t// keep one-line comments\n\t#define SASS2SCSS_KEEP_COMMENT 32\n\t// remove multi-line comments\n\t#define SASS2SCSS_STRIP_COMMENT 64\n\t// convert one-line to multi-line\n\t#define SASS2SCSS_CONVERT_COMMENT 128\n\n\t// available to c and c++ code\n\tADDAPI char* ADDCALL sass2scss (const char* sass, const int options);\n\n\t// Get compiled sass2scss version\n\tADDAPI const char* ADDCALL sass2scss_version(void);\n\n#ifdef __cplusplus\n} // __cplusplus defined.\n#endif\n\n#endif\n// Path: src/libsass/include/sass.h\n#ifndef SASS_H\n#define SASS_H\n\n// #define DEBUG 1\n\n// include API headers\n#include \n#include \n#include \n#include \n#include \n#include \n\n#endif\n\n\n// Path: src/libsass/src/ast.hpp\n#ifndef SASS_AST_H\n#define SASS_AST_H\n\n#include \"sass.hpp\"\n#include \n#include \n#include \n#include \n#include \n#include \n#include \n#include \n#include \"sass/base.h\"\n#include \"ast_fwd_decl.hpp\"\n\n#ifdef DEBUG_SHARED_PTR\n\n#define ATTACH_VIRTUAL_AST_OPERATIONS(klass) \\\n virtual klass##_Ptr copy(std::string, size_t) const = 0; \\\n virtual klass##_Ptr clone(std::string, size_t) const = 0; \\\n\n#define ATTACH_AST_OPERATIONS(klass) \\\n virtual klass##_Ptr copy(std::string, size_t) const; \\\n virtual klass##_Ptr clone(std::string, size_t) const; \\\n\n#else\n\n#define ATTACH_VIRTUAL_AST_OPERATIONS(klass) \\\n virtual klass##_Ptr copy() const = 0; \\\n virtual klass##_Ptr clone() const = 0; \\\n\n#define ATTACH_AST_OPERATIONS(klass) \\\n virtual klass##_Ptr copy() const; \\\n virtual klass##_Ptr clone() const; \\\n\n#endif\n\n#ifdef __clang__\n\n/*\n * There are some overloads used here that trigger the clang overload\n * hiding warning. Specifically:\n *\n * Type type() which hides string type() from Expression\n *\n */\n#pragma clang diagnostic push\n#pragma clang diagnostic ignored \"-Woverloaded-virtual\"\n\n#endif\n\n#include \"util.hpp\"\n#include \"units.hpp\"\n#include \"context.hpp\"\n#include \"position.hpp\"\n#include \"constants.hpp\"\n#include \"operation.hpp\"\n#include \"position.hpp\"\n#include \"inspect.hpp\"\n#include \"source_map.hpp\"\n#include \"environment.hpp\"\n#include \"error_handling.hpp\"\n#include \"ast_def_macros.hpp\"\n#include \"ast_fwd_decl.hpp\"\n#include \"source_map.hpp\"\n\n#include \"sass.h\"\n\nnamespace Sass {\n\n // easier to search with name\n const bool DELAYED = true;\n\n // ToDo: should this really be hardcoded\n // Note: most methods follow precision option\n const double NUMBER_EPSILON = 0.00000000000001;\n\n // macro to test if numbers are equal within a small error margin\n #define NEAR_EQUAL(lhs, rhs) std::fabs(lhs - rhs) < NUMBER_EPSILON\n\n // ToDo: where does this fit best?\n // We don't share this with C-API?\n class Operand {\n public:\n \nOperand(Sass_OP operand, bool ws_before = false, bool ws_after = false)\n : operand(operand), ws_before(ws_before), ws_after(ws_after)\n { }\n public:\n enum Sass_OP operand;\n bool ws_before;\n bool ws_after;\n };\n\n //////////////////////////////////////////////////////////\n // `hash_combine` comes from boost (functional/hash):\n // http://www.boost.org/doc/libs/1_35_0/doc/html/hash/combine.html\n // Boost Software License - Version 1.0\n // http://www.boost.org/users/license.html\n template \n void hash_combine (std::size_t& seed, const T& val)\n {\n seed ^= std::hash()(val) + 0x9e3779b9\n + (seed<<6) + (seed>>2);\n }\n //////////////////////////////////////////////////////////\n\n //////////////////////////////////////////////////////////\n // Abstract base class for all abstract syntax tree nodes.\n //////////////////////////////////////////////////////////\n class AST_Node : public SharedObj {\n ADD_PROPERTY(ParserState, pstate)\n public:\n AST_Node(ParserState pstate)\n : pstate_(pstate)\n { }\n AST_Node(const AST_Node* ptr)\n : pstate_(ptr->pstate_)\n { }\n\n // AST_Node(AST_Node& ptr) = delete;\n\n virtual ~AST_Node() = 0;\n virtual size_t hash() { return 0; }\n ATTACH_VIRTUAL_AST_OPERATIONS(AST_Node);\n virtual std::string inspect() const { return to_string({ INSPECT, 5 }); }\n virtual std::string to_sass() const { return to_string({ TO_SASS, 5 }); }\n virtual const std::string to_string(Sass_Inspect_Options opt) const;\n virtual const std::string to_string() const;\n virtual void cloneChildren() {};\n // generic find function (not fully implemented yet)\n // ToDo: add specific implementions to all children\n virtual bool find ( bool (*f)(AST_Node_Obj) ) { return f(this); };\n public:\n void update_pstate(const ParserState& pstate);\n public:\n Offset off() { return pstate(); }\n Position pos() { return pstate(); }\n ATTACH_OPERATIONS()\n };\n inline AST_Node::~AST_Node() { }\n\n //////////////////////////////////////////////////////////////////////\n // define cast template now (need complete type)\n //////////////////////////////////////////////////////////////////////\n\n template\n T* Cast(AST_Node* ptr) {\n return ptr && typeid(T) == typeid(*ptr) ?\n static_cast(ptr) : NULL;\n };\n\n template\n const T* Cast(const AST_Node* ptr) {\n return ptr && typeid(T) == typeid(*ptr) ?\n static_cast(ptr) : NULL;\n };\n\n //////////////////////////////////////////////////////////////////////\n // Abstract base class for expressions. This side of the AST hierarchy\n // represents elements in value contexts, which exist primarily to be\n // evaluated and returned.\n //////////////////////////////////////////////////////////////////////\n class Expression : public AST_Node {\n public:\n enum Concrete_Type {\n NONE,\n BOOLEAN,\n NUMBER,\n COLOR,\n STRING,\n LIST,\n MAP,\n SELECTOR,\n NULL_VAL,\n FUNCTION_VAL,\n C_WARNING,\n C_ERROR,\n FUNCTION,\n VARIABLE,\n NUM_TYPES\n };\n enum Simple_Type {\n SIMPLE,\n ATTR_SEL,\n PSEUDO_SEL,\n WRAPPED_SEL,\n };\n private:\n // expressions in some contexts shouldn't be evaluated\n ADD_PROPERTY(bool, is_delayed)\n ADD_PROPERTY(bool, is_expanded)\n ADD_PROPERTY(bool, is_interpolant)\n ADD_PROPERTY(Concrete_Type, concrete_type)\n public:\n Expression(ParserState pstate,\n bool d = false, bool e = false, bool i = false, Concrete_Type ct = NONE)\n : AST_Node(pstate),\n is_delayed_(d),\n is_expanded_(e),\n is_interpolant_(i),\n concrete_type_(ct)\n { }\n Expression(const Expression* ptr)\n : AST_Node(ptr),\n is_delayed_(ptr->is_delayed_),\n is_expanded_(ptr->is_expanded_),\n is_interpolant_(ptr->is_interpolant_),\n concrete_type_(ptr->concrete_type_)\n { }\n virtual operator bool() { return true; }\n virtual ~Expression() { }\n virtual std::string type() const { return \"\"; /* TODO: raise an error? */ }\n virtual bool is_invisible() const { return false; }\n static std::string type_name() { return \"\"; }\n virtual bool is_false() { return false; }\n // virtual bool is_true() { return !is_false(); }\n virtual bool operator== (const Expression& rhs) const { return false; }\n virtual bool eq(const Expression& rhs) const { return *this == rhs; };\n virtual void set_delayed(bool delayed) { is_delayed(delayed); }\n virtual bool has_interpolant() const { return is_interpolant(); }\n virtual bool is_left_interpolant() const { return is_interpolant(); }\n virtual bool is_right_interpolant() const { return is_interpolant(); }\n virtual std::string inspect() const { return to_string({ INSPECT, 5 }); }\n virtual std::string to_sass() const { return to_string({ TO_SASS, 5 }); }\n ATTACH_VIRTUAL_AST_OPERATIONS(Expression);\n virtual size_t hash() { return 0; }\n };\n\n //////////////////////////////////////////////////////////////////////\n // Still just an expression, but with a to_string method\n //////////////////////////////////////////////////////////////////////\n class PreValue : public Expression {\n public:\n PreValue(ParserState pstate,\n bool d = false, bool e = false, bool i = false, Concrete_Type ct = NONE)\n : Expression(pstate, d, e, i, ct)\n { }\n PreValue(const PreValue* ptr)\n : Expression(ptr)\n { }\n ATTACH_VIRTUAL_AST_OPERATIONS(PreValue);\n virtual ~PreValue() { }\n };\n\n //////////////////////////////////////////////////////////////////////\n // base class for values that support operations\n //////////////////////////////////////////////////////////////////////\n class Value : public Expression {\n public:\n Value(ParserState pstate,\n bool d = false, bool e = false, bool i = false, Concrete_Type ct = NONE)\n : Expression(pstate, d, e, i, ct)\n { }\n Value(const Value* ptr)\n : Expression(ptr)\n { }\n ATTACH_VIRTUAL_AST_OPERATIONS(Value);\n virtual bool operator== (const Expression& rhs) const = 0;\n };\n}\n\n/////////////////////////////////////////////////////////////////////////////////////\n// Hash method specializations for std::unordered_map to work with Sass::Expression\n/////////////////////////////////////////////////////////////////////////////////////\n\nnamespace std {\n template<>\n struct hash\n {\n size_t operator()(Sass::Expression_Obj s) const\n {\n return s->hash();\n }\n };\n template<>\n struct equal_to\n {\n bool operator()( Sass::Expression_Obj lhs, Sass::Expression_Obj rhs) const\n {\n return lhs->hash() == rhs->hash();\n }\n };\n}\n\nnamespace Sass {\n\n /////////////////////////////////////////////////////////////////////////////\n // Mixin class for AST nodes that should behave like vectors. Uses the\n // \"Template Method\" design pattern to allow subclasses to adjust their flags\n // when certain objects are pushed.\n /////////////////////////////////////////////////////////////////////////////\n template \n class Vectorized {\n std::vector elements_;\n protected:\n size_t hash_;\n void reset_hash() { hash_ = 0; }\n virtual void adjust_after_pushing(T element) { }\n public:\n Vectorized(size_t s = 0) : elements_(std::vector()), hash_(0)\n { elements_.reserve(s); }\n virtual ~Vectorized() = 0;\n size_t length() const { return elements_.size(); }\n bool empty() const { return elements_.empty(); }\n void clear() { return elements_.clear(); }\n T last() const { return elements_.back(); }\n T first() const { return elements_.front(); }\n T& operator[](size_t i) { return elements_[i]; }\n virtual const T& at(size_t i) const { return elements_.at(i); }\n virtual T& at(size_t i) { return elements_.at(i); }\n const T& operator[](size_t i) const { return elements_[i]; }\n virtual void append(T element)\n {\n if (element) {\n reset_hash();\n elements_.push_back(element);\n adjust_after_pushing(element);\n }\n }\n virtual void concat(Vectorized* v)\n {\n for (size_t i = 0, L = v->length(); i < L; ++i) this->append((*v)[i]);\n }\n Vectorized& unshift(T element)\n {\n elements_.insert(elements_.begin(), element);\n return *this;\n }\n std::vector& elements() { return elements_; }\n const std::vector& elements() const { return elements_; }\n std::vector& elements(std::vector& e) { elements_ = e; return elements_; }\n\n virtual size_t hash()\n {\n if (hash_ == 0) {\n for (T& el : elements_) {\n hash_combine(hash_, el->hash());\n }\n }\n return hash_;\n }\n\n typename std::vector::iterator end() { return elements_.end(); }\n typename std::vector::iterator begin() { return elements_.begin(); }\n typename std::vector::const_iterator end() const { return elements_.end(); }\n typename std::vector::const_iterator begin() const { return elements_.begin(); }\n typename std::vector::iterator erase(typename std::vector::iterator el) { return elements_.erase(el); }\n typename std::vector::const_iterator erase(typename std::vector::const_iterator el) { return elements_.erase(el); }\n\n };\n template \n inline Vectorized::~Vectorized() { }\n\n /////////////////////////////////////////////////////////////////////////////\n // Mixin class for AST nodes that should behave like a hash table. Uses an\n // extra internally to maintain insertion order for interation.\n /////////////////////////////////////////////////////////////////////////////\n class Hashed {\n private:\n ExpressionMap elements_;\n std::vector list_;\n protected:\n size_t hash_;\n Expression_Obj duplicate_key_;\n void reset_hash() { hash_ = 0; }\n void reset_duplicate_key() { duplicate_key_ = 0; }\n virtual void adjust_after_pushing(std::pair p) { }\n public:\n Hashed(size_t s = 0)\n : elements_(ExpressionMap(s)),\n list_(std::vector()),\n hash_(0), duplicate_key_(NULL)\n { elements_.reserve(s); list_.reserve(s); }\n virtual ~Hashed();\n size_t length() const { return list_.size(); }\n bool empty() const { return list_.empty(); }\n bool has(Expression_Obj k) const { return elements_.count(k) == 1; }\n Expression_Obj at(Expression_Obj k) const;\n bool has_duplicate_key() const { return duplicate_key_ != 0; }\n Expression_Obj get_duplicate_key() const { return duplicate_key_; }\n const ExpressionMap elements() { return elements_; }\n Hashed& operator<<(std::pair p)\n {\n reset_hash();\n\n if (!has(p.first)) list_.push_back(p.first);\n else if (!duplicate_key_) duplicate_key_ = p.first;\n\n elements_[p.first] = p.second;\n\n adjust_after_pushing(p);\n return *this;\n }\n Hashed& operator+=(Hashed* h)\n {\n if (length() == 0) {\n this->elements_ = h->elements_;\n this->list_ = h->list_;\n return *this;\n }\n\n for (auto key : h->keys()) {\n *this << std::make_pair(key, h->at(key));\n }\n\n reset_duplicate_key();\n return *this;\n }\n const ExpressionMap& pairs() const { return elements_; }\n const std::vector& keys() const { return list_; }\n\n// std::unordered_map::iterator end() { return elements_.end(); }\n// std::unordered_map::iterator begin() { return elements_.begin(); }\n// std::unordered_map::const_iterator end() const { return elements_.end(); }\n// std::unordered_map::const_iterator begin() const { return elements_.begin(); }\n\n };\n inline Hashed::~Hashed() { }\n\n\n /////////////////////////////////////////////////////////////////////////\n // Abstract base class for statements. This side of the AST hierarchy\n // represents elements in expansion contexts, which exist primarily to be\n // rewritten and macro-expanded.\n /////////////////////////////////////////////////////////////////////////\n class Statement : public AST_Node {\n public:\n enum Statement_Type {\n NONE,\n RULESET,\n MEDIA,\n DIRECTIVE,\n SUPPORTS,\n ATROOT,\n BUBBLE,\n CONTENT,\n KEYFRAMERULE,\n DECLARATION,\n ASSIGNMENT,\n IMPORT_STUB,\n IMPORT,\n COMMENT,\n WARNING,\n RETURN,\n EXTEND,\n ERROR,\n DEBUGSTMT,\n WHILE,\n EACH,\n FOR,\n IF\n };\n private:\n ADD_PROPERTY(Statement_Type, statement_type)\n ADD_PROPERTY(size_t, tabs)\n ADD_PROPERTY(bool, group_end)\n public:\n Statement(ParserState pstate, Statement_Type st = NONE, size_t t = 0)\n : AST_Node(pstate), statement_type_(st), tabs_(t), group_end_(false)\n { }\n Statement(const Statement* ptr)\n : AST_Node(ptr),\n statement_type_(ptr->statement_type_),\n tabs_(ptr->tabs_),\n group_end_(ptr->group_end_)\n { }\n virtual ~Statement() = 0;\n // needed for rearranging nested rulesets during CSS emission\n virtual bool is_invisible() const { return false; }\n virtual bool bubbles() { return false; }\n virtual bool has_content()\n {\n return statement_type_ == CONTENT;\n }\n };\n...\n", "needle_token_start": 11993, "needle_token_end": 12046, "code_context_ntokens": 16081, "output": ["```cpp\nclass Operand {\n public:\n \nOperand(Sass_OP operand, bool ws_before = false, bool ws_after = false)\n : operand(operand), ws_before(ws_before), ws_after(ws_after)\n { }\n public:\n enum Sass_OP operand;\n bool ws_before;\n bool ws_after;\n};\n```"]} +{"repo": "sass/node-sass", "name": "debug_sources_set", "language": "cpp", "path": "src/libsass/src/debugger.hpp", "position_ratio": 0.85, "description": "\nFunction Description:\n1. **Purpose**: The function is designed to assist in debugging by printing a visual representation of a set of complex selectors to the standard error output, helping developers understand the current state and structure of the data.\n2. **Input**: It takes a set of complex selectors and an optional string indentation parameter for formatting the output.\n3. **Output**: There is no return value; the function outputs directly to the standard error stream.\n4. **Procedure**: The function begins by checking if the indentation parameter is empty and, if so, prints a starting delimiter. It then iterates over each element in the set, using another function to print details of each complex selector with the provided indentation. After iterating through the set, if the indentation was initially empty, it prints an ending delimiter to the standard error.\n\n", "instruction": "Based on the function description and code context, please retrieve and repeat the exact described function from the code context in a code block wrapped by ```:", "template": "instruction\ncode_context\ndescription\ninstruction", "code_context": "// Path: src/libsass/src/output.cpp\n#include \"sass.hpp\"\n#include \"ast.hpp\"\n#include \"output.hpp\"\n\nnamespace Sass {\n\n Output::Output(Sass_Output_Options& opt)\n : Inspect(Emitter(opt)),\n charset(\"\"),\n top_nodes(0)\n {}\n\n Output::~Output() { }\n\n void Output::fallback_impl(AST_Node_Ptr n)\n {\n return n->perform(this);\n }\n\n void Output::operator()(Number_Ptr n)\n {\n // check for a valid unit here\n // includes result for reporting\n if (!n->is_valid_css_unit()) {\n // should be handle in check_expression\n throw Exception::InvalidValue({}, *n);\n }\n // use values to_string facility\n std::string res = n->to_string(opt);\n // output the final token\n append_token(res, n);\n }\n\n void Output::operator()(Import_Ptr imp)\n {\n top_nodes.push_back(imp);\n }\n\n void Output::operator()(Map_Ptr m)\n {\n // should be handle in check_expression\n throw Exception::InvalidValue({}, *m);\n }\n\n OutputBuffer Output::get_buffer(void)\n {\n\n Emitter emitter(opt);\n Inspect inspect(emitter);\n\n size_t size_nodes = top_nodes.size();\n for (size_t i = 0; i < size_nodes; i++) {\n top_nodes[i]->perform(&inspect);\n inspect.append_mandatory_linefeed();\n }\n\n // flush scheduled outputs\n // maybe omit semicolon if possible\n inspect.finalize(wbuf.buffer.size() == 0);\n // prepend buffer on top\n prepend_output(inspect.output());\n // make sure we end with a linefeed\n if (!ends_with(wbuf.buffer, opt.linefeed)) {\n // if the output is not completely empty\n if (!wbuf.buffer.empty()) append_string(opt.linefeed);\n }\n\n // search for unicode char\n for(const char& chr : wbuf.buffer) {\n // skip all ascii chars\n // static cast to unsigned to handle `char` being signed / unsigned\n if (static_cast(chr) < 128) continue;\n // declare the charset\n if (output_style() != COMPRESSED)\n charset = \"@charset \\\"UTF-8\\\";\"\n + std::string(opt.linefeed);\n else charset = \"\\xEF\\xBB\\xBF\";\n // abort search\n break;\n }\n\n // add charset as first line, before comments and imports\n if (!charset.empty()) prepend_string(charset);\n\n return wbuf;\n\n }\n\n void Output::operator()(Comment_Ptr c)\n {\n std::string txt = c->text()->to_string(opt);\n // if (indentation && txt == \"/**/\") return;\n bool important = c->is_important();\n if (output_style() != COMPRESSED || important) {\n if (buffer().size() == 0) {\n top_nodes.push_back(c);\n } else {\n in_comment = true;\n append_indentation();\n c->text()->perform(this);\n in_comment = false;\n if (indentation == 0) {\n append_mandatory_linefeed();\n } else {\n append_optional_linefeed();\n }\n }\n }\n }\n\n void Output::operator()(Ruleset_Ptr r)\n {\n Selector_Obj s = r->selector();\n Block_Obj b = r->block();\n\n // Filter out rulesets that aren't printable (process its children though)\n if (!Util::isPrintable(r, output_style())) {\n for (size_t i = 0, L = b->length(); i < L; ++i) {\n const Statement_Obj& stm = b->at(i);\n if (Cast(stm)) {\n if (!Cast(stm)) {\n stm->perform(this);\n }\n }\n }\n return;\n }\n\n if (output_style() == NESTED) indentation += r->tabs();\n if (opt.source_comments) {\n std::stringstream ss;\n append_indentation();\n std::string path(File::abs2rel(r->pstate().path));\n ss << \"/* line \" << r->pstate().line + 1 << \", \" << path << \" */\";\n append_string(ss.str());\n append_optional_linefeed();\n }\n scheduled_crutch = s;\n if (s) s->perform(this);\n append_scope_opener(b);\n for (size_t i = 0, L = b->length(); i < L; ++i) {\n Statement_Obj stm = b->at(i);\n bool bPrintExpression = true;\n // Check print conditions\n if (Declaration_Ptr dec = Cast(stm)) {\n if (String_Constant_Ptr valConst = Cast(dec->value())) {\n std::string val(valConst->value());\n if (String_Quoted_Ptr qstr = Cast(valConst)) {\n if (!qstr->quote_mark() && val.empty()) {\n bPrintExpression = false;\n }\n }\n }\n else if (List_Ptr list = Cast(dec->value())) {\n bool all_invisible = true;\n for (size_t list_i = 0, list_L = list->length(); list_i < list_L; ++list_i) {\n Expression_Ptr item = list->at(list_i);\n if (!item->is_invisible()) all_invisible = false;\n }\n if (all_invisible && !list->is_bracketed()) bPrintExpression = false;\n }\n }\n // Print if OK\n if (bPrintExpression) {\n stm->perform(this);\n }\n }\n if (output_style() == NESTED) indentation -= r->tabs();\n append_scope_closer(b);\n\n }\n void Output::operator()(Keyframe_Rule_Ptr r)\n {\n Block_Obj b = r->block();\n Selector_Obj v = r->name();\n\n if (!v.isNull()) {\n v->perform(this);\n }\n\n if (!b) {\n append_colon_separator();\n return;\n }\n\n append_scope_opener();\n for (size_t i = 0, L = b->length(); i < L; ++i) {\n Statement_Obj stm = b->at(i);\n stm->perform(this);\n if (i < L - 1) append_special_linefeed();\n }\n append_scope_closer();\n }\n\n void Output::operator()(Supports_Block_Ptr f)\n {\n if (f->is_invisible()) return;\n\n Supports_Condition_Obj c = f->condition();\n Block_Obj b = f->block();\n\n // Filter out feature blocks that aren't printable (process its children though)\n if (!Util::isPrintable(f, output_style())) {\n for (size_t i = 0, L = b->length(); i < L; ++i) {\n Statement_Obj stm = b->at(i);\n if (Cast(stm)) {\n stm->perform(this);\n }\n }\n return;\n }\n\n if (output_style() == NESTED) indentation += f->tabs();\n append_indentation();\n append_token(\"@supports\", f);\n append_mandatory_space();\n c->perform(this);\n append_scope_opener();\n\n for (size_t i = 0, L = b->length(); i < L; ++i) {\n Statement_Obj stm = b->at(i);\n stm->perform(this);\n if (i < L - 1) append_special_linefeed();\n }\n\n if (output_style() == NESTED) indentation -= f->tabs();\n\n append_scope_closer();\n\n }\n\n void Output::operator()(Media_Block_Ptr m)\n {\n if (m->is_invisible()) return;\n\n Block_Obj b = m->block();\n\n // Filter out media blocks that aren't printable (process its children though)\n if (!Util::isPrintable(m, output_style())) {\n for (size_t i = 0, L = b->length(); i < L; ++i) {\n Statement_Obj stm = b->at(i);\n if (Cast(stm)) {\n stm->perform(this);\n }\n }\n return;\n }\n if (output_style() == NESTED) indentation += m->tabs();\n append_indentation();\n append_token(\"@media\", m);\n append_mandatory_space();\n in_media_block = true;\n m->media_queries()->perform(this);\n in_media_block = false;\n append_scope_opener();\n\n for (size_t i = 0, L = b->length(); i < L; ++i) {\n if (b->at(i)) {\n Statement_Obj stm = b->at(i);\n stm->perform(this);\n }\n if (i < L - 1) append_special_linefeed();\n }\n\n if (output_style() == NESTED) indentation -= m->tabs();\n append_scope_closer();\n }\n\n void Output::operator()(Directive_Ptr a)\n {\n std::string kwd = a->keyword();\n Selector_Obj s = a->selector();\n Expression_Obj v = a->value();\n Block_Obj b = a->block();\n\n append_indentation();\n append_token(kwd, a);\n if (s) {\n append_mandatory_space();\n in_wrapped = true;\n s->perform(this);\n in_wrapped = false;\n }\n if (v) {\n append_mandatory_space();\n // ruby sass bug? should use options?\n append_token(v->to_string(/* opt */), v);\n }\n if (!b) {\n append_delimiter();\n return;\n }\n\n if (b->is_invisible() || b->length() == 0) {\n append_optional_space();\n return append_string(\"{}\");\n }\n\n append_scope_opener();\n\n bool format = kwd != \"@font-face\";;\n\n for (size_t i = 0, L = b->length(); i < L; ++i) {\n Statement_Obj stm = b->at(i);\n stm->perform(this);\n if (i < L - 1 && format) append_special_linefeed();\n }\n\n append_scope_closer();\n }\n\n void Output::operator()(String_Quoted_Ptr s)\n {\n if (s->quote_mark()) {\n append_token(quote(s->value(), s->quote_mark()), s);\n } else if (!in_comment) {\n...\n// Path: src/libsass/src/sass.cpp\n#include \"sass.hpp\"\n#include \n#include \n#include \n#include \n\n#include \"sass.h\"\n#include \"file.hpp\"\n#include \"util.hpp\"\n#include \"sass_context.hpp\"\n#include \"sass_functions.hpp\"\n\nnamespace Sass {\n\n // helper to convert string list to vector\n std::vector list2vec(struct string_list* cur)\n {\n std::vector list;\n while (cur) {\n list.push_back(cur->string);\n cur = cur->next;\n }\n return list;\n }\n\n}\n\nextern \"C\" {\n using namespace Sass;\n\n // Allocate libsass heap memory\n // Don't forget string termination!\n void* ADDCALL sass_alloc_memory(size_t size)\n {\n void* ptr = malloc(size);\n if (ptr == NULL) {\n std::cerr << \"Out of memory.\\n\";\n exit(EXIT_FAILURE);\n }\n return ptr;\n }\n\n char* ADDCALL sass_copy_c_string(const char* str)\n {\n size_t len = strlen(str) + 1;\n char* cpy = (char*) sass_alloc_memory(len);\n std::memcpy(cpy, str, len);\n return cpy;\n }\n\n // Deallocate libsass heap memory\n void ADDCALL sass_free_memory(void* ptr)\n {\n if (ptr) free (ptr);\n }\n\n // caller must free the returned memory\n char* ADDCALL sass_string_quote (const char *str, const char quote_mark)\n {\n std::string quoted = quote(str, quote_mark);\n return sass_copy_c_string(quoted.c_str());\n }\n\n // caller must free the returned memory\n char* ADDCALL sass_string_unquote (const char *str)\n {\n std::string unquoted = unquote(str);\n return sass_copy_c_string(unquoted.c_str());\n }\n\n char* ADDCALL sass_compiler_find_include (const char* file, struct Sass_Compiler* compiler)\n {\n // get the last import entry to get current base directory\n Sass_Import_Entry import = sass_compiler_get_last_import(compiler);\n const std::vector& incs = compiler->cpp_ctx->include_paths;\n // create the vector with paths to lookup\n std::vector paths(1 + incs.size());\n paths.push_back(File::dir_name(import->abs_path));\n paths.insert( paths.end(), incs.begin(), incs.end() );\n // now resolve the file path relative to lookup paths\n std::string resolved(File::find_include(file, paths));\n return sass_copy_c_string(resolved.c_str());\n }\n\n char* ADDCALL sass_compiler_find_file (const char* file, struct Sass_Compiler* compiler)\n {\n // get the last import entry to get current base directory\n Sass_Import_Entry import = sass_compiler_get_last_import(compiler);\n const std::vector& incs = compiler->cpp_ctx->include_paths;\n // create the vector with paths to lookup\n std::vector paths(1 + incs.size());\n paths.push_back(File::dir_name(import->abs_path));\n paths.insert( paths.end(), incs.begin(), incs.end() );\n // now resolve the file path relative to lookup paths\n std::string resolved(File::find_file(file, paths));\n return sass_copy_c_string(resolved.c_str());\n }\n\n // Make sure to free the returned value!\n // Incs array has to be null terminated!\n // this has the original resolve logic for sass include\n char* ADDCALL sass_find_include (const char* file, struct Sass_Options* opt)\n {\n std::vector vec(list2vec(opt->include_paths));\n std::string resolved(File::find_include(file, vec));\n return sass_copy_c_string(resolved.c_str());\n }\n\n // Make sure to free the returned value!\n // Incs array has to be null terminated!\n char* ADDCALL sass_find_file (const char* file, struct Sass_Options* opt)\n {\n std::vector vec(list2vec(opt->include_paths));\n std::string resolved(File::find_file(file, vec));\n return sass_copy_c_string(resolved.c_str());\n }\n\n // Get compiled libsass version\n const char* ADDCALL libsass_version(void)\n {\n return LIBSASS_VERSION;\n }\n\n // Get compiled libsass version\n const char* ADDCALL libsass_language_version(void)\n {\n return LIBSASS_LANGUAGE_VERSION;\n }\n\n}\n\nnamespace Sass {\n\n // helper to aid dreaded MSVC debug mode\n char* sass_copy_string(std::string str)\n {\n // In MSVC the following can lead to segfault:\n // sass_copy_c_string(stream.str().c_str());\n // Reason is that the string returned by str() is disposed before\n // sass_copy_c_string is invoked. The string is actually a stack\n // object, so indeed nobody is holding on to it. So it seems\n // perfectly fair to release it right away. So the const char*\n // by c_str will point to invalid memory. I'm not sure if this is\n // the behavior for all compiler, but I'm pretty sure we would\n // have gotten more issues reported if that would be the case.\n // Wrapping it in a functions seems the cleanest approach as the\n // function must hold on to the stack variable until it's done.\n return sass_copy_c_string(str.c_str());\n }\n\n}\n// Path: src/libsass/src/emitter.cpp\n#include \"sass.hpp\"\n#include \"util.hpp\"\n#include \"context.hpp\"\n#include \"output.hpp\"\n#include \"emitter.hpp\"\n#include \"utf8_string.hpp\"\n\nnamespace Sass {\n\n Emitter::Emitter(struct Sass_Output_Options& opt)\n : wbuf(),\n opt(opt),\n indentation(0),\n scheduled_space(0),\n scheduled_linefeed(0),\n scheduled_delimiter(false),\n scheduled_crutch(0),\n scheduled_mapping(0),\n in_custom_property(false),\n in_comment(false),\n in_wrapped(false),\n in_media_block(false),\n in_declaration(false),\n in_space_array(false),\n in_comma_array(false)\n { }\n\n // return buffer as string\n std::string Emitter::get_buffer(void)\n {\n return wbuf.buffer;\n }\n\n Sass_Output_Style Emitter::output_style(void) const\n {\n return opt.output_style;\n }\n\n // PROXY METHODS FOR SOURCE MAPS\n\n void Emitter::add_source_index(size_t idx)\n { wbuf.smap.source_index.push_back(idx); }\n\n std::string Emitter::render_srcmap(Context &ctx)\n { return wbuf.smap.render_srcmap(ctx); }\n\n void Emitter::set_filename(const std::string& str)\n { wbuf.smap.file = str; }\n\n void Emitter::schedule_mapping(const AST_Node_Ptr node)\n { scheduled_mapping = node; }\n void Emitter::add_open_mapping(const AST_Node_Ptr node)\n { wbuf.smap.add_open_mapping(node); }\n void Emitter::add_close_mapping(const AST_Node_Ptr node)\n { wbuf.smap.add_close_mapping(node); }\n ParserState Emitter::remap(const ParserState& pstate)\n { return wbuf.smap.remap(pstate); }\n\n // MAIN BUFFER MANIPULATION\n\n // add outstanding delimiter\n void Emitter::finalize(bool final)\n {\n scheduled_space = 0;\n if (output_style() == SASS_STYLE_COMPRESSED)\n if (final) scheduled_delimiter = false;\n if (scheduled_linefeed)\n scheduled_linefeed = 1;\n flush_schedules();\n }\n\n // flush scheduled space/linefeed\n void Emitter::flush_schedules(void)\n {\n // check the schedule\n if (scheduled_linefeed) {\n std::string linefeeds = \"\";\n\n for (size_t i = 0; i < scheduled_linefeed; i++)\n linefeeds += opt.linefeed;\n scheduled_space = 0;\n scheduled_linefeed = 0;\n append_string(linefeeds);\n\n } else if (scheduled_space) {\n std::string spaces(scheduled_space, ' ');\n scheduled_space = 0;\n append_string(spaces);\n }\n if (scheduled_delimiter) {\n scheduled_delimiter = false;\n append_string(\";\");\n }\n }\n\n // prepend some text or token to the buffer\n void Emitter::prepend_output(const OutputBuffer& output)\n {\n wbuf.smap.prepend(output);\n wbuf.buffer = output.buffer + wbuf.buffer;\n }\n\n // prepend some text or token to the buffer\n void Emitter::prepend_string(const std::string& text)\n {\n // do not adjust mappings for utf8 bom\n // seems they are not counted in any UA\n if (text.compare(\"\\xEF\\xBB\\xBF\") != 0) {\n wbuf.smap.prepend(Offset(text));\n }\n wbuf.buffer = text + wbuf.buffer;\n }\n\n char Emitter::last_char()\n {\n return wbuf.buffer.back();\n }\n\n // append a single char to the buffer\n void Emitter::append_char(const char chr)\n {\n // write space/lf\n flush_schedules();\n // add to buffer\n wbuf.buffer += chr;\n // account for data in source-maps\n wbuf.smap.append(Offset(chr));\n }\n\n // append some text or token to the buffer\n void Emitter::append_string(const std::string& text)\n {\n\n // write space/lf\n flush_schedules();\n\n if (in_comment && output_style() == COMPACT) {\n // unescape comment nodes\n std::string out = comment_to_string(text);\n // add to buffer\n wbuf.buffer += out;\n // account for data in source-maps\n wbuf.smap.append(Offset(out));\n } else {\n // add to buffer\n wbuf.buffer += text;\n // account for data in source-maps\n wbuf.smap.append(Offset(text));\n }\n }\n\n // append some white-space only text\n void Emitter::append_wspace(const std::string& text)\n {\n if (text.empty()) return;\n if (peek_linefeed(text.c_str())) {\n scheduled_space = 0;\n append_mandatory_linefeed();\n }\n }\n\n // append some text or token to the buffer\n // this adds source-mappings for node start and end\n void Emitter::append_token(const std::string& text, const AST_Node_Ptr node)\n {\n flush_schedules();\n add_open_mapping(node);\n // hotfix for browser issues\n // this is pretty ugly indeed\n if (scheduled_crutch) {\n add_open_mapping(scheduled_crutch);\n scheduled_crutch = 0;\n }\n append_string(text);\n add_close_mapping(node);\n }\n\n // HELPER METHODS\n\n void Emitter::append_indentation()\n {\n if (output_style() == COMPRESSED) return;\n if (output_style() == COMPACT) return;\n if (in_declaration && in_comma_array) return;\n if (scheduled_linefeed && indentation)\n scheduled_linefeed = 1;\n std::string indent = \"\";\n for (size_t i = 0; i < indentation; i++)\n indent += opt.indent;\n append_string(indent);\n }\n\n void Emitter::append_delimiter()\n {\n scheduled_delimiter = true;\n if (output_style() == COMPACT) {\n if (indentation == 0) {\n append_mandatory_linefeed();\n } else {\n append_mandatory_space();\n }\n } else if (output_style() != COMPRESSED) {\n append_optional_linefeed();\n }\n }\n\n void Emitter::append_comma_separator()\n {\n // scheduled_space = 0;\n append_string(\",\");\n append_optional_space();\n }\n\n void Emitter::append_colon_separator()\n {\n scheduled_space = 0;\n append_string(\":\");\n if (!in_custom_property) append_optional_space();\n }\n\n void Emitter::append_mandatory_space()\n {\n scheduled_space = 1;\n }\n\n void Emitter::append_optional_space()\n {\n if ((output_style() != COMPRESSED) && buffer().size()) {\n unsigned char lst = buffer().at(buffer().length() - 1);\n if (!isspace(lst) || scheduled_delimiter) {\n if (last_char() != '(') {\n append_mandatory_space();\n }\n }\n }\n }\n\n void Emitter::append_special_linefeed()\n {\n if (output_style() == COMPACT) {\n append_mandatory_linefeed();\n for (size_t p = 0; p < indentation; p++)\n append_string(opt.indent);\n }\n }\n\n void Emitter::append_optional_linefeed()\n {\n if (in_declaration && in_comma_array) return;\n if (output_style() == COMPACT) {\n append_mandatory_space();\n } else {\n append_mandatory_linefeed();\n }\n }\n\n void Emitter::append_mandatory_linefeed()\n {\n if (output_style() != COMPRESSED) {\n scheduled_linefeed = 1;\n scheduled_space = 0;\n // flush_schedules();\n }\n }\n\n void Emitter::append_scope_opener(AST_Node_Ptr node)\n {\n scheduled_linefeed = 0;\n append_optional_space();\n flush_schedules();\n if (node) add_open_mapping(node);\n append_string(\"{\");\n append_optional_linefeed();\n // append_optional_space();\n ++ indentation;\n }\n void Emitter::append_scope_closer(AST_Node_Ptr node)\n {\n -- indentation;\n scheduled_linefeed = 0;\n if (output_style() == COMPRESSED)\n scheduled_delimiter = false;\n if (output_style() == EXPANDED) {\n append_optional_linefeed();\n append_indentation();\n } else {\n append_optional_space();\n }\n append_string(\"}\");\n if (node) add_close_mapping(node);\n append_optional_linefeed();\n if (indentation != 0) return;\n if (output_style() != COMPRESSED)\n scheduled_linefeed = 2;\n }\n\n}\n\n// Path: src/libsass/src/listize.cpp\n#include \"sass.hpp\"\n#include \n#include \n#include \n\n#include \"listize.hpp\"\n#include \"context.hpp\"\n#include \"backtrace.hpp\"\n#include \"error_handling.hpp\"\n\nnamespace Sass {\n\n Listize::Listize()\n { }\n\n Expression_Ptr Listize::operator()(Selector_List_Ptr sel)\n {\n List_Obj l = SASS_MEMORY_NEW(List, sel->pstate(), sel->length(), SASS_COMMA);\n l->from_selector(true);\n for (size_t i = 0, L = sel->length(); i < L; ++i) {\n if (!sel->at(i)) continue;\n l->append(sel->at(i)->perform(this));\n }\n if (l->length()) return l.detach();\n return SASS_MEMORY_NEW(Null, l->pstate());\n }\n\n Expression_Ptr Listize::operator()(Compound_Selector_Ptr sel)\n {\n std::string str;\n for (size_t i = 0, L = sel->length(); i < L; ++i) {\n Expression_Ptr e = (*sel)[i]->perform(this);\n if (e) str += e->to_string();\n }\n return SASS_MEMORY_NEW(String_Quoted, sel->pstate(), str);\n }\n\n Expression_Ptr Listize::operator()(Complex_Selector_Ptr sel)\n {\n List_Obj l = SASS_MEMORY_NEW(List, sel->pstate(), 2);\n l->from_selector(true);\n Compound_Selector_Obj head = sel->head();\n if (head && !head->is_empty_reference())\n {\n Expression_Ptr hh = head->perform(this);\n if (hh) l->append(hh);\n }\n\n std::string reference = ! sel->reference() ? \"\"\n : sel->reference()->to_string();\n switch(sel->combinator())\n {\n case Complex_Selector::PARENT_OF:\n l->append(SASS_MEMORY_NEW(String_Quoted, sel->pstate(), \">\"));\n break;\n case Complex_Selector::ADJACENT_TO:\n l->append(SASS_MEMORY_NEW(String_Quoted, sel->pstate(), \"+\"));\n break;\n case Complex_Selector::REFERENCE:\n l->append(SASS_MEMORY_NEW(String_Quoted, sel->pstate(), \"/\" + reference + \"/\"));\n break;\n case Complex_Selector::PRECEDES:\n l->append(SASS_MEMORY_NEW(String_Quoted, sel->pstate(), \"~\"));\n break;\n case Complex_Selector::ANCESTOR_OF:\n break;\n default: break;\n }\n\n Complex_Selector_Obj tail = sel->tail();\n if (tail)\n {\n Expression_Obj tt = tail->perform(this);\n if (List_Ptr ls = Cast(tt))\n { l->concat(ls); }\n }\n if (l->length() == 0) return 0;\n return l.detach();\n }\n\n Expression_Ptr Listize::fallback_impl(AST_Node_Ptr n)\n {\n return Cast(n);\n }\n\n}\n\n// Path: src/libsass/src/source_map.cpp\n#include \"sass.hpp\"\n#include \n#include \n#include \n#include \n\n#include \"ast.hpp\"\n#include \"json.hpp\"\n#include \"context.hpp\"\n#include \"position.hpp\"\n#include \"source_map.hpp\"\n\nnamespace Sass {\n SourceMap::SourceMap() : current_position(0, 0, 0), file(\"stdin\") { }\n SourceMap::SourceMap(const std::string& file) : current_position(0, 0, 0), file(file) { }\n\n std::string SourceMap::render_srcmap(Context &ctx) {\n\n const bool include_sources = ctx.c_options.source_map_contents;\n const std::vector links = ctx.srcmap_links;\n const std::vector& sources(ctx.resources);\n\n JsonNode* json_srcmap = json_mkobject();\n\n json_append_member(json_srcmap, \"version\", json_mknumber(3));\n\n const char *file_name = file.c_str();\n JsonNode *json_file_name = json_mkstring(file_name);\n json_append_member(json_srcmap, \"file\", json_file_name);\n\n // pass-through sourceRoot option\n if (!ctx.source_map_root.empty()) {\n JsonNode* root = json_mkstring(ctx.source_map_root.c_str());\n json_append_member(json_srcmap, \"sourceRoot\", root);\n }\n\n JsonNode *json_sources = json_mkarray();\n for (size_t i = 0; i < source_index.size(); ++i) {\n std::string source(links[source_index[i]]);\n if (ctx.c_options.source_map_file_urls) {\n source = File::rel2abs(source);\n // check for windows abs path\n if (source[0] == '/') {\n // ends up with three slashes\n source = \"file://\" + source;\n } else {\n // needs an additional slash\n source = \"file:///\" + source;\n }\n }\n const char* source_name = source.c_str();\n JsonNode *json_source_name = json_mkstring(source_name);\n json_append_element(json_sources, json_source_name);\n }\n json_append_member(json_srcmap, \"sources\", json_sources);\n\n if (include_sources && source_index.size()) {\n JsonNode *json_contents = json_mkarray();\n for (size_t i = 0; i < source_index.size(); ++i) {\n const Resource& resource(sources[source_index[i]]);\n JsonNode *json_content = json_mkstring(resource.contents);\n json_append_element(json_contents, json_content);\n }\n json_append_member(json_srcmap, \"sourcesContent\", json_contents);\n }\n\n JsonNode *json_names = json_mkarray();\n // so far we have no implementation for names\n // no problem as we do not alter any identifiers\n json_append_member(json_srcmap, \"names\", json_names);\n\n std::string mappings = serialize_mappings();\n JsonNode *json_mappings = json_mkstring(mappings.c_str());\n json_append_member(json_srcmap, \"mappings\", json_mappings);\n\n char *str = json_stringify(json_srcmap, \"\\t\");\n std::string result = std::string(str);\n free(str);\n json_delete(json_srcmap);\n return result;\n }\n\n std::string SourceMap::serialize_mappings() {\n std::string result = \"\";\n\n size_t previous_generated_line = 0;\n size_t previous_generated_column = 0;\n size_t previous_original_line = 0;\n size_t previous_original_column = 0;\n size_t previous_original_file = 0;\n for (size_t i = 0; i < mappings.size(); ++i) {\n const size_t generated_line = mappings[i].generated_position.line;\n const size_t generated_column = mappings[i].generated_position.column;\n const size_t original_line = mappings[i].original_position.line;\n const size_t original_column = mappings[i].original_position.column;\n const size_t original_file = mappings[i].original_position.file;\n\n if (generated_line != previous_generated_line) {\n previous_generated_column = 0;\n if (generated_line > previous_generated_line) {\n result += std::string(generated_line - previous_generated_line, ';');\n previous_generated_line = generated_line;\n }\n }\n else if (i > 0) {\n result += \",\";\n }\n\n // generated column\n result += base64vlq.encode(static_cast(generated_column) - static_cast(previous_generated_column));\n previous_generated_column = generated_column;\n // file\n result += base64vlq.encode(static_cast(original_file) - static_cast(previous_original_file));\n previous_original_file = original_file;\n // source line\n result += base64vlq.encode(static_cast(original_line) - static_cast(previous_original_line));\n previous_original_line = original_line;\n // source column\n result += base64vlq.encode(static_cast(original_column) - static_cast(previous_original_column));\n previous_original_column = original_column;\n }\n\n return result;\n }\n\n void SourceMap::prepend(const OutputBuffer& out)\n {\n Offset size(out.smap.current_position);\n for (Mapping mapping : out.smap.mappings) {\n if (mapping.generated_position.line > size.line) {\n throw(std::runtime_error(\"prepend sourcemap has illegal line\"));\n }\n if (mapping.generated_position.line == size.line) {\n if (mapping.generated_position.column > size.column) {\n throw(std::runtime_error(\"prepend sourcemap has illegal column\"));\n }\n }\n }\n // adjust the buffer offset\n prepend(Offset(out.buffer));\n // now add the new mappings\n VECTOR_UNSHIFT(mappings, out.smap.mappings);\n }\n\n void SourceMap::append(const OutputBuffer& out)\n {\n append(Offset(out.buffer));\n }\n\n void SourceMap::prepend(const Offset& offset)\n {\n if (offset.line != 0 || offset.column != 0) {\n for (Mapping& mapping : mappings) {\n // move stuff on the first old line\n if (mapping.generated_position.line == 0) {\n mapping.generated_position.column += offset.column;\n }\n // make place for the new lines\n mapping.generated_position.line += offset.line;\n }\n }\n if (current_position.line == 0) {\n current_position.column += offset.column;\n }\n current_position.line += offset.line;\n }\n\n void SourceMap::append(const Offset& offset)\n {\n current_position += offset;\n }\n\n void SourceMap::add_open_mapping(const AST_Node_Ptr node)\n {\n mappings.push_back(Mapping(node->pstate(), current_position));\n }\n\n void SourceMap::add_close_mapping(const AST_Node_Ptr node)\n {\n mappings.push_back(Mapping(node->pstate() + node->pstate().offset, current_position));\n }\n\n ParserState SourceMap::remap(const ParserState& pstate) {\n for (size_t i = 0; i < mappings.size(); ++i) {\n if (\n mappings[i].generated_position.file == pstate.file &&\n mappings[i].generated_position.line == pstate.line &&\n mappings[i].generated_position.column == pstate.column\n ) return ParserState(pstate.path, pstate.src, mappings[i].original_position, pstate.offset);\n }\n return ParserState(pstate.path, pstate.src, Position(-1, -1, -1), Offset(0, 0));\n\n }\n\n}\n\n// Path: src/libsass/src/bind.cpp\n#include \"sass.hpp\"\n#include \"bind.hpp\"\n#include \"ast.hpp\"\n#include \"context.hpp\"\n#include \"expand.hpp\"\n#include \"eval.hpp\"\n#include \n#include \n#include \n\nnamespace Sass {\n\n void bind(std::string type, std::string name, Parameters_Obj ps, Arguments_Obj as, Context* ctx, Env* env, Eval* eval)\n {\n std::string callee(type + \" \" + name);\n\n std::map param_map;\n List_Obj varargs = SASS_MEMORY_NEW(List, as->pstate());\n varargs->is_arglist(true); // enable keyword size handling\n\n for (size_t i = 0, L = as->length(); i < L; ++i) {\n if (auto str = Cast((*as)[i]->value())) {\n // force optional quotes (only if needed)\n if (str->quote_mark()) {\n str->quote_mark('*');\n }\n }\n }\n\n // Set up a map to ensure named arguments refer to actual parameters. Also\n // eval each default value left-to-right, wrt env, populating env as we go.\n for (size_t i = 0, L = ps->length(); i < L; ++i) {\n Parameter_Obj p = ps->at(i);\n param_map[p->name()] = p;\n // if (p->default_value()) {\n // env->local_frame()[p->name()] = p->default_value()->perform(eval->with(env));\n // }\n }\n\n // plug in all args; if we have leftover params, deal with it later\n size_t ip = 0, LP = ps->length();\n size_t ia = 0, LA = as->length();\n while (ia < LA) {\n Argument_Obj a = as->at(ia);\n if (ip >= LP) {\n // skip empty rest arguments\n if (a->is_rest_argument()) {\n if (List_Obj l = Cast(a->value())) {\n if (l->length() == 0) {\n ++ ia; continue;\n }\n }\n }\n std::stringstream msg;\n msg << \"wrong number of arguments (\" << LA << \" for \" << LP << \")\";\n msg << \" for `\" << name << \"'\";\n return error(msg.str(), as->pstate(), eval->exp.traces);\n }\n Parameter_Obj p = ps->at(ip);\n\n // If the current parameter is the rest parameter, process and break the loop\n if (p->is_rest_parameter()) {\n // The next argument by coincidence provides a rest argument\n if (a->is_rest_argument()) {\n\n // We should always get a list for rest arguments\n if (List_Obj rest = Cast(a->value())) {\n // create a new list object for wrapped items\n List_Ptr arglist = SASS_MEMORY_NEW(List,\n p->pstate(),\n 0,\n rest->separator(),\n true);\n // wrap each item from list as an argument\n for (Expression_Obj item : rest->elements()) {\n if (Argument_Obj arg = Cast(item)) {\n arglist->append(SASS_MEMORY_COPY(arg)); // copy\n } else {\n arglist->append(SASS_MEMORY_NEW(Argument,\n item->pstate(),\n item,\n \"\",\n false,\n false));\n }\n }\n // assign new arglist to environment\n env->local_frame()[p->name()] = arglist;\n }\n // invalid state\n else {\n throw std::runtime_error(\"invalid state\");\n }\n } else if (a->is_keyword_argument()) {\n\n // expand keyword arguments into their parameters\n List_Ptr arglist = SASS_MEMORY_NEW(List, p->pstate(), 0, SASS_COMMA, true);\n env->local_frame()[p->name()] = arglist;\n Map_Obj argmap = Cast(a->value());\n for (auto key : argmap->keys()) {\n if (String_Constant_Obj str = Cast(key)) {\n std::string param = unquote(str->value());\n arglist->append(SASS_MEMORY_NEW(Argument,\n key->pstate(),\n argmap->at(key),\n \"$\" + param,\n false,\n false));\n } else {\n eval->exp.traces.push_back(Backtrace(key->pstate()));\n throw Exception::InvalidVarKwdType(key->pstate(), eval->exp.traces, key->inspect(), a);\n }\n }\n\n } else {\n\n // create a new list object for wrapped items\n List_Obj arglist = SASS_MEMORY_NEW(List,\n p->pstate(),\n 0,\n SASS_COMMA,\n true);\n // consume the next args\n while (ia < LA) {\n // get and post inc\n a = (*as)[ia++];\n // maybe we have another list as argument\n List_Obj ls = Cast(a->value());\n // skip any list completely if empty\n if (ls && ls->empty() && a->is_rest_argument()) continue;\n\n Expression_Obj value = a->value();\n if (Argument_Obj arg = Cast(value)) {\n arglist->append(arg);\n }\n // check if we have rest argument\n else if (a->is_rest_argument()) {\n // preserve the list separator from rest args\n if (List_Obj rest = Cast(a->value())) {\n arglist->separator(rest->separator());\n\n for (size_t i = 0, L = rest->length(); i < L; ++i) {\n Expression_Obj obj = rest->value_at_index(i);\n arglist->append(SASS_MEMORY_NEW(Argument,\n obj->pstate(),\n obj,\n \"\",\n false,\n false));\n }\n }\n // no more arguments\n break;\n }\n // wrap all other value types into Argument\n else {\n arglist->append(SASS_MEMORY_NEW(Argument,\n a->pstate(),\n a->value(),\n a->name(),\n false,\n false));\n }\n }\n // assign new arglist to environment\n env->local_frame()[p->name()] = arglist;\n }\n // consumed parameter\n ++ip;\n // no more paramaters\n break;\n }\n\n // If the current argument is the rest argument, extract a value for processing\n else if (a->is_rest_argument()) {\n // normal param and rest arg\n List_Obj arglist = Cast(a->value());\n if (!arglist) {\n if (Expression_Obj arg = Cast(a->value())) {\n arglist = SASS_MEMORY_NEW(List, a->pstate(), 1);\n arglist->append(arg);\n }\n }\n\n // empty rest arg - treat all args as default values\n if (!arglist || !arglist->length()) {\n break;\n } else {\n if (arglist->length() > LP - ip && !ps->has_rest_parameter()) {\n size_t arg_count = (arglist->length() + LA - 1);\n std::stringstream msg;\n msg << callee << \" takes \" << LP;\n msg << (LP == 1 ? \" argument\" : \" arguments\");\n msg << \" but \" << arg_count;\n msg << (arg_count == 1 ? \" was passed\" : \" were passed.\");\n deprecated_bind(msg.str(), as->pstate());\n\n while (arglist->length() > LP - ip) {\n arglist->elements().erase(arglist->elements().end() - 1);\n }\n }\n }\n // otherwise move one of the rest args into the param, converting to argument if necessary\n Expression_Obj obj = arglist->at(0);\n if (!(a = Cast(obj))) {\n Expression_Ptr a_to_convert = obj;\n a = SASS_MEMORY_NEW(Argument,\n a_to_convert->pstate(),\n a_to_convert,\n \"\",\n false,\n false);\n }\n arglist->elements().erase(arglist->elements().begin());\n if (!arglist->length() || (!arglist->is_arglist() && ip + 1 == LP)) {\n ++ia;\n }\n\n } else if (a->is_keyword_argument()) {\n Map_Obj argmap = Cast(a->value());\n\n for (auto key : argmap->keys()) {\n String_Constant_Ptr val = Cast(key);\n if (val == NULL) {\n eval->exp.traces.push_back(Backtrace(key->pstate()));\n throw Exception::InvalidVarKwdType(key->pstate(), eval->exp.traces, key->inspect(), a);\n }\n std::string param = \"$\" + unquote(val->value());\n\n if (!param_map.count(param)) {\n std::stringstream msg;\n msg << callee << \" has no parameter named \" << param;\n error(msg.str(), a->pstate(), eval->exp.traces);\n }\n env->local_frame()[param] = argmap->at(key);\n }\n ++ia;\n continue;\n } else {\n ++ia;\n }\n\n if (a->name().empty()) {\n if (env->has_local(p->name())) {\n std::stringstream msg;\n msg << \"parameter \" << p->name()\n << \" provided more than once in call to \" << callee;\n error(msg.str(), a->pstate(), eval->exp.traces);\n }\n // ordinal arg -- bind it to the next param\n env->local_frame()[p->name()] = a->value();\n ++ip;\n }\n else {\n // named arg -- bind it to the appropriately named param\n if (!param_map.count(a->name())) {\n if (ps->has_rest_parameter()) {\n varargs->append(a);\n } else {\n std::stringstream msg;\n msg << callee << \" has no parameter named \" << a->name();\n error(msg.str(), a->pstate(), eval->exp.traces);\n }\n }\n if (param_map[a->name()]) {\n if (param_map[a->name()]->is_rest_parameter()) {\n std::stringstream msg;\n msg << \"argument \" << a->name() << \" of \" << callee\n << \"cannot be used as named argument\";\n error(msg.str(), a->pstate(), eval->exp.traces);\n }\n }\n if (env->has_local(a->name())) {\n std::stringstream msg;\n msg << \"parameter \" << p->name()\n << \"provided more than once in call to \" << callee;\n error(msg.str(), a->pstate(), eval->exp.traces);\n }\n env->local_frame()[a->name()] = a->value();\n }\n }\n // EO while ia\n\n // If we make it here, we're out of args but may have leftover params.\n // That's only okay if they have default values, or were already bound by\n // named arguments, or if it's a single rest-param.\n for (size_t i = ip; i < LP; ++i) {\n Parameter_Obj leftover = ps->at(i);\n // cerr << \"env for default params:\" << endl;\n // env->print();\n // cerr << \"********\" << endl;\n if (!env->has_local(leftover->name())) {\n if (leftover->is_rest_parameter()) {\n env->local_frame()[leftover->name()] = varargs;\n }\n else if (leftover->default_value()) {\n Expression_Ptr dv = leftover->default_value()->perform(eval);\n env->local_frame()[leftover->name()] = dv;\n }\n else {\n // param is unbound and has no default value -- error\n throw Exception::MissingArgument(as->pstate(), eval->exp.traces, name, leftover->name(), type);\n }\n }\n }\n\n return;\n }\n\n\n}\n\n// Path: src/libsass/src/debugger.hpp\n#ifndef SASS_DEBUGGER_H\n#define SASS_DEBUGGER_H\n\n#include \n#include \n#include \"node.hpp\"\n#include \"ast_fwd_decl.hpp\"\n\nusing namespace Sass;\n\ninline void debug_ast(AST_Node_Ptr node, std::string ind = \"\", Env* env = 0);\n\ninline void debug_ast(const AST_Node* node, std::string ind = \"\", Env* env = 0) {\n debug_ast(const_cast(node), ind, env);\n}\n\n\ninline void debug_sources_set(ComplexSelectorSet& set, std::string ind = \"\")\n{\n if (ind == \"\") std::cerr << \"#%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%\\n\";\n for(auto const &pair : set) {\n debug_ast(pair, ind + \"\");\n // debug_ast(set[pair], ind + \"first: \");\n }\n if (ind == \"\") std::cerr << \"#%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%\\n\";\n}\n\ninline std::string str_replace(std::string str, const std::string& oldStr, const std::string& newStr)\n{\n size_t pos = 0;\n while((pos = str.find(oldStr, pos)) != std::string::npos)\n {\n str.replace(pos, oldStr.length(), newStr);\n pos += newStr.length();\n }\n return str;\n}\n\ninline std::string prettyprint(const std::string& str) {\n std::string clean = str_replace(str, \"\\n\", \"\\\\n\");\n clean = str_replace(clean, \"\t\", \"\\\\t\");\n clean = str_replace(clean, \"\\r\", \"\\\\r\");\n return clean;\n}\n\ninline std::string longToHex(long long t) {\n std::stringstream is;\n is << std::hex << t;\n return is.str();\n}\n\ninline std::string pstate_source_position(AST_Node_Ptr node)\n{\n std::stringstream str;\n Position start(node->pstate());\n Position end(start + node->pstate().offset);\n str << (start.file == std::string::npos ? -1 : start.file)\n << \"@[\" << start.line << \":\" << start.column << \"]\"\n << \"-[\" << end.line << \":\" << end.column << \"]\";\n#ifdef DEBUG_SHARED_PTR\n str << \"x\" << node->getRefCount() << \"\"\n << \" \" << node->getDbgFile()\n << \"@\" << node->getDbgLine();\n#endif\n return str.str();\n}\n\ninline void debug_ast(AST_Node_Ptr node, std::string ind, Env* env)\n{\n if (node == 0) return;\n if (ind == \"\") std::cerr << \"####################################################################\\n\";\n if (Cast(node)) {\n Bubble_Ptr bubble = Cast(node);\n std::cerr << ind << \"Bubble \" << bubble;\n std::cerr << \" (\" << pstate_source_position(node) << \")\";\n std::cerr << \" \" << bubble->tabs();\n std::cerr << std::endl;\n debug_ast(bubble->node(), ind + \" \", env);\n } else if (Cast(node)) {\n Trace_Ptr trace = Cast(node);\n std::cerr << ind << \"Trace \" << trace;\n std::cerr << \" (\" << pstate_source_position(node) << \")\"\n << \" [name:\" << trace->name() << \", type: \" << trace->type() << \"]\"\n << std::endl;\n debug_ast(trace->block(), ind + \" \", env);\n } else if (Cast(node)) {\n At_Root_Block_Ptr root_block = Cast(node);\n std::cerr << ind << \"At_Root_Block \" << root_block;\n std::cerr << \" (\" << pstate_source_position(node) << \")\";\n std::cerr << \" \" << root_block->tabs();\n std::cerr << std::endl;\n debug_ast(root_block->expression(), ind + \":\", env);\n debug_ast(root_block->block(), ind + \" \", env);\n } else if (Cast(node)) {\n Selector_List_Ptr selector = Cast(node);\n std::cerr << ind << \"Selector_List \" << selector;\n std::cerr << \" (\" << pstate_source_position(node) << \")\";\n std::cerr << \" <\" << selector->hash() << \">\";\n std::cerr << \" [@media:\" << selector->media_block() << \"]\";\n std::cerr << (selector->is_invisible() ? \" [INVISIBLE]\": \" -\");\n std::cerr << (selector->has_placeholder() ? \" [PLACEHOLDER]\": \" -\");\n std::cerr << (selector->is_optional() ? \" [is_optional]\": \" -\");\n std::cerr << (selector->has_parent_ref() ? \" [has-parent]\": \" -\");\n std::cerr << (selector->has_line_break() ? \" [line-break]\": \" -\");\n std::cerr << (selector->has_line_feed() ? \" [line-feed]\": \" -\");\n std::cerr << std::endl;\n debug_ast(selector->schema(), ind + \"#{} \");\n\n for(const Complex_Selector_Obj& i : selector->elements()) { debug_ast(i, ind + \" \", env); }\n\n// } else if (Cast(node)) {\n// Expression_Ptr expression = Cast(node);\n// std::cerr << ind << \"Expression \" << expression << \" \" << expression->concrete_type() << std::endl;\n\n } else if (Cast(node)) {\n Parent_Selector_Ptr selector = Cast(node);\n std::cerr << ind << \"Parent_Selector \" << selector;\n// if (selector->not_selector()) cerr << \" [in_declaration]\";\n std::cerr << \" (\" << pstate_source_position(node) << \")\";\n std::cerr << \" <\" << selector->hash() << \">\";\n std::cerr << \" [\" << (selector->is_real_parent_ref() ? \"REAL\" : \"FAKE\") << \"]\";\n std::cerr << \" <\" << prettyprint(selector->pstate().token.ws_before()) << \">\" << std::endl;\n// debug_ast(selector->selector(), ind + \"->\", env);\n\n } else if (Cast(node)) {\n Complex_Selector_Ptr selector = Cast(node);\n std::cerr << ind << \"Complex_Selector \" << selector\n << \" (\" << pstate_source_position(node) << \")\"\n << \" <\" << selector->hash() << \">\"\n << \" [length:\" << longToHex(selector->length()) << \"]\"\n << \" [weight:\" << longToHex(selector->specificity()) << \"]\"\n << \" [@media:\" << selector->media_block() << \"]\"\n << (selector->is_invisible() ? \" [INVISIBLE]\": \" -\")\n << (selector->has_placeholder() ? \" [PLACEHOLDER]\": \" -\")\n << (selector->is_optional() ? \" [is_optional]\": \" -\")\n << (selector->has_parent_ref() ? \" [has parent]\": \" -\")\n << (selector->has_line_feed() ? \" [line-feed]\": \" -\")\n << (selector->has_line_break() ? \" [line-break]\": \" -\")\n << \" -- \";\n std::string del;\n switch (selector->combinator()) {\n case Complex_Selector::PARENT_OF: del = \">\"; break;\n case Complex_Selector::PRECEDES: del = \"~\"; break;\n case Complex_Selector::ADJACENT_TO: del = \"+\"; break;\n case Complex_Selector::ANCESTOR_OF: del = \" \"; break;\n case Complex_Selector::REFERENCE: del = \"//\"; break;\n }\n // if (del = \"/\") del += selector->reference()->perform(&to_string) + \"/\";\n std::cerr << \" <\" << prettyprint(selector->pstate().token.ws_before()) << \">\" << std::endl;\n debug_ast(selector->head(), ind + \" \" /* + \"[\" + del + \"]\" */, env);\n if (selector->tail()) {\n debug_ast(selector->tail(), ind + \"{\" + del + \"}\", env);\n } else if(del != \" \") {\n std::cerr << ind << \" |\" << del << \"| {trailing op}\" << std::endl;\n }\n ComplexSelectorSet set = selector->sources();\n // debug_sources_set(set, ind + \" @--> \");\n } else if (Cast(node)) {\n Compound_Selector_Ptr selector = Cast(node);\n std::cerr << ind << \"Compound_Selector \" << selector;\n std::cerr << \" (\" << pstate_source_position(node) << \")\";\n std::cerr << \" <\" << selector->hash() << \">\";\n std::cerr << \" [weight:\" << longToHex(selector->specificity()) << \"]\";\n std::cerr << \" [@media:\" << selector->media_block() << \"]\";\n std::cerr << (selector->extended() ? \" [extended]\": \" -\");\n std::cerr << (selector->is_optional() ? \" [is_optional]\": \" -\");\n std::cerr << (selector->has_parent_ref() ? \" [has-parent]\": \" -\");\n std::cerr << (selector->has_line_break() ? \" [line-break]\": \" -\");\n std::cerr << (selector->has_line_feed() ? \" [line-feed]\": \" -\");\n std::cerr << \" <\" << prettyprint(selector->pstate().token.ws_before()) << \">\" << std::endl;\n for(const Simple_Selector_Obj& i : selector->elements()) { debug_ast(i, ind + \" \", env); }\n } else if (Cast(node)) {\n Wrapped_Selector_Ptr selector = Cast(node);\n...\n", "needle_token_start": 13682, "needle_token_end": 13804, "code_context_ntokens": 16192, "output": ["```cpp\ninline void debug_sources_set(ComplexSelectorSet& set, std::string ind = \"\")\n{\n if (ind == \"\") std::cerr << \"#%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%\\n\";\n for(auto const &pair : set) {\n debug_ast(pair, ind + \"\");\n // debug_ast(set[pair], ind + \"first: \");\n }\n if (ind == \"\") std::cerr << \"#%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%\\n\";\n}\n```"]} +{"repo": "sass/node-sass", "name": "is_alnum", "language": "cpp", "path": "src/libsass/src/lexer.cpp", "position_ratio": 0.95, "description": "\nFunction Description:\n1. **Purpose**: To determine if a character is either a letter or a digit.\n2. **Input**: A single character.\n3. **Output**: A boolean value; `true` if the character is a letter or a digit, otherwise `false`.\n4. **Procedure**: The function checks if the character is a letter by calling another function designed to identify letters, or if it is a digit by calling a function designed to identify digits. If either condition is met, it returns `true`; otherwise, it returns `false`.\n\n", "instruction": "Based on the function description and code context, please retrieve and repeat the exact described function from the code context in a code block wrapped by ```:", "template": "instruction\ncode_context\ndescription\ninstruction", "code_context": "// Path: src/libsass/src/lexer.hpp\n#ifndef SASS_LEXER_H\n#define SASS_LEXER_H\n\n#include \n\nnamespace Sass {\n namespace Prelexer {\n\n //####################################\n // BASIC CHARACTER MATCHERS\n //####################################\n\n // Match standard control chars\n const char* kwd_at(const char* src);\n const char* kwd_dot(const char* src);\n const char* kwd_comma(const char* src);\n const char* kwd_colon(const char* src);\n const char* kwd_star(const char* src);\n const char* kwd_plus(const char* src);\n const char* kwd_minus(const char* src);\n const char* kwd_slash(const char* src);\n\n //####################################\n // BASIC CLASS MATCHERS\n //####################################\n\n // These are locale independant\n bool is_space(const char& src);\n bool is_alpha(const char& src);\n bool is_punct(const char& src);\n...\n// Path: src/libsass/src/prelexer.hpp\n#ifndef SASS_PRELEXER_H\n#define SASS_PRELEXER_H\n\n#include \n#include \"lexer.hpp\"\n\nnamespace Sass {\n // using namespace Lexer;\n namespace Prelexer {\n\n //####################################\n // KEYWORD \"REGEX\" MATCHERS\n //####################################\n\n // Match Sass boolean keywords.\n const char* kwd_true(const char* src);\n const char* kwd_false(const char* src);\n const char* kwd_only(const char* src);\n const char* kwd_and(const char* src);\n const char* kwd_or(const char* src);\n const char* kwd_not(const char* src);\n const char* kwd_eq(const char* src);\n const char* kwd_neq(const char* src);\n const char* kwd_gt(const char* src);\n const char* kwd_gte(const char* src);\n const char* kwd_lt(const char* src);\n const char* kwd_lte(const char* src);\n\n // Match standard control chars\n const char* kwd_at(const char* src);\n const char* kwd_dot(const char* src);\n const char* kwd_comma(const char* src);\n const char* kwd_colon(const char* src);\n const char* kwd_slash(const char* src);\n const char* kwd_star(const char* src);\n const char* kwd_plus(const char* src);\n const char* kwd_minus(const char* src);\n\n //####################################\n // SPECIAL \"REGEX\" CONSTRUCTS\n //####################################\n\n // Match a sequence of characters delimited by the supplied chars.\n template \n const char* delimited_by(const char* src) {\n src = exactly(src);\n if (!src) return 0;\n const char* stop;\n while (true) {\n if (!*src) return 0;\n stop = exactly(src);\n if (stop && (!esc || *(src - 1) != '\\\\')) return stop;\n src = stop ? stop : src + 1;\n }\n }\n\n // skip to delimiter (mx) inside given range\n // this will savely skip over all quoted strings\n // recursive skip stuff delimited by start/stop\n // first start/opener must be consumed already!\n template\n const char* skip_over_scopes(const char* src, const char* end) {\n\n size_t level = 0;\n bool in_squote = false;\n bool in_dquote = false;\n // bool in_braces = false;\n\n while (*src) {\n\n // check for abort condition\n if (end && src >= end) break;\n\n // has escaped sequence?\n if (*src == '\\\\') {\n ++ src; // skip this (and next)\n }\n else if (*src == '\"') {\n in_dquote = ! in_dquote;\n }\n else if (*src == '\\'') {\n in_squote = ! in_squote;\n }\n else if (in_dquote || in_squote) {\n // take everything literally\n }\n\n // find another opener inside?\n else if (const char* pos = start(src)) {\n ++ level; // increase counter\n src = pos - 1; // advance position\n }\n\n // look for the closer (maybe final, maybe not)\n else if (const char* final = stop(src)) {\n // only close one level?\n if (level > 0) -- level;\n // return position at end of stop\n // delimiter may be multiple chars\n else return final;\n // advance position\n src = final - 1;\n }\n\n // next\n ++ src;\n }\n\n return 0;\n }\n\n // skip to a skip delimited by parentheses\n // uses smart `skip_over_scopes` internally\n const char* parenthese_scope(const char* src);\n\n // skip to delimiter (mx) inside given range\n // this will savely skip over all quoted strings\n // recursive skip stuff delimited by start/stop\n // first start/opener must be consumed already!\n template\n const char* skip_over_scopes(const char* src) {\n return skip_over_scopes(src, 0);\n }\n\n // Match a sequence of characters delimited by the supplied chars.\n template \n const char* recursive_scopes(const char* src) {\n // parse opener\n src = start(src);\n // abort if not found\n if (!src) return 0;\n // parse the rest until final closer\n return skip_over_scopes(src);\n }\n\n // Match a sequence of characters delimited by the supplied strings.\n template \n const char* delimited_by(const char* src) {\n src = exactly(src);\n if (!src) return 0;\n const char* stop;\n while (true) {\n if (!*src) return 0;\n stop = exactly(src);\n if (stop && (!esc || *(src - 1) != '\\\\')) return stop;\n src = stop ? stop : src + 1;\n }\n }\n\n // Tries to match a certain number of times (between the supplied interval).\n template\n const char* between(const char* src) {\n for (size_t i = 0; i < lo; ++i) {\n src = mx(src);\n if (!src) return 0;\n }\n for (size_t i = lo; i <= hi; ++i) {\n const char* new_src = mx(src);\n if (!new_src) return src;\n src = new_src;\n }\n return src;\n }\n\n // equivalent of STRING_REGULAR_EXPRESSIONS\n const char* re_string_double_open(const char* src);\n const char* re_string_double_close(const char* src);\n const char* re_string_single_open(const char* src);\n const char* re_string_single_close(const char* src);\n const char* re_string_uri_open(const char* src);\n const char* re_string_uri_close(const char* src);\n\n // Match a line comment.\n const char* line_comment(const char* src);\n\n // Match a block comment.\n const char* block_comment(const char* src);\n // Match either.\n const char* comment(const char* src);\n // Match double- and single-quoted strings.\n const char* double_quoted_string(const char* src);\n const char* single_quoted_string(const char* src);\n const char* quoted_string(const char* src);\n // Match interpolants.\n const char* interpolant(const char* src);\n // Match number prefix ([\\+\\-]+)\n const char* number_prefix(const char* src);\n\n // Match zero plus white-space or line_comments\n const char* optional_css_whitespace(const char* src);\n const char* css_whitespace(const char* src);\n // Match optional_css_whitepace plus block_comments\n const char* optional_css_comments(const char* src);\n const char* css_comments(const char* src);\n\n // Match one backslash escaped char\n const char* escape_seq(const char* src);\n\n // Match CSS css variables.\n const char* custom_property_name(const char* src);\n // Match a CSS identifier.\n const char* identifier(const char* src);\n const char* identifier_alpha(const char* src);\n const char* identifier_alnum(const char* src);\n const char* strict_identifier(const char* src);\n const char* strict_identifier_alpha(const char* src);\n const char* strict_identifier_alnum(const char* src);\n // Match a CSS unit identifier.\n const char* one_unit(const char* src);\n const char* multiple_units(const char* src);\n const char* unit_identifier(const char* src);\n // const char* strict_identifier_alnums(const char* src);\n // Match reference selector.\n const char* re_reference_combinator(const char* src);\n const char* static_reference_combinator(const char* src);\n const char* schema_reference_combinator(const char* src);\n\n // Match interpolant schemas\n const char* identifier_schema(const char* src);\n const char* value_schema(const char* src);\n const char* sass_value(const char* src);\n // const char* filename(const char* src);\n // const char* filename_schema(const char* src);\n // const char* url_schema(const char* src);\n // const char* url_value(const char* src);\n const char* vendor_prefix(const char* src);\n\n const char* re_special_directive(const char* src);\n const char* re_prefixed_directive(const char* src);\n const char* re_almost_any_value_token(const char* src);\n\n // Match CSS '@' keywords.\n const char* at_keyword(const char* src);\n const char* kwd_import(const char* src);\n const char* kwd_at_root(const char* src);\n const char* kwd_with_directive(const char* src);\n const char* kwd_without_directive(const char* src);\n const char* kwd_media(const char* src);\n const char* kwd_supports_directive(const char* src);\n // const char* keyframes(const char* src);\n // const char* keyf(const char* src);\n const char* kwd_mixin(const char* src);\n const char* kwd_function(const char* src);\n const char* kwd_return_directive(const char* src);\n const char* kwd_include_directive(const char* src);\n const char* kwd_content_directive(const char* src);\n const char* kwd_charset_directive(const char* src);\n const char* kwd_extend(const char* src);\n\n const char* unicode_seq(const char* src);\n\n const char* kwd_if_directive(const char* src);\n const char* kwd_else_directive(const char* src);\n const char* elseif_directive(const char* src);\n\n const char* kwd_for_directive(const char* src);\n const char* kwd_from(const char* src);\n const char* kwd_to(const char* src);\n const char* kwd_through(const char* src);\n\n const char* kwd_each_directive(const char* src);\n const char* kwd_in(const char* src);\n\n const char* kwd_while_directive(const char* src);\n\n const char* re_nothing(const char* src);\n\n const char* re_special_fun(const char* src);\n\n const char* kwd_warn(const char* src);\n const char* kwd_err(const char* src);\n const char* kwd_dbg(const char* src);\n\n const char* kwd_null(const char* src);\n\n const char* re_selector_list(const char* src);\n const char* re_type_selector(const char* src);\n const char* re_static_expression(const char* src);\n\n // identifier that can start with hyphens\n const char* css_identifier(const char* src);\n const char* css_ip_identifier(const char* src);\n\n // Match CSS type selectors\n const char* namespace_schema(const char* src);\n const char* namespace_prefix(const char* src);\n const char* type_selector(const char* src);\n const char* hyphens_and_identifier(const char* src);\n const char* hyphens_and_name(const char* src);\n const char* universal(const char* src);\n // Match CSS id names.\n const char* id_name(const char* src);\n // Match CSS class names.\n const char* class_name(const char* src);\n // Attribute name in an attribute selector\n const char* attribute_name(const char* src);\n // Match placeholder selectors.\n const char* placeholder(const char* src);\n // Match CSS numeric constants.\n const char* op(const char* src);\n const char* sign(const char* src);\n const char* unsigned_number(const char* src);\n const char* number(const char* src);\n const char* coefficient(const char* src);\n const char* binomial(const char* src);\n const char* percentage(const char* src);\n const char* ampersand(const char* src);\n const char* dimension(const char* src);\n const char* hex(const char* src);\n const char* hexa(const char* src);\n const char* hex0(const char* src);\n // const char* rgb_prefix(const char* src);\n // Match CSS uri specifiers.\n const char* uri_prefix(const char* src);\n // Match CSS \"!important\" keyword.\n const char* kwd_important(const char* src);\n // Match CSS \"!optional\" keyword.\n const char* kwd_optional(const char* src);\n // Match Sass \"!default\" keyword.\n const char* default_flag(const char* src);\n const char* global_flag(const char* src);\n // Match CSS pseudo-class/element prefixes\n const char* pseudo_prefix(const char* src);\n // Match CSS function call openers.\n const char* re_functional(const char* src);\n const char* re_pseudo_selector(const char* src);\n const char* functional_schema(const char* src);\n const char* pseudo_not(const char* src);\n // Match CSS 'odd' and 'even' keywords for functional pseudo-classes.\n const char* even(const char* src);\n const char* odd(const char* src);\n // Match CSS attribute-matching operators.\n const char* exact_match(const char* src);\n const char* class_match(const char* src);\n const char* dash_match(const char* src);\n const char* prefix_match(const char* src);\n const char* suffix_match(const char* src);\n const char* substring_match(const char* src);\n // Match CSS combinators.\n // const char* adjacent_to(const char* src);\n // const char* precedes(const char* src);\n // const char* parent_of(const char* src);\n // const char* ancestor_of(const char* src);\n\n // Match SCSS variable names.\n const char* variable(const char* src);\n const char* calc_fn_call(const char* src);\n\n // IE stuff\n const char* ie_progid(const char* src);\n const char* ie_expression(const char* src);\n const char* ie_property(const char* src);\n const char* ie_keyword_arg(const char* src);\n const char* ie_keyword_arg_value(const char* src);\n const char* ie_keyword_arg_property(const char* src);\n\n // characters that terminate parsing of a list\n const char* list_terminator(const char* src);\n const char* space_list_terminator(const char* src);\n\n // match url()\n const char* H(const char* src);\n const char* W(const char* src);\n // `UNICODE` makes VS sad\n const char* UUNICODE(const char* src);\n const char* NONASCII(const char* src);\n const char* ESCAPE(const char* src);\n const char* real_uri(const char* src);\n const char* real_uri_suffix(const char* src);\n // const char* real_uri_prefix(const char* src);\n const char* real_uri_value(const char* src);\n\n // Path matching functions.\n // const char* folder(const char* src);\n // const char* folders(const char* src);\n\n\n const char* static_string(const char* src);\n const char* static_component(const char* src);\n const char* static_property(const char* src);\n const char* static_value(const char* src);\n\n const char* css_variable_value(const char* src);\n const char* css_variable_top_level_value(const char* src);\n\n // Utility functions for finding and counting characters in a string.\n template\n const char* find_first(const char* src) {\n while (*src && *src != c) ++src;\n return *src ? src : 0;\n }\n template\n const char* find_first(const char* src) {\n while (*src && !mx(src)) ++src;\n return *src ? src : 0;\n }\n template\n const char* find_first_in_interval(const char* beg, const char* end) {\n bool esc = false;\n while ((beg < end) && *beg) {\n if (esc) esc = false;\n else if (*beg == '\\\\') esc = true;\n else if (mx(beg)) return beg;\n ++beg;\n }\n return 0;\n }\n template\n const char* find_first_in_interval(const char* beg, const char* end) {\n bool esc = false;\n while ((beg < end) && *beg) {\n if (esc) esc = false;\n else if (*beg == '\\\\') esc = true;\n else if (const char* pos = skip(beg)) beg = pos;\n else if (mx(beg)) return beg;\n ++beg;\n }\n return 0;\n }\n template \n unsigned int count_interval(const char* beg, const char* end) {\n unsigned int counter = 0;\n bool esc = false;\n while (beg < end && *beg) {\n const char* p;\n if (esc) {\n esc = false;\n ++beg;\n } else if (*beg == '\\\\') {\n esc = true;\n ++beg;\n } else if ((p = mx(beg))) {\n ++counter;\n beg = p;\n }\n else {\n ++beg;\n }\n }\n return counter;\n }\n\n template \n const char* padded_token(const char* src)\n {\n size_t got = 0;\n const char* pos = src;\n while (got < size) {\n if (!mx(pos)) break;\n ++ pos; ++ got;\n }\n while (got < size) {\n if (!pad(pos)) break;\n ++ pos; ++ got;\n }\n return got ? pos : 0;\n }\n\n template \n const char* minmax_range(const char* src)\n {\n size_t got = 0;\n const char* pos = src;\n while (got < max) {\n if (!mx(pos)) break;\n ++ pos; ++ got;\n }\n if (got < min) return 0;\n if (got > max) return 0;\n return pos;\n }\n\n template \n const char* char_range(const char* src)\n {\n if (*src < min) return 0;\n if (*src > max) return 0;\n return src + 1;\n }\n\n }\n}\n\n#endif\n\n// Path: src/libsass/src/parser.hpp\n#ifndef SASS_PARSER_H\n#define SASS_PARSER_H\n\n#include \n#include \n\n#include \"ast.hpp\"\n#include \"position.hpp\"\n#include \"context.hpp\"\n#include \"position.hpp\"\n#include \"prelexer.hpp\"\n\n#ifndef MAX_NESTING\n// Note that this limit is not an exact science\n// it depends on various factors, which some are\n// not under our control (compile time or even OS\n// dependent settings on the available stack size)\n// It should fix most common segfault cases though.\n#define MAX_NESTING 512\n#endif\n\nstruct Lookahead {\n const char* found;\n const char* error;\n const char* position;\n bool parsable;\n bool has_interpolants;\n bool is_custom_property;\n};\n\nnamespace Sass {\n\n class Parser : public ParserState {\n public:\n\n enum Scope { Root, Mixin, Function, Media, Control, Properties, Rules, AtRoot };\n\n Context& ctx;\n std::vector block_stack;\n std::vector stack;\n Media_Block_Ptr last_media_block;\n const char* source;\n const char* position;\n const char* end;\n Position before_token;\n Position after_token;\n ParserState pstate;\n Backtraces traces;\n size_t indentation;\n size_t nestings;\n\n Token lexed;\n\n Parser(Context& ctx, const ParserState& pstate, Backtraces traces)\n : ParserState(pstate), ctx(ctx), block_stack(), stack(0), last_media_block(),\n source(0), position(0), end(0), before_token(pstate), after_token(pstate),\n pstate(pstate), traces(traces), indentation(0), nestings(0)\n { \n stack.push_back(Scope::Root);\n }\n\n // static Parser from_string(const std::string& src, Context& ctx, ParserState pstate = ParserState(\"[STRING]\"));\n static Parser from_c_str(const char* src, Context& ctx, Backtraces, ParserState pstate = ParserState(\"[CSTRING]\"), const char* source = 0);\n static Parser from_c_str(const char* beg, const char* end, Context& ctx, Backtraces, ParserState pstate = ParserState(\"[CSTRING]\"), const char* source = 0);\n static Parser from_token(Token t, Context& ctx, Backtraces, ParserState pstate = ParserState(\"[TOKEN]\"), const char* source = 0);\n // special static parsers to convert strings into certain selectors\n static Selector_List_Obj parse_selector(const char* src, Context& ctx, Backtraces, ParserState pstate = ParserState(\"[SELECTOR]\"), const char* source = 0);\n\n#ifdef __clang__\n\n // lex and peak uses the template parameter to branch on the action, which\n // triggers clangs tautological comparison on the single-comparison\n // branches. This is not a bug, just a merging of behaviour into\n // one function\n\n#pragma clang diagnostic push\n#pragma clang diagnostic ignored \"-Wtautological-compare\"\n\n#endif\n\n\n // skip current token and next whitespace\n // moves ParserState right before next token\n void advanceToNextToken();\n\n bool peek_newline(const char* start = 0);\n\n // skip over spaces, tabs and line comments\n template \n const char* sneak(const char* start = 0)\n {\n using namespace Prelexer;\n\n // maybe use optional start position from arguments?\n const char* it_position = start ? start : position;\n\n // skip white-space?\n if (mx == spaces ||\n mx == no_spaces ||\n mx == css_comments ||\n mx == css_whitespace ||\n mx == optional_spaces ||\n mx == optional_css_comments ||\n mx == optional_css_whitespace\n ) {\n return it_position;\n }\n\n // skip over spaces, tabs and sass line comments\n const char* pos = optional_css_whitespace(it_position);\n // always return a valid position\n return pos ? pos : it_position;\n\n }\n\n // match will not skip over space, tabs and line comment\n // return the position where the lexer match will occur\n template \n const char* match(const char* start = 0)\n {\n // match the given prelexer\n return mx(position);\n }\n\n // peek will only skip over space, tabs and line comment\n // return the position where the lexer match will occur\n template \n const char* peek(const char* start = 0)\n {\n\n // sneak up to the actual token we want to lex\n // this should skip over white-space if desired\n const char* it_before_token = sneak < mx >(start);\n\n // match the given prelexer\n const char* match = mx(it_before_token);\n\n // check if match is in valid range\n return match <= end ? match : 0;\n\n }\n\n // white-space handling is built into the lexer\n // this way you do not need to parse it yourself\n // some matchers don't accept certain white-space\n // we do not support start arg, since we manipulate\n // sourcemap offset and we modify the position pointer!\n // lex will only skip over space, tabs and line comment\n template \n const char* lex(bool lazy = true, bool force = false)\n {\n\n if (*position == 0) return 0;\n\n // position considered before lexed token\n // we can skip whitespace or comments for\n // lazy developers (but we need control)\n const char* it_before_token = position;\n\n // sneak up to the actual token we want to lex\n // this should skip over white-space if desired\n if (lazy) it_before_token = sneak < mx >(position);\n\n // now call matcher to get position after token\n const char* it_after_token = mx(it_before_token);\n\n // check if match is in valid range\n if (it_after_token > end) return 0;\n\n // maybe we want to update the parser state anyway?\n if (force == false) {\n // assertion that we got a valid match\n if (it_after_token == 0) return 0;\n // assertion that we actually lexed something\n if (it_after_token == it_before_token) return 0;\n }\n\n // create new lexed token object (holds the parse results)\n lexed = Token(position, it_before_token, it_after_token);\n\n // advance position (add whitespace before current token)\n before_token = after_token.add(position, it_before_token);\n\n // update after_token position for current token\n after_token.add(it_before_token, it_after_token);\n\n // ToDo: could probably do this incremetal on original object (API wants offset?)\n pstate = ParserState(path, source, lexed, before_token, after_token - before_token);\n\n // advance internal char iterator\n return position = it_after_token;\n\n }\n\n // lex_css skips over space, tabs, line and block comment\n // all block comments will be consumed and thrown away\n // source-map position will point to token after the comment\n template \n const char* lex_css()\n {\n // copy old token\n Token prev = lexed;\n // store previous pointer\n const char* oldpos = position;\n Position bt = before_token;\n Position at = after_token;\n ParserState op = pstate;\n // throw away comments\n // update srcmap position\n lex < Prelexer::css_comments >();\n // now lex a new token\n const char* pos = lex< mx >();\n // maybe restore prev state\n if (pos == 0) {\n pstate = op;\n lexed = prev;\n position = oldpos;\n after_token = at;\n before_token = bt;\n }\n // return match\n return pos;\n }\n\n // all block comments will be skipped and thrown away\n template \n const char* peek_css(const char* start = 0)\n {\n // now peek a token (skip comments first)\n return peek< mx >(peek < Prelexer::css_comments >(start));\n }\n\n#ifdef __clang__\n\n#pragma clang diagnostic pop\n\n#endif\n\n void error(std::string msg);\n void error(std::string msg, Position pos);\n // generate message with given and expected sample\n // text before and in the middle are configurable\n void css_error(const std::string& msg,\n const std::string& prefix = \" after \",\n const std::string& middle = \", was: \",\n const bool trim = true);\n void read_bom();\n\n Block_Obj parse();\n Import_Obj parse_import();\n Definition_Obj parse_definition(Definition::Type which_type);\n Parameters_Obj parse_parameters();\n Parameter_Obj parse_parameter();\n Mixin_Call_Obj parse_include_directive();\n Arguments_Obj parse_arguments();\n Argument_Obj parse_argument();\n Assignment_Obj parse_assignment();\n Ruleset_Obj parse_ruleset(Lookahead lookahead);\n Selector_List_Obj parse_selector_list(bool chroot);\n Complex_Selector_Obj parse_complex_selector(bool chroot);\n Selector_Schema_Obj parse_selector_schema(const char* end_of_selector, bool chroot);\n Compound_Selector_Obj parse_compound_selector();\n Simple_Selector_Obj parse_simple_selector();\n Wrapped_Selector_Obj parse_negated_selector();\n Simple_Selector_Obj parse_pseudo_selector();\n Attribute_Selector_Obj parse_attribute_selector();\n Block_Obj parse_block(bool is_root = false);\n Block_Obj parse_css_block(bool is_root = false);\n bool parse_block_nodes(bool is_root = false);\n bool parse_block_node(bool is_root = false);\n\n bool parse_number_prefix();\n Declaration_Obj parse_declaration();\n Expression_Obj parse_map();\n Expression_Obj parse_bracket_list();\n Expression_Obj parse_list(bool delayed = false);\n Expression_Obj parse_comma_list(bool delayed = false);\n Expression_Obj parse_space_list();\n Expression_Obj parse_disjunction();\n Expression_Obj parse_conjunction();\n Expression_Obj parse_relation();\n Expression_Obj parse_expression();\n Expression_Obj parse_operators();\n Expression_Obj parse_factor();\n Expression_Obj parse_value();\n Function_Call_Obj parse_calc_function();\n Function_Call_Obj parse_function_call();\n Function_Call_Schema_Obj parse_function_call_schema();\n String_Obj parse_url_function_string();\n String_Obj parse_url_function_argument();\n String_Obj parse_interpolated_chunk(Token, bool constant = false, bool css = true);\n String_Obj parse_string();\n Value_Obj parse_static_value();\n String_Schema_Obj parse_css_variable_value(bool top_level = true);\n String_Schema_Obj parse_css_variable_value_token(bool top_level = true);\n String_Obj parse_ie_property();\n String_Obj parse_ie_keyword_arg();\n String_Schema_Obj parse_value_schema(const char* stop);\n String_Obj parse_identifier_schema();\n If_Obj parse_if_directive(bool else_if = false);\n For_Obj parse_for_directive();\n Each_Obj parse_each_directive();\n While_Obj parse_while_directive();\n Return_Obj parse_return_directive();\n Content_Obj parse_content_directive();\n void parse_charset_directive();\n Media_Block_Obj parse_media_block();\n List_Obj parse_media_queries();\n Media_Query_Obj parse_media_query();\n Media_Query_Expression_Obj parse_media_expression();\n Supports_Block_Obj parse_supports_directive();\n Supports_Condition_Obj parse_supports_condition();\n Supports_Condition_Obj parse_supports_negation();\n Supports_Condition_Obj parse_supports_operator();\n Supports_Condition_Obj parse_supports_interpolation();\n Supports_Condition_Obj parse_supports_declaration();\n Supports_Condition_Obj parse_supports_condition_in_parens();\n At_Root_Block_Obj parse_at_root_block();\n At_Root_Query_Obj parse_at_root_query();\n String_Schema_Obj parse_almost_any_value();\n Directive_Obj parse_special_directive();\n Directive_Obj parse_prefixed_directive();\n Directive_Obj parse_directive();\n Warning_Obj parse_warning();\n Error_Obj parse_error();\n Debug_Obj parse_debug();\n\n Value_Ptr color_or_string(const std::string& lexed) const;\n\n // be more like ruby sass\n Expression_Obj lex_almost_any_value_token();\n Expression_Obj lex_almost_any_value_chars();\n Expression_Obj lex_interp_string();\n Expression_Obj lex_interp_uri();\n Expression_Obj lex_interpolation();\n\n // these will throw errors\n Token lex_variable();\n Token lex_identifier();\n\n void parse_block_comments();\n\n Lookahead lookahead_for_value(const char* start = 0);\n Lookahead lookahead_for_selector(const char* start = 0);\n Lookahead lookahead_for_include(const char* start = 0);\n\n Expression_Obj fold_operands(Expression_Obj base, std::vector& operands, Operand op);\n Expression_Obj fold_operands(Expression_Obj base, std::vector& operands, std::vector& ops, size_t i = 0);\n\n void throw_syntax_error(std::string message, size_t ln = 0);\n void throw_read_error(std::string message, size_t ln = 0);\n\n\n template \n Expression_Obj lex_interp()\n {\n if (lex < open >(false)) {\n String_Schema_Obj schema = SASS_MEMORY_NEW(String_Schema, pstate);\n // std::cerr << \"LEX [[\" << std::string(lexed) << \"]]\\n\";\n schema->append(SASS_MEMORY_NEW(String_Constant, pstate, lexed));\n if (position[0] == '#' && position[1] == '{') {\n Expression_Obj itpl = lex_interpolation();\n if (!itpl.isNull()) schema->append(itpl);\n while (lex < close >(false)) {\n // std::cerr << \"LEX [[\" << std::string(lexed) << \"]]\\n\";\n schema->append(SASS_MEMORY_NEW(String_Constant, pstate, lexed));\n if (position[0] == '#' && position[1] == '{') {\n Expression_Obj itpl = lex_interpolation();\n if (!itpl.isNull()) schema->append(itpl);\n } else {\n return schema;\n }\n }\n } else {\n return SASS_MEMORY_NEW(String_Constant, pstate, lexed);\n }\n }\n return 0;\n }\n\n public:\n static Number_Ptr lexed_number(const ParserState& pstate, const std::string& parsed);\n static Number_Ptr lexed_dimension(const ParserState& pstate, const std::string& parsed);\n static Number_Ptr lexed_percentage(const ParserState& pstate, const std::string& parsed);\n static Value_Ptr lexed_hex_color(const ParserState& pstate, const std::string& parsed);\n private:\n Number_Ptr lexed_number(const std::string& parsed) { return lexed_number(pstate, parsed); };\n Number_Ptr lexed_dimension(const std::string& parsed) { return lexed_dimension(pstate, parsed); };\n Number_Ptr lexed_percentage(const std::string& parsed) { return lexed_percentage(pstate, parsed); };\n Value_Ptr lexed_hex_color(const std::string& parsed) { return lexed_hex_color(pstate, parsed); };\n\n static const char* re_attr_sensitive_close(const char* src);\n static const char* re_attr_insensitive_close(const char* src);\n\n };\n\n size_t check_bom_chars(const char* src, const char *end, const unsigned char* bom, size_t len);\n}\n\n#endif\n\n// Path: src/libsass/src/node.cpp\n#include \"sass.hpp\"\n#include \n\n#include \"node.hpp\"\n#include \"context.hpp\"\n#include \"parser.hpp\"\n\nnamespace Sass {\n\n\n Node Node::createCombinator(const Complex_Selector::Combinator& combinator) {\n NodeDequePtr null;\n return Node(COMBINATOR, combinator, NULL /*pSelector*/, null /*pCollection*/);\n }\n\n\n Node Node::createSelector(const Complex_Selector& pSelector) {\n NodeDequePtr null;\n\n Complex_Selector_Ptr pStripped = SASS_MEMORY_COPY(&pSelector);\n pStripped->tail(NULL);\n pStripped->combinator(Complex_Selector::ANCESTOR_OF);\n\n Node n(SELECTOR, Complex_Selector::ANCESTOR_OF, pStripped, null /*pCollection*/);\n n.got_line_feed = pSelector.has_line_feed();\n return n;\n }\n\n\n Node Node::createCollection() {\n NodeDequePtr pEmptyCollection = std::make_shared();\n return Node(COLLECTION, Complex_Selector::ANCESTOR_OF, NULL /*pSelector*/, pEmptyCollection);\n }\n\n\n Node Node::createCollection(const NodeDeque& values) {\n NodeDequePtr pShallowCopiedCollection = std::make_shared(values);\n return Node(COLLECTION, Complex_Selector::ANCESTOR_OF, NULL /*pSelector*/, pShallowCopiedCollection);\n }\n\n\n Node Node::createNil() {\n NodeDequePtr null;\n return Node(NIL, Complex_Selector::ANCESTOR_OF, NULL /*pSelector*/, null /*pCollection*/);\n }\n\n\n Node::Node(const TYPE& type, Complex_Selector::Combinator combinator, Complex_Selector_Ptr pSelector, NodeDequePtr& pCollection)\n : got_line_feed(false), mType(type), mCombinator(combinator), mpSelector(pSelector), mpCollection(pCollection)\n { if (pSelector) got_line_feed = pSelector->has_line_feed(); }\n\n\n Node Node::klone() const {\n NodeDequePtr pNewCollection = std::make_shared();\n if (mpCollection) {\n for (NodeDeque::iterator iter = mpCollection->begin(), iterEnd = mpCollection->end(); iter != iterEnd; iter++) {\n Node& toClone = *iter;\n pNewCollection->push_back(toClone.klone());\n }\n }\n\n Node n(mType, mCombinator, mpSelector ? SASS_MEMORY_COPY(mpSelector) : NULL, pNewCollection);\n n.got_line_feed = got_line_feed;\n return n;\n }\n\n\n bool Node::contains(const Node& potentialChild) const {\n bool found = false;\n\n for (NodeDeque::iterator iter = mpCollection->begin(), iterEnd = mpCollection->end(); iter != iterEnd; iter++) {\n Node& toTest = *iter;\n\n if (toTest == potentialChild) {\n found = true;\n break;\n }\n }\n\n return found;\n }\n\n\n bool Node::operator==(const Node& rhs) const {\n if (this->type() != rhs.type()) {\n return false;\n }\n\n if (this->isCombinator()) {\n\n return this->combinator() == rhs.combinator();\n\n } else if (this->isNil()) {\n\n return true; // no state to check\n\n } else if (this->isSelector()){\n\n return *this->selector() == *rhs.selector();\n\n } else if (this->isCollection()) {\n\n if (this->collection()->size() != rhs.collection()->size()) {\n return false;\n }\n\n for (NodeDeque::iterator lhsIter = this->collection()->begin(), lhsIterEnd = this->collection()->end(),\n rhsIter = rhs.collection()->begin(); lhsIter != lhsIterEnd; lhsIter++, rhsIter++) {\n\n if (*lhsIter != *rhsIter) {\n return false;\n }\n\n }\n\n return true;\n\n }\n\n // We shouldn't get here.\n throw \"Comparing unknown node types. A new type was probably added and this method wasn't implemented for it.\";\n }\n\n\n void Node::plus(Node& rhs) {\n if (!this->isCollection() || !rhs.isCollection()) {\n throw \"Both the current node and rhs must be collections.\";\n }\n this->collection()->insert(this->collection()->end(), rhs.collection()->begin(), rhs.collection()->end());\n }\n\n#ifdef DEBUG\n std::ostream& operator<<(std::ostream& os, const Node& node) {\n\n if (node.isCombinator()) {\n\n switch (node.combinator()) {\n case Complex_Selector::ANCESTOR_OF: os << \"\\\" \\\"\"; break;\n case Complex_Selector::PARENT_OF: os << \"\\\">\\\"\"; break;\n case Complex_Selector::PRECEDES: os << \"\\\"~\\\"\"; break;\n case Complex_Selector::ADJACENT_TO: os << \"\\\"+\\\"\"; break;\n case Complex_Selector::REFERENCE: os << \"\\\"/\\\"\"; break;\n }\n\n } else if (node.isNil()) {\n\n os << \"nil\";\n\n } else if (node.isSelector()){\n\n os << node.selector()->head()->to_string();\n\n } else if (node.isCollection()) {\n\n os << \"[\";\n\n for (NodeDeque::iterator iter = node.collection()->begin(), iterBegin = node.collection()->begin(), iterEnd = node.collection()->end(); iter != iterEnd; iter++) {\n if (iter != iterBegin) {\n os << \", \";\n }\n\n os << (*iter);\n }\n\n os << \"]\";\n\n }\n\n return os;\n\n }\n#endif\n\n\n Node complexSelectorToNode(Complex_Selector_Ptr pToConvert) {\n if (pToConvert == NULL) {\n return Node::createNil();\n }\n Node node = Node::createCollection();\n node.got_line_feed = pToConvert->has_line_feed();\n bool has_lf = pToConvert->has_line_feed();\n\n // unwrap the selector from parent ref\n if (pToConvert->head() && pToConvert->head()->has_parent_ref()) {\n Complex_Selector_Obj tail = pToConvert->tail();\n if (tail) tail->has_line_feed(pToConvert->has_line_feed());\n pToConvert = tail;\n }\n\n while (pToConvert) {\n\n bool empty_parent_ref = pToConvert->head() && pToConvert->head()->is_empty_reference();\n\n // the first Complex_Selector may contain a dummy head pointer, skip it.\n if (pToConvert->head() && !empty_parent_ref) {\n node.collection()->push_back(Node::createSelector(*pToConvert));\n if (has_lf) node.collection()->back().got_line_feed = has_lf;\n if (pToConvert->head() || empty_parent_ref) {\n if (pToConvert->tail()) {\n pToConvert->tail()->has_line_feed(pToConvert->has_line_feed());\n }\n }\n has_lf = false;\n }\n\n if (pToConvert->combinator() != Complex_Selector::ANCESTOR_OF) {\n node.collection()->push_back(Node::createCombinator(pToConvert->combinator()));\n if (has_lf) node.collection()->back().got_line_feed = has_lf;\n has_lf = false;\n }\n\n if (pToConvert && empty_parent_ref && pToConvert->tail()) {\n // pToConvert->tail()->has_line_feed(pToConvert->has_line_feed());\n }\n\n pToConvert = pToConvert->tail();\n }\n\n return node;\n }\n\n\n Complex_Selector_Ptr nodeToComplexSelector(const Node& toConvert) {\n if (toConvert.isNil()) {\n return NULL;\n }\n\n\n if (!toConvert.isCollection()) {\n throw \"The node to convert to a Complex_Selector_Ptr must be a collection type or nil.\";\n }\n\n\n NodeDeque& childNodes = *toConvert.collection();\n\n std::string noPath(\"\");\n Complex_Selector_Obj pFirst = SASS_MEMORY_NEW(Complex_Selector, ParserState(\"[NODE]\"), Complex_Selector::ANCESTOR_OF, NULL, NULL);\n\n Complex_Selector_Obj pCurrent = pFirst;\n\n if (toConvert.isSelector()) pFirst->has_line_feed(toConvert.got_line_feed);\n if (toConvert.isCombinator()) pFirst->has_line_feed(toConvert.got_line_feed);\n\n for (NodeDeque::iterator childIter = childNodes.begin(), childIterEnd = childNodes.end(); childIter != childIterEnd; childIter++) {\n\n Node& child = *childIter;\n\n if (child.isSelector()) {\n // JMA - need to clone the selector, because they can end up getting shared across Node\n // collections, and can result in an infinite loop during the call to parentSuperselector()\n pCurrent->tail(SASS_MEMORY_COPY(child.selector()));\n // if (child.got_line_feed) pCurrent->has_line_feed(child.got_line_feed);\n pCurrent = pCurrent->tail();\n } else if (child.isCombinator()) {\n pCurrent->combinator(child.combinator());\n if (child.got_line_feed) pCurrent->has_line_feed(child.got_line_feed);\n\n // if the next node is also a combinator, create another Complex_Selector to hold it so it doesn't replace the current combinator\n if (childIter+1 != childIterEnd) {\n Node& nextNode = *(childIter+1);\n if (nextNode.isCombinator()) {\n pCurrent->tail(SASS_MEMORY_NEW(Complex_Selector, ParserState(\"[NODE]\"), Complex_Selector::ANCESTOR_OF, NULL, NULL));\n if (nextNode.got_line_feed) pCurrent->tail()->has_line_feed(nextNode.got_line_feed);\n pCurrent = pCurrent->tail();\n }\n }\n } else {\n throw \"The node to convert's children must be only combinators or selectors.\";\n }\n }\n\n // Put the dummy Compound_Selector in the first position, for consistency with the rest of libsass\n Compound_Selector_Ptr fakeHead = SASS_MEMORY_NEW(Compound_Selector, ParserState(\"[NODE]\"), 1);\n Parent_Selector_Ptr selectorRef = SASS_MEMORY_NEW(Parent_Selector, ParserState(\"[NODE]\"));\n fakeHead->elements().push_back(selectorRef);\n if (toConvert.got_line_feed) pFirst->has_line_feed(toConvert.got_line_feed);\n // pFirst->has_line_feed(pFirst->has_line_feed() || pFirst->tail()->has_line_feed() || toConvert.got_line_feed);\n pFirst->head(fakeHead);\n return SASS_MEMORY_COPY(pFirst);\n }\n\n // A very naive trim function, which removes duplicates in a node\n // This is only used in Complex_Selector::unify_with for now, may need modifications to fit other needs\n Node Node::naiveTrim(Node& seqses) {\n\n std::vector res;\n std::vector known;\n\n NodeDeque::reverse_iterator seqsesIter = seqses.collection()->rbegin(),\n seqsesIterEnd = seqses.collection()->rend();\n\n for (; seqsesIter != seqsesIterEnd; ++seqsesIter)\n {\n Node& seqs1 = *seqsesIter;\n if( seqs1.isSelector() ) {\n Complex_Selector_Obj sel = seqs1.selector();\n std::vector::iterator it;\n bool found = false;\n for (it = known.begin(); it != known.end(); ++it) {\n if (**it == *sel) { found = true; break; }\n }\n if( !found ) {\n known.push_back(seqs1.selector());\n res.push_back(&seqs1);\n }\n } else {\n res.push_back(&seqs1);\n }\n }\n\n Node result = Node::createCollection();\n\n for (size_t i = res.size() - 1; i != std::string::npos; --i) {\n result.collection()->push_back(*res[i]);\n }\n\n return result;\n }\n}\n\n// Path: src/libsass/src/sass_util.cpp\n#include \"sass.hpp\"\n#include \"node.hpp\"\n\nnamespace Sass {\n\n\n /*\n # This is the equivalent of ruby's Sass::Util.paths.\n #\n # Return an array of all possible paths through the given arrays.\n #\n # @param arrs [NodeCollection>]\n # @return [NodeCollection>]\n #\n # @example\n # paths([[1, 2], [3, 4], [5]]) #=>\n # # [[1, 3, 5],\n # # [2, 3, 5],\n # # [1, 4, 5],\n # # [2, 4, 5]]\n\n The following is the modified version of the ruby code that was more portable to C++. You\n should be able to drop it into ruby 3.2.19 and get the same results from ruby sass.\n\n def paths(arrs)\n // I changed the inject and maps to an iterative approach to make it easier to implement in C++\n loopStart = [[]]\n\n for arr in arrs do\n permutations = []\n for e in arr do\n for path in loopStart do\n permutations.push(path + [e])\n end\n end\n loopStart = permutations\n end\n end\n */\n Node paths(const Node& arrs) {\n\n Node loopStart = Node::createCollection();\n loopStart.collection()->push_back(Node::createCollection());\n\n for (NodeDeque::iterator arrsIter = arrs.collection()->begin(), arrsEndIter = arrs.collection()->end();\n \tarrsIter != arrsEndIter; ++arrsIter) {\n\n Node& arr = *arrsIter;\n\n Node permutations = Node::createCollection();\n\n for (NodeDeque::iterator arrIter = arr.collection()->begin(), arrIterEnd = arr.collection()->end();\n \tarrIter != arrIterEnd; ++arrIter) {\n\n Node& e = *arrIter;\n\n for (NodeDeque::iterator loopStartIter = loopStart.collection()->begin(), loopStartIterEnd = loopStart.collection()->end();\n loopStartIter != loopStartIterEnd; ++loopStartIter) {\n\n Node& path = *loopStartIter;\n\n Node newPermutation = Node::createCollection();\n newPermutation.got_line_feed = arr.got_line_feed;\n newPermutation.plus(path);\n newPermutation.collection()->push_back(e);\n\n permutations.collection()->push_back(newPermutation);\n }\n }\n\n loopStart = permutations;\n }\n\n return loopStart;\n }\n\n\n /*\n This is the equivalent of ruby sass' Sass::Util.flatten and [].flatten.\n Sass::Util.flatten requires the number of levels to flatten, while\n [].flatten doesn't and will flatten the entire array. This function\n supports both.\n\n # Flattens the first `n` nested arrays. If n == -1, all arrays will be flattened\n #\n # @param arr [NodeCollection] The array to flatten\n # @param n [int] The number of levels to flatten\n # @return [NodeCollection] The flattened array\n\n The following is the modified version of the ruby code that was more portable to C++. You\n should be able to drop it into ruby 3.2.19 and get the same results from ruby sass.\n\n def flatten(arr, n = -1)\n if n != -1 and n == 0 then\n return arr\n end\n\n flattened = []\n\n for e in arr do\n if e.is_a?(Array) then\n flattened.concat(flatten(e, n - 1))\n else\n flattened << e\n end\n end\n\n return flattened\n end\n */\n Node flatten(Node& arr, int n) {\n if (n != -1 && n == 0) {\n return arr;\n }\n\n Node flattened = Node::createCollection();\n if (arr.got_line_feed) flattened.got_line_feed = true;\n\n for (NodeDeque::iterator iter = arr.collection()->begin(), iterEnd = arr.collection()->end();\n \titer != iterEnd; iter++) {\n \tNode& e = *iter;\n\n // e has the lf set\n if (e.isCollection()) {\n\n \t// e.collection().got_line_feed = e.got_line_feed;\n \tNode recurseFlattened = flatten(e, n - 1);\n\n \tif(e.got_line_feed) {\n \t\t flattened.got_line_feed = e.got_line_feed;\n \t recurseFlattened.got_line_feed = e.got_line_feed;\n \t}\n\n \tfor(auto i : (*recurseFlattened.collection())) {\n if (recurseFlattened.got_line_feed) {\n\n i.got_line_feed = true;\n }\n flattened.collection()->push_back(i);\n \t}\n\n } else {\n \tflattened.collection()->push_back(e);\n }\n }\n\n return flattened;\n }\n}\n\n// Path: src/libsass/src/lexer.cpp\n#include \"sass.hpp\"\n#include \n#include \n#include \n#include \"lexer.hpp\"\n#include \"constants.hpp\"\n\n\nnamespace Sass {\n using namespace Constants;\n\n namespace Prelexer {\n\n //####################################\n // BASIC CHARACTER MATCHERS\n //####################################\n\n // Match standard control chars\n const char* kwd_at(const char* src) { return exactly<'@'>(src); }\n const char* kwd_dot(const char* src) { return exactly<'.'>(src); }\n const char* kwd_comma(const char* src) { return exactly<','>(src); };\n const char* kwd_colon(const char* src) { return exactly<':'>(src); };\n const char* kwd_star(const char* src) { return exactly<'*'>(src); };\n const char* kwd_plus(const char* src) { return exactly<'+'>(src); };\n const char* kwd_minus(const char* src) { return exactly<'-'>(src); };\n const char* kwd_slash(const char* src) { return exactly<'/'>(src); };\n\n //####################################\n // implement some function that do exist in the standard\n // but those are locale aware which brought some trouble\n // this even seems to improve performance by quite a bit\n //####################################\n\n bool is_alpha(const char& chr)\n {\n return unsigned(chr - 'A') <= 'Z' - 'A' ||\n unsigned(chr - 'a') <= 'z' - 'a';\n }\n\n bool is_space(const char& chr)\n {\n // adapted the technique from is_alpha\n return chr == ' ' || unsigned(chr - '\\t') <= '\\r' - '\\t';\n }\n\n bool is_digit(const char& chr)\n {\n // adapted the technique from is_alpha\n return unsigned(chr - '0') <= '9' - '0';\n }\n\n bool is_number(const char& chr)\n {\n // adapted the technique from is_alpha\n return is_digit(chr) || chr == '-' || chr == '+';\n }\n\n bool is_xdigit(const char& chr)\n {\n // adapted the technique from is_alpha\n return unsigned(chr - '0') <= '9' - '0' ||\n unsigned(chr - 'a') <= 'f' - 'a' ||\n unsigned(chr - 'A') <= 'F' - 'A';\n }\n\n bool is_punct(const char& chr)\n {\n // locale independent\n return chr == '.';\n }\n\n \nbool is_alnum(const char& chr)\n {\n return is_alpha(chr) || is_digit(chr);\n }\n\n // check if char is outside ascii range\n bool is_unicode(const char& chr)\n {\n // check for unicode range\n return unsigned(chr) > 127;\n }\n\n // check if char is outside ascii range\n // but with specific ranges (copied from Ruby Sass)\n bool is_nonascii(const char& chr)\n {\n unsigned int cmp = unsigned(chr);\n return (\n (cmp >= 128 && cmp <= 15572911) ||\n (cmp >= 15630464 && cmp <= 15712189) ||\n (cmp >= 4036001920)\n );\n }\n\n // check if char is within a reduced ascii range\n // valid in a uri (copied from Ruby Sass)\n bool is_uri_character(const char& chr)\n {\n unsigned int cmp = unsigned(chr);\n return (cmp > 41 && cmp < 127) ||\n cmp == ':' || cmp == '/';\n }\n\n // check if char is within a reduced ascii range\n // valid for escaping (copied from Ruby Sass)\n bool is_escapable_character(const char& chr)\n {\n unsigned int cmp = unsigned(chr);\n return cmp > 31 && cmp < 127;\n }\n\n // Match word character (look ahead)\n bool is_character(const char& chr)\n {\n // valid alpha, numeric or unicode char (plus hyphen)\n return is_alnum(chr) || is_unicode(chr) || chr == '-';\n }\n\n //####################################\n // BASIC CLASS MATCHERS\n //####################################\n\n // create matchers that advance the position\n const char* space(const char* src) { return is_space(*src) ? src + 1 : 0; }\n const char* alpha(const char* src) { return is_alpha(*src) ? src + 1 : 0; }\n const char* unicode(const char* src) { return is_unicode(*src) ? src + 1 : 0; }\n const char* nonascii(const char* src) { return is_nonascii(*src) ? src + 1 : 0; }\n const char* digit(const char* src) { return is_digit(*src) ? src + 1 : 0; }\n const char* xdigit(const char* src) { return is_xdigit(*src) ? src + 1 : 0; }\n const char* alnum(const char* src) { return is_alnum(*src) ? src + 1 : 0; }\n const char* punct(const char* src) { return is_punct(*src) ? src + 1 : 0; }\n const char* hyphen(const char* src) { return *src && *src == '-' ? src + 1 : 0; }\n const char* character(const char* src) { return is_character(*src) ? src + 1 : 0; }\n...\n", "needle_token_start": 15278, "needle_token_end": 15311, "code_context_ntokens": 16090, "output": ["```cpp\nbool is_alnum(const char& chr)\n{\n return is_alpha(chr) || is_digit(chr);\n}\n```"]} +{"repo": "ml-explore/mlx", "name": "vmap_binary_op", "language": "cpp", "path": "mlx/primitives.cpp", "position_ratio": 0.05, "description": "\nFunction Description:\n1. **Purpose**: The function is designed to align two input arrays based on specified axes for vectorized binary operations, ensuring they are compatible for subsequent operations.\n2. **Input**: \n - A pair of arrays.\n - A pair of integers representing the axes along which the arrays should be aligned.\n - A stream object for managing operations.\n3. **Output**: A tuple containing:\n - The first modified array.\n - The second modified array, potentially transposed to match the alignment of the first.\n - The axis along which the alignment was performed.\n4. **Procedure**: \n - Validates that there are exactly two input arrays and two axes provided.\n - Determines the new dimensionality needed to accommodate both arrays.\n - Expands the dimensions of both arrays to match this new dimensionality.\n - Transposes the second array if necessary to align it with the first array along the specified axes.\n - Returns the modified arrays and the axis of alignment.\n\n", "instruction": "Based on the function description and code context, please retrieve and repeat the exact described function from the code context in a code block wrapped by ```:", "template": "instruction\ncode_context\ndescription\ninstruction", "code_context": "// Path: mlx/random.cpp\n// Copyright \u00a9 2023-2024 Apple Inc.\n\n#include \n#include \n\n#include \"mlx/linalg.h\"\n#include \"mlx/ops.h\"\n#include \"mlx/primitives.h\"\n#include \"mlx/random.h\"\n#include \"mlx/utils.h\"\n\nnamespace mlx::core::random {\n\nKeySequence::KeySequence(uint64_t seed) : key_(key(seed)) {}\n\nvoid KeySequence::seed(uint64_t seed) {\n key_ = key((seed));\n...\n// Path: mlx/device.cpp\n// Copyright \u00a9 2023 Apple Inc.\n\n#include \"mlx/device.h\"\n#include \"mlx/backend/metal/metal.h\"\n\nnamespace mlx::core {\n\nstatic Device default_device_{\n metal::is_available() ? Device::gpu : Device::cpu};\n\nconst Device& default_device() {\n return default_device_;\n}\n\nvoid set_default_device(const Device& d) {\n if (!metal::is_available() && d == Device::gpu) {\n throw std::invalid_argument(\n \"[set_default_device] Cannot set gpu device without gpu backend.\");\n }\n default_device_ = d;\n}\n\nbool operator==(const Device& lhs, const Device& rhs) {\n return lhs.type == rhs.type && lhs.index == rhs.index;\n}\n\nbool operator!=(const Device& lhs, const Device& rhs) {\n return !(lhs == rhs);\n}\n\n} // namespace mlx::core\n\n// Path: mlx/primitives.cpp\n// Copyright \u00a9 2023-2024 Apple Inc.\n#include \n#include \n#include \n#include \n#include \n#include \n\n#include \"mlx/backend/common/utils.h\"\n#include \"mlx/fft.h\"\n#include \"mlx/ops.h\"\n#include \"mlx/primitives.h\"\n#include \"mlx/utils.h\"\n\nnamespace mlx::core {\n\nnamespace {\n\ns\ntd::tuple vmap_binary_op(\n const std::vector& inputs,\n const std::vector& axes,\n const Stream& stream) {\n assert(inputs.size() == 2);\n assert(axes.size() == 2);\n\n if (axes[0] == -1 && axes[1] == -1) {\n return {inputs[0], inputs[1], -1};\n }\n\n auto a = inputs[0];\n auto b = inputs[1];\n int ndim = std::max(a.ndim() + (axes[0] == -1), b.ndim() + (axes[1] == -1));\n\n auto expand_dims = [stream, ndim](auto in) {\n auto shape = in.shape();\n shape.insert(shape.begin(), ndim - shape.size(), 1);\n return reshape(in, shape, stream);\n };\n\n int to_ax = (ndim - a.ndim()) + axes[0];\n int from_ax = (ndim - b.ndim()) + axes[1];\n a = expand_dims(a);\n b = expand_dims(b);\n\n if (from_ax != to_ax) {\n std::vector tdims(b.ndim());\n std::iota(tdims.begin(), tdims.end(), 0);\n tdims.erase(tdims.begin() + from_ax);\n tdims.insert(tdims.begin() + to_ax, from_ax);\n b = transpose(b, tdims, stream);\n }\n return {a, b, to_ax};\n}\n\nstd::tuple vmap_ternary_op(\n const std::vector& inputs,\n const std::vector& axes,\n const Stream& stream) {\n assert(inputs.size() == 3);\n assert(axes.size() == 3);\n\n if (axes[0] == -1 && axes[1] == -1 && axes[2] == -1) {\n return {inputs[0], inputs[1], inputs[2], -1};\n }\n\n auto a = inputs[0];\n auto b = inputs[1];\n auto c = inputs[2];\n int ndim = std::max(\n {a.ndim() + (axes[0] == -1),\n b.ndim() + (axes[1] == -1),\n c.ndim() + (axes[2] == -1)});\n\n auto expand_dims = [stream, ndim](auto in) {\n auto shape = in.shape();\n shape.insert(shape.begin(), ndim - shape.size(), 1);\n return reshape(in, shape, stream);\n };\n\n int to_ax = (ndim - a.ndim()) + axes[0];\n int from_ax1 = (ndim - b.ndim()) + axes[1];\n int from_ax2 = (ndim - c.ndim()) + axes[2];\n a = expand_dims(a);\n b = expand_dims(b);\n c = expand_dims(c);\n\n auto find_tdims = [](auto x, int to_ax, int from_ax) {\n std::vector tdims(x.ndim());\n std::iota(tdims.begin(), tdims.end(), 0);\n tdims.erase(tdims.begin() + from_ax);\n tdims.insert(tdims.begin() + to_ax, from_ax);\n return tdims;\n };\n\n if (to_ax != from_ax1) {\n std::vector tdims = find_tdims(b, to_ax, from_ax1);\n b = transpose(b, tdims, stream);\n }\n\n if (to_ax != from_ax2) {\n std::vector tdims = find_tdims(c, to_ax, from_ax2);\n c = transpose(c, tdims, stream);\n }\n return {a, b, c, to_ax};\n}\n\n} // namespace\n\nstd::vector Primitive::jvp(\n const std::vector&,\n const std::vector&,\n const std::vector&) {\n std::ostringstream msg;\n msg << \"[Primitive::jvp] Not implemented for \";\n print(msg);\n msg << \".\";\n throw std::invalid_argument(msg.str());\n};\n\nstd::vector Primitive::vjp(\n const std::vector&,\n const std::vector&,\n const std::vector&,\n const std::vector&) {\n std::ostringstream msg;\n msg << \"[Primitive::vip] Not implemented for \";\n print(msg);\n msg << \".\";\n throw std::invalid_argument(msg.str());\n};\n\nstd::pair, std::vector> Primitive::vmap(\n const std::vector&,\n const std::vector&) {\n std::ostringstream msg;\n msg << \"[Primitive::vmap] Not implemented for \";\n print(msg);\n msg << \".\";\n throw std::invalid_argument(msg.str());\n};\n\nstd::vector> Primitive::output_shapes(\n const std::vector&) {\n std::ostringstream msg;\n msg << \"[Primitive::output_shapes] \";\n this->print(msg);\n msg << \" cannot infer output shapes.\";\n throw std::invalid_argument(msg.str());\n};\n\nstd::vector Abs::vjp(\n const std::vector& primals,\n const std::vector& cotangents,\n const std::vector& argnums,\n const std::vector&) {\n return jvp(primals, cotangents, argnums);\n}\n\nstd::vector Abs::jvp(\n const std::vector& primals,\n const std::vector& tangents,\n const std::vector& argnums) {\n assert(primals.size() == 1);\n assert(argnums.size() == 1);\n return {multiply(tangents[0], sign(primals[0], stream()), stream())};\n}\n\nstd::pair, std::vector> Abs::vmap(\n const std::vector& inputs,\n const std::vector& axes) {\n assert(inputs.size() == 1);\n assert(axes.size() == 1);\n return {{abs(inputs[0], stream())}, axes};\n}\n\nstd::vector Add::jvp(\n const std::vector& primals,\n const std::vector& tangents,\n const std::vector& argnums) {\n return {\n tangents.size() > 1 ? add(tangents[0], tangents[1], stream())\n : tangents[0]};\n}\n\nstd::vector Add::vjp(\n const std::vector& primals,\n const std::vector& cotangents,\n const std::vector& argnums,\n const std::vector&) {\n if (argnums.size() == 1) {\n return cotangents;\n } else {\n return {cotangents[0], cotangents[0]};\n }\n}\n\nstd::pair, std::vector> Add::vmap(\n const std::vector& inputs,\n const std::vector& axes) {\n auto [a, b, to_ax] = vmap_binary_op(inputs, axes, stream());\n return {{add(a, b, stream())}, {to_ax}};\n}\n\nstd::vector AddMM::vjp(\n const std::vector& primals,\n const std::vector& cotangents,\n const std::vector& argnums,\n const std::vector&) {\n std::vector vjps;\n auto& cotan = cotangents[0];\n std::vector reorder(cotan.ndim());\n std::iota(reorder.begin(), reorder.end(), 0);\n std::iter_swap(reorder.end() - 1, reorder.end() - 2);\n for (auto arg : argnums) {\n if (arg == 0) {\n // M X N * (K X N).T -> M X K\n auto cotan_scaled = cotan;\n if (alpha_ != 1.) {\n auto alpha_arr = array(alpha_, cotan.dtype());\n cotan_scaled = (multiply(alpha_arr, cotan_scaled, stream()));\n }\n vjps.push_back(matmul(\n cotan_scaled, transpose(primals[1], reorder, stream()), stream()));\n } else if (arg == 1) {\n // (M X K).T * M X N -> K X N\n auto cotan_scaled = cotan;\n if (alpha_ != 1.) {\n auto alpha_arr = array(alpha_, cotan.dtype());\n cotan_scaled = (multiply(alpha_arr, cotan_scaled, stream()));\n }\n vjps.push_back(matmul(\n transpose(primals[0], reorder, stream()), cotan_scaled, stream()));\n } else {\n auto cotan_scaled = cotan;\n if (beta_ != 1.) {\n auto beta_arr = array(beta_, cotan.dtype());\n cotan_scaled = (multiply(beta_arr, cotan_scaled, stream()));\n }\n vjps.push_back(cotan_scaled);\n }\n }\n return vjps;\n}\n\nbool AddMM::is_equivalent(const Primitive& other) const {\n const AddMM& a_other = static_cast(other);\n return (alpha_ == a_other.alpha_ && beta_ == a_other.beta_);\n}\n\nstd::pair, std::vector> AddMM::vmap(\n const std::vector& inputs,\n const std::vector& axes) {\n auto maybe_move_ax = [this](auto& arr, auto ax) {\n return ax > 0 ? moveaxis(arr, ax, 0, stream()) : arr;\n };\n auto a = maybe_move_ax(inputs[0], axes[0]);\n auto b = maybe_move_ax(inputs[1], axes[1]);\n auto c = maybe_move_ax(inputs[2], axes[2]);\n return {{addmm(c, a, b, alpha_, beta_, stream())}, {0}};\n}\n\nbool Arange::is_equivalent(const Primitive& other) const {\n const Arange& a_other = static_cast(other);\n return (\n start_ == a_other.start_ && stop_ == a_other.stop_ &&\n step_ == a_other.step_);\n}\n\nstd::vector ArcCos::vjp(\n const std::vector& primals,\n const std::vector& cotangents,\n const std::vector& argnums,\n const std::vector&) {\n return jvp(primals, cotangents, argnums);\n}\n\nstd::vector ArcCos::jvp(\n const std::vector& primals,\n const std::vector& tangents,\n const std::vector& argnums) {\n assert(primals.size() == 1);\n assert(argnums.size() == 1);\n array one = array(1., primals[0].dtype());\n array t = subtract(one, square(primals[0], stream()), stream());\n array denom = negative(rsqrt(t, stream()), stream());\n return {multiply(tangents[0], denom, stream())};\n}\n\nstd::pair, std::vector> ArcCos::vmap(\n const std::vector& inputs,\n const std::vector& axes) {\n assert(inputs.size() == 1);\n assert(axes.size() == 1);\n return {{arccos(inputs[0], stream())}, axes};\n}\n\nstd::vector ArcCosh::vjp(\n const std::vector& primals,\n const std::vector& cotangents,\n const std::vector& argnums,\n const std::vector&) {\n return jvp(primals, cotangents, argnums);\n}\n\nstd::vector ArcCosh::jvp(\n const std::vector& primals,\n const std::vector& tangents,\n const std::vector& argnums) {\n assert(primals.size() == 1);\n assert(argnums.size() == 1);\n array one = array(1., primals[0].dtype());\n array t = subtract(square(primals[0], stream()), one, stream());\n return {multiply(tangents[0], rsqrt(t, stream()), stream())};\n}\n\nstd::pair, std::vector> ArcCosh::vmap(\n const std::vector& inputs,\n const std::vector& axes) {\n assert(inputs.size() == 1);\n assert(axes.size() == 1);\n return {{arccosh(inputs[0], stream())}, axes};\n}\n\nstd::vector ArcSin::vjp(\n const std::vector& primals,\n const std::vector& cotangents,\n const std::vector& argnums,\n const std::vector&) {\n return jvp(primals, cotangents, argnums);\n}\n\nstd::vector ArcSin::jvp(\n const std::vector& primals,\n const std::vector& tangents,\n const std::vector& argnums) {\n assert(primals.size() == 1);\n assert(argnums.size() == 1);\n array one = array(1., primals[0].dtype());\n array t = subtract(one, square(primals[0], stream()), stream());\n return {multiply(tangents[0], rsqrt(t, stream()), stream())};\n}\n\nstd::pair, std::vector> ArcSin::vmap(\n const std::vector& inputs,\n const std::vector& axes) {\n assert(inputs.size() == 1);\n assert(axes.size() == 1);\n return {{arcsin(inputs[0], stream())}, axes};\n}\n\nstd::vector ArcSinh::vjp(\n const std::vector& primals,\n const std::vector& cotangents,\n const std::vector& argnums,\n const std::vector&) {\n return jvp(primals, cotangents, argnums);\n}\n\nstd::vector ArcSinh::jvp(\n const std::vector& primals,\n const std::vector& tangents,\n const std::vector& argnums) {\n assert(primals.size() == 1);\n assert(argnums.size() == 1);\n array one = array(1., primals[0].dtype());\n array t = add(square(primals[0], stream()), one, stream());\n return {multiply(tangents[0], rsqrt(t, stream()), stream())};\n}\n\nstd::pair, std::vector> ArcSinh::vmap(\n const std::vector& inputs,\n const std::vector& axes) {\n assert(inputs.size() == 1);\n assert(axes.size() == 1);\n return {{arcsinh(inputs[0], stream())}, axes};\n}\n\nstd::vector ArcTan::vjp(\n const std::vector& primals,\n const std::vector& cotangents,\n const std::vector& argnums,\n const std::vector&) {\n return jvp(primals, cotangents, argnums);\n}\n\nstd::vector ArcTan::jvp(\n const std::vector& primals,\n const std::vector& tangents,\n const std::vector& argnums) {\n assert(primals.size() == 1);\n assert(argnums.size() == 1);\n array one = array(1., primals[0].dtype());\n array t = add(one, square(primals[0], stream()), stream());\n return {divide(tangents[0], t, stream())};\n}\n\nstd::pair, std::vector> ArcTan::vmap(\n const std::vector& inputs,\n const std::vector& axes) {\n assert(inputs.size() == 1);\n assert(axes.size() == 1);\n return {{arctan(inputs[0], stream())}, axes};\n}\n\nstd::vector ArcTanh::vjp(\n const std::vector& primals,\n const std::vector& cotangents,\n const std::vector& argnums,\n const std::vector&) {\n return jvp(primals, cotangents, argnums);\n}\n\nstd::vector ArcTanh::jvp(\n const std::vector& primals,\n const std::vector& tangents,\n const std::vector& argnums) {\n assert(primals.size() == 1);\n assert(argnums.size() == 1);\n array one = array(1., primals[0].dtype());\n array t = subtract(one, square(primals[0], stream()), stream());\n return {divide(tangents[0], t, stream())};\n}\n\nstd::pair, std::vector> ArcTanh::vmap(\n const std::vector& inputs,\n const std::vector& axes) {\n assert(inputs.size() == 1);\n assert(axes.size() == 1);\n return {{arctanh(inputs[0], stream())}, axes};\n}\n\nstd::pair, std::vector> ArgPartition::vmap(\n const std::vector& inputs,\n const std::vector& axes) {\n assert(inputs.size() == 1);\n assert(axes.size() == 1);\n\n int axis_left = axes[0] >= 0 && axes[0] <= axis_;\n return {{argpartition(inputs[0], axis_ + axis_left, stream())}, axes};\n}\n\nbool ArgPartition::is_equivalent(const Primitive& other) const {\n const ArgPartition& r_other = static_cast(other);\n return axis_ == r_other.axis_ && kth_ == r_other.kth_;\n}\n\nbool ArgReduce::is_equivalent(const Primitive& other) const {\n const ArgReduce& r_other = static_cast(other);\n return reduce_type_ == r_other.reduce_type_ && axis_ == r_other.axis_;\n}\n\nstd::pair, std::vector> ArgReduce::vmap(\n const std::vector& inputs,\n const std::vector& axes) {\n int reduce_ax = axis_ + (axes[0] >= 0 && axis_ >= axes[0]);\n auto& in = inputs[0];\n std::vector out;\n if (reduce_type_ == ArgReduce::ArgMin) {\n out.push_back(argmin(in, reduce_ax, true, stream()));\n } else {\n out.push_back(argmax(in, reduce_ax, true, stream()));\n }\n return {out, axes};\n}\n\nstd::pair, std::vector> ArgSort::vmap(\n const std::vector& inputs,\n const std::vector& axes) {\n assert(inputs.size() == 1);\n assert(axes.size() == 1);\n\n int axis_left = axes[0] >= 0 && axes[0] <= axis_;\n return {{argsort(inputs[0], axis_ + axis_left, stream())}, axes};\n}\n\nstd::vector> ArgReduce::output_shapes(\n const std::vector& inputs) {\n auto out_shape = inputs[0].shape();\n out_shape[axis_] = 1;\n return {out_shape};\n}\n\nbool ArgSort::is_equivalent(const Primitive& other) const {\n const ArgSort& r_other = static_cast(other);\n return axis_ == r_other.axis_;\n}\n\nstd::vector AsType::vjp(\n const std::vector& primals,\n const std::vector& cotangents,\n const std::vector& argnums,\n const std::vector&) {\n if (cotangents[0].dtype() != dtype_) {\n throw std::invalid_argument(\n \"[astype] Type of cotangentsgent does not much primal output type.\");\n }\n return {astype(cotangents[0], primals[0].dtype(), stream())};\n}\n\nstd::vector AsType::jvp(\n const std::vector& primals,\n const std::vector& tangents,\n const std::vector& argnums) {\n return {astype(tangents[0], dtype_, stream())};\n}\n\nstd::pair, std::vector> AsType::vmap(\n const std::vector& inputs,\n const std::vector& axes) {\n return {{astype(inputs[0], dtype_, stream())}, axes};\n}\n\nbool AsType::is_equivalent(const Primitive& other) const {\n const AsType& a_other = static_cast(other);\n return dtype_ == a_other.dtype_;\n}\n\nstd::vector AsStrided::vjp(\n const std::vector& primals,\n const std::vector& cotangents,\n const std::vector& argnums,\n const std::vector&) {\n assert(argnums.size() == 1);\n\n // Extract the sizes and cast them to ints\n int grad_size = primals[0].size();\n int cotangents_size = cotangents[0].size();\n\n // Make a flat container to hold the gradients\n auto grad = zeros_like(primals[0], stream());\n grad = reshape(grad, {grad_size}, stream());\n\n // Create the indices that map output to input\n auto idx = arange(grad_size, stream());\n idx = as_strided(idx, shape_, strides_, offset_, stream());\n idx = reshape(idx, {cotangents_size}, stream());\n\n // Reshape the cotangentsgent for use with scatter\n auto flat_cotangents = reshape(cotangents[0], {cotangents_size, 1}, stream());\n\n // Finally accumulate the gradients and reshape them to look like the input\n grad = scatter_add(grad, idx, flat_cotangents, 0, stream());\n grad = reshape(grad, primals[0].shape(), stream());\n\n return {grad};\n}\n\nstd::vector AsStrided::jvp(\n const std::vector& primals,\n const std::vector& tangents,\n const std::vector& argnums) {\n assert(primals.size() == 1);\n\n return {as_strided(tangents[0], shape_, strides_, offset_, stream())};\n}\n\nbool AsStrided::is_equivalent(const Primitive& other) const {\n const AsStrided& a_other = static_cast(other);\n return shape_ == a_other.shape_ && strides_ == a_other.strides_ &&\n offset_ == a_other.offset_;\n}\n\nstd::vector Broadcast::vjp(\n const std::vector& primals,\n const std::vector& cotangents,\n const std::vector& argnums,\n const std::vector&) {\n assert(argnums.size() == 1);\n\n // Reduce cotangents to the shape of the primal\n auto& shape = primals[0].shape();\n auto& cotan = cotangents[0];\n int diff = cotan.ndim() - shape.size();\n std::vector reduce_axes;\n for (int i = 0; i < cotan.ndim(); ++i) {\n if (i < diff) {\n reduce_axes.push_back(i);\n } else if (shape[i - diff] != cotan.shape(i)) {\n reduce_axes.push_back(i);\n }\n }\n return {reshape(sum(cotan, reduce_axes, true, stream()), shape, stream())};\n}\n\nstd::vector Broadcast::jvp(\n const std::vector& primals,\n const std::vector& tangents,\n const std::vector& argnums) {\n assert(argnums.size() == 1);\n return {broadcast_to(tangents[0], shape_, stream())};\n}\n\nstd::pair, std::vector> Broadcast::vmap(\n const std::vector& inputs,\n const std::vector& axes) {\n assert(inputs.size() == 1);\n assert(axes.size() == 1);\n auto ax = axes[0];\n auto in = inputs[0];\n if (ax >= 0) {\n auto in_shape = in.shape();\n int diff = shape_.size() - in.ndim() + 1;\n assert(diff >= 0);\n in_shape.insert(in_shape.begin(), diff, 1);\n ax += diff;\n shape_.insert(shape_.begin() + ax, in_shape[ax]);\n in = reshape(in, in_shape, stream());\n }\n return {{broadcast_to(in, shape_, stream())}, {ax}};\n}\n\nbool Broadcast::is_equivalent(const Primitive& other) const {\n const Broadcast& b_other = static_cast(other);\n return shape_ == b_other.shape_;\n}\n\nstd::vector Ceil::vjp(\n const std::vector& primals,\n const std::vector& cotangents,\n const std::vector& argnums,\n const std::vector&) {\n return jvp(primals, cotangents, argnums);\n}\n\nstd::vector Ceil::jvp(\n const std::vector& primals,\n const std::vector& tangents,\n const std::vector& argnums) {\n assert(primals.size() == 1);\n assert(argnums.size() == 1);\n return {zeros_like(primals[0], stream())};\n}\n\nstd::pair, std::vector> Ceil::vmap(\n const std::vector& inputs,\n const std::vector& axes) {\n assert(inputs.size() == 1);\n assert(axes.size() == 1);\n return {{ceil(inputs[0], stream())}, axes};\n}\n\nstd::vector Concatenate::vjp(\n const std::vector& primals,\n const std::vector& cotangents,\n const std::vector& argnums,\n const std::vector&) {\n auto& cotan = cotangents[0];\n std::vector start(cotan.ndim(), 0);\n std::vector stop = cotan.shape();\n\n std::vector sizes;\n sizes.push_back(0);\n for (auto& p : primals) {\n sizes.push_back(p.shape(axis_));\n }\n std::partial_sum(sizes.cbegin(), sizes.cend(), sizes.begin());\n\n std::vector grads;\n for (auto i : argnums) {\n start[axis_] = sizes[i];\n stop[axis_] = sizes[i + 1];\n grads.push_back(slice(cotan, start, stop, stream()));\n }\n return grads;\n}\n\nstd::vector Concatenate::jvp(\n const std::vector& primals,\n const std::vector& tangents,\n const std::vector& argnums) {\n std::vector argidx(argnums.size());\n std::iota(argidx.begin(), argidx.end(), 0);\n std::sort(argidx.begin(), argidx.end(), [&argnums](int a, int b) {\n return argnums[a] < argnums[b];\n });\n\n std::vector vals;\n for (int i = 0, j = 0; i < primals.size(); ++i) {\n if (j < argnums.size() && argnums[argidx[j]] == i) {\n vals.push_back(tangents[argidx[j++]]);\n } else {\n vals.push_back(zeros_like(primals[i], stream()));\n }\n }\n return {concatenate(vals, axis_, stream())};\n}\n\nstd::pair, std::vector> Concatenate::vmap(\n const std::vector& inputs,\n const std::vector& axes) {\n std::vector t_inputs;\n int out_ax = -1;\n // Find the first vmapped input\n int i = 0;\n for (; i < axes.size(); i++) {\n t_inputs.push_back(inputs[i]);\n if (axes[i] >= 0) {\n out_ax = axes[i];\n break;\n }\n }\n if (out_ax >= 0) {\n // Advance to the next input\n i++;\n\n // Move vmap axes to the same spot.\n for (; i < axes.size(); ++i) {\n if (out_ax != axes[i] && axes[i] >= 0) {\n t_inputs.push_back(moveaxis(inputs[i], axes[i], out_ax, stream()));\n } else {\n t_inputs.push_back(inputs[i]);\n }\n }\n }\n auto axis = axis_ + (out_ax >= 0 && axis_ >= out_ax);\n return {{concatenate(t_inputs, axis, stream())}, {out_ax}};\n}\n\nbool Concatenate::is_equivalent(const Primitive& other) const {\n const Concatenate& c_other = static_cast(other);\n return axis_ == c_other.axis_;\n}\n\narray conv_weight_backward_patches(\n const array& in,\n const array& wt,\n const array& cotan,\n const std::vector& kernel_strides,\n const std::vector& padding,\n StreamOrDevice s) {\n // Resolve Padded input shapes and strides\n std::vector padding_starts(in.ndim(), 0);\n std::vector padding_ends = in.shape();\n std::vector in_padded_shape = in.shape();\n\n // padded shape\n for (int i = 1; i < in.ndim() - 1; i++) {\n in_padded_shape[i] += 2 * padding[i - 1];\n padding_ends[i] += padding[i - 1];\n padding_starts[i] += padding[i - 1];\n }\n\n // padded strides (contiguous)\n std::vector in_padded_strides(in.ndim(), 1);\n for (int i = in.ndim() - 2; i >= 0; --i) {\n in_padded_strides[i] = in_padded_strides[i + 1] * in_padded_shape[i + 1];\n }\n\n // Pad input\n std::vector padded_axes(in.ndim() - 2, 0);\n std::iota(padded_axes.begin(), padded_axes.end(), 1);\n auto in_padded =\n pad(in, padded_axes, padding, padding, array(0, in.dtype()), s);\n\n // Resolve strided patches\n\n // patches are shaped as\n // (batch_dim, out_spatial_dims, weight_spatial_dims, in_channels)\n std::vector patches_shape{\n cotan.shape().begin(), cotan.shape().end() - 1};\n patches_shape.insert(\n patches_shape.end(), wt.shape().begin() + 1, wt.shape().end());\n\n // Resolve patch strides\n int n_spatial_dim = in.ndim() - 2;\n std::vector patches_strides(patches_shape.size(), 1);\n patches_strides[0] = in_padded_strides[0];\n for (int i = 1; i < n_spatial_dim + 1; i++) {\n patches_strides[i] = in_padded_strides[i] * kernel_strides[i - 1];\n }\n for (int i = 1; i < in.ndim(); i++) {\n patches_strides[n_spatial_dim + i] = in_padded_strides[i];\n }\n\n // Make patches from in\n auto in_patches = as_strided(in_padded, patches_shape, patches_strides, 0, s);\n\n // Prepare for matmul\n int O = wt.shape(0);\n auto cotan_mat = reshape(cotan, {-1, O}, s);\n in_patches = reshape(in_patches, {cotan_mat.shape(0), -1}, s);\n\n auto grad = matmul(transpose(cotan_mat, {1, 0}, s), in_patches, s);\n grad = reshape(grad, wt.shape(), s);\n return grad;\n}\n\nstd::vector Convolution::vjp(\n const std::vector& primals,\n const std::vector& cotangents,\n const std::vector& argnums,\n const std::vector&) {\n assert(primals.size() == 2);\n std::vector grads;\n\n // Collect info\n auto& in = primals[0];\n auto& wt = primals[1];\n auto& cotan = cotangents[0];\n\n for (int a : argnums) {\n // Grads for input\n if (a == 0) {\n std::vector padding_lo = padding_;\n std::vector padding_hi = padding_;\n\n for (int i = 0; i < padding_lo.size(); ++i) {\n int wt_size = 1 + kernel_dilation_[i] * (wt.shape(1 + i) - 1);\n padding_lo[i] = wt_size - padding_[i] - 1;\n\n int in_size = 1 + input_dilation_[i] * (in.shape(1 + i) - 1);\n int out_size = 1 + kernel_strides_[i] * (cotan.shape(1 + i) - 1);\n padding_hi[i] = in_size - out_size + padding_[i];\n }\n\n auto wt_trans = swapaxes(wt, 0, -1, stream());\n\n auto grad = conv_general(\n /* const array& input = */ cotan,\n /* const array& weight = */ wt_trans,\n /* std::vector stride = */ input_dilation_,\n /* std::vector padding_lo = */ padding_lo,\n /* std::vector padding_hi = */ padding_hi,\n /* std::vector kernel_dilation = */ kernel_dilation_,\n /* std::vector input_dilation = */ kernel_strides_,\n /* int groups = */ 1,\n /* bool flip = */ !flip_,\n stream());\n\n grads.push_back(grad);\n }\n // Grads for weight\n else if (a == 1) {\n bool no_dilation = true;\n\n for (int i = 0; i < input_dilation_.size(); i++) {\n no_dilation &= (input_dilation_[i] == 1) && (kernel_dilation_[i] == 1);\n }\n\n if (no_dilation) {\n auto grad = conv_weight_backward_patches(\n in, wt, cotan, kernel_strides_, padding_, stream());\n grads.push_back(grad);\n } else {\n std::vector padding_lo = padding_;\n std::vector padding_hi = padding_;\n\n for (int i = 0; i < padding_hi.size(); ++i) {\n int in_size = 1 + input_dilation_[i] * (in.shape(1 + i) - 1);\n int out_size = 1 + kernel_strides_[i] * (cotan.shape(1 + i) - 1);\n int wt_size = 1 + kernel_dilation_[i] * (wt.shape(1 + i) - 1);\n padding_hi[i] = out_size - in_size + wt_size - padding_[i] - 1;\n }\n\n auto in_trans = swapaxes(in, 0, -1, stream());\n auto cotan_trans = swapaxes(cotan, 0, -1, stream());\n auto grad_trans = conv_general(\n /* const array& input = */ in_trans,\n /* const array& weight = */ cotan_trans,\n /* std::vector stride = */ kernel_dilation_,\n /* std::vector padding_lo = */ padding_lo,\n /* std::vector padding_hi = */ padding_hi,\n /* std::vector kernel_dilation = */ kernel_strides_,\n /* std::vector input_dilation = */ input_dilation_,\n /* int groups = */ 1,\n /* bool flip = */ flip_,\n stream());\n auto grad = swapaxes(grad_trans, 0, -1, stream());\n grads.push_back(grad);\n }\n }\n }\n\n return grads;\n}\n\nbool Convolution::is_equivalent(const Primitive& other) const {\n const Convolution& c_other = static_cast(other);\n return padding_ == c_other.padding_ &&\n kernel_strides_ == c_other.kernel_strides_ &&\n kernel_dilation_ == c_other.kernel_dilation_ &&\n input_dilation_ == c_other.input_dilation_ &&\n groups_ == c_other.groups_ && flip_ == c_other.flip_;\n}\n\nstd::vector Copy::vjp(\n const std::vector& primals,\n const std::vector& cotangents,\n const std::vector& argnums,\n const std::vector&) {\n assert(primals.size() == 1);\n assert(argnums.size() == 1);\n return cotangents;\n}\n\nstd::vector Copy::jvp(\n const std::vector& primals,\n const std::vector& tangents,\n const std::vector& argnums) {\n assert(primals.size() == 1);\n assert(argnums.size() == 1);\n return tangents;\n}\n\nstd::pair, std::vector> Copy::vmap(\n const std::vector& inputs,\n const std::vector& axes) {\n assert(inputs.size() == 1);\n assert(axes.size() == 1);\n return {{copy(inputs[0], stream())}, axes};\n}\n\nstd::vector Cos::vjp(\n const std::vector& primals,\n const std::vector& cotangents,\n const std::vector& argnums,\n const std::vector&) {\n return {jvp(primals, cotangents, argnums)};\n}\n\nstd::vector Cos::jvp(\n const std::vector& primals,\n const std::vector& tangents,\n const std::vector& argnums) {\n assert(primals.size() == 1);\n assert(argnums.size() == 1);\n return {multiply(\n tangents[0], negative(sin(primals[0], stream()), stream()), stream())};\n}\n\nstd::pair, std::vector> Cos::vmap(\n const std::vector& inputs,\n const std::vector& axes) {\n assert(inputs.size() == 1);\n assert(axes.size() == 1);\n return {{cos(inputs[0], stream())}, axes};\n}\n\nstd::vector Cosh::vjp(\n const std::vector& primals,\n const std::vector& cotangents,\n const std::vector& argnums,\n const std::vector&) {\n return jvp(primals, cotangents, argnums);\n}\n\nstd::vector Cosh::jvp(\n const std::vector& primals,\n const std::vector& tangents,\n const std::vector& argnums) {\n assert(primals.size() == 1);\n assert(argnums.size() == 1);\n return {multiply(tangents[0], sinh(primals[0], stream()), stream())};\n}\n\nstd::pair, std::vector> Cosh::vmap(\n const std::vector& inputs,\n const std::vector& axes) {\n assert(inputs.size() == 1);\n assert(axes.size() == 1);\n return {{cosh(inputs[0], stream())}, axes};\n}\n\nstd::vector CustomVJP::vjp(\n const std::vector& primals,\n const std::vector& cotangents,\n const std::vector& argnums,\n const std::vector& outputs) {\n std::vector inputs(primals.begin(), primals.end() - outputs.size());\n auto all_vjps = vjp_fun_(inputs, cotangents, outputs);\n for (const auto& cot : cotangents) {\n all_vjps.emplace_back(cot);\n }\n\n std::vector vjps;\n vjps.reserve(argnums.size());\n for (auto arg : argnums) {\n vjps.push_back(all_vjps[arg]);\n }\n\n return vjps;\n}\n\nstd::vector Depends::vjp(\n const std::vector& primals,\n const std::vector& cotangents,\n const std::vector& argnums,\n const std::vector& outputs) {\n std::vector vjps;\n\n for (auto arg : argnums) {\n if (arg < cotangents.size()) {\n vjps.push_back(cotangents[arg]);\n } else {\n vjps.push_back(zeros_like(primals[arg]));\n }\n }\n return vjps;\n}\n\nstd::vector Divide::vjp(\n const std::vector& primals,\n const std::vector& cotangents,\n const std::vector& argnums,\n const std::vector&) {\n std::vector vjps;\n for (auto arg : argnums) {\n if (arg == 0) {\n vjps.push_back(divide(cotangents[0], primals[1], stream()));\n } else {\n vjps.push_back(negative(\n divide(\n multiply(cotangents[0], primals[0], stream()),\n square(primals[1], stream()),\n stream()),\n stream()));\n }\n }\n return vjps;\n}\n\nstd::vector DivMod::vjp(\n const std::vector& primals,\n const std::vector& cotangents,\n const std::vector& argnums,\n const std::vector&) {\n std::vector vjps;\n for (auto arg : argnums) {\n vjps.push_back(zeros_like(primals[arg], stream()));\n }\n return vjps;\n}\n\nstd::vector DivMod::jvp(\n const std::vector& primals,\n const std::vector& tangents,\n const std::vector& argnums) {\n return {zeros_like(primals[0], stream())};\n}\n\nstd::pair, std::vector> DivMod::vmap(\n const std::vector& inputs,\n const std::vector& axes) {\n auto [a, b, to_ax] = vmap_binary_op(inputs, axes, stream());\n return {divmod(a, b, stream()), {to_ax}};\n}\n\nstd::vector Divide::jvp(\n const std::vector& primals,\n const std::vector& tangents,\n const std::vector& argnums) {\n auto jvp_fun = [&](int i) {\n int arg = argnums[i];\n if (arg == 0) {\n return divide(tangents[i], primals[1], stream());\n } else {\n return negative(\n divide(\n multiply(tangents[i], primals[0], stream()),\n square(primals[1], stream()),\n stream()),\n stream());\n }\n };\n auto out = jvp_fun(0);\n if (argnums.size() > 1) {\n out = add(out, jvp_fun(1), stream());\n }\n return {out};\n}\n\nstd::pair, std::vector> Divide::vmap(\n const std::vector& inputs,\n const std::vector& axes) {\n auto [a, b, to_ax] = vmap_binary_op(inputs, axes, stream());\n return {{divide(a, b, stream())}, {to_ax}};\n}\n\nstd::vector Remainder::vjp(\n const std::vector& primals,\n const std::vector& cotangents,\n const std::vector& argnums,\n const std::vector&) {\n std::vector vjps;\n for (auto arg : argnums) {\n if (arg == 0) {\n vjps.push_back(cotangents[0]);\n } else {\n auto x_over_y = divide(primals[0], primals[1], stream());\n x_over_y = floor(x_over_y, stream());\n vjps.push_back(\n negative(multiply(x_over_y, cotangents[0], stream()), stream()));\n }\n }\n return vjps;\n}\n\nstd::vector Remainder::jvp(\n const std::vector& primals,\n const std::vector& tangents,\n const std::vector& argnums) {\n auto jvp_fun = [&](int i) {\n int arg = argnums[i];\n if (arg == 0) {\n return tangents[i];\n } else {\n auto x_over_y = divide(primals[0], primals[1], stream());\n x_over_y = floor(x_over_y, stream());\n return negative(multiply(x_over_y, tangents[i], stream()), stream());\n }\n };\n auto out = jvp_fun(0);\n if (argnums.size() > 1) {\n out = add(out, jvp_fun(1), stream());\n }\n return {out};\n}\n\nstd::pair, std::vector> Remainder::vmap(\n const std::vector& inputs,\n const std::vector& axes) {\n auto [a, b, to_ax] = vmap_binary_op(inputs, axes, stream());\n return {{remainder(a, b, stream())}, {to_ax}};\n}\n\nstd::pair, std::vector> Equal::vmap(\n const std::vector& inputs,\n const std::vector& axes) {\n auto [a, b, to_ax] = vmap_binary_op(inputs, axes, stream());\n return {{equal(a, b, stream())}, {to_ax}};\n}\n\nstd::vector Equal::vjp(\n const std::vector& primals,\n const std::vector& cotangents,\n const std::vector& argnums,\n const std::vector&) {\n std::vector vjps;\n for (auto arg : argnums) {\n vjps.push_back(zeros_like(primals[arg], stream()));\n }\n return vjps;\n}\n\nstd::vector Equal::jvp(\n const std::vector& primals,\n const std::vector& tangents,\n const std::vector& argnums) {\n auto shape = broadcast_shapes(primals[0].shape(), primals[1].shape());\n return {zeros(shape, bool_, stream())};\n}\n\nstd::vector Erf::vjp(\n const std::vector& primals,\n const std::vector& cotangents,\n const std::vector& argnums,\n const std::vector&) {\n return jvp(primals, cotangents, argnums);\n}\n\nstd::vector Erf::jvp(\n const std::vector& primals,\n const std::vector& tangents,\n const std::vector& argnums) {\n assert(primals.size() == 1);\n assert(argnums.size() == 1);\n auto dtype = primals[0].dtype();\n auto scale = multiply(array(M_2_SQRTPI, dtype), tangents[0], stream());\n return {multiply(\n scale,\n exp(negative(square(primals[0], stream()), stream()), stream()),\n stream())};\n}\n\nstd::pair, std::vector> Erf::vmap(\n const std::vector& inputs,\n const std::vector& axes) {\n assert(inputs.size() == 1);\n assert(axes.size() == 1);\n return {{erf(inputs[0], stream())}, axes};\n}\n\nstd::vector ErfInv::vjp(\n const std::vector& primals,\n const std::vector& cotangents,\n const std::vector& argnums,\n const std::vector& outputs) {\n auto dtype = primals[0].dtype();\n auto scale =\n multiply(array(1.0 / M_2_SQRTPI, dtype), cotangents[0], stream());\n return {\n multiply(scale, exp(square(outputs[0], stream()), stream()), stream())};\n}\n\nstd::vector ErfInv::jvp(\n const std::vector& primals,\n const std::vector& tangents,\n const std::vector& argnums) {\n assert(primals.size() == 1);\n assert(argnums.size() == 1);\n auto dtype = primals[0].dtype();\n auto scale = multiply(array(1.0 / M_2_SQRTPI, dtype), tangents[0], stream());\n return {multiply(\n scale,\n exp(square(erfinv(primals[0], stream()), stream()), stream()),\n stream())};\n}\n\nstd::pair, std::vector> ErfInv::vmap(\n const std::vector& inputs,\n const std::vector& axes) {\n assert(inputs.size() == 1);\n assert(axes.size() == 1);\n return {{erfinv(inputs[0], stream())}, axes};\n}\n\nstd::vector Exp::vjp(\n const std::vector& primals,\n const std::vector& cotangents,\n const std::vector& argnums,\n const std::vector& outputs) {\n return {multiply(cotangents[0], outputs[0], stream())};\n}\n\nstd::vector Exp::jvp(\n const std::vector& primals,\n const std::vector& tangents,\n const std::vector& argnums) {\n assert(primals.size() == 1);\n assert(argnums.size() == 1);\n return {multiply(tangents[0], exp(primals[0], stream()), stream())};\n}\n\nstd::pair, std::vector> Exp::vmap(\n const std::vector& inputs,\n const std::vector& axes) {\n assert(inputs.size() == 1);\n assert(axes.size() == 1);\n return {{exp(inputs[0], stream())}, axes};\n}\n\nstd::vector Expm1::vjp(\n const std::vector& primals,\n const std::vector& cotangents,\n const std::vector& argnums,\n const std::vector& outputs) {\n return {multiply(\n cotangents[0],\n add(outputs[0], array(1.0f, outputs[0].dtype()), stream()),\n stream())};\n}\n\nstd::vector Expm1::jvp(\n const std::vector& primals,\n const std::vector& tangents,\n const std::vector& argnums) {\n assert(primals.size() == 1);\n assert(argnums.size() == 1);\n return {multiply(tangents[0], exp(primals[0], stream()), stream())};\n}\n\nstd::pair, std::vector> Expm1::vmap(\n const std::vector& inputs,\n const std::vector& axes) {\n assert(inputs.size() == 1);\n assert(axes.size() == 1);\n return {{expm1(inputs[0], stream())}, axes};\n}\n\nbool FFT::is_equivalent(const Primitive& other) const {\n const FFT& r_other = static_cast(other);\n return axes_ == r_other.axes_ && inverse_ == r_other.inverse_ &&\n real_ == r_other.real_;\n}\n\nstd::pair, std::vector> FFT::vmap(\n const std::vector& inputs,\n const std::vector& axes) {\n auto& in = inputs[0];\n int ax = axes[0];\n auto fft_axes = axes_;\n auto out_shape = in.shape();\n if (ax >= 0) {\n for (auto& fft_ax : fft_axes) {\n if (fft_ax >= ax) {\n fft_ax++;\n }\n if (real_) {\n auto n = out_shape[fft_ax];\n out_shape[fft_ax] = inverse_ ? 2 * (n - 1) : n / 2 + 1;\n }\n }\n }\n return {\n {array(\n out_shape,\n real_ && inverse_ ? float32 : complex64,\n std::make_shared(stream(), fft_axes, inverse_, real_),\n {in})},\n {ax}};\n}\n\nstd::vector FFT::vjp(\n const std::vector& primals,\n const std::vector& cotangents,\n const std::vector& argnums,\n const std::vector&) {\n assert(primals.size() == 1);\n assert(argnums.size() == 1);\n auto& in = primals[0];\n std::vector axes(axes_.begin(), axes_.end());\n if (real_ && inverse_) {\n auto out = fft::fftn(cotangents[0], axes, stream());\n auto start = std::vector(out.ndim(), 0);\n auto stop = in.shape();\n out = slice(out, start, stop, stream());\n auto mask_shape = out.shape();\n mask_shape[axes_.back()] -= 2;\n auto mask = full(mask_shape, 2.0f, stream());\n auto pad_shape = out.shape();\n pad_shape[axes_.back()] = 1;\n auto pad = full(pad_shape, 1.0f, stream());\n mask = concatenate({pad, mask, pad}, axes_.back(), stream());\n return {multiply(mask, out, stream())};\n } else if (real_) {\n std::vector n;\n for (auto ax : axes_) {\n n.push_back(in.shape()[ax]);\n }\n return {astype(\n fft::fftn(cotangents[0], n, axes, stream()), in.dtype(), stream())};\n } else if (inverse_) {\n return {fft::ifftn(cotangents[0], axes, stream())};\n } else {\n return {fft::fftn(cotangents[0], axes, stream())};\n }\n}\n\nstd::vector FFT::jvp(\n const std::vector& primals,\n const std::vector& tangents,\n const std::vector& argnums) {\n assert(primals.size() == 1);\n assert(argnums.size() == 1);\n auto& tan = tangents[0];\n if (real_ & inverse_) {\n return {fft::irfftn(tan, stream())};\n } else if (real_) {\n return {fft::rfftn(tan, stream())};\n } else if (inverse_) {\n return {fft::ifftn(tan, stream())};\n } else {\n return {fft::fftn(tan, stream())};\n }\n}\n\nstd::vector Floor::vjp(\n const std::vector& primals,\n...\n", "needle_token_start": 585, "needle_token_end": 1018, "code_context_ntokens": 16213, "output": ["```cpp\nstd::tuple vmap_binary_op(\n const std::vector& inputs,\n const std::vector& axes,\n const Stream& stream) {\n assert(inputs.size() == 2);\n assert(axes.size() == 2);\n\n if (axes[0] == -1 && axes[1] == -1) {\n return {inputs[0], inputs[1], -1};\n }\n\n auto a = inputs[0];\n auto b = inputs[1];\n int ndim = std::max(a.ndim() + (axes[0] == -1), b.ndim() + (axes[1] == -1));\n\n auto expand_dims = [stream, ndim](auto in) {\n auto shape = in.shape();\n shape.insert(shape.begin(), ndim - shape.size(), 1);\n return reshape(in, shape, stream);\n };\n\n int to_ax = (ndim - a.ndim()) + axes[0];\n int from_ax = (ndim - b.ndim()) + axes[1];\n a = expand_dims(a);\n b = expand_dims(b);\n\n if (from_ax != to_ax) {\n std::vector tdims(b.ndim());\n std::iota(tdims.begin(), tdims.end(), 0);\n tdims.erase(tdims.begin() + from_ax);\n tdims.insert(tdims.begin() + to_ax, from_ax);\n b = transpose(b, tdims, stream);\n }\n return {a, b, to_ax};\n}\n```"]} +{"repo": "ml-explore/mlx", "name": "vmap_trace", "language": "cpp", "path": "mlx/transforms.cpp", "position_ratio": 0.15, "description": "\nFunction Description:\n1. **Purpose**: The function is designed to trace a vectorized computation over a batch of inputs, allowing for operations to be applied in parallel across elements of the input arrays specified by axes.\n2. **Input**: This function takes three parameters: a function that operates on vectors of arrays, a vector of arrays representing the inputs to the function, and a vector of integers specifying the axes along which to vectorize each input array.\n3. **Output**: It returns a pair of vectors of arrays. The first vector contains the modified inputs with tracing enabled, and the second vector contains the results of applying the input function to these traced inputs.\n4. **Procedure**: The function first checks that the number of specified axes matches the number of input arrays and validates the specified axes against the dimensions of the input arrays. It then determines the size of the vectorization axis and ensures consistency across all inputs. For each input, if it is to be vectorized, it modifies the shape by removing the vectorization dimension and creates a placeholder input with tracing enabled. These modified inputs are then used to invoke the provided function, capturing the computational graph of the operation.\n\n", "instruction": "Based on the function description and code context, please retrieve and repeat the exact described function from the code context in a code block wrapped by ```:", "template": "instruction\ncode_context\ndescription\ninstruction", "code_context": "}\n\nstd::pair vjp(\n const std::function& fun,\n const array& primal,\n const array& cotan) {\n auto vec_fun = [fun](const std::vector& inputs) {\n return std::vector{fun(inputs[0])};\n };\n auto [outputs, vjps] = vjp(vec_fun, {primal}, {cotan});\n return {outputs[0], vjps[0]};\n}\n\nstd::pair, std::vector> jvp(\n const std::function(const std::vector&)>& fun,\n const std::vector& primals,\n const std::vector& tangents) {\n // Set the global tracing flag.\n detail::InTracing in_tracing;\n\n if (primals.size() != tangents.size()) {\n throw std::invalid_argument(\n \"[jvp] Number of inputs does not match number of tangents.\");\n }\n for (int i = 0; i < primals.size(); ++i) {\n if (primals[i].shape() != tangents[i].shape()) {\n throw std::invalid_argument(\n \"[jvp] Input shape does not match shape of tangent.\");\n }\n }\n\n std::vector primals_;\n for (auto& p : primals) {\n auto s = p.has_primitive() ? p.primitive().stream()\n : default_stream(default_device());\n primals_.push_back(copy(p, s)); // Does not do a deep copy\n primals_.back().set_tracer(true);\n }\n auto outputs = fun(primals_);\n\n // Topologically sort the compute graph, record outputs\n // in the tape if a gradient is needed.\n std::unordered_set cache;\n std::unordered_set calc_grad;\n for (auto& primal : primals_) {\n primal.set_tracer(false);\n calc_grad.insert(primal.id());\n cache.insert(primal.id());\n }\n\n std::vector tape;\n\n std::function recurse;\n recurse = [&](auto& a) {\n // Check if visited and add to cache if not\n if (auto inserted = cache.insert(a.id()); !inserted.second) {\n return;\n }\n a.set_tracer(false);\n for (auto s : a.siblings()) {\n s.set_tracer(false);\n cache.insert(s.id());\n }\n\n for (auto input : a.inputs()) {\n recurse(input);\n }\n\n // Stop grad\n if (a.has_primitive()) {\n if (auto& p = a.primitive(); typeid(p) == typeid(StopGradient)) {\n return;\n }\n }\n\n // Calculate gradient if any inputs require gradient\n for (auto& input : a.inputs()) {\n if (calc_grad.find(input.id()) != calc_grad.end()) {\n tape.push_back(a);\n calc_grad.insert(a.id());\n for (auto& s : a.siblings()) {\n calc_grad.insert(s.id());\n }\n break;\n }\n }\n };\n\n for (auto out : outputs) {\n recurse(out);\n }\n\n std::unordered_map tan_map;\n for (int i = 0; i < primals_.size(); ++i) {\n tan_map.insert({primals_[i].id(), tangents[i]});\n }\n\n for (auto& a : tape) {\n // Get the arguments used in the jvp\n std::vector argnums;\n std::vector tangents;\n for (int i = 0; i < a.inputs().size(); ++i) {\n if (auto it = tan_map.find(a.inputs()[i].id()); it != tan_map.end()) {\n argnums.push_back(i);\n tangents.push_back(it->second);\n }\n }\n\n auto jvps = a.primitive().jvp(a.inputs(), tangents, argnums);\n auto outputs = a.outputs();\n for (int i = 0; i < jvps.size(); ++i) {\n tan_map.insert({outputs[i].id(), jvps[i]});\n }\n }\n\n std::vector jvps;\n for (auto& out : outputs) {\n if (auto it = tan_map.find(out.id()); it != tan_map.end()) {\n jvps.push_back(it->second);\n } else {\n auto s = out.has_primitive() ? out.primitive().stream()\n : default_stream(default_device());\n jvps.push_back(zeros_like(out, s));\n }\n }\n return {outputs, jvps};\n}\n\nstd::pair jvp(\n const std::function& fun,\n const array& primal,\n const array& tangent) {\n auto vec_fun = [fun](const std::vector& inputs) {\n return std::vector{fun(inputs[0])};\n };\n auto [outputs, jvps] = jvp(vec_fun, {primal}, {tangent});\n return {outputs[0], jvps[0]};\n}\n\nValueAndGradFn value_and_grad(\n const std::function(const std::vector&)>& fun,\n const std::vector& argnums) {\n if (argnums.empty()) {\n throw std::invalid_argument(\"[grad] Must specify at least one argument.\");\n }\n return [fun, argnums](const std::vector& inputs) {\n std::set args;\n for (auto& arg : argnums) {\n args.insert(arg < 0 ? arg + inputs.size() : arg);\n }\n if (args.size() != argnums.size()) {\n throw std::invalid_argument(\n \"[grad] Repeat argument number not allowed in grad.\");\n }\n if (*args.begin() < 0 || *args.rbegin() >= inputs.size()) {\n std::ostringstream msg;\n msg << \"[grad] Invalid argument number for function with \"\n << inputs.size() << \" inputs.\";\n throw std::invalid_argument(msg.str());\n }\n\n auto gfun = [&fun, &inputs, &args](const std::vector& ginputs) {\n std::vector inputs_(inputs);\n auto argit = args.begin();\n for (int i = 0; i < ginputs.size(); ++i) {\n inputs_[*argit] = ginputs[i];\n ++argit;\n }\n auto outputs = fun(inputs_);\n for (int i = 1; i < outputs.size(); i++) {\n auto& out = outputs[i];\n auto s = out.has_primitive() ? out.primitive().stream()\n : default_stream(default_device());\n outputs[i] = stop_gradient(out, s);\n }\n return outputs;\n };\n\n std::vector ginputs;\n for (auto arg : args) {\n ginputs.push_back(inputs[arg]);\n }\n // Set the incoming gradient to int32, vjp will cast it to the output type\n auto [outputs, grads] = vjp(gfun, ginputs, {array(1.0f)});\n return std::make_pair(outputs, grads);\n };\n}\n\nnamespace detail {\n\ns\ntd::pair, std::vector> vmap_trace(\n const std::function(const std::vector&)>& fun,\n const std::vector& inputs,\n const std::vector& in_axes) {\n // Set the global tracing flag.\n detail::InTracing in_tracing;\n\n if (in_axes.size() != inputs.size()) {\n throw std::invalid_argument(\n \"[vmap] The number of in axes must match the number of inputs.\");\n }\n\n // Some error checking and get the vmap axis size\n size_t vmap_ax_size;\n for (int i = 0; i < inputs.size(); ++i) {\n if (in_axes[i] != -1) {\n if (inputs[i].ndim() == 0) {\n throw std::invalid_argument(\n \"[vmap] Cannot vmap an input with zero dimensions.\");\n }\n if (in_axes[i] > inputs[i].ndim()) {\n std::ostringstream msg;\n msg << \"[vmap] Axis \" << in_axes[i] << \" invalid for input with \"\n << inputs[i].ndim() << \" dimensions.\";\n throw std::invalid_argument(msg.str());\n }\n vmap_ax_size = inputs[i].shape(in_axes[i]);\n }\n }\n // Check that all vmapped axes have the same size\n for (int i = 0; i < inputs.size(); ++i) {\n if (in_axes[i] != -1) {\n if (size_t in_ax = inputs[i].shape(in_axes[i]); vmap_ax_size != in_ax) {\n std::ostringstream msg;\n msg << \"[vmap] Inconsistent axis sizes: \" << in_ax << \" and \"\n << vmap_ax_size << \".\";\n throw std::invalid_argument(msg.str());\n }\n }\n }\n\n // Run the function on placeholder inputs\n // to get the original graph\n std::vector s_inputs;\n for (int i = 0; i < inputs.size(); ++i) {\n if (in_axes[i] != -1) {\n std::vector shape = inputs[i].shape();\n shape.erase(shape.begin() + in_axes[i]);\n array in(shape, inputs[i].dtype(), nullptr, {});\n s_inputs.push_back(in);\n s_inputs.back().set_tracer(true);\n } else {\n s_inputs.push_back(inputs[i]);\n }\n }\n return {s_inputs, fun(s_inputs)};\n}\n\nstd::vector vmap_replace(\n const std::vector& inputs,\n const std::vector& s_inputs,\n const std::vector& s_outputs,\n const std::vector& in_axes,\n const std::vector& out_axes) {\n if (out_axes.size() != s_outputs.size()) {\n throw std::invalid_argument(\n \"[vmap] The number of out axes must match the number of outputs.\");\n }\n\n std::unordered_map> tmap;\n std::unordered_set needs_vmap;\n std::unordered_set cache;\n for (int i = 0; i < s_inputs.size(); ++i) {\n auto in = s_inputs[i];\n if (in_axes[i] != -1) {\n tmap.insert({in.id(), {inputs[i], in_axes[i]}});\n needs_vmap.insert(in.id());\n in.set_tracer(false);\n }\n cache.insert(in.id());\n }\n\n // Topologically sort the graph\n std::vector tape;\n\n std::function recurse;\n\n recurse = [&](const array& a) {\n auto id = a.id();\n if (cache.find(id) != cache.end()) {\n return;\n }\n cache.insert(id);\n for (auto& s : a.siblings()) {\n cache.insert(s.id());\n }\n\n // Recurse on inputs\n for (auto& input : a.inputs()) {\n recurse(input);\n }\n // If any input needs a vmap, then the outputs also need\n // a vmap\n for (auto& input : a.inputs()) {\n if (needs_vmap.find(input.id()) != needs_vmap.end()) {\n tape.push_back(a);\n tape.back().set_tracer(false);\n needs_vmap.insert(a.id());\n for (auto s : a.siblings()) {\n needs_vmap.insert(s.id());\n s.set_tracer(false);\n }\n break;\n }\n }\n };\n\n for (auto& out : s_outputs) {\n if (out.has_primitive()) {\n recurse(out);\n }\n }\n\n // Transform each primitive in the graph with\n // its vmap implementation\n for (auto& a : tape) {\n std::vector v_inputs;\n std::vector v_axes;\n for (auto& in : a.inputs()) {\n auto map_it = tmap.find(in.id());\n if (map_it != tmap.end()) {\n v_inputs.push_back(map_it->second.first);\n v_axes.push_back(map_it->second.second);\n } else {\n v_inputs.push_back(in);\n v_axes.push_back(-1);\n }\n }\n\n auto [v_outputs, v_out_axes] = a.primitive().vmap(v_inputs, v_axes);\n\n // For each primitive's outputs add its id, the vout id and the vax\n auto outputs = a.outputs();\n for (int i = 0; i < v_outputs.size(); ++i) {\n tmap.insert({outputs[i].id(), {v_outputs[i], v_out_axes[i]}});\n }\n }\n\n // Populate the outputs and make sure all the output axes are\n // in the right place\n std::vector outputs;\n for (int i = 0; i < s_outputs.size(); ++i) {\n if (auto map_it = tmap.find(s_outputs[i].id()); map_it != tmap.end()) {\n auto& [out, vdim] = map_it->second;\n if (vdim != out_axes[i]) {\n if (out_axes[i] >= out.ndim()) {\n std::ostringstream msg;\n msg << \"[vmap] Axis \" << out_axes[i] << \" invalid for output with \"\n << out.ndim() << \" dimensions.\";\n throw std::invalid_argument(msg.str());\n }\n out = moveaxis(out, vdim, out_axes[i]);\n }\n outputs.push_back(out);\n } else {\n outputs.push_back(s_outputs[i]);\n }\n }\n return outputs;\n}\n\n} // namespace detail\n\nstd::function(const std::vector&)> vmap(\n const std::function(const std::vector&)>& fun,\n const std::vector& in_axes /* = {} */,\n const std::vector& out_axes /* = {} */) {\n auto infer_axes = [](auto axes) {\n return !axes.empty() &&\n std::all_of(axes.begin(), axes.end(), [](int ax) { return ax < 0; });\n };\n if (infer_axes(in_axes) != infer_axes(out_axes)) {\n throw std::invalid_argument(\n \"[vmap] Input (or output) axes must be \"\n \"specified if output (or input) axes are.\");\n }\n auto vfun = [fun, in_axes = in_axes, out_axes = out_axes](\n const std::vector& inputs) mutable {\n if (in_axes.size() == 0) {\n in_axes.resize(inputs.size(), 0);\n }\n\n auto [trace_inputs, trace_outputs] =\n detail::vmap_trace(fun, inputs, in_axes);\n\n if (out_axes.size() == 0) {\n out_axes.resize(trace_outputs.size(), 0);\n }\n\n return detail::vmap_replace(\n inputs, trace_inputs, trace_outputs, in_axes, out_axes);\n };\n\n return vfun;\n}\n\nstd::function vmap(\n const std::function& fun,\n int in_axis_a /* = 0 */,\n int in_axis_b /* = 0 */,\n int out_axis /* = 0 */) {\n auto vfun = vmap(\n [in_axis_a, in_axis_b, out_axis, fun](const std::vector& inputs) {\n return std::vector{fun(inputs[0], inputs[1])};\n },\n {in_axis_a, in_axis_b},\n {out_axis});\n return [vfun](const array& a, const array& b) { return vfun({a, b})[0]; };\n}\n\nstd::function vmap(\n const std::function& fun,\n int in_axis /* = 0 */,\n int out_axis /* = 0 */) {\n auto vfun = vmap(\n [in_axis, out_axis, fun](const std::vector& inputs) {\n return std::vector{fun(inputs[0])};\n },\n {in_axis},\n {out_axis});\n return [vfun](const array& a) { return vfun({a})[0]; };\n}\n\nstd::function(const std::vector&)> custom_vjp(\n std::function(const std::vector&)> fun,\n std::function(\n const std::vector&,\n const std::vector&,\n const std::vector&)> fun_vjp) {\n return [fun = std::move(fun),\n fun_vjp = std::move(fun_vjp)](const std::vector& args) {\n // Compute the outputs\n auto outputs = fun(args);\n for (auto& out : outputs) {\n out = stop_gradient(out);\n }\n\n // Prepare the inputs to the primitive\n // We also add the outputs to the primitive so that it can \"run\" the forward\n // pass.\n std::vector inputs = args;\n inputs.insert(inputs.end(), outputs.begin(), outputs.end());\n\n // Compute the stream. Maybe do it in a smarter way at some point in the\n // future.\n Stream s = (outputs[0].has_primitive()) ? outputs[0].primitive().stream()\n : default_stream(default_device());\n\n // Make the output info\n std::vector> shapes;\n std::vector dtypes;\n for (const auto& out : outputs) {\n shapes.emplace_back(out.shape());\n dtypes.emplace_back(out.dtype());\n }\n\n return array::make_arrays(\n std::move(shapes),\n dtypes,\n std::make_shared(to_stream(s), fun_vjp),\n inputs);\n };\n}\n\nstd::function(const std::vector&)> checkpoint(\n std::function(const std::vector&)> fun) {\n auto vjp_fun = [fun](\n const std::vector& primals,\n const std::vector& cotangents,\n const std::vector& outputs) -> std::vector {\n auto [__, vjps] = vjp(fun, depends(primals, outputs), cotangents);\n return vjps;\n };\n\n return custom_vjp(fun, vjp_fun);\n}\n\n} // namespace mlx::core\n\n// Path: mlx/linalg.cpp\n// Copyright \u00a9 2023 Apple Inc.\n\n#include \n#include \n#include \n\n#include \"mlx/linalg.h\"\n#include \"mlx/primitives.h\"\n#include \"mlx/utils.h\"\n\nnamespace mlx::core::linalg {\n\nDtype at_least_float(const Dtype& d) {\n return issubdtype(d, inexact) ? d : promote_types(d, float32);\n}\n\ninline array l2_norm(\n const array& a,\n const std::vector& axis,\n bool keepdims,\n StreamOrDevice s) {\n if (issubdtype(a.dtype(), complexfloating)) {\n return sqrt(sum(abs(a, s) * abs(a, s), axis, keepdims, s), s);\n } else {\n return sqrt(sum(square(a, s), axis, keepdims, s), s);\n }\n}\n\ninline array vector_norm(\n const array& a,\n const double ord,\n const std::vector& axis,\n bool keepdims,\n StreamOrDevice s) {\n auto dtype = at_least_float(a.dtype());\n if (ord == 0.0) {\n return astype(sum(not_equal(a, array(0), s), axis, keepdims, s), dtype, s);\n } else if (ord == 1.0) {\n return astype(sum(abs(a, s), axis, keepdims, s), dtype, s);\n } else if (ord == 2.0) {\n return l2_norm(a, axis, keepdims, s);\n } else if (ord == std::numeric_limits::infinity()) {\n return astype(max(abs(a, s), axis, keepdims, s), dtype, s);\n } else if (ord == -std::numeric_limits::infinity()) {\n return astype(min(abs(a, s), axis, keepdims, s), dtype, s);\n } else {\n return power(\n sum(power(abs(a, s), array(ord, dtype), s), axis, keepdims, s),\n array(1.0 / ord, dtype),\n s);\n }\n}\n\ninline array matrix_norm(\n const array& a,\n const double ord,\n const std::vector& axis,\n bool keepdims,\n StreamOrDevice s) {\n auto dtype = at_least_float(a.dtype());\n auto row_axis = axis[0];\n auto col_axis = axis[1];\n if (ord == -1.0) {\n col_axis -= (!keepdims && col_axis > row_axis && col_axis > 0);\n return astype(\n min(sum(abs(a, s), row_axis, keepdims, s), col_axis, keepdims, s),\n dtype,\n s);\n } else if (ord == 1.0) {\n col_axis -= (!keepdims && col_axis > row_axis && col_axis > 0);\n return astype(\n max(sum(abs(a, s), row_axis, keepdims, s), col_axis, keepdims, s),\n dtype,\n s);\n } else if (ord == std::numeric_limits::infinity()) {\n row_axis -= (!keepdims && row_axis > col_axis && row_axis > 0);\n return astype(\n max(sum(abs(a, s), col_axis, keepdims, s), row_axis, keepdims, s),\n dtype,\n s);\n } else if (ord == -std::numeric_limits::infinity()) {\n row_axis -= (!keepdims && row_axis > col_axis && row_axis > 0);\n return astype(\n min(sum(abs(a, s), col_axis, keepdims, s), row_axis, keepdims, s),\n dtype,\n s);\n } else if (ord == 2.0 || ord == -2.0) {\n throw std::runtime_error(\n \"[linalg::norm] Singular value norms are not implemented.\");\n } else {\n std::ostringstream msg;\n msg << \"[linalg::norm] Invalid ord \" << ord << \" for matrix norm.\";\n throw std::invalid_argument(msg.str());\n }\n}\n\ninline array matrix_norm(\n const array& a,\n const std::string& ord,\n const std::vector& axis,\n bool keepdims,\n StreamOrDevice s) {\n if (ord == \"f\" || ord == \"fro\") {\n return l2_norm(a, axis, keepdims, s);\n } else if (ord == \"nuc\") {\n throw std::runtime_error(\n \"[linalg::norm] Nuclear norm not yet implemented.\");\n } else {\n std::ostringstream msg;\n msg << \"[linalg::norm] Invalid ord value '\" << ord << \"' for matrix norm.\";\n throw std::invalid_argument(msg.str());\n }\n}\n\narray norm(\n const array& a,\n const std::optional>& axis /* = std::nullopt */,\n bool keepdims /* = false */,\n StreamOrDevice s /* = {} */) {\n if (!axis) {\n return norm(flatten(a, s), std::vector{0}, keepdims, s);\n }\n\n if (axis.value().size() > 2) {\n throw std::invalid_argument(\n \"[linalg::norm] Received too many axes for norm.\");\n }\n return l2_norm(a, axis.value(), keepdims, s);\n}\n\narray norm(\n const array& a,\n const double ord,\n const std::optional>& axis /* = std::nullopt */,\n bool keepdims /* = false */,\n StreamOrDevice s /* = {} */) {\n std::vector ax;\n if (!axis) {\n ax.resize(a.ndim());\n std::iota(ax.begin(), ax.end(), 0);\n } else {\n ax = axis.value();\n }\n if (ax.size() == 1) {\n return vector_norm(a, ord, ax, keepdims, s);\n } else if (ax.size() == 2) {\n return matrix_norm(a, ord, ax, keepdims, s);\n } else {\n throw std::invalid_argument(\n \"[linalg::norm] Received too many axes for norm.\");\n }\n}\n\narray norm(\n const array& a,\n const std::string& ord,\n const std::optional>& axis /* = std::nullopt */,\n bool keepdims /* = false */,\n StreamOrDevice s /* = {} */) {\n std::vector ax;\n if (!axis) {\n ax.resize(a.ndim());\n std::iota(ax.begin(), ax.end(), 0);\n } else {\n ax = axis.value();\n }\n if (ax.size() != 2) {\n std::ostringstream msg;\n msg << \"[linalg::norm] Norm '\" << ord << \"' only supported for matrices,\"\n << \" but received \" << ax.size() << \" axis/axes.\";\n throw std::invalid_argument(msg.str());\n }\n return matrix_norm(a, ord, ax, keepdims, s);\n}\n\nstd::pair qr(const array& a, StreamOrDevice s /* = {} */) {\n if (a.dtype() != float32) {\n std::ostringstream msg;\n msg << \"[linalg::qr] Arrays must type float32. Received array \"\n << \"with type \" << a.dtype() << \".\";\n throw std::invalid_argument(msg.str());\n }\n if (a.ndim() < 2) {\n std::ostringstream msg;\n msg << \"[linalg::qr] Arrays must have >= 2 dimensions. Received array \"\n \"with \"\n << a.ndim() << \" dimensions.\";\n throw std::invalid_argument(msg.str());\n }\n if (a.shape(-1) != a.shape(-2)) {\n throw std::invalid_argument(\n \"[linalg::qr] Support for non-square matrices NYI.\");\n }\n\n auto out = array::make_arrays(\n {a.shape(), a.shape()},\n {a.dtype(), a.dtype()},\n std::make_shared(to_stream(s)),\n {astype(a, a.dtype(), s)});\n return std::make_pair(out[0], out[1]);\n}\n\nstd::vector svd(const array& a, StreamOrDevice s /* = {} */) {\n if (a.dtype() != float32) {\n std::ostringstream msg;\n msg << \"[linalg::svd] Input array must have type float32. Received array \"\n << \"with type \" << a.dtype() << \".\";\n throw std::invalid_argument(msg.str());\n }\n if (a.ndim() < 2) {\n std::ostringstream msg;\n msg << \"[linalg::svd] Input array must have >= 2 dimensions. Received array \"\n \"with \"\n << a.ndim() << \" dimensions.\";\n throw std::invalid_argument(msg.str());\n }\n\n const auto m = a.shape(-2);\n const auto n = a.shape(-1);\n const auto rank = a.ndim();\n\n std::vector u_shape = a.shape();\n u_shape[rank - 2] = m;\n u_shape[rank - 1] = m;\n\n std::vector s_shape = a.shape();\n s_shape.pop_back();\n s_shape[rank - 2] = std::min(m, n);\n\n std::vector vt_shape = a.shape();\n vt_shape[rank - 2] = n;\n vt_shape[rank - 1] = n;\n\n return array::make_arrays(\n {u_shape, s_shape, vt_shape},\n {a.dtype(), a.dtype(), a.dtype()},\n std::make_shared(to_stream(s)),\n {a});\n}\n\narray inv(const array& a, StreamOrDevice s /* = {} */) {\n if (a.dtype() != float32) {\n std::ostringstream msg;\n msg << \"[linalg::inv] Arrays must type float32. Received array \"\n << \"with type \" << a.dtype() << \".\";\n throw std::invalid_argument(msg.str());\n }\n if (a.ndim() < 2) {\n std::ostringstream msg;\n msg << \"[linalg::inv] Arrays must have >= 2 dimensions. Received array \"\n \"with \"\n << a.ndim() << \" dimensions.\";\n throw std::invalid_argument(msg.str());\n }\n if (a.shape(-1) != a.shape(-2)) {\n throw std::invalid_argument(\n \"[linalg::inv] Inverses are only defined for square matrices.\");\n }\n\n return array(\n a.shape(), a.dtype(), std::make_shared(to_stream(s)), {a});\n}\n\n} // namespace mlx::core::linalg\n\n// Path: mlx/compile.cpp\n// Copyright \u00a9 2023-2024 Apple Inc.\n\n#include \n#include \n#include \n#include \n\n#include \"mlx/allocator.h\"\n#include \"mlx/compile.h\"\n#include \"mlx/compile_impl.h\"\n#include \"mlx/primitives.h\"\n#include \"mlx/transforms.h\"\n#include \"mlx/transforms_impl.h\"\n\nnamespace mlx::core {\n\nconstexpr int max_compile_depth = 11;\n\nbool is_unary(const Primitive& p) {\n return (\n typeid(p) == typeid(Abs) || typeid(p) == typeid(ArcCos) ||\n typeid(p) == typeid(ArcCosh) || typeid(p) == typeid(ArcSin) ||\n typeid(p) == typeid(ArcSinh) || typeid(p) == typeid(ArcTan) ||\n typeid(p) == typeid(ArcTanh) || typeid(p) == typeid(AsType) ||\n typeid(p) == typeid(Ceil) || typeid(p) == typeid(Cos) ||\n typeid(p) == typeid(Cosh) || typeid(p) == typeid(Remainder) ||\n typeid(p) == typeid(Erf) || typeid(p) == typeid(ErfInv) ||\n typeid(p) == typeid(Exp) || typeid(p) == typeid(Floor) ||\n typeid(p) == typeid(Log) || typeid(p) == typeid(Log1p) ||\n typeid(p) == typeid(LogicalNot) || typeid(p) == typeid(Negative) ||\n typeid(p) == typeid(Round) || typeid(p) == typeid(Sigmoid) ||\n typeid(p) == typeid(Sign) || typeid(p) == typeid(Sin) ||\n typeid(p) == typeid(Sinh) || typeid(p) == typeid(Square) ||\n typeid(p) == typeid(Sqrt) || typeid(p) == typeid(Tan) ||\n typeid(p) == typeid(Tanh) || typeid(p) == typeid(Expm1));\n}\n\nbool is_binary(const Primitive& p) {\n return (\n typeid(p) == typeid(Add) || typeid(p) == typeid(Divide) ||\n typeid(p) == typeid(Equal) || typeid(p) == typeid(Greater) ||\n typeid(p) == typeid(GreaterEqual) || typeid(p) == typeid(Less) ||\n typeid(p) == typeid(LessEqual) || typeid(p) == typeid(LogicalNot) ||\n typeid(p) == typeid(LogicalAnd) || typeid(p) == typeid(LogicalOr) ||\n typeid(p) == typeid(LogAddExp) || typeid(p) == typeid(Maximum) ||\n typeid(p) == typeid(Minimum) || typeid(p) == typeid(Multiply) ||\n typeid(p) == typeid(NotEqual) || typeid(p) == typeid(Power) ||\n typeid(p) == typeid(Subtract));\n}\n\nbool is_ternary(const Primitive& p) {\n return typeid(p) == typeid(Select);\n}\n\nbool is_broadcast(const Primitive& p) {\n return typeid(p) == typeid(Broadcast);\n}\n\nbool is_noop(const Primitive& p) {\n return typeid(p) == typeid(Copy) || typeid(p) == typeid(StopGradient);\n}\n\nbool is_reduction(const Primitive& p) {\n return typeid(p) == typeid(Reduce) || typeid(p) == typeid(ArgReduce);\n}\n\nbool is_fusable(const Primitive& p) {\n return is_unary(p) || is_binary(p) || is_ternary(p) || is_broadcast(p) ||\n is_noop(p);\n}\n\nbool allows_shapeless(const Primitive& p) {\n return typeid(p) == typeid(Compiled) || is_unary(p) || is_binary(p) ||\n is_noop(p) || is_reduction(p) || typeid(p) == typeid(Softmax) ||\n typeid(p) == typeid(Sort) || typeid(p) == typeid(ArgSort) ||\n typeid(p) == typeid(ArgPartition) || typeid(p) == typeid(Partition) ||\n typeid(p) == typeid(Select) || typeid(p) == typeid(NumberOfElements);\n}\n\nCompiled::Compiled(\n Stream stream,\n std::vector inputs,\n std::vector outputs,\n std::vector tape,\n std::unordered_set constant_ids)\n : Primitive(stream),\n inputs_(std::move(inputs)),\n outputs_(std::move(outputs)),\n tape_(std::move(tape)),\n constant_ids_(std::move(constant_ids)) {}\n\nstd::vector Compiled::vjp(\n const std::vector& primals,\n const std::vector& cotangents,\n const std::vector& argnums,\n const std::vector& outputs) {\n throw std::runtime_error(\"[Compiled] Cannot vjp primitive.\");\n}\n\nstd::vector Compiled::jvp(\n const std::vector& primals,\n const std::vector& tangents,\n const std::vector& argnums) {\n throw std::runtime_error(\"[Compiled] Cannot jvp primitive.\");\n}\n\nstd::pair, std::vector> Compiled::vmap(\n const std::vector& inputs,\n const std::vector& axes) {\n throw std::runtime_error(\"[Compiled] Cannot vmap primitive.\");\n}\n\nbool Compiled::is_equivalent(const Primitive& other) const {\n const Compiled& a_other = static_cast(other);\n return std::equal(\n tape_.begin(),\n tape_.end(),\n a_other.tape_.begin(),\n a_other.tape_.end(),\n [](const array& a1, const array& a2) {\n auto& p1 = a1.primitive();\n auto& p2 = a2.primitive();\n return typeid(p1) == typeid(p2) && p1.is_equivalent(p2);\n });\n}\n\nvoid Compiled::print(std::ostream& os) {\n os << \"Compiled\";\n for (auto& a : tape_) {\n a.primitive().print(os);\n }\n}\n\nstd::vector> Compiled::output_shapes(\n const std::vector& inputs) {\n size_t nd = 0;\n for (auto& in : inputs) {\n nd = std::max(nd, in.ndim());\n }\n std::vector out_shape(nd, 0);\n for (auto& in : inputs) {\n auto dd = nd - in.ndim();\n for (auto i = dd; i < nd; ++i) {\n out_shape[i] = std::max(out_shape[i], in.shape()[i - dd]);\n }\n }\n // All outputs have the same shape\n return std::vector>(outputs_.size(), out_shape);\n}\n\nnamespace detail {\n\nCompileMode& compile_mode() {\n auto get_val = []() {\n if (const char* buff_str = std::getenv(\"MLX_DISABLE_COMPILE\")) {\n return CompileMode::disabled;\n } else {\n return CompileMode::enabled;\n }\n };\n static CompileMode compile_mode_ = get_val();\n return compile_mode_;\n}\n\nusing ParentsMap =\n std::unordered_map>>;\n\n// Helper like below but only merges the two provided arrays. If the src has\n// siblings then these won't be merged to the dst.\nvoid merge_one(array& dst, array& src, ParentsMap& parents_map) {\n auto src_parents = parents_map.find(src.id());\n if (src_parents == parents_map.end()) {\n return;\n }\n auto& pairs = parents_map[dst.id()];\n for (auto& parent : src_parents->second) {\n parent.first.inputs()[parent.second] = dst;\n pairs.push_back(parent);\n }\n // Remove the source from the map to avoid fusing with it again\n parents_map.erase(src_parents);\n};\n\n// Helper that merges two arrays in the graph by setting the parents of the\n// source to point to the destination. The arrays are assumed to be coming from\n// equivalent primitives so their siblings are merged as well.\nvoid merge(array& dst, array& src, ParentsMap& parents_map) {\n // Canonicalize the order of the primitives outputs\n auto sources = src.outputs();\n auto dests = dst.outputs();\n // For each src parent, point it to the corresponding dst\n for (int i = 0; i < sources.size(); ++i) {\n merge_one(dests[i], sources[i], parents_map);\n }\n};\n\ntemplate \nstd::uintptr_t get_function_address(const std::function& fun) {\n using FunType = T (*)(U...);\n const FunType* fun_ptr = fun.template target();\n if (fun_ptr == nullptr) {\n throw std::invalid_argument(\n \"[compile] Cannot compile a non-addressable function.\");\n }\n return reinterpret_cast(*fun_ptr);\n}\n\nclass CompilerCache {\n public:\n struct CacheEntry {\n std::vector inputs;\n std::vector outputs;\n std::vector tape;\n bool empty{true};\n std::vector constants;\n };\n\n // Returns a reference to a CacheEntry which can be updated\n // by the caller to avoid copying large tapes / inputs / outputs\n CacheEntry& find(\n std::uintptr_t fun_id,\n const std::vector& inputs,\n bool shapeless,\n const std::vector& constants) {\n // Find the cache entries for |fun_id|.\n std::vector& entries = cache_[fun_id];\n // Compare if 2 arrays have same shape and dtype.\n auto has_same_shape_and_dtype = [shapeless](\n const std::vector& in1,\n const std::vector& in2) {\n if (in1.size() != in2.size()) {\n return false;\n }\n for (size_t i = 0; i < in1.size(); ++i) {\n if (in1[i].ndim() != in2[i].ndim()) {\n return false;\n }\n if (!shapeless && in1[i].shape() != in2[i].shape()) {\n return false;\n }\n if (in1[i].dtype() != in2[i].dtype()) {\n return false;\n }\n }\n return true;\n };\n // Loop over entries and check inputs match i.e. shapes and types must be\n // equal. Note this could get really slow if one compiles the same\n // function with many different shapes. May want to store entries in a\n // more easily searchable structure.\n for (CacheEntry& entry : entries) {\n // Check the inputs match and return if so\n if (has_same_shape_and_dtype(inputs, entry.inputs) &&\n constants == entry.constants) {\n return entry;\n }\n }\n // Otherwise append a new cache entry\n entries.push_back(CacheEntry{});\n return entries.back();\n };\n\n void erase(std::uintptr_t fun_id) {\n cache_.erase(fun_id);\n }\n\n private:\n CompilerCache() {\n // Make sure the allocator is fully\n // initialized before the compiler cache\n allocator::allocator();\n }\n\n friend CompilerCache& compiler_cache();\n std::unordered_map> cache_;\n};\n\nCompilerCache& compiler_cache() {\n static CompilerCache compiler_cache_;\n return compiler_cache_;\n}\n\nstd::pair, std::vector> compile_trace(\n const std::function(const std::vector&)>& fun,\n const std::vector& inputs) {\n // Set the global tracing flag.\n detail::InTracing in_tracing;\n\n // Run the function on placeholder inputs\n // to get compute graph\n std::vector tracer_inputs;\n for (int i = 0; i < inputs.size(); ++i) {\n array in(inputs[i].shape(), inputs[i].dtype(), nullptr, {});\n in.set_tracer(true);\n tracer_inputs.push_back(std::move(in));\n }\n return {tracer_inputs, fun(tracer_inputs)};\n}\n\n// Traverses the graph to build a tape and a map of array ids to their parents\nstd::pair, ParentsMap> compile_dfs(\n const std::vector& inputs,\n const std::vector& outputs) {\n std::function recurse;\n std::vector tape;\n std::unordered_set input_set;\n std::unordered_map>>\n parents_map;\n for (int i = 0; i < inputs.size(); ++i) {\n auto in = inputs[i];\n input_set.insert(in.id());\n }\n\n // DFS the graph to build the tape, and log parents and scalars\n std::unordered_set cache;\n recurse = [&](const array& a) {\n auto id = a.id();\n if (cache.find(id) != cache.end()) {\n return;\n }\n for (int i = 0; i < a.inputs().size(); i++) {\n auto& in = a.inputs()[i];\n parents_map[in.id()].push_back({a, i});\n for (auto& s : a.siblings()) {\n parents_map[in.id()].push_back({s, i});\n }\n // Don't recurse on inputs (but add them to the tape for the purpose\n // of future optimizations)\n if (input_set.find(a.id()) == input_set.end()) {\n recurse(in);\n }\n }\n cache.insert(id);\n for (auto& s : a.siblings()) {\n cache.insert(s.id());\n }\n tape.push_back(a);\n };\n for (auto& a : outputs) {\n recurse(a);\n }\n return {tape, parents_map};\n}\n\n// Simplify the tape. Note, this function modifies in-place both the tape and\n// the parents map to remove orphaned arrays\nvoid compile_simplify(\n std::vector& tape,\n ParentsMap& parents_map,\n const std::vector& outputs,\n int passes) {\n // Helpers to identify identical scalars\n std::map, array> scalars;\n auto is_scalar = [](const array& a) {\n return a.is_evaled() && a.ndim() == 0;\n };\n auto get_scalar_rep = [](const array& a) {\n uint64_t v = 0;\n int dtype;\n switch (a.dtype().size) {\n case 1:\n v = *a.data();\n break;\n case 2:\n v = *a.data();\n break;\n case 4:\n v = *a.data();\n break;\n case 8:\n v = *a.data();\n break;\n }\n return std::make_pair(v, a.dtype().val);\n };\n\n for (auto& a : tape) {\n if (is_scalar(a)) {\n scalars.insert({get_scalar_rep(a), a});\n }\n }\n\n // Depth-1 array equivalence check.\n auto array_equivalent = [](const array& a, const array& b) {\n if (!a.has_primitive() || !b.has_primitive()) {\n return false;\n }\n if (a.primitive_id() == b.primitive_id()) {\n return false;\n }\n const auto& pa = a.primitive();\n const auto& pb = b.primitive();\n if (typeid(pa) != typeid(pb)) {\n return false;\n }\n\n if (a.inputs().size() != b.inputs().size()) {\n return false;\n }\n\n for (int i = 0; i < a.inputs().size(); i++) {\n if (a.inputs()[i].id() != b.inputs()[i].id()) {\n return false;\n }\n }\n\n return pa.is_equivalent(pb);\n };\n\n // Merge scalars\n std::vector new_tape;\n for (auto& arr : tape) {\n // Check if we can merge scalars\n if (is_scalar(arr)) {\n auto scalar = scalars.find(get_scalar_rep(arr));\n if (scalar->second.id() != arr.id()) {\n merge(scalar->second, arr, parents_map);\n // Don't keep orphaned scalars in the tape\n continue;\n }\n }\n new_tape.push_back(std::move(arr));\n }\n tape = std::move(new_tape);\n\n std::unordered_map tape_order;\n for (uint32_t i = 0; i < tape.size(); ++i) {\n tape_order.insert({tape[i].id(), i});\n }\n\n std::unordered_set output_set;\n for (auto& o : outputs) {\n output_set.insert(o.id());\n }\n // Multi-pass merge only keeping non-orphaned arrays in the tape\n for (int pass = 0; pass < passes; ++pass) {\n for (auto& arr : tape) {\n // Helper to check if we can merge the parents of the\n // given array\n auto maybe_merge_parents = [&](auto& a) {\n auto parents = parents_map.find(a.id());\n if (parents != parents_map.end()) {\n auto N = parents->second.size();\n std::vector mask(N, false);\n for (int i = 0; i < N; i++) {\n if (mask[i]) {\n continue;\n }\n for (int j = i + 1; j < N; j++) {\n if (mask[j]) {\n continue;\n }\n auto src_idx = j;\n auto dst_idx = i;\n if (tape_order[parents->second[src_idx].first.id()] <\n tape_order[parents->second[dst_idx].first.id()]) {\n std::swap(src_idx, dst_idx);\n }\n auto& src = parents->second[src_idx].first;\n auto& dst = parents->second[dst_idx].first;\n if (src.id() != dst.id() && array_equivalent(src, dst) &&\n output_set.find(src.id()) == output_set.end()) {\n merge(dst, src, parents_map);\n mask[src_idx] = true;\n }\n }\n }\n // Erase orphaned parents so we don't keep fusing with them\n for (int i = N - 1; i >= 0; --i) {\n if (mask[i]) {\n parents->second.erase(parents->second.begin() + i);\n }\n }\n return false;\n } else {\n return output_set.find(a.id()) == output_set.end();\n }\n };\n bool discard = maybe_merge_parents(arr);\n for (auto& s : arr.siblings()) {\n discard &= maybe_merge_parents(s);\n }\n // If an array and its siblings have no parents, and none of them are\n // outputs, it is safe to remove it from the tape\n if (!discard) {\n new_tape.push_back(std::move(arr));\n }\n }\n tape = std::move(new_tape);\n }\n}\n\n// Extract sub-graphs of the graph that can be compiled\n// and replace them with a Compiled Primitive.\nvoid compile_fuse(\n std::vector& tape,\n ParentsMap& parents_map,\n const std::vector& inputs,\n std::vector& outputs) {\n // Track outputs to replace with new compiled outputs\n std::unordered_map output_map;\n for (auto& o : outputs) {\n output_map.insert({o.id(), o});\n }\n\n // Set of inputs to distinguish constants\n std::unordered_set input_ids;\n for (auto& in : inputs) {\n input_ids.insert(in.id());\n }\n\n // Go through the tape in reverse order and check for fusable sub-graphs\n std::vector new_tape;\n std::unordered_set global_cache;\n for (int i = tape.size() - 1; i >= 0; --i) {\n auto& arr = tape[i];\n\n // Already compiled\n if (global_cache.find(arr.id()) != global_cache.end()) {\n continue;\n }\n\n // Two pass recursion:\n // First pass:\n // - Collect all the primitives which we can fuse with\n // - Keeps a cache of fusable primitives which may be added out of\n // DAG order. We have to determine if all of a fused primitive's\n // outputs are also in the fused section, and this may not be the\n // case the first time we visit it.\n // Second pass:\n // - Collect inputs to the new compiled primitive\n // - Add fusable primitives to a tape in the correct order\n\n std::function&)>\n recurse;\n std::unordered_set cache;\n recurse = [&](const array& a,\n int depth,\n const Stream& s,\n const std::vector& shape) {\n if (cache.find(a.id()) != cache.end()) {\n return;\n }\n\n // Stop fusing if:\n // - Depth limit exceeded\n // - Constant input\n // - Stream mismatch\n // - Non fusable primitive\n // - Is global output but has a different shape\n if (depth >= max_compile_depth || !a.has_primitive() ||\n a.primitive().stream() != s || !is_fusable(a.primitive()) ||\n (output_map.find(a.id()) != output_map.end() && a.shape() != shape)) {\n return;\n }\n\n bool all_parents_in = true;\n if (depth > 0) {\n // Guaranteed to have a parent since nested in the\n // recursion.\n auto& parents = parents_map.at(a.id());\n for (auto& [p, idx] : parents) {\n auto in_cache = cache.find(p.id()) != cache.end();\n if (!in_cache) {\n all_parents_in = false;\n break;\n }\n }\n }\n\n // Arrays with a mix of parents outside the compilable section\n // are not fusable\n if (!all_parents_in) {\n return;\n }\n\n cache.insert({a.id()});\n\n for (auto& in : a.inputs()) {\n recurse(in, depth + 1, s, shape);\n }\n };\n\n if (arr.has_primitive()) {\n Stream s = arr.primitive().stream();\n recurse(arr, 0, s, arr.shape());\n }\n\n // Not worth fusing a single primitive\n if (cache.size() <= 1) {\n new_tape.push_back(arr);\n continue;\n }\n\n // Recurse a second time to build the tape in the right\n // order and collect the inputs\n std::unordered_set input_set;\n std::vector inputs;\n std::vector fused_tape;\n std::unordered_set tape_set;\n std::function recurse_tape;\n recurse_tape = [&](const array& a) {\n if (cache.find(a.id()) == cache.end()) {\n if (input_set.find(a.id()) == input_set.end()) {\n input_set.insert(a.id());\n inputs.push_back(a);\n }\n return;\n }\n if (tape_set.find(a.id()) != tape_set.end()) {\n return;\n }\n tape_set.insert(a.id());\n for (auto& in : a.inputs()) {\n recurse_tape(in);\n }\n fused_tape.push_back(a);\n };\n recurse_tape(arr);\n\n std::vector old_outputs;\n // Add to global cache and add any global outputs to outputs\n // of new primitive\n for (int j = 0; j < fused_tape.size() - 1; ++j) {\n auto& f = fused_tape[j];\n if (output_map.find(f.id()) != output_map.end()) {\n old_outputs.push_back(f);\n // Parents are now siblings, update the parent map\n auto& pairs = parents_map[f.id()];\n pairs.erase(\n std::remove_if(\n pairs.begin(),\n pairs.end(),\n [&](auto& p) {\n return cache.find(p.first.id()) != cache.end();\n }),\n pairs.end());\n } else {\n // Remove inner fused arrays parents from the parents map\n // to keep the parents map in a valid state\n parents_map.erase(f.id());\n }\n global_cache.insert({f.id()});\n }\n old_outputs.push_back(arr);\n\n std::vector> shapes;\n std::vector types;\n for (auto& o : old_outputs) {\n if (o.shape() != old_outputs.back().shape()) {\n throw std::runtime_error(\n \"[compile] Compilation failed. Tried to fuse operations with different output shapes\");\n }\n shapes.push_back(o.shape());\n types.push_back(o.dtype());\n }\n std::unordered_set constant_ids;\n for (auto& in : inputs) {\n // Scalar constant\n if (in.size() == 1 && !in.has_primitive() &&\n input_ids.find(in.id()) == input_ids.end()) {\n constant_ids.insert(in.id());\n }\n }\n auto compiled_outputs = array::make_arrays(\n std::move(shapes),\n types,\n std::make_shared(\n old_outputs.back().primitive().stream(),\n inputs,\n old_outputs,\n std::move(fused_tape),\n std::move(constant_ids)),\n inputs);\n\n // One output per primitive\n new_tape.push_back(compiled_outputs.back());\n\n // Replace inputs old parents with compiled_outputs\n for (int i = 0; i < inputs.size(); ++i) {\n auto& pairs = parents_map[inputs[i].id()];\n pairs.erase(\n std::remove_if(\n pairs.begin(),\n pairs.end(),\n [&](auto& p) { return cache.find(p.first.id()) != cache.end(); }),\n pairs.end());\n for (auto& o : compiled_outputs) {\n pairs.push_back({o, i});\n }\n }\n\n // - Update outputs parents to point to compiled outputs\n // - Update any overall graph outputs to be compiled outputs\n for (int o = 0; o < old_outputs.size(); ++o) {\n merge_one(compiled_outputs[o], old_outputs[o], parents_map);\n if (auto it = output_map.find(old_outputs[o].id());\n it != output_map.end()) {\n it->second = compiled_outputs[o];\n }\n }\n }\n\n...\n", "needle_token_start": 2074, "needle_token_end": 2788, "code_context_ntokens": 16226, "output": ["```cpp\nstd::pair, std::vector> vmap_trace(\n const std::function(const std::vector&)>& fun,\n const std::vector& inputs,\n const std::vector& in_axes) {\n // Set the global tracing flag.\n detail::InTracing in_tracing;\n\n if (in_axes.size() != inputs.size()) {\n throw std::invalid_argument(\n \"[vmap] The number of in axes must match the number of inputs.\");\n }\n\n // Some error checking and get the vmap axis size\n size_t vmap_ax_size;\n for (int i = 0; i < inputs.size(); ++i) {\n if (in_axes[i] != -1) {\n if (inputs[i].ndim() == 0) {\n throw std::invalid_argument(\n \"[vmap] Cannot vmap an input with zero dimensions.\");\n }\n if (in_axes[i] > inputs[i].ndim()) {\n std::ostringstream msg;\n msg << \"[vmap] Axis \" << in_axes[i] << \" invalid for input with \"\n << inputs[i].ndim() << \" dimensions.\";\n throw std::invalid_argument(msg.str());\n }\n vmap_ax_size = inputs[i].shape(in_axes[i]);\n }\n }\n // Check that all vmapped axes have the same size\n for (int i = 0; i < inputs.size(); ++i) {\n if (in_axes[i] != -1) {\n if (size_t in_ax = inputs[i].shape(in_axes[i]); vmap_ax_size != in_ax) {\n std::ostringstream msg;\n msg << \"[vmap] Inconsistent axis sizes: \" << in_ax << \" and \"\n << vmap_ax_size << \".\";\n throw std::invalid_argument(msg.str());\n }\n }\n }\n\n // Run the function on placeholder inputs\n // to get the original graph\n std::vector s_inputs;\n for (int i = 0; i < inputs.size(); ++i) {\n if (in_axes[i] != -1) {\n std::vector shape = inputs[i].shape();\n shape.erase(shape.begin() + in_axes[i]);\n array in(shape, inputs[i].dtype(), nullptr, {});\n s_inputs.push_back(in);\n s_inputs.back().set_tracer(true);\n } else {\n s_inputs.push_back(inputs[i]);\n }\n }\n return {s_inputs, fun(s_inputs)};\n}\n```"]} +{"repo": "ml-explore/mlx", "name": "in_tracing", "language": "cpp", "path": "mlx/array.cpp", "position_ratio": 0.25, "description": "\nFunction Description:\n1. **Purpose**: To determine if the system is currently in a state of performing transformations on functions specifically for maintaining computational graphs during the evaluation of tracer arrays.\n2. **Input**: None; the function does not take any parameters.\n3. **Output**: Returns a boolean value; true if the system is in the process of function transformation for graph retention, otherwise false.\n4. **Procedure**: The function internally accesses a state-checking utility from a detailed namespace to ascertain whether the current operational mode involves function transformation aimed at graph preservation.\n\n", "instruction": "Based on the function description and code context, please retrieve and repeat the exact described function from the code context in a code block wrapped by ```:", "template": "instruction\ncode_context\ndescription\ninstruction", "code_context": "// Path: mlx/array.cpp\n// Copyright \u00a9 2023-2024 Apple Inc.\n\n#include \n\n#include \"mlx/array.h\"\n#include \"mlx/ops.h\"\n#include \"mlx/primitives.h\"\n#include \"mlx/transforms.h\"\n#include \"mlx/transforms_impl.h\"\n\nnamespace mlx::core {\n\nnamespace {\n\n/** Return true if we are currently performing a function transformation in\n * order to keep the graph when evaluating tracer arrays. */\nb\nool in_tracing() {\n return detail::InTracing::in_tracing();\n}\n\n} // namespace\n\narray::array(const std::complex& val, Dtype dtype /* = complex64 */)\n : array_desc_(std::make_shared(std::vector{}, dtype)) {\n auto cval = static_cast(val);\n init(&cval);\n}\n\narray::array(\n std::vector shape,\n Dtype dtype,\n std::shared_ptr primitive,\n std::vector inputs)\n : array_desc_(std::make_shared(\n std::move(shape),\n dtype,\n std::move(primitive),\n std::move(inputs))) {}\n\nstd::vector array::make_arrays(\n std::vector> shapes,\n const std::vector& dtypes,\n const std::shared_ptr& primitive,\n const std::vector& inputs) {\n std::vector outputs;\n for (size_t i = 0; i < shapes.size(); ++i) {\n outputs.emplace_back(std::move(shapes[i]), dtypes[i], primitive, inputs);\n }\n // For each node in |outputs|, its siblings are the other nodes.\n for (size_t i = 0; i < outputs.size(); ++i) {\n auto siblings = outputs;\n siblings.erase(siblings.begin() + i);\n outputs[i].set_siblings(std::move(siblings), i);\n }\n return outputs;\n}\n\narray::array(std::initializer_list data)\n : array_desc_(std::make_shared(\n std::vector{static_cast(data.size())},\n float32)) {\n init(data.begin());\n}\n\narray::array(std::initializer_list data, Dtype dtype)\n : array_desc_(std::make_shared(\n std::vector{static_cast(data.size())},\n dtype)) {\n init(data.begin());\n}\n\n/* Build an array from a shared buffer */\narray::array(\n allocator::Buffer data,\n std::vector shape,\n Dtype dtype,\n deleter_t deleter)\n : array_desc_(std::make_shared(std::move(shape), dtype)) {\n set_data(data, deleter);\n}\n\nvoid array::detach() {\n for (auto& s : array_desc_->siblings) {\n s.array_desc_->inputs.clear();\n s.array_desc_->siblings.clear();\n s.array_desc_->position = 0;\n s.array_desc_->primitive = nullptr;\n }\n array_desc_->inputs.clear();\n array_desc_->siblings.clear();\n array_desc_->position = 0;\n array_desc_->primitive = nullptr;\n}\n\nvoid array::eval() {\n if (!is_evaled()) {\n mlx::core::eval({*this});\n }\n}\n\nbool array::is_tracer() const {\n return array_desc_->is_tracer && in_tracing();\n}\n\nvoid array::set_data(allocator::Buffer buffer, deleter_t d) {\n array_desc_->data = std::make_shared(buffer, d);\n array_desc_->data_ptr = buffer.raw_ptr();\n array_desc_->data_size = size();\n array_desc_->flags.contiguous = true;\n array_desc_->flags.row_contiguous = true;\n auto max_dim = std::max_element(shape().begin(), shape().end());\n array_desc_->flags.col_contiguous = size() <= 1 || size() == *max_dim;\n}\n\nvoid array::set_data(\n allocator::Buffer buffer,\n size_t data_size,\n std::vector strides,\n Flags flags,\n deleter_t d) {\n array_desc_->data = std::make_shared(buffer, d);\n array_desc_->data_ptr = buffer.raw_ptr();\n array_desc_->data_size = data_size;\n array_desc_->strides = std::move(strides);\n array_desc_->flags = flags;\n}\n\nvoid array::copy_shared_buffer(\n const array& other,\n const std::vector& strides,\n Flags flags,\n size_t data_size,\n size_t offset /* = 0 */) {\n array_desc_->data = other.array_desc_->data;\n array_desc_->strides = strides;\n array_desc_->flags = flags;\n array_desc_->data_size = data_size;\n auto char_offset = sizeof(char) * itemsize() * offset;\n array_desc_->data_ptr = static_cast(\n static_cast(other.array_desc_->data_ptr) + char_offset);\n}\n\nvoid array::copy_shared_buffer(const array& other) {\n copy_shared_buffer(other, other.strides(), other.flags(), other.data_size());\n}\n\nvoid array::move_shared_buffer(\n array other,\n const std::vector& strides,\n Flags flags,\n size_t data_size,\n size_t offset /* = 0 */) {\n array_desc_->data = std::move(other.array_desc_->data);\n array_desc_->strides = strides;\n array_desc_->flags = flags;\n array_desc_->data_size = data_size;\n auto char_offset = sizeof(char) * itemsize() * offset;\n array_desc_->data_ptr = static_cast(\n static_cast(other.array_desc_->data_ptr) + char_offset);\n}\n\nvoid array::move_shared_buffer(array other) {\n move_shared_buffer(other, other.strides(), other.flags(), other.data_size());\n}\n\nvoid array::ArrayDesc::init() {\n strides.resize(shape.size());\n size = 1;\n for (int i = shape.size() - 1; i >= 0; --i) {\n strides[i] = size;\n size *= shape[i];\n }\n for (auto& in : inputs) {\n is_tracer |= in.is_tracer();\n }\n}\n\narray::ArrayDesc::ArrayDesc(std::vector shape, Dtype dtype)\n : shape(std::move(shape)), dtype(dtype) {\n init();\n}\n\narray::ArrayDesc::ArrayDesc(\n std::vector shape,\n Dtype dtype,\n std::shared_ptr primitive,\n std::vector inputs)\n : shape(std::move(shape)),\n dtype(dtype),\n primitive(std::move(primitive)),\n inputs(std::move(inputs)) {\n init();\n}\n\narray::ArrayDesc::~ArrayDesc() {\n // When an array description is destroyed it will delete a bunch of arrays\n // that may also destory their corresponding descriptions and so on and so\n // forth.\n //\n // This calls recursively the destructor and can result in stack overflow, we\n // instead put them in a vector and destroy them one at a time resulting in a\n // max stack depth of 2.\n std::vector> for_deletion;\n\n for (array& a : inputs) {\n if (a.array_desc_.use_count() == 1) {\n for_deletion.push_back(std::move(a.array_desc_));\n }\n }\n\n while (!for_deletion.empty()) {\n // top is going to be deleted at the end of the block *after* the arrays\n // with inputs have been moved into the vector\n auto top = std::move(for_deletion.back());\n for_deletion.pop_back();\n\n for (array& a : top->inputs) {\n if (a.array_desc_.use_count() == 1) {\n for_deletion.push_back(std::move(a.array_desc_));\n }\n }\n }\n}\n\narray::ArrayIterator::ArrayIterator(const array& arr, int idx)\n : arr(arr), idx(idx) {\n if (arr.ndim() == 0) {\n throw std::invalid_argument(\"Cannot iterate over 0-d array.\");\n }\n}\n\narray::ArrayIterator::reference array::ArrayIterator::operator*() const {\n auto start = std::vector(arr.ndim(), 0);\n auto end = arr.shape();\n auto shape = arr.shape();\n shape.erase(shape.begin());\n start[0] = idx;\n end[0] = idx + 1;\n return reshape(slice(arr, start, end), shape);\n};\n\n} // namespace mlx::core\n\n// Path: mlx/transforms.cpp\n// Copyright \u00a9 2023-2024 Apple Inc.\n#include \n#include \n#include \n#include \n#include \n#include \n#include \n#include \n\n#include \"mlx/backend/metal/metal_impl.h\"\n#include \"mlx/ops.h\"\n#include \"mlx/primitives.h\"\n#include \"mlx/scheduler.h\"\n#include \"mlx/transforms.h\"\n#include \"mlx/transforms_impl.h\"\n#include \"mlx/utils.h\"\n\nnamespace mlx::core {\n\n/* This class is only meant to be used in eval\n * for synchronizing with the main thread. */\nclass Synchronizer : public Primitive {\n public:\n explicit Synchronizer(Stream stream) : Primitive(stream){};\n\n void eval_cpu(const std::vector&, std::vector&) override {};\n void eval_gpu(const std::vector&, std::vector&) override {};\n\n DEFINE_PRINT(Synchronize);\n};\n\n// Initialize the static tracing counter from transforms_impl.h .\n//\n// This is used to implement the in_tracing() function the returns true if we\n// are currently under a function transformation.\nint detail::InTracing::tracing_counter{0};\n\nstd::shared_future async_eval(std::vector outputs) {\n static std::shared_future global_synchronizer;\n // Catch up with previous async eval if needed\n if (global_synchronizer.valid()) {\n global_synchronizer.wait();\n }\n std::queue tape;\n std::unordered_set cache;\n std::unordered_map> deps;\n\n // Make an effort to choose a good output stream\n Stream stream = default_stream(default_device());\n for (auto& o : outputs) {\n if (!o.is_evaled() && o.has_primitive()) {\n stream = o.primitive().stream();\n break;\n }\n }\n\n auto synchronizer = array(\n {}, bool_, std::make_shared(stream), std::move(outputs));\n\n {\n std::stack, int>> dfs;\n dfs.emplace(synchronizer, 0);\n while (!dfs.empty()) {\n auto& [a_ref, idx] = dfs.top();\n auto& a = a_ref.get();\n if (idx < a.inputs().size()) {\n // Add an input, and continue\n auto& in = a.inputs()[idx++];\n if (!in.is_evaled()) {\n if (!in.has_primitive()) {\n throw std::invalid_argument(\n \"[eval] Attempting to eval an array without a primitive.\");\n }\n\n // If the input is being computed on a different stream, we need to\n // manage the dependency.\n if (a.primitive().stream() != in.primitive().stream()) {\n deps.insert({in.output(0).id(), std::shared_future{}});\n }\n }\n\n if (cache.find(in.id()) == cache.end()) {\n dfs.emplace(in, 0);\n cache.insert(in.id());\n for (auto& s : in.siblings()) {\n cache.insert(s.id());\n }\n }\n continue;\n }\n\n // All inputs are done being processed, process this array\n if (!a.is_evaled() || (!a.is_tracer() && a.has_primitive())) {\n tape.push(a);\n }\n dfs.pop();\n }\n }\n deps.insert({synchronizer.id(), std::shared_future{}});\n\n std::vector>> ps;\n while (!tape.empty()) {\n auto arr = std::move(tape.front());\n tape.pop();\n if (arr.is_evaled()) {\n if (!arr.is_tracer() && arr.has_primitive()) {\n arr.detach();\n }\n continue;\n }\n\n auto stream = arr.primitive().stream();\n std::vector> arr_deps;\n for (auto& in : arr.inputs()) {\n if (auto it = deps.find(in.output(0).id()); it != deps.end()) {\n arr_deps.push_back(it->second);\n }\n }\n std::shared_ptr> p;\n if (auto it = deps.find(arr.output(0).id()); it != deps.end()) {\n p = std::make_shared>();\n ps.push_back(p);\n it->second = p->get_future().share();\n }\n\n if (arr.primitive().device() == Device::gpu) {\n if (!metal::is_available()) {\n throw std::runtime_error(\"Metal GPU is not available.\");\n }\n scheduler::enqueue(\n stream, metal::make_task(arr, std::move(arr_deps), std::move(p)));\n } else {\n auto task = [arr,\n stream,\n deps = std::move(arr_deps),\n p = std::move(p)]() mutable {\n for (auto& d : deps) {\n d.wait();\n }\n scheduler::notify_new_task(stream);\n auto outputs = arr.outputs();\n arr.primitive().eval_cpu(arr.inputs(), outputs);\n if (!arr.is_tracer()) {\n arr.detach();\n }\n if (p) {\n p->set_value();\n }\n scheduler::notify_task_completion(stream);\n };\n scheduler::enqueue(stream, std::move(task));\n }\n }\n global_synchronizer = std::move(deps[synchronizer.id()]);\n return global_synchronizer;\n}\n\nvoid eval(std::vector outputs) {\n async_eval(std::move(outputs)).wait();\n}\n\nstd::pair, std::vector> vjp(\n const std::function(const std::vector&)>& fun,\n const std::vector& primals,\n const std::vector& cotans) {\n // Set the global tracing flag.\n detail::InTracing in_tracing;\n\n // Make tracers from given primals\n std::vector primals_;\n for (auto& p : primals) {\n auto s = p.has_primitive() ? p.primitive().stream()\n : default_stream(default_device());\n primals_.push_back(copy(p, s)); // Does not do a deep copy\n primals_.back().set_tracer(true);\n }\n\n // Pass tracer primals through the function\n // Any variables that depend on the primals are marked as tracers\n auto outputs = fun(primals_);\n\n // Map outputs to passed cotans while ignoring the outputs\n // that have stop_gradient called on them\n int cotan_index = 0;\n std::vector> output_cotan_pairs;\n for (int i = 0; i < outputs.size(); ++i) {\n auto& out = outputs[i];\n if (out.has_primitive()) {\n if (auto& p = out.primitive(); typeid(p) == typeid(StopGradient)) {\n continue;\n }\n }\n if (cotan_index >= cotans.size()) {\n std::ostringstream msg;\n msg << \"[vjp] Number of outputs to compute gradients for (\"\n << outputs.size() << \") does not match number of cotangents (\"\n << cotans.size() << \").\";\n throw std::invalid_argument(msg.str());\n }\n if (out.shape() != cotans[cotan_index].shape()) {\n std::ostringstream msg;\n msg << \"[vjp] Output shape \" << out.shape()\n << \" does not match cotangent shape \" << cotans[cotan_index].shape()\n << \".\";\n if (outputs.size() == 1 && out.size() == 1) {\n msg << \" If you are using grad your function must return a scalar.\";\n }\n throw std::invalid_argument(msg.str());\n }\n output_cotan_pairs.emplace_back(i, cotan_index++);\n }\n\n // Topologically sort the compute graph, add graph nodes\n // to the tape which need a gradient.\n std::unordered_set cache;\n std::unordered_set calc_grad;\n for (auto& primal : primals_) {\n primal.set_tracer(false);\n calc_grad.insert(primal.id());\n cache.insert(primal.id());\n }\n\n std::vector tape;\n\n std::function recurse;\n recurse = [&](auto& a) {\n // Check if visited and add to cache if not\n if (auto inserted = cache.insert(a.id()); !inserted.second) {\n return;\n }\n a.set_tracer(false);\n for (auto s : a.siblings()) {\n s.set_tracer(false);\n cache.insert(s.id());\n }\n\n for (auto& input : a.inputs()) {\n recurse(input);\n }\n\n // Stop grad\n if (a.has_primitive()) {\n if (auto& p = a.primitive(); typeid(p) == typeid(StopGradient)) {\n return;\n }\n }\n\n // Calculate gradient if any inputs require gradient\n for (auto& input : a.inputs()) {\n if (calc_grad.find(input.id()) != calc_grad.end()) {\n tape.push_back(a);\n calc_grad.insert(a.id());\n for (auto& s : a.siblings()) {\n calc_grad.insert(s.id());\n }\n break;\n }\n }\n };\n\n for (auto out : outputs) {\n recurse(out);\n }\n\n // Run the tape backwards, computing vector-jacobian\n // products for each primitive\n std::unordered_map cotan_map;\n for (auto [out_idx, cotan_idx] : output_cotan_pairs) {\n auto& o = outputs[out_idx];\n auto s = o.has_primitive() ? o.primitive().stream()\n : default_stream(default_device());\n cotan_map.insert({o.id(), astype(cotans[cotan_idx], o.dtype(), s)});\n }\n for (auto it = tape.rbegin(); it != tape.rend(); ++it) {\n auto& a = *it;\n\n // Get the arguments whose gradients are needed\n std::vector argnums;\n for (int i = 0; i < a.inputs().size(); ++i) {\n if (calc_grad.find(a.inputs()[i].id()) != calc_grad.end()) {\n argnums.push_back(i);\n }\n }\n\n // Check if any of the array or its siblings have cotangents,\n // if not, we can skip this primitive\n auto outputs = a.outputs();\n bool has_cotans =\n std::any_of(outputs.cbegin(), outputs.cend(), [&cotan_map](auto& s) {\n return cotan_map.find(s.id()) != cotan_map.end();\n });\n if (!has_cotans) {\n continue;\n }\n\n auto s = a.primitive().stream();\n std::vector cotangents{};\n for (auto& o : outputs) {\n if (auto cotan_it = cotan_map.find(o.id()); cotan_it != cotan_map.end()) {\n cotangents.push_back(cotan_map.extract(cotan_it).mapped());\n } else {\n cotangents.push_back(zeros_like(o, s));\n }\n }\n\n auto vjps = a.primitive().vjp(a.inputs(), cotangents, argnums, outputs);\n // Accumulate the vector-jacobian products for each input\n for (int i = 0; i < argnums.size(); ++i) {\n auto in_id = a.inputs()[argnums[i]].id();\n if (auto cotan_it = cotan_map.find(in_id); cotan_it != cotan_map.end()) {\n cotan_it->second = add(cotan_it->second, vjps[i], s);\n } else {\n cotan_map.insert({in_id, vjps[i]});\n }\n }\n }\n\n std::vector vjps;\n for (auto& primal : primals_) {\n if (auto cotan_it = cotan_map.find(primal.id());\n cotan_it != cotan_map.end()) {\n vjps.push_back(cotan_it->second);\n } else {\n auto s = primal.has_primitive() ? primal.primitive().stream()\n : default_stream(default_device());\n vjps.push_back(zeros_like(primal, s));\n }\n }\n return {outputs, vjps};\n}\n\nstd::pair vjp(\n const std::function& fun,\n const array& primal,\n const array& cotan) {\n auto vec_fun = [fun](const std::vector& inputs) {\n return std::vector{fun(inputs[0])};\n };\n auto [outputs, vjps] = vjp(vec_fun, {primal}, {cotan});\n return {outputs[0], vjps[0]};\n}\n\nstd::pair, std::vector> jvp(\n const std::function(const std::vector&)>& fun,\n const std::vector& primals,\n const std::vector& tangents) {\n // Set the global tracing flag.\n detail::InTracing in_tracing;\n\n if (primals.size() != tangents.size()) {\n throw std::invalid_argument(\n \"[jvp] Number of inputs does not match number of tangents.\");\n }\n for (int i = 0; i < primals.size(); ++i) {\n if (primals[i].shape() != tangents[i].shape()) {\n throw std::invalid_argument(\n \"[jvp] Input shape does not match shape of tangent.\");\n }\n }\n\n std::vector primals_;\n for (auto& p : primals) {\n auto s = p.has_primitive() ? p.primitive().stream()\n : default_stream(default_device());\n primals_.push_back(copy(p, s)); // Does not do a deep copy\n primals_.back().set_tracer(true);\n }\n auto outputs = fun(primals_);\n\n // Topologically sort the compute graph, record outputs\n // in the tape if a gradient is needed.\n std::unordered_set cache;\n std::unordered_set calc_grad;\n for (auto& primal : primals_) {\n primal.set_tracer(false);\n calc_grad.insert(primal.id());\n cache.insert(primal.id());\n }\n\n std::vector tape;\n\n std::function recurse;\n recurse = [&](auto& a) {\n // Check if visited and add to cache if not\n if (auto inserted = cache.insert(a.id()); !inserted.second) {\n return;\n }\n a.set_tracer(false);\n for (auto s : a.siblings()) {\n s.set_tracer(false);\n cache.insert(s.id());\n }\n\n for (auto input : a.inputs()) {\n recurse(input);\n }\n\n // Stop grad\n if (a.has_primitive()) {\n if (auto& p = a.primitive(); typeid(p) == typeid(StopGradient)) {\n return;\n }\n }\n\n // Calculate gradient if any inputs require gradient\n for (auto& input : a.inputs()) {\n if (calc_grad.find(input.id()) != calc_grad.end()) {\n tape.push_back(a);\n calc_grad.insert(a.id());\n for (auto& s : a.siblings()) {\n calc_grad.insert(s.id());\n }\n break;\n }\n }\n };\n\n for (auto out : outputs) {\n recurse(out);\n }\n\n std::unordered_map tan_map;\n for (int i = 0; i < primals_.size(); ++i) {\n tan_map.insert({primals_[i].id(), tangents[i]});\n }\n\n for (auto& a : tape) {\n // Get the arguments used in the jvp\n std::vector argnums;\n std::vector tangents;\n for (int i = 0; i < a.inputs().size(); ++i) {\n if (auto it = tan_map.find(a.inputs()[i].id()); it != tan_map.end()) {\n argnums.push_back(i);\n tangents.push_back(it->second);\n }\n }\n\n auto jvps = a.primitive().jvp(a.inputs(), tangents, argnums);\n auto outputs = a.outputs();\n for (int i = 0; i < jvps.size(); ++i) {\n tan_map.insert({outputs[i].id(), jvps[i]});\n }\n }\n\n std::vector jvps;\n for (auto& out : outputs) {\n if (auto it = tan_map.find(out.id()); it != tan_map.end()) {\n jvps.push_back(it->second);\n } else {\n auto s = out.has_primitive() ? out.primitive().stream()\n : default_stream(default_device());\n jvps.push_back(zeros_like(out, s));\n }\n }\n return {outputs, jvps};\n}\n\nstd::pair jvp(\n const std::function& fun,\n const array& primal,\n const array& tangent) {\n auto vec_fun = [fun](const std::vector& inputs) {\n return std::vector{fun(inputs[0])};\n };\n auto [outputs, jvps] = jvp(vec_fun, {primal}, {tangent});\n return {outputs[0], jvps[0]};\n}\n\nValueAndGradFn value_and_grad(\n const std::function(const std::vector&)>& fun,\n const std::vector& argnums) {\n if (argnums.empty()) {\n throw std::invalid_argument(\"[grad] Must specify at least one argument.\");\n }\n return [fun, argnums](const std::vector& inputs) {\n std::set args;\n for (auto& arg : argnums) {\n args.insert(arg < 0 ? arg + inputs.size() : arg);\n }\n if (args.size() != argnums.size()) {\n throw std::invalid_argument(\n \"[grad] Repeat argument number not allowed in grad.\");\n }\n if (*args.begin() < 0 || *args.rbegin() >= inputs.size()) {\n std::ostringstream msg;\n msg << \"[grad] Invalid argument number for function with \"\n << inputs.size() << \" inputs.\";\n throw std::invalid_argument(msg.str());\n }\n\n auto gfun = [&fun, &inputs, &args](const std::vector& ginputs) {\n std::vector inputs_(inputs);\n auto argit = args.begin();\n for (int i = 0; i < ginputs.size(); ++i) {\n inputs_[*argit] = ginputs[i];\n ++argit;\n }\n auto outputs = fun(inputs_);\n for (int i = 1; i < outputs.size(); i++) {\n auto& out = outputs[i];\n auto s = out.has_primitive() ? out.primitive().stream()\n : default_stream(default_device());\n outputs[i] = stop_gradient(out, s);\n }\n return outputs;\n };\n\n std::vector ginputs;\n for (auto arg : args) {\n ginputs.push_back(inputs[arg]);\n }\n // Set the incoming gradient to int32, vjp will cast it to the output type\n auto [outputs, grads] = vjp(gfun, ginputs, {array(1.0f)});\n return std::make_pair(outputs, grads);\n };\n}\n\nnamespace detail {\n\nstd::pair, std::vector> vmap_trace(\n const std::function(const std::vector&)>& fun,\n const std::vector& inputs,\n const std::vector& in_axes) {\n // Set the global tracing flag.\n detail::InTracing in_tracing;\n\n if (in_axes.size() != inputs.size()) {\n throw std::invalid_argument(\n \"[vmap] The number of in axes must match the number of inputs.\");\n }\n\n // Some error checking and get the vmap axis size\n size_t vmap_ax_size;\n for (int i = 0; i < inputs.size(); ++i) {\n if (in_axes[i] != -1) {\n if (inputs[i].ndim() == 0) {\n throw std::invalid_argument(\n \"[vmap] Cannot vmap an input with zero dimensions.\");\n }\n if (in_axes[i] > inputs[i].ndim()) {\n std::ostringstream msg;\n msg << \"[vmap] Axis \" << in_axes[i] << \" invalid for input with \"\n << inputs[i].ndim() << \" dimensions.\";\n throw std::invalid_argument(msg.str());\n }\n vmap_ax_size = inputs[i].shape(in_axes[i]);\n }\n }\n // Check that all vmapped axes have the same size\n for (int i = 0; i < inputs.size(); ++i) {\n if (in_axes[i] != -1) {\n if (size_t in_ax = inputs[i].shape(in_axes[i]); vmap_ax_size != in_ax) {\n std::ostringstream msg;\n msg << \"[vmap] Inconsistent axis sizes: \" << in_ax << \" and \"\n << vmap_ax_size << \".\";\n throw std::invalid_argument(msg.str());\n }\n }\n }\n\n // Run the function on placeholder inputs\n // to get the original graph\n std::vector s_inputs;\n for (int i = 0; i < inputs.size(); ++i) {\n if (in_axes[i] != -1) {\n std::vector shape = inputs[i].shape();\n shape.erase(shape.begin() + in_axes[i]);\n array in(shape, inputs[i].dtype(), nullptr, {});\n s_inputs.push_back(in);\n s_inputs.back().set_tracer(true);\n } else {\n s_inputs.push_back(inputs[i]);\n }\n }\n return {s_inputs, fun(s_inputs)};\n}\n\nstd::vector vmap_replace(\n const std::vector& inputs,\n const std::vector& s_inputs,\n const std::vector& s_outputs,\n const std::vector& in_axes,\n const std::vector& out_axes) {\n if (out_axes.size() != s_outputs.size()) {\n throw std::invalid_argument(\n \"[vmap] The number of out axes must match the number of outputs.\");\n }\n\n std::unordered_map> tmap;\n std::unordered_set needs_vmap;\n std::unordered_set cache;\n for (int i = 0; i < s_inputs.size(); ++i) {\n auto in = s_inputs[i];\n if (in_axes[i] != -1) {\n tmap.insert({in.id(), {inputs[i], in_axes[i]}});\n needs_vmap.insert(in.id());\n in.set_tracer(false);\n }\n cache.insert(in.id());\n }\n\n // Topologically sort the graph\n std::vector tape;\n\n std::function recurse;\n\n recurse = [&](const array& a) {\n auto id = a.id();\n if (cache.find(id) != cache.end()) {\n return;\n }\n cache.insert(id);\n for (auto& s : a.siblings()) {\n cache.insert(s.id());\n }\n\n // Recurse on inputs\n for (auto& input : a.inputs()) {\n recurse(input);\n }\n // If any input needs a vmap, then the outputs also need\n // a vmap\n for (auto& input : a.inputs()) {\n if (needs_vmap.find(input.id()) != needs_vmap.end()) {\n tape.push_back(a);\n tape.back().set_tracer(false);\n needs_vmap.insert(a.id());\n for (auto s : a.siblings()) {\n needs_vmap.insert(s.id());\n s.set_tracer(false);\n }\n break;\n }\n }\n };\n\n for (auto& out : s_outputs) {\n if (out.has_primitive()) {\n recurse(out);\n }\n }\n\n // Transform each primitive in the graph with\n // its vmap implementation\n for (auto& a : tape) {\n std::vector v_inputs;\n std::vector v_axes;\n for (auto& in : a.inputs()) {\n auto map_it = tmap.find(in.id());\n if (map_it != tmap.end()) {\n v_inputs.push_back(map_it->second.first);\n v_axes.push_back(map_it->second.second);\n } else {\n v_inputs.push_back(in);\n v_axes.push_back(-1);\n }\n }\n\n auto [v_outputs, v_out_axes] = a.primitive().vmap(v_inputs, v_axes);\n\n // For each primitive's outputs add its id, the vout id and the vax\n auto outputs = a.outputs();\n for (int i = 0; i < v_outputs.size(); ++i) {\n tmap.insert({outputs[i].id(), {v_outputs[i], v_out_axes[i]}});\n }\n }\n\n // Populate the outputs and make sure all the output axes are\n // in the right place\n std::vector outputs;\n for (int i = 0; i < s_outputs.size(); ++i) {\n if (auto map_it = tmap.find(s_outputs[i].id()); map_it != tmap.end()) {\n auto& [out, vdim] = map_it->second;\n if (vdim != out_axes[i]) {\n if (out_axes[i] >= out.ndim()) {\n std::ostringstream msg;\n msg << \"[vmap] Axis \" << out_axes[i] << \" invalid for output with \"\n << out.ndim() << \" dimensions.\";\n throw std::invalid_argument(msg.str());\n }\n out = moveaxis(out, vdim, out_axes[i]);\n }\n outputs.push_back(out);\n } else {\n outputs.push_back(s_outputs[i]);\n }\n }\n return outputs;\n}\n\n} // namespace detail\n\nstd::function(const std::vector&)> vmap(\n const std::function(const std::vector&)>& fun,\n const std::vector& in_axes /* = {} */,\n const std::vector& out_axes /* = {} */) {\n auto infer_axes = [](auto axes) {\n return !axes.empty() &&\n std::all_of(axes.begin(), axes.end(), [](int ax) { return ax < 0; });\n };\n if (infer_axes(in_axes) != infer_axes(out_axes)) {\n throw std::invalid_argument(\n \"[vmap] Input (or output) axes must be \"\n \"specified if output (or input) axes are.\");\n }\n auto vfun = [fun, in_axes = in_axes, out_axes = out_axes](\n const std::vector& inputs) mutable {\n if (in_axes.size() == 0) {\n in_axes.resize(inputs.size(), 0);\n }\n\n auto [trace_inputs, trace_outputs] =\n detail::vmap_trace(fun, inputs, in_axes);\n\n if (out_axes.size() == 0) {\n out_axes.resize(trace_outputs.size(), 0);\n }\n\n return detail::vmap_replace(\n inputs, trace_inputs, trace_outputs, in_axes, out_axes);\n };\n\n return vfun;\n}\n\nstd::function vmap(\n const std::function& fun,\n int in_axis_a /* = 0 */,\n int in_axis_b /* = 0 */,\n int out_axis /* = 0 */) {\n auto vfun = vmap(\n [in_axis_a, in_axis_b, out_axis, fun](const std::vector& inputs) {\n return std::vector{fun(inputs[0], inputs[1])};\n },\n {in_axis_a, in_axis_b},\n {out_axis});\n return [vfun](const array& a, const array& b) { return vfun({a, b})[0]; };\n}\n\nstd::function vmap(\n const std::function& fun,\n int in_axis /* = 0 */,\n int out_axis /* = 0 */) {\n auto vfun = vmap(\n [in_axis, out_axis, fun](const std::vector& inputs) {\n return std::vector{fun(inputs[0])};\n },\n {in_axis},\n {out_axis});\n return [vfun](const array& a) { return vfun({a})[0]; };\n}\n\nstd::function(const std::vector&)> custom_vjp(\n std::function(const std::vector&)> fun,\n std::function(\n const std::vector&,\n const std::vector&,\n const std::vector&)> fun_vjp) {\n return [fun = std::move(fun),\n fun_vjp = std::move(fun_vjp)](const std::vector& args) {\n // Compute the outputs\n auto outputs = fun(args);\n for (auto& out : outputs) {\n out = stop_gradient(out);\n }\n\n // Prepare the inputs to the primitive\n // We also add the outputs to the primitive so that it can \"run\" the forward\n // pass.\n std::vector inputs = args;\n inputs.insert(inputs.end(), outputs.begin(), outputs.end());\n\n // Compute the stream. Maybe do it in a smarter way at some point in the\n // future.\n Stream s = (outputs[0].has_primitive()) ? outputs[0].primitive().stream()\n : default_stream(default_device());\n\n // Make the output info\n std::vector> shapes;\n std::vector dtypes;\n for (const auto& out : outputs) {\n shapes.emplace_back(out.shape());\n dtypes.emplace_back(out.dtype());\n }\n\n return array::make_arrays(\n std::move(shapes),\n dtypes,\n std::make_shared(to_stream(s), fun_vjp),\n inputs);\n };\n}\n\nstd::function(const std::vector&)> checkpoint(\n std::function(const std::vector&)> fun) {\n auto vjp_fun = [fun](\n const std::vector& primals,\n const std::vector& cotangents,\n const std::vector& outputs) -> std::vector {\n auto [__, vjps] = vjp(fun, depends(primals, outputs), cotangents);\n return vjps;\n };\n\n return custom_vjp(fun, vjp_fun);\n}\n\n} // namespace mlx::core\n\n// Path: mlx/linalg.cpp\n// Copyright \u00a9 2023 Apple Inc.\n\n#include \n#include \n#include \n\n#include \"mlx/linalg.h\"\n#include \"mlx/primitives.h\"\n#include \"mlx/utils.h\"\n\nnamespace mlx::core::linalg {\n\nDtype at_least_float(const Dtype& d) {\n return issubdtype(d, inexact) ? d : promote_types(d, float32);\n}\n\ninline array l2_norm(\n const array& a,\n const std::vector& axis,\n bool keepdims,\n StreamOrDevice s) {\n if (issubdtype(a.dtype(), complexfloating)) {\n return sqrt(sum(abs(a, s) * abs(a, s), axis, keepdims, s), s);\n } else {\n return sqrt(sum(square(a, s), axis, keepdims, s), s);\n }\n}\n\ninline array vector_norm(\n const array& a,\n const double ord,\n const std::vector& axis,\n bool keepdims,\n StreamOrDevice s) {\n auto dtype = at_least_float(a.dtype());\n if (ord == 0.0) {\n return astype(sum(not_equal(a, array(0), s), axis, keepdims, s), dtype, s);\n } else if (ord == 1.0) {\n return astype(sum(abs(a, s), axis, keepdims, s), dtype, s);\n } else if (ord == 2.0) {\n return l2_norm(a, axis, keepdims, s);\n } else if (ord == std::numeric_limits::infinity()) {\n return astype(max(abs(a, s), axis, keepdims, s), dtype, s);\n } else if (ord == -std::numeric_limits::infinity()) {\n return astype(min(abs(a, s), axis, keepdims, s), dtype, s);\n } else {\n return power(\n sum(power(abs(a, s), array(ord, dtype), s), axis, keepdims, s),\n array(1.0 / ord, dtype),\n s);\n }\n}\n\ninline array matrix_norm(\n const array& a,\n const double ord,\n const std::vector& axis,\n bool keepdims,\n StreamOrDevice s) {\n auto dtype = at_least_float(a.dtype());\n auto row_axis = axis[0];\n auto col_axis = axis[1];\n if (ord == -1.0) {\n col_axis -= (!keepdims && col_axis > row_axis && col_axis > 0);\n return astype(\n min(sum(abs(a, s), row_axis, keepdims, s), col_axis, keepdims, s),\n dtype,\n s);\n...\n", "needle_token_start": 144, "needle_token_end": 172, "code_context_ntokens": 12323, "output": ["```cpp\nbool in_tracing() {\n return detail::InTracing::in_tracing();\n}\n```"]} +{"repo": "ml-explore/mlx", "name": "InTracing", "language": "cpp", "path": "mlx/transforms_impl.h", "position_ratio": 0.35, "description": "\nFunction Description:\n1. **Purpose**: The purpose of this function is to manage and indicate the state of tracing operations within a computational graph, ensuring that during these operations, evaluations do not discard the graph.\n2. **Input**: This function does not take any explicit parameters as input.\n3. **Output**: There is no direct output returned by this function; instead, it affects the state of a counter used to track the tracing status.\n4. **Procedure**: Upon instantiation, the function increments a global counter that signifies an ongoing tracing operation. This increment helps other parts of the codebase recognize that a tracing process is active, thereby altering their behavior to preserve computational graphs during evaluations.\n\n", "instruction": "Based on the function description and code context, please retrieve and repeat the exact described function from the code context in a code block wrapped by ```:", "template": "instruction\ncode_context\ndescription\ninstruction", "code_context": "// Path: mlx/ops.h\n// Copyright \u00a9 2023-2024 Apple Inc.\n\n#pragma once\n\n#include \n\n#include \"mlx/array.h\"\n#include \"mlx/device.h\"\n#include \"mlx/stream.h\"\n#include \"mlx/utils.h\"\n\nnamespace mlx::core {\n\n/** Creation operations */\n\n/**\n * A 1D array of numbers starting at `start` (optional),\n * stopping at stop, stepping by `step` (optional). */\narray arange(\n double start,\n double stop,\n double step,\n Dtype dtype,\n StreamOrDevice s = {});\narray arange(double start, double stop, double step, StreamOrDevice s = {});\narray arange(double start, double stop, Dtype dtype, StreamOrDevice s = {});\narray arange(double start, double stop, StreamOrDevice s = {});\narray arange(double stop, Dtype dtype, StreamOrDevice s = {});\narray arange(double stop, StreamOrDevice s = {});\n\narray arange(int start, int stop, int step, StreamOrDevice s = {});\narray arange(int start, int stop, StreamOrDevice s = {});\narray arange(int stop, StreamOrDevice s = {});\n\n/** A 1D array of `num` evenly spaced numbers in the range `[start, stop]` */\narray linspace(\n double start,\n double stop,\n int num = 50,\n Dtype dtype = float32,\n StreamOrDevice s = {});\n\n/** Convert an array to the given data type. */\narray astype(array a, Dtype dtype, StreamOrDevice s = {});\n\n/** Create a view of an array with the given shape and strides. */\narray as_strided(\n array a,\n std::vector shape,\n std::vector strides,\n size_t offset,\n StreamOrDevice s = {});\n\n/** Copy another array. */\narray copy(array a, StreamOrDevice s = {});\n\n/** Fill an array of the given shape with the given value(s). */\narray full(\n std::vector shape,\n array vals,\n Dtype dtype,\n StreamOrDevice s = {});\narray full(std::vector shape, array vals, StreamOrDevice s = {});\ntemplate \narray full(std::vector shape, T val, Dtype dtype, StreamOrDevice s = {}) {\n return full(std::move(shape), array(val, dtype), to_stream(s));\n}\ntemplate \narray full(std::vector shape, T val, StreamOrDevice s = {}) {\n return full(std::move(shape), array(val), to_stream(s));\n}\n\n/** Fill an array of the given shape with zeros. */\narray zeros(const std::vector& shape, Dtype dtype, StreamOrDevice s = {});\ninline array zeros(const std::vector& shape, StreamOrDevice s = {}) {\n return zeros(shape, float32, s);\n}\narray zeros_like(const array& a, StreamOrDevice s = {});\n\n/** Fill an array of the given shape with ones. */\narray ones(const std::vector& shape, Dtype dtype, StreamOrDevice s = {});\ninline array ones(const std::vector& shape, StreamOrDevice s = {}) {\n return ones(shape, float32, s);\n}\narray ones_like(const array& a, StreamOrDevice s = {});\n\n/** Fill an array of the given shape (n,m) with ones in the specified diagonal\n * k, and zeros everywhere else. */\narray eye(int n, int m, int k, Dtype dtype, StreamOrDevice s = {});\ninline array eye(int n, Dtype dtype, StreamOrDevice s = {}) {\n return eye(n, n, 0, dtype, s);\n}\ninline array eye(int n, int m, StreamOrDevice s = {}) {\n return eye(n, m, 0, float32, s);\n}\ninline array eye(int n, int m, int k, StreamOrDevice s = {}) {\n return eye(n, m, k, float32, s);\n}\ninline array eye(int n, StreamOrDevice s = {}) {\n return eye(n, n, 0, float32, s);\n}\n\n/** Create a square matrix of shape (n,n) of zeros, and ones in the major\n * diagonal. */\narray identity(int n, Dtype dtype, StreamOrDevice s = {});\ninline array identity(int n, StreamOrDevice s = {}) {\n return identity(n, float32, s);\n}\n\narray tri(int n, int m, int k, Dtype type, StreamOrDevice s = {});\ninline array tri(int n, Dtype type, StreamOrDevice s = {}) {\n return tri(n, n, 0, type, s);\n}\n\narray tril(array x, int k = 0, StreamOrDevice s = {});\narray triu(array x, int k = 0, StreamOrDevice s = {});\n\n/** array manipulation */\n\n/** Reshape an array to the given shape. */\narray reshape(const array& a, std::vector shape, StreamOrDevice s = {});\n\n/** Flatten the dimensions in the range `[start_axis, end_axis]` . */\narray flatten(\n const array& a,\n int start_axis,\n int end_axis = -1,\n StreamOrDevice s = {});\n\n/** Flatten the array to 1D. */\narray flatten(const array& a, StreamOrDevice s = {});\n\n/** Remove singleton dimensions at the given axes. */\narray squeeze(\n const array& a,\n const std::vector& axes,\n StreamOrDevice s = {});\n\n/** Remove singleton dimensions at the given axis. */\ninline array squeeze(const array& a, int axis, StreamOrDevice s = {}) {\n return squeeze(a, std::vector{axis}, s);\n}\n\n/** Remove all singleton dimensions. */\narray squeeze(const array& a, StreamOrDevice s = {});\n\n/** Add a singleton dimension at the given axes. */\narray expand_dims(\n const array& a,\n const std::vector& axes,\n StreamOrDevice s = {});\n\n/** Add a singleton dimension at the given axis. */\narray expand_dims(const array& a, int axis, StreamOrDevice s = {});\n\n/** Slice an array. */\narray slice(\n const array& a,\n std::vector start,\n std::vector stop,\n std::vector strides,\n StreamOrDevice s = {});\n\n/** Slice an array with a stride of 1 in each dimension. */\narray slice(\n const array& a,\n const std::vector& start,\n const std::vector& stop,\n StreamOrDevice s = {});\n\n/** Update a slice from the source array */\narray slice_update(\n const array& src,\n const array& update,\n std::vector start,\n std::vector stop,\n std::vector strides,\n StreamOrDevice s = {});\n\n/** Update a slice from the source array with stride 1 in each dimension */\narray slice_update(\n const array& src,\n const array& update,\n std::vector start,\n std::vector stop,\n StreamOrDevice s = {});\n\n/** Split an array into sub-arrays along a given axis. */\nstd::vector\nsplit(const array& a, int num_splits, int axis, StreamOrDevice s = {});\nstd::vector split(const array& a, int num_splits, StreamOrDevice s = {});\nstd::vector split(\n const array& a,\n const std::vector& indices,\n int axis,\n StreamOrDevice s = {});\nstd::vector\nsplit(const array& a, const std::vector& indices, StreamOrDevice s = {});\n\n/** A vector of coordinate arrays from coordinate vectors. */\nstd::vector meshgrid(\n const std::vector& arrays,\n bool sparse = false,\n std::string indexing = \"xy\",\n StreamOrDevice s = {});\n\n/**\n * Clip (limit) the values in an array.\n */\narray clip(\n const array& a,\n const std::optional& a_min = std::nullopt,\n const std::optional& a_max = std::nullopt,\n StreamOrDevice s = {});\n\n/** Concatenate arrays along a given axis. */\narray concatenate(\n const std::vector& arrays,\n int axis,\n StreamOrDevice s = {});\narray concatenate(const std::vector& arrays, StreamOrDevice s = {});\n\n/** Stack arrays along a new axis. */\narray stack(const std::vector& arrays, int axis, StreamOrDevice s = {});\narray stack(const std::vector& arrays, StreamOrDevice s = {});\n\n/** Repeat an array along an axis. */\narray repeat(const array& arr, int repeats, int axis, StreamOrDevice s = {});\narray repeat(const array& arr, int repeats, StreamOrDevice s = {});\n\narray tile(const array& arr, std::vector reps, StreamOrDevice s = {});\n\n/** Permutes the dimensions according to the given axes. */\narray transpose(const array& a, std::vector axes, StreamOrDevice s = {});\ninline array transpose(\n const array& a,\n std::initializer_list axes,\n StreamOrDevice s = {}) {\n return transpose(a, std::vector(axes), s);\n}\n\n/** Swap two axes of an array. */\narray swapaxes(const array& a, int axis1, int axis2, StreamOrDevice s = {});\n\n/** Move an axis of an array. */\narray moveaxis(\n const array& a,\n int source,\n int destination,\n StreamOrDevice s = {});\n\n/** Pad an array with a constant value */\narray pad(\n const array& a,\n const std::vector& axes,\n const std::vector& low_pad_size,\n const std::vector& high_pad_size,\n const array& pad_value = array(0),\n StreamOrDevice s = {});\n\n/** Pad an array with a constant value along all axes */\narray pad(\n const array& a,\n const std::vector>& pad_width,\n const array& pad_value = array(0),\n StreamOrDevice s = {});\narray pad(\n const array& a,\n const std::pair& pad_width,\n const array& pad_value = array(0),\n StreamOrDevice s = {});\narray pad(\n const array& a,\n int pad_width,\n const array& pad_value = array(0),\n StreamOrDevice s = {});\n\n/** Permutes the dimensions in reverse order. */\narray transpose(const array& a, StreamOrDevice s = {});\n\n/** Broadcast an array to a given shape. */\narray broadcast_to(\n const array& a,\n const std::vector& shape,\n StreamOrDevice s = {});\n\n/** Broadcast a vector of arrays against one another. */\nstd::vector broadcast_arrays(\n const std::vector& inputs,\n StreamOrDevice s = {});\n\n/** Comparison operations */\n\n/** Returns the bool array with (a == b) element-wise. */\narray equal(const array& a, const array& b, StreamOrDevice s = {});\ninline array operator==(const array& a, const array& b) {\n return equal(a, b);\n}\ntemplate \narray operator==(T a, const array& b) {\n return equal(array(a), b);\n}\ntemplate \narray operator==(const array& a, T b) {\n return equal(a, array(b));\n}\n\n/** Returns the bool array with (a != b) element-wise. */\narray not_equal(const array& a, const array& b, StreamOrDevice s = {});\ninline array operator!=(const array& a, const array& b) {\n return not_equal(a, b);\n}\ntemplate \narray operator!=(T a, const array& b) {\n return not_equal(array(a), b);\n}\ntemplate \narray operator!=(const array& a, T b) {\n return not_equal(a, array(b));\n}\n\n/** Returns bool array with (a > b) element-wise. */\narray greater(const array& a, const array& b, StreamOrDevice s = {});\ninline array operator>(const array& a, const array& b) {\n...\n// Path: mlx/random.h\n// Copyright \u00a9 2023 Apple Inc.\n\n#pragma once\n\n#include \n#include \n\n#include \"mlx/array.h\"\n#include \"mlx/stream.h\"\n\nnamespace mlx::core::random {\n\nclass KeySequence {\n public:\n explicit KeySequence(uint64_t seed);\n\n void seed(uint64_t seed);\n array next();\n\n // static default\n static KeySequence& default_() {\n static KeySequence ks(get_current_time_seed());\n return ks;\n }\n\n private:\n array key_;\n static uint64_t get_current_time_seed() {\n auto now = std::chrono::system_clock::now();\n return std::chrono::duration_cast(\n now.time_since_epoch())\n .count();\n }\n};\n\n/** Get a PRNG key from a seed. */\narray key(uint64_t seed);\n\n/** Seed the default PRNG key. */\nvoid seed(uint64_t seed);\n\n/** Generate an array with type uint32 filled with random bits. */\narray bits(\n const std::vector& shape,\n int width,\n const std::optional& key = std::nullopt,\n StreamOrDevice s = {});\ninline array bits(\n const std::vector& shape,\n const std::optional& key = std::nullopt,\n StreamOrDevice s = {}) {\n return bits(shape, 4, key, s);\n}\n\n/** Split the rng key into a pair of keys. */\nstd::pair split(const array& key, StreamOrDevice s = {});\n\n/** Split the rng key into `num` keys. */\narray split(const array& key, int num, StreamOrDevice s = {});\n\n/** Generate uniform random numbers between low and high. */\narray uniform(\n const array& low,\n const array& high,\n const std::vector& shape,\n Dtype dtype = float32,\n const std::optional& key = std::nullopt,\n StreamOrDevice s = {});\n\ntemplate \narray uniform(\n T low,\n U high,\n const std::vector& shape,\n Dtype dtype = float32,\n const std::optional& key = std::nullopt,\n StreamOrDevice s = {}) {\n return uniform(array(low), array(high), shape, dtype, key, to_stream(s));\n}\n\n/** Generate uniform random numbers between 0 and 1. */\narray uniform(\n const std::vector& shape,\n Dtype dtype,\n const std::optional& key = std::nullopt,\n StreamOrDevice s = {});\ninline array uniform(\n const std::vector& shape,\n const std::optional& key = std::nullopt,\n StreamOrDevice s = {}) {\n return uniform(shape, float32, key);\n}\n\n/** Generate samples from the standard normal distribution. */\narray normal(\n const std::vector& shape,\n Dtype dtype,\n const float loc,\n const float scale,\n const std::optional& key = std::nullopt,\n StreamOrDevice s = {});\ninline array normal(\n const std::vector& shape,\n const float loc,\n const float scale,\n const std::optional& key = std::nullopt,\n StreamOrDevice s = {}) {\n return normal(shape, float32, loc, scale, key, s);\n}\ninline array normal(\n const std::vector& shape,\n const Dtype dtype,\n const std::optional& key = std::nullopt,\n StreamOrDevice s = {}) {\n return normal(shape, dtype, 0.0, 1.0, key, s);\n}\ninline array normal(\n const std::vector& shape,\n const std::optional& key = std::nullopt,\n StreamOrDevice s = {}) {\n return normal(shape, float32, 0.0, 1.0, key, s);\n}\n\n/** Generate samples from a multivariate normal distribution. **/\narray multivariate_normal(\n const array& mean,\n const array& cov,\n const std::vector& shape,\n Dtype dtype,\n const std::optional& key = std::nullopt,\n StreamOrDevice s = {});\n\n/** Generate integer samples uniformly at random */\narray randint(\n const array& low,\n const array& high,\n const std::vector& shape,\n Dtype dtype = int32,\n const std::optional& key = std::nullopt,\n StreamOrDevice s = {});\n\ntemplate \narray randint(\n T low,\n U high,\n const std::vector& shape,\n Dtype dtype = int32,\n const std::optional& key = std::nullopt,\n StreamOrDevice s = {}) {\n return randint(array(low), array(high), shape, dtype, key, to_stream(s));\n};\n\n/** Generate binary variables with probability to be true equal to p */\narray bernoulli(\n const array& p,\n const std::vector& shape,\n const std::optional& key = std::nullopt,\n StreamOrDevice s = {});\narray bernoulli(\n const array& p,\n const std::optional& key = std::nullopt,\n StreamOrDevice s = {});\n\ntemplate \narray bernoulli(\n T p,\n const std::optional& key = std::nullopt,\n StreamOrDevice s = {}) {\n return bernoulli(array(p), key, s);\n};\n\ntemplate \narray bernoulli(\n T p,\n const std::vector& shape,\n const std::optional& key = std::nullopt,\n StreamOrDevice s = {}) {\n return bernoulli(array(p), shape, key, s);\n};\n\narray bernoulli(\n const std::optional& key = std::nullopt,\n StreamOrDevice s = {});\n\narray truncated_normal(\n const array& lower,\n const array& upper,\n const std::vector& shape,\n Dtype dtype = float32,\n const std::optional& key = std::nullopt,\n StreamOrDevice s = {});\n\narray truncated_normal(\n const array& lower,\n const array& upper,\n Dtype dtype = float32,\n const std::optional& key = std::nullopt,\n StreamOrDevice s = {});\n\narray gumbel(\n const std::vector& shape,\n Dtype dtype = float32,\n const std::optional& key = std::nullopt,\n StreamOrDevice s = {});\n\narray categorical(\n const array& logits,\n int axis,\n const std::vector& shape,\n const std::optional& key = std::nullopt,\n StreamOrDevice s = {});\n\narray categorical(\n const array& logits_,\n int axis,\n int num_samples,\n const std::optional& key = std::nullopt,\n StreamOrDevice s = {});\n\narray categorical(\n const array& logits,\n int axis = -1,\n const std::optional& key = std::nullopt,\n StreamOrDevice s = {});\n\n} // namespace mlx::core::random\n\n// Path: mlx/transforms_impl.h\n// Copyright \u00a9 2023-2024 Apple Inc.\n\n#pragma once\n\nnamespace mlx::core::detail {\n\nstd::pair, std::vector> vmap_trace(\n const std::function(const std::vector&)>& fun,\n const std::vector& inputs,\n const std::vector& in_axes);\n\nstd::vector vmap_replace(\n const std::vector& inputs,\n const std::vector& s_inputs,\n const std::vector& s_outputs,\n const std::vector& in_axes,\n const std::vector& out_axes);\n\n// This is not part of the general C++ API as calling with a bad id is a bad\n// idea.\nstd::function(const std::vector&)> compile(\n const std::function(const std::vector&)>& fun,\n std::uintptr_t fun_id,\n bool shapeless = false,\n std::vector constants = {});\n\n// Erase cached compile functions\nvoid compile_erase(std::uintptr_t fun_id);\n\n// Create an InTracing object during tracing operations to signify to the rest\n// of the codebase that we are during tracing so evals should not throw away\n// the graph.\nstruct InTracing {\n I\nnTracing() {\n tracing_counter++;\n }\n ~InTracing() {\n tracing_counter--;\n }\n\n static bool in_tracing() {\n return tracing_counter > 0;\n }\n\n private:\n static int tracing_counter;\n};\n\n} // namespace mlx::core::detail\n\n// Path: mlx/fast_primitives.h\n// Copyright \u00a9 2024 Apple Inc.\n\n#include \"mlx/primitives.h\"\n\nnamespace mlx::core::fast {\n\n// Custom primitive accepts a fallback function which it uses for\n// transformations. Transformations are virtual so that derived classes may\n// override the default behavior.\nclass Custom : public Primitive {\n public:\n explicit Custom(\n Stream stream,\n std::function(std::vector)> fallback)\n : Primitive(stream), fallback_(fallback){};\n\n virtual std::pair, std::vector> vmap(\n const std::vector& inputs,\n const std::vector& axes) override;\n\n virtual std::vector jvp(\n const std::vector& primals,\n const std::vector& tangents,\n const std::vector& argnums) override;\n\n virtual std::vector vjp(\n const std::vector& primals,\n const std::vector& cotangents,\n const std::vector& argnums,\n const std::vector& outputs) override;\n\n private:\n std::function(std::vector)> fallback_;\n};\n\nclass RMSNorm : public Custom {\n public:\n RMSNorm(\n Stream stream,\n std::function(std::vector)> fallback,\n float eps)\n : Custom(stream, fallback), eps_(eps){};\n\n void eval_cpu(const std::vector& inputs, std::vector& outputs)\n override {\n throw std::runtime_error(\"NYI\");\n };\n void eval_gpu(const std::vector& inputs, std::vector& outputs)\n override;\n\n std::vector vjp(\n const std::vector& primals,\n const std::vector& cotangents,\n const std::vector& argnums,\n const std::vector& outputs) override;\n\n DEFINE_PRINT(RMSNorm)\n bool is_equivalent(const Primitive& other) const override;\n\n private:\n std::function(std::vector)> fallback_;\n float eps_;\n};\n\nclass RMSNormVJP : public Custom {\n public:\n RMSNormVJP(\n Stream stream,\n std::function(std::vector)> fallback,\n float eps)\n : Custom(stream, fallback), eps_(eps){};\n\n void eval_cpu(const std::vector& inputs, std::vector& outputs)\n override {\n throw std::runtime_error(\"NYI\");\n };\n void eval_gpu(const std::vector& inputs, std::vector& outputs)\n override;\n\n DEFINE_PRINT(RMSNormVJP)\n bool is_equivalent(const Primitive& other) const override;\n\n private:\n std::function(std::vector)> fallback_;\n float eps_;\n};\n\nclass LayerNorm : public Custom {\n public:\n LayerNorm(\n Stream stream,\n std::function(std::vector)> fallback,\n float eps)\n : Custom(stream, fallback), eps_(eps){};\n\n void eval_cpu(const std::vector& inputs, std::vector& outputs)\n override {\n throw std::runtime_error(\"NYI\");\n };\n void eval_gpu(const std::vector& inputs, std::vector& outputs)\n override;\n\n std::vector vjp(\n const std::vector& primals,\n const std::vector& cotangents,\n const std::vector& argnums,\n const std::vector& outputs) override;\n\n DEFINE_PRINT(LayerNorm)\n bool is_equivalent(const Primitive& other) const override;\n\n private:\n std::function(std::vector)> fallback_;\n float eps_;\n};\n\nclass LayerNormVJP : public Custom {\n public:\n LayerNormVJP(\n Stream stream,\n std::function(std::vector)> fallback,\n float eps)\n : Custom(stream, fallback), eps_(eps){};\n\n void eval_cpu(const std::vector& inputs, std::vector& outputs)\n override {\n throw std::runtime_error(\"NYI\");\n };\n void eval_gpu(const std::vector& inputs, std::vector& outputs)\n override;\n\n DEFINE_PRINT(LayerNormVJP)\n bool is_equivalent(const Primitive& other) const override;\n\n private:\n std::function(std::vector)> fallback_;\n float eps_;\n};\n\nclass RoPE : public Custom {\n public:\n RoPE(\n Stream stream,\n std::function(std::vector)> fallback,\n int dims,\n bool traditional,\n float base,\n float scale,\n int offset,\n bool forward)\n : Custom(stream, fallback),\n dims_(dims),\n traditional_(traditional),\n base_(base),\n scale_(scale),\n offset_(offset),\n forward_(forward){};\n\n void eval_cpu(const std::vector& inputs, std::vector& outputs)\n override {\n throw std::runtime_error(\"NYI\");\n };\n void eval_gpu(const std::vector& inputs, std::vector& outputs)\n override;\n\n std::vector vjp(\n const std::vector& primals,\n const std::vector& cotangents,\n const std::vector& argnums,\n const std::vector& outputs) override;\n\n DEFINE_PRINT(RoPE)\n bool is_equivalent(const Primitive& other) const override;\n\n private:\n std::function(std::vector)> fallback_;\n int dims_;\n bool traditional_;\n float base_;\n float scale_;\n int offset_;\n bool forward_;\n};\n\nclass ScaledDotProductAttention : public Custom {\n public:\n explicit ScaledDotProductAttention(\n Stream stream,\n std::function(std::vector)> fallback,\n const float scale,\n const bool needs_mask)\n : Custom(stream, fallback), scale_(scale), needs_mask_(needs_mask){};\n\n void eval_cpu(const std::vector& inputs, std::vector& outputs)\n override {\n throw std::runtime_error(\"NYI\");\n };\n\n void eval_gpu(const std::vector& inputs, std::vector& outputs)\n override {\n eval_gpu(inputs, outputs[0]);\n };\n\n void eval_gpu(const std::vector& inputs, array& out);\n bool is_equivalent(const Primitive& other) const override;\n\n DEFINE_PRINT(ScaledDotProductAttention)\n\n private:\n std::function(std::vector)> fallback_;\n float scale_;\n bool needs_mask_;\n};\n\n} // namespace mlx::core::fast\n\n// Path: mlx/compile_impl.h\n// Copyright \u00a9 2023-2024 Apple Inc.\n\n#pragma once\n\n#include \"mlx/device.h\"\n\nnamespace mlx::core::detail {\n\nbool compile_available_for_device(const Device& device);\n\n}\n\n// Path: mlx/allocator.h\n// Copyright \u00a9 2023 Apple Inc.\n\n#pragma once\n\n#include \n\nnamespace mlx::core::allocator {\n\n// Simple wrapper around buffer pointers\n// WARNING: Only Buffer objects constructed from and those that wrap\n// raw pointers from mlx::allocator are supported.\nclass Buffer {\n private:\n void* ptr_;\n\n public:\n Buffer(void* ptr) : ptr_(ptr){};\n\n // Get the raw data pointer from the buffer\n void* raw_ptr();\n\n // Get the buffer pointer from the buffer\n const void* ptr() const {\n return ptr_;\n };\n void* ptr() {\n return ptr_;\n };\n};\n\nBuffer malloc(size_t size);\n\nvoid free(Buffer buffer);\n\n// Wait for running tasks to finish and free up memory\n// if allocation fails\nBuffer malloc_or_wait(size_t size);\n\nclass Allocator {\n /** Abstract base class for a memory allocator. */\n public:\n virtual Buffer malloc(size_t size, bool allow_swap = false) = 0;\n virtual void free(Buffer buffer) = 0;\n\n Allocator() = default;\n Allocator(const Allocator& other) = delete;\n Allocator(Allocator&& other) = delete;\n Allocator& operator=(const Allocator& other) = delete;\n Allocator& operator=(Allocator&& other) = delete;\n virtual ~Allocator() = default;\n};\n\nAllocator& allocator();\n\nclass CommonAllocator : public Allocator {\n /** A general CPU allocator. */\n public:\n virtual Buffer malloc(size_t size, bool allow_swap = false) override;\n virtual void free(Buffer buffer) override;\n\n private:\n CommonAllocator() = default;\n friend Allocator& allocator();\n};\n\n} // namespace mlx::core::allocator\n\n// Path: mlx/io/load.h\n// Copyright \u00a9 2023 Apple Inc.\n\n#pragma once\n\n#include \n#include \n#include \n\nnamespace mlx::core {\n\nnamespace io {\n\nclass Reader {\n public:\n virtual bool is_open() const = 0;\n virtual bool good() const = 0;\n virtual size_t tell() = 0; // tellp is non-const in iostream\n virtual void seek(\n int64_t off,\n std::ios_base::seekdir way = std::ios_base::beg) = 0;\n virtual void read(char* data, size_t n) = 0;\n virtual std::string label() const = 0;\n};\n\nclass Writer {\n public:\n virtual bool is_open() const = 0;\n virtual bool good() const = 0;\n virtual size_t tell() = 0;\n virtual void seek(\n int64_t off,\n std::ios_base::seekdir way = std::ios_base::beg) = 0;\n virtual void write(const char* data, size_t n) = 0;\n virtual std::string label() const = 0;\n};\n\nclass FileReader : public Reader {\n public:\n explicit FileReader(std::ifstream is)\n : is_(std::move(is)), label_(\"stream\") {}\n explicit FileReader(std::string file_path)\n : is_(std::ifstream(file_path, std::ios::binary)),\n label_(std::move(file_path)) {}\n\n bool is_open() const override {\n return is_.is_open();\n }\n\n bool good() const override {\n return is_.good();\n }\n\n size_t tell() override {\n return is_.tellg();\n }\n\n void seek(int64_t off, std::ios_base::seekdir way = std::ios_base::beg)\n override {\n is_.seekg(off, way);\n }\n\n void read(char* data, size_t n) override {\n is_.read(data, n);\n }\n\n std::string label() const override {\n return \"file \" + label_;\n }\n\n private:\n std::ifstream is_;\n std::string label_;\n};\n\nclass FileWriter : public Writer {\n public:\n explicit FileWriter(std::ofstream os)\n : os_(std::move(os)), label_(\"stream\") {}\n explicit FileWriter(std::string file_path)\n : os_(std::ofstream(file_path, std::ios::binary)),\n label_(std::move(file_path)) {}\n\n bool is_open() const override {\n return os_.is_open();\n }\n\n bool good() const override {\n return os_.good();\n }\n\n size_t tell() override {\n return os_.tellp();\n }\n\n void seek(int64_t off, std::ios_base::seekdir way = std::ios_base::beg)\n override {\n os_.seekp(off, way);\n }\n\n void write(const char* data, size_t n) override {\n os_.write(data, n);\n }\n\n std::string label() const override {\n return \"file \" + label_;\n }\n\n private:\n std::ofstream os_;\n std::string label_;\n};\n\n} // namespace io\n} // namespace mlx::core\n\n// Path: mlx/io/gguf.h\n// Copyright \u00a9 2023-2024 Apple Inc.\n#pragma once\n\n#include \"mlx/io.h\"\n#include \"mlx/primitives.h\"\n#include \"mlx/transforms.h\"\n#include \"mlx/utils.h\"\n\nextern \"C\" {\n#include \n}\n\nnamespace mlx::core {\n\nstd::vector get_shape(const gguf_tensor& tensor);\nvoid gguf_load_quantized(\n std::unordered_map& a,\n const gguf_tensor& tensor);\n\n} // namespace mlx::core\n\n// Path: mlx/types/bf16.h\n// Copyright \u00a9 2023 Apple Inc.\n\n#pragma once\n\n#include \n#include \n#include \n#include \n\n#define __MLX_BFLOAT_NAN__ 0x7FC0\n\nnamespace mlx::core {\n\nnamespace {\nunion float_bits_bf16 {\n float f;\n uint32_t u;\n};\n} // namespace\n\nstruct _MLX_BFloat16 {\n uint16_t bits_;\n\n // Default constructor\n _MLX_BFloat16() = default;\n\n // Default copy constructor\n _MLX_BFloat16(_MLX_BFloat16 const&) = default;\n\n // Appease std::vector for being special\n _MLX_BFloat16& operator=(std::vector::reference x) {\n bits_ = x;\n return *this;\n }\n\n _MLX_BFloat16& operator=(const float& x) {\n return (*this = _MLX_BFloat16(x));\n }\n\n // From float32\n _MLX_BFloat16(const float& x) {\n if (std::isnan(x)) {\n bits_ = __MLX_BFLOAT_NAN__;\n } else {\n // Union\n float_bits_bf16 in;\n\n // Take bits\n in.f = x;\n\n // Round to nearest even\n in.u += (in.u >> 16 & 1) + uint32_t(0x7FFF);\n\n // Take upper 16 bits\n bits_ = in.u >> 16;\n }\n }\n\n // To float32\n operator float() const {\n // Union\n float_bits_bf16 out;\n\n // Upper 16 bits are the data and lower 16 bits are 0s\n out.u = ((uint32_t)bits_) << 16;\n\n return out.f;\n }\n};\n\n#define bfloat_binop_base(__op__, __operator__, otype, atype, btype, ctype) \\\n inline otype __operator__(atype lhs, btype rhs) { \\\n return static_cast(lhs) __op__ static_cast(rhs); \\\n }\n\n#define bfloat_binop_helper(__op__, __operator__, otype, itype, ctype) \\\n inline otype __operator__(_MLX_BFloat16 lhs, itype rhs) { \\\n return static_cast(lhs) __op__ static_cast(rhs); \\\n } \\\n inline otype __operator__(itype lhs, _MLX_BFloat16 rhs) { \\\n return static_cast(lhs) __op__ static_cast(rhs); \\\n }\n\n// Operators\n#define bfloat_binop(_op_, _operator_) \\\n bfloat_binop_base( \\\n _op_, _operator_, _MLX_BFloat16, _MLX_BFloat16, _MLX_BFloat16, float); \\\n bfloat_binop_helper(_op_, _operator_, float, float, float); \\\n bfloat_binop_helper(_op_, _operator_, double, double, double); \\\n bfloat_binop_helper(_op_, _operator_, _MLX_BFloat16, bool, float); \\\n bfloat_binop_helper(_op_, _operator_, _MLX_BFloat16, int32_t, float); \\\n bfloat_binop_helper(_op_, _operator_, _MLX_BFloat16, uint32_t, float); \\\n bfloat_binop_helper(_op_, _operator_, _MLX_BFloat16, int64_t, float); \\\n bfloat_binop_helper(_op_, _operator_, _MLX_BFloat16, uint64_t, float);\n\nbfloat_binop(+, operator+);\nbfloat_binop(-, operator-);\nbfloat_binop(*, operator*);\nbfloat_binop(/, operator/);\n\n#undef bfloat_binop\n\n// Comparison ops\n#define bfloat_compop(__op__, __operator__) \\\n bfloat_binop_base( \\\n __op__, __operator__, bool, _MLX_BFloat16, _MLX_BFloat16, float); \\\n bfloat_binop_helper(__op__, __operator__, bool, float, float); \\\n bfloat_binop_helper(__op__, __operator__, bool, double, double); \\\n bfloat_binop_helper(__op__, __operator__, bool, int32_t, float); \\\n bfloat_binop_helper(__op__, __operator__, bool, uint32_t, float); \\\n bfloat_binop_helper(__op__, __operator__, bool, int64_t, float); \\\n bfloat_binop_helper(__op__, __operator__, bool, uint64_t, float);\n\nbfloat_compop(>, operator>);\nbfloat_compop(<, operator<);\nbfloat_compop(>=, operator>=);\nbfloat_compop(<=, operator<=);\nbfloat_compop(==, operator==);\nbfloat_compop(!=, operator!=);\n\n#undef bfloat_compop\n\n// Negative\ninline _MLX_BFloat16 operator-(_MLX_BFloat16 lhs) {\n return -static_cast(lhs);\n}\n\n// Inplace ops\n#define bfloat_inplace_op(__op__, __operator__) \\\n inline _MLX_BFloat16& __operator__(_MLX_BFloat16& lhs, const float& rhs) { \\\n lhs = lhs __op__ rhs; \\\n return lhs; \\\n } \\\n inline float& __operator__(float& lhs, _MLX_BFloat16 rhs) { \\\n lhs = lhs __op__ rhs; \\\n return lhs; \\\n }\n\nbfloat_inplace_op(+, operator+=);\nbfloat_inplace_op(-, operator-=);\nbfloat_inplace_op(*, operator*=);\nbfloat_inplace_op(/, operator/=);\n\n#undef bfloat_inplace_op\n\n// Bitwise ops\n\n#define bfloat_bitop(__op__, __operator__) \\\n inline _MLX_BFloat16 __operator__(_MLX_BFloat16 lhs, _MLX_BFloat16 rhs) { \\\n _MLX_BFloat16 out; \\\n out.bits_ = lhs.bits_ __op__ rhs.bits_; \\\n return out; \\\n } \\\n inline _MLX_BFloat16 __operator__(_MLX_BFloat16 lhs, uint16_t rhs) { \\\n _MLX_BFloat16 out; \\\n out.bits_ = lhs.bits_ __op__ rhs; \\\n return out; \\\n } \\\n inline _MLX_BFloat16 __operator__(uint16_t lhs, _MLX_BFloat16 rhs) { \\\n _MLX_BFloat16 out; \\\n out.bits_ = lhs __op__ rhs.bits_; \\\n return out; \\\n }\n\nbfloat_bitop(|, operator|);\nbfloat_bitop(&, operator&);\nbfloat_bitop(^, operator^);\n\n#undef bfloat_bitop\n\n#define bfloat_inplace_bitop(__op__, __operator__) \\\n inline _MLX_BFloat16& __operator__(_MLX_BFloat16& lhs, _MLX_BFloat16 rhs) { \\\n lhs.bits_ = lhs.bits_ __op__ rhs.bits_; \\\n return lhs; \\\n } \\\n inline _MLX_BFloat16& __operator__(_MLX_BFloat16& lhs, uint16_t rhs) { \\\n lhs.bits_ = lhs.bits_ __op__ rhs; \\\n return lhs; \\\n }\n\nbfloat_inplace_bitop(|, operator|=);\nbfloat_inplace_bitop(&, operator&=);\nbfloat_inplace_bitop(^, operator^=);\n\n#undef bfloat_inplace_bitop\n\n} // namespace mlx::core\n\n// Path: mlx/types/complex.h\n// Copyright \u00a9 2023 Apple Inc.\n\n#pragma once\n#include \n#include \"mlx/types/half_types.h\"\n\nnamespace mlx::core {\n\nstruct complex64_t;\nstruct complex128_t;\n\ntemplate \ninline constexpr bool can_convert_to_complex128 =\n !std::is_same_v && std::is_convertible_v;\n\nstruct complex128_t : public std::complex {\n complex128_t(double v, double u) : std::complex(v, u){};\n complex128_t(std::complex v) : std::complex(v){};\n\n template <\n typename T,\n typename = typename std::enable_if>::type>\n complex128_t(T x) : std::complex(x){};\n\n operator float() const {\n return real();\n };\n};\n\ntemplate \ninline constexpr bool can_convert_to_complex64 =\n !std::is_same_v && std::is_convertible_v;\n\nstruct complex64_t : public std::complex {\n complex64_t(float v, float u) : std::complex(v, u){};\n complex64_t(std::complex v) : std::complex(v){};\n\n template <\n typename T,\n typename = typename std::enable_if>::type>\n complex64_t(T x) : std::complex(x){};\n\n operator float() const {\n return real();\n };\n};\n\ninline bool operator>=(const complex64_t& a, const complex64_t& b) {\n return (a.real() > b.real()) ||\n (a.real() == b.real() && a.imag() >= b.imag());\n}\n\ninline bool operator>(const complex64_t& a, const complex64_t& b) {\n return (a.real() > b.real()) || (a.real() == b.real() && a.imag() > b.imag());\n}\n\ninline complex64_t operator%(complex64_t a, complex64_t b) {\n auto real = a.real() - (b.real() * static_cast(a.real() / b.real()));\n auto imag = a.imag() - (b.imag() * static_cast(a.imag() / b.imag()));\n if (real != 0 && ((real < 0) != (b.real() < 0)))\n real += b.real();\n if (imag != 0 && ((imag < 0) != (b.imag() < 0)))\n imag += b.imag();\n return {real, imag};\n}\n\ninline bool operator<=(const complex64_t& a, const complex64_t& b) {\n return operator>=(b, a);\n}\n\ninline bool operator<(const complex64_t& a, const complex64_t& b) {\n return operator>(b, a);\n}\n\ninline complex64_t operator-(const complex64_t& v) {\n return -static_cast>(v);\n}\n\n// clang-format off\n#define complex_binop_helper(_op_, _operator_, itype) \\\n inline complex64_t _operator_(itype x, const complex64_t& y) { \\\n return static_cast(x) _op_ y; \\\n } \\\n inline complex64_t _operator_(const complex64_t& x, itype y) { \\\n return x _op_ static_cast(y); \\\n }\n\n#define complex_binop(_op_, _operator_) \\\n inline complex64_t _operator_(const std::complex& x, const complex64_t& y) { \\\n return x _op_ static_cast>(y); \\\n } \\\n inline complex64_t _operator_(const complex64_t& x, const std::complex& y) { \\\n return static_cast>(x) _op_ y; \\\n } \\\n inline complex64_t _operator_(const complex64_t& x, const complex64_t& y) { \\\n return static_cast>(x) \\\n _op_ static_cast>(y); \\\n } \\\n complex_binop_helper(_op_, _operator_, bool) \\\n complex_binop_helper(_op_, _operator_, uint32_t) \\\n complex_binop_helper(_op_, _operator_, uint64_t) \\\n complex_binop_helper(_op_, _operator_, int32_t) \\\n complex_binop_helper(_op_, _operator_, int64_t) \\\n complex_binop_helper(_op_, _operator_, float16_t) \\\n complex_binop_helper(_op_, _operator_, bfloat16_t) \\\n complex_binop_helper(_op_, _operator_, float)\n// clang-format on\n\ncomplex_binop(+, operator+)\n\n} // namespace mlx::core\n\n// Path: mlx/types/fp16.h\n// Copyright \u00a9 2023 Apple Inc.\n\n#pragma once\n\n#include \n#include \n#include \n#include \n\n#define __MLX_HALF_NAN__ 0x7D00\n\nnamespace mlx::core {\n\nnamespace {\nunion float_bits_fp16 {\n float f;\n uint32_t u;\n};\n} // namespace\n\nstruct _MLX_Float16 {\n uint16_t bits_;\n\n // Default constructor\n _MLX_Float16() = default;\n\n // Default copy constructor\n _MLX_Float16(_MLX_Float16 const&) = default;\n\n // Appease std::vector for being special\n _MLX_Float16& operator=(std::vector::reference x) {\n bits_ = x;\n return *this;\n }\n\n _MLX_Float16& operator=(const float& x) {\n return (*this = _MLX_Float16(x));\n }\n\n // From float32\n _MLX_Float16(const float& x) : bits_(0) {\n // Conversion following\n // https://github.com/Maratyszcza/FP16/blob/master/include/fp16/fp16.h\n\n // Union\n float_bits_fp16 in;\n\n // Take fp32 bits\n in.f = x;\n\n // Find and take sign bit\n uint32_t x_sign_32 = in.u & uint32_t(0x80000000);\n uint16_t x_sign_16 = (x_sign_32 >> 16);\n\n if (std::isnan(x)) {\n bits_ = x_sign_16 | uint16_t(__MLX_HALF_NAN__);\n } else {\n // Union\n float_bits_fp16 inf_scale, zero_scale, magic_bits;\n\n // Find exponent bits and take the max supported by half\n uint32_t x_expo_32 = in.u & uint32_t(0x7f800000);\n uint32_t max_expo_32 = uint32_t(0x38800000);\n x_expo_32 = x_expo_32 < max_expo_32 ? max_expo_32 : x_expo_32;\n x_expo_32 += uint32_t(15) << 23;\n\n // Handle scaling to inf as needed\n inf_scale.u = uint32_t(0x77800000);\n zero_scale.u = uint32_t(0x08800000);\n\n // Combine with magic and let addition do rounding\n magic_bits.u = x_expo_32;\n magic_bits.f += (std::abs(x) * inf_scale.f) * zero_scale.f;\n\n // Take the lower 5 bits of the exponent\n uint32_t x_expo_16 = ((magic_bits.u >> 13) & uint32_t(0x7c00));\n\n // Collect the lower 12 bits which have the mantissa\n uint32_t x_mant_16 = magic_bits.u & uint32_t(0x0fff);\n\n // Combine sign, exp and mantissa\n bits_ = (x_sign_16 | uint16_t(x_expo_16 + x_mant_16));\n }\n }\n\n // To float32\n operator float() const {\n // Conversion following\n // https://github.com/Maratyszcza/FP16/blob/master/include/fp16/fp16.h\n\n // Union\n float_bits_fp16 out;\n\n uint32_t x_sign_32 = (bits_ << 16) & uint32_t(0x80000000);\n uint32_t base = (bits_ << 16);\n uint32_t two_base = base + base;\n\n uint32_t denorm_max = 1u << 27;\n if (two_base < denorm_max) {\n out.u = uint32_t(126) << 23; // magic mask\n out.u |= (two_base >> 17); // Bits from fp16\n out.f -= 0.5f; // magic bias\n } else {\n out.u = uint32_t(0xE0) << 23; // exponent offset\n out.u += (two_base >> 4); // Bits from fp16\n float out_unscaled = out.f; // Store value\n out.u = uint32_t(0x7800000); // exponent scale\n out.f *= out_unscaled;\n }\n\n // Add sign\n out.u |= x_sign_32;\n\n return out.f;\n }\n};\n\n#define half_binop_base(__op__, __operator__, otype, atype, btype, ctype) \\\n inline otype __operator__(atype lhs, btype rhs) { \\\n return static_cast(lhs) __op__ static_cast(rhs); \\\n }\n\n#define half_binop_helper(__op__, __operator__, otype, itype, ctype) \\\n inline otype __operator__(_MLX_Float16 lhs, itype rhs) { \\\n return static_cast(lhs) __op__ static_cast(rhs); \\\n } \\\n inline otype __operator__(itype lhs, _MLX_Float16 rhs) { \\\n return static_cast(lhs) __op__ static_cast(rhs); \\\n }\n\n// Operators\n#define half_binop(__op__, __operator__) \\\n half_binop_base( \\\n __op__, __operator__, _MLX_Float16, _MLX_Float16, _MLX_Float16, float); \\\n half_binop_helper(__op__, __operator__, float, float, float); \\\n half_binop_helper(__op__, __operator__, double, double, double); \\\n half_binop_helper(__op__, __operator__, _MLX_Float16, bool, float); \\\n half_binop_helper(__op__, __operator__, _MLX_Float16, int32_t, float); \\\n half_binop_helper(__op__, __operator__, _MLX_Float16, uint32_t, float); \\\n half_binop_helper(__op__, __operator__, _MLX_Float16, int64_t, float); \\\n half_binop_helper(__op__, __operator__, _MLX_Float16, uint64_t, float);\n\nhalf_binop(+, operator+);\nhalf_binop(-, operator-);\nhalf_binop(*, operator*);\nhalf_binop(/, operator/);\n\n#undef half_binop\n\n// Comparison ops\n#define half_compop(__op__, __operator__) \\\n half_binop_base( \\\n __op__, __operator__, bool, _MLX_Float16, _MLX_Float16, float); \\\n half_binop_helper(__op__, __operator__, bool, float, float); \\\n half_binop_helper(__op__, __operator__, bool, double, double); \\\n half_binop_helper(__op__, __operator__, bool, int32_t, float); \\\n half_binop_helper(__op__, __operator__, bool, uint32_t, float); \\\n half_binop_helper(__op__, __operator__, bool, int64_t, float); \\\n half_binop_helper(__op__, __operator__, bool, uint64_t, float);\n\nhalf_compop(>, operator>);\nhalf_compop(<, operator<);\nhalf_compop(>=, operator>=);\nhalf_compop(<=, operator<=);\nhalf_compop(==, operator==);\nhalf_compop(!=, operator!=);\n\n#undef half_compop\n\n// Negative\ninline _MLX_Float16 operator-(_MLX_Float16 lhs) {\n return -static_cast(lhs);\n}\n\n// Inplace ops\n#define half_inplace_op(__op__, __operator__) \\\n inline _MLX_Float16& __operator__(_MLX_Float16& lhs, const float& rhs) { \\\n lhs = lhs __op__ rhs; \\\n return lhs; \\\n } \\\n inline float& __operator__(float& lhs, _MLX_Float16 rhs) { \\\n lhs = lhs __op__ rhs; \\\n return lhs; \\\n }\n\nhalf_inplace_op(+, operator+=);\nhalf_inplace_op(-, operator-=);\nhalf_inplace_op(*, operator*=);\nhalf_inplace_op(/, operator/=);\n\n#undef half_inplace_op\n\n// Bitwise ops\n\n#define half_bitop(__op__, __operator__) \\\n inline _MLX_Float16 __operator__(_MLX_Float16 lhs, _MLX_Float16 rhs) { \\\n _MLX_Float16 out; \\\n out.bits_ = lhs.bits_ __op__ rhs.bits_; \\\n return out; \\\n } \\\n inline _MLX_Float16 __operator__(_MLX_Float16 lhs, uint16_t rhs) { \\\n _MLX_Float16 out; \\\n out.bits_ = lhs.bits_ __op__ rhs; \\\n return out; \\\n } \\\n inline _MLX_Float16 __operator__(uint16_t lhs, _MLX_Float16 rhs) { \\\n _MLX_Float16 out; \\\n out.bits_ = lhs __op__ rhs.bits_; \\\n return out; \\\n }\n\nhalf_bitop(|, operator|);\nhalf_bitop(&, operator&);\nhalf_bitop(^, operator^);\n\n#undef half_bitop\n\n#define half_inplace_bitop(__op__, __operator__) \\\n inline _MLX_Float16& __operator__(_MLX_Float16& lhs, _MLX_Float16 rhs) { \\\n lhs.bits_ = lhs.bits_ __op__ rhs.bits_; \\\n return lhs; \\\n } \\\n inline _MLX_Float16& __operator__(_MLX_Float16& lhs, uint16_t rhs) { \\\n lhs.bits_ = lhs.bits_ __op__ rhs; \\\n return lhs; \\\n...\n", "needle_token_start": 5628, "needle_token_end": 5645, "code_context_ntokens": 16065, "output": ["```cpp\nnamespace mlx::core::detail {\n\nstruct InTracing {\n InTracing() {\n tracing_counter++;\n }\n ~InTracing() {\n tracing_counter--;\n }\n\n static bool in_tracing() {\n return tracing_counter > 0;\n }\n\n private:\n static int tracing_counter;\n};\n\n} // namespace mlx::core::detail\n```"]} +{"repo": "ml-explore/mlx", "name": "copy_inplace_dispatch", "language": "cpp", "path": "mlx/backend/common/copy.cpp", "position_ratio": 0.45, "description": "\nFunction Description:\n1. **Purpose**: The function is designed to handle the copying of data from one array to another, ensuring that the data type of the source array is appropriately managed during the copy process to match the destination array's data type.\n2. **Input**: This function takes a source array, a destination array, a type of copy operation, and additional arguments that may be needed for specific copy operations.\n3. **Output**: There is no direct output returned by the function; however, it modifies the destination array by copying data from the source array into it according to the specified copy type.\n4. **Procedure**: The function first identifies the data type of the source array. Based on this data type, it selects the appropriate templated copy function to execute. The selected function handles the copying process, which may vary depending on the copy type (e.g., scalar, vector, general, or general-general copying). This ensures that the data is copied correctly while respecting the data types and the intended memory layout and transformation specified by the copy type.\n\n", "instruction": "Based on the function description and code context, please retrieve and repeat the exact described function from the code context in a code block wrapped by ```:", "template": "instruction\ncode_context\ndescription\ninstruction", "code_context": "// Path: mlx/backend/common/quantized.cpp\n// Copyright \u00a9 2023 Apple Inc.\n\n#include \n\n#include \"mlx/backend/metal/copy.h\"\n#include \"mlx/primitives.h\"\n\nnamespace mlx::core {\n\nnamespace {\n\ntemplate \nvoid _qmm(\n T* result,\n const T* x,\n const uint32_t* w,\n const T* scales,\n const T* biases,\n int M,\n int N,\n int K) {\n constexpr int bitmask = (1 << bits) - 1;\n constexpr int pack_factor = 32 / bits;\n constexpr int packs_in_group = group_size / pack_factor;\n const int Ng = N / group_size;\n const int Nw = N / pack_factor;\n\n for (int m = 0; m < M; m++) {\n const uint32_t* w_local = w;\n const T* scales_local = scales;\n const T* biases_local = biases;\n\n std::fill(result, result + N, 0);\n\n for (int k = 0; k < K; k++) {\n T* result_local = result;\n T xi = *x++;\n\n for (int n = 0; n < N; n += group_size) {\n T scale = *scales_local++;\n T bias = *biases_local++;\n for (int ng = 0; ng < packs_in_group; ng++) {\n uint32_t wi = *w_local++;\n\n#pragma clang loop unroll(full)\n for (int p = 0; p < pack_factor; p++) {\n (*result_local++) +=\n xi * (scale * static_cast(wi & bitmask) + bias);\n wi >>= bits;\n }\n }\n }\n }\n\n result += N;\n }\n}\n\ntemplate \nvoid _qmm_t(\n T* result,\n const T* x,\n const uint32_t* w,\n const T* scales,\n const T* biases,\n int M,\n int N,\n int K) {\n constexpr int bitmask = (1 << bits) - 1;\n constexpr int pack_factor = 32 / bits;\n constexpr int packs_in_group = group_size / pack_factor;\n const int Kg = K / group_size;\n const int Kw = K / pack_factor;\n\n for (int m = 0; m < M; m++) {\n const uint32_t* w_local = w;\n const T* scales_local = scales;\n const T* biases_local = biases;\n\n for (int n = 0; n < N; n++) {\n const T* x_local = x;\n T sum = 0;\n for (int k = 0; k < K; k += group_size) {\n T scale = *scales_local++;\n T bias = *biases_local++;\n\n for (int kw = 0; kw < packs_in_group; kw++) {\n uint32_t wi = *w_local++;\n\n#pragma clang loop unroll(full)\n for (int p = 0; p < pack_factor; p++) {\n sum += (*x_local++) * (scale * static_cast(wi & bitmask) + bias);\n wi >>= bits;\n }\n }\n }\n *result = sum;\n result++;\n }\n\n x += K;\n }\n}\n\ntemplate \nvoid _qmm_dispatch_typed(\n T* result,\n const T* x,\n const uint32_t* w,\n const T* scales,\n const T* biases,\n int M,\n int N,\n int K,\n int group_size,\n int bits,\n bool transposed_w) {\n switch (bits) {\n case 2: {\n switch (group_size) {\n case 32:\n if (transposed_w) {\n return _qmm_t(result, x, w, scales, biases, M, N, K);\n } else {\n return _qmm(result, x, w, scales, biases, M, N, K);\n }\n case 64:\n if (transposed_w) {\n return _qmm_t(result, x, w, scales, biases, M, N, K);\n } else {\n return _qmm(result, x, w, scales, biases, M, N, K);\n }\n case 128:\n if (transposed_w) {\n return _qmm_t(result, x, w, scales, biases, M, N, K);\n } else {\n return _qmm(result, x, w, scales, biases, M, N, K);\n }\n }\n }\n case 4: {\n switch (group_size) {\n case 32:\n if (transposed_w) {\n return _qmm_t(result, x, w, scales, biases, M, N, K);\n } else {\n return _qmm(result, x, w, scales, biases, M, N, K);\n }\n case 64:\n if (transposed_w) {\n return _qmm_t(result, x, w, scales, biases, M, N, K);\n } else {\n return _qmm(result, x, w, scales, biases, M, N, K);\n }\n case 128:\n if (transposed_w) {\n return _qmm_t(result, x, w, scales, biases, M, N, K);\n } else {\n return _qmm(result, x, w, scales, biases, M, N, K);\n }\n }\n }\n case 8: {\n switch (group_size) {\n case 32:\n if (transposed_w) {\n return _qmm_t(result, x, w, scales, biases, M, N, K);\n } else {\n return _qmm(result, x, w, scales, biases, M, N, K);\n }\n case 64:\n if (transposed_w) {\n return _qmm_t(result, x, w, scales, biases, M, N, K);\n } else {\n return _qmm(result, x, w, scales, biases, M, N, K);\n }\n case 128:\n if (transposed_w) {\n return _qmm_t(result, x, w, scales, biases, M, N, K);\n } else {\n return _qmm(result, x, w, scales, biases, M, N, K);\n }\n }\n }\n }\n std::ostringstream msg;\n msg << \"Quantization type not supported. Provided bits=\" << bits\n << \" and group_size=\" << group_size\n << \". The supported options are bits in \"\n << \"{2, 4, 8} and group_size in {64, 128}.\";\n throw std::invalid_argument(msg.str());\n}\n\nvoid _qmm_dispatch(\n array out,\n const array& x,\n const array& w,\n const array& scales,\n const array& biases,\n int bits,\n int group_size,\n bool transposed_w) {\n int K = x.shape(-1);\n int M = x.size() / K;\n int N = out.shape(-1);\n\n switch (x.dtype()) {\n case float32:\n _qmm_dispatch_typed(\n out.data(),\n x.data(),\n w.data(),\n scales.data(),\n biases.data(),\n M,\n N,\n K,\n bits,\n group_size,\n transposed_w);\n break;\n case float16:\n _qmm_dispatch_typed(\n out.data(),\n x.data(),\n w.data(),\n scales.data(),\n biases.data(),\n M,\n N,\n K,\n bits,\n group_size,\n transposed_w);\n break;\n case bfloat16:\n _qmm_dispatch_typed(\n out.data(),\n x.data(),\n w.data(),\n scales.data(),\n biases.data(),\n M,\n N,\n K,\n bits,\n group_size,\n transposed_w);\n break;\n default:\n throw std::invalid_argument(\n \"[quantized_matmul] only floating types are supported\");\n }\n}\n\n} // namespace\n\nvoid QuantizedMatmul::eval(const std::vector& inputs, array& out) {\n assert(inputs.size() == 4);\n\n auto& x_pre = inputs[0];\n auto& w_pre = inputs[1];\n auto& scales_pre = inputs[2];\n auto& biases_pre = inputs[3];\n\n auto ensure_row_contiguous = [](const array& arr) {\n...\n// Path: mlx/backend/common/copy.cpp\n// Copyright \u00a9 2023-2024 Apple Inc.\n\n#include \n\n#include \"mlx/allocator.h\"\n#include \"mlx/backend/common/copy.h\"\n\nnamespace mlx::core {\n\nnamespace {\n\ntemplate \nvoid copy_single(const array& src, array& dst) {\n auto val = static_cast(src.data()[0]);\n auto dst_ptr = dst.data();\n for (int i = 0; i < dst.size(); ++i) {\n dst_ptr[i] = val;\n }\n}\n\ntemplate \nvoid copy_vector(const array& src, array& dst) {\n auto src_ptr = src.data();\n auto dst_ptr = dst.data();\n std::copy(src_ptr, src_ptr + src.data_size(), dst_ptr);\n}\n\ntemplate \nvoid copy_general_dim1(\n const array& src,\n array& dst,\n const std::vector& data_shape,\n const std::vector& i_strides,\n int64_t i_offset) {\n const SrcT* src_ptr = src.data();\n DstT* dst_ptr = dst.data();\n stride_t src_idx = i_offset;\n stride_t dst_idx = 0;\n for (int i = 0; i < data_shape[0]; ++i) {\n dst_ptr[dst_idx++] = static_cast(src_ptr[src_idx]);\n src_idx += i_strides[0];\n }\n}\n\ntemplate \ninline void copy_general_dim1(const array& src, array& dst) {\n return copy_general_dim1(\n src, dst, src.shape(), src.strides(), 0);\n}\n\ntemplate \nvoid copy_general_dim2(\n const array& src,\n array& dst,\n const std::vector& data_shape,\n const std::vector& i_strides,\n int64_t i_offset) {\n const SrcT* src_ptr = src.data();\n DstT* dst_ptr = dst.data();\n stride_t src_idx = i_offset;\n stride_t dst_idx = 0;\n for (int i = 0; i < data_shape[0]; ++i) {\n for (int j = 0; j < data_shape[1]; ++j) {\n dst_ptr[dst_idx++] = static_cast(src_ptr[src_idx]);\n src_idx += i_strides[1];\n }\n src_idx += i_strides[0] - i_strides[1] * data_shape[1];\n }\n}\n\ntemplate \ninline void copy_general_dim2(const array& src, array& dst) {\n return copy_general_dim2(\n src, dst, src.shape(), src.strides(), 0);\n}\n\ntemplate \nvoid copy_general_dim3(\n const array& src,\n array& dst,\n const std::vector& data_shape,\n const std::vector& i_strides,\n int64_t i_offset) {\n const SrcT* src_ptr = src.data();\n DstT* dst_ptr = dst.data();\n stride_t src_idx = i_offset;\n stride_t dst_idx = 0;\n for (int i = 0; i < data_shape[0]; ++i) {\n for (int j = 0; j < data_shape[1]; ++j) {\n for (int k = 0; k < data_shape[2]; ++k) {\n dst_ptr[dst_idx++] = static_cast(src_ptr[src_idx]);\n src_idx += i_strides[2];\n }\n src_idx += i_strides[1] - i_strides[2] * data_shape[2];\n }\n src_idx += i_strides[0] - i_strides[1] * data_shape[1];\n }\n}\n\ntemplate \ninline void copy_general_dim3(const array& src, array& dst) {\n return copy_general_dim3(\n src, dst, src.shape(), src.strides(), 0);\n}\n\ntemplate \nvoid copy_general_dim4(\n const array& src,\n array& dst,\n const std::vector& data_shape,\n const std::vector& i_strides,\n int64_t i_offset) {\n const SrcT* src_ptr = src.data();\n DstT* dst_ptr = dst.data();\n stride_t src_idx = i_offset;\n stride_t dst_idx = 0;\n for (int i = 0; i < data_shape[0]; ++i) {\n for (int j = 0; j < data_shape[1]; ++j) {\n for (int k = 0; k < data_shape[2]; ++k) {\n for (int ii = 0; ii < data_shape[3]; ++ii) {\n dst_ptr[dst_idx++] = static_cast(src_ptr[src_idx]);\n src_idx += i_strides[3];\n }\n src_idx += i_strides[2] - i_strides[3] * data_shape[3];\n }\n src_idx += i_strides[1] - i_strides[2] * data_shape[2];\n }\n src_idx += i_strides[0] - i_strides[1] * data_shape[1];\n }\n}\n\ntemplate \ninline void copy_general_dim4(const array& src, array& dst) {\n return copy_general_dim4(\n src, dst, src.shape(), src.strides(), 0);\n}\n\ntemplate \nvoid copy_general(\n const array& src,\n array& dst,\n const std::vector& data_shape,\n const std::vector& i_strides,\n int64_t i_offset) {\n switch (src.ndim()) {\n case 1:\n copy_general_dim1(\n src, dst, data_shape, i_strides, i_offset);\n return;\n case 2:\n copy_general_dim2(\n src, dst, data_shape, i_strides, i_offset);\n return;\n case 3:\n copy_general_dim3(\n src, dst, data_shape, i_strides, i_offset);\n return;\n case 4:\n copy_general_dim4(\n src, dst, data_shape, i_strides, i_offset);\n return;\n }\n\n auto src_ptr = src.data() + i_offset;\n auto dst_ptr = dst.data();\n for (size_t i = 0; i < dst.size(); ++i) {\n stride_t src_elem = elem_to_loc(i, data_shape, i_strides);\n dst_ptr[i] = static_cast(src_ptr[src_elem]);\n }\n}\n\ntemplate \ninline void copy_general(const array& src, array& dst) {\n return copy_general(\n src, dst, src.shape(), src.strides(), 0);\n}\n\ntemplate \ninline void copy_general(\n const array& src,\n array& dst,\n const std::vector& data_shape,\n const std::vector& i_strides,\n const std::vector& o_strides,\n int64_t i_offset,\n int64_t o_offset) {\n return copy_general(\n src, dst, data_shape, i_strides, i_offset);\n}\n\ntemplate \ninline void copy_general_general_dims(\n const array& src,\n array& dst,\n const std::vector& data_shape,\n const std::vector& i_strides,\n const std::vector& o_strides,\n stride_t i_offset,\n stride_t o_offset) {\n if constexpr (D > 1) {\n int axis = src.ndim() - D;\n auto stride_src = i_strides[axis];\n auto stride_dst = o_strides[axis];\n auto N = data_shape[axis];\n for (int i = 0; i < N; i++) {\n copy_general_general_dims(\n src, dst, data_shape, i_strides, o_strides, i_offset, o_offset);\n i_offset += stride_src;\n o_offset += stride_dst;\n }\n } else {\n int axis = src.ndim() - 1;\n auto stride_src = i_strides[axis];\n auto stride_dst = o_strides[axis];\n auto N = data_shape[axis];\n const SrcT* src_ptr = src.data() + i_offset;\n DstT* dst_ptr = dst.data() + o_offset;\n for (int i = 0; i < N; i++) {\n *dst_ptr = static_cast(*src_ptr);\n src_ptr += stride_src;\n dst_ptr += stride_dst;\n }\n }\n}\n\ntemplate \nvoid copy_general_general(\n const array& src,\n array& dst,\n const std::vector& data_shape,\n const std::vector& i_strides,\n const std::vector& o_strides,\n stride_t i_offset,\n stride_t o_offset) {\n switch (src.ndim()) {\n case 1:\n copy_general_general_dims(\n src, dst, data_shape, i_strides, o_strides, i_offset, o_offset);\n return;\n case 2:\n copy_general_general_dims(\n src, dst, data_shape, i_strides, o_strides, i_offset, o_offset);\n return;\n case 3:\n copy_general_general_dims(\n src, dst, data_shape, i_strides, o_strides, i_offset, o_offset);\n return;\n case 4:\n copy_general_general_dims(\n src, dst, data_shape, i_strides, o_strides, i_offset, o_offset);\n return;\n case 5:\n copy_general_general_dims(\n src, dst, data_shape, i_strides, o_strides, i_offset, o_offset);\n return;\n }\n\n int size = std::accumulate(\n data_shape.begin() - 5, data_shape.end(), 1, std::multiplies());\n for (int i = 0; i < src.size(); i += size) {\n stride_t src_offset = i_offset + elem_to_loc(i, data_shape, i_strides);\n stride_t dst_offset = o_offset + elem_to_loc(i, dst.shape(), o_strides);\n copy_general_general_dims(\n src, dst, data_shape, i_strides, o_strides, src_offset, dst_offset);\n }\n}\n\ntemplate \ninline void copy_general_general(const array& src, array& dst) {\n return copy_general_general(\n src, dst, src.shape(), src.strides(), dst.strides(), 0, 0);\n}\n\ntemplate \nvoid copy(const array& src, array& dst, CopyType ctype, Args&&... args) {\n switch (ctype) {\n case CopyType::Scalar:\n copy_single(src, dst);\n return;\n case CopyType::Vector:\n copy_vector(src, dst);\n return;\n case CopyType::General:\n copy_general(src, dst, std::forward(args)...);\n return;\n case CopyType::GeneralGeneral:\n copy_general_general(src, dst, std::forward(args)...);\n }\n}\n\ntemplate \nvoid copy(const array& src, array& dst, CopyType ctype, Args&&... args) {\n switch (dst.dtype()) {\n case bool_:\n copy(src, dst, ctype, std::forward(args)...);\n break;\n case uint8:\n copy(src, dst, ctype, std::forward(args)...);\n break;\n case uint16:\n copy(src, dst, ctype, std::forward(args)...);\n break;\n case uint32:\n copy(src, dst, ctype, std::forward(args)...);\n break;\n case uint64:\n copy(src, dst, ctype, std::forward(args)...);\n break;\n case int8:\n copy(src, dst, ctype, std::forward(args)...);\n break;\n case int16:\n copy(src, dst, ctype, std::forward(args)...);\n break;\n case int32:\n copy(src, dst, ctype, std::forward(args)...);\n break;\n case int64:\n copy(src, dst, ctype, std::forward(args)...);\n break;\n case float16:\n copy(src, dst, ctype, std::forward(args)...);\n break;\n case float32:\n copy(src, dst, ctype, std::forward(args)...);\n break;\n case bfloat16:\n copy(src, dst, ctype, std::forward(args)...);\n break;\n case complex64:\n copy(src, dst, ctype, std::forward(args)...);\n break;\n }\n}\n\ntemplate \ni\nnline void copy_inplace_dispatch(\n const array& src,\n array& dst,\n CopyType ctype,\n Args&&... args) {\n switch (src.dtype()) {\n case bool_:\n copy(src, dst, ctype, std::forward(args)...);\n break;\n case uint8:\n copy(src, dst, ctype, std::forward(args)...);\n break;\n case uint16:\n copy(src, dst, ctype, std::forward(args)...);\n break;\n case uint32:\n copy(src, dst, ctype, std::forward(args)...);\n break;\n case uint64:\n copy(src, dst, ctype, std::forward(args)...);\n break;\n case int8:\n copy(src, dst, ctype, std::forward(args)...);\n break;\n case int16:\n copy(src, dst, ctype, std::forward(args)...);\n break;\n case int32:\n copy(src, dst, ctype, std::forward(args)...);\n break;\n case int64:\n copy(src, dst, ctype, std::forward(args)...);\n break;\n case float16:\n copy(src, dst, ctype, std::forward(args)...);\n break;\n case float32:\n copy(src, dst, ctype, std::forward(args)...);\n break;\n case bfloat16:\n copy(src, dst, ctype, std::forward(args)...);\n break;\n case complex64:\n copy(src, dst, ctype, std::forward(args)...);\n break;\n }\n}\n\n} // namespace\n\nvoid copy_inplace(const array& src, array& dst, CopyType ctype) {\n return copy_inplace_dispatch(src, dst, ctype);\n}\n\nvoid copy(const array& src, array& dst, CopyType ctype) {\n // Allocate the output\n switch (ctype) {\n case CopyType::Vector:\n if (src.is_donatable() && src.itemsize() == dst.itemsize()) {\n dst.copy_shared_buffer(src);\n } else {\n auto size = src.data_size();\n dst.set_data(\n allocator::malloc_or_wait(size * dst.itemsize()),\n size,\n src.strides(),\n src.flags());\n }\n break;\n case CopyType::Scalar:\n case CopyType::General:\n case CopyType::GeneralGeneral:\n dst.set_data(allocator::malloc_or_wait(dst.nbytes()));\n break;\n }\n if (ctype == CopyType::GeneralGeneral) {\n ctype = CopyType::General;\n }\n copy_inplace(src, dst, ctype);\n}\n\ntemplate \nvoid copy_inplace(\n const array& src,\n array& dst,\n const std::vector& data_shape,\n const std::vector& i_strides,\n const std::vector& o_strides,\n int64_t i_offset,\n int64_t o_offset,\n CopyType ctype) {\n switch (ctype) {\n case CopyType::General:\n case CopyType::GeneralGeneral:\n return copy_inplace_dispatch(\n src,\n dst,\n ctype,\n data_shape,\n i_strides,\n o_strides,\n i_offset,\n o_offset);\n\n case CopyType::Scalar:\n case CopyType::Vector:\n return copy_inplace_dispatch(src, dst, ctype);\n }\n}\n\ntemplate <>\nvoid copy_inplace(\n const array& src,\n array& dst,\n const std::vector& data_shape,\n const std::vector& i_strides,\n const std::vector& o_strides,\n int64_t i_offset,\n int64_t o_offset,\n CopyType ctype) {\n switch (ctype) {\n case CopyType::General:\n case CopyType::GeneralGeneral:\n return copy_inplace_dispatch(\n src,\n dst,\n ctype,\n data_shape,\n i_strides,\n o_strides,\n i_offset,\n o_offset);\n\n case CopyType::Scalar:\n case CopyType::Vector:\n return copy_inplace_dispatch(src, dst, ctype);\n }\n}\n\n} // namespace mlx::core\n\n// Path: mlx/backend/common/compiled_nocpu.cpp\n// Copyright \u00a9 2023-2024 Apple Inc.\n\n#include \"mlx/backend/common/compiled.h\"\n\nnamespace mlx::core {\n\n// GPU compile is always available if the GPU is available and since we are in\n// this file CPU compile is not available so check if the device is a GPU\n// device.\nnamespace detail {\nbool compile_available_for_device(const Device& device) {\n return device == Device::gpu;\n}\n} // namespace detail\n\nvoid Compiled::eval_cpu(\n const std::vector& inputs,\n std::vector& outputs) {\n throw std::runtime_error(\n \"[Compiled::eval_cpu] CPU compialtion not supported on the platform.\");\n}\n\n} // namespace mlx::core\n\n// Path: mlx/backend/common/compiled.cpp\n// Copyright \u00a9 2023-2024 Apple Inc.\n\n#include \"mlx/backend/common/compiled.h\"\n#include \"mlx/graph_utils.h\"\n#include \"mlx/primitives.h\"\n#include \"mlx/utils.h\"\n\nnamespace mlx::core {\n\nvoid print_constant(std::ostream& os, const array& x) {\n switch (x.dtype()) {\n case float32:\n return print_float_constant(os, x);\n case float16:\n return print_float_constant(os, x);\n case bfloat16:\n return print_float_constant(os, x);\n case complex64:\n return print_complex_constant(os, x);\n case int8:\n return print_int_constant(os, x);\n case int16:\n return print_int_constant(os, x);\n case int32:\n return print_int_constant(os, x);\n case int64:\n return print_int_constant(os, x);\n case uint8:\n return print_int_constant(os, x);\n case uint16:\n return print_int_constant(os, x);\n case uint32:\n return print_int_constant(os, x);\n case uint64:\n return print_int_constant(os, x);\n case bool_:\n os << std::boolalpha << x.item();\n return;\n default:\n throw std::runtime_error(\"Unsupported constant type\");\n }\n}\n\nstd::string get_type_string(Dtype d) {\n switch (d) {\n case float32:\n return \"float\";\n case float16:\n return \"float16_t\";\n case bfloat16:\n return \"bfloat16_t\";\n case complex64:\n return \"complex64_t\";\n case bool_:\n return \"bool\";\n case int8:\n return \"int8_t\";\n case int16:\n return \"int16_t\";\n case int32:\n return \"int32_t\";\n case int64:\n return \"int64_t\";\n case uint8:\n return \"uint8_t\";\n case uint16:\n return \"uint16_t\";\n case uint32:\n return \"uint32_t\";\n case uint64:\n return \"uint64_t\";\n default: {\n std::ostringstream msg;\n msg << \"Unsupported compilation type \" << d;\n throw std::runtime_error(msg.str());\n }\n }\n}\n\nstd::string build_lib_name(\n const std::vector& inputs,\n const std::vector& outputs,\n const std::vector& tape,\n const std::unordered_set& constant_ids) {\n NodeNamer namer;\n std::ostringstream os;\n std::ostringstream constant_hasher;\n\n // Fill the input names. This is not really necessary, I just like having A,\n // B, C, ... as the inputs.\n for (auto& x : inputs) {\n namer.get_name(x);\n }\n\n // The primitives describing the tape. For unary and binary primitives this\n // must be enough to describe the full computation.\n for (auto& a : tape) {\n // name and type of output\n os << namer.get_name(a) << kindof(a.dtype()) << a.itemsize();\n // computation performed\n a.primitive().print(os);\n // name of inputs to the function\n for (auto& inp : a.inputs()) {\n os << namer.get_name(inp);\n }\n }\n os << \"_\";\n\n for (auto& x : inputs) {\n if (constant_ids.find(x.id()) != constant_ids.end()) {\n os << \"C\";\n print_constant(constant_hasher, x);\n } else {\n os << (is_scalar(x) ? \"S\" : \"V\");\n }\n }\n os << \"_\";\n for (auto& x : inputs) {\n if (constant_ids.find(x.id()) != constant_ids.end()) {\n continue;\n }\n os << kindof(x.dtype()) << x.itemsize();\n }\n os << \"_\" << std::hash{}(constant_hasher.str());\n\n return os.str();\n}\n\nbool compiled_check_contiguity(\n const std::vector& inputs,\n const std::vector& shape) {\n bool contiguous = true;\n bool all_contig = true;\n bool all_row_contig = true;\n bool all_col_contig = true;\n int non_scalar_inputs = 0;\n for (const auto& x : inputs) {\n if (is_scalar(x)) {\n continue;\n }\n non_scalar_inputs++;\n bool shape_eq = x.shape() == shape;\n all_contig &= (x.flags().contiguous && shape_eq);\n all_row_contig &= (x.flags().row_contiguous && shape_eq);\n all_col_contig &= (x.flags().col_contiguous && shape_eq);\n }\n if (non_scalar_inputs > 1 && !all_row_contig && !all_col_contig) {\n contiguous = false;\n } else if (non_scalar_inputs == 1 && !all_contig) {\n contiguous = false;\n } else if (non_scalar_inputs == 0 && !shape.empty()) {\n contiguous = false;\n }\n return contiguous;\n}\n\nvoid compiled_allocate_outputs(\n const std::vector& inputs,\n std::vector& outputs,\n const std::vector& inputs_,\n const std::unordered_set& constant_ids_,\n bool contiguous,\n bool move_buffers /* = false */) {\n if (contiguous) {\n int o = 0;\n std::vector strides;\n size_t data_size;\n array::Flags flags;\n for (int i = 0; i < inputs.size() && o < outputs.size(); ++i) {\n auto& in = inputs[i];\n // Conditions for donation\n // - Correct size\n // - Not a scalar\n // - Donatable\n // - Not a constant\n if (in.itemsize() == outputs[o].itemsize() && !is_scalar(in) &&\n in.is_donatable() &&\n constant_ids_.find(inputs_[i].id()) == constant_ids_.end()) {\n if (move_buffers) {\n outputs[o++].move_shared_buffer(in);\n } else {\n outputs[o++].copy_shared_buffer(in);\n }\n }\n // Get representative input flags to properly set non-donated outputs\n if (strides.empty() && in.size() == outputs[0].size()) {\n strides = in.strides();\n flags = in.flags();\n data_size = in.data_size();\n }\n }\n for (; o < outputs.size(); ++o) {\n outputs[o].set_data(\n allocator::malloc_or_wait(data_size * outputs[o].itemsize()),\n data_size,\n strides,\n flags);\n }\n } else {\n int o = 0;\n for (int i = 0; i < inputs.size() && o < outputs.size(); ++i) {\n auto& in = inputs[i];\n // Conditions for donation\n // - Row contiguous\n // - Donatable\n // - Correct size\n // - Not a constant\n if (in.flags().row_contiguous && in.nbytes() == outputs[o].nbytes() &&\n in.is_donatable() &&\n constant_ids_.find(inputs_[i].id()) == constant_ids_.end()) {\n if (move_buffers) {\n outputs[o].move_shared_buffer(\n in, outputs[o].strides(), in.flags(), in.data_size());\n } else {\n outputs[o].copy_shared_buffer(\n in, outputs[o].strides(), in.flags(), in.data_size());\n }\n o++;\n }\n }\n for (; o < outputs.size(); ++o) {\n outputs[o].set_data(allocator::malloc_or_wait(outputs[o].nbytes()));\n }\n }\n}\n\n} // namespace mlx::core\n\n// Path: mlx/backend/common/select.cpp\n// Copyright \u00a9 2023 Apple Inc.\n\n#include \n\n#include \"mlx/backend/common/ternary.h\"\n#include \"mlx/primitives.h\"\n\nnamespace mlx::core {\n\nnamespace {\n\ntemplate \nvoid select_op(\n const array& a,\n const array& b,\n const array& c,\n array& out,\n Op op) {\n switch (out.dtype()) {\n case bool_:\n ternary_op(a, b, c, out, op);\n break;\n case uint8:\n ternary_op(a, b, c, out, op);\n break;\n case uint16:\n ternary_op(a, b, c, out, op);\n break;\n case uint32:\n ternary_op(a, b, c, out, op);\n break;\n case uint64:\n ternary_op(a, b, c, out, op);\n break;\n case int8:\n ternary_op(a, b, c, out, op);\n break;\n case int16:\n ternary_op(a, b, c, out, op);\n break;\n case int32:\n ternary_op(a, b, c, out, op);\n break;\n case int64:\n ternary_op(a, b, c, out, op);\n break;\n case float16:\n ternary_op(a, b, c, out, op);\n break;\n case float32:\n ternary_op(a, b, c, out, op);\n break;\n case bfloat16:\n ternary_op(a, b, c, out, op);\n break;\n case complex64:\n ternary_op(a, b, c, out, op);\n break;\n }\n}\n\n} // namespace\n\nvoid Select::eval(const std::vector& inputs, array& out) {\n assert(inputs.size() == 3);\n const auto& condition = inputs[0];\n const auto& a = inputs[1];\n const auto& b = inputs[2];\n select_op(condition, a, b, out, detail::Select());\n}\n\n} // namespace mlx::core\n\n// Path: mlx/backend/common/threefry.cpp\n// Copyright \u00a9 2023 Apple Inc.\n\n#include \"mlx/backend/common/threefry.h\"\n\nnamespace mlx::core::random {\n\nstd::pair threefry2x32_hash(\n const std::pair& key,\n std::pair count) {\n constexpr static uint32_t rotations[2][4] = {\n {13, 15, 26, 6}, {17, 29, 16, 24}};\n\n uint32_t ks[3] = {key.first, key.second, key.first ^ key.second ^ 0x1BD11BDA};\n\n count.first += ks[0];\n count.second += ks[1];\n\n for (int i = 0; i < 5; ++i) {\n for (auto r : rotations[i % 2]) {\n count.first += count.second;\n count.second = (count.second << r) | (count.second >> (32 - r));\n count.second ^= count.first;\n }\n count.first += ks[(i + 1) % 3];\n count.second += ks[(i + 2) % 3] + i + 1;\n }\n\n return count;\n}\n\n} // namespace mlx::core::random\n\n// Path: mlx/backend/common/fft.cpp\n// Copyright \u00a9 2023 Apple Inc.\n\n#include \n\n#include \"mlx/3rdparty/pocketfft.h\"\n#include \"mlx/allocator.h\"\n#include \"mlx/primitives.h\"\n\nnamespace mlx::core {\n\nvoid FFT::eval(const std::vector& inputs, array& out) {\n auto& in = inputs[0];\n std::vector strides_in(\n in.strides().begin(), in.strides().end());\n for (auto& s : strides_in) {\n s *= in.itemsize();\n }\n std::vector strides_out(\n out.strides().begin(), out.strides().end());\n for (auto& s : strides_out) {\n s *= out.itemsize();\n }\n\n out.set_data(allocator::malloc_or_wait(out.nbytes()));\n\n std::vector shape;\n if (out.dtype() == float32) {\n shape.insert(shape.end(), out.shape().begin(), out.shape().end());\n } else {\n shape.insert(shape.end(), in.shape().begin(), in.shape().end());\n }\n\n float scale = 1.0f;\n if (inverse_) {\n size_t nelem = std::accumulate(\n axes_.begin(), axes_.end(), 1, [&shape](auto x, auto y) {\n return x * shape[y];\n });\n scale /= nelem;\n }\n if (in.dtype() == complex64 && out.dtype() == complex64) {\n auto in_ptr =\n reinterpret_cast*>(in.data());\n auto out_ptr =\n reinterpret_cast*>(out.data());\n pocketfft::c2c(\n shape,\n strides_in,\n strides_out,\n axes_,\n !inverse_,\n in_ptr,\n out_ptr,\n scale);\n } else if (in.dtype() == float32 && out.dtype() == complex64) {\n auto in_ptr = in.data();\n auto out_ptr =\n reinterpret_cast*>(out.data());\n pocketfft::r2c(\n shape,\n strides_in,\n strides_out,\n axes_,\n !inverse_,\n in_ptr,\n out_ptr,\n scale);\n } else if (in.dtype() == complex64 && out.dtype() == float32) {\n auto in_ptr =\n reinterpret_cast*>(in.data());\n auto out_ptr = out.data();\n pocketfft::c2r(\n shape,\n strides_in,\n strides_out,\n axes_,\n !inverse_,\n in_ptr,\n out_ptr,\n scale);\n } else {\n throw std::runtime_error(\n \"[FFT] Received unexpected input and output type combination.\");\n }\n}\n\n} // namespace mlx::core\n\n// Path: mlx/backend/common/sort.cpp\n// Copyright \u00a9 2023 Apple Inc.\n\n#include \n#include \n#include \n#include \n\n#include \"mlx/backend/common/copy.h\"\n#include \"mlx/backend/common/utils.h\"\n\n#include \"mlx/primitives.h\"\n\nnamespace mlx::core {\n\nnamespace {\n\ntemplate \nstruct StridedIterator {\n using iterator_category = std::random_access_iterator_tag;\n using difference_type = IdxT;\n using value_type = T;\n using reference = value_type&;\n using pointer = value_type*;\n\n // Constructors\n StridedIterator() = default;\n\n explicit StridedIterator(T* ptr, size_t stride, difference_type offset = 0)\n : ptr_(ptr + offset * stride), stride_(stride) {}\n\n explicit StridedIterator(array& arr, int axis, difference_type offset = 0)\n : StridedIterator(arr.data(), arr.strides()[axis], offset) {}\n\n // Accessors\n reference operator*() const {\n return ptr_[0];\n }\n\n reference operator[](difference_type idx) const {\n return ptr_[idx * stride_];\n }\n\n // Comparisons\n bool operator==(const StridedIterator& other) const {\n return ptr_ == other.ptr_ && stride_ == other.stride_;\n }\n\n bool operator!=(const StridedIterator& other) const {\n return ptr_ != other.ptr_;\n }\n\n bool operator<(const StridedIterator& other) const {\n return ptr_ < other.ptr_;\n }\n\n bool operator>(const StridedIterator& other) const {\n return ptr_ > other.ptr_;\n }\n\n bool operator<=(const StridedIterator& other) const {\n return ptr_ <= other.ptr_;\n }\n\n bool operator>=(const StridedIterator& other) const {\n return ptr_ >= other.ptr_;\n }\n\n difference_type operator-(const StridedIterator& other) const {\n return (ptr_ - other.ptr_) / stride_;\n }\n\n // Moving\n StridedIterator& operator++() {\n ptr_ += stride_;\n return *this;\n }\n\n StridedIterator& operator--() {\n ptr_ -= stride_;\n return *this;\n }\n\n StridedIterator& operator+=(difference_type diff) {\n ptr_ += diff * stride_;\n return *this;\n }\n\n StridedIterator& operator-=(difference_type diff) {\n ptr_ -= diff * stride_;\n return *this;\n }\n\n StridedIterator operator+(difference_type diff) {\n return StridedIterator(ptr_, stride_, diff);\n }\n\n StridedIterator operator-(difference_type diff) {\n return StridedIterator(ptr_, stride_, -diff);\n }\n\n private:\n size_t stride_;\n T* ptr_;\n};\n\ntemplate \nvoid sort(const array& in, array& out, int axis) {\n // Copy input to output\n CopyType ctype = in.flags().contiguous ? CopyType::Vector : CopyType::General;\n copy(in, out, ctype);\n\n // Get axis, shape and stride info\n axis = axis < 0 ? axis + in.ndim() : axis;\n size_t n_rows = in.size() / in.shape(axis);\n\n auto remaining_shape = in.shape();\n remaining_shape.erase(remaining_shape.begin() + axis);\n\n auto remaining_strides = in.strides();\n remaining_strides.erase(remaining_strides.begin() + axis);\n\n size_t axis_stride = in.strides()[axis];\n int axis_size = in.shape(axis);\n\n // Perform sorting in place\n for (int i = 0; i < n_rows; i++) {\n size_t loc = elem_to_loc(i, remaining_shape, remaining_strides);\n T* data_ptr = out.data() + loc;\n\n StridedIterator st(data_ptr, axis_stride, 0);\n StridedIterator ed(data_ptr, axis_stride, axis_size);\n\n std::stable_sort(st, ed);\n }\n}\n\ntemplate \nvoid argsort(const array& in, array& out, int axis) {\n // Allocate output\n out.set_data(allocator::malloc_or_wait(out.nbytes()));\n\n // Get axis, shape and stride info\n axis = axis < 0 ? axis + in.ndim() : axis;\n size_t n_rows = in.size() / in.shape(axis);\n\n auto remaining_shape = in.shape();\n remaining_shape.erase(remaining_shape.begin() + axis);\n\n auto remaining_strides = in.strides();\n remaining_strides.erase(remaining_strides.begin() + axis);\n\n size_t axis_stride = in.strides()[axis];\n int axis_size = in.shape(axis);\n\n // Perform sorting\n for (int i = 0; i < n_rows; i++) {\n size_t loc = elem_to_loc(i, remaining_shape, remaining_strides);\n const T* data_ptr = in.data() + loc;\n IdxT* idx_ptr = out.data() + loc;\n\n StridedIterator st_(idx_ptr, axis_stride, 0);\n StridedIterator ed_(idx_ptr, axis_stride, axis_size);\n\n // Initialize with iota\n std::iota(st_, ed_, IdxT(0));\n\n // Sort according to vals\n StridedIterator st(idx_ptr, axis_stride, 0);\n StridedIterator ed(idx_ptr, axis_stride, axis_size);\n\n std::stable_sort(st, ed, [data_ptr, axis_stride](IdxT a, IdxT b) {\n auto v1 = data_ptr[a * axis_stride];\n auto v2 = data_ptr[b * axis_stride];\n return v1 < v2 || (v1 == v2 && a < b);\n });\n }\n}\n\ntemplate \nvoid partition(const array& in, array& out, int axis, int kth) {\n // Copy input to output\n CopyType ctype = in.flags().contiguous ? CopyType::Vector : CopyType::General;\n copy(in, out, ctype);\n\n // Get axis, shape and stride info\n axis = axis < 0 ? axis + in.ndim() : axis;\n size_t n_rows = in.size() / in.shape(axis);\n\n auto remaining_shape = in.shape();\n remaining_shape.erase(remaining_shape.begin() + axis);\n\n auto remaining_strides = in.strides();\n remaining_strides.erase(remaining_strides.begin() + axis);\n\n size_t axis_stride = in.strides()[axis];\n int axis_size = in.shape(axis);\n\n kth = kth < 0 ? kth + axis_size : kth;\n\n // Perform partition in place\n for (int i = 0; i < n_rows; i++) {\n size_t loc = elem_to_loc(i, remaining_shape, remaining_strides);\n T* data_ptr = out.data() + loc;\n\n StridedIterator st(data_ptr, axis_stride, 0);\n StridedIterator md(data_ptr, axis_stride, kth);\n StridedIterator ed(data_ptr, axis_stride, axis_size);\n\n std::nth_element(st, md, ed);\n }\n}\n\ntemplate \nvoid argpartition(const array& in, array& out, int axis, int kth) {\n // Allocate output\n out.set_data(allocator::malloc_or_wait(out.nbytes()));\n\n // Get axis, shape and stride info\n axis = axis < 0 ? axis + in.ndim() : axis;\n size_t n_rows = in.size() / in.shape(axis);\n\n auto remaining_shape = in.shape();\n remaining_shape.erase(remaining_shape.begin() + axis);\n\n auto remaining_strides = in.strides();\n remaining_strides.erase(remaining_strides.begin() + axis);\n\n size_t axis_stride = in.strides()[axis];\n int axis_size = in.shape(axis);\n\n kth = kth < 0 ? kth + axis_size : kth;\n\n // Perform partition\n for (int i = 0; i < n_rows; i++) {\n size_t loc = elem_to_loc(i, remaining_shape, remaining_strides);\n const T* data_ptr = in.data() + loc;\n IdxT* idx_ptr = out.data() + loc;\n\n StridedIterator st_(idx_ptr, axis_stride, 0);\n StridedIterator ed_(idx_ptr, axis_stride, axis_size);\n\n // Initialize with iota\n std::iota(st_, ed_, IdxT(0));\n\n // Sort according to vals\n StridedIterator st(idx_ptr, axis_stride, 0);\n StridedIterator md(idx_ptr, axis_stride, kth);\n StridedIterator ed(idx_ptr, axis_stride, axis_size);\n\n std::nth_element(st, md, ed, [data_ptr, axis_stride](IdxT a, IdxT b) {\n auto v1 = data_ptr[a * axis_stride];\n auto v2 = data_ptr[b * axis_stride];\n return v1 < v2 || (v1 == v2 && a < b);\n });\n }\n}\n\n} // namespace\n\nvoid ArgSort::eval(const std::vector& inputs, array& out) {\n assert(inputs.size() == 1);\n auto& in = inputs[0];\n\n switch (in.dtype()) {\n case bool_:\n return argsort(in, out, axis_);\n case uint8:\n return argsort(in, out, axis_);\n case uint16:\n return argsort(in, out, axis_);\n case uint32:\n return argsort(in, out, axis_);\n case uint64:\n return argsort(in, out, axis_);\n case int8:\n return argsort(in, out, axis_);\n case int16:\n return argsort(in, out, axis_);\n case int32:\n return argsort(in, out, axis_);\n case int64:\n return argsort(in, out, axis_);\n case float32:\n return argsort(in, out, axis_);\n case float16:\n return argsort(in, out, axis_);\n case bfloat16:\n return argsort(in, out, axis_);\n case complex64:\n return argsort(in, out, axis_);\n }\n}\n\nvoid Sort::eval(const std::vector& inputs, array& out) {\n assert(inputs.size() == 1);\n auto& in = inputs[0];\n\n switch (in.dtype()) {\n case bool_:\n return sort(in, out, axis_);\n case uint8:\n return sort(in, out, axis_);\n case uint16:\n return sort(in, out, axis_);\n case uint32:\n return sort(in, out, axis_);\n case uint64:\n...\n", "needle_token_start": 7054, "needle_token_end": 7592, "code_context_ntokens": 16205, "output": ["```cpp\ntemplate \nvoid copy(const array& src, array& dst, CopyType ctype, Args&&... args) {\n switch (dst.dtype()) {\n case bool_:\n copy(src, dst, ctype, std::forward(args)...);\n break;\n case uint8:\n copy(src, dst, ctype, std::forward(args)...);\n break;\n case uint16:\n copy(src, dst, ctype, std::forward(args)...);\n break;\n case uint32:\n copy(src, dst, ctype, std::forward(args)...);\n break;\n case uint64:\n copy(src, dst, ctype, std::forward(args)...);\n break;\n case int8:\n copy(src, dst, ctype, std::forward(args)...);\n break;\n case int16:\n copy(src, dst, ctype, std::forward(args)...);\n break;\n case int32:\n copy(src, dst, ctype, std::forward(args)...);\n break;\n case int64:\n copy(src, dst, ctype, std::forward(args)...);\n break;\n case float16:\n copy(src, dst, ctype, std::forward(args)...);\n break;\n case float32:\n copy(src, dst, ctype, std::forward(args)...);\n break;\n case bfloat16:\n copy(src, dst, ctype, std::forward(args)...);\n break;\n case complex64:\n copy(src, dst, ctype, std::forward(args)...);\n break;\n }\n}\n```"]} +{"repo": "ml-explore/mlx", "name": "arange_set_scalars", "language": "cpp", "path": "mlx/backend/metal/primitives.cpp", "position_ratio": 0.55, "description": "\nFunction Description:\n1. **Purpose**: The function configures a command encoder with scalar values necessary for generating a sequence of numbers starting from a specified value and increasing by a constant step.\n2. **Input**: The function accepts two scalar values of a generic type, representing the starting value of the sequence and the next value in the sequence.\n3. **Output**: There is no direct output returned by the function; however, it sets scalar values in a command encoder which are used in subsequent operations.\n4. **Procedure**: The function first sets the starting value of the sequence in the command encoder. It then calculates the step size by subtracting the starting value from the next value. This step size is also set in the command encoder.\n\n", "instruction": "Based on the function description and code context, please retrieve and repeat the exact described function from the code context in a code block wrapped by ```:", "template": "instruction\ncode_context\ndescription\ninstruction", "code_context": "// Path: mlx/backend/metal/device.cpp\n// Copyright \u00a9 2023-2024 Apple Inc.\n\n#include \n#include \n#include \n#include \n\n#define NS_PRIVATE_IMPLEMENTATION\n#define CA_PRIVATE_IMPLEMENTATION\n#define MTL_PRIVATE_IMPLEMENTATION\n\n#include \"mlx/backend/metal/device.h\"\n#include \"mlx/backend/metal/metal.h\"\n#include \"mlx/backend/metal/metal_impl.h\"\n#include \"mlx/backend/metal/mps/gemm.h\"\n#include \"mlx/backend/metal/utils.h\"\n\nnamespace fs = std::filesystem;\n\nnamespace mlx::core::metal {\n\nnamespace {\n\n// TODO nicer way to set this or possibly expose as an environment variable\nconstexpr int MAX_BUFFERS_PER_QUEUE = 12;\n\nconstexpr const char* default_mtllib_path = METAL_PATH;\n\nauto load_device() {\n auto devices = MTL::CopyAllDevices();\n auto device = static_cast(devices->object(0))\n ?: MTL::CreateSystemDefaultDevice();\n if (!device) {\n throw std::runtime_error(\"Failed to load device\");\n }\n return device;\n}\n\nstd::pair load_library_from_path(\n MTL::Device* device,\n const char* path) {\n auto library = NS::String::string(path, NS::UTF8StringEncoding);\n NS::Error* error;\n auto lib = device->newLibrary(library, &error);\n\n return std::make_pair(lib, error);\n}\n\n#ifdef SWIFTPM_BUNDLE\nMTL::Library* try_load_bundle(MTL::Device* device, NS::URL* url) {\n std::string bundle_path = std::string(url->fileSystemRepresentation()) + \"/\" +\n SWIFTPM_BUNDLE + \".bundle\";\n auto bundle = NS::Bundle::alloc()->init(\n NS::String::string(bundle_path.c_str(), NS::UTF8StringEncoding));\n if (bundle != nullptr) {\n std::string resource_path =\n std::string(bundle->resourceURL()->fileSystemRepresentation()) + \"/\" +\n \"default.metallib\";\n auto [lib, error] = load_library_from_path(device, resource_path.c_str());\n if (lib) {\n return lib;\n }\n }\n return nullptr;\n}\n#endif\n\nMTL::Library* load_library(\n MTL::Device* device,\n const std::string& lib_name = \"mlx\",\n const char* lib_path = default_mtllib_path) {\n // Firstly, search for the metallib in the same path as this binary\n std::string first_path = get_colocated_mtllib_path(lib_name);\n if (first_path.size() != 0) {\n auto [lib, error] = load_library_from_path(device, first_path.c_str());\n if (lib) {\n return lib;\n }\n }\n\n#ifdef SWIFTPM_BUNDLE\n // try to load from a swiftpm resource bundle -- scan the available bundles to\n // find one that contains the named bundle\n {\n MTL::Library* library =\n try_load_bundle(device, NS::Bundle::mainBundle()->bundleURL());\n if (library != nullptr) {\n return library;\n }\n auto bundles = NS::Bundle::allBundles();\n for (int i = 0, c = (int)bundles->count(); i < c; i++) {\n auto bundle = reinterpret_cast(bundles->object(i));\n library = try_load_bundle(device, bundle->resourceURL());\n if (library != nullptr) {\n return library;\n }\n }\n }\n#endif\n\n // Couldn't find it so let's load it from default_mtllib_path\n {\n auto [lib, error] = load_library_from_path(device, lib_path);\n if (!lib) {\n std::ostringstream msg;\n msg << error->localizedDescription()->utf8String() << \"\\n\"\n << \"Failed to load device library from <\" << lib_path << \">\"\n << \" or <\" << first_path << \">.\";\n throw std::runtime_error(msg.str());\n }\n return lib;\n }\n}\n\n} // namespace\n\nDevice::Device() {\n auto pool = new_scoped_memory_pool();\n device_ = load_device();\n library_map_ = {{\"mlx\", load_library(device_)}};\n}\n\nDevice::~Device() {\n auto pool = new_scoped_memory_pool();\n for (auto& q : queue_map_) {\n q.second->release();\n }\n for (auto& b : buffer_map_) {\n b.second.second->release();\n }\n for (auto& e : encoder_map_) {\n e.second->release();\n }\n for (auto& k : kernel_map_) {\n k.second->release();\n }\n for (auto& l : library_map_) {\n l.second->release();\n }\n device_->release();\n}\n\nvoid Device::new_queue(int index) {\n auto thread_pool = metal::new_scoped_memory_pool();\n\n // Multiple threads can ask the device for queues\n // We lock this as a critical section for safety\n const std::lock_guard lock(mtx_);\n auto q = device_->newCommandQueue(MAX_BUFFERS_PER_QUEUE);\n debug_set_stream_queue_label(q, index);\n if (!q) {\n throw std::runtime_error(\n \"[metal::Device] Failed to make new command queue.\");\n }\n queue_map_.insert({index, q});\n}\n\nint Device::get_command_buffer_ops(int index) {\n auto bit = buffer_map_.find(index);\n return bit->second.first;\n}\n\nvoid Device::increment_command_buffer_ops(int index) {\n auto bit = buffer_map_.find(index);\n bit->second.first++;\n}\n\nMTL::CommandBuffer* Device::get_command_buffer(int index) {\n auto bit = buffer_map_.find(index);\n return (bit == buffer_map_.end()) ? nullptr : bit->second.second;\n}\n\nMTL::CommandBuffer* Device::new_command_buffer(int index) {\n auto qit = queue_map_.find(index);\n if (qit == queue_map_.end()) {\n throw std::runtime_error(\n \"[metal::Device] Attempting to get command buffer for invalid queue.\");\n }\n\n auto cb = qit->second->commandBufferWithUnretainedReferences();\n\n if (!cb) {\n throw std::runtime_error(\n \"[metal::Device] Unable to create new command buffer\");\n }\n\n // Increment ref count so the buffer is not garbage collected\n cb->retain();\n\n return buffer_map_.insert({index, {0, cb}}).first->second.second;\n}\n\nvoid Device::commit_command_buffer(int index) {\n auto bit = buffer_map_.find(index);\n bit->second.second->commit();\n bit->second.second->release();\n buffer_map_.erase(bit);\n}\n\nvoid Device::end_encoding(int index) {\n auto eit = encoder_map_.find(index);\n if (eit != encoder_map_.end()) {\n eit->second->endEncoding();\n eit->second->release();\n encoder_map_.erase(eit);\n }\n}\n\nCommandEncoder& Device::get_command_encoder(int index) {\n auto eit = encoder_map_.find(index);\n if (eit == encoder_map_.end()) {\n auto cb = get_command_buffer(index);\n auto compute_encoder =\n cb->computeCommandEncoder(MTL::DispatchTypeConcurrent);\n // Increment ref count so the buffer is not garbage collected\n compute_encoder->retain();\n eit = encoder_map_.emplace(index, CommandEncoder{compute_encoder}).first;\n }\n return eit->second;\n}\n\nvoid Device::register_library(\n const std::string& lib_name,\n const std::string& lib_path) {\n if (auto it = library_map_.find(lib_name); it == library_map_.end()) {\n auto new_lib = load_library(device_, lib_name, lib_path.c_str());\n library_map_.insert({lib_name, new_lib});\n }\n}\n\nvoid Device::register_library(\n const std::string& lib_name,\n const std::function& lib_path_func) {\n if (auto it = library_map_.find(lib_name); it == library_map_.end()) {\n std::string new_lib_path = lib_path_func(lib_name);\n auto new_lib = load_library(device_, lib_name, new_lib_path.c_str());\n library_map_.insert({lib_name, new_lib});\n }\n}\n\nMTL::Library* Device::get_library_cache_(const std::string& lib_name) {\n // Search for cached metal lib\n MTL::Library* mtl_lib;\n if (auto it = library_map_.find(lib_name); it != library_map_.end()) {\n mtl_lib = it->second;\n } else { // Look for metallib alongside library\n register_library(lib_name);\n mtl_lib = library_map_[lib_name];\n }\n\n return mtl_lib;\n}\n\nMTL::Library* Device::get_library_(const std::string& source_string) {\n auto pool = new_scoped_memory_pool();\n\n auto ns_code =\n NS::String::string(source_string.c_str(), NS::ASCIIStringEncoding);\n\n NS::Error* error = nullptr;\n auto mtl_lib = device_->newLibrary(ns_code, nullptr, &error);\n\n // Throw error if unable to compile library\n if (!mtl_lib) {\n std::ostringstream msg;\n msg << \"[metal::Device] Unable to load build metal library from source\"\n << \"\\n\";\n if (error) {\n msg << error->localizedDescription()->utf8String() << \"\\n\";\n }\n throw std::runtime_error(msg.str());\n }\n\n return mtl_lib;\n}\n\nMTL::Library* Device::get_library_(const MTL::StitchedLibraryDescriptor* desc) {\n auto pool = new_scoped_memory_pool();\n\n NS::Error* error = nullptr;\n auto mtl_lib = device_->newLibrary(desc, &error);\n\n // Throw error if unable to compile library\n if (!mtl_lib) {\n std::ostringstream msg;\n msg << \"[metal::Device] Unable to load build stitched metal library\"\n << \"\\n\";\n if (error) {\n msg << error->localizedDescription()->utf8String() << \"\\n\";\n }\n throw std::runtime_error(msg.str());\n }\n\n return mtl_lib;\n}\n\nMTL::Function* Device::get_function_(\n const std::string& name,\n MTL::Library* mtl_lib) {\n // Pull kernel from library\n auto ns_name = NS::String::string(name.c_str(), NS::ASCIIStringEncoding);\n auto mtl_function = mtl_lib->newFunction(ns_name);\n\n return mtl_function;\n}\n\nMTL::Function* Device::get_function_(\n const std::string& name,\n const std::string& specialized_name,\n const MTLFCList& func_consts,\n MTL::Library* mtl_lib) {\n if (func_consts.empty() && (specialized_name == name)) {\n return get_function_(name, mtl_lib);\n }\n\n // Prepare function constants\n auto mtl_func_consts = MTL::FunctionConstantValues::alloc()->init();\n\n for (auto [value, type, index] : func_consts) {\n mtl_func_consts->setConstantValue(value, type, index);\n }\n\n // Prepare function desc\n auto desc = MTL::FunctionDescriptor::functionDescriptor();\n desc->setName(NS::String::string(name.c_str(), NS::ASCIIStringEncoding));\n desc->setSpecializedName(\n NS::String::string(specialized_name.c_str(), NS::ASCIIStringEncoding));\n desc->setConstantValues(mtl_func_consts);\n\n // Pull kernel from library\n NS::Error* error = nullptr;\n auto mtl_function = mtl_lib->newFunction(desc, &error);\n\n // Throw error if unable to build metal function\n if (!mtl_function) {\n std::ostringstream msg;\n msg << \"[metal::Device] Unable to load function \" << name << \"\\n\";\n if (error) {\n msg << error->localizedDescription()->utf8String() << \"\\n\";\n }\n throw std::runtime_error(msg.str());\n }\n\n mtl_func_consts->release();\n desc->release();\n\n return mtl_function;\n}\n\nMTL::ComputePipelineState* Device::get_kernel_(\n const std::string& name,\n const MTL::Function* mtl_function) {\n // Compile kernel to compute pipeline\n NS::Error* error = nullptr;\n MTL::ComputePipelineState* kernel;\n\n if (mtl_function) {\n kernel = device_->newComputePipelineState(mtl_function, &error);\n }\n\n // Throw error if unable to compile metal function\n if (!mtl_function || !kernel) {\n std::ostringstream msg;\n msg << \"[metal::Device] Unable to load kernel \" << name << \"\\n\";\n if (error) {\n msg << error->localizedDescription()->utf8String() << \"\\n\";\n }\n throw std::runtime_error(msg.str());\n }\n\n return kernel;\n}\n\nMTL::ComputePipelineState* Device::get_kernel_(\n const std::string& name,\n const MTL::Function* mtl_function,\n const MTL::LinkedFunctions* linked_functions) {\n // Check inputs\n if (!linked_functions) {\n return get_kernel_(name, mtl_function);\n }\n\n if (!mtl_function) {\n std::ostringstream msg;\n msg << \"[metal::Device] Unable to load kernel \" << name << \"\\n\";\n throw std::runtime_error(msg.str());\n }\n\n // Prepare compute pipeline state descriptor\n auto desc = MTL::ComputePipelineDescriptor::alloc()->init();\n desc->setComputeFunction(mtl_function);\n desc->setLinkedFunctions(linked_functions);\n\n // Compile kernel to compute pipeline\n NS::Error* error = nullptr;\n auto kernel = device_->newComputePipelineState(\n desc, MTL::PipelineOptionNone, nullptr, &error);\n\n // Throw error if unable to compile metal function\n if (!kernel) {\n std::ostringstream msg;\n msg << \"[metal::Device] Unable to load kernel \" << name << \"\\n\";\n if (error) {\n msg << error->localizedDescription()->utf8String() << \"\\n\";\n }\n throw std::runtime_error(msg.str());\n }\n\n return kernel;\n}\n\n...\n// Path: mlx/backend/metal/primitives.cpp\n// Copyright \u00a9 2023-2024 Apple Inc.\n#include \n#include \n#include \n#include \n\n#include \"mlx/backend/common/binary.h\"\n#include \"mlx/backend/common/ternary.h\"\n#include \"mlx/backend/metal/copy.h\"\n#include \"mlx/backend/metal/device.h\"\n#include \"mlx/backend/metal/kernels/defines.h\"\n#include \"mlx/backend/metal/utils.h\"\n#include \"mlx/primitives.h\"\n#include \"mlx/utils.h\"\n\nnamespace mlx::core {\n\nnamespace {\n\nconstexpr int METAL_MAX_INDEX_ARRAYS = 10;\n\nvoid binary_op(\n const std::vector& inputs,\n std::vector& outputs,\n const std::string op) {\n assert(inputs.size() == 2);\n auto& a = inputs[0];\n auto& b = inputs[1];\n auto bopt = get_binary_op_type(a, b);\n set_binary_op_output_data(a, b, outputs[0], bopt, true);\n set_binary_op_output_data(a, b, outputs[1], bopt, true);\n\n auto& out = outputs[0];\n if (out.size() == 0) {\n return;\n }\n\n // Try to collapse contiguous dims\n auto [shape, strides] = collapse_contiguous_dims(a, b, out);\n auto& strides_a = strides[0];\n auto& strides_b = strides[1];\n auto& strides_out = strides[2];\n\n std::ostringstream kname;\n switch (bopt) {\n case BinaryOpType::ScalarScalar:\n kname << \"ss\";\n break;\n case BinaryOpType::ScalarVector:\n kname << \"sv\";\n break;\n case BinaryOpType::VectorScalar:\n kname << \"vs\";\n break;\n case BinaryOpType::VectorVector:\n kname << \"vv\";\n break;\n case BinaryOpType::General:\n kname << \"g\";\n break;\n }\n kname << op << type_to_name(a);\n if (bopt == BinaryOpType::General &&\n shape.size() <= MAX_BINARY_SPECIALIZED_DIMS) {\n kname << \"_\" << shape.size();\n }\n\n auto& s = out.primitive().stream();\n auto& d = metal::device(s.device);\n auto kernel = d.get_kernel(kname.str());\n auto& compute_encoder = d.get_command_encoder(s.index);\n compute_encoder->setComputePipelineState(kernel);\n // - If a is donated it goes to the first output\n // - If b is donated it goes to the first output if a was not donated\n // otherwise it goes to the second output\n bool donate_a = a.data_shared_ptr() == nullptr;\n bool donate_b = b.data_shared_ptr() == nullptr;\n compute_encoder.set_input_array(donate_a ? outputs[0] : a, 0);\n compute_encoder.set_input_array(\n donate_b ? (donate_a ? outputs[1] : outputs[0]) : b, 1);\n compute_encoder.set_output_array(outputs[0], 2);\n compute_encoder.set_output_array(outputs[1], 3);\n\n if (bopt == BinaryOpType::General) {\n auto ndim = shape.size();\n if (ndim > 3) {\n compute_encoder->setBytes(shape.data(), ndim * sizeof(int), 4);\n compute_encoder->setBytes(strides_a.data(), ndim * sizeof(size_t), 5);\n compute_encoder->setBytes(strides_b.data(), ndim * sizeof(size_t), 6);\n } else {\n // The shape is implicit in the grid for <= 3D\n compute_encoder->setBytes(strides_a.data(), ndim * sizeof(size_t), 4);\n compute_encoder->setBytes(strides_b.data(), ndim * sizeof(size_t), 5);\n }\n\n if (ndim > MAX_BINARY_SPECIALIZED_DIMS) {\n compute_encoder->setBytes(&ndim, sizeof(int), 7);\n }\n\n // Launch up to 3D grid of threads\n size_t dim0 = ndim > 0 ? shape[ndim - 1] : 1;\n size_t dim1 = ndim > 1 ? shape[ndim - 2] : 1;\n size_t rest = out.size() / (dim0 * dim1);\n NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();\n if (thread_group_size != 1024) {\n throw std::runtime_error(\"[Metal::binary] Must use 1024 sized block\");\n }\n auto group_dims = get_block_dims(dim0, dim1, rest);\n MTL::Size grid_dims = MTL::Size(dim0, dim1, rest);\n compute_encoder->dispatchThreads(grid_dims, group_dims);\n } else {\n // Launch a 1D grid of threads\n size_t nthreads = out.data_size();\n MTL::Size grid_dims = MTL::Size(nthreads, 1, 1);\n NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();\n if (thread_group_size > nthreads) {\n thread_group_size = nthreads;\n }\n MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1);\n compute_encoder->dispatchThreads(grid_dims, group_dims);\n }\n}\n\nvoid binary_op(\n const std::vector& inputs,\n array& out,\n const std::string op) {\n assert(inputs.size() == 2);\n auto& a = inputs[0];\n auto& b = inputs[1];\n auto bopt = get_binary_op_type(a, b);\n set_binary_op_output_data(a, b, out, bopt, true);\n if (out.size() == 0) {\n return;\n }\n\n // Try to collapse contiguous dims\n auto [shape, strides] = collapse_contiguous_dims(a, b, out);\n auto& strides_a = strides[0];\n auto& strides_b = strides[1];\n auto& strides_out = strides[2];\n\n std::ostringstream kname;\n switch (bopt) {\n case BinaryOpType::ScalarScalar:\n kname << \"ss\";\n break;\n case BinaryOpType::ScalarVector:\n kname << \"sv\";\n break;\n case BinaryOpType::VectorScalar:\n kname << \"vs\";\n break;\n case BinaryOpType::VectorVector:\n kname << \"vv\";\n break;\n case BinaryOpType::General:\n kname << \"g\";\n break;\n }\n kname << op << type_to_name(a);\n if (bopt == BinaryOpType::General &&\n shape.size() <= MAX_BINARY_SPECIALIZED_DIMS) {\n kname << \"_\" << shape.size();\n }\n\n auto& s = out.primitive().stream();\n auto& d = metal::device(s.device);\n auto kernel = d.get_kernel(kname.str());\n auto& compute_encoder = d.get_command_encoder(s.index);\n compute_encoder->setComputePipelineState(kernel);\n bool donate_a = a.data_shared_ptr() == nullptr;\n bool donate_b = b.data_shared_ptr() == nullptr;\n compute_encoder.set_input_array(donate_a ? out : a, 0);\n compute_encoder.set_input_array(donate_b ? out : b, 1);\n compute_encoder.set_output_array(out, 2);\n\n if (bopt == BinaryOpType::General) {\n auto ndim = shape.size();\n if (ndim > 3) {\n compute_encoder->setBytes(shape.data(), ndim * sizeof(int), 3);\n compute_encoder->setBytes(strides_a.data(), ndim * sizeof(size_t), 4);\n compute_encoder->setBytes(strides_b.data(), ndim * sizeof(size_t), 5);\n } else {\n // The shape is implicit in the grid for <= 3D\n compute_encoder->setBytes(strides_a.data(), ndim * sizeof(size_t), 3);\n compute_encoder->setBytes(strides_b.data(), ndim * sizeof(size_t), 4);\n }\n\n if (ndim > MAX_BINARY_SPECIALIZED_DIMS) {\n compute_encoder->setBytes(&ndim, sizeof(int), 6);\n }\n\n // Launch up to 3D grid of threads\n size_t dim0 = ndim > 0 ? shape[ndim - 1] : 1;\n size_t dim1 = ndim > 1 ? shape[ndim - 2] : 1;\n size_t rest = out.size() / (dim0 * dim1);\n NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();\n if (thread_group_size != 1024) {\n throw std::runtime_error(\"[Metal::binary] Must use 1024 sized block\");\n }\n auto group_dims = get_block_dims(dim0, dim1, rest);\n MTL::Size grid_dims = MTL::Size(dim0, dim1, rest);\n compute_encoder->dispatchThreads(grid_dims, group_dims);\n } else {\n // Launch a 1D grid of threads\n size_t nthreads =\n bopt == BinaryOpType::General ? out.size() : out.data_size();\n MTL::Size grid_dims = MTL::Size(nthreads, 1, 1);\n NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();\n if (thread_group_size > nthreads) {\n thread_group_size = nthreads;\n }\n MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1);\n compute_encoder->dispatchThreads(grid_dims, group_dims);\n }\n}\n\nvoid ternary_op(\n const std::vector& inputs,\n array& out,\n const std::string op) {\n assert(inputs.size() == 3);\n auto& a = inputs[0];\n auto& b = inputs[1];\n auto& c = inputs[2];\n TernaryOpType topt = get_ternary_op_type(a, b, c);\n set_ternary_op_output_data(a, b, c, out, topt, true /* donate_with_move */);\n\n if (out.size() == 0) {\n return;\n }\n\n // Try to collapse contiguous dims\n auto [shape, strides] = collapse_contiguous_dims(a, b, c, out);\n auto& strides_a = strides[0];\n auto& strides_b = strides[1];\n auto& strides_c = strides[2];\n auto& strides_out = strides[3];\n\n std::ostringstream kname;\n if (topt == TernaryOpType::General) {\n kname << \"g\";\n kname << op << type_to_name(b);\n if (shape.size() <= MAX_BINARY_SPECIALIZED_DIMS) {\n kname << \"_\" << shape.size();\n }\n } else {\n kname << \"v\";\n kname << op << type_to_name(b);\n }\n\n auto& s = out.primitive().stream();\n auto& d = metal::device(s.device);\n auto kernel = d.get_kernel(kname.str());\n auto& compute_encoder = d.get_command_encoder(s.index);\n compute_encoder->setComputePipelineState(kernel);\n compute_encoder.set_input_array(a, 0);\n compute_encoder.set_input_array(b, 1);\n compute_encoder.set_input_array(c, 2);\n compute_encoder.set_output_array(out, 3);\n\n if (topt == TernaryOpType::General) {\n auto ndim = shape.size();\n if (ndim > 3) {\n compute_encoder->setBytes(shape.data(), ndim * sizeof(int), 4);\n compute_encoder->setBytes(strides_a.data(), ndim * sizeof(size_t), 5);\n compute_encoder->setBytes(strides_b.data(), ndim * sizeof(size_t), 6);\n compute_encoder->setBytes(strides_c.data(), ndim * sizeof(size_t), 7);\n\n if (ndim > MAX_BINARY_SPECIALIZED_DIMS) {\n compute_encoder->setBytes(&ndim, sizeof(int), 8);\n }\n } else {\n // The shape is implicit in the grid for <= 3D\n compute_encoder->setBytes(strides_a.data(), ndim * sizeof(size_t), 4);\n compute_encoder->setBytes(strides_b.data(), ndim * sizeof(size_t), 5);\n compute_encoder->setBytes(strides_c.data(), ndim * sizeof(size_t), 6);\n }\n\n // Launch up to 3D grid of threads\n size_t dim0 = ndim > 0 ? shape[ndim - 1] : 1;\n size_t dim1 = ndim > 1 ? shape[ndim - 2] : 1;\n size_t rest = out.size() / (dim0 * dim1);\n NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();\n if (thread_group_size != 1024) {\n throw std::runtime_error(\"[Metal::binary] Must use 1024 sized block\");\n }\n MTL::Size group_dims = get_block_dims(dim0, dim1, rest);\n MTL::Size grid_dims = MTL::Size(dim0, dim1, rest);\n compute_encoder->dispatchThreads(grid_dims, group_dims);\n } else {\n // Launch a 1D grid of threads\n size_t nthreads = out.data_size();\n MTL::Size grid_dims = MTL::Size(nthreads, 1, 1);\n NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();\n if (thread_group_size > nthreads) {\n thread_group_size = nthreads;\n }\n MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1);\n compute_encoder->dispatchThreads(grid_dims, group_dims);\n }\n}\n\nvoid unary_op(\n const std::vector& inputs,\n array& out,\n const std::string op) {\n auto& in = inputs[0];\n bool contig = in.flags().contiguous;\n if (contig) {\n if (in.is_donatable() && in.itemsize() == out.itemsize()) {\n out.move_shared_buffer(in);\n } else {\n out.set_data(\n allocator::malloc_or_wait(in.data_size() * out.itemsize()),\n in.data_size(),\n in.strides(),\n in.flags());\n }\n } else {\n out.set_data(allocator::malloc_or_wait(out.nbytes()));\n }\n if (in.size() == 0) {\n return;\n }\n\n auto& s = out.primitive().stream();\n auto& d = metal::device(s.device);\n std::string tname = type_to_name(in);\n std::string opt_name = contig ? \"v\" : \"g\";\n auto kernel = d.get_kernel(opt_name + op + tname);\n\n size_t nthreads = contig ? in.data_size() : in.size();\n MTL::Size grid_dims = MTL::Size(nthreads, 1, 1);\n NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();\n if (thread_group_size > nthreads) {\n thread_group_size = nthreads;\n }\n MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1);\n\n auto& compute_encoder = d.get_command_encoder(s.index);\n compute_encoder->setComputePipelineState(kernel);\n compute_encoder.set_input_array(\n in.data_shared_ptr() == nullptr ? out : in, 0);\n compute_encoder.set_output_array(out, 1);\n if (!contig) {\n compute_encoder->setBytes(in.shape().data(), in.ndim() * sizeof(int), 2);\n compute_encoder->setBytes(\n in.strides().data(), in.ndim() * sizeof(size_t), 3);\n int ndim = in.ndim();\n compute_encoder->setBytes(&ndim, sizeof(int), 4);\n }\n compute_encoder->dispatchThreads(grid_dims, group_dims);\n}\n\n} // namespace\n\nvoid Abs::eval_gpu(const std::vector& inputs, array& out) {\n unary_op(inputs, out, \"abs\");\n}\n\nvoid Add::eval_gpu(const std::vector& inputs, array& out) {\n binary_op(inputs, out, \"add\");\n}\n\ntemplate \nv\noid arange_set_scalars(T start, T next, CommandEncoder& enc) {\n enc->setBytes(&start, sizeof(T), 0);\n T step = next - start;\n enc->setBytes(&step, sizeof(T), 1);\n}\n\nvoid Arange::eval_gpu(const std::vector& inputs, array& out) {\n assert(inputs.size() == 0);\n out.set_data(allocator::malloc_or_wait(out.nbytes()));\n if (out.size() == 0) {\n return;\n }\n auto& s = stream();\n auto& d = metal::device(s.device);\n auto kernel = d.get_kernel(\"arange\" + type_to_name(out));\n size_t nthreads = out.size();\n MTL::Size grid_dims = MTL::Size(nthreads, 1, 1);\n MTL::Size group_dims = MTL::Size(\n std::min(nthreads, kernel->maxTotalThreadsPerThreadgroup()), 1, 1);\n auto& compute_encoder = d.get_command_encoder(s.index);\n compute_encoder->setComputePipelineState(kernel);\n\n switch (out.dtype()) {\n case bool_: // unsupported\n throw std::runtime_error(\"[Arange::eval_gpu] Does not support bool\");\n case uint8:\n arange_set_scalars(start_, start_ + step_, compute_encoder);\n break;\n case uint16:\n arange_set_scalars(start_, start_ + step_, compute_encoder);\n break;\n case uint32:\n arange_set_scalars(start_, start_ + step_, compute_encoder);\n break;\n case uint64:\n arange_set_scalars(start_, start_ + step_, compute_encoder);\n break;\n case int8:\n arange_set_scalars(start_, start_ + step_, compute_encoder);\n break;\n case int16:\n arange_set_scalars(start_, start_ + step_, compute_encoder);\n break;\n case int32:\n arange_set_scalars(start_, start_ + step_, compute_encoder);\n break;\n case int64:\n arange_set_scalars(start_, start_ + step_, compute_encoder);\n break;\n case float16:\n arange_set_scalars(start_, start_ + step_, compute_encoder);\n break;\n case float32:\n arange_set_scalars(start_, start_ + step_, compute_encoder);\n break;\n case bfloat16:\n arange_set_scalars(start_, start_ + step_, compute_encoder);\n break;\n case complex64:\n throw std::runtime_error(\"[Arange::eval_gpu] Does not support complex64\");\n }\n\n compute_encoder.set_output_array(out, 2);\n compute_encoder->dispatchThreads(grid_dims, group_dims);\n}\n\nvoid ArcCos::eval_gpu(const std::vector& inputs, array& out) {\n unary_op(inputs, out, \"arccos\");\n}\n\nvoid ArcCosh::eval_gpu(const std::vector& inputs, array& out) {\n unary_op(inputs, out, \"arccosh\");\n}\n\nvoid ArcSin::eval_gpu(const std::vector& inputs, array& out) {\n unary_op(inputs, out, \"arcsin\");\n}\n\nvoid ArcSinh::eval_gpu(const std::vector& inputs, array& out) {\n unary_op(inputs, out, \"arcsinh\");\n}\n\nvoid ArcTan::eval_gpu(const std::vector& inputs, array& out) {\n unary_op(inputs, out, \"arctan\");\n}\n\nvoid ArcTanh::eval_gpu(const std::vector& inputs, array& out) {\n unary_op(inputs, out, \"arctanh\");\n}\n\nvoid ArgReduce::eval_gpu(const std::vector& inputs, array& out) {\n assert(inputs.size() == 1);\n auto& in = inputs[0];\n out.set_data(allocator::malloc_or_wait(out.nbytes()));\n auto& s = stream();\n auto& d = metal::device(s.device);\n std::string op_name;\n switch (reduce_type_) {\n case ArgReduce::ArgMin:\n op_name = \"argmin_\";\n break;\n case ArgReduce::ArgMax:\n op_name = \"argmax_\";\n break;\n }\n\n // Prepare the shapes, strides and axis arguments.\n std::vector in_strides = in.strides();\n std::vector shape = in.shape();\n std::vector out_strides = out.strides();\n size_t axis_stride = in_strides[axis_];\n size_t axis_size = shape[axis_];\n if (out_strides.size() == in_strides.size()) {\n out_strides.erase(out_strides.begin() + axis_);\n }\n in_strides.erase(in_strides.begin() + axis_);\n shape.erase(shape.begin() + axis_);\n size_t ndim = shape.size();\n\n // ArgReduce\n int simd_size = 32;\n int n_reads = 4;\n auto& compute_encoder = d.get_command_encoder(s.index);\n {\n auto kernel = d.get_kernel(op_name + type_to_name(in));\n NS::UInteger thread_group_size = std::min(\n (axis_size + n_reads - 1) / n_reads,\n kernel->maxTotalThreadsPerThreadgroup());\n // round up to the closest number divisible by simd_size\n thread_group_size =\n (thread_group_size + simd_size - 1) / simd_size * simd_size;\n assert(thread_group_size <= kernel->maxTotalThreadsPerThreadgroup());\n\n size_t n_threads = out.size() * thread_group_size;\n MTL::Size grid_dims = MTL::Size(n_threads, 1, 1);\n MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1);\n compute_encoder->setComputePipelineState(kernel);\n compute_encoder.set_input_array(in, 0);\n compute_encoder.set_output_array(out, 1);\n if (ndim == 0) {\n // Pass place holders so metal doesn't complain\n int shape_ = 0;\n size_t stride_ = 0;\n compute_encoder->setBytes(&shape_, sizeof(int), 2);\n compute_encoder->setBytes(&stride_, sizeof(size_t), 3);\n compute_encoder->setBytes(&stride_, sizeof(size_t), 4);\n } else {\n compute_encoder->setBytes(shape.data(), ndim * sizeof(int), 2);\n compute_encoder->setBytes(in_strides.data(), ndim * sizeof(size_t), 3);\n compute_encoder->setBytes(out_strides.data(), ndim * sizeof(size_t), 4);\n }\n compute_encoder->setBytes(&ndim, sizeof(size_t), 5);\n compute_encoder->setBytes(&axis_stride, sizeof(size_t), 6);\n compute_encoder->setBytes(&axis_size, sizeof(size_t), 7);\n compute_encoder->dispatchThreads(grid_dims, group_dims);\n }\n}\n\nvoid AsType::eval_gpu(const std::vector& inputs, array& out) {\n CopyType ctype =\n inputs[0].flags().contiguous ? CopyType::Vector : CopyType::General;\n copy_gpu(inputs[0], out, ctype);\n}\n\nvoid AsStrided::eval_gpu(const std::vector& inputs, array& out) {\n eval(inputs, out);\n}\n\nvoid Broadcast::eval_gpu(const std::vector& inputs, array& out) {\n eval(inputs, out);\n}\n\nvoid Concatenate::eval_gpu(const std::vector& inputs, array& out) {\n std::vector sizes;\n sizes.push_back(0);\n for (auto& p : inputs) {\n sizes.push_back(p.shape(axis_));\n }\n std::partial_sum(sizes.cbegin(), sizes.cend(), sizes.begin());\n\n out.set_data(allocator::malloc_or_wait(out.nbytes()));\n\n auto strides = out.strides();\n auto flags = out.flags();\n flags.row_contiguous = false;\n flags.col_contiguous = false;\n flags.contiguous = false;\n auto& d = metal::device(stream().device);\n auto& compute_encoder = d.get_command_encoder(stream().index);\n auto concurrent_ctx = compute_encoder.start_concurrent();\n for (int i = 0; i < inputs.size(); i++) {\n array out_slice(inputs[i].shape(), out.dtype(), nullptr, {});\n size_t data_offset = strides[axis_] * sizes[i];\n out_slice.copy_shared_buffer(\n out, strides, flags, out_slice.size(), data_offset);\n copy_gpu_inplace(inputs[i], out_slice, CopyType::GeneralGeneral, stream());\n }\n}\n\nvoid Copy::eval_gpu(const std::vector& inputs, array& out) {\n eval(inputs, out);\n}\n\nvoid Cos::eval_gpu(const std::vector& inputs, array& out) {\n unary_op(inputs, out, \"cos\");\n}\n\nvoid Cosh::eval_gpu(const std::vector& inputs, array& out) {\n unary_op(inputs, out, \"cosh\");\n}\n\nvoid CustomVJP::eval_gpu(\n const std::vector& inputs,\n std::vector& outputs) {\n eval(inputs, outputs);\n}\n\nvoid Depends::eval_gpu(\n const std::vector& inputs,\n std::vector& outputs) {\n eval(inputs, outputs);\n}\n\nvoid Divide::eval_gpu(const std::vector& inputs, array& out) {\n binary_op(inputs, out, \"div\");\n}\n\nvoid DivMod::eval_gpu(\n const std::vector& inputs,\n std::vector& outputs) {\n binary_op(inputs, outputs, \"divmod\");\n}\n\nvoid Remainder::eval_gpu(const std::vector& inputs, array& out) {\n binary_op(inputs, out, \"rem\");\n}\n\nvoid Equal::eval_gpu(const std::vector& inputs, array& out) {\n binary_op(inputs, out, equal_nan_ ? \"naneq\" : \"eq\");\n}\n\nvoid Erf::eval_gpu(const std::vector& inputs, array& out) {\n unary_op(inputs, out, \"erf\");\n}\n\nvoid ErfInv::eval_gpu(const std::vector& inputs, array& out) {\n unary_op(inputs, out, \"erfinv\");\n}\n\nvoid Exp::eval_gpu(const std::vector& inputs, array& out) {\n unary_op(inputs, out, \"exp\");\n}\n\nvoid Expm1::eval_gpu(const std::vector& inputs, array& out) {\n unary_op(inputs, out, \"expm1\");\n}\n\nvoid Full::eval_gpu(const std::vector& inputs, array& out) {\n auto in = inputs[0];\n CopyType ctype;\n if (in.data_size() == 1) {\n ctype = CopyType::Scalar;\n } else if (in.flags().contiguous) {\n ctype = CopyType::Vector;\n } else {\n ctype = CopyType::General;\n }\n copy_gpu(in, out, ctype);\n}\n\nvoid Greater::eval_gpu(const std::vector& inputs, array& out) {\n binary_op(inputs, out, \"ge\");\n}\n\nvoid GreaterEqual::eval_gpu(const std::vector& inputs, array& out) {\n binary_op(inputs, out, \"geq\");\n}\n\nvoid Less::eval_gpu(const std::vector& inputs, array& out) {\n binary_op(inputs, out, \"le\");\n}\n\nvoid LessEqual::eval_gpu(const std::vector& inputs, array& out) {\n binary_op(inputs, out, \"leq\");\n}\n\nvoid Load::eval_gpu(const std::vector& inputs, array& out) {\n eval(inputs, out);\n}\n\nvoid Log::eval_gpu(const std::vector& inputs, array& out) {\n switch (base_) {\n case Base::e:\n unary_op(inputs, out, \"log\");\n break;\n case Base::two:\n unary_op(inputs, out, \"log2\");\n break;\n case Base::ten:\n unary_op(inputs, out, \"log10\");\n break;\n }\n}\n\nvoid Log1p::eval_gpu(const std::vector& inputs, array& out) {\n unary_op(inputs, out, \"log1p\");\n}\n\nvoid LogicalNot::eval_gpu(const std::vector& inputs, array& out) {\n unary_op(inputs, out, \"lnot\");\n}\n\nvoid LogicalAnd::eval_gpu(const std::vector& inputs, array& out) {\n binary_op(\n inputs,\n out,\n \"land\"); // Assume \"land\" is the operation identifier for logical AND\n}\n\nvoid LogicalOr::eval_gpu(const std::vector& inputs, array& out) {\n binary_op(\n inputs,\n out,\n \"lor\"); // Assume \"lor\" is the operation identifier for logical OR\n}\n\nvoid LogAddExp::eval_gpu(const std::vector& inputs, array& out) {\n binary_op(inputs, out, \"lae\");\n}\n\nvoid Maximum::eval_gpu(const std::vector& inputs, array& out) {\n binary_op(inputs, out, \"max\");\n}\n\nvoid Minimum::eval_gpu(const std::vector& inputs, array& out) {\n binary_op(inputs, out, \"min\");\n}\n\nvoid NumberOfElements::eval_gpu(const std::vector& inputs, array& out) {\n eval(inputs, out);\n}\n\nvoid Floor::eval_gpu(const std::vector& inputs, array& out) {\n unary_op(inputs, out, \"floor\");\n}\n\nvoid Ceil::eval_gpu(const std::vector& inputs, array& out) {\n unary_op(inputs, out, \"ceil\");\n}\n\nvoid Multiply::eval_gpu(const std::vector& inputs, array& out) {\n binary_op(inputs, out, \"mul\");\n}\n\nvoid Select::eval_gpu(const std::vector& inputs, array& out) {\n ternary_op(inputs, out, \"select\");\n}\n\nvoid Negative::eval_gpu(const std::vector& inputs, array& out) {\n unary_op(inputs, out, \"neg\");\n}\n\nvoid NotEqual::eval_gpu(const std::vector& inputs, array& out) {\n binary_op(inputs, out, \"neq\");\n}\n\nvoid Pad::eval_gpu(const std::vector& inputs, array& out) {\n // Inputs must be base input array and scalar val array\n assert(inputs.size() == 2);\n auto& in = inputs[0];\n auto& val = inputs[1];\n\n // Padding value must be a scalar\n assert(val.size() == 1);\n\n // Padding value, input and output must be of the same type\n assert(val.dtype() == in.dtype() && in.dtype() == out.dtype());\n\n // Fill output with val\n copy_gpu(val, out, CopyType::Scalar, stream());\n\n // Find offset for start of input values\n size_t data_offset = 0;\n for (int i = 0; i < axes_.size(); i++) {\n auto ax = axes_[i] < 0 ? out.ndim() + axes_[i] : axes_[i];\n data_offset += out.strides()[ax] * low_pad_size_[i];\n }\n\n // Extract slice from output where input will be pasted\n array out_slice(in.shape(), out.dtype(), nullptr, {});\n out_slice.copy_shared_buffer(\n out, out.strides(), out.flags(), out_slice.size(), data_offset);\n\n // Copy input values into the slice\n copy_gpu_inplace(in, out_slice, CopyType::GeneralGeneral, stream());\n}\n\nvoid Power::eval_gpu(const std::vector& inputs, array& out) {\n binary_op(inputs, out, \"pow\");\n}\n\nvoid RandomBits::eval_gpu(const std::vector& inputs, array& out) {\n assert(inputs.size() == 1);\n\n // keys has shape (N1, ..., NK, 2)\n // out has shape (N1, ..., NK, M1, M2, ...)\n auto& keys = inputs[0];\n size_t num_keys = keys.size() / 2;\n\n size_t elems_per_key = out.size() / num_keys;\n size_t bytes_per_key = out.itemsize() * elems_per_key;\n out.set_data(allocator::malloc_or_wait(out.nbytes()));\n if (out.size() == 0) {\n return;\n }\n\n size_t out_per_key = (bytes_per_key + 4 - 1) / 4;\n size_t half_size = out_per_key / 2;\n bool odd = out_per_key % 2;\n\n auto& s = stream();\n auto& d = metal::device(s.device);\n std::string kname = keys.flags().row_contiguous ? \"rbitsc\" : \"rbits\";\n auto kernel = d.get_kernel(kname);\n\n // organize into grid nkeys x elem_per_key\n MTL::Size grid_dims = MTL::Size(num_keys, half_size + odd, 1);\n NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();\n MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1);\n auto& compute_encoder = d.get_command_encoder(s.index);\n compute_encoder->setComputePipelineState(kernel);\n compute_encoder.set_input_array(keys, 0);\n compute_encoder.set_output_array(out, 1);\n compute_encoder->setBytes(&odd, sizeof(bool), 2);\n compute_encoder->setBytes(&bytes_per_key, sizeof(size_t), 3);\n\n if (!keys.flags().row_contiguous) {\n int ndim = keys.ndim();\n compute_encoder->setBytes(&ndim, sizeof(int), 4);\n compute_encoder->setBytes(\n keys.shape().data(), keys.ndim() * sizeof(int), 5);\n compute_encoder->setBytes(\n keys.strides().data(), keys.ndim() * sizeof(size_t), 6);\n }\n\n compute_encoder->dispatchThreads(grid_dims, group_dims);\n}\n\nvoid Reshape::eval_gpu(const std::vector& inputs, array& out) {\n assert(inputs.size() == 1);\n const auto& in = inputs[0];\n\n auto [copy_necessary, out_strides] = prepare_reshape(in, out);\n\n if (copy_necessary) {\n copy_gpu(in, out, CopyType::General);\n } else {\n shared_buffer_reshape(in, out_strides, out);\n }\n}\n\nvoid Round::eval_gpu(const std::vector& inputs, array& out) {\n assert(inputs.size() == 1);\n const auto& in = inputs[0];\n if (issubdtype(in.dtype(), inexact)) {\n unary_op(inputs, out, \"round\");\n } else {\n // No-op integer types\n out.copy_shared_buffer(in);\n }\n}\n\nvoid Sigmoid::eval_gpu(const std::vector& inputs, array& out) {\n unary_op(inputs, out, \"sigmoid\");\n}\n\nvoid Sign::eval_gpu(const std::vector& inputs, array& out) {\n unary_op(inputs, out, \"sign\");\n}\n\nvoid Sin::eval_gpu(const std::vector& inputs, array& out) {\n unary_op(inputs, out, \"sin\");\n}\n\nvoid Sinh::eval_gpu(const std::vector& inputs, array& out) {\n unary_op(inputs, out, \"sinh\");\n}\n\nvoid Split::eval_gpu(\n const std::vector& inputs,\n std::vector& outputs) {\n eval(inputs, outputs);\n}\n\nvoid Square::eval_gpu(const std::vector& inputs, array& out) {\n unary_op(inputs, out, \"square\");\n}\n\nvoid Sqrt::eval_gpu(const std::vector& inputs, array& out) {\n if (recip_) {\n unary_op(inputs, out, \"rsqrt\");\n } else {\n unary_op(inputs, out, \"sqrt\");\n }\n}\n\nvoid Slice::eval_gpu(const std::vector& inputs, array& out) {\n assert(inputs.size() == 1);\n if (out.size() == 0) {\n out.set_data(nullptr);\n return;\n }\n\n auto& in = inputs[0];\n\n // Calculate out strides, initial offset and if copy needs to be made\n auto [copy_needed, data_offset, inp_strides] = prepare_slice(in);\n\n // Do copy if needed\n if (copy_needed) {\n out.set_data(allocator::malloc_or_wait(out.nbytes()));\n std::vector ostrides{out.strides().begin(), out.strides().end()};\n copy_gpu_inplace(\n /* const array& in = */ in,\n /* array& out = */ out,\n /* const std::vector& data_shape = */ out.shape(),\n /* const std::vector& i_strides = */ inp_strides,\n /* const std::vector& o_strides = */ ostrides,\n /* int64_t i_offset = */ data_offset,\n /* int64_t o_offset = */ 0,\n /* CopyType ctype = */ CopyType::General,\n /* const Stream& s = */ stream());\n } else {\n std::vector ostrides{inp_strides.begin(), inp_strides.end()};\n shared_buffer_slice(in, ostrides, data_offset, out);\n }\n}\n\nvoid SliceUpdate::eval_gpu(const std::vector& inputs, array& out) {\n assert(inputs.size() == 2);\n if (out.size() == 0) {\n out.set_data(nullptr);\n return;\n }\n\n auto& in = inputs[0];\n auto& upd = inputs[1];\n\n if (upd.size() == 0) {\n out.copy_shared_buffer(in);\n return;\n }\n\n // Check if materialization is needed\n auto ctype = in.flags().contiguous && in.size() == in.data_size()\n ? CopyType::Vector\n : CopyType::General;\n copy_gpu(in, out, in.data_size() == 1 ? CopyType::Scalar : ctype, stream());\n\n // Calculate out strides, initial offset and if copy needs to be made\n auto [data_offset, out_strides] = prepare_slice(out);\n\n // Do copy\n std::vector upd_strides{upd.strides().begin(), upd.strides().end()};\n copy_gpu_inplace(\n /* const array& src = */ upd,\n /* array& dst = */ out,\n /* const std::vector& data_shape = */ upd.shape(),\n /* const std::vector& i_strides = */ upd_strides,\n /* const std::vector& o_strides = */ out_strides,\n /* int64_t i_offset = */ 0,\n /* int64_t o_offset = */ data_offset,\n /* CopyType ctype = */ CopyType::GeneralGeneral,\n /* const Stream& s = */ stream());\n}\n\nvoid StopGradient::eval_gpu(const std::vector& inputs, array& out) {\n eval(inputs, out);\n}\n\nvoid Subtract::eval_gpu(const std::vector& inputs, array& out) {\n binary_op(inputs, out, \"sub\");\n}\n\nvoid Tan::eval_gpu(const std::vector& inputs, array& out) {\n unary_op(inputs, out, \"tan\");\n}\n\nvoid Tanh::eval_gpu(const std::vector& inputs, array& out) {\n unary_op(inputs, out, \"tanh\");\n}\n\nvoid Transpose::eval_gpu(const std::vector& inputs, array& out) {\n eval(inputs, out);\n}\n\nvoid QRF::eval_gpu(\n const std::vector& inputs,\n std::vector& outputs) {\n throw std::runtime_error(\"[QRF::eval_gpu] Metal QR factorization NYI.\");\n}\n\nvoid SVD::eval_gpu(\n const std::vector& inputs,\n std::vector& outputs) {\n throw std::runtime_error(\"[SVD::eval_gpu] Metal SVD NYI.\");\n}\n\nvoid Inverse::eval_gpu(const std::vector& inputs, array& output) {\n throw std::runtime_error(\"[Inverse::eval_gpu] Metal inversion NYI.\");\n}\n\n} // namespace mlx::core\n\n// Path: mlx/backend/metal/indexing.cpp\n// Copyright \u00a9 2023-2024 Apple Inc.\n#include \n#include \n#include \n#include \n\n#include \"mlx/backend/common/binary.h\"\n#include \"mlx/backend/metal/copy.h\"\n#include \"mlx/backend/metal/device.h\"\n#include \"mlx/backend/metal/kernels/defines.h\"\n#include \"mlx/backend/metal/utils.h\"\n#include \"mlx/primitives.h\"\n...\n", "needle_token_start": 8875, "needle_token_end": 8942, "code_context_ntokens": 16174, "output": ["```cpp\nvoid arange_set_scalars(T start, T next, CommandEncoder& enc) {\n enc->setBytes(&start, sizeof(T), 0);\n T step = next - start;\n enc->setBytes(&step, sizeof(T), 1);\n}\n```"]} +{"repo": "ml-explore/mlx", "name": "unpack_32_4", "language": "cpp", "path": "mlx/io/gguf_quants.cpp", "position_ratio": 0.65, "description": "\nFunction Description:\n1. **Purpose**: The function is designed to decode and distribute packed bit data into a more accessible byte format, specifically targeting the transformation of compressed data into a format suitable for further processing or analysis.\n2. **Input**: The function takes two parameters: a pointer to an array of 8-bit unsigned integers (the source data array), and a pointer to an array of 8-bit signed integers (the destination array where the unpacked data will be stored).\n3. **Output**: There is no return value; however, the function modifies the destination array by filling it with unpacked and processed values derived from the source data array.\n4. **Procedure**: The function performs the following steps:\n - Iterates over the first 16 bytes of the source array starting from the third byte, extracting the lower 4 bits of each byte. Depending on the position (odd or even index), these bits are either shifted left by 4 bits or left as is, and then added to the corresponding position in the destination array.\n - For the next 16 bytes (also starting from the third byte of the source), it extracts the higher 4 bits, applies a similar shifting logic based on the byte's position, and adds these to the next set of positions in the destination array.\n\n", "instruction": "Based on the function description and code context, please retrieve and repeat the exact described function from the code context in a code block wrapped by ```:", "template": "instruction\ncode_context\ndescription\ninstruction", "code_context": "// Path: mlx/utils.h\n// Copyright \u00a9 2023-2024 Apple Inc.\n\n#pragma once\n\n#include \n\n#include \"array.h\"\n#include \"device.h\"\n#include \"dtype.h\"\n#include \"stream.h\"\n\nnamespace mlx::core {\n\nusing StreamOrDevice = std::variant;\nStream to_stream(StreamOrDevice s);\n\nstruct StreamContext {\n public:\n StreamContext(StreamOrDevice s) : _stream(default_stream(default_device())) {\n...\n// Path: mlx/utils.cpp\n// Copyright \u00a9 2023 Apple Inc.\n\n#include \n#include \n\n#include \"utils.h\"\n\nnamespace mlx::core {\n\nStream to_stream(StreamOrDevice s) {\n if (std::holds_alternative(s)) {\n return default_stream(default_device());\n } else if (std::holds_alternative(s)) {\n return default_stream(std::get(s));\n } else {\n return std::get(s);\n }\n}\n\nvoid PrintFormatter::print(std::ostream& os, bool val) {\n if (capitalize_bool) {\n os << (val ? \"True\" : \"False\");\n } else {\n os << val;\n }\n}\ninline void PrintFormatter::print(std::ostream& os, int16_t val) {\n os << val;\n}\ninline void PrintFormatter::print(std::ostream& os, uint16_t val) {\n os << val;\n}\ninline void PrintFormatter::print(std::ostream& os, int32_t val) {\n os << val;\n}\ninline void PrintFormatter::print(std::ostream& os, uint32_t val) {\n os << val;\n}\ninline void PrintFormatter::print(std::ostream& os, int64_t val) {\n os << val;\n}\ninline void PrintFormatter::print(std::ostream& os, uint64_t val) {\n os << val;\n}\ninline void PrintFormatter::print(std::ostream& os, float16_t val) {\n os << val;\n}\ninline void PrintFormatter::print(std::ostream& os, bfloat16_t val) {\n os << val;\n}\ninline void PrintFormatter::print(std::ostream& os, float val) {\n os << val;\n}\ninline void PrintFormatter::print(std::ostream& os, complex64_t val) {\n os << val;\n}\n\nPrintFormatter global_formatter;\n\nDtype result_type(const std::vector& arrays) {\n std::vector dtypes(1, bool_);\n for (auto& arr : arrays) {\n dtypes.push_back(promote_types(dtypes.back(), arr.dtype()));\n }\n return dtypes.back();\n}\n\nstd::vector broadcast_shapes(\n const std::vector& s1,\n const std::vector& s2) {\n // Use the same broadcasting rules as numpy\n // https://numpy.org/doc/1.20/user/theory.broadcasting.html\n // \"The size of the trailing axes for both arrays in an operation must\n // either be the same size or one of them must be one.\"\n int ndim1 = s1.size();\n int ndim2 = s2.size();\n int ndim = std::max(ndim1, ndim2);\n int diff = std::abs(ndim1 - ndim2);\n const auto& big = ndim1 > ndim2 ? s1 : s2;\n const auto& small = ndim1 > ndim2 ? s2 : s1;\n std::vector out_shape(ndim);\n for (int i = ndim - 1; i >= diff; --i) {\n int a = big[i];\n int b = small[i - diff];\n if (b == a) {\n out_shape[i] = a;\n } else if (a == 1 || b == 1) {\n // 0 if a or b is 0 otherwise max(a, b)\n out_shape[i] = a * b;\n } else {\n std::ostringstream msg;\n msg << \"Shapes \" << s1 << \" and \" << s2 << \" cannot be broadcast.\";\n throw std::invalid_argument(msg.str());\n }\n }\n for (int i = diff - 1; i >= 0; --i) {\n out_shape[i] = big[i];\n }\n return out_shape;\n}\n\nbool is_same_shape(const std::vector& arrays) {\n if (arrays.empty()) {\n return true;\n }\n return std::all_of(arrays.begin() + 1, arrays.end(), [&](const array& a) {\n return (a.shape() == arrays[0].shape());\n });\n}\n\nint normalize_axis(int axis, int ndim) {\n if (ndim <= 0) {\n throw std::invalid_argument(\"Number of dimensions must be positive.\");\n }\n if (axis < -ndim || axis >= ndim) {\n std::ostringstream msg;\n msg << \"Axis \" << axis << \" is out of bounds for array with \" << ndim\n << \" dimensions.\";\n throw std::invalid_argument(msg.str());\n }\n if (axis < 0) {\n axis += ndim;\n }\n return axis;\n}\n\nstd::ostream& operator<<(std::ostream& os, const Device& d) {\n os << \"Device(\";\n switch (d.type) {\n case Device::cpu:\n os << \"cpu\";\n break;\n case Device::gpu:\n os << \"gpu\";\n break;\n }\n os << \", \" << d.index << \")\";\n return os;\n}\n\nstd::ostream& operator<<(std::ostream& os, const Stream& s) {\n os << \"Stream(\";\n os << s.device;\n os << \", \" << s.index << \")\";\n return os;\n}\n\nstd::ostream& operator<<(std::ostream& os, int8_t x) {\n os << static_cast(x);\n return os;\n}\n\nstd::ostream& operator<<(std::ostream& os, uint8_t x) {\n os << static_cast(x);\n return os;\n}\n\nnamespace {\n\ninline size_t elem_to_loc(\n int elem,\n const std::vector& shape,\n const std::vector& strides) {\n size_t loc = 0;\n for (int i = shape.size() - 1; i >= 0; --i) {\n auto q_and_r = ldiv(elem, shape[i]);\n loc += q_and_r.rem * strides[i];\n elem = q_and_r.quot;\n }\n return loc;\n}\n\ntemplate \nvoid print_subarray(std::ostream& os, const array& a, size_t index, int dim) {\n int num_print = 3;\n int n = a.shape(dim);\n size_t s = a.strides()[dim];\n bool is_last = dim == a.ndim() - 1;\n auto prefix = is_last ? \"\" : std::string(7 + dim, ' ');\n auto postfix = is_last ? \", \" : \",\\n\";\n os << \"[\";\n for (int i = 0; i < n; ++i) {\n os << (i == 0 ? \"\" : prefix);\n if (i == num_print && n > 2 * num_print) {\n os << \"...\";\n i = n - num_print - 1;\n index += s * (n - 2 * num_print - 1);\n } else if (is_last) {\n global_formatter.print(os, a.data()[index]);\n } else {\n print_subarray(os, a, index, dim + 1);\n }\n os << (i == n - 1 ? \"\" : postfix);\n index += s;\n }\n os << \"]\";\n}\n\ntemplate \nvoid print_array(std::ostream& os, const array& a) {\n std::vector indices(a.ndim(), 0);\n os << std::boolalpha;\n os << \"array(\";\n if (a.ndim() == 0) {\n auto data = a.data();\n global_formatter.print(os, data[0]);\n } else {\n print_subarray(os, a, 0, 0);\n }\n os << \", dtype=\" << a.dtype() << \")\";\n os << std::noboolalpha;\n}\n\n} // namespace\n\nstd::ostream& operator<<(std::ostream& os, const Dtype& dtype) {\n switch (dtype) {\n case bool_:\n return os << \"bool\";\n case uint8:\n return os << \"uint8\";\n case uint16:\n return os << \"uint16\";\n case uint32:\n return os << \"uint32\";\n case uint64:\n return os << \"uint64\";\n case int8:\n return os << \"int8\";\n case int16:\n return os << \"int16\";\n case int32:\n return os << \"int32\";\n case int64:\n return os << \"int64\";\n case float16:\n return os << \"float16\";\n case float32:\n return os << \"float32\";\n case bfloat16:\n return os << \"bfloat16\";\n case complex64:\n return os << \"complex64\";\n }\n return os;\n}\n\nstd::ostream& operator<<(std::ostream& os, const Dtype::Kind& k) {\n switch (k) {\n case Dtype::Kind::b:\n return os << \"b\";\n case Dtype::Kind::i:\n return os << \"i\";\n case Dtype::Kind::u:\n return os << \"u\";\n case Dtype::Kind::f:\n return os << \"f\";\n case Dtype::Kind::c:\n return os << \"c\";\n case Dtype::Kind::V:\n return os << \"V\";\n }\n return os;\n}\n\nstd::ostream& operator<<(std::ostream& os, array a) {\n if (!a.is_evaled()) {\n a.eval();\n }\n switch (a.dtype()) {\n case bool_:\n print_array(os, a);\n break;\n case uint8:\n print_array(os, a);\n break;\n case uint16:\n print_array(os, a);\n break;\n case uint32:\n print_array(os, a);\n break;\n case uint64:\n print_array(os, a);\n break;\n case int8:\n print_array(os, a);\n break;\n case int16:\n print_array(os, a);\n break;\n case int32:\n print_array(os, a);\n break;\n case int64:\n print_array(os, a);\n break;\n case float16:\n print_array(os, a);\n break;\n case bfloat16:\n print_array(os, a);\n break;\n case float32:\n print_array(os, a);\n break;\n case complex64:\n print_array(os, a);\n break;\n }\n return os;\n}\n\nstd::ostream& operator<<(std::ostream& os, const std::vector& v) {\n os << \"(\";\n for (int i = 0; i < v.size(); ++i) {\n os << v[i] << ((i == v.size() - 1) ? \"\" : \",\");\n }\n os << \")\";\n return os;\n}\n\nstd::ostream& operator<<(std::ostream& os, const std::vector& v) {\n os << \"(\";\n for (int i = 0; i < v.size(); ++i) {\n os << v[i] << ((i == v.size() - 1) ? \"\" : \",\");\n }\n os << \")\";\n return os;\n}\n\nstd::ostream& operator<<(std::ostream& os, const std::vector& v) {\n os << \"(\";\n for (int i = 0; i < v.size(); ++i) {\n os << v[i] << ((i == v.size() - 1) ? \"\" : \",\");\n }\n os << \")\";\n return os;\n}\n\n} // namespace mlx::core\n\n// Path: mlx/fft.cpp\n// Copyright \u00a9 2023 Apple Inc.\n\n#include \n#include \n\n#include \"mlx/fft.h\"\n#include \"mlx/ops.h\"\n#include \"mlx/primitives.h\"\n#include \"mlx/utils.h\"\n\nnamespace mlx::core::fft {\n\narray fft_impl(\n const array& a,\n std::vector n,\n const std::vector& axes,\n bool real,\n bool inverse,\n StreamOrDevice s) {\n if (a.ndim() < 1) {\n throw std::invalid_argument(\n \"[fftn] Requires array with at least one dimension.\");\n }\n if (n.size() != axes.size()) {\n throw std::invalid_argument(\"[fftn] Shape and axes have different sizes.\");\n }\n if (axes.empty()) {\n return a;\n }\n\n std::vector valid_axes;\n for (int ax : axes) {\n valid_axes.push_back(ax < 0 ? ax + a.ndim() : ax);\n }\n std::set unique_axes(valid_axes.begin(), valid_axes.end());\n if (unique_axes.size() != axes.size()) {\n std::ostringstream msg;\n msg << \"[fftn] Duplicated axis received \" << axes;\n throw std::invalid_argument(msg.str());\n }\n if (*unique_axes.begin() < 0 || *unique_axes.rbegin() >= a.ndim()) {\n std::ostringstream msg;\n msg << \"[fftn] Invalid axis received for array with \" << a.ndim()\n << \" dimensions.\";\n throw std::invalid_argument(msg.str());\n }\n\n // In the following shape manipulations there are three cases to consider:\n // 1. In a complex to complex transform (fftn / ifftn) the output\n // and input shapes are the same.\n // 2. In a real to complex transform (rfftn) n specifies the input dims\n // and the output dims are n[i] / 2 + 1\n // 3 In a complex to real transform (irfftn) n specifies the output dims\n // and the input dims are n[i] / 2 + 1\n\n if (std::any_of(n.begin(), n.end(), [](auto i) { return i <= 0; })) {\n std::ostringstream msg;\n msg << \"[fftn] Invalid FFT output size requested \" << n;\n throw std::invalid_argument(msg.str());\n }\n\n std::vector in_shape = a.shape();\n for (int i = 0; i < valid_axes.size(); ++i) {\n in_shape[valid_axes[i]] = n[i];\n }\n if (real && inverse) {\n in_shape[valid_axes.back()] = n.back() / 2 + 1;\n }\n\n bool any_greater = false;\n bool any_less = false;\n for (int i = 0; i < in_shape.size(); ++i) {\n any_greater |= in_shape[i] > a.shape()[i];\n any_less |= in_shape[i] < a.shape()[i];\n }\n\n auto in = a;\n if (any_less) {\n in = slice(in, std::vector(in.ndim(), 0), in_shape, s);\n }\n if (any_greater) {\n // Pad with zeros\n auto tmp = zeros(in_shape, a.dtype(), s);\n in = scatter(tmp, std::vector{}, in, std::vector{}, s);\n }\n\n auto out_shape = in_shape;\n if (real) {\n auto ax = valid_axes.back();\n out_shape[ax] = inverse ? n.back() : out_shape[ax] / 2 + 1;\n }\n\n auto in_type = real && !inverse ? float32 : complex64;\n auto out_type = real && inverse ? float32 : complex64;\n return array(\n out_shape,\n out_type,\n std::make_shared(to_stream(s), valid_axes, inverse, real),\n {astype(in, in_type, s)});\n}\n\narray fft_impl(\n const array& a,\n const std::vector& axes,\n bool real,\n bool inverse,\n StreamOrDevice s) {\n std::vector n;\n for (auto ax : axes) {\n n.push_back(a.shape(ax));\n }\n if (real && inverse) {\n n.back() = (n.back() - 1) * 2;\n }\n return fft_impl(a, n, axes, real, inverse, s);\n}\n\narray fft_impl(const array& a, bool real, bool inverse, StreamOrDevice s) {\n std::vector axes(a.ndim());\n std::iota(axes.begin(), axes.end(), 0);\n return fft_impl(a, axes, real, inverse, s);\n}\n\narray fftn(\n const array& a,\n const std::vector& n,\n const std::vector& axes,\n StreamOrDevice s /* = {} */) {\n return fft_impl(a, n, axes, false, false, s);\n}\narray fftn(\n const array& a,\n const std::vector& axes,\n StreamOrDevice s /* = {} */) {\n return fft_impl(a, axes, false, false, s);\n}\narray fftn(const array& a, StreamOrDevice s /* = {} */) {\n return fft_impl(a, false, false, s);\n}\n\narray ifftn(\n const array& a,\n const std::vector& n,\n const std::vector& axes,\n StreamOrDevice s /* = {} */) {\n return fft_impl(a, n, axes, false, true, s);\n}\narray ifftn(\n const array& a,\n const std::vector& axes,\n StreamOrDevice s /* = {} */) {\n return fft_impl(a, axes, false, true, s);\n}\narray ifftn(const array& a, StreamOrDevice s /* = {} */) {\n return fft_impl(a, false, true, s);\n}\n\narray rfftn(\n const array& a,\n const std::vector& n,\n const std::vector& axes,\n StreamOrDevice s /* = {} */) {\n return fft_impl(a, n, axes, true, false, s);\n}\narray rfftn(\n const array& a,\n const std::vector& axes,\n StreamOrDevice s /* = {} */) {\n return fft_impl(a, axes, true, false, s);\n}\narray rfftn(const array& a, StreamOrDevice s /* = {} */) {\n return fft_impl(a, true, false, s);\n}\n\narray irfftn(\n const array& a,\n const std::vector& n,\n const std::vector& axes,\n StreamOrDevice s /* = {} */) {\n return fft_impl(a, n, axes, true, true, s);\n}\narray irfftn(\n const array& a,\n const std::vector& axes,\n StreamOrDevice s /* = {} */) {\n return fft_impl(a, axes, true, true, s);\n}\narray irfftn(const array& a, StreamOrDevice s /* = {} */) {\n return fft_impl(a, true, true, s);\n}\n\n} // namespace mlx::core::fft\n\n// Path: mlx/io/safetensor.cpp\n// Copyright \u00a9 2023 Apple Inc.\n//\n#include \n#include \n\n#include \"mlx/io.h\"\n#include \"mlx/io/load.h\"\n#include \"mlx/primitives.h\"\n\nusing json = nlohmann::json;\n\n#define ST_F16 \"F16\"\n#define ST_BF16 \"BF16\"\n#define ST_F32 \"F32\"\n\n#define ST_BOOL \"BOOL\"\n#define ST_I8 \"I8\"\n#define ST_I16 \"I16\"\n#define ST_I32 \"I32\"\n#define ST_I64 \"I64\"\n#define ST_U8 \"U8\"\n#define ST_U16 \"U16\"\n#define ST_U32 \"U32\"\n#define ST_U64 \"U64\"\n\n// Note: Complex numbers aren't in the spec yet so this could change -\n// https://github.com/huggingface/safetensors/issues/389\n#define ST_C64 \"C64\"\n\nnamespace mlx::core {\n\nstd::string dtype_to_safetensor_str(Dtype t) {\n switch (t) {\n case float32:\n return ST_F32;\n case bfloat16:\n return ST_BF16;\n case float16:\n return ST_F16;\n case int64:\n return ST_I64;\n case int32:\n return ST_I32;\n case int16:\n return ST_I16;\n case int8:\n return ST_I8;\n case uint64:\n return ST_U64;\n case uint32:\n return ST_U32;\n case uint16:\n return ST_U16;\n case uint8:\n return ST_U8;\n case bool_:\n return ST_BOOL;\n case complex64:\n return ST_C64;\n }\n}\n\nDtype dtype_from_safetensor_str(std::string_view str) {\n if (str == ST_F32) {\n return float32;\n } else if (str == ST_F16) {\n return float16;\n } else if (str == ST_BF16) {\n return bfloat16;\n } else if (str == ST_I64) {\n return int64;\n } else if (str == ST_I32) {\n return int32;\n } else if (str == ST_I16) {\n return int16;\n } else if (str == ST_I8) {\n return int8;\n } else if (str == ST_U64) {\n return uint64;\n } else if (str == ST_U32) {\n return uint32;\n } else if (str == ST_U16) {\n return uint16;\n } else if (str == ST_U8) {\n return uint8;\n } else if (str == ST_BOOL) {\n return bool_;\n } else if (str == ST_C64) {\n return complex64;\n } else {\n throw std::runtime_error(\n \"[safetensor] unsupported dtype \" + std::string(str));\n }\n}\n\n/** Load array from reader in safetensor format */\nSafetensorsLoad load_safetensors(\n std::shared_ptr in_stream,\n StreamOrDevice s) {\n ////////////////////////////////////////////////////////\n // Open and check file\n if (!in_stream->good() || !in_stream->is_open()) {\n throw std::runtime_error(\n \"[load_safetensors] Failed to open \" + in_stream->label());\n }\n\n uint64_t jsonHeaderLength = 0;\n in_stream->read(reinterpret_cast(&jsonHeaderLength), 8);\n if (jsonHeaderLength <= 0) {\n throw std::runtime_error(\n \"[load_safetensors] Invalid json header length \" + in_stream->label());\n }\n // Load the json metadata\n char rawJson[jsonHeaderLength];\n in_stream->read(rawJson, jsonHeaderLength);\n auto metadata = json::parse(rawJson, rawJson + jsonHeaderLength);\n // Should always be an object on the top-level\n if (!metadata.is_object()) {\n throw std::runtime_error(\n \"[load_safetensors] Invalid json metadata \" + in_stream->label());\n }\n size_t offset = jsonHeaderLength + 8;\n // Load the arrays using metadata\n std::unordered_map res;\n std::unordered_map metadata_map;\n for (const auto& item : metadata.items()) {\n if (item.key() == \"__metadata__\") {\n for (const auto& meta_item : item.value().items()) {\n metadata_map.insert({meta_item.key(), meta_item.value()});\n }\n continue;\n }\n const std::string& dtype = item.value().at(\"dtype\");\n const std::vector& shape = item.value().at(\"shape\");\n const std::vector& data_offsets = item.value().at(\"data_offsets\");\n Dtype type = dtype_from_safetensor_str(dtype);\n auto loaded_array = array(\n shape,\n type,\n std::make_shared(\n to_stream(s), in_stream, offset + data_offsets.at(0), false),\n std::vector{});\n res.insert({item.key(), loaded_array});\n }\n return {res, metadata_map};\n}\n\nSafetensorsLoad load_safetensors(const std::string& file, StreamOrDevice s) {\n return load_safetensors(std::make_shared(file), s);\n}\n\n/** Save array to out stream in .npy format */\nvoid save_safetensors(\n std::shared_ptr out_stream,\n std::unordered_map a,\n std::unordered_map metadata /* = {} */) {\n ////////////////////////////////////////////////////////\n // Check file\n if (!out_stream->good() || !out_stream->is_open()) {\n throw std::runtime_error(\n \"[save_safetensors] Failed to open \" + out_stream->label());\n }\n\n ////////////////////////////////////////////////////////\n // Check array map\n json parent;\n json _metadata;\n for (auto& [key, value] : metadata) {\n _metadata[key] = value;\n }\n parent[\"__metadata__\"] = _metadata;\n size_t offset = 0;\n for (auto& [key, arr] : a) {\n arr.eval();\n if (arr.nbytes() == 0) {\n throw std::invalid_argument(\n \"[save_safetensors] cannot serialize an empty array key: \" + key);\n }\n\n // Try to make it row contiguous\n if (!arr.flags().row_contiguous) {\n arr = reshape(flatten(arr), arr.shape());\n arr.eval();\n }\n\n // Has to be row-major now but, check one more time in case\n // any of the above change in the future\n if (!arr.flags().row_contiguous) {\n throw std::invalid_argument(\n \"[save_safetensors] can only serialize row-major arrays\");\n }\n\n json child;\n child[\"dtype\"] = dtype_to_safetensor_str(arr.dtype());\n child[\"shape\"] = arr.shape();\n child[\"data_offsets\"] = std::vector{offset, offset + arr.nbytes()};\n parent[key] = child;\n offset += arr.nbytes();\n }\n\n auto header = parent.dump();\n uint64_t header_len = header.length();\n out_stream->write(reinterpret_cast(&header_len), 8);\n out_stream->write(header.c_str(), header_len);\n for (auto& [key, arr] : a) {\n out_stream->write(arr.data(), arr.nbytes());\n }\n}\n\nvoid save_safetensors(\n std::string file,\n std::unordered_map a,\n std::unordered_map metadata /* = {} */) {\n // Add .safetensors to file name if it is not there\n if (file.length() < 12 ||\n file.substr(file.length() - 12, 12) != \".safetensors\")\n file += \".safetensors\";\n\n // Serialize array\n save_safetensors(\n std::make_shared(std::move(file)), a, metadata);\n}\n\n} // namespace mlx::core\n\n// Path: mlx/io/load.cpp\n// Copyright \u00a9 2023-2024 Apple Inc.\n#include \n#include \n#include \n#include \n#include \n\n#include \"mlx/io/load.h\"\n#include \"mlx/ops.h\"\n#include \"mlx/primitives.h\"\n#include \"mlx/utils.h\"\n\n// Adapted from\n// https://github.com/angeloskath/supervised-lda/blob/master/include/ldaplusplus/NumpyFormat.hpp\n\nnamespace mlx::core {\n\nnamespace {\n\nconstexpr uint8_t MAGIC[] = {\n 0x93,\n 0x4e,\n 0x55,\n 0x4d,\n 0x50,\n 0x59,\n};\n\n} // namespace\n\n/** Save array to out stream in .npy format */\nvoid save(std::shared_ptr out_stream, array a) {\n ////////////////////////////////////////////////////////\n // Check array\n\n a.eval();\n\n if (a.nbytes() == 0) {\n throw std::invalid_argument(\"[save] cannot serialize an empty array\");\n }\n\n if (!(a.flags().row_contiguous || a.flags().col_contiguous)) {\n a = reshape(flatten(a), a.shape());\n a.eval();\n }\n // Check once more in-case the above ops change\n if (!(a.flags().row_contiguous || a.flags().col_contiguous)) {\n throw std::invalid_argument(\n \"[save] can only serialize row or col contiguous arrays\");\n }\n\n ////////////////////////////////////////////////////////\n // Check file\n if (!out_stream->good() || !out_stream->is_open()) {\n throw std::runtime_error(\"[save] Failed to open \" + out_stream->label());\n }\n\n ////////////////////////////////////////////////////////\n // Prepare header\n std::ostringstream magic_ver_len;\n magic_ver_len.write(reinterpret_cast(MAGIC), 6);\n\n std::string fortran_order = a.flags().col_contiguous ? \"True\" : \"False\";\n std::ostringstream header;\n header << \"{'descr': '\" << dtype_to_array_protocol(a.dtype()) << \"',\"\n << \" 'fortran_order': \" << fortran_order << \",\" << \" 'shape': (\";\n for (auto i : a.shape()) {\n header << i << \", \";\n }\n header << \")}\";\n\n size_t header_len = static_cast(header.tellp());\n bool is_v1 = header_len + 15 < std::numeric_limits::max();\n\n // Pad out magic + version + header_len + header + \\n to be divisible by 16\n size_t padding = (6 + 2 + (2 + 2 * is_v1) + header_len + 1) % 16;\n\n header << std::string(padding, ' ') << '\\n';\n\n if (is_v1) {\n magic_ver_len << (char)0x01 << (char)0x00;\n\n uint16_t v1_header_len = header.tellp();\n const char* len_bytes = reinterpret_cast(&v1_header_len);\n\n if (!is_big_endian()) {\n magic_ver_len.write(len_bytes, 2);\n } else {\n magic_ver_len.write(len_bytes + 1, 1);\n magic_ver_len.write(len_bytes, 1);\n }\n } else {\n magic_ver_len << (char)0x02 << (char)0x00;\n\n uint32_t v2_header_len = header.tellp();\n const char* len_bytes = reinterpret_cast(&v2_header_len);\n\n if (!is_big_endian()) {\n magic_ver_len.write(len_bytes, 4);\n } else {\n magic_ver_len.write(len_bytes + 3, 1);\n magic_ver_len.write(len_bytes + 2, 1);\n magic_ver_len.write(len_bytes + 1, 1);\n magic_ver_len.write(len_bytes, 1);\n }\n }\n ////////////////////////////////////////////////////////\n // Serialize array\n\n out_stream->write(magic_ver_len.str().c_str(), magic_ver_len.str().length());\n out_stream->write(header.str().c_str(), header.str().length());\n out_stream->write(a.data(), a.nbytes());\n}\n\n/** Save array to file in .npy format */\nvoid save(std::string file, array a) {\n // Add .npy to file name if it is not there\n if (file.length() < 4 || file.substr(file.length() - 4, 4) != \".npy\")\n file += \".npy\";\n\n // Serialize array\n save(std::make_shared(std::move(file)), a);\n}\n\n/** Load array from reader in .npy format */\narray load(std::shared_ptr in_stream, StreamOrDevice s) {\n ////////////////////////////////////////////////////////\n // Open and check file\n if (!in_stream->good() || !in_stream->is_open()) {\n throw std::runtime_error(\"[load] Failed to open \" + in_stream->label());\n }\n\n ////////////////////////////////////////////////////////\n // Read header and prepare array details\n\n // Read and check magic\n char read_magic_and_ver[8];\n in_stream->read(read_magic_and_ver, 8);\n if (std::memcmp(read_magic_and_ver, MAGIC, 6) != 0) {\n throw std::runtime_error(\"[load] Invalid header in \" + in_stream->label());\n }\n\n // Read and check version\n if (read_magic_and_ver[6] != 1 && read_magic_and_ver[6] != 2) {\n throw std::runtime_error(\n \"[load] Unsupported npy format version in \" + in_stream->label());\n }\n\n // Read header len and header\n int header_len_size = read_magic_and_ver[6] == 1 ? 2 : 4;\n size_t header_len;\n\n if (header_len_size == 2) {\n uint16_t v1_header_len;\n in_stream->read(reinterpret_cast(&v1_header_len), header_len_size);\n header_len = v1_header_len;\n } else {\n uint32_t v2_header_len;\n in_stream->read(reinterpret_cast(&v2_header_len), header_len_size);\n header_len = v2_header_len;\n }\n\n // Read the header\n std::vector buffer(header_len + 1);\n in_stream->read(&buffer[0], header_len);\n buffer[header_len] = 0;\n std::string header(&buffer[0]);\n\n // Read data type from header\n std::string dtype_str = header.substr(11, 3);\n bool read_is_big_endian = dtype_str[0] == '>';\n Dtype dtype = dtype_from_array_protocol(dtype_str);\n\n // Read contiguity order\n bool col_contiguous = header[34] == 'T';\n\n // Read array shape from header\n std::vector shape;\n\n size_t st = header.find_last_of('(') + 1;\n size_t ed = header.find_last_of(')');\n std::string shape_str = header.substr(st, ed - st);\n\n while (!shape_str.empty()) {\n // Read current number and get position of comma\n size_t pos;\n int dim = std::stoi(shape_str, &pos);\n shape.push_back(dim);\n\n // Skip the comma and space and read the next number\n if (pos + 2 <= shape_str.length())\n shape_str = shape_str.substr(pos + 2);\n else {\n shape_str = shape_str.substr(pos);\n if (!shape_str.empty() && shape_str != \" \" && shape_str != \",\") {\n throw std::runtime_error(\n \"[load] Unknown error while parsing header in \" +\n in_stream->label());\n }\n shape_str = \"\";\n }\n }\n\n ////////////////////////////////////////////////////////\n // Build primitive\n\n size_t offset = 8 + header_len_size + header.length();\n bool swap_endianness = read_is_big_endian != is_big_endian();\n\n if (col_contiguous) {\n std::reverse(shape.begin(), shape.end());\n }\n auto loaded_array = array(\n shape,\n dtype,\n std::make_shared(to_stream(s), in_stream, offset, swap_endianness),\n std::vector{});\n if (col_contiguous) {\n loaded_array = transpose(loaded_array, s);\n }\n\n return loaded_array;\n}\n\n/** Load array from file in .npy format */\narray load(std::string file, StreamOrDevice s) {\n return load(std::make_shared(std::move(file)), s);\n}\n\n} // namespace mlx::core\n\n// Path: mlx/io/gguf_quants.cpp\n// Copyright \u00a9 2023-2024 Apple Inc.\n\n#include \n#include \n\n#include \n\nnamespace mlx::core {\n\nv\noid unpack_32_4(uint8_t* data, int8_t* dst) {\n for (int64_t j = 0; j < 16; ++j) {\n uint8_t x = (data[j + 2] & 0x0F); // j+2 to skip scale bytes.\n if (j % 2 != 0) {\n x <<= 4;\n }\n dst[j / 2] += x;\n }\n // Last 16 weights are in the higher bits\n for (int64_t j = 0; j < 16; ++j) {\n uint8_t x = (data[j + 2] >> 4);\n if (j % 2 != 0) {\n x <<= 4;\n }\n dst[8 + j / 2] += x;\n }\n}\n\n// Extracts (weight, scales, biases) from Q4_0 tensors.\n// Data layout is: |16 bit scale|32 x 4bit weights|.\nvoid extract_q4_0_data(\n const gguf_tensor& tensor,\n array& weights_arr,\n array& scales_arr,\n array& biases_arr) {\n const uint64_t bytes_per_block = 18; // 2 bytes scale, 32x0.5 byte weights\n auto data = static_cast(tensor.weights_data);\n auto weights = weights_arr.data();\n auto scales = scales_arr.data();\n auto biases = biases_arr.data();\n for (int64_t i = 0; i < scales_arr.size(); i++) {\n scales[i] = *((float16_t*)data);\n biases[i] = -8 * scales[i];\n unpack_32_4(data, weights);\n weights += 16;\n data += bytes_per_block;\n }\n}\n\n// Extracts (weight, scales, biases) from Q4_1 tensors.\n// Data layout is: |16 bit scale|16 bit bias|32 x 4bit weights|.\nvoid extract_q4_1_data(\n const gguf_tensor& tensor,\n array& weights_arr,\n array& scales_arr,\n array& biases_arr) {\n const uint64_t bytes_per_block =\n 20; // 2 bytes scale, 2 bytes bias, 32x0.5 byte weights\n auto data = static_cast(tensor.weights_data);\n auto weights = weights_arr.data();\n auto scales = scales_arr.data();\n auto biases = biases_arr.data();\n for (int64_t i = 0; i < scales_arr.size(); i++) {\n scales[i] = *((float16_t*)data);\n biases[i] = *((float16_t*)(data) + 1);\n unpack_32_4(data, weights);\n weights += 16;\n data += bytes_per_block;\n }\n}\n\n// Extracts (weight, scales, biases) from Q8_0 tensors.\n// Data layout is: |16 bit scale|32 x 8bit weights|.\nvoid extract_q8_0_data(\n const gguf_tensor& tensor,\n array& weights_arr,\n array& scales_arr,\n array& biases_arr) {\n const uint64_t weights_per_block = 32;\n const uint64_t bytes_per_block = 34; // 2 bytes scale, 32x1 byte weights\n auto data = static_cast(tensor.weights_data);\n auto weights = weights_arr.data();\n auto scales = scales_arr.data();\n auto biases = biases_arr.data();\n for (int64_t i = 0; i < scales_arr.size(); i++) {\n uint8_t* block_data = data + i * bytes_per_block;\n scales[i] = *((float16_t*)block_data);\n biases[i] = -128 * scales[i];\n for (int64_t j = 0; j < weights_per_block; ++j) {\n uint8_t x = block_data[j + 2]; // j+2 to skip the scale bytes.\n // Original data is in int8_t, so we add a bias of -128 and invert the\n // first bit.\n x ^= 1 << 7;\n weights[i * weights_per_block + j] = x;\n }\n }\n}\n\nvoid gguf_load_quantized(\n std::unordered_map& a,\n const gguf_tensor& tensor) {\n uint64_t weights_per_byte;\n if (tensor.type == GGUF_TYPE_Q4_0 || tensor.type == GGUF_TYPE_Q4_1) {\n weights_per_byte = 2;\n } else { // tensor.type == GGUF_TYPE_Q8_0\n weights_per_byte = 1;\n }\n\n std::string name(tensor.name, tensor.namelen);\n\n std::vector shape = get_shape(tensor);\n const uint64_t weights_per_block = 32;\n if (shape[shape.size() - 1] % weights_per_block != 0) {\n std::ostringstream msg;\n msg << \"[load_gguf] tensor \" << name\n << \"has incompatible last dim shape: \" << shape[shape.size() - 1];\n throw std::runtime_error(msg.str());\n }\n\n std::vector weights_shape = shape;\n weights_shape.back() /= (weights_per_byte * 4);\n\n array weights(std::move(weights_shape), uint32, nullptr, {});\n weights.set_data(allocator::malloc(weights.nbytes()));\n\n // For scales and bias\n shape[shape.size() - 1] = shape[shape.size() - 1] / weights_per_block;\n array scales(shape, float16, nullptr, {});\n array biases(std::move(shape), float16, nullptr, {});\n scales.set_data(allocator::malloc(scales.nbytes()));\n biases.set_data(allocator::malloc(biases.nbytes()));\n\n if (tensor.type == GGUF_TYPE_Q4_0) {\n extract_q4_0_data(tensor, weights, scales, biases);\n } else if (tensor.type == GGUF_TYPE_Q4_1) {\n extract_q4_1_data(tensor, weights, scales, biases);\n } else if (tensor.type == GGUF_TYPE_Q8_0) {\n extract_q8_0_data(tensor, weights, scales, biases);\n }\n\n a.emplace(name, std::move(weights));\n\n auto check_insert = [](const auto& inserted) {\n if (!inserted.second) {\n std::ostringstream msg;\n msg << \"[load_gguf] Duplicate parameter name \" << inserted.first->second\n << \" this can happend when loading quantized tensors.\";\n throw std::runtime_error(msg.str());\n }\n };\n\n constexpr std::string_view weight_suffix = \".weight\";\n const std::string name_prefix =\n name.substr(0, name.length() - weight_suffix.length());\n check_insert(a.emplace(name_prefix + \".scales\", std::move(scales)));\n check_insert(a.emplace(name_prefix + \".biases\", std::move(biases)));\n}\n\n} // namespace mlx::core\n\n// Path: mlx/io/gguf.cpp\n// Copyright \u00a9 2023-2024 Apple Inc.\n\n#include \n#include \n#include \n\n#include \n\nnamespace mlx::core {\n\n// https://github.com/antirez/gguf-tools/blob/af7d88d808a7608a33723fba067036202910acb3/gguflib.h#L102-L108\nconstexpr int gguf_array_header_size = 12;\n\nstd::optional dtype_to_gguf_tensor_type(const Dtype& dtype) {\n switch (dtype) {\n case float32:\n return GGUF_TYPE_F32;\n case float16:\n return GGUF_TYPE_F16;\n case int8:\n return GGUF_TYPE_I8;\n case int16:\n return GGUF_TYPE_I16;\n case int32:\n return GGUF_TYPE_I32;\n default:\n return {};\n }\n}\n\nstd::optional gguf_type_to_dtype(const uint32_t& gguf_type) {\n switch (gguf_type) {\n case GGUF_TYPE_F32:\n return float32;\n case GGUF_TYPE_F16:\n return float16;\n case GGUF_TYPE_I8:\n return int8;\n case GGUF_TYPE_I16:\n return int16;\n case GGUF_TYPE_I32:\n return int32;\n default:\n return {};\n }\n}\n\nstd::vector get_shape(const gguf_tensor& tensor) {\n std::vector shape;\n // The dimension order in GGML is the reverse of the order used in MLX.\n for (int i = tensor.ndim - 1; i >= 0; i--) {\n shape.push_back(tensor.dim[i]);\n }\n return shape;\n}\n\nstd::tuple extract_tensor_data(gguf_tensor* tensor) {\n std::optional equivalent_dtype = gguf_type_to_dtype(tensor->type);\n // If there's an equivalent type, we can simply copy.\n if (equivalent_dtype.has_value()) {\n allocator::Buffer buffer = allocator::malloc(tensor->bsize);\n memcpy(\n buffer.raw_ptr(),\n tensor->weights_data,\n tensor->num_weights * equivalent_dtype.value().size);\n return {buffer, equivalent_dtype.value()};\n }\n // Otherwise, we convert to float16.\n // TODO: Add other dequantization options.\n int16_t* data = gguf_tensor_to_f16(tensor);\n if (data == NULL) {\n throw std::runtime_error(\"[load_gguf] gguf_tensor_to_f16 failed\");\n }\n const size_t new_size = tensor->num_weights * sizeof(int16_t);\n allocator::Buffer buffer = allocator::malloc(new_size);\n memcpy(buffer.raw_ptr(), data, new_size);\n free(data);\n return {buffer, float16};\n}\n\nvoid set_mx_value_from_gguf(\n gguf_ctx* ctx,\n uint32_t type,\n gguf_value* val,\n GGUFMetaData& value) {\n switch (type) {\n case GGUF_VALUE_TYPE_UINT8:\n value = array(val->uint8, uint8);\n break;\n case GGUF_VALUE_TYPE_INT8:\n value = array(val->int8, int8);\n break;\n case GGUF_VALUE_TYPE_UINT16:\n value = array(val->uint16, uint16);\n break;\n case GGUF_VALUE_TYPE_INT16:\n value = array(val->int16, int16);\n break;\n case GGUF_VALUE_TYPE_UINT32:\n value = array(val->uint32, uint32);\n break;\n case GGUF_VALUE_TYPE_INT32:\n value = array(val->int32, int32);\n break;\n case GGUF_VALUE_TYPE_UINT64:\n value = array(val->uint64, uint64);\n break;\n case GGUF_VALUE_TYPE_INT64:\n value = array(val->int64, int64);\n break;\n case GGUF_VALUE_TYPE_FLOAT32:\n value = array(val->float32, float32);\n break;\n case GGUF_VALUE_TYPE_BOOL:\n value = array(val->boolval, bool_);\n break;\n case GGUF_VALUE_TYPE_STRING:\n value =\n std::string(val->string.string, static_cast(val->string.len));\n break;\n case GGUF_VALUE_TYPE_FLOAT64:\n value = array(val->float64, float32);\n break;\n case GGUF_VALUE_TYPE_ARRAY: {\n ctx->off += gguf_array_header_size; // Skip header\n char* data = reinterpret_cast(val) + gguf_array_header_size;\n auto size = static_cast(val->array.len);\n if (val->array.type == GGUF_VALUE_TYPE_ARRAY) {\n throw std::invalid_argument(\n \"[load_gguf] Only supports loading 1-layer of nested arrays.\");\n }\n switch (val->array.type) {\n case GGUF_VALUE_TYPE_UINT8:\n value = array(reinterpret_cast(data), {size}, uint8);\n break;\n case GGUF_VALUE_TYPE_INT8:\n value = array(reinterpret_cast(data), {size}, int8);\n break;\n case GGUF_VALUE_TYPE_UINT16:\n value = array(reinterpret_cast(data), {size}, uint16);\n break;\n case GGUF_VALUE_TYPE_INT16:\n value = array(reinterpret_cast(data), {size}, int16);\n break;\n case GGUF_VALUE_TYPE_UINT32:\n value = array(reinterpret_cast(data), {size}, uint32);\n break;\n case GGUF_VALUE_TYPE_INT32:\n value = array(reinterpret_cast(data), {size}, int32);\n break;\n case GGUF_VALUE_TYPE_UINT64:\n value = array(reinterpret_cast(data), {size}, uint64);\n break;\n case GGUF_VALUE_TYPE_INT64:\n value = array(reinterpret_cast(data), {size}, int64);\n break;\n case GGUF_VALUE_TYPE_FLOAT32:\n value = array(reinterpret_cast(data), {size}, float32);\n break;\n case GGUF_VALUE_TYPE_BOOL:\n value = array(reinterpret_cast(data), {size}, bool_);\n break;\n case GGUF_VALUE_TYPE_STRING: {\n std::vector strs(size);\n for (auto& str : strs) {\n auto str_val = reinterpret_cast(data);\n data += (str_val->len + sizeof(gguf_string));\n str = std::string(str_val->string, static_cast(str_val->len));\n ctx->off += (str_val->len + sizeof(gguf_string));\n }\n value = std::move(strs);\n break;\n }\n case GGUF_VALUE_TYPE_FLOAT64:\n value = array(reinterpret_cast(data), {size}, float32);\n break;\n default:\n throw std::runtime_error(\n \"[load_gguf] Multiple levels of nested arrays are not supported.\");\n }\n break;\n }\n default:\n throw std::runtime_error(\"[load_gguf] Received unexpected type.\");\n break;\n }\n if (type == GGUF_VALUE_TYPE_STRING) {\n ctx->off += (sizeof(gguf_string) + std::get(value).size());\n } else if (auto pv = std::get_if(&value); pv) {\n ctx->off += pv->nbytes();\n }\n}\n\nstd::unordered_map load_metadata(gguf_ctx* ctx) {\n std::unordered_map metadata;\n gguf_key key;\n while (gguf_get_key(ctx, &key)) {\n std::string key_name = std::string(key.name, key.namelen);\n auto& val = metadata.insert({key_name, GGUFMetaData{}}).first->second;\n set_mx_value_from_gguf(ctx, key.type, key.val, val);\n }\n return metadata;\n}\n\nstd::unordered_map load_arrays(gguf_ctx* ctx) {\n std::unordered_map array_map;\n gguf_tensor tensor;\n\n auto check_insert = [](const auto& inserted) {\n if (!inserted.second) {\n std::ostringstream msg;\n msg << \"[load_gguf] Duplicate parameter name \" << inserted.first->second\n << \" this can happend when loading quantized tensors.\";\n throw std::runtime_error(msg.str());\n }\n };\n\n while (gguf_get_tensor(ctx, &tensor)) {\n std::string name(tensor.name, tensor.namelen);\n if (tensor.type == GGUF_TYPE_Q4_0 || tensor.type == GGUF_TYPE_Q4_1 ||\n tensor.type == GGUF_TYPE_Q8_0) {\n gguf_load_quantized(array_map, tensor);\n } else {\n std::string name = std::string(tensor.name, tensor.namelen);\n\n const auto& [data, dtype] = extract_tensor_data(&tensor);\n array loaded_array = array(data, get_shape(tensor), dtype);\n check_insert(array_map.insert({name, loaded_array}));\n }\n }\n return array_map;\n}\n\nGGUFLoad load_gguf(const std::string& file, StreamOrDevice s) {\n gguf_ctx* ctx = gguf_open(file.data());\n if (!ctx) {\n throw std::runtime_error(\"[load_gguf] gguf_init failed\");\n }\n auto metadata = load_metadata(ctx);\n auto arrays = load_arrays(ctx);\n gguf_close(ctx);\n return {arrays, metadata};\n}\n\nvoid append_kv_array(\n gguf_ctx* ctx,\n const std::string& key,\n array& val,\n uint32_t gguf_type) {\n if (val.ndim() == 1) {\n size_t gguf_size = val.nbytes() + gguf_array_header_size;\n std::vector val_vec(gguf_size);\n gguf_value* gguf_val = reinterpret_cast(val_vec.data());\n gguf_val->array.type = gguf_type;\n gguf_val->array.len = val.size();\n memcpy(\n val_vec.data() + gguf_array_header_size,\n val.data(),\n val.nbytes());\n gguf_append_kv(\n ctx,\n key.c_str(),\n key.length(),\n GGUF_VALUE_TYPE_ARRAY,\n reinterpret_cast(val_vec.data()),\n gguf_size);\n } else {\n gguf_append_kv(\n ctx,\n key.c_str(),\n key.length(),\n gguf_type,\n reinterpret_cast(val.data()),\n val.nbytes());\n }\n}\n\nvoid save_gguf(\n std::string file,\n std::unordered_map array_map,\n std::unordered_map metadata /* = {} */) {\n // Add .gguf to file name if it is not there\n if (file.length() < 5 || file.substr(file.length() - 5, 5) != \".gguf\") {\n file += \".gguf\";\n }\n\n gguf_ctx* ctx = gguf_create(file.c_str(), GGUF_OVERWRITE);\n if (!ctx) {\n throw std::runtime_error(\"[save_gguf] gguf_create failed\");\n }\n\n auto string_to_gguf = [](char* dst, const std::string& src) {\n gguf_string* val = reinterpret_cast(dst);\n val->len = src.length();\n memcpy(val->string, src.c_str(), src.length());\n };\n\n // Save any meta data\n for (auto& [key, value] : metadata) {\n if (auto pv = std::get_if(&value); pv) {\n const std::string& str = *pv;\n size_t size = sizeof(gguf_string) + str.length();\n std::vector val_vec(size);\n string_to_gguf(val_vec.data(), str);\n gguf_append_kv(\n ctx,\n key.c_str(),\n key.length(),\n GGUF_VALUE_TYPE_STRING,\n...\n", "needle_token_start": 10409, "needle_token_end": 10630, "code_context_ntokens": 16217, "output": ["```cpp\nvoid unpack_32_4(uint8_t* data, int8_t* dst) {\n for (int64_t j = 0; j < 16; ++j) {\n uint8_t x = (data[j + 2] & 0x0F); // j+2 to skip scale bytes.\n if (j % 2 != 0) {\n x <<= 4;\n }\n dst[j / 2] += x;\n }\n // Last 16 weights are in the higher bits\n for (int64_t j = 0; j < 16; ++j) {\n uint8_t x = (data[j + 2] >> 4);\n if (j % 2 != 0) {\n x <<= 4;\n }\n dst[8 + j / 2] += x;\n }\n}\n```"]} +{"repo": "ml-explore/mlx", "name": "QuantizedMatmul", "language": "cpp", "path": "mlx/primitives.h", "position_ratio": 0.75, "description": "\nFunction Description:\n1. **Purpose**: This function is designed to perform matrix multiplication on quantized data, which is a method to reduce the computational precision of the operands to lower bit-widths. This is particularly useful in scenarios where memory bandwidth and power consumption are limited, such as in embedded systems or mobile devices.\n2. **Input**: The inputs include two matrices to be multiplied. These matrices are expected to be quantized, meaning their numerical precision is reduced to a specified number of bits.\n3. **Output**: The output is a single matrix that represents the product of the input matrices, also in a quantized format.\n4. **Procedure**: The function first checks if the matrices need to be transposed based on the specified parameter. It then groups the elements of the matrices according to the provided group size, performs the multiplication in the reduced bit-width format, and finally aggregates the results to form the output matrix.\n\n", "instruction": "Based on the function description and code context, please retrieve and repeat the exact described function from the code context in a code block wrapped by ```:", "template": "instruction\ncode_context\ndescription\ninstruction", "code_context": " const std::vector& inputs,\n const std::vector& axes);\n\n /** Print the primitive. */\n virtual void print(std::ostream& os) = 0;\n\n /** Equivalence check defaults to false unless overridden by the primitive */\n virtual bool is_equivalent(const Primitive& other) const {\n return false;\n }\n\n /** Get the output shapes of the primitive. This is not required to be\n * implemented by derived classes, in which case it will throw. */\n virtual std::vector> output_shapes(\n const std::vector& inputs);\n\n virtual ~Primitive() = default;\n Primitive(const Primitive& other) = delete;\n Primitive(Primitive&& other) = delete;\n Primitive& operator=(const Primitive& other) = delete;\n Primitive& operator=(Primitive&& other) = delete;\n\n private:\n // Every primitive stores the stream it should run in\n Stream stream_;\n};\n\nclass UnaryPrimitive : public Primitive {\n /**\n * An abstract base class for a primitive with a single output.\n */\n public:\n explicit UnaryPrimitive(Stream stream) : Primitive(stream) {}\n\n virtual void eval_cpu(const std::vector& inputs, array& output) = 0;\n virtual void eval_gpu(const std::vector& inputs, array& output) = 0;\n\n inline void eval_cpu(\n const std::vector& inputs,\n std::vector& outputs) override {\n eval_cpu(inputs, outputs[0]);\n }\n inline void eval_gpu(\n const std::vector& inputs,\n std::vector& outputs) override {\n eval_gpu(inputs, outputs[0]);\n }\n\n virtual ~UnaryPrimitive() = default;\n UnaryPrimitive(const UnaryPrimitive& other) = delete;\n UnaryPrimitive(UnaryPrimitive&& other) = delete;\n UnaryPrimitive& operator=(const UnaryPrimitive& other) = delete;\n UnaryPrimitive& operator=(UnaryPrimitive&& other) = delete;\n};\n\nclass Abs : public UnaryPrimitive {\n public:\n explicit Abs(Stream stream) : UnaryPrimitive(stream){};\n\n void eval_cpu(const std::vector& inputs, array& out) override;\n void eval_gpu(const std::vector& inputs, array& out) override;\n\n DEFINE_VMAP()\n DEFINE_GRADS()\n DEFINE_PRINT(Abs)\n DEFINE_DEFAULT_IS_EQUIVALENT()\n DEFINE_INPUT_OUTPUT_SHAPE()\n\n private:\n void eval(const std::vector& inputs, array& out);\n};\n\nclass Add : public UnaryPrimitive {\n public:\n explicit Add(Stream stream) : UnaryPrimitive(stream){};\n\n void eval_cpu(const std::vector& inputs, array& out) override;\n void eval_gpu(const std::vector& inputs, array& out) override;\n\n DEFINE_VMAP()\n DEFINE_GRADS()\n DEFINE_PRINT(Add)\n DEFINE_DEFAULT_IS_EQUIVALENT()\n DEFINE_INPUT_OUTPUT_SHAPE()\n\n private:\n void eval(const std::vector& inputs, array& out);\n};\n\nclass AddMM : public UnaryPrimitive {\n public:\n explicit AddMM(Stream stream, float alpha, float beta)\n : UnaryPrimitive(stream), alpha_(alpha), beta_(beta){};\n\n void eval_cpu(const std::vector& inputs, array& out) override;\n void eval_gpu(const std::vector& inputs, array& out) override;\n\n std::vector vjp(\n const std::vector& primals,\n const std::vector& cotangents,\n const std::vector& argnums,\n const std::vector& outputs) override;\n\n DEFINE_VMAP()\n DEFINE_PRINT(AddMM)\n\n bool is_equivalent(const Primitive& other) const override;\n\n private:\n const float alpha_;\n const float beta_;\n};\n\nclass Arange : public UnaryPrimitive {\n public:\n explicit Arange(Stream stream, double start, double stop, double step)\n : UnaryPrimitive(stream), start_(start), stop_(stop), step_(step){};\n\n void eval_cpu(const std::vector& inputs, array& out) override;\n void eval_gpu(const std::vector& inputs, array& out) override;\n\n DEFINE_PRINT(Arange)\n bool is_equivalent(const Primitive& other) const override;\n\n private:\n double start_;\n double stop_;\n double step_;\n\n void eval(const std::vector& inputs, array& out);\n};\n\nclass ArcCos : public UnaryPrimitive {\n public:\n explicit ArcCos(Stream stream) : UnaryPrimitive(stream){};\n\n void eval_cpu(const std::vector& inputs, array& out) override;\n void eval_gpu(const std::vector& inputs, array& out) override;\n\n DEFINE_VMAP()\n DEFINE_GRADS()\n DEFINE_PRINT(ArcCos)\n DEFINE_DEFAULT_IS_EQUIVALENT()\n DEFINE_INPUT_OUTPUT_SHAPE()\n\n private:\n void eval(const std::vector& inputs, array& out);\n};\n\nclass ArcCosh : public UnaryPrimitive {\n public:\n explicit ArcCosh(Stream stream) : UnaryPrimitive(stream){};\n\n void eval_cpu(const std::vector& inputs, array& out) override;\n void eval_gpu(const std::vector& inputs, array& out) override;\n\n DEFINE_VMAP()\n DEFINE_GRADS()\n DEFINE_PRINT(ArcCosh)\n DEFINE_DEFAULT_IS_EQUIVALENT()\n DEFINE_INPUT_OUTPUT_SHAPE()\n\n private:\n void eval(const std::vector& inputs, array& out);\n};\n\nclass ArcSin : public UnaryPrimitive {\n public:\n explicit ArcSin(Stream stream) : UnaryPrimitive(stream){};\n\n void eval_cpu(const std::vector& inputs, array& out) override;\n void eval_gpu(const std::vector& inputs, array& out) override;\n\n DEFINE_VMAP()\n DEFINE_GRADS()\n DEFINE_PRINT(ArcSin)\n DEFINE_DEFAULT_IS_EQUIVALENT()\n DEFINE_INPUT_OUTPUT_SHAPE()\n\n private:\n void eval(const std::vector& inputs, array& out);\n};\n\nclass ArcSinh : public UnaryPrimitive {\n public:\n explicit ArcSinh(Stream stream) : UnaryPrimitive(stream){};\n\n void eval_cpu(const std::vector& inputs, array& out) override;\n void eval_gpu(const std::vector& inputs, array& out) override;\n\n DEFINE_VMAP()\n DEFINE_GRADS()\n DEFINE_PRINT(ArcSinh)\n DEFINE_DEFAULT_IS_EQUIVALENT()\n DEFINE_INPUT_OUTPUT_SHAPE()\n\n private:\n void eval(const std::vector& inputs, array& out);\n};\n\nclass ArcTan : public UnaryPrimitive {\n public:\n explicit ArcTan(Stream stream) : UnaryPrimitive(stream){};\n\n void eval_cpu(const std::vector& inputs, array& out) override;\n void eval_gpu(const std::vector& inputs, array& out) override;\n\n DEFINE_VMAP()\n DEFINE_GRADS()\n DEFINE_PRINT(ArcTan)\n DEFINE_DEFAULT_IS_EQUIVALENT()\n DEFINE_INPUT_OUTPUT_SHAPE()\n\n private:\n void eval(const std::vector& inputs, array& out);\n};\n\nclass ArcTanh : public UnaryPrimitive {\n public:\n explicit ArcTanh(Stream stream) : UnaryPrimitive(stream){};\n\n void eval_cpu(const std::vector& inputs, array& out) override;\n void eval_gpu(const std::vector& inputs, array& out) override;\n\n DEFINE_VMAP()\n DEFINE_GRADS()\n DEFINE_PRINT(ArcTanh)\n DEFINE_DEFAULT_IS_EQUIVALENT()\n DEFINE_INPUT_OUTPUT_SHAPE()\n\n private:\n void eval(const std::vector& inputs, array& out);\n};\n\nclass ArgPartition : public UnaryPrimitive {\n public:\n explicit ArgPartition(Stream stream, int kth, int axis)\n : UnaryPrimitive(stream), kth_(kth), axis_(axis){};\n\n void eval_cpu(const std::vector& inputs, array& out) override;\n void eval_gpu(const std::vector& inputs, array& out) override;\n\n DEFINE_VMAP()\n DEFINE_PRINT(ArgPartition)\n DEFINE_INPUT_OUTPUT_SHAPE()\n bool is_equivalent(const Primitive& other) const override;\n\n private:\n int kth_;\n int axis_;\n\n void eval(const std::vector& inputs, array& out);\n};\n\nclass ArgReduce : public UnaryPrimitive {\n public:\n enum ReduceType {\n ArgMin,\n ArgMax,\n };\n\n explicit ArgReduce(Stream stream, ReduceType reduce_type, int axis)\n : UnaryPrimitive(stream), reduce_type_(reduce_type), axis_(axis){};\n\n void eval_cpu(const std::vector& inputs, array& out) override;\n void eval_gpu(const std::vector& inputs, array& out) override;\n\n DEFINE_VMAP()\n DEFINE_PRINT(ArgReduce)\n bool is_equivalent(const Primitive& other) const override;\n std::vector> output_shapes(\n const std::vector& inputs) override;\n\n private:\n ReduceType reduce_type_;\n int axis_;\n\n void eval(const std::vector& inputs, array& out);\n};\n\nclass ArgSort : public UnaryPrimitive {\n public:\n explicit ArgSort(Stream stream, int axis)\n : UnaryPrimitive(stream), axis_(axis){};\n\n void eval_cpu(const std::vector& inputs, array& out) override;\n void eval_gpu(const std::vector& inputs, array& out) override;\n\n DEFINE_VMAP()\n DEFINE_PRINT(ArgSort)\n DEFINE_INPUT_OUTPUT_SHAPE()\n bool is_equivalent(const Primitive& other) const override;\n\n private:\n int axis_;\n\n void eval(const std::vector& inputs, array& out);\n};\n\nclass AsType : public UnaryPrimitive {\n public:\n explicit AsType(Stream stream, Dtype dtype)\n : UnaryPrimitive(stream), dtype_(dtype){};\n\n void eval_cpu(const std::vector& inputs, array& out) override;\n void eval_gpu(const std::vector& inputs, array& out) override;\n\n DEFINE_VMAP()\n DEFINE_GRADS()\n DEFINE_PRINT(AsType)\n DEFINE_INPUT_OUTPUT_SHAPE()\n bool is_equivalent(const Primitive& other) const override;\n\n private:\n Dtype dtype_;\n\n void eval(const std::vector& inputs, array& out);\n};\n\nclass AsStrided : public UnaryPrimitive {\n public:\n explicit AsStrided(\n Stream stream,\n std::vector shape,\n std::vector strides,\n size_t offset)\n : UnaryPrimitive(stream),\n shape_(std::move(shape)),\n strides_(std::move(strides)),\n offset_(offset){};\n\n void eval_cpu(const std::vector& inputs, array& out) override;\n void eval_gpu(const std::vector& inputs, array& out) override;\n\n DEFINE_GRADS()\n DEFINE_PRINT(AsStrided)\n bool is_equivalent(const Primitive& other) const override;\n\n private:\n std::vector shape_;\n std::vector strides_;\n size_t offset_;\n\n void eval(const std::vector& inputs, array& out);\n};\n\nclass Broadcast : public UnaryPrimitive {\n public:\n explicit Broadcast(Stream stream, const std::vector& shape)\n : UnaryPrimitive(stream), shape_(shape){};\n\n void eval_cpu(const std::vector& inputs, array& out) override;\n void eval_gpu(const std::vector& inputs, array& out) override;\n\n DEFINE_VMAP()\n DEFINE_GRADS()\n DEFINE_PRINT(Broadcast)\n bool is_equivalent(const Primitive& other) const override;\n\n private:\n std::vector shape_;\n\n void eval(const std::vector& inputs, array& out);\n};\n\nclass Ceil : public UnaryPrimitive {\n public:\n explicit Ceil(Stream stream) : UnaryPrimitive(stream){};\n\n void eval_cpu(const std::vector& inputs, array& out) override;\n void eval_gpu(const std::vector& inputs, array& out) override;\n\n DEFINE_VMAP()\n DEFINE_GRADS()\n DEFINE_PRINT(Ceil)\n DEFINE_DEFAULT_IS_EQUIVALENT()\n DEFINE_INPUT_OUTPUT_SHAPE()\n\n private:\n void eval(const std::vector& inputs, array& out);\n};\n\nclass Compiled : public Primitive {\n public:\n /*\n * The inputs, outputs and tape are either tracers or constants.\n * - The tape should not contain the inputs, but it should contain the\n * outputs.\n * - The tape should also have only one array per primitive for multi-output\n * primitives.\n * - The constant_ids contains ids of arrays in the input list that are safe\n * to treat as scalar constants.\n */\n explicit Compiled(\n Stream stream,\n std::vector inputs,\n std::vector outputs,\n std::vector tape,\n std::unordered_set constant_ids);\n\n void eval_cpu(const std::vector& inputs, std::vector& outputs)\n override;\n void eval_gpu(const std::vector& inputs, std::vector& outputs)\n override;\n\n DEFINE_VMAP()\n DEFINE_GRADS()\n std::vector> output_shapes(\n const std::vector& inputs) override;\n void print(std::ostream& os) override;\n bool is_equivalent(const Primitive& other) const override;\n\n std::string lib_name() const {\n return kernel_lib_;\n }\n\n private:\n const std::vector inputs_;\n const std::vector outputs_;\n const std::vector tape_;\n const std::unordered_set constant_ids_;\n\n std::string kernel_lib_;\n};\n\nclass Concatenate : public UnaryPrimitive {\n public:\n explicit Concatenate(Stream stream, int axis)\n : UnaryPrimitive(stream), axis_(axis){};\n\n void eval_cpu(const std::vector& inputs, array& out) override;\n void eval_gpu(const std::vector& inputs, array& out) override;\n\n DEFINE_VMAP()\n DEFINE_GRADS()\n DEFINE_PRINT(Concatenate)\n bool is_equivalent(const Primitive& other) const override;\n\n private:\n int axis_;\n\n void eval(const std::vector& inputs, array& out);\n};\n\nclass Convolution : public UnaryPrimitive {\n public:\n explicit Convolution(\n Stream stream,\n const std::vector& kernel_strides,\n const std::vector& padding,\n const std::vector& kernel_dilation,\n const std::vector& input_dilation,\n const int groups = 1,\n const bool flip = false)\n : UnaryPrimitive(stream),\n padding_(padding),\n kernel_strides_(kernel_strides),\n kernel_dilation_(kernel_dilation),\n input_dilation_(input_dilation),\n groups_(groups),\n flip_(flip){};\n\n void eval_cpu(const std::vector& inputs, array& out) override;\n void eval_gpu(const std::vector& inputs, array& out) override;\n\n std::vector vjp(\n const std::vector& primals,\n const std::vector& cotangents,\n const std::vector& argnums,\n const std::vector& outputs) override;\n\n DEFINE_PRINT(Convolution)\n bool is_equivalent(const Primitive& other) const override;\n\n private:\n std::vector padding_;\n std::vector kernel_strides_;\n std::vector kernel_dilation_;\n std::vector input_dilation_;\n int groups_;\n bool flip_;\n\n void eval(const std::vector& inputs, array& out);\n};\n\nclass Copy : public UnaryPrimitive {\n public:\n explicit Copy(Stream stream) : UnaryPrimitive(stream){};\n\n void eval_cpu(const std::vector& inputs, array& out) override;\n void eval_gpu(const std::vector& inputs, array& out) override;\n\n DEFINE_VMAP()\n DEFINE_GRADS()\n DEFINE_PRINT(Copy)\n DEFINE_DEFAULT_IS_EQUIVALENT()\n DEFINE_INPUT_OUTPUT_SHAPE()\n\n private:\n void eval(const std::vector& inputs, array& out);\n};\n\nclass Cos : public UnaryPrimitive {\n public:\n explicit Cos(Stream stream) : UnaryPrimitive(stream){};\n\n void eval_cpu(const std::vector& inputs, array& out) override;\n void eval_gpu(const std::vector& inputs, array& out) override;\n\n DEFINE_VMAP()\n DEFINE_GRADS()\n DEFINE_PRINT(Cos)\n DEFINE_DEFAULT_IS_EQUIVALENT()\n DEFINE_INPUT_OUTPUT_SHAPE()\n\n private:\n void eval(const std::vector& inputs, array& out);\n};\n\nclass Cosh : public UnaryPrimitive {\n public:\n explicit Cosh(Stream stream) : UnaryPrimitive(stream){};\n\n void eval_cpu(const std::vector& inputs, array& out) override;\n void eval_gpu(const std::vector& inputs, array& out) override;\n\n DEFINE_VMAP()\n DEFINE_GRADS()\n DEFINE_PRINT(Cosh)\n DEFINE_DEFAULT_IS_EQUIVALENT()\n DEFINE_INPUT_OUTPUT_SHAPE()\n\n private:\n void eval(const std::vector& inputs, array& out);\n};\n\nclass CustomVJP : public Primitive {\n public:\n explicit CustomVJP(\n Stream stream,\n std::function(\n const std::vector&,\n const std::vector&,\n const std::vector&)> fun)\n : Primitive(stream), vjp_fun_(std::move(fun)) {}\n\n void eval_cpu(const std::vector& inputs, std::vector& outputs)\n override;\n void eval_gpu(const std::vector& inputs, std::vector& outputs)\n override;\n\n std::vector vjp(\n const std::vector& primals,\n const std::vector& cotan,\n const std::vector& argnums,\n const std::vector& outputs) override;\n\n DEFINE_PRINT(CustomVJP);\n\n private:\n void eval(const std::vector& inputs, std::vector& outputs);\n\n std::function(\n const std::vector&,\n const std::vector&,\n const std::vector&)>\n vjp_fun_;\n};\n\nclass Depends : public Primitive {\n public:\n explicit Depends(Stream stream) : Primitive(stream) {}\n\n void eval_cpu(const std::vector& inputs, std::vector& outputs)\n override;\n void eval_gpu(const std::vector& inputs, std::vector& outputs)\n override;\n\n std::vector vjp(\n const std::vector& primals,\n const std::vector& cotan,\n const std::vector& argnums,\n const std::vector& outputs) override;\n\n DEFINE_PRINT(Depends);\n\n private:\n void eval(const std::vector& inputs, std::vector& outputs);\n};\n\nclass Divide : public UnaryPrimitive {\n public:\n explicit Divide(Stream stream) : UnaryPrimitive(stream){};\n\n void eval_cpu(const std::vector& inputs, array& out) override;\n void eval_gpu(const std::vector& inputs, array& out) override;\n\n DEFINE_VMAP()\n DEFINE_GRADS()\n DEFINE_PRINT(Divide)\n DEFINE_DEFAULT_IS_EQUIVALENT()\n DEFINE_INPUT_OUTPUT_SHAPE()\n\n private:\n void eval(const std::vector& inputs, array& out);\n};\n\nclass DivMod : public Primitive {\n public:\n explicit DivMod(Stream stream) : Primitive(stream){};\n\n void eval_cpu(const std::vector& inputs, std::vector& outputs)\n override;\n void eval_gpu(const std::vector& inputs, std::vector& outputs)\n override;\n\n DEFINE_VMAP()\n DEFINE_GRADS()\n DEFINE_PRINT(DivMod)\n DEFINE_DEFAULT_IS_EQUIVALENT()\n std::vector> output_shapes(\n const std::vector& inputs) override {\n return std::vector{inputs[0].shape(), inputs[0].shape()};\n };\n\n private:\n void eval(const std::vector& inputs, std::vector& outputs);\n};\n\nclass Select : public UnaryPrimitive {\n public:\n explicit Select(Stream stream) : UnaryPrimitive(stream){};\n\n void eval_cpu(const std::vector& inputs, array& out) override;\n void eval_gpu(const std::vector& inputs, array& out) override;\n\n DEFINE_VMAP()\n DEFINE_GRADS()\n DEFINE_PRINT(Select)\n DEFINE_DEFAULT_IS_EQUIVALENT()\n DEFINE_INPUT_OUTPUT_SHAPE()\n\n private:\n void eval(const std::vector& inputs, array& out);\n};\n\nclass Remainder : public UnaryPrimitive {\n public:\n explicit Remainder(Stream stream) : UnaryPrimitive(stream){};\n\n void eval_cpu(const std::vector& inputs, array& out) override;\n void eval_gpu(const std::vector& inputs, array& out) override;\n\n DEFINE_VMAP()\n DEFINE_GRADS()\n DEFINE_PRINT(Remainder)\n DEFINE_DEFAULT_IS_EQUIVALENT()\n DEFINE_INPUT_OUTPUT_SHAPE()\n\n private:\n void eval(const std::vector& inputs, array& out);\n};\n\nclass Equal : public UnaryPrimitive {\n public:\n explicit Equal(Stream stream, bool equal_nan = false)\n : UnaryPrimitive(stream), equal_nan_(equal_nan){};\n\n void eval_cpu(const std::vector& inputs, array& out) override;\n void eval_gpu(const std::vector& inputs, array& out) override;\n\n DEFINE_VMAP()\n DEFINE_GRADS()\n DEFINE_DEFAULT_IS_EQUIVALENT()\n DEFINE_INPUT_OUTPUT_SHAPE()\n\n void print(std::ostream& os) override {\n if (equal_nan_) {\n os << \"NanEqual\";\n } else {\n os << \"Equal\";\n }\n }\n\n private:\n void eval(const std::vector& inputs, array& out);\n bool equal_nan_;\n};\n\nclass Erf : public UnaryPrimitive {\n public:\n explicit Erf(Stream stream) : UnaryPrimitive(stream){};\n\n void eval_cpu(const std::vector& inputs, array& out) override;\n void eval_gpu(const std::vector& inputs, array& out) override;\n\n DEFINE_VMAP()\n DEFINE_GRADS()\n DEFINE_PRINT(Erf)\n DEFINE_DEFAULT_IS_EQUIVALENT()\n DEFINE_INPUT_OUTPUT_SHAPE()\n\n private:\n void eval(const std::vector& inputs, array& out);\n};\n\nclass ErfInv : public UnaryPrimitive {\n public:\n explicit ErfInv(Stream stream) : UnaryPrimitive(stream){};\n\n void eval_cpu(const std::vector& inputs, array& out) override;\n void eval_gpu(const std::vector& inputs, array& out) override;\n\n DEFINE_VMAP()\n DEFINE_GRADS()\n DEFINE_PRINT(ErfInv)\n DEFINE_DEFAULT_IS_EQUIVALENT()\n DEFINE_INPUT_OUTPUT_SHAPE()\n\n private:\n void eval(const std::vector& inputs, array& out);\n};\n\nclass Exp : public UnaryPrimitive {\n public:\n explicit Exp(Stream stream) : UnaryPrimitive(stream){};\n\n void eval_cpu(const std::vector& inputs, array& out) override;\n void eval_gpu(const std::vector& inputs, array& out) override;\n\n DEFINE_VMAP()\n DEFINE_GRADS()\n DEFINE_PRINT(Exp)\n DEFINE_DEFAULT_IS_EQUIVALENT()\n DEFINE_INPUT_OUTPUT_SHAPE()\n\n private:\n void eval(const std::vector& inputs, array& out);\n};\n\nclass Expm1 : public UnaryPrimitive {\n public:\n explicit Expm1(Stream stream) : UnaryPrimitive(stream){};\n\n void eval_cpu(const std::vector& inputs, array& out) override;\n void eval_gpu(const std::vector& inputs, array& out) override;\n\n DEFINE_VMAP()\n DEFINE_GRADS()\n DEFINE_PRINT(Expm1)\n DEFINE_INPUT_OUTPUT_SHAPE()\n\n private:\n void eval(const std::vector& inputs, array& out);\n};\n\nclass FFT : public UnaryPrimitive {\n public:\n explicit FFT(\n Stream stream,\n const std::vector& axes,\n bool inverse,\n bool real)\n : UnaryPrimitive(stream), axes_(axes), inverse_(inverse), real_(real){};\n\n void eval_cpu(const std::vector& inputs, array& out) override;\n void eval_gpu(const std::vector& inputs, array& out) override;\n\n DEFINE_VMAP()\n DEFINE_GRADS()\n DEFINE_PRINT(FFT)\n\n bool is_equivalent(const Primitive& other) const override;\n\n private:\n std::vector axes_;\n bool inverse_;\n bool real_;\n\n void eval(const std::vector& inputs, array& out);\n};\n\nclass Floor : public UnaryPrimitive {\n public:\n explicit Floor(Stream stream) : UnaryPrimitive(stream){};\n\n void eval_cpu(const std::vector& inputs, array& out) override;\n void eval_gpu(const std::vector& inputs, array& out) override;\n\n DEFINE_VMAP()\n DEFINE_GRADS()\n DEFINE_PRINT(Floor)\n DEFINE_DEFAULT_IS_EQUIVALENT()\n DEFINE_INPUT_OUTPUT_SHAPE()\n\n private:\n void eval(const std::vector& inputs, array& out);\n};\n\nclass Full : public UnaryPrimitive {\n public:\n explicit Full(Stream stream) : UnaryPrimitive(stream){};\n\n void eval_cpu(const std::vector& inputs, array& out) override;\n void eval_gpu(const std::vector& inputs, array& out) override;\n\n DEFINE_VMAP()\n DEFINE_GRADS()\n DEFINE_PRINT(Full)\n DEFINE_DEFAULT_IS_EQUIVALENT()\n\n private:\n void eval(const std::vector& inputs, array& out);\n};\n\nclass Gather : public UnaryPrimitive {\n public:\n explicit Gather(\n Stream stream,\n const std::vector& axes,\n const std::vector& slice_sizes)\n : UnaryPrimitive(stream), axes_(axes), slice_sizes_(slice_sizes){};\n\n void eval_cpu(const std::vector& inputs, array& out) override;\n void eval_gpu(const std::vector& inputs, array& out) override;\n\n DEFINE_VMAP()\n DEFINE_GRADS()\n DEFINE_PRINT(Gather)\n bool is_equivalent(const Primitive& other) const override;\n\n private:\n void eval(const std::vector& inputs, array& out);\n std::vector axes_;\n std::vector slice_sizes_;\n};\n\nclass Greater : public UnaryPrimitive {\n public:\n explicit Greater(Stream stream) : UnaryPrimitive(stream){};\n\n void eval_cpu(const std::vector& inputs, array& out) override;\n void eval_gpu(const std::vector& inputs, array& out) override;\n\n DEFINE_VMAP()\n DEFINE_GRADS()\n DEFINE_PRINT(Greater)\n DEFINE_DEFAULT_IS_EQUIVALENT()\n DEFINE_INPUT_OUTPUT_SHAPE()\n\n private:\n void eval(const std::vector& inputs, array& out);\n};\n\nclass GreaterEqual : public UnaryPrimitive {\n public:\n explicit GreaterEqual(Stream stream) : UnaryPrimitive(stream){};\n\n void eval_cpu(const std::vector& inputs, array& out) override;\n void eval_gpu(const std::vector& inputs, array& out) override;\n\n DEFINE_VMAP()\n DEFINE_GRADS()\n DEFINE_PRINT(GreaterEqual)\n DEFINE_DEFAULT_IS_EQUIVALENT()\n DEFINE_INPUT_OUTPUT_SHAPE()\n\n private:\n void eval(const std::vector& inputs, array& out);\n};\n\nclass Less : public UnaryPrimitive {\n public:\n explicit Less(Stream stream) : UnaryPrimitive(stream){};\n\n void eval_cpu(const std::vector& inputs, array& out) override;\n void eval_gpu(const std::vector& inputs, array& out) override;\n\n DEFINE_VMAP()\n DEFINE_GRADS()\n DEFINE_PRINT(Less)\n DEFINE_DEFAULT_IS_EQUIVALENT()\n DEFINE_INPUT_OUTPUT_SHAPE()\n\n private:\n void eval(const std::vector& inputs, array& out);\n};\n\nclass LessEqual : public UnaryPrimitive {\n public:\n explicit LessEqual(Stream stream) : UnaryPrimitive(stream){};\n\n void eval_cpu(const std::vector& inputs, array& out) override;\n void eval_gpu(const std::vector& inputs, array& out) override;\n\n DEFINE_VMAP()\n DEFINE_GRADS()\n DEFINE_PRINT(LessEqual)\n DEFINE_DEFAULT_IS_EQUIVALENT()\n DEFINE_INPUT_OUTPUT_SHAPE()\n\n private:\n void eval(const std::vector& inputs, array& out);\n};\n\nclass Load : public UnaryPrimitive {\n public:\n explicit Load(\n Stream stream,\n std::shared_ptr reader,\n size_t offset,\n bool swap_endianness = false)\n : UnaryPrimitive(stream),\n reader_(reader),\n offset_(offset),\n swap_endianness_(swap_endianness){};\n\n void eval_cpu(const std::vector& inputs, array& out) override;\n void eval_gpu(const std::vector& inputs, array& out) override;\n\n DEFINE_PRINT(Load)\n\n private:\n void eval(const std::vector& inputs, array& out);\n std::shared_ptr reader_;\n size_t offset_;\n bool swap_endianness_;\n};\n\nclass Log : public UnaryPrimitive {\n public:\n enum Base { two, ten, e };\n\n explicit Log(Stream stream, Base base)\n : UnaryPrimitive(stream), base_(base){};\n\n void eval_cpu(const std::vector& inputs, array& out) override;\n void eval_gpu(const std::vector& inputs, array& out) override;\n\n DEFINE_VMAP()\n DEFINE_GRADS()\n DEFINE_DEFAULT_IS_EQUIVALENT()\n DEFINE_INPUT_OUTPUT_SHAPE()\n\n void print(std::ostream& os) override {\n switch (base_) {\n case e:\n os << \"Log\";\n break;\n case two:\n os << \"Log2\";\n break;\n case ten:\n os << \"Log10\";\n break;\n }\n }\n\n private:\n Base base_;\n void eval(const std::vector& inputs, array& out);\n};\n\nclass Log1p : public UnaryPrimitive {\n public:\n explicit Log1p(Stream stream) : UnaryPrimitive(stream){};\n\n void eval_cpu(const std::vector& inputs, array& out) override;\n void eval_gpu(const std::vector& inputs, array& out) override;\n\n DEFINE_VMAP()\n DEFINE_GRADS()\n DEFINE_PRINT(Log1p)\n DEFINE_INPUT_OUTPUT_SHAPE()\n\n private:\n void eval(const std::vector& inputs, array& out);\n};\n\nclass LogicalNot : public UnaryPrimitive {\n public:\n explicit LogicalNot(Stream stream) : UnaryPrimitive(stream){};\n\n void eval_cpu(const std::vector& inputs, array& out) override;\n void eval_gpu(const std::vector& inputs, array& out) override;\n\n DEFINE_VMAP()\n DEFINE_GRADS()\n DEFINE_PRINT(LogicalNot)\n DEFINE_DEFAULT_IS_EQUIVALENT()\n DEFINE_INPUT_OUTPUT_SHAPE()\n\n private:\n void eval(const std::vector& inputs, array& out);\n};\n\nclass LogicalAnd : public UnaryPrimitive {\n public:\n explicit LogicalAnd(Stream stream) : UnaryPrimitive(stream){};\n\n void eval_cpu(const std::vector& inputs, array& out) override;\n void eval_gpu(const std::vector& inputs, array& out) override;\n\n DEFINE_VMAP()\n DEFINE_GRADS()\n DEFINE_PRINT(LogicalAnd)\n DEFINE_DEFAULT_IS_EQUIVALENT()\n DEFINE_INPUT_OUTPUT_SHAPE()\n\n private:\n void eval(const std::vector& inputs, array& out);\n};\n\nclass LogicalOr : public UnaryPrimitive {\n public:\n explicit LogicalOr(Stream stream) : UnaryPrimitive(stream){};\n\n void eval_cpu(const std::vector& inputs, array& out) override;\n void eval_gpu(const std::vector& inputs, array& out) override;\n\n DEFINE_VMAP()\n DEFINE_GRADS()\n DEFINE_PRINT(LogicalOr)\n DEFINE_DEFAULT_IS_EQUIVALENT()\n DEFINE_INPUT_OUTPUT_SHAPE()\n\n private:\n void eval(const std::vector& inputs, array& out);\n};\n\nclass LogAddExp : public UnaryPrimitive {\n public:\n explicit LogAddExp(Stream stream) : UnaryPrimitive(stream){};\n\n void eval_cpu(const std::vector& inputs, array& out) override;\n void eval_gpu(const std::vector& inputs, array& out) override;\n\n DEFINE_VMAP()\n DEFINE_GRADS()\n DEFINE_PRINT(LogAddExp)\n DEFINE_DEFAULT_IS_EQUIVALENT()\n DEFINE_INPUT_OUTPUT_SHAPE()\n\n private:\n void eval(const std::vector& inputs, array& out);\n};\n\nclass Matmul : public UnaryPrimitive {\n public:\n explicit Matmul(Stream stream) : UnaryPrimitive(stream){};\n\n void eval_cpu(const std::vector& inputs, array& out) override;\n void eval_gpu(const std::vector& inputs, array& out) override;\n\n std::vector vjp(\n const std::vector& primals,\n const std::vector& cotangents,\n const std::vector& argnums,\n const std::vector& outputs) override;\n\n DEFINE_VMAP()\n DEFINE_PRINT(Matmul)\n DEFINE_DEFAULT_IS_EQUIVALENT()\n};\n\nclass Maximum : public UnaryPrimitive {\n public:\n explicit Maximum(Stream stream) : UnaryPrimitive(stream){};\n\n void eval_cpu(const std::vector& inputs, array& out) override;\n void eval_gpu(const std::vector& inputs, array& out) override;\n\n DEFINE_VMAP()\n DEFINE_GRADS()\n DEFINE_PRINT(Maximum)\n DEFINE_DEFAULT_IS_EQUIVALENT()\n DEFINE_INPUT_OUTPUT_SHAPE()\n\n private:\n void eval(const std::vector& inputs, array& out);\n};\n\nclass Minimum : public UnaryPrimitive {\n public:\n explicit Minimum(Stream stream) : UnaryPrimitive(stream){};\n\n void eval_cpu(const std::vector& inputs, array& out) override;\n void eval_gpu(const std::vector& inputs, array& out) override;\n\n DEFINE_VMAP()\n DEFINE_GRADS()\n DEFINE_PRINT(Minimum)\n DEFINE_DEFAULT_IS_EQUIVALENT()\n DEFINE_INPUT_OUTPUT_SHAPE()\n\n private:\n void eval(const std::vector& inputs, array& out);\n};\n\nclass Multiply : public UnaryPrimitive {\n public:\n explicit Multiply(Stream stream) : UnaryPrimitive(stream){};\n\n void eval_cpu(const std::vector& inputs, array& out) override;\n void eval_gpu(const std::vector& inputs, array& out) override;\n\n DEFINE_VMAP()\n DEFINE_GRADS()\n DEFINE_PRINT(Multiply)\n DEFINE_DEFAULT_IS_EQUIVALENT()\n DEFINE_INPUT_OUTPUT_SHAPE()\n\n private:\n void eval(const std::vector& inputs, array& out);\n};\n\nclass Negative : public UnaryPrimitive {\n public:\n explicit Negative(Stream stream) : UnaryPrimitive(stream){};\n\n void eval_cpu(const std::vector& inputs, array& out) override;\n void eval_gpu(const std::vector& inputs, array& out) override;\n\n DEFINE_VMAP()\n DEFINE_GRADS()\n DEFINE_PRINT(Negative)\n DEFINE_DEFAULT_IS_EQUIVALENT()\n DEFINE_INPUT_OUTPUT_SHAPE()\n\n private:\n void eval(const std::vector& inputs, array& out);\n};\n\nclass NotEqual : public UnaryPrimitive {\n public:\n explicit NotEqual(Stream stream) : UnaryPrimitive(stream){};\n\n void eval_cpu(const std::vector& inputs, array& out) override;\n void eval_gpu(const std::vector& inputs, array& out) override;\n\n DEFINE_VMAP()\n DEFINE_GRADS()\n DEFINE_PRINT(NotEqual)\n DEFINE_DEFAULT_IS_EQUIVALENT()\n DEFINE_INPUT_OUTPUT_SHAPE()\n\n private:\n void eval(const std::vector& inputs, array& out);\n};\n\nclass NumberOfElements : public UnaryPrimitive {\n public:\n explicit NumberOfElements(\n Stream stream,\n std::vector axes,\n bool inverted,\n Dtype dtype)\n : UnaryPrimitive(stream),\n axes_(std::move(axes)),\n inverted_(inverted),\n dtype_(dtype) {}\n\n void eval_cpu(const std::vector& inputs, array& out) override;\n void eval_gpu(const std::vector& inputs, array& out) override;\n\n DEFINE_VMAP()\n DEFINE_PRINT(NumberOfElements)\n bool is_equivalent(const Primitive& other) const override;\n std::vector> output_shapes(\n const std::vector& inputs) override {\n return {{}};\n }\n\n private:\n std::vector axes_;\n bool inverted_;\n Dtype dtype_;\n\n void eval(const std::vector& inputs, array& out);\n};\n\nclass Pad : public UnaryPrimitive {\n public:\n explicit Pad(\n Stream stream,\n const std::vector& axes,\n const std::vector& low_pad_size,\n const std::vector& high_pad_size)\n : UnaryPrimitive(stream),\n axes_(axes),\n low_pad_size_(low_pad_size),\n high_pad_size_(high_pad_size){};\n\n void eval_cpu(const std::vector& inputs, array& out) override;\n void eval_gpu(const std::vector& inputs, array& out) override;\n\n DEFINE_VMAP()\n DEFINE_GRADS()\n DEFINE_PRINT(Pad)\n bool is_equivalent(const Primitive& other) const override;\n\n private:\n std::vector axes_;\n std::vector low_pad_size_;\n std::vector high_pad_size_;\n\n void eval(const std::vector& inputs, array& out);\n};\n\nclass Partition : public UnaryPrimitive {\n public:\n explicit Partition(Stream stream, int kth, int axis)\n : UnaryPrimitive(stream), kth_(kth), axis_(axis){};\n\n void eval_cpu(const std::vector& inputs, array& out) override;\n void eval_gpu(const std::vector& inputs, array& out) override;\n\n DEFINE_VMAP()\n DEFINE_GRADS()\n DEFINE_PRINT(Partition)\n DEFINE_INPUT_OUTPUT_SHAPE()\n bool is_equivalent(const Primitive& other) const override;\n\n private:\n int kth_;\n int axis_;\n\n void eval(const std::vector& inputs, array& out);\n};\n\nclass Power : public UnaryPrimitive {\n public:\n explicit Power(Stream stream) : UnaryPrimitive(stream){};\n\n void eval_cpu(const std::vector& inputs, array& out) override;\n void eval_gpu(const std::vector& inputs, array& out) override;\n\n DEFINE_VMAP()\n DEFINE_GRADS()\n DEFINE_PRINT(Power)\n DEFINE_DEFAULT_IS_EQUIVALENT()\n DEFINE_INPUT_OUTPUT_SHAPE()\n\n private:\n void eval(const std::vector& inputs, array& out);\n};\n\nclass QuantizedMatmul : public UnaryPrimitive {\n public:\n e\nxplicit QuantizedMatmul(\n Stream stream,\n int group_size,\n int bits,\n bool transpose)\n : UnaryPrimitive(stream),\n group_size_(group_size),\n bits_(bits),\n transpose_(transpose){};\n\n void eval_cpu(const std::vector& inputs, array& out) override;\n void eval_gpu(const std::vector& inputs, array& out) override;\n\n DEFINE_VMAP()\n DEFINE_GRADS()\n DEFINE_PRINT(QuantizedMatmul)\n bool is_equivalent(const Primitive& other) const override;\n\n private:\n int group_size_;\n int bits_;\n bool transpose_;\n\n void eval(const std::vector& inputs, array& out);\n};\n\nclass RandomBits : public UnaryPrimitive {\n public:\n explicit RandomBits(Stream stream, const std::vector& shape, int width)\n : UnaryPrimitive(stream), shape_(shape), width_(width){};\n\n void eval_cpu(const std::vector& inputs, array& out) override;\n void eval_gpu(const std::vector& inputs, array& out) override;\n\n DEFINE_VMAP()\n DEFINE_PRINT(RandomBits)\n bool is_equivalent(const Primitive& other) const override;\n\n private:\n std::vector shape_;\n int width_;\n\n void eval(const std::vector& inputs, array& out);\n};\n\nclass Reshape : public UnaryPrimitive {\n public:\n explicit Reshape(Stream stream, const std::vector& shape)\n : UnaryPrimitive(stream), shape_(shape){};\n\n void eval_cpu(const std::vector& inputs, array& out) override;\n void eval_gpu(const std::vector& inputs, array& out) override;\n\n DEFINE_VMAP()\n DEFINE_GRADS()\n DEFINE_PRINT(Reshape)\n bool is_equivalent(const Primitive& other) const override;\n\n private:\n std::vector shape_;\n\n void eval(const std::vector& inputs, array& out);\n\n std::pair> prepare_reshape(\n const array& in,\n const array& out);\n void shared_buffer_reshape(\n const array& in,\n const std::vector& out_strides,\n array& out);\n};\n\nclass Reduce : public UnaryPrimitive {\n public:\n enum ReduceType { And, Or, Sum, Prod, Min, Max };\n\n explicit Reduce(\n Stream stream,\n ReduceType reduce_type,\n const std::vector& axes)\n : UnaryPrimitive(stream), reduce_type_(reduce_type), axes_(axes){};\n\n void eval_cpu(const std::vector& inputs, array& out) override;\n void eval_gpu(const std::vector& inputs, array& out) override;\n\n DEFINE_VMAP()\n\n std::vector vjp(\n const std::vector& primals,\n const std::vector& cotangents,\n const std::vector& argnums,\n const std::vector& outputs) override;\n\n std::vector> output_shapes(\n const std::vector& inputs) override;\n\n void print(std::ostream& os) override {\n switch (reduce_type_) {\n case And:\n os << \"And\";\n case Or:\n os << \"And\";\n break;\n case Sum:\n os << \"Sum\";\n break;\n case Prod:\n os << \"Prod\";\n break;\n case Min:\n os << \"Min\";\n break;\n case Max:\n os << \"Max\";\n break;\n }\n os << \" Reduce\";\n }\n bool is_equivalent(const Primitive& other) const override;\n\n private:\n ReduceType reduce_type_;\n std::vector axes_;\n\n void eval(const std::vector& inputs, array& out);\n};\n\nclass Round : public UnaryPrimitive {\n public:\n explicit Round(Stream stream) : UnaryPrimitive(stream){};\n\n void eval_cpu(const std::vector& inputs, array& out) override;\n void eval_gpu(const std::vector& inputs, array& out) override;\n\n DEFINE_VMAP()\n DEFINE_GRADS()\n DEFINE_PRINT(Round)\n DEFINE_DEFAULT_IS_EQUIVALENT()\n DEFINE_INPUT_OUTPUT_SHAPE()\n\n private:\n void eval(const std::vector& inputs, array& out);\n};\n\nclass Scan : public UnaryPrimitive {\n public:\n enum ReduceType { Max, Min, Sum, Prod };\n\n explicit Scan(\n Stream stream,\n ReduceType reduce_type,\n int axis,\n bool reverse,\n bool inclusive)\n : UnaryPrimitive(stream),\n reduce_type_(reduce_type),\n axis_(axis),\n reverse_(reverse),\n inclusive_(inclusive){};\n\n void eval_cpu(const std::vector& inputs, array& out) override;\n void eval_gpu(const std::vector& inputs, array& out) override;\n\n DEFINE_VMAP()\n DEFINE_GRADS();\n\n void print(std::ostream& os) override {\n os << \"Cum\";\n switch (reduce_type_) {\n case Sum:\n os << \"Sum\";\n break;\n case Prod:\n os << \"Prod\";\n break;\n case Min:\n os << \"Min\";\n break;\n case Max:\n os << \"Max\";\n break;\n }\n os << \" Reduce\";\n }\n bool is_equivalent(const Primitive& other) const override;\n\n private:\n ReduceType reduce_type_;\n int axis_;\n bool reverse_;\n bool inclusive_;\n\n void eval(const std::vector& inputs, array& out);\n};\n\nclass Scatter : public UnaryPrimitive {\n public:\n enum ReduceType { Max, Min, Sum, Prod, None };\n\n explicit Scatter(\n Stream stream,\n ReduceType reduce_type,\n const std::vector& axes)\n : UnaryPrimitive(stream), reduce_type_(reduce_type), axes_(axes){};\n\n void eval_cpu(const std::vector& inputs, array& out) override;\n void eval_gpu(const std::vector& inputs, array& out) override;\n\n DEFINE_GRADS();\n void print(std::ostream& os) override {\n os << \"Scatter\";\n switch (reduce_type_) {\n case Sum:\n os << \" Sum\";\n break;\n case Prod:\n os << \" Prod\";\n break;\n case Min:\n os << \" Min\";\n break;\n case Max:\n os << \" Max\";\n break;\n case None:\n break;\n }\n }\n bool is_equivalent(const Primitive& other) const override;\n\n private:\n void eval(const std::vector& inputs, array& out);\n ReduceType reduce_type_;\n std::vector axes_;\n};\n\nclass Sigmoid : public UnaryPrimitive {\n public:\n explicit Sigmoid(Stream stream) : UnaryPrimitive(stream){};\n\n void eval_cpu(const std::vector& inputs, array& out) override;\n void eval_gpu(const std::vector& inputs, array& out) override;\n\n DEFINE_VMAP()\n DEFINE_GRADS()\n DEFINE_PRINT(Sigmoid)\n DEFINE_DEFAULT_IS_EQUIVALENT()\n DEFINE_INPUT_OUTPUT_SHAPE()\n\n private:\n void eval(const std::vector& inputs, array& out);\n};\n\nclass Sign : public UnaryPrimitive {\n public:\n explicit Sign(Stream stream) : UnaryPrimitive(stream){};\n\n void eval_cpu(const std::vector& inputs, array& out) override;\n void eval_gpu(const std::vector& inputs, array& out) override;\n\n DEFINE_VMAP()\n DEFINE_GRADS()\n DEFINE_PRINT(Sign)\n DEFINE_DEFAULT_IS_EQUIVALENT()\n DEFINE_INPUT_OUTPUT_SHAPE()\n\n private:\n void eval(const std::vector& inputs, array& out);\n};\n\nclass Sin : public UnaryPrimitive {\n public:\n explicit Sin(Stream stream) : UnaryPrimitive(stream){};\n\n void eval_cpu(const std::vector& inputs, array& out) override;\n void eval_gpu(const std::vector& inputs, array& out) override;\n\n DEFINE_VMAP()\n DEFINE_GRADS()\n DEFINE_PRINT(Sin)\n DEFINE_DEFAULT_IS_EQUIVALENT()\n DEFINE_INPUT_OUTPUT_SHAPE()\n\n private:\n void eval(const std::vector& inputs, array& out);\n};\n\nclass Sinh : public UnaryPrimitive {\n public:\n explicit Sinh(Stream stream) : UnaryPrimitive(stream){};\n\n void eval_cpu(const std::vector& inputs, array& out) override;\n void eval_gpu(const std::vector& inputs, array& out) override;\n\n DEFINE_VMAP()\n DEFINE_GRADS()\n DEFINE_PRINT(Sinh)\n DEFINE_DEFAULT_IS_EQUIVALENT()\n DEFINE_INPUT_OUTPUT_SHAPE()\n\n private:\n void eval(const std::vector& inputs, array& out);\n};\n\nclass Slice : public UnaryPrimitive {\n public:\n explicit Slice(\n Stream stream,\n const std::vector& start_indices,\n const std::vector& end_indices,\n const std::vector& strides)\n : UnaryPrimitive(stream),\n start_indices_(start_indices),\n end_indices_(end_indices),\n strides_(strides){};\n\n void eval_cpu(const std::vector& inputs, array& out) override;\n void eval_gpu(const std::vector& inputs, array& out) override;\n\n DEFINE_VMAP()\n DEFINE_GRADS()\n DEFINE_PRINT(Slice)\n bool is_equivalent(const Primitive& other) const override;\n\n private:\n std::vector start_indices_;\n std::vector end_indices_;\n std::vector strides_;\n\n void eval(const std::vector& inputs, array& out);\n\n std::tuple> prepare_slice(\n const array& in);\n void shared_buffer_slice(\n const array& in,\n const std::vector& out_strides,\n size_t data_offset,\n array& out);\n};\n\nclass SliceUpdate : public UnaryPrimitive {\n public:\n explicit SliceUpdate(\n Stream stream,\n const std::vector& start_indices,\n const std::vector& end_indices,\n const std::vector& strides)\n : UnaryPrimitive(stream),\n start_indices_(start_indices),\n end_indices_(end_indices),\n strides_(strides){};\n\n void eval_cpu(const std::vector& inputs, array& out) override;\n void eval_gpu(const std::vector& inputs, array& out) override;\n\n DEFINE_VMAP()\n DEFINE_GRADS()\n DEFINE_PRINT(SliceUpdate)\n bool is_equivalent(const Primitive& other) const override;\n\n private:\n std::vector start_indices_;\n std::vector end_indices_;\n std::vector strides_;\n\n void eval(const std::vector& inputs, array& out);\n\n std::tuple> prepare_slice(const array& in);\n};\n\nclass Softmax : public UnaryPrimitive {\n public:\n explicit Softmax(Stream stream, bool precise)\n : UnaryPrimitive(stream), precise_(precise){};\n\n void eval_cpu(const std::vector& inputs, array& out) override;\n void eval_gpu(const std::vector& inputs, array& out) override;\n\n DEFINE_VMAP()\n DEFINE_GRADS()\n DEFINE_PRINT(Softmax)\n DEFINE_INPUT_OUTPUT_SHAPE()\n\n bool is_equivalent(const Primitive& other) const override;\n\n private:\n void eval(const std::vector& inputs, array& out);\n bool precise_;\n};\n\nclass Sort : public UnaryPrimitive {\n public:\n explicit Sort(Stream stream, int axis)\n : UnaryPrimitive(stream), axis_(axis){};\n\n void eval_cpu(const std::vector& inputs, array& out) override;\n void eval_gpu(const std::vector& inputs, array& out) override;\n\n DEFINE_VMAP()\n DEFINE_GRADS()\n DEFINE_PRINT(Sort)\n DEFINE_INPUT_OUTPUT_SHAPE()\n bool is_equivalent(const Primitive& other) const override;\n\n private:\n int axis_;\n\n void eval(const std::vector& inputs, array& out);\n};\n\nclass Split : public Primitive {\n public:\n explicit Split(Stream stream, const std::vector& indices, int axis)\n : Primitive(stream), indices_(indices), axis_(axis){};\n\n void eval_cpu(const std::vector& inputs, std::vector& outputs)\n override;\n void eval_gpu(const std::vector& inputs, std::vector& outputs)\n override;\n\n DEFINE_VMAP()\n DEFINE_GRADS()\n DEFINE_PRINT(Split)\n bool is_equivalent(const Primitive& other) const override;\n\n private:\n void eval(const std::vector& inputs, std::vector& outputs);\n\n std::vector indices_;\n int axis_;\n};\n\nclass Square : public UnaryPrimitive {\n public:\n explicit Square(Stream stream) : UnaryPrimitive(stream){};\n\n void eval_cpu(const std::vector& inputs, array& out) override;\n void eval_gpu(const std::vector& inputs, array& out) override;\n\n DEFINE_VMAP()\n DEFINE_GRADS()\n DEFINE_PRINT(Square)\n DEFINE_DEFAULT_IS_EQUIVALENT()\n DEFINE_INPUT_OUTPUT_SHAPE()\n\n private:\n void eval(const std::vector& inputs, array& out);\n};\n\nclass Sqrt : public UnaryPrimitive {\n public:\n explicit Sqrt(Stream stream, bool recip = false)\n...\n", "needle_token_start": 11860, "needle_token_end": 11926, "code_context_ntokens": 15856, "output": ["```cpp\nclass QuantizedMatmul : public UnaryPrimitive {\n public:\n explicit QuantizedMatmul(\n Stream stream,\n int group_size,\n int bits,\n bool transpose)\n : UnaryPrimitive(stream),\n group_size_(group_size),\n bits_(bits),\n transpose_(transpose){};\n\n void eval_cpu(const std::vector& inputs, array& out) override;\n void eval_gpu(const std::vector& inputs, array& out) override;\n\n DEFINE_VMAP()\n DEFINE_GRADS()\n DEFINE_PRINT(QuantizedMatmul)\n bool is_equivalent(const Primitive& other) const override;\n\n private:\n int group_size_;\n int bits_;\n bool transpose_;\n\n void eval(const std::vector& inputs, array& out);\n};\n```"]} +{"repo": "ml-explore/mlx", "name": "sqrt", "language": "cpp", "path": "mlx/ops.cpp", "position_ratio": 0.85, "description": "\nFunction Description:\n1. **Purpose**: Computes the square root of each element in a numerical array.\n2. **Input**: A numerical array and an optional specification of the computational context (like CPU or GPU).\n3. **Output**: A new array of the same shape as the input, containing the square roots of the original elements, ensuring the data type is at least a floating-point type to accommodate non-integer results.\n4. **Procedure**: The function first ensures the data type of the input array can handle floating-point results. It then applies a square root operation to each element of the input array, producing an output array with the results in the same shape as the input.\n\n", "instruction": "Based on the function description and code context, please retrieve and repeat the exact described function from the code context in a code block wrapped by ```:", "template": "instruction\ncode_context\ndescription\ninstruction", "code_context": " const std::vector& axes,\n bool keepdims /* = false */,\n int ddof /* = 0*/,\n StreamOrDevice s /* = {}*/) {\n auto dtype = at_least_float(a.dtype());\n auto mu2 = square(mean(a, axes, keepdims, s), s);\n auto a2 = mean(square(a, s), axes, keepdims, s);\n auto v = subtract(a2, mu2, s);\n\n if (ddof != 0) {\n auto nelements = number_of_elements(a, axes, false, dtype, s);\n auto factor = divide(\n nelements,\n maximum(subtract(nelements, array(ddof, dtype), s), array(0, dtype), s),\n s);\n v = multiply(v, factor, s);\n }\n\n return v;\n}\n\narray var(\n const array& a,\n int axis,\n bool keepdims /* = false */,\n int ddof /* = 0*/,\n StreamOrDevice s /* = {} */) {\n return var(a, std::vector{axis}, keepdims, ddof, to_stream(s));\n}\n\narray std(\n const array& a,\n bool keepdims,\n int ddof /* = 0*/,\n StreamOrDevice s /* = {}*/) {\n std::vector axes(a.ndim());\n std::iota(axes.begin(), axes.end(), 0);\n return std(a, axes, keepdims, ddof, to_stream(s));\n}\n\narray std(\n const array& a,\n const std::vector& axes,\n bool keepdims /* = false */,\n int ddof /* = 0*/,\n StreamOrDevice s /* = {}*/) {\n return sqrt(var(a, axes, keepdims, ddof, s), s);\n}\n\narray std(\n const array& a,\n int axis,\n bool keepdims /* = false */,\n int ddof /* = 0*/,\n StreamOrDevice s /* = {} */) {\n return std(a, std::vector{axis}, keepdims, ddof, to_stream(s));\n}\n\narray prod(const array& a, bool keepdims, StreamOrDevice s /* = {}*/) {\n std::vector axes(a.ndim());\n std::iota(axes.begin(), axes.end(), 0);\n return prod(a, axes, keepdims, s);\n}\n\narray prod(\n const array& a,\n const std::vector& axes,\n bool keepdims /* = false */,\n StreamOrDevice s /* = {}*/) {\n if (axes.empty()) {\n return a;\n }\n auto [out_shape, sorted_axes] = compute_reduce_shape(axes, a.shape());\n auto out = array(\n out_shape,\n a.dtype(),\n std::make_shared(to_stream(s), Reduce::Prod, sorted_axes),\n {a});\n if (!keepdims) {\n out = squeeze(out, sorted_axes, s);\n }\n return out;\n}\n\narray prod(\n const array& a,\n int axis,\n bool keepdims /* = false */,\n StreamOrDevice s /* = {} */) {\n return prod(a, std::vector{axis}, keepdims, s);\n}\n\narray max(const array& a, bool keepdims, StreamOrDevice s /* = {}*/) {\n std::vector axes(a.ndim());\n std::iota(axes.begin(), axes.end(), 0);\n return max(a, axes, keepdims, s);\n}\n\narray max(\n const array& a,\n const std::vector& axes,\n bool keepdims /* = false */,\n StreamOrDevice s /* = {}*/) {\n if (a.size() == 0) {\n throw std::invalid_argument(\"[max] Cannot max reduce zero size array.\");\n }\n if (axes.empty()) {\n return a;\n }\n auto [out_shape, sorted_axes] = compute_reduce_shape(axes, a.shape());\n auto out = array(\n out_shape,\n a.dtype(),\n std::make_shared(to_stream(s), Reduce::Max, sorted_axes),\n {a});\n if (!keepdims) {\n out = squeeze(out, sorted_axes, s);\n }\n return out;\n}\n\narray max(\n const array& a,\n int axis,\n bool keepdims /* = false */,\n StreamOrDevice s /* = {} */) {\n return max(a, std::vector{axis}, keepdims, s);\n}\n\narray min(const array& a, bool keepdims, StreamOrDevice s /* = {}*/) {\n std::vector axes(a.ndim());\n std::iota(axes.begin(), axes.end(), 0);\n return min(a, axes, keepdims, s);\n}\n\narray min(\n const array& a,\n const std::vector& axes,\n bool keepdims /* = false */,\n StreamOrDevice s /* = {}*/) {\n if (a.size() == 0) {\n throw std::invalid_argument(\"[min] Cannot min reduce zero size array.\");\n }\n if (axes.empty()) {\n return a;\n }\n auto [out_shape, sorted_axes] = compute_reduce_shape(axes, a.shape());\n auto out = array(\n out_shape,\n a.dtype(),\n std::make_shared(to_stream(s), Reduce::Min, sorted_axes),\n {a});\n if (!keepdims) {\n out = squeeze(out, sorted_axes, s);\n }\n return out;\n}\n\narray min(\n const array& a,\n int axis,\n bool keepdims /* = false */,\n StreamOrDevice s /* = {} */) {\n return min(a, std::vector{axis}, keepdims, s);\n}\n\narray argmin(const array& a, bool keepdims, StreamOrDevice s /* = {} */) {\n int size = a.size();\n auto result = argmin(reshape(a, {size}, s), 0, true, s);\n if (keepdims) {\n result = reshape(result, std::vector(a.shape().size(), 1), s);\n } else {\n result = squeeze(result, s);\n }\n return result;\n}\n\narray argmin(\n const array& a,\n int axis,\n bool keepdims /* = false */,\n StreamOrDevice s /* = {} */) {\n if (a.size() == 0) {\n throw std::invalid_argument(\n \"[argmin] Cannot argmin reduce zero size array.\");\n }\n auto [out_shape, sorted_axes] = compute_reduce_shape({axis}, a.shape());\n auto out = array(\n out_shape,\n uint32,\n std::make_shared(\n to_stream(s), ArgReduce::ArgMin, sorted_axes[0]),\n {a});\n if (!keepdims) {\n out = squeeze(out, sorted_axes, s);\n }\n return out;\n}\n\narray argmax(const array& a, bool keepdims, StreamOrDevice s /* = {} */) {\n int size = a.size();\n auto result = argmax(reshape(a, {size}, s), 0, true, s);\n if (keepdims) {\n result = reshape(result, std::vector(a.shape().size(), 1), s);\n } else {\n result = squeeze(result, s);\n }\n return result;\n}\n\narray argmax(\n const array& a,\n int axis,\n bool keepdims /* = false */,\n StreamOrDevice s /* = {} */) {\n if (a.size() == 0) {\n throw std::invalid_argument(\n \"[argmax] Cannot argmax reduce zero size array.\");\n }\n auto [out_shape, sorted_axes] = compute_reduce_shape({axis}, a.shape());\n auto out = array(\n out_shape,\n uint32,\n std::make_shared(\n to_stream(s), ArgReduce::ArgMax, sorted_axes[0]),\n {a});\n if (!keepdims) {\n out = squeeze(out, sorted_axes, s);\n }\n return out;\n}\n\n/** Returns a sorted copy of the flattened array. */\narray sort(const array& a, StreamOrDevice s /* = {} */) {\n int size = a.size();\n return sort(reshape(a, {size}, s), 0, s);\n}\n\n/** Returns a sorted copy of the array along a given axis. */\narray sort(const array& a, int axis, StreamOrDevice s /* = {} */) {\n // Check for valid axis\n if (axis + static_cast(a.ndim()) < 0 ||\n axis >= static_cast(a.ndim())) {\n std::ostringstream msg;\n msg << \"[sort] Received invalid axis \" << axis << \" for array with \"\n << a.ndim() << \" dimensions.\";\n throw std::invalid_argument(msg.str());\n }\n\n // TODO: Fix GPU kernel\n if (a.shape(axis) >= (1u << 21) && to_stream(s).device.type == Device::gpu) {\n std::ostringstream msg;\n msg << \"[sort] GPU sort cannot handle sort axis of >= 2M elements,\"\n << \" got array with sort axis size \" << a.shape(axis) << \".\"\n << \" Please place this operation on the CPU instead.\";\n throw std::runtime_error(msg.str());\n }\n\n return array(\n a.shape(), a.dtype(), std::make_shared(to_stream(s), axis), {a});\n}\n\n/** Returns indices that sort the flattened array. */\narray argsort(const array& a, StreamOrDevice s /* = {} */) {\n int size = a.size();\n return argsort(reshape(a, {size}, s), 0, s);\n}\n\n/** Returns indices that sort the array along a given axis. */\narray argsort(const array& a, int axis, StreamOrDevice s /* = {} */) {\n // Check for valid axis\n if (axis + static_cast(a.ndim()) < 0 ||\n axis >= static_cast(a.ndim())) {\n std::ostringstream msg;\n msg << \"[argsort] Received invalid axis \" << axis << \" for array with \"\n << a.ndim() << \" dimensions.\";\n throw std::invalid_argument(msg.str());\n }\n\n // TODO: Fix GPU kernel\n if (a.shape(axis) >= (1u << 21) && to_stream(s).device.type == Device::gpu) {\n std::ostringstream msg;\n msg << \"[argsort] GPU sort cannot handle sort axis of >= 2M elements,\"\n << \" got array with sort axis size \" << a.shape(axis) << \".\"\n << \" Please place this operation on the CPU instead.\";\n throw std::runtime_error(msg.str());\n }\n\n return array(\n a.shape(), uint32, std::make_shared(to_stream(s), axis), {a});\n}\n\n/**\n * Returns a partitioned copy of the flattened array\n * such that the smaller kth elements are first.\n **/\narray partition(const array& a, int kth, StreamOrDevice s /* = {} */) {\n int size = a.size();\n return partition(reshape(a, {size}, s), kth, 0, s);\n}\n\n/**\n * Returns a partitioned copy of the array along a given axis\n * such that the smaller kth elements are first.\n **/\narray partition(\n const array& a,\n int kth,\n int axis,\n StreamOrDevice s /* = {} */) {\n // Check for valid axis\n if (axis + static_cast(a.ndim()) < 0 ||\n axis >= static_cast(a.ndim())) {\n std::ostringstream msg;\n msg << \"[partition] Received invalid axis \" << axis << \" for array with \"\n << a.ndim() << \" dimensions.\";\n throw std::invalid_argument(msg.str());\n }\n int axis_ = axis < 0 ? axis + a.ndim() : axis;\n int kth_ = kth < 0 ? kth + a.shape(axis) : kth;\n if (kth_ < 0 || kth_ >= a.shape(axis_)) {\n std::ostringstream msg;\n msg << \"[partition] Received invalid kth \" << kth << \"along axis \" << axis\n << \" for array with shape: \" << a.shape();\n throw std::invalid_argument(msg.str());\n }\n return array(\n a.shape(),\n a.dtype(),\n std::make_shared(to_stream(s), kth_, axis_),\n {a});\n}\n\n/**\n * Returns indices that partition the flattened array\n * such that the smaller kth elements are first.\n **/\narray argpartition(const array& a, int kth, StreamOrDevice s /* = {} */) {\n int size = a.size();\n return argpartition(reshape(a, {size}, s), kth, 0, s);\n}\n\n/**\n * Returns indices that partition the array along a given axis\n * such that the smaller kth elements are first.\n **/\narray argpartition(\n const array& a,\n int kth,\n int axis,\n StreamOrDevice s /* = {} */) {\n // Check for valid axis\n if (axis + static_cast(a.ndim()) < 0 ||\n axis >= static_cast(a.ndim())) {\n std::ostringstream msg;\n msg << \"[argpartition] Received invalid axis \" << axis << \" for array with \"\n << a.ndim() << \" dimensions.\";\n throw std::invalid_argument(msg.str());\n }\n int axis_ = axis < 0 ? axis + a.ndim() : axis;\n int kth_ = kth < 0 ? kth + a.shape(axis) : kth;\n if (kth_ < 0 || kth_ >= a.shape(axis_)) {\n std::ostringstream msg;\n msg << \"[argpartition] Received invalid kth \" << kth << \" along axis \"\n << axis << \" for array with shape: \" << a.shape();\n throw std::invalid_argument(msg.str());\n }\n return array(\n a.shape(),\n uint32,\n std::make_shared(to_stream(s), kth_, axis_),\n {a});\n}\n\n/** Returns topk elements of the flattened array. */\narray topk(const array& a, int k, StreamOrDevice s /* = {}*/) {\n int size = a.size();\n return topk(reshape(a, {size}, s), k, 0, s);\n}\n\n/** Returns topk elements of the array along a given axis. */\narray topk(const array& a, int k, int axis, StreamOrDevice s /* = {}*/) {\n // Check for valid axis\n int axis_ = axis < 0 ? axis + a.ndim() : axis;\n if (axis_ < 0 || axis_ >= static_cast(a.ndim())) {\n std::ostringstream msg;\n msg << \"[topk] Received invalid axis \" << axis << \" for array with \"\n << a.ndim() << \" dimensions.\";\n throw std::invalid_argument(msg.str());\n }\n if (k < 0 || k > a.shape(axis_)) {\n std::ostringstream msg;\n msg << \"[topk] Received invalid k=\" << k << \" along axis \" << axis\n << \" for array with shape: \" << a.shape();\n throw std::invalid_argument(msg.str());\n }\n\n // Return early if the whole input was requested.\n if (k == a.shape(axis_)) {\n return a;\n }\n\n array a_partitioned = partition(a, -k, axis_, s);\n std::vector slice_starts(a.ndim(), 0);\n std::vector slice_ends = a.shape();\n slice_starts[axis_] = a.shape(axis_) - k;\n return slice(a_partitioned, slice_starts, slice_ends, s);\n}\n\narray logsumexp(const array& a, bool keepdims, StreamOrDevice s /* = {}*/) {\n std::vector axes(a.ndim());\n std::iota(axes.begin(), axes.end(), 0);\n return logsumexp(a, axes, keepdims, s);\n}\n\narray logsumexp(\n const array& a,\n const std::vector& axes,\n bool keepdims /* = false */,\n StreamOrDevice s /* = {}*/) {\n auto maxval = stop_gradient(max(a, axes, true, s), s);\n auto out = log(sum(exp(subtract(a, maxval, s), s), axes, keepdims, s), s);\n out = add(out, reshape(maxval, out.shape(), s), s);\n if (!keepdims) {\n maxval = squeeze(maxval, axes, s);\n }\n return where(isinf(maxval, s), maxval, out, s);\n}\n\narray logsumexp(\n const array& a,\n int axis,\n bool keepdims /* = false */,\n StreamOrDevice s /* = {} */) {\n return logsumexp(a, std::vector{axis}, keepdims, s);\n}\n\narray abs(const array& a, StreamOrDevice s /* = {} */) {\n auto out =\n array(a.shape(), a.dtype(), std::make_shared(to_stream(s)), {a});\n if (a.dtype() == complex64) {\n out = astype(out, float32, s);\n }\n return out;\n}\n\narray negative(const array& a, StreamOrDevice s /* = {} */) {\n if (a.dtype() == bool_) {\n auto msg = \"[negative] Not supported for bool, use logical_not instead.\";\n throw std::invalid_argument(msg);\n }\n return array(\n a.shape(), a.dtype(), std::make_shared(to_stream(s)), {a});\n}\narray operator-(const array& a) {\n return negative(a);\n}\n\narray sign(const array& a, StreamOrDevice s /* = {} */) {\n return array(a.shape(), a.dtype(), std::make_shared(to_stream(s)), {a});\n}\n\narray logical_not(const array& a, StreamOrDevice s /* = {} */) {\n return array(\n a.shape(),\n bool_,\n std::make_shared(to_stream(s)),\n {astype(a, bool_, s)});\n}\n\narray logical_and(const array& a, const array& b, StreamOrDevice s /* = {} */) {\n // Broadcast arrays to a common shape\n auto inputs = broadcast_arrays(astype(a, bool_, s), astype(b, bool_, s), s);\n auto& shape = inputs[0].shape();\n return array(\n shape,\n bool_,\n std::make_shared(to_stream(s)),\n std::move(inputs));\n}\narray operator&&(const array& a, const array& b) {\n return logical_and(a, b);\n}\n\narray logical_or(const array& a, const array& b, StreamOrDevice s /* = {} */) {\n // Broadcast arrays to a common shape\n auto inputs = broadcast_arrays(astype(a, bool_, s), astype(b, bool_, s), s);\n auto& shape = inputs[0].shape();\n return array(\n shape,\n bool_,\n std::make_shared(to_stream(s)),\n std::move(inputs));\n}\narray operator||(const array& a, const array& b) {\n return logical_or(a, b);\n}\n\narray reciprocal(const array& a, StreamOrDevice s /* = {} */) {\n auto dtype = at_least_float(a.dtype());\n return divide(array(1.0f, dtype), a, to_stream(s));\n}\n\narray add(const array& a, const array& b, StreamOrDevice s /* = {} */) {\n auto out_type = promote_types(a.dtype(), b.dtype());\n auto inputs =\n broadcast_arrays(astype(a, out_type, s), astype(b, out_type, s), s);\n auto& shape = inputs[0].shape();\n return array(\n shape, out_type, std::make_shared(to_stream(s)), std::move(inputs));\n}\n\narray operator+(const array& a, const array& b) {\n return add(a, b);\n}\n\narray subtract(const array& a, const array& b, StreamOrDevice s /* = {} */) {\n auto out_type = promote_types(a.dtype(), b.dtype());\n auto inputs =\n broadcast_arrays(astype(a, out_type, s), astype(b, out_type, s), s);\n auto& shape = inputs[0].shape();\n return array(\n shape,\n out_type,\n std::make_shared(to_stream(s)),\n std::move(inputs));\n}\n\narray operator-(const array& a, const array& b) {\n return subtract(a, b);\n}\n\narray multiply(const array& a, const array& b, StreamOrDevice s /* = {} */) {\n auto out_type = promote_types(a.dtype(), b.dtype());\n auto inputs =\n broadcast_arrays(astype(a, out_type, s), astype(b, out_type, s), s);\n auto& shape = inputs[0].shape();\n return array(\n shape,\n out_type,\n std::make_shared(to_stream(s)),\n std::move(inputs));\n}\n\narray operator*(const array& a, const array& b) {\n return multiply(a, b);\n}\n\narray divide(const array& a, const array& b, StreamOrDevice s /* = {} */) {\n auto dtype = at_least_float(promote_types(a.dtype(), b.dtype()));\n auto inputs =\n broadcast_arrays(astype(a, dtype, s), astype(b, dtype, to_stream(s)), s);\n auto& shape = inputs[0].shape();\n return array(\n shape, dtype, std::make_shared(to_stream(s)), std::move(inputs));\n}\narray operator/(const array& a, const array& b) {\n return divide(a, b);\n}\narray operator/(double a, const array& b) {\n return divide(array(a), b);\n}\narray operator/(const array& a, double b) {\n return divide(a, array(b));\n}\n\narray floor_divide(\n const array& a,\n const array& b,\n StreamOrDevice s /* = {} */) {\n auto dtype = promote_types(a.dtype(), b.dtype());\n if (issubdtype(dtype, inexact)) {\n return floor(divide(a, b, s), s);\n }\n\n auto inputs = broadcast_arrays(astype(a, dtype, s), astype(b, dtype, s), s);\n auto& shape = inputs[0].shape();\n return array(\n shape, dtype, std::make_shared(to_stream(s)), std::move(inputs));\n}\n\narray remainder(const array& a, const array& b, StreamOrDevice s /* = {} */) {\n auto dtype = promote_types(a.dtype(), b.dtype());\n auto inputs =\n broadcast_arrays(astype(a, dtype, s), astype(b, dtype, to_stream(s)), s);\n auto& shape = inputs[0].shape();\n return array(\n shape,\n dtype,\n std::make_shared(to_stream(s)),\n std::move(inputs));\n}\narray operator%(const array& a, const array& b) {\n return remainder(a, b);\n}\n\nstd::vector\ndivmod(const array& a, const array& b, StreamOrDevice s /* = {} */) {\n auto dtype = promote_types(a.dtype(), b.dtype());\n if (issubdtype(dtype, complexfloating)) {\n throw std::invalid_argument(\"[divmod] Complex type not supported.\");\n }\n auto inputs =\n broadcast_arrays(astype(a, dtype, s), astype(b, dtype, to_stream(s)), s);\n return array::make_arrays(\n {inputs[0].shape(), inputs[0].shape()},\n {inputs[0].dtype(), inputs[0].dtype()},\n std::make_shared(to_stream(s)),\n inputs);\n}\n\narray maximum(const array& a, const array& b, StreamOrDevice s /* = {} */) {\n auto out_type = promote_types(a.dtype(), b.dtype());\n auto inputs =\n broadcast_arrays(astype(a, out_type, s), astype(b, out_type, s), s);\n auto& shape = inputs[0].shape();\n return array(\n shape,\n out_type,\n std::make_shared(to_stream(s)),\n std::move(inputs));\n}\n\narray minimum(const array& a, const array& b, StreamOrDevice s /* = {} */) {\n auto out_type = promote_types(a.dtype(), b.dtype());\n auto inputs =\n broadcast_arrays(astype(a, out_type, s), astype(b, out_type, s), s);\n auto& shape = inputs[0].shape();\n return array(\n shape,\n out_type,\n std::make_shared(to_stream(s)),\n std::move(inputs));\n}\n\narray floor(const array& a, StreamOrDevice s /* = {} */) {\n if (a.dtype() == complex64) {\n throw std::invalid_argument(\"[floor] Not supported for complex64.\");\n }\n return array(\n a.shape(), a.dtype(), std::make_shared(to_stream(s)), {a});\n}\n\narray ceil(const array& a, StreamOrDevice s /* = {} */) {\n if (a.dtype() == complex64) {\n throw std::invalid_argument(\"[floor] Not supported for complex64.\");\n }\n return array(a.shape(), a.dtype(), std::make_shared(to_stream(s)), {a});\n}\n\narray square(const array& a, StreamOrDevice s /* = {} */) {\n return array(\n a.shape(), a.dtype(), std::make_shared(to_stream(s)), {a});\n}\n\narray exp(const array& a, StreamOrDevice s /* = {} */) {\n auto dtype = at_least_float(a.dtype());\n auto input = astype(a, dtype, s);\n return array(a.shape(), dtype, std::make_shared(to_stream(s)), {input});\n}\n\narray expm1(const array& a, StreamOrDevice s /* = {} */) {\n auto dtype = at_least_float(a.dtype());\n auto input = astype(a, dtype, s);\n return array(\n a.shape(), dtype, std::make_shared(to_stream(s)), {input});\n}\n\narray sin(const array& a, StreamOrDevice s /* = {} */) {\n auto dtype = at_least_float(a.dtype());\n auto input = astype(a, dtype, s);\n return array(a.shape(), dtype, std::make_shared(to_stream(s)), {input});\n}\n\narray cos(const array& a, StreamOrDevice s /* = {} */) {\n auto dtype = at_least_float(a.dtype());\n auto input = astype(a, dtype, s);\n return array(a.shape(), dtype, std::make_shared(to_stream(s)), {input});\n}\n\narray tan(const array& a, StreamOrDevice s /* = {} */) {\n auto dtype = at_least_float(a.dtype());\n auto input = astype(a, dtype, s);\n return array(a.shape(), dtype, std::make_shared(to_stream(s)), {input});\n}\n\narray arcsin(const array& a, StreamOrDevice s /* = {} */) {\n auto dtype = at_least_float(a.dtype());\n auto input = astype(a, dtype, s);\n return array(\n a.shape(), dtype, std::make_shared(to_stream(s)), {input});\n}\n\narray arccos(const array& a, StreamOrDevice s /* = {} */) {\n auto dtype = at_least_float(a.dtype());\n auto input = astype(a, dtype, s);\n return array(\n a.shape(), dtype, std::make_shared(to_stream(s)), {input});\n}\n\narray arctan(const array& a, StreamOrDevice s /* = {} */) {\n auto dtype = at_least_float(a.dtype());\n auto input = astype(a, dtype, s);\n return array(\n a.shape(), dtype, std::make_shared(to_stream(s)), {input});\n}\n\narray sinh(const array& a, StreamOrDevice s /* = {} */) {\n auto dtype = at_least_float(a.dtype());\n auto input = astype(a, dtype, s);\n return array(a.shape(), dtype, std::make_shared(to_stream(s)), {input});\n}\n\narray cosh(const array& a, StreamOrDevice s /* = {} */) {\n auto dtype = at_least_float(a.dtype());\n auto input = astype(a, dtype, s);\n return array(a.shape(), dtype, std::make_shared(to_stream(s)), {input});\n}\n\narray tanh(const array& a, StreamOrDevice s /* = {} */) {\n auto dtype = at_least_float(a.dtype());\n auto input = astype(a, dtype, s);\n return array(a.shape(), dtype, std::make_shared(to_stream(s)), {input});\n}\n\narray arcsinh(const array& a, StreamOrDevice s /* = {} */) {\n auto dtype = at_least_float(a.dtype());\n auto input = astype(a, dtype, s);\n return array(\n a.shape(), dtype, std::make_shared(to_stream(s)), {input});\n}\n\narray arccosh(const array& a, StreamOrDevice s /* = {} */) {\n auto dtype = at_least_float(a.dtype());\n auto input = astype(a, dtype, s);\n return array(\n a.shape(), dtype, std::make_shared(to_stream(s)), {input});\n}\n\narray arctanh(const array& a, StreamOrDevice s /* = {} */) {\n auto dtype = at_least_float(a.dtype());\n auto input = astype(a, dtype, s);\n return array(\n a.shape(), dtype, std::make_shared(to_stream(s)), {input});\n}\n\narray log(const array& a, StreamOrDevice s /* = {} */) {\n auto dtype = at_least_float(a.dtype());\n auto input = astype(a, dtype, s);\n return array(\n a.shape(),\n dtype,\n std::make_shared(to_stream(s), Log::Base::e),\n {input});\n}\n\narray log2(const array& a, StreamOrDevice s /* = {} */) {\n auto dtype = at_least_float(a.dtype());\n auto input = astype(a, dtype, s);\n return array(\n a.shape(),\n dtype,\n std::make_shared(to_stream(s), Log::Base::two),\n {input});\n}\n\narray log10(const array& a, StreamOrDevice s /* = {} */) {\n auto dtype = at_least_float(a.dtype());\n auto input = astype(a, dtype, s);\n return array(\n a.shape(),\n dtype,\n std::make_shared(to_stream(s), Log::Base::ten),\n {input});\n}\n\narray log1p(const array& a, StreamOrDevice s /* = {} */) {\n auto dtype = at_least_float(a.dtype());\n auto input = astype(a, dtype, s);\n return array(\n a.shape(), dtype, std::make_shared(to_stream(s)), {input});\n}\n\narray logaddexp(const array& a, const array& b, StreamOrDevice s /* = {} */) {\n // Make sure out type is floating point\n auto out_type = at_least_float(promote_types(a.dtype(), b.dtype()));\n auto inputs =\n broadcast_arrays(astype(a, out_type, s), astype(b, out_type, s), s);\n auto& shape = inputs[0].shape();\n return array(\n shape,\n out_type,\n std::make_shared(to_stream(s)),\n std::move(inputs));\n}\n\narray sigmoid(const array& a, StreamOrDevice s /* = {} */) {\n auto dtype = at_least_float(a.dtype());\n auto input = astype(a, dtype, s);\n return array(\n a.shape(), dtype, std::make_shared(to_stream(s)), {input});\n}\n\narray erf(const array& a, StreamOrDevice s /* = {} */) {\n auto dtype = at_least_float(a.dtype());\n return array(\n a.shape(),\n dtype,\n std::make_shared(to_stream(s)),\n {astype(a, dtype, s)});\n}\n\narray erfinv(const array& a, StreamOrDevice s /* = {} */) {\n auto dtype = at_least_float(a.dtype());\n return array(\n a.shape(),\n dtype,\n std::make_shared(to_stream(s)),\n {astype(a, dtype, s)});\n}\n\narray stop_gradient(const array& a, StreamOrDevice s /* = {} */) {\n return array(\n a.shape(), a.dtype(), std::make_shared(to_stream(s)), {a});\n}\n\narray round(const array& a, int decimals, StreamOrDevice s /* = {} */) {\n if (decimals == 0) {\n return array(\n a.shape(), a.dtype(), std::make_shared(to_stream(s)), {a});\n }\n\n auto dtype = at_least_float(a.dtype());\n float scale = std::pow(10, decimals);\n auto result = multiply(a, array(scale, dtype), s);\n result = round(result, 0, s);\n result = multiply(result, array(1 / scale, dtype), s);\n\n return astype(result, a.dtype(), s);\n}\n\narray matmul(\n const array& in_a,\n const array& in_b,\n StreamOrDevice s /* = {} */) {\n auto a = in_a;\n auto b = in_b;\n if (a.ndim() == 0 || b.ndim() == 0) {\n throw std::invalid_argument(\n \"[matmul] Got 0 dimension input. Inputs must \"\n \"have at least one dimension.\");\n }\n if (a.ndim() == 1) {\n // Insert a singleton dim in the beginning\n a = reshape(a, {1, -1}, s);\n }\n if (b.ndim() == 1) {\n // Insert a singleton dim at the end\n b = reshape(b, {-1, 1}, s);\n }\n if (a.shape(-1) != b.shape(-2)) {\n std::ostringstream msg;\n msg << \"[matmul] Last dimension of first input with shape \" << a.shape()\n << \" must match second to last dimension of\"\n << \" second input with shape \" << b.shape() << \".\";\n throw std::invalid_argument(msg.str());\n }\n // Type promotion\n auto out_type = promote_types(a.dtype(), b.dtype());\n if (!issubdtype(out_type, floating)) {\n std::ostringstream msg;\n msg << \"[matmul] Only real floating point types are supported but \"\n << a.dtype() << \" and \" << b.dtype() << \" were provided which results\"\n << \" in \" << out_type << \", which is not a real floating point type.\";\n throw std::invalid_argument(msg.str());\n }\n if (a.dtype() != out_type) {\n a = astype(a, out_type, s);\n }\n if (b.dtype() != out_type) {\n b = astype(b, out_type, s);\n }\n\n // We can batch the multiplication by reshaping a\n if (a.ndim() > 2 && b.ndim() == 2) {\n std::vector out_shape = a.shape();\n a = reshape(a, {-1, out_shape.back()}, s);\n out_shape.back() = b.shape(-1);\n if (in_b.ndim() == 1) {\n out_shape.pop_back();\n }\n auto out = array(\n {a.shape(0), b.shape(1)},\n out_type,\n std::make_shared(to_stream(s)),\n {a, b});\n return reshape(out, out_shape, s);\n }\n\n if (a.ndim() > 2 || b.ndim() > 2) {\n std::vector bsx_a(a.shape().begin(), a.shape().end() - 2);\n std::vector bsx_b(b.shape().begin(), b.shape().end() - 2);\n auto inner_shape = broadcast_shapes(bsx_a, bsx_b);\n\n // Broadcast a\n inner_shape.push_back(a.shape(-2));\n inner_shape.push_back(a.shape(-1));\n a = broadcast_to(a, inner_shape, s);\n\n // Broadcast b\n *(inner_shape.end() - 2) = b.shape(-2);\n *(inner_shape.end() - 1) = b.shape(-1);\n b = broadcast_to(b, inner_shape, s);\n }\n\n auto out_shape = a.shape();\n out_shape.back() = b.shape(-1);\n\n auto p = std::make_shared(to_stream(s));\n\n // Remove the possibly inserted singleton dimensions\n if (in_a.ndim() == 1 || in_b.ndim() == 1) {\n auto out = array(out_shape, out_type, std::move(p), {a, b});\n out_shape.erase(\n out_shape.end() - ((in_a.ndim() == 1) ? 2 : 1),\n out_shape.end() - ((in_b.ndim() == 1) ? 0 : 1));\n return reshape(out, std::move(out_shape), s);\n }\n return array(std::move(out_shape), out_type, std::move(p), {a, b});\n}\n\narray gather(\n const array& a,\n const std::vector& indices,\n const std::vector& axes,\n const std::vector& slice_sizes,\n StreamOrDevice s /* = {} */) {\n // Checks that indices, dimensions, and slice_sizes are all valid\n if (indices.size() > a.ndim()) {\n std::ostringstream msg;\n msg << \"[gather] Too many index arrays. Got \" << indices.size()\n << \" index arrays for input with \" << a.ndim() << \" dimensions.\";\n throw std::invalid_argument(msg.str());\n }\n\n std::set dims(axes.begin(), axes.end());\n if (dims.size() != axes.size()) {\n throw std::invalid_argument(\"[gather] Repeat axes not allowed in gather.\");\n }\n if (!dims.empty() && (*dims.begin() < 0 || *dims.rbegin() >= a.ndim())) {\n throw std::invalid_argument(\"[gather] Axes don't match array dimensions.\");\n }\n if (indices.size() != axes.size()) {\n throw std::invalid_argument(\n \"[gather] Number of index arrays does not match number of axes.\");\n }\n for (auto& x : indices) {\n if (x.dtype() == bool_) {\n throw(\"[Gather] Boolean indices not supported.\");\n }\n }\n\n if (slice_sizes.size() != a.ndim()) {\n std::ostringstream msg;\n msg << \"[gather] Got slice_sizes with size \" << slice_sizes.size()\n << \" for array with \" << a.ndim() << \" dimensions.\";\n throw std::invalid_argument(msg.str());\n }\n for (int i = 0; i < a.ndim(); ++i) {\n if (slice_sizes[i] < 0 || slice_sizes[i] > a.shape(i)) {\n std::ostringstream msg;\n msg << \"[gather] Slice sizes must be in [0, a.shape(i)]. Got \"\n << slice_sizes << \" for array with shape \" << a.shape() << \".\";\n throw std::invalid_argument(msg.str());\n }\n }\n\n // Promote indices to the same type\n auto dtype = result_type(indices);\n if (issubdtype(dtype, inexact)) {\n throw std::invalid_argument(\n \"[gather] Got indices with invalid dtype. Indices must be integral.\");\n }\n\n // Broadcast and cast indices if necessary\n auto inputs = broadcast_arrays(indices);\n for (auto& idx : inputs) {\n idx = astype(idx, dtype, s);\n }\n\n std::vector out_shape;\n if (!inputs.empty()) {\n out_shape = inputs[0].shape();\n }\n out_shape.insert(out_shape.end(), slice_sizes.begin(), slice_sizes.end());\n\n inputs.insert(inputs.begin(), a);\n return array(\n out_shape,\n a.dtype(),\n std::make_shared(to_stream(s), axes, slice_sizes),\n inputs);\n}\n\narray take(\n const array& a,\n const array& indices,\n int axis,\n StreamOrDevice s /* = {} */) {\n // Check for valid axis\n if (axis + static_cast(a.ndim()) < 0 ||\n axis >= static_cast(a.ndim())) {\n std::ostringstream msg;\n msg << \"[take] Received invalid axis \" << axis << \" for array with \"\n << a.ndim() << \" dimensions.\";\n throw std::invalid_argument(msg.str());\n }\n\n // Check for valid take\n if (a.size() == 0 && indices.size() != 0) {\n throw std::invalid_argument(\n \"[take] Cannot do a non-empty take from an array with zero elements.\");\n }\n\n // Handle negative axis\n axis = axis < 0 ? a.ndim() + axis : axis;\n\n // Make slice sizes to pass to gather\n std::vector slice_sizes = a.shape();\n slice_sizes[axis] = indices.size() > 0 ? 1 : 0;\n\n auto out = gather(a, indices, axis, slice_sizes, s);\n\n // Transpose indices dimensions to axis dimension\n if (axis != 0) {\n std::vector t_axes(out.ndim());\n std::iota(t_axes.begin(), t_axes.begin() + axis, indices.ndim());\n std::iota(t_axes.begin() + axis, t_axes.begin() + axis + indices.ndim(), 0);\n std::iota(\n t_axes.begin() + axis + indices.ndim(),\n t_axes.end(),\n indices.ndim() + axis);\n out = transpose(out, t_axes, s);\n }\n\n // Squeeze the axis we take over\n std::vector out_shape = out.shape();\n out_shape.erase(out_shape.begin() + indices.ndim() + axis);\n return reshape(out, out_shape, s);\n}\n\narray take(const array& a, const array& indices, StreamOrDevice s /* = {} */) {\n return take(reshape(a, {-1}, s), indices, 0, s);\n}\n\narray take_along_axis(\n const array& a,\n const array& indices,\n int axis,\n StreamOrDevice s /* = {} */) {\n if (axis + a.ndim() < 0 || axis >= static_cast(a.ndim())) {\n std::ostringstream msg;\n msg << \"[take_along_axis] Received invalid axis \" << \" for array with \"\n << a.ndim() << \" dimensions.\";\n throw std::invalid_argument(msg.str());\n }\n\n if (indices.ndim() != a.ndim()) {\n std::ostringstream msg;\n msg << \"[take_along_axis] Indices of dimension \" << indices.ndim()\n << \" does not match array of dimension \" << a.ndim() << \".\";\n throw std::invalid_argument(msg.str());\n }\n\n // Allow negative axis\n axis = axis < 0 ? a.ndim() + axis : axis;\n\n std::vector nd_indices;\n std::vector index_shape(a.ndim(), 1);\n for (int i = 0; i < a.ndim(); ++i) {\n if (i == axis) {\n nd_indices.push_back(indices);\n } else {\n // Reshape so they can be broadcast\n index_shape[i] = a.shape(i);\n nd_indices.push_back(reshape(arange(a.shape(i), s), index_shape, s));\n index_shape[i] = 1;\n }\n }\n std::vector dims(a.ndim());\n std::iota(dims.begin(), dims.end(), 0);\n std::vector slice_sizes(a.ndim(), a.size() > 0);\n auto out = gather(a, nd_indices, dims, slice_sizes, s);\n\n // Squeeze out the slice shape\n std::vector out_shape(\n out.shape().begin(), out.shape().begin() + a.ndim());\n return reshape(out, out_shape, s);\n}\n\n/** Scatter updates to given indices */\narray scatter(\n const array& a,\n const std::vector& indices,\n const array& updates,\n const std::vector& axes,\n Scatter::ReduceType mode /*= Scatter::ReduceType::None*/,\n StreamOrDevice s /*= {}*/) {\n // Checks that indices, dimensions, and slice_sizes are all valid\n if (indices.size() > a.ndim()) {\n std::ostringstream msg;\n msg << \"[scatter] Too many index arrays. Got \" << indices.size()\n << \" index arrays for input with \" << a.ndim() << \" dimensions.\";\n throw std::invalid_argument(msg.str());\n }\n for (auto& x : indices) {\n if (x.dtype() == bool_) {\n throw(\"[scatter] Boolean indices not supported.\");\n }\n }\n\n std::set dims(axes.begin(), axes.end());\n if (dims.size() != axes.size()) {\n throw std::invalid_argument(\n \"[scatter] Repeat axes not allowed in scatter.\");\n }\n if (!dims.empty() && (*dims.begin() < 0 || *dims.rbegin() >= a.ndim())) {\n throw std::invalid_argument(\"[scatter] Axes don't match array dimensions.\");\n }\n if (indices.size() != axes.size()) {\n throw std::invalid_argument(\n \"[scatter] Number of index arrays does not match number of axes.\");\n }\n\n // Broadcast and cast indices if necessary\n auto inputs = broadcast_arrays(indices);\n\n std::vector idx_shape;\n if (!inputs.empty()) {\n idx_shape = inputs[0].shape();\n }\n\n if (updates.ndim() != (a.ndim() + idx_shape.size())) {\n std::ostringstream msg;\n msg << \"[scatter] Updates with \" << updates.ndim()\n << \" dimensions does not match the sum of the array and indices \"\n \"dimensions \"\n << a.ndim() + idx_shape.size() << \".\";\n throw std::invalid_argument(msg.str());\n }\n for (int i = 0; i < idx_shape.size(); ++i) {\n if (updates.shape(i) != idx_shape[i]) {\n std::ostringstream msg;\n msg << \"[scatter] Update shape \" << updates.shape()\n << \" is not valid for broadcasted index shape \" << idx_shape << \".\";\n throw std::invalid_argument(msg.str());\n }\n }\n for (int i = 0; i < a.ndim(); ++i) {\n auto up_shape = updates.shape(i + idx_shape.size());\n if (up_shape > a.shape(i)) {\n std::ostringstream msg;\n msg << \"[scatter] Updates with shape \" << updates.shape()\n << \" are too large for array with shape \" << a.shape() << \".\";\n throw std::invalid_argument(msg.str());\n }\n }\n\n // Promote indices to the same type\n auto dtype = result_type(indices);\n if (issubdtype(dtype, inexact)) {\n throw std::invalid_argument(\n \"[scatter] Got indices with invalid dtype. Indices must be integral.\");\n }\n for (auto& idx : inputs) {\n idx = astype(idx, dtype, s);\n }\n\n inputs.insert(inputs.begin(), a);\n // TODO promote or cast?\n inputs.push_back(astype(updates, a.dtype(), s));\n return array(\n a.shape(),\n a.dtype(),\n std::make_shared(to_stream(s), mode, axes),\n inputs);\n}\n\narray scatter(\n const array& a,\n const std::vector& indices,\n const array& updates,\n const std::vector& axes,\n StreamOrDevice s /*= {}*/) {\n return scatter(a, indices, updates, axes, Scatter::None, s);\n}\n\narray scatter_add(\n const array& a,\n const std::vector& indices,\n const array& updates,\n const std::vector& axes,\n StreamOrDevice s /*= {}*/) {\n return scatter(a, indices, updates, axes, Scatter::Sum, s);\n}\n\narray scatter_prod(\n const array& a,\n const std::vector& indices,\n const array& updates,\n const std::vector& axes,\n StreamOrDevice s /*= {}*/) {\n return scatter(a, indices, updates, axes, Scatter::Prod, s);\n}\n\narray scatter_max(\n const array& a,\n const std::vector& indices,\n const array& updates,\n const std::vector& axes,\n StreamOrDevice s /*= {}*/) {\n return scatter(a, indices, updates, axes, Scatter::Max, s);\n}\n\narray scatter_min(\n const array& a,\n const std::vector& indices,\n const array& updates,\n const std::vector& axes,\n StreamOrDevice s /*= {}*/) {\n return scatter(a, indices, updates, axes, Scatter::Min, s);\n}\n\na\nrray sqrt(const array& a, StreamOrDevice s /* = {} */) {\n auto dtype = at_least_float(a.dtype());\n return array(\n a.shape(),\n dtype,\n std::make_shared(to_stream(s)),\n {astype(a, dtype, s)});\n}\n\narray rsqrt(const array& a, StreamOrDevice s /* = {} */) {\n auto dtype = at_least_float(a.dtype());\n return array(\n a.shape(),\n dtype,\n std::make_shared(to_stream(s), true),\n {astype(a, dtype, s)});\n}\n\narray softmax(\n const array& a,\n const std::vector& axes,\n bool precise /* = false */,\n StreamOrDevice s /* = {}*/) {\n if (axes.size() == 1 && (a.ndim() == axes[0] + 1 || axes[0] == -1)) {\n auto dtype = at_least_float(a.dtype());\n return array(\n a.shape(),\n dtype,\n std::make_shared(to_stream(s), precise),\n {astype(a, dtype, s)});\n } else {\n auto in = a;\n if (precise) {\n in = astype(a, float32, s);\n }\n auto a_max = stop_gradient(max(in, axes, /*keepdims = */ true, s), s);\n auto ex = exp(subtract(in, a_max, s), s);\n return astype(\n divide(ex, sum(ex, axes, /*keepdims = */ true, s), s), a.dtype(), s);\n }\n}\n\narray softmax(\n const array& a,\n bool precise /* = false */,\n StreamOrDevice s /* = {}*/) {\n std::vector axes(a.ndim());\n std::iota(axes.begin(), axes.end(), 0);\n return softmax(a, axes, precise, s);\n}\n\narray power(const array& a, const array& b, StreamOrDevice s /* = {} */) {\n auto dtype = promote_types(a.dtype(), b.dtype());\n std::vector inputs = {astype(a, dtype, s), astype(b, dtype, s)};\n if (a.shape() != b.shape()) {\n inputs = broadcast_arrays(inputs, s);\n }\n return array(\n inputs[0].shape(), dtype, std::make_shared(to_stream(s)), inputs);\n}\n\narray cumsum(\n const array& a,\n int axis,\n bool reverse /* = false*/,\n bool inclusive /* = true*/,\n StreamOrDevice s /* = {}*/) {\n int ndim = a.ndim();\n if (axis >= ndim || axis < -ndim) {\n std::ostringstream msg;\n msg << \"[cumsum] Axis \" << axis << \" is out of bounds for array with \"\n << a.ndim() << \" dimensions.\";\n throw std::invalid_argument(msg.str());\n }\n axis = (axis + a.ndim()) % a.ndim();\n auto out_type = a.dtype() == bool_ ? int32 : a.dtype();\n return array(\n a.shape(),\n out_type,\n std::make_shared(\n to_stream(s), Scan::ReduceType::Sum, axis, reverse, inclusive),\n {a});\n}\n\narray cumprod(\n const array& a,\n int axis,\n bool reverse /* = false*/,\n bool inclusive /* = true*/,\n StreamOrDevice s /* = {}*/) {\n int ndim = a.ndim();\n if (axis >= ndim || axis < -ndim) {\n std::ostringstream msg;\n msg << \"[cumprod] Axis \" << axis << \" is out of bounds for array with \"\n << a.ndim() << \" dimensions.\";\n throw std::invalid_argument(msg.str());\n }\n axis = (axis + a.ndim()) % a.ndim();\n return array(\n a.shape(),\n a.dtype(),\n std::make_shared(\n to_stream(s), Scan::ReduceType::Prod, axis, reverse, inclusive),\n {a});\n}\n\narray cummax(\n const array& a,\n int axis,\n bool reverse /* = false*/,\n bool inclusive /* = true*/,\n StreamOrDevice s /* = {}*/) {\n int ndim = a.ndim();\n if (axis >= ndim || axis < -ndim) {\n std::ostringstream msg;\n msg << \"[cummax] Axis \" << axis << \" is out of bounds for array with \"\n << a.ndim() << \" dimensions.\";\n throw std::invalid_argument(msg.str());\n }\n axis = (axis + a.ndim()) % a.ndim();\n return array(\n a.shape(),\n a.dtype(),\n std::make_shared(\n to_stream(s), Scan::ReduceType::Max, axis, reverse, inclusive),\n {a});\n}\n\narray cummin(\n const array& a,\n int axis,\n bool reverse /* = false*/,\n bool inclusive /* = true*/,\n StreamOrDevice s /* = {}*/) {\n int ndim = a.ndim();\n if (axis >= ndim || axis < -ndim) {\n std::ostringstream msg;\n msg << \"[cummin] Axis \" << axis << \" is out of bounds for array with \"\n << a.ndim() << \" dimensions.\";\n throw std::invalid_argument(msg.str());\n }\n axis = (axis + a.ndim()) % a.ndim();\n return array(\n a.shape(),\n a.dtype(),\n std::make_shared(\n to_stream(s), Scan::ReduceType::Min, axis, reverse, inclusive),\n {a});\n}\n\n/** Convolution operations */\n\nnamespace {\n\n// Conv helpers\ninline int conv_out_axis_size(int in_dim, int wt_dim, int stride, int padding) {\n return ((in_dim + padding - wt_dim) / stride) + 1;\n}\n\n// Conv helpers\ninline int dilate_size(int dim, int dil) {\n return 1 + dil * (dim - 1);\n}\n\ninline std::vector conv_out_shape(\n const std::vector& in_shape,\n const std::vector& wt_shape,\n const std::vector& strides,\n const std::vector& pads_lo,\n const std::vector& pads_hi,\n const std::vector& kernel_dilation,\n const std::vector& input_dilation) {\n int N = in_shape[0];\n int O = wt_shape[0];\n std::vector out_shape(in_shape.size());\n int i = 0;\n out_shape[i++] = N;\n\n int spatial_dims = in_shape.size() - 2;\n\n if (strides.size() != spatial_dims) {\n std::ostringstream msg;\n msg << \"[conv] Invalid strides \" << strides << \"for \" << spatial_dims\n << \"D convolution.\";\n throw std::invalid_argument(msg.str());\n }\n\n if (pads_lo.size() != spatial_dims || pads_hi.size() != spatial_dims) {\n std::ostringstream msg;\n msg << \"[conv] Invalid pading \" << pads_lo << \" | \" << pads_hi << \"for \"\n << spatial_dims << \"D convolution.\";\n throw std::invalid_argument(msg.str());\n }\n\n if (kernel_dilation.size() != spatial_dims) {\n std::ostringstream msg;\n msg << \"[conv] Invalid kernel dilation \" << kernel_dilation << \"for \"\n << spatial_dims << \"D convolution.\";\n throw std::invalid_argument(msg.str());\n }\n\n if (input_dilation.size() != spatial_dims) {\n std::ostringstream msg;\n msg << \"[conv] Invalid input dilation \" << input_dilation << \"for \"\n << spatial_dims << \"D convolution.\";\n throw std::invalid_argument(msg.str());\n }\n\n for (; i < in_shape.size() - 1; i++) {\n if (kernel_dilation[i - 1] <= 0) {\n std::ostringstream msg;\n msg << \"[conv] Kernel dilation sizes must be positive.\"\n << \" Got kernel dilation \" << kernel_dilation << \".\";\n throw std::invalid_argument(msg.str());\n }\n\n if (input_dilation[i - 1] <= 0) {\n std::ostringstream msg;\n msg << \"[conv] Input dilation sizes must be positive.\"\n << \" Got input dilation \" << input_dilation << \".\";\n throw std::invalid_argument(msg.str());\n }\n\n if (pads_lo[i - 1] < 0 || pads_hi[i - 1] < 0) {\n std::ostringstream msg;\n msg << \"[conv] Padding sizes must be non-negative.\" << \" Got padding \"\n << pads_lo << \" | \" << pads_hi << \".\";\n throw std::invalid_argument(msg.str());\n }\n\n...\n", "needle_token_start": 13741, "needle_token_end": 13827, "code_context_ntokens": 16222, "output": ["```cpp\narray sqrt(const array& a, StreamOrDevice s /* = {} */) {\n auto dtype = at_least_float(a.dtype());\n return array(\n a.shape(),\n dtype,\n std::make_shared(to_stream(s)),\n {astype(a, dtype, s)});\n}\n```"]} +{"repo": "ml-explore/mlx", "name": "promote_types", "language": "cpp", "path": "mlx/dtype.cpp", "position_ratio": 0.95, "description": "\nFunction Description:\n1. **Purpose**: To determine the most appropriate data type for operations involving two different data types, ensuring that no precision is lost.\n2. **Input**: Two data type identifiers representing the types of two operands.\n3. **Output**: A single data type identifier that represents the promoted data type suitable for handling operations between the two input types without loss of information.\n4. **Procedure**: The function uses a predefined matrix that maps pairs of input data types to a resulting data type. It retrieves the appropriate data type from this matrix using the indices derived from the input data types.\n\n", "instruction": "Based on the function description and code context, please retrieve and repeat the exact described function from the code context in a code block wrapped by ```:", "template": "instruction\ncode_context\ndescription\ninstruction", "code_context": "// Path: mlx/primitives.cpp\n// Copyright \u00a9 2023-2024 Apple Inc.\n#include \n#include \n#include \n#include \n#include \n#include \n\n#include \"mlx/backend/common/utils.h\"\n#include \"mlx/fft.h\"\n#include \"mlx/ops.h\"\n#include \"mlx/primitives.h\"\n#include \"mlx/utils.h\"\n\nnamespace mlx::core {\n\nnamespace {\n\nstd::tuple vmap_binary_op(\n const std::vector& inputs,\n const std::vector& axes,\n const Stream& stream) {\n assert(inputs.size() == 2);\n assert(axes.size() == 2);\n\n if (axes[0] == -1 && axes[1] == -1) {\n return {inputs[0], inputs[1], -1};\n }\n\n auto a = inputs[0];\n auto b = inputs[1];\n int ndim = std::max(a.ndim() + (axes[0] == -1), b.ndim() + (axes[1] == -1));\n\n auto expand_dims = [stream, ndim](auto in) {\n auto shape = in.shape();\n shape.insert(shape.begin(), ndim - shape.size(), 1);\n return reshape(in, shape, stream);\n };\n\n int to_ax = (ndim - a.ndim()) + axes[0];\n int from_ax = (ndim - b.ndim()) + axes[1];\n a = expand_dims(a);\n b = expand_dims(b);\n\n if (from_ax != to_ax) {\n std::vector tdims(b.ndim());\n std::iota(tdims.begin(), tdims.end(), 0);\n tdims.erase(tdims.begin() + from_ax);\n tdims.insert(tdims.begin() + to_ax, from_ax);\n b = transpose(b, tdims, stream);\n }\n return {a, b, to_ax};\n}\n\nstd::tuple vmap_ternary_op(\n const std::vector& inputs,\n const std::vector& axes,\n const Stream& stream) {\n assert(inputs.size() == 3);\n assert(axes.size() == 3);\n\n if (axes[0] == -1 && axes[1] == -1 && axes[2] == -1) {\n return {inputs[0], inputs[1], inputs[2], -1};\n }\n\n auto a = inputs[0];\n auto b = inputs[1];\n auto c = inputs[2];\n int ndim = std::max(\n {a.ndim() + (axes[0] == -1),\n b.ndim() + (axes[1] == -1),\n c.ndim() + (axes[2] == -1)});\n\n auto expand_dims = [stream, ndim](auto in) {\n auto shape = in.shape();\n shape.insert(shape.begin(), ndim - shape.size(), 1);\n return reshape(in, shape, stream);\n };\n\n int to_ax = (ndim - a.ndim()) + axes[0];\n int from_ax1 = (ndim - b.ndim()) + axes[1];\n int from_ax2 = (ndim - c.ndim()) + axes[2];\n a = expand_dims(a);\n b = expand_dims(b);\n c = expand_dims(c);\n\n auto find_tdims = [](auto x, int to_ax, int from_ax) {\n std::vector tdims(x.ndim());\n std::iota(tdims.begin(), tdims.end(), 0);\n tdims.erase(tdims.begin() + from_ax);\n tdims.insert(tdims.begin() + to_ax, from_ax);\n return tdims;\n };\n\n if (to_ax != from_ax1) {\n std::vector tdims = find_tdims(b, to_ax, from_ax1);\n b = transpose(b, tdims, stream);\n }\n\n if (to_ax != from_ax2) {\n std::vector tdims = find_tdims(c, to_ax, from_ax2);\n c = transpose(c, tdims, stream);\n }\n return {a, b, c, to_ax};\n}\n\n} // namespace\n\nstd::vector Primitive::jvp(\n const std::vector&,\n const std::vector&,\n const std::vector&) {\n std::ostringstream msg;\n msg << \"[Primitive::jvp] Not implemented for \";\n print(msg);\n msg << \".\";\n throw std::invalid_argument(msg.str());\n};\n\nstd::vector Primitive::vjp(\n const std::vector&,\n const std::vector&,\n const std::vector&,\n const std::vector&) {\n std::ostringstream msg;\n msg << \"[Primitive::vip] Not implemented for \";\n print(msg);\n msg << \".\";\n throw std::invalid_argument(msg.str());\n};\n\nstd::pair, std::vector> Primitive::vmap(\n const std::vector&,\n const std::vector&) {\n std::ostringstream msg;\n msg << \"[Primitive::vmap] Not implemented for \";\n print(msg);\n msg << \".\";\n throw std::invalid_argument(msg.str());\n};\n\nstd::vector> Primitive::output_shapes(\n const std::vector&) {\n std::ostringstream msg;\n msg << \"[Primitive::output_shapes] \";\n this->print(msg);\n msg << \" cannot infer output shapes.\";\n throw std::invalid_argument(msg.str());\n};\n\nstd::vector Abs::vjp(\n const std::vector& primals,\n const std::vector& cotangents,\n const std::vector& argnums,\n const std::vector&) {\n return jvp(primals, cotangents, argnums);\n}\n\nstd::vector Abs::jvp(\n const std::vector& primals,\n const std::vector& tangents,\n const std::vector& argnums) {\n assert(primals.size() == 1);\n assert(argnums.size() == 1);\n return {multiply(tangents[0], sign(primals[0], stream()), stream())};\n}\n\nstd::pair, std::vector> Abs::vmap(\n const std::vector& inputs,\n const std::vector& axes) {\n assert(inputs.size() == 1);\n assert(axes.size() == 1);\n return {{abs(inputs[0], stream())}, axes};\n}\n\nstd::vector Add::jvp(\n const std::vector& primals,\n const std::vector& tangents,\n const std::vector& argnums) {\n return {\n tangents.size() > 1 ? add(tangents[0], tangents[1], stream())\n : tangents[0]};\n}\n\nstd::vector Add::vjp(\n const std::vector& primals,\n const std::vector& cotangents,\n const std::vector& argnums,\n const std::vector&) {\n if (argnums.size() == 1) {\n return cotangents;\n } else {\n return {cotangents[0], cotangents[0]};\n }\n}\n\nstd::pair, std::vector> Add::vmap(\n const std::vector& inputs,\n const std::vector& axes) {\n auto [a, b, to_ax] = vmap_binary_op(inputs, axes, stream());\n return {{add(a, b, stream())}, {to_ax}};\n}\n\nstd::vector AddMM::vjp(\n const std::vector& primals,\n const std::vector& cotangents,\n const std::vector& argnums,\n const std::vector&) {\n std::vector vjps;\n auto& cotan = cotangents[0];\n std::vector reorder(cotan.ndim());\n std::iota(reorder.begin(), reorder.end(), 0);\n std::iter_swap(reorder.end() - 1, reorder.end() - 2);\n for (auto arg : argnums) {\n if (arg == 0) {\n // M X N * (K X N).T -> M X K\n auto cotan_scaled = cotan;\n if (alpha_ != 1.) {\n auto alpha_arr = array(alpha_, cotan.dtype());\n cotan_scaled = (multiply(alpha_arr, cotan_scaled, stream()));\n }\n vjps.push_back(matmul(\n cotan_scaled, transpose(primals[1], reorder, stream()), stream()));\n } else if (arg == 1) {\n // (M X K).T * M X N -> K X N\n auto cotan_scaled = cotan;\n if (alpha_ != 1.) {\n auto alpha_arr = array(alpha_, cotan.dtype());\n cotan_scaled = (multiply(alpha_arr, cotan_scaled, stream()));\n }\n vjps.push_back(matmul(\n transpose(primals[0], reorder, stream()), cotan_scaled, stream()));\n } else {\n auto cotan_scaled = cotan;\n if (beta_ != 1.) {\n auto beta_arr = array(beta_, cotan.dtype());\n cotan_scaled = (multiply(beta_arr, cotan_scaled, stream()));\n }\n vjps.push_back(cotan_scaled);\n }\n }\n return vjps;\n}\n\nbool AddMM::is_equivalent(const Primitive& other) const {\n const AddMM& a_other = static_cast(other);\n return (alpha_ == a_other.alpha_ && beta_ == a_other.beta_);\n}\n\nstd::pair, std::vector> AddMM::vmap(\n const std::vector& inputs,\n const std::vector& axes) {\n auto maybe_move_ax = [this](auto& arr, auto ax) {\n return ax > 0 ? moveaxis(arr, ax, 0, stream()) : arr;\n };\n auto a = maybe_move_ax(inputs[0], axes[0]);\n auto b = maybe_move_ax(inputs[1], axes[1]);\n auto c = maybe_move_ax(inputs[2], axes[2]);\n return {{addmm(c, a, b, alpha_, beta_, stream())}, {0}};\n}\n\nbool Arange::is_equivalent(const Primitive& other) const {\n const Arange& a_other = static_cast(other);\n return (\n start_ == a_other.start_ && stop_ == a_other.stop_ &&\n step_ == a_other.step_);\n}\n\nstd::vector ArcCos::vjp(\n const std::vector& primals,\n const std::vector& cotangents,\n const std::vector& argnums,\n const std::vector&) {\n return jvp(primals, cotangents, argnums);\n}\n\nstd::vector ArcCos::jvp(\n const std::vector& primals,\n const std::vector& tangents,\n const std::vector& argnums) {\n assert(primals.size() == 1);\n assert(argnums.size() == 1);\n array one = array(1., primals[0].dtype());\n array t = subtract(one, square(primals[0], stream()), stream());\n array denom = negative(rsqrt(t, stream()), stream());\n return {multiply(tangents[0], denom, stream())};\n}\n\nstd::pair, std::vector> ArcCos::vmap(\n const std::vector& inputs,\n const std::vector& axes) {\n assert(inputs.size() == 1);\n assert(axes.size() == 1);\n return {{arccos(inputs[0], stream())}, axes};\n}\n\nstd::vector ArcCosh::vjp(\n const std::vector& primals,\n const std::vector& cotangents,\n const std::vector& argnums,\n const std::vector&) {\n return jvp(primals, cotangents, argnums);\n}\n\nstd::vector ArcCosh::jvp(\n const std::vector& primals,\n const std::vector& tangents,\n const std::vector& argnums) {\n assert(primals.size() == 1);\n assert(argnums.size() == 1);\n array one = array(1., primals[0].dtype());\n array t = subtract(square(primals[0], stream()), one, stream());\n return {multiply(tangents[0], rsqrt(t, stream()), stream())};\n}\n\nstd::pair, std::vector> ArcCosh::vmap(\n const std::vector& inputs,\n const std::vector& axes) {\n assert(inputs.size() == 1);\n assert(axes.size() == 1);\n return {{arccosh(inputs[0], stream())}, axes};\n}\n\nstd::vector ArcSin::vjp(\n const std::vector& primals,\n const std::vector& cotangents,\n const std::vector& argnums,\n const std::vector&) {\n return jvp(primals, cotangents, argnums);\n}\n\nstd::vector ArcSin::jvp(\n const std::vector& primals,\n const std::vector& tangents,\n const std::vector& argnums) {\n assert(primals.size() == 1);\n assert(argnums.size() == 1);\n array one = array(1., primals[0].dtype());\n array t = subtract(one, square(primals[0], stream()), stream());\n return {multiply(tangents[0], rsqrt(t, stream()), stream())};\n}\n\nstd::pair, std::vector> ArcSin::vmap(\n const std::vector& inputs,\n const std::vector& axes) {\n assert(inputs.size() == 1);\n assert(axes.size() == 1);\n return {{arcsin(inputs[0], stream())}, axes};\n}\n\nstd::vector ArcSinh::vjp(\n const std::vector& primals,\n const std::vector& cotangents,\n const std::vector& argnums,\n const std::vector&) {\n return jvp(primals, cotangents, argnums);\n}\n\nstd::vector ArcSinh::jvp(\n const std::vector& primals,\n const std::vector& tangents,\n const std::vector& argnums) {\n assert(primals.size() == 1);\n assert(argnums.size() == 1);\n array one = array(1., primals[0].dtype());\n array t = add(square(primals[0], stream()), one, stream());\n return {multiply(tangents[0], rsqrt(t, stream()), stream())};\n}\n\nstd::pair, std::vector> ArcSinh::vmap(\n const std::vector& inputs,\n const std::vector& axes) {\n assert(inputs.size() == 1);\n assert(axes.size() == 1);\n return {{arcsinh(inputs[0], stream())}, axes};\n}\n\nstd::vector ArcTan::vjp(\n const std::vector& primals,\n const std::vector& cotangents,\n const std::vector& argnums,\n const std::vector&) {\n return jvp(primals, cotangents, argnums);\n}\n\nstd::vector ArcTan::jvp(\n const std::vector& primals,\n const std::vector& tangents,\n const std::vector& argnums) {\n assert(primals.size() == 1);\n assert(argnums.size() == 1);\n array one = array(1., primals[0].dtype());\n array t = add(one, square(primals[0], stream()), stream());\n return {divide(tangents[0], t, stream())};\n}\n\nstd::pair, std::vector> ArcTan::vmap(\n const std::vector& inputs,\n const std::vector& axes) {\n assert(inputs.size() == 1);\n assert(axes.size() == 1);\n return {{arctan(inputs[0], stream())}, axes};\n}\n\nstd::vector ArcTanh::vjp(\n const std::vector& primals,\n const std::vector& cotangents,\n const std::vector& argnums,\n const std::vector&) {\n return jvp(primals, cotangents, argnums);\n}\n\nstd::vector ArcTanh::jvp(\n const std::vector& primals,\n const std::vector& tangents,\n const std::vector& argnums) {\n assert(primals.size() == 1);\n assert(argnums.size() == 1);\n array one = array(1., primals[0].dtype());\n array t = subtract(one, square(primals[0], stream()), stream());\n return {divide(tangents[0], t, stream())};\n}\n\nstd::pair, std::vector> ArcTanh::vmap(\n const std::vector& inputs,\n const std::vector& axes) {\n assert(inputs.size() == 1);\n assert(axes.size() == 1);\n return {{arctanh(inputs[0], stream())}, axes};\n}\n\nstd::pair, std::vector> ArgPartition::vmap(\n const std::vector& inputs,\n const std::vector& axes) {\n assert(inputs.size() == 1);\n assert(axes.size() == 1);\n\n int axis_left = axes[0] >= 0 && axes[0] <= axis_;\n return {{argpartition(inputs[0], axis_ + axis_left, stream())}, axes};\n}\n\nbool ArgPartition::is_equivalent(const Primitive& other) const {\n const ArgPartition& r_other = static_cast(other);\n return axis_ == r_other.axis_ && kth_ == r_other.kth_;\n}\n\nbool ArgReduce::is_equivalent(const Primitive& other) const {\n const ArgReduce& r_other = static_cast(other);\n return reduce_type_ == r_other.reduce_type_ && axis_ == r_other.axis_;\n}\n\nstd::pair, std::vector> ArgReduce::vmap(\n const std::vector& inputs,\n const std::vector& axes) {\n int reduce_ax = axis_ + (axes[0] >= 0 && axis_ >= axes[0]);\n auto& in = inputs[0];\n std::vector out;\n if (reduce_type_ == ArgReduce::ArgMin) {\n out.push_back(argmin(in, reduce_ax, true, stream()));\n } else {\n out.push_back(argmax(in, reduce_ax, true, stream()));\n }\n return {out, axes};\n}\n\nstd::pair, std::vector> ArgSort::vmap(\n const std::vector& inputs,\n const std::vector& axes) {\n assert(inputs.size() == 1);\n assert(axes.size() == 1);\n\n int axis_left = axes[0] >= 0 && axes[0] <= axis_;\n return {{argsort(inputs[0], axis_ + axis_left, stream())}, axes};\n}\n\nstd::vector> ArgReduce::output_shapes(\n const std::vector& inputs) {\n auto out_shape = inputs[0].shape();\n out_shape[axis_] = 1;\n return {out_shape};\n}\n\nbool ArgSort::is_equivalent(const Primitive& other) const {\n const ArgSort& r_other = static_cast(other);\n return axis_ == r_other.axis_;\n}\n\nstd::vector AsType::vjp(\n const std::vector& primals,\n const std::vector& cotangents,\n const std::vector& argnums,\n const std::vector&) {\n if (cotangents[0].dtype() != dtype_) {\n throw std::invalid_argument(\n \"[astype] Type of cotangentsgent does not much primal output type.\");\n }\n return {astype(cotangents[0], primals[0].dtype(), stream())};\n}\n\nstd::vector AsType::jvp(\n const std::vector& primals,\n const std::vector& tangents,\n const std::vector& argnums) {\n return {astype(tangents[0], dtype_, stream())};\n}\n\nstd::pair, std::vector> AsType::vmap(\n const std::vector& inputs,\n const std::vector& axes) {\n return {{astype(inputs[0], dtype_, stream())}, axes};\n}\n\nbool AsType::is_equivalent(const Primitive& other) const {\n const AsType& a_other = static_cast(other);\n return dtype_ == a_other.dtype_;\n}\n\nstd::vector AsStrided::vjp(\n const std::vector& primals,\n const std::vector& cotangents,\n const std::vector& argnums,\n const std::vector&) {\n assert(argnums.size() == 1);\n\n // Extract the sizes and cast them to ints\n int grad_size = primals[0].size();\n int cotangents_size = cotangents[0].size();\n\n // Make a flat container to hold the gradients\n auto grad = zeros_like(primals[0], stream());\n grad = reshape(grad, {grad_size}, stream());\n\n // Create the indices that map output to input\n auto idx = arange(grad_size, stream());\n idx = as_strided(idx, shape_, strides_, offset_, stream());\n idx = reshape(idx, {cotangents_size}, stream());\n\n // Reshape the cotangentsgent for use with scatter\n auto flat_cotangents = reshape(cotangents[0], {cotangents_size, 1}, stream());\n\n // Finally accumulate the gradients and reshape them to look like the input\n grad = scatter_add(grad, idx, flat_cotangents, 0, stream());\n grad = reshape(grad, primals[0].shape(), stream());\n\n return {grad};\n}\n\nstd::vector AsStrided::jvp(\n const std::vector& primals,\n const std::vector& tangents,\n const std::vector& argnums) {\n assert(primals.size() == 1);\n\n return {as_strided(tangents[0], shape_, strides_, offset_, stream())};\n}\n\nbool AsStrided::is_equivalent(const Primitive& other) const {\n const AsStrided& a_other = static_cast(other);\n return shape_ == a_other.shape_ && strides_ == a_other.strides_ &&\n offset_ == a_other.offset_;\n}\n\nstd::vector Broadcast::vjp(\n const std::vector& primals,\n const std::vector& cotangents,\n const std::vector& argnums,\n const std::vector&) {\n assert(argnums.size() == 1);\n\n // Reduce cotangents to the shape of the primal\n auto& shape = primals[0].shape();\n auto& cotan = cotangents[0];\n int diff = cotan.ndim() - shape.size();\n std::vector reduce_axes;\n for (int i = 0; i < cotan.ndim(); ++i) {\n if (i < diff) {\n reduce_axes.push_back(i);\n } else if (shape[i - diff] != cotan.shape(i)) {\n reduce_axes.push_back(i);\n }\n }\n return {reshape(sum(cotan, reduce_axes, true, stream()), shape, stream())};\n}\n\nstd::vector Broadcast::jvp(\n const std::vector& primals,\n const std::vector& tangents,\n const std::vector& argnums) {\n assert(argnums.size() == 1);\n return {broadcast_to(tangents[0], shape_, stream())};\n}\n\nstd::pair, std::vector> Broadcast::vmap(\n const std::vector& inputs,\n const std::vector& axes) {\n assert(inputs.size() == 1);\n assert(axes.size() == 1);\n auto ax = axes[0];\n auto in = inputs[0];\n if (ax >= 0) {\n auto in_shape = in.shape();\n int diff = shape_.size() - in.ndim() + 1;\n assert(diff >= 0);\n in_shape.insert(in_shape.begin(), diff, 1);\n ax += diff;\n shape_.insert(shape_.begin() + ax, in_shape[ax]);\n in = reshape(in, in_shape, stream());\n }\n return {{broadcast_to(in, shape_, stream())}, {ax}};\n}\n\nbool Broadcast::is_equivalent(const Primitive& other) const {\n const Broadcast& b_other = static_cast(other);\n return shape_ == b_other.shape_;\n}\n\nstd::vector Ceil::vjp(\n const std::vector& primals,\n const std::vector& cotangents,\n const std::vector& argnums,\n const std::vector&) {\n return jvp(primals, cotangents, argnums);\n}\n\nstd::vector Ceil::jvp(\n const std::vector& primals,\n const std::vector& tangents,\n const std::vector& argnums) {\n assert(primals.size() == 1);\n assert(argnums.size() == 1);\n return {zeros_like(primals[0], stream())};\n}\n\nstd::pair, std::vector> Ceil::vmap(\n const std::vector& inputs,\n const std::vector& axes) {\n assert(inputs.size() == 1);\n assert(axes.size() == 1);\n return {{ceil(inputs[0], stream())}, axes};\n}\n\nstd::vector Concatenate::vjp(\n const std::vector& primals,\n const std::vector& cotangents,\n const std::vector& argnums,\n const std::vector&) {\n auto& cotan = cotangents[0];\n std::vector start(cotan.ndim(), 0);\n std::vector stop = cotan.shape();\n\n std::vector sizes;\n sizes.push_back(0);\n for (auto& p : primals) {\n sizes.push_back(p.shape(axis_));\n }\n std::partial_sum(sizes.cbegin(), sizes.cend(), sizes.begin());\n\n std::vector grads;\n for (auto i : argnums) {\n start[axis_] = sizes[i];\n stop[axis_] = sizes[i + 1];\n grads.push_back(slice(cotan, start, stop, stream()));\n }\n return grads;\n}\n\nstd::vector Concatenate::jvp(\n const std::vector& primals,\n const std::vector& tangents,\n const std::vector& argnums) {\n std::vector argidx(argnums.size());\n std::iota(argidx.begin(), argidx.end(), 0);\n std::sort(argidx.begin(), argidx.end(), [&argnums](int a, int b) {\n return argnums[a] < argnums[b];\n });\n\n std::vector vals;\n for (int i = 0, j = 0; i < primals.size(); ++i) {\n if (j < argnums.size() && argnums[argidx[j]] == i) {\n vals.push_back(tangents[argidx[j++]]);\n } else {\n vals.push_back(zeros_like(primals[i], stream()));\n }\n }\n return {concatenate(vals, axis_, stream())};\n}\n\nstd::pair, std::vector> Concatenate::vmap(\n const std::vector& inputs,\n const std::vector& axes) {\n std::vector t_inputs;\n int out_ax = -1;\n // Find the first vmapped input\n int i = 0;\n for (; i < axes.size(); i++) {\n t_inputs.push_back(inputs[i]);\n if (axes[i] >= 0) {\n out_ax = axes[i];\n break;\n }\n }\n if (out_ax >= 0) {\n // Advance to the next input\n i++;\n\n // Move vmap axes to the same spot.\n for (; i < axes.size(); ++i) {\n if (out_ax != axes[i] && axes[i] >= 0) {\n t_inputs.push_back(moveaxis(inputs[i], axes[i], out_ax, stream()));\n } else {\n t_inputs.push_back(inputs[i]);\n }\n }\n }\n auto axis = axis_ + (out_ax >= 0 && axis_ >= out_ax);\n return {{concatenate(t_inputs, axis, stream())}, {out_ax}};\n}\n\nbool Concatenate::is_equivalent(const Primitive& other) const {\n const Concatenate& c_other = static_cast(other);\n return axis_ == c_other.axis_;\n}\n\narray conv_weight_backward_patches(\n const array& in,\n const array& wt,\n const array& cotan,\n const std::vector& kernel_strides,\n const std::vector& padding,\n StreamOrDevice s) {\n // Resolve Padded input shapes and strides\n std::vector padding_starts(in.ndim(), 0);\n std::vector padding_ends = in.shape();\n std::vector in_padded_shape = in.shape();\n\n // padded shape\n for (int i = 1; i < in.ndim() - 1; i++) {\n in_padded_shape[i] += 2 * padding[i - 1];\n padding_ends[i] += padding[i - 1];\n padding_starts[i] += padding[i - 1];\n }\n\n // padded strides (contiguous)\n std::vector in_padded_strides(in.ndim(), 1);\n for (int i = in.ndim() - 2; i >= 0; --i) {\n in_padded_strides[i] = in_padded_strides[i + 1] * in_padded_shape[i + 1];\n }\n\n // Pad input\n std::vector padded_axes(in.ndim() - 2, 0);\n std::iota(padded_axes.begin(), padded_axes.end(), 1);\n auto in_padded =\n pad(in, padded_axes, padding, padding, array(0, in.dtype()), s);\n\n // Resolve strided patches\n\n // patches are shaped as\n // (batch_dim, out_spatial_dims, weight_spatial_dims, in_channels)\n std::vector patches_shape{\n cotan.shape().begin(), cotan.shape().end() - 1};\n patches_shape.insert(\n patches_shape.end(), wt.shape().begin() + 1, wt.shape().end());\n\n // Resolve patch strides\n int n_spatial_dim = in.ndim() - 2;\n std::vector patches_strides(patches_shape.size(), 1);\n patches_strides[0] = in_padded_strides[0];\n for (int i = 1; i < n_spatial_dim + 1; i++) {\n patches_strides[i] = in_padded_strides[i] * kernel_strides[i - 1];\n }\n for (int i = 1; i < in.ndim(); i++) {\n patches_strides[n_spatial_dim + i] = in_padded_strides[i];\n }\n\n // Make patches from in\n auto in_patches = as_strided(in_padded, patches_shape, patches_strides, 0, s);\n\n // Prepare for matmul\n int O = wt.shape(0);\n auto cotan_mat = reshape(cotan, {-1, O}, s);\n in_patches = reshape(in_patches, {cotan_mat.shape(0), -1}, s);\n\n auto grad = matmul(transpose(cotan_mat, {1, 0}, s), in_patches, s);\n grad = reshape(grad, wt.shape(), s);\n return grad;\n}\n\nstd::vector Convolution::vjp(\n const std::vector& primals,\n const std::vector& cotangents,\n const std::vector& argnums,\n const std::vector&) {\n assert(primals.size() == 2);\n std::vector grads;\n\n // Collect info\n auto& in = primals[0];\n auto& wt = primals[1];\n auto& cotan = cotangents[0];\n\n for (int a : argnums) {\n // Grads for input\n if (a == 0) {\n std::vector padding_lo = padding_;\n std::vector padding_hi = padding_;\n\n for (int i = 0; i < padding_lo.size(); ++i) {\n int wt_size = 1 + kernel_dilation_[i] * (wt.shape(1 + i) - 1);\n padding_lo[i] = wt_size - padding_[i] - 1;\n\n int in_size = 1 + input_dilation_[i] * (in.shape(1 + i) - 1);\n int out_size = 1 + kernel_strides_[i] * (cotan.shape(1 + i) - 1);\n padding_hi[i] = in_size - out_size + padding_[i];\n }\n\n auto wt_trans = swapaxes(wt, 0, -1, stream());\n\n auto grad = conv_general(\n /* const array& input = */ cotan,\n /* const array& weight = */ wt_trans,\n /* std::vector stride = */ input_dilation_,\n /* std::vector padding_lo = */ padding_lo,\n /* std::vector padding_hi = */ padding_hi,\n /* std::vector kernel_dilation = */ kernel_dilation_,\n /* std::vector input_dilation = */ kernel_strides_,\n /* int groups = */ 1,\n /* bool flip = */ !flip_,\n stream());\n\n grads.push_back(grad);\n }\n // Grads for weight\n else if (a == 1) {\n bool no_dilation = true;\n\n for (int i = 0; i < input_dilation_.size(); i++) {\n no_dilation &= (input_dilation_[i] == 1) && (kernel_dilation_[i] == 1);\n }\n\n if (no_dilation) {\n auto grad = conv_weight_backward_patches(\n in, wt, cotan, kernel_strides_, padding_, stream());\n grads.push_back(grad);\n } else {\n std::vector padding_lo = padding_;\n std::vector padding_hi = padding_;\n\n for (int i = 0; i < padding_hi.size(); ++i) {\n int in_size = 1 + input_dilation_[i] * (in.shape(1 + i) - 1);\n int out_size = 1 + kernel_strides_[i] * (cotan.shape(1 + i) - 1);\n int wt_size = 1 + kernel_dilation_[i] * (wt.shape(1 + i) - 1);\n padding_hi[i] = out_size - in_size + wt_size - padding_[i] - 1;\n }\n\n auto in_trans = swapaxes(in, 0, -1, stream());\n auto cotan_trans = swapaxes(cotan, 0, -1, stream());\n auto grad_trans = conv_general(\n /* const array& input = */ in_trans,\n /* const array& weight = */ cotan_trans,\n /* std::vector stride = */ kernel_dilation_,\n /* std::vector padding_lo = */ padding_lo,\n /* std::vector padding_hi = */ padding_hi,\n /* std::vector kernel_dilation = */ kernel_strides_,\n /* std::vector input_dilation = */ input_dilation_,\n /* int groups = */ 1,\n /* bool flip = */ flip_,\n stream());\n auto grad = swapaxes(grad_trans, 0, -1, stream());\n grads.push_back(grad);\n }\n }\n }\n\n return grads;\n}\n\nbool Convolution::is_equivalent(const Primitive& other) const {\n const Convolution& c_other = static_cast(other);\n return padding_ == c_other.padding_ &&\n kernel_strides_ == c_other.kernel_strides_ &&\n kernel_dilation_ == c_other.kernel_dilation_ &&\n input_dilation_ == c_other.input_dilation_ &&\n groups_ == c_other.groups_ && flip_ == c_other.flip_;\n}\n\nstd::vector Copy::vjp(\n const std::vector& primals,\n const std::vector& cotangents,\n const std::vector& argnums,\n const std::vector&) {\n assert(primals.size() == 1);\n assert(argnums.size() == 1);\n return cotangents;\n}\n\nstd::vector Copy::jvp(\n const std::vector& primals,\n const std::vector& tangents,\n const std::vector& argnums) {\n assert(primals.size() == 1);\n assert(argnums.size() == 1);\n return tangents;\n}\n\nstd::pair, std::vector> Copy::vmap(\n const std::vector& inputs,\n const std::vector& axes) {\n assert(inputs.size() == 1);\n assert(axes.size() == 1);\n return {{copy(inputs[0], stream())}, axes};\n}\n\nstd::vector Cos::vjp(\n const std::vector& primals,\n const std::vector& cotangents,\n const std::vector& argnums,\n const std::vector&) {\n return {jvp(primals, cotangents, argnums)};\n}\n\nstd::vector Cos::jvp(\n const std::vector& primals,\n const std::vector& tangents,\n const std::vector& argnums) {\n assert(primals.size() == 1);\n assert(argnums.size() == 1);\n return {multiply(\n tangents[0], negative(sin(primals[0], stream()), stream()), stream())};\n}\n\nstd::pair, std::vector> Cos::vmap(\n const std::vector& inputs,\n const std::vector& axes) {\n assert(inputs.size() == 1);\n assert(axes.size() == 1);\n return {{cos(inputs[0], stream())}, axes};\n}\n\nstd::vector Cosh::vjp(\n const std::vector& primals,\n const std::vector& cotangents,\n const std::vector& argnums,\n const std::vector&) {\n return jvp(primals, cotangents, argnums);\n}\n\nstd::vector Cosh::jvp(\n const std::vector& primals,\n const std::vector& tangents,\n const std::vector& argnums) {\n assert(primals.size() == 1);\n assert(argnums.size() == 1);\n return {multiply(tangents[0], sinh(primals[0], stream()), stream())};\n}\n\nstd::pair, std::vector> Cosh::vmap(\n const std::vector& inputs,\n const std::vector& axes) {\n assert(inputs.size() == 1);\n assert(axes.size() == 1);\n return {{cosh(inputs[0], stream())}, axes};\n}\n\nstd::vector CustomVJP::vjp(\n const std::vector& primals,\n const std::vector& cotangents,\n const std::vector& argnums,\n const std::vector& outputs) {\n std::vector inputs(primals.begin(), primals.end() - outputs.size());\n auto all_vjps = vjp_fun_(inputs, cotangents, outputs);\n for (const auto& cot : cotangents) {\n all_vjps.emplace_back(cot);\n }\n\n std::vector vjps;\n vjps.reserve(argnums.size());\n for (auto arg : argnums) {\n vjps.push_back(all_vjps[arg]);\n }\n\n return vjps;\n}\n\nstd::vector Depends::vjp(\n const std::vector& primals,\n const std::vector& cotangents,\n const std::vector& argnums,\n const std::vector& outputs) {\n std::vector vjps;\n\n for (auto arg : argnums) {\n if (arg < cotangents.size()) {\n vjps.push_back(cotangents[arg]);\n } else {\n vjps.push_back(zeros_like(primals[arg]));\n }\n }\n return vjps;\n}\n\nstd::vector Divide::vjp(\n const std::vector& primals,\n const std::vector& cotangents,\n const std::vector& argnums,\n const std::vector&) {\n std::vector vjps;\n for (auto arg : argnums) {\n if (arg == 0) {\n vjps.push_back(divide(cotangents[0], primals[1], stream()));\n } else {\n vjps.push_back(negative(\n divide(\n multiply(cotangents[0], primals[0], stream()),\n square(primals[1], stream()),\n stream()),\n stream()));\n }\n }\n return vjps;\n}\n\nstd::vector DivMod::vjp(\n const std::vector& primals,\n const std::vector& cotangents,\n const std::vector& argnums,\n const std::vector&) {\n std::vector vjps;\n for (auto arg : argnums) {\n vjps.push_back(zeros_like(primals[arg], stream()));\n }\n return vjps;\n}\n\nstd::vector DivMod::jvp(\n const std::vector& primals,\n const std::vector& tangents,\n const std::vector& argnums) {\n return {zeros_like(primals[0], stream())};\n}\n\nstd::pair, std::vector> DivMod::vmap(\n const std::vector& inputs,\n const std::vector& axes) {\n auto [a, b, to_ax] = vmap_binary_op(inputs, axes, stream());\n return {divmod(a, b, stream()), {to_ax}};\n}\n\nstd::vector Divide::jvp(\n const std::vector& primals,\n const std::vector& tangents,\n const std::vector& argnums) {\n auto jvp_fun = [&](int i) {\n int arg = argnums[i];\n if (arg == 0) {\n return divide(tangents[i], primals[1], stream());\n } else {\n return negative(\n divide(\n multiply(tangents[i], primals[0], stream()),\n square(primals[1], stream()),\n stream()),\n stream());\n }\n };\n auto out = jvp_fun(0);\n if (argnums.size() > 1) {\n out = add(out, jvp_fun(1), stream());\n }\n return {out};\n}\n\nstd::pair, std::vector> Divide::vmap(\n const std::vector& inputs,\n const std::vector& axes) {\n auto [a, b, to_ax] = vmap_binary_op(inputs, axes, stream());\n return {{divide(a, b, stream())}, {to_ax}};\n}\n\nstd::vector Remainder::vjp(\n const std::vector& primals,\n const std::vector& cotangents,\n const std::vector& argnums,\n const std::vector&) {\n std::vector vjps;\n for (auto arg : argnums) {\n if (arg == 0) {\n vjps.push_back(cotangents[0]);\n } else {\n auto x_over_y = divide(primals[0], primals[1], stream());\n x_over_y = floor(x_over_y, stream());\n vjps.push_back(\n negative(multiply(x_over_y, cotangents[0], stream()), stream()));\n }\n }\n return vjps;\n}\n\nstd::vector Remainder::jvp(\n const std::vector& primals,\n const std::vector& tangents,\n const std::vector& argnums) {\n auto jvp_fun = [&](int i) {\n int arg = argnums[i];\n if (arg == 0) {\n return tangents[i];\n } else {\n auto x_over_y = divide(primals[0], primals[1], stream());\n x_over_y = floor(x_over_y, stream());\n return negative(multiply(x_over_y, tangents[i], stream()), stream());\n }\n };\n auto out = jvp_fun(0);\n if (argnums.size() > 1) {\n out = add(out, jvp_fun(1), stream());\n }\n return {out};\n}\n\nstd::pair, std::vector> Remainder::vmap(\n const std::vector& inputs,\n const std::vector& axes) {\n auto [a, b, to_ax] = vmap_binary_op(inputs, axes, stream());\n return {{remainder(a, b, stream())}, {to_ax}};\n}\n\nstd::pair, std::vector> Equal::vmap(\n const std::vector& inputs,\n const std::vector& axes) {\n auto [a, b, to_ax] = vmap_binary_op(inputs, axes, stream());\n return {{equal(a, b, stream())}, {to_ax}};\n}\n\nstd::vector Equal::vjp(\n const std::vector& primals,\n const std::vector& cotangents,\n const std::vector& argnums,\n const std::vector&) {\n std::vector vjps;\n for (auto arg : argnums) {\n vjps.push_back(zeros_like(primals[arg], stream()));\n }\n return vjps;\n}\n\nstd::vector Equal::jvp(\n const std::vector& primals,\n const std::vector& tangents,\n const std::vector& argnums) {\n auto shape = broadcast_shapes(primals[0].shape(), primals[1].shape());\n return {zeros(shape, bool_, stream())};\n}\n\nstd::vector Erf::vjp(\n const std::vector& primals,\n const std::vector& cotangents,\n const std::vector& argnums,\n const std::vector&) {\n return jvp(primals, cotangents, argnums);\n}\n\nstd::vector Erf::jvp(\n const std::vector& primals,\n const std::vector& tangents,\n...\n// Path: mlx/dtype.cpp\n// Copyright \u00a9 2023-2024 Apple Inc.\n\n#include \n#include \n#include \n\n#include \"mlx/dtype.h\"\n#include \"mlx/utils.h\"\n\nnamespace mlx::core {\n\nnamespace {\n\nconstexpr int num_types = 13;\nconstexpr int num_cats = 8;\n\nconstexpr Dtype::Kind type_kinds[num_types] = {\n Dtype::Kind::b, // bool_,\n Dtype::Kind::u, // uint8,\n Dtype::Kind::u, // uint16,\n Dtype::Kind::u, // uint32,\n Dtype::Kind::u, // uint64,\n Dtype::Kind::i, // int8,\n Dtype::Kind::i, // int16,\n Dtype::Kind::i, // int32,\n Dtype::Kind::i, // int64,\n Dtype::Kind::f, // float16,\n Dtype::Kind::f, // float32,\n Dtype::Kind::V, // bfloat16,\n Dtype::Kind::c // complex64,\n};\n\n// Following Jax type promotion rules:\n// https://jax.readthedocs.io/en/latest/type_promotion.html\n// clang-format off\nconstexpr Dtype type_rules[num_types][num_types] = {\n// bool uint8 uint16 uint32 uint64 int8 int16 int32 int64 float16 float32 bfloat16 complex64\n {bool_, uint8, uint16, uint32, uint64, int8, int16, int32, int64, float16, float32, bfloat16, complex64}, // bool\n {uint8, uint8, uint16, uint32, uint64, int16, int16, int32, int64, float16, float32, bfloat16, complex64}, // uint8\n {uint16, uint16, uint16, uint32, uint64, int32, int32, int32, int64, float16, float32, bfloat16, complex64}, // uint16\n {uint32, uint32, uint32, uint32, uint64, int64, int64, int64, int64, float16, float32, bfloat16, complex64}, // uint32\n {uint64, uint64, uint64, uint64, uint64, float32, float32, float32, float32, float16, float32, bfloat16, complex64}, // uint64\n {int8, int16, int32, int64, float32, int8, int16, int32, int64, float16, float32, bfloat16, complex64}, // int8\n {int16, int16, int32, int64, float32, int16, int16, int32, int64, float16, float32, bfloat16, complex64}, // int16\n {int32, int32, int32, int64, float32, int32, int32, int32, int64, float16, float32, bfloat16, complex64}, // int32\n {int64, int64, int64, int64, float32, int64, int64, int64, int64, float16, float32, bfloat16, complex64}, // int64\n {float16, float16, float16, float16, float16, float16, float16, float16, float16, float16, float32, float32, complex64}, // float16\n {float32, float32, float32, float32, float32, float32, float32, float32, float32, float32, float32, float32, complex64}, // float32\n {bfloat16, bfloat16, bfloat16, bfloat16, bfloat16, bfloat16, bfloat16, bfloat16, bfloat16, float32, float32, bfloat16, complex64}, // bfloat16\n {complex64, complex64, complex64, complex64, complex64, complex64, complex64, complex64, complex64, complex64, complex64, complex64, complex64}, // complex64\n};\n\n\nconstexpr bool subcategory_to_category[num_cats][num_cats] = {\n// complexfloating floating inexact signedinteger unsignedinteger integer number generic\n {true, false, true, false, false, false, true, true}, // complexfloating\n {false, true, true, false, false, false, true, true}, // floating\n {false, false, true, false, false, false, true, true}, // inexact\n {false, false, false, true, false, true, true, true}, // signedinteger\n {false, false, false, false, true, true, true, true}, // unsignedinteger\n {false, false, false, false, false, true, true, true}, // integer\n {false, false, false, false, false, false, true, true}, // number\n {false, false, false, false, false, false, false, true}, // generic\n};\n\nconstexpr Dtype::Category type_to_category[num_types] = {\n Dtype::Category::generic, // bool_,\n Dtype::Category::unsignedinteger, // uint8,\n Dtype::Category::unsignedinteger, // uint16,\n Dtype::Category::unsignedinteger, // uint32,\n Dtype::Category::unsignedinteger, // uint64,\n Dtype::Category::signedinteger, // int8,\n Dtype::Category::signedinteger, // int16,\n Dtype::Category::signedinteger, // int32,\n Dtype::Category::signedinteger, // int64,\n Dtype::Category::floating, // float16,\n Dtype::Category::floating, // float32,\n Dtype::Category::floating, // bfloat16,\n Dtype::Category::complexfloating, // complex64,\n};\n\n// clang-format on\n\n} // namespace\n\nD\ntype promote_types(const Dtype& t1, const Dtype& t2) {\n return Dtype(type_rules[static_cast(t1.val)][static_cast(t2.val)]);\n}\n\nDtype::Kind kindof(const Dtype& t) {\n return type_kinds[static_cast(t.val)];\n}\n\ntemplate <>\nTypeToDtype::operator Dtype() {\n return bool_;\n}\n\ntemplate <>\nTypeToDtype::operator Dtype() {\n return uint8;\n}\n\ntemplate <>\nTypeToDtype::operator Dtype() {\n return uint16;\n}\n\ntemplate <>\nTypeToDtype::operator Dtype() {\n return uint32;\n}\n\ntemplate <>\nTypeToDtype::operator Dtype() {\n return uint64;\n}\n\ntemplate <>\nTypeToDtype::operator Dtype() {\n return int8;\n}\n\ntemplate <>\nTypeToDtype::operator Dtype() {\n return int16;\n}\n\ntemplate <>\nTypeToDtype::operator Dtype() {\n return int32;\n}\n\ntemplate <>\nTypeToDtype::operator Dtype() {\n return int64;\n}\n\ntemplate <>\nTypeToDtype::operator Dtype() {\n return float16;\n}\n\ntemplate <>\nTypeToDtype::operator Dtype() {\n return float32;\n}\n\ntemplate <>\nTypeToDtype::operator Dtype() {\n return float32;\n}\n\ntemplate <>\nTypeToDtype::operator Dtype() {\n return bfloat16;\n}\n\ntemplate <>\nTypeToDtype::operator Dtype() {\n return complex64;\n}\n\nbool issubdtype(const Dtype& a, const Dtype& b) {\n return a == b;\n}\n\nbool issubdtype(const Dtype::Category& cat, const Dtype& type) {\n return false;\n}\n\nbool issubdtype(const Dtype& type, const Dtype::Category& cat) {\n return issubdtype(type_to_category[static_cast(type.val)], cat);\n}\n\nbool issubdtype(const Dtype::Category& a, const Dtype::Category& b) {\n return subcategory_to_category[static_cast(a)]\n [static_cast(b)];\n}\n\n// Array protocol typestring for Dtype\nstd::string dtype_to_array_protocol(const Dtype& t) {\n std::ostringstream r;\n if (size_of(t) > 1)\n r << (is_big_endian() ? \">\" : \"<\");\n else\n r << \"|\";\n r << kindof(t) << (int)size_of(t);\n return r.str();\n}\n\n// Dtype from array protocol type string\nDtype dtype_from_array_protocol(std::string_view t) {\n...\n", "needle_token_start": 15381, "needle_token_end": 15438, "code_context_ntokens": 16200, "output": ["```cpp\nDtype promote_types(const Dtype& t1, const Dtype& t2) {\n return Dtype(type_rules[static_cast(t1.val)][static_cast(t2.val)]);\n}\n```"]} +{"repo": "scylladb/seastar", "name": "send", "language": "cpp", "path": "src/net/dhcp.cc", "position_ratio": 0.05, "description": "\nFunction Description:\n1. **Purpose**: The function is designed to transmit a DHCP packet over a UDP socket to a predefined server port, with the packet containing specific DHCP fields set for communication.\n2. **Input**: The function accepts a DHCP packet which has been prepared for transmission.\n3. **Output**: The function returns a future object that resolves when the packet sending operation is complete, although the actual sending operation's result is not tracked due to the future being discarded.\n4. **Procedure**: \n - The function first sets the transaction ID and client hardware address in the DHCP packet.\n - It then converts the packet from host byte order to network byte order.\n - The packet is sent to the broadcast address on the server port using a UDP socket.\n - A future object is returned immediately, indicating the operation is complete, but without tracking the success of the send operation.\n\n", "instruction": "Based on the function description and code context, please retrieve and repeat the exact described function from the code context in a code block wrapped by ```:", "template": "instruction\ncode_context\ndescription\ninstruction", "code_context": " _state = state::DONE;\n _result.set_value(info);\n break;\n default:\n break;\n }\n return make_ready_future<>();\n }\n\n future<> handle(packet& p, ip_hdr* iph, ethernet_address from, bool & handled) override {\n if (_state == state::NONE || p.len() < sizeof(dhcp_packet_base)) {\n return make_ready_future<>();\n }\n\n auto ipl = iph->ihl * 4;\n auto udp = p.get_header(ipl);\n auto dhp = p.get_header(ipl + sizeof(*udp));\n\n const auto opt_off = ipl + sizeof(*udp) + sizeof(dhcp_payload);\n\n if (udp == nullptr || dhp == nullptr\n || iph->ip_proto != uint8_t(ip_protocol_num::udp)\n || (::ntohs)(udp->dst_port) != client_port\n || iph->len < (opt_off + sizeof(option_mark))\n || dhp->magic != options_magic) {\n return make_ready_future<>();\n }\n handled = true;\n auto src_cpu = this_shard_id();\n if (src_cpu == 0) {\n return process_packet(std::move(p), dhp, opt_off);\n }\n // FIXME: future is discarded\n (void)smp::submit_to(0, [this, p = std::move(p), src_cpu, dhp, opt_off]() mutable {\n return process_packet(p.free_on_cpu(src_cpu), dhp, opt_off);\n });\n return make_ready_future<>();\n }\n\n future> run(const lease & l,\n const steady_clock_type::duration & timeout) {\n\n _state = state::NONE;\n _timer.set_callback([this]() {\n _state = state::FAIL;\n log() << \"timeout\" << std::endl;\n _retry_timer.cancel();\n _result.set_value(std::nullopt);\n });\n\n log() << \"sending discover\" << std::endl;\n (void)send_discover(l.ip); // FIXME: ignoring return\n if (timeout.count()) {\n _timer.arm(timeout);\n }\n _retry_timer.set_callback([this, l] {\n // FIXME: ignoring return\n (void)send_discover(l.ip);\n });\n _retry_timer.arm_periodic(1s);\n return _result.get_future();\n }\n\n template\n \nfuture<> send(T && pkt) {\n pkt.dhp.bootp.xid = _xid;\n auto ipf = _stack.netif();\n auto mac = ipf->hw_address().mac;\n std::copy(mac.begin(), mac.end(), std::begin(pkt.dhp.bootp.chaddr));\n\n pkt = hton(pkt);\n\n // FIXME: future is discarded\n (void)_sock.send({0xffffffff, server_port}, packet(reinterpret_cast(&pkt), sizeof(pkt)));\n\n return make_ready_future<>();\n }\n\n future<> send_discover(const ipv4_address & ip = ipv4_address()) {\n struct discover : public dhcp_packet_base {\n type_option type = type_option(msg_type::DISCOVER);\n ip_option requested_ip;\n requested_option req;\n option_mark end;\n } __attribute__((packed));\n\n discover d;\n\n d.requested_ip = ip_option(opt_type::REQUESTED_ADDRESS, ip);\n\n std::random_device rd;\n std::default_random_engine e1(rd());\n std::uniform_int_distribution xid_dist{};\n\n _xid = xid_dist(e1);\n _state = state::DISCOVER;\n return send(d);\n }\n\n future<> send_request(const lease & info) {\n struct request : public dhcp_packet_base {\n type_option type = type_option(msg_type::REQUEST);\n ip_option dhcp_server;\n ip_option requested_ip;\n requested_option req;\n option_mark end;\n } __attribute__((packed));\n\n request d;\n\n d.dhcp_server = ip_option(opt_type::DHCP_SERVER, info.dhcp_server);\n d.requested_ip = ip_option(opt_type::REQUESTED_ADDRESS, info.ip);\n\n log() << \"sending request for \" << info.ip << std::endl;\n _state = state::REQUEST;\n return send(d);\n }\n\nprivate:\n promise> _result;\n state _state = state::NONE;\n timer<> _timer;\n timer<> _retry_timer;\n ipv4 & _stack;\n udp_channel _sock;\n uint32_t _xid = 0;\n};\n\nconst net::dhcp::impl::req_opt_type net::dhcp::impl::requested_options = { {\n opt_type::SUBNET_MASK, opt_type::ROUTER, opt_type::DOMAIN_NAME_SERVERS,\n opt_type::INTERFACE_MTU, opt_type::BROADCAST_ADDRESS } };\n\nconst net::dhcp::impl::magic_tag net::dhcp::impl::options_magic = { { 0x63, 0x82, 0x53,\n 0x63 } };\n\nconst uint16_t net::dhcp::impl::client_port;\nconst uint16_t net::dhcp::impl::server_port;\n\nconst steady_clock_type::duration net::dhcp::default_timeout = std::chrono::duration_cast(std::chrono::seconds(30));\n\nnet::dhcp::dhcp(ipv4 & ip)\n: _impl(std::make_unique(ip))\n{}\n\nnet::dhcp::dhcp(dhcp && v) noexcept = default;\n\nnet::dhcp::~dhcp()\n{}\n\nnet::dhcp::result_type net::dhcp::discover(const steady_clock_type::duration & timeout) {\n return _impl->run(lease(), timeout);\n}\n\nnet::dhcp::result_type net::dhcp::renew(const lease & l, const steady_clock_type::duration & timeout) {\n return _impl->run(l, timeout);\n}\n\nnet::ip_packet_filter* net::dhcp::get_ipv4_filter() {\n return _impl.get();\n}\n\n}\n\n// Path: src/net/ip.cc\n/*\n * This file is open source software, licensed to you under the terms\n * of the Apache License, Version 2.0 (the \"License\"). See the NOTICE file\n * distributed with this work for additional information regarding copyright\n * ownership. You may not use this file except in compliance with the License.\n *\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing,\n * software distributed under the License is distributed on an\n * \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY\n * KIND, either express or implied. See the License for the\n * specific language governing permissions and limitations\n * under the License.\n */\n/*\n * Copyright (C) 2014 Cloudius Systems, Ltd.\n *\n */\n\n#ifdef SEASTAR_MODULE\nmodule;\n#include \n#include \n#include \n#include \nmodule seastar;\n#else\n#include \n#include \n#include \n#include \n#include \n#endif\n\nnamespace seastar {\n\nnamespace net {\n\nipv4_address::ipv4_address(const std::string& addr) {\n boost::system::error_code ec;\n auto ipv4 = boost::asio::ip::address_v4::from_string(addr, ec);\n if (ec) {\n throw std::runtime_error(\n fmt::format(\"Wrong format for IPv4 address {}. Please ensure it's in dotted-decimal format\", addr));\n }\n ip = static_cast(std::move(ipv4).to_ulong());\n}\n\nipv4::ipv4(interface* netif)\n : _netif(netif)\n , _global_arp(netif)\n , _arp(_global_arp)\n , _host_address(0)\n , _gw_address(0)\n , _netmask(0)\n , _l3(netif, eth_protocol_num::ipv4, [this] { return get_packet(); })\n , _tcp(*this)\n , _icmp(*this)\n , _udp(*this)\n , _l4({ { uint8_t(ip_protocol_num::tcp), &_tcp }, { uint8_t(ip_protocol_num::icmp), &_icmp }, { uint8_t(ip_protocol_num::udp), &_udp }})\n{\n namespace sm = seastar::metrics;\n // FIXME: ignored future\n (void)_l3.receive(\n [this](packet p, ethernet_address ea) {\n return handle_received_packet(std::move(p), ea);\n },\n [this](forward_hash& out_hash_data, packet& p, size_t off) {\n return forward(out_hash_data, p, off);\n });\n\n _metrics.add_group(\"ipv4\", {\n //\n // Linearized events: DERIVE:0:u\n //\n sm::make_counter(\"linearizations\", [] { return ipv4_packet_merger::linearizations(); },\n sm::description(\"Counts a number of times a buffer linearization was invoked during buffers merge process. \"\n \"Divide it by a total IPv4 receive packet rate to get an average number of lineraizations per packet.\"))\n });\n _frag_timer.set_callback([this] { frag_timeout(); });\n}\n\nbool ipv4::forward(forward_hash& out_hash_data, packet& p, size_t off)\n{\n auto iph = p.get_header(off);\n\n out_hash_data.push_back(iph->src_ip.ip);\n out_hash_data.push_back(iph->dst_ip.ip);\n\n auto h = ntoh(*iph);\n auto l4 = _l4[h.ip_proto];\n if (l4) {\n if (h.mf() == false && h.offset() == 0) {\n // This IP datagram is atomic, forward according to tcp or udp connection hash\n l4->forward(out_hash_data, p, off + sizeof(ip_hdr));\n }\n // else forward according to ip fields only\n }\n return true;\n}\n\nbool ipv4::in_my_netmask(ipv4_address a) const {\n return !((a.ip ^ _host_address.ip) & _netmask.ip);\n}\n\nbool ipv4::needs_frag(packet& p, ip_protocol_num prot_num, net::hw_features hw_features) {\n if (p.len() + ipv4_hdr_len_min <= hw_features.mtu) {\n return false;\n }\n\n if ((prot_num == ip_protocol_num::tcp && hw_features.tx_tso) ||\n (prot_num == ip_protocol_num::udp && hw_features.tx_ufo)) {\n return false;\n }\n\n return true;\n}\n\nfuture<>\nipv4::handle_received_packet(packet p, ethernet_address from) {\n auto iph = p.get_header(0);\n if (!iph) {\n return make_ready_future<>();\n }\n\n // Skip checking csum of reassembled IP datagram\n if (!hw_features().rx_csum_offload && !p.offload_info_ref().reassembled) {\n checksummer csum;\n csum.sum(reinterpret_cast(iph), sizeof(*iph));\n if (csum.get() != 0) {\n return make_ready_future<>();\n }\n }\n\n auto h = ntoh(*iph);\n unsigned ip_len = h.len;\n unsigned ip_hdr_len = h.ihl * 4;\n unsigned pkt_len = p.len();\n auto offset = h.offset();\n if (pkt_len > ip_len) {\n // Trim extra data in the packet beyond IP total length\n p.trim_back(pkt_len - ip_len);\n } else if (pkt_len < ip_len) {\n // Drop if it contains less than IP total length\n return make_ready_future<>();\n }\n // Drop if the reassembled datagram will be larger than maximum IP size\n if (offset + p.len() > net::ip_packet_len_max) {\n return make_ready_future<>();\n }\n\n // FIXME: process options\n if (in_my_netmask(h.src_ip) && h.src_ip != _host_address) {\n _arp.learn(from, h.src_ip);\n }\n\n if (_packet_filter) {\n bool handled = false;\n auto r = _packet_filter->handle(p, &h, from, handled);\n if (handled) {\n return r;\n }\n }\n\n if (h.dst_ip != _host_address) {\n // FIXME: forward\n return make_ready_future<>();\n }\n\n // Does this IP datagram need reassembly\n auto mf = h.mf();\n if (mf == true || offset != 0) {\n frag_limit_mem();\n auto frag_id = ipv4_frag_id{h.src_ip, h.dst_ip, h.id, h.ip_proto};\n auto& frag = _frags[frag_id];\n if (mf == false) {\n frag.last_frag_received = true;\n }\n // This is a newly created frag_id\n if (frag.mem_size == 0) {\n _frags_age.push_back(frag_id);\n frag.rx_time = clock_type::now();\n }\n auto added_size = frag.merge(h, offset, std::move(p));\n _frag_mem += added_size;\n if (frag.is_complete()) {\n // All the fragments are received\n auto dropped_size = frag.mem_size;\n auto& ip_data = frag.data.map.begin()->second;\n // Choose a cpu to forward this packet\n auto cpu_id = this_shard_id();\n auto l4 = _l4[h.ip_proto];\n if (l4) {\n if (smp::count == 1) {\n l4->received(std::move(ip_data), h.src_ip, h.dst_ip);\n } else {\n size_t l4_offset = 0;\n forward_hash hash_data;\n hash_data.push_back(hton(h.src_ip.ip));\n hash_data.push_back(hton(h.dst_ip.ip));\n auto forwarded = l4->forward(hash_data, ip_data, l4_offset);\n if (forwarded) {\n cpu_id = _netif->hash2cpu(toeplitz_hash(_netif->rss_key(), hash_data));\n // No need to forward if the dst cpu is the current cpu\n if (cpu_id == this_shard_id()) {\n l4->received(std::move(ip_data), h.src_ip, h.dst_ip);\n } else {\n auto to = _netif->hw_address();\n auto pkt = frag.get_assembled_packet(from, to);\n _netif->forward(cpu_id, std::move(pkt));\n }\n }\n }\n }\n\n // Delete this frag from _frags and _frags_age\n frag_drop(frag_id, dropped_size);\n _frags_age.remove(frag_id);\n } else {\n // Some of the fragments are missing\n if (!_frag_timer.armed()) {\n frag_arm();\n }\n }\n return make_ready_future<>();\n }\n\n auto l4 = _l4[h.ip_proto];\n if (l4) {\n // Trim IP header and pass to upper layer\n p.trim_front(ip_hdr_len);\n l4->received(std::move(p), h.src_ip, h.dst_ip);\n }\n return make_ready_future<>();\n}\n\nfuture ipv4::get_l2_dst_address(ipv4_address to) {\n // Figure out where to send the packet to. If it is a directly connected\n // host, send to it directly, otherwise send to the default gateway.\n ipv4_address dst;\n if (in_my_netmask(to)) {\n dst = to;\n } else {\n dst = _gw_address;\n }\n\n return _arp.lookup(dst);\n}\n\nvoid ipv4::send(ipv4_address to, ip_protocol_num proto_num, packet p, ethernet_address e_dst) {\n auto needs_frag = this->needs_frag(p, proto_num, hw_features());\n\n auto send_pkt = [this, to, proto_num, needs_frag, e_dst] (packet& pkt, uint16_t remaining, uint16_t offset) mutable {\n auto iph = pkt.prepend_header();\n iph->ihl = sizeof(*iph) / 4;\n iph->ver = 4;\n iph->dscp = 0;\n iph->ecn = 0;\n iph->len = pkt.len();\n // FIXME: a proper id\n iph->id = 0;\n if (needs_frag) {\n uint16_t mf = remaining > 0;\n // The fragment offset is measured in units of 8 octets (64 bits)\n auto off = offset / 8;\n iph->frag = (mf << uint8_t(ip_hdr::frag_bits::mf)) | off;\n } else {\n iph->frag = 0;\n }\n iph->ttl = 64;\n iph->ip_proto = (uint8_t)proto_num;\n iph->csum = 0;\n iph->src_ip = _host_address;\n iph->dst_ip = to;\n *iph = hton(*iph);\n\n if (hw_features().tx_csum_ip_offload) {\n iph->csum = 0;\n pkt.offload_info_ref().needs_ip_csum = true;\n } else {\n checksummer csum;\n csum.sum(reinterpret_cast(iph), sizeof(*iph));\n iph->csum = csum.get();\n }\n\n _packetq.push_back(l3_protocol::l3packet{eth_protocol_num::ipv4, e_dst, std::move(pkt)});\n };\n\n if (needs_frag) {\n uint16_t offset = 0;\n uint16_t remaining = p.len();\n auto mtu = hw_features().mtu;\n\n while (remaining) {\n auto can_send = std::min(uint16_t(mtu - net::ipv4_hdr_len_min), remaining);\n remaining -= can_send;\n auto pkt = p.share(offset, can_send);\n send_pkt(pkt, remaining, offset);\n offset += can_send;\n }\n } else {\n // The whole packet can be send in one shot\n send_pkt(p, 0, 0);\n }\n}\n\nstd::optional ipv4::get_packet() {\n // _packetq will be mostly empty here unless it hold remnants of previously\n // fragmented packet\n if (_packetq.empty()) {\n for (size_t i = 0; i < _pkt_providers.size(); i++) {\n auto l4p = _pkt_providers[_pkt_provider_idx++]();\n if (_pkt_provider_idx == _pkt_providers.size()) {\n _pkt_provider_idx = 0;\n }\n if (l4p) {\n auto l4pv = std::move(l4p.value());\n send(l4pv.to, l4pv.proto_num, std::move(l4pv.p), l4pv.e_dst);\n break;\n }\n }\n }\n\n std::optional p;\n if (!_packetq.empty()) {\n p = std::move(_packetq.front());\n _packetq.pop_front();\n }\n return p;\n}\n\nvoid ipv4::set_host_address(ipv4_address ip) {\n _host_address = ip;\n _arp.set_self_addr(ip);\n}\n\nipv4_address ipv4::host_address() const {\n return _host_address;\n}\n\nvoid ipv4::set_gw_address(ipv4_address ip) {\n _gw_address = ip;\n}\n\nipv4_address ipv4::gw_address() const {\n return _gw_address;\n}\n\nvoid ipv4::set_netmask_address(ipv4_address ip) {\n _netmask = ip;\n}\n\nipv4_address ipv4::netmask_address() const {\n return _netmask;\n}\n\nvoid ipv4::set_packet_filter(ip_packet_filter * f) {\n _packet_filter = f;\n}\n\nip_packet_filter * ipv4::packet_filter() const {\n return _packet_filter;\n}\n\nvoid ipv4::frag_limit_mem() {\n if (_frag_mem <= _frag_high_thresh) {\n return;\n }\n auto drop = _frag_mem - _frag_low_thresh;\n while (drop) {\n if (_frags_age.empty()) {\n return;\n }\n // Drop the oldest frag (first element) from _frags_age\n auto frag_id = _frags_age.front();\n _frags_age.pop_front();\n\n // Drop from _frags as well\n auto& frag = _frags[frag_id];\n auto dropped_size = frag.mem_size;\n frag_drop(frag_id, dropped_size);\n\n drop -= std::min(drop, dropped_size);\n }\n}\n\nvoid ipv4::frag_timeout() {\n if (_frags.empty()) {\n return;\n }\n auto now = clock_type::now();\n for (auto it = _frags_age.begin(); it != _frags_age.end();) {\n auto frag_id = *it;\n auto& frag = _frags[frag_id];\n if (now > frag.rx_time + _frag_timeout) {\n auto dropped_size = frag.mem_size;\n // Drop from _frags\n frag_drop(frag_id, dropped_size);\n // Drop from _frags_age\n it = _frags_age.erase(it);\n } else {\n // The further items can only be younger\n break;\n }\n }\n if (_frags.size() != 0) {\n frag_arm(now);\n } else {\n _frag_mem = 0;\n }\n}\n\nvoid ipv4::frag_drop(ipv4_frag_id frag_id, uint32_t dropped_size) {\n _frags.erase(frag_id);\n _frag_mem -= dropped_size;\n}\n\nint32_t ipv4::frag::merge(ip_hdr &h, uint16_t offset, packet p) {\n uint32_t old = mem_size;\n unsigned ip_hdr_len = h.ihl * 4;\n // Store IP header\n if (offset == 0) {\n header = p.share(0, ip_hdr_len);\n }\n // Sotre IP payload\n p.trim_front(ip_hdr_len);\n data.merge(offset, std::move(p));\n // Update mem size\n mem_size = header.memory();\n for (const auto& x : data.map) {\n mem_size += x.second.memory();\n }\n auto added_size = mem_size - old;\n return added_size;\n}\n\nbool ipv4::frag::is_complete() {\n // If all the fragments are received, ipv4::frag::merge() should merge all\n // the fragments into a single packet\n auto offset = data.map.begin()->first;\n auto nr_packet = data.map.size();\n return last_frag_received && nr_packet == 1 && offset == 0;\n}\n\npacket ipv4::frag::get_assembled_packet(ethernet_address from, ethernet_address to) {\n auto& ip_header = header;\n auto& ip_data = data.map.begin()->second;\n // Append a ethernet header, needed for forwarding\n auto eh = ip_header.prepend_header();\n eh->src_mac = from;\n eh->dst_mac = to;\n eh->eth_proto = uint16_t(eth_protocol_num::ipv4);\n *eh = hton(*eh);\n // Prepare a packet contains both ethernet header, ip header and ip data\n ip_header.append(std::move(ip_data));\n auto pkt = std::move(ip_header);\n auto iph = pkt.get_header(sizeof(eth_hdr));\n // len is the sum of each fragment\n iph->len = hton(uint16_t(pkt.len() - sizeof(eth_hdr)));\n // No fragmentation for the assembled datagram\n iph->frag = 0;\n // Since each fragment's csum is checked, no need to csum\n // again for the assembled datagram\n offload_info oi;\n oi.reassembled = true;\n pkt.set_offload_info(oi);\n return pkt;\n}\n\nvoid icmp::received(packet p, ipaddr from, ipaddr to) {\n auto hdr = p.get_header(0);\n if (!hdr || hdr->type != icmp_hdr::msg_type::echo_request) {\n return;\n }\n hdr->type = icmp_hdr::msg_type::echo_reply;\n hdr->code = 0;\n hdr->csum = 0;\n checksummer csum;\n csum.sum(reinterpret_cast(hdr), p.len());\n hdr->csum = csum.get();\n\n if (_queue_space.try_wait(p.len())) { // drop packets that do not fit the queue\n // FIXME: future is discarded\n (void)_inet.get_l2_dst_address(from).then([this, from, p = std::move(p)] (ethernet_address e_dst) mutable {\n _packetq.emplace_back(ipv4_traits::l4packet{from, std::move(p), e_dst, ip_protocol_num::icmp});\n });\n }\n}\n\n}\n\n}\n\n// Path: src/net/tcp.cc\n/*\n * This file is open source software, licensed to you under the terms\n * of the Apache License, Version 2.0 (the \"License\"). See the NOTICE file\n * distributed with this work for additional information regarding copyright\n * ownership. You may not use this file except in compliance with the License.\n *\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing,\n * software distributed under the License is distributed on an\n * \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY\n * KIND, either express or implied. See the License for the\n * specific language governing permissions and limitations\n * under the License.\n */\n/*\n * Copyright (C) 2014 Cloudius Systems, Ltd.\n */\n\n#ifdef SEASTAR_MODULE\nmodule;\n#include \n#include \n#include \n#include \n#include \n#include \nmodule seastar;\n#else\n#include \n#include \n#include \n#include \n#include \n#include \"net/native-stack-impl.hh\"\n#endif\n\nnamespace seastar {\n\nnamespace net {\n\nvoid tcp_option::parse(uint8_t* beg1, uint8_t* end1) {\n const char* beg = reinterpret_cast(beg1);\n const char* end = reinterpret_cast(end1);\n while (beg < end) {\n auto kind = option_kind(*beg);\n if (kind != option_kind::nop && kind != option_kind::eol) {\n // Make sure there is enough room for this option\n auto len = uint8_t(beg[1]);\n if (beg + len > end) {\n return;\n }\n }\n switch (kind) {\n case option_kind::mss:\n _mss_received = true;\n _remote_mss = mss::read(beg).mss;\n beg += option_len::mss;\n break;\n case option_kind::win_scale:\n _win_scale_received = true;\n _remote_win_scale = win_scale::read(beg).shift;\n // We can turn on win_scale option, 7 is Linux's default win scale size\n _local_win_scale = 7;\n beg += option_len::win_scale;\n break;\n case option_kind::sack:\n _sack_received = true;\n beg += option_len::sack;\n break;\n case option_kind::nop:\n beg += option_len::nop;\n break;\n case option_kind::eol:\n return;\n default:\n // Ignore options we do not understand\n uint8_t len = *(beg + 1);\n beg += len;\n // Prevent infinite loop\n if (len == 0) {\n return;\n }\n break;\n }\n }\n}\n\nuint8_t tcp_option::fill(void* h, const tcp_hdr* th, uint8_t options_size) {\n auto hdr = reinterpret_cast(h);\n auto off = hdr + tcp_hdr::len;\n uint8_t size = 0;\n bool syn_on = th->f_syn;\n bool ack_on = th->f_ack;\n\n if (syn_on) {\n if (_mss_received || !ack_on) {\n auto mss = tcp_option::mss();\n mss.mss = _local_mss;\n mss.write(off);\n off += mss.len;\n size += mss.len;\n }\n if (_win_scale_received || !ack_on) {\n auto win_scale = tcp_option::win_scale();\n win_scale.shift = _local_win_scale;\n win_scale.write(off);\n off += win_scale.len;\n size += win_scale.len;\n }\n }\n if (size > 0) {\n // Insert NOP option\n auto size_max = align_up(uint8_t(size + 1), tcp_option::align);\n while (size < size_max - uint8_t(option_len::eol)) {\n auto nop = tcp_option::nop();\n nop.write(off);\n off += option_len::nop;\n size += option_len::nop;\n }\n auto eol = tcp_option::eol();\n eol.write(off);\n size += option_len::eol;\n }\n assert(size == options_size);\n\n return size;\n}\n\nuint8_t tcp_option::get_size(bool syn_on, bool ack_on) {\n uint8_t size = 0;\n if (syn_on) {\n if (_mss_received || !ack_on) {\n size += option_len::mss;\n }\n if (_win_scale_received || !ack_on) {\n size += option_len::win_scale;\n }\n }\n if (size > 0) {\n size += option_len::eol;\n // Insert NOP option to align on 32-bit\n size = align_up(size, tcp_option::align);\n }\n return size;\n}\n\nipv4_tcp::ipv4_tcp(ipv4& inet)\n\t: _inet_l4(inet), _tcp(std::make_unique>(_inet_l4)) {\n}\n\nipv4_tcp::~ipv4_tcp() {\n}\n\nvoid ipv4_tcp::received(packet p, ipv4_address from, ipv4_address to) {\n _tcp->received(std::move(p), from, to);\n}\n\nbool ipv4_tcp::forward(forward_hash& out_hash_data, packet& p, size_t off) {\n\n return _tcp->forward(out_hash_data, p, off);\n}\n\nserver_socket\ntcpv4_listen(tcp& tcpv4, uint16_t port, listen_options opts) {\n\treturn server_socket(std::make_unique>>(\n\t\t\ttcpv4, port, opts));\n}\n\n::seastar::socket\ntcpv4_socket(tcp& tcpv4) {\n return ::seastar::socket(std::make_unique>>(\n tcpv4));\n}\n\n}\n\n}\n\n// Path: src/net/ip_checksum.cc\n/*\n * This file is open source software, licensed to you under the terms\n * of the Apache License, Version 2.0 (the \"License\"). See the NOTICE file\n * distributed with this work for additional information regarding copyright\n * ownership. You may not use this file except in compliance with the License.\n *\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing,\n * software distributed under the License is distributed on an\n * \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY\n * KIND, either express or implied. See the License for the\n * specific language governing permissions and limitations\n * under the License.\n */\n/*\n * Copyright (C) 2014 Cloudius Systems, Ltd.\n */\n\n#ifdef SEASTAR_MODULE\nmodule;\n#endif\n\n#include \n\n#ifdef SEASTAR_MODULE\nmodule seastar;\n#else\n#include \n#include \n#endif\n\nnamespace seastar {\n\nnamespace net {\n\nvoid checksummer::sum(const char* data, size_t len) {\n auto orig_len = len;\n if (odd) {\n csum += uint8_t(*data++);\n --len;\n }\n auto p64 = reinterpret_cast*>(data);\n while (len >= 8) {\n csum += ntohq(*p64++);\n len -= 8;\n }\n auto p16 = reinterpret_cast*>(p64);\n while (len >= 2) {\n csum += ntohs(*p16++);\n len -= 2;\n }\n auto p8 = reinterpret_cast(p16);\n if (len) {\n csum += *p8++ << 8;\n len -= 1;\n }\n odd ^= orig_len & 1;\n}\n\nuint16_t checksummer::get() const {\n __int128 csum1 = (csum & 0xffff'ffff'ffff'ffff) + (csum >> 64);\n uint64_t csum = (csum1 & 0xffff'ffff'ffff'ffff) + (csum1 >> 64);\n csum = (csum & 0xffff) + ((csum >> 16) & 0xffff) + ((csum >> 32) & 0xffff) + (csum >> 48);\n csum = (csum & 0xffff) + (csum >> 16);\n csum = (csum & 0xffff) + (csum >> 16);\n return htons(~csum);\n}\n\nvoid checksummer::sum(const packet& p) {\n for (auto&& f : p.fragments()) {\n sum(f.base, f.size);\n }\n}\n\nuint16_t ip_checksum(const void* data, size_t len) {\n checksummer cksum;\n cksum.sum(reinterpret_cast(data), len);\n return cksum.get();\n}\n\n\n}\n\n}\n\n// Path: src/net/udp.cc\n/*\n * This file is open source software, licensed to you under the terms\n * of the Apache License, Version 2.0 (the \"License\"). See the NOTICE file\n * distributed with this work for additional information regarding copyright\n * ownership. You may not use this file except in compliance with the License.\n *\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing,\n * software distributed under the License is distributed on an\n * \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY\n * KIND, either express or implied. See the License for the\n * specific language governing permissions and limitations\n * under the License.\n */\n/*\n * Copyright (C) 2014 Cloudius Systems, Ltd.\n */\n\n#ifdef SEASTAR_MODULE\nmodule;\n#include \n#include \n#include \n#include \n#include \n#include \n#include \nmodule seastar;\n#else\n#include \n#include \n#include \n#endif\n\nnamespace seastar {\n\nusing namespace net;\n\nnamespace net {\nnamespace ipv4_udp_impl {\n\nstatic inline\nipv4_addr\nto_ipv4_addr(ipv4_address a, uint16_t port) {\n return {a.ip, port};\n}\n\nclass native_datagram : public datagram_impl {\nprivate:\n ipv4_addr _src;\n ipv4_addr _dst;\n packet _p;\npublic:\n native_datagram(ipv4_address src, ipv4_address dst, packet p)\n : _p(std::move(p)) {\n udp_hdr* hdr = _p.get_header();\n auto h = ntoh(*hdr);\n _p.trim_front(sizeof(*hdr));\n _src = to_ipv4_addr(src, h.src_port);\n _dst = to_ipv4_addr(dst, h.dst_port);\n }\n\n virtual socket_address get_src() override {\n return _src;\n };\n\n virtual socket_address get_dst() override {\n return _dst;\n };\n\n virtual uint16_t get_dst_port() override {\n return _dst.port;\n }\n\n virtual packet& get_data() override {\n return _p;\n }\n};\n\nclass native_channel : public datagram_channel_impl {\nprivate:\n ipv4_udp& _proto;\n ipv4_udp::registration _reg;\n bool _closed;\n lw_shared_ptr _state;\n\npublic:\n native_channel(ipv4_udp &proto, ipv4_udp::registration reg, lw_shared_ptr state)\n : _proto(proto)\n , _reg(reg)\n , _closed(false)\n , _state(state)\n {\n }\n\n ~native_channel()\n {\n if (!_closed)\n close();\n }\n\n socket_address local_address() const override {\n return socket_address(_proto.inet().host_address(), _reg.port());\n }\n\n virtual future receive() override {\n return _state->_queue.pop_eventually();\n }\n\n virtual future<> send(const socket_address& dst, const char* msg) override {\n return send(dst, packet::from_static_data(msg, strlen(msg)));\n }\n\n virtual future<> send(const socket_address& dst, packet p) override {\n auto len = p.len();\n return _state->wait_for_send_buffer(len).then([this, dst, p = std::move(p), len] () mutable {\n p = packet(std::move(p), make_deleter([s = _state, len] { s->complete_send(len); }));\n _proto.send(_reg.port(), dst, std::move(p));\n });\n }\n\n virtual bool is_closed() const override {\n return _closed;\n }\n\n virtual void shutdown_input() override {\n _state->_queue.abort(std::make_exception_ptr(std::system_error(EBADF, std::system_category())));\n }\n\n virtual void shutdown_output() override {\n _state->_queue.abort(std::make_exception_ptr(std::system_error(EPIPE, std::system_category())));\n }\n\n virtual void close() override {\n _reg.unregister();\n _closed = true;\n }\n};\n\n} /* namespace ipv4_udp_impl */\n\nusing namespace net::ipv4_udp_impl;\n\nconst int ipv4_udp::default_queue_size = 1024;\n\nipv4_udp::ipv4_udp(ipv4& inet)\n : _inet(inet)\n{\n _inet.register_packet_provider([this] {\n std::optional l4p;\n if (!_packetq.empty()) {\n l4p = std::move(_packetq.front());\n _packetq.pop_front();\n }\n return l4p;\n });\n}\n\nbool ipv4_udp::forward(forward_hash& out_hash_data, packet& p, size_t off)\n{\n auto uh = p.get_header(off);\n\n if (uh) {\n out_hash_data.push_back(uh->src_port);\n out_hash_data.push_back(uh->dst_port);\n }\n return true;\n}\n\nvoid ipv4_udp::received(packet p, ipv4_address from, ipv4_address to)\n{\n datagram dgram(std::make_unique(from, to, std::move(p)));\n\n auto chan_it = _channels.find(dgram.get_dst_port());\n if (chan_it != _channels.end()) {\n auto chan = chan_it->second;\n chan->_queue.push(std::move(dgram));\n }\n}\n\nvoid ipv4_udp::send(uint16_t src_port, ipv4_addr dst, packet &&p)\n{\n auto src = _inet.host_address();\n auto hdr = p.prepend_header();\n hdr->src_port = src_port;\n hdr->dst_port = dst.port;\n hdr->len = p.len();\n *hdr = hton(*hdr);\n\n offload_info oi;\n checksummer csum;\n ipv4_traits::udp_pseudo_header_checksum(csum, src, dst, p.len());\n bool needs_frag = ipv4::needs_frag(p, ip_protocol_num::udp, _inet.hw_features());\n if (_inet.hw_features().tx_csum_l4_offload && !needs_frag) {\n hdr->cksum = ~csum.get();\n oi.needs_csum = true;\n } else {\n csum.sum(p);\n hdr->cksum = csum.get();\n oi.needs_csum = false;\n }\n oi.protocol = ip_protocol_num::udp;\n p.set_offload_info(oi);\n\n // FIXME: future is discarded\n (void)_inet.get_l2_dst_address(dst).then([this, dst, p = std::move(p)] (ethernet_address e_dst) mutable {\n _packetq.emplace_back(ipv4_traits::l4packet{dst, std::move(p), e_dst, ip_protocol_num::udp});\n });\n}\n\nuint16_t ipv4_udp::next_port(uint16_t port) {\n return (port + 1) == 0 ? min_anonymous_port : port + 1;\n}\n\nudp_channel\nipv4_udp::make_channel(ipv4_addr addr) {\n if (!is_ip_unspecified(addr)) {\n throw std::runtime_error(\"Binding to specific IP not supported yet\");\n }\n\n uint16_t bind_port;\n\n if (!is_port_unspecified(addr)) {\n if (_channels.count(addr.port)) {\n throw std::runtime_error(\"Address already in use\");\n }\n bind_port = addr.port;\n } else {\n auto starting_port = _next_anonymous_port;\n while (_channels.count(_next_anonymous_port)) {\n _next_anonymous_port = next_port(_next_anonymous_port);\n if (starting_port == _next_anonymous_port) {\n throw std::runtime_error(\"No free port\");\n }\n }\n\n bind_port = _next_anonymous_port;\n _next_anonymous_port = next_port(_next_anonymous_port);\n }\n\n auto chan_state = make_lw_shared(_queue_size);\n _channels[bind_port] = chan_state;\n return udp_channel(std::make_unique(*this, registration(*this, bind_port), chan_state));\n}\n\n} /* namespace net */\n\n}\n\n\n// Path: src/net/stack.cc\n/*\n * This file is open source software, licensed to you under the terms\n * of the Apache License, Version 2.0 (the \"License\"). See the NOTICE file\n * distributed with this work for additional information regarding copyright\n * ownership. You may not use this file except in compliance with the License.\n *\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing,\n * software distributed under the License is distributed on an\n * \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY\n * KIND, either express or implied. See the License for the\n * specific language governing permissions and limitations\n * under the License.\n */\n/*\n * Copyright 2015 Cloudius Systems\n */\n\n#ifdef SEASTAR_MODULE\nmodule;\n#endif\n\n#include \n#include \n#include \n#include \n\n#ifdef SEASTAR_MODULE\nmodule seastar;\n#else\n#include \n#include \n#endif\n\nnamespace seastar {\n\nstatic_assert(std::is_nothrow_default_constructible_v);\nstatic_assert(std::is_nothrow_move_constructible_v);\n\nstatic_assert(std::is_nothrow_default_constructible_v);\nstatic_assert(std::is_nothrow_move_constructible_v);\n\nstatic_assert(std::is_nothrow_default_constructible_v);\nstatic_assert(std::is_nothrow_move_constructible_v);\n\nnet::datagram_channel::datagram_channel() noexcept\n{}\n\nnet::datagram_channel::datagram_channel(std::unique_ptr impl) noexcept : _impl(std::move(impl))\n{}\n\nnet::datagram_channel::~datagram_channel()\n{}\n\nnet::datagram_channel::datagram_channel(datagram_channel&&) noexcept = default;\nnet::datagram_channel& net::datagram_channel::operator=(datagram_channel&&) noexcept = default;\n\nsocket_address net::datagram_channel::local_address() const {\n if (_impl) {\n return _impl->local_address();\n } else {\n return {};\n }\n}\n\nfuture net::datagram_channel::receive() {\n return _impl->receive();\n}\n\nfuture<> net::datagram_channel::send(const socket_address& dst, const char* msg) {\n return _impl->send(dst, msg);\n}\n\nfuture<> net::datagram_channel::send(const socket_address& dst, packet p) {\n return _impl->send(dst, std::move(p));\n}\n\nbool net::datagram_channel::is_closed() const {\n return _impl->is_closed();\n}\n\nvoid net::datagram_channel::shutdown_input() {\n _impl->shutdown_input();\n}\n\nvoid net::datagram_channel::shutdown_output() {\n _impl->shutdown_output();\n}\n\n\nvoid net::datagram_channel::close() {\n return _impl->close();\n}\n\nconnected_socket::connected_socket() noexcept\n{}\n\nconnected_socket::connected_socket(\n std::unique_ptr csi) noexcept\n : _csi(std::move(csi)) {\n}\n\nconnected_socket::connected_socket(connected_socket&& cs) noexcept = default;\nconnected_socket& connected_socket::operator=(connected_socket&& cs) noexcept = default;\n\nconnected_socket::~connected_socket()\n{}\n\ninput_stream connected_socket::input(connected_socket_input_stream_config csisc) {\n return input_stream(_csi->source(csisc));\n}\n\noutput_stream connected_socket::output(size_t buffer_size) {\n output_stream_options opts;\n opts.batch_flushes = true;\n // TODO: allow user to determine buffer size etc\n return output_stream(_csi->sink(), buffer_size, opts);\n}\n\nvoid connected_socket::set_nodelay(bool nodelay) {\n _csi->set_nodelay(nodelay);\n}\n\nbool connected_socket::get_nodelay() const {\n return _csi->get_nodelay();\n}\nvoid connected_socket::set_keepalive(bool keepalive) {\n _csi->set_keepalive(keepalive);\n}\nbool connected_socket::get_keepalive() const {\n return _csi->get_keepalive();\n}\nvoid connected_socket::set_keepalive_parameters(const net::keepalive_params& p) {\n _csi->set_keepalive_parameters(p);\n}\nnet::keepalive_params connected_socket::get_keepalive_parameters() const {\n return _csi->get_keepalive_parameters();\n}\nvoid connected_socket::set_sockopt(int level, int optname, const void* data, size_t len) {\n _csi->set_sockopt(level, optname, data, len);\n}\nint connected_socket::get_sockopt(int level, int optname, void* data, size_t len) const {\n return _csi->get_sockopt(level, optname, data, len);\n}\n\nsocket_address connected_socket::local_address() const noexcept {\n return _csi->local_address();\n}\n\nsocket_address connected_socket::remote_address() const noexcept {\n return _csi->remote_address();\n}\n\nvoid connected_socket::shutdown_output() {\n _csi->shutdown_output();\n}\n\nvoid connected_socket::shutdown_input() {\n _csi->shutdown_input();\n}\n\nfuture<> connected_socket::wait_input_shutdown() {\n return _csi->wait_input_shutdown();\n}\n\ndata_source\nnet::connected_socket_impl::source(connected_socket_input_stream_config csisc) {\n // Default implementation falls back to non-parameterized data_source\n return source();\n}\n\nsocket::~socket()\n{}\n\nsocket::socket(\n std::unique_ptr si) noexcept\n : _si(std::move(si)) {\n}\n\nsocket::socket(socket&&) noexcept = default;\nsocket& socket::operator=(socket&&) noexcept = default;\n\nfuture socket::connect(socket_address sa, socket_address local, transport proto) {\n return _si->connect(sa, local, proto);\n}\n\nvoid socket::set_reuseaddr(bool reuseaddr) {\n _si->set_reuseaddr(reuseaddr);\n}\n\nbool socket::get_reuseaddr() const {\n return _si->get_reuseaddr();\n}\n\nvoid socket::shutdown() {\n _si->shutdown();\n}\n\nserver_socket::server_socket() noexcept {\n}\n\nserver_socket::server_socket(std::unique_ptr ssi) noexcept\n : _ssi(std::move(ssi)) {\n}\nserver_socket::server_socket(server_socket&& ss) noexcept = default;\nserver_socket& server_socket::operator=(server_socket&& cs) noexcept = default;\n\nserver_socket::~server_socket() {\n}\n\nfuture server_socket::accept() {\n if (_aborted) {\n return make_exception_future(std::system_error(ECONNABORTED, std::system_category()));\n }\n return _ssi->accept();\n}\n\nvoid server_socket::abort_accept() {\n _ssi->abort_accept();\n _aborted = true;\n}\n\nsocket_address server_socket::local_address() const noexcept {\n return _ssi->local_address();\n}\n\nnetwork_interface::network_interface(shared_ptr impl) noexcept\n : _impl(std::move(impl))\n{}\n\nnetwork_interface::network_interface(network_interface&&) noexcept = default;\nnetwork_interface& network_interface::operator=(network_interface&&) noexcept = default;\n \nuint32_t network_interface::index() const {\n return _impl->index();\n}\n\nuint32_t network_interface::mtu() const {\n return _impl->mtu();\n}\n\nconst sstring& network_interface::name() const {\n return _impl->name();\n}\n\nconst sstring& network_interface::display_name() const {\n return _impl->display_name();\n}\n\nconst std::vector& network_interface::addresses() const {\n return _impl->addresses();\n}\n\nconst std::vector network_interface::hardware_address() const {\n return _impl->hardware_address();\n}\n\nbool network_interface::is_loopback() const {\n return _impl->is_loopback();\n}\n\nbool network_interface::is_virtual() const {\n return _impl->is_virtual();\n}\n\nbool network_interface::is_up() const {\n return _impl->is_up();\n}\n\nbool network_interface::supports_ipv6() const {\n return _impl->supports_ipv6();\n}\n\n\nfuture\nnetwork_stack::connect(socket_address sa, socket_address local, transport proto) {\n return do_with(socket(), [sa, local, proto](::seastar::socket& s) {\n return s.connect(sa, local, proto);\n });\n}\n\nstd::vector network_stack::network_interfaces() {\n return {};\n}\n\n}\n\n// Path: src/net/config.cc\n/*\n * This file is open source software, licensed to you under the terms\n * of the Apache License, Version 2.0 (the \"License\"). See the NOTICE file\n * distributed with this work for additional information regarding copyright\n * ownership. You may not use this file except in compliance with the License.\n *\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing,\n * software distributed under the License is distributed on an\n * \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY\n * KIND, either express or implied. See the License for the\n * specific language governing permissions and limitations\n * under the License.\n */\n/*\n * Copyright 2017 Marek Waszkiewicz ( marek.waszkiewicz77@gmail.com )\n */\n\n#ifdef SEASTAR_MODULE\nmodule;\n#endif\n\n#include \n#include \n#include \n#include \n#include \n#include \n#include \n#include \n\n#ifdef SEASTAR_MODULE\nmodule seastar;\n#else\n#include \n#include \n#endif\n\nusing namespace boost::algorithm;\n\nnamespace seastar {\nnamespace net {\n\n // list of supported config keys\n std::string config_keys[]{ \"pci-address\", \"port-index\", \"ip\", \"gateway\", \"netmask\", \"dhcp\", \"lro\", \"tso\", \"ufo\", \"hw-fc\", \"event-index\", \"csum-offload\",\"ring-size\" };\n\n std::unordered_map\n parse_config(std::istream& input) {\n std::unordered_map device_configs;\n\n YAML::Node doc = YAML::Load(input);\n for (auto&& item : doc) {\n device_configs[item.first.as()] = item.second.as();\n }\n\n bool port_index_used = false;\n bool pci_address_used = false;\n\n for (auto&& item : device_configs) {\n\n if (item.second.hw_cfg.port_index) {\n port_index_used = true;\n }\n\n if (!item.second.hw_cfg.pci_address.empty()) {\n pci_address_used = true;\n }\n\n if (port_index_used && pci_address_used) {\n throw config_exception(\"port_index and pci_address cannot be used together\");\n }\n }\n\n...\n", "needle_token_start": 717, "needle_token_end": 877, "code_context_ntokens": 16030, "output": ["```cpp\nfuture<> send(T && pkt) {\n pkt.dhp.bootp.xid = _xid;\n auto ipf = _stack.netif();\n auto mac = ipf->hw_address().mac;\n std::copy(mac.begin(), mac.end(), std::begin(pkt.dhp.bootp.chaddr));\n\n pkt = hton(pkt);\n\n // FIXME: future is discarded\n (void)_sock.send({0xffffffff, server_port}, packet(reinterpret_cast(&pkt), sizeof(pkt)));\n\n return make_ready_future<>();\n}\n```"]} +{"repo": "scylladb/seastar", "name": "sharded_parallel_for_each", "language": "cpp", "path": "src/core/sharded.cc", "position_ratio": 0.15, "description": "\nFunction Description:\n1. **Purpose**: To execute a given function concurrently across multiple shards, facilitating parallel processing in a distributed system.\n2. **Input**: The number of shards to operate on and a function that defines the operation to be performed on each shard.\n3. **Output**: A future object that completes when all shard operations have finished.\n4. **Procedure**: The function iterates over a range from 0 to the specified number of shards, applying the provided function concurrently to each shard using a parallel execution mechanism.\n\n", "instruction": "Based on the function description and code context, please retrieve and repeat the exact described function from the code context in a code block wrapped by ```:", "template": "instruction\ncode_context\ndescription\ninstruction", "code_context": "// Path: src/core/fair_queue.cc\n/*\n * This file is open source software, licensed to you under the terms\n * of the Apache License, Version 2.0 (the \"License\"). See the NOTICE file\n * distributed with this work for additional information regarding copyright\n * ownership. You may not use this file except in compliance with the License.\n *\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing,\n * software distributed under the License is distributed on an\n * \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY\n * KIND, either express or implied. See the License for the\n * specific language governing permissions and limitations\n * under the License.\n */\n/*\n * Copyright 2019 ScyllaDB\n */\n\n#ifdef SEASTAR_MODULE\nmodule;\n#endif\n\n#include \n#include \n#include \n#include \n#include \n#include \n#include \n\n#include \"fmt/format.h\"\n#include \"fmt/ostream.h\"\n\n#ifdef SEASTAR_MODULE\nmodule seastar;\n#else\n#include \n#include \n#include \n#include \n#include \n#include \n#include \n#endif\n\nnamespace seastar {\n\nstatic_assert(sizeof(fair_queue_ticket) == sizeof(uint64_t), \"unexpected fair_queue_ticket size\");\nstatic_assert(sizeof(fair_queue_entry) <= 3 * sizeof(void*), \"unexpected fair_queue_entry::_hook size\");\nstatic_assert(sizeof(fair_queue_entry::container_list_t) == 2 * sizeof(void*), \"unexpected priority_class::_queue size\");\n\nfair_queue_ticket::fair_queue_ticket(uint32_t weight, uint32_t size) noexcept\n : _weight(weight)\n , _size(size)\n{}\n\nfloat fair_queue_ticket::normalize(fair_queue_ticket denominator) const noexcept {\n return float(_weight) / denominator._weight + float(_size) / denominator._size;\n}\n\nfair_queue_ticket fair_queue_ticket::operator+(fair_queue_ticket desc) const noexcept {\n return fair_queue_ticket(_weight + desc._weight, _size + desc._size);\n}\n\n...\n// Path: src/core/dpdk_rte.cc\n/*\n * This file is open source software, licensed to you under the terms\n * of the Apache License, Version 2.0 (the \"License\"). See the NOTICE file\n * distributed with this work for additional information regarding copyright\n * ownership. You may not use this file except in compliance with the License.\n *\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing,\n * software distributed under the License is distributed on an\n * \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY\n * KIND, either express or implied. See the License for the\n * specific language governing permissions and limitations\n * under the License.\n */\n#ifdef SEASTAR_HAVE_DPDK\n\n#include \n#include \n#include \n#include \n#include \n#include \n\nnamespace seastar {\n\nnamespace dpdk {\n\nbool eal::initialized = false;\n\nvoid eal::init(cpuset cpus, const std::string& argv0, const std::optional& hugepages_path, bool dpdk_pmd)\n{\n if (initialized) {\n return;\n }\n\n size_t cpu_count = cpus.count();\n std::stringstream mask;\n cpuset nibble = 0xF;\n while (cpus.any()) {\n mask << std::hex << (cpus & nibble).to_ulong();\n cpus >>= 4;\n }\n\n std::string mask_str = mask.str();\n std::reverse(mask_str.begin(), mask_str.end());\n\n std::vector> args {\n string2vector(argv0),\n string2vector(\"-c\"), string2vector(mask_str),\n string2vector(\"-n\"), string2vector(\"1\")\n };\n\n // If \"hugepages\" is not provided and DPDK PMD drivers mode is requested -\n // use the default DPDK huge tables configuration.\n if (hugepages_path) {\n args.push_back(string2vector(\"--huge-dir\"));\n args.push_back(string2vector(hugepages_path.value()));\n\n //\n // We don't know what is going to be our networking configuration so we\n // assume there is going to be a queue per-CPU. Plus we'll give a DPDK\n // 64MB for \"other stuff\".\n //\n size_t size_MB = mem_size(cpu_count) >> 20;\n std::stringstream size_MB_str;\n size_MB_str << size_MB;\n\n args.push_back(string2vector(\"-m\"));\n args.push_back(string2vector(size_MB_str.str()));\n } else if (!dpdk_pmd) {\n args.push_back(string2vector(\"--no-huge\"));\n }\n#ifdef HAVE_OSV\n args.push_back(string2vector(\"--no-shconf\"));\n#endif\n\n std::vector cargs;\n\n for (auto&& a: args) {\n cargs.push_back(a.data());\n }\n /* initialise the EAL for all */\n int ret = rte_eal_init(cargs.size(), cargs.data());\n if (ret < 0) {\n rte_exit(EXIT_FAILURE, \"Cannot init EAL\\n\");\n }\n\n initialized = true;\n}\n\nuint32_t __attribute__((weak)) qp_mempool_obj_size(bool hugetlbfs_membackend)\n{\n return 0;\n}\n\nsize_t eal::mem_size(int num_cpus, bool hugetlbfs_membackend)\n{\n size_t memsize = 0;\n //\n // PMD mempool memory:\n //\n // We don't know what is going to be our networking configuration so we\n // assume there is going to be a queue per-CPU.\n //\n memsize += num_cpus * qp_mempool_obj_size(hugetlbfs_membackend);\n\n // Plus we'll give a DPDK 64MB for \"other stuff\".\n memsize += (64UL << 20);\n\n return memsize;\n}\n\n} // namespace dpdk\n\n}\n\n#endif // SEASTAR_HAVE_DPDK\n\n// Path: src/core/sharded.cc\n/*\n * This file is open source software, licensed to you under the terms\n * of the Apache License, Version 2.0 (the \"License\"). See the NOTICE file\n * distributed with this work for additional information regarding copyright\n * ownership. You may not use this file except in compliance with the License.\n *\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing,\n * software distributed under the License is distributed on an\n * \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY\n * KIND, either express or implied. See the License for the\n * specific language governing permissions and limitations\n * under the License.\n */\n/*\n * Copyright (C) 2018 ScyllaDB\n */\n\n#ifdef SEASTAR_MODULE\nmodule;\n#endif\n\n#include \n\n#ifdef SEASTAR_MODULE\nmodule seastar;\n#else\n#include \n#include \n#endif\n\nnamespace seastar {\n\nnamespace internal {\n\n\n\nfuture<>\nsharded_parallel_for_each(unsigned nr_shards, on_each_shard_func on_each_shard) noexcept(std::is_nothrow_move_constructible_v) {\n return parallel_for_each(boost::irange(0, nr_shards), std::move(on_each_shard));\n}\n\n}\n\n}\n\n// Path: src/core/uname.cc\n/*\n * This file is open source software, licensed to you under the terms\n * of the Apache License, Version 2.0 (the \"License\"). See the NOTICE file\n * distributed with this work for additional information regarding copyright\n * ownership. You may not use this file except in compliance with the License.\n *\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing,\n * software distributed under the License is distributed on an\n * \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY\n * KIND, either express or implied. See the License for the\n * specific language governing permissions and limitations\n * under the License.\n */\n\n/*\n * Copyright (C) 2019 ScyllaDB\n */\n\n#ifdef SEASTAR_MODULE\nmodule;\n#endif\n\n#include \n#include \n#include \n#include \n#include \n#include \n\n#ifdef SEASTAR_MODULE\nmodule seastar;\n#else\n#include \n#endif\n\nnamespace seastar {\n\nnamespace internal {\n\nint uname_t::component_count() const {\n if (distro_patch) {\n return 5;\n }\n if (subsublevel) {\n return 4;\n }\n if (sublevel) {\n return 3;\n }\n return 2;\n}\n\nbool uname_t::has_distro_extra(std::string extra) const {\n return distro_extra.find(extra) != std::string::npos;\n}\n\n// Can't use optional compares, C++17 only\nstatic int cmp(const std::optional& u1, const std::optional& u2) {\n return int(u1.value_or(0) - u2.value_or(0));\n}\n\nbool uname_t::same_as_or_descendant_of(const uname_t& x) const {\n if (version < x.version) {\n return false; // 4.2 vs. 5.1\n }\n if (version == x.version && patchlevel < x.patchlevel) {\n return false; // 4.0 vs 4.1\n }\n if (!has_distro_extra(x.distro_extra)) {\n return false;\n }\n switch (x.component_count()) {\n case 5:\n return version == x.version\n && patchlevel == x.patchlevel\n && cmp(sublevel, x.sublevel) == 0\n && cmp(subsublevel, x.subsublevel) == 0\n && cmp(distro_patch, x.distro_patch) >= 0;\n case 4:\n return version == x.version\n && patchlevel == x.patchlevel\n && cmp(sublevel, x.sublevel) == 0\n && cmp(subsublevel, x.subsublevel) >= 0;\n case 3:\n return version == x.version\n && patchlevel == x.patchlevel\n && cmp(sublevel, x.sublevel) >= 0;\n case 2:\n return true;\n default:\n return false;\n }\n}\n\nuname_t parse_uname(const char* u) {\n static std::regex re(R\"XX((\\d+)\\.(\\d+)(?:\\.(\\d+)(?:\\.(\\d+))?)?(?:-(\\d*)(.+))?)XX\");\n std::cmatch m;\n if (std::regex_match(u, m, re)) {\n auto num = [] (std::csub_match sm) -> std::optional {\n if (sm.length() > 0) {\n return std::atoi(sm.str().c_str());\n } else {\n return std::nullopt;\n }\n };\n return uname_t{*num(m[1]), *num(m[2]), num(m[3]), num(m[4]), num(m[5]), m[6].str()};\n } else {\n return uname_t{0, 0, std::nullopt, std::nullopt, std::nullopt, \"\"};\n }\n}\n\n\nbool uname_t::whitelisted(std::initializer_list wl) const {\n return boost::algorithm::any_of(wl, [this] (const char* v) {\n return same_as_or_descendant_of(parse_uname(v));\n });\n}\n\nstd::ostream& operator<<(std::ostream& os, const uname_t& u) {\n os << u.version << \".\" << u.patchlevel;\n if (u.sublevel) {\n os << \".\" << *u.sublevel;\n }\n if (u.subsublevel) {\n os << \".\" << *u.subsublevel;\n }\n if (u.distro_patch || !u.distro_extra.empty()) {\n os << \"-\";\n }\n if (u.distro_patch) {\n os << *u.distro_patch;\n }\n os << u.distro_extra;\n return os;\n}\n\n\nuname_t kernel_uname() {\n struct ::utsname buf;\n ::uname(&buf);\n return parse_uname(buf.release);\n}\n\n}\n}\n\n// Path: src/core/app-template.cc\n/*\n * This file is open source software, licensed to you under the terms\n * of the Apache License, Version 2.0 (the \"License\"). See the NOTICE file\n * distributed with this work for additional information regarding copyright\n * ownership. You may not use this file except in compliance with the License.\n *\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing,\n * software distributed under the License is distributed on an\n * \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY\n * KIND, either express or implied. See the License for the\n * specific language governing permissions and limitations\n * under the License.\n */\n/*\n * Copyright (C) 2014 Cloudius Systems, Ltd.\n */\n\n#ifdef SEASTAR_MODULE\nmodule;\n#endif\n\n#include \n#include \n#include \n#include \n#include \n#include \n#include \n\n#ifdef SEASTAR_MODULE\nmodule seastar;\n#else\n#include \n#include \n#include \n#include \n#include \n#include \n#include \n#include \n#include \n#include \"program_options.hh\"\n#endif\n\nnamespace seastar {\n\nnamespace bpo = boost::program_options;\n\nusing namespace std::chrono_literals;\n\nstatic\napp_template::seastar_options\nseastar_options_from_config(app_template::config cfg) {\n app_template::seastar_options opts;\n opts.name = std::move(cfg.name);\n opts.description = std::move(cfg.description);\n opts.auto_handle_sigint_sigterm = std::move(cfg.auto_handle_sigint_sigterm);\n opts.reactor_opts.task_quota_ms.set_default_value(cfg.default_task_quota / 1ms);\n opts.reactor_opts.max_networking_io_control_blocks.set_default_value(cfg.max_networking_aio_io_control_blocks);\n opts.smp_opts.reserve_additional_memory_per_shard = cfg.reserve_additional_memory_per_shard;\n return opts;\n}\n\napp_template::seastar_options::seastar_options()\n : program_options::option_group(nullptr, \"seastar\")\n , reactor_opts(this)\n , metrics_opts(this)\n , smp_opts(this)\n , scollectd_opts(this)\n , log_opts(this)\n{\n}\n\napp_template::app_template(app_template::seastar_options opts)\n : _alien(std::make_unique())\n , _smp(std::make_shared(*_alien))\n , _opts(std::move(opts))\n , _app_opts(_opts.name + \" options\")\n , _conf_reader(get_default_configuration_reader()) {\n\n if (!alien::internal::default_instance) {\n alien::internal::default_instance = _alien.get();\n }\n _app_opts.add_options()\n (\"help,h\", \"show help message\")\n ;\n _app_opts.add_options()\n (\"help-seastar\", \"show help message about seastar options\")\n ;\n _app_opts.add_options()\n (\"help-loggers\", \"print a list of logger names and exit\")\n ;\n\n {\n program_options::options_description_building_visitor visitor;\n _opts.describe(visitor);\n _opts_conf_file.add(std::move(visitor).get_options_description());\n }\n\n _seastar_opts.add(_opts_conf_file);\n}\n\napp_template::app_template(app_template::config cfg)\n : app_template(seastar_options_from_config(std::move(cfg)))\n{\n}\n\napp_template::~app_template() = default;\n\nconst app_template::seastar_options& app_template::options() const {\n return _opts;\n}\n\napp_template::configuration_reader app_template::get_default_configuration_reader() {\n return [this] (bpo::variables_map& configuration) {\n auto home = std::getenv(\"HOME\");\n if (home) {\n std::ifstream ifs(std::string(home) + \"/.config/seastar/seastar.conf\");\n if (ifs) {\n bpo::store(bpo::parse_config_file(ifs, _opts_conf_file), configuration);\n }\n std::ifstream ifs_io(std::string(home) + \"/.config/seastar/io.conf\");\n if (ifs_io) {\n bpo::store(bpo::parse_config_file(ifs_io, _opts_conf_file), configuration);\n }\n }\n };\n}\n\nvoid app_template::set_configuration_reader(configuration_reader conf_reader) {\n _conf_reader = conf_reader;\n}\n\nboost::program_options::options_description& app_template::get_options_description() {\n return _app_opts;\n}\n\nboost::program_options::options_description& app_template::get_conf_file_options_description() {\n return _opts_conf_file;\n}\n\nboost::program_options::options_description_easy_init\napp_template::add_options() {\n return _app_opts.add_options();\n}\n\nvoid\napp_template::add_positional_options(std::initializer_list options) {\n for (auto&& o : options) {\n _app_opts.add(boost::make_shared(o.name, o.value_semantic, o.help));\n _pos_opts.add(o.name, o.max_count);\n }\n}\n\n\nbpo::variables_map&\napp_template::configuration() {\n return *_configuration;\n}\n\nint\napp_template::run(int ac, char ** av, std::function ()>&& func) noexcept {\n return run_deprecated(ac, av, [func = std::move(func)] () mutable {\n auto func_done = make_lw_shared>();\n engine().at_exit([func_done] { return func_done->get_future(); });\n // No need to wait for this future.\n // func's returned exit_code is communicated via engine().exit()\n (void)futurize_invoke(func).finally([func_done] {\n func_done->set_value();\n }).then([] (int exit_code) {\n return engine().exit(exit_code);\n }).or_terminate();\n });\n}\n\nint\napp_template::run(int ac, char ** av, std::function ()>&& func) noexcept {\n return run(ac, av, [func = std::move(func)] {\n return func().then([] () {\n return 0;\n });\n });\n}\n\nint\napp_template::run_deprecated(int ac, char ** av, std::function&& func) noexcept {\n#ifdef SEASTAR_DEBUG\n fmt::print(std::cerr, \"WARNING: debug mode. Not for benchmarking or production\\n\");\n#endif\n boost::program_options::options_description all_opts;\n all_opts.add(_app_opts);\n all_opts.add(_seastar_opts);\n\n bpo::variables_map configuration;\n try {\n bpo::store(bpo::command_line_parser(ac, av)\n .options(all_opts)\n .positional(_pos_opts)\n .run()\n , configuration);\n _conf_reader(configuration);\n } catch (bpo::error& e) {\n fmt::print(\"error: {}\\n\\nTry --help.\\n\", e.what());\n return 2;\n }\n if (configuration.count(\"help\")) {\n if (!_opts.description.empty()) {\n std::cout << _opts.description << \"\\n\";\n }\n std::cout << _app_opts << \"\\n\";\n return 1;\n }\n if (configuration.count(\"help-seastar\")) {\n std::cout << _seastar_opts << \"\\n\";\n return 1;\n }\n if (configuration.count(\"help-loggers\")) {\n log_cli::print_available_loggers(std::cout);\n return 1;\n }\n\n try {\n bpo::notify(configuration);\n } catch (const bpo::error& ex) {\n std::cout << ex.what() << std::endl;\n return 1;\n }\n\n {\n program_options::variables_map_extracting_visitor visitor(configuration);\n _opts.mutate(visitor);\n }\n _opts.reactor_opts._argv0 = std::string(av[0]);\n _opts.reactor_opts._auto_handle_sigint_sigterm = _opts.auto_handle_sigint_sigterm;\n if (auto* native_stack = dynamic_cast(_opts.reactor_opts.network_stack.get_selected_candidate_opts())) {\n native_stack->_hugepages = _opts.smp_opts.hugepages;\n }\n\n // Needs to be before `smp::configure()`.\n try {\n apply_logging_settings(log_cli::extract_settings(_opts.log_opts));\n } catch (const std::runtime_error& exn) {\n std::cout << \"logging configuration error: \" << exn.what() << '\\n';\n return 1;\n }\n\n try {\n _smp->configure(_opts.smp_opts, _opts.reactor_opts);\n } catch (...) {\n std::cerr << \"Could not initialize seastar: \" << std::current_exception() << std::endl;\n return 1;\n }\n _configuration = {std::move(configuration)};\n // No need to wait for this future.\n // func is waited on via engine().run()\n (void)engine().when_started().then([this] {\n return seastar::metrics::configure(_opts.metrics_opts).then([this] {\n // set scollectd use the metrics configuration, so the later\n // need to be set first\n scollectd::configure( _opts.scollectd_opts);\n });\n }).then(\n std::move(func)\n ).then_wrapped([] (auto&& f) {\n try {\n f.get();\n } catch (std::exception& ex) {\n std::cout << \"program failed with uncaught exception: \" << ex.what() << \"\\n\";\n engine().exit(1);\n }\n });\n auto exit_code = engine().run();\n _smp->cleanup();\n return exit_code;\n}\n\n}\n\n// Path: src/core/program_options.cc\n/*\n * This file is open source software, licensed to you under the terms\n * of the Apache License, Version 2.0 (the \"License\"). See the NOTICE file\n * distributed with this work for additional information regarding copyright\n * ownership. You may not use this file except in compliance with the License.\n *\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing,\n * software distributed under the License is distributed on an\n * \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY\n * KIND, either express or implied. See the License for the\n * specific language governing permissions and limitations\n * under the License.\n */\n/*\n * Copyright (C) 2021 Cloudius Systems, Ltd.\n */\n\n#ifdef SEASTAR_MODULE\nmodule;\n#include \n#include \n#include \n#include \n#include \n#include \nmodule seastar;\n#else\n#include \"core/program_options.hh\"\n\n#include \n#include \n#include \n#include \n#include \n#endif\n\nnamespace seastar::program_options {\n\nnamespace {\n\nconst char* to_string(memory::alloc_failure_kind val) {\n switch (val) {\n case memory::alloc_failure_kind::none: return \"none\";\n case memory::alloc_failure_kind::critical: return \"critical\";\n case memory::alloc_failure_kind::all: return \"all\";\n }\n std::abort();\n}\n\nconst char* to_string(log_level val) {\n switch (val) {\n case log_level::error: return \"error\";\n case log_level::warn: return \"warn\";\n case log_level::info: return \"info\";\n case log_level::debug: return \"debug\";\n case log_level::trace: return \"trace\";\n }\n std::abort();\n}\n\nconst char* to_string(logger_timestamp_style val) {\n switch (val) {\n case logger_timestamp_style::none: return \"none\";\n case logger_timestamp_style::boot: return \"boot\";\n case logger_timestamp_style::real: return \"real\";\n }\n std::abort();\n}\n\nconst char* to_string(logger_ostream_type val) {\n switch (val) {\n case logger_ostream_type::none: return \"none\";\n case logger_ostream_type::cout: return \"stdout\";\n case logger_ostream_type::cerr: return \"stderr\";\n }\n std::abort();\n}\n\nmemory::alloc_failure_kind from_string(const std::string& val, boost::type) {\n if (val == \"none\") {\n return memory::alloc_failure_kind::none;\n } else if (val == \"critical\") {\n return memory::alloc_failure_kind::critical;\n } else if (val == \"all\") {\n return memory::alloc_failure_kind::all;\n }\n throw std::runtime_error(fmt::format(\"Invalid value for enum memory::alloc_failure_kind: {}\", val));\n}\n\nlog_level from_string(const std::string& val, boost::type) {\n if (val == \"error\") {\n return log_level::error;\n } else if (val == \"warn\") {\n return log_level::warn;\n } else if (val == \"info\") {\n return log_level::info;\n } else if (val == \"debug\") {\n return log_level::debug;\n } else if (val == \"trace\") {\n return log_level::trace;\n }\n throw std::runtime_error(fmt::format(\"Invalid value for enum log_level: {}\", val));\n}\n\nlogger_timestamp_style from_string(const std::string& val, boost::type) {\n if (val == \"none\") {\n return logger_timestamp_style::none;\n } else if (val == \"boot\") {\n return logger_timestamp_style::boot;\n } else if (val == \"real\") {\n return logger_timestamp_style::real;\n }\n throw std::runtime_error(fmt::format(\"Invalid value for enum logger_timestamp_style: {}\", val));\n}\n\nlogger_ostream_type from_string(const std::string& val, boost::type) {\n if (val == \"none\") {\n return logger_ostream_type::none;\n } else if (val == \"stdout\") {\n return logger_ostream_type::cout;\n } else if (val == \"stderr\") {\n return logger_ostream_type::cerr;\n }\n throw std::runtime_error(fmt::format(\"Invalid value for enum logger_ostream_type: {}\", val));\n}\n\ntemplate \nvoid describe_value(bpo::options_description& opts, const std::string& name, const std::string& description, const Type& default_value) {\n opts.add_options()(name.c_str(), boost::program_options::value()->default_value(default_value), description.c_str());\n}\n\ntemplate \nvoid describe_value(bpo::options_description& opts, const options_description_building_visitor::value_metadata& d, const Type& default_value) {\n describe_value(opts, d.name, d.description, default_value);\n}\n\ntemplate \nvoid describe_value(bpo::options_description& opts, const std::string& name, const std::string& description) {\n opts.add_options()(name.c_str(), boost::program_options::value(), description.c_str());\n}\n\ntemplate \nvoid describe_value(bpo::options_description& opts, const options_description_building_visitor::value_metadata& d) {\n describe_value(opts, d.name, d.description);\n}\n\ntemplate \nvoid describe_value_maybe_default(bpo::options_description& opts, const std::string& name, const std::string& description, const Type* default_value) {\n if (default_value) {\n describe_value(opts, name, description, *default_value);\n } else {\n describe_value(opts, name, description);\n }\n}\n\ntemplate \nvoid describe_value_maybe_default(bpo::options_description& opts, const options_description_building_visitor::value_metadata& d, const Type* default_value) {\n describe_value_maybe_default(opts, d.name, d.description, default_value);\n}\n\ntemplate \nvoid describe_enum_value(bpo::options_description& opts, const options_description_building_visitor::value_metadata& d, const Enum* default_value) {\n if (default_value) {\n opts.add_options()(d.name.c_str(), boost::program_options::value()->default_value(to_string(*default_value)), d.description.c_str());\n } else {\n opts.add_options()(d.name.c_str(), boost::program_options::value(), d.description.c_str());\n }\n}\n\ntemplate \nbool extract_value(const bpo::variables_map& values, const std::string& current_name, T& val) {\n auto it = values.find(current_name);\n if (it == values.end() || it->second.defaulted()) {\n return false;\n }\n val = it->second.as();\n return true;\n}\n\ntemplate \nbool extract_enum_value(const bpo::variables_map& values, const std::string& current_name, T& val) {\n auto it = values.find(current_name);\n if (it == values.end() || it->second.defaulted()) {\n return false;\n }\n val = from_string(it->second.as(), boost::type{});\n return true;\n}\n\n} // anonymous namespace\n\nbool options_description_building_visitor::visit_group_start(const std::string& name, bool used) {\n _groups.push({name, bpo::options_description(name.c_str()), used});\n return used;\n}\nvoid options_description_building_visitor::visit_group_end() {\n if (_groups.size() == 1) {\n return;\n }\n auto grp = std::move(_groups.top());\n _groups.pop();\n if (grp.used && grp.values) {\n _groups.top().description.add(std::move(grp.description));\n }\n}\n\nbool options_description_building_visitor::visit_value_metadata(const std::string& name, const std::string& description, bool used) {\n if (!used) {\n return false;\n }\n ++_groups.top().values;\n _current_metadata.emplace(value_metadata{name, description});\n return true;\n}\n\nvoid options_description_building_visitor::visit_value() {\n _groups.top().description.add_options()(_current_metadata->name.c_str(), _current_metadata->description.c_str());\n}\n\nvoid options_description_building_visitor::visit_value(const bool* default_value) {\n describe_value_maybe_default(_groups.top().description, *_current_metadata, default_value);\n}\n\nvoid options_description_building_visitor::visit_value(const int* default_value) {\n describe_value_maybe_default(_groups.top().description, *_current_metadata, default_value);\n}\n\nvoid options_description_building_visitor::visit_value(const unsigned* default_value) {\n auto name = _current_metadata->name;\n if (_current_metadata->name == \"smp\") {\n name = \"smp,c\";\n }\n describe_value_maybe_default(_groups.top().description, name, _current_metadata->description, default_value);\n}\n\nvoid options_description_building_visitor::visit_value(const float* default_value) {\n describe_value_maybe_default(_groups.top().description, *_current_metadata, default_value);\n}\n\nvoid options_description_building_visitor::visit_value(const double* default_value) {\n describe_value_maybe_default(_groups.top().description, *_current_metadata, default_value);\n}\n\nvoid options_description_building_visitor::visit_value(const std::string* default_value) {\n auto name = _current_metadata->name;\n if (_current_metadata->name == \"memory\") {\n name = \"memory,m\";\n }\n describe_value_maybe_default(_groups.top().description, name, _current_metadata->description, default_value);\n}\n\nvoid options_description_building_visitor::visit_value(const std::set*) {\n describe_value(_groups.top().description, *_current_metadata);\n}\n\nvoid options_description_building_visitor::visit_value(const memory::alloc_failure_kind* default_value) {\n describe_enum_value(_groups.top().description, *_current_metadata, default_value);\n}\n\nvoid options_description_building_visitor::visit_value(const log_level* default_value) {\n describe_enum_value(_groups.top().description, *_current_metadata, default_value);\n}\n\nvoid options_description_building_visitor::visit_value(const logger_timestamp_style* default_value) {\n describe_enum_value(_groups.top().description, *_current_metadata, default_value);\n}\n\nvoid options_description_building_visitor::visit_value(const logger_ostream_type* default_value) {\n describe_enum_value(_groups.top().description, *_current_metadata, default_value);\n}\n\nvoid options_description_building_visitor::visit_value(const std::unordered_map*) {\n describe_value>(_groups.top().description, *_current_metadata);\n}\n\nvoid options_description_building_visitor::visit_selection_value(const std::vector& candidates, const std::size_t* selected_candidate) {\n if (selected_candidate) {\n describe_value(_groups.top().description, *_current_metadata, candidates.at(*selected_candidate));\n } else {\n describe_value(_groups.top().description, *_current_metadata);\n }\n}\n\nvariables_map_extracting_visitor::variables_map_extracting_visitor(const bpo::variables_map& values) : _values(values) {\n}\n\nbool variables_map_extracting_visitor::visit_group_start(const std::string& name, bool used) {\n return used;\n}\n\nvoid variables_map_extracting_visitor::visit_group_end() {\n}\n\nbool variables_map_extracting_visitor::visit_value_metadata(const std::string& name, bool used) {\n if (used) {\n _current_name = &name;\n return true;\n } else {\n _current_name = nullptr;\n return false;\n }\n}\n\nbool variables_map_extracting_visitor::visit_value() {\n return _values.count(*_current_name);\n}\n\nbool variables_map_extracting_visitor::visit_value(bool& val) {\n return extract_value(_values, *_current_name, val);\n}\n\nbool variables_map_extracting_visitor::visit_value(int& val) {\n return extract_value(_values, *_current_name, val);\n}\n\nbool variables_map_extracting_visitor::visit_value(unsigned& val) {\n return extract_value(_values, *_current_name, val);\n}\n\nbool variables_map_extracting_visitor::visit_value(float& val) {\n return extract_value(_values, *_current_name, val);\n}\n\nbool variables_map_extracting_visitor::visit_value(double& val) {\n return extract_value(_values, *_current_name, val);\n}\n\nbool variables_map_extracting_visitor::visit_value(std::string& val) {\n return extract_value(_values, *_current_name, val);\n}\n\nbool variables_map_extracting_visitor::visit_value(std::set& val) {\n std::string raw_val;\n if (!extract_value(_values, *_current_name, raw_val)) {\n return false;\n }\n if (auto parsed_cpu_set = resource::parse_cpuset(raw_val)) {\n val = std::move(*parsed_cpu_set);\n return true;\n }\n throw std::invalid_argument(fmt::format(\"invalid value for option {}: failed to parse cpuset: {}\", *_current_name, raw_val));\n}\n\nbool variables_map_extracting_visitor::visit_value(log_level& val) {\n return extract_enum_value(_values, *_current_name, val);\n}\n\nbool variables_map_extracting_visitor::visit_value(logger_timestamp_style& val) {\n return extract_enum_value(_values, *_current_name, val);\n}\n\nbool variables_map_extracting_visitor::visit_value(logger_ostream_type& val) {\n return extract_enum_value(_values, *_current_name, val);\n}\n\nbool variables_map_extracting_visitor::visit_value(memory::alloc_failure_kind& val) {\n return extract_enum_value(_values, *_current_name, val);\n}\n\nbool variables_map_extracting_visitor::visit_value(std::unordered_map& val) {\n std::vector raw_val;\n if (!extract_value(_values, *_current_name, raw_val)) {\n return false;\n }\n for (const auto& e : raw_val) {\n log_cli::parse_map_associations(e, [&val] (std::string k, std::string v) { val[std::move(k)] = log_cli::parse_log_level(v); });\n }\n return !val.empty();\n}\n\nbool variables_map_extracting_visitor::visit_selection_value(const std::vector& candidates, std::size_t& selected_candidate) {\n std::string candidate_name;\n if (!extract_value(_values, *_current_name, candidate_name)) {\n return false;\n }\n auto it = std::find(candidates.begin(), candidates.end(), candidate_name);\n if (it == candidates.end()) {\n throw std::invalid_argument(fmt::format(\"invalid value for option {}: selected candidate doesn't exist: {}\", *_current_name, candidate_name));\n }\n selected_candidate = it - candidates.begin();\n return true;\n}\n\n} // namespace seastar::program_options\n\n// Path: src/core/future-util.cc\n/*\n * This file is open source software, licensed to you under the terms\n * of the Apache License, Version 2.0 (the \"License\"). See the NOTICE file\n * distributed with this work for additional information regarding copyright\n * ownership. You may not use this file except in compliance with the License.\n *\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing,\n * software distributed under the License is distributed on an\n * \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY\n * KIND, either express or implied. See the License for the\n * specific language governing permissions and limitations\n * under the License.\n */\n/*\n * Copyright (C) 2017 ScyllaDB\n */\n#ifdef SEASTAR_MODULE\nmodule;\n#include \n#include \n#include \n#include \n#include \n#include \n#include \nmodule seastar;\n#else\n#include \n#include \n#include \n#include \n#include \n#endif\n\nnamespace seastar {\n\nparallel_for_each_state::parallel_for_each_state(size_t n) {\n _incomplete.reserve(n);\n}\n\nfuture<> parallel_for_each_state::get_future() {\n auto ret = _result.get_future();\n wait_for_one();\n return ret;\n}\n\nvoid parallel_for_each_state::add_future(future<>&& f) {\n _incomplete.push_back(std::move(f));\n}\n\nvoid parallel_for_each_state::wait_for_one() noexcept {\n // Process from back to front, on the assumption that the front\n // futures are likely to complete earlier than the back futures.\n // If that's indeed the case, then the front futures will be\n // available and we won't have to wait for them.\n\n // Skip over futures that happen to be complete already.\n while (!_incomplete.empty() && _incomplete.back().available()) {\n if (_incomplete.back().failed()) {\n _ex = _incomplete.back().get_exception();\n }\n _incomplete.pop_back();\n }\n\n // If there's an incompelete future, wait for it.\n if (!_incomplete.empty()) {\n internal::set_callback(std::move(_incomplete.back()), static_cast*>(this));\n // This future's state will be collected in run_and_dispose(), so we can drop it.\n _incomplete.pop_back();\n return;\n }\n\n // Everything completed, report a result.\n if (__builtin_expect(bool(_ex), false)) {\n _result.set_exception(std::move(_ex));\n } else {\n _result.set_value();\n }\n delete this;\n}\n\nvoid parallel_for_each_state::run_and_dispose() noexcept {\n if (_state.failed()) {\n _ex = std::move(_state).get_exception();\n }\n _state = {};\n wait_for_one();\n}\n\ntemplate \nfuture<> sleep_abortable(typename Clock::duration dur) {\n return engine().wait_for_stop(dur).then([] {\n throw sleep_aborted();\n }).handle_exception([] (std::exception_ptr ep) {\n try {\n std::rethrow_exception(ep);\n } catch(condition_variable_timed_out&) {};\n });\n}\n\ntemplate future<> sleep_abortable(typename steady_clock_type::duration);\ntemplate future<> sleep_abortable(typename lowres_clock::duration);\n\ntemplate \nfuture<> sleep_abortable(typename Clock::duration dur, abort_source& as) {\n struct sleeper {\n promise<> done;\n timer tmr;\n abort_source::subscription sc;\n\n sleeper(typename Clock::duration dur, abort_source& as)\n : tmr([this] { done.set_value(); }) {\n auto sc_opt = as.subscribe([this] (const std::optional& opt_ex) noexcept {\n if (tmr.cancel()) {\n done.set_exception(opt_ex.value_or(std::make_exception_ptr(sleep_aborted())));\n }\n });\n if (sc_opt) {\n sc = std::move(*sc_opt);\n tmr.arm(dur);\n } else {\n done.set_exception(sleep_aborted());\n }\n }\n };\n //FIXME: Use do_with() after #373\n auto s = std::make_unique(dur, as);\n auto fut = s->done.get_future();\n return fut.finally([s = std::move(s)] { });\n}\n\ntemplate future<> sleep_abortable(typename steady_clock_type::duration, abort_source&);\ntemplate future<> sleep_abortable(typename lowres_clock::duration, abort_source&);\ntemplate future<> sleep_abortable(typename manual_clock::duration, abort_source&);\n\n}\n\n// Path: src/core/future.cc\n/*\n * This file is open source software, licensed to you under the terms\n * of the Apache License, Version 2.0 (the \"License\"). See the NOTICE file\n * distributed with this work for additional information regarding copyright\n * ownership. You may not use this file except in compliance with the License.\n *\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing,\n * software distributed under the License is distributed on an\n * \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY\n * KIND, either express or implied. See the License for the\n * specific language governing permissions and limitations\n * under the License.\n */\n/*\n * Copyright (C) 2020 ScyllaDB\n */\n\n#ifdef SEASTAR_MODULE\nmodule;\n#include \n#include \n#include \n#include \n#include \nmodule seastar;\n#else\n#include \n#include \n#include \n#include \n#include \n#endif\n\nnamespace seastar {\n\n// We can't test future_state_base directly because its private\n// destructor is protected.\nstatic_assert(std::is_nothrow_move_constructible_v>>,\n \"future_state's move constructor must not throw\");\n\nstatic_assert(sizeof(future_state>) <= 8, \"future_state> is too large\");\nstatic_assert(sizeof(future_state>) <= 16, \"future_state> is too large\");\nstatic_assert(future_state>::has_trivial_move_and_destroy, \"future_state> not trivial\");\nstatic_assert(future_state::has_trivial_move_and_destroy, \"future_state not trivial\");\n\n// We need to be able to move and copy std::exception_ptr in and out\n// of future/promise/continuations without that producing a new\n// exception.\nstatic_assert(std::is_nothrow_copy_constructible_v,\n \"std::exception_ptr's copy constructor must not throw\");\nstatic_assert(std::is_nothrow_move_constructible_v,\n \"std::exception_ptr's move constructor must not throw\");\n\nnamespace internal {\n\nstatic_assert(std::is_empty_v>>, \"This should still be empty\");\n\nvoid promise_base::move_it(promise_base&& x) noexcept {\n // Don't use std::exchange to make sure x's values are nulled even\n // if &x == this.\n _task = x._task;\n x._task = nullptr;\n#ifdef SEASTAR_DEBUG_PROMISE\n _task_shard = x._task_shard;\n#endif\n _state = x._state;\n x._state = nullptr;\n _future = x._future;\n if (auto* fut = _future) {\n fut->detach_promise();\n fut->_promise = this;\n }\n}\n\nstatic void set_to_broken_promise(future_state_base& state) noexcept {\n try {\n // Constructing broken_promise may throw (std::logic_error ctor is not noexcept).\n state.set_exception(std::make_exception_ptr(broken_promise{}));\n } catch (...) {\n state.set_exception(std::current_exception());\n }\n}\n\npromise_base::promise_base(promise_base&& x) noexcept {\n move_it(std::move(x));\n}\n\nvoid promise_base::clear() noexcept {\n if (__builtin_expect(bool(_task), false)) {\n assert(_state && !_state->available());\n set_to_broken_promise(*_state);\n ::seastar::schedule(std::exchange(_task, nullptr));\n }\n if (_future) {\n assert(_state);\n if (!_state->available()) {\n set_to_broken_promise(*_state);\n }\n _future->detach_promise();\n }\n}\n\npromise_base& promise_base::operator=(promise_base&& x) noexcept {\n clear();\n move_it(std::move(x));\n return *this;\n}\n\nvoid promise_base::set_to_current_exception() noexcept {\n set_exception(std::current_exception());\n}\n\n#ifdef SEASTAR_DEBUG_PROMISE\n\nvoid promise_base::assert_task_shard() const noexcept {\n if (_task_shard >= 0 && static_cast(_task_shard) != this_shard_id()) {\n on_fatal_internal_error(seastar_logger, format(\"Promise task was set on shard {} but made ready on shard {}\", _task_shard, this_shard_id()));\n }\n}\n\n#endif\n\ntemplate \nvoid promise_base::make_ready() noexcept {\n if (_task) {\n assert_task_shard();\n if (Urgent == urgent::yes) {\n ::seastar::schedule_urgent(std::exchange(_task, nullptr));\n } else {\n ::seastar::schedule(std::exchange(_task, nullptr));\n }\n }\n}\n\ntemplate void promise_base::make_ready() noexcept;\ntemplate void promise_base::make_ready() noexcept;\n}\n\ntemplate\nfuture current_exception_as_future() noexcept;\n\n/**\n * engine_exit() exits the reactor. It should be given a pointer to the\n * exception which prompted this exit - or a null pointer if the exit\n * request was not caused by any exception.\n */\nvoid engine_exit(std::exception_ptr eptr) {\n if (!eptr) {\n engine().exit(0);\n return;\n }\n report_exception(\"Exiting on unhandled exception\", eptr);\n engine().exit(1);\n}\n\nbroken_promise::broken_promise() : logic_error(\"broken promise\") { }\n\nfuture_state_base::future_state_base(current_exception_future_marker) noexcept\n : future_state_base(std::current_exception()) { }\n\nvoid future_state_base::ignore() noexcept {\n switch (_u.st) {\n case state::invalid:\n case state::future:\n case state::result_unavailable:\n assert(0 && \"invalid state for ignore\");\n case state::result:\n _u.st = state::result_unavailable;\n break;\n default:\n // Ignore the exception\n _u.take_exception();\n }\n}\n\nnested_exception::nested_exception(std::exception_ptr inner, std::exception_ptr outer) noexcept\n : inner(std::move(inner)), outer(std::move(outer)) {}\n\nnested_exception::nested_exception(nested_exception&&) noexcept = default;\n\nnested_exception::nested_exception(const nested_exception&) noexcept = default;\n\nconst char* nested_exception::what() const noexcept {\n return \"seastar::nested_exception\";\n}\n\n[[noreturn]] void nested_exception::rethrow_nested() const {\n std::rethrow_exception(outer);\n}\n\nstatic std::exception_ptr make_nested(std::exception_ptr&& inner, future_state_base&& old) noexcept {\n std::exception_ptr outer = std::move(old).get_exception();\n nested_exception nested{std::move(inner), std::move(outer)};\n return std::make_exception_ptr(std::move(nested));\n}\n\nfuture_state_base::future_state_base(nested_exception_marker, future_state_base&& n, future_state_base&& old) noexcept {\n std::exception_ptr inner = std::move(n).get_exception();\n if (!old.failed()) {\n new (this) future_state_base(std::move(inner));\n } else {\n new (this) future_state_base(make_nested(std::move(inner), std::move(old)));\n }\n}\n\nfuture_state_base::future_state_base(nested_exception_marker, future_state_base&& old) noexcept {\n if (!old.failed()) {\n new (this) future_state_base(current_exception_future_marker());\n return;\n } else {\n new (this) future_state_base(make_nested(std::current_exception(), std::move(old)));\n }\n}\n\nvoid future_state_base::rethrow_exception() && {\n // Move ex out so future::~future() knows we've handled it\n std::rethrow_exception(std::move(*this).get_exception());\n}\n\nvoid future_state_base::rethrow_exception() const& {\n std::rethrow_exception(_u.ex);\n}\n\nvoid report_failed_future(const std::exception_ptr& eptr) noexcept {\n ++engine()._abandoned_failed_futures;\n seastar_logger.warn(\"Exceptional future ignored: {}, backtrace: {}\", eptr, current_backtrace());\n}\n\nvoid report_failed_future(const future_state_base& state) noexcept {\n report_failed_future(state._u.ex);\n}\n\nvoid report_failed_future(future_state_base::any&& state) noexcept {\n report_failed_future(std::move(state).take_exception());\n}\n\nvoid reactor::test::with_allow_abandoned_failed_futures(unsigned count, noncopyable_function func) {\n auto before = engine()._abandoned_failed_futures;\n auto old_level = seastar_logger.level();\n seastar_logger.set_level(log_level::error);\n func();\n auto after = engine()._abandoned_failed_futures;\n assert(after - before == count);\n engine()._abandoned_failed_futures = before;\n seastar_logger.set_level(old_level);\n}\n\nnamespace {\nclass thread_wake_task final : public task {\n thread_context* _thread;\npublic:\n thread_wake_task(thread_context* thread) noexcept : _thread(thread) {}\n virtual void run_and_dispose() noexcept override {\n thread_impl::switch_in(_thread);\n // no need to delete, since this is always allocated on\n // _thread's stack.\n }\n /// Returns the task which is waiting for this thread to be done, or nullptr.\n virtual task* waiting_task() noexcept override {\n return _thread->waiting_task();\n }\n};\n}\n\nvoid internal::future_base::do_wait() noexcept {\n auto thread = thread_impl::get();\n assert(thread);\n thread_wake_task wake_task{thread};\n wake_task.make_backtrace();\n _promise->set_task(&wake_task);\n thread_impl::switch_out(thread);\n}\n\n#ifdef SEASTAR_COROUTINES_ENABLED\nvoid internal::future_base::set_coroutine(task& coroutine) noexcept {\n assert(_promise);\n _promise->set_task(&coroutine);\n}\n#endif\n\n}\n\n// Path: src/core/resource.cc\n\n/*\n * This file is open source software, licensed to you under the terms\n * of the Apache License, Version 2.0 (the \"License\"). See the NOTICE file\n * distributed with this work for additional information regarding copyright\n * ownership. You may not use this file except in compliance with the License.\n *\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing,\n...\n", "needle_token_start": 2313, "needle_token_end": 2410, "code_context_ntokens": 16017, "output": ["```cpp\nfuture<>\nsharded_parallel_for_each(unsigned nr_shards, on_each_shard_func on_each_shard) noexcept(std::is_nothrow_move_constructible_v) {\n return parallel_for_each(boost::irange(0, nr_shards), std::move(on_each_shard));\n}\n```"]} +{"repo": "scylladb/seastar", "name": "enable_abort_on_allocation_failure", "language": "cpp", "path": "src/core/memory.cc", "position_ratio": 0.25, "description": "\nFunction Description:\n1. **Purpose**: To configure the system to terminate the process when memory allocation fails.\n2. **Input**: None.\n3. **Output**: None.\n4. **Procedure**: This function calls another function that sets a flag to true, indicating that the system should abort when it cannot allocate memory.\n\n", "instruction": "Based on the function description and code context, please retrieve and repeat the exact described function from the code context in a code block wrapped by ```:", "template": "instruction\ncode_context\ndescription\ninstruction", "code_context": "// Path: src/core/file.cc\n/*\n * This file is open source software, licensed to you under the terms\n * of the Apache License, Version 2.0 (the \"License\"). See the NOTICE file\n * distributed with this work for additional information regarding copyright\n * ownership. You may not use this file except in compliance with the License.\n *\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing,\n * software distributed under the License is distributed on an\n * \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY\n * KIND, either express or implied. See the License for the\n * specific language governing permissions and limitations\n * under the License.\n */\n/*\n * Copyright 2019 ScyllaDB\n */\n\n#ifdef SEASTAR_MODULE\nmodule;\n#endif\n\n#include \n#include \n#ifdef SEASTAR_COROUTINES_ENABLED\n#include \n#endif\n#include \n#include \n#include \n#include \n#include \n#include \n\n#define __user /* empty */ // for xfs includes, below\n\n#include \n#include \n#include // for xfs, below\n#include // BLKBSZGET\n#include \n#include \n#include \n#include \n#include \n/*\n * With package xfsprogs-devel >= 5.14.1, `fallthrough` has defined to\n * fix compilation warning in header ,\n * (see: https://git.kernel.org/pub/scm/fs/xfs/xfsprogs-dev.git/commit/?id=df9c7d8d8f3ed0785ed83e7fd0c7ddc92cbfbe15)\n * There is a confliction with c++ keyword `fallthrough`, so undefine fallthrough here.\n */\n#undef fallthrough \n#define min min /* prevent xfs.h from defining min() as a macro */\n#include \n#undef min\n\n#ifdef SEASTAR_MODULE\nmodule seastar;\n#else\n#include \n#include \n#include \n#include \n#include \n#include \n#include \n#include \n#include \n#include \n#include \n#include \"core/file-impl.hh\"\n#include \"core/syscall_result.hh\"\n#include \"core/thread_pool.hh\"\n#endif\n\nnamespace seastar {\n\n#if SEASTAR_API_LEVEL < 7\nstatic_assert(std::is_nothrow_copy_constructible_v);\nstatic_assert(std::is_nothrow_move_constructible_v);\n#endif\n\nnamespace internal {\n\nstruct fs_info {\n uint32_t block_size;\n bool append_challenged;\n unsigned append_concurrency;\n bool fsync_is_exclusive;\n bool nowait_works;\n std::optional dioinfo;\n};\n\n};\n\nusing namespace internal::linux_abi;\n\nfile_handle::file_handle(const file_handle& x)\n : _impl(x._impl ? x._impl->clone() : std::unique_ptr()) {\n}\n\nfile_handle::file_handle(file_handle&& x) noexcept = default;\n\nfile_handle&\nfile_handle::operator=(const file_handle& x) {\n return operator=(file_handle(x));\n}\n\nfile_handle&\nfile_handle::operator=(file_handle&&) noexcept = default;\n\nfile\nfile_handle::to_file() const & {\n return file_handle(*this).to_file();\n}\n\nfile\nfile_handle::to_file() && {\n return file(std::move(*_impl).to_file());\n}\n\nposix_file_impl::posix_file_impl(int fd, open_flags f, file_open_options options, dev_t device_id, bool nowait_works)\n : _device_id(device_id)\n , _nowait_works(nowait_works)\n , _io_queue(engine().get_io_queue(_device_id))\n , _open_flags(f)\n , _fd(fd)\n{\n configure_io_lengths();\n}\n\nposix_file_impl::posix_file_impl(int fd, open_flags f, file_open_options options, dev_t device_id, const internal::fs_info& fsi)\n : posix_file_impl(fd, f, options, device_id, fsi.nowait_works)\n{\n configure_dma_alignment(fsi);\n}\n\nposix_file_impl::~posix_file_impl() {\n if (_refcount && _refcount->fetch_add(-1, std::memory_order_relaxed) != 1) {\n return;\n }\n delete _refcount;\n if (_fd != -1) {\n // Note: close() can be a blocking operation on NFS\n ::close(_fd);\n }\n}\n\nvoid\nposix_file_impl::configure_dma_alignment(const internal::fs_info& fsi) {\n if (fsi.dioinfo) {\n const dioattr& da = *fsi.dioinfo;\n _memory_dma_alignment = da.d_mem;\n _disk_read_dma_alignment = da.d_miniosz;\n // xfs wants at least the block size for writes\n // FIXME: really read the block size\n _disk_write_dma_alignment = std::max(da.d_miniosz, fsi.block_size);\n static bool xfs_with_relaxed_overwrite_alignment = internal::kernel_uname().whitelisted({\"5.12\"});\n _disk_overwrite_dma_alignment = xfs_with_relaxed_overwrite_alignment ? da.d_miniosz : _disk_write_dma_alignment;\n }\n}\n\nvoid posix_file_impl::configure_io_lengths() noexcept {\n auto limits = _io_queue.get_request_limits();\n _read_max_length = std::min(_read_max_length, limits.max_read);\n _write_max_length = std::min(_write_max_length, limits.max_write);\n}\n\nstd::unique_ptr\nposix_file_impl::dup() {\n if (!_refcount) {\n _refcount = new std::atomic(1u);\n }\n auto ret = std::make_unique(_fd, _open_flags, _refcount, _device_id,\n _memory_dma_alignment, _disk_read_dma_alignment, _disk_write_dma_alignment, _disk_overwrite_dma_alignment,\n _nowait_works);\n _refcount->fetch_add(1, std::memory_order_relaxed);\n return ret;\n}\n\nposix_file_impl::posix_file_impl(int fd, open_flags f, std::atomic* refcount, dev_t device_id,\n uint32_t memory_dma_alignment,\n uint32_t disk_read_dma_alignment,\n uint32_t disk_write_dma_alignment,\n uint32_t disk_overwrite_dma_alignment,\n bool nowait_works)\n : _refcount(refcount)\n , _device_id(device_id)\n , _nowait_works(nowait_works)\n , _io_queue(engine().get_io_queue(_device_id))\n , _open_flags(f)\n , _fd(fd) {\n...\n// Path: src/core/memory.cc\n/*\n * This file is open source software, licensed to you under the terms\n * of the Apache License, Version 2.0 (the \"License\"). See the NOTICE file\n * distributed with this work for additional information regarding copyright\n * ownership. You may not use this file except in compliance with the License.\n *\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing,\n * software distributed under the License is distributed on an\n * \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY\n * KIND, either express or implied. See the License for the\n * specific language governing permissions and limitations\n * under the License.\n */\n/*\n * Copyright (C) 2014 Cloudius Systems, Ltd.\n */\n\n\n/// \\cond internal\n\n//\n// Seastar memory allocator\n//\n// This is a share-nothing allocator (memory allocated on one cpu must\n// be freed on the same cpu).\n//\n// Inspired by gperftools' tcmalloc.\n//\n// Memory map:\n//\n// 0x0000'sccc'vvvv'vvvv\n//\n// 0000 - required by architecture (only 48 bits of address space)\n// s - chosen to satisfy system allocator (1-7)\n// ccc - cpu number (0-12 bits allocated vary according to system)\n// v - virtual address within cpu (32-44 bits, according to how much ccc\n// leaves us\n//\n// Each page has a page structure that describes it. Within a cpu's\n// memory pool, the page array starts at offset 0, describing all pages\n// within that pool. Page 0 does not describe a valid page.\n//\n// Each pool can contain at most 2^32 pages (or 44 address bits), so we can\n// use a 32-bit integer to identify a page.\n//\n// Runs of pages are organized into spans. Free spans are organized into lists,\n// by size. When spans are broken up or coalesced, they may move into new lists.\n// Spans have a size that is a power-of-two and are naturally aligned (aka buddy\n// allocator)\n//\n// If compiled with SEASTAR_HEAPPROF seastar features a sampling memory\n// profiler. Allocations are sampled at random (see `sampler` for the sampling\n// logic) and tracked. The sampled live set can be retrieved with\n// `sampled_memory_profile()`. Sampled allocations carry an extra\n// allocation_site pointer with them which is used on free to remove them from\n// the sampled live set.\n//\n// Large allocations are tracked via a pointer to the allocation_site which is\n// stored on the page structure. To check whether an allocation was sampled or\n// not this pointer is being looked at on free.\n//\n// Small allocations store an extra 8 bytes at the end of their allocation.\n// Sampled allocations are allocated in a separate set of small pools. Hence, to\n// check whether an allocation was sampled or not one only has to look at the\n// tag in pool.\n//\n\n#ifdef SEASTAR_MODULE\nmodule;\n#endif\n\n#include \n#include \n#include \n#include \n#include \n#include \n#include \n\n#include \n#include \n\n#include \n\n#include \n\n#ifndef SEASTAR_DEFAULT_ALLOCATOR\n#include \n#include \n#include \n#include \n#include \n#include \n#include \n#include \n#include \n#include \n#include \n#include \n\n#ifdef SEASTAR_HAVE_NUMA\n#include \n#endif\n#endif // !defined(SEASTAR_DEFAULT_ALLOCATOR)\n\n#ifdef SEASTAR_MODULE\nmodule seastar;\n#else\n#include \n#include \n#include \n#include \n#include \n#include \n#include \n#include \n#include \n#ifndef SEASTAR_DEFAULT_ALLOCATOR\n#include \n#include \n#include \n#include \n#include \n#endif\n#endif\n\n#ifdef SEASTAR_DEBUG\n#define dassert(expr) assert(expr)\n#else\n#define dassert(expr) do {} while(false)\n#endif\n\nnamespace seastar {\n\nextern seastar::logger seastar_logger;\n\nvoid* internal::allocate_aligned_buffer_impl(size_t size, size_t align) {\n void *ret;\n auto r = posix_memalign(&ret, align, size);\n if (r == ENOMEM) {\n throw std::bad_alloc();\n } else if (r == EINVAL) {\n throw std::runtime_error(format(\"Invalid alignment of {:d}; allocating {:d} bytes\", align, size));\n } else {\n assert(r == 0);\n return ret;\n }\n}\n\nnamespace memory {\n\n// We always create the logger object for memory disagnostics, even in\n// in SEASTAR_DEFAULT_ALLOCATOR builds, though it only logs when the\n// seastar allocator is enabled.\nseastar::logger seastar_memory_logger(\"seastar_memory\");\n\nstatic thread_local int abort_on_alloc_failure_suppressed = 0;\n\ndisable_abort_on_alloc_failure_temporarily::disable_abort_on_alloc_failure_temporarily() {\n ++abort_on_alloc_failure_suppressed;\n}\n\ndisable_abort_on_alloc_failure_temporarily::~disable_abort_on_alloc_failure_temporarily() noexcept {\n --abort_on_alloc_failure_suppressed;\n}\n\n\nvoid enable_abort_on_allocation_failure() {\n set_abort_on_allocation_failure(true);\n}\n\nstatic std::pmr::polymorphic_allocator static_malloc_allocator{std::pmr::get_default_resource()};;\nstd::pmr::polymorphic_allocator* malloc_allocator{&static_malloc_allocator};\n\nnamespace internal {\n\n#ifdef __cpp_constinit\n#define SEASTAR_CONSTINIT constinit\n#else\n#define SEASTAR_CONSTINIT\n#endif\n\n#ifdef SEASTAR_ENABLE_ALLOC_FAILURE_INJECTION\n\n#ifdef __cpp_constinit\nthread_local constinit volatile int critical_alloc_section = 0;\n#else\n__thread volatile int critical_alloc_section = 0;\n#endif\n\n#endif // SEASTAR_ENABLE_ALLOC_FAILURE_INJECTION\n\nnuma_layout\nmerge(numa_layout one, numa_layout two) {\n // There's no chance to merge, so just concatenate\n one.ranges.insert(one.ranges.end(), two.ranges.begin(), two.ranges.end());\n return one;\n}\n\n} // namespace internal\n\n}\n\n}\n\n#ifndef SEASTAR_DEFAULT_ALLOCATOR\n\n#if FMT_VERSION >= 90000\nnamespace seastar::memory {\nstruct human_readable_value;\n}\ntemplate <> struct fmt::formatter : fmt::ostream_formatter {};\n#endif\n\nnamespace seastar {\n\nusing allocation_site_ptr = const memory::allocation_site*;\n\nnamespace memory {\n\n[[gnu::unused]]\nstatic allocation_site_ptr get_allocation_site();\n\n[[gnu::noinline]]\nstatic void on_allocation_failure(size_t size);\n\nstatic constexpr unsigned cpu_id_shift = 36; // FIXME: make dynamic\nstatic constexpr unsigned max_cpus = 256;\nstatic constexpr uintptr_t cpu_id_and_mem_base_mask = ~((uintptr_t(1) << cpu_id_shift) - 1);\n\nusing pageidx = uint32_t;\n\nstruct page;\nclass page_list;\n\nstatic std::atomic live_cpus[max_cpus];\n\nusing std::optional;\n\n// is_reactor_thread gets set to true when memory::configure() gets called\n// it is used to identify seastar threads and hence use system memory allocator\n// for those threads\nstatic thread_local bool is_reactor_thread = false;\n\n// We default transparent hugepages to true since we prefer to transiently\n// use a transparent hugepage and then break it, to having the kernel\n// work to rearrange a broken transparent hugepage.\nstd::atomic use_transparent_hugepages = true;\n\nnamespace alloc_stats {\n\nenum class types { allocs, frees, cross_cpu_frees, reclaims, large_allocs, failed_allocs,\n foreign_mallocs, foreign_frees, foreign_cross_frees, enum_size };\n\nusing stats_array = std::array(types::enum_size)>;\nusing stats_atomic_array = std::array(types::enum_size)>;\n\nstatic thread_local SEASTAR_CONSTINIT stats_array stats{};\nstd::array alien_stats{};\n\nstatic void increment_local(types stat_type, uint64_t size = 1) {\n stats[static_cast(stat_type)] += size;\n}\n\nstatic void increment(types stat_type, uint64_t size=1)\n{\n // fast path, reactor threads takes thread local statistics\n if (is_reactor_thread) {\n increment_local(stat_type, size);\n } else {\n auto hash = std::hash()(std::this_thread::get_id());\n auto i = static_cast(stat_type);\n alien_stats[hash % alien_stats.size()][i].fetch_add(size, std::memory_order_relaxed);\n }\n}\n\nstatic uint64_t get(types stat_type)\n{\n auto i = static_cast(stat_type);\n // fast path, reactor threads takes thread local statistics\n if (is_reactor_thread) {\n return stats[i];\n } else {\n auto hash = std::hash()(std::this_thread::get_id());\n return alien_stats[hash % alien_stats.size()][i].load();\n }\n}\n\n}\n\n// original memory allocator support\n// note: allocations before calling the constructor would use seastar allocator\nusing malloc_func_type = void * (*)(size_t);\nusing free_func_type = void * (*)(void *);\nusing realloc_func_type = void * (*)(void *, size_t);\nusing aligned_alloc_type = void * (*)(size_t alignment, size_t size);\nusing malloc_trim_type = int (*)(size_t);\nusing malloc_usable_size_type = size_t (*)(void *);\n\nmalloc_func_type original_malloc_func = reinterpret_cast(dlsym(RTLD_NEXT, \"malloc\"));\nfree_func_type original_free_func = reinterpret_cast(dlsym(RTLD_NEXT, \"free\"));\nrealloc_func_type original_realloc_func = reinterpret_cast(dlsym(RTLD_NEXT, \"realloc\"));\naligned_alloc_type original_aligned_alloc_func = reinterpret_cast(dlsym(RTLD_NEXT, \"aligned_alloc\"));\nmalloc_trim_type original_malloc_trim_func = reinterpret_cast(dlsym(RTLD_NEXT, \"malloc_trim\"));\nmalloc_usable_size_type original_malloc_usable_size_func = reinterpret_cast(dlsym(RTLD_NEXT, \"malloc_usable_size\"));\n\nusing allocate_system_memory_fn\n = std::function;\n\nnamespace bi = boost::intrusive;\n\nstatic thread_local uintptr_t local_expected_cpu_id = std::numeric_limits::max();\n\ninline\nunsigned object_cpu_id(const void* ptr) {\n return (reinterpret_cast(ptr) >> cpu_id_shift) & 0xff;\n}\n\nclass page_list_link {\n uint32_t _prev;\n uint32_t _next;\n friend class page_list;\n friend seastar::internal::log_buf::inserter_iterator do_dump_memory_diagnostics(seastar::internal::log_buf::inserter_iterator);\n};\n\nconstexpr size_t mem_base_alloc = size_t(1) << 44;\n\nstatic char* mem_base() {\n static char* known;\n static std::once_flag flag;\n std::call_once(flag, [] {\n auto r = ::mmap(NULL, 2 * mem_base_alloc,\n PROT_NONE,\n MAP_PRIVATE | MAP_ANONYMOUS | MAP_NORESERVE,\n -1, 0);\n if (r == MAP_FAILED) {\n abort();\n }\n ::madvise(r, 2 * mem_base_alloc, MADV_DONTDUMP);\n auto cr = reinterpret_cast(r);\n known = align_up(cr, mem_base_alloc);\n ::munmap(cr, known - cr);\n ::munmap(known + mem_base_alloc, cr + 2 * mem_base_alloc - (known + mem_base_alloc));\n // extremely unlikely for mmap to return a mapping at 0, but our detection of free(null)\n // depends on it not doing that so check it\n assert(known != nullptr);\n assert(reinterpret_cast(known) != 0);\n });\n return known;\n}\n\nbool is_seastar_memory(void * ptr)\n{\n auto begin = mem_base();\n auto end = begin + mem_base_alloc;\n return ptr >= begin && ptr < end;\n}\n\nconstexpr bool is_page_aligned(size_t size) {\n return (size & (page_size - 1)) == 0;\n}\n\nconstexpr size_t next_page_aligned(size_t size) {\n return (size + (page_size - 1)) & ~(page_size - 1);\n}\n\nclass small_pool;\n\nstruct free_object {\n free_object* next;\n};\n\nstruct page {\n bool free;\n uint8_t offset_in_span;\n uint16_t nr_small_alloc;\n uint32_t span_size; // in pages, if we're the head or the tail\n page_list_link link;\n small_pool* pool; // if used in a small_pool\n free_object* freelist;\n#ifdef SEASTAR_HEAPPROF\n allocation_site_ptr alloc_site; // for objects whose size is multiple of page size, valid for head only\n#endif\n};\n\nclass page_list {\n uint32_t _front = 0;\n uint32_t _back = 0;\npublic:\n page& front(page* ary) { return ary[_front]; }\n page& back(page* ary) { return ary[_back]; }\n bool empty() const { return !_front; }\n void erase(page* ary, page& span) {\n if (span.link._next) {\n ary[span.link._next].link._prev = span.link._prev;\n } else {\n _back = span.link._prev;\n }\n if (span.link._prev) {\n ary[span.link._prev].link._next = span.link._next;\n } else {\n _front = span.link._next;\n }\n }\n void push_front(page* ary, page& span) {\n auto idx = &span - ary;\n if (_front) {\n ary[_front].link._prev = idx;\n } else {\n _back = idx;\n }\n span.link._next = _front;\n span.link._prev = 0;\n _front = idx;\n }\n void pop_front(page* ary) {\n if (ary[_front].link._next) {\n ary[ary[_front].link._next].link._prev = 0;\n } else {\n _back = 0;\n }\n _front = ary[_front].link._next;\n }\n friend seastar::internal::log_buf::inserter_iterator do_dump_memory_diagnostics(seastar::internal::log_buf::inserter_iterator);\n};\n\nclass small_pool {\n struct span_sizes {\n uint8_t preferred;\n uint8_t fallback;\n };\n free_object* _free = nullptr;\n unsigned _object_size;\n span_sizes _span_sizes;\n unsigned _free_count = 0;\n unsigned _min_free;\n unsigned _max_free;\n unsigned _pages_in_use = 0;\n // Flag to indicate whether this pool stores sampled allocations.\n // When freeing small allocations this flag is checked to see whether an\n // allocation site pointer is part of the object and the allocation needs\n // removal from the allocation_site tracking\n#ifdef SEASTAR_HEAPPROF\n bool _sampled_pool = false;\n#endif\n page_list _span_list;\n static constexpr unsigned idx_frac_bits = 2;\npublic:\n explicit small_pool(unsigned object_size, bool is_sampled) noexcept;\n ~small_pool();\n inline void* allocate();\n void deallocate(void* object);\n unsigned object_size() const { return _object_size; }\n /// See _sampled_pool\n bool is_sampled_pool() const {\n#ifdef SEASTAR_HEAPPROF\n return _sampled_pool;\n#else\n return false;\n#endif\n }\n bool objects_page_aligned() const { return is_page_aligned(_object_size); }\n static constexpr unsigned size_to_idx(unsigned size);\n static constexpr unsigned idx_to_size(unsigned idx);\n allocation_site_ptr& alloc_site_holder(void* ptr);\nprivate:\n inline void* pop_free();\n [[gnu::noinline]] void* add_more_objects();\n void trim_free_list();\n friend seastar::internal::log_buf::inserter_iterator do_dump_memory_diagnostics(seastar::internal::log_buf::inserter_iterator);\n};\n\n// index 0b0001'1100 -> size (1 << 4) + 0b11 << (4 - 2)\n\nconstexpr unsigned\nsmall_pool::idx_to_size(unsigned idx) {\n size_t s = (((1 << idx_frac_bits) | (idx & ((1 << idx_frac_bits) - 1)))\n << (idx >> idx_frac_bits))\n >> idx_frac_bits;\n // If size is larger than max_align_t, force it to be a multiple of\n // max_align_t. Clang relies in this property to use aligned mov\n // instructions (e.g. movaps)\n //\n // Note this function is used at initialization time only, so it doesn't\n // need to be especially fast.\n if (s > alignof(std::max_align_t)) {\n\ts = align_up(s, alignof(std::max_align_t));\n }\n return s;\n}\n\nconstexpr unsigned\nsmall_pool::size_to_idx(unsigned size) {\n return ((log2floor(size) << idx_frac_bits) - ((1 << idx_frac_bits) - 1))\n + ((size - 1) >> (log2floor(size) - idx_frac_bits));\n}\n\ntemplate // tag the pools in this array as sampled, see small_pool._sampled_pool\nclass small_pool_array {\npublic:\n static constexpr unsigned nr_small_pools = small_pool::size_to_idx(4 * page_size) + 1;\nprivate:\n union u {\n small_pool a[nr_small_pools];\n u() {\n for (unsigned i = 0; i < nr_small_pools; ++i) {\n new (&a[i]) small_pool(small_pool::idx_to_size(i), sampled);\n }\n }\n ~u() {\n // cannot really call destructor, since other\n // objects may be freed after we are gone.\n }\n } _u;\npublic:\n small_pool& operator[](unsigned idx) { return _u.a[idx]; }\n};\n\nstatic constexpr size_t max_small_allocation\n = small_pool::idx_to_size(small_pool_array::nr_small_pools - 1);\n\n#ifdef SEASTAR_HEAPPROF\nconstexpr size_t object_size_with_alloc_site(size_t size) {\n // For page-aligned sizes, allocation_site* lives in page::alloc_site, not with the object.\n static_assert(is_page_aligned(max_small_allocation), \"assuming that max_small_allocation is page aligned so that we\"\n \" don't need to add allocation_site_ptr to objects of size close to it\");\n size_t next_page_aligned_size = next_page_aligned(size);\n if (next_page_aligned_size - size > sizeof(allocation_site_ptr)) {\n size += sizeof(allocation_site_ptr);\n } else {\n return next_page_aligned_size;\n }\n return size;\n}\n\n// Ensure that object_size_with_alloc_site() does not exceed max_small_allocation\nstatic_assert(object_size_with_alloc_site(max_small_allocation) == max_small_allocation, \"\");\nstatic_assert(object_size_with_alloc_site(max_small_allocation - 1) == max_small_allocation, \"\");\nstatic_assert(object_size_with_alloc_site(max_small_allocation - sizeof(allocation_site_ptr) + 1) == max_small_allocation, \"\");\nstatic_assert(object_size_with_alloc_site(max_small_allocation - sizeof(allocation_site_ptr)) == max_small_allocation, \"\");\nstatic_assert(object_size_with_alloc_site(max_small_allocation - sizeof(allocation_site_ptr) - 1) == max_small_allocation - 1, \"\");\nstatic_assert(object_size_with_alloc_site(max_small_allocation - sizeof(allocation_site_ptr) - 2) == max_small_allocation - 2, \"\");\n#endif\n\nstruct cross_cpu_free_item {\n cross_cpu_free_item* next;\n};\n\nstruct cpu_pages {\n small_pool_array small_pools;\n uint32_t min_free_pages = 20000000 / page_size;\n char* memory;\n page* pages;\n uint32_t nr_pages;\n uint32_t nr_free_pages;\n uint32_t current_min_free_pages = 0;\n size_t large_allocation_warning_threshold = std::numeric_limits::max();\n unsigned cpu_id = -1U;\n std::function)> reclaim_hook;\n std::vector reclaimers;\n static constexpr unsigned nr_span_lists = 32;\n page_list free_spans[nr_span_lists]; // contains aligned spans with span_size == 2^idx\n alignas(seastar::cache_line_size) std::atomic xcpu_freelist;\n static std::atomic cpu_id_gen;\n static cpu_pages* all_cpus[max_cpus];\n union asu {\n using alloc_sites_type = std::unordered_set;\n asu() : alloc_sites{} {\n }\n ~asu() {} // alloc_sites live forever\n alloc_sites_type alloc_sites;\n } asu;\n allocation_site_ptr alloc_site_list_head = nullptr; // For easy traversal of asu.alloc_sites from scylla-gdb.py\n sampler heap_prof_sampler;\n small_pool_array sampled_small_pools;\n\n char* mem() { return memory; }\n\n void link(page_list& list, page* span);\n void unlink(page_list& list, page* span);\n struct trim {\n unsigned offset;\n unsigned nr_pages;\n };\n void maybe_reclaim();\n void* allocate_large_and_trim(unsigned nr_pages, bool should_sample);\n void* allocate_large(unsigned nr_pages, bool should_sample);\n void* allocate_large_aligned(unsigned align_pages, unsigned nr_pages, bool should_sample);\n page* find_and_unlink_span(unsigned nr_pages);\n page* find_and_unlink_span_reclaiming(unsigned n_pages);\n void free_large(void* ptr);\n bool grow_span(pageidx& start, uint32_t& nr_pages, unsigned idx);\n void free_span(pageidx start, uint32_t nr_pages);\n void free_span_no_merge(pageidx start, uint32_t nr_pages);\n void free_span_unaligned(pageidx start, uint32_t nr_pages);\n void free(void* ptr);\n void free(void* ptr, size_t size);\n static bool try_free_fastpath(void* ptr);\n static bool is_local_pointer(void* ptr);\n static void do_foreign_free(void* ptr);\n void shrink(void* ptr, size_t new_size);\n static void free_cross_cpu(unsigned cpu_id, void* ptr);\n bool drain_cross_cpu_freelist();\n size_t object_size(void* ptr);\n\n page* to_page(void* p) {\n size_t page_idx = ((uintptr_t)p) << (64 - cpu_id_shift) >> (64 - cpu_id_shift + page_bits);\n return &pages[page_idx];\n }\n\n bool is_initialized() const;\n bool initialize();\n reclaiming_result run_reclaimers(reclaimer_scope, size_t pages_to_reclaim);\n void schedule_reclaim();\n void set_reclaim_hook(std::function)> hook);\n void set_min_free_pages(size_t pages);\n void resize(size_t new_size, allocate_system_memory_fn alloc_sys_mem);\n void do_resize(size_t new_size, allocate_system_memory_fn alloc_sys_mem);\n void replace_memory_backing(allocate_system_memory_fn alloc_sys_mem);\n void check_large_allocation(size_t size);\n void warn_large_allocation(size_t size);\n allocation_site_ptr add_alloc_site(size_t allocated_size);\n void remove_alloc_site(allocation_site_ptr alloc_site, size_t deallocated_size);\n bool maybe_sample(size_t size);\n bool definitely_sample(size_t size);\n memory::memory_layout memory_layout();\n ~cpu_pages();\n};\n\nstatic thread_local cpu_pages cpu_mem;\nstd::atomic cpu_pages::cpu_id_gen;\ncpu_pages* cpu_pages::all_cpus[max_cpus];\n\nstatic cpu_pages& get_cpu_mem();\n\n#ifdef SEASTAR_HEAPPROF\n\nvoid set_heap_profiling_sampling_rate(size_t sample_rate) {\n bool current_sample_rate = get_cpu_mem().heap_prof_sampler.sampling_interval();\n if (sample_rate) {\n if (!current_sample_rate) {\n seastar_logger.info(\"Enabling heap profiler - using {} bytes sampling rate\", sample_rate);\n } else {\n seastar_logger.warn(\"Ignoring change to heap profiler sample rate as heap profiling is already turned on\");\n return;\n }\n } else {\n if (current_sample_rate) {\n seastar_logger.info(\"Disabling heap profiler\");\n }\n }\n get_cpu_mem().heap_prof_sampler.set_sampling_interval(sample_rate);\n}\n\nsize_t get_heap_profiling_sample_rate() {\n return get_cpu_mem().heap_prof_sampler.sampling_interval();\n}\n\nstatic thread_local int64_t scoped_heap_profiling_embed_count = 0;\n\nscoped_heap_profiling::scoped_heap_profiling(size_t sample_rate) noexcept {\n ++scoped_heap_profiling_embed_count;\n set_heap_profiling_sampling_rate(sample_rate);\n}\n\nscoped_heap_profiling::~scoped_heap_profiling() {\n if (!--scoped_heap_profiling_embed_count) {\n set_heap_profiling_sampling_rate(0);\n }\n}\n\n#else\n\nvoid set_heap_profiling_sampling_rate(size_t enable) {\n seastar_logger.warn(\"Seastar compiled without heap profiling support, heap profiler not supported;\"\n \" compile with the Seastar_HEAP_PROFILING=ON CMake option to add heap profiling support\");\n}\n\nsize_t get_heap_profiling_sample_rate() {\n // don't log here, called on all paths\n return 0;\n}\n\nscoped_heap_profiling::scoped_heap_profiling(size_t sample_rate) noexcept {\n set_heap_profiling_sampling_rate(sample_rate); // let it print the warning\n}\n\nscoped_heap_profiling::~scoped_heap_profiling() {\n}\n\n#endif\n\n// Smallest index i such that all spans stored in the index are >= pages.\nstatic inline\nunsigned index_of(unsigned pages) {\n if (pages == 1) {\n return 0;\n }\n return std::numeric_limits::digits - count_leading_zeros(pages - 1);\n}\n\nvoid\ncpu_pages::unlink(page_list& list, page* span) {\n list.erase(pages, *span);\n}\n\nvoid\ncpu_pages::link(page_list& list, page* span) {\n list.push_front(pages, *span);\n}\n\nvoid cpu_pages::free_span_no_merge(uint32_t span_start, uint32_t nr_pages) {\n assert(nr_pages);\n nr_free_pages += nr_pages;\n auto span = &pages[span_start];\n auto span_end = &pages[span_start + nr_pages - 1];\n span->free = span_end->free = true;\n span->span_size = span_end->span_size = nr_pages;\n auto idx = index_of(nr_pages);\n link(free_spans[idx], span);\n}\n\nbool cpu_pages::grow_span(uint32_t& span_start, uint32_t& nr_pages, unsigned idx) {\n auto which = (span_start >> idx) & 1; // 0=lower, 1=upper\n // locate first page of upper buddy or last page of lower buddy\n // examples: span_start = 0x10 nr_pages = 0x08 -> buddy = 0x18 (which = 0)\n // span_start = 0x18 nr_pages = 0x08 -> buddy = 0x17 (which = 1)\n // delta = which ? -1u : nr_pages\n auto delta = ((which ^ 1) << idx) | -which;\n auto buddy = span_start + delta;\n if (pages[buddy].free && pages[buddy].span_size == nr_pages) {\n unlink(free_spans[idx], &pages[span_start ^ nr_pages]);\n nr_free_pages -= nr_pages; // free_span_no_merge() will restore\n span_start &= ~nr_pages;\n nr_pages *= 2;\n return true;\n }\n return false;\n}\n\nvoid cpu_pages::free_span(uint32_t span_start, uint32_t nr_pages) {\n auto idx = index_of(nr_pages);\n while (grow_span(span_start, nr_pages, idx)) {\n ++idx;\n }\n free_span_no_merge(span_start, nr_pages);\n}\n\n// Internal, used during startup. Span is not aligned so needs to be broken up\nvoid cpu_pages::free_span_unaligned(uint32_t span_start, uint32_t nr_pages) {\n while (nr_pages) {\n auto start_nr_bits = span_start ? count_trailing_zeros(span_start) : 32;\n auto size_nr_bits = count_trailing_zeros(nr_pages);\n auto now = 1u << std::min(start_nr_bits, size_nr_bits);\n free_span(span_start, now);\n span_start += now;\n nr_pages -= now;\n }\n}\n\npage*\ncpu_pages::find_and_unlink_span(unsigned n_pages) {\n auto idx = index_of(n_pages);\n if (n_pages >= (2u << idx)) {\n return nullptr;\n }\n while (idx < nr_span_lists && free_spans[idx].empty()) {\n ++idx;\n }\n if (idx == nr_span_lists) {\n if (initialize()) {\n return find_and_unlink_span(n_pages);\n }\n return nullptr;\n }\n auto& list = free_spans[idx];\n page* span = &list.front(pages);\n unlink(list, span);\n return span;\n}\n\npage*\ncpu_pages::find_and_unlink_span_reclaiming(unsigned n_pages) {\n while (true) {\n auto span = find_and_unlink_span(n_pages);\n if (span) {\n return span;\n }\n if (run_reclaimers(reclaimer_scope::sync, n_pages) == reclaiming_result::reclaimed_nothing) {\n return nullptr;\n }\n }\n}\n\nvoid cpu_pages::maybe_reclaim() {\n if (nr_free_pages < current_min_free_pages) {\n drain_cross_cpu_freelist();\n if (nr_free_pages < current_min_free_pages) {\n run_reclaimers(reclaimer_scope::sync, current_min_free_pages - nr_free_pages);\n }\n if (nr_free_pages < current_min_free_pages) {\n schedule_reclaim();\n }\n }\n}\n\nvoid*\ncpu_pages::allocate_large_and_trim(unsigned n_pages, bool should_sample) {\n // Avoid exercising the reclaimers for requests we'll not be able to satisfy\n // nr_pages might be zero during startup, so check for that too\n if (nr_pages && n_pages >= nr_pages) {\n return nullptr;\n }\n page* span = find_and_unlink_span_reclaiming(n_pages);\n if (!span) {\n return nullptr;\n }\n auto span_size = span->span_size;\n auto span_idx = span - pages;\n nr_free_pages -= span->span_size;\n while (span_size >= n_pages * 2) {\n span_size /= 2;\n auto other_span_idx = span_idx + span_size;\n free_span_no_merge(other_span_idx, span_size);\n }\n auto span_end = &pages[span_idx + span_size - 1];\n span->free = span_end->free = false;\n span->span_size = span_end->span_size = span_size;\n span->pool = nullptr;\n#ifdef SEASTAR_HEAPPROF\n if (should_sample) {\n auto alloc_site = add_alloc_site(span->span_size * page_size);\n span->alloc_site = alloc_site;\n }\n else {\n span->alloc_site = nullptr;\n }\n#endif\n maybe_reclaim();\n return mem() + span_idx * page_size;\n}\n\nvoid\ncpu_pages::warn_large_allocation(size_t size) {\n alloc_stats::increment_local(alloc_stats::types::large_allocs);\n seastar_memory_logger.warn(\"oversized allocation: {} bytes. This is non-fatal, but could lead to latency and/or fragmentation issues. Please report: at {}\", size, current_backtrace());\n large_allocation_warning_threshold *= 1.618; // prevent spam\n}\n\nallocation_site_ptr\ncpu_pages::add_alloc_site(size_t allocated_size) {\n allocation_site_ptr alloc_site = get_allocation_site();\n if (alloc_site) {\n ++alloc_site->count;\n alloc_site->size += heap_prof_sampler.sample_size(allocated_size);\n }\n\n return alloc_site;\n}\n\nvoid\ncpu_pages::remove_alloc_site(allocation_site_ptr alloc_site, size_t deallocated_size) {\n if (alloc_site) {\n --alloc_site->count;\n auto sample_size = heap_prof_sampler.sample_size(deallocated_size);\n // prevent underflow in case sample rate changed\n alloc_site->size -= alloc_site->size < sample_size ? alloc_site->size : sample_size;\n if (alloc_site->count == 0) {\n if (alloc_site->prev) {\n alloc_site->prev->next = alloc_site->next;\n }\n if (alloc_site->next) {\n alloc_site->next->prev = alloc_site->prev;\n }\n if (alloc_site_list_head == alloc_site) {\n alloc_site_list_head = alloc_site->next;\n }\n\n asu.alloc_sites.erase(*alloc_site);\n }\n }\n}\n\n[[gnu::always_inline]]\ninline bool\ncpu_pages::maybe_sample(size_t size) {\n#ifdef SEASTAR_HEAPPROF\n return heap_prof_sampler.maybe_sample(size);\n#else\n return false;\n#endif\n}\n\n[[gnu::always_inline]]\ninline bool\ncpu_pages::definitely_sample(size_t size) {\n#ifdef SEASTAR_HEAPPROF\n return heap_prof_sampler.definitely_sample(size);\n#else\n return false;\n#endif\n}\n\n[[gnu::always_inline]]\nvoid\ninline\ncpu_pages::check_large_allocation(size_t size) {\n if (size >= large_allocation_warning_threshold) {\n warn_large_allocation(size);\n }\n}\n\n[[gnu::always_inline]]\ninline void*\ncpu_pages::allocate_large(unsigned n_pages, bool should_sample) {\n check_large_allocation(n_pages * page_size);\n return allocate_large_and_trim(n_pages, should_sample);\n}\n\nvoid*\ncpu_pages::allocate_large_aligned(unsigned align_pages, unsigned n_pages, bool should_sample) {\n check_large_allocation(n_pages * page_size);\n // buddy allocation is always aligned\n return allocate_large_and_trim(n_pages, should_sample);\n}\n\ndisable_backtrace_temporarily::disable_backtrace_temporarily()\n : _disable_sampling(cpu_mem.heap_prof_sampler.pause_sampling()) {\n}\n\nstatic\nsimple_backtrace get_backtrace() noexcept {\n disable_backtrace_temporarily dbt;\n return current_backtrace_tasklocal();\n}\n\nstatic\nallocation_site_ptr get_allocation_site() {\n if (!cpu_mem.is_initialized() || !cpu_mem.heap_prof_sampler.sampling_interval()) {\n return nullptr;\n }\n disable_backtrace_temporarily dbt;\n allocation_site new_alloc_site;\n new_alloc_site.backtrace = get_backtrace();\n if (cpu_mem.asu.alloc_sites.size() >= 1000\n && cpu_mem.asu.alloc_sites.find(new_alloc_site) == cpu_mem.asu.alloc_sites.end()) {\n // Drop sample for now. Could do something smarter like dropping a\n // current one at random but needs more work in remove_alloc_site as we\n // might then have allocations for which the allocsite is no longer\n // alive\n return nullptr;\n }\n auto insert_result = cpu_mem.asu.alloc_sites.insert(std::move(new_alloc_site));\n allocation_site_ptr alloc_site = &*insert_result.first;\n if (insert_result.second) {\n alloc_site->next = cpu_mem.alloc_site_list_head;\n if (cpu_mem.alloc_site_list_head) {\n cpu_mem.alloc_site_list_head->prev = alloc_site;\n }\n cpu_mem.alloc_site_list_head = alloc_site;\n }\n return alloc_site;\n}\n\n#ifdef SEASTAR_HEAPPROF\n\nallocation_site_ptr&\nsmall_pool::alloc_site_holder(void* ptr) {\n if (objects_page_aligned()) {\n return get_cpu_mem().to_page(ptr)->alloc_site;\n } else {\n return *reinterpret_cast(reinterpret_cast(ptr) + _object_size - sizeof(allocation_site_ptr));\n }\n}\n\n#endif\n\nstatic\nvoid\nmaybe_enable_transparent_hugepages(void* addr, size_t len) {\n if (use_transparent_hugepages.load(std::memory_order_relaxed)) {\n ::madvise(addr, len, MADV_HUGEPAGE);\n }\n}\n\nstatic\nvoid\nmaybe_disable_transparent_hugepages(void* addr, size_t len) {\n if (!use_transparent_hugepages.load(std::memory_order_relaxed)) {\n ::madvise(addr, len, MADV_NOHUGEPAGE);\n }\n}\n\nvoid cpu_pages::free_large(void* ptr) {\n pageidx idx = (reinterpret_cast(ptr) - mem()) / page_size;\n page* span = &pages[idx];\n#ifdef SEASTAR_HEAPPROF\n if (span->alloc_site) {\n auto alloc_site = span->alloc_site;\n remove_alloc_site(alloc_site, span->span_size * page_size);\n }\n#endif\n free_span(idx, span->span_size);\n}\n\nsize_t cpu_pages::object_size(void* ptr) {\n page* span = to_page(ptr);\n if (span->pool) {\n auto s = span->pool->object_size();\n#ifdef SEASTAR_HEAPPROF\n if (span->pool->is_sampled_pool()) {\n // We must not allow the object to be extended onto the allocation_site_ptr field.\n if (!span->pool->objects_page_aligned()) {\n s -= sizeof(allocation_site_ptr);\n }\n }\n#endif\n return s;\n } else {\n return size_t(span->span_size) * page_size;\n }\n}\n\nvoid cpu_pages::free_cross_cpu(unsigned cpu_id, void* ptr) {\n if (!live_cpus[cpu_id].load(std::memory_order_relaxed)) {\n // Thread was destroyed; leak object\n // should only happen for boost unit-tests.\n return;\n }\n auto p = reinterpret_cast(ptr);\n auto& list = all_cpus[cpu_id]->xcpu_freelist;\n auto old = list.load(std::memory_order_relaxed);\n do {\n p->next = old;\n } while (!list.compare_exchange_weak(old, p, std::memory_order_release, std::memory_order_relaxed));\n alloc_stats::increment(alloc_stats::types::cross_cpu_frees);\n}\n\nbool cpu_pages::drain_cross_cpu_freelist() {\n if (!xcpu_freelist.load(std::memory_order_relaxed)) {\n return false;\n }\n auto p = xcpu_freelist.exchange(nullptr, std::memory_order_acquire);\n while (p) {\n auto n = p->next;\n alloc_stats::increment_local(alloc_stats::types::frees);\n free(p);\n p = n;\n }\n return true;\n}\n\n[[gnu::always_inline]]\ninline void cpu_pages::free(void* ptr) {\n page* span = to_page(ptr);\n if (span->pool) {\n small_pool& pool = *span->pool;\n#ifdef SEASTAR_HEAPPROF\n if (pool.is_sampled_pool()) {\n allocation_site_ptr alloc_site = pool.alloc_site_holder(ptr);\n remove_alloc_site(alloc_site, pool.object_size());\n }\n#endif\n pool.deallocate(ptr);\n } else {\n free_large(ptr);\n }\n}\n\nvoid cpu_pages::free(void* ptr, size_t size) {\n#ifdef SEASTAR_HEAPPROF\n // sized free can avoid accessing the `page` structure as an optimization.\n // With memory sampling we always need to check the pool though to see\n // whether this allocation was sampled. Hence, just defer to the non-sized\n // implementation\n (void) size;\n free(ptr);\n#else\n // match action on allocate() so hit the right pool\n if (size <= sizeof(free_object)) {\n size = sizeof(free_object);\n }\n if (size <= max_small_allocation) {\n auto pool = &small_pools[small_pool::size_to_idx(size)];\n pool->deallocate(ptr);\n } else {\n free_large(ptr);\n }\n#endif\n}\n\n// Is the passed pointer a local pointer, i.e., allocated on the current shard from the \n// per-shard allocator.\n[[gnu::always_inline]]\ninline bool\ncpu_pages::is_local_pointer(void* ptr) {\n return (reinterpret_cast(ptr) & cpu_id_and_mem_base_mask) == local_expected_cpu_id;\n}\n\n// Try to execute free on the fast path, which succeeds if:\n//\n// 1) The pointer is local to this shard\n// 2) The pointer is from a small pool\n// 3) The small pool is not sampled\n//\n// In this case, complete the de-allocation and return true.\n// Otherwise, modify nothing and return false.\n[[gnu::always_inline]]\ninline bool\ncpu_pages::try_free_fastpath(void* ptr) {\n if (__builtin_expect(is_local_pointer(ptr), true)) {\n auto pool = get_cpu_mem().to_page(ptr)->pool;\n if (__builtin_expect(pool && !pool->is_sampled_pool(), true)) {\n alloc_stats::increment_local(alloc_stats::types::frees);\n pool->deallocate(ptr);\n return true;\n }\n }\n return false;\n}\n\n/// Helper to allow a single implementation for sized and non-sized functions.\n/// Indicator to allow a single implementation for sized and non-sized functions.\n/// The size parameter will be either no_size tag type or size_t, and most\n/// of the implementation can be shared, using constexpr if or other dispatch\n/// in the places where there should be a difference of behavior.\nstruct no_size {};\n\ntemplate \nrequires std::same_as || std::same_as\n[[gnu::noinline]]\nstatic void free_slowpath(void* obj, S size) {\n if (cpu_pages::is_local_pointer(obj)) {\n alloc_stats::increment_local(alloc_stats::types::frees);\n if constexpr (std::is_same_v) {\n get_cpu_mem().free(obj);\n } else {\n get_cpu_mem().free(obj, size);\n }\n } else {\n cpu_pages::do_foreign_free(obj);\n }\n}\n\n[[gnu::noinline]]\nvoid\ncpu_pages::do_foreign_free(void* ptr) {\n // handles:\n // 1) non-seastar pointers\n // 2) cross-shard frees\n // 3) null pointer\n\n if (!ptr) {\n return;\n }\n\n if (!is_seastar_memory(ptr)) {\n if (is_reactor_thread) {\n alloc_stats::increment_local(alloc_stats::types::foreign_cross_frees);\n } else {\n alloc_stats::increment(alloc_stats::types::foreign_frees);\n }\n original_free_func(ptr);\n return;\n }\n free_cross_cpu(object_cpu_id(ptr), ptr);\n}\n\nvoid cpu_pages::shrink(void* ptr, size_t new_size) {\n assert(object_cpu_id(ptr) == cpu_id);\n page* span = to_page(ptr);\n if (span->pool) {\n return;\n }\n auto old_size_pages = span->span_size;\n size_t new_size_pages = old_size_pages;\n while (new_size_pages / 2 * page_size >= new_size) {\n new_size_pages /= 2;\n }\n if (new_size_pages == old_size_pages) {\n return;\n }\n#ifdef SEASTAR_HEAPPROF\n auto alloc_site = span->alloc_site;\n if (alloc_site) {\n alloc_site->size -= span->span_size * page_size;\n alloc_site->size += new_size_pages * page_size;\n }\n#endif\n span->span_size = new_size_pages;\n span[new_size_pages - 1].free = false;\n span[new_size_pages - 1].span_size = new_size_pages;\n pageidx idx = span - pages;\n free_span_unaligned(idx + new_size_pages, old_size_pages - new_size_pages);\n}\n\ncpu_pages::~cpu_pages() {\n if (is_initialized()) {\n live_cpus[cpu_id].store(false, std::memory_order_relaxed);\n }\n}\n\nbool cpu_pages::is_initialized() const {\n return bool(nr_pages);\n}\n\nbool cpu_pages::initialize() {\n if (is_initialized()) {\n return false;\n }\n cpu_id = cpu_id_gen.fetch_add(1, std::memory_order_relaxed);\n local_expected_cpu_id = (static_cast(cpu_id) << cpu_id_shift)\n\t | reinterpret_cast(mem_base());\n assert(cpu_id < max_cpus);\n all_cpus[cpu_id] = this;\n auto base = mem_base() + (size_t(cpu_id) << cpu_id_shift);\n auto size = 32 << 20; // Small size for bootstrap\n auto r = ::mmap(base, size,\n PROT_READ | PROT_WRITE,\n MAP_PRIVATE | MAP_ANONYMOUS | MAP_FIXED,\n -1, 0);\n if (r == MAP_FAILED) {\n abort();\n }\n maybe_enable_transparent_hugepages(base, size);\n pages = reinterpret_cast(base);\n memory = base;\n nr_pages = size / page_size;\n // we reserve the end page so we don't have to special case\n...\n", "needle_token_start": 3993, "needle_token_end": 4027, "code_context_ntokens": 16177, "output": ["```cpp\nvoid enable_abort_on_allocation_failure() {\n set_abort_on_allocation_failure(true);\n}\n```"]} +{"repo": "scylladb/seastar", "name": "make_shard_local_buffer_copy", "language": "cpp", "path": "src/rpc/rpc.cc", "position_ratio": 0.35, "description": "\nFunction Description:\n1. **Purpose**: The function is designed to create a local copy of a buffer that may originally reside on a different processing shard, ensuring that the buffer's data is accessible locally without the need for inter-shard communication when accessed.\n2. **Input**: It takes a pointer to a unique buffer object that may belong to a different shard.\n3. **Output**: Returns a new buffer object that contains either a direct move of the original buffer if it's already local, or a new buffer with shared ownership of the original data to avoid unnecessary data copying.\n4. **Procedure**: \n - First, the function checks if the original buffer is owned by the current shard. If so, it directly moves the buffer to the output.\n - If the buffer is from a different shard, it initializes a new buffer of the same size.\n - It then checks the type of the original buffer (single or multiple chunks).\n - For a single chunk, it creates a new buffer that points to the original data and sets up a deleter that will handle the original buffer's deletion once no longer needed.\n - For multiple chunks, it creates a new vector of buffers, each pointing to a chunk of the original data, with a shared deleter across all new chunks to manage the original data's lifecycle.\n - The newly created buffer or vector of buffers is then returned, ensuring local access with minimal overhead.\n\n", "instruction": "Based on the function description and code context, please retrieve and repeat the exact described function from the code context in a code block wrapped by ```:", "template": "instruction\ncode_context\ndescription\ninstruction", "code_context": "// Path: src/http/json_path.cc\n/*\n * This file is open source software, licensed to you under the terms\n * of the Apache License, Version 2.0 (the \"License\"). See the NOTICE file\n * distributed with this work for additional information regarding copyright\n * ownership. You may not use this file except in compliance with the License.\n *\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing,\n * software distributed under the License is distributed on an\n * \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY\n * KIND, either express or implied. See the License for the\n * specific language governing permissions and limitations\n * under the License.\n */\n/*\n * Copyright 2015 Cloudius Systems\n */\n\n#ifdef SEASTAR_MODULE\nmodule;\n#include \nmodule seastar;\n#else\n#include \n#endif\n\nnamespace seastar {\n\nnamespace httpd {\n\nusing namespace std;\n\nvoid path_description::set(routes& _routes, handler_base* handler) const {\n for (auto& i : mandatory_queryparams) {\n handler->mandatory(i);\n }\n\n if (params.size() == 0)\n _routes.put(operations.method, path, handler);\n else {\n match_rule* rule = new match_rule(handler);\n rule->add_str(path);\n for (auto&& i : params) {\n if (i.type == url_component_type::FIXED_STRING) {\n rule->add_str(i.name);\n } else {\n rule->add_param(i.name, i.type == url_component_type::PARAM_UNTIL_END_OF_PATH);\n }\n }\n _cookie = _routes.add_cookie(rule, operations.method);\n }\n}\n\nvoid path_description::set(routes& _routes,\n const json_request_function& f) const {\n set(_routes, new function_handler(f));\n}\n\nvoid path_description::set(routes& _routes, const future_json_function& f) const {\n set(_routes, new function_handler(f));\n}\n\nvoid path_description::unset(routes& _routes) const {\n if (params.size() == 0) {\n...\n// Path: src/testing/test_runner.cc\n/*\n * This file is open source software, licensed to you under the terms\n * of the Apache License, Version 2.0 (the \"License\"). See the NOTICE file\n * distributed with this work for additional information regarding copyright\n * ownership. You may not use this file except in compliance with the License.\n *\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing,\n * software distributed under the License is distributed on an\n * \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY\n * KIND, either express or implied. See the License for the\n * specific language governing permissions and limitations\n * under the License.\n */\n/*\n * Copyright (C) 2015 Cloudius Systems, Ltd.\n */\n\n#include \n\n#include \n#include \n#include \n#include \n\nnamespace seastar {\n\nnamespace testing {\n\nstatic test_runner instance;\n\nstruct stop_execution : public std::exception {};\n\ntest_runner::~test_runner() {\n finalize();\n}\n\nbool\ntest_runner::start(int ac, char** av) {\n bool expected = false;\n if (!_started.compare_exchange_strong(expected, true, std::memory_order_acquire)) {\n return true;\n }\n\n // Don't interfere with seastar signal handling\n sigset_t mask;\n sigfillset(&mask); \n for (auto sig : { SIGSEGV }) {\n sigdelset(&mask, sig);\n }\n auto r = ::pthread_sigmask(SIG_BLOCK, &mask, NULL);\n if (r) {\n std::cerr << \"Error blocking signals. Aborting.\" << std::endl;\n abort();\n }\n\n _st_args = std::make_unique(ac, av);\n return true;\n}\n\nint test_runner::start_thread(int ac, char** av) {\n auto init_outcome = std::make_shared>();\n\n namespace bpo = boost::program_options;\n _thread = std::make_unique([this, ac, av, init_outcome]() mutable {\n app_template app;\n app.add_options()\n (\"random-seed\", bpo::value(), \"Random number generator seed\")\n (\"fail-on-abandoned-failed-futures\", bpo::value()->default_value(true), \"Fail the test if there are any abandoned failed futures\");\n // We guarantee that only one thread is running.\n // We only read this after that one thread is joined, so this is safe.\n _exit_code = app.run(ac, av, [this, &app, init_outcome = init_outcome.get()] {\n init_outcome->give(0);\n auto init = [&app] {\n auto conf_seed = app.configuration()[\"random-seed\"];\n auto seed = conf_seed.empty() ? std::random_device()(): conf_seed.as();\n std::cout << \"random-seed=\" << seed << std::endl;\n return smp::invoke_on_all([seed] {\n auto local_seed = seed + this_shard_id();\n local_random_engine.seed(local_seed);\n });\n };\n\n return init().then([this] {\n return do_until([this] { return _done; }, [this] {\n // this will block the reactor briefly, but we don't care\n try {\n auto func = _task.take();\n return func();\n } catch (const stop_execution&) {\n _done = true;\n return make_ready_future<>();\n }\n }).or_terminate();\n }).then([&app] {\n if (engine().abandoned_failed_futures()) {\n std::cerr << \"*** \" << engine().abandoned_failed_futures() << \" abandoned failed future(s) detected\" << std::endl;\n if (app.configuration()[\"fail-on-abandoned-failed-futures\"].as()) {\n std::cerr << \"Failing the test because fail was requested by --fail-on-abandoned-failed-futures\" << std::endl;\n return 3;\n }\n }\n return 0;\n });\n });\n init_outcome->give(_exit_code);\n });\n\n return init_outcome->take();\n}\n\nvoid\ntest_runner::run_sync(std::function()> task) {\n if (_st_args) {\n start_thread_args sa = *_st_args;\n _st_args.reset();\n if (int start_status = start_thread(sa.ac, sa.av); start_status != 0) {\n // something bad happened when starting the reactor or app, and\n // the _thread has exited before taking any task. but we need to\n // move on. let's report this bad news with exit code\n _exit_code = start_status;\n }\n }\n if (_exit_code != 0) {\n // we failed to start the worker reactor, so we cannot send the task to\n // it.\n return;\n }\n\n exchanger e;\n _task.give([task = std::move(task), &e] {\n assert(engine_is_ready());\n try {\n return task().then_wrapped([&e](auto&& f) {\n try {\n f.get();\n e.give({});\n } catch (...) {\n e.give(std::current_exception());\n }\n });\n } catch (...) {\n e.give(std::current_exception());\n return make_ready_future<>();\n }\n });\n auto maybe_exception = e.take();\n if (maybe_exception) {\n std::rethrow_exception(maybe_exception);\n }\n}\n\nint test_runner::finalize() {\n if (_thread) {\n _task.interrupt(stop_execution());\n _thread->join();\n _thread = nullptr;\n }\n return _exit_code;\n}\n\ntest_runner& global_test_runner() {\n return instance;\n}\n\n}\n\n}\n\n// Path: src/testing/entry_point.cc\n/*\n * This file is open source software, licensed to you under the terms\n * of the Apache License, Version 2.0 (the \"License\"). See the NOTICE file\n * distributed with this work for additional information regarding copyright\n * ownership. You may not use this file except in compliance with the License.\n *\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing,\n * software distributed under the License is distributed on an\n * \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY\n * KIND, either express or implied. See the License for the\n * specific language governing permissions and limitations\n * under the License.\n */\n\n/*\n * Copyright (C) 2018 ScyllaDB Ltd.\n */\n\n#include \n#include \n#include \n\nnamespace seastar {\n\nnamespace testing {\n\nstatic bool init_unit_test_suite() {\n auto&& ts = boost::unit_test::framework::master_test_suite();\n return global_test_runner().start(ts.argc, ts.argv);\n}\n\nstatic void dummy_handler(int) {\n // This handler should have been replaced.\n _exit(1);\n}\n\nstatic void install_dummy_handler(int sig) {\n struct sigaction sa {};\n sa.sa_handler = dummy_handler;\n sigaction(sig, &sa, nullptr);\n}\n\nint entry_point(int argc, char** argv) {\n#ifndef SEASTAR_ASAN_ENABLED\n // Before we call into boost, install some dummy signal\n // handlers. This seems to be the only way to stop boost from\n // installing its own handlers, which disables our backtrace\n // printer. The real handler will be installed when the reactor is\n // constructed.\n // If we are using ASAN, it has already installed a signal handler\n // that does its own stack printing.\n for (int sig : {SIGSEGV, SIGABRT}) {\n install_dummy_handler(sig);\n }\n#else\n (void)install_dummy_handler;\n#endif\n\n const int boost_exit_code = ::boost::unit_test::unit_test_main(&init_unit_test_suite, argc, argv);\n const int seastar_exit_code = seastar::testing::global_test_runner().finalize();\n if (boost_exit_code) {\n return boost_exit_code;\n }\n return seastar_exit_code;\n}\n\n}\n\n}\n\n// Path: src/testing/seastar_test.cc\n/*\n * This file is open source software, licensed to you under the terms\n * of the Apache License, Version 2.0 (the \"License\"). See the NOTICE file\n * distributed with this work for additional information regarding copyright\n * ownership. You may not use this file except in compliance with the License.\n *\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing,\n * software distributed under the License is distributed on an\n * \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY\n * KIND, either express or implied. See the License for the\n * specific language governing permissions and limitations\n * under the License.\n */\n\n/*\n * Copyright (C) 2015 Cloudius Systems, Ltd.\n */\n\n#include \n#include \n\n#include \n#include \n#include \n#include \n#include \n#include \n#include \n\nnamespace seastar {\n\nnamespace testing {\n\nexchanger_base::exchanger_base() { }\nexchanger_base::~exchanger_base() { }\n\nvoid seastar_test::run() {\n // HACK: please see https://github.com/cloudius-systems/seastar/issues/10\n BOOST_REQUIRE(true);\n\n // HACK: please see https://github.com/cloudius-systems/seastar/issues/10\n boost::program_options::variables_map()[\"dummy\"];\n\n set_abort_on_internal_error(true);\n\n global_test_runner().run_sync([this] {\n return run_test_case();\n });\n}\n\nseastar_test::seastar_test(const char* test_name, const char* test_file, int test_line)\n : seastar_test(test_name, test_file, test_line, boost::unit_test::decorator::collector_t::instance()) {}\n\nseastar_test::seastar_test(const char* test_name, const char* test_file, int test_line,\n boost::unit_test::decorator::collector_t& decorators)\n : _test_file{test_file} {\n auto test = boost::unit_test::make_test_case([this] { run(); }, test_name, test_file, test_line);\n decorators.store_in(*test);\n decorators.reset();\n boost::unit_test::framework::current_auto_test_suite().add(test);\n}\n\nconst std::string& seastar_test::get_name() {\n const auto& current_test = boost::unit_test::framework::current_test_unit();\n return current_test.p_name.get();\n}\n\nnamespace exception_predicate {\n\nstd::function message_equals(std::string_view expected_message) {\n return [expected_message] (const std::exception& e) {\n std::string error = e.what();\n if (error == expected_message) {\n return true;\n } else {\n std::cerr << \"Expected \\\"\" << expected_message << \"\\\" but got \\\"\" << error << '\"' << std::endl;\n return false;\n }\n };\n}\n\nstd::function message_contains(std::string_view expected_message) {\n return [expected_message] (const std::exception& e) {\n std::string error = e.what();\n if (error.find(expected_message.data()) != std::string::npos) {\n return true;\n } else {\n std::cerr << \"Expected \\\"\" << expected_message << \"\\\" but got \\\"\" << error << '\"' << std::endl;\n return false;\n }\n };\n}\n\n} // exception_predicate\n\nscoped_no_abort_on_internal_error::scoped_no_abort_on_internal_error() noexcept\n : _prev(set_abort_on_internal_error(false))\n{\n}\n\nscoped_no_abort_on_internal_error::~scoped_no_abort_on_internal_error() {\n set_abort_on_internal_error(_prev);\n}\n\n}\n\n}\n\n// Path: src/testing/random.cc\n/*\n * This file is open source software, licensed to you under the terms\n * of the Apache License, Version 2.0 (the \"License\"). See the NOTICE file\n * distributed with this work for additional information regarding copyright\n * ownership. You may not use this file except in compliance with the License.\n *\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing,\n * software distributed under the License is distributed on an\n * \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY\n * KIND, either express or implied. See the License for the\n * specific language governing permissions and limitations\n * under the License.\n */\n/*\n * Copyright (C) 2020 Cloudius Systems, Ltd.\n */\n\n#include \n\nnamespace seastar {\n\nnamespace testing {\n\nthread_local std::default_random_engine local_random_engine;\n\n} // namespace testing\n\n} // namespace seastar\n\n// Path: src/rpc/rpc.cc\n#include \n#include \n#include \n#include \n#include \n#include \n#include \n\n#if FMT_VERSION >= 90000\ntemplate <> struct fmt::formatter : fmt::ostream_formatter {};\n#endif\n\nnamespace seastar {\n\nnamespace rpc {\n\n void logger::operator()(const client_info& info, id_type msg_id, const sstring& str) const {\n log(format(\"client {} msg_id {}: {}\", info.addr, msg_id, str));\n }\n\n void logger::operator()(const client_info& info, id_type msg_id, log_level level, std::string_view str) const {\n log(level, \"client {} msg_id {}: {}\", info.addr, msg_id, str);\n }\n\n void logger::operator()(const client_info& info, const sstring& str) const {\n (*this)(info.addr, str);\n }\n\n void logger::operator()(const client_info& info, log_level level, std::string_view str) const {\n (*this)(info.addr, level, str);\n }\n\n void logger::operator()(const socket_address& addr, const sstring& str) const {\n log(format(\"client {}: {}\", addr, str));\n }\n\n void logger::operator()(const socket_address& addr, log_level level, std::string_view str) const {\n log(level, \"client {}: {}\", addr, str);\n }\n\n no_wait_type no_wait;\n\n snd_buf::snd_buf(size_t size_) : size(size_) {\n if (size <= chunk_size) {\n bufs = temporary_buffer(size);\n } else {\n std::vector> v;\n v.reserve(align_up(size_t(size), chunk_size) / chunk_size);\n while (size_) {\n v.push_back(temporary_buffer(std::min(chunk_size, size_)));\n size_ -= v.back().size();\n }\n bufs = std::move(v);\n }\n }\n\n snd_buf::snd_buf(snd_buf&&) noexcept = default;\n snd_buf& snd_buf::operator=(snd_buf&&) noexcept = default;\n\n temporary_buffer& snd_buf::front() {\n auto* one = std::get_if>(&bufs);\n if (one) {\n return *one;\n } else {\n return std::get>>(bufs).front();\n }\n }\n\n // Make a copy of a remote buffer. No data is actually copied, only pointers and\n // a deleter of a new buffer takes care of deleting the original buffer\n template // T is either snd_buf or rcv_buf\n \nT make_shard_local_buffer_copy(foreign_ptr> org) {\n if (org.get_owner_shard() == this_shard_id()) {\n return std::move(*org);\n }\n T buf(org->size);\n auto* one = std::get_if>(&org->bufs);\n\n if (one) {\n buf.bufs = temporary_buffer(one->get_write(), one->size(), make_object_deleter(std::move(org)));\n } else {\n auto& orgbufs = std::get>>(org->bufs);\n std::vector> newbufs;\n newbufs.reserve(orgbufs.size());\n deleter d = make_object_deleter(std::move(org));\n for (auto&& b : orgbufs) {\n newbufs.push_back(temporary_buffer(b.get_write(), b.size(), d.share()));\n }\n buf.bufs = std::move(newbufs);\n }\n\n return buf;\n }\n\n template snd_buf make_shard_local_buffer_copy(foreign_ptr>);\n template rcv_buf make_shard_local_buffer_copy(foreign_ptr>);\n\n static void log_exception(connection& c, log_level level, const char* log, std::exception_ptr eptr) {\n const char* s;\n try {\n std::rethrow_exception(eptr);\n } catch (std::exception& ex) {\n s = ex.what();\n } catch (...) {\n s = \"unknown exception\";\n }\n auto formatted = format(\"{}: {}\", log, s);\n c.get_logger()(c.peer_address(), level, std::string_view(formatted.data(), formatted.size()));\n }\n\n snd_buf connection::compress(snd_buf buf) {\n if (_compressor) {\n buf = _compressor->compress(4, std::move(buf));\n static_assert(snd_buf::chunk_size >= 4, \"send buffer chunk size is too small\");\n write_le(buf.front().get_write(), buf.size - 4);\n return buf;\n }\n return buf;\n }\n\n future<> connection::send_buffer(snd_buf buf) {\n auto* b = std::get_if>(&buf.bufs);\n if (b) {\n return _write_buf.write(std::move(*b));\n } else {\n return do_with(std::move(std::get>>(buf.bufs)),\n [this] (std::vector>& ar) {\n return do_for_each(ar.begin(), ar.end(), [this] (auto& b) {\n return _write_buf.write(std::move(b));\n });\n });\n }\n }\n\n future<> connection::send_entry(outgoing_entry& d) noexcept {\n return futurize_invoke([this, &d] {\n if (d.buf.size && _propagate_timeout) {\n static_assert(snd_buf::chunk_size >= sizeof(uint64_t), \"send buffer chunk size is too small\");\n if (_timeout_negotiated) {\n auto expire = d.t.get_timeout();\n uint64_t left = 0;\n if (expire != typename timer::time_point()) {\n left = std::chrono::duration_cast(expire - timer::clock::now()).count();\n }\n write_le(d.buf.front().get_write(), left);\n } else {\n d.buf.front().trim_front(sizeof(uint64_t));\n d.buf.size -= sizeof(uint64_t);\n }\n }\n auto buf = compress(std::move(d.buf));\n return send_buffer(std::move(buf)).then([this] {\n _stats.sent_messages++;\n return _write_buf.flush();\n });\n });\n }\n\n void connection::set_negotiated() noexcept {\n _negotiated->set_value();\n _negotiated = std::nullopt;\n }\n\n future<> connection::stop_send_loop(std::exception_ptr ex) {\n _error = true;\n if (_connected) {\n _fd.shutdown_output();\n }\n if (ex == nullptr) {\n ex = std::make_exception_ptr(closed_error());\n }\n while (!_outgoing_queue.empty()) {\n auto it = std::prev(_outgoing_queue.end());\n // Cancel all but front entry normally. The front entry is sitting in the\n // send_entry() and cannot be withdrawn, except when _negotiated is still\n // engaged. In the latter case when it will be aborted below the entry's\n // continuation will not be called and its done promise will not resolve\n // the _outgoing_queue_ready, so do it here\n if (it != _outgoing_queue.begin()) {\n withdraw(it, ex);\n } else {\n if (_negotiated) {\n it->done.set_exception(ex);\n }\n break;\n }\n }\n if (_negotiated) {\n _negotiated->set_exception(ex);\n }\n return when_all(std::move(_outgoing_queue_ready), std::move(_sink_closed_future)).then([this] (std::tuple, future> res){\n // _outgoing_queue_ready might be exceptional if queue drain or\n // _negotiated abortion set it such\n std::get<0>(res).ignore_ready_future();\n // _sink_closed_future is never exceptional\n bool sink_closed = std::get<1>(res).get();\n return _connected && !sink_closed ? _write_buf.close() : make_ready_future();\n });\n }\n\n void connection::set_socket(connected_socket&& fd) {\n if (_connected) {\n throw std::runtime_error(\"already connected\");\n }\n _fd = std::move(fd);\n _read_buf =_fd.input();\n _write_buf = _fd.output();\n _connected = true;\n }\n\n future<> connection::send_negotiation_frame(feature_map features) {\n auto negotiation_frame_feature_record_size = [] (const feature_map::value_type& e) {\n return 8 + e.second.size();\n };\n auto extra_len = boost::accumulate(\n features | boost::adaptors::transformed(negotiation_frame_feature_record_size),\n uint32_t(0));\n temporary_buffer reply(sizeof(negotiation_frame) + extra_len);\n auto p = reply.get_write();\n p = std::copy_n(rpc_magic, 8, p);\n write_le(p, extra_len);\n p += 4;\n for (auto&& e : features) {\n write_le(p, static_cast(e.first));\n p += 4;\n write_le(p, e.second.size());\n p += 4;\n p = std::copy_n(e.second.begin(), e.second.size(), p);\n }\n return _write_buf.write(std::move(reply)).then([this] {\n _stats.sent_messages++;\n return _write_buf.flush();\n });\n }\n\n void connection::withdraw(outgoing_entry::container_t::iterator it, std::exception_ptr ex) {\n assert(it != _outgoing_queue.end());\n\n auto pit = std::prev(it);\n // Previous entry's (pit's) done future will schedule current entry (it)\n // continuation. Similarly, it.done will schedule next entry continuation\n // or will resolve _outgoing_queue_ready future.\n //\n // To withdraw \"it\" we need to do two things:\n // - make pit.done resolve it->next (some time later)\n // - resolve \"it\"'s continuation right now\n //\n // The latter is achieved by resolving pit.done immediatelly, the former\n // by moving it.done into pit.done. For simplicity (verging on obscurity?)\n // both done's are just swapped and \"it\" resolves its new promise\n\n std::swap(it->done, pit->done);\n it->uncancellable();\n it->unlink();\n if (ex == nullptr) {\n it->done.set_value();\n } else {\n it->done.set_exception(ex);\n }\n }\n\n future<> connection::send(snd_buf buf, std::optional timeout, cancellable* cancel) {\n if (!_error) {\n if (timeout && *timeout <= rpc_clock_type::now()) {\n return make_ready_future<>();\n }\n\n auto p = std::make_unique(std::move(buf));\n auto& d = *p;\n _outgoing_queue.push_back(d);\n _outgoing_queue_size++;\n auto deleter = [this, it = _outgoing_queue.iterator_to(d)] {\n // Front entry is most likely (unless _negotiated is unresolved, check enqueue_zero_frame()) sitting\n // inside send_entry() continuations and thus it cannot be cancelled.\n if (it != _outgoing_queue.begin()) {\n withdraw(it);\n }\n };\n\n if (timeout) {\n auto& t = d.t;\n t.set_callback(deleter);\n t.arm(timeout.value());\n }\n if (cancel) {\n cancel->cancel_send = std::move(deleter);\n cancel->send_back_pointer = &d.pcancel;\n d.pcancel = cancel;\n }\n\n // New entry should continue (do its .then() lambda) after _outgoing_queue_ready\n // resolves. Next entry will need to do the same after this entry's done resolves.\n // Thus -- replace _outgoing_queue_ready with d's future and chain its continuation\n // on ..._ready's old value.\n return std::exchange(_outgoing_queue_ready, d.done.get_future()).then([this, p = std::move(p)] () mutable {\n _outgoing_queue_size--;\n if (__builtin_expect(!p->is_linked(), false)) {\n // If withdrawn the entry is unlinked and this lambda is fired right at once\n return make_ready_future<>();\n }\n\n p->uncancellable();\n return send_entry(*p).then_wrapped([this, p = std::move(p)] (auto f) mutable {\n if (f.failed()) {\n f.ignore_ready_future();\n abort();\n }\n p->done.set_value();\n });\n });\n } else {\n return make_exception_future<>(closed_error());\n }\n }\n\n void connection::abort() {\n if (!_error) {\n _error = true;\n _fd.shutdown_input();\n }\n }\n\n future<> connection::stop() noexcept {\n try {\n abort();\n } catch (...) {\n log_exception(*this, log_level::error, \"fail to shutdown connection while stopping\", std::current_exception());\n }\n return _stopped.get_future();\n }\n\n template\n static bool verify_frame(Connection& c, temporary_buffer& buf, size_t expected, const char* log) {\n if (buf.size() != expected) {\n if (buf.size() != 0) {\n c.get_logger()(c.peer_address(), log);\n }\n return false;\n }\n return true;\n }\n\n template\n static\n future\n receive_negotiation_frame(Connection& c, input_stream& in) {\n return in.read_exactly(sizeof(negotiation_frame)).then([&c, &in] (temporary_buffer neg) {\n if (!verify_frame(c, neg, sizeof(negotiation_frame), \"unexpected eof during negotiation frame\")) {\n return make_exception_future(closed_error());\n }\n negotiation_frame frame;\n std::copy_n(neg.get_write(), sizeof(frame.magic), frame.magic);\n frame.len = read_le(neg.get_write() + 8);\n if (std::memcmp(frame.magic, rpc_magic, sizeof(frame.magic)) != 0) {\n c.get_logger()(c.peer_address(), format(\"wrong protocol magic: {:02x}{:02x}{:02x}{:02x}{:02x}{:02x}{:02x}{:02x}\",\n frame.magic[0], frame.magic[1], frame.magic[2], frame.magic[3], frame.magic[4], frame.magic[5], frame.magic[6], frame.magic[7]));\n return make_exception_future(closed_error());\n }\n auto len = frame.len;\n return in.read_exactly(len).then([&c, len] (temporary_buffer extra) {\n if (extra.size() != len) {\n c.get_logger()(c.peer_address(), \"unexpected eof during negotiation frame\");\n return make_exception_future(closed_error());\n }\n feature_map map;\n auto p = extra.get();\n auto end = p + extra.size();\n while (p != end) {\n if (end - p < 8) {\n c.get_logger()(c.peer_address(), \"bad feature data format in negotiation frame\");\n return make_exception_future(closed_error());\n }\n auto feature = static_cast(read_le(p));\n auto f_len = read_le(p + 4);\n p += 8;\n if (f_len > end - p) {\n c.get_logger()(c.peer_address(), \"buffer underflow in feature data in negotiation frame\");\n return make_exception_future(closed_error());\n }\n auto data = sstring(p, f_len);\n p += f_len;\n map.emplace(feature, std::move(data));\n }\n return make_ready_future(std::move(map));\n });\n });\n }\n\n inline future\n read_rcv_buf(input_stream& in, uint32_t size) {\n return in.read_up_to(size).then([&, size] (temporary_buffer data) mutable {\n rcv_buf rb(size);\n if (data.size() == 0) {\n return make_ready_future(rcv_buf());\n } else if (data.size() == size) {\n rb.bufs = std::move(data);\n return make_ready_future(std::move(rb));\n } else {\n size -= data.size();\n std::vector> v;\n v.push_back(std::move(data));\n rb.bufs = std::move(v);\n return do_with(std::move(rb), std::move(size), [&in] (rcv_buf& rb, uint32_t& left) {\n return repeat([&] () {\n return in.read_up_to(left).then([&] (temporary_buffer data) {\n if (!data.size()) {\n rb.size -= left;\n return stop_iteration::yes;\n } else {\n left -= data.size();\n std::get>>(rb.bufs).push_back(std::move(data));\n return left ? stop_iteration::no : stop_iteration::yes;\n }\n });\n }).then([&rb] {\n return std::move(rb);\n });\n });\n }\n });\n }\n\n template\n future\n connection::read_frame(socket_address info, input_stream& in) {\n auto header_size = FrameType::header_size();\n return in.read_exactly(header_size).then([this, header_size, info, &in] (temporary_buffer header) {\n if (header.size() != header_size) {\n if (header.size() != 0) {\n _logger(info, format(\"unexpected eof on a {} while reading header: expected {:d} got {:d}\", FrameType::role(), header_size, header.size()));\n }\n return make_ready_future(FrameType::empty_value());\n }\n auto [size, h] = FrameType::decode_header(header.get());\n if (!size) {\n return make_ready_future(FrameType::make_value(h, rcv_buf()));\n } else {\n return read_rcv_buf(in, size).then([this, info, h = std::move(h), size] (rcv_buf rb) {\n if (rb.size != size) {\n _logger(info, format(\"unexpected eof on a {} while reading data: expected {:d} got {:d}\", FrameType::role(), size, rb.size));\n return make_ready_future(FrameType::empty_value());\n } else {\n return make_ready_future(FrameType::make_value(h, std::move(rb)));\n }\n });\n }\n });\n }\n\n template\n future\n connection::read_frame_compressed(socket_address info, std::unique_ptr& compressor, input_stream& in) {\n if (compressor) {\n return in.read_exactly(4).then([this, info, &in, &compressor] (temporary_buffer compress_header) {\n if (compress_header.size() != 4) {\n if (compress_header.size() != 0) {\n _logger(info, format(\"unexpected eof on a {} while reading compression header: expected 4 got {:d}\", FrameType::role(), compress_header.size()));\n }\n return make_ready_future(FrameType::empty_value());\n }\n auto ptr = compress_header.get();\n auto size = read_le(ptr);\n return read_rcv_buf(in, size).then([this, size, &compressor, info, &in] (rcv_buf compressed_data) {\n if (compressed_data.size != size) {\n _logger(info, format(\"unexpected eof on a {} while reading compressed data: expected {:d} got {:d}\", FrameType::role(), size, compressed_data.size));\n return make_ready_future(FrameType::empty_value());\n }\n auto eb = compressor->decompress(std::move(compressed_data));\n if (eb.size == 0) {\n // Empty frames might be sent as means of communication between the compressors, and should be skipped by the RPC layer.\n // We skip the empty frame here. We recursively restart the function, as if the empty frame didn't happen.\n // The yield() is here to limit the stack depth of the recursion to 1.\n return yield().then([this, info, &in, &compressor] { return read_frame_compressed(info, compressor, in); });\n }\n net::packet p;\n auto* one = std::get_if>(&eb.bufs);\n if (one) {\n p = net::packet(std::move(p), std::move(*one));\n } else {\n auto&& bufs = std::get>>(eb.bufs);\n p.reserve(bufs.size());\n for (auto&& b : bufs) {\n p = net::packet(std::move(p), std::move(b));\n }\n }\n return do_with(as_input_stream(std::move(p)), [this, info] (input_stream& in) {\n return read_frame(info, in);\n });\n });\n });\n } else {\n return read_frame(info, in);\n }\n }\n\n struct stream_frame {\n using opt_buf_type = std::optional;\n using return_type = opt_buf_type;\n struct header_type {\n bool eos;\n };\n static size_t header_size() {\n return 4;\n }\n static const char* role() {\n return \"stream\";\n }\n static auto empty_value() {\n return std::nullopt;\n }\n static std::pair decode_header(const char* ptr) {\n auto size = read_le(ptr);\n return size != -1U ? std::make_pair(size, header_type{false}) : std::make_pair(0U, header_type{true});\n }\n static auto make_value(const header_type& t, rcv_buf data) {\n if (t.eos) {\n data.size = -1U;\n }\n return data;\n }\n };\n\n future>\n connection::read_stream_frame_compressed(input_stream& in) {\n return read_frame_compressed(peer_address(), _compressor, in);\n }\n\n future<> connection::stream_close() {\n auto f = make_ready_future<>();\n if (!error()) {\n promise p;\n _sink_closed_future = p.get_future();\n // stop_send_loop(), which also calls _write_buf.close(), and this code can run in parallel.\n // Use _sink_closed_future to serialize them and skip second call to close()\n f = _write_buf.close().finally([p = std::move(p)] () mutable { p.set_value(true);});\n }\n return f.finally([this] () mutable { return stop(); });\n }\n\n future<> connection::stream_process_incoming(rcv_buf&& buf) {\n // we do not want to dead lock on huge packets, so let them in\n // but only one at a time\n auto size = std::min(size_t(buf.size), max_stream_buffers_memory);\n return get_units(_stream_sem, size).then([this, buf = std::move(buf)] (semaphore_units<>&& su) mutable {\n buf.su = std::move(su);\n return _stream_queue.push_eventually(std::move(buf));\n });\n }\n\n future<> connection::handle_stream_frame() {\n return read_stream_frame_compressed(_read_buf).then([this] (std::optional data) {\n if (!data) {\n _error = true;\n return make_ready_future<>();\n }\n return stream_process_incoming(std::move(*data));\n });\n }\n\n future<> connection::stream_receive(circular_buffer>>& bufs) {\n return _stream_queue.not_empty().then([this, &bufs] {\n bool eof = !_stream_queue.consume([&bufs] (rcv_buf&& b) {\n if (b.size == -1U) { // max fragment length marks an end of a stream\n return false;\n } else {\n bufs.push_back(make_foreign(std::make_unique(std::move(b))));\n return true;\n }\n });\n if (eof && !bufs.empty()) {\n assert(_stream_queue.empty());\n _stream_queue.push(rcv_buf(-1U)); // push eof marker back for next read to notice it\n }\n });\n }\n\n void connection::register_stream(connection_id id, xshard_connection_ptr c) {\n _streams.emplace(id, std::move(c));\n }\n\n xshard_connection_ptr connection::get_stream(connection_id id) const {\n auto it = _streams.find(id);\n if (it == _streams.end()) {\n throw std::logic_error(format(\"rpc stream id {} not found\", id).c_str());\n }\n return it->second;\n }\n\n // The request frame is\n // le64 optional timeout (see request_frame_with_timeout below)\n // le64 message type a.k.a. verb ID\n // le64 message ID\n // le32 payload length\n // ... payload\n struct request_frame {\n using opt_buf_type = std::optional;\n using return_type = std::tuple, uint64_t, int64_t, opt_buf_type>;\n using header_type = std::tuple, uint64_t, int64_t>;\n static constexpr size_t raw_header_size = sizeof(uint64_t) + sizeof(int64_t) + sizeof(uint32_t);\n static size_t header_size() {\n static_assert(request_frame_headroom >= raw_header_size);\n return raw_header_size;\n }\n static const char* role() {\n return \"server\";\n }\n static auto empty_value() {\n return std::make_tuple(std::nullopt, uint64_t(0), 0, std::nullopt);\n }\n static std::pair decode_header(const char* ptr) {\n auto type = read_le(ptr);\n auto msgid = read_le(ptr + 8);\n auto size = read_le(ptr + 16);\n return std::make_pair(size, std::make_tuple(std::nullopt, type, msgid));\n }\n static void encode_header(uint64_t type, int64_t msg_id, snd_buf& buf, size_t off) {\n auto p = buf.front().get_write() + off;\n write_le(p, type);\n write_le(p + 8, msg_id);\n write_le(p + 16, buf.size - raw_header_size - off);\n }\n static auto make_value(const header_type& t, rcv_buf data) {\n return std::make_tuple(std::get<0>(t), std::get<1>(t), std::get<2>(t), std::move(data));\n }\n };\n\n // This frame is used if protocol_features.TIMEOUT was negotiated\n struct request_frame_with_timeout : request_frame {\n using super = request_frame;\n static constexpr size_t raw_header_size = sizeof(uint64_t) + request_frame::raw_header_size;\n static size_t header_size() {\n static_assert(request_frame_headroom >= raw_header_size);\n return raw_header_size;\n }\n static std::pair decode_header(const char* ptr) {\n auto h = super::decode_header(ptr + 8);\n std::get<0>(h.second) = read_le(ptr);\n return h;\n }\n static void encode_header(uint64_t type, int64_t msg_id, snd_buf& buf) {\n static_assert(snd_buf::chunk_size >= raw_header_size, \"send buffer chunk size is too small\");\n // expiration timer is encoded later\n request_frame::encode_header(type, msg_id, buf, 8);\n }\n };\n\n future<> client::request(uint64_t type, int64_t msg_id, snd_buf buf, std::optional timeout, cancellable* cancel) {\n request_frame_with_timeout::encode_header(type, msg_id, buf);\n return send(std::move(buf), timeout, cancel);\n }\n\n void\n client::negotiate(feature_map provided) {\n // record features returned here\n for (auto&& e : provided) {\n auto id = e.first;\n switch (id) {\n // supported features go here\n case protocol_features::COMPRESS:\n if (_options.compressor_factory) {\n _compressor = _options.compressor_factory->negotiate(e.second, false, [this] { return send({}); });\n }\n if (!_compressor) {\n throw std::runtime_error(format(\"RPC server responded with compression {} - unsupported\", e.second));\n }\n break;\n case protocol_features::TIMEOUT:\n _timeout_negotiated = true;\n break;\n case protocol_features::CONNECTION_ID: {\n _id = deserialize_connection_id(e.second);\n break;\n }\n default:\n // nothing to do\n ;\n }\n }\n }\n\n future<> client::negotiate_protocol(feature_map features) {\n return send_negotiation_frame(std::move(features)).then([this] {\n return receive_negotiation_frame(*this, _read_buf).then([this] (feature_map features) {\n return negotiate(std::move(features));\n });\n });\n }\n\n // The response frame is\n // le64 message ID\n // le32 payload size\n // ... payload\n struct response_frame {\n using opt_buf_type = std::optional;\n using return_type = std::tuple;\n using header_type = std::tuple;\n static constexpr size_t raw_header_size = sizeof(int64_t) + sizeof(uint32_t);\n static size_t header_size() {\n static_assert(response_frame_headroom >= raw_header_size);\n return raw_header_size;\n }\n static const char* role() {\n return \"client\";\n }\n static auto empty_value() {\n return std::make_tuple(0, std::nullopt);\n }\n static std::pair decode_header(const char* ptr) {\n auto msgid = read_le(ptr);\n auto size = read_le(ptr + 8);\n return std::make_pair(size, std::make_tuple(msgid));\n }\n static void encode_header(int64_t msg_id, snd_buf& data) {\n static_assert(snd_buf::chunk_size >= raw_header_size, \"send buffer chunk size is too small\");\n auto p = data.front().get_write();\n write_le(p, msg_id);\n write_le(p + 8, data.size - raw_header_size);\n }\n static auto make_value(const header_type& t, rcv_buf data) {\n return std::make_tuple(std::get<0>(t), std::move(data));\n }\n };\n\n\n future\n client::read_response_frame_compressed(input_stream& in) {\n return read_frame_compressed(_server_addr, _compressor, in);\n }\n\n stats client::get_stats() const {\n stats res = _stats;\n res.wait_reply = incoming_queue_length();\n res.pending = outgoing_queue_length();\n return res;\n }\n\n void client::wait_for_reply(id_type id, std::unique_ptr&& h, std::optional timeout, cancellable* cancel) {\n if (timeout) {\n h->t.set_callback(std::bind(std::mem_fn(&client::wait_timed_out), this, id));\n h->t.arm(timeout.value());\n }\n if (cancel) {\n cancel->cancel_wait = [this, id] {\n _outstanding[id]->cancel();\n _outstanding.erase(id);\n };\n h->pcancel = cancel;\n cancel->wait_back_pointer = &h->pcancel;\n }\n _outstanding.emplace(id, std::move(h));\n }\n void client::wait_timed_out(id_type id) {\n _stats.timeout++;\n _outstanding[id]->timeout();\n _outstanding.erase(id);\n }\n\n future<> client::stop() noexcept {\n _error = true;\n try {\n _socket.shutdown();\n } catch(...) {\n log_exception(*this, log_level::error, \"fail to shutdown connection while stopping\", std::current_exception());\n }\n return _stopped.get_future();\n }\n\n void client::abort_all_streams() {\n while (!_streams.empty()) {\n auto&& s = _streams.begin();\n assert(s->second->get_owner_shard() == this_shard_id()); // abort can be called only locally\n s->second->get()->abort();\n _streams.erase(s);\n }\n }\n\n void client::deregister_this_stream() {\n if (_parent) {\n _parent->_streams.erase(_id);\n }\n }\n\n // This is the enlightened copy of the connection::send() method. Its intention is to\n // keep a dummy entry in front of the queue while connect+negotiate is happenning so\n // that all subsequent entries could abort on timeout or explicit cancellation.\n void client::enqueue_zero_frame() {\n if (_error) {\n return;\n }\n\n auto p = std::make_unique(snd_buf(0));\n auto& d = *p;\n _outgoing_queue.push_back(d);\n\n // Make it in the background. Even if the client is stopped it will pick\n // up all the entries hanging around\n (void)std::exchange(_outgoing_queue_ready, d.done.get_future()).then_wrapped([p = std::move(p)] (auto f) mutable {\n if (f.failed()) {\n f.ignore_ready_future();\n } else {\n p->done.set_value();\n }\n });\n }\n\n struct client::metrics::domain {\n metrics::domain_list_t list;\n stats dead;\n seastar::metrics::metric_groups metric_groups;\n\n static thread_local std::unordered_map all;\n static domain& find_or_create(sstring name);\n\n stats::counter_type count_all(stats::counter_type stats::* field) noexcept {\n stats::counter_type res = dead.*field;\n for (const auto& m : list) {\n res += m._c._stats.*field;\n }\n return res;\n }\n\n size_t count_all_fn(size_t (client::*fn)(void) const) noexcept {\n size_t res = 0;\n for (const auto& m : list) {\n res += (m._c.*fn)();\n }\n return res;\n }\n\n domain(sstring name)\n {\n namespace sm = seastar::metrics;\n auto domain_l = sm::label(\"domain\")(name);\n\n metric_groups.add_group(\"rpc_client\", {\n sm::make_gauge(\"count\", [this] { return list.size(); },\n sm::description(\"Total number of clients\"), { domain_l }),\n sm::make_counter(\"sent_messages\", std::bind(&domain::count_all, this, &stats::sent_messages),\n sm::description(\"Total number of messages sent\"), { domain_l }),\n sm::make_counter(\"replied\", std::bind(&domain::count_all, this, &stats::replied),\n sm::description(\"Total number of responses received\"), { domain_l }),\n sm::make_counter(\"exception_received\", std::bind(&domain::count_all, this, &stats::exception_received),\n sm::description(\"Total number of exceptional responses received\"), { domain_l }).set_skip_when_empty(),\n sm::make_counter(\"timeout\", std::bind(&domain::count_all, this, &stats::timeout),\n sm::description(\"Total number of timeout responses\"), { domain_l }).set_skip_when_empty(),\n sm::make_gauge(\"pending\", std::bind(&domain::count_all_fn, this, &client::outgoing_queue_length),\n sm::description(\"Number of queued outbound messages\"), { domain_l }),\n sm::make_gauge(\"wait_reply\", std::bind(&domain::count_all_fn, this, &client::incoming_queue_length),\n sm::description(\"Number of replies waiting for\"), { domain_l }),\n });\n }\n };\n\n thread_local std::unordered_map client::metrics::domain::all;\n\n client::metrics::domain& client::metrics::domain::find_or_create(sstring name) {\n auto i = all.try_emplace(name, name);\n return i.first->second;\n }\n\n client::metrics::metrics(const client& c)\n : _c(c)\n , _domain(domain::find_or_create(_c._options.metrics_domain))\n {\n _domain.list.push_back(*this);\n }\n\n client::metrics::~metrics() {\n _domain.dead.replied += _c._stats.replied;\n _domain.dead.exception_received += _c._stats.exception_received;\n _domain.dead.sent_messages += _c._stats.sent_messages;\n _domain.dead.timeout += _c._stats.timeout;\n }\n\n client::client(const logger& l, void* s, client_options ops, socket socket, const socket_address& addr, const socket_address& local)\n : rpc::connection(l, s), _socket(std::move(socket)), _server_addr(addr), _local_addr(local), _options(ops), _metrics(*this)\n {\n _socket.set_reuseaddr(ops.reuseaddr);\n // Run client in the background.\n // Communicate result via _stopped.\n // The caller has to call client::stop() to synchronize.\n (void)_socket.connect(addr, local).then([this, ops = std::move(ops)] (connected_socket fd) {\n fd.set_nodelay(ops.tcp_nodelay);\n if (ops.keepalive) {\n fd.set_keepalive(true);\n fd.set_keepalive_parameters(ops.keepalive.value());\n }\n set_socket(std::move(fd));\n\n feature_map features;\n if (_options.compressor_factory) {\n features[protocol_features::COMPRESS] = _options.compressor_factory->supported();\n }\n if (_options.send_timeout_data) {\n features[protocol_features::TIMEOUT] = \"\";\n }\n if (_options.stream_parent) {\n features[protocol_features::STREAM_PARENT] = serialize_connection_id(_options.stream_parent);\n }\n if (!_options.isolation_cookie.empty()) {\n features[protocol_features::ISOLATION] = _options.isolation_cookie;\n }\n\n return negotiate_protocol(std::move(features)).then([this] {\n...\n", "needle_token_start": 5395, "needle_token_end": 5702, "code_context_ntokens": 16135, "output": ["```cpp\ntemplate // T is either snd_buf or rcv_buf\nT make_shard_local_buffer_copy(foreign_ptr> org) {\n if (org.get_owner_shard() == this_shard_id()) {\n return std::move(*org);\n }\n T buf(org->size);\n auto* one = std::get_if>(&org->bufs);\n\n if (one) {\n buf.bufs = temporary_buffer(one->get_write(), one->size(), make_object_deleter(std::move(org)));\n } else {\n auto& orgbufs = std::get>>(org->bufs);\n std::vector> newbufs;\n newbufs.reserve(orgbufs.size());\n deleter d = make_object_deleter(std::move(org));\n for (auto&& b : orgbufs) {\n newbufs.push_back(temporary_buffer(b.get_write(), b.size(), d.share()));\n }\n buf.bufs = std::move(newbufs);\n }\n\n return buf;\n}\n```"]} +{"repo": "scylladb/seastar", "name": "describe_enum_value", "language": "cpp", "path": "src/core/program_options.cc", "position_ratio": 0.45, "description": "\nFunction Description:\n1. **Purpose**: To add an enumeration option to a program's options description, optionally setting a default value if provided.\n2. **Input**: Takes a structured description of the option (including its name and description), the enumeration type, and optionally a pointer to a default value of the enumeration type.\n3. **Output**: There is no direct output; however, the function modifies the options description to include the new enumeration option.\n4. **Procedure**: The function checks if a default value is provided. If so, it converts this enumeration value to its corresponding string representation and adds it as an option with a default value to the options description. If no default value is provided, it simply adds the option without a default value.\n\n", "instruction": "Based on the function description and code context, please retrieve and repeat the exact described function from the code context in a code block wrapped by ```:", "template": "instruction\ncode_context\ndescription\ninstruction", "code_context": "// Path: src/core/dpdk_rte.cc\n/*\n * This file is open source software, licensed to you under the terms\n * of the Apache License, Version 2.0 (the \"License\"). See the NOTICE file\n * distributed with this work for additional information regarding copyright\n * ownership. You may not use this file except in compliance with the License.\n *\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing,\n * software distributed under the License is distributed on an\n * \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY\n * KIND, either express or implied. See the License for the\n * specific language governing permissions and limitations\n * under the License.\n */\n#ifdef SEASTAR_HAVE_DPDK\n\n#include \n#include \n#include \n#include \n#include \n#include \n\nnamespace seastar {\n\nnamespace dpdk {\n\nbool eal::initialized = false;\n\nvoid eal::init(cpuset cpus, const std::string& argv0, const std::optional& hugepages_path, bool dpdk_pmd)\n{\n if (initialized) {\n return;\n }\n\n size_t cpu_count = cpus.count();\n std::stringstream mask;\n cpuset nibble = 0xF;\n while (cpus.any()) {\n...\n// Path: src/core/sharded.cc\n/*\n * This file is open source software, licensed to you under the terms\n * of the Apache License, Version 2.0 (the \"License\"). See the NOTICE file\n * distributed with this work for additional information regarding copyright\n * ownership. You may not use this file except in compliance with the License.\n *\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing,\n * software distributed under the License is distributed on an\n * \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY\n * KIND, either express or implied. See the License for the\n * specific language governing permissions and limitations\n * under the License.\n */\n/*\n * Copyright (C) 2018 ScyllaDB\n */\n\n#ifdef SEASTAR_MODULE\nmodule;\n#endif\n\n#include \n\n#ifdef SEASTAR_MODULE\nmodule seastar;\n#else\n#include \n#include \n#endif\n\nnamespace seastar {\n\nnamespace internal {\n\n\nfuture<>\nsharded_parallel_for_each(unsigned nr_shards, on_each_shard_func on_each_shard) noexcept(std::is_nothrow_move_constructible_v) {\n return parallel_for_each(boost::irange(0, nr_shards), std::move(on_each_shard));\n}\n\n}\n\n}\n\n// Path: src/core/uname.cc\n/*\n * This file is open source software, licensed to you under the terms\n * of the Apache License, Version 2.0 (the \"License\"). See the NOTICE file\n * distributed with this work for additional information regarding copyright\n * ownership. You may not use this file except in compliance with the License.\n *\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing,\n * software distributed under the License is distributed on an\n * \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY\n * KIND, either express or implied. See the License for the\n * specific language governing permissions and limitations\n * under the License.\n */\n\n/*\n * Copyright (C) 2019 ScyllaDB\n */\n\n#ifdef SEASTAR_MODULE\nmodule;\n#endif\n\n#include \n#include \n#include \n#include \n#include \n#include \n\n#ifdef SEASTAR_MODULE\nmodule seastar;\n#else\n#include \n#endif\n\nnamespace seastar {\n\nnamespace internal {\n\nint uname_t::component_count() const {\n if (distro_patch) {\n return 5;\n }\n if (subsublevel) {\n return 4;\n }\n if (sublevel) {\n return 3;\n }\n return 2;\n}\n\nbool uname_t::has_distro_extra(std::string extra) const {\n return distro_extra.find(extra) != std::string::npos;\n}\n\n// Can't use optional compares, C++17 only\nstatic int cmp(const std::optional& u1, const std::optional& u2) {\n return int(u1.value_or(0) - u2.value_or(0));\n}\n\nbool uname_t::same_as_or_descendant_of(const uname_t& x) const {\n if (version < x.version) {\n return false; // 4.2 vs. 5.1\n }\n if (version == x.version && patchlevel < x.patchlevel) {\n return false; // 4.0 vs 4.1\n }\n if (!has_distro_extra(x.distro_extra)) {\n return false;\n }\n switch (x.component_count()) {\n case 5:\n return version == x.version\n && patchlevel == x.patchlevel\n && cmp(sublevel, x.sublevel) == 0\n && cmp(subsublevel, x.subsublevel) == 0\n && cmp(distro_patch, x.distro_patch) >= 0;\n case 4:\n return version == x.version\n && patchlevel == x.patchlevel\n && cmp(sublevel, x.sublevel) == 0\n && cmp(subsublevel, x.subsublevel) >= 0;\n case 3:\n return version == x.version\n && patchlevel == x.patchlevel\n && cmp(sublevel, x.sublevel) >= 0;\n case 2:\n return true;\n default:\n return false;\n }\n}\n\nuname_t parse_uname(const char* u) {\n static std::regex re(R\"XX((\\d+)\\.(\\d+)(?:\\.(\\d+)(?:\\.(\\d+))?)?(?:-(\\d*)(.+))?)XX\");\n std::cmatch m;\n if (std::regex_match(u, m, re)) {\n auto num = [] (std::csub_match sm) -> std::optional {\n if (sm.length() > 0) {\n return std::atoi(sm.str().c_str());\n } else {\n return std::nullopt;\n }\n };\n return uname_t{*num(m[1]), *num(m[2]), num(m[3]), num(m[4]), num(m[5]), m[6].str()};\n } else {\n return uname_t{0, 0, std::nullopt, std::nullopt, std::nullopt, \"\"};\n }\n}\n\n\nbool uname_t::whitelisted(std::initializer_list wl) const {\n return boost::algorithm::any_of(wl, [this] (const char* v) {\n return same_as_or_descendant_of(parse_uname(v));\n });\n}\n\nstd::ostream& operator<<(std::ostream& os, const uname_t& u) {\n os << u.version << \".\" << u.patchlevel;\n if (u.sublevel) {\n os << \".\" << *u.sublevel;\n }\n if (u.subsublevel) {\n os << \".\" << *u.subsublevel;\n }\n if (u.distro_patch || !u.distro_extra.empty()) {\n os << \"-\";\n }\n if (u.distro_patch) {\n os << *u.distro_patch;\n }\n os << u.distro_extra;\n return os;\n}\n\n\nuname_t kernel_uname() {\n struct ::utsname buf;\n ::uname(&buf);\n return parse_uname(buf.release);\n}\n\n}\n}\n\n// Path: src/core/app-template.cc\n/*\n * This file is open source software, licensed to you under the terms\n * of the Apache License, Version 2.0 (the \"License\"). See the NOTICE file\n * distributed with this work for additional information regarding copyright\n * ownership. You may not use this file except in compliance with the License.\n *\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing,\n * software distributed under the License is distributed on an\n * \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY\n * KIND, either express or implied. See the License for the\n * specific language governing permissions and limitations\n * under the License.\n */\n/*\n * Copyright (C) 2014 Cloudius Systems, Ltd.\n */\n\n#ifdef SEASTAR_MODULE\nmodule;\n#endif\n\n#include \n#include \n#include \n#include \n#include \n#include \n#include \n\n#ifdef SEASTAR_MODULE\nmodule seastar;\n#else\n#include \n#include \n#include \n#include \n#include \n#include \n#include \n#include \n#include \n#include \"program_options.hh\"\n#endif\n\nnamespace seastar {\n\nnamespace bpo = boost::program_options;\n\nusing namespace std::chrono_literals;\n\nstatic\napp_template::seastar_options\nseastar_options_from_config(app_template::config cfg) {\n app_template::seastar_options opts;\n opts.name = std::move(cfg.name);\n opts.description = std::move(cfg.description);\n opts.auto_handle_sigint_sigterm = std::move(cfg.auto_handle_sigint_sigterm);\n opts.reactor_opts.task_quota_ms.set_default_value(cfg.default_task_quota / 1ms);\n opts.reactor_opts.max_networking_io_control_blocks.set_default_value(cfg.max_networking_aio_io_control_blocks);\n opts.smp_opts.reserve_additional_memory_per_shard = cfg.reserve_additional_memory_per_shard;\n return opts;\n}\n\napp_template::seastar_options::seastar_options()\n : program_options::option_group(nullptr, \"seastar\")\n , reactor_opts(this)\n , metrics_opts(this)\n , smp_opts(this)\n , scollectd_opts(this)\n , log_opts(this)\n{\n}\n\napp_template::app_template(app_template::seastar_options opts)\n : _alien(std::make_unique())\n , _smp(std::make_shared(*_alien))\n , _opts(std::move(opts))\n , _app_opts(_opts.name + \" options\")\n , _conf_reader(get_default_configuration_reader()) {\n\n if (!alien::internal::default_instance) {\n alien::internal::default_instance = _alien.get();\n }\n _app_opts.add_options()\n (\"help,h\", \"show help message\")\n ;\n _app_opts.add_options()\n (\"help-seastar\", \"show help message about seastar options\")\n ;\n _app_opts.add_options()\n (\"help-loggers\", \"print a list of logger names and exit\")\n ;\n\n {\n program_options::options_description_building_visitor visitor;\n _opts.describe(visitor);\n _opts_conf_file.add(std::move(visitor).get_options_description());\n }\n\n _seastar_opts.add(_opts_conf_file);\n}\n\napp_template::app_template(app_template::config cfg)\n : app_template(seastar_options_from_config(std::move(cfg)))\n{\n}\n\napp_template::~app_template() = default;\n\nconst app_template::seastar_options& app_template::options() const {\n return _opts;\n}\n\napp_template::configuration_reader app_template::get_default_configuration_reader() {\n return [this] (bpo::variables_map& configuration) {\n auto home = std::getenv(\"HOME\");\n if (home) {\n std::ifstream ifs(std::string(home) + \"/.config/seastar/seastar.conf\");\n if (ifs) {\n bpo::store(bpo::parse_config_file(ifs, _opts_conf_file), configuration);\n }\n std::ifstream ifs_io(std::string(home) + \"/.config/seastar/io.conf\");\n if (ifs_io) {\n bpo::store(bpo::parse_config_file(ifs_io, _opts_conf_file), configuration);\n }\n }\n };\n}\n\nvoid app_template::set_configuration_reader(configuration_reader conf_reader) {\n _conf_reader = conf_reader;\n}\n\nboost::program_options::options_description& app_template::get_options_description() {\n return _app_opts;\n}\n\nboost::program_options::options_description& app_template::get_conf_file_options_description() {\n return _opts_conf_file;\n}\n\nboost::program_options::options_description_easy_init\napp_template::add_options() {\n return _app_opts.add_options();\n}\n\nvoid\napp_template::add_positional_options(std::initializer_list options) {\n for (auto&& o : options) {\n _app_opts.add(boost::make_shared(o.name, o.value_semantic, o.help));\n _pos_opts.add(o.name, o.max_count);\n }\n}\n\n\nbpo::variables_map&\napp_template::configuration() {\n return *_configuration;\n}\n\nint\napp_template::run(int ac, char ** av, std::function ()>&& func) noexcept {\n return run_deprecated(ac, av, [func = std::move(func)] () mutable {\n auto func_done = make_lw_shared>();\n engine().at_exit([func_done] { return func_done->get_future(); });\n // No need to wait for this future.\n // func's returned exit_code is communicated via engine().exit()\n (void)futurize_invoke(func).finally([func_done] {\n func_done->set_value();\n }).then([] (int exit_code) {\n return engine().exit(exit_code);\n }).or_terminate();\n });\n}\n\nint\napp_template::run(int ac, char ** av, std::function ()>&& func) noexcept {\n return run(ac, av, [func = std::move(func)] {\n return func().then([] () {\n return 0;\n });\n });\n}\n\nint\napp_template::run_deprecated(int ac, char ** av, std::function&& func) noexcept {\n#ifdef SEASTAR_DEBUG\n fmt::print(std::cerr, \"WARNING: debug mode. Not for benchmarking or production\\n\");\n#endif\n boost::program_options::options_description all_opts;\n all_opts.add(_app_opts);\n all_opts.add(_seastar_opts);\n\n bpo::variables_map configuration;\n try {\n bpo::store(bpo::command_line_parser(ac, av)\n .options(all_opts)\n .positional(_pos_opts)\n .run()\n , configuration);\n _conf_reader(configuration);\n } catch (bpo::error& e) {\n fmt::print(\"error: {}\\n\\nTry --help.\\n\", e.what());\n return 2;\n }\n if (configuration.count(\"help\")) {\n if (!_opts.description.empty()) {\n std::cout << _opts.description << \"\\n\";\n }\n std::cout << _app_opts << \"\\n\";\n return 1;\n }\n if (configuration.count(\"help-seastar\")) {\n std::cout << _seastar_opts << \"\\n\";\n return 1;\n }\n if (configuration.count(\"help-loggers\")) {\n log_cli::print_available_loggers(std::cout);\n return 1;\n }\n\n try {\n bpo::notify(configuration);\n } catch (const bpo::error& ex) {\n std::cout << ex.what() << std::endl;\n return 1;\n }\n\n {\n program_options::variables_map_extracting_visitor visitor(configuration);\n _opts.mutate(visitor);\n }\n _opts.reactor_opts._argv0 = std::string(av[0]);\n _opts.reactor_opts._auto_handle_sigint_sigterm = _opts.auto_handle_sigint_sigterm;\n if (auto* native_stack = dynamic_cast(_opts.reactor_opts.network_stack.get_selected_candidate_opts())) {\n native_stack->_hugepages = _opts.smp_opts.hugepages;\n }\n\n // Needs to be before `smp::configure()`.\n try {\n apply_logging_settings(log_cli::extract_settings(_opts.log_opts));\n } catch (const std::runtime_error& exn) {\n std::cout << \"logging configuration error: \" << exn.what() << '\\n';\n return 1;\n }\n\n try {\n _smp->configure(_opts.smp_opts, _opts.reactor_opts);\n } catch (...) {\n std::cerr << \"Could not initialize seastar: \" << std::current_exception() << std::endl;\n return 1;\n }\n _configuration = {std::move(configuration)};\n // No need to wait for this future.\n // func is waited on via engine().run()\n (void)engine().when_started().then([this] {\n return seastar::metrics::configure(_opts.metrics_opts).then([this] {\n // set scollectd use the metrics configuration, so the later\n // need to be set first\n scollectd::configure( _opts.scollectd_opts);\n });\n }).then(\n std::move(func)\n ).then_wrapped([] (auto&& f) {\n try {\n f.get();\n } catch (std::exception& ex) {\n std::cout << \"program failed with uncaught exception: \" << ex.what() << \"\\n\";\n engine().exit(1);\n }\n });\n auto exit_code = engine().run();\n _smp->cleanup();\n return exit_code;\n}\n\n}\n\n// Path: src/core/program_options.cc\n/*\n * This file is open source software, licensed to you under the terms\n * of the Apache License, Version 2.0 (the \"License\"). See the NOTICE file\n * distributed with this work for additional information regarding copyright\n * ownership. You may not use this file except in compliance with the License.\n *\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing,\n * software distributed under the License is distributed on an\n * \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY\n * KIND, either express or implied. See the License for the\n * specific language governing permissions and limitations\n * under the License.\n */\n/*\n * Copyright (C) 2021 Cloudius Systems, Ltd.\n */\n\n#ifdef SEASTAR_MODULE\nmodule;\n#include \n#include \n#include \n#include \n#include \n#include \nmodule seastar;\n#else\n#include \"core/program_options.hh\"\n\n#include \n#include \n#include \n#include \n#include \n#endif\n\nnamespace seastar::program_options {\n\nnamespace {\n\nconst char* to_string(memory::alloc_failure_kind val) {\n switch (val) {\n case memory::alloc_failure_kind::none: return \"none\";\n case memory::alloc_failure_kind::critical: return \"critical\";\n case memory::alloc_failure_kind::all: return \"all\";\n }\n std::abort();\n}\n\nconst char* to_string(log_level val) {\n switch (val) {\n case log_level::error: return \"error\";\n case log_level::warn: return \"warn\";\n case log_level::info: return \"info\";\n case log_level::debug: return \"debug\";\n case log_level::trace: return \"trace\";\n }\n std::abort();\n}\n\nconst char* to_string(logger_timestamp_style val) {\n switch (val) {\n case logger_timestamp_style::none: return \"none\";\n case logger_timestamp_style::boot: return \"boot\";\n case logger_timestamp_style::real: return \"real\";\n }\n std::abort();\n}\n\nconst char* to_string(logger_ostream_type val) {\n switch (val) {\n case logger_ostream_type::none: return \"none\";\n case logger_ostream_type::cout: return \"stdout\";\n case logger_ostream_type::cerr: return \"stderr\";\n }\n std::abort();\n}\n\nmemory::alloc_failure_kind from_string(const std::string& val, boost::type) {\n if (val == \"none\") {\n return memory::alloc_failure_kind::none;\n } else if (val == \"critical\") {\n return memory::alloc_failure_kind::critical;\n } else if (val == \"all\") {\n return memory::alloc_failure_kind::all;\n }\n throw std::runtime_error(fmt::format(\"Invalid value for enum memory::alloc_failure_kind: {}\", val));\n}\n\nlog_level from_string(const std::string& val, boost::type) {\n if (val == \"error\") {\n return log_level::error;\n } else if (val == \"warn\") {\n return log_level::warn;\n } else if (val == \"info\") {\n return log_level::info;\n } else if (val == \"debug\") {\n return log_level::debug;\n } else if (val == \"trace\") {\n return log_level::trace;\n }\n throw std::runtime_error(fmt::format(\"Invalid value for enum log_level: {}\", val));\n}\n\nlogger_timestamp_style from_string(const std::string& val, boost::type) {\n if (val == \"none\") {\n return logger_timestamp_style::none;\n } else if (val == \"boot\") {\n return logger_timestamp_style::boot;\n } else if (val == \"real\") {\n return logger_timestamp_style::real;\n }\n throw std::runtime_error(fmt::format(\"Invalid value for enum logger_timestamp_style: {}\", val));\n}\n\nlogger_ostream_type from_string(const std::string& val, boost::type) {\n if (val == \"none\") {\n return logger_ostream_type::none;\n } else if (val == \"stdout\") {\n return logger_ostream_type::cout;\n } else if (val == \"stderr\") {\n return logger_ostream_type::cerr;\n }\n throw std::runtime_error(fmt::format(\"Invalid value for enum logger_ostream_type: {}\", val));\n}\n\ntemplate \nvoid describe_value(bpo::options_description& opts, const std::string& name, const std::string& description, const Type& default_value) {\n opts.add_options()(name.c_str(), boost::program_options::value()->default_value(default_value), description.c_str());\n}\n\ntemplate \nvoid describe_value(bpo::options_description& opts, const options_description_building_visitor::value_metadata& d, const Type& default_value) {\n describe_value(opts, d.name, d.description, default_value);\n}\n\ntemplate \nvoid describe_value(bpo::options_description& opts, const std::string& name, const std::string& description) {\n opts.add_options()(name.c_str(), boost::program_options::value(), description.c_str());\n}\n\ntemplate \nvoid describe_value(bpo::options_description& opts, const options_description_building_visitor::value_metadata& d) {\n describe_value(opts, d.name, d.description);\n}\n\ntemplate \nvoid describe_value_maybe_default(bpo::options_description& opts, const std::string& name, const std::string& description, const Type* default_value) {\n if (default_value) {\n describe_value(opts, name, description, *default_value);\n } else {\n describe_value(opts, name, description);\n }\n}\n\ntemplate \nvoid describe_value_maybe_default(bpo::options_description& opts, const options_description_building_visitor::value_metadata& d, const Type* default_value) {\n describe_value_maybe_default(opts, d.name, d.description, default_value);\n}\n\ntemplate \n\nvoid describe_enum_value(bpo::options_description& opts, const options_description_building_visitor::value_metadata& d, const Enum* default_value) {\n if (default_value) {\n opts.add_options()(d.name.c_str(), boost::program_options::value()->default_value(to_string(*default_value)), d.description.c_str());\n } else {\n opts.add_options()(d.name.c_str(), boost::program_options::value(), d.description.c_str());\n }\n}\n\ntemplate \nbool extract_value(const bpo::variables_map& values, const std::string& current_name, T& val) {\n auto it = values.find(current_name);\n if (it == values.end() || it->second.defaulted()) {\n return false;\n }\n val = it->second.as();\n return true;\n}\n\ntemplate \nbool extract_enum_value(const bpo::variables_map& values, const std::string& current_name, T& val) {\n auto it = values.find(current_name);\n if (it == values.end() || it->second.defaulted()) {\n return false;\n }\n val = from_string(it->second.as(), boost::type{});\n return true;\n}\n\n} // anonymous namespace\n\nbool options_description_building_visitor::visit_group_start(const std::string& name, bool used) {\n _groups.push({name, bpo::options_description(name.c_str()), used});\n return used;\n}\nvoid options_description_building_visitor::visit_group_end() {\n if (_groups.size() == 1) {\n return;\n }\n auto grp = std::move(_groups.top());\n _groups.pop();\n if (grp.used && grp.values) {\n _groups.top().description.add(std::move(grp.description));\n }\n}\n\nbool options_description_building_visitor::visit_value_metadata(const std::string& name, const std::string& description, bool used) {\n if (!used) {\n return false;\n }\n ++_groups.top().values;\n _current_metadata.emplace(value_metadata{name, description});\n return true;\n}\n\nvoid options_description_building_visitor::visit_value() {\n _groups.top().description.add_options()(_current_metadata->name.c_str(), _current_metadata->description.c_str());\n}\n\nvoid options_description_building_visitor::visit_value(const bool* default_value) {\n describe_value_maybe_default(_groups.top().description, *_current_metadata, default_value);\n}\n\nvoid options_description_building_visitor::visit_value(const int* default_value) {\n describe_value_maybe_default(_groups.top().description, *_current_metadata, default_value);\n}\n\nvoid options_description_building_visitor::visit_value(const unsigned* default_value) {\n auto name = _current_metadata->name;\n if (_current_metadata->name == \"smp\") {\n name = \"smp,c\";\n }\n describe_value_maybe_default(_groups.top().description, name, _current_metadata->description, default_value);\n}\n\nvoid options_description_building_visitor::visit_value(const float* default_value) {\n describe_value_maybe_default(_groups.top().description, *_current_metadata, default_value);\n}\n\nvoid options_description_building_visitor::visit_value(const double* default_value) {\n describe_value_maybe_default(_groups.top().description, *_current_metadata, default_value);\n}\n\nvoid options_description_building_visitor::visit_value(const std::string* default_value) {\n auto name = _current_metadata->name;\n if (_current_metadata->name == \"memory\") {\n name = \"memory,m\";\n }\n describe_value_maybe_default(_groups.top().description, name, _current_metadata->description, default_value);\n}\n\nvoid options_description_building_visitor::visit_value(const std::set*) {\n describe_value(_groups.top().description, *_current_metadata);\n}\n\nvoid options_description_building_visitor::visit_value(const memory::alloc_failure_kind* default_value) {\n describe_enum_value(_groups.top().description, *_current_metadata, default_value);\n}\n\nvoid options_description_building_visitor::visit_value(const log_level* default_value) {\n describe_enum_value(_groups.top().description, *_current_metadata, default_value);\n}\n\nvoid options_description_building_visitor::visit_value(const logger_timestamp_style* default_value) {\n describe_enum_value(_groups.top().description, *_current_metadata, default_value);\n}\n\nvoid options_description_building_visitor::visit_value(const logger_ostream_type* default_value) {\n describe_enum_value(_groups.top().description, *_current_metadata, default_value);\n}\n\nvoid options_description_building_visitor::visit_value(const std::unordered_map*) {\n describe_value>(_groups.top().description, *_current_metadata);\n}\n\nvoid options_description_building_visitor::visit_selection_value(const std::vector& candidates, const std::size_t* selected_candidate) {\n if (selected_candidate) {\n describe_value(_groups.top().description, *_current_metadata, candidates.at(*selected_candidate));\n } else {\n describe_value(_groups.top().description, *_current_metadata);\n }\n}\n\nvariables_map_extracting_visitor::variables_map_extracting_visitor(const bpo::variables_map& values) : _values(values) {\n}\n\nbool variables_map_extracting_visitor::visit_group_start(const std::string& name, bool used) {\n return used;\n}\n\nvoid variables_map_extracting_visitor::visit_group_end() {\n}\n\nbool variables_map_extracting_visitor::visit_value_metadata(const std::string& name, bool used) {\n if (used) {\n _current_name = &name;\n return true;\n } else {\n _current_name = nullptr;\n return false;\n }\n}\n\nbool variables_map_extracting_visitor::visit_value() {\n return _values.count(*_current_name);\n}\n\nbool variables_map_extracting_visitor::visit_value(bool& val) {\n return extract_value(_values, *_current_name, val);\n}\n\nbool variables_map_extracting_visitor::visit_value(int& val) {\n return extract_value(_values, *_current_name, val);\n}\n\nbool variables_map_extracting_visitor::visit_value(unsigned& val) {\n return extract_value(_values, *_current_name, val);\n}\n\nbool variables_map_extracting_visitor::visit_value(float& val) {\n return extract_value(_values, *_current_name, val);\n}\n\nbool variables_map_extracting_visitor::visit_value(double& val) {\n return extract_value(_values, *_current_name, val);\n}\n\nbool variables_map_extracting_visitor::visit_value(std::string& val) {\n return extract_value(_values, *_current_name, val);\n}\n\nbool variables_map_extracting_visitor::visit_value(std::set& val) {\n std::string raw_val;\n if (!extract_value(_values, *_current_name, raw_val)) {\n return false;\n }\n if (auto parsed_cpu_set = resource::parse_cpuset(raw_val)) {\n val = std::move(*parsed_cpu_set);\n return true;\n }\n throw std::invalid_argument(fmt::format(\"invalid value for option {}: failed to parse cpuset: {}\", *_current_name, raw_val));\n}\n\nbool variables_map_extracting_visitor::visit_value(log_level& val) {\n return extract_enum_value(_values, *_current_name, val);\n}\n\nbool variables_map_extracting_visitor::visit_value(logger_timestamp_style& val) {\n return extract_enum_value(_values, *_current_name, val);\n}\n\nbool variables_map_extracting_visitor::visit_value(logger_ostream_type& val) {\n return extract_enum_value(_values, *_current_name, val);\n}\n\nbool variables_map_extracting_visitor::visit_value(memory::alloc_failure_kind& val) {\n return extract_enum_value(_values, *_current_name, val);\n}\n\nbool variables_map_extracting_visitor::visit_value(std::unordered_map& val) {\n std::vector raw_val;\n if (!extract_value(_values, *_current_name, raw_val)) {\n return false;\n }\n for (const auto& e : raw_val) {\n log_cli::parse_map_associations(e, [&val] (std::string k, std::string v) { val[std::move(k)] = log_cli::parse_log_level(v); });\n }\n return !val.empty();\n}\n\nbool variables_map_extracting_visitor::visit_selection_value(const std::vector& candidates, std::size_t& selected_candidate) {\n std::string candidate_name;\n if (!extract_value(_values, *_current_name, candidate_name)) {\n return false;\n }\n auto it = std::find(candidates.begin(), candidates.end(), candidate_name);\n if (it == candidates.end()) {\n throw std::invalid_argument(fmt::format(\"invalid value for option {}: selected candidate doesn't exist: {}\", *_current_name, candidate_name));\n }\n selected_candidate = it - candidates.begin();\n return true;\n}\n\n} // namespace seastar::program_options\n\n// Path: src/core/future-util.cc\n/*\n * This file is open source software, licensed to you under the terms\n * of the Apache License, Version 2.0 (the \"License\"). See the NOTICE file\n * distributed with this work for additional information regarding copyright\n * ownership. You may not use this file except in compliance with the License.\n *\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing,\n * software distributed under the License is distributed on an\n * \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY\n * KIND, either express or implied. See the License for the\n * specific language governing permissions and limitations\n * under the License.\n */\n/*\n * Copyright (C) 2017 ScyllaDB\n */\n#ifdef SEASTAR_MODULE\nmodule;\n#include \n#include \n#include \n#include \n#include \n#include \n#include \nmodule seastar;\n#else\n#include \n#include \n#include \n#include \n#include \n#endif\n\nnamespace seastar {\n\nparallel_for_each_state::parallel_for_each_state(size_t n) {\n _incomplete.reserve(n);\n}\n\nfuture<> parallel_for_each_state::get_future() {\n auto ret = _result.get_future();\n wait_for_one();\n return ret;\n}\n\nvoid parallel_for_each_state::add_future(future<>&& f) {\n _incomplete.push_back(std::move(f));\n}\n\nvoid parallel_for_each_state::wait_for_one() noexcept {\n // Process from back to front, on the assumption that the front\n // futures are likely to complete earlier than the back futures.\n // If that's indeed the case, then the front futures will be\n // available and we won't have to wait for them.\n\n // Skip over futures that happen to be complete already.\n while (!_incomplete.empty() && _incomplete.back().available()) {\n if (_incomplete.back().failed()) {\n _ex = _incomplete.back().get_exception();\n }\n _incomplete.pop_back();\n }\n\n // If there's an incompelete future, wait for it.\n if (!_incomplete.empty()) {\n internal::set_callback(std::move(_incomplete.back()), static_cast*>(this));\n // This future's state will be collected in run_and_dispose(), so we can drop it.\n _incomplete.pop_back();\n return;\n }\n\n // Everything completed, report a result.\n if (__builtin_expect(bool(_ex), false)) {\n _result.set_exception(std::move(_ex));\n } else {\n _result.set_value();\n }\n delete this;\n}\n\nvoid parallel_for_each_state::run_and_dispose() noexcept {\n if (_state.failed()) {\n _ex = std::move(_state).get_exception();\n }\n _state = {};\n wait_for_one();\n}\n\ntemplate \nfuture<> sleep_abortable(typename Clock::duration dur) {\n return engine().wait_for_stop(dur).then([] {\n throw sleep_aborted();\n }).handle_exception([] (std::exception_ptr ep) {\n try {\n std::rethrow_exception(ep);\n } catch(condition_variable_timed_out&) {};\n });\n}\n\ntemplate future<> sleep_abortable(typename steady_clock_type::duration);\ntemplate future<> sleep_abortable(typename lowres_clock::duration);\n\ntemplate \nfuture<> sleep_abortable(typename Clock::duration dur, abort_source& as) {\n struct sleeper {\n promise<> done;\n timer tmr;\n abort_source::subscription sc;\n\n sleeper(typename Clock::duration dur, abort_source& as)\n : tmr([this] { done.set_value(); }) {\n auto sc_opt = as.subscribe([this] (const std::optional& opt_ex) noexcept {\n if (tmr.cancel()) {\n done.set_exception(opt_ex.value_or(std::make_exception_ptr(sleep_aborted())));\n }\n });\n if (sc_opt) {\n sc = std::move(*sc_opt);\n tmr.arm(dur);\n } else {\n done.set_exception(sleep_aborted());\n }\n }\n };\n //FIXME: Use do_with() after #373\n auto s = std::make_unique(dur, as);\n auto fut = s->done.get_future();\n return fut.finally([s = std::move(s)] { });\n}\n\ntemplate future<> sleep_abortable(typename steady_clock_type::duration, abort_source&);\ntemplate future<> sleep_abortable(typename lowres_clock::duration, abort_source&);\ntemplate future<> sleep_abortable(typename manual_clock::duration, abort_source&);\n\n}\n\n// Path: src/core/future.cc\n/*\n * This file is open source software, licensed to you under the terms\n * of the Apache License, Version 2.0 (the \"License\"). See the NOTICE file\n * distributed with this work for additional information regarding copyright\n * ownership. You may not use this file except in compliance with the License.\n *\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing,\n * software distributed under the License is distributed on an\n * \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY\n * KIND, either express or implied. See the License for the\n * specific language governing permissions and limitations\n * under the License.\n */\n/*\n * Copyright (C) 2020 ScyllaDB\n */\n\n#ifdef SEASTAR_MODULE\nmodule;\n#include \n#include \n#include \n#include \n#include \nmodule seastar;\n#else\n#include \n#include \n#include \n#include \n#include \n#endif\n\nnamespace seastar {\n\n// We can't test future_state_base directly because its private\n// destructor is protected.\nstatic_assert(std::is_nothrow_move_constructible_v>>,\n \"future_state's move constructor must not throw\");\n\nstatic_assert(sizeof(future_state>) <= 8, \"future_state> is too large\");\nstatic_assert(sizeof(future_state>) <= 16, \"future_state> is too large\");\nstatic_assert(future_state>::has_trivial_move_and_destroy, \"future_state> not trivial\");\nstatic_assert(future_state::has_trivial_move_and_destroy, \"future_state not trivial\");\n\n// We need to be able to move and copy std::exception_ptr in and out\n// of future/promise/continuations without that producing a new\n// exception.\nstatic_assert(std::is_nothrow_copy_constructible_v,\n \"std::exception_ptr's copy constructor must not throw\");\nstatic_assert(std::is_nothrow_move_constructible_v,\n \"std::exception_ptr's move constructor must not throw\");\n\nnamespace internal {\n\nstatic_assert(std::is_empty_v>>, \"This should still be empty\");\n\nvoid promise_base::move_it(promise_base&& x) noexcept {\n // Don't use std::exchange to make sure x's values are nulled even\n // if &x == this.\n _task = x._task;\n x._task = nullptr;\n#ifdef SEASTAR_DEBUG_PROMISE\n _task_shard = x._task_shard;\n#endif\n _state = x._state;\n x._state = nullptr;\n _future = x._future;\n if (auto* fut = _future) {\n fut->detach_promise();\n fut->_promise = this;\n }\n}\n\nstatic void set_to_broken_promise(future_state_base& state) noexcept {\n try {\n // Constructing broken_promise may throw (std::logic_error ctor is not noexcept).\n state.set_exception(std::make_exception_ptr(broken_promise{}));\n } catch (...) {\n state.set_exception(std::current_exception());\n }\n}\n\npromise_base::promise_base(promise_base&& x) noexcept {\n move_it(std::move(x));\n}\n\nvoid promise_base::clear() noexcept {\n if (__builtin_expect(bool(_task), false)) {\n assert(_state && !_state->available());\n set_to_broken_promise(*_state);\n ::seastar::schedule(std::exchange(_task, nullptr));\n }\n if (_future) {\n assert(_state);\n if (!_state->available()) {\n set_to_broken_promise(*_state);\n }\n _future->detach_promise();\n }\n}\n\npromise_base& promise_base::operator=(promise_base&& x) noexcept {\n clear();\n move_it(std::move(x));\n return *this;\n}\n\nvoid promise_base::set_to_current_exception() noexcept {\n set_exception(std::current_exception());\n}\n\n#ifdef SEASTAR_DEBUG_PROMISE\n\nvoid promise_base::assert_task_shard() const noexcept {\n if (_task_shard >= 0 && static_cast(_task_shard) != this_shard_id()) {\n on_fatal_internal_error(seastar_logger, format(\"Promise task was set on shard {} but made ready on shard {}\", _task_shard, this_shard_id()));\n }\n}\n\n#endif\n\ntemplate \nvoid promise_base::make_ready() noexcept {\n if (_task) {\n assert_task_shard();\n if (Urgent == urgent::yes) {\n ::seastar::schedule_urgent(std::exchange(_task, nullptr));\n } else {\n ::seastar::schedule(std::exchange(_task, nullptr));\n }\n }\n}\n\ntemplate void promise_base::make_ready() noexcept;\ntemplate void promise_base::make_ready() noexcept;\n}\n\ntemplate\nfuture current_exception_as_future() noexcept;\n\n/**\n * engine_exit() exits the reactor. It should be given a pointer to the\n * exception which prompted this exit - or a null pointer if the exit\n * request was not caused by any exception.\n */\nvoid engine_exit(std::exception_ptr eptr) {\n if (!eptr) {\n engine().exit(0);\n return;\n }\n report_exception(\"Exiting on unhandled exception\", eptr);\n engine().exit(1);\n}\n\nbroken_promise::broken_promise() : logic_error(\"broken promise\") { }\n\nfuture_state_base::future_state_base(current_exception_future_marker) noexcept\n : future_state_base(std::current_exception()) { }\n\nvoid future_state_base::ignore() noexcept {\n switch (_u.st) {\n case state::invalid:\n case state::future:\n case state::result_unavailable:\n assert(0 && \"invalid state for ignore\");\n case state::result:\n _u.st = state::result_unavailable;\n break;\n default:\n // Ignore the exception\n _u.take_exception();\n }\n}\n\nnested_exception::nested_exception(std::exception_ptr inner, std::exception_ptr outer) noexcept\n : inner(std::move(inner)), outer(std::move(outer)) {}\n\nnested_exception::nested_exception(nested_exception&&) noexcept = default;\n\nnested_exception::nested_exception(const nested_exception&) noexcept = default;\n\nconst char* nested_exception::what() const noexcept {\n return \"seastar::nested_exception\";\n}\n\n[[noreturn]] void nested_exception::rethrow_nested() const {\n std::rethrow_exception(outer);\n}\n\nstatic std::exception_ptr make_nested(std::exception_ptr&& inner, future_state_base&& old) noexcept {\n std::exception_ptr outer = std::move(old).get_exception();\n nested_exception nested{std::move(inner), std::move(outer)};\n return std::make_exception_ptr(std::move(nested));\n}\n\nfuture_state_base::future_state_base(nested_exception_marker, future_state_base&& n, future_state_base&& old) noexcept {\n std::exception_ptr inner = std::move(n).get_exception();\n if (!old.failed()) {\n new (this) future_state_base(std::move(inner));\n } else {\n new (this) future_state_base(make_nested(std::move(inner), std::move(old)));\n }\n}\n\nfuture_state_base::future_state_base(nested_exception_marker, future_state_base&& old) noexcept {\n if (!old.failed()) {\n new (this) future_state_base(current_exception_future_marker());\n return;\n } else {\n new (this) future_state_base(make_nested(std::current_exception(), std::move(old)));\n }\n}\n\nvoid future_state_base::rethrow_exception() && {\n // Move ex out so future::~future() knows we've handled it\n std::rethrow_exception(std::move(*this).get_exception());\n}\n\nvoid future_state_base::rethrow_exception() const& {\n std::rethrow_exception(_u.ex);\n}\n\nvoid report_failed_future(const std::exception_ptr& eptr) noexcept {\n ++engine()._abandoned_failed_futures;\n seastar_logger.warn(\"Exceptional future ignored: {}, backtrace: {}\", eptr, current_backtrace());\n}\n\nvoid report_failed_future(const future_state_base& state) noexcept {\n report_failed_future(state._u.ex);\n}\n\nvoid report_failed_future(future_state_base::any&& state) noexcept {\n report_failed_future(std::move(state).take_exception());\n}\n\nvoid reactor::test::with_allow_abandoned_failed_futures(unsigned count, noncopyable_function func) {\n auto before = engine()._abandoned_failed_futures;\n auto old_level = seastar_logger.level();\n seastar_logger.set_level(log_level::error);\n func();\n auto after = engine()._abandoned_failed_futures;\n assert(after - before == count);\n engine()._abandoned_failed_futures = before;\n seastar_logger.set_level(old_level);\n}\n\nnamespace {\nclass thread_wake_task final : public task {\n thread_context* _thread;\npublic:\n thread_wake_task(thread_context* thread) noexcept : _thread(thread) {}\n virtual void run_and_dispose() noexcept override {\n thread_impl::switch_in(_thread);\n // no need to delete, since this is always allocated on\n // _thread's stack.\n }\n /// Returns the task which is waiting for this thread to be done, or nullptr.\n virtual task* waiting_task() noexcept override {\n return _thread->waiting_task();\n }\n};\n}\n\nvoid internal::future_base::do_wait() noexcept {\n auto thread = thread_impl::get();\n assert(thread);\n thread_wake_task wake_task{thread};\n wake_task.make_backtrace();\n _promise->set_task(&wake_task);\n thread_impl::switch_out(thread);\n}\n\n#ifdef SEASTAR_COROUTINES_ENABLED\nvoid internal::future_base::set_coroutine(task& coroutine) noexcept {\n assert(_promise);\n _promise->set_task(&coroutine);\n}\n#endif\n\n}\n\n// Path: src/core/resource.cc\n\n/*\n * This file is open source software, licensed to you under the terms\n * of the Apache License, Version 2.0 (the \"License\"). See the NOTICE file\n * distributed with this work for additional information regarding copyright\n * ownership. You may not use this file except in compliance with the License.\n *\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing,\n * software distributed under the License is distributed on an\n * \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY\n * KIND, either express or implied. See the License for the\n * specific language governing permissions and limitations\n * under the License.\n */\n/*\n * Copyright (C) 2014 Cloudius Systems, Ltd.\n */\n\n#ifdef SEASTAR_MODULE\nmodule;\n#endif\n\n#include \n#include \n#include \n#include \n#include \n#include \n#include \n#include \n#include \n#include \n#include \n\n#ifdef SEASTAR_MODULE\nmodule seastar;\n#else\n#include \n#include \n#include \n#include \n#include \n#include \n#include \n#include \n#include \"cgroup.hh\"\n\n#if SEASTAR_HAVE_HWLOC\n#include \n#endif\n\n#endif\n\nnamespace seastar {\n\nextern logger seastar_logger;\n\nnamespace resource {\n\n// This function was made optional because of validate. It needs to\n// throw an error when a non parseable input is given.\nstd::optional parse_cpuset(std::string value) {\n static std::regex r(\"(\\\\d+-)?(\\\\d+)(,(\\\\d+-)?(\\\\d+))*\");\n\n std::smatch match;\n if (std::regex_match(value, match, r)) {\n std::vector ranges;\n boost::split(ranges, value, boost::is_any_of(\",\"));\n resource::cpuset ret;\n for (auto&& range: ranges) {\n std::string beg = range;\n std::string end = range;\n auto dash = range.find('-');\n if (dash != range.npos) {\n beg = range.substr(0, dash);\n end = range.substr(dash + 1);\n }\n auto b = boost::lexical_cast(beg);\n auto e = boost::lexical_cast(end);\n\n if (b > e) {\n return std::nullopt;\n }\n\n for (auto i = b; i <= e; ++i) {\n ret.insert(i);\n }\n }\n return ret;\n }\n return std::nullopt;\n}\n\n}\n\nnamespace cgroup {\n\nnamespace fs = std::filesystem;\n\noptional cpu_set() {\n auto cpuset = read_setting_V1V2_as(\n \"cpuset/cpuset.cpus\",\n \"cpuset.cpus.effective\");\n if (cpuset) {\n return seastar::resource::parse_cpuset(*cpuset);\n }\n\n seastar_logger.warn(\"Unable to parse cgroup's cpuset. Ignoring.\");\n return std::nullopt;\n}\n\nsize_t memory_limit() {\n return read_setting_V1V2_as(\n \"memory/memory.limit_in_bytes\",\n \"memory.max\")\n .value_or(std::numeric_limits::max());\n}\n\ntemplate \noptional read_setting_as(std::string path) {\n try {\n auto line = read_first_line(path);\n return boost::lexical_cast(line);\n } catch (...) {\n seastar_logger.warn(\"Couldn't read cgroup file {}.\", path);\n }\n\n return std::nullopt;\n}\n\n/*\n * what cgroup do we belong to?\n *\n * For cgroups V2, /proc/self/cgroup should read \"0::\"\n * Note: true only for V2-only systems, but there is no reason to support\n * a hybrid configuration.\n */\nstatic optional cgroup2_path_my_pid() {\n seastar::sstring cline;\n try {\n cline = read_first_line(fs::path{\"/proc/self/cgroup\"});\n } catch (...) {\n // '/proc/self/cgroup' must be there. If not - there is an issue\n // with the system configuration.\n throw std::runtime_error(\"no cgroup data for our process\");\n }\n\n // for a V2-only system, we expect exactly one line:\n // 0::\n if (cline.at(0) != '0') {\n // This is either a v1 system, or system configured with a hybrid of v1 & v2.\n // We do not support such combinations of v1 and v2 at this point.\n seastar_logger.debug(\"Not a cgroups-v2-only system\");\n return std::nullopt;\n }\n\n // the path is guaranteed to start with '0::/'\n return fs::path{\"/sys/fs/cgroup/\" + cline.substr(4)};\n}\n\n/*\n * traverse the cgroups V2 hierarchy bottom-up, starting from our process'\n * specific cgroup up to /sys/fs/cgroup, looking for the named file.\n */\nstatic optional locate_lowest_cgroup2(fs::path lowest_subdir, std::string filename) {\n // locate the lowest subgroup containing the named file (i.e.\n // handles the requested control by itself)\n...\n", "needle_token_start": 7115, "needle_token_end": 7263, "code_context_ntokens": 16022, "output": ["```cpp\ntemplate \nvoid describe_enum_value(bpo::options_description& opts, const options_description_building_visitor::value_metadata& d, const Enum* default_value) {\n if (default_value) {\n opts.add_options()(d.name.c_str(), boost::program_options::value()->default_value(to_string(*default_value)), d.description.c_str());\n } else {\n opts.add_options()(d.name.c_str(), boost::program_options::value(), d.description.c_str());\n }\n}\n```"]} +{"repo": "scylladb/seastar", "name": "posix_datagram", "language": "cpp", "path": "src/net/posix-stack.cc", "position_ratio": 0.55, "description": "\nFunction Description:\n1. **Purpose**: The purpose is to encapsulate a network message, including its source and destination addresses along with the data packet, for transmission over a network.\n2. **Input**: The input includes the source address, destination address, and the data packet to be transmitted.\n3. **Output**: There is no direct output as this is a constructor for setting up the network message object. The object itself represents the encapsulated data ready for network operations.\n4. **Procedure**: The procedure involves initializing the object with the provided source and destination addresses and the data packet. The data packet is moved into the object to manage ownership and ensure efficient data handling.\n\n", "instruction": "Based on the function description and code context, please retrieve and repeat the exact described function from the code context in a code block wrapped by ```:", "template": "instruction\ncode_context\ndescription\ninstruction", "code_context": " std::pmr::polymorphic_allocator* _allocator;\nprivate:\n explicit posix_connected_socket_impl(sa_family_t family, int protocol, pollable_fd fd, std::pmr::polymorphic_allocator* allocator=memory::malloc_allocator) :\n _fd(std::move(fd)), _ops(get_posix_connected_socket_ops(family, protocol)), _allocator(allocator) {}\n explicit posix_connected_socket_impl(sa_family_t family, int protocol, pollable_fd fd, conntrack::handle&& handle,\n std::pmr::polymorphic_allocator* allocator=memory::malloc_allocator) : _fd(std::move(fd))\n , _ops(get_posix_connected_socket_ops(family, protocol)), _handle(std::move(handle)), _allocator(allocator) {}\npublic:\n virtual data_source source() override {\n return source(connected_socket_input_stream_config());\n }\n virtual data_source source(connected_socket_input_stream_config csisc) override {\n return data_source(std::make_unique(_fd, csisc, _allocator));\n }\n virtual data_sink sink() override {\n return data_sink(std::make_unique< posix_data_sink_impl>(_fd));\n }\n virtual void shutdown_input() override {\n shutdown_socket_fd(_fd, SHUT_RD);\n }\n virtual void shutdown_output() override {\n shutdown_socket_fd(_fd, SHUT_WR);\n }\n virtual void set_nodelay(bool nodelay) override {\n return _ops->set_nodelay(_fd.get_file_desc(), nodelay);\n }\n virtual bool get_nodelay() const override {\n return _ops->get_nodelay(_fd.get_file_desc());\n }\n void set_keepalive(bool keepalive) override {\n return _ops->set_keepalive(_fd.get_file_desc(), keepalive);\n }\n bool get_keepalive() const override {\n return _ops->get_keepalive(_fd.get_file_desc());\n }\n void set_keepalive_parameters(const keepalive_params& p) override {\n return _ops->set_keepalive_parameters(_fd.get_file_desc(), p);\n }\n keepalive_params get_keepalive_parameters() const override {\n return _ops->get_keepalive_parameters(_fd.get_file_desc());\n }\n void set_sockopt(int level, int optname, const void* data, size_t len) override {\n return _ops->set_sockopt(_fd.get_file_desc(), level, optname, data, len);\n }\n int get_sockopt(int level, int optname, void* data, size_t len) const override {\n return _ops->get_sockopt(_fd.get_file_desc(), level, optname, data, len);\n }\n socket_address local_address() const noexcept override {\n return _ops->local_address(_fd.get_file_desc());\n }\n socket_address remote_address() const noexcept override {\n return _ops->remote_address(_fd.get_file_desc());\n }\n future<> wait_input_shutdown() override {\n return _fd.poll_rdhup();\n }\n\n friend class posix_server_socket_impl;\n friend class posix_ap_server_socket_impl;\n friend class posix_reuseport_server_socket_impl;\n friend class posix_network_stack;\n friend class posix_ap_network_stack;\n friend class posix_socket_impl;\n};\n\nstatic void resolve_outgoing_address(socket_address& a) {\n if (a.family() != AF_INET6\n || a.as_posix_sockaddr_in6().sin6_scope_id != inet_address::invalid_scope\n || !IN6_IS_ADDR_LINKLOCAL(&a.as_posix_sockaddr_in6().sin6_addr)\n ) {\n return;\n }\n\n FILE *f;\n\n if (!(f = fopen(\"/proc/net/ipv6_route\", \"r\"))) {\n throw std::system_error(errno, std::system_category(), \"resolve_address\");\n }\n\n auto holder = std::unique_ptr(f, &::fclose);\n\n /**\n Here all configured IPv6 routes are shown in a special format. The example displays for loopback interface only. The meaning is shown below (see net/ipv6/route.c for more).\n\n # cat /proc/net/ipv6_route\n 00000000000000000000000000000000 00 00000000000000000000000000000000 00 00000000000000000000000000000000 ffffffff 00000001 00000001 00200200 lo\n +------------------------------+ ++ +------------------------------+ ++ +------------------------------+ +------+ +------+ +------+ +------+ ++\n | | | | | | | | | |\n 1 2 3 4 5 6 7 8 9 10\n\n 1: IPv6 destination network displayed in 32 hexadecimal chars without colons as separator\n\n 2: IPv6 destination prefix length in hexadecimal\n\n 3: IPv6 source network displayed in 32 hexadecimal chars without colons as separator\n\n 4: IPv6 source prefix length in hexadecimal\n\n 5: IPv6 next hop displayed in 32 hexadecimal chars without colons as separator\n\n 6: Metric in hexadecimal\n\n 7: Reference counter\n\n 8: Use counter\n\n 9: Flags\n\n 10: Device name\n\n */\n\n uint32_t prefix_len, src_prefix_len;\n unsigned long flags;\n char device[16];\n char dest_str[40];\n\n for (;;) {\n auto n = fscanf(f, \"%4s%4s%4s%4s%4s%4s%4s%4s %02x \"\n \"%*4s%*4s%*4s%*4s%*4s%*4s%*4s%*4s %02x \"\n \"%*4s%*4s%*4s%*4s%*4s%*4s%*4s%*4s \"\n \"%*08x %*08x %*08x %08lx %8s\",\n &dest_str[0], &dest_str[5], &dest_str[10], &dest_str[15],\n &dest_str[20], &dest_str[25], &dest_str[30], &dest_str[35],\n &prefix_len,\n &src_prefix_len,\n &flags, device);\n if (n != 12) {\n break;\n }\n\n if ((prefix_len > 128) || (src_prefix_len != 0)\n || (flags & (RTF_POLICY | RTF_FLOW))\n || ((flags & RTF_REJECT) && prefix_len == 0) /* reject all */) {\n continue;\n }\n\n dest_str[4] = dest_str[9] = dest_str[14] = dest_str[19] = dest_str[24] = dest_str[29] = dest_str[34] = ':';\n dest_str[39] = '\\0';\n\n struct in6_addr addr;\n if (inet_pton(AF_INET6, dest_str, &addr) < 0) {\n /* not an Ipv6 address */\n continue;\n }\n\n auto bytes = prefix_len / 8;\n auto bits = prefix_len % 8;\n\n auto& src = a.as_posix_sockaddr_in6().sin6_addr;\n\n if (bytes > 0 && memcmp(&src, &addr, bytes)) {\n continue;\n }\n if (bits > 0) {\n auto c1 = src.s6_addr[bytes];\n auto c2 = addr.s6_addr[bytes];\n auto mask = 0xffu << (8 - bits);\n if ((c1 & mask) != (c2 & mask)) {\n continue;\n }\n }\n\n // found the route.\n for (auto& nif : engine().net().network_interfaces()) {\n if (nif.name() == device || nif.display_name() == device) {\n a.as_posix_sockaddr_in6().sin6_scope_id = nif.index();\n return;\n }\n }\n }\n}\n\nclass posix_socket_impl final : public socket_impl {\n pollable_fd _fd;\n std::pmr::polymorphic_allocator* _allocator;\n bool _reuseaddr = false;\n\n future<> find_port_and_connect(socket_address sa, socket_address local, transport proto = transport::TCP) {\n static thread_local std::default_random_engine random_engine{std::random_device{}()};\n static thread_local std::uniform_int_distribution u(49152/smp::count + 1, 65535/smp::count - 1);\n // If no explicit local address, set to dest address family wildcard. \n if (local.is_unspecified()) {\n local = net::inet_address(sa.addr().in_family());\n }\n resolve_outgoing_address(sa);\n return repeat([this, sa, local, proto, attempts = 0, requested_port = ntoh(local.as_posix_sockaddr_in().sin_port)] () mutable {\n _fd = engine().make_pollable_fd(sa, int(proto));\n _fd.get_file_desc().setsockopt(SOL_SOCKET, SO_REUSEADDR, int(_reuseaddr));\n uint16_t port = attempts++ < 5 && requested_port == 0 && proto == transport::TCP ? u(random_engine) * smp::count + this_shard_id() : requested_port;\n local.as_posix_sockaddr_in().sin_port = hton(port);\n return futurize_invoke([this, sa, local] { return engine().posix_connect(_fd, sa, local); }).then_wrapped([port, requested_port] (future<> f) {\n try {\n f.get();\n return stop_iteration::yes;\n } catch (std::system_error& err) {\n if (port != requested_port && (err.code().value() == EADDRINUSE || err.code().value() == EADDRNOTAVAIL)) {\n return stop_iteration::no;\n }\n throw;\n }\n });\n });\n }\n\n /// an aux function to handle unix-domain-specific requests\n future connect_unix_domain(socket_address sa, socket_address local) {\n // note that if the 'local' address was not set by the client, it was created as an undefined address\n if (local.is_unspecified()) {\n local = socket_address{unix_domain_addr{std::string{}}};\n }\n\n _fd = engine().make_pollable_fd(sa, 0);\n return engine().posix_connect(_fd, sa, local).then(\n [fd = _fd, allocator = _allocator](){\n // a problem with 'private' interaction with 'unique_ptr'\n std::unique_ptr csi;\n csi.reset(new posix_connected_socket_impl{AF_UNIX, 0, std::move(fd), allocator});\n return make_ready_future(connected_socket(std::move(csi)));\n }\n );\n }\n\npublic:\n explicit posix_socket_impl(std::pmr::polymorphic_allocator* allocator=memory::malloc_allocator) : _allocator(allocator) {}\n\n virtual future connect(socket_address sa, socket_address local, transport proto = transport::TCP) override {\n if (sa.is_af_unix()) {\n return connect_unix_domain(sa, local);\n }\n return find_port_and_connect(sa, local, proto).then([this, sa, proto, allocator = _allocator] () mutable {\n std::unique_ptr csi;\n csi.reset(new posix_connected_socket_impl(sa.family(), static_cast(proto), _fd, allocator));\n return make_ready_future(connected_socket(std::move(csi)));\n });\n }\n\n void set_reuseaddr(bool reuseaddr) override {\n _reuseaddr = reuseaddr;\n if (_fd) {\n _fd.get_file_desc().setsockopt(SOL_SOCKET, SO_REUSEADDR, int(reuseaddr));\n }\n }\n\n bool get_reuseaddr() const override {\n if(_fd) {\n return _fd.get_file_desc().getsockopt(SOL_SOCKET, SO_REUSEADDR);\n } else {\n return _reuseaddr;\n }\n }\n\n virtual void shutdown() override {\n if (_fd) {\n try {\n _fd.shutdown(SHUT_RDWR);\n } catch (std::system_error& e) {\n if (e.code().value() != ENOTCONN) {\n throw;\n }\n }\n }\n }\n};\n\nfuture\nposix_server_socket_impl::accept() {\n return _lfd.accept().then([this] (std::tuple fd_sa) {\n auto& fd = std::get<0>(fd_sa);\n auto& sa = std::get<1>(fd_sa);\n auto cth = [this, &sa] {\n switch(_lba) {\n case server_socket::load_balancing_algorithm::connection_distribution:\n return _conntrack.get_handle();\n case server_socket::load_balancing_algorithm::port:\n return _conntrack.get_handle(ntoh(sa.as_posix_sockaddr_in().sin_port) % smp::count);\n case server_socket::load_balancing_algorithm::fixed:\n return _conntrack.get_handle(_fixed_cpu);\n default: abort();\n }\n } ();\n auto cpu = cth.cpu();\n if (cpu == this_shard_id()) {\n std::unique_ptr csi(\n new posix_connected_socket_impl(sa.family(), _protocol, std::move(fd), std::move(cth), _allocator));\n return make_ready_future(\n accept_result{connected_socket(std::move(csi)), sa});\n } else {\n // FIXME: future is discarded\n (void)smp::submit_to(cpu, [protocol = _protocol, ssa = _sa, fd = std::move(fd.get_file_desc()), sa, cth = std::move(cth), allocator = _allocator] () mutable {\n posix_ap_server_socket_impl::move_connected_socket(protocol, ssa, pollable_fd(std::move(fd)), sa, std::move(cth), allocator);\n });\n return accept();\n }\n });\n}\n\nvoid\nposix_server_socket_impl::abort_accept() {\n _lfd.shutdown(SHUT_RD, pollable_fd::shutdown_kernel_only::no);\n}\n\nsocket_address posix_server_socket_impl::local_address() const {\n return _lfd.get_file_desc().get_address();\n}\n\nposix_ap_server_socket_impl::posix_ap_server_socket_impl(int protocol, socket_address sa, std::pmr::polymorphic_allocator* allocator)\n : _protocol(protocol), _sa(sa), _allocator(allocator)\n{\n auto it = ports.emplace(std::make_tuple(_protocol, _sa));\n if (!it.second) {\n throw std::system_error(EADDRINUSE, std::system_category());\n }\n}\n\nposix_ap_server_socket_impl::~posix_ap_server_socket_impl() {\n ports.erase(std::make_tuple(_protocol, _sa));\n}\n\nfuture posix_ap_server_socket_impl::accept() {\n auto t_sa = std::make_tuple(_protocol, _sa);\n auto conni = conn_q.find(t_sa);\n if (conni != conn_q.end()) {\n connection c = std::move(conni->second);\n conn_q.erase(conni);\n try {\n std::unique_ptr csi(\n new posix_connected_socket_impl(_sa.family(), _protocol, std::move(c.fd), std::move(c.connection_tracking_handle), _allocator));\n return make_ready_future(accept_result{connected_socket(std::move(csi)), std::move(c.addr)});\n } catch (...) {\n return make_exception_future(std::current_exception());\n }\n } else {\n try {\n auto i = sockets.emplace(std::piecewise_construct, std::make_tuple(t_sa), std::make_tuple());\n assert(i.second);\n return i.first->second.get_future();\n } catch (...) {\n return make_exception_future(std::current_exception());\n }\n }\n}\n\nvoid\nposix_ap_server_socket_impl::abort_accept() {\n auto t_sa = std::make_tuple(_protocol, _sa);\n conn_q.erase(t_sa);\n auto i = sockets.find(t_sa);\n if (i != sockets.end()) {\n i->second.set_exception(std::system_error(ECONNABORTED, std::system_category()));\n sockets.erase(i);\n }\n}\n\nfuture\nposix_reuseport_server_socket_impl::accept() {\n return _lfd.accept().then([allocator = _allocator, protocol = _protocol] (std::tuple fd_sa) {\n auto& fd = std::get<0>(fd_sa);\n auto& sa = std::get<1>(fd_sa);\n std::unique_ptr csi(\n new posix_connected_socket_impl(sa.family(), protocol, std::move(fd), allocator));\n return make_ready_future(\n accept_result{connected_socket(std::move(csi)), sa});\n });\n}\n\nvoid\nposix_reuseport_server_socket_impl::abort_accept() {\n _lfd.shutdown(SHUT_RD, pollable_fd::shutdown_kernel_only::no);\n}\n\nsocket_address posix_reuseport_server_socket_impl::local_address() const {\n return _lfd.get_file_desc().get_address();\n}\n\nvoid\nposix_ap_server_socket_impl::move_connected_socket(int protocol, socket_address sa, pollable_fd fd, socket_address addr, conntrack::handle cth, std::pmr::polymorphic_allocator* allocator) {\n auto t_sa = std::make_tuple(protocol, sa);\n auto i = sockets.find(t_sa);\n if (i != sockets.end()) {\n try {\n std::unique_ptr csi(new posix_connected_socket_impl(sa.family(), protocol, std::move(fd), std::move(cth), allocator));\n i->second.set_value(accept_result{connected_socket(std::move(csi)), std::move(addr)});\n } catch (...) {\n i->second.set_exception(std::current_exception());\n }\n sockets.erase(i);\n } else {\n conn_q.emplace(std::piecewise_construct, std::make_tuple(t_sa), std::make_tuple(std::move(fd), std::move(addr), std::move(cth)));\n }\n}\n\nfuture>\nposix_data_source_impl::get() {\n return _fd.recv_some(static_cast(this)).then([this] (temporary_buffer b) {\n if (b.size() >= _config.buffer_size) {\n _config.buffer_size *= 2;\n _config.buffer_size = std::min(_config.buffer_size, _config.max_buffer_size);\n } else if (b.size() <= _config.buffer_size / 4) {\n _config.buffer_size /= 2;\n _config.buffer_size = std::max(_config.buffer_size, _config.min_buffer_size);\n }\n return b;\n });\n}\n\ntemporary_buffer\nposix_data_source_impl::allocate_buffer() {\n return make_temporary_buffer(_buffer_allocator, _config.buffer_size);\n}\n\nfuture<> posix_data_source_impl::close() {\n _fd.shutdown(SHUT_RD);\n return make_ready_future<>();\n}\n\nstd::vector to_iovec(const packet& p) {\n std::vector v;\n v.reserve(p.nr_frags());\n for (auto&& f : p.fragments()) {\n v.push_back({.iov_base = f.base, .iov_len = f.size});\n }\n return v;\n}\n\nstd::vector to_iovec(std::vector>& buf_vec) {\n std::vector v;\n v.reserve(buf_vec.size());\n for (auto& buf : buf_vec) {\n v.push_back({.iov_base = buf.get_write(), .iov_len = buf.size()});\n }\n return v;\n}\n\nfuture<>\nposix_data_sink_impl::put(temporary_buffer buf) {\n return _fd.write_all(buf.get(), buf.size()).then([d = buf.release()] {});\n}\n\nfuture<>\nposix_data_sink_impl::put(packet p) {\n _p = std::move(p);\n return _fd.write_all(_p).then([this] { _p.reset(); });\n}\n\nfuture<>\nposix_data_sink_impl::close() {\n _fd.shutdown(SHUT_WR);\n return make_ready_future<>();\n}\n\nvoid posix_data_sink_impl::on_batch_flush_error() noexcept {\n shutdown_socket_fd(_fd, SHUT_RD);\n}\n\nposix_network_stack::posix_network_stack(const program_options::option_group& opts, std::pmr::polymorphic_allocator* allocator)\n : _reuseport(engine().posix_reuseport_available()), _allocator(allocator) {\n}\n\nserver_socket\nposix_network_stack::listen(socket_address sa, listen_options opt) {\n using server_socket = seastar::server_socket;\n // allow unspecified bind address -> default to ipv4 wildcard\n if (sa.is_unspecified()) {\n sa = inet_address(inet_address::family::INET);\n }\n if (sa.is_af_unix()) {\n return server_socket(std::make_unique(0, sa, engine().posix_listen(sa, opt), opt.lba, opt.fixed_cpu, _allocator));\n }\n auto protocol = static_cast(opt.proto);\n return _reuseport ?\n server_socket(std::make_unique(protocol, sa, engine().posix_listen(sa, opt), _allocator))\n :\n server_socket(std::make_unique(protocol, sa, engine().posix_listen(sa, opt), opt.lba, opt.fixed_cpu, _allocator));\n}\n\n::seastar::socket posix_network_stack::socket() {\n return ::seastar::socket(std::make_unique(_allocator));\n}\n\nposix_ap_network_stack::posix_ap_network_stack(const program_options::option_group& opts, std::pmr::polymorphic_allocator* allocator)\n : posix_network_stack(opts, allocator), _reuseport(engine().posix_reuseport_available()) {\n}\n\nserver_socket\nposix_ap_network_stack::listen(socket_address sa, listen_options opt) {\n using server_socket = seastar::server_socket;\n // allow unspecified bind address -> default to ipv4 wildcard\n if (sa.is_unspecified()) {\n sa = inet_address(inet_address::family::INET);\n }\n if (sa.is_af_unix()) {\n return server_socket(std::make_unique(0, sa, _allocator));\n }\n auto protocol = static_cast(opt.proto);\n return _reuseport ?\n server_socket(std::make_unique(protocol, sa, engine().posix_listen(sa, opt), _allocator))\n :\n server_socket(std::make_unique(protocol, sa, _allocator));\n}\n\nstruct cmsg_with_pktinfo {\n struct cmsghdrcmh;\n union {\n struct in_pktinfo pktinfo;\n struct in6_pktinfo pkt6info;\n };\n};\n\nclass posix_datagram_channel : public datagram_channel_impl {\nprivate:\n static constexpr int MAX_DATAGRAM_SIZE = 65507;\n struct recv_ctx {\n struct msghdr _hdr;\n struct iovec _iov;\n socket_address _src_addr;\n char* _buffer;\n cmsg_with_pktinfo _cmsg;\n\n recv_ctx(bool use_pktinfo) {\n memset(&_hdr, 0, sizeof(_hdr));\n _hdr.msg_iov = &_iov;\n _hdr.msg_iovlen = 1;\n _hdr.msg_name = &_src_addr.u.sa;\n _hdr.msg_namelen = sizeof(_src_addr.u.sas);\n\n if (use_pktinfo) {\n memset(&_cmsg, 0, sizeof(_cmsg));\n _hdr.msg_control = &_cmsg;\n _hdr.msg_controllen = sizeof(_cmsg);\n } else {\n _hdr.msg_control = nullptr;\n _hdr.msg_controllen = 0;\n }\n }\n\n recv_ctx(const recv_ctx&) = delete;\n recv_ctx(recv_ctx&&) = delete;\n\n void prepare() {\n _buffer = new char[MAX_DATAGRAM_SIZE];\n _iov.iov_base = _buffer;\n _iov.iov_len = MAX_DATAGRAM_SIZE;\n }\n };\n struct send_ctx {\n struct msghdr _hdr;\n std::vector _iovecs;\n socket_address _dst;\n packet _p;\n\n send_ctx() {\n memset(&_hdr, 0, sizeof(_hdr));\n _hdr.msg_name = &_dst.u.sa;\n _hdr.msg_namelen = _dst.addr_length;\n }\n\n send_ctx(const send_ctx&) = delete;\n send_ctx(send_ctx&&) = delete;\n\n void prepare(const socket_address& dst, packet p) {\n _dst = dst;\n _hdr.msg_namelen = _dst.addr_length;\n _p = std::move(p);\n _iovecs = to_iovec(_p);\n _hdr.msg_iov = _iovecs.data();\n _hdr.msg_iovlen = _iovecs.size();\n resolve_outgoing_address(_dst);\n }\n };\n\n static bool is_inet(sa_family_t family) {\n return family == AF_INET || family == AF_INET6;\n }\n\n static file_desc create_socket(sa_family_t family) {\n file_desc fd = file_desc::socket(family, SOCK_DGRAM | SOCK_NONBLOCK | SOCK_CLOEXEC, 0);\n\n if (is_inet(family)) {\n fd.setsockopt(SOL_IP, IP_PKTINFO, true);\n if (engine().posix_reuseport_available()) {\n fd.setsockopt(SOL_SOCKET, SO_REUSEPORT, 1);\n }\n }\n\n return fd;\n }\n\n pollable_fd _fd;\n socket_address _address;\n recv_ctx _recv;\n send_ctx _send;\n bool _closed;\npublic:\n /// Creates a channel that is not bound to any socket address. The channel\n /// can be used to communicate with adressess that belong to the \\param\n /// family.\n posix_datagram_channel(sa_family_t family)\n : _recv(is_inet(family)), _closed(false) {\n auto fd = create_socket(family);\n\n _address = fd.get_address();\n _fd = std::move(fd);\n }\n\n /// Creates a channel that is bound to the specified local address. It can be used to\n /// communicate with addresses that belong to the family of \\param local.\n posix_datagram_channel(socket_address local)\n : _recv(is_inet(local.family())), _closed(false) {\n auto fd = create_socket(local.family());\n fd.bind(local.u.sa, local.addr_length);\n\n _address = fd.get_address();\n _fd = std::move(fd);\n }\n\n virtual ~posix_datagram_channel() { if (!_closed) close(); };\n virtual future receive() override;\n virtual future<> send(const socket_address& dst, const char *msg) override;\n virtual future<> send(const socket_address& dst, packet p) override;\n virtual void shutdown_input() override {\n _fd.shutdown(SHUT_RD, pollable_fd::shutdown_kernel_only::no);\n }\n virtual void shutdown_output() override {\n _fd.shutdown(SHUT_WR, pollable_fd::shutdown_kernel_only::no);\n }\n virtual void close() override {\n _closed = true;\n _fd = {};\n }\n virtual bool is_closed() const override { return _closed; }\n socket_address local_address() const override {\n assert(_address.u.sas.ss_family != AF_INET6 || (_address.addr_length > 20));\n return _address;\n }\n};\n\nfuture<> posix_datagram_channel::send(const socket_address& dst, const char *message) {\n auto len = strlen(message);\n auto a = dst;\n resolve_outgoing_address(a);\n return _fd.sendto(a, message, len)\n .then([len] (size_t size) { assert(size == len); });\n}\n\nfuture<> posix_datagram_channel::send(const socket_address& dst, packet p) {\n auto len = p.len();\n _send.prepare(dst, std::move(p));\n return _fd.sendmsg(&_send._hdr)\n .then([len] (size_t size) { assert(size == len); });\n}\n\nudp_channel\nposix_network_stack::make_udp_channel(const socket_address& addr) {\n if (!addr.is_unspecified()) {\n return make_bound_datagram_channel(addr);\n } else {\n // Preserve the default behavior of make_udp_channel({}) which is to\n // create an unbound channel that can be used to send UDP datagrams.\n return make_unbound_datagram_channel(AF_INET);\n }\n}\n\ndatagram_channel\nposix_network_stack::make_unbound_datagram_channel(sa_family_t family) {\n return datagram_channel(std::make_unique(family));\n}\n\ndatagram_channel\nposix_network_stack::make_bound_datagram_channel(const socket_address& local) {\n return datagram_channel(std::make_unique(local));\n}\n\nbool\nposix_network_stack::supports_ipv6() const {\n static bool has_ipv6 = [] {\n try {\n posix_datagram_channel c(ipv6_addr{\"::1\"});\n c.close();\n return true;\n } catch (...) {}\n return false;\n }();\n\n return has_ipv6;\n}\n\nclass posix_datagram : public datagram_impl {\nprivate:\n socket_address _src;\n socket_address _dst;\n packet _p;\npublic:\n \nposix_datagram(const socket_address& src, const socket_address& dst, packet p) : _src(src), _dst(dst), _p(std::move(p)) {}\n virtual socket_address get_src() override { return _src; }\n virtual socket_address get_dst() override { return _dst; }\n virtual uint16_t get_dst_port() override {\n if (_dst.family() != AF_INET && _dst.family() != AF_INET6) {\n throw std::runtime_error(format(\"get_dst_port() called on non-IP address: {}\", _dst));\n }\n return _dst.port();\n }\n virtual packet& get_data() override { return _p; }\n};\n\nfuture\nposix_datagram_channel::receive() {\n _recv.prepare();\n return _fd.recvmsg(&_recv._hdr).then([this] (size_t size) {\n std::optional dst;\n for (auto* cmsg = CMSG_FIRSTHDR(&_recv._hdr); cmsg != nullptr; cmsg = CMSG_NXTHDR(&_recv._hdr, cmsg)) {\n if (cmsg->cmsg_level == IPPROTO_IP && cmsg->cmsg_type == IP_PKTINFO) {\n dst = ipv4_addr(copy_reinterpret_cast(CMSG_DATA(cmsg)).ipi_addr, _address.port());\n break;\n } else if (cmsg->cmsg_level == IPPROTO_IPV6 && cmsg->cmsg_type == IPV6_PKTINFO) {\n dst = ipv6_addr(copy_reinterpret_cast(CMSG_DATA(cmsg)).ipi6_addr, _address.port());\n break;\n }\n }\n return make_ready_future(datagram(std::make_unique(\n _recv._src_addr, dst ? *dst : _address, packet(fragment{_recv._buffer, size}, make_deleter([buf = _recv._buffer] { delete[] buf; })))));\n }).handle_exception([p = _recv._buffer](auto ep) {\n delete[] p;\n return make_exception_future(std::move(ep));\n });\n}\n\nnetwork_stack_entry register_posix_stack() {\n return network_stack_entry{\n \"posix\", std::make_unique(nullptr, \"Posix\"),\n [](const program_options::option_group& ops) {\n return smp::main_thread() ? posix_network_stack::create(ops)\n : posix_ap_network_stack::create(ops);\n },\n true};\n}\n\n// nw interface stuff\n\nstd::vector posix_network_stack::network_interfaces() {\n class posix_network_interface_impl final : public network_interface_impl {\n public:\n uint32_t _index = 0, _mtu = 0;\n sstring _name, _display_name;\n std::vector _addresses;\n std::vector _hardware_address;\n bool _loopback = false, _virtual = false, _up = false;\n\n uint32_t index() const override {\n return _index;\n }\n uint32_t mtu() const override {\n return _mtu;\n }\n const sstring& name() const override {\n return _name; \n }\n const sstring& display_name() const override {\n return _display_name.empty() ? name() : _display_name;\n }\n const std::vector& addresses() const override {\n return _addresses; \n }\n const std::vector hardware_address() const override {\n return _hardware_address;\n }\n bool is_loopback() const override {\n return _loopback; \n }\n bool is_virtual() const override {\n return _virtual;\n }\n bool is_up() const override {\n // TODO: should be checked on query?\n return _up;\n }\n bool supports_ipv6() const override {\n // TODO: this is not 100% correct.\n return std::any_of(_addresses.begin(), _addresses.end(), std::mem_fn(&inet_address::is_ipv6));\n }\n };\n\n // For now, keep an immutable set of interfaces created on start, shared across \n // shards\n static const std::vector global_interfaces = [] {\n auto fd = ::socket(AF_NETLINK, SOCK_RAW, NETLINK_ROUTE);\n throw_system_error_on(fd < 0, \"could not open netlink socket\");\n\n std::unique_ptr fd_guard(&fd, [](int* p) { ::close(*p); });\n\n auto pid = ::getpid();\n\n sockaddr_nl local = {\n .nl_family = AF_NETLINK,\n .nl_pid = static_cast(pid),\n .nl_groups = RTMGRP_IPV6_IFADDR|RTMGRP_IPV4_IFADDR,\n };\n\n throw_system_error_on(bind(fd, (struct sockaddr *) &local, sizeof(local)) < 0, \"could not bind netlink socket\");\n\n /* RTNL socket is ready for use, prepare and send requests */\n\n std::vector res;\n\n for (auto msg : { RTM_GETLINK, RTM_GETADDR}) {\n struct nl_req {\n nlmsghdr hdr;\n union {\n rtgenmsg gen;\n ifaddrmsg addr; \n }; \n } req = {};\n\n req.hdr.nlmsg_len = NLMSG_LENGTH(sizeof(struct rtgenmsg));\n req.hdr.nlmsg_type = msg;\n req.hdr.nlmsg_flags = NLM_F_REQUEST | NLM_F_ROOT; \n req.hdr.nlmsg_seq = 1;\n req.hdr.nlmsg_pid = pid;\n\n if (msg == RTM_GETLINK) {\n req.gen.rtgen_family = AF_PACKET; /* no preferred AF, we will get *all* interfaces */\n } else {\n req.addr.ifa_family = AF_UNSPEC;\n }\n\n sockaddr_nl kernel = {\n .nl_family = AF_NETLINK, /* fill-in kernel address (destination of our message) */\n };\n iovec io = {\n .iov_base = &req,\n .iov_len = req.hdr.nlmsg_len,\n };\n msghdr rtnl_msg = {\n .msg_name = &kernel,\n .msg_namelen = sizeof(kernel),\n .msg_iov = &io,\n .msg_iovlen = 1,\n };\n\n throw_system_error_on(::sendmsg(fd, (struct msghdr *) &rtnl_msg, 0) < 0, \"could not send netlink request\");\n /* parse reply */\n\n constexpr size_t reply_buffer_size = 8192;\n char reply[reply_buffer_size]; \n\n bool done = false;\n\n while (!done) {\n iovec io_reply = {\n .iov_base = reply,\n .iov_len = reply_buffer_size,\n };\n msghdr rtnl_reply = {\n .msg_name = &kernel,\n .msg_namelen = sizeof(kernel),\n .msg_iov = &io_reply,\n .msg_iovlen = 1,\n };\n\n auto len = ::recvmsg(fd, &rtnl_reply, 0); /* read as much data as fits in the receive buffer */\n if (len <= 0) {\n return res;\n }\n\n for (auto* msg_ptr = (struct nlmsghdr *) reply; NLMSG_OK(msg_ptr, len); msg_ptr = NLMSG_NEXT(msg_ptr, len)) {\n switch(msg_ptr->nlmsg_type) {\n case NLMSG_DONE: // that is all\n done = true;\n break; \n case RTM_NEWLINK: \n {\n auto* iface = reinterpret_cast(NLMSG_DATA(msg_ptr));\n auto ilen = msg_ptr->nlmsg_len - NLMSG_LENGTH(sizeof(ifinfomsg));\n\n // todo: filter any non-network interfaces (family)\n\n posix_network_interface_impl nwif;\n \n nwif._index = iface->ifi_index;\n nwif._loopback = (iface->ifi_flags & IFF_LOOPBACK) != 0;\n nwif._up = (iface->ifi_flags & IFF_UP) != 0;\n #if defined(IFF_802_1Q_VLAN) && defined(IFF_EBRIDGE) && defined(IFF_SLAVE_INACTIVE)\n nwif._virtual = (iface->ifi_flags & (IFF_802_1Q_VLAN|IFF_EBRIDGE|IFF_SLAVE_INACTIVE)) != 0;\n #endif \n for (auto* attribute = IFLA_RTA(iface); RTA_OK(attribute, ilen); attribute = RTA_NEXT(attribute, ilen)) {\n switch(attribute->rta_type) {\n case IFLA_IFNAME:\n nwif._name = reinterpret_cast(RTA_DATA(attribute));\n break;\n case IFLA_MTU:\n nwif._mtu = *reinterpret_cast(RTA_DATA(attribute)); \n break;\n case IFLA_ADDRESS:\n nwif._hardware_address.assign(reinterpret_cast(RTA_DATA(attribute)), reinterpret_cast(RTA_DATA(attribute)) + RTA_PAYLOAD(attribute));\n break;\n default:\n break;\n }\n }\n\n res.emplace_back(std::move(nwif));\n\n break;\n }\n case RTM_NEWADDR:\n {\n auto* addr = reinterpret_cast(NLMSG_DATA(msg_ptr));\n auto ilen = msg_ptr->nlmsg_len - NLMSG_LENGTH(sizeof(ifaddrmsg));\n \n for (auto& nwif : res) {\n if (nwif._index == addr->ifa_index) {\n for (auto* attribute = IFA_RTA(addr); RTA_OK(attribute, ilen); attribute = RTA_NEXT(attribute, ilen)) {\n std::optional ia;\n \n switch(attribute->rta_type) {\n case IFA_LOCAL:\n case IFA_ADDRESS: // ipv6 addresses are reported only as \"ADDRESS\"\n\n if (RTA_PAYLOAD(attribute) == sizeof(::in_addr)) {\n ia.emplace(*reinterpret_cast(RTA_DATA(attribute)));\n } else if (RTA_PAYLOAD(attribute) == sizeof(::in6_addr)) {\n ia.emplace(*reinterpret_cast(RTA_DATA(attribute)), nwif.index());\n }\n \n if (ia && std::find(nwif._addresses.begin(), nwif._addresses.end(), *ia) == nwif._addresses.end()) {\n nwif._addresses.emplace_back(*ia);\n }\n\n break;\n default:\n break;\n }\n }\n\n break;\n }\n }\n\n break;\n }\n default:\n break;\n }\n } \n }\n }\n\n return res;\n }();\n\n // And a similarly immutable set of shared_ptr to network_interface_impl per shard, ready \n // to be handed out to callers with minimal overhead\n static const thread_local std::vector> thread_local_interfaces = [] {\n std::vector> res;\n res.reserve(global_interfaces.size());\n std::transform(global_interfaces.begin(), global_interfaces.end(), std::back_inserter(res), [](const posix_network_interface_impl& impl) {\n return make_shared(impl);\n });\n return res;\n }();\n\n return std::vector(thread_local_interfaces.begin(), thread_local_interfaces.end());\n}\n\n}\n\n}\n\n// Path: src/net/arp.cc\n/*\n * This file is open source software, licensed to you under the terms\n * of the Apache License, Version 2.0 (the \"License\"). See the NOTICE file\n * distributed with this work for additional information regarding copyright\n * ownership. You may not use this file except in compliance with the License.\n *\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing,\n * software distributed under the License is distributed on an\n * \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY\n * KIND, either express or implied. See the License for the\n * specific language governing permissions and limitations\n * under the License.\n */\n/*\n * Copyright (C) 2014 Cloudius Systems, Ltd.\n */\n\n#ifdef SEASTAR_MODULE\nmodule;\n#include \n#include \n#include \nmodule seastar;\n#else\n#include \n#endif\n\nnamespace seastar {\n\nnamespace net {\n\narp_for_protocol::arp_for_protocol(arp& a, uint16_t proto_num)\n : _arp(a), _proto_num(proto_num) {\n _arp.add(proto_num, this);\n}\n\narp_for_protocol::~arp_for_protocol() {\n _arp.del(_proto_num);\n}\n\narp::arp(interface* netif) : _netif(netif), _proto(netif, eth_protocol_num::arp, [this] { return get_packet(); })\n{\n // FIXME: ignored future\n (void)_proto.receive(\n [this](packet p, ethernet_address ea) {\n return process_packet(std::move(p), ea);\n },\n [this](forward_hash& out_hash_data, packet& p, size_t off) {\n return forward(out_hash_data, p, off);\n });\n}\n\nstd::optional arp::get_packet() {\n std::optional p;\n if (!_packetq.empty()) {\n p = std::move(_packetq.front());\n _packetq.pop_front();\n }\n return p;\n}\n\nbool arp::forward(forward_hash& out_hash_data, packet& p, size_t off) {\n auto ah = p.get_header(off);\n auto i = _arp_for_protocol.find(ntoh(ah->ptype));\n if (i != _arp_for_protocol.end()) {\n return i->second->forward(out_hash_data, p, off);\n }\n return false;\n}\n\nvoid arp::add(uint16_t proto_num, arp_for_protocol* afp) {\n _arp_for_protocol[proto_num] = afp;\n}\n\nvoid arp::del(uint16_t proto_num) {\n _arp_for_protocol.erase(proto_num);\n}\n\nfuture<>\narp::process_packet(packet p, ethernet_address from) {\n auto h = p.get_header(0, arp_hdr::size());\n if (!h) {\n return make_ready_future<>();\n }\n auto ah = arp_hdr::read(h);\n auto i = _arp_for_protocol.find(ah.ptype);\n if (i != _arp_for_protocol.end()) {\n return i->second->received(std::move(p));\n }\n return make_ready_future<>();\n}\n\n}\n\n}\n\n// Path: src/core/condition-variable.cc\n/*\n * This file is open source software, licensed to you under the terms\n * of the Apache License, Version 2.0 (the \"License\"). See the NOTICE file\n * distributed with this work for additional information regarding copyright\n * ownership. You may not use this file except in compliance with the License.\n *\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing,\n * software distributed under the License is distributed on an\n * \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY\n * KIND, either express or implied. See the License for the\n * specific language governing permissions and limitations\n * under the License.\n */\n/*\n * Copyright (C) 2020 ScyllaDB, Ltd.\n */\n#ifdef SEASTAR_MODULE\nmodule;\n#include \n#include \n#include \nmodule seastar;\n#else\n#include \n#endif\n\nnamespace seastar {\n\nconst char* broken_condition_variable::what() const noexcept {\n return \"Condition variable is broken\";\n}\n\nconst char* condition_variable_timed_out::what() const noexcept {\n return \"Condition variable timed out\";\n}\n\ncondition_variable::~condition_variable() {\n broken();\n}\n\nvoid condition_variable::add_waiter(waiter& w) noexcept {\n assert(!_signalled); // should not have snuck between\n if (_ex) {\n w.set_exception(_ex);\n return;\n }\n _waiters.push_back(w);\n}\n\nvoid condition_variable::waiter::timeout() noexcept {\n this->unlink();\n this->set_exception(std::make_exception_ptr(condition_variable_timed_out()));\n}\n\nbool condition_variable::wakeup_first() noexcept {\n if (_waiters.empty()) {\n return false;\n }\n auto& w = _waiters.front();\n _waiters.pop_front();\n if (_ex) {\n w.set_exception(_ex);\n } else {\n w.signal();\n }\n return true;\n}\n\nbool condition_variable::check_and_consume_signal() noexcept {\n return std::exchange(_signalled, false);\n}\n\nvoid condition_variable::signal() noexcept {\n if (!wakeup_first()) {\n _signalled = true;\n }\n}\n\n/// Notify variable and wake up all waiter\nvoid condition_variable::broadcast() noexcept {\n auto tmp(std::move(_waiters));\n while (!tmp.empty()) {\n auto& w = tmp.front();\n tmp.pop_front();\n if (_ex) {\n w.set_exception(_ex);\n } else {\n w.signal();\n }\n }\n}\n\n/// Signal to waiters that an error occurred. \\ref wait() will see\n/// an exceptional future<> containing the provided exception parameter.\n/// The future is made available immediately.\nvoid condition_variable::broken() noexcept {\n broken(std::make_exception_ptr(broken_condition_variable()));\n}\n\nvoid condition_variable::broken(std::exception_ptr ep) noexcept {\n _ex = ep;\n broadcast();\n}\n\n} // namespace seastar\n\n// Path: src/core/smp.cc\n/*\n * This file is open source software, licensed to you under the terms\n * of the Apache License, Version 2.0 (the \"License\"). See the NOTICE file\n * distributed with this work for additional information regarding copyright\n * ownership. You may not use this file except in compliance with the License.\n *\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing,\n * software distributed under the License is distributed on an\n * \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY\n * KIND, either express or implied. See the License for the\n * specific language governing permissions and limitations\n * under the License.\n */\n/*\n * Copyright 2019 ScyllaDB\n */\n#ifdef SEASTAR_MODULE\nmodule;\n#endif\n\n#include \n#include \n#include \n\n#ifdef SEASTAR_MODULE\nmodule seastar;\n#else\n#include \n#include \n#include \n#include \n#include \n#include \n#include \"prefault.hh\"\n#endif\n\nnamespace seastar {\n\nextern logger seastar_logger;\n\n#ifdef SEASTAR_BUILD_SHARED_LIBS\nshard_id* internal::this_shard_id_ptr() noexcept {\n static thread_local shard_id g_this_shard_id;\n return &g_this_shard_id;\n}\n#endif\n\nvoid smp_message_queue::work_item::process() {\n schedule(this);\n}\n\nstruct smp_service_group_impl {\n std::vector clients; // one client per server shard\n#ifdef SEASTAR_DEBUG\n unsigned version = 0;\n#endif\n};\n\nstatic thread_local smp_service_group_semaphore smp_service_group_management_sem{1, named_semaphore_exception_factory{\"smp_service_group_management_sem\"}};\nstatic thread_local std::vector smp_service_groups;\n\nstatic named_semaphore_exception_factory make_service_group_semaphore_exception_factory(unsigned id, shard_id client_cpu, shard_id this_cpu, std::optional smp_group_name) {\n if (smp_group_name) {\n return named_semaphore_exception_factory{format(\"smp_service_group:'{}' (#{}) {}->{} semaphore\", *smp_group_name, id, client_cpu, this_cpu)};\n } else {\n return named_semaphore_exception_factory{format(\"smp_service_group#{} {}->{} semaphore\", id, client_cpu, this_cpu)};\n }\n\n}\n\nstatic_assert(std::is_nothrow_copy_constructible_v);\nstatic_assert(std::is_nothrow_move_constructible_v);\n\nstatic_assert(std::is_nothrow_default_constructible_v);\nstatic_assert(std::is_nothrow_copy_constructible_v);\nstatic_assert(std::is_nothrow_move_constructible_v);\n\nfuture create_smp_service_group(smp_service_group_config ssgc) noexcept {\n ssgc.max_nonlocal_requests = std::max(ssgc.max_nonlocal_requests, smp::count - 1);\n return smp::submit_to(0, [ssgc] {\n return with_semaphore(smp_service_group_management_sem, 1, [ssgc] {\n auto it = boost::range::find_if(smp_service_groups, [&] (smp_service_group_impl& ssgi) { return ssgi.clients.empty(); });\n size_t id = it - smp_service_groups.begin();\n return parallel_for_each(smp::all_cpus(), [ssgc, id] (unsigned cpu) {\n return smp::submit_to(cpu, [ssgc, id, cpu] {\n if (id >= smp_service_groups.size()) {\n smp_service_groups.resize(id + 1); // may throw\n }\n smp_service_groups[id].clients.reserve(smp::count); // may throw\n auto per_client = smp::count > 1 ? ssgc.max_nonlocal_requests / (smp::count - 1) : 0u;\n for (unsigned i = 0; i != smp::count; ++i) {\n smp_service_groups[id].clients.emplace_back(per_client, make_service_group_semaphore_exception_factory(id, i, cpu, ssgc.group_name));\n }\n });\n }).handle_exception([id] (std::exception_ptr e) {\n // rollback\n return smp::invoke_on_all([id] {\n if (smp_service_groups.size() > id) {\n smp_service_groups[id].clients.clear();\n }\n }).then([e = std::move(e)] () mutable {\n std::rethrow_exception(std::move(e));\n });\n }).then([id] {\n auto ret = smp_service_group(id);\n#ifdef SEASTAR_DEBUG\n ret._version = smp_service_groups[id].version;\n#endif\n return ret;\n });\n });\n });\n}\n\nfuture<> destroy_smp_service_group(smp_service_group ssg) noexcept {\n return smp::submit_to(0, [ssg] {\n return with_semaphore(smp_service_group_management_sem, 1, [ssg] {\n auto id = internal::smp_service_group_id(ssg);\n if (id >= smp_service_groups.size()) {\n on_fatal_internal_error(seastar_logger, format(\"destroy_smp_service_group id={}: out of range\", id));\n }\n#ifdef SEASTAR_DEBUG\n if (ssg._version != smp_service_groups[id].version) {\n on_fatal_internal_error(seastar_logger, format(\"destroy_smp_service_group id={}: stale version={}: current_version={}\", id, ssg._version, smp_service_groups[id].version));\n }\n#endif\n return smp::invoke_on_all([id] {\n smp_service_groups[id].clients.clear();\n#ifdef SEASTAR_DEBUG\n ++smp_service_groups[id].version;\n#endif\n });\n });\n });\n}\n\nvoid init_default_smp_service_group(shard_id cpu) {\n...\n", "needle_token_start": 8908, "needle_token_end": 8952, "code_context_ntokens": 16161, "output": ["```cpp\nclass posix_datagram : public datagram_impl {\nprivate:\n socket_address _src;\n socket_address _dst;\n packet _p;\npublic:\n posix_datagram(const socket_address& src, const socket_address& dst, packet p) : _src(src), _dst(dst), _p(std::move(p)) {}\n virtual socket_address get_src() override { return _src; }\n virtual socket_address get_dst() override { return _dst; }\n virtual uint16_t get_dst_port() override {\n if (_dst.family() != AF_INET && _dst.family() != AF_INET6) {\n throw std::runtime_error(format(\"get_dst_port() called on non-IP address: {}\", _dst));\n }\n return _dst.port();\n }\n virtual packet& get_data() override { return _p; }\n};\n```"]} +{"repo": "scylladb/seastar", "name": "arp_learn", "language": "cpp", "path": "src/net/native-stack.cc", "position_ratio": 0.65, "description": "\nFunction Description:\n1. **Purpose**: The function is designed to update the Address Resolution Protocol (ARP) table across all processing units in a distributed system, associating a network layer address with a data link layer address.\n2. **Input**: It accepts two parameters: a data link layer address (such as a MAC address) and a network layer address (such as an IP address).\n3. **Output**: There is no direct output returned to the caller as the function operates asynchronously in the background.\n4. **Procedure**: The function broadcasts the task of updating the ARP table with the provided addresses to all processing units in the system. Each unit executes this update independently.\n\n", "instruction": "Based on the function description and code context, please retrieve and repeat the exact described function from the code context in a code block wrapped by ```:", "template": "instruction\ncode_context\ndescription\ninstruction", "code_context": "// Path: src/net/virtio.cc\n/*\n * This file is open source software, licensed to you under the terms\n * of the Apache License, Version 2.0 (the \"License\"). See the NOTICE file\n * distributed with this work for additional information regarding copyright\n * ownership. You may not use this file except in compliance with the License.\n *\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing,\n * software distributed under the License is distributed on an\n * \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY\n * KIND, either express or implied. See the License for the\n * specific language governing permissions and limitations\n * under the License.\n */\n/*\n * Copyright (C) 2014 Cloudius Systems, Ltd.\n */\n\n#ifdef SEASTAR_MODULE\nmodule;\n#endif\n\n#include \n#include \n#include \n#include \n#include \n#include \n#include \n#include \n#include \n#include \n#include \n#include \n#include \n#ifdef HAVE_OSV\n#include \n#endif\n\n#ifdef SEASTAR_MODULE\nmodule seastar;\n#else\n#include \n#include \n#include \n#include \n#include \"core/vla.hh\"\n#include \n#include \n#include \n#include \n#include \n#include \n#include \n#include \n#include \n#include \n#endif\n\nnamespace seastar {\n\nusing namespace net;\n\nnamespace virtio {\n\nusing phys = uint64_t;\n\n#ifndef HAVE_OSV\n\nphys virt_to_phys(void* p) {\n return reinterpret_cast(p);\n}\n\n#else\n\nphys virt_to_phys(void* p) {\n return osv::assigned_virtio::virt_to_phys(p);\n}\n\n#endif\n\nclass device : public net::device {\nprivate:\n net::hw_features _hw_features;\n uint64_t _features;\n\nprivate:\n uint64_t setup_features(const net::virtio_options& opts, const program_options::value& lro) {\n int64_t seastar_supported_features = VIRTIO_RING_F_INDIRECT_DESC | VIRTIO_NET_F_MRG_RXBUF;\n\n if (!(opts.event_index && opts.event_index.get_value() == \"off\")) {\n seastar_supported_features |= VIRTIO_RING_F_EVENT_IDX;\n }\n if (!(opts.csum_offload && opts.csum_offload.get_value() == \"off\")) {\n seastar_supported_features |= VIRTIO_NET_F_CSUM | VIRTIO_NET_F_GUEST_CSUM;\n _hw_features.tx_csum_l4_offload = true;\n _hw_features.rx_csum_offload = true;\n } else {\n _hw_features.tx_csum_l4_offload = false;\n _hw_features.rx_csum_offload = false;\n }\n if (!(opts.tso && opts.tso.get_value() == \"off\")) {\n seastar_supported_features |= VIRTIO_NET_F_HOST_TSO4;\n _hw_features.tx_tso = true;\n } else {\n _hw_features.tx_tso = false;\n }\n\n if (!(lro && lro.get_value() == \"off\")) {\n seastar_supported_features |= VIRTIO_NET_F_GUEST_TSO4;\n _hw_features.rx_lro = true;\n } else {\n _hw_features.rx_lro = false;\n }\n\n if (!(opts.ufo && opts.ufo.get_value() == \"off\")) {\n seastar_supported_features |= VIRTIO_NET_F_HOST_UFO;\n seastar_supported_features |= VIRTIO_NET_F_GUEST_UFO;\n _hw_features.tx_ufo = true;\n } else {\n _hw_features.tx_ufo = false;\n }\n\n seastar_supported_features |= VIRTIO_NET_F_MAC;\n return seastar_supported_features;\n }\n\npublic:\n device(const virtio_options& opts, const program_options::value& lro)\n : _features(setup_features(opts, lro))\n {}\n ethernet_address hw_address() override {\n return { 0x12, 0x23, 0x34, 0x56, 0x67, 0x78 };\n }\n\n net::hw_features hw_features() override {\n return _hw_features;\n }\n\n uint64_t features() {\n return _features;\n }\n\n virtual std::unique_ptr init_local_queue(const program_options::option_group& opts, uint16_t qid) override;\n};\n\n/* The virtio_notifier class determines how to do host-to-guest and guest-to-\n * host notifications. We have two different implementations - one for vhost\n * (where both notifications occur through eventfds) and one for an assigned\n * virtio device from OSv.\n */\nclass notifier {\npublic:\n // Notify the host\n virtual void notify() = 0;\n // Do whatever it takes to wake wait(). A notifier does not need to\n // implement this function if wait() waits for an external even which is\n // generated by an external process (e.g., virtio_notifier_host doesn't\n // need to implement this).\n virtual void wake_wait() {\n abort();\n }\n virtual ~notifier() {\n }\n};\n\nclass notifier_vhost : public notifier {\nprivate:\n writeable_eventfd _kick;\npublic:\n virtual void notify() override {\n _kick.signal(1);\n }\n notifier_vhost(writeable_eventfd &&kick)\n : _kick(std::move(kick)) {}\n};\n\n#ifdef HAVE_OSV\nclass notifier_osv : public notifier {\nprivate:\n uint16_t _q_index;\n...\n// Path: src/net/inet_address.cc\n/*\n * This file is open source software, licensed to you under the terms\n * of the Apache License, Version 2.0 (the \"License\"). See the NOTICE file\n * distributed with this work for additional information regarding copyright\n * ownership. You may not use this file except in compliance with the License.\n *\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing,\n * software distributed under the License is distributed on an\n * \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY\n * KIND, either express or implied. See the License for the\n * specific language governing permissions and limitations\n * under the License.\n */\n/*\n * Copyright (C) 2016 ScyllaDB.\n */\n\n#ifdef SEASTAR_MODULE\nmodule;\n#endif\n\n#include \n#include \n#include \n#include \n#include \n\n#ifdef SEASTAR_MODULE\nmodule seastar;\n#else\n#include \n#include \n#include \n#include \n#include \n#include \n#endif\n\nstatic_assert(std::is_nothrow_default_constructible_v);\nstatic_assert(std::is_nothrow_copy_constructible_v);\nstatic_assert(std::is_nothrow_move_constructible_v);\n\nstatic_assert(std::is_nothrow_default_constructible_v);\nstatic_assert(std::is_nothrow_copy_constructible_v);\nstatic_assert(std::is_nothrow_move_constructible_v);\n\nstatic_assert(std::is_nothrow_default_constructible_v);\nstatic_assert(std::is_nothrow_copy_constructible_v);\nstatic_assert(std::is_nothrow_move_constructible_v);\n\nseastar::net::inet_address::inet_address() noexcept\n : inet_address(::in6_addr{})\n{}\n\nseastar::net::inet_address::inet_address(family f) noexcept\n : _in_family(f)\n{\n memset(&_in6, 0, sizeof(_in6));\n}\n\nseastar::net::inet_address::inet_address(::in_addr i) noexcept\n : _in_family(family::INET), _in(i) {\n}\n\nseastar::net::inet_address::inet_address(::in6_addr i, uint32_t scope) noexcept\n : _in_family(family::INET6), _in6(i), _scope(scope) {\n}\n\nstd::optional \nseastar::net::inet_address::parse_numerical(const sstring& addr) {\n inet_address in;\n if (::inet_pton(AF_INET, addr.c_str(), &in._in)) {\n in._in_family = family::INET;\n return in;\n }\n auto i = addr.find_last_of('%');\n if (i != sstring::npos) {\n auto ext = addr.substr(i + 1);\n auto src = addr.substr(0, i);\n auto res = parse_numerical(src);\n\n if (res) {\n uint32_t index = std::numeric_limits::max();\n try {\n index = std::stoul(ext);\n } catch (...) {\n }\n for (auto& nwif : engine().net().network_interfaces()) {\n if (nwif.index() == index || nwif.name() == ext || nwif.display_name() == ext) {\n res->_scope = nwif.index();\n break;\n }\n }\n return *res;\n }\n }\n if (::inet_pton(AF_INET6, addr.c_str(), &in._in6)) {\n in._in_family = family::INET6;\n return in;\n }\n return {};\n}\n\nseastar::net::inet_address::inet_address(const sstring& addr)\n : inet_address([&addr] {\n auto res = parse_numerical(addr); \n if (res) {\n return std::move(*res);\n }\n throw std::invalid_argument(addr);\n}())\n{}\n\nseastar::net::inet_address::inet_address(const ipv4_address& in) noexcept\n : inet_address(::in_addr{hton(in.ip)})\n{}\n\nseastar::net::inet_address::inet_address(const ipv6_address& in, uint32_t scope) noexcept\n : inet_address([&] {\n ::in6_addr tmp;\n std::copy(in.bytes().begin(), in.bytes().end(), tmp.s6_addr);\n return tmp;\n }(), scope)\n{}\n\nseastar::net::ipv4_address seastar::net::inet_address::as_ipv4_address() const {\n in_addr in = *this;\n return ipv4_address(ntoh(in.s_addr));\n}\n\nseastar::net::ipv6_address seastar::net::inet_address::as_ipv6_address() const noexcept {\n in6_addr in6 = *this;\n return ipv6_address{in6};\n}\n\nbool seastar::net::inet_address::operator==(const inet_address& o) const noexcept {\n if (o._in_family != _in_family) {\n return false;\n }\n switch (_in_family) {\n case family::INET:\n return _in.s_addr == o._in.s_addr;\n case family::INET6:\n return std::equal(std::begin(_in6.s6_addr), std::end(_in6.s6_addr), std::begin(o._in6.s6_addr));\n default:\n return false;\n }\n}\n\nseastar::net::inet_address::operator ::in_addr() const {\n if (_in_family != family::INET) {\n if (IN6_IS_ADDR_V4MAPPED(&_in6)) {\n ::in_addr in;\n in.s_addr = _in6.s6_addr32[3];\n return in;\n }\n throw std::invalid_argument(\"Not an IPv4 address\");\n }\n return _in;\n}\n\nseastar::net::inet_address::operator ::in6_addr() const noexcept {\n if (_in_family == family::INET) {\n in6_addr in6 = IN6ADDR_ANY_INIT;\n in6.s6_addr32[2] = htonl(0xffff);\n in6.s6_addr32[3] = _in.s_addr;\n return in6;\n }\n return _in6;\n}\n\nseastar::net::inet_address::operator seastar::net::ipv6_address() const noexcept {\n return as_ipv6_address();\n}\n\nsize_t seastar::net::inet_address::size() const noexcept {\n switch (_in_family) {\n case family::INET:\n return sizeof(::in_addr);\n case family::INET6:\n return sizeof(::in6_addr);\n default:\n return 0;\n }\n}\n\nconst void * seastar::net::inet_address::data() const noexcept {\n return &_in;\n}\n\nbool seastar::net::inet_address::is_loopback() const noexcept {\n switch (_in_family) {\n case family::INET:\n return (net::ntoh(_in.s_addr) & 0xff000000) == 0x7f000000;\n case family::INET6:\n return std::equal(std::begin(_in6.s6_addr), std::end(_in6.s6_addr), std::begin(::in6addr_loopback.s6_addr));\n default:\n return false;\n }\n}\n\nbool seastar::net::inet_address::is_addr_any() const noexcept {\n switch (_in_family) {\n case family::INET:\n return _in.s_addr == INADDR_ANY;\n case family::INET6:\n return std::equal(std::begin(_in6.s6_addr), std::end(_in6.s6_addr), std::begin(::in6addr_any.s6_addr));\n default:\n return false;\n }\n}\n\nseastar::net::ipv6_address::ipv6_address(const ::in6_addr& in) noexcept {\n std::copy(std::begin(in.s6_addr), std::end(in.s6_addr), ip.begin());\n}\n\nseastar::net::ipv6_address::ipv6_address(const ipv6_bytes& in) noexcept\n : ip(in)\n{}\n\nseastar::net::ipv6_address::ipv6_address(const ipv6_addr& addr) noexcept\n : ipv6_address(addr.ip)\n{}\n\nseastar::net::ipv6_address::ipv6_address() noexcept\n : ipv6_address(::in6addr_any)\n{}\n\nseastar::net::ipv6_address::ipv6_address(const std::string& addr) {\n if (!::inet_pton(AF_INET6, addr.c_str(), ip.data())) {\n throw std::runtime_error(fmt::format(\"Wrong format for IPv6 address {}. Please ensure it's in colon-hex format\",\n addr));\n }\n}\n\nseastar::net::ipv6_address seastar::net::ipv6_address::read(const char* s) noexcept {\n auto* b = reinterpret_cast(s);\n ipv6_address in;\n std::copy(b, b + ipv6_address::size(), in.ip.begin());\n return in;\n}\n\nseastar::net::ipv6_address seastar::net::ipv6_address::consume(const char*& p) noexcept {\n auto res = read(p);\n p += size();\n return res;\n}\n\nvoid seastar::net::ipv6_address::write(char* p) const noexcept {\n std::copy(ip.begin(), ip.end(), p);\n}\n\nvoid seastar::net::ipv6_address::produce(char*& p) const noexcept {\n write(p);\n p += size();\n}\n\nbool seastar::net::ipv6_address::is_unspecified() const noexcept {\n return std::all_of(ip.begin(), ip.end(), [](uint8_t b) { return b == 0; });\n}\n\nstd::ostream& seastar::net::operator<<(std::ostream& os, const ipv4_address& a) {\n auto ip = a.ip;\n fmt::print(os, \"{:d}.{:d}.{:d}.{:d}\",\n (ip >> 24) & 0xff,\n (ip >> 16) & 0xff,\n (ip >> 8) & 0xff,\n (ip >> 0) & 0xff);\n return os;\n}\n\nstd::ostream& seastar::net::operator<<(std::ostream& os, const ipv6_address& a) {\n char buffer[64];\n return os << ::inet_ntop(AF_INET6, a.ip.data(), buffer, sizeof(buffer));\n}\n\nseastar::ipv6_addr::ipv6_addr(const ipv6_bytes& b, uint16_t p) noexcept\n : ip(b), port(p)\n{}\n\nseastar::ipv6_addr::ipv6_addr(uint16_t p) noexcept\n : ipv6_addr(net::inet_address(), p)\n{}\n\nseastar::ipv6_addr::ipv6_addr(const ::in6_addr& in6, uint16_t p) noexcept\n : ipv6_addr(net::ipv6_address(in6).bytes(), p)\n{}\n\nseastar::ipv6_addr::ipv6_addr(const std::string& s)\n : ipv6_addr([&] {\n auto lc = s.find_last_of(']');\n auto cp = s.find_first_of(':', lc);\n auto port = cp != std::string::npos ? std::stoul(s.substr(cp + 1)) : 0;\n auto ss = lc != std::string::npos ? s.substr(1, lc - 1) : s;\n return ipv6_addr(net::ipv6_address(ss).bytes(), uint16_t(port));\n }())\n{}\n\nseastar::ipv6_addr::ipv6_addr(const std::string& s, uint16_t p)\n : ipv6_addr(net::ipv6_address(s).bytes(), p)\n{}\n\nseastar::ipv6_addr::ipv6_addr(const net::inet_address& i, uint16_t p) noexcept\n : ipv6_addr(i.as_ipv6_address().bytes(), p)\n{}\n\nseastar::ipv6_addr::ipv6_addr(const ::sockaddr_in6& s) noexcept\n : ipv6_addr(s.sin6_addr, net::ntoh(s.sin6_port))\n{}\n\nseastar::ipv6_addr::ipv6_addr(const socket_address& s) noexcept\n : ipv6_addr(s.as_posix_sockaddr_in6())\n{}\n\nbool seastar::ipv6_addr::is_ip_unspecified() const noexcept {\n return std::all_of(ip.begin(), ip.end(), [](uint8_t b) { return b == 0; });\n}\n\n\nseastar::net::inet_address seastar::socket_address::addr() const noexcept {\n switch (family()) {\n case AF_INET:\n return net::inet_address(as_posix_sockaddr_in().sin_addr);\n case AF_INET6:\n return net::inet_address(as_posix_sockaddr_in6().sin6_addr, as_posix_sockaddr_in6().sin6_scope_id);\n default:\n return net::inet_address();\n }\n}\n\n::in_port_t seastar::socket_address::port() const noexcept {\n return net::ntoh(u.in.sin_port);\n}\n\nbool seastar::socket_address::is_wildcard() const noexcept {\n switch (family()) {\n case AF_INET: {\n ipv4_addr addr(*this);\n return addr.is_ip_unspecified() && addr.is_port_unspecified();\n }\n default:\n case AF_INET6: {\n ipv6_addr addr(*this);\n return addr.is_ip_unspecified() && addr.is_port_unspecified();\n }\n case AF_UNIX:\n return length() <= sizeof(::sa_family_t);\n }\n}\n\nstd::ostream& seastar::net::operator<<(std::ostream& os, const inet_address& addr) {\n char buffer[64];\n os << inet_ntop(int(addr.in_family()), addr.data(), buffer, sizeof(buffer));\n if (addr.scope() != inet_address::invalid_scope) {\n os << \"%\" << addr.scope();\n }\n return os;\n}\n\nstd::ostream& seastar::net::operator<<(std::ostream& os, const inet_address::family& f) {\n switch (f) {\n case inet_address::family::INET:\n os << \"INET\";\n break;\n case inet_address::family::INET6:\n os << \"INET6\";\n break;\n default:\n break;\n }\n return os;\n}\n\nstd::ostream& seastar::operator<<(std::ostream& os, const ipv4_addr& a) {\n return os << seastar::socket_address(a);\n}\n\nstd::ostream& seastar::operator<<(std::ostream& os, const ipv6_addr& a) {\n return os << seastar::socket_address(a);\n}\n\nsize_t std::hash::operator()(const seastar::net::inet_address& a) const {\n switch (a.in_family()) {\n case seastar::net::inet_address::family::INET:\n return std::hash()(a.as_ipv4_address());\n case seastar::net::inet_address::family::INET6:\n return std::hash()(a.as_ipv6_address());\n default:\n return 0;\n }\n}\n\nsize_t std::hash::operator()(const seastar::net::ipv6_address& a) const {\n return boost::hash_range(a.ip.begin(), a.ip.end());\n}\n\nsize_t std::hash::operator()(const seastar::ipv4_addr& x) const {\n size_t h = x.ip;\n boost::hash_combine(h, x.port);\n return h;\n}\n\n// Path: src/net/native-stack.cc\n/*\n * This file is open source software, licensed to you under the terms\n * of the Apache License, Version 2.0 (the \"License\"). See the NOTICE file\n * distributed with this work for additional information regarding copyright\n * ownership. You may not use this file except in compliance with the License.\n *\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing,\n * software distributed under the License is distributed on an\n * \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY\n * KIND, either express or implied. See the License for the\n * specific language governing permissions and limitations\n * under the License.\n */\n/*\n * Copyright (C) 2014 Cloudius Systems, Ltd.\n */\n\n#ifdef SEASTAR_MODULE\nmodule;\n#endif\n\n#include \n#include \n#include \n#include \n#include \n#include \n#include \n#include \n\n#include \n#include \n#include \n#include \n\n#ifdef HAVE_OSV\n#include \n#include \n#endif\n\n#ifdef SEASTAR_MODULE\nmodule seastar;\n#else\n#include \n#include \"net/native-stack-impl.hh\"\n#include \n#include \n#include \n#include \n#include \n#include \n#include \n#include \n#include \n#include \n#include \n#endif\n\nnamespace seastar {\n\nnamespace net {\n\nusing namespace seastar;\n\nvoid create_native_net_device(const native_stack_options& opts) {\n\n bool deprecated_config_used = true;\n\n std::stringstream net_config;\n\n if ( opts.net_config) {\n deprecated_config_used = false;\n net_config << opts.net_config.get_value();\n }\n if ( opts.net_config_file) {\n deprecated_config_used = false;\n std::fstream fs(opts.net_config_file.get_value());\n net_config << fs.rdbuf();\n }\n\n std::unique_ptr dev;\n\n if ( deprecated_config_used) {\n#ifdef SEASTAR_HAVE_DPDK\n if ( opts.dpdk_pmd) {\n dev = create_dpdk_net_device(opts.dpdk_opts.dpdk_port_index.get_value(), smp::count,\n !(opts.lro && opts.lro.get_value() == \"off\"),\n !(opts.dpdk_opts.hw_fc && opts.dpdk_opts.hw_fc.get_value() == \"off\"));\n } else \n#endif \n dev = create_virtio_net_device(opts.virtio_opts, opts.lro);\n }\n else {\n auto device_configs = parse_config(net_config);\n\n if ( device_configs.size() > 1) {\n std::runtime_error(\"only one network interface is supported\");\n }\n\n for ( auto&& device_config : device_configs) {\n auto& hw_config = device_config.second.hw_cfg; \n#ifdef SEASTAR_HAVE_DPDK\n if ( hw_config.port_index || !hw_config.pci_address.empty() ) {\n\t dev = create_dpdk_net_device(hw_config);\n\t } else \n#endif \n {\n (void)hw_config; \n std::runtime_error(\"only DPDK supports new configuration format\"); \n }\n }\n }\n\n auto sem = std::make_shared(0);\n std::shared_ptr sdev(dev.release());\n // set_local_queue on all shard in the background,\n // signal when done.\n // FIXME: handle exceptions\n for (unsigned i = 0; i < smp::count; i++) {\n (void)smp::submit_to(i, [&opts, sdev] {\n uint16_t qid = this_shard_id();\n if (qid < sdev->hw_queues_count()) {\n auto qp = sdev->init_local_queue(opts, qid);\n std::map cpu_weights;\n for (unsigned i = sdev->hw_queues_count() + qid % sdev->hw_queues_count(); i < smp::count; i+= sdev->hw_queues_count()) {\n cpu_weights[i] = 1;\n }\n cpu_weights[qid] = opts.hw_queue_weight.get_value();\n qp->configure_proxies(cpu_weights);\n sdev->set_local_queue(std::move(qp));\n } else {\n auto master = qid % sdev->hw_queues_count();\n sdev->set_local_queue(create_proxy_net_device(master, sdev.get()));\n }\n }).then([sem] {\n sem->signal();\n });\n }\n // wait for all shards to set their local queue,\n // then when link is ready, communicate the native_stack to the caller\n // via `create_native_stack` (that sets the ready_promise value)\n (void)sem->wait(smp::count).then([&opts, sdev] {\n // FIXME: future is discarded\n (void)sdev->link_ready().then([&opts, sdev] {\n for (unsigned i = 0; i < smp::count; i++) {\n // FIXME: future is discarded\n (void)smp::submit_to(i, [&opts, sdev] {\n create_native_stack(opts, sdev);\n });\n }\n });\n });\n}\n\n// native_network_stack\nclass native_network_stack : public network_stack {\npublic:\n static thread_local promise> ready_promise;\nprivate:\n interface _netif;\n ipv4 _inet;\n bool _dhcp = false;\n promise<> _config;\n timer<> _timer;\n\n future<> run_dhcp(bool is_renew = false, const dhcp::lease & res = dhcp::lease());\n void on_dhcp(std::optional lease, bool is_renew);\n void set_ipv4_packet_filter(ip_packet_filter* filter) {\n _inet.set_packet_filter(filter);\n }\n using tcp4 = tcp;\npublic:\n explicit native_network_stack(const native_stack_options& opts, std::shared_ptr dev);\n virtual server_socket listen(socket_address sa, listen_options opt) override;\n virtual ::seastar::socket socket() override;\n virtual udp_channel make_udp_channel(const socket_address& addr) override;\n virtual net::datagram_channel make_unbound_datagram_channel(sa_family_t) override;\n virtual net::datagram_channel make_bound_datagram_channel(const socket_address& local) override;\n virtual future<> initialize() override;\n static future> create(const program_options::option_group& opts) {\n auto ns_opts = dynamic_cast(&opts);\n assert(ns_opts);\n if (this_shard_id() == 0) {\n create_native_net_device(*ns_opts);\n }\n return ready_promise.get_future();\n }\n virtual bool has_per_core_namespace() override { return true; };\n void arp_learn(ethernet_address l2, ipv4_address l3) {\n _inet.learn(l2, l3);\n }\n friend class native_server_socket_impl;\n\n class native_network_interface;\n friend class native_network_interface;\n\n std::vector network_interfaces() override;\n};\n\nthread_local promise> native_network_stack::ready_promise;\n\nudp_channel\nnative_network_stack::make_udp_channel(const socket_address& addr) {\n return _inet.get_udp().make_channel(addr);\n}\n\nnet::datagram_channel native_network_stack::make_unbound_datagram_channel(sa_family_t family) {\n if (family != AF_INET) {\n throw std::runtime_error(\"Unsupported address family\");\n }\n\n return _inet.get_udp().make_channel({});\n}\n\nnet::datagram_channel native_network_stack::make_bound_datagram_channel(const socket_address& local) {\n return _inet.get_udp().make_channel(local);\n}\n\nnative_network_stack::native_network_stack(const native_stack_options& opts, std::shared_ptr dev)\n : _netif(std::move(dev))\n , _inet(&_netif) {\n _inet.get_udp().set_queue_size(opts.udpv4_queue_size.get_value());\n _dhcp = opts.host_ipv4_addr.defaulted()\n && opts.gw_ipv4_addr.defaulted()\n && opts.netmask_ipv4_addr.defaulted() && opts.dhcp.get_value();\n if (!_dhcp) {\n _inet.set_host_address(ipv4_address(opts.host_ipv4_addr.get_value()));\n _inet.set_gw_address(ipv4_address(opts.gw_ipv4_addr.get_value()));\n _inet.set_netmask_address(ipv4_address(opts.netmask_ipv4_addr.get_value()));\n }\n}\n\nserver_socket\nnative_network_stack::listen(socket_address sa, listen_options opts) {\n assert(sa.family() == AF_INET || sa.is_unspecified());\n return tcpv4_listen(_inet.get_tcp(), ntohs(sa.as_posix_sockaddr_in().sin_port), opts);\n}\n\nseastar::socket native_network_stack::socket() {\n return tcpv4_socket(_inet.get_tcp());\n}\n\nusing namespace std::chrono_literals;\n\nfuture<> native_network_stack::run_dhcp(bool is_renew, const dhcp::lease& res) {\n dhcp d(_inet);\n // Hijack the ip-stack.\n auto f = d.get_ipv4_filter();\n return smp::invoke_on_all([f] {\n auto & ns = static_cast(engine().net());\n ns.set_ipv4_packet_filter(f);\n }).then([this, d = std::move(d), is_renew, res = res]() mutable {\n net::dhcp::result_type fut = is_renew ? d.renew(res) : d.discover();\n return fut.then([this, is_renew](std::optional lease) {\n return smp::invoke_on_all([] {\n auto & ns = static_cast(engine().net());\n ns.set_ipv4_packet_filter(nullptr);\n }).then(std::bind(&net::native_network_stack::on_dhcp, this, lease, is_renew));\n }).finally([d = std::move(d)] {});\n });\n}\n\nvoid native_network_stack::on_dhcp(std::optional lease, bool is_renew) {\n if (lease) {\n auto& res = *lease;\n _inet.set_host_address(res.ip);\n _inet.set_gw_address(res.gateway);\n _inet.set_netmask_address(res.netmask);\n }\n // Signal waiters.\n if (!is_renew) {\n _config.set_value();\n }\n\n if (this_shard_id() == 0) {\n // And the other cpus, which, in the case of initial discovery,\n // will be waiting for us.\n for (unsigned i = 1; i < smp::count; i++) {\n (void)smp::submit_to(i, [lease, is_renew]() {\n auto & ns = static_cast(engine().net());\n ns.on_dhcp(lease, is_renew);\n });\n }\n if (lease) {\n // And set up to renew the lease later on.\n auto& res = *lease;\n _timer.set_callback(\n [this, res]() {\n _config = promise<>();\n // callback ignores future result\n (void)run_dhcp(true, res);\n });\n _timer.arm(\n std::chrono::duration_cast(\n res.lease_time));\n }\n }\n}\n\nfuture<> native_network_stack::initialize() {\n return network_stack::initialize().then([this]() {\n if (!_dhcp) {\n return make_ready_future();\n }\n\n // Only run actual discover on main cpu.\n // All other cpus must simply for main thread to complete and signal them.\n if (this_shard_id() == 0) {\n // FIXME: future is discarded\n (void)run_dhcp();\n }\n return _config.get_future();\n });\n}\n\n\nvoid arp_learn(ethernet_address l2, ipv4_address l3)\n{\n // Run arp_learn on all shard in the background\n (void)smp::invoke_on_all([l2, l3] {\n auto & ns = static_cast(engine().net());\n ns.arp_learn(l2, l3);\n });\n}\n\nvoid create_native_stack(const native_stack_options& opts, std::shared_ptr dev) {\n native_network_stack::ready_promise.set_value(std::unique_ptr(std::make_unique(opts, std::move(dev))));\n}\n\nnative_stack_options::native_stack_options()\n : program_options::option_group(nullptr, \"Native networking stack options\")\n // these two are ghost options\n , net_config(*this, \"net-config\", program_options::unused{})\n , net_config_file(*this, \"net-config-file\", program_options::unused{})\n , tap_device(*this, \"tap-device\",\n \"tap0\",\n \"tap device to connect to\")\n , host_ipv4_addr(*this, \"host-ipv4-addr\",\n \"192.168.122.2\",\n \"static IPv4 address to use\")\n , gw_ipv4_addr(*this, \"gw-ipv4-addr\",\n \"192.168.122.1\",\n \"static IPv4 gateway to use\")\n , netmask_ipv4_addr(*this, \"netmask-ipv4-addr\",\n \"255.255.255.0\",\n \"static IPv4 netmask to use\")\n , udpv4_queue_size(*this, \"udpv4-queue-size\",\n ipv4_udp::default_queue_size,\n \"Default size of the UDPv4 per-channel packet queue\")\n , dhcp(*this, \"dhcp\",\n true,\n \"Use DHCP discovery\")\n , hw_queue_weight(*this, \"hw-queue-weight\",\n 1.0f,\n \"Weighing of a hardware network queue relative to a software queue (0=no work, 1=equal share)\")\n#ifdef SEASTAR_HAVE_DPDK\n , dpdk_pmd(*this, \"dpdk-pmd\", \"Use DPDK PMD drivers\")\n#else\n , dpdk_pmd(*this, \"dpdk-pmd\", program_options::unused{})\n#endif\n , lro(*this, \"lro\",\n \"on\",\n \"Enable LRO\")\n , virtio_opts(this)\n , dpdk_opts(this)\n{\n}\n\nnetwork_stack_entry register_native_stack() {\n return network_stack_entry{\"native\", std::make_unique(), native_network_stack::create, false};\n}\n\nclass native_network_stack::native_network_interface : public net::network_interface_impl {\n const native_network_stack& _stack;\n std::vector _addresses;\n std::vector _hardware_address;\npublic:\n native_network_interface(const native_network_stack& stack)\n : _stack(stack)\n , _addresses(1, _stack._inet.host_address())\n {\n const auto mac = _stack._inet.netif()->hw_address().mac;\n _hardware_address = std::vector{mac.cbegin(), mac.cend()};\n }\n native_network_interface(const native_network_interface&) = default;\n\n uint32_t index() const override {\n return 0;\n }\n uint32_t mtu() const override {\n return _stack._inet.netif()->hw_features().mtu;\n }\n const sstring& name() const override {\n static const sstring name = \"if0\";\n return name;\n }\n const sstring& display_name() const override {\n return name();\n }\n const std::vector& addresses() const override {\n return _addresses; \n }\n const std::vector hardware_address() const override {\n return _hardware_address;\n }\n bool is_loopback() const override {\n return false; \n }\n bool is_virtual() const override {\n return false;\n }\n bool is_up() const override {\n return true;\n }\n bool supports_ipv6() const override {\n return false;\n }\n};\n\nstd::vector native_network_stack::network_interfaces() {\n if (!_inet.netif()) {\n return {};\n }\n\n static const native_network_interface nwif(*this);\n\n std::vector res;\n res.emplace_back(make_shared(nwif));\n return res;\n}\n\n}\n\n}\n\n// Path: src/net/tls.cc\n/*\n * This file is open source software, licensed to you under the terms\n * of the Apache License, Version 2.0 (the \"License\"). See the NOTICE file\n * distributed with this work for additional information regarding copyright\n * ownership. You may not use this file except in compliance with the License.\n *\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing,\n * software distributed under the License is distributed on an\n * \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY\n * KIND, either express or implied. See the License for the\n * specific language governing permissions and limitations\n * under the License.\n */\n/*\n * Copyright 2015 Cloudius Systems\n */\n\n#ifdef SEASTAR_MODULE\nmodule;\n#endif\n\n#include \n#include \n#include \n#include \n#include \n#include \n\n#include \n#include \n#include \n#include \n\n#include \n#include \n#include \n\n#include \n#include \n\n#ifdef SEASTAR_MODULE\nmodule seastar;\n#else\n#include \n#include \n#include \n#include \n#include \n#include \n#include \n#include \n#include \n#include \n#include \n#include \n#include \n#include \n#include \n#endif\n\nnamespace seastar {\n\nclass net::get_impl {\npublic:\n static std::unique_ptr get(connected_socket s) {\n return std::move(s._csi);\n }\n\n static connected_socket_impl* maybe_get_ptr(connected_socket& s) {\n if (s._csi) {\n return s._csi.get();\n }\n return nullptr;\n }\n};\n\nclass blob_wrapper: public gnutls_datum_t {\npublic:\n blob_wrapper(const tls::blob& in)\n : gnutls_datum_t {\n reinterpret_cast(const_cast(in.data())),\n unsigned(in.size()) } {\n }\n};\n\nclass gnutlsinit {\npublic:\n gnutlsinit() {\n gnutls_global_init();\n }\n ~gnutlsinit() {\n gnutls_global_deinit();\n }\n};\n\n// Helper to ensure gnutls legacy init\n// is handled properly with regards to\n// object life spans. Could be better,\n// this version will not destroy the\n// gnutls stack until process exit.\nclass gnutlsobj {\npublic:\n gnutlsobj() {\n static gnutlsinit init;\n }\n};\n\n// Helper\nstruct file_info {\n sstring filename;\n std::chrono::system_clock::time_point modified;\n};\n\nstruct file_result {\n temporary_buffer buf;\n file_info file;\n operator temporary_buffer&&() && {\n return std::move(buf);\n }\n};\n\nstatic future read_fully(const sstring& name, const sstring& what) {\n return open_file_dma(name, open_flags::ro).then([name = name](file f) mutable {\n return do_with(std::move(f), [name = std::move(name)](file& f) mutable {\n return f.stat().then([&f, name = std::move(name)](struct stat s) mutable {\n return f.dma_read_bulk(0, s.st_size).then([s, name = std::move(name)](temporary_buffer buf) mutable {\n return file_result{ std::move(buf), file_info{ \n std::move(name), std::chrono::system_clock::from_time_t(s.st_mtim.tv_sec) +\n std::chrono::duration_cast(std::chrono::nanoseconds(s.st_mtim.tv_nsec))\n } };\n });\n }).finally([&f]() {\n return f.close();\n });\n });\n }).handle_exception([name = name, what = what](std::exception_ptr ep) -> future {\n try {\n std::rethrow_exception(std::move(ep));\n } catch (...) {\n std::throw_with_nested(std::runtime_error(sstring(\"Could not read \") + what + \" \" + name));\n }\n });\n}\n\n// Note: we are not using gnutls++ interfaces, mainly because we\n// want to keep _our_ interface reasonably non-gnutls (well...)\n// and once we get to this level, their abstractions don't help\n// that much anyway. And they are sooo c++98...\nclass gnutls_error_category : public std::error_category {\npublic:\n constexpr gnutls_error_category() noexcept : std::error_category{} {}\n const char * name() const noexcept override {\n return \"GnuTLS\";\n }\n std::string message(int error) const override {\n return gnutls_strerror(error);\n }\n};\n\nconst std::error_category& tls::error_category() {\n static const gnutls_error_category ec;\n return ec;\n}\n\n// Checks a gnutls return value.\n// < 0 -> error.\nstatic void gtls_chk(int res) {\n if (res < 0) {\n throw std::system_error(res, tls::error_category());\n }\n}\n\nnamespace {\n\n// helper for gnutls-functions for receiving a string\n// arguments\n// func - the gnutls function that is returning a string (e.g. gnutls_x509_crt_get_issuer_dn)\n// args - the arguments to func that come before the char array's ptr and size args\n// returns\n// pair - [gnutls error code, extracted string],\n// in case of no errors, the error code is zero\nstatic auto get_gtls_string = [](auto func, auto... args) noexcept {\n size_t size = 0;\n int ret = func(args..., nullptr, &size);\n\n // by construction, we expect the SHORT_MEMORY_BUFFER error code here\n if (ret != GNUTLS_E_SHORT_MEMORY_BUFFER) {\n return std::make_pair(ret, sstring{});\n }\n assert(size != 0);\n sstring res(sstring::initialized_later{}, size - 1);\n ret = func(args..., res.data(), &size);\n return std::make_pair(ret, res);\n};\n\n}\n\nclass tls::dh_params::impl : gnutlsobj {\n static gnutls_sec_param_t to_gnutls_level(level l) {\n switch (l) {\n case level::LEGACY: return GNUTLS_SEC_PARAM_LEGACY;\n#if GNUTLS_VERSION_NUMBER >= 0x030300\n case level::MEDIUM: return GNUTLS_SEC_PARAM_MEDIUM;\n#else\n case level::MEDIUM: return GNUTLS_SEC_PARAM_NORMAL;\n#endif\n case level::HIGH: return GNUTLS_SEC_PARAM_HIGH;\n case level::ULTRA: return GNUTLS_SEC_PARAM_ULTRA;\n default:\n throw std::runtime_error(format(\"Unknown value of dh_params::level: {:d}\", static_cast>(l)));\n }\n }\n using dh_ptr = std::unique_ptr, void(*)(gnutls_dh_params_t)>;\n\n static dh_ptr new_dh_params() {\n gnutls_dh_params_t params;\n gtls_chk(gnutls_dh_params_init(¶ms));\n return dh_ptr(params, &gnutls_dh_params_deinit);\n }\npublic:\n impl(dh_ptr p) \n : _params(std::move(p)) \n {}\n impl(level lvl)\n#if GNUTLS_VERSION_NUMBER >= 0x030506\n : _params(nullptr, &gnutls_dh_params_deinit)\n , _sec_param(to_gnutls_level(lvl))\n#else \n : impl([&] {\n auto bits = gnutls_sec_param_to_pk_bits(GNUTLS_PK_DH, to_gnutls_level(lvl));\n auto ptr = new_dh_params();\n gtls_chk(gnutls_dh_params_generate2(ptr.get(), bits));\n return ptr;\n }())\n#endif\n {}\n impl(const blob& pkcs3, x509_crt_format fmt)\n : impl([&] {\n auto ptr = new_dh_params();\n blob_wrapper w(pkcs3);\n gtls_chk(gnutls_dh_params_import_pkcs3(ptr.get(), &w, gnutls_x509_crt_fmt_t(fmt)));\n return ptr;\n }()) \n {}\n impl(const impl& v)\n : impl([&v] {\n auto ptr = new_dh_params();\n gtls_chk(gnutls_dh_params_cpy(ptr.get(), v));\n return ptr;\n }()) \n {}\n ~impl() = default;\n\n operator gnutls_dh_params_t() const {\n return _params.get();\n }\n#if GNUTLS_VERSION_NUMBER >= 0x030506\n std::optional sec_param() const {\n return _sec_param;\n }\n#endif\nprivate:\n dh_ptr _params;\n#if GNUTLS_VERSION_NUMBER >= 0x030506\n std::optional _sec_param;\n#endif\n};\n\ntls::dh_params::dh_params(level lvl) : _impl(std::make_unique(lvl))\n{}\n\ntls::dh_params::dh_params(const blob& b, x509_crt_format fmt)\n : _impl(std::make_unique(b, fmt)) {\n}\n\ntls::dh_params::~dh_params() {\n}\n\ntls::dh_params::dh_params(dh_params&&) noexcept = default;\ntls::dh_params& tls::dh_params::operator=(dh_params&&) noexcept = default;\n\nfuture tls::dh_params::from_file(\n const sstring& filename, x509_crt_format fmt) {\n return read_fully(filename, \"dh parameters\").then([fmt](temporary_buffer buf) {\n return make_ready_future(dh_params(blob(buf.get()), fmt));\n });\n}\n\nclass tls::x509_cert::impl : gnutlsobj {\npublic:\n impl()\n : _cert([] {\n gnutls_x509_crt_t cert;\n gtls_chk(gnutls_x509_crt_init(&cert));\n return cert;\n }()) {\n }\n impl(const blob& b, x509_crt_format fmt)\n : impl()\n {\n blob_wrapper w(b);\n gtls_chk(gnutls_x509_crt_import(*this, &w, gnutls_x509_crt_fmt_t(fmt)));\n }\n ~impl() {\n if (_cert != nullptr) {\n gnutls_x509_crt_deinit(_cert);\n }\n }\n operator gnutls_x509_crt_t() const {\n return _cert;\n }\n\nprivate:\n gnutls_x509_crt_t _cert;\n};\n\ntls::x509_cert::x509_cert(shared_ptr impl)\n : _impl(std::move(impl)) {\n}\n\ntls::x509_cert::x509_cert(const blob& b, x509_crt_format fmt)\n : x509_cert(::seastar::make_shared(b, fmt)) {\n}\n\nfuture tls::x509_cert::from_file(\n const sstring& filename, x509_crt_format fmt) {\n return read_fully(filename, \"x509 certificate\").then([fmt](temporary_buffer buf) {\n return make_ready_future(x509_cert(blob(buf.get()), fmt));\n });\n}\n\nclass tls::certificate_credentials::impl: public gnutlsobj {\npublic:\n impl()\n : _creds([] {\n gnutls_certificate_credentials_t xcred;\n gnutls_certificate_allocate_credentials(&xcred);\n if (xcred == nullptr) {\n throw std::bad_alloc();\n }\n return xcred;\n }()), _priority(nullptr, &gnutls_priority_deinit)\n {}\n ~impl() {\n if (_creds != nullptr) {\n gnutls_certificate_free_credentials (_creds);\n }\n }\n\n operator gnutls_certificate_credentials_t() const {\n return _creds;\n }\n\n void set_x509_trust(const blob& b, x509_crt_format fmt) {\n blob_wrapper w(b);\n gtls_chk(\n gnutls_certificate_set_x509_trust_mem(_creds, &w,\n gnutls_x509_crt_fmt_t(fmt)));\n }\n void set_x509_crl(const blob& b, x509_crt_format fmt) {\n blob_wrapper w(b);\n gtls_chk(\n gnutls_certificate_set_x509_crl_mem(_creds, &w,\n gnutls_x509_crt_fmt_t(fmt)));\n }\n void set_x509_key(const blob& cert, const blob& key, x509_crt_format fmt) {\n blob_wrapper w1(cert);\n blob_wrapper w2(key);\n gtls_chk(\n gnutls_certificate_set_x509_key_mem(_creds, &w1, &w2,\n gnutls_x509_crt_fmt_t(fmt)));\n }\n void set_simple_pkcs12(const blob& b, x509_crt_format fmt,\n const sstring& password) {\n blob_wrapper w(b);\n gtls_chk(\n gnutls_certificate_set_x509_simple_pkcs12_mem(_creds, &w,\n gnutls_x509_crt_fmt_t(fmt), password.c_str()));\n }\n void dh_params(const tls::dh_params& dh) {\n#if GNUTLS_VERSION_NUMBER >= 0x030506\n auto sec_param = dh._impl->sec_param();\n if (sec_param) {\n gnutls_certificate_set_known_dh_params(*this, *sec_param);\n return;\n }\n#endif\n...\n", "needle_token_start": 10413, "needle_token_end": 10513, "code_context_ntokens": 16118, "output": ["```cpp\nvoid arp_learn(ethernet_address l2, ipv4_address l3)\n{\n // Run arp_learn on all shard in the background\n (void)smp::invoke_on_all([l2, l3] {\n auto & ns = static_cast(engine().net());\n ns.arp_learn(l2, l3);\n });\n}\n```"]} +{"repo": "scylladb/seastar", "name": "RTE_LCORE_FOREACH_WORKER", "language": "cpp", "path": "src/core/reactor.cc", "position_ratio": 0.75, "description": "\nFunction Description:\n1. **Purpose**: The function iterates over each worker logical core (excluding the main core) available in a multi-core system, typically used in environments managed by the Data Plane Development Kit (DPDK).\n2. **Input**: There are no direct inputs as it functions as a macro iterating over available worker cores.\n3. **Output**: There is no direct output; however, it controls the flow to execute specific operations on each worker core.\n4. **Procedure**: This macro retrieves the list of all worker cores, excluding the main core, and iterates over them. During each iteration, it allows the execution of specified operations on each worker core, such as launching threads or processes tailored to run on these specific cores.\n\n", "instruction": "Based on the function description and code context, please retrieve and repeat the exact described function from the code context in a code block wrapped by ```:", "template": "instruction\ncode_context\ndescription\ninstruction", "code_context": " _metrics.add_group(\"smp\", {\n // queue_length value:GAUGE:0:U\n // Absolute value of num packets in last tx batch.\n sm::make_queue_length(\"send_batch_queue_length\", _last_snt_batch, sm::description(\"Current send batch queue length\"), {sm::shard_label(instance)})(sm::metric_disabled),\n sm::make_queue_length(\"receive_batch_queue_length\", _last_rcv_batch, sm::description(\"Current receive batch queue length\"), {sm::shard_label(instance)})(sm::metric_disabled),\n sm::make_queue_length(\"complete_batch_queue_length\", _last_cmpl_batch, sm::description(\"Current complete batch queue length\"), {sm::shard_label(instance)})(sm::metric_disabled),\n sm::make_queue_length(\"send_queue_length\", _current_queue_length, sm::description(\"Current send queue length\"), {sm::shard_label(instance)})(sm::metric_disabled),\n // total_operations value:DERIVE:0:U\n sm::make_counter(\"total_received_messages\", _received, sm::description(\"Total number of received messages\"), {sm::shard_label(instance)})(sm::metric_disabled),\n // total_operations value:DERIVE:0:U\n sm::make_counter(\"total_sent_messages\", _sent, sm::description(\"Total number of sent messages\"), {sm::shard_label(instance)})(sm::metric_disabled),\n // total_operations value:DERIVE:0:U\n sm::make_counter(\"total_completed_messages\", _compl, sm::description(\"Total number of messages completed\"), {sm::shard_label(instance)})(sm::metric_disabled)\n });\n}\n\nreadable_eventfd writeable_eventfd::read_side() {\n return readable_eventfd(_fd.dup());\n}\n\nfile_desc writeable_eventfd::try_create_eventfd(size_t initial) {\n assert(size_t(int(initial)) == initial);\n return file_desc::eventfd(initial, EFD_CLOEXEC);\n}\n\nvoid writeable_eventfd::signal(size_t count) {\n uint64_t c = count;\n auto r = _fd.write(&c, sizeof(c));\n assert(r == sizeof(c));\n}\n\nwriteable_eventfd readable_eventfd::write_side() {\n return writeable_eventfd(_fd.get_file_desc().dup());\n}\n\nfile_desc readable_eventfd::try_create_eventfd(size_t initial) {\n assert(size_t(int(initial)) == initial);\n return file_desc::eventfd(initial, EFD_CLOEXEC | EFD_NONBLOCK);\n}\n\nfuture readable_eventfd::wait() {\n return engine().readable(*_fd._s).then([this] {\n uint64_t count;\n int r = ::read(_fd.get_fd(), &count, sizeof(count));\n assert(r == sizeof(count));\n return make_ready_future(count);\n });\n}\n\nvoid schedule(task* t) noexcept {\n engine().add_task(t);\n}\n\nvoid schedule_checked(task* t) noexcept {\n if (t->group().is_at_exit()) {\n // trying to schedule a task in at_destroy. Not allowed\n on_internal_error(seastar_logger, \"Cannot schedule tasks in at_destroy queue. Use reactor::at_destroy.\");\n }\n engine().add_task(t);\n}\n\nvoid schedule_urgent(task* t) noexcept {\n engine().add_urgent_task(t);\n}\n\n}\n\nbool operator==(const ::sockaddr_in a, const ::sockaddr_in b) {\n return (a.sin_addr.s_addr == b.sin_addr.s_addr) && (a.sin_port == b.sin_port);\n}\n\nnamespace seastar {\n\nstatic bool kernel_supports_aio_fsync() {\n return internal::kernel_uname().whitelisted({\"4.18\"});\n}\n\nstatic program_options::selection_value create_network_stacks_option(reactor_options& zis) {\n using value_type = program_options::selection_value;\n value_type::candidates candidates;\n std::vector net_stack_names;\n\n auto deleter = [] (network_stack_factory* p) { delete p; };\n\n std::string default_stack;\n for (auto reg_func : {register_native_stack, register_posix_stack}) {\n auto s = reg_func();\n if (s.is_default) {\n default_stack = s.name;\n }\n candidates.push_back({s.name, {new network_stack_factory(std::move(s.factory)), deleter}, std::move(s.opts)});\n net_stack_names.emplace_back(s.name);\n }\n\n return program_options::selection_value(zis, \"network-stack\", std::move(candidates), default_stack,\n fmt::format(\"select network stack (valid values: {})\", fmt::join(net_stack_names, \", \")));\n}\n\nstatic program_options::selection_value::candidates backend_selector_candidates() {\n using value_type = program_options::selection_value;\n value_type::candidates candidates;\n\n auto deleter = [] (reactor_backend_selector* p) { delete p; };\n\n for (auto&& be : reactor_backend_selector::available()) {\n auto name = be.name();\n candidates.push_back({std::move(name), {new reactor_backend_selector(std::move(be)), deleter}, {}});\n }\n return candidates;\n}\n\nreactor_options::reactor_options(program_options::option_group* parent_group)\n : program_options::option_group(parent_group, \"Core options\")\n , network_stack(create_network_stacks_option(*this))\n , poll_mode(*this, \"poll-mode\", \"poll continuously (100% cpu use)\")\n , idle_poll_time_us(*this, \"idle-poll-time-us\", reactor::calculate_poll_time() / 1us,\n \"idle polling time in microseconds (reduce for overprovisioned environments or laptops)\")\n , poll_aio(*this, \"poll-aio\", true,\n \"busy-poll for disk I/O (reduces latency and increases throughput)\")\n , task_quota_ms(*this, \"task-quota-ms\", 0.5, \"Max time (ms) between polls\")\n , io_latency_goal_ms(*this, \"io-latency-goal-ms\", {}, \"Max time (ms) io operations must take (1.5 * task-quota-ms if not set)\")\n , io_flow_ratio_threshold(*this, \"io-flow-rate-threshold\", 1.1, \"Dispatch rate to completion rate threshold\")\n , max_task_backlog(*this, \"max-task-backlog\", 1000, \"Maximum number of task backlog to allow; above this we ignore I/O\")\n , blocked_reactor_notify_ms(*this, \"blocked-reactor-notify-ms\", 25, \"threshold in miliseconds over which the reactor is considered blocked if no progress is made\")\n , blocked_reactor_reports_per_minute(*this, \"blocked-reactor-reports-per-minute\", 5, \"Maximum number of backtraces reported by stall detector per minute\")\n , blocked_reactor_report_format_oneline(*this, \"blocked-reactor-report-format-oneline\", true, \"Print a simplified backtrace on a single line\")\n , relaxed_dma(*this, \"relaxed-dma\", \"allow using buffered I/O if DMA is not available (reduces performance)\")\n , linux_aio_nowait(*this, \"linux-aio-nowait\", aio_nowait_supported,\n \"use the Linux NOWAIT AIO feature, which reduces reactor stalls due to aio (autodetected)\")\n , unsafe_bypass_fsync(*this, \"unsafe-bypass-fsync\", false, \"Bypass fsync(), may result in data loss. Use for testing on consumer drives\")\n , kernel_page_cache(*this, \"kernel-page-cache\", false,\n \"Use the kernel page cache. This disables DMA (O_DIRECT).\"\n \" Useful for short-lived functional tests with a small data set.\")\n , overprovisioned(*this, \"overprovisioned\", \"run in an overprovisioned environment (such as docker or a laptop); equivalent to --idle-poll-time-us 0 --thread-affinity 0 --poll-aio 0\")\n , abort_on_seastar_bad_alloc(*this, \"abort-on-seastar-bad-alloc\", \"abort when seastar allocator cannot allocate memory\")\n , force_aio_syscalls(*this, \"force-aio-syscalls\", false,\n \"Force io_getevents(2) to issue a system call, instead of bypassing the kernel when possible.\"\n \" This makes strace output more useful, but slows down the application\")\n , dump_memory_diagnostics_on_alloc_failure_kind(*this, \"dump-memory-diagnostics-on-alloc-failure-kind\", memory::alloc_failure_kind::critical,\n \"Dump diagnostics of the seastar allocator state on allocation failure.\"\n \" Accepted values: none, critical (default), all. When set to critical, only allocations marked as critical will trigger diagnostics dump.\"\n \" The diagnostics will be written to the seastar_memory logger, with error level.\"\n \" Note that if the seastar_memory logger is set to debug or trace level, the diagnostics will be logged irrespective of this setting.\")\n , reactor_backend(*this, \"reactor-backend\", backend_selector_candidates(), reactor_backend_selector::default_backend().name(),\n fmt::format(\"Internal reactor implementation ({})\", reactor_backend_selector::available()))\n , aio_fsync(*this, \"aio-fsync\", kernel_supports_aio_fsync(),\n \"Use Linux aio for fsync() calls. This reduces latency; requires Linux 4.18 or later.\")\n , max_networking_io_control_blocks(*this, \"max-networking-io-control-blocks\", 10000,\n \"Maximum number of I/O control blocks (IOCBs) to allocate per shard. This translates to the number of sockets supported per shard.\"\n \" Requires tuning /proc/sys/fs/aio-max-nr. Only valid for the linux-aio reactor backend (see --reactor-backend).\")\n#ifdef SEASTAR_HEAPPROF\n , heapprof(*this, \"heapprof\", 0, \"Enable seastar heap profiling. Sample every ARG bytes. 0 means off\")\n#else\n , heapprof(*this, \"heapprof\", program_options::unused{})\n#endif\n , no_handle_interrupt(*this, \"no-handle-interrupt\", \"ignore SIGINT (for gdb)\")\n{\n}\n\nsmp_options::smp_options(program_options::option_group* parent_group)\n : program_options::option_group(parent_group, \"SMP options\")\n , smp(*this, \"smp\", {}, \"number of threads (default: one per CPU)\")\n , cpuset(*this, \"cpuset\", {}, \"CPUs to use (in cpuset(7) list format (ex: 0,1-3,7); default: all))\")\n , memory(*this, \"memory\", std::nullopt, \"memory to use, in bytes (ex: 4G) (default: all)\")\n , reserve_memory(*this, \"reserve-memory\", {}, \"memory reserved to OS (if --memory not specified)\")\n , hugepages(*this, \"hugepages\", {}, \"path to accessible hugetlbfs mount (typically /dev/hugepages/something)\")\n , lock_memory(*this, \"lock-memory\", {}, \"lock all memory (prevents swapping)\")\n , thread_affinity(*this, \"thread-affinity\", true, \"pin threads to their cpus (disable for overprovisioning)\")\n#ifdef SEASTAR_HAVE_HWLOC\n , num_io_groups(*this, \"num-io-groups\", {}, \"Number of IO groups. Each IO group will be responsible for a fraction of the IO requests. Defaults to the number of NUMA nodes\")\n#else\n , num_io_groups(*this, \"num-io-groups\", program_options::unused{})\n#endif\n , io_properties_file(*this, \"io-properties-file\", {}, \"path to a YAML file describing the characteristics of the I/O Subsystem\")\n , io_properties(*this, \"io-properties\", {}, \"a YAML string describing the characteristics of the I/O Subsystem\")\n , mbind(*this, \"mbind\", true, \"enable mbind\")\n#ifndef SEASTAR_NO_EXCEPTION_HACK\n , enable_glibc_exception_scaling_workaround(*this, \"enable-glibc-exception-scaling-workaround\", true, \"enable workaround for glibc/gcc c++ exception scalablity problem\")\n#else\n , enable_glibc_exception_scaling_workaround(*this, \"enable-glibc-exception-scaling-workaround\", program_options::unused{})\n#endif\n#ifdef SEASTAR_HAVE_HWLOC\n , allow_cpus_in_remote_numa_nodes(*this, \"allow-cpus-in-remote-numa-nodes\", true, \"if some CPUs are found not to have any local NUMA nodes, allow assigning them to remote ones\")\n#else\n , allow_cpus_in_remote_numa_nodes(*this, \"allow-cpus-in-remote-numa-nodes\", program_options::unused{})\n#endif\n{\n}\n\nstruct reactor_deleter {\n void operator()(reactor* p) {\n p->~reactor();\n free(p);\n }\n};\n\nthread_local std::unique_ptr reactor_holder;\n\nthread_local smp_message_queue** smp::_qs;\nthread_local std::thread::id smp::_tmain;\nunsigned smp::count = 0;\n\nvoid smp::start_all_queues()\n{\n for (unsigned c = 0; c < count; c++) {\n if (c != this_shard_id()) {\n _qs[c][this_shard_id()].start(c);\n }\n }\n _alien._qs[this_shard_id()].start();\n}\n\n#ifdef SEASTAR_HAVE_DPDK\n\nint dpdk_thread_adaptor(void* f)\n{\n (*static_cast*>(f))();\n return 0;\n}\n\n#endif\n\nvoid smp::join_all()\n{\n#ifdef SEASTAR_HAVE_DPDK\n if (_using_dpdk) {\n rte_eal_mp_wait_lcore();\n return;\n }\n#endif\n for (auto&& t: smp::_threads) {\n t.join();\n }\n}\n\nvoid smp::pin(unsigned cpu_id) {\n if (_using_dpdk) {\n // dpdk does its own pinning\n return;\n }\n pin_this_thread(cpu_id);\n}\n\nvoid smp::arrive_at_event_loop_end() {\n if (_all_event_loops_done) {\n _all_event_loops_done->wait();\n }\n}\n\nvoid smp::allocate_reactor(unsigned id, reactor_backend_selector rbs, reactor_config cfg) {\n assert(!reactor_holder);\n\n // we cannot just write \"local_engin = new reactor\" since reactor's constructor\n // uses local_engine\n void *buf;\n int r = posix_memalign(&buf, cache_line_size, sizeof(reactor));\n assert(r == 0);\n *internal::this_shard_id_ptr() = id;\n local_engine = new (buf) reactor(this->shared_from_this(), _alien, id, std::move(rbs), cfg);\n reactor_holder.reset(local_engine);\n}\n\nvoid smp::cleanup() noexcept {\n smp::_threads = std::vector();\n _thread_loops.clear();\n reactor_holder.reset();\n local_engine = nullptr;\n}\n\nvoid smp::cleanup_cpu() {\n size_t cpuid = this_shard_id();\n\n if (_qs) {\n for(unsigned i = 0; i < smp::count; i++) {\n _qs[i][cpuid].stop();\n }\n }\n if (_alien._qs) {\n _alien._qs[cpuid].stop();\n }\n}\n\nvoid smp::create_thread(std::function thread_loop) {\n if (_using_dpdk) {\n _thread_loops.push_back(std::move(thread_loop));\n } else {\n _threads.emplace_back(std::move(thread_loop));\n }\n}\n\n// Installs handler for Signal which ensures that Func is invoked only once\n// in the whole program and that after it is invoked the default handler is restored.\ntemplate\nvoid install_oneshot_signal_handler() {\n static bool handled = false;\n static util::spinlock lock;\n\n struct sigaction sa;\n sa.sa_sigaction = [](int sig, siginfo_t *info, void *p) {\n std::lock_guard g(lock);\n if (!handled) {\n handled = true;\n Func();\n signal(sig, SIG_DFL);\n }\n };\n sigfillset(&sa.sa_mask);\n sa.sa_flags = SA_SIGINFO | SA_RESTART;\n if (Signal == SIGSEGV) {\n sa.sa_flags |= SA_ONSTACK;\n }\n auto r = ::sigaction(Signal, &sa, nullptr);\n throw_system_error_on(r == -1);\n}\n\nstatic void reraise_signal(int signo) {\n signal(signo, SIG_DFL);\n pthread_kill(pthread_self(), signo);\n}\n\nstatic void sigsegv_action() noexcept {\n print_with_backtrace(\"Segmentation fault\");\n reraise_signal(SIGSEGV);\n}\n\nstatic void sigabrt_action() noexcept {\n print_with_backtrace(\"Aborting\");\n reraise_signal(SIGABRT);\n}\n\n// We don't need to handle SIGSEGV when asan is enabled.\n#ifdef SEASTAR_ASAN_ENABLED\ntemplate<>\nvoid install_oneshot_signal_handler() {\n (void)sigsegv_action;\n}\n#endif\n\nvoid smp::qs_deleter::operator()(smp_message_queue** qs) const {\n for (unsigned i = 0; i < smp::count; i++) {\n for (unsigned j = 0; j < smp::count; j++) {\n qs[i][j].~smp_message_queue();\n }\n ::operator delete[](qs[i], std::align_val_t(alignof(smp_message_queue))\n );\n }\n delete[](qs);\n}\n\nclass disk_config_params {\nprivate:\n const unsigned _max_queues;\n unsigned _num_io_groups = 0;\n std::unordered_map _mountpoints;\n std::chrono::duration _latency_goal;\n double _flow_ratio_backpressure_threshold;\n\npublic:\n explicit disk_config_params(unsigned max_queues) noexcept\n : _max_queues(max_queues)\n {}\n\n uint64_t per_io_group(uint64_t qty, unsigned nr_groups) const noexcept {\n return std::max(qty / nr_groups, 1ul);\n }\n\n unsigned num_io_groups() const noexcept { return _num_io_groups; }\n\n std::chrono::duration latency_goal() const {\n return _latency_goal;\n }\n\n double latency_goal_opt(const reactor_options& opts) const {\n return opts.io_latency_goal_ms ?\n opts.io_latency_goal_ms.get_value() :\n opts.task_quota_ms.get_value() * 1.5;\n }\n\n void parse_config(const smp_options& smp_opts, const reactor_options& reactor_opts) {\n seastar_logger.debug(\"smp::count: {}\", smp::count);\n _latency_goal = std::chrono::duration_cast>(latency_goal_opt(reactor_opts) * 1ms);\n seastar_logger.debug(\"latency_goal: {}\", latency_goal().count());\n _flow_ratio_backpressure_threshold = reactor_opts.io_flow_ratio_threshold.get_value();\n seastar_logger.debug(\"flow-ratio threshold: {}\", _flow_ratio_backpressure_threshold);\n\n if (smp_opts.num_io_groups) {\n _num_io_groups = smp_opts.num_io_groups.get_value();\n if (!_num_io_groups) {\n throw std::runtime_error(\"num-io-groups must be greater than zero\");\n }\n }\n if (smp_opts.io_properties_file && smp_opts.io_properties) {\n throw std::runtime_error(\"Both io-properties and io-properties-file specified. Don't know which to trust!\");\n }\n\n std::optional doc;\n if (smp_opts.io_properties_file) {\n doc = YAML::LoadFile(smp_opts.io_properties_file.get_value());\n } else if (smp_opts.io_properties) {\n doc = YAML::Load(smp_opts.io_properties.get_value());\n }\n\n if (doc) {\n if (!doc->IsMap()) {\n throw std::runtime_error(\"Bogus io-properties (did you mix up --io-properties and --io-properties-file?)\");\n }\n for (auto&& section : *doc) {\n auto sec_name = section.first.as();\n if (sec_name != \"disks\") {\n throw std::runtime_error(fmt::format(\"While parsing I/O options: section {} currently unsupported.\", sec_name));\n }\n auto disks = section.second.as>();\n for (auto& d : disks) {\n struct ::stat buf;\n auto ret = stat(d.mountpoint.c_str(), &buf);\n if (ret < 0) {\n throw std::runtime_error(fmt::format(\"Couldn't stat {}\", d.mountpoint));\n }\n\n auto st_dev = S_ISBLK(buf.st_mode) ? buf.st_rdev : buf.st_dev;\n if (_mountpoints.count(st_dev)) {\n throw std::runtime_error(fmt::format(\"Mountpoint {} already configured\", d.mountpoint));\n }\n if (_mountpoints.size() >= _max_queues) {\n throw std::runtime_error(fmt::format(\"Configured number of queues {} is larger than the maximum {}\",\n _mountpoints.size(), _max_queues));\n }\n\n d.read_bytes_rate *= d.rate_factor;\n d.write_bytes_rate *= d.rate_factor;\n d.read_req_rate *= d.rate_factor;\n d.write_req_rate *= d.rate_factor;\n\n if (d.read_bytes_rate == 0 || d.write_bytes_rate == 0 ||\n d.read_req_rate == 0 || d.write_req_rate == 0) {\n throw std::runtime_error(fmt::format(\"R/W bytes and req rates must not be zero\"));\n }\n\n seastar_logger.debug(\"dev_id: {} mountpoint: {}\", st_dev, d.mountpoint);\n _mountpoints.emplace(st_dev, d);\n }\n }\n }\n\n // Placeholder for unconfigured disks.\n mountpoint_params d = {};\n _mountpoints.emplace(0, d);\n }\n\n struct io_queue::config generate_config(dev_t devid, unsigned nr_groups) const {\n seastar_logger.debug(\"generate_config dev_id: {}\", devid);\n const mountpoint_params& p = _mountpoints.at(devid);\n struct io_queue::config cfg;\n\n cfg.devid = devid;\n\n if (p.read_bytes_rate != std::numeric_limits::max()) {\n cfg.blocks_count_rate = (io_queue::read_request_base_count * (unsigned long)per_io_group(p.read_bytes_rate, nr_groups)) >> io_queue::block_size_shift;\n cfg.disk_blocks_write_to_read_multiplier = (io_queue::read_request_base_count * p.read_bytes_rate) / p.write_bytes_rate;\n }\n if (p.read_req_rate != std::numeric_limits::max()) {\n cfg.req_count_rate = io_queue::read_request_base_count * (unsigned long)per_io_group(p.read_req_rate, nr_groups);\n cfg.disk_req_write_to_read_multiplier = (io_queue::read_request_base_count * p.read_req_rate) / p.write_req_rate;\n }\n if (p.read_saturation_length != std::numeric_limits::max()) {\n cfg.disk_read_saturation_length = p.read_saturation_length;\n }\n if (p.write_saturation_length != std::numeric_limits::max()) {\n cfg.disk_write_saturation_length = p.write_saturation_length;\n }\n cfg.mountpoint = p.mountpoint;\n cfg.duplex = p.duplex;\n cfg.rate_limit_duration = latency_goal();\n cfg.flow_ratio_backpressure_threshold = _flow_ratio_backpressure_threshold;\n // Block count limit should not be less than the minimal IO size on the device\n // On the other hand, even this is not good enough -- in the worst case the\n // scheduler will self-tune to allow for the single 64k request, while it would\n // be better to sacrifice some IO latency, but allow for larger concurrency\n cfg.block_count_limit_min = (64 << 10) >> io_queue::block_size_shift;\n\n return cfg;\n }\n\n auto device_ids() {\n return boost::adaptors::keys(_mountpoints);\n }\n};\n\nunsigned smp::adjust_max_networking_aio_io_control_blocks(unsigned network_iocbs)\n{\n static unsigned constexpr storage_iocbs = reactor::max_aio;\n static unsigned constexpr preempt_iocbs = 2;\n\n auto aio_max_nr = read_first_line_as(\"/proc/sys/fs/aio-max-nr\");\n auto aio_nr = read_first_line_as(\"/proc/sys/fs/aio-nr\");\n auto available_aio = aio_max_nr - aio_nr;\n auto requested_aio_network = network_iocbs * smp::count;\n auto requested_aio_other = (storage_iocbs + preempt_iocbs) * smp::count;\n auto requested_aio = requested_aio_network + requested_aio_other;\n auto network_iocbs_old = network_iocbs;\n\n if (available_aio < requested_aio) {\n seastar_logger.warn(\"Requested AIO slots too large, please increase request capacity in /proc/sys/fs/aio-max-nr. configured:{} available:{} requested:{}\", aio_max_nr, available_aio, requested_aio);\n if (available_aio >= requested_aio_other + smp::count) { // at least one queue for each shard\n network_iocbs = (available_aio - requested_aio_other) / smp::count;\n seastar_logger.warn(\"max-networking-io-control-blocks adjusted from {} to {}, since AIO slots are unavailable\", network_iocbs_old, network_iocbs);\n } else {\n throw std::runtime_error(\"Could not setup Async I/O: Not enough request capacity in /proc/sys/fs/aio-max-nr. Try increasing that number or reducing the amount of logical CPUs available for your application\");\n }\n }\n\n return network_iocbs;\n}\n\nvoid smp::configure(const smp_options& smp_opts, const reactor_options& reactor_opts)\n{\n bool use_transparent_hugepages = !reactor_opts.overprovisioned;\n\n#ifndef SEASTAR_NO_EXCEPTION_HACK\n if (smp_opts.enable_glibc_exception_scaling_workaround.get_value()) {\n init_phdr_cache();\n }\n#endif\n\n // Mask most, to prevent threads (esp. dpdk helper threads)\n // from servicing a signal. Individual reactors will unmask signals\n // as they become prepared to handle them.\n //\n // We leave some signals unmasked since we don't handle them ourself.\n sigset_t sigs;\n sigfillset(&sigs);\n for (auto sig : {SIGHUP, SIGQUIT, SIGILL, SIGABRT, SIGFPE, SIGSEGV,\n SIGALRM, SIGCONT, SIGSTOP, SIGTSTP, SIGTTIN, SIGTTOU}) {\n sigdelset(&sigs, sig);\n }\n if (!reactor_opts._auto_handle_sigint_sigterm) {\n sigdelset(&sigs, SIGINT);\n sigdelset(&sigs, SIGTERM);\n }\n pthread_sigmask(SIG_BLOCK, &sigs, nullptr);\n\n install_oneshot_signal_handler();\n install_oneshot_signal_handler();\n\n#ifdef SEASTAR_HAVE_DPDK\n const auto* native_stack = dynamic_cast(reactor_opts.network_stack.get_selected_candidate_opts());\n _using_dpdk = native_stack && native_stack->dpdk_pmd;\n#endif\n auto thread_affinity = smp_opts.thread_affinity.get_value();\n if (reactor_opts.overprovisioned\n && smp_opts.thread_affinity.defaulted()) {\n thread_affinity = false;\n }\n if (!thread_affinity && _using_dpdk) {\n fmt::print(\"warning: --thread-affinity 0 ignored in dpdk mode\\n\");\n }\n auto mbind = smp_opts.mbind.get_value();\n if (!thread_affinity) {\n mbind = false;\n }\n\n resource::configuration rc;\n\n smp::_tmain = std::this_thread::get_id();\n resource::cpuset cpu_set = get_current_cpuset();\n\n if (smp_opts.cpuset) {\n auto opts_cpuset = smp_opts.cpuset.get_value();\n // CPUs that are not available are those pinned by\n // --cpuset but not present in current task set\n std::set not_available_cpus;\n std::set_difference(opts_cpuset.begin(), opts_cpuset.end(),\n cpu_set.begin(), cpu_set.end(),\n std::inserter(not_available_cpus, not_available_cpus.end()));\n\n if (!not_available_cpus.empty()) {\n std::ostringstream not_available_cpus_list;\n for (auto cpu_id : not_available_cpus) {\n not_available_cpus_list << \" \" << cpu_id;\n }\n seastar_logger.error(\"Bad value for --cpuset:{} not allowed. Shutting down.\", not_available_cpus_list.str());\n exit(1);\n }\n cpu_set = opts_cpuset;\n }\n\n if (smp_opts.smp) {\n smp::count = smp_opts.smp.get_value();\n } else {\n smp::count = cpu_set.size();\n }\n std::vector reactors(smp::count);\n if (smp_opts.memory) {\n#ifdef SEASTAR_DEFAULT_ALLOCATOR\n seastar_logger.warn(\"Seastar compiled with default allocator, --memory option won't take effect\");\n#endif\n rc.total_memory = parse_memory_size(smp_opts.memory.get_value());\n#ifdef SEASTAR_HAVE_DPDK\n if (smp_opts.hugepages &&\n !reactor_opts.network_stack.get_selected_candidate_name().compare(\"native\") &&\n _using_dpdk) {\n size_t dpdk_memory = dpdk::eal::mem_size(smp::count);\n\n if (dpdk_memory >= rc.total_memory) {\n seastar_logger.error(\"Can't run with the given amount of memory: {}. \"\n \"Consider giving more.\",\n smp_opts.memory.get_value());\n exit(1);\n }\n\n //\n // Subtract the memory we are about to give to DPDK from the total\n // amount of memory we are allowed to use.\n //\n rc.total_memory.value() -= dpdk_memory;\n }\n#endif\n }\n if (smp_opts.reserve_memory) {\n rc.reserve_memory = parse_memory_size(smp_opts.reserve_memory.get_value());\n }\n rc.reserve_additional_memory_per_shard = smp_opts.reserve_additional_memory_per_shard;\n std::optional hugepages_path;\n if (smp_opts.hugepages) {\n hugepages_path = smp_opts.hugepages.get_value();\n }\n auto mlock = false;\n if (smp_opts.lock_memory) {\n mlock = smp_opts.lock_memory.get_value();\n }\n if (mlock) {\n auto extra_flags = 0;\n#ifdef MCL_ONFAULT\n // Linux will serialize faulting in anonymous memory, and also\n // serialize marking them as locked. This can take many minutes on\n // terabyte class machines, so fault them in the future to spread\n // out the cost. This isn't good since we'll see contention if\n // multiple shards fault in memory at once, but if that work can be\n // in parallel to regular reactor work on other shards.\n extra_flags |= MCL_ONFAULT; // Linux 4.4+\n#endif\n auto r = mlockall(MCL_CURRENT | MCL_FUTURE | extra_flags);\n if (r) {\n // Don't hard fail for now, it's hard to get the configuration right\n fmt::print(\"warning: failed to mlockall: {}\\n\", strerror(errno));\n }\n }\n\n rc.cpus = smp::count;\n rc.cpu_set = std::move(cpu_set);\n\n disk_config_params disk_config(reactor::max_queues);\n disk_config.parse_config(smp_opts, reactor_opts);\n for (auto& id : disk_config.device_ids()) {\n rc.devices.push_back(id);\n }\n rc.num_io_groups = disk_config.num_io_groups();\n\n#ifdef SEASTAR_HAVE_HWLOC\n if (smp_opts.allow_cpus_in_remote_numa_nodes.get_value()) {\n rc.assign_orphan_cpus = true;\n }\n#endif\n\n auto resources = resource::allocate(rc);\n logger::set_shard_field_width(std::ceil(std::log10(smp::count)));\n std::vector allocations = std::move(resources.cpus);\n if (thread_affinity) {\n smp::pin(allocations[0].cpu_id);\n }\n std::optional layout;\n if (smp_opts.memory_allocator == memory_allocator::seastar) {\n layout = memory::configure(allocations[0].mem, mbind, use_transparent_hugepages, hugepages_path);\n } else {\n // #2148 - if running seastar allocator but options that contradict this, we still need to\n // init memory at least minimally, otherwise a bunch of stuff breaks.\n // Previously, we got away wth this by accident due to #2137.\n memory::configure_minimal();\n }\n\n if (reactor_opts.abort_on_seastar_bad_alloc) {\n memory::set_abort_on_allocation_failure(true);\n }\n\n if (reactor_opts.dump_memory_diagnostics_on_alloc_failure_kind) {\n memory::set_dump_memory_diagnostics_on_alloc_failure_kind(reactor_opts.dump_memory_diagnostics_on_alloc_failure_kind.get_value());\n }\n\n reactor_config reactor_cfg;\n reactor_cfg.auto_handle_sigint_sigterm = reactor_opts._auto_handle_sigint_sigterm;\n reactor_cfg.max_networking_aio_io_control_blocks = adjust_max_networking_aio_io_control_blocks(reactor_opts.max_networking_io_control_blocks.get_value());\n\n std::mutex mtx;\n\n#ifdef SEASTAR_HEAPPROF\n size_t heapprof_sampling_rate = reactor_opts.heapprof.get_value();\n if (heapprof_sampling_rate) {\n memory::set_heap_profiling_sampling_rate(heapprof_sampling_rate);\n }\n#else\n size_t heapprof_sampling_rate = 0;\n#endif\n\n#ifdef SEASTAR_HAVE_DPDK\n if (_using_dpdk) {\n dpdk::eal::cpuset cpus;\n for (auto&& a : allocations) {\n cpus[a.cpu_id] = true;\n }\n dpdk::eal::init(cpus, reactor_opts._argv0, hugepages_path, native_stack ? bool(native_stack->dpdk_pmd) : false);\n }\n#endif\n\n // Better to put it into the smp class, but at smp construction time\n // correct smp::count is not known.\n boost::barrier reactors_registered(smp::count);\n boost::barrier smp_queues_constructed(smp::count);\n // We use shared_ptr since this thread can exit while other threads are still unlocking\n auto inited = std::make_shared(smp::count);\n\n auto ioq_topology = std::move(resources.ioq_topology);\n\n // ATTN: The ioq_topology value is referenced by below lambdas which are\n // then copied to other shard's threads, so each shard has a copy of the\n // ioq_topology on stack, but (!) still references and uses the value\n // from shard-0. This access is race-free because\n // 1. The .shard_to_group is not modified\n // 2. The .queues is pre-resize()-d in advance, so the vector itself\n // doesn't change; existing slots are accessed by owning shards only\n // without interference\n // 3. The .groups manipulations are guarded by the .lock lock (but it's\n // also pre-resize()-d in advance)\n\n auto alloc_io_queues = [&ioq_topology, &disk_config] (shard_id shard) {\n for (auto& [dev, io_info] : ioq_topology) {\n auto group_idx = io_info.shard_to_group[shard];\n std::shared_ptr group;\n\n {\n std::lock_guard _(io_info.lock);\n auto& iog = io_info.groups[group_idx];\n if (!iog) {\n struct io_queue::config qcfg = disk_config.generate_config(dev, io_info.groups.size());\n iog = std::make_shared(std::move(qcfg), io_info.shards_in_group[group_idx]);\n seastar_logger.debug(\"allocate {} IO group with {} queues, dev {}\", group_idx, io_info.shards_in_group[group_idx], dev);\n }\n group = iog;\n }\n\n io_info.queues[shard] = std::make_unique(std::move(group), engine()._io_sink);\n seastar_logger.debug(\"attached {} queue to {} IO group, dev {}\", shard, group_idx, dev);\n }\n };\n\n auto assign_io_queues = [&ioq_topology] (shard_id shard) {\n for (auto& [dev, io_info] : ioq_topology) {\n auto queue = std::move(io_info.queues[shard]);\n assert(queue);\n engine()._io_queues.emplace(dev, std::move(queue));\n\n auto num_io_groups = io_info.groups.size();\n if (engine()._num_io_groups == 0) {\n engine()._num_io_groups = num_io_groups;\n } else if (engine()._num_io_groups != num_io_groups) {\n throw std::logic_error(format(\"Number of IO-groups mismatch, {} != {}\", engine()._num_io_groups, num_io_groups));\n }\n }\n };\n\n _all_event_loops_done.emplace(smp::count);\n\n auto backend_selector = reactor_opts.reactor_backend.get_selected_candidate();\n seastar_logger.info(\"Reactor backend: {}\", backend_selector);\n\n unsigned i;\n auto smp_tmain = smp::_tmain;\n for (i = 1; i < smp::count; i++) {\n auto allocation = allocations[i];\n create_thread([this, smp_tmain, inited, &reactors_registered, &smp_queues_constructed, &smp_opts, &reactor_opts, &reactors, hugepages_path, i, allocation, assign_io_queues, alloc_io_queues, thread_affinity, heapprof_sampling_rate, mbind, backend_selector, reactor_cfg, &mtx, &layout, use_transparent_hugepages] {\n try {\n // initialize thread_locals that are equal across all reacto threads of this smp instance\n smp::_tmain = smp_tmain;\n auto thread_name = fmt::format(\"reactor-{}\", i);\n pthread_setname_np(pthread_self(), thread_name.c_str());\n if (thread_affinity) {\n smp::pin(allocation.cpu_id);\n }\n if (smp_opts.memory_allocator == memory_allocator::seastar) {\n auto another_layout = memory::configure(allocation.mem, mbind, use_transparent_hugepages, hugepages_path);\n auto guard = std::lock_guard(mtx);\n *layout = memory::internal::merge(std::move(*layout), std::move(another_layout));\n } else {\n // See comment above (shard 0)\n memory::configure_minimal();\n }\n if (heapprof_sampling_rate) {\n memory::set_heap_profiling_sampling_rate(heapprof_sampling_rate);\n }\n sigset_t mask;\n sigfillset(&mask);\n for (auto sig : { SIGSEGV }) {\n sigdelset(&mask, sig);\n }\n auto r = ::pthread_sigmask(SIG_BLOCK, &mask, NULL);\n throw_pthread_error(r);\n init_default_smp_service_group(i);\n lowres_clock::update();\n allocate_reactor(i, backend_selector, reactor_cfg);\n reactors[i] = &engine();\n alloc_io_queues(i);\n reactors_registered.wait();\n smp_queues_constructed.wait();\n // _qs_owner is only initialized here\n _qs = _qs_owner.get();\n start_all_queues();\n assign_io_queues(i);\n inited->wait();\n engine().configure(reactor_opts);\n engine().do_run();\n } catch (const std::exception& e) {\n seastar_logger.error(\"{}\", e.what());\n _exit(1);\n }\n });\n }\n\n init_default_smp_service_group(0);\n lowres_clock::update();\n try {\n allocate_reactor(0, backend_selector, reactor_cfg);\n } catch (const std::exception& e) {\n seastar_logger.error(\"{}\", e.what());\n _exit(1);\n }\n\n reactors[0] = &engine();\n alloc_io_queues(0);\n\n#ifdef SEASTAR_HAVE_DPDK\n if (_using_dpdk) {\n auto it = _thread_loops.begin();\n \nRTE_LCORE_FOREACH_WORKER(i) {\n rte_eal_remote_launch(dpdk_thread_adaptor, static_cast(&*(it++)), i);\n }\n }\n#endif\n\n reactors_registered.wait();\n _qs_owner = decltype(smp::_qs_owner){new smp_message_queue* [smp::count], qs_deleter{}};\n _qs = _qs_owner.get();\n for(unsigned i = 0; i < smp::count; i++) {\n smp::_qs_owner[i] = reinterpret_cast(operator new[] (sizeof(smp_message_queue) * smp::count\n // smp_message_queue has members with hefty alignment requirements.\n // if we are reactor thread, or not running with dpdk, doing this\n // new default aligned seemingly works, as does reordering\n // dlinit dependencies (ugh). But we should enforce calling out to\n // aligned_alloc, instead of pure malloc, if possible.\n , std::align_val_t(alignof(smp_message_queue))\n ));\n for (unsigned j = 0; j < smp::count; ++j) {\n new (&smp::_qs_owner[i][j]) smp_message_queue(reactors[j], reactors[i]);\n }\n }\n _alien._qs = alien::instance::create_qs(reactors);\n smp_queues_constructed.wait();\n start_all_queues();\n assign_io_queues(0);\n inited->wait();\n\n engine().configure(reactor_opts);\n\n if (smp_opts.lock_memory && smp_opts.lock_memory.get_value() && layout && !layout->ranges.empty()) {\n smp::setup_prefaulter(resources, std::move(*layout));\n }\n}\n\nbool smp::poll_queues() {\n size_t got = 0;\n for (unsigned i = 0; i < count; i++) {\n if (this_shard_id() != i) {\n auto& rxq = _qs[this_shard_id()][i];\n rxq.flush_response_batch();\n got += rxq.has_unflushed_responses();\n got += rxq.process_incoming();\n auto& txq = _qs[i][this_shard_id()];\n txq.flush_request_batch();\n got += txq.process_completions(i);\n }\n }\n return got != 0;\n}\n\nbool smp::pure_poll_queues() {\n for (unsigned i = 0; i < count; i++) {\n if (this_shard_id() != i) {\n auto& rxq = _qs[this_shard_id()][i];\n rxq.flush_response_batch();\n auto& txq = _qs[i][this_shard_id()];\n txq.flush_request_batch();\n if (rxq.pure_poll_rx() || txq.pure_poll_tx() || rxq.has_unflushed_responses()) {\n return true;\n }\n }\n }\n return false;\n}\n\n__thread reactor* local_engine;\n\nvoid report_exception(std::string_view message, std::exception_ptr eptr) noexcept {\n seastar_logger.error(\"{}: {}\", message, eptr);\n}\n\nfuture<> check_direct_io_support(std::string_view path) noexcept {\n struct w {\n sstring path;\n open_flags flags;\n std::function()> cleanup;\n\n static w parse(sstring path, std::optional type) {\n if (!type) {\n throw std::invalid_argument(format(\"Could not open file at {}. Make sure it exists\", path));\n }\n\n if (type == directory_entry_type::directory) {\n auto fpath = path + \"/.o_direct_test\";\n return w{fpath, open_flags::wo | open_flags::create | open_flags::truncate, [fpath] { return remove_file(fpath); }};\n } else if ((type == directory_entry_type::regular) || (type == directory_entry_type::link)) {\n return w{path, open_flags::ro, [] { return make_ready_future<>(); }};\n } else {\n throw std::invalid_argument(format(\"{} neither a directory nor file. Can't be opened with O_DIRECT\", path));\n }\n };\n };\n\n // Allocating memory for a sstring can throw, hence the futurize_invoke\n return futurize_invoke([path] {\n return engine().file_type(path).then([path = sstring(path)] (auto type) {\n auto w = w::parse(path, type);\n return open_file_dma(w.path, w.flags).then_wrapped([path = w.path, cleanup = std::move(w.cleanup)] (future f) {\n try {\n auto fd = f.get();\n return cleanup().finally([fd = std::move(fd)] () mutable {\n return fd.close();\n });\n } catch (std::system_error& e) {\n if (e.code() == std::error_code(EINVAL, std::system_category())) {\n report_exception(format(\"Could not open file at {}. Does your filesystem support O_DIRECT?\", path), std::current_exception());\n }\n throw;\n }\n });\n });\n });\n}\n\nserver_socket listen(socket_address sa) {\n return engine().listen(sa);\n}\n\nserver_socket listen(socket_address sa, listen_options opts) {\n return engine().listen(sa, opts);\n}\n\nfuture connect(socket_address sa) {\n return engine().connect(sa);\n}\n\nfuture connect(socket_address sa, socket_address local, transport proto = transport::TCP) {\n return engine().connect(sa, local, proto);\n}\n\nsocket make_socket() {\n return engine().net().socket();\n}\n\nnet::udp_channel make_udp_channel() {\n return make_unbound_datagram_channel(AF_INET);\n}\n\nnet::udp_channel make_udp_channel(const socket_address& local) {\n return make_bound_datagram_channel(local);\n}\n\nnet::datagram_channel make_unbound_datagram_channel(sa_family_t family) {\n return engine().net().make_unbound_datagram_channel(family);\n}\n\nnet::datagram_channel make_bound_datagram_channel(const socket_address& local) {\n return engine().net().make_bound_datagram_channel(local);\n}\n\nvoid reactor::add_high_priority_task(task* t) noexcept {\n add_urgent_task(t);\n // break .then() chains\n request_preemption();\n}\n\n\nvoid set_idle_cpu_handler(idle_cpu_handler&& handler) {\n engine().set_idle_cpu_handler(std::move(handler));\n}\n\nnamespace experimental {\nfuture> make_pipe() {\n return engine().make_pipe();\n}\n\nfuture spawn_process(const std::filesystem::path& pathname,\n spawn_parameters params) {\n return process::spawn(pathname, std::move(params));\n}\n\nfuture spawn_process(const std::filesystem::path& pathname) {\n return process::spawn(pathname);\n}\n}\n\nstatic\nbool\nvirtualized() {\n return fs::exists(\"/sys/hypervisor/type\");\n}\n\nstd::chrono::nanoseconds\nreactor::calculate_poll_time() {\n // In a non-virtualized environment, select a poll time\n // that is competitive with halt/unhalt.\n //\n // In a virutalized environment, IPIs are slow and dominate\n // sleep/wake (mprotect/tgkill), so increase poll time to reduce\n // so we don't sleep in a request/reply workload\n return virtualized() ? 2000us : 200us;\n}\n\nfuture<>\nyield() noexcept {\n memory::scoped_critical_alloc_section _;\n auto tsk = make_task([] {});\n schedule(tsk);\n return tsk->get_future();\n}\n\nfuture<> check_for_io_immediately() noexcept {\n memory::scoped_critical_alloc_section _;\n engine().force_poll();\n auto tsk = make_task(default_scheduling_group(), [] {});\n schedule(tsk);\n return tsk->get_future();\n}\n\nfuture<> later() noexcept {\n return check_for_io_immediately();\n}\n\nsteady_clock_type::duration reactor::total_idle_time() {\n return _total_idle;\n}\n\nsteady_clock_type::duration reactor::total_busy_time() {\n return now() - _start_time - _total_idle;\n}\n\nstd::chrono::nanoseconds reactor::total_steal_time() {\n // Steal time: this mimics the concept some Hypervisors have about Steal time.\n // That is the time in which a VM has something to run, but is not running because some other\n // process (another VM or the hypervisor itself) is in control.\n //\n // For us, we notice that during the time in which we were not sleeping (either running or busy\n // polling while idle), we should be accumulating thread runtime. If we are not, that's because\n // someone stole it from us.\n //\n // Because this is totally in userspace we can miss some events. For instance, if the seastar\n // process is ready to run but the kernel hasn't scheduled us yet, that would be technically\n // steal time but we have no ways to account it.\n //\n // But what we have here should be good enough and at least has a well defined meaning.\n return std::chrono::duration_cast(now() - _start_time - _total_sleep) -\n std::chrono::duration_cast(thread_cputime_clock::now().time_since_epoch());\n}\n\nstatic std::atomic s_used_scheduling_group_ids_bitmap{3}; // 0=main, 1=atexit\nstatic std::atomic s_next_scheduling_group_specific_key{0};\n\nstatic\nint\nallocate_scheduling_group_id() noexcept {\n static_assert(max_scheduling_groups() <= std::numeric_limits::digits, \"more scheduling groups than available bits\");\n auto b = s_used_scheduling_group_ids_bitmap.load(std::memory_order_relaxed);\n auto nb = b;\n unsigned i = 0;\n do {\n if (__builtin_popcountl(b) == max_scheduling_groups()) {\n return -1;\n }\n i = count_trailing_zeros(~b);\n nb = b | (1ul << i);\n } while (!s_used_scheduling_group_ids_bitmap.compare_exchange_weak(b, nb, std::memory_order_relaxed));\n return i;\n}\n\nstatic\nunsigned long\nallocate_scheduling_group_specific_key() noexcept {\n return s_next_scheduling_group_specific_key.fetch_add(1, std::memory_order_relaxed);\n}\n\nstatic\nvoid\ndeallocate_scheduling_group_id(unsigned id) noexcept {\n s_used_scheduling_group_ids_bitmap.fetch_and(~(1ul << id), std::memory_order_relaxed);\n}\n\nvoid\nreactor::allocate_scheduling_group_specific_data(scheduling_group sg, scheduling_group_key key) {\n auto& sg_data = _scheduling_group_specific_data;\n auto& this_sg = sg_data.per_scheduling_group_data[sg._id];\n this_sg.specific_vals.resize(std::max(this_sg.specific_vals.size(), key.id()+1));\n this_sg.specific_vals[key.id()] =\n aligned_alloc(sg_data.scheduling_group_key_configs[key.id()].alignment,\n sg_data.scheduling_group_key_configs[key.id()].allocation_size);\n if (!this_sg.specific_vals[key.id()]) {\n std::abort();\n }\n if (sg_data.scheduling_group_key_configs[key.id()].constructor) {\n sg_data.scheduling_group_key_configs[key.id()].constructor(this_sg.specific_vals[key.id()]);\n }\n}\n\nfuture<>\nreactor::rename_scheduling_group_specific_data(scheduling_group sg) {\n return with_scheduling_group(sg, [this, sg] {\n auto& sg_data = _scheduling_group_specific_data;\n auto& this_sg = sg_data.per_scheduling_group_data[sg._id];\n for (size_t i = 0; i < sg_data.scheduling_group_key_configs.size(); ++i) {\n auto &c = sg_data.scheduling_group_key_configs[i];\n if (c.rename) {\n (c.rename)(this_sg.specific_vals[i]);\n }\n }\n });\n}\n\nfuture<>\nreactor::init_scheduling_group(seastar::scheduling_group sg, sstring name, sstring shortname, float shares) {\n auto& sg_data = _scheduling_group_specific_data;\n auto& this_sg = sg_data.per_scheduling_group_data[sg._id];\n this_sg.queue_is_initialized = true;\n _task_queues.resize(std::max(_task_queues.size(), sg._id + 1));\n _task_queues[sg._id] = std::make_unique(sg._id, name, shortname, shares);\n unsigned long num_keys = s_next_scheduling_group_specific_key.load(std::memory_order_relaxed);\n\n return with_scheduling_group(sg, [this, num_keys, sg] () {\n for (unsigned long key_id = 0; key_id < num_keys; key_id++) {\n allocate_scheduling_group_specific_data(sg, scheduling_group_key(key_id));\n }\n });\n}\n\nfuture<>\nreactor::init_new_scheduling_group_key(scheduling_group_key key, scheduling_group_key_config cfg) {\n auto& sg_data = _scheduling_group_specific_data;\n sg_data.scheduling_group_key_configs.resize(std::max(sg_data.scheduling_group_key_configs.size(), key.id() + 1));\n sg_data.scheduling_group_key_configs[key.id()] = cfg;\n return parallel_for_each(_task_queues, [this, cfg, key] (std::unique_ptr& tq) {\n...\n", "needle_token_start": 12125, "needle_token_end": 12182, "code_context_ntokens": 16228, "output": ["```cpp\n#ifdef SEASTAR_HAVE_DPDK\n\nint dpdk_thread_adaptor(void* f)\n{\n (*static_cast*>(f))();\n return 0;\n}\n\n#endif\n```"]} +{"repo": "scylladb/seastar", "name": "default_isolate_connection", "language": "cpp", "path": "src/rpc/rpc.cc", "position_ratio": 0.85, "description": "\nFunction Description:\n1. **Purpose**: The function is designed to provide a default configuration for isolating a connection based on a specific identifier.\n2. **Input**: A string representing an identifier for the isolation configuration.\n3. **Output**: An object representing the configuration settings for isolating a connection.\n4. **Procedure**: The function receives an identifier as input and returns a new configuration object with default settings, without modifying any properties based on the input.\n\n", "instruction": "Based on the function description and code context, please retrieve and repeat the exact described function from the code context in a code block wrapped by ```:", "template": "instruction\ncode_context\ndescription\ninstruction", "code_context": " //\n // The latter is achieved by resolving pit.done immediatelly, the former\n // by moving it.done into pit.done. For simplicity (verging on obscurity?)\n // both done's are just swapped and \"it\" resolves its new promise\n\n std::swap(it->done, pit->done);\n it->uncancellable();\n it->unlink();\n if (ex == nullptr) {\n it->done.set_value();\n } else {\n it->done.set_exception(ex);\n }\n }\n\n future<> connection::send(snd_buf buf, std::optional timeout, cancellable* cancel) {\n if (!_error) {\n if (timeout && *timeout <= rpc_clock_type::now()) {\n return make_ready_future<>();\n }\n\n auto p = std::make_unique(std::move(buf));\n auto& d = *p;\n _outgoing_queue.push_back(d);\n _outgoing_queue_size++;\n auto deleter = [this, it = _outgoing_queue.iterator_to(d)] {\n // Front entry is most likely (unless _negotiated is unresolved, check enqueue_zero_frame()) sitting\n // inside send_entry() continuations and thus it cannot be cancelled.\n if (it != _outgoing_queue.begin()) {\n withdraw(it);\n }\n };\n\n if (timeout) {\n auto& t = d.t;\n t.set_callback(deleter);\n t.arm(timeout.value());\n }\n if (cancel) {\n cancel->cancel_send = std::move(deleter);\n cancel->send_back_pointer = &d.pcancel;\n d.pcancel = cancel;\n }\n\n // New entry should continue (do its .then() lambda) after _outgoing_queue_ready\n // resolves. Next entry will need to do the same after this entry's done resolves.\n // Thus -- replace _outgoing_queue_ready with d's future and chain its continuation\n // on ..._ready's old value.\n return std::exchange(_outgoing_queue_ready, d.done.get_future()).then([this, p = std::move(p)] () mutable {\n _outgoing_queue_size--;\n if (__builtin_expect(!p->is_linked(), false)) {\n // If withdrawn the entry is unlinked and this lambda is fired right at once\n return make_ready_future<>();\n }\n\n p->uncancellable();\n return send_entry(*p).then_wrapped([this, p = std::move(p)] (auto f) mutable {\n if (f.failed()) {\n f.ignore_ready_future();\n abort();\n }\n p->done.set_value();\n });\n });\n } else {\n return make_exception_future<>(closed_error());\n }\n }\n\n void connection::abort() {\n if (!_error) {\n _error = true;\n _fd.shutdown_input();\n }\n }\n\n future<> connection::stop() noexcept {\n try {\n abort();\n } catch (...) {\n log_exception(*this, log_level::error, \"fail to shutdown connection while stopping\", std::current_exception());\n }\n return _stopped.get_future();\n }\n\n template\n static bool verify_frame(Connection& c, temporary_buffer& buf, size_t expected, const char* log) {\n if (buf.size() != expected) {\n if (buf.size() != 0) {\n c.get_logger()(c.peer_address(), log);\n }\n return false;\n }\n return true;\n }\n\n template\n static\n future\n receive_negotiation_frame(Connection& c, input_stream& in) {\n return in.read_exactly(sizeof(negotiation_frame)).then([&c, &in] (temporary_buffer neg) {\n if (!verify_frame(c, neg, sizeof(negotiation_frame), \"unexpected eof during negotiation frame\")) {\n return make_exception_future(closed_error());\n }\n negotiation_frame frame;\n std::copy_n(neg.get_write(), sizeof(frame.magic), frame.magic);\n frame.len = read_le(neg.get_write() + 8);\n if (std::memcmp(frame.magic, rpc_magic, sizeof(frame.magic)) != 0) {\n c.get_logger()(c.peer_address(), format(\"wrong protocol magic: {:02x}{:02x}{:02x}{:02x}{:02x}{:02x}{:02x}{:02x}\",\n frame.magic[0], frame.magic[1], frame.magic[2], frame.magic[3], frame.magic[4], frame.magic[5], frame.magic[6], frame.magic[7]));\n return make_exception_future(closed_error());\n }\n auto len = frame.len;\n return in.read_exactly(len).then([&c, len] (temporary_buffer extra) {\n if (extra.size() != len) {\n c.get_logger()(c.peer_address(), \"unexpected eof during negotiation frame\");\n return make_exception_future(closed_error());\n }\n feature_map map;\n auto p = extra.get();\n auto end = p + extra.size();\n while (p != end) {\n if (end - p < 8) {\n c.get_logger()(c.peer_address(), \"bad feature data format in negotiation frame\");\n return make_exception_future(closed_error());\n }\n auto feature = static_cast(read_le(p));\n auto f_len = read_le(p + 4);\n p += 8;\n if (f_len > end - p) {\n c.get_logger()(c.peer_address(), \"buffer underflow in feature data in negotiation frame\");\n return make_exception_future(closed_error());\n }\n auto data = sstring(p, f_len);\n p += f_len;\n map.emplace(feature, std::move(data));\n }\n return make_ready_future(std::move(map));\n });\n });\n }\n\n inline future\n read_rcv_buf(input_stream& in, uint32_t size) {\n return in.read_up_to(size).then([&, size] (temporary_buffer data) mutable {\n rcv_buf rb(size);\n if (data.size() == 0) {\n return make_ready_future(rcv_buf());\n } else if (data.size() == size) {\n rb.bufs = std::move(data);\n return make_ready_future(std::move(rb));\n } else {\n size -= data.size();\n std::vector> v;\n v.push_back(std::move(data));\n rb.bufs = std::move(v);\n return do_with(std::move(rb), std::move(size), [&in] (rcv_buf& rb, uint32_t& left) {\n return repeat([&] () {\n return in.read_up_to(left).then([&] (temporary_buffer data) {\n if (!data.size()) {\n rb.size -= left;\n return stop_iteration::yes;\n } else {\n left -= data.size();\n std::get>>(rb.bufs).push_back(std::move(data));\n return left ? stop_iteration::no : stop_iteration::yes;\n }\n });\n }).then([&rb] {\n return std::move(rb);\n });\n });\n }\n });\n }\n\n template\n future\n connection::read_frame(socket_address info, input_stream& in) {\n auto header_size = FrameType::header_size();\n return in.read_exactly(header_size).then([this, header_size, info, &in] (temporary_buffer header) {\n if (header.size() != header_size) {\n if (header.size() != 0) {\n _logger(info, format(\"unexpected eof on a {} while reading header: expected {:d} got {:d}\", FrameType::role(), header_size, header.size()));\n }\n return make_ready_future(FrameType::empty_value());\n }\n auto [size, h] = FrameType::decode_header(header.get());\n if (!size) {\n return make_ready_future(FrameType::make_value(h, rcv_buf()));\n } else {\n return read_rcv_buf(in, size).then([this, info, h = std::move(h), size] (rcv_buf rb) {\n if (rb.size != size) {\n _logger(info, format(\"unexpected eof on a {} while reading data: expected {:d} got {:d}\", FrameType::role(), size, rb.size));\n return make_ready_future(FrameType::empty_value());\n } else {\n return make_ready_future(FrameType::make_value(h, std::move(rb)));\n }\n });\n }\n });\n }\n\n template\n future\n connection::read_frame_compressed(socket_address info, std::unique_ptr& compressor, input_stream& in) {\n if (compressor) {\n return in.read_exactly(4).then([this, info, &in, &compressor] (temporary_buffer compress_header) {\n if (compress_header.size() != 4) {\n if (compress_header.size() != 0) {\n _logger(info, format(\"unexpected eof on a {} while reading compression header: expected 4 got {:d}\", FrameType::role(), compress_header.size()));\n }\n return make_ready_future(FrameType::empty_value());\n }\n auto ptr = compress_header.get();\n auto size = read_le(ptr);\n return read_rcv_buf(in, size).then([this, size, &compressor, info, &in] (rcv_buf compressed_data) {\n if (compressed_data.size != size) {\n _logger(info, format(\"unexpected eof on a {} while reading compressed data: expected {:d} got {:d}\", FrameType::role(), size, compressed_data.size));\n return make_ready_future(FrameType::empty_value());\n }\n auto eb = compressor->decompress(std::move(compressed_data));\n if (eb.size == 0) {\n // Empty frames might be sent as means of communication between the compressors, and should be skipped by the RPC layer.\n // We skip the empty frame here. We recursively restart the function, as if the empty frame didn't happen.\n // The yield() is here to limit the stack depth of the recursion to 1.\n return yield().then([this, info, &in, &compressor] { return read_frame_compressed(info, compressor, in); });\n }\n net::packet p;\n auto* one = std::get_if>(&eb.bufs);\n if (one) {\n p = net::packet(std::move(p), std::move(*one));\n } else {\n auto&& bufs = std::get>>(eb.bufs);\n p.reserve(bufs.size());\n for (auto&& b : bufs) {\n p = net::packet(std::move(p), std::move(b));\n }\n }\n return do_with(as_input_stream(std::move(p)), [this, info] (input_stream& in) {\n return read_frame(info, in);\n });\n });\n });\n } else {\n return read_frame(info, in);\n }\n }\n\n struct stream_frame {\n using opt_buf_type = std::optional;\n using return_type = opt_buf_type;\n struct header_type {\n bool eos;\n };\n static size_t header_size() {\n return 4;\n }\n static const char* role() {\n return \"stream\";\n }\n static auto empty_value() {\n return std::nullopt;\n }\n static std::pair decode_header(const char* ptr) {\n auto size = read_le(ptr);\n return size != -1U ? std::make_pair(size, header_type{false}) : std::make_pair(0U, header_type{true});\n }\n static auto make_value(const header_type& t, rcv_buf data) {\n if (t.eos) {\n data.size = -1U;\n }\n return data;\n }\n };\n\n future>\n connection::read_stream_frame_compressed(input_stream& in) {\n return read_frame_compressed(peer_address(), _compressor, in);\n }\n\n future<> connection::stream_close() {\n auto f = make_ready_future<>();\n if (!error()) {\n promise p;\n _sink_closed_future = p.get_future();\n // stop_send_loop(), which also calls _write_buf.close(), and this code can run in parallel.\n // Use _sink_closed_future to serialize them and skip second call to close()\n f = _write_buf.close().finally([p = std::move(p)] () mutable { p.set_value(true);});\n }\n return f.finally([this] () mutable { return stop(); });\n }\n\n future<> connection::stream_process_incoming(rcv_buf&& buf) {\n // we do not want to dead lock on huge packets, so let them in\n // but only one at a time\n auto size = std::min(size_t(buf.size), max_stream_buffers_memory);\n return get_units(_stream_sem, size).then([this, buf = std::move(buf)] (semaphore_units<>&& su) mutable {\n buf.su = std::move(su);\n return _stream_queue.push_eventually(std::move(buf));\n });\n }\n\n future<> connection::handle_stream_frame() {\n return read_stream_frame_compressed(_read_buf).then([this] (std::optional data) {\n if (!data) {\n _error = true;\n return make_ready_future<>();\n }\n return stream_process_incoming(std::move(*data));\n });\n }\n\n future<> connection::stream_receive(circular_buffer>>& bufs) {\n return _stream_queue.not_empty().then([this, &bufs] {\n bool eof = !_stream_queue.consume([&bufs] (rcv_buf&& b) {\n if (b.size == -1U) { // max fragment length marks an end of a stream\n return false;\n } else {\n bufs.push_back(make_foreign(std::make_unique(std::move(b))));\n return true;\n }\n });\n if (eof && !bufs.empty()) {\n assert(_stream_queue.empty());\n _stream_queue.push(rcv_buf(-1U)); // push eof marker back for next read to notice it\n }\n });\n }\n\n void connection::register_stream(connection_id id, xshard_connection_ptr c) {\n _streams.emplace(id, std::move(c));\n }\n\n xshard_connection_ptr connection::get_stream(connection_id id) const {\n auto it = _streams.find(id);\n if (it == _streams.end()) {\n throw std::logic_error(format(\"rpc stream id {} not found\", id).c_str());\n }\n return it->second;\n }\n\n // The request frame is\n // le64 optional timeout (see request_frame_with_timeout below)\n // le64 message type a.k.a. verb ID\n // le64 message ID\n // le32 payload length\n // ... payload\n struct request_frame {\n using opt_buf_type = std::optional;\n using return_type = std::tuple, uint64_t, int64_t, opt_buf_type>;\n using header_type = std::tuple, uint64_t, int64_t>;\n static constexpr size_t raw_header_size = sizeof(uint64_t) + sizeof(int64_t) + sizeof(uint32_t);\n static size_t header_size() {\n static_assert(request_frame_headroom >= raw_header_size);\n return raw_header_size;\n }\n static const char* role() {\n return \"server\";\n }\n static auto empty_value() {\n return std::make_tuple(std::nullopt, uint64_t(0), 0, std::nullopt);\n }\n static std::pair decode_header(const char* ptr) {\n auto type = read_le(ptr);\n auto msgid = read_le(ptr + 8);\n auto size = read_le(ptr + 16);\n return std::make_pair(size, std::make_tuple(std::nullopt, type, msgid));\n }\n static void encode_header(uint64_t type, int64_t msg_id, snd_buf& buf, size_t off) {\n auto p = buf.front().get_write() + off;\n write_le(p, type);\n write_le(p + 8, msg_id);\n write_le(p + 16, buf.size - raw_header_size - off);\n }\n static auto make_value(const header_type& t, rcv_buf data) {\n return std::make_tuple(std::get<0>(t), std::get<1>(t), std::get<2>(t), std::move(data));\n }\n };\n\n // This frame is used if protocol_features.TIMEOUT was negotiated\n struct request_frame_with_timeout : request_frame {\n using super = request_frame;\n static constexpr size_t raw_header_size = sizeof(uint64_t) + request_frame::raw_header_size;\n static size_t header_size() {\n static_assert(request_frame_headroom >= raw_header_size);\n return raw_header_size;\n }\n static std::pair decode_header(const char* ptr) {\n auto h = super::decode_header(ptr + 8);\n std::get<0>(h.second) = read_le(ptr);\n return h;\n }\n static void encode_header(uint64_t type, int64_t msg_id, snd_buf& buf) {\n static_assert(snd_buf::chunk_size >= raw_header_size, \"send buffer chunk size is too small\");\n // expiration timer is encoded later\n request_frame::encode_header(type, msg_id, buf, 8);\n }\n };\n\n future<> client::request(uint64_t type, int64_t msg_id, snd_buf buf, std::optional timeout, cancellable* cancel) {\n request_frame_with_timeout::encode_header(type, msg_id, buf);\n return send(std::move(buf), timeout, cancel);\n }\n\n void\n client::negotiate(feature_map provided) {\n // record features returned here\n for (auto&& e : provided) {\n auto id = e.first;\n switch (id) {\n // supported features go here\n case protocol_features::COMPRESS:\n if (_options.compressor_factory) {\n _compressor = _options.compressor_factory->negotiate(e.second, false, [this] { return send({}); });\n }\n if (!_compressor) {\n throw std::runtime_error(format(\"RPC server responded with compression {} - unsupported\", e.second));\n }\n break;\n case protocol_features::TIMEOUT:\n _timeout_negotiated = true;\n break;\n case protocol_features::CONNECTION_ID: {\n _id = deserialize_connection_id(e.second);\n break;\n }\n default:\n // nothing to do\n ;\n }\n }\n }\n\n future<> client::negotiate_protocol(feature_map features) {\n return send_negotiation_frame(std::move(features)).then([this] {\n return receive_negotiation_frame(*this, _read_buf).then([this] (feature_map features) {\n return negotiate(std::move(features));\n });\n });\n }\n\n // The response frame is\n // le64 message ID\n // le32 payload size\n // ... payload\n struct response_frame {\n using opt_buf_type = std::optional;\n using return_type = std::tuple;\n using header_type = std::tuple;\n static constexpr size_t raw_header_size = sizeof(int64_t) + sizeof(uint32_t);\n static size_t header_size() {\n static_assert(response_frame_headroom >= raw_header_size);\n return raw_header_size;\n }\n static const char* role() {\n return \"client\";\n }\n static auto empty_value() {\n return std::make_tuple(0, std::nullopt);\n }\n static std::pair decode_header(const char* ptr) {\n auto msgid = read_le(ptr);\n auto size = read_le(ptr + 8);\n return std::make_pair(size, std::make_tuple(msgid));\n }\n static void encode_header(int64_t msg_id, snd_buf& data) {\n static_assert(snd_buf::chunk_size >= raw_header_size, \"send buffer chunk size is too small\");\n auto p = data.front().get_write();\n write_le(p, msg_id);\n write_le(p + 8, data.size - raw_header_size);\n }\n static auto make_value(const header_type& t, rcv_buf data) {\n return std::make_tuple(std::get<0>(t), std::move(data));\n }\n };\n\n\n future\n client::read_response_frame_compressed(input_stream& in) {\n return read_frame_compressed(_server_addr, _compressor, in);\n }\n\n stats client::get_stats() const {\n stats res = _stats;\n res.wait_reply = incoming_queue_length();\n res.pending = outgoing_queue_length();\n return res;\n }\n\n void client::wait_for_reply(id_type id, std::unique_ptr&& h, std::optional timeout, cancellable* cancel) {\n if (timeout) {\n h->t.set_callback(std::bind(std::mem_fn(&client::wait_timed_out), this, id));\n h->t.arm(timeout.value());\n }\n if (cancel) {\n cancel->cancel_wait = [this, id] {\n _outstanding[id]->cancel();\n _outstanding.erase(id);\n };\n h->pcancel = cancel;\n cancel->wait_back_pointer = &h->pcancel;\n }\n _outstanding.emplace(id, std::move(h));\n }\n void client::wait_timed_out(id_type id) {\n _stats.timeout++;\n _outstanding[id]->timeout();\n _outstanding.erase(id);\n }\n\n future<> client::stop() noexcept {\n _error = true;\n try {\n _socket.shutdown();\n } catch(...) {\n log_exception(*this, log_level::error, \"fail to shutdown connection while stopping\", std::current_exception());\n }\n return _stopped.get_future();\n }\n\n void client::abort_all_streams() {\n while (!_streams.empty()) {\n auto&& s = _streams.begin();\n assert(s->second->get_owner_shard() == this_shard_id()); // abort can be called only locally\n s->second->get()->abort();\n _streams.erase(s);\n }\n }\n\n void client::deregister_this_stream() {\n if (_parent) {\n _parent->_streams.erase(_id);\n }\n }\n\n // This is the enlightened copy of the connection::send() method. Its intention is to\n // keep a dummy entry in front of the queue while connect+negotiate is happenning so\n // that all subsequent entries could abort on timeout or explicit cancellation.\n void client::enqueue_zero_frame() {\n if (_error) {\n return;\n }\n\n auto p = std::make_unique(snd_buf(0));\n auto& d = *p;\n _outgoing_queue.push_back(d);\n\n // Make it in the background. Even if the client is stopped it will pick\n // up all the entries hanging around\n (void)std::exchange(_outgoing_queue_ready, d.done.get_future()).then_wrapped([p = std::move(p)] (auto f) mutable {\n if (f.failed()) {\n f.ignore_ready_future();\n } else {\n p->done.set_value();\n }\n });\n }\n\n struct client::metrics::domain {\n metrics::domain_list_t list;\n stats dead;\n seastar::metrics::metric_groups metric_groups;\n\n static thread_local std::unordered_map all;\n static domain& find_or_create(sstring name);\n\n stats::counter_type count_all(stats::counter_type stats::* field) noexcept {\n stats::counter_type res = dead.*field;\n for (const auto& m : list) {\n res += m._c._stats.*field;\n }\n return res;\n }\n\n size_t count_all_fn(size_t (client::*fn)(void) const) noexcept {\n size_t res = 0;\n for (const auto& m : list) {\n res += (m._c.*fn)();\n }\n return res;\n }\n\n domain(sstring name)\n {\n namespace sm = seastar::metrics;\n auto domain_l = sm::label(\"domain\")(name);\n\n metric_groups.add_group(\"rpc_client\", {\n sm::make_gauge(\"count\", [this] { return list.size(); },\n sm::description(\"Total number of clients\"), { domain_l }),\n sm::make_counter(\"sent_messages\", std::bind(&domain::count_all, this, &stats::sent_messages),\n sm::description(\"Total number of messages sent\"), { domain_l }),\n sm::make_counter(\"replied\", std::bind(&domain::count_all, this, &stats::replied),\n sm::description(\"Total number of responses received\"), { domain_l }),\n sm::make_counter(\"exception_received\", std::bind(&domain::count_all, this, &stats::exception_received),\n sm::description(\"Total number of exceptional responses received\"), { domain_l }).set_skip_when_empty(),\n sm::make_counter(\"timeout\", std::bind(&domain::count_all, this, &stats::timeout),\n sm::description(\"Total number of timeout responses\"), { domain_l }).set_skip_when_empty(),\n sm::make_gauge(\"pending\", std::bind(&domain::count_all_fn, this, &client::outgoing_queue_length),\n sm::description(\"Number of queued outbound messages\"), { domain_l }),\n sm::make_gauge(\"wait_reply\", std::bind(&domain::count_all_fn, this, &client::incoming_queue_length),\n sm::description(\"Number of replies waiting for\"), { domain_l }),\n });\n }\n };\n\n thread_local std::unordered_map client::metrics::domain::all;\n\n client::metrics::domain& client::metrics::domain::find_or_create(sstring name) {\n auto i = all.try_emplace(name, name);\n return i.first->second;\n }\n\n client::metrics::metrics(const client& c)\n : _c(c)\n , _domain(domain::find_or_create(_c._options.metrics_domain))\n {\n _domain.list.push_back(*this);\n }\n\n client::metrics::~metrics() {\n _domain.dead.replied += _c._stats.replied;\n _domain.dead.exception_received += _c._stats.exception_received;\n _domain.dead.sent_messages += _c._stats.sent_messages;\n _domain.dead.timeout += _c._stats.timeout;\n }\n\n client::client(const logger& l, void* s, client_options ops, socket socket, const socket_address& addr, const socket_address& local)\n : rpc::connection(l, s), _socket(std::move(socket)), _server_addr(addr), _local_addr(local), _options(ops), _metrics(*this)\n {\n _socket.set_reuseaddr(ops.reuseaddr);\n // Run client in the background.\n // Communicate result via _stopped.\n // The caller has to call client::stop() to synchronize.\n (void)_socket.connect(addr, local).then([this, ops = std::move(ops)] (connected_socket fd) {\n fd.set_nodelay(ops.tcp_nodelay);\n if (ops.keepalive) {\n fd.set_keepalive(true);\n fd.set_keepalive_parameters(ops.keepalive.value());\n }\n set_socket(std::move(fd));\n\n feature_map features;\n if (_options.compressor_factory) {\n features[protocol_features::COMPRESS] = _options.compressor_factory->supported();\n }\n if (_options.send_timeout_data) {\n features[protocol_features::TIMEOUT] = \"\";\n }\n if (_options.stream_parent) {\n features[protocol_features::STREAM_PARENT] = serialize_connection_id(_options.stream_parent);\n }\n if (!_options.isolation_cookie.empty()) {\n features[protocol_features::ISOLATION] = _options.isolation_cookie;\n }\n\n return negotiate_protocol(std::move(features)).then([this] {\n _propagate_timeout = !is_stream();\n set_negotiated();\n return do_until([this] { return _read_buf.eof() || _error; }, [this] () mutable {\n if (is_stream()) {\n return handle_stream_frame();\n }\n return read_response_frame_compressed(_read_buf).then([this] (response_frame::return_type msg_id_and_data) {\n auto& msg_id = std::get<0>(msg_id_and_data);\n auto& data = std::get<1>(msg_id_and_data);\n auto it = _outstanding.find(std::abs(msg_id));\n if (!data) {\n _error = true;\n } else if (it != _outstanding.end()) {\n auto handler = std::move(it->second);\n _outstanding.erase(it);\n (*handler)(*this, msg_id, std::move(data.value()));\n } else if (msg_id < 0) {\n try {\n std::rethrow_exception(unmarshal_exception(data.value()));\n } catch(const unknown_verb_error& ex) {\n // if this is unknown verb exception with unknown id ignore it\n // can happen if unknown verb was used by no_wait client\n get_logger()(peer_address(), format(\"unknown verb exception {:d} ignored\", ex.type));\n } catch(...) {\n // We've got error response but handler is no longer waiting, could be timed out.\n log_exception(*this, log_level::info, \"ignoring error response\", std::current_exception());\n }\n } else {\n // we get a reply for a message id not in _outstanding\n // this can happened if the message id is timed out already\n get_logger()(peer_address(), log_level::debug, \"got a reply for an expired message id\");\n }\n });\n });\n });\n }).then_wrapped([this] (future<> f) {\n std::exception_ptr ep;\n if (f.failed()) {\n ep = f.get_exception();\n if (_connected) {\n if (is_stream()) {\n log_exception(*this, log_level::error, \"client stream connection dropped\", ep);\n } else {\n log_exception(*this, log_level::error, \"client connection dropped\", ep);\n }\n } else {\n if (is_stream()) {\n log_exception(*this, log_level::debug, \"stream fail to connect\", ep);\n } else {\n log_exception(*this, log_level::debug, \"fail to connect\", ep);\n }\n }\n }\n _error = true;\n _stream_queue.abort(std::make_exception_ptr(stream_closed()));\n return stop_send_loop(ep).then_wrapped([this] (future<> f) {\n f.ignore_ready_future();\n _outstanding.clear();\n if (is_stream()) {\n deregister_this_stream();\n } else {\n abort_all_streams();\n }\n }).finally([this] {\n return _compressor ? _compressor->close() : make_ready_future();\n }).finally([this]{\n _stopped.set_value();\n });\n });\n enqueue_zero_frame();\n }\n\n client::client(const logger& l, void* s, const socket_address& addr, const socket_address& local)\n : client(l, s, client_options{}, make_socket(), addr, local)\n {}\n\n client::client(const logger& l, void* s, client_options options, const socket_address& addr, const socket_address& local)\n : client(l, s, options, make_socket(), addr, local)\n {}\n\n client::client(const logger& l, void* s, socket socket, const socket_address& addr, const socket_address& local)\n : client(l, s, client_options{}, std::move(socket), addr, local)\n {}\n\n\n future\n server::connection::negotiate(feature_map requested) {\n feature_map ret;\n future<> f = make_ready_future<>();\n for (auto&& e : requested) {\n auto id = e.first;\n switch (id) {\n // supported features go here\n case protocol_features::COMPRESS: {\n if (get_server()._options.compressor_factory) {\n _compressor = get_server()._options.compressor_factory->negotiate(e.second, true, [this] { return send({}); });\n if (_compressor) {\n ret[protocol_features::COMPRESS] = _compressor->name();\n }\n }\n }\n break;\n case protocol_features::TIMEOUT:\n _timeout_negotiated = true;\n ret[protocol_features::TIMEOUT] = \"\";\n break;\n case protocol_features::STREAM_PARENT: {\n if (!get_server()._options.streaming_domain) {\n f = f.then([] {\n return make_exception_future<>(std::runtime_error(\"streaming is not configured for the server\"));\n });\n } else {\n _parent_id = deserialize_connection_id(e.second);\n _is_stream = true;\n // remove stream connection from rpc connection list\n get_server()._conns.erase(get_connection_id());\n f = f.then([this, c = shared_from_this()] () mutable {\n return smp::submit_to(_parent_id.shard(), [this, c = make_foreign(static_pointer_cast(c))] () mutable {\n auto sit = _servers.find(*get_server()._options.streaming_domain);\n if (sit == _servers.end()) {\n throw std::logic_error(format(\"Shard {:d} does not have server with streaming domain {}\", this_shard_id(), *get_server()._options.streaming_domain).c_str());\n }\n auto s = sit->second;\n auto it = s->_conns.find(_parent_id);\n if (it == s->_conns.end()) {\n throw std::logic_error(format(\"Unknown parent connection {} on shard {:d}\", _parent_id, this_shard_id()).c_str());\n }\n if (it->second->_error) {\n throw std::runtime_error(format(\"Parent connection {} is aborting on shard {:d}\", _parent_id, this_shard_id()).c_str());\n }\n auto id = c->get_connection_id();\n it->second->register_stream(id, make_lw_shared(std::move(c)));\n });\n });\n }\n break;\n }\n case protocol_features::ISOLATION: {\n auto&& isolation_cookie = e.second;\n struct isolation_function_visitor {\n isolation_function_visitor(const sstring& isolation_cookie) :\n _isolation_cookie(isolation_cookie) { }\n future operator() (resource_limits::syncronous_isolation_function f) const {\n return futurize_invoke(f, _isolation_cookie);\n }\n future operator() (resource_limits::asyncronous_isolation_function f) const {\n return f(_isolation_cookie);\n }\n private:\n sstring _isolation_cookie;\n };\n\n auto visitor = isolation_function_visitor(isolation_cookie);\n f = f.then([visitor = std::move(visitor), this] () mutable {\n return std::visit(visitor, get_server()._limits.isolate_connection).then([this] (isolation_config conf) {\n _isolation_config = conf;\n });\n });\n ret.emplace(e);\n break;\n }\n default:\n // nothing to do\n ;\n }\n }\n if (get_server()._options.streaming_domain) {\n ret[protocol_features::CONNECTION_ID] = serialize_connection_id(_id);\n }\n return f.then([ret = std::move(ret)] {\n return ret;\n });\n }\n\n future<>\n server::connection::negotiate_protocol() {\n return receive_negotiation_frame(*this, _read_buf).then([this] (feature_map requested_features) {\n return negotiate(std::move(requested_features)).then([this] (feature_map returned_features) {\n return send_negotiation_frame(std::move(returned_features));\n });\n });\n }\n\n future\n server::connection::read_request_frame_compressed(input_stream& in) {\n if (_timeout_negotiated) {\n return read_frame_compressed(_info.addr, _compressor, in);\n } else {\n return read_frame_compressed(_info.addr, _compressor, in);\n }\n }\n\n future<>\n server::connection::respond(int64_t msg_id, snd_buf&& data, std::optional timeout) {\n response_frame::encode_header(msg_id, data);\n return send(std::move(data), timeout);\n }\n\nfuture<> server::connection::send_unknown_verb_reply(std::optional timeout, int64_t msg_id, uint64_t type) {\n return wait_for_resources(28, timeout).then([this, timeout, msg_id, type] (auto permit) {\n // send unknown_verb exception back\n snd_buf data(28);\n static_assert(snd_buf::chunk_size >= 28, \"send buffer chunk size is too small\");\n auto p = data.front().get_write() + 12;\n write_le(p, uint32_t(exception_type::UNKNOWN_VERB));\n write_le(p + 4, uint32_t(8));\n write_le(p + 8, type);\n try {\n // Send asynchronously.\n // This is safe since connection::stop() will wait for background work.\n (void)with_gate(get_server()._reply_gate, [this, timeout, msg_id, data = std::move(data), permit = std::move(permit)] () mutable {\n // workaround for https://gcc.gnu.org/bugzilla/show_bug.cgi?id=83268\n auto c = shared_from_this();\n return respond(-msg_id, std::move(data), timeout).then([c = std::move(c), permit = std::move(permit)] {});\n });\n } catch(gate_closed_exception&) {/* ignore */}\n });\n}\n\n future<> server::connection::process() {\n return negotiate_protocol().then([this] () mutable {\n auto sg = _isolation_config ? _isolation_config->sched_group : current_scheduling_group();\n return with_scheduling_group(sg, [this] {\n set_negotiated();\n return do_until([this] { return _read_buf.eof() || _error; }, [this] () mutable {\n if (is_stream()) {\n return handle_stream_frame();\n }\n return read_request_frame_compressed(_read_buf).then([this] (request_frame::return_type header_and_buffer) {\n auto& expire = std::get<0>(header_and_buffer);\n auto& type = std::get<1>(header_and_buffer);\n auto& msg_id = std::get<2>(header_and_buffer);\n auto& data = std::get<3>(header_and_buffer);\n if (!data) {\n _error = true;\n return make_ready_future<>();\n } else {\n std::optional timeout;\n if (expire && *expire) {\n timeout = relative_timeout_to_absolute(std::chrono::milliseconds(*expire));\n }\n auto h = get_server()._proto.get_handler(type);\n if (!h) {\n return send_unknown_verb_reply(timeout, msg_id, type);\n }\n\n // If the new method of per-connection scheduling group was used, honor it.\n // Otherwise, use the old per-handler scheduling group.\n auto sg = _isolation_config ? _isolation_config->sched_group : h->handler.sg;\n return with_scheduling_group(sg, [this, timeout, msg_id, &h = h->handler, data = std::move(data.value()), guard = std::move(h->holder)] () mutable {\n return h.func(shared_from_this(), timeout, msg_id, std::move(data), std::move(guard));\n });\n }\n });\n });\n });\n }).then_wrapped([this] (future<> f) {\n std::exception_ptr ep;\n if (f.failed()) {\n ep = f.get_exception();\n log_exception(*this, log_level::error,\n format(\"server{} connection dropped\", is_stream() ? \" stream\" : \"\").c_str(), ep);\n }\n _fd.shutdown_input();\n _error = true;\n _stream_queue.abort(std::make_exception_ptr(stream_closed()));\n return stop_send_loop(ep).then_wrapped([this] (future<> f) {\n f.ignore_ready_future();\n get_server()._conns.erase(get_connection_id());\n if (is_stream()) {\n return deregister_this_stream();\n } else {\n return abort_all_streams();\n }\n }).finally([this] {\n return _compressor ? _compressor->close() : make_ready_future();\n }).finally([this] {\n _stopped.set_value();\n });\n }).finally([conn_ptr = shared_from_this()] {\n // hold onto connection pointer until do_until() exists\n });\n }\n\n server::connection::connection(server& s, connected_socket&& fd, socket_address&& addr, const logger& l, void* serializer, connection_id id)\n : rpc::connection(std::move(fd), l, serializer, id)\n , _info{.addr{std::move(addr)}, .server{s}, .conn_id{id}} {\n }\n\n future<> server::connection::deregister_this_stream() {\n if (!get_server()._options.streaming_domain) {\n return make_ready_future<>();\n }\n return smp::submit_to(_parent_id.shard(), [this] () mutable {\n auto sit = server::_servers.find(*get_server()._options.streaming_domain);\n if (sit != server::_servers.end()) {\n auto s = sit->second;\n auto it = s->_conns.find(_parent_id);\n if (it != s->_conns.end()) {\n it->second->_streams.erase(get_connection_id());\n }\n }\n });\n }\n\n future<> server::connection::abort_all_streams() {\n return parallel_for_each(_streams | boost::adaptors::map_values, [] (xshard_connection_ptr s) {\n return smp::submit_to(s->get_owner_shard(), [s] {\n s->get()->abort();\n });\n }).then([this] {\n _streams.clear();\n });\n }\n\n thread_local std::unordered_map server::_servers;\n\n server::server(protocol_base* proto, const socket_address& addr, resource_limits limits)\n : server(proto, seastar::listen(addr, listen_options{true}), limits, server_options{})\n {}\n\n server::server(protocol_base* proto, server_options opts, const socket_address& addr, resource_limits limits)\n : server(proto, seastar::listen(addr, listen_options{true, opts.load_balancing_algorithm}), limits, opts)\n {}\n\n server::server(protocol_base* proto, server_socket ss, resource_limits limits, server_options opts)\n : _proto(*proto), _ss(std::move(ss)), _limits(limits), _resources_available(limits.max_memory), _options(opts)\n {\n if (_options.streaming_domain) {\n if (_servers.find(*_options.streaming_domain) != _servers.end()) {\n throw std::runtime_error(format(\"An RPC server with the streaming domain {} is already exist\", *_options.streaming_domain));\n }\n _servers[*_options.streaming_domain] = this;\n }\n accept();\n }\n\n server::server(protocol_base* proto, server_options opts, server_socket ss, resource_limits limits)\n : server(proto, std::move(ss), limits, opts)\n {}\n\n void server::accept() {\n // Run asynchronously in background.\n // Communicate result via __ss_stopped.\n // The caller has to call server::stop() to synchronize.\n (void)keep_doing([this] () mutable {\n return _ss.accept().then([this] (accept_result ar) mutable {\n if (_options.filter_connection && !_options.filter_connection(ar.remote_address)) {\n return;\n }\n auto fd = std::move(ar.connection);\n auto addr = std::move(ar.remote_address);\n fd.set_nodelay(_options.tcp_nodelay);\n connection_id id = _options.streaming_domain ?\n connection_id::make_id(_next_client_id++, uint16_t(this_shard_id())) :\n connection_id::make_invalid_id(_next_client_id++);\n auto conn = _proto.make_server_connection(*this, std::move(fd), std::move(addr), id);\n auto r = _conns.emplace(id, conn);\n assert(r.second);\n // Process asynchronously in background.\n (void)conn->process();\n });\n }).then_wrapped([this] (future<>&& f){\n try {\n f.get();\n assert(false);\n } catch (...) {\n _ss_stopped.set_value();\n }\n });\n }\n\n future<> server::shutdown() {\n if (_shutdown) {\n return make_ready_future<>();\n }\n\n _ss.abort_accept();\n _resources_available.broken();\n if (_options.streaming_domain) {\n _servers.erase(*_options.streaming_domain);\n }\n return _ss_stopped.get_future().then([this] {\n return parallel_for_each(_conns | boost::adaptors::map_values, [] (shared_ptr conn) {\n return conn->stop();\n });\n }).finally([this] {\n _shutdown = true;\n });\n }\n\n future<> server::stop() {\n return when_all(\n shutdown(),\n _reply_gate.close()\n ).discard_result();\n }\n\n void server::abort_connection(connection_id id) {\n auto it = _conns.find(id);\n if (it == _conns.end()) {\n return;\n }\n try {\n it->second->abort();\n } catch (...) {\n log_exception(*it->second, log_level::error,\n \"fail to shutdown connection on user request\", std::current_exception());\n }\n }\n\n std::ostream& operator<<(std::ostream& os, const connection_id& id) {\n fmt::print(os, \"{:x}\", id.id());\n return os;\n }\n\n std::ostream& operator<<(std::ostream& os, const streaming_domain_type& domain) {\n fmt::print(os, \"{:d}\", domain._id);\n return os;\n }\n\n \nisolation_config default_isolate_connection(sstring isolation_cookie) {\n return isolation_config{};\n }\n\n}\n\n}\n\n// Path: src/rpc/lz4_compressor.cc\n/*\n * This file is open source software, licensed to you under the terms\n * of the Apache License, Version 2.0 (the \"License\"). See the NOTICE file\n * distributed with this work for additional information regarding copyright\n * ownership. You may not use this file except in compliance with the License.\n *\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing,\n * software distributed under the License is distributed on an\n * \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY\n * KIND, either express or implied. See the License for the\n * specific language governing permissions and limitations\n * under the License.\n */\n/*\n * Copyright (C) 2016 Scylladb, Ltd.\n */\n\n#include \n#include \n#include \n\nnamespace seastar {\n\nnamespace rpc {\n\nsstring lz4_compressor::name() const {\n return factory{}.supported();\n}\n\nconst sstring& lz4_compressor::factory::supported() const {\n const static sstring name = \"LZ4\";\n return name;\n}\n\nstd::unique_ptr lz4_compressor::factory::negotiate(sstring feature, bool is_server) const {\n return feature == supported() ? std::make_unique() : nullptr;\n}\n\n// Reusable contiguous buffers needed for LZ4 compression and decompression functions.\nclass reusable_buffer {\n static constexpr size_t chunk_size = 128 * 1024;\n static_assert(snd_buf::chunk_size == chunk_size, \"snd_buf::chunk_size == chunk_size\");\n\n std::unique_ptr _data;\n size_t _size;\nprivate:\n void reserve(size_t n) {\n if (_size < n) {\n _data.reset();\n // Not using std::make_unique to avoid value-initialisation.\n _data = std::unique_ptr(new char[n]);\n _size = n;\n }\n }\npublic:\n // Returns a pointer to a contiguous buffer containing all data stored in input.\n // The pointer remains valid until next call to this.\n const char* prepare(const std::variant>, temporary_buffer>& input, size_t size) {\n if (const auto single = std::get_if>(&input)) {\n return single->get();\n }\n reserve(size);\n auto dst = _data.get();\n for (const auto& fragment : std::get>>(input)) {\n dst = std::copy_n(fragment.begin(), fragment.size(), dst);\n }\n return _data.get();\n }\n\n // Calls function fn passing to it a pointer to a temporary contigiuous max_size\n // buffer.\n // fn is supposed to return the actual size of the data.\n // with_reserved() returns an Output object (snd_buf or rcv_buf or compatible),\n // containing data that was written to the temporary buffer.\n // Output should be either snd_buf or rcv_buf.\n template\n requires requires (Function fn, char* ptr) {\n { fn(ptr) } -> std::convertible_to;\n } && (std::is_same_v || std::is_same_v)\n Output with_reserved(size_t max_size, Function&& fn) {\n if (max_size <= chunk_size) {\n auto dst = temporary_buffer(max_size);\n size_t dst_size = fn(dst.get_write());\n dst.trim(dst_size);\n return Output(std::move(dst));\n }\n\n reserve(max_size);\n size_t dst_size = fn(_data.get());\n if (dst_size <= chunk_size) {\n return Output(temporary_buffer(_data.get(), dst_size));\n }\n\n auto left = dst_size;\n auto pos = _data.get();\n std::vector> buffers;\n while (left) {\n auto this_size = std::min(left, chunk_size);\n buffers.emplace_back(this_size);\n std::copy_n(pos, this_size, buffers.back().get_write());\n pos += this_size;\n left -= this_size;\n }\n return Output(std::move(buffers), dst_size);\n }\n\n void clear() noexcept {\n _data.reset();\n _size = 0;\n }\n};\n\nstatic thread_local reusable_buffer reusable_buffer_compressed_data;\nstatic thread_local reusable_buffer reusable_buffer_decompressed_data;\nstatic thread_local size_t buffer_use_count = 0;\nstatic constexpr size_t drop_buffers_trigger = 100'000;\n\nstatic void after_buffer_use() noexcept {\n if (buffer_use_count++ == drop_buffers_trigger) {\n reusable_buffer_compressed_data.clear();\n reusable_buffer_decompressed_data.clear();\n buffer_use_count = 0;\n }\n}\n\nsnd_buf lz4_compressor::compress(size_t head_space, snd_buf data) {\n head_space += 4;\n auto dst_size = head_space + LZ4_compressBound(data.size);\n auto dst = reusable_buffer_compressed_data.with_reserved(dst_size, [&] (char* dst) {\n auto src_size = data.size;\n auto src = reusable_buffer_decompressed_data.prepare(data.bufs, data.size);\n\n auto size = LZ4_compress_default(src, dst + head_space, src_size, LZ4_compressBound(src_size));\n if (size == 0) {\n throw std::runtime_error(\"RPC frame LZ4 compression failure\");\n }\n write_le(dst + (head_space - 4), src_size);\n return size + head_space;\n });\n after_buffer_use();\n return dst;\n}\n\nrcv_buf lz4_compressor::decompress(rcv_buf data) {\n if (data.size < 4) {\n return rcv_buf();\n } else {\n auto src_size = data.size;\n auto src = reusable_buffer_decompressed_data.prepare(data.bufs, data.size);\n\n auto dst_size = read_le(src);\n if (!dst_size) {\n throw std::runtime_error(\"RPC frame LZ4 decompression failure: decompressed size cannot be zero\");\n }\n src += sizeof(uint32_t);\n src_size -= sizeof(uint32_t);\n\n auto dst = reusable_buffer_compressed_data.with_reserved(dst_size, [&] (char* dst) {\n if (LZ4_decompress_safe(src, dst, src_size, dst_size) < 0) {\n throw std::runtime_error(\"RPC frame LZ4 decompression failure\");\n }\n return dst_size;\n });\n after_buffer_use();\n return dst;\n }\n}\n\n}\n\n}\n\n// Path: src/rpc/lz4_fragmented_compressor.cc\n/*\n * This file is open source software, licensed to you under the terms\n * of the Apache License, Version 2.0 (the \"License\"). See the NOTICE file\n * distributed with this work for additional information regarding copyright\n * ownership. You may not use this file except in compliance with the License.\n *\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing,\n * software distributed under the License is distributed on an\n * \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY\n * KIND, either express or implied. See the License for the\n * specific language governing permissions and limitations\n * under the License.\n */\n/*\n * Copyright (C) 2019 Scylladb, Ltd.\n */\n\n#include \n#include \n\n#include \n// LZ4_DECODER_RING_BUFFER_SIZE macro is introduced since v1.8.2\n// To work with previous lz4 release, copied the definition in lz4 here\n...\n", "needle_token_start": 13829, "needle_token_end": 13860, "code_context_ntokens": 16232, "output": ["```cpp\nisolation_config default_isolate_connection(sstring isolation_cookie) {\n return isolation_config{};\n}\n```"]} +{"repo": "scylladb/seastar", "name": "is_control_char", "language": "cpp", "path": "src/json/formatter.cc", "position_ratio": 0.95, "description": "\nFunction Description:\n1. **Purpose**: The function checks if a given character is a control character based on its ASCII value.\n2. **Input**: A single character.\n3. **Output**: A boolean value indicating whether the character is a control character.\n4. **Procedure**: The function evaluates if the ASCII value of the character is between 0 and 31 (inclusive), which are recognized as control characters in the ASCII table.\n\n", "instruction": "Based on the function description and code context, please retrieve and repeat the exact described function from the code context in a code block wrapped by ```:", "template": "instruction\ncode_context\ndescription\ninstruction", "code_context": "// Path: src/rpc/rpc.cc\n#include \n#include \n#include \n#include \n#include \n#include \n#include \n\n#if FMT_VERSION >= 90000\ntemplate <> struct fmt::formatter : fmt::ostream_formatter {};\n#endif\n\nnamespace seastar {\n\nnamespace rpc {\n\n void logger::operator()(const client_info& info, id_type msg_id, const sstring& str) const {\n log(format(\"client {} msg_id {}: {}\", info.addr, msg_id, str));\n }\n\n void logger::operator()(const client_info& info, id_type msg_id, log_level level, std::string_view str) const {\n log(level, \"client {} msg_id {}: {}\", info.addr, msg_id, str);\n }\n\n void logger::operator()(const client_info& info, const sstring& str) const {\n (*this)(info.addr, str);\n }\n\n void logger::operator()(const client_info& info, log_level level, std::string_view str) const {\n (*this)(info.addr, level, str);\n }\n\n void logger::operator()(const socket_address& addr, const sstring& str) const {\n log(format(\"client {}: {}\", addr, str));\n }\n\n void logger::operator()(const socket_address& addr, log_level level, std::string_view str) const {\n log(level, \"client {}: {}\", addr, str);\n }\n\n no_wait_type no_wait;\n\n snd_buf::snd_buf(size_t size_) : size(size_) {\n if (size <= chunk_size) {\n bufs = temporary_buffer(size);\n } else {\n std::vector> v;\n v.reserve(align_up(size_t(size), chunk_size) / chunk_size);\n while (size_) {\n v.push_back(temporary_buffer(std::min(chunk_size, size_)));\n size_ -= v.back().size();\n }\n bufs = std::move(v);\n }\n }\n\n snd_buf::snd_buf(snd_buf&&) noexcept = default;\n snd_buf& snd_buf::operator=(snd_buf&&) noexcept = default;\n\n temporary_buffer& snd_buf::front() {\n auto* one = std::get_if>(&bufs);\n if (one) {\n return *one;\n } else {\n return std::get>>(bufs).front();\n }\n }\n\n // Make a copy of a remote buffer. No data is actually copied, only pointers and\n // a deleter of a new buffer takes care of deleting the original buffer\n template // T is either snd_buf or rcv_buf\n T make_shard_local_buffer_copy(foreign_ptr> org) {\n if (org.get_owner_shard() == this_shard_id()) {\n return std::move(*org);\n }\n T buf(org->size);\n auto* one = std::get_if>(&org->bufs);\n\n if (one) {\n buf.bufs = temporary_buffer(one->get_write(), one->size(), make_object_deleter(std::move(org)));\n } else {\n auto& orgbufs = std::get>>(org->bufs);\n std::vector> newbufs;\n newbufs.reserve(orgbufs.size());\n deleter d = make_object_deleter(std::move(org));\n for (auto&& b : orgbufs) {\n newbufs.push_back(temporary_buffer(b.get_write(), b.size(), d.share()));\n }\n buf.bufs = std::move(newbufs);\n }\n\n return buf;\n }\n\n template snd_buf make_shard_local_buffer_copy(foreign_ptr>);\n template rcv_buf make_shard_local_buffer_copy(foreign_ptr>);\n\n static void log_exception(connection& c, log_level level, const char* log, std::exception_ptr eptr) {\n const char* s;\n try {\n std::rethrow_exception(eptr);\n } catch (std::exception& ex) {\n s = ex.what();\n } catch (...) {\n s = \"unknown exception\";\n }\n auto formatted = format(\"{}: {}\", log, s);\n c.get_logger()(c.peer_address(), level, std::string_view(formatted.data(), formatted.size()));\n }\n\n snd_buf connection::compress(snd_buf buf) {\n if (_compressor) {\n buf = _compressor->compress(4, std::move(buf));\n static_assert(snd_buf::chunk_size >= 4, \"send buffer chunk size is too small\");\n write_le(buf.front().get_write(), buf.size - 4);\n return buf;\n }\n return buf;\n }\n\n future<> connection::send_buffer(snd_buf buf) {\n auto* b = std::get_if>(&buf.bufs);\n if (b) {\n return _write_buf.write(std::move(*b));\n } else {\n return do_with(std::move(std::get>>(buf.bufs)),\n [this] (std::vector>& ar) {\n return do_for_each(ar.begin(), ar.end(), [this] (auto& b) {\n return _write_buf.write(std::move(b));\n });\n });\n }\n }\n\n future<> connection::send_entry(outgoing_entry& d) noexcept {\n return futurize_invoke([this, &d] {\n if (d.buf.size && _propagate_timeout) {\n static_assert(snd_buf::chunk_size >= sizeof(uint64_t), \"send buffer chunk size is too small\");\n if (_timeout_negotiated) {\n auto expire = d.t.get_timeout();\n uint64_t left = 0;\n if (expire != typename timer::time_point()) {\n left = std::chrono::duration_cast(expire - timer::clock::now()).count();\n }\n write_le(d.buf.front().get_write(), left);\n } else {\n d.buf.front().trim_front(sizeof(uint64_t));\n d.buf.size -= sizeof(uint64_t);\n }\n }\n auto buf = compress(std::move(d.buf));\n return send_buffer(std::move(buf)).then([this] {\n _stats.sent_messages++;\n return _write_buf.flush();\n });\n });\n }\n\n void connection::set_negotiated() noexcept {\n _negotiated->set_value();\n _negotiated = std::nullopt;\n }\n\n future<> connection::stop_send_loop(std::exception_ptr ex) {\n _error = true;\n if (_connected) {\n _fd.shutdown_output();\n }\n if (ex == nullptr) {\n ex = std::make_exception_ptr(closed_error());\n }\n while (!_outgoing_queue.empty()) {\n auto it = std::prev(_outgoing_queue.end());\n // Cancel all but front entry normally. The front entry is sitting in the\n // send_entry() and cannot be withdrawn, except when _negotiated is still\n // engaged. In the latter case when it will be aborted below the entry's\n // continuation will not be called and its done promise will not resolve\n // the _outgoing_queue_ready, so do it here\n if (it != _outgoing_queue.begin()) {\n withdraw(it, ex);\n } else {\n if (_negotiated) {\n it->done.set_exception(ex);\n }\n break;\n }\n }\n if (_negotiated) {\n _negotiated->set_exception(ex);\n }\n return when_all(std::move(_outgoing_queue_ready), std::move(_sink_closed_future)).then([this] (std::tuple, future> res){\n // _outgoing_queue_ready might be exceptional if queue drain or\n // _negotiated abortion set it such\n std::get<0>(res).ignore_ready_future();\n // _sink_closed_future is never exceptional\n bool sink_closed = std::get<1>(res).get();\n return _connected && !sink_closed ? _write_buf.close() : make_ready_future();\n });\n }\n\n void connection::set_socket(connected_socket&& fd) {\n if (_connected) {\n throw std::runtime_error(\"already connected\");\n }\n _fd = std::move(fd);\n _read_buf =_fd.input();\n _write_buf = _fd.output();\n _connected = true;\n }\n\n future<> connection::send_negotiation_frame(feature_map features) {\n auto negotiation_frame_feature_record_size = [] (const feature_map::value_type& e) {\n return 8 + e.second.size();\n };\n auto extra_len = boost::accumulate(\n features | boost::adaptors::transformed(negotiation_frame_feature_record_size),\n uint32_t(0));\n temporary_buffer reply(sizeof(negotiation_frame) + extra_len);\n auto p = reply.get_write();\n p = std::copy_n(rpc_magic, 8, p);\n write_le(p, extra_len);\n p += 4;\n for (auto&& e : features) {\n write_le(p, static_cast(e.first));\n p += 4;\n write_le(p, e.second.size());\n p += 4;\n p = std::copy_n(e.second.begin(), e.second.size(), p);\n }\n return _write_buf.write(std::move(reply)).then([this] {\n _stats.sent_messages++;\n return _write_buf.flush();\n });\n }\n\n void connection::withdraw(outgoing_entry::container_t::iterator it, std::exception_ptr ex) {\n assert(it != _outgoing_queue.end());\n\n auto pit = std::prev(it);\n // Previous entry's (pit's) done future will schedule current entry (it)\n // continuation. Similarly, it.done will schedule next entry continuation\n // or will resolve _outgoing_queue_ready future.\n //\n // To withdraw \"it\" we need to do two things:\n // - make pit.done resolve it->next (some time later)\n // - resolve \"it\"'s continuation right now\n //\n // The latter is achieved by resolving pit.done immediatelly, the former\n // by moving it.done into pit.done. For simplicity (verging on obscurity?)\n // both done's are just swapped and \"it\" resolves its new promise\n\n std::swap(it->done, pit->done);\n it->uncancellable();\n it->unlink();\n if (ex == nullptr) {\n it->done.set_value();\n } else {\n it->done.set_exception(ex);\n }\n }\n\n future<> connection::send(snd_buf buf, std::optional timeout, cancellable* cancel) {\n if (!_error) {\n if (timeout && *timeout <= rpc_clock_type::now()) {\n return make_ready_future<>();\n }\n\n auto p = std::make_unique(std::move(buf));\n auto& d = *p;\n _outgoing_queue.push_back(d);\n _outgoing_queue_size++;\n auto deleter = [this, it = _outgoing_queue.iterator_to(d)] {\n // Front entry is most likely (unless _negotiated is unresolved, check enqueue_zero_frame()) sitting\n // inside send_entry() continuations and thus it cannot be cancelled.\n if (it != _outgoing_queue.begin()) {\n withdraw(it);\n }\n };\n\n if (timeout) {\n auto& t = d.t;\n t.set_callback(deleter);\n t.arm(timeout.value());\n }\n if (cancel) {\n cancel->cancel_send = std::move(deleter);\n cancel->send_back_pointer = &d.pcancel;\n d.pcancel = cancel;\n }\n\n // New entry should continue (do its .then() lambda) after _outgoing_queue_ready\n // resolves. Next entry will need to do the same after this entry's done resolves.\n // Thus -- replace _outgoing_queue_ready with d's future and chain its continuation\n // on ..._ready's old value.\n return std::exchange(_outgoing_queue_ready, d.done.get_future()).then([this, p = std::move(p)] () mutable {\n _outgoing_queue_size--;\n if (__builtin_expect(!p->is_linked(), false)) {\n // If withdrawn the entry is unlinked and this lambda is fired right at once\n return make_ready_future<>();\n }\n\n p->uncancellable();\n return send_entry(*p).then_wrapped([this, p = std::move(p)] (auto f) mutable {\n if (f.failed()) {\n f.ignore_ready_future();\n abort();\n }\n p->done.set_value();\n });\n });\n } else {\n return make_exception_future<>(closed_error());\n }\n }\n\n void connection::abort() {\n if (!_error) {\n _error = true;\n _fd.shutdown_input();\n }\n }\n\n future<> connection::stop() noexcept {\n try {\n abort();\n } catch (...) {\n log_exception(*this, log_level::error, \"fail to shutdown connection while stopping\", std::current_exception());\n }\n return _stopped.get_future();\n }\n\n template\n static bool verify_frame(Connection& c, temporary_buffer& buf, size_t expected, const char* log) {\n if (buf.size() != expected) {\n if (buf.size() != 0) {\n c.get_logger()(c.peer_address(), log);\n }\n return false;\n }\n return true;\n }\n\n template\n static\n future\n receive_negotiation_frame(Connection& c, input_stream& in) {\n return in.read_exactly(sizeof(negotiation_frame)).then([&c, &in] (temporary_buffer neg) {\n if (!verify_frame(c, neg, sizeof(negotiation_frame), \"unexpected eof during negotiation frame\")) {\n return make_exception_future(closed_error());\n }\n negotiation_frame frame;\n std::copy_n(neg.get_write(), sizeof(frame.magic), frame.magic);\n frame.len = read_le(neg.get_write() + 8);\n if (std::memcmp(frame.magic, rpc_magic, sizeof(frame.magic)) != 0) {\n c.get_logger()(c.peer_address(), format(\"wrong protocol magic: {:02x}{:02x}{:02x}{:02x}{:02x}{:02x}{:02x}{:02x}\",\n frame.magic[0], frame.magic[1], frame.magic[2], frame.magic[3], frame.magic[4], frame.magic[5], frame.magic[6], frame.magic[7]));\n return make_exception_future(closed_error());\n }\n auto len = frame.len;\n return in.read_exactly(len).then([&c, len] (temporary_buffer extra) {\n if (extra.size() != len) {\n c.get_logger()(c.peer_address(), \"unexpected eof during negotiation frame\");\n return make_exception_future(closed_error());\n }\n feature_map map;\n auto p = extra.get();\n auto end = p + extra.size();\n while (p != end) {\n if (end - p < 8) {\n c.get_logger()(c.peer_address(), \"bad feature data format in negotiation frame\");\n return make_exception_future(closed_error());\n }\n auto feature = static_cast(read_le(p));\n auto f_len = read_le(p + 4);\n p += 8;\n if (f_len > end - p) {\n c.get_logger()(c.peer_address(), \"buffer underflow in feature data in negotiation frame\");\n return make_exception_future(closed_error());\n }\n auto data = sstring(p, f_len);\n p += f_len;\n map.emplace(feature, std::move(data));\n }\n return make_ready_future(std::move(map));\n });\n });\n }\n\n inline future\n read_rcv_buf(input_stream& in, uint32_t size) {\n return in.read_up_to(size).then([&, size] (temporary_buffer data) mutable {\n rcv_buf rb(size);\n if (data.size() == 0) {\n return make_ready_future(rcv_buf());\n } else if (data.size() == size) {\n rb.bufs = std::move(data);\n return make_ready_future(std::move(rb));\n } else {\n size -= data.size();\n std::vector> v;\n v.push_back(std::move(data));\n rb.bufs = std::move(v);\n return do_with(std::move(rb), std::move(size), [&in] (rcv_buf& rb, uint32_t& left) {\n return repeat([&] () {\n return in.read_up_to(left).then([&] (temporary_buffer data) {\n if (!data.size()) {\n rb.size -= left;\n return stop_iteration::yes;\n } else {\n left -= data.size();\n std::get>>(rb.bufs).push_back(std::move(data));\n return left ? stop_iteration::no : stop_iteration::yes;\n }\n });\n }).then([&rb] {\n return std::move(rb);\n });\n });\n }\n });\n }\n\n template\n future\n connection::read_frame(socket_address info, input_stream& in) {\n auto header_size = FrameType::header_size();\n return in.read_exactly(header_size).then([this, header_size, info, &in] (temporary_buffer header) {\n if (header.size() != header_size) {\n if (header.size() != 0) {\n _logger(info, format(\"unexpected eof on a {} while reading header: expected {:d} got {:d}\", FrameType::role(), header_size, header.size()));\n }\n return make_ready_future(FrameType::empty_value());\n }\n auto [size, h] = FrameType::decode_header(header.get());\n if (!size) {\n return make_ready_future(FrameType::make_value(h, rcv_buf()));\n } else {\n return read_rcv_buf(in, size).then([this, info, h = std::move(h), size] (rcv_buf rb) {\n if (rb.size != size) {\n _logger(info, format(\"unexpected eof on a {} while reading data: expected {:d} got {:d}\", FrameType::role(), size, rb.size));\n return make_ready_future(FrameType::empty_value());\n } else {\n return make_ready_future(FrameType::make_value(h, std::move(rb)));\n }\n });\n }\n });\n }\n\n template\n future\n connection::read_frame_compressed(socket_address info, std::unique_ptr& compressor, input_stream& in) {\n if (compressor) {\n return in.read_exactly(4).then([this, info, &in, &compressor] (temporary_buffer compress_header) {\n if (compress_header.size() != 4) {\n if (compress_header.size() != 0) {\n _logger(info, format(\"unexpected eof on a {} while reading compression header: expected 4 got {:d}\", FrameType::role(), compress_header.size()));\n }\n return make_ready_future(FrameType::empty_value());\n }\n auto ptr = compress_header.get();\n auto size = read_le(ptr);\n return read_rcv_buf(in, size).then([this, size, &compressor, info, &in] (rcv_buf compressed_data) {\n if (compressed_data.size != size) {\n _logger(info, format(\"unexpected eof on a {} while reading compressed data: expected {:d} got {:d}\", FrameType::role(), size, compressed_data.size));\n return make_ready_future(FrameType::empty_value());\n }\n auto eb = compressor->decompress(std::move(compressed_data));\n if (eb.size == 0) {\n // Empty frames might be sent as means of communication between the compressors, and should be skipped by the RPC layer.\n // We skip the empty frame here. We recursively restart the function, as if the empty frame didn't happen.\n // The yield() is here to limit the stack depth of the recursion to 1.\n return yield().then([this, info, &in, &compressor] { return read_frame_compressed(info, compressor, in); });\n }\n net::packet p;\n auto* one = std::get_if>(&eb.bufs);\n if (one) {\n p = net::packet(std::move(p), std::move(*one));\n } else {\n auto&& bufs = std::get>>(eb.bufs);\n p.reserve(bufs.size());\n for (auto&& b : bufs) {\n p = net::packet(std::move(p), std::move(b));\n }\n }\n return do_with(as_input_stream(std::move(p)), [this, info] (input_stream& in) {\n return read_frame(info, in);\n });\n });\n });\n } else {\n return read_frame(info, in);\n }\n }\n\n struct stream_frame {\n using opt_buf_type = std::optional;\n using return_type = opt_buf_type;\n struct header_type {\n bool eos;\n };\n static size_t header_size() {\n return 4;\n }\n static const char* role() {\n return \"stream\";\n }\n static auto empty_value() {\n return std::nullopt;\n }\n static std::pair decode_header(const char* ptr) {\n auto size = read_le(ptr);\n return size != -1U ? std::make_pair(size, header_type{false}) : std::make_pair(0U, header_type{true});\n }\n static auto make_value(const header_type& t, rcv_buf data) {\n if (t.eos) {\n data.size = -1U;\n }\n return data;\n }\n };\n\n future>\n connection::read_stream_frame_compressed(input_stream& in) {\n return read_frame_compressed(peer_address(), _compressor, in);\n }\n\n future<> connection::stream_close() {\n auto f = make_ready_future<>();\n if (!error()) {\n promise p;\n _sink_closed_future = p.get_future();\n // stop_send_loop(), which also calls _write_buf.close(), and this code can run in parallel.\n // Use _sink_closed_future to serialize them and skip second call to close()\n f = _write_buf.close().finally([p = std::move(p)] () mutable { p.set_value(true);});\n }\n return f.finally([this] () mutable { return stop(); });\n }\n\n future<> connection::stream_process_incoming(rcv_buf&& buf) {\n // we do not want to dead lock on huge packets, so let them in\n // but only one at a time\n auto size = std::min(size_t(buf.size), max_stream_buffers_memory);\n return get_units(_stream_sem, size).then([this, buf = std::move(buf)] (semaphore_units<>&& su) mutable {\n buf.su = std::move(su);\n return _stream_queue.push_eventually(std::move(buf));\n });\n }\n\n future<> connection::handle_stream_frame() {\n return read_stream_frame_compressed(_read_buf).then([this] (std::optional data) {\n if (!data) {\n _error = true;\n return make_ready_future<>();\n }\n return stream_process_incoming(std::move(*data));\n });\n }\n\n future<> connection::stream_receive(circular_buffer>>& bufs) {\n return _stream_queue.not_empty().then([this, &bufs] {\n bool eof = !_stream_queue.consume([&bufs] (rcv_buf&& b) {\n if (b.size == -1U) { // max fragment length marks an end of a stream\n return false;\n } else {\n bufs.push_back(make_foreign(std::make_unique(std::move(b))));\n return true;\n }\n });\n if (eof && !bufs.empty()) {\n assert(_stream_queue.empty());\n _stream_queue.push(rcv_buf(-1U)); // push eof marker back for next read to notice it\n }\n });\n }\n\n void connection::register_stream(connection_id id, xshard_connection_ptr c) {\n _streams.emplace(id, std::move(c));\n }\n\n xshard_connection_ptr connection::get_stream(connection_id id) const {\n auto it = _streams.find(id);\n if (it == _streams.end()) {\n throw std::logic_error(format(\"rpc stream id {} not found\", id).c_str());\n }\n return it->second;\n }\n\n // The request frame is\n // le64 optional timeout (see request_frame_with_timeout below)\n // le64 message type a.k.a. verb ID\n // le64 message ID\n // le32 payload length\n // ... payload\n struct request_frame {\n using opt_buf_type = std::optional;\n using return_type = std::tuple, uint64_t, int64_t, opt_buf_type>;\n using header_type = std::tuple, uint64_t, int64_t>;\n static constexpr size_t raw_header_size = sizeof(uint64_t) + sizeof(int64_t) + sizeof(uint32_t);\n static size_t header_size() {\n static_assert(request_frame_headroom >= raw_header_size);\n return raw_header_size;\n }\n static const char* role() {\n return \"server\";\n }\n static auto empty_value() {\n return std::make_tuple(std::nullopt, uint64_t(0), 0, std::nullopt);\n }\n static std::pair decode_header(const char* ptr) {\n auto type = read_le(ptr);\n auto msgid = read_le(ptr + 8);\n auto size = read_le(ptr + 16);\n return std::make_pair(size, std::make_tuple(std::nullopt, type, msgid));\n }\n static void encode_header(uint64_t type, int64_t msg_id, snd_buf& buf, size_t off) {\n auto p = buf.front().get_write() + off;\n write_le(p, type);\n write_le(p + 8, msg_id);\n write_le(p + 16, buf.size - raw_header_size - off);\n }\n static auto make_value(const header_type& t, rcv_buf data) {\n return std::make_tuple(std::get<0>(t), std::get<1>(t), std::get<2>(t), std::move(data));\n }\n };\n\n // This frame is used if protocol_features.TIMEOUT was negotiated\n...\n// Path: src/rpc/lz4_compressor.cc\n/*\n * This file is open source software, licensed to you under the terms\n * of the Apache License, Version 2.0 (the \"License\"). See the NOTICE file\n * distributed with this work for additional information regarding copyright\n * ownership. You may not use this file except in compliance with the License.\n *\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing,\n * software distributed under the License is distributed on an\n * \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY\n * KIND, either express or implied. See the License for the\n * specific language governing permissions and limitations\n * under the License.\n */\n/*\n * Copyright (C) 2016 Scylladb, Ltd.\n */\n\n#include \n#include \n#include \n\nnamespace seastar {\n\nnamespace rpc {\n\nsstring lz4_compressor::name() const {\n return factory{}.supported();\n}\n\nconst sstring& lz4_compressor::factory::supported() const {\n const static sstring name = \"LZ4\";\n return name;\n}\n\nstd::unique_ptr lz4_compressor::factory::negotiate(sstring feature, bool is_server) const {\n return feature == supported() ? std::make_unique() : nullptr;\n}\n\n// Reusable contiguous buffers needed for LZ4 compression and decompression functions.\nclass reusable_buffer {\n static constexpr size_t chunk_size = 128 * 1024;\n static_assert(snd_buf::chunk_size == chunk_size, \"snd_buf::chunk_size == chunk_size\");\n\n std::unique_ptr _data;\n size_t _size;\nprivate:\n void reserve(size_t n) {\n if (_size < n) {\n _data.reset();\n // Not using std::make_unique to avoid value-initialisation.\n _data = std::unique_ptr(new char[n]);\n _size = n;\n }\n }\npublic:\n // Returns a pointer to a contiguous buffer containing all data stored in input.\n // The pointer remains valid until next call to this.\n const char* prepare(const std::variant>, temporary_buffer>& input, size_t size) {\n if (const auto single = std::get_if>(&input)) {\n return single->get();\n }\n reserve(size);\n auto dst = _data.get();\n for (const auto& fragment : std::get>>(input)) {\n dst = std::copy_n(fragment.begin(), fragment.size(), dst);\n }\n return _data.get();\n }\n\n // Calls function fn passing to it a pointer to a temporary contigiuous max_size\n // buffer.\n // fn is supposed to return the actual size of the data.\n // with_reserved() returns an Output object (snd_buf or rcv_buf or compatible),\n // containing data that was written to the temporary buffer.\n // Output should be either snd_buf or rcv_buf.\n template\n requires requires (Function fn, char* ptr) {\n { fn(ptr) } -> std::convertible_to;\n } && (std::is_same_v || std::is_same_v)\n Output with_reserved(size_t max_size, Function&& fn) {\n if (max_size <= chunk_size) {\n auto dst = temporary_buffer(max_size);\n size_t dst_size = fn(dst.get_write());\n dst.trim(dst_size);\n return Output(std::move(dst));\n }\n\n reserve(max_size);\n size_t dst_size = fn(_data.get());\n if (dst_size <= chunk_size) {\n return Output(temporary_buffer(_data.get(), dst_size));\n }\n\n auto left = dst_size;\n auto pos = _data.get();\n std::vector> buffers;\n while (left) {\n auto this_size = std::min(left, chunk_size);\n buffers.emplace_back(this_size);\n std::copy_n(pos, this_size, buffers.back().get_write());\n pos += this_size;\n left -= this_size;\n }\n return Output(std::move(buffers), dst_size);\n }\n\n void clear() noexcept {\n _data.reset();\n _size = 0;\n }\n};\n\nstatic thread_local reusable_buffer reusable_buffer_compressed_data;\nstatic thread_local reusable_buffer reusable_buffer_decompressed_data;\nstatic thread_local size_t buffer_use_count = 0;\nstatic constexpr size_t drop_buffers_trigger = 100'000;\n\nstatic void after_buffer_use() noexcept {\n if (buffer_use_count++ == drop_buffers_trigger) {\n reusable_buffer_compressed_data.clear();\n reusable_buffer_decompressed_data.clear();\n buffer_use_count = 0;\n }\n}\n\nsnd_buf lz4_compressor::compress(size_t head_space, snd_buf data) {\n head_space += 4;\n auto dst_size = head_space + LZ4_compressBound(data.size);\n auto dst = reusable_buffer_compressed_data.with_reserved(dst_size, [&] (char* dst) {\n auto src_size = data.size;\n auto src = reusable_buffer_decompressed_data.prepare(data.bufs, data.size);\n\n auto size = LZ4_compress_default(src, dst + head_space, src_size, LZ4_compressBound(src_size));\n if (size == 0) {\n throw std::runtime_error(\"RPC frame LZ4 compression failure\");\n }\n write_le(dst + (head_space - 4), src_size);\n return size + head_space;\n });\n after_buffer_use();\n return dst;\n}\n\nrcv_buf lz4_compressor::decompress(rcv_buf data) {\n if (data.size < 4) {\n return rcv_buf();\n } else {\n auto src_size = data.size;\n auto src = reusable_buffer_decompressed_data.prepare(data.bufs, data.size);\n\n auto dst_size = read_le(src);\n if (!dst_size) {\n throw std::runtime_error(\"RPC frame LZ4 decompression failure: decompressed size cannot be zero\");\n }\n src += sizeof(uint32_t);\n src_size -= sizeof(uint32_t);\n\n auto dst = reusable_buffer_compressed_data.with_reserved(dst_size, [&] (char* dst) {\n if (LZ4_decompress_safe(src, dst, src_size, dst_size) < 0) {\n throw std::runtime_error(\"RPC frame LZ4 decompression failure\");\n }\n return dst_size;\n });\n after_buffer_use();\n return dst;\n }\n}\n\n}\n\n}\n\n// Path: src/rpc/lz4_fragmented_compressor.cc\n/*\n * This file is open source software, licensed to you under the terms\n * of the Apache License, Version 2.0 (the \"License\"). See the NOTICE file\n * distributed with this work for additional information regarding copyright\n * ownership. You may not use this file except in compliance with the License.\n *\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing,\n * software distributed under the License is distributed on an\n * \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY\n * KIND, either express or implied. See the License for the\n * specific language governing permissions and limitations\n * under the License.\n */\n/*\n * Copyright (C) 2019 Scylladb, Ltd.\n */\n\n#include \n#include \n\n#include \n// LZ4_DECODER_RING_BUFFER_SIZE macro is introduced since v1.8.2\n// To work with previous lz4 release, copied the definition in lz4 here\n#ifndef LZ4_DECODER_RING_BUFFER_SIZE\n#define LZ4_DECODER_RING_BUFFER_SIZE(maxBlockSize) (65536 + 14 + (maxBlockSize))\n#endif\n\nnamespace seastar {\nnamespace rpc {\n\nsstring lz4_fragmented_compressor::name() const {\n return factory{}.supported();\n}\n\nconst sstring& lz4_fragmented_compressor::factory::supported() const {\n const static sstring name = \"LZ4_FRAGMENTED\";\n return name;\n}\n\nstd::unique_ptr lz4_fragmented_compressor::factory::negotiate(sstring feature, bool is_server) const {\n return feature == supported() ? std::make_unique() : nullptr;\n}\n\n// Compressed message format:\n// The message consists of one or more data chunks each preceeded by a 4 byte header.\n// The value of the header detrmines the size of the chunk:\n// - most significant bit is cleared: intermediate chunk, 31 least significant bits\n// contain the compressed size of the chunk (i.e. how it appears on wire), the\n// decompressed size is 32 kB.\n// - most significant bit is set: last chunk, 31 least significant bits contain the\n// decompressed size of the chunk, the compressed size is the remaining part of\n// the message.\n// Compression and decompression is done using LZ4 streaming interface. Each chunk\n// depends on the one that precedes it.\n// All metadata is little-endian.\n\nstatic constexpr uint32_t last_chunk_flag = uint32_t(1) << 31;\nstatic constexpr size_t chunk_header_size = sizeof(uint32_t);\nstatic constexpr size_t chunk_size = 32 * 1024;\n\nnamespace {\n\nstruct compression_stream_deleter {\n void operator()(LZ4_stream_t* stream) const noexcept {\n LZ4_freeStream(stream);\n }\n};\n\nstruct decompression_stream_deleter {\n void operator()(LZ4_streamDecode_t* stream) const noexcept {\n LZ4_freeStreamDecode(stream);\n }\n};\n\n}\n\nsnd_buf lz4_fragmented_compressor::compress(size_t head_space, snd_buf data) {\n static thread_local auto stream = std::unique_ptr(LZ4_createStream());\n static_assert(chunk_size <= snd_buf::chunk_size, \"chunk_size <= snd_buf::chunk_size\");\n\n LZ4_resetStream(stream.get());\n\n auto size_left = data.size;\n auto src = std::get_if>(&data.bufs);\n if (!src) {\n src = std::get>>(data.bufs).data();\n }\n\n auto single_chunk_size = LZ4_COMPRESSBOUND(size_left) + head_space + chunk_header_size;\n if (single_chunk_size <= chunk_size && size_left <= chunk_size && src->size() == size_left) {\n // faster path for small messages\n auto dst = temporary_buffer(single_chunk_size);\n auto header = dst.get_write() + head_space;\n auto compressed_data = header + chunk_header_size;\n auto compressed_size = LZ4_compress_fast_continue(stream.get(), src->get(), compressed_data, size_left, LZ4_COMPRESSBOUND(size_left), 0);\n write_le(header, last_chunk_flag | size_left);\n dst.trim(head_space + chunk_header_size + compressed_size);\n return snd_buf(std::move(dst));\n }\n\n static constexpr size_t chunk_compress_bound = LZ4_COMPRESSBOUND(chunk_size);\n static constexpr size_t chunk_maximum_compressed_size = chunk_compress_bound + chunk_header_size;\n static_assert(chunk_maximum_compressed_size < snd_buf::chunk_size, \"chunk_maximum_compressed_size is too large\");\n\n std::vector> dst_buffers;\n size_t dst_offset = head_space;\n\n dst_buffers.emplace_back(std::max(head_space, snd_buf::chunk_size));\n\n // Intermediate chunks\n size_t total_compressed_size = 0;\n auto src_left = data.size;\n size_t src_current_offset = 0;\n\n // Advance offset in the current source fragment, move to the next fragment if needed.\n auto advance_src = [&] (size_t n) {\n src_current_offset += n;\n if (src_current_offset >= src->size()) {\n ++src;\n src_current_offset = 0;\n }\n src_left -= n;\n };\n\n // Input chunks do not have to be multiplies of chunk_size.\n // We handle such cases by reassembling a chunk in this temporary buffer.\n // Note that this is similar to the ring buffer compression case in docs, \n // we need to ensure that a suitable amount of (maybe) previous data is \n // stable in this or input buffer, thus make the temp buffer \n // LZ4_DECODER_RING_BUFFER_SIZE(chunk_size) large, and treat it as a ring.\n static constexpr auto lin_buf_size = LZ4_DECODER_RING_BUFFER_SIZE(chunk_size);\n static thread_local char temporary_chunk_data[lin_buf_size];\n size_t lin_off = 0;\n\n auto maybe_linearize = [&](size_t size) {\n auto src_ptr = src->get() + src_current_offset;\n if (src->size() - src_current_offset < size) {\n auto left = size;\n assert(lin_buf_size > size);\n if (lin_buf_size - lin_off < size) {\n lin_off = 0;\n }\n auto tmp = temporary_chunk_data + std::exchange(lin_off, lin_off + size);\n src_ptr = tmp;\n while (left) {\n auto this_size = std::min(src->size() - src_current_offset, left);\n tmp = std::copy_n(src->get() + src_current_offset, this_size, tmp);\n left -= this_size;\n advance_src(this_size);\n }\n } else {\n advance_src(chunk_size);\n lin_off = 0;\n }\n return src_ptr;\n };\n\n while (src_left > chunk_size) {\n // Check if we can fit another chunk in the current destination fragment.\n // If not allocate a new one.\n if (dst_offset + chunk_maximum_compressed_size > dst_buffers.back().size()) {\n dst_buffers.back().trim(dst_offset);\n dst_buffers.emplace_back(snd_buf::chunk_size);\n dst_offset = 0;\n }\n\n // Check if there is at least a contiguous chunk_size of data in the current\n // source fragment. If not, linearise that into temporary_chunk_data.\n auto src_ptr = maybe_linearize(chunk_size);\n auto header = dst_buffers.back().get_write() + dst_offset;\n auto dst = header + chunk_header_size;\n\n auto compressed_size = LZ4_compress_fast_continue(stream.get(), src_ptr, dst, chunk_size, chunk_compress_bound, 0);\n total_compressed_size += compressed_size + chunk_header_size;\n\n dst_offset += compressed_size + chunk_header_size;\n write_le(header, compressed_size);\n }\n\n // Last chunk\n auto last_chunk_compress_bound = LZ4_COMPRESSBOUND(src_left);\n auto last_chunk_maximum_compressed_size = last_chunk_compress_bound + chunk_header_size;\n\n // Check if we can fit the last chunk in the current destination fragment. Allocate a new one if not.\n if (dst_offset + last_chunk_maximum_compressed_size > dst_buffers.back().size()) {\n dst_buffers.back().trim(dst_offset);\n dst_buffers.emplace_back(snd_buf::chunk_size);\n dst_offset = 0;\n }\n auto header = dst_buffers.back().get_write() + dst_offset;\n auto dst = header + chunk_header_size;\n\n // Check if all remaining source data is contiguous. If not linearise it into temporary_chunk_data.\n auto rem = src_left;\n auto src_ptr = maybe_linearize(src_left);\n\n auto compressed_size = LZ4_compress_fast_continue(stream.get(), src_ptr, dst, rem, last_chunk_compress_bound, 0);\n dst_offset += compressed_size + chunk_header_size;\n write_le(header, last_chunk_flag | rem);\n total_compressed_size += compressed_size + chunk_header_size + head_space;\n\n auto& last = dst_buffers.back();\n last.trim(dst_offset);\n\n if (dst_buffers.size() == 1) {\n return snd_buf(std::move(dst_buffers.front()));\n }\n return snd_buf(std::move(dst_buffers), total_compressed_size);\n}\n\nrcv_buf lz4_fragmented_compressor::decompress(rcv_buf data) {\n if (data.size < 4) {\n return rcv_buf();\n }\n\n static thread_local auto stream = std::unique_ptr(LZ4_createStreamDecode());\n\n if (!LZ4_setStreamDecode(stream.get(), nullptr, 0)) {\n throw std::runtime_error(\"RPC frame LZ4_FRAGMENTED decompression failed to reset state\");\n }\n\n auto src = std::get_if>(&data.bufs);\n size_t src_left = data.size;\n size_t src_offset = 0;\n\n // Prepare source data. Returns pointer to n contiguous bytes of source data.\n // Avoids copy if possible, otherwise uses dst as a temporary storage.\n auto copy_src = [&] (char* dst, size_t n) -> const char* {\n // Fast path, no need to copy anything.\n if (src->size() - src_offset >= n) {\n auto ptr = src->get() + src_offset;\n src_left -= n;\n src_offset += n;\n return ptr;\n }\n\n // Need to linearise source chunk into dst.\n auto ptr = dst;\n src_left -= n;\n while (n) {\n if (src_offset == src->size()) {\n ++src;\n src_offset = 0;\n }\n auto this_size = std::min(n, src->size() - src_offset);\n std::copy_n(src->get() + src_offset, this_size, dst);\n n -= this_size;\n dst += this_size;\n src_offset += this_size;\n }\n return ptr;\n };\n\n // Read, possibly fragmented, header.\n auto read_header = [&] {\n uint32_t header_value;\n auto ptr = copy_src(reinterpret_cast(&header_value), chunk_header_size);\n if (ptr != reinterpret_cast(&header_value)) {\n std::copy_n(ptr, sizeof(uint32_t), reinterpret_cast(&header_value));\n }\n return le_to_cpu(header_value);\n };\n\n if (src) {\n auto header = read_le(src->get());\n if (header & last_chunk_flag) {\n // faster path for small messages: single chunk in a single buffer\n header &= ~last_chunk_flag;\n src_offset += chunk_header_size;\n src_left -= chunk_header_size;\n auto dst = temporary_buffer(header);\n if (LZ4_decompress_safe_continue(stream.get(), src->get() + src_offset, dst.get_write(), src_left, header) < 0) {\n throw std::runtime_error(\"RPC frame LZ4_FRAGMENTED decompression failure (short)\");\n }\n return rcv_buf(std::move(dst));\n }\n // not eligible for fast path: multiple chunks in a single buffer\n } else {\n // not eligible for fast path: multiple buffers\n src = std::get>>(data.bufs).data();\n }\n\n // Let's be a bit paranoid and not assume that the remote has the same\n // LZ4_COMPRESSBOUND as we do and allow any compressed chunk size.\n static thread_local auto chunk_buffer = temporary_buffer(LZ4_COMPRESSBOUND(chunk_size));\n\n std::vector> dst_buffers;\n size_t total_size = 0;\n\n // Decompressing requires either dest to be fully split or\n // \"preserved\" in 64KB or larger, depending on how it was\n // compressed. If not, decompression will fail, typically\n // on text-like constructs. Making our dest buffers 64K\n // ensures we retain a suitable dictionary region for all\n // passes. \n constexpr auto buf_size = 64 * 1024;\n size_t dst_offset = 0;\n\n auto get_dest = [&](size_t size) {\n if (dst_buffers.empty()) {\n dst_buffers.emplace_back(buf_size);\n }\n if (dst_buffers.back().size() - dst_offset < size) {\n dst_buffers.back().trim(dst_offset);\n dst_buffers.emplace_back(buf_size);\n dst_offset = 0;\n }\n return dst_buffers.back().get_write() + std::exchange(dst_offset, dst_offset + size);\n };\n\n // Intermediate chunks\n uint32_t header_value = read_header();\n while (!(header_value & last_chunk_flag)) {\n total_size += chunk_size;\n if (chunk_buffer.size() < header_value) {\n chunk_buffer = temporary_buffer(header_value);\n }\n auto src_ptr = copy_src(chunk_buffer.get_write(), header_value);\n auto dst = get_dest(chunk_size);\n if (LZ4_decompress_safe_continue(stream.get(), src_ptr, /*dst_buffers.back().get_write()*/dst, header_value, chunk_size) < 0) {\n throw std::runtime_error(format(\"RPC frame LZ4_FRAGMENTED decompression failure (long, at {} bytes)\", total_size - chunk_size));\n }\n header_value = read_header();\n }\n\n // Last chunk\n header_value &= ~last_chunk_flag;\n total_size += header_value;\n auto dst = get_dest(header_value);\n if (chunk_buffer.size() < src_left) {\n chunk_buffer = temporary_buffer(src_left);\n }\n auto last_chunk_compressed_size = src_left;\n auto src_ptr = copy_src(chunk_buffer.get_write(), src_left);\n if (LZ4_decompress_safe_continue(stream.get(), src_ptr, /*dst_buffers.back().get_write()*/dst, last_chunk_compressed_size, header_value) < 0) {\n throw std::runtime_error(format(\"RPC frame LZ4_FRAGMENTED decompression failure (long, last frame, at {} bytes)\", total_size - header_value));\n }\n\n dst_buffers.back().trim(dst_offset);\n\n if (dst_buffers.size() == 1) {\n return rcv_buf(std::move(dst_buffers.front()));\n }\n return rcv_buf(std::move(dst_buffers), total_size);\n}\n\n}\n}\n\n// Path: src/json/formatter.cc\n/*\n * This file is open source software, licensed to you under the terms\n * of the Apache License, Version 2.0 (the \"License\"). See the NOTICE file\n * distributed with this work for additional information regarding copyright\n * ownership. You may not use this file except in compliance with the License.\n *\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing,\n * software distributed under the License is distributed on an\n * \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY\n * KIND, either express or implied. See the License for the\n * specific language governing permissions and limitations\n * under the License.\n */\n/*\n * Copyright 2015 Cloudius Systems\n */\n\n#ifdef SEASTAR_MODULE\nmodule;\n#endif\n\n#include \n#include \n#include \n#include \n#include \n#include \n\n#ifdef SEASTAR_MODULE\nmodule seastar;\n#else\n#include \n#include \n#endif\n\nnamespace seastar {\n\nusing namespace std;\n\nnamespace json {\n\nsstring formatter::begin(state s) {\n switch (s) {\n case state::array: return \"[\";\n case state::map: return \"{\";\n default: return {};\n }\n}\n\nsstring formatter::end(state s) {\n switch (s) {\n case state::array: return \"]\";\n case state::map: return \"}\";\n default: return {};\n }\n}\n\n\nstatic inline bool is_control_char(char c) {\n return c >= 0 && c <= 0x1F;\n}\n\nstatic bool needs_escaping(const string_view& str) {\n return std::any_of(str.begin(), str.end(), [] (char c) {\n return is_control_char(c) || c == '\"' || c == '\\\\';\n });\n}\n\nstatic sstring string_view_to_json(const string_view& str) {\n if (!needs_escaping(str)) {\n return format(\"\\\"{}\\\"\", str);\n }\n\n ostringstream oss;\n oss << std::hex << std::uppercase << std::setfill('0');\n oss.put('\"');\n for (char c : str) {\n switch (c) {\n case '\"':\n oss.put('\\\\').put('\"');\n break;\n case '\\\\':\n oss.put('\\\\').put('\\\\');\n break;\n case '\\b':\n oss.put('\\\\').put('b');\n break;\n case '\\f':\n oss.put('\\\\').put('f');\n break;\n case '\\n':\n oss.put('\\\\').put('n');\n break;\n case '\\r':\n oss.put('\\\\').put('r');\n break;\n case '\\t':\n oss.put('\\\\').put('t');\n break;\n default:\n if (is_control_char(c)) {\n oss.put('\\\\').put('u') << std::setw(4) << static_cast(c);\n } else {\n oss.put(c);\n }\n break;\n }\n }\n oss.put('\"');\n return oss.str();\n}\n\nsstring formatter::to_json(const sstring& str) {\n return string_view_to_json(str);\n}\n\nsstring formatter::to_json(const char* str) {\n return string_view_to_json(str);\n}\n\nsstring formatter::to_json(const char* str, size_t len) {\n return string_view_to_json(string_view{str, len});\n}\n\nsstring formatter::to_json(int n) {\n return to_string(n);\n}\n\nsstring formatter::to_json(unsigned n) {\n return to_string(n);\n}\n\nsstring formatter::to_json(long n) {\n return to_string(n);\n}\n\nsstring formatter::to_json(float f) {\n if (std::isinf(f)) {\n throw out_of_range(\"Infinite float value is not supported\");\n } else if (std::isnan(f)) {\n throw invalid_argument(\"Invalid float value\");\n }\n return to_sstring(f);\n}\n\nsstring formatter::to_json(double d) {\n if (std::isinf(d)) {\n throw out_of_range(\"Infinite double value is not supported\");\n } else if (std::isnan(d)) {\n throw invalid_argument(\"Invalid double value\");\n }\n return to_sstring(d);\n}\n\n...\n", "needle_token_start": 15350, "needle_token_end": 15381, "code_context_ntokens": 16167, "output": ["```cpp\nstatic inline bool is_control_char(char c) {\n return c >= 0 && c <= 0x1F;\n}\n```"]} +{"repo": "WasmEdge/WasmEdge", "name": "packVal", "language": "cpp", "path": "lib/executor/engine/refInstr.cpp", "position_ratio": 0.05, "description": "\nFunction Description:\n1. **Purpose**: The function is designed to conditionally modify a value based on its type, specifically for types that require packing into smaller bit widths.\n2. **Input**: It takes two parameters: a type descriptor indicating the data type and the value to be potentially modified.\n3. **Output**: It returns a modified value if the type requires packing, otherwise it returns the original value unchanged.\n4. **Procedure**: The function first checks if the type descriptor indicates a packing requirement. If so, it checks the specific type code and applies a bitwise operation to limit the value to the appropriate size (8 or 16 bits). If no packing is required, the original value is returned as is.\n\n", "instruction": "Based on the function description and code context, please retrieve and repeat the exact described function from the code context in a code block wrapped by ```:", "template": "instruction\ncode_context\ndescription\ninstruction", "code_context": "// Path: lib/executor/engine/engine.cpp\n// SPDX-License-Identifier: Apache-2.0\n// SPDX-FileCopyrightText: 2019-2022 Second State INC\n\n#include \"executor/executor.h\"\n\n#include \n#include \n#include \n\nnamespace WasmEdge {\nnamespace Executor {\n\nExpect Executor::runExpression(Runtime::StackManager &StackMgr,\n AST::InstrView Instrs) {\n return execute(StackMgr, Instrs.begin(), Instrs.end());\n}\n\nExpect\nExecutor::runFunction(Runtime::StackManager &StackMgr,\n const Runtime::Instance::FunctionInstance &Func,\n Span Params) {\n // Set start time.\n if (Stat && Conf.getStatisticsConfigure().isTimeMeasuring()) {\n Stat->startRecordWasm();\n }\n\n // Reset and push a dummy frame into stack.\n StackMgr.pushFrame(nullptr, AST::InstrView::iterator(), 0, 0);\n\n // Push arguments.\n const auto &PTypes = Func.getFuncType().getParamTypes();\n for (uint32_t I = 0; I < Params.size(); I++) {\n // For the references, transform to non-null reference type if the value not\n // null.\n if (PTypes[I].isRefType() && Params[I].get().getPtr() &&\n Params[I].get().getType().isNullableRefType()) {\n auto Val = Params[I];\n Val.get().getType().toNonNullableRef();\n StackMgr.push(Val);\n } else {\n StackMgr.push(Params[I]);\n }\n }\n\n // Enter and execute function.\n AST::InstrView::iterator StartIt = {};\n Expect Res = {};\n if (auto GetIt = enterFunction(StackMgr, Func, Func.getInstrs().end())) {\n StartIt = *GetIt;\n } else {\n if (GetIt.error() == ErrCode::Value::Terminated) {\n // Handle the terminated case in entering AOT or host functions.\n // For the terminated case, not return now to print the statistics.\n Res = Unexpect(GetIt.error());\n } else {\n return Unexpect(GetIt);\n }\n }\n...\n// Path: lib/executor/engine/refInstr.cpp\n// SPDX-License-Identifier: Apache-2.0\n// SPDX-FileCopyrightText: 2019-2022 Second State INC\n\n#include \"executor/executor.h\"\n\nnamespace WasmEdge {\nnamespace Executor {\n\nnamespace {\n\nValVariant packVal(const ValType &Type, const ValVariant &Val) {\n if (Type.isPackType()) {\n switch (Type.getCode()) {\n case TypeCode::I8:\n return ValVariant(Val.get() & 0xFFU);\n case TypeCode::I16:\n return ValVariant(Val.get() & 0xFFFFU);\n default:\n assumingUnreachable();\n }\n }\n return Val;\n}\n\nValVariant unpackVal(const ValType &Type, const ValVariant &Val,\n bool IsSigned = false) {\n if (Type.isPackType()) {\n uint32_t Num = Val.get();\n switch (Type.getCode()) {\n case TypeCode::I8:\n if (IsSigned) {\n return static_cast(static_cast(Num));\n } else {\n return static_cast(static_cast(Num));\n }\n case TypeCode::I16:\n if (IsSigned) {\n return static_cast(static_cast(Num));\n } else {\n return static_cast(static_cast(Num));\n }\n default:\n assumingUnreachable();\n }\n }\n return Val;\n}\n\nstd::vector packVals(const ValType &Type,\n std::vector &&Vals) {\n for (uint32_t I = 0; I < Vals.size(); I++) {\n Vals[I] = packVal(Type, Vals[I]);\n }\n return std::move(Vals);\n}\n} // namespace\n\nExpect Executor::runRefNullOp(Runtime::StackManager &StackMgr,\n const ValType &Type) const noexcept {\n // A null reference is typed with the least type in its respective hierarchy.\n StackMgr.push(RefVariant(toBottomType(StackMgr, Type)));\n return {};\n}\n\nExpect Executor::runRefIsNullOp(ValVariant &Val) const noexcept {\n Val.emplace(Val.get().isNull() ? 1U : 0U);\n return {};\n}\n\nExpect Executor::runRefFuncOp(Runtime::StackManager &StackMgr,\n uint32_t Idx) const noexcept {\n const auto *FuncInst = getFuncInstByIdx(StackMgr, Idx);\n StackMgr.push(RefVariant(FuncInst->getDefType(), FuncInst));\n return {};\n}\n\nExpect Executor::runRefEqOp(ValVariant &Val1,\n const ValVariant &Val2) const noexcept {\n Val1.emplace(Val1.get().getPtr() ==\n Val2.get().getPtr()\n ? 1U\n : 0U);\n return {};\n}\n\nExpect\nExecutor::runRefAsNonNullOp(RefVariant &Ref,\n const AST::Instruction &Instr) const noexcept {\n if (Ref.isNull()) {\n spdlog::error(ErrCode::Value::CastNullToNonNull);\n spdlog::error(\n ErrInfo::InfoInstruction(Instr.getOpCode(), Instr.getOffset()));\n return Unexpect(ErrCode::Value::CastNullToNonNull);\n }\n Ref.getType().toNonNullableRef();\n return {};\n}\n\nExpect Executor::runStructNewOp(Runtime::StackManager &StackMgr,\n const uint32_t DefIndex,\n bool IsDefault) const noexcept {\n /// TODO: The array and struct instances are owned by the module instance\n /// currently because of referring the defined types of the module instances.\n /// This may be changed after applying the garbage collection mechanism.\n const auto &CompType =\n getDefTypeByIdx(StackMgr, DefIndex)->getCompositeType();\n uint32_t N = static_cast(CompType.getFieldTypes().size());\n std::vector Vals;\n if (IsDefault) {\n Vals.resize(N);\n for (uint32_t I = 0; I < N; I++) {\n const auto &VType = CompType.getFieldTypes()[I].getStorageType();\n Vals[I] = VType.isRefType()\n ? ValVariant(RefVariant(toBottomType(StackMgr, VType)))\n : ValVariant(static_cast(0));\n }\n } else {\n Vals = StackMgr.pop(N);\n for (uint32_t I = 0; I < N; I++) {\n Vals[I] = packVal(CompType.getFieldTypes()[I].getStorageType(), Vals[I]);\n }\n }\n auto *Inst =\n const_cast(StackMgr.getModule())\n ->newStruct(DefIndex, std::move(Vals));\n StackMgr.push(RefVariant(Inst->getDefType(), Inst));\n\n return {};\n}\n\nExpect Executor::runStructGetOp(ValVariant &Val, const uint32_t Idx,\n const AST::CompositeType &CompType,\n const AST::Instruction &Instr,\n bool IsSigned) const noexcept {\n const auto *Inst =\n Val.get().getPtr();\n if (Inst == nullptr) {\n spdlog::error(ErrCode::Value::AccessNullStruct);\n spdlog::error(\n ErrInfo::InfoInstruction(Instr.getOpCode(), Instr.getOffset()));\n return Unexpect(ErrCode::Value::AccessNullStruct);\n }\n const auto &SType = CompType.getFieldTypes()[Idx].getStorageType();\n Val = unpackVal(SType, Inst->getField(Idx), IsSigned);\n return {};\n}\n\nExpect\nExecutor::runStructSetOp(const ValVariant &Val, const RefVariant &InstRef,\n const AST::CompositeType &CompType, uint32_t Idx,\n const AST::Instruction &Instr) const noexcept {\n auto *Inst = InstRef.getPtr();\n if (Inst == nullptr) {\n spdlog::error(ErrCode::Value::AccessNullStruct);\n spdlog::error(\n ErrInfo::InfoInstruction(Instr.getOpCode(), Instr.getOffset()));\n return Unexpect(ErrCode::Value::AccessNullStruct);\n }\n const auto &SType = CompType.getFieldTypes()[Idx].getStorageType();\n Inst->getField(Idx) = packVal(SType, Val);\n return {};\n}\n\nExpect Executor::runArrayNewOp(Runtime::StackManager &StackMgr,\n const uint32_t DefIndex, uint32_t InitCnt,\n uint32_t ValCnt) const noexcept {\n /// TODO: The array and struct instances are owned by the module instance\n /// currently because of referring the defined types of the module instances.\n /// This may be changed after applying the garbage collection mechanism.\n assuming(InitCnt == 0 || InitCnt == 1 || InitCnt == ValCnt);\n const auto &CompType =\n getDefTypeByIdx(StackMgr, DefIndex)->getCompositeType();\n const auto &VType = CompType.getFieldTypes()[0].getStorageType();\n if (InitCnt == 0) {\n auto InitVal = VType.isRefType()\n ? ValVariant(RefVariant(toBottomType(StackMgr, VType)))\n : ValVariant(static_cast(0));\n auto *Inst =\n const_cast(StackMgr.getModule())\n ->newArray(DefIndex, ValCnt, InitVal);\n StackMgr.push(RefVariant(Inst->getDefType(), Inst));\n } else if (InitCnt == 1) {\n auto *Inst =\n const_cast(StackMgr.getModule())\n ->newArray(DefIndex, ValCnt, packVal(VType, StackMgr.getTop()));\n StackMgr.getTop().emplace(Inst->getDefType(), Inst);\n } else {\n auto *Inst =\n const_cast(StackMgr.getModule())\n ->newArray(DefIndex, packVals(VType, StackMgr.pop(ValCnt)));\n StackMgr.push(RefVariant(Inst->getDefType(), Inst));\n }\n return {};\n}\n\nExpect\nExecutor::runArrayNewDataOp(Runtime::StackManager &StackMgr,\n const Runtime::Instance::DataInstance &DataInst,\n const AST::Instruction &Instr) const noexcept {\n const uint32_t N = StackMgr.pop().get();\n const uint32_t S = StackMgr.getTop().get();\n const auto &CompType =\n getDefTypeByIdx(StackMgr, Instr.getTargetIndex())->getCompositeType();\n const uint32_t BSize =\n CompType.getFieldTypes()[0].getStorageType().getBitWidth() / 8;\n if (static_cast(S) + static_cast(N) * BSize >\n DataInst.getData().size()) {\n spdlog::error(ErrCode::Value::MemoryOutOfBounds);\n spdlog::error(ErrInfo::InfoBoundary(\n static_cast(S), N * BSize,\n DataInst.getData().size() > 0\n ? static_cast(DataInst.getData().size() - 1)\n : 0U));\n spdlog::error(\n ErrInfo::InfoInstruction(Instr.getOpCode(), Instr.getOffset()));\n return Unexpect(ErrCode::Value::MemoryOutOfBounds);\n }\n /// TODO: The array and struct instances are owned by the module instance\n /// currently because of referring the defined types of the module instances.\n /// This may be changed after applying the garbage collection mechanism.\n auto *Inst =\n const_cast(StackMgr.getModule())\n ->newArray(Instr.getTargetIndex(), N, 0U);\n for (uint32_t Idx = 0; Idx < N; Idx++) {\n // The value has been packed.\n Inst->getData(Idx) = DataInst.loadValue(S + Idx * BSize, BSize);\n }\n StackMgr.getTop().emplace(Inst->getDefType(), Inst);\n return {};\n}\n\nExpect\nExecutor::runArrayNewElemOp(Runtime::StackManager &StackMgr,\n const Runtime::Instance::ElementInstance &ElemInst,\n const AST::Instruction &Instr) const noexcept {\n const uint32_t N = StackMgr.pop().get();\n const uint32_t S = StackMgr.getTop().get();\n const auto &CompType =\n getDefTypeByIdx(StackMgr, Instr.getTargetIndex())->getCompositeType();\n const auto &SType = CompType.getFieldTypes()[0].getStorageType();\n auto ElemSrc = ElemInst.getRefs();\n if (static_cast(S) + static_cast(N) > ElemSrc.size()) {\n spdlog::error(ErrCode::Value::TableOutOfBounds);\n spdlog::error(ErrInfo::InfoBoundary(\n static_cast(S), N,\n ElemSrc.size() > 0 ? static_cast(ElemSrc.size() - 1) : 0U));\n spdlog::error(\n ErrInfo::InfoInstruction(Instr.getOpCode(), Instr.getOffset()));\n return Unexpect(ErrCode::Value::TableOutOfBounds);\n }\n std::vector Refs(ElemSrc.begin() + S, ElemSrc.begin() + S + N);\n /// TODO: The array and struct instances are owned by the module instance\n /// currently because of referring the defined types of the module instances.\n /// This may be changed after applying the garbage collection mechanism.\n auto *Inst =\n const_cast(StackMgr.getModule())\n ->newArray(Instr.getTargetIndex(), packVals(SType, std::move(Refs)));\n StackMgr.getTop().emplace(Inst->getDefType(), Inst);\n return {};\n}\n\nExpect\nExecutor::runArraySetOp(const ValVariant &Val, const uint32_t Idx,\n const RefVariant &InstRef,\n const AST::CompositeType &CompType,\n const AST::Instruction &Instr) const noexcept {\n auto *Inst = InstRef.getPtr();\n if (Inst == nullptr) {\n spdlog::error(ErrCode::Value::AccessNullArray);\n spdlog::error(\n ErrInfo::InfoInstruction(Instr.getOpCode(), Instr.getOffset()));\n return Unexpect(ErrCode::Value::AccessNullArray);\n }\n if (Idx >= Inst->getLength()) {\n spdlog::error(ErrCode::Value::ArrayOutOfBounds);\n spdlog::error(ErrInfo::InfoBoundary(Idx, 1, Inst->getBoundIdx()));\n spdlog::error(\n ErrInfo::InfoInstruction(Instr.getOpCode(), Instr.getOffset()));\n return Unexpect(ErrCode::Value::ArrayOutOfBounds);\n }\n const auto &SType = CompType.getFieldTypes()[0].getStorageType();\n Inst->getData(Idx) = packVal(SType, Val);\n return {};\n}\n\nExpect Executor::runArrayGetOp(ValVariant &Val, const uint32_t Idx,\n const AST::CompositeType &CompType,\n const AST::Instruction &Instr,\n bool IsSigned) const noexcept {\n const auto *Inst =\n Val.get().getPtr();\n if (Inst == nullptr) {\n spdlog::error(ErrCode::Value::AccessNullArray);\n spdlog::error(\n ErrInfo::InfoInstruction(Instr.getOpCode(), Instr.getOffset()));\n return Unexpect(ErrCode::Value::AccessNullArray);\n }\n if (Idx >= Inst->getLength()) {\n spdlog::error(ErrCode::Value::ArrayOutOfBounds);\n spdlog::error(ErrInfo::InfoBoundary(Idx, 1, Inst->getBoundIdx()));\n spdlog::error(\n ErrInfo::InfoInstruction(Instr.getOpCode(), Instr.getOffset()));\n return Unexpect(ErrCode::Value::ArrayOutOfBounds);\n }\n const auto &SType = CompType.getFieldTypes()[0].getStorageType();\n Val = unpackVal(SType, Inst->getData(Idx), IsSigned);\n return {};\n}\n\nExpect\nExecutor::runArrayLenOp(ValVariant &Val,\n const AST::Instruction &Instr) const noexcept {\n const auto *Inst =\n Val.get().getPtr();\n if (Inst == nullptr) {\n spdlog::error(ErrCode::Value::AccessNullArray);\n spdlog::error(\n ErrInfo::InfoInstruction(Instr.getOpCode(), Instr.getOffset()));\n return Unexpect(ErrCode::Value::AccessNullArray);\n }\n Val.emplace(Inst->getLength());\n return {};\n}\n\nExpect\nExecutor::runArrayFillOp(uint32_t N, const ValVariant &Val, uint32_t D,\n const RefVariant &InstRef,\n const AST::CompositeType &CompType,\n const AST::Instruction &Instr) const noexcept {\n auto *Inst = InstRef.getPtr();\n if (Inst == nullptr) {\n spdlog::error(ErrCode::Value::AccessNullArray);\n spdlog::error(\n ErrInfo::InfoInstruction(Instr.getOpCode(), Instr.getOffset()));\n return Unexpect(ErrCode::Value::AccessNullArray);\n }\n if (static_cast(D) + static_cast(N) > Inst->getLength()) {\n spdlog::error(ErrCode::Value::ArrayOutOfBounds);\n spdlog::error(ErrInfo::InfoBoundary(static_cast(D), N,\n Inst->getBoundIdx()));\n spdlog::error(\n ErrInfo::InfoInstruction(Instr.getOpCode(), Instr.getOffset()));\n return Unexpect(ErrCode::Value::ArrayOutOfBounds);\n }\n const auto &SType = CompType.getFieldTypes()[0].getStorageType();\n auto Arr = Inst->getArray();\n std::fill(Arr.begin() + D, Arr.begin() + D + N, packVal(SType, Val));\n return {};\n}\n\nExpect\nExecutor::runArrayCopyOp(uint32_t N, uint32_t S, const RefVariant &SrcInstRef,\n uint32_t D, const RefVariant &DstInstRef,\n const AST::CompositeType &SrcCompType,\n const AST::CompositeType &DstCompType,\n const AST::Instruction &Instr) const noexcept {\n auto *SrcInst = SrcInstRef.getPtr();\n auto *DstInst = DstInstRef.getPtr();\n if (SrcInst == nullptr || DstInst == nullptr) {\n spdlog::error(ErrCode::Value::AccessNullArray);\n spdlog::error(\n ErrInfo::InfoInstruction(Instr.getOpCode(), Instr.getOffset()));\n return Unexpect(ErrCode::Value::AccessNullArray);\n }\n if (static_cast(S) + static_cast(N) >\n SrcInst->getLength()) {\n spdlog::error(ErrCode::Value::ArrayOutOfBounds);\n spdlog::error(ErrInfo::InfoBoundary(static_cast(S), N,\n SrcInst->getBoundIdx()));\n spdlog::error(\n ErrInfo::InfoInstruction(Instr.getOpCode(), Instr.getOffset()));\n return Unexpect(ErrCode::Value::ArrayOutOfBounds);\n }\n if (static_cast(D) + static_cast(N) >\n DstInst->getLength()) {\n spdlog::error(ErrCode::Value::ArrayOutOfBounds);\n spdlog::error(ErrInfo::InfoBoundary(static_cast(D), N,\n DstInst->getBoundIdx()));\n spdlog::error(\n ErrInfo::InfoInstruction(Instr.getOpCode(), Instr.getOffset()));\n return Unexpect(ErrCode::Value::ArrayOutOfBounds);\n }\n const auto &SrcSType = SrcCompType.getFieldTypes()[0].getStorageType();\n const auto &DstSType = DstCompType.getFieldTypes()[0].getStorageType();\n auto SrcArr = SrcInst->getArray();\n auto DstArr = DstInst->getArray();\n if (D <= S) {\n std::transform(SrcArr.begin() + S, SrcArr.begin() + S + N,\n DstArr.begin() + D, [&](const ValVariant &V) {\n return packVal(DstSType, unpackVal(SrcSType, V));\n });\n } else {\n std::transform(std::make_reverse_iterator(SrcArr.begin() + S + N),\n std::make_reverse_iterator(SrcArr.begin() + S),\n std::make_reverse_iterator(DstArr.begin() + D + N),\n [&](const ValVariant &V) {\n return packVal(DstSType, unpackVal(SrcSType, V));\n });\n }\n return {};\n}\n\nExpect\nExecutor::runArrayInitDataOp(uint32_t N, uint32_t S, uint32_t D,\n const RefVariant &InstRef,\n const AST::CompositeType &CompType,\n const Runtime::Instance::DataInstance &DataInst,\n const AST::Instruction &Instr) const noexcept {\n const uint32_t BSize =\n CompType.getFieldTypes()[0].getStorageType().getBitWidth() / 8;\n auto *Inst = InstRef.getPtr();\n if (Inst == nullptr) {\n spdlog::error(ErrCode::Value::AccessNullArray);\n spdlog::error(\n ErrInfo::InfoInstruction(Instr.getOpCode(), Instr.getOffset()));\n return Unexpect(ErrCode::Value::AccessNullArray);\n }\n if (static_cast(D) + static_cast(N) > Inst->getLength()) {\n spdlog::error(ErrCode::Value::ArrayOutOfBounds);\n spdlog::error(ErrInfo::InfoBoundary(static_cast(D), N,\n Inst->getBoundIdx()));\n spdlog::error(\n ErrInfo::InfoInstruction(Instr.getOpCode(), Instr.getOffset()));\n return Unexpect(ErrCode::Value::ArrayOutOfBounds);\n }\n if (static_cast(S) + static_cast(N) * BSize >\n DataInst.getData().size()) {\n spdlog::error(ErrCode::Value::MemoryOutOfBounds);\n spdlog::error(ErrInfo::InfoBoundary(\n static_cast(S), N * BSize,\n DataInst.getData().size() > 0\n ? static_cast(DataInst.getData().size() - 1)\n : 0U));\n spdlog::error(\n ErrInfo::InfoInstruction(Instr.getOpCode(), Instr.getOffset()));\n return Unexpect(ErrCode::Value::MemoryOutOfBounds);\n }\n for (uint32_t Off = 0; Off < N; Off++) {\n // The value has been packed.\n Inst->getData(D + Off) = DataInst.loadValue(S + Off * BSize, BSize);\n }\n return {};\n}\n\nExpect\nExecutor::runArrayInitElemOp(uint32_t N, uint32_t S, uint32_t D,\n const RefVariant &InstRef,\n const AST::CompositeType &CompType,\n const Runtime::Instance::ElementInstance &ElemInst,\n const AST::Instruction &Instr) const noexcept {\n auto ElemSrc = ElemInst.getRefs();\n auto *Inst = InstRef.getPtr();\n if (Inst == nullptr) {\n spdlog::error(ErrCode::Value::AccessNullArray);\n spdlog::error(\n ErrInfo::InfoInstruction(Instr.getOpCode(), Instr.getOffset()));\n return Unexpect(ErrCode::Value::AccessNullArray);\n }\n if (static_cast(D) + static_cast(N) > Inst->getLength()) {\n spdlog::error(ErrCode::Value::ArrayOutOfBounds);\n spdlog::error(ErrInfo::InfoBoundary(static_cast(D), N,\n Inst->getBoundIdx()));\n spdlog::error(\n ErrInfo::InfoInstruction(Instr.getOpCode(), Instr.getOffset()));\n return Unexpect(ErrCode::Value::ArrayOutOfBounds);\n }\n if (static_cast(S) + static_cast(N) > ElemSrc.size()) {\n spdlog::error(ErrCode::Value::TableOutOfBounds);\n spdlog::error(ErrInfo::InfoBoundary(\n static_cast(S), N,\n ElemSrc.size() > 0 ? static_cast(ElemSrc.size() - 1) : 0U));\n spdlog::error(\n ErrInfo::InfoInstruction(Instr.getOpCode(), Instr.getOffset()));\n return Unexpect(ErrCode::Value::TableOutOfBounds);\n }\n const auto &SType = CompType.getFieldTypes()[0].getStorageType();\n\n auto Arr = Inst->getArray();\n // The value has been packed.\n std::transform(ElemSrc.begin() + S, ElemSrc.begin() + S + N, Arr.begin() + D,\n [&](const RefVariant &V) { return packVal(SType, V); });\n return {};\n}\n\nExpect\nExecutor::runRefTestOp(const Runtime::Instance::ModuleInstance *ModInst,\n ValVariant &Val, const AST::Instruction &Instr,\n bool IsCast) const noexcept {\n // Copy the value type here due to handling the externalized case.\n auto &VT = Val.get().getType();\n if (VT.isExternalized()) {\n VT = ValType(TypeCode::Ref, TypeCode::ExternRef);\n }\n Span GotTypeList = ModInst->getTypeList();\n if (!VT.isAbsHeapType()) {\n auto *Inst =\n Val.get().getPtr();\n // Reference must not be nullptr here because the null references are typed\n // with the least abstract heap type.\n if (Inst->getModule()) {\n GotTypeList = Inst->getModule()->getTypeList();\n }\n }\n\n if (AST::TypeMatcher::matchType(ModInst->getTypeList(), Instr.getValType(),\n GotTypeList, VT)) {\n if (!IsCast) {\n Val.emplace(1U);\n }\n } else {\n if (IsCast) {\n spdlog::error(ErrCode::Value::CastFailed);\n spdlog::error(ErrInfo::InfoMismatch(Instr.getValType(), VT));\n spdlog::error(\n ErrInfo::InfoInstruction(Instr.getOpCode(), Instr.getOffset()));\n return Unexpect(ErrCode::Value::CastFailed);\n } else {\n Val.emplace(0U);\n }\n }\n return {};\n}\n\nExpect Executor::runRefConvOp(RefVariant &Ref,\n TypeCode TCode) const noexcept {\n\n if (TCode == TypeCode::AnyRef) {\n // Internalize.\n if (Ref.isNull()) {\n Ref = RefVariant(ValType(TypeCode::RefNull, TypeCode::NullRef));\n } else {\n Ref.getType().setInternalized();\n if (Ref.getType().isExternRefType()) {\n Ref.getType() = ValType(TypeCode::Ref, TypeCode::AnyRef);\n }\n }\n } else {\n // Externalize.\n if (Ref.isNull()) {\n Ref = RefVariant(ValType(TypeCode::RefNull, TypeCode::NullExternRef));\n } else {\n // Use the externalize flag because the value type information should be\n // reserved when a reference being externalized and internalized.\n Ref.getType().setExternalized();\n }\n }\n return {};\n}\n\nExpect Executor::runRefI31Op(ValVariant &Val) const noexcept {\n uint32_t RefNum = (Val.get() & 0x7FFFFFFFU) | 0x80000000U;\n Val = RefVariant(ValType(TypeCode::Ref, TypeCode::I31Ref),\n reinterpret_cast(static_cast(RefNum)));\n return {};\n}\n\nExpect Executor::runI31GetOp(ValVariant &Val,\n const AST::Instruction &Instr,\n bool IsSigned) const noexcept {\n uint32_t RefNum = static_cast(\n reinterpret_cast(Val.get().getPtr()));\n if ((RefNum & 0x80000000U) == 0) {\n spdlog::error(ErrCode::Value::AccessNullI31);\n spdlog::error(\n ErrInfo::InfoInstruction(Instr.getOpCode(), Instr.getOffset()));\n return Unexpect(ErrCode::Value::AccessNullI31);\n }\n RefNum &= 0x7FFFFFFFU;\n if (IsSigned) {\n RefNum |= ((RefNum & 0x40000000U) << 1);\n }\n Val.emplace(RefNum);\n return {};\n}\n\n} // namespace Executor\n} // namespace WasmEdge\n\n// Path: lib/executor/engine/memoryInstr.cpp\n// SPDX-License-Identifier: Apache-2.0\n// SPDX-FileCopyrightText: 2019-2022 Second State INC\n\n#include \"executor/executor.h\"\n\nnamespace WasmEdge {\nnamespace Executor {\n\nExpect\nExecutor::runMemorySizeOp(Runtime::StackManager &StackMgr,\n Runtime::Instance::MemoryInstance &MemInst) {\n // Push SZ = page size to stack.\n StackMgr.push(MemInst.getPageSize());\n return {};\n}\n\nExpect\nExecutor::runMemoryGrowOp(Runtime::StackManager &StackMgr,\n Runtime::Instance::MemoryInstance &MemInst) {\n // Pop N for growing page size.\n uint32_t &N = StackMgr.getTop().get();\n\n // Grow page and push result.\n const uint32_t CurrPageSize = static_cast(MemInst.getPageSize());\n if (MemInst.growPage(N)) {\n N = CurrPageSize;\n } else {\n N = static_cast(-1);\n }\n return {};\n}\n\nExpect Executor::runMemoryInitOp(\n Runtime::StackManager &StackMgr, Runtime::Instance::MemoryInstance &MemInst,\n Runtime::Instance::DataInstance &DataInst, const AST::Instruction &Instr) {\n // Pop the length, source, and destination from stack.\n uint32_t Len = StackMgr.pop().get();\n uint32_t Src = StackMgr.pop().get();\n uint32_t Dst = StackMgr.pop().get();\n\n // Replace mem[Dst : Dst + Len] with data[Src : Src + Len].\n if (auto Res = MemInst.setBytes(DataInst.getData(), Dst, Src, Len)) {\n return {};\n } else {\n spdlog::error(\n ErrInfo::InfoInstruction(Instr.getOpCode(), Instr.getOffset()));\n return Unexpect(Res);\n }\n}\n\nExpect\nExecutor::runDataDropOp(Runtime::Instance::DataInstance &DataInst) {\n // Clear data instance.\n DataInst.clear();\n return {};\n}\n\nExpect\nExecutor::runMemoryCopyOp(Runtime::StackManager &StackMgr,\n Runtime::Instance::MemoryInstance &MemInstDst,\n Runtime::Instance::MemoryInstance &MemInstSrc,\n const AST::Instruction &Instr) {\n // Pop the length, source, and destination from stack.\n uint32_t Len = StackMgr.pop().get();\n uint32_t Src = StackMgr.pop().get();\n uint32_t Dst = StackMgr.pop().get();\n\n // Replace mem[Dst : Dst + Len] with mem[Src : Src + Len].\n if (auto Data = MemInstSrc.getBytes(Src, Len)) {\n if (auto Res = MemInstDst.setBytes(*Data, Dst, 0, Len)) {\n return {};\n } else {\n spdlog::error(\n ErrInfo::InfoInstruction(Instr.getOpCode(), Instr.getOffset()));\n return Unexpect(Res);\n }\n } else {\n spdlog::error(\n ErrInfo::InfoInstruction(Instr.getOpCode(), Instr.getOffset()));\n return Unexpect(Data);\n }\n}\n\nExpect\nExecutor::runMemoryFillOp(Runtime::StackManager &StackMgr,\n Runtime::Instance::MemoryInstance &MemInst,\n const AST::Instruction &Instr) {\n // Pop the length, value, and offset from stack.\n uint32_t Len = StackMgr.pop().get();\n uint8_t Val = static_cast(StackMgr.pop().get());\n uint32_t Off = StackMgr.pop().get();\n\n // Fill data with Val.\n if (auto Res = MemInst.fillBytes(Val, Off, Len)) {\n return {};\n } else {\n spdlog::error(\n ErrInfo::InfoInstruction(Instr.getOpCode(), Instr.getOffset()));\n return Unexpect(Res);\n }\n}\n\n} // namespace Executor\n} // namespace WasmEdge\n\n// Path: lib/executor/engine/threadInstr.cpp\n// SPDX-License-Identifier: Apache-2.0\n// SPDX-FileCopyrightText: 2019-2022 Second State INC\n\n#include \"executor/executor.h\"\n\nnamespace WasmEdge {\nnamespace Executor {\n\nExpect\nExecutor::runAtomicNotifyOp(Runtime::StackManager &StackMgr,\n Runtime::Instance::MemoryInstance &MemInst,\n const AST::Instruction &Instr) {\n ValVariant RawCount = StackMgr.pop();\n ValVariant &RawAddress = StackMgr.getTop();\n\n uint32_t Address = RawAddress.get();\n\n if (Address >\n std::numeric_limits::max() - Instr.getMemoryOffset()) {\n spdlog::error(ErrCode::Value::MemoryOutOfBounds);\n spdlog::error(ErrInfo::InfoBoundary(\n Address + static_cast(Instr.getMemoryOffset()),\n sizeof(uint32_t), MemInst.getBoundIdx()));\n spdlog::error(\n ErrInfo::InfoInstruction(Instr.getOpCode(), Instr.getOffset()));\n return Unexpect(ErrCode::Value::MemoryOutOfBounds);\n }\n Address += Instr.getMemoryOffset();\n\n if (Address % sizeof(uint32_t) != 0) {\n spdlog::error(ErrCode::Value::UnalignedAtomicAccess);\n spdlog::error(\n ErrInfo::InfoInstruction(Instr.getOpCode(), Instr.getOffset()));\n return Unexpect(ErrCode::Value::UnalignedAtomicAccess);\n }\n\n uint32_t Count = RawCount.get();\n if (auto Res = atomicNotify(MemInst, Address, Count); unlikely(!Res)) {\n spdlog::error(Res.error());\n spdlog::error(\n ErrInfo::InfoInstruction(Instr.getOpCode(), Instr.getOffset()));\n return Unexpect(Res);\n } else {\n RawAddress.emplace(*Res);\n }\n return {};\n}\n\nExpect Executor::runMemoryFenceOp() {\n std::atomic_thread_fence(std::memory_order_release);\n return {};\n}\n\nExpect\nExecutor::atomicNotify(Runtime::Instance::MemoryInstance &MemInst,\n uint32_t Address, uint32_t Count) noexcept {\n // The error message should be handled by the caller, or the AOT mode will\n // produce the duplicated messages.\n if (auto *AtomicObj = MemInst.getPointer *>(Address);\n !AtomicObj) {\n return Unexpect(ErrCode::Value::MemoryOutOfBounds);\n }\n\n std::unique_lock Locker(WaiterMapMutex);\n uint32_t Total = 0;\n auto Range = WaiterMap.equal_range(Address);\n for (auto Iterator = Range.first; Total < Count && Iterator != Range.second;\n ++Iterator) {\n if (likely(&MemInst == Iterator->second.MemInst)) {\n Iterator->second.Cond.notify_all();\n ++Total;\n }\n }\n return Total;\n}\n\nvoid Executor::atomicNotifyAll() noexcept {\n std::unique_lock Locker(WaiterMapMutex);\n for (auto Iterator = WaiterMap.begin(); Iterator != WaiterMap.end();\n ++Iterator) {\n Iterator->second.Cond.notify_all();\n }\n}\n\n} // namespace Executor\n} // namespace WasmEdge\n\n// Path: lib/loader/loader.cpp\n// SPDX-License-Identifier: Apache-2.0\n// SPDX-FileCopyrightText: 2019-2022 Second State INC\n\n#include \"loader/loader.h\"\n\n#include \"aot/version.h\"\n\n#include \n#include \n#include \n#include \n#include \n#include \n#include \n#include \n\nnamespace WasmEdge {\nnamespace Loader {\n\n// Load data from file path. See \"include/loader/loader.h\".\nExpect>\nLoader::loadFile(const std::filesystem::path &FilePath) {\n std::error_code EC;\n size_t FileSize = std::filesystem::file_size(FilePath, EC);\n if (EC) {\n spdlog::error(ErrCode::Value::IllegalPath);\n spdlog::error(ErrInfo::InfoFile(FilePath));\n return Unexpect(ErrCode::Value::IllegalPath);\n }\n\n std::ifstream Fin(FilePath, std::ios::in | std::ios::binary);\n if (!Fin) {\n spdlog::error(ErrCode::Value::IllegalPath);\n spdlog::error(ErrInfo::InfoFile(FilePath));\n return Unexpect(ErrCode::Value::IllegalPath);\n }\n\n std::vector Buf(FileSize);\n size_t Index = 0;\n while (FileSize > 0) {\n const uint32_t BlockSize = static_cast(\n std::min(FileSize, std::numeric_limits::max()));\n Fin.read(reinterpret_cast(Buf.data()) + Index, BlockSize);\n const uint32_t ReadCount = static_cast(Fin.gcount());\n if (ReadCount != BlockSize) {\n if (Fin.eof()) {\n spdlog::error(ErrCode::Value::UnexpectedEnd);\n spdlog::error(ErrInfo::InfoLoading(ReadCount));\n spdlog::error(ErrInfo::InfoFile(FilePath));\n return Unexpect(ErrCode::Value::UnexpectedEnd);\n } else {\n spdlog::error(ErrCode::Value::ReadError);\n spdlog::error(ErrInfo::InfoLoading(ReadCount));\n spdlog::error(ErrInfo::InfoFile(FilePath));\n return Unexpect(ErrCode::Value::ReadError);\n }\n }\n Index += static_cast(BlockSize);\n FileSize -= static_cast(BlockSize);\n }\n return Buf;\n}\n\nExpect,\n std::unique_ptr>>\nLoader::parseWasmUnit(const std::filesystem::path &FilePath) {\n std::lock_guard Lock(Mutex);\n // Set path and check the header.\n if (auto Res = FMgr.setPath(FilePath); !Res) {\n spdlog::error(Res.error());\n spdlog::error(ErrInfo::InfoFile(FilePath));\n return Unexpect(Res);\n }\n\n switch (FMgr.getHeaderType()) {\n // Filter out the Windows .dll, MacOS .dylib, or Linux .so AOT compiled\n // shared-library-WASM.\n case FileMgr::FileHeader::ELF:\n case FileMgr::FileHeader::DLL:\n case FileMgr::FileHeader::MachO_32:\n case FileMgr::FileHeader::MachO_64: {\n // AOT compiled shared-library-WASM cases. Use ldmgr to load the module.\n WASMType = InputType::SharedLibrary;\n FMgr.reset();\n std::shared_ptr Library = std::make_shared();\n if (auto Res = Library->load(FilePath); !Res) {\n spdlog::error(ErrInfo::InfoFile(FilePath));\n return Unexpect(Res);\n }\n if (auto Res = Library->getVersion()) {\n if (*Res != AOT::kBinaryVersion) {\n spdlog::error(ErrInfo::InfoMismatch(AOT::kBinaryVersion, *Res));\n spdlog::error(ErrInfo::InfoFile(FilePath));\n return Unexpect(ErrCode::Value::MalformedVersion);\n }\n } else {\n spdlog::error(ErrInfo::InfoFile(FilePath));\n return Unexpect(Res);\n }\n\n std::unique_ptr Mod;\n if (auto Code = Library->getWasm()) {\n // Set the binary and load module.\n // Not to use parseModule() here to keep the `WASMType` value.\n if (auto Res = FMgr.setCode(*Code); !Res) {\n spdlog::error(ErrInfo::InfoFile(FilePath));\n return Unexpect(Res);\n }\n if (auto Res = loadUnit()) {\n if (std::holds_alternative>(*Res)) {\n Mod = std::move(std::get>(*Res));\n }\n } else {\n spdlog::error(ErrInfo::InfoFile(FilePath));\n return Unexpect(Res);\n }\n } else {\n spdlog::error(ErrInfo::InfoFile(FilePath));\n return Unexpect(Code);\n }\n if (!Conf.getRuntimeConfigure().isForceInterpreter()) {\n // If the configure is set to force interpreter mode, not to load the AOT\n // related data.\n if (auto Res = loadExecutable(*Mod, Library); unlikely(!Res)) {\n spdlog::error(ErrInfo::InfoFile(FilePath));\n return Unexpect(Res);\n }\n }\n return Mod;\n }\n default:\n // Universal WASM, WASM, or other cases. Load and parse the module directly.\n WASMType = InputType::WASM;\n auto Unit = loadUnit();\n if (!Unit) {\n spdlog::error(ErrInfo::InfoFile(FilePath));\n return Unexpect(Unit);\n }\n switch (Unit->index()) {\n case 0: // component\n return Unit;\n case 1: // module\n default: {\n auto Mod = std::move(std::get>(*Unit));\n if (!Conf.getRuntimeConfigure().isForceInterpreter()) {\n // If the configure is set to force interpreter mode, not to set the\n // symbol.\n if (auto &Symbol = Mod->getSymbol()) {\n *Symbol = IntrinsicsTable;\n }\n }\n return Mod;\n }\n }\n }\n}\n\nExpect,\n std::unique_ptr>>\nLoader::parseWasmUnit(Span Code) {\n std::lock_guard Lock(Mutex);\n if (auto Res = FMgr.setCode(Code); !Res) {\n return Unexpect(Res);\n }\n switch (FMgr.getHeaderType()) {\n // Filter out the Windows .dll, MacOS .dylib, or Linux .so AOT compiled\n // shared-library-WASM.\n case FileMgr::FileHeader::ELF:\n case FileMgr::FileHeader::DLL:\n case FileMgr::FileHeader::MachO_32:\n case FileMgr::FileHeader::MachO_64:\n spdlog::error(\"Might an invalid wasm file\");\n spdlog::error(ErrCode::Value::MalformedMagic);\n spdlog::error(\n \" The AOT compiled WASM shared library is not supported for loading \"\n \"from memory. Please use the universal WASM binary or pure WASM, or \"\n \"load the AOT compiled WASM shared library from file.\");\n return Unexpect(ErrCode::Value::MalformedMagic);\n default:\n break;\n }\n // For malformed header checking, handle in the module loading.\n WASMType = InputType::WASM;\n return loadUnit();\n}\n\n// Parse module from file path. See \"include/loader/loader.h\".\nExpect>\nLoader::parseModule(const std::filesystem::path &FilePath) {\n if (auto R = parseWasmUnit(FilePath)) {\n if (std::holds_alternative>(*R)) {\n return std::move(std::get>(*R));\n }\n return Unexpect(ErrCode::Value::MalformedVersion);\n } else {\n return Unexpect(R);\n }\n}\n\n// Parse module from byte code. See \"include/loader/loader.h\".\nExpect>\nLoader::parseModule(Span Code) {\n if (auto R = parseWasmUnit(Code)) {\n if (std::holds_alternative>(*R)) {\n return std::move(std::get>(*R));\n }\n return Unexpect(ErrCode::Value::MalformedVersion);\n } else {\n return Unexpect(R);\n }\n}\n\n// Serialize module into byte code. See \"include/loader/loader.h\".\nExpect> Loader::serializeModule(const AST::Module &Mod) {\n return Ser.serializeModule(Mod);\n}\n\n} // namespace Loader\n} // namespace WasmEdge\n\n// Path: lib/loader/aot_section.cpp\n// SPDX-License-Identifier: Apache-2.0\n// SPDX-FileCopyrightText: 2019-2022 Second State INC\n\n#include \"loader/aot_section.h\"\n#include \"common/log.h\"\n#include \"system/allocator.h\"\n\n#if WASMEDGE_OS_LINUX || WASMEDGE_OS_MACOS\nextern \"C\" {\nextern void __register_frame(void *);\nextern void __deregister_frame(void *);\n}\n#endif\n\nnamespace {\ninline constexpr uint64_t roundDownPageBoundary(const uint64_t Value) {\n// ARM64 Mac has a special page size\n#if WASMEDGE_OS_MACOS && defined(__aarch64__)\n return Value & ~UINT64_C(16383);\n#else\n return Value & ~UINT64_C(4095);\n#endif\n}\ninline constexpr uint64_t roundUpPageBoundary(const uint64_t Value) {\n// ARM64 Mac has a special page size\n#if WASMEDGE_OS_MACOS && defined(__aarch64__)\n return roundDownPageBoundary(Value + UINT64_C(16383));\n#else\n return roundDownPageBoundary(Value + UINT64_C(4095));\n#endif\n}\n} // namespace\n\nnamespace WasmEdge::Loader {\n\nExpect AOTSection::load(const AST::AOTSection &AOTSec) noexcept {\n BinarySize = 0;\n for (const auto &Section : AOTSec.getSections()) {\n const auto Offset = std::get<1>(Section);\n const auto Size = std::get<2>(Section);\n BinarySize = std::max(BinarySize, Offset + Size);\n }\n BinarySize = roundUpPageBoundary(BinarySize);\n\n Binary = Allocator::allocate_chunk(BinarySize);\n if (unlikely(!Binary)) {\n spdlog::error(ErrCode::Value::MemoryOutOfBounds);\n return Unexpect(ErrCode::Value::MemoryOutOfBounds);\n }\n\n std::vector> ExecutableRanges;\n for (const auto &Section : AOTSec.getSections()) {\n const auto Offset = std::get<1>(Section);\n const auto Size = std::get<2>(Section);\n const auto &Content = std::get<3>(Section);\n if (Size > BinarySize || Offset > BinarySize ||\n Offset + Size > BinarySize || Content.size() > Size) {\n return Unexpect(ErrCode::Value::IntegerTooLarge);\n }\n std::copy(Content.begin(), Content.end(), Binary + Offset);\n switch (std::get<0>(Section)) {\n case 1: { // Text\n const auto O = roundDownPageBoundary(Offset);\n const auto S = roundUpPageBoundary(Size + (Offset - O));\n ExecutableRanges.emplace_back(Binary + O, S);\n break;\n }\n case 2: // Data\n break;\n case 3: // BSS\n break;\n#if WASMEDGE_OS_LINUX\n case 4: // EHFrame\n EHFrameAddress = reinterpret_cast(Binary + Offset);\n break;\n#elif WASMEDGE_OS_MACOS\n case 4: // EHFrame\n EHFrameAddress = reinterpret_cast(Binary + Offset);\n EHFrameSize = Size;\n break;\n#elif WASMEDGE_OS_WINDOWS\n case 4: // PData\n PDataAddress = reinterpret_cast(Binary + Offset);\n PDataSize =\n static_cast(Size / sizeof(winapi::RUNTIME_FUNCTION_));\n break;\n#endif\n default:\n return Unexpect(ErrCode::Value::IntegerTooLarge);\n }\n }\n\n for (const auto &[Pointer, Size] : ExecutableRanges) {\n if (!Allocator::set_chunk_executable(Pointer, Size)) {\n spdlog::error(ErrCode::Value::MemoryOutOfBounds);\n spdlog::error(\" set_chunk_executable failed:{}\", std::strerror(errno));\n return Unexpect(ErrCode::Value::MemoryOutOfBounds);\n }\n }\n\n IntrinsicsAddress = AOTSec.getIntrinsicsAddress();\n TypesAddress = AOTSec.getTypesAddress();\n CodesAddress = AOTSec.getCodesAddress();\n\n#if WASMEDGE_OS_LINUX\n if (EHFrameAddress) {\n __register_frame(EHFrameAddress);\n }\n#elif WASMEDGE_OS_MACOS\n if (EHFrameAddress) {\n auto Iter = EHFrameAddress;\n const auto End = EHFrameAddress + EHFrameSize - 4;\n\n while (Iter < End) {\n if (Iter != EHFrameAddress) {\n __register_frame(Iter);\n }\n const uint32_t Length = *reinterpret_cast(Iter);\n Iter += Length + 4;\n }\n }\n#elif WASMEDGE_OS_WINDOWS\n if (PDataSize != 0) {\n winapi::RtlAddFunctionTable(\n static_cast(PDataAddress), PDataSize,\n reinterpret_cast(Binary));\n }\n#endif\n\n return {};\n}\n\nvoid AOTSection::unload() noexcept {\n if (Binary) {\n#if WASMEDGE_OS_LINUX\n if (EHFrameAddress) {\n __deregister_frame(EHFrameAddress);\n }\n#elif WASMEDGE_OS_MACOS\n if (EHFrameAddress) {\n auto Iter = EHFrameAddress;\n const auto End = EHFrameAddress + EHFrameSize - 4;\n\n while (Iter < End) {\n if (Iter != EHFrameAddress) {\n __deregister_frame(Iter);\n }\n const uint32_t Length = *reinterpret_cast(Iter);\n Iter += Length + 4;\n }\n }\n#elif WASMEDGE_OS_WINDOWS\n if (PDataSize != 0) {\n winapi::RtlDeleteFunctionTable(\n static_cast(PDataAddress));\n }\n#endif\n Allocator::set_chunk_readable_writable(Binary, BinarySize);\n Allocator::release_chunk(Binary, BinarySize);\n Binary = nullptr;\n }\n}\n\n} // namespace WasmEdge::Loader\n\n// Path: lib/loader/filemgr.cpp\n// SPDX-License-Identifier: Apache-2.0\n// SPDX-FileCopyrightText: 2019-2022 Second State INC\n\n#include \"loader/filemgr.h\"\n\n#include \n#include \n\n// Error logging of file manager need to be handled in caller.\n\nnamespace WasmEdge {\n\n// Set path to file manager. See \"include/loader/filemgr.h\".\nExpect FileMgr::setPath(const std::filesystem::path &FilePath) {\n reset();\n std::error_code ErrCode;\n Size = std::filesystem::file_size(FilePath, ErrCode);\n if (likely(!ErrCode)) {\n if (!MMap::supported()) {\n Size = 0;\n Status = ErrCode::Value::IllegalPath;\n return Unexpect(Status);\n }\n FileMap.emplace(FilePath);\n if (auto *Pointer = FileMap->address(); likely(Pointer)) {\n Data = reinterpret_cast(Pointer);\n Status = ErrCode::Value::Success;\n } else {\n // File size is 0, mmap failed.\n // Will get 'UnexpectedEnd' error while the first reading.\n FileMap.reset();\n }\n return {};\n }\n Size = 0;\n Status = ErrCode::Value::IllegalPath;\n return Unexpect(Status);\n}\n\n// Set code data. See \"include/loader/filemgr.h\".\nExpect FileMgr::setCode(Span CodeData) {\n reset();\n Data = CodeData.data();\n Size = CodeData.size();\n Status = ErrCode::Value::Success;\n return {};\n}\n\n// Set code data. See \"include/loader/filemgr.h\".\nExpect FileMgr::setCode(std::vector CodeData) {\n reset();\n DataHolder.emplace(std::move(CodeData));\n Data = DataHolder->data();\n Size = DataHolder->size();\n Status = ErrCode::Value::Success;\n return {};\n}\n\n// Read one byte. See \"include/loader/filemgr.h\".\nExpect FileMgr::readByte() {\n if (unlikely(Status != ErrCode::Value::Success)) {\n return Unexpect(Status);\n }\n // Set the flag to the start offset.\n LastPos = Pos;\n // Check if exceed the data boundary and section boundary.\n if (auto Res = testRead(1); unlikely(!Res)) {\n return Unexpect(Res);\n }\n return Data[Pos++];\n}\n\n// Read number of bytes. See \"include/loader/filemgr.h\".\nExpect> FileMgr::readBytes(size_t SizeToRead) {\n // Set the flag to the start offset.\n LastPos = Pos;\n // Read bytes into vector.\n std::vector Buf(SizeToRead);\n if (auto Res = readBytes(Buf); unlikely(!Res)) {\n return Unexpect(Res);\n }\n return Buf;\n}\n\n// Decode and read an unsigned int. See \"include/loader/filemgr.h\".\nExpect FileMgr::readU32() {\n if (unlikely(Status != ErrCode::Value::Success)) {\n return Unexpect(Status);\n }\n // Set the flag to the start offset.\n LastPos = Pos;\n\n // Read and decode U32.\n uint32_t Result = 0;\n uint32_t Offset = 0;\n Byte Byte = 0x80;\n while (Byte & 0x80) {\n if (unlikely(Offset >= 32)) {\n Status = ErrCode::Value::IntegerTooLong;\n return Unexpect(Status);\n }\n if (auto Res = testRead(1); unlikely(!Res)) {\n return Unexpect(Res);\n }\n Byte = Data[Pos++];\n Result |= (Byte & UINT32_C(0x7F)) << Offset;\n if (Offset == 28 && unlikely((Byte & UINT32_C(0x70)) != 0)) {\n Status = ErrCode::Value::IntegerTooLarge;\n return Unexpect(Status);\n }\n Offset += 7;\n }\n return Result;\n}\n\n// Decode and read an unsigned long long int. See \"include/loader/filemgr.h\".\nExpect FileMgr::readU64() {\n...\n", "needle_token_start": 737, "needle_token_end": 870, "code_context_ntokens": 16238, "output": ["```cpp\nValVariant packVal(const ValType &Type, const ValVariant &Val) {\n if (Type.isPackType()) {\n switch (Type.getCode()) {\n case TypeCode::I8:\n return ValVariant(Val.get() & 0xFFU);\n case TypeCode::I16:\n return ValVariant(Val.get() & 0xFFFFU);\n default:\n assumingUnreachable();\n }\n }\n return Val;\n}\n```"]} +{"repo": "WasmEdge/WasmEdge", "name": "HostVersion", "language": "cpp", "path": "lib/loader/ast/section.cpp", "position_ratio": 0.15, "description": "\nFunction Description:\n1. **Purpose**: To retrieve the version number of the binary format supported by the host environment.\n2. **Input**: None, as it does not require any parameters.\n3. **Output**: Returns a constant integer representing the version number.\n4. **Procedure**: Simply accesses a predefined constant that holds the version number and returns this value.\n\n", "instruction": "Based on the function description and code context, please retrieve and repeat the exact described function from the code context in a code block wrapped by ```:", "template": "instruction\ncode_context\ndescription\ninstruction", "code_context": " return {};\n });\n });\n}\n\n// Load vector of table section. See \"include/loader/loader.h\".\nExpect Loader::loadSection(AST::TableSection &Sec) {\n return loadSectionContent(Sec, [this, &Sec]() {\n return loadSectionContentVec(\n Sec, [this](AST::TableSegment &TabSeg) { return loadSegment(TabSeg); });\n });\n}\n\n// Load vector of memory section. See \"include/loader/loader.h\".\nExpect Loader::loadSection(AST::MemorySection &Sec) {\n return loadSectionContent(Sec, [this, &Sec]() {\n return loadSectionContentVec(\n Sec, [this](AST::MemoryType &MemType) { return loadType(MemType); });\n });\n}\n\n// Load vector of global section. See \"include/loader/loader.h\".\nExpect Loader::loadSection(AST::GlobalSection &Sec) {\n return loadSectionContent(Sec, [this, &Sec]() {\n return loadSectionContentVec(Sec, [this](AST::GlobalSegment &GlobSeg) {\n return loadSegment(GlobSeg);\n });\n });\n}\n\n// Load vector of export section. See \"include/loader/loader.h\".\nExpect Loader::loadSection(AST::ExportSection &Sec) {\n return loadSectionContent(Sec, [this, &Sec]() {\n return loadSectionContentVec(\n Sec, [this](AST::ExportDesc &ExpDesc) { return loadDesc(ExpDesc); });\n });\n}\n\n// Load start function index. See \"include/loader/loader.h\".\nExpect Loader::loadSection(AST::StartSection &Sec) {\n return loadSectionContent(Sec, [this, &Sec]() -> Expect {\n // Read u32 of start function index.\n if (auto Res = FMgr.readU32()) {\n Sec.setContent(*Res);\n } else {\n return logLoadError(Res.error(), FMgr.getLastOffset(),\n ASTNodeAttr::Sec_Start);\n }\n return {};\n });\n}\n\n// Load vector of element section. See \"include/loader/loader.h\".\nExpect Loader::loadSection(AST::ElementSection &Sec) {\n return loadSectionContent(Sec, [this, &Sec]() {\n return loadSectionContentVec(Sec, [this](AST::ElementSegment &ElemSeg) {\n return loadSegment(ElemSeg);\n });\n });\n}\n\n// Load vector of code section. See \"include/loader/loader.h\".\nExpect Loader::loadSection(AST::CodeSection &Sec) {\n return loadSectionContent(Sec, [this, &Sec]() {\n return loadSectionContentVec(Sec, [this](AST::CodeSegment &CodeSeg) {\n return loadSegment(CodeSeg);\n });\n });\n}\n\n// Load vector of data section. See \"include/loader/loader.h\".\nExpect Loader::loadSection(AST::DataSection &Sec) {\n return loadSectionContent(Sec, [this, &Sec]() {\n return loadSectionContentVec(Sec, [this](AST::DataSegment &DataSeg) {\n return loadSegment(DataSeg);\n });\n });\n}\n\n// Load content of data count section. See \"include/loader/loader.h\".\nExpect Loader::loadSection(AST::DataCountSection &Sec) {\n return loadSectionContent(Sec, [this, &Sec]() -> Expect {\n // Read u32 of data count.\n if (auto Res = FMgr.readU32()) {\n Sec.setContent(*Res);\n } else {\n return logLoadError(Res.error(), FMgr.getLastOffset(),\n ASTNodeAttr::Sec_DataCount);\n }\n return {};\n });\n}\n\nExpect Loader::loadSection(AST::Component::ComponentSection &Sec) {\n auto ResPreamble = Loader::loadPreamble();\n if (!ResPreamble) {\n return Unexpect(ResPreamble);\n }\n auto WasmMagic = ResPreamble->first;\n auto Ver = ResPreamble->second;\n if (unlikely(Ver != ComponentVersion)) {\n return logLoadError(ErrCode::Value::MalformedVersion, FMgr.getLastOffset(),\n ASTNodeAttr::Component);\n }\n auto NestedComp = std::make_shared();\n NestedComp->getMagic() = WasmMagic;\n NestedComp->getVersion() = {Ver[0], Ver[1]};\n NestedComp->getLayer() = {Ver[2], Ver[3]};\n if (auto Res = loadComponent(*NestedComp); !Res) {\n spdlog::error(ErrInfo::InfoAST(ASTNodeAttr::Component));\n return Unexpect(Res);\n }\n Sec.getContent() = NestedComp;\n return {};\n}\n\nExpect Loader::loadSection(AST::CoreModuleSection &Sec) {\n return loadSectionContent(Sec, [this, &Sec]() -> Expect {\n auto ExpectedSize = Sec.getContentSize();\n auto StartOffset = FMgr.getOffset();\n auto ResPreamble = Loader::loadPreamble();\n if (!ResPreamble) {\n spdlog::error(ErrInfo::InfoAST(ASTNodeAttr::Module));\n return Unexpect(ResPreamble);\n }\n auto WasmMagic = ResPreamble->first;\n auto Ver = ResPreamble->second;\n if (unlikely(Ver != ModuleVersion)) {\n return logLoadError(ErrCode::Value::MalformedVersion,\n FMgr.getLastOffset(), ASTNodeAttr::Module);\n }\n AST::Module CoreMod;\n CoreMod.getMagic() = WasmMagic;\n CoreMod.getVersion() = Ver;\n\n auto Offset = FMgr.getOffset();\n ExpectedSize -= (Offset - StartOffset);\n\n if (auto Res = loadModuleInBound(CoreMod, ExpectedSize); !Res) {\n spdlog::error(ErrInfo::InfoAST(ASTNodeAttr::Module));\n return Unexpect(Res);\n }\n Sec.getContent() = CoreMod;\n return {};\n });\n}\n\n// Load vector of component alias section.\n// See \"include/loader/loader.h\".\nExpect Loader::loadSection(AST::Component::AliasSection &Sec) {\n return loadSectionContent(Sec, [this, &Sec]() {\n return loadSectionContentVec(\n Sec, [this](AST::Component::Alias &Alias) { return loadAlias(Alias); });\n });\n}\n\n// Load vector of component core:instance section.\n// See \"include/loader/loader.h\".\nExpect Loader::loadSection(AST::Component::CoreInstanceSection &Sec) {\n return loadSectionContent(Sec, [this, &Sec]() {\n return loadSectionContentVec(\n Sec, [this](AST::Component::CoreInstanceExpr &InstanceExpr) {\n return loadCoreInstance(InstanceExpr);\n });\n });\n}\n\n// Load vector of core type section.\n// See \"include/loader/loader.h\".\nExpect Loader::loadSection(AST::Component::CoreTypeSection &Sec) {\n return loadSectionContent(Sec, [this, &Sec]() {\n return loadSectionContentVec(\n Sec, [this](AST::Component::CoreDefType &Ty) { return loadType(Ty); });\n });\n}\n\n// Load vector of component type section.\n// See \"include/loader/loader.h\".\nExpect Loader::loadSection(AST::Component::TypeSection &Sec) {\n return loadSectionContent(Sec, [this, &Sec]() {\n return loadSectionContentVec(\n Sec, [this](AST::Component::DefType &Ty) { return loadType(Ty); });\n });\n}\n\nExpect Loader::loadSection(AST::Component::StartSection &Sec) {\n return loadSectionContent(Sec, [this, &Sec]() -> Expect {\n return loadStart(Sec.getContent());\n });\n}\n\nExpect Loader::loadSection(AST::Component::CanonSection &Sec) {\n return loadSectionContent(Sec, [this, &Sec]() {\n return loadSectionContentVec(\n Sec, [this](AST::Component::Canon &C) { return loadCanonical(C); });\n });\n}\n\nExpect Loader::loadSection(AST::Component::ImportSection &Sec) {\n return loadSectionContent(Sec, [this, &Sec]() {\n return loadSectionContentVec(\n Sec, [this](AST::Component::Import &C) { return loadImport(C); });\n });\n}\nExpect Loader::loadSection(AST::Component::ExportSection &Sec) {\n return loadSectionContent(Sec, [this, &Sec]() {\n return loadSectionContentVec(\n Sec, [this](AST::Component::Export &C) { return loadExport(C); });\n });\n}\n\n// Load vector of component instance section.\n// See \"include/loader/loader.h\".\nExpect Loader::loadSection(AST::Component::InstanceSection &Sec) {\n return loadSectionContent(Sec, [this, &Sec]() {\n return loadSectionContentVec(\n Sec, [this](AST::Component::InstanceExpr &InstanceExpr) {\n return loadInstance(InstanceExpr);\n });\n });\n}\n\nnamespace {\n\n\ninline constexpr uint32_t HostVersion() noexcept {\n return WasmEdge::AOT::kBinaryVersion;\n}\n\ninline constexpr uint8_t HostOSType() noexcept {\n#if WASMEDGE_OS_LINUX\n return UINT8_C(1);\n#elif WASMEDGE_OS_MACOS\n return UINT8_C(2);\n#elif WASMEDGE_OS_WINDOWS\n return UINT8_C(3);\n#else\n // Means WasmEdge is not yet supported on this OS.\n return UINT8_C(-1);\n#endif\n}\n\ninline constexpr uint8_t HostArchType() noexcept {\n#if defined(__x86_64__) || defined(_M_X64)\n return UINT8_C(1);\n#elif defined(__aarch64__)\n return UINT8_C(2);\n#elif defined(__riscv) && __riscv_xlen == 64\n return UINT8_C(3);\n#elif defined(__arm__) && __ARM_ARCH == 7\n return UINT8_C(4);\n#else\n // Means universal wasm binary is not yet supported on this arch.\n return UINT8_C(-1);\n#endif\n}\n\n} // namespace\n\n// If there is any loader error occurs in the loadSection, then fallback\n// to the interpreter mode with info level log.\nExpect Loader::loadSection(FileMgr &VecMgr, AST::AOTSection &Sec) {\n if (auto Res = VecMgr.readU32(); unlikely(!Res)) {\n spdlog::info(Res.error());\n spdlog::info(\" AOT binary version read error:{}\", Res.error());\n return Unexpect(Res);\n } else {\n Sec.setVersion(*Res);\n }\n if (unlikely(Sec.getVersion() != HostVersion())) {\n spdlog::info(ErrCode::Value::MalformedSection);\n spdlog::info(\" AOT binary version unmatched.\");\n return Unexpect(ErrCode::Value::MalformedSection);\n }\n\n if (auto Res = VecMgr.readByte(); unlikely(!Res)) {\n spdlog::info(Res.error());\n spdlog::info(\" AOT os type read error:{}\", Res.error());\n return Unexpect(Res);\n } else {\n Sec.setOSType(*Res);\n }\n if (unlikely(Sec.getOSType() != HostOSType())) {\n spdlog::info(ErrCode::Value::MalformedSection);\n spdlog::info(\" AOT OS type unmatched.\");\n return Unexpect(ErrCode::Value::MalformedSection);\n }\n\n if (auto Res = VecMgr.readByte(); unlikely(!Res)) {\n spdlog::info(Res.error());\n spdlog::info(\" AOT arch type read error:{}\", Res.error());\n return Unexpect(Res);\n } else {\n Sec.setArchType(*Res);\n }\n if (unlikely(Sec.getArchType() != HostArchType())) {\n spdlog::info(ErrCode::Value::MalformedSection);\n spdlog::info(\" AOT arch type unmatched.\");\n return Unexpect(ErrCode::Value::MalformedSection);\n }\n\n if (auto Res = VecMgr.readU64(); unlikely(!Res)) {\n spdlog::info(Res.error());\n spdlog::info(\" AOT version address read error:{}\", Res.error());\n return Unexpect(Res);\n } else {\n Sec.setVersionAddress(*Res);\n }\n if (auto Res = VecMgr.readU64(); unlikely(!Res)) {\n spdlog::info(Res.error());\n spdlog::info(\" AOT intrinsics address read error:{}\", Res.error());\n return Unexpect(Res);\n } else {\n Sec.setIntrinsicsAddress(*Res);\n }\n if (auto Res = VecMgr.readU64(); unlikely(!Res)) {\n spdlog::info(Res.error());\n spdlog::info(\" AOT types size read error:{}\", Res.error());\n return Unexpect(Res);\n } else {\n const uint64_t Size = *Res;\n if (Size > VecMgr.getRemainSize()) {\n spdlog::info(ErrCode::Value::IntegerTooLong);\n spdlog::info(\" AOT types size too large\");\n return Unexpect(ErrCode::Value::IntegerTooLong);\n }\n Sec.getTypesAddress().resize(Size);\n }\n for (size_t I = 0; I < Sec.getTypesAddress().size(); ++I) {\n if (auto Res = VecMgr.readU64(); unlikely(!Res)) {\n spdlog::info(Res.error());\n spdlog::info(\" AOT type address read error:{}\", Res.error());\n return Unexpect(Res);\n } else {\n Sec.getTypesAddress()[I] = *Res;\n }\n }\n if (auto Res = VecMgr.readU64(); unlikely(!Res)) {\n spdlog::info(Res.error());\n spdlog::info(\" AOT code size read error:{}\", Res.error());\n return Unexpect(Res);\n } else {\n const uint64_t Size = *Res;\n if (Size > VecMgr.getRemainSize()) {\n spdlog::info(ErrCode::Value::IntegerTooLong);\n spdlog::info(\" AOT code size too large\");\n return Unexpect(ErrCode::Value::IntegerTooLong);\n }\n Sec.getCodesAddress().resize(Size);\n }\n for (size_t I = 0; I < Sec.getCodesAddress().size(); ++I) {\n if (auto Res = VecMgr.readU64(); unlikely(!Res)) {\n spdlog::info(Res.error());\n spdlog::info(\" AOT code address read error:{}\", Res.error());\n return Unexpect(Res);\n } else {\n const uint64_t Address = *Res;\n Sec.getCodesAddress()[I] = Address;\n }\n }\n\n if (auto Res = VecMgr.readU32(); unlikely(!Res)) {\n spdlog::info(Res.error());\n spdlog::info(\" AOT section count read error:{}\", Res.error());\n return Unexpect(Res);\n } else {\n const uint32_t Size = *Res;\n if (Size > VecMgr.getRemainSize()) {\n spdlog::info(ErrCode::Value::IntegerTooLong);\n spdlog::info(\" AOT section count too large\");\n return Unexpect(ErrCode::Value::IntegerTooLong);\n }\n Sec.getSections().resize(Size);\n }\n\n for (auto &Section : Sec.getSections()) {\n if (auto Res = VecMgr.readByte(); unlikely(!Res)) {\n spdlog::info(Res.error());\n spdlog::info(\" AOT section type read error:{}\", Res.error());\n return Unexpect(Res);\n } else {\n std::get<0>(Section) = *Res;\n }\n if (auto Res = VecMgr.readU64(); unlikely(!Res)) {\n spdlog::info(Res.error());\n spdlog::info(\" AOT section offset read error:{}\", Res.error());\n return Unexpect(Res);\n } else {\n std::get<1>(Section) = *Res;\n }\n if (auto Res = VecMgr.readU64(); unlikely(!Res)) {\n spdlog::info(Res.error());\n spdlog::info(\" AOT section size read error:{}\", Res.error());\n return Unexpect(Res);\n } else {\n std::get<2>(Section) = *Res;\n }\n uint32_t ContentSize;\n if (auto Res = VecMgr.readU32(); unlikely(!Res)) {\n spdlog::info(Res.error());\n spdlog::info(\" AOT section data size read error:{}\", Res.error());\n return Unexpect(Res);\n } else {\n ContentSize = *Res;\n if (ContentSize > VecMgr.getRemainSize()) {\n spdlog::info(ErrCode::Value::IntegerTooLong);\n spdlog::info(\" AOT section data size is too large\");\n return Unexpect(ErrCode::Value::IntegerTooLong);\n }\n if (std::get<2>(Section) < ContentSize) {\n spdlog::info(ErrCode::Value::IntegerTooLong);\n spdlog::info(\" AOT section data size is larger then section size\");\n return Unexpect(ErrCode::Value::IntegerTooLong);\n }\n }\n if (auto Res = VecMgr.readBytes(ContentSize); unlikely(!Res)) {\n spdlog::info(Res.error());\n spdlog::info(\" AOT section data read error:{}\", Res.error());\n return Unexpect(Res);\n } else {\n std::get<3>(Section) = std::move(*Res);\n }\n }\n return {};\n}\n\n} // namespace Loader\n} // namespace WasmEdge\n\n// Path: lib/loader/ast/component.cpp\n// SPDX-License-Identifier: Apache-2.0\n// SPDX-FileCopyrightText: 2019-2022 Second State INC\n\n#include \"loader/loader.h\"\n#include \"spdlog/common.h\"\n#include \"spdlog/spdlog.h\"\n\n#include \n#include \n#include \n#include \n\nnamespace WasmEdge {\nnamespace Loader {\n\nExpect, std::vector>> Loader::loadPreamble() {\n // component ::= s*:
* => (component flatten(s*))\n // preamble ::= \n // magic ::= 0x00 0x61 0x73 0x6D\n // version ::= 0x0a 0x00\n // layer ::= 0x01 0x00\n\n // The combination of version and layer is corresponding to the version of\n // core wasm.\n // The core module has same magic but the different version:\n // 0x01 0x00 0x00 0x00\n auto Magic = FMgr.readBytes(4);\n if (!Magic) {\n return logLoadError(Magic.error(), FMgr.getLastOffset(),\n ASTNodeAttr::Component);\n }\n std::vector WasmMagic = {0x00, 0x61, 0x73, 0x6D};\n if (*Magic != WasmMagic) {\n spdlog::error(\"Might an invalid wasm file\");\n return logLoadError(ErrCode::Value::MalformedMagic, FMgr.getLastOffset(),\n ASTNodeAttr::Component);\n }\n auto Ver = FMgr.readBytes(4);\n if (!Ver) {\n return logLoadError(Ver.error(), FMgr.getLastOffset(),\n ASTNodeAttr::Component);\n }\n return std::make_pair(*Magic, *Ver);\n}\n\nExpect,\n std::unique_ptr>>\nLoader::loadUnit() {\n auto ResPreamble = Loader::loadPreamble();\n if (!ResPreamble) {\n return Unexpect(ResPreamble);\n }\n auto WasmMagic = ResPreamble->first;\n auto Ver = ResPreamble->second;\n if (Ver == ModuleVersion) {\n auto Mod = std::make_unique();\n Mod->getMagic() = WasmMagic;\n Mod->getVersion() = Ver;\n if (!Conf.getRuntimeConfigure().isForceInterpreter()) {\n if (auto Res = loadModuleAOT(Mod->getAOTSection()); !Res) {\n return Unexpect(Res);\n }\n }\n // Seek to the position after the binary header.\n FMgr.seek(8);\n if (auto Res = loadModule(*Mod); !Res) {\n return Unexpect(Res);\n }\n\n // Load library from AOT Section for the universal WASM case.\n // For the force interpreter mode, skip this.\n if (!Conf.getRuntimeConfigure().isForceInterpreter() &&\n WASMType == InputType::UniversalWASM) {\n if (auto Res = loadUniversalWASM(*Mod); !Res) {\n return Unexpect(Res);\n }\n }\n return Mod;\n } else if (Ver == ComponentVersion) {\n if (!Conf.hasProposal(Proposal::Component)) {\n return logNeedProposal(ErrCode::Value::IllegalOpCode, Proposal::Component,\n FMgr.getLastOffset(), ASTNodeAttr::Component);\n }\n spdlog::warn(\"component model is an experimental proposal\");\n auto Comp = std::make_unique();\n Comp->getMagic() = WasmMagic;\n Comp->getVersion() = {Ver[0], Ver[1]};\n Comp->getLayer() = {Ver[2], Ver[3]};\n if (auto Res = loadComponent(*Comp); !Res) {\n return Unexpect(Res);\n }\n return Comp;\n } else {\n return logLoadError(ErrCode::Value::MalformedVersion, FMgr.getLastOffset(),\n ASTNodeAttr::Component);\n }\n}\n\nExpect Loader::loadComponent(AST::Component::Component &Comp) {\n using namespace AST::Component;\n\n while (auto ResSecId = FMgr.readByte()) {\n if (!ResSecId) {\n return logLoadError(ResSecId.error(), FMgr.getLastOffset(),\n ASTNodeAttr::Component);\n }\n // keep going only if we have new section ID\n uint8_t NewSectionId = *ResSecId;\n\n switch (NewSectionId) {\n case 0x00:\n Comp.getSections().emplace_back();\n if (auto Res = loadSection(\n Comp.getSections().back().emplace());\n !Res) {\n spdlog::error(ErrInfo::InfoAST(ASTNodeAttr::Component));\n return Unexpect(Res);\n }\n break;\n case 0x01:\n Comp.getSections().emplace_back();\n if (auto Res = loadSection(\n Comp.getSections().back().emplace());\n !Res) {\n spdlog::error(ErrInfo::InfoAST(ASTNodeAttr::Component));\n return Unexpect(Res);\n }\n break;\n case 0x02: {\n Comp.getSections().emplace_back();\n if (auto Res = loadSection(\n Comp.getSections().back().emplace());\n !Res) {\n spdlog::error(ErrInfo::InfoAST(ASTNodeAttr::Component));\n return Unexpect(Res);\n }\n break;\n }\n case 0x03: {\n Comp.getSections().emplace_back();\n if (auto Res =\n loadSection(Comp.getSections().back().emplace());\n !Res) {\n spdlog::error(ErrInfo::InfoAST(ASTNodeAttr::Component));\n return Unexpect(Res);\n }\n break;\n }\n case 0x04:\n Comp.getSections().emplace_back();\n if (auto Res = loadSection(\n Comp.getSections().back().emplace());\n !Res) {\n spdlog::error(ErrInfo::InfoAST(ASTNodeAttr::Component));\n return Unexpect(Res);\n }\n break;\n case 0x05: {\n Comp.getSections().emplace_back();\n if (auto Res =\n loadSection(Comp.getSections().back().emplace());\n !Res) {\n spdlog::error(ErrInfo::InfoAST(ASTNodeAttr::Component));\n return Unexpect(Res);\n }\n break;\n }\n case 0x06: {\n Comp.getSections().emplace_back();\n if (auto Res =\n loadSection(Comp.getSections().back().emplace());\n !Res) {\n spdlog::error(ErrInfo::InfoAST(ASTNodeAttr::Component));\n return Unexpect(Res);\n }\n break;\n }\n case 0x07: {\n Comp.getSections().emplace_back();\n if (auto Res =\n loadSection(Comp.getSections().back().emplace());\n !Res) {\n spdlog::error(ErrInfo::InfoAST(ASTNodeAttr::Component));\n return Unexpect(Res);\n }\n break;\n }\n case 0x08: {\n Comp.getSections().emplace_back();\n if (auto Res =\n loadSection(Comp.getSections().back().emplace());\n !Res) {\n spdlog::error(ErrInfo::InfoAST(ASTNodeAttr::Component));\n return Unexpect(Res);\n }\n break;\n }\n case 0x09: {\n Comp.getSections().emplace_back();\n if (auto Res =\n loadSection(Comp.getSections().back().emplace());\n !Res) {\n spdlog::error(ErrInfo::InfoAST(ASTNodeAttr::Component));\n return Unexpect(Res);\n }\n break;\n }\n case 0x0A: {\n Comp.getSections().emplace_back();\n if (auto Res =\n loadSection(Comp.getSections().back().emplace());\n !Res) {\n spdlog::error(ErrInfo::InfoAST(ASTNodeAttr::Component));\n return Unexpect(Res);\n }\n break;\n }\n case 0x0B: {\n Comp.getSections().emplace_back();\n if (auto Res =\n loadSection(Comp.getSections().back().emplace());\n !Res) {\n spdlog::error(ErrInfo::InfoAST(ASTNodeAttr::Component));\n return Unexpect(Res);\n }\n break;\n }\n default:\n return logLoadError(ErrCode::Value::MalformedSection,\n FMgr.getLastOffset(), ASTNodeAttr::Component);\n }\n }\n\n return {};\n}\n\n} // namespace Loader\n} // namespace WasmEdge\n\n// Path: lib/loader/ast/expression.cpp\n// SPDX-License-Identifier: Apache-2.0\n// SPDX-FileCopyrightText: 2019-2022 Second State INC\n\n#include \"loader/loader.h\"\n\n#include \n\nnamespace WasmEdge {\nnamespace Loader {\n\n// Load to construct Expression node. See \"include/loader/loader.h\".\nExpect Loader::loadExpression(AST::Expression &Expr,\n std::optional SizeBound) {\n if (auto Res = loadInstrSeq(SizeBound)) {\n // For the section size mismatch case, check in caller.\n Expr.getInstrs() = std::move(*Res);\n } else {\n spdlog::error(ErrInfo::InfoAST(ASTNodeAttr::Expression));\n return Unexpect(Res);\n }\n return {};\n}\n\n} // namespace Loader\n} // namespace WasmEdge\n\n// Path: lib/loader/ast/description.cpp\n// SPDX-License-Identifier: Apache-2.0\n// SPDX-FileCopyrightText: 2019-2022 Second State INC\n\n#include \"loader/loader.h\"\n\nnamespace WasmEdge {\nnamespace Loader {\n\n// Load binary of Import description. See \"include/loader/loader.h\".\nExpect Loader::loadDesc(AST::ImportDesc &ImpDesc) {\n // Read the module name.\n if (auto Res = FMgr.readName()) {\n ImpDesc.setModuleName(*Res);\n } else {\n return logLoadError(Res.error(), FMgr.getLastOffset(),\n ASTNodeAttr::Desc_Import);\n }\n\n // Read the external name.\n if (auto Res = FMgr.readName()) {\n ImpDesc.setExternalName(*Res);\n } else {\n return logLoadError(Res.error(), FMgr.getLastOffset(),\n ASTNodeAttr::Desc_Import);\n }\n\n // Read the external type.\n if (auto Res = FMgr.readByte()) {\n ImpDesc.setExternalType(static_cast(*Res));\n } else {\n return logLoadError(Res.error(), FMgr.getLastOffset(),\n ASTNodeAttr::Desc_Import);\n }\n\n // Make content node according to external type.\n switch (ImpDesc.getExternalType()) {\n case ExternalType::Function: {\n // Read the function type index.\n if (auto Res = FMgr.readU32()) {\n ImpDesc.setExternalFuncTypeIdx(*Res);\n } else {\n return logLoadError(Res.error(), FMgr.getLastOffset(),\n ASTNodeAttr::Desc_Import);\n }\n break;\n }\n case ExternalType::Table: {\n // Read the table type node.\n return loadType(ImpDesc.getExternalTableType());\n }\n case ExternalType::Memory: {\n // Read the memory type node.\n return loadType(ImpDesc.getExternalMemoryType());\n }\n case ExternalType::Global: {\n // Read the global type node.\n if (auto Res = loadType(ImpDesc.getExternalGlobalType()); !Res) {\n return Unexpect(Res.error());\n }\n // Import the mutable globals are for ImportExportMutGlobals proposal.\n if (ImpDesc.getExternalGlobalType().getValMut() == ValMut::Var &&\n unlikely(!Conf.hasProposal(Proposal::ImportExportMutGlobals))) {\n return logNeedProposal(ErrCode::Value::InvalidMut,\n Proposal::ImportExportMutGlobals,\n FMgr.getLastOffset(), ASTNodeAttr::Desc_Import);\n }\n return {};\n }\n default:\n return logLoadError(ErrCode::Value::MalformedImportKind,\n FMgr.getLastOffset(), ASTNodeAttr::Desc_Import);\n }\n return {};\n}\n\n// Load binary of Export description. See \"include/loader/loader.h\".\nExpect Loader::loadDesc(AST::ExportDesc &ExpDesc) {\n // Read external name to export.\n if (auto Res = FMgr.readName()) {\n ExpDesc.setExternalName(*Res);\n } else {\n return logLoadError(Res.error(), FMgr.getLastOffset(),\n ASTNodeAttr::Desc_Export);\n }\n\n // Read external type.\n if (auto Res = FMgr.readByte()) {\n ExpDesc.setExternalType(static_cast(*Res));\n } else {\n return logLoadError(Res.error(), FMgr.getLastOffset(),\n ASTNodeAttr::Desc_Export);\n }\n switch (ExpDesc.getExternalType()) {\n case ExternalType::Function:\n case ExternalType::Table:\n case ExternalType::Memory:\n case ExternalType::Global:\n break;\n default:\n return logLoadError(ErrCode::Value::MalformedExportKind,\n FMgr.getLastOffset(), ASTNodeAttr::Desc_Export);\n }\n\n // Read external index to export.\n if (auto Res = FMgr.readU32()) {\n ExpDesc.setExternalIndex(*Res);\n } else {\n return logLoadError(Res.error(), FMgr.getLastOffset(),\n ASTNodeAttr::Desc_Export);\n }\n return {};\n}\n\n} // namespace Loader\n} // namespace WasmEdge\n\n// Path: lib/loader/ast/type.cpp\n// SPDX-License-Identifier: Apache-2.0\n// SPDX-FileCopyrightText: 2019-2022 Second State INC\n\n#include \"loader/loader.h\"\n\n#include \n\nnamespace WasmEdge {\nnamespace Loader {\n\n// Load binary and decode HeapType. See \"include/loader/loader.h\".\nExpect Loader::loadHeapType(TypeCode TC, ASTNodeAttr From) {\n if (auto Res = FMgr.readS33()) {\n if (*Res < 0) {\n // FuncRef or ExternRef case.\n TypeCode HTCode =\n static_cast(static_cast((*Res) & INT64_C(0x7F)));\n switch (HTCode) {\n case TypeCode::ExternRef:\n // For the ref.func instruction, the immediate changed to store the heap\n // type directly instead of the reference type after applying the\n // typed function reference proposal. Therefore the reference-types\n // proposal should be checked here.\n if (!Conf.hasProposal(Proposal::ReferenceTypes)) {\n return logNeedProposal(ErrCode::Value::MalformedElemType,\n Proposal::ReferenceTypes, FMgr.getLastOffset(),\n From);\n }\n [[fallthrough]];\n case TypeCode::FuncRef:\n return ValType(TC, HTCode);\n case TypeCode::NullFuncRef:\n case TypeCode::NullExternRef:\n case TypeCode::NullRef:\n case TypeCode::AnyRef:\n case TypeCode::EqRef:\n case TypeCode::I31Ref:\n case TypeCode::StructRef:\n case TypeCode::ArrayRef:\n if (!Conf.hasProposal(Proposal::GC)) {\n return logNeedProposal(ErrCode::Value::MalformedRefType, Proposal::GC,\n FMgr.getLastOffset(), From);\n }\n return ValType(TC, HTCode);\n default:\n return logLoadError(ErrCode::Value::MalformedRefType,\n FMgr.getLastOffset(), From);\n }\n } else {\n // Type index case. Legal if the function reference proposal is enabled.\n if (!Conf.hasProposal(Proposal::FunctionReferences)) {\n return logNeedProposal(ErrCode::Value::MalformedRefType,\n Proposal::FunctionReferences,\n FMgr.getLastOffset(), From);\n }\n return ValType(TC, static_cast(*Res));\n }\n } else {\n return logLoadError(Res.error(), FMgr.getLastOffset(), From);\n }\n}\n\n// Load binary and decode RefType. See \"include/loader/loader.h\".\nExpect Loader::loadRefType(ASTNodeAttr From) {\n if (auto Res = FMgr.readByte()) {\n // The error code is different when the reference-types proposal turned off.\n ErrCode::Value FailCode = Conf.hasProposal(Proposal::ReferenceTypes)\n ? ErrCode::Value::MalformedRefType\n : ErrCode::Value::MalformedElemType;\n TypeCode Code = static_cast(*Res);\n switch (Code) {\n case TypeCode::ExternRef:\n if (!Conf.hasProposal(Proposal::ReferenceTypes)) {\n return logNeedProposal(FailCode, Proposal::ReferenceTypes,\n FMgr.getLastOffset(), From);\n }\n [[fallthrough]];\n case TypeCode::FuncRef:\n // The FuncRef (0x70) is always allowed in the RefType even if the\n // reference-types proposal not enabled.\n return ValType(Code);\n case TypeCode::NullFuncRef:\n case TypeCode::NullExternRef:\n case TypeCode::NullRef:\n case TypeCode::AnyRef:\n case TypeCode::EqRef:\n case TypeCode::I31Ref:\n case TypeCode::StructRef:\n case TypeCode::ArrayRef:\n if (!Conf.hasProposal(Proposal::GC)) {\n return logNeedProposal(FailCode, Proposal::GC, FMgr.getLastOffset(),\n From);\n }\n return ValType(Code);\n case TypeCode::Ref:\n case TypeCode::RefNull: {\n if (!Conf.hasProposal(Proposal::FunctionReferences)) {\n return logNeedProposal(FailCode, Proposal::FunctionReferences,\n FMgr.getLastOffset(), From);\n }\n return loadHeapType(Code, From);\n }\n default:\n return logLoadError(FailCode, FMgr.getLastOffset(), From);\n }\n } else {\n return logLoadError(Res.error(), FMgr.getLastOffset(), From);\n }\n}\n\n// Load binary and decode ValType. See \"include/loader/loader.h\".\nExpect Loader::loadValType(ASTNodeAttr From, bool IsStorageType) {\n if (auto Res = FMgr.readByte()) {\n TypeCode Code = static_cast(*Res);\n switch (Code) {\n case TypeCode::V128:\n if (!Conf.hasProposal(Proposal::SIMD)) {\n return logNeedProposal(ErrCode::Value::MalformedValType, Proposal::SIMD,\n FMgr.getLastOffset(), From);\n }\n [[fallthrough]];\n case TypeCode::I32:\n case TypeCode::I64:\n case TypeCode::F32:\n case TypeCode::F64:\n return ValType(Code);\n case TypeCode::I8:\n case TypeCode::I16:\n if (!IsStorageType) {\n break;\n }\n if (!Conf.hasProposal(Proposal::GC)) {\n return logNeedProposal(ErrCode::Value::MalformedValType, Proposal::GC,\n FMgr.getLastOffset(), From);\n }\n return ValType(Code);\n case TypeCode::FuncRef:\n if (!Conf.hasProposal(Proposal::ReferenceTypes) &&\n !Conf.hasProposal(Proposal::BulkMemoryOperations)) {\n return logNeedProposal(ErrCode::Value::MalformedElemType,\n Proposal::ReferenceTypes, FMgr.getLastOffset(),\n From);\n }\n return ValType(Code);\n case TypeCode::ExternRef:\n if (!Conf.hasProposal(Proposal::ReferenceTypes)) {\n return logNeedProposal(ErrCode::Value::MalformedElemType,\n Proposal::ReferenceTypes, FMgr.getLastOffset(),\n From);\n }\n return ValType(Code);\n case TypeCode::NullFuncRef:\n case TypeCode::NullExternRef:\n case TypeCode::NullRef:\n case TypeCode::AnyRef:\n case TypeCode::EqRef:\n case TypeCode::I31Ref:\n case TypeCode::StructRef:\n case TypeCode::ArrayRef:\n if (!Conf.hasProposal(Proposal::GC)) {\n return logNeedProposal(ErrCode::Value::MalformedValType, Proposal::GC,\n FMgr.getLastOffset(), From);\n }\n return ValType(Code);\n case TypeCode::Ref:\n case TypeCode::RefNull:\n if (!Conf.hasProposal(Proposal::FunctionReferences)) {\n return logNeedProposal(ErrCode::Value::MalformedValType,\n Proposal::FunctionReferences,\n FMgr.getLastOffset(), From);\n }\n return loadHeapType(Code, From);\n default:\n break;\n }\n } else {\n return logLoadError(Res.error(), FMgr.getLastOffset(), From);\n }\n return logLoadError(ErrCode::Value::MalformedValType, FMgr.getLastOffset(),\n From);\n}\n\nExpect Loader::loadMutability(ASTNodeAttr From) {\n if (auto Res = FMgr.readByte()) {\n switch (static_cast(*Res)) {\n case ValMut::Const:\n case ValMut::Var:\n return static_cast(*Res);\n default:\n return logLoadError(ErrCode::Value::InvalidMut, FMgr.getLastOffset(),\n From);\n }\n } else {\n return logLoadError(Res.error(), FMgr.getLastOffset(), From);\n }\n}\n\nExpect Loader::loadFieldType(AST::FieldType &FType) {\n if (auto Res = loadValType(ASTNodeAttr::Type_Rec, true)) {\n FType.setStorageType(*Res);\n } else {\n // The error code logging is handled.\n return Unexpect(Res);\n }\n if (auto Res = loadMutability(ASTNodeAttr::Type_Rec)) {\n FType.setValMut(*Res);\n } else {\n // The error code logging is handled.\n return Unexpect(Res);\n }\n return {};\n}\n\nExpect Loader::loadCompositeType(AST::CompositeType &CType) {\n if (auto CodeByte = FMgr.readByte()) {\n switch (static_cast(*CodeByte)) {\n case TypeCode::Array: {\n AST::FieldType FType;\n if (auto Res = loadFieldType(FType); unlikely(!Res)) {\n return Unexpect(Res);\n }\n CType.setArrayType(std::move(FType));\n return {};\n }\n case TypeCode::Struct: {\n std::vector FList;\n if (auto Res = loadVec(\n FList,\n [this](AST::FieldType &FType) -> Expect {\n // The error code logging is handled.\n return loadFieldType(FType);\n });\n !Res) {\n return Unexpect(Res);\n }\n CType.setStructType(std::move(FList));\n return {};\n }\n case TypeCode::Func: {\n AST::FunctionType FuncType;\n if (auto Res = loadType(FuncType); unlikely(!Res)) {\n return Unexpect(Res);\n }\n CType.setFunctionType(std::move(FuncType));\n return {};\n }\n default:\n return logLoadError(ErrCode::Value::IntegerTooLong, FMgr.getLastOffset(),\n ASTNodeAttr::Type_Rec);\n }\n } else {\n return logLoadError(CodeByte.error(), FMgr.getLastOffset(),\n ASTNodeAttr::Type_Rec);\n }\n}\n\n// Load binary to construct Limit node. See \"include/loader/loader.h\".\nExpect Loader::loadLimit(AST::Limit &Lim) {\n // Read limit.\n if (auto Res = FMgr.readByte()) {\n switch (static_cast(*Res)) {\n case AST::Limit::LimitType::HasMin:\n Lim.setType(AST::Limit::LimitType::HasMin);\n break;\n case AST::Limit::LimitType::HasMinMax:\n Lim.setType(AST::Limit::LimitType::HasMinMax);\n break;\n case AST::Limit::LimitType::SharedNoMax:\n if (Conf.hasProposal(Proposal::Threads)) {\n return logLoadError(ErrCode::Value::SharedMemoryNoMax,\n FMgr.getLastOffset(), ASTNodeAttr::Type_Limit);\n } else {\n return logLoadError(ErrCode::Value::IntegerTooLarge,\n FMgr.getLastOffset(), ASTNodeAttr::Type_Limit);\n }\n case AST::Limit::LimitType::Shared:\n Lim.setType(AST::Limit::LimitType::Shared);\n break;\n default:\n if (*Res == 0x80 || *Res == 0x81) {\n // LEB128 cases will fail.\n return logLoadError(ErrCode::Value::IntegerTooLong,\n FMgr.getLastOffset(), ASTNodeAttr::Type_Limit);\n } else {\n return logLoadError(ErrCode::Value::IntegerTooLarge,\n FMgr.getLastOffset(), ASTNodeAttr::Type_Limit);\n }\n }\n } else {\n return logLoadError(Res.error(), FMgr.getLastOffset(),\n ASTNodeAttr::Type_Limit);\n }\n\n // Read min and max number.\n if (auto Res = FMgr.readU32()) {\n Lim.setMin(*Res);\n Lim.setMax(*Res);\n } else {\n return logLoadError(Res.error(), FMgr.getLastOffset(),\n ASTNodeAttr::Type_Limit);\n }\n if (Lim.hasMax()) {\n if (auto Res = FMgr.readU32()) {\n Lim.setMax(*Res);\n } else {\n return logLoadError(Res.error(), FMgr.getLastOffset(),\n ASTNodeAttr::Type_Limit);\n }\n }\n return {};\n}\n\n// Load binary to construct SubType node. See \"include/loader/loader.h\".\nExpect Loader::loadType(AST::SubType &SType) {\n if (auto CodeByte = FMgr.peekByte()) {\n switch (static_cast(*CodeByte)) {\n default:\n // Case: comptype.\n SType.setFinal(true);\n return loadCompositeType(SType.getCompositeType());\n case TypeCode::Sub:\n // Case: 0x50 vec(typeidx) comptype.\n SType.setFinal(false);\n break;\n case TypeCode::SubFinal:\n // Case: 0x4F vec(typeidx) comptype.\n SType.setFinal(true);\n break;\n }\n FMgr.readByte();\n if (auto Res = loadVec(\n SType.getSuperTypeIndices(),\n [this](uint32_t &Idx) -> Expect {\n if (auto Num = FMgr.readU32()) {\n Idx = *Num;\n } else {\n return logLoadError(Num.error(), FMgr.getLastOffset(),\n ASTNodeAttr::Type_Sub);\n }\n return {};\n });\n !Res) {\n return Unexpect(Res);\n }\n return loadCompositeType(SType.getCompositeType());\n } else {\n return logLoadError(CodeByte.error(), FMgr.getLastOffset(),\n ASTNodeAttr::Type_Rec);\n }\n}\n\n// Load binary to construct FunctionType node. See \"include/loader/loader.h\".\nExpect Loader::loadType(AST::FunctionType &FuncType) {\n // Read type of Func (0x60). Moved into the composite type.\n auto LoadValType = [this](ValType &VT) -> Expect {\n if (auto Res = loadValType(ASTNodeAttr::Type_Function)) {\n VT = *Res;\n } else {\n // The error code logging is handled.\n return Unexpect(Res);\n }\n return {};\n };\n // Read vector of parameter types.\n if (auto Res =\n loadVec(FuncType.getParamTypes(), LoadValType);\n !Res) {\n return Unexpect(Res);\n }\n\n // Read vector of result types.\n if (auto Res =\n loadVec(FuncType.getReturnTypes(), LoadValType);\n !Res) {\n return Unexpect(Res);\n }\n\n if (unlikely(!Conf.hasProposal(Proposal::MultiValue)) &&\n FuncType.getReturnTypes().size() > 1) {\n return logNeedProposal(ErrCode::Value::MalformedValType,\n Proposal::MultiValue, FMgr.getLastOffset(),\n ASTNodeAttr::Type_Function);\n }\n return {};\n}\n\n// Load binary to construct MemoryType node. See \"include/loader/loader.h\".\nExpect Loader::loadType(AST::MemoryType &MemType) {\n // Read limit.\n if (auto Res = loadLimit(MemType.getLimit()); !Res) {\n spdlog::error(ErrInfo::InfoAST(ASTNodeAttr::Type_Memory));\n return Unexpect(Res);\n }\n return {};\n}\n\n// Load binary to construct TableType node. See \"include/loader/loader.h\".\nExpect Loader::loadType(AST::TableType &TabType) {\n // Read reference type.\n if (auto Res = loadRefType(ASTNodeAttr::Type_Table)) {\n TabType.setRefType(*Res);\n } else {\n // The AST node information is handled.\n return Unexpect(Res);\n }\n\n // Read limit.\n if (auto Res = loadLimit(TabType.getLimit()); !Res) {\n spdlog::error(ErrInfo::InfoAST(ASTNodeAttr::Type_Table));\n return Unexpect(Res);\n }\n return {};\n}\n\n// Load binary to construct GlobalType node. See \"include/loader/loader.h\".\nExpect Loader::loadType(AST::GlobalType &GlobType) {\n // Read value type.\n if (auto Res = loadValType(ASTNodeAttr::Type_Global)) {\n GlobType.setValType(*Res);\n } else {\n // The AST node information is handled.\n return Unexpect(Res);\n }\n\n // Read mutability.\n if (auto Res = loadMutability(ASTNodeAttr::Type_Global)) {\n GlobType.setValMut(*Res);\n } else {\n // The AST node information is handled.\n return Unexpect(Res);\n }\n return {};\n}\n\n} // namespace Loader\n} // namespace WasmEdge\n\n// Path: lib/loader/ast/segment.cpp\n// SPDX-License-Identifier: Apache-2.0\n// SPDX-FileCopyrightText: 2019-2022 Second State INC\n\n#include \"loader/loader.h\"\n\n#include \n#include \n\nnamespace WasmEdge {\nnamespace Loader {\n\n// Load binary of TableSegment node. See \"include/loader/loader.h\".\nExpect Loader::loadSegment(AST::TableSegment &TabSeg) {\n // Check the first byte is the reftype in table type or not.\n if (auto CheckByte = FMgr.peekByte()) {\n if (*CheckByte == 0x40U) {\n // Table segment case is for FunctionReferences proposal.\n if (!Conf.hasProposal(Proposal::FunctionReferences)) {\n return logNeedProposal(ErrCode::Value::MalformedTable,\n Proposal::FunctionReferences,\n FMgr.getLastOffset(), ASTNodeAttr::Seg_Table);\n }\n FMgr.readByte();\n\n // Check the second byte.\n if (auto Res = FMgr.readByte()) {\n if (*Res != 0x00U) {\n return logLoadError(ErrCode::Value::MalformedTable,\n FMgr.getLastOffset(), ASTNodeAttr::Seg_Table);\n }\n } else {\n return logLoadError(Res.error(), FMgr.getLastOffset(),\n ASTNodeAttr::Seg_Table);\n }\n\n // Read the table type.\n if (auto Res = loadType(TabSeg.getTableType()); unlikely(!Res)) {\n spdlog::error(ErrInfo::InfoAST(ASTNodeAttr::Seg_Table));\n return Unexpect(Res);\n }\n\n // Read the expression.\n if (auto Res = loadExpression(TabSeg.getExpr()); unlikely(!Res)) {\n spdlog::error(ErrInfo::InfoAST(ASTNodeAttr::Seg_Global));\n return Unexpect(Res);\n }\n } else {\n // The table type case.\n if (auto Res = loadType(TabSeg.getTableType()); unlikely(!Res)) {\n spdlog::error(ErrInfo::InfoAST(ASTNodeAttr::Seg_Table));\n return Unexpect(Res);\n }\n }\n } else {\n return logLoadError(CheckByte.error(), FMgr.getLastOffset(),\n ASTNodeAttr::Seg_Table);\n }\n\n return {};\n}\n\n// Load binary of GlobalSegment node. See \"include/loader/loader.h\".\nExpect Loader::loadSegment(AST::GlobalSegment &GlobSeg) {\n // Read global type node.\n if (auto Res = loadType(GlobSeg.getGlobalType()); unlikely(!Res)) {\n spdlog::error(ErrInfo::InfoAST(ASTNodeAttr::Seg_Global));\n return Unexpect(Res);\n }\n\n // Read the expression.\n if (auto Res = loadExpression(GlobSeg.getExpr()); unlikely(!Res)) {\n spdlog::error(ErrInfo::InfoAST(ASTNodeAttr::Seg_Global));\n return Unexpect(Res);\n }\n\n return {};\n}\n\n// Load binary of ElementSegment node. See \"include/loader/loader.h\".\nExpect Loader::loadSegment(AST::ElementSegment &ElemSeg) {\n // Element segment binary format:\n // ---------------------------------------------------------------------------\n // Mode | TableIdx | OffExpr | ElemKind | RefType | vec(FuncIdx) | vec(expr)\n // ------|----------|---------|----------|---------|--------------|-----------\n // 0 | | v | | | v |\n // 1 | | | v | | v |\n // 2 | v | v | v | | v |\n // 3 | | | v | | v |\n // 4 | | v | | | | v\n // 5 | | | | v | | v\n // 6 | v | v | | v | | v\n // 7 | | | | v | | v\n // ---------------------------------------------------------------------------\n // Mode: element initial integer, u32\n // TableIdx: target table index, u32\n // OffExpr: init offset expression, expr\n // ElemKind: byte 0x00, RefType::FuncRef\n // RefType: reference type, RefType\n // vec(FuncIdx): function index vector, vec(u32)\n // vec(expr): reference init list, vec(expr)\n\n // Read the checking byte.\n uint32_t Check;\n if (auto Res = FMgr.readU32()) {\n Check = *Res;\n } else {\n return logLoadError(Res.error(), FMgr.getLastOffset(),\n ASTNodeAttr::Seg_Element);\n }\n // Check > 0 cases are for BulkMemoryOperations or ReferenceTypes proposal.\n if (Check > 0 && !Conf.hasProposal(Proposal::BulkMemoryOperations) &&\n !Conf.hasProposal(Proposal::ReferenceTypes)) {\n return logNeedProposal(ErrCode::Value::ExpectedZeroByte,\n Proposal::BulkMemoryOperations, FMgr.getLastOffset(),\n ASTNodeAttr::Seg_Element);\n }\n\n // Check the prefix byte.\n switch (Check) {\n case 0x00:\n case 0x02:\n case 0x04:\n case 0x06:\n ElemSeg.setMode(AST::ElementSegment::ElemMode::Active);\n break;\n\n case 0x01:\n case 0x05:\n ElemSeg.setMode(AST::ElementSegment::ElemMode::Passive);\n break;\n\n case 0x03:\n case 0x07:\n ElemSeg.setMode(AST::ElementSegment::ElemMode::Declarative);\n break;\n\n default:\n // TODO: Correctness the error code once there's spec test.\n return logLoadError(ErrCode::Value::IllegalGrammar, FMgr.getLastOffset(),\n ASTNodeAttr::Seg_Element);\n }\n\n // Read the table index.\n ElemSeg.setIdx(0);\n switch (Check) {\n case 0x02:\n case 0x06:\n if (auto Res = FMgr.readU32()) {\n ElemSeg.setIdx(*Res);\n } else {\n return logLoadError(Res.error(), FMgr.getLastOffset(),\n ASTNodeAttr::Seg_Element);\n }\n break;\n\n default:\n break;\n }\n\n // Read the expression.\n switch (Check) {\n case 0x00:\n case 0x02:\n case 0x04:\n case 0x06:\n if (auto Res = loadExpression(ElemSeg.getExpr()); unlikely(!Res)) {\n spdlog::error(ErrInfo::InfoAST(ASTNodeAttr::Seg_Element));\n return Unexpect(Res);\n }\n break;\n\n default:\n break;\n }\n\n // Read element kind and init function indices.\n ElemSeg.setRefType(TypeCode::FuncRef);\n switch (Check) {\n case 0x01:\n case 0x02:\n case 0x03:\n if (auto Res = FMgr.readByte()) {\n if (*Res != 0x00U) {\n return logLoadError(ErrCode::Value::ExpectedZeroByte,\n FMgr.getLastOffset(), ASTNodeAttr::Seg_Element);\n }\n } else {\n return logLoadError(Res.error(), FMgr.getLastOffset(),\n ASTNodeAttr::Seg_Element);\n }\n [[fallthrough]];\n\n case 0x00: {\n uint32_t VecCnt = 0;\n if (auto Res = FMgr.readU32()) {\n VecCnt = *Res;\n } else {\n return logLoadError(Res.error(), FMgr.getLastOffset(),\n ASTNodeAttr::Seg_Element);\n }\n for (uint32_t I = 0; I < VecCnt; ++I) {\n // For each element in vec(funcidx), make expr(ref.func idx end).\n ElemSeg.getInitExprs().emplace_back();\n AST::Instruction RefFunc(OpCode::Ref__func);\n AST::Instruction End(OpCode::End);\n if (auto Res = loadInstruction(RefFunc); unlikely(!Res)) {\n spdlog::error(ErrInfo::InfoAST(ASTNodeAttr::Seg_Element));\n return Unexpect(Res);\n }\n ElemSeg.getInitExprs().back().getInstrs().emplace_back(\n std::move(RefFunc));\n ElemSeg.getInitExprs().back().getInstrs().emplace_back(std::move(End));\n }\n break;\n }\n default:\n break;\n }\n\n // Read the reference type and init expressions.\n switch (Check) {\n case 0x05:\n case 0x06:\n case 0x07:\n if (auto Res = loadRefType(ASTNodeAttr::Seg_Element)) {\n ElemSeg.setRefType(*Res);\n } else {\n // The AST node information is handled.\n return Unexpect(Res);\n }\n [[fallthrough]];\n case 0x04: {\n return loadVec(\n ElemSeg.getInitExprs(), [this](AST::Expression &Expr) -> Expect {\n return loadExpression(Expr);\n });\n }\n\n default:\n break;\n }\n\n return {};\n}\n\n// Load binary of CodeSegment node. See \"include/loader/loader.h\".\nExpect Loader::loadSegment(AST::CodeSegment &CodeSeg) {\n // Read the code segment size.\n if (auto Res = FMgr.readU32()) {\n CodeSeg.setSegSize(*Res);\n } else {\n return logLoadError(Res.error(), FMgr.getLastOffset(),\n ASTNodeAttr::Seg_Code);\n }\n auto ExprSizeBound = FMgr.getOffset() + CodeSeg.getSegSize();\n\n // Read the vector of local variable counts and types.\n uint32_t VecCnt = 0;\n if (auto Res = loadVecCnt()) {\n VecCnt = *Res;\n } else {\n return logLoadError(Res.error(), FMgr.getLastOffset(),\n ASTNodeAttr::Seg_Element);\n }\n CodeSeg.getLocals().clear();\n CodeSeg.getLocals().reserve(VecCnt);\n uint32_t TotalLocalCnt = 0;\n for (uint32_t I = 0; I < VecCnt; ++I) {\n uint32_t LocalCnt = 0;\n ValType LocalType;\n if (auto Res = FMgr.readU32(); unlikely(!Res)) {\n return logLoadError(Res.error(), FMgr.getLastOffset(),\n ASTNodeAttr::Seg_Code);\n...\n", "needle_token_start": 2416, "needle_token_end": 2446, "code_context_ntokens": 16264, "output": ["```cpp\ninline constexpr uint32_t HostVersion() noexcept {\n return WasmEdge::AOT::kBinaryVersion;\n}\n```"]} +{"repo": "WasmEdge/WasmEdge", "name": "toLLVMConstantZero", "language": "cpp", "path": "lib/llvm/compiler.cpp", "position_ratio": 0.25, "description": "\nFunction Description:\n1. **Purpose**: The function generates a zero-initialized constant value corresponding to a specified data type in the LLVM context.\n2. **Input**: It takes a context for LLVM operations and a value type that specifies the data type for which a zero value is needed.\n3. **Output**: It returns an LLVM value that represents a zero-initialized constant of the specified data type.\n4. **Procedure**: The function checks the code of the input value type and, based on this code, it creates and returns a zero-initialized constant using the appropriate LLVM data type. For integer types, vector types, and floating-point types, it uses specific LLVM functions to create a null (zero) value of the corresponding type.\n\n", "instruction": "Based on the function description and code context, please retrieve and repeat the exact described function from the code context in a code block wrapped by ```:", "template": "instruction\ncode_context\ndescription\ninstruction", "code_context": " LLVM::Type Int64PtrTy;\n LLVM::Type Int128PtrTy;\n LLVM::Type ExecCtxTy;\n LLVM::Type ExecCtxPtrTy;\n LLVM::Type IntrinsicsTableTy;\n LLVM::Type IntrinsicsTablePtrTy;\n LLVM::Message SubtargetFeatures;\n\n#if defined(__x86_64__)\n#if defined(__XOP__)\n bool SupportXOP = true;\n#else\n bool SupportXOP = false;\n#endif\n\n#if defined(__SSE4_1__)\n bool SupportSSE4_1 = true;\n#else\n bool SupportSSE4_1 = false;\n#endif\n\n#if defined(__SSSE3__)\n bool SupportSSSE3 = true;\n#else\n bool SupportSSSE3 = false;\n#endif\n\n#if defined(__SSE2__)\n bool SupportSSE2 = true;\n#else\n bool SupportSSE2 = false;\n#endif\n#endif\n\n#if defined(__aarch64__)\n#if defined(__ARM_NEON__) || defined(__ARM_NEON) || defined(__ARM_NEON_FP)\n bool SupportNEON = true;\n#else\n bool SupportNEON = false;\n#endif\n#endif\n\n std::vector FunctionTypes;\n std::vector FunctionWrappers;\n std::vector>\n Functions;\n std::vector Globals;\n LLVM::Value IntrinsicsTable;\n LLVM::FunctionCallee Trap;\n CompileContext(LLVM::Context C, LLVM::Module &M,\n bool IsGenericBinary) noexcept\n : LLContext(C), LLModule(M),\n Cold(LLVM::Attribute::createEnum(C, LLVM::Core::Cold, 0)),\n NoAlias(LLVM::Attribute::createEnum(C, LLVM::Core::NoAlias, 0)),\n NoInline(LLVM::Attribute::createEnum(C, LLVM::Core::NoInline, 0)),\n NoReturn(LLVM::Attribute::createEnum(C, LLVM::Core::NoReturn, 0)),\n ReadOnly(LLVM::Attribute::createEnum(C, LLVM::Core::ReadOnly, 0)),\n StrictFP(LLVM::Attribute::createEnum(C, LLVM::Core::StrictFP, 0)),\n UWTable(LLVM::Attribute::createEnum(C, LLVM::Core::UWTable,\n LLVM::Core::UWTableDefault)),\n NoStackArgProbe(\n LLVM::Attribute::createString(C, \"no-stack-arg-probe\"sv, {})),\n VoidTy(LLContext.getVoidTy()), Int8Ty(LLContext.getInt8Ty()),\n Int16Ty(LLContext.getInt16Ty()), Int32Ty(LLContext.getInt32Ty()),\n Int64Ty(LLContext.getInt64Ty()), Int128Ty(LLContext.getInt128Ty()),\n FloatTy(LLContext.getFloatTy()), DoubleTy(LLContext.getDoubleTy()),\n Int8x16Ty(LLVM::Type::getVectorType(Int8Ty, 16)),\n Int16x8Ty(LLVM::Type::getVectorType(Int16Ty, 8)),\n Int32x4Ty(LLVM::Type::getVectorType(Int32Ty, 4)),\n Floatx4Ty(LLVM::Type::getVectorType(FloatTy, 4)),\n Int64x2Ty(LLVM::Type::getVectorType(Int64Ty, 2)),\n Doublex2Ty(LLVM::Type::getVectorType(DoubleTy, 2)),\n Int128x1Ty(LLVM::Type::getVectorType(Int128Ty, 1)),\n Int8PtrTy(Int8Ty.getPointerTo()), Int32PtrTy(Int32Ty.getPointerTo()),\n Int64PtrTy(Int64Ty.getPointerTo()),\n Int128PtrTy(Int128Ty.getPointerTo()),\n ExecCtxTy(LLVM::Type::getStructType(\n \"ExecCtx\",\n std::initializer_list{\n // Memory\n Int8PtrTy.getPointerTo(),\n // Globals\n Int128PtrTy.getPointerTo(),\n // InstrCount\n Int64PtrTy,\n // CostTable\n LLVM::Type::getArrayType(Int64Ty, UINT16_MAX + 1)\n .getPointerTo(),\n // Gas\n Int64PtrTy,\n // GasLimit\n Int64Ty,\n // StopToken\n Int32PtrTy,\n })),\n ExecCtxPtrTy(ExecCtxTy.getPointerTo()),\n IntrinsicsTableTy(LLVM::Type::getArrayType(\n Int8PtrTy,\n static_cast(Executable::Intrinsics::kIntrinsicMax))),\n IntrinsicsTablePtrTy(IntrinsicsTableTy.getPointerTo()),\n IntrinsicsTable(LLModule.addGlobal(IntrinsicsTablePtrTy, true,\n LLVMExternalLinkage, LLVM::Value(),\n \"intrinsics\")) {\n Trap.Ty = LLVM::Type::getFunctionType(VoidTy, {Int32Ty});\n Trap.Fn = LLModule.addFunction(Trap.Ty, LLVMPrivateLinkage, \"trap\");\n Trap.Fn.setDSOLocal(true);\n Trap.Fn.addFnAttr(NoStackArgProbe);\n Trap.Fn.addFnAttr(StrictFP);\n Trap.Fn.addFnAttr(UWTable);\n Trap.Fn.addFnAttr(NoReturn);\n Trap.Fn.addFnAttr(Cold);\n Trap.Fn.addFnAttr(NoInline);\n\n LLModule.addGlobal(Int32Ty, true, LLVMExternalLinkage,\n LLVM::Value::getConstInt(Int32Ty, AOT::kBinaryVersion),\n \"version\");\n\n if (!IsGenericBinary) {\n SubtargetFeatures = LLVM::getHostCPUFeatures();\n auto Features = SubtargetFeatures.string_view();\n while (!Features.empty()) {\n std::string_view Feature;\n if (auto Pos = Features.find(','); Pos != std::string_view::npos) {\n Feature = Features.substr(0, Pos);\n Features = Features.substr(Pos + 1);\n } else {\n Feature = std::exchange(Features, std::string_view());\n }\n if (Feature[0] != '+') {\n continue;\n }\n Feature = Feature.substr(1);\n\n#if defined(__x86_64__)\n if (!SupportXOP && Feature == \"xop\"sv) {\n SupportXOP = true;\n }\n if (!SupportSSE4_1 && Feature == \"sse4.1\"sv) {\n SupportSSE4_1 = true;\n }\n if (!SupportSSSE3 && Feature == \"ssse3\"sv) {\n SupportSSSE3 = true;\n }\n if (!SupportSSE2 && Feature == \"sse2\"sv) {\n SupportSSE2 = true;\n }\n#elif defined(__aarch64__)\n if (!SupportNEON && Feature == \"neon\"sv) {\n SupportNEON = true;\n }\n#endif\n }\n }\n\n {\n // create trap\n LLVM::Builder Builder(LLContext);\n Builder.positionAtEnd(\n LLVM::BasicBlock::create(LLContext, Trap.Fn, \"entry\"));\n auto FnTy = LLVM::Type::getFunctionType(VoidTy, {Int32Ty});\n auto CallTrap = Builder.createCall(\n getIntrinsic(Builder, Executable::Intrinsics::kTrap, FnTy),\n {Trap.Fn.getFirstParam()});\n CallTrap.addCallSiteAttribute(NoReturn);\n Builder.createUnreachable();\n }\n }\n LLVM::Value getMemory(LLVM::Builder &Builder, LLVM::Value ExecCtx,\n uint32_t Index) noexcept {\n auto Array = Builder.createExtractValue(ExecCtx, 0);\n auto VPtr = Builder.createLoad(\n Int8PtrTy, Builder.createInBoundsGEP1(Int8PtrTy, Array,\n LLContext.getInt64(Index)));\n VPtr.setMetadata(LLContext, LLVM::Core::InvariantGroup,\n LLVM::Metadata(LLContext, {}));\n return Builder.createBitCast(VPtr, Int8PtrTy);\n }\n std::pair getGlobal(LLVM::Builder &Builder,\n LLVM::Value ExecCtx,\n uint32_t Index) noexcept {\n auto Ty = Globals[Index];\n auto Array = Builder.createExtractValue(ExecCtx, 1);\n auto VPtr = Builder.createLoad(\n Int128PtrTy, Builder.createInBoundsGEP1(Int8PtrTy, Array,\n LLContext.getInt64(Index)));\n VPtr.setMetadata(LLContext, LLVM::Core::InvariantGroup,\n LLVM::Metadata(LLContext, {}));\n auto Ptr = Builder.createBitCast(VPtr, Ty.getPointerTo());\n return {Ty, Ptr};\n }\n LLVM::Value getInstrCount(LLVM::Builder &Builder,\n LLVM::Value ExecCtx) noexcept {\n return Builder.createExtractValue(ExecCtx, 2);\n }\n LLVM::Value getCostTable(LLVM::Builder &Builder,\n LLVM::Value ExecCtx) noexcept {\n return Builder.createExtractValue(ExecCtx, 3);\n }\n LLVM::Value getGas(LLVM::Builder &Builder, LLVM::Value ExecCtx) noexcept {\n return Builder.createExtractValue(ExecCtx, 4);\n }\n LLVM::Value getGasLimit(LLVM::Builder &Builder,\n LLVM::Value ExecCtx) noexcept {\n return Builder.createExtractValue(ExecCtx, 5);\n }\n LLVM::Value getStopToken(LLVM::Builder &Builder,\n LLVM::Value ExecCtx) noexcept {\n return Builder.createExtractValue(ExecCtx, 6);\n }\n LLVM::FunctionCallee getIntrinsic(LLVM::Builder &Builder,\n Executable::Intrinsics Index,\n LLVM::Type Ty) noexcept {\n const auto Value = static_cast(Index);\n auto PtrTy = Ty.getPointerTo();\n auto PtrPtrTy = PtrTy.getPointerTo();\n auto IT = Builder.createLoad(IntrinsicsTablePtrTy, IntrinsicsTable);\n IT.setMetadata(LLContext, LLVM::Core::InvariantGroup,\n LLVM::Metadata(LLContext, {}));\n auto VPtr =\n Builder.createInBoundsGEP2(IntrinsicsTableTy, IT, LLContext.getInt64(0),\n LLContext.getInt64(Value));\n auto Ptr = Builder.createBitCast(VPtr, PtrPtrTy);\n return {Ty, Builder.createLoad(PtrTy, Ptr)};\n }\n std::pair, std::vector>\n resolveBlockType(const BlockType &BType) const noexcept {\n using VecT = std::vector;\n using RetT = std::pair;\n if (BType.isEmpty()) {\n return RetT{};\n }\n if (BType.isValType()) {\n return RetT{{}, {BType.getValType()}};\n } else {\n // Type index case. t2* = type[index].returns\n const uint32_t TypeIdx = BType.getTypeIndex();\n const auto &FType = *FunctionTypes[TypeIdx];\n return RetT{\n VecT(FType.getParamTypes().begin(), FType.getParamTypes().end()),\n VecT(FType.getReturnTypes().begin(), FType.getReturnTypes().end())};\n }\n }\n};\n\nnamespace {\n\nusing namespace WasmEdge;\n\nstatic bool isVoidReturn(Span ValTypes) noexcept {\n return ValTypes.empty();\n}\n\nstatic LLVM::Type toLLVMType(LLVM::Context LLContext,\n const ValType &ValType) noexcept {\n switch (ValType.getCode()) {\n case TypeCode::I32:\n return LLContext.getInt32Ty();\n case TypeCode::I64:\n return LLContext.getInt64Ty();\n case TypeCode::Ref:\n case TypeCode::RefNull:\n case TypeCode::V128:\n return LLVM::Type::getVectorType(LLContext.getInt64Ty(), 2);\n case TypeCode::F32:\n return LLContext.getFloatTy();\n case TypeCode::F64:\n return LLContext.getDoubleTy();\n default:\n assumingUnreachable();\n }\n}\n\nstatic std::vector\ntoLLVMTypeVector(LLVM::Context LLContext,\n Span ValTypes) noexcept {\n std::vector Result;\n Result.reserve(ValTypes.size());\n for (const auto &Type : ValTypes) {\n Result.push_back(toLLVMType(LLContext, Type));\n }\n return Result;\n}\n\nstatic std::vector\ntoLLVMArgsType(LLVM::Context LLContext, LLVM::Type ExecCtxPtrTy,\n Span ValTypes) noexcept {\n auto Result = toLLVMTypeVector(LLContext, ValTypes);\n Result.insert(Result.begin(), ExecCtxPtrTy);\n return Result;\n}\n\nstatic LLVM::Type toLLVMRetsType(LLVM::Context LLContext,\n Span ValTypes) noexcept {\n if (isVoidReturn(ValTypes)) {\n return LLContext.getVoidTy();\n }\n if (ValTypes.size() == 1) {\n return toLLVMType(LLContext, ValTypes.front());\n }\n std::vector Result;\n Result.reserve(ValTypes.size());\n for (const auto &Type : ValTypes) {\n Result.push_back(toLLVMType(LLContext, Type));\n }\n return LLVM::Type::getStructType(Result);\n}\n\nstatic LLVM::Type toLLVMType(LLVM::Context LLContext, LLVM::Type ExecCtxPtrTy,\n const AST::FunctionType &FuncType) noexcept {\n auto ArgsTy =\n toLLVMArgsType(LLContext, ExecCtxPtrTy, FuncType.getParamTypes());\n auto RetTy = toLLVMRetsType(LLContext, FuncType.getReturnTypes());\n return LLVM::Type::getFunctionType(RetTy, ArgsTy);\n}\n\n\nstatic LLVM::Value toLLVMConstantZero(LLVM::Context LLContext,\n const ValType &ValType) noexcept {\n switch (ValType.getCode()) {\n case TypeCode::I32:\n return LLVM::Value::getConstNull(LLContext.getInt32Ty());\n case TypeCode::I64:\n return LLVM::Value::getConstNull(LLContext.getInt64Ty());\n case TypeCode::Ref:\n case TypeCode::RefNull:\n case TypeCode::V128:\n return LLVM::Value::getConstNull(\n LLVM::Type::getVectorType(LLContext.getInt64Ty(), 2));\n case TypeCode::F32:\n return LLVM::Value::getConstNull(LLContext.getFloatTy());\n case TypeCode::F64:\n return LLVM::Value::getConstNull(LLContext.getDoubleTy());\n default:\n assumingUnreachable();\n }\n}\n\nclass FunctionCompiler {\n struct Control;\n\npublic:\n FunctionCompiler(LLVM::Compiler::CompileContext &Context,\n LLVM::FunctionCallee F, Span Locals,\n bool Interruptible, bool InstructionCounting,\n bool GasMeasuring) noexcept\n : Context(Context), LLContext(Context.LLContext),\n Interruptible(Interruptible), F(F), Builder(LLContext) {\n if (F.Fn) {\n Builder.positionAtEnd(LLVM::BasicBlock::create(LLContext, F.Fn, \"entry\"));\n ExecCtx = Builder.createLoad(Context.ExecCtxTy, F.Fn.getFirstParam());\n\n if (InstructionCounting) {\n LocalInstrCount = Builder.createAlloca(Context.Int64Ty);\n Builder.createStore(LLContext.getInt64(0), LocalInstrCount);\n }\n\n if (GasMeasuring) {\n LocalGas = Builder.createAlloca(Context.Int64Ty);\n Builder.createStore(LLContext.getInt64(0), LocalGas);\n }\n\n for (LLVM::Value Arg = F.Fn.getFirstParam().getNextParam(); Arg;\n Arg = Arg.getNextParam()) {\n LLVM::Type Ty = Arg.getType();\n LLVM::Value ArgPtr = Builder.createAlloca(Ty);\n Builder.createStore(Arg, ArgPtr);\n Local.emplace_back(Ty, ArgPtr);\n }\n\n for (const auto &Type : Locals) {\n LLVM::Type Ty = toLLVMType(LLContext, Type);\n LLVM::Value ArgPtr = Builder.createAlloca(Ty);\n Builder.createStore(toLLVMConstantZero(LLContext, Type), ArgPtr);\n Local.emplace_back(Ty, ArgPtr);\n }\n }\n }\n\n LLVM::BasicBlock getTrapBB(ErrCode::Value Error) noexcept {\n if (auto Iter = TrapBB.find(Error); Iter != TrapBB.end()) {\n return Iter->second;\n }\n auto BB = LLVM::BasicBlock::create(LLContext, F.Fn, \"trap\");\n TrapBB.emplace(Error, BB);\n return BB;\n }\n\n void\n compile(const AST::CodeSegment &Code,\n std::pair, std::vector> Type) noexcept {\n auto RetBB = LLVM::BasicBlock::create(LLContext, F.Fn, \"ret\");\n Type.first.clear();\n enterBlock(RetBB, {}, {}, {}, std::move(Type));\n compile(Code.getExpr().getInstrs());\n assuming(ControlStack.empty());\n compileReturn();\n\n for (auto &[Error, BB] : TrapBB) {\n Builder.positionAtEnd(BB);\n updateInstrCount();\n updateGasAtTrap();\n auto CallTrap = Builder.createCall(\n Context.Trap, {LLContext.getInt32(static_cast(Error))});\n CallTrap.addCallSiteAttribute(Context.NoReturn);\n Builder.createUnreachable();\n }\n }\n\n void compile(AST::InstrView Instrs) noexcept {\n auto Dispatch = [this](const AST::Instruction &Instr) -> void {\n switch (Instr.getOpCode()) {\n case OpCode::Block: {\n auto Block = LLVM::BasicBlock::create(LLContext, F.Fn, \"block\");\n auto EndBlock = LLVM::BasicBlock::create(LLContext, F.Fn, \"block.end\");\n Builder.createBr(Block);\n\n Builder.positionAtEnd(Block);\n auto Type = Context.resolveBlockType(Instr.getBlockType());\n const auto Arity = Type.first.size();\n std::vector Args(Arity);\n if (isUnreachable()) {\n for (size_t I = 0; I < Arity; ++I) {\n auto Ty = toLLVMType(LLContext, Type.first[I]);\n Args[I] = LLVM::Value::getUndef(Ty);\n }\n } else {\n for (size_t I = 0; I < Arity; ++I) {\n const size_t J = Arity - 1 - I;\n Args[J] = stackPop();\n }\n }\n enterBlock(EndBlock, {}, {}, std::move(Args), std::move(Type));\n checkStop();\n updateGas();\n return;\n }\n case OpCode::Loop: {\n auto Curr = Builder.getInsertBlock();\n auto Loop = LLVM::BasicBlock::create(LLContext, F.Fn, \"loop\");\n auto EndLoop = LLVM::BasicBlock::create(LLContext, F.Fn, \"loop.end\");\n Builder.createBr(Loop);\n\n Builder.positionAtEnd(Loop);\n auto Type = Context.resolveBlockType(Instr.getBlockType());\n const auto Arity = Type.first.size();\n std::vector Args(Arity);\n if (isUnreachable()) {\n for (size_t I = 0; I < Arity; ++I) {\n auto Ty = toLLVMType(LLContext, Type.first[I]);\n auto Value = LLVM::Value::getUndef(Ty);\n auto PHINode = Builder.createPHI(Ty);\n PHINode.addIncoming(Value, Curr);\n Args[I] = PHINode;\n }\n } else {\n for (size_t I = 0; I < Arity; ++I) {\n const size_t J = Arity - 1 - I;\n auto Value = stackPop();\n auto PHINode = Builder.createPHI(Value.getType());\n PHINode.addIncoming(Value, Curr);\n Args[J] = PHINode;\n }\n }\n enterBlock(Loop, EndLoop, {}, std::move(Args), std::move(Type));\n checkStop();\n updateGas();\n return;\n }\n case OpCode::If: {\n auto Then = LLVM::BasicBlock::create(LLContext, F.Fn, \"then\");\n auto Else = LLVM::BasicBlock::create(LLContext, F.Fn, \"else\");\n auto EndIf = LLVM::BasicBlock::create(LLContext, F.Fn, \"if.end\");\n LLVM::Value Cond;\n if (isUnreachable()) {\n Cond = LLVM::Value::getUndef(LLContext.getInt1Ty());\n } else {\n Cond = Builder.createICmpNE(stackPop(), LLContext.getInt32(0));\n }\n Builder.createCondBr(Cond, Then, Else);\n\n Builder.positionAtEnd(Then);\n auto Type = Context.resolveBlockType(Instr.getBlockType());\n const auto Arity = Type.first.size();\n std::vector Args(Arity);\n if (isUnreachable()) {\n for (size_t I = 0; I < Arity; ++I) {\n auto Ty = toLLVMType(LLContext, Type.first[I]);\n Args[I] = LLVM::Value::getUndef(Ty);\n }\n } else {\n for (size_t I = 0; I < Arity; ++I) {\n const size_t J = Arity - 1 - I;\n Args[J] = stackPop();\n }\n }\n enterBlock(EndIf, {}, Else, std::move(Args), std::move(Type));\n return;\n }\n case OpCode::End: {\n auto Entry = leaveBlock();\n if (Entry.ElseBlock) {\n auto Block = Builder.getInsertBlock();\n Builder.positionAtEnd(Entry.ElseBlock);\n enterBlock(Block, {}, {}, std::move(Entry.Args),\n std::move(Entry.Type), std::move(Entry.ReturnPHI));\n Entry = leaveBlock();\n }\n buildPHI(Entry.Type.second, Entry.ReturnPHI);\n return;\n }\n case OpCode::Else: {\n auto Entry = leaveBlock();\n Builder.positionAtEnd(Entry.ElseBlock);\n enterBlock(Entry.JumpBlock, {}, {}, std::move(Entry.Args),\n std::move(Entry.Type), std::move(Entry.ReturnPHI));\n return;\n }\n default:\n break;\n }\n\n if (isUnreachable()) {\n return;\n }\n\n switch (Instr.getOpCode()) {\n case OpCode::Unreachable:\n Builder.createBr(getTrapBB(ErrCode::Value::Unreachable));\n setUnreachable();\n Builder.positionAtEnd(\n LLVM::BasicBlock::create(LLContext, F.Fn, \"unreachable.end\"));\n break;\n case OpCode::Nop:\n break;\n case OpCode::Return:\n compileReturn();\n setUnreachable();\n Builder.positionAtEnd(\n LLVM::BasicBlock::create(LLContext, F.Fn, \"ret.end\"));\n break;\n case OpCode::Br: {\n const auto Label = Instr.getJump().TargetIndex;\n setLableJumpPHI(Label);\n Builder.createBr(getLabel(Label));\n setUnreachable();\n Builder.positionAtEnd(\n LLVM::BasicBlock::create(LLContext, F.Fn, \"br.end\"));\n break;\n }\n case OpCode::Br_if: {\n const auto Label = Instr.getJump().TargetIndex;\n auto Cond = Builder.createICmpNE(stackPop(), LLContext.getInt32(0));\n setLableJumpPHI(Label);\n auto Next = LLVM::BasicBlock::create(LLContext, F.Fn, \"br_if.end\");\n Builder.createCondBr(Cond, getLabel(Label), Next);\n Builder.positionAtEnd(Next);\n break;\n }\n case OpCode::Br_table: {\n auto LabelTable = Instr.getLabelList();\n assuming(LabelTable.size() <= std::numeric_limits::max());\n const auto LabelTableSize =\n static_cast(LabelTable.size() - 1);\n auto Value = stackPop();\n setLableJumpPHI(LabelTable[LabelTableSize].TargetIndex);\n auto Switch = Builder.createSwitch(\n Value, getLabel(LabelTable[LabelTableSize].TargetIndex),\n LabelTableSize);\n for (uint32_t I = 0; I < LabelTableSize; ++I) {\n setLableJumpPHI(LabelTable[I].TargetIndex);\n Switch.addCase(LLContext.getInt32(I),\n getLabel(LabelTable[I].TargetIndex));\n }\n setUnreachable();\n Builder.positionAtEnd(\n LLVM::BasicBlock::create(LLContext, F.Fn, \"br_table.end\"));\n break;\n }\n case OpCode::Br_on_null: {\n const auto Label = Instr.getJump().TargetIndex;\n auto Value = Builder.createBitCast(stackPop(), Context.Int64x2Ty);\n auto Cond = Builder.createICmpEQ(\n Builder.createExtractElement(Value, LLContext.getInt64(1)),\n LLContext.getInt64(0));\n setLableJumpPHI(Label);\n auto Next = LLVM::BasicBlock::create(LLContext, F.Fn, \"br_on_null.end\");\n Builder.createCondBr(Cond, getLabel(Label), Next);\n Builder.positionAtEnd(Next);\n stackPush(Value);\n break;\n }\n case OpCode::Br_on_non_null: {\n const auto Label = Instr.getJump().TargetIndex;\n auto Cond = Builder.createICmpNE(\n Builder.createExtractElement(\n Builder.createBitCast(Stack.back(), Context.Int64x2Ty),\n LLContext.getInt64(1)),\n LLContext.getInt64(0));\n setLableJumpPHI(Label);\n auto Next =\n LLVM::BasicBlock::create(LLContext, F.Fn, \"br_on_non_null.end\");\n Builder.createCondBr(Cond, getLabel(Label), Next);\n Builder.positionAtEnd(Next);\n stackPop();\n break;\n }\n case OpCode::Call:\n updateInstrCount();\n updateGas();\n compileCallOp(Instr.getTargetIndex());\n break;\n case OpCode::Call_indirect:\n updateInstrCount();\n updateGas();\n compileIndirectCallOp(Instr.getSourceIndex(), Instr.getTargetIndex());\n break;\n case OpCode::Return_call:\n updateInstrCount();\n updateGas();\n compileReturnCallOp(Instr.getTargetIndex());\n setUnreachable();\n Builder.positionAtEnd(\n LLVM::BasicBlock::create(LLContext, F.Fn, \"ret_call.end\"));\n break;\n case OpCode::Return_call_indirect:\n updateInstrCount();\n updateGas();\n compileReturnIndirectCallOp(Instr.getSourceIndex(),\n Instr.getTargetIndex());\n setUnreachable();\n Builder.positionAtEnd(\n LLVM::BasicBlock::create(LLContext, F.Fn, \"ret_call_indir.end\"));\n break;\n case OpCode::Call_ref:\n updateInstrCount();\n updateGas();\n compileCallRefOp(Instr.getTargetIndex());\n break;\n case OpCode::Return_call_ref:\n updateInstrCount();\n updateGas();\n compileReturnCallRefOp(Instr.getTargetIndex());\n setUnreachable();\n Builder.positionAtEnd(\n LLVM::BasicBlock::create(LLContext, F.Fn, \"ret_call_ref.end\"));\n break;\n case OpCode::Ref__null: {\n std::array Val = {0};\n // For null references, the dynamic type down scaling is needed.\n ValType VType;\n if (Instr.getValType().isAbsHeapType()) {\n switch (Instr.getValType().getHeapTypeCode()) {\n case TypeCode::NullFuncRef:\n case TypeCode::FuncRef:\n VType = TypeCode::NullFuncRef;\n break;\n case TypeCode::NullExternRef:\n case TypeCode::ExternRef:\n VType = TypeCode::NullExternRef;\n break;\n case TypeCode::NullRef:\n case TypeCode::AnyRef:\n case TypeCode::EqRef:\n case TypeCode::I31Ref:\n case TypeCode::StructRef:\n case TypeCode::ArrayRef:\n VType = TypeCode::NullRef;\n break;\n default:\n assumingUnreachable();\n }\n } else {\n // TODO: GC - AOT: support other composite here.\n VType = TypeCode::NullFuncRef;\n }\n std::copy_n(VType.getRawData().cbegin(), 8, Val.begin());\n auto Vector = LLVM::Value::getConstVector8(LLContext, Val);\n stackPush(Builder.createBitCast(Vector, Context.Int64x2Ty));\n break;\n }\n case OpCode::Ref__is_null:\n stackPush(Builder.createZExt(\n Builder.createICmpEQ(\n Builder.createExtractElement(\n Builder.createBitCast(stackPop(), Context.Int64x2Ty),\n LLContext.getInt64(1)),\n LLContext.getInt64(0)),\n Context.Int32Ty));\n break;\n case OpCode::Ref__func:\n stackPush(Builder.createCall(\n Context.getIntrinsic(Builder, Executable::Intrinsics::kRefFunc,\n LLVM::Type::getFunctionType(Context.Int64x2Ty,\n {Context.Int32Ty},\n false)),\n {LLContext.getInt32(Instr.getTargetIndex())}));\n break;\n case OpCode::Ref__as_non_null: {\n auto Next =\n LLVM::BasicBlock::create(LLContext, F.Fn, \"ref_as_non_null.ok\");\n Stack.back() = Builder.createBitCast(Stack.back(), Context.Int64x2Ty);\n auto IsNotNull = Builder.createLikely(Builder.createICmpNE(\n Builder.createExtractElement(Stack.back(), LLContext.getInt64(1)),\n LLContext.getInt64(0)));\n Builder.createCondBr(IsNotNull, Next,\n getTrapBB(ErrCode::Value::CastNullToNonNull));\n Builder.positionAtEnd(Next);\n break;\n }\n case OpCode::Drop:\n stackPop();\n break;\n case OpCode::Select:\n case OpCode::Select_t: {\n auto Cond = Builder.createICmpNE(stackPop(), LLContext.getInt32(0));\n auto False = stackPop();\n auto True = stackPop();\n stackPush(Builder.createSelect(Cond, True, False));\n break;\n }\n case OpCode::Local__get: {\n const auto &L = Local[Instr.getTargetIndex()];\n stackPush(Builder.createLoad(L.first, L.second));\n break;\n }\n case OpCode::Local__set:\n Builder.createStore(stackPop(), Local[Instr.getTargetIndex()].second);\n break;\n case OpCode::Local__tee:\n Builder.createStore(Stack.back(), Local[Instr.getTargetIndex()].second);\n break;\n case OpCode::Global__get: {\n const auto G =\n Context.getGlobal(Builder, ExecCtx, Instr.getTargetIndex());\n stackPush(Builder.createLoad(G.first, G.second));\n break;\n }\n case OpCode::Global__set:\n Builder.createStore(\n stackPop(),\n Context.getGlobal(Builder, ExecCtx, Instr.getTargetIndex()).second);\n break;\n case OpCode::Table__get: {\n auto Idx = stackPop();\n stackPush(Builder.createCall(\n Context.getIntrinsic(\n Builder, Executable::Intrinsics::kTableGet,\n LLVM::Type::getFunctionType(Context.Int64x2Ty,\n {Context.Int32Ty, Context.Int32Ty},\n false)),\n {LLContext.getInt32(Instr.getTargetIndex()), Idx}));\n break;\n }\n case OpCode::Table__set: {\n auto Ref = stackPop();\n auto Idx = stackPop();\n Builder.createCall(\n Context.getIntrinsic(\n Builder, Executable::Intrinsics::kTableSet,\n LLVM::Type::getFunctionType(\n Context.Int64Ty,\n {Context.Int32Ty, Context.Int32Ty, Context.Int64x2Ty},\n false)),\n {LLContext.getInt32(Instr.getTargetIndex()), Idx, Ref});\n break;\n }\n case OpCode::Table__init: {\n auto Len = stackPop();\n auto Src = stackPop();\n auto Dst = stackPop();\n Builder.createCall(\n Context.getIntrinsic(\n Builder, Executable::Intrinsics::kTableInit,\n LLVM::Type::getFunctionType(Context.VoidTy,\n {Context.Int32Ty, Context.Int32Ty,\n Context.Int32Ty, Context.Int32Ty,\n Context.Int32Ty},\n false)),\n {LLContext.getInt32(Instr.getTargetIndex()),\n LLContext.getInt32(Instr.getSourceIndex()), Dst, Src, Len});\n break;\n }\n case OpCode::Elem__drop: {\n Builder.createCall(\n Context.getIntrinsic(Builder, Executable::Intrinsics::kElemDrop,\n LLVM::Type::getFunctionType(\n Context.VoidTy, {Context.Int32Ty}, false)),\n {LLContext.getInt32(Instr.getTargetIndex())});\n break;\n }\n case OpCode::Table__copy: {\n auto Len = stackPop();\n auto Src = stackPop();\n auto Dst = stackPop();\n Builder.createCall(\n Context.getIntrinsic(\n Builder, Executable::Intrinsics::kTableCopy,\n LLVM::Type::getFunctionType(Context.VoidTy,\n {Context.Int32Ty, Context.Int32Ty,\n Context.Int32Ty, Context.Int32Ty,\n Context.Int32Ty},\n false)),\n {LLContext.getInt32(Instr.getTargetIndex()),\n LLContext.getInt32(Instr.getSourceIndex()), Dst, Src, Len});\n break;\n }\n case OpCode::Table__grow: {\n auto NewSize = stackPop();\n auto Val = stackPop();\n stackPush(Builder.createCall(\n Context.getIntrinsic(\n Builder, Executable::Intrinsics::kTableGrow,\n LLVM::Type::getFunctionType(\n Context.Int32Ty,\n {Context.Int32Ty, Context.Int64x2Ty, Context.Int32Ty},\n false)),\n {LLContext.getInt32(Instr.getTargetIndex()), Val, NewSize}));\n break;\n }\n case OpCode::Table__size: {\n stackPush(Builder.createCall(\n Context.getIntrinsic(Builder, Executable::Intrinsics::kTableSize,\n LLVM::Type::getFunctionType(Context.Int32Ty,\n {Context.Int32Ty},\n false)),\n {LLContext.getInt32(Instr.getTargetIndex())}));\n break;\n }\n case OpCode::Table__fill: {\n auto Len = stackPop();\n auto Val = stackPop();\n auto Off = stackPop();\n Builder.createCall(\n Context.getIntrinsic(Builder, Executable::Intrinsics::kTableFill,\n LLVM::Type::getFunctionType(\n Context.Int32Ty,\n {Context.Int32Ty, Context.Int32Ty,\n Context.Int64x2Ty, Context.Int32Ty},\n false)),\n {LLContext.getInt32(Instr.getTargetIndex()), Off, Val, Len});\n break;\n }\n case OpCode::I32__load:\n compileLoadOp(Instr.getTargetIndex(), Instr.getMemoryOffset(),\n Instr.getMemoryAlign(), Context.Int32Ty);\n break;\n case OpCode::I64__load:\n compileLoadOp(Instr.getTargetIndex(), Instr.getMemoryOffset(),\n Instr.getMemoryAlign(), Context.Int64Ty);\n break;\n case OpCode::F32__load:\n compileLoadOp(Instr.getTargetIndex(), Instr.getMemoryOffset(),\n Instr.getMemoryAlign(), Context.FloatTy);\n break;\n case OpCode::F64__load:\n compileLoadOp(Instr.getTargetIndex(), Instr.getMemoryOffset(),\n Instr.getMemoryAlign(), Context.DoubleTy);\n break;\n case OpCode::I32__load8_s:\n compileLoadOp(Instr.getTargetIndex(), Instr.getMemoryOffset(),\n Instr.getMemoryAlign(), Context.Int8Ty, Context.Int32Ty,\n true);\n break;\n case OpCode::I32__load8_u:\n compileLoadOp(Instr.getTargetIndex(), Instr.getMemoryOffset(),\n Instr.getMemoryAlign(), Context.Int8Ty, Context.Int32Ty,\n false);\n break;\n case OpCode::I32__load16_s:\n compileLoadOp(Instr.getTargetIndex(), Instr.getMemoryOffset(),\n Instr.getMemoryAlign(), Context.Int16Ty, Context.Int32Ty,\n true);\n break;\n case OpCode::I32__load16_u:\n compileLoadOp(Instr.getTargetIndex(), Instr.getMemoryOffset(),\n Instr.getMemoryAlign(), Context.Int16Ty, Context.Int32Ty,\n false);\n break;\n case OpCode::I64__load8_s:\n compileLoadOp(Instr.getTargetIndex(), Instr.getMemoryOffset(),\n Instr.getMemoryAlign(), Context.Int8Ty, Context.Int64Ty,\n true);\n break;\n case OpCode::I64__load8_u:\n compileLoadOp(Instr.getTargetIndex(), Instr.getMemoryOffset(),\n Instr.getMemoryAlign(), Context.Int8Ty, Context.Int64Ty,\n false);\n break;\n case OpCode::I64__load16_s:\n compileLoadOp(Instr.getTargetIndex(), Instr.getMemoryOffset(),\n Instr.getMemoryAlign(), Context.Int16Ty, Context.Int64Ty,\n true);\n break;\n case OpCode::I64__load16_u:\n compileLoadOp(Instr.getTargetIndex(), Instr.getMemoryOffset(),\n Instr.getMemoryAlign(), Context.Int16Ty, Context.Int64Ty,\n false);\n break;\n case OpCode::I64__load32_s:\n compileLoadOp(Instr.getTargetIndex(), Instr.getMemoryOffset(),\n Instr.getMemoryAlign(), Context.Int32Ty, Context.Int64Ty,\n true);\n break;\n case OpCode::I64__load32_u:\n compileLoadOp(Instr.getTargetIndex(), Instr.getMemoryOffset(),\n Instr.getMemoryAlign(), Context.Int32Ty, Context.Int64Ty,\n false);\n break;\n\n case OpCode::I32__store:\n compileStoreOp(Instr.getTargetIndex(), Instr.getMemoryOffset(),\n Instr.getMemoryAlign(), Context.Int32Ty);\n break;\n case OpCode::I64__store:\n compileStoreOp(Instr.getTargetIndex(), Instr.getMemoryOffset(),\n Instr.getMemoryAlign(), Context.Int64Ty);\n break;\n case OpCode::F32__store:\n compileStoreOp(Instr.getTargetIndex(), Instr.getMemoryOffset(),\n Instr.getMemoryAlign(), Context.FloatTy);\n break;\n case OpCode::F64__store:\n compileStoreOp(Instr.getTargetIndex(), Instr.getMemoryOffset(),\n Instr.getMemoryAlign(), Context.DoubleTy);\n break;\n case OpCode::I32__store8:\n case OpCode::I64__store8:\n compileStoreOp(Instr.getTargetIndex(), Instr.getMemoryOffset(),\n Instr.getMemoryAlign(), Context.Int8Ty, true);\n break;\n case OpCode::I32__store16:\n case OpCode::I64__store16:\n compileStoreOp(Instr.getTargetIndex(), Instr.getMemoryOffset(),\n Instr.getMemoryAlign(), Context.Int16Ty, true);\n break;\n case OpCode::I64__store32:\n compileStoreOp(Instr.getTargetIndex(), Instr.getMemoryOffset(),\n Instr.getMemoryAlign(), Context.Int32Ty, true);\n break;\n case OpCode::Memory__size:\n stackPush(Builder.createCall(\n Context.getIntrinsic(Builder, Executable::Intrinsics::kMemSize,\n LLVM::Type::getFunctionType(Context.Int32Ty,\n {Context.Int32Ty},\n false)),\n {LLContext.getInt32(Instr.getTargetIndex())}));\n break;\n case OpCode::Memory__grow: {\n auto Diff = stackPop();\n stackPush(Builder.createCall(\n Context.getIntrinsic(\n Builder, Executable::Intrinsics::kMemGrow,\n LLVM::Type::getFunctionType(Context.Int32Ty,\n {Context.Int32Ty, Context.Int32Ty},\n false)),\n {LLContext.getInt32(Instr.getTargetIndex()), Diff}));\n break;\n }\n case OpCode::Memory__init: {\n auto Len = stackPop();\n auto Src = stackPop();\n auto Dst = stackPop();\n Builder.createCall(\n Context.getIntrinsic(\n Builder, Executable::Intrinsics::kMemInit,\n LLVM::Type::getFunctionType(Context.VoidTy,\n {Context.Int32Ty, Context.Int32Ty,\n Context.Int32Ty, Context.Int32Ty,\n Context.Int32Ty},\n false)),\n {LLContext.getInt32(Instr.getTargetIndex()),\n LLContext.getInt32(Instr.getSourceIndex()), Dst, Src, Len});\n break;\n }\n case OpCode::Data__drop: {\n Builder.createCall(\n Context.getIntrinsic(Builder, Executable::Intrinsics::kDataDrop,\n LLVM::Type::getFunctionType(\n Context.VoidTy, {Context.Int32Ty}, false)),\n {LLContext.getInt32(Instr.getTargetIndex())});\n break;\n }\n case OpCode::Memory__copy: {\n auto Len = stackPop();\n auto Src = stackPop();\n auto Dst = stackPop();\n Builder.createCall(\n Context.getIntrinsic(\n Builder, Executable::Intrinsics::kMemCopy,\n LLVM::Type::getFunctionType(Context.VoidTy,\n {Context.Int32Ty, Context.Int32Ty,\n Context.Int32Ty, Context.Int32Ty,\n Context.Int32Ty},\n false)),\n {LLContext.getInt32(Instr.getTargetIndex()),\n LLContext.getInt32(Instr.getSourceIndex()), Dst, Src, Len});\n break;\n }\n case OpCode::Memory__fill: {\n auto Len = stackPop();\n auto Val = Builder.createTrunc(stackPop(), Context.Int8Ty);\n auto Off = stackPop();\n Builder.createCall(\n Context.getIntrinsic(\n Builder, Executable::Intrinsics::kMemFill,\n LLVM::Type::getFunctionType(Context.VoidTy,\n {Context.Int32Ty, Context.Int32Ty,\n Context.Int8Ty, Context.Int32Ty},\n false)),\n {LLContext.getInt32(Instr.getTargetIndex()), Off, Val, Len});\n break;\n }\n case OpCode::I32__const:\n stackPush(LLContext.getInt32(Instr.getNum().get()));\n break;\n case OpCode::I64__const:\n stackPush(LLContext.getInt64(Instr.getNum().get()));\n break;\n case OpCode::F32__const:\n stackPush(LLContext.getFloat(Instr.getNum().get()));\n break;\n case OpCode::F64__const:\n stackPush(LLContext.getDouble(Instr.getNum().get()));\n break;\n case OpCode::I32__eqz:\n stackPush(Builder.createZExt(\n Builder.createICmpEQ(stackPop(), LLContext.getInt32(0)),\n Context.Int32Ty));\n break;\n case OpCode::I64__eqz:\n stackPush(Builder.createZExt(\n Builder.createICmpEQ(stackPop(), LLContext.getInt64(0)),\n Context.Int32Ty));\n break;\n case OpCode::I32__clz:\n assuming(LLVM::Core::Ctlz != LLVM::Core::NotIntrinsic);\n stackPush(Builder.createIntrinsic(LLVM::Core::Ctlz, {Context.Int32Ty},\n {stackPop(), LLContext.getFalse()}));\n break;\n case OpCode::I64__clz:\n assuming(LLVM::Core::Ctlz != LLVM::Core::NotIntrinsic);\n stackPush(Builder.createIntrinsic(LLVM::Core::Ctlz, {Context.Int64Ty},\n {stackPop(), LLContext.getFalse()}));\n break;\n case OpCode::I32__ctz:\n assuming(LLVM::Core::Cttz != LLVM::Core::NotIntrinsic);\n stackPush(Builder.createIntrinsic(LLVM::Core::Cttz, {Context.Int32Ty},\n {stackPop(), LLContext.getFalse()}));\n break;\n case OpCode::I64__ctz:\n assuming(LLVM::Core::Cttz != LLVM::Core::NotIntrinsic);\n stackPush(Builder.createIntrinsic(LLVM::Core::Cttz, {Context.Int64Ty},\n {stackPop(), LLContext.getFalse()}));\n break;\n case OpCode::I32__popcnt:\n case OpCode::I64__popcnt:\n assuming(LLVM::Core::Ctpop != LLVM::Core::NotIntrinsic);\n stackPush(Builder.createUnaryIntrinsic(LLVM::Core::Ctpop, stackPop()));\n break;\n case OpCode::F32__abs:\n case OpCode::F64__abs:\n assuming(LLVM::Core::Fabs != LLVM::Core::NotIntrinsic);\n stackPush(Builder.createUnaryIntrinsic(LLVM::Core::Fabs, stackPop()));\n break;\n case OpCode::F32__neg:\n case OpCode::F64__neg:\n stackPush(Builder.createFNeg(stackPop()));\n break;\n case OpCode::F32__ceil:\n case OpCode::F64__ceil:\n assuming(LLVM::Core::Ceil != LLVM::Core::NotIntrinsic);\n stackPush(Builder.createUnaryIntrinsic(LLVM::Core::Ceil, stackPop()));\n break;\n case OpCode::F32__floor:\n case OpCode::F64__floor:\n assuming(LLVM::Core::Floor != LLVM::Core::NotIntrinsic);\n stackPush(Builder.createUnaryIntrinsic(LLVM::Core::Floor, stackPop()));\n break;\n case OpCode::F32__trunc:\n case OpCode::F64__trunc:\n assuming(LLVM::Core::Trunc != LLVM::Core::NotIntrinsic);\n stackPush(Builder.createUnaryIntrinsic(LLVM::Core::Trunc, stackPop()));\n break;\n case OpCode::F32__nearest:\n case OpCode::F64__nearest: {\n const bool IsFloat = Instr.getOpCode() == OpCode::F32__nearest;\n LLVM::Value Value = stackPop();\n\n#if LLVM_VERSION_MAJOR >= 12\n assuming(LLVM::Core::Roundeven != LLVM::Core::NotIntrinsic);\n if (LLVM::Core::Roundeven != LLVM::Core::NotIntrinsic) {\n stackPush(Builder.createUnaryIntrinsic(LLVM::Core::Roundeven, Value));\n break;\n }\n#endif\n\n // The VectorSize is only used when SSE4_1 or NEON is supported.\n [[maybe_unused]] const uint32_t VectorSize = IsFloat ? 4 : 2;\n#if defined(__x86_64__)\n if (Context.SupportSSE4_1) {\n auto Zero = LLContext.getInt64(0);\n auto VectorTy =\n LLVM::Type::getVectorType(Value.getType(), VectorSize);\n LLVM::Value Ret = LLVM::Value::getUndef(VectorTy);\n Ret = Builder.createInsertElement(Ret, Value, Zero);\n auto ID = IsFloat ? LLVM::Core::X86SSE41RoundSs\n : LLVM::Core::X86SSE41RoundSd;\n assuming(ID != LLVM::Core::NotIntrinsic);\n Ret = Builder.createIntrinsic(ID, {},\n {Ret, Ret, LLContext.getInt32(8)});\n Ret = Builder.createExtractElement(Ret, Zero);\n stackPush(Ret);\n break;\n }\n#endif\n\n#if defined(__aarch64__)\n if (Context.SupportNEON &&\n LLVM::Core::AArch64NeonFRIntN != LLVM::Core::NotIntrinsic) {\n auto Zero = LLContext.getInt64(0);\n auto VectorTy =\n LLVM::Type::getVectorType(Value.getType(), VectorSize);\n LLVM::Value Ret = LLVM::Value::getUndef(VectorTy);\n Ret = Builder.createInsertElement(Ret, Value, Zero);\n Ret =\n Builder.createUnaryIntrinsic(LLVM::Core::AArch64NeonFRIntN, Ret);\n Ret = Builder.createExtractElement(Ret, Zero);\n stackPush(Ret);\n break;\n }\n#endif\n\n // Fallback case.\n // If the SSE4.1 is not supported on the x86_64 platform or\n // the NEON is not supported on the aarch64 platform,\n // then fallback to this.\n assuming(LLVM::Core::Nearbyint != LLVM::Core::NotIntrinsic);\n stackPush(Builder.createUnaryIntrinsic(LLVM::Core::Nearbyint, Value));\n break;\n }\n case OpCode::F32__sqrt:\n case OpCode::F64__sqrt:\n assuming(LLVM::Core::Sqrt != LLVM::Core::NotIntrinsic);\n stackPush(Builder.createUnaryIntrinsic(LLVM::Core::Sqrt, stackPop()));\n break;\n case OpCode::I32__wrap_i64:\n stackPush(Builder.createTrunc(stackPop(), Context.Int32Ty));\n break;\n case OpCode::I32__trunc_f32_s:\n compileSignedTrunc(Context.Int32Ty);\n break;\n case OpCode::I32__trunc_f64_s:\n compileSignedTrunc(Context.Int32Ty);\n break;\n case OpCode::I32__trunc_f32_u:\n compileUnsignedTrunc(Context.Int32Ty);\n break;\n case OpCode::I32__trunc_f64_u:\n compileUnsignedTrunc(Context.Int32Ty);\n break;\n case OpCode::I64__extend_i32_s:\n stackPush(Builder.createSExt(stackPop(), Context.Int64Ty));\n break;\n case OpCode::I64__extend_i32_u:\n stackPush(Builder.createZExt(stackPop(), Context.Int64Ty));\n break;\n case OpCode::I64__trunc_f32_s:\n compileSignedTrunc(Context.Int64Ty);\n break;\n case OpCode::I64__trunc_f64_s:\n compileSignedTrunc(Context.Int64Ty);\n break;\n case OpCode::I64__trunc_f32_u:\n compileUnsignedTrunc(Context.Int64Ty);\n break;\n case OpCode::I64__trunc_f64_u:\n compileUnsignedTrunc(Context.Int64Ty);\n break;\n case OpCode::F32__convert_i32_s:\n case OpCode::F32__convert_i64_s:\n stackPush(Builder.createSIToFP(stackPop(), Context.FloatTy));\n break;\n case OpCode::F32__convert_i32_u:\n case OpCode::F32__convert_i64_u:\n stackPush(Builder.createUIToFP(stackPop(), Context.FloatTy));\n break;\n case OpCode::F64__convert_i32_s:\n case OpCode::F64__convert_i64_s:\n stackPush(Builder.createSIToFP(stackPop(), Context.DoubleTy));\n break;\n case OpCode::F64__convert_i32_u:\n case OpCode::F64__convert_i64_u:\n stackPush(Builder.createUIToFP(stackPop(), Context.DoubleTy));\n break;\n case OpCode::F32__demote_f64:\n stackPush(Builder.createFPTrunc(stackPop(), Context.FloatTy));\n break;\n case OpCode::F64__promote_f32:\n stackPush(Builder.createFPExt(stackPop(), Context.DoubleTy));\n break;\n case OpCode::I32__reinterpret_f32:\n stackPush(Builder.createBitCast(stackPop(), Context.Int32Ty));\n break;\n case OpCode::I64__reinterpret_f64:\n stackPush(Builder.createBitCast(stackPop(), Context.Int64Ty));\n break;\n case OpCode::F32__reinterpret_i32:\n stackPush(Builder.createBitCast(stackPop(), Context.FloatTy));\n break;\n case OpCode::F64__reinterpret_i64:\n stackPush(Builder.createBitCast(stackPop(), Context.DoubleTy));\n break;\n case OpCode::I32__extend8_s:\n stackPush(Builder.createSExt(\n Builder.createTrunc(stackPop(), Context.Int8Ty), Context.Int32Ty));\n break;\n case OpCode::I32__extend16_s:\n stackPush(Builder.createSExt(\n Builder.createTrunc(stackPop(), Context.Int16Ty), Context.Int32Ty));\n break;\n case OpCode::I64__extend8_s:\n stackPush(Builder.createSExt(\n Builder.createTrunc(stackPop(), Context.Int8Ty), Context.Int64Ty));\n break;\n case OpCode::I64__extend16_s:\n stackPush(Builder.createSExt(\n Builder.createTrunc(stackPop(), Context.Int16Ty), Context.Int64Ty));\n break;\n case OpCode::I64__extend32_s:\n stackPush(Builder.createSExt(\n Builder.createTrunc(stackPop(), Context.Int32Ty), Context.Int64Ty));\n break;\n case OpCode::I32__trunc_sat_f32_s:\n compileSignedTruncSat(Context.Int32Ty);\n break;\n case OpCode::I32__trunc_sat_f32_u:\n compileUnsignedTruncSat(Context.Int32Ty);\n break;\n case OpCode::I32__trunc_sat_f64_s:\n compileSignedTruncSat(Context.Int32Ty);\n break;\n case OpCode::I32__trunc_sat_f64_u:\n compileUnsignedTruncSat(Context.Int32Ty);\n break;\n case OpCode::I64__trunc_sat_f32_s:\n compileSignedTruncSat(Context.Int64Ty);\n break;\n case OpCode::I64__trunc_sat_f32_u:\n compileUnsignedTruncSat(Context.Int64Ty);\n break;\n case OpCode::I64__trunc_sat_f64_s:\n compileSignedTruncSat(Context.Int64Ty);\n break;\n case OpCode::I64__trunc_sat_f64_u:\n compileUnsignedTruncSat(Context.Int64Ty);\n break;\n case OpCode::I32__eq:\n case OpCode::I64__eq: {\n LLVM::Value RHS = stackPop();\n LLVM::Value LHS = stackPop();\n stackPush(Builder.createZExt(Builder.createICmpEQ(LHS, RHS),\n Context.Int32Ty));\n break;\n }\n case OpCode::I32__ne:\n case OpCode::I64__ne: {\n LLVM::Value RHS = stackPop();\n LLVM::Value LHS = stackPop();\n stackPush(Builder.createZExt(Builder.createICmpNE(LHS, RHS),\n Context.Int32Ty));\n break;\n }\n case OpCode::I32__lt_s:\n case OpCode::I64__lt_s: {\n LLVM::Value RHS = stackPop();\n LLVM::Value LHS = stackPop();\n stackPush(Builder.createZExt(Builder.createICmpSLT(LHS, RHS),\n Context.Int32Ty));\n break;\n }\n case OpCode::I32__lt_u:\n case OpCode::I64__lt_u: {\n LLVM::Value RHS = stackPop();\n LLVM::Value LHS = stackPop();\n stackPush(Builder.createZExt(Builder.createICmpULT(LHS, RHS),\n Context.Int32Ty));\n break;\n }\n case OpCode::I32__gt_s:\n case OpCode::I64__gt_s: {\n LLVM::Value RHS = stackPop();\n LLVM::Value LHS = stackPop();\n stackPush(Builder.createZExt(Builder.createICmpSGT(LHS, RHS),\n Context.Int32Ty));\n break;\n }\n case OpCode::I32__gt_u:\n case OpCode::I64__gt_u: {\n LLVM::Value RHS = stackPop();\n LLVM::Value LHS = stackPop();\n stackPush(Builder.createZExt(Builder.createICmpUGT(LHS, RHS),\n Context.Int32Ty));\n break;\n }\n case OpCode::I32__le_s:\n case OpCode::I64__le_s: {\n...\n", "needle_token_start": 3944, "needle_token_end": 4188, "code_context_ntokens": 16326, "output": ["```cpp\nstatic LLVM::Value toLLVMConstantZero(LLVM::Context LLContext,\n const ValType &ValType) noexcept {\n switch (ValType.getCode()) {\n case TypeCode::I32:\n return LLVM::Value::getConstNull(LLContext.getInt32Ty());\n case TypeCode::I64:\n return LLVM::Value::getConstNull(LLContext.getInt64Ty());\n case TypeCode::Ref:\n case TypeCode::RefNull:\n case TypeCode::V128:\n return LLVM::Value::getConstNull(\n LLVM::Type::getVectorType(LLContext.getInt64Ty(), 2));\n case TypeCode::F32:\n return LLVM::Value::getConstNull(LLContext.getFloatTy());\n case TypeCode::F64:\n return LLVM::Value::getConstNull(LLContext.getDoubleTy());\n default:\n assumingUnreachable();\n }\n}\n```"]} +{"repo": "WasmEdge/WasmEdge", "name": "LLVMOrcLLJITBuilderSetObjectLinkingLayerCreator", "language": "cpp", "path": "lib/llvm/llvm.h", "position_ratio": 0.35, "description": "\nFunction Description:\n1. **Purpose**: The function configures a Just-In-Time (JIT) builder to use a custom object linking layer, which is responsible for managing memory and handling symbols during the JIT compilation process.\n2. **Input**: \n - A reference to the JIT builder.\n - A function pointer that defines how to create the object linking layer.\n - A context pointer that can be used to pass additional information to the layer creation function.\n3. **Output**: There is no direct output; the function modifies the state of the JIT builder by setting the custom object linking layer creator.\n4. **Procedure**: \n - The function receives the JIT builder reference and the function pointer for creating the object linking layer.\n - It sets this function as the method to create the object linking layer within the JIT compilation process.\n - This setup allows the JIT compiler to utilize a customized linking layer that can include specific behaviors, such as custom memory management or symbol resolution strategies, tailored to the needs of the application.\n\n", "instruction": "Based on the function description and code context, please retrieve and repeat the exact described function from the code context in a code block wrapped by ```:", "template": "instruction\ncode_context\ndescription\ninstruction", "code_context": " OrcThreadSafeContext &RHS) noexcept {\n using std::swap;\n swap(LHS.Ref, RHS.Ref);\n }\n\n Context getContext() noexcept {\n return LLVMOrcThreadSafeContextGetContext(Ref);\n }\n\nprivate:\n LLVMOrcThreadSafeContextRef Ref = nullptr;\n};\n\nclass OrcThreadSafeModule {\npublic:\n constexpr OrcThreadSafeModule() noexcept = default;\n constexpr OrcThreadSafeModule(LLVMOrcThreadSafeModuleRef R) noexcept\n : Ref(R) {}\n OrcThreadSafeModule(const OrcThreadSafeModule &) = delete;\n OrcThreadSafeModule &operator=(const OrcThreadSafeModule &) = delete;\n OrcThreadSafeModule(OrcThreadSafeModule &&B) noexcept\n : OrcThreadSafeModule() {\n swap(*this, B);\n }\n OrcThreadSafeModule &operator=(OrcThreadSafeModule &&B) noexcept {\n swap(*this, B);\n return *this;\n }\n\n OrcThreadSafeModule(Module &&M, OrcThreadSafeContext &C) noexcept\n : Ref(LLVMOrcCreateNewThreadSafeModule(M.release(), C.unwrap())) {}\n ~OrcThreadSafeModule() noexcept { LLVMOrcDisposeThreadSafeModule(Ref); }\n\n constexpr operator bool() const noexcept { return Ref != nullptr; }\n constexpr auto &unwrap() const noexcept { return Ref; }\n constexpr auto &unwrap() noexcept { return Ref; }\n LLVMOrcThreadSafeModuleRef release() noexcept {\n return std::exchange(Ref, nullptr);\n }\n friend void swap(OrcThreadSafeModule &LHS,\n OrcThreadSafeModule &RHS) noexcept {\n using std::swap;\n swap(LHS.Ref, RHS.Ref);\n }\n Error withModuleDo(LLVMOrcGenericIRModuleOperationFunction F,\n void *Ctx) noexcept {\n return LLVMOrcThreadSafeModuleWithModuleDo(Ref, F, Ctx);\n }\n\nprivate:\n LLVMOrcThreadSafeModuleRef Ref = nullptr;\n};\n\nclass OrcJITDylib {\npublic:\n constexpr OrcJITDylib() noexcept = default;\n constexpr OrcJITDylib(LLVMOrcJITDylibRef R) noexcept : Ref(R) {}\n OrcJITDylib(const OrcJITDylib &) = delete;\n OrcJITDylib &operator=(const OrcJITDylib &) = delete;\n OrcJITDylib(OrcJITDylib &&B) noexcept : OrcJITDylib() { swap(*this, B); }\n OrcJITDylib &operator=(OrcJITDylib &&B) noexcept {\n swap(*this, B);\n return *this;\n }\n\n constexpr operator bool() const noexcept { return Ref != nullptr; }\n constexpr auto &unwrap() const noexcept { return Ref; }\n constexpr auto &unwrap() noexcept { return Ref; }\n friend void swap(OrcJITDylib &LHS, OrcJITDylib &RHS) noexcept {\n using std::swap;\n swap(LHS.Ref, RHS.Ref);\n }\n\nprivate:\n LLVMOrcJITDylibRef Ref = nullptr;\n};\n\nclass OrcIRTransformLayer {\npublic:\n constexpr OrcIRTransformLayer() noexcept = default;\n constexpr OrcIRTransformLayer(LLVMOrcIRTransformLayerRef R) noexcept\n : Ref(R) {}\n OrcIRTransformLayer(const OrcIRTransformLayer &) = delete;\n OrcIRTransformLayer &operator=(const OrcIRTransformLayer &) = delete;\n OrcIRTransformLayer(OrcIRTransformLayer &&B) noexcept\n : OrcIRTransformLayer() {\n swap(*this, B);\n }\n OrcIRTransformLayer &operator=(OrcIRTransformLayer &&B) noexcept {\n swap(*this, B);\n return *this;\n }\n\n constexpr operator bool() const noexcept { return Ref != nullptr; }\n constexpr auto &unwrap() const noexcept { return Ref; }\n constexpr auto &unwrap() noexcept { return Ref; }\n friend void swap(OrcIRTransformLayer &LHS,\n OrcIRTransformLayer &RHS) noexcept {\n using std::swap;\n swap(LHS.Ref, RHS.Ref);\n }\n\n void setTransform(LLVMOrcIRTransformLayerTransformFunction TransformFunction,\n void *Ctx) noexcept {\n LLVMOrcIRTransformLayerSetTransform(Ref, TransformFunction, Ctx);\n }\n\nprivate:\n LLVMOrcIRTransformLayerRef Ref = nullptr;\n};\n\nclass OrcLLJIT {\npublic:\n constexpr OrcLLJIT() noexcept = default;\n constexpr OrcLLJIT(LLVMOrcLLJITRef R) noexcept : Ref(R) {}\n OrcLLJIT(const OrcLLJIT &) = delete;\n OrcLLJIT &operator=(const OrcLLJIT &) = delete;\n OrcLLJIT(OrcLLJIT &&B) noexcept : OrcLLJIT() { swap(*this, B); }\n OrcLLJIT &operator=(OrcLLJIT &&B) noexcept {\n swap(*this, B);\n return *this;\n }\n\n ~OrcLLJIT() noexcept { LLVMOrcDisposeLLJIT(Ref); }\n\n constexpr operator bool() const noexcept { return Ref != nullptr; }\n constexpr auto &unwrap() const noexcept { return Ref; }\n constexpr auto &unwrap() noexcept { return Ref; }\n friend void swap(OrcLLJIT &LHS, OrcLLJIT &RHS) noexcept {\n using std::swap;\n swap(LHS.Ref, RHS.Ref);\n }\n\n static cxx20::expected create() noexcept {\n OrcLLJIT Result;\n if (auto Err = LLVMOrcCreateLLJIT(&Result.Ref, getBuilder())) {\n return cxx20::unexpected(Err);\n } else {\n return Result;\n }\n }\n\n OrcJITDylib getMainJITDylib() noexcept {\n return LLVMOrcLLJITGetMainJITDylib(Ref);\n }\n\n Error addLLVMIRModule(const OrcJITDylib &L, OrcThreadSafeModule M) noexcept {\n return LLVMOrcLLJITAddLLVMIRModule(Ref, L.unwrap(), M.release());\n }\n\n template \n cxx20::expected lookup(const char *Name) noexcept {\n LLVMOrcJITTargetAddress Addr;\n if (auto Err = LLVMOrcLLJITLookup(Ref, &Addr, Name)) {\n return cxx20::unexpected(Err);\n }\n return reinterpret_cast(Addr);\n }\n\n OrcIRTransformLayer getIRTransformLayer() noexcept {\n return LLVMOrcLLJITGetIRTransformLayer(Ref);\n }\n\nprivate:\n LLVMOrcLLJITRef Ref = nullptr;\n\n static inline LLVMOrcLLJITBuilderRef getBuilder() noexcept;\n};\n\n} // namespace WasmEdge::LLVM\n\n#include \n#include \n#include \n#if LLVM_VERSION_MAJOR < 12 || WASMEDGE_OS_WINDOWS\n#include \n#endif\n#if LLVM_VERSION_MAJOR < 13\n#include \n#include \n#include \n#endif\n\n#if WASMEDGE_OS_WINDOWS\n#include \n#include \n#include \n#include \n#endif\n\nnamespace llvm {\n#if WASMEDGE_OS_WINDOWS\nDEFINE_SIMPLE_CONVERSION_FUNCTIONS(orc::ExecutionSession,\n LLVMOrcExecutionSessionRef)\nDEFINE_SIMPLE_CONVERSION_FUNCTIONS(orc::ObjectLayer, LLVMOrcObjectLayerRef)\n#endif\n#if LLVM_VERSION_MAJOR < 12\nDEFINE_SIMPLE_CONVERSION_FUNCTIONS(orc::LLJITBuilder, LLVMOrcLLJITBuilderRef)\n#endif\n#if LLVM_VERSION_MAJOR < 13\nDEFINE_SIMPLE_CONVERSION_FUNCTIONS(orc::ThreadSafeModule,\n LLVMOrcThreadSafeModuleRef)\nDEFINE_SIMPLE_CONVERSION_FUNCTIONS(orc::IRTransformLayer,\n LLVMOrcIRTransformLayerRef)\nDEFINE_SIMPLE_CONVERSION_FUNCTIONS(orc::MaterializationResponsibility,\n LLVMOrcMaterializationResponsibilityRef)\nDEFINE_SIMPLE_CONVERSION_FUNCTIONS(orc::LLJIT, LLVMOrcLLJITRef)\n#endif\n} // namespace llvm\n\nnamespace WasmEdge::LLVM {\n\nvoid Value::setDSOLocal(bool Local) noexcept {\n llvm::cast(reinterpret_cast(Ref))\n ->setDSOLocal(Local);\n}\n\nvoid Value::eliminateUnreachableBlocks() noexcept {\n llvm::EliminateUnreachableBlocks(\n *llvm::cast(reinterpret_cast(Ref)));\n}\n\nbool SectionIterator::isText() const noexcept {\n auto *S = reinterpret_cast(Ref);\n return (*S)->isText();\n}\n\nbool SectionIterator::isData() const noexcept {\n auto *S = reinterpret_cast(Ref);\n return (*S)->isData();\n}\n\nbool SectionIterator::isBSS() const noexcept {\n auto *S = reinterpret_cast(Ref);\n return (*S)->isBSS();\n}\n\nbool SectionIterator::isPData() const noexcept {\n#if WASMEDGE_OS_WINDOWS\n using namespace std::literals;\n return \".pdata\"sv == getName();\n#else\n return false;\n#endif\n}\n\nbool SectionIterator::isEHFrame() const noexcept {\n#if WASMEDGE_OS_LINUX\n using namespace std::literals;\n return \".eh_frame\"sv == getName();\n#elif WASMEDGE_OS_MACOS\n using namespace std::literals;\n return \"__eh_frame\"sv == getName();\n#else\n return false;\n#endif\n}\n\n#if WASMEDGE_OS_WINDOWS\nclass DefaultMMapper final : public llvm::SectionMemoryManager::MemoryMapper {\npublic:\n llvm::sys::MemoryBlock allocateMappedMemory(\n llvm::SectionMemoryManager::AllocationPurpose /*Purpose*/,\n size_t NumBytes, const llvm::sys::MemoryBlock *const NearBlock,\n unsigned Flags, std::error_code &EC) override {\n return llvm::sys::Memory::allocateMappedMemory(NumBytes, NearBlock, Flags,\n EC);\n }\n std::error_code protectMappedMemory(const llvm::sys::MemoryBlock &Block,\n unsigned Flags) override {\n return llvm::sys::Memory::protectMappedMemory(Block, Flags);\n }\n\n std::error_code releaseMappedMemory(llvm::sys::MemoryBlock &M) override {\n return llvm::sys::Memory::releaseMappedMemory(M);\n }\n};\n\nclass ContiguousSectionMemoryManager : public llvm::RTDyldMemoryManager {\npublic:\n explicit ContiguousSectionMemoryManager(\n llvm::SectionMemoryManager::MemoryMapper *UnownedMM = nullptr)\n : MMapper(UnownedMM), OwnedMMapper(nullptr) {\n if (!MMapper) {\n OwnedMMapper = std::make_unique();\n MMapper = OwnedMMapper.get();\n }\n }\n\n ~ContiguousSectionMemoryManager() noexcept override {\n using namespace std::literals;\n if (Preallocated.allocatedSize() != 0) {\n auto EC = MMapper->releaseMappedMemory(Preallocated);\n if (EC) {\n spdlog::error(\"releaseMappedMemory failed with error: {}\"sv,\n EC.message());\n }\n }\n }\n\n bool needsToReserveAllocationSpace() override { return true; }\n\n void reserveAllocationSpace(uintptr_t CodeSize, llvm::Align CodeAlign,\n uintptr_t RODataSize, llvm::Align RODataAlign,\n uintptr_t RWDataSize,\n llvm::Align RWDataAlign) override {\n using namespace std::literals;\n assuming(Preallocated.allocatedSize() == 0);\n\n static const size_t PageSize = llvm::sys::Process::getPageSizeEstimate();\n assuming(CodeAlign.value() <= PageSize);\n assuming(RODataAlign.value() <= PageSize);\n assuming(RWDataAlign.value() <= PageSize);\n CodeSize = roundUpTo(CodeSize + CodeAlign.value(), PageSize);\n RODataSize = roundUpTo(RODataSize + RODataAlign.value(), PageSize);\n RWDataSize = roundUpTo(RWDataSize + RWDataAlign.value(), PageSize);\n const uintptr_t TotalSize =\n CodeSize + RODataSize + RWDataSize + PageSize * 3;\n\n std::error_code EC;\n Preallocated = MMapper->allocateMappedMemory(\n llvm::SectionMemoryManager::AllocationPurpose::Code, TotalSize, nullptr,\n llvm::sys::Memory::MF_READ | llvm::sys::Memory::MF_WRITE, EC);\n if (EC) {\n spdlog::error(\"allocateMappedMemory failed with error: {}\"sv,\n EC.message());\n return;\n }\n\n auto base = reinterpret_cast(Preallocated.base());\n CodeMem = CodeFree =\n llvm::sys::MemoryBlock(reinterpret_cast(base), CodeSize);\n base += CodeSize;\n RODataMem = RODataFree =\n llvm::sys::MemoryBlock(reinterpret_cast(base), RODataSize);\n base += RODataSize;\n RWDataMem = RWDataFree =\n llvm::sys::MemoryBlock(reinterpret_cast(base), RWDataSize);\n }\n\n uint8_t *allocateDataSection(uintptr_t Size, unsigned Alignment,\n unsigned /*SectionID*/,\n llvm::StringRef /*SectionName*/,\n bool IsReadOnly) override {\n if (IsReadOnly) {\n return Allocate(RODataFree, Size, Alignment);\n } else {\n return Allocate(RWDataFree, Size, Alignment);\n }\n }\n\n uint8_t *allocateCodeSection(uintptr_t Size, unsigned Alignment,\n unsigned /*SectionID*/,\n llvm::StringRef /*SectionName*/) override {\n return Allocate(CodeFree, Size, Alignment);\n }\n\n bool finalizeMemory(std::string *ErrMsg) override {\n std::error_code EC;\n\n EC = MMapper->protectMappedMemory(CodeMem, llvm::sys::Memory::MF_READ |\n llvm::sys::Memory::MF_EXEC);\n if (EC) {\n if (ErrMsg) {\n *ErrMsg = EC.message();\n }\n return true;\n }\n EC = MMapper->protectMappedMemory(RODataMem, llvm::sys::Memory::MF_READ);\n if (EC) {\n if (ErrMsg) {\n *ErrMsg = EC.message();\n }\n return true;\n }\n\n llvm::sys::Memory::InvalidateInstructionCache(CodeMem.base(),\n CodeMem.allocatedSize());\n return false;\n }\n\nprivate:\n llvm::sys::MemoryBlock Preallocated;\n\n // Sections must be in the order code < rodata < rwdata.\n llvm::sys::MemoryBlock CodeMem;\n llvm::sys::MemoryBlock RODataMem;\n llvm::sys::MemoryBlock RWDataMem;\n\n llvm::sys::MemoryBlock CodeFree;\n llvm::sys::MemoryBlock RODataFree;\n llvm::sys::MemoryBlock RWDataFree;\n\n llvm::SectionMemoryManager::MemoryMapper *MMapper;\n std::unique_ptr OwnedMMapper;\n\n uint8_t *Allocate(llvm::sys::MemoryBlock &FreeBlock, std::uintptr_t Size,\n unsigned alignment) {\n using namespace std::literals;\n const auto Base = reinterpret_cast(FreeBlock.base());\n const auto Start = roundUpTo(Base, alignment);\n const uintptr_t PaddedSize = (Start - Base) + Size;\n if (PaddedSize > FreeBlock.allocatedSize()) {\n spdlog::error(\"Failed to satisfy suballocation request for {}\"sv, Size);\n return nullptr;\n }\n FreeBlock =\n llvm::sys::MemoryBlock(reinterpret_cast(Base + PaddedSize),\n FreeBlock.allocatedSize() - PaddedSize);\n return reinterpret_cast(Start);\n }\n\n static uintptr_t roundUpTo(uintptr_t Value, uintptr_t Divisor) noexcept {\n return ((Value + (Divisor - 1)) / Divisor) * Divisor;\n }\n};\n\n// Register stack unwind info for JIT functions\nclass Win64EHManager : public ContiguousSectionMemoryManager {\n using Base = ContiguousSectionMemoryManager;\n uint64_t CodeAddress = 0;\n\npublic:\n ~Win64EHManager() noexcept override {}\n\n uint8_t *allocateCodeSection(uintptr_t Size, unsigned Alignment,\n unsigned SectionID,\n llvm::StringRef SectionName) override {\n using namespace std::literals;\n const auto Allocated =\n Base::allocateCodeSection(Size, Alignment, SectionID, SectionName);\n if (SectionName == llvm::StringRef(\".text\"sv)) {\n CodeAddress = reinterpret_cast(Allocated);\n }\n return Allocated;\n }\n\n void registerEHFrames(uint8_t *Addr, uint64_t /*LoadAddr*/,\n size_t Size) noexcept override {\n using namespace std::literals;\n winapi::RUNTIME_FUNCTION_ *const FunctionTable =\n reinterpret_cast(Addr);\n const uint32_t EntryCount =\n static_cast(Size / sizeof(winapi::RUNTIME_FUNCTION_));\n if (EntryCount == 0)\n return;\n // Calculate object image base address by assuming that address of the first\n // function is equal to the address of the code section\n const auto ImageBase = CodeAddress - FunctionTable[0].BeginAddress;\n winapi::RtlAddFunctionTable(FunctionTable, EntryCount, ImageBase);\n EHFrames.push_back({Addr, Size});\n }\n void deregisterEHFrames() noexcept override {\n using namespace std::literals;\n for (auto &Frame : EHFrames) {\n winapi::RtlDeleteFunctionTable(\n reinterpret_cast(Frame.Addr));\n }\n EHFrames.clear();\n }\n};\n\nLLVMOrcLLJITBuilderRef OrcLLJIT::getBuilder() noexcept {\n using llvm::unwrap;\n using llvm::wrap;\n const LLVMOrcLLJITBuilderRef Builder = LLVMOrcCreateLLJITBuilder();\n LLVMOrcLLJITBuilderSetObjectLinkingLayerCreator(\n Builder,\n [](void *, LLVMOrcExecutionSessionRef ES, const char *) noexcept {\n auto Layer = std::make_unique(\n *unwrap(ES), []() { return std::make_unique(); });\n Layer->setOverrideObjectFlagsWithResponsibilityFlags(true);\n Layer->setAutoClaimResponsibilityForObjectSymbols(true);\n return wrap(static_cast(Layer.release()));\n },\n nullptr);\n return Builder;\n}\n#else\nLLVMOrcLLJITBuilderRef OrcLLJIT::getBuilder() noexcept { return nullptr; }\n#endif\n\n} // namespace WasmEdge::LLVM\n\n#if LLVM_VERSION_MAJOR < 12 && WASMEDGE_OS_WINDOWS\n\nvoid LLVMOrcLLJITBuilderSetObjectLinkingLayerCreator(\n LLVMOrcLLJITBuilderRef Builder,\n LLVMOrcLLJITBuilderObjectLinkingLayerCreatorFunction F,\n void *Ctx) noexcept {\n using llvm::unwrap;\n using llvm::wrap;\n unwrap(Builder)->setObjectLinkingLayerCreator(\n [=](llvm::orc::ExecutionSession &ES, const llvm::Triple &TT) {\n auto TTStr = TT.str();\n return std::unique_ptr(\n unwrap(F(Ctx, wrap(&ES), TTStr.c_str())));\n });\n}\n#endif\n#if LLVM_VERSION_MAJOR < 13\nLLVMOrcIRTransformLayerRef\nLLVMOrcLLJITGetIRTransformLayer(LLVMOrcLLJITRef J) noexcept {\n using llvm::unwrap;\n using llvm::wrap;\n return wrap(&(unwrap(J)->getIRTransformLayer()));\n}\nvoid LLVMOrcIRTransformLayerSetTransform(\n LLVMOrcIRTransformLayerRef IRTransformLayer,\n LLVMOrcIRTransformLayerTransformFunction TransformFunction,\n void *Ctx) noexcept {\n using llvm::unwrap;\n using llvm::wrap;\n unwrap(IRTransformLayer)\n ->setTransform([=](llvm::orc::ThreadSafeModule TSM,\n llvm::orc::MaterializationResponsibility &R)\n -> llvm::Expected {\n LLVMOrcThreadSafeModuleRef TSMRef =\n wrap(new llvm::orc::ThreadSafeModule(std::move(TSM)));\n if (LLVMErrorRef Err = TransformFunction(Ctx, &TSMRef, wrap(&R))) {\n return unwrap(Err);\n }\n return std::move(*unwrap(TSMRef));\n });\n}\n\nLLVMErrorRef\nLLVMOrcThreadSafeModuleWithModuleDo(LLVMOrcThreadSafeModuleRef TSM,\n LLVMOrcGenericIRModuleOperationFunction F,\n void *Ctx) noexcept {\n using llvm::unwrap;\n using llvm::wrap;\n return wrap(unwrap(TSM)->withModuleDo(\n [&](llvm::Module &M) { return unwrap(F(Ctx, wrap(&M))); }));\n}\n#endif\n\n// Path: lib/llvm/data.h\n// SPDX-License-Identifier: Apache-2.0\n// SPDX-FileCopyrightText: 2019-2022 Second State INC\n#pragma once\n\n#include \"llvm.h\"\n#include \"llvm/data.h\"\n\nstruct WasmEdge::LLVM::Data::DataContext {\n LLVM::OrcThreadSafeContext TSContext;\n LLVM::Module LLModule;\n LLVM::TargetMachine TM;\n DataContext() noexcept : TSContext(), LLModule(LLContext(), \"wasm\") {}\n LLVM::Context LLContext() noexcept { return TSContext.getContext(); }\n};\n\n// Path: lib/llvm/data.cpp\n// SPDX-License-Identifier: Apache-2.0\n// SPDX-FileCopyrightText: 2019-2022 Second State INC\n\n#include \"llvm/data.h\"\n#include \"data.h\"\n#include \"llvm.h\"\n\nnamespace LLVM = WasmEdge::LLVM;\n\nLLVM::Data::Data() noexcept : Context(std::make_unique()) {}\n\nLLVM::Data::~Data() noexcept {}\n\nLLVM::Data::Data(LLVM::Data &&RHS) noexcept : Context(std::move(RHS.Context)) {}\nLLVM::Data &LLVM::Data::operator=(LLVM::Data &&RHS) noexcept {\n using std::swap;\n swap(Context, RHS.Context);\n return *this;\n}\n\n// Path: lib/llvm/jit.cpp\n// SPDX-License-Identifier: Apache-2.0\n// SPDX-FileCopyrightText: 2019-2022 Second State INC\n\n#include \"llvm/jit.h\"\n#include \"common/log.h\"\n\n#include \"data.h\"\n#include \"llvm.h\"\n\nnamespace LLVM = WasmEdge::LLVM;\nusing namespace std::literals;\n\nnamespace WasmEdge::LLVM {\n\nJITLibrary::JITLibrary(OrcLLJIT JIT) noexcept\n : J(std::make_unique(std::move(JIT)).release()) {}\n\nJITLibrary::~JITLibrary() noexcept {\n std::unique_ptr JIT(std::exchange(J, nullptr));\n}\n\nSymbol\nJITLibrary::getIntrinsics() noexcept {\n if (auto Symbol = J->lookup(\"intrinsics\")) {\n return createSymbol(*Symbol);\n } else {\n spdlog::error(\"{}\"sv, Symbol.error().message().string_view());\n return {};\n }\n}\n\nstd::vector>\nJITLibrary::getTypes(size_t Size) noexcept {\n std::vector> Result;\n Result.reserve(Size);\n for (size_t I = 0; I < Size; ++I) {\n const std::string Name = fmt::format(\"t{}\"sv, I);\n if (auto Symbol = J->lookup(Name.c_str())) {\n Result.push_back(createSymbol(*Symbol));\n } else {\n spdlog::error(\"{}\"sv, Symbol.error().message().string_view());\n Result.emplace_back();\n }\n }\n\n return Result;\n}\n\nstd::vector> JITLibrary::getCodes(size_t Offset,\n size_t Size) noexcept {\n std::vector> Result;\n Result.reserve(Size);\n for (size_t I = 0; I < Size; ++I) {\n const std::string Name = fmt::format(\"f{}\"sv, I + Offset);\n if (auto Symbol = J->lookup(Name.c_str())) {\n Result.push_back(createSymbol(*Symbol));\n } else {\n spdlog::error(\"{}\"sv, Symbol.error().message().string_view());\n Result.emplace_back();\n }\n }\n\n return Result;\n}\n\nExpect> JIT::load(Data D) noexcept {\n OrcLLJIT J;\n if (auto Res = OrcLLJIT::create(); !Res) {\n spdlog::error(\"{}\"sv, Res.error().message().string_view());\n return Unexpect(ErrCode::Value::HostFuncError);\n } else {\n J = std::move(*Res);\n }\n\n auto &LLModule = D.extract().LLModule;\n\n if (Conf.getCompilerConfigure().isDumpIR()) {\n if (auto ErrorMessage = LLModule.printModuleToFile(\"wasm-jit.ll\")) {\n spdlog::error(\"printModuleToFile failed\");\n }\n }\n\n auto MainJD = J.getMainJITDylib();\n if (auto Err = J.addLLVMIRModule(\n MainJD,\n OrcThreadSafeModule(LLModule.release(), D.extract().TSContext))) {\n spdlog::error(\"{}\"sv, Err.message().string_view());\n return Unexpect(ErrCode::Value::HostFuncError);\n }\n\n return std::make_shared(std::move(J));\n}\n} // namespace WasmEdge::LLVM\n\n// Path: lib/llvm/codegen.cpp\n// SPDX-License-Identifier: Apache-2.0\n// SPDX-FileCopyrightText: 2019-2022 Second State INC\n\n#include \"llvm/codegen.h\"\n\n#include \"aot/version.h\"\n#include \"common/defines.h\"\n#include \"data.h\"\n#include \"llvm.h\"\n\n#include \n#include \n#include \n#include \n#include \n\n#if LLVM_VERSION_MAJOR >= 14\n#include \n#endif\n#if LLVM_VERSION_MAJOR >= 17\n#if WASMEDGE_OS_MACOS\nLLD_HAS_DRIVER(macho)\n#elif WASMEDGE_OS_LINUX\nLLD_HAS_DRIVER(elf)\n#elif WASMEDGE_OS_WINDOWS\nLLD_HAS_DRIVER(coff)\n#endif\n#endif\n\n#if WASMEDGE_OS_MACOS\n#include \n#include \n#endif\n#if WASMEDGE_OS_WINDOWS\n#include \n#endif\n\n#if WASMEDGE_OS_LINUX\n#define SYMBOL(X) X\n#elif WASMEDGE_OS_MACOS\n#define SYMBOL(X) \"_\" X\n#elif WASMEDGE_OS_WINDOWS\n#define SYMBOL(X) X\n#endif\n\nnamespace LLVM = WasmEdge::LLVM;\nusing namespace std::literals;\n\nnamespace {\n\nusing namespace WasmEdge;\n\n#if WASMEDGE_OS_MACOS\n// Get current OS version\nstd::string getOSVersion() noexcept {\n struct utsname Info;\n if (::uname(&Info)) {\n // default os version\n return \"13.0.0\"s;\n }\n std::string_view Release = Info.release;\n auto GetNum = [](std::string_view &String) noexcept {\n uint64_t Result = 0;\n while (!String.empty() && std::isdigit(String[0])) {\n Result = Result * 10 + (String[0] - '0');\n String = String.substr(1);\n }\n return Result;\n };\n auto SkipDot = [](std::string_view &String) noexcept {\n if (!String.empty() && String[0] == '.')\n String = String.substr(1);\n };\n uint64_t Major = GetNum(Release);\n SkipDot(Release);\n uint64_t Minor = GetNum(Release);\n SkipDot(Release);\n uint64_t Micro = GetNum(Release);\n\n if (Major == 0) {\n Major = 8;\n }\n if (Major <= 19) {\n Micro = 0;\n Minor = Major - 4;\n Major = 10;\n } else {\n Micro = 0;\n Minor = 0;\n Major = 11 + Major - 20;\n }\n\n return fmt::format(\"{}.{}.{}\"sv, Major, Minor, Micro);\n}\n// Get current SDK version\nstd::string getSDKVersion() noexcept {\n // TODO: parse SDKSettings.json to get real version\n return \"12.1\"s;\n}\n// Get current SDK version in pair\nstd::pair getSDKVersionPair() noexcept {\n // TODO: parse SDKSettings.json to get real version\n return {UINT32_C(12), UINT32_C(1)};\n}\n#endif\n\nExpect WriteByte(std::ostream &OS, uint8_t Data) noexcept {\n OS.put(static_cast(Data));\n return {};\n}\n\nExpect WriteU32(std::ostream &OS, uint32_t Data) noexcept {\n do {\n uint8_t Byte = static_cast(Data & UINT32_C(0x7f));\n Data >>= 7;\n if (Data > UINT32_C(0)) {\n Byte |= UINT8_C(0x80);\n }\n WriteByte(OS, Byte);\n } while (Data > UINT32_C(0));\n return {};\n}\n\nExpect WriteU64(std::ostream &OS, uint64_t Data) noexcept {\n do {\n uint8_t Byte = static_cast(Data & UINT64_C(0x7f));\n Data >>= 7;\n if (Data > UINT64_C(0)) {\n Byte |= UINT8_C(0x80);\n }\n WriteByte(OS, Byte);\n } while (Data > UINT64_C(0));\n return {};\n}\n\nExpect WriteName(std::ostream &OS, std::string_view Data) noexcept {\n WriteU32(OS, static_cast(Data.size()));\n for (const auto C : Data) {\n WriteByte(OS, static_cast(C));\n }\n return {};\n}\n\ninline constexpr bool startsWith(std::string_view Value,\n std::string_view Prefix) noexcept {\n return Value.size() >= Prefix.size() &&\n Value.substr(0, Prefix.size()) == Prefix;\n}\n\nstd::filesystem::path uniquePath(const std::filesystem::path Model) noexcept {\n using size_type = std::filesystem::path::string_type::size_type;\n using value_type = std::filesystem::path::value_type;\n static const auto Hex = \"0123456789abcdef\"sv;\n std::random_device Device;\n std::default_random_engine Engine(Device());\n std::uniform_int_distribution Distribution(0, Hex.size() - 1);\n auto String = Model.native();\n for (size_type N = String.size(), I = 0; I < N; ++I) {\n if (String[I] == static_cast('%')) {\n String[I] = static_cast(Hex[Distribution(Engine)]);\n }\n }\n return String;\n}\n\nstd::filesystem::path createTemp(const std::filesystem::path Model) noexcept {\n while (true) {\n auto Result = uniquePath(Model);\n std::error_code Error;\n if (!std::filesystem::exists(Result, Error)) {\n if (Error) {\n return {};\n }\n return Result;\n }\n }\n}\n\n// Write output object and link\nExpect outputNativeLibrary(const std::filesystem::path &OutputPath,\n const LLVM::MemoryBuffer &OSVec) noexcept {\n spdlog::info(\"output start\");\n std::filesystem::path ObjectName;\n {\n // tempfile\n std::filesystem::path OPath(OutputPath);\n#if WASMEDGE_OS_WINDOWS\n OPath.replace_extension(\"%%%%%%%%%%.obj\"sv);\n#else\n OPath.replace_extension(\"%%%%%%%%%%.o\"sv);\n#endif\n ObjectName = createTemp(OPath);\n if (ObjectName.empty()) {\n // TODO:return error\n spdlog::error(\"so file creation failed:{}\", OPath.u8string());\n return Unexpect(ErrCode::Value::IllegalPath);\n }\n std::ofstream OS(ObjectName, std::ios_base::binary);\n OS.write(OSVec.data(), static_cast(OSVec.size()));\n OS.close();\n }\n\n // link\n bool LinkResult = false;\n#if WASMEDGE_OS_MACOS\n const auto OSVersion = getOSVersion();\n const auto SDKVersion = getSDKVersion();\n#if LLVM_VERSION_MAJOR >= 14\n // LLVM 14 replaces the older mach_o lld implementation with the new one.\n // So we need to change the namespace after LLVM 14.x released.\n // Reference: https://reviews.llvm.org/D114842\n LinkResult = lld::macho::link(\n#else\n LinkResult = lld::mach_o::link(\n#endif\n std::initializer_list {\n \"lld\", \"-arch\",\n#if defined(__x86_64__)\n \"x86_64\",\n#elif defined(__aarch64__)\n \"arm64\",\n#else\n#error Unsupported architecture on the MacOS!\n#endif\n#if LLVM_VERSION_MAJOR >= 14\n // LLVM 14 replaces the older mach_o lld implementation with the new\n // one. And it require -arch and -platform_version to always be\n // specified. Reference: https://reviews.llvm.org/D97799\n \"-platform_version\", \"macos\", OSVersion.c_str(), SDKVersion.c_str(),\n#else\n \"-sdk_version\", SDKVersion.c_str(),\n#endif\n \"-dylib\", \"-demangle\", \"-macosx_version_min\", OSVersion.c_str(),\n \"-syslibroot\",\n \"/Library/Developer/CommandLineTools/SDKs/MacOSX.sdk\",\n ObjectName.u8string().c_str(), \"-o\", OutputPath.u8string().c_str()\n },\n#elif WASMEDGE_OS_LINUX\n LinkResult = lld::elf::link(\n std::initializer_list{\"ld.lld\", \"--eh-frame-hdr\",\n \"--shared\", \"--gc-sections\",\n \"--discard-all\", ObjectName.c_str(),\n \"-o\", OutputPath.u8string().c_str()},\n#elif WASMEDGE_OS_WINDOWS\n LinkResult = lld::coff::link(\n std::initializer_list{\n \"lld-link\", \"-dll\", \"-base:0\", \"-nologo\",\n ObjectName.u8string().c_str(),\n (\"-out:\" + OutputPath.u8string()).c_str()},\n#endif\n\n#if LLVM_VERSION_MAJOR >= 14\n llvm::outs(), llvm::errs(), false, false\n#elif LLVM_VERSION_MAJOR >= 10\n false, llvm::outs(), llvm::errs()\n#else\n false, llvm::errs()\n#endif\n );\n\n#if LLVM_VERSION_MAJOR >= 14\n lld::CommonLinkerContext::destroy();\n#endif\n\n if (LinkResult) {\n std::error_code Error;\n std::filesystem::remove(ObjectName, Error);\n#if WASMEDGE_OS_WINDOWS\n std::filesystem::path LibPath(OutputPath);\n LibPath.replace_extension(\".lib\"sv);\n std::filesystem::remove(LibPath, Error);\n#endif\n\n spdlog::info(\"codegen done\");\n } else {\n spdlog::error(\"link error\");\n }\n\n#if WASMEDGE_OS_MACOS\n // codesign\n if (LinkResult) {\n pid_t PID = ::fork();\n if (PID == -1) {\n spdlog::error(\"codesign error on fork:{}\", std::strerror(errno));\n } else if (PID == 0) {\n execlp(\"/usr/bin/codesign\", \"codesign\", \"-s\", \"-\",\n OutputPath.u8string().c_str(), nullptr);\n std::exit(256);\n } else {\n int ChildStat;\n waitpid(PID, &ChildStat, 0);\n if (const int Status = WEXITSTATUS(ChildStat); Status != 0) {\n spdlog::error(\"codesign exited with status {}\", Status);\n }\n }\n }\n#endif\n\n return {};\n}\n\nExpect outputWasmLibrary(LLVM::Context LLContext,\n const std::filesystem::path &OutputPath,\n Span Data,\n const LLVM::MemoryBuffer &OSVec) noexcept {\n std::filesystem::path SharedObjectName;\n {\n // tempfile\n std::filesystem::path SOPath(OutputPath);\n SOPath.replace_extension(\"%%%%%%%%%%\" WASMEDGE_LIB_EXTENSION);\n SharedObjectName = createTemp(SOPath);\n if (SharedObjectName.empty()) {\n // TODO:return error\n spdlog::error(\"so file creation failed:{}\", SOPath.u8string());\n return Unexpect(ErrCode::Value::IllegalPath);\n }\n std::ofstream OS(SharedObjectName, std::ios_base::binary);\n OS.write(OSVec.data(), static_cast(OSVec.size()));\n OS.close();\n }\n\n if (auto Res = outputNativeLibrary(SharedObjectName, OSVec); unlikely(!Res)) {\n return Unexpect(Res);\n }\n\n LLVM::MemoryBuffer SOFile;\n if (auto [Res, ErrorMessage] =\n LLVM::MemoryBuffer::getFile(SharedObjectName.u8string().c_str());\n unlikely(ErrorMessage)) {\n spdlog::error(\"object file open error:{}\", ErrorMessage.string_view());\n return Unexpect(ErrCode::Value::IllegalPath);\n } else {\n SOFile = std::move(Res);\n }\n\n LLVM::Binary ObjFile;\n if (auto [Res, ErrorMessage] = LLVM::Binary::create(SOFile, LLContext);\n unlikely(ErrorMessage)) {\n spdlog::error(\"object file parse error:{}\", ErrorMessage.string_view());\n return Unexpect(ErrCode::Value::IllegalPath);\n } else {\n ObjFile = std::move(Res);\n }\n\n std::string OSCustomSecVec;\n {\n std::ostringstream OS;\n WriteName(OS, \"wasmedge\"sv);\n WriteU32(OS, AOT::kBinaryVersion);\n\n#if WASMEDGE_OS_LINUX\n WriteByte(OS, UINT8_C(1));\n#elif WASMEDGE_OS_MACOS\n WriteByte(OS, UINT8_C(2));\n#elif WASMEDGE_OS_WINDOWS\n WriteByte(OS, UINT8_C(3));\n#else\n#error Unsupported operating system!\n#endif\n\n#if defined(__x86_64__)\n WriteByte(OS, UINT8_C(1));\n#elif defined(__aarch64__)\n WriteByte(OS, UINT8_C(2));\n#elif defined(__riscv) && __riscv_xlen == 64\n WriteByte(OS, UINT8_C(3));\n#elif defined(__arm__) && __ARM_ARCH == 7\n WriteByte(OS, UINT8_C(4));\n#else\n#error Unsupported hardware architecture!\n#endif\n\n std::vector> SymbolTable;\n#if !WASMEDGE_OS_WINDOWS\n for (auto Symbol = ObjFile.symbols();\n Symbol && !ObjFile.isSymbolEnd(Symbol); Symbol.next()) {\n SymbolTable.emplace_back(Symbol.getName(), Symbol.getAddress());\n }\n#else\n for (auto &Symbol :\n llvm::object::unwrap(ObjFile.unwrap())\n ->export_directories()) {\n llvm::StringRef Name;\n if (auto Error = Symbol.getSymbolName(Name); unlikely(!!Error)) {\n continue;\n } else if (Name.empty()) {\n continue;\n }\n uint32_t Offset = 0;\n if (auto Error = Symbol.getExportRVA(Offset); unlikely(!!Error)) {\n continue;\n }\n SymbolTable.emplace_back(Name.str(), Offset);\n }\n#endif\n uint64_t VersionAddress = 0, IntrinsicsAddress = 0;\n std::vector Types;\n std::vector Codes;\n uint64_t CodesMin = std::numeric_limits::max();\n for (const auto &[Name, Address] : SymbolTable) {\n if (Name == SYMBOL(\"version\"sv)) {\n VersionAddress = Address;\n } else if (Name == SYMBOL(\"intrinsics\"sv)) {\n IntrinsicsAddress = Address;\n } else if (startsWith(Name, SYMBOL(\"t\"sv))) {\n uint64_t Index = 0;\n std::from_chars(Name.data() + SYMBOL(\"t\"sv).size(),\n Name.data() + Name.size(), Index);\n if (Types.size() < Index + 1) {\n Types.resize(Index + 1);\n }\n Types[Index] = Address;\n } else if (startsWith(Name, SYMBOL(\"f\"sv))) {\n uint64_t Index = 0;\n std::from_chars(Name.data() + SYMBOL(\"f\"sv).size(),\n Name.data() + Name.size(), Index);\n if (Codes.size() < Index + 1) {\n Codes.resize(Index + 1);\n }\n CodesMin = std::min(CodesMin, Index);\n Codes[Index] = Address;\n }\n }\n if (CodesMin != std::numeric_limits::max()) {\n Codes.erase(Codes.begin(),\n Codes.begin() + static_cast(CodesMin));\n }\n WriteU64(OS, VersionAddress);\n WriteU64(OS, IntrinsicsAddress);\n WriteU64(OS, Types.size());\n for (const uint64_t TypeAddress : Types) {\n WriteU64(OS, TypeAddress);\n }\n WriteU64(OS, Codes.size());\n for (const uint64_t CodeAddress : Codes) {\n WriteU64(OS, CodeAddress);\n }\n\n uint32_t SectionCount = 0;\n for (auto Section = ObjFile.sections(); !ObjFile.isSectionEnd(Section);\n Section.next()) {\n if (Section.getSize() == 0) {\n continue;\n }\n if (!Section.isEHFrame() && !Section.isPData() && !Section.isText() &&\n !Section.isData() && !Section.isBSS()) {\n continue;\n }\n ++SectionCount;\n }\n WriteU32(OS, SectionCount);\n\n for (auto Section = ObjFile.sections(); !ObjFile.isSectionEnd(Section);\n Section.next()) {\n if (Section.getSize() == 0) {\n continue;\n }\n std::vector Content;\n if (auto Res = Section.getContents(); unlikely(Res.empty())) {\n assumingUnreachable();\n } else {\n Content.assign(Res.begin(), Res.end());\n }\n if (Section.isEHFrame() || Section.isPData()) {\n WriteByte(OS, UINT8_C(4));\n } else if (Section.isText()) {\n WriteByte(OS, UINT8_C(1));\n } else if (Section.isData()) {\n WriteByte(OS, UINT8_C(2));\n } else if (Section.isBSS()) {\n WriteByte(OS, UINT8_C(3));\n } else {\n continue;\n }\n\n WriteU64(OS, Section.getAddress());\n WriteU64(OS, Content.size());\n WriteName(OS, std::string_view(Content.data(), Content.size()));\n }\n OSCustomSecVec = OS.str();\n }\n\n spdlog::info(\"output start\");\n\n std::ofstream OS(OutputPath, std::ios_base::binary);\n if (!OS) {\n spdlog::error(\"output failed.\");\n return Unexpect(ErrCode::Value::IllegalPath);\n }\n OS.write(reinterpret_cast(Data.data()),\n static_cast(Data.size()));\n // Custom section id\n WriteByte(OS, UINT8_C(0x00));\n WriteName(OS, std::string_view(OSCustomSecVec.data(), OSCustomSecVec.size()));\n\n std::error_code Error;\n std::filesystem::remove(SharedObjectName, Error);\n\n spdlog::info(\"output done\");\n return {};\n}\n\n} // namespace\n\nnamespace WasmEdge::LLVM {\n\nExpect CodeGen::codegen(Span WasmData, Data D,\n std::filesystem::path OutputPath) noexcept {\n auto LLContext = D.extract().LLContext();\n auto &LLModule = D.extract().LLModule;\n auto &TM = D.extract().TM;\n std::filesystem::path LLPath(OutputPath);\n LLPath.replace_extension(\"ll\"sv);\n\n#if WASMEDGE_OS_WINDOWS\n {\n // create dummy dllmain function\n auto FTy = LLVM::Type::getFunctionType(LLContext.getInt32Ty(), {});\n auto F =\n LLModule.addFunction(FTy, LLVMExternalLinkage, \"_DllMainCRTStartup\");\n F.setVisibility(LLVMProtectedVisibility);\n F.setDSOLocal(true);\n F.addFnAttr(\n LLVM::Attribute::createString(LLContext, \"no-stack-arg-probe\"sv, {}));\n F.addFnAttr(\n LLVM::Attribute::createEnum(LLContext, LLVM::Core::StrictFP, 0));\n F.addFnAttr(LLVM::Attribute::createEnum(LLContext, LLVM::Core::UWTable,\n LLVM::Core::UWTableDefault));\n F.addFnAttr(\n LLVM::Attribute::createEnum(LLContext, LLVM::Core::NoReturn, 0));\n LLVM::Builder Builder(LLContext);\n Builder.positionAtEnd(LLVM::BasicBlock::create(LLContext, F, \"entry\"));\n Builder.createRet(LLContext.getInt32(1u));\n\n auto A = LLModule.addAlias(F.getType(), F, \"_fltused\");\n A.setLinkage(LLVMExternalLinkage);\n A.setVisibility(LLVMProtectedVisibility);\n A.setDSOLocal(true);\n }\n#endif\n#if WASMEDGE_OS_MACOS\n {\n const auto [Major, Minor] = getSDKVersionPair();\n LLModule.addFlag(LLVMModuleFlagBehaviorError, \"SDK Version\"sv,\n LLVM::Value::getConstVector32(LLContext, {Major, Minor}));\n }\n#endif\n\n if (Conf.getCompilerConfigure().getOutputFormat() !=\n CompilerConfigure::OutputFormat::Wasm) {\n // create wasm.code and wasm.size\n auto Int32Ty = LLContext.getInt32Ty();\n auto Content = LLVM::Value::getConstString(\n LLContext,\n {reinterpret_cast(WasmData.data()), WasmData.size()},\n true);\n LLModule.addGlobal(Content.getType(), true, LLVMExternalLinkage, Content,\n \"wasm.code\");\n LLModule.addGlobal(Int32Ty, true, LLVMExternalLinkage,\n LLVM::Value::getConstInt(Int32Ty, WasmData.size()),\n \"wasm.size\");\n for (auto Fn = LLModule.getFirstFunction(); Fn; Fn = Fn.getNextFunction()) {\n if (Fn.getLinkage() == LLVMInternalLinkage) {\n Fn.setLinkage(LLVMExternalLinkage);\n Fn.setVisibility(LLVMProtectedVisibility);\n Fn.setDSOLocal(true);\n Fn.setDLLStorageClass(LLVMDLLExportStorageClass);\n }\n }\n } else {\n for (auto Fn = LLModule.getFirstFunction(); Fn; Fn = Fn.getNextFunction()) {\n if (Fn.getLinkage() == LLVMInternalLinkage) {\n Fn.setLinkage(LLVMPrivateLinkage);\n Fn.setDSOLocal(true);\n Fn.setDLLStorageClass(LLVMDefaultStorageClass);\n }\n }\n }\n\n // set dllexport\n for (auto GV = LLModule.getFirstGlobal(); GV; GV = GV.getNextGlobal()) {\n if (GV.getLinkage() == LLVMExternalLinkage) {\n GV.setVisibility(LLVMProtectedVisibility);\n GV.setDSOLocal(true);\n GV.setDLLStorageClass(LLVMDLLExportStorageClass);\n }\n }\n\n if (Conf.getCompilerConfigure().isDumpIR()) {\n if (auto ErrorMessage = LLModule.printModuleToFile(\"wasm.ll\");\n unlikely(ErrorMessage)) {\n spdlog::error(\"wasm.ll open error:{}\", ErrorMessage.string_view());\n return WasmEdge::Unexpect(WasmEdge::ErrCode::Value::IllegalPath);\n }\n }\n\n spdlog::info(\"codegen start\");\n // codegen\n {\n if (Conf.getCompilerConfigure().isDumpIR()) {\n if (auto ErrorMessage = LLModule.printModuleToFile(\"wasm-opt.ll\")) {\n // TODO:return error\n spdlog::error(\"printModuleToFile failed\");\n return Unexpect(ErrCode::Value::IllegalPath);\n }\n }\n\n auto [OSVec, ErrorMessage] =\n TM.emitToMemoryBuffer(LLModule, LLVMObjectFile);\n if (ErrorMessage) {\n // TODO:return error\n spdlog::error(\"addPassesToEmitFile failed\");\n return Unexpect(ErrCode::Value::IllegalPath);\n }\n\n if (Conf.getCompilerConfigure().getOutputFormat() ==\n CompilerConfigure::OutputFormat::Wasm) {\n if (auto Res = outputWasmLibrary(LLContext, OutputPath, WasmData, OSVec);\n unlikely(!Res)) {\n return Unexpect(Res);\n }\n } else {\n if (auto Res = outputNativeLibrary(OutputPath, OSVec); unlikely(!Res)) {\n return Unexpect(Res);\n }\n }\n }\n\n return {};\n}\n\n} // namespace WasmEdge::LLVM\n\n// Path: lib/llvm/compiler.cpp\n// SPDX-License-Identifier: Apache-2.0\n// SPDX-FileCopyrightText: 2019-2022 Second State INC\n\n#include \"llvm/compiler.h\"\n\n#include \"aot/version.h\"\n#include \"common/defines.h\"\n#include \"common/filesystem.h\"\n#include \"common/log.h\"\n#include \"data.h\"\n#include \"llvm.h\"\n\n#include \n#include \n#include \n#include \n#include \n#include \n#include \n#include \n#include \n#include \n#include \n\nnamespace LLVM = WasmEdge::LLVM;\nusing namespace std::literals;\n\nnamespace {\n\nstatic bool\nisVoidReturn(WasmEdge::Span ValTypes) noexcept;\nstatic LLVM::Type toLLVMType(LLVM::Context LLContext,\n const WasmEdge::ValType &ValType) noexcept;\nstatic std::vector\ntoLLVMArgsType(LLVM::Context LLContext, LLVM::Type ExecCtxPtrTy,\n WasmEdge::Span ValTypes) noexcept;\nstatic LLVM::Type\ntoLLVMRetsType(LLVM::Context LLContext,\n WasmEdge::Span ValTypes) noexcept;\nstatic LLVM::Type\ntoLLVMType(LLVM::Context LLContext, LLVM::Type ExecCtxPtrTy,\n const WasmEdge::AST::FunctionType &FuncType) noexcept;\nstatic LLVM::Value\ntoLLVMConstantZero(LLVM::Context LLContext,\n const WasmEdge::ValType &ValType) noexcept;\nstatic std::vector unpackStruct(LLVM::Builder &Builder,\n LLVM::Value Struct) noexcept;\nclass FunctionCompiler;\n\n// XXX: Misalignment handler not implemented yet, forcing unalignment\n// force unalignment load/store\nstatic inline constexpr const bool kForceUnalignment = true;\n\n// force checking div/rem on zero\nstatic inline constexpr const bool kForceDivCheck = true;\n\n// Size of a ValVariant\nstatic inline constexpr const uint32_t kValSize = sizeof(WasmEdge::ValVariant);\n\n// Translate Compiler::OptimizationLevel to llvm::PassBuilder version\n#if LLVM_VERSION_MAJOR >= 13\nstatic inline const char *\ntoLLVMLevel(WasmEdge::CompilerConfigure::OptimizationLevel Level) noexcept {\n using OL = WasmEdge::CompilerConfigure::OptimizationLevel;\n switch (Level) {\n case OL::O0:\n return \"default,function(tailcallelim)\";\n case OL::O1:\n return \"default,function(tailcallelim)\";\n case OL::O2:\n return \"default\";\n case OL::O3:\n return \"default\";\n case OL::Os:\n return \"default\";\n case OL::Oz:\n return \"default\";\n default:\n assumingUnreachable();\n }\n}\n#else\nstatic inline std::pair\ntoLLVMLevel(WasmEdge::CompilerConfigure::OptimizationLevel Level) noexcept {\n using OL = WasmEdge::CompilerConfigure::OptimizationLevel;\n switch (Level) {\n case OL::O0:\n return {0, 0};\n case OL::O1:\n return {1, 0};\n case OL::O2:\n return {2, 0};\n case OL::O3:\n return {3, 0};\n case OL::Os:\n return {2, 1};\n case OL::Oz:\n return {2, 2};\n default:\n assumingUnreachable();\n }\n}\n#endif\n\nstatic inline LLVMCodeGenOptLevel toLLVMCodeGenLevel(\n WasmEdge::CompilerConfigure::OptimizationLevel Level) noexcept {\n using OL = WasmEdge::CompilerConfigure::OptimizationLevel;\n switch (Level) {\n case OL::O0:\n return LLVMCodeGenLevelNone;\n case OL::O1:\n return LLVMCodeGenLevelLess;\n case OL::O2:\n return LLVMCodeGenLevelDefault;\n case OL::O3:\n return LLVMCodeGenLevelAggressive;\n case OL::Os:\n return LLVMCodeGenLevelDefault;\n case OL::Oz:\n return LLVMCodeGenLevelDefault;\n default:\n assumingUnreachable();\n }\n}\n} // namespace\n\nstruct LLVM::Compiler::CompileContext {\n LLVM::Context LLContext;\n LLVM::Module &LLModule;\n LLVM::Attribute Cold;\n LLVM::Attribute NoAlias;\n LLVM::Attribute NoInline;\n LLVM::Attribute NoReturn;\n LLVM::Attribute ReadOnly;\n LLVM::Attribute StrictFP;\n LLVM::Attribute UWTable;\n LLVM::Attribute NoStackArgProbe;\n LLVM::Type VoidTy;\n LLVM::Type Int8Ty;\n LLVM::Type Int16Ty;\n LLVM::Type Int32Ty;\n LLVM::Type Int64Ty;\n LLVM::Type Int128Ty;\n LLVM::Type FloatTy;\n LLVM::Type DoubleTy;\n LLVM::Type Int8x16Ty;\n LLVM::Type Int16x8Ty;\n LLVM::Type Int32x4Ty;\n LLVM::Type Floatx4Ty;\n LLVM::Type Int64x2Ty;\n LLVM::Type Doublex2Ty;\n LLVM::Type Int128x1Ty;\n LLVM::Type Int8PtrTy;\n LLVM::Type Int32PtrTy;\n LLVM::Type Int64PtrTy;\n LLVM::Type Int128PtrTy;\n LLVM::Type ExecCtxTy;\n LLVM::Type ExecCtxPtrTy;\n LLVM::Type IntrinsicsTableTy;\n LLVM::Type IntrinsicsTablePtrTy;\n LLVM::Message SubtargetFeatures;\n\n#if defined(__x86_64__)\n#if defined(__XOP__)\n bool SupportXOP = true;\n#else\n bool SupportXOP = false;\n#endif\n\n#if defined(__SSE4_1__)\n...\n", "needle_token_start": 5571, "needle_token_end": 5746, "code_context_ntokens": 16207, "output": ["```cpp\nLLVMOrcLLJITBuilderSetObjectLinkingLayerCreator(\n LLVMOrcLLJITBuilderRef Builder,\n LLVMOrcLLJITBuilderObjectLinkingLayerCreatorFunction F,\n void *Ctx) noexcept {\n using llvm::unwrap;\n using llvm::wrap;\n unwrap(Builder)->setObjectLinkingLayerCreator(\n [=](llvm::orc::ExecutionSession &ES, const llvm::Triple &TT) {\n auto TTStr = TT.str();\n return std::unique_ptr(\n unwrap(F(Ctx, wrap(&ES), TTStr.c_str())));\n });\n}\n```"]} +{"repo": "WasmEdge/WasmEdge", "name": "toAdvice", "language": "cpp", "path": "lib/host/wasi/linux.h", "position_ratio": 0.45, "description": "\nFunction Description:\n1. **Purpose**: The function converts a WebAssembly System Interface (WASI) file advice constant to a POSIX file advice constant used to optimize file I/O operations.\n2. **Input**: A WASI file advice constant.\n3. **Output**: A POSIX file advice constant corresponding to the input WASI constant.\n4. **Procedure**: The function takes the WASI advice constant, checks it against known WASI constants using a switch statement, and returns the corresponding POSIX constant. If the input constant is not recognized, the function reaches an unreachable state, indicating an error or unexpected input.\n\n", "instruction": "Based on the function description and code context, please retrieve and repeat the exact described function from the code context in a code block wrapped by ```:", "template": "instruction\ncode_context\ndescription\ninstruction", "code_context": "// Path: lib/host/wasi/macos.h\n// SPDX-License-Identifier: Apache-2.0\n// SPDX-FileCopyrightText: 2019-2022 Second State INC\n\n#include \"common/defines.h\"\n#if !WASMEDGE_OS_MACOS\n#error\n#endif\n\n#define _DARWIN_C_SOURCE\n#include \"common/errcode.h\"\n#include \"wasi/api.hpp\"\n#include \n#include \n#include \n#include \n#include \n#include \n#include \n#include \n#include \n#include \n#include \n#include \n#include \n#include \n#include \n#include \n#include \n#include \n\nnamespace WasmEdge {\nnamespace Host {\nnamespace WASI {\ninline namespace detail {\n\ninline constexpr __wasi_errno_t fromErrNo(int ErrNo) noexcept {\n switch (ErrNo) {\n case 0:\n return __WASI_ERRNO_SUCCESS;\n case E2BIG:\n return __WASI_ERRNO_2BIG;\n case EACCES:\n return __WASI_ERRNO_ACCES;\n case EADDRINUSE:\n return __WASI_ERRNO_ADDRINUSE;\n case EADDRNOTAVAIL:\n return __WASI_ERRNO_ADDRNOTAVAIL;\n case EAFNOSUPPORT:\n return __WASI_ERRNO_AFNOSUPPORT;\n case EAGAIN:\n return __WASI_ERRNO_AGAIN;\n case EALREADY:\n return __WASI_ERRNO_ALREADY;\n case EBADF:\n return __WASI_ERRNO_BADF;\n case EBADMSG:\n return __WASI_ERRNO_BADMSG;\n case EBUSY:\n return __WASI_ERRNO_BUSY;\n case ECANCELED:\n return __WASI_ERRNO_CANCELED;\n case ECHILD:\n return __WASI_ERRNO_CHILD;\n case ECONNABORTED:\n return __WASI_ERRNO_CONNABORTED;\n case ECONNREFUSED:\n return __WASI_ERRNO_CONNREFUSED;\n case ECONNRESET:\n return __WASI_ERRNO_CONNRESET;\n case EDEADLK:\n return __WASI_ERRNO_DEADLK;\n case EDESTADDRREQ:\n return __WASI_ERRNO_DESTADDRREQ;\n case EDOM:\n return __WASI_ERRNO_DOM;\n case EDQUOT:\n return __WASI_ERRNO_DQUOT;\n case EEXIST:\n return __WASI_ERRNO_EXIST;\n case EFAULT:\n return __WASI_ERRNO_FAULT;\n case EFBIG:\n return __WASI_ERRNO_FBIG;\n case EHOSTUNREACH:\n return __WASI_ERRNO_HOSTUNREACH;\n case EIDRM:\n return __WASI_ERRNO_IDRM;\n case EILSEQ:\n return __WASI_ERRNO_ILSEQ;\n case EINPROGRESS:\n return __WASI_ERRNO_INPROGRESS;\n case EINTR:\n return __WASI_ERRNO_INTR;\n case EINVAL:\n return __WASI_ERRNO_INVAL;\n case EIO:\n return __WASI_ERRNO_IO;\n case EISCONN:\n return __WASI_ERRNO_ISCONN;\n case EISDIR:\n return __WASI_ERRNO_ISDIR;\n case ELOOP:\n return __WASI_ERRNO_LOOP;\n case EMFILE:\n return __WASI_ERRNO_MFILE;\n case EMLINK:\n return __WASI_ERRNO_MLINK;\n case EMSGSIZE:\n return __WASI_ERRNO_MSGSIZE;\n case EMULTIHOP:\n return __WASI_ERRNO_MULTIHOP;\n case ENAMETOOLONG:\n return __WASI_ERRNO_NAMETOOLONG;\n case ENETDOWN:\n return __WASI_ERRNO_NETDOWN;\n case ENETRESET:\n return __WASI_ERRNO_NETRESET;\n case ENETUNREACH:\n return __WASI_ERRNO_NETUNREACH;\n case ENFILE:\n return __WASI_ERRNO_NFILE;\n case ENOBUFS:\n return __WASI_ERRNO_NOBUFS;\n case ENODEV:\n return __WASI_ERRNO_NODEV;\n case ENOENT:\n return __WASI_ERRNO_NOENT;\n case ENOEXEC:\n return __WASI_ERRNO_NOEXEC;\n case ENOLCK:\n return __WASI_ERRNO_NOLCK;\n case ENOLINK:\n return __WASI_ERRNO_NOLINK;\n case ENOMEM:\n return __WASI_ERRNO_NOMEM;\n case ENOMSG:\n return __WASI_ERRNO_NOMSG;\n case ENOPROTOOPT:\n return __WASI_ERRNO_NOPROTOOPT;\n case ENOSPC:\n return __WASI_ERRNO_NOSPC;\n case ENOSYS:\n return __WASI_ERRNO_NOSYS;\n case ENOTCONN:\n return __WASI_ERRNO_NOTCONN;\n case ENOTDIR:\n return __WASI_ERRNO_NOTDIR;\n case ENOTEMPTY:\n return __WASI_ERRNO_NOTEMPTY;\n case ENOTRECOVERABLE:\n return __WASI_ERRNO_NOTRECOVERABLE;\n case ENOTSOCK:\n return __WASI_ERRNO_NOTSOCK;\n case ENOTSUP:\n return __WASI_ERRNO_NOTSUP;\n case ENOTTY:\n return __WASI_ERRNO_NOTTY;\n case ENXIO:\n return __WASI_ERRNO_NXIO;\n case EOVERFLOW:\n return __WASI_ERRNO_OVERFLOW;\n case EOWNERDEAD:\n return __WASI_ERRNO_OWNERDEAD;\n case EPERM:\n return __WASI_ERRNO_PERM;\n case EPIPE:\n return __WASI_ERRNO_PIPE;\n case EPROTO:\n return __WASI_ERRNO_PROTO;\n case EPROTONOSUPPORT:\n return __WASI_ERRNO_PROTONOSUPPORT;\n case EPROTOTYPE:\n return __WASI_ERRNO_PROTOTYPE;\n case ERANGE:\n return __WASI_ERRNO_RANGE;\n case EROFS:\n return __WASI_ERRNO_ROFS;\n case ESPIPE:\n return __WASI_ERRNO_SPIPE;\n case ESRCH:\n return __WASI_ERRNO_SRCH;\n case ESTALE:\n return __WASI_ERRNO_STALE;\n case ETIMEDOUT:\n return __WASI_ERRNO_TIMEDOUT;\n case ETXTBSY:\n return __WASI_ERRNO_TXTBSY;\n case EXDEV:\n return __WASI_ERRNO_XDEV;\n default:\n assumingUnreachable();\n }\n}\n\ninline constexpr __wasi_errno_t fromEAIErrNo(int ErrNo) noexcept {\n switch (ErrNo) {\n case EAI_ADDRFAMILY:\n return __WASI_ERRNO_AIADDRFAMILY;\n case EAI_AGAIN:\n return __WASI_ERRNO_AIAGAIN;\n case EAI_BADFLAGS:\n return __WASI_ERRNO_AIBADFLAG;\n case EAI_FAIL:\n return __WASI_ERRNO_AIFAIL;\n case EAI_FAMILY:\n return __WASI_ERRNO_AIFAMILY;\n case EAI_MEMORY:\n return __WASI_ERRNO_AIMEMORY;\n case EAI_NODATA:\n return __WASI_ERRNO_AINODATA;\n case EAI_NONAME:\n return __WASI_ERRNO_AINONAME;\n case EAI_SERVICE:\n return __WASI_ERRNO_AISERVICE;\n case EAI_SOCKTYPE:\n return __WASI_ERRNO_AISOCKTYPE;\n case EAI_SYSTEM:\n return __WASI_ERRNO_AISYSTEM;\n default:\n assumingUnreachable();\n }\n}\n\ninline constexpr clockid_t toClockId(__wasi_clockid_t Clock) noexcept {\n switch (Clock) {\n case __WASI_CLOCKID_REALTIME:\n return CLOCK_REALTIME;\n case __WASI_CLOCKID_MONOTONIC:\n return CLOCK_MONOTONIC;\n case __WASI_CLOCKID_PROCESS_CPUTIME_ID:\n return CLOCK_PROCESS_CPUTIME_ID;\n case __WASI_CLOCKID_THREAD_CPUTIME_ID:\n return CLOCK_THREAD_CPUTIME_ID;\n default:\n assumingUnreachable();\n }\n}\n\n...\n// Path: lib/host/wasi/environ-macos.cpp\n// SPDX-License-Identifier: Apache-2.0\n// SPDX-FileCopyrightText: 2019-2022 Second State INC\n\n#include \"common/defines.h\"\n#if WASMEDGE_OS_MACOS\n\n#include \"common/errcode.h\"\n#include \"host/wasi/environ.h\"\n#include \"macos.h\"\n\nnamespace WasmEdge {\nnamespace Host {\nnamespace WASI {\n\nWasiExpect Environ::procRaise(__wasi_signal_t Signal) const noexcept {\n int SysSignal;\n switch (Signal) {\n case __WASI_SIGNAL_NONE:\n SysSignal = 0;\n break;\n case __WASI_SIGNAL_HUP:\n SysSignal = SIGHUP;\n break;\n case __WASI_SIGNAL_INT:\n SysSignal = SIGINT;\n break;\n case __WASI_SIGNAL_QUIT:\n SysSignal = SIGQUIT;\n break;\n case __WASI_SIGNAL_ILL:\n SysSignal = SIGILL;\n break;\n case __WASI_SIGNAL_TRAP:\n SysSignal = SIGTRAP;\n break;\n case __WASI_SIGNAL_ABRT:\n SysSignal = SIGABRT;\n break;\n case __WASI_SIGNAL_BUS:\n SysSignal = SIGBUS;\n break;\n case __WASI_SIGNAL_FPE:\n SysSignal = SIGFPE;\n break;\n case __WASI_SIGNAL_KILL:\n SysSignal = SIGKILL;\n break;\n case __WASI_SIGNAL_USR1:\n SysSignal = SIGUSR1;\n break;\n case __WASI_SIGNAL_SEGV:\n SysSignal = SIGSEGV;\n break;\n case __WASI_SIGNAL_USR2:\n SysSignal = SIGUSR2;\n break;\n case __WASI_SIGNAL_PIPE:\n SysSignal = SIGPIPE;\n break;\n case __WASI_SIGNAL_ALRM:\n SysSignal = SIGALRM;\n break;\n case __WASI_SIGNAL_TERM:\n SysSignal = SIGTERM;\n break;\n case __WASI_SIGNAL_CHLD:\n SysSignal = SIGCHLD;\n break;\n case __WASI_SIGNAL_CONT:\n SysSignal = SIGCONT;\n break;\n case __WASI_SIGNAL_STOP:\n SysSignal = SIGSTOP;\n break;\n case __WASI_SIGNAL_TSTP:\n SysSignal = SIGTSTP;\n break;\n case __WASI_SIGNAL_TTIN:\n SysSignal = SIGTTIN;\n break;\n case __WASI_SIGNAL_TTOU:\n SysSignal = SIGTTOU;\n break;\n case __WASI_SIGNAL_URG:\n SysSignal = SIGURG;\n break;\n case __WASI_SIGNAL_XCPU:\n SysSignal = SIGXCPU;\n break;\n case __WASI_SIGNAL_XFSZ:\n SysSignal = SIGXFSZ;\n break;\n case __WASI_SIGNAL_VTALRM:\n SysSignal = SIGVTALRM;\n break;\n case __WASI_SIGNAL_PROF:\n SysSignal = SIGPROF;\n break;\n case __WASI_SIGNAL_WINCH:\n SysSignal = SIGWINCH;\n break;\n case __WASI_SIGNAL_SYS:\n SysSignal = SIGSYS;\n break;\n case __WASI_SIGNAL_POLL:\n case __WASI_SIGNAL_PWR:\n default:\n return WasiUnexpect(__WASI_ERRNO_NOTSUP);\n }\n if (auto Res = std::raise(SysSignal); Res != 0) {\n return WasiUnexpect(fromErrNo(errno));\n }\n return {};\n}\n\nWasiExpect Environ::schedYield() const noexcept {\n ::sched_yield();\n return {};\n}\n\n} // namespace WASI\n} // namespace Host\n} // namespace WasmEdge\n\n#endif\n\n// Path: lib/host/wasi/linux.h\n// SPDX-License-Identifier: Apache-2.0\n// SPDX-FileCopyrightText: 2019-2022 Second State INC\n\n#include \"common/defines.h\"\n#if !WASMEDGE_OS_LINUX\n#error\n#endif\n\n// Uncomment these flag to test CentOS 6\n// #undef __GLIBC_MINOR__\n// #define __GLIBC_MINOR__ 5\n\n#include \"common/errcode.h\"\n#include \"wasi/api.hpp\"\n#include \n#include \n#include \n#include \n#include \n#include \n#include \n#include \n#include \n#include \n#include \n#include \n#include \n#include \n#include \n#include \n\n// socket include\n#include \n#include \n#include \n#include \n\n#include \n\n#if defined(__GLIBC_PREREQ)\n#if defined(_LIBCPP_GLIBC_PREREQ)\n#undef _LIBCPP_GLIBC_PREREQ\n#endif\n#define _LIBCPP_GLIBC_PREREQ(a, b) 0\n#else\n#if defined(_LIBCPP_GLIBC_PREREQ)\n#define __GLIBC_PREREQ(a, b) _LIBCPP_GLIBC_PREREQ(a, b)\n#else\n#define __GLIBC_PREREQ(a, b) 1\n#endif\n#endif\n\n#if __GLIBC_PREREQ(2, 8)\n#include \n#endif\n\nnamespace WasmEdge {\nnamespace Host {\nnamespace WASI {\ninline namespace detail {\n\ninline constexpr __wasi_errno_t fromErrNo(int ErrNo) noexcept {\n switch (ErrNo) {\n case 0:\n return __WASI_ERRNO_SUCCESS;\n case E2BIG:\n return __WASI_ERRNO_2BIG;\n case EACCES:\n return __WASI_ERRNO_ACCES;\n case EADDRINUSE:\n return __WASI_ERRNO_ADDRINUSE;\n case EADDRNOTAVAIL:\n return __WASI_ERRNO_ADDRNOTAVAIL;\n case EAFNOSUPPORT:\n return __WASI_ERRNO_AFNOSUPPORT;\n case EAGAIN:\n return __WASI_ERRNO_AGAIN;\n case EALREADY:\n return __WASI_ERRNO_ALREADY;\n case EBADF:\n return __WASI_ERRNO_BADF;\n case EBADMSG:\n return __WASI_ERRNO_BADMSG;\n case EBUSY:\n return __WASI_ERRNO_BUSY;\n case ECANCELED:\n return __WASI_ERRNO_CANCELED;\n case ECHILD:\n return __WASI_ERRNO_CHILD;\n case ECONNABORTED:\n return __WASI_ERRNO_CONNABORTED;\n case ECONNREFUSED:\n return __WASI_ERRNO_CONNREFUSED;\n case ECONNRESET:\n return __WASI_ERRNO_CONNRESET;\n case EDEADLK:\n return __WASI_ERRNO_DEADLK;\n case EDESTADDRREQ:\n return __WASI_ERRNO_DESTADDRREQ;\n case EDOM:\n return __WASI_ERRNO_DOM;\n case EDQUOT:\n return __WASI_ERRNO_DQUOT;\n case EEXIST:\n return __WASI_ERRNO_EXIST;\n case EFAULT:\n return __WASI_ERRNO_FAULT;\n case EFBIG:\n return __WASI_ERRNO_FBIG;\n case EHOSTUNREACH:\n return __WASI_ERRNO_HOSTUNREACH;\n case EIDRM:\n return __WASI_ERRNO_IDRM;\n case EILSEQ:\n return __WASI_ERRNO_ILSEQ;\n case EINPROGRESS:\n return __WASI_ERRNO_INPROGRESS;\n case EINTR:\n return __WASI_ERRNO_INTR;\n case EINVAL:\n return __WASI_ERRNO_INVAL;\n case EIO:\n return __WASI_ERRNO_IO;\n case EISCONN:\n return __WASI_ERRNO_ISCONN;\n case EISDIR:\n return __WASI_ERRNO_ISDIR;\n case ELOOP:\n return __WASI_ERRNO_LOOP;\n case EMFILE:\n return __WASI_ERRNO_MFILE;\n case EMLINK:\n return __WASI_ERRNO_MLINK;\n case EMSGSIZE:\n return __WASI_ERRNO_MSGSIZE;\n case EMULTIHOP:\n return __WASI_ERRNO_MULTIHOP;\n case ENAMETOOLONG:\n return __WASI_ERRNO_NAMETOOLONG;\n case ENETDOWN:\n return __WASI_ERRNO_NETDOWN;\n case ENETRESET:\n return __WASI_ERRNO_NETRESET;\n case ENETUNREACH:\n return __WASI_ERRNO_NETUNREACH;\n case ENFILE:\n return __WASI_ERRNO_NFILE;\n case ENOBUFS:\n return __WASI_ERRNO_NOBUFS;\n case ENODEV:\n return __WASI_ERRNO_NODEV;\n case ENOENT:\n return __WASI_ERRNO_NOENT;\n case ENOEXEC:\n return __WASI_ERRNO_NOEXEC;\n case ENOLCK:\n return __WASI_ERRNO_NOLCK;\n case ENOLINK:\n return __WASI_ERRNO_NOLINK;\n case ENOMEM:\n return __WASI_ERRNO_NOMEM;\n case ENOMSG:\n return __WASI_ERRNO_NOMSG;\n case ENOPROTOOPT:\n return __WASI_ERRNO_NOPROTOOPT;\n case ENOSPC:\n return __WASI_ERRNO_NOSPC;\n case ENOSYS:\n return __WASI_ERRNO_NOSYS;\n case ENOTCONN:\n return __WASI_ERRNO_NOTCONN;\n case ENOTDIR:\n return __WASI_ERRNO_NOTDIR;\n case ENOTEMPTY:\n return __WASI_ERRNO_NOTEMPTY;\n case ENOTRECOVERABLE:\n return __WASI_ERRNO_NOTRECOVERABLE;\n case ENOTSOCK:\n return __WASI_ERRNO_NOTSOCK;\n case ENOTSUP:\n return __WASI_ERRNO_NOTSUP;\n case ENOTTY:\n return __WASI_ERRNO_NOTTY;\n case ENXIO:\n return __WASI_ERRNO_NXIO;\n case EOVERFLOW:\n return __WASI_ERRNO_OVERFLOW;\n case EOWNERDEAD:\n return __WASI_ERRNO_OWNERDEAD;\n case EPERM:\n return __WASI_ERRNO_PERM;\n case EPIPE:\n return __WASI_ERRNO_PIPE;\n case EPROTO:\n return __WASI_ERRNO_PROTO;\n case EPROTONOSUPPORT:\n return __WASI_ERRNO_PROTONOSUPPORT;\n case EPROTOTYPE:\n return __WASI_ERRNO_PROTOTYPE;\n case ERANGE:\n return __WASI_ERRNO_RANGE;\n case EROFS:\n return __WASI_ERRNO_ROFS;\n case ESPIPE:\n return __WASI_ERRNO_SPIPE;\n case ESRCH:\n return __WASI_ERRNO_SRCH;\n case ESTALE:\n return __WASI_ERRNO_STALE;\n case ETIMEDOUT:\n return __WASI_ERRNO_TIMEDOUT;\n case ETXTBSY:\n return __WASI_ERRNO_TXTBSY;\n case EXDEV:\n return __WASI_ERRNO_XDEV;\n default:\n assumingUnreachable();\n }\n}\n\ninline constexpr __wasi_errno_t fromEAIErrNo(int ErrNo) noexcept {\n switch (ErrNo) {\n case EAI_ADDRFAMILY:\n return __WASI_ERRNO_AIADDRFAMILY;\n case EAI_AGAIN:\n return __WASI_ERRNO_AIAGAIN;\n case EAI_BADFLAGS:\n return __WASI_ERRNO_AIBADFLAG;\n case EAI_FAIL:\n return __WASI_ERRNO_AIFAIL;\n case EAI_FAMILY:\n return __WASI_ERRNO_AIFAMILY;\n case EAI_MEMORY:\n return __WASI_ERRNO_AIMEMORY;\n case EAI_NODATA:\n return __WASI_ERRNO_AINODATA;\n case EAI_NONAME:\n return __WASI_ERRNO_AINONAME;\n case EAI_SERVICE:\n return __WASI_ERRNO_AISERVICE;\n case EAI_SOCKTYPE:\n return __WASI_ERRNO_AISOCKTYPE;\n case EAI_SYSTEM:\n return __WASI_ERRNO_AISYSTEM;\n default:\n assumingUnreachable();\n }\n}\n\ninline constexpr clockid_t toClockId(__wasi_clockid_t Clock) noexcept {\n switch (Clock) {\n case __WASI_CLOCKID_REALTIME:\n return CLOCK_REALTIME;\n case __WASI_CLOCKID_MONOTONIC:\n return CLOCK_MONOTONIC;\n case __WASI_CLOCKID_PROCESS_CPUTIME_ID:\n return CLOCK_PROCESS_CPUTIME_ID;\n case __WASI_CLOCKID_THREAD_CPUTIME_ID:\n return CLOCK_THREAD_CPUTIME_ID;\n default:\n assumingUnreachable();\n }\n}\n\ninline constexpr timespec toTimespec(__wasi_timestamp_t Timestamp) noexcept {\n using namespace std::chrono;\n const auto Total = duration(Timestamp);\n const auto Second = duration_cast(Total);\n const auto Nano = Total - Second;\n timespec Result{};\n Result.tv_sec = Second.count();\n Result.tv_nsec = Nano.count();\n return Result;\n}\n\ninline constexpr __wasi_timestamp_t\nfromTimespec(const timespec &Time) noexcept {\n using namespace std::chrono;\n const auto Result = seconds(Time.tv_sec) + nanoseconds(Time.tv_nsec);\n return Result.count();\n}\n\n#if !__GLIBC_PREREQ(2, 6)\ninline constexpr timeval toTimeval(__wasi_timestamp_t Timestamp) noexcept {\n using namespace std::chrono;\n const auto Total = duration_cast(nanoseconds(Timestamp));\n const auto Second = duration_cast(Total);\n const auto Micro = Total - Second;\n timeval Result{};\n Result.tv_sec = Second.count();\n Result.tv_usec = Micro.count();\n return Result;\n}\ninline constexpr timeval toTimeval(timespec Timespec) noexcept {\n using namespace std::chrono;\n timeval Result{};\n Result.tv_sec = Timespec.tv_sec;\n Result.tv_usec =\n duration_cast(nanoseconds(Timespec.tv_nsec)).count();\n return Result;\n}\n#endif\n\n\ninline constexpr int toAdvice(__wasi_advice_t Advice) noexcept {\n switch (Advice) {\n case __WASI_ADVICE_NORMAL:\n return POSIX_FADV_NORMAL;\n case __WASI_ADVICE_SEQUENTIAL:\n return POSIX_FADV_SEQUENTIAL;\n case __WASI_ADVICE_RANDOM:\n return POSIX_FADV_RANDOM;\n case __WASI_ADVICE_WILLNEED:\n return POSIX_FADV_WILLNEED;\n case __WASI_ADVICE_DONTNEED:\n return POSIX_FADV_DONTNEED;\n case __WASI_ADVICE_NOREUSE:\n return POSIX_FADV_NOREUSE;\n default:\n assumingUnreachable();\n }\n}\n\ninline constexpr __wasi_filetype_t fromFileType(mode_t Mode) noexcept {\n switch (Mode & S_IFMT) {\n case S_IFBLK:\n return __WASI_FILETYPE_BLOCK_DEVICE;\n case S_IFCHR:\n return __WASI_FILETYPE_CHARACTER_DEVICE;\n case S_IFDIR:\n return __WASI_FILETYPE_DIRECTORY;\n case S_IFREG:\n return __WASI_FILETYPE_REGULAR_FILE;\n case S_IFSOCK:\n return __WASI_FILETYPE_SOCKET_STREAM;\n case S_IFLNK:\n return __WASI_FILETYPE_SYMBOLIC_LINK;\n case S_IFIFO:\n default:\n return __WASI_FILETYPE_UNKNOWN;\n }\n}\n\ninline constexpr __wasi_filetype_t fromFileType(uint8_t Type) noexcept {\n switch (Type) {\n case DT_BLK:\n return __WASI_FILETYPE_BLOCK_DEVICE;\n case DT_CHR:\n return __WASI_FILETYPE_CHARACTER_DEVICE;\n case DT_DIR:\n return __WASI_FILETYPE_DIRECTORY;\n case DT_LNK:\n return __WASI_FILETYPE_SYMBOLIC_LINK;\n case DT_REG:\n return __WASI_FILETYPE_REGULAR_FILE;\n case DT_SOCK:\n return __WASI_FILETYPE_SOCKET_STREAM;\n case DT_FIFO:\n case DT_UNKNOWN:\n default:\n return __WASI_FILETYPE_UNKNOWN;\n }\n}\n\ninline constexpr int toWhence(__wasi_whence_t Whence) noexcept {\n switch (Whence) {\n case __WASI_WHENCE_CUR:\n return SEEK_CUR;\n case __WASI_WHENCE_END:\n return SEEK_END;\n case __WASI_WHENCE_SET:\n return SEEK_SET;\n default:\n assumingUnreachable();\n }\n}\n\ninline constexpr int toSockOptLevel(__wasi_sock_opt_level_t Level) noexcept {\n switch (Level) {\n case __WASI_SOCK_OPT_LEVEL_SOL_SOCKET:\n return SOL_SOCKET;\n default:\n assumingUnreachable();\n }\n}\n\ninline constexpr int toSockOptSoName(__wasi_sock_opt_so_t SoName) noexcept {\n switch (SoName) {\n case __WASI_SOCK_OPT_SO_REUSEADDR:\n return SO_REUSEADDR;\n case __WASI_SOCK_OPT_SO_TYPE:\n return SO_TYPE;\n case __WASI_SOCK_OPT_SO_ERROR:\n return SO_ERROR;\n case __WASI_SOCK_OPT_SO_DONTROUTE:\n return SO_DONTROUTE;\n case __WASI_SOCK_OPT_SO_BROADCAST:\n return SO_BROADCAST;\n case __WASI_SOCK_OPT_SO_SNDBUF:\n return SO_SNDBUF;\n case __WASI_SOCK_OPT_SO_RCVBUF:\n return SO_RCVBUF;\n case __WASI_SOCK_OPT_SO_KEEPALIVE:\n return SO_KEEPALIVE;\n case __WASI_SOCK_OPT_SO_OOBINLINE:\n return SO_OOBINLINE;\n case __WASI_SOCK_OPT_SO_LINGER:\n return SO_LINGER;\n case __WASI_SOCK_OPT_SO_RCVLOWAT:\n return SO_RCVLOWAT;\n case __WASI_SOCK_OPT_SO_RCVTIMEO:\n return SO_RCVTIMEO;\n case __WASI_SOCK_OPT_SO_SNDTIMEO:\n return SO_SNDTIMEO;\n case __WASI_SOCK_OPT_SO_ACCEPTCONN:\n return SO_ACCEPTCONN;\n case __WASI_SOCK_OPT_SO_BINDTODEVICE:\n return SO_BINDTODEVICE;\n\n default:\n assumingUnreachable();\n }\n}\n\ninline constexpr __wasi_aiflags_t fromAIFlags(int AIFlags) noexcept {\n __wasi_aiflags_t Result = static_cast<__wasi_aiflags_t>(0);\n\n if (AIFlags & AI_PASSIVE) {\n Result |= __WASI_AIFLAGS_AI_PASSIVE;\n }\n if (AIFlags & AI_CANONNAME) {\n Result |= __WASI_AIFLAGS_AI_CANONNAME;\n }\n if (AIFlags & AI_NUMERICHOST) {\n Result |= __WASI_AIFLAGS_AI_NUMERICHOST;\n }\n if (AIFlags & AI_NUMERICSERV) {\n Result |= __WASI_AIFLAGS_AI_NUMERICSERV;\n }\n if (AIFlags & AI_V4MAPPED) {\n Result |= __WASI_AIFLAGS_AI_V4MAPPED;\n }\n if (AIFlags & AI_ALL) {\n Result |= __WASI_AIFLAGS_AI_ALL;\n }\n if (AIFlags & AI_ADDRCONFIG) {\n Result |= __WASI_AIFLAGS_AI_ADDRCONFIG;\n }\n\n return Result;\n}\n\ninline constexpr int toAIFlags(__wasi_aiflags_t AIFlags) noexcept {\n int Result = 0;\n\n if (AIFlags & __WASI_AIFLAGS_AI_PASSIVE) {\n Result |= AI_PASSIVE;\n }\n if (AIFlags & __WASI_AIFLAGS_AI_CANONNAME) {\n Result |= AI_CANONNAME;\n }\n if (AIFlags & __WASI_AIFLAGS_AI_NUMERICHOST) {\n Result |= AI_NUMERICHOST;\n }\n if (AIFlags & __WASI_AIFLAGS_AI_NUMERICSERV) {\n Result |= AI_NUMERICSERV;\n }\n if (AIFlags & __WASI_AIFLAGS_AI_V4MAPPED) {\n Result |= AI_V4MAPPED;\n }\n if (AIFlags & __WASI_AIFLAGS_AI_ALL) {\n Result |= AI_ALL;\n }\n if (AIFlags & __WASI_AIFLAGS_AI_ADDRCONFIG) {\n Result |= AI_ADDRCONFIG;\n }\n\n return Result;\n}\n\ninline constexpr __wasi_sock_type_t fromSockType(int SockType) noexcept {\n switch (SockType) {\n case 0:\n return __WASI_SOCK_TYPE_SOCK_ANY;\n case SOCK_DGRAM:\n return __WASI_SOCK_TYPE_SOCK_DGRAM;\n case SOCK_STREAM:\n return __WASI_SOCK_TYPE_SOCK_STREAM;\n default:\n assumingUnreachable();\n }\n}\n\ninline constexpr int toSockType(__wasi_sock_type_t SockType) noexcept {\n switch (SockType) {\n case __WASI_SOCK_TYPE_SOCK_ANY:\n return 0;\n case __WASI_SOCK_TYPE_SOCK_DGRAM:\n return SOCK_DGRAM;\n case __WASI_SOCK_TYPE_SOCK_STREAM:\n return SOCK_STREAM;\n default:\n assumingUnreachable();\n }\n}\n\ninline constexpr __wasi_protocol_t fromProtocol(int Protocol) noexcept {\n switch (Protocol) {\n case IPPROTO_IP:\n return __WASI_PROTOCOL_IPPROTO_IP;\n case IPPROTO_TCP:\n return __WASI_PROTOCOL_IPPROTO_TCP;\n case IPPROTO_UDP:\n return __WASI_PROTOCOL_IPPROTO_UDP;\n default:\n assumingUnreachable();\n }\n}\n\ninline constexpr int toProtocol(__wasi_protocol_t Protocol) noexcept {\n switch (Protocol) {\n case __WASI_PROTOCOL_IPPROTO_IP:\n return IPPROTO_IP;\n case __WASI_PROTOCOL_IPPROTO_TCP:\n return IPPROTO_TCP;\n case __WASI_PROTOCOL_IPPROTO_UDP:\n return IPPROTO_UDP;\n default:\n assumingUnreachable();\n }\n}\n\ninline constexpr __wasi_address_family_t\nfromAddressFamily(int AddressFamily) noexcept {\n switch (AddressFamily) {\n case PF_UNSPEC:\n return __WASI_ADDRESS_FAMILY_UNSPEC;\n case PF_INET:\n return __WASI_ADDRESS_FAMILY_INET4;\n case PF_INET6:\n return __WASI_ADDRESS_FAMILY_INET6;\n case PF_UNIX:\n return __WASI_ADDRESS_FAMILY_AF_UNIX;\n default:\n assumingUnreachable();\n }\n}\n\ninline constexpr int\ntoAddressFamily(__wasi_address_family_t AddressFamily) noexcept {\n switch (AddressFamily) {\n case __WASI_ADDRESS_FAMILY_UNSPEC:\n return PF_UNSPEC;\n case __WASI_ADDRESS_FAMILY_INET4:\n return PF_INET;\n case __WASI_ADDRESS_FAMILY_INET6:\n return PF_INET6;\n case __WASI_ADDRESS_FAMILY_AF_UNIX:\n return PF_UNIX;\n default:\n assumingUnreachable();\n }\n}\n\n} // namespace detail\n} // namespace WASI\n} // namespace Host\n} // namespace WasmEdge\n\n// Path: lib/host/wasi/inode-linux.cpp\n// SPDX-License-Identifier: Apache-2.0\n// SPDX-FileCopyrightText: 2019-2022 Second State INC\n\n#include \"common/defines.h\"\n#if WASMEDGE_OS_LINUX\n\n#include \"common/errcode.h\"\n#include \"common/variant.h\"\n#include \"host/wasi/environ.h\"\n#include \"host/wasi/inode.h\"\n#include \"host/wasi/vfs.h\"\n#include \"linux.h\"\n#include \n#include \n#include \n#include \n#include \n#include \n#include \n#include \n\nnamespace WasmEdge {\nnamespace Host {\nnamespace WASI {\n\nnamespace {\n\ninline constexpr bool isSpecialFd(int Fd) noexcept {\n switch (Fd) {\n case STDIN_FILENO:\n case STDOUT_FILENO:\n case STDERR_FILENO:\n return true;\n default:\n return false;\n }\n}\n\ninline constexpr __wasi_size_t\ncalculateAddrinfoLinkedListSize(struct addrinfo *const Addrinfo) {\n __wasi_size_t Length = 0;\n for (struct addrinfo *TmpPointer = Addrinfo; TmpPointer != nullptr;\n TmpPointer = TmpPointer->ai_next) {\n Length++;\n }\n return Length;\n};\n\nconstexpr int openFlags(__wasi_oflags_t OpenFlags, __wasi_fdflags_t FdFlags,\n VFS::Flags VFSFlags) noexcept {\n int Flags = O_NOFOLLOW;\n#ifdef O_CLOEXEC\n Flags |= O_CLOEXEC;\n#endif\n\n if (VFSFlags & VFS::Read) {\n if (VFSFlags & VFS::Write) {\n Flags |= O_RDWR;\n } else {\n Flags |= O_RDONLY;\n }\n } else if (VFSFlags & VFS::Write) {\n Flags |= O_WRONLY;\n } else {\n#ifdef O_PATH\n if (OpenFlags == __WASI_OFLAGS_DIRECTORY) {\n Flags |= O_PATH;\n } else {\n Flags |= O_RDONLY;\n }\n#else\n Flags |= O_RDONLY;\n#endif\n }\n\n if (OpenFlags & __WASI_OFLAGS_CREAT) {\n Flags |= O_CREAT;\n }\n if (OpenFlags & __WASI_OFLAGS_DIRECTORY) {\n Flags |= O_DIRECTORY;\n }\n if (OpenFlags & __WASI_OFLAGS_EXCL) {\n Flags |= O_EXCL;\n }\n if (OpenFlags & __WASI_OFLAGS_TRUNC) {\n Flags |= O_TRUNC;\n }\n\n // Convert file descriptor flags.\n if ((FdFlags & __WASI_FDFLAGS_APPEND) != 0) {\n Flags |= O_APPEND;\n }\n if ((FdFlags & __WASI_FDFLAGS_DSYNC) != 0) {\n#ifdef O_DSYNC\n Flags |= O_DSYNC;\n#else\n Flags |= O_SYNC;\n#endif\n }\n if ((FdFlags & __WASI_FDFLAGS_NONBLOCK) != 0) {\n Flags |= O_NONBLOCK;\n }\n if ((FdFlags & __WASI_FDFLAGS_RSYNC) != 0) {\n#ifdef O_RSYNC\n Flags |= O_RSYNC;\n#else\n Flags |= O_SYNC;\n#endif\n }\n if ((FdFlags & __WASI_FDFLAGS_SYNC) != 0) {\n Flags |= O_SYNC;\n }\n\n return Flags;\n}\n\nstd::pair>\ncreateNullTerminatedString(std::string_view View) noexcept {\n const char *CStr = nullptr;\n std::unique_ptr Buffer;\n if (!View.empty()) {\n if (const auto Pos = View.find_first_of('\\0');\n Pos != std::string_view::npos) {\n CStr = View.data();\n } else {\n Buffer = std::make_unique(View.size() + 1);\n std::copy(View.begin(), View.end(), Buffer.get());\n CStr = Buffer.get();\n }\n }\n return {CStr, std::move(Buffer)};\n}\n\n} // namespace\n\nvoid FdHolder::reset() noexcept {\n if (likely(ok())) {\n if (likely(!isSpecialFd(Fd))) {\n ::close(Fd);\n }\n Fd = -1;\n }\n}\n\nvoid TimerHolder::reset() noexcept {\n if (likely(Id.has_value())) {\n timer_delete(*Id);\n Id.reset();\n }\n}\n\nvoid DirHolder::reset() noexcept {\n if (likely(Dir != nullptr)) {\n closedir(Dir);\n Dir = nullptr;\n Cookie = 0;\n }\n}\n\nINode INode::stdIn() noexcept { return INode(STDIN_FILENO); }\n\nINode INode::stdOut() noexcept { return INode(STDOUT_FILENO); }\n\nINode INode::stdErr() noexcept { return INode(STDERR_FILENO); }\n\nWasiExpect INode::open(std::string Path, __wasi_oflags_t OpenFlags,\n __wasi_fdflags_t FdFlags,\n VFS::Flags VFSFlags) noexcept {\n const int Flags = openFlags(OpenFlags, FdFlags, VFSFlags);\n\n if (auto NewFd = ::open(Path.c_str(), Flags, 0644); unlikely(NewFd < 0)) {\n return WasiUnexpect(fromErrNo(errno));\n } else {\n INode New(NewFd);\n#ifndef O_CLOEXEC\n if (auto Res = ::fcntl(New.Fd, F_SETFD, FD_CLOEXEC); unlikely(Res != 0)) {\n return WasiUnexpect(fromErrNo(errno));\n }\n#endif\n return New;\n }\n}\n\nWasiExpect INode::fdAdvise(__wasi_filesize_t Offset,\n __wasi_filesize_t Len,\n __wasi_advice_t Advice) const noexcept {\n if (auto Res = ::posix_fadvise(Fd, Offset, Len, toAdvice(Advice));\n unlikely(Res != 0)) {\n return WasiUnexpect(fromErrNo(errno));\n }\n\n return {};\n}\n\nWasiExpect INode::fdAllocate(__wasi_filesize_t Offset,\n __wasi_filesize_t Len) const noexcept {\n if (auto Res = ::posix_fallocate(Fd, Offset, Len); unlikely(Res != 0)) {\n return WasiUnexpect(fromErrNo(errno));\n }\n\n return {};\n}\n\nWasiExpect INode::fdDatasync() const noexcept {\n if (auto Res = ::fdatasync(Fd); unlikely(Res != 0)) {\n return WasiUnexpect(fromErrNo(errno));\n }\n\n return {};\n}\n\nWasiExpect INode::fdFdstatGet(__wasi_fdstat_t &FdStat) const noexcept {\n if (auto Res = updateStat(); unlikely(!Res)) {\n return WasiUnexpect(Res);\n }\n\n if (int FdFlags = ::fcntl(Fd, F_GETFL); unlikely(FdFlags < 0)) {\n return WasiUnexpect(fromErrNo(errno));\n } else {\n FdStat.fs_filetype = unsafeFiletype();\n\n FdStat.fs_flags = static_cast<__wasi_fdflags_t>(0);\n if (FdFlags & O_APPEND) {\n FdStat.fs_flags |= __WASI_FDFLAGS_APPEND;\n }\n if (FdFlags & O_DSYNC) {\n FdStat.fs_flags |= __WASI_FDFLAGS_DSYNC;\n }\n if (FdFlags & O_NONBLOCK) {\n FdStat.fs_flags |= __WASI_FDFLAGS_NONBLOCK;\n }\n if (FdFlags & O_SYNC) {\n FdStat.fs_flags |= __WASI_FDFLAGS_RSYNC | __WASI_FDFLAGS_SYNC;\n }\n }\n\n return {};\n}\n\nWasiExpect\nINode::fdFdstatSetFlags(__wasi_fdflags_t FdFlags) const noexcept {\n int SysFlag = 0;\n if (FdFlags & __WASI_FDFLAGS_NONBLOCK) {\n SysFlag |= O_NONBLOCK;\n }\n if (FdFlags & __WASI_FDFLAGS_APPEND) {\n SysFlag |= O_APPEND;\n }\n if (FdFlags & __WASI_FDFLAGS_DSYNC) {\n SysFlag |= O_DSYNC;\n }\n if (FdFlags & __WASI_FDFLAGS_RSYNC) {\n SysFlag |= O_RSYNC;\n }\n if (FdFlags & __WASI_FDFLAGS_SYNC) {\n SysFlag |= O_SYNC;\n }\n\n if (auto Res = ::fcntl(Fd, F_SETFL, SysFlag); unlikely(Res != 0)) {\n return WasiUnexpect(fromErrNo(errno));\n }\n\n return {};\n}\n\nWasiExpect\nINode::fdFilestatGet(__wasi_filestat_t &Filestat) const noexcept {\n if (auto Res = updateStat(); unlikely(!Res)) {\n return WasiUnexpect(Res);\n }\n\n // Zeroing out these values to prevent leaking information about the host\n // environment from special fd such as stdin, stdout and stderr.\n Filestat.dev = isSpecialFd(Fd) ? 0 : Stat->st_dev;\n Filestat.ino = isSpecialFd(Fd) ? 0 : Stat->st_ino;\n Filestat.filetype = unsafeFiletype();\n Filestat.nlink = isSpecialFd(Fd) ? 0 : Stat->st_nlink;\n Filestat.size = isSpecialFd(Fd) ? 0 : Stat->st_size;\n Filestat.atim = isSpecialFd(Fd) ? 0 : fromTimespec(Stat->st_atim);\n Filestat.mtim = isSpecialFd(Fd) ? 0 : fromTimespec(Stat->st_mtim);\n Filestat.ctim = isSpecialFd(Fd) ? 0 : fromTimespec(Stat->st_ctim);\n\n return {};\n}\n\nWasiExpect\nINode::fdFilestatSetSize(__wasi_filesize_t Size) const noexcept {\n if (auto Res = ::ftruncate(Fd, Size); unlikely(Res == -1)) {\n return WasiUnexpect(fromErrNo(errno));\n }\n\n return {};\n}\n\nWasiExpect\nINode::fdFilestatSetTimes(__wasi_timestamp_t ATim, __wasi_timestamp_t MTim,\n __wasi_fstflags_t FstFlags) const noexcept {\n#if __GLIBC_PREREQ(2, 6) || __BIONIC__\n timespec SysTimespec[2];\n if (FstFlags & __WASI_FSTFLAGS_ATIM) {\n SysTimespec[0] = toTimespec(ATim);\n } else if (FstFlags & __WASI_FSTFLAGS_ATIM_NOW) {\n SysTimespec[0].tv_nsec = UTIME_NOW;\n } else {\n SysTimespec[0].tv_nsec = UTIME_OMIT;\n }\n if (FstFlags & __WASI_FSTFLAGS_MTIM) {\n SysTimespec[1] = toTimespec(MTim);\n } else if (FstFlags & __WASI_FSTFLAGS_MTIM_NOW) {\n SysTimespec[1].tv_nsec = UTIME_NOW;\n } else {\n SysTimespec[1].tv_nsec = UTIME_OMIT;\n }\n\n if (auto Res = ::futimens(Fd, SysTimespec); unlikely(Res != 0)) {\n return WasiUnexpect(fromErrNo(errno));\n }\n#else\n bool NeedNow = false;\n bool NeedFile = false;\n if (FstFlags & __WASI_FSTFLAGS_ATIM) {\n // Nothing to do.\n } else if (FstFlags & __WASI_FSTFLAGS_ATIM_NOW) {\n NeedNow = true;\n } else {\n NeedFile = true;\n }\n if (FstFlags & __WASI_FSTFLAGS_MTIM) {\n // Nothing to do.\n } else if (FstFlags & __WASI_FSTFLAGS_MTIM_NOW) {\n NeedNow = true;\n } else {\n NeedFile = true;\n }\n\n if (NeedFile) {\n if (auto Res = updateStat(); unlikely(!Res)) {\n return WasiUnexpect(Res);\n }\n }\n\n timespec Now;\n if (NeedNow) {\n if (auto Res = ::clock_gettime(CLOCK_REALTIME, &Now); unlikely(Res != 0)) {\n return WasiUnexpect(fromErrNo(errno));\n }\n }\n\n timeval SysTimeval[2];\n if (FstFlags & __WASI_FSTFLAGS_ATIM) {\n SysTimeval[0] = toTimeval(ATim);\n } else if (FstFlags & __WASI_FSTFLAGS_ATIM_NOW) {\n SysTimeval[0] = toTimeval(Now);\n } else {\n SysTimeval[0] = toTimeval(Stat->st_atim);\n }\n if (FstFlags & __WASI_FSTFLAGS_MTIM) {\n SysTimeval[1] = toTimeval(MTim);\n } else if (FstFlags & __WASI_FSTFLAGS_MTIM_NOW) {\n SysTimeval[1] = toTimeval(Now);\n } else {\n SysTimeval[1] = toTimeval(Stat->st_mtim);\n }\n\n if (auto Res = ::futimes(Fd, SysTimeval); unlikely(Res != 0)) {\n return WasiUnexpect(fromErrNo(errno));\n }\n#endif\n\n return {};\n}\n\nWasiExpect INode::fdPread(Span> IOVs,\n __wasi_filesize_t Offset,\n __wasi_size_t &NRead) const noexcept {\n iovec SysIOVs[kIOVMax];\n size_t SysIOVsSize = 0;\n for (auto &IOV : IOVs) {\n SysIOVs[SysIOVsSize].iov_base = IOV.data();\n SysIOVs[SysIOVsSize].iov_len = IOV.size();\n ++SysIOVsSize;\n }\n\n#if __GLIBC_PREREQ(2, 10)\n // Store read bytes length.\n if (auto Res = ::preadv(Fd, SysIOVs, SysIOVsSize, Offset);\n unlikely(Res < 0)) {\n return WasiUnexpect(fromErrNo(errno));\n } else {\n NRead = Res;\n }\n#else\n const auto OldOffset = ::lseek(Fd, 0, SEEK_CUR);\n if (OldOffset < 0) {\n return WasiUnexpect(fromErrNo(errno));\n }\n if (::lseek(Fd, Offset, SEEK_SET) < 0 ||\n ::lseek(Fd, OldOffset, SEEK_SET) < 0) {\n return WasiUnexpect(fromErrNo(errno));\n }\n if (auto Res = ::readv(Fd, SysIOVs, SysIOVsSize); unlikely(Res < 0)) {\n ::lseek(Fd, OldOffset, SEEK_SET);\n return WasiUnexpect(fromErrNo(errno));\n } else {\n if (::lseek(Fd, OldOffset, SEEK_SET) < 0) {\n return WasiUnexpect(fromErrNo(errno));\n }\n NRead = Res;\n }\n#endif\n\n return {};\n}\n\nWasiExpect INode::fdPwrite(Span> IOVs,\n __wasi_filesize_t Offset,\n __wasi_size_t &NWritten) const noexcept {\n iovec SysIOVs[kIOVMax];\n size_t SysIOVsSize = 0;\n for (auto &IOV : IOVs) {\n SysIOVs[SysIOVsSize].iov_base = const_cast(IOV.data());\n SysIOVs[SysIOVsSize].iov_len = IOV.size();\n ++SysIOVsSize;\n }\n\n#if __GLIBC_PREREQ(2, 10)\n if (auto Res = ::pwritev(Fd, SysIOVs, SysIOVsSize, Offset);\n unlikely(Res < 0)) {\n return WasiUnexpect(fromErrNo(errno));\n } else {\n NWritten = Res;\n }\n#else\n const auto OldOffset = ::lseek(Fd, 0, SEEK_CUR);\n if (OldOffset < 0) {\n return WasiUnexpect(fromErrNo(errno));\n }\n if (::lseek(Fd, Offset, SEEK_SET) < 0 ||\n ::lseek(Fd, OldOffset, SEEK_SET) < 0) {\n return WasiUnexpect(fromErrNo(errno));\n }\n if (auto Res = ::writev(Fd, SysIOVs, SysIOVsSize); unlikely(Res < 0)) {\n ::lseek(Fd, OldOffset, SEEK_SET);\n return WasiUnexpect(fromErrNo(errno));\n } else {\n if (::lseek(Fd, OldOffset, SEEK_SET) < 0) {\n return WasiUnexpect(fromErrNo(errno));\n }\n NWritten = Res;\n }\n#endif\n\n return {};\n}\n\nWasiExpect INode::fdRead(Span> IOVs,\n __wasi_size_t &NRead) const noexcept {\n iovec SysIOVs[kIOVMax];\n size_t SysIOVsSize = 0;\n for (auto &IOV : IOVs) {\n SysIOVs[SysIOVsSize].iov_base = IOV.data();\n SysIOVs[SysIOVsSize].iov_len = IOV.size();\n ++SysIOVsSize;\n }\n\n if (auto Res = ::readv(Fd, SysIOVs, SysIOVsSize); unlikely(Res < 0)) {\n return WasiUnexpect(fromErrNo(errno));\n } else {\n NRead = Res;\n }\n\n return {};\n}\n\n// Due to the unfortunate design of wasi::fd_readdir, It's nearly impossible to\n// provide a correct implementation. The below implementation is just a\n// workaround for most usages and may not be correct in some edge cases. The\n// readdir entry API is going to be updated to use a stream type, so we don't\n// have to deal with it right now.\nWasiExpect INode::fdReaddir(Span Buffer,\n __wasi_dircookie_t Cookie,\n __wasi_size_t &Size) noexcept {\n if (unlikely(!Dir.ok())) {\n if (FdHolder NewFd(::dup(Fd)); unlikely(!NewFd.ok())) {\n return WasiUnexpect(fromErrNo(errno));\n } else if (DIR *D = ::fdopendir(NewFd.Fd); unlikely(!D)) {\n return WasiUnexpect(fromErrNo(errno));\n } else {\n NewFd.release();\n Dir.emplace(D);\n }\n }\n\n if (Cookie == 0) {\n ::rewinddir(Dir.Dir);\n } else if (unlikely(Cookie != Dir.Cookie)) {\n ::seekdir(Dir.Dir, Cookie);\n }\n\n Size = 0;\n do {\n if (!Dir.Buffer.empty()) {\n const auto NewDataSize =\n std::min(Buffer.size(), Dir.Buffer.size());\n std::copy(Dir.Buffer.begin(), Dir.Buffer.begin() + NewDataSize,\n Buffer.begin());\n Buffer = Buffer.subspan(NewDataSize);\n Size += NewDataSize;\n Dir.Buffer.clear();\n if (unlikely(Buffer.empty())) {\n break;\n }\n }\n errno = 0;\n dirent *SysDirent = ::readdir(Dir.Dir);\n if (SysDirent == nullptr) {\n if (errno != 0) {\n return WasiUnexpect(fromErrNo(errno));\n }\n // End of entries\n break;\n }\n Dir.Cookie = SysDirent->d_off;\n std::string_view Name = SysDirent->d_name;\n\n Dir.Buffer.resize(sizeof(__wasi_dirent_t) + Name.size());\n\n __wasi_dirent_t *const Dirent =\n reinterpret_cast<__wasi_dirent_t *>(Dir.Buffer.data());\n Dirent->d_next = Dir.Cookie;\n Dirent->d_ino = SysDirent->d_ino;\n Dirent->d_type = fromFileType(SysDirent->d_type);\n...\n", "needle_token_start": 7214, "needle_token_end": 7450, "code_context_ntokens": 16260, "output": ["```cpp\ninline constexpr int toAdvice(__wasi_advice_t Advice) noexcept {\n switch (Advice) {\n case __WASI_ADVICE_NORMAL:\n return POSIX_FADV_NORMAL;\n case __WASI_ADVICE_SEQUENTIAL:\n return POSIX_FADV_SEQUENTIAL;\n case __WASI_ADVICE_RANDOM:\n return POSIX_FADV_RANDOM;\n case __WASI_ADVICE_WILLNEED:\n return POSIX_FADV_WILLNEED;\n case __WASI_ADVICE_DONTNEED:\n return POSIX_FADV_DONTNEED;\n case __WASI_ADVICE_NOREUSE:\n return POSIX_FADV_NOREUSE;\n default:\n assumingUnreachable();\n }\n}\n```"]} +{"repo": "WasmEdge/WasmEdge", "name": "WasmEdge_CompilerContext", "language": "cpp", "path": "lib/api/wasmedge.cpp", "position_ratio": 0.55, "description": "\nFunction Description:\n1. **Purpose**: The purpose of this structure is to initialize and manage the compilation process for WebAssembly modules, integrating various components necessary for compilation, code generation, loading, and validation.\n2. **Input**: It takes a configuration object which contains settings that influence the behavior of the compilation process.\n3. **Output**: It does not directly produce an output but sets up an environment for WebAssembly module compilation by initializing several internal components based on the provided configuration.\n4. **Procedure**: The procedure involves initializing internal components responsible for compiling, generating machine code, loading, and validating WebAssembly modules. These components are configured based on the settings provided in the input configuration object.\n\n", "instruction": "Based on the function description and code context, please retrieve and repeat the exact described function from the code context in a code block wrapped by ```:", "template": "instruction\ncode_context\ndescription\ninstruction", "code_context": "// Path: lib/validator/formchecker.cpp\n// SPDX-License-Identifier: Apache-2.0\n// SPDX-FileCopyrightText: 2019-2022 Second State INC\n\n#include \"validator/formchecker.h\"\n\n#include \"common/errinfo.h\"\n#include \"common/log.h\"\n\n#include \n#include \n#include \n#include \n\nnamespace WasmEdge {\nnamespace Validator {\n\nnamespace {\n\n// Helper function for printing error log of index out of range.\nauto logOutOfRange(ErrCode Code, ErrInfo::IndexCategory Cate, uint32_t Idx,\n uint32_t Bound) {\n spdlog::error(Code);\n spdlog::error(ErrInfo::InfoForbidIndex(Cate, Idx, Bound));\n return Unexpect(Code);\n}\n\n} // namespace\n\nvoid FormChecker::reset(bool CleanGlobal) {\n ValStack.clear();\n CtrlStack.clear();\n Locals.clear();\n Returns.clear();\n\n if (CleanGlobal) {\n Types.clear();\n Funcs.clear();\n Tables.clear();\n Mems = 0;\n Globals.clear();\n Datas.clear();\n Elems.clear();\n Refs.clear();\n NumImportFuncs = 0;\n NumImportGlobals = 0;\n }\n}\n\nExpect FormChecker::validate(AST::InstrView Instrs,\n Span RetVals) {\n for (const ValType &Val : RetVals) {\n Returns.push_back(Val);\n }\n return checkExpr(Instrs);\n}\n\nExpect FormChecker::validate(const ValType &VT) const noexcept {\n // The value type should be validated for the type index case.\n if (VT.isRefType() && VT.getHeapTypeCode() == TypeCode::TypeIndex) {\n if (VT.getTypeIndex() >= Types.size()) {\n spdlog::error(ErrCode::Value::InvalidFuncTypeIdx);\n spdlog::error(ErrInfo::InfoForbidIndex(\n ErrInfo::IndexCategory::FunctionType, VT.getTypeIndex(),\n static_cast(Types.size())));\n return Unexpect(ErrCode::Value::InvalidFuncTypeIdx);\n }\n }\n return {};\n}\n\nvoid FormChecker::addType(const AST::SubType &Type) { Types.push_back(&Type); }\n\nvoid FormChecker::addFunc(const uint32_t TypeIdx, const bool IsImport) {\n if (Types.size() > TypeIdx) {\n Funcs.emplace_back(TypeIdx);\n }\n if (IsImport) {\n NumImportFuncs++;\n }\n}\n\nvoid FormChecker::addTable(const AST::TableType &Tab) {\n Tables.push_back(Tab.getRefType());\n}\n\nvoid FormChecker::addMemory(const AST::MemoryType &) { Mems++; }\n\nvoid FormChecker::addGlobal(const AST::GlobalType &Glob, const bool IsImport) {\n // Type in global is confirmed in loading phase.\n Globals.emplace_back(Glob.getValType(), Glob.getValMut());\n if (IsImport) {\n NumImportGlobals++;\n }\n}\n\nvoid FormChecker::addData(const AST::DataSegment &) {\n Datas.emplace_back(static_cast(Datas.size()));\n}\n\n...\n// Path: lib/aot/blake3.cpp\n// SPDX-License-Identifier: Apache-2.0\n// SPDX-FileCopyrightText: 2019-2022 Second State INC\n\n#include \"aot/blake3.h\"\n\n#include \"common/config.h\"\n#include \"common/defines.h\"\n\nnamespace WasmEdge {\nnamespace AOT {\n\nnamespace {}\n\nBlake3::Blake3() noexcept { blake3_hasher_init(&Hasher); }\n\nvoid Blake3::update(Span Data) noexcept {\n blake3_hasher_update(&Hasher, Data.data(), Data.size());\n}\n\nvoid Blake3::finalize(Span Output) noexcept {\n blake3_hasher_finalize(&Hasher, Output.data(), Output.size());\n}\n\n} // namespace AOT\n} // namespace WasmEdge\n\n// Path: lib/aot/cache.cpp\n// SPDX-License-Identifier: Apache-2.0\n// SPDX-FileCopyrightText: 2019-2022 Second State INC\n\n#include \"aot/cache.h\"\n\n#include \"aot/blake3.h\"\n#include \"common/config.h\"\n#include \"common/defines.h\"\n#include \"common/hexstr.h\"\n#include \"system/path.h\"\n\n#include \n#include \n#include \n\nnamespace WasmEdge {\nnamespace AOT {\n\nnamespace {\nstd::filesystem::path getRoot(Cache::StorageScope Scope) {\n switch (Scope) {\n case Cache::StorageScope::Global:\n return std::filesystem::u8path(kCacheRoot);\n case Cache::StorageScope::Local: {\n if (const auto Home = Path::home(); !Home.empty()) {\n return Home / \"cache\"sv;\n }\n return {};\n }\n default:\n assumingUnreachable();\n }\n}\n} // namespace\n\nExpect Cache::getPath(Span Data,\n Cache::StorageScope Scope,\n std::string_view Key) {\n auto Root = getRoot(Scope);\n if (!Key.empty()) {\n Root /= std::filesystem::u8path(Key);\n }\n\n Blake3 Hasher;\n Hasher.update(Data);\n std::array Hash;\n Hasher.finalize(Hash);\n std::string HexStr;\n convertBytesToHexStr(Hash, HexStr);\n\n return Root / HexStr;\n}\n\nvoid Cache::clear(Cache::StorageScope Scope, std::string_view Key) {\n auto Root = getRoot(Scope);\n if (!Key.empty()) {\n Root /= std::filesystem::u8path(Key);\n }\n std::error_code ErrCode;\n std::filesystem::remove_all(Root, ErrCode);\n}\n\n} // namespace AOT\n} // namespace WasmEdge\n\n// Path: lib/vm/vm.cpp\n// SPDX-License-Identifier: Apache-2.0\n// SPDX-FileCopyrightText: 2019-2022 Second State INC\n\n#include \"vm/vm.h\"\n\n#include \"host/wasi/wasimodule.h\"\n#include \"plugin/plugin.h\"\n#include \"llvm/compiler.h\"\n#include \"llvm/jit.h\"\n\n#include \"host/mock/wasi_crypto_module.h\"\n#include \"host/mock/wasi_logging_module.h\"\n#include \"host/mock/wasi_nn_module.h\"\n#include \"host/mock/wasmedge_image_module.h\"\n#include \"host/mock/wasmedge_process_module.h\"\n#include \"host/mock/wasmedge_tensorflow_module.h\"\n#include \"host/mock/wasmedge_tensorflowlite_module.h\"\n#include \n\nnamespace WasmEdge {\nnamespace VM {\n\nnamespace {\ntemplate \nstd::unique_ptr\ncreatePluginModule(std::string_view PName, std::string_view MName) {\n using namespace std::literals::string_view_literals;\n if (const auto *Plugin = Plugin::Plugin::find(PName)) {\n if (const auto *Module = Plugin->findModule(MName)) {\n return Module->create();\n }\n }\n spdlog::debug(\"Plugin: {} , module name: {} not found. Mock instead.\"sv,\n PName, MName);\n return std::make_unique();\n}\n} // namespace\n\nVM::VM(const Configure &Conf)\n : Conf(Conf), Stage(VMStage::Inited),\n LoaderEngine(Conf, &Executor::Executor::Intrinsics),\n ValidatorEngine(Conf), ExecutorEngine(Conf, &Stat),\n Store(std::make_unique()), StoreRef(*Store.get()) {\n unsafeInitVM();\n}\n\nVM::VM(const Configure &Conf, Runtime::StoreManager &S)\n : Conf(Conf), Stage(VMStage::Inited),\n LoaderEngine(Conf, &Executor::Executor::Intrinsics),\n ValidatorEngine(Conf), ExecutorEngine(Conf, &Stat), StoreRef(S) {\n unsafeInitVM();\n}\n\nvoid VM::unsafeInitVM() {\n // Load the built-in modules and the plug-ins.\n unsafeLoadBuiltInHosts();\n unsafeLoadPlugInHosts();\n\n // Register all module instances.\n unsafeRegisterBuiltInHosts();\n unsafeRegisterPlugInHosts();\n}\n\nvoid VM::unsafeLoadBuiltInHosts() {\n // Load the built-in host modules from configuration.\n // TODO: This will be extended for the versionlized WASI in the future.\n BuiltInModInsts.clear();\n if (Conf.hasHostRegistration(HostRegistration::Wasi)) {\n std::unique_ptr WasiMod =\n std::make_unique();\n BuiltInModInsts.insert({HostRegistration::Wasi, std::move(WasiMod)});\n }\n}\n\nvoid VM::unsafeLoadPlugInHosts() {\n // Load the plugins and mock them if not found.\n using namespace std::literals::string_view_literals;\n PlugInModInsts.clear();\n\n PlugInModInsts.push_back(\n createPluginModule(\"wasi_nn\"sv, \"wasi_nn\"sv));\n PlugInModInsts.push_back(createPluginModule(\n \"wasi_crypto\"sv, \"wasi_crypto_common\"sv));\n PlugInModInsts.push_back(\n createPluginModule(\n \"wasi_crypto\"sv, \"wasi_crypto_asymmetric_common\"sv));\n PlugInModInsts.push_back(createPluginModule(\n \"wasi_crypto\"sv, \"wasi_crypto_kx\"sv));\n PlugInModInsts.push_back(\n createPluginModule(\n \"wasi_crypto\"sv, \"wasi_crypto_signatures\"sv));\n PlugInModInsts.push_back(\n createPluginModule(\n \"wasi_crypto\"sv, \"wasi_crypto_symmetric\"sv));\n PlugInModInsts.push_back(createPluginModule(\n \"wasmedge_process\"sv, \"wasmedge_process\"sv));\n PlugInModInsts.push_back(createPluginModule(\n \"wasi_logging\"sv, \"wasi:logging/logging\"sv));\n PlugInModInsts.push_back(\n createPluginModule(\n \"wasmedge_tensorflow\"sv, \"wasmedge_tensorflow\"sv));\n PlugInModInsts.push_back(\n createPluginModule(\n \"wasmedge_tensorflowlite\"sv, \"wasmedge_tensorflowlite\"sv));\n PlugInModInsts.push_back(createPluginModule(\n \"wasmedge_image\"sv, \"wasmedge_image\"sv));\n\n // Load the other non-official plugins.\n for (const auto &Plugin : Plugin::Plugin::plugins()) {\n if (Conf.isForbiddenPlugins(Plugin.name())) {\n continue;\n }\n // Skip wasi_crypto, wasi_nn, wasi_logging, WasmEdge_Process,\n // WasmEdge_Tensorflow, WasmEdge_TensorflowLite, and WasmEdge_Image.\n if (Plugin.name() == \"wasi_crypto\"sv || Plugin.name() == \"wasi_nn\"sv ||\n Plugin.name() == \"wasi_logging\"sv ||\n Plugin.name() == \"wasmedge_process\"sv ||\n Plugin.name() == \"wasmedge_tensorflow\"sv ||\n Plugin.name() == \"wasmedge_tensorflowlite\"sv ||\n Plugin.name() == \"wasmedge_image\"sv) {\n continue;\n }\n for (const auto &Module : Plugin.modules()) {\n PlugInModInsts.push_back(Module.create());\n }\n }\n}\n\nvoid VM::unsafeRegisterBuiltInHosts() {\n // Register all created WASI host modules.\n for (auto &It : BuiltInModInsts) {\n ExecutorEngine.registerModule(StoreRef, *(It.second.get()));\n }\n}\n\nvoid VM::unsafeRegisterPlugInHosts() {\n // Register all created module instances from plugins.\n for (auto &It : PlugInModInsts) {\n ExecutorEngine.registerModule(StoreRef, *(It.get()));\n }\n}\n\nExpect VM::unsafeRegisterModule(std::string_view Name,\n const std::filesystem::path &Path) {\n if (Stage == VMStage::Instantiated) {\n // When registering module, instantiated module in store will be reset.\n // Therefore the instantiation should restart.\n Stage = VMStage::Validated;\n }\n // Load module.\n if (auto Res = LoaderEngine.parseModule(Path)) {\n return unsafeRegisterModule(Name, *(*Res).get());\n } else {\n return Unexpect(Res);\n }\n}\n\nExpect VM::unsafeRegisterModule(std::string_view Name,\n Span Code) {\n if (Stage == VMStage::Instantiated) {\n // When registering module, instantiated module in store will be reset.\n // Therefore the instantiation should restart.\n Stage = VMStage::Validated;\n }\n // Load module.\n if (auto Res = LoaderEngine.parseModule(Code)) {\n return unsafeRegisterModule(Name, *(*Res).get());\n } else {\n return Unexpect(Res);\n }\n}\n\nExpect VM::unsafeRegisterModule(std::string_view Name,\n const AST::Module &Module) {\n if (Stage == VMStage::Instantiated) {\n // When registering module, instantiated module in store will be reset.\n // Therefore the instantiation should restart.\n Stage = VMStage::Validated;\n }\n // Validate module.\n if (auto Res = ValidatorEngine.validate(Module); !Res) {\n return Unexpect(Res);\n }\n // Instantiate and register module.\n if (auto Res = ExecutorEngine.registerModule(StoreRef, Module, Name)) {\n RegModInsts.push_back(std::move(*Res));\n return {};\n } else {\n return Unexpect(Res);\n }\n}\n\nExpect\nVM::unsafeRegisterModule(const Runtime::Instance::ModuleInstance &ModInst) {\n if (Stage == VMStage::Instantiated) {\n // When registering module, instantiated module in store will be reset.\n // Therefore the instantiation should restart.\n Stage = VMStage::Validated;\n }\n return ExecutorEngine.registerModule(StoreRef, ModInst);\n}\n\nExpect>>\nVM::unsafeRunWasmFile(const std::filesystem::path &Path, std::string_view Func,\n Span Params,\n Span ParamTypes) {\n if (Stage == VMStage::Instantiated) {\n // When running another module, instantiated module in store will be reset.\n // Therefore the instantiation should restart.\n Stage = VMStage::Validated;\n }\n // Load wasm unit.\n if (auto Res = LoaderEngine.parseWasmUnit(Path)) {\n if (std::holds_alternative>(*Res)) {\n auto M = std::move(std::get>(*Res));\n return unsafeRunWasmFile(*M, Func, Params, ParamTypes);\n } else {\n return unsafeRunWasmFile(\n (*std::get>(*Res)), Func,\n Params, ParamTypes);\n }\n } else {\n return Unexpect(Res);\n }\n}\n\nExpect>>\nVM::unsafeRunWasmFile(Span Code, std::string_view Func,\n Span Params,\n Span ParamTypes) {\n if (Stage == VMStage::Instantiated) {\n // When running another module, instantiated module in store will be reset.\n // Therefore the instantiation should restart.\n Stage = VMStage::Validated;\n }\n // Load wasm unit.\n if (auto Res = LoaderEngine.parseWasmUnit(Code)) {\n if (std::holds_alternative>(*Res)) {\n std::unique_ptr M =\n std::move(std::get>(*Res));\n return unsafeRunWasmFile(*M, Func, Params, ParamTypes);\n } else {\n std::unique_ptr C =\n std::move(std::get>(*Res));\n return unsafeRunWasmFile(*C, Func, Params, ParamTypes);\n }\n } else {\n return Unexpect(Res);\n }\n}\n\nExpect>>\nVM::unsafeRunWasmFile(const AST::Component::Component &Component,\n std::string_view, Span,\n Span) {\n if (Stage == VMStage::Instantiated) {\n // When running another module, instantiated module in store will be reset.\n // Therefore the instantiation should restart.\n Stage = VMStage::Validated;\n }\n if (auto Res = ValidatorEngine.validate(Component); !Res) {\n return Unexpect(Res);\n }\n spdlog::error(\"component execution is not done yet.\");\n return Unexpect(ErrCode::Value::RuntimeError);\n}\n\nExpect>>\nVM::unsafeRunWasmFile(const AST::Module &Module, std::string_view Func,\n Span Params,\n Span ParamTypes) {\n if (Stage == VMStage::Instantiated) {\n // When running another module, instantiated module in store will be reset.\n // Therefore the instantiation should restart.\n Stage = VMStage::Validated;\n }\n if (auto Res = ValidatorEngine.validate(Module); !Res) {\n return Unexpect(Res);\n }\n if (auto Res = ExecutorEngine.instantiateModule(StoreRef, Module)) {\n ActiveModInst = std::move(*Res);\n } else {\n return Unexpect(Res);\n }\n // Get module instance.\n if (ActiveModInst) {\n // Execute function and return values with the module instance.\n return unsafeExecute(ActiveModInst.get(), Func, Params, ParamTypes);\n } else {\n spdlog::error(ErrCode::Value::WrongInstanceAddress);\n spdlog::error(ErrInfo::InfoExecuting(\"\", Func));\n return Unexpect(ErrCode::Value::WrongInstanceAddress);\n }\n}\n\nAsync>>>\nVM::asyncRunWasmFile(const std::filesystem::path &Path, std::string_view Func,\n Span Params,\n Span ParamTypes) {\n Expect>> (VM::*FPtr)(\n const std::filesystem::path &, std::string_view, Span,\n Span) = &VM::runWasmFile;\n return {FPtr,\n *this,\n std::filesystem::path(Path),\n std::string(Func),\n std::vector(Params.begin(), Params.end()),\n std::vector(ParamTypes.begin(), ParamTypes.end())};\n}\n\nAsync>>>\nVM::asyncRunWasmFile(Span Code, std::string_view Func,\n Span Params,\n Span ParamTypes) {\n Expect>> (VM::*FPtr)(\n Span, std::string_view, Span,\n Span) = &VM::runWasmFile;\n return {FPtr,\n *this,\n Code,\n std::string(Func),\n std::vector(Params.begin(), Params.end()),\n std::vector(ParamTypes.begin(), ParamTypes.end())};\n}\n\nAsync>>>\nVM::asyncRunWasmFile(const AST::Module &Module, std::string_view Func,\n Span Params,\n Span ParamTypes) {\n Expect>> (VM::*FPtr)(\n const AST::Module &, std::string_view, Span,\n Span) = &VM::runWasmFile;\n return {FPtr,\n *this,\n Module,\n std::string(Func),\n std::vector(Params.begin(), Params.end()),\n std::vector(ParamTypes.begin(), ParamTypes.end())};\n}\n\nExpect VM::unsafeLoadWasm(const std::filesystem::path &Path) {\n // If not load successfully, the previous status will be reserved.\n if (auto Res = LoaderEngine.parseWasmUnit(Path)) {\n if (std::holds_alternative>(*Res)) {\n Mod = std::move(std::get>(*Res));\n } else if (std::holds_alternative<\n std::unique_ptr>(*Res)) {\n spdlog::error(\"component execution is not done yet.\");\n } else {\n return Unexpect(Res);\n }\n Stage = VMStage::Loaded;\n } else {\n return Unexpect(Res);\n }\n return {};\n}\n\nExpect VM::unsafeLoadWasm(Span Code) {\n // If not load successfully, the previous status will be reserved.\n if (auto Res = LoaderEngine.parseWasmUnit(Code)) {\n if (std::holds_alternative>(*Res)) {\n Mod = std::move(std::get>(*Res));\n } else if (std::holds_alternative<\n std::unique_ptr>(*Res)) {\n spdlog::error(\"component execution is not done yet.\");\n } else {\n return Unexpect(Res);\n }\n Stage = VMStage::Loaded;\n } else {\n return Unexpect(Res);\n }\n return {};\n}\n\nExpect VM::unsafeLoadWasm(const AST::Module &Module) {\n Mod = std::make_unique(Module);\n Stage = VMStage::Loaded;\n return {};\n}\n\nExpect VM::unsafeValidate() {\n if (Stage < VMStage::Loaded) {\n // When module is not loaded, not validate.\n spdlog::error(ErrCode::Value::WrongVMWorkflow);\n return Unexpect(ErrCode::Value::WrongVMWorkflow);\n }\n if (auto Res = ValidatorEngine.validate(*Mod.get())) {\n Stage = VMStage::Validated;\n return {};\n } else {\n return Unexpect(Res);\n }\n}\n\nExpect VM::unsafeInstantiate() {\n if (Stage < VMStage::Validated) {\n // When module is not validated, not instantiate.\n spdlog::error(ErrCode::Value::WrongVMWorkflow);\n return Unexpect(ErrCode::Value::WrongVMWorkflow);\n }\n\n if (Mod) {\n if (Conf.getRuntimeConfigure().isEnableJIT() && !Mod->getSymbol()) {\n#ifdef WASMEDGE_USE_LLVM\n LLVM::Compiler Compiler(Conf);\n LLVM::JIT JIT(Conf);\n if (auto Res = Compiler.compile(*Mod); !Res) {\n const auto Err = static_cast(Res.error());\n spdlog::error(\n \"Compilation failed. Error code: {}, use interpreter mode instead.\"sv,\n Err);\n } else if (auto Res2 = JIT.load(std::move(*Res)); !Res2) {\n const auto Err = static_cast(Res2.error());\n spdlog::warn(\n \"JIT failed. Error code: {}, use interpreter mode instead.\"sv, Err);\n } else {\n LoaderEngine.loadExecutable(*Mod, std::move(*Res2));\n }\n#else\n spdlog::error(\"LLVM disabled, JIT is unsupported!\");\n#endif\n }\n }\n\n if (auto Res = ExecutorEngine.instantiateModule(StoreRef, *Mod.get())) {\n Stage = VMStage::Instantiated;\n ActiveModInst = std::move(*Res);\n return {};\n } else {\n return Unexpect(Res);\n }\n}\n\nExpect>>\nVM::unsafeExecute(std::string_view Func, Span Params,\n Span ParamTypes) {\n if (ActiveModInst) {\n // Execute function and return values with the module instance.\n return unsafeExecute(ActiveModInst.get(), Func, Params, ParamTypes);\n } else {\n spdlog::error(ErrCode::Value::WrongInstanceAddress);\n spdlog::error(ErrInfo::InfoExecuting(\"\", Func));\n return Unexpect(ErrCode::Value::WrongInstanceAddress);\n }\n}\n\nExpect>>\nVM::unsafeExecute(std::string_view ModName, std::string_view Func,\n Span Params,\n Span ParamTypes) {\n // Find module instance by name.\n const auto *FindModInst = StoreRef.findModule(ModName);\n if (FindModInst != nullptr) {\n // Execute function and return values with the module instance.\n return unsafeExecute(FindModInst, Func, Params, ParamTypes);\n } else {\n spdlog::error(ErrCode::Value::WrongInstanceAddress);\n spdlog::error(ErrInfo::InfoExecuting(ModName, Func));\n return Unexpect(ErrCode::Value::WrongInstanceAddress);\n }\n}\n\nExpect>>\nVM::unsafeExecute(const Runtime::Instance::ModuleInstance *ModInst,\n std::string_view Func, Span Params,\n Span ParamTypes) {\n // Find exported function by name.\n Runtime::Instance::FunctionInstance *FuncInst =\n ModInst->findFuncExports(Func);\n\n // Execute function.\n if (auto Res = ExecutorEngine.invoke(FuncInst, Params, ParamTypes);\n unlikely(!Res)) {\n if (Res.error() != ErrCode::Value::Terminated) {\n spdlog::error(ErrInfo::InfoExecuting(ModInst->getModuleName(), Func));\n }\n return Unexpect(Res);\n } else {\n return Res;\n }\n}\n\nAsync>>>\nVM::asyncExecute(std::string_view Func, Span Params,\n Span ParamTypes) {\n Expect>> (VM::*FPtr)(\n std::string_view, Span, Span) =\n &VM::execute;\n return {FPtr, *this, std::string(Func),\n std::vector(Params.begin(), Params.end()),\n std::vector(ParamTypes.begin(), ParamTypes.end())};\n}\n\nAsync>>>\nVM::asyncExecute(std::string_view ModName, std::string_view Func,\n Span Params,\n Span ParamTypes) {\n Expect>> (VM::*FPtr)(\n std::string_view, std::string_view, Span,\n Span) = &VM::execute;\n return {FPtr,\n *this,\n std::string(ModName),\n std::string(Func),\n std::vector(Params.begin(), Params.end()),\n std::vector(ParamTypes.begin(), ParamTypes.end())};\n}\n\nvoid VM::unsafeCleanup() {\n Mod.reset();\n ActiveModInst.reset();\n StoreRef.reset();\n RegModInsts.clear();\n Stat.clear();\n unsafeLoadBuiltInHosts();\n unsafeLoadPlugInHosts();\n unsafeRegisterBuiltInHosts();\n unsafeRegisterPlugInHosts();\n LoaderEngine.reset();\n Stage = VMStage::Inited;\n}\n\nstd::vector>\nVM::unsafeGetFunctionList() const {\n std::vector> Map;\n if (ActiveModInst) {\n ActiveModInst->getFuncExports([&](const auto &FuncExports) {\n Map.reserve(FuncExports.size());\n for (auto &&Func : FuncExports) {\n const auto &FuncType = (Func.second)->getFuncType();\n Map.emplace_back(Func.first, FuncType);\n }\n });\n }\n return Map;\n}\n\nRuntime::Instance::ModuleInstance *\nVM::unsafeGetImportModule(const HostRegistration Type) const {\n if (auto Iter = BuiltInModInsts.find(Type); Iter != BuiltInModInsts.cend()) {\n return Iter->second.get();\n }\n return nullptr;\n}\n\nconst Runtime::Instance::ModuleInstance *VM::unsafeGetActiveModule() const {\n if (ActiveModInst) {\n return ActiveModInst.get();\n }\n return nullptr;\n};\n\n} // namespace VM\n} // namespace WasmEdge\n\n// Path: lib/api/wasmedge.cpp\n// SPDX-License-Identifier: Apache-2.0\n// SPDX-FileCopyrightText: 2019-2022 Second State INC\n\n#include \"wasmedge/wasmedge.h\"\n\n#include \"common/defines.h\"\n#include \"driver/compiler.h\"\n#include \"driver/tool.h\"\n#include \"driver/unitool.h\"\n#include \"host/wasi/wasimodule.h\"\n#include \"plugin/plugin.h\"\n#include \"system/winapi.h\"\n#include \"vm/vm.h\"\n#include \"llvm/codegen.h\"\n#include \"llvm/compiler.h\"\n\n#ifdef WASMEDGE_BUILD_FUZZING\n#include \"driver/fuzzPO.h\"\n#include \"driver/fuzzTool.h\"\n#endif\n\n#ifdef WASMEDGE_BUILD_WASI_NN_RPC\n#include \"driver/wasiNNRPCServerTool.h\"\n#endif\n\n#include \n#include \n#include \n#include \n#include \n#include \n#include \n#include \n#include \n#include \n#include \n#include \n\n// WasmEdge_ConfigureContext implementation.\nstruct WasmEdge_ConfigureContext {\n WasmEdge::Configure Conf;\n};\n\n// WasmEdge_StatisticsContext implementation.\nstruct WasmEdge_StatisticsContext {};\n\n// WasmEdge_ASTModuleContext implementation.\nstruct WasmEdge_ASTModuleContext {};\n\n// WasmEdge_FunctionTypeContext implementation.\nstruct WasmEdge_FunctionTypeContext {};\n\n// WasmEdge_TableTypeContext implementation.\nstruct WasmEdge_TableTypeContext {};\n\n// WasmEdge_MemoryTypeContext implementation.\nstruct WasmEdge_MemoryTypeContext {};\n\n// WasmEdge_GlobalTypeContext implementation.\nstruct WasmEdge_GlobalTypeContext {};\n\n// WasmEdge_ImportTypeContext implementation.\nstruct WasmEdge_ImportTypeContext {};\n\n// WasmEdge_ExportTypeContext implementation.\nstruct WasmEdge_ExportTypeContext {};\n\n// WasmEdge_CompilerContext implementation.\nstruct WasmEdge_CompilerContext {\n#ifdef WASMEDGE_USE_LLVM\n \nWasmEdge_CompilerContext(const WasmEdge::Configure &Conf) noexcept\n : Compiler(Conf), CodeGen(Conf), Load(Conf), Valid(Conf) {}\n WasmEdge::LLVM::Compiler Compiler;\n WasmEdge::LLVM::CodeGen CodeGen;\n WasmEdge::Loader::Loader Load;\n WasmEdge::Validator::Validator Valid;\n#endif\n};\n\n// WasmEdge_LoaderContext implementation.\nstruct WasmEdge_LoaderContext {};\n\n// WasmEdge_ValidatorContext implementation.\nstruct WasmEdge_ValidatorContext {};\n\n// WasmEdge_ExecutorContext implementation.\nstruct WasmEdge_ExecutorContext {};\n\n// WasmEdge_StoreContext implementation.\nstruct WasmEdge_StoreContext {};\n\n// WasmEdge_ModuleInstanceContext implementation.\nstruct WasmEdge_ModuleInstanceContext {};\n\n// WasmEdge_FunctionInstanceContext implementation.\nstruct WasmEdge_FunctionInstanceContext {};\n\n// WasmEdge_TableInstanceContext implementation.\nstruct WasmEdge_TableInstanceContext {};\n\n// WasmEdge_MemoryInstanceContext implementation.\nstruct WasmEdge_MemoryInstanceContext {};\n\n// WasmEdge_GlobalInstanceContext implementation.\nstruct WasmEdge_GlobalInstanceContext {};\n\n// WasmEdge_CallingFrameContext implementation.\nstruct WasmEdge_CallingFrameContext {};\n\n// WasmEdge_Async implementation.\nstruct WasmEdge_Async {\n template \n WasmEdge_Async(Args &&...Vals) noexcept\n : Async(std::forward(Vals)...) {}\n WasmEdge::Async>>>\n Async;\n};\n\n// WasmEdge_VMContext implementation.\nstruct WasmEdge_VMContext {\n template \n WasmEdge_VMContext(Args &&...Vals) noexcept\n : VM(std::forward(Vals)...) {}\n WasmEdge::VM::VM VM;\n};\n\n// WasmEdge_PluginContext implementation.\nstruct WasmEdge_PluginContext {};\n\nnamespace {\n\nusing namespace WasmEdge;\n\n// Helper function for returning a WasmEdge_Result by error code.\ninline constexpr WasmEdge_Result\ngenWasmEdge_Result(const ErrCode::Value &Code) noexcept {\n return WasmEdge_Result{/* Code */ static_cast(Code) & 0x00FFFFFFU};\n}\ninline constexpr WasmEdge_Result\ngenWasmEdge_Result(const ErrCode &Code) noexcept {\n return WasmEdge_Result{/* Code */ Code.operator uint32_t()};\n}\n\n// Helper function for returning a struct uint128_t / int128_t\n// from class WasmEdge::uint128_t / WasmEdge::int128_t.\ntemplate \ninline constexpr ::uint128_t to_uint128_t(C Val) noexcept {\n#if defined(__x86_64__) || defined(__aarch64__) || \\\n (defined(__riscv) && __riscv_xlen == 64)\n return Val;\n#else\n return {/* Low */ Val.low(), /* High */ static_cast(Val.high())};\n#endif\n}\ntemplate inline constexpr ::int128_t to_int128_t(C Val) noexcept {\n#if defined(__x86_64__) || defined(__aarch64__) || \\\n (defined(__riscv) && __riscv_xlen == 64)\n return Val;\n#else\n return {/* Low */ Val.low(), /* High */ Val.high()};\n#endif\n}\n\n// Helper function for returning a class WasmEdge::uint128_t /\n// WasmEdge::int128_t from struct uint128_t / int128_t.\ntemplate \ninline constexpr C to_WasmEdge_128_t(T Val) noexcept {\n#if defined(__x86_64__) || defined(__aarch64__) || \\\n (defined(__riscv) && __riscv_xlen == 64)\n return Val;\n#else\n return C(Val.High, Val.Low);\n#endif\n}\n\n// Helper functions for returning a WasmEdge::ValType by WasmEdge_ValType.\ninline ValType genValType(const WasmEdge_ValType &T) noexcept {\n std::array R;\n std::copy_n(T.Data, 8, R.begin());\n return ValType(R);\n}\n\n// Helper functions for returning a WasmEdge_ValType by WasmEdge::ValType.\ninline WasmEdge_ValType genWasmEdge_ValType(const ValType &T) noexcept {\n WasmEdge_ValType VT;\n std::copy_n(T.getRawData().cbegin(), 8, VT.Data);\n return VT;\n}\n\n// Helper functions for returning a WasmEdge_Value by various values.\ntemplate \ninline WasmEdge_Value genWasmEdge_Value(const T &Val) noexcept {\n return WasmEdge_Value{\n /* Value */ to_uint128_t(ValVariant(Val).unwrap()),\n /* Type */ genWasmEdge_ValType(WasmEdge::ValTypeFromType())};\n}\ninline WasmEdge_Value genWasmEdge_Value(const ValVariant &Val,\n const ValType &T) noexcept {\n return WasmEdge_Value{/* Value */ to_uint128_t(Val.unwrap()),\n /* Type */ genWasmEdge_ValType(T)};\n}\n\n// Helper function for converting a WasmEdge_Value array to a ValVariant\n// vector.\ninline std::pair, std::vector>\ngenParamPair(const WasmEdge_Value *Val, const uint32_t Len) noexcept {\n // The nullable value in reference types checking is handled in executor.\n std::vector VVec;\n std::vector TVec;\n if (Val == nullptr) {\n return {VVec, TVec};\n }\n VVec.resize(Len);\n TVec.resize(Len);\n for (uint32_t I = 0; I < Len; I++) {\n TVec[I] = genValType(Val[I].Type);\n switch (TVec[I].getCode()) {\n case TypeCode::I32:\n VVec[I] = ValVariant::wrap(\n to_WasmEdge_128_t(Val[I].Value));\n break;\n case TypeCode::I64:\n VVec[I] = ValVariant::wrap(\n to_WasmEdge_128_t(Val[I].Value));\n break;\n case TypeCode::F32:\n VVec[I] = ValVariant::wrap(\n to_WasmEdge_128_t(Val[I].Value));\n break;\n case TypeCode::F64:\n VVec[I] = ValVariant::wrap(\n to_WasmEdge_128_t(Val[I].Value));\n break;\n case TypeCode::V128:\n VVec[I] = ValVariant::wrap(\n to_WasmEdge_128_t(Val[I].Value));\n break;\n case TypeCode::Ref:\n case TypeCode::RefNull: {\n VVec[I] = ValVariant::wrap(\n to_WasmEdge_128_t(Val[I].Value));\n break;\n }\n default:\n assumingUnreachable();\n }\n }\n return {VVec, TVec};\n}\n\n// Helper function for making a Span to a uint8_t array.\ntemplate \ninline constexpr Span genSpan(const T *Buf,\n const uint32_t Len) noexcept {\n if (Buf && Len > 0) {\n return Span(Buf, Len);\n }\n return Span();\n}\n\n// Helper functions for converting WasmEdge_String to std::String.\ninline std::string_view genStrView(const WasmEdge_String S) noexcept {\n return std::string_view(S.Buf, S.Length);\n}\n\n// Helper functions for converting a ValVariant vector to a WasmEdge_Value\n// array.\ninline constexpr void\nfillWasmEdge_ValueArr(Span> Vec,\n WasmEdge_Value *Val, const uint32_t Len) noexcept {\n if (Val == nullptr) {\n return;\n }\n for (uint32_t I = 0; I < Len && I < Vec.size(); I++) {\n Val[I] = genWasmEdge_Value(Vec[I].first, Vec[I].second);\n }\n}\n\n// Helper template to run and return result.\nauto EmptyThen = [](auto &&) noexcept {};\ntemplate inline bool isContext(T *Cxt) noexcept {\n return (Cxt != nullptr);\n}\ntemplate \ninline bool isContext(T *Cxt, Args *...Cxts) noexcept {\n return isContext(Cxt) && isContext(Cxts...);\n}\ntemplate \ninline WasmEdge_Result wrap(T &&Proc, U &&Then, CxtT *...Cxts) noexcept {\n if (isContext(Cxts...)) {\n if (auto Res = Proc()) {\n Then(Res);\n return genWasmEdge_Result(ErrCode::Value::Success);\n } else {\n return genWasmEdge_Result(Res.error());\n }\n } else {\n return genWasmEdge_Result(ErrCode::Value::WrongVMWorkflow);\n }\n}\n\n// Helper function of retrieving exported maps.\ntemplate \ninline uint32_t fillMap(const std::map> &Map,\n WasmEdge_String *Names, const uint32_t Len) noexcept {\n uint32_t I = 0;\n for (auto &&Pair : Map) {\n if (I >= Len) {\n break;\n }\n if (Names) {\n Names[I] = WasmEdge_String{\n /* Length */ static_cast(Pair.first.length()),\n /* Buf */ Pair.first.data()};\n }\n I++;\n }\n return static_cast(Map.size());\n}\n\n// Helper functions of context conversions.\n#define CONVTO(SIMP, INST, NAME, QUANT) \\\n inline QUANT auto *to##SIMP##Cxt(QUANT INST *Cxt) noexcept { \\\n return reinterpret_cast(Cxt); \\\n }\nCONVTO(Stat, Statistics::Statistics, Statistics, )\nCONVTO(ASTMod, AST::Module, ASTModule, )\nCONVTO(FuncType, AST::FunctionType, FunctionType, )\nCONVTO(FuncType, AST::FunctionType, FunctionType, const)\nCONVTO(TabType, AST::TableType, TableType, )\nCONVTO(TabType, AST::TableType, TableType, const)\nCONVTO(MemType, AST::MemoryType, MemoryType, )\nCONVTO(MemType, AST::MemoryType, MemoryType, const)\nCONVTO(GlobType, AST::GlobalType, GlobalType, )\nCONVTO(GlobType, AST::GlobalType, GlobalType, const)\nCONVTO(ImpType, AST::ImportDesc, ImportType, const)\nCONVTO(ExpType, AST::ExportDesc, ExportType, const)\nCONVTO(Store, Runtime::StoreManager, Store, )\nCONVTO(Loader, Loader::Loader, Loader, )\nCONVTO(Validator, Validator::Validator, Validator, )\nCONVTO(Executor, Executor::Executor, Executor, )\nCONVTO(Mod, Runtime::Instance::ModuleInstance, ModuleInstance, )\nCONVTO(Mod, Runtime::Instance::ModuleInstance, ModuleInstance, const)\nCONVTO(Func, Runtime::Instance::FunctionInstance, FunctionInstance, )\nCONVTO(Func, Runtime::Instance::FunctionInstance, FunctionInstance, const)\nCONVTO(Tab, Runtime::Instance::TableInstance, TableInstance, )\nCONVTO(Mem, Runtime::Instance::MemoryInstance, MemoryInstance, )\nCONVTO(Glob, Runtime::Instance::GlobalInstance, GlobalInstance, )\nCONVTO(CallFrame, Runtime::CallingFrame, CallingFrame, const)\nCONVTO(Plugin, Plugin::Plugin, Plugin, const)\n#undef CONVTO\n\n#define CONVFROM(SIMP, INST, NAME, QUANT) \\\n inline QUANT auto *from##SIMP##Cxt( \\\n QUANT WasmEdge_##NAME##Context *Cxt) noexcept { \\\n return reinterpret_cast(Cxt); \\\n }\nCONVFROM(Stat, Statistics::Statistics, Statistics, )\nCONVFROM(Stat, Statistics::Statistics, Statistics, const)\nCONVFROM(ASTMod, AST::Module, ASTModule, )\nCONVFROM(ASTMod, AST::Module, ASTModule, const)\nCONVFROM(FuncType, AST::FunctionType, FunctionType, )\nCONVFROM(FuncType, AST::FunctionType, FunctionType, const)\nCONVFROM(TabType, AST::TableType, TableType, )\nCONVFROM(TabType, AST::TableType, TableType, const)\nCONVFROM(MemType, AST::MemoryType, MemoryType, )\nCONVFROM(MemType, AST::MemoryType, MemoryType, const)\nCONVFROM(GlobType, AST::GlobalType, GlobalType, )\nCONVFROM(GlobType, AST::GlobalType, GlobalType, const)\nCONVFROM(ImpType, AST::ImportDesc, ImportType, const)\nCONVFROM(ExpType, AST::ExportDesc, ExportType, const)\nCONVFROM(Store, Runtime::StoreManager, Store, )\nCONVFROM(Store, Runtime::StoreManager, Store, const)\nCONVFROM(Loader, Loader::Loader, Loader, )\nCONVFROM(Validator, Validator::Validator, Validator, )\nCONVFROM(Executor, Executor::Executor, Executor, )\nCONVFROM(Mod, Runtime::Instance::ModuleInstance, ModuleInstance, )\nCONVFROM(Mod, Runtime::Instance::ModuleInstance, ModuleInstance, const)\nCONVFROM(Func, Runtime::Instance::FunctionInstance, FunctionInstance, )\nCONVFROM(Func, Runtime::Instance::FunctionInstance, FunctionInstance, const)\nCONVFROM(Tab, Runtime::Instance::TableInstance, TableInstance, )\nCONVFROM(Tab, Runtime::Instance::TableInstance, TableInstance, const)\nCONVFROM(Mem, Runtime::Instance::MemoryInstance, MemoryInstance, )\nCONVFROM(Mem, Runtime::Instance::MemoryInstance, MemoryInstance, const)\nCONVFROM(Glob, Runtime::Instance::GlobalInstance, GlobalInstance, )\nCONVFROM(Glob, Runtime::Instance::GlobalInstance, GlobalInstance, const)\nCONVFROM(CallFrame, Runtime::CallingFrame, CallingFrame, const)\nCONVFROM(Plugin, Plugin::Plugin, Plugin, const)\n#undef CONVFROM\n\n// C API Host function class\nclass CAPIHostFunc : public Runtime::HostFunctionBase {\npublic:\n CAPIHostFunc(const AST::FunctionType *Type, WasmEdge_HostFunc_t FuncPtr,\n void *ExtData, const uint64_t FuncCost = 0) noexcept\n : Runtime::HostFunctionBase(FuncCost), Func(FuncPtr), Wrap(nullptr),\n Binding(nullptr), Data(ExtData) {\n DefType.getCompositeType().getFuncType() = *Type;\n }\n CAPIHostFunc(const AST::FunctionType *Type, WasmEdge_WrapFunc_t WrapPtr,\n void *BindingPtr, void *ExtData,\n const uint64_t FuncCost = 0) noexcept\n : Runtime::HostFunctionBase(FuncCost), Func(nullptr), Wrap(WrapPtr),\n Binding(BindingPtr), Data(ExtData) {\n DefType.getCompositeType().getFuncType() = *Type;\n }\n ~CAPIHostFunc() noexcept override = default;\n\n Expect run(const Runtime::CallingFrame &CallFrame,\n Span Args,\n Span Rets) override {\n auto &FuncType = DefType.getCompositeType().getFuncType();\n std::vector Params(FuncType.getParamTypes().size()),\n Returns(FuncType.getReturnTypes().size());\n for (uint32_t I = 0; I < Args.size(); I++) {\n Params[I] = genWasmEdge_Value(Args[I], FuncType.getParamTypes()[I]);\n }\n WasmEdge_Value *PPtr = Params.size() ? (&Params[0]) : nullptr;\n WasmEdge_Value *RPtr = Returns.size() ? (&Returns[0]) : nullptr;\n auto *CallFrameCxt = toCallFrameCxt(&CallFrame);\n WasmEdge_Result Stat;\n if (Func) {\n Stat = Func(Data, CallFrameCxt, PPtr, RPtr);\n } else {\n Stat = Wrap(Binding, Data, CallFrameCxt, PPtr,\n static_cast(Params.size()), RPtr,\n static_cast(Returns.size()));\n }\n for (uint32_t I = 0; I < Rets.size(); I++) {\n Rets[I] = to_WasmEdge_128_t(Returns[I].Value);\n }\n if (WasmEdge_ResultOK(Stat)) {\n if (WasmEdge_ResultGetCode(Stat) == 0x01U) {\n return Unexpect(ErrCode::Value::Terminated);\n }\n } else {\n return Unexpect(\n static_cast(WasmEdge_ResultGetCategory(Stat)),\n WasmEdge_ResultGetCode(Stat));\n }\n return {};\n }\n void *getData() const noexcept { return Data; }\n\nprivate:\n WasmEdge_HostFunc_t Func;\n WasmEdge_WrapFunc_t Wrap;\n void *Binding;\n void *Data;\n};\n\n} // namespace\n\n#ifdef __cplusplus\nextern \"C\" {\n#endif\n\n// >>>>>>>> WasmEdge version functions >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n\nWASMEDGE_CAPI_EXPORT const char *WasmEdge_VersionGet(void) {\n return WASMEDGE_VERSION;\n}\n\nWASMEDGE_CAPI_EXPORT uint32_t WasmEdge_VersionGetMajor(void) {\n return WASMEDGE_VERSION_MAJOR;\n}\n\nWASMEDGE_CAPI_EXPORT uint32_t WasmEdge_VersionGetMinor(void) {\n return WASMEDGE_VERSION_MINOR;\n}\n\nWASMEDGE_CAPI_EXPORT uint32_t WasmEdge_VersionGetPatch(void) {\n return WASMEDGE_VERSION_PATCH;\n}\n\n// <<<<<<<< WasmEdge version functions <<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<\n\n// >>>>>>>> WasmEdge logging functions >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n\nWASMEDGE_CAPI_EXPORT void WasmEdge_LogSetErrorLevel(void) {\n WasmEdge::Log::setErrorLoggingLevel();\n}\n\nWASMEDGE_CAPI_EXPORT void WasmEdge_LogSetDebugLevel(void) {\n WasmEdge::Log::setDebugLoggingLevel();\n}\n\nWASMEDGE_CAPI_EXPORT void WasmEdge_LogOff(void) { WasmEdge::Log::setLogOff(); }\n\n// <<<<<<<< WasmEdge logging functions <<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<\n\n// >>>>>>>> WasmEdge valtype functions >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n\nWASMEDGE_CAPI_EXPORT WasmEdge_ValType WasmEdge_ValTypeGenI32(void) {\n return genWasmEdge_ValType(ValType(TypeCode::I32));\n}\n\nWASMEDGE_CAPI_EXPORT WasmEdge_ValType WasmEdge_ValTypeGenI64(void) {\n return genWasmEdge_ValType(ValType(TypeCode::I64));\n}\n\nWASMEDGE_CAPI_EXPORT WasmEdge_ValType WasmEdge_ValTypeGenF32(void) {\n return genWasmEdge_ValType(ValType(TypeCode::F32));\n}\n\nWASMEDGE_CAPI_EXPORT WasmEdge_ValType WasmEdge_ValTypeGenF64(void) {\n return genWasmEdge_ValType(ValType(TypeCode::F64));\n}\n\nWASMEDGE_CAPI_EXPORT WasmEdge_ValType WasmEdge_ValTypeGenV128(void) {\n return genWasmEdge_ValType(ValType(TypeCode::V128));\n}\n\nWASMEDGE_CAPI_EXPORT WasmEdge_ValType WasmEdge_ValTypeGenFuncRef(void) {\n return genWasmEdge_ValType(ValType(TypeCode::FuncRef));\n}\n\nWASMEDGE_CAPI_EXPORT WasmEdge_ValType WasmEdge_ValTypeGenExternRef(void) {\n return genWasmEdge_ValType(ValType(TypeCode::ExternRef));\n}\n\nWASMEDGE_CAPI_EXPORT bool\nWasmEdge_ValTypeIsEqual(const WasmEdge_ValType ValType1,\n const WasmEdge_ValType ValType2) {\n return genValType(ValType1) == genValType(ValType2);\n}\n\nWASMEDGE_CAPI_EXPORT bool\nWasmEdge_ValTypeIsI32(const WasmEdge_ValType ValType) {\n return genValType(ValType).getCode() == WasmEdge::TypeCode::I32;\n}\n\nWASMEDGE_CAPI_EXPORT bool\nWasmEdge_ValTypeIsI64(const WasmEdge_ValType ValType) {\n return genValType(ValType).getCode() == WasmEdge::TypeCode::I64;\n}\n\nWASMEDGE_CAPI_EXPORT bool\nWasmEdge_ValTypeIsF32(const WasmEdge_ValType ValType) {\n return genValType(ValType).getCode() == WasmEdge::TypeCode::F32;\n}\n\nWASMEDGE_CAPI_EXPORT bool\nWasmEdge_ValTypeIsF64(const WasmEdge_ValType ValType) {\n return genValType(ValType).getCode() == WasmEdge::TypeCode::F64;\n}\n\nWASMEDGE_CAPI_EXPORT bool\nWasmEdge_ValTypeIsV128(const WasmEdge_ValType ValType) {\n return genValType(ValType).getCode() == WasmEdge::TypeCode::V128;\n}\n\nWASMEDGE_CAPI_EXPORT bool\nWasmEdge_ValTypeIsFuncRef(const WasmEdge_ValType ValType) {\n return genValType(ValType).isFuncRefType();\n}\n\nWASMEDGE_CAPI_EXPORT bool\nWasmEdge_ValTypeIsExternRef(const WasmEdge_ValType ValType) {\n return genValType(ValType).isExternRefType();\n}\n\nWASMEDGE_CAPI_EXPORT bool\nWasmEdge_ValTypeIsRef(const WasmEdge_ValType ValType) {\n return genValType(ValType).isRefType();\n}\n\nWASMEDGE_CAPI_EXPORT bool\nWasmEdge_ValTypeIsRefNull(const WasmEdge_ValType ValType) {\n return genValType(ValType).isNullableRefType();\n}\n\n// <<<<<<<< WasmEdge valtype functions <<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<\n\n// >>>>>>>> WasmEdge value functions >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n\nWASMEDGE_CAPI_EXPORT WasmEdge_Value WasmEdge_ValueGenI32(const int32_t Val) {\n return genWasmEdge_Value(Val);\n}\n\nWASMEDGE_CAPI_EXPORT WasmEdge_Value WasmEdge_ValueGenI64(const int64_t Val) {\n return genWasmEdge_Value(Val);\n}\n\nWASMEDGE_CAPI_EXPORT WasmEdge_Value WasmEdge_ValueGenF32(const float Val) {\n return genWasmEdge_Value(Val);\n}\n\nWASMEDGE_CAPI_EXPORT WasmEdge_Value WasmEdge_ValueGenF64(const double Val) {\n return genWasmEdge_Value(Val);\n}\n\nWASMEDGE_CAPI_EXPORT WasmEdge_Value\nWasmEdge_ValueGenV128(const ::int128_t Val) {\n return genWasmEdge_Value(to_WasmEdge_128_t(Val));\n}\n\nWASMEDGE_CAPI_EXPORT WasmEdge_Value\nWasmEdge_ValueGenFuncRef(const WasmEdge_FunctionInstanceContext *Cxt) {\n return genWasmEdge_Value(WasmEdge::RefVariant(fromFuncCxt(Cxt)),\n TypeCode::FuncRef);\n}\n\nWASMEDGE_CAPI_EXPORT WasmEdge_Value WasmEdge_ValueGenExternRef(void *Ref) {\n return genWasmEdge_Value(WasmEdge::RefVariant(Ref), TypeCode::ExternRef);\n}\n\nWASMEDGE_CAPI_EXPORT int32_t WasmEdge_ValueGetI32(const WasmEdge_Value Val) {\n return WasmEdge::ValVariant::wrap(\n to_WasmEdge_128_t(Val.Value))\n .get();\n}\n\nWASMEDGE_CAPI_EXPORT int64_t WasmEdge_ValueGetI64(const WasmEdge_Value Val) {\n return WasmEdge::ValVariant::wrap(\n to_WasmEdge_128_t(Val.Value))\n .get();\n}\n\nWASMEDGE_CAPI_EXPORT float WasmEdge_ValueGetF32(const WasmEdge_Value Val) {\n return WasmEdge::ValVariant::wrap(\n to_WasmEdge_128_t(Val.Value))\n .get();\n}\n\nWASMEDGE_CAPI_EXPORT double WasmEdge_ValueGetF64(const WasmEdge_Value Val) {\n return WasmEdge::ValVariant::wrap(\n to_WasmEdge_128_t(Val.Value))\n .get();\n}\n\nWASMEDGE_CAPI_EXPORT ::int128_t\nWasmEdge_ValueGetV128(const WasmEdge_Value Val) {\n return to_int128_t(WasmEdge::ValVariant::wrap(\n...\n", "needle_token_start": 8876, "needle_token_end": 8917, "code_context_ntokens": 16178, "output": ["```cpp\nWasmEdge_CompilerContext(const WasmEdge::Configure &Conf) noexcept\n : Compiler(Conf), CodeGen(Conf), Load(Conf), Valid(Conf) {}\n```"]} +{"repo": "WasmEdge/WasmEdge", "name": "proxy", "language": "cpp", "path": "lib/executor/engine/proxy.cpp", "position_ratio": 0.65, "description": "\nFunction Description:\n1. **Purpose**: The purpose of this function is to serve as an intermediary that invokes another function, handling the result and potentially converting the output format, while also managing errors through a fault emission mechanism.\n2. **Input**: The inputs are various arguments required by the target function it intends to call. These arguments are passed through dynamically.\n3. **Output**: The output varies based on the return type of the function being proxied. It could be a direct value, a reference, or a data structure, potentially transformed for compatibility with specific compiler requirements.\n4. **Procedure**: The procedure involves:\n - Invoking a member function of a class using the current execution context and stack.\n - Checking if the function call was successful.\n - Emitting a fault if the function call resulted in an error.\n - Transforming the output data to a specific format if required, before returning it.\n\n", "instruction": "Based on the function description and code context, please retrieve and repeat the exact described function from the code context in a code block wrapped by ```:", "template": "instruction\ncode_context\ndescription\ninstruction", "code_context": "// Path: lib/executor/instantiate/import.cpp\n// SPDX-License-Identifier: Apache-2.0\n// SPDX-FileCopyrightText: 2019-2022 Second State INC\n\n#include \"executor/executor.h\"\n\n#include \"common/errinfo.h\"\n#include \"common/log.h\"\n\n#include \n#include \n#include \n\nnamespace WasmEdge {\nnamespace Executor {\n\nnamespace {\ntemplate \nauto logMatchError(std::string_view ModName, std::string_view ExtName,\n ExternalType ExtType, Args &&...Values) {\n spdlog::error(ErrCode::Value::IncompatibleImportType);\n spdlog::error(ErrInfo::InfoMismatch(std::forward(Values)...));\n spdlog::error(ErrInfo::InfoLinking(ModName, ExtName, ExtType));\n spdlog::error(ErrInfo::InfoAST(ASTNodeAttr::Desc_Import));\n return Unexpect(ErrCode::Value::IncompatibleImportType);\n}\n\nauto logUnknownError(std::string_view ModName, std::string_view ExtName,\n ExternalType ExtType) {\n spdlog::error(ErrCode::Value::UnknownImport);\n spdlog::error(ErrInfo::InfoLinking(ModName, ExtName, ExtType));\n spdlog::error(ErrInfo::InfoAST(ASTNodeAttr::Desc_Import));\n return Unexpect(ErrCode::Value::UnknownImport);\n}\n\nbool matchLimit(const AST::Limit &Exp, const AST::Limit &Got) {\n if (Exp.isShared() != Got.isShared()) {\n return false;\n }\n if ((Got.getMin() < Exp.getMin()) || (Exp.hasMax() && !Got.hasMax())) {\n return false;\n }\n if (Exp.hasMax() && Got.hasMax() && Got.getMax() > Exp.getMax()) {\n return false;\n }\n return true;\n}\n\nExpect\ncheckImportMatched(std::string_view ModName, std::string_view ExtName,\n const ExternalType ExtType,\n const Runtime::Instance::ModuleInstance &ModInst) {\n switch (ExtType) {\n case ExternalType::Function:\n if (auto Res = ModInst.findFuncExports(ExtName); likely(Res != nullptr)) {\n return {};\n }\n break;\n case ExternalType::Table:\n if (auto Res = ModInst.findTableExports(ExtName); likely(Res != nullptr)) {\n return {};\n }\n break;\n case ExternalType::Memory:\n if (auto Res = ModInst.findMemoryExports(ExtName); likely(Res != nullptr)) {\n return {};\n }\n break;\n case ExternalType::Global:\n if (auto Res = ModInst.findGlobalExports(ExtName); likely(Res != nullptr)) {\n return {};\n }\n break;\n default:\n return logUnknownError(ModName, ExtName, ExtType);\n }\n\n // Check is error external type or unknown imports.\n if (ModInst.findFuncExports(ExtName)) {\n return logMatchError(ModName, ExtName, ExtType, ExtType,\n ExternalType::Function);\n }\n if (ModInst.findTableExports(ExtName)) {\n return logMatchError(ModName, ExtName, ExtType, ExtType,\n ExternalType::Table);\n }\n if (ModInst.findMemoryExports(ExtName)) {\n return logMatchError(ModName, ExtName, ExtType, ExtType,\n ExternalType::Memory);\n }\n if (ModInst.findGlobalExports(ExtName)) {\n return logMatchError(ModName, ExtName, ExtType, ExtType,\n ExternalType::Global);\n }\n\n return logUnknownError(ModName, ExtName, ExtType);\n}\n} // namespace\n\n// Instantiate imports. See \"include/executor/executor.h\".\nExpect Executor::instantiate(Runtime::StoreManager &StoreMgr,\n Runtime::Instance::ModuleInstance &ModInst,\n const AST::ImportSection &ImportSec) {\n // Iterate and instantiate import descriptions.\n for (const auto &ImpDesc : ImportSec.getContent()) {\n // Get data from import description and find import module.\n auto ExtType = ImpDesc.getExternalType();\n auto ModName = ImpDesc.getModuleName();\n auto ExtName = ImpDesc.getExternalName();\n const auto *ImpModInst = StoreMgr.findModule(ModName);\n if (unlikely(ImpModInst == nullptr)) {\n auto Res = logUnknownError(ModName, ExtName, ExtType);\n if (ModName == \"wasi_snapshot_preview1\") {\n spdlog::error(\" This is a WASI related import. Please ensure that \"\n \"you've turned on the WASI configuration.\");\n } else if (ModName == \"wasi_nn\") {\n spdlog::error(\" This is a WASI-NN related import. Please ensure \"\n \"that you've turned on the WASI-NN configuration and \"\n \"installed the WASI-NN plug-in.\");\n } else if (ModName == \"wasi_crypto_common\" ||\n ModName == \"wasi_crypto_asymmetric_common\" ||\n ModName == \"wasi_crypto_kx\" ||\n ModName == \"wasi_crypto_signatures\" ||\n ModName == \"wasi_crypto_symmetric\") {\n spdlog::error(\" This is a WASI-Crypto related import. Please ensure \"\n \"that you've turned on the WASI-Crypto configuration and \"\n \"installed the WASI-Crypto plug-in.\");\n } else if (ModName == \"env\") {\n spdlog::error(\n \" This may be the import of host environment like JavaScript or \"\n \"Golang. Please check that you've registered the necessary host \"\n \"modules from the host programming language.\");\n }\n return Res;\n }\n if (auto Res = checkImportMatched(ModName, ExtName, ExtType, *ImpModInst);\n unlikely(!Res)) {\n return Unexpect(Res);\n }\n\n // Add the imports into module instance.\n switch (ExtType) {\n case ExternalType::Function: {\n // Get function type index. External type checked in validation.\n uint32_t TypeIdx = ImpDesc.getExternalFuncTypeIdx();\n // Import matching.\n auto *ImpInst = ImpModInst->findFuncExports(ExtName);\n const auto &ExpDefType = **ModInst.getType(TypeIdx);\n // External function type should match the import function type in\n // description.\n\n if (!AST::TypeMatcher::matchType(\n ModInst.getTypeList(), *ExpDefType.getTypeIndex(),\n ImpModInst->getTypeList(), ImpInst->getTypeIndex())) {\n bool IsMatchV2 = false;\n const auto &ExpFuncType = ExpDefType.getCompositeType().getFuncType();\n const auto &ImpFuncType = ImpInst->getFuncType();\n if (ModName == \"wasi_snapshot_preview1\") {\n /*\n * The following functions should provide V1 and V2.\n \"sock_open_v2\",\n \"sock_bind_v2\",\n \"sock_connect_v2\",\n \"sock_listen_v2\",\n \"sock_accept_v2\",\n \"sock_recv_v2\",\n \"sock_recv_from_v2\",\n \"sock_send_v2\",\n \"sock_send_to_v2\",\n \"sock_getlocaladdr_v2\",\n \"sock_getpeeraddr_v2\"\n */\n std::vector CompatibleWASISocketAPI = {\n \"sock_open\", \"sock_bind\", \"sock_connect\",\n \"sock_listen\", \"sock_accept\", \"sock_recv\",\n \"sock_recv_from\", \"sock_send\", \"sock_send_to\",\n \"sock_getlocaladdr\", \"sock_getpeeraddr\"};\n for (auto Iter = CompatibleWASISocketAPI.begin();\n Iter != CompatibleWASISocketAPI.end(); Iter++) {\n if (ExtName == *Iter) {\n auto *ImpInstV2 = ImpModInst->findFuncExports(*Iter + \"_v2\");\n if (!AST::TypeMatcher::matchType(\n ModInst.getTypeList(), *ExpDefType.getTypeIndex(),\n ImpModInst->getTypeList(), ImpInst->getTypeIndex())) {\n // Try to match the new version\n ImpInst = ImpInstV2;\n IsMatchV2 = true;\n break;\n }\n }\n }\n }\n if (!IsMatchV2) {\n return logMatchError(\n ModName, ExtName, ExtType, ExpFuncType.getParamTypes(),\n ExpFuncType.getReturnTypes(), ImpFuncType.getParamTypes(),\n ImpFuncType.getReturnTypes());\n }\n }\n // Set the matched function address to module instance.\n ModInst.importFunction(ImpInst);\n break;\n }\n case ExternalType::Table: {\n // Get table type. External type checked in validation.\n const auto &TabType = ImpDesc.getExternalTableType();\n const auto &TabLim = TabType.getLimit();\n // Import matching. External table type should match the one in import\n // description.\n auto *ImpInst = ImpModInst->findTableExports(ExtName);\n const auto &ImpType = ImpInst->getTableType();\n const auto &ImpLim = ImpType.getLimit();\n // External table reference type should match the import table reference\n // type in description, and vice versa.\n if (!AST::TypeMatcher::matchType(\n ModInst.getTypeList(), TabType.getRefType(),\n ImpModInst->getTypeList(), ImpType.getRefType()) ||\n !AST::TypeMatcher::matchType(\n ImpModInst->getTypeList(), ImpType.getRefType(),\n ModInst.getTypeList(), TabType.getRefType()) ||\n !matchLimit(TabLim, ImpLim)) {\n return logMatchError(ModName, ExtName, ExtType, TabType.getRefType(),\n TabLim.hasMax(), TabLim.getMin(), TabLim.getMax(),\n ImpType.getRefType(), ImpLim.hasMax(),\n ImpLim.getMin(), ImpLim.getMax());\n }\n // Set the matched table address to module instance.\n ModInst.importTable(ImpInst);\n break;\n }\n case ExternalType::Memory: {\n // Get memory type. External type checked in validation.\n const auto &MemType = ImpDesc.getExternalMemoryType();\n const auto &MemLim = MemType.getLimit();\n // Import matching. External memory type should match the one in import\n // description.\n auto *ImpInst = ImpModInst->findMemoryExports(ExtName);\n const auto &ImpLim = ImpInst->getMemoryType().getLimit();\n if (!matchLimit(MemLim, ImpLim)) {\n return logMatchError(ModName, ExtName, ExtType, MemLim.hasMax(),\n MemLim.getMin(), MemLim.getMax(), ImpLim.hasMax(),\n ImpLim.getMin(), ImpLim.getMax());\n }\n // Set the matched memory address to module instance.\n ModInst.importMemory(ImpInst);\n break;\n }\n case ExternalType::Global: {\n // Get global type. External type checked in validation.\n const auto &GlobType = ImpDesc.getExternalGlobalType();\n // Import matching. External global type should match the one in\n // import description.\n auto *ImpInst = ImpModInst->findGlobalExports(ExtName);\n const auto &ImpType = ImpInst->getGlobalType();\n bool IsMatch = false;\n if (ImpType.getValMut() == GlobType.getValMut()) {\n // For both const or both var: external global value type should match\n // the import global value type in description.\n IsMatch = AST::TypeMatcher::matchType(\n ModInst.getTypeList(), GlobType.getValType(),\n ImpModInst->getTypeList(), ImpType.getValType());\n if (ImpType.getValMut() == ValMut::Var) {\n // If both var: import global value type in description should also\n // match the external global value type.\n IsMatch &= AST::TypeMatcher::matchType(\n ImpModInst->getTypeList(), ImpType.getValType(),\n ModInst.getTypeList(), GlobType.getValType());\n }\n }\n if (!IsMatch) {\n...\n// Path: lib/executor/instantiate/function.cpp\n// SPDX-License-Identifier: Apache-2.0\n// SPDX-FileCopyrightText: 2019-2022 Second State INC\n\n#include \"executor/executor.h\"\n\n#include \n#include \n\nnamespace WasmEdge {\nnamespace Executor {\n\n// Instantiate function instance. See \"include/executor/executor.h\".\nExpect Executor::instantiate(Runtime::Instance::ModuleInstance &ModInst,\n const AST::FunctionSection &FuncSec,\n const AST::CodeSection &CodeSec) {\n\n // Get the function type indices.\n auto TypeIdxs = FuncSec.getContent();\n auto CodeSegs = CodeSec.getContent();\n\n if (CodeSegs.size() == 0) {\n return {};\n }\n // The module will always choose the `for` loop in `else` case under\n // interpreter mode. Instead, if we do branch in the `for` loop which might\n // cause meaningless branch misses. Therefore we should check the first item\n // and dispatch it into different cases to reduce branch misses.\n if (CodeSegs[0].getSymbol() != false) {\n for (uint32_t I = 0; I < CodeSegs.size(); ++I) {\n auto Symbol = CodeSegs[I].getSymbol();\n ModInst.addFunc(\n TypeIdxs[I],\n (*ModInst.getType(TypeIdxs[I]))->getCompositeType().getFuncType(),\n std::move(Symbol));\n }\n } else {\n // Iterate through the code segments to instantiate function instances.\n for (uint32_t I = 0; I < CodeSegs.size(); ++I) {\n // Create and add the function instance into the module instance.\n ModInst.addFunc(\n TypeIdxs[I],\n (*ModInst.getType(TypeIdxs[I]))->getCompositeType().getFuncType(),\n CodeSegs[I].getLocals(), CodeSegs[I].getExpr().getInstrs());\n }\n }\n return {};\n}\n\n} // namespace Executor\n} // namespace WasmEdge\n\n// Path: lib/executor/instantiate/table.cpp\n// SPDX-License-Identifier: Apache-2.0\n// SPDX-FileCopyrightText: 2019-2022 Second State INC\n\n#include \"executor/executor.h\"\n\n#include \n\nnamespace WasmEdge {\nnamespace Executor {\n\n// Instantiate table instance. See \"include/executor/executor.h\".\nExpect Executor::instantiate(Runtime::StackManager &StackMgr,\n Runtime::Instance::ModuleInstance &ModInst,\n const AST::TableSection &TabSec) {\n // A frame with temp. module is pushed into the stack in caller.\n\n // Iterate through the table segments to instantiate and initialize table\n // instances.\n for (const auto &TabSeg : TabSec.getContent()) {\n if (TabSeg.getExpr().getInstrs().size() > 0) {\n // Run initialize expression.\n if (auto Res = runExpression(StackMgr, TabSeg.getExpr().getInstrs());\n !Res) {\n spdlog::error(ErrInfo::InfoAST(ASTNodeAttr::Expression));\n return Unexpect(Res);\n }\n // Pop result from stack.\n RefVariant InitTabValue = StackMgr.pop().get();\n // Create and add the table instance into the module instance.\n ModInst.addTable(TabSeg.getTableType(), InitTabValue);\n } else {\n // No init expression case. Use the null reference to initialize.\n ModInst.addTable(TabSeg.getTableType());\n }\n }\n return {};\n}\n\n} // namespace Executor\n} // namespace WasmEdge\n\n// Path: lib/executor/instantiate/memory.cpp\n// SPDX-License-Identifier: Apache-2.0\n// SPDX-FileCopyrightText: 2019-2022 Second State INC\n\n#include \"executor/executor.h\"\n\n#include \n\nnamespace WasmEdge {\nnamespace Executor {\n\n// Instantiate memory instance. See \"include/executor/executor.h\".\nExpect Executor::instantiate(Runtime::Instance::ModuleInstance &ModInst,\n const AST::MemorySection &MemSec) {\n // Prepare pointers vector for compiled functions.\n ModInst.MemoryPtrs.resize(ModInst.getMemoryNum() +\n MemSec.getContent().size());\n\n // Iterate through the memory types to instantiate memory instances.\n for (const auto &MemType : MemSec.getContent()) {\n // Create and add the memory instance into the module instance.\n ModInst.addMemory(MemType, Conf.getRuntimeConfigure().getMaxMemoryPage());\n }\n return {};\n}\n\n} // namespace Executor\n} // namespace WasmEdge\n\n// Path: lib/executor/instantiate/global.cpp\n// SPDX-License-Identifier: Apache-2.0\n// SPDX-FileCopyrightText: 2019-2022 Second State INC\n\n#include \"executor/executor.h\"\n\n#include \n\nnamespace WasmEdge {\nnamespace Executor {\n\n// Instantiate global instance. See \"include/executor/executor.h\".\nExpect Executor::instantiate(Runtime::StackManager &StackMgr,\n Runtime::Instance::ModuleInstance &ModInst,\n const AST::GlobalSection &GlobSec) {\n // A frame with temp. module is pushed into the stack in caller.\n\n // Prepare pointers for compiled functions.\n ModInst.GlobalPtrs.resize(ModInst.getGlobalNum() +\n GlobSec.getContent().size());\n\n // Set the global pointers of imported globals.\n for (uint32_t I = 0; I < ModInst.getGlobalNum(); ++I) {\n ModInst.GlobalPtrs[I] = &((*ModInst.getGlobal(I))->getValue());\n }\n\n // Iterate through the global segments to instantiate and initialize global\n // instances.\n for (const auto &GlobSeg : GlobSec.getContent()) {\n // Run initialize expression.\n if (auto Res = runExpression(StackMgr, GlobSeg.getExpr().getInstrs());\n !Res) {\n spdlog::error(ErrInfo::InfoAST(ASTNodeAttr::Expression));\n return Unexpect(Res);\n }\n\n // Pop result from stack.\n ValVariant InitValue = StackMgr.pop();\n\n // Create and add the global instance into the module instance.\n ModInst.addGlobal(GlobSeg.getGlobalType(), InitValue);\n const auto Index = ModInst.getGlobalNum() - 1;\n Runtime::Instance::GlobalInstance *GlobInst = *ModInst.getGlobal(Index);\n\n // Set the global pointers of instantiated globals.\n ModInst.GlobalPtrs[Index] = &(GlobInst->getValue());\n }\n return {};\n}\n\n} // namespace Executor\n} // namespace WasmEdge\n\n// Path: lib/executor/instantiate/export.cpp\n// SPDX-License-Identifier: Apache-2.0\n// SPDX-FileCopyrightText: 2019-2022 Second State INC\n\n#include \"executor/executor.h\"\n\n#include \n#include \n\nnamespace WasmEdge {\nnamespace Executor {\n\n// Instantiate exports. See \"include/executor/executor.h\".\nExpect Executor::instantiate(Runtime::Instance::ModuleInstance &ModInst,\n const AST::ExportSection &ExportSec) {\n // Iterate through the export descriptions and instantiate the exports.\n for (const auto &ExpDesc : ExportSec.getContent()) {\n // Get data from the export description.\n const auto ExtType = ExpDesc.getExternalType();\n std::string_view ExtName = ExpDesc.getExternalName();\n const uint32_t ExtIdx = ExpDesc.getExternalIndex();\n\n // Export the instance with the name.\n switch (ExtType) {\n case ExternalType::Function:\n ModInst.exportFunction(ExtName, ExtIdx);\n break;\n case ExternalType::Global:\n ModInst.exportGlobal(ExtName, ExtIdx);\n break;\n case ExternalType::Memory:\n ModInst.exportMemory(ExtName, ExtIdx);\n break;\n case ExternalType::Table:\n ModInst.exportTable(ExtName, ExtIdx);\n break;\n default:\n break;\n }\n }\n return {};\n}\n\n} // namespace Executor\n} // namespace WasmEdge\n\n// Path: lib/executor/instantiate/elem.cpp\n// SPDX-License-Identifier: Apache-2.0\n// SPDX-FileCopyrightText: 2019-2022 Second State INC\n\n#include \"executor/executor.h\"\n\n#include \"common/errinfo.h\"\n#include \"common/log.h\"\n#include \n#include \n\nnamespace WasmEdge {\nnamespace Executor {\n\n// Instantiate element instance. See \"include/executor/executor.h\".\nExpect Executor::instantiate(Runtime::StackManager &StackMgr,\n Runtime::Instance::ModuleInstance &ModInst,\n const AST::ElementSection &ElemSec) {\n // A frame with the current module has been pushed into the stack outside.\n\n // Iterate through the element segments to instantiate element instances.\n for (const auto &ElemSeg : ElemSec.getContent()) {\n std::vector InitVals;\n for (const auto &Expr : ElemSeg.getInitExprs()) {\n // Run init expr of every elements and get the result reference.\n if (auto Res = runExpression(StackMgr, Expr.getInstrs()); !Res) {\n spdlog::error(ErrInfo::InfoAST(ASTNodeAttr::Expression));\n spdlog::error(ErrInfo::InfoAST(ASTNodeAttr::Seg_Element));\n return Unexpect(Res);\n }\n // Pop result from stack.\n InitVals.push_back(StackMgr.pop().get());\n }\n\n uint32_t Offset = 0;\n if (ElemSeg.getMode() == AST::ElementSegment::ElemMode::Active) {\n // Run initialize expression.\n if (auto Res = runExpression(StackMgr, ElemSeg.getExpr().getInstrs());\n !Res) {\n spdlog::error(ErrInfo::InfoAST(ASTNodeAttr::Expression));\n spdlog::error(ErrInfo::InfoAST(ASTNodeAttr::Seg_Element));\n return Unexpect(Res);\n }\n Offset = StackMgr.pop().get();\n\n // Check boundary unless ReferenceTypes or BulkMemoryOperations proposal\n // enabled.\n if (!Conf.hasProposal(Proposal::ReferenceTypes) &&\n !Conf.hasProposal(Proposal::BulkMemoryOperations)) {\n // Table index should be 0. Checked in validation phase.\n auto *TabInst = getTabInstByIdx(StackMgr, ElemSeg.getIdx());\n // Check elements fits.\n assuming(TabInst);\n if (!TabInst->checkAccessBound(\n Offset, static_cast(InitVals.size()))) {\n spdlog::error(ErrCode::Value::ElemSegDoesNotFit);\n spdlog::error(ErrInfo::InfoAST(ASTNodeAttr::Seg_Element));\n return Unexpect(ErrCode::Value::ElemSegDoesNotFit);\n }\n }\n }\n\n // Create and add the element instance into the module instance.\n ModInst.addElem(Offset, ElemSeg.getRefType(), InitVals);\n }\n return {};\n}\n\n// Initialize table with Element section. See \"include/executor/executor.h\".\nExpect Executor::initTable(Runtime::StackManager &StackMgr,\n const AST::ElementSection &ElemSec) {\n // Initialize tables.\n uint32_t Idx = 0;\n for (const auto &ElemSeg : ElemSec.getContent()) {\n auto *ElemInst = getElemInstByIdx(StackMgr, Idx);\n assuming(ElemInst);\n if (ElemSeg.getMode() == AST::ElementSegment::ElemMode::Active) {\n // Table index is checked in validation phase.\n auto *TabInst = getTabInstByIdx(StackMgr, ElemSeg.getIdx());\n assuming(TabInst);\n const uint32_t Off = ElemInst->getOffset();\n\n // Replace table[Off : Off + n] with elem[0 : n].\n if (auto Res = TabInst->setRefs(\n ElemInst->getRefs(), Off, 0,\n static_cast(ElemInst->getRefs().size()));\n !Res) {\n spdlog::error(ErrInfo::InfoAST(ASTNodeAttr::Seg_Element));\n return Unexpect(Res);\n }\n\n // Drop the element instance.\n ElemInst->clear();\n\n // Operation above is equal to the following instruction sequence:\n // expr(init) -> i32.const off\n // i32.const 0\n // i32.const n\n // table.init idx\n // elem.drop idx\n } else if (ElemSeg.getMode() ==\n AST::ElementSegment::ElemMode::Declarative) {\n // Drop the element instance.\n ElemInst->clear();\n }\n Idx++;\n }\n return {};\n}\n\n} // namespace Executor\n} // namespace WasmEdge\n\n// Path: lib/executor/engine/variableInstr.cpp\n// SPDX-License-Identifier: Apache-2.0\n// SPDX-FileCopyrightText: 2019-2022 Second State INC\n\n#include \"executor/executor.h\"\n\n#include \n\nnamespace WasmEdge {\nnamespace Executor {\n\nExpect Executor::runLocalGetOp(Runtime::StackManager &StackMgr,\n uint32_t StackOffset) const noexcept {\n StackMgr.push(StackMgr.getTopN(StackOffset));\n return {};\n}\n\nExpect Executor::runLocalSetOp(Runtime::StackManager &StackMgr,\n uint32_t StackOffset) const noexcept {\n StackMgr.getTopN(StackOffset - 1) = StackMgr.pop();\n return {};\n}\n\nExpect Executor::runLocalTeeOp(Runtime::StackManager &StackMgr,\n uint32_t StackOffset) const noexcept {\n const ValVariant &Val = StackMgr.getTop();\n StackMgr.getTopN(StackOffset) = Val;\n return {};\n}\n\nExpect Executor::runGlobalGetOp(Runtime::StackManager &StackMgr,\n uint32_t Idx) const noexcept {\n auto *GlobInst = getGlobInstByIdx(StackMgr, Idx);\n assuming(GlobInst);\n StackMgr.push(GlobInst->getValue());\n return {};\n}\n\nExpect Executor::runGlobalSetOp(Runtime::StackManager &StackMgr,\n uint32_t Idx) const noexcept {\n auto *GlobInst = getGlobInstByIdx(StackMgr, Idx);\n assuming(GlobInst);\n GlobInst->setValue(StackMgr.pop());\n return {};\n}\n\n} // namespace Executor\n} // namespace WasmEdge\n\n// Path: lib/executor/engine/controlInstr.cpp\n// SPDX-License-Identifier: Apache-2.0\n// SPDX-FileCopyrightText: 2019-2022 Second State INC\n\n#include \"executor/executor.h\"\n\n#include \n\nnamespace WasmEdge {\nnamespace Executor {\n\nExpect Executor::runIfElseOp(Runtime::StackManager &StackMgr,\n const AST::Instruction &Instr,\n AST::InstrView::iterator &PC) noexcept {\n // Get condition.\n uint32_t Cond = StackMgr.pop().get();\n\n // If non-zero, run if-statement; else, run else-statement.\n if (Cond == 0) {\n if (Instr.getJumpElse() == Instr.getJumpEnd()) {\n // No else-statement case. Jump to right before End instruction.\n PC += (Instr.getJumpEnd() - 1);\n } else {\n if (Stat) {\n Stat->incInstrCount();\n if (unlikely(!Stat->addInstrCost(OpCode::Else))) {\n return Unexpect(ErrCode::Value::CostLimitExceeded);\n }\n }\n // Have else-statement case. Jump to Else instruction to continue.\n PC += Instr.getJumpElse();\n }\n }\n return {};\n}\n\nExpect Executor::runBrOp(Runtime::StackManager &StackMgr,\n const AST::Instruction &Instr,\n AST::InstrView::iterator &PC) noexcept {\n return branchToLabel(StackMgr, Instr.getJump().StackEraseBegin,\n Instr.getJump().StackEraseEnd, Instr.getJump().PCOffset,\n PC);\n}\n\nExpect Executor::runBrIfOp(Runtime::StackManager &StackMgr,\n const AST::Instruction &Instr,\n AST::InstrView::iterator &PC) noexcept {\n if (StackMgr.pop().get() != 0) {\n return runBrOp(StackMgr, Instr, PC);\n }\n return {};\n}\n\nExpect Executor::runBrOnNullOp(Runtime::StackManager &StackMgr,\n const AST::Instruction &Instr,\n AST::InstrView::iterator &PC) noexcept {\n if (StackMgr.getTop().get().isNull()) {\n StackMgr.pop();\n return runBrOp(StackMgr, Instr, PC);\n }\n return {};\n}\n\nExpect Executor::runBrOnNonNullOp(Runtime::StackManager &StackMgr,\n const AST::Instruction &Instr,\n AST::InstrView::iterator &PC) noexcept {\n if (!StackMgr.getTop().get().isNull()) {\n return runBrOp(StackMgr, Instr, PC);\n }\n StackMgr.pop();\n return {};\n}\n\nExpect Executor::runBrTableOp(Runtime::StackManager &StackMgr,\n const AST::Instruction &Instr,\n AST::InstrView::iterator &PC) noexcept {\n // Get value on top of stack.\n uint32_t Value = StackMgr.pop().get();\n\n // Do branch.\n auto LabelTable = Instr.getLabelList();\n const auto LabelTableSize = static_cast(LabelTable.size() - 1);\n if (Value < LabelTableSize) {\n return branchToLabel(StackMgr, LabelTable[Value].StackEraseBegin,\n LabelTable[Value].StackEraseEnd,\n LabelTable[Value].PCOffset, PC);\n }\n return branchToLabel(StackMgr, LabelTable[LabelTableSize].StackEraseBegin,\n LabelTable[LabelTableSize].StackEraseEnd,\n LabelTable[LabelTableSize].PCOffset, PC);\n}\n\nExpect Executor::runBrOnCastOp(Runtime::StackManager &StackMgr,\n const AST::Instruction &Instr,\n AST::InstrView::iterator &PC,\n bool IsReverse) noexcept {\n // Get value on top of stack.\n const auto *ModInst = StackMgr.getModule();\n const auto &Val = StackMgr.getTop().get();\n const auto &VT = Val.getType();\n Span GotTypeList = ModInst->getTypeList();\n if (!VT.isAbsHeapType()) {\n auto *Inst = Val.getPtr();\n // Reference must not be nullptr here because the null references are typed\n // with the least abstract heap type.\n if (Inst->getModule()) {\n GotTypeList = Inst->getModule()->getTypeList();\n }\n }\n\n if (AST::TypeMatcher::matchType(ModInst->getTypeList(),\n Instr.getBrCast().RType2, GotTypeList,\n VT) != IsReverse) {\n return branchToLabel(StackMgr, Instr.getBrCast().Jump.StackEraseBegin,\n Instr.getBrCast().Jump.StackEraseEnd,\n Instr.getBrCast().Jump.PCOffset, PC);\n }\n return {};\n}\n\nExpect Executor::runReturnOp(Runtime::StackManager &StackMgr,\n AST::InstrView::iterator &PC) noexcept {\n // Check stop token\n if (unlikely(StopToken.exchange(0, std::memory_order_relaxed))) {\n spdlog::error(ErrCode::Value::Interrupted);\n return Unexpect(ErrCode::Value::Interrupted);\n }\n PC = StackMgr.popFrame();\n return {};\n}\n\nExpect Executor::runCallOp(Runtime::StackManager &StackMgr,\n const AST::Instruction &Instr,\n AST::InstrView::iterator &PC,\n bool IsTailCall) noexcept {\n // Get Function address.\n const auto *FuncInst = getFuncInstByIdx(StackMgr, Instr.getTargetIndex());\n if (auto Res = enterFunction(StackMgr, *FuncInst, PC + 1, IsTailCall); !Res) {\n return Unexpect(Res);\n } else {\n PC = (*Res) - 1;\n }\n return {};\n}\n\nExpect Executor::runCallRefOp(Runtime::StackManager &StackMgr,\n const AST::Instruction &Instr,\n AST::InstrView::iterator &PC,\n bool IsTailCall) noexcept {\n\n const auto Ref = StackMgr.pop().get();\n if (Ref.isNull()) {\n spdlog::error(ErrCode::Value::AccessNullFunc);\n spdlog::error(\n ErrInfo::InfoInstruction(Instr.getOpCode(), Instr.getOffset()));\n return Unexpect(ErrCode::Value::AccessNullFunc);\n }\n\n // Get Function address.\n const auto *FuncInst = retrieveFuncRef(Ref);\n if (auto Res = enterFunction(StackMgr, *FuncInst, PC + 1, IsTailCall); !Res) {\n return Unexpect(Res);\n } else {\n PC = (*Res) - 1;\n }\n return {};\n}\n\nExpect Executor::runCallIndirectOp(Runtime::StackManager &StackMgr,\n const AST::Instruction &Instr,\n AST::InstrView::iterator &PC,\n bool IsTailCall) noexcept {\n // Get Table Instance\n const auto *TabInst = getTabInstByIdx(StackMgr, Instr.getSourceIndex());\n\n // Get function type at index x.\n const auto *ModInst = StackMgr.getModule();\n const auto &ExpDefType = **ModInst->getType(Instr.getTargetIndex());\n\n // Pop the value i32.const i from the Stack.\n uint32_t Idx = StackMgr.pop().get();\n\n // If idx not small than tab.elem, trap.\n if (Idx >= TabInst->getSize()) {\n spdlog::error(ErrCode::Value::UndefinedElement);\n spdlog::error(ErrInfo::InfoInstruction(Instr.getOpCode(), Instr.getOffset(),\n {Idx},\n {ValTypeFromType()}));\n return Unexpect(ErrCode::Value::UndefinedElement);\n }\n\n // Get function address. The bound is guaranteed.\n RefVariant Ref = *TabInst->getRefAddr(Idx);\n if (Ref.isNull()) {\n spdlog::error(ErrCode::Value::UninitializedElement);\n spdlog::error(ErrInfo::InfoInstruction(Instr.getOpCode(), Instr.getOffset(),\n {Idx},\n {ValTypeFromType()}));\n return Unexpect(ErrCode::Value::UninitializedElement);\n }\n\n // Check function type.\n const auto *FuncInst = retrieveFuncRef(Ref);\n bool IsMatch = false;\n if (FuncInst->getModule()) {\n IsMatch = AST::TypeMatcher::matchType(\n ModInst->getTypeList(), *ExpDefType.getTypeIndex(),\n FuncInst->getModule()->getTypeList(), FuncInst->getTypeIndex());\n } else {\n // Independent host module instance case. Matching the composite type\n // directly.\n IsMatch = AST::TypeMatcher::matchType(\n ModInst->getTypeList(), ExpDefType.getCompositeType(),\n FuncInst->getHostFunc().getDefinedType().getCompositeType());\n }\n if (!IsMatch) {\n auto &ExpFuncType = ExpDefType.getCompositeType().getFuncType();\n auto &GotFuncType = FuncInst->getFuncType();\n spdlog::error(ErrCode::Value::IndirectCallTypeMismatch);\n spdlog::error(ErrInfo::InfoInstruction(Instr.getOpCode(), Instr.getOffset(),\n {Idx},\n {ValTypeFromType()}));\n spdlog::error(ErrInfo::InfoMismatch(\n ExpFuncType.getParamTypes(), ExpFuncType.getReturnTypes(),\n GotFuncType.getParamTypes(), GotFuncType.getReturnTypes()));\n return Unexpect(ErrCode::Value::IndirectCallTypeMismatch);\n }\n\n // Enter the function.\n if (auto Res = enterFunction(StackMgr, *FuncInst, PC + 1, IsTailCall); !Res) {\n return Unexpect(Res);\n } else {\n PC = (*Res) - 1;\n }\n return {};\n}\n\n} // namespace Executor\n} // namespace WasmEdge\n\n// Path: lib/executor/engine/proxy.cpp\n// SPDX-License-Identifier: Apache-2.0\n// SPDX-FileCopyrightText: 2019-2022 Second State INC\n\n#include \"executor/executor.h\"\n#include \"system/fault.h\"\n\n#include \n\nnamespace WasmEdge {\nnamespace Executor {\n\nthread_local Executor *Executor::This = nullptr;\nthread_local Runtime::StackManager *Executor::CurrentStack = nullptr;\nthread_local Executor::ExecutionContextStruct Executor::ExecutionContext;\n\ntemplate \nstruct Executor::ProxyHelper (Executor::*)(Runtime::StackManager &,\n ArgsT...) noexcept> {\n template (Executor::*Func)(Runtime::StackManager &,\n ArgsT...) noexcept>\n \nstatic auto proxy(ArgsT... Args) {\n Expect Res = (This->*Func)(*CurrentStack, Args...);\n if (unlikely(!Res)) {\n Fault::emitFault(Res.error());\n }\n if constexpr (std::is_same_v) {\n#if defined(_MSC_VER) && !defined(__clang__) // MSVC\n return *reinterpret_cast<__m128 *>((*Res).getRawData().data());\n#else\n return (*Res).getRawData();\n#endif // MSVC\n } else if constexpr (!std::is_void_v) {\n return *Res;\n }\n }\n};\n\n#if defined(__clang_major__) && __clang_major__ >= 10\n#pragma clang diagnostic push\n#pragma clang diagnostic ignored \"-Wc99-designator\"\n#endif\n\n// Intrinsics table\nconst Executable::IntrinsicsTable Executor::Intrinsics = {\n#if defined(_MSC_VER) && !defined(__clang__)\n#define ENTRY(NAME, FUNC) \\\n reinterpret_cast(&Executor::ProxyHelper< \\\n decltype(&Executor::FUNC)>::proxy<&Executor::FUNC>)\n#else\n#define ENTRY(NAME, FUNC) \\\n [uint8_t(Executable::Intrinsics::NAME)] = reinterpret_cast( \\\n &Executor::ProxyHelper::proxy< \\\n &Executor::FUNC>)\n#endif\n ENTRY(kTrap, trap),\n ENTRY(kCall, call),\n ENTRY(kCallIndirect, callIndirect),\n ENTRY(kMemCopy, memCopy),\n ENTRY(kMemFill, memFill),\n ENTRY(kMemGrow, memGrow),\n ENTRY(kMemSize, memSize),\n ENTRY(kMemInit, memInit),\n ENTRY(kDataDrop, dataDrop),\n ENTRY(kTableGet, tableGet),\n ENTRY(kTableSet, tableSet),\n ENTRY(kTableCopy, tableCopy),\n ENTRY(kTableFill, tableFill),\n ENTRY(kTableGrow, tableGrow),\n ENTRY(kTableSize, tableSize),\n ENTRY(kTableInit, tableInit),\n ENTRY(kElemDrop, elemDrop),\n ENTRY(kRefFunc, refFunc),\n ENTRY(kTableGetFuncSymbol, tableGetFuncSymbol),\n ENTRY(kMemoryAtomicNotify, memoryAtomicNotify),\n ENTRY(kMemoryAtomicWait, memoryAtomicWait),\n ENTRY(kCallRef, callRef),\n ENTRY(kRefGetFuncSymbol, refGetFuncSymbol),\n#undef ENTRY\n};\n\n#if defined(__clang_major__) && __clang_major__ >= 10\n#pragma clang diagnostic pop\n#endif\n\nExpect Executor::trap(Runtime::StackManager &,\n const uint32_t Code) noexcept {\n return Unexpect(static_cast(Code >> 24), Code);\n}\n\nExpect Executor::call(Runtime::StackManager &StackMgr,\n const uint32_t FuncIdx, const ValVariant *Args,\n ValVariant *Rets) noexcept {\n const auto *FuncInst = getFuncInstByIdx(StackMgr, FuncIdx);\n const auto &FuncType = FuncInst->getFuncType();\n const uint32_t ParamsSize =\n static_cast(FuncType.getParamTypes().size());\n const uint32_t ReturnsSize =\n static_cast(FuncType.getReturnTypes().size());\n\n for (uint32_t I = 0; I < ParamsSize; ++I) {\n StackMgr.push(Args[I]);\n }\n\n auto Instrs = FuncInst->getInstrs();\n AST::InstrView::iterator StartIt;\n if (auto Res = enterFunction(StackMgr, *FuncInst, Instrs.end())) {\n StartIt = *Res;\n } else {\n return Unexpect(Res);\n }\n if (auto Res = execute(StackMgr, StartIt, Instrs.end()); unlikely(!Res)) {\n return Unexpect(Res);\n }\n\n for (uint32_t I = 0; I < ReturnsSize; ++I) {\n Rets[ReturnsSize - 1 - I] = StackMgr.pop();\n }\n\n return {};\n}\n\nExpect Executor::tableGetFuncSymbol(Runtime::StackManager &StackMgr,\n const uint32_t TableIdx,\n const uint32_t FuncTypeIdx,\n const uint32_t FuncIdx) noexcept {\n const auto *TabInst = getTabInstByIdx(StackMgr, TableIdx);\n assuming(TabInst);\n\n if (unlikely(FuncIdx >= TabInst->getSize())) {\n return Unexpect(ErrCode::Value::UndefinedElement);\n }\n\n auto Ref = TabInst->getRefAddr(FuncIdx);\n assuming(Ref);\n if (unlikely(Ref->isNull())) {\n return Unexpect(ErrCode::Value::UninitializedElement);\n }\n\n const auto *ModInst = StackMgr.getModule();\n assuming(ModInst);\n const auto &ExpDefType = **ModInst->getType(FuncTypeIdx);\n const auto *FuncInst = retrieveFuncRef(*Ref);\n assuming(FuncInst);\n bool IsMatch = false;\n if (FuncInst->getModule()) {\n IsMatch = AST::TypeMatcher::matchType(\n ModInst->getTypeList(), *ExpDefType.getTypeIndex(),\n FuncInst->getModule()->getTypeList(), FuncInst->getTypeIndex());\n } else {\n // Independent host module instance case. Matching the composite type\n // directly.\n IsMatch = AST::TypeMatcher::matchType(\n ModInst->getTypeList(), ExpDefType.getCompositeType(),\n FuncInst->getHostFunc().getDefinedType().getCompositeType());\n }\n if (!IsMatch) {\n return Unexpect(ErrCode::Value::IndirectCallTypeMismatch);\n }\n\n if (unlikely(!FuncInst->isCompiledFunction())) {\n return nullptr;\n }\n\n return FuncInst->getSymbol().get();\n}\n\nExpect\nExecutor::callIndirect(Runtime::StackManager &StackMgr, const uint32_t TableIdx,\n const uint32_t FuncTypeIdx, const uint32_t FuncIdx,\n const ValVariant *Args, ValVariant *Rets) noexcept {\n const auto *TabInst = getTabInstByIdx(StackMgr, TableIdx);\n assuming(TabInst);\n\n if (unlikely(FuncIdx >= TabInst->getSize())) {\n return Unexpect(ErrCode::Value::UndefinedElement);\n }\n\n auto Ref = TabInst->getRefAddr(FuncIdx);\n assuming(Ref);\n if (unlikely(Ref->isNull())) {\n return Unexpect(ErrCode::Value::UninitializedElement);\n }\n\n const auto *ModInst = StackMgr.getModule();\n assuming(ModInst);\n const auto &ExpDefType = **ModInst->getType(FuncTypeIdx);\n const auto *FuncInst = retrieveFuncRef(*Ref);\n assuming(FuncInst);\n bool IsMatch = false;\n if (FuncInst->getModule()) {\n IsMatch = AST::TypeMatcher::matchType(\n ModInst->getTypeList(), *ExpDefType.getTypeIndex(),\n FuncInst->getModule()->getTypeList(), FuncInst->getTypeIndex());\n } else {\n // Independent host module instance case. Matching the composite type\n // directly.\n IsMatch = AST::TypeMatcher::matchType(\n ModInst->getTypeList(), ExpDefType.getCompositeType(),\n FuncInst->getHostFunc().getDefinedType().getCompositeType());\n }\n if (!IsMatch) {\n return Unexpect(ErrCode::Value::IndirectCallTypeMismatch);\n }\n\n const auto &FuncType = FuncInst->getFuncType();\n const uint32_t ParamsSize =\n static_cast(FuncType.getParamTypes().size());\n const uint32_t ReturnsSize =\n static_cast(FuncType.getReturnTypes().size());\n\n for (uint32_t I = 0; I < ParamsSize; ++I) {\n StackMgr.push(Args[I]);\n }\n\n auto Instrs = FuncInst->getInstrs();\n AST::InstrView::iterator StartIt;\n if (auto Res = enterFunction(StackMgr, *FuncInst, Instrs.end())) {\n StartIt = *Res;\n } else {\n return Unexpect(Res);\n }\n if (auto Res = execute(StackMgr, StartIt, Instrs.end()); unlikely(!Res)) {\n return Unexpect(Res);\n }\n\n for (uint32_t I = 0; I < ReturnsSize; ++I) {\n Rets[ReturnsSize - 1 - I] = StackMgr.pop();\n }\n\n return {};\n}\n\nExpect Executor::memGrow(Runtime::StackManager &StackMgr,\n const uint32_t MemIdx,\n const uint32_t NewSize) noexcept {\n auto *MemInst = getMemInstByIdx(StackMgr, MemIdx);\n assuming(MemInst);\n const uint32_t CurrPageSize = MemInst->getPageSize();\n if (MemInst->growPage(NewSize)) {\n return CurrPageSize;\n } else {\n return static_cast(-1);\n }\n}\n\nExpect Executor::memSize(Runtime::StackManager &StackMgr,\n const uint32_t MemIdx) noexcept {\n auto *MemInst = getMemInstByIdx(StackMgr, MemIdx);\n assuming(MemInst);\n return MemInst->getPageSize();\n}\n\nExpect Executor::memCopy(Runtime::StackManager &StackMgr,\n const uint32_t DstMemIdx,\n const uint32_t SrcMemIdx, const uint32_t DstOff,\n const uint32_t SrcOff,\n const uint32_t Len) noexcept {\n auto *MemInstDst = getMemInstByIdx(StackMgr, DstMemIdx);\n assuming(MemInstDst);\n auto *MemInstSrc = getMemInstByIdx(StackMgr, SrcMemIdx);\n assuming(MemInstSrc);\n\n if (auto Data = MemInstSrc->getBytes(SrcOff, Len); unlikely(!Data)) {\n return Unexpect(Data);\n } else {\n if (auto Res = MemInstDst->setBytes(*Data, DstOff, 0, Len);\n unlikely(!Res)) {\n return Unexpect(Res);\n }\n }\n\n return {};\n}\n\nExpect Executor::memFill(Runtime::StackManager &StackMgr,\n const uint32_t MemIdx, const uint32_t Off,\n const uint8_t Val, const uint32_t Len) noexcept {\n auto *MemInst = getMemInstByIdx(StackMgr, MemIdx);\n assuming(MemInst);\n if (auto Res = MemInst->fillBytes(Val, Off, Len); unlikely(!Res)) {\n return Unexpect(Res);\n }\n\n return {};\n}\n\nExpect Executor::memInit(Runtime::StackManager &StackMgr,\n const uint32_t MemIdx, const uint32_t DataIdx,\n const uint32_t DstOff, const uint32_t SrcOff,\n const uint32_t Len) noexcept {\n auto *MemInst = getMemInstByIdx(StackMgr, MemIdx);\n assuming(MemInst);\n auto *DataInst = getDataInstByIdx(StackMgr, DataIdx);\n assuming(DataInst);\n\n if (auto Res = MemInst->setBytes(DataInst->getData(), DstOff, SrcOff, Len);\n unlikely(!Res)) {\n return Unexpect(Res);\n }\n\n return {};\n}\n\nExpect Executor::dataDrop(Runtime::StackManager &StackMgr,\n const uint32_t DataIdx) noexcept {\n auto *DataInst = getDataInstByIdx(StackMgr, DataIdx);\n assuming(DataInst);\n DataInst->clear();\n\n return {};\n}\n\nExpect Executor::tableGet(Runtime::StackManager &StackMgr,\n const uint32_t TableIdx,\n const uint32_t Off) noexcept {\n auto *TabInst = getTabInstByIdx(StackMgr, TableIdx);\n assuming(TabInst);\n if (auto Res = TabInst->getRefAddr(Off); unlikely(!Res)) {\n return Unexpect(Res);\n } else {\n return *Res;\n }\n}\n\nExpect Executor::tableSet(Runtime::StackManager &StackMgr,\n const uint32_t TableIdx, const uint32_t Off,\n const RefVariant Ref) noexcept {\n auto *TabInst = getTabInstByIdx(StackMgr, TableIdx);\n assuming(TabInst);\n if (auto Res = TabInst->setRefAddr(Off, Ref); unlikely(!Res)) {\n return Unexpect(Res);\n }\n\n return {};\n}\n\nExpect Executor::tableCopy(Runtime::StackManager &StackMgr,\n const uint32_t TableIdxDst,\n const uint32_t TableIdxSrc,\n const uint32_t DstOff, const uint32_t SrcOff,\n const uint32_t Len) noexcept {\n auto *TabInstDst = getTabInstByIdx(StackMgr, TableIdxDst);\n assuming(TabInstDst);\n auto *TabInstSrc = getTabInstByIdx(StackMgr, TableIdxSrc);\n assuming(TabInstSrc);\n\n if (auto Refs = TabInstSrc->getRefs(0, SrcOff + Len); unlikely(!Refs)) {\n return Unexpect(Refs);\n } else {\n if (auto Res = TabInstDst->setRefs(*Refs, DstOff, SrcOff, Len);\n unlikely(!Res)) {\n return Unexpect(Res);\n }\n }\n\n return {};\n}\n\nExpect Executor::tableGrow(Runtime::StackManager &StackMgr,\n const uint32_t TableIdx,\n const RefVariant Val,\n const uint32_t NewSize) noexcept {\n auto *TabInst = getTabInstByIdx(StackMgr, TableIdx);\n assuming(TabInst);\n const uint32_t CurrTableSize = TabInst->getSize();\n if (likely(TabInst->growTable(NewSize, Val))) {\n return CurrTableSize;\n } else {\n return static_cast(-1);\n }\n}\n\nExpect Executor::tableSize(Runtime::StackManager &StackMgr,\n const uint32_t TableIdx) noexcept {\n auto *TabInst = getTabInstByIdx(StackMgr, TableIdx);\n assuming(TabInst);\n return TabInst->getSize();\n}\n\nExpect Executor::tableFill(Runtime::StackManager &StackMgr,\n const uint32_t TableIdx, const uint32_t Off,\n const RefVariant Ref,\n const uint32_t Len) noexcept {\n auto *TabInst = getTabInstByIdx(StackMgr, TableIdx);\n assuming(TabInst);\n if (auto Res = TabInst->fillRefs(Ref, Off, Len); unlikely(!Res)) {\n return Unexpect(Res);\n }\n\n return {};\n}\n\nExpect Executor::tableInit(Runtime::StackManager &StackMgr,\n const uint32_t TableIdx,\n const uint32_t ElemIdx, const uint32_t DstOff,\n const uint32_t SrcOff,\n const uint32_t Len) noexcept {\n auto *TabInst = getTabInstByIdx(StackMgr, TableIdx);\n assuming(TabInst);\n auto *ElemInst = getElemInstByIdx(StackMgr, ElemIdx);\n assuming(ElemInst);\n if (auto Res = TabInst->setRefs(ElemInst->getRefs(), DstOff, SrcOff, Len);\n unlikely(!Res)) {\n return Unexpect(Res);\n }\n\n return {};\n}\n\nExpect Executor::elemDrop(Runtime::StackManager &StackMgr,\n const uint32_t ElemIdx) noexcept {\n auto *ElemInst = getElemInstByIdx(StackMgr, ElemIdx);\n assuming(ElemInst);\n ElemInst->clear();\n\n return {};\n}\n\nExpect Executor::refFunc(Runtime::StackManager &StackMgr,\n const uint32_t FuncIdx) noexcept {\n auto *FuncInst = getFuncInstByIdx(StackMgr, FuncIdx);\n assuming(FuncInst);\n return RefVariant(FuncInst);\n}\n\nExpect Executor::memoryAtomicNotify(Runtime::StackManager &StackMgr,\n const uint32_t MemIdx,\n const uint32_t Offset,\n const uint32_t Count) noexcept {\n auto *MemInst = getMemInstByIdx(StackMgr, MemIdx);\n assuming(MemInst);\n\n return atomicNotify(*MemInst, Offset, Count);\n}\n\nExpect Executor::memoryAtomicWait(Runtime::StackManager &StackMgr,\n const uint32_t MemIdx,\n const uint32_t Offset,\n const uint64_t Expected,\n const int64_t Timeout,\n const uint32_t BitWidth) noexcept {\n auto *MemInst = getMemInstByIdx(StackMgr, MemIdx);\n assuming(MemInst);\n\n if (BitWidth == 64) {\n return atomicWait(*MemInst, Offset, Expected, Timeout);\n } else if (BitWidth == 32) {\n return atomicWait(*MemInst, Offset,\n static_cast(Expected), Timeout);\n }\n\n assumingUnreachable();\n}\n\nExpect Executor::callRef(Runtime::StackManager &StackMgr,\n const RefVariant Ref, const ValVariant *Args,\n ValVariant *Rets) noexcept {\n const auto *FuncInst = retrieveFuncRef(Ref);\n const auto &FuncType = FuncInst->getFuncType();\n const uint32_t ParamsSize =\n static_cast(FuncType.getParamTypes().size());\n const uint32_t ReturnsSize =\n static_cast(FuncType.getReturnTypes().size());\n\n for (uint32_t I = 0; I < ParamsSize; ++I) {\n StackMgr.push(Args[I]);\n }\n\n auto Instrs = FuncInst->getInstrs();\n AST::InstrView::iterator StartIt;\n if (auto Res = enterFunction(StackMgr, *FuncInst, Instrs.end())) {\n StartIt = *Res;\n } else {\n return Unexpect(Res);\n }\n if (auto Res = execute(StackMgr, StartIt, Instrs.end()); unlikely(!Res)) {\n return Unexpect(Res);\n }\n\n for (uint32_t I = 0; I < ReturnsSize; ++I) {\n Rets[ReturnsSize - 1 - I] = StackMgr.pop();\n }\n\n return {};\n}\n\nExpect Executor::refGetFuncSymbol(Runtime::StackManager &,\n const RefVariant Ref) noexcept {\n const auto *FuncInst = retrieveFuncRef(Ref);\n assuming(FuncInst);\n if (unlikely(!FuncInst->isCompiledFunction())) {\n return nullptr;\n }\n return FuncInst->getSymbol().get();\n}\n\n} // namespace Executor\n} // namespace WasmEdge\n\n// Path: lib/executor/engine/tableInstr.cpp\n// SPDX-License-Identifier: Apache-2.0\n// SPDX-FileCopyrightText: 2019-2022 Second State INC\n\n#include \"executor/executor.h\"\n\nnamespace WasmEdge {\nnamespace Executor {\n\nExpect Executor::runTableGetOp(Runtime::StackManager &StackMgr,\n Runtime::Instance::TableInstance &TabInst,\n const AST::Instruction &Instr) {\n // Pop Idx from Stack.\n uint32_t Idx = StackMgr.pop().get();\n\n // Get table[Idx] and push to Stack.\n if (auto Res = TabInst.getRefAddr(Idx)) {\n StackMgr.push(*Res);\n } else {\n spdlog::error(ErrInfo::InfoInstruction(Instr.getOpCode(), Instr.getOffset(),\n {Idx},\n {ValTypeFromType()}));\n return Unexpect(Res);\n }\n return {};\n}\n\n...\n", "needle_token_start": 10437, "needle_token_end": 10621, "code_context_ntokens": 16183, "output": ["```cpp\ntemplate \nstruct Executor::ProxyHelper (Executor::*)(Runtime::StackManager &,\n ArgsT...) noexcept> {\n template (Executor::*Func)(Runtime::StackManager &,\n ArgsT...) noexcept>\n static auto proxy(ArgsT... Args) {\n Expect Res = (This->*Func)(*CurrentStack, Args...);\n if (unlikely(!Res)) {\n Fault::emitFault(Res.error());\n }\n if constexpr (std::is_same_v) {\n#if defined(_MSC_VER) && !defined(__clang__) // MSVC\n return *reinterpret_cast<__m128 *>((*Res).getRawData().data());\n#else\n return (*Res).getRawData();\n#endif // MSVC\n } else if constexpr (!std::is_void_v) {\n return *Res;\n }\n }\n};\n```"]} +{"repo": "WasmEdge/WasmEdge", "name": "loadPlugins", "language": "cpp", "path": "lib/driver/wasiNNRPCServerTool.cpp", "position_ratio": 0.75, "description": "\nFunction Description:\n1. **Purpose**: The function is designed to initialize and register plugins from predefined paths, ensuring that the system is equipped with necessary extensions or functionalities before performing further operations.\n2. **Input**: There are no direct inputs from the user or function parameters; it operates based on internally defined plugin paths.\n3. **Output**: This function does not return any values but outputs log messages indicating the status of plugin loading, such as successful loads or failures.\n4. **Procedure**: \n - Iterates over a list of default plugin paths.\n - Attempts to load plugins from each path.\n - Logs the outcome of each attempt, whether successful or not.\n - After attempting to load plugins from all paths, it logs the names of all successfully loaded plugins.\n\n", "instruction": "Based on the function description and code context, please retrieve and repeat the exact described function from the code context in a code block wrapped by ```:", "template": "instruction\ncode_context\ndescription\ninstruction", "code_context": "// Path: lib/driver/uniTool.cpp\n// SPDX-License-Identifier: Apache-2.0\n// SPDX-FileCopyrightText: 2019-2022 Second State INC\n\n#include \"driver/unitool.h\"\n#include \"common/log.h\"\n#include \"driver/compiler.h\"\n#include \"driver/tool.h\"\n#include \"po/argument_parser.h\"\n\n#include \n#include \n\nnamespace WasmEdge {\nnamespace Driver {\n\nint UniTool(int Argc, const char *Argv[], const ToolType ToolSelect) noexcept {\n using namespace std::literals;\n\n std::ios::sync_with_stdio(false);\n Log::setInfoLoggingLevel();\n\n auto Parser = PO::ArgumentParser();\n\n PO::SubCommand ToolSubCommand(\n PO::Description(\"Wasmedge runtime tool subcommand\"sv));\n PO::SubCommand CompilerSubCommand(\n PO::Description(\"Wasmedge compiler subcommand\"sv));\n struct DriverToolOptions ToolOptions;\n struct DriverCompilerOptions CompilerOptions;\n\n // Construct Parser Subcommands and Options\n if (ToolSelect == ToolType::All) {\n ToolOptions.add_option(Parser);\n\n Parser.begin_subcommand(CompilerSubCommand, \"compile\"sv);\n CompilerOptions.add_option(Parser);\n Parser.end_subcommand();\n\n Parser.begin_subcommand(ToolSubCommand, \"run\"sv);\n ToolOptions.add_option(Parser);\n Parser.end_subcommand();\n } else if (ToolSelect == ToolType::Tool) {\n ToolOptions.add_option(Parser);\n } else if (ToolSelect == ToolType::Compiler) {\n CompilerOptions.add_option(Parser);\n } else {\n return EXIT_FAILURE;\n }\n\n // Parse\n if (!Parser.parse(stdout, Argc, Argv)) {\n return EXIT_FAILURE;\n }\n if (Parser.isVersion()) {\n std::cout << Argv[0] << \" version \"sv << kVersionString << '\\n';\n for (const auto &Plugin : Plugin::Plugin::plugins()) {\n auto PluginVersion = Plugin.version();\n std::cout << Plugin.path().string() << \" (plugin \\\"\"sv << Plugin.name()\n << \"\\\") version \"sv << PluginVersion.Major << '.'\n << PluginVersion.Minor << '.' << PluginVersion.Patch << '.'\n << PluginVersion.Build << '\\n';\n }\n return EXIT_SUCCESS;\n }\n if (Parser.isHelp()) {\n return EXIT_SUCCESS;\n }\n\n // Forward Results\n if (ToolSubCommand.is_selected() || ToolSelect == ToolType::Tool) {\n return Tool(ToolOptions);\n } else if (CompilerSubCommand.is_selected() ||\n ToolSelect == ToolType::Compiler) {\n return Compiler(CompilerOptions);\n } else {\n return Tool(ToolOptions);\n }\n}\n} // namespace Driver\n} // namespace WasmEdge\n\n// Path: lib/driver/compilerTool.cpp\n// SPDX-License-Identifier: Apache-2.0\n// SPDX-FileCopyrightText: 2019-2022 Second State INC\n\n#include \"common/configure.h\"\n#include \"common/defines.h\"\n#include \"common/filesystem.h\"\n#include \"common/version.h\"\n#include \"driver/compiler.h\"\n#include \"loader/loader.h\"\n#include \"validator/validator.h\"\n#include \"llvm/codegen.h\"\n#include \"llvm/compiler.h\"\n#include \n#include \n#include \n#include \n#include \n#include \n#include \n\nnamespace WasmEdge {\nnamespace Driver {\n\nint Compiler([[maybe_unused]] struct DriverCompilerOptions &Opt) noexcept {\n using namespace std::literals;\n\n std::ios::sync_with_stdio(false);\n Log::setInfoLoggingLevel();\n\n#ifdef WASMEDGE_USE_LLVM\n\n Configure Conf;\n if (Opt.PropMutGlobals.value()) {\n Conf.removeProposal(Proposal::ImportExportMutGlobals);\n }\n if (Opt.PropNonTrapF2IConvs.value()) {\n Conf.removeProposal(Proposal::NonTrapFloatToIntConversions);\n }\n if (Opt.PropSignExtendOps.value()) {\n Conf.removeProposal(Proposal::SignExtensionOperators);\n }\n if (Opt.PropMultiValue.value()) {\n Conf.removeProposal(Proposal::MultiValue);\n }\n if (Opt.PropBulkMemOps.value()) {\n Conf.removeProposal(Proposal::BulkMemoryOperations);\n }\n if (Opt.PropRefTypes.value()) {\n Conf.removeProposal(Proposal::ReferenceTypes);\n }\n if (Opt.PropSIMD.value()) {\n Conf.removeProposal(Proposal::SIMD);\n }\n if (Opt.PropMultiMem.value()) {\n Conf.addProposal(Proposal::MultiMemories);\n }\n if (Opt.PropTailCall.value()) {\n Conf.addProposal(Proposal::TailCall);\n }\n if (Opt.PropExtendConst.value()) {\n Conf.addProposal(Proposal::ExtendedConst);\n }\n if (Opt.PropThreads.value()) {\n Conf.addProposal(Proposal::Threads);\n }\n if (Opt.PropAll.value()) {\n Conf.addProposal(Proposal::MultiMemories);\n Conf.addProposal(Proposal::TailCall);\n Conf.addProposal(Proposal::ExtendedConst);\n Conf.addProposal(Proposal::Threads);\n }\n\n if (Opt.PropOptimizationLevel.value() == \"0\") {\n Conf.getCompilerConfigure().setOptimizationLevel(\n WasmEdge::CompilerConfigure::OptimizationLevel::O0);\n } else if (Opt.PropOptimizationLevel.value() == \"1\") {\n Conf.getCompilerConfigure().setOptimizationLevel(\n WasmEdge::CompilerConfigure::OptimizationLevel::O1);\n } else if (Opt.PropOptimizationLevel.value() == \"3\") {\n Conf.getCompilerConfigure().setOptimizationLevel(\n WasmEdge::CompilerConfigure::OptimizationLevel::O3);\n } else if (Opt.PropOptimizationLevel.value() == \"s\") {\n Conf.getCompilerConfigure().setOptimizationLevel(\n WasmEdge::CompilerConfigure::OptimizationLevel::Os);\n } else if (Opt.PropOptimizationLevel.value() == \"z\") {\n Conf.getCompilerConfigure().setOptimizationLevel(\n WasmEdge::CompilerConfigure::OptimizationLevel::Oz);\n } else {\n Conf.getCompilerConfigure().setOptimizationLevel(\n WasmEdge::CompilerConfigure::OptimizationLevel::O2);\n }\n\n // Set force interpreter here to load instructions of function body forcibly.\n Conf.getRuntimeConfigure().setForceInterpreter(true);\n\n std::filesystem::path InputPath =\n std::filesystem::absolute(std::filesystem::u8path(Opt.WasmName.value()));\n std::filesystem::path OutputPath =\n std::filesystem::absolute(std::filesystem::u8path(Opt.SoName.value()));\n Loader::Loader Loader(Conf);\n\n std::vector Data;\n if (auto Res = Loader.loadFile(InputPath)) {\n Data = std::move(*Res);\n } else {\n const auto Err = static_cast(Res.error());\n spdlog::error(\"Load failed. Error code: {}\", Err);\n return EXIT_FAILURE;\n }\n\n std::unique_ptr Module;\n if (auto Res = Loader.parseModule(Data)) {\n Module = std::move(*Res);\n } else {\n const auto Err = static_cast(Res.error());\n spdlog::error(\"Parse Module failed. Error code: {}\", Err);\n return EXIT_FAILURE;\n }\n\n {\n Validator::Validator ValidatorEngine(Conf);\n if (auto Res = ValidatorEngine.validate(*Module); !Res) {\n const auto Err = static_cast(Res.error());\n spdlog::error(\"Validate Module failed. Error code: {}\", Err);\n return EXIT_FAILURE;\n }\n }\n\n {\n if (Opt.ConfDumpIR.value()) {\n Conf.getCompilerConfigure().setDumpIR(true);\n }\n if (Opt.ConfInterruptible.value()) {\n Conf.getCompilerConfigure().setInterruptible(true);\n }\n if (Opt.ConfEnableAllStatistics.value()) {\n Conf.getStatisticsConfigure().setInstructionCounting(true);\n Conf.getStatisticsConfigure().setCostMeasuring(true);\n Conf.getStatisticsConfigure().setTimeMeasuring(true);\n } else {\n if (Opt.ConfEnableInstructionCounting.value()) {\n Conf.getStatisticsConfigure().setInstructionCounting(true);\n }\n if (Opt.ConfEnableGasMeasuring.value()) {\n Conf.getStatisticsConfigure().setCostMeasuring(true);\n }\n if (Opt.ConfEnableTimeMeasuring.value()) {\n Conf.getStatisticsConfigure().setTimeMeasuring(true);\n }\n }\n if (Opt.ConfGenericBinary.value()) {\n Conf.getCompilerConfigure().setGenericBinary(true);\n }\n if (OutputPath.extension().u8string() == WASMEDGE_LIB_EXTENSION) {\n Conf.getCompilerConfigure().setOutputFormat(\n CompilerConfigure::OutputFormat::Native);\n }\n LLVM::Compiler Compiler(Conf);\n LLVM::CodeGen CodeGen(Conf);\n if (auto Res = Compiler.compile(*Module); !Res) {\n const auto Err = static_cast(Res.error());\n spdlog::error(\"Compilation failed. Error code: {}\", Err);\n return EXIT_FAILURE;\n } else if (auto Res2 = CodeGen.codegen(Data, std::move(*Res), OutputPath);\n !Res2) {\n const auto Err = static_cast(Res2.error());\n spdlog::error(\"Code Generation failed. Error code: {}\", Err);\n return EXIT_FAILURE;\n }\n }\n\n return EXIT_SUCCESS;\n#else\n spdlog::error(\"Compilation is not supported!\");\n\n return EXIT_FAILURE;\n#endif\n}\n\n} // namespace Driver\n} // namespace WasmEdge\n\n// Path: lib/driver/runtimeTool.cpp\n// SPDX-License-Identifier: Apache-2.0\n// SPDX-FileCopyrightText: 2019-2022 Second State INC\n\n#include \"common/configure.h\"\n#include \"common/filesystem.h\"\n#include \"common/log.h\"\n#include \"common/types.h\"\n#include \"common/version.h\"\n#include \"driver/tool.h\"\n#include \"host/wasi/wasimodule.h\"\n#include \"vm/vm.h\"\n\n#include \n#include \n#include \n#include \n#include \n#include \n#include \n#include \n\nnamespace WasmEdge {\nnamespace Driver {\n\nint Tool(struct DriverToolOptions &Opt) noexcept {\n using namespace std::literals;\n\n std::ios::sync_with_stdio(false);\n Log::setInfoLoggingLevel();\n\n Configure Conf;\n if (Opt.PropAFUNIX.value()) {\n Conf.getRuntimeConfigure().setAllowAFUNIX(true);\n }\n if (Opt.PropMutGlobals.value()) {\n Conf.removeProposal(Proposal::ImportExportMutGlobals);\n }\n if (Opt.PropNonTrapF2IConvs.value()) {\n Conf.removeProposal(Proposal::NonTrapFloatToIntConversions);\n }\n if (Opt.PropSignExtendOps.value()) {\n Conf.removeProposal(Proposal::SignExtensionOperators);\n }\n if (Opt.PropMultiValue.value()) {\n Conf.removeProposal(Proposal::MultiValue);\n }\n if (Opt.PropBulkMemOps.value()) {\n Conf.removeProposal(Proposal::BulkMemoryOperations);\n }\n if (Opt.PropRefTypes.value()) {\n Conf.removeProposal(Proposal::ReferenceTypes);\n }\n if (Opt.PropSIMD.value()) {\n Conf.removeProposal(Proposal::SIMD);\n }\n if (Opt.PropMultiMem.value()) {\n Conf.addProposal(Proposal::MultiMemories);\n }\n if (Opt.PropTailCall.value()) {\n Conf.addProposal(Proposal::TailCall);\n }\n if (Opt.PropExtendConst.value()) {\n Conf.addProposal(Proposal::ExtendedConst);\n }\n if (Opt.PropThreads.value()) {\n Conf.addProposal(Proposal::Threads);\n }\n if (Opt.PropFunctionReference.value()) {\n Conf.addProposal(Proposal::FunctionReferences);\n }\n if (Opt.PropGC.value()) {\n Conf.addProposal(Proposal::GC);\n spdlog::warn(\"GC proposal is enabled, this is experimental.\");\n }\n if (Opt.PropComponent.value()) {\n Conf.addProposal(Proposal::Component);\n spdlog::warn(\"component model is enabled, this is experimental.\");\n }\n if (Opt.PropAll.value()) {\n Conf.addProposal(Proposal::MultiMemories);\n Conf.addProposal(Proposal::TailCall);\n Conf.addProposal(Proposal::ExtendedConst);\n Conf.addProposal(Proposal::Threads);\n Conf.addProposal(Proposal::GC);\n Conf.addProposal(Proposal::Component);\n spdlog::warn(\"GC proposal is enabled, this is experimental.\");\n spdlog::warn(\"component model is enabled, this is experimental.\");\n }\n\n std::optional Timeout;\n if (Opt.TimeLim.value() > 0) {\n Timeout = std::chrono::system_clock::now() +\n std::chrono::milliseconds(Opt.TimeLim.value());\n }\n if (Opt.GasLim.value().size() > 0) {\n Conf.getStatisticsConfigure().setCostMeasuring(true);\n Conf.getStatisticsConfigure().setCostLimit(\n static_cast(Opt.GasLim.value().back()));\n }\n if (Opt.MemLim.value().size() > 0) {\n Conf.getRuntimeConfigure().setMaxMemoryPage(\n static_cast(Opt.MemLim.value().back()));\n }\n if (Opt.ConfEnableAllStatistics.value()) {\n Conf.getStatisticsConfigure().setInstructionCounting(true);\n Conf.getStatisticsConfigure().setCostMeasuring(true);\n Conf.getStatisticsConfigure().setTimeMeasuring(true);\n } else {\n if (Opt.ConfEnableInstructionCounting.value()) {\n Conf.getStatisticsConfigure().setInstructionCounting(true);\n }\n if (Opt.ConfEnableGasMeasuring.value()) {\n Conf.getStatisticsConfigure().setCostMeasuring(true);\n }\n if (Opt.ConfEnableTimeMeasuring.value()) {\n Conf.getStatisticsConfigure().setTimeMeasuring(true);\n }\n }\n if (Opt.ConfEnableJIT.value()) {\n Conf.getRuntimeConfigure().setEnableJIT(true);\n Conf.getCompilerConfigure().setOptimizationLevel(\n WasmEdge::CompilerConfigure::OptimizationLevel::O1);\n }\n if (Opt.ConfForceInterpreter.value()) {\n Conf.getRuntimeConfigure().setForceInterpreter(true);\n }\n\n for (const auto &Name : Opt.ForbiddenPlugins.value()) {\n Conf.addForbiddenPlugins(Name);\n }\n\n Conf.addHostRegistration(HostRegistration::Wasi);\n const auto InputPath =\n std::filesystem::absolute(std::filesystem::u8path(Opt.SoName.value()));\n VM::VM VM(Conf);\n\n Host::WasiModule *WasiMod = dynamic_cast(\n VM.getImportModule(HostRegistration::Wasi));\n\n if (auto Result = VM.loadWasm(InputPath.u8string()); !Result) {\n return EXIT_FAILURE;\n }\n if (auto Result = VM.validate(); !Result) {\n return EXIT_FAILURE;\n }\n if (auto Result = VM.instantiate(); !Result) {\n return EXIT_FAILURE;\n }\n\n auto HasValidCommandModStartFunc = [&]() {\n bool HasStart = false;\n bool Valid = false;\n\n auto Functions = VM.getFunctionList();\n for (auto &[FuncName, Type] : Functions) {\n if (FuncName == \"_start\") {\n HasStart = true;\n if (Type.getReturnTypes().size() == 0 &&\n Type.getParamTypes().size() == 0) {\n Valid = true;\n break;\n }\n }\n }\n\n // if HasStart but not Valid, insert _start to enter reactor mode\n if (HasStart && !Valid) {\n Opt.Args.value().insert(Opt.Args.value().begin(), \"_start\");\n }\n\n return HasStart && Valid;\n };\n\n bool EnterCommandMode = !Opt.Reactor.value() && HasValidCommandModStartFunc();\n\n WasiMod->getEnv().init(\n Opt.Dir.value(),\n InputPath.filename()\n .replace_extension(std::filesystem::u8path(\"wasm\"sv))\n .u8string(),\n Opt.Args.value(), Opt.Env.value());\n\n if (EnterCommandMode) {\n // command mode\n auto AsyncResult = VM.asyncExecute(\"_start\"sv);\n if (Timeout.has_value()) {\n if (!AsyncResult.waitUntil(*Timeout)) {\n AsyncResult.cancel();\n }\n }\n if (auto Result = AsyncResult.get();\n Result || Result.error() == ErrCode::Value::Terminated) {\n return static_cast(WasiMod->getEnv().getExitCode());\n } else {\n // It indicates that the execution of wasm has been aborted\n return 128 + SIGABRT;\n }\n } else {\n // reactor mode\n if (Opt.Args.value().empty()) {\n std::cerr\n << \"A function name is required when reactor mode is enabled.\\n\";\n return EXIT_FAILURE;\n }\n const auto &FuncName = Opt.Args.value().front();\n\n using namespace std::literals::string_literals;\n const auto InitFunc = \"_initialize\"s;\n\n bool HasInit = false;\n AST::FunctionType FuncType;\n\n for (const auto &Func : VM.getFunctionList()) {\n if (Func.first == InitFunc) {\n HasInit = true;\n } else if (Func.first == FuncName) {\n FuncType = Func.second;\n }\n }\n\n if (HasInit) {\n auto AsyncResult = VM.asyncExecute(InitFunc);\n if (Timeout.has_value()) {\n if (!AsyncResult.waitUntil(*Timeout)) {\n AsyncResult.cancel();\n }\n }\n if (auto Result = AsyncResult.get(); unlikely(!Result)) {\n // It indicates that the execution of wasm has been aborted\n return 128 + SIGABRT;\n }\n }\n\n std::vector FuncArgs;\n std::vector FuncArgTypes;\n for (size_t I = 0;\n I < FuncType.getParamTypes().size() && I + 1 < Opt.Args.value().size();\n ++I) {\n switch (FuncType.getParamTypes()[I].getCode()) {\n case TypeCode::I32: {\n const uint32_t Value =\n static_cast(std::stol(Opt.Args.value()[I + 1]));\n FuncArgs.emplace_back(Value);\n FuncArgTypes.emplace_back(TypeCode::I32);\n break;\n }\n case TypeCode::I64: {\n const uint64_t Value =\n static_cast(std::stoll(Opt.Args.value()[I + 1]));\n FuncArgs.emplace_back(Value);\n FuncArgTypes.emplace_back(TypeCode::I64);\n break;\n }\n case TypeCode::F32: {\n const float Value = std::stof(Opt.Args.value()[I + 1]);\n FuncArgs.emplace_back(Value);\n FuncArgTypes.emplace_back(TypeCode::F32);\n break;\n }\n case TypeCode::F64: {\n const double Value = std::stod(Opt.Args.value()[I + 1]);\n FuncArgs.emplace_back(Value);\n FuncArgTypes.emplace_back(TypeCode::F64);\n break;\n }\n /// TODO: FuncRef and ExternRef\n default:\n break;\n }\n }\n if (FuncType.getParamTypes().size() + 1 < Opt.Args.value().size()) {\n for (size_t I = FuncType.getParamTypes().size() + 1;\n I < Opt.Args.value().size(); ++I) {\n const uint64_t Value =\n static_cast(std::stoll(Opt.Args.value()[I]));\n FuncArgs.emplace_back(Value);\n FuncArgTypes.emplace_back(TypeCode::I64);\n }\n }\n\n auto AsyncResult = VM.asyncExecute(FuncName, FuncArgs, FuncArgTypes);\n if (Timeout.has_value()) {\n if (!AsyncResult.waitUntil(*Timeout)) {\n AsyncResult.cancel();\n }\n }\n if (auto Result = AsyncResult.get()) {\n /// Print results.\n for (size_t I = 0; I < Result->size(); ++I) {\n switch ((*Result)[I].second.getCode()) {\n case TypeCode::I32:\n std::cout << (*Result)[I].first.get() << '\\n';\n break;\n case TypeCode::I64:\n std::cout << (*Result)[I].first.get() << '\\n';\n break;\n case TypeCode::F32:\n std::cout << (*Result)[I].first.get() << '\\n';\n break;\n case TypeCode::F64:\n std::cout << (*Result)[I].first.get() << '\\n';\n break;\n case TypeCode::V128:\n std::cout << (*Result)[I].first.get() << '\\n';\n break;\n /// TODO: FuncRef and ExternRef\n default:\n break;\n }\n }\n return EXIT_SUCCESS;\n } else {\n // It indicates that the execution of wasm has been aborted\n return 128 + SIGABRT;\n }\n }\n}\n\n} // namespace Driver\n} // namespace WasmEdge\n\n// Path: lib/driver/wasiNNRPCServerTool.cpp\n#include \"common/log.h\"\n#include \"driver/wasi_nn_rpc/wasi_nn_rpcserver/wasi_nn_rpcserver.h\"\n#include \"plugin/plugin.h\"\n#include \"po/argument_parser.h\"\n\n#include \n#include \n\nusing namespace std::literals;\nusing namespace WasmEdge;\n\nnamespace WasmEdge {\nnamespace Driver {\n\n\nvoid loadPlugins(void) {\n for (const auto &Path : Plugin::Plugin::getDefaultPluginPaths()) {\n spdlog::info(\"Loading plugin path {}\"sv, Path);\n if (Plugin::Plugin::load(Path)) {\n spdlog::info(\"Loaded plugin path {}\"sv, Path);\n } else {\n spdlog::info(\"Nothing was loaded from plugin path {}\"sv, Path);\n }\n }\n for (const auto &Plugin : Plugin::Plugin::plugins()) {\n spdlog::info(\"Plugin: {}\", Plugin.name());\n }\n}\n\nRuntime::Instance::ModuleInstance *createWasiNNModule() {\n if (const auto *Plugin = Plugin::Plugin::find(\"wasi_nn\"sv)) {\n if (const auto *Module = Plugin->findModule(\"wasi_nn\"sv)) {\n return Module->create().release();\n }\n }\n return nullptr;\n}\n\nint WasiNNRPCServer(int Argc, const char *Argv[]) noexcept {\n std::ios::sync_with_stdio(false);\n Log::setInfoLoggingLevel();\n setenv(\"_WASI_NN_RPCSERVER\", \"1\", 1); // wasi_nn plugin checks this env var\n\n // Parse the args\n PO::Option NNRPCURI(\n PO::Description(\"Specify NN RPC URI to serve (\\\"unix://...\\\")\"sv),\n PO::MetaVar(\"URI\"sv), PO::DefaultValue(std::string(\"\")));\n auto Parser = PO::ArgumentParser();\n Parser.add_option(\"nn-rpc-uri\"sv, NNRPCURI);\n loadPlugins();\n Plugin::Plugin::addPluginOptions(Parser); // Register \"nn-preload\", etc.\n if (!Parser.parse(stdout, Argc, Argv)) {\n return EXIT_FAILURE;\n }\n if (Parser.isHelp()) {\n return EXIT_SUCCESS;\n }\n auto URI = NNRPCURI.value();\n if (URI.empty()) {\n spdlog::error(\"--nn-rpc-uri has to be specified\"sv);\n return EXIT_FAILURE;\n }\n\n // Create the wasi_nn module\n auto *NNMod = createWasiNNModule();\n if (NNMod == nullptr) {\n spdlog::error(\n \"Failed to get the wasi_nn module (Hint: set $WASMEDGE_PLUGIN_PATH to \"\n \"the directory where libwasmedgePluginWasiNN.* exists\"sv);\n return EXIT_FAILURE;\n }\n\n // Create the services\n WasiNNRPC::Server::ServiceSet ServiceSet(*NNMod);\n\n // Create the gRPC server\n grpc::ServerBuilder Builder;\n spdlog::info(\"Listening on \\\"{}\\\"\"sv, URI);\n std::string_view UnixPrefix = \"unix://\";\n if (URI.substr(0, UnixPrefix.length()) != UnixPrefix) {\n spdlog::warn(\"Expected \\\"unix://...\\\", got \\\"{}\\\"\"sv, URI);\n }\n auto Cred = grpc::InsecureServerCredentials(); // safe for unix://...\n Builder.AddListeningPort(URI, Cred);\n for (auto *Service : ServiceSet.services()) {\n Builder.RegisterService(Service);\n }\n\n // Start the gRPC server\n auto Server = Builder.BuildAndStart();\n if (Server == nullptr) {\n return EXIT_FAILURE;\n }\n Server->Wait();\n return EXIT_SUCCESS;\n}\n} // namespace Driver\n} // namespace WasmEdge\n\n// Path: lib/driver/fuzzPO.cpp\n// SPDX-License-Identifier: Apache-2.0\n// SPDX-FileCopyrightText: 2019-2022 Second State INC\n\n#ifdef WASMEDGE_BUILD_FUZZING\n#include \"driver/fuzzPO.h\"\n#include \"common/log.h\"\n#include \"common/version.h\"\n#include \"po/argument_parser.h\"\n\n#include \n#include \n#include \n#include \n#include \n#include \n#include \nnamespace {\ntemplate \nclass SkipTable {\nprivate:\n using UnsignedKey = std::make_unsigned_t;\n std::array(std::numeric_limits::max()) +\n 1u>\n Table;\n\npublic:\n SkipTable(std::size_t, Value Default, Hash, BinaryPredicate) {\n std::fill_n(Table.begin(), Table.size(), Default);\n }\n\n void insert(const Key &K, Value V) { Table[static_cast(K)] = V; }\n\n const Value &at(const Key &K) const {\n return Table[static_cast(K)];\n }\n};\n\ntemplate ::value_type>,\n class BinaryPredicate = std::equal_to<>>\nclass BoyerMooreHorspoolSearcher {\nprivate:\n using Key = typename std::iterator_traits::value_type;\n using Value = typename std::iterator_traits::difference_type;\n static_assert(std::is_integral_v && sizeof(Key) == 1 &&\n std::is_same_v> &&\n std::is_same_v>);\n using SkipTableType = SkipTable;\n\npublic:\n BoyerMooreHorspoolSearcher(RandomIt1 First, RandomIt1 Last, Hash HF = Hash(),\n BinaryPredicate Pred = BinaryPredicate())\n : Pattern(First), PatternLength(std::distance(First, Last)), Pred(Pred),\n Table(PatternLength, PatternLength, HF, Pred) {\n if (First != Last) {\n --Last;\n for (Value I = 0; First != Last; ++First, ++I) {\n Table.insert(*First, PatternLength - 1 - I);\n }\n }\n }\n\n template \n std::pair operator()(RandomIt2 First,\n RandomIt2 Last) const {\n static_assert(\n std::is_same_v<\n std::remove_cv_t::value_type>>,\n std::remove_cv_t::value_type>>>,\n \"Corpus and Pattern iterators must point to the same type\");\n if (First == Last) {\n // empty corpus\n return {Last, Last};\n }\n if (PatternLength == 0) {\n // empty pattern\n return {First, First};\n }\n // the pattern is larger than the corpus\n if (PatternLength > std::distance(First, Last)) {\n return {Last, Last};\n }\n\n RandomIt2 Curr = First;\n const RandomIt2 End = Last - PatternLength;\n while (Curr <= End) {\n Value J = PatternLength;\n while (Pred(Pattern[J - 1], Curr[J - 1])) {\n --J;\n if (J == 0) {\n // found\n return {Curr, Curr + PatternLength};\n }\n }\n const auto K = Curr[PatternLength - 1];\n const auto D = Table.at(K);\n Curr += D;\n }\n return {Last, Last};\n }\n\nprivate:\n RandomIt1 Pattern;\n Value PatternLength;\n BinaryPredicate Pred;\n SkipTableType Table;\n};\n} // namespace\n\nnamespace WasmEdge {\nnamespace Driver {\n\nint FuzzPO(const uint8_t *Data, size_t Size) noexcept {\n using namespace std::literals;\n\n std::ios::sync_with_stdio(false);\n spdlog::set_level(spdlog::level::info);\n\n PO::Option SoName(PO::Description(\"Wasm or so file\"sv),\n PO::MetaVar(\"WASM_OR_SO\"sv));\n PO::List Args(PO::Description(\"Execution arguments\"sv),\n PO::MetaVar(\"ARG\"sv));\n\n PO::Option Reactor(PO::Description(\n \"Enable reactor mode. Reactor mode calls `_initialize` if exported.\"));\n\n PO::List Dir(\n PO::Description(\n \"Binding directories into WASI virtual filesystem. Each directories \"\n \"can specified as --dir `guest_path:host_path`, where `guest_path` \"\n \"specifies the path that will correspond to `host_path` for calls \"\n \"like `fopen` in the guest.\"sv),\n PO::MetaVar(\"PREOPEN_DIRS\"sv));\n\n PO::List Env(\n PO::Description(\n \"Environ variables. Each variable can be specified as --env `NAME=VALUE`.\"sv),\n PO::MetaVar(\"ENVS\"sv));\n\n PO::Option PropMutGlobals(\n PO::Description(\"Disable Import/Export of mutable globals proposal\"sv));\n PO::Option PropNonTrapF2IConvs(PO::Description(\n \"Disable Non-trapping float-to-int conversions proposal\"sv));\n PO::Option PropSignExtendOps(\n PO::Description(\"Disable Sign-extension operators proposal\"sv));\n PO::Option PropMultiValue(\n PO::Description(\"Disable Multi-value proposal\"sv));\n PO::Option PropBulkMemOps(\n PO::Description(\"Disable Bulk memory operations proposal\"sv));\n PO::Option PropRefTypes(\n PO::Description(\"Disable Reference types proposal\"sv));\n PO::Option PropSIMD(PO::Description(\"Disable SIMD proposal\"sv));\n PO::Option PropMultiMem(\n PO::Description(\"Enable Multiple memories proposal\"sv));\n PO::Option PropTailCall(\n PO::Description(\"Enable Tail-call proposal\"sv));\n PO::Option PropExtendConst(\n PO::Description(\"Enable Extended-const proposal\"sv));\n PO::Option PropThreads(\n PO::Description(\"Enable Threads proposal\"sv));\n PO::Option PropAll(PO::Description(\"Enable all features\"sv));\n\n PO::Option ConfEnableInstructionCounting(PO::Description(\n \"Enable generating code for counting Wasm instructions executed.\"sv));\n PO::Option ConfEnableGasMeasuring(PO::Description(\n \"Enable generating code for counting gas burned during execution.\"sv));\n PO::Option ConfEnableTimeMeasuring(PO::Description(\n \"Enable generating code for counting time during execution.\"sv));\n PO::Option ConfEnableAllStatistics(PO::Description(\n \"Enable generating code for all statistics options include instruction counting, gas measuring, and execution time\"sv));\n\n PO::Option TimeLim(\n PO::Description(\n \"Limitation of maximum time(in milliseconds) for execution, default value is 0 for no limitations\"sv),\n PO::MetaVar(\"TIMEOUT\"sv), PO::DefaultValue(0));\n\n PO::List GasLim(\n PO::Description(\n \"Limitation of execution gas. Upper bound can be specified as --gas-limit `GAS_LIMIT`.\"sv),\n PO::MetaVar(\"GAS_LIMIT\"sv));\n\n PO::List MemLim(\n PO::Description(\n \"Limitation of pages(as size of 64 KiB) in every memory instance. Upper bound can be specified as --memory-page-limit `PAGE_COUNT`.\"sv),\n PO::MetaVar(\"PAGE_COUNT\"sv));\n\n PO::List ForbiddenPlugins(\n PO::Description(\"List of plugins to ignore.\"sv), PO::MetaVar(\"NAMES\"sv));\n\n auto Parser = PO::ArgumentParser();\n Parser.add_option(SoName)\n .add_option(Args)\n .add_option(\"reactor\"sv, Reactor)\n .add_option(\"dir\"sv, Dir)\n .add_option(\"env\"sv, Env)\n .add_option(\"enable-instruction-count\"sv, ConfEnableInstructionCounting)\n .add_option(\"enable-gas-measuring\"sv, ConfEnableGasMeasuring)\n .add_option(\"enable-time-measuring\"sv, ConfEnableTimeMeasuring)\n .add_option(\"enable-all-statistics\"sv, ConfEnableAllStatistics)\n .add_option(\"disable-import-export-mut-globals\"sv, PropMutGlobals)\n .add_option(\"disable-non-trap-float-to-int\"sv, PropNonTrapF2IConvs)\n .add_option(\"disable-sign-extension-operators\"sv, PropSignExtendOps)\n .add_option(\"disable-multi-value\"sv, PropMultiValue)\n .add_option(\"disable-bulk-memory\"sv, PropBulkMemOps)\n .add_option(\"disable-reference-types\"sv, PropRefTypes)\n .add_option(\"disable-simd\"sv, PropSIMD)\n .add_option(\"enable-multi-memory\"sv, PropMultiMem)\n .add_option(\"enable-tail-call\"sv, PropTailCall)\n .add_option(\"enable-extended-const\"sv, PropExtendConst)\n .add_option(\"enable-threads\"sv, PropThreads)\n .add_option(\"enable-all\"sv, PropAll)\n .add_option(\"time-limit\"sv, TimeLim)\n .add_option(\"gas-limit\"sv, GasLim)\n .add_option(\"memory-page-limit\"sv, MemLim)\n .add_option(\"forbidden-plugin\"sv, ForbiddenPlugins);\n\n static constexpr const std::array Separator = {'\\xde', '\\xad',\n '\\xbe', '\\xef'};\n static const BoyerMooreHorspoolSearcher Searcher(Separator.begin(),\n Separator.end());\n Span RawArgs(reinterpret_cast(Data), Size);\n std::vector ArgvStr;\n std::vector Argv;\n while (!RawArgs.empty()) {\n const auto It = std::search(RawArgs.begin(), RawArgs.end(), Searcher);\n ArgvStr.emplace_back(RawArgs.begin(), It);\n RawArgs = RawArgs.subspan(std::min(\n std::distance(RawArgs.begin(), It) + 4, RawArgs.size()));\n }\n for (const auto &Arg : ArgvStr) {\n Argv.push_back(Arg.c_str());\n }\n\n std::unique_ptr Out{\n std::fopen(\"/dev/null\", \"w\"), std::fclose};\n if (!Parser.parse(Out.get(), Argv.size(), Argv.data())) {\n return EXIT_FAILURE;\n }\n if (Parser.isVersion()) {\n fmt::print(Out.get(), \"{} version {}\\n\"sv, Argv.empty() ? \"\" : Argv[0],\n kVersionString);\n return EXIT_SUCCESS;\n }\n\n return EXIT_SUCCESS;\n}\n\n} // namespace Driver\n} // namespace WasmEdge\n#endif\n\n// Path: lib/driver/fuzzTool.cpp\n...\n", "needle_token_start": 6402, "needle_token_end": 6548, "code_context_ntokens": 10528, "output": ["```cpp\n#include \"common/log.h\"\n#include \"plugin/plugin.h\"\n#include \"po/argument_parser.h\"\n\n#include \n#include \n\nusing namespace std::literals;\nusing namespace WasmEdge;\n\nnamespace WasmEdge {\nnamespace Driver {\n\nvoid loadPlugins(void) {\n for (const auto &Path : Plugin::Plugin::getDefaultPluginPaths()) {\n spdlog::info(\"Loading plugin path {}\"sv, Path);\n if (Plugin::Plugin::load(Path)) {\n spdlog::info(\"Loaded plugin path {}\"sv, Path);\n } else {\n spdlog::info(\"Nothing was loaded from plugin path {}\"sv, Path);\n }\n }\n for (const auto &Plugin : Plugin::Plugin::plugins()) {\n spdlog::info(\"Plugin: {}\", Plugin.name());\n }\n}\n\n} // namespace Driver\n} // namespace WasmEdge\n```"]} +{"repo": "WasmEdge/WasmEdge", "name": "fromLastError", "language": "cpp", "path": "lib/host/wasi/win.h", "position_ratio": 0.85, "description": "\nFunction Description:\n1. **Purpose**: The function translates system-specific error codes into standardized error codes defined by a WebAssembly system interface.\n2. **Input**: A system-specific error code (numerical value).\n3. **Output**: A standardized error code as defined by the WebAssembly system interface.\n4. **Procedure**: The function checks the input error code against a series of predefined system-specific error codes. Depending on the input error code, it returns the corresponding standardized error code. If the input error code does not match any predefined codes, it defaults to returning a standardized error code indicating that the operation is not implemented.\n\n", "instruction": "Based on the function description and code context, please retrieve and repeat the exact described function from the code context in a code block wrapped by ```:", "template": "instruction\ncode_context\ndescription\ninstruction", "code_context": "// Path: lib/host/wasi/inode-linux.cpp\n// SPDX-License-Identifier: Apache-2.0\n// SPDX-FileCopyrightText: 2019-2022 Second State INC\n\n#include \"common/defines.h\"\n#if WASMEDGE_OS_LINUX\n\n#include \"common/errcode.h\"\n#include \"common/variant.h\"\n#include \"host/wasi/environ.h\"\n#include \"host/wasi/inode.h\"\n#include \"host/wasi/vfs.h\"\n#include \"linux.h\"\n#include \n#include \n#include \n#include \n#include \n#include \n#include \n#include \n\nnamespace WasmEdge {\nnamespace Host {\nnamespace WASI {\n\nnamespace {\n\ninline constexpr bool isSpecialFd(int Fd) noexcept {\n switch (Fd) {\n case STDIN_FILENO:\n case STDOUT_FILENO:\n case STDERR_FILENO:\n return true;\n default:\n return false;\n }\n}\n\ninline constexpr __wasi_size_t\ncalculateAddrinfoLinkedListSize(struct addrinfo *const Addrinfo) {\n __wasi_size_t Length = 0;\n for (struct addrinfo *TmpPointer = Addrinfo; TmpPointer != nullptr;\n TmpPointer = TmpPointer->ai_next) {\n Length++;\n }\n return Length;\n};\n\nconstexpr int openFlags(__wasi_oflags_t OpenFlags, __wasi_fdflags_t FdFlags,\n VFS::Flags VFSFlags) noexcept {\n int Flags = O_NOFOLLOW;\n#ifdef O_CLOEXEC\n Flags |= O_CLOEXEC;\n#endif\n\n if (VFSFlags & VFS::Read) {\n if (VFSFlags & VFS::Write) {\n Flags |= O_RDWR;\n } else {\n Flags |= O_RDONLY;\n }\n } else if (VFSFlags & VFS::Write) {\n Flags |= O_WRONLY;\n } else {\n#ifdef O_PATH\n if (OpenFlags == __WASI_OFLAGS_DIRECTORY) {\n Flags |= O_PATH;\n } else {\n Flags |= O_RDONLY;\n }\n#else\n Flags |= O_RDONLY;\n#endif\n }\n\n if (OpenFlags & __WASI_OFLAGS_CREAT) {\n Flags |= O_CREAT;\n }\n if (OpenFlags & __WASI_OFLAGS_DIRECTORY) {\n Flags |= O_DIRECTORY;\n }\n if (OpenFlags & __WASI_OFLAGS_EXCL) {\n Flags |= O_EXCL;\n }\n if (OpenFlags & __WASI_OFLAGS_TRUNC) {\n Flags |= O_TRUNC;\n }\n\n // Convert file descriptor flags.\n if ((FdFlags & __WASI_FDFLAGS_APPEND) != 0) {\n Flags |= O_APPEND;\n }\n if ((FdFlags & __WASI_FDFLAGS_DSYNC) != 0) {\n#ifdef O_DSYNC\n Flags |= O_DSYNC;\n#else\n Flags |= O_SYNC;\n#endif\n }\n if ((FdFlags & __WASI_FDFLAGS_NONBLOCK) != 0) {\n Flags |= O_NONBLOCK;\n }\n if ((FdFlags & __WASI_FDFLAGS_RSYNC) != 0) {\n#ifdef O_RSYNC\n Flags |= O_RSYNC;\n#else\n Flags |= O_SYNC;\n#endif\n }\n if ((FdFlags & __WASI_FDFLAGS_SYNC) != 0) {\n Flags |= O_SYNC;\n }\n\n return Flags;\n}\n\nstd::pair>\ncreateNullTerminatedString(std::string_view View) noexcept {\n const char *CStr = nullptr;\n std::unique_ptr Buffer;\n if (!View.empty()) {\n if (const auto Pos = View.find_first_of('\\0');\n Pos != std::string_view::npos) {\n CStr = View.data();\n } else {\n Buffer = std::make_unique(View.size() + 1);\n std::copy(View.begin(), View.end(), Buffer.get());\n CStr = Buffer.get();\n }\n }\n return {CStr, std::move(Buffer)};\n}\n\n} // namespace\n\nvoid FdHolder::reset() noexcept {\n if (likely(ok())) {\n if (likely(!isSpecialFd(Fd))) {\n ::close(Fd);\n }\n Fd = -1;\n }\n}\n\nvoid TimerHolder::reset() noexcept {\n if (likely(Id.has_value())) {\n timer_delete(*Id);\n Id.reset();\n }\n}\n\nvoid DirHolder::reset() noexcept {\n if (likely(Dir != nullptr)) {\n closedir(Dir);\n Dir = nullptr;\n Cookie = 0;\n }\n}\n\nINode INode::stdIn() noexcept { return INode(STDIN_FILENO); }\n\nINode INode::stdOut() noexcept { return INode(STDOUT_FILENO); }\n\nINode INode::stdErr() noexcept { return INode(STDERR_FILENO); }\n\nWasiExpect INode::open(std::string Path, __wasi_oflags_t OpenFlags,\n __wasi_fdflags_t FdFlags,\n VFS::Flags VFSFlags) noexcept {\n const int Flags = openFlags(OpenFlags, FdFlags, VFSFlags);\n\n if (auto NewFd = ::open(Path.c_str(), Flags, 0644); unlikely(NewFd < 0)) {\n return WasiUnexpect(fromErrNo(errno));\n } else {\n INode New(NewFd);\n#ifndef O_CLOEXEC\n if (auto Res = ::fcntl(New.Fd, F_SETFD, FD_CLOEXEC); unlikely(Res != 0)) {\n return WasiUnexpect(fromErrNo(errno));\n }\n#endif\n return New;\n }\n}\n\nWasiExpect INode::fdAdvise(__wasi_filesize_t Offset,\n __wasi_filesize_t Len,\n __wasi_advice_t Advice) const noexcept {\n if (auto Res = ::posix_fadvise(Fd, Offset, Len, toAdvice(Advice));\n unlikely(Res != 0)) {\n return WasiUnexpect(fromErrNo(errno));\n }\n\n return {};\n}\n\nWasiExpect INode::fdAllocate(__wasi_filesize_t Offset,\n __wasi_filesize_t Len) const noexcept {\n if (auto Res = ::posix_fallocate(Fd, Offset, Len); unlikely(Res != 0)) {\n return WasiUnexpect(fromErrNo(errno));\n }\n\n return {};\n}\n\nWasiExpect INode::fdDatasync() const noexcept {\n if (auto Res = ::fdatasync(Fd); unlikely(Res != 0)) {\n return WasiUnexpect(fromErrNo(errno));\n }\n\n return {};\n}\n\nWasiExpect INode::fdFdstatGet(__wasi_fdstat_t &FdStat) const noexcept {\n if (auto Res = updateStat(); unlikely(!Res)) {\n return WasiUnexpect(Res);\n }\n\n if (int FdFlags = ::fcntl(Fd, F_GETFL); unlikely(FdFlags < 0)) {\n return WasiUnexpect(fromErrNo(errno));\n } else {\n FdStat.fs_filetype = unsafeFiletype();\n\n FdStat.fs_flags = static_cast<__wasi_fdflags_t>(0);\n if (FdFlags & O_APPEND) {\n FdStat.fs_flags |= __WASI_FDFLAGS_APPEND;\n }\n if (FdFlags & O_DSYNC) {\n FdStat.fs_flags |= __WASI_FDFLAGS_DSYNC;\n }\n if (FdFlags & O_NONBLOCK) {\n FdStat.fs_flags |= __WASI_FDFLAGS_NONBLOCK;\n }\n if (FdFlags & O_SYNC) {\n FdStat.fs_flags |= __WASI_FDFLAGS_RSYNC | __WASI_FDFLAGS_SYNC;\n }\n }\n\n return {};\n}\n\nWasiExpect\nINode::fdFdstatSetFlags(__wasi_fdflags_t FdFlags) const noexcept {\n int SysFlag = 0;\n if (FdFlags & __WASI_FDFLAGS_NONBLOCK) {\n SysFlag |= O_NONBLOCK;\n }\n if (FdFlags & __WASI_FDFLAGS_APPEND) {\n SysFlag |= O_APPEND;\n }\n if (FdFlags & __WASI_FDFLAGS_DSYNC) {\n SysFlag |= O_DSYNC;\n }\n if (FdFlags & __WASI_FDFLAGS_RSYNC) {\n SysFlag |= O_RSYNC;\n }\n if (FdFlags & __WASI_FDFLAGS_SYNC) {\n SysFlag |= O_SYNC;\n }\n\n if (auto Res = ::fcntl(Fd, F_SETFL, SysFlag); unlikely(Res != 0)) {\n return WasiUnexpect(fromErrNo(errno));\n }\n\n return {};\n}\n\nWasiExpect\nINode::fdFilestatGet(__wasi_filestat_t &Filestat) const noexcept {\n if (auto Res = updateStat(); unlikely(!Res)) {\n return WasiUnexpect(Res);\n }\n\n // Zeroing out these values to prevent leaking information about the host\n // environment from special fd such as stdin, stdout and stderr.\n Filestat.dev = isSpecialFd(Fd) ? 0 : Stat->st_dev;\n Filestat.ino = isSpecialFd(Fd) ? 0 : Stat->st_ino;\n Filestat.filetype = unsafeFiletype();\n Filestat.nlink = isSpecialFd(Fd) ? 0 : Stat->st_nlink;\n Filestat.size = isSpecialFd(Fd) ? 0 : Stat->st_size;\n Filestat.atim = isSpecialFd(Fd) ? 0 : fromTimespec(Stat->st_atim);\n Filestat.mtim = isSpecialFd(Fd) ? 0 : fromTimespec(Stat->st_mtim);\n Filestat.ctim = isSpecialFd(Fd) ? 0 : fromTimespec(Stat->st_ctim);\n\n return {};\n}\n\nWasiExpect\nINode::fdFilestatSetSize(__wasi_filesize_t Size) const noexcept {\n if (auto Res = ::ftruncate(Fd, Size); unlikely(Res == -1)) {\n return WasiUnexpect(fromErrNo(errno));\n }\n\n return {};\n}\n\nWasiExpect\nINode::fdFilestatSetTimes(__wasi_timestamp_t ATim, __wasi_timestamp_t MTim,\n __wasi_fstflags_t FstFlags) const noexcept {\n#if __GLIBC_PREREQ(2, 6) || __BIONIC__\n timespec SysTimespec[2];\n if (FstFlags & __WASI_FSTFLAGS_ATIM) {\n SysTimespec[0] = toTimespec(ATim);\n } else if (FstFlags & __WASI_FSTFLAGS_ATIM_NOW) {\n SysTimespec[0].tv_nsec = UTIME_NOW;\n } else {\n SysTimespec[0].tv_nsec = UTIME_OMIT;\n }\n if (FstFlags & __WASI_FSTFLAGS_MTIM) {\n SysTimespec[1] = toTimespec(MTim);\n } else if (FstFlags & __WASI_FSTFLAGS_MTIM_NOW) {\n SysTimespec[1].tv_nsec = UTIME_NOW;\n } else {\n SysTimespec[1].tv_nsec = UTIME_OMIT;\n }\n\n if (auto Res = ::futimens(Fd, SysTimespec); unlikely(Res != 0)) {\n return WasiUnexpect(fromErrNo(errno));\n }\n#else\n bool NeedNow = false;\n bool NeedFile = false;\n if (FstFlags & __WASI_FSTFLAGS_ATIM) {\n // Nothing to do.\n } else if (FstFlags & __WASI_FSTFLAGS_ATIM_NOW) {\n NeedNow = true;\n } else {\n NeedFile = true;\n }\n if (FstFlags & __WASI_FSTFLAGS_MTIM) {\n // Nothing to do.\n } else if (FstFlags & __WASI_FSTFLAGS_MTIM_NOW) {\n NeedNow = true;\n } else {\n NeedFile = true;\n }\n\n if (NeedFile) {\n if (auto Res = updateStat(); unlikely(!Res)) {\n return WasiUnexpect(Res);\n }\n }\n\n timespec Now;\n if (NeedNow) {\n if (auto Res = ::clock_gettime(CLOCK_REALTIME, &Now); unlikely(Res != 0)) {\n return WasiUnexpect(fromErrNo(errno));\n }\n }\n\n timeval SysTimeval[2];\n if (FstFlags & __WASI_FSTFLAGS_ATIM) {\n SysTimeval[0] = toTimeval(ATim);\n } else if (FstFlags & __WASI_FSTFLAGS_ATIM_NOW) {\n SysTimeval[0] = toTimeval(Now);\n } else {\n SysTimeval[0] = toTimeval(Stat->st_atim);\n }\n if (FstFlags & __WASI_FSTFLAGS_MTIM) {\n SysTimeval[1] = toTimeval(MTim);\n } else if (FstFlags & __WASI_FSTFLAGS_MTIM_NOW) {\n SysTimeval[1] = toTimeval(Now);\n } else {\n SysTimeval[1] = toTimeval(Stat->st_mtim);\n }\n\n if (auto Res = ::futimes(Fd, SysTimeval); unlikely(Res != 0)) {\n return WasiUnexpect(fromErrNo(errno));\n }\n#endif\n\n return {};\n}\n\nWasiExpect INode::fdPread(Span> IOVs,\n __wasi_filesize_t Offset,\n __wasi_size_t &NRead) const noexcept {\n iovec SysIOVs[kIOVMax];\n size_t SysIOVsSize = 0;\n for (auto &IOV : IOVs) {\n SysIOVs[SysIOVsSize].iov_base = IOV.data();\n SysIOVs[SysIOVsSize].iov_len = IOV.size();\n ++SysIOVsSize;\n }\n\n#if __GLIBC_PREREQ(2, 10)\n // Store read bytes length.\n if (auto Res = ::preadv(Fd, SysIOVs, SysIOVsSize, Offset);\n unlikely(Res < 0)) {\n return WasiUnexpect(fromErrNo(errno));\n } else {\n NRead = Res;\n }\n#else\n const auto OldOffset = ::lseek(Fd, 0, SEEK_CUR);\n if (OldOffset < 0) {\n return WasiUnexpect(fromErrNo(errno));\n }\n if (::lseek(Fd, Offset, SEEK_SET) < 0 ||\n ::lseek(Fd, OldOffset, SEEK_SET) < 0) {\n return WasiUnexpect(fromErrNo(errno));\n }\n if (auto Res = ::readv(Fd, SysIOVs, SysIOVsSize); unlikely(Res < 0)) {\n ::lseek(Fd, OldOffset, SEEK_SET);\n return WasiUnexpect(fromErrNo(errno));\n } else {\n if (::lseek(Fd, OldOffset, SEEK_SET) < 0) {\n return WasiUnexpect(fromErrNo(errno));\n }\n NRead = Res;\n }\n#endif\n\n return {};\n}\n\nWasiExpect INode::fdPwrite(Span> IOVs,\n __wasi_filesize_t Offset,\n __wasi_size_t &NWritten) const noexcept {\n iovec SysIOVs[kIOVMax];\n size_t SysIOVsSize = 0;\n for (auto &IOV : IOVs) {\n SysIOVs[SysIOVsSize].iov_base = const_cast(IOV.data());\n SysIOVs[SysIOVsSize].iov_len = IOV.size();\n ++SysIOVsSize;\n }\n\n#if __GLIBC_PREREQ(2, 10)\n if (auto Res = ::pwritev(Fd, SysIOVs, SysIOVsSize, Offset);\n unlikely(Res < 0)) {\n return WasiUnexpect(fromErrNo(errno));\n } else {\n NWritten = Res;\n }\n#else\n const auto OldOffset = ::lseek(Fd, 0, SEEK_CUR);\n if (OldOffset < 0) {\n return WasiUnexpect(fromErrNo(errno));\n }\n if (::lseek(Fd, Offset, SEEK_SET) < 0 ||\n ::lseek(Fd, OldOffset, SEEK_SET) < 0) {\n return WasiUnexpect(fromErrNo(errno));\n }\n if (auto Res = ::writev(Fd, SysIOVs, SysIOVsSize); unlikely(Res < 0)) {\n ::lseek(Fd, OldOffset, SEEK_SET);\n return WasiUnexpect(fromErrNo(errno));\n } else {\n if (::lseek(Fd, OldOffset, SEEK_SET) < 0) {\n return WasiUnexpect(fromErrNo(errno));\n }\n NWritten = Res;\n }\n#endif\n\n return {};\n}\n\nWasiExpect INode::fdRead(Span> IOVs,\n __wasi_size_t &NRead) const noexcept {\n iovec SysIOVs[kIOVMax];\n size_t SysIOVsSize = 0;\n for (auto &IOV : IOVs) {\n SysIOVs[SysIOVsSize].iov_base = IOV.data();\n SysIOVs[SysIOVsSize].iov_len = IOV.size();\n ++SysIOVsSize;\n }\n\n if (auto Res = ::readv(Fd, SysIOVs, SysIOVsSize); unlikely(Res < 0)) {\n return WasiUnexpect(fromErrNo(errno));\n } else {\n NRead = Res;\n }\n\n return {};\n}\n\n// Due to the unfortunate design of wasi::fd_readdir, It's nearly impossible to\n// provide a correct implementation. The below implementation is just a\n// workaround for most usages and may not be correct in some edge cases. The\n// readdir entry API is going to be updated to use a stream type, so we don't\n// have to deal with it right now.\nWasiExpect INode::fdReaddir(Span Buffer,\n __wasi_dircookie_t Cookie,\n __wasi_size_t &Size) noexcept {\n if (unlikely(!Dir.ok())) {\n if (FdHolder NewFd(::dup(Fd)); unlikely(!NewFd.ok())) {\n return WasiUnexpect(fromErrNo(errno));\n } else if (DIR *D = ::fdopendir(NewFd.Fd); unlikely(!D)) {\n return WasiUnexpect(fromErrNo(errno));\n } else {\n NewFd.release();\n Dir.emplace(D);\n }\n }\n\n if (Cookie == 0) {\n ::rewinddir(Dir.Dir);\n } else if (unlikely(Cookie != Dir.Cookie)) {\n ::seekdir(Dir.Dir, Cookie);\n }\n\n Size = 0;\n do {\n if (!Dir.Buffer.empty()) {\n const auto NewDataSize =\n std::min(Buffer.size(), Dir.Buffer.size());\n std::copy(Dir.Buffer.begin(), Dir.Buffer.begin() + NewDataSize,\n Buffer.begin());\n Buffer = Buffer.subspan(NewDataSize);\n Size += NewDataSize;\n Dir.Buffer.clear();\n if (unlikely(Buffer.empty())) {\n break;\n }\n }\n errno = 0;\n dirent *SysDirent = ::readdir(Dir.Dir);\n if (SysDirent == nullptr) {\n if (errno != 0) {\n return WasiUnexpect(fromErrNo(errno));\n }\n // End of entries\n break;\n }\n Dir.Cookie = SysDirent->d_off;\n std::string_view Name = SysDirent->d_name;\n\n Dir.Buffer.resize(sizeof(__wasi_dirent_t) + Name.size());\n\n __wasi_dirent_t *const Dirent =\n reinterpret_cast<__wasi_dirent_t *>(Dir.Buffer.data());\n Dirent->d_next = Dir.Cookie;\n Dirent->d_ino = SysDirent->d_ino;\n Dirent->d_type = fromFileType(SysDirent->d_type);\n Dirent->d_namlen = Name.size();\n std::copy(Name.cbegin(), Name.cend(),\n Dir.Buffer.begin() + sizeof(__wasi_dirent_t));\n } while (!Buffer.empty());\n\n return {};\n}\n\nWasiExpect INode::fdSeek(__wasi_filedelta_t Offset,\n __wasi_whence_t Whence,\n __wasi_filesize_t &Size) const noexcept {\n if (auto Res = ::lseek(Fd, Offset, toWhence(Whence)); unlikely(Res < 0)) {\n return WasiUnexpect(fromErrNo(errno));\n } else {\n Size = Res;\n }\n\n return {};\n}\n\nWasiExpect INode::fdSync() const noexcept {\n if (auto Res = ::fsync(Fd); unlikely(Res != 0)) {\n return WasiUnexpect(fromErrNo(errno));\n }\n\n return {};\n}\n\nWasiExpect INode::fdTell(__wasi_filesize_t &Size) const noexcept {\n if (auto Res = ::lseek(Fd, 0, SEEK_CUR); unlikely(Res < 0)) {\n return WasiUnexpect(fromErrNo(errno));\n } else {\n Size = Res;\n }\n\n return {};\n}\n\nWasiExpect INode::fdWrite(Span> IOVs,\n __wasi_size_t &NWritten) const noexcept {\n iovec SysIOVs[kIOVMax];\n size_t SysIOVsSize = 0;\n for (auto &IOV : IOVs) {\n SysIOVs[SysIOVsSize].iov_base = const_cast(IOV.data());\n SysIOVs[SysIOVsSize].iov_len = IOV.size();\n ++SysIOVsSize;\n }\n\n if (auto Res = ::writev(Fd, SysIOVs, SysIOVsSize); unlikely(Res < 0)) {\n return WasiUnexpect(fromErrNo(errno));\n } else {\n NWritten = Res;\n }\n\n return {};\n}\n\nWasiExpect INode::getNativeHandler() const noexcept {\n return static_cast(Fd);\n}\n\nWasiExpect INode::pathCreateDirectory(std::string Path) const noexcept {\n if (auto Res = ::mkdirat(Fd, Path.c_str(), 0755); unlikely(Res != 0)) {\n return WasiUnexpect(fromErrNo(errno));\n }\n\n return {};\n}\n\nWasiExpect\nINode::pathFilestatGet(std::string Path,\n __wasi_filestat_t &Filestat) const noexcept {\n struct stat SysFStat;\n if (int Res = ::fstatat(Fd, Path.c_str(), &SysFStat, AT_SYMLINK_NOFOLLOW);\n unlikely(Res != 0)) {\n return WasiUnexpect(fromErrNo(errno));\n }\n\n Filestat.dev = SysFStat.st_dev;\n Filestat.ino = SysFStat.st_ino;\n Filestat.filetype = fromFileType(static_cast(SysFStat.st_mode));\n Filestat.nlink = SysFStat.st_nlink;\n Filestat.size = SysFStat.st_size;\n Filestat.atim = fromTimespec(SysFStat.st_atim);\n Filestat.mtim = fromTimespec(SysFStat.st_mtim);\n Filestat.ctim = fromTimespec(SysFStat.st_ctim);\n\n return {};\n}\n\nWasiExpect\nINode::pathFilestatSetTimes(std::string Path, __wasi_timestamp_t ATim,\n __wasi_timestamp_t MTim,\n __wasi_fstflags_t FstFlags) const noexcept {\n#if __GLIBC_PREREQ(2, 6) || __BIONIC__\n timespec SysTimespec[2];\n if (FstFlags & __WASI_FSTFLAGS_ATIM) {\n SysTimespec[0] = toTimespec(ATim);\n } else if (FstFlags & __WASI_FSTFLAGS_ATIM_NOW) {\n SysTimespec[0].tv_nsec = UTIME_NOW;\n } else {\n SysTimespec[0].tv_nsec = UTIME_OMIT;\n }\n if (FstFlags & __WASI_FSTFLAGS_MTIM) {\n SysTimespec[1] = toTimespec(MTim);\n } else if (FstFlags & __WASI_FSTFLAGS_MTIM_NOW) {\n SysTimespec[1].tv_nsec = UTIME_NOW;\n } else {\n SysTimespec[1].tv_nsec = UTIME_OMIT;\n }\n\n if (auto Res =\n ::utimensat(Fd, Path.c_str(), SysTimespec, AT_SYMLINK_NOFOLLOW);\n unlikely(Res != 0)) {\n return WasiUnexpect(fromErrNo(errno));\n }\n#else\n bool NeedNow = false;\n bool NeedFile = false;\n if (FstFlags & __WASI_FSTFLAGS_ATIM) {\n // Nothing to do.\n } else if (FstFlags & __WASI_FSTFLAGS_ATIM_NOW) {\n NeedNow = true;\n } else {\n NeedFile = true;\n }\n if (FstFlags & __WASI_FSTFLAGS_MTIM) {\n // Nothing to do.\n } else if (FstFlags & __WASI_FSTFLAGS_MTIM_NOW) {\n NeedNow = true;\n } else {\n NeedFile = true;\n }\n\n#ifdef O_PATH\n const int OFlags = O_PATH | O_SYMLINK;\n#else\n const int OFlags = O_RDONLY | O_SYMLINK;\n#endif\n\n FdHolder Target(::openat(Fd, Path.c_str(), OFlags));\n if (unlikely(!Target.ok())) {\n return WasiUnexpect(fromErrNo(errno));\n }\n\n struct stat SysStat;\n if (NeedFile) {\n if (auto Res = ::fstat(Target.Fd, &SysStat); unlikely(Res != 0)) {\n return WasiUnexpect(fromErrNo(errno));\n }\n }\n\n timespec Now;\n if (NeedNow) {\n if (auto Res = ::clock_gettime(CLOCK_REALTIME, &Now); unlikely(Res != 0)) {\n return WasiUnexpect(fromErrNo(errno));\n }\n }\n\n timeval SysTimeval[2];\n if (FstFlags & __WASI_FSTFLAGS_ATIM) {\n SysTimeval[0] = toTimeval(ATim);\n } else if (FstFlags & __WASI_FSTFLAGS_ATIM_NOW) {\n SysTimeval[0] = toTimeval(Now);\n } else {\n SysTimeval[0] = toTimeval(SysStat.st_atim);\n }\n if (FstFlags & __WASI_FSTFLAGS_MTIM) {\n SysTimeval[1] = toTimeval(MTim);\n } else if (FstFlags & __WASI_FSTFLAGS_MTIM_NOW) {\n SysTimeval[1] = toTimeval(Now);\n } else {\n SysTimeval[1] = toTimeval(SysStat.st_mtim);\n }\n\n if (auto Res = ::futimes(Target.Fd, SysTimeval); unlikely(Res != 0)) {\n return WasiUnexpect(fromErrNo(errno));\n }\n#endif\n\n return {};\n}\n\nWasiExpect INode::pathLink(const INode &Old, std::string OldPath,\n const INode &New,\n std::string NewPath) noexcept {\n if (auto Res = ::linkat(Old.Fd, OldPath.c_str(), New.Fd, NewPath.c_str(), 0);\n unlikely(Res != 0)) {\n return WasiUnexpect(fromErrNo(errno));\n }\n\n return {};\n}\n\nWasiExpect INode::pathOpen(std::string Path, __wasi_oflags_t OpenFlags,\n __wasi_fdflags_t FdFlags,\n VFS::Flags VFSFlags) const noexcept {\n const int Flags = openFlags(OpenFlags, FdFlags, VFSFlags);\n\n if (auto NewFd = ::openat(Fd, Path.c_str(), Flags, 0644);\n unlikely(NewFd < 0)) {\n return WasiUnexpect(fromErrNo(errno));\n } else {\n INode New(NewFd);\n#ifndef O_CLOEXEC\n if (auto Res = ::fcntl(New.Fd, F_SETFD, FD_CLOEXEC); unlikely(Res != 0)) {\n return WasiUnexpect(fromErrNo(errno));\n }\n#endif\n return New;\n }\n}\n\nWasiExpect INode::pathReadlink(std::string Path, Span Buffer,\n __wasi_size_t &NRead) const noexcept {\n if (auto Res = ::readlinkat(Fd, Path.c_str(), Buffer.data(), Buffer.size());\n unlikely(Res < 0)) {\n return WasiUnexpect(fromErrNo(errno));\n } else {\n NRead = Res;\n }\n\n return {};\n}\n\nWasiExpect INode::pathRemoveDirectory(std::string Path) const noexcept {\n if (auto Res = ::unlinkat(Fd, Path.c_str(), AT_REMOVEDIR);\n unlikely(Res < 0)) {\n return WasiUnexpect(fromErrNo(errno));\n }\n\n return {};\n}\n\nWasiExpect INode::pathRename(const INode &Old, std::string OldPath,\n const INode &New,\n std::string NewPath) noexcept {\n if (auto Res = ::renameat(Old.Fd, OldPath.c_str(), New.Fd, NewPath.c_str());\n unlikely(Res != 0)) {\n return WasiUnexpect(fromErrNo(errno));\n }\n\n return {};\n}\n\nWasiExpect INode::pathSymlink(std::string OldPath,\n std::string NewPath) const noexcept {\n if (auto Res = ::symlinkat(OldPath.c_str(), Fd, NewPath.c_str());\n unlikely(Res != 0)) {\n return WasiUnexpect(fromErrNo(errno));\n }\n\n return {};\n}\n\nWasiExpect INode::pathUnlinkFile(std::string Path) const noexcept {\n if (auto Res = ::unlinkat(Fd, Path.c_str(), 0); unlikely(Res < 0)) {\n return WasiUnexpect(fromErrNo(errno));\n }\n\n return {};\n}\n\nWasiExpect INode::getAddrinfo(std::string_view Node,\n std::string_view Service,\n const __wasi_addrinfo_t &Hint,\n uint32_t MaxResLength,\n Span<__wasi_addrinfo_t *> WasiAddrinfoArray,\n Span<__wasi_sockaddr_t *> WasiSockaddrArray,\n Span AiAddrSaDataArray,\n Span AiCanonnameArray,\n /*Out*/ __wasi_size_t &ResLength) noexcept {\n const auto [NodeCStr, NodeBuf] = createNullTerminatedString(Node);\n const auto [ServiceCStr, ServiceBuf] = createNullTerminatedString(Service);\n\n struct addrinfo SysHint;\n SysHint.ai_flags = toAIFlags(Hint.ai_flags);\n SysHint.ai_family = toAddressFamily(Hint.ai_family);\n SysHint.ai_socktype = toSockType(Hint.ai_socktype);\n SysHint.ai_protocol = toProtocol(Hint.ai_protocol);\n SysHint.ai_addrlen = Hint.ai_addrlen;\n SysHint.ai_addr = nullptr;\n SysHint.ai_canonname = nullptr;\n SysHint.ai_next = nullptr;\n\n struct addrinfo *SysResPtr = nullptr;\n if (auto Res = ::getaddrinfo(NodeCStr, ServiceCStr, &SysHint, &SysResPtr);\n unlikely(Res < 0)) {\n return WasiUnexpect(fromEAIErrNo(Res));\n }\n // calculate ResLength\n if (ResLength = calculateAddrinfoLinkedListSize(SysResPtr);\n ResLength > MaxResLength) {\n ResLength = MaxResLength;\n }\n\n struct addrinfo *SysResItem = SysResPtr;\n for (uint32_t Idx = 0; Idx < ResLength; Idx++) {\n auto &CurAddrinfo = WasiAddrinfoArray[Idx];\n CurAddrinfo->ai_flags = fromAIFlags(SysResItem->ai_flags);\n CurAddrinfo->ai_socktype = fromSockType(SysResItem->ai_socktype);\n CurAddrinfo->ai_protocol = fromProtocol(SysResItem->ai_protocol);\n CurAddrinfo->ai_family = fromAddressFamily(SysResItem->ai_family);\n CurAddrinfo->ai_addrlen = SysResItem->ai_addrlen;\n\n // process ai_canonname in addrinfo\n if (SysResItem->ai_canonname != nullptr) {\n CurAddrinfo->ai_canonname_len = std::strlen(SysResItem->ai_canonname);\n auto &CurAiCanonname = AiCanonnameArray[Idx];\n std::memcpy(CurAiCanonname, SysResItem->ai_canonname,\n CurAddrinfo->ai_canonname_len + 1);\n } else {\n CurAddrinfo->ai_canonname_len = 0;\n }\n\n // process socket address\n if (SysResItem->ai_addrlen > 0) {\n auto &CurSockaddr = WasiSockaddrArray[Idx];\n CurSockaddr->sa_family =\n fromAddressFamily(SysResItem->ai_addr->sa_family);\n\n // process sa_data in socket address\n size_t SaSize = 0;\n switch (CurSockaddr->sa_family) {\n case __WASI_ADDRESS_FAMILY_INET4:\n SaSize = sizeof(sockaddr_in) - offsetof(sockaddr_in, sin_port);\n break;\n case __WASI_ADDRESS_FAMILY_INET6:\n SaSize = sizeof(sockaddr_in6) - offsetof(sockaddr_in6, sin6_port);\n break;\n default:\n assumingUnreachable();\n }\n std::memcpy(AiAddrSaDataArray[Idx], SysResItem->ai_addr->sa_data, SaSize);\n CurSockaddr->sa_data_len = __wasi_size_t(SaSize);\n }\n // process ai_next in addrinfo\n SysResItem = SysResItem->ai_next;\n }\n ::freeaddrinfo(SysResPtr);\n\n return {};\n}\n\nWasiExpect INode::sockOpen(__wasi_address_family_t AddressFamily,\n __wasi_sock_type_t SockType) noexcept {\n int SysProtocol = IPPROTO_IP;\n int SysDomain = 0;\n int SysType = 0;\n\n switch (AddressFamily) {\n case __WASI_ADDRESS_FAMILY_INET4:\n SysDomain = AF_INET;\n break;\n case __WASI_ADDRESS_FAMILY_INET6:\n SysDomain = AF_INET6;\n break;\n case __WASI_ADDRESS_FAMILY_AF_UNIX:\n SysDomain = AF_UNIX;\n break;\n default:\n return WasiUnexpect(__WASI_ERRNO_INVAL);\n }\n\n switch (SockType) {\n case __WASI_SOCK_TYPE_SOCK_DGRAM:\n SysType = SOCK_DGRAM;\n break;\n case __WASI_SOCK_TYPE_SOCK_STREAM:\n SysType = SOCK_STREAM;\n break;\n default:\n return WasiUnexpect(__WASI_ERRNO_INVAL);\n }\n\n if (auto NewFd = ::socket(SysDomain, SysType, SysProtocol);\n unlikely(NewFd < 0)) {\n return WasiUnexpect(fromErrNo(errno));\n } else {\n INode New(NewFd);\n return New;\n }\n}\n\nstruct SockEmptyAddr {};\nusing VarAddrT = std::variant;\n\nstruct VarAddrBuf {\n template sockaddr *operator()(T &V) {\n return reinterpret_cast(&V);\n }\n sockaddr *operator()(SockEmptyAddr &) { return nullptr; }\n};\n\nstruct VarAddrSize {\n template int operator()(const T &) { return sizeof(T); }\n int operator()(const SockEmptyAddr &) { return 0; }\n};\n\nstatic VarAddrT sockAddressAssignHelper(__wasi_address_family_t AddrFamily,\n const Span &Address,\n uint16_t Port) {\n VarAddrT Addr;\n if (Address.size() == 0) {\n Addr.emplace();\n } else if (AddrFamily == __WASI_ADDRESS_FAMILY_INET4) {\n auto &ServerAddr4 = Addr.emplace();\n\n ServerAddr4.sin_family = AF_INET;\n ServerAddr4.sin_port = htons(Port);\n assuming(Address.size() >= sizeof(in_addr));\n std::memcpy(&ServerAddr4.sin_addr, Address.data(), sizeof(in_addr));\n } else if (AddrFamily == __WASI_ADDRESS_FAMILY_INET6) {\n auto &ServerAddr6 = Addr.emplace();\n\n ServerAddr6.sin6_family = AF_INET6;\n ServerAddr6.sin6_port = htons(Port);\n ServerAddr6.sin6_flowinfo = 0;\n assuming(Address.size() >= sizeof(in6_addr));\n std::memcpy(&ServerAddr6.sin6_addr, Address.data(), sizeof(in6_addr));\n } else if (AddrFamily == __WASI_ADDRESS_FAMILY_AF_UNIX) {\n...\n// Path: lib/host/wasi/clock-macos.cpp\n// SPDX-License-Identifier: Apache-2.0\n// SPDX-FileCopyrightText: 2019-2022 Second State INC\n\n#include \"common/defines.h\"\n#if WASMEDGE_OS_MACOS\n\n#include \"host/wasi/clock.h\"\n#include \"macos.h\"\n\nnamespace WasmEdge {\nnamespace Host {\nnamespace WASI {\n\nWasiExpect Clock::clockResGet(__wasi_clockid_t Id,\n __wasi_timestamp_t &Resolution) noexcept {\n timespec SysTimespec;\n if (auto Res = ::clock_getres(toClockId(Id), &SysTimespec);\n unlikely(Res != 0)) {\n return WasiUnexpect(fromErrNo(errno));\n }\n\n Resolution = fromTimespec(SysTimespec);\n return {};\n}\n\nWasiExpect Clock::clockTimeGet(__wasi_clockid_t Id, __wasi_timestamp_t,\n __wasi_timestamp_t &Time) noexcept {\n timespec SysTimespec;\n if (auto Res = ::clock_gettime(toClockId(Id), &SysTimespec);\n unlikely(Res != 0)) {\n return WasiUnexpect(fromErrNo(errno));\n }\n\n Time = fromTimespec(SysTimespec);\n return {};\n}\n\n} // namespace WASI\n} // namespace Host\n} // namespace WasmEdge\n\n#endif\n\n// Path: lib/host/wasi/win.h\n// SPDX-License-Identifier: Apache-2.0\n// SPDX-FileCopyrightText: 2019-2022 Second State INC\n\n#include \"common/defines.h\"\n#if !WASMEDGE_OS_WINDOWS\n#error\n#endif\n\n#include \"common/errcode.h\"\n#include \"system/winapi.h\"\n#include \"wasi/api.hpp\"\n#include \n#include \n#include \n#include \n#include \n#include \n\nnamespace WasmEdge {\nnamespace Host {\nnamespace WASI {\ninline namespace detail {\nusing namespace winapi;\n\ninline constexpr __wasi_errno_t fromErrNo(int ErrNo) noexcept {\n switch (ErrNo) {\n case 0:\n return __WASI_ERRNO_SUCCESS;\n case E2BIG:\n return __WASI_ERRNO_2BIG;\n case EACCES:\n return __WASI_ERRNO_ACCES;\n case EADDRINUSE:\n return __WASI_ERRNO_ADDRINUSE;\n case EADDRNOTAVAIL:\n return __WASI_ERRNO_ADDRNOTAVAIL;\n case EAFNOSUPPORT:\n return __WASI_ERRNO_AFNOSUPPORT;\n case EAGAIN:\n return __WASI_ERRNO_AGAIN;\n case EALREADY:\n return __WASI_ERRNO_ALREADY;\n case EBADF:\n return __WASI_ERRNO_BADF;\n case EBADMSG:\n return __WASI_ERRNO_BADMSG;\n case EBUSY:\n return __WASI_ERRNO_BUSY;\n case ECANCELED:\n return __WASI_ERRNO_CANCELED;\n case ECHILD:\n return __WASI_ERRNO_CHILD;\n case ECONNABORTED:\n return __WASI_ERRNO_CONNABORTED;\n case ECONNREFUSED:\n return __WASI_ERRNO_CONNREFUSED;\n case ECONNRESET:\n return __WASI_ERRNO_CONNRESET;\n case EDEADLK:\n return __WASI_ERRNO_DEADLK;\n case EDESTADDRREQ:\n return __WASI_ERRNO_DESTADDRREQ;\n case EDOM:\n return __WASI_ERRNO_DOM;\n case EEXIST:\n return __WASI_ERRNO_EXIST;\n case EFAULT:\n return __WASI_ERRNO_FAULT;\n case EFBIG:\n return __WASI_ERRNO_FBIG;\n case EHOSTUNREACH:\n return __WASI_ERRNO_HOSTUNREACH;\n case EIDRM:\n return __WASI_ERRNO_IDRM;\n case EILSEQ:\n return __WASI_ERRNO_ILSEQ;\n case EINPROGRESS:\n return __WASI_ERRNO_INPROGRESS;\n case EINTR:\n return __WASI_ERRNO_INTR;\n case EINVAL:\n return __WASI_ERRNO_INVAL;\n case EIO:\n return __WASI_ERRNO_IO;\n case EISCONN:\n return __WASI_ERRNO_ISCONN;\n case EISDIR:\n return __WASI_ERRNO_ISDIR;\n case ELOOP:\n return __WASI_ERRNO_LOOP;\n case EMFILE:\n return __WASI_ERRNO_MFILE;\n case EMLINK:\n return __WASI_ERRNO_MLINK;\n case EMSGSIZE:\n return __WASI_ERRNO_MSGSIZE;\n case ENAMETOOLONG:\n return __WASI_ERRNO_NAMETOOLONG;\n case ENETDOWN:\n return __WASI_ERRNO_NETDOWN;\n case ENETRESET:\n return __WASI_ERRNO_NETRESET;\n case ENETUNREACH:\n return __WASI_ERRNO_NETUNREACH;\n case ENFILE:\n return __WASI_ERRNO_NFILE;\n case ENOBUFS:\n return __WASI_ERRNO_NOBUFS;\n case ENODEV:\n return __WASI_ERRNO_NODEV;\n case ENOENT:\n return __WASI_ERRNO_NOENT;\n case ENOEXEC:\n return __WASI_ERRNO_NOEXEC;\n case ENOLCK:\n return __WASI_ERRNO_NOLCK;\n case ENOLINK:\n return __WASI_ERRNO_NOLINK;\n case ENOMEM:\n return __WASI_ERRNO_NOMEM;\n case ENOMSG:\n return __WASI_ERRNO_NOMSG;\n case ENOPROTOOPT:\n return __WASI_ERRNO_NOPROTOOPT;\n case ENOSPC:\n return __WASI_ERRNO_NOSPC;\n case ENOSYS:\n return __WASI_ERRNO_NOSYS;\n case ENOTCONN:\n return __WASI_ERRNO_NOTCONN;\n case ENOTDIR:\n return __WASI_ERRNO_NOTDIR;\n case ENOTEMPTY:\n return __WASI_ERRNO_NOTEMPTY;\n case ENOTRECOVERABLE:\n return __WASI_ERRNO_NOTRECOVERABLE;\n case ENOTSOCK:\n return __WASI_ERRNO_NOTSOCK;\n case ENOTSUP:\n return __WASI_ERRNO_NOTSUP;\n case ENOTTY:\n return __WASI_ERRNO_NOTTY;\n case ENXIO:\n return __WASI_ERRNO_NXIO;\n case EOVERFLOW:\n return __WASI_ERRNO_OVERFLOW;\n case EOWNERDEAD:\n return __WASI_ERRNO_OWNERDEAD;\n case EPERM:\n return __WASI_ERRNO_PERM;\n case EPIPE:\n return __WASI_ERRNO_PIPE;\n case EPROTO:\n return __WASI_ERRNO_PROTO;\n case EPROTONOSUPPORT:\n return __WASI_ERRNO_PROTONOSUPPORT;\n case EPROTOTYPE:\n return __WASI_ERRNO_PROTOTYPE;\n case ERANGE:\n return __WASI_ERRNO_RANGE;\n case EROFS:\n return __WASI_ERRNO_ROFS;\n case ESPIPE:\n return __WASI_ERRNO_SPIPE;\n case ESRCH:\n return __WASI_ERRNO_SRCH;\n case ETIMEDOUT:\n return __WASI_ERRNO_TIMEDOUT;\n case ETXTBSY:\n return __WASI_ERRNO_TXTBSY;\n case EXDEV:\n return __WASI_ERRNO_XDEV;\n default:\n assumingUnreachable();\n }\n}\n\n\ninline __wasi_errno_t fromLastError(DWORD_ Code) noexcept {\n switch (Code) {\n case ERROR_INVALID_PARAMETER_: // MultiByteToWideChar\n case ERROR_INVALID_HANDLE_: // GetFinalPathNameByHandleW\n return __WASI_ERRNO_INVAL;\n case ERROR_SHARING_VIOLATION_: // CreateFile2\n case ERROR_PIPE_BUSY_: // CreateFile2\n return __WASI_ERRNO_BUSY;\n case ERROR_ACCESS_DENIED_: // CreateFile2\n return __WASI_ERRNO_ACCES;\n case ERROR_ALREADY_EXISTS_: // CreateFile2\n case ERROR_FILE_EXISTS_: // CreateFile2\n return __WASI_ERRNO_EXIST;\n case ERROR_FILE_NOT_FOUND_: // CreateFile2\n return __WASI_ERRNO_NOENT;\n case ERROR_PRIVILEGE_NOT_HELD_: // CreateSymbolicLinkW\n return __WASI_ERRNO_PERM;\n\n case ERROR_IO_PENDING_: // ReadFileEx\n case ERROR_HANDLE_EOF_: // ReadFileEx\n case ERROR_INSUFFICIENT_BUFFER_: // MultiByteToWideChar\n case ERROR_INVALID_FLAGS_: // MultiByteToWideChar\n case ERROR_NO_UNICODE_TRANSLATION_: // MultiByteToWideChar\n default:\n return __WASI_ERRNO_NOSYS;\n }\n}\n\nusing FiletimeDuration = std::chrono::duration<\n uint64_t,\n std::ratio_multiply, std::chrono::nanoseconds::period>>;\n/// from 1601-01-01 to 1970-01-01, 134774 days\nstatic inline constexpr const FiletimeDuration NTToUnixEpoch =\n std::chrono::seconds{134774u * 86400u};\n\nstatic constexpr __wasi_timestamp_t fromFiletime(FILETIME_ FileTime) noexcept {\n using std::chrono::duration_cast;\n using std::chrono::nanoseconds;\n ULARGE_INTEGER_ Temp = {/* LowPart */ FileTime.dwLowDateTime,\n /* HighPart */ FileTime.dwHighDateTime};\n auto Duration = duration_cast(FiletimeDuration{Temp.QuadPart} -\n NTToUnixEpoch);\n return static_cast<__wasi_timestamp_t>(Duration.count());\n}\n\nstatic constexpr FILETIME_ toFiletime(__wasi_timestamp_t TimeStamp) noexcept {\n using std::chrono::duration_cast;\n using std::chrono::nanoseconds;\n auto Duration =\n duration_cast(nanoseconds{TimeStamp}) + NTToUnixEpoch;\n ULARGE_INTEGER_ Temp = ULARGE_INTEGER_(Duration.count());\n return FILETIME_{/* dwLowDateTime */ Temp.LowPart,\n /* dwHighDateTime */ Temp.HighPart};\n}\n\ninline __wasi_errno_t fromWSALastError() noexcept {\n switch (WSAGetLastError()) {\n case WSASYSNOTREADY_: // WSAStartup\n case WSAEWOULDBLOCK_: // closesocket\n return __WASI_ERRNO_AGAIN;\n case WSAVERNOTSUPPORTED_: // WSAStartup\n return __WASI_ERRNO_NOTSUP;\n case WSAEINPROGRESS_: // WSAStartup, socket, closesocket\n return __WASI_ERRNO_INPROGRESS;\n case WSAEPROCLIM_: // WSAStartup\n return __WASI_ERRNO_BUSY;\n case WSAEFAULT_: // WSAStartup\n return __WASI_ERRNO_FAULT;\n case WSAENETDOWN_: // socket, closesocket\n return __WASI_ERRNO_NETDOWN;\n case WSAENOTSOCK_: // closesocket\n return __WASI_ERRNO_NOTSOCK;\n case WSAEINTR_: // closesocket\n return __WASI_ERRNO_INTR;\n case WSAEAFNOSUPPORT_: // socket\n return __WASI_ERRNO_AIFAMILY;\n case WSAEMFILE_: // socket\n return __WASI_ERRNO_NFILE;\n case WSAEINVAL_: // socket\n return __WASI_ERRNO_INVAL;\n case WSAENOBUFS_: // socket\n return __WASI_ERRNO_NOBUFS;\n case WSAEPROTONOSUPPORT_: // socket\n return __WASI_ERRNO_PROTONOSUPPORT;\n case WSAEPROTOTYPE_: // socket\n return __WASI_ERRNO_PROTOTYPE;\n case WSAESOCKTNOSUPPORT_: // socket\n return __WASI_ERRNO_AISOCKTYPE;\n case WSAEINVALIDPROCTABLE_: // socket\n case WSAEINVALIDPROVIDER_: // socket\n case WSAEPROVIDERFAILEDINIT_: // socket\n case WSANOTINITIALISED_: // socket, closesocket\n default:\n return __WASI_ERRNO_NOSYS;\n }\n}\n\ninline constexpr __wasi_errno_t fromWSAError(int WSAError) noexcept {\n switch (WSAError) {\n case WSATRY_AGAIN_:\n return __WASI_ERRNO_AIAGAIN;\n case WSAEINVAL_:\n return __WASI_ERRNO_AIBADFLAG;\n case WSANO_RECOVERY_:\n return __WASI_ERRNO_AIFAIL;\n case WSAEAFNOSUPPORT_:\n return __WASI_ERRNO_AIFAMILY;\n case ERROR_NOT_ENOUGH_MEMORY_:\n return __WASI_ERRNO_AIMEMORY;\n case WSAHOST_NOT_FOUND_:\n return __WASI_ERRNO_AINONAME;\n case WSATYPE_NOT_FOUND_:\n return __WASI_ERRNO_AISERVICE;\n case WSAESOCKTNOSUPPORT_:\n return __WASI_ERRNO_AISOCKTYPE;\n default:\n assumingUnreachable();\n }\n}\n\ninline WasiExpect ensureWSAStartup() noexcept {\n static std::once_flag InitFlag;\n try {\n std::call_once(InitFlag, []() {\n WSADATA_ WSAData;\n if (unlikely(WSAStartup(0x0202, &WSAData) != 0)) {\n throw detail::fromWSALastError();\n }\n if (unlikely(WSAData.wVersion != 0x0202)) {\n throw __WASI_ERRNO_NOSYS;\n }\n });\n return {};\n } catch (__wasi_errno_t &Error) {\n return WasiUnexpect(Error);\n }\n}\n\ninline constexpr DWORD_ toWhence(__wasi_whence_t Whence) noexcept {\n switch (Whence) {\n case __WASI_WHENCE_SET:\n return FILE_BEGIN_;\n case __WASI_WHENCE_END:\n return FILE_END_;\n case __WASI_WHENCE_CUR:\n return FILE_CURRENT_;\n default:\n assumingUnreachable();\n }\n}\n\ninline constexpr int toSockOptLevel(__wasi_sock_opt_level_t Level) noexcept {\n switch (Level) {\n case __WASI_SOCK_OPT_LEVEL_SOL_SOCKET:\n return SOL_SOCKET;\n default:\n assumingUnreachable();\n }\n}\n\ninline constexpr int toSockOptSoName(__wasi_sock_opt_so_t SoName) noexcept {\n switch (SoName) {\n case __WASI_SOCK_OPT_SO_REUSEADDR:\n return SO_REUSEADDR;\n case __WASI_SOCK_OPT_SO_TYPE:\n return SO_TYPE;\n case __WASI_SOCK_OPT_SO_ERROR:\n return SO_ERROR;\n case __WASI_SOCK_OPT_SO_DONTROUTE:\n return SO_DONTROUTE;\n case __WASI_SOCK_OPT_SO_BROADCAST:\n return SO_BROADCAST;\n case __WASI_SOCK_OPT_SO_SNDBUF:\n return SO_SNDBUF;\n case __WASI_SOCK_OPT_SO_RCVBUF:\n return SO_RCVBUF;\n case __WASI_SOCK_OPT_SO_KEEPALIVE:\n return SO_KEEPALIVE;\n case __WASI_SOCK_OPT_SO_OOBINLINE:\n return SO_OOBINLINE;\n case __WASI_SOCK_OPT_SO_LINGER:\n return SO_LINGER;\n case __WASI_SOCK_OPT_SO_RCVLOWAT:\n return SO_RCVLOWAT;\n case __WASI_SOCK_OPT_SO_RCVTIMEO:\n return SO_RCVTIMEO;\n case __WASI_SOCK_OPT_SO_SNDTIMEO:\n return SO_SNDTIMEO;\n case __WASI_SOCK_OPT_SO_ACCEPTCONN:\n return SO_ACCEPTCONN;\n default:\n assumingUnreachable();\n }\n}\n\ninline constexpr __wasi_aiflags_t fromAIFlags(int AIFlags) noexcept {\n __wasi_aiflags_t Result = static_cast<__wasi_aiflags_t>(0);\n\n...\n", "needle_token_start": 13569, "needle_token_end": 13978, "code_context_ntokens": 16213, "output": ["```cpp\ninline constexpr __wasi_errno_t fromErrNo(int ErrNo) noexcept {\n switch (ErrNo) {\n case 0:\n return __WASI_ERRNO_SUCCESS;\n case E2BIG:\n return __WASI_ERRNO_2BIG;\n case EACCES:\n return __WASI_ERRNO_ACCES;\n case EADDRINUSE:\n return __WASI_ERRNO_ADDRINUSE;\n case EADDRNOTAVAIL:\n return __WASI_ERRNO_ADDRNOTAVAIL;\n case EAFNOSUPPORT:\n return __WASI_ERRNO_AFNOSUPPORT;\n case EAGAIN:\n return __WASI_ERRNO_AGAIN;\n case EALREADY:\n return __WASI_ERRNO_ALREADY;\n case EBADF:\n return __WASI_ERRNO_BADF;\n case EBADMSG:\n return __WASI_ERRNO_BADMSG;\n case EBUSY:\n return __WASI_ERRNO_BUSY;\n case ECANCELED:\n return __WASI_ERRNO_CANCELED;\n case ECHILD:\n return __WASI_ERRNO_CHILD;\n case ECONNABORTED:\n return __WASI_ERRNO_CONNABORTED;\n case ECONNREFUSED:\n return __WASI_ERRNO_CONNREFUSED;\n case ECONNRESET:\n return __WASI_ERRNO_CONNRESET;\n case EDEADLK:\n return __WASI_ERRNO_DEADLK;\n case EDESTADDRREQ:\n return __WASI_ERRNO_DESTADDRREQ;\n case EDOM:\n return __WASI_ERRNO_DOM;\n case EEXIST:\n return __WASI_ERRNO_EXIST;\n case EFAULT:\n return __WASI_ERRNO_FAULT;\n case EFBIG:\n return __WASI_ERRNO_FBIG;\n case EHOSTUNREACH:\n return __WASI_ERRNO_HOSTUNREACH;\n case EIDRM:\n return __WASI_ERRNO_IDRM;\n case EILSEQ:\n return __WASI_ERRNO_ILSEQ;\n case EINPROGRESS:\n return __WASI_ERRNO_INPROGRESS;\n case EINTR:\n return __WASI_ERRNO_INTR;\n case EINVAL:\n return __WASI_ERRNO_INVAL;\n case EIO:\n return __WASI_ERRNO_IO;\n case EISCONN:\n return __WASI_ERRNO_ISCONN;\n case EISDIR:\n return __WASI_ERRNO_ISDIR;\n case ELOOP:\n return __WASI_ERRNO_LOOP;\n case EMFILE:\n return __WASI_ERRNO_MFILE;\n case EMLINK:\n return __WASI_ERRNO_MLINK;\n case EMSGSIZE:\n return __WASI_ERRNO_MSGSIZE;\n case ENAMETOOLONG:\n return __WASI_ERRNO_NAMETOOLONG;\n case ENETDOWN:\n return __WASI_ERRNO_NETDOWN;\n case ENETRESET:\n return __WASI_ERRNO_NETRESET;\n case ENETUNREACH:\n return __WASI_ERRNO_NETUNREACH;\n case ENFILE:\n return __WASI_ERRNO_NFILE;\n case ENOBUFS:\n return __WASI_ERRNO_NOBUFS;\n case ENODEV:\n return __WASI_ERRNO_NODEV;\n case ENOENT:\n return __WASI_ERRNO_NOENT;\n case ENOEXEC:\n return __WASI_ERRNO_NOEXEC;\n case ENOLCK:\n return __WASI_ERRNO_NOLCK;\n case ENOLINK:\n return __WASI_ERRNO_NOLINK;\n case ENOMEM:\n return __WASI_ERRNO_NOMEM;\n case ENOMSG:\n return __WASI_ERRNO_NOMSG;\n case ENOPROTOOPT:\n return __WASI_ERRNO_NOPROTOOPT;\n case ENOSPC:\n return __WASI_ERRNO_NOSPC;\n case ENOSYS:\n return __WASI_ERRNO_NOSYS;\n case ENOTCONN:\n return __WASI_ERRNO_NOTCONN;\n case ENOTDIR:\n return __WASI_ERRNO_NOTDIR;\n case ENOTEMPTY:\n return __WASI_ERRNO_NOTEMPTY;\n case ENOTRECOVERABLE:\n return __WASI_ERRNO_NOTRECOVERABLE;\n case ENOTSOCK:\n return __WASI_ERRNO_NOTSOCK;\n case ENOTSUP:\n return __WASI_ERRNO_NOTSUP;\n case ENOTTY:\n return __WASI_ERRNO_NOTTY;\n case ENXIO:\n return __WASI_ERRNO_NXIO;\n case EOVERFLOW:\n return __WASI_ERRNO_OVERFLOW;\n case EOWNERDEAD:\n return __WASI_ERRNO_OWNER"]} +{"repo": "WasmEdge/WasmEdge", "name": "roundDownPageBoundary", "language": "cpp", "path": "lib/loader/aot_section.cpp", "position_ratio": 0.95, "description": "\nFunction Description:\n1. **Purpose**: To adjust a given numerical value downwards to the nearest boundary of a memory page size, which varies depending on the operating system and architecture.\n2. **Input**: A 64-bit unsigned integer representing an address or size in memory.\n3. **Output**: A 64-bit unsigned integer that has been adjusted down to the nearest page boundary.\n4. **Procedure**: The function checks the operating system and architecture to determine the appropriate page size (standard or specific to ARM64 on macOS). It then uses a bitwise AND operation with the complement of the page size minus one to clear the lower bits of the input value, effectively rounding it down to the nearest multiple of the page size.\n\n", "instruction": "Based on the function description and code context, please retrieve and repeat the exact described function from the code context in a code block wrapped by ```:", "template": "instruction\ncode_context\ndescription\ninstruction", "code_context": "// Path: lib/executor/engine/engine.cpp\n// SPDX-License-Identifier: Apache-2.0\n// SPDX-FileCopyrightText: 2019-2022 Second State INC\n\n#include \"executor/executor.h\"\n\n#include \n#include \n#include \n\nnamespace WasmEdge {\nnamespace Executor {\n\nExpect Executor::runExpression(Runtime::StackManager &StackMgr,\n AST::InstrView Instrs) {\n return execute(StackMgr, Instrs.begin(), Instrs.end());\n}\n\nExpect\nExecutor::runFunction(Runtime::StackManager &StackMgr,\n const Runtime::Instance::FunctionInstance &Func,\n Span Params) {\n // Set start time.\n if (Stat && Conf.getStatisticsConfigure().isTimeMeasuring()) {\n Stat->startRecordWasm();\n }\n\n // Reset and push a dummy frame into stack.\n StackMgr.pushFrame(nullptr, AST::InstrView::iterator(), 0, 0);\n\n // Push arguments.\n const auto &PTypes = Func.getFuncType().getParamTypes();\n for (uint32_t I = 0; I < Params.size(); I++) {\n // For the references, transform to non-null reference type if the value not\n // null.\n if (PTypes[I].isRefType() && Params[I].get().getPtr() &&\n Params[I].get().getType().isNullableRefType()) {\n auto Val = Params[I];\n Val.get().getType().toNonNullableRef();\n StackMgr.push(Val);\n } else {\n StackMgr.push(Params[I]);\n }\n }\n\n // Enter and execute function.\n AST::InstrView::iterator StartIt = {};\n Expect Res = {};\n if (auto GetIt = enterFunction(StackMgr, Func, Func.getInstrs().end())) {\n StartIt = *GetIt;\n } else {\n if (GetIt.error() == ErrCode::Value::Terminated) {\n // Handle the terminated case in entering AOT or host functions.\n // For the terminated case, not return now to print the statistics.\n Res = Unexpect(GetIt.error());\n } else {\n return Unexpect(GetIt);\n }\n }\n if (Res) {\n // If not terminated, execute the instructions in interpreter mode.\n // For the entering AOT or host functions, the `StartIt` is equal to the end\n // of instruction list, therefore the execution will return immediately.\n Res = execute(StackMgr, StartIt, Func.getInstrs().end());\n }\n\n if (Res) {\n spdlog::debug(\" Execution succeeded.\");\n } else if (Res.error() == ErrCode::Value::Terminated) {\n spdlog::debug(\" Terminated.\");\n }\n\n if (Stat && Conf.getStatisticsConfigure().isTimeMeasuring()) {\n Stat->stopRecordWasm();\n }\n\n // If Statistics is enabled, then dump it here.\n if (Stat) {\n Stat->dumpToLog(Conf);\n }\n\n if (Res) {\n return {};\n }\n if (Res.error() == ErrCode::Value::Terminated) {\n StackMgr.reset();\n }\n return Unexpect(Res);\n}\n\nExpect Executor::execute(Runtime::StackManager &StackMgr,\n const AST::InstrView::iterator Start,\n const AST::InstrView::iterator End) {\n AST::InstrView::iterator PC = Start;\n AST::InstrView::iterator PCEnd = End;\n\n auto Dispatch = [this, &PC, &StackMgr]() -> Expect {\n const AST::Instruction &Instr = *PC;\n\n auto GetDstCompType = [&StackMgr, &Instr, this]() {\n return getDefTypeByIdx(StackMgr, Instr.getTargetIndex())\n ->getCompositeType();\n };\n auto GetSrcCompType = [&StackMgr, &Instr, this]() {\n return getDefTypeByIdx(StackMgr, Instr.getSourceIndex())\n ->getCompositeType();\n };\n\n switch (Instr.getOpCode()) {\n // Control instructions.\n case OpCode::Unreachable:\n spdlog::error(ErrCode::Value::Unreachable);\n spdlog::error(\n ErrInfo::InfoInstruction(Instr.getOpCode(), Instr.getOffset()));\n return Unexpect(ErrCode::Value::Unreachable);\n case OpCode::Nop:\n return {};\n case OpCode::Block:\n return {};\n case OpCode::Loop:\n return {};\n case OpCode::If:\n return runIfElseOp(StackMgr, Instr, PC);\n case OpCode::Else:\n if (Stat && Conf.getStatisticsConfigure().isCostMeasuring()) {\n // Reach here means end of if-statement.\n if (unlikely(!Stat->subInstrCost(Instr.getOpCode()))) {\n spdlog::error(ErrCode::Value::CostLimitExceeded);\n spdlog::error(\n ErrInfo::InfoInstruction(Instr.getOpCode(), Instr.getOffset()));\n return Unexpect(ErrCode::Value::CostLimitExceeded);\n }\n if (unlikely(!Stat->addInstrCost(OpCode::End))) {\n spdlog::error(ErrCode::Value::CostLimitExceeded);\n spdlog::error(\n ErrInfo::InfoInstruction(Instr.getOpCode(), Instr.getOffset()));\n return Unexpect(ErrCode::Value::CostLimitExceeded);\n }\n }\n PC += PC->getJumpEnd();\n [[fallthrough]];\n case OpCode::End:\n PC = StackMgr.maybePopFrame(PC);\n return {};\n case OpCode::Br:\n return runBrOp(StackMgr, Instr, PC);\n case OpCode::Br_if:\n return runBrIfOp(StackMgr, Instr, PC);\n case OpCode::Br_table:\n return runBrTableOp(StackMgr, Instr, PC);\n case OpCode::Br_on_null:\n return runBrOnNullOp(StackMgr, Instr, PC);\n case OpCode::Br_on_non_null:\n return runBrOnNonNullOp(StackMgr, Instr, PC);\n case OpCode::Br_on_cast:\n return runBrOnCastOp(StackMgr, Instr, PC);\n case OpCode::Br_on_cast_fail:\n return runBrOnCastOp(StackMgr, Instr, PC, true);\n case OpCode::Return:\n return runReturnOp(StackMgr, PC);\n case OpCode::Call:\n return runCallOp(StackMgr, Instr, PC);\n case OpCode::Call_indirect:\n return runCallIndirectOp(StackMgr, Instr, PC);\n case OpCode::Return_call:\n return runCallOp(StackMgr, Instr, PC, true);\n case OpCode::Return_call_indirect:\n return runCallIndirectOp(StackMgr, Instr, PC, true);\n case OpCode::Call_ref:\n return runCallRefOp(StackMgr, Instr, PC);\n case OpCode::Return_call_ref:\n return runCallRefOp(StackMgr, Instr, PC, true);\n\n // Reference Instructions\n case OpCode::Ref__null:\n return runRefNullOp(StackMgr, Instr.getValType());\n case OpCode::Ref__is_null:\n return runRefIsNullOp(StackMgr.getTop());\n case OpCode::Ref__func:\n return runRefFuncOp(StackMgr, Instr.getTargetIndex());\n case OpCode::Ref__eq: {\n ValVariant Rhs = StackMgr.pop();\n return runRefEqOp(StackMgr.getTop(), Rhs);\n }\n case OpCode::Ref__as_non_null:\n return runRefAsNonNullOp(StackMgr.getTop().get(), Instr);\n\n // GC Instructions\n case OpCode::Struct__new:\n return runStructNewOp(StackMgr, Instr.getTargetIndex());\n case OpCode::Struct__new_default:\n return runStructNewOp(StackMgr, Instr.getTargetIndex(), true);\n case OpCode::Struct__get:\n case OpCode::Struct__get_u:\n return runStructGetOp(StackMgr.getTop(), Instr.getSourceIndex(),\n GetDstCompType(), Instr);\n case OpCode::Struct__get_s:\n return runStructGetOp(StackMgr.getTop(), Instr.getSourceIndex(),\n GetDstCompType(), Instr, true);\n case OpCode::Struct__set: {\n const ValVariant Val = StackMgr.pop();\n RefVariant StructRef = StackMgr.pop().get();\n return runStructSetOp(Val, StructRef, GetDstCompType(),\n Instr.getSourceIndex(), Instr);\n }\n case OpCode::Array__new:\n return runArrayNewOp(StackMgr, Instr.getTargetIndex(), 1,\n StackMgr.pop().get());\n case OpCode::Array__new_default:\n return runArrayNewOp(StackMgr, Instr.getTargetIndex(), 0,\n StackMgr.pop().get());\n case OpCode::Array__new_fixed:\n return runArrayNewOp(StackMgr, Instr.getTargetIndex(),\n Instr.getSourceIndex(), Instr.getSourceIndex());\n case OpCode::Array__new_data:\n return runArrayNewDataOp(\n StackMgr, *getDataInstByIdx(StackMgr, Instr.getSourceIndex()), Instr);\n case OpCode::Array__new_elem:\n return runArrayNewElemOp(\n...\n// Path: lib/executor/engine/refInstr.cpp\n// SPDX-License-Identifier: Apache-2.0\n// SPDX-FileCopyrightText: 2019-2022 Second State INC\n\n#include \"executor/executor.h\"\n\nnamespace WasmEdge {\nnamespace Executor {\n\nnamespace {\nValVariant packVal(const ValType &Type, const ValVariant &Val) {\n if (Type.isPackType()) {\n switch (Type.getCode()) {\n case TypeCode::I8:\n return ValVariant(Val.get() & 0xFFU);\n case TypeCode::I16:\n return ValVariant(Val.get() & 0xFFFFU);\n default:\n assumingUnreachable();\n }\n }\n return Val;\n}\n\nValVariant unpackVal(const ValType &Type, const ValVariant &Val,\n bool IsSigned = false) {\n if (Type.isPackType()) {\n uint32_t Num = Val.get();\n switch (Type.getCode()) {\n case TypeCode::I8:\n if (IsSigned) {\n return static_cast(static_cast(Num));\n } else {\n return static_cast(static_cast(Num));\n }\n case TypeCode::I16:\n if (IsSigned) {\n return static_cast(static_cast(Num));\n } else {\n return static_cast(static_cast(Num));\n }\n default:\n assumingUnreachable();\n }\n }\n return Val;\n}\n\nstd::vector packVals(const ValType &Type,\n std::vector &&Vals) {\n for (uint32_t I = 0; I < Vals.size(); I++) {\n Vals[I] = packVal(Type, Vals[I]);\n }\n return std::move(Vals);\n}\n} // namespace\n\nExpect Executor::runRefNullOp(Runtime::StackManager &StackMgr,\n const ValType &Type) const noexcept {\n // A null reference is typed with the least type in its respective hierarchy.\n StackMgr.push(RefVariant(toBottomType(StackMgr, Type)));\n return {};\n}\n\nExpect Executor::runRefIsNullOp(ValVariant &Val) const noexcept {\n Val.emplace(Val.get().isNull() ? 1U : 0U);\n return {};\n}\n\nExpect Executor::runRefFuncOp(Runtime::StackManager &StackMgr,\n uint32_t Idx) const noexcept {\n const auto *FuncInst = getFuncInstByIdx(StackMgr, Idx);\n StackMgr.push(RefVariant(FuncInst->getDefType(), FuncInst));\n return {};\n}\n\nExpect Executor::runRefEqOp(ValVariant &Val1,\n const ValVariant &Val2) const noexcept {\n Val1.emplace(Val1.get().getPtr() ==\n Val2.get().getPtr()\n ? 1U\n : 0U);\n return {};\n}\n\nExpect\nExecutor::runRefAsNonNullOp(RefVariant &Ref,\n const AST::Instruction &Instr) const noexcept {\n if (Ref.isNull()) {\n spdlog::error(ErrCode::Value::CastNullToNonNull);\n spdlog::error(\n ErrInfo::InfoInstruction(Instr.getOpCode(), Instr.getOffset()));\n return Unexpect(ErrCode::Value::CastNullToNonNull);\n }\n Ref.getType().toNonNullableRef();\n return {};\n}\n\nExpect Executor::runStructNewOp(Runtime::StackManager &StackMgr,\n const uint32_t DefIndex,\n bool IsDefault) const noexcept {\n /// TODO: The array and struct instances are owned by the module instance\n /// currently because of referring the defined types of the module instances.\n /// This may be changed after applying the garbage collection mechanism.\n const auto &CompType =\n getDefTypeByIdx(StackMgr, DefIndex)->getCompositeType();\n uint32_t N = static_cast(CompType.getFieldTypes().size());\n std::vector Vals;\n if (IsDefault) {\n Vals.resize(N);\n for (uint32_t I = 0; I < N; I++) {\n const auto &VType = CompType.getFieldTypes()[I].getStorageType();\n Vals[I] = VType.isRefType()\n ? ValVariant(RefVariant(toBottomType(StackMgr, VType)))\n : ValVariant(static_cast(0));\n }\n } else {\n Vals = StackMgr.pop(N);\n for (uint32_t I = 0; I < N; I++) {\n Vals[I] = packVal(CompType.getFieldTypes()[I].getStorageType(), Vals[I]);\n }\n }\n auto *Inst =\n const_cast(StackMgr.getModule())\n ->newStruct(DefIndex, std::move(Vals));\n StackMgr.push(RefVariant(Inst->getDefType(), Inst));\n\n return {};\n}\n\nExpect Executor::runStructGetOp(ValVariant &Val, const uint32_t Idx,\n const AST::CompositeType &CompType,\n const AST::Instruction &Instr,\n bool IsSigned) const noexcept {\n const auto *Inst =\n Val.get().getPtr();\n if (Inst == nullptr) {\n spdlog::error(ErrCode::Value::AccessNullStruct);\n spdlog::error(\n ErrInfo::InfoInstruction(Instr.getOpCode(), Instr.getOffset()));\n return Unexpect(ErrCode::Value::AccessNullStruct);\n }\n const auto &SType = CompType.getFieldTypes()[Idx].getStorageType();\n Val = unpackVal(SType, Inst->getField(Idx), IsSigned);\n return {};\n}\n\nExpect\nExecutor::runStructSetOp(const ValVariant &Val, const RefVariant &InstRef,\n const AST::CompositeType &CompType, uint32_t Idx,\n const AST::Instruction &Instr) const noexcept {\n auto *Inst = InstRef.getPtr();\n if (Inst == nullptr) {\n spdlog::error(ErrCode::Value::AccessNullStruct);\n spdlog::error(\n ErrInfo::InfoInstruction(Instr.getOpCode(), Instr.getOffset()));\n return Unexpect(ErrCode::Value::AccessNullStruct);\n }\n const auto &SType = CompType.getFieldTypes()[Idx].getStorageType();\n Inst->getField(Idx) = packVal(SType, Val);\n return {};\n}\n\nExpect Executor::runArrayNewOp(Runtime::StackManager &StackMgr,\n const uint32_t DefIndex, uint32_t InitCnt,\n uint32_t ValCnt) const noexcept {\n /// TODO: The array and struct instances are owned by the module instance\n /// currently because of referring the defined types of the module instances.\n /// This may be changed after applying the garbage collection mechanism.\n assuming(InitCnt == 0 || InitCnt == 1 || InitCnt == ValCnt);\n const auto &CompType =\n getDefTypeByIdx(StackMgr, DefIndex)->getCompositeType();\n const auto &VType = CompType.getFieldTypes()[0].getStorageType();\n if (InitCnt == 0) {\n auto InitVal = VType.isRefType()\n ? ValVariant(RefVariant(toBottomType(StackMgr, VType)))\n : ValVariant(static_cast(0));\n auto *Inst =\n const_cast(StackMgr.getModule())\n ->newArray(DefIndex, ValCnt, InitVal);\n StackMgr.push(RefVariant(Inst->getDefType(), Inst));\n } else if (InitCnt == 1) {\n auto *Inst =\n const_cast(StackMgr.getModule())\n ->newArray(DefIndex, ValCnt, packVal(VType, StackMgr.getTop()));\n StackMgr.getTop().emplace(Inst->getDefType(), Inst);\n } else {\n auto *Inst =\n const_cast(StackMgr.getModule())\n ->newArray(DefIndex, packVals(VType, StackMgr.pop(ValCnt)));\n StackMgr.push(RefVariant(Inst->getDefType(), Inst));\n }\n return {};\n}\n\nExpect\nExecutor::runArrayNewDataOp(Runtime::StackManager &StackMgr,\n const Runtime::Instance::DataInstance &DataInst,\n const AST::Instruction &Instr) const noexcept {\n const uint32_t N = StackMgr.pop().get();\n const uint32_t S = StackMgr.getTop().get();\n const auto &CompType =\n getDefTypeByIdx(StackMgr, Instr.getTargetIndex())->getCompositeType();\n const uint32_t BSize =\n CompType.getFieldTypes()[0].getStorageType().getBitWidth() / 8;\n if (static_cast(S) + static_cast(N) * BSize >\n DataInst.getData().size()) {\n spdlog::error(ErrCode::Value::MemoryOutOfBounds);\n spdlog::error(ErrInfo::InfoBoundary(\n static_cast(S), N * BSize,\n DataInst.getData().size() > 0\n ? static_cast(DataInst.getData().size() - 1)\n : 0U));\n spdlog::error(\n ErrInfo::InfoInstruction(Instr.getOpCode(), Instr.getOffset()));\n return Unexpect(ErrCode::Value::MemoryOutOfBounds);\n }\n /// TODO: The array and struct instances are owned by the module instance\n /// currently because of referring the defined types of the module instances.\n /// This may be changed after applying the garbage collection mechanism.\n auto *Inst =\n const_cast(StackMgr.getModule())\n ->newArray(Instr.getTargetIndex(), N, 0U);\n for (uint32_t Idx = 0; Idx < N; Idx++) {\n // The value has been packed.\n Inst->getData(Idx) = DataInst.loadValue(S + Idx * BSize, BSize);\n }\n StackMgr.getTop().emplace(Inst->getDefType(), Inst);\n return {};\n}\n\nExpect\nExecutor::runArrayNewElemOp(Runtime::StackManager &StackMgr,\n const Runtime::Instance::ElementInstance &ElemInst,\n const AST::Instruction &Instr) const noexcept {\n const uint32_t N = StackMgr.pop().get();\n const uint32_t S = StackMgr.getTop().get();\n const auto &CompType =\n getDefTypeByIdx(StackMgr, Instr.getTargetIndex())->getCompositeType();\n const auto &SType = CompType.getFieldTypes()[0].getStorageType();\n auto ElemSrc = ElemInst.getRefs();\n if (static_cast(S) + static_cast(N) > ElemSrc.size()) {\n spdlog::error(ErrCode::Value::TableOutOfBounds);\n spdlog::error(ErrInfo::InfoBoundary(\n static_cast(S), N,\n ElemSrc.size() > 0 ? static_cast(ElemSrc.size() - 1) : 0U));\n spdlog::error(\n ErrInfo::InfoInstruction(Instr.getOpCode(), Instr.getOffset()));\n return Unexpect(ErrCode::Value::TableOutOfBounds);\n }\n std::vector Refs(ElemSrc.begin() + S, ElemSrc.begin() + S + N);\n /// TODO: The array and struct instances are owned by the module instance\n /// currently because of referring the defined types of the module instances.\n /// This may be changed after applying the garbage collection mechanism.\n auto *Inst =\n const_cast(StackMgr.getModule())\n ->newArray(Instr.getTargetIndex(), packVals(SType, std::move(Refs)));\n StackMgr.getTop().emplace(Inst->getDefType(), Inst);\n return {};\n}\n\nExpect\nExecutor::runArraySetOp(const ValVariant &Val, const uint32_t Idx,\n const RefVariant &InstRef,\n const AST::CompositeType &CompType,\n const AST::Instruction &Instr) const noexcept {\n auto *Inst = InstRef.getPtr();\n if (Inst == nullptr) {\n spdlog::error(ErrCode::Value::AccessNullArray);\n spdlog::error(\n ErrInfo::InfoInstruction(Instr.getOpCode(), Instr.getOffset()));\n return Unexpect(ErrCode::Value::AccessNullArray);\n }\n if (Idx >= Inst->getLength()) {\n spdlog::error(ErrCode::Value::ArrayOutOfBounds);\n spdlog::error(ErrInfo::InfoBoundary(Idx, 1, Inst->getBoundIdx()));\n spdlog::error(\n ErrInfo::InfoInstruction(Instr.getOpCode(), Instr.getOffset()));\n return Unexpect(ErrCode::Value::ArrayOutOfBounds);\n }\n const auto &SType = CompType.getFieldTypes()[0].getStorageType();\n Inst->getData(Idx) = packVal(SType, Val);\n return {};\n}\n\nExpect Executor::runArrayGetOp(ValVariant &Val, const uint32_t Idx,\n const AST::CompositeType &CompType,\n const AST::Instruction &Instr,\n bool IsSigned) const noexcept {\n const auto *Inst =\n Val.get().getPtr();\n if (Inst == nullptr) {\n spdlog::error(ErrCode::Value::AccessNullArray);\n spdlog::error(\n ErrInfo::InfoInstruction(Instr.getOpCode(), Instr.getOffset()));\n return Unexpect(ErrCode::Value::AccessNullArray);\n }\n if (Idx >= Inst->getLength()) {\n spdlog::error(ErrCode::Value::ArrayOutOfBounds);\n spdlog::error(ErrInfo::InfoBoundary(Idx, 1, Inst->getBoundIdx()));\n spdlog::error(\n ErrInfo::InfoInstruction(Instr.getOpCode(), Instr.getOffset()));\n return Unexpect(ErrCode::Value::ArrayOutOfBounds);\n }\n const auto &SType = CompType.getFieldTypes()[0].getStorageType();\n Val = unpackVal(SType, Inst->getData(Idx), IsSigned);\n return {};\n}\n\nExpect\nExecutor::runArrayLenOp(ValVariant &Val,\n const AST::Instruction &Instr) const noexcept {\n const auto *Inst =\n Val.get().getPtr();\n if (Inst == nullptr) {\n spdlog::error(ErrCode::Value::AccessNullArray);\n spdlog::error(\n ErrInfo::InfoInstruction(Instr.getOpCode(), Instr.getOffset()));\n return Unexpect(ErrCode::Value::AccessNullArray);\n }\n Val.emplace(Inst->getLength());\n return {};\n}\n\nExpect\nExecutor::runArrayFillOp(uint32_t N, const ValVariant &Val, uint32_t D,\n const RefVariant &InstRef,\n const AST::CompositeType &CompType,\n const AST::Instruction &Instr) const noexcept {\n auto *Inst = InstRef.getPtr();\n if (Inst == nullptr) {\n spdlog::error(ErrCode::Value::AccessNullArray);\n spdlog::error(\n ErrInfo::InfoInstruction(Instr.getOpCode(), Instr.getOffset()));\n return Unexpect(ErrCode::Value::AccessNullArray);\n }\n if (static_cast(D) + static_cast(N) > Inst->getLength()) {\n spdlog::error(ErrCode::Value::ArrayOutOfBounds);\n spdlog::error(ErrInfo::InfoBoundary(static_cast(D), N,\n Inst->getBoundIdx()));\n spdlog::error(\n ErrInfo::InfoInstruction(Instr.getOpCode(), Instr.getOffset()));\n return Unexpect(ErrCode::Value::ArrayOutOfBounds);\n }\n const auto &SType = CompType.getFieldTypes()[0].getStorageType();\n auto Arr = Inst->getArray();\n std::fill(Arr.begin() + D, Arr.begin() + D + N, packVal(SType, Val));\n return {};\n}\n\nExpect\nExecutor::runArrayCopyOp(uint32_t N, uint32_t S, const RefVariant &SrcInstRef,\n uint32_t D, const RefVariant &DstInstRef,\n const AST::CompositeType &SrcCompType,\n const AST::CompositeType &DstCompType,\n const AST::Instruction &Instr) const noexcept {\n auto *SrcInst = SrcInstRef.getPtr();\n auto *DstInst = DstInstRef.getPtr();\n if (SrcInst == nullptr || DstInst == nullptr) {\n spdlog::error(ErrCode::Value::AccessNullArray);\n spdlog::error(\n ErrInfo::InfoInstruction(Instr.getOpCode(), Instr.getOffset()));\n return Unexpect(ErrCode::Value::AccessNullArray);\n }\n if (static_cast(S) + static_cast(N) >\n SrcInst->getLength()) {\n spdlog::error(ErrCode::Value::ArrayOutOfBounds);\n spdlog::error(ErrInfo::InfoBoundary(static_cast(S), N,\n SrcInst->getBoundIdx()));\n spdlog::error(\n ErrInfo::InfoInstruction(Instr.getOpCode(), Instr.getOffset()));\n return Unexpect(ErrCode::Value::ArrayOutOfBounds);\n }\n if (static_cast(D) + static_cast(N) >\n DstInst->getLength()) {\n spdlog::error(ErrCode::Value::ArrayOutOfBounds);\n spdlog::error(ErrInfo::InfoBoundary(static_cast(D), N,\n DstInst->getBoundIdx()));\n spdlog::error(\n ErrInfo::InfoInstruction(Instr.getOpCode(), Instr.getOffset()));\n return Unexpect(ErrCode::Value::ArrayOutOfBounds);\n }\n const auto &SrcSType = SrcCompType.getFieldTypes()[0].getStorageType();\n const auto &DstSType = DstCompType.getFieldTypes()[0].getStorageType();\n auto SrcArr = SrcInst->getArray();\n auto DstArr = DstInst->getArray();\n if (D <= S) {\n std::transform(SrcArr.begin() + S, SrcArr.begin() + S + N,\n DstArr.begin() + D, [&](const ValVariant &V) {\n return packVal(DstSType, unpackVal(SrcSType, V));\n });\n } else {\n std::transform(std::make_reverse_iterator(SrcArr.begin() + S + N),\n std::make_reverse_iterator(SrcArr.begin() + S),\n std::make_reverse_iterator(DstArr.begin() + D + N),\n [&](const ValVariant &V) {\n return packVal(DstSType, unpackVal(SrcSType, V));\n });\n }\n return {};\n}\n\nExpect\nExecutor::runArrayInitDataOp(uint32_t N, uint32_t S, uint32_t D,\n const RefVariant &InstRef,\n const AST::CompositeType &CompType,\n const Runtime::Instance::DataInstance &DataInst,\n const AST::Instruction &Instr) const noexcept {\n const uint32_t BSize =\n CompType.getFieldTypes()[0].getStorageType().getBitWidth() / 8;\n auto *Inst = InstRef.getPtr();\n if (Inst == nullptr) {\n spdlog::error(ErrCode::Value::AccessNullArray);\n spdlog::error(\n ErrInfo::InfoInstruction(Instr.getOpCode(), Instr.getOffset()));\n return Unexpect(ErrCode::Value::AccessNullArray);\n }\n if (static_cast(D) + static_cast(N) > Inst->getLength()) {\n spdlog::error(ErrCode::Value::ArrayOutOfBounds);\n spdlog::error(ErrInfo::InfoBoundary(static_cast(D), N,\n Inst->getBoundIdx()));\n spdlog::error(\n ErrInfo::InfoInstruction(Instr.getOpCode(), Instr.getOffset()));\n return Unexpect(ErrCode::Value::ArrayOutOfBounds);\n }\n if (static_cast(S) + static_cast(N) * BSize >\n DataInst.getData().size()) {\n spdlog::error(ErrCode::Value::MemoryOutOfBounds);\n spdlog::error(ErrInfo::InfoBoundary(\n static_cast(S), N * BSize,\n DataInst.getData().size() > 0\n ? static_cast(DataInst.getData().size() - 1)\n : 0U));\n spdlog::error(\n ErrInfo::InfoInstruction(Instr.getOpCode(), Instr.getOffset()));\n return Unexpect(ErrCode::Value::MemoryOutOfBounds);\n }\n for (uint32_t Off = 0; Off < N; Off++) {\n // The value has been packed.\n Inst->getData(D + Off) = DataInst.loadValue(S + Off * BSize, BSize);\n }\n return {};\n}\n\nExpect\nExecutor::runArrayInitElemOp(uint32_t N, uint32_t S, uint32_t D,\n const RefVariant &InstRef,\n const AST::CompositeType &CompType,\n const Runtime::Instance::ElementInstance &ElemInst,\n const AST::Instruction &Instr) const noexcept {\n auto ElemSrc = ElemInst.getRefs();\n auto *Inst = InstRef.getPtr();\n if (Inst == nullptr) {\n spdlog::error(ErrCode::Value::AccessNullArray);\n spdlog::error(\n ErrInfo::InfoInstruction(Instr.getOpCode(), Instr.getOffset()));\n return Unexpect(ErrCode::Value::AccessNullArray);\n }\n if (static_cast(D) + static_cast(N) > Inst->getLength()) {\n spdlog::error(ErrCode::Value::ArrayOutOfBounds);\n spdlog::error(ErrInfo::InfoBoundary(static_cast(D), N,\n Inst->getBoundIdx()));\n spdlog::error(\n ErrInfo::InfoInstruction(Instr.getOpCode(), Instr.getOffset()));\n return Unexpect(ErrCode::Value::ArrayOutOfBounds);\n }\n if (static_cast(S) + static_cast(N) > ElemSrc.size()) {\n spdlog::error(ErrCode::Value::TableOutOfBounds);\n spdlog::error(ErrInfo::InfoBoundary(\n static_cast(S), N,\n ElemSrc.size() > 0 ? static_cast(ElemSrc.size() - 1) : 0U));\n spdlog::error(\n ErrInfo::InfoInstruction(Instr.getOpCode(), Instr.getOffset()));\n return Unexpect(ErrCode::Value::TableOutOfBounds);\n }\n const auto &SType = CompType.getFieldTypes()[0].getStorageType();\n\n auto Arr = Inst->getArray();\n // The value has been packed.\n std::transform(ElemSrc.begin() + S, ElemSrc.begin() + S + N, Arr.begin() + D,\n [&](const RefVariant &V) { return packVal(SType, V); });\n return {};\n}\n\nExpect\nExecutor::runRefTestOp(const Runtime::Instance::ModuleInstance *ModInst,\n ValVariant &Val, const AST::Instruction &Instr,\n bool IsCast) const noexcept {\n // Copy the value type here due to handling the externalized case.\n auto &VT = Val.get().getType();\n if (VT.isExternalized()) {\n VT = ValType(TypeCode::Ref, TypeCode::ExternRef);\n }\n Span GotTypeList = ModInst->getTypeList();\n if (!VT.isAbsHeapType()) {\n auto *Inst =\n Val.get().getPtr();\n // Reference must not be nullptr here because the null references are typed\n // with the least abstract heap type.\n if (Inst->getModule()) {\n GotTypeList = Inst->getModule()->getTypeList();\n }\n }\n\n if (AST::TypeMatcher::matchType(ModInst->getTypeList(), Instr.getValType(),\n GotTypeList, VT)) {\n if (!IsCast) {\n Val.emplace(1U);\n }\n } else {\n if (IsCast) {\n spdlog::error(ErrCode::Value::CastFailed);\n spdlog::error(ErrInfo::InfoMismatch(Instr.getValType(), VT));\n spdlog::error(\n ErrInfo::InfoInstruction(Instr.getOpCode(), Instr.getOffset()));\n return Unexpect(ErrCode::Value::CastFailed);\n } else {\n Val.emplace(0U);\n }\n }\n return {};\n}\n\nExpect Executor::runRefConvOp(RefVariant &Ref,\n TypeCode TCode) const noexcept {\n\n if (TCode == TypeCode::AnyRef) {\n // Internalize.\n if (Ref.isNull()) {\n Ref = RefVariant(ValType(TypeCode::RefNull, TypeCode::NullRef));\n } else {\n Ref.getType().setInternalized();\n if (Ref.getType().isExternRefType()) {\n Ref.getType() = ValType(TypeCode::Ref, TypeCode::AnyRef);\n }\n }\n } else {\n // Externalize.\n if (Ref.isNull()) {\n Ref = RefVariant(ValType(TypeCode::RefNull, TypeCode::NullExternRef));\n } else {\n // Use the externalize flag because the value type information should be\n // reserved when a reference being externalized and internalized.\n Ref.getType().setExternalized();\n }\n }\n return {};\n}\n\nExpect Executor::runRefI31Op(ValVariant &Val) const noexcept {\n uint32_t RefNum = (Val.get() & 0x7FFFFFFFU) | 0x80000000U;\n Val = RefVariant(ValType(TypeCode::Ref, TypeCode::I31Ref),\n reinterpret_cast(static_cast(RefNum)));\n return {};\n}\n\nExpect Executor::runI31GetOp(ValVariant &Val,\n const AST::Instruction &Instr,\n bool IsSigned) const noexcept {\n uint32_t RefNum = static_cast(\n reinterpret_cast(Val.get().getPtr()));\n if ((RefNum & 0x80000000U) == 0) {\n spdlog::error(ErrCode::Value::AccessNullI31);\n spdlog::error(\n ErrInfo::InfoInstruction(Instr.getOpCode(), Instr.getOffset()));\n return Unexpect(ErrCode::Value::AccessNullI31);\n }\n RefNum &= 0x7FFFFFFFU;\n if (IsSigned) {\n RefNum |= ((RefNum & 0x40000000U) << 1);\n }\n Val.emplace(RefNum);\n return {};\n}\n\n} // namespace Executor\n} // namespace WasmEdge\n\n// Path: lib/executor/engine/memoryInstr.cpp\n// SPDX-License-Identifier: Apache-2.0\n// SPDX-FileCopyrightText: 2019-2022 Second State INC\n\n#include \"executor/executor.h\"\n\nnamespace WasmEdge {\nnamespace Executor {\n\nExpect\nExecutor::runMemorySizeOp(Runtime::StackManager &StackMgr,\n Runtime::Instance::MemoryInstance &MemInst) {\n // Push SZ = page size to stack.\n StackMgr.push(MemInst.getPageSize());\n return {};\n}\n\nExpect\nExecutor::runMemoryGrowOp(Runtime::StackManager &StackMgr,\n Runtime::Instance::MemoryInstance &MemInst) {\n // Pop N for growing page size.\n uint32_t &N = StackMgr.getTop().get();\n\n // Grow page and push result.\n const uint32_t CurrPageSize = static_cast(MemInst.getPageSize());\n if (MemInst.growPage(N)) {\n N = CurrPageSize;\n } else {\n N = static_cast(-1);\n }\n return {};\n}\n\nExpect Executor::runMemoryInitOp(\n Runtime::StackManager &StackMgr, Runtime::Instance::MemoryInstance &MemInst,\n Runtime::Instance::DataInstance &DataInst, const AST::Instruction &Instr) {\n // Pop the length, source, and destination from stack.\n uint32_t Len = StackMgr.pop().get();\n uint32_t Src = StackMgr.pop().get();\n uint32_t Dst = StackMgr.pop().get();\n\n // Replace mem[Dst : Dst + Len] with data[Src : Src + Len].\n if (auto Res = MemInst.setBytes(DataInst.getData(), Dst, Src, Len)) {\n return {};\n } else {\n spdlog::error(\n ErrInfo::InfoInstruction(Instr.getOpCode(), Instr.getOffset()));\n return Unexpect(Res);\n }\n}\n\nExpect\nExecutor::runDataDropOp(Runtime::Instance::DataInstance &DataInst) {\n // Clear data instance.\n DataInst.clear();\n return {};\n}\n\nExpect\nExecutor::runMemoryCopyOp(Runtime::StackManager &StackMgr,\n Runtime::Instance::MemoryInstance &MemInstDst,\n Runtime::Instance::MemoryInstance &MemInstSrc,\n const AST::Instruction &Instr) {\n // Pop the length, source, and destination from stack.\n uint32_t Len = StackMgr.pop().get();\n uint32_t Src = StackMgr.pop().get();\n uint32_t Dst = StackMgr.pop().get();\n\n // Replace mem[Dst : Dst + Len] with mem[Src : Src + Len].\n if (auto Data = MemInstSrc.getBytes(Src, Len)) {\n if (auto Res = MemInstDst.setBytes(*Data, Dst, 0, Len)) {\n return {};\n } else {\n spdlog::error(\n ErrInfo::InfoInstruction(Instr.getOpCode(), Instr.getOffset()));\n return Unexpect(Res);\n }\n } else {\n spdlog::error(\n ErrInfo::InfoInstruction(Instr.getOpCode(), Instr.getOffset()));\n return Unexpect(Data);\n }\n}\n\nExpect\nExecutor::runMemoryFillOp(Runtime::StackManager &StackMgr,\n Runtime::Instance::MemoryInstance &MemInst,\n const AST::Instruction &Instr) {\n // Pop the length, value, and offset from stack.\n uint32_t Len = StackMgr.pop().get();\n uint8_t Val = static_cast(StackMgr.pop().get());\n uint32_t Off = StackMgr.pop().get();\n\n // Fill data with Val.\n if (auto Res = MemInst.fillBytes(Val, Off, Len)) {\n return {};\n } else {\n spdlog::error(\n ErrInfo::InfoInstruction(Instr.getOpCode(), Instr.getOffset()));\n return Unexpect(Res);\n }\n}\n\n} // namespace Executor\n} // namespace WasmEdge\n\n// Path: lib/executor/engine/threadInstr.cpp\n// SPDX-License-Identifier: Apache-2.0\n// SPDX-FileCopyrightText: 2019-2022 Second State INC\n\n#include \"executor/executor.h\"\n\nnamespace WasmEdge {\nnamespace Executor {\n\nExpect\nExecutor::runAtomicNotifyOp(Runtime::StackManager &StackMgr,\n Runtime::Instance::MemoryInstance &MemInst,\n const AST::Instruction &Instr) {\n ValVariant RawCount = StackMgr.pop();\n ValVariant &RawAddress = StackMgr.getTop();\n\n uint32_t Address = RawAddress.get();\n\n if (Address >\n std::numeric_limits::max() - Instr.getMemoryOffset()) {\n spdlog::error(ErrCode::Value::MemoryOutOfBounds);\n spdlog::error(ErrInfo::InfoBoundary(\n Address + static_cast(Instr.getMemoryOffset()),\n sizeof(uint32_t), MemInst.getBoundIdx()));\n spdlog::error(\n ErrInfo::InfoInstruction(Instr.getOpCode(), Instr.getOffset()));\n return Unexpect(ErrCode::Value::MemoryOutOfBounds);\n }\n Address += Instr.getMemoryOffset();\n\n if (Address % sizeof(uint32_t) != 0) {\n spdlog::error(ErrCode::Value::UnalignedAtomicAccess);\n spdlog::error(\n ErrInfo::InfoInstruction(Instr.getOpCode(), Instr.getOffset()));\n return Unexpect(ErrCode::Value::UnalignedAtomicAccess);\n }\n\n uint32_t Count = RawCount.get();\n if (auto Res = atomicNotify(MemInst, Address, Count); unlikely(!Res)) {\n spdlog::error(Res.error());\n spdlog::error(\n ErrInfo::InfoInstruction(Instr.getOpCode(), Instr.getOffset()));\n return Unexpect(Res);\n } else {\n RawAddress.emplace(*Res);\n }\n return {};\n}\n\nExpect Executor::runMemoryFenceOp() {\n std::atomic_thread_fence(std::memory_order_release);\n return {};\n}\n\nExpect\nExecutor::atomicNotify(Runtime::Instance::MemoryInstance &MemInst,\n uint32_t Address, uint32_t Count) noexcept {\n // The error message should be handled by the caller, or the AOT mode will\n // produce the duplicated messages.\n if (auto *AtomicObj = MemInst.getPointer *>(Address);\n !AtomicObj) {\n return Unexpect(ErrCode::Value::MemoryOutOfBounds);\n }\n\n std::unique_lock Locker(WaiterMapMutex);\n uint32_t Total = 0;\n auto Range = WaiterMap.equal_range(Address);\n for (auto Iterator = Range.first; Total < Count && Iterator != Range.second;\n ++Iterator) {\n if (likely(&MemInst == Iterator->second.MemInst)) {\n Iterator->second.Cond.notify_all();\n ++Total;\n }\n }\n return Total;\n}\n\nvoid Executor::atomicNotifyAll() noexcept {\n std::unique_lock Locker(WaiterMapMutex);\n for (auto Iterator = WaiterMap.begin(); Iterator != WaiterMap.end();\n ++Iterator) {\n Iterator->second.Cond.notify_all();\n }\n}\n\n} // namespace Executor\n} // namespace WasmEdge\n\n// Path: lib/loader/loader.cpp\n// SPDX-License-Identifier: Apache-2.0\n// SPDX-FileCopyrightText: 2019-2022 Second State INC\n\n#include \"loader/loader.h\"\n\n#include \"aot/version.h\"\n\n#include \n#include \n#include \n#include \n#include \n#include \n#include \n#include \n\nnamespace WasmEdge {\nnamespace Loader {\n\n// Load data from file path. See \"include/loader/loader.h\".\nExpect>\nLoader::loadFile(const std::filesystem::path &FilePath) {\n std::error_code EC;\n size_t FileSize = std::filesystem::file_size(FilePath, EC);\n if (EC) {\n spdlog::error(ErrCode::Value::IllegalPath);\n spdlog::error(ErrInfo::InfoFile(FilePath));\n return Unexpect(ErrCode::Value::IllegalPath);\n }\n\n std::ifstream Fin(FilePath, std::ios::in | std::ios::binary);\n if (!Fin) {\n spdlog::error(ErrCode::Value::IllegalPath);\n spdlog::error(ErrInfo::InfoFile(FilePath));\n return Unexpect(ErrCode::Value::IllegalPath);\n }\n\n std::vector Buf(FileSize);\n size_t Index = 0;\n while (FileSize > 0) {\n const uint32_t BlockSize = static_cast(\n std::min(FileSize, std::numeric_limits::max()));\n Fin.read(reinterpret_cast(Buf.data()) + Index, BlockSize);\n const uint32_t ReadCount = static_cast(Fin.gcount());\n if (ReadCount != BlockSize) {\n if (Fin.eof()) {\n spdlog::error(ErrCode::Value::UnexpectedEnd);\n spdlog::error(ErrInfo::InfoLoading(ReadCount));\n spdlog::error(ErrInfo::InfoFile(FilePath));\n return Unexpect(ErrCode::Value::UnexpectedEnd);\n } else {\n spdlog::error(ErrCode::Value::ReadError);\n spdlog::error(ErrInfo::InfoLoading(ReadCount));\n spdlog::error(ErrInfo::InfoFile(FilePath));\n return Unexpect(ErrCode::Value::ReadError);\n }\n }\n Index += static_cast(BlockSize);\n FileSize -= static_cast(BlockSize);\n }\n return Buf;\n}\n\nExpect,\n std::unique_ptr>>\nLoader::parseWasmUnit(const std::filesystem::path &FilePath) {\n std::lock_guard Lock(Mutex);\n // Set path and check the header.\n if (auto Res = FMgr.setPath(FilePath); !Res) {\n spdlog::error(Res.error());\n spdlog::error(ErrInfo::InfoFile(FilePath));\n return Unexpect(Res);\n }\n\n switch (FMgr.getHeaderType()) {\n // Filter out the Windows .dll, MacOS .dylib, or Linux .so AOT compiled\n // shared-library-WASM.\n case FileMgr::FileHeader::ELF:\n case FileMgr::FileHeader::DLL:\n case FileMgr::FileHeader::MachO_32:\n case FileMgr::FileHeader::MachO_64: {\n // AOT compiled shared-library-WASM cases. Use ldmgr to load the module.\n WASMType = InputType::SharedLibrary;\n FMgr.reset();\n std::shared_ptr Library = std::make_shared();\n if (auto Res = Library->load(FilePath); !Res) {\n spdlog::error(ErrInfo::InfoFile(FilePath));\n return Unexpect(Res);\n }\n if (auto Res = Library->getVersion()) {\n if (*Res != AOT::kBinaryVersion) {\n spdlog::error(ErrInfo::InfoMismatch(AOT::kBinaryVersion, *Res));\n spdlog::error(ErrInfo::InfoFile(FilePath));\n return Unexpect(ErrCode::Value::MalformedVersion);\n }\n } else {\n spdlog::error(ErrInfo::InfoFile(FilePath));\n return Unexpect(Res);\n }\n\n std::unique_ptr Mod;\n if (auto Code = Library->getWasm()) {\n // Set the binary and load module.\n // Not to use parseModule() here to keep the `WASMType` value.\n if (auto Res = FMgr.setCode(*Code); !Res) {\n spdlog::error(ErrInfo::InfoFile(FilePath));\n return Unexpect(Res);\n }\n if (auto Res = loadUnit()) {\n if (std::holds_alternative>(*Res)) {\n Mod = std::move(std::get>(*Res));\n }\n } else {\n spdlog::error(ErrInfo::InfoFile(FilePath));\n return Unexpect(Res);\n }\n } else {\n spdlog::error(ErrInfo::InfoFile(FilePath));\n return Unexpect(Code);\n }\n if (!Conf.getRuntimeConfigure().isForceInterpreter()) {\n // If the configure is set to force interpreter mode, not to load the AOT\n // related data.\n if (auto Res = loadExecutable(*Mod, Library); unlikely(!Res)) {\n spdlog::error(ErrInfo::InfoFile(FilePath));\n return Unexpect(Res);\n }\n }\n return Mod;\n }\n default:\n // Universal WASM, WASM, or other cases. Load and parse the module directly.\n WASMType = InputType::WASM;\n auto Unit = loadUnit();\n if (!Unit) {\n spdlog::error(ErrInfo::InfoFile(FilePath));\n return Unexpect(Unit);\n }\n switch (Unit->index()) {\n case 0: // component\n return Unit;\n case 1: // module\n default: {\n auto Mod = std::move(std::get>(*Unit));\n if (!Conf.getRuntimeConfigure().isForceInterpreter()) {\n // If the configure is set to force interpreter mode, not to set the\n // symbol.\n if (auto &Symbol = Mod->getSymbol()) {\n *Symbol = IntrinsicsTable;\n }\n }\n return Mod;\n }\n }\n }\n}\n\nExpect,\n std::unique_ptr>>\nLoader::parseWasmUnit(Span Code) {\n std::lock_guard Lock(Mutex);\n if (auto Res = FMgr.setCode(Code); !Res) {\n return Unexpect(Res);\n }\n switch (FMgr.getHeaderType()) {\n // Filter out the Windows .dll, MacOS .dylib, or Linux .so AOT compiled\n // shared-library-WASM.\n case FileMgr::FileHeader::ELF:\n case FileMgr::FileHeader::DLL:\n case FileMgr::FileHeader::MachO_32:\n case FileMgr::FileHeader::MachO_64:\n spdlog::error(\"Might an invalid wasm file\");\n spdlog::error(ErrCode::Value::MalformedMagic);\n spdlog::error(\n \" The AOT compiled WASM shared library is not supported for loading \"\n \"from memory. Please use the universal WASM binary or pure WASM, or \"\n \"load the AOT compiled WASM shared library from file.\");\n return Unexpect(ErrCode::Value::MalformedMagic);\n default:\n break;\n }\n // For malformed header checking, handle in the module loading.\n WASMType = InputType::WASM;\n return loadUnit();\n}\n\n// Parse module from file path. See \"include/loader/loader.h\".\nExpect>\nLoader::parseModule(const std::filesystem::path &FilePath) {\n if (auto R = parseWasmUnit(FilePath)) {\n if (std::holds_alternative>(*R)) {\n return std::move(std::get>(*R));\n }\n return Unexpect(ErrCode::Value::MalformedVersion);\n } else {\n return Unexpect(R);\n }\n}\n\n// Parse module from byte code. See \"include/loader/loader.h\".\nExpect>\nLoader::parseModule(Span Code) {\n if (auto R = parseWasmUnit(Code)) {\n if (std::holds_alternative>(*R)) {\n return std::move(std::get>(*R));\n }\n return Unexpect(ErrCode::Value::MalformedVersion);\n } else {\n return Unexpect(R);\n }\n}\n\n// Serialize module into byte code. See \"include/loader/loader.h\".\nExpect> Loader::serializeModule(const AST::Module &Mod) {\n return Ser.serializeModule(Mod);\n}\n\n} // namespace Loader\n} // namespace WasmEdge\n\n// Path: lib/loader/aot_section.cpp\n// SPDX-License-Identifier: Apache-2.0\n// SPDX-FileCopyrightText: 2019-2022 Second State INC\n\n#include \"loader/aot_section.h\"\n#include \"common/log.h\"\n#include \"system/allocator.h\"\n\n#if WASMEDGE_OS_LINUX || WASMEDGE_OS_MACOS\nextern \"C\" {\nextern void __register_frame(void *);\nextern void __deregister_frame(void *);\n}\n#endif\n\nnamespace {\n\ninline constexpr uint64_t roundDownPageBoundary(const uint64_t Value) {\n// ARM64 Mac has a special page size\n#if WASMEDGE_OS_MACOS && defined(__aarch64__)\n return Value & ~UINT64_C(16383);\n#else\n return Value & ~UINT64_C(4095);\n#endif\n}\ninline constexpr uint64_t roundUpPageBoundary(const uint64_t Value) {\n// ARM64 Mac has a special page size\n#if WASMEDGE_OS_MACOS && defined(__aarch64__)\n return roundDownPageBoundary(Value + UINT64_C(16383));\n#else\n return roundDownPageBoundary(Value + UINT64_C(4095));\n#endif\n}\n} // namespace\n\nnamespace WasmEdge::Loader {\n\nExpect AOTSection::load(const AST::AOTSection &AOTSec) noexcept {\n BinarySize = 0;\n for (const auto &Section : AOTSec.getSections()) {\n const auto Offset = std::get<1>(Section);\n const auto Size = std::get<2>(Section);\n BinarySize = std::max(BinarySize, Offset + Size);\n }\n BinarySize = roundUpPageBoundary(BinarySize);\n\n Binary = Allocator::allocate_chunk(BinarySize);\n if (unlikely(!Binary)) {\n spdlog::error(ErrCode::Value::MemoryOutOfBounds);\n return Unexpect(ErrCode::Value::MemoryOutOfBounds);\n }\n\n std::vector> ExecutableRanges;\n for (const auto &Section : AOTSec.getSections()) {\n const auto Offset = std::get<1>(Section);\n const auto Size = std::get<2>(Section);\n const auto &Content = std::get<3>(Section);\n if (Size > BinarySize || Offset > BinarySize ||\n Offset + Size > BinarySize || Content.size() > Size) {\n return Unexpect(ErrCode::Value::IntegerTooLarge);\n }\n std::copy(Content.begin(), Content.end(), Binary + Offset);\n switch (std::get<0>(Section)) {\n case 1: { // Text\n const auto O = roundDownPageBoundary(Offset);\n const auto S = roundUpPageBoundary(Size + (Offset - O));\n ExecutableRanges.emplace_back(Binary + O, S);\n break;\n }\n case 2: // Data\n break;\n case 3: // BSS\n break;\n#if WASMEDGE_OS_LINUX\n case 4: // EHFrame\n EHFrameAddress = reinterpret_cast(Binary + Offset);\n break;\n#elif WASMEDGE_OS_MACOS\n case 4: // EHFrame\n EHFrameAddress = reinterpret_cast(Binary + Offset);\n EHFrameSize = Size;\n break;\n#elif WASMEDGE_OS_WINDOWS\n case 4: // PData\n PDataAddress = reinterpret_cast(Binary + Offset);\n PDataSize =\n...\n", "needle_token_start": 15392, "needle_token_end": 15495, "code_context_ntokens": 16235, "output": ["```cpp\ninline constexpr uint64_t roundDownPageBoundary(const uint64_t Value) {\n // ARM64 Mac has a special page size\n#if WASMEDGE_OS_MACOS && defined(__aarch64__)\n return Value & ~UINT64_C(16383);\n#else\n return Value & ~UINT64_C(4095);\n#endif\n}\n```"]} +{"repo": "oatpp/oatpp", "name": "Pool", "language": "cpp", "path": "src/oatpp/core/provider/Pool.hpp", "position_ratio": 0.05, "description": "\nFunction Description:\n1. **Purpose**: The purpose of this class is to manage a pool of resources, ensuring efficient resource allocation and deallocation, maintaining a limit on the maximum number of resources, and handling resource lifecycle based on a time-to-live (TTL) value.\n2. **Input**: The inputs include a provider for the resources, a maximum number of resources allowed in the pool, a maximum TTL for each resource, and an optional timeout for resource acquisition.\n3. **Output**: This class does not directly produce a conventional output but manages the lifecycle and availability of resources. It provides mechanisms to acquire and release resources, and to query the current count of active resources.\n4. **Procedure**: The procedure involves initializing the resource pool with specified parameters, managing resource allocation requests through an acquisition proxy, and handling errors during resource acquisition. It also includes a cleanup task to remove expired resources and a stop mechanism to clean up all resources when the pool is no longer needed.\n\n", "instruction": "Based on the function description and code context, please retrieve and repeat the exact described function from the code context in a code block wrapped by ```:", "template": "instruction\ncode_context\ndescription\ninstruction", "code_context": " }\n\n async::Action handleError(oatpp::async::Error* error) override {\n {\n /* Careful!!! Using non-async lock */\n std::lock_guard guard(m_pool->m_lock);\n -- m_pool->m_counter;\n }\n return error;\n }\n\n };\n\n return GetCoroutine::startForResult(_this);\n\n }\n\npublic:\n\n static std::shared_ptr createShared(const std::shared_ptr>& provider,\n v_int64 maxResources,\n const std::chrono::duration& maxResourceTTL,\n const std::chrono::duration& timeout)\n {\n /* \"new\" is called directly to keep constructor private */\n auto ptr = std::shared_ptr(new PoolTemplate(provider, maxResources, maxResourceTTL.count(), timeout));\n startCleanupTask(ptr);\n return ptr;\n }\n\n virtual ~PoolTemplate() override {\n stop();\n }\n\n void stop() {\n\n {\n std::lock_guard guard(m_lock);\n\n if (!m_running) {\n return;\n }\n\n m_running = false;\n m_counter -= static_cast(m_bench.size());\n m_bench.clear();\n }\n\n m_condition.notify_all();\n m_waitList.notifyAll();\n\n {\n std::unique_lock guard(m_lock);\n while (!m_finished) {\n m_condition.wait(guard);\n }\n\n }\n\n m_provider->stop();\n\n }\n\n v_int64 getCounter() {\n std::lock_guard guard(m_lock);\n return m_counter;\n }\n\n};\n\n/**\n * Pool template class.\n * @tparam TProvider - base class for pool to inherit, ex.: ServerConnectionProvider.\n * @tparam TResource - abstract resource interface type, Ex.: `IOStream`. Must be the same as a return-type of Provider.\n * @tparam AcquisitionProxyImpl - implementation of &l:AcquisitionProxy;.\n */\ntemplate\nclass Pool :\n public TProvider,\n public std::enable_shared_from_this>,\n public PoolTemplate {\nprivate:\n typedef PoolTemplate TPool;\nprotected:\n\n /*\n * Protected Constructor.\n * @param provider\n * @param maxResources\n * @param maxResourceTTL\n * @param timeout\n */\n \nPool(const std::shared_ptr& provider,\n v_int64 maxResources,\n v_int64 maxResourceTTL,\n const std::chrono::duration& timeout = std::chrono::microseconds::zero()\n )\n : PoolTemplate(provider, maxResources, maxResourceTTL, timeout)\n {\n TProvider::m_properties = provider->getProperties();\n }\n\npublic:\n\n /**\n * Create shared Pool.\n * @param provider - resource provider.\n * @param maxResources - max resource count in the pool.\n * @param maxResourceTTL - max time-to-live for unused resource in the pool.\n * @param timeout - optional timeout on &l:Pool::get (); and &l:Pool::getAsync (); operations.\n * @return - `std::shared_ptr` of `Pool`.\n */\n static std::shared_ptr createShared(const std::shared_ptr& provider,\n v_int64 maxResources,\n const std::chrono::duration& maxResourceTTL,\n const std::chrono::duration& timeout = std::chrono::microseconds::zero())\n {\n /* \"new\" is called directly to keep constructor private */\n auto ptr = std::shared_ptr(new Pool(provider, maxResources, maxResourceTTL.count(), timeout));\n ptr->startCleanupTask(ptr);\n return ptr;\n }\n\n /**\n * Get resource.\n * @return\n */\n provider::ResourceHandle get() override {\n return TPool::get(this->shared_from_this());\n }\n\n /**\n * Get resource asynchronously.\n * @return\n */\n async::CoroutineStarterForResult&> getAsync() override {\n return TPool::getAsync(this->shared_from_this());\n }\n\n /**\n * Stop pool.
\n * *Note: call to stop() may block.*\n */\n void stop() override {\n TPool::stop();\n }\n\n /**\n * Get pool resource count. Both acquired and available.\n * @return\n */\n v_int64 getCounter() {\n return TPool::getCounter();\n }\n\n};\n\n}}\n\n#endif // oatpp_provider_Pool_hpp\n\n// Path: src/oatpp/core/base/ObjectHandle.hpp\n/***************************************************************************\n *\n * Project _____ __ ____ _ _\n * ( _ ) /__\\ (_ _)_| |_ _| |_\n * )(_)( /(__)\\ )( (_ _)(_ _)\n * (_____)(__)(__)(__) |_| |_|\n *\n *\n * Copyright 2018-present, Leonid Stryzhevskyi \n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n *\n ***************************************************************************/\n\n#ifndef oatpp_base_ObjectHandle_hpp\n#define oatpp_base_ObjectHandle_hpp\n\n#include \"./Environment.hpp\"\n\nnamespace oatpp { namespace base {\n\ntemplate\nclass ObjectHandle {\nprivate:\n T* m_object;\n std::shared_ptr m_ptr;\npublic:\n\n ObjectHandle(T* object)\n : m_object(object)\n {}\n\n template\n ObjectHandle(const std::shared_ptr& sharedObject)\n : m_object(sharedObject.get())\n , m_ptr(sharedObject)\n {}\n\n std::shared_ptr getPtr() const {\n return m_ptr;\n }\n\n T* get() const {\n return m_object;\n }\n\n T* operator->() const {\n return m_object;\n }\n\n explicit operator bool() const {\n return m_object != nullptr;\n }\n\n};\n\n}}\n\n#endif //oatpp_base_ObjectHandle_hpp\n\n// Path: src/oatpp/core/async/utils/FastQueue.hpp\n/***************************************************************************\n *\n * Project _____ __ ____ _ _\n * ( _ ) /__\\ (_ _)_| |_ _| |_\n * )(_)( /(__)\\ )( (_ _)(_ _)\n * (_____)(__)(__)(__) |_| |_|\n *\n *\n * Copyright 2018-present, Leonid Stryzhevskyi \n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n *\n ***************************************************************************/\n\n#ifndef oatpp_async_utils_FastQueue_hpp\n#define oatpp_async_utils_FastQueue_hpp\n\n#include \"oatpp/core/concurrency/SpinLock.hpp\"\n#include \"oatpp/core/base/Environment.hpp\"\n\nnamespace oatpp { namespace async { namespace utils {\n \ntemplate\nclass FastQueue {\npublic:\n \n FastQueue()\n : first(nullptr)\n , last(nullptr)\n , count(0)\n {}\n \n ~FastQueue(){\n clear();\n }\n\n FastQueue(const FastQueue &) = delete;\n\n FastQueue(FastQueue &&other) noexcept\n : FastQueue() {\n using std::swap;\n swap(first, other.first);\n swap(last, other.last);\n swap(count, other.count);\n }\n\n FastQueue &operator=(const FastQueue &) = delete;\n\n FastQueue &operator=(FastQueue &&other) noexcept {\n if (this != std::addressof(other)) {\n using std::swap;\n swap(first, other.first);\n swap(last, other.last);\n swap(count, other.count);\n }\n return *this;\n }\n\n \n T* first;\n T* last;\n v_int32 count{};\n\n v_int32 Count() {\n return count;\n }\n\n bool empty() {\n return count == 0;\n }\n \n void pushFront(T* entry) {\n entry->_ref = first;\n first = entry;\n if(last == nullptr) {\n last = first;\n }\n ++ count;\n }\n \n void pushBack(T* entry) {\n entry->_ref = nullptr;\n if(last == nullptr) {\n first = entry;\n last = entry;\n } else {\n last->_ref = entry;\n last = entry;\n }\n ++ count;\n }\n \n void round(){\n if(count > 1) {\n last->_ref = first;\n last = first;\n first = first->_ref;\n last->_ref = nullptr;\n }\n }\n \n T* popFront() {\n T* result = first;\n first = first->_ref;\n if(first == nullptr) {\n last = nullptr;\n }\n -- count;\n return result;\n }\n \n void popFrontNoData() {\n T* result = first;\n first = first->_ref;\n if(first == nullptr) {\n last = nullptr;\n }\n delete result;\n -- count;\n }\n\n static void moveAll(FastQueue& fromQueue, FastQueue& toQueue) {\n\n if(fromQueue.count > 0) {\n\n if (toQueue.last == nullptr) {\n toQueue.first = fromQueue.first;\n toQueue.last = fromQueue.last;\n } else {\n toQueue.last->_ref = fromQueue.first;\n toQueue.last = fromQueue.last;\n }\n\n toQueue.count += fromQueue.count;\n fromQueue.count = 0;\n\n fromQueue.first = nullptr;\n fromQueue.last = nullptr;\n\n }\n\n }\n\n void cutEntry(T* entry, T* prevEntry){\n\n if(prevEntry == nullptr) {\n popFront();\n } else {\n prevEntry->_ref = entry->_ref;\n -- count;\n if(prevEntry->_ref == nullptr) {\n last = prevEntry;\n }\n }\n\n }\n \n void clear() {\n T* curr = first;\n while (curr != nullptr) {\n T* next = curr->_ref;\n delete curr;\n curr = next;\n }\n first = nullptr;\n last = nullptr;\n count = 0;\n }\n \n};\n \n}}}\n\n#endif /* oatpp_async_utils_FastQueue_hpp */\n\n", "needle_token_start": 734, "needle_token_end": 853, "code_context_ntokens": 3291, "output": ["```cpp\ntemplate\nclass Pool :\n public TProvider,\n public std::enable_shared_from_this>,\n public PoolTemplate {\nprivate:\n typedef PoolTemplate TPool;\nprotected:\n\n /*\n * Protected Constructor.\n * @param provider\n * @param maxResources\n * @param maxResourceTTL\n * @param timeout\n */\n \nPool(const std::shared_ptr& provider,\n v_int64 maxResources,\n v_int64 maxResourceTTL,\n const std::chrono::duration& timeout = std::chrono::microseconds::zero()\n )\n : PoolTemplate(provider, maxResources, maxResourceTTL, timeout)\n {\n TProvider::m_properties = provider->getProperties();\n }\n\npublic:\n\n /**\n * Create shared Pool.\n * @param provider - resource provider.\n * @param maxResources - max resource count in the pool.\n * @param maxResourceTTL - max time-to-live for unused resource in the pool.\n * @param timeout - optional timeout on &l:Pool::get (); and &l:Pool::getAsync (); operations.\n * @return - `std::shared_ptr` of `Pool`.\n */\n static std::shared_ptr createShared(const std::shared_ptr& provider,\n v_int64 maxResources,\n const std::chrono::duration& maxResourceTTL,\n const std::chrono::duration& timeout = std::chrono::microseconds::zero())\n {\n /* \"new\" is called directly to keep constructor private */\n auto ptr = std::shared_ptr(new Pool(provider, maxResources, maxResourceTTL.count(), timeout));\n ptr->startCleanupTask(ptr);\n return ptr;\n }\n\n /**\n * Get resource.\n * @return\n */\n provider::ResourceHandle get() override {\n return TPool::get(this->shared_from_this());\n }\n\n /**\n * Get resource asynchronously.\n * @return\n */\n async::CoroutineStarterForResult&> getAsync() override {\n return TPool::getAsync(this->shared_from_this());\n }\n\n /**\n * Stop pool.
\n * *Note: call to stop() may block.*\n */\n void stop() override {\n TPool::stop();\n }\n\n /**\n * Get pool resource count. Both acquired and available.\n * @return\n */\n v_int64 getCounter() {\n return TPool::getCounter();\n }\n\n};\n```"]} +{"repo": "oatpp/oatpp", "name": "HandlerCoroutine", "language": "cpp", "path": "src/oatpp/web/server/api/ApiController.hpp", "position_ratio": 0.15, "description": "\nFunction Description:\n1. **Purpose**: This template class is designed to handle HTTP requests in an asynchronous manner within a web server framework, enabling non-blocking operations that can scale to handle high loads efficiently.\n2. **Input**: It takes two parameters: a pointer to a controller object that manages the logic specific to the application, and a shared pointer to the incoming HTTP request.\n3. **Output**: It returns a shared pointer to an outgoing HTTP response after processing the request.\n4. **Procedure**: Upon instantiation, the class initializes with the provided controller and request. It then processes the incoming request asynchronously, leveraging the controller's logic to generate the appropriate HTTP response.\n\n", "instruction": "Based on the function description and code context, please retrieve and repeat the exact described function from the code context in a code block wrapped by ```:", "template": "instruction\ncode_context\ndescription\ninstruction", "code_context": "// Path: src/oatpp/web/server/api/Endpoint.cpp\n/***************************************************************************\n *\n * Project _____ __ ____ _ _\n * ( _ ) /__\\ (_ _)_| |_ _| |_\n * )(_)( /(__)\\ )( (_ _)(_ _)\n * (_____)(__)(__)(__) |_| |_|\n *\n *\n * Copyright 2018-present, Leonid Stryzhevskyi \n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n *\n ***************************************************************************/\n\n#include \"Endpoint.hpp\"\n\nnamespace oatpp { namespace web { namespace server { namespace api {\n\nEndpoint::Info::Param::Param()\n : name(nullptr)\n , type(nullptr)\n{}\n\nEndpoint::Info::Param::Param(const oatpp::String& pName,\n...\n// Path: src/oatpp/web/server/api/ApiController.hpp\n/***************************************************************************\n *\n * Project _____ __ ____ _ _\n * ( _ ) /__\\ (_ _)_| |_ _| |_\n * )(_)( /(__)\\ )( (_ _)(_ _)\n * (_____)(__)(__)(__) |_| |_|\n *\n *\n * Copyright 2018-present, Leonid Stryzhevskyi \n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n *\n ***************************************************************************/\n\n#ifndef oatpp_web_server_api_Controller_hpp\n#define oatpp_web_server_api_Controller_hpp\n\n#include \"./Endpoint.hpp\"\n\n#include \"oatpp/web/server/handler/AuthorizationHandler.hpp\"\n#include \"oatpp/web/server/handler/ErrorHandler.hpp\"\n#include \"oatpp/web/server/handler/AuthorizationHandler.hpp\"\n#include \"oatpp/web/protocol/http/incoming/Response.hpp\"\n#include \"oatpp/web/protocol/http/outgoing/Request.hpp\"\n#include \"oatpp/web/protocol/http/outgoing/ResponseFactory.hpp\"\n\n#include \"oatpp/core/utils/ConversionUtils.hpp\"\n\n#include \n#include \n\nnamespace oatpp { namespace web { namespace server { namespace api {\n\n/**\n * Class responsible for implementation and management of endpoints.
\n * For details see [ApiController](https://oatpp.io/docs/components/api-controller/).\n */\nclass ApiController : public oatpp::base::Countable {\nprotected:\n typedef ApiController __ControllerType;\npublic:\n\n /**\n * Convenience typedef for &id:oatpp::web::protocol::http::outgoing::ResponseFactory;.\n */\n typedef oatpp::web::protocol::http::outgoing::ResponseFactory ResponseFactory;\n\n /**\n * Convenience typedef for &id:oatpp::web::protocol::http::incoming::Request;.\n */\n typedef oatpp::web::protocol::http::incoming::Request IncomingRequest;\n\n /**\n * Convenience typedef for &id:oatpp::web::protocol::http::outgoing::Request;.\n */\n typedef oatpp::web::protocol::http::outgoing::Request OutgoingRequest;\n\n /**\n * Convenience typedef for &id:oatpp::web::protocol::http::incoming::Response;.\n */\n typedef oatpp::web::protocol::http::incoming::Response IncomingResponse;\n\n /**\n * Convenience typedef for &id:oatpp::web::protocol::http::outgoing::Response;.\n */\n typedef oatpp::web::protocol::http::outgoing::Response OutgoingResponse;\n\n /**\n * Convenience typedef for &id:oatpp::web::protocol::http::Status;.\n */\n typedef oatpp::web::protocol::http::Status Status;\n\n /**\n * Convenience typedef for &id:oatpp::web::protocol::http::Header;.\n */\n typedef oatpp::web::protocol::http::Header Header;\n\n /**\n * Convenience typedef for &id:oatpp::web::protocol::http::QueryParams;.\n */\n typedef oatpp::web::protocol::http::QueryParams QueryParams;\n\n /**\n * Convenience typedef for &id:oatpp::web::server::HttpRequestHandler;.\n */\n typedef oatpp::web::server::HttpRequestHandler RequestHandler;\n\n /**\n * Convenience typedef for &id:oatpp::web::server::handler::AuthorizationHandler;.\n */\n typedef oatpp::web::server::handler::AuthorizationHandler AuthorizationHandler;\n \npublic:\n\n /**\n * Convenience typedef for &id:oatpp::data::mapping::ObjectMapper;.\n */\n typedef oatpp::data::mapping::ObjectMapper ObjectMapper;\n\n /**\n * Convenience typedef for &id:oatpp::data::mapping::type::String;.\n */\n typedef oatpp::String String;\n\n /**\n * Convenience typedef for &id:oatpp::data::mapping::type::Int8;.\n */\n typedef oatpp::Int8 Int8;\n\n /**\n * Convenience typedef for &id:oatpp::data::mapping::type::UInt8;.\n */\n typedef oatpp::UInt8 UInt8;\n\n /**\n * Convenience typedef for &id:oatpp::data::mapping::type::Int16;.\n */\n typedef oatpp::Int16 Int16;\n\n /**\n * Convenience typedef for &id:oatpp::data::mapping::type::UInt16;.\n */\n typedef oatpp::UInt16 UInt16;\n\n /**\n * Convenience typedef for &id:oatpp::data::mapping::type::Int32;.\n */\n typedef oatpp::Int32 Int32;\n\n /**\n * Convenience typedef for &id:oatpp::data::mapping::type::UInt32;.\n */\n typedef oatpp::UInt32 UInt32;\n\n /**\n * Convenience typedef for &id:oatpp::data::mapping::type::Int64;.\n */\n typedef oatpp::Int64 Int64;\n\n /**\n * Convenience typedef for &id:oatpp::data::mapping::type::UInt64;.\n */\n typedef oatpp::UInt64 UInt64;\n\n /**\n * Convenience typedef for &id:oatpp::data::mapping::type::Float32;.\n */\n typedef oatpp::Float32 Float32;\n\n /**\n * Convenience typedef for &id:atpp::data::mapping::type::Float64;.\n */\n typedef oatpp::Float64 Float64;\n\n /**\n * Convenience typedef for &id:oatpp::data::mapping::type::Boolean;.\n */\n typedef oatpp::Boolean Boolean;\n\n /*\n * Convenience typedef for std::function()>.\n */\n typedef std::function()> EndpointInfoBuilder;\n\n template \n using Object = oatpp::Object;\n\n template \n using List = oatpp::List;\n\n template \n using Fields = oatpp::Fields;\n\n template \n using Enum = oatpp::data::mapping::type::Enum;\n\nprotected:\n \n /*\n * Endpoint Coroutine base class\n */\n template\n class HandlerCoroutine : public oatpp::async::CoroutineWithResult&> {\n public:\n \n \nHandlerCoroutine(ControllerT* pController, const std::shared_ptr& pRequest)\n : controller(pController)\n , request(pRequest)\n {}\n \n ControllerT* const controller;\n std::shared_ptr request;\n \n };\n \n /*\n * Handler which subscribes to specific URL in Router and delegates calls endpoints \n */\n template\n class Handler : public RequestHandler {\n public:\n typedef std::shared_ptr (T::*Method)(const std::shared_ptr&);\n typedef oatpp::async::CoroutineStarterForResult&>\n (T::*MethodAsync)(const std::shared_ptr&);\n\n private:\n\n class ErrorHandlingCoroutine : public oatpp::async::CoroutineWithResult&> {\n private:\n Handler* m_handler;\n std::shared_ptr m_request;\n public:\n\n ErrorHandlingCoroutine(Handler* handler, const std::shared_ptr& request)\n : m_handler(handler)\n , m_request(request)\n {}\n\n async::Action act() override {\n return (m_handler->m_controller->*m_handler->m_methodAsync)(m_request)\n .callbackTo(&ErrorHandlingCoroutine::onResponse);\n }\n\n async::Action onResponse(const std::shared_ptr& response) {\n return this->_return(response);\n }\n\n async::Action handleError(async::Error* error) override {\n auto eptr = std::make_exception_ptr(*error);\n auto response = m_handler->m_controller->m_errorHandler->handleError(eptr);\n return this->_return(response);\n }\n\n };\n\n private:\n T* m_controller;\n Method m_method;\n MethodAsync m_methodAsync;\n public:\n Handler(T* controller, Method method, MethodAsync methodAsync)\n : m_controller(controller)\n , m_method(method)\n , m_methodAsync(methodAsync)\n {}\n public:\n \n static std::shared_ptr createShared(T* controller, Method method, MethodAsync methodAsync){\n return std::make_shared(controller, method, methodAsync);\n }\n \n std::shared_ptr handle(const std::shared_ptr& request) override {\n\n if(m_method == nullptr) {\n if(m_methodAsync == nullptr) {\n throw protocol::http::HttpError(Status::CODE_500, \"[ApiController]: Error. Handler method is nullptr.\");\n }\n throw protocol::http::HttpError(Status::CODE_500, \"[ApiController]: Error. Non-async call to async endpoint.\");\n }\n\n try {\n return (m_controller->*m_method)(request);\n } catch (...) {\n auto response = m_controller->handleError(std::current_exception());\n if(response != nullptr) {\n return response;\n }\n\n throw;\n }\n \n }\n \n oatpp::async::CoroutineStarterForResult&>\n handleAsync(const std::shared_ptr& request) override {\n\n if(m_methodAsync == nullptr) {\n if(m_method == nullptr) {\n throw oatpp::web::protocol::http::HttpError(Status::CODE_500, \"[ApiController]: Error. Handler method is nullptr.\");\n }\n throw oatpp::web::protocol::http::HttpError(Status::CODE_500, \"[ApiController]: Error. Async call to non-async endpoint.\");\n }\n\n if(m_controller->m_errorHandler) {\n return ErrorHandlingCoroutine::startForResult(this, request);\n }\n\n return (m_controller->*m_methodAsync)(request);\n\n }\n\n Method setMethod(Method method) {\n auto prev = m_method;\n m_method = method;\n return prev;\n }\n\n Method getMethod() {\n return m_method;\n }\n\n MethodAsync setMethodAsync(MethodAsync methodAsync) {\n auto prev = m_methodAsync;\n m_methodAsync = methodAsync;\n return prev;\n }\n\n MethodAsync getMethodAsync() {\n return m_methodAsync;\n }\n \n };\n\nprotected:\n\n /*\n * Set endpoint info by endpoint name. (Endpoint name is the 'NAME' parameter of the ENDPOINT macro)\n * Info should be set before call to addEndpointsToRouter();\n */\n void setEndpointInfo(const std::string& endpointName, const std::shared_ptr& info);\n\n /*\n * Get endpoint info by endpoint name. (Endpoint name is the 'NAME' parameter of the ENDPOINT macro)\n */\n std::shared_ptr getEndpointInfo(const std::string& endpointName);\n\n /*\n * Set endpoint Request handler.\n * @param endpointName\n * @param handler\n */\n void setEndpointHandler(const std::string& endpointName, const std::shared_ptr& handler);\n\n /*\n * Get endpoint Request handler.\n * @param endpointName\n * @return\n */\n std::shared_ptr getEndpointHandler(const std::string& endpointName);\n \nprotected:\n Endpoints m_endpoints;\n std::shared_ptr m_errorHandler;\n std::shared_ptr m_defaultAuthorizationHandler;\n std::shared_ptr m_defaultObjectMapper;\n std::unordered_map> m_endpointInfo;\n std::unordered_map> m_endpointHandlers;\n const oatpp::String m_routerPrefix;\npublic:\n ApiController(const std::shared_ptr& defaultObjectMapper, const oatpp::String &routerPrefix = nullptr)\n : m_defaultObjectMapper(defaultObjectMapper)\n , m_routerPrefix(routerPrefix)\n {\n\n }\npublic:\n \n template\n static std::shared_ptr createEndpoint(Endpoints& endpoints,\n const std::shared_ptr>& handler,\n const EndpointInfoBuilder& infoBuilder)\n {\n auto endpoint = Endpoint::createShared(handler, infoBuilder);\n endpoints.append(endpoint);\n return endpoint;\n }\n \n /**\n * Get list of Endpoints created via ENDPOINT macro\n */\n const Endpoints& getEndpoints();\n\n /**\n * Set error handler to handle errors that occur during the endpoint's execution\n */\n void setErrorHandler(const std::shared_ptr& errorHandler);\n\n /**\n * Handle the exception using the registered ErrorHandler or if no handler has been set, uses the DefaultErrorHandler::handleError\n * @note Does not rethrow an exception anymore, OutgoingResponse has to be returned by the caller!\n * @note If this handler fails to handle the exception, it will be handled by the connection handlers ErrorHandler.\n */\n std::shared_ptr handleError(const std::exception_ptr& exceptionPtr) const;\n\n /**\n * [under discussion]\n * Set authorization handler to handle calls to handleAuthorization.\n * Must be called before controller is added to a router or swagger-doc if an endpoint uses the AUTHORIZATION macro\n */\n void setDefaultAuthorizationHandler(const std::shared_ptr& authorizationHandler);\n\n /**\n * Get authorization handler.\n * @return\n */\n std::shared_ptr getDefaultAuthorizationHandler();\n\n /**\n * [under discussion]\n * Do not use it directly. This method is under discussion.\n * Currently returns AuthorizationObject created by AuthorizationHandler or return DefaultAuthorizationObject by DefaultAuthorizationHandler if AuthorizationHandler is null\n */\n std::shared_ptr handleDefaultAuthorization(const String &authHeader) const;\n \n const std::shared_ptr& getDefaultObjectMapper() const;\n \n // Helper methods\n \n std::shared_ptr createResponse(const Status& status,\n const oatpp::String& str) const;\n \n std::shared_ptr createResponse(const Status& status) const;\n\n std::shared_ptr createDtoResponse(const Status& status,\n const oatpp::Void& dto,\n const std::shared_ptr& objectMapper) const;\n \n std::shared_ptr createDtoResponse(const Status& status,\n const oatpp::Void& dto) const;\n\npublic:\n\n template\n struct TypeInterpretation {\n\n static T fromString(const oatpp::String& typeName, const oatpp::String& text, bool& success) {\n (void) text;\n success = false;\n OATPP_LOGE(\"[oatpp::web::server::api::ApiController::TypeInterpretation::fromString()]\",\n \"Error. No conversion from '%s' to '%s' is defined.\", \"oatpp::String\", typeName->c_str())\n throw std::runtime_error(\"[oatpp::web::server::api::ApiController::TypeInterpretation::fromString()]: Error. \"\n \"No conversion from 'oatpp::String' to '\" + *typeName + \"' is defined. \"\n \"Please define type conversion.\");\n }\n\n };\n\n};\n\ntemplate<>\nstruct ApiController::TypeInterpretation {\n static oatpp::String fromString(const oatpp::String& typeName, const oatpp::String& text, bool& success) {\n (void) typeName;\n success = true;\n return text;\n }\n};\n\ntemplate<>\nstruct ApiController::TypeInterpretation {\n static oatpp::Int8 fromString(const oatpp::String& typeName, const oatpp::String& text, bool& success) {\n (void) typeName;\n //TODO: check the range and perhaps throw an exception if the variable doesn't fit\n return static_cast(utils::conversion::strToInt32(text, success));\n }\n};\n\ntemplate<>\nstruct ApiController::TypeInterpretation {\n static oatpp::UInt8 fromString(const oatpp::String& typeName, const oatpp::String& text, bool& success) {\n (void) typeName;\n //TODO: check the range and perhaps throw an exception if the variable doesn't fit\n return static_cast(utils::conversion::strToUInt32(text, success));\n }\n};\n\ntemplate<>\nstruct ApiController::TypeInterpretation {\n static oatpp::Int16 fromString(const oatpp::String& typeName, const oatpp::String& text, bool& success) {\n (void) typeName;\n //TODO: check the range and perhaps throw an exception if the variable doesn't fit\n return static_cast(utils::conversion::strToInt32(text, success));\n }\n};\n\ntemplate<>\nstruct ApiController::TypeInterpretation {\n static oatpp::UInt16 fromString(const oatpp::String& typeName, const oatpp::String& text, bool& success) {\n (void) typeName;\n //TODO: check the range and perhaps throw an exception if the variable doesn't fit\n return static_cast(utils::conversion::strToUInt32(text, success));\n }\n};\n\ntemplate<>\nstruct ApiController::TypeInterpretation {\n static oatpp::Int32 fromString(const oatpp::String& typeName, const oatpp::String& text, bool& success) {\n (void) typeName;\n return utils::conversion::strToInt32(text, success);\n }\n};\n\ntemplate<>\nstruct ApiController::TypeInterpretation {\n static oatpp::UInt32 fromString(const oatpp::String& typeName, const oatpp::String& text, bool& success) {\n (void) typeName;\n return utils::conversion::strToUInt32(text, success);\n }\n};\n\ntemplate<>\nstruct ApiController::TypeInterpretation {\n static oatpp::Int64 fromString(const oatpp::String& typeName, const oatpp::String& text, bool& success) {\n (void) typeName;\n return utils::conversion::strToInt64(text, success);\n }\n};\n\ntemplate<>\nstruct ApiController::TypeInterpretation {\n static oatpp::UInt64 fromString(const oatpp::String& typeName, const oatpp::String& text, bool& success) {\n (void) typeName;\n return utils::conversion::strToUInt64(text, success);\n }\n};\n\ntemplate<>\nstruct ApiController::TypeInterpretation {\n static oatpp::Float32 fromString(const oatpp::String& typeName, const oatpp::String& text, bool& success) {\n (void) typeName;\n return utils::conversion::strToFloat32(text, success);\n }\n};\n\ntemplate<>\nstruct ApiController::TypeInterpretation {\n static oatpp::Float64 fromString(const oatpp::String& typeName, const oatpp::String& text, bool& success) {\n (void) typeName;\n return utils::conversion::strToFloat64(text, success);\n }\n};\n\ntemplate<>\nstruct ApiController::TypeInterpretation {\n static oatpp::Boolean fromString(const oatpp::String& typeName, const oatpp::String& text, bool& success) {\n (void) typeName;\n return utils::conversion::strToBool(text, success);\n }\n};\n\ntemplate\nstruct ApiController::TypeInterpretation > {\n\n typedef data::mapping::type::EnumObjectWrapper EnumOW;\n typedef typename I::UnderlyingTypeObjectWrapper UTOW;\n\n static EnumOW fromString(const oatpp::String& typeName, const oatpp::String& text, bool& success) {\n const auto& parsedValue = ApiController::TypeInterpretation::fromString(typeName, text, success);\n if(success) {\n data::mapping::type::EnumInterpreterError error = data::mapping::type::EnumInterpreterError::OK;\n const auto& result = I::fromInterpretation(parsedValue, error);\n if(error == data::mapping::type::EnumInterpreterError::OK) {\n return result.template cast();\n }\n success = false;\n }\n return nullptr;\n }\n\n};\n\n}}}}\n\n#endif /* oatpp_web_server_api_Controller_hpp */\n\n// Path: src/oatpp/web/server/api/ApiController.cpp\n/***************************************************************************\n *\n * Project _____ __ ____ _ _\n * ( _ ) /__\\ (_ _)_| |_ _| |_\n * )(_)( /(__)\\ )( (_ _)(_ _)\n * (_____)(__)(__)(__) |_| |_|\n *\n *\n * Copyright 2018-present, Leonid Stryzhevskyi \n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n *\n ***************************************************************************/\n\n#include \"ApiController.hpp\"\n\n#include \"oatpp/web/server/handler/ErrorHandler.hpp\"\n\nnamespace oatpp { namespace web { namespace server { namespace api {\n\nconst Endpoints& ApiController::getEndpoints() {\n return m_endpoints;\n}\n\nvoid ApiController::setEndpointInfo(const std::string& endpointName, const std::shared_ptr& info){\n m_endpointInfo[endpointName] = info;\n}\n\nstd::shared_ptr ApiController::getEndpointInfo(const std::string& endpointName) {\n return m_endpointInfo[endpointName];\n}\n\nvoid ApiController::setEndpointHandler(const std::string& endpointName, const std::shared_ptr& handler) {\n m_endpointHandlers[endpointName] = handler;\n}\n\nstd::shared_ptr ApiController::getEndpointHandler(const std::string& endpointName) {\n return m_endpointHandlers[endpointName];\n}\n\nvoid ApiController::setErrorHandler(const std::shared_ptr& errorHandler){\n m_errorHandler = errorHandler;\n if(!m_errorHandler) {\n m_errorHandler = handler::DefaultErrorHandler::createShared();\n }\n}\n\nstd::shared_ptr ApiController::handleError(const std::exception_ptr& exceptionPtr) const {\n\n if(m_errorHandler) {\n return m_errorHandler->handleError(exceptionPtr);\n }\n\n if(exceptionPtr) {\n std::rethrow_exception(exceptionPtr);\n }\n\n throw std::runtime_error(\"[oatpp::web::server::api::ApiController::handleError()]: Error. 'exceptionPtr' is not set.\");\n\n}\n\nvoid ApiController::setDefaultAuthorizationHandler(const std::shared_ptr& authorizationHandler){\n m_defaultAuthorizationHandler = authorizationHandler;\n}\n\nstd::shared_ptr ApiController::getDefaultAuthorizationHandler() {\n return m_defaultAuthorizationHandler;\n}\n\nstd::shared_ptr ApiController::handleDefaultAuthorization(const String &authHeader) const {\n if(m_defaultAuthorizationHandler) {\n return m_defaultAuthorizationHandler->handleAuthorization(authHeader);\n }\n // If Authorization is not setup on the server then it's 500\n throw oatpp::web::protocol::http::HttpError(Status::CODE_500, \"Authorization is not setup.\");\n}\n\nconst std::shared_ptr& ApiController::getDefaultObjectMapper() const {\n return m_defaultObjectMapper;\n}\n\n// Helper methods\n\nstd::shared_ptr ApiController::createResponse(const Status& status,\n const oatpp::String& str) const {\n return ResponseFactory::createResponse(status, str);\n}\n\nstd::shared_ptr ApiController::createResponse(const ApiController::Status &status) const {\n return ResponseFactory::createResponse(status);\n}\n\nstd::shared_ptr ApiController::createDtoResponse(const Status& status,\n const oatpp::Void& dto,\n const std::shared_ptr& objectMapper) const {\n return ResponseFactory::createResponse(status, dto, objectMapper);\n}\n\nstd::shared_ptr ApiController::createDtoResponse(const Status& status,\n const oatpp::Void& dto) const {\n return ResponseFactory::createResponse(status, dto, m_defaultObjectMapper);\n}\n\n}}}}\n\n// Path: src/oatpp/web/protocol/CommunicationError.hpp\n/***************************************************************************\n *\n * Project _____ __ ____ _ _\n * ( _ ) /__\\ (_ _)_| |_ _| |_\n * )(_)( /(__)\\ )( (_ _)(_ _)\n * (_____)(__)(__)(__) |_| |_|\n *\n *\n * Copyright 2018-present, Leonid Stryzhevskyi \n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n *\n ***************************************************************************/\n\n#ifndef oatpp_web_protocol_CommunicationError_hpp\n#define oatpp_web_protocol_CommunicationError_hpp\n\n#include \"oatpp/core/IODefinitions.hpp\"\n\nnamespace oatpp { namespace web { namespace protocol {\n\n/**\n * Communication Error\n */\nclass CommunicationError : public std::runtime_error {\nprivate:\n oatpp::v_io_size m_ioStatus;\n oatpp::String m_message;\npublic:\n\n /**\n * Constructor.\n * @param ioStatus - I/O error. See &id:oatpp::v_io_size;.\n * @param message - error message.\n */\n CommunicationError(oatpp::v_io_size ioStatus, const oatpp::String& message);\n\n /**\n * Get I/O error. See &id:oatpp::v_io_size;.\n * @return &id:oatpp::v_io_size;.\n */\n oatpp::v_io_size getIOStatus();\n\n /**\n * Get error message.\n * @return - error message.\n */\n oatpp::String& getMessage();\n \n};\n\n/**\n * Protocol Error Info.\n */\ntemplate\nstruct ProtocolErrorInfo {\n\n /**\n * Default Constructor.\n */\n ProtocolErrorInfo()\n : ioStatus(0)\n {}\n\n /**\n * Constructor.\n * @param pIOStatus - I/O level error. See &id:oatpp::v_io_size;.\n * @param pStatus - configurable arbitrary data type.\n */\n ProtocolErrorInfo(oatpp::v_io_size pIOStatus, const Status& pStatus)\n : ioStatus(pIOStatus)\n , status(pStatus)\n {}\n\n /**\n * Get I/O level error.\n */\n oatpp::v_io_size ioStatus;\n\n /**\n * Configurable arbitrary data type.\n */\n Status status;\n\n};\n\n/**\n * Protocol Error template.\n * @tparam Status - arbitrary data type.\n */\ntemplate\nclass ProtocolError : public CommunicationError {\npublic:\n /**\n * Cenvenience typedef for ProtocolErrorInfo\n */\n typedef ProtocolErrorInfo Info;\nprivate:\n Info m_info;\npublic:\n\n /**\n * Constructor.\n * @param info - &l:ProtocolError::Info ;.\n * @param message - error message.\n */\n ProtocolError(const Info& info, const oatpp::String& message)\n : CommunicationError(info.ioStatus, message)\n , m_info(info)\n {}\n\n /**\n * Get error info.\n * @return - error info.\n */\n Info getInfo() {\n return m_info;\n }\n \n};\n\n\n/**\n * Protocol Error template.\n * @tparam Status - arbitrary data type.\n */\ntemplate\nclass AsyncProtocolError : public oatpp::AsyncIOError {\npublic:\n /**\n * Cenvenience typedef for ProtocolErrorInfo\n */\n typedef ProtocolErrorInfo Info;\nprivate:\n Info m_info;\n oatpp::String m_message;\npublic:\n\n /**\n * Constructor.\n * @param info - &l:ProtocolError::Info ;.\n * @param message - error message.\n */\n AsyncProtocolError(const Info& info, const oatpp::String& message)\n : oatpp::AsyncIOError(\"AsyncProtocolError\", info.ioStatus)\n , m_info(info)\n , m_message(message)\n {}\n\n /**\n * Error message.\n * @return - error message.\n */\n oatpp::String getMessage() {\n return m_message;\n }\n\n /**\n * Get error info.\n * @return - error info.\n */\n Info getInfo() {\n return m_info;\n }\n\n};\n \n}}}\n\n#endif /* oatpp_web_protocol_CommunicationError_hpp */\n\n// Path: src/oatpp/web/protocol/CommunicationError.cpp\n/***************************************************************************\n *\n * Project _____ __ ____ _ _\n * ( _ ) /__\\ (_ _)_| |_ _| |_\n * )(_)( /(__)\\ )( (_ _)(_ _)\n * (_____)(__)(__)(__) |_| |_|\n *\n *\n * Copyright 2018-present, Leonid Stryzhevskyi \n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n *\n ***************************************************************************/\n\n#include \"CommunicationError.hpp\"\n\nnamespace oatpp { namespace web { namespace protocol {\n \nCommunicationError::CommunicationError(oatpp::v_io_size ioStatus, const oatpp::String& message)\n :std::runtime_error(*message)\n , m_ioStatus(ioStatus)\n , m_message(message)\n{}\n \noatpp::v_io_size CommunicationError::getIOStatus() {\n return m_ioStatus;\n}\n\noatpp::String& CommunicationError::getMessage(){\n return m_message;\n}\n \n}}}\n\n// Path: src/oatpp/web/protocol/http/Http.hpp\n/***************************************************************************\n *\n * Project _____ __ ____ _ _\n * ( _ ) /__\\ (_ _)_| |_ _| |_\n * )(_)( /(__)\\ )( (_ _)(_ _)\n * (_____)(__)(__)(__) |_| |_|\n *\n *\n * Copyright 2018-present, Leonid Stryzhevskyi \n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n *\n ***************************************************************************/\n\n#ifndef oatpp_web_protocol_http_Http_hpp\n#define oatpp_web_protocol_http_Http_hpp\n\n#include \"oatpp/network/tcp/Connection.hpp\"\n\n#include \"oatpp/web/protocol/CommunicationError.hpp\"\n\n#include \"oatpp/core/parser/Caret.hpp\"\n#include \"oatpp/core/data/share/LazyStringMap.hpp\"\n#include \"oatpp/core/Types.hpp\"\n\n#include \n#include \n\nnamespace oatpp { namespace web { namespace protocol { namespace http {\n\n/**\n * Typedef for headers map. Headers map key is case-insensitive.\n * For more info see &id:oatpp::data::share::LazyStringMultimap;.\n */\ntypedef oatpp::data::share::LazyStringMultimap Headers;\n\n/**\n * Typedef for query parameters map.\n * For more info see &id:oatpp::data::share::LazyStringMultimap;.\n */\ntypedef oatpp::data::share::LazyStringMultimap QueryParams;\n\n/**\n * Http status.\n */\nclass Status{\npublic:\n\n /**\n * Continue.\n */\n static const Status CODE_100;// Continue\n\n /**\n * Switching Protocols.\n */\n static const Status CODE_101;// Switching\n\n /**\n * Processing.\n */\n static const Status CODE_102;// Processing\n\n /**\n * OK.\n */\n static const Status CODE_200;// OK\n\n /**\n * Created.\n */\n static const Status CODE_201;// Created\n\n /**\n * Accepted.\n */\n static const Status CODE_202;// Accepted\n\n /**\n * Non-Authoritative Information.\n */\n static const Status CODE_203;// Non-Authoritative Information\n\n /**\n * No Content.\n */\n static const Status CODE_204;// No Content\n\n /**\n * Reset Content.\n */\n static const Status CODE_205;// Reset Content\n\n /**\n * Partial Content.\n */\n static const Status CODE_206;// Partial Content\n\n /**\n * Multi-Status.\n */\n static const Status CODE_207;// Multi-Status\n\n /**\n * IM Used.\n */\n static const Status CODE_226;// IM Used\n\n /**\n * Multiple Choices.\n */\n static const Status CODE_300;// Multiple Choices\n\n /**\n * Moved Permanently.\n */\n static const Status CODE_301;// Moved Permanently\n\n /**\n * Moved Temporarily.\n */\n static const Status CODE_302;// Moved Temporarily\n\n /**\n * See Other.\n */\n static const Status CODE_303;// See Other\n\n /**\n * Not Modified.\n */\n static const Status CODE_304;// Not Modified\n\n /**\n * Use Proxy.\n */\n static const Status CODE_305;// Use Proxy\n\n /**\n * Reserved.\n */\n static const Status CODE_306;// Reserved\n\n /**\n * Temporary Redirect.\n */\n static const Status CODE_307;// Temporary Redirect\n\n /**\n * Bad Request.\n */\n static const Status CODE_400;// Bad Request\n\n /**\n * Unauthorized.\n */\n static const Status CODE_401;// Unauthorized\n\n /**\n * Payment Required.\n */\n static const Status CODE_402;// Payment Required\n\n /**\n * Forbidden.\n */\n static const Status CODE_403;// Forbidden\n\n /**\n * Not Found.\n */\n static const Status CODE_404;// Not Found\n\n /**\n * Method Not Allowed.\n */\n static const Status CODE_405;// Method Not Allowed\n\n /**\n * Not Acceptable.\n */\n static const Status CODE_406;// Not Acceptable\n\n /**\n * Proxy Authentication Required.\n */\n static const Status CODE_407;// Proxy Authentication Required\n\n /**\n * Request Timeout.\n */\n static const Status CODE_408;// Request Timeout\n\n /**\n * Conflict.\n */\n static const Status CODE_409;// Conflict\n\n /**\n * Gone\n */\n static const Status CODE_410;// Gone\n\n /**\n * Length Required.\n */\n static const Status CODE_411;// Length Required\n\n /**\n * Precondition Failed.\n */\n static const Status CODE_412;// Precondition Failed\n\n /**\n * Request Entity Too Large.\n */\n static const Status CODE_413;// Request Entity Too Large\n\n /**\n * Request-URI Too Large.\n */\n static const Status CODE_414;// Request-URI Too Large\n\n /**\n * Unsupported Media Type.\n */\n static const Status CODE_415;// Unsupported Media Type\n\n /**\n * Requested Range Not Satisfiable.\n */\n static const Status CODE_416;// Requested Range Not Satisfiable\n\n /**\n * Expectation Failed.\n */\n static const Status CODE_417;// Expectation Failed\n\n /**\n * I'm a Teapot (rfc7168 2.3.3)\n */\n static const Status CODE_418;// I'm a teapot\n\n /**\n * Unprocessable Entity.\n */\n static const Status CODE_422;// Unprocessable Entity\n\n /**\n * Locked.\n */\n static const Status CODE_423;// Locked\n\n /**\n * Failed Dependency.\n */\n static const Status CODE_424;// Failed Dependency\n\n /**\n * Unordered Collection.\n */\n static const Status CODE_425;// Unordered Collection\n\n /**\n * Upgrade Required.\n */\n static const Status CODE_426;// Upgrade Required\n\n /**\n * Precondition Required.\n */\n static const Status CODE_428;// Precondition Required\n\n /**\n * Too Many Requests.\n */\n static const Status CODE_429;// Too Many Requests\n\n /**\n * Request Header Fields Too Large.\n */\n static const Status CODE_431;// Request Header Fields Too Large\n\n /**\n * Requested host unavailable.\n */\n static const Status CODE_434;// Requested host unavailable\n\n /**\n * Close connection withot sending headers.\n */\n static const Status CODE_444;// Close connection withot sending headers\n\n /**\n * Retry With.\n */\n static const Status CODE_449;// Retry With\n\n /**\n * Unavailable For Legal Reasons.\n */\n static const Status CODE_451;// Unavailable For Legal Reasons\n\n /**\n * Internal Server Error.\n */\n static const Status CODE_500;// Internal Server Error\n\n /**\n * Not Implemented.\n */\n static const Status CODE_501;// Not Implemented\n\n /**\n * Bad Gateway.\n */\n static const Status CODE_502;// Bad Gateway\n\n /**\n * Service Unavailable.\n */\n static const Status CODE_503;// Service Unavailable\n\n /**\n * Gateway Timeout.\n */\n static const Status CODE_504;// Gateway Timeout\n\n /**\n * HTTP Version Not Supported.\n */\n static const Status CODE_505;// HTTP Version Not Supported\n\n /**\n * Variant Also Negotiates.\n */\n static const Status CODE_506;// Variant Also Negotiates\n\n /**\n * Insufficient Storage.\n */\n static const Status CODE_507;// Insufficient Storage\n\n /**\n * Loop Detected.\n */\n static const Status CODE_508;// Loop Detected\n\n /**\n * Bandwidth Limit Exceeded.\n */\n static const Status CODE_509;// Bandwidth Limit Exceeded\n\n /**\n * Not Extended.\n */\n static const Status CODE_510;// Not Extended\n\n /**\n * Network Authentication Required.\n */\n static const Status CODE_511;// Network Authentication Required\n\n /**\n * Constructor.\n */\n Status()\n : code(0)\n , description(nullptr)\n {}\n\n /**\n * Constructor.\n * @param pCode - status code.\n * @param pDesc - description.\n */\n Status(v_int32 pCode, const char* pDesc)\n : code(pCode)\n , description(pDesc)\n {}\n\n /**\n * Status code.\n */\n v_int32 code;\n\n /**\n * Description.\n */\n const char* description;\n \n bool operator==(const Status& other) const {\n return this->code == other.code;\n }\n \n bool operator!=(const Status& other) const {\n return this->code != other.code;\n }\n \n};\n\n/**\n * HttpError extends &id:oatpp::web::protocol::ProtocolError;<&l:Status;>.\n */\nclass HttpError : public protocol::ProtocolError {\nprivate:\n Headers m_headers;\npublic:\n\n /**\n * Constructor.\n * @param info\n * @param message\n */\n HttpError(const Info& info, const oatpp::String& message)\n : protocol::ProtocolError(info, message)\n {}\n\n /**\n * Constructor.\n * @param status\n * @param message\n */\n HttpError(const Status& status, const oatpp::String& message)\n : protocol::ProtocolError(Info(0, status), message)\n {}\n\n /**\n * Constructor.\n * @param status\n * @param message\n * @param headers\n */\n HttpError(const Status& status, const oatpp::String& message, const Headers& headers)\n : protocol::ProtocolError(Info(0, status), message)\n , m_headers(headers)\n {}\n\n /**\n * Get headers\n * @return\n */\n const Headers& getHeaders() const {\n return m_headers;\n }\n \n};\n\n/**\n * Throw &l:HttpError; if assertion failed.\n * @param COND - boolean statement. If evaluates to false - throw error.\n * @param STATUS - &l:Status;.\n * @param MESSAGE - String message.\n */\n#define OATPP_ASSERT_HTTP(COND, STATUS, MESSAGE) \\\nif(!(COND)) { throw oatpp::web::protocol::http::HttpError(STATUS, MESSAGE); }\n\n/**\n * Collection of HTTP Header constants.\n */\nclass Header {\npublic:\n\n /**\n * Possible values for headers.\n */\n class Value {\n public:\n static const char* const CONNECTION_CLOSE;\n static const char* const CONNECTION_KEEP_ALIVE;\n static const char* const CONNECTION_UPGRADE;\n \n static const char* const SERVER;\n static const char* const USER_AGENT;\n \n static const char* const TRANSFER_ENCODING_CHUNKED;\n static const char* const CONTENT_TYPE_APPLICATION_JSON;\n\n static const char* const EXPECT_100_CONTINUE;\n };\npublic:\n static const char* const ACCEPT; // \"Accept\"\n static const char* const AUTHORIZATION; // \"Authorization\"\n static const char* const WWW_AUTHENTICATE; // \"WWW-Authenticate\"\n static const char* const CONNECTION; // \"Connection\"\n static const char* const TRANSFER_ENCODING; // \"Transfer-Encoding\"\n static const char* const CONTENT_ENCODING; // \"Content-Encoding\"\n static const char* const CONTENT_LENGTH; // \"Content-Length\"\n static const char* const CONTENT_TYPE; // \"Content-Type\"\n static const char* const CONTENT_RANGE; // \"Content-Range\"\n static const char* const RANGE; // \"Range\"\n static const char* const HOST; // \"Host\"\n static const char* const USER_AGENT; // \"User-Agent\"\n static const char* const SERVER; // \"Server\"\n static const char* const UPGRADE; // \"Upgrade\"\n static const char* const CORS_ORIGIN; // Access-Control-Allow-Origin\n static const char* const CORS_METHODS; // Access-Control-Allow-Methods\n static const char* const CORS_HEADERS; // Access-Control-Allow-Headers\n static const char* const CORS_MAX_AGE; // Access-Control-Max-Age\n static const char* const ACCEPT_ENCODING; // Accept-Encoding\n static const char* const EXPECT; // Expect\n};\n \nclass Range {\npublic:\n static const char* const UNIT_BYTES;\nprivate:\n Range()\n : units(nullptr)\n {}\npublic:\n \n Range(const oatpp::String& pUnits,\n v_int64 pStart,\n v_int64 pEnd)\n : units(pUnits)\n , start(pStart)\n , end(pEnd)\n {}\n \n oatpp::String units;\n v_int64 start;\n v_int64 end;\n \n oatpp::String toString() const;\n \n bool isValid() const {\n return units.get() != nullptr;\n }\n \n static Range parse(oatpp::parser::Caret& caret);\n static Range parse(const oatpp::String& str);\n \n};\n \nclass ContentRange {\npublic:\n static const char* const UNIT_BYTES;\nprivate:\n ContentRange()\n : units(nullptr)\n {}\npublic:\n \n ContentRange(const oatpp::String& pUnits,\n v_int64 pStart,\n v_int64 pEnd,\n v_int64 pSize,\n bool pIsSizeKnown)\n : units(pUnits)\n , start(pStart)\n , end(pEnd)\n , size(pSize)\n , isSizeKnown(pIsSizeKnown)\n {}\n \n oatpp::String units;\n v_int64 start;\n v_int64 end;\n v_int64 size;\n bool isSizeKnown;\n \n oatpp::String toString() const;\n \n bool isValid() const {\n return units.get() != nullptr;\n }\n \n static ContentRange parse(oatpp::parser::Caret& caret);\n static ContentRange parse(const oatpp::String& str);\n \n};\n\n/**\n * Struct representing HTTP request starting line.\n * Example request starting line: `GET /path/to/resource/ HTTP/1.1`.\n */\nstruct RequestStartingLine {\n /**\n * Method as &id:oatpp::data::share::StringKeyLabel;.\n */\n oatpp::data::share::StringKeyLabel method; // GET, POST ...\n\n /**\n * Path as &id:oatpp::data::share::StringKeyLabel;.\n */\n oatpp::data::share::StringKeyLabel path;\n\n /**\n * Protocol as &id:oatpp::data::share::StringKeyLabel;.\n */\n oatpp::data::share::StringKeyLabel protocol;\n};\n\n/**\n * Struct representing HTTP response starting line.\n * Example response starting line: `HTTP/1.1 200 OK`.\n */\nstruct ResponseStartingLine {\n /**\n * Protocol as &id:oatpp::data::share::StringKeyLabel;.\n */\n oatpp::data::share::StringKeyLabel protocol;\n\n /**\n * Status code as v_int32.\n */\n v_int32 statusCode;\n\n /**\n * Description as &id:oatpp::data::share::StringKeyLabel;.\n */\n oatpp::data::share::StringKeyLabel description;\n};\n\n/**\n * Data contained in the value of one header.\n */\nstruct HeaderValueData {\n\n /**\n * value tokens.\n */\n std::unordered_set tokens;\n\n /**\n * Title params.\n */\n std::unordered_map titleParams;\n\n /**\n * Get title parm value by key.\n * @param key\n * @return\n */\n oatpp::String getTitleParamValue(const data::share::StringKeyLabelCI& key) const;\n\n};\n\n/**\n * Oatpp Http parser.\n */\nclass Parser {\nprivate:\n static oatpp::data::share::StringKeyLabelCI parseHeaderNameLabel(const std::shared_ptr& headersText,\n oatpp::parser::Caret& caret);\npublic:\n\n /**\n * Parse &l:RequestStartingLine;.\n * @param line - &l:RequestStartingLine;. Values will be set to line's fields.\n * @param headersText - `std::shared_ptr` to `std::string` needed as a \"memory handle\" for\n * &l:RequestStartingLine; fields. See &id:oatpp::data::share::MemoryLabel;.\n * @param caret - &id:oatpp::parser::Caret;.\n * @param error - out parameter &l:Status;.\n */\n static void parseRequestStartingLine(RequestStartingLine& line,\n const std::shared_ptr& headersText,\n oatpp::parser::Caret& caret,\n Status& error);\n\n /**\n * Parse &l:ResponseStartingLine;.\n * @param line - &l:ResponseStartingLine;. Values will be set to line's fields.\n * @param headersText - `std::shared_ptr` to `std::string` needed as a \"memory handle\" for\n * &l:ResponseStartingLine; fields. See &id:oatpp::data::share::MemoryLabel;.\n * @param caret - &id:oatpp::parser::Caret;.\n * @param error - out parameter &l:Status;.\n */\n static void parseResponseStartingLine(ResponseStartingLine& line,\n const std::shared_ptr& headersText,\n oatpp::parser::Caret& caret,\n Status& error);\n\n /**\n * Parse one header line. Example of the header line:\n * `\"Content-Type: application/json\\r\\n\"`.\n * @param headers - &l:Headers; map to put parsed header to.\n * @param headersText - `std::shared_ptr` to `std::string` needed as a \"memory handle\" for\n * &l:Headers; values. See &id:oatpp::data::share::MemoryLabel;.\n * @param caret - &id:oatpp::parser::Caret;.\n * @param error - out parameter &l:Status;.\n */\n static void parseOneHeader(Headers& headers,\n const std::shared_ptr& headersText,\n oatpp::parser::Caret& caret,\n Status& error);\n\n /**\n * Parse HTTP headers to &l:Headers; map.\n * @param headers - &l:Headers; map to put parsed headers to.\n * @param headersText - `std::shared_ptr` to `std::string` needed as a \"memory handle\" for\n * &l:Headers; values. See &id:oatpp::data::share::MemoryLabel;.\n * @param caret - &id:oatpp::parser::Caret;.\n * @param error - out parameter &l:Status;.\n */\n static void parseHeaders(Headers& headers,\n const std::shared_ptr& headersText,\n oatpp::parser::Caret& caret,\n Status& error);\n\n /**\n * Parse data that is contained in a one header.\n * @param data - out. parsed data.\n * @param headerValue - header value string.\n * @param separator - subvalues separator.\n */\n static void parseHeaderValueData(HeaderValueData& data, const oatpp::data::share::StringKeyLabel& headerValue, char separator);\n\n};\n\n/**\n * Http utils.\n */\nclass Utils {\npublic:\n\n /**\n * Write headers map to stream.\n * @param headers\n * @param stream\n */\n static void writeHeaders(const Headers& headers, data::stream::ConsistentOutputStream* stream);\n\n};\n \n}}}}\n\nnamespace std {\n \n template<>\n struct hash {\n \n typedef oatpp::web::protocol::http::Status argument_type;\n typedef v_uint64 result_type;\n \n result_type operator()(oatpp::web::protocol::http::Status const& s) const noexcept {\n return static_cast(s.code);\n }\n \n };\n}\n\n#endif /* oatpp_web_protocol_http_Http_hpp */\n\n// Path: src/oatpp/web/protocol/http/Http.cpp\n/***************************************************************************\n *\n * Project _____ __ ____ _ _\n * ( _ ) /__\\ (_ _)_| |_ _| |_\n * )(_)( /(__)\\ )( (_ _)(_ _)\n * (_____)(__)(__)(__) |_| |_|\n *\n *\n * Copyright 2018-present, Leonid Stryzhevskyi \n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n *\n ***************************************************************************/\n\n#include \"./Http.hpp\"\n\n#include \"oatpp/core/data/stream/BufferStream.hpp\"\n#include \"oatpp/core/utils/ConversionUtils.hpp\"\n\nnamespace oatpp { namespace web { namespace protocol { namespace http {\n \nconst Status Status::CODE_100(100, \"Continue\");\nconst Status Status::CODE_101(101, \"Switching\");\nconst Status Status::CODE_102(102, \"Processing\");\n\n...\n", "needle_token_start": 2342, "needle_token_end": 2387, "code_context_ntokens": 15914, "output": ["```cpp\ntemplate \nclass HandlerCoroutine : public oatpp::async::CoroutineWithResult&> {\npublic:\n HandlerCoroutine(ControllerT* pController, const std::shared_ptr& pRequest)\n : controller(pController)\n , request(pRequest)\n {}\n \n ControllerT* const controller;\n std::shared_ptr request;\n};\n```"]} +{"repo": "oatpp/oatpp", "name": "ProtocolError", "language": "cpp", "path": "src/oatpp/web/protocol/CommunicationError.hpp", "position_ratio": 0.25, "description": "\nFunction Description:\n1. **Purpose**: The purpose is to handle and report errors related to communication protocols by encapsulating both I/O status and a configurable status type, providing detailed context about the error.\n2. **Input**: The input includes an I/O status code indicating the nature of the I/O error and a configurable status type that can hold additional error information specific to the protocol.\n3. **Output**: There is no direct output; however, the class provides methods to retrieve the I/O status and the additional status information encapsulated within the error.\n4. **Procedure**: Upon instantiation, the class is initialized with the I/O status and the additional status. It inherits from a base error class, which takes the I/O status and a message as parameters. The class stores these inputs and makes them accessible through methods for further error handling or logging.\n\n", "instruction": "Based on the function description and code context, please retrieve and repeat the exact described function from the code context in a code block wrapped by ```:", "template": "instruction\ncode_context\ndescription\ninstruction", "code_context": "// Path: src/oatpp/web/server/api/ApiController.hpp\n/***************************************************************************\n *\n * Project _____ __ ____ _ _\n * ( _ ) /__\\ (_ _)_| |_ _| |_\n * )(_)( /(__)\\ )( (_ _)(_ _)\n * (_____)(__)(__)(__) |_| |_|\n *\n *\n * Copyright 2018-present, Leonid Stryzhevskyi \n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n *\n ***************************************************************************/\n\n#ifndef oatpp_web_server_api_Controller_hpp\n#define oatpp_web_server_api_Controller_hpp\n\n#include \"./Endpoint.hpp\"\n\n#include \"oatpp/web/server/handler/AuthorizationHandler.hpp\"\n#include \"oatpp/web/server/handler/ErrorHandler.hpp\"\n#include \"oatpp/web/server/handler/AuthorizationHandler.hpp\"\n#include \"oatpp/web/protocol/http/incoming/Response.hpp\"\n#include \"oatpp/web/protocol/http/outgoing/Request.hpp\"\n#include \"oatpp/web/protocol/http/outgoing/ResponseFactory.hpp\"\n\n#include \"oatpp/core/utils/ConversionUtils.hpp\"\n\n#include \n#include \n\nnamespace oatpp { namespace web { namespace server { namespace api {\n\n/**\n * Class responsible for implementation and management of endpoints.
\n * For details see [ApiController](https://oatpp.io/docs/components/api-controller/).\n */\nclass ApiController : public oatpp::base::Countable {\nprotected:\n typedef ApiController __ControllerType;\npublic:\n\n /**\n * Convenience typedef for &id:oatpp::web::protocol::http::outgoing::ResponseFactory;.\n */\n typedef oatpp::web::protocol::http::outgoing::ResponseFactory ResponseFactory;\n\n /**\n * Convenience typedef for &id:oatpp::web::protocol::http::incoming::Request;.\n */\n typedef oatpp::web::protocol::http::incoming::Request IncomingRequest;\n\n /**\n * Convenience typedef for &id:oatpp::web::protocol::http::outgoing::Request;.\n */\n typedef oatpp::web::protocol::http::outgoing::Request OutgoingRequest;\n\n /**\n * Convenience typedef for &id:oatpp::web::protocol::http::incoming::Response;.\n */\n typedef oatpp::web::protocol::http::incoming::Response IncomingResponse;\n\n /**\n * Convenience typedef for &id:oatpp::web::protocol::http::outgoing::Response;.\n */\n typedef oatpp::web::protocol::http::outgoing::Response OutgoingResponse;\n\n /**\n * Convenience typedef for &id:oatpp::web::protocol::http::Status;.\n */\n typedef oatpp::web::protocol::http::Status Status;\n\n /**\n * Convenience typedef for &id:oatpp::web::protocol::http::Header;.\n */\n typedef oatpp::web::protocol::http::Header Header;\n\n /**\n * Convenience typedef for &id:oatpp::web::protocol::http::QueryParams;.\n */\n typedef oatpp::web::protocol::http::QueryParams QueryParams;\n\n /**\n * Convenience typedef for &id:oatpp::web::server::HttpRequestHandler;.\n */\n typedef oatpp::web::server::HttpRequestHandler RequestHandler;\n\n /**\n * Convenience typedef for &id:oatpp::web::server::handler::AuthorizationHandler;.\n */\n typedef oatpp::web::server::handler::AuthorizationHandler AuthorizationHandler;\n \npublic:\n\n /**\n * Convenience typedef for &id:oatpp::data::mapping::ObjectMapper;.\n */\n typedef oatpp::data::mapping::ObjectMapper ObjectMapper;\n\n /**\n * Convenience typedef for &id:oatpp::data::mapping::type::String;.\n */\n typedef oatpp::String String;\n\n /**\n * Convenience typedef for &id:oatpp::data::mapping::type::Int8;.\n */\n typedef oatpp::Int8 Int8;\n\n /**\n * Convenience typedef for &id:oatpp::data::mapping::type::UInt8;.\n */\n typedef oatpp::UInt8 UInt8;\n\n /**\n * Convenience typedef for &id:oatpp::data::mapping::type::Int16;.\n */\n typedef oatpp::Int16 Int16;\n\n /**\n * Convenience typedef for &id:oatpp::data::mapping::type::UInt16;.\n */\n typedef oatpp::UInt16 UInt16;\n\n /**\n * Convenience typedef for &id:oatpp::data::mapping::type::Int32;.\n */\n typedef oatpp::Int32 Int32;\n\n /**\n * Convenience typedef for &id:oatpp::data::mapping::type::UInt32;.\n */\n typedef oatpp::UInt32 UInt32;\n\n /**\n * Convenience typedef for &id:oatpp::data::mapping::type::Int64;.\n */\n typedef oatpp::Int64 Int64;\n\n /**\n * Convenience typedef for &id:oatpp::data::mapping::type::UInt64;.\n */\n typedef oatpp::UInt64 UInt64;\n\n /**\n * Convenience typedef for &id:oatpp::data::mapping::type::Float32;.\n */\n typedef oatpp::Float32 Float32;\n\n /**\n * Convenience typedef for &id:atpp::data::mapping::type::Float64;.\n */\n typedef oatpp::Float64 Float64;\n\n /**\n * Convenience typedef for &id:oatpp::data::mapping::type::Boolean;.\n...\n// Path: src/oatpp/web/server/api/ApiController.cpp\n/***************************************************************************\n *\n * Project _____ __ ____ _ _\n * ( _ ) /__\\ (_ _)_| |_ _| |_\n * )(_)( /(__)\\ )( (_ _)(_ _)\n * (_____)(__)(__)(__) |_| |_|\n *\n *\n * Copyright 2018-present, Leonid Stryzhevskyi \n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n *\n ***************************************************************************/\n\n#include \"ApiController.hpp\"\n\n#include \"oatpp/web/server/handler/ErrorHandler.hpp\"\n\nnamespace oatpp { namespace web { namespace server { namespace api {\n\nconst Endpoints& ApiController::getEndpoints() {\n return m_endpoints;\n}\n\nvoid ApiController::setEndpointInfo(const std::string& endpointName, const std::shared_ptr& info){\n m_endpointInfo[endpointName] = info;\n}\n\nstd::shared_ptr ApiController::getEndpointInfo(const std::string& endpointName) {\n return m_endpointInfo[endpointName];\n}\n\nvoid ApiController::setEndpointHandler(const std::string& endpointName, const std::shared_ptr& handler) {\n m_endpointHandlers[endpointName] = handler;\n}\n\nstd::shared_ptr ApiController::getEndpointHandler(const std::string& endpointName) {\n return m_endpointHandlers[endpointName];\n}\n\nvoid ApiController::setErrorHandler(const std::shared_ptr& errorHandler){\n m_errorHandler = errorHandler;\n if(!m_errorHandler) {\n m_errorHandler = handler::DefaultErrorHandler::createShared();\n }\n}\n\nstd::shared_ptr ApiController::handleError(const std::exception_ptr& exceptionPtr) const {\n\n if(m_errorHandler) {\n return m_errorHandler->handleError(exceptionPtr);\n }\n\n if(exceptionPtr) {\n std::rethrow_exception(exceptionPtr);\n }\n\n throw std::runtime_error(\"[oatpp::web::server::api::ApiController::handleError()]: Error. 'exceptionPtr' is not set.\");\n\n}\n\nvoid ApiController::setDefaultAuthorizationHandler(const std::shared_ptr& authorizationHandler){\n m_defaultAuthorizationHandler = authorizationHandler;\n}\n\nstd::shared_ptr ApiController::getDefaultAuthorizationHandler() {\n return m_defaultAuthorizationHandler;\n}\n\nstd::shared_ptr ApiController::handleDefaultAuthorization(const String &authHeader) const {\n if(m_defaultAuthorizationHandler) {\n return m_defaultAuthorizationHandler->handleAuthorization(authHeader);\n }\n // If Authorization is not setup on the server then it's 500\n throw oatpp::web::protocol::http::HttpError(Status::CODE_500, \"Authorization is not setup.\");\n}\n\nconst std::shared_ptr& ApiController::getDefaultObjectMapper() const {\n return m_defaultObjectMapper;\n}\n\n// Helper methods\n\nstd::shared_ptr ApiController::createResponse(const Status& status,\n const oatpp::String& str) const {\n return ResponseFactory::createResponse(status, str);\n}\n\nstd::shared_ptr ApiController::createResponse(const ApiController::Status &status) const {\n return ResponseFactory::createResponse(status);\n}\n\nstd::shared_ptr ApiController::createDtoResponse(const Status& status,\n const oatpp::Void& dto,\n const std::shared_ptr& objectMapper) const {\n return ResponseFactory::createResponse(status, dto, objectMapper);\n}\n\nstd::shared_ptr ApiController::createDtoResponse(const Status& status,\n const oatpp::Void& dto) const {\n return ResponseFactory::createResponse(status, dto, m_defaultObjectMapper);\n}\n\n}}}}\n\n// Path: src/oatpp/web/protocol/CommunicationError.hpp\n/***************************************************************************\n *\n * Project _____ __ ____ _ _\n * ( _ ) /__\\ (_ _)_| |_ _| |_\n * )(_)( /(__)\\ )( (_ _)(_ _)\n * (_____)(__)(__)(__) |_| |_|\n *\n *\n * Copyright 2018-present, Leonid Stryzhevskyi \n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n *\n ***************************************************************************/\n\n#ifndef oatpp_web_protocol_CommunicationError_hpp\n#define oatpp_web_protocol_CommunicationError_hpp\n\n#include \"oatpp/core/IODefinitions.hpp\"\n\nnamespace oatpp { namespace web { namespace protocol {\n\n/**\n * Communication Error\n */\nclass CommunicationError : public std::runtime_error {\nprivate:\n oatpp::v_io_size m_ioStatus;\n oatpp::String m_message;\npublic:\n\n /**\n * Constructor.\n * @param ioStatus - I/O error. See &id:oatpp::v_io_size;.\n * @param message - error message.\n */\n CommunicationError(oatpp::v_io_size ioStatus, const oatpp::String& message);\n\n /**\n * Get I/O error. See &id:oatpp::v_io_size;.\n * @return &id:oatpp::v_io_size;.\n */\n oatpp::v_io_size getIOStatus();\n\n /**\n * Get error message.\n * @return - error message.\n */\n oatpp::String& getMessage();\n \n};\n\n/**\n * Protocol Error Info.\n */\ntemplate\nstruct ProtocolErrorInfo {\n\n /**\n * Default Constructor.\n */\n ProtocolErrorInfo()\n : ioStatus(0)\n {}\n\n /**\n * Constructor.\n * @param pIOStatus - I/O level error. See &id:oatpp::v_io_size;.\n * @param pStatus - configurable arbitrary data type.\n */\n ProtocolErrorInfo(oatpp::v_io_size pIOStatus, const Status& pStatus)\n : ioStatus(pIOStatus)\n , status(pStatus)\n {}\n\n /**\n * Get I/O level error.\n */\n oatpp::v_io_size ioStatus;\n\n /**\n * Configurable arbitrary data type.\n */\n Status status;\n\n};\n\n/**\n * Protocol Error template.\n * @tparam Status - arbitrary data type.\n */\ntemplate\nclass ProtocolError : public CommunicationError {\npublic:\n /**\n * Cenvenience typedef for ProtocolErrorInfo\n */\n typedef ProtocolErrorInfo Info;\nprivate:\n Info m_info;\npublic:\n\n /**\n * Constructor.\n * @param info - &l:ProtocolError::Info ;.\n * @param message - error message.\n */\n \nProtocolError(const Info& info, const oatpp::String& message)\n : CommunicationError(info.ioStatus, message)\n , m_info(info)\n {}\n\n /**\n * Get error info.\n * @return - error info.\n */\n Info getInfo() {\n return m_info;\n }\n \n};\n\n\n/**\n * Protocol Error template.\n * @tparam Status - arbitrary data type.\n */\ntemplate\nclass AsyncProtocolError : public oatpp::AsyncIOError {\npublic:\n /**\n * Cenvenience typedef for ProtocolErrorInfo\n */\n typedef ProtocolErrorInfo Info;\nprivate:\n Info m_info;\n oatpp::String m_message;\npublic:\n\n /**\n * Constructor.\n * @param info - &l:ProtocolError::Info ;.\n * @param message - error message.\n */\n AsyncProtocolError(const Info& info, const oatpp::String& message)\n : oatpp::AsyncIOError(\"AsyncProtocolError\", info.ioStatus)\n , m_info(info)\n , m_message(message)\n {}\n\n /**\n * Error message.\n * @return - error message.\n */\n oatpp::String getMessage() {\n return m_message;\n }\n\n /**\n * Get error info.\n * @return - error info.\n */\n Info getInfo() {\n return m_info;\n }\n\n};\n \n}}}\n\n#endif /* oatpp_web_protocol_CommunicationError_hpp */\n\n// Path: src/oatpp/web/protocol/CommunicationError.cpp\n/***************************************************************************\n *\n * Project _____ __ ____ _ _\n * ( _ ) /__\\ (_ _)_| |_ _| |_\n * )(_)( /(__)\\ )( (_ _)(_ _)\n * (_____)(__)(__)(__) |_| |_|\n *\n *\n * Copyright 2018-present, Leonid Stryzhevskyi \n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n *\n ***************************************************************************/\n\n#include \"CommunicationError.hpp\"\n\nnamespace oatpp { namespace web { namespace protocol {\n \nCommunicationError::CommunicationError(oatpp::v_io_size ioStatus, const oatpp::String& message)\n :std::runtime_error(*message)\n , m_ioStatus(ioStatus)\n , m_message(message)\n{}\n \noatpp::v_io_size CommunicationError::getIOStatus() {\n return m_ioStatus;\n}\n\noatpp::String& CommunicationError::getMessage(){\n return m_message;\n}\n \n}}}\n\n// Path: src/oatpp/web/protocol/http/Http.hpp\n/***************************************************************************\n *\n * Project _____ __ ____ _ _\n * ( _ ) /__\\ (_ _)_| |_ _| |_\n * )(_)( /(__)\\ )( (_ _)(_ _)\n * (_____)(__)(__)(__) |_| |_|\n *\n *\n * Copyright 2018-present, Leonid Stryzhevskyi \n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n *\n ***************************************************************************/\n\n#ifndef oatpp_web_protocol_http_Http_hpp\n#define oatpp_web_protocol_http_Http_hpp\n\n#include \"oatpp/network/tcp/Connection.hpp\"\n\n#include \"oatpp/web/protocol/CommunicationError.hpp\"\n\n#include \"oatpp/core/parser/Caret.hpp\"\n#include \"oatpp/core/data/share/LazyStringMap.hpp\"\n#include \"oatpp/core/Types.hpp\"\n\n#include \n#include \n\nnamespace oatpp { namespace web { namespace protocol { namespace http {\n\n/**\n * Typedef for headers map. Headers map key is case-insensitive.\n * For more info see &id:oatpp::data::share::LazyStringMultimap;.\n */\ntypedef oatpp::data::share::LazyStringMultimap Headers;\n\n/**\n * Typedef for query parameters map.\n * For more info see &id:oatpp::data::share::LazyStringMultimap;.\n */\ntypedef oatpp::data::share::LazyStringMultimap QueryParams;\n\n/**\n * Http status.\n */\nclass Status{\npublic:\n\n /**\n * Continue.\n */\n static const Status CODE_100;// Continue\n\n /**\n * Switching Protocols.\n */\n static const Status CODE_101;// Switching\n\n /**\n * Processing.\n */\n static const Status CODE_102;// Processing\n\n /**\n * OK.\n */\n static const Status CODE_200;// OK\n\n /**\n * Created.\n */\n static const Status CODE_201;// Created\n\n /**\n * Accepted.\n */\n static const Status CODE_202;// Accepted\n\n /**\n * Non-Authoritative Information.\n */\n static const Status CODE_203;// Non-Authoritative Information\n\n /**\n * No Content.\n */\n static const Status CODE_204;// No Content\n\n /**\n * Reset Content.\n */\n static const Status CODE_205;// Reset Content\n\n /**\n * Partial Content.\n */\n static const Status CODE_206;// Partial Content\n\n /**\n * Multi-Status.\n */\n static const Status CODE_207;// Multi-Status\n\n /**\n * IM Used.\n */\n static const Status CODE_226;// IM Used\n\n /**\n * Multiple Choices.\n */\n static const Status CODE_300;// Multiple Choices\n\n /**\n * Moved Permanently.\n */\n static const Status CODE_301;// Moved Permanently\n\n /**\n * Moved Temporarily.\n */\n static const Status CODE_302;// Moved Temporarily\n\n /**\n * See Other.\n */\n static const Status CODE_303;// See Other\n\n /**\n * Not Modified.\n */\n static const Status CODE_304;// Not Modified\n\n /**\n * Use Proxy.\n */\n static const Status CODE_305;// Use Proxy\n\n /**\n * Reserved.\n */\n static const Status CODE_306;// Reserved\n\n /**\n * Temporary Redirect.\n */\n static const Status CODE_307;// Temporary Redirect\n\n /**\n * Bad Request.\n */\n static const Status CODE_400;// Bad Request\n\n /**\n * Unauthorized.\n */\n static const Status CODE_401;// Unauthorized\n\n /**\n * Payment Required.\n */\n static const Status CODE_402;// Payment Required\n\n /**\n * Forbidden.\n */\n static const Status CODE_403;// Forbidden\n\n /**\n * Not Found.\n */\n static const Status CODE_404;// Not Found\n\n /**\n * Method Not Allowed.\n */\n static const Status CODE_405;// Method Not Allowed\n\n /**\n * Not Acceptable.\n */\n static const Status CODE_406;// Not Acceptable\n\n /**\n * Proxy Authentication Required.\n */\n static const Status CODE_407;// Proxy Authentication Required\n\n /**\n * Request Timeout.\n */\n static const Status CODE_408;// Request Timeout\n\n /**\n * Conflict.\n */\n static const Status CODE_409;// Conflict\n\n /**\n * Gone\n */\n static const Status CODE_410;// Gone\n\n /**\n * Length Required.\n */\n static const Status CODE_411;// Length Required\n\n /**\n * Precondition Failed.\n */\n static const Status CODE_412;// Precondition Failed\n\n /**\n * Request Entity Too Large.\n */\n static const Status CODE_413;// Request Entity Too Large\n\n /**\n * Request-URI Too Large.\n */\n static const Status CODE_414;// Request-URI Too Large\n\n /**\n * Unsupported Media Type.\n */\n static const Status CODE_415;// Unsupported Media Type\n\n /**\n * Requested Range Not Satisfiable.\n */\n static const Status CODE_416;// Requested Range Not Satisfiable\n\n /**\n * Expectation Failed.\n */\n static const Status CODE_417;// Expectation Failed\n\n /**\n * I'm a Teapot (rfc7168 2.3.3)\n */\n static const Status CODE_418;// I'm a teapot\n\n /**\n * Unprocessable Entity.\n */\n static const Status CODE_422;// Unprocessable Entity\n\n /**\n * Locked.\n */\n static const Status CODE_423;// Locked\n\n /**\n * Failed Dependency.\n */\n static const Status CODE_424;// Failed Dependency\n\n /**\n * Unordered Collection.\n */\n static const Status CODE_425;// Unordered Collection\n\n /**\n * Upgrade Required.\n */\n static const Status CODE_426;// Upgrade Required\n\n /**\n * Precondition Required.\n */\n static const Status CODE_428;// Precondition Required\n\n /**\n * Too Many Requests.\n */\n static const Status CODE_429;// Too Many Requests\n\n /**\n * Request Header Fields Too Large.\n */\n static const Status CODE_431;// Request Header Fields Too Large\n\n /**\n * Requested host unavailable.\n */\n static const Status CODE_434;// Requested host unavailable\n\n /**\n * Close connection withot sending headers.\n */\n static const Status CODE_444;// Close connection withot sending headers\n\n /**\n * Retry With.\n */\n static const Status CODE_449;// Retry With\n\n /**\n * Unavailable For Legal Reasons.\n */\n static const Status CODE_451;// Unavailable For Legal Reasons\n\n /**\n * Internal Server Error.\n */\n static const Status CODE_500;// Internal Server Error\n\n /**\n * Not Implemented.\n */\n static const Status CODE_501;// Not Implemented\n\n /**\n * Bad Gateway.\n */\n static const Status CODE_502;// Bad Gateway\n\n /**\n * Service Unavailable.\n */\n static const Status CODE_503;// Service Unavailable\n\n /**\n * Gateway Timeout.\n */\n static const Status CODE_504;// Gateway Timeout\n\n /**\n * HTTP Version Not Supported.\n */\n static const Status CODE_505;// HTTP Version Not Supported\n\n /**\n * Variant Also Negotiates.\n */\n static const Status CODE_506;// Variant Also Negotiates\n\n /**\n * Insufficient Storage.\n */\n static const Status CODE_507;// Insufficient Storage\n\n /**\n * Loop Detected.\n */\n static const Status CODE_508;// Loop Detected\n\n /**\n * Bandwidth Limit Exceeded.\n */\n static const Status CODE_509;// Bandwidth Limit Exceeded\n\n /**\n * Not Extended.\n */\n static const Status CODE_510;// Not Extended\n\n /**\n * Network Authentication Required.\n */\n static const Status CODE_511;// Network Authentication Required\n\n /**\n * Constructor.\n */\n Status()\n : code(0)\n , description(nullptr)\n {}\n\n /**\n * Constructor.\n * @param pCode - status code.\n * @param pDesc - description.\n */\n Status(v_int32 pCode, const char* pDesc)\n : code(pCode)\n , description(pDesc)\n {}\n\n /**\n * Status code.\n */\n v_int32 code;\n\n /**\n * Description.\n */\n const char* description;\n \n bool operator==(const Status& other) const {\n return this->code == other.code;\n }\n \n bool operator!=(const Status& other) const {\n return this->code != other.code;\n }\n \n};\n\n/**\n * HttpError extends &id:oatpp::web::protocol::ProtocolError;<&l:Status;>.\n */\nclass HttpError : public protocol::ProtocolError {\nprivate:\n Headers m_headers;\npublic:\n\n /**\n * Constructor.\n * @param info\n * @param message\n */\n HttpError(const Info& info, const oatpp::String& message)\n : protocol::ProtocolError(info, message)\n {}\n\n /**\n * Constructor.\n * @param status\n * @param message\n */\n HttpError(const Status& status, const oatpp::String& message)\n : protocol::ProtocolError(Info(0, status), message)\n {}\n\n /**\n * Constructor.\n * @param status\n * @param message\n * @param headers\n */\n HttpError(const Status& status, const oatpp::String& message, const Headers& headers)\n : protocol::ProtocolError(Info(0, status), message)\n , m_headers(headers)\n {}\n\n /**\n * Get headers\n * @return\n */\n const Headers& getHeaders() const {\n return m_headers;\n }\n \n};\n\n/**\n * Throw &l:HttpError; if assertion failed.\n * @param COND - boolean statement. If evaluates to false - throw error.\n * @param STATUS - &l:Status;.\n * @param MESSAGE - String message.\n */\n#define OATPP_ASSERT_HTTP(COND, STATUS, MESSAGE) \\\nif(!(COND)) { throw oatpp::web::protocol::http::HttpError(STATUS, MESSAGE); }\n\n/**\n * Collection of HTTP Header constants.\n */\nclass Header {\npublic:\n\n /**\n * Possible values for headers.\n */\n class Value {\n public:\n static const char* const CONNECTION_CLOSE;\n static const char* const CONNECTION_KEEP_ALIVE;\n static const char* const CONNECTION_UPGRADE;\n \n static const char* const SERVER;\n static const char* const USER_AGENT;\n \n static const char* const TRANSFER_ENCODING_CHUNKED;\n static const char* const CONTENT_TYPE_APPLICATION_JSON;\n\n static const char* const EXPECT_100_CONTINUE;\n };\npublic:\n static const char* const ACCEPT; // \"Accept\"\n static const char* const AUTHORIZATION; // \"Authorization\"\n static const char* const WWW_AUTHENTICATE; // \"WWW-Authenticate\"\n static const char* const CONNECTION; // \"Connection\"\n static const char* const TRANSFER_ENCODING; // \"Transfer-Encoding\"\n static const char* const CONTENT_ENCODING; // \"Content-Encoding\"\n static const char* const CONTENT_LENGTH; // \"Content-Length\"\n static const char* const CONTENT_TYPE; // \"Content-Type\"\n static const char* const CONTENT_RANGE; // \"Content-Range\"\n static const char* const RANGE; // \"Range\"\n static const char* const HOST; // \"Host\"\n static const char* const USER_AGENT; // \"User-Agent\"\n static const char* const SERVER; // \"Server\"\n static const char* const UPGRADE; // \"Upgrade\"\n static const char* const CORS_ORIGIN; // Access-Control-Allow-Origin\n static const char* const CORS_METHODS; // Access-Control-Allow-Methods\n static const char* const CORS_HEADERS; // Access-Control-Allow-Headers\n static const char* const CORS_MAX_AGE; // Access-Control-Max-Age\n static const char* const ACCEPT_ENCODING; // Accept-Encoding\n static const char* const EXPECT; // Expect\n};\n \nclass Range {\npublic:\n static const char* const UNIT_BYTES;\nprivate:\n Range()\n : units(nullptr)\n {}\npublic:\n \n Range(const oatpp::String& pUnits,\n v_int64 pStart,\n v_int64 pEnd)\n : units(pUnits)\n , start(pStart)\n , end(pEnd)\n {}\n \n oatpp::String units;\n v_int64 start;\n v_int64 end;\n \n oatpp::String toString() const;\n \n bool isValid() const {\n return units.get() != nullptr;\n }\n \n static Range parse(oatpp::parser::Caret& caret);\n static Range parse(const oatpp::String& str);\n \n};\n \nclass ContentRange {\npublic:\n static const char* const UNIT_BYTES;\nprivate:\n ContentRange()\n : units(nullptr)\n {}\npublic:\n \n ContentRange(const oatpp::String& pUnits,\n v_int64 pStart,\n v_int64 pEnd,\n v_int64 pSize,\n bool pIsSizeKnown)\n : units(pUnits)\n , start(pStart)\n , end(pEnd)\n , size(pSize)\n , isSizeKnown(pIsSizeKnown)\n {}\n \n oatpp::String units;\n v_int64 start;\n v_int64 end;\n v_int64 size;\n bool isSizeKnown;\n \n oatpp::String toString() const;\n \n bool isValid() const {\n return units.get() != nullptr;\n }\n \n static ContentRange parse(oatpp::parser::Caret& caret);\n static ContentRange parse(const oatpp::String& str);\n \n};\n\n/**\n * Struct representing HTTP request starting line.\n * Example request starting line: `GET /path/to/resource/ HTTP/1.1`.\n */\nstruct RequestStartingLine {\n /**\n * Method as &id:oatpp::data::share::StringKeyLabel;.\n */\n oatpp::data::share::StringKeyLabel method; // GET, POST ...\n\n /**\n * Path as &id:oatpp::data::share::StringKeyLabel;.\n */\n oatpp::data::share::StringKeyLabel path;\n\n /**\n * Protocol as &id:oatpp::data::share::StringKeyLabel;.\n */\n oatpp::data::share::StringKeyLabel protocol;\n};\n\n/**\n * Struct representing HTTP response starting line.\n * Example response starting line: `HTTP/1.1 200 OK`.\n */\nstruct ResponseStartingLine {\n /**\n * Protocol as &id:oatpp::data::share::StringKeyLabel;.\n */\n oatpp::data::share::StringKeyLabel protocol;\n\n /**\n * Status code as v_int32.\n */\n v_int32 statusCode;\n\n /**\n * Description as &id:oatpp::data::share::StringKeyLabel;.\n */\n oatpp::data::share::StringKeyLabel description;\n};\n\n/**\n * Data contained in the value of one header.\n */\nstruct HeaderValueData {\n\n /**\n * value tokens.\n */\n std::unordered_set tokens;\n\n /**\n * Title params.\n */\n std::unordered_map titleParams;\n\n /**\n * Get title parm value by key.\n * @param key\n * @return\n */\n oatpp::String getTitleParamValue(const data::share::StringKeyLabelCI& key) const;\n\n};\n\n/**\n * Oatpp Http parser.\n */\nclass Parser {\nprivate:\n static oatpp::data::share::StringKeyLabelCI parseHeaderNameLabel(const std::shared_ptr& headersText,\n oatpp::parser::Caret& caret);\npublic:\n\n /**\n * Parse &l:RequestStartingLine;.\n * @param line - &l:RequestStartingLine;. Values will be set to line's fields.\n * @param headersText - `std::shared_ptr` to `std::string` needed as a \"memory handle\" for\n * &l:RequestStartingLine; fields. See &id:oatpp::data::share::MemoryLabel;.\n * @param caret - &id:oatpp::parser::Caret;.\n * @param error - out parameter &l:Status;.\n */\n static void parseRequestStartingLine(RequestStartingLine& line,\n const std::shared_ptr& headersText,\n oatpp::parser::Caret& caret,\n Status& error);\n\n /**\n * Parse &l:ResponseStartingLine;.\n * @param line - &l:ResponseStartingLine;. Values will be set to line's fields.\n * @param headersText - `std::shared_ptr` to `std::string` needed as a \"memory handle\" for\n * &l:ResponseStartingLine; fields. See &id:oatpp::data::share::MemoryLabel;.\n * @param caret - &id:oatpp::parser::Caret;.\n * @param error - out parameter &l:Status;.\n */\n static void parseResponseStartingLine(ResponseStartingLine& line,\n const std::shared_ptr& headersText,\n oatpp::parser::Caret& caret,\n Status& error);\n\n /**\n * Parse one header line. Example of the header line:\n * `\"Content-Type: application/json\\r\\n\"`.\n * @param headers - &l:Headers; map to put parsed header to.\n * @param headersText - `std::shared_ptr` to `std::string` needed as a \"memory handle\" for\n * &l:Headers; values. See &id:oatpp::data::share::MemoryLabel;.\n * @param caret - &id:oatpp::parser::Caret;.\n * @param error - out parameter &l:Status;.\n */\n static void parseOneHeader(Headers& headers,\n const std::shared_ptr& headersText,\n oatpp::parser::Caret& caret,\n Status& error);\n\n /**\n * Parse HTTP headers to &l:Headers; map.\n * @param headers - &l:Headers; map to put parsed headers to.\n * @param headersText - `std::shared_ptr` to `std::string` needed as a \"memory handle\" for\n * &l:Headers; values. See &id:oatpp::data::share::MemoryLabel;.\n * @param caret - &id:oatpp::parser::Caret;.\n * @param error - out parameter &l:Status;.\n */\n static void parseHeaders(Headers& headers,\n const std::shared_ptr& headersText,\n oatpp::parser::Caret& caret,\n Status& error);\n\n /**\n * Parse data that is contained in a one header.\n * @param data - out. parsed data.\n * @param headerValue - header value string.\n * @param separator - subvalues separator.\n */\n static void parseHeaderValueData(HeaderValueData& data, const oatpp::data::share::StringKeyLabel& headerValue, char separator);\n\n};\n\n/**\n * Http utils.\n */\nclass Utils {\npublic:\n\n /**\n * Write headers map to stream.\n * @param headers\n * @param stream\n */\n static void writeHeaders(const Headers& headers, data::stream::ConsistentOutputStream* stream);\n\n};\n \n}}}}\n\nnamespace std {\n \n template<>\n struct hash {\n \n typedef oatpp::web::protocol::http::Status argument_type;\n typedef v_uint64 result_type;\n \n result_type operator()(oatpp::web::protocol::http::Status const& s) const noexcept {\n return static_cast(s.code);\n }\n \n };\n}\n\n#endif /* oatpp_web_protocol_http_Http_hpp */\n\n// Path: src/oatpp/web/protocol/http/Http.cpp\n/***************************************************************************\n *\n * Project _____ __ ____ _ _\n * ( _ ) /__\\ (_ _)_| |_ _| |_\n * )(_)( /(__)\\ )( (_ _)(_ _)\n * (_____)(__)(__)(__) |_| |_|\n *\n *\n * Copyright 2018-present, Leonid Stryzhevskyi \n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n *\n ***************************************************************************/\n\n#include \"./Http.hpp\"\n\n#include \"oatpp/core/data/stream/BufferStream.hpp\"\n#include \"oatpp/core/utils/ConversionUtils.hpp\"\n\nnamespace oatpp { namespace web { namespace protocol { namespace http {\n \nconst Status Status::CODE_100(100, \"Continue\");\nconst Status Status::CODE_101(101, \"Switching\");\nconst Status Status::CODE_102(102, \"Processing\");\n\nconst Status Status::CODE_200(200, \"OK\");\nconst Status Status::CODE_201(201, \"Created\");\nconst Status Status::CODE_202(202, \"Accepted\");\nconst Status Status::CODE_203(203, \"Non-Authoritative Information\");\nconst Status Status::CODE_204(204, \"No Content\");\nconst Status Status::CODE_205(205, \"Reset Content\");\nconst Status Status::CODE_206(206, \"Partial Content\");\nconst Status Status::CODE_207(207, \"Multi-Status\");\nconst Status Status::CODE_226(226, \"IM Used\");\n\nconst Status Status::CODE_300(300, \"Multiple Choices\");\nconst Status Status::CODE_301(301, \"Moved Permanently\");\nconst Status Status::CODE_302(302, \"Moved Temporarily\");\nconst Status Status::CODE_303(303, \"See Other\");\nconst Status Status::CODE_304(304, \"Not Modified\");\nconst Status Status::CODE_305(305, \"Use Proxy\");\nconst Status Status::CODE_306(306, \"Reserved\");\nconst Status Status::CODE_307(307, \"Temporary Redirect\");\n\nconst Status Status::CODE_400(400, \"Bad Request\");\nconst Status Status::CODE_401(401, \"Unauthorized\");\nconst Status Status::CODE_402(402, \"Payment Required\");\nconst Status Status::CODE_403(403, \"Forbidden\");\nconst Status Status::CODE_404(404, \"Not Found\");\nconst Status Status::CODE_405(405, \"Method Not Allowed\");\nconst Status Status::CODE_406(406, \"Not Acceptable\");\nconst Status Status::CODE_407(407, \"Proxy Authentication Required\");\nconst Status Status::CODE_408(408, \"Request Timeout\");\nconst Status Status::CODE_409(409, \"Conflict\");\nconst Status Status::CODE_410(410, \"Gone\");\nconst Status Status::CODE_411(411, \"Length Required\");\nconst Status Status::CODE_412(412, \"Precondition Failed\");\nconst Status Status::CODE_413(413, \"Request Entity Too Large\");\nconst Status Status::CODE_414(414, \"Request-URI Too Large\");\nconst Status Status::CODE_415(415, \"Unsupported Media Type\");\nconst Status Status::CODE_416(416, \"Requested Range Not Satisfiable\");\nconst Status Status::CODE_417(417, \"Expectation Failed\");\nconst Status Status::CODE_418(418, \"I'm a Teapot\");\nconst Status Status::CODE_422(422, \"Unprocessable Entity\");\nconst Status Status::CODE_423(423, \"Locked\");\nconst Status Status::CODE_424(424, \"Failed Dependency\");\nconst Status Status::CODE_425(425, \"Unordered Collection\");\nconst Status Status::CODE_426(426, \"Upgrade Required\");\nconst Status Status::CODE_428(428, \"Precondition Required\");\nconst Status Status::CODE_429(429, \"Too Many Requests\");\nconst Status Status::CODE_431(431, \"Request Header Fields Too Large\");\nconst Status Status::CODE_434(434, \"Requested host unavailable\");\nconst Status Status::CODE_444(444, \"Close connection withot sending headers\");\nconst Status Status::CODE_449(449, \"Retry With\");\nconst Status Status::CODE_451(451, \"Unavailable For Legal Reasons\");\n\nconst Status Status::CODE_500(500, \"Internal Server Error\");\nconst Status Status::CODE_501(501, \"Not Implemented\");\nconst Status Status::CODE_502(502, \"Bad Gateway\");\nconst Status Status::CODE_503(503, \"Service Unavailable\");\nconst Status Status::CODE_504(504, \"Gateway Timeout\");\nconst Status Status::CODE_505(505, \"HTTP Version Not Supported\");\nconst Status Status::CODE_506(506, \"Variant Also Negotiates\");\nconst Status Status::CODE_507(507, \"Insufficient Storage\");\nconst Status Status::CODE_508(508, \"Loop Detected\");\nconst Status Status::CODE_509(509, \"Bandwidth Limit Exceeded\");\nconst Status Status::CODE_510(510, \"Not Extended\");\nconst Status Status::CODE_511(511, \"Network Authentication Required\");\n\nconst char* const Header::Value::CONNECTION_CLOSE = \"close\";\nconst char* const Header::Value::CONNECTION_KEEP_ALIVE = \"keep-alive\";\nconst char* const Header::Value::CONNECTION_UPGRADE = \"Upgrade\";\n \nconst char* const Header::Value::SERVER = \"oatpp/\" OATPP_VERSION;\nconst char* const Header::Value::USER_AGENT = \"oatpp/\" OATPP_VERSION;\n \nconst char* const Header::Value::TRANSFER_ENCODING_CHUNKED = \"chunked\";\n \nconst char* const Header::Value::CONTENT_TYPE_APPLICATION_JSON = \"application/json\";\n\nconst char* const Header::Value::EXPECT_100_CONTINUE = \"100-continue\";\n \nconst char* const Header::ACCEPT = \"Accept\";\nconst char* const Header::AUTHORIZATION = \"Authorization\";\nconst char* const Header::WWW_AUTHENTICATE = \"WWW-Authenticate\";\nconst char* const Header::CONNECTION = \"Connection\";\nconst char* const Header::TRANSFER_ENCODING = \"Transfer-Encoding\";\nconst char* const Header::CONTENT_ENCODING = \"Content-Encoding\";\nconst char* const Header::CONTENT_LENGTH = \"Content-Length\";\nconst char* const Header::CONTENT_TYPE = \"Content-Type\";\nconst char* const Header::CONTENT_RANGE = \"Content-Range\";\nconst char* const Header::RANGE = \"Range\";\nconst char* const Header::HOST = \"Host\";\nconst char* const Header::USER_AGENT = \"User-Agent\";\nconst char* const Header::SERVER = \"Server\";\nconst char* const Header::UPGRADE = \"Upgrade\";\n\n\nconst char* const Header::CORS_ORIGIN = \"Access-Control-Allow-Origin\";\nconst char* const Header::CORS_METHODS = \"Access-Control-Allow-Methods\";\nconst char* const Header::CORS_HEADERS = \"Access-Control-Allow-Headers\";\nconst char* const Header::CORS_MAX_AGE = \"Access-Control-Max-Age\";\n\nconst char* const Header::ACCEPT_ENCODING = \"Accept-Encoding\";\n\nconst char* const Header::EXPECT = \"Expect\";\n\nconst char* const Range::UNIT_BYTES = \"bytes\";\nconst char* const ContentRange::UNIT_BYTES = \"bytes\";\n \noatpp::String Range::toString() const {\n data::stream::BufferOutputStream stream(256);\n stream.writeSimple(units->data(), static_cast(units->size()));\n stream.writeSimple(\"=\", 1);\n stream.writeAsString(start);\n stream.writeSimple(\"-\", 1);\n stream.writeAsString(end);\n return stream.toString();\n}\n\nRange Range::parse(oatpp::parser::Caret& caret) {\n\n auto unitsLabel = caret.putLabel();\n if(caret.findChar('=')) {\n unitsLabel.end();\n caret.inc();\n } else {\n caret.setError(\"'=' - expected\");\n return Range();\n }\n\n auto startLabel = caret.putLabel();\n if(caret.findChar('-')) {\n startLabel.end();\n caret.inc();\n } else {\n caret.setError(\"'-' - expected\");\n return Range();\n }\n\n auto endLabel = caret.putLabel();\n caret.findRN();\n endLabel.end();\n\n auto start = oatpp::utils::conversion::strToInt64(startLabel.getData());\n auto end = oatpp::utils::conversion::strToInt64(endLabel.getData());\n return Range(unitsLabel.toString(), start, end);\n \n}\n\nRange Range::parse(const oatpp::String& str) {\n oatpp::parser::Caret caret(str);\n return parse(caret);\n}\n\noatpp::String ContentRange::toString() const {\n data::stream::BufferOutputStream stream(256);\n stream.writeSimple(units->data(), static_cast(units->size()));\n stream.writeSimple(\" \", 1);\n stream.writeAsString(start);\n stream.writeSimple(\"-\", 1);\n stream.writeAsString(end);\n stream.writeSimple(\"/\", 1);\n if(isSizeKnown) {\n stream.writeAsString(size);\n } else {\n stream.writeSimple(\"*\", 1);\n }\n return stream.toString();\n}\n\nContentRange ContentRange::parse(oatpp::parser::Caret& caret) {\n\n auto unitsLabel = caret.putLabel();\n if(caret.findChar(' ')) {\n unitsLabel.end();\n caret.inc();\n } else {\n caret.setError(\"' ' - expected\");\n return ContentRange();\n }\n\n auto startLabel = caret.putLabel();\n if(caret.findChar('-')) {\n startLabel.end();\n caret.inc();\n } else {\n caret.setError(\"'-' - expected\");\n return ContentRange();\n }\n\n auto endLabel = caret.putLabel();\n if(caret.findChar('/')) {\n endLabel.end();\n caret.inc();\n } else {\n caret.setError(\"'/' - expected\");\n return ContentRange();\n }\n\n auto sizeLabel = caret.putLabel();\n caret.findRN();\n sizeLabel.end();\n \n v_int64 start = oatpp::utils::conversion::strToInt64(startLabel.getData());\n v_int64 end = oatpp::utils::conversion::strToInt64(endLabel.getData());\n v_int64 size = 0;\n bool isSizeKnown = false;\n if(sizeLabel.getData()[0] != '*') {\n isSizeKnown = true;\n size = oatpp::utils::conversion::strToInt64(sizeLabel.getData());\n }\n \n return ContentRange(unitsLabel.toString(), start, end, size, isSizeKnown);\n \n}\n\nContentRange ContentRange::parse(const oatpp::String& str) {\n oatpp::parser::Caret caret(str);\n return parse(caret);\n}\n\n\noatpp::String HeaderValueData::getTitleParamValue(const data::share::StringKeyLabelCI& key) const {\n auto it = titleParams.find(key);\n if(it != titleParams.end()) {\n return it->second.toString();\n }\n return nullptr;\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////\n// Parser\n \noatpp::data::share::StringKeyLabelCI Parser::parseHeaderNameLabel(const std::shared_ptr& headersText,\n oatpp::parser::Caret& caret) {\n const char* data = caret.getData();\n for(v_buff_size i = caret.getPosition(); i < caret.getDataSize(); i++) {\n v_char8 a = static_cast(data[i]);\n if(a == ':' || a == ' '){\n oatpp::data::share::StringKeyLabelCI label(headersText, &data[caret.getPosition()], i - caret.getPosition());\n caret.setPosition(i);\n return label;\n \n }\n }\n return oatpp::data::share::StringKeyLabelCI(nullptr, nullptr, 0);\n}\n \nvoid Parser::parseRequestStartingLine(RequestStartingLine& line,\n const std::shared_ptr& headersText,\n oatpp::parser::Caret& caret,\n Status& error) {\n\n auto methodLabel = caret.putLabel();\n if(caret.findChar(' ')){\n line.method = oatpp::data::share::StringKeyLabel(headersText, methodLabel.getData(), methodLabel.getSize());\n caret.inc();\n } else {\n error = Status::CODE_400;\n return;\n }\n\n auto pathLabel = caret.putLabel();\n if(caret.findChar(' ')){\n line.path = oatpp::data::share::StringKeyLabel(headersText, pathLabel.getData(), pathLabel.getSize());\n caret.inc();\n } else {\n error = Status::CODE_400;\n return;\n }\n\n auto protocolLabel = caret.putLabel();\n if(caret.findRN()){\n line.protocol = oatpp::data::share::StringKeyLabel(headersText, protocolLabel.getData(), protocolLabel.getSize());\n caret.skipRN();\n } else {\n error = Status::CODE_400;\n return;\n }\n \n}\n \nvoid Parser::parseResponseStartingLine(ResponseStartingLine& line,\n const std::shared_ptr& headersText,\n oatpp::parser::Caret& caret,\n Status& error) {\n\n auto protocolLabel = caret.putLabel();\n if(caret.findChar(' ')){\n line.protocol = oatpp::data::share::StringKeyLabel(headersText, protocolLabel.getData(), protocolLabel.getSize());\n caret.inc();\n } else {\n error = Status::CODE_400;\n return;\n }\n\n line.statusCode = static_cast(caret.parseInt());\n\n auto descriptionLabel = caret.putLabel();\n if(caret.findRN()){\n line.description = oatpp::data::share::StringKeyLabel(headersText, descriptionLabel.getData(), descriptionLabel.getSize());\n caret.skipRN();\n } else {\n error = Status::CODE_400;\n return;\n }\n \n}\n \nvoid Parser::parseOneHeader(Headers& headers,\n const std::shared_ptr& headersText,\n oatpp::parser::Caret& caret,\n Status& error)\n{\n caret.skipChar(' ');\n auto name = parseHeaderNameLabel(headersText, caret);\n if(name.getData() != nullptr) {\n caret.skipChar(' ');\n if(!caret.canContinueAtChar(':', 1)) {\n error = Status::CODE_400;\n return;\n }\n caret.skipChar(' ');\n v_buff_size valuePos0 = caret.getPosition();\n caret.findRN();\n headers.put_LockFree(name, oatpp::data::share::StringKeyLabel(headersText, &caret.getData()[valuePos0], caret.getPosition() - valuePos0));\n caret.skipRN();\n } else {\n error = Status::CODE_431;\n return;\n }\n}\n\nvoid Parser::parseHeaders(Headers& headers,\n const std::shared_ptr& headersText,\n oatpp::parser::Caret& caret,\n Status& error)\n{\n \n while (!caret.isAtRN()) {\n parseOneHeader(headers, headersText, caret, error);\n if(error.code != 0) {\n return;\n }\n }\n \n caret.skipRN();\n \n}\n\nvoid Parser::parseHeaderValueData(HeaderValueData& data, const oatpp::data::share::StringKeyLabel& headerValue, char separator) {\n\n oatpp::parser::Caret caret(reinterpret_cast(headerValue.getData()), headerValue.getSize());\n\n const char charSet[5] = {' ', '=', separator, '\\r', '\\n'};\n const char charSet2[4] = {' ', separator, '\\r', '\\n'};\n\n while (caret.canContinue()) {\n\n caret.skipChar(' ');\n\n auto label = caret.putLabel();\n auto res = caret.findCharFromSet(charSet, 5);\n\n if (res == '=') {\n\n data::share::StringKeyLabelCI key(headerValue.getMemoryHandle(), label.getData(), label.getSize());\n caret.inc();\n\n if (caret.isAtChar('\"')) {\n label = caret.parseStringEnclosed('\"', '\"', '\\\\');\n } else if (caret.isAtChar('\\'')) {\n label = caret.parseStringEnclosed('\\'', '\\'', '\\\\');\n } else {\n label = caret.putLabel();\n caret.findCharFromSet(charSet2, 4);\n }\n\n data.titleParams[key] = data::share::StringKeyLabel(headerValue.getMemoryHandle(), label.getData(),\n label.getSize());\n\n } else {\n data.tokens.insert(\n data::share::StringKeyLabelCI(headerValue.getMemoryHandle(), label.getData(), label.getSize()));\n }\n\n if (caret.isAtCharFromSet(\"\\r\\n\", 2)) {\n break;\n } else if (caret.isAtChar(separator)) {\n caret.inc();\n }\n\n }\n\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////\n// Utils\n\nvoid Utils::writeHeaders(const Headers& headers, data::stream::ConsistentOutputStream* stream) {\n\n...\n", "needle_token_start": 3928, "needle_token_end": 3972, "code_context_ntokens": 15952, "output": ["```cpp\nclass CommunicationError : public std::runtime_error {\nprivate:\n oatpp::v_io_size m_ioStatus;\n oatpp::String m_message;\npublic:\n\n CommunicationError(oatpp::v_io_size ioStatus, const oatpp::String& message);\n\n oatpp::v_io_size getIOStatus();\n\n oatpp::String& getMessage();\n \n};\n```"]} +{"repo": "oatpp/oatpp", "name": "serializePrimitive", "language": "cpp", "path": "src/oatpp/parser/json/mapping/Serializer.hpp", "position_ratio": 0.35, "description": "\nFunction Description:\n1. **Purpose**: The function is designed to convert a basic data type into a JSON-compatible string format and write it to an output stream. If the data is null, it writes \"null\" to the stream.\n2. **Input**: It accepts three parameters: a pointer to the serializer instance, a pointer to a consistent output stream where the data will be written, and a polymorphic wrapper around the data to be serialized.\n3. **Output**: There is no return value; the function writes the serialized data directly to the provided output stream.\n4. **Procedure**: The function first checks if the provided data is not null. If it is not null, it casts the data to its original type and writes it as a string to the output stream. If the data is null, it writes the string \"null\" to the output stream.\n\n", "instruction": "Based on the function description and code context, please retrieve and repeat the exact described function from the code context in a code block wrapped by ```:", "template": "instruction\ncode_context\ndescription\ninstruction", "code_context": "// Path: src/oatpp/parser/json/Utils.cpp\n/***************************************************************************\n *\n * Project _____ __ ____ _ _\n * ( _ ) /__\\ (_ _)_| |_ _| |_\n * )(_)( /(__)\\ )( (_ _)(_ _)\n * (_____)(__)(__)(__) |_| |_|\n *\n *\n * Copyright 2018-present, Leonid Stryzhevskyi \n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n *\n ***************************************************************************/\n\n#include \"Utils.hpp\"\n\n#include \"oatpp/encoding/Unicode.hpp\"\n#include \"oatpp/encoding/Hex.hpp\"\n\nnamespace oatpp { namespace parser { namespace json{\n\nv_buff_size Utils::calcEscapedStringSize(const char* data, v_buff_size size, v_buff_size& safeSize, v_uint32 flags) {\n v_buff_size result = 0;\n v_buff_size i = 0;\n safeSize = size;\n while (i < size) {\n v_char8 a = static_cast(data[i]);\n if(a < 32) {\n i ++;\n\n switch (a) {\n\n case '\\b':\n case '\\f':\n case '\\n':\n case '\\r':\n case '\\t': result += 2; break; // '\\n'\n\n default:\n result += 6; // '\\uFFFF' - 6 chars\n break;\n\n }\n\n } else if(a < 128){\n i ++;\n\n switch (a) {\n case '\\\"':\n case '\\\\': result += 2; break; // '\\/'\n\n case '/':\n result ++;\n if((flags & FLAG_ESCAPE_SOLIDUS) > 0) result ++;\n break;\n\n default:\n result ++;\n break;\n\n }\n\n } else {\n v_buff_size charSize = oatpp::encoding::Unicode::getUtf8CharSequenceLength(a);\n if(charSize != 0) {\n if(i + charSize > size) {\n safeSize = i;\n }\n i += charSize;\n if (!(flags & FLAG_ESCAPE_UTF8CHAR)) {\n result += charSize; // output as-is\n } else if(charSize < 4) {\n result += 6; // '\\uFFFF' - 6 chars\n } else if(charSize == 4) {\n result += 12; // '\\uFFFF\\uFFFF' - 12 chars surrogate pair\n } else {\n result += 11; // '\\u+FFFFFFFF' - 11 chars NOT JSON standard case\n }\n } else {\n // invalid char\n i ++;\n result ++;\n }\n }\n }\n return result;\n}\n\nv_buff_size Utils::calcUnescapedStringSize(const char* data, v_buff_size size, v_int64& errorCode, v_buff_size& errorPosition) {\n errorCode = 0;\n v_buff_size result = 0;\n v_buff_size i = 0;\n \n while (i < size) {\n v_char8 a = static_cast(data[i]);\n if(a == '\\\\'){\n \n if(i + 1 == size){\n errorCode = ERROR_CODE_INVALID_ESCAPED_CHAR;\n errorPosition = i;\n return 0;\n }\n \n v_char8 b = static_cast(data[i + 1]);\n \n if(b == '\"' || b == '\\\\' || b == '/' || b == 'b' || b == 'f' || b == 'n' || b == 'r' || b == 't'){\n result += 1;\n i += 2;\n } else if(b == 'u'){\n \n if(i + 6 > size){\n errorCode = ERROR_CODE_INVALID_ESCAPED_CHAR;\n errorPosition = i;\n return 0;\n }\n \n if(data[i + 2] == '+') { // not JSON standard case\n if(i + 11 > size){\n errorCode = ERROR_CODE_INVALID_ESCAPED_CHAR;\n errorPosition = i;\n return 0;\n }\n v_uint32 code;\n errorCode = encoding::Hex::readUInt32(&data[i + 3], code);\n if(errorCode != 0){\n errorPosition = i + 3;\n return 0;\n }\n i += 11;\n result += encoding::Unicode::getUtf8CharSequenceLengthForCode(code);\n } else {\n v_uint16 code;\n errorCode = encoding::Hex::readUInt16(&data[i + 2], code);\n if(errorCode != 0){\n errorPosition = i + 2;\n return 0;\n }\n \n if(code >= 0xD800 && code <= 0xDBFF){\n if(i + 12 > size){\n errorCode = ERROR_CODE_INVALID_SURROGATE_PAIR;\n errorPosition = i;\n return 0;\n }\n v_uint16 low;\n errorCode = encoding::Hex::readUInt16(&data[i + 8], low);\n if(errorCode != 0){\n errorPosition = i + 8;\n return 0;\n }\n \n if(low >= 0xDC00 && low <= 0xDFFF){\n v_uint32 bigCode = static_cast(encoding::Unicode::utf16SurrogatePairToCode(static_cast(code), static_cast(low)));\n i += 12;\n result += encoding::Unicode::getUtf8CharSequenceLengthForCode(bigCode);\n } else {\n errorCode = ERROR_CODE_INVALID_SURROGATE_PAIR;\n errorPosition = i;\n return 0;\n }\n \n } else {\n i += 6;\n result += encoding::Unicode::getUtf8CharSequenceLengthForCode(code);\n }\n }\n \n } else {\n errorCode = ERROR_CODE_INVALID_ESCAPED_CHAR;\n errorPosition = i;\n return 0;\n }\n \n } else {\n i ++;\n result ++;\n }\n \n }\n \n return result;\n}\n \nv_buff_size Utils::escapeUtf8Char(const char* sequence, p_char8 buffer){\n v_buff_size length;\n v_int32 code = oatpp::encoding::Unicode::encodeUtf8Char(sequence, length);\n if(code < 0x00010000) {\n buffer[0] = '\\\\';\n buffer[1] = 'u';\n oatpp::encoding::Hex::writeUInt16(v_uint16(code), &buffer[2]);\n return 6;\n } else if(code < 0x00200000) {\n v_int16 high;\n v_int16 low;\n oatpp::encoding::Unicode::codeToUtf16SurrogatePair(code, high, low);\n buffer[0] = '\\\\';\n buffer[1] = 'u';\n oatpp::encoding::Hex::writeUInt16(static_cast(high), &buffer[2]);\n buffer[6] = '\\\\';\n buffer[7] = 'u';\n oatpp::encoding::Hex::writeUInt16(static_cast(low), &buffer[8]);\n return 12;\n } else {\n buffer[0] = '\\\\';\n buffer[1] = 'u';\n buffer[2] = '+';\n oatpp::encoding::Hex::writeUInt32(static_cast(code), &buffer[2]);\n return 11;\n }\n}\n\noatpp::String Utils::escapeString(const char* data, v_buff_size size, v_uint32 flags) {\n v_buff_size safeSize;\n v_buff_size escapedSize = calcEscapedStringSize(data, size, safeSize, flags);\n if(escapedSize == size) {\n return String(data, size);\n }\n auto result = String(escapedSize);\n p_char8 resultData = reinterpret_cast(const_cast(result->data()));\n v_buff_size pos = 0;\n\n {\n v_buff_size i = 0;\n while (i < safeSize) {\n v_char8 a = static_cast(data[i]);\n if (a < 32) {\n\n switch (a) {\n\n case '\\b': resultData[pos] = '\\\\'; resultData[pos + 1] = 'b'; pos += 2; break;\n case '\\f': resultData[pos] = '\\\\'; resultData[pos + 1] = 'f'; pos += 2; break;\n case '\\n': resultData[pos] = '\\\\'; resultData[pos + 1] = 'n'; pos += 2; break;\n case '\\r': resultData[pos] = '\\\\'; resultData[pos + 1] = 'r'; pos += 2; break;\n case '\\t': resultData[pos] = '\\\\'; resultData[pos + 1] = 't'; pos += 2; break;\n\n default:\n resultData[pos] = '\\\\';\n resultData[pos + 1] = 'u';\n oatpp::encoding::Hex::writeUInt16(a, &resultData[pos + 2]);\n pos += 6;\n break;\n\n }\n\n i++;\n\n }\n else if (a < 128) {\n\n switch (a) {\n case '\\\"': resultData[pos] = '\\\\'; resultData[pos + 1] = '\"'; pos += 2; break;\n case '\\\\': resultData[pos] = '\\\\'; resultData[pos + 1] = '\\\\'; pos += 2; break;\n\n case '/':\n if((flags & FLAG_ESCAPE_SOLIDUS) > 0) {\n resultData[pos] = '\\\\';\n resultData[pos + 1] = '/';\n pos += 2;\n } else {\n resultData[pos] = static_cast(data[i]);\n pos++;\n }\n break;\n\n default:\n resultData[pos] = static_cast(data[i]);\n pos++;\n break;\n\n }\n\n i++;\n }\n else {\n v_buff_size charSize = oatpp::encoding::Unicode::getUtf8CharSequenceLength(a);\n if (charSize != 0) {\n if (!(flags & FLAG_ESCAPE_UTF8CHAR)) {\n std::memcpy(reinterpret_cast(&resultData[pos]), reinterpret_cast(const_cast(&data[i])), static_cast(charSize));\n pos += charSize;\n }\n else {\n pos += escapeUtf8Char(&data[i], &resultData[pos]);\n }\n i += charSize;\n }\n else {\n // invalid char\n resultData[pos] = static_cast(data[i]);\n i++;\n pos++;\n }\n }\n }\n }\n \n if(size > safeSize){\n for(v_buff_size i = pos; static_cast(i) < result->size(); i ++){\n resultData[i] = '?';\n }\n }\n \n return result;\n}\n\nvoid Utils::unescapeStringToBuffer(const char* data, v_buff_size size, p_char8 resultData){\n \n v_buff_size i = 0;\n v_buff_size pos = 0;\n \n while (i < size) {\n v_char8 a = static_cast(data[i]);\n \n if(a == '\\\\'){\n v_char8 b = static_cast(data[i + 1]);\n if(b != 'u'){\n switch (b) {\n case '\"': resultData[pos] = '\"'; pos ++; break;\n case '\\\\': resultData[pos] = '\\\\'; pos ++; break;\n case '/': resultData[pos] = '/'; pos ++; break;\n case 'b': resultData[pos] = '\\b'; pos ++; break;\n case 'f': resultData[pos] = '\\f'; pos ++; break;\n case 'n': resultData[pos] = '\\n'; pos ++; break;\n case 'r': resultData[pos] = '\\r'; pos ++; break;\n case 't': resultData[pos] = '\\t'; pos ++; break;\n default: break;\n }\n i += 2;\n } else {\n if(data[i + 2] == '+'){ // Not JSON standard case\n v_uint32 code;\n encoding::Hex::readUInt32(&data[i + 3], code);\n i += 11;\n pos += encoding::Unicode::decodeUtf8Char(static_cast(code), &resultData[pos]);\n } else {\n \n v_uint16 code;\n encoding::Hex::readUInt16(&data[i + 2], code);\n \n if(code >= 0xD800 && code <= 0xDBFF){\n v_uint16 low;\n encoding::Hex::readUInt16(&data[i + 8], low);\n v_uint32 bigCode = static_cast(encoding::Unicode::utf16SurrogatePairToCode(static_cast(code), static_cast(low)));\n pos += encoding::Unicode::decodeUtf8Char(static_cast(bigCode), &resultData[pos]);\n i += 12;\n } else {\n pos += encoding::Unicode::decodeUtf8Char(code, &resultData[pos]);\n i += 6;\n }\n \n }\n }\n } else {\n resultData[pos] = a;\n pos ++;\n i++;\n }\n \n }\n \n}\n \noatpp::String Utils::unescapeString(const char* data, v_buff_size size, v_int64& errorCode, v_buff_size& errorPosition) {\n \n v_buff_size unescapedSize = calcUnescapedStringSize(data, size, errorCode, errorPosition);\n if(errorCode != 0){\n return nullptr;\n }\n auto result = String(unescapedSize);\n if(unescapedSize == size) {\n std::memcpy(reinterpret_cast(const_cast(result->data())), data, static_cast(size));\n } else {\n unescapeStringToBuffer(data, size, reinterpret_cast(const_cast(result->data())));\n }\n return result;\n \n}\n \nstd::string Utils::unescapeStringToStdString(const char* data, v_buff_size size, v_int64& errorCode, v_buff_size& errorPosition){\n \n v_buff_size unescapedSize = calcUnescapedStringSize(data, size, errorCode, errorPosition);\n if(errorCode != 0){\n return \"\";\n }\n std::string result;\n result.resize(static_cast(unescapedSize));\n if(unescapedSize == size) {\n std::memcpy(reinterpret_cast(const_cast(result.data())), data, static_cast(size));\n } else {\n unescapeStringToBuffer(data, size, reinterpret_cast(const_cast(result.data())));\n }\n return result;\n \n}\n \nconst char* Utils::preparseString(ParsingCaret& caret, v_buff_size& size){\n \n if(caret.canContinueAtChar('\"', 1)){\n \n const char* data = caret.getData();\n v_buff_size pos = caret.getPosition();\n...\n// Path: src/oatpp/parser/json/mapping/Serializer.hpp\n/***************************************************************************\n *\n * Project _____ __ ____ _ _\n * ( _ ) /__\\ (_ _)_| |_ _| |_\n * )(_)( /(__)\\ )( (_ _)(_ _)\n * (_____)(__)(__)(__) |_| |_|\n *\n *\n * Copyright 2018-present, Leonid Stryzhevskyi \n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n *\n ***************************************************************************/\n\n#ifndef oatpp_parser_json_mapping_Serializer_hpp\n#define oatpp_parser_json_mapping_Serializer_hpp\n\n#include \"oatpp/parser/json/Utils.hpp\"\n#include \"oatpp/parser/json/Beautifier.hpp\"\n#include \"oatpp/core/Types.hpp\"\n#include \n\nnamespace oatpp { namespace parser { namespace json { namespace mapping {\n\n/**\n * Json Serializer.\n * Serializes oatpp DTO object to json. See [Data Transfer Object(DTO) component](https://oatpp.io/docs/components/dto/).\n */\nclass Serializer {\npublic:\n typedef oatpp::data::mapping::type::Type Type;\n typedef oatpp::data::mapping::type::BaseObject::Property Property;\n typedef oatpp::data::mapping::type::BaseObject::Properties Properties;\n\n typedef oatpp::String String;\npublic:\n /**\n * Serializer config.\n */\n class Config : public oatpp::base::Countable {\n public:\n /**\n * Constructor.\n */\n Config()\n {}\n public:\n\n /**\n * Create shared config.\n * @return - `std::shared_ptr` to Config.\n */\n static std::shared_ptr createShared(){\n return std::make_shared();\n }\n\n /**\n * Include fields with value == nullptr into serialized json.\n * Field will still be included when field-info `required` is set to true and &id:alwaysIncludeRequired is set to true.\n */\n bool includeNullFields = true;\n\n /**\n * Always include required fields (set in in DTO_FIELD_INFO) even if they are `value == nullptr`\n */\n bool alwaysIncludeRequired = false;\n\n /**\n * Always include array or map elements, even if their value is `nullptr`.\n */\n bool alwaysIncludeNullCollectionElements = false;\n\n /**\n * If `true` - insert string `\"\"` in json field value in case unknown field found.\n * Fail if `false`.\n * Known types for this serializer are:
\n * (String, Int8, Int16, Int32, Int64, Float32, Float64, Boolean, DTOs, List, Fields).\n */\n bool throwOnUnknownTypes = true;\n\n /**\n * Use JSON Beautifier.\n */\n bool useBeautifier = false;\n\n /**\n * Beautifier Indent.\n */\n oatpp::String beautifierIndent = \" \";\n\n /**\n * Beautifier new line.\n */\n oatpp::String beautifierNewLine = \"\\n\";\n\n /**\n * Enable type interpretations.\n */\n std::vector enabledInterpretations = {};\n\n /**\n * Escape flags.\n */\n v_uint32 escapeFlags = json::Utils::FLAG_ESCAPE_ALL;\n\n };\npublic:\n typedef void (*SerializerMethod)(Serializer*,\n data::stream::ConsistentOutputStream*,\n const oatpp::Void&);\nprivate:\n\n template\n \nstatic void serializePrimitive(Serializer* serializer,\n data::stream::ConsistentOutputStream* stream,\n const oatpp::Void& polymorph){\n (void) serializer;\n\n if(polymorph){\n stream->writeAsString(* static_cast(polymorph.get()));\n } else {\n stream->writeSimple(\"null\", 4);\n }\n }\n \n static void serializeString(oatpp::data::stream::ConsistentOutputStream* stream,\n const char* data,\n v_buff_size size,\n v_uint32 escapeFlags);\n\n static void serializeString(Serializer* serializer,\n data::stream::ConsistentOutputStream* stream,\n const oatpp::Void& polymorph);\n\n static void serializeAny(Serializer* serializer,\n data::stream::ConsistentOutputStream* stream,\n const oatpp::Void& polymorph);\n\n static void serializeEnum(Serializer* serializer,\n data::stream::ConsistentOutputStream* stream,\n const oatpp::Void& polymorph);\n\n static void serializeCollection(Serializer* serializer,\n data::stream::ConsistentOutputStream* stream,\n const oatpp::Void& polymorph);\n\n static void serializeMap(Serializer* serializer,\n data::stream::ConsistentOutputStream* stream,\n const oatpp::Void& polymorph);\n\n static void serializeObject(Serializer* serializer,\n data::stream::ConsistentOutputStream* stream,\n const oatpp::Void& polymorph);\n\n void serialize(data::stream::ConsistentOutputStream* stream, const oatpp::Void& polymorph);\n\nprivate:\n std::shared_ptr m_config;\n std::vector m_methods;\npublic:\n\n /**\n * Constructor.\n * @param config - serializer config.\n */\n Serializer(const std::shared_ptr& config = std::make_shared());\n\n /**\n * Set serializer method for type.\n * @param classId - &id:oatpp::data::mapping::type::ClassId;.\n * @param method - `typedef void (*SerializerMethod)(Serializer*, data::stream::ConsistentOutputStream*, const oatpp::Void&)`.\n */\n void setSerializerMethod(const data::mapping::type::ClassId& classId, SerializerMethod method);\n\n /**\n * Serialize object to stream.\n * @param stream - &id:oatpp::data::stream::ConsistentOutputStream;.\n * @param polymorph - DTO as &id:oatpp::Void;.\n */\n void serializeToStream(data::stream::ConsistentOutputStream* stream, const oatpp::Void& polymorph);\n\n /**\n * Get serializer config.\n * @return\n */\n const std::shared_ptr& getConfig();\n\n};\n\n}}}}\n\n#endif /* oatpp_parser_json_mapping_Serializer_hpp */\n\n// Path: src/oatpp/parser/json/mapping/Deserializer.hpp\n/***************************************************************************\n *\n * Project _____ __ ____ _ _\n * ( _ ) /__\\ (_ _)_| |_ _| |_\n * )(_)( /(__)\\ )( (_ _)(_ _)\n * (_____)(__)(__)(__) |_| |_|\n *\n *\n * Copyright 2018-present, Leonid Stryzhevskyi \n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n *\n ***************************************************************************/\n\n#ifndef oatpp_parser_json_mapping_Deserializer_hpp\n#define oatpp_parser_json_mapping_Deserializer_hpp\n\n#include \"oatpp/parser/json/Utils.hpp\"\n#include \"oatpp/core/parser/Caret.hpp\"\n#include \"oatpp/core/Types.hpp\"\n\n#include \n\nnamespace oatpp { namespace parser { namespace json { namespace mapping {\n\n/**\n * Json Deserializer.\n * Deserialize oatpp DTO object from json. See [Data Transfer Object(DTO) component](https://oatpp.io/docs/components/dto/).\n */\nclass Deserializer {\npublic:\n typedef oatpp::data::mapping::type::Type Type;\n typedef oatpp::data::mapping::type::BaseObject::Property Property;\n typedef oatpp::data::mapping::type::BaseObject::Properties Properties;\n\n typedef oatpp::String String;\n\npublic:\n\n /**\n * \"'{' - expected\"\n */\n static constexpr v_int32 ERROR_CODE_OBJECT_SCOPE_OPEN = 1;\n\n /**\n * \"'}' - expected\"\n */\n static constexpr v_int32 ERROR_CODE_OBJECT_SCOPE_CLOSE = 2;\n\n /**\n * \"Unknown field\"\n */\n static constexpr v_int32 ERROR_CODE_OBJECT_SCOPE_UNKNOWN_FIELD = 3;\n\n /**\n * \"':' - expected\"\n */\n static constexpr v_int32 ERROR_CODE_OBJECT_SCOPE_COLON_MISSING = 4;\n\n /**\n * \"'[' - expected\"\n */\n static constexpr v_int32 ERROR_CODE_ARRAY_SCOPE_OPEN = 5;\n\n /**\n * \"']' - expected\"\n */\n static constexpr v_int32 ERROR_CODE_ARRAY_SCOPE_CLOSE = 6;\n\n /**\n * \"'true' or 'false' - expected\"\n */\n static constexpr v_int32 ERROR_CODE_VALUE_BOOLEAN = 7;\n\npublic:\n\n /**\n * Deserializer config.\n */\n class Config : public oatpp::base::Countable {\n public:\n /**\n * Constructor.\n */\n Config()\n {}\n public:\n\n /**\n * Create shared Config.\n * @return - `std::shared_ptr` to Config.\n */\n static std::shared_ptr createShared(){\n return std::make_shared();\n }\n\n /**\n * Do not fail if unknown field is found in json.\n * \"unknown field\" is the one which is not present in DTO object class.\n */\n bool allowUnknownFields = true;\n\n /**\n * Enable type interpretations.\n */\n std::vector enabledInterpretations = {};\n\n };\n\npublic:\n typedef oatpp::Void (*DeserializerMethod)(Deserializer*, parser::Caret&, const Type* const);\nprivate:\n static void skipScope(oatpp::parser::Caret& caret, v_char8 charOpen, v_char8 charClose);\n static void skipString(oatpp::parser::Caret& caret);\n static void skipToken(oatpp::parser::Caret& caret);\n static void skipValue(oatpp::parser::Caret& caret);\nprivate:\n static const Type* guessNumberType(oatpp::parser::Caret& caret);\n static const Type* guessType(oatpp::parser::Caret& caret);\nprivate:\n\n template\n static oatpp::Void deserializeInt(Deserializer* deserializer, parser::Caret& caret, const Type* const type){\n\n (void) deserializer;\n (void) type;\n\n if(caret.isAtText(\"null\", true)){\n return oatpp::Void(T::Class::getType());\n } else {\n //TODO: shall we handle overflow cases like\n // oatpp::String json = \"128\";\n // auto value = jsonObjectMapper->readFromString(json); // UInt8 will overflow to -128\n return T(static_cast(caret.parseInt()));\n }\n\n }\n\n template\n static oatpp::Void deserializeUInt(Deserializer* deserializer, parser::Caret& caret, const Type* const type){\n\n (void) deserializer;\n (void) type;\n\n if(caret.isAtText(\"null\", true)){\n return oatpp::Void(T::Class::getType());\n } else {\n //TODO: shall we handle overflow cases like\n // oatpp::String json = \"256\";\n // auto value = jsonObjectMapper->readFromString(json); // UInt8 will overflow to 0\n return T(static_cast(caret.parseUnsignedInt()));\n }\n\n }\n\n static oatpp::Void deserializeFloat32(Deserializer* deserializer, parser::Caret& caret, const Type* const type);\n static oatpp::Void deserializeFloat64(Deserializer* deserializer, parser::Caret& caret, const Type* const type);\n static oatpp::Void deserializeBoolean(Deserializer* deserializer, parser::Caret& caret, const Type* const type);\n static oatpp::Void deserializeString(Deserializer* deserializer, parser::Caret& caret, const Type* const type);\n static oatpp::Void deserializeAny(Deserializer* deserializer, parser::Caret& caret, const Type* const type);\n static oatpp::Void deserializeEnum(Deserializer* deserializer, parser::Caret& caret, const Type* const type);\n\n static oatpp::Void deserializeCollection(Deserializer* deserializer, parser::Caret& caret, const Type* type);\n static oatpp::Void deserializeMap(Deserializer* deserializer, parser::Caret& caret, const Type* const type);\n\n static oatpp::Void deserializeObject(Deserializer* deserializer, parser::Caret& caret, const Type* const type);\n\nprivate:\n std::shared_ptr m_config;\n std::vector m_methods;\npublic:\n\n /**\n * Constructor.\n * @param config\n */\n Deserializer(const std::shared_ptr& config = std::make_shared());\n\n /**\n * Set deserializer method for type.\n * @param classId - &id:oatpp::data::mapping::type::ClassId;.\n * @param method - `typedef oatpp::Void (*DeserializerMethod)(Deserializer*, parser::Caret&, const Type* const)`.\n */\n void setDeserializerMethod(const data::mapping::type::ClassId& classId, DeserializerMethod method);\n\n /**\n * Deserialize text.\n * @param caret - &id:oatpp::parser::Caret;.\n * @param type - &id:oatpp::data::mapping::type::Type;\n * @return - `oatpp::Void` over deserialized object.\n */\n oatpp::Void deserialize(parser::Caret& caret, const Type* const type);\n\n /**\n * Get deserializer config.\n * @return\n */\n const std::shared_ptr& getConfig();\n\n};\n\n}}}}\n\n#endif /* oatpp_parser_json_mapping_Deserializer_hpp */\n\n// Path: src/oatpp/parser/json/mapping/ObjectMapper.hpp\n/***************************************************************************\n *\n * Project _____ __ ____ _ _\n * ( _ ) /__\\ (_ _)_| |_ _| |_\n * )(_)( /(__)\\ )( (_ _)(_ _)\n * (_____)(__)(__)(__) |_| |_|\n *\n *\n * Copyright 2018-present, Leonid Stryzhevskyi \n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n *\n ***************************************************************************/\n\n#ifndef oatpp_parser_json_mapping_ObjectMapper_hpp\n#define oatpp_parser_json_mapping_ObjectMapper_hpp\n\n#include \"./Serializer.hpp\"\n#include \"./Deserializer.hpp\"\n\n#include \"oatpp/core/data/mapping/ObjectMapper.hpp\"\n\nnamespace oatpp { namespace parser { namespace json { namespace mapping {\n\n/**\n * Json ObjectMapper. Serialized/Deserializes oatpp DTO objects to/from JSON.\n * See [Data Transfer Object(DTO) component](https://oatpp.io/docs/components/dto/).
\n * Extends &id:oatpp::base::Countable;, &id:oatpp::data::mapping::ObjectMapper;.\n */\nclass ObjectMapper : public oatpp::base::Countable, public oatpp::data::mapping::ObjectMapper {\nprivate:\n static Info& getMapperInfo() {\n static Info info(\"application/json\");\n return info;\n }\nprivate:\n std::shared_ptr m_serializer;\n std::shared_ptr m_deserializer;\npublic:\n /**\n * Constructor.\n * @param serializerConfig - &id:oatpp::parser::json::mapping::Serializer::Config;.\n * @param deserializerConfig - &id:oatpp::parser::json::mapping::Deserializer::Config;.\n */\n ObjectMapper(const std::shared_ptr& serializerConfig,\n const std::shared_ptr& deserializerConfig);\n\n /**\n * Constructor.\n * @param serializer\n * @param deserializer\n */\n ObjectMapper(const std::shared_ptr& serializer = std::make_shared(),\n const std::shared_ptr& deserializer = std::make_shared());\npublic:\n\n /**\n * Create shared ObjectMapper.\n * @param serializerConfig - &id:oatpp::parser::json::mapping::Serializer::Config;.\n * @param deserializerConfig - &id:oatpp::parser::json::mapping::Deserializer::Config;.\n * @return - `std::shared_ptr` to ObjectMapper.\n */\n static std::shared_ptr\n createShared(const std::shared_ptr& serializerConfig,\n const std::shared_ptr& deserializerConfig);\n\n /**\n * Create shared ObjectMapper.\n * @param serializer\n * @param deserializer\n * @return\n */\n static std::shared_ptr\n createShared(const std::shared_ptr& serializer = std::make_shared(),\n const std::shared_ptr& deserializer = std::make_shared());\n\n /**\n * Implementation of &id:oatpp::data::mapping::ObjectMapper::write;.\n * @param stream - stream to write serializerd data to &id:oatpp::data::stream::ConsistentOutputStream;.\n * @param variant - object to serialize &id:oatpp::Void;.\n */\n void write(data::stream::ConsistentOutputStream* stream, const oatpp::Void& variant) const override;\n\n /**\n * Implementation of &id:oatpp::data::mapping::ObjectMapper::read;.\n * @param caret - &id:oatpp::parser::Caret;.\n * @param type - type of resultant object &id:oatpp::data::mapping::type::Type;.\n * @return - &id:oatpp::Void; holding resultant object.\n */\n oatpp::Void read(oatpp::parser::Caret& caret, const oatpp::data::mapping::type::Type* const type) const override;\n\n\n /**\n * Get serializer.\n * @return\n */\n std::shared_ptr getSerializer();\n\n /**\n * Get deserializer.\n * @return\n */\n std::shared_ptr getDeserializer();\n \n};\n \n}}}}\n\n#endif /* oatpp_parser_json_mapping_ObjectMapper_hpp */\n\n// Path: src/oatpp/parser/json/mapping/ObjectMapper.cpp\n/***************************************************************************\n *\n * Project _____ __ ____ _ _\n * ( _ ) /__\\ (_ _)_| |_ _| |_\n * )(_)( /(__)\\ )( (_ _)(_ _)\n * (_____)(__)(__)(__) |_| |_|\n *\n *\n * Copyright 2018-present, Leonid Stryzhevskyi \n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n *\n ***************************************************************************/\n\n#include \"ObjectMapper.hpp\"\n\nnamespace oatpp { namespace parser { namespace json { namespace mapping {\n\nObjectMapper::ObjectMapper(const std::shared_ptr& serializerConfig,\n const std::shared_ptr& deserializerConfig)\n : data::mapping::ObjectMapper(getMapperInfo())\n , m_serializer(std::make_shared(serializerConfig))\n , m_deserializer(std::make_shared(deserializerConfig))\n{}\n\nObjectMapper::ObjectMapper(const std::shared_ptr& serializer,\n const std::shared_ptr& deserializer)\n : data::mapping::ObjectMapper(getMapperInfo())\n , m_serializer(serializer)\n , m_deserializer(deserializer)\n{}\n\nstd::shared_ptr ObjectMapper::createShared(const std::shared_ptr& serializerConfig,\n const std::shared_ptr& deserializerConfig){\n return std::make_shared(serializerConfig, deserializerConfig);\n}\n\nstd::shared_ptr ObjectMapper::createShared(const std::shared_ptr& serializer,\n const std::shared_ptr& deserializer){\n return std::make_shared(serializer, deserializer);\n}\n\nvoid ObjectMapper::write(data::stream::ConsistentOutputStream* stream,\n const oatpp::Void& variant) const {\n m_serializer->serializeToStream(stream, variant);\n}\n\noatpp::Void ObjectMapper::read(oatpp::parser::Caret& caret,\n const oatpp::data::mapping::type::Type* const type) const {\n return m_deserializer->deserialize(caret, type);\n}\n\nstd::shared_ptr ObjectMapper::getSerializer() {\n return m_serializer;\n}\n\nstd::shared_ptr ObjectMapper::getDeserializer() {\n return m_deserializer;\n}\n\n}}}}\n\n// Path: src/oatpp/parser/json/mapping/Deserializer.cpp\n/***************************************************************************\n *\n * Project _____ __ ____ _ _\n * ( _ ) /__\\ (_ _)_| |_ _| |_\n * )(_)( /(__)\\ )( (_ _)(_ _)\n * (_____)(__)(__)(__) |_| |_|\n *\n *\n * Copyright 2018-present, Leonid Stryzhevskyi \n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n *\n ***************************************************************************/\n\n#include \"Deserializer.hpp\"\n\n#include \"oatpp/core/utils/ConversionUtils.hpp\"\n\nnamespace oatpp { namespace parser { namespace json { namespace mapping {\n\nDeserializer::Deserializer(const std::shared_ptr& config)\n : m_config(config)\n{\n\n m_methods.resize(static_cast(data::mapping::type::ClassId::getClassCount()), nullptr);\n\n setDeserializerMethod(data::mapping::type::__class::String::CLASS_ID, &Deserializer::deserializeString);\n setDeserializerMethod(data::mapping::type::__class::Any::CLASS_ID, &Deserializer::deserializeAny);\n\n setDeserializerMethod(data::mapping::type::__class::Int8::CLASS_ID, &Deserializer::deserializeInt);\n setDeserializerMethod(data::mapping::type::__class::UInt8::CLASS_ID, &Deserializer::deserializeUInt);\n\n setDeserializerMethod(data::mapping::type::__class::Int16::CLASS_ID, &Deserializer::deserializeInt);\n setDeserializerMethod(data::mapping::type::__class::UInt16::CLASS_ID, &Deserializer::deserializeUInt);\n\n setDeserializerMethod(data::mapping::type::__class::Int32::CLASS_ID, &Deserializer::deserializeInt);\n setDeserializerMethod(data::mapping::type::__class::UInt32::CLASS_ID, &Deserializer::deserializeUInt);\n\n setDeserializerMethod(data::mapping::type::__class::Int64::CLASS_ID, &Deserializer::deserializeInt);\n setDeserializerMethod(data::mapping::type::__class::UInt64::CLASS_ID, &Deserializer::deserializeUInt);\n\n setDeserializerMethod(data::mapping::type::__class::Float32::CLASS_ID, &Deserializer::deserializeFloat32);\n setDeserializerMethod(data::mapping::type::__class::Float64::CLASS_ID, &Deserializer::deserializeFloat64);\n setDeserializerMethod(data::mapping::type::__class::Boolean::CLASS_ID, &Deserializer::deserializeBoolean);\n\n setDeserializerMethod(data::mapping::type::__class::AbstractObject::CLASS_ID, &Deserializer::deserializeObject);\n setDeserializerMethod(data::mapping::type::__class::AbstractEnum::CLASS_ID, &Deserializer::deserializeEnum);\n\n setDeserializerMethod(data::mapping::type::__class::AbstractVector::CLASS_ID, &Deserializer::deserializeCollection);\n setDeserializerMethod(data::mapping::type::__class::AbstractList::CLASS_ID, &Deserializer::deserializeCollection);\n setDeserializerMethod(data::mapping::type::__class::AbstractUnorderedSet::CLASS_ID, &Deserializer::deserializeCollection);\n\n setDeserializerMethod(data::mapping::type::__class::AbstractPairList::CLASS_ID, &Deserializer::deserializeMap);\n setDeserializerMethod(data::mapping::type::__class::AbstractUnorderedMap::CLASS_ID, &Deserializer::deserializeMap);\n\n}\n\nvoid Deserializer::setDeserializerMethod(const data::mapping::type::ClassId& classId, DeserializerMethod method) {\n const v_uint32 id = static_cast(classId.id);\n if(id >= m_methods.size()) {\n m_methods.resize(id + 1, nullptr);\n }\n m_methods[id] = method;\n}\n\nvoid Deserializer::skipScope(oatpp::parser::Caret& caret, v_char8 charOpen, v_char8 charClose){\n\n const char* data = caret.getData();\n v_buff_size size = caret.getDataSize();\n v_buff_size pos = caret.getPosition();\n v_int32 scopeCounter = 0;\n\n bool isInString = false;\n\n while(pos < size){\n v_char8 a = static_cast(data[pos]);\n if(a == charOpen){\n if(!isInString){\n scopeCounter ++;\n }\n } else if(a == charClose){\n if(!isInString){\n scopeCounter --;\n if(scopeCounter == 0){\n caret.setPosition(pos + 1);\n return;\n }\n }\n } else if(a == '\"') {\n isInString = !isInString;\n } else if(a == '\\\\'){\n pos ++;\n }\n\n pos ++;\n\n }\n}\n\nvoid Deserializer::skipString(oatpp::parser::Caret& caret){\n const char* data = caret.getData();\n v_buff_size size = caret.getDataSize();\n v_buff_size pos = caret.getPosition();\n v_int32 scopeCounter = 0;\n while(pos < size){\n v_char8 a = static_cast(data[pos]);\n if(a == '\"'){\n scopeCounter ++;\n if(scopeCounter == 2) {\n caret.setPosition(pos + 1);\n return;\n }\n } else if(a == '\\\\'){\n pos ++;\n }\n pos ++;\n }\n}\n\nvoid Deserializer::skipToken(oatpp::parser::Caret& caret){\n const char* data = caret.getData();\n v_buff_size size = caret.getDataSize();\n v_buff_size pos = caret.getPosition();\n while(pos < size){\n v_char8 a = static_cast(data[pos]);\n if(a == ' ' || a == '\\t' || a == '\\n' || a == '\\r' || a == '\\b' || a == '\\f' ||\n a == '}' || a == ',' || a == ']') {\n caret.setPosition(pos);\n return;\n }\n pos ++;\n }\n}\n\nvoid Deserializer::skipValue(oatpp::parser::Caret& caret){\n if(caret.isAtChar('{')){\n skipScope(caret, '{', '}');\n } else if(caret.isAtChar('[')){\n skipScope(caret, '[', ']');\n } else if(caret.isAtChar('\"')){\n skipString(caret);\n } else {\n skipToken(caret);\n }\n}\n\noatpp::Void Deserializer::deserializeFloat32(Deserializer* deserializer, parser::Caret& caret, const Type* const type) {\n\n (void) deserializer;\n (void) type;\n\n if(caret.isAtText(\"null\", true)){\n return oatpp::Void(Float32::Class::getType());\n } else {\n return Float32(caret.parseFloat32());\n }\n}\n\noatpp::Void Deserializer::deserializeFloat64(Deserializer* deserializer, parser::Caret& caret, const Type* const type) {\n\n (void) deserializer;\n (void) type;\n\n if(caret.isAtText(\"null\", true)){\n return oatpp::Void(Float64::Class::getType());\n } else {\n return Float64(caret.parseFloat64());\n }\n\n}\n\noatpp::Void Deserializer::deserializeBoolean(Deserializer* deserializer, parser::Caret& caret, const Type* const type) {\n\n (void) deserializer;\n (void) type;\n\n if(caret.isAtText(\"null\", true)){\n return oatpp::Void(Boolean::Class::getType());\n } else {\n if(caret.isAtText(\"true\", true)) {\n return Boolean(true);\n } else if(caret.isAtText(\"false\", true)) {\n return Boolean(false);\n } else {\n caret.setError(\"[oatpp::parser::json::mapping::Deserializer::readBooleanValue()]: Error. 'true' or 'false' - expected.\", ERROR_CODE_VALUE_BOOLEAN);\n return oatpp::Void(Boolean::Class::getType());\n }\n }\n\n}\n\noatpp::Void Deserializer::deserializeString(Deserializer* deserializer, parser::Caret& caret, const Type* const type) {\n\n (void) deserializer;\n (void) type;\n\n if(caret.isAtText(\"null\", true)){\n return oatpp::Void(String::Class::getType());\n } else {\n return oatpp::Void(oatpp::parser::json::Utils::parseString(caret).getPtr(), String::Class::getType());\n }\n}\n\nconst data::mapping::type::Type* Deserializer::guessNumberType(oatpp::parser::Caret& caret) {\n if (!Utils::findDecimalSeparatorInCurrentNumber(caret)) {\n if (*caret.getCurrData() == '-') {\n return Int64::Class::getType();\n } else {\n return UInt64::Class::getType();\n }\n }\n return Float64::Class::getType();\n}\n\nconst data::mapping::type::Type* Deserializer::guessType(oatpp::parser::Caret& caret) {\n {\n parser::Caret::StateSaveGuard stateGuard(caret);\n v_char8 c = static_cast(*caret.getCurrData());\n switch (c) {\n case '\"':\n return String::Class::getType();\n case '{':\n return oatpp::Fields::Class::getType();\n case '[':\n return oatpp::List::Class::getType();\n case 't':\n if(caret.isAtText(\"true\")) return Boolean::Class::getType();\n break;\n case 'f':\n if(caret.isAtText(\"false\")) return Boolean::Class::getType();\n break;\n default:\n if (c == '-' || caret.isAtDigitChar()) {\n return guessNumberType(caret);\n }\n }\n }\n caret.setError(\"[oatpp::parser::json::mapping::Deserializer::guessType()]: Error. Can't guess type for oatpp::Any.\");\n return nullptr;\n}\n\noatpp::Void Deserializer::deserializeAny(Deserializer* deserializer, parser::Caret& caret, const Type* const type) {\n (void) type;\n if(caret.isAtText(\"null\", true)){\n return oatpp::Void(Any::Class::getType());\n } else {\n const Type* const fieldType = guessType(caret);\n if(fieldType != nullptr) {\n auto fieldValue = deserializer->deserialize(caret, fieldType);\n auto anyHandle = std::make_shared(fieldValue.getPtr(), fieldValue.getValueType());\n return oatpp::Void(anyHandle, Any::Class::getType());\n }\n }\n return oatpp::Void(Any::Class::getType());\n}\n\noatpp::Void Deserializer::deserializeEnum(Deserializer* deserializer, parser::Caret& caret, const Type* const type) {\n\n auto polymorphicDispatcher = static_cast(\n type->polymorphicDispatcher\n );\n\n data::mapping::type::EnumInterpreterError e = data::mapping::type::EnumInterpreterError::OK;\n const auto& value = deserializer->deserialize(caret, polymorphicDispatcher->getInterpretationType());\n if(caret.hasError()) {\n return nullptr;\n }\n const auto& result = polymorphicDispatcher->fromInterpretation(value, e);\n\n if(e == data::mapping::type::EnumInterpreterError::OK) {\n return result;\n }\n\n switch(e) {\n case data::mapping::type::EnumInterpreterError::CONSTRAINT_NOT_NULL:\n caret.setError(\"[oatpp::parser::json::mapping::Deserializer::deserializeEnum()]: Error. Enum constraint violated - 'NotNull'.\");\n break;\n case data::mapping::type::EnumInterpreterError::OK:\n case data::mapping::type::EnumInterpreterError::TYPE_MISMATCH_ENUM:\n case data::mapping::type::EnumInterpreterError::TYPE_MISMATCH_ENUM_VALUE:\n case data::mapping::type::EnumInterpreterError::ENTRY_NOT_FOUND:\n default:\n caret.setError(\"[oatpp::parser::json::mapping::Deserializer::deserializeEnum()]: Error. Can't deserialize Enum.\");\n }\n\n return nullptr;\n\n}\n\noatpp::Void Deserializer::deserializeCollection(Deserializer* deserializer, parser::Caret& caret, const Type* type) {\n\n if(caret.isAtText(\"null\", true)){\n return oatpp::Void(type);\n }\n\n if(caret.canContinueAtChar('[', 1)) {\n\n auto dispatcher = static_cast(type->polymorphicDispatcher);\n auto collection = dispatcher->createObject();\n\n auto itemType = dispatcher->getItemType();\n\n caret.skipBlankChars();\n\n while(!caret.isAtChar(']') && caret.canContinue()){\n\n caret.skipBlankChars();\n auto item = deserializer->deserialize(caret, itemType);\n if(caret.hasError()){\n return nullptr;\n }\n\n dispatcher->addItem(collection, item);\n caret.skipBlankChars();\n\n caret.canContinueAtChar(',', 1);\n\n }\n\n if(!caret.canContinueAtChar(']', 1)){\n if(!caret.hasError()){\n caret.setError(\"[oatpp::parser::json::mapping::Deserializer::deserializeCollection()]: Error. ']' - expected\", ERROR_CODE_ARRAY_SCOPE_CLOSE);\n }\n return nullptr;\n }\n\n return collection;\n\n } else {\n caret.setError(\"[oatpp::parser::json::mapping::Deserializer::deserializeCollection()]: Error. '[' - expected\", ERROR_CODE_ARRAY_SCOPE_OPEN);\n return nullptr;\n }\n\n}\n\noatpp::Void Deserializer::deserializeMap(Deserializer* deserializer, parser::Caret& caret, const Type* const type) {\n\n if(caret.isAtText(\"null\", true)){\n return oatpp::Void(type);\n }\n\n if(caret.canContinueAtChar('{', 1)) {\n\n auto dispatcher = static_cast(type->polymorphicDispatcher);\n auto map = dispatcher->createObject();\n\n auto keyType = dispatcher->getKeyType();\n if(keyType->classId != oatpp::String::Class::CLASS_ID){\n throw std::runtime_error(\"[oatpp::parser::json::mapping::Deserializer::deserializeMap()]: Invalid json map key. Key should be String\");\n }\n auto valueType = dispatcher->getValueType();\n\n caret.skipBlankChars();\n\n while (!caret.isAtChar('}') && caret.canContinue()) {\n\n caret.skipBlankChars();\n auto key = Utils::parseString(caret);\n if(caret.hasError()){\n return nullptr;\n }\n\n caret.skipBlankChars();\n if(!caret.canContinueAtChar(':', 1)){\n caret.setError(\"[oatpp::parser::json::mapping::Deserializer::deserializeMap()]: Error. ':' - expected\", ERROR_CODE_OBJECT_SCOPE_COLON_MISSING);\n return nullptr;\n }\n\n caret.skipBlankChars();\n\n auto item = deserializer->deserialize(caret, valueType);\n if(caret.hasError()){\n return nullptr;\n }\n dispatcher->addItem(map, key, item);\n\n caret.skipBlankChars();\n caret.canContinueAtChar(',', 1);\n\n }\n\n if(!caret.canContinueAtChar('}', 1)){\n if(!caret.hasError()){\n caret.setError(\"[oatpp::parser::json::mapping::Deserializer::deserializeMap()]: Error. '}' - expected\", ERROR_CODE_OBJECT_SCOPE_CLOSE);\n }\n return nullptr;\n }\n\n return map;\n\n } else {\n caret.setError(\"[oatpp::parser::json::mapping::Deserializer::deserializeMap()]: Error. '{' - expected\", ERROR_CODE_OBJECT_SCOPE_OPEN);\n }\n\n return nullptr;\n\n}\n\noatpp::Void Deserializer::deserializeObject(Deserializer* deserializer, parser::Caret& caret, const Type* const type) {\n\n if(caret.isAtText(\"null\", true)){\n return oatpp::Void(type);\n }\n\n if(caret.canContinueAtChar('{', 1)) {\n\n auto dispatcher = static_cast(type->polymorphicDispatcher);\n...\n", "needle_token_start": 5590, "needle_token_end": 5693, "code_context_ntokens": 16020, "output": ["```cpp\nstatic void serializePrimitive(Serializer* serializer,\n data::stream::ConsistentOutputStream* stream,\n const oatpp::Void& polymorph){\n (void) serializer;\n\n if(polymorph){\n stream->writeAsString(* static_cast(polymorph.get()));\n } else {\n stream->writeSimple(\"null\", 4);\n }\n }\n```"]} +{"repo": "oatpp/oatpp", "name": "DataHandle", "language": "cpp", "path": "src/oatpp/core/data/resource/InMemoryData.hpp", "position_ratio": 0.45, "description": "\nFunction Description:\n1. **Purpose**: To encapsulate and manage a string of data within a resource context.\n2. **Input**: A string representing the data to be managed.\n3. **Output**: An instance that holds the provided string data.\n4. **Procedure**: Upon instantiation, the input string is stored within the instance, making it accessible for further operations or retrieval within the resource management context.\n\n", "instruction": "Based on the function description and code context, please retrieve and repeat the exact described function from the code context in a code block wrapped by ```:", "template": "instruction\ncode_context\ndescription\ninstruction", "code_context": "// Path: src/oatpp/core/data/stream/FileStream.cpp\n/***************************************************************************\n *\n * Project _____ __ ____ _ _\n * ( _ ) /__\\ (_ _)_| |_ _| |_\n * )(_)( /(__)\\ )( (_ _)(_ _)\n * (_____)(__)(__)(__) |_| |_|\n *\n *\n * Copyright 2018-present, Leonid Stryzhevskyi \n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n *\n ***************************************************************************/\n\n#include \"FileStream.hpp\"\n\nnamespace oatpp { namespace data{ namespace stream {\n\n\n////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////\n// FileInputStream\n\noatpp::data::stream::DefaultInitializedContext FileInputStream::DEFAULT_CONTEXT(data::stream::StreamType::STREAM_FINITE);\n\nFileInputStream::FileInputStream(FileInputStream&& other)\n : m_file(other.m_file)\n , m_ownsFile(other.m_ownsFile)\n , m_ioMode(other.m_ioMode)\n{\n other.m_file = nullptr;\n other.m_ownsFile = false;\n}\n\nFileInputStream::FileInputStream(std::FILE* file, bool ownsFile, const std::shared_ptr& captureData)\n : m_file(file)\n , m_ownsFile(ownsFile)\n , m_ioMode(IOMode::ASYNCHRONOUS)\n , m_capturedData(captureData)\n{}\n\nFileInputStream::FileInputStream(const char* filename, const std::shared_ptr& captureData)\n : FileInputStream(std::fopen(filename, \"rb\"), true, captureData)\n{\n if(!m_file) {\n OATPP_LOGE(\"[oatpp::data::stream::FileInputStream::FileInputStream(filename)]\", \"Error. Can't open file '%s'.\", filename)\n throw std::runtime_error(\"[oatpp::data::stream::FileInputStream::FileInputStream(filename)]: Error. Can't open file.\");\n }\n}\n\nFileInputStream::~FileInputStream() {\n this->close();\n}\n\nstd::FILE* FileInputStream::getFile() {\n return m_file;\n}\n\n...\n// Path: src/oatpp/core/data/stream/BufferStream.hpp\n/***************************************************************************\n *\n * Project _____ __ ____ _ _\n * ( _ ) /__\\ (_ _)_| |_ _| |_\n * )(_)( /(__)\\ )( (_ _)(_ _)\n * (_____)(__)(__)(__) |_| |_|\n *\n *\n * Copyright 2018-present, Leonid Stryzhevskyi \n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n *\n ***************************************************************************/\n\n#ifndef oatpp_data_stream_BufferStream_hpp\n#define oatpp_data_stream_BufferStream_hpp\n\n#include \"Stream.hpp\"\n\nnamespace oatpp { namespace data{ namespace stream {\n\n/**\n * BufferOutputStream\n */\nclass BufferOutputStream : public ConsistentOutputStream {\npublic:\n static data::stream::DefaultInitializedContext DEFAULT_CONTEXT;\nprivate:\n p_char8 m_data;\n v_buff_size m_capacity;\n v_buff_size m_position;\n v_buff_size m_maxCapacity;\n IOMode m_ioMode;\nprivate:\n std::shared_ptr m_capturedData;\npublic:\n\n /**\n * Constructor.\n * @param growBytes\n * @param captureData - capture auxiliary data to not get deleted until it's done with the stream.\n */\n BufferOutputStream(v_buff_size initialCapacity = 2048, const std::shared_ptr& captureData = nullptr);\n\n /**\n * Virtual destructor.\n */\n ~BufferOutputStream() override;\n\n /**\n * Write `count` of bytes to stream.\n * @param data - data to write.\n * @param count - number of bytes to write.\n * @param action - async specific action. If action is NOT &id:oatpp::async::Action::TYPE_NONE;, then\n * caller MUST return this action on coroutine iteration.\n * @return - actual number of bytes written. &id:oatpp::v_io_size;.\n */\n v_io_size write(const void *data, v_buff_size count, async::Action& action) override;\n\n /**\n * Set stream I/O mode.\n * @throws\n */\n void setOutputStreamIOMode(IOMode ioMode) override;\n\n /**\n * Get stream I/O mode.\n * @return\n */\n IOMode getOutputStreamIOMode() override;\n\n /**\n * Get stream context.\n * @return\n */\n Context& getOutputStreamContext() override;\n\n /**\n * Reserve bytes for future writes.\n */\n void reserveBytesUpfront(v_buff_size count);\n\n /**\n * Get pointer to data.\n * @return - pointer to data.\n */\n p_char8 getData();\n\n /**\n * Get current capacity.\n * Capacity may change.\n * @return\n */\n v_buff_size getCapacity();\n\n /**\n * Get current data write position.\n * @return - current data write position.\n */\n v_buff_size getCurrentPosition();\n\n /**\n * Set current data write position.\n * @param position - data write position.\n */\n void setCurrentPosition(v_buff_size position);\n\n /**\n * Reset stream buffer and its capacity. Also reset write position.\n * @param initialCapacity\n */\n void reset(v_buff_size initialCapacity = 2048);\n\n /**\n * Copy data to &id:oatpp::String;.\n * @return\n */\n oatpp::String toString();\n\n /**\n * Create &id:oatpp::String; from part of buffer.\n * @param pos - starting position in buffer.\n * @param count - size of bytes to write to substring.\n * @return - &id:oatpp::String;\n */\n oatpp::String getSubstring(v_buff_size pos, v_buff_size count);\n\n /**\n * Write all bytes from buffer to stream.\n * @param stream - stream to flush all data to.\n * @return - actual amount of bytes flushed.\n */\n oatpp::v_io_size flushToStream(OutputStream* stream);\n\n /**\n * Write all bytes from buffer to stream in async manner.\n * @param _this - pointer to `this` buffer.\n * @param stream - stream to flush all data to.\n * @return - &id:oatpp::async::CoroutineStarter;.\n */\n static oatpp::async::CoroutineStarter flushToStreamAsync(const std::shared_ptr& _this, const std::shared_ptr& stream);\n\n};\n\n/**\n * BufferInputStream\n */\nclass BufferInputStream : public BufferedInputStream {\npublic:\n static data::stream::DefaultInitializedContext DEFAULT_CONTEXT;\nprivate:\n std::shared_ptr m_memoryHandle;\n p_char8 m_data;\n v_buff_size m_size;\n v_buff_size m_position;\n IOMode m_ioMode;\nprivate:\n std::shared_ptr m_capturedData;\npublic:\n\n /**\n * Constructor.\n * @param memoryHandle - buffer memory handle. May be nullptr.\n * @param data - pointer to buffer data.\n * @param size - size of the buffer.\n * @param captureData - capture auxiliary data to not get deleted until it's done with the stream.\n */\n BufferInputStream(const std::shared_ptr& memoryHandle,\n const void* data,\n v_buff_size size,\n const std::shared_ptr& captureData = nullptr);\n\n /**\n * Constructor.\n * @param data - buffer.\n * @param captureData - capture auxiliary data to not get deleted until it's done with the stream.\n */\n BufferInputStream(const oatpp::String& data, const std::shared_ptr& captureData = nullptr);\n\n /**\n * Reset stream data and set position to `0`.\n * @param memoryHandle - buffer memory handle. May be nullptr.\n * @param data - pointer to buffer data.\n * @param size - size of the buffer.\n * @param captureData - capture auxiliary data to not get deleted until it's done with the stream.\n */\n void reset(const std::shared_ptr& memoryHandle,\n p_char8 data,\n v_buff_size size,\n const std::shared_ptr& captureData = nullptr);\n\n\n /**\n * Same as `reset(nullptr, nullptr, 0);.`\n */\n void reset();\n\n /**\n * Read data from stream.
\n * It is a legal case if return result < count. Caller should handle this!\n * *Calls to this method are always NON-BLOCKING*\n * @param data - buffer to read data to.\n * @param count - size of the buffer.\n * @param action - async specific action. If action is NOT &id:oatpp::async::Action::TYPE_NONE;, then\n * caller MUST return this action on coroutine iteration.\n * @return - actual number of bytes read. 0 - designates end of the buffer.\n */\n v_io_size read(void *data, v_buff_size count, async::Action& action) override;\n\n /**\n * Set stream I/O mode.\n * @throws\n */\n void setInputStreamIOMode(IOMode ioMode) override;\n\n /**\n * Get stream I/O mode.\n * @return\n */\n IOMode getInputStreamIOMode() override;\n\n /**\n * Get stream context.\n * @return\n */\n Context& getInputStreamContext() override;\n\n /**\n * Get data memory handle.\n * @return - data memory handle.\n */\n std::shared_ptr getDataMemoryHandle();\n\n /**\n * Get pointer to data.\n * @return - pointer to data.\n */\n p_char8 getData();\n\n /**\n * Get data size.\n * @return - data size.\n */\n v_buff_size getDataSize();\n\n /**\n * Get current data read position.\n * @return - current data read position.\n */\n v_buff_size getCurrentPosition();\n\n /**\n * Set current data read position.\n * @param position - data read position.\n */\n void setCurrentPosition(v_buff_size position);\n\n /**\n * Peek up to count of bytes int he buffer\n * @param data\n * @param count\n * @return [1..count], IOErrors.\n */\n v_io_size peek(void *data, v_buff_size count, async::Action& action) override;\n\n /**\n * Amount of bytes currently available to read from buffer.\n * @return &id:oatpp::v_io_size;.\n */\n v_io_size availableToRead() const override;\n\n /**\n * Commit read offset\n * @param count\n * @return [1..count], IOErrors.\n */\n v_io_size commitReadOffset(v_buff_size count) override;\n\n};\n\n}}}\n\n#endif // oatpp_data_stream_BufferStream_hpp\n\n// Path: src/oatpp/core/data/stream/BufferStream.cpp\n/***************************************************************************\n *\n * Project _____ __ ____ _ _\n * ( _ ) /__\\ (_ _)_| |_ _| |_\n * )(_)( /(__)\\ )( (_ _)(_ _)\n * (_____)(__)(__)(__) |_| |_|\n *\n *\n * Copyright 2018-present, Leonid Stryzhevskyi \n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n *\n ***************************************************************************/\n\n#include \"BufferStream.hpp\"\n\n#include \"oatpp/core/utils/Binary.hpp\"\n\nnamespace oatpp { namespace data{ namespace stream {\n\n////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////\n// BufferOutputStream\n\ndata::stream::DefaultInitializedContext BufferOutputStream::DEFAULT_CONTEXT(data::stream::StreamType::STREAM_INFINITE);\n\nBufferOutputStream::BufferOutputStream(v_buff_size initialCapacity, const std::shared_ptr& captureData)\n : m_data(new v_char8[initialCapacity])\n , m_capacity(initialCapacity)\n , m_position(0)\n , m_maxCapacity(-1)\n , m_ioMode(IOMode::ASYNCHRONOUS)\n , m_capturedData(captureData)\n{}\n\nBufferOutputStream::~BufferOutputStream() {\n m_capturedData.reset(); // reset capture data before deleting data.\n delete [] m_data;\n}\n\nv_io_size BufferOutputStream::write(const void *data, v_buff_size count, async::Action& action) {\n\n (void) action;\n\n reserveBytesUpfront(count);\n\n std::memcpy(m_data + m_position, data, static_cast(count));\n m_position += count;\n\n return count;\n\n}\n\nvoid BufferOutputStream::setOutputStreamIOMode(IOMode ioMode) {\n m_ioMode = ioMode;\n}\n\nIOMode BufferOutputStream::getOutputStreamIOMode() {\n return m_ioMode;\n}\n\nContext& BufferOutputStream::getOutputStreamContext() {\n return DEFAULT_CONTEXT;\n}\n\nvoid BufferOutputStream::reserveBytesUpfront(v_buff_size count) {\n\n v_buff_size capacityNeeded = m_position + count;\n\n if(capacityNeeded > m_capacity) {\n\n v_buff_size newCapacity = utils::Binary::nextP2(capacityNeeded);\n\n if(newCapacity < 0 || (m_maxCapacity > 0 && newCapacity > m_maxCapacity)) {\n newCapacity = m_maxCapacity;\n }\n\n if(newCapacity < capacityNeeded) {\n throw std::runtime_error(\"[oatpp::data::stream::BufferOutputStream::reserveBytesUpfront()]: Error. Unable to allocate requested memory.\");\n }\n\n p_char8 newData = new v_char8[newCapacity];\n\n std::memcpy(newData, m_data, static_cast(m_position));\n delete [] m_data;\n m_data = newData;\n m_capacity = newCapacity;\n\n }\n\n}\n\np_char8 BufferOutputStream::getData() {\n return m_data;\n}\n\n\nv_buff_size BufferOutputStream::getCapacity() {\n return m_capacity;\n}\n\n\nv_buff_size BufferOutputStream::getCurrentPosition() {\n return m_position;\n}\n\n\nvoid BufferOutputStream::setCurrentPosition(v_buff_size position) {\n m_position = position;\n}\n\nvoid BufferOutputStream::reset(v_buff_size initialCapacity) {\n delete [] m_data;\n m_data = new v_char8[initialCapacity];\n m_capacity = initialCapacity;\n m_position = 0;\n}\n\noatpp::String BufferOutputStream::toString() {\n return oatpp::String(reinterpret_cast(m_data), m_position);\n}\n\noatpp::String BufferOutputStream::getSubstring(v_buff_size pos, v_buff_size count) {\n if(pos + count <= m_position) {\n return oatpp::String(reinterpret_cast(m_data + pos), count);\n } else {\n return oatpp::String(reinterpret_cast(m_data + pos), m_position - pos);\n }\n}\n\noatpp::v_io_size BufferOutputStream::flushToStream(OutputStream* stream) {\n return stream->writeExactSizeDataSimple(m_data, m_position);\n}\n\noatpp::async::CoroutineStarter BufferOutputStream::flushToStreamAsync(const std::shared_ptr& _this, const std::shared_ptr& stream) {\n\n class WriteDataCoroutine : public oatpp::async::Coroutine {\n private:\n std::shared_ptr m_this;\n std::shared_ptr m_stream;\n data::buffer::InlineWriteData m_inlineData;\n public:\n\n WriteDataCoroutine(const std::shared_ptr& _this,\n const std::shared_ptr& stream)\n : m_this(_this)\n , m_stream(stream)\n {}\n\n Action act() override {\n if(m_inlineData.currBufferPtr == nullptr) {\n m_inlineData.currBufferPtr = m_this->m_data;\n m_inlineData.bytesLeft = m_this->m_position;\n }\n return m_stream.get()->writeExactSizeDataAsyncInline(m_inlineData, finish());\n }\n\n };\n\n return WriteDataCoroutine::start(_this, stream);\n\n}\n\n\n////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////\n// BufferInputStream\n\ndata::stream::DefaultInitializedContext BufferInputStream::DEFAULT_CONTEXT(data::stream::StreamType::STREAM_FINITE);\n\nBufferInputStream::BufferInputStream(const std::shared_ptr& memoryHandle,\n const void* data,\n v_buff_size size,\n const std::shared_ptr& captureData)\n : m_memoryHandle(memoryHandle)\n , m_data(reinterpret_cast(const_cast(data)))\n , m_size(size)\n , m_position(0)\n , m_ioMode(IOMode::ASYNCHRONOUS)\n , m_capturedData(captureData)\n{}\n\nBufferInputStream::BufferInputStream(const oatpp::String& data, const std::shared_ptr& captureData)\n : BufferInputStream(data.getPtr(), reinterpret_cast(const_cast(data->data())), static_cast(data->size()), captureData)\n{}\n\nvoid BufferInputStream::reset(const std::shared_ptr& memoryHandle,\n p_char8 data,\n v_buff_size size,\n const std::shared_ptr& captureData)\n{\n m_memoryHandle = memoryHandle;\n m_data = data;\n m_size = size;\n m_position = 0;\n m_capturedData = captureData;\n}\n\nvoid BufferInputStream::reset() {\n m_memoryHandle = nullptr;\n m_data = nullptr;\n m_size = 0;\n m_position = 0;\n m_capturedData.reset();\n}\n\nv_io_size BufferInputStream::read(void *data, v_buff_size count, async::Action& action) {\n\n (void) action;\n\n v_buff_size desiredAmount = count;\n if(desiredAmount > m_size - m_position) {\n desiredAmount = m_size - m_position;\n }\n std::memcpy(data, &m_data[m_position], static_cast(desiredAmount));\n m_position += desiredAmount;\n return desiredAmount;\n}\n\nvoid BufferInputStream::setInputStreamIOMode(IOMode ioMode) {\n m_ioMode = ioMode;\n}\n\nIOMode BufferInputStream::getInputStreamIOMode() {\n return m_ioMode;\n}\n\nContext& BufferInputStream::getInputStreamContext() {\n return DEFAULT_CONTEXT;\n}\n\nstd::shared_ptr BufferInputStream::getDataMemoryHandle() {\n return m_memoryHandle;\n}\n\np_char8 BufferInputStream::getData() {\n return m_data;\n}\n\nv_buff_size BufferInputStream::getDataSize() {\n return m_size;\n}\n\nv_buff_size BufferInputStream::getCurrentPosition() {\n return m_position;\n}\n\nvoid BufferInputStream::setCurrentPosition(v_buff_size position) {\n m_position = position;\n}\n\nv_io_size BufferInputStream::peek(void *data, v_buff_size count, async::Action &action) {\n (void) action;\n\n v_buff_size desiredAmount = count;\n if(desiredAmount > m_size - m_position) {\n desiredAmount = m_size - m_position;\n }\n std::memcpy(data, &m_data[m_position], static_cast(desiredAmount));\n return desiredAmount;\n}\n\nv_io_size BufferInputStream::availableToRead() const {\n return m_size - m_position;\n}\n\nv_io_size BufferInputStream::commitReadOffset(v_buff_size count) {\n if(count > m_size - m_position) {\n count = m_size - m_position;\n }\n m_position += count;\n return count;\n}\n\n}}}\n\n// Path: src/oatpp/core/data/resource/Resource.hpp\n/***************************************************************************\n *\n * Project _____ __ ____ _ _\n * ( _ ) /__\\ (_ _)_| |_ _| |_\n * )(_)( /(__)\\ )( (_ _)(_ _)\n * (_____)(__)(__)(__) |_| |_|\n *\n *\n * Copyright 2018-present, Leonid Stryzhevskyi \n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n *\n ***************************************************************************/\n\n#ifndef oatpp_data_resource_Resource_hpp\n#define oatpp_data_resource_Resource_hpp\n\n#include \"oatpp/core/data/stream/Stream.hpp\"\n\nnamespace oatpp { namespace data { namespace resource {\n\n/**\n * Abstract data resource\n */\nclass Resource : public oatpp::base::Countable {\npublic:\n\n /**\n * virtual destructor.\n */\n virtual ~Resource() override = default;\n\n /**\n * Open output stream.\n * @return\n */\n virtual std::shared_ptr openOutputStream() = 0;\n\n /**\n * Open input stream.\n * @return\n */\n virtual std::shared_ptr openInputStream() = 0;\n\n /**\n * Get in-memory data if applicable.\n * @return - `&id:oatpp::String;` or `nullptr` if not applicable.\n */\n virtual oatpp::String getInMemoryData() = 0;\n\n /**\n * Get known data size if applicable.\n * @return - known size of the data. `-1` - if size is unknown.\n */\n virtual v_int64 getKnownSize() = 0;\n\n /**\n * Get resource location if applicable.
\n * location can be for example a file name.\n * @return - `&id:oatpp::String;` or `nullptr` if not applicable.\n */\n virtual oatpp::String getLocation() = 0;\n\n};\n\n}}}\n\n#endif //oatpp_data_resource_Resource_hpp\n\n// Path: src/oatpp/core/data/resource/InMemoryData.hpp\n/***************************************************************************\n *\n * Project _____ __ ____ _ _\n * ( _ ) /__\\ (_ _)_| |_ _| |_\n * )(_)( /(__)\\ )( (_ _)(_ _)\n * (_____)(__)(__)(__) |_| |_|\n *\n *\n * Copyright 2018-present, Leonid Stryzhevskyi \n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n *\n ***************************************************************************/\n\n#ifndef oatpp_data_resource_InMemoryData_hpp\n#define oatpp_data_resource_InMemoryData_hpp\n\n#include \"./Resource.hpp\"\n#include \"oatpp/core/data/stream/BufferStream.hpp\"\n\nnamespace oatpp { namespace data { namespace resource {\n\nclass InMemoryData : public Resource {\nprivate:\n\n struct DataHandle {\n\n oatpp::String data;\n\n \nDataHandle(const oatpp::String& pData)\n : data(pData)\n {}\n\n };\n\n struct OutputDataHandle {\n\n std::shared_ptr dataHandle;\n oatpp::String data;\n data::stream::BufferOutputStream* stream;\n\n ~OutputDataHandle();\n\n };\n\nprivate:\n std::shared_ptr m_handle;\npublic:\n\n /**\n * Default constructor.\n */\n InMemoryData() = default;\n\n /**\n * Constructor.\n * @param fullInMemoryDataname\n */\n InMemoryData(const oatpp::String& data);\n\n /**\n * Open output stream to an InMemoryData.
\n * NOT thread-safe.
\n * *Note: data is committed once stream is closed.*
\n * *Note: stream also captures data-handle. The data won't be deleted until the stream is deleted.*\n * @return - `std::shared_ptr` to &id:oatpp::data::stream::OutputStream;.\n */\n std::shared_ptr openOutputStream() override;\n\n /**\n * Open input stream to an InMemoryData.
\n * NOT thread-safe.
\n * *Note: once the stream is open no subsequent writes through the output stream won't affect currently opened input streams.*
\n * *Note: stream also captures file-handle. The data won't be deleted until the stream is deleted.*\n * @return - `std::shared_ptr` &id:oatpp::data::stream::InputStream;.\n */\n std::shared_ptr openInputStream() override;\n\n /**\n * Get in-memory-data.\n * @return - always returns `nullptr`.\n */\n oatpp::String getInMemoryData() override;\n\n /**\n * Get size of an in-memory-data.\n * @return - size of the data in bytes.\n */\n v_int64 getKnownSize() override;\n\n /**\n * Not applicable.\n * @return - always returns `nullptr`.\n */\n oatpp::String getLocation() override;\n\n};\n\n}}}\n\n#endif //oatpp_data_resource_InMemoryData_hpp\n\n// Path: src/oatpp/core/data/resource/InMemoryData.cpp\n/***************************************************************************\n *\n * Project _____ __ ____ _ _\n * ( _ ) /__\\ (_ _)_| |_ _| |_\n * )(_)( /(__)\\ )( (_ _)(_ _)\n * (_____)(__)(__)(__) |_| |_|\n *\n *\n * Copyright 2018-present, Leonid Stryzhevskyi \n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n *\n ***************************************************************************/\n\n#include \"InMemoryData.hpp\"\n\nnamespace oatpp { namespace data { namespace resource {\n\nInMemoryData::OutputDataHandle::~OutputDataHandle() {\n dataHandle->data = data + stream->toString();\n}\n\nInMemoryData::InMemoryData(const oatpp::String& data)\n : m_handle(std::make_shared(data))\n{}\n\nstd::shared_ptr InMemoryData::openOutputStream() {\n auto outputDataHandle = std::make_shared();\n if(!m_handle) {\n m_handle = std::make_shared(\"\");\n }\n if(!m_handle->data){\n m_handle->data = \"\";\n }\n outputDataHandle->dataHandle = m_handle;\n outputDataHandle->data = m_handle->data;\n\n auto stream = std::make_shared(1024, outputDataHandle);\n outputDataHandle->stream = stream.get();\n return stream;\n}\n\nstd::shared_ptr InMemoryData::openInputStream() {\n if(!m_handle) {\n m_handle = std::make_shared(\"\");\n }\n if(!m_handle->data){\n m_handle->data = \"\";\n }\n return std::make_shared(m_handle->data, m_handle);\n}\n\noatpp::String InMemoryData::getInMemoryData() {\n if(m_handle && m_handle->data) {\n return m_handle->data;\n }\n return nullptr;\n}\n\nv_int64 InMemoryData::getKnownSize() {\n if(m_handle && m_handle->data) {\n return static_cast(m_handle->data->size());\n }\n return 0;\n}\n\noatpp::String InMemoryData::getLocation() {\n return nullptr;\n}\n\n}}}\n\n// Path: src/oatpp/core/data/resource/File.hpp\n/***************************************************************************\n *\n * Project _____ __ ____ _ _\n * ( _ ) /__\\ (_ _)_| |_ _| |_\n * )(_)( /(__)\\ )( (_ _)(_ _)\n * (_____)(__)(__)(__) |_| |_|\n *\n *\n * Copyright 2018-present, Leonid Stryzhevskyi \n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n *\n ***************************************************************************/\n\n#ifndef oatpp_data_resource_File_hpp\n#define oatpp_data_resource_File_hpp\n\n#include \"./Resource.hpp\"\n#include \"oatpp/core/data/stream/Stream.hpp\"\n\nnamespace oatpp { namespace data { namespace resource {\n\n/**\n * File.\n * @extends - &id:oatpp::data::Resource;.\n */\nclass File : public Resource {\nprivate:\n \n struct FileHandle {\n\n oatpp::String fileName;\n\n FileHandle(const oatpp::String& fullFileName)\n : fileName(fullFileName)\n {}\n\n };\n\npublic:\n static oatpp::String concatDirAndName(const oatpp::String& dir, const oatpp::String& filename);\nprivate:\n std::shared_ptr m_handle;\npublic:\n\n /**\n * Default constructor.\n */\n File() = default;\n\n /**\n * Constructor.\n * @param fullFilename\n */\n File(const oatpp::String& fullFilename);\n\n /**\n * Constructor.\n * @param directory\n * @param filename\n */\n File(const oatpp::String& directory, const oatpp::String& filename);\n\n /**\n * Open output stream to a file.
\n * *Note: stream also captures file-handle. The file object won't be deleted until the stream is deleted.*\n * @return - `std::shared_ptr` to &id:oatpp::data::stream::OutputStream;.\n */\n std::shared_ptr openOutputStream() override;\n\n /**\n * Open input stream to a temporary file.
\n * *Note: stream also captures file-handle. The file won't be deleted until the stream is deleted.*\n * @return - `std::shared_ptr` &id:oatpp::data::stream::InputStream;.\n */\n std::shared_ptr openInputStream() override;\n\n /**\n * Not applicable.\n * @return - always returns `nullptr`.\n */\n oatpp::String getInMemoryData() override;\n\n /**\n * Not applicable.\n * @return - always returns `-1`.\n */\n v_int64 getKnownSize() override;\n\n /**\n * Get location where temporary data is stored.\n * @return - `&id:oatpp::String;`.\n */\n oatpp::String getLocation() override;\n\n};\n\n}}}\n\n#endif //oatpp_data_resource_File_hpp\n\n// Path: src/oatpp/core/data/resource/File.cpp\n/***************************************************************************\n *\n * Project _____ __ ____ _ _\n * ( _ ) /__\\ (_ _)_| |_ _| |_\n * )(_)( /(__)\\ )( (_ _)(_ _)\n * (_____)(__)(__)(__) |_| |_|\n *\n *\n * Copyright 2018-present, Leonid Stryzhevskyi \n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n *\n ***************************************************************************/\n\n#include \"File.hpp\"\n\n#include \"oatpp/core/data/stream/FileStream.hpp\"\n\nnamespace oatpp { namespace data { namespace resource {\n\noatpp::String File::concatDirAndName(const oatpp::String& dir, const oatpp::String& filename) {\n if(dir && dir->size() > 0) {\n auto lastChar = dir->data()[dir->size() - 1];\n if(lastChar != '/' && lastChar != '\\\\') {\n return dir + \"/\" + filename;\n }\n return dir + filename;\n }\n return filename;\n}\n\nFile::File(const oatpp::String& fullFileName)\n : m_handle(std::make_shared(fullFileName))\n{}\n\nFile::File(const oatpp::String& tmpDirectory, const oatpp::String& tmpFileName)\n : m_handle(std::make_shared(concatDirAndName(tmpDirectory, tmpFileName)))\n{}\n\nstd::shared_ptr File::openOutputStream() {\n if(m_handle) {\n return std::make_shared(m_handle->fileName->c_str(), \"wb\", m_handle);\n }\n throw std::runtime_error(\"[oatpp::data::resource::File::openOutputStream()]: Error. FileHandle is NOT initialized.\");\n}\n\nstd::shared_ptr File::openInputStream() {\n if(m_handle) {\n return std::make_shared(m_handle->fileName->c_str(), m_handle);\n }\n throw std::runtime_error(\"[oatpp::data::resource::File::openInputStream()]: Error. FileHandle is NOT initialized.\");\n}\n\noatpp::String File::getInMemoryData() {\n return nullptr;\n}\n\nv_int64 File::getKnownSize() {\n return -1;\n}\n\noatpp::String File::getLocation() {\n if(m_handle) {\n return m_handle->fileName;\n }\n return nullptr;\n}\n\n}}}\n\n// Path: src/oatpp/core/data/resource/TemporaryFile.hpp\n/***************************************************************************\n *\n * Project _____ __ ____ _ _\n * ( _ ) /__\\ (_ _)_| |_ _| |_\n * )(_)( /(__)\\ )( (_ _)(_ _)\n * (_____)(__)(__)(__) |_| |_|\n *\n *\n * Copyright 2018-present, Leonid Stryzhevskyi \n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n *\n ***************************************************************************/\n\n#ifndef oatpp_data_resource_TemporaryFile_hpp\n#define oatpp_data_resource_TemporaryFile_hpp\n\n#include \"./Resource.hpp\"\n#include \"oatpp/core/Types.hpp\"\n\nnamespace oatpp { namespace data { namespace resource {\n\n/**\n * Temporary file - the file which gets deleted when the destructor is called\n * (more precisely when all copies of the same `TemporaryFile` object deleted).
\n * The `TemporaryFile` object internally stores a `shared_ptr` to a file handle.\n * When file handle deleted it also deletes the underlying file.
\n * Thus it's safe to copy `TemporaryFile` object and you may treat `TemporaryFile` object\n * as a shared_ptr to a temporary file.
\n * @extends - &id:oatpp::data::Resource;.\n */\nclass TemporaryFile : public Resource {\nprivate:\n\n /*\n * Shared handle.\n * File is deleted on handle destroy.\n */\n struct FileHandle {\n\n oatpp::String fileName;\n\n FileHandle(const oatpp::String& fullFileName)\n : fileName(fullFileName)\n {}\n\n ~FileHandle();\n };\n\nprivate:\n static oatpp::String constructRandomFilename(const oatpp::String &dir, v_int32 randomWordSizeBytes, const oatpp::String &extension);\nprivate:\n std::shared_ptr m_handle;\npublic:\n\n /**\n * Default constructor.\n */\n TemporaryFile() = default;\n\n /**\n * Constructor.
\n * Create temporary file with a random name in the `tmpDirectory`.
\n * The actual file will be created only after first write to that file.
\n * Example of the generated random file name: `f7c6ecd44024ff31.tmp`.\n * @param tmpDirectory - directory where to create a temporary file.\n * @param randomWordSizeBytes - number of random bytes to generate file name.\n */\n TemporaryFile(const oatpp::String& tmpDirectory, v_int32 randomWordSizeBytes = 8);\n\n /**\n * Constructor.
\n * Create temporary file with the `tmpFileName` name in the `tmpDirectory`.
\n * @param tmpDirectory - directory where to create a temporary file.\n * @param tmpFileName - predefined name for the temporary file.\n */\n TemporaryFile(const oatpp::String& tmpDirectory, const oatpp::String& tmpFileName);\n\n /**\n * Constructor.
\n * Create temporary file with a random name and specified extension in the `tmpDirectory`.
\n * The actual file will be created only after first write to that file.
\n * Example of the generated random file name: `f7c6ecd44024ff31.txt`.\n * @param tmpDirectory - directory where to create a temporary file.\n * @param randomWordSizeBytes - number of random bytes to generate file name.\n * @param extension - extension of the temporary file, e.g. txt or .txt\n */\n TemporaryFile(const oatpp::String& tmpDirectory, v_int32 randomWordSizeBytes, const oatpp::String& extension);\n\n /**\n * Open output stream to a temporary file.
\n * *Note: stream also captures file-handle. The temporary file won't be deleted until the stream is deleted.*\n * @return - `std::shared_ptr` to &id:oatpp::data::stream::OutputStream;.\n */\n std::shared_ptr openOutputStream() override;\n\n /**\n * Open input stream to a temporary file.
\n * *Note: stream also captures file-handle. The temporary file won't be deleted until the stream is deleted.*\n * @return - `std::shared_ptr` &id:oatpp::data::stream::InputStream;.\n */\n std::shared_ptr openInputStream() override;\n\n /**\n * Not applicable.\n * @return - always returns `nullptr`.\n */\n oatpp::String getInMemoryData() override;\n\n /**\n * Not applicable.\n * @return - always returns `-1`.\n */\n v_int64 getKnownSize() override;\n\n /**\n * Get location where temporary data is stored.\n * @return - `&id:oatpp::String;`.\n */\n oatpp::String getLocation() override;\n\n /**\n * Move payload to a different file.
\n * @param fullFileName - full-file-name. Note: all the parent folders must exist.\n * @return - `true` - file was successfully moved, `false` - otherwise.\n */\n bool moveFile(const oatpp::String& fullFileName);\n\n};\n\n}}}\n\n#endif //oatpp_data_resource_TemporaryFile_hpp\n\n// Path: src/oatpp/core/data/resource/TemporaryFile.cpp\n/***************************************************************************\n *\n * Project _____ __ ____ _ _\n * ( _ ) /__\\ (_ _)_| |_ _| |_\n * )(_)( /(__)\\ )( (_ _)(_ _)\n * (_____)(__)(__)(__) |_| |_|\n *\n *\n * Copyright 2018-present, Leonid Stryzhevskyi \n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n *\n ***************************************************************************/\n\n#include \"TemporaryFile.hpp\"\n\n#include \"./File.hpp\"\n\n#include \"oatpp/core/data/stream/FileStream.hpp\"\n#include \"oatpp/core/data/stream/BufferStream.hpp\"\n#include \"oatpp/encoding/Hex.hpp\"\n#include \"oatpp/core/utils/Random.hpp\"\n\nnamespace oatpp { namespace data { namespace resource {\n\nTemporaryFile::FileHandle::~FileHandle() {\n if(fileName) {\n std::remove(fileName->c_str());\n }\n}\n\noatpp::String TemporaryFile::constructRandomFilename(const oatpp::String &dir, v_int32 randomWordSizeBytes, const oatpp::String &extension) {\n\n std::unique_ptr buff(new v_char8[randomWordSizeBytes]);\n utils::random::Random::randomBytes(buff.get(), randomWordSizeBytes);\n data::stream::BufferOutputStream s(randomWordSizeBytes * 2 + 4);\n encoding::Hex::encode(&s, buff.get(), randomWordSizeBytes, encoding::Hex::ALPHABET_LOWER);\n if (extension->at(0) != '.') {\n s << \".\";\n }\n s << extension;\n\n return File::concatDirAndName(dir, s.toString());\n\n}\n\nTemporaryFile::TemporaryFile(const oatpp::String& tmpDirectory, v_int32 randomWordSizeBytes)\n : m_handle(std::make_shared(constructRandomFilename(tmpDirectory, randomWordSizeBytes, \"tmp\")))\n{}\n\nTemporaryFile::TemporaryFile(const oatpp::String& tmpDirectory, const oatpp::String& tmpFileName)\n : m_handle(std::make_shared(File::concatDirAndName(tmpDirectory, tmpFileName)))\n{}\n\nTemporaryFile::TemporaryFile(const oatpp::String& tmpDirectory, v_int32 randomWordSizeBytes, const oatpp::String& extension)\n : m_handle(std::make_shared(constructRandomFilename(tmpDirectory, randomWordSizeBytes, extension)))\n{}\n\nstd::shared_ptr TemporaryFile::openOutputStream() {\n if(m_handle) {\n return std::make_shared(m_handle->fileName->c_str(), \"wb\", m_handle);\n }\n throw std::runtime_error(\"[oatpp::data::resource::TemporaryFile::openOutputStream()]: Error. FileHandle is NOT initialized.\");\n}\n\nstd::shared_ptr TemporaryFile::openInputStream() {\n if(m_handle) {\n return std::make_shared(m_handle->fileName->c_str(), m_handle);\n }\n throw std::runtime_error(\"[oatpp::data::resource::TemporaryFile::openInputStream()]: Error. FileHandle is NOT initialized.\");\n}\n\noatpp::String TemporaryFile::getInMemoryData() {\n return nullptr;\n}\n\nv_int64 TemporaryFile::getKnownSize() {\n return -1;\n}\n\noatpp::String TemporaryFile::getLocation() {\n if(m_handle) {\n return m_handle->fileName;\n }\n return nullptr;\n}\n\nbool TemporaryFile::moveFile(const oatpp::String& fullFileName) {\n if(m_handle) {\n return std::rename(m_handle->fileName->c_str(), fullFileName->c_str()) == 0;\n }\n return false;\n}\n\n}}}\n\n// Path: src/oatpp/core/data/buffer/Processor.hpp\n/***************************************************************************\n *\n * Project _____ __ ____ _ _\n * ( _ ) /__\\ (_ _)_| |_ _| |_\n * )(_)( /(__)\\ )( (_ _)(_ _)\n * (_____)(__)(__)(__) |_| |_|\n *\n *\n * Copyright 2018-present, Leonid Stryzhevskyi \n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n *\n ***************************************************************************/\n\n#ifndef oatpp_data_buffer_Processor_hpp\n#define oatpp_data_buffer_Processor_hpp\n\n#include \"oatpp/core/IODefinitions.hpp\"\n#include \"oatpp/core/base/ObjectHandle.hpp\"\n#include \n\nnamespace oatpp { namespace data { namespace buffer {\n\n/**\n * Convenience structure for stream Async-Inline read operations.\n */\nstruct InlineReadData {\n\n /**\n * Pointer to current position in the buffer.\n */\n void* currBufferPtr;\n\n /**\n * Bytes left to read to the buffer.\n */\n v_buff_size bytesLeft;\n\n /**\n * Default constructor.\n */\n InlineReadData();\n\n /**\n * Constructor.\n * @param data\n * @param size\n */\n InlineReadData(void* data, v_buff_size size);\n\n /**\n * Set `currBufferPtr` and `bytesLeft` values.
\n * @param data - pointer to buffer to store read data.\n * @param size - size in bytes of the buffer.\n */\n void set(void* data, v_buff_size size);\n\n /**\n * Increase position in the read buffer by `amount` bytes.
\n * This will increase `currBufferPtr` and descrease `bytesLeft` values.\n * @param amount\n */\n void inc(v_buff_size amount);\n\n /**\n * Same as `inc(bytesLeft).`\n */\n void setEof();\n\n};\n\n/**\n * Convenience structure for stream Async-Inline write operations.\n */\nstruct InlineWriteData {\n\n /**\n * Pointer to current position in the buffer.\n */\n const void* currBufferPtr;\n\n /**\n * Bytes left to write from the buffer.\n */\n v_buff_size bytesLeft;\n\n /**\n * Default constructor.\n */\n InlineWriteData();\n\n /**\n * Constructor.\n * @param data\n * @param size\n */\n InlineWriteData(const void* data, v_buff_size size);\n\n /**\n * Set `currBufferPtr` and `bytesLeft` values.
\n * @param data - pointer to buffer containing data to be written.\n * @param size - size in bytes of the buffer.\n */\n void set(const void* data, v_buff_size size);\n\n /**\n * Increase position in the write buffer by `amount` bytes.
\n * This will increase `currBufferPtr` and descrease `bytesLeft` values.\n * @param amount\n */\n void inc(v_buff_size amount);\n\n /**\n * Same as `inc(bytesLeft).`\n */\n void setEof();\n\n};\n\n/**\n * Buffer processor.\n * Note: all processors are considered to be stateful.\n */\nclass Processor {\npublic:\n\n /**\n * Enum of processing errors.\n */\n enum Error : v_int32 {\n\n /**\n * No error.\n */\n OK = 0,\n\n /**\n * Caller must set fields of `dataIn` parameter.\n */\n PROVIDE_DATA_IN = 1,\n\n /**\n * Caller must read all the data from the `dataOut`.\n */\n FLUSH_DATA_OUT = 2,\n\n /**\n * Processing is finished.\n */\n FINISHED = 3\n\n //*********************************************//\n // Other values are processor-specific errors. //\n //*********************************************//\n };\n\npublic:\n\n /**\n * Default virtual destructor.\n */\n virtual ~Processor() = default;\n\n /**\n * If the client is using the input stream to read data and add it to the processor,\n * the client MAY ask the processor for a suggested read size.\n * @return - suggested read size.\n */\n virtual v_io_size suggestInputStreamReadSize() = 0;\n\n /**\n * Process data.\n * @param dataIn - data provided by client to processor. Input data. &id:data::buffer::InlineReadData;.\n * Set `dataIn` buffer pointer to `nullptr` to designate the end of input.\n * @param dataOut - data provided to client by processor. Output data. &id:data::buffer::InlineReadData;.\n * @return - &l:Processor::Error;.\n */\n virtual v_int32 iterate(data::buffer::InlineReadData& dataIn,\n data::buffer::InlineReadData& dataOut) = 0;\n\n};\n\n/**\n * Pipeline of buffer processors.\n */\nclass ProcessingPipeline : public Processor {\nprivate:\n std::vector> m_processors;\n std::vector m_intermediateData;\npublic:\n\n /**\n * Constructor.\n * @param m_processors - the array of processors defining the pipeline.\n */\n ProcessingPipeline(const std::vector>& m_processors);\n\n /**\n * If the client is using the input stream to read data and add it to the processor,\n * the client MAY ask the processor for a suggested read size.\n * @return - suggested read size.\n */\n v_io_size suggestInputStreamReadSize() override;\n\n /**\n * Process data.\n * @param dataIn - data provided by client to processor. Input data. &id:data::buffer::InlineReadData;.\n * Set `dataIn` buffer pointer to `nullptr` to designate the end of input.\n * @param dataOut - data provided to client by processor. Output data. &id:data::buffer::InlineReadData;.\n * @return - &l:Processor::Error;.\n */\n v_int32 iterate(data::buffer::InlineReadData& dataIn,\n data::buffer::InlineReadData& dataOut) override;\n\n};\n\n}}}\n\n#endif // oatpp_data_buffer_Processor_hpp\n\n// Path: src/oatpp/core/data/buffer/Processor.cpp\n/***************************************************************************\n *\n * Project _____ __ ____ _ _\n * ( _ ) /__\\ (_ _)_| |_ _| |_\n * )(_)( /(__)\\ )( (_ _)(_ _)\n * (_____)(__)(__)(__) |_| |_|\n *\n *\n * Copyright 2018-present, Leonid Stryzhevskyi \n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n *\n ***************************************************************************/\n\n#include \"Processor.hpp\"\n\nnamespace oatpp { namespace data{ namespace buffer {\n\n////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////\n// InlineReadData\n\nInlineReadData::InlineReadData()\n : currBufferPtr(nullptr)\n , bytesLeft(0)\n{}\n\nInlineReadData::InlineReadData(void* data, v_buff_size size)\n : currBufferPtr(data)\n , bytesLeft(size)\n{}\n\nvoid InlineReadData::set(void* data, v_buff_size size) {\n currBufferPtr = data;\n bytesLeft = size;\n}\n\nvoid InlineReadData::inc(v_buff_size amount) {\n currBufferPtr = &(reinterpret_cast(currBufferPtr))[amount];\n bytesLeft -= amount;\n}\n\nvoid InlineReadData::setEof() {\n currBufferPtr = &(reinterpret_cast(currBufferPtr))[bytesLeft];\n bytesLeft = 0;\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////\n// InlineWriteData\n\nInlineWriteData::InlineWriteData()\n : currBufferPtr(nullptr)\n , bytesLeft(0)\n{}\n\nInlineWriteData::InlineWriteData(const void* data, v_buff_size size)\n : currBufferPtr(data)\n , bytesLeft(size)\n{}\n\nvoid InlineWriteData::set(const void* data, v_buff_size size) {\n currBufferPtr = data;\n bytesLeft = size;\n}\n\nvoid InlineWriteData::inc(v_buff_size amount) {\n currBufferPtr = &(reinterpret_cast(const_cast(currBufferPtr)))[amount];\n bytesLeft -= amount;\n}\n\nvoid InlineWriteData::setEof() {\n currBufferPtr = &(reinterpret_cast(const_cast(currBufferPtr)))[bytesLeft];\n bytesLeft = 0;\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////\n...\n", "needle_token_start": 7087, "needle_token_end": 7111, "code_context_ntokens": 15764, "output": ["```cpp\nclass InMemoryData : public Resource {\nprivate:\n struct DataHandle {\n oatpp::String data;\n DataHandle(const oatpp::String& pData) : data(pData) {}\n };\n std::shared_ptr m_handle;\npublic:\n InMemoryData() = default;\n InMemoryData(const oatpp::String& data);\n std::shared_ptr openOutputStream() override;\n std::shared_ptr openInputStream() override;\n oatpp::String getInMemoryData() override;\n v_int64 getKnownSize() override;\n oatpp::String getLocation() override;\n};\n```"]} +{"repo": "oatpp/oatpp", "name": "ListenerCall", "language": "cpp", "path": "src/oatpp/web/mime/multipart/StatefulParser.hpp", "position_ratio": 0.55, "description": "\nFunction Description:\n1. **Purpose**: The purpose of this class is to manage the invocation of specific callback functions based on the type of event occurring during the parsing of multipart data streams. It determines whether headers or data chunks are being processed and triggers the appropriate event handlers.\n2. **Input**: This class takes no direct inputs when instantiated but uses internal state variables to manage the type of callback to invoke (either for headers or data chunks) and the associated data.\n3. **Output**: There is no direct output produced by instances of this class. Instead, it indirectly affects the system by triggering specific callback functions which handle the parsed data.\n4. **Procedure**: Upon instantiation, the class initializes with no scheduled callback. When the parsing process identifies a new part or data chunk, the class sets the appropriate callback type and data pointers. Depending on the internal state, it then triggers either the header processing or data processing callbacks at the appropriate times in the parsing lifecycle.\n\n", "instruction": "Based on the function description and code context, please retrieve and repeat the exact described function from the code context in a code block wrapped by ```:", "template": "instruction\ncode_context\ndescription\ninstruction", "code_context": "// Path: src/oatpp/encoding/Unicode.cpp\n/***************************************************************************\n *\n * Project _____ __ ____ _ _\n * ( _ ) /__\\ (_ _)_| |_ _| |_\n * )(_)( /(__)\\ )( (_ _)(_ _)\n * (_____)(__)(__)(__) |_| |_|\n *\n *\n * Copyright 2018-present, Leonid Stryzhevskyi \n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n *\n ***************************************************************************/\n\n#include \"Unicode.hpp\"\n\n#if defined(WIN32) || defined(_WIN32)\n #include \n#else\n #include \n#endif\n\nnamespace oatpp { namespace encoding {\n \nv_buff_size Unicode::getUtf8CharSequenceLength(v_char8 firstByte) {\n \n if(firstByte < 128){\n return 1;\n }\n \n if((firstByte | 192) != firstByte){\n return 0;\n }\n \n if((firstByte | 32) != firstByte){\n return 2;\n } else if((firstByte | 16) != firstByte){\n return 3;\n } else if((firstByte | 8) != firstByte){\n return 4;\n } else if((firstByte | 4) != firstByte){\n return 5;\n } else if((firstByte | 2) != firstByte){\n return 6;\n } else {\n return 0;\n }\n \n}\n \nv_buff_size Unicode::getUtf8CharSequenceLengthForCode(v_uint32 code){\n if(code < 128) {\n return 1;\n } else if(code < 0x00000800){\n return 2;\n } else if(code < 0x00010000){\n return 3;\n } else if(code < 0x00200000){\n return 4;\n } else if(code < 0x04000000){\n return 5;\n } else {\n return 6;\n }\n}\n \nv_int32 Unicode::encodeUtf8Char(const char* sequence, v_buff_size& length){\n v_char8 byte = static_cast(sequence[0]);\n if(byte > 127){\n v_int32 code;\n if((byte | 32) != byte){\n length = 2;\n code = ((31 & byte) << 6) | (sequence[1] & 63);\n return code;\n } else if((byte | 16) != byte){\n code = (15 & byte) << 12;\n length = 3;\n } else if((byte | 8) != byte){\n length = 4;\n v_int32 value = *(reinterpret_cast(const_cast(sequence)));\n code = ((7 & byte) << 18) |\n (((value >> 24) & 0xFF) & 63) |\n (((value >> 16) & 0xFF) & 63) << 6 |\n (((value >> 8) & 0xFF) & 63) << 12;\n return code;\n } else if((byte | 4) != byte){\n code = (3 & byte) << 24;\n length = 5;\n } else if((byte | 2) != byte){\n code = (1 & byte) << 30;\n length = 6;\n } else {\n return -1;\n }\n \n v_char8 bitIndex = 0;\n for(v_buff_size i = length; i > 1; i--){\n code |= (sequence[i - 1] & 63) << bitIndex;\n bitIndex = static_cast(bitIndex + 6);\n }\n return code;\n } else {\n length = 1;\n return byte;\n }\n}\n \nv_buff_size Unicode::decodeUtf8Char(v_int32 signed_code, p_char8 buffer) {\n v_uint32 code = static_cast(signed_code);\n if(code >= 0x00000080 && code < 0x00000800){\n *(reinterpret_cast(buffer)) = static_cast(htons(((((code >> 6) & 31) | 192) << 8) | ((code & 63) | 128)));\n return 2;\n } else if(code >= 0x00000800 && code < 0x00010000){\n...\n// Path: src/oatpp/encoding/Url.hpp\n/***************************************************************************\n *\n * Project _____ __ ____ _ _\n * ( _ ) /__\\ (_ _)_| |_ _| |_\n * )(_)( /(__)\\ )( (_ _)(_ _)\n * (_____)(__)(__)(__) |_| |_|\n *\n *\n * Copyright 2018-present, Leonid Stryzhevskyi \n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n *\n ***************************************************************************/\n\n#ifndef oatpp_encoding_Url_hpp\n#define oatpp_encoding_Url_hpp\n\n#include \"Hex.hpp\"\n#include \"oatpp/core/data/stream/Stream.hpp\"\n\nnamespace oatpp { namespace encoding {\n\nclass Url {\npublic:\n\n struct Config {\n\n bool spaceToPlus = false;\n const char* hexAlphabet = Hex::ALPHABET_UPPER;\n bool allowedChars[256];\n\n Config();\n\n void allowChar(v_char8 c);\n void allowCharRange(v_char8 from, v_char8 to);\n\n void disallowChar(v_char8 c);\n void disallowCharRange(v_char8 from, v_char8 to);\n\n };\n\npublic:\n\n static void encode(data::stream::ConsistentOutputStream* stream, const void* data, v_buff_size size, const Config& config);\n static void decode(data::stream::ConsistentOutputStream* stream, const void* data, v_buff_size size);\n\n static oatpp::String encode(const oatpp::String& data, const Config& config);\n static oatpp::String decode(const oatpp::String& data);\n\n};\n\n}}\n\n#endif //oatpp_encoding_Url_hpp\n\n// Path: src/oatpp/encoding/Url.cpp\n/***************************************************************************\n *\n * Project _____ __ ____ _ _\n * ( _ ) /__\\ (_ _)_| |_ _| |_\n * )(_)( /(__)\\ )( (_ _)(_ _)\n * (_____)(__)(__)(__) |_| |_|\n *\n *\n * Copyright 2018-present, Leonid Stryzhevskyi \n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n *\n ***************************************************************************/\n\n#include \"Url.hpp\"\n\n#include \"oatpp/core/data/stream/BufferStream.hpp\"\n\nnamespace oatpp { namespace encoding {\n\nUrl::Config::Config() {\n\n disallowCharRange(0, 255);\n\n allowCharRange('0', '9');\n allowCharRange('a', 'z');\n allowCharRange('A', 'Z');\n\n allowChar('-');\n allowChar('.');\n allowChar('_');\n allowChar('~');\n\n}\n\nvoid Url::Config::allowChar(v_char8 c) {\n allowedChars[c] = true;\n}\n\nvoid Url::Config::allowCharRange(v_char8 from, v_char8 to) {\n for(v_int32 i = from; i <= to; i++) {\n allowedChars[i] = true;\n }\n}\n\nvoid Url::Config::disallowChar(v_char8 c) {\n allowedChars[c] = false;\n}\n\nvoid Url::Config::disallowCharRange(v_char8 from, v_char8 to) {\n for(v_int32 i = from; i <= to; i++) {\n allowedChars[i] = false;\n }\n}\n\nvoid Url::encode(data::stream::ConsistentOutputStream *stream, const void *data, v_buff_size size, const Config& config) {\n\n auto pdata = reinterpret_cast(data);\n\n for(v_buff_size i = 0; i < size; i++) {\n v_char8 c = static_cast(pdata[i]);\n if(config.allowedChars[c]) {\n stream->writeCharSimple(c);\n } else if(c == ' ' && config.spaceToPlus) {\n stream->writeCharSimple('+');\n } else {\n stream->writeCharSimple('%');\n Hex::encode(stream, pdata + i, 1, config.hexAlphabet);\n }\n }\n\n}\n\nvoid Url::decode(data::stream::ConsistentOutputStream* stream, const void* data, v_buff_size size) {\n\n auto pdata = reinterpret_cast(data);\n v_buff_size i = 0;\n\n while (i < size) {\n\n v_char8 c = static_cast(pdata[i]);\n if(c == '%') {\n if(size - i > 1) {\n Hex::decode(stream, pdata + i + 1, 2);\n i += 3;\n } else {\n break;\n }\n } else if (c == '+') {\n stream->writeCharSimple(' ');\n i ++;\n } else {\n stream->writeCharSimple(c);\n i ++;\n }\n\n }\n\n}\n\noatpp::String Url::encode(const oatpp::String& data, const Config& config) {\n if(!data) return nullptr;\n data::stream::BufferOutputStream stream(static_cast(data->size() * 3));\n encode(&stream, data->data(), static_cast(data->size()), config);\n return stream.toString();\n}\n\noatpp::String Url::decode(const oatpp::String& data) {\n if(!data) return nullptr;\n data::stream::BufferOutputStream stream(static_cast(data->size()));\n decode(&stream, data->data(), static_cast(data->size()));\n return stream.toString();\n}\n\n}}\n\n// Path: src/oatpp/algorithm/CRC.hpp\n/***************************************************************************\n *\n * Project _____ __ ____ _ _\n * ( _ ) /__\\ (_ _)_| |_ _| |_\n * )(_)( /(__)\\ )( (_ _)(_ _)\n * (_____)(__)(__)(__) |_| |_|\n *\n *\n * Copyright 2018-present, Leonid Stryzhevskyi \n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n *\n ***************************************************************************/\n\n#ifndef oatpp_algorithm_CRC_hpp\n#define oatpp_algorithm_CRC_hpp\n\n#include \"oatpp/core/base/Environment.hpp\"\n\n#include \"oatpp/encoding/Hex.hpp\"\n\nnamespace oatpp { namespace algorithm {\n\n/**\n * Implementation of CRC-32. Cyclic redundancy check algorithm.\n */\nclass CRC32 {\npublic:\n\n /**\n * Precalculated table\n */\n static const p_uint32 TABLE_04C11DB7;\npublic:\n\n static v_uint32 bitReverse(v_uint32 poly);\n \n /**\n * Generates v_uint32 table[256] for polynomial\n */\n static p_uint32 generateTable(v_uint32 poly);\n\n /**\n * Calculate CRC32 value for buffer of defined size\n * @param buffer\n * @param size\n * @param crc\n * @param initValue\n * @param xorOut\n * @param table\n * @return - CRC32 value (v_uint32)\n */\n static v_uint32 calc(const void *buffer, v_buff_size size, v_uint32 crc = 0, v_uint32 initValue = 0xFFFFFFFF, v_uint32 xorOut = 0xFFFFFFFF, p_uint32 table = TABLE_04C11DB7);\n \n};\n \n}}\n\n#endif /* oatpp_algorithm_CRC_hpp */\n\n// Path: src/oatpp/algorithm/CRC.cpp\n/***************************************************************************\n *\n * Project _____ __ ____ _ _\n * ( _ ) /__\\ (_ _)_| |_ _| |_\n * )(_)( /(__)\\ )( (_ _)(_ _)\n * (_____)(__)(__)(__) |_| |_|\n *\n *\n * Copyright 2018-present, Leonid Stryzhevskyi \n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n *\n ***************************************************************************/\n\n#include \"CRC.hpp\"\n\nnamespace oatpp { namespace algorithm {\n \nconst p_uint32 CRC32::TABLE_04C11DB7 = generateTable(0x04C11DB7);\n \nv_uint32 CRC32::bitReverse(v_uint32 poly) {\n v_uint32 result = 0;\n for(v_int32 i = 0; i < 32; i ++) {\n if((poly & (1U << i)) > 0) {\n result |= (1U << (31 - i));\n }\n }\n return result;\n}\n \np_uint32 CRC32::generateTable(v_uint32 poly) {\n \n p_uint32 result = new v_uint32[256];\n v_uint32 polyReverse = bitReverse(poly);\n v_uint32 value;\n \n for(v_uint32 i = 0; i < 256; i++) {\n value = i;\n for (v_int32 bit = 0; bit < 8; bit++) {\n if (value & 1) {\n value = (value >> 1) ^ polyReverse;\n } else {\n value = (value >> 1);\n }\n }\n \n result[i] = value;\n \n }\n \n return result;\n \n}\n \nv_uint32 CRC32::calc(const void *buffer, v_buff_size size, v_uint32 crc, v_uint32 initValue, v_uint32 xorOut, p_uint32 table) {\n \n auto data = reinterpret_cast(buffer);\n crc = crc ^ initValue;\n \n for(v_buff_size i = 0; i < size; i++) {\n crc = table[(crc & 0xFF) ^ data[i]] ^ (crc >> 8);\n }\n \n return crc ^ xorOut;\n}\n \n}}\n\n// Path: src/oatpp/web/mime/multipart/Part.hpp\n/***************************************************************************\n *\n * Project _____ __ ____ _ _\n * ( _ ) /__\\ (_ _)_| |_ _| |_\n * )(_)( /(__)\\ )( (_ _)(_ _)\n * (_____)(__)(__)(__) |_| |_|\n *\n *\n * Copyright 2018-present, Leonid Stryzhevskyi \n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n *\n ***************************************************************************/\n\n#ifndef oatpp_web_mime_multipart_Part_hpp\n#define oatpp_web_mime_multipart_Part_hpp\n\n#include \"oatpp/core/data/share/LazyStringMap.hpp\"\n#include \"oatpp/core/data/resource/Resource.hpp\"\n\nnamespace oatpp { namespace web { namespace mime { namespace multipart {\n\n/**\n * One part of the multipart.\n */\nclass Part {\npublic:\n /**\n * Typedef for headers map. Headers map key is case-insensitive.\n * For more info see &id:oatpp::data::share::LazyStringMap;.\n */\n typedef oatpp::data::share::LazyStringMultimap Headers;\nprivate:\n oatpp::String m_name;\n oatpp::String m_filename;\n Headers m_headers;\n std::shared_ptr m_payload;\nprivate:\n const char* m_tagName;\n std::shared_ptr m_tagObject;\npublic:\n\n /**\n * Default constructor.\n */\n Part() = default;\n\n /**\n * Constructor.\n * @param headers - headers of the part.\n * @param payload - part payload.\n */\n Part(const Headers& headers, const std::shared_ptr& payload = nullptr);\n\n /**\n * Set payload.\n * @param payload\n */\n void setPayload(const std::shared_ptr& payload);\n\n /**\n * Get payload.\n * @return\n */\n std::shared_ptr getPayload();\n\n /**\n * Get name of the part.\n * @return - name of the part.\n */\n oatpp::String getName() const;\n\n /**\n * Get filename of the part (if applicable).\n * @return - filename.\n */\n oatpp::String getFilename() const;\n\n /**\n * Get request's headers map.\n * @return Headers map\n */\n const Headers& getHeaders() const;\n\n /**\n * Get header value\n * @param headerName\n * @return header value\n */\n oatpp::String getHeader(const oatpp::data::share::StringKeyLabelCI& headerName) const;\n\n /**\n * Add http header.\n * @param key - &id:oatpp::data::share::StringKeyLabelCI;.\n * @param value - &id:oatpp::data::share::StringKeyLabel;.\n */\n void putHeader(const oatpp::data::share::StringKeyLabelCI& key, const oatpp::data::share::StringKeyLabel& value);\n\n /**\n * Add http header if not already exists.\n * @param key - &id:oatpp::data::share::StringKeyLabelCI;.\n * @param value - &id:oatpp::data::share::StringKeyLabel;.\n * @return - `true` if header was added.\n */\n bool putHeaderIfNotExists(const oatpp::data::share::StringKeyLabelCI& key, const oatpp::data::share::StringKeyLabel& value);\n\n /**\n * Tag-object - object used to associate some data with the Part.
\n * Ex.: used by &id:oatpp::web::mime::multipart::InMemoryPartReader;. to\n * associate intermediate buffer with the part.\n * @param tagName\n * @param tagObject\n */\n void setTag(const char* tagName, const std::shared_ptr& tagObject);\n\n /**\n * Get tag name.\n * @return\n */\n const char* getTagName();\n\n /**\n * Get tag object.\n * @return\n */\n std::shared_ptr getTagObject();\n\n /**\n * Clear the tag.\n */\n void clearTag();\n\n};\n\n}}}}\n\n#endif // oatpp_web_mime_multipart_Part_hpp\n\n// Path: src/oatpp/web/mime/multipart/Part.cpp\n/***************************************************************************\n *\n * Project _____ __ ____ _ _\n * ( _ ) /__\\ (_ _)_| |_ _| |_\n * )(_)( /(__)\\ )( (_ _)(_ _)\n * (_____)(__)(__)(__) |_| |_|\n *\n *\n * Copyright 2018-present, Leonid Stryzhevskyi \n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n *\n ***************************************************************************/\n\n#include \"Part.hpp\"\n\n#include \"oatpp/web/protocol/http/Http.hpp\"\n#include \"oatpp/core/parser/Caret.hpp\"\n\n#include \n\nnamespace oatpp { namespace web { namespace mime { namespace multipart {\n\nPart::Part(const Headers &headers, const std::shared_ptr& payload)\n : m_headers(headers)\n , m_payload(payload)\n{\n\n auto contentDisposition = m_headers.getAsMemoryLabel(\"Content-Disposition\");\n if(contentDisposition) {\n\n oatpp::web::protocol::http::HeaderValueData valueData;\n oatpp::web::protocol::http::Parser::parseHeaderValueData(valueData, contentDisposition, ';');\n\n m_name = valueData.getTitleParamValue(\"name\");\n m_filename = valueData.getTitleParamValue(\"filename\");\n\n }\n\n}\n\nvoid Part::setPayload(const std::shared_ptr& payload) {\n m_payload = payload;\n}\n\nstd::shared_ptr Part::getPayload() {\n return m_payload;\n}\n\noatpp::String Part::getName() const {\n return m_name;\n}\n\n\noatpp::String Part::getFilename() const {\n return m_filename;\n}\n\n\nconst Part::Headers& Part::getHeaders() const {\n return m_headers;\n}\n\noatpp::String Part::getHeader(const oatpp::data::share::StringKeyLabelCI& headerName) const {\n return m_headers.get(headerName);\n}\n\nvoid Part::putHeader(const oatpp::data::share::StringKeyLabelCI& key, const oatpp::data::share::StringKeyLabel& value) {\n m_headers.put(key, value);\n}\n\nbool Part::putHeaderIfNotExists(const oatpp::data::share::StringKeyLabelCI& key, const oatpp::data::share::StringKeyLabel& value) {\n return m_headers.putIfNotExists(key, value);\n}\n\nvoid Part::setTag(const char* tagName, const std::shared_ptr& tagObject) {\n m_tagName = tagName;\n m_tagObject = tagObject;\n}\n\nconst char* Part::getTagName() {\n return m_tagName;\n}\n\nstd::shared_ptr Part::getTagObject() {\n return m_tagObject;\n}\n\nvoid Part::clearTag() {\n m_tagName = nullptr;\n m_tagObject.reset();\n}\n\n}}}}\n\n// Path: src/oatpp/web/mime/multipart/StatefulParser.hpp\n/***************************************************************************\n *\n * Project _____ __ ____ _ _\n * ( _ ) /__\\ (_ _)_| |_ _| |_\n * )(_)( /(__)\\ )( (_ _)(_ _)\n * (_____)(__)(__)(__) |_| |_|\n *\n *\n * Copyright 2018-present, Leonid Stryzhevskyi \n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n *\n ***************************************************************************/\n\n#ifndef oatpp_web_mime_multipart_StatefulParser_hpp\n#define oatpp_web_mime_multipart_StatefulParser_hpp\n\n#include \"oatpp/core/data/stream/BufferStream.hpp\"\n#include \"oatpp/core/data/share/LazyStringMap.hpp\"\n#include \"oatpp/core/Types.hpp\"\n\n#include \n\nnamespace oatpp { namespace web { namespace mime { namespace multipart {\n\n/**\n * Stateful parser of multipart-data stream.\n * Parser designed to work with stream-like data in order to store minimum data in the memory.\n */\nclass StatefulParser {\nprivate:\n static constexpr v_int32 STATE_BOUNDARY = 0;\n static constexpr v_int32 STATE_AFTER_BOUNDARY = 1;\n static constexpr v_int32 STATE_HEADERS = 2;\n static constexpr v_int32 STATE_DATA = 3;\n static constexpr v_int32 STATE_DONE = 4;\nprivate:\n static constexpr v_uint32 HEADERS_SECTION_END = ('\\r' << 24) | ('\\n' << 16) | ('\\r' << 8) | ('\\n');\nprivate:\n /**\n * Typedef for headers map. Headers map key is case-insensitive.\n * For more info see &id:oatpp::data::share::LazyStringMap;.\n */\n typedef oatpp::data::share::LazyStringMultimap Headers;\npublic:\n\n /**\n * Listener for parsed items.\n */\n class Listener {\n public:\n /**\n * Typedef for headers map. Headers map key is case-insensitive.\n * For more info see &id:oatpp::data::share::LazyStringMap;.\n */\n typedef oatpp::data::share::LazyStringMultimap Headers;\n public:\n\n /**\n * Default virtual Destructor.\n */\n virtual ~Listener() = default;\n\n /**\n * Called on new part found in the stream.\n * Always called before `onPartData` events.\n * @param partHeaders - complete set of part headers.\n */\n virtual void onPartHeaders(const Headers& partHeaders) = 0;\n\n /**\n * Called on each new chunk of bytes parsed from the part body.\n * When all data of message is read, readMessage is called again with size == 0 to\n * indicate end of the part.\n * @param data - pointer to data.\n * @param size - size of the data in bytes.\n */\n virtual void onPartData(const char* data, v_buff_size size) = 0;\n\n };\n\npublic:\n\n /**\n * Async Listener for parsed items.\n */\n class AsyncListener {\n public:\n /**\n * Typedef for headers map. Headers map key is case-insensitive.\n * For more info see &id:oatpp::data::share::LazyStringMap;.\n */\n typedef oatpp::data::share::LazyStringMultimap Headers;\n public:\n\n /**\n * Default virtual Destructor.\n */\n virtual ~AsyncListener() = default;\n\n /**\n * Called on new part found in the stream.\n * Always called before `onPartData` events.\n * @param partHeaders - complete set of part headers.\n */\n virtual async::CoroutineStarter onPartHeadersAsync(const Headers& partHeaders) = 0;\n\n /**\n * Called on each new chunk of bytes parsed from the part body.\n * When all data of message is read, readMessage is called again with size == 0 to\n * indicate end of the part.\n * @param data - pointer to data.\n * @param size - size of the data in bytes.\n */\n virtual async::CoroutineStarter onPartDataAsync(const char* data, v_buff_size size) = 0;\n\n };\n\nprivate:\n\n class ListenerCall {\n public:\n\n static constexpr v_int32 CALL_NONE = 0;\n static constexpr v_int32 CALL_ON_HEADERS = 1;\n static constexpr v_int32 CALL_ON_DATA = 2;\n\n public:\n\n \nListenerCall()\n : callType(CALL_NONE)\n , data(nullptr)\n , size(0)\n {}\n\n v_int32 callType;\n const char* data;\n v_io_size size;\n\n void setOnHeadersCall();\n void setOnDataCall(const char* pData, v_buff_size pSize);\n\n void call(StatefulParser* parser);\n async::CoroutineStarter callAsync(StatefulParser* parser);\n\n explicit operator bool() const;\n\n };\n\nprivate:\n\n v_int32 m_state;\n v_int64 m_currPartIndex;\n v_buff_size m_currBoundaryCharIndex;\n bool m_checkForBoundary;\n bool m_finishingBoundary;\n bool m_readingBody;\n\n v_uint32 m_headerSectionEndAccumulator;\n\n oatpp::String m_firstBoundarySample;\n oatpp::String m_nextBoundarySample;\n\n /*\n * Headers of the part are stored in the buffer and are parsed as one chunk.\n */\n data::stream::BufferOutputStream m_headersBuffer;\n\n /*\n * Max length of all headers per one part.\n * Default value = 4096 bytes.\n */\n v_buff_size m_maxPartHeadersSize;\n\n std::shared_ptr m_listener;\n std::shared_ptr m_asyncListener;\n\nprivate:\n\n void parseHeaders(Headers& headers);\n\nprivate:\n\n ListenerCall parseNext_Boundary(data::buffer::InlineWriteData& inlineData);\n void parseNext_AfterBoundary(data::buffer::InlineWriteData& inlineData);\n ListenerCall parseNext_Headers(data::buffer::InlineWriteData& inlineData);\n ListenerCall parseNext_Data(data::buffer::InlineWriteData& inlineData);\n\npublic:\n\n /**\n * Constructor.\n * @param boundary - value of multipart boundary.\n * @param listener - &l:StatefulParser::Listener;.\n * @param asyncListener - &l:StatefulParser::AsyncListener;.\n */\n StatefulParser(const oatpp::String& boundary,\n const std::shared_ptr& listener,\n const std::shared_ptr& asyncListener);\n\n /**\n * Parse next chunk of bytes.\n * @param inlineData - inline data.\n * @param action - Async Action in case Async Listener was provided in constructor.\n */\n void parseNext(data::buffer::InlineWriteData& inlineData, async::Action& action);\n\n /**\n * Check if parser done parsing data.\n * @return - `true` or `false`.\n */\n bool finished();\n\n};\n\n}}}}\n\n#endif // oatpp_web_mime_multipart_StatefulParser_hpp\n\n// Path: src/oatpp/web/mime/multipart/StatefulParser.cpp\n/***************************************************************************\n *\n * Project _____ __ ____ _ _\n * ( _ ) /__\\ (_ _)_| |_ _| |_\n * )(_)( /(__)\\ )( (_ _)(_ _)\n * (_____)(__)(__)(__) |_| |_|\n *\n *\n * Copyright 2018-present, Leonid Stryzhevskyi \n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n *\n ***************************************************************************/\n\n#include \"StatefulParser.hpp\"\n\n#include \"oatpp/web/protocol/http/Http.hpp\"\n\n#include \"oatpp/core/parser/Caret.hpp\"\n#include \"oatpp/core/parser/ParsingError.hpp\"\n\nnamespace oatpp { namespace web { namespace mime { namespace multipart {\n\n////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////\n// StatefulParser::ListenerCall\n\nvoid StatefulParser::ListenerCall::setOnHeadersCall() {\n callType = CALL_ON_HEADERS;\n data = nullptr;\n size = 0;\n}\n\nvoid StatefulParser::ListenerCall::setOnDataCall(const char* pData, v_buff_size pSize) {\n callType = CALL_ON_DATA;\n data = pData;\n size = pSize;\n}\n\nvoid StatefulParser::ListenerCall::call(StatefulParser* parser) {\n\n if(parser->m_listener) {\n\n switch(callType) {\n\n case CALL_ON_HEADERS:\n {\n Headers headers;\n parser->parseHeaders(headers);\n parser->m_listener->onPartHeaders(headers);\n }\n break;\n\n case CALL_ON_DATA:\n parser->m_listener->onPartData(data, size);\n break;\n\n default:\n break;\n\n }\n\n }\n\n}\n\nasync::CoroutineStarter StatefulParser::ListenerCall::callAsync(StatefulParser* parser) {\n\n if(parser->m_asyncListener) {\n\n switch(callType) {\n\n case CALL_ON_HEADERS: {\n Headers headers;\n parser->parseHeaders(headers);\n return parser->m_asyncListener->onPartHeadersAsync(headers);\n }\n\n case CALL_ON_DATA:\n return parser->m_asyncListener->onPartDataAsync(data, size);\n\n default:\n break;\n\n }\n\n }\n\n return nullptr;\n\n}\n\nStatefulParser::ListenerCall::operator bool() const {\n return callType != CALL_NONE;\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////\n// StatefulParser\n\nStatefulParser::StatefulParser(const oatpp::String& boundary,\n const std::shared_ptr& listener,\n const std::shared_ptr& asyncListener)\n : m_state(STATE_BOUNDARY)\n , m_currPartIndex(0)\n , m_currBoundaryCharIndex(0)\n , m_checkForBoundary(true)\n , m_finishingBoundary(false)\n , m_readingBody(false)\n , m_headerSectionEndAccumulator(0)\n , m_firstBoundarySample(\"--\" + boundary)\n , m_nextBoundarySample(\"\\r\\n--\" + boundary)\n , m_maxPartHeadersSize(4092)\n , m_listener(listener)\n , m_asyncListener(asyncListener)\n{}\n\nvoid StatefulParser::parseHeaders(Headers& headers) {\n\n m_currPartIndex ++;\n\n auto headersText = m_headersBuffer.toString();\n m_headersBuffer.setCurrentPosition(0);\n\n protocol::http::Status status;\n parser::Caret caret(headersText);\n\n protocol::http::Parser::parseHeaders(headers, headersText.getPtr(), caret, status);\n\n}\n\nStatefulParser::ListenerCall StatefulParser::parseNext_Boundary(data::buffer::InlineWriteData& inlineData) {\n\n ListenerCall result;\n auto data = inlineData.currBufferPtr;\n auto size = inlineData.bytesLeft;\n\n auto sampleData = m_nextBoundarySample->data();\n v_io_size sampleSize = static_cast(m_nextBoundarySample->size());\n\n if (m_currPartIndex == 0) {\n sampleData = m_firstBoundarySample->data();\n sampleSize = static_cast(m_firstBoundarySample->size());\n } else {\n sampleData = m_nextBoundarySample->data();\n sampleSize = static_cast(m_nextBoundarySample->size());\n }\n\n v_io_size checkSize = sampleSize - m_currBoundaryCharIndex;\n if(checkSize > size) {\n checkSize = size;\n }\n\n parser::Caret caret(reinterpret_cast(data), size);\n\n if(caret.isAtText(&sampleData[m_currBoundaryCharIndex], checkSize, true)) {\n\n m_currBoundaryCharIndex += caret.getPosition();\n\n if(m_currBoundaryCharIndex == sampleSize) {\n m_state = STATE_AFTER_BOUNDARY;\n m_currBoundaryCharIndex = 0;\n m_readingBody = false;\n if(m_currPartIndex > 0) {\n result.setOnDataCall(nullptr, 0);\n }\n }\n\n inlineData.inc(caret.getPosition());\n return result;\n\n } else if(m_readingBody) {\n\n if(m_currBoundaryCharIndex > 0) {\n result.setOnDataCall(sampleData, m_currBoundaryCharIndex);\n } else {\n m_checkForBoundary = false;\n }\n\n m_state = STATE_DATA;\n m_currBoundaryCharIndex = 0;\n\n return result;\n\n }\n\n throw std::runtime_error(\"[oatpp::web::mime::multipart::StatefulParser::parseNext_Boundary()]: Error. Invalid state.\");\n\n}\n\nvoid StatefulParser::parseNext_AfterBoundary(data::buffer::InlineWriteData& inlineData) {\n\n p_char8 data = reinterpret_cast(const_cast(inlineData.currBufferPtr));\n auto size = inlineData.bytesLeft;\n\n if(m_currBoundaryCharIndex == 0) {\n\n if(data[0] == '-') {\n m_finishingBoundary = true;\n } else if(data[0] != '\\r') {\n throw std::runtime_error(\"[oatpp::web::mime::multipart::StatefulParser::parseNext_AfterBoundary()]: Error. Invalid char.\");\n }\n\n }\n\n if(size > 1 || m_currBoundaryCharIndex == 1) {\n\n if (m_finishingBoundary && data[1 - m_currBoundaryCharIndex] == '-') {\n auto result = 2 - m_currBoundaryCharIndex;\n m_state = STATE_DONE;\n m_currBoundaryCharIndex = 0;\n inlineData.inc(result);\n return;\n } else if (!m_finishingBoundary && data[1 - m_currBoundaryCharIndex] == '\\n') {\n auto result = 2 - m_currBoundaryCharIndex;\n m_state = STATE_HEADERS;\n m_currBoundaryCharIndex = 0;\n m_headerSectionEndAccumulator = 0;\n inlineData.inc(result);\n return;\n } else {\n throw std::runtime_error(\"[oatpp::web::mime::multipart::StatefulParser::parseNext_AfterBoundary()]: Error. Invalid trailing char.\");\n }\n\n }\n\n m_currBoundaryCharIndex = 1;\n inlineData.inc(1);\n return;\n\n}\n\nStatefulParser::ListenerCall StatefulParser::parseNext_Headers(data::buffer::InlineWriteData& inlineData) {\n\n ListenerCall result;\n\n p_char8 data = reinterpret_cast(const_cast(inlineData.currBufferPtr));\n auto size = inlineData.bytesLeft;\n\n for(v_buff_size i = 0; i < size; i ++) {\n\n m_headerSectionEndAccumulator <<= 8;\n m_headerSectionEndAccumulator |= data[i];\n\n if(m_headerSectionEndAccumulator == HEADERS_SECTION_END) {\n\n if(m_headersBuffer.getCurrentPosition() + i > m_maxPartHeadersSize) {\n throw std::runtime_error(\"[oatpp::web::mime::multipart::StatefulParser::parseNext_Headers()]: Error. Too large heades.\");\n }\n\n m_headersBuffer.writeSimple(data, i);\n\n result.setOnHeadersCall();\n\n m_state = STATE_DATA;\n m_checkForBoundary = true;\n\n inlineData.inc(i + 1);\n return result;\n\n }\n\n }\n\n if(m_headersBuffer.getCurrentPosition() + size > m_maxPartHeadersSize) {\n throw std::runtime_error(\"[oatpp::web::mime::multipart::StatefulParser::parseNext_Headers()]: Error. Headers section is too large.\");\n }\n\n m_headersBuffer.writeSimple(data, size);\n\n inlineData.inc(size);\n\n return result;\n\n}\n\nStatefulParser::ListenerCall StatefulParser::parseNext_Data(data::buffer::InlineWriteData& inlineData) {\n\n ListenerCall result;\n\n const char* data = reinterpret_cast(inlineData.currBufferPtr);\n auto size = inlineData.bytesLeft;\n\n parser::Caret caret(data, size);\n\n bool rFound = caret.findChar('\\r');\n if(rFound && !m_checkForBoundary) {\n caret.inc();\n rFound = caret.findChar('\\r');\n }\n\n m_checkForBoundary = true;\n\n if(rFound) {\n if(caret.getPosition() > 0) {\n result.setOnDataCall(data, caret.getPosition());\n }\n m_state = STATE_BOUNDARY;\n m_readingBody = true;\n inlineData.inc(caret.getPosition());\n } else {\n result.setOnDataCall(data, size);\n inlineData.inc(size);\n }\n\n return result;\n\n}\n\nvoid StatefulParser::parseNext(data::buffer::InlineWriteData& inlineData, async::Action& action) {\n\n while(inlineData.bytesLeft > 0) {\n\n ListenerCall listenerCall;\n\n switch (m_state) {\n case STATE_BOUNDARY:\n listenerCall = parseNext_Boundary(inlineData);\n break;\n case STATE_AFTER_BOUNDARY:\n parseNext_AfterBoundary(inlineData);\n break;\n case STATE_HEADERS:\n listenerCall = parseNext_Headers(inlineData);\n break;\n case STATE_DATA:\n listenerCall = parseNext_Data(inlineData);\n break;\n case STATE_DONE:\n return;\n default:\n throw std::runtime_error(\"[oatpp::web::mime::multipart::StatefulParser::parseNext()]: Error. Invalid state.\");\n }\n\n if(listenerCall) {\n if(m_asyncListener) {\n action = listenerCall.callAsync(this).next(async::Action::createActionByType(async::Action::TYPE_REPEAT));\n break;\n }else {\n listenerCall.call(this);\n }\n }\n\n }\n\n}\n\nbool StatefulParser::finished() {\n return m_state == STATE_DONE;\n}\n\n}}}}\n\n// Path: src/oatpp/web/mime/multipart/Multipart.hpp\n/***************************************************************************\n *\n * Project _____ __ ____ _ _\n * ( _ ) /__\\ (_ _)_| |_ _| |_\n * )(_)( /(__)\\ )( (_ _)(_ _)\n * (_____)(__)(__)(__) |_| |_|\n *\n *\n * Copyright 2018-present, Leonid Stryzhevskyi \n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n *\n ***************************************************************************/\n\n#ifndef oatpp_web_mime_multipart_Multipart_hpp\n#define oatpp_web_mime_multipart_Multipart_hpp\n\n#include \"Part.hpp\"\n#include \n\nnamespace oatpp { namespace web { namespace mime { namespace multipart {\n\n/**\n * Typedef for headers map. Headers map key is case-insensitive.\n * For more info see &id:oatpp::data::share::LazyStringMap;.\n */\ntypedef oatpp::data::share::LazyStringMultimap Headers;\n\n/**\n * Abstract Multipart.\n */\nclass Multipart {\nprivate:\n oatpp::String m_boundary;\npublic:\n\n /**\n * Constructor.\n * @param boundary - multipart boundary value.\n */\n Multipart(const oatpp::String& boundary);\n\n /**\n * Default virtual Destructor.\n */\n virtual ~Multipart() = default;\n\n /**\n * Get multipart boundary value.\n * @return - multipart boundary value.\n */\n oatpp::String getBoundary();\n\n /**\n * Read part-by-part from Multipart.\n * @return\n */\n virtual std::shared_ptr readNextPart(async::Action& action) = 0;\n\n /**\n * Write part-by-part to Multipart.\n * @param part\n */\n virtual void writeNextPart(const std::shared_ptr& part, async::Action& action) = 0;\n\n /**\n * Read part-by-part from Multipart.
\n * Call writeNextPart(...) and throw if `action.isNone() == false`.\n * @return\n */\n std::shared_ptr readNextPartSimple();\n\n /**\n * Write part-by-part to Multipart.\n * Call writeNextPartSimple(...) and throw if `action.isNone() == false`.\n * @param part\n */\n void writeNextPartSimple(const std::shared_ptr& part);\n\npublic:\n\n /**\n * Generate random boundary for Multipart object. Base64 encoded.\n * @param boundarySize - size in bytes of random vector.\n * @return - &id:oatpp::String;.\n */\n static oatpp::String generateRandomBoundary(v_int32 boundarySize = 15);\n\n /**\n * Parse boundary value from headers\n * @param headers\n * @return\n */\n static oatpp::String parseBoundaryFromHeaders(const Headers& requestHeaders);\n\n};\n\n\n\n}}}}\n\n\n#endif // oatpp_web_mime_multipart_Multipart_hpp\n\n// Path: src/oatpp/web/mime/multipart/PartReader.hpp\n/***************************************************************************\n *\n * Project _____ __ ____ _ _\n * ( _ ) /__\\ (_ _)_| |_ _| |_\n * )(_)( /(__)\\ )( (_ _)(_ _)\n * (_____)(__)(__)(__) |_| |_|\n *\n *\n * Copyright 2018-present, Leonid Stryzhevskyi \n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n *\n ***************************************************************************/\n\n#ifndef oatpp_web_mime_multipart_PartReader_hpp\n#define oatpp_web_mime_multipart_PartReader_hpp\n\n#include \"./Multipart.hpp\"\n#include \"oatpp/core/data/stream/Stream.hpp\"\n\nnamespace oatpp { namespace web { namespace mime { namespace multipart {\n\n/**\n * Abstract read handler of multipart parts.\n */\nclass PartReader {\npublic:\n\n /**\n * Default virtual destructor.\n */\n virtual ~PartReader() = default;\n\n /**\n * Called when new part headers are parsed and part object is created.\n * @param part\n */\n virtual void onNewPart(const std::shared_ptr& part) = 0;\n\n /**\n * Called on each new chunk of data is parsed for the multipart-part.
\n * When all data is read, called again with `data == nullptr && size == 0` to indicate end of the part.\n * @param part\n * @param data - pointer to buffer containing chunk data.\n * @param size - size of the buffer.\n */\n virtual void onPartData(const std::shared_ptr& part, const char* data, oatpp::v_io_size size) = 0;\n\n};\n\n/**\n * Abstract Async read handler of multipart parts.\n */\nclass AsyncPartReader {\npublic:\n\n /**\n * Default virtual destructor.\n */\n virtual ~AsyncPartReader() = default;\n\n /**\n * Called when new part headers are parsed and part object is created.\n * @param part\n * @return - &id:oatpp::async::CoroutineStarter;.\n */\n virtual async::CoroutineStarter onNewPartAsync(const std::shared_ptr& part) = 0;\n\n /**\n * Called on each new chunk of data is parsed for the multipart-part.
\n * When all data is read, called again with `data == nullptr && size == 0` to indicate end of the part.\n * @param part\n * @param data - pointer to buffer containing chunk data.\n * @param size - size of the buffer.\n * @return - &id:oatpp::async::CoroutineStarter;.\n */\n virtual async::CoroutineStarter onPartDataAsync(const std::shared_ptr& part, const char* data, oatpp::v_io_size size) = 0;\n\n};\n\n/**\n * Resource provider for `StreamPartReader`.\n */\nclass PartReaderResourceProvider {\npublic:\n\n /**\n * Default virtual destructor.\n */\n virtual ~PartReaderResourceProvider() = default;\n\n /**\n * Get data resource to write (save) part data in.\n * @param part\n * @return\n */\n virtual std::shared_ptr getResource(const std::shared_ptr& part) = 0;\n\n /**\n * Get data resource to write (save) part data in.\n * @param part\n * @param resource - put here pointer to obtained resource.\n * @return\n */\n virtual async::CoroutineStarter getResourceAsync(const std::shared_ptr& part,\n std::shared_ptr& resource) = 0;\n\n};\n\n/**\n * Part reader used in order to stream part data.\n */\nclass StreamPartReader : public PartReader {\nprivate:\n static const char* const TAG_NAME;\nprivate:\n\n class TagObject : public oatpp::base::Countable {\n public:\n v_io_size size = 0;\n std::shared_ptr resource;\n std::shared_ptr outputStream;\n };\n\nprivate:\n std::shared_ptr m_resourceProvider;\n v_io_size m_maxDataSize;\npublic:\n\n /**\n * Constructor.\n * @param resourceProvider\n * @param maxDataSize - use `-1` for no limit.\n */\n StreamPartReader(const std::shared_ptr& resourceProvider, v_io_size maxDataSize = -1);\n\n /**\n * Called when new part headers are parsed and part object is created.\n * @param part\n */\n void onNewPart(const std::shared_ptr& part) override;\n\n /**\n * Called on each new chunk of data is parsed for the multipart-part.
\n * When all data is read, called again with `data == nullptr && size == 0` to indicate end of the part.\n * @param part\n * @param data - pointer to buffer containing chunk data.\n * @param size - size of the buffer.\n */\n void onPartData(const std::shared_ptr& part, const char* data, oatpp::v_io_size size) override;\n\n};\n\n/**\n * Async part reader used in order to stream part data in Asynchronous manner.\n */\nclass AsyncStreamPartReader : public AsyncPartReader {\nprivate:\n static const char* const TAG_NAME;\nprivate:\n\n class TagObject : public oatpp::base::Countable {\n public:\n v_io_size size = 0;\n std::shared_ptr resource;\n std::shared_ptr outputStream;\n };\n\nprivate:\n std::shared_ptr m_resourceProvider;\n v_io_size m_maxDataSize;\npublic:\n\n /**\n * Constructor.\n * @param resourceProvider\n * @param maxDataSize - use `-1` for no limit.\n */\n AsyncStreamPartReader(const std::shared_ptr& resourceProvider, v_io_size maxDataSize = -1);\n\n /**\n * Called when new part headers are parsed and part object is created.\n * @param part\n * @return - &id:oatpp::async::CoroutineStarter;.\n */\n async::CoroutineStarter onNewPartAsync(const std::shared_ptr& part) override;\n\n /**\n * Called on each new chunk of data is parsed for the multipart-part.
\n * When all data is read, called again with `data == nullptr && size == 0` to indicate end of the part.\n * @param part\n * @param data - pointer to buffer containing chunk data.\n * @param size - size of the buffer.\n * @return - &id:oatpp::async::CoroutineStarter;.\n */\n async::CoroutineStarter onPartDataAsync(const std::shared_ptr& part, const char* data, oatpp::v_io_size size) override;\n\n};\n\n}}}}\n\n#endif //oatpp_web_mime_multipart_PartReader_hpp\n\n// Path: src/oatpp/web/mime/multipart/PartReader.cpp\n/***************************************************************************\n *\n * Project _____ __ ____ _ _\n * ( _ ) /__\\ (_ _)_| |_ _| |_\n * )(_)( /(__)\\ )( (_ _)(_ _)\n * (_____)(__)(__)(__) |_| |_|\n *\n *\n...\n", "needle_token_start": 8672, "needle_token_end": 8706, "code_context_ntokens": 15778, "output": ["```cpp\nclass StatefulParser {\nprivate:\n static constexpr v_int32 STATE_BOUNDARY = 0;\n static constexpr v_int32 STATE_AFTER_BOUNDARY = 1;\n static constexpr v_int32 STATE_HEADERS = 2;\n static constexpr v_int32 STATE_DATA = 3;\n static constexpr v_int32 STATE_DONE = 4;\nprivate:\n static constexpr v_uint32 HEADERS_SECTION_END = ('\\r' << 24) | ('\\n' << 16) | ('\\r' << 8) | ('\\n');\nprivate:\n /**\n * Typedef for headers map. Headers map key is case-insensitive.\n * For more info see &id:oatpp::data::share::LazyStringMap;.\n */\n typedef oatpp::data::share::LazyStringMultimap Headers;\npublic:\n\n /**\n * Listener for parsed items.\n */\n class Listener {\n public:\n /**\n * Typedef for headers map. Headers map key is case-insensitive.\n * For more info see &id:oatpp::data::share::LazyStringMap;.\n */\n typedef oatpp::data::share::LazyStringMultimap Headers;\n public:\n\n /**\n * Default virtual Destructor.\n */\n virtual ~Listener() = default;\n\n /**\n * Called on new part found in the stream.\n * Always called before `onPartData` events.\n * @param partHeaders - complete set of part headers.\n */\n virtual void onPartHeaders(const Headers& partHeaders) = 0;\n\n /**\n * Called on each new chunk of bytes parsed from the part body.\n * When all data of message is read, readMessage is called again with size == 0 to\n * indicate end of the part.\n * @param data - pointer to data.\n * @param size - size of the data in bytes.\n */\n virtual void onPartData(const char* data, v_buff_size size) = 0;\n\n };\n\npublic:\n\n /**\n * Async Listener for parsed items.\n */\n class AsyncListener {\n public:\n /**\n * Typedef for headers map. Headers map key is case-insensitive.\n * For more info see &id:oatpp::data::share::LazyStringMap;.\n */\n typedef oatpp::data::share::LazyStringMultimap Headers;\n public:\n\n /**\n * Default virtual Destructor.\n */\n virtual ~AsyncListener() = default;\n\n /**\n * Called on new part found in the stream.\n * Always called before `onPartData` events.\n * @param partHeaders - complete set of part headers.\n */\n virtual async::CoroutineStarter onPartHeadersAsync(const Headers& partHeaders) = 0;\n\n /**\n * Called on each new chunk of bytes parsed from the part body.\n * When all data of message is read, readMessage is called again with size == 0 to\n * indicate end of the part.\n * @param data - pointer to data.\n * @param size - size of the data in bytes.\n */\n virtual async::CoroutineStarter onPartDataAsync(const char* data, v_buff_size size) = 0;\n\n };\n\nprivate:\n\n class ListenerCall {\n public:\n\n static constexpr v_int32 CALL_NONE = 0;\n static constexpr v_int32 CALL_ON_HEADERS = 1;\n static constexpr v_int32 CALL_ON_DATA = 2;\n\n public:\n\n \nListenerCall()\n : callType(CALL_NONE)\n , data(nullptr)\n , size(0)\n {}\n\n v_int32 callType;\n const char* data;\n v_io_size size;\n\n void setOnHeadersCall();\n void setOnDataCall(const char* pData, v_buff_size pSize);\n\n void call(StatefulParser* parser);\n async::CoroutineStarter callAsync(StatefulParser* parser);\n\n explicit operator bool() const;\n\n };\n\nprivate:\n\n v_int32 m_state;\n v_int64 m_currPartIndex;\n v_buff_size m_currBoundaryCharIndex;\n bool m_checkForBoundary;\n bool m_finishingBoundary;\n bool m_readingBody;\n\n v_uint32 m_headerSectionEndAccumulator;\n\n oatpp::String m_firstBoundarySample;\n oatpp::String m_nextBoundarySample;\n\n /*\n * Headers of the part are stored in the buffer and are parsed as one chunk.\n */\n data::stream::BufferOutputStream m_headersBuffer;\n\n /*\n * Max length of all headers per one part.\n * Default value = 4096 bytes.\n"]} +{"repo": "oatpp/oatpp", "name": "PairListObjectWrapper", "language": "cpp", "path": "src/oatpp/core/data/mapping/type/PairList.hpp", "position_ratio": 0.65, "description": "\nFunction Description:\n1. **Purpose**: To provide a wrapper around a list of key-value pairs, enabling object-oriented operations and type safety in a generic programming context.\n2. **Input**: Receives an initializer list of key-value pairs where keys and values are of templated types.\n3. **Output**: Creates an instance that encapsulates a list of key-value pairs, allowing for further manipulation and access in an object-oriented manner.\n4. **Procedure**: Initializes the encapsulated list using the provided initializer list of key-value pairs, leveraging the capabilities of smart pointers for memory management.\n\n", "instruction": "Based on the function description and code context, please retrieve and repeat the exact described function from the code context in a code block wrapped by ```:", "template": "instruction\ncode_context\ndescription\ninstruction", "code_context": "// Path: src/oatpp/core/data/mapping/type/Primitive.hpp\n/***************************************************************************\n *\n * Project _____ __ ____ _ _\n * ( _ ) /__\\ (_ _)_| |_ _| |_\n * )(_)( /(__)\\ )( (_ _)(_ _)\n * (_____)(__)(__)(__) |_| |_|\n *\n *\n * Copyright 2018-present, Leonid Stryzhevskyi \n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n *\n ***************************************************************************/\n\n#ifndef oatpp_data_type_Primitive_hpp\n#define oatpp_data_type_Primitive_hpp\n\n#include \"./Type.hpp\"\n\n...\n// Path: src/oatpp/core/data/mapping/type/Enum.hpp\n/***************************************************************************\n *\n * Project _____ __ ____ _ _\n * ( _ ) /__\\ (_ _)_| |_ _| |_\n * )(_)( /(__)\\ )( (_ _)(_ _)\n * (_____)(__)(__)(__) |_| |_|\n *\n *\n * Copyright 2018-present, Leonid Stryzhevskyi \n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n *\n ***************************************************************************/\n\n#ifndef oatpp_data_mapping_type_Enum_hpp\n#define oatpp_data_mapping_type_Enum_hpp\n\n#include \"./Any.hpp\"\n#include \"./Primitive.hpp\"\n#include \"oatpp/core/data/share/MemoryLabel.hpp\"\n\n#include \n#include \n#include \n\nnamespace oatpp { namespace data { namespace mapping { namespace type {\n\n/**\n * Errors of enum interpretation.\n */\nenum class EnumInterpreterError : v_int32 {\n\n /**\n * The interpretation was successful.\n */\n OK = 0,\n\n /**\n * Wrong `Interpreter` is used to interpret the variable.
\n * This may also occur if for example:
\n * `oatpp::Enum` is passed to interpreter of `oatpp::Enum::NotNull`.\n */\n TYPE_MISMATCH_ENUM = 1,\n\n /**\n * Wrong &id:oatpp::data::mapping::type::Primitive; is passed to interpreter.\n */\n TYPE_MISMATCH_ENUM_VALUE = 2,\n\n /**\n * Interpreter constraint is violated.
\n * The constraint was set to `NotNull` but interpretation to/from `nullptr` is requested.\n */\n CONSTRAINT_NOT_NULL = 3,\n\n /**\n * Enum entry not found.\n */\n ENTRY_NOT_FOUND = 4,\n\n};\n\nnamespace __class {\n\n /**\n * Abstract Enum class.\n */\n class AbstractEnum {\n public:\n static const ClassId CLASS_ID;\n public:\n\n class PolymorphicDispatcher {\n public:\n\n PolymorphicDispatcher(bool pNotNull)\n : notNull(pNotNull)\n {}\n\n virtual ~PolymorphicDispatcher() = default;\n\n const bool notNull;\n\n virtual type::Void createObject() const = 0;\n\n virtual type::Void toInterpretation(const type::Void& enumValue, EnumInterpreterError& error) const = 0;\n virtual type::Void fromInterpretation(const type::Void& interValue, EnumInterpreterError& error) const = 0;\n virtual type::Type* getInterpretationType() const = 0;\n virtual std::vector getInterpretedEnum() const = 0;\n\n };\n\n };\n\n template\n class Enum;\n\n}\n\n/**\n * Enum value info.\n * @tparam T - underlying enum type.\n */\ntemplate\nstruct EnumValueInfo {\n /**\n * Entry value. T - enum type.\n */\n const T value;\n\n /**\n * Index of the entry.\n */\n const v_int32 index;\n\n /**\n * Name of the enum entry or name-qualifier, if qualifier was specified.
\n * &id:oatpp::data::share::StringKeyLabel;.\n */\n const data::share::StringKeyLabel name;\n \n /**\n * Description of the enum etry.
\n * &id:oatpp::data::share::StringKeyLabel;.\n */\n const data::share::StringKeyLabel description;\n};\n\ntemplate\nstruct EnumInfo {\npublic:\n const char* nameQualifier = nullptr;\n std::unordered_map> byName;\n std::unordered_map> byValue;\n std::vector> byIndex;\n};\n\ntemplate\nclass EnumObjectWrapper; // FWD\n\ntemplate\nclass EnumMeta {\n\n template\n friend class __class::Enum;\n\n template\n friend class EnumObjectWrapper;\n\npublic:\n typedef T EnumType;\nprotected:\n static EnumInfo* getInfo() {\n static EnumInfo info;\n return &info;\n }\n};\n\n/**\n * Enum interpreter `AsString`\n * @tparam T \n * @tparam notnull \n */\ntemplate\nclass EnumInterpreterAsString {\npublic:\n typedef String UnderlyingTypeObjectWrapper;\npublic:\n template \n using InterpreterType = EnumInterpreterAsString;\npublic:\n constexpr static bool notNull = notnull;\npublic:\n static Void toInterpretation(const Void& enumValue, EnumInterpreterError& error);\n static Void fromInterpretation(const Void& interValue, EnumInterpreterError& error);\n static Type* getInterpretationType();\n};\n\n/**\n * Enum interpreter `AsNumber`\n * @tparam T \n * @tparam notnull \n */\ntemplate\nclass EnumInterpreterAsNumber {\nprivate:\n typedef typename std::underlying_type::type EnumUnderlyingType;\npublic:\n typedef typename ObjectWrapperByUnderlyingType::ObjectWrapper UnderlyingTypeObjectWrapper;\npublic:\n template \n using InterpreterType = EnumInterpreterAsNumber;\npublic:\n constexpr static bool notNull = notnull;\npublic:\n static Void toInterpretation(const Void& enumValue, EnumInterpreterError& error);\n static Void fromInterpretation(const Void& interValue, EnumInterpreterError& error);\n static Type* getInterpretationType();\n};\n\n/**\n * Template class for `oatpp::Enum`.\n * @tparam T - enum type.\n * @tparam EnumInterpreter - enum interpreter.\n */\ntemplate\nclass EnumObjectWrapper : public ObjectWrapper>{\n template\n friend class EnumObjectWrapper;\npublic:\n typedef typename std::underlying_type::type UnderlyingEnumType;\n typedef T Z__EnumType;\n typedef __class::Enum EnumObjectClass;\n\n /**\n * Template parameter - `Interpreter`.\n */\n typedef EnumInterpreter Interpreter;\npublic:\n /**\n * Enum interpreted `AsString`.\n */\n typedef EnumObjectWrapper> AsString;\n\n /**\n * Enum interpreted `AsNumber`.\n */\n typedef EnumObjectWrapper> AsNumber;\n\n /**\n * Enum with `NotNull` interpretation constraint.\n */\n typedef EnumObjectWrapper> NotNull;\npublic:\n\n EnumObjectWrapper(const std::shared_ptr& ptr, const type::Type* const valueType)\n : type::ObjectWrapper(ptr, valueType)\n {}\n\n /**\n * Default constructor.\n */\n EnumObjectWrapper() {}\n\n /**\n * Nullptr constructor.\n */\n EnumObjectWrapper(std::nullptr_t) {}\n\n /**\n * Constructor.\n * @param ptr\n */\n EnumObjectWrapper(const std::shared_ptr& ptr)\n : type::ObjectWrapper(ptr)\n {}\n\n /**\n * Constructor.\n * @param ptr\n */\n EnumObjectWrapper(std::shared_ptr&& ptr)\n : type::ObjectWrapper(std::forward>(ptr))\n {}\n\n /**\n * Copy-constructor.\n * @tparam OtherInter\n * @param other\n */\n template\n EnumObjectWrapper(const EnumObjectWrapper& other)\n : type::ObjectWrapper(other.m_ptr)\n {}\n\n /**\n * Move-constructor.\n * @tparam OtherInter\n * @param other\n */\n template\n EnumObjectWrapper(EnumObjectWrapper&& other)\n : type::ObjectWrapper(std::move(other.m_ptr))\n {}\n\n inline EnumObjectWrapper& operator = (std::nullptr_t) {\n this->m_ptr.reset();\n return *this;\n }\n\n template\n inline EnumObjectWrapper& operator = (const EnumObjectWrapper& other) {\n this->m_ptr = other.m_ptr;\n return *this;\n }\n\n template\n inline EnumObjectWrapper& operator = (EnumObjectWrapper&& other) {\n this->m_ptr = std::move(other.m_ptr);\n return *this;\n }\n\npublic:\n\n /**\n * Constructor by value.\n * @param value\n */\n EnumObjectWrapper(T value)\n : type::ObjectWrapper(std::make_shared(value))\n {}\n\n EnumObjectWrapper& operator = (T value) {\n this->m_ptr = std::make_shared(value);\n return *this;\n }\n\n T operator*() const {\n return this->m_ptr.operator*();\n }\n\n template::value, void>::type\n >\n inline bool operator == (TP){\n return this->m_ptr.get() == nullptr;\n }\n\n template::value, void>::type\n >\n inline bool operator != (TP){\n return this->m_ptr.get() != nullptr;\n }\n\n template::value, void>::type\n >\n inline bool operator == (TP value) const {\n if(!this->m_ptr) return false;\n return *this->m_ptr == value;\n }\n\n template::value, void>::type\n >\n inline bool operator != (TP value) const {\n if(!this->m_ptr) return true;\n return *this->m_ptr != value;\n }\n\n template::value, void>::type\n >\n inline bool operator == (const TP &other) const {\n if(this->m_ptr.get() == other.m_ptr.get()) return true;\n if(!this->m_ptr || !other.m_ptr) return false;\n return *this->m_ptr == *other.m_ptr;\n }\n\n template::value, void>::type\n >\n inline bool operator != (const TP &other) const {\n return !operator == (other);\n }\n\n template::value, void>::type\n >\n inline operator TP() const {\n return *this->m_ptr;\n }\n\npublic:\n\n /**\n * Get &l:EnumValueInfo ; by name.\n * @param name - name or name-qualifier of the enum entry.\n * @return - &l:EnumValueInfo ;.\n * @throws - `std::runtime_error` if not found.\n */\n static const EnumValueInfo& getEntryByName(const String& name) {\n auto it = EnumMeta::getInfo()->byName.find(name);\n if(it != EnumMeta::getInfo()->byName.end()) {\n return it->second;\n }\n throw std::runtime_error(\"[oatpp::data::mapping::type::Enum::getEntryByName()]: Error. Entry not found.\");\n }\n\n /**\n * Get &l:EnumValueInfo ; by enum value.\n * @param value - enum value.\n * @return - &l:EnumValueInfo ;.\n * @throws - `std::runtime_error` if not found.\n */\n static const EnumValueInfo& getEntryByValue(T value) {\n auto it = EnumMeta::getInfo()->byValue.find(static_cast(value));\n if(it != EnumMeta::getInfo()->byValue.end()) {\n return it->second;\n }\n throw std::runtime_error(\"[oatpp::data::mapping::type::Enum::getEntryByValue()]: Error. Entry not found.\");\n }\n\n /**\n * Get &l:EnumValueInfo ; by integer value.\n * @param value - integer value.\n * @return - &l:EnumValueInfo ;.\n * @throws - `std::runtime_error` if not found.\n */\n static const EnumValueInfo& getEntryByUnderlyingValue(UnderlyingEnumType value) {\n auto it = EnumMeta::getInfo()->byValue.find(static_cast(value));\n if(it != EnumMeta::getInfo()->byValue.end()) {\n return it->second;\n }\n throw std::runtime_error(\"[oatpp::data::mapping::type::Enum::getEntryByUnderlyingValue()]: Error. Entry not found.\");\n }\n\n /**\n * Get &l:EnumValueInfo ; by index.\n * @param index - index of the entry in the enum.\n * @return - &l:EnumValueInfo ;.\n * @throws - `std::runtime_error` if not found.\n */\n static const EnumValueInfo& getEntryByIndex(v_int32 index) {\n if(index >= 0 && index < EnumMeta::getInfo()->byIndex.size()) {\n return EnumMeta::getInfo()->byIndex[index];\n }\n throw std::runtime_error(\"[oatpp::data::mapping::type::Enum::getEntryByIndex()]: Error. Entry not found.\");\n }\n\n /**\n * Get `std::vector` of &l:EnumValueInfo ;.\n * @return - `std::vector` of &l:EnumValueInfo ;.\n */\n static const std::vector>& getEntries() {\n return EnumMeta::getInfo()->byIndex;\n }\n\n};\n\n/**\n * Mapping-enabled Enum. See &l:EnumObjectWrapper;.\n */\ntemplate \nusing Enum = EnumObjectWrapper>;\n\ntemplate\nVoid EnumInterpreterAsString::toInterpretation(const Void& enumValue, EnumInterpreterError& error) {\n typedef EnumObjectWrapper> EnumOW;\n\n if(enumValue.getValueType() != EnumOW::Class::getType()) {\n error = EnumInterpreterError::TYPE_MISMATCH_ENUM;\n return Void(nullptr, String::Class::getType());\n }\n\n if(!enumValue) {\n if(notnull) {\n error = EnumInterpreterError::CONSTRAINT_NOT_NULL;\n return Void(nullptr, String::Class::getType());\n }\n return Void(nullptr, String::Class::getType());\n }\n\n const auto& ow = enumValue.template cast();\n const auto& entry = EnumOW::getEntryByValue(*ow);\n return entry.name.toString();\n}\n\ntemplate\nVoid EnumInterpreterAsString::fromInterpretation(const Void& interValue, EnumInterpreterError& error) {\n typedef EnumObjectWrapper> EnumOW;\n\n if(interValue.getValueType() != String::Class::getType()) {\n error = EnumInterpreterError::TYPE_MISMATCH_ENUM_VALUE;\n return Void(nullptr, EnumOW::Class::getType());\n }\n\n if(!interValue) {\n if(notnull) {\n error = EnumInterpreterError::CONSTRAINT_NOT_NULL;\n return Void(nullptr, EnumOW::Class::getType());\n }\n return Void(nullptr, EnumOW::Class::getType());\n }\n\n try {\n return EnumOW(EnumOW::getEntryByName(interValue.template cast()).value);\n } catch (const std::runtime_error&) { // TODO - add a specific error for this.\n error = EnumInterpreterError::ENTRY_NOT_FOUND;\n }\n return Void(nullptr, EnumOW::Class::getType());\n}\n\ntemplate\nType* EnumInterpreterAsString::getInterpretationType() {\n return String::Class::getType();\n}\n\ntemplate\nVoid EnumInterpreterAsNumber::toInterpretation(const Void& enumValue, EnumInterpreterError& error) {\n\n typedef EnumObjectWrapper> EnumOW;\n typedef typename std::underlying_type::type EnumUT;\n typedef typename ObjectWrapperByUnderlyingType::ObjectWrapper UTOW;\n\n if(enumValue.getValueType() != EnumOW::Class::getType()) {\n error = EnumInterpreterError::TYPE_MISMATCH_ENUM;\n return Void(nullptr, UTOW::Class::getType());\n }\n\n if(!enumValue) {\n if(notnull) {\n error = EnumInterpreterError::CONSTRAINT_NOT_NULL;\n return Void(nullptr, UTOW::Class::getType());\n }\n return Void(nullptr, UTOW::Class::getType());\n }\n\n const auto& ow = enumValue.template cast();\n return UTOW(static_cast(*ow));\n\n}\n\ntemplate\nVoid EnumInterpreterAsNumber::fromInterpretation(const Void& interValue, EnumInterpreterError& error) {\n typedef EnumObjectWrapper> EnumOW;\n\n typedef typename std::underlying_type::type EnumUT;\n typedef typename ObjectWrapperByUnderlyingType::ObjectWrapper OW;\n\n if(interValue.getValueType() != OW::Class::getType()) {\n error = EnumInterpreterError::TYPE_MISMATCH_ENUM_VALUE;\n return Void(nullptr, EnumOW::Class::getType());\n }\n\n if(!interValue) {\n if(notnull) {\n error = EnumInterpreterError::CONSTRAINT_NOT_NULL;\n return Void(nullptr, EnumOW::Class::getType());\n }\n return Void(nullptr, EnumOW::Class::getType());\n }\n\n try{\n const auto& entry = EnumOW::getEntryByUnderlyingValue(\n interValue.template cast()\n );\n return EnumOW(entry.value);\n } catch (const std::runtime_error&) { // TODO - add a specific error for this.\n error = EnumInterpreterError::ENTRY_NOT_FOUND;\n }\n return Void(nullptr, EnumOW::Class::getType());\n}\n\ntemplate\nType* EnumInterpreterAsNumber::getInterpretationType() {\n typedef typename std::underlying_type::type EnumUT;\n return ObjectWrapperByUnderlyingType::ObjectWrapper::Class::getType();\n}\n\nnamespace __class {\n\n template\n class Enum : public AbstractEnum {\n private:\n\n class PolymorphicDispatcher : public AbstractEnum::PolymorphicDispatcher {\n public:\n\n PolymorphicDispatcher()\n : AbstractEnum::PolymorphicDispatcher(Interpreter::notNull)\n {}\n\n type::Void createObject() const override {\n return type::Void(std::make_shared(), getType());\n }\n\n type::Void toInterpretation(const type::Void& enumValue, EnumInterpreterError& error) const override {\n return Interpreter::toInterpretation(enumValue, error);\n }\n\n type::Void fromInterpretation(const type::Void& interValue, EnumInterpreterError& error) const override {\n return Interpreter::fromInterpretation(interValue, error);\n }\n\n type::Type* getInterpretationType() const override {\n return Interpreter::getInterpretationType();\n }\n\n std::vector getInterpretedEnum() const override {\n\n typedef type::EnumObjectWrapper EnumOW;\n\n std::vector result({});\n\n for(const auto& e : EnumOW::getEntries()) {\n type::EnumInterpreterError error = type::EnumInterpreterError::OK;\n result.push_back(type::Any(toInterpretation(EnumOW(e.value), error)));\n if(error != type::EnumInterpreterError::OK) {\n throw std::runtime_error(\"[oatpp::data::mapping::type::__class::Enum::getInterpretedEnum()]: Unknown error.\");\n }\n }\n\n return result;\n\n }\n\n };\n\n private:\n\n static Type createType() {\n Type::Info info;\n info.nameQualifier = type::EnumMeta::getInfo()->nameQualifier;\n info.polymorphicDispatcher = new PolymorphicDispatcher();\n return Type(__class::AbstractEnum::CLASS_ID, info);\n }\n\n public:\n\n static Type* getType() {\n static Type type = createType();\n return &type;\n }\n\n };\n\n}\n\n}}}}\n\nnamespace std {\n\n template\n struct hash > {\n\n typedef oatpp::data::mapping::type::EnumObjectWrapper argument_type;\n typedef v_uint64 result_type;\n\n result_type operator()(argument_type const &e) const noexcept {\n if (e.get() == nullptr) return 0;\n return static_cast(*e);\n }\n\n };\n\n}\n\n#endif // oatpp_data_mapping_type_Enum_hpp\n\n// Path: src/oatpp/core/data/mapping/type/Map.hpp\n/***************************************************************************\n *\n * Project _____ __ ____ _ _\n * ( _ ) /__\\ (_ _)_| |_ _| |_\n * )(_)( /(__)\\ )( (_ _)(_ _)\n * (_____)(__)(__)(__) |_| |_|\n *\n *\n * Copyright 2018-present, Leonid Stryzhevskyi \n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n *\n ***************************************************************************/\n\n#ifndef oatpp_data_mapping_type_Map_hpp\n#define oatpp_data_mapping_type_Map_hpp\n\n#include \"./Type.hpp\"\n#include \n\nnamespace oatpp { namespace data { namespace mapping { namespace type {\n\nnamespace __class {\n\n/**\n * Abstract Map.
\n * Ex.: UnorderedMap, Fields.\n */\nclass Map {\npublic:\n\n /**\n * Iterator.\n */\n struct Iterator {\n\n /**\n * Default virtual destructor.\n */\n virtual ~Iterator() = default;\n\n /**\n * Get current item key.\n * @return\n */\n virtual type::Void getKey() = 0;\n\n /**\n * Get current item value.\n * @return\n */\n virtual type::Void getValue() = 0;\n\n /**\n * Iterate to next item.\n */\n virtual void next() = 0;\n\n /**\n * Check if iterator finished.\n * @return\n */\n virtual bool finished() = 0;\n\n };\n\npublic:\n\n /**\n * Polymorphic Dispatcher\n */\n class PolymorphicDispatcher {\n public:\n\n /**\n * Virtual destructor.\n */\n virtual ~PolymorphicDispatcher() = default;\n\n /**\n * Create Map.\n * @return\n */\n virtual type::Void createObject() const = 0;\n\n /**\n * Get type of map keys.\n * @return\n */\n virtual const type::Type* getKeyType() const = 0;\n\n /**\n * Get type of map values.\n * @return\n */\n virtual const type::Type* getValueType() const = 0;\n\n /**\n * Get map size.\n * @param object - map object.\n * @return - size of the map.\n */\n virtual v_int64 getMapSize(const type::Void& object) const = 0;\n\n /**\n * Add item.\n * @param object - Map.\n * @param key\n * @param value\n */\n virtual void addItem(const type::Void& object, const type::Void& key, const type::Void& value) const = 0;\n\n /**\n * Begin map iteration.\n * @param object - Map.\n * @return\n */\n virtual std::unique_ptr beginIteration(const type::Void& object) const = 0;\n\n };\n\n template\n struct Inserter {\n\n static void insert(ContainerType* c, const KeyType& k, const ValueType& v) {\n (*c)[k] = v;\n }\n\n };\n\n};\n\ntemplate\nclass StandardMap {\npublic:\n\n struct Iterator : public Map::Iterator {\n\n typename ContainerType::iterator iterator;\n typename ContainerType::iterator end;\n\n type::Void getKey() override {\n return iterator->first;\n }\n\n type::Void getValue() override {\n return iterator->second;\n }\n\n void next() override {\n std::advance(iterator, 1);\n }\n\n bool finished() override {\n return iterator == end;\n }\n\n };\n\npublic:\n\n class PolymorphicDispatcher : public Map::PolymorphicDispatcher {\n public:\n\n type::Void createObject() const override {\n return type::Void(std::make_shared(), Clazz::getType());\n }\n\n const type::Type* getKeyType() const override {\n const type::Type* mapType = Clazz::getType();\n return mapType->params[0];\n }\n\n const type::Type* getValueType() const override {\n const type::Type* mapType = Clazz::getType();\n return mapType->params[1];\n }\n\n v_int64 getMapSize(const type::Void& object) const override {\n ContainerType* map = static_cast(object.get());\n return static_cast(map->size());\n }\n\n void addItem(const type::Void& object, const type::Void& key, const type::Void& value) const override {\n ContainerType* map = static_cast(object.get());\n const auto& mapKey = key.template cast();\n const auto& mapValue = value.template cast();\n Map::Inserter::insert(map, mapKey, mapValue);\n }\n\n std::unique_ptr beginIteration(const type::Void& object) const override {\n ContainerType* map = static_cast(object.get());\n auto iterator = new Iterator();\n iterator->iterator = map->begin();\n iterator->end = map->end();\n return std::unique_ptr(iterator);\n }\n\n };\n\n};\n\ntemplate\nstruct Map::Inserter>, KeyType, ValueType> {\n static void insert(std::list>* c, const KeyType& k, const ValueType& v) {\n c->push_back({k, v});\n }\n};\n\n}\n\n}}}}\n\n#endif //oatpp_data_mapping_type_Map_hpp\n\n// Path: src/oatpp/core/data/mapping/type/UnorderedMap.hpp\n/***************************************************************************\n *\n * Project _____ __ ____ _ _\n * ( _ ) /__\\ (_ _)_| |_ _| |_\n * )(_)( /(__)\\ )( (_ _)(_ _)\n * (_____)(__)(__)(__) |_| |_|\n *\n *\n * Copyright 2018-present, Leonid Stryzhevskyi \n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n *\n ***************************************************************************/\n\n#ifndef oatpp_data_mapping_type_UnorderedMap_hpp\n#define oatpp_data_mapping_type_UnorderedMap_hpp\n\n#include \"./Map.hpp\"\n#include \"./Type.hpp\"\n\n#include \n#include \n#include \n\nnamespace oatpp { namespace data { namespace mapping { namespace type {\n\nnamespace __class {\n\n /**\n * Abstract Unordered Map class.\n */\n class AbstractUnorderedMap {\n public:\n /**\n * Class Id.\n */\n static const ClassId CLASS_ID;\n };\n\n template\n class UnorderedMap;\n\n}\n\n/**\n * `ObjectWrapper` for `std::unordered_map`\n * @tparam Key - Key `ObjectWrapper` type.\n * @tparam Value - Value `ObjectWrapper` type.\n * @tparam C - Class.\n */\ntemplate\nclass UnorderedMapObjectWrapper : public type::ObjectWrapper, C> {\npublic:\n typedef std::unordered_map TemplateObjectType;\n typedef C TemplateObjectClass;\npublic:\n\n OATPP_DEFINE_OBJECT_WRAPPER_DEFAULTS(UnorderedMapObjectWrapper, TemplateObjectType, TemplateObjectClass)\n\n UnorderedMapObjectWrapper(std::initializer_list> ilist)\n : type::ObjectWrapper(std::make_shared(ilist))\n {}\n\n static UnorderedMapObjectWrapper createShared() {\n return std::make_shared();\n }\n\n UnorderedMapObjectWrapper& operator = (std::initializer_list> ilist) {\n this->m_ptr = std::make_shared(ilist);\n return *this;\n }\n\n Value& operator[] (const Key& key) const {\n return this->m_ptr->operator [] (key);\n }\n\n TemplateObjectType& operator*() const {\n return this->m_ptr.operator*();\n }\n\n};\n\n/**\n * Mapping-Enables UnorderedMap. See &l:UnorderedMapObjectWrapper;.\n */\ntemplate\nusing UnorderedMap = UnorderedMapObjectWrapper>;\n\nnamespace __class {\n\n template\n class UnorderedMap : public AbstractUnorderedMap {\n private:\n\n static Type createType() {\n Type::Info info;\n info.params.push_back(Key::Class::getType());\n info.params.push_back(Value::Class::getType());\n info.polymorphicDispatcher =\n new typename __class::StandardMap, Key, Value, UnorderedMap>::PolymorphicDispatcher();\n info.isMap = true;\n return Type(__class::AbstractUnorderedMap::CLASS_ID, info);\n }\n\n public:\n\n static Type* getType() {\n static Type type = createType();\n return &type;\n }\n\n };\n\n}\n\n}}}}\n\n#endif // oatpp_data_mapping_type_UnorderedMap_hpp\n\n// Path: src/oatpp/core/data/mapping/type/PairList.hpp\n/***************************************************************************\n *\n * Project _____ __ ____ _ _\n * ( _ ) /__\\ (_ _)_| |_ _| |_\n * )(_)( /(__)\\ )( (_ _)(_ _)\n * (_____)(__)(__)(__) |_| |_|\n *\n *\n * Copyright 2018-present, Leonid Stryzhevskyi \n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n *\n ***************************************************************************/\n\n#ifndef oatpp_data_mapping_type_PairList_hpp\n#define oatpp_data_mapping_type_PairList_hpp\n\n#include \"./Map.hpp\"\n#include \"./Type.hpp\"\n\n#include \n#include \n#include \n\nnamespace oatpp { namespace data { namespace mapping { namespace type {\n\nnamespace __class {\n\n /**\n * Abstract PairList class.\n */\n class AbstractPairList {\n public:\n /**\n * Class id.\n */\n static const ClassId CLASS_ID;\n };\n\n template\n class PairList;\n\n}\n\n/**\n * `ObjectWrapper` over `std::list>`\n * @tparam Key - Key `ObjectWrapper` type.\n * @tparam Value - Value `ObjectWrapper` type.\n * @tparam C - Class.\n */\ntemplate\nclass PairListObjectWrapper : public type::ObjectWrapper>, C> {\npublic:\n typedef std::list> TemplateObjectType;\n typedef C TemplateObjectClass;\npublic:\n\n OATPP_DEFINE_OBJECT_WRAPPER_DEFAULTS(PairListObjectWrapper, TemplateObjectType, TemplateObjectClass)\n\n \nPairListObjectWrapper(std::initializer_list> ilist)\n : type::ObjectWrapper(std::make_shared(ilist))\n {}\n\n static PairListObjectWrapper createShared() {\n return std::make_shared();\n }\n\n PairListObjectWrapper& operator = (std::initializer_list> ilist) {\n this->m_ptr = std::make_shared(ilist);\n return *this;\n }\n\n std::pair& operator[] (v_buff_usize index) const {\n auto it = this->m_ptr->begin();\n std::advance(it, index);\n return *it;\n }\n \n Value& operator[] (const Key& key) const {\n auto& list = *(this->m_ptr.get());\n auto it = list.begin();\n while(it != list.end()) {\n if(it->first == key) {\n return it->second;\n }\n it ++;\n }\n list.push_back({key, nullptr});\n return list.back().second;\n }\n\n Value getValueByKey(const Key& key, const Value& defValue = nullptr) const {\n auto& list = *(this->m_ptr.get());\n auto it = list.begin();\n while(it != list.end()) {\n if(it->first == key) {\n return it->second;\n }\n it ++;\n }\n return defValue;\n }\n\n TemplateObjectType& operator*() const {\n return this->m_ptr.operator*();\n }\n\n};\n\n/**\n * Mapping-Enables PairList. See &l:PairListObjectWrapper;.\n */\ntemplate\nusing PairList = PairListObjectWrapper>;\n\nnamespace __class {\n\ntemplate\nclass PairList : public AbstractPairList {\nprivate:\n\n static Type createType() {\n Type::Info info;\n info.params.push_back(Key::Class::getType());\n info.params.push_back(Value::Class::getType());\n info.polymorphicDispatcher =\n new typename __class::StandardMap>, Key, Value, PairList>::PolymorphicDispatcher();\n info.isMap = true;\n return Type(__class::AbstractPairList::CLASS_ID, info);\n }\n\npublic:\n\n static Type* getType() {\n static Type type = createType();\n return &type;\n }\n\n};\n\n}\n\n}}}}\n\n#endif // oatpp_data_mapping_type_PairList_hpp\n\n// Path: src/oatpp/core/data/mapping/type/Collection.hpp\n/***************************************************************************\n *\n * Project _____ __ ____ _ _\n * ( _ ) /__\\ (_ _)_| |_ _| |_\n * )(_)( /(__)\\ )( (_ _)(_ _)\n * (_____)(__)(__)(__) |_| |_|\n *\n *\n * Copyright 2018-present, Leonid Stryzhevskyi \n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n *\n ***************************************************************************/\n\n#ifndef oatpp_data_mapping_type_Collection_hpp\n#define oatpp_data_mapping_type_Collection_hpp\n\n#include \"./Type.hpp\"\n#include \n\nnamespace oatpp { namespace data { namespace mapping { namespace type {\n\nnamespace __class {\n\n/**\n * Abstract Collection.
\n * Ex.: Vector, List, Set.\n */\nclass Collection {\npublic:\n\n /**\n * Iterator.\n */\n struct Iterator {\n\n /**\n * Default virtual destructor.\n */\n virtual ~Iterator() = default;\n\n /**\n * Get current item.\n * @return\n */\n virtual type::Void get() = 0;\n\n /**\n * Iterate to next item.\n */\n virtual void next() = 0;\n\n /**\n * Check if iterator finished.\n * @return\n */\n virtual bool finished() = 0;\n\n };\n\npublic:\n\n /**\n * Polymorphic Dispatcher\n */\n class PolymorphicDispatcher {\n public:\n\n /**\n * Virtual destructor.\n */\n virtual ~PolymorphicDispatcher() = default;\n\n /**\n * Create Collection.\n * @return\n */\n virtual type::Void createObject() const = 0;\n\n /**\n * Get type of collection items.\n * @return\n */\n virtual const type::Type* getItemType() const = 0;\n\n /**\n * Get collection size.\n * @param object - collection.\n * @return - size of the collection (elements count).\n */\n virtual v_int64 getCollectionSize(const type::Void& object) const = 0;\n\n /**\n * Add item.\n * @param object - Collection.\n * @param item - Item to add.\n */\n virtual void addItem(const type::Void& object, const type::Void& item) const = 0;\n\n /**\n * Begin collection iteration.\n * @param object - Collection.\n * @return\n */\n virtual std::unique_ptr beginIteration(const type::Void& object) const = 0;\n\n };\n\n template\n struct Inserter {\n\n static void insert(ContainerType* c, const ItemType& i) {\n c->emplace_back(i);\n }\n\n };\n\n};\n\ntemplate\nclass StandardCollection {\npublic:\n\n struct Iterator : public Collection::Iterator {\n\n typename ContainerType::iterator iterator;\n typename ContainerType::iterator end;\n\n type::Void get() override {\n return *iterator;\n }\n\n void next() override {\n std::advance(iterator, 1);\n }\n\n bool finished() override {\n return iterator == end;\n }\n\n };\n\npublic:\n\n class PolymorphicDispatcher : public Collection::PolymorphicDispatcher {\n public:\n\n type::Void createObject() const override {\n return type::Void(std::make_shared(), Clazz::getType());\n }\n\n const type::Type* getItemType() const override {\n const type::Type* collectionType = Clazz::getType();\n return collectionType->params[0];\n }\n\n v_int64 getCollectionSize(const type::Void& object) const override {\n ContainerType* collection = static_cast(object.get());\n return static_cast(collection->size());\n }\n\n void addItem(const type::Void& object, const type::Void& item) const override {\n ContainerType* collection = static_cast(object.get());\n const auto& collectionItem = item.template cast();\n Collection::Inserter::insert(collection, collectionItem);\n }\n\n std::unique_ptr beginIteration(const type::Void& object) const override {\n ContainerType* collection = static_cast(object.get());\n auto iterator = new Iterator();\n iterator->iterator = collection->begin();\n iterator->end = collection->end();\n return std::unique_ptr(iterator);\n }\n\n };\n\n};\n\ntemplate\nstruct Collection::Inserter, ItemType> {\n static void insert(std::unordered_set* c, const ItemType& i) {\n c->emplace(i);\n }\n};\n\n}\n\n}}}}\n\n#endif //oatpp_data_mapping_type_Collection_hpp\n\n// Path: src/oatpp/core/data/mapping/type/List.hpp\n/***************************************************************************\n *\n * Project _____ __ ____ _ _\n * ( _ ) /__\\ (_ _)_| |_ _| |_\n * )(_)( /(__)\\ )( (_ _)(_ _)\n * (_____)(__)(__)(__) |_| |_|\n *\n *\n * Copyright 2018-present, Leonid Stryzhevskyi \n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n *\n ***************************************************************************/\n\n#ifndef oatpp_data_mapping_type_List_hpp\n#define oatpp_data_mapping_type_List_hpp\n\n#include \"./Collection.hpp\"\n#include \"./Type.hpp\"\n\n#include \n#include \n\nnamespace oatpp { namespace data { namespace mapping { namespace type {\n\nnamespace __class {\n\n /**\n * Abstract list class.\n */\n class AbstractList {\n public:\n\n /**\n * Class Id.\n */\n static const ClassId CLASS_ID;\n\n };\n\n template\n class List;\n\n}\n\n/**\n * `ObjectWrapper` over `std::list`\n * @tparam T - Item `ObjectWrapper` type.\n * @tparam C - Class.\n */\ntemplate\nclass ListObjectWrapper : public type::ObjectWrapper, C> {\npublic:\n typedef std::list TemplateObjectType;\n typedef C TemplateObjectClass;\npublic:\n\n OATPP_DEFINE_OBJECT_WRAPPER_DEFAULTS(ListObjectWrapper, TemplateObjectType, TemplateObjectClass)\n\n ListObjectWrapper(std::initializer_list ilist)\n : type::ObjectWrapper(std::make_shared(ilist))\n {}\n\n static ListObjectWrapper createShared() {\n return std::make_shared();\n }\n\n ListObjectWrapper& operator = (std::initializer_list ilist) {\n this->m_ptr = std::make_shared(ilist);\n return *this;\n }\n\n T& operator[] (v_buff_usize index) const {\n auto it = this->m_ptr->begin();\n std::advance(it, index);\n return *it;\n }\n\n TemplateObjectType& operator*() const {\n return this->m_ptr.operator*();\n }\n\n};\n\n/**\n * Mapping-Enabled List. See - &l:ListObjectWrapper;.\n */\ntemplate\nusing List = ListObjectWrapper>;\n\ntypedef List AbstractList;\n\nnamespace __class {\n\n template\n class List : public AbstractList {\n private:\n\n static Type createType() {\n Type::Info info;\n info.params.push_back(T::Class::getType());\n info.polymorphicDispatcher = new typename StandardCollection, T, List>::PolymorphicDispatcher();\n info.isCollection = true;\n return Type(__class::AbstractList::CLASS_ID, info);\n }\n\n public:\n\n static Type* getType() {\n static Type type = createType();\n return &type;\n }\n\n };\n\n}\n\n}}}}\n\n#endif // oatpp_data_mapping_type_List_hpp\n\n// Path: src/oatpp/core/data/mapping/type/Vector.hpp\n/***************************************************************************\n *\n * Project _____ __ ____ _ _\n * ( _ ) /__\\ (_ _)_| |_ _| |_\n * )(_)( /(__)\\ )( (_ _)(_ _)\n * (_____)(__)(__)(__) |_| |_|\n *\n *\n * Copyright 2018-present, Leonid Stryzhevskyi \n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n *\n ***************************************************************************/\n\n#ifndef oatpp_data_mapping_type_Vector_hpp\n#define oatpp_data_mapping_type_Vector_hpp\n\n#include \"./Collection.hpp\"\n#include \"./Type.hpp\"\n\n#include \n#include \n\nnamespace oatpp { namespace data { namespace mapping { namespace type {\n\nnamespace __class {\n\n /**\n * Abstract Vector Class.\n */\n class AbstractVector {\n public:\n /**\n * Class Id.\n */\n static const ClassId CLASS_ID;\n };\n\n template\n class Vector;\n\n}\n\n/**\n * `ObjectWrapper` over `std::vector`.\n * @tparam T - Item `ObjectWrapper` type.\n * @tparam C - Class.\n */\ntemplate\nclass VectorObjectWrapper : public type::ObjectWrapper, C> {\npublic:\n typedef std::vector TemplateObjectType;\n typedef C TemplateObjectClass;\npublic:\n\n OATPP_DEFINE_OBJECT_WRAPPER_DEFAULTS(VectorObjectWrapper, TemplateObjectType, TemplateObjectClass)\n\n VectorObjectWrapper(std::initializer_list ilist)\n : type::ObjectWrapper(std::make_shared(ilist))\n {}\n\n static VectorObjectWrapper createShared() {\n return std::make_shared();\n }\n\n VectorObjectWrapper& operator = (std::initializer_list ilist) {\n this->m_ptr = std::make_shared(ilist);\n return *this;\n }\n\n T& operator[] (v_buff_usize index) const {\n return this->m_ptr->operator [] (index);\n }\n\n TemplateObjectType& operator*() const {\n return this->m_ptr.operator*();\n }\n\n};\n\n/**\n * Mapping-enabled Vector. See &l:VectorObjectWrapper;.\n */\ntemplate\nusing Vector = VectorObjectWrapper>;\n\ntypedef Vector AbstractVector;\n\nnamespace __class {\n\n template\n class Vector : public AbstractVector {\n private:\n\n static Type createType() {\n Type::Info info;\n info.params.push_back(T::Class::getType());\n info.polymorphicDispatcher = new typename StandardCollection, T, Vector>::PolymorphicDispatcher();\n info.isCollection = true;\n return Type(__class::AbstractVector::CLASS_ID, info);\n }\n\n public:\n\n static Type* getType() {\n static Type type = createType();\n return &type;\n }\n\n };\n\n}\n\n}}}}\n\n#endif // oatpp_data_mapping_type_Vector_hpp\n\n// Path: src/oatpp/core/data/mapping/type/UnorderedSet.hpp\n/***************************************************************************\n *\n * Project _____ __ ____ _ _\n * ( _ ) /__\\ (_ _)_| |_ _| |_\n * )(_)( /(__)\\ )( (_ _)(_ _)\n * (_____)(__)(__)(__) |_| |_|\n *\n *\n * Copyright 2018-present, Leonid Stryzhevskyi \n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n *\n ***************************************************************************/\n\n#ifndef oatpp_data_mapping_type_UnorderedSet_hpp\n#define oatpp_data_mapping_type_UnorderedSet_hpp\n\n#include \"./Collection.hpp\"\n#include \"./Type.hpp\"\n\n#include \n#include \n\nnamespace oatpp { namespace data { namespace mapping { namespace type {\n\nnamespace __class {\n\n /**\n * Abstract Unordered Set class.\n */\n class AbstractUnorderedSet {\n public:\n /**\n * Class Id.\n */\n static const ClassId CLASS_ID;\n };\n\n template\n class UnorderedSet;\n\n}\n\n/**\n * `ObjectWrapper` over `std::unordered_set`\n * @tparam T - Key `ObjectWrapper` type.\n * @tparam C - Class.\n */\ntemplate\nclass UnorderedSetObjectWrapper : public type::ObjectWrapper, C> {\npublic:\n typedef std::unordered_set TemplateObjectType;\n typedef C TemplateObjectClass;\npublic:\n\n OATPP_DEFINE_OBJECT_WRAPPER_DEFAULTS(UnorderedSetObjectWrapper, TemplateObjectType, TemplateObjectClass)\n\n UnorderedSetObjectWrapper(std::initializer_list ilist)\n : type::ObjectWrapper(std::make_shared(ilist))\n {}\n\n static UnorderedSetObjectWrapper createShared() {\n return std::make_shared();\n }\n\n UnorderedSetObjectWrapper& operator = (std::initializer_list ilist) {\n this->m_ptr = std::make_shared(ilist);\n return *this;\n }\n\n bool operator[] (const T& key) const {\n if(this->m_ptr) {\n auto it = this->m_ptr->find(key);\n return it != this->m_ptr->end();\n }\n return false;\n }\n\n TemplateObjectType& operator*() const {\n return this->m_ptr.operator*();\n }\n\n};\n\n/**\n * Mapping-Enabled UnorderedSet. See &l:UnorderedSetObjectWrapper;.\n */\ntemplate\nusing UnorderedSet = UnorderedSetObjectWrapper>;\n\ntypedef UnorderedSet AbstractUnorderedSet;\n\nnamespace __class {\n\ntemplate\nclass UnorderedSet : public AbstractUnorderedSet {\nprivate:\n\n static Type createType() {\n Type::Info info;\n info.params.push_back(T::Class::getType());\n info.polymorphicDispatcher = new typename StandardCollection, T, UnorderedSet>::PolymorphicDispatcher();\n info.isCollection = true;\n return Type(__class::AbstractUnorderedSet::CLASS_ID, info);\n }\n\npublic:\n\n static Type* getType() {\n static Type type = createType();\n return &type;\n }\n\n};\n\n}\n\n}}}}\n\n#endif // oatpp_data_mapping_type_UnorderedSet_hpp\n\n// Path: src/oatpp/core/data/mapping/type/Object.hpp\n/***************************************************************************\n *\n * Project _____ __ ____ _ _\n...\n", "needle_token_start": 10275, "needle_token_end": 10331, "code_context_ntokens": 15782, "output": ["```cpp\ntemplate\nclass PairListObjectWrapper : public type::ObjectWrapper>, C> {\npublic:\n typedef std::list> TemplateObjectType;\n typedef C TemplateObjectClass;\npublic:\n\n OATPP_DEFINE_OBJECT_WRAPPER_DEFAULTS(PairListObjectWrapper, TemplateObjectType, TemplateObjectClass)\n\n PairListObjectWrapper(std::initializer_list> ilist)\n : type::ObjectWrapper(std::make_shared(ilist))\n {}\n\n static PairListObjectWrapper createShared() {\n return std::make_shared();\n }\n\n PairListObjectWrapper& operator = (std::initializer_list> ilist) {\n this->m_ptr = std::make_shared(ilist);\n return *this;\n }\n\n std::pair& operator[] (v_buff_usize index) const {\n auto it = this->m_ptr->begin();\n std::advance(it, index);\n return *it;\n }\n \n Value& operator[] (const Key& key) const {\n auto& list = *(this->m_ptr.get());\n auto it = list.begin();\n while(it != list.end()) {\n if(it->first == key) {\n return it->second;\n }\n it ++;\n }\n list.push_back({key, nullptr});\n return list.back().second;\n }\n\n Value getValueByKey(const Key& key, const Value& defValue = nullptr) const {\n auto& list = *(this->m_ptr.get());\n auto it = list.begin();\n while(it != list.end()) {\n if(it->first == key) {\n return it->second;\n }\n it ++;\n }\n return defValue;\n }\n\n TemplateObjectType& operator*() const {\n return this->m_ptr.operator*();\n }\n\n};\n```"]} +{"repo": "oatpp/oatpp", "name": "FlushCoroutine", "language": "cpp", "path": "src/oatpp/core/data/buffer/FIFOBuffer.cpp", "position_ratio": 0.75, "description": "\nFunction Description:\n1. **Purpose**: The purpose of this coroutine is to asynchronously transfer all available data from a buffer to an output stream until the buffer is empty.\n2. **Input**: The coroutine takes two inputs: a reference to a buffer from which data is read, and a shared pointer to an output stream where data is written.\n3. **Output**: The coroutine does not return a value directly but performs asynchronous I/O operations, resulting in data being written to the output stream.\n4. **Procedure**: The coroutine first checks if there is data available in the buffer. If data is available, it determines the segments of the buffer to write based on the current read and write positions. It then performs asynchronous writes to the output stream in one or two parts, depending on whether the data wraps around the end of the buffer. After writing, it resets the buffer's positions, effectively emptying it.\n\n", "instruction": "Based on the function description and code context, please retrieve and repeat the exact described function from the code context in a code block wrapped by ```:", "template": "instruction\ncode_context\ndescription\ninstruction", "code_context": "...\n// Path: src/oatpp/core/data/resource/TemporaryFile.hpp\n/***************************************************************************\n *\n * Project _____ __ ____ _ _\n * ( _ ) /__\\ (_ _)_| |_ _| |_\n * )(_)( /(__)\\ )( (_ _)(_ _)\n * (_____)(__)(__)(__) |_| |_|\n *\n *\n * Copyright 2018-present, Leonid Stryzhevskyi \n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n *\n ***************************************************************************/\n\n#ifndef oatpp_data_resource_TemporaryFile_hpp\n#define oatpp_data_resource_TemporaryFile_hpp\n\n#include \"./Resource.hpp\"\n#include \"oatpp/core/Types.hpp\"\n\nnamespace oatpp { namespace data { namespace resource {\n\n/**\n * Temporary file - the file which gets deleted when the destructor is called\n * (more precisely when all copies of the same `TemporaryFile` object deleted).
\n * The `TemporaryFile` object internally stores a `shared_ptr` to a file handle.\n * When file handle deleted it also deletes the underlying file.
\n * Thus it's safe to copy `TemporaryFile` object and you may treat `TemporaryFile` object\n * as a shared_ptr to a temporary file.
\n * @extends - &id:oatpp::data::Resource;.\n */\nclass TemporaryFile : public Resource {\nprivate:\n\n /*\n * Shared handle.\n * File is deleted on handle destroy.\n */\n struct FileHandle {\n\n oatpp::String fileName;\n\n FileHandle(const oatpp::String& fullFileName)\n : fileName(fullFileName)\n {}\n\n ~FileHandle();\n };\n\nprivate:\n static oatpp::String constructRandomFilename(const oatpp::String &dir, v_int32 randomWordSizeBytes, const oatpp::String &extension);\nprivate:\n std::shared_ptr m_handle;\npublic:\n\n /**\n * Default constructor.\n */\n TemporaryFile() = default;\n\n /**\n * Constructor.
\n * Create temporary file with a random name in the `tmpDirectory`.
\n * The actual file will be created only after first write to that file.
\n * Example of the generated random file name: `f7c6ecd44024ff31.tmp`.\n * @param tmpDirectory - directory where to create a temporary file.\n * @param randomWordSizeBytes - number of random bytes to generate file name.\n */\n TemporaryFile(const oatpp::String& tmpDirectory, v_int32 randomWordSizeBytes = 8);\n\n /**\n * Constructor.
\n * Create temporary file with the `tmpFileName` name in the `tmpDirectory`.
\n * @param tmpDirectory - directory where to create a temporary file.\n * @param tmpFileName - predefined name for the temporary file.\n */\n TemporaryFile(const oatpp::String& tmpDirectory, const oatpp::String& tmpFileName);\n\n /**\n * Constructor.
\n * Create temporary file with a random name and specified extension in the `tmpDirectory`.
\n * The actual file will be created only after first write to that file.
\n * Example of the generated random file name: `f7c6ecd44024ff31.txt`.\n * @param tmpDirectory - directory where to create a temporary file.\n * @param randomWordSizeBytes - number of random bytes to generate file name.\n * @param extension - extension of the temporary file, e.g. txt or .txt\n */\n TemporaryFile(const oatpp::String& tmpDirectory, v_int32 randomWordSizeBytes, const oatpp::String& extension);\n\n /**\n * Open output stream to a temporary file.
\n * *Note: stream also captures file-handle. The temporary file won't be deleted until the stream is deleted.*\n * @return - `std::shared_ptr` to &id:oatpp::data::stream::OutputStream;.\n */\n std::shared_ptr openOutputStream() override;\n\n /**\n * Open input stream to a temporary file.
\n * *Note: stream also captures file-handle. The temporary file won't be deleted until the stream is deleted.*\n * @return - `std::shared_ptr` &id:oatpp::data::stream::InputStream;.\n */\n std::shared_ptr openInputStream() override;\n\n /**\n * Not applicable.\n * @return - always returns `nullptr`.\n */\n oatpp::String getInMemoryData() override;\n\n /**\n * Not applicable.\n * @return - always returns `-1`.\n */\n v_int64 getKnownSize() override;\n\n /**\n * Get location where temporary data is stored.\n * @return - `&id:oatpp::String;`.\n */\n oatpp::String getLocation() override;\n\n /**\n * Move payload to a different file.
\n * @param fullFileName - full-file-name. Note: all the parent folders must exist.\n * @return - `true` - file was successfully moved, `false` - otherwise.\n */\n bool moveFile(const oatpp::String& fullFileName);\n\n};\n\n}}}\n\n#endif //oatpp_data_resource_TemporaryFile_hpp\n\n// Path: src/oatpp/core/data/resource/TemporaryFile.cpp\n/***************************************************************************\n *\n * Project _____ __ ____ _ _\n * ( _ ) /__\\ (_ _)_| |_ _| |_\n * )(_)( /(__)\\ )( (_ _)(_ _)\n * (_____)(__)(__)(__) |_| |_|\n *\n *\n * Copyright 2018-present, Leonid Stryzhevskyi \n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n *\n ***************************************************************************/\n\n#include \"TemporaryFile.hpp\"\n\n#include \"./File.hpp\"\n\n#include \"oatpp/core/data/stream/FileStream.hpp\"\n#include \"oatpp/core/data/stream/BufferStream.hpp\"\n#include \"oatpp/encoding/Hex.hpp\"\n#include \"oatpp/core/utils/Random.hpp\"\n\nnamespace oatpp { namespace data { namespace resource {\n\nTemporaryFile::FileHandle::~FileHandle() {\n if(fileName) {\n std::remove(fileName->c_str());\n }\n}\n\noatpp::String TemporaryFile::constructRandomFilename(const oatpp::String &dir, v_int32 randomWordSizeBytes, const oatpp::String &extension) {\n\n std::unique_ptr buff(new v_char8[randomWordSizeBytes]);\n utils::random::Random::randomBytes(buff.get(), randomWordSizeBytes);\n data::stream::BufferOutputStream s(randomWordSizeBytes * 2 + 4);\n encoding::Hex::encode(&s, buff.get(), randomWordSizeBytes, encoding::Hex::ALPHABET_LOWER);\n if (extension->at(0) != '.') {\n s << \".\";\n }\n s << extension;\n\n return File::concatDirAndName(dir, s.toString());\n\n}\n\nTemporaryFile::TemporaryFile(const oatpp::String& tmpDirectory, v_int32 randomWordSizeBytes)\n : m_handle(std::make_shared(constructRandomFilename(tmpDirectory, randomWordSizeBytes, \"tmp\")))\n{}\n\nTemporaryFile::TemporaryFile(const oatpp::String& tmpDirectory, const oatpp::String& tmpFileName)\n : m_handle(std::make_shared(File::concatDirAndName(tmpDirectory, tmpFileName)))\n{}\n\nTemporaryFile::TemporaryFile(const oatpp::String& tmpDirectory, v_int32 randomWordSizeBytes, const oatpp::String& extension)\n : m_handle(std::make_shared(constructRandomFilename(tmpDirectory, randomWordSizeBytes, extension)))\n{}\n\nstd::shared_ptr TemporaryFile::openOutputStream() {\n if(m_handle) {\n return std::make_shared(m_handle->fileName->c_str(), \"wb\", m_handle);\n }\n throw std::runtime_error(\"[oatpp::data::resource::TemporaryFile::openOutputStream()]: Error. FileHandle is NOT initialized.\");\n}\n\nstd::shared_ptr TemporaryFile::openInputStream() {\n if(m_handle) {\n return std::make_shared(m_handle->fileName->c_str(), m_handle);\n }\n throw std::runtime_error(\"[oatpp::data::resource::TemporaryFile::openInputStream()]: Error. FileHandle is NOT initialized.\");\n}\n\noatpp::String TemporaryFile::getInMemoryData() {\n return nullptr;\n}\n\nv_int64 TemporaryFile::getKnownSize() {\n return -1;\n}\n\noatpp::String TemporaryFile::getLocation() {\n if(m_handle) {\n return m_handle->fileName;\n }\n return nullptr;\n}\n\nbool TemporaryFile::moveFile(const oatpp::String& fullFileName) {\n if(m_handle) {\n return std::rename(m_handle->fileName->c_str(), fullFileName->c_str()) == 0;\n }\n return false;\n}\n\n}}}\n\n// Path: src/oatpp/core/data/buffer/Processor.hpp\n/***************************************************************************\n *\n * Project _____ __ ____ _ _\n * ( _ ) /__\\ (_ _)_| |_ _| |_\n * )(_)( /(__)\\ )( (_ _)(_ _)\n * (_____)(__)(__)(__) |_| |_|\n *\n *\n * Copyright 2018-present, Leonid Stryzhevskyi \n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n *\n ***************************************************************************/\n\n#ifndef oatpp_data_buffer_Processor_hpp\n#define oatpp_data_buffer_Processor_hpp\n\n#include \"oatpp/core/IODefinitions.hpp\"\n#include \"oatpp/core/base/ObjectHandle.hpp\"\n#include \n\nnamespace oatpp { namespace data { namespace buffer {\n\n/**\n * Convenience structure for stream Async-Inline read operations.\n */\nstruct InlineReadData {\n\n /**\n * Pointer to current position in the buffer.\n */\n void* currBufferPtr;\n\n /**\n * Bytes left to read to the buffer.\n */\n v_buff_size bytesLeft;\n\n /**\n * Default constructor.\n */\n InlineReadData();\n\n /**\n * Constructor.\n * @param data\n * @param size\n */\n InlineReadData(void* data, v_buff_size size);\n\n /**\n * Set `currBufferPtr` and `bytesLeft` values.
\n * @param data - pointer to buffer to store read data.\n * @param size - size in bytes of the buffer.\n */\n void set(void* data, v_buff_size size);\n\n /**\n * Increase position in the read buffer by `amount` bytes.
\n * This will increase `currBufferPtr` and descrease `bytesLeft` values.\n * @param amount\n */\n void inc(v_buff_size amount);\n\n /**\n * Same as `inc(bytesLeft).`\n */\n void setEof();\n\n};\n\n/**\n * Convenience structure for stream Async-Inline write operations.\n */\nstruct InlineWriteData {\n\n /**\n * Pointer to current position in the buffer.\n */\n const void* currBufferPtr;\n\n /**\n * Bytes left to write from the buffer.\n */\n v_buff_size bytesLeft;\n\n /**\n * Default constructor.\n */\n InlineWriteData();\n\n /**\n * Constructor.\n * @param data\n * @param size\n */\n InlineWriteData(const void* data, v_buff_size size);\n\n /**\n * Set `currBufferPtr` and `bytesLeft` values.
\n * @param data - pointer to buffer containing data to be written.\n * @param size - size in bytes of the buffer.\n */\n void set(const void* data, v_buff_size size);\n\n /**\n * Increase position in the write buffer by `amount` bytes.
\n * This will increase `currBufferPtr` and descrease `bytesLeft` values.\n * @param amount\n */\n void inc(v_buff_size amount);\n\n /**\n * Same as `inc(bytesLeft).`\n */\n void setEof();\n\n};\n\n/**\n * Buffer processor.\n * Note: all processors are considered to be stateful.\n */\nclass Processor {\npublic:\n\n /**\n * Enum of processing errors.\n */\n enum Error : v_int32 {\n\n /**\n * No error.\n */\n OK = 0,\n\n /**\n * Caller must set fields of `dataIn` parameter.\n */\n PROVIDE_DATA_IN = 1,\n\n /**\n * Caller must read all the data from the `dataOut`.\n */\n FLUSH_DATA_OUT = 2,\n\n /**\n * Processing is finished.\n */\n FINISHED = 3\n\n //*********************************************//\n // Other values are processor-specific errors. //\n //*********************************************//\n };\n\npublic:\n\n /**\n * Default virtual destructor.\n */\n virtual ~Processor() = default;\n\n /**\n * If the client is using the input stream to read data and add it to the processor,\n * the client MAY ask the processor for a suggested read size.\n * @return - suggested read size.\n */\n virtual v_io_size suggestInputStreamReadSize() = 0;\n\n /**\n * Process data.\n * @param dataIn - data provided by client to processor. Input data. &id:data::buffer::InlineReadData;.\n * Set `dataIn` buffer pointer to `nullptr` to designate the end of input.\n * @param dataOut - data provided to client by processor. Output data. &id:data::buffer::InlineReadData;.\n * @return - &l:Processor::Error;.\n */\n virtual v_int32 iterate(data::buffer::InlineReadData& dataIn,\n data::buffer::InlineReadData& dataOut) = 0;\n\n};\n\n/**\n * Pipeline of buffer processors.\n */\nclass ProcessingPipeline : public Processor {\nprivate:\n std::vector> m_processors;\n std::vector m_intermediateData;\npublic:\n\n /**\n * Constructor.\n * @param m_processors - the array of processors defining the pipeline.\n */\n ProcessingPipeline(const std::vector>& m_processors);\n\n /**\n * If the client is using the input stream to read data and add it to the processor,\n * the client MAY ask the processor for a suggested read size.\n * @return - suggested read size.\n */\n v_io_size suggestInputStreamReadSize() override;\n\n /**\n * Process data.\n * @param dataIn - data provided by client to processor. Input data. &id:data::buffer::InlineReadData;.\n * Set `dataIn` buffer pointer to `nullptr` to designate the end of input.\n * @param dataOut - data provided to client by processor. Output data. &id:data::buffer::InlineReadData;.\n * @return - &l:Processor::Error;.\n */\n v_int32 iterate(data::buffer::InlineReadData& dataIn,\n data::buffer::InlineReadData& dataOut) override;\n\n};\n\n}}}\n\n#endif // oatpp_data_buffer_Processor_hpp\n\n// Path: src/oatpp/core/data/buffer/Processor.cpp\n/***************************************************************************\n *\n * Project _____ __ ____ _ _\n * ( _ ) /__\\ (_ _)_| |_ _| |_\n * )(_)( /(__)\\ )( (_ _)(_ _)\n * (_____)(__)(__)(__) |_| |_|\n *\n *\n * Copyright 2018-present, Leonid Stryzhevskyi \n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n *\n ***************************************************************************/\n\n#include \"Processor.hpp\"\n\nnamespace oatpp { namespace data{ namespace buffer {\n\n////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////\n// InlineReadData\n\nInlineReadData::InlineReadData()\n : currBufferPtr(nullptr)\n , bytesLeft(0)\n{}\n\nInlineReadData::InlineReadData(void* data, v_buff_size size)\n : currBufferPtr(data)\n , bytesLeft(size)\n{}\n\nvoid InlineReadData::set(void* data, v_buff_size size) {\n currBufferPtr = data;\n bytesLeft = size;\n}\n\nvoid InlineReadData::inc(v_buff_size amount) {\n currBufferPtr = &(reinterpret_cast(currBufferPtr))[amount];\n bytesLeft -= amount;\n}\n\nvoid InlineReadData::setEof() {\n currBufferPtr = &(reinterpret_cast(currBufferPtr))[bytesLeft];\n bytesLeft = 0;\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////\n// InlineWriteData\n\nInlineWriteData::InlineWriteData()\n : currBufferPtr(nullptr)\n , bytesLeft(0)\n{}\n\nInlineWriteData::InlineWriteData(const void* data, v_buff_size size)\n : currBufferPtr(data)\n , bytesLeft(size)\n{}\n\nvoid InlineWriteData::set(const void* data, v_buff_size size) {\n currBufferPtr = data;\n bytesLeft = size;\n}\n\nvoid InlineWriteData::inc(v_buff_size amount) {\n currBufferPtr = &(reinterpret_cast(const_cast(currBufferPtr)))[amount];\n bytesLeft -= amount;\n}\n\nvoid InlineWriteData::setEof() {\n currBufferPtr = &(reinterpret_cast(const_cast(currBufferPtr)))[bytesLeft];\n bytesLeft = 0;\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////\n// ProcessingPipeline\n\nProcessingPipeline::ProcessingPipeline(const std::vector>& processors)\n : m_processors(processors), m_intermediateData(processors.size() - 1)\n{\n}\n\nv_io_size ProcessingPipeline::suggestInputStreamReadSize() {\n return m_processors[0]->suggestInputStreamReadSize();\n}\n\nv_int32 ProcessingPipeline::iterate(data::buffer::InlineReadData& dataIn,\n data::buffer::InlineReadData& dataOut)\n{\n\n if(dataOut.bytesLeft > 0) {\n return Error::FLUSH_DATA_OUT;\n }\n\n size_t i = 0;\n v_int32 res = Error::OK;\n\n while(res == Error::OK) {\n\n auto& p = m_processors[i];\n\n data::buffer::InlineReadData* currDataIn = &dataIn;\n if(i > 0) {\n currDataIn = &m_intermediateData[i - 1];\n }\n\n data::buffer::InlineReadData* currDataOut = &dataOut;\n if(i < m_intermediateData.size()) {\n currDataOut = &m_intermediateData[i];\n }\n\n while(res == Error::OK) {\n res = p->iterate(*currDataIn, *currDataOut);\n }\n\n const size_t numOfProcessors = m_processors.size();\n\n switch (res) {\n case Error::PROVIDE_DATA_IN:\n if (i > 0) {\n i --;\n res = Error::OK;\n }\n break;\n\n\n case Error::FLUSH_DATA_OUT:\n if (i < numOfProcessors - 1) {\n i ++;\n res = Error::OK;\n }\n break;\n\n\n case Error::FINISHED:\n if (i < numOfProcessors - 1) {\n i ++;\n res = Error::OK;\n }\n break;\n\n default:\n break;\n }\n\n }\n\n return res;\n\n}\n\n}}}\n\n// Path: src/oatpp/core/data/buffer/FIFOBuffer.hpp\n/***************************************************************************\n *\n * Project _____ __ ____ _ _\n * ( _ ) /__\\ (_ _)_| |_ _| |_\n * )(_)( /(__)\\ )( (_ _)(_ _)\n * (_____)(__)(__)(__) |_| |_|\n *\n *\n * Copyright 2018-present, Leonid Stryzhevskyi \n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n *\n ***************************************************************************/\n\n#ifndef oatpp_data_buffer_FIFOBuffer_hpp\n#define oatpp_data_buffer_FIFOBuffer_hpp\n\n#include \"oatpp/core/data/stream/Stream.hpp\"\n#include \"oatpp/core/IODefinitions.hpp\"\n#include \"oatpp/core/async/Coroutine.hpp\"\n#include \"oatpp/core/concurrency/SpinLock.hpp\"\n\nnamespace oatpp { namespace data { namespace buffer {\n\n/**\n * FIFO operations over the buffer\n * !FIFOBuffer is NOT an IOStream despite having similar APIs!\n */\nclass FIFOBuffer {\nprivate:\n p_char8 m_buffer;\n v_buff_size m_bufferSize;\n v_buff_size m_readPosition;\n v_buff_size m_writePosition;\n bool m_canRead;\npublic:\n\n /**\n * Constructor.\n * @param buffer - pointer to buffer used for reads/writes.\n * @param bufferSize - buffer size.\n * @param readPosition - initial read position in buffer.\n * @param writePosition - initial write position in buffer.\n * @param canRead - flag to resolve ambiguity when readPosition == writePosition. If(readPosition == writePosition && canRead) then\n * &l:FIFOBuffer::availableToRead (); returns buffer size, and &l:FIFOBuffer::availableToWrite (); returns 0.\n */\n FIFOBuffer(void* buffer, v_buff_size bufferSize,\n v_buff_size readPosition = 0, v_buff_size writePosition = 0,\n bool canRead = false);\n\n /**\n * Set read and write positions in buffer.\n * @param readPosition - read position in buffer.\n * @param writePosition - write position in buffer.\n * @param canRead - flag to resolve ambiguity when readPosition == writePosition. If(readPosition == writePosition && canRead) then\n * &l:FIFOBuffer::availableToRead (); returns buffer size, and &l:FIFOBuffer::availableToWrite (); returns 0.\n */\n void setBufferPosition(v_buff_size readPosition, v_buff_size writePosition, bool canRead);\n\n /**\n * Amount of bytes currently available to read from buffer.\n * @return &id:oatpp::v_io_size;.\n */\n v_io_size availableToRead() const;\n\n /**\n * Amount of buffer space currently available for data writes.\n * @return &id:oatpp::v_io_size;.\n */\n v_io_size availableToWrite() const;\n\n /**\n * Get FIFOBuffer size.\n * @return - FIFOBuffer size.\n */\n v_buff_size getBufferSize() const;\n\n /**\n * read up to count bytes from the buffer to data\n * @param data\n * @param count\n * @return [1..count], IOErrors.\n */\n v_io_size read(void *data, v_buff_size count);\n\n /**\n * Peek up to count of bytes int he buffer\n * @param data\n * @param count\n * @return [1..count], IOErrors.\n */\n v_io_size peek(void *data, v_buff_size count);\n\n /**\n * Commit read offset\n * @param count\n * @return [1..count], IOErrors.\n */\n v_io_size commitReadOffset(v_buff_size count);\n\n /**\n * write up to count bytes from data to buffer\n * @param data\n * @param count\n * @return [1..count], IOErrors.\n */\n v_io_size write(const void *data, v_buff_size count);\n\n /**\n * call read and then write bytes read to output stream\n * @param stream\n * @param count\n * @param action\n * @return [1..count], IOErrors.\n */\n v_io_size readAndWriteToStream(data::stream::WriteCallback* stream, v_buff_size count, async::Action& action);\n\n /**\n * call stream.read() and then write bytes read to buffer\n * @param stream\n * @param count\n * @param action\n * @return\n */\n v_io_size readFromStreamAndWrite(data::stream::ReadCallback* stream, v_buff_size count, async::Action& action);\n\n /**\n * flush all availableToRead bytes to stream\n * @param stream\n * @return\n */\n v_io_size flushToStream(data::stream::OutputStream* stream);\n\n /**\n * flush all availableToRead bytes to stream in asynchronous manner\n * @param stream - &id:data::stream::OutputStream;.\n * @return - &id:async::CoroutineStarter;.\n */\n async::CoroutineStarter flushToStreamAsync(const std::shared_ptr& stream);\n\n \n};\n\n/**\n * Same as FIFOBuffer + synchronization with SpinLock\n */\nclass SynchronizedFIFOBuffer {\nprivate:\n FIFOBuffer m_fifo;\n oatpp::concurrency::SpinLock m_lock;\npublic:\n\n /**\n * Constructor.\n * @param buffer - pointer to buffer used for reads/writes.\n * @param bufferSize - buffer size.\n * @param readPosition - initial read position in buffer.\n * @param writePosition - initial write position in buffer.\n * @param canRead - flag to resolve ambiguity when readPosition == writePosition. If(readPosition == writePosition && canRead) then\n * &l:SynchronizedFIFOBuffer::availableToRead (); returns buffer size, and &l:SynchronizedFIFOBuffer::availableToWrite (); returns 0.\n */\n SynchronizedFIFOBuffer(void* buffer, v_buff_size bufferSize,\n v_buff_size readPosition = 0, v_buff_size writePosition = 0,\n bool canRead = false);\n\n /**\n * Set read and write positions in buffer.\n * @param readPosition - read position in buffer.\n * @param writePosition - write position in buffer.\n * @param canRead - flag to resolve ambiguity when readPosition == writePosition. If(readPosition == writePosition && canRead) then\n * &l:SynchronizedFIFOBuffer::availableToRead (); returns buffer size, and &l:SynchronizedFIFOBuffer::availableToWrite (); returns 0.\n */\n void setBufferPosition(v_buff_size readPosition, v_buff_size writePosition, bool canRead);\n\n /**\n * Amount of bytes currently available to read from buffer.\n * @return &id:oatpp::v_io_size;.\n */\n v_io_size availableToRead();\n\n /**\n * Amount of buffer space currently available for data writes.\n * @return &id:oatpp::v_io_size;.\n */\n v_io_size availableToWrite();\n\n /**\n * read up to count bytes from the buffer to data\n * @param data\n * @param count\n * @return [1..count], IOErrors.\n */\n v_io_size read(void *data, v_buff_size count);\n\n /**\n * write up to count bytes from data to buffer\n * @param data\n * @param count\n * @return [1..count], IOErrors.\n */\n v_io_size write(const void *data, v_buff_size count);\n\n /* No implementation of other methods */\n /* User should implement his own synchronization for other methods */\n\n};\n \n}}}\n\n#endif /* FIFOBuffer_hpp */\n\n// Path: src/oatpp/core/data/buffer/FIFOBuffer.cpp\n/***************************************************************************\n *\n * Project _____ __ ____ _ _\n * ( _ ) /__\\ (_ _)_| |_ _| |_\n * )(_)( /(__)\\ )( (_ _)(_ _)\n * (_____)(__)(__)(__) |_| |_|\n *\n *\n * Copyright 2018-present, Leonid Stryzhevskyi \n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n *\n ***************************************************************************/\n\n#include \"FIFOBuffer.hpp\"\n#include \n\nnamespace oatpp { namespace data{ namespace buffer {\n\nFIFOBuffer::FIFOBuffer(void* buffer, v_buff_size bufferSize,\n v_buff_size readPosition, v_buff_size writePosition,\n bool canRead)\n : m_buffer(reinterpret_cast(buffer))\n , m_bufferSize(bufferSize)\n , m_readPosition(readPosition)\n , m_writePosition(writePosition)\n , m_canRead(canRead)\n{}\n\nvoid FIFOBuffer::setBufferPosition(v_buff_size readPosition, v_buff_size writePosition, bool canRead) {\n m_readPosition = readPosition;\n m_writePosition = writePosition;\n m_canRead = canRead;\n}\n\nv_io_size FIFOBuffer::availableToRead() const {\n if(!m_canRead) {\n return 0;\n }\n if(m_readPosition < m_writePosition) {\n return m_writePosition - m_readPosition;\n }\n return (m_bufferSize - m_readPosition + m_writePosition);\n}\n\nv_io_size FIFOBuffer::availableToWrite() const {\n if(m_canRead && m_writePosition == m_readPosition) {\n return 0;\n }\n if(m_writePosition < m_readPosition) {\n return m_readPosition - m_writePosition;\n }\n return (m_bufferSize - m_writePosition + m_readPosition);\n}\n\nv_buff_size FIFOBuffer::getBufferSize() const {\n return m_bufferSize;\n}\n\nv_io_size FIFOBuffer::read(void *data, v_buff_size count) {\n \n if(!m_canRead) {\n return IOError::RETRY_READ;\n }\n \n if(count == 0) {\n return 0;\n } else if(count < 0) {\n throw std::runtime_error(\"[oatpp::data::buffer::FIFOBuffer::read(...)]: count < 0\");\n }\n \n if(m_readPosition < m_writePosition) {\n auto size = m_writePosition - m_readPosition;\n if(size > count) {\n size = count;\n }\n std::memcpy(data, &m_buffer[m_readPosition], static_cast(size));\n m_readPosition += size;\n if(m_readPosition == m_writePosition) {\n m_canRead = false;\n }\n return size;\n }\n \n auto size = m_bufferSize - m_readPosition;\n \n if(size > count){\n std::memcpy(data, &m_buffer[m_readPosition], static_cast(count));\n m_readPosition += count;\n return count;\n }\n \n std::memcpy(data, &m_buffer[m_readPosition], static_cast(size));\n auto size2 = m_writePosition;\n if(size2 > count - size) {\n size2 = count - size;\n }\n \n std::memcpy(&(reinterpret_cast(data))[size], m_buffer, static_cast(size2));\n m_readPosition = size2;\n if(m_readPosition == m_writePosition) {\n m_canRead = false;\n }\n \n return (size + size2);\n \n}\n\nv_io_size FIFOBuffer::peek(void *data, v_buff_size count) {\n\n if(!m_canRead) {\n return IOError::RETRY_READ;\n }\n\n if(count == 0) {\n return 0;\n } else if(count < 0) {\n throw std::runtime_error(\"[oatpp::data::buffer::FIFOBuffer::peek(...)]: count < 0\");\n }\n\n if(m_readPosition < m_writePosition) {\n auto size = m_writePosition - m_readPosition;\n if(size > count) {\n size = count;\n }\n std::memcpy(data, &m_buffer[m_readPosition], static_cast(size));\n return size;\n }\n\n auto size = m_bufferSize - m_readPosition;\n\n if(size > count){\n std::memcpy(data, &m_buffer[m_readPosition], static_cast(count));\n return count;\n }\n\n std::memcpy(data, &m_buffer[m_readPosition], static_cast(size));\n auto size2 = m_writePosition;\n if(size2 > count - size) {\n size2 = count - size;\n }\n\n std::memcpy(&(reinterpret_cast(data))[size], m_buffer, static_cast(size2));\n\n return (size + size2);\n\n}\n\nv_io_size FIFOBuffer::commitReadOffset(v_buff_size count) {\n\n if(!m_canRead) {\n return IOError::RETRY_READ;\n }\n\n if(count == 0) {\n return 0;\n } else if(count < 0) {\n throw std::runtime_error(\"[oatpp::data::buffer::FIFOBuffer::commitReadOffset(...)]: count < 0\");\n }\n\n if(m_readPosition < m_writePosition) {\n auto size = m_writePosition - m_readPosition;\n if(size > count) {\n size = count;\n }\n m_readPosition += size;\n if(m_readPosition == m_writePosition) {\n m_canRead = false;\n }\n return size;\n }\n\n auto size = m_bufferSize - m_readPosition;\n\n if(size > count){\n m_readPosition += count;\n return count;\n }\n\n auto size2 = m_writePosition;\n if(size2 > count - size) {\n size2 = count - size;\n }\n\n m_readPosition = size2;\n if(m_readPosition == m_writePosition) {\n m_canRead = false;\n }\n\n return (size + size2);\n\n}\n\nv_io_size FIFOBuffer::write(const void *data, v_buff_size count) {\n \n if(m_canRead && m_writePosition == m_readPosition) {\n return IOError::RETRY_WRITE;\n }\n \n if(count == 0) {\n return 0;\n } else if(count < 0) {\n throw std::runtime_error(\"[oatpp::data::buffer::FIFOBuffer::write(...)]: count < 0\");\n } else {\n m_canRead = true;\n }\n \n if(m_writePosition < m_readPosition) {\n auto size = m_readPosition - m_writePosition;\n if(size > count) {\n size = count;\n }\n std::memcpy(&m_buffer[m_writePosition], data, static_cast(size));\n m_writePosition += size;\n return size;\n }\n \n auto size = m_bufferSize - m_writePosition;\n \n if(size > count){\n std::memcpy(&m_buffer[m_writePosition], data, static_cast(count));\n m_writePosition += count;\n return count;\n }\n \n std::memcpy(&m_buffer[m_writePosition], data, static_cast(size));\n auto size2 = m_readPosition;\n if(size2 > count - size) {\n size2 = count - size;\n }\n \n std::memcpy(m_buffer, &(reinterpret_cast(data))[size], static_cast(size2));\n m_writePosition = size2;\n \n return (size + size2);\n \n}\n\nv_io_size FIFOBuffer::readAndWriteToStream(data::stream::WriteCallback* stream, v_buff_size count, async::Action& action) {\n\n if(!m_canRead) {\n return IOError::RETRY_READ;\n }\n\n if(count == 0) {\n return 0;\n } else if(count < 0) {\n throw std::runtime_error(\"[oatpp::data::buffer::FIFOBuffer::readAndWriteToStream(...)]: count < 0\");\n }\n\n if(m_readPosition < m_writePosition) {\n auto size = m_writePosition - m_readPosition;\n if(size > count) {\n size = count;\n }\n auto bytesWritten = stream->write(&m_buffer[m_readPosition], size, action);\n if(bytesWritten > 0) {\n m_readPosition += bytesWritten;\n if (m_readPosition == m_writePosition) {\n m_canRead = false;\n }\n }\n return bytesWritten;\n }\n\n auto size = m_bufferSize - m_readPosition;\n\n /* DO NOT call stream.write() twice if size > count !!! */\n if(size > count){\n size = count;\n } else if(size == 0) {\n\n auto bytesWritten = stream->write(m_buffer, m_writePosition, action);\n if(bytesWritten > 0) {\n m_readPosition = bytesWritten;\n if (m_readPosition == m_writePosition) {\n m_canRead = false;\n }\n }\n return bytesWritten;\n\n }\n\n auto bytesWritten = stream->write(&m_buffer[m_readPosition], size, action);\n if(bytesWritten > 0) {\n m_readPosition += bytesWritten;\n }\n return bytesWritten;\n\n}\n\nv_io_size FIFOBuffer::readFromStreamAndWrite(data::stream::ReadCallback* stream, v_buff_size count, async::Action& action) {\n\n if(m_canRead && m_writePosition == m_readPosition) {\n return IOError::RETRY_WRITE;\n }\n\n if(count == 0) {\n return 0;\n } else if(count < 0) {\n throw std::runtime_error(\"[oatpp::data::buffer::FIFOBuffer::readFromStreamAndWrite(...)]: count < 0\");\n }\n\n if(m_writePosition < m_readPosition) {\n auto size = m_readPosition - m_writePosition;\n if(size > count) {\n size = count;\n }\n auto bytesRead = stream->read(&m_buffer[m_writePosition], size, action);\n if(bytesRead > 0) {\n m_writePosition += bytesRead;\n m_canRead = true;\n }\n return bytesRead;\n }\n\n auto size = m_bufferSize - m_writePosition;\n\n /* DO NOT call stream.read() twice if size > count !!! */\n if(size > count){\n size = count;\n } else if(size == 0) {\n\n auto bytesRead = stream->read(m_buffer, m_readPosition, action);\n if(bytesRead > 0) {\n m_writePosition = bytesRead;\n m_canRead = true;\n }\n\n return bytesRead;\n\n }\n\n auto bytesRead = stream->read(&m_buffer[m_writePosition], size, action);\n if(bytesRead > 0) {\n m_writePosition += bytesRead;\n m_canRead = true;\n }\n\n return bytesRead;\n\n}\n\nv_io_size FIFOBuffer::flushToStream(data::stream::OutputStream* stream) {\n\n if(!m_canRead) {\n return 0;\n }\n\n v_io_size result = 0;\n\n if(m_readPosition < m_writePosition) {\n result = stream->writeExactSizeDataSimple(&m_buffer[m_readPosition], m_writePosition - m_readPosition);\n } else {\n result = stream->writeExactSizeDataSimple(&m_buffer[m_readPosition], m_bufferSize - m_readPosition);\n result += stream->writeExactSizeDataSimple(m_buffer, m_writePosition);\n }\n\n setBufferPosition(0, 0, false);\n\n return result;\n\n}\n\nasync::CoroutineStarter FIFOBuffer::flushToStreamAsync(const std::shared_ptr& stream)\n{\n\n class FlushCoroutine : public oatpp::async::Coroutine {\n private:\n FIFOBuffer* m_fifo;\n std::shared_ptr m_stream;\n private:\n data::buffer::InlineWriteData m_data1;\n data::buffer::InlineWriteData m_data2;\n public:\n\n \nFlushCoroutine(FIFOBuffer* fifo, const std::shared_ptr& stream)\n : m_fifo(fifo)\n , m_stream(stream)\n {}\n\n Action act() override {\n\n if(!m_fifo->m_canRead) {\n return finish();\n }\n\n if(m_fifo->m_readPosition < m_fifo->m_writePosition) {\n\n m_data1.set(&m_fifo->m_buffer[m_fifo->m_readPosition], m_fifo->m_writePosition - m_fifo->m_readPosition);\n return yieldTo(&FlushCoroutine::fullFlush);\n\n } else {\n\n m_data1.set(&m_fifo->m_buffer[m_fifo->m_readPosition], m_fifo->m_bufferSize - m_fifo->m_readPosition);\n m_data2.set(m_fifo->m_buffer, m_fifo->m_writePosition);\n return yieldTo(&FlushCoroutine::partialFlush1);\n\n }\n }\n\n Action fullFlush() {\n return m_stream->writeExactSizeDataAsyncInline(m_data1, yieldTo(&FlushCoroutine::beforeFinish));\n }\n\n Action partialFlush1() {\n return m_stream->writeExactSizeDataAsyncInline(m_data1, yieldTo(&FlushCoroutine::partialFlush2));\n }\n\n Action partialFlush2() {\n return m_stream->writeExactSizeDataAsyncInline(m_data2, yieldTo(&FlushCoroutine::beforeFinish));\n }\n\n Action beforeFinish() {\n m_fifo->setBufferPosition(0, 0, false);\n return finish();\n }\n\n };\n\n return FlushCoroutine::start(this, stream);\n\n}\n\n//////////////////////////////////////////////////////////////////////////////////////////\n// SynchronizedFIFOBuffer\n\nSynchronizedFIFOBuffer::SynchronizedFIFOBuffer(void* buffer, v_buff_size bufferSize,\n v_buff_size readPosition, v_buff_size writePosition,\n bool canRead)\n : m_fifo(buffer, bufferSize, readPosition, writePosition, canRead)\n{}\n\nvoid SynchronizedFIFOBuffer::setBufferPosition(v_buff_size readPosition, v_buff_size writePosition, bool canRead) {\n std::lock_guard lock(m_lock);\n m_fifo.setBufferPosition(readPosition, writePosition, canRead);\n}\n\nv_io_size SynchronizedFIFOBuffer::availableToRead() {\n std::lock_guard lock(m_lock);\n return m_fifo.availableToRead();\n}\n\nv_io_size SynchronizedFIFOBuffer::availableToWrite() {\n std::lock_guard lock(m_lock);\n return m_fifo.availableToWrite();\n}\n\nv_io_size SynchronizedFIFOBuffer::read(void *data, v_buff_size count) {\n std::lock_guard lock(m_lock);\n return m_fifo.read(data, count);\n}\n\nv_io_size SynchronizedFIFOBuffer::write(const void *data, v_buff_size count) {\n std::lock_guard lock(m_lock);\n return m_fifo.write(data, count);\n}\n\n}}}\n\n// Path: src/oatpp/core/data/buffer/IOBuffer.hpp\n/***************************************************************************\n *\n * Project _____ __ ____ _ _\n * ( _ ) /__\\ (_ _)_| |_ _| |_\n * )(_)( /(__)\\ )( (_ _)(_ _)\n * (_____)(__)(__)(__) |_| |_|\n *\n *\n * Copyright 2018-present, Leonid Stryzhevskyi \n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n *\n ***************************************************************************/\n\n#ifndef oatpp_data_buffer_IOBuffer_hpp\n#define oatpp_data_buffer_IOBuffer_hpp\n\n#include \"oatpp/core/base/Countable.hpp\"\n\nnamespace oatpp { namespace data{ namespace buffer {\n\n/**\n * Predefined buffer implementation for I/O operations.\n * Allocates buffer bytes using &id:oatpp::base::memory::ThreadDistributedMemoryPool;.\n */\nclass IOBuffer : public oatpp::base::Countable {\npublic:\n /**\n * Buffer size constant.\n */\n static const v_buff_size BUFFER_SIZE;\nprivate:\n p_char8 m_entry;\npublic:\n /**\n * Constructor.\n */\n IOBuffer();\npublic:\n\n /**\n * Create shared IOBuffer.\n * @return\n */\n static std::shared_ptr createShared();\n\n /**\n * Virtual destructor.\n */\n ~IOBuffer() override;\n\n /**\n * Get pointer to buffer data.\n * @return\n */\n void* getData();\n\n /**\n * Get buffer size.\n * @return - should always return &l:IOBuffer::BUFFER_SIZE;.\n */\n v_buff_size getSize();\n \n};\n \n}}}\n\n#endif /* oatpp_data_buffer_IOBuffer_hpp */\n\n// Path: src/oatpp/core/data/buffer/IOBuffer.cpp\n/***************************************************************************\n *\n * Project _____ __ ____ _ _\n * ( _ ) /__\\ (_ _)_| |_ _| |_\n * )(_)( /(__)\\ )( (_ _)(_ _)\n * (_____)(__)(__)(__) |_| |_|\n *\n *\n * Copyright 2018-present, Leonid Stryzhevskyi \n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n *\n ***************************************************************************/\n\n#include \"IOBuffer.hpp\"\n\nnamespace oatpp { namespace data { namespace buffer {\n\nconst v_buff_size IOBuffer::BUFFER_SIZE = 4096;\n\nIOBuffer::IOBuffer()\n : m_entry(new v_char8[BUFFER_SIZE])\n{}\n\nstd::shared_ptr IOBuffer::createShared(){\n return std::make_shared();\n}\n\nIOBuffer::~IOBuffer() {\n delete [] m_entry;\n}\n\nvoid* IOBuffer::getData(){\n return m_entry;\n}\n\nv_buff_size IOBuffer::getSize(){\n return BUFFER_SIZE;\n}\n \n}}}\n\n// Path: src/oatpp/core/data/mapping/type/Type.hpp\n/***************************************************************************\n *\n * Project _____ __ ____ _ _\n * ( _ ) /__\\ (_ _)_| |_ _| |_\n * )(_)( /(__)\\ )( (_ _)(_ _)\n * (_____)(__)(__)(__) |_| |_|\n *\n *\n * Copyright 2018-present, Leonid Stryzhevskyi \n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n *\n ***************************************************************************/\n\n#ifndef oatpp_data_type_Type_hpp\n#define oatpp_data_type_Type_hpp\n\n#include \"oatpp/core/base/Countable.hpp\"\n#include \"oatpp/core/base/Environment.hpp\"\n\n#include \n#include \n#include \n#include \n\nnamespace oatpp { namespace data { namespace mapping { namespace type {\n\nclass Type; // FWD\n\n/**\n * Structure representing `ID` of the type class.\n */\nclass ClassId {\nprivate:\n static std::mutex& getClassMutex();\n static std::vector& getClassNames();\n static v_int32 registerClassName(const char* name);\npublic:\n /**\n * Get count of all type classes created.\n * @return\n */\n static int getClassCount();\n\n /**\n * Get registered class names.\n * @return\n */\n static std::vector getRegisteredClassNames();\npublic:\n\n /**\n * Constructor.\n * @param pName\n */\n ClassId(const char* pName);\n\n /**\n * Name of the type class.\n */\n const char* const name;\n\n /**\n * Integer ID of the type class.
\n * *Note: class type IDs are integer values incremented continuously from [0 to `getClassCount()`]*\n */\n const v_int32 id;\n\npublic:\n\n inline bool operator == (const ClassId& other) const {\n return id == other.id;\n }\n\n inline bool operator != (const ClassId& other) const {\n return id != other.id;\n }\n\n};\n\n\nnamespace __class {\n /**\n * Void Object Class.\n */\n class Void {\n public:\n /**\n * Class id.\n */\n static const ClassId CLASS_ID;\n\n /**\n * Get class type information.\n * @return - &l:Type;\n */\n static Type* getType();\n };\n\n}\n\nclass Void; // FWD\n\n/**\n * ObjectWrapper holds std::shared_ptr to object, object static type, plus object dynamic type information.\n * @tparam T - Object Type.\n * @tparam Clazz - Static type info.\n */\ntemplate \nclass ObjectWrapper {\n friend Void;\n template \n friend class ObjectWrapper;\nprotected:\n static void checkType(const Type* _this, const Type* other);\nprotected:\n std::shared_ptr m_ptr;\n const Type* m_valueType;\npublic:\n\n /**\n * Static object type\n */\n typedef T ObjectType;\n\n /**\n * Static object class information.\n */\n typedef Clazz Class;\npublic:\n\n ObjectWrapper(const std::shared_ptr& ptr)\n : m_ptr(ptr)\n , m_valueType(Class::getType())\n {}\n\n ObjectWrapper(const std::shared_ptr& ptr, const Type* const type)\n : m_ptr(ptr)\n , m_valueType(type)\n {}\n\n ObjectWrapper(std::shared_ptr&& ptr, const Type* const type)\n : m_ptr(std::move(ptr))\n , m_valueType(type)\n {}\n \npublic:\n\n ObjectWrapper()\n : m_valueType(Class::getType())\n {}\n\n ObjectWrapper(std::nullptr_t)\n : m_valueType(Class::getType())\n {}\n\n ObjectWrapper(const Type* const type)\n : m_valueType(type)\n {}\n\n ObjectWrapper(const ObjectWrapper& other)\n : m_ptr(other.m_ptr)\n , m_valueType(other.m_valueType)\n {}\n\n ObjectWrapper(ObjectWrapper&& other)\n : m_ptr(std::move(other.m_ptr))\n , m_valueType(other.m_valueType)\n {}\n\n template \n ObjectWrapper(const ObjectWrapper& other)\n : m_ptr(other.m_ptr)\n , m_valueType(other.m_valueType)\n {}\n\n template \n ObjectWrapper(ObjectWrapper&& other)\n : m_ptr(std::move(other.m_ptr))\n , m_valueType(other.m_valueType)\n {}\n\n inline ObjectWrapper& operator=(const ObjectWrapper& other){\n checkType(m_valueType, other.m_valueType);\n m_ptr = other.m_ptr;\n return *this;\n }\n\n inline ObjectWrapper& operator=(ObjectWrapper&& other){\n checkType(m_valueType, other.m_valueType);\n m_ptr = std::move(other.m_ptr);\n return *this;\n }\n\n template \n inline ObjectWrapper& operator=(const ObjectWrapper& other){\n checkType(m_valueType, other.m_valueType);\n m_ptr = other.m_ptr;\n return *this;\n }\n\n template \n inline ObjectWrapper& operator=(ObjectWrapper&& other){\n checkType(m_valueType, other.m_valueType);\n m_ptr = std::move(other.m_ptr);\n return *this;\n }\n\n template\n Wrapper cast() const;\n\n inline T* operator->() const {\n return m_ptr.operator->();\n }\n \n T* get() const {\n return m_ptr.get();\n }\n \n void resetPtr(const std::shared_ptr& ptr = nullptr) {\n m_ptr = ptr;\n }\n \n std::shared_ptr getPtr() const {\n return m_ptr;\n }\n\n inline bool operator == (std::nullptr_t) const {\n return m_ptr.get() == nullptr;\n }\n\n inline bool operator != (std::nullptr_t) const {\n return m_ptr.get() != nullptr;\n }\n...\n", "needle_token_start": 11876, "needle_token_end": 11927, "code_context_ntokens": 15839, "output": ["```cpp\nasync::CoroutineStarter FIFOBuffer::flushToStreamAsync(const std::shared_ptr& stream)\n{\n\n class FlushCoroutine : public oatpp::async::Coroutine {\n private:\n FIFOBuffer* m_fifo;\n std::shared_ptr m_stream;\n private:\n data::buffer::InlineWriteData m_data1;\n data::buffer::InlineWriteData m_data2;\n public:\n\n \nFlushCoroutine(FIFOBuffer* fifo, const std::shared_ptr& stream)\n : m_fifo(fifo)\n , m_stream(stream)\n {}\n\n Action act() override {\n\n if(!m_fifo->m_canRead) {\n return finish();\n }\n\n if(m_fifo->m_readPosition < m_fifo->m_writePosition) {\n\n m_data1.set(&m_fifo->m_buffer[m_fifo->m_readPosition], m_fifo->m_writePosition - m_fifo->m_readPosition);\n return yieldTo(&FlushCoroutine::fullFlush);\n\n } else {\n\n m_data1.set(&m_fifo->m_buffer[m_fifo->m_readPosition], m_fifo->m_bufferSize - m_fifo->m_readPosition);\n m_data2.set(m_fifo->m_buffer, m_fifo->m_writePosition);\n return yieldTo(&FlushCoroutine::partialFlush1);\n\n }\n }\n\n Action fullFlush() {\n return m_stream->writeExactSizeDataAsyncInline(m_data1, yieldTo(&FlushCoroutine::beforeFinish));\n }\n\n Action partialFlush1() {\n return m_stream->writeExactSizeDataAsyncInline(m_data1, yieldTo(&FlushCoroutine::partialFlush2));\n }\n\n Action partialFlush2() {\n return m_stream->writeExactSizeDataAsyncInline(m_data2, yieldTo(&FlushCoroutine::beforeFinish));\n }\n\n Action beforeFinish() {\n m_fifo->setBufferPosition(0, 0, false);\n return finish();\n }\n\n };\n\n return FlushCoroutine::start(this, stream);\n\n}\n```"]} +{"repo": "oatpp/oatpp", "name": "HttpRouterTemplate", "language": "cpp", "path": "src/oatpp/web/server/HttpRouter.hpp", "position_ratio": 0.85, "description": "\nFunction Description:\n1. **Purpose**: The primary function is to route HTTP requests to the appropriate handlers based on the HTTP method and the path pattern specified in the request.\n2. **Input**: The input includes an HTTP method and a path pattern.\n3. **Output**: The output is a handler that is associated with the given HTTP method and path pattern.\n4. **Procedure**: The procedure involves checking if a routing branch exists for the specified HTTP method. If it does not exist, a new branch is created. Then, the request is forwarded to the corresponding handler based on the path pattern within the specified method's branch.\n\n", "instruction": "Based on the function description and code context, please retrieve and repeat the exact described function from the code context in a code block wrapped by ```:", "template": "instruction\ncode_context\ndescription\ninstruction", "code_context": "// Path: src/oatpp/web/mime/multipart/Reader.cpp\n/***************************************************************************\n *\n * Project _____ __ ____ _ _\n * ( _ ) /__\\ (_ _)_| |_ _| |_\n * )(_)( /(__)\\ )( (_ _)(_ _)\n * (_____)(__)(__)(__) |_| |_|\n *\n *\n * Copyright 2018-present, Leonid Stryzhevskyi \n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n *\n ***************************************************************************/\n\n#include \"Reader.hpp\"\n\nnamespace oatpp { namespace web { namespace mime { namespace multipart {\n\n////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////\n// PartsParser\n\nPartsParser::PartsParser(Multipart* multipart)\n : m_multipart(multipart)\n{}\n\nPartsParser::PartsParser(Multipart* multipart, const PartReadersMap& readersMap)\n : m_readers(readersMap)\n , m_multipart(multipart)\n{}\n\nvoid PartsParser::onPartHeaders(const Headers& partHeaders) {\n\n m_currPart = std::make_shared(partHeaders);\n\n if(m_currPart->getName()) {\n auto it = m_readers.find(m_currPart->getName());\n if(it != m_readers.end()) {\n m_currReader = it->second;\n } else {\n m_currReader = m_defaultReader;\n }\n }\n\n if(m_currReader) {\n m_currReader->onNewPart(m_currPart);\n }\n\n}\n\nvoid PartsParser::onPartData(const char* data, v_buff_size size) {\n if(size > 0) {\n if(m_currReader) {\n m_currReader->onPartData(m_currPart, data, size);\n }\n } else {\n if(m_currReader) {\n m_currReader->onPartData(m_currPart, nullptr, 0);\n }\n m_multipart->writeNextPartSimple(m_currPart);\n }\n}\n\nvoid PartsParser::setPartReader(const oatpp::String& partName, const std::shared_ptr& reader) {\n m_readers[partName] = reader;\n}\n\nvoid PartsParser::setDefaultPartReader(const std::shared_ptr& reader) {\n m_defaultReader = reader;\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////\n// AsyncPartsParser\n\nAsyncPartsParser::AsyncPartsParser(Multipart* multipart)\n : m_multipart(multipart)\n{}\n\nAsyncPartsParser::AsyncPartsParser(Multipart* multipart, const AsyncPartReadersMap& readersMap)\n : m_readers(readersMap)\n , m_multipart(multipart)\n{}\n\nasync::CoroutineStarter AsyncPartsParser::onPartHeadersAsync(const Headers& partHeaders) {\n\n m_currPart = std::make_shared(partHeaders);\n\n if(m_currPart->getName()) {\n auto it = m_readers.find(m_currPart->getName());\n if(it != m_readers.end()) {\n m_currReader = it->second;\n } else {\n m_currReader = m_defaultReader;\n }\n }\n\n if(m_currReader) {\n return m_currReader->onNewPartAsync(m_currPart);\n }\n\n return nullptr;\n\n}\n\nasync::CoroutineStarter AsyncPartsParser::onPartDone(const std::shared_ptr& part) {\n\n class PutPartCoroutine : public async::Coroutine {\n private:\n Multipart* m_multipart;\n std::shared_ptr m_currReader;\n std::shared_ptr m_part;\n public:\n\n PutPartCoroutine(Multipart* multipart,\n const std::shared_ptr& currReader,\n const std::shared_ptr& part)\n : m_multipart(multipart)\n , m_currReader(currReader)\n , m_part(part)\n {}\n\n Action act() override {\n return m_currReader->onPartDataAsync(m_part, nullptr, 0).next(yieldTo(&PutPartCoroutine::putPart));\n }\n\n Action putPart() {\n async::Action action;\n m_multipart->writeNextPart(m_part, action);\n if(!action.isNone()) {\n return action;\n }\n return finish();\n }\n\n };\n\n return PutPartCoroutine::start(m_multipart, m_currReader, part);\n\n}\n\nasync::CoroutineStarter AsyncPartsParser::onPartDataAsync(const char* data, v_buff_size size) {\n if(size > 0) {\n if(m_currReader) {\n return m_currReader->onPartDataAsync(m_currPart, data, size);\n...\n// Path: src/oatpp/web/mime/multipart/FileProvider.hpp\n/***************************************************************************\n *\n * Project _____ __ ____ _ _\n * ( _ ) /__\\ (_ _)_| |_ _| |_\n * )(_)( /(__)\\ )( (_ _)(_ _)\n * (_____)(__)(__)(__) |_| |_|\n *\n *\n * Copyright 2018-present, Leonid Stryzhevskyi \n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n *\n ***************************************************************************/\n\n#ifndef oatpp_web_mime_multipart_FileProvider_hpp\n#define oatpp_web_mime_multipart_FileProvider_hpp\n\n#include \"PartReader.hpp\"\n#include \"Reader.hpp\"\n\nnamespace oatpp { namespace web { namespace mime { namespace multipart {\n\nclass FileProvider : public PartReaderResourceProvider {\nprivate:\n oatpp::String m_filename;\npublic:\n\n FileProvider(const oatpp::String& filename);\n\n std::shared_ptr getResource(const std::shared_ptr& part) override;\n\n async::CoroutineStarter getResourceAsync(const std::shared_ptr& part,\n std::shared_ptr& resource) override;\n\n};\n\n/**\n * Create file part reader.
\n * Reader will save part to a specified file.\n * @param filename - name of the file.\n * @param maxDataSize - max size of the received data. put `-1` for no-limit.\n * @return - `std::shared_ptr` to &id:oatpp::web::mime::multipart::PartReader;.\n */\nstd::shared_ptr createFilePartReader(const oatpp::String& filename, v_io_size maxDataSize = -1);\n\n/**\n * Create async file part reader.
\n * Reader will save part to a specified file.\n * @param filename - name of the file.\n * @param maxDataSize - max size of the received data. put `-1` for no-limit.\n * @return - `std::shared_ptr` to &id:oatpp::web::mime::multipart::AsyncPartReader;.\n */\nstd::shared_ptr createAsyncFilePartReader(const oatpp::String& filename, v_io_size maxDataSize = -1);\n\n}}}}\n\n#endif //oatpp_web_mime_multipart_FileProvider_hpp\n\n// Path: src/oatpp/web/mime/multipart/FileProvider.cpp\n/***************************************************************************\n *\n * Project _____ __ ____ _ _\n * ( _ ) /__\\ (_ _)_| |_ _| |_\n * )(_)( /(__)\\ )( (_ _)(_ _)\n * (_____)(__)(__)(__) |_| |_|\n *\n *\n * Copyright 2018-present, Leonid Stryzhevskyi \n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n *\n ***************************************************************************/\n\n#include \"FileProvider.hpp\"\n\n#include \"oatpp/core/data/resource/File.hpp\"\n\nnamespace oatpp { namespace web { namespace mime { namespace multipart {\n\nFileProvider::FileProvider(const oatpp::String& filename)\n : m_filename(filename)\n{}\n\nstd::shared_ptr FileProvider::getResource(const std::shared_ptr& part) {\n (void)part;\n return std::make_shared(m_filename->c_str());\n}\n\nasync::CoroutineStarter FileProvider::getResourceAsync(const std::shared_ptr& part,\n std::shared_ptr& stream)\n{\n (void)part;\n stream = std::make_shared(m_filename->c_str());\n return nullptr;\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////\n// Other functions\n\nstd::shared_ptr createFilePartReader(const oatpp::String& filename, v_io_size maxDataSize) {\n auto provider = std::make_shared(filename);\n auto reader = std::make_shared(provider, maxDataSize);\n return reader;\n}\n\nstd::shared_ptr createAsyncFilePartReader(const oatpp::String& filename, v_io_size maxDataSize) {\n auto provider = std::make_shared(filename);\n auto reader = std::make_shared(provider, maxDataSize);\n return reader;\n}\n\n}}}}\n\n// Path: src/oatpp/web/mime/multipart/TemporaryFileProvider.hpp\n/***************************************************************************\n *\n * Project _____ __ ____ _ _\n * ( _ ) /__\\ (_ _)_| |_ _| |_\n * )(_)( /(__)\\ )( (_ _)(_ _)\n * (_____)(__)(__)(__) |_| |_|\n *\n *\n * Copyright 2018-present, Leonid Stryzhevskyi \n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n *\n ***************************************************************************/\n\n#ifndef oatpp_web_mime_multipart_TemporaryFileProvider_hpp\n#define oatpp_web_mime_multipart_TemporaryFileProvider_hpp\n\n#include \"PartReader.hpp\"\n#include \"Reader.hpp\"\n\nnamespace oatpp { namespace web { namespace mime { namespace multipart {\n\nclass TemporaryFileProvider : public PartReaderResourceProvider {\nprivate:\n oatpp::String m_tmpDirectory;\n v_int32 m_randomWordSizeBytes;\npublic:\n\n TemporaryFileProvider(const oatpp::String& tmpDirectory, v_int32 randomWordSizeBytes = 8);\n\n std::shared_ptr getResource(const std::shared_ptr& part) override;\n\n async::CoroutineStarter getResourceAsync(const std::shared_ptr& part,\n std::shared_ptr& resource) override;\n\n};\n\n/**\n * Create part reader to a temporary file.\n * @param tmpDirectory - directory for temporary files.\n * @param randomWordSizeBytes - number of random bytes to generate file name.\n * @param maxDataSize - max size of the received data. put `-1` for no-limit.\n * @return - `std::shared_ptr` to &id:oatpp::web::mime::multipart::PartReader;.\n */\nstd::shared_ptr createTemporaryFilePartReader(const oatpp::String& tmpDirectory,\n v_int32 randomWordSizeBytes = 8,\n v_io_size maxDataSize = -1);\n\n/**\n * Create async part reader to a temporary file.\n * @param tmpDirectory - directory for temporary files.\n * @param randomWordSizeBytes - number of random bytes to generate file name.\n * @param maxDataSize - max size of the received data. put `-1` for no-limit.\n * @return - `std::shared_ptr` to &id:oatpp::web::mime::multipart::AsyncPartReader;.\n */\nstd::shared_ptr createAsyncTemporaryFilePartReader(const oatpp::String& tmpDirectory,\n v_int32 randomWordSizeBytes = 8,\n v_io_size maxDataSize = -1);\n\n}}}}\n\n#endif //oatpp_web_mime_multipart_TemporaryFileProvider_hpp\n\n// Path: src/oatpp/web/mime/multipart/TemporaryFileProvider.cpp\n/***************************************************************************\n *\n * Project _____ __ ____ _ _\n * ( _ ) /__\\ (_ _)_| |_ _| |_\n * )(_)( /(__)\\ )( (_ _)(_ _)\n * (_____)(__)(__)(__) |_| |_|\n *\n *\n * Copyright 2018-present, Leonid Stryzhevskyi \n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n *\n ***************************************************************************/\n\n#include \"TemporaryFileProvider.hpp\"\n\n#include \"oatpp/core/data/resource/TemporaryFile.hpp\"\n\nnamespace oatpp { namespace web { namespace mime { namespace multipart {\n\nTemporaryFileProvider::TemporaryFileProvider(const oatpp::String& tmpDirectory, v_int32 randomWordSizeBytes)\n : m_tmpDirectory(tmpDirectory)\n , m_randomWordSizeBytes(randomWordSizeBytes)\n{}\n\nstd::shared_ptr TemporaryFileProvider::getResource(const std::shared_ptr& part) {\n (void)part;\n return std::make_shared(m_tmpDirectory, m_randomWordSizeBytes);\n}\n\nasync::CoroutineStarter TemporaryFileProvider::getResourceAsync(const std::shared_ptr& part,\n std::shared_ptr& stream)\n{\n (void)part;\n stream = std::make_shared(m_tmpDirectory, m_randomWordSizeBytes);\n return nullptr;\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////\n// Other functions\n\nstd::shared_ptr createTemporaryFilePartReader(const oatpp::String& tmpDirectory,\n v_int32 randomWordSizeBytes,\n v_io_size maxDataSize)\n{\n auto provider = std::make_shared(tmpDirectory, randomWordSizeBytes);\n auto reader = std::make_shared(provider, maxDataSize);\n return reader;\n}\n\nstd::shared_ptr createAsyncTemporaryFilePartReader(const oatpp::String& tmpDirectory,\n v_int32 randomWordSizeBytes,\n v_io_size maxDataSize)\n{\n auto provider = std::make_shared(tmpDirectory, randomWordSizeBytes);\n auto reader = std::make_shared(provider, maxDataSize);\n return reader;\n}\n\n}}}}\n\n// Path: src/oatpp/web/mime/multipart/Multipart.cpp\n/***************************************************************************\n *\n * Project _____ __ ____ _ _\n * ( _ ) /__\\ (_ _)_| |_ _| |_\n * )(_)( /(__)\\ )( (_ _)(_ _)\n * (_____)(__)(__)(__) |_| |_|\n *\n *\n * Copyright 2018-present, Leonid Stryzhevskyi \n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n *\n ***************************************************************************/\n\n#include \"Multipart.hpp\"\n\n#include \"oatpp/web/protocol/http/Http.hpp\"\n#include \"oatpp/encoding/Base64.hpp\"\n#include \"oatpp/core/utils/Random.hpp\"\n\nnamespace oatpp { namespace web { namespace mime { namespace multipart {\n\nMultipart::Multipart(const oatpp::String& boundary)\n : m_boundary(boundary)\n{}\n\noatpp::String Multipart::getBoundary() {\n return m_boundary;\n}\n\nstd::shared_ptr Multipart::readNextPartSimple() {\n async::Action action;\n auto result = readNextPart(action);\n if(!action.isNone()) {\n throw std::runtime_error(\"[oatpp::web::mime::multipart::Multipart::readNextPartSimple()]. Error.\"\n \"Async method is called for non-async API.\");\n }\n return result;\n}\n\nvoid Multipart::writeNextPartSimple(const std::shared_ptr& part) {\n async::Action action;\n writeNextPart(part, action);\n if(!action.isNone()) {\n throw std::runtime_error(\"[oatpp::web::mime::multipart::Multipart::writeNextPartSimple()]. Error.\"\n \"Async method is called for non-async API.\");\n }\n}\n\noatpp::String Multipart::generateRandomBoundary(v_int32 boundarySize) {\n std::unique_ptr buffer(new v_char8[boundarySize]);\n utils::random::Random::randomBytes(buffer.get(), boundarySize);\n return encoding::Base64::encode(buffer.get(), boundarySize, encoding::Base64::ALPHABET_BASE64_URL_SAFE);\n}\n\noatpp::String Multipart::parseBoundaryFromHeaders(const Headers& requestHeaders) {\n\n oatpp::String boundary;\n auto contentType = requestHeaders.getAsMemoryLabel(\"Content-Type\");\n\n if(contentType) {\n oatpp::web::protocol::http::HeaderValueData valueData;\n oatpp::web::protocol::http::Parser::parseHeaderValueData(valueData, contentType, ';');\n boundary = valueData.getTitleParamValue(\"boundary\");\n }\n\n return boundary;\n\n}\n\n}}}}\n\n// Path: src/oatpp/web/mime/multipart/PartList.hpp\n/***************************************************************************\n *\n * Project _____ __ ____ _ _\n * ( _ ) /__\\ (_ _)_| |_ _| |_\n * )(_)( /(__)\\ )( (_ _)(_ _)\n * (_____)(__)(__)(__) |_| |_|\n *\n *\n * Copyright 2018-present, Leonid Stryzhevskyi ,\n * Matthias Haselmaier \n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n *\n ***************************************************************************/\n\n#ifndef oatpp_web_mime_multipart_PartList_hpp\n#define oatpp_web_mime_multipart_PartList_hpp\n\n#include \"Multipart.hpp\"\n\nnamespace oatpp { namespace web { namespace mime { namespace multipart {\n\n/**\n * Structure that holds Multipart parts in the `std::list`.\n */\nclass PartList : public Multipart {\nprivate:\n std::unordered_map>> m_namedParts;\n bool m_readIteratorInitialized;\n std::list> m_parts;\n std::list>::const_iterator m_iterator;\npublic:\n\n /**\n * Constructor.\n * @param boundary - multipart boundary value.\n */\n PartList(const oatpp::String& boundary);\n\n /**\n * Constructor.\n * @param requestHeaders - request headers. Headers must contain \"Content-Type\" header.\n */\n PartList(const Headers& requestHeaders);\n\n /**\n * Create Multipart object with random boundary.
\n * It will generate random vector of size `boundarySize` in bytes encoded in base64.\n * @param boundarySize - size of the random vecrot in bytes.\n * @return - `std::shared_ptr` to Multipart.\n */\n static std::shared_ptr createSharedWithRandomBoundary(v_int32 boundarySize = 15);\n\n /**\n * Read part-by-part from Multipart.\n * @return\n */\n std::shared_ptr readNextPart(async::Action& action) override;\n\n /**\n * Write part-by-part to Multipart.\n * @param part\n */\n void writeNextPart(const std::shared_ptr& part, async::Action& action) override;\n\n /**\n * Get part by name
\n * Returns the first part if multiple parts with same name exist.\n * Applicable to named parts only.\n * @param name - &id:oatpp::String;.\n * @return - `std::shared_ptr` to &id:oatpp::web::mime::multipart::Part;.\n */\n std::shared_ptr getNamedPart(const oatpp::String& name);\n\n /**\n * Get all parts by name
\n * Applicable to named parts only.\n * @param name - &id:oatpp::String;.\n * @return - `std::list` of `std::shared_ptr` to &id:oatpp::web::mime::multipart::Part;.\n */\n std::list> getNamedParts(const oatpp::String& name);\n\n /**\n * Get list of all parts.\n * @return - `std::list` of `std::shared_ptr` to &id:oatpp::web::mime::multipart::Part;.\n */\n const std::list>& getAllParts();\n\n /**\n * Get parts count.\n * @return - parts count.\n */\n v_int64 count();\n\n};\n\n}}}}\n\n#endif //oatpp_web_mime_multipart_PartList_hpp\n\n// Path: src/oatpp/web/mime/multipart/PartList.cpp\n/***************************************************************************\n *\n * Project _____ __ ____ _ _\n * ( _ ) /__\\ (_ _)_| |_ _| |_\n * )(_)( /(__)\\ )( (_ _)(_ _)\n * (_____)(__)(__)(__) |_| |_|\n *\n *\n * Copyright 2018-present, Leonid Stryzhevskyi ,\n * Matthias Haselmaier \n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n *\n ***************************************************************************/\n\n#include \"PartList.hpp\"\n\nnamespace oatpp { namespace web { namespace mime { namespace multipart {\n\nPartList::PartList(const oatpp::String& boundary)\n : Multipart(boundary)\n , m_readIteratorInitialized(false)\n{}\n\nPartList::PartList(const Headers& requestHeaders)\n : Multipart(parseBoundaryFromHeaders(requestHeaders))\n , m_readIteratorInitialized(false)\n{\n if(!getBoundary()) {\n throw std::runtime_error(\"[oatpp::web::mime::multipart::PartList::PartList]: Error. No 'boundary' value found in headers.\");\n }\n}\n\nstd::shared_ptr PartList::createSharedWithRandomBoundary(v_int32 boundarySize) {\n auto boundary = generateRandomBoundary(boundarySize);\n return std::make_shared(boundary);\n}\n\nstd::shared_ptr PartList::readNextPart(async::Action& action) {\n if(!m_readIteratorInitialized) {\n m_readIteratorInitialized = true;\n m_iterator = m_parts.begin();\n }\n if(m_iterator == m_parts.end()) {\n return nullptr;\n }\n return *m_iterator ++;\n}\n\nvoid PartList::writeNextPart(const std::shared_ptr& part, async::Action& action) {\n\n if(part->getName()) {\n m_namedParts[part->getName()].push_back(part);\n }\n\n m_parts.push_back(part);\n\n}\n\nstd::shared_ptr PartList::getNamedPart(const oatpp::String& name) {\n\n auto it = m_namedParts.find(name);\n if(it != m_namedParts.end()) {\n return it->second.front();\n }\n\n return nullptr;\n\n}\n\nstd::list> PartList::getNamedParts(const oatpp::String& name) {\n\n auto it = m_namedParts.find(name);\n if(it != m_namedParts.end()) {\n return it->second;\n }\n\n return std::list>{};\n\n}\n\nconst std::list>& PartList::getAllParts() {\n return m_parts;\n}\n\nv_int64 PartList::count() {\n return static_cast(m_parts.size());\n}\n\n}}}}\n\n// Path: src/oatpp/web/server/AsyncHttpConnectionHandler.hpp\n/***************************************************************************\n *\n * Project _____ __ ____ _ _\n * ( _ ) /__\\ (_ _)_| |_ _| |_\n * )(_)( /(__)\\ )( (_ _)(_ _)\n * (_____)(__)(__)(__) |_| |_|\n *\n *\n * Copyright 2018-present, Leonid Stryzhevskyi \n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n *\n ***************************************************************************/\n\n#ifndef oatpp_web_server_AsyncHttpConnectionHandler_hpp\n#define oatpp_web_server_AsyncHttpConnectionHandler_hpp\n\n#include \"oatpp/web/server/HttpProcessor.hpp\"\n#include \"oatpp/network/ConnectionHandler.hpp\"\n#include \"oatpp/core/async/Executor.hpp\"\n#include \"oatpp/core/concurrency/SpinLock.hpp\"\n\n#include \n\nnamespace oatpp { namespace web { namespace server {\n\n/**\n * Asynchronous &id:oatpp::network::ConnectionHandler; for handling http communication.\n */\nclass AsyncHttpConnectionHandler : public base::Countable, public network::ConnectionHandler, public HttpProcessor::TaskProcessingListener {\nprotected:\n\n void onTaskStart(const provider::ResourceHandle& connection) override;\n void onTaskEnd(const provider::ResourceHandle& connection) override;\n\n void invalidateAllConnections();\n\nprivate:\n std::shared_ptr m_executor;\n std::shared_ptr m_components;\n std::atomic_bool m_continue;\n std::unordered_map> m_connections;\n oatpp::concurrency::SpinLock m_connectionsLock;\npublic:\n\n /**\n * Constructor.\n * @param components - &id:oatpp::web::server::HttpProcessor::Components;.\n * @param threadCount - number of threads.\n */\n AsyncHttpConnectionHandler(const std::shared_ptr& components, v_int32 threadCount = oatpp::async::Executor::VALUE_SUGGESTED);\n\n /**\n * Constructor.\n * @param components - &id:oatpp::web::server::HttpProcessor::Components;.\n * @param executor - &id:oatpp::async::Executor;.\n */\n AsyncHttpConnectionHandler(const std::shared_ptr& components, const std::shared_ptr& executor);\n\n /**\n * Constructor.\n * @param router - &id:oatpp::web::server::HttpRouter; to route incoming requests.\n * @param threadCount - number of threads.\n */\n AsyncHttpConnectionHandler(const std::shared_ptr& router, v_int32 threadCount = oatpp::async::Executor::VALUE_SUGGESTED)\n : AsyncHttpConnectionHandler(std::make_shared(router), threadCount)\n {}\n\n /**\n * Constructor.\n * @param router - &id:oatpp::web::server::HttpRouter; to route incoming requests.\n * @param executor - &id:oatpp::async::Executor;.\n */\n AsyncHttpConnectionHandler(const std::shared_ptr& router, const std::shared_ptr& executor)\n : AsyncHttpConnectionHandler(std::make_shared(router), executor)\n {}\n\n /**\n * Constructor.\n * @param router - &id:oatpp::web::server::HttpRouter; to route incoming requests.\n * @param config - &id:oatpp::web::server::HttpProcessor::Config;.\n * @param threadCount - number of threads.\n */\n AsyncHttpConnectionHandler(const std::shared_ptr& router,\n const std::shared_ptr& config,\n v_int32 threadCount = oatpp::async::Executor::VALUE_SUGGESTED)\n : AsyncHttpConnectionHandler(std::make_shared(router, config), threadCount)\n {}\n\n /**\n * Constructor.\n * @param router - &id:oatpp::web::server::HttpRouter; to route incoming requests.\n * @param config - &id:oatpp::web::server::HttpProcessor::Config;.\n * @param executor - &id:oatpp::async::Executor;.\n */\n AsyncHttpConnectionHandler(const std::shared_ptr& router,\n const std::shared_ptr& config,\n const std::shared_ptr& executor)\n : AsyncHttpConnectionHandler(std::make_shared(router, config), executor)\n {}\n\npublic:\n \n static std::shared_ptr createShared(const std::shared_ptr& router,\n v_int32 threadCount = oatpp::async::Executor::VALUE_SUGGESTED);\n \n static std::shared_ptr createShared(const std::shared_ptr& router,\n const std::shared_ptr& executor);\n \n static std::shared_ptr createShared(const std::shared_ptr& components,\n const std::shared_ptr& executor);\n\n static std::shared_ptr createShared(const std::shared_ptr& components,\n v_int32 threadCount = oatpp::async::Executor::VALUE_SUGGESTED);\n\n void setErrorHandler(const std::shared_ptr& errorHandler);\n\n /**\n * Add request interceptor. Request interceptors are called before routing happens.\n * If multiple interceptors set then the order of interception is the same as the order of calls to `addRequestInterceptor`.\n * @param interceptor - &id:oatpp::web::server::interceptor::RequestInterceptor;.\n */\n void addRequestInterceptor(const std::shared_ptr& interceptor);\n\n /**\n * Add response interceptor.\n * If multiple interceptors set then the order of interception is the same as the order of calls to `addResponseInterceptor`.\n * @param interceptor - &id:oatpp::web::server::interceptor::RequestInterceptor;.\n */\n void addResponseInterceptor(const std::shared_ptr& interceptor);\n\n \n void handleConnection(const provider::ResourceHandle& connection,\n const std::shared_ptr& params) override;\n\n /**\n * Will call m_executor.stop()\n */\n void stop() override;\n\n /**\n * Get connections count.\n * @return\n */\n v_uint64 getConnectionsCount();\n \n};\n \n}}}\n\n#endif /* oatpp_web_server_AsyncHttpConnectionHandler_hpp */\n\n\n// Path: src/oatpp/web/server/AsyncHttpConnectionHandler.cpp\n/***************************************************************************\n *\n * Project _____ __ ____ _ _\n * ( _ ) /__\\ (_ _)_| |_ _| |_\n * )(_)( /(__)\\ )( (_ _)(_ _)\n * (_____)(__)(__)(__) |_| |_|\n *\n *\n * Copyright 2018-present, Leonid Stryzhevskyi \n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n *\n ***************************************************************************/\n\n#include \"./AsyncHttpConnectionHandler.hpp\"\n\nnamespace oatpp { namespace web { namespace server {\n\nvoid AsyncHttpConnectionHandler::onTaskStart(const provider::ResourceHandle& connection) {\n\n std::lock_guard lock(m_connectionsLock);\n m_connections.insert({reinterpret_cast(connection.object.get()), connection});\n\n if(!m_continue.load()) {\n connection.invalidator->invalidate(connection.object);\n }\n\n}\n\nvoid AsyncHttpConnectionHandler::onTaskEnd(const provider::ResourceHandle& connection) {\n std::lock_guard lock(m_connectionsLock);\n m_connections.erase(reinterpret_cast(connection.object.get()));\n}\n\nvoid AsyncHttpConnectionHandler::invalidateAllConnections() {\n std::lock_guard lock(m_connectionsLock);\n for(auto& c : m_connections) {\n const auto& handle = c.second;\n handle.invalidator->invalidate(handle.object);\n }\n}\n\nv_uint64 AsyncHttpConnectionHandler::getConnectionsCount() {\n std::lock_guard lock(m_connectionsLock);\n return m_connections.size();\n}\n\nAsyncHttpConnectionHandler::AsyncHttpConnectionHandler(const std::shared_ptr& components,\n v_int32 threadCount)\n : m_executor(std::make_shared(threadCount))\n , m_components(components)\n , m_continue(true)\n{\n m_executor->detach();\n}\n\nAsyncHttpConnectionHandler::AsyncHttpConnectionHandler(const std::shared_ptr& components,\n const std::shared_ptr& executor)\n : m_executor(executor)\n , m_components(components)\n , m_continue(true)\n{}\n\nstd::shared_ptr AsyncHttpConnectionHandler::createShared(const std::shared_ptr& router, v_int32 threadCount){\n return std::make_shared(router, threadCount);\n}\n\nstd::shared_ptr AsyncHttpConnectionHandler::createShared(const std::shared_ptr& router, const std::shared_ptr& executor){\n return std::make_shared(router, executor);\n}\n\nstd::shared_ptr AsyncHttpConnectionHandler::createShared(const std::shared_ptr& components, const std::shared_ptr& executor){\n return std::make_shared(components, executor);\n}\n\nstd::shared_ptr AsyncHttpConnectionHandler::createShared(const std::shared_ptr& components, v_int32 threadCount){\n return std::make_shared(components, threadCount);\n}\n\nvoid AsyncHttpConnectionHandler::setErrorHandler(const std::shared_ptr& errorHandler){\n m_components->errorHandler = errorHandler;\n if(!m_components->errorHandler) {\n m_components->errorHandler = handler::DefaultErrorHandler::createShared();\n }\n}\n\nvoid AsyncHttpConnectionHandler::addRequestInterceptor(const std::shared_ptr& interceptor) {\n m_components->requestInterceptors.push_back(interceptor);\n}\n\nvoid AsyncHttpConnectionHandler::addResponseInterceptor(const std::shared_ptr& interceptor) {\n m_components->responseInterceptors.push_back(interceptor);\n}\n\nvoid AsyncHttpConnectionHandler::handleConnection(const provider::ResourceHandle& connection,\n const std::shared_ptr& params)\n{\n\n (void)params;\n\n if (m_continue.load()) {\n\n connection.object->setOutputStreamIOMode(oatpp::data::stream::IOMode::ASYNCHRONOUS);\n connection.object->setInputStreamIOMode(oatpp::data::stream::IOMode::ASYNCHRONOUS);\n\n m_executor->execute(m_components, connection, this);\n\n }\n \n}\n\nvoid AsyncHttpConnectionHandler::stop() {\n m_continue.store(false);\n\n /* invalidate all connections */\n invalidateAllConnections();\n\n /* Wait until all connection-threads are done */\n while(getConnectionsCount() > 0) {\n std::this_thread::sleep_for(std::chrono::milliseconds(100));\n }\n}\n \n}}}\n\n// Path: src/oatpp/web/server/HttpRequestHandler.hpp\n/***************************************************************************\n *\n * Project _____ __ ____ _ _\n * ( _ ) /__\\ (_ _)_| |_ _| |_\n * )(_)( /(__)\\ )( (_ _)(_ _)\n * (_____)(__)(__)(__) |_| |_|\n *\n *\n * Copyright 2018-present, Leonid Stryzhevskyi \n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n *\n ***************************************************************************/\n\n#ifndef oatpp_web_server_HttpRequestHandler_hpp\n#define oatpp_web_server_HttpRequestHandler_hpp\n\n#include \"oatpp/web/protocol/http/outgoing/ResponseFactory.hpp\"\n#include \"oatpp/web/protocol/http/outgoing/Response.hpp\"\n#include \"oatpp/web/protocol/http/incoming/Request.hpp\"\n\nnamespace oatpp { namespace web { namespace server {\n\n/**\n * HTTP request handler.\n */\nclass HttpRequestHandler {\npublic:\n\n /**\n * Convenience typedef for &id:oatpp::web::protocol::http::Status;.\n */\n typedef oatpp::web::protocol::http::Status Status;\n\n /**\n * Convenience typedef for &id:oatpp::web::protocol::http::Header;.\n */\n typedef oatpp::web::protocol::http::Header Header;\n\n /**\n * Convenience typedef for &id:oatpp::web::protocol::http::Headers;.\n */\n typedef oatpp::web::protocol::http::Headers Headers;\n\n /**\n * Convenience typedef for &id:oatpp::web::protocol::http::QueryParams;.\n */\n typedef oatpp::web::protocol::http::QueryParams QueryParams;\n\n /**\n * Convenience typedef for &id:oatpp::web::protocol::http::incoming::Request;.\n */\n typedef oatpp::web::protocol::http::incoming::Request IncomingRequest;\n\n /**\n * Convenience typedef for &id:oatpp::web::protocol::http::outgoing::Response;.\n */\n typedef oatpp::web::protocol::http::outgoing::Response OutgoingResponse;\n\n /**\n * Convenience typedef for &id:oatpp::web::protocol::http::outgoing::ResponseFactory;.\n */\n typedef oatpp::web::protocol::http::outgoing::ResponseFactory ResponseFactory;\n\n /**\n * Convenience typedef for &id:oatpp::web::protocol::http::HttpError;.\n */\n typedef oatpp::web::protocol::http::HttpError HttpError;\n\npublic:\n\n /**\n * Handle incoming http request.
\n * *Implement this method.*\n * @param request - incoming http request. &id:oatpp::web::protocol::http::incoming::Request;.\n * @return - outgoing http response. &id:oatpp::web::protocol::http::outgoing::Response;.\n */\n virtual std::shared_ptr handle(const std::shared_ptr& request) {\n (void)request;\n throw HttpError(Status::CODE_501, \"Endpoint not implemented.\");\n }\n\n /**\n * Handle incoming http request in Asynchronous manner.
\n * *Implement this method.*\n * @param request - &id:oatpp::web::protocol::http::incoming::Request;.\n * @return - &id:oatpp::async::CoroutineStarterForResult; of &id:oatpp::web::protocol::http::outgoing::Response;.\n */\n virtual oatpp::async::CoroutineStarterForResult&>\n handleAsync(const std::shared_ptr& request) {\n (void)request;\n throw HttpError(Status::CODE_501, \"Asynchronous endpoint not implemented.\");\n }\n\n /**\n * You have to provide a definition for destructors, otherwise its undefined behaviour.\n */\n virtual ~HttpRequestHandler() = default;\n};\n\n}}}\n\n#endif // oatpp_web_server_HttpRequestHandler_hpp\n\n// Path: src/oatpp/web/server/HttpRouter.hpp\n/***************************************************************************\n *\n * Project _____ __ ____ _ _\n * ( _ ) /__\\ (_ _)_| |_ _| |_\n * )(_)( /(__)\\ )( (_ _)(_ _)\n * (_____)(__)(__)(__) |_| |_|\n *\n *\n * Copyright 2018-present, Leonid Stryzhevskyi \n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n *\n ***************************************************************************/\n\n#ifndef oatpp_web_server_HttpRouter_hpp\n#define oatpp_web_server_HttpRouter_hpp\n\n#include \"./HttpRequestHandler.hpp\"\n\n#include \"oatpp/web/server/api/ApiController.hpp\"\n#include \"oatpp/web/server/api/Endpoint.hpp\"\n#include \"oatpp/web/url/mapping/Router.hpp\"\n\nnamespace oatpp { namespace web { namespace server {\n\n/**\n * HttpRouter is responsible for routing http requests by method and path-pattern.\n */\ntemplate\nclass HttpRouterTemplate : public oatpp::base::Countable {\nprivate:\n /**\n * Convenience typedef for &id:oatpp::data::share::StringKeyLabel;.\n */\n typedef data::share::StringKeyLabel StringKeyLabel;\npublic:\n\n /**\n * &id:oatpp::web::url::mapping::Router;\n */\n typedef web::url::mapping::Router BranchRouter;\n\n /**\n * Http method to &l:HttpRouter::BranchRouter; map.\n * Meaning that for each http method like [\"GET\", \"POST\", ...] there is a separate &l:HttpRouter::BranchRouter;.\n */\n typedef std::unordered_map> BranchMap;\nprotected:\n BranchMap m_branchMap;\nprotected:\n\n const std::shared_ptr& getBranch(const StringKeyLabel& name){\n auto it = m_branchMap.find(name);\n if(it == m_branchMap.end()){\n m_branchMap[name] = BranchRouter::createShared();\n }\n return m_branchMap[name];\n }\n\npublic:\n\n /**\n * Default Constructor.\n */\n \nHttpRouterTemplate() = default;\n\n /**\n * Create shared HttpRouter.\n * @return - `std::shared_ptr` to HttpRouter.\n */\n static std::shared_ptr createShared() {\n return std::make_shared();\n }\n\n /**\n * Route URL to Endpoint by method, and pathPattern.\n * @param method - http method like [\"GET\", \"POST\", etc.].\n * @param pathPattern - url path pattern. ex.: `\"/path/to/resource/with/{param1}/{param2}\"`.\n * @param endpoint - router endpoint.\n */\n void route(const oatpp::String& method, const oatpp::String& pathPattern, const RouterEndpoint& endpoint) {\n getBranch(method)->route(pathPattern, endpoint);\n }\n\n /**\n * Resolve http method and path to &id:oatpp::web::url::mapping::Router::Route;\n * @param method - http method like [\"GET\", \"POST\", etc.].\n * @param url - url path. \"Path\" part of url only.\n * @return - &id:oatpp::web::url::mapping::Router::Route;.\n */\n typename BranchRouter::Route getRoute(const StringKeyLabel& method, const StringKeyLabel& path){\n auto it = m_branchMap.find(method);\n if(it != m_branchMap.end()) {\n return m_branchMap[method]->getRoute(path);\n }\n return typename BranchRouter::Route();\n }\n\n /**\n * Print out all router mapping.\n */\n void logRouterMappings() {\n for(auto it : m_branchMap) {\n it.second->logRouterMappings(it.first);\n }\n }\n \n};\n\n/**\n * Default HttpRouter.\n */\nclass HttpRouter : public HttpRouterTemplate> {\nprivate:\n std::list> m_controllers;\npublic:\n\n /**\n * Create shared HttpRouter.\n * @return\n */\n static std::shared_ptr createShared();\n\n using HttpRouterTemplate::route;\n void route(const std::shared_ptr& endpoint);\n void route(const server::api::Endpoints& endpoints);\n\n /**\n * Add controller and route its' endpoints.\n * @param controller\n * @return - `std::shared_ptr` to the controller added.\n */\n std::shared_ptr addController(const std::shared_ptr& controller);\n\n};\n \n}}}\n\n#endif /* oatpp_web_server_HttpRouter_hpp */\n\n// Path: src/oatpp/web/server/HttpRouter.cpp\n/***************************************************************************\n *\n * Project _____ __ ____ _ _\n * ( _ ) /__\\ (_ _)_| |_ _| |_\n * )(_)( /(__)\\ )( (_ _)(_ _)\n * (_____)(__)(__)(__) |_| |_|\n *\n *\n * Copyright 2018-present, Leonid Stryzhevskyi \n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n *\n ***************************************************************************/\n\n#include \"HttpRouter.hpp\"\n\nnamespace oatpp { namespace web { namespace server {\n\nstd::shared_ptr HttpRouter::createShared() {\n return std::make_shared();\n}\n\nvoid HttpRouter::route(const std::shared_ptr& endpoint) {\n route(endpoint->info()->method, endpoint->info()->path, endpoint->handler);\n}\n\nvoid HttpRouter::route(const server::api::Endpoints& endpoints) {\n for(auto& e : endpoints.list) {\n route(e);\n }\n}\n\nstd::shared_ptr HttpRouter::addController(const std::shared_ptr& controller) {\n m_controllers.push_back(controller);\n route(controller->getEndpoints());\n return controller;\n}\n\n}}}\n\n// Path: src/oatpp/web/server/interceptor/RequestInterceptor.hpp\n/***************************************************************************\n *\n * Project _____ __ ____ _ _\n * ( _ ) /__\\ (_ _)_| |_ _| |_\n * )(_)( /(__)\\ )( (_ _)(_ _)\n * (_____)(__)(__)(__) |_| |_|\n *\n *\n * Copyright 2018-present, Leonid Stryzhevskyi \n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n *\n ***************************************************************************/\n\n#ifndef oatpp_web_server_interceptor_RequestInterceptor_hpp\n#define oatpp_web_server_interceptor_RequestInterceptor_hpp\n\n#include \"oatpp/web/protocol/http/outgoing/Response.hpp\"\n#include \"oatpp/web/protocol/http/incoming/Request.hpp\"\n#include \"oatpp/web/protocol/http/Http.hpp\"\n\nnamespace oatpp { namespace web { namespace server { namespace interceptor {\n\n/**\n * RequestInterceptor.\n */\nclass RequestInterceptor {\npublic:\n /**\n * Convenience typedef for &id:oatpp::web::protocol::http::incoming::Request;.\n */\n typedef oatpp::web::protocol::http::incoming::Request IncomingRequest;\n\n /**\n * Convenience typedef for &id:oatpp::web::protocol::http::outgoing::Response;.\n */\n typedef oatpp::web::protocol::http::outgoing::Response OutgoingResponse;\npublic:\n \n /**\n * Default virtual destructor.\n */\n virtual ~RequestInterceptor() = default;\n \n /**\n *\n * This method should not do any \"heavy\" nor I/O operations\n * as it is used for both \"Simple\" and \"Async\" API\n * NOT FOR I/O operations!!!\n *\n * - return nullptr to continue.\n * - return OutgoingResponse to send response immediately\n *\n * possible usage ex: return 301 - redirect if needed\n *\n */\n virtual std::shared_ptr intercept(const std::shared_ptr& request) = 0;\n \n};\n \n}}}}\n\n#endif /* oatpp_web_server_interceptor_RequestInterceptor_hpp */\n\n// Path: src/oatpp/web/server/interceptor/ResponseInterceptor.hpp\n/***************************************************************************\n *\n * Project _____ __ ____ _ _\n * ( _ ) /__\\ (_ _)_| |_ _| |_\n * )(_)( /(__)\\ )( (_ _)(_ _)\n * (_____)(__)(__)(__) |_| |_|\n *\n *\n * Copyright 2018-present, Leonid Stryzhevskyi \n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n *\n ***************************************************************************/\n\n#ifndef oatpp_web_server_interceptor_ResponseInterceptor_hpp\n#define oatpp_web_server_interceptor_ResponseInterceptor_hpp\n\n#include \"oatpp/web/protocol/http/incoming/Request.hpp\"\n...\n", "needle_token_start": 13398, "needle_token_end": 13405, "code_context_ntokens": 15751, "output": ["```cpp\nvoid HttpRouterTemplate::route(const oatpp::String& method, const oatpp::String& pathPattern, const RouterEndpoint& endpoint) {\n getBranch(method)->route(pathPattern, endpoint);\n}\n```"]} +{"repo": "oatpp/oatpp", "name": "AuthorizationObject", "language": "cpp", "path": "src/oatpp/web/server/handler/AuthorizationHandler.hpp", "position_ratio": 0.95, "description": "\nFunction Description:\n1. **Purpose**: To serve as a base class for creating objects that manage and validate user access permissions within a web server environment.\n2. **Input**: Typically receives HTTP request data indirectly through methods implemented in subclasses.\n3. **Output**: Does not directly produce output but provides a framework for subclasses to generate authorization decisions, which may affect the HTTP response.\n4. **Procedure**: Acts as a foundational structure that other specific authorization classes extend. It initializes essential properties and functionalities needed for handling authorization but does not implement specific authorization logic itself.\n\n", "instruction": "Based on the function description and code context, please retrieve and repeat the exact described function from the code context in a code block wrapped by ```:", "template": "instruction\ncode_context\ndescription\ninstruction", "code_context": "// Path: src/oatpp/web/server/interceptor/RequestInterceptor.hpp\n/***************************************************************************\n *\n * Project _____ __ ____ _ _\n * ( _ ) /__\\ (_ _)_| |_ _| |_\n * )(_)( /(__)\\ )( (_ _)(_ _)\n * (_____)(__)(__)(__) |_| |_|\n *\n *\n * Copyright 2018-present, Leonid Stryzhevskyi \n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n *\n ***************************************************************************/\n\n#ifndef oatpp_web_server_interceptor_RequestInterceptor_hpp\n#define oatpp_web_server_interceptor_RequestInterceptor_hpp\n\n#include \"oatpp/web/protocol/http/outgoing/Response.hpp\"\n#include \"oatpp/web/protocol/http/incoming/Request.hpp\"\n#include \"oatpp/web/protocol/http/Http.hpp\"\n\nnamespace oatpp { namespace web { namespace server { namespace interceptor {\n\n/**\n * RequestInterceptor.\n */\nclass RequestInterceptor {\npublic:\n /**\n * Convenience typedef for &id:oatpp::web::protocol::http::incoming::Request;.\n */\n typedef oatpp::web::protocol::http::incoming::Request IncomingRequest;\n\n /**\n * Convenience typedef for &id:oatpp::web::protocol::http::outgoing::Response;.\n */\n...\n// Path: src/oatpp/web/server/interceptor/ResponseInterceptor.hpp\n/***************************************************************************\n *\n * Project _____ __ ____ _ _\n * ( _ ) /__\\ (_ _)_| |_ _| |_\n * )(_)( /(__)\\ )( (_ _)(_ _)\n * (_____)(__)(__)(__) |_| |_|\n *\n *\n * Copyright 2018-present, Leonid Stryzhevskyi \n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n *\n ***************************************************************************/\n\n#ifndef oatpp_web_server_interceptor_ResponseInterceptor_hpp\n#define oatpp_web_server_interceptor_ResponseInterceptor_hpp\n\n#include \"oatpp/web/protocol/http/incoming/Request.hpp\"\n#include \"oatpp/web/protocol/http/outgoing/Response.hpp\"\n#include \"oatpp/web/protocol/http/Http.hpp\"\n\nnamespace oatpp { namespace web { namespace server { namespace interceptor {\n\n/**\n * ResponseInterceptor.\n */\nclass ResponseInterceptor {\npublic:\n /**\n * Convenience typedef for &id:oatpp::web::protocol::http::incoming::Request;.\n */\n typedef oatpp::web::protocol::http::incoming::Request IncomingRequest;\n\n /**\n * Convenience typedef for &id:oatpp::web::protocol::http::outgoing::Response;.\n */\n typedef oatpp::web::protocol::http::outgoing::Response OutgoingResponse;\npublic:\n\n /**\n * Default virtual destructor.\n */\n virtual ~ResponseInterceptor() = default;\n\n /**\n *\n * This method should not do any \"heavy\" nor I/O operations
\n * as it is used for both \"Simple\" and \"Async\" API
\n * NOT FOR I/O operations!!!
\n *
\n * - return the same response, or the new one.
\n * - do **NOT** return `nullptr`.\n *

\n * possible usage ex: add extra headers to the response.\n *\n * @param request - the corresponding request.\n * @param response - response to the request\n * @return - &id:oatpp::web::protocol::http::outgoing::Response;.\n */\n virtual std::shared_ptr intercept(const std::shared_ptr& request,\n const std::shared_ptr& response) = 0;\n\n};\n\n}}}}\n\n#endif /* oatpp_web_server_interceptor_ResponseInterceptor_hpp */\n\n// Path: src/oatpp/web/server/handler/ErrorHandler.hpp\n/***************************************************************************\n *\n * Project _____ __ ____ _ _\n * ( _ ) /__\\ (_ _)_| |_ _| |_\n * )(_)( /(__)\\ )( (_ _)(_ _)\n * (_____)(__)(__)(__) |_| |_|\n *\n *\n * Copyright 2018-present, Leonid Stryzhevskyi \n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n *\n ***************************************************************************/\n\n#ifndef oatpp_web_server_handler_ErrorHandler_hpp\n#define oatpp_web_server_handler_ErrorHandler_hpp\n\n#include \"oatpp/web/protocol/http/outgoing/Response.hpp\"\n#include \"oatpp/web/protocol/http/Http.hpp\"\n\nnamespace oatpp { namespace web { namespace server { namespace handler {\n\n/**\n * Error Handler.\n */\nclass ErrorHandler {\npublic:\n /**\n * Convenience typedef for Headers.
\n * See &id:oatpp::web::protocol::http::Headers;\n */\n typedef web::protocol::http::Headers Headers;\npublic:\n /**\n * Virtual destructor since the class is meant to be derived from.\n * */\n virtual ~ErrorHandler() = default;\n\n /**\n * Implement this method!\n * @param error - &std::exception;.\n * @return - std::shared_ptr to &id:oatpp::web::protocol::http::outgoing::Response;.\n */\n virtual std::shared_ptr handleError(const std::exception_ptr& exceptionPtr);\n\n /**\n * Implement this method!\n * @param status - &id:oatpp::web::protocol::http::Status;.\n * @param message - &id:oatpp::String;.\n * @param Headers - &id:oatpp::web::protocol::http::Headers;\n * @return - std::shared_ptr to &id:oatpp::web::protocol::http::outgoing::Response;.\n */\n [[deprecated]]\n virtual\n std::shared_ptr\n handleError(const protocol::http::Status& status, const oatpp::String& message, const Headers& headers) = 0;\n\n /**\n * Convenience method to call `handleError` method with no headers.\n * @param status - &id:oatpp::web::protocol::http::Status;\n * @param message - &id:oatpp::String;.\n * @return - std::shared_ptr to &id:oatpp::web::protocol::http::outgoing::Response;.\n */\n [[deprecated]]\n std::shared_ptr handleError(const protocol::http::Status& status, const oatpp::String& message);\n \n};\n\n/**\n * Default Error Handler.\n */\nclass DefaultErrorHandler : public oatpp::base::Countable, public ErrorHandler {\npublic:\n /**\n * Constructor.\n */\n DefaultErrorHandler() = default;\npublic:\n\n /**\n * Create shared DefaultErrorHandler.\n * @return - `std::shared_ptr` to DefaultErrorHandler.\n */\n static std::shared_ptr createShared() {\n return std::make_shared();\n }\n\n std::shared_ptr handleError(const std::exception_ptr& error) override;\n\n /**\n * Implementation of &l:ErrorHandler::handleError ();\n * @param status - &id:oatpp::web::protocol::http::Status;.\n * @param message - &id:oatpp::String;.\n * @return - &id:oatpp::web::protocol::http::outgoing::Response;.\n */\n std::shared_ptr\n handleError(const protocol::http::Status& status, const oatpp::String& message, const Headers& headers) override;\n\n};\n \n}}}}\n\n#endif /* oatpp_web_server_handler_ErrorHandler_hpp */\n\n// Path: src/oatpp/web/server/HttpProcessor.hpp\n/***************************************************************************\n *\n * Project _____ __ ____ _ _\n * ( _ ) /__\\ (_ _)_| |_ _| |_\n * )(_)( /(__)\\ )( (_ _)(_ _)\n * (_____)(__)(__)(__) |_| |_|\n *\n *\n * Copyright 2018-present, Leonid Stryzhevskyi \n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n *\n ***************************************************************************/\n\n#ifndef oatpp_web_server_HttpProcessor_hpp\n#define oatpp_web_server_HttpProcessor_hpp\n\n#include \"./HttpRouter.hpp\"\n\n#include \"./interceptor/RequestInterceptor.hpp\"\n#include \"./interceptor/ResponseInterceptor.hpp\"\n#include \"./handler/ErrorHandler.hpp\"\n\n#include \"oatpp/web/protocol/http/encoding/ProviderCollection.hpp\"\n\n#include \"oatpp/web/protocol/http/incoming/RequestHeadersReader.hpp\"\n#include \"oatpp/web/protocol/http/incoming/Request.hpp\"\n\n#include \"oatpp/web/protocol/http/outgoing/Response.hpp\"\n#include \"oatpp/web/protocol/http/utils/CommunicationUtils.hpp\"\n\n#include \"oatpp/core/data/stream/StreamBufferedProxy.hpp\"\n#include \"oatpp/core/async/Processor.hpp\"\n\nnamespace oatpp { namespace web { namespace server {\n\n/**\n * HttpProcessor. Helper class to handle HTTP processing.\n */\nclass HttpProcessor {\npublic:\n typedef std::list> RequestInterceptors;\n typedef std::list> ResponseInterceptors;\n typedef web::protocol::http::incoming::RequestHeadersReader RequestHeadersReader;\n typedef protocol::http::utils::CommunicationUtils::ConnectionState ConnectionState;\n\n HttpProcessor() = default;\n\npublic:\n\n /**\n * Resource config per connection.\n */\n struct Config {\n\n /**\n * Buffer used to read headers in request. Initial size of the buffer.\n */\n v_buff_size headersInBufferInitial = 2048;\n\n /**\n * Buffer used to write headers in response. Initial size of the buffer.\n */\n v_buff_size headersOutBufferInitial = 2048;\n\n /**\n * Size of the chunk used for iterative-read of headers.\n */\n v_buff_size headersReaderChunkSize = 2048;\n\n /**\n * Maximum allowed size of requests headers. The overall size of all headers in the request.\n */\n v_buff_size headersReaderMaxSize = 4096;\n\n };\n\npublic:\n\n /**\n * Collection of components needed to serve http-connection.\n */\n struct Components {\n\n /**\n * Constructor.\n * @param pRouter\n * @param pContentEncodingProviders\n * @param pBodyDecoder\n * @param pErrorHandler\n * @param pRequestInterceptors\n * @param pConfig\n */\n Components(const std::shared_ptr& pRouter,\n const std::shared_ptr& pContentEncodingProviders,\n const std::shared_ptr& pBodyDecoder,\n const std::shared_ptr& pErrorHandler,\n const RequestInterceptors& pRequestInterceptors,\n const ResponseInterceptors& pResponseInterceptors,\n const std::shared_ptr& pConfig);\n\n /**\n * Constructor.\n * @param pRouter\n */\n Components(const std::shared_ptr& pRouter);\n\n /**\n * Constructor.\n * @param pRouter\n * @param pConfig\n */\n Components(const std::shared_ptr& pRouter, const std::shared_ptr& pConfig);\n\n /**\n * Router to route incoming requests. &id:oatpp::web::server::HttpRouter;.\n */\n std::shared_ptr router;\n\n /**\n * Content-encoding providers. &id:oatpp::web::protocol::encoding::ProviderCollection;.\n */\n std::shared_ptr contentEncodingProviders;\n\n /**\n * Body decoder. &id:oatpp::web::protocol::http::incoming::BodyDecoder;.\n */\n std::shared_ptr bodyDecoder;\n\n /**\n * Error handler. &id:oatpp::web::server::handler::ErrorHandler;.\n */\n std::shared_ptr errorHandler;\n\n /**\n * Collection of request interceptors. &id:oatpp::web::server::interceptor::RequestInterceptor;.\n */\n RequestInterceptors requestInterceptors;\n\n /**\n * Collection of request interceptors. &id:oatpp::web::server::interceptor::ResponseInterceptor;.\n */\n ResponseInterceptors responseInterceptors;\n\n /**\n * Resource allocation config. &l:HttpProcessor::Config;.\n */\n std::shared_ptr config;\n\n };\n\nprivate:\n\n struct ProcessingResources {\n\n ProcessingResources(const std::shared_ptr& pComponents,\n const provider::ResourceHandle& pConnection);\n\n std::shared_ptr components;\n provider::ResourceHandle connection;\n oatpp::data::stream::BufferOutputStream headersInBuffer;\n oatpp::data::stream::BufferOutputStream headersOutBuffer;\n RequestHeadersReader headersReader;\n std::shared_ptr inStream;\n\n };\n\n static\n std::shared_ptr\n processNextRequest(ProcessingResources& resources,\n const std::shared_ptr& request,\n ConnectionState& connectionState);\n static ConnectionState processNextRequest(ProcessingResources& resources);\n\npublic:\n\n /**\n * Listener of the connection processing task.\n */\n class TaskProcessingListener {\n public:\n virtual void onTaskStart(const provider::ResourceHandle& connection) = 0;\n virtual void onTaskEnd(const provider::ResourceHandle& connection) = 0;\n };\n\npublic:\n\n /**\n * Connection serving task.
\n * Usege example:
\n * `std::thread thread(&HttpProcessor::Task::run, HttpProcessor::Task(components, connection));`\n */\n class Task : public base::Countable {\n private:\n std::shared_ptr m_components;\n provider::ResourceHandle m_connection;\n TaskProcessingListener* m_taskListener;\n public:\n\n /**\n * Constructor.\n * @param components - &l:HttpProcessor::Components;.\n * @param connection - &id:oatpp::data::stream::IOStream;.\n */\n Task(const std::shared_ptr& components,\n const provider::ResourceHandle& connection,\n TaskProcessingListener* taskListener);\n\n Task(const Task&) = delete;\n Task &operator=(const Task&) = delete;\n\n /**\n * Move-Constructor to correclty count tasks;\n */\n Task(Task &&other);\n\n /**\n * Move-Assignment to correctly count tasks.\n * @param t\n * @return\n */\n Task &operator=(Task &&other);\n\n /**\n * Destructor, needed for counting.\n */\n ~Task() override;\n\n public:\n\n /**\n * Run loop.\n */\n void run();\n\n };\n \npublic:\n\n /**\n * Connection serving coroutiner - &id:oatpp::async::Coroutine;.\n */\n class Coroutine : public oatpp::async::Coroutine {\n private:\n std::shared_ptr m_components;\n provider::ResourceHandle m_connection;\n oatpp::data::stream::BufferOutputStream m_headersInBuffer;\n RequestHeadersReader m_headersReader;\n std::shared_ptr m_headersOutBuffer;\n std::shared_ptr m_inStream;\n ConnectionState m_connectionState;\n private:\n oatpp::web::server::HttpRouter::BranchRouter::Route m_currentRoute;\n std::shared_ptr m_currentRequest;\n std::shared_ptr m_currentResponse;\n TaskProcessingListener* m_taskListener;\n public:\n\n /**\n * Constructor.\n * @param components - &l:HttpProcessor::Components;.\n * @param connection - &id:oatpp::data::stream::IOStream;.\n */\n Coroutine(const std::shared_ptr& components,\n const provider::ResourceHandle& connection,\n TaskProcessingListener* taskListener);\n\n ~Coroutine() override;\n\n Action act() override;\n\n Action parseHeaders();\n \n Action onHeadersParsed(const RequestHeadersReader::Result& headersReadResult);\n \n Action onRequestFormed();\n Action onResponse(const std::shared_ptr& response);\n Action onResponseFormed();\n Action onRequestDone();\n \n Action handleError(Error* error) override;\n \n };\n \n};\n \n}}}\n\n#endif /* oatpp_web_server_HttpProcessor_hpp */\n\n// Path: src/oatpp/web/server/HttpProcessor.cpp\n/***************************************************************************\n *\n * Project _____ __ ____ _ _\n * ( _ ) /__\\ (_ _)_| |_ _| |_\n * )(_)( /(__)\\ )( (_ _)(_ _)\n * (_____)(__)(__)(__) |_| |_|\n *\n *\n * Copyright 2018-present, Leonid Stryzhevskyi \n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n *\n ***************************************************************************/\n\n#include \"HttpProcessor.hpp\"\n\n#include \"oatpp/web/protocol/http/incoming/SimpleBodyDecoder.hpp\"\n#include \"oatpp/core/data/stream/BufferStream.hpp\"\n\nnamespace oatpp { namespace web { namespace server {\n\n////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////\n// Components\n\nHttpProcessor::Components::Components(const std::shared_ptr& pRouter,\n const std::shared_ptr& pContentEncodingProviders,\n const std::shared_ptr& pBodyDecoder,\n const std::shared_ptr& pErrorHandler,\n const RequestInterceptors& pRequestInterceptors,\n const ResponseInterceptors& pResponseInterceptors,\n const std::shared_ptr& pConfig)\n : router(pRouter)\n , contentEncodingProviders(pContentEncodingProviders)\n , bodyDecoder(pBodyDecoder)\n , errorHandler(pErrorHandler)\n , requestInterceptors(pRequestInterceptors)\n , responseInterceptors(pResponseInterceptors)\n , config(pConfig)\n{}\n\nHttpProcessor::Components::Components(const std::shared_ptr& pRouter)\n : Components(pRouter,\n nullptr,\n std::make_shared(),\n handler::DefaultErrorHandler::createShared(),\n {},\n {},\n std::make_shared())\n{}\n\nHttpProcessor::Components::Components(const std::shared_ptr& pRouter, const std::shared_ptr& pConfig)\n : Components(pRouter,\n nullptr,\n std::make_shared(),\n handler::DefaultErrorHandler::createShared(),\n {},\n {},\n pConfig)\n{}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////\n// Other\n\nHttpProcessor::ProcessingResources::ProcessingResources(const std::shared_ptr& pComponents,\n const provider::ResourceHandle& pConnection)\n : components(pComponents)\n , connection(pConnection)\n , headersInBuffer(components->config->headersInBufferInitial)\n , headersOutBuffer(components->config->headersOutBufferInitial)\n , headersReader(&headersInBuffer, components->config->headersReaderChunkSize, components->config->headersReaderMaxSize)\n , inStream(data::stream::InputStreamBufferedProxy::createShared(connection.object, std::make_shared(data::buffer::IOBuffer::BUFFER_SIZE, 0)))\n{}\n\nstd::shared_ptr\nHttpProcessor::processNextRequest(ProcessingResources& resources,\n const std::shared_ptr& request,\n ConnectionState& connectionState)\n{\n\n std::shared_ptr response;\n\n try{\n\n for(auto& interceptor : resources.components->requestInterceptors) {\n response = interceptor->intercept(request);\n if(response) {\n return response;\n }\n }\n\n auto route = resources.components->router->getRoute(request->getStartingLine().method, request->getStartingLine().path);\n\n if(!route) {\n\n data::stream::BufferOutputStream ss;\n ss << \"No mapping for HTTP-method: '\" << request->getStartingLine().method.toString()\n << \"', URL: '\" << request->getStartingLine().path.toString() << \"'\";\n\n connectionState = ConnectionState::CLOSING;\n oatpp::web::protocol::http::HttpError error(protocol::http::Status::CODE_404, ss.toString());\n auto ptr = std::make_exception_ptr(error);\n return resources.components->errorHandler->handleError(ptr);\n\n }\n\n request->setPathVariables(route.getMatchMap());\n return route.getEndpoint()->handle(request);\n\n } catch (...) {\n response = resources.components->errorHandler->handleError(std::current_exception());\n connectionState = ConnectionState::CLOSING;\n }\n\n return response;\n\n}\n\nHttpProcessor::ConnectionState HttpProcessor::processNextRequest(ProcessingResources& resources) {\n\n oatpp::web::protocol::http::HttpError::Info error;\n auto headersReadResult = resources.headersReader.readHeaders(resources.inStream.get(), error);\n\n if(error.ioStatus <= 0) {\n return ConnectionState::DEAD;\n }\n\n ConnectionState connectionState = ConnectionState::ALIVE;\n std::shared_ptr request;\n std::shared_ptr response;\n\n if(error.status.code != 0) {\n oatpp::web::protocol::http::HttpError httpError(error.status, \"Invalid Request Headers\");\n auto eptr = std::make_exception_ptr(httpError);\n response = resources.components->errorHandler->handleError(eptr);\n connectionState = ConnectionState::CLOSING;\n } else {\n\n request = protocol::http::incoming::Request::createShared(resources.connection.object,\n headersReadResult.startingLine,\n headersReadResult.headers,\n resources.inStream,\n resources.components->bodyDecoder);\n\n response = processNextRequest(resources, request, connectionState);\n\n try {\n\n for (auto& interceptor : resources.components->responseInterceptors) {\n response = interceptor->intercept(request, response);\n if (!response) {\n oatpp::web::protocol::http::HttpError httpError(protocol::http::Status::CODE_500, \"Response Interceptor returned an Invalid Response - 'null'\");\n auto eptr = std::make_exception_ptr(httpError);\n response = resources.components->errorHandler->handleError(eptr);\n connectionState = ConnectionState::CLOSING;\n }\n }\n\n } catch (...) {\n response = resources.components->errorHandler->handleError(std::current_exception());\n connectionState = ConnectionState::CLOSING;\n }\n\n response->putHeaderIfNotExists(protocol::http::Header::SERVER, protocol::http::Header::Value::SERVER);\n protocol::http::utils::CommunicationUtils::considerConnectionState(request, response, connectionState);\n\n switch(connectionState) {\n\n case ConnectionState::ALIVE :\n response->putHeaderIfNotExists(protocol::http::Header::CONNECTION, protocol::http::Header::Value::CONNECTION_KEEP_ALIVE);\n break;\n\n case ConnectionState::CLOSING:\n case ConnectionState::DEAD:\n response->putHeaderIfNotExists(protocol::http::Header::CONNECTION, protocol::http::Header::Value::CONNECTION_CLOSE);\n break;\n\n case ConnectionState::DELEGATED:\n default:\n break;\n\n }\n\n }\n\n auto contentEncoderProvider =\n protocol::http::utils::CommunicationUtils::selectEncoder(request, resources.components->contentEncodingProviders);\n\n response->send(resources.connection.object.get(), &resources.headersOutBuffer, contentEncoderProvider.get());\n\n /* Delegate connection handling to another handler only after the response is sent to the client */\n if(connectionState == ConnectionState::DELEGATED) {\n auto handler = response->getConnectionUpgradeHandler();\n if(handler) {\n handler->handleConnection(resources.connection, response->getConnectionUpgradeParameters());\n connectionState = ConnectionState::DELEGATED;\n } else {\n OATPP_LOGW(\"[oatpp::web::server::HttpProcessor::processNextRequest()]\", \"Warning. ConnectionUpgradeHandler not set!\")\n connectionState = ConnectionState::CLOSING;\n }\n }\n\n return connectionState;\n\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////\n// Task\n\nHttpProcessor::Task::Task(const std::shared_ptr& components,\n const provider::ResourceHandle& connection,\n TaskProcessingListener* taskListener)\n : m_components(components)\n , m_connection(connection)\n , m_taskListener(taskListener)\n{\n m_taskListener->onTaskStart(m_connection);\n}\n\nHttpProcessor::Task::Task(HttpProcessor::Task &&other)\n : m_components(std::move(other.m_components))\n , m_connection(std::move(other.m_connection))\n , m_taskListener(other.m_taskListener)\n{\n other.m_taskListener = nullptr;\n}\n\nHttpProcessor::Task::~Task() {\n if (m_taskListener != nullptr) {\n m_taskListener->onTaskEnd(m_connection);\n }\n}\n\nHttpProcessor::Task &HttpProcessor::Task::operator=(HttpProcessor::Task &&other) {\n m_components = std::move(other.m_components);\n m_connection = std::move(other.m_connection);\n m_taskListener = other.m_taskListener;\n other.m_taskListener = nullptr;\n return *this;\n}\n\nvoid HttpProcessor::Task::run(){\n\n m_connection.object->initContexts();\n\n ProcessingResources resources(m_components, m_connection);\n\n ConnectionState connectionState;\n\n try {\n\n do {\n\n connectionState = HttpProcessor::processNextRequest(resources);\n\n } while (connectionState == ConnectionState::ALIVE);\n\n } catch (...) {\n // DO NOTHING\n }\n\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////\n// HttpProcessor::Coroutine\n\nHttpProcessor::Coroutine::Coroutine(const std::shared_ptr& components,\n const provider::ResourceHandle& connection,\n TaskProcessingListener* taskListener)\n : m_components(components)\n , m_connection(connection)\n , m_headersInBuffer(components->config->headersInBufferInitial)\n , m_headersReader(&m_headersInBuffer, components->config->headersReaderChunkSize, components->config->headersReaderMaxSize)\n , m_headersOutBuffer(std::make_shared(components->config->headersOutBufferInitial))\n , m_inStream(data::stream::InputStreamBufferedProxy::createShared(m_connection.object, std::make_shared(data::buffer::IOBuffer::BUFFER_SIZE, 0)))\n , m_connectionState(ConnectionState::ALIVE)\n , m_taskListener(taskListener)\n{\n m_taskListener->onTaskStart(m_connection);\n}\n\nHttpProcessor::Coroutine::~Coroutine() {\n m_taskListener->onTaskEnd(m_connection);\n}\n\nHttpProcessor::Coroutine::Action HttpProcessor::Coroutine::act() {\n return m_connection.object->initContextsAsync().next(yieldTo(&HttpProcessor::Coroutine::parseHeaders));\n}\n\nHttpProcessor::Coroutine::Action HttpProcessor::Coroutine::parseHeaders() {\n return m_headersReader.readHeadersAsync(m_inStream).callbackTo(&HttpProcessor::Coroutine::onHeadersParsed);\n}\n\noatpp::async::Action HttpProcessor::Coroutine::onHeadersParsed(const RequestHeadersReader::Result& headersReadResult) {\n\n m_currentRequest = protocol::http::incoming::Request::createShared(m_connection.object,\n headersReadResult.startingLine,\n headersReadResult.headers,\n m_inStream,\n m_components->bodyDecoder);\n\n for(auto& interceptor : m_components->requestInterceptors) {\n m_currentResponse = interceptor->intercept(m_currentRequest);\n if(m_currentResponse) {\n return yieldTo(&HttpProcessor::Coroutine::onResponseFormed);\n }\n }\n\n m_currentRoute = m_components->router->getRoute(headersReadResult.startingLine.method.toString(), headersReadResult.startingLine.path.toString());\n\n if(!m_currentRoute) {\n\n data::stream::BufferOutputStream ss;\n ss << \"No mapping for HTTP-method: '\" << headersReadResult.startingLine.method.toString()\n << \"', URL: '\" << headersReadResult.startingLine.path.toString() << \"'\";\n oatpp::web::protocol::http::HttpError error(protocol::http::Status::CODE_404, ss.toString());\n auto eptr = std::make_exception_ptr(error);\n m_currentResponse = m_components->errorHandler->handleError(eptr);\n m_connectionState = ConnectionState::CLOSING;\n return yieldTo(&HttpProcessor::Coroutine::onResponseFormed);\n }\n\n m_currentRequest->setPathVariables(m_currentRoute.getMatchMap());\n\n return yieldTo(&HttpProcessor::Coroutine::onRequestFormed);\n\n}\n\nHttpProcessor::Coroutine::Action HttpProcessor::Coroutine::onRequestFormed() {\n return m_currentRoute.getEndpoint()->handleAsync(m_currentRequest).callbackTo(&HttpProcessor::Coroutine::onResponse);\n}\n\nHttpProcessor::Coroutine::Action HttpProcessor::Coroutine::onResponse(const std::shared_ptr& response) {\n m_currentResponse = response;\n return yieldTo(&HttpProcessor::Coroutine::onResponseFormed);\n}\n \nHttpProcessor::Coroutine::Action HttpProcessor::Coroutine::onResponseFormed() {\n\n for(auto& interceptor : m_components->responseInterceptors) {\n m_currentResponse = interceptor->intercept(m_currentRequest, m_currentResponse);\n if(!m_currentResponse) {\n oatpp::web::protocol::http::HttpError error(protocol::http::Status::CODE_500, \"Response Interceptor returned an Invalid Response - 'null'\");\n auto eptr = std::make_exception_ptr(error);\n m_currentResponse = m_components->errorHandler->handleError(eptr);\n }\n }\n\n m_currentResponse->putHeaderIfNotExists(protocol::http::Header::SERVER, protocol::http::Header::Value::SERVER);\n oatpp::web::protocol::http::utils::CommunicationUtils::considerConnectionState(m_currentRequest, m_currentResponse, m_connectionState);\n\n switch(m_connectionState) {\n\n case ConnectionState::ALIVE :\n m_currentResponse->putHeaderIfNotExists(protocol::http::Header::CONNECTION, protocol::http::Header::Value::CONNECTION_KEEP_ALIVE);\n break;\n\n case ConnectionState::CLOSING:\n case ConnectionState::DEAD:\n m_currentResponse->putHeaderIfNotExists(protocol::http::Header::CONNECTION, protocol::http::Header::Value::CONNECTION_CLOSE);\n break;\n\n case ConnectionState::DELEGATED:\n default:\n break;\n\n }\n\n auto contentEncoderProvider =\n protocol::http::utils::CommunicationUtils::selectEncoder(m_currentRequest, m_components->contentEncodingProviders);\n\n return protocol::http::outgoing::Response::sendAsync(m_currentResponse, m_connection.object, m_headersOutBuffer, contentEncoderProvider)\n .next(yieldTo(&HttpProcessor::Coroutine::onRequestDone));\n\n}\n \nHttpProcessor::Coroutine::Action HttpProcessor::Coroutine::onRequestDone() {\n\n switch (m_connectionState) {\n case ConnectionState::ALIVE:\n return yieldTo(&HttpProcessor::Coroutine::parseHeaders);\n\n /* Delegate connection handling to another handler only after the response is sent to the client */\n case ConnectionState::DELEGATED: {\n auto handler = m_currentResponse->getConnectionUpgradeHandler();\n if(handler) {\n handler->handleConnection(m_connection, m_currentResponse->getConnectionUpgradeParameters());\n m_connectionState = ConnectionState::DELEGATED;\n } else {\n OATPP_LOGW(\"[oatpp::web::server::HttpProcessor::Coroutine::onResponseFormed()]\", \"Warning. ConnectionUpgradeHandler not set!\")\n m_connectionState = ConnectionState::CLOSING;\n }\n break;\n }\n\n case ConnectionState::CLOSING:\n case ConnectionState::DEAD:\n default:\n break;\n\n }\n \n return finish();\n\n}\n \nHttpProcessor::Coroutine::Action HttpProcessor::Coroutine::handleError(Error* error) {\n\n if(error) {\n\n if(error->is()) {\n auto aioe = static_cast(error);\n if(aioe->getCode() == oatpp::IOError::BROKEN_PIPE) {\n return aioe; // do not report BROKEN_PIPE error\n }\n }\n\n if(m_currentResponse) {\n //OATPP_LOGE(\"[oatpp::web::server::HttpProcessor::Coroutine::handleError()]\", \"Unhandled error. '%s'. Dropping connection\", error->what())\n return error;\n }\n\n oatpp::web::protocol::http::HttpError httpError(protocol::http::Status::CODE_500, error->what());\n auto eptr = std::make_exception_ptr(httpError);\n m_currentResponse = m_components->errorHandler->handleError(eptr);\n return yieldTo(&HttpProcessor::Coroutine::onResponseFormed);\n\n }\n\n return error;\n\n}\n \n}}}\n\n// Path: src/oatpp/web/server/HttpConnectionHandler.hpp\n/***************************************************************************\n *\n * Project _____ __ ____ _ _\n * ( _ ) /__\\ (_ _)_| |_ _| |_\n * )(_)( /(__)\\ )( (_ _)(_ _)\n * (_____)(__)(__)(__) |_| |_|\n *\n *\n * Copyright 2018-present, Leonid Stryzhevskyi \n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n *\n ***************************************************************************/\n\n#ifndef oatpp_web_server_HttpConnectionHandler_hpp\n#define oatpp_web_server_HttpConnectionHandler_hpp\n\n#include \"oatpp/web/server/HttpProcessor.hpp\"\n#include \"oatpp/network/ConnectionHandler.hpp\"\n#include \"oatpp/core/concurrency/SpinLock.hpp\"\n\n#include \n\nnamespace oatpp { namespace web { namespace server {\n\n/**\n * Simple ConnectionHandler (&id:oatpp::network::ConnectionHandler;) for handling HTTP communication.
\n * Will create one thread per each connection to handle communication.\n */\nclass HttpConnectionHandler : public base::Countable, public network::ConnectionHandler, public HttpProcessor::TaskProcessingListener {\nprotected:\n\n void onTaskStart(const provider::ResourceHandle& connection) override;\n void onTaskEnd(const provider::ResourceHandle& connection) override;\n\n void invalidateAllConnections();\n\nprivate:\n std::shared_ptr m_components;\n std::atomic_bool m_continue;\n std::unordered_map> m_connections;\n oatpp::concurrency::SpinLock m_connectionsLock;\npublic:\n\n /**\n * Constructor.\n * @param components - &id:oatpp::web::server::HttpProcessor::Components;.\n */\n HttpConnectionHandler(const std::shared_ptr& components);\n\n /**\n * Constructor.\n * @param router - &id:oatpp::web::server::HttpRouter; to route incoming requests.\n */\n HttpConnectionHandler(const std::shared_ptr& router)\n : HttpConnectionHandler(std::make_shared(router))\n {}\n\n /**\n * Constructor.\n * @param router - &id:oatpp::web::server::HttpRouter; to route incoming requests.\n * @param config - &id:oatpp::web::server::HttpProcessor::Config;.\n */\n HttpConnectionHandler(const std::shared_ptr& router,\n const std::shared_ptr& config)\n : HttpConnectionHandler(std::make_shared(router, config))\n {}\n\npublic:\n\n /**\n * Create shared HttpConnectionHandler.\n * @param router - &id:oatpp::web::server::HttpRouter; to route incoming requests.\n * @return - `std::shared_ptr` to HttpConnectionHandler.\n */\n static std::shared_ptr createShared(const std::shared_ptr& router);\n\n /**\n * Set root error handler for all requests coming through this Connection Handler.\n * All unhandled errors will be handled by this error handler.\n * @param errorHandler - &id:oatpp::web::server::handler::ErrorHandler;.\n */\n void setErrorHandler(const std::shared_ptr& errorHandler);\n\n /**\n * Add request interceptor. Request interceptors are called before routing happens.\n * If multiple interceptors set then the order of interception is the same as the order of calls to `addRequestInterceptor`.\n * @param interceptor - &id:oatpp::web::server::interceptor::RequestInterceptor;.\n */\n void addRequestInterceptor(const std::shared_ptr& interceptor);\n\n /**\n * Add response interceptor.\n * If multiple interceptors set then the order of interception is the same as the order of calls to `addResponseInterceptor`.\n * @param interceptor - &id:oatpp::web::server::interceptor::RequestInterceptor;.\n */\n void addResponseInterceptor(const std::shared_ptr& interceptor);\n\n /**\n * Implementation of &id:oatpp::network::ConnectionHandler::handleConnection;.\n * @param connection - &id:oatpp::data::stream::IOStream; representing connection.\n */\n void handleConnection(const provider::ResourceHandle& connection,\n const std::shared_ptr& params) override;\n\n /**\n * Tell all worker threads to exit when done.\n */\n void stop() override;\n\n /**\n * Get connections count.\n * @return\n */\n v_uint64 getConnectionsCount();\n \n};\n \n}}}\n\n#endif /* oatpp_web_server_HttpConnectionHandler_hpp */\n\n// Path: src/oatpp/web/server/HttpConnectionHandler.cpp\n/***************************************************************************\n *\n * Project _____ __ ____ _ _\n * ( _ ) /__\\ (_ _)_| |_ _| |_\n * )(_)( /(__)\\ )( (_ _)(_ _)\n * (_____)(__)(__)(__) |_| |_|\n *\n *\n * Copyright 2018-present, Leonid Stryzhevskyi \n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n *\n ***************************************************************************/\n\n#include \"./HttpConnectionHandler.hpp\"\n\n#include \"oatpp/web/protocol/http/incoming/Request.hpp\"\n#include \"oatpp/web/protocol/http/Http.hpp\"\n\n#include \"oatpp/core/concurrency/Thread.hpp\"\n\n#include \"oatpp/core/data/buffer/IOBuffer.hpp\"\n\n#include \"oatpp/core/data/stream/BufferStream.hpp\"\n#include \"oatpp/core/data/stream/StreamBufferedProxy.hpp\"\n\n\nnamespace oatpp { namespace web { namespace server {\n\nvoid HttpConnectionHandler::onTaskStart(const provider::ResourceHandle& connection) {\n\n std::lock_guard lock(m_connectionsLock);\n m_connections.insert({reinterpret_cast(connection.object.get()), connection});\n\n if(!m_continue.load()) {\n connection.invalidator->invalidate(connection.object);\n }\n\n}\n\nvoid HttpConnectionHandler::onTaskEnd(const provider::ResourceHandle& connection) {\n std::lock_guard lock(m_connectionsLock);\n m_connections.erase(reinterpret_cast(connection.object.get()));\n}\n\nvoid HttpConnectionHandler::invalidateAllConnections() {\n std::lock_guard lock(m_connectionsLock);\n for(auto& c : m_connections) {\n const auto& handle = c.second;\n handle.invalidator->invalidate(handle.object);\n }\n}\n\nv_uint64 HttpConnectionHandler::getConnectionsCount() {\n std::lock_guard lock(m_connectionsLock);\n return m_connections.size();\n}\n\nHttpConnectionHandler::HttpConnectionHandler(const std::shared_ptr& components)\n : m_components(components)\n , m_continue(true)\n{}\n\nstd::shared_ptr HttpConnectionHandler::createShared(const std::shared_ptr& router){\n return std::make_shared(router);\n}\n\nvoid HttpConnectionHandler::setErrorHandler(const std::shared_ptr& errorHandler){\n m_components->errorHandler = errorHandler;\n if(!m_components->errorHandler) {\n m_components->errorHandler = handler::DefaultErrorHandler::createShared();\n }\n}\n\nvoid HttpConnectionHandler::addRequestInterceptor(const std::shared_ptr& interceptor) {\n m_components->requestInterceptors.push_back(interceptor);\n}\n\nvoid HttpConnectionHandler::addResponseInterceptor(const std::shared_ptr& interceptor) {\n m_components->responseInterceptors.push_back(interceptor);\n}\n \nvoid HttpConnectionHandler::handleConnection(const provider::ResourceHandle& connection,\n const std::shared_ptr& params)\n{\n\n (void)params;\n\n if (m_continue.load()) {\n\n connection.object->setOutputStreamIOMode(oatpp::data::stream::IOMode::BLOCKING);\n connection.object->setInputStreamIOMode(oatpp::data::stream::IOMode::BLOCKING);\n\n /* Create working thread */\n std::thread thread(&HttpProcessor::Task::run, std::move(HttpProcessor::Task(m_components, connection, this)));\n\n /* Get hardware concurrency -1 in order to have 1cpu free of workers. */\n v_int32 concurrency = oatpp::concurrency::getHardwareConcurrency();\n if (concurrency > 1) {\n concurrency -= 1;\n }\n\n /* Set thread affinity group CPUs [0..cpu_count - 1]. Leave one cpu free of workers */\n oatpp::concurrency::setThreadAffinityToCpuRange(thread.native_handle(),\n 0,\n concurrency - 1 /* -1 because 0-based index */);\n\n thread.detach();\n }\n\n}\n\nvoid HttpConnectionHandler::stop() {\n m_continue.store(false);\n\n /* invalidate all connections */\n invalidateAllConnections();\n\n /* Wait until all connection-threads are done */\n while(getConnectionsCount() > 0) {\n std::this_thread::sleep_for(std::chrono::milliseconds(100));\n }\n}\n\n}}}\n\n// Path: src/oatpp/web/server/interceptor/AllowCorsGlobal.hpp\n/***************************************************************************\n *\n * Project _____ __ ____ _ _\n * ( _ ) /__\\ (_ _)_| |_ _| |_\n * )(_)( /(__)\\ )( (_ _)(_ _)\n * (_____)(__)(__)(__) |_| |_|\n *\n *\n * Copyright 2018-present, Leonid Stryzhevskyi \n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n *\n ***************************************************************************/\n\n#ifndef oatpp_web_server_interceptor_AllowCorsGlobal_hpp\n#define oatpp_web_server_interceptor_AllowCorsGlobal_hpp\n\n#include \"oatpp/web/server/interceptor/ResponseInterceptor.hpp\"\n#include \"oatpp/web/server/interceptor/RequestInterceptor.hpp\"\n\nnamespace oatpp { namespace web { namespace server { namespace interceptor {\n\nclass AllowOptionsGlobal : public RequestInterceptor {\npublic:\n std::shared_ptr intercept(const std::shared_ptr& request) override;\n};\n\nclass AllowCorsGlobal : public ResponseInterceptor {\nprivate:\n oatpp::String m_origin;\n oatpp::String m_methods;\n oatpp::String m_headers;\n oatpp::String m_maxAge;\npublic:\n\n AllowCorsGlobal(const oatpp::String &origin = \"*\",\n const oatpp::String &methods = \"GET, POST, OPTIONS\",\n const oatpp::String &headers = \"DNT, User-Agent, X-Requested-With, If-Modified-Since, Cache-Control, Content-Type, Range, Authorization\",\n const oatpp::String &maxAge = \"1728000\");\n\n std::shared_ptr intercept(const std::shared_ptr& request,\n const std::shared_ptr& response) override;\n\n};\n\n}}}}\n\n#endif // oatpp_web_server_interceptor_AllowCorsGlobal_hpp\n\n// Path: src/oatpp/web/server/interceptor/AllowCorsGlobal.cpp\n/***************************************************************************\n *\n * Project _____ __ ____ _ _\n * ( _ ) /__\\ (_ _)_| |_ _| |_\n * )(_)( /(__)\\ )( (_ _)(_ _)\n * (_____)(__)(__)(__) |_| |_|\n *\n *\n * Copyright 2018-present, Leonid Stryzhevskyi \n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n *\n ***************************************************************************/\n\n#include \"AllowCorsGlobal.hpp\"\n\nnamespace oatpp { namespace web { namespace server { namespace interceptor {\n\nstd::shared_ptr AllowOptionsGlobal::intercept(const std::shared_ptr &request) {\n\n const auto &line = request->getStartingLine();\n\n if (line.method == \"OPTIONS\") {\n return OutgoingResponse::createShared(protocol::http::Status::CODE_204, nullptr);\n }\n\n return nullptr;\n\n}\n\nAllowCorsGlobal::AllowCorsGlobal(const oatpp::String &origin,\n const oatpp::String &methods,\n const oatpp::String &headers,\n const oatpp::String &maxAge)\n : m_origin(origin)\n , m_methods(methods)\n , m_headers(headers)\n , m_maxAge(maxAge)\n{}\n\nstd::shared_ptr AllowCorsGlobal::intercept(const std::shared_ptr& request,\n const std::shared_ptr& response)\n{\n response->putHeaderIfNotExists(protocol::http::Header::CORS_ORIGIN, m_origin);\n response->putHeaderIfNotExists(protocol::http::Header::CORS_METHODS, m_methods);\n response->putHeaderIfNotExists(protocol::http::Header::CORS_HEADERS, m_headers);\n response->putHeaderIfNotExists(protocol::http::Header::CORS_MAX_AGE, m_maxAge);\n return response;\n}\n\n}}}}\n\n// Path: src/oatpp/web/server/handler/AuthorizationHandler.hpp\n/***************************************************************************\n *\n * Project _____ __ ____ _ _\n * ( _ ) /__\\ (_ _)_| |_ _| |_\n * )(_)( /(__)\\ )( (_ _)(_ _)\n * (_____)(__)(__)(__) |_| |_|\n *\n *\n * Copyright 2018-present, Leonid Stryzhevskyi \n * Benedikt-Alexander Mokro\u00df \n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n *\n ***************************************************************************/\n\n#ifndef oatpp_web_server_handler_AuthorizationHandler_hpp\n#define oatpp_web_server_handler_AuthorizationHandler_hpp\n\n#include \n#include \"oatpp/web/protocol/http/Http.hpp\"\n#include \"oatpp/core/macro/codegen.hpp\"\n#include \"oatpp/core/data/mapping/type/Type.hpp\"\n\nnamespace oatpp { namespace web { namespace server { namespace handler {\n\n/**\n * The AuthorizationObject superclass, all AuthorizationObjects have to extend this class.\n */\nclass AuthorizationObject : public oatpp::base::Countable {\nprotected:\n A\nuthorizationObject() = default;\n};\n\n/**\n * Abstract Authorization Handler.\n */\nclass AuthorizationHandler {\npublic:\n /**\n * Convenience typedef for &l:AuthorizationObject;.\n */\n typedef oatpp::web::server::handler::AuthorizationObject AuthorizationObject;\n\n /**\n * Convenience typedef for &id:oatpp::data::stream::BufferOutputStream;.\n */\n typedef oatpp::data::stream::BufferOutputStream BufferOutputStream;\n\n /**\n * Convenience typedef for &id:oatpp::web::protocol::http::Headers;.\n */\n typedef oatpp::web::protocol::http::Headers Headers;\nprivate:\n oatpp::String m_scheme;\n oatpp::String m_realm;\npublic:\n\n /**\n * Constructor.\n * @param scheme - authorization type scheme. &id:oatpp::String;.\n * @param realm - realm. &id:oatpp::String;.\n */\n AuthorizationHandler(const oatpp::String& scheme, const oatpp::String& realm);\n\n /**\n * Default virtual destructor.\n */\n virtual ~AuthorizationHandler() = default;\n\n /**\n * Implement this method! Return nullptr if authorization should be denied.\n * @param header - `Authorization` header. &id:oatpp::String;.\n * @return - `std::shared_ptr` to &id:oatpp::web::server::handler::AuthorizationObject;.\n */\n virtual std::shared_ptr handleAuthorization(const oatpp::String& authorizationHeader) = 0;\n\n /**\n * Render WWW-Authenicate header value.
\n * Custom Authorization handlers may override this method in order to provide additional information.\n * @param stream - &id:oatpp::data::stream::BufferOutputStream;.\n */\n virtual void renderAuthenticateHeaderValue(BufferOutputStream& stream);\n\n /**\n * Add authorization error headers to the headers map.
\n * @param headers - &id:oatpp::web::protocol::http::Headers;.\n */\n virtual void addErrorResponseHeaders(Headers& headers);\n\n /**\n * Get authorization scheme.\n * @return\n */\n oatpp::String getScheme();\n\n /**\n * Get authorization realm.\n * @return\n */\n oatpp::String getRealm();\n\n};\n\n/**\n * Default Basic AuthorizationObject - Convenience object to enable Basic-Authorization without the need to implement anything.\n */\nclass DefaultBasicAuthorizationObject : public AuthorizationObject {\npublic:\n\n /**\n * User-Id. &id:oatpp::String;.\n */\n oatpp::String userId;\n\n /**\n * Password. &id:oatpp::String;.\n */\n oatpp::String password;\n\n};\n\n/**\n * AuthorizationHandler for Authorization Type `Basic`.
\n * See [RFC 7617](https://tools.ietf.org/html/rfc7617).
\n * Extend this class to implement Custom Basic Authorization.\n */\nclass BasicAuthorizationHandler : public AuthorizationHandler {\n...\n", "needle_token_start": 15039, "needle_token_end": 15049, "code_context_ntokens": 15841, "output": ["```cpp\nclass AuthorizationHandler {\npublic:\n /**\n * Convenience typedef for &l:AuthorizationObject;.\n */\n typedef oatpp::web::server::handler::AuthorizationObject AuthorizationObject;\n\n /**\n * Convenience typedef for &id:oatpp::data::stream::BufferOutputStream;.\n */\n typedef oatpp::data::stream::BufferOutputStream BufferOutputStream;\n\n /**\n * Convenience typedef for &id:oatpp::web::protocol::http::Headers;.\n */\n typedef oatpp::web::protocol::http::Headers Headers;\nprivate:\n oatpp::String m_scheme;\n oatpp::String m_realm;\npublic:\n\n /**\n * Constructor.\n * @param scheme - authorization type scheme. &id:oatpp::String;.\n * @param realm - realm. &id:oatpp::String;.\n */\n AuthorizationHandler(const oatpp::String& scheme, const oatpp::String& realm);\n\n /**\n * Default virtual destructor.\n */\n virtual ~AuthorizationHandler() = default;\n\n /**\n * Implement this method! Return nullptr if authorization should be denied.\n * @param header - `Authorization` header. &id:oatpp::String;.\n * @return - `std::shared_ptr` to &id:oatpp::web::server::handler::AuthorizationObject;.\n */\n virtual std::shared_ptr handleAuthorization(const oatpp::String& authorizationHeader) = 0;\n\n /**\n * Render WWW-Authenicate header value.
\n * Custom Authorization handlers may override this method in order to provide additional information.\n * @param stream - &id:oatpp::data::stream::BufferOutputStream;.\n */\n virtual void renderAuthenticateHeaderValue(BufferOutputStream& stream);\n\n /**\n * Add authorization error headers to the headers map.
\n * @param headers - &id:oatpp::web::protocol::http::Headers;.\n */\n virtual void addErrorResponseHeaders(Headers& headers);\n\n /**\n * Get authorization scheme.\n * @return\n */\n oatpp::String getScheme();\n\n /**\n * Get authorization realm.\n * @return\n */\n oatpp::String getRealm();\n\n};\n```"]} +{"repo": "google/gson", "name": "addTypeAdaptersForDate", "language": "java", "path": "gson/src/main/java/com/google/gson/GsonBuilder.java", "position_ratio": 0.05, "description": "\nFunction Description:\n1. **Purpose**: To configure serialization and deserialization adapters for date-related types based on specified patterns or styles.\n2. **Input**: A string pattern for dates, two integers representing date and time styles, and a list to which the adapters will be added.\n3. **Output**: There is no direct output as it modifies the list of factories by adding new type adapters.\n4. **Procedure**: The function checks if a non-empty date pattern is provided and creates date adapters accordingly. If no pattern is provided but styles are specified, it creates adapters based on these styles. It also checks for SQL types support and adds corresponding SQL date adapters if supported. These adapters are then added to the provided list of factories.\n\n", "instruction": "Based on the function description and code context, please retrieve and repeat the exact described function from the code context in a code block wrapped by ```:", "template": "instruction\ncode_context\ndescription\ninstruction", "code_context": " * which no {@link TypeAdapter} has been registered, and for which no built-in Gson {@code\n * TypeAdapter} exists.\n *\n *

The created Gson instance might only use an access filter once for a class or its members\n * and cache the result. It is not guaranteed that the filter will be used again every time a\n * class or its members are accessed during serialization or deserialization.\n *\n * @param filter filter to add\n * @return a reference to this {@code GsonBuilder} object to fulfill the \"Builder\" pattern\n * @since 2.9.1\n */\n @CanIgnoreReturnValue\n public GsonBuilder addReflectionAccessFilter(ReflectionAccessFilter filter) {\n Objects.requireNonNull(filter);\n reflectionFilters.addFirst(filter);\n return this;\n }\n\n /**\n * Creates a {@link Gson} instance based on the current configuration. This method is free of\n * side-effects to this {@code GsonBuilder} instance and hence can be called multiple times.\n *\n * @return an instance of Gson configured with the options currently set in this builder\n */\n public Gson create() {\n List factories =\n new ArrayList<>(this.factories.size() + this.hierarchyFactories.size() + 3);\n factories.addAll(this.factories);\n Collections.reverse(factories);\n\n List hierarchyFactories = new ArrayList<>(this.hierarchyFactories);\n Collections.reverse(hierarchyFactories);\n factories.addAll(hierarchyFactories);\n\n addTypeAdaptersForDate(datePattern, dateStyle, timeStyle, factories);\n\n return new Gson(\n excluder,\n fieldNamingPolicy,\n new HashMap<>(instanceCreators),\n serializeNulls,\n complexMapKeySerialization,\n generateNonExecutableJson,\n escapeHtmlChars,\n formattingStyle,\n strictness,\n serializeSpecialFloatingPointValues,\n useJdkUnsafe,\n longSerializationPolicy,\n datePattern,\n dateStyle,\n timeStyle,\n new ArrayList<>(this.factories),\n new ArrayList<>(this.hierarchyFactories),\n factories,\n objectToNumberStrategy,\n numberToNumberStrategy,\n new ArrayList<>(reflectionFilters));\n }\n\n \nprivate static void addTypeAdaptersForDate(\n String datePattern, int dateStyle, int timeStyle, List factories) {\n TypeAdapterFactory dateAdapterFactory;\n boolean sqlTypesSupported = SqlTypesSupport.SUPPORTS_SQL_TYPES;\n TypeAdapterFactory sqlTimestampAdapterFactory = null;\n TypeAdapterFactory sqlDateAdapterFactory = null;\n\n if (datePattern != null && !datePattern.trim().isEmpty()) {\n dateAdapterFactory = DefaultDateTypeAdapter.DateType.DATE.createAdapterFactory(datePattern);\n\n if (sqlTypesSupported) {\n sqlTimestampAdapterFactory =\n SqlTypesSupport.TIMESTAMP_DATE_TYPE.createAdapterFactory(datePattern);\n sqlDateAdapterFactory = SqlTypesSupport.DATE_DATE_TYPE.createAdapterFactory(datePattern);\n }\n } else if (dateStyle != DateFormat.DEFAULT || timeStyle != DateFormat.DEFAULT) {\n dateAdapterFactory =\n DefaultDateTypeAdapter.DateType.DATE.createAdapterFactory(dateStyle, timeStyle);\n\n if (sqlTypesSupported) {\n sqlTimestampAdapterFactory =\n SqlTypesSupport.TIMESTAMP_DATE_TYPE.createAdapterFactory(dateStyle, timeStyle);\n sqlDateAdapterFactory =\n SqlTypesSupport.DATE_DATE_TYPE.createAdapterFactory(dateStyle, timeStyle);\n }\n } else {\n return;\n }\n\n factories.add(dateAdapterFactory);\n if (sqlTypesSupported) {\n factories.add(sqlTimestampAdapterFactory);\n factories.add(sqlDateAdapterFactory);\n }\n }\n}\n\n// Path: gson/src/main/java/com/google/gson/stream/JsonScope.java\n/*\n * Copyright (C) 2010 Google Inc.\n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n */\n\npackage com.google.gson.stream;\n\n/**\n * Lexical scoping elements within a JSON reader or writer.\n *\n * @author Jesse Wilson\n * @since 1.6\n */\nfinal class JsonScope {\n private JsonScope() {}\n\n /** An array with no elements requires no separator before the next element. */\n static final int EMPTY_ARRAY = 1;\n\n /** An array with at least one value requires a separator before the next element. */\n static final int NONEMPTY_ARRAY = 2;\n\n /** An object with no name/value pairs requires no separator before the next element. */\n static final int EMPTY_OBJECT = 3;\n\n /** An object whose most recent element is a key. The next element must be a value. */\n static final int DANGLING_NAME = 4;\n\n /** An object with at least one name/value pair requires a separator before the next element. */\n static final int NONEMPTY_OBJECT = 5;\n\n /** No top-level value has been started yet. */\n static final int EMPTY_DOCUMENT = 6;\n\n /** A top-level value has already been started. */\n static final int NONEMPTY_DOCUMENT = 7;\n\n /** A document that's been closed and cannot be accessed. */\n static final int CLOSED = 8;\n}\n\n// Path: gson/src/main/java/com/google/gson/stream/JsonWriter.java\n/*\n * Copyright (C) 2010 Google Inc.\n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n */\n\npackage com.google.gson.stream;\n\nimport static com.google.gson.stream.JsonScope.DANGLING_NAME;\nimport static com.google.gson.stream.JsonScope.EMPTY_ARRAY;\nimport static com.google.gson.stream.JsonScope.EMPTY_DOCUMENT;\nimport static com.google.gson.stream.JsonScope.EMPTY_OBJECT;\nimport static com.google.gson.stream.JsonScope.NONEMPTY_ARRAY;\nimport static com.google.gson.stream.JsonScope.NONEMPTY_DOCUMENT;\nimport static com.google.gson.stream.JsonScope.NONEMPTY_OBJECT;\n\nimport com.google.errorprone.annotations.CanIgnoreReturnValue;\nimport com.google.gson.FormattingStyle;\nimport com.google.gson.Gson;\nimport com.google.gson.GsonBuilder;\nimport com.google.gson.Strictness;\nimport java.io.Closeable;\nimport java.io.Flushable;\nimport java.io.IOException;\nimport java.io.Writer;\nimport java.math.BigDecimal;\nimport java.math.BigInteger;\nimport java.util.Arrays;\nimport java.util.Objects;\nimport java.util.concurrent.atomic.AtomicInteger;\nimport java.util.concurrent.atomic.AtomicLong;\nimport java.util.regex.Pattern;\n\n/**\n * Writes a JSON (RFC 8259) encoded value to a\n * stream, one token at a time. The stream includes both literal values (strings, numbers, booleans\n * and nulls) as well as the begin and end delimiters of objects and arrays.\n *\n *

Encoding JSON

\n *\n * To encode your data as JSON, create a new {@code JsonWriter}. Call methods on the writer as you\n * walk the structure's contents, nesting arrays and objects as necessary:\n *\n *
    \n *
  • To write arrays, first call {@link #beginArray()}. Write each of the\n * array's elements with the appropriate {@link #value} methods or by nesting other arrays and\n * objects. Finally close the array using {@link #endArray()}.\n *
  • To write objects, first call {@link #beginObject()}. Write each of the\n * object's properties by alternating calls to {@link #name} with the property's value. Write\n * property values with the appropriate {@link #value} method or by nesting other objects or\n * arrays. Finally close the object using {@link #endObject()}.\n *
\n *\n *

Configuration

\n *\n * The behavior of this writer can be customized with the following methods:\n *\n *
    \n *
  • {@link #setFormattingStyle(FormattingStyle)}, the default is {@link\n * FormattingStyle#COMPACT}\n *
  • {@link #setHtmlSafe(boolean)}, by default HTML characters are not escaped in the JSON\n * output\n *
  • {@link #setStrictness(Strictness)}, the default is {@link Strictness#LEGACY_STRICT}\n *
  • {@link #setSerializeNulls(boolean)}, by default {@code null} is serialized\n *
\n *\n * The default configuration of {@code JsonWriter} instances used internally by the {@link Gson}\n * class differs, and can be adjusted with the various {@link GsonBuilder} methods.\n *\n *

Example

\n *\n * Suppose we'd like to encode a stream of messages such as the following:\n *\n *
{@code\n * [\n *   {\n *     \"id\": 912345678901,\n *     \"text\": \"How do I stream JSON in Java?\",\n *     \"geo\": null,\n *     \"user\": {\n *       \"name\": \"json_newb\",\n *       \"followers_count\": 41\n *      }\n *   },\n *   {\n *     \"id\": 912345678902,\n *     \"text\": \"@json_newb just use JsonWriter!\",\n *     \"geo\": [50.454722, -104.606667],\n *     \"user\": {\n *       \"name\": \"jesse\",\n *       \"followers_count\": 2\n *     }\n *   }\n * ]\n * }
\n *\n * This code encodes the above structure:\n *\n *
{@code\n * public void writeJsonStream(OutputStream out, List messages) throws IOException {\n *   JsonWriter writer = new JsonWriter(new OutputStreamWriter(out, \"UTF-8\"));\n *   writer.setIndent(\"    \");\n *   writeMessagesArray(writer, messages);\n *   writer.close();\n * }\n *\n * public void writeMessagesArray(JsonWriter writer, List messages) throws IOException {\n *   writer.beginArray();\n *   for (Message message : messages) {\n *     writeMessage(writer, message);\n *   }\n *   writer.endArray();\n * }\n *\n * public void writeMessage(JsonWriter writer, Message message) throws IOException {\n *   writer.beginObject();\n *   writer.name(\"id\").value(message.getId());\n *   writer.name(\"text\").value(message.getText());\n *   if (message.getGeo() != null) {\n *     writer.name(\"geo\");\n *     writeDoublesArray(writer, message.getGeo());\n *   } else {\n *     writer.name(\"geo\").nullValue();\n *   }\n *   writer.name(\"user\");\n *   writeUser(writer, message.getUser());\n *   writer.endObject();\n * }\n *\n * public void writeUser(JsonWriter writer, User user) throws IOException {\n *   writer.beginObject();\n *   writer.name(\"name\").value(user.getName());\n *   writer.name(\"followers_count\").value(user.getFollowersCount());\n *   writer.endObject();\n * }\n *\n * public void writeDoublesArray(JsonWriter writer, List doubles) throws IOException {\n *   writer.beginArray();\n *   for (Double value : doubles) {\n *     writer.value(value);\n *   }\n *   writer.endArray();\n * }\n * }
\n *\n *

Each {@code JsonWriter} may be used to write a single JSON stream. Instances of this class are\n * not thread safe. Calls that would result in a malformed JSON string will fail with an {@link\n * IllegalStateException}.\n *\n * @author Jesse Wilson\n * @since 1.6\n */\npublic class JsonWriter implements Closeable, Flushable {\n\n // Syntax as defined by https://datatracker.ietf.org/doc/html/rfc8259#section-6\n private static final Pattern VALID_JSON_NUMBER_PATTERN =\n Pattern.compile(\"-?(?:0|[1-9][0-9]*)(?:\\\\.[0-9]+)?(?:[eE][-+]?[0-9]+)?\");\n\n /*\n * From RFC 8259, \"All Unicode characters may be placed within the\n * quotation marks except for the characters that must be escaped:\n * quotation mark, reverse solidus, and the control characters\n * (U+0000 through U+001F).\"\n *\n * We also escape '\\u2028' and '\\u2029', which JavaScript interprets as\n * newline characters. This prevents eval() from failing with a syntax\n * error. http://code.google.com/p/google-gson/issues/detail?id=341\n */\n private static final String[] REPLACEMENT_CHARS;\n private static final String[] HTML_SAFE_REPLACEMENT_CHARS;\n\n static {\n REPLACEMENT_CHARS = new String[128];\n for (int i = 0; i <= 0x1f; i++) {\n REPLACEMENT_CHARS[i] = String.format(\"\\\\u%04x\", i);\n }\n REPLACEMENT_CHARS['\"'] = \"\\\\\\\"\";\n REPLACEMENT_CHARS['\\\\'] = \"\\\\\\\\\";\n REPLACEMENT_CHARS['\\t'] = \"\\\\t\";\n REPLACEMENT_CHARS['\\b'] = \"\\\\b\";\n REPLACEMENT_CHARS['\\n'] = \"\\\\n\";\n REPLACEMENT_CHARS['\\r'] = \"\\\\r\";\n REPLACEMENT_CHARS['\\f'] = \"\\\\f\";\n HTML_SAFE_REPLACEMENT_CHARS = REPLACEMENT_CHARS.clone();\n HTML_SAFE_REPLACEMENT_CHARS['<'] = \"\\\\u003c\";\n HTML_SAFE_REPLACEMENT_CHARS['>'] = \"\\\\u003e\";\n HTML_SAFE_REPLACEMENT_CHARS['&'] = \"\\\\u0026\";\n HTML_SAFE_REPLACEMENT_CHARS['='] = \"\\\\u003d\";\n HTML_SAFE_REPLACEMENT_CHARS['\\''] = \"\\\\u0027\";\n }\n\n /** The JSON output destination */\n private final Writer out;\n\n private int[] stack = new int[32];\n private int stackSize = 0;\n\n {\n push(EMPTY_DOCUMENT);\n }\n\n private FormattingStyle formattingStyle;\n // These fields cache data derived from the formatting style, to avoid having to\n // re-evaluate it every time something is written\n private String formattedColon;\n private String formattedComma;\n private boolean usesEmptyNewlineAndIndent;\n\n private Strictness strictness = Strictness.LEGACY_STRICT;\n\n private boolean htmlSafe;\n\n private String deferredName;\n\n private boolean serializeNulls = true;\n\n /**\n * Creates a new instance that writes a JSON-encoded stream to {@code out}. For best performance,\n * ensure {@link Writer} is buffered; wrapping in {@link java.io.BufferedWriter BufferedWriter} if\n * necessary.\n */\n public JsonWriter(Writer out) {\n this.out = Objects.requireNonNull(out, \"out == null\");\n setFormattingStyle(FormattingStyle.COMPACT);\n }\n\n /**\n * Sets the indentation string to be repeated for each level of indentation in the encoded\n * document. If {@code indent.isEmpty()} the encoded document will be compact. Otherwise the\n * encoded document will be more human-readable.\n *\n *

This is a convenience method which overwrites any previously {@linkplain\n * #setFormattingStyle(FormattingStyle) set formatting style} with either {@link\n * FormattingStyle#COMPACT} if the given indent string is empty, or {@link FormattingStyle#PRETTY}\n * with the given indent if not empty.\n *\n * @param indent a string containing only whitespace.\n */\n public final void setIndent(String indent) {\n if (indent.isEmpty()) {\n setFormattingStyle(FormattingStyle.COMPACT);\n } else {\n setFormattingStyle(FormattingStyle.PRETTY.withIndent(indent));\n }\n }\n\n /**\n * Sets the formatting style to be used in the encoded document.\n *\n *

The formatting style specifies for example the indentation string to be repeated for each\n * level of indentation, or the newline style, to accommodate various OS styles.\n *\n * @param formattingStyle the formatting style to use, must not be {@code null}.\n * @since $next-version$\n */\n public final void setFormattingStyle(FormattingStyle formattingStyle) {\n this.formattingStyle = Objects.requireNonNull(formattingStyle);\n\n this.formattedComma = \",\";\n if (this.formattingStyle.usesSpaceAfterSeparators()) {\n this.formattedColon = \": \";\n\n // Only add space if no newline is written\n if (this.formattingStyle.getNewline().isEmpty()) {\n this.formattedComma = \", \";\n }\n } else {\n this.formattedColon = \":\";\n }\n\n this.usesEmptyNewlineAndIndent =\n this.formattingStyle.getNewline().isEmpty() && this.formattingStyle.getIndent().isEmpty();\n }\n\n /**\n * Returns the pretty printing style used by this writer.\n *\n * @return the {@code FormattingStyle} that will be used.\n * @since $next-version$\n */\n public final FormattingStyle getFormattingStyle() {\n return formattingStyle;\n }\n\n /**\n * Sets the strictness of this writer.\n *\n * @deprecated Please use {@link #setStrictness(Strictness)} instead. {@code\n * JsonWriter.setLenient(true)} should be replaced by {@code\n * JsonWriter.setStrictness(Strictness.LENIENT)} and {@code JsonWriter.setLenient(false)}\n * should be replaced by {@code JsonWriter.setStrictness(Strictness.LEGACY_STRICT)}.
\n * However, if you used {@code setLenient(false)} before, you might prefer {@link\n * Strictness#STRICT} now instead.\n * @param lenient whether this writer should be lenient. If true, the strictness is set to {@link\n * Strictness#LENIENT}. If false, the strictness is set to {@link Strictness#LEGACY_STRICT}.\n * @see #setStrictness(Strictness)\n */\n @Deprecated\n // Don't specify @InlineMe, so caller with `setLenient(false)` becomes aware of new\n // Strictness.STRICT\n @SuppressWarnings(\"InlineMeSuggester\")\n public final void setLenient(boolean lenient) {\n setStrictness(lenient ? Strictness.LENIENT : Strictness.LEGACY_STRICT);\n }\n\n /**\n * Returns true if the {@link Strictness} of this writer is equal to {@link Strictness#LENIENT}.\n *\n * @see JsonWriter#setStrictness(Strictness)\n */\n public boolean isLenient() {\n return strictness == Strictness.LENIENT;\n }\n\n /**\n * Configures how strict this writer is with regard to the syntax rules specified in RFC 8259. By default, {@link\n * Strictness#LEGACY_STRICT} is used.\n *\n *

\n *
{@link Strictness#STRICT} & {@link Strictness#LEGACY_STRICT}\n *
The behavior of these is currently identical. In these strictness modes, the writer only\n * writes JSON in accordance with RFC 8259.\n *
{@link Strictness#LENIENT}\n *
This mode relaxes the behavior of the writer to allow the writing of {@link\n * Double#isNaN() NaNs} and {@link Double#isInfinite() infinities}. It also allows writing\n * multiple top level values.\n *
\n *\n * @param strictness the new strictness of this writer. May not be {@code null}.\n * @since $next-version$\n */\n public final void setStrictness(Strictness strictness) {\n this.strictness = Objects.requireNonNull(strictness);\n }\n\n /**\n * Returns the {@linkplain Strictness strictness} of this writer.\n *\n * @see #setStrictness(Strictness)\n * @since $next-version$\n */\n public final Strictness getStrictness() {\n return strictness;\n }\n\n /**\n * Configures this writer to emit JSON that's safe for direct inclusion in HTML and XML documents.\n * This escapes the HTML characters {@code <}, {@code >}, {@code &}, {@code =} and {@code '}\n * before writing them to the stream. Without this setting, your XML/HTML encoder should replace\n * these characters with the corresponding escape sequences.\n */\n public final void setHtmlSafe(boolean htmlSafe) {\n this.htmlSafe = htmlSafe;\n }\n\n /**\n * Returns true if this writer writes JSON that's safe for inclusion in HTML and XML documents.\n */\n public final boolean isHtmlSafe() {\n return htmlSafe;\n }\n\n /**\n * Sets whether object members are serialized when their value is null. This has no impact on\n * array elements. The default is true.\n */\n public final void setSerializeNulls(boolean serializeNulls) {\n this.serializeNulls = serializeNulls;\n }\n\n /**\n * Returns true if object members are serialized when their value is null. This has no impact on\n * array elements. The default is true.\n */\n public final boolean getSerializeNulls() {\n return serializeNulls;\n }\n\n /**\n * Begins encoding a new array. Each call to this method must be paired with a call to {@link\n * #endArray}.\n *\n * @return this writer.\n */\n @CanIgnoreReturnValue\n public JsonWriter beginArray() throws IOException {\n writeDeferredName();\n return openScope(EMPTY_ARRAY, '[');\n }\n\n /**\n * Ends encoding the current array.\n *\n * @return this writer.\n */\n @CanIgnoreReturnValue\n public JsonWriter endArray() throws IOException {\n return closeScope(EMPTY_ARRAY, NONEMPTY_ARRAY, ']');\n }\n\n /**\n * Begins encoding a new object. Each call to this method must be paired with a call to {@link\n * #endObject}.\n *\n * @return this writer.\n */\n @CanIgnoreReturnValue\n public JsonWriter beginObject() throws IOException {\n writeDeferredName();\n return openScope(EMPTY_OBJECT, '{');\n }\n\n /**\n * Ends encoding the current object.\n *\n * @return this writer.\n */\n @CanIgnoreReturnValue\n public JsonWriter endObject() throws IOException {\n return closeScope(EMPTY_OBJECT, NONEMPTY_OBJECT, '}');\n }\n\n /** Enters a new scope by appending any necessary whitespace and the given bracket. */\n @CanIgnoreReturnValue\n private JsonWriter openScope(int empty, char openBracket) throws IOException {\n beforeValue();\n push(empty);\n out.write(openBracket);\n return this;\n }\n\n /** Closes the current scope by appending any necessary whitespace and the given bracket. */\n @CanIgnoreReturnValue\n private JsonWriter closeScope(int empty, int nonempty, char closeBracket) throws IOException {\n int context = peek();\n if (context != nonempty && context != empty) {\n throw new IllegalStateException(\"Nesting problem.\");\n }\n if (deferredName != null) {\n throw new IllegalStateException(\"Dangling name: \" + deferredName);\n }\n\n stackSize--;\n if (context == nonempty) {\n newline();\n }\n out.write(closeBracket);\n return this;\n }\n\n private void push(int newTop) {\n if (stackSize == stack.length) {\n stack = Arrays.copyOf(stack, stackSize * 2);\n }\n stack[stackSize++] = newTop;\n }\n\n /** Returns the value on the top of the stack. */\n private int peek() {\n if (stackSize == 0) {\n throw new IllegalStateException(\"JsonWriter is closed.\");\n }\n return stack[stackSize - 1];\n }\n\n /** Replace the value on the top of the stack with the given value. */\n private void replaceTop(int topOfStack) {\n stack[stackSize - 1] = topOfStack;\n }\n\n /**\n * Encodes the property name.\n *\n * @param name the name of the forthcoming value. May not be {@code null}.\n * @return this writer.\n */\n @CanIgnoreReturnValue\n public JsonWriter name(String name) throws IOException {\n Objects.requireNonNull(name, \"name == null\");\n if (deferredName != null) {\n throw new IllegalStateException(\"Already wrote a name, expecting a value.\");\n }\n int context = peek();\n if (context != EMPTY_OBJECT && context != NONEMPTY_OBJECT) {\n throw new IllegalStateException(\"Please begin an object before writing a name.\");\n }\n deferredName = name;\n return this;\n }\n\n private void writeDeferredName() throws IOException {\n if (deferredName != null) {\n beforeName();\n string(deferredName);\n deferredName = null;\n }\n }\n\n /**\n * Encodes {@code value}.\n *\n * @param value the literal string value, or null to encode a null literal.\n * @return this writer.\n */\n @CanIgnoreReturnValue\n public JsonWriter value(String value) throws IOException {\n if (value == null) {\n return nullValue();\n }\n writeDeferredName();\n beforeValue();\n string(value);\n return this;\n }\n\n /**\n * Encodes {@code value}.\n *\n * @return this writer.\n */\n @CanIgnoreReturnValue\n public JsonWriter value(boolean value) throws IOException {\n writeDeferredName();\n beforeValue();\n out.write(value ? \"true\" : \"false\");\n return this;\n }\n\n /**\n * Encodes {@code value}.\n *\n * @return this writer.\n * @since 2.7\n */\n @CanIgnoreReturnValue\n public JsonWriter value(Boolean value) throws IOException {\n if (value == null) {\n return nullValue();\n }\n writeDeferredName();\n beforeValue();\n out.write(value ? \"true\" : \"false\");\n return this;\n }\n\n /**\n * Encodes {@code value}.\n *\n * @param value a finite value, or if {@link #setStrictness(Strictness) lenient}, also {@link\n * Float#isNaN() NaN} or {@link Float#isInfinite() infinity}.\n * @return this writer.\n * @throws IllegalArgumentException if the value is NaN or Infinity and this writer is not {@link\n * #setStrictness(Strictness) lenient}.\n * @since 2.9.1\n */\n @CanIgnoreReturnValue\n public JsonWriter value(float value) throws IOException {\n writeDeferredName();\n if (strictness != Strictness.LENIENT && (Float.isNaN(value) || Float.isInfinite(value))) {\n throw new IllegalArgumentException(\"Numeric values must be finite, but was \" + value);\n }\n beforeValue();\n out.append(Float.toString(value));\n return this;\n }\n\n /**\n * Encodes {@code value}.\n *\n * @param value a finite value, or if {@link #setStrictness(Strictness) lenient}, also {@link\n * Double#isNaN() NaN} or {@link Double#isInfinite() infinity}.\n * @return this writer.\n * @throws IllegalArgumentException if the value is NaN or Infinity and this writer is not {@link\n * #setStrictness(Strictness) lenient}.\n */\n @CanIgnoreReturnValue\n public JsonWriter value(double value) throws IOException {\n writeDeferredName();\n if (strictness != Strictness.LENIENT && (Double.isNaN(value) || Double.isInfinite(value))) {\n throw new IllegalArgumentException(\"Numeric values must be finite, but was \" + value);\n }\n beforeValue();\n out.append(Double.toString(value));\n return this;\n }\n\n /**\n * Encodes {@code value}.\n *\n * @return this writer.\n */\n @CanIgnoreReturnValue\n public JsonWriter value(long value) throws IOException {\n writeDeferredName();\n beforeValue();\n out.write(Long.toString(value));\n return this;\n }\n\n /**\n * Encodes {@code value}. The value is written by directly writing the {@link Number#toString()}\n * result to JSON. Implementations must make sure that the result represents a valid JSON number.\n *\n * @param value a finite value, or if {@link #setStrictness(Strictness) lenient}, also {@link\n * Double#isNaN() NaN} or {@link Double#isInfinite() infinity}.\n * @return this writer.\n * @throws IllegalArgumentException if the value is NaN or Infinity and this writer is not {@link\n * #setStrictness(Strictness) lenient}; or if the {@code toString()} result is not a valid\n * JSON number.\n */\n @CanIgnoreReturnValue\n public JsonWriter value(Number value) throws IOException {\n if (value == null) {\n return nullValue();\n }\n\n writeDeferredName();\n String string = value.toString();\n if (string.equals(\"-Infinity\") || string.equals(\"Infinity\") || string.equals(\"NaN\")) {\n if (strictness != Strictness.LENIENT) {\n throw new IllegalArgumentException(\"Numeric values must be finite, but was \" + string);\n }\n } else {\n Class numberClass = value.getClass();\n // Validate that string is valid before writing it directly to JSON output\n if (!isTrustedNumberType(numberClass)\n && !VALID_JSON_NUMBER_PATTERN.matcher(string).matches()) {\n throw new IllegalArgumentException(\n \"String created by \" + numberClass + \" is not a valid JSON number: \" + string);\n }\n }\n\n beforeValue();\n out.append(string);\n return this;\n }\n\n /**\n * Encodes {@code null}.\n *\n * @return this writer.\n */\n @CanIgnoreReturnValue\n public JsonWriter nullValue() throws IOException {\n if (deferredName != null) {\n if (serializeNulls) {\n writeDeferredName();\n } else {\n deferredName = null;\n return this; // skip the name and the value\n }\n }\n beforeValue();\n out.write(\"null\");\n return this;\n }\n\n /**\n * Writes {@code value} directly to the writer without quoting or escaping. This might not be\n * supported by all implementations, if not supported an {@code UnsupportedOperationException} is\n * thrown.\n *\n * @param value the literal string value, or null to encode a null literal.\n * @return this writer.\n * @throws UnsupportedOperationException if this writer does not support writing raw JSON values.\n * @since 2.4\n */\n @CanIgnoreReturnValue\n public JsonWriter jsonValue(String value) throws IOException {\n if (value == null) {\n return nullValue();\n }\n writeDeferredName();\n beforeValue();\n out.append(value);\n return this;\n }\n\n /**\n * Ensures all buffered data is written to the underlying {@link Writer} and flushes that writer.\n */\n @Override\n public void flush() throws IOException {\n if (stackSize == 0) {\n throw new IllegalStateException(\"JsonWriter is closed.\");\n }\n out.flush();\n }\n\n /**\n * Flushes and closes this writer and the underlying {@link Writer}.\n *\n * @throws IOException if the JSON document is incomplete.\n */\n @Override\n public void close() throws IOException {\n out.close();\n\n int size = stackSize;\n if (size > 1 || (size == 1 && stack[size - 1] != NONEMPTY_DOCUMENT)) {\n throw new IOException(\"Incomplete document\");\n }\n stackSize = 0;\n }\n\n /**\n * Returns whether the {@code toString()} of {@code c} can be trusted to return a valid JSON\n * number.\n */\n private static boolean isTrustedNumberType(Class c) {\n // Note: Don't consider LazilyParsedNumber trusted because it could contain\n // an arbitrary malformed string\n return c == Integer.class\n || c == Long.class\n || c == Double.class\n || c == Float.class\n || c == Byte.class\n || c == Short.class\n || c == BigDecimal.class\n || c == BigInteger.class\n || c == AtomicInteger.class\n || c == AtomicLong.class;\n }\n\n private void string(String value) throws IOException {\n String[] replacements = htmlSafe ? HTML_SAFE_REPLACEMENT_CHARS : REPLACEMENT_CHARS;\n out.write('\\\"');\n int last = 0;\n int length = value.length();\n for (int i = 0; i < length; i++) {\n char c = value.charAt(i);\n String replacement;\n if (c < 128) {\n replacement = replacements[c];\n if (replacement == null) {\n continue;\n }\n } else if (c == '\\u2028') {\n replacement = \"\\\\u2028\";\n } else if (c == '\\u2029') {\n replacement = \"\\\\u2029\";\n } else {\n continue;\n }\n if (last < i) {\n out.write(value, last, i - last);\n }\n out.write(replacement);\n last = i + 1;\n }\n if (last < length) {\n out.write(value, last, length - last);\n }\n out.write('\\\"');\n }\n\n private void newline() throws IOException {\n if (usesEmptyNewlineAndIndent) {\n return;\n }\n\n out.write(formattingStyle.getNewline());\n for (int i = 1, size = stackSize; i < size; i++) {\n out.write(formattingStyle.getIndent());\n }\n }\n\n /**\n * Inserts any necessary separators and whitespace before a name. Also adjusts the stack to expect\n * the name's value.\n */\n private void beforeName() throws IOException {\n int context = peek();\n if (context == NONEMPTY_OBJECT) { // first in object\n out.write(formattedComma);\n } else if (context != EMPTY_OBJECT) { // not in an object!\n throw new IllegalStateException(\"Nesting problem.\");\n }\n newline();\n replaceTop(DANGLING_NAME);\n }\n\n /**\n * Inserts any necessary separators and whitespace before a literal value, inline array, or inline\n * object. Also adjusts the stack to expect either a closing bracket or another element.\n */\n @SuppressWarnings(\"fallthrough\")\n private void beforeValue() throws IOException {\n switch (peek()) {\n case NONEMPTY_DOCUMENT:\n if (strictness != Strictness.LENIENT) {\n throw new IllegalStateException(\"JSON must have only one top-level value.\");\n }\n // fall-through\n case EMPTY_DOCUMENT: // first in document\n replaceTop(NONEMPTY_DOCUMENT);\n break;\n\n case EMPTY_ARRAY: // first in array\n replaceTop(NONEMPTY_ARRAY);\n newline();\n break;\n\n case NONEMPTY_ARRAY: // another in array\n out.append(formattedComma);\n newline();\n break;\n\n case DANGLING_NAME: // value for name\n out.append(formattedColon);\n replaceTop(NONEMPTY_OBJECT);\n break;\n\n default:\n throw new IllegalStateException(\"Nesting problem.\");\n }\n }\n}\n\n// Path: gson/src/main/java/com/google/gson/FormattingStyle.java\n/*\n * Copyright (C) 2022 Google Inc.\n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n */\n\npackage com.google.gson;\n\nimport com.google.gson.stream.JsonWriter;\nimport java.util.Objects;\n\n/**\n * A class used to control what the serialization output looks like.\n *\n *

It currently has the following configuration methods, but more methods might be added in the\n * future:\n *\n *

    \n *
  • {@link #withNewline(String)}\n *
  • {@link #withIndent(String)}\n *
  • {@link #withSpaceAfterSeparators(boolean)}\n *
\n *\n * @see GsonBuilder#setFormattingStyle(FormattingStyle)\n * @see JsonWriter#setFormattingStyle(FormattingStyle)\n * @see Wikipedia Newline article\n * @since $next-version$\n */\npublic class FormattingStyle {\n private final String newline;\n private final String indent;\n private final boolean spaceAfterSeparators;\n\n /**\n * The default compact formatting style:\n *\n *
    \n *
  • no newline\n *
  • no indent\n *
  • no space after {@code ','} and {@code ':'}\n *
\n */\n public static final FormattingStyle COMPACT = new FormattingStyle(\"\", \"\", false);\n\n /**\n * The default pretty printing formatting style:\n *\n *
    \n *
  • {@code \"\\n\"} as newline\n *
  • two spaces as indent\n *
  • a space between {@code ':'} and the subsequent value\n *
\n */\n public static final FormattingStyle PRETTY = new FormattingStyle(\"\\n\", \" \", true);\n\n private FormattingStyle(String newline, String indent, boolean spaceAfterSeparators) {\n Objects.requireNonNull(newline, \"newline == null\");\n Objects.requireNonNull(indent, \"indent == null\");\n if (!newline.matches(\"[\\r\\n]*\")) {\n throw new IllegalArgumentException(\n \"Only combinations of \\\\n and \\\\r are allowed in newline.\");\n }\n if (!indent.matches(\"[ \\t]*\")) {\n throw new IllegalArgumentException(\n \"Only combinations of spaces and tabs are allowed in indent.\");\n }\n this.newline = newline;\n this.indent = indent;\n this.spaceAfterSeparators = spaceAfterSeparators;\n }\n\n /**\n * Creates a {@link FormattingStyle} with the specified newline setting.\n *\n *

It can be used to accommodate certain OS convention, for example hardcode {@code \"\\n\"} for\n * Linux and macOS, {@code \"\\r\\n\"} for Windows, or call {@link java.lang.System#lineSeparator()}\n * to match the current OS.\n *\n *

Only combinations of {@code \\n} and {@code \\r} are allowed.\n *\n * @param newline the string value that will be used as newline.\n * @return a newly created {@link FormattingStyle}\n */\n public FormattingStyle withNewline(String newline) {\n return new FormattingStyle(newline, this.indent, this.spaceAfterSeparators);\n }\n\n /**\n * Creates a {@link FormattingStyle} with the specified indent string.\n *\n *

Only combinations of spaces and tabs allowed in indent.\n *\n * @param indent the string value that will be used as indent.\n * @return a newly created {@link FormattingStyle}\n */\n public FormattingStyle withIndent(String indent) {\n return new FormattingStyle(this.newline, indent, this.spaceAfterSeparators);\n }\n\n /**\n * Creates a {@link FormattingStyle} which either uses a space after the separators {@code ','}\n * and {@code ':'} in the JSON output, or not.\n *\n *

This setting has no effect on the {@linkplain #withNewline(String) configured newline}. If a\n * non-empty newline is configured, it will always be added after {@code ','} and no space is\n * added after the {@code ','} in that case.\n *\n * @param spaceAfterSeparators whether to output a space after {@code ','} and {@code ':'}.\n * @return a newly created {@link FormattingStyle}\n */\n public FormattingStyle withSpaceAfterSeparators(boolean spaceAfterSeparators) {\n return new FormattingStyle(this.newline, this.indent, spaceAfterSeparators);\n }\n\n /**\n * Returns the string value that will be used as a newline.\n *\n * @return the newline value.\n */\n public String getNewline() {\n return this.newline;\n }\n\n /**\n * Returns the string value that will be used as indent.\n *\n * @return the indent value.\n */\n public String getIndent() {\n return this.indent;\n }\n\n /** Returns whether a space will be used after {@code ','} and {@code ':'}. */\n public boolean usesSpaceAfterSeparators() {\n return this.spaceAfterSeparators;\n }\n}\n\n// Path: gson/src/main/java/com/google/gson/annotations/JsonAdapter.java\n/*\n * Copyright (C) 2014 Google Inc.\n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n */\n\npackage com.google.gson.annotations;\n\nimport com.google.gson.Gson;\nimport com.google.gson.GsonBuilder;\nimport com.google.gson.InstanceCreator;\nimport com.google.gson.JsonDeserializer;\nimport com.google.gson.JsonSerializer;\nimport com.google.gson.TypeAdapter;\nimport com.google.gson.TypeAdapterFactory;\nimport java.lang.annotation.ElementType;\nimport java.lang.annotation.Retention;\nimport java.lang.annotation.RetentionPolicy;\nimport java.lang.annotation.Target;\n\n/**\n * An annotation that indicates the Gson {@link TypeAdapter} to use with a class or field.\n *\n *

Here is an example of how this annotation is used:\n *\n *

\n * @JsonAdapter(UserJsonAdapter.class)\n * public class User {\n *   public final String firstName, lastName;\n *\n *   private User(String firstName, String lastName) {\n *     this.firstName = firstName;\n *     this.lastName = lastName;\n *   }\n * }\n *\n * public class UserJsonAdapter extends TypeAdapter<User> {\n *   @Override public void write(JsonWriter out, User user) throws IOException {\n *     // implement write: combine firstName and lastName into name\n *     out.beginObject();\n *     out.name(\"name\");\n *     out.value(user.firstName + \" \" + user.lastName);\n *     out.endObject();\n *   }\n *\n *   @Override public User read(JsonReader in) throws IOException {\n *     // implement read: split name into firstName and lastName\n *     in.beginObject();\n *     in.nextName();\n *     String[] nameParts = in.nextString().split(\" \");\n *     in.endObject();\n *     return new User(nameParts[0], nameParts[1]);\n *   }\n * }\n * 
\n *\n * Since {@code User} class specified {@code UserJsonAdapter.class} in {@code @JsonAdapter}\n * annotation, it will automatically be invoked to serialize/deserialize {@code User} instances.\n *\n *

Here is an example of how to apply this annotation to a field.\n *\n *

\n * private static final class Gadget {\n *   @JsonAdapter(UserJsonAdapter.class)\n *   final User user;\n *\n *   Gadget(User user) {\n *     this.user = user;\n *   }\n * }\n * 
\n *\n * It's possible to specify different type adapters on a field, that field's type, and in the {@link\n * GsonBuilder}. Field annotations take precedence over {@code GsonBuilder}-registered type\n * adapters, which in turn take precedence over annotated types.\n *\n *

The class referenced by this annotation must be either a {@link TypeAdapter} or a {@link\n * TypeAdapterFactory}, or must implement one or both of {@link JsonDeserializer} or {@link\n * JsonSerializer}. Using {@link TypeAdapterFactory} makes it possible to delegate to the enclosing\n * {@link Gson} instance. By default the specified adapter will not be called for {@code null}\n * values; set {@link #nullSafe()} to {@code false} to let the adapter handle {@code null} values\n * itself.\n *\n *

The type adapter is created in the same way Gson creates instances of custom classes during\n * deserialization, that means:\n *\n *

    \n *
  1. If a custom {@link InstanceCreator} has been registered for the adapter class, it will be\n * used to create the instance\n *
  2. Otherwise, if the adapter class has a no-args constructor (regardless of which visibility),\n * it will be invoked to create the instance\n *
  3. Otherwise, JDK {@code Unsafe} will be used to create the instance; see {@link\n * GsonBuilder#disableJdkUnsafe()} for the unexpected side-effects this might have\n *
\n *\n *

{@code Gson} instances might cache the adapter they create for a {@code @JsonAdapter}\n * annotation. It is not guaranteed that a new adapter is created every time the annotated class or\n * field is serialized or deserialized.\n *\n * @since 2.3\n * @author Inderjeet Singh\n * @author Joel Leitch\n * @author Jesse Wilson\n */\n// Note that the above example is taken from AdaptAnnotationTest.\n@Retention(RetentionPolicy.RUNTIME)\n@Target({ElementType.TYPE, ElementType.FIELD})\npublic @interface JsonAdapter {\n\n /**\n * Either a {@link TypeAdapter} or {@link TypeAdapterFactory}, or one or both of {@link\n * JsonDeserializer} or {@link JsonSerializer}.\n */\n Class value();\n\n /**\n * Whether the adapter referenced by {@link #value()} should be made {@linkplain\n * TypeAdapter#nullSafe() null-safe}.\n *\n *

If {@code true} (the default), it will be made null-safe and Gson will handle {@code null}\n * Java objects on serialization and JSON {@code null} on deserialization without calling the\n * adapter. If {@code false}, the adapter will have to handle the {@code null} values.\n */\n boolean nullSafe() default true;\n}\n\n// Path: gson/src/main/java/com/google/gson/internal/ReflectionAccessFilterHelper.java\n/*\n * Copyright (C) 2022 Google Inc.\n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n */\n\npackage com.google.gson.internal;\n\nimport com.google.gson.ReflectionAccessFilter;\nimport com.google.gson.ReflectionAccessFilter.FilterResult;\nimport java.lang.reflect.AccessibleObject;\nimport java.lang.reflect.Method;\nimport java.util.List;\n\n/** Internal helper class for {@link ReflectionAccessFilter}. */\npublic class ReflectionAccessFilterHelper {\n private ReflectionAccessFilterHelper() {}\n\n // Platform type detection is based on Moshi's Util.isPlatformType(Class)\n // See\n // https://github.com/square/moshi/blob/3c108919ee1cce88a433ffda04eeeddc0341eae7/moshi/src/main/java/com/squareup/moshi/internal/Util.java#L141\n\n public static boolean isJavaType(Class c) {\n return isJavaType(c.getName());\n }\n\n private static boolean isJavaType(String className) {\n return className.startsWith(\"java.\") || className.startsWith(\"javax.\");\n }\n\n public static boolean isAndroidType(Class c) {\n return isAndroidType(c.getName());\n }\n\n private static boolean isAndroidType(String className) {\n return className.startsWith(\"android.\")\n || className.startsWith(\"androidx.\")\n || isJavaType(className);\n }\n\n public static boolean isAnyPlatformType(Class c) {\n String className = c.getName();\n return isAndroidType(className) // Covers Android and Java\n || className.startsWith(\"kotlin.\")\n || className.startsWith(\"kotlinx.\")\n || className.startsWith(\"scala.\");\n }\n\n /**\n * Gets the result of applying all filters until the first one returns a result other than {@link\n * FilterResult#INDECISIVE}, or {@link FilterResult#ALLOW} if the list of filters is empty or all\n * returned {@code INDECISIVE}.\n */\n public static FilterResult getFilterResult(\n List reflectionFilters, Class c) {\n for (ReflectionAccessFilter filter : reflectionFilters) {\n FilterResult result = filter.check(c);\n if (result != FilterResult.INDECISIVE) {\n return result;\n }\n }\n return FilterResult.ALLOW;\n }\n\n /** See {@link AccessibleObject#canAccess(Object)} (Java >= 9) */\n public static boolean canAccess(AccessibleObject accessibleObject, Object object) {\n return AccessChecker.INSTANCE.canAccess(accessibleObject, object);\n }\n\n private abstract static class AccessChecker {\n public static final AccessChecker INSTANCE;\n\n static {\n AccessChecker accessChecker = null;\n // TODO: Ideally should use Multi-Release JAR for this version specific code\n if (JavaVersion.isJava9OrLater()) {\n try {\n final Method canAccessMethod =\n AccessibleObject.class.getDeclaredMethod(\"canAccess\", Object.class);\n accessChecker =\n new AccessChecker() {\n @Override\n public boolean canAccess(AccessibleObject accessibleObject, Object object) {\n try {\n return (Boolean) canAccessMethod.invoke(accessibleObject, object);\n } catch (Exception e) {\n throw new RuntimeException(\"Failed invoking canAccess\", e);\n }\n }\n };\n } catch (NoSuchMethodException ignored) {\n // OK: will assume everything is accessible\n }\n }\n\n if (accessChecker == null) {\n accessChecker =\n new AccessChecker() {\n @Override\n public boolean canAccess(AccessibleObject accessibleObject, Object object) {\n // Cannot determine whether object can be accessed, so assume it can be accessed\n return true;\n }\n };\n }\n INSTANCE = accessChecker;\n }\n\n public abstract boolean canAccess(AccessibleObject accessibleObject, Object object);\n }\n}\n\n// Path: gson/src/main/java/com/google/gson/ReflectionAccessFilter.java\n/*\n * Copyright (C) 2022 Google Inc.\n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n */\n\npackage com.google.gson;\n\nimport com.google.gson.internal.ReflectionAccessFilterHelper;\nimport java.lang.reflect.AccessibleObject;\n\n/**\n * Filter for determining whether reflection based serialization and deserialization is allowed for\n * a class.\n *\n *

A filter can be useful in multiple scenarios, for example when upgrading to newer Java\n * versions which use the Java Platform Module System (JPMS). A filter then allows to {@linkplain\n * FilterResult#BLOCK_INACCESSIBLE prevent making inaccessible members accessible}, even if the used\n * Java version might still allow illegal access (but logs a warning), or if {@code java} command\n * line arguments are used to open the inaccessible packages to other parts of the application. This\n * interface defines some convenience filters for this task, such as {@link\n * #BLOCK_INACCESSIBLE_JAVA}.\n *\n *

A filter can also be useful to prevent mixing model classes of a project with other non-model\n * classes; the filter could {@linkplain FilterResult#BLOCK_ALL block all reflective access} to\n * non-model classes.\n *\n *

A reflection access filter is similar to an {@link ExclusionStrategy} with the major\n * difference that a filter will cause an exception to be thrown when access is disallowed while an\n * exclusion strategy just skips fields and classes.\n *\n * @see GsonBuilder#addReflectionAccessFilter(ReflectionAccessFilter)\n * @since 2.9.1\n */\npublic interface ReflectionAccessFilter {\n /**\n * Result of a filter check.\n *\n * @since 2.9.1\n */\n enum FilterResult {\n /**\n * Reflection access for the class is allowed.\n *\n *

Note that this does not affect the Java access checks in any way, it only permits Gson to\n * try using reflection for a class. The Java runtime might still deny such access.\n */\n ALLOW,\n /**\n * The filter is indecisive whether reflection access should be allowed. The next registered\n * filter will be consulted to get the result. If there is no next filter, this result acts like\n * {@link #ALLOW}.\n */\n INDECISIVE,\n /**\n * Blocks reflection access if a member of the class is not accessible by default and would have\n * to be made accessible. This is unaffected by any {@code java} command line arguments being\n * used to make packages accessible, or by module declaration directives which open the\n * complete module or certain packages for reflection and will consider such packages\n * inaccessible.\n *\n *

Note that this only works for Java 9 and higher, for older Java versions its\n * functionality will be limited and it might behave like {@link #ALLOW}. Access checks are only\n * performed as defined by the Java Language Specification (JLS 11\n * §6.6), restrictions imposed by a {@link SecurityManager} are not considered.\n *\n *

This result type is mainly intended to help enforce the access checks of the Java Platform\n * Module System. It allows detecting illegal access, even if the used Java version would only\n * log a warning, or is configured to open packages for reflection using command line arguments.\n *\n * @see AccessibleObject#canAccess(Object)\n */\n BLOCK_INACCESSIBLE,\n /**\n * Blocks all reflection access for the class. Other means for serializing and deserializing the\n * class, such as a {@link TypeAdapter}, have to be used.\n */\n BLOCK_ALL\n }\n\n /**\n * Blocks all reflection access to members of standard Java classes which are not accessible by\n * default. However, reflection access is still allowed for classes for which all fields are\n * accessible and which have an accessible no-args constructor (or for which an {@link\n * InstanceCreator} has been registered).\n *\n *

If this filter encounters a class other than a standard Java class it returns {@link\n * FilterResult#INDECISIVE}.\n *\n *

This filter is mainly intended to help enforcing the access checks of Java Platform Module\n * System. It allows detecting illegal access, even if the used Java version would only log a\n * warning, or is configured to open packages for reflection. However, this filter only works\n * for Java 9 and higher, when using an older Java version its functionality will be limited.\n *\n *

Note that this filter might not cover all standard Java classes. Currently only classes in a\n * {@code java.*} or {@code javax.*} package are considered. The set of detected classes might be\n * expanded in the future without prior notice.\n *\n * @see FilterResult#BLOCK_INACCESSIBLE\n */\n ReflectionAccessFilter BLOCK_INACCESSIBLE_JAVA =\n new ReflectionAccessFilter() {\n @Override\n public FilterResult check(Class rawClass) {\n return ReflectionAccessFilterHelper.isJavaType(rawClass)\n ? FilterResult.BLOCK_INACCESSIBLE\n : FilterResult.INDECISIVE;\n }\n\n @Override\n public String toString() {\n return \"ReflectionAccessFilter#BLOCK_INACCESSIBLE_JAVA\";\n }\n };\n\n /**\n * Blocks all reflection access to members of standard Java classes.\n *\n *

If this filter encounters a class other than a standard Java class it returns {@link\n * FilterResult#INDECISIVE}.\n *\n *

This filter is mainly intended to prevent depending on implementation details of the Java\n * platform and to help applications prepare for upgrading to the Java Platform Module System.\n *\n *

Note that this filter might not cover all standard Java classes. Currently only classes in a\n * {@code java.*} or {@code javax.*} package are considered. The set of detected classes might be\n * expanded in the future without prior notice.\n *\n * @see #BLOCK_INACCESSIBLE_JAVA\n * @see FilterResult#BLOCK_ALL\n */\n ReflectionAccessFilter BLOCK_ALL_JAVA =\n new ReflectionAccessFilter() {\n @Override\n...\n", "needle_token_start": 598, "needle_token_end": 986, "code_context_ntokens": 15904, "output": ["```java\nprivate static void addTypeAdaptersForDate(\n String datePattern, int dateStyle, int timeStyle, List factories) {\n TypeAdapterFactory dateAdapterFactory;\n boolean sqlTypesSupported = SqlTypesSupport.SUPPORTS_SQL_TYPES;\n TypeAdapterFactory sqlTimestampAdapterFactory = null;\n TypeAdapterFactory sqlDateAdapterFactory = null;\n\n if (datePattern != null && !datePattern.trim().isEmpty()) {\n dateAdapterFactory = DefaultDateTypeAdapter.DateType.DATE.createAdapterFactory(datePattern);\n\n if (sqlTypesSupported) {\n sqlTimestampAdapterFactory =\n SqlTypesSupport.TIMESTAMP_DATE_TYPE.createAdapterFactory(datePattern);\n sqlDateAdapterFactory = SqlTypesSupport.DATE_DATE_TYPE.createAdapterFactory(datePattern);\n }\n } else if (dateStyle != DateFormat.DEFAULT || timeStyle != DateFormat.DEFAULT) {\n dateAdapterFactory =\n DefaultDateTypeAdapter.DateType.DATE.createAdapterFactory(dateStyle, timeStyle);\n\n if (sqlTypesSupported) {\n sqlTimestampAdapterFactory =\n SqlTypesSupport.TIMESTAMP_DATE_TYPE.createAdapterFactory(dateStyle, timeStyle);\n sqlDateAdapterFactory =\n SqlTypesSupport.DATE_DATE_TYPE.createAdapterFactory(dateStyle, timeStyle);\n }\n } else {\n return;\n }\n\n factories.add(dateAdapterFactory);\n if (sqlTypesSupported) {\n factories.add(sqlTimestampAdapterFactory);\n factories.add(sqlDateAdapterFactory);\n }\n }\n```"]} +{"repo": "google/gson", "name": "parseString", "language": "java", "path": "gson/src/main/java/com/google/gson/JsonParser.java", "position_ratio": 0.15, "description": "\nFunction Description:\n1. **Purpose**: To convert a JSON formatted string into a structured parse tree of elements that can be easily manipulated within a program.\n2. **Input**: A string formatted in JSON.\n3. **Output**: A structured parse tree of elements that represent the JSON data.\n4. **Procedure**: The function reads the input string using a reader tailored for JSON, processes it in a lenient mode allowing some flexibility in JSON syntax, and constructs a parse tree that represents all the elements and structures in the JSON string. If the input string is not valid JSON, it throws an exception indicating the syntax error.\n\n", "instruction": "Based on the function description and code context, please retrieve and repeat the exact described function from the code context in a code block wrapped by ```:", "template": "instruction\ncode_context\ndescription\ninstruction", "code_context": "// Path: gson/src/main/java/com/google/gson/ToNumberPolicy.java\n/*\n * Copyright (C) 2021 Google Inc.\n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n */\n\npackage com.google.gson;\n\nimport com.google.gson.internal.LazilyParsedNumber;\nimport com.google.gson.internal.NumberLimits;\nimport com.google.gson.stream.JsonReader;\nimport com.google.gson.stream.MalformedJsonException;\n...\n// Path: gson/src/main/java/com/google/gson/JsonStreamParser.java\n/*\n * Copyright (C) 2009 Google Inc.\n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n */\npackage com.google.gson;\n\nimport com.google.gson.internal.Streams;\nimport com.google.gson.stream.JsonReader;\nimport com.google.gson.stream.JsonToken;\nimport com.google.gson.stream.MalformedJsonException;\nimport java.io.IOException;\nimport java.io.Reader;\nimport java.io.StringReader;\nimport java.util.Iterator;\nimport java.util.NoSuchElementException;\n\n/**\n * A streaming parser that allows reading of multiple {@link JsonElement}s from the specified reader\n * asynchronously. The JSON data is parsed in lenient mode, see also {@link\n * JsonReader#setStrictness(Strictness)}.\n *\n *

This class is conditionally thread-safe (see Item 70, Effective Java second edition). To\n * properly use this class across multiple threads, you will need to add some external\n * synchronization. For example:\n *\n *

\n * JsonStreamParser parser = new JsonStreamParser(\"['first'] {'second':10} 'third'\");\n * JsonElement element;\n * synchronized (parser) {  // synchronize on an object shared by threads\n *   if (parser.hasNext()) {\n *     element = parser.next();\n *   }\n * }\n * 
\n *\n * @author Inderjeet Singh\n * @author Joel Leitch\n * @since 1.4\n */\npublic final class JsonStreamParser implements Iterator {\n private final JsonReader parser;\n private final Object lock;\n\n /**\n * @param json The string containing JSON elements concatenated to each other.\n * @since 1.4\n */\n public JsonStreamParser(String json) {\n this(new StringReader(json));\n }\n\n /**\n * @param reader The data stream containing JSON elements concatenated to each other.\n * @since 1.4\n */\n public JsonStreamParser(Reader reader) {\n parser = new JsonReader(reader);\n parser.setStrictness(Strictness.LENIENT);\n lock = new Object();\n }\n\n /**\n * Returns the next available {@link JsonElement} on the reader. Throws a {@link\n * NoSuchElementException} if no element is available.\n *\n * @return the next available {@code JsonElement} on the reader.\n * @throws JsonParseException if the incoming stream is malformed JSON.\n * @throws NoSuchElementException if no {@code JsonElement} is available.\n * @since 1.4\n */\n @Override\n public JsonElement next() throws JsonParseException {\n if (!hasNext()) {\n throw new NoSuchElementException();\n }\n\n try {\n return Streams.parse(parser);\n } catch (StackOverflowError e) {\n throw new JsonParseException(\"Failed parsing JSON source to Json\", e);\n } catch (OutOfMemoryError e) {\n throw new JsonParseException(\"Failed parsing JSON source to Json\", e);\n }\n }\n\n /**\n * Returns true if a {@link JsonElement} is available on the input for consumption\n *\n * @return true if a {@link JsonElement} is available on the input, false otherwise\n * @throws JsonParseException if the incoming stream is malformed JSON.\n * @since 1.4\n */\n @Override\n public boolean hasNext() {\n synchronized (lock) {\n try {\n return parser.peek() != JsonToken.END_DOCUMENT;\n } catch (MalformedJsonException e) {\n throw new JsonSyntaxException(e);\n } catch (IOException e) {\n throw new JsonIOException(e);\n }\n }\n }\n\n /**\n * This optional {@link Iterator} method is not relevant for stream parsing and hence is not\n * implemented.\n *\n * @since 1.4\n */\n @Override\n public void remove() {\n throw new UnsupportedOperationException();\n }\n}\n\n// Path: gson/src/main/java/com/google/gson/package-info.java\n/*\n * Copyright (C) 2008 Google Inc.\n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n */\n\n/**\n * This package provides the {@link com.google.gson.Gson} class to convert Json to Java and\n * vice-versa.\n *\n *

The primary class to use is {@link com.google.gson.Gson} which can be constructed with {@code\n * new Gson()} (using default settings) or by using {@link com.google.gson.GsonBuilder} (to\n * configure various options such as using versioning and so on).\n *\n * @author Inderjeet Singh, Joel Leitch\n */\npackage com.google.gson;\n\n// Path: gson/src/main/java/com/google/gson/JsonParser.java\n/*\n * Copyright (C) 2009 Google Inc.\n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n */\npackage com.google.gson;\n\nimport com.google.errorprone.annotations.InlineMe;\nimport com.google.gson.internal.Streams;\nimport com.google.gson.stream.JsonReader;\nimport com.google.gson.stream.JsonToken;\nimport com.google.gson.stream.MalformedJsonException;\nimport java.io.IOException;\nimport java.io.Reader;\nimport java.io.StringReader;\n\n/**\n * A parser to parse JSON into a parse tree of {@link JsonElement}s.\n *\n * @author Inderjeet Singh\n * @author Joel Leitch\n * @since 1.3\n */\npublic final class JsonParser {\n /**\n * @deprecated No need to instantiate this class, use the static methods instead.\n */\n @Deprecated\n public JsonParser() {}\n\n /**\n * Parses the specified JSON string into a parse tree. An exception is thrown if the JSON string\n * has multiple top-level JSON elements, or if there is trailing data.\n *\n *

The JSON string is parsed in {@linkplain JsonReader#setStrictness(Strictness) lenient mode}.\n *\n * @param json JSON text\n * @return a parse tree of {@link JsonElement}s corresponding to the specified JSON\n * @throws JsonParseException if the specified text is not valid JSON\n * @since 2.8.6\n */\n \npublic static JsonElement parseString(String json) throws JsonSyntaxException {\n return parseReader(new StringReader(json));\n }\n\n /**\n * Parses the complete JSON string provided by the reader into a parse tree. An exception is\n * thrown if the JSON string has multiple top-level JSON elements, or if there is trailing data.\n *\n *

The JSON data is parsed in {@linkplain JsonReader#setStrictness(Strictness) lenient mode}.\n *\n * @param reader JSON text\n * @return a parse tree of {@link JsonElement}s corresponding to the specified JSON\n * @throws JsonParseException if there is an IOException or if the specified text is not valid\n * JSON\n * @since 2.8.6\n */\n public static JsonElement parseReader(Reader reader) throws JsonIOException, JsonSyntaxException {\n try {\n JsonReader jsonReader = new JsonReader(reader);\n JsonElement element = parseReader(jsonReader);\n if (!element.isJsonNull() && jsonReader.peek() != JsonToken.END_DOCUMENT) {\n throw new JsonSyntaxException(\"Did not consume the entire document.\");\n }\n return element;\n } catch (MalformedJsonException e) {\n throw new JsonSyntaxException(e);\n } catch (IOException e) {\n throw new JsonIOException(e);\n } catch (NumberFormatException e) {\n throw new JsonSyntaxException(e);\n }\n }\n\n /**\n * Returns the next value from the JSON stream as a parse tree. Unlike the other {@code parse}\n * methods, no exception is thrown if the JSON data has multiple top-level JSON elements, or if\n * there is trailing data.\n *\n *

The JSON data is parsed in {@linkplain JsonReader#setStrictness(Strictness) lenient mode},\n * regardless of the strictness setting of the provided reader. The strictness setting of the\n * reader is restored once this method returns.\n *\n * @throws JsonParseException if there is an IOException or if the specified text is not valid\n * JSON\n * @since 2.8.6\n */\n public static JsonElement parseReader(JsonReader reader)\n throws JsonIOException, JsonSyntaxException {\n Strictness strictness = reader.getStrictness();\n reader.setStrictness(Strictness.LENIENT);\n try {\n return Streams.parse(reader);\n } catch (StackOverflowError e) {\n throw new JsonParseException(\"Failed parsing JSON source: \" + reader + \" to Json\", e);\n } catch (OutOfMemoryError e) {\n throw new JsonParseException(\"Failed parsing JSON source: \" + reader + \" to Json\", e);\n } finally {\n reader.setStrictness(strictness);\n }\n }\n\n /**\n * @deprecated Use {@link JsonParser#parseString}\n */\n @Deprecated\n @InlineMe(replacement = \"JsonParser.parseString(json)\", imports = \"com.google.gson.JsonParser\")\n public JsonElement parse(String json) throws JsonSyntaxException {\n return parseString(json);\n }\n\n /**\n * @deprecated Use {@link JsonParser#parseReader(Reader)}\n */\n @Deprecated\n @InlineMe(replacement = \"JsonParser.parseReader(json)\", imports = \"com.google.gson.JsonParser\")\n public JsonElement parse(Reader json) throws JsonIOException, JsonSyntaxException {\n return parseReader(json);\n }\n\n /**\n * @deprecated Use {@link JsonParser#parseReader(JsonReader)}\n */\n @Deprecated\n @InlineMe(replacement = \"JsonParser.parseReader(json)\", imports = \"com.google.gson.JsonParser\")\n public JsonElement parse(JsonReader json) throws JsonIOException, JsonSyntaxException {\n return parseReader(json);\n }\n}\n\n// Path: gson/src/main/java/com/google/gson/stream/package-info.java\n/*\n * Copyright (C) 2021 Google Inc.\n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n */\n\n/** This package provides classes for processing JSON in an efficient streaming way. */\npackage com.google.gson.stream;\n\n// Path: gson/src/main/java/com/google/gson/reflect/package-info.java\n/*\n * Copyright (C) 2008 Google Inc.\n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n */\n\n/**\n * This package provides utility classes for finding type information for generic types.\n *\n * @author Inderjeet Singh, Joel Leitch\n */\npackage com.google.gson.reflect;\n\n// Path: gson/src/main/java/com/google/gson/internal/package-info.java\n/*\n * Copyright (C) 2011 Google Inc.\n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n */\n\n/**\n * Do NOT use any class in this package as they are meant for internal use in Gson. These classes\n * will very likely change incompatibly in future versions. You have been warned.\n *\n * @author Inderjeet Singh, Joel Leitch, Jesse Wilson\n */\npackage com.google.gson.internal;\n\n// Path: gson/src/main/java/com/google/gson/internal/bind/TypeAdapterRuntimeTypeWrapper.java\n/*\n * Copyright (C) 2011 Google Inc.\n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n */\npackage com.google.gson.internal.bind;\n\nimport com.google.gson.Gson;\nimport com.google.gson.TypeAdapter;\nimport com.google.gson.reflect.TypeToken;\nimport com.google.gson.stream.JsonReader;\nimport com.google.gson.stream.JsonWriter;\nimport java.io.IOException;\nimport java.lang.reflect.Type;\nimport java.lang.reflect.TypeVariable;\n\nfinal class TypeAdapterRuntimeTypeWrapper extends TypeAdapter {\n private final Gson context;\n private final TypeAdapter delegate;\n private final Type type;\n\n TypeAdapterRuntimeTypeWrapper(Gson context, TypeAdapter delegate, Type type) {\n this.context = context;\n this.delegate = delegate;\n this.type = type;\n }\n\n @Override\n public T read(JsonReader in) throws IOException {\n return delegate.read(in);\n }\n\n @Override\n public void write(JsonWriter out, T value) throws IOException {\n // Order of preference for choosing type adapters\n // First preference: a type adapter registered for the runtime type\n // Second preference: a type adapter registered for the declared type\n // Third preference: reflective type adapter for the runtime type\n // (if it is a subclass of the declared type)\n // Fourth preference: reflective type adapter for the declared type\n\n TypeAdapter chosen = delegate;\n Type runtimeType = getRuntimeTypeIfMoreSpecific(type, value);\n if (runtimeType != type) {\n @SuppressWarnings(\"unchecked\")\n TypeAdapter runtimeTypeAdapter =\n (TypeAdapter) context.getAdapter(TypeToken.get(runtimeType));\n // For backward compatibility only check ReflectiveTypeAdapterFactory.Adapter here but not any\n // other wrapping adapters, see\n // https://github.com/google/gson/pull/1787#issuecomment-1222175189\n if (!(runtimeTypeAdapter instanceof ReflectiveTypeAdapterFactory.Adapter)) {\n // The user registered a type adapter for the runtime type, so we will use that\n chosen = runtimeTypeAdapter;\n } else if (!isReflective(delegate)) {\n // The user registered a type adapter for Base class, so we prefer it over the\n // reflective type adapter for the runtime type\n chosen = delegate;\n } else {\n // Use the type adapter for runtime type\n chosen = runtimeTypeAdapter;\n }\n }\n chosen.write(out, value);\n }\n\n /**\n * Returns whether the type adapter uses reflection.\n *\n * @param typeAdapter the type adapter to check.\n */\n private static boolean isReflective(TypeAdapter typeAdapter) {\n // Run this in loop in case multiple delegating adapters are nested\n while (typeAdapter instanceof SerializationDelegatingTypeAdapter) {\n TypeAdapter delegate =\n ((SerializationDelegatingTypeAdapter) typeAdapter).getSerializationDelegate();\n // Break if adapter does not delegate serialization\n if (delegate == typeAdapter) {\n break;\n }\n typeAdapter = delegate;\n }\n\n return typeAdapter instanceof ReflectiveTypeAdapterFactory.Adapter;\n }\n\n /** Finds a compatible runtime type if it is more specific */\n private static Type getRuntimeTypeIfMoreSpecific(Type type, Object value) {\n if (value != null && (type instanceof Class || type instanceof TypeVariable)) {\n type = value.getClass();\n }\n return type;\n }\n}\n\n// Path: gson/src/main/java/com/google/gson/annotations/package-info.java\n/*\n * Copyright (C) 2008 Google Inc.\n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n */\n\n/**\n * This package provides annotations that can be used with {@link com.google.gson.Gson}.\n *\n * @author Inderjeet Singh, Joel Leitch\n */\npackage com.google.gson.annotations;\n\n", "needle_token_start": 2332, "needle_token_end": 2362, "code_context_ntokens": 5389, "output": ["```java\n/**\n * Parses the specified JSON string into a parse tree. An exception is thrown if the JSON string\n * has multiple top-level JSON elements, or if there is trailing data.\n *\n *

The JSON string is parsed in {@linkplain JsonReader#setStrictness(Strictness) lenient mode}.\n *\n * @param json JSON text\n * @return a parse tree of {@link JsonElement}s corresponding to the specified JSON\n * @throws JsonParseException if the specified text is not valid JSON\n * @since 2.8.6\n */\npublic static JsonElement parseString(String json) throws JsonSyntaxException {\n return parseReader(new StringReader(json));\n}\n```"]} +{"repo": "google/gson", "name": "newBuilder", "language": "java", "path": "gson/src/main/java/com/google/gson/Gson.java", "position_ratio": 0.25, "description": "\nFunction Description:\n1. **Purpose**: To create a new builder instance that retains all custom settings and factory configurations from the current instance.\n2. **Input**: None.\n3. **Output**: A new instance of a builder with all the configurations and customizations copied from the existing instance.\n4. **Procedure**: The method initializes a new builder instance by passing the current instance to its constructor, ensuring that all customizations and settings are preserved in the new instance.\n\n", "instruction": "Based on the function description and code context, please retrieve and repeat the exact described function from the code context in a code block wrapped by ```:", "template": "instruction\ncode_context\ndescription\ninstruction", "code_context": " *

  • When using an existing {@code JsonWriter}, manually apply the writer settings of this\n * {@code Gson} instance listed by {@link #newJsonWriter(Writer)}.
    \n * Otherwise, when not using an existing {@code JsonWriter}, use {@link\n * #newJsonWriter(Writer)} to construct one.\n *
  • Call {@link TypeAdapter#write(JsonWriter, Object)}\n * \n *\n *

    Deserialization

    \n *\n *
      \n *
    1. Use {@link #getAdapter(Class)} to obtain the adapter for the type to be deserialized\n *
    2. When using an existing {@code JsonReader}, manually apply the reader settings of this\n * {@code Gson} instance listed by {@link #newJsonReader(Reader)}.
      \n * Otherwise, when not using an existing {@code JsonReader}, use {@link\n * #newJsonReader(Reader)} to construct one.\n *
    3. Call {@link TypeAdapter#read(JsonReader)}\n *
    4. Call {@link JsonReader#peek()} and verify that the result is {@link JsonToken#END_DOCUMENT}\n * to make sure there is no trailing data\n *
    \n *\n * Note that the {@code JsonReader} created this way is only 'legacy strict', it mostly adheres to\n * the JSON specification but allows small deviations. See {@link\n * JsonReader#setStrictness(Strictness)} for details.\n *\n * @see TypeToken\n * @author Inderjeet Singh\n * @author Joel Leitch\n * @author Jesse Wilson\n */\npublic final class Gson {\n\n static final boolean DEFAULT_JSON_NON_EXECUTABLE = false;\n // Strictness of `null` is the legacy mode where some Gson APIs are always lenient\n static final Strictness DEFAULT_STRICTNESS = null;\n static final FormattingStyle DEFAULT_FORMATTING_STYLE = FormattingStyle.COMPACT;\n static final boolean DEFAULT_ESCAPE_HTML = true;\n static final boolean DEFAULT_SERIALIZE_NULLS = false;\n static final boolean DEFAULT_COMPLEX_MAP_KEYS = false;\n static final boolean DEFAULT_SPECIALIZE_FLOAT_VALUES = false;\n static final boolean DEFAULT_USE_JDK_UNSAFE = true;\n static final String DEFAULT_DATE_PATTERN = null;\n static final FieldNamingStrategy DEFAULT_FIELD_NAMING_STRATEGY = FieldNamingPolicy.IDENTITY;\n static final ToNumberStrategy DEFAULT_OBJECT_TO_NUMBER_STRATEGY = ToNumberPolicy.DOUBLE;\n static final ToNumberStrategy DEFAULT_NUMBER_TO_NUMBER_STRATEGY =\n ToNumberPolicy.LAZILY_PARSED_NUMBER;\n\n private static final String JSON_NON_EXECUTABLE_PREFIX = \")]}'\\n\";\n\n /**\n * This thread local guards against reentrant calls to {@link #getAdapter(TypeToken)}. In certain\n * object graphs, creating an adapter for a type may recursively require an adapter for the same\n * type! Without intervention, the recursive lookup would stack overflow. We cheat by returning a\n * proxy type adapter, {@link FutureTypeAdapter}, which is wired up once the initial adapter has\n * been created.\n *\n *

    The map stores the type adapters for ongoing {@code getAdapter} calls, with the type token\n * provided to {@code getAdapter} as key and either {@code FutureTypeAdapter} or a regular {@code\n * TypeAdapter} as value.\n */\n @SuppressWarnings(\"ThreadLocalUsage\")\n private final ThreadLocal, TypeAdapter>> threadLocalAdapterResults =\n new ThreadLocal<>();\n\n private final ConcurrentMap, TypeAdapter> typeTokenCache =\n new ConcurrentHashMap<>();\n\n private final ConstructorConstructor constructorConstructor;\n private final JsonAdapterAnnotationTypeAdapterFactory jsonAdapterFactory;\n\n final List factories;\n\n final Excluder excluder;\n final FieldNamingStrategy fieldNamingStrategy;\n final Map> instanceCreators;\n final boolean serializeNulls;\n final boolean complexMapKeySerialization;\n final boolean generateNonExecutableJson;\n final boolean htmlSafe;\n final FormattingStyle formattingStyle;\n final Strictness strictness;\n final boolean serializeSpecialFloatingPointValues;\n final boolean useJdkUnsafe;\n final String datePattern;\n final int dateStyle;\n final int timeStyle;\n final LongSerializationPolicy longSerializationPolicy;\n final List builderFactories;\n final List builderHierarchyFactories;\n final ToNumberStrategy objectToNumberStrategy;\n final ToNumberStrategy numberToNumberStrategy;\n final List reflectionFilters;\n\n /**\n * Constructs a Gson object with default configuration. The default configuration has the\n * following settings:\n *\n *

      \n *
    • The JSON generated by {@code toJson} methods is in compact representation. This means\n * that all the unneeded white-space is removed. You can change this behavior with {@link\n * GsonBuilder#setPrettyPrinting()}.\n *
    • When the JSON generated contains more than one line, the kind of newline and indent to\n * use can be configured with {@link GsonBuilder#setFormattingStyle(FormattingStyle)}.\n *
    • The generated JSON omits all the fields that are null. Note that nulls in arrays are kept\n * as is since an array is an ordered list. Moreover, if a field is not null, but its\n * generated JSON is empty, the field is kept. You can configure Gson to serialize null\n * values by setting {@link GsonBuilder#serializeNulls()}.\n *
    • Gson provides default serialization and deserialization for Enums, {@link Map}, {@link\n * java.net.URL}, {@link java.net.URI}, {@link java.util.Locale}, {@link java.util.Date},\n * {@link java.math.BigDecimal}, and {@link java.math.BigInteger} classes. If you would\n * prefer to change the default representation, you can do so by registering a type adapter\n * through {@link GsonBuilder#registerTypeAdapter(Type, Object)}.\n *
    • The default Date format is same as {@link java.text.DateFormat#DEFAULT}. This format\n * ignores the millisecond portion of the date during serialization. You can change this by\n * invoking {@link GsonBuilder#setDateFormat(int, int)} or {@link\n * GsonBuilder#setDateFormat(String)}.\n *
    • By default, Gson ignores the {@link com.google.gson.annotations.Expose} annotation. You\n * can enable Gson to serialize/deserialize only those fields marked with this annotation\n * through {@link GsonBuilder#excludeFieldsWithoutExposeAnnotation()}.\n *
    • By default, Gson ignores the {@link com.google.gson.annotations.Since} annotation. You\n * can enable Gson to use this annotation through {@link GsonBuilder#setVersion(double)}.\n *
    • The default field naming policy for the output JSON is same as in Java. So, a Java class\n * field {@code versionNumber} will be output as {@code \"versionNumber\"} in JSON. The same\n * rules are applied for mapping incoming JSON to the Java classes. You can change this\n * policy through {@link GsonBuilder#setFieldNamingPolicy(FieldNamingPolicy)}.\n *
    • By default, Gson excludes {@code transient} or {@code static} fields from consideration\n * for serialization and deserialization. You can change this behavior through {@link\n * GsonBuilder#excludeFieldsWithModifiers(int...)}.\n *
    • No explicit strictness is set. You can change this by calling {@link\n * GsonBuilder#setStrictness(Strictness)}.\n *
    \n */\n public Gson() {\n this(\n Excluder.DEFAULT,\n DEFAULT_FIELD_NAMING_STRATEGY,\n Collections.>emptyMap(),\n DEFAULT_SERIALIZE_NULLS,\n DEFAULT_COMPLEX_MAP_KEYS,\n DEFAULT_JSON_NON_EXECUTABLE,\n DEFAULT_ESCAPE_HTML,\n DEFAULT_FORMATTING_STYLE,\n DEFAULT_STRICTNESS,\n DEFAULT_SPECIALIZE_FLOAT_VALUES,\n DEFAULT_USE_JDK_UNSAFE,\n LongSerializationPolicy.DEFAULT,\n DEFAULT_DATE_PATTERN,\n DateFormat.DEFAULT,\n DateFormat.DEFAULT,\n Collections.emptyList(),\n Collections.emptyList(),\n Collections.emptyList(),\n DEFAULT_OBJECT_TO_NUMBER_STRATEGY,\n DEFAULT_NUMBER_TO_NUMBER_STRATEGY,\n Collections.emptyList());\n }\n\n Gson(\n Excluder excluder,\n FieldNamingStrategy fieldNamingStrategy,\n Map> instanceCreators,\n boolean serializeNulls,\n boolean complexMapKeySerialization,\n boolean generateNonExecutableGson,\n boolean htmlSafe,\n FormattingStyle formattingStyle,\n Strictness strictness,\n boolean serializeSpecialFloatingPointValues,\n boolean useJdkUnsafe,\n LongSerializationPolicy longSerializationPolicy,\n String datePattern,\n int dateStyle,\n int timeStyle,\n List builderFactories,\n List builderHierarchyFactories,\n List factoriesToBeAdded,\n ToNumberStrategy objectToNumberStrategy,\n ToNumberStrategy numberToNumberStrategy,\n List reflectionFilters) {\n this.excluder = excluder;\n this.fieldNamingStrategy = fieldNamingStrategy;\n this.instanceCreators = instanceCreators;\n this.constructorConstructor =\n new ConstructorConstructor(instanceCreators, useJdkUnsafe, reflectionFilters);\n this.serializeNulls = serializeNulls;\n this.complexMapKeySerialization = complexMapKeySerialization;\n this.generateNonExecutableJson = generateNonExecutableGson;\n this.htmlSafe = htmlSafe;\n this.formattingStyle = formattingStyle;\n this.strictness = strictness;\n this.serializeSpecialFloatingPointValues = serializeSpecialFloatingPointValues;\n this.useJdkUnsafe = useJdkUnsafe;\n this.longSerializationPolicy = longSerializationPolicy;\n this.datePattern = datePattern;\n this.dateStyle = dateStyle;\n this.timeStyle = timeStyle;\n this.builderFactories = builderFactories;\n this.builderHierarchyFactories = builderHierarchyFactories;\n this.objectToNumberStrategy = objectToNumberStrategy;\n this.numberToNumberStrategy = numberToNumberStrategy;\n this.reflectionFilters = reflectionFilters;\n\n List factories = new ArrayList<>();\n\n // built-in type adapters that cannot be overridden\n factories.add(TypeAdapters.JSON_ELEMENT_FACTORY);\n factories.add(ObjectTypeAdapter.getFactory(objectToNumberStrategy));\n\n // the excluder must precede all adapters that handle user-defined types\n factories.add(excluder);\n\n // users' type adapters\n factories.addAll(factoriesToBeAdded);\n\n // type adapters for basic platform types\n factories.add(TypeAdapters.STRING_FACTORY);\n factories.add(TypeAdapters.INTEGER_FACTORY);\n factories.add(TypeAdapters.BOOLEAN_FACTORY);\n factories.add(TypeAdapters.BYTE_FACTORY);\n factories.add(TypeAdapters.SHORT_FACTORY);\n TypeAdapter longAdapter = longAdapter(longSerializationPolicy);\n factories.add(TypeAdapters.newFactory(long.class, Long.class, longAdapter));\n factories.add(\n TypeAdapters.newFactory(\n double.class, Double.class, doubleAdapter(serializeSpecialFloatingPointValues)));\n factories.add(\n TypeAdapters.newFactory(\n float.class, Float.class, floatAdapter(serializeSpecialFloatingPointValues)));\n factories.add(NumberTypeAdapter.getFactory(numberToNumberStrategy));\n factories.add(TypeAdapters.ATOMIC_INTEGER_FACTORY);\n factories.add(TypeAdapters.ATOMIC_BOOLEAN_FACTORY);\n factories.add(TypeAdapters.newFactory(AtomicLong.class, atomicLongAdapter(longAdapter)));\n factories.add(\n TypeAdapters.newFactory(AtomicLongArray.class, atomicLongArrayAdapter(longAdapter)));\n factories.add(TypeAdapters.ATOMIC_INTEGER_ARRAY_FACTORY);\n factories.add(TypeAdapters.CHARACTER_FACTORY);\n factories.add(TypeAdapters.STRING_BUILDER_FACTORY);\n factories.add(TypeAdapters.STRING_BUFFER_FACTORY);\n factories.add(TypeAdapters.newFactory(BigDecimal.class, TypeAdapters.BIG_DECIMAL));\n factories.add(TypeAdapters.newFactory(BigInteger.class, TypeAdapters.BIG_INTEGER));\n // Add adapter for LazilyParsedNumber because user can obtain it from Gson and then try to\n // serialize it again\n factories.add(\n TypeAdapters.newFactory(LazilyParsedNumber.class, TypeAdapters.LAZILY_PARSED_NUMBER));\n factories.add(TypeAdapters.URL_FACTORY);\n factories.add(TypeAdapters.URI_FACTORY);\n factories.add(TypeAdapters.UUID_FACTORY);\n factories.add(TypeAdapters.CURRENCY_FACTORY);\n factories.add(TypeAdapters.LOCALE_FACTORY);\n factories.add(TypeAdapters.INET_ADDRESS_FACTORY);\n factories.add(TypeAdapters.BIT_SET_FACTORY);\n factories.add(DefaultDateTypeAdapter.DEFAULT_STYLE_FACTORY);\n factories.add(TypeAdapters.CALENDAR_FACTORY);\n\n if (SqlTypesSupport.SUPPORTS_SQL_TYPES) {\n factories.add(SqlTypesSupport.TIME_FACTORY);\n factories.add(SqlTypesSupport.DATE_FACTORY);\n factories.add(SqlTypesSupport.TIMESTAMP_FACTORY);\n }\n\n factories.add(ArrayTypeAdapter.FACTORY);\n factories.add(TypeAdapters.CLASS_FACTORY);\n\n // type adapters for composite and user-defined types\n factories.add(new CollectionTypeAdapterFactory(constructorConstructor));\n factories.add(new MapTypeAdapterFactory(constructorConstructor, complexMapKeySerialization));\n this.jsonAdapterFactory = new JsonAdapterAnnotationTypeAdapterFactory(constructorConstructor);\n factories.add(jsonAdapterFactory);\n factories.add(TypeAdapters.ENUM_FACTORY);\n factories.add(\n new ReflectiveTypeAdapterFactory(\n constructorConstructor,\n fieldNamingStrategy,\n excluder,\n jsonAdapterFactory,\n reflectionFilters));\n\n this.factories = Collections.unmodifiableList(factories);\n }\n\n /**\n * Returns a new GsonBuilder containing all custom factories and configuration used by the current\n * instance.\n *\n * @return a GsonBuilder instance.\n * @since 2.8.3\n */\n \npublic GsonBuilder newBuilder() {\n return new GsonBuilder(this);\n }\n\n /**\n * @deprecated This method by accident exposes an internal Gson class; it might be removed in a\n * future version.\n */\n @Deprecated\n public Excluder excluder() {\n return excluder;\n }\n\n /**\n * Returns the field naming strategy used by this Gson instance.\n *\n * @see GsonBuilder#setFieldNamingStrategy(FieldNamingStrategy)\n */\n public FieldNamingStrategy fieldNamingStrategy() {\n return fieldNamingStrategy;\n }\n\n /**\n * Returns whether this Gson instance is serializing JSON object properties with {@code null}\n * values, or just omits them.\n *\n * @see GsonBuilder#serializeNulls()\n */\n public boolean serializeNulls() {\n return serializeNulls;\n }\n\n /**\n * Returns whether this Gson instance produces JSON output which is HTML-safe, that means all HTML\n * characters are escaped.\n *\n * @see GsonBuilder#disableHtmlEscaping()\n */\n public boolean htmlSafe() {\n return htmlSafe;\n }\n\n private TypeAdapter doubleAdapter(boolean serializeSpecialFloatingPointValues) {\n if (serializeSpecialFloatingPointValues) {\n return TypeAdapters.DOUBLE;\n }\n return new TypeAdapter() {\n @Override\n public Double read(JsonReader in) throws IOException {\n if (in.peek() == JsonToken.NULL) {\n in.nextNull();\n return null;\n }\n return in.nextDouble();\n }\n\n @Override\n public void write(JsonWriter out, Number value) throws IOException {\n if (value == null) {\n out.nullValue();\n return;\n }\n double doubleValue = value.doubleValue();\n checkValidFloatingPoint(doubleValue);\n out.value(doubleValue);\n }\n };\n }\n\n private TypeAdapter floatAdapter(boolean serializeSpecialFloatingPointValues) {\n if (serializeSpecialFloatingPointValues) {\n return TypeAdapters.FLOAT;\n }\n return new TypeAdapter() {\n @Override\n public Float read(JsonReader in) throws IOException {\n if (in.peek() == JsonToken.NULL) {\n in.nextNull();\n return null;\n }\n return (float) in.nextDouble();\n }\n\n @Override\n public void write(JsonWriter out, Number value) throws IOException {\n if (value == null) {\n out.nullValue();\n return;\n }\n float floatValue = value.floatValue();\n checkValidFloatingPoint(floatValue);\n // For backward compatibility don't call `JsonWriter.value(float)` because that method has\n // been newly added and not all custom JsonWriter implementations might override it yet\n Number floatNumber = value instanceof Float ? value : floatValue;\n out.value(floatNumber);\n }\n };\n }\n\n static void checkValidFloatingPoint(double value) {\n if (Double.isNaN(value) || Double.isInfinite(value)) {\n throw new IllegalArgumentException(\n value\n + \" is not a valid double value as per JSON specification. To override this\"\n + \" behavior, use GsonBuilder.serializeSpecialFloatingPointValues() method.\");\n }\n }\n\n private static TypeAdapter longAdapter(LongSerializationPolicy longSerializationPolicy) {\n if (longSerializationPolicy == LongSerializationPolicy.DEFAULT) {\n return TypeAdapters.LONG;\n }\n return new TypeAdapter() {\n @Override\n public Number read(JsonReader in) throws IOException {\n if (in.peek() == JsonToken.NULL) {\n in.nextNull();\n return null;\n }\n return in.nextLong();\n }\n\n @Override\n public void write(JsonWriter out, Number value) throws IOException {\n if (value == null) {\n out.nullValue();\n return;\n }\n out.value(value.toString());\n }\n };\n }\n\n private static TypeAdapter atomicLongAdapter(final TypeAdapter longAdapter) {\n return new TypeAdapter() {\n @Override\n public void write(JsonWriter out, AtomicLong value) throws IOException {\n longAdapter.write(out, value.get());\n }\n\n @Override\n public AtomicLong read(JsonReader in) throws IOException {\n Number value = longAdapter.read(in);\n return new AtomicLong(value.longValue());\n }\n }.nullSafe();\n }\n\n private static TypeAdapter atomicLongArrayAdapter(\n final TypeAdapter longAdapter) {\n return new TypeAdapter() {\n @Override\n public void write(JsonWriter out, AtomicLongArray value) throws IOException {\n out.beginArray();\n for (int i = 0, length = value.length(); i < length; i++) {\n longAdapter.write(out, value.get(i));\n }\n out.endArray();\n }\n\n @Override\n public AtomicLongArray read(JsonReader in) throws IOException {\n List list = new ArrayList<>();\n in.beginArray();\n while (in.hasNext()) {\n long value = longAdapter.read(in).longValue();\n list.add(value);\n }\n in.endArray();\n int length = list.size();\n AtomicLongArray array = new AtomicLongArray(length);\n for (int i = 0; i < length; ++i) {\n array.set(i, list.get(i));\n }\n return array;\n }\n }.nullSafe();\n }\n\n /**\n * Returns the type adapter for {@code type}.\n *\n *

    When calling this method concurrently from multiple threads and requesting an adapter for\n * the same type this method may return different {@code TypeAdapter} instances. However, that\n * should normally not be an issue because {@code TypeAdapter} implementations are supposed to be\n * stateless.\n *\n * @throws IllegalArgumentException if this Gson instance cannot serialize and deserialize {@code\n * type}.\n */\n public TypeAdapter getAdapter(TypeToken type) {\n Objects.requireNonNull(type, \"type must not be null\");\n TypeAdapter cached = typeTokenCache.get(type);\n if (cached != null) {\n @SuppressWarnings(\"unchecked\")\n TypeAdapter adapter = (TypeAdapter) cached;\n return adapter;\n }\n\n Map, TypeAdapter> threadCalls = threadLocalAdapterResults.get();\n boolean isInitialAdapterRequest = false;\n if (threadCalls == null) {\n threadCalls = new HashMap<>();\n threadLocalAdapterResults.set(threadCalls);\n isInitialAdapterRequest = true;\n } else {\n // the key and value type parameters always agree\n @SuppressWarnings(\"unchecked\")\n TypeAdapter ongoingCall = (TypeAdapter) threadCalls.get(type);\n if (ongoingCall != null) {\n return ongoingCall;\n }\n }\n\n TypeAdapter candidate = null;\n try {\n FutureTypeAdapter call = new FutureTypeAdapter<>();\n threadCalls.put(type, call);\n\n for (TypeAdapterFactory factory : factories) {\n candidate = factory.create(this, type);\n if (candidate != null) {\n call.setDelegate(candidate);\n // Replace future adapter with actual adapter\n threadCalls.put(type, candidate);\n break;\n }\n }\n } finally {\n if (isInitialAdapterRequest) {\n threadLocalAdapterResults.remove();\n }\n }\n\n if (candidate == null) {\n throw new IllegalArgumentException(\n \"GSON (\" + GsonBuildConfig.VERSION + \") cannot handle \" + type);\n }\n\n if (isInitialAdapterRequest) {\n /*\n * Publish resolved adapters to all threads\n * Can only do this for the initial request because cyclic dependency TypeA -> TypeB -> TypeA\n * would otherwise publish adapter for TypeB which uses not yet resolved adapter for TypeA\n * See https://github.com/google/gson/issues/625\n */\n typeTokenCache.putAll(threadCalls);\n }\n return candidate;\n }\n\n /**\n * Returns the type adapter for {@code type}.\n *\n * @throws IllegalArgumentException if this Gson instance cannot serialize and deserialize {@code\n * type}.\n */\n public TypeAdapter getAdapter(Class type) {\n return getAdapter(TypeToken.get(type));\n }\n\n /**\n * This method is used to get an alternate type adapter for the specified type. This is used to\n * access a type adapter that is overridden by a {@link TypeAdapterFactory} that you may have\n * registered. This feature is typically used when you want to register a type adapter that does a\n * little bit of work but then delegates further processing to the Gson default type adapter. Here\n * is an example:\n *\n *

    Let's say we want to write a type adapter that counts the number of objects being read from\n * or written to JSON. We can achieve this by writing a type adapter factory that uses the {@code\n * getDelegateAdapter} method:\n *\n *

    {@code\n   * class StatsTypeAdapterFactory implements TypeAdapterFactory {\n   *   public int numReads = 0;\n   *   public int numWrites = 0;\n   *   public  TypeAdapter create(Gson gson, TypeToken type) {\n   *     final TypeAdapter delegate = gson.getDelegateAdapter(this, type);\n   *     return new TypeAdapter() {\n   *       public void write(JsonWriter out, T value) throws IOException {\n   *         ++numWrites;\n   *         delegate.write(out, value);\n   *       }\n   *       public T read(JsonReader in) throws IOException {\n   *         ++numReads;\n   *         return delegate.read(in);\n   *       }\n   *     };\n   *   }\n   * }\n   * }
    \n *\n * This factory can now be used like this:\n *\n *
    {@code\n   * StatsTypeAdapterFactory stats = new StatsTypeAdapterFactory();\n   * Gson gson = new GsonBuilder().registerTypeAdapterFactory(stats).create();\n   * // Call gson.toJson() and fromJson methods on objects\n   * System.out.println(\"Num JSON reads: \" + stats.numReads);\n   * System.out.println(\"Num JSON writes: \" + stats.numWrites);\n   * }
    \n *\n * Note that this call will skip all factories registered before {@code skipPast}. In case of\n * multiple TypeAdapterFactories registered it is up to the caller of this function to ensure that\n * the order of registration does not prevent this method from reaching a factory they would\n * expect to reply from this call. Note that since you can not override the type adapter factories\n * for some types, see {@link GsonBuilder#registerTypeAdapter(Type, Object)}, our stats factory\n * will not count the number of instances of those types that will be read or written.\n *\n *

    If {@code skipPast} is a factory which has neither been registered on the {@link\n * GsonBuilder} nor specified with the {@link JsonAdapter @JsonAdapter} annotation on a class,\n * then this method behaves as if {@link #getAdapter(TypeToken)} had been called. This also means\n * that for fields with {@code @JsonAdapter} annotation this method behaves normally like {@code\n * getAdapter} (except for corner cases where a custom {@link InstanceCreator} is used to create\n * an instance of the factory).\n *\n * @param skipPast The type adapter factory that needs to be skipped while searching for a\n * matching type adapter. In most cases, you should just pass this (the type adapter\n * factory from where {@code getDelegateAdapter} method is being invoked).\n * @param type Type for which the delegate adapter is being searched for.\n * @since 2.2\n */\n public TypeAdapter getDelegateAdapter(TypeAdapterFactory skipPast, TypeToken type) {\n Objects.requireNonNull(skipPast, \"skipPast must not be null\");\n Objects.requireNonNull(type, \"type must not be null\");\n\n if (jsonAdapterFactory.isClassJsonAdapterFactory(type, skipPast)) {\n skipPast = jsonAdapterFactory;\n }\n\n boolean skipPastFound = false;\n for (TypeAdapterFactory factory : factories) {\n if (!skipPastFound) {\n if (factory == skipPast) {\n skipPastFound = true;\n }\n continue;\n }\n\n TypeAdapter candidate = factory.create(this, type);\n if (candidate != null) {\n return candidate;\n }\n }\n\n if (skipPastFound) {\n throw new IllegalArgumentException(\"GSON cannot serialize or deserialize \" + type);\n } else {\n // Probably a factory from @JsonAdapter on a field\n return getAdapter(type);\n }\n }\n\n /**\n * This method serializes the specified object into its equivalent representation as a tree of\n * {@link JsonElement}s. This method should be used when the specified object is not a generic\n * type. This method uses {@link Class#getClass()} to get the type for the specified object, but\n * the {@code getClass()} loses the generic type information because of the Type Erasure feature\n * of Java. Note that this method works fine if any of the object fields are of generic type, just\n * the object itself should not be of a generic type. If the object is of generic type, use {@link\n * #toJsonTree(Object, Type)} instead.\n *\n * @param src the object for which JSON representation is to be created\n * @return JSON representation of {@code src}.\n * @since 1.4\n * @see #toJsonTree(Object, Type)\n */\n public JsonElement toJsonTree(Object src) {\n if (src == null) {\n return JsonNull.INSTANCE;\n }\n return toJsonTree(src, src.getClass());\n }\n\n /**\n * This method serializes the specified object, including those of generic types, into its\n * equivalent representation as a tree of {@link JsonElement}s. This method must be used if the\n * specified object is a generic type. For non-generic objects, use {@link #toJsonTree(Object)}\n * instead.\n *\n * @param src the object for which JSON representation is to be created\n * @param typeOfSrc The specific genericized type of src. You can obtain this type by using the\n * {@link com.google.gson.reflect.TypeToken} class. For example, to get the type for {@code\n * Collection}, you should use:\n *

    \n   * Type typeOfSrc = new TypeToken<Collection<Foo>>(){}.getType();\n   * 
    \n *\n * @return JSON representation of {@code src}.\n * @since 1.4\n * @see #toJsonTree(Object)\n */\n public JsonElement toJsonTree(Object src, Type typeOfSrc) {\n JsonTreeWriter writer = new JsonTreeWriter();\n toJson(src, typeOfSrc, writer);\n return writer.get();\n }\n\n /**\n * This method serializes the specified object into its equivalent JSON representation. This\n * method should be used when the specified object is not a generic type. This method uses {@link\n * Class#getClass()} to get the type for the specified object, but the {@code getClass()} loses\n * the generic type information because of the Type Erasure feature of Java. Note that this method\n * works fine if any of the object fields are of generic type, just the object itself should not\n * be of a generic type. If the object is of generic type, use {@link #toJson(Object, Type)}\n * instead. If you want to write out the object to a {@link Writer}, use {@link #toJson(Object,\n * Appendable)} instead.\n *\n * @param src the object for which JSON representation is to be created\n * @return JSON representation of {@code src}.\n * @see #toJson(Object, Appendable)\n * @see #toJson(Object, Type)\n */\n public String toJson(Object src) {\n if (src == null) {\n return toJson(JsonNull.INSTANCE);\n }\n return toJson(src, src.getClass());\n }\n\n /**\n * This method serializes the specified object, including those of generic types, into its\n * equivalent JSON representation. This method must be used if the specified object is a generic\n * type. For non-generic objects, use {@link #toJson(Object)} instead. If you want to write out\n * the object to a {@link Appendable}, use {@link #toJson(Object, Type, Appendable)} instead.\n *\n * @param src the object for which JSON representation is to be created\n * @param typeOfSrc The specific genericized type of src. You can obtain this type by using the\n * {@link com.google.gson.reflect.TypeToken} class. For example, to get the type for {@code\n * Collection}, you should use:\n *
    \n   * Type typeOfSrc = new TypeToken<Collection<Foo>>(){}.getType();\n   * 
    \n *\n * @return JSON representation of {@code src}.\n * @see #toJson(Object, Type, Appendable)\n * @see #toJson(Object)\n */\n public String toJson(Object src, Type typeOfSrc) {\n StringWriter writer = new StringWriter();\n toJson(src, typeOfSrc, writer);\n return writer.toString();\n }\n\n /**\n * This method serializes the specified object into its equivalent JSON representation and writes\n * it to the writer. This method should be used when the specified object is not a generic type.\n * This method uses {@link Class#getClass()} to get the type for the specified object, but the\n * {@code getClass()} loses the generic type information because of the Type Erasure feature of\n * Java. Note that this method works fine if any of the object fields are of generic type, just\n * the object itself should not be of a generic type. If the object is of generic type, use {@link\n * #toJson(Object, Type, Appendable)} instead.\n *\n * @param src the object for which JSON representation is to be created\n * @param writer Writer to which the JSON representation needs to be written\n * @throws JsonIOException if there was a problem writing to the writer\n * @since 1.2\n * @see #toJson(Object)\n * @see #toJson(Object, Type, Appendable)\n */\n public void toJson(Object src, Appendable writer) throws JsonIOException {\n if (src != null) {\n toJson(src, src.getClass(), writer);\n } else {\n toJson(JsonNull.INSTANCE, writer);\n }\n }\n\n /**\n * This method serializes the specified object, including those of generic types, into its\n * equivalent JSON representation and writes it to the writer. This method must be used if the\n * specified object is a generic type. For non-generic objects, use {@link #toJson(Object,\n * Appendable)} instead.\n *\n * @param src the object for which JSON representation is to be created\n * @param typeOfSrc The specific genericized type of src. You can obtain this type by using the\n * {@link com.google.gson.reflect.TypeToken} class. For example, to get the type for {@code\n * Collection}, you should use:\n *
    \n   * Type typeOfSrc = new TypeToken<Collection<Foo>>(){}.getType();\n   * 
    \n *\n * @param writer Writer to which the JSON representation of src needs to be written\n * @throws JsonIOException if there was a problem writing to the writer\n * @since 1.2\n * @see #toJson(Object, Type)\n * @see #toJson(Object, Appendable)\n */\n public void toJson(Object src, Type typeOfSrc, Appendable writer) throws JsonIOException {\n try {\n JsonWriter jsonWriter = newJsonWriter(Streams.writerForAppendable(writer));\n toJson(src, typeOfSrc, jsonWriter);\n } catch (IOException e) {\n throw new JsonIOException(e);\n }\n }\n\n /**\n * Writes the JSON representation of {@code src} of type {@code typeOfSrc} to {@code writer}.\n *\n *

    If the {@code Gson} instance has an {@linkplain GsonBuilder#setStrictness(Strictness)\n * explicit strictness setting}, this setting will be used for writing the JSON regardless of the\n * {@linkplain JsonWriter#getStrictness() strictness} of the provided {@link JsonWriter}. For\n * legacy reasons, if the {@code Gson} instance has no explicit strictness setting and the writer\n * does not have the strictness {@link Strictness#STRICT}, the JSON will be written in {@link\n * Strictness#LENIENT} mode.
    \n * Note that in all cases the old strictness setting of the writer will be restored when this\n * method returns.\n *\n *

    The 'HTML-safe' and 'serialize {@code null}' settings of this {@code Gson} instance\n * (configured by the {@link GsonBuilder}) are applied, and the original settings of the writer\n * are restored once this method returns.\n *\n * @param src the object for which JSON representation is to be created\n * @param typeOfSrc the type of the object to be written\n * @param writer Writer to which the JSON representation of src needs to be written\n * @throws JsonIOException if there was a problem writing to the writer\n */\n public void toJson(Object src, Type typeOfSrc, JsonWriter writer) throws JsonIOException {\n @SuppressWarnings(\"unchecked\")\n TypeAdapter adapter = (TypeAdapter) getAdapter(TypeToken.get(typeOfSrc));\n\n Strictness oldStrictness = writer.getStrictness();\n if (this.strictness != null) {\n writer.setStrictness(this.strictness);\n } else if (writer.getStrictness() != Strictness.STRICT) {\n writer.setStrictness(Strictness.LENIENT);\n }\n\n boolean oldHtmlSafe = writer.isHtmlSafe();\n boolean oldSerializeNulls = writer.getSerializeNulls();\n\n writer.setHtmlSafe(htmlSafe);\n writer.setSerializeNulls(serializeNulls);\n try {\n adapter.write(writer, src);\n } catch (IOException e) {\n throw new JsonIOException(e);\n } catch (AssertionError e) {\n throw new AssertionError(\n \"AssertionError (GSON \" + GsonBuildConfig.VERSION + \"): \" + e.getMessage(), e);\n } finally {\n writer.setStrictness(oldStrictness);\n writer.setHtmlSafe(oldHtmlSafe);\n writer.setSerializeNulls(oldSerializeNulls);\n }\n }\n\n /**\n * Converts a tree of {@link JsonElement}s into its equivalent JSON representation.\n *\n * @param jsonElement root of a tree of {@link JsonElement}s\n * @return JSON String representation of the tree.\n * @since 1.4\n */\n public String toJson(JsonElement jsonElement) {\n StringWriter writer = new StringWriter();\n toJson(jsonElement, writer);\n return writer.toString();\n }\n\n /**\n * Writes out the equivalent JSON for a tree of {@link JsonElement}s.\n *\n * @param jsonElement root of a tree of {@link JsonElement}s\n * @param writer Writer to which the JSON representation needs to be written\n * @throws JsonIOException if there was a problem writing to the writer\n * @since 1.4\n */\n public void toJson(JsonElement jsonElement, Appendable writer) throws JsonIOException {\n try {\n JsonWriter jsonWriter = newJsonWriter(Streams.writerForAppendable(writer));\n toJson(jsonElement, jsonWriter);\n } catch (IOException e) {\n throw new JsonIOException(e);\n }\n }\n\n /**\n * Writes the JSON for {@code jsonElement} to {@code writer}.\n *\n *

    If the {@code Gson} instance has an {@linkplain GsonBuilder#setStrictness(Strictness)\n * explicit strictness setting}, this setting will be used for writing the JSON regardless of the\n * {@linkplain JsonWriter#getStrictness() strictness} of the provided {@link JsonWriter}. For\n * legacy reasons, if the {@code Gson} instance has no explicit strictness setting and the writer\n * does not have the strictness {@link Strictness#STRICT}, the JSON will be written in {@link\n * Strictness#LENIENT} mode.
    \n * Note that in all cases the old strictness setting of the writer will be restored when this\n * method returns.\n *\n *

    The 'HTML-safe' and 'serialize {@code null}' settings of this {@code Gson} instance\n * (configured by the {@link GsonBuilder}) are applied, and the original settings of the writer\n * are restored once this method returns.\n *\n * @param jsonElement the JSON element to be written\n * @param writer the JSON writer to which the provided element will be written\n * @throws JsonIOException if there was a problem writing to the writer\n */\n public void toJson(JsonElement jsonElement, JsonWriter writer) throws JsonIOException {\n Strictness oldStrictness = writer.getStrictness();\n boolean oldHtmlSafe = writer.isHtmlSafe();\n boolean oldSerializeNulls = writer.getSerializeNulls();\n\n writer.setHtmlSafe(htmlSafe);\n writer.setSerializeNulls(serializeNulls);\n\n if (this.strictness != null) {\n writer.setStrictness(this.strictness);\n } else if (writer.getStrictness() != Strictness.STRICT) {\n writer.setStrictness(Strictness.LENIENT);\n }\n\n try {\n Streams.write(jsonElement, writer);\n } catch (IOException e) {\n throw new JsonIOException(e);\n } catch (AssertionError e) {\n throw new AssertionError(\n \"AssertionError (GSON \" + GsonBuildConfig.VERSION + \"): \" + e.getMessage(), e);\n } finally {\n writer.setStrictness(oldStrictness);\n writer.setHtmlSafe(oldHtmlSafe);\n writer.setSerializeNulls(oldSerializeNulls);\n }\n }\n\n /**\n * Returns a new JSON writer configured for the settings on this Gson instance.\n *\n *

    The following settings are considered:\n *\n *

      \n *
    • {@link GsonBuilder#disableHtmlEscaping()}\n *
    • {@link GsonBuilder#generateNonExecutableJson()}\n *
    • {@link GsonBuilder#serializeNulls()}\n *
    • {@link GsonBuilder#setStrictness(Strictness)}. If no {@linkplain\n * GsonBuilder#setStrictness(Strictness) explicit strictness has been set} the created\n * writer will have a strictness of {@link Strictness#LEGACY_STRICT}. Otherwise, the\n * strictness of the {@code Gson} instance will be used for the created writer.\n *
    • {@link GsonBuilder#setPrettyPrinting()}\n *
    • {@link GsonBuilder#setFormattingStyle(FormattingStyle)}\n *
    \n */\n public JsonWriter newJsonWriter(Writer writer) throws IOException {\n if (generateNonExecutableJson) {\n writer.write(JSON_NON_EXECUTABLE_PREFIX);\n }\n JsonWriter jsonWriter = new JsonWriter(writer);\n jsonWriter.setFormattingStyle(formattingStyle);\n jsonWriter.setHtmlSafe(htmlSafe);\n jsonWriter.setStrictness(strictness == null ? Strictness.LEGACY_STRICT : strictness);\n jsonWriter.setSerializeNulls(serializeNulls);\n return jsonWriter;\n }\n\n /**\n * Returns a new JSON reader configured for the settings on this Gson instance.\n *\n *

    The following settings are considered:\n *\n *

      \n *
    • {@link GsonBuilder#setStrictness(Strictness)}. If no {@linkplain\n * GsonBuilder#setStrictness(Strictness) explicit strictness has been set} the created\n * reader will have a strictness of {@link Strictness#LEGACY_STRICT}. Otherwise, the\n * strictness of the {@code Gson} instance will be used for the created reader.\n *
    \n */\n public JsonReader newJsonReader(Reader reader) {\n JsonReader jsonReader = new JsonReader(reader);\n jsonReader.setStrictness(strictness == null ? Strictness.LEGACY_STRICT : strictness);\n return jsonReader;\n }\n\n /**\n * This method deserializes the specified JSON into an object of the specified class. It is not\n * suitable to use if the specified class is a generic type since it will not have the generic\n * type information because of the Type Erasure feature of Java. Therefore, this method should not\n * be used if the desired type is a generic type. Note that this method works fine if any of the\n * fields of the specified object are generics, just the object itself should not be a generic\n * type. For the cases when the object is of generic type, invoke {@link #fromJson(String,\n * TypeToken)}. If you have the JSON in a {@link Reader} instead of a String, use {@link\n * #fromJson(Reader, Class)} instead.\n *\n *

    An exception is thrown if the JSON string has multiple top-level JSON elements, or if there\n * is trailing data. Use {@link #fromJson(JsonReader, Type)} if this behavior is not desired.\n *\n * @param the type of the desired object\n * @param json the string from which the object is to be deserialized\n * @param classOfT the class of T\n * @return an object of type T from the string. Returns {@code null} if {@code json} is {@code\n * null} or if {@code json} is empty.\n * @throws JsonSyntaxException if json is not a valid representation for an object of type\n * classOfT\n * @see #fromJson(Reader, Class)\n * @see #fromJson(String, TypeToken)\n */\n public T fromJson(String json, Class classOfT) throws JsonSyntaxException {\n T object = fromJson(json, TypeToken.get(classOfT));\n return Primitives.wrap(classOfT).cast(object);\n }\n\n /**\n * This method deserializes the specified JSON into an object of the specified type. This method\n * is useful if the specified object is a generic type. For non-generic objects, use {@link\n * #fromJson(String, Class)} instead. If you have the JSON in a {@link Reader} instead of a\n * String, use {@link #fromJson(Reader, Type)} instead.\n *\n *

    Since {@code Type} is not parameterized by T, this method is not type-safe and should be\n * used carefully. If you are creating the {@code Type} from a {@link TypeToken}, prefer using\n * {@link #fromJson(String, TypeToken)} instead since its return type is based on the {@code\n * TypeToken} and is therefore more type-safe.\n *\n *

    An exception is thrown if the JSON string has multiple top-level JSON elements, or if there\n * is trailing data. Use {@link #fromJson(JsonReader, Type)} if this behavior is not desired.\n *\n * @param the type of the desired object\n * @param json the string from which the object is to be deserialized\n * @param typeOfT The specific genericized type of src\n * @return an object of type T from the string. Returns {@code null} if {@code json} is {@code\n * null} or if {@code json} is empty.\n * @throws JsonSyntaxException if json is not a valid representation for an object of type typeOfT\n * @see #fromJson(Reader, Type)\n * @see #fromJson(String, Class)\n * @see #fromJson(String, TypeToken)\n */\n @SuppressWarnings({\"unchecked\", \"TypeParameterUnusedInFormals\"})\n public T fromJson(String json, Type typeOfT) throws JsonSyntaxException {\n return (T) fromJson(json, TypeToken.get(typeOfT));\n }\n\n /**\n * This method deserializes the specified JSON into an object of the specified type. This method\n * is useful if the specified object is a generic type. For non-generic objects, use {@link\n * #fromJson(String, Class)} instead. If you have the JSON in a {@link Reader} instead of a\n * String, use {@link #fromJson(Reader, TypeToken)} instead.\n *\n *

    An exception is thrown if the JSON string has multiple top-level JSON elements, or if there\n * is trailing data. Use {@link #fromJson(JsonReader, TypeToken)} if this behavior is not desired.\n *\n * @param the type of the desired object\n * @param json the string from which the object is to be deserialized\n * @param typeOfT The specific genericized type of src. You should create an anonymous subclass of\n * {@code TypeToken} with the specific generic type arguments. For example, to get the type\n * for {@code Collection}, you should use:\n *

    \n   * new TypeToken<Collection<Foo>>(){}\n   * 
    \n *\n * @return an object of type T from the string. Returns {@code null} if {@code json} is {@code\n * null} or if {@code json} is empty.\n * @throws JsonSyntaxException if json is not a valid representation for an object of the type\n * typeOfT\n * @see #fromJson(Reader, TypeToken)\n * @see #fromJson(String, Class)\n * @since 2.10\n */\n public T fromJson(String json, TypeToken typeOfT) throws JsonSyntaxException {\n if (json == null) {\n return null;\n }\n StringReader reader = new StringReader(json);\n return fromJson(reader, typeOfT);\n }\n\n /**\n * This method deserializes the JSON read from the specified reader into an object of the\n * specified class. It is not suitable to use if the specified class is a generic type since it\n * will not have the generic type information because of the Type Erasure feature of Java.\n * Therefore, this method should not be used if the desired type is a generic type. Note that this\n * method works fine if any of the fields of the specified object are generics, just the object\n * itself should not be a generic type. For the cases when the object is of generic type, invoke\n * {@link #fromJson(Reader, TypeToken)}. If you have the JSON in a String form instead of a {@link\n * Reader}, use {@link #fromJson(String, Class)} instead.\n *\n *

    An exception is thrown if the JSON data has multiple top-level JSON elements, or if there is\n * trailing data. Use {@link #fromJson(JsonReader, Type)} if this behavior is not desired.\n *\n * @param the type of the desired object\n * @param json the reader producing the JSON from which the object is to be deserialized.\n * @param classOfT the class of T\n * @return an object of type T from the Reader. Returns {@code null} if {@code json} is at EOF.\n * @throws JsonIOException if there was a problem reading from the Reader\n * @throws JsonSyntaxException if json is not a valid representation for an object of type typeOfT\n * @since 1.2\n * @see #fromJson(String, Class)\n * @see #fromJson(Reader, TypeToken)\n */\n public T fromJson(Reader json, Class classOfT)\n throws JsonSyntaxException, JsonIOException {\n T object = fromJson(json, TypeToken.get(classOfT));\n return Primitives.wrap(classOfT).cast(object);\n }\n\n /**\n * This method deserializes the JSON read from the specified reader into an object of the\n * specified type. This method is useful if the specified object is a generic type. For\n * non-generic objects, use {@link #fromJson(Reader, Class)} instead. If you have the JSON in a\n * String form instead of a {@link Reader}, use {@link #fromJson(String, Type)} instead.\n *\n *

    Since {@code Type} is not parameterized by T, this method is not type-safe and should be\n * used carefully. If you are creating the {@code Type} from a {@link TypeToken}, prefer using\n * {@link #fromJson(Reader, TypeToken)} instead since its return type is based on the {@code\n * TypeToken} and is therefore more type-safe.\n *\n *

    An exception is thrown if the JSON data has multiple top-level JSON elements, or if there is\n * trailing data. Use {@link #fromJson(JsonReader, Type)} if this behavior is not desired.\n *\n * @param the type of the desired object\n * @param json the reader producing JSON from which the object is to be deserialized\n * @param typeOfT The specific genericized type of src\n * @return an object of type T from the Reader. Returns {@code null} if {@code json} is at EOF.\n * @throws JsonIOException if there was a problem reading from the Reader\n * @throws JsonSyntaxException if json is not a valid representation for an object of type typeOfT\n * @since 1.2\n * @see #fromJson(String, Type)\n * @see #fromJson(Reader, Class)\n * @see #fromJson(Reader, TypeToken)\n */\n @SuppressWarnings({\"unchecked\", \"TypeParameterUnusedInFormals\"})\n public T fromJson(Reader json, Type typeOfT) throws JsonIOException, JsonSyntaxException {\n return (T) fromJson(json, TypeToken.get(typeOfT));\n }\n\n /**\n * This method deserializes the JSON read from the specified reader into an object of the\n * specified type. This method is useful if the specified object is a generic type. For\n * non-generic objects, use {@link #fromJson(Reader, Class)} instead. If you have the JSON in a\n * String form instead of a {@link Reader}, use {@link #fromJson(String, TypeToken)} instead.\n *\n *

    An exception is thrown if the JSON data has multiple top-level JSON elements, or if there is\n * trailing data. Use {@link #fromJson(JsonReader, TypeToken)} if this behavior is not desired.\n *\n * @param the type of the desired object\n * @param json the reader producing JSON from which the object is to be deserialized\n * @param typeOfT The specific genericized type of src. You should create an anonymous subclass of\n * {@code TypeToken} with the specific generic type arguments. For example, to get the type\n * for {@code Collection}, you should use:\n *

    \n   * new TypeToken<Collection<Foo>>(){}\n   * 
    \n *\n * @return an object of type T from the Reader. Returns {@code null} if {@code json} is at EOF.\n * @throws JsonIOException if there was a problem reading from the Reader\n * @throws JsonSyntaxException if json is not a valid representation for an object of type of\n * typeOfT\n * @see #fromJson(String, TypeToken)\n * @see #fromJson(Reader, Class)\n * @since 2.10\n */\n public T fromJson(Reader json, TypeToken typeOfT)\n throws JsonIOException, JsonSyntaxException {\n JsonReader jsonReader = newJsonReader(json);\n T object = fromJson(jsonReader, typeOfT);\n assertFullConsumption(object, jsonReader);\n return object;\n }\n\n // fromJson(JsonReader, Class) is unfortunately missing and cannot be added now without breaking\n // source compatibility in certain cases, see\n // https://github.com/google/gson/pull/1700#discussion_r973764414\n\n /**\n * Reads the next JSON value from {@code reader} and converts it to an object of type {@code\n * typeOfT}. Returns {@code null}, if the {@code reader} is at EOF.\n *\n *

    Since {@code Type} is not parameterized by T, this method is not type-safe and should be\n * used carefully. If you are creating the {@code Type} from a {@link TypeToken}, prefer using\n * {@link #fromJson(JsonReader, TypeToken)} instead since its return type is based on the {@code\n * TypeToken} and is therefore more type-safe. If the provided type is a {@code Class} the {@code\n * TypeToken} can be created with {@link TypeToken#get(Class)}.\n *\n *

    Unlike the other {@code fromJson} methods, no exception is thrown if the JSON data has\n * multiple top-level JSON elements, or if there is trailing data.\n *\n *

    If the {@code Gson} instance has an {@linkplain GsonBuilder#setStrictness(Strictness)\n * explicit strictness setting}, this setting will be used for reading the JSON regardless of the\n * {@linkplain JsonReader#getStrictness() strictness} of the provided {@link JsonReader}. For\n * legacy reasons, if the {@code Gson} instance has no explicit strictness setting and the reader\n * does not have the strictness {@link Strictness#STRICT}, the JSON will be written in {@link\n * Strictness#LENIENT} mode.
    \n * Note that in all cases the old strictness setting of the reader will be restored when this\n * method returns.\n *\n * @param the type of the desired object\n * @param reader the reader whose next JSON value should be deserialized\n * @param typeOfT The specific genericized type of src\n * @return an object of type T from the JsonReader. Returns {@code null} if {@code reader} is at\n * EOF.\n * @throws JsonIOException if there was a problem reading from the JsonReader\n * @throws JsonSyntaxException if json is not a valid representation for an object of type typeOfT\n * @see #fromJson(Reader, Type)\n * @see #fromJson(JsonReader, TypeToken)\n */\n @SuppressWarnings({\"unchecked\", \"TypeParameterUnusedInFormals\"})\n public T fromJson(JsonReader reader, Type typeOfT)\n throws JsonIOException, JsonSyntaxException {\n return (T) fromJson(reader, TypeToken.get(typeOfT));\n }\n\n /**\n * Reads the next JSON value from {@code reader} and converts it to an object of type {@code\n * typeOfT}. Returns {@code null}, if the {@code reader} is at EOF. This method is useful if the\n * specified object is a generic type. For non-generic objects, {@link #fromJson(JsonReader,\n * Type)} can be called, or {@link TypeToken#get(Class)} can be used to create the type token.\n *\n *

    Unlike the other {@code fromJson} methods, no exception is thrown if the JSON data has\n * multiple top-level JSON elements, or if there is trailing data.\n *\n *

    If the {@code Gson} instance has an {@linkplain GsonBuilder#setStrictness(Strictness)\n * explicit strictness setting}, this setting will be used for reading the JSON regardless of the\n * {@linkplain JsonReader#getStrictness() strictness} of the provided {@link JsonReader}. For\n * legacy reasons, if the {@code Gson} instance has no explicit strictness setting and the reader\n * does not have the strictness {@link Strictness#STRICT}, the JSON will be written in {@link\n * Strictness#LENIENT} mode.
    \n * Note that in all cases the old strictness setting of the reader will be restored when this\n * method returns.\n *\n * @param the type of the desired object\n * @param reader the reader whose next JSON value should be deserialized\n * @param typeOfT The specific genericized type of src. You should create an anonymous subclass of\n * {@code TypeToken} with the specific generic type arguments. For example, to get the type\n * for {@code Collection}, you should use:\n *

    \n   * new TypeToken<Collection<Foo>>(){}\n   * 
    \n *\n * @return an object of type T from the JsonReader. Returns {@code null} if {@code reader} is at\n * EOF.\n * @throws JsonIOException if there was a problem reading from the JsonReader\n * @throws JsonSyntaxException if json is not a valid representation for an object of the type\n * typeOfT\n * @see #fromJson(Reader, TypeToken)\n * @see #fromJson(JsonReader, Type)\n * @since 2.10\n */\n public T fromJson(JsonReader reader, TypeToken typeOfT)\n throws JsonIOException, JsonSyntaxException {\n boolean isEmpty = true;\n Strictness oldStrictness = reader.getStrictness();\n\n if (this.strictness != null) {\n reader.setStrictness(this.strictness);\n } else if (reader.getStrictness() != Strictness.STRICT) {\n reader.setStrictness(Strictness.LENIENT);\n }\n\n try {\n JsonToken unused = reader.peek();\n isEmpty = false;\n TypeAdapter typeAdapter = getAdapter(typeOfT);\n return typeAdapter.read(reader);\n } catch (EOFException e) {\n /*\n * For compatibility with JSON 1.5 and earlier, we return null for empty\n * documents instead of throwing.\n */\n if (isEmpty) {\n return null;\n }\n throw new JsonSyntaxException(e);\n } catch (IllegalStateException e) {\n throw new JsonSyntaxException(e);\n } catch (IOException e) {\n // TODO(inder): Figure out whether it is indeed right to rethrow this as JsonSyntaxException\n throw new JsonSyntaxException(e);\n } catch (AssertionError e) {\n throw new AssertionError(\n \"AssertionError (GSON \" + GsonBuildConfig.VERSION + \"): \" + e.getMessage(), e);\n } finally {\n reader.setStrictness(oldStrictness);\n }\n }\n\n /**\n * This method deserializes the JSON read from the specified parse tree into an object of the\n * specified type. It is not suitable to use if the specified class is a generic type since it\n * will not have the generic type information because of the Type Erasure feature of Java.\n * Therefore, this method should not be used if the desired type is a generic type. Note that this\n * method works fine if any of the fields of the specified object are generics, just the object\n * itself should not be a generic type. For the cases when the object is of generic type, invoke\n * {@link #fromJson(JsonElement, TypeToken)}.\n *\n * @param the type of the desired object\n * @param json the root of the parse tree of {@link JsonElement}s from which the object is to be\n * deserialized\n * @param classOfT The class of T\n * @return an object of type T from the JSON. Returns {@code null} if {@code json} is {@code null}\n * or if {@code json} is empty.\n...\n", "needle_token_start": 4029, "needle_token_end": 4050, "code_context_ntokens": 16261, "output": ["```java\n/**\n * Returns a new GsonBuilder containing all custom factories and configuration used by the current\n * instance.\n *\n * @return a GsonBuilder instance.\n * @since 2.8.3\n */\npublic GsonBuilder newBuilder() {\n return new GsonBuilder(this);\n}\n```"]} +{"repo": "google/gson", "name": "getAsJsonNull", "language": "java", "path": "gson/src/main/java/com/google/gson/JsonElement.java", "position_ratio": 0.35, "description": "\nFunction Description:\n1. **Purpose**: To retrieve the current element as a specific null type representation if it is indeed that type, ensuring type safety before casting.\n2. **Input**: No parameters are required as it operates on the current instance of the element.\n3. **Output**: Returns the element as a null type representation if it is valid; otherwise, it throws an exception.\n4. **Procedure**: The method first checks if the current element is of the null type using a type-checking method. If the check passes, the element is cast to the null type and returned. If the check fails, an exception is thrown indicating that the element is not of the expected null type.\n\n", "instruction": "Based on the function description and code context, please retrieve and repeat the exact described function from the code context in a code block wrapped by ```:", "template": "instruction\ncode_context\ndescription\ninstruction", "code_context": "// Path: gson/src/main/java/com/google/gson/internal/bind/TypeAdapters.java\n/*\n * Copyright (C) 2011 Google Inc.\n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n */\n\npackage com.google.gson.internal.bind;\n\nimport com.google.gson.Gson;\nimport com.google.gson.JsonArray;\nimport com.google.gson.JsonElement;\nimport com.google.gson.JsonIOException;\nimport com.google.gson.JsonNull;\nimport com.google.gson.JsonObject;\nimport com.google.gson.JsonPrimitive;\nimport com.google.gson.JsonSyntaxException;\nimport com.google.gson.TypeAdapter;\nimport com.google.gson.TypeAdapterFactory;\nimport com.google.gson.annotations.SerializedName;\nimport com.google.gson.internal.LazilyParsedNumber;\nimport com.google.gson.internal.NumberLimits;\nimport com.google.gson.internal.TroubleshootingGuide;\nimport com.google.gson.reflect.TypeToken;\nimport com.google.gson.stream.JsonReader;\nimport com.google.gson.stream.JsonToken;\nimport com.google.gson.stream.JsonWriter;\nimport java.io.IOException;\nimport java.lang.reflect.AccessibleObject;\nimport java.lang.reflect.Field;\nimport java.math.BigDecimal;\nimport java.math.BigInteger;\nimport java.net.InetAddress;\nimport java.net.URI;\nimport java.net.URISyntaxException;\nimport java.net.URL;\nimport java.security.AccessController;\nimport java.security.PrivilegedAction;\nimport java.util.ArrayDeque;\nimport java.util.ArrayList;\nimport java.util.BitSet;\nimport java.util.Calendar;\nimport java.util.Currency;\nimport java.util.Deque;\nimport java.util.GregorianCalendar;\nimport java.util.HashMap;\nimport java.util.List;\nimport java.util.Locale;\nimport java.util.Map;\nimport java.util.StringTokenizer;\nimport java.util.UUID;\nimport java.util.concurrent.atomic.AtomicBoolean;\nimport java.util.concurrent.atomic.AtomicInteger;\nimport java.util.concurrent.atomic.AtomicIntegerArray;\n\n/** Type adapters for basic types. */\npublic final class TypeAdapters {\n private TypeAdapters() {\n throw new UnsupportedOperationException();\n }\n\n @SuppressWarnings(\"rawtypes\")\n public static final TypeAdapter CLASS =\n new TypeAdapter() {\n @Override\n public void write(JsonWriter out, Class value) throws IOException {\n throw new UnsupportedOperationException(\n \"Attempted to serialize java.lang.Class: \"\n + value.getName()\n + \". Forgot to register a type adapter?\"\n + \"\\nSee \"\n + TroubleshootingGuide.createUrl(\"java-lang-class-unsupported\"));\n }\n\n @Override\n public Class read(JsonReader in) throws IOException {\n throw new UnsupportedOperationException(\n \"Attempted to deserialize a java.lang.Class. Forgot to register a type adapter?\"\n + \"\\nSee \"\n + TroubleshootingGuide.createUrl(\"java-lang-class-unsupported\"));\n }\n }.nullSafe();\n\n public static final TypeAdapterFactory CLASS_FACTORY = newFactory(Class.class, CLASS);\n\n public static final TypeAdapter BIT_SET =\n new TypeAdapter() {\n @Override\n public BitSet read(JsonReader in) throws IOException {\n BitSet bitset = new BitSet();\n in.beginArray();\n int i = 0;\n JsonToken tokenType = in.peek();\n while (tokenType != JsonToken.END_ARRAY) {\n boolean set;\n switch (tokenType) {\n case NUMBER:\n case STRING:\n int intValue = in.nextInt();\n if (intValue == 0) {\n set = false;\n } else if (intValue == 1) {\n set = true;\n } else {\n throw new JsonSyntaxException(\n \"Invalid bitset value \"\n + intValue\n + \", expected 0 or 1; at path \"\n + in.getPreviousPath());\n }\n break;\n case BOOLEAN:\n set = in.nextBoolean();\n break;\n default:\n throw new JsonSyntaxException(\n \"Invalid bitset value type: \" + tokenType + \"; at path \" + in.getPath());\n }\n if (set) {\n bitset.set(i);\n }\n ++i;\n tokenType = in.peek();\n }\n in.endArray();\n return bitset;\n }\n\n @Override\n public void write(JsonWriter out, BitSet src) throws IOException {\n out.beginArray();\n for (int i = 0, length = src.length(); i < length; i++) {\n int value = src.get(i) ? 1 : 0;\n out.value(value);\n }\n out.endArray();\n }\n }.nullSafe();\n\n public static final TypeAdapterFactory BIT_SET_FACTORY = newFactory(BitSet.class, BIT_SET);\n\n public static final TypeAdapter BOOLEAN =\n new TypeAdapter() {\n @Override\n public Boolean read(JsonReader in) throws IOException {\n JsonToken peek = in.peek();\n if (peek == JsonToken.NULL) {\n in.nextNull();\n return null;\n } else if (peek == JsonToken.STRING) {\n // support strings for compatibility with GSON 1.7\n return Boolean.parseBoolean(in.nextString());\n }\n return in.nextBoolean();\n }\n\n @Override\n public void write(JsonWriter out, Boolean value) throws IOException {\n out.value(value);\n }\n };\n\n /**\n * Writes a boolean as a string. Useful for map keys, where booleans aren't otherwise permitted.\n */\n public static final TypeAdapter BOOLEAN_AS_STRING =\n new TypeAdapter() {\n @Override\n public Boolean read(JsonReader in) throws IOException {\n if (in.peek() == JsonToken.NULL) {\n in.nextNull();\n return null;\n }\n return Boolean.valueOf(in.nextString());\n }\n\n @Override\n public void write(JsonWriter out, Boolean value) throws IOException {\n out.value(value == null ? \"null\" : value.toString());\n }\n };\n\n public static final TypeAdapterFactory BOOLEAN_FACTORY =\n newFactory(boolean.class, Boolean.class, BOOLEAN);\n\n public static final TypeAdapter BYTE =\n new TypeAdapter() {\n @Override\n public Number read(JsonReader in) throws IOException {\n if (in.peek() == JsonToken.NULL) {\n in.nextNull();\n return null;\n }\n\n int intValue;\n try {\n intValue = in.nextInt();\n } catch (NumberFormatException e) {\n throw new JsonSyntaxException(e);\n }\n // Allow up to 255 to support unsigned values\n if (intValue > 255 || intValue < Byte.MIN_VALUE) {\n throw new JsonSyntaxException(\n \"Lossy conversion from \" + intValue + \" to byte; at path \" + in.getPreviousPath());\n }\n return (byte) intValue;\n }\n\n @Override\n public void write(JsonWriter out, Number value) throws IOException {\n if (value == null) {\n out.nullValue();\n } else {\n out.value(value.byteValue());\n }\n }\n };\n\n public static final TypeAdapterFactory BYTE_FACTORY = newFactory(byte.class, Byte.class, BYTE);\n\n public static final TypeAdapter SHORT =\n new TypeAdapter() {\n @Override\n public Number read(JsonReader in) throws IOException {\n if (in.peek() == JsonToken.NULL) {\n in.nextNull();\n return null;\n }\n\n int intValue;\n try {\n intValue = in.nextInt();\n } catch (NumberFormatException e) {\n throw new JsonSyntaxException(e);\n }\n // Allow up to 65535 to support unsigned values\n...\n// Path: gson/src/main/java/com/google/gson/stream/MalformedJsonException.java\n/*\n * Copyright (C) 2010 Google Inc.\n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n */\n\npackage com.google.gson.stream;\n\nimport com.google.gson.Strictness;\nimport java.io.IOException;\n\n/**\n * Thrown when a reader encounters malformed JSON. Some syntax errors can be ignored by using {@link\n * Strictness#LENIENT} for {@link JsonReader#setStrictness(Strictness)}.\n */\npublic final class MalformedJsonException extends IOException {\n private static final long serialVersionUID = 1L;\n\n public MalformedJsonException(String msg) {\n super(msg);\n }\n\n public MalformedJsonException(String msg, Throwable throwable) {\n super(msg, throwable);\n }\n\n public MalformedJsonException(Throwable throwable) {\n super(throwable);\n }\n}\n\n// Path: gson/src/main/java/com/google/gson/internal/Streams.java\n/*\n * Copyright (C) 2010 Google Inc.\n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n */\n\npackage com.google.gson.internal;\n\nimport com.google.gson.JsonElement;\nimport com.google.gson.JsonIOException;\nimport com.google.gson.JsonNull;\nimport com.google.gson.JsonParseException;\nimport com.google.gson.JsonSyntaxException;\nimport com.google.gson.internal.bind.TypeAdapters;\nimport com.google.gson.stream.JsonReader;\nimport com.google.gson.stream.JsonToken;\nimport com.google.gson.stream.JsonWriter;\nimport com.google.gson.stream.MalformedJsonException;\nimport java.io.EOFException;\nimport java.io.IOException;\nimport java.io.Writer;\nimport java.util.Objects;\n\n/** Reads and writes GSON parse trees over streams. */\npublic final class Streams {\n private Streams() {\n throw new UnsupportedOperationException();\n }\n\n /** Takes a reader in any state and returns the next value as a JsonElement. */\n public static JsonElement parse(JsonReader reader) throws JsonParseException {\n boolean isEmpty = true;\n try {\n JsonToken unused = reader.peek();\n isEmpty = false;\n return TypeAdapters.JSON_ELEMENT.read(reader);\n } catch (EOFException e) {\n /*\n * For compatibility with JSON 1.5 and earlier, we return a JsonNull for\n * empty documents instead of throwing.\n */\n if (isEmpty) {\n return JsonNull.INSTANCE;\n }\n // The stream ended prematurely so it is likely a syntax error.\n throw new JsonSyntaxException(e);\n } catch (MalformedJsonException e) {\n throw new JsonSyntaxException(e);\n } catch (IOException e) {\n throw new JsonIOException(e);\n } catch (NumberFormatException e) {\n throw new JsonSyntaxException(e);\n }\n }\n\n /** Writes the JSON element to the writer, recursively. */\n public static void write(JsonElement element, JsonWriter writer) throws IOException {\n TypeAdapters.JSON_ELEMENT.write(writer, element);\n }\n\n public static Writer writerForAppendable(Appendable appendable) {\n return appendable instanceof Writer ? (Writer) appendable : new AppendableWriter(appendable);\n }\n\n /** Adapts an {@link Appendable} so it can be passed anywhere a {@link Writer} is used. */\n private static final class AppendableWriter extends Writer {\n private final Appendable appendable;\n private final CurrentWrite currentWrite = new CurrentWrite();\n\n AppendableWriter(Appendable appendable) {\n this.appendable = appendable;\n }\n\n @SuppressWarnings(\"UngroupedOverloads\") // this is intentionally ungrouped, see comment below\n @Override\n public void write(char[] chars, int offset, int length) throws IOException {\n currentWrite.setChars(chars);\n appendable.append(currentWrite, offset, offset + length);\n }\n\n @Override\n public void flush() {}\n\n @Override\n public void close() {}\n\n // Override these methods for better performance\n // They would otherwise unnecessarily create Strings or char arrays\n\n @Override\n public void write(int i) throws IOException {\n appendable.append((char) i);\n }\n\n @Override\n public void write(String str, int off, int len) throws IOException {\n // Appendable.append turns null -> \"null\", which is not desired here\n Objects.requireNonNull(str);\n appendable.append(str, off, off + len);\n }\n\n @Override\n public Writer append(CharSequence csq) throws IOException {\n appendable.append(csq);\n return this;\n }\n\n @Override\n public Writer append(CharSequence csq, int start, int end) throws IOException {\n appendable.append(csq, start, end);\n return this;\n }\n\n /** A mutable char sequence pointing at a single char[]. */\n private static class CurrentWrite implements CharSequence {\n private char[] chars;\n private String cachedString;\n\n void setChars(char[] chars) {\n this.chars = chars;\n this.cachedString = null;\n }\n\n @Override\n public int length() {\n return chars.length;\n }\n\n @Override\n public char charAt(int i) {\n return chars[i];\n }\n\n @Override\n public CharSequence subSequence(int start, int end) {\n return new String(chars, start, end - start);\n }\n\n // Must return string representation to satisfy toString() contract\n @Override\n public String toString() {\n if (cachedString == null) {\n cachedString = new String(chars);\n }\n return cachedString;\n }\n }\n }\n}\n\n// Path: gson/src/main/java/com/google/gson/JsonElement.java\n/*\n * Copyright (C) 2008 Google Inc.\n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n */\n\npackage com.google.gson;\n\nimport com.google.errorprone.annotations.CanIgnoreReturnValue;\nimport com.google.gson.internal.Streams;\nimport com.google.gson.stream.JsonWriter;\nimport java.io.IOException;\nimport java.io.StringWriter;\nimport java.math.BigDecimal;\nimport java.math.BigInteger;\n\n/**\n * A class representing an element of JSON. It could either be a {@link JsonObject}, a {@link\n * JsonArray}, a {@link JsonPrimitive} or a {@link JsonNull}.\n *\n * @author Inderjeet Singh\n * @author Joel Leitch\n */\npublic abstract class JsonElement {\n /**\n * @deprecated Creating custom {@code JsonElement} subclasses is highly discouraged and can lead\n * to undefined behavior.
    \n * This constructor is only kept for backward compatibility.\n */\n @Deprecated\n public JsonElement() {}\n\n /**\n * Returns a deep copy of this element. Immutable elements like primitives and nulls are not\n * copied.\n *\n * @since 2.8.2\n */\n public abstract JsonElement deepCopy();\n\n /**\n * Provides a check for verifying if this element is a JSON array or not.\n *\n * @return true if this element is of type {@link JsonArray}, false otherwise.\n */\n public boolean isJsonArray() {\n return this instanceof JsonArray;\n }\n\n /**\n * Provides a check for verifying if this element is a JSON object or not.\n *\n * @return true if this element is of type {@link JsonObject}, false otherwise.\n */\n public boolean isJsonObject() {\n return this instanceof JsonObject;\n }\n\n /**\n * Provides a check for verifying if this element is a primitive or not.\n *\n * @return true if this element is of type {@link JsonPrimitive}, false otherwise.\n */\n public boolean isJsonPrimitive() {\n return this instanceof JsonPrimitive;\n }\n\n /**\n * Provides a check for verifying if this element represents a null value or not.\n *\n * @return true if this element is of type {@link JsonNull}, false otherwise.\n * @since 1.2\n */\n public boolean isJsonNull() {\n return this instanceof JsonNull;\n }\n\n /**\n * Convenience method to get this element as a {@link JsonObject}. If this element is of some\n * other type, an {@link IllegalStateException} will result. Hence it is best to use this method\n * after ensuring that this element is of the desired type by calling {@link #isJsonObject()}\n * first.\n *\n * @return this element as a {@link JsonObject}.\n * @throws IllegalStateException if this element is of another type.\n */\n public JsonObject getAsJsonObject() {\n if (isJsonObject()) {\n return (JsonObject) this;\n }\n throw new IllegalStateException(\"Not a JSON Object: \" + this);\n }\n\n /**\n * Convenience method to get this element as a {@link JsonArray}. If this element is of some other\n * type, an {@link IllegalStateException} will result. Hence it is best to use this method after\n * ensuring that this element is of the desired type by calling {@link #isJsonArray()} first.\n *\n * @return this element as a {@link JsonArray}.\n * @throws IllegalStateException if this element is of another type.\n */\n public JsonArray getAsJsonArray() {\n if (isJsonArray()) {\n return (JsonArray) this;\n }\n throw new IllegalStateException(\"Not a JSON Array: \" + this);\n }\n\n /**\n * Convenience method to get this element as a {@link JsonPrimitive}. If this element is of some\n * other type, an {@link IllegalStateException} will result. Hence it is best to use this method\n * after ensuring that this element is of the desired type by calling {@link #isJsonPrimitive()}\n * first.\n *\n * @return this element as a {@link JsonPrimitive}.\n * @throws IllegalStateException if this element is of another type.\n */\n public JsonPrimitive getAsJsonPrimitive() {\n if (isJsonPrimitive()) {\n return (JsonPrimitive) this;\n }\n throw new IllegalStateException(\"Not a JSON Primitive: \" + this);\n }\n\n /**\n * Convenience method to get this element as a {@link JsonNull}. If this element is of some other\n * type, an {@link IllegalStateException} will result. Hence it is best to use this method after\n * ensuring that this element is of the desired type by calling {@link #isJsonNull()} first.\n *\n * @return this element as a {@link JsonNull}.\n * @throws IllegalStateException if this element is of another type.\n * @since 1.2\n */\n \n@CanIgnoreReturnValue // When this method is used only to verify that the value is JsonNull\n public JsonNull getAsJsonNull() {\n if (isJsonNull()) {\n return (JsonNull) this;\n }\n throw new IllegalStateException(\"Not a JSON Null: \" + this);\n }\n\n /**\n * Convenience method to get this element as a boolean value.\n *\n * @return this element as a primitive boolean value.\n * @throws UnsupportedOperationException if this element is not a {@link JsonPrimitive} or {@link\n * JsonArray}.\n * @throws IllegalStateException if this element is of the type {@link JsonArray} but contains\n * more than a single element.\n */\n public boolean getAsBoolean() {\n throw new UnsupportedOperationException(getClass().getSimpleName());\n }\n\n /**\n * Convenience method to get this element as a {@link Number}.\n *\n * @return this element as a {@link Number}.\n * @throws UnsupportedOperationException if this element is not a {@link JsonPrimitive} or {@link\n * JsonArray}, or cannot be converted to a number.\n * @throws IllegalStateException if this element is of the type {@link JsonArray} but contains\n * more than a single element.\n */\n public Number getAsNumber() {\n throw new UnsupportedOperationException(getClass().getSimpleName());\n }\n\n /**\n * Convenience method to get this element as a string value.\n *\n * @return this element as a string value.\n * @throws UnsupportedOperationException if this element is not a {@link JsonPrimitive} or {@link\n * JsonArray}.\n * @throws IllegalStateException if this element is of the type {@link JsonArray} but contains\n * more than a single element.\n */\n public String getAsString() {\n throw new UnsupportedOperationException(getClass().getSimpleName());\n }\n\n /**\n * Convenience method to get this element as a primitive double value.\n *\n * @return this element as a primitive double value.\n * @throws UnsupportedOperationException if this element is not a {@link JsonPrimitive} or {@link\n * JsonArray}.\n * @throws NumberFormatException if the value contained is not a valid double.\n * @throws IllegalStateException if this element is of the type {@link JsonArray} but contains\n * more than a single element.\n */\n public double getAsDouble() {\n throw new UnsupportedOperationException(getClass().getSimpleName());\n }\n\n /**\n * Convenience method to get this element as a primitive float value.\n *\n * @return this element as a primitive float value.\n * @throws UnsupportedOperationException if this element is not a {@link JsonPrimitive} or {@link\n * JsonArray}.\n * @throws NumberFormatException if the value contained is not a valid float.\n * @throws IllegalStateException if this element is of the type {@link JsonArray} but contains\n * more than a single element.\n */\n public float getAsFloat() {\n throw new UnsupportedOperationException(getClass().getSimpleName());\n }\n\n /**\n * Convenience method to get this element as a primitive long value.\n *\n * @return this element as a primitive long value.\n * @throws UnsupportedOperationException if this element is not a {@link JsonPrimitive} or {@link\n * JsonArray}.\n * @throws NumberFormatException if the value contained is not a valid long.\n * @throws IllegalStateException if this element is of the type {@link JsonArray} but contains\n * more than a single element.\n */\n public long getAsLong() {\n throw new UnsupportedOperationException(getClass().getSimpleName());\n }\n\n /**\n * Convenience method to get this element as a primitive integer value.\n *\n * @return this element as a primitive integer value.\n * @throws UnsupportedOperationException if this element is not a {@link JsonPrimitive} or {@link\n * JsonArray}.\n * @throws NumberFormatException if the value contained is not a valid integer.\n * @throws IllegalStateException if this element is of the type {@link JsonArray} but contains\n * more than a single element.\n */\n public int getAsInt() {\n throw new UnsupportedOperationException(getClass().getSimpleName());\n }\n\n /**\n * Convenience method to get this element as a primitive byte value.\n *\n * @return this element as a primitive byte value.\n * @throws UnsupportedOperationException if this element is not a {@link JsonPrimitive} or {@link\n * JsonArray}.\n * @throws NumberFormatException if the value contained is not a valid byte.\n * @throws IllegalStateException if this element is of the type {@link JsonArray} but contains\n * more than a single element.\n * @since 1.3\n */\n public byte getAsByte() {\n throw new UnsupportedOperationException(getClass().getSimpleName());\n }\n\n /**\n * Convenience method to get the first character of the string value of this element.\n *\n * @return the first character of the string value.\n * @throws UnsupportedOperationException if this element is not a {@link JsonPrimitive} or {@link\n * JsonArray}, or if its string value is empty.\n * @throws IllegalStateException if this element is of the type {@link JsonArray} but contains\n * more than a single element.\n * @since 1.3\n * @deprecated This method is misleading, as it does not get this element as a char but rather as\n * a string's first character.\n */\n @Deprecated\n public char getAsCharacter() {\n throw new UnsupportedOperationException(getClass().getSimpleName());\n }\n\n /**\n * Convenience method to get this element as a {@link BigDecimal}.\n *\n * @return this element as a {@link BigDecimal}.\n * @throws UnsupportedOperationException if this element is not a {@link JsonPrimitive} or {@link\n * JsonArray}.\n * @throws NumberFormatException if this element is not a valid {@link BigDecimal}.\n * @throws IllegalStateException if this element is of the type {@link JsonArray} but contains\n * more than a single element.\n * @since 1.2\n */\n public BigDecimal getAsBigDecimal() {\n throw new UnsupportedOperationException(getClass().getSimpleName());\n }\n\n /**\n * Convenience method to get this element as a {@link BigInteger}.\n *\n * @return this element as a {@link BigInteger}.\n * @throws UnsupportedOperationException if this element is not a {@link JsonPrimitive} or {@link\n * JsonArray}.\n * @throws NumberFormatException if this element is not a valid {@link BigInteger}.\n * @throws IllegalStateException if this element is of the type {@link JsonArray} but contains\n * more than a single element.\n * @since 1.2\n */\n public BigInteger getAsBigInteger() {\n throw new UnsupportedOperationException(getClass().getSimpleName());\n }\n\n /**\n * Convenience method to get this element as a primitive short value.\n *\n * @return this element as a primitive short value.\n * @throws UnsupportedOperationException if this element is not a {@link JsonPrimitive} or {@link\n * JsonArray}.\n * @throws NumberFormatException if the value contained is not a valid short.\n * @throws IllegalStateException if this element is of the type {@link JsonArray} but contains\n * more than a single element.\n */\n public short getAsShort() {\n throw new UnsupportedOperationException(getClass().getSimpleName());\n }\n\n /** Returns a String representation of this element. */\n @Override\n public String toString() {\n try {\n StringWriter stringWriter = new StringWriter();\n JsonWriter jsonWriter = new JsonWriter(stringWriter);\n // Make writer lenient because toString() must not fail, even if for example JsonPrimitive\n // contains NaN\n jsonWriter.setStrictness(Strictness.LENIENT);\n Streams.write(this, jsonWriter);\n return stringWriter.toString();\n } catch (IOException e) {\n throw new AssertionError(e);\n }\n }\n}\n\n// Path: gson/src/main/java/com/google/gson/internal/bind/JsonTreeReader.java\n/*\n * Copyright (C) 2011 Google Inc.\n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n */\n\npackage com.google.gson.internal.bind;\n\nimport com.google.errorprone.annotations.CanIgnoreReturnValue;\nimport com.google.gson.JsonArray;\nimport com.google.gson.JsonElement;\nimport com.google.gson.JsonNull;\nimport com.google.gson.JsonObject;\nimport com.google.gson.JsonPrimitive;\nimport com.google.gson.stream.JsonReader;\nimport com.google.gson.stream.JsonToken;\nimport com.google.gson.stream.MalformedJsonException;\nimport java.io.IOException;\nimport java.io.Reader;\nimport java.util.Arrays;\nimport java.util.Iterator;\nimport java.util.Map;\n\n/**\n * This reader walks the elements of a JsonElement as if it was coming from a character stream.\n *\n * @author Jesse Wilson\n */\npublic final class JsonTreeReader extends JsonReader {\n private static final Reader UNREADABLE_READER =\n new Reader() {\n @Override\n public int read(char[] buffer, int offset, int count) {\n throw new AssertionError();\n }\n\n @Override\n public void close() {\n throw new AssertionError();\n }\n };\n private static final Object SENTINEL_CLOSED = new Object();\n\n /*\n * The nesting stack. Using a manual array rather than an ArrayList saves 20%.\n */\n private Object[] stack = new Object[32];\n private int stackSize = 0;\n\n /*\n * The path members. It corresponds directly to stack: At indices where the\n * stack contains an object (EMPTY_OBJECT, DANGLING_NAME or NONEMPTY_OBJECT),\n * pathNames contains the name at this scope. Where it contains an array\n * (EMPTY_ARRAY, NONEMPTY_ARRAY) pathIndices contains the current index in\n * that array. Otherwise the value is undefined, and we take advantage of that\n * by incrementing pathIndices when doing so isn't useful.\n */\n private String[] pathNames = new String[32];\n private int[] pathIndices = new int[32];\n\n public JsonTreeReader(JsonElement element) {\n super(UNREADABLE_READER);\n push(element);\n }\n\n @Override\n public void beginArray() throws IOException {\n expect(JsonToken.BEGIN_ARRAY);\n JsonArray array = (JsonArray) peekStack();\n push(array.iterator());\n pathIndices[stackSize - 1] = 0;\n }\n\n @Override\n public void endArray() throws IOException {\n expect(JsonToken.END_ARRAY);\n popStack(); // empty iterator\n popStack(); // array\n if (stackSize > 0) {\n pathIndices[stackSize - 1]++;\n }\n }\n\n @Override\n public void beginObject() throws IOException {\n expect(JsonToken.BEGIN_OBJECT);\n JsonObject object = (JsonObject) peekStack();\n push(object.entrySet().iterator());\n }\n\n @Override\n public void endObject() throws IOException {\n expect(JsonToken.END_OBJECT);\n pathNames[stackSize - 1] = null; // Free the last path name so that it can be garbage collected\n popStack(); // empty iterator\n popStack(); // object\n if (stackSize > 0) {\n pathIndices[stackSize - 1]++;\n }\n }\n\n @Override\n public boolean hasNext() throws IOException {\n JsonToken token = peek();\n return token != JsonToken.END_OBJECT\n && token != JsonToken.END_ARRAY\n && token != JsonToken.END_DOCUMENT;\n }\n\n @Override\n public JsonToken peek() throws IOException {\n if (stackSize == 0) {\n return JsonToken.END_DOCUMENT;\n }\n\n Object o = peekStack();\n if (o instanceof Iterator) {\n boolean isObject = stack[stackSize - 2] instanceof JsonObject;\n Iterator iterator = (Iterator) o;\n if (iterator.hasNext()) {\n if (isObject) {\n return JsonToken.NAME;\n } else {\n push(iterator.next());\n return peek();\n }\n } else {\n return isObject ? JsonToken.END_OBJECT : JsonToken.END_ARRAY;\n }\n } else if (o instanceof JsonObject) {\n return JsonToken.BEGIN_OBJECT;\n } else if (o instanceof JsonArray) {\n return JsonToken.BEGIN_ARRAY;\n } else if (o instanceof JsonPrimitive) {\n JsonPrimitive primitive = (JsonPrimitive) o;\n if (primitive.isString()) {\n return JsonToken.STRING;\n } else if (primitive.isBoolean()) {\n return JsonToken.BOOLEAN;\n } else if (primitive.isNumber()) {\n return JsonToken.NUMBER;\n } else {\n throw new AssertionError();\n }\n } else if (o instanceof JsonNull) {\n return JsonToken.NULL;\n } else if (o == SENTINEL_CLOSED) {\n throw new IllegalStateException(\"JsonReader is closed\");\n } else {\n throw new MalformedJsonException(\n \"Custom JsonElement subclass \" + o.getClass().getName() + \" is not supported\");\n }\n }\n\n private Object peekStack() {\n return stack[stackSize - 1];\n }\n\n @CanIgnoreReturnValue\n private Object popStack() {\n Object result = stack[--stackSize];\n stack[stackSize] = null;\n return result;\n }\n\n private void expect(JsonToken expected) throws IOException {\n if (peek() != expected) {\n throw new IllegalStateException(\n \"Expected \" + expected + \" but was \" + peek() + locationString());\n }\n }\n\n private String nextName(boolean skipName) throws IOException {\n expect(JsonToken.NAME);\n Iterator i = (Iterator) peekStack();\n Map.Entry entry = (Map.Entry) i.next();\n String result = (String) entry.getKey();\n pathNames[stackSize - 1] = skipName ? \"\" : result;\n push(entry.getValue());\n return result;\n }\n\n @Override\n public String nextName() throws IOException {\n return nextName(false);\n }\n\n @Override\n public String nextString() throws IOException {\n JsonToken token = peek();\n if (token != JsonToken.STRING && token != JsonToken.NUMBER) {\n throw new IllegalStateException(\n \"Expected \" + JsonToken.STRING + \" but was \" + token + locationString());\n }\n String result = ((JsonPrimitive) popStack()).getAsString();\n if (stackSize > 0) {\n pathIndices[stackSize - 1]++;\n }\n return result;\n }\n\n @Override\n public boolean nextBoolean() throws IOException {\n expect(JsonToken.BOOLEAN);\n boolean result = ((JsonPrimitive) popStack()).getAsBoolean();\n if (stackSize > 0) {\n pathIndices[stackSize - 1]++;\n }\n return result;\n }\n\n @Override\n public void nextNull() throws IOException {\n expect(JsonToken.NULL);\n popStack();\n if (stackSize > 0) {\n pathIndices[stackSize - 1]++;\n }\n }\n\n @Override\n public double nextDouble() throws IOException {\n JsonToken token = peek();\n if (token != JsonToken.NUMBER && token != JsonToken.STRING) {\n throw new IllegalStateException(\n \"Expected \" + JsonToken.NUMBER + \" but was \" + token + locationString());\n }\n double result = ((JsonPrimitive) peekStack()).getAsDouble();\n if (!isLenient() && (Double.isNaN(result) || Double.isInfinite(result))) {\n throw new MalformedJsonException(\"JSON forbids NaN and infinities: \" + result);\n }\n popStack();\n if (stackSize > 0) {\n pathIndices[stackSize - 1]++;\n }\n return result;\n }\n\n @Override\n public long nextLong() throws IOException {\n JsonToken token = peek();\n if (token != JsonToken.NUMBER && token != JsonToken.STRING) {\n throw new IllegalStateException(\n \"Expected \" + JsonToken.NUMBER + \" but was \" + token + locationString());\n }\n long result = ((JsonPrimitive) peekStack()).getAsLong();\n popStack();\n if (stackSize > 0) {\n pathIndices[stackSize - 1]++;\n }\n return result;\n }\n\n @Override\n public int nextInt() throws IOException {\n JsonToken token = peek();\n if (token != JsonToken.NUMBER && token != JsonToken.STRING) {\n throw new IllegalStateException(\n \"Expected \" + JsonToken.NUMBER + \" but was \" + token + locationString());\n }\n int result = ((JsonPrimitive) peekStack()).getAsInt();\n popStack();\n if (stackSize > 0) {\n pathIndices[stackSize - 1]++;\n }\n return result;\n }\n\n JsonElement nextJsonElement() throws IOException {\n final JsonToken peeked = peek();\n if (peeked == JsonToken.NAME\n || peeked == JsonToken.END_ARRAY\n || peeked == JsonToken.END_OBJECT\n || peeked == JsonToken.END_DOCUMENT) {\n throw new IllegalStateException(\"Unexpected \" + peeked + \" when reading a JsonElement.\");\n }\n final JsonElement element = (JsonElement) peekStack();\n skipValue();\n return element;\n }\n\n @Override\n public void close() throws IOException {\n stack = new Object[] {SENTINEL_CLOSED};\n stackSize = 1;\n }\n\n @Override\n public void skipValue() throws IOException {\n JsonToken peeked = peek();\n switch (peeked) {\n case NAME:\n @SuppressWarnings(\"unused\")\n String unused = nextName(true);\n break;\n case END_ARRAY:\n endArray();\n break;\n case END_OBJECT:\n endObject();\n break;\n case END_DOCUMENT:\n // Do nothing\n break;\n default:\n popStack();\n if (stackSize > 0) {\n pathIndices[stackSize - 1]++;\n }\n break;\n }\n }\n\n @Override\n public String toString() {\n return getClass().getSimpleName() + locationString();\n }\n\n public void promoteNameToValue() throws IOException {\n expect(JsonToken.NAME);\n Iterator i = (Iterator) peekStack();\n Map.Entry entry = (Map.Entry) i.next();\n push(entry.getValue());\n push(new JsonPrimitive((String) entry.getKey()));\n }\n\n private void push(Object newTop) {\n if (stackSize == stack.length) {\n int newLength = stackSize * 2;\n stack = Arrays.copyOf(stack, newLength);\n pathIndices = Arrays.copyOf(pathIndices, newLength);\n pathNames = Arrays.copyOf(pathNames, newLength);\n }\n stack[stackSize++] = newTop;\n }\n\n private String getPath(boolean usePreviousPath) {\n StringBuilder result = new StringBuilder().append('$');\n for (int i = 0; i < stackSize; i++) {\n if (stack[i] instanceof JsonArray) {\n if (++i < stackSize && stack[i] instanceof Iterator) {\n int pathIndex = pathIndices[i];\n // If index is last path element it points to next array element; have to decrement\n // `- 1` covers case where iterator for next element is on stack\n // `- 2` covers case where peek() already pushed next element onto stack\n if (usePreviousPath && pathIndex > 0 && (i == stackSize - 1 || i == stackSize - 2)) {\n pathIndex--;\n }\n result.append('[').append(pathIndex).append(']');\n }\n } else if (stack[i] instanceof JsonObject) {\n if (++i < stackSize && stack[i] instanceof Iterator) {\n result.append('.');\n if (pathNames[i] != null) {\n result.append(pathNames[i]);\n }\n }\n }\n }\n return result.toString();\n }\n\n @Override\n public String getPath() {\n return getPath(false);\n }\n\n @Override\n public String getPreviousPath() {\n return getPath(true);\n }\n\n private String locationString() {\n return \" at path \" + getPath();\n }\n}\n\n// Path: gson/src/main/java/com/google/gson/internal/bind/JsonTreeWriter.java\n/*\n * Copyright (C) 2011 Google Inc.\n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n */\n\npackage com.google.gson.internal.bind;\n\nimport com.google.errorprone.annotations.CanIgnoreReturnValue;\nimport com.google.gson.JsonArray;\nimport com.google.gson.JsonElement;\nimport com.google.gson.JsonNull;\nimport com.google.gson.JsonObject;\nimport com.google.gson.JsonPrimitive;\nimport com.google.gson.stream.JsonWriter;\nimport java.io.IOException;\nimport java.io.Writer;\nimport java.util.ArrayList;\nimport java.util.List;\nimport java.util.Objects;\n\n/** This writer creates a JsonElement. */\npublic final class JsonTreeWriter extends JsonWriter {\n private static final Writer UNWRITABLE_WRITER =\n new Writer() {\n @Override\n public void write(char[] buffer, int offset, int counter) {\n throw new AssertionError();\n }\n\n @Override\n public void flush() {\n throw new AssertionError();\n }\n\n @Override\n public void close() {\n throw new AssertionError();\n }\n };\n\n /** Added to the top of the stack when this writer is closed to cause following ops to fail. */\n private static final JsonPrimitive SENTINEL_CLOSED = new JsonPrimitive(\"closed\");\n\n /** The JsonElements and JsonArrays under modification, outermost to innermost. */\n private final List stack = new ArrayList<>();\n\n /** The name for the next JSON object value. If non-null, the top of the stack is a JsonObject. */\n private String pendingName;\n\n /** the JSON element constructed by this writer. */\n private JsonElement product = JsonNull.INSTANCE; // TODO: is this really what we want?;\n\n public JsonTreeWriter() {\n super(UNWRITABLE_WRITER);\n }\n\n /** Returns the top level object produced by this writer. */\n public JsonElement get() {\n if (!stack.isEmpty()) {\n throw new IllegalStateException(\"Expected one JSON element but was \" + stack);\n }\n return product;\n }\n\n private JsonElement peek() {\n return stack.get(stack.size() - 1);\n }\n\n private void put(JsonElement value) {\n if (pendingName != null) {\n if (!value.isJsonNull() || getSerializeNulls()) {\n JsonObject object = (JsonObject) peek();\n object.add(pendingName, value);\n }\n pendingName = null;\n } else if (stack.isEmpty()) {\n product = value;\n } else {\n JsonElement element = peek();\n if (element instanceof JsonArray) {\n ((JsonArray) element).add(value);\n } else {\n throw new IllegalStateException();\n }\n }\n }\n\n @CanIgnoreReturnValue\n @Override\n public JsonWriter beginArray() throws IOException {\n JsonArray array = new JsonArray();\n put(array);\n stack.add(array);\n return this;\n }\n\n @CanIgnoreReturnValue\n @Override\n public JsonWriter endArray() throws IOException {\n if (stack.isEmpty() || pendingName != null) {\n throw new IllegalStateException();\n }\n JsonElement element = peek();\n if (element instanceof JsonArray) {\n stack.remove(stack.size() - 1);\n return this;\n }\n throw new IllegalStateException();\n }\n\n @CanIgnoreReturnValue\n @Override\n public JsonWriter beginObject() throws IOException {\n JsonObject object = new JsonObject();\n put(object);\n stack.add(object);\n return this;\n }\n\n @CanIgnoreReturnValue\n @Override\n public JsonWriter endObject() throws IOException {\n if (stack.isEmpty() || pendingName != null) {\n throw new IllegalStateException();\n }\n JsonElement element = peek();\n if (element instanceof JsonObject) {\n stack.remove(stack.size() - 1);\n return this;\n }\n throw new IllegalStateException();\n }\n\n @CanIgnoreReturnValue\n @Override\n public JsonWriter name(String name) throws IOException {\n Objects.requireNonNull(name, \"name == null\");\n if (stack.isEmpty() || pendingName != null) {\n throw new IllegalStateException(\"Did not expect a name\");\n }\n JsonElement element = peek();\n if (element instanceof JsonObject) {\n pendingName = name;\n return this;\n }\n throw new IllegalStateException(\"Please begin an object before writing a name.\");\n }\n\n @CanIgnoreReturnValue\n @Override\n public JsonWriter value(String value) throws IOException {\n if (value == null) {\n return nullValue();\n }\n put(new JsonPrimitive(value));\n return this;\n }\n\n @CanIgnoreReturnValue\n @Override\n public JsonWriter value(boolean value) throws IOException {\n put(new JsonPrimitive(value));\n return this;\n }\n\n @CanIgnoreReturnValue\n @Override\n public JsonWriter value(Boolean value) throws IOException {\n if (value == null) {\n return nullValue();\n }\n put(new JsonPrimitive(value));\n return this;\n }\n\n @CanIgnoreReturnValue\n @Override\n public JsonWriter value(float value) throws IOException {\n if (!isLenient() && (Float.isNaN(value) || Float.isInfinite(value))) {\n throw new IllegalArgumentException(\"JSON forbids NaN and infinities: \" + value);\n }\n put(new JsonPrimitive(value));\n return this;\n }\n\n @CanIgnoreReturnValue\n @Override\n public JsonWriter value(double value) throws IOException {\n if (!isLenient() && (Double.isNaN(value) || Double.isInfinite(value))) {\n throw new IllegalArgumentException(\"JSON forbids NaN and infinities: \" + value);\n }\n put(new JsonPrimitive(value));\n return this;\n }\n\n @CanIgnoreReturnValue\n @Override\n public JsonWriter value(long value) throws IOException {\n put(new JsonPrimitive(value));\n return this;\n }\n\n @CanIgnoreReturnValue\n @Override\n public JsonWriter value(Number value) throws IOException {\n if (value == null) {\n return nullValue();\n }\n\n if (!isLenient()) {\n double d = value.doubleValue();\n if (Double.isNaN(d) || Double.isInfinite(d)) {\n throw new IllegalArgumentException(\"JSON forbids NaN and infinities: \" + value);\n }\n }\n\n put(new JsonPrimitive(value));\n return this;\n }\n\n @CanIgnoreReturnValue\n @Override\n public JsonWriter nullValue() throws IOException {\n put(JsonNull.INSTANCE);\n return this;\n }\n\n @Override\n public JsonWriter jsonValue(String value) throws IOException {\n throw new UnsupportedOperationException();\n }\n\n @Override\n public void flush() throws IOException {}\n\n @Override\n public void close() throws IOException {\n if (!stack.isEmpty()) {\n throw new IOException(\"Incomplete document\");\n }\n stack.add(SENTINEL_CLOSED);\n }\n}\n\n// Path: gson/src/main/java/com/google/gson/TypeAdapter.java\n/*\n * Copyright (C) 2011 Google Inc.\n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n */\n\npackage com.google.gson;\n\nimport com.google.gson.internal.bind.JsonTreeReader;\nimport com.google.gson.internal.bind.JsonTreeWriter;\nimport com.google.gson.stream.JsonReader;\nimport com.google.gson.stream.JsonToken;\nimport com.google.gson.stream.JsonWriter;\nimport java.io.IOException;\nimport java.io.Reader;\nimport java.io.StringReader;\nimport java.io.StringWriter;\nimport java.io.Writer;\n\n/**\n * Converts Java objects to and from JSON.\n *\n *

    Defining a type's JSON form

    \n *\n * By default Gson converts application classes to JSON using its built-in type adapters. If Gson's\n * default JSON conversion isn't appropriate for a type, extend this class to customize the\n * conversion. Here's an example of a type adapter for an (X,Y) coordinate point:\n *\n *
    {@code\n * public class PointAdapter extends TypeAdapter {\n *   public Point read(JsonReader reader) throws IOException {\n *     if (reader.peek() == JsonToken.NULL) {\n *       reader.nextNull();\n *       return null;\n *     }\n *     String xy = reader.nextString();\n *     String[] parts = xy.split(\",\");\n *     int x = Integer.parseInt(parts[0]);\n *     int y = Integer.parseInt(parts[1]);\n *     return new Point(x, y);\n *   }\n *   public void write(JsonWriter writer, Point value) throws IOException {\n *     if (value == null) {\n *       writer.nullValue();\n *       return;\n *     }\n *     String xy = value.getX() + \",\" + value.getY();\n *     writer.value(xy);\n *   }\n * }\n * }
    \n *\n * With this type adapter installed, Gson will convert {@code Points} to JSON as strings like {@code\n * \"5,8\"} rather than objects like {@code {\"x\":5,\"y\":8}}. In this case the type adapter binds a rich\n * Java class to a compact JSON value.\n *\n *

    The {@link #read(JsonReader) read()} method must read exactly one value and {@link\n * #write(JsonWriter,Object) write()} must write exactly one value. For primitive types this is\n * means readers should make exactly one call to {@code nextBoolean()}, {@code nextDouble()}, {@code\n * nextInt()}, {@code nextLong()}, {@code nextString()} or {@code nextNull()}. Writers should make\n * exactly one call to one of {@code value()} or {@code nullValue()}. For arrays, type adapters\n * should start with a call to {@code beginArray()}, convert all elements, and finish with a call to\n * {@code endArray()}. For objects, they should start with {@code beginObject()}, convert the\n * object, and finish with {@code endObject()}. Failing to convert a value or converting too many\n * values may cause the application to crash.\n *\n *

    Type adapters should be prepared to read null from the stream and write it to the stream.\n * Alternatively, they should use {@link #nullSafe()} method while registering the type adapter with\n * Gson. If your {@code Gson} instance has been configured to {@link GsonBuilder#serializeNulls()},\n * these nulls will be written to the final document. Otherwise the value (and the corresponding\n * name when writing to a JSON object) will be omitted automatically. In either case your type\n * adapter must handle null.\n *\n *

    Type adapters should be stateless and thread-safe, otherwise the thread-safety guarantees of\n * {@link Gson} might not apply.\n *\n *

    To use a custom type adapter with Gson, you must register it with a {@link\n * GsonBuilder}:\n *\n *

    {@code\n * GsonBuilder builder = new GsonBuilder();\n * builder.registerTypeAdapter(Point.class, new PointAdapter());\n * // if PointAdapter didn't check for nulls in its read/write methods, you should instead use\n * // builder.registerTypeAdapter(Point.class, new PointAdapter().nullSafe());\n * ...\n * Gson gson = builder.create();\n * }
    \n *\n * @since 2.1\n */\n// non-Javadoc:\n//\n//

    JSON Conversion

    \n//

    A type adapter registered with Gson is automatically invoked while serializing\n// or deserializing JSON. However, you can also use type adapters directly to serialize\n// and deserialize JSON. Here is an example for deserialization:

    {@code\n//   String json = \"{'origin':'0,0','points':['1,2','3,4']}\";\n//   TypeAdapter graphAdapter = gson.getAdapter(Graph.class);\n//   Graph graph = graphAdapter.fromJson(json);\n// }
    \n// And an example for serialization:
    {@code\n//   Graph graph = new Graph(...);\n//   TypeAdapter graphAdapter = gson.getAdapter(Graph.class);\n//   String json = graphAdapter.toJson(graph);\n// }
    \n//\n//

    Type adapters are type-specific. For example, a {@code\n// TypeAdapter} can convert {@code Date} instances to JSON and JSON to\n// instances of {@code Date}, but cannot convert any other types.\n//\npublic abstract class TypeAdapter {\n\n public TypeAdapter() {}\n\n /**\n * Writes one JSON value (an array, object, string, number, boolean or null) for {@code value}.\n *\n * @param value the Java object to write. May be null.\n */\n public abstract void write(JsonWriter out, T value) throws IOException;\n\n /**\n * Converts {@code value} to a JSON document and writes it to {@code out}.\n *\n *

    A {@link JsonWriter} with default configuration is used for writing the JSON data. To\n * customize this behavior, create a {@link JsonWriter}, configure it and then use {@link\n * #write(JsonWriter, Object)} instead.\n *\n * @param value the Java object to convert. May be {@code null}.\n * @since 2.2\n */\n public final void toJson(Writer out, T value) throws IOException {\n JsonWriter writer = new JsonWriter(out);\n write(writer, value);\n }\n\n /**\n * Converts {@code value} to a JSON document.\n *\n *

    A {@link JsonWriter} with default configuration is used for writing the JSON data. To\n * customize this behavior, create a {@link JsonWriter}, configure it and then use {@link\n * #write(JsonWriter, Object)} instead.\n *\n * @throws JsonIOException wrapping {@code IOException}s thrown by {@link #write(JsonWriter,\n * Object)}\n * @param value the Java object to convert. May be {@code null}.\n * @since 2.2\n */\n public final String toJson(T value) {\n StringWriter stringWriter = new StringWriter();\n try {\n toJson(stringWriter, value);\n } catch (IOException e) {\n throw new JsonIOException(e);\n }\n return stringWriter.toString();\n }\n\n /**\n * Converts {@code value} to a JSON tree.\n *\n * @param value the Java object to convert. May be {@code null}.\n * @return the converted JSON tree. May be {@link JsonNull}.\n * @throws JsonIOException wrapping {@code IOException}s thrown by {@link #write(JsonWriter,\n * Object)}\n * @since 2.2\n */\n public final JsonElement toJsonTree(T value) {\n try {\n JsonTreeWriter jsonWriter = new JsonTreeWriter();\n write(jsonWriter, value);\n return jsonWriter.get();\n } catch (IOException e) {\n throw new JsonIOException(e);\n }\n }\n\n /**\n * Reads one JSON value (an array, object, string, number, boolean or null) and converts it to a\n * Java object. Returns the converted object.\n *\n * @return the converted Java object. May be {@code null}.\n */\n public abstract T read(JsonReader in) throws IOException;\n\n /**\n * Converts the JSON document in {@code in} to a Java object.\n *\n *

    A {@link JsonReader} with default configuration (that is with {@link\n * Strictness#LEGACY_STRICT} as strictness) is used for reading the JSON data. To customize this\n * behavior, create a {@link JsonReader}, configure it and then use {@link #read(JsonReader)}\n * instead.\n *\n *

    No exception is thrown if the JSON data has multiple top-level JSON elements, or if there is\n * trailing data.\n *\n * @return the converted Java object. May be {@code null}.\n * @since 2.2\n */\n public final T fromJson(Reader in) throws IOException {\n JsonReader reader = new JsonReader(in);\n return read(reader);\n }\n\n /**\n * Converts the JSON document in {@code json} to a Java object.\n *\n *

    A {@link JsonReader} with default configuration (that is with {@link\n * Strictness#LEGACY_STRICT} as strictness) is used for reading the JSON data. To customize this\n * behavior, create a {@link JsonReader}, configure it and then use {@link #read(JsonReader)}\n * instead.\n *\n *

    No exception is thrown if the JSON data has multiple top-level JSON elements, or if there is\n * trailing data.\n *\n * @return the converted Java object. May be {@code null}.\n * @since 2.2\n */\n public final T fromJson(String json) throws IOException {\n return fromJson(new StringReader(json));\n }\n\n /**\n * Converts {@code jsonTree} to a Java object.\n *\n * @param jsonTree the JSON element to convert. May be {@link JsonNull}.\n * @return the converted Java object. May be {@code null}.\n * @throws JsonIOException wrapping {@code IOException}s thrown by {@link #read(JsonReader)}\n * @since 2.2\n */\n public final T fromJsonTree(JsonElement jsonTree) {\n try {\n JsonReader jsonReader = new JsonTreeReader(jsonTree);\n return read(jsonReader);\n } catch (IOException e) {\n throw new JsonIOException(e);\n }\n }\n\n /**\n * This wrapper method is used to make a type adapter null tolerant. In general, a type adapter is\n * required to handle nulls in write and read methods. Here is how this is typically done:
    \n *\n...\n", "needle_token_start": 5553, "needle_token_end": 5626, "code_context_ntokens": 16030, "output": ["```java\n/**\n * Returns a deep copy of this element. Immutable elements like primitives and nulls are not\n * copied.\n *\n * @since 2.8.2\n */\npublic abstract JsonElement deepCopy();\n```"]} +{"repo": "google/gson", "name": "nonNull", "language": "java", "path": "gson/src/main/java/com/google/gson/internal/NonNullElementWrapperList.java", "position_ratio": 0.45, "description": "\nFunction Description:\n1. **Purpose**: The function ensures that the provided element is not null, throwing an exception if it is.\n2. **Input**: A single element of a generic type.\n3. **Output**: The same element if it is not null; otherwise, it throws a `NullPointerException`.\n4. **Procedure**: The function checks if the input element is null. If it is, it throws a `NullPointerException` with a specific error message. If the element is not null, it simply returns the element.\n\n", "instruction": "Based on the function description and code context, please retrieve and repeat the exact described function from the code context in a code block wrapped by ```:", "template": "instruction\ncode_context\ndescription\ninstruction", "code_context": "// Path: gson/src/main/java/com/google/gson/FieldNamingPolicy.java\n/*\n * Copyright (C) 2008 Google Inc.\n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n */\n\npackage com.google.gson;\n\nimport java.lang.reflect.Field;\nimport java.util.Locale;\n\n/**\n * An enumeration that defines a few standard naming conventions for JSON field names. This\n * enumeration should be used in conjunction with {@link com.google.gson.GsonBuilder} to configure a\n * {@link com.google.gson.Gson} instance to properly translate Java field names into the desired\n * JSON field names.\n *\n * @author Inderjeet Singh\n * @author Joel Leitch\n */\npublic enum FieldNamingPolicy implements FieldNamingStrategy {\n\n /** Using this naming policy with Gson will ensure that the field name is unchanged. */\n IDENTITY() {\n @Override\n public String translateName(Field f) {\n return f.getName();\n }\n },\n\n /**\n * Using this naming policy with Gson will ensure that the first \"letter\" of the Java field name\n * is capitalized when serialized to its JSON form.\n *\n *

    Here are a few examples of the form \"Java Field Name\" ---> \"JSON Field Name\":\n *\n *

      \n *
    • someFieldName ---> SomeFieldName\n *
    • _someFieldName ---> _SomeFieldName\n *
    \n */\n UPPER_CAMEL_CASE() {\n @Override\n public String translateName(Field f) {\n return upperCaseFirstLetter(f.getName());\n }\n },\n\n /**\n * Using this naming policy with Gson will ensure that the first \"letter\" of the Java field name\n * is capitalized when serialized to its JSON form and the words will be separated by a space.\n *\n *

    Here are a few examples of the form \"Java Field Name\" ---> \"JSON Field Name\":\n *\n *

      \n *
    • someFieldName ---> Some Field Name\n *
    • _someFieldName ---> _Some Field Name\n *
    \n *\n * @since 1.4\n */\n UPPER_CAMEL_CASE_WITH_SPACES() {\n @Override\n public String translateName(Field f) {\n return upperCaseFirstLetter(separateCamelCase(f.getName(), ' '));\n }\n },\n\n /**\n * Using this naming policy with Gson will modify the Java Field name from its camel cased form to\n * an upper case field name where each word is separated by an underscore (_).\n *\n *

    Here are a few examples of the form \"Java Field Name\" ---> \"JSON Field Name\":\n *\n *

      \n *
    • someFieldName ---> SOME_FIELD_NAME\n *
    • _someFieldName ---> _SOME_FIELD_NAME\n *
    • aStringField ---> A_STRING_FIELD\n *
    • aURL ---> A_U_R_L\n *
    \n *\n * @since 2.9.0\n */\n UPPER_CASE_WITH_UNDERSCORES() {\n @Override\n public String translateName(Field f) {\n...\n// Path: gson/src/main/java/com/google/gson/LongSerializationPolicy.java\n/*\n * Copyright (C) 2009 Google Inc.\n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n */\n\npackage com.google.gson;\n\n/**\n * Defines the expected format for a {@code long} or {@code Long} type when it is serialized.\n *\n * @since 1.3\n * @author Inderjeet Singh\n * @author Joel Leitch\n */\npublic enum LongSerializationPolicy {\n /**\n * This is the \"default\" serialization policy that will output a {@code Long} object as a JSON\n * number. For example, assume an object has a long field named \"f\" then the serialized output\n * would be: {@code {\"f\":123}}\n *\n *

    A {@code null} value is serialized as {@link JsonNull}.\n */\n DEFAULT() {\n @Override\n public JsonElement serialize(Long value) {\n if (value == null) {\n return JsonNull.INSTANCE;\n }\n return new JsonPrimitive(value);\n }\n },\n\n /**\n * Serializes a long value as a quoted string. For example, assume an object has a long field\n * named \"f\" then the serialized output would be: {@code {\"f\":\"123\"}}\n *\n *

    A {@code null} value is serialized as {@link JsonNull}.\n */\n STRING() {\n @Override\n public JsonElement serialize(Long value) {\n if (value == null) {\n return JsonNull.INSTANCE;\n }\n return new JsonPrimitive(value.toString());\n }\n };\n\n /**\n * Serialize this {@code value} using this serialization policy.\n *\n * @param value the long value to be serialized into a {@link JsonElement}\n * @return the serialized version of {@code value}\n */\n public abstract JsonElement serialize(Long value);\n}\n\n// Path: gson/src/main/java/com/google/gson/Strictness.java\npackage com.google.gson;\n\nimport com.google.gson.stream.JsonReader;\nimport com.google.gson.stream.JsonWriter;\n\n/**\n * Modes that indicate how strictly a JSON {@linkplain JsonReader reader} or {@linkplain JsonWriter\n * writer} follows the syntax laid out in the RFC\n * 8259 JSON specification.\n *\n *

    You can look at {@link JsonReader#setStrictness(Strictness)} to see how the strictness affects\n * the {@link JsonReader} and you can look at {@link JsonWriter#setStrictness(Strictness)} to see\n * how the strictness affects the {@link JsonWriter}.\n *\n * @see JsonReader#setStrictness(Strictness)\n * @see JsonWriter#setStrictness(Strictness)\n * @since $next-version$\n */\npublic enum Strictness {\n /** Allow large deviations from the JSON specification. */\n LENIENT,\n\n /** Allow certain small deviations from the JSON specification for legacy reasons. */\n LEGACY_STRICT,\n\n /** Strict compliance with the JSON specification. */\n STRICT\n}\n\n// Path: gson/src/main/java/com/google/gson/annotations/Since.java\n/*\n * Copyright (C) 2008 Google Inc.\n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n */\n\npackage com.google.gson.annotations;\n\nimport com.google.gson.GsonBuilder;\nimport java.lang.annotation.Documented;\nimport java.lang.annotation.ElementType;\nimport java.lang.annotation.Retention;\nimport java.lang.annotation.RetentionPolicy;\nimport java.lang.annotation.Target;\n\n/**\n * An annotation that indicates the version number since a member or a type has been present. This\n * annotation is useful to manage versioning of your JSON classes for a web-service.\n *\n *

    This annotation has no effect unless you build {@link com.google.gson.Gson} with a {@code\n * GsonBuilder} and invoke the {@link GsonBuilder#setVersion(double)} method.\n *\n *

    Here is an example of how this annotation is meant to be used:\n *\n *

    \n * public class User {\n *   private String firstName;\n *   private String lastName;\n *   @Since(1.0) private String emailAddress;\n *   @Since(1.0) private String password;\n *   @Since(1.1) private Address address;\n * }\n * 
    \n *\n *

    If you created Gson with {@code new Gson()}, the {@code toJson()} and {@code fromJson()}\n * methods will use all the fields for serialization and deserialization. However, if you created\n * Gson with {@code Gson gson = new GsonBuilder().setVersion(1.0).create()} then the {@code\n * toJson()} and {@code fromJson()} methods of Gson will exclude the {@code address} field since\n * it's version number is set to {@code 1.1}.\n *\n * @author Inderjeet Singh\n * @author Joel Leitch\n * @see GsonBuilder#setVersion(double)\n * @see Until\n */\n@Documented\n@Retention(RetentionPolicy.RUNTIME)\n@Target({ElementType.FIELD, ElementType.TYPE})\npublic @interface Since {\n /**\n * The value indicating a version number since this member or type has been present. The number is\n * inclusive; annotated elements will be included if {@code gsonVersion >= value}.\n */\n double value();\n}\n\n// Path: gson/src/main/java/com/google/gson/annotations/Until.java\n/*\n * Copyright (C) 2008 Google Inc.\n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n */\n\npackage com.google.gson.annotations;\n\nimport com.google.gson.GsonBuilder;\nimport java.lang.annotation.Documented;\nimport java.lang.annotation.ElementType;\nimport java.lang.annotation.Retention;\nimport java.lang.annotation.RetentionPolicy;\nimport java.lang.annotation.Target;\n\n/**\n * An annotation that indicates the version number until a member or a type should be present.\n * Basically, if Gson is created with a version number that is equal to or exceeds the value stored\n * in the {@code Until} annotation then the field will be ignored from the JSON output. This\n * annotation is useful to manage versioning of your JSON classes for a web-service.\n *\n *

    This annotation has no effect unless you build {@link com.google.gson.Gson} with a {@code\n * GsonBuilder} and invoke the {@link GsonBuilder#setVersion(double)} method.\n *\n *

    Here is an example of how this annotation is meant to be used:\n *\n *

    \n * public class User {\n *   private String firstName;\n *   private String lastName;\n *   @Until(1.1) private String emailAddress;\n *   @Until(1.1) private String password;\n * }\n * 
    \n *\n *

    If you created Gson with {@code new Gson()}, the {@code toJson()} and {@code fromJson()}\n * methods will use all the fields for serialization and deserialization. However, if you created\n * Gson with {@code Gson gson = new GsonBuilder().setVersion(1.2).create()} then the {@code\n * toJson()} and {@code fromJson()} methods of Gson will exclude the {@code emailAddress} and {@code\n * password} fields from the example above, because the version number passed to the GsonBuilder,\n * {@code 1.2}, exceeds the version number set on the {@code Until} annotation, {@code 1.1}, for\n * those fields.\n *\n * @author Inderjeet Singh\n * @author Joel Leitch\n * @see GsonBuilder#setVersion(double)\n * @see Since\n * @since 1.3\n */\n@Documented\n@Retention(RetentionPolicy.RUNTIME)\n@Target({ElementType.FIELD, ElementType.TYPE})\npublic @interface Until {\n\n /**\n * The value indicating a version number until this member or type should be included. The number\n * is exclusive; annotated elements will be included if {@code gsonVersion < value}.\n */\n double value();\n}\n\n// Path: gson/src/main/java/com/google/gson/internal/$Gson$Preconditions.java\n/*\n * Copyright (C) 2008 Google Inc.\n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n */\n\npackage com.google.gson.internal;\n\nimport java.util.Objects;\n\n/**\n * A simple utility class used to check method Preconditions.\n *\n *

    \n * public long divideBy(long value) {\n *   Preconditions.checkArgument(value != 0);\n *   return this.value / value;\n * }\n * 
    \n *\n * @author Inderjeet Singh\n * @author Joel Leitch\n */\npublic final class $Gson$Preconditions {\n private $Gson$Preconditions() {\n throw new UnsupportedOperationException();\n }\n\n /**\n * @deprecated This is an internal Gson method. Use {@link Objects#requireNonNull(Object)}\n * instead.\n */\n // Only deprecated for now because external projects might be using this by accident\n @Deprecated\n public static T checkNotNull(T obj) {\n if (obj == null) {\n throw new NullPointerException();\n }\n return obj;\n }\n\n public static void checkArgument(boolean condition) {\n if (!condition) {\n throw new IllegalArgumentException();\n }\n }\n}\n\n// Path: gson/src/main/java/com/google/gson/ExclusionStrategy.java\n/*\n * Copyright (C) 2008 Google Inc.\n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n */\n\npackage com.google.gson;\n\n/**\n * A strategy (or policy) definition that is used to decide whether or not a field or class should\n * be serialized or deserialized as part of the JSON output/input.\n *\n *

    The following are a few examples that shows how you can use this exclusion mechanism.\n *\n *

    Exclude fields and objects based on a particular class type:\n *\n *

    \n * private static class SpecificClassExclusionStrategy implements ExclusionStrategy {\n *   private final Class<?> excludedThisClass;\n *\n *   public SpecificClassExclusionStrategy(Class<?> excludedThisClass) {\n *     this.excludedThisClass = excludedThisClass;\n *   }\n *\n *   public boolean shouldSkipClass(Class<?> clazz) {\n *     return excludedThisClass.equals(clazz);\n *   }\n *\n *   public boolean shouldSkipField(FieldAttributes f) {\n *     return excludedThisClass.equals(f.getDeclaredClass());\n *   }\n * }\n * 
    \n *\n *

    Excludes fields and objects based on a particular annotation:\n *\n *

    \n * public @interface FooAnnotation {\n *   // some implementation here\n * }\n *\n * // Excludes any field (or class) that is tagged with an \"@FooAnnotation\"\n * private static class FooAnnotationExclusionStrategy implements ExclusionStrategy {\n *   public boolean shouldSkipClass(Class<?> clazz) {\n *     return clazz.getAnnotation(FooAnnotation.class) != null;\n *   }\n *\n *   public boolean shouldSkipField(FieldAttributes f) {\n *     return f.getAnnotation(FooAnnotation.class) != null;\n *   }\n * }\n * 
    \n *\n *

    Now if you want to configure {@code Gson} to use a user defined exclusion strategy, then the\n * {@code GsonBuilder} is required. The following is an example of how you can use the {@code\n * GsonBuilder} to configure Gson to use one of the above samples:\n *\n *

    \n * ExclusionStrategy excludeStrings = new UserDefinedExclusionStrategy(String.class);\n * Gson gson = new GsonBuilder()\n *     .setExclusionStrategies(excludeStrings)\n *     .create();\n * 
    \n *\n *

    For certain model classes, you may only want to serialize a field, but exclude it for\n * deserialization. To do that, you can write an {@code ExclusionStrategy} as per normal; however,\n * you would register it with the {@link\n * GsonBuilder#addDeserializationExclusionStrategy(ExclusionStrategy)} method. For example:\n *\n *

    \n * ExclusionStrategy excludeStrings = new UserDefinedExclusionStrategy(String.class);\n * Gson gson = new GsonBuilder()\n *     .addDeserializationExclusionStrategy(excludeStrings)\n *     .create();\n * 
    \n *\n * @author Inderjeet Singh\n * @author Joel Leitch\n * @see GsonBuilder#setExclusionStrategies(ExclusionStrategy...)\n * @see GsonBuilder#addDeserializationExclusionStrategy(ExclusionStrategy)\n * @see GsonBuilder#addSerializationExclusionStrategy(ExclusionStrategy)\n * @since 1.4\n */\npublic interface ExclusionStrategy {\n\n /**\n * Decides if a field should be skipped during serialization or deserialization.\n *\n * @param f the field object that is under test\n * @return true if the field should be ignored; otherwise false\n */\n public boolean shouldSkipField(FieldAttributes f);\n\n /**\n * Decides if a class should be serialized or deserialized\n *\n * @param clazz the class object that is under test\n * @return true if the class should be ignored; otherwise false\n */\n public boolean shouldSkipClass(Class clazz);\n}\n\n// Path: gson/src/main/java/com/google/gson/FieldAttributes.java\n/*\n * Copyright (C) 2009 Google Inc.\n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n */\n\npackage com.google.gson;\n\nimport java.lang.annotation.Annotation;\nimport java.lang.reflect.Field;\nimport java.lang.reflect.Type;\nimport java.util.Arrays;\nimport java.util.Collection;\nimport java.util.Objects;\n\n/**\n * A data object that stores attributes of a field.\n *\n *

    This class is immutable; therefore, it can be safely shared across threads.\n *\n * @author Inderjeet Singh\n * @author Joel Leitch\n * @since 1.4\n */\npublic final class FieldAttributes {\n private final Field field;\n\n /**\n * Constructs a Field Attributes object from the {@code f}.\n *\n * @param f the field to pull attributes from\n */\n public FieldAttributes(Field f) {\n this.field = Objects.requireNonNull(f);\n }\n\n /**\n * Gets the declaring Class that contains this field\n *\n * @return the declaring class that contains this field\n */\n public Class getDeclaringClass() {\n return field.getDeclaringClass();\n }\n\n /**\n * Gets the name of the field\n *\n * @return the name of the field\n */\n public String getName() {\n return field.getName();\n }\n\n /**\n * Returns the declared generic type of the field.\n *\n *

    For example, assume the following class definition:\n *\n *

    \n   * public class Foo {\n   *   private String bar;\n   *   private List<String> red;\n   * }\n   *\n   * Type listParameterizedType = new TypeToken<List<String>>() {}.getType();\n   * 
    \n *\n *

    This method would return {@code String.class} for the {@code bar} field and {@code\n * listParameterizedType} for the {@code red} field.\n *\n * @return the specific type declared for this field\n */\n public Type getDeclaredType() {\n return field.getGenericType();\n }\n\n /**\n * Returns the {@code Class} object that was declared for this field.\n *\n *

    For example, assume the following class definition:\n *\n *

    \n   * public class Foo {\n   *   private String bar;\n   *   private List<String> red;\n   * }\n   * 
    \n *\n *

    This method would return {@code String.class} for the {@code bar} field and {@code\n * List.class} for the {@code red} field.\n *\n * @return the specific class object that was declared for the field\n */\n public Class getDeclaredClass() {\n return field.getType();\n }\n\n /**\n * Returns the {@code T} annotation object from this field if it exists; otherwise returns {@code\n * null}.\n *\n * @param annotation the class of the annotation that will be retrieved\n * @return the annotation instance if it is bound to the field; otherwise {@code null}\n */\n public T getAnnotation(Class annotation) {\n return field.getAnnotation(annotation);\n }\n\n /**\n * Returns the annotations that are present on this field.\n *\n * @return an array of all the annotations set on the field\n * @since 1.4\n */\n public Collection getAnnotations() {\n return Arrays.asList(field.getAnnotations());\n }\n\n /**\n * Returns {@code true} if the field is defined with the {@code modifier}.\n *\n *

    This method is meant to be called as:\n *\n *

    \n   * boolean hasPublicModifier = fieldAttribute.hasModifier(java.lang.reflect.Modifier.PUBLIC);\n   * 
    \n *\n * @see java.lang.reflect.Modifier\n */\n public boolean hasModifier(int modifier) {\n return (field.getModifiers() & modifier) != 0;\n }\n\n @Override\n public String toString() {\n return field.toString();\n }\n}\n\n// Path: gson/src/main/java/com/google/gson/internal/NonNullElementWrapperList.java\n/*\n * Copyright (C) 2018 Google Inc.\n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n */\n\npackage com.google.gson.internal;\n\nimport java.util.AbstractList;\nimport java.util.ArrayList;\nimport java.util.Collection;\nimport java.util.List;\nimport java.util.Objects;\nimport java.util.RandomAccess;\n\n/**\n * {@link List} which wraps another {@code List} but prevents insertion of {@code null} elements.\n * Methods which only perform checks with the element argument (e.g. {@link #contains(Object)}) do\n * not throw exceptions for {@code null} arguments.\n */\npublic class NonNullElementWrapperList extends AbstractList implements RandomAccess {\n // Explicitly specify ArrayList as type to guarantee that delegate implements RandomAccess\n private final ArrayList delegate;\n\n @SuppressWarnings(\"NonApiType\")\n public NonNullElementWrapperList(ArrayList delegate) {\n this.delegate = Objects.requireNonNull(delegate);\n }\n\n @Override\n public E get(int index) {\n return delegate.get(index);\n }\n\n @Override\n public int size() {\n return delegate.size();\n }\n\n \nprivate E nonNull(E element) {\n if (element == null) {\n throw new NullPointerException(\"Element must be non-null\");\n }\n return element;\n }\n\n @Override\n public E set(int index, E element) {\n return delegate.set(index, nonNull(element));\n }\n\n @Override\n public void add(int index, E element) {\n delegate.add(index, nonNull(element));\n }\n\n @Override\n public E remove(int index) {\n return delegate.remove(index);\n }\n\n /* The following methods are overridden because their default implementation is inefficient */\n\n @Override\n public void clear() {\n delegate.clear();\n }\n\n @SuppressWarnings(\"UngroupedOverloads\") // this is intentionally ungrouped, see comment above\n @Override\n public boolean remove(Object o) {\n return delegate.remove(o);\n }\n\n @Override\n public boolean removeAll(Collection c) {\n return delegate.removeAll(c);\n }\n\n @Override\n public boolean retainAll(Collection c) {\n return delegate.retainAll(c);\n }\n\n @Override\n public boolean contains(Object o) {\n return delegate.contains(o);\n }\n\n @Override\n public int indexOf(Object o) {\n return delegate.indexOf(o);\n }\n\n @Override\n public int lastIndexOf(Object o) {\n return delegate.lastIndexOf(o);\n }\n\n @Override\n public Object[] toArray() {\n return delegate.toArray();\n }\n\n @Override\n public T[] toArray(T[] a) {\n return delegate.toArray(a);\n }\n\n @Override\n public boolean equals(Object o) {\n return delegate.equals(o);\n }\n\n @Override\n public int hashCode() {\n return delegate.hashCode();\n }\n\n // TODO: Once Gson targets Java 8 also override List.sort\n}\n\n// Path: gson/src/main/java/com/google/gson/JsonArray.java\n/*\n * Copyright (C) 2008 Google Inc.\n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n */\n\npackage com.google.gson;\n\nimport com.google.errorprone.annotations.CanIgnoreReturnValue;\nimport com.google.gson.internal.NonNullElementWrapperList;\nimport java.math.BigDecimal;\nimport java.math.BigInteger;\nimport java.util.ArrayList;\nimport java.util.Iterator;\nimport java.util.List;\n\n/**\n * A class representing an array type in JSON. An array is a list of {@link JsonElement}s each of\n * which can be of a different type. This is an ordered list, meaning that the order in which\n * elements are added is preserved. This class does not support {@code null} values. If {@code null}\n * is provided as element argument to any of the methods, it is converted to a {@link JsonNull}.\n *\n *

    {@code JsonArray} only implements the {@link Iterable} interface but not the {@link List}\n * interface. A {@code List} view of it can be obtained with {@link #asList()}.\n *\n * @author Inderjeet Singh\n * @author Joel Leitch\n */\npublic final class JsonArray extends JsonElement implements Iterable {\n private final ArrayList elements;\n\n /** Creates an empty JsonArray. */\n @SuppressWarnings(\"deprecation\") // superclass constructor\n public JsonArray() {\n elements = new ArrayList<>();\n }\n\n /**\n * Creates an empty JsonArray with the desired initial capacity.\n *\n * @param capacity initial capacity.\n * @throws IllegalArgumentException if the {@code capacity} is negative\n * @since 2.8.1\n */\n @SuppressWarnings(\"deprecation\") // superclass constructor\n public JsonArray(int capacity) {\n elements = new ArrayList<>(capacity);\n }\n\n /**\n * Creates a deep copy of this element and all its children.\n *\n * @since 2.8.2\n */\n @Override\n public JsonArray deepCopy() {\n if (!elements.isEmpty()) {\n JsonArray result = new JsonArray(elements.size());\n for (JsonElement element : elements) {\n result.add(element.deepCopy());\n }\n return result;\n }\n return new JsonArray();\n }\n\n /**\n * Adds the specified boolean to self.\n *\n * @param bool the boolean that needs to be added to the array.\n * @since 2.4\n */\n public void add(Boolean bool) {\n elements.add(bool == null ? JsonNull.INSTANCE : new JsonPrimitive(bool));\n }\n\n /**\n * Adds the specified character to self.\n *\n * @param character the character that needs to be added to the array.\n * @since 2.4\n */\n public void add(Character character) {\n elements.add(character == null ? JsonNull.INSTANCE : new JsonPrimitive(character));\n }\n\n /**\n * Adds the specified number to self.\n *\n * @param number the number that needs to be added to the array.\n * @since 2.4\n */\n public void add(Number number) {\n elements.add(number == null ? JsonNull.INSTANCE : new JsonPrimitive(number));\n }\n\n /**\n * Adds the specified string to self.\n *\n * @param string the string that needs to be added to the array.\n * @since 2.4\n */\n public void add(String string) {\n elements.add(string == null ? JsonNull.INSTANCE : new JsonPrimitive(string));\n }\n\n /**\n * Adds the specified element to self.\n *\n * @param element the element that needs to be added to the array.\n */\n public void add(JsonElement element) {\n if (element == null) {\n element = JsonNull.INSTANCE;\n }\n elements.add(element);\n }\n\n /**\n * Adds all the elements of the specified array to self.\n *\n * @param array the array whose elements need to be added to the array.\n */\n public void addAll(JsonArray array) {\n elements.addAll(array.elements);\n }\n\n /**\n * Replaces the element at the specified position in this array with the specified element.\n *\n * @param index index of the element to replace\n * @param element element to be stored at the specified position\n * @return the element previously at the specified position\n * @throws IndexOutOfBoundsException if the specified index is outside the array bounds\n */\n @CanIgnoreReturnValue\n public JsonElement set(int index, JsonElement element) {\n return elements.set(index, element == null ? JsonNull.INSTANCE : element);\n }\n\n /**\n * Removes the first occurrence of the specified element from this array, if it is present. If the\n * array does not contain the element, it is unchanged.\n *\n * @param element element to be removed from this array, if present\n * @return true if this array contained the specified element, false otherwise\n * @since 2.3\n */\n @CanIgnoreReturnValue\n public boolean remove(JsonElement element) {\n return elements.remove(element);\n }\n\n /**\n * Removes the element at the specified position in this array. Shifts any subsequent elements to\n * the left (subtracts one from their indices). Returns the element that was removed from the\n * array.\n *\n * @param index index the index of the element to be removed\n * @return the element previously at the specified position\n * @throws IndexOutOfBoundsException if the specified index is outside the array bounds\n * @since 2.3\n */\n @CanIgnoreReturnValue\n public JsonElement remove(int index) {\n return elements.remove(index);\n }\n\n /**\n * Returns true if this array contains the specified element.\n *\n * @return true if this array contains the specified element.\n * @param element whose presence in this array is to be tested\n * @since 2.3\n */\n public boolean contains(JsonElement element) {\n return elements.contains(element);\n }\n\n /**\n * Returns the number of elements in the array.\n *\n * @return the number of elements in the array.\n */\n public int size() {\n return elements.size();\n }\n\n /**\n * Returns true if the array is empty.\n *\n * @return true if the array is empty.\n * @since 2.8.7\n */\n public boolean isEmpty() {\n return elements.isEmpty();\n }\n\n /**\n * Returns an iterator to navigate the elements of the array. Since the array is an ordered list,\n * the iterator navigates the elements in the order they were inserted.\n *\n * @return an iterator to navigate the elements of the array.\n */\n @Override\n public Iterator iterator() {\n return elements.iterator();\n }\n\n /**\n * Returns the i-th element of the array.\n *\n * @param i the index of the element that is being sought.\n * @return the element present at the i-th index.\n * @throws IndexOutOfBoundsException if {@code i} is negative or greater than or equal to the\n * {@link #size()} of the array.\n */\n public JsonElement get(int i) {\n return elements.get(i);\n }\n\n private JsonElement getAsSingleElement() {\n int size = elements.size();\n if (size == 1) {\n return elements.get(0);\n }\n throw new IllegalStateException(\"Array must have size 1, but has size \" + size);\n }\n\n /**\n * Convenience method to get this array as a {@link Number} if it contains a single element. This\n * method calls {@link JsonElement#getAsNumber()} on the element, therefore any of the exceptions\n * declared by that method can occur.\n *\n * @return this element as a number if it is single element array.\n * @throws IllegalStateException if the array is empty or has more than one element.\n */\n @Override\n public Number getAsNumber() {\n return getAsSingleElement().getAsNumber();\n }\n\n /**\n * Convenience method to get this array as a {@link String} if it contains a single element. This\n * method calls {@link JsonElement#getAsString()} on the element, therefore any of the exceptions\n * declared by that method can occur.\n *\n * @return this element as a String if it is single element array.\n * @throws IllegalStateException if the array is empty or has more than one element.\n */\n @Override\n public String getAsString() {\n return getAsSingleElement().getAsString();\n }\n\n /**\n * Convenience method to get this array as a double if it contains a single element. This method\n * calls {@link JsonElement#getAsDouble()} on the element, therefore any of the exceptions\n * declared by that method can occur.\n *\n * @return this element as a double if it is single element array.\n * @throws IllegalStateException if the array is empty or has more than one element.\n */\n @Override\n public double getAsDouble() {\n return getAsSingleElement().getAsDouble();\n }\n\n /**\n * Convenience method to get this array as a {@link BigDecimal} if it contains a single element.\n * This method calls {@link JsonElement#getAsBigDecimal()} on the element, therefore any of the\n * exceptions declared by that method can occur.\n *\n * @return this element as a {@link BigDecimal} if it is single element array.\n * @throws IllegalStateException if the array is empty or has more than one element.\n * @since 1.2\n */\n @Override\n public BigDecimal getAsBigDecimal() {\n return getAsSingleElement().getAsBigDecimal();\n }\n\n /**\n * Convenience method to get this array as a {@link BigInteger} if it contains a single element.\n * This method calls {@link JsonElement#getAsBigInteger()} on the element, therefore any of the\n * exceptions declared by that method can occur.\n *\n * @return this element as a {@link BigInteger} if it is single element array.\n * @throws IllegalStateException if the array is empty or has more than one element.\n * @since 1.2\n */\n @Override\n public BigInteger getAsBigInteger() {\n return getAsSingleElement().getAsBigInteger();\n }\n\n /**\n * Convenience method to get this array as a float if it contains a single element. This method\n * calls {@link JsonElement#getAsFloat()} on the element, therefore any of the exceptions declared\n * by that method can occur.\n *\n * @return this element as a float if it is single element array.\n * @throws IllegalStateException if the array is empty or has more than one element.\n */\n @Override\n public float getAsFloat() {\n return getAsSingleElement().getAsFloat();\n }\n\n /**\n * Convenience method to get this array as a long if it contains a single element. This method\n * calls {@link JsonElement#getAsLong()} on the element, therefore any of the exceptions declared\n * by that method can occur.\n *\n * @return this element as a long if it is single element array.\n * @throws IllegalStateException if the array is empty or has more than one element.\n */\n @Override\n public long getAsLong() {\n return getAsSingleElement().getAsLong();\n }\n\n /**\n * Convenience method to get this array as an integer if it contains a single element. This method\n * calls {@link JsonElement#getAsInt()} on the element, therefore any of the exceptions declared\n * by that method can occur.\n *\n * @return this element as an integer if it is single element array.\n * @throws IllegalStateException if the array is empty or has more than one element.\n */\n @Override\n public int getAsInt() {\n return getAsSingleElement().getAsInt();\n }\n\n /**\n * Convenience method to get this array as a primitive byte if it contains a single element. This\n * method calls {@link JsonElement#getAsByte()} on the element, therefore any of the exceptions\n * declared by that method can occur.\n *\n * @return this element as a primitive byte if it is single element array.\n * @throws IllegalStateException if the array is empty or has more than one element.\n */\n @Override\n public byte getAsByte() {\n return getAsSingleElement().getAsByte();\n }\n\n /**\n * Convenience method to get this array as a character if it contains a single element. This\n * method calls {@link JsonElement#getAsCharacter()} on the element, therefore any of the\n * exceptions declared by that method can occur.\n *\n * @return this element as a primitive short if it is single element array.\n * @throws IllegalStateException if the array is empty or has more than one element.\n * @deprecated This method is misleading, as it does not get this element as a char but rather as\n * a string's first character.\n */\n @Deprecated\n @Override\n public char getAsCharacter() {\n return getAsSingleElement().getAsCharacter();\n }\n\n /**\n * Convenience method to get this array as a primitive short if it contains a single element. This\n * method calls {@link JsonElement#getAsShort()} on the element, therefore any of the exceptions\n * declared by that method can occur.\n *\n * @return this element as a primitive short if it is single element array.\n * @throws IllegalStateException if the array is empty or has more than one element.\n */\n @Override\n public short getAsShort() {\n return getAsSingleElement().getAsShort();\n }\n\n /**\n * Convenience method to get this array as a boolean if it contains a single element. This method\n * calls {@link JsonElement#getAsBoolean()} on the element, therefore any of the exceptions\n * declared by that method can occur.\n *\n * @return this element as a boolean if it is single element array.\n * @throws IllegalStateException if the array is empty or has more than one element.\n */\n @Override\n public boolean getAsBoolean() {\n return getAsSingleElement().getAsBoolean();\n }\n\n /**\n * Returns a mutable {@link List} view of this {@code JsonArray}. Changes to the {@code List} are\n * visible in this {@code JsonArray} and the other way around.\n *\n *

    The {@code List} does not permit {@code null} elements. Unlike {@code JsonArray}'s {@code\n * null} handling, a {@link NullPointerException} is thrown when trying to add {@code null}. Use\n * {@link JsonNull} for JSON null values.\n *\n * @return mutable {@code List} view\n * @since 2.10\n */\n public List asList() {\n return new NonNullElementWrapperList<>(elements);\n }\n\n /**\n * Returns whether the other object is equal to this. This method only considers the other object\n * to be equal if it is an instance of {@code JsonArray} and has equal elements in the same order.\n */\n @Override\n public boolean equals(Object o) {\n return (o == this) || (o instanceof JsonArray && ((JsonArray) o).elements.equals(elements));\n }\n\n /**\n * Returns the hash code of this array. This method calculates the hash code based on the elements\n * of this array.\n */\n @Override\n public int hashCode() {\n return elements.hashCode();\n }\n}\n\n// Path: gson/src/main/java/com/google/gson/JsonIOException.java\n/*\n * Copyright (C) 2008 Google Inc.\n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n */\npackage com.google.gson;\n\n/**\n * This exception is raised when Gson was unable to read an input stream or write to one.\n *\n * @author Inderjeet Singh\n * @author Joel Leitch\n */\npublic final class JsonIOException extends JsonParseException {\n private static final long serialVersionUID = 1L;\n\n public JsonIOException(String msg) {\n super(msg);\n }\n\n public JsonIOException(String msg, Throwable cause) {\n super(msg, cause);\n }\n\n /**\n * Creates exception with the specified cause. Consider using {@link #JsonIOException(String,\n * Throwable)} instead if you can describe what happened.\n *\n * @param cause root exception that caused this exception to be thrown.\n */\n public JsonIOException(Throwable cause) {\n super(cause);\n }\n}\n\n// Path: gson/src/main/java/com/google/gson/JsonSyntaxException.java\n/*\n * Copyright (C) 2010 Google Inc.\n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n */\npackage com.google.gson;\n\n/**\n * This exception is raised when Gson attempts to read (or write) a malformed JSON element.\n *\n * @author Inderjeet Singh\n * @author Joel Leitch\n */\npublic final class JsonSyntaxException extends JsonParseException {\n\n private static final long serialVersionUID = 1L;\n\n public JsonSyntaxException(String msg) {\n super(msg);\n }\n\n public JsonSyntaxException(String msg, Throwable cause) {\n super(msg, cause);\n }\n\n /**\n * Creates exception with the specified cause. Consider using {@link #JsonSyntaxException(String,\n * Throwable)} instead if you can describe what actually happened.\n *\n * @param cause root exception that caused this exception to be thrown.\n */\n public JsonSyntaxException(Throwable cause) {\n super(cause);\n }\n}\n\n// Path: gson/src/main/java/com/google/gson/JsonPrimitive.java\n/*\n * Copyright (C) 2008 Google Inc.\n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n */\n\npackage com.google.gson;\n\nimport com.google.gson.internal.LazilyParsedNumber;\nimport com.google.gson.internal.NumberLimits;\nimport java.math.BigDecimal;\nimport java.math.BigInteger;\nimport java.util.Objects;\n\n/**\n * A class representing a JSON primitive value. A primitive value is either a String, a Java\n * primitive, or a Java primitive wrapper type.\n *\n * @author Inderjeet Singh\n * @author Joel Leitch\n */\npublic final class JsonPrimitive extends JsonElement {\n\n private final Object value;\n\n /**\n * Create a primitive containing a boolean value.\n *\n * @param bool the value to create the primitive with.\n */\n // \"deprecation\" suppression for superclass constructor\n // \"UnnecessaryBoxedVariable\" Error Prone warning is correct since method does not accept\n // null, but cannot be changed anymore since this is public API\n @SuppressWarnings({\"deprecation\", \"UnnecessaryBoxedVariable\"})\n public JsonPrimitive(Boolean bool) {\n value = Objects.requireNonNull(bool);\n }\n\n /**\n * Create a primitive containing a {@link Number}.\n *\n * @param number the value to create the primitive with.\n */\n @SuppressWarnings(\"deprecation\") // superclass constructor\n public JsonPrimitive(Number number) {\n value = Objects.requireNonNull(number);\n }\n\n /**\n * Create a primitive containing a String value.\n *\n * @param string the value to create the primitive with.\n */\n @SuppressWarnings(\"deprecation\") // superclass constructor\n public JsonPrimitive(String string) {\n value = Objects.requireNonNull(string);\n }\n\n /**\n * Create a primitive containing a character. The character is turned into a one character String\n * since JSON only supports String.\n *\n * @param c the value to create the primitive with.\n */\n // \"deprecation\" suppression for superclass constructor\n // \"UnnecessaryBoxedVariable\" Error Prone warning is correct since method does not accept\n // null, but cannot be changed anymore since this is public API\n @SuppressWarnings({\"deprecation\", \"UnnecessaryBoxedVariable\"})\n public JsonPrimitive(Character c) {\n // convert characters to strings since in JSON, characters are represented as a single\n // character string\n value = Objects.requireNonNull(c).toString();\n }\n\n /**\n * Returns the same value as primitives are immutable.\n *\n * @since 2.8.2\n */\n @Override\n public JsonPrimitive deepCopy() {\n return this;\n }\n\n /**\n * Check whether this primitive contains a boolean value.\n *\n * @return true if this primitive contains a boolean value, false otherwise.\n */\n public boolean isBoolean() {\n return value instanceof Boolean;\n }\n\n /**\n * Convenience method to get this element as a boolean value. If this primitive {@linkplain\n * #isBoolean() is not a boolean}, the string value is parsed using {@link\n * Boolean#parseBoolean(String)}. This means {@code \"true\"} (ignoring case) is considered {@code\n * true} and any other value is considered {@code false}.\n */\n @Override\n public boolean getAsBoolean() {\n if (isBoolean()) {\n return (Boolean) value;\n }\n // Check to see if the value as a String is \"true\" in any case.\n return Boolean.parseBoolean(getAsString());\n }\n\n /**\n * Check whether this primitive contains a Number.\n *\n * @return true if this primitive contains a Number, false otherwise.\n */\n public boolean isNumber() {\n return value instanceof Number;\n }\n\n /**\n * Convenience method to get this element as a {@link Number}. If this primitive {@linkplain\n * #isString() is a string}, a lazily parsed {@code Number} is constructed which parses the string\n * when any of its methods are called (which can lead to a {@link NumberFormatException}).\n *\n * @throws UnsupportedOperationException if this primitive is neither a number nor a string.\n */\n @Override\n public Number getAsNumber() {\n if (value instanceof Number) {\n return (Number) value;\n } else if (value instanceof String) {\n return new LazilyParsedNumber((String) value);\n }\n throw new UnsupportedOperationException(\"Primitive is neither a number nor a string\");\n }\n\n /**\n * Check whether this primitive contains a String value.\n *\n * @return true if this primitive contains a String value, false otherwise.\n */\n public boolean isString() {\n return value instanceof String;\n }\n\n // Don't add Javadoc, inherit it from super implementation; no exceptions are thrown here\n @Override\n public String getAsString() {\n if (value instanceof String) {\n return (String) value;\n } else if (isNumber()) {\n return getAsNumber().toString();\n } else if (isBoolean()) {\n return ((Boolean) value).toString();\n }\n throw new AssertionError(\"Unexpected value type: \" + value.getClass());\n }\n\n /**\n * @throws NumberFormatException {@inheritDoc}\n */\n @Override\n public double getAsDouble() {\n return isNumber() ? getAsNumber().doubleValue() : Double.parseDouble(getAsString());\n }\n\n /**\n * @throws NumberFormatException {@inheritDoc}\n */\n @Override\n public BigDecimal getAsBigDecimal() {\n return value instanceof BigDecimal\n ? (BigDecimal) value\n : NumberLimits.parseBigDecimal(getAsString());\n }\n\n /**\n * @throws NumberFormatException {@inheritDoc}\n */\n @Override\n public BigInteger getAsBigInteger() {\n return value instanceof BigInteger\n ? (BigInteger) value\n : isIntegral(this)\n ? BigInteger.valueOf(this.getAsNumber().longValue())\n : NumberLimits.parseBigInteger(this.getAsString());\n }\n\n /**\n * @throws NumberFormatException {@inheritDoc}\n */\n @Override\n public float getAsFloat() {\n return isNumber() ? getAsNumber().floatValue() : Float.parseFloat(getAsString());\n }\n\n /**\n * Convenience method to get this element as a primitive long.\n *\n * @return this element as a primitive long.\n * @throws NumberFormatException {@inheritDoc}\n */\n @Override\n public long getAsLong() {\n return isNumber() ? getAsNumber().longValue() : Long.parseLong(getAsString());\n }\n\n /**\n * @throws NumberFormatException {@inheritDoc}\n */\n @Override\n public short getAsShort() {\n return isNumber() ? getAsNumber().shortValue() : Short.parseShort(getAsString());\n }\n\n /**\n * @throws NumberFormatException {@inheritDoc}\n */\n @Override\n public int getAsInt() {\n return isNumber() ? getAsNumber().intValue() : Integer.parseInt(getAsString());\n }\n\n /**\n * @throws NumberFormatException {@inheritDoc}\n */\n @Override\n public byte getAsByte() {\n return isNumber() ? getAsNumber().byteValue() : Byte.parseByte(getAsString());\n }\n\n /**\n * @throws UnsupportedOperationException if the string value of this primitive is empty.\n * @deprecated This method is misleading, as it does not get this element as a char but rather as\n * a string's first character.\n */\n @Deprecated\n @Override\n public char getAsCharacter() {\n String s = getAsString();\n if (s.isEmpty()) {\n throw new UnsupportedOperationException(\"String value is empty\");\n } else {\n return s.charAt(0);\n }\n }\n\n /** Returns the hash code of this object. */\n @Override\n public int hashCode() {\n if (value == null) {\n return 31;\n }\n // Using recommended hashing algorithm from Effective Java for longs and doubles\n if (isIntegral(this)) {\n long value = getAsNumber().longValue();\n return (int) (value ^ (value >>> 32));\n }\n if (value instanceof Number) {\n long value = Double.doubleToLongBits(getAsNumber().doubleValue());\n return (int) (value ^ (value >>> 32));\n }\n return value.hashCode();\n }\n\n /**\n * Returns whether the other object is equal to this. This method only considers the other object\n * to be equal if it is an instance of {@code JsonPrimitive} and has an equal value.\n */\n @Override\n public boolean equals(Object obj) {\n if (this == obj) {\n return true;\n }\n if (obj == null || getClass() != obj.getClass()) {\n return false;\n }\n JsonPrimitive other = (JsonPrimitive) obj;\n if (value == null) {\n return other.value == null;\n }\n if (isIntegral(this) && isIntegral(other)) {\n return (this.value instanceof BigInteger || other.value instanceof BigInteger)\n ? this.getAsBigInteger().equals(other.getAsBigInteger())\n : this.getAsNumber().longValue() == other.getAsNumber().longValue();\n }\n if (value instanceof Number && other.value instanceof Number) {\n if (value instanceof BigDecimal && other.value instanceof BigDecimal) {\n // Uses compareTo to ignore scale of values, e.g. `0` and `0.00` should be considered equal\n return this.getAsBigDecimal().compareTo(other.getAsBigDecimal()) == 0;\n }\n\n double thisAsDouble = this.getAsDouble();\n double otherAsDouble = other.getAsDouble();\n // Don't use Double.compare(double, double) because that considers -0.0 and +0.0 not equal\n return (thisAsDouble == otherAsDouble)\n || (Double.isNaN(thisAsDouble) && Double.isNaN(otherAsDouble));\n }\n return value.equals(other.value);\n }\n\n /**\n * Returns true if the specified number is an integral type (Long, Integer, Short, Byte,\n * BigInteger)\n */\n private static boolean isIntegral(JsonPrimitive primitive) {\n if (primitive.value instanceof Number) {\n Number number = (Number) primitive.value;\n return number instanceof BigInteger\n || number instanceof Long\n || number instanceof Integer\n || number instanceof Short\n || number instanceof Byte;\n }\n return false;\n }\n}\n\n// Path: gson/src/main/java/com/google/gson/internal/$Gson$Types.java\n/*\n * Copyright (C) 2008 Google Inc.\n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n */\n\npackage com.google.gson.internal;\n\nimport static com.google.gson.internal.$Gson$Preconditions.checkArgument;\nimport static java.util.Objects.requireNonNull;\n\nimport java.io.Serializable;\nimport java.lang.reflect.Array;\nimport java.lang.reflect.GenericArrayType;\nimport java.lang.reflect.GenericDeclaration;\nimport java.lang.reflect.Modifier;\nimport java.lang.reflect.ParameterizedType;\nimport java.lang.reflect.Type;\nimport java.lang.reflect.TypeVariable;\nimport java.lang.reflect.WildcardType;\nimport java.util.Arrays;\nimport java.util.Collection;\nimport java.util.HashMap;\nimport java.util.Map;\nimport java.util.NoSuchElementException;\nimport java.util.Objects;\nimport java.util.Properties;\n\n/**\n * Static methods for working with types.\n *\n * @author Bob Lee\n * @author Jesse Wilson\n */\npublic final class $Gson$Types {\n...\n", "needle_token_start": 6981, "needle_token_end": 7025, "code_context_ntokens": 15800, "output": ["```java\nprivate E nonNull(E element) {\n if (element == null) {\n throw new NullPointerException(\"Element must be non-null\");\n }\n return element;\n}\n```"]} +{"repo": "google/gson", "name": "rotateLeft", "language": "java", "path": "gson/src/main/java/com/google/gson/internal/LinkedTreeMap.java", "position_ratio": 0.55, "description": "\nFunction Description:\n1. **Purpose**: To adjust the structure of a binary tree by promoting the right child of a given node to be the new root of the subtree, which helps maintain or restore the balance of the tree after modifications such as insertions or deletions.\n2. **Input**: A node in a binary tree that is currently unbalanced due to a heavier right subtree.\n3. **Output**: There is no direct output returned by the function; however, it modifies the structure of the tree to balance the heights of the subtrees.\n4. **Procedure**: \n - The function identifies the right child of the given node and its children.\n - It then reassigns the left child of the right child to be the right child of the original node.\n - The original node is set as the left child of its former right child, effectively making the right child the new root of the subtree.\n - The parent-child relationships are updated to reflect these changes.\n - Finally, the heights of the affected nodes are recalculated to ensure they accurately represent the heights of their respective subtrees post-rotation.\n\n", "instruction": "Based on the function description and code context, please retrieve and repeat the exact described function from the code context in a code block wrapped by ```:", "template": "instruction\ncode_context\ndescription\ninstruction", "code_context": "// Path: gson/src/main/java/module-info.java\n/*\n * Copyright (C) 2018 Google Inc.\n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n */\n\n/**\n * Defines the Gson serialization/deserialization API.\n *\n * @since 2.8.6\n */\nmodule com.google.gson {\n exports com.google.gson;\n exports com.google.gson.annotations;\n exports com.google.gson.reflect;\n exports com.google.gson.stream;\n\n // Dependency on Error Prone Annotations\n requires static com.google.errorprone.annotations;\n\n // Optional dependency on java.sql\n requires static java.sql;\n\n // Optional dependency on jdk.unsupported for JDK's sun.misc.Unsafe\n requires static jdk.unsupported;\n}\n\n// Path: gson/src/main/java/com/google/gson/JsonParseException.java\n/*\n * Copyright (C) 2008 Google Inc.\n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n */\n\npackage com.google.gson;\n\n/**\n * This exception is raised if there is a serious issue that occurs during parsing of a Json string.\n * One of the main usages for this class is for the Gson infrastructure. If the incoming Json is\n * bad/malicious, an instance of this exception is raised.\n *\n *

    This exception is a {@link RuntimeException} because it is exposed to the client. Using a\n * {@link RuntimeException} avoids bad coding practices on the client side where they catch the\n * exception and do nothing. It is often the case that you want to blow up if there is a parsing\n * error (i.e. often clients do not know how to recover from a {@link JsonParseException}.\n *\n * @author Inderjeet Singh\n * @author Joel Leitch\n */\npublic class JsonParseException extends RuntimeException {\n static final long serialVersionUID = -4086729973971783390L;\n\n /**\n * Creates exception with the specified message. If you are wrapping another exception, consider\n * using {@link #JsonParseException(String, Throwable)} instead.\n *\n * @param msg error message describing a possible cause of this exception.\n */\n public JsonParseException(String msg) {\n super(msg);\n }\n\n /**\n * Creates exception with the specified message and cause.\n *\n * @param msg error message describing what happened.\n * @param cause root exception that caused this exception to be thrown.\n */\n public JsonParseException(String msg, Throwable cause) {\n super(msg, cause);\n }\n\n /**\n * Creates exception with the specified cause. Consider using {@link #JsonParseException(String,\n * Throwable)} instead if you can describe what happened.\n *\n * @param cause root exception that caused this exception to be thrown.\n */\n public JsonParseException(Throwable cause) {\n super(cause);\n }\n}\n\n// Path: gson/src/main/java/com/google/gson/InstanceCreator.java\n/*\n * Copyright (C) 2008 Google Inc.\n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n */\n\npackage com.google.gson;\n\nimport java.lang.reflect.Type;\n\n/**\n * This interface is implemented to create instances of a class that does not define a no-args\n * constructor. If you can modify the class, you should instead add a private, or public no-args\n * constructor. However, that is not possible for library classes, such as JDK classes, or a\n * third-party library that you do not have source-code of. In such cases, you should define an\n * instance creator for the class. Implementations of this interface should be registered with\n * {@link GsonBuilder#registerTypeAdapter(Type, Object)} method before Gson will be able to use\n * them.\n *\n *

    Let us look at an example where defining an InstanceCreator might be useful. The {@code Id}\n * class defined below does not have a default no-args constructor.\n *\n *

    \n * public class Id<T> {\n *   private final Class<T> clazz;\n *   private final long value;\n *   public Id(Class<T> clazz, long value) {\n *     this.clazz = clazz;\n *     this.value = value;\n *   }\n * }\n * 
    \n *\n *

    If Gson encounters an object of type {@code Id} during deserialization, it will throw an\n * exception. The easiest way to solve this problem will be to add a (public or private) no-args\n * constructor as follows:\n *\n *

    \n * private Id() {\n *   this(Object.class, 0L);\n * }\n * 
    \n *\n *

    However, let us assume that the developer does not have access to the source-code of the\n * {@code Id} class, or does not want to define a no-args constructor for it. The developer can\n * solve this problem by defining an {@code InstanceCreator} for {@code Id}:\n *\n *

    \n * class IdInstanceCreator implements InstanceCreator<Id> {\n *   public Id createInstance(Type type) {\n *     return new Id(Object.class, 0L);\n *   }\n * }\n * 
    \n *\n *

    Note that it does not matter what the fields of the created instance contain since Gson will\n * overwrite them with the deserialized values specified in JSON. You should also ensure that a\n * new object is returned, not a common object since its fields will be overwritten. The\n * developer will need to register {@code IdInstanceCreator} with Gson as follows:\n *\n *

    \n * Gson gson = new GsonBuilder().registerTypeAdapter(Id.class, new IdInstanceCreator()).create();\n * 
    \n *\n * @param the type of object that will be created by this implementation.\n * @see GsonBuilder#registerTypeAdapter(Type, Object)\n * @author Inderjeet Singh\n * @author Joel Leitch\n */\npublic interface InstanceCreator {\n\n /**\n * Gson invokes this call-back method during deserialization to create an instance of the\n * specified type. The fields of the returned instance are overwritten with the data present in\n * the JSON. Since the prior contents of the object are destroyed and overwritten, do not return\n * an instance that is useful elsewhere. In particular, do not return a common instance, always\n * use {@code new} to create a new instance.\n *\n * @param type the parameterized T represented as a {@link Type}.\n * @return a default object instance of type T.\n */\n public T createInstance(Type type);\n}\n\n// Path: gson/src/main/java/com/google/gson/JsonNull.java\n/*\n * Copyright (C) 2008 Google Inc.\n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n */\n\npackage com.google.gson;\n\n/**\n * A class representing a JSON {@code null} value.\n *\n * @author Inderjeet Singh\n * @author Joel Leitch\n * @since 1.2\n */\npublic final class JsonNull extends JsonElement {\n /**\n * Singleton for {@code JsonNull}.\n *\n * @since 1.8\n */\n public static final JsonNull INSTANCE = new JsonNull();\n\n /**\n * Creates a new {@code JsonNull} object.\n *\n * @deprecated Deprecated since Gson version 1.8, use {@link #INSTANCE} instead.\n */\n @Deprecated\n public JsonNull() {\n // Do nothing\n }\n\n /**\n * Returns the same instance since it is an immutable value.\n *\n * @since 2.8.2\n */\n @Override\n public JsonNull deepCopy() {\n return INSTANCE;\n }\n\n /** All instances of {@code JsonNull} have the same hash code since they are indistinguishable. */\n @Override\n public int hashCode() {\n return JsonNull.class.hashCode();\n }\n\n /** All instances of {@code JsonNull} are considered equal. */\n @Override\n public boolean equals(Object other) {\n return other instanceof JsonNull;\n }\n}\n\n// Path: gson/src/main/java/com/google/gson/internal/LinkedTreeMap.java\n/*\n * Copyright (C) 2010 The Android Open Source Project\n * Copyright (C) 2012 Google Inc.\n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n */\n\npackage com.google.gson.internal;\n\nimport com.google.errorprone.annotations.CanIgnoreReturnValue;\nimport java.io.IOException;\nimport java.io.InvalidObjectException;\nimport java.io.ObjectInputStream;\nimport java.io.ObjectStreamException;\nimport java.io.Serializable;\nimport java.util.AbstractMap;\nimport java.util.AbstractSet;\nimport java.util.Comparator;\nimport java.util.ConcurrentModificationException;\nimport java.util.Iterator;\nimport java.util.LinkedHashMap;\nimport java.util.NoSuchElementException;\nimport java.util.Objects;\nimport java.util.Set;\n\n/**\n * A map of comparable keys to values. Unlike {@code TreeMap}, this class uses insertion order for\n * iteration order. Comparison order is only used as an optimization for efficient insertion and\n * removal.\n *\n *

    This implementation was derived from Android 4.1's TreeMap class.\n */\n@SuppressWarnings(\"serial\") // ignore warning about missing serialVersionUID\npublic final class LinkedTreeMap extends AbstractMap implements Serializable {\n @SuppressWarnings({\"unchecked\", \"rawtypes\"}) // to avoid Comparable>>\n private static final Comparator NATURAL_ORDER =\n new Comparator() {\n @Override\n public int compare(Comparable a, Comparable b) {\n return a.compareTo(b);\n }\n };\n\n private final Comparator comparator;\n private final boolean allowNullValues;\n Node root;\n int size = 0;\n int modCount = 0;\n\n // Used to preserve iteration order\n final Node header;\n\n /**\n * Create a natural order, empty tree map whose keys must be mutually comparable and non-null, and\n * whose values can be {@code null}.\n */\n @SuppressWarnings(\"unchecked\") // unsafe! this assumes K is comparable\n public LinkedTreeMap() {\n this((Comparator) NATURAL_ORDER, true);\n }\n\n /**\n * Create a natural order, empty tree map whose keys must be mutually comparable and non-null.\n *\n * @param allowNullValues whether {@code null} is allowed as entry value\n */\n @SuppressWarnings(\"unchecked\") // unsafe! this assumes K is comparable\n public LinkedTreeMap(boolean allowNullValues) {\n this((Comparator) NATURAL_ORDER, allowNullValues);\n }\n\n /**\n * Create a tree map ordered by {@code comparator}. This map's keys may only be null if {@code\n * comparator} permits.\n *\n * @param comparator the comparator to order elements with, or {@code null} to use the natural\n * ordering.\n * @param allowNullValues whether {@code null} is allowed as entry value\n */\n // unsafe! if comparator is null, this assumes K is comparable\n @SuppressWarnings({\"unchecked\", \"rawtypes\"})\n public LinkedTreeMap(Comparator comparator, boolean allowNullValues) {\n this.comparator = comparator != null ? comparator : (Comparator) NATURAL_ORDER;\n this.allowNullValues = allowNullValues;\n this.header = new Node<>(allowNullValues);\n }\n\n @Override\n public int size() {\n return size;\n }\n\n @Override\n public V get(Object key) {\n Node node = findByObject(key);\n return node != null ? node.value : null;\n }\n\n @Override\n public boolean containsKey(Object key) {\n return findByObject(key) != null;\n }\n\n @CanIgnoreReturnValue\n @Override\n public V put(K key, V value) {\n if (key == null) {\n throw new NullPointerException(\"key == null\");\n }\n if (value == null && !allowNullValues) {\n throw new NullPointerException(\"value == null\");\n }\n Node created = find(key, true);\n V result = created.value;\n created.value = value;\n return result;\n }\n\n @Override\n public void clear() {\n root = null;\n size = 0;\n modCount++;\n\n // Clear iteration order\n Node header = this.header;\n header.next = header.prev = header;\n }\n\n @Override\n public V remove(Object key) {\n Node node = removeInternalByKey(key);\n return node != null ? node.value : null;\n }\n\n /**\n * Returns the node at or adjacent to the given key, creating it if requested.\n *\n * @throws ClassCastException if {@code key} and the tree's keys aren't mutually comparable.\n */\n Node find(K key, boolean create) {\n Comparator comparator = this.comparator;\n Node nearest = root;\n int comparison = 0;\n\n if (nearest != null) {\n // Micro-optimization: avoid polymorphic calls to Comparator.compare().\n @SuppressWarnings(\"unchecked\") // Throws a ClassCastException below if there's trouble.\n Comparable comparableKey =\n (comparator == NATURAL_ORDER) ? (Comparable) key : null;\n\n while (true) {\n comparison =\n (comparableKey != null)\n ? comparableKey.compareTo(nearest.key)\n : comparator.compare(key, nearest.key);\n\n // We found the requested key.\n if (comparison == 0) {\n return nearest;\n }\n\n // If it exists, the key is in a subtree. Go deeper.\n Node child = (comparison < 0) ? nearest.left : nearest.right;\n if (child == null) {\n break;\n }\n\n nearest = child;\n }\n }\n\n // The key doesn't exist in this tree.\n if (!create) {\n return null;\n }\n\n // Create the node and add it to the tree or the table.\n Node header = this.header;\n Node created;\n if (nearest == null) {\n // Check that the value is comparable if we didn't do any comparisons.\n if (comparator == NATURAL_ORDER && !(key instanceof Comparable)) {\n throw new ClassCastException(key.getClass().getName() + \" is not Comparable\");\n }\n created = new Node<>(allowNullValues, nearest, key, header, header.prev);\n root = created;\n } else {\n created = new Node<>(allowNullValues, nearest, key, header, header.prev);\n if (comparison < 0) { // nearest.key is higher\n nearest.left = created;\n } else { // comparison > 0, nearest.key is lower\n nearest.right = created;\n }\n rebalance(nearest, true);\n }\n size++;\n modCount++;\n\n return created;\n }\n\n @SuppressWarnings(\"unchecked\")\n Node findByObject(Object key) {\n try {\n return key != null ? find((K) key, false) : null;\n } catch (ClassCastException e) {\n return null;\n }\n }\n\n /**\n * Returns this map's entry that has the same key and value as {@code entry}, or null if this map\n * has no such entry.\n *\n *

    This method uses the comparator for key equality rather than {@code equals}. If this map's\n * comparator isn't consistent with equals (such as {@code String.CASE_INSENSITIVE_ORDER}), then\n * {@code remove()} and {@code contains()} will violate the collections API.\n */\n Node findByEntry(Entry entry) {\n Node mine = findByObject(entry.getKey());\n boolean valuesEqual = mine != null && equal(mine.value, entry.getValue());\n return valuesEqual ? mine : null;\n }\n\n private static boolean equal(Object a, Object b) {\n return Objects.equals(a, b);\n }\n\n /**\n * Removes {@code node} from this tree, rearranging the tree's structure as necessary.\n *\n * @param unlink true to also unlink this node from the iteration linked list.\n */\n void removeInternal(Node node, boolean unlink) {\n if (unlink) {\n node.prev.next = node.next;\n node.next.prev = node.prev;\n }\n\n Node left = node.left;\n Node right = node.right;\n Node originalParent = node.parent;\n if (left != null && right != null) {\n\n /*\n * To remove a node with both left and right subtrees, move an\n * adjacent node from one of those subtrees into this node's place.\n *\n * Removing the adjacent node may change this node's subtrees. This\n * node may no longer have two subtrees once the adjacent node is\n * gone!\n */\n\n Node adjacent = (left.height > right.height) ? left.last() : right.first();\n removeInternal(adjacent, false); // takes care of rebalance and size--\n\n int leftHeight = 0;\n left = node.left;\n if (left != null) {\n leftHeight = left.height;\n adjacent.left = left;\n left.parent = adjacent;\n node.left = null;\n }\n\n int rightHeight = 0;\n right = node.right;\n if (right != null) {\n rightHeight = right.height;\n adjacent.right = right;\n right.parent = adjacent;\n node.right = null;\n }\n\n adjacent.height = Math.max(leftHeight, rightHeight) + 1;\n replaceInParent(node, adjacent);\n return;\n } else if (left != null) {\n replaceInParent(node, left);\n node.left = null;\n } else if (right != null) {\n replaceInParent(node, right);\n node.right = null;\n } else {\n replaceInParent(node, null);\n }\n\n rebalance(originalParent, false);\n size--;\n modCount++;\n }\n\n Node removeInternalByKey(Object key) {\n Node node = findByObject(key);\n if (node != null) {\n removeInternal(node, true);\n }\n return node;\n }\n\n @SuppressWarnings(\"ReferenceEquality\")\n private void replaceInParent(Node node, Node replacement) {\n Node parent = node.parent;\n node.parent = null;\n if (replacement != null) {\n replacement.parent = parent;\n }\n\n if (parent != null) {\n if (parent.left == node) {\n parent.left = replacement;\n } else {\n assert parent.right == node;\n parent.right = replacement;\n }\n } else {\n root = replacement;\n }\n }\n\n /**\n * Rebalances the tree by making any AVL rotations necessary between the newly-unbalanced node and\n * the tree's root.\n *\n * @param insert true if the node was unbalanced by an insert; false if it was by a removal.\n */\n private void rebalance(Node unbalanced, boolean insert) {\n for (Node node = unbalanced; node != null; node = node.parent) {\n Node left = node.left;\n Node right = node.right;\n int leftHeight = left != null ? left.height : 0;\n int rightHeight = right != null ? right.height : 0;\n\n int delta = leftHeight - rightHeight;\n if (delta == -2) {\n Node rightLeft = right.left;\n Node rightRight = right.right;\n int rightRightHeight = rightRight != null ? rightRight.height : 0;\n int rightLeftHeight = rightLeft != null ? rightLeft.height : 0;\n\n int rightDelta = rightLeftHeight - rightRightHeight;\n if (rightDelta == -1 || (rightDelta == 0 && !insert)) {\n rotateLeft(node); // AVL right right\n } else {\n assert (rightDelta == 1);\n rotateRight(right); // AVL right left\n rotateLeft(node);\n }\n if (insert) {\n break; // no further rotations will be necessary\n }\n\n } else if (delta == 2) {\n Node leftLeft = left.left;\n Node leftRight = left.right;\n int leftRightHeight = leftRight != null ? leftRight.height : 0;\n int leftLeftHeight = leftLeft != null ? leftLeft.height : 0;\n\n int leftDelta = leftLeftHeight - leftRightHeight;\n if (leftDelta == 1 || (leftDelta == 0 && !insert)) {\n rotateRight(node); // AVL left left\n } else {\n assert (leftDelta == -1);\n rotateLeft(left); // AVL left right\n rotateRight(node);\n }\n if (insert) {\n break; // no further rotations will be necessary\n }\n\n } else if (delta == 0) {\n node.height = leftHeight + 1; // leftHeight == rightHeight\n if (insert) {\n break; // the insert caused balance, so rebalancing is done!\n }\n\n } else {\n assert (delta == -1 || delta == 1);\n node.height = Math.max(leftHeight, rightHeight) + 1;\n if (!insert) {\n break; // the height hasn't changed, so rebalancing is done!\n }\n }\n }\n }\n\n /** Rotates the subtree so that its root's right child is the new root. */\n \nprivate void rotateLeft(Node root) {\n Node left = root.left;\n Node pivot = root.right;\n Node pivotLeft = pivot.left;\n Node pivotRight = pivot.right;\n\n // move the pivot's left child to the root's right\n root.right = pivotLeft;\n if (pivotLeft != null) {\n pivotLeft.parent = root;\n }\n\n replaceInParent(root, pivot);\n\n // move the root to the pivot's left\n pivot.left = root;\n root.parent = pivot;\n\n // fix heights\n root.height =\n Math.max(left != null ? left.height : 0, pivotLeft != null ? pivotLeft.height : 0) + 1;\n pivot.height = Math.max(root.height, pivotRight != null ? pivotRight.height : 0) + 1;\n }\n\n /** Rotates the subtree so that its root's left child is the new root. */\n private void rotateRight(Node root) {\n Node pivot = root.left;\n Node right = root.right;\n Node pivotLeft = pivot.left;\n Node pivotRight = pivot.right;\n\n // move the pivot's right child to the root's left\n root.left = pivotRight;\n if (pivotRight != null) {\n pivotRight.parent = root;\n }\n\n replaceInParent(root, pivot);\n\n // move the root to the pivot's right\n pivot.right = root;\n root.parent = pivot;\n\n // fixup heights\n root.height =\n Math.max(right != null ? right.height : 0, pivotRight != null ? pivotRight.height : 0) + 1;\n pivot.height = Math.max(root.height, pivotLeft != null ? pivotLeft.height : 0) + 1;\n }\n\n private EntrySet entrySet;\n private KeySet keySet;\n\n @Override\n public Set> entrySet() {\n EntrySet result = entrySet;\n return result != null ? result : (entrySet = new EntrySet());\n }\n\n @Override\n public Set keySet() {\n KeySet result = keySet;\n return result != null ? result : (keySet = new KeySet());\n }\n\n static final class Node implements Entry {\n Node parent;\n Node left;\n Node right;\n Node next;\n Node prev;\n final K key;\n final boolean allowNullValue;\n V value;\n int height;\n\n /** Create the header entry */\n Node(boolean allowNullValue) {\n key = null;\n this.allowNullValue = allowNullValue;\n next = prev = this;\n }\n\n /** Create a regular entry */\n Node(boolean allowNullValue, Node parent, K key, Node next, Node prev) {\n this.parent = parent;\n this.key = key;\n this.allowNullValue = allowNullValue;\n this.height = 1;\n this.next = next;\n this.prev = prev;\n prev.next = this;\n next.prev = this;\n }\n\n @Override\n public K getKey() {\n return key;\n }\n\n @Override\n public V getValue() {\n return value;\n }\n\n @Override\n public V setValue(V value) {\n if (value == null && !allowNullValue) {\n throw new NullPointerException(\"value == null\");\n }\n V oldValue = this.value;\n this.value = value;\n return oldValue;\n }\n\n @Override\n public boolean equals(Object o) {\n if (o instanceof Entry) {\n Entry other = (Entry) o;\n return (key == null ? other.getKey() == null : key.equals(other.getKey()))\n && (value == null ? other.getValue() == null : value.equals(other.getValue()));\n }\n return false;\n }\n\n @Override\n public int hashCode() {\n return (key == null ? 0 : key.hashCode()) ^ (value == null ? 0 : value.hashCode());\n }\n\n @Override\n public String toString() {\n return key + \"=\" + value;\n }\n\n /** Returns the first node in this subtree. */\n public Node first() {\n Node node = this;\n Node child = node.left;\n while (child != null) {\n node = child;\n child = node.left;\n }\n return node;\n }\n\n /** Returns the last node in this subtree. */\n public Node last() {\n Node node = this;\n Node child = node.right;\n while (child != null) {\n node = child;\n child = node.right;\n }\n return node;\n }\n }\n\n private abstract class LinkedTreeMapIterator implements Iterator {\n Node next = header.next;\n Node lastReturned = null;\n int expectedModCount = modCount;\n\n LinkedTreeMapIterator() {}\n\n @Override\n @SuppressWarnings(\"ReferenceEquality\")\n public final boolean hasNext() {\n return next != header;\n }\n\n @SuppressWarnings(\"ReferenceEquality\")\n final Node nextNode() {\n Node e = next;\n if (e == header) {\n throw new NoSuchElementException();\n }\n if (modCount != expectedModCount) {\n throw new ConcurrentModificationException();\n }\n next = e.next;\n return lastReturned = e;\n }\n\n @Override\n public final void remove() {\n if (lastReturned == null) {\n throw new IllegalStateException();\n }\n removeInternal(lastReturned, true);\n lastReturned = null;\n expectedModCount = modCount;\n }\n }\n\n class EntrySet extends AbstractSet> {\n @Override\n public int size() {\n return size;\n }\n\n @Override\n public Iterator> iterator() {\n return new LinkedTreeMapIterator>() {\n @Override\n public Entry next() {\n return nextNode();\n }\n };\n }\n\n @Override\n public boolean contains(Object o) {\n return o instanceof Entry && findByEntry((Entry) o) != null;\n }\n\n @Override\n public boolean remove(Object o) {\n if (!(o instanceof Entry)) {\n return false;\n }\n\n Node node = findByEntry((Entry) o);\n if (node == null) {\n return false;\n }\n removeInternal(node, true);\n return true;\n }\n\n @Override\n public void clear() {\n LinkedTreeMap.this.clear();\n }\n }\n\n final class KeySet extends AbstractSet {\n @Override\n public int size() {\n return size;\n }\n\n @Override\n public Iterator iterator() {\n return new LinkedTreeMapIterator() {\n @Override\n public K next() {\n return nextNode().key;\n }\n };\n }\n\n @Override\n public boolean contains(Object o) {\n return containsKey(o);\n }\n\n @Override\n public boolean remove(Object key) {\n return removeInternalByKey(key) != null;\n }\n\n @Override\n public void clear() {\n LinkedTreeMap.this.clear();\n }\n }\n\n /**\n * If somebody is unlucky enough to have to serialize one of these, serialize it as a\n * LinkedHashMap so that they won't need Gson on the other side to deserialize it. Using\n * serialization defeats our DoS defence, so most apps shouldn't use it.\n */\n private Object writeReplace() throws ObjectStreamException {\n return new LinkedHashMap<>(this);\n }\n\n private void readObject(ObjectInputStream in) throws IOException {\n // Don't permit directly deserializing this class; writeReplace() should have written a\n // replacement\n throw new InvalidObjectException(\"Deserialization is unsupported\");\n }\n}\n\n// Path: gson/src/main/java/com/google/gson/JsonObject.java\n/*\n * Copyright (C) 2008 Google Inc.\n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n */\n\npackage com.google.gson;\n\nimport com.google.errorprone.annotations.CanIgnoreReturnValue;\nimport com.google.gson.internal.LinkedTreeMap;\nimport java.util.Map;\nimport java.util.Set;\n\n/**\n * A class representing an object type in Json. An object consists of name-value pairs where names\n * are strings, and values are any other type of {@link JsonElement}. This allows for a creating a\n * tree of JsonElements. The member elements of this object are maintained in order they were added.\n * This class does not support {@code null} values. If {@code null} is provided as value argument to\n * any of the methods, it is converted to a {@link JsonNull}.\n *\n *

    {@code JsonObject} does not implement the {@link Map} interface, but a {@code Map} view of it\n * can be obtained with {@link #asMap()}.\n *\n * @author Inderjeet Singh\n * @author Joel Leitch\n */\npublic final class JsonObject extends JsonElement {\n private final LinkedTreeMap members = new LinkedTreeMap<>(false);\n\n /** Creates an empty JsonObject. */\n @SuppressWarnings(\"deprecation\") // superclass constructor\n public JsonObject() {}\n\n /**\n * Creates a deep copy of this element and all its children.\n *\n * @since 2.8.2\n */\n @Override\n public JsonObject deepCopy() {\n JsonObject result = new JsonObject();\n for (Map.Entry entry : members.entrySet()) {\n result.add(entry.getKey(), entry.getValue().deepCopy());\n }\n return result;\n }\n\n /**\n * Adds a member, which is a name-value pair, to self. The name must be a String, but the value\n * can be an arbitrary {@link JsonElement}, thereby allowing you to build a full tree of\n * JsonElements rooted at this node.\n *\n * @param property name of the member.\n * @param value the member object.\n */\n public void add(String property, JsonElement value) {\n members.put(property, value == null ? JsonNull.INSTANCE : value);\n }\n\n /**\n * Removes the {@code property} from this object.\n *\n * @param property name of the member that should be removed.\n * @return the {@link JsonElement} object that is being removed, or {@code null} if no member with\n * this name exists.\n * @since 1.3\n */\n @CanIgnoreReturnValue\n public JsonElement remove(String property) {\n return members.remove(property);\n }\n\n /**\n * Convenience method to add a string member. The specified value is converted to a {@link\n * JsonPrimitive} of String.\n *\n * @param property name of the member.\n * @param value the string value associated with the member.\n */\n public void addProperty(String property, String value) {\n add(property, value == null ? JsonNull.INSTANCE : new JsonPrimitive(value));\n }\n\n /**\n * Convenience method to add a number member. The specified value is converted to a {@link\n * JsonPrimitive} of Number.\n *\n * @param property name of the member.\n * @param value the number value associated with the member.\n */\n public void addProperty(String property, Number value) {\n add(property, value == null ? JsonNull.INSTANCE : new JsonPrimitive(value));\n }\n\n /**\n * Convenience method to add a boolean member. The specified value is converted to a {@link\n * JsonPrimitive} of Boolean.\n *\n * @param property name of the member.\n * @param value the boolean value associated with the member.\n */\n public void addProperty(String property, Boolean value) {\n add(property, value == null ? JsonNull.INSTANCE : new JsonPrimitive(value));\n }\n\n /**\n * Convenience method to add a char member. The specified value is converted to a {@link\n * JsonPrimitive} of Character.\n *\n * @param property name of the member.\n * @param value the char value associated with the member.\n */\n public void addProperty(String property, Character value) {\n add(property, value == null ? JsonNull.INSTANCE : new JsonPrimitive(value));\n }\n\n /**\n * Returns a set of members of this object. The set is ordered, and the order is in which the\n * elements were added.\n *\n * @return a set of members of this object.\n */\n public Set> entrySet() {\n return members.entrySet();\n }\n\n /**\n * Returns a set of members key values.\n *\n * @return a set of member keys as Strings\n * @since 2.8.1\n */\n public Set keySet() {\n return members.keySet();\n }\n\n /**\n * Returns the number of key/value pairs in the object.\n *\n * @return the number of key/value pairs in the object.\n * @since 2.7\n */\n public int size() {\n return members.size();\n }\n\n /**\n * Returns true if the number of key/value pairs in the object is zero.\n *\n * @return true if the number of key/value pairs in the object is zero.\n * @since 2.10.1\n */\n public boolean isEmpty() {\n return members.size() == 0;\n }\n\n /**\n * Convenience method to check if a member with the specified name is present in this object.\n *\n * @param memberName name of the member that is being checked for presence.\n * @return true if there is a member with the specified name, false otherwise.\n */\n public boolean has(String memberName) {\n return members.containsKey(memberName);\n }\n\n /**\n * Returns the member with the specified name.\n *\n * @param memberName name of the member that is being requested.\n * @return the member matching the name, or {@code null} if no such member exists.\n */\n public JsonElement get(String memberName) {\n return members.get(memberName);\n }\n\n /**\n * Convenience method to get the specified member as a {@link JsonPrimitive}.\n *\n * @param memberName name of the member being requested.\n * @return the {@code JsonPrimitive} corresponding to the specified member, or {@code null} if no\n * member with this name exists.\n * @throws ClassCastException if the member is not of type {@code JsonPrimitive}.\n */\n public JsonPrimitive getAsJsonPrimitive(String memberName) {\n return (JsonPrimitive) members.get(memberName);\n }\n\n /**\n * Convenience method to get the specified member as a {@link JsonArray}.\n *\n * @param memberName name of the member being requested.\n * @return the {@code JsonArray} corresponding to the specified member, or {@code null} if no\n * member with this name exists.\n * @throws ClassCastException if the member is not of type {@code JsonArray}.\n */\n public JsonArray getAsJsonArray(String memberName) {\n return (JsonArray) members.get(memberName);\n }\n\n /**\n * Convenience method to get the specified member as a {@link JsonObject}.\n *\n * @param memberName name of the member being requested.\n * @return the {@code JsonObject} corresponding to the specified member, or {@code null} if no\n * member with this name exists.\n * @throws ClassCastException if the member is not of type {@code JsonObject}.\n */\n public JsonObject getAsJsonObject(String memberName) {\n return (JsonObject) members.get(memberName);\n }\n\n /**\n * Returns a mutable {@link Map} view of this {@code JsonObject}. Changes to the {@code Map} are\n * visible in this {@code JsonObject} and the other way around.\n *\n *

    The {@code Map} does not permit {@code null} keys or values. Unlike {@code JsonObject}'s\n * {@code null} handling, a {@link NullPointerException} is thrown when trying to add {@code\n * null}. Use {@link JsonNull} for JSON null values.\n *\n * @return mutable {@code Map} view\n * @since 2.10\n */\n public Map asMap() {\n // It is safe to expose the underlying map because it disallows null keys and values\n return members;\n }\n\n /**\n * Returns whether the other object is equal to this. This method only considers the other object\n * to be equal if it is an instance of {@code JsonObject} and has equal members, ignoring order.\n */\n @Override\n public boolean equals(Object o) {\n return (o == this) || (o instanceof JsonObject && ((JsonObject) o).members.equals(members));\n }\n\n /**\n * Returns the hash code of this object. This method calculates the hash code based on the members\n * of this object, ignoring order.\n */\n @Override\n public int hashCode() {\n return members.hashCode();\n }\n}\n\n// Path: gson/src/main/java/com/google/gson/JsonSerializationContext.java\n/*\n * Copyright (C) 2008 Google Inc.\n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n */\n\npackage com.google.gson;\n\nimport java.lang.reflect.Type;\n\n/**\n * Context for serialization that is passed to a custom serializer during invocation of its {@link\n * JsonSerializer#serialize(Object, Type, JsonSerializationContext)} method.\n *\n * @author Inderjeet Singh\n * @author Joel Leitch\n */\npublic interface JsonSerializationContext {\n\n /**\n * Invokes default serialization on the specified object.\n *\n * @param src the object that needs to be serialized.\n * @return a tree of {@link JsonElement}s corresponding to the serialized form of {@code src}.\n */\n public JsonElement serialize(Object src);\n\n /**\n * Invokes default serialization on the specified object passing the specific type information. It\n * should never be invoked on the element received as a parameter of the {@link\n * JsonSerializer#serialize(Object, Type, JsonSerializationContext)} method. Doing so will result\n * in an infinite loop since Gson will in-turn call the custom serializer again.\n *\n * @param src the object that needs to be serialized.\n * @param typeOfSrc the actual genericized type of src object.\n * @return a tree of {@link JsonElement}s corresponding to the serialized form of {@code src}.\n */\n public JsonElement serialize(Object src, Type typeOfSrc);\n}\n\n// Path: gson/src/main/java/com/google/gson/JsonSerializer.java\n/*\n * Copyright (C) 2008 Google Inc.\n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n */\n\npackage com.google.gson;\n\nimport java.lang.reflect.Type;\n\n/**\n * Interface representing a custom serializer for JSON. You should write a custom serializer, if you\n * are not happy with the default serialization done by Gson. You will also need to register this\n * serializer through {@link com.google.gson.GsonBuilder#registerTypeAdapter(Type, Object)}.\n *\n *

    Let us look at example where defining a serializer will be useful. The {@code Id} class\n * defined below has two fields: {@code clazz} and {@code value}.\n *\n *

    \n * public class Id<T> {\n *   private final Class<T> clazz;\n *   private final long value;\n *\n *   public Id(Class<T> clazz, long value) {\n *     this.clazz = clazz;\n *     this.value = value;\n *   }\n *\n *   public long getValue() {\n *     return value;\n *   }\n * }\n * 
    \n *\n *

    The default serialization of {@code Id(com.foo.MyObject.class, 20L)} will be \n * {\"clazz\":\"com.foo.MyObject\",\"value\":20}. Suppose, you just want the output to be the value\n * instead, which is {@code 20} in this case. You can achieve that by writing a custom serializer:\n *\n *

    \n * class IdSerializer implements JsonSerializer<Id> {\n *   public JsonElement serialize(Id id, Type typeOfId, JsonSerializationContext context) {\n *     return new JsonPrimitive(id.getValue());\n *   }\n * }\n * 
    \n *\n *

    You will also need to register {@code IdSerializer} with Gson as follows:\n *\n *

    \n * Gson gson = new GsonBuilder().registerTypeAdapter(Id.class, new IdSerializer()).create();\n * 
    \n *\n *

    Serializers should be stateless and thread-safe, otherwise the thread-safety guarantees of\n * {@link Gson} might not apply.\n *\n *

    New applications should prefer {@link TypeAdapter}, whose streaming API is more efficient than\n * this interface's tree API.\n *\n * @author Inderjeet Singh\n * @author Joel Leitch\n * @param type for which the serializer is being registered. It is possible that a serializer\n * may be asked to serialize a specific generic type of the T.\n */\npublic interface JsonSerializer {\n\n /**\n * Gson invokes this call-back method during serialization when it encounters a field of the\n * specified type.\n *\n *

    In the implementation of this call-back method, you should consider invoking {@link\n * JsonSerializationContext#serialize(Object, Type)} method to create JsonElements for any\n * non-trivial field of the {@code src} object. However, you should never invoke it on the {@code\n * src} object itself since that will cause an infinite loop (Gson will call your call-back method\n * again).\n *\n * @param src the object that needs to be converted to Json.\n * @param typeOfSrc the actual type (fully genericized version) of the source object.\n * @return a JsonElement corresponding to the specified object.\n */\n public JsonElement serialize(T src, Type typeOfSrc, JsonSerializationContext context);\n}\n\n// Path: gson/src/main/java/com/google/gson/internal/NumberLimits.java\npackage com.google.gson.internal;\n\nimport java.math.BigDecimal;\nimport java.math.BigInteger;\n\n/**\n * This class enforces limits on numbers parsed from JSON to avoid potential performance problems\n * when extremely large numbers are used.\n */\npublic class NumberLimits {\n private NumberLimits() {}\n\n private static final int MAX_NUMBER_STRING_LENGTH = 10_000;\n\n private static void checkNumberStringLength(String s) {\n if (s.length() > MAX_NUMBER_STRING_LENGTH) {\n throw new NumberFormatException(\"Number string too large: \" + s.substring(0, 30) + \"...\");\n }\n }\n\n public static BigDecimal parseBigDecimal(String s) throws NumberFormatException {\n checkNumberStringLength(s);\n BigDecimal decimal = new BigDecimal(s);\n\n // Cast to long to avoid issues with abs when value is Integer.MIN_VALUE\n if (Math.abs((long) decimal.scale()) >= 10_000) {\n throw new NumberFormatException(\"Number has unsupported scale: \" + s);\n }\n return decimal;\n }\n\n public static BigInteger parseBigInteger(String s) throws NumberFormatException {\n checkNumberStringLength(s);\n return new BigInteger(s);\n }\n}\n\n// Path: gson/src/main/java/com/google/gson/internal/LazilyParsedNumber.java\n/*\n * Copyright (C) 2011 Google Inc.\n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n */\npackage com.google.gson.internal;\n\nimport java.io.IOException;\nimport java.io.InvalidObjectException;\nimport java.io.ObjectInputStream;\nimport java.io.ObjectStreamException;\nimport java.math.BigDecimal;\n\n/**\n * This class holds a number value that is lazily converted to a specific number type\n *\n * @author Inderjeet Singh\n */\n@SuppressWarnings(\"serial\") // ignore warning about missing serialVersionUID\npublic final class LazilyParsedNumber extends Number {\n private final String value;\n\n /**\n * @param value must not be null\n */\n public LazilyParsedNumber(String value) {\n this.value = value;\n }\n\n private BigDecimal asBigDecimal() {\n return NumberLimits.parseBigDecimal(value);\n }\n\n @Override\n public int intValue() {\n try {\n return Integer.parseInt(value);\n } catch (NumberFormatException e) {\n try {\n return (int) Long.parseLong(value);\n } catch (NumberFormatException nfe) {\n return asBigDecimal().intValue();\n }\n }\n }\n\n @Override\n public long longValue() {\n try {\n return Long.parseLong(value);\n } catch (NumberFormatException e) {\n return asBigDecimal().longValue();\n }\n }\n\n @Override\n public float floatValue() {\n return Float.parseFloat(value);\n }\n\n @Override\n public double doubleValue() {\n return Double.parseDouble(value);\n }\n\n @Override\n public String toString() {\n...\n", "needle_token_start": 6506, "needle_token_end": 6748, "code_context_ntokens": 13775, "output": ["```java\nprivate void rotateLeft(Node root) {\n Node left = root.left;\n Node pivot = root.right;\n Node pivotLeft = pivot.left;\n Node pivotRight = pivot.right;\n\n // move the pivot's left child to the root's right\n root.right = pivotLeft;\n if (pivotLeft != null) {\n pivotLeft.parent = root;\n }\n\n replaceInParent(root, pivot);\n\n // move the root to the pivot's left\n pivot.left = root;\n root.parent = pivot;\n\n // fix heights\n root.height =\n Math.max(left != null ? left.height : 0, pivotLeft != null ? pivotLeft.height : 0) + 1;\n pivot.height = Math.max(root.height, pivotRight != null ? pivotRight.height : 0) + 1;\n }\n```"]} +{"repo": "google/gson", "name": "getArrayComponentType", "language": "java", "path": "gson/src/main/java/com/google/gson/internal/$Gson$Types.java", "position_ratio": 0.65, "description": "\nFunction Description:\n1. **Purpose**: To retrieve the type of elements contained within an array type, whether it is a generic or non-generic array.\n2. **Input**: A `Type` object representing an array.\n3. **Output**: A `Type` object representing the component type of the input array.\n4. **Procedure**: The function first checks if the input type is a generic array. If it is, it extracts the generic component type. If the input type is a standard array, it retrieves the component type directly. If the input is not an array type, the function throws a `ClassCastException`.\n\n", "instruction": "Based on the function description and code context, please retrieve and repeat the exact described function from the code context in a code block wrapped by ```:", "template": "instruction\ncode_context\ndescription\ninstruction", "code_context": "// Path: gson/src/main/java/com/google/gson/JsonArray.java\n/*\n * Copyright (C) 2008 Google Inc.\n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n */\n\npackage com.google.gson;\n\nimport com.google.errorprone.annotations.CanIgnoreReturnValue;\nimport com.google.gson.internal.NonNullElementWrapperList;\nimport java.math.BigDecimal;\nimport java.math.BigInteger;\nimport java.util.ArrayList;\nimport java.util.Iterator;\nimport java.util.List;\n\n/**\n * A class representing an array type in JSON. An array is a list of {@link JsonElement}s each of\n * which can be of a different type. This is an ordered list, meaning that the order in which\n * elements are added is preserved. This class does not support {@code null} values. If {@code null}\n * is provided as element argument to any of the methods, it is converted to a {@link JsonNull}.\n *\n *

    {@code JsonArray} only implements the {@link Iterable} interface but not the {@link List}\n * interface. A {@code List} view of it can be obtained with {@link #asList()}.\n *\n * @author Inderjeet Singh\n * @author Joel Leitch\n */\npublic final class JsonArray extends JsonElement implements Iterable {\n private final ArrayList elements;\n\n /** Creates an empty JsonArray. */\n @SuppressWarnings(\"deprecation\") // superclass constructor\n public JsonArray() {\n elements = new ArrayList<>();\n }\n\n /**\n * Creates an empty JsonArray with the desired initial capacity.\n *\n * @param capacity initial capacity.\n * @throws IllegalArgumentException if the {@code capacity} is negative\n * @since 2.8.1\n */\n @SuppressWarnings(\"deprecation\") // superclass constructor\n public JsonArray(int capacity) {\n elements = new ArrayList<>(capacity);\n }\n\n /**\n * Creates a deep copy of this element and all its children.\n *\n * @since 2.8.2\n */\n @Override\n public JsonArray deepCopy() {\n if (!elements.isEmpty()) {\n JsonArray result = new JsonArray(elements.size());\n for (JsonElement element : elements) {\n result.add(element.deepCopy());\n }\n return result;\n }\n return new JsonArray();\n }\n\n /**\n * Adds the specified boolean to self.\n *\n * @param bool the boolean that needs to be added to the array.\n * @since 2.4\n */\n public void add(Boolean bool) {\n elements.add(bool == null ? JsonNull.INSTANCE : new JsonPrimitive(bool));\n }\n\n /**\n * Adds the specified character to self.\n *\n * @param character the character that needs to be added to the array.\n * @since 2.4\n */\n public void add(Character character) {\n elements.add(character == null ? JsonNull.INSTANCE : new JsonPrimitive(character));\n }\n\n /**\n * Adds the specified number to self.\n *\n * @param number the number that needs to be added to the array.\n * @since 2.4\n */\n public void add(Number number) {\n elements.add(number == null ? JsonNull.INSTANCE : new JsonPrimitive(number));\n }\n\n /**\n * Adds the specified string to self.\n *\n * @param string the string that needs to be added to the array.\n * @since 2.4\n */\n public void add(String string) {\n elements.add(string == null ? JsonNull.INSTANCE : new JsonPrimitive(string));\n }\n\n /**\n * Adds the specified element to self.\n *\n * @param element the element that needs to be added to the array.\n */\n public void add(JsonElement element) {\n if (element == null) {\n element = JsonNull.INSTANCE;\n }\n elements.add(element);\n }\n\n /**\n * Adds all the elements of the specified array to self.\n *\n * @param array the array whose elements need to be added to the array.\n */\n public void addAll(JsonArray array) {\n elements.addAll(array.elements);\n }\n\n /**\n * Replaces the element at the specified position in this array with the specified element.\n *\n * @param index index of the element to replace\n * @param element element to be stored at the specified position\n * @return the element previously at the specified position\n * @throws IndexOutOfBoundsException if the specified index is outside the array bounds\n */\n @CanIgnoreReturnValue\n public JsonElement set(int index, JsonElement element) {\n return elements.set(index, element == null ? JsonNull.INSTANCE : element);\n }\n\n /**\n * Removes the first occurrence of the specified element from this array, if it is present. If the\n * array does not contain the element, it is unchanged.\n *\n * @param element element to be removed from this array, if present\n * @return true if this array contained the specified element, false otherwise\n * @since 2.3\n */\n @CanIgnoreReturnValue\n public boolean remove(JsonElement element) {\n return elements.remove(element);\n }\n\n /**\n * Removes the element at the specified position in this array. Shifts any subsequent elements to\n * the left (subtracts one from their indices). Returns the element that was removed from the\n * array.\n *\n * @param index index the index of the element to be removed\n * @return the element previously at the specified position\n * @throws IndexOutOfBoundsException if the specified index is outside the array bounds\n * @since 2.3\n */\n @CanIgnoreReturnValue\n public JsonElement remove(int index) {\n return elements.remove(index);\n }\n\n /**\n * Returns true if this array contains the specified element.\n *\n * @return true if this array contains the specified element.\n * @param element whose presence in this array is to be tested\n * @since 2.3\n */\n public boolean contains(JsonElement element) {\n return elements.contains(element);\n }\n\n /**\n * Returns the number of elements in the array.\n *\n * @return the number of elements in the array.\n */\n public int size() {\n return elements.size();\n }\n\n /**\n * Returns true if the array is empty.\n *\n * @return true if the array is empty.\n * @since 2.8.7\n */\n public boolean isEmpty() {\n return elements.isEmpty();\n }\n\n /**\n * Returns an iterator to navigate the elements of the array. Since the array is an ordered list,\n * the iterator navigates the elements in the order they were inserted.\n *\n * @return an iterator to navigate the elements of the array.\n */\n @Override\n public Iterator iterator() {\n return elements.iterator();\n }\n\n /**\n * Returns the i-th element of the array.\n *\n * @param i the index of the element that is being sought.\n * @return the element present at the i-th index.\n * @throws IndexOutOfBoundsException if {@code i} is negative or greater than or equal to the\n * {@link #size()} of the array.\n */\n public JsonElement get(int i) {\n return elements.get(i);\n }\n\n private JsonElement getAsSingleElement() {\n int size = elements.size();\n if (size == 1) {\n return elements.get(0);\n }\n throw new IllegalStateException(\"Array must have size 1, but has size \" + size);\n }\n\n /**\n * Convenience method to get this array as a {@link Number} if it contains a single element. This\n * method calls {@link JsonElement#getAsNumber()} on the element, therefore any of the exceptions\n * declared by that method can occur.\n *\n * @return this element as a number if it is single element array.\n * @throws IllegalStateException if the array is empty or has more than one element.\n */\n @Override\n public Number getAsNumber() {\n return getAsSingleElement().getAsNumber();\n }\n\n /**\n * Convenience method to get this array as a {@link String} if it contains a single element. This\n * method calls {@link JsonElement#getAsString()} on the element, therefore any of the exceptions\n * declared by that method can occur.\n *\n * @return this element as a String if it is single element array.\n * @throws IllegalStateException if the array is empty or has more than one element.\n */\n @Override\n public String getAsString() {\n return getAsSingleElement().getAsString();\n }\n\n /**\n * Convenience method to get this array as a double if it contains a single element. This method\n * calls {@link JsonElement#getAsDouble()} on the element, therefore any of the exceptions\n * declared by that method can occur.\n *\n * @return this element as a double if it is single element array.\n * @throws IllegalStateException if the array is empty or has more than one element.\n */\n @Override\n public double getAsDouble() {\n return getAsSingleElement().getAsDouble();\n }\n\n /**\n * Convenience method to get this array as a {@link BigDecimal} if it contains a single element.\n * This method calls {@link JsonElement#getAsBigDecimal()} on the element, therefore any of the\n * exceptions declared by that method can occur.\n *\n * @return this element as a {@link BigDecimal} if it is single element array.\n * @throws IllegalStateException if the array is empty or has more than one element.\n * @since 1.2\n */\n @Override\n public BigDecimal getAsBigDecimal() {\n return getAsSingleElement().getAsBigDecimal();\n }\n\n /**\n * Convenience method to get this array as a {@link BigInteger} if it contains a single element.\n * This method calls {@link JsonElement#getAsBigInteger()} on the element, therefore any of the\n * exceptions declared by that method can occur.\n *\n * @return this element as a {@link BigInteger} if it is single element array.\n * @throws IllegalStateException if the array is empty or has more than one element.\n * @since 1.2\n */\n @Override\n public BigInteger getAsBigInteger() {\n return getAsSingleElement().getAsBigInteger();\n }\n\n /**\n * Convenience method to get this array as a float if it contains a single element. This method\n * calls {@link JsonElement#getAsFloat()} on the element, therefore any of the exceptions declared\n * by that method can occur.\n *\n * @return this element as a float if it is single element array.\n * @throws IllegalStateException if the array is empty or has more than one element.\n */\n @Override\n public float getAsFloat() {\n return getAsSingleElement().getAsFloat();\n }\n\n /**\n * Convenience method to get this array as a long if it contains a single element. This method\n * calls {@link JsonElement#getAsLong()} on the element, therefore any of the exceptions declared\n * by that method can occur.\n *\n * @return this element as a long if it is single element array.\n * @throws IllegalStateException if the array is empty or has more than one element.\n */\n @Override\n public long getAsLong() {\n return getAsSingleElement().getAsLong();\n }\n\n /**\n * Convenience method to get this array as an integer if it contains a single element. This method\n * calls {@link JsonElement#getAsInt()} on the element, therefore any of the exceptions declared\n * by that method can occur.\n *\n * @return this element as an integer if it is single element array.\n * @throws IllegalStateException if the array is empty or has more than one element.\n */\n @Override\n public int getAsInt() {\n return getAsSingleElement().getAsInt();\n }\n\n /**\n * Convenience method to get this array as a primitive byte if it contains a single element. This\n * method calls {@link JsonElement#getAsByte()} on the element, therefore any of the exceptions\n * declared by that method can occur.\n *\n * @return this element as a primitive byte if it is single element array.\n * @throws IllegalStateException if the array is empty or has more than one element.\n */\n @Override\n public byte getAsByte() {\n return getAsSingleElement().getAsByte();\n }\n\n /**\n * Convenience method to get this array as a character if it contains a single element. This\n * method calls {@link JsonElement#getAsCharacter()} on the element, therefore any of the\n * exceptions declared by that method can occur.\n *\n * @return this element as a primitive short if it is single element array.\n * @throws IllegalStateException if the array is empty or has more than one element.\n * @deprecated This method is misleading, as it does not get this element as a char but rather as\n * a string's first character.\n */\n @Deprecated\n @Override\n public char getAsCharacter() {\n return getAsSingleElement().getAsCharacter();\n }\n\n /**\n * Convenience method to get this array as a primitive short if it contains a single element. This\n * method calls {@link JsonElement#getAsShort()} on the element, therefore any of the exceptions\n * declared by that method can occur.\n...\n// Path: gson/src/main/java/com/google/gson/JsonIOException.java\n/*\n * Copyright (C) 2008 Google Inc.\n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n */\npackage com.google.gson;\n\n/**\n * This exception is raised when Gson was unable to read an input stream or write to one.\n *\n * @author Inderjeet Singh\n * @author Joel Leitch\n */\npublic final class JsonIOException extends JsonParseException {\n private static final long serialVersionUID = 1L;\n\n public JsonIOException(String msg) {\n super(msg);\n }\n\n public JsonIOException(String msg, Throwable cause) {\n super(msg, cause);\n }\n\n /**\n * Creates exception with the specified cause. Consider using {@link #JsonIOException(String,\n * Throwable)} instead if you can describe what happened.\n *\n * @param cause root exception that caused this exception to be thrown.\n */\n public JsonIOException(Throwable cause) {\n super(cause);\n }\n}\n\n// Path: gson/src/main/java/com/google/gson/JsonSyntaxException.java\n/*\n * Copyright (C) 2010 Google Inc.\n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n */\npackage com.google.gson;\n\n/**\n * This exception is raised when Gson attempts to read (or write) a malformed JSON element.\n *\n * @author Inderjeet Singh\n * @author Joel Leitch\n */\npublic final class JsonSyntaxException extends JsonParseException {\n\n private static final long serialVersionUID = 1L;\n\n public JsonSyntaxException(String msg) {\n super(msg);\n }\n\n public JsonSyntaxException(String msg, Throwable cause) {\n super(msg, cause);\n }\n\n /**\n * Creates exception with the specified cause. Consider using {@link #JsonSyntaxException(String,\n * Throwable)} instead if you can describe what actually happened.\n *\n * @param cause root exception that caused this exception to be thrown.\n */\n public JsonSyntaxException(Throwable cause) {\n super(cause);\n }\n}\n\n// Path: gson/src/main/java/com/google/gson/JsonPrimitive.java\n/*\n * Copyright (C) 2008 Google Inc.\n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n */\n\npackage com.google.gson;\n\nimport com.google.gson.internal.LazilyParsedNumber;\nimport com.google.gson.internal.NumberLimits;\nimport java.math.BigDecimal;\nimport java.math.BigInteger;\nimport java.util.Objects;\n\n/**\n * A class representing a JSON primitive value. A primitive value is either a String, a Java\n * primitive, or a Java primitive wrapper type.\n *\n * @author Inderjeet Singh\n * @author Joel Leitch\n */\npublic final class JsonPrimitive extends JsonElement {\n\n private final Object value;\n\n /**\n * Create a primitive containing a boolean value.\n *\n * @param bool the value to create the primitive with.\n */\n // \"deprecation\" suppression for superclass constructor\n // \"UnnecessaryBoxedVariable\" Error Prone warning is correct since method does not accept\n // null, but cannot be changed anymore since this is public API\n @SuppressWarnings({\"deprecation\", \"UnnecessaryBoxedVariable\"})\n public JsonPrimitive(Boolean bool) {\n value = Objects.requireNonNull(bool);\n }\n\n /**\n * Create a primitive containing a {@link Number}.\n *\n * @param number the value to create the primitive with.\n */\n @SuppressWarnings(\"deprecation\") // superclass constructor\n public JsonPrimitive(Number number) {\n value = Objects.requireNonNull(number);\n }\n\n /**\n * Create a primitive containing a String value.\n *\n * @param string the value to create the primitive with.\n */\n @SuppressWarnings(\"deprecation\") // superclass constructor\n public JsonPrimitive(String string) {\n value = Objects.requireNonNull(string);\n }\n\n /**\n * Create a primitive containing a character. The character is turned into a one character String\n * since JSON only supports String.\n *\n * @param c the value to create the primitive with.\n */\n // \"deprecation\" suppression for superclass constructor\n // \"UnnecessaryBoxedVariable\" Error Prone warning is correct since method does not accept\n // null, but cannot be changed anymore since this is public API\n @SuppressWarnings({\"deprecation\", \"UnnecessaryBoxedVariable\"})\n public JsonPrimitive(Character c) {\n // convert characters to strings since in JSON, characters are represented as a single\n // character string\n value = Objects.requireNonNull(c).toString();\n }\n\n /**\n * Returns the same value as primitives are immutable.\n *\n * @since 2.8.2\n */\n @Override\n public JsonPrimitive deepCopy() {\n return this;\n }\n\n /**\n * Check whether this primitive contains a boolean value.\n *\n * @return true if this primitive contains a boolean value, false otherwise.\n */\n public boolean isBoolean() {\n return value instanceof Boolean;\n }\n\n /**\n * Convenience method to get this element as a boolean value. If this primitive {@linkplain\n * #isBoolean() is not a boolean}, the string value is parsed using {@link\n * Boolean#parseBoolean(String)}. This means {@code \"true\"} (ignoring case) is considered {@code\n * true} and any other value is considered {@code false}.\n */\n @Override\n public boolean getAsBoolean() {\n if (isBoolean()) {\n return (Boolean) value;\n }\n // Check to see if the value as a String is \"true\" in any case.\n return Boolean.parseBoolean(getAsString());\n }\n\n /**\n * Check whether this primitive contains a Number.\n *\n * @return true if this primitive contains a Number, false otherwise.\n */\n public boolean isNumber() {\n return value instanceof Number;\n }\n\n /**\n * Convenience method to get this element as a {@link Number}. If this primitive {@linkplain\n * #isString() is a string}, a lazily parsed {@code Number} is constructed which parses the string\n * when any of its methods are called (which can lead to a {@link NumberFormatException}).\n *\n * @throws UnsupportedOperationException if this primitive is neither a number nor a string.\n */\n @Override\n public Number getAsNumber() {\n if (value instanceof Number) {\n return (Number) value;\n } else if (value instanceof String) {\n return new LazilyParsedNumber((String) value);\n }\n throw new UnsupportedOperationException(\"Primitive is neither a number nor a string\");\n }\n\n /**\n * Check whether this primitive contains a String value.\n *\n * @return true if this primitive contains a String value, false otherwise.\n */\n public boolean isString() {\n return value instanceof String;\n }\n\n // Don't add Javadoc, inherit it from super implementation; no exceptions are thrown here\n @Override\n public String getAsString() {\n if (value instanceof String) {\n return (String) value;\n } else if (isNumber()) {\n return getAsNumber().toString();\n } else if (isBoolean()) {\n return ((Boolean) value).toString();\n }\n throw new AssertionError(\"Unexpected value type: \" + value.getClass());\n }\n\n /**\n * @throws NumberFormatException {@inheritDoc}\n */\n @Override\n public double getAsDouble() {\n return isNumber() ? getAsNumber().doubleValue() : Double.parseDouble(getAsString());\n }\n\n /**\n * @throws NumberFormatException {@inheritDoc}\n */\n @Override\n public BigDecimal getAsBigDecimal() {\n return value instanceof BigDecimal\n ? (BigDecimal) value\n : NumberLimits.parseBigDecimal(getAsString());\n }\n\n /**\n * @throws NumberFormatException {@inheritDoc}\n */\n @Override\n public BigInteger getAsBigInteger() {\n return value instanceof BigInteger\n ? (BigInteger) value\n : isIntegral(this)\n ? BigInteger.valueOf(this.getAsNumber().longValue())\n : NumberLimits.parseBigInteger(this.getAsString());\n }\n\n /**\n * @throws NumberFormatException {@inheritDoc}\n */\n @Override\n public float getAsFloat() {\n return isNumber() ? getAsNumber().floatValue() : Float.parseFloat(getAsString());\n }\n\n /**\n * Convenience method to get this element as a primitive long.\n *\n * @return this element as a primitive long.\n * @throws NumberFormatException {@inheritDoc}\n */\n @Override\n public long getAsLong() {\n return isNumber() ? getAsNumber().longValue() : Long.parseLong(getAsString());\n }\n\n /**\n * @throws NumberFormatException {@inheritDoc}\n */\n @Override\n public short getAsShort() {\n return isNumber() ? getAsNumber().shortValue() : Short.parseShort(getAsString());\n }\n\n /**\n * @throws NumberFormatException {@inheritDoc}\n */\n @Override\n public int getAsInt() {\n return isNumber() ? getAsNumber().intValue() : Integer.parseInt(getAsString());\n }\n\n /**\n * @throws NumberFormatException {@inheritDoc}\n */\n @Override\n public byte getAsByte() {\n return isNumber() ? getAsNumber().byteValue() : Byte.parseByte(getAsString());\n }\n\n /**\n * @throws UnsupportedOperationException if the string value of this primitive is empty.\n * @deprecated This method is misleading, as it does not get this element as a char but rather as\n * a string's first character.\n */\n @Deprecated\n @Override\n public char getAsCharacter() {\n String s = getAsString();\n if (s.isEmpty()) {\n throw new UnsupportedOperationException(\"String value is empty\");\n } else {\n return s.charAt(0);\n }\n }\n\n /** Returns the hash code of this object. */\n @Override\n public int hashCode() {\n if (value == null) {\n return 31;\n }\n // Using recommended hashing algorithm from Effective Java for longs and doubles\n if (isIntegral(this)) {\n long value = getAsNumber().longValue();\n return (int) (value ^ (value >>> 32));\n }\n if (value instanceof Number) {\n long value = Double.doubleToLongBits(getAsNumber().doubleValue());\n return (int) (value ^ (value >>> 32));\n }\n return value.hashCode();\n }\n\n /**\n * Returns whether the other object is equal to this. This method only considers the other object\n * to be equal if it is an instance of {@code JsonPrimitive} and has an equal value.\n */\n @Override\n public boolean equals(Object obj) {\n if (this == obj) {\n return true;\n }\n if (obj == null || getClass() != obj.getClass()) {\n return false;\n }\n JsonPrimitive other = (JsonPrimitive) obj;\n if (value == null) {\n return other.value == null;\n }\n if (isIntegral(this) && isIntegral(other)) {\n return (this.value instanceof BigInteger || other.value instanceof BigInteger)\n ? this.getAsBigInteger().equals(other.getAsBigInteger())\n : this.getAsNumber().longValue() == other.getAsNumber().longValue();\n }\n if (value instanceof Number && other.value instanceof Number) {\n if (value instanceof BigDecimal && other.value instanceof BigDecimal) {\n // Uses compareTo to ignore scale of values, e.g. `0` and `0.00` should be considered equal\n return this.getAsBigDecimal().compareTo(other.getAsBigDecimal()) == 0;\n }\n\n double thisAsDouble = this.getAsDouble();\n double otherAsDouble = other.getAsDouble();\n // Don't use Double.compare(double, double) because that considers -0.0 and +0.0 not equal\n return (thisAsDouble == otherAsDouble)\n || (Double.isNaN(thisAsDouble) && Double.isNaN(otherAsDouble));\n }\n return value.equals(other.value);\n }\n\n /**\n * Returns true if the specified number is an integral type (Long, Integer, Short, Byte,\n * BigInteger)\n */\n private static boolean isIntegral(JsonPrimitive primitive) {\n if (primitive.value instanceof Number) {\n Number number = (Number) primitive.value;\n return number instanceof BigInteger\n || number instanceof Long\n || number instanceof Integer\n || number instanceof Short\n || number instanceof Byte;\n }\n return false;\n }\n}\n\n// Path: gson/src/main/java/com/google/gson/internal/$Gson$Types.java\n/*\n * Copyright (C) 2008 Google Inc.\n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n */\n\npackage com.google.gson.internal;\n\nimport static com.google.gson.internal.$Gson$Preconditions.checkArgument;\nimport static java.util.Objects.requireNonNull;\n\nimport java.io.Serializable;\nimport java.lang.reflect.Array;\nimport java.lang.reflect.GenericArrayType;\nimport java.lang.reflect.GenericDeclaration;\nimport java.lang.reflect.Modifier;\nimport java.lang.reflect.ParameterizedType;\nimport java.lang.reflect.Type;\nimport java.lang.reflect.TypeVariable;\nimport java.lang.reflect.WildcardType;\nimport java.util.Arrays;\nimport java.util.Collection;\nimport java.util.HashMap;\nimport java.util.Map;\nimport java.util.NoSuchElementException;\nimport java.util.Objects;\nimport java.util.Properties;\n\n/**\n * Static methods for working with types.\n *\n * @author Bob Lee\n * @author Jesse Wilson\n */\npublic final class $Gson$Types {\n static final Type[] EMPTY_TYPE_ARRAY = new Type[] {};\n\n private $Gson$Types() {\n throw new UnsupportedOperationException();\n }\n\n /**\n * Returns a new parameterized type, applying {@code typeArguments} to {@code rawType} and\n * enclosed by {@code ownerType}.\n *\n * @return a {@link java.io.Serializable serializable} parameterized type.\n */\n public static ParameterizedType newParameterizedTypeWithOwner(\n Type ownerType, Type rawType, Type... typeArguments) {\n return new ParameterizedTypeImpl(ownerType, rawType, typeArguments);\n }\n\n /**\n * Returns an array type whose elements are all instances of {@code componentType}.\n *\n * @return a {@link java.io.Serializable serializable} generic array type.\n */\n public static GenericArrayType arrayOf(Type componentType) {\n return new GenericArrayTypeImpl(componentType);\n }\n\n /**\n * Returns a type that represents an unknown type that extends {@code bound}. For example, if\n * {@code bound} is {@code CharSequence.class}, this returns {@code ? extends CharSequence}. If\n * {@code bound} is {@code Object.class}, this returns {@code ?}, which is shorthand for {@code ?\n * extends Object}.\n */\n public static WildcardType subtypeOf(Type bound) {\n Type[] upperBounds;\n if (bound instanceof WildcardType) {\n upperBounds = ((WildcardType) bound).getUpperBounds();\n } else {\n upperBounds = new Type[] {bound};\n }\n return new WildcardTypeImpl(upperBounds, EMPTY_TYPE_ARRAY);\n }\n\n /**\n * Returns a type that represents an unknown supertype of {@code bound}. For example, if {@code\n * bound} is {@code String.class}, this returns {@code ? super String}.\n */\n public static WildcardType supertypeOf(Type bound) {\n Type[] lowerBounds;\n if (bound instanceof WildcardType) {\n lowerBounds = ((WildcardType) bound).getLowerBounds();\n } else {\n lowerBounds = new Type[] {bound};\n }\n return new WildcardTypeImpl(new Type[] {Object.class}, lowerBounds);\n }\n\n /**\n * Returns a type that is functionally equal but not necessarily equal according to {@link\n * Object#equals(Object) Object.equals()}. The returned type is {@link java.io.Serializable}.\n */\n public static Type canonicalize(Type type) {\n if (type instanceof Class) {\n Class c = (Class) type;\n return c.isArray() ? new GenericArrayTypeImpl(canonicalize(c.getComponentType())) : c;\n\n } else if (type instanceof ParameterizedType) {\n ParameterizedType p = (ParameterizedType) type;\n return new ParameterizedTypeImpl(\n p.getOwnerType(), p.getRawType(), p.getActualTypeArguments());\n\n } else if (type instanceof GenericArrayType) {\n GenericArrayType g = (GenericArrayType) type;\n return new GenericArrayTypeImpl(g.getGenericComponentType());\n\n } else if (type instanceof WildcardType) {\n WildcardType w = (WildcardType) type;\n return new WildcardTypeImpl(w.getUpperBounds(), w.getLowerBounds());\n\n } else {\n // type is either serializable as-is or unsupported\n return type;\n }\n }\n\n public static Class getRawType(Type type) {\n if (type instanceof Class) {\n // type is a normal class.\n return (Class) type;\n\n } else if (type instanceof ParameterizedType) {\n ParameterizedType parameterizedType = (ParameterizedType) type;\n\n // getRawType() returns Type instead of Class; that seems to be an API mistake,\n // see https://bugs.openjdk.org/browse/JDK-8250659\n Type rawType = parameterizedType.getRawType();\n checkArgument(rawType instanceof Class);\n return (Class) rawType;\n\n } else if (type instanceof GenericArrayType) {\n Type componentType = ((GenericArrayType) type).getGenericComponentType();\n return Array.newInstance(getRawType(componentType), 0).getClass();\n\n } else if (type instanceof TypeVariable) {\n // we could use the variable's bounds, but that won't work if there are multiple.\n // having a raw type that's more general than necessary is okay\n return Object.class;\n\n } else if (type instanceof WildcardType) {\n Type[] bounds = ((WildcardType) type).getUpperBounds();\n // Currently the JLS only permits one bound for wildcards so using first bound is safe\n assert bounds.length == 1;\n return getRawType(bounds[0]);\n\n } else {\n String className = type == null ? \"null\" : type.getClass().getName();\n throw new IllegalArgumentException(\n \"Expected a Class, ParameterizedType, or GenericArrayType, but <\"\n + type\n + \"> is of type \"\n + className);\n }\n }\n\n private static boolean equal(Object a, Object b) {\n return Objects.equals(a, b);\n }\n\n /** Returns true if {@code a} and {@code b} are equal. */\n public static boolean equals(Type a, Type b) {\n if (a == b) {\n // also handles (a == null && b == null)\n return true;\n\n } else if (a instanceof Class) {\n // Class already specifies equals().\n return a.equals(b);\n\n } else if (a instanceof ParameterizedType) {\n if (!(b instanceof ParameterizedType)) {\n return false;\n }\n\n // TODO: save a .clone() call\n ParameterizedType pa = (ParameterizedType) a;\n ParameterizedType pb = (ParameterizedType) b;\n return equal(pa.getOwnerType(), pb.getOwnerType())\n && pa.getRawType().equals(pb.getRawType())\n && Arrays.equals(pa.getActualTypeArguments(), pb.getActualTypeArguments());\n\n } else if (a instanceof GenericArrayType) {\n if (!(b instanceof GenericArrayType)) {\n return false;\n }\n\n GenericArrayType ga = (GenericArrayType) a;\n GenericArrayType gb = (GenericArrayType) b;\n return equals(ga.getGenericComponentType(), gb.getGenericComponentType());\n\n } else if (a instanceof WildcardType) {\n if (!(b instanceof WildcardType)) {\n return false;\n }\n\n WildcardType wa = (WildcardType) a;\n WildcardType wb = (WildcardType) b;\n return Arrays.equals(wa.getUpperBounds(), wb.getUpperBounds())\n && Arrays.equals(wa.getLowerBounds(), wb.getLowerBounds());\n\n } else if (a instanceof TypeVariable) {\n if (!(b instanceof TypeVariable)) {\n return false;\n }\n TypeVariable va = (TypeVariable) a;\n TypeVariable vb = (TypeVariable) b;\n return Objects.equals(va.getGenericDeclaration(), vb.getGenericDeclaration())\n && va.getName().equals(vb.getName());\n\n } else {\n // This isn't a type we support. Could be a generic array type, wildcard type, etc.\n return false;\n }\n }\n\n public static String typeToString(Type type) {\n return type instanceof Class ? ((Class) type).getName() : type.toString();\n }\n\n /**\n * Returns the generic supertype for {@code supertype}. For example, given a class {@code\n * IntegerSet}, the result for when supertype is {@code Set.class} is {@code Set} and the\n * result when the supertype is {@code Collection.class} is {@code Collection}.\n */\n private static Type getGenericSupertype(Type context, Class rawType, Class supertype) {\n if (supertype == rawType) {\n return context;\n }\n\n // we skip searching through interfaces if unknown is an interface\n if (supertype.isInterface()) {\n Class[] interfaces = rawType.getInterfaces();\n for (int i = 0, length = interfaces.length; i < length; i++) {\n if (interfaces[i] == supertype) {\n return rawType.getGenericInterfaces()[i];\n } else if (supertype.isAssignableFrom(interfaces[i])) {\n return getGenericSupertype(rawType.getGenericInterfaces()[i], interfaces[i], supertype);\n }\n }\n }\n\n // check our supertypes\n if (!rawType.isInterface()) {\n while (rawType != Object.class) {\n Class rawSupertype = rawType.getSuperclass();\n if (rawSupertype == supertype) {\n return rawType.getGenericSuperclass();\n } else if (supertype.isAssignableFrom(rawSupertype)) {\n return getGenericSupertype(rawType.getGenericSuperclass(), rawSupertype, supertype);\n }\n rawType = rawSupertype;\n }\n }\n\n // we can't resolve this further\n return supertype;\n }\n\n /**\n * Returns the generic form of {@code supertype}. For example, if this is {@code\n * ArrayList}, this returns {@code Iterable} given the input {@code\n * Iterable.class}.\n *\n * @param supertype a superclass of, or interface implemented by, this.\n */\n private static Type getSupertype(Type context, Class contextRawType, Class supertype) {\n if (context instanceof WildcardType) {\n // Wildcards are useless for resolving supertypes. As the upper bound has the same raw type,\n // use it instead\n Type[] bounds = ((WildcardType) context).getUpperBounds();\n // Currently the JLS only permits one bound for wildcards so using first bound is safe\n assert bounds.length == 1;\n context = bounds[0];\n }\n checkArgument(supertype.isAssignableFrom(contextRawType));\n return resolve(\n context,\n contextRawType,\n $Gson$Types.getGenericSupertype(context, contextRawType, supertype));\n }\n\n /**\n * Returns the component type of this array type.\n *\n * @throws ClassCastException if this type is not an array.\n */\n \npublic static Type getArrayComponentType(Type array) {\n return array instanceof GenericArrayType\n ? ((GenericArrayType) array).getGenericComponentType()\n : ((Class) array).getComponentType();\n }\n\n /**\n * Returns the element type of this collection type.\n *\n * @throws IllegalArgumentException if this type is not a collection.\n */\n public static Type getCollectionElementType(Type context, Class contextRawType) {\n Type collectionType = getSupertype(context, contextRawType, Collection.class);\n\n if (collectionType instanceof ParameterizedType) {\n return ((ParameterizedType) collectionType).getActualTypeArguments()[0];\n }\n return Object.class;\n }\n\n /**\n * Returns a two element array containing this map's key and value types in positions 0 and 1\n * respectively.\n */\n public static Type[] getMapKeyAndValueTypes(Type context, Class contextRawType) {\n /*\n * Work around a problem with the declaration of java.util.Properties. That\n * class should extend Hashtable, but it's declared to\n * extend Hashtable.\n */\n if (context == Properties.class) {\n return new Type[] {String.class, String.class}; // TODO: test subclasses of Properties!\n }\n\n Type mapType = getSupertype(context, contextRawType, Map.class);\n // TODO: strip wildcards?\n if (mapType instanceof ParameterizedType) {\n ParameterizedType mapParameterizedType = (ParameterizedType) mapType;\n return mapParameterizedType.getActualTypeArguments();\n }\n return new Type[] {Object.class, Object.class};\n }\n\n public static Type resolve(Type context, Class contextRawType, Type toResolve) {\n\n return resolve(context, contextRawType, toResolve, new HashMap, Type>());\n }\n\n private static Type resolve(\n Type context,\n Class contextRawType,\n Type toResolve,\n Map, Type> visitedTypeVariables) {\n // this implementation is made a little more complicated in an attempt to avoid object-creation\n TypeVariable resolving = null;\n while (true) {\n if (toResolve instanceof TypeVariable) {\n TypeVariable typeVariable = (TypeVariable) toResolve;\n Type previouslyResolved = visitedTypeVariables.get(typeVariable);\n if (previouslyResolved != null) {\n // cannot reduce due to infinite recursion\n return (previouslyResolved == Void.TYPE) ? toResolve : previouslyResolved;\n }\n\n // Insert a placeholder to mark the fact that we are in the process of resolving this type\n visitedTypeVariables.put(typeVariable, Void.TYPE);\n if (resolving == null) {\n resolving = typeVariable;\n }\n\n toResolve = resolveTypeVariable(context, contextRawType, typeVariable);\n if (toResolve == typeVariable) {\n break;\n }\n\n } else if (toResolve instanceof Class && ((Class) toResolve).isArray()) {\n Class original = (Class) toResolve;\n Type componentType = original.getComponentType();\n Type newComponentType =\n resolve(context, contextRawType, componentType, visitedTypeVariables);\n toResolve = equal(componentType, newComponentType) ? original : arrayOf(newComponentType);\n break;\n\n } else if (toResolve instanceof GenericArrayType) {\n GenericArrayType original = (GenericArrayType) toResolve;\n Type componentType = original.getGenericComponentType();\n Type newComponentType =\n resolve(context, contextRawType, componentType, visitedTypeVariables);\n toResolve = equal(componentType, newComponentType) ? original : arrayOf(newComponentType);\n break;\n\n } else if (toResolve instanceof ParameterizedType) {\n ParameterizedType original = (ParameterizedType) toResolve;\n Type ownerType = original.getOwnerType();\n Type newOwnerType = resolve(context, contextRawType, ownerType, visitedTypeVariables);\n boolean changed = !equal(newOwnerType, ownerType);\n\n Type[] args = original.getActualTypeArguments();\n for (int t = 0, length = args.length; t < length; t++) {\n Type resolvedTypeArgument =\n resolve(context, contextRawType, args[t], visitedTypeVariables);\n if (!equal(resolvedTypeArgument, args[t])) {\n if (!changed) {\n args = args.clone();\n changed = true;\n }\n args[t] = resolvedTypeArgument;\n }\n }\n\n toResolve =\n changed\n ? newParameterizedTypeWithOwner(newOwnerType, original.getRawType(), args)\n : original;\n break;\n\n } else if (toResolve instanceof WildcardType) {\n WildcardType original = (WildcardType) toResolve;\n Type[] originalLowerBound = original.getLowerBounds();\n Type[] originalUpperBound = original.getUpperBounds();\n\n if (originalLowerBound.length == 1) {\n Type lowerBound =\n resolve(context, contextRawType, originalLowerBound[0], visitedTypeVariables);\n if (lowerBound != originalLowerBound[0]) {\n toResolve = supertypeOf(lowerBound);\n break;\n }\n } else if (originalUpperBound.length == 1) {\n Type upperBound =\n resolve(context, contextRawType, originalUpperBound[0], visitedTypeVariables);\n if (upperBound != originalUpperBound[0]) {\n toResolve = subtypeOf(upperBound);\n break;\n }\n }\n toResolve = original;\n break;\n\n } else {\n break;\n }\n }\n // ensure that any in-process resolution gets updated with the final result\n if (resolving != null) {\n visitedTypeVariables.put(resolving, toResolve);\n }\n return toResolve;\n }\n\n private static Type resolveTypeVariable(\n Type context, Class contextRawType, TypeVariable unknown) {\n Class declaredByRaw = declaringClassOf(unknown);\n\n // we can't reduce this further\n if (declaredByRaw == null) {\n return unknown;\n }\n\n Type declaredBy = getGenericSupertype(context, contextRawType, declaredByRaw);\n if (declaredBy instanceof ParameterizedType) {\n int index = indexOf(declaredByRaw.getTypeParameters(), unknown);\n return ((ParameterizedType) declaredBy).getActualTypeArguments()[index];\n }\n\n return unknown;\n }\n\n private static int indexOf(Object[] array, Object toFind) {\n for (int i = 0, length = array.length; i < length; i++) {\n if (toFind.equals(array[i])) {\n return i;\n }\n }\n throw new NoSuchElementException();\n }\n\n /**\n * Returns the declaring class of {@code typeVariable}, or {@code null} if it was not declared by\n * a class.\n */\n private static Class declaringClassOf(TypeVariable typeVariable) {\n GenericDeclaration genericDeclaration = typeVariable.getGenericDeclaration();\n return genericDeclaration instanceof Class ? (Class) genericDeclaration : null;\n }\n\n static void checkNotPrimitive(Type type) {\n checkArgument(!(type instanceof Class) || !((Class) type).isPrimitive());\n }\n\n /**\n * Whether an {@linkplain ParameterizedType#getOwnerType() owner type} must be specified when\n * constructing a {@link ParameterizedType} for {@code rawType}.\n *\n *

    Note that this method might not require an owner type for all cases where Java reflection\n * would create parameterized types with owner type.\n */\n public static boolean requiresOwnerType(Type rawType) {\n if (rawType instanceof Class) {\n Class rawTypeAsClass = (Class) rawType;\n return !Modifier.isStatic(rawTypeAsClass.getModifiers())\n && rawTypeAsClass.getDeclaringClass() != null;\n }\n return false;\n }\n\n // Here and below we put @SuppressWarnings(\"serial\") on fields of type `Type`. Recent Java\n // compilers complain that the declared type is not Serializable. But in this context we go out of\n // our way to ensure that the Type in question is either Class (which is serializable) or one of\n // the nested Type implementations here (which are also serializable).\n private static final class ParameterizedTypeImpl implements ParameterizedType, Serializable {\n @SuppressWarnings(\"serial\")\n private final Type ownerType;\n\n @SuppressWarnings(\"serial\")\n private final Type rawType;\n\n @SuppressWarnings(\"serial\")\n private final Type[] typeArguments;\n\n public ParameterizedTypeImpl(Type ownerType, Type rawType, Type... typeArguments) {\n // TODO: Should this enforce that rawType is a Class? See JDK implementation of\n // the ParameterizedType interface and https://bugs.openjdk.org/browse/JDK-8250659\n requireNonNull(rawType);\n if (ownerType == null && requiresOwnerType(rawType)) {\n throw new IllegalArgumentException(\"Must specify owner type for \" + rawType);\n }\n\n this.ownerType = ownerType == null ? null : canonicalize(ownerType);\n this.rawType = canonicalize(rawType);\n this.typeArguments = typeArguments.clone();\n for (int t = 0, length = this.typeArguments.length; t < length; t++) {\n requireNonNull(this.typeArguments[t]);\n checkNotPrimitive(this.typeArguments[t]);\n this.typeArguments[t] = canonicalize(this.typeArguments[t]);\n }\n }\n\n @Override\n public Type[] getActualTypeArguments() {\n return typeArguments.clone();\n }\n\n @Override\n public Type getRawType() {\n return rawType;\n }\n\n @Override\n public Type getOwnerType() {\n return ownerType;\n }\n\n @Override\n public boolean equals(Object other) {\n return other instanceof ParameterizedType\n && $Gson$Types.equals(this, (ParameterizedType) other);\n }\n\n private static int hashCodeOrZero(Object o) {\n return o != null ? o.hashCode() : 0;\n }\n\n @Override\n public int hashCode() {\n return Arrays.hashCode(typeArguments) ^ rawType.hashCode() ^ hashCodeOrZero(ownerType);\n }\n\n @Override\n public String toString() {\n int length = typeArguments.length;\n if (length == 0) {\n return typeToString(rawType);\n }\n\n StringBuilder stringBuilder = new StringBuilder(30 * (length + 1));\n stringBuilder\n .append(typeToString(rawType))\n .append(\"<\")\n .append(typeToString(typeArguments[0]));\n for (int i = 1; i < length; i++) {\n stringBuilder.append(\", \").append(typeToString(typeArguments[i]));\n }\n return stringBuilder.append(\">\").toString();\n }\n\n private static final long serialVersionUID = 0;\n }\n\n private static final class GenericArrayTypeImpl implements GenericArrayType, Serializable {\n @SuppressWarnings(\"serial\")\n private final Type componentType;\n\n public GenericArrayTypeImpl(Type componentType) {\n requireNonNull(componentType);\n this.componentType = canonicalize(componentType);\n }\n\n @Override\n public Type getGenericComponentType() {\n return componentType;\n }\n\n @Override\n public boolean equals(Object o) {\n return o instanceof GenericArrayType && $Gson$Types.equals(this, (GenericArrayType) o);\n }\n\n @Override\n public int hashCode() {\n return componentType.hashCode();\n }\n\n @Override\n public String toString() {\n return typeToString(componentType) + \"[]\";\n }\n\n private static final long serialVersionUID = 0;\n }\n\n /**\n * The WildcardType interface supports multiple upper bounds and multiple lower bounds. We only\n * support what the target Java version supports - at most one bound, see also\n * https://bugs.openjdk.java.net/browse/JDK-8250660. If a lower bound is set, the upper bound must\n * be Object.class.\n */\n private static final class WildcardTypeImpl implements WildcardType, Serializable {\n @SuppressWarnings(\"serial\")\n private final Type upperBound;\n\n @SuppressWarnings(\"serial\")\n private final Type lowerBound;\n\n public WildcardTypeImpl(Type[] upperBounds, Type[] lowerBounds) {\n checkArgument(lowerBounds.length <= 1);\n checkArgument(upperBounds.length == 1);\n\n if (lowerBounds.length == 1) {\n requireNonNull(lowerBounds[0]);\n checkNotPrimitive(lowerBounds[0]);\n checkArgument(upperBounds[0] == Object.class);\n this.lowerBound = canonicalize(lowerBounds[0]);\n this.upperBound = Object.class;\n\n } else {\n requireNonNull(upperBounds[0]);\n checkNotPrimitive(upperBounds[0]);\n this.lowerBound = null;\n this.upperBound = canonicalize(upperBounds[0]);\n }\n }\n\n @Override\n public Type[] getUpperBounds() {\n return new Type[] {upperBound};\n }\n\n @Override\n public Type[] getLowerBounds() {\n return lowerBound != null ? new Type[] {lowerBound} : EMPTY_TYPE_ARRAY;\n }\n\n @Override\n public boolean equals(Object other) {\n return other instanceof WildcardType && $Gson$Types.equals(this, (WildcardType) other);\n }\n\n @Override\n public int hashCode() {\n // this equals Arrays.hashCode(getLowerBounds()) ^ Arrays.hashCode(getUpperBounds());\n return (lowerBound != null ? 31 + lowerBound.hashCode() : 1) ^ (31 + upperBound.hashCode());\n }\n\n @Override\n public String toString() {\n if (lowerBound != null) {\n return \"? super \" + typeToString(lowerBound);\n } else if (upperBound == Object.class) {\n return \"?\";\n } else {\n return \"? extends \" + typeToString(upperBound);\n }\n }\n\n private static final long serialVersionUID = 0;\n }\n}\n\n// Path: gson/src/main/java/com/google/gson/internal/TroubleshootingGuide.java\npackage com.google.gson.internal;\n\npublic class TroubleshootingGuide {\n private TroubleshootingGuide() {}\n\n /** Creates a URL referring to the specified troubleshooting section. */\n public static String createUrl(String id) {\n return \"https://github.com/google/gson/blob/main/Troubleshooting.md#\" + id;\n }\n}\n\n// Path: gson/src/main/java/com/google/gson/reflect/TypeToken.java\n/*\n * Copyright (C) 2008 Google Inc.\n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n */\n\npackage com.google.gson.reflect;\n\nimport com.google.gson.internal.$Gson$Types;\nimport com.google.gson.internal.TroubleshootingGuide;\nimport java.lang.reflect.GenericArrayType;\nimport java.lang.reflect.ParameterizedType;\nimport java.lang.reflect.Type;\nimport java.lang.reflect.TypeVariable;\nimport java.lang.reflect.WildcardType;\nimport java.util.HashMap;\nimport java.util.Map;\nimport java.util.Objects;\n\n/**\n * Represents a generic type {@code T}. Java doesn't yet provide a way to represent generic types,\n * so this class does. Forces clients to create a subclass of this class which enables retrieval the\n * type information even at runtime.\n *\n *

    For example, to create a type literal for {@code List}, you can create an empty\n * anonymous class:\n *\n *

    {@code TypeToken> list = new TypeToken>() {};}\n *\n *

    Capturing a type variable as type argument of an anonymous {@code TypeToken} subclass is not\n * allowed, for example {@code TypeToken>}. Due to type erasure the runtime type of a type\n * variable is not available to Gson and therefore it cannot provide the functionality one might\n * expect. This would give a false sense of type-safety at compile time and could lead to an\n * unexpected {@code ClassCastException} at runtime.\n *\n *

    If the type arguments of the parameterized type are only available at runtime, for example\n * when you want to create a {@code List} based on a {@code Class} representing the element\n * type, the method {@link #getParameterized(Type, Type...)} can be used.\n *\n * @author Bob Lee\n * @author Sven Mawson\n * @author Jesse Wilson\n */\npublic class TypeToken {\n private final Class rawType;\n private final Type type;\n private final int hashCode;\n\n /**\n * Constructs a new type literal. Derives represented class from type parameter.\n *\n *

    Clients create an empty anonymous subclass. Doing so embeds the type parameter in the\n * anonymous class's type hierarchy so we can reconstitute it at runtime despite erasure, for\n * example:\n *\n *

    {@code new TypeToken>() {}}\n *\n * @throws IllegalArgumentException If the anonymous {@code TypeToken} subclass captures a type\n * variable, for example {@code TypeToken>}. See the {@code TypeToken} class\n * documentation for more details.\n */\n @SuppressWarnings(\"unchecked\")\n protected TypeToken() {\n this.type = getTypeTokenTypeArgument();\n this.rawType = (Class) $Gson$Types.getRawType(type);\n this.hashCode = type.hashCode();\n }\n\n /** Unsafe. Constructs a type literal manually. */\n @SuppressWarnings(\"unchecked\")\n private TypeToken(Type type) {\n this.type = $Gson$Types.canonicalize(Objects.requireNonNull(type));\n this.rawType = (Class) $Gson$Types.getRawType(this.type);\n this.hashCode = this.type.hashCode();\n }\n\n private static boolean isCapturingTypeVariablesForbidden() {\n return !Objects.equals(System.getProperty(\"gson.allowCapturingTypeVariables\"), \"true\");\n }\n\n /**\n * Verifies that {@code this} is an instance of a direct subclass of TypeToken and returns the\n * type argument for {@code T} in {@link $Gson$Types#canonicalize canonical form}.\n */\n private Type getTypeTokenTypeArgument() {\n Type superclass = getClass().getGenericSuperclass();\n if (superclass instanceof ParameterizedType) {\n ParameterizedType parameterized = (ParameterizedType) superclass;\n if (parameterized.getRawType() == TypeToken.class) {\n Type typeArgument = $Gson$Types.canonicalize(parameterized.getActualTypeArguments()[0]);\n\n if (isCapturingTypeVariablesForbidden()) {\n verifyNoTypeVariable(typeArgument);\n }\n return typeArgument;\n }\n }\n // Check for raw TypeToken as superclass\n else if (superclass == TypeToken.class) {\n throw new IllegalStateException(\n \"TypeToken must be created with a type argument: new TypeToken<...>() {}; When using code\"\n + \" shrinkers (ProGuard, R8, ...) make sure that generic signatures are preserved.\"\n + \"\\nSee \"\n + TroubleshootingGuide.createUrl(\"type-token-raw\"));\n }\n\n // User created subclass of subclass of TypeToken\n throw new IllegalStateException(\"Must only create direct subclasses of TypeToken\");\n }\n\n private static void verifyNoTypeVariable(Type type) {\n if (type instanceof TypeVariable) {\n TypeVariable typeVariable = (TypeVariable) type;\n throw new IllegalArgumentException(\n \"TypeToken type argument must not contain a type variable; captured type variable \"\n + typeVariable.getName()\n + \" declared by \"\n + typeVariable.getGenericDeclaration()\n + \"\\nSee \"\n + TroubleshootingGuide.createUrl(\"typetoken-type-variable\"));\n } else if (type instanceof GenericArrayType) {\n verifyNoTypeVariable(((GenericArrayType) type).getGenericComponentType());\n } else if (type instanceof ParameterizedType) {\n ParameterizedType parameterizedType = (ParameterizedType) type;\n Type ownerType = parameterizedType.getOwnerType();\n if (ownerType != null) {\n verifyNoTypeVariable(ownerType);\n }\n\n for (Type typeArgument : parameterizedType.getActualTypeArguments()) {\n verifyNoTypeVariable(typeArgument);\n }\n } else if (type instanceof WildcardType) {\n WildcardType wildcardType = (WildcardType) type;\n for (Type bound : wildcardType.getLowerBounds()) {\n verifyNoTypeVariable(bound);\n }\n for (Type bound : wildcardType.getUpperBounds()) {\n verifyNoTypeVariable(bound);\n }\n } else if (type == null) {\n // Occurs in Eclipse IDE and certain Java versions (e.g. Java 11.0.18) when capturing type\n // variable declared by method of local class, see\n...\n", "needle_token_start": 10412, "needle_token_end": 10464, "code_context_ntokens": 16061, "output": ["```java\n/**\n * Returns the component type of this array type.\n *\n * @throws ClassCastException if this type is not an array.\n */\npublic static Type getArrayComponentType(Type array) {\n return array instanceof GenericArrayType\n ? ((GenericArrayType) array).getGenericComponentType()\n : ((Class) array).getComponentType();\n}\n```"]} +{"repo": "google/gson", "name": "setExclusionStrategies", "language": "java", "path": "gson/src/main/java/com/google/gson/GsonBuilder.java", "position_ratio": 0.75, "description": "\nFunction Description:\n1. **Purpose**: To configure a JSON serialization and deserialization library to selectively ignore certain fields or classes based on specified rules, enhancing control over the output and input process.\n2. **Input**: An array of strategy objects that define conditions under which fields or classes should be skipped during serialization and deserialization.\n3. **Output**: A reference to the builder object itself, allowing for method chaining in accordance with the builder pattern.\n4. **Procedure**: The method iterates over each provided strategy object, applying these to the internal configuration. Each strategy is checked to determine if a field or class should be excluded during the serialization or deserialization process. If any strategy suggests exclusion, the respective field or class is skipped.\n\n", "instruction": "Based on the function description and code context, please retrieve and repeat the exact described function from the code context in a code block wrapped by ```:", "template": "instruction\ncode_context\ndescription\ninstruction", "code_context": "// Path: gson/src/main/java/com/google/gson/JsonDeserializer.java\n/*\n * Copyright (C) 2008 Google Inc.\n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n */\n\npackage com.google.gson;\n\nimport java.lang.reflect.Type;\n\n/**\n * Interface representing a custom deserializer for JSON. You should write a custom deserializer, if\n * you are not happy with the default deserialization done by Gson. You will also need to register\n * this deserializer through {@link GsonBuilder#registerTypeAdapter(Type, Object)}.\n *\n *

    Let us look at example where defining a deserializer will be useful. The {@code Id} class\n * defined below has two fields: {@code clazz} and {@code value}.\n *\n *

    \n * public class Id<T> {\n *   private final Class<T> clazz;\n *   private final long value;\n *   public Id(Class<T> clazz, long value) {\n *     this.clazz = clazz;\n *     this.value = value;\n *   }\n *   public long getValue() {\n *     return value;\n *   }\n * }\n * 
    \n *\n *

    The default deserialization of {@code Id(com.foo.MyObject.class, 20L)} will require the JSON\n * string to be {\"clazz\":\"com.foo.MyObject\",\"value\":20}. Suppose, you already know the\n * type of the field that the {@code Id} will be deserialized into, and hence just want to\n * deserialize it from a JSON string {@code 20}. You can achieve that by writing a custom\n * deserializer:\n *\n *

    \n * class IdDeserializer implements JsonDeserializer<Id> {\n *   public Id deserialize(JsonElement json, Type typeOfT, JsonDeserializationContext context)\n *       throws JsonParseException {\n *     long idValue = json.getAsJsonPrimitive().getAsLong();\n *     return new Id((Class) typeOfT, idValue);\n *   }\n * }\n * 
    \n *\n...\n// Path: gson/src/main/java/com/google/gson/internal/bind/TreeTypeAdapter.java\n/*\n * Copyright (C) 2011 Google Inc.\n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n */\n\npackage com.google.gson.internal.bind;\n\nimport com.google.gson.Gson;\nimport com.google.gson.JsonDeserializationContext;\nimport com.google.gson.JsonDeserializer;\nimport com.google.gson.JsonElement;\nimport com.google.gson.JsonParseException;\nimport com.google.gson.JsonSerializationContext;\nimport com.google.gson.JsonSerializer;\nimport com.google.gson.TypeAdapter;\nimport com.google.gson.TypeAdapterFactory;\nimport com.google.gson.internal.$Gson$Preconditions;\nimport com.google.gson.internal.Streams;\nimport com.google.gson.reflect.TypeToken;\nimport com.google.gson.stream.JsonReader;\nimport com.google.gson.stream.JsonWriter;\nimport java.io.IOException;\nimport java.lang.reflect.Type;\n\n/**\n * Adapts a Gson 1.x tree-style adapter as a streaming TypeAdapter. Since the tree adapter may be\n * serialization-only or deserialization-only, this class has a facility to look up a delegate type\n * adapter on demand.\n */\npublic final class TreeTypeAdapter extends SerializationDelegatingTypeAdapter {\n private final JsonSerializer serializer;\n private final JsonDeserializer deserializer;\n final Gson gson;\n private final TypeToken typeToken;\n\n /**\n * Only intended as {@code skipPast} for {@link Gson#getDelegateAdapter(TypeAdapterFactory,\n * TypeToken)}, must not be used in any other way.\n */\n private final TypeAdapterFactory skipPastForGetDelegateAdapter;\n\n private final GsonContextImpl context = new GsonContextImpl();\n private final boolean nullSafe;\n\n /**\n * The delegate is lazily created because it may not be needed, and creating it may fail. Field\n * has to be {@code volatile} because {@link Gson} guarantees to be thread-safe.\n */\n private volatile TypeAdapter delegate;\n\n public TreeTypeAdapter(\n JsonSerializer serializer,\n JsonDeserializer deserializer,\n Gson gson,\n TypeToken typeToken,\n TypeAdapterFactory skipPast,\n boolean nullSafe) {\n this.serializer = serializer;\n this.deserializer = deserializer;\n this.gson = gson;\n this.typeToken = typeToken;\n this.skipPastForGetDelegateAdapter = skipPast;\n this.nullSafe = nullSafe;\n }\n\n public TreeTypeAdapter(\n JsonSerializer serializer,\n JsonDeserializer deserializer,\n Gson gson,\n TypeToken typeToken,\n TypeAdapterFactory skipPast) {\n this(serializer, deserializer, gson, typeToken, skipPast, true);\n }\n\n @Override\n public T read(JsonReader in) throws IOException {\n if (deserializer == null) {\n return delegate().read(in);\n }\n JsonElement value = Streams.parse(in);\n if (nullSafe && value.isJsonNull()) {\n return null;\n }\n return deserializer.deserialize(value, typeToken.getType(), context);\n }\n\n @Override\n public void write(JsonWriter out, T value) throws IOException {\n if (serializer == null) {\n delegate().write(out, value);\n return;\n }\n if (nullSafe && value == null) {\n out.nullValue();\n return;\n }\n JsonElement tree = serializer.serialize(value, typeToken.getType(), context);\n Streams.write(tree, out);\n }\n\n private TypeAdapter delegate() {\n // A race might lead to `delegate` being assigned by multiple threads but the last assignment\n // will stick\n TypeAdapter d = delegate;\n return d != null\n ? d\n : (delegate = gson.getDelegateAdapter(skipPastForGetDelegateAdapter, typeToken));\n }\n\n /**\n * Returns the type adapter which is used for serialization. Returns {@code this} if this {@code\n * TreeTypeAdapter} has a {@link #serializer}; otherwise returns the delegate.\n */\n @Override\n public TypeAdapter getSerializationDelegate() {\n return serializer != null ? this : delegate();\n }\n\n /** Returns a new factory that will match each type against {@code exactType}. */\n public static TypeAdapterFactory newFactory(TypeToken exactType, Object typeAdapter) {\n return new SingleTypeFactory(typeAdapter, exactType, false, null);\n }\n\n /** Returns a new factory that will match each type and its raw type against {@code exactType}. */\n public static TypeAdapterFactory newFactoryWithMatchRawType(\n TypeToken exactType, Object typeAdapter) {\n // only bother matching raw types if exact type is a raw type\n boolean matchRawType = exactType.getType() == exactType.getRawType();\n return new SingleTypeFactory(typeAdapter, exactType, matchRawType, null);\n }\n\n /**\n * Returns a new factory that will match each type's raw type for assignability to {@code\n * hierarchyType}.\n */\n public static TypeAdapterFactory newTypeHierarchyFactory(\n Class hierarchyType, Object typeAdapter) {\n return new SingleTypeFactory(typeAdapter, null, false, hierarchyType);\n }\n\n private static final class SingleTypeFactory implements TypeAdapterFactory {\n private final TypeToken exactType;\n private final boolean matchRawType;\n private final Class hierarchyType;\n private final JsonSerializer serializer;\n private final JsonDeserializer deserializer;\n\n SingleTypeFactory(\n Object typeAdapter, TypeToken exactType, boolean matchRawType, Class hierarchyType) {\n serializer = typeAdapter instanceof JsonSerializer ? (JsonSerializer) typeAdapter : null;\n deserializer =\n typeAdapter instanceof JsonDeserializer ? (JsonDeserializer) typeAdapter : null;\n $Gson$Preconditions.checkArgument(serializer != null || deserializer != null);\n this.exactType = exactType;\n this.matchRawType = matchRawType;\n this.hierarchyType = hierarchyType;\n }\n\n @SuppressWarnings(\"unchecked\") // guarded by typeToken.equals() call\n @Override\n public TypeAdapter create(Gson gson, TypeToken type) {\n boolean matches =\n exactType != null\n ? exactType.equals(type) || (matchRawType && exactType.getType() == type.getRawType())\n : hierarchyType.isAssignableFrom(type.getRawType());\n return matches\n ? new TreeTypeAdapter<>(\n (JsonSerializer) serializer, (JsonDeserializer) deserializer, gson, type, this)\n : null;\n }\n }\n\n private final class GsonContextImpl\n implements JsonSerializationContext, JsonDeserializationContext {\n @Override\n public JsonElement serialize(Object src) {\n return gson.toJsonTree(src);\n }\n\n @Override\n public JsonElement serialize(Object src, Type typeOfSrc) {\n return gson.toJsonTree(src, typeOfSrc);\n }\n\n @Override\n @SuppressWarnings({\"unchecked\", \"TypeParameterUnusedInFormals\"})\n public R deserialize(JsonElement json, Type typeOfT) throws JsonParseException {\n return gson.fromJson(json, typeOfT);\n }\n }\n}\n\n// Path: gson/src/main/java/com/google/gson/internal/sql/SqlDateTypeAdapter.java\n/*\n * Copyright (C) 2011 Google Inc.\n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n */\n\npackage com.google.gson.internal.sql;\n\nimport com.google.gson.Gson;\nimport com.google.gson.JsonSyntaxException;\nimport com.google.gson.TypeAdapter;\nimport com.google.gson.TypeAdapterFactory;\nimport com.google.gson.reflect.TypeToken;\nimport com.google.gson.stream.JsonReader;\nimport com.google.gson.stream.JsonToken;\nimport com.google.gson.stream.JsonWriter;\nimport java.io.IOException;\nimport java.text.DateFormat;\nimport java.text.ParseException;\nimport java.text.SimpleDateFormat;\nimport java.util.Date;\nimport java.util.TimeZone;\n\n/**\n * Adapter for java.sql.Date. Although this class appears stateless, it is not. DateFormat captures\n * its time zone and locale when it is created, which gives this class state. DateFormat isn't\n * thread safe either, so this class has to synchronize its read and write methods.\n */\n@SuppressWarnings(\"JavaUtilDate\")\nfinal class SqlDateTypeAdapter extends TypeAdapter {\n static final TypeAdapterFactory FACTORY =\n new TypeAdapterFactory() {\n @SuppressWarnings(\"unchecked\") // we use a runtime check to make sure the 'T's equal\n @Override\n public TypeAdapter create(Gson gson, TypeToken typeToken) {\n return typeToken.getRawType() == java.sql.Date.class\n ? (TypeAdapter) new SqlDateTypeAdapter()\n : null;\n }\n };\n\n private final DateFormat format = new SimpleDateFormat(\"MMM d, yyyy\");\n\n private SqlDateTypeAdapter() {}\n\n @Override\n public java.sql.Date read(JsonReader in) throws IOException {\n if (in.peek() == JsonToken.NULL) {\n in.nextNull();\n return null;\n }\n String s = in.nextString();\n synchronized (this) {\n TimeZone originalTimeZone = format.getTimeZone(); // Save the original time zone\n try {\n Date utilDate = format.parse(s);\n return new java.sql.Date(utilDate.getTime());\n } catch (ParseException e) {\n throw new JsonSyntaxException(\n \"Failed parsing '\" + s + \"' as SQL Date; at path \" + in.getPreviousPath(), e);\n } finally {\n format.setTimeZone(originalTimeZone); // Restore the original time zone after parsing\n }\n }\n }\n\n @Override\n public void write(JsonWriter out, java.sql.Date value) throws IOException {\n if (value == null) {\n out.nullValue();\n return;\n }\n String dateString;\n synchronized (this) {\n dateString = format.format(value);\n }\n out.value(dateString);\n }\n}\n\n// Path: gson/src/main/java/com/google/gson/internal/sql/SqlTimeTypeAdapter.java\n/*\n * Copyright (C) 2011 Google Inc.\n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n */\n\npackage com.google.gson.internal.sql;\n\nimport com.google.gson.Gson;\nimport com.google.gson.JsonSyntaxException;\nimport com.google.gson.TypeAdapter;\nimport com.google.gson.TypeAdapterFactory;\nimport com.google.gson.reflect.TypeToken;\nimport com.google.gson.stream.JsonReader;\nimport com.google.gson.stream.JsonToken;\nimport com.google.gson.stream.JsonWriter;\nimport java.io.IOException;\nimport java.sql.Time;\nimport java.text.DateFormat;\nimport java.text.ParseException;\nimport java.text.SimpleDateFormat;\nimport java.util.Date;\nimport java.util.TimeZone;\n\n/**\n * Adapter for java.sql.Time. Although this class appears stateless, it is not. DateFormat captures\n * its time zone and locale when it is created, which gives this class state. DateFormat isn't\n * thread safe either, so this class has to synchronize its read and write methods.\n */\n@SuppressWarnings(\"JavaUtilDate\")\nfinal class SqlTimeTypeAdapter extends TypeAdapter